summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib')
-rw-r--r--ml/dlib/dlib/CMakeLists.txt841
-rw-r--r--ml/dlib/dlib/LICENSE.txt23
-rw-r--r--ml/dlib/dlib/algs.h1157
-rw-r--r--ml/dlib/dlib/all/source.cpp98
-rw-r--r--ml/dlib/dlib/any.h13
-rw-r--r--ml/dlib/dlib/any/any.h183
-rw-r--r--ml/dlib/dlib/any/any_abstract.h210
-rw-r--r--ml/dlib/dlib/any/any_decision_function.h209
-rw-r--r--ml/dlib/dlib/any/any_decision_function_abstract.h224
-rw-r--r--ml/dlib/dlib/any/any_function.h885
-rw-r--r--ml/dlib/dlib/any/any_function_abstract.h292
-rw-r--r--ml/dlib/dlib/any/any_function_impl.h516
-rw-r--r--ml/dlib/dlib/any/any_function_impl2.h44
-rw-r--r--ml/dlib/dlib/any/any_trainer.h217
-rw-r--r--ml/dlib/dlib/any/any_trainer_abstract.h234
-rw-r--r--ml/dlib/dlib/appveyor/dtest.yml19
-rw-r--r--ml/dlib/dlib/appveyor/dtest_vc2017.yml21
-rw-r--r--ml/dlib/dlib/appveyor/examples.yml16
-rw-r--r--ml/dlib/dlib/appveyor/python.yml33
-rw-r--r--ml/dlib/dlib/array.h10
-rw-r--r--ml/dlib/dlib/array/array_kernel.h810
-rw-r--r--ml/dlib/dlib/array/array_kernel_abstract.h360
-rw-r--r--ml/dlib/dlib/array/array_tools.h38
-rw-r--r--ml/dlib/dlib/array/array_tools_abstract.h33
-rw-r--r--ml/dlib/dlib/array2d.h12
-rw-r--r--ml/dlib/dlib/array2d/array2d_generic_image.h67
-rw-r--r--ml/dlib/dlib/array2d/array2d_kernel.h498
-rw-r--r--ml/dlib/dlib/array2d/array2d_kernel_abstract.h301
-rw-r--r--ml/dlib/dlib/array2d/serialize_pixel_overloads.h371
-rw-r--r--ml/dlib/dlib/assert.h216
-rw-r--r--ml/dlib/dlib/base64.h9
-rw-r--r--ml/dlib/dlib/base64/base64_kernel_1.cpp403
-rw-r--r--ml/dlib/dlib/base64/base64_kernel_1.h92
-rw-r--r--ml/dlib/dlib/base64/base64_kernel_abstract.h121
-rw-r--r--ml/dlib/dlib/bayes_utils.h11
-rw-r--r--ml/dlib/dlib/bayes_utils/bayes_utils.h1678
-rw-r--r--ml/dlib/dlib/bayes_utils/bayes_utils_abstract.h1042
-rw-r--r--ml/dlib/dlib/bigint.h43
-rw-r--r--ml/dlib/dlib/bigint/bigint_kernel_1.cpp1720
-rw-r--r--ml/dlib/dlib/bigint/bigint_kernel_1.h544
-rw-r--r--ml/dlib/dlib/bigint/bigint_kernel_2.cpp1945
-rw-r--r--ml/dlib/dlib/bigint/bigint_kernel_2.h570
-rw-r--r--ml/dlib/dlib/bigint/bigint_kernel_abstract.h670
-rw-r--r--ml/dlib/dlib/bigint/bigint_kernel_c.h1141
-rw-r--r--ml/dlib/dlib/binary_search_tree.h50
-rw-r--r--ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_1.h2064
-rw-r--r--ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_2.h1897
-rw-r--r--ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h311
-rw-r--r--ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_c.h235
-rw-r--r--ml/dlib/dlib/bit_stream.h42
-rw-r--r--ml/dlib/dlib/bit_stream/bit_stream_kernel_1.cpp200
-rw-r--r--ml/dlib/dlib/bit_stream/bit_stream_kernel_1.h120
-rw-r--r--ml/dlib/dlib/bit_stream/bit_stream_kernel_abstract.h185
-rw-r--r--ml/dlib/dlib/bit_stream/bit_stream_kernel_c.h172
-rw-r--r--ml/dlib/dlib/bit_stream/bit_stream_multi_1.h103
-rw-r--r--ml/dlib/dlib/bit_stream/bit_stream_multi_abstract.h77
-rw-r--r--ml/dlib/dlib/bit_stream/bit_stream_multi_c.h101
-rw-r--r--ml/dlib/dlib/bits/c++config.h1
-rw-r--r--ml/dlib/dlib/bound_function_pointer.h10
-rw-r--r--ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h774
-rw-r--r--ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h456
-rw-r--r--ml/dlib/dlib/bridge.h17
-rw-r--r--ml/dlib/dlib/bridge/bridge.h669
-rw-r--r--ml/dlib/dlib/bridge/bridge_abstract.h347
-rw-r--r--ml/dlib/dlib/bsp.h12
-rw-r--r--ml/dlib/dlib/bsp/bsp.cpp496
-rw-r--r--ml/dlib/dlib/bsp/bsp.h1043
-rw-r--r--ml/dlib/dlib/bsp/bsp_abstract.h912
-rw-r--r--ml/dlib/dlib/byte_orderer.h10
-rw-r--r--ml/dlib/dlib/byte_orderer/byte_orderer_kernel_1.h176
-rw-r--r--ml/dlib/dlib/byte_orderer/byte_orderer_kernel_abstract.h149
-rw-r--r--ml/dlib/dlib/cassert1
-rw-r--r--ml/dlib/dlib/clustering.h13
-rw-r--r--ml/dlib/dlib/clustering/bottom_up_cluster.h253
-rw-r--r--ml/dlib/dlib/clustering/bottom_up_cluster_abstract.h136
-rw-r--r--ml/dlib/dlib/clustering/chinese_whispers.h135
-rw-r--r--ml/dlib/dlib/clustering/chinese_whispers_abstract.h97
-rw-r--r--ml/dlib/dlib/clustering/modularity_clustering.h515
-rw-r--r--ml/dlib/dlib/clustering/modularity_clustering_abstract.h125
-rw-r--r--ml/dlib/dlib/clustering/spectral_cluster.h80
-rw-r--r--ml/dlib/dlib/clustering/spectral_cluster_abstract.h43
-rw-r--r--ml/dlib/dlib/cmake5
-rw-r--r--ml/dlib/dlib/cmake_utils/add_global_compiler_switch.cmake35
-rw-r--r--ml/dlib/dlib/cmake_utils/check_if_neon_available.cmake20
-rw-r--r--ml/dlib/dlib/cmake_utils/dlib.pc.in9
-rw-r--r--ml/dlib/dlib/cmake_utils/dlibConfig.cmake.in50
-rw-r--r--ml/dlib/dlib/cmake_utils/find_blas.cmake385
-rw-r--r--ml/dlib/dlib/cmake_utils/release_build_by_default9
-rw-r--r--ml/dlib/dlib/cmake_utils/set_compiler_specific_options.cmake131
-rw-r--r--ml/dlib/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake19
-rw-r--r--ml/dlib/dlib/cmake_utils/test_for_cpp11/CMakeLists.txt17
-rw-r--r--ml/dlib/dlib/cmake_utils/test_for_cpp11/cpp11_test.cpp51
-rw-r--r--ml/dlib/dlib/cmake_utils/test_for_cuda/CMakeLists.txt14
-rw-r--r--ml/dlib/dlib/cmake_utils/test_for_cuda/cuda_test.cu21
-rw-r--r--ml/dlib/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt19
-rw-r--r--ml/dlib/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt24
-rw-r--r--ml/dlib/dlib/cmake_utils/test_for_neon/CMakeLists.txt6
-rw-r--r--ml/dlib/dlib/cmake_utils/test_for_neon/neon_test.cpp9
-rw-r--r--ml/dlib/dlib/cmake_utils/use_cpp_11.cmake113
-rw-r--r--ml/dlib/dlib/cmd_line_parser.h84
-rw-r--r--ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_1.h580
-rw-r--r--ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_c.h453
-rw-r--r--ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h799
-rw-r--r--ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h673
-rw-r--r--ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h203
-rw-r--r--ml/dlib/dlib/cmd_line_parser/cmd_line_parser_print_1.h205
-rw-r--r--ml/dlib/dlib/cmd_line_parser/get_option.h181
-rw-r--r--ml/dlib/dlib/cmd_line_parser/get_option_abstract.h146
-rw-r--r--ml/dlib/dlib/compress_stream.h133
-rw-r--r--ml/dlib/dlib/compress_stream/compress_stream_kernel_1.h252
-rw-r--r--ml/dlib/dlib/compress_stream/compress_stream_kernel_2.h431
-rw-r--r--ml/dlib/dlib/compress_stream/compress_stream_kernel_3.h381
-rw-r--r--ml/dlib/dlib/compress_stream/compress_stream_kernel_abstract.h94
-rw-r--r--ml/dlib/dlib/conditioning_class.h80
-rw-r--r--ml/dlib/dlib/conditioning_class/conditioning_class_kernel_1.h333
-rw-r--r--ml/dlib/dlib/conditioning_class/conditioning_class_kernel_2.h500
-rw-r--r--ml/dlib/dlib/conditioning_class/conditioning_class_kernel_3.h438
-rw-r--r--ml/dlib/dlib/conditioning_class/conditioning_class_kernel_4.h533
-rw-r--r--ml/dlib/dlib/conditioning_class/conditioning_class_kernel_abstract.h228
-rw-r--r--ml/dlib/dlib/conditioning_class/conditioning_class_kernel_c.h162
-rw-r--r--ml/dlib/dlib/config.h31
-rw-r--r--ml/dlib/dlib/config.h.in34
-rw-r--r--ml/dlib/dlib/config_reader.h39
-rw-r--r--ml/dlib/dlib/config_reader/config_reader_kernel_1.h738
-rw-r--r--ml/dlib/dlib/config_reader/config_reader_kernel_abstract.h363
-rw-r--r--ml/dlib/dlib/config_reader/config_reader_thread_safe_1.h456
-rw-r--r--ml/dlib/dlib/config_reader/config_reader_thread_safe_abstract.h45
-rw-r--r--ml/dlib/dlib/console_progress_indicator.h207
-rw-r--r--ml/dlib/dlib/control.h11
-rw-r--r--ml/dlib/dlib/control/approximate_linear_models.h128
-rw-r--r--ml/dlib/dlib/control/approximate_linear_models_abstract.h213
-rw-r--r--ml/dlib/dlib/control/lspi.h188
-rw-r--r--ml/dlib/dlib/control/lspi_abstract.h193
-rw-r--r--ml/dlib/dlib/control/mpc.h370
-rw-r--r--ml/dlib/dlib/control/mpc_abstract.h276
-rw-r--r--ml/dlib/dlib/cpp_pretty_printer.h39
-rw-r--r--ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h583
-rw-r--r--ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h520
-rw-r--r--ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h88
-rw-r--r--ml/dlib/dlib/cpp_tokenizer.h40
-rw-r--r--ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h675
-rw-r--r--ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h224
-rw-r--r--ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h137
-rw-r--r--ml/dlib/dlib/crc32.h10
-rw-r--r--ml/dlib/dlib/crc32/crc32_kernel_1.h262
-rw-r--r--ml/dlib/dlib/crc32/crc32_kernel_abstract.h132
-rw-r--r--ml/dlib/dlib/cstring1
-rw-r--r--ml/dlib/dlib/data_io.h18
-rw-r--r--ml/dlib/dlib/data_io/image_dataset_metadata.cpp411
-rw-r--r--ml/dlib/dlib/data_io/image_dataset_metadata.h174
-rw-r--r--ml/dlib/dlib/data_io/libsvm_io.h276
-rw-r--r--ml/dlib/dlib/data_io/libsvm_io_abstract.h125
-rw-r--r--ml/dlib/dlib/data_io/load_image_dataset.h510
-rw-r--r--ml/dlib/dlib/data_io/load_image_dataset_abstract.h358
-rw-r--r--ml/dlib/dlib/data_io/mnist.cpp133
-rw-r--r--ml/dlib/dlib/data_io/mnist.h32
-rw-r--r--ml/dlib/dlib/data_io/mnist_abstract.h46
-rw-r--r--ml/dlib/dlib/dir_nav.h21
-rw-r--r--ml/dlib/dlib/dir_nav/dir_nav_extensions.cpp121
-rw-r--r--ml/dlib/dlib/dir_nav/dir_nav_extensions.h172
-rw-r--r--ml/dlib/dlib/dir_nav/dir_nav_extensions_abstract.h203
-rw-r--r--ml/dlib/dlib/dir_nav/dir_nav_kernel_1.cpp258
-rw-r--r--ml/dlib/dlib/dir_nav/dir_nav_kernel_1.h634
-rw-r--r--ml/dlib/dlib/dir_nav/dir_nav_kernel_2.cpp254
-rw-r--r--ml/dlib/dlib/dir_nav/dir_nav_kernel_2.h659
-rw-r--r--ml/dlib/dlib/dir_nav/dir_nav_kernel_abstract.h515
-rw-r--r--ml/dlib/dlib/dir_nav/posix.h6
-rw-r--r--ml/dlib/dlib/dir_nav/windows.h6
-rw-r--r--ml/dlib/dlib/directed_graph.h37
-rw-r--r--ml/dlib/dlib/directed_graph/directed_graph_kernel_1.h704
-rw-r--r--ml/dlib/dlib/directed_graph/directed_graph_kernel_abstract.h383
-rw-r--r--ml/dlib/dlib/disjoint_subsets.h12
-rw-r--r--ml/dlib/dlib/disjoint_subsets/disjoint_subsets.h141
-rw-r--r--ml/dlib/dlib/disjoint_subsets/disjoint_subsets_abstract.h96
-rw-r--r--ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized.h130
-rw-r--r--ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h123
-rw-r--r--ml/dlib/dlib/dlib_basic_cpp_build_tutorial.txt13
-rw-r--r--ml/dlib/dlib/dlib_include_path_tutorial.txt20
-rw-r--r--ml/dlib/dlib/dnn.h37
-rw-r--r--ml/dlib/dlib/dnn/core.h3599
-rw-r--r--ml/dlib/dlib/dnn/core_abstract.h1700
-rw-r--r--ml/dlib/dlib/dnn/cpu_dlib.cpp2170
-rw-r--r--ml/dlib/dlib/dnn/cpu_dlib.h505
-rw-r--r--ml/dlib/dlib/dnn/cublas_dlibapi.cpp165
-rw-r--r--ml/dlib/dlib/dnn/cublas_dlibapi.h50
-rw-r--r--ml/dlib/dlib/dnn/cuda_data_ptr.cpp71
-rw-r--r--ml/dlib/dlib/dnn/cuda_data_ptr.h184
-rw-r--r--ml/dlib/dlib/dnn/cuda_dlib.cu1630
-rw-r--r--ml/dlib/dlib/dnn/cuda_dlib.h469
-rw-r--r--ml/dlib/dlib/dnn/cuda_errors.h70
-rw-r--r--ml/dlib/dlib/dnn/cuda_utils.h413
-rw-r--r--ml/dlib/dlib/dnn/cudnn_dlibapi.cpp1604
-rw-r--r--ml/dlib/dlib/dnn/cudnn_dlibapi.h518
-rw-r--r--ml/dlib/dlib/dnn/curand_dlibapi.cpp113
-rw-r--r--ml/dlib/dlib/dnn/curand_dlibapi.h75
-rw-r--r--ml/dlib/dlib/dnn/cusolver_dlibapi.cu204
-rw-r--r--ml/dlib/dlib/dnn/cusolver_dlibapi.h75
-rw-r--r--ml/dlib/dlib/dnn/gpu_data.cpp228
-rw-r--r--ml/dlib/dlib/dnn/gpu_data.h266
-rw-r--r--ml/dlib/dlib/dnn/gpu_data_abstract.h266
-rw-r--r--ml/dlib/dlib/dnn/input.h808
-rw-r--r--ml/dlib/dlib/dnn/input_abstract.h467
-rw-r--r--ml/dlib/dlib/dnn/layers.h3244
-rw-r--r--ml/dlib/dlib/dnn/layers_abstract.h2631
-rw-r--r--ml/dlib/dlib/dnn/loss.h2870
-rw-r--r--ml/dlib/dlib/dnn/loss_abstract.h1542
-rw-r--r--ml/dlib/dlib/dnn/solvers.h405
-rw-r--r--ml/dlib/dlib/dnn/solvers_abstract.h204
-rw-r--r--ml/dlib/dlib/dnn/tensor.h686
-rw-r--r--ml/dlib/dlib/dnn/tensor_abstract.h727
-rw-r--r--ml/dlib/dlib/dnn/tensor_tools.cpp985
-rw-r--r--ml/dlib/dlib/dnn/tensor_tools.h1711
-rw-r--r--ml/dlib/dlib/dnn/trainer.h1333
-rw-r--r--ml/dlib/dlib/dnn/trainer_abstract.h765
-rw-r--r--ml/dlib/dlib/dnn/utilities.h281
-rw-r--r--ml/dlib/dlib/dnn/utilities_abstract.h127
-rw-r--r--ml/dlib/dlib/dnn/validation.h122
-rw-r--r--ml/dlib/dlib/enable_if.h62
-rw-r--r--ml/dlib/dlib/entropy_decoder.h44
-rw-r--r--ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.cpp220
-rw-r--r--ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.h132
-rw-r--r--ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.cpp224
-rw-r--r--ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.h127
-rw-r--r--ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_abstract.h207
-rw-r--r--ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_c.h123
-rw-r--r--ml/dlib/dlib/entropy_decoder_model.h108
-rw-r--r--ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_1.h173
-rw-r--r--ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_2.h245
-rw-r--r--ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_3.h335
-rw-r--r--ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_4.h622
-rw-r--r--ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_5.h793
-rw-r--r--ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_6.h131
-rw-r--r--ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_abstract.h116
-rw-r--r--ml/dlib/dlib/entropy_encoder.h43
-rw-r--r--ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.cpp239
-rw-r--r--ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.h119
-rw-r--r--ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.cpp233
-rw-r--r--ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.h112
-rw-r--r--ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_abstract.h161
-rw-r--r--ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_c.h112
-rw-r--r--ml/dlib/dlib/entropy_encoder_model.h146
-rw-r--r--ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_1.h167
-rw-r--r--ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_2.h246
-rw-r--r--ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_3.h341
-rw-r--r--ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_4.h553
-rw-r--r--ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_5.h817
-rw-r--r--ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_6.h127
-rw-r--r--ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_abstract.h118
-rw-r--r--ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_c.h65
-rw-r--r--ml/dlib/dlib/error.h449
-rw-r--r--ml/dlib/dlib/external/cblas/CMakeLists.txt182
-rw-r--r--ml/dlib/dlib/external/cblas/README7
-rw-r--r--ml/dlib/dlib/external/cblas/cblas.h575
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_caxpy.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ccopy.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cdotc_sub.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cdotu_sub.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cgbmv.c154
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cgemm.c94
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cgemv.c151
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cgerc.c77
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cgeru.c38
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_chbmv.c145
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_chemm.c91
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_chemv.c146
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cher.c103
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cher2.c139
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cher2k.c96
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cherk.c90
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_chpmv.c146
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_chpr.c102
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_chpr2.c136
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cscal.c21
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_csscal.c21
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_cswap.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_csymm.c91
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_csyr2k.c93
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_csyrk.c93
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ctbmv.c139
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ctbsv.c143
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ctpmv.c133
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ctpsv.c138
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ctrmm.c123
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ctrmv.c136
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ctrsm.c132
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ctrsv.c137
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dasum.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_daxpy.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dcopy.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ddot.c25
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dgbmv.c70
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dgemm.c94
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dgemv.c67
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dger.c40
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dnrm2.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_drot.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_drotg.c14
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_drotm.c14
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_drotmg.c15
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dsbmv.c66
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dscal.c21
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dsdot.c25
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dspmv.c65
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dspr.c59
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dspr2.c59
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dswap.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dsymm.c91
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dsymv.c65
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dsyr.c60
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dsyr2.c65
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dsyr2k.c94
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dsyrk.c93
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dtbmv.c103
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dtbsv.c103
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dtpmv.c98
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dtpsv.c99
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dtrmm.c125
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dtrmv.c103
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dtrsm.c130
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dtrsv.c102
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dzasum.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_dznrm2.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_f77.h701
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_icamax.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_idamax.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_isamax.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_izamax.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sasum.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_saxpy.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_scasum.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_scnrm2.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_scopy.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sdot.c25
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sdsdot.c25
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sgbmv.c72
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sgemm.c95
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sgemv.c67
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sger.c39
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_snrm2.c23
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_srot.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_srotg.c14
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_srotm.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_srotmg.c15
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ssbmv.c65
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sscal.c21
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sspmv.c62
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sspr.c61
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sspr2.c60
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_sswap.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ssymm.c93
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ssymv.c65
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ssyr.c59
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ssyr2.c65
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ssyr2k.c96
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ssyrk.c95
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_stbmv.c103
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_stbsv.c103
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_stpmv.c99
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_stpsv.c99
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_strmm.c125
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_strmv.c103
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_strsm.c120
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_strsv.c102
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_xerbla.c66
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zaxpy.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zcopy.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zdotc_sub.c24
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zdotu_sub.c24
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zdscal.c21
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zgbmv.c155
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zgemm.c94
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zgemv.c153
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zgerc.c77
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zgeru.c37
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zhbmv.c145
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zhemm.c91
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zhemv.c146
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zher.c99
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zher2.c140
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zher2k.c95
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zherk.c90
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zhpmv.c146
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zhpr.c102
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zhpr2.c137
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zscal.c21
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zswap.c22
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zsymm.c91
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zsyr2k.c93
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_zsyrk.c92
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ztbmv.c139
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ztbsv.c143
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ztpmv.c133
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ztpsv.c138
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ztrmm.c126
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ztrmv.c137
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ztrsm.c132
-rw-r--r--ml/dlib/dlib/external/cblas/cblas_ztrsv.c137
-rw-r--r--ml/dlib/dlib/external/cblas/cdotcsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/cdotusub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/dasumsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/ddotsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/dnrm2sub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/dsdotsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/dzasumsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/dznrm2sub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/icamaxsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/idamaxsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/isamaxsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/izamaxsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/sasumsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/scasumsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/scnrm2sub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/sdotsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/sdsdotsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/snrm2sub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/zdotcsub.f15
-rw-r--r--ml/dlib/dlib/external/cblas/zdotusub.f15
-rw-r--r--ml/dlib/dlib/external/libjpeg/README385
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcapimin.cpp280
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcapistd.cpp161
-rw-r--r--ml/dlib/dlib/external/libjpeg/jccoefct.cpp449
-rw-r--r--ml/dlib/dlib/external/libjpeg/jccolor.cpp459
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcdctmgr.cpp387
-rw-r--r--ml/dlib/dlib/external/libjpeg/jchuff.cpp909
-rw-r--r--ml/dlib/dlib/external/libjpeg/jchuff.h47
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcinit.cpp72
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcmainct.cpp293
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcmarker.cpp664
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcmaster.cpp590
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcomapi.cpp106
-rw-r--r--ml/dlib/dlib/external/libjpeg/jconfig.h45
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcparam.cpp610
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcphuff.cpp833
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcprepct.cpp354
-rw-r--r--ml/dlib/dlib/external/libjpeg/jcsample.cpp519
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdapimin.cpp395
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdapistd.cpp275
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdatadst.cpp151
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdatasrc.cpp212
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdcoefct.cpp736
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdcolor.cpp396
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdct.h176
-rw-r--r--ml/dlib/dlib/external/libjpeg/jddctmgr.cpp269
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdhuff.cpp654
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdhuff.h201
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdinput.cpp381
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdmainct.cpp512
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdmarker.cpp1360
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdmaster.cpp557
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdmerge.cpp400
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdphuff.cpp671
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdpostct.cpp290
-rw-r--r--ml/dlib/dlib/external/libjpeg/jdsample.cpp478
-rw-r--r--ml/dlib/dlib/external/libjpeg/jerror.cpp252
-rw-r--r--ml/dlib/dlib/external/libjpeg/jerror.h291
-rw-r--r--ml/dlib/dlib/external/libjpeg/jfdctflt.cpp168
-rw-r--r--ml/dlib/dlib/external/libjpeg/jfdctfst.cpp224
-rw-r--r--ml/dlib/dlib/external/libjpeg/jfdctint.cpp283
-rw-r--r--ml/dlib/dlib/external/libjpeg/jidctflt.cpp242
-rw-r--r--ml/dlib/dlib/external/libjpeg/jidctfst.cpp368
-rw-r--r--ml/dlib/dlib/external/libjpeg/jidctint.cpp389
-rw-r--r--ml/dlib/dlib/external/libjpeg/jidctred.cpp398
-rw-r--r--ml/dlib/dlib/external/libjpeg/jinclude.h91
-rw-r--r--ml/dlib/dlib/external/libjpeg/jmemmgr.cpp1118
-rw-r--r--ml/dlib/dlib/external/libjpeg/jmemnobs.cpp109
-rw-r--r--ml/dlib/dlib/external/libjpeg/jmemsys.h198
-rw-r--r--ml/dlib/dlib/external/libjpeg/jmorecfg.h356
-rw-r--r--ml/dlib/dlib/external/libjpeg/jpegint.h392
-rw-r--r--ml/dlib/dlib/external/libjpeg/jpeglib.h1096
-rw-r--r--ml/dlib/dlib/external/libjpeg/jquant1.cpp856
-rw-r--r--ml/dlib/dlib/external/libjpeg/jquant2.cpp1310
-rw-r--r--ml/dlib/dlib/external/libjpeg/jutils.cpp179
-rw-r--r--ml/dlib/dlib/external/libjpeg/jversion.h14
-rw-r--r--ml/dlib/dlib/external/libpng/LICENSE111
-rw-r--r--ml/dlib/dlib/external/libpng/README202
-rw-r--r--ml/dlib/dlib/external/libpng/arm/arm_init.c232
-rw-r--r--ml/dlib/dlib/external/libpng/arm/filter_neon.S245
-rw-r--r--ml/dlib/dlib/external/libpng/arm/filter_neon_intrinsics.c372
-rw-r--r--ml/dlib/dlib/external/libpng/png.c4299
-rw-r--r--ml/dlib/dlib/external/libpng/png.h3319
-rw-r--r--ml/dlib/dlib/external/libpng/pngconf.h626
-rw-r--r--ml/dlib/dlib/external/libpng/pngdebug.h157
-rw-r--r--ml/dlib/dlib/external/libpng/pngerror.c932
-rw-r--r--ml/dlib/dlib/external/libpng/pngget.c1177
-rw-r--r--ml/dlib/dlib/external/libpng/pnginfo.h260
-rw-r--r--ml/dlib/dlib/external/libpng/pnglibconf.h211
-rw-r--r--ml/dlib/dlib/external/libpng/pngmem.c277
-rw-r--r--ml/dlib/dlib/external/libpng/pngpread.c1291
-rw-r--r--ml/dlib/dlib/external/libpng/pngpriv.h2047
-rw-r--r--ml/dlib/dlib/external/libpng/pngread.c4000
-rw-r--r--ml/dlib/dlib/external/libpng/pngrio.c118
-rw-r--r--ml/dlib/dlib/external/libpng/pngrtran.c5110
-rw-r--r--ml/dlib/dlib/external/libpng/pngrutil.c4475
-rw-r--r--ml/dlib/dlib/external/libpng/pngset.c1597
-rw-r--r--ml/dlib/dlib/external/libpng/pngstruct.h489
-rw-r--r--ml/dlib/dlib/external/libpng/pngtrans.c841
-rw-r--r--ml/dlib/dlib/external/libpng/pngwio.c164
-rw-r--r--ml/dlib/dlib/external/libpng/pngwrite.c2330
-rw-r--r--ml/dlib/dlib/external/libpng/pngwtran.c637
-rw-r--r--ml/dlib/dlib/external/libpng/pngwutil.c3023
-rw-r--r--ml/dlib/dlib/external/pybind11/CMakeLists.txt155
-rw-r--r--ml/dlib/dlib/external/pybind11/CONTRIBUTING.md47
-rw-r--r--ml/dlib/dlib/external/pybind11/LICENSE29
-rw-r--r--ml/dlib/dlib/external/pybind11/README.md129
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/attr.h489
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/buffer_info.h108
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/cast.h2063
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/chrono.h162
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/common.h2
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/complex.h61
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/detail/class.h626
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/detail/common.h802
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/detail/descr.h185
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/detail/init.h335
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/detail/internals.h249
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/detail/typeid.h53
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/eigen.h612
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/embed.h194
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/eval.h117
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/functional.h85
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/iostream.h200
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/numpy.h1600
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/operators.h168
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/options.h65
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/pybind11.h1963
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/pytypes.h1332
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/stl.h370
-rw-r--r--ml/dlib/dlib/external/pybind11/include/pybind11/stl_bind.h599
-rw-r--r--ml/dlib/dlib/external/pybind11/tools/FindCatch.cmake57
-rw-r--r--ml/dlib/dlib/external/pybind11/tools/FindEigen3.cmake81
-rw-r--r--ml/dlib/dlib/external/pybind11/tools/FindPythonLibsNew.cmake195
-rwxr-xr-xml/dlib/dlib/external/pybind11/tools/check-style.sh70
-rw-r--r--ml/dlib/dlib/external/pybind11/tools/libsize.py38
-rw-r--r--ml/dlib/dlib/external/pybind11/tools/mkdoc.py304
-rw-r--r--ml/dlib/dlib/external/pybind11/tools/pybind11Config.cmake.in100
-rw-r--r--ml/dlib/dlib/external/pybind11/tools/pybind11Tools.cmake202
-rw-r--r--ml/dlib/dlib/external/zlib/README115
-rw-r--r--ml/dlib/dlib/external/zlib/adler32.c179
-rw-r--r--ml/dlib/dlib/external/zlib/compress.c80
-rw-r--r--ml/dlib/dlib/external/zlib/crc32.c425
-rw-r--r--ml/dlib/dlib/external/zlib/crc32.h441
-rw-r--r--ml/dlib/dlib/external/zlib/deflate.c1967
-rw-r--r--ml/dlib/dlib/external/zlib/deflate.h346
-rw-r--r--ml/dlib/dlib/external/zlib/gzclose.c25
-rw-r--r--ml/dlib/dlib/external/zlib/gzguts.h219
-rw-r--r--ml/dlib/dlib/external/zlib/gzlib.c634
-rw-r--r--ml/dlib/dlib/external/zlib/gzread.c594
-rw-r--r--ml/dlib/dlib/external/zlib/gzwrite.c577
-rw-r--r--ml/dlib/dlib/external/zlib/infback.c640
-rw-r--r--ml/dlib/dlib/external/zlib/inffast.c340
-rw-r--r--ml/dlib/dlib/external/zlib/inffast.h11
-rw-r--r--ml/dlib/dlib/external/zlib/inffixed.h94
-rw-r--r--ml/dlib/dlib/external/zlib/inflate.c1512
-rw-r--r--ml/dlib/dlib/external/zlib/inflate.h122
-rw-r--r--ml/dlib/dlib/external/zlib/inftrees.c306
-rw-r--r--ml/dlib/dlib/external/zlib/inftrees.h62
-rw-r--r--ml/dlib/dlib/external/zlib/trees.c1226
-rw-r--r--ml/dlib/dlib/external/zlib/trees.h128
-rw-r--r--ml/dlib/dlib/external/zlib/uncompr.c59
-rw-r--r--ml/dlib/dlib/external/zlib/zconf.h511
-rw-r--r--ml/dlib/dlib/external/zlib/zlib.h1768
-rw-r--r--ml/dlib/dlib/external/zlib/zutil.c324
-rw-r--r--ml/dlib/dlib/external/zlib/zutil.h253
-rw-r--r--ml/dlib/dlib/filtering.h12
-rw-r--r--ml/dlib/dlib/filtering/kalman_filter.cpp104
-rw-r--r--ml/dlib/dlib/filtering/kalman_filter.h382
-rw-r--r--ml/dlib/dlib/filtering/kalman_filter_abstract.h492
-rw-r--r--ml/dlib/dlib/filtering/rls_filter.h198
-rw-r--r--ml/dlib/dlib/filtering/rls_filter_abstract.h171
-rw-r--r--ml/dlib/dlib/float_details.h161
-rw-r--r--ml/dlib/dlib/fstream1
-rw-r--r--ml/dlib/dlib/general_hash/count_bits.h82
-rw-r--r--ml/dlib/dlib/general_hash/count_bits_abstract.h48
-rw-r--r--ml/dlib/dlib/general_hash/general_hash.h80
-rw-r--r--ml/dlib/dlib/general_hash/hash.h142
-rw-r--r--ml/dlib/dlib/general_hash/hash_abstract.h182
-rw-r--r--ml/dlib/dlib/general_hash/murmur_hash3.h519
-rw-r--r--ml/dlib/dlib/general_hash/murmur_hash3_abstract.h125
-rw-r--r--ml/dlib/dlib/general_hash/random_hashing.h877
-rw-r--r--ml/dlib/dlib/general_hash/random_hashing_abstract.h58
-rw-r--r--ml/dlib/dlib/geometry.h14
-rw-r--r--ml/dlib/dlib/geometry/border_enumerator.h186
-rw-r--r--ml/dlib/dlib/geometry/border_enumerator_abstract.h126
-rw-r--r--ml/dlib/dlib/geometry/drectangle.h488
-rw-r--r--ml/dlib/dlib/geometry/drectangle_abstract.h628
-rw-r--r--ml/dlib/dlib/geometry/point_transforms.h989
-rw-r--r--ml/dlib/dlib/geometry/point_transforms_abstract.h797
-rw-r--r--ml/dlib/dlib/geometry/rectangle.h824
-rw-r--r--ml/dlib/dlib/geometry/rectangle_abstract.h836
-rw-r--r--ml/dlib/dlib/geometry/vector.h1330
-rw-r--r--ml/dlib/dlib/geometry/vector_abstract.h489
-rw-r--r--ml/dlib/dlib/global_optimization.h14
-rw-r--r--ml/dlib/dlib/global_optimization/find_max_global.h511
-rw-r--r--ml/dlib/dlib/global_optimization/find_max_global_abstract.h496
-rw-r--r--ml/dlib/dlib/global_optimization/global_function_search.cpp942
-rw-r--r--ml/dlib/dlib/global_optimization/global_function_search.h245
-rw-r--r--ml/dlib/dlib/global_optimization/global_function_search_abstract.h605
-rw-r--r--ml/dlib/dlib/global_optimization/upper_bound_function.h286
-rw-r--r--ml/dlib/dlib/global_optimization/upper_bound_function_abstract.h212
-rw-r--r--ml/dlib/dlib/graph.h37
-rw-r--r--ml/dlib/dlib/graph/graph_kernel_1.h629
-rw-r--r--ml/dlib/dlib/graph/graph_kernel_abstract.h329
-rw-r--r--ml/dlib/dlib/graph_cuts.h14
-rw-r--r--ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h959
-rw-r--r--ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts_abstract.h636
-rw-r--r--ml/dlib/dlib/graph_cuts/general_flow_graph.h172
-rw-r--r--ml/dlib/dlib/graph_cuts/general_potts_problem.h99
-rw-r--r--ml/dlib/dlib/graph_cuts/graph_labeler.h211
-rw-r--r--ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h185
-rw-r--r--ml/dlib/dlib/graph_cuts/min_cut.h571
-rw-r--r--ml/dlib/dlib/graph_cuts/min_cut_abstract.h476
-rw-r--r--ml/dlib/dlib/graph_utils.h12
-rw-r--r--ml/dlib/dlib/graph_utils/edge_list_graphs.h593
-rw-r--r--ml/dlib/dlib/graph_utils/edge_list_graphs_abstract.h358
-rw-r--r--ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh.h217
-rw-r--r--ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h102
-rw-r--r--ml/dlib/dlib/graph_utils/function_objects.h129
-rw-r--r--ml/dlib/dlib/graph_utils/function_objects_abstract.h209
-rw-r--r--ml/dlib/dlib/graph_utils/graph_utils.h1227
-rw-r--r--ml/dlib/dlib/graph_utils/graph_utils_abstract.h452
-rw-r--r--ml/dlib/dlib/graph_utils/ordered_sample_pair.h125
-rw-r--r--ml/dlib/dlib/graph_utils/ordered_sample_pair_abstract.h128
-rw-r--r--ml/dlib/dlib/graph_utils/sample_pair.h179
-rw-r--r--ml/dlib/dlib/graph_utils/sample_pair_abstract.h192
-rw-r--r--ml/dlib/dlib/graph_utils_threaded.h12
-rw-r--r--ml/dlib/dlib/gui_core.h20
-rw-r--r--ml/dlib/dlib/gui_core/gui_core_kernel_1.cpp2204
-rw-r--r--ml/dlib/dlib/gui_core/gui_core_kernel_1.h420
-rw-r--r--ml/dlib/dlib/gui_core/gui_core_kernel_2.cpp1996
-rw-r--r--ml/dlib/dlib/gui_core/gui_core_kernel_2.h419
-rw-r--r--ml/dlib/dlib/gui_core/gui_core_kernel_abstract.h792
-rw-r--r--ml/dlib/dlib/gui_core/windows.h6
-rw-r--r--ml/dlib/dlib/gui_core/xlib.h6
-rw-r--r--ml/dlib/dlib/gui_widgets.h18
-rw-r--r--ml/dlib/dlib/gui_widgets/base_widgets.cpp3343
-rw-r--r--ml/dlib/dlib/gui_widgets/base_widgets.h2678
-rw-r--r--ml/dlib/dlib/gui_widgets/base_widgets_abstract.h2290
-rw-r--r--ml/dlib/dlib/gui_widgets/canvas_drawing.cpp101
-rw-r--r--ml/dlib/dlib/gui_widgets/canvas_drawing.h964
-rw-r--r--ml/dlib/dlib/gui_widgets/canvas_drawing_abstract.h364
-rw-r--r--ml/dlib/dlib/gui_widgets/drawable.cpp544
-rw-r--r--ml/dlib/dlib/gui_widgets/drawable.h527
-rw-r--r--ml/dlib/dlib/gui_widgets/drawable_abstract.h717
-rw-r--r--ml/dlib/dlib/gui_widgets/fonts.cpp673
-rw-r--r--ml/dlib/dlib/gui_widgets/fonts.h628
-rw-r--r--ml/dlib/dlib/gui_widgets/fonts_abstract.h492
-rw-r--r--ml/dlib/dlib/gui_widgets/nativefont.h612
-rw-r--r--ml/dlib/dlib/gui_widgets/style.cpp998
-rw-r--r--ml/dlib/dlib/gui_widgets/style.h825
-rw-r--r--ml/dlib/dlib/gui_widgets/style_abstract.h777
-rw-r--r--ml/dlib/dlib/gui_widgets/widgets.cpp7341
-rw-r--r--ml/dlib/dlib/gui_widgets/widgets.h4165
-rw-r--r--ml/dlib/dlib/gui_widgets/widgets_abstract.h3461
-rw-r--r--ml/dlib/dlib/hash.h14
-rw-r--r--ml/dlib/dlib/hash_map.h63
-rw-r--r--ml/dlib/dlib/hash_map/hash_map_kernel_1.h460
-rw-r--r--ml/dlib/dlib/hash_map/hash_map_kernel_abstract.h247
-rw-r--r--ml/dlib/dlib/hash_map/hash_map_kernel_c.h276
-rw-r--r--ml/dlib/dlib/hash_set.h63
-rw-r--r--ml/dlib/dlib/hash_set/hash_set_kernel_1.h391
-rw-r--r--ml/dlib/dlib/hash_set/hash_set_kernel_abstract.h207
-rw-r--r--ml/dlib/dlib/hash_set/hash_set_kernel_c.h190
-rw-r--r--ml/dlib/dlib/hash_table.h60
-rw-r--r--ml/dlib/dlib/hash_table/hash_table_kernel_1.h819
-rw-r--r--ml/dlib/dlib/hash_table/hash_table_kernel_2.h612
-rw-r--r--ml/dlib/dlib/hash_table/hash_table_kernel_abstract.h253
-rw-r--r--ml/dlib/dlib/hash_table/hash_table_kernel_c.h194
-rw-r--r--ml/dlib/dlib/http_client/http_client.cpp743
-rw-r--r--ml/dlib/dlib/http_client/http_client.h101
-rw-r--r--ml/dlib/dlib/http_client/http_client_abstract.h218
-rw-r--r--ml/dlib/dlib/image_io.h20
-rw-r--r--ml/dlib/dlib/image_keypoint.h16
-rw-r--r--ml/dlib/dlib/image_keypoint/binned_vector_feature_image.h433
-rw-r--r--ml/dlib/dlib/image_keypoint/binned_vector_feature_image_abstract.h287
-rw-r--r--ml/dlib/dlib/image_keypoint/build_separable_poly_filters.h186
-rw-r--r--ml/dlib/dlib/image_keypoint/draw_surf_points.h40
-rw-r--r--ml/dlib/dlib/image_keypoint/draw_surf_points_abstract.h30
-rw-r--r--ml/dlib/dlib/image_keypoint/fine_hog_image.h378
-rw-r--r--ml/dlib/dlib/image_keypoint/fine_hog_image_abstract.h276
-rw-r--r--ml/dlib/dlib/image_keypoint/hashed_feature_image.h518
-rw-r--r--ml/dlib/dlib/image_keypoint/hashed_feature_image_abstract.h303
-rw-r--r--ml/dlib/dlib/image_keypoint/hessian_pyramid.h531
-rw-r--r--ml/dlib/dlib/image_keypoint/hessian_pyramid_abstract.h244
-rw-r--r--ml/dlib/dlib/image_keypoint/hog.h514
-rw-r--r--ml/dlib/dlib/image_keypoint/hog_abstract.h335
-rw-r--r--ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image.h408
-rw-r--r--ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h254
-rw-r--r--ml/dlib/dlib/image_keypoint/poly_image.h649
-rw-r--r--ml/dlib/dlib/image_keypoint/poly_image_abstract.h335
-rw-r--r--ml/dlib/dlib/image_keypoint/surf.h295
-rw-r--r--ml/dlib/dlib/image_keypoint/surf_abstract.h163
-rw-r--r--ml/dlib/dlib/image_loader/image_loader.h863
-rw-r--r--ml/dlib/dlib/image_loader/image_loader_abstract.h136
-rw-r--r--ml/dlib/dlib/image_loader/jpeg_loader.cpp173
-rw-r--r--ml/dlib/dlib/image_loader/jpeg_loader.h109
-rw-r--r--ml/dlib/dlib/image_loader/jpeg_loader_abstract.h133
-rw-r--r--ml/dlib/dlib/image_loader/load_image.h226
-rw-r--r--ml/dlib/dlib/image_loader/load_image_abstract.h37
-rw-r--r--ml/dlib/dlib/image_loader/png_loader.cpp222
-rw-r--r--ml/dlib/dlib/image_loader/png_loader.h223
-rw-r--r--ml/dlib/dlib/image_loader/png_loader_abstract.h162
-rw-r--r--ml/dlib/dlib/image_processing.h28
-rw-r--r--ml/dlib/dlib/image_processing/box_overlap_testing.h215
-rw-r--r--ml/dlib/dlib/image_processing/box_overlap_testing_abstract.h201
-rw-r--r--ml/dlib/dlib/image_processing/correlation_tracker.h404
-rw-r--r--ml/dlib/dlib/image_processing/correlation_tracker_abstract.h162
-rw-r--r--ml/dlib/dlib/image_processing/detection_template_tools.h113
-rw-r--r--ml/dlib/dlib/image_processing/detection_template_tools_abstract.h95
-rw-r--r--ml/dlib/dlib/image_processing/frontal_face_detector.h2373
-rw-r--r--ml/dlib/dlib/image_processing/frontal_face_detector_abstract.h25
-rw-r--r--ml/dlib/dlib/image_processing/full_object_detection.h191
-rw-r--r--ml/dlib/dlib/image_processing/full_object_detection_abstract.h203
-rw-r--r--ml/dlib/dlib/image_processing/generic_image.h431
-rw-r--r--ml/dlib/dlib/image_processing/object_detector.h626
-rw-r--r--ml/dlib/dlib/image_processing/object_detector_abstract.h404
-rw-r--r--ml/dlib/dlib/image_processing/remove_unobtainable_rectangles.h317
-rw-r--r--ml/dlib/dlib/image_processing/remove_unobtainable_rectangles_abstract.h56
-rw-r--r--ml/dlib/dlib/image_processing/render_face_detections.h99
-rw-r--r--ml/dlib/dlib/image_processing/render_face_detections_abstract.h59
-rw-r--r--ml/dlib/dlib/image_processing/scan_fhog_pyramid.h1348
-rw-r--r--ml/dlib/dlib/image_processing/scan_fhog_pyramid_abstract.h784
-rw-r--r--ml/dlib/dlib/image_processing/scan_image.h368
-rw-r--r--ml/dlib/dlib/image_processing/scan_image_abstract.h227
-rw-r--r--ml/dlib/dlib/image_processing/scan_image_boxes.h630
-rw-r--r--ml/dlib/dlib/image_processing/scan_image_boxes_abstract.h394
-rw-r--r--ml/dlib/dlib/image_processing/scan_image_custom.h401
-rw-r--r--ml/dlib/dlib/image_processing/scan_image_custom_abstract.h390
-rw-r--r--ml/dlib/dlib/image_processing/scan_image_pyramid.h1101
-rw-r--r--ml/dlib/dlib/image_processing/scan_image_pyramid_abstract.h495
-rw-r--r--ml/dlib/dlib/image_processing/scan_image_pyramid_tools.h180
-rw-r--r--ml/dlib/dlib/image_processing/scan_image_pyramid_tools_abstract.h118
-rw-r--r--ml/dlib/dlib/image_processing/setup_hashed_features.h219
-rw-r--r--ml/dlib/dlib/image_processing/setup_hashed_features_abstract.h210
-rw-r--r--ml/dlib/dlib/image_processing/shape_predictor.h524
-rw-r--r--ml/dlib/dlib/image_processing/shape_predictor_abstract.h195
-rw-r--r--ml/dlib/dlib/image_processing/shape_predictor_trainer.h852
-rw-r--r--ml/dlib/dlib/image_processing/shape_predictor_trainer_abstract.h418
-rw-r--r--ml/dlib/dlib/image_saver/dng_shared.h288
-rw-r--r--ml/dlib/dlib/image_saver/image_saver.h688
-rw-r--r--ml/dlib/dlib/image_saver/image_saver_abstract.h129
-rw-r--r--ml/dlib/dlib/image_saver/save_jpeg.cpp175
-rw-r--r--ml/dlib/dlib/image_saver/save_jpeg.h82
-rw-r--r--ml/dlib/dlib/image_saver/save_jpeg_abstract.h52
-rw-r--r--ml/dlib/dlib/image_saver/save_png.cpp124
-rw-r--r--ml/dlib/dlib/image_saver/save_png.h162
-rw-r--r--ml/dlib/dlib/image_saver/save_png_abstract.h50
-rw-r--r--ml/dlib/dlib/image_transforms.h31
-rw-r--r--ml/dlib/dlib/image_transforms/assign_image.h385
-rw-r--r--ml/dlib/dlib/image_transforms/assign_image_abstract.h196
-rw-r--r--ml/dlib/dlib/image_transforms/colormaps.h269
-rw-r--r--ml/dlib/dlib/image_transforms/colormaps_abstract.h152
-rw-r--r--ml/dlib/dlib/image_transforms/draw.h396
-rw-r--r--ml/dlib/dlib/image_transforms/draw_abstract.h150
-rw-r--r--ml/dlib/dlib/image_transforms/edge_detector.h302
-rw-r--r--ml/dlib/dlib/image_transforms/edge_detector_abstract.h112
-rw-r--r--ml/dlib/dlib/image_transforms/equalize_histogram.h143
-rw-r--r--ml/dlib/dlib/image_transforms/equalize_histogram_abstract.h91
-rw-r--r--ml/dlib/dlib/image_transforms/fhog.h1404
-rw-r--r--ml/dlib/dlib/image_transforms/fhog_abstract.h346
-rw-r--r--ml/dlib/dlib/image_transforms/hough_transform.h358
-rw-r--r--ml/dlib/dlib/image_transforms/hough_transform_abstract.h145
-rw-r--r--ml/dlib/dlib/image_transforms/image_pyramid.h1238
-rw-r--r--ml/dlib/dlib/image_transforms/image_pyramid_abstract.h384
-rw-r--r--ml/dlib/dlib/image_transforms/integral_image.h190
-rw-r--r--ml/dlib/dlib/image_transforms/integral_image_abstract.h169
-rw-r--r--ml/dlib/dlib/image_transforms/interpolation.h2193
-rw-r--r--ml/dlib/dlib/image_transforms/interpolation_abstract.h1480
-rw-r--r--ml/dlib/dlib/image_transforms/label_connected_blobs.h188
-rw-r--r--ml/dlib/dlib/image_transforms/label_connected_blobs_abstract.h199
-rw-r--r--ml/dlib/dlib/image_transforms/lbp.h307
-rw-r--r--ml/dlib/dlib/image_transforms/lbp_abstract.h139
-rw-r--r--ml/dlib/dlib/image_transforms/morphological_operations.h846
-rw-r--r--ml/dlib/dlib/image_transforms/morphological_operations_abstract.h316
-rw-r--r--ml/dlib/dlib/image_transforms/random_color_transform.h157
-rw-r--r--ml/dlib/dlib/image_transforms/random_color_transform_abstract.h94
-rw-r--r--ml/dlib/dlib/image_transforms/random_cropper.h361
-rw-r--r--ml/dlib/dlib/image_transforms/random_cropper_abstract.h346
-rw-r--r--ml/dlib/dlib/image_transforms/segment_image.h730
-rw-r--r--ml/dlib/dlib/image_transforms/segment_image_abstract.h126
-rw-r--r--ml/dlib/dlib/image_transforms/spatial_filtering.h1580
-rw-r--r--ml/dlib/dlib/image_transforms/spatial_filtering_abstract.h487
-rw-r--r--ml/dlib/dlib/image_transforms/thresholding.h340
-rw-r--r--ml/dlib/dlib/image_transforms/thresholding_abstract.h139
-rw-r--r--ml/dlib/dlib/interfaces/cmd_line_parser_option.h107
-rw-r--r--ml/dlib/dlib/interfaces/enumerable.h130
-rw-r--r--ml/dlib/dlib/interfaces/map_pair.h74
-rw-r--r--ml/dlib/dlib/interfaces/remover.h220
-rw-r--r--ml/dlib/dlib/iomanip1
-rw-r--r--ml/dlib/dlib/iosfwd1
-rw-r--r--ml/dlib/dlib/iosockstream.h11
-rw-r--r--ml/dlib/dlib/iosockstream/iosockstream.h171
-rw-r--r--ml/dlib/dlib/iosockstream/iosockstream_abstract.h171
-rw-r--r--ml/dlib/dlib/iostream1
-rw-r--r--ml/dlib/dlib/is_kind.h162
-rw-r--r--ml/dlib/dlib/istream1
-rw-r--r--ml/dlib/dlib/java/CMakeLists.txt32
-rw-r--r--ml/dlib/dlib/java/cmake_swig_jni265
-rw-r--r--ml/dlib/dlib/java/java_array.h605
-rwxr-xr-xml/dlib/dlib/java/run_test.sh17
-rw-r--r--ml/dlib/dlib/java/swig_api.h126
-rw-r--r--ml/dlib/dlib/java/swig_test.java254
-rw-r--r--ml/dlib/dlib/linker.h9
-rw-r--r--ml/dlib/dlib/linker/linker_kernel_1.cpp357
-rw-r--r--ml/dlib/dlib/linker/linker_kernel_1.h141
-rw-r--r--ml/dlib/dlib/linker/linker_kernel_abstract.h141
-rw-r--r--ml/dlib/dlib/locale1
-rw-r--r--ml/dlib/dlib/logger.h11
-rw-r--r--ml/dlib/dlib/logger/extra_logger_headers.cpp40
-rw-r--r--ml/dlib/dlib/logger/extra_logger_headers.h41
-rw-r--r--ml/dlib/dlib/logger/logger_config_file.cpp214
-rw-r--r--ml/dlib/dlib/logger/logger_config_file.h135
-rw-r--r--ml/dlib/dlib/logger/logger_kernel_1.cpp498
-rw-r--r--ml/dlib/dlib/logger/logger_kernel_1.h687
-rw-r--r--ml/dlib/dlib/logger/logger_kernel_abstract.h429
-rw-r--r--ml/dlib/dlib/lsh.h14
-rw-r--r--ml/dlib/dlib/lsh/create_random_projection_hash.h232
-rw-r--r--ml/dlib/dlib/lsh/create_random_projection_hash_abstract.h148
-rw-r--r--ml/dlib/dlib/lsh/hashes.h219
-rw-r--r--ml/dlib/dlib/lsh/hashes_abstract.h286
-rw-r--r--ml/dlib/dlib/lsh/projection_hash.h118
-rw-r--r--ml/dlib/dlib/lsh/projection_hash_abstract.h119
-rw-r--r--ml/dlib/dlib/lz77_buffer.h47
-rw-r--r--ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_1.h263
-rw-r--r--ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_2.h504
-rw-r--r--ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_abstract.h210
-rw-r--r--ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_c.h169
-rw-r--r--ml/dlib/dlib/lzp_buffer.h46
-rw-r--r--ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_1.h236
-rw-r--r--ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_2.h319
-rw-r--r--ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_abstract.h130
-rw-r--r--ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_c.h101
-rw-r--r--ml/dlib/dlib/manifold_regularization.h13
-rw-r--r--ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer.h328
-rw-r--r--ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer_abstract.h137
-rw-r--r--ml/dlib/dlib/map.h59
-rw-r--r--ml/dlib/dlib/map/map_kernel_1.h436
-rw-r--r--ml/dlib/dlib/map/map_kernel_abstract.h235
-rw-r--r--ml/dlib/dlib/map/map_kernel_c.h248
-rw-r--r--ml/dlib/dlib/matlab/CMakeLists.txt22
-rw-r--r--ml/dlib/dlib/matlab/README.txt20
-rw-r--r--ml/dlib/dlib/matlab/call_matlab.h852
-rw-r--r--ml/dlib/dlib/matlab/cmake_mex_wrapper103
-rw-r--r--ml/dlib/dlib/matlab/example.m16
-rw-r--r--ml/dlib/dlib/matlab/example_mex_callback.cpp52
-rw-r--r--ml/dlib/dlib/matlab/example_mex_class.cpp72
-rw-r--r--ml/dlib/dlib/matlab/example_mex_function.cpp84
-rw-r--r--ml/dlib/dlib/matlab/example_mex_struct.cpp55
-rw-r--r--ml/dlib/dlib/matlab/mex_wrapper.cpp5144
-rw-r--r--ml/dlib/dlib/matlab/subprocess_stream.cpp537
-rw-r--r--ml/dlib/dlib/matlab/subprocess_stream.h223
-rw-r--r--ml/dlib/dlib/matrix.h24
-rw-r--r--ml/dlib/dlib/matrix/cblas_constants.h22
-rw-r--r--ml/dlib/dlib/matrix/lapack/fortran_id.h62
-rw-r--r--ml/dlib/dlib/matrix/lapack/gees.h264
-rw-r--r--ml/dlib/dlib/matrix/lapack/geev.h234
-rw-r--r--ml/dlib/dlib/matrix/lapack/geqrf.h168
-rw-r--r--ml/dlib/dlib/matrix/lapack/gesdd.h364
-rw-r--r--ml/dlib/dlib/matrix/lapack/gesvd.h323
-rw-r--r--ml/dlib/dlib/matrix/lapack/getrf.h132
-rw-r--r--ml/dlib/dlib/matrix/lapack/ormqr.h224
-rw-r--r--ml/dlib/dlib/matrix/lapack/pbtrf.h178
-rw-r--r--ml/dlib/dlib/matrix/lapack/potrf.h174
-rw-r--r--ml/dlib/dlib/matrix/lapack/syev.h218
-rw-r--r--ml/dlib/dlib/matrix/lapack/syevr.h445
-rw-r--r--ml/dlib/dlib/matrix/matrix.h2162
-rw-r--r--ml/dlib/dlib/matrix/matrix_abstract.h857
-rw-r--r--ml/dlib/dlib/matrix/matrix_assign.h978
-rw-r--r--ml/dlib/dlib/matrix/matrix_assign_fwd.h413
-rw-r--r--ml/dlib/dlib/matrix/matrix_blas_bindings.h1637
-rw-r--r--ml/dlib/dlib/matrix/matrix_cholesky.h231
-rw-r--r--ml/dlib/dlib/matrix/matrix_conj_trans.h71
-rw-r--r--ml/dlib/dlib/matrix/matrix_conv.h358
-rw-r--r--ml/dlib/dlib/matrix/matrix_conv_abstract.h158
-rw-r--r--ml/dlib/dlib/matrix/matrix_data_layout.h1271
-rw-r--r--ml/dlib/dlib/matrix/matrix_data_layout_abstract.h40
-rw-r--r--ml/dlib/dlib/matrix/matrix_default_mul.h134
-rw-r--r--ml/dlib/dlib/matrix/matrix_eigenvalue.h1379
-rw-r--r--ml/dlib/dlib/matrix/matrix_exp.h271
-rw-r--r--ml/dlib/dlib/matrix/matrix_exp_abstract.h210
-rw-r--r--ml/dlib/dlib/matrix/matrix_expressions.h280
-rw-r--r--ml/dlib/dlib/matrix/matrix_fft.h846
-rw-r--r--ml/dlib/dlib/matrix/matrix_fft_abstract.h118
-rw-r--r--ml/dlib/dlib/matrix/matrix_fwd.h31
-rw-r--r--ml/dlib/dlib/matrix/matrix_generic_image.h110
-rw-r--r--ml/dlib/dlib/matrix/matrix_la.h1807
-rw-r--r--ml/dlib/dlib/matrix/matrix_la_abstract.h1005
-rw-r--r--ml/dlib/dlib/matrix/matrix_lu.h361
-rw-r--r--ml/dlib/dlib/matrix/matrix_mat.h733
-rw-r--r--ml/dlib/dlib/matrix/matrix_mat_abstract.h243
-rw-r--r--ml/dlib/dlib/matrix/matrix_math_functions.h448
-rw-r--r--ml/dlib/dlib/matrix/matrix_math_functions_abstract.h595
-rw-r--r--ml/dlib/dlib/matrix/matrix_op.h479
-rw-r--r--ml/dlib/dlib/matrix/matrix_qr.h466
-rw-r--r--ml/dlib/dlib/matrix/matrix_read_from_istream.h108
-rw-r--r--ml/dlib/dlib/matrix/matrix_subexp.h1566
-rw-r--r--ml/dlib/dlib/matrix/matrix_subexp_abstract.h570
-rw-r--r--ml/dlib/dlib/matrix/matrix_trsm.h654
-rw-r--r--ml/dlib/dlib/matrix/matrix_utilities.h4544
-rw-r--r--ml/dlib/dlib/matrix/matrix_utilities_abstract.h1874
-rw-r--r--ml/dlib/dlib/matrix/symmetric_matrix_cache.h464
-rw-r--r--ml/dlib/dlib/matrix/symmetric_matrix_cache_abstract.h63
-rw-r--r--ml/dlib/dlib/md5.h3
-rw-r--r--ml/dlib/dlib/md5/md5_kernel_1.cpp617
-rw-r--r--ml/dlib/dlib/md5/md5_kernel_1.h50
-rw-r--r--ml/dlib/dlib/md5/md5_kernel_abstract.h83
-rw-r--r--ml/dlib/dlib/member_function_pointer.h10
-rw-r--r--ml/dlib/dlib/member_function_pointer/make_mfp.h179
-rw-r--r--ml/dlib/dlib/member_function_pointer/make_mfp_abstract.h207
-rw-r--r--ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_1.h498
-rw-r--r--ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_abstract.h483
-rw-r--r--ml/dlib/dlib/memory_manager.h73
-rw-r--r--ml/dlib/dlib/memory_manager/memory_manager_kernel_1.h305
-rw-r--r--ml/dlib/dlib/memory_manager/memory_manager_kernel_2.h253
-rw-r--r--ml/dlib/dlib/memory_manager/memory_manager_kernel_3.h385
-rw-r--r--ml/dlib/dlib/memory_manager/memory_manager_kernel_abstract.h146
-rw-r--r--ml/dlib/dlib/memory_manager_global.h38
-rw-r--r--ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_1.h113
-rw-r--r--ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_abstract.h181
-rw-r--r--ml/dlib/dlib/memory_manager_stateless.h72
-rw-r--r--ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_1.h86
-rw-r--r--ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_2.h119
-rw-r--r--ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h142
-rw-r--r--ml/dlib/dlib/metaprogramming.h71
-rw-r--r--ml/dlib/dlib/misc_api.h20
-rw-r--r--ml/dlib/dlib/misc_api/misc_api_kernel_1.cpp149
-rw-r--r--ml/dlib/dlib/misc_api/misc_api_kernel_1.h110
-rw-r--r--ml/dlib/dlib/misc_api/misc_api_kernel_2.cpp123
-rw-r--r--ml/dlib/dlib/misc_api/misc_api_kernel_2.h81
-rw-r--r--ml/dlib/dlib/misc_api/misc_api_kernel_abstract.h159
-rw-r--r--ml/dlib/dlib/misc_api/misc_api_shared.h57
-rw-r--r--ml/dlib/dlib/misc_api/posix.h6
-rw-r--r--ml/dlib/dlib/misc_api/windows.h6
-rw-r--r--ml/dlib/dlib/mlp.h30
-rw-r--r--ml/dlib/dlib/mlp/mlp_kernel_1.h394
-rw-r--r--ml/dlib/dlib/mlp/mlp_kernel_abstract.h225
-rw-r--r--ml/dlib/dlib/mlp/mlp_kernel_c.h151
-rw-r--r--ml/dlib/dlib/noncopyable.h32
-rw-r--r--ml/dlib/dlib/numeric_constants.h53
-rw-r--r--ml/dlib/dlib/numerical_integration.h8
-rw-r--r--ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson.h93
-rw-r--r--ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson_abstract.h34
-rw-r--r--ml/dlib/dlib/opencv.h17
-rw-r--r--ml/dlib/dlib/opencv/cv_image.h225
-rw-r--r--ml/dlib/dlib/opencv/cv_image_abstract.h280
-rw-r--r--ml/dlib/dlib/opencv/to_open_cv.h46
-rw-r--r--ml/dlib/dlib/opencv/to_open_cv_abstract.h34
-rw-r--r--ml/dlib/dlib/optimization.h24
-rw-r--r--ml/dlib/dlib/optimization/elastic_net.h389
-rw-r--r--ml/dlib/dlib/optimization/elastic_net_abstract.h190
-rw-r--r--ml/dlib/dlib/optimization/find_max_factor_graph_nmplp.h337
-rw-r--r--ml/dlib/dlib/optimization/find_max_factor_graph_nmplp_abstract.h365
-rw-r--r--ml/dlib/dlib/optimization/find_max_factor_graph_viterbi.h232
-rw-r--r--ml/dlib/dlib/optimization/find_max_factor_graph_viterbi_abstract.h131
-rw-r--r--ml/dlib/dlib/optimization/find_max_parse_cky.h414
-rw-r--r--ml/dlib/dlib/optimization/find_max_parse_cky_abstract.h388
-rw-r--r--ml/dlib/dlib/optimization/find_optimal_parameters.h117
-rw-r--r--ml/dlib/dlib/optimization/find_optimal_parameters_abstract.h58
-rw-r--r--ml/dlib/dlib/optimization/isotonic_regression.h169
-rw-r--r--ml/dlib/dlib/optimization/isotonic_regression_abstract.h128
-rw-r--r--ml/dlib/dlib/optimization/max_cost_assignment.h288
-rw-r--r--ml/dlib/dlib/optimization/max_cost_assignment_abstract.h63
-rw-r--r--ml/dlib/dlib/optimization/max_sum_submatrix.h285
-rw-r--r--ml/dlib/dlib/optimization/max_sum_submatrix_abstract.h49
-rw-r--r--ml/dlib/dlib/optimization/optimization.h714
-rw-r--r--ml/dlib/dlib/optimization/optimization_abstract.h468
-rw-r--r--ml/dlib/dlib/optimization/optimization_bobyqa.h3423
-rw-r--r--ml/dlib/dlib/optimization/optimization_bobyqa_abstract.h120
-rw-r--r--ml/dlib/dlib/optimization/optimization_least_squares.h345
-rw-r--r--ml/dlib/dlib/optimization/optimization_least_squares_abstract.h112
-rw-r--r--ml/dlib/dlib/optimization/optimization_line_search.h888
-rw-r--r--ml/dlib/dlib/optimization/optimization_line_search_abstract.h361
-rw-r--r--ml/dlib/dlib/optimization/optimization_oca.h407
-rw-r--r--ml/dlib/dlib/optimization/optimization_oca_abstract.h334
-rw-r--r--ml/dlib/dlib/optimization/optimization_search_strategies.h324
-rw-r--r--ml/dlib/dlib/optimization/optimization_search_strategies_abstract.h330
-rw-r--r--ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo.h468
-rw-r--r--ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo_abstract.h150
-rw-r--r--ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo.h455
-rw-r--r--ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo_abstract.h139
-rw-r--r--ml/dlib/dlib/optimization/optimization_solve_qp_using_smo.h937
-rw-r--r--ml/dlib/dlib/optimization/optimization_solve_qp_using_smo_abstract.h282
-rw-r--r--ml/dlib/dlib/optimization/optimization_stop_strategies.h173
-rw-r--r--ml/dlib/dlib/optimization/optimization_stop_strategies_abstract.h157
-rw-r--r--ml/dlib/dlib/optimization/optimization_trust_region.h564
-rw-r--r--ml/dlib/dlib/optimization/optimization_trust_region_abstract.h233
-rw-r--r--ml/dlib/dlib/ostream1
-rw-r--r--ml/dlib/dlib/pipe.h10
-rw-r--r--ml/dlib/dlib/pipe/pipe_kernel_1.h756
-rw-r--r--ml/dlib/dlib/pipe/pipe_kernel_abstract.h323
-rw-r--r--ml/dlib/dlib/pixel.h1649
-rw-r--r--ml/dlib/dlib/platform.h65
-rw-r--r--ml/dlib/dlib/python.h14
-rw-r--r--ml/dlib/dlib/python/numpy.h214
-rw-r--r--ml/dlib/dlib/python/numpy_image.h129
-rw-r--r--ml/dlib/dlib/python/pyassert.h17
-rw-r--r--ml/dlib/dlib/python/pybind_utils.h82
-rw-r--r--ml/dlib/dlib/python/serialize_pickle.h66
-rw-r--r--ml/dlib/dlib/quantum_computing.h12
-rw-r--r--ml/dlib/dlib/quantum_computing/quantum_computing.h863
-rw-r--r--ml/dlib/dlib/quantum_computing/quantum_computing_abstract.h590
-rw-r--r--ml/dlib/dlib/queue.h84
-rw-r--r--ml/dlib/dlib/queue/queue_kernel_1.h554
-rw-r--r--ml/dlib/dlib/queue/queue_kernel_2.h600
-rw-r--r--ml/dlib/dlib/queue/queue_kernel_abstract.h196
-rw-r--r--ml/dlib/dlib/queue/queue_kernel_c.h187
-rw-r--r--ml/dlib/dlib/queue/queue_sort_1.h165
-rw-r--r--ml/dlib/dlib/queue/queue_sort_abstract.h74
-rw-r--r--ml/dlib/dlib/rand.h9
-rw-r--r--ml/dlib/dlib/rand/mersenne_twister.h210
-rw-r--r--ml/dlib/dlib/rand/rand_kernel_1.h354
-rw-r--r--ml/dlib/dlib/rand/rand_kernel_abstract.h218
-rw-r--r--ml/dlib/dlib/random_forest.h10
-rw-r--r--ml/dlib/dlib/random_forest/random_forest_regression.h738
-rw-r--r--ml/dlib/dlib/random_forest/random_forest_regression_abstract.h460
-rw-r--r--ml/dlib/dlib/ref.h84
-rw-r--r--ml/dlib/dlib/reference_counter.h31
-rw-r--r--ml/dlib/dlib/reference_counter/reference_counter_kernel_1.h298
-rw-r--r--ml/dlib/dlib/reference_counter/reference_counter_kernel_abstract.h141
-rw-r--r--ml/dlib/dlib/revision.h.in6
-rw-r--r--ml/dlib/dlib/sequence.h83
-rw-r--r--ml/dlib/dlib/sequence/sequence_compare_1.h102
-rw-r--r--ml/dlib/dlib/sequence/sequence_compare_abstract.h75
-rw-r--r--ml/dlib/dlib/sequence/sequence_kernel_1.h1340
-rw-r--r--ml/dlib/dlib/sequence/sequence_kernel_2.h682
-rw-r--r--ml/dlib/dlib/sequence/sequence_kernel_abstract.h199
-rw-r--r--ml/dlib/dlib/sequence/sequence_kernel_c.h253
-rw-r--r--ml/dlib/dlib/sequence/sequence_sort_1.h182
-rw-r--r--ml/dlib/dlib/sequence/sequence_sort_2.h65
-rw-r--r--ml/dlib/dlib/sequence/sequence_sort_abstract.h65
-rw-r--r--ml/dlib/dlib/serialize.h1779
-rw-r--r--ml/dlib/dlib/server.h12
-rw-r--r--ml/dlib/dlib/server/server_http.cpp409
-rw-r--r--ml/dlib/dlib/server/server_http.h242
-rw-r--r--ml/dlib/dlib/server/server_http_abstract.h390
-rw-r--r--ml/dlib/dlib/server/server_iostream.cpp14
-rw-r--r--ml/dlib/dlib/server/server_iostream.h155
-rw-r--r--ml/dlib/dlib/server/server_iostream_abstract.h84
-rw-r--r--ml/dlib/dlib/server/server_kernel.cpp595
-rw-r--r--ml/dlib/dlib/server/server_kernel.h234
-rw-r--r--ml/dlib/dlib/server/server_kernel_abstract.h310
-rw-r--r--ml/dlib/dlib/set.h74
-rw-r--r--ml/dlib/dlib/set/set_compare_1.h122
-rw-r--r--ml/dlib/dlib/set/set_compare_abstract.h96
-rw-r--r--ml/dlib/dlib/set/set_kernel_1.h372
-rw-r--r--ml/dlib/dlib/set/set_kernel_abstract.h192
-rw-r--r--ml/dlib/dlib/set/set_kernel_c.h194
-rw-r--r--ml/dlib/dlib/set_utils.h11
-rw-r--r--ml/dlib/dlib/set_utils/set_utils.h246
-rw-r--r--ml/dlib/dlib/set_utils/set_utils_abstract.h98
-rw-r--r--ml/dlib/dlib/simd.h12
-rw-r--r--ml/dlib/dlib/simd/simd4f.h685
-rw-r--r--ml/dlib/dlib/simd/simd4i.h566
-rw-r--r--ml/dlib/dlib/simd/simd8f.h402
-rw-r--r--ml/dlib/dlib/simd/simd8i.h339
-rw-r--r--ml/dlib/dlib/simd/simd_check.h177
-rw-r--r--ml/dlib/dlib/sliding_buffer.h38
-rw-r--r--ml/dlib/dlib/sliding_buffer/circular_buffer.h235
-rw-r--r--ml/dlib/dlib/sliding_buffer/circular_buffer_abstract.h257
-rw-r--r--ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_1.h227
-rw-r--r--ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_abstract.h205
-rw-r--r--ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_c.h222
-rw-r--r--ml/dlib/dlib/smart_pointers.h22
-rw-r--r--ml/dlib/dlib/smart_pointers/scoped_ptr.h16
-rw-r--r--ml/dlib/dlib/smart_pointers/shared_ptr.h492
-rw-r--r--ml/dlib/dlib/smart_pointers/shared_ptr_abstract.h374
-rw-r--r--ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe.h462
-rw-r--r--ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe_abstract.h352
-rw-r--r--ml/dlib/dlib/smart_pointers/weak_ptr.h225
-rw-r--r--ml/dlib/dlib/smart_pointers/weak_ptr_abstract.h193
-rw-r--r--ml/dlib/dlib/smart_pointers_thread_safe.h21
-rw-r--r--ml/dlib/dlib/sockets.h20
-rw-r--r--ml/dlib/dlib/sockets/posix.h6
-rw-r--r--ml/dlib/dlib/sockets/sockets_extensions.cpp341
-rw-r--r--ml/dlib/dlib/sockets/sockets_extensions.h151
-rw-r--r--ml/dlib/dlib/sockets/sockets_extensions_abstract.h300
-rw-r--r--ml/dlib/dlib/sockets/sockets_kernel_1.cpp979
-rw-r--r--ml/dlib/dlib/sockets/sockets_kernel_1.h351
-rw-r--r--ml/dlib/dlib/sockets/sockets_kernel_2.cpp1109
-rw-r--r--ml/dlib/dlib/sockets/sockets_kernel_2.h396
-rw-r--r--ml/dlib/dlib/sockets/sockets_kernel_abstract.h495
-rw-r--r--ml/dlib/dlib/sockets/windows.h6
-rw-r--r--ml/dlib/dlib/sockstreambuf.h11
-rw-r--r--ml/dlib/dlib/sockstreambuf/sockstreambuf.cpp177
-rw-r--r--ml/dlib/dlib/sockstreambuf/sockstreambuf.h172
-rw-r--r--ml/dlib/dlib/sockstreambuf/sockstreambuf_abstract.h127
-rw-r--r--ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.cpp168
-rw-r--r--ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.h118
-rw-r--r--ml/dlib/dlib/sort.h490
-rw-r--r--ml/dlib/dlib/sparse_vector.h10
-rw-r--r--ml/dlib/dlib/sqlite.h11
-rw-r--r--ml/dlib/dlib/sqlite/sqlite.h625
-rw-r--r--ml/dlib/dlib/sqlite/sqlite_abstract.h506
-rw-r--r--ml/dlib/dlib/sqlite/sqlite_tools.h189
-rw-r--r--ml/dlib/dlib/sqlite/sqlite_tools_abstract.h164
-rw-r--r--ml/dlib/dlib/sstream1
-rw-r--r--ml/dlib/dlib/stack.h34
-rw-r--r--ml/dlib/dlib/stack/stack_kernel_1.h504
-rw-r--r--ml/dlib/dlib/stack/stack_kernel_abstract.h180
-rw-r--r--ml/dlib/dlib/stack/stack_kernel_c.h189
-rw-r--r--ml/dlib/dlib/stack_trace.cpp91
-rw-r--r--ml/dlib/dlib/stack_trace.h118
-rw-r--r--ml/dlib/dlib/static_map.h43
-rw-r--r--ml/dlib/dlib/static_map/static_map_kernel_1.h756
-rw-r--r--ml/dlib/dlib/static_map/static_map_kernel_abstract.h181
-rw-r--r--ml/dlib/dlib/static_map/static_map_kernel_c.h89
-rw-r--r--ml/dlib/dlib/static_set.h49
-rw-r--r--ml/dlib/dlib/static_set/static_set_compare_1.h122
-rw-r--r--ml/dlib/dlib/static_set/static_set_compare_abstract.h93
-rw-r--r--ml/dlib/dlib/static_set/static_set_kernel_1.h446
-rw-r--r--ml/dlib/dlib/static_set/static_set_kernel_abstract.h154
-rw-r--r--ml/dlib/dlib/static_set/static_set_kernel_c.h88
-rw-r--r--ml/dlib/dlib/statistics.h19
-rw-r--r--ml/dlib/dlib/statistics/average_precision.h66
-rw-r--r--ml/dlib/dlib/statistics/average_precision_abstract.h67
-rw-r--r--ml/dlib/dlib/statistics/cca.h186
-rw-r--r--ml/dlib/dlib/statistics/cca_abstract.h191
-rw-r--r--ml/dlib/dlib/statistics/dpca.h541
-rw-r--r--ml/dlib/dlib/statistics/dpca_abstract.h365
-rw-r--r--ml/dlib/dlib/statistics/image_feature_sampling.h82
-rw-r--r--ml/dlib/dlib/statistics/image_feature_sampling_abstract.h45
-rw-r--r--ml/dlib/dlib/statistics/lda.h237
-rw-r--r--ml/dlib/dlib/statistics/lda_abstract.h118
-rw-r--r--ml/dlib/dlib/statistics/random_subset_selector.h372
-rw-r--r--ml/dlib/dlib/statistics/random_subset_selector_abstract.h388
-rw-r--r--ml/dlib/dlib/statistics/running_gradient.h370
-rw-r--r--ml/dlib/dlib/statistics/running_gradient_abstract.h276
-rw-r--r--ml/dlib/dlib/statistics/sammon.h269
-rw-r--r--ml/dlib/dlib/statistics/sammon_abstract.h117
-rw-r--r--ml/dlib/dlib/statistics/statistics.h1890
-rw-r--r--ml/dlib/dlib/statistics/statistics_abstract.h1387
-rw-r--r--ml/dlib/dlib/statistics/vector_normalizer_frobmetric.h618
-rw-r--r--ml/dlib/dlib/statistics/vector_normalizer_frobmetric_abstract.h328
-rw-r--r--ml/dlib/dlib/std_allocator.h199
-rw-r--r--ml/dlib/dlib/stl_checked.h10
-rw-r--r--ml/dlib/dlib/stl_checked/std_vector_c.h333
-rw-r--r--ml/dlib/dlib/stl_checked/std_vector_c_abstract.h470
-rw-r--r--ml/dlib/dlib/string.h9
-rw-r--r--ml/dlib/dlib/string/cassert1
-rw-r--r--ml/dlib/dlib/string/iomanip1
-rw-r--r--ml/dlib/dlib/string/iosfwd1
-rw-r--r--ml/dlib/dlib/string/iostream1
-rw-r--r--ml/dlib/dlib/string/locale1
-rw-r--r--ml/dlib/dlib/string/string.h1004
-rw-r--r--ml/dlib/dlib/string/string_abstract.h652
-rw-r--r--ml/dlib/dlib/svm.h60
-rw-r--r--ml/dlib/dlib/svm/active_learning.h162
-rw-r--r--ml/dlib/dlib/svm/active_learning_abstract.h75
-rw-r--r--ml/dlib/dlib/svm/assignment_function.h255
-rw-r--r--ml/dlib/dlib/svm/assignment_function_abstract.h342
-rw-r--r--ml/dlib/dlib/svm/cross_validate_assignment_trainer.h181
-rw-r--r--ml/dlib/dlib/svm/cross_validate_assignment_trainer_abstract.h69
-rw-r--r--ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer.h258
-rw-r--r--ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer_abstract.h147
-rw-r--r--ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h208
-rw-r--r--ml/dlib/dlib/svm/cross_validate_multiclass_trainer_abstract.h99
-rw-r--r--ml/dlib/dlib/svm/cross_validate_object_detection_trainer.h430
-rw-r--r--ml/dlib/dlib/svm/cross_validate_object_detection_trainer_abstract.h297
-rw-r--r--ml/dlib/dlib/svm/cross_validate_regression_trainer.h155
-rw-r--r--ml/dlib/dlib/svm/cross_validate_regression_trainer_abstract.h82
-rw-r--r--ml/dlib/dlib/svm/cross_validate_sequence_labeler.h152
-rw-r--r--ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h83
-rw-r--r--ml/dlib/dlib/svm/cross_validate_sequence_segmenter.h187
-rw-r--r--ml/dlib/dlib/svm/cross_validate_sequence_segmenter_abstract.h80
-rw-r--r--ml/dlib/dlib/svm/cross_validate_track_association_trainer.h163
-rw-r--r--ml/dlib/dlib/svm/cross_validate_track_association_trainer_abstract.h69
-rw-r--r--ml/dlib/dlib/svm/empirical_kernel_map.h429
-rw-r--r--ml/dlib/dlib/svm/empirical_kernel_map_abstract.h430
-rw-r--r--ml/dlib/dlib/svm/feature_ranking.h477
-rw-r--r--ml/dlib/dlib/svm/feature_ranking_abstract.h136
-rw-r--r--ml/dlib/dlib/svm/function.h882
-rw-r--r--ml/dlib/dlib/svm/function_abstract.h997
-rw-r--r--ml/dlib/dlib/svm/kcentroid.h614
-rw-r--r--ml/dlib/dlib/svm/kcentroid_abstract.h339
-rw-r--r--ml/dlib/dlib/svm/kcentroid_overloads.h1324
-rw-r--r--ml/dlib/dlib/svm/kernel.h569
-rw-r--r--ml/dlib/dlib/svm/kernel_abstract.h681
-rw-r--r--ml/dlib/dlib/svm/kernel_matrix.h268
-rw-r--r--ml/dlib/dlib/svm/kernel_matrix_abstract.h115
-rw-r--r--ml/dlib/dlib/svm/kkmeans.h654
-rw-r--r--ml/dlib/dlib/svm/kkmeans_abstract.h365
-rw-r--r--ml/dlib/dlib/svm/krls.h358
-rw-r--r--ml/dlib/dlib/svm/krls_abstract.h202
-rw-r--r--ml/dlib/dlib/svm/krr_trainer.h368
-rw-r--r--ml/dlib/dlib/svm/krr_trainer_abstract.h322
-rw-r--r--ml/dlib/dlib/svm/linearly_independent_subset_finder.h540
-rw-r--r--ml/dlib/dlib/svm/linearly_independent_subset_finder_abstract.h327
-rw-r--r--ml/dlib/dlib/svm/multiclass_tools.h68
-rw-r--r--ml/dlib/dlib/svm/multiclass_tools_abstract.h45
-rw-r--r--ml/dlib/dlib/svm/null_df.h33
-rw-r--r--ml/dlib/dlib/svm/null_trainer.h61
-rw-r--r--ml/dlib/dlib/svm/null_trainer_abstract.h101
-rw-r--r--ml/dlib/dlib/svm/num_nonnegative_weights.h76
-rw-r--r--ml/dlib/dlib/svm/one_vs_all_decision_function.h265
-rw-r--r--ml/dlib/dlib/svm/one_vs_all_decision_function_abstract.h214
-rw-r--r--ml/dlib/dlib/svm/one_vs_all_trainer.h234
-rw-r--r--ml/dlib/dlib/svm/one_vs_all_trainer_abstract.h163
-rw-r--r--ml/dlib/dlib/svm/one_vs_one_decision_function.h291
-rw-r--r--ml/dlib/dlib/svm/one_vs_one_decision_function_abstract.h213
-rw-r--r--ml/dlib/dlib/svm/one_vs_one_trainer.h249
-rw-r--r--ml/dlib/dlib/svm/one_vs_one_trainer_abstract.h166
-rw-r--r--ml/dlib/dlib/svm/pegasos.h710
-rw-r--r--ml/dlib/dlib/svm/pegasos_abstract.h514
-rw-r--r--ml/dlib/dlib/svm/ranking_tools.h448
-rw-r--r--ml/dlib/dlib/svm/ranking_tools_abstract.h247
-rw-r--r--ml/dlib/dlib/svm/rbf_network.h162
-rw-r--r--ml/dlib/dlib/svm/rbf_network_abstract.h132
-rw-r--r--ml/dlib/dlib/svm/reduced.h613
-rw-r--r--ml/dlib/dlib/svm/reduced_abstract.h267
-rw-r--r--ml/dlib/dlib/svm/rls.h232
-rw-r--r--ml/dlib/dlib/svm/rls_abstract.h175
-rw-r--r--ml/dlib/dlib/svm/roc_trainer.h149
-rw-r--r--ml/dlib/dlib/svm/roc_trainer_abstract.h135
-rw-r--r--ml/dlib/dlib/svm/rr_trainer.h456
-rw-r--r--ml/dlib/dlib/svm/rr_trainer_abstract.h255
-rw-r--r--ml/dlib/dlib/svm/rvm.h1018
-rw-r--r--ml/dlib/dlib/svm/rvm_abstract.h278
-rw-r--r--ml/dlib/dlib/svm/sequence_labeler.h339
-rw-r--r--ml/dlib/dlib/svm/sequence_labeler_abstract.h396
-rw-r--r--ml/dlib/dlib/svm/sequence_segmenter.h468
-rw-r--r--ml/dlib/dlib/svm/sequence_segmenter_abstract.h452
-rw-r--r--ml/dlib/dlib/svm/simplify_linear_decision_function.h110
-rw-r--r--ml/dlib/dlib/svm/simplify_linear_decision_function_abstract.h74
-rw-r--r--ml/dlib/dlib/svm/sort_basis_vectors.h224
-rw-r--r--ml/dlib/dlib/svm/sort_basis_vectors_abstract.h59
-rw-r--r--ml/dlib/dlib/svm/sparse_kernel.h384
-rw-r--r--ml/dlib/dlib/svm/sparse_kernel_abstract.h486
-rw-r--r--ml/dlib/dlib/svm/sparse_vector.h1170
-rw-r--r--ml/dlib/dlib/svm/sparse_vector_abstract.h688
-rw-r--r--ml/dlib/dlib/svm/structural_assignment_trainer.h294
-rw-r--r--ml/dlib/dlib/svm/structural_assignment_trainer_abstract.h299
-rw-r--r--ml/dlib/dlib/svm/structural_graph_labeling_trainer.h282
-rw-r--r--ml/dlib/dlib/svm/structural_graph_labeling_trainer_abstract.h265
-rw-r--r--ml/dlib/dlib/svm/structural_object_detection_trainer.h402
-rw-r--r--ml/dlib/dlib/svm/structural_object_detection_trainer_abstract.h390
-rw-r--r--ml/dlib/dlib/svm/structural_sequence_labeling_trainer.h271
-rw-r--r--ml/dlib/dlib/svm/structural_sequence_labeling_trainer_abstract.h266
-rw-r--r--ml/dlib/dlib/svm/structural_sequence_segmentation_trainer.h281
-rw-r--r--ml/dlib/dlib/svm/structural_sequence_segmentation_trainer_abstract.h264
-rw-r--r--ml/dlib/dlib/svm/structural_svm_assignment_problem.h288
-rw-r--r--ml/dlib/dlib/svm/structural_svm_assignment_problem_abstract.h87
-rw-r--r--ml/dlib/dlib/svm/structural_svm_distributed.h700
-rw-r--r--ml/dlib/dlib/svm/structural_svm_distributed_abstract.h357
-rw-r--r--ml/dlib/dlib/svm/structural_svm_graph_labeling_problem.h542
-rw-r--r--ml/dlib/dlib/svm/structural_svm_graph_labeling_problem_abstract.h249
-rw-r--r--ml/dlib/dlib/svm/structural_svm_object_detection_problem.h531
-rw-r--r--ml/dlib/dlib/svm/structural_svm_object_detection_problem_abstract.h178
-rw-r--r--ml/dlib/dlib/svm/structural_svm_problem.h649
-rw-r--r--ml/dlib/dlib/svm/structural_svm_problem_abstract.h348
-rw-r--r--ml/dlib/dlib/svm/structural_svm_problem_threaded.h157
-rw-r--r--ml/dlib/dlib/svm/structural_svm_problem_threaded_abstract.h68
-rw-r--r--ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem.h281
-rw-r--r--ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem_abstract.h110
-rw-r--r--ml/dlib/dlib/svm/structural_track_association_trainer.h404
-rw-r--r--ml/dlib/dlib/svm/structural_track_association_trainer_abstract.h268
-rw-r--r--ml/dlib/dlib/svm/svm.h1205
-rw-r--r--ml/dlib/dlib/svm/svm_abstract.h604
-rw-r--r--ml/dlib/dlib/svm/svm_c_ekm_trainer.h636
-rw-r--r--ml/dlib/dlib/svm/svm_c_ekm_trainer_abstract.h384
-rw-r--r--ml/dlib/dlib/svm/svm_c_linear_dcd_trainer.h712
-rw-r--r--ml/dlib/dlib/svm/svm_c_linear_dcd_trainer_abstract.h382
-rw-r--r--ml/dlib/dlib/svm/svm_c_linear_trainer.h706
-rw-r--r--ml/dlib/dlib/svm/svm_c_linear_trainer_abstract.h359
-rw-r--r--ml/dlib/dlib/svm/svm_c_trainer.h359
-rw-r--r--ml/dlib/dlib/svm/svm_c_trainer_abstract.h237
-rw-r--r--ml/dlib/dlib/svm/svm_multiclass_linear_trainer.h432
-rw-r--r--ml/dlib/dlib/svm/svm_multiclass_linear_trainer_abstract.h275
-rw-r--r--ml/dlib/dlib/svm/svm_nu_trainer.h326
-rw-r--r--ml/dlib/dlib/svm/svm_nu_trainer_abstract.h210
-rw-r--r--ml/dlib/dlib/svm/svm_one_class_trainer.h284
-rw-r--r--ml/dlib/dlib/svm/svm_one_class_trainer_abstract.h201
-rw-r--r--ml/dlib/dlib/svm/svm_rank_trainer.h495
-rw-r--r--ml/dlib/dlib/svm/svm_rank_trainer_abstract.h298
-rw-r--r--ml/dlib/dlib/svm/svm_threaded.h253
-rw-r--r--ml/dlib/dlib/svm/svm_threaded_abstract.h62
-rw-r--r--ml/dlib/dlib/svm/svr_linear_trainer.h424
-rw-r--r--ml/dlib/dlib/svm/svr_linear_trainer_abstract.h269
-rw-r--r--ml/dlib/dlib/svm/svr_trainer.h393
-rw-r--r--ml/dlib/dlib/svm/svr_trainer_abstract.h209
-rw-r--r--ml/dlib/dlib/svm/track_association_function.h154
-rw-r--r--ml/dlib/dlib/svm/track_association_function_abstract.h271
-rw-r--r--ml/dlib/dlib/svm_threaded.h36
-rw-r--r--ml/dlib/dlib/sync_extension.h31
-rw-r--r--ml/dlib/dlib/sync_extension/sync_extension_kernel_1.h67
-rw-r--r--ml/dlib/dlib/sync_extension/sync_extension_kernel_abstract.h190
-rw-r--r--ml/dlib/dlib/test/CMakeLists.txt181
-rw-r--r--ml/dlib/dlib/test/WINDOWS_build_and_run_all_unit_tests.bat42
-rw-r--r--ml/dlib/dlib/test/active_learning.cpp165
-rw-r--r--ml/dlib/dlib/test/any.cpp139
-rw-r--r--ml/dlib/dlib/test/any_function.cpp253
-rw-r--r--ml/dlib/dlib/test/array.cpp669
-rw-r--r--ml/dlib/dlib/test/array2d.cpp580
-rw-r--r--ml/dlib/dlib/test/assignment_learning.cpp379
-rw-r--r--ml/dlib/dlib/test/base64.cpp208
-rw-r--r--ml/dlib/dlib/test/bayes_nets.cpp411
-rw-r--r--ml/dlib/dlib/test/bigint.cpp522
-rw-r--r--ml/dlib/dlib/test/binary_search_tree.h889
-rw-r--r--ml/dlib/dlib/test/binary_search_tree_kernel_1a.cpp47
-rw-r--r--ml/dlib/dlib/test/binary_search_tree_kernel_2a.cpp45
-rw-r--r--ml/dlib/dlib/test/binary_search_tree_mm1.cpp66
-rw-r--r--ml/dlib/dlib/test/binary_search_tree_mm2.cpp48
-rw-r--r--ml/dlib/dlib/test/blas_bindings/CMakeLists.txt33
-rw-r--r--ml/dlib/dlib/test/blas_bindings/blas_bindings_dot.cpp314
-rw-r--r--ml/dlib/dlib/test/blas_bindings/blas_bindings_gemm.cpp311
-rw-r--r--ml/dlib/dlib/test/blas_bindings/blas_bindings_gemv.cpp226
-rw-r--r--ml/dlib/dlib/test/blas_bindings/blas_bindings_ger.cpp200
-rw-r--r--ml/dlib/dlib/test/blas_bindings/blas_bindings_scal_axpy.cpp261
-rw-r--r--ml/dlib/dlib/test/blas_bindings/vector.cpp115
-rw-r--r--ml/dlib/dlib/test/bridge.cpp259
-rw-r--r--ml/dlib/dlib/test/bsp.cpp566
-rw-r--r--ml/dlib/dlib/test/byte_orderer.cpp111
-rw-r--r--ml/dlib/dlib/test/cca.cpp460
-rw-r--r--ml/dlib/dlib/test/checkerboard.h55
-rw-r--r--ml/dlib/dlib/test/clustering.cpp410
-rw-r--r--ml/dlib/dlib/test/cmd_line_parser.cpp40
-rw-r--r--ml/dlib/dlib/test/cmd_line_parser.h901
-rw-r--r--ml/dlib/dlib/test/cmd_line_parser_wchar_t.cpp40
-rw-r--r--ml/dlib/dlib/test/compress_stream.cpp306
-rw-r--r--ml/dlib/dlib/test/conditioning_class.cpp86
-rw-r--r--ml/dlib/dlib/test/conditioning_class.h841
-rw-r--r--ml/dlib/dlib/test/conditioning_class_c.cpp87
-rw-r--r--ml/dlib/dlib/test/config_reader.cpp509
-rw-r--r--ml/dlib/dlib/test/correlation_tracker.cpp955
-rw-r--r--ml/dlib/dlib/test/crc32.cpp74
-rw-r--r--ml/dlib/dlib/test/create_iris_datafile.cpp65
-rw-r--r--ml/dlib/dlib/test/create_iris_datafile.h19
-rw-r--r--ml/dlib/dlib/test/cublas.cpp198
-rw-r--r--ml/dlib/dlib/test/data_io.cpp227
-rw-r--r--ml/dlib/dlib/test/directed_graph.cpp541
-rw-r--r--ml/dlib/dlib/test/discriminant_pca.cpp365
-rw-r--r--ml/dlib/dlib/test/disjoint_subsets.cpp102
-rw-r--r--ml/dlib/dlib/test/disjoint_subsets_sized.cpp143
-rw-r--r--ml/dlib/dlib/test/dnn.cpp3261
-rw-r--r--ml/dlib/dlib/test/ekm_and_lisf.cpp306
-rw-r--r--ml/dlib/dlib/test/elastic_net.cpp122
-rw-r--r--ml/dlib/dlib/test/empirical_kernel_map.cpp444
-rw-r--r--ml/dlib/dlib/test/entropy_coder.cpp587
-rw-r--r--ml/dlib/dlib/test/entropy_encoder_model.cpp198
-rw-r--r--ml/dlib/dlib/test/example.cpp72
-rw-r--r--ml/dlib/dlib/test/example_args.cpp75
-rw-r--r--ml/dlib/dlib/test/examples/CMakeLists.txt8
-rw-r--r--ml/dlib/dlib/test/face.cpp360
-rw-r--r--ml/dlib/dlib/test/fft.cpp553
-rw-r--r--ml/dlib/dlib/test/fhog.cpp684
-rw-r--r--ml/dlib/dlib/test/filtering.cpp166
-rw-r--r--ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp787
-rw-r--r--ml/dlib/dlib/test/find_max_factor_graph_viterbi.cpp217
-rw-r--r--ml/dlib/dlib/test/find_optimal_parameters.cpp58
-rw-r--r--ml/dlib/dlib/test/geometry.cpp883
-rw-r--r--ml/dlib/dlib/test/global_optimization.cpp302
-rw-r--r--ml/dlib/dlib/test/graph.cpp414
-rw-r--r--ml/dlib/dlib/test/graph_cuts.cpp1217
-rw-r--r--ml/dlib/dlib/test/graph_labeler.cpp472
-rw-r--r--ml/dlib/dlib/test/gui/CMakeLists.txt20
-rw-r--r--ml/dlib/dlib/test/gui/main.cpp840
-rw-r--r--ml/dlib/dlib/test/hash.cpp369
-rw-r--r--ml/dlib/dlib/test/hash_map.cpp450
-rw-r--r--ml/dlib/dlib/test/hash_set.cpp387
-rw-r--r--ml/dlib/dlib/test/hash_table.cpp663
-rw-r--r--ml/dlib/dlib/test/hog_image.cpp126
-rw-r--r--ml/dlib/dlib/test/image.cpp1903
-rw-r--r--ml/dlib/dlib/test/iosockstream.cpp181
-rw-r--r--ml/dlib/dlib/test/is_same_object.cpp141
-rw-r--r--ml/dlib/dlib/test/isotonic_regression.cpp103
-rw-r--r--ml/dlib/dlib/test/kcentroid.cpp684
-rw-r--r--ml/dlib/dlib/test/kernel_matrix.cpp161
-rw-r--r--ml/dlib/dlib/test/kmeans.cpp163
-rw-r--r--ml/dlib/dlib/test/learning_to_track.cpp306
-rw-r--r--ml/dlib/dlib/test/least_squares.cpp452
-rw-r--r--ml/dlib/dlib/test/linear_manifold_regularizer.cpp408
-rw-r--r--ml/dlib/dlib/test/lspi.cpp258
-rw-r--r--ml/dlib/dlib/test/lz77_buffer.cpp569
-rw-r--r--ml/dlib/dlib/test/main.cpp217
-rw-r--r--ml/dlib/dlib/test/makefile185
-rw-r--r--ml/dlib/dlib/test/map.cpp441
-rw-r--r--ml/dlib/dlib/test/matrix.cpp1519
-rw-r--r--ml/dlib/dlib/test/matrix2.cpp1158
-rw-r--r--ml/dlib/dlib/test/matrix3.cpp1134
-rw-r--r--ml/dlib/dlib/test/matrix4.cpp1119
-rw-r--r--ml/dlib/dlib/test/matrix_chol.cpp182
-rw-r--r--ml/dlib/dlib/test/matrix_eig.cpp245
-rw-r--r--ml/dlib/dlib/test/matrix_lu.cpp223
-rw-r--r--ml/dlib/dlib/test/matrix_qr.cpp208
-rw-r--r--ml/dlib/dlib/test/max_cost_assignment.cpp157
-rw-r--r--ml/dlib/dlib/test/max_sum_submatrix.cpp177
-rw-r--r--ml/dlib/dlib/test/md5.cpp71
-rw-r--r--ml/dlib/dlib/test/member_function_pointer.cpp553
-rw-r--r--ml/dlib/dlib/test/metaprogramming.cpp94
-rw-r--r--ml/dlib/dlib/test/mpc.cpp346
-rw-r--r--ml/dlib/dlib/test/multithreaded_object.cpp321
-rw-r--r--ml/dlib/dlib/test/numerical_integration.cpp228
-rw-r--r--ml/dlib/dlib/test/object_detector.cpp1028
-rw-r--r--ml/dlib/dlib/test/oca.cpp244
-rw-r--r--ml/dlib/dlib/test/one_vs_all_trainer.cpp305
-rw-r--r--ml/dlib/dlib/test/one_vs_one_trainer.cpp218
-rw-r--r--ml/dlib/dlib/test/opt_qp_solver.cpp813
-rw-r--r--ml/dlib/dlib/test/optimization.cpp1231
-rw-r--r--ml/dlib/dlib/test/optimization_test_functions.cpp425
-rw-r--r--ml/dlib/dlib/test/optimization_test_functions.h310
-rw-r--r--ml/dlib/dlib/test/parallel_for.cpp334
-rw-r--r--ml/dlib/dlib/test/parse.cpp233
-rw-r--r--ml/dlib/dlib/test/pipe.cpp688
-rw-r--r--ml/dlib/dlib/test/pixel.cpp777
-rw-r--r--ml/dlib/dlib/test/probabilistic.cpp123
-rw-r--r--ml/dlib/dlib/test/pyramid_down.cpp424
-rw-r--r--ml/dlib/dlib/test/queue.cpp426
-rw-r--r--ml/dlib/dlib/test/rand.cpp436
-rw-r--r--ml/dlib/dlib/test/random_forest.cpp405
-rw-r--r--ml/dlib/dlib/test/ranking.cpp485
-rw-r--r--ml/dlib/dlib/test/read_write_mutex.cpp208
-rw-r--r--ml/dlib/dlib/test/reference_counter.cpp122
-rw-r--r--ml/dlib/dlib/test/rls.cpp196
-rw-r--r--ml/dlib/dlib/test/sammon.cpp211
-rw-r--r--ml/dlib/dlib/test/scan_image.cpp713
-rw-r--r--ml/dlib/dlib/test/sequence.cpp312
-rw-r--r--ml/dlib/dlib/test/sequence_labeler.cpp461
-rw-r--r--ml/dlib/dlib/test/sequence_segmenter.cpp294
-rw-r--r--ml/dlib/dlib/test/serialize.cpp1087
-rw-r--r--ml/dlib/dlib/test/set.cpp464
-rw-r--r--ml/dlib/dlib/test/sldf.cpp296
-rw-r--r--ml/dlib/dlib/test/sliding_buffer.cpp439
-rw-r--r--ml/dlib/dlib/test/smart_pointers.cpp449
-rw-r--r--ml/dlib/dlib/test/sockets.cpp247
-rw-r--r--ml/dlib/dlib/test/sockets2.cpp204
-rw-r--r--ml/dlib/dlib/test/sockstreambuf.cpp253
-rw-r--r--ml/dlib/dlib/test/sparse_vector.cpp301
-rw-r--r--ml/dlib/dlib/test/stack.cpp294
-rw-r--r--ml/dlib/dlib/test/static_map.cpp323
-rw-r--r--ml/dlib/dlib/test/static_set.cpp206
-rw-r--r--ml/dlib/dlib/test/statistics.cpp915
-rw-r--r--ml/dlib/dlib/test/std_vector_c.cpp101
-rw-r--r--ml/dlib/dlib/test/string.cpp329
-rw-r--r--ml/dlib/dlib/test/svm.cpp661
-rw-r--r--ml/dlib/dlib/test/svm_c_linear.cpp392
-rw-r--r--ml/dlib/dlib/test/svm_c_linear_dcd.cpp545
-rw-r--r--ml/dlib/dlib/test/svm_multiclass_linear.cpp226
-rw-r--r--ml/dlib/dlib/test/svm_struct.cpp641
-rw-r--r--ml/dlib/dlib/test/svr_linear_trainer.cpp161
-rw-r--r--ml/dlib/dlib/test/symmetric_matrix_cache.cpp212
-rw-r--r--ml/dlib/dlib/test/tester.cpp175
-rw-r--r--ml/dlib/dlib/test/tester.h187
-rw-r--r--ml/dlib/dlib/test/thread_pool.cpp428
-rw-r--r--ml/dlib/dlib/test/threads.cpp158
-rw-r--r--ml/dlib/dlib/test/timer.cpp347
-rw-r--r--ml/dlib/dlib/test/tokenizer.cpp378
-rw-r--r--ml/dlib/dlib/test/tools/CMakeLists.txt5
-rw-r--r--ml/dlib/dlib/test/trust_region.cpp329
-rw-r--r--ml/dlib/dlib/test/tuple.cpp186
-rw-r--r--ml/dlib/dlib/test/type_safe_union.cpp455
-rw-r--r--ml/dlib/dlib/test/vectorstream.cpp142
-rw-r--r--ml/dlib/dlib/test_for_odr_violations.cpp47
-rw-r--r--ml/dlib/dlib/test_for_odr_violations.h57
-rw-r--r--ml/dlib/dlib/threads.h28
-rw-r--r--ml/dlib/dlib/threads/async.cpp48
-rw-r--r--ml/dlib/dlib/threads/async.h105
-rw-r--r--ml/dlib/dlib/threads/async_abstract.h67
-rw-r--r--ml/dlib/dlib/threads/auto_mutex_extension.h180
-rw-r--r--ml/dlib/dlib/threads/auto_mutex_extension_abstract.h185
-rw-r--r--ml/dlib/dlib/threads/auto_unlock_extension.h116
-rw-r--r--ml/dlib/dlib/threads/auto_unlock_extension_abstract.h116
-rw-r--r--ml/dlib/dlib/threads/create_new_thread_extension.h46
-rw-r--r--ml/dlib/dlib/threads/create_new_thread_extension_abstract.h33
-rw-r--r--ml/dlib/dlib/threads/multithreaded_object_extension.cpp241
-rw-r--r--ml/dlib/dlib/threads/multithreaded_object_extension.h153
-rw-r--r--ml/dlib/dlib/threads/multithreaded_object_extension_abstract.h186
-rw-r--r--ml/dlib/dlib/threads/parallel_for_extension.h676
-rw-r--r--ml/dlib/dlib/threads/parallel_for_extension_abstract.h469
-rw-r--r--ml/dlib/dlib/threads/posix.h6
-rw-r--r--ml/dlib/dlib/threads/read_write_mutex_extension.h177
-rw-r--r--ml/dlib/dlib/threads/read_write_mutex_extension_abstract.h146
-rw-r--r--ml/dlib/dlib/threads/rmutex_extension.h109
-rw-r--r--ml/dlib/dlib/threads/rmutex_extension_abstract.h107
-rw-r--r--ml/dlib/dlib/threads/rsignaler_extension.h90
-rw-r--r--ml/dlib/dlib/threads/rsignaler_extension_abstract.h123
-rw-r--r--ml/dlib/dlib/threads/thread_function_extension.h215
-rw-r--r--ml/dlib/dlib/threads/thread_function_extension_abstract.h146
-rw-r--r--ml/dlib/dlib/threads/thread_pool_extension.cpp347
-rw-r--r--ml/dlib/dlib/threads/thread_pool_extension.h1392
-rw-r--r--ml/dlib/dlib/threads/thread_pool_extension_abstract.h842
-rw-r--r--ml/dlib/dlib/threads/thread_specific_data_extension.h141
-rw-r--r--ml/dlib/dlib/threads/thread_specific_data_extension_abstract.h87
-rw-r--r--ml/dlib/dlib/threads/threaded_object_extension.cpp290
-rw-r--r--ml/dlib/dlib/threads/threaded_object_extension.h123
-rw-r--r--ml/dlib/dlib/threads/threaded_object_extension_abstract.h199
-rw-r--r--ml/dlib/dlib/threads/threads_kernel.h18
-rw-r--r--ml/dlib/dlib/threads/threads_kernel_1.cpp83
-rw-r--r--ml/dlib/dlib/threads/threads_kernel_1.h158
-rw-r--r--ml/dlib/dlib/threads/threads_kernel_2.cpp75
-rw-r--r--ml/dlib/dlib/threads/threads_kernel_2.h180
-rw-r--r--ml/dlib/dlib/threads/threads_kernel_abstract.h302
-rw-r--r--ml/dlib/dlib/threads/threads_kernel_shared.cpp318
-rw-r--r--ml/dlib/dlib/threads/threads_kernel_shared.h274
-rw-r--r--ml/dlib/dlib/threads/windows.h6
-rw-r--r--ml/dlib/dlib/time_this.h36
-rw-r--r--ml/dlib/dlib/timeout.h10
-rw-r--r--ml/dlib/dlib/timeout/timeout.h200
-rw-r--r--ml/dlib/dlib/timeout/timeout_abstract.h188
-rw-r--r--ml/dlib/dlib/timer.h10
-rw-r--r--ml/dlib/dlib/timer/timer.cpp235
-rw-r--r--ml/dlib/dlib/timer/timer.h427
-rw-r--r--ml/dlib/dlib/timer/timer_abstract.h190
-rw-r--r--ml/dlib/dlib/timer/timer_heavy.h392
-rw-r--r--ml/dlib/dlib/timing.h196
-rw-r--r--ml/dlib/dlib/tokenizer.h33
-rw-r--r--ml/dlib/dlib/tokenizer/tokenizer_kernel_1.cpp295
-rw-r--r--ml/dlib/dlib/tokenizer/tokenizer_kernel_1.h155
-rw-r--r--ml/dlib/dlib/tokenizer/tokenizer_kernel_abstract.h289
-rw-r--r--ml/dlib/dlib/tokenizer/tokenizer_kernel_c.h167
-rwxr-xr-xml/dlib/dlib/travis/build-and-test.sh45
-rw-r--r--ml/dlib/dlib/tuple.h10
-rw-r--r--ml/dlib/dlib/tuple/tuple.h410
-rw-r--r--ml/dlib/dlib/tuple/tuple_abstract.h302
-rw-r--r--ml/dlib/dlib/type_safe_union.h11
-rw-r--r--ml/dlib/dlib/type_safe_union/type_safe_union_kernel.h711
-rw-r--r--ml/dlib/dlib/type_safe_union/type_safe_union_kernel_abstract.h329
-rw-r--r--ml/dlib/dlib/uintn.h96
-rw-r--r--ml/dlib/dlib/unicode.h9
-rw-r--r--ml/dlib/dlib/unicode/unicode.cpp175
-rw-r--r--ml/dlib/dlib/unicode/unicode.h622
-rw-r--r--ml/dlib/dlib/unicode/unicode_abstract.h233
-rw-r--r--ml/dlib/dlib/unordered_pair.h176
-rw-r--r--ml/dlib/dlib/vectorstream.h11
-rw-r--r--ml/dlib/dlib/vectorstream/unserialize.h98
-rw-r--r--ml/dlib/dlib/vectorstream/unserialize_abstract.h58
-rw-r--r--ml/dlib/dlib/vectorstream/vectorstream.h138
-rw-r--r--ml/dlib/dlib/vectorstream/vectorstream_abstract.h62
-rw-r--r--ml/dlib/dlib/windows_magic.h50
-rw-r--r--ml/dlib/dlib/xml_parser.h13
-rw-r--r--ml/dlib/dlib/xml_parser/xml_parser_kernel_1.h1532
-rw-r--r--ml/dlib/dlib/xml_parser/xml_parser_kernel_abstract.h276
-rw-r--r--ml/dlib/dlib/xml_parser/xml_parser_kernel_interfaces.h244
1530 files changed, 538563 insertions, 0 deletions
diff --git a/ml/dlib/dlib/CMakeLists.txt b/ml/dlib/dlib/CMakeLists.txt
new file mode 100644
index 000000000..15208064c
--- /dev/null
+++ b/ml/dlib/dlib/CMakeLists.txt
@@ -0,0 +1,841 @@
+#
+# This is a CMake makefile. You can find the cmake utility and
+# information about it at http://www.cmake.org
+#
+
+
+cmake_minimum_required(VERSION 2.8.12)
+project(dlib)
+
+
+include(cmake_utils/set_compiler_specific_options.cmake)
+
+
+# Adhere to GNU filesystem layout conventions
+include(GNUInstallDirs)
+
+# default to a Release build (except if CMAKE_BUILD_TYPE is set)
+include(cmake_utils/release_build_by_default)
+include(cmake_utils/use_cpp_11.cmake)
+
+
+set(CPACK_PACKAGE_VERSION_MAJOR "19")
+set(CPACK_PACKAGE_VERSION_MINOR "10")
+set(CPACK_PACKAGE_VERSION_PATCH "0")
+set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH})
+# Set DLIB_VERSION in the including CMake file so they can use it to do whatever they want.
+get_directory_property(has_parent PARENT_DIRECTORY)
+if(has_parent)
+ set(DLIB_VERSION ${VERSION} PARENT_SCOPE)
+ if (NOT DEFINED DLIB_IN_PROJECT_BUILD)
+ set(DLIB_IN_PROJECT_BUILD true)
+ endif()
+endif()
+
+
+if (DLIB_IN_PROJECT_BUILD)
+ # DLIB_IN_PROJECT_BUILD==true means you are using dlib by invoking
+ # add_subdirectory(dlib) in the parent project. In this case, we always want
+ # to build dlib as a static library so the parent project doesn't need to
+ # deal with some random dlib shared library file. It is much better to
+ # statically compile dlib into the parent project. So the following bit of
+ # CMake ensures that happens. However, we have to take care to compile dlib
+ # with position independent code if appropriate (i.e. if the parent project
+ # is a shared library).
+ if (BUILD_SHARED_LIBS)
+ if (CMAKE_COMPILER_IS_GNUCXX)
+ # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set
+ # -fPIC for GCC but sometimes it still doesn't get set, so make sure it
+ # does.
+ add_definitions("-fPIC")
+ endif()
+ set(CMAKE_POSITION_INDEPENDENT_CODE true)
+ endif()
+
+ # Tell cmake to build dlib as a static library
+ set(BUILD_SHARED_LIBS false)
+endif()
+
+
+if (CMAKE_VERSION VERSION_LESS "3.9.0")
+ # Set only because there are old target_link_libraries() statements in the
+ # FindCUDA.cmake file that comes with CMake that error out if the new behavior
+ # is used. In newer versions of CMake we can instead set CUDA_LINK_LIBRARIES_KEYWORD which fixes this issue.
+ cmake_policy(SET CMP0023 OLD)
+else()
+ set(CUDA_LINK_LIBRARIES_KEYWORD PUBLIC)
+endif()
+
+
+macro (enable_preprocessor_switch option_name)
+ list(APPEND active_preprocessor_switches "-D${option_name}")
+endmacro()
+
+macro (disable_preprocessor_switch option_name)
+ if (active_preprocessor_switches)
+ list(REMOVE_ITEM active_preprocessor_switches "-D${option_name}")
+ endif()
+endmacro()
+
+macro (toggle_preprocessor_switch option_name)
+ if (${option_name})
+ enable_preprocessor_switch(${option_name})
+ else()
+ disable_preprocessor_switch(${option_name})
+ endif()
+endmacro()
+
+
+
+# Suppress superfluous randlib warnings about libdlib.a having no symbols on MacOSX.
+if (APPLE)
+ set(CMAKE_C_ARCHIVE_CREATE "<CMAKE_AR> Scr <TARGET> <LINK_FLAGS> <OBJECTS>")
+ set(CMAKE_CXX_ARCHIVE_CREATE "<CMAKE_AR> Scr <TARGET> <LINK_FLAGS> <OBJECTS>")
+ set(CMAKE_C_ARCHIVE_FINISH "<CMAKE_RANLIB> -no_warning_for_no_symbols -c <TARGET>")
+ set(CMAKE_CXX_ARCHIVE_FINISH "<CMAKE_RANLIB> -no_warning_for_no_symbols -c <TARGET>")
+endif()
+
+# Don't try to call add_library(dlib) and setup dlib's stuff if it has already
+# been done by some other part of the current cmake project. We do this
+# because it avoids getting warnings/errors about cmake policy CMP0002. This
+# happens when a project tries to call add_subdirectory() on dlib more than
+# once. This most often happens when the top level of a project depends on two
+# or more other things which both depend on dlib.
+if (NOT TARGET dlib)
+
+ set (DLIB_ISO_CPP_ONLY_STR
+ "Enable this if you don't want to compile any non-ISO C++ code (i.e. you don't use any of the API Wrappers)" )
+ set (DLIB_NO_GUI_SUPPORT_STR
+ "Enable this if you don't want to compile any of the dlib GUI code" )
+ set (DLIB_ENABLE_STACK_TRACE_STR
+ "Enable this if you want to turn on the DLIB_STACK_TRACE macros" )
+ set (DLIB_USE_BLAS_STR
+ "Disable this if you don't want to use a BLAS library" )
+ set (DLIB_USE_LAPACK_STR
+ "Disable this if you don't want to use a LAPACK library" )
+ set (DLIB_USE_CUDA_STR
+ "Disable this if you don't want to use NVIDIA CUDA" )
+ set (DLIB_PNG_SUPPORT_STR
+ "Disable this if you don't want to link against libpng" )
+ set (DLIB_GIF_SUPPORT_STR
+ "Disable this if you don't want to link against libgif" )
+ set (DLIB_JPEG_SUPPORT_STR
+ "Disable this if you don't want to link against libjpeg" )
+ set (DLIB_LINK_WITH_SQLITE3_STR
+ "Disable this if you don't want to link against sqlite3" )
+ #set (DLIB_USE_FFTW_STR "Disable this if you don't want to link against fftw" )
+ set (DLIB_USE_MKL_FFT_STR
+ "Disable this is you don't want to use the MKL DFTI FFT implementation" )
+ set (DLIB_ENABLE_ASSERTS_STR
+ "Enable this if you want to turn on the DLIB_ASSERT macro" )
+
+
+ option(DLIB_ENABLE_ASSERTS ${DLIB_ENABLE_ASSERTS_STR} OFF)
+ option(DLIB_ISO_CPP_ONLY ${DLIB_ISO_CPP_ONLY_STR} OFF)
+ toggle_preprocessor_switch(DLIB_ISO_CPP_ONLY)
+ option(DLIB_NO_GUI_SUPPORT ${DLIB_NO_GUI_SUPPORT_STR} OFF)
+ toggle_preprocessor_switch(DLIB_NO_GUI_SUPPORT)
+ option(DLIB_ENABLE_STACK_TRACE ${DLIB_ENABLE_STACK_TRACE_STR} OFF)
+ toggle_preprocessor_switch(DLIB_ENABLE_STACK_TRACE)
+
+ if(DLIB_ENABLE_ASSERTS)
+ # Set these variables so they are set in the config.h.in file when dlib
+ # is installed.
+ set (DLIB_DISABLE_ASSERTS false)
+ set (ENABLE_ASSERTS true)
+ enable_preprocessor_switch(ENABLE_ASSERTS)
+ disable_preprocessor_switch(DLIB_DISABLE_ASSERTS)
+ else()
+ # Set these variables so they are set in the config.h.in file when dlib
+ # is installed.
+ set (DLIB_DISABLE_ASSERTS true)
+ set (ENABLE_ASSERTS false)
+ disable_preprocessor_switch(ENABLE_ASSERTS)
+ # Never force the asserts off when doing an in project build. The only
+ # time this matters is when using visual studio. The visual studio IDE
+ # has a drop down that lets the user select either release or debug
+ # builds. The DLIB_ASSERT macro is setup to enable/disable automatically
+ # based on this drop down (via preprocessor magic). However, if
+ # DLIB_DISABLE_ASSERTS is defined it permanently disables asserts no
+ # matter what, which would defeat the visual studio drop down. So here
+ # we make a point to not do that kind of severe disabling when in a
+ # project build. It should also be pointed out that DLIB_DISABLE_ASSERTS
+ # is only needed when building and installing dlib as a separately
+ # installed library. It doesn't matter when doing an in project build.
+ if (NOT DLIB_IN_PROJECT_BUILD)
+ enable_preprocessor_switch(DLIB_DISABLE_ASSERTS)
+ endif()
+ endif()
+
+ if (DLIB_ISO_CPP_ONLY)
+ option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} OFF)
+ option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} OFF)
+ option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} OFF)
+ option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} OFF)
+ option(DLIB_USE_CUDA ${DLIB_USE_CUDA_STR} OFF)
+ option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} OFF)
+ option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} OFF)
+ #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} OFF)
+ option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} OFF)
+ else()
+ option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} ON)
+ option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} ON)
+ option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} ON)
+ option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} ON)
+ option(DLIB_USE_CUDA ${DLIB_USE_CUDA_STR} ON)
+ option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} ON)
+ option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} ON)
+ #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} ON)
+ option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} ON)
+ endif()
+ toggle_preprocessor_switch(DLIB_JPEG_SUPPORT)
+ toggle_preprocessor_switch(DLIB_USE_BLAS)
+ toggle_preprocessor_switch(DLIB_USE_LAPACK)
+ toggle_preprocessor_switch(DLIB_USE_CUDA)
+ toggle_preprocessor_switch(DLIB_PNG_SUPPORT)
+ toggle_preprocessor_switch(DLIB_GIF_SUPPORT)
+ #toggle_preprocessor_switch(DLIB_USE_FFTW)
+ toggle_preprocessor_switch(DLIB_USE_MKL_FFT)
+
+
+ set(source_files
+ base64/base64_kernel_1.cpp
+ bigint/bigint_kernel_1.cpp
+ bigint/bigint_kernel_2.cpp
+ bit_stream/bit_stream_kernel_1.cpp
+ entropy_decoder/entropy_decoder_kernel_1.cpp
+ entropy_decoder/entropy_decoder_kernel_2.cpp
+ entropy_encoder/entropy_encoder_kernel_1.cpp
+ entropy_encoder/entropy_encoder_kernel_2.cpp
+ md5/md5_kernel_1.cpp
+ tokenizer/tokenizer_kernel_1.cpp
+ unicode/unicode.cpp
+ data_io/image_dataset_metadata.cpp
+ data_io/mnist.cpp
+ global_optimization/global_function_search.cpp
+ filtering/kalman_filter.cpp
+ test_for_odr_violations.cpp
+ )
+
+
+ set(dlib_needed_libraries)
+ set(dlib_needed_includes)
+
+ if (DLIB_ISO_CPP_ONLY)
+ add_library(dlib ${source_files} )
+ else()
+
+ set(source_files ${source_files}
+ sockets/sockets_kernel_1.cpp
+ bsp/bsp.cpp
+ dir_nav/dir_nav_kernel_1.cpp
+ dir_nav/dir_nav_kernel_2.cpp
+ dir_nav/dir_nav_extensions.cpp
+ linker/linker_kernel_1.cpp
+ logger/extra_logger_headers.cpp
+ logger/logger_kernel_1.cpp
+ logger/logger_config_file.cpp
+ misc_api/misc_api_kernel_1.cpp
+ misc_api/misc_api_kernel_2.cpp
+ sockets/sockets_extensions.cpp
+ sockets/sockets_kernel_2.cpp
+ sockstreambuf/sockstreambuf.cpp
+ sockstreambuf/sockstreambuf_unbuffered.cpp
+ server/server_kernel.cpp
+ server/server_iostream.cpp
+ server/server_http.cpp
+ threads/multithreaded_object_extension.cpp
+ threads/threaded_object_extension.cpp
+ threads/threads_kernel_1.cpp
+ threads/threads_kernel_2.cpp
+ threads/threads_kernel_shared.cpp
+ threads/thread_pool_extension.cpp
+ threads/async.cpp
+ timer/timer.cpp
+ stack_trace.cpp
+ dnn/cpu_dlib.cpp
+ dnn/tensor_tools.cpp
+ )
+
+ if(UNIX)
+ set(CMAKE_THREAD_PREFER_PTHREAD ON)
+ find_package(Threads REQUIRED)
+ set(dlib_needed_libraries ${dlib_needed_libraries} ${CMAKE_THREAD_LIBS_INIT})
+ endif()
+
+ # we want to link to the right stuff depending on our platform.
+ if (WIN32 AND NOT CYGWIN) ###############################################################################
+ if (DLIB_NO_GUI_SUPPORT)
+ set (dlib_needed_libraries ws2_32 winmm)
+ else()
+ set (dlib_needed_libraries ws2_32 winmm comctl32 gdi32 imm32)
+ endif()
+ elseif(APPLE) ############################################################################
+ set(CMAKE_MACOSX_RPATH 1)
+ if (NOT DLIB_NO_GUI_SUPPORT)
+ find_package(X11 QUIET)
+ if (X11_FOUND)
+ # If both X11 and anaconda are installed, it's possible for the
+ # anaconda path to appear before /opt/X11, so we remove anaconda.
+ foreach (ITR ${X11_INCLUDE_DIR})
+ if ("${ITR}" MATCHES "(.*)(Ana|ana|mini)conda(.*)")
+ list (REMOVE_ITEM X11_INCLUDE_DIR ${ITR})
+ endif ()
+ endforeach(ITR)
+ include_directories(${X11_INCLUDE_DIR})
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${X11_LIBRARIES})
+ else()
+ find_library(xlib X11)
+ # Make sure X11 is in the include path. Note that we look for
+ # Xlocale.h rather than Xlib.h because it avoids finding a partial
+ # copy of the X11 headers on systems with anaconda installed.
+ find_path(xlib_path Xlocale.h
+ PATHS
+ /Developer/SDKs/MacOSX10.4u.sdk/usr/X11R6/include
+ /opt/local/include
+ PATH_SUFFIXES X11
+ )
+ if (xlib AND xlib_path)
+ get_filename_component(x11_path ${xlib_path} PATH CACHE)
+ include_directories(${x11_path})
+ set(dlib_needed_libraries ${dlib_needed_libraries} ${xlib} )
+ set(X11_FOUND 1)
+ endif()
+ endif()
+ if (NOT X11_FOUND)
+ message(" *****************************************************************************")
+ message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***")
+ message(" *** Make sure XQuartz is installed if you want GUI support. ***")
+ message(" *** You can download XQuartz from: https://www.xquartz.org/ ***")
+ message(" *****************************************************************************")
+ set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE )
+ enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT)
+ endif()
+ endif()
+
+ mark_as_advanced(pthreadlib xlib xlib_path x11_path)
+ else () ##################################################################################
+ # link to the nsl library if it exists. this is something you need sometimes
+ find_library(nsllib nsl)
+ if (nsllib)
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${nsllib})
+ endif ()
+
+ # link to the socket library if it exists. this is something you need on solaris
+ find_library(socketlib socket)
+ if (socketlib)
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${socketlib})
+ endif ()
+
+ if (NOT DLIB_NO_GUI_SUPPORT)
+ include(FindX11)
+ if (X11_FOUND)
+ include_directories(${X11_INCLUDE_DIR})
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${X11_LIBRARIES})
+ else()
+ message(" *****************************************************************************")
+ message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***")
+ message(" *** Make sure libx11-dev is installed if you want GUI support. ***")
+ message(" *** On Ubuntu run: sudo apt-get install libx11-dev ***")
+ message(" *****************************************************************************")
+ set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE )
+ enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT)
+ endif()
+ endif()
+
+ mark_as_advanced(nsllib pthreadlib socketlib)
+ endif () ##################################################################################
+
+ if (NOT DLIB_NO_GUI_SUPPORT)
+ set(source_files ${source_files}
+ gui_widgets/fonts.cpp
+ gui_widgets/widgets.cpp
+ gui_widgets/drawable.cpp
+ gui_widgets/canvas_drawing.cpp
+ gui_widgets/style.cpp
+ gui_widgets/base_widgets.cpp
+ gui_core/gui_core_kernel_1.cpp
+ gui_core/gui_core_kernel_2.cpp
+ )
+ endif()
+
+ INCLUDE (CheckFunctionExists)
+
+ if (DLIB_GIF_SUPPORT)
+ find_package(GIF QUIET)
+ if (GIF_FOUND)
+ set (dlib_needed_includes ${dlib_needed_includes} ${GIF_INCLUDE_DIR})
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${GIF_LIBRARY})
+ else()
+ set(DLIB_GIF_SUPPORT OFF CACHE STRING ${DLIB_GIF_SUPPORT_STR} FORCE )
+ toggle_preprocessor_switch(DLIB_GIF_SUPPORT)
+ endif()
+ endif()
+
+ if (DLIB_PNG_SUPPORT)
+ # try to find libpng
+ find_package(PNG QUIET)
+ # Make sure there isn't something wrong with the version of LIBPNG
+ # installed on this system.
+ if (PNG_FOUND)
+ set(CMAKE_REQUIRED_LIBRARIES ${PNG_LIBRARIES})
+ CHECK_FUNCTION_EXISTS(png_create_read_struct LIBPNG_IS_GOOD)
+ endif()
+ if (PNG_FOUND AND LIBPNG_IS_GOOD)
+ include_directories(${PNG_INCLUDE_DIR})
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${PNG_LIBRARIES})
+ set(REQUIRES_LIBS " libpng")
+ else()
+ # If we can't find libpng then statically compile it in.
+ include_directories(external/libpng external/zlib)
+ set(source_files ${source_files}
+ external/libpng/png.c
+ external/libpng/pngerror.c
+ external/libpng/pngget.c
+ external/libpng/pngmem.c
+ external/libpng/pngpread.c
+ external/libpng/pngread.c
+ external/libpng/pngrio.c
+ external/libpng/pngrtran.c
+ external/libpng/pngrutil.c
+ external/libpng/pngset.c
+ external/libpng/pngtrans.c
+ external/libpng/pngwio.c
+ external/libpng/pngwrite.c
+ external/libpng/pngwtran.c
+ external/libpng/pngwutil.c
+ external/zlib/adler32.c
+ external/zlib/compress.c
+ external/zlib/crc32.c
+ external/zlib/deflate.c
+ external/zlib/gzclose.c
+ external/zlib/gzlib.c
+ external/zlib/gzread.c
+ external/zlib/gzwrite.c
+ external/zlib/infback.c
+ external/zlib/inffast.c
+ external/zlib/inflate.c
+ external/zlib/inftrees.c
+ external/zlib/trees.c
+ external/zlib/uncompr.c
+ external/zlib/zutil.c
+ )
+
+ include(cmake_utils/check_if_neon_available.cmake)
+ if (ARM_NEON_IS_AVAILABLE)
+ message (STATUS "NEON instructions will be used for libpng.")
+ enable_language(ASM)
+ set(source_files ${source_files}
+ external/libpng/arm/arm_init.c
+ external/libpng/arm/filter_neon_intrinsics.c
+ external/libpng/arm/filter_neon.S
+ )
+ set_source_files_properties(external/libpng/arm/filter_neon.S PROPERTIES COMPILE_FLAGS "${CMAKE_ASM_FLAGS} ${CMAKE_CXX_FLAGS} -x assembler-with-cpp")
+ endif()
+ set(REQUIRES_LIBS "")
+ endif()
+ set(source_files ${source_files}
+ image_loader/png_loader.cpp
+ image_saver/save_png.cpp
+ )
+ endif()
+
+ if (DLIB_JPEG_SUPPORT)
+ # try to find libjpeg
+ find_package(JPEG QUIET)
+ # Make sure there isn't something wrong with the version of libjpeg
+ # installed on this system. Also don't use the installed libjpeg
+ # if this is an APPLE system because apparently it's broken (as of 2015/01/01).
+ if (JPEG_FOUND AND NOT ("${JPEG_INCLUDE_DIR}" MATCHES "(.*)(Ana|ana|mini)conda(.*)"))
+ set(CMAKE_REQUIRED_LIBRARIES ${JPEG_LIBRARY})
+ CHECK_FUNCTION_EXISTS(jpeg_read_header LIBJPEG_IS_GOOD)
+ endif()
+ if (JPEG_FOUND AND LIBJPEG_IS_GOOD AND NOT APPLE)
+ include_directories(${JPEG_INCLUDE_DIR})
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${JPEG_LIBRARY})
+ else()
+ # If we can't find libjpeg then statically compile it in.
+ add_definitions(-DDLIB_JPEG_STATIC)
+ set(source_files ${source_files}
+ external/libjpeg/jcomapi.cpp
+ external/libjpeg/jdapimin.cpp
+ external/libjpeg/jdapistd.cpp
+ external/libjpeg/jdatasrc.cpp
+ external/libjpeg/jdcoefct.cpp
+ external/libjpeg/jdcolor.cpp
+ external/libjpeg/jddctmgr.cpp
+ external/libjpeg/jdhuff.cpp
+ external/libjpeg/jdinput.cpp
+ external/libjpeg/jdmainct.cpp
+ external/libjpeg/jdmarker.cpp
+ external/libjpeg/jdmaster.cpp
+ external/libjpeg/jdmerge.cpp
+ external/libjpeg/jdphuff.cpp
+ external/libjpeg/jdpostct.cpp
+ external/libjpeg/jdsample.cpp
+ external/libjpeg/jerror.cpp
+ external/libjpeg/jidctflt.cpp
+ external/libjpeg/jidctfst.cpp
+ external/libjpeg/jidctint.cpp
+ external/libjpeg/jidctred.cpp
+ external/libjpeg/jmemmgr.cpp
+ external/libjpeg/jmemnobs.cpp
+ external/libjpeg/jquant1.cpp
+ external/libjpeg/jquant2.cpp
+ external/libjpeg/jutils.cpp
+ external/libjpeg/jcapimin.cpp
+ external/libjpeg/jdatadst.cpp
+ external/libjpeg/jcparam.cpp
+ external/libjpeg/jcapistd.cpp
+ external/libjpeg/jcmarker.cpp
+ external/libjpeg/jcinit.cpp
+ external/libjpeg/jcmaster.cpp
+ external/libjpeg/jcdctmgr.cpp
+ external/libjpeg/jccoefct.cpp
+ external/libjpeg/jccolor.cpp
+ external/libjpeg/jchuff.cpp
+ external/libjpeg/jcmainct.cpp
+ external/libjpeg/jcphuff.cpp
+ external/libjpeg/jcprepct.cpp
+ external/libjpeg/jcsample.cpp
+ external/libjpeg/jfdctint.cpp
+ external/libjpeg/jfdctflt.cpp
+ external/libjpeg/jfdctfst.cpp
+ )
+ endif()
+ set(source_files ${source_files}
+ image_loader/jpeg_loader.cpp
+ image_saver/save_jpeg.cpp
+ )
+ endif()
+
+
+ if (DLIB_USE_BLAS OR DLIB_USE_LAPACK OR DLIB_USE_MKL_FFT)
+ # Try to find BLAS, LAPACK and MKL
+ include(cmake_utils/find_blas.cmake)
+
+ if (DLIB_USE_BLAS)
+ if (blas_found)
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${blas_libraries})
+ else()
+ set(DLIB_USE_BLAS OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE )
+ toggle_preprocessor_switch(DLIB_USE_BLAS)
+ endif()
+ endif()
+
+ if (DLIB_USE_LAPACK)
+ if (lapack_found)
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${lapack_libraries})
+ if (lapack_with_underscore)
+ set(LAPACK_FORCE_UNDERSCORE 1)
+ enable_preprocessor_switch(LAPACK_FORCE_UNDERSCORE)
+ elseif (lapack_without_underscore)
+ set(LAPACK_FORCE_NOUNDERSCORE 1)
+ enable_preprocessor_switch(LAPACK_FORCE_NOUNDERSCORE)
+ endif ()
+ else()
+ set(DLIB_USE_LAPACK OFF CACHE STRING ${DLIB_USE_LAPACK_STR} FORCE )
+ toggle_preprocessor_switch(DLIB_USE_LAPACK)
+ endif()
+ endif()
+
+ if (DLIB_USE_MKL_FFT)
+ if (found_intel_mkl AND found_intel_mkl_headers)
+ set (dlib_needed_includes ${dlib_needed_includes} ${mkl_include_dir})
+ set (dlib_needed_libraries ${dlib_needed_libraries} ${mkl_libraries})
+ else()
+ set(DLIB_USE_MKL_FFT OFF CACHE STRING ${DLIB_USE_MKL_FFT_STR} FORCE )
+ toggle_preprocessor_switch(DLIB_USE_MKL_FFT)
+ endif()
+ endif()
+ endif()
+
+
+ if (DLIB_USE_CUDA)
+ find_package(CUDA 7.5)
+
+ if (CUDA_FOUND AND MSVC AND NOT CUDA_CUBLAS_LIBRARIES AND "${CMAKE_SIZEOF_VOID_P}" EQUAL "4")
+ message(WARNING "You have CUDA installed, but we can't use it unless you put visual studio in 64bit mode.")
+ set(CUDA_FOUND 0)
+ endif()
+
+ if (CUDA_FOUND AND (NOT USING_OLD_VISUAL_STUDIO_COMPILER))
+
+ # There is some bug in cmake that causes it to mess up the
+ # -std=c++11 option if you let it propagate it to nvcc in some
+ # cases. So instead we disable this and manually include
+ # things from CMAKE_CXX_FLAGS in the CUDA_NVCC_FLAGS list below.
+ if (APPLE)
+ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
+ # Grab all the -D flags from CMAKE_CXX_FLAGS so we can pass them
+ # to nvcc.
+ string(REGEX MATCHALL "-D[^ ]*" FLAGS_FOR_NVCC "${CMAKE_CXX_FLAGS}")
+ endif()
+
+
+ set(CUDA_HOST_COMPILATION_CPP ON)
+ # Note that we add __STRICT_ANSI__ to avoid freaking out nvcc with gcc specific
+ # magic in the standard C++ header files (since nvcc uses gcc headers on
+ # linux).
+ list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-D__STRICT_ANSI__;-D_MWAITXINTRIN_H_INCLUDED;-D_FORCE_INLINES;${FLAGS_FOR_NVCC}")
+ list(APPEND CUDA_NVCC_FLAGS ${active_preprocessor_switches})
+ if (NOT MSVC)
+ list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
+ endif()
+ if (CMAKE_POSITION_INDEPENDENT_CODE)
+ # sometimes this setting isn't propagated to NVCC, which then causes the
+ # compile to fail. So make sure it's propagated.
+ if (NOT MSVC) # Visual studio doesn't have -fPIC so don't do it in that case.
+ list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
+ endif()
+ endif()
+
+ include(cmake_utils/test_for_cudnn/find_cudnn.txt)
+
+ if (cudnn AND cudnn_include AND NOT DEFINED cuda_test_compile_worked AND NOT DEFINED cudnn_test_compile_worked)
+ # make sure cuda is really working by doing a test compile
+ message(STATUS "Building a CUDA test project to see if your compiler is compatible with CUDA...")
+
+ set(CUDA_TEST_CMAKE_FLAGS
+ "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}"
+ "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}"
+ "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}")
+
+ if (NOT MSVC) # see https://github.com/davisking/dlib/issues/363
+ list(APPEND CUDA_TEST_CMAKE_FLAGS "-DCUDA_HOST_COMPILER=${CUDA_HOST_COMPILER}")
+ endif()
+
+ try_compile(cuda_test_compile_worked
+ ${PROJECT_BINARY_DIR}/cuda_test_build
+ ${PROJECT_SOURCE_DIR}/cmake_utils/test_for_cuda cuda_test
+ CMAKE_FLAGS ${CUDA_TEST_CMAKE_FLAGS}
+ OUTPUT_VARIABLE try_compile_output_message
+ )
+ if (NOT cuda_test_compile_worked)
+ string(REPLACE "\n" "\n *** " try_compile_output_message "${try_compile_output_message}")
+ message(STATUS "*****************************************************************************************************************")
+ message(STATUS "*** CUDA was found but your compiler failed to compile a simple CUDA program so dlib isn't going to use CUDA. ")
+ message(STATUS "*** The output of the failed CUDA test compile is shown below: ")
+ message(STATUS "*** ${try_compile_output_message}")
+ message(STATUS "*****************************************************************************************************************")
+ else()
+ message(STATUS "Checking if you have the right version of cuDNN installed.")
+ try_compile(cudnn_test_compile_worked
+ ${PROJECT_BINARY_DIR}/cudnn_test_build
+ ${PROJECT_SOURCE_DIR}/cmake_utils/test_for_cudnn cudnn_test
+ CMAKE_FLAGS ${CUDA_TEST_CMAKE_FLAGS}
+ )
+ if (NOT cudnn_test_compile_worked)
+ message(STATUS "*** Found cuDNN, but it looks like the wrong version so dlib will not use it. ***")
+ message(STATUS "*** Dlib requires cuDNN V5.0 OR GREATER. Since cuDNN is not found DLIB WILL NOT USE CUDA. ***")
+ message(STATUS "*** If you have cuDNN then set CMAKE_PREFIX_PATH to include cuDNN's folder. ***")
+ endif()
+ endif()
+ endif()
+
+ # Find where cuSOLVER is since the FindCUDA cmake package doesn't
+ # bother to look for it.
+ get_filename_component(cuda_blas_path "${CUDA_CUBLAS_LIBRARIES}" DIRECTORY)
+ find_library(cusolver cusolver HINTS ${cuda_blas_path})
+ mark_as_advanced(cusolver)
+ # Also find OpenMP since cuSOLVER needs it. Importantly, we only
+ # look for one to link to if our use of BLAS, specifically the
+ # Intel MKL, hasn't already decided what to use. This is because
+ # it makes the MKL bug out if you link to another openmp lib other
+ # than Intel's when you use the MKL.
+ if (NOT openmp_libraries AND NOT MSVC AND NOT XCODE)
+ find_package(OpenMP)
+ if (OPENMP_FOUND)
+ set(openmp_libraries ${OpenMP_CXX_FLAGS})
+ else()
+ message(STATUS "*** Didn't find OpenMP, which is required to use CUDA. ***")
+ set(CUDA_FOUND 0)
+ endif()
+ endif()
+ endif()
+
+ if (CUDA_FOUND AND cudnn AND (NOT USING_OLD_VISUAL_STUDIO_COMPILER) AND cuda_test_compile_worked AND cudnn_test_compile_worked AND cudnn_include)
+ set(source_files ${source_files}
+ dnn/cuda_dlib.cu
+ dnn/cudnn_dlibapi.cpp
+ dnn/cublas_dlibapi.cpp
+ dnn/cusolver_dlibapi.cu
+ dnn/curand_dlibapi.cpp
+ dnn/cuda_data_ptr.cpp
+ dnn/gpu_data.cpp
+ )
+ set(dlib_needed_libraries ${dlib_needed_libraries}
+ ${CUDA_CUBLAS_LIBRARIES}
+ ${cudnn}
+ ${CUDA_curand_LIBRARY}
+ ${cusolver}
+ )
+ if(openmp_libraries)
+ list(APPEND dlib_needed_libraries ${openmp_libraries})
+ endif()
+
+ include_directories(${cudnn_include})
+ message(STATUS "Enabling CUDA support for dlib. DLIB WILL USE CUDA")
+ else()
+ set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE )
+ toggle_preprocessor_switch(DLIB_USE_CUDA)
+ if (USING_OLD_VISUAL_STUDIO_COMPILER)
+ message(STATUS "*** Dlib CUDA support requires C++11 but your compiler doesn't support it. ***")
+ endif()
+ message(STATUS "Disabling CUDA support for dlib. DLIB WILL NOT USE CUDA")
+ endif()
+ endif()
+
+
+ if (DLIB_LINK_WITH_SQLITE3)
+ find_library(sqlite sqlite3)
+ # make sure sqlite3.h is in the include path
+ find_path(sqlite_path sqlite3.h)
+ if (sqlite AND sqlite_path)
+ set(dlib_needed_includes ${dlib_needed_includes} ${sqlite_path})
+ set(dlib_needed_libraries ${dlib_needed_libraries} ${sqlite} )
+ else()
+ set(DLIB_LINK_WITH_SQLITE3 OFF CACHE STRING ${DLIB_LINK_WITH_SQLITE3_STR} FORCE )
+ endif()
+ mark_as_advanced(sqlite sqlite_path)
+ endif()
+
+
+
+ if (DLIB_USE_FFTW)
+ find_library(fftw fftw3)
+ # make sure fftw3.h is in the include path
+ find_path(fftw_path fftw3.h)
+ if (fftw AND fftw_path)
+ set(dlib_needed_includes ${dlib_needed_includes} ${fftw_path})
+ set(dlib_needed_libraries ${dlib_needed_libraries} ${fftw} )
+ else()
+ set(DLIB_USE_FFTW OFF CACHE STRING ${DLIB_USE_FFTW_STR} FORCE )
+ toggle_preprocessor_switch(DLIB_USE_FFTW)
+ endif()
+ mark_as_advanced(fftw fftw_path)
+ endif()
+
+
+
+ # Tell CMake to build dlib via add_library()/cuda_add_library()
+ if (DLIB_USE_CUDA)
+ # The old cuda_add_library() command doesn't support CMake's newer dependency
+ # stuff, so we have to set the include path manually still, which we do here.
+ include_directories(${dlib_needed_includes})
+ cuda_add_library(dlib ${source_files} )
+ else()
+ add_library(dlib ${source_files} )
+ endif()
+
+ endif () ##### end of if NOT DLIB_ISO_CPP_ONLY ##########################################################
+
+
+ target_include_directories(dlib
+ INTERFACE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/..>
+ INTERFACE $<INSTALL_INTERFACE:include>
+ PUBLIC ${dlib_needed_includes}
+ )
+ target_link_libraries(dlib PUBLIC ${dlib_needed_libraries})
+ if (DLIB_IN_PROJECT_BUILD)
+ target_compile_options(dlib PUBLIC ${active_preprocessor_switches})
+ else()
+ # These are private in this case because they will be controlled by the
+ # contents of dlib/config.h once it's installed. But for in project
+ # builds, there is no real config.h so they are public in the above case.
+ target_compile_options(dlib PRIVATE ${active_preprocessor_switches})
+ # Do this so that dlib/config.h won't set DLIB_NOT_CONFIGURED. This will then allow
+ # the code in dlib/threads_kernel_shared.cpp to emit a linker error for users who
+ # don't use the configured config.h file generated by cmake.
+ target_compile_options(dlib PRIVATE -DDLIB__CMAKE_GENERATED_A_CONFIG_H_FILE)
+
+ # Do this so that dlib/config.h can record the version of dlib it's configured with
+ # and ultimately issue a linker error to people who try to use a binary dlib that is
+ # the wrong version.
+ set(DLIB_CHECK_FOR_VERSION_MISMATCH
+ DLIB_VERSION_MISMATCH_CHECK__EXPECTED_VERSION_${CPACK_PACKAGE_VERSION_MAJOR}_${CPACK_PACKAGE_VERSION_MINOR}_${CPACK_PACKAGE_VERSION_PATCH})
+ target_compile_options(dlib PRIVATE "-DDLIB_CHECK_FOR_VERSION_MISMATCH=${DLIB_CHECK_FOR_VERSION_MISMATCH}")
+ endif()
+
+
+ # Allow the unit tests to ask us to compile the all/source.cpp file just to make sure it compiles.
+ if (DLIB_TEST_COMPILE_ALL_SOURCE_CPP)
+ add_library(dlib_all_source_cpp STATIC all/source.cpp)
+ target_link_libraries(dlib_all_source_cpp dlib)
+ target_compile_options(dlib_all_source_cpp PUBLIC ${active_preprocessor_switches})
+ enable_cpp11_for_target(dlib_all_source_cpp)
+ endif()
+
+ if (TARGET dlib)
+ enable_cpp11_for_target(dlib)
+ target_compile_options(dlib PUBLIC ${active_compile_opts})
+ endif()
+
+ # Install the library
+ if (NOT DLIB_IN_PROJECT_BUILD)
+ set_target_properties(dlib PROPERTIES
+ VERSION ${VERSION})
+ install(TARGETS dlib
+ EXPORT dlib
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} # Windows considers .dll to be runtime artifacts
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
+
+ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib
+ FILES_MATCHING PATTERN "*.h" PATTERN "*.cmake"
+ REGEX "${CMAKE_CURRENT_BINARY_DIR}" EXCLUDE)
+
+
+ configure_file(${PROJECT_SOURCE_DIR}/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/config.h)
+ # overwrite config.h with the configured one
+ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/config.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib)
+
+ configure_file(${PROJECT_SOURCE_DIR}/revision.h.in ${CMAKE_CURRENT_BINARY_DIR}/revision.h)
+ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/revision.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib)
+
+ ## Config.cmake generation and installation
+
+ set(ConfigPackageLocation "${CMAKE_INSTALL_LIBDIR}/cmake/dlib")
+ install(EXPORT dlib
+ NAMESPACE dlib::
+ DESTINATION ${ConfigPackageLocation})
+
+ configure_file(cmake_utils/dlibConfig.cmake.in "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake" @ONLY)
+
+ include(CMakePackageConfigHelpers)
+ write_basic_package_version_file(
+ "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake"
+ VERSION ${VERSION}
+ COMPATIBILITY AnyNewerVersion
+ )
+
+ install(FILES
+ "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake"
+ "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake"
+ DESTINATION ${ConfigPackageLocation})
+
+ ## dlib-1.pc generation and installation
+
+ configure_file("cmake_utils/dlib.pc.in" "dlib-1.pc" @ONLY)
+ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/dlib-1.pc"
+ DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig")
+
+ endif()
+
+endif()
+
+if (MSVC)
+ # Give the output library files names that are unique functions of the
+ # visual studio mode that compiled them. We do this so that people who
+ # compile dlib and then copy the .lib files around (which they shouldn't be
+ # doing in the first place!) will hopefully be slightly less confused by
+ # what happens since, at the very least, the filenames will indicate what
+ # visual studio runtime they go with.
+ math(EXPR numbits ${CMAKE_SIZEOF_VOID_P}*8)
+ set_target_properties(dlib PROPERTIES DEBUG_POSTFIX "${VERSION}_debug_${numbits}bit_msvc${MSVC_VERSION}")
+ set_target_properties(dlib PROPERTIES RELEASE_POSTFIX "${VERSION}_release_${numbits}bit_msvc${MSVC_VERSION}")
+ set_target_properties(dlib PROPERTIES MINSIZEREL_POSTFIX "${VERSION}_minsizerel_${numbits}bit_msvc${MSVC_VERSION}")
+ set_target_properties(dlib PROPERTIES RELWITHDEBINFO_POSTFIX "${VERSION}_relwithdebinfo_${numbits}bit_msvc${MSVC_VERSION}")
+endif()
+
+add_library(dlib::dlib ALIAS dlib)
diff --git a/ml/dlib/dlib/LICENSE.txt b/ml/dlib/dlib/LICENSE.txt
new file mode 100644
index 000000000..127a5bc39
--- /dev/null
+++ b/ml/dlib/dlib/LICENSE.txt
@@ -0,0 +1,23 @@
+Boost Software License - Version 1.0 - August 17th, 2003
+
+Permission is hereby granted, free of charge, to any person or organization
+obtaining a copy of the software and accompanying documentation covered by
+this license (the "Software") to use, reproduce, display, distribute,
+execute, and transmit the Software, and to prepare derivative works of the
+Software, and to permit third-parties to whom the Software is furnished to
+do so, all subject to the following:
+
+The copyright notices in the Software and this entire statement, including
+the above license grant, this restriction and the following disclaimer,
+must be included in all copies of the Software, in whole or in part, and
+all derivative works of the Software, unless such copies or derivative
+works are solely in the form of machine-executable object code generated by
+a source language processor.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
+SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
+FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
+ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE. \ No newline at end of file
diff --git a/ml/dlib/dlib/algs.h b/ml/dlib/dlib/algs.h
new file mode 100644
index 000000000..d0f74b1f2
--- /dev/null
+++ b/ml/dlib/dlib/algs.h
@@ -0,0 +1,1157 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_ALGs_
+#define DLIB_ALGs_
+
+// this file contains miscellaneous stuff
+
+// Give people who forget the -std=c++11 option a reminder
+#if (defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))) || \
+ (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4) || (__clang_major__ >= 3)))
+ #if __cplusplus < 201103
+ #error "Dlib requires C++11 support. Give your compiler the -std=c++11 option to enable it."
+ #endif
+#endif
+
+#if defined __NVCC__
+ // Disable the "statement is unreachable" message since it will go off on code that is
+ // actually reachable but just happens to not be reachable sometimes during certain
+ // template instantiations.
+ #pragma diag_suppress code_is_unreachable
+#endif
+
+
+#ifdef _MSC_VER
+
+#if _MSC_VER < 1900
+#error "dlib versions newer than v19.1 use C++11 and therefore require Visual Studio 2015 or newer."
+#endif
+
+// Disable the following warnings for Visual Studio
+
+// this is to disable the "'this' : used in base member initializer list"
+// warning you get from some of the GUI objects since all the objects
+// require that their parent class be passed into their constructor.
+// In this case though it is totally safe so it is ok to disable this warning.
+#pragma warning(disable : 4355)
+
+// This is a warning you get sometimes when Visual Studio performs a Koenig Lookup.
+// This is a bug in visual studio. It is a totally legitimate thing to
+// expect from a compiler.
+#pragma warning(disable : 4675)
+
+// This is a warning you get from visual studio 2005 about things in the standard C++
+// library being "deprecated." I checked the C++ standard and it doesn't say jack
+// about any of them (I checked the searchable PDF). So this warning is total Bunk.
+#pragma warning(disable : 4996)
+
+// This is a warning you get from visual studio 2003:
+// warning C4345: behavior change: an object of POD type constructed with an initializer
+// of the form () will be default-initialized.
+// I love it when this compiler gives warnings about bugs in previous versions of itself.
+#pragma warning(disable : 4345)
+
+
+// Disable warnings about conversion from size_t to unsigned long and long.
+#pragma warning(disable : 4267)
+
+// Disable warnings about conversion from double to float
+#pragma warning(disable : 4244)
+#pragma warning(disable : 4305)
+
+// Disable "warning C4180: qualifier applied to function type has no meaning; ignored".
+// This warning happens often in generic code that works with functions and isn't useful.
+#pragma warning(disable : 4180)
+
+// Disable "warning C4290: C++ exception specification ignored except to indicate a function is not __declspec(nothrow)"
+#pragma warning(disable : 4290)
+
+
+// DNN module uses template-based network declaration that leads to very long
+// type names. Visual Studio will produce Warning C4503 in such cases. https://msdn.microsoft.com/en-us/library/074af4b6.aspx says
+// that correct binaries are still produced even when this warning happens, but linker errors from visual studio, if they occurr could be confusing.
+#pragma warning( disable: 4503 )
+
+
+#endif
+
+#ifdef __BORLANDC__
+// Disable the following warnings for the Borland Compilers
+//
+// These warnings just say that the compiler is refusing to inline functions with
+// loops or try blocks in them.
+//
+#pragma option -w-8027
+#pragma option -w-8026
+#endif
+
+#include <string> // for the exceptions
+
+#ifdef __CYGWIN__
+namespace std
+{
+ typedef std::basic_string<wchar_t> wstring;
+}
+#endif
+
+#include "platform.h"
+#include "windows_magic.h"
+
+
+#include <algorithm> // for std::swap
+#include <new> // for std::bad_alloc
+#include <cstdlib>
+#include <limits> // for std::numeric_limits for is_finite()
+#include "assert.h"
+#include "error.h"
+#include "noncopyable.h"
+#include "enable_if.h"
+#include "uintn.h"
+#include "numeric_constants.h"
+#include "memory_manager_stateless/memory_manager_stateless_kernel_1.h" // for the default memory manager
+
+
+
+// ----------------------------------------------------------------------------------------
+/*!A _dT !*/
+
+template <typename charT>
+inline charT _dTcast (const char a, const wchar_t b);
+template <>
+inline char _dTcast<char> (const char a, const wchar_t ) { return a; }
+template <>
+inline wchar_t _dTcast<wchar_t> (const char , const wchar_t b) { return b; }
+
+template <typename charT>
+inline const charT* _dTcast ( const char* a, const wchar_t* b);
+template <>
+inline const char* _dTcast<char> ( const char* a, const wchar_t* ) { return a; }
+template <>
+inline const wchar_t* _dTcast<wchar_t> ( const char* , const wchar_t* b) { return b; }
+
+
+#define _dT(charT,str) _dTcast<charT>(str,L##str)
+/*!
+ requires
+ - charT == char or wchar_t
+ - str == a string or character literal
+ ensures
+ - returns the literal in the form of a charT type literal.
+!*/
+
+// ----------------------------------------------------------------------------------------
+
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A default_memory_manager
+
+ This memory manager just calls new and delete directly.
+
+ !*/
+ typedef memory_manager_stateless_kernel_1<char> default_memory_manager;
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A swap !*/
+ // make swap available in the dlib namespace
+ using std::swap;
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ Here is where I define my return codes. It is
+ important that they all be < 0.
+ !*/
+
+ enum general_return_codes
+ {
+ TIMEOUT = -1,
+ WOULDBLOCK = -2,
+ OTHER_ERROR = -3,
+ SHUTDOWN = -4,
+ PORTINUSE = -5
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline unsigned long square_root (
+ unsigned long value
+ )
+ /*!
+ requires
+ - value <= 2^32 - 1
+ ensures
+ - returns the square root of value. if the square root is not an
+ integer then it will be rounded up to the nearest integer.
+ !*/
+ {
+ unsigned long x;
+
+ // set the initial guess for what the root is depending on
+ // how big value is
+ if (value < 3)
+ return value;
+ else if (value < 4096) // 12
+ x = 45;
+ else if (value < 65536) // 16
+ x = 179;
+ else if (value < 1048576) // 20
+ x = 717;
+ else if (value < 16777216) // 24
+ x = 2867;
+ else if (value < 268435456) // 28
+ x = 11469;
+ else // 32
+ x = 45875;
+
+
+
+ // find the root
+ x = (x + value/x)>>1;
+ x = (x + value/x)>>1;
+ x = (x + value/x)>>1;
+ x = (x + value/x)>>1;
+
+
+
+ if (x*x < value)
+ return x+1;
+ else
+ return x;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void median (
+ T& one,
+ T& two,
+ T& three
+ );
+ /*!
+ requires
+ - T implements operator<
+ - T is swappable by a global swap()
+ ensures
+ - #one is the median
+ - #one, #two, and #three is some permutation of one, two, and three.
+ !*/
+
+
+ template <
+ typename T
+ >
+ void median (
+ T& one,
+ T& two,
+ T& three
+ )
+ {
+ using std::swap;
+ using dlib::swap;
+
+ if ( one < two )
+ {
+ // one < two
+ if ( two < three )
+ {
+ // one < two < three : two
+ swap(one,two);
+
+ }
+ else
+ {
+ // one < two >= three
+ if ( one < three)
+ {
+ // three
+ swap(three,one);
+ }
+ }
+
+ }
+ else
+ {
+ // one >= two
+ if ( three < one )
+ {
+ // three <= one >= two
+ if ( three < two )
+ {
+ // two
+ swap(two,one);
+ }
+ else
+ {
+ // three
+ swap(three,one);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace relational_operators
+ {
+ template <
+ typename A,
+ typename B
+ >
+ constexpr bool operator> (
+ const A& a,
+ const B& b
+ ) { return b < a; }
+
+ // ---------------------------------
+
+ template <
+ typename A,
+ typename B
+ >
+ constexpr bool operator!= (
+ const A& a,
+ const B& b
+ ) { return !(a == b); }
+
+ // ---------------------------------
+
+ template <
+ typename A,
+ typename B
+ >
+ constexpr bool operator<= (
+ const A& a,
+ const B& b
+ ) { return !(b < a); }
+
+ // ---------------------------------
+
+ template <
+ typename A,
+ typename B
+ >
+ constexpr bool operator>= (
+ const A& a,
+ const B& b
+ ) { return !(a < b); }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void exchange (
+ T& a,
+ T& b
+ )
+ /*!
+ This function does the exact same thing that global swap does and it does it by
+ just calling swap. But a lot of compilers have problems doing a Koenig Lookup
+ and the fact that this has a different name (global swap has the same name as
+ the member functions called swap) makes them compile right.
+
+ So this is a workaround but not too ugly of one. But hopefully I get get
+ rid of this in a few years. So this function is already deprecated.
+
+ This also means you should NOT use this function in your own code unless
+ you have to support an old buggy compiler that benefits from this hack.
+ !*/
+ {
+ using std::swap;
+ using dlib::swap;
+ swap(a,b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_pointer_type
+
+ This is a template where is_pointer_type<T>::value == true when T is a pointer
+ type and false otherwise.
+ !*/
+
+ template <
+ typename T
+ >
+ class is_pointer_type
+ {
+ public:
+ enum { value = false };
+ private:
+ is_pointer_type();
+ };
+
+ template <
+ typename T
+ >
+ class is_pointer_type<T*>
+ {
+ public:
+ enum { value = true };
+ private:
+ is_pointer_type();
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_const_type
+
+ This is a template where is_const_type<T>::value == true when T is a const
+ type and false otherwise.
+ !*/
+
+ template <typename T>
+ struct is_const_type
+ {
+ static const bool value = false;
+ };
+ template <typename T>
+ struct is_const_type<const T>
+ {
+ static const bool value = true;
+ };
+ template <typename T>
+ struct is_const_type<const T&>
+ {
+ static const bool value = true;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_reference_type
+
+ This is a template where is_reference_type<T>::value == true when T is a reference
+ type and false otherwise.
+ !*/
+
+ template <typename T>
+ struct is_reference_type
+ {
+ static const bool value = false;
+ };
+
+ template <typename T> struct is_reference_type<const T&> { static const bool value = true; };
+ template <typename T> struct is_reference_type<T&> { static const bool value = true; };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_same_type
+
+ This is a template where is_same_type<T,U>::value == true when T and U are the
+ same type and false otherwise.
+ !*/
+
+ template <
+ typename T,
+ typename U
+ >
+ class is_same_type
+ {
+ public:
+ enum {value = false};
+ private:
+ is_same_type();
+ };
+
+ template <typename T>
+ class is_same_type<T,T>
+ {
+ public:
+ enum {value = true};
+ private:
+ is_same_type();
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_float_type
+
+ This is a template that can be used to determine if a type is one of the built
+ int floating point types (i.e. float, double, or long double).
+ !*/
+
+ template < typename T > struct is_float_type { const static bool value = false; };
+ template <> struct is_float_type<float> { const static bool value = true; };
+ template <> struct is_float_type<double> { const static bool value = true; };
+ template <> struct is_float_type<long double> { const static bool value = true; };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_convertible
+
+ This is a template that can be used to determine if one type is convertible
+ into another type.
+
+ For example:
+ is_convertible<int,float>::value == true // because ints are convertible to floats
+ is_convertible<int*,float>::value == false // because int pointers are NOT convertible to floats
+ !*/
+
+ template <typename from, typename to>
+ struct is_convertible
+ {
+ struct yes_type { char a; };
+ struct no_type { yes_type a[2]; };
+ static const from& from_helper();
+ static yes_type test(to);
+ static no_type test(...);
+ const static bool value = sizeof(test(from_helper())) == sizeof(yes_type);
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct general_ {};
+ struct special_ : general_ {};
+ template<typename> struct int_ { typedef int type; };
+
+// ----------------------------------------------------------------------------------------
+
+
+ /*!A is_same_object
+
+ This is a templated function which checks if both of its arguments are actually
+ references to the same object. It returns true if they are and false otherwise.
+
+ !*/
+
+ // handle the case where T and U are unrelated types.
+ template < typename T, typename U >
+ typename disable_if_c<is_convertible<T*, U*>::value || is_convertible<U*,T*>::value, bool>::type
+ is_same_object (
+ const T& a,
+ const U& b
+ )
+ {
+ return ((void*)&a == (void*)&b);
+ }
+
+ // handle the case where T and U are related types because their pointers can be
+ // implicitly converted into one or the other. E.g. a derived class and its base class.
+ // Or where both T and U are just the same type. This way we make sure that if there is a
+ // valid way to convert between these two pointer types then we will take that route rather
+ // than the void* approach used otherwise.
+ template < typename T, typename U >
+ typename enable_if_c<is_convertible<T*, U*>::value || is_convertible<U*,T*>::value, bool>::type
+ is_same_object (
+ const T& a,
+ const U& b
+ )
+ {
+ return (&a == &b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_unsigned_type
+
+ This is a template where is_unsigned_type<T>::value == true when T is an unsigned
+ scalar type and false when T is a signed scalar type.
+ !*/
+ template <
+ typename T
+ >
+ struct is_unsigned_type
+ {
+ static const bool value = static_cast<T>((static_cast<T>(0)-static_cast<T>(1))) > 0;
+ };
+ template <> struct is_unsigned_type<long double> { static const bool value = false; };
+ template <> struct is_unsigned_type<double> { static const bool value = false; };
+ template <> struct is_unsigned_type<float> { static const bool value = false; };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_signed_type
+
+ This is a template where is_signed_type<T>::value == true when T is a signed
+ scalar type and false when T is an unsigned scalar type.
+ !*/
+ template <
+ typename T
+ >
+ struct is_signed_type
+ {
+ static const bool value = !is_unsigned_type<T>::value;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class copy_functor
+ {
+ public:
+ void operator() (
+ const T& source,
+ T& destination
+ ) const
+ {
+ destination = source;
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A static_switch
+
+ To use this template you give it some number of boolean expressions and it
+ tells you which one of them is true. If more than one of them is true then
+ it causes a compile time error.
+
+ for example:
+ static_switch<1 + 1 == 2, 4 - 1 == 4>::value == 1 // because the first expression is true
+ static_switch<1 + 1 == 3, 4 == 4>::value == 2 // because the second expression is true
+ static_switch<1 + 1 == 3, 4 == 5>::value == 0 // 0 here because none of them are true
+ static_switch<1 + 1 == 2, 4 == 4>::value == compiler error // because more than one expression is true
+ !*/
+
+ template < bool v1 = 0, bool v2 = 0, bool v3 = 0, bool v4 = 0, bool v5 = 0,
+ bool v6 = 0, bool v7 = 0, bool v8 = 0, bool v9 = 0, bool v10 = 0,
+ bool v11 = 0, bool v12 = 0, bool v13 = 0, bool v14 = 0, bool v15 = 0 >
+ struct static_switch;
+
+ template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 0; };
+ template <> struct static_switch<1,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 1; };
+ template <> struct static_switch<0,1,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 2; };
+ template <> struct static_switch<0,0,1,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 3; };
+ template <> struct static_switch<0,0,0,1,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 4; };
+ template <> struct static_switch<0,0,0,0,1,0,0,0,0,0,0,0,0,0,0> { const static int value = 5; };
+ template <> struct static_switch<0,0,0,0,0,1,0,0,0,0,0,0,0,0,0> { const static int value = 6; };
+ template <> struct static_switch<0,0,0,0,0,0,1,0,0,0,0,0,0,0,0> { const static int value = 7; };
+ template <> struct static_switch<0,0,0,0,0,0,0,1,0,0,0,0,0,0,0> { const static int value = 8; };
+ template <> struct static_switch<0,0,0,0,0,0,0,0,1,0,0,0,0,0,0> { const static int value = 9; };
+ template <> struct static_switch<0,0,0,0,0,0,0,0,0,1,0,0,0,0,0> { const static int value = 10; };
+ template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,1,0,0,0,0> { const static int value = 11; };
+ template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,1,0,0,0> { const static int value = 12; };
+ template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,1,0,0> { const static int value = 13; };
+ template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,1,0> { const static int value = 14; };
+ template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,1> { const static int value = 15; };
+
+// ----------------------------------------------------------------------------------------
+ /*!A is_built_in_scalar_type
+
+ This is a template that allows you to determine if the given type is a built
+ in scalar type such as an int, char, float, short, etc.
+
+ For example, is_built_in_scalar_type<char>::value == true
+ For example, is_built_in_scalar_type<std::string>::value == false
+ !*/
+
+ template <typename T> struct is_built_in_scalar_type { const static bool value = false; };
+
+ template <> struct is_built_in_scalar_type<float> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<double> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<long double> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<short> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<int> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<long> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<unsigned short> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<unsigned int> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<unsigned long> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<uint64> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<int64> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<char> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<signed char> { const static bool value = true; };
+ template <> struct is_built_in_scalar_type<unsigned char> { const static bool value = true; };
+ // Don't define one for wchar_t when using a version of visual studio
+ // older than 8.0 (visual studio 2005) since before then they improperly set
+ // wchar_t to be a typedef rather than its own type as required by the C++
+ // standard.
+#if !defined(_MSC_VER) || _NATIVE_WCHAR_T_DEFINED
+ template <> struct is_built_in_scalar_type<wchar_t> { const static bool value = true; };
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename enable_if<is_built_in_scalar_type<T>,bool>::type is_finite (
+ const T& value
+ )
+ /*!
+ requires
+ - value must be some kind of scalar type such as int or double
+ ensures
+ - returns true if value is a finite value (e.g. not infinity or NaN) and false
+ otherwise.
+ !*/
+ {
+ if (is_float_type<T>::value)
+ return -std::numeric_limits<T>::infinity() < value && value < std::numeric_limits<T>::infinity();
+ else
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A promote
+
+ This is a template that takes one of the built in scalar types and gives you another
+ scalar type that should be big enough to hold sums of values from the original scalar
+ type. The new scalar type will also always be signed.
+
+ For example, promote<uint16>::type == int32
+ !*/
+
+ template <typename T, size_t s = sizeof(T)> struct promote;
+ template <typename T> struct promote<T,1> { typedef int32 type; };
+ template <typename T> struct promote<T,2> { typedef int32 type; };
+ template <typename T> struct promote<T,4> { typedef int64 type; };
+ template <typename T> struct promote<T,8> { typedef int64 type; };
+
+ template <> struct promote<float,sizeof(float)> { typedef double type; };
+ template <> struct promote<double,sizeof(double)> { typedef double type; };
+ template <> struct promote<long double,sizeof(long double)> { typedef long double type; };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A assign_zero_if_built_in_scalar_type
+
+ This function assigns its argument the value of 0 if it is a built in scalar
+ type according to the is_built_in_scalar_type<> template. If it isn't a
+ built in scalar type then it does nothing.
+ !*/
+
+ template <typename T> inline typename disable_if<is_built_in_scalar_type<T>,void>::type assign_zero_if_built_in_scalar_type (T&){}
+ template <typename T> inline typename enable_if<is_built_in_scalar_type<T>,void>::type assign_zero_if_built_in_scalar_type (T& a){a=0;}
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A basic_type
+
+ This is a template that takes a type and strips off any const, volatile, or reference
+ qualifiers and gives you back the basic underlying type. So for example:
+
+ basic_type<const int&>::type == int
+ !*/
+
+ template <typename T> struct basic_type { typedef T type; };
+ template <typename T> struct basic_type<const T> { typedef T type; };
+ template <typename T> struct basic_type<const T&> { typedef T type; };
+ template <typename T> struct basic_type<volatile const T&> { typedef T type; };
+ template <typename T> struct basic_type<T&> { typedef T type; };
+ template <typename T> struct basic_type<volatile T&> { typedef T type; };
+ template <typename T> struct basic_type<volatile T> { typedef T type; };
+ template <typename T> struct basic_type<volatile const T> { typedef T type; };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ T put_in_range (
+ const T& a,
+ const T& b,
+ const T& val
+ )
+ /*!
+ requires
+ - T is a type that looks like double, float, int, or so forth
+ ensures
+ - if (val is within the range [a,b]) then
+ - returns val
+ - else
+ - returns the end of the range [a,b] that is closest to val
+ !*/
+ {
+ if (a < b)
+ {
+ if (val < a)
+ return a;
+ else if (val > b)
+ return b;
+ }
+ else
+ {
+ if (val < b)
+ return b;
+ else if (val > a)
+ return a;
+ }
+
+ return val;
+ }
+
+ // overload for double
+ inline double put_in_range(const double& a, const double& b, const double& val)
+ { return put_in_range<double>(a,b,val); }
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A tabs
+
+ This is a template to compute the absolute value a number at compile time.
+
+ For example,
+ abs<-4>::value == 4
+ abs<4>::value == 4
+ !*/
+
+ template <long x, typename enabled=void>
+ struct tabs { const static long value = x; };
+ template <long x>
+ struct tabs<x,typename enable_if_c<(x < 0)>::type> { const static long value = -x; };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A tmax
+
+ This is a template to compute the max of two values at compile time
+
+ For example,
+ abs<4,7>::value == 7
+ !*/
+
+ template <long x, long y, typename enabled=void>
+ struct tmax { const static long value = x; };
+ template <long x, long y>
+ struct tmax<x,y,typename enable_if_c<(y > x)>::type> { const static long value = y; };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A tmin
+
+ This is a template to compute the min of two values at compile time
+
+ For example,
+ abs<4,7>::value == 4
+ !*/
+
+ template <long x, long y, typename enabled=void>
+ struct tmin { const static long value = x; };
+ template <long x, long y>
+ struct tmin<x,y,typename enable_if_c<(y < x)>::type> { const static long value = y; };
+
+// ----------------------------------------------------------------------------------------
+
+#define DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(testname, returnT, funct_name, args) \
+ struct _two_bytes_##testname { char a[2]; }; \
+ template < typename T, returnT (T::*funct)args > \
+ struct _helper_##testname { typedef char type; }; \
+ template <typename T> \
+ static char _has_##testname##_helper( typename _helper_##testname<T,&T::funct_name >::type ) { return 0;} \
+ template <typename T> \
+ static _two_bytes_##testname _has_##testname##_helper(int) { return _two_bytes_##testname();} \
+ template <typename T> struct _##testname##workaroundbug { \
+ const static unsigned long U = sizeof(_has_##testname##_helper<T>('a')); }; \
+ template <typename T, unsigned long U = _##testname##workaroundbug<T>::U > \
+ struct testname { static const bool value = false; }; \
+ template <typename T> \
+ struct testname<T,1> { static const bool value = true; };
+ /*!A DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST
+
+ The DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST() macro is used to define traits templates
+ that tell you if a class has a certain member function. For example, to make a
+ test to see if a class has a public method with the signature void print(int) you
+ would say:
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, (int))
+
+ Then you can check if a class, T, has this method by looking at the boolean value:
+ has_print<T>::value
+ which will be true if the member function is in the T class.
+
+ Note that you can test for member functions taking no arguments by simply passing
+ in empty () like so:
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ())
+ This would test for a member of the form:
+ void print().
+
+ To test for const member functions you would use a statement such as this:
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ()const)
+ This would test for a member of the form:
+ void print() const.
+
+ To test for const templated member functions you would use a statement such as this:
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, template print<int>, ())
+ This would test for a member of the form:
+ template <typename T> void print().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_function
+
+ This is a template that allows you to determine if the given type is a function.
+
+ For example,
+ void funct();
+
+ is_built_in_scalar_type<funct>::value == true
+ is_built_in_scalar_type<int>::value == false
+ !*/
+
+ template <typename T> struct is_function { static const bool value = false; };
+ template <typename T>
+ struct is_function<T (void)> { static const bool value = true; };
+ template <typename T, typename A0>
+ struct is_function<T (A0)> { static const bool value = true; };
+ template <typename T, typename A0, typename A1>
+ struct is_function<T (A0, A1)> { static const bool value = true; };
+ template <typename T, typename A0, typename A1, typename A2>
+ struct is_function<T (A0, A1, A2)> { static const bool value = true; };
+ template <typename T, typename A0, typename A1, typename A2, typename A3>
+ struct is_function<T (A0, A1, A2, A3)> { static const bool value = true; };
+ template <typename T, typename A0, typename A1, typename A2, typename A3, typename A4>
+ struct is_function<T (A0, A1, A2, A3, A4)> { static const bool value = true; };
+ template <typename T, typename A0, typename A1, typename A2, typename A3, typename A4,
+ typename A5>
+ struct is_function<T (A0,A1,A2,A3,A4,A5)> { static const bool value = true; };
+ template <typename T, typename A0, typename A1, typename A2, typename A3, typename A4,
+ typename A5, typename A6>
+ struct is_function<T (A0,A1,A2,A3,A4,A5,A6)> { static const bool value = true; };
+ template <typename T, typename A0, typename A1, typename A2, typename A3, typename A4,
+ typename A5, typename A6, typename A7>
+ struct is_function<T (A0,A1,A2,A3,A4,A5,A6,A7)> { static const bool value = true; };
+ template <typename T, typename A0, typename A1, typename A2, typename A3, typename A4,
+ typename A5, typename A6, typename A7, typename A8>
+ struct is_function<T (A0,A1,A2,A3,A4,A5,A6,A7,A8)> { static const bool value = true; };
+ template <typename T, typename A0, typename A1, typename A2, typename A3, typename A4,
+ typename A5, typename A6, typename A7, typename A8, typename A9>
+ struct is_function<T (A0,A1,A2,A3,A4,A5,A6,A7,A8,A9)> { static const bool value = true; };
+
+
+ template <typename T> class funct_wrap0
+ {
+ public:
+ funct_wrap0(T (&f_)()):f(f_){}
+ T operator()() const { return f(); }
+ private:
+ T (&f)();
+ };
+ template <typename T, typename A0> class funct_wrap1
+ {
+ public:
+ funct_wrap1(T (&f_)(A0)):f(f_){}
+ T operator()(A0 a0) const { return f(a0); }
+ private:
+ T (&f)(A0);
+ };
+ template <typename T, typename A0, typename A1> class funct_wrap2
+ {
+ public:
+ funct_wrap2(T (&f_)(A0,A1)):f(f_){}
+ T operator()(A0 a0, A1 a1) const { return f(a0,a1); }
+ private:
+ T (&f)(A0,A1);
+ };
+ template <typename T, typename A0, typename A1, typename A2> class funct_wrap3
+ {
+ public:
+ funct_wrap3(T (&f_)(A0,A1,A2)):f(f_){}
+ T operator()(A0 a0, A1 a1, A2 a2) const { return f(a0,a1,a2); }
+ private:
+ T (&f)(A0,A1,A2);
+ };
+ template <typename T, typename A0, typename A1, typename A2, typename A3> class funct_wrap4
+ {
+ public:
+ funct_wrap4(T (&f_)(A0,A1,A2,A3)):f(f_){}
+ T operator()(A0 a0, A1 a1, A2 a2, A3 a3) const { return f(a0,a1,a2,a3); }
+ private:
+ T (&f)(A0,A1,A2,A3);
+ };
+ template <typename T, typename A0, typename A1, typename A2, typename A3, typename A4> class funct_wrap5
+ {
+ public:
+ funct_wrap5(T (&f_)(A0,A1,A2,A3,A4)):f(f_){}
+ T operator()(A0 a0, A1 a1, A2 a2, A3 a3, A4 a4) const { return f(a0,a1,a2,a3,a4); }
+ private:
+ T (&f)(A0,A1,A2,A3,A4);
+ };
+
+ /*!A wrap_function
+
+ This is a template that allows you to turn a global function into a
+ function object. The reason for this template's existance is so you can
+ do stuff like this:
+
+ template <typename T>
+ void call_funct(const T& funct)
+ { cout << funct(); }
+
+ std::string test() { return "asdfasf"; }
+
+ int main()
+ {
+ call_funct(wrap_function(test));
+ }
+
+ The above code doesn't work right on some compilers if you don't
+ use wrap_function.
+ !*/
+
+ template <typename T>
+ funct_wrap0<T> wrap_function(T (&f)()) { return funct_wrap0<T>(f); }
+ template <typename T, typename A0>
+ funct_wrap1<T,A0> wrap_function(T (&f)(A0)) { return funct_wrap1<T,A0>(f); }
+ template <typename T, typename A0, typename A1>
+ funct_wrap2<T,A0,A1> wrap_function(T (&f)(A0, A1)) { return funct_wrap2<T,A0,A1>(f); }
+ template <typename T, typename A0, typename A1, typename A2>
+ funct_wrap3<T,A0,A1,A2> wrap_function(T (&f)(A0, A1, A2)) { return funct_wrap3<T,A0,A1,A2>(f); }
+ template <typename T, typename A0, typename A1, typename A2, typename A3>
+ funct_wrap4<T,A0,A1,A2,A3> wrap_function(T (&f)(A0, A1, A2, A3)) { return funct_wrap4<T,A0,A1,A2,A3>(f); }
+ template <typename T, typename A0, typename A1, typename A2, typename A3, typename A4>
+ funct_wrap5<T,A0,A1,A2,A3,A4> wrap_function(T (&f)(A0, A1, A2, A3, A4)) { return funct_wrap5<T,A0,A1,A2,A3,A4>(f); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <unsigned long bSIZE>
+ class stack_based_memory_block : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple container for a block of memory
+ of bSIZE bytes. This memory block is located on the stack
+ and properly aligned to hold any kind of object.
+ !*/
+ public:
+ static const unsigned long size = bSIZE;
+
+ stack_based_memory_block(): data(mem.data) {}
+
+ void* get () { return data; }
+ /*!
+ ensures
+ - returns a pointer to the block of memory contained in this object
+ !*/
+
+ const void* get () const { return data; }
+ /*!
+ ensures
+ - returns a pointer to the block of memory contained in this object
+ !*/
+
+ private:
+
+ // You obviously can't have a block of memory that has zero bytes in it.
+ COMPILE_TIME_ASSERT(bSIZE > 0);
+
+ union mem_block
+ {
+ // All of this garbage is to make sure this union is properly aligned
+ // (a union is always aligned such that everything in it would be properly
+ // aligned. So the assumption here is that one of these objects has
+ // a large enough alignment requirement to satisfy any object this
+ // block of memory might be cast into).
+ void* void_ptr;
+ int integer;
+ struct {
+ void (stack_based_memory_block::*callback)();
+ stack_based_memory_block* o;
+ } stuff;
+ long double more_stuff;
+
+ uint64 var1;
+ uint32 var2;
+ double var3;
+
+ char data[size];
+ } mem;
+
+ // The reason for having this variable is that doing it this way avoids
+ // warnings from gcc about violations of strict-aliasing rules.
+ void* const data;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename F
+ >
+ auto max_scoring_element(
+ const T& container,
+ F score_func
+ ) -> decltype(std::make_pair(*container.begin(), 0.0))
+ /*!
+ requires
+ - container has .begin() and .end(), allowing it to be enumerated.
+ - score_func() is a function that takes an element of the container and returns a double.
+ ensures
+ - This function finds the element of container that has the largest score,
+ according to score_func(), and returns a std::pair containing that maximal
+ element along with the score.
+ - If the container is empty then make_pair(a default initialized object, -infinity) is returned.
+ !*/
+ {
+ double best_score = -std::numeric_limits<double>::infinity();
+ auto best_i = container.begin();
+ for (auto i = container.begin(); i != container.end(); ++i)
+ {
+ auto score = score_func(*i);
+ if (score > best_score)
+ {
+ best_score = score;
+ best_i = i;
+ }
+ }
+
+ using item_type = typename std::remove_reference<decltype(*best_i)>::type;
+
+ if (best_i == container.end())
+ return std::make_pair(item_type(), best_score);
+ else
+ return std::make_pair(*best_i, best_score);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename F
+ >
+ auto min_scoring_element(
+ const T& container,
+ F score_func
+ ) -> decltype(std::make_pair(*container.begin(), 0.0))
+ /*!
+ requires
+ - container has .begin() and .end(), allowing it to be enumerated.
+ - score_func() is a function that takes an element of the container and returns a double.
+ ensures
+ - This function finds the element of container that has the smallest score,
+ according to score_func(), and returns a std::pair containing that minimal
+ element along with the score.
+ - If the container is empty then make_pair(a default initialized object, infinity) is returned.
+ !*/
+ {
+ double best_score = std::numeric_limits<double>::infinity();
+ auto best_i = container.begin();
+ for (auto i = container.begin(); i != container.end(); ++i)
+ {
+ auto score = score_func(*i);
+ if (score < best_score)
+ {
+ best_score = score;
+ best_i = i;
+ }
+ }
+
+ using item_type = typename std::remove_reference<decltype(*best_i)>::type;
+
+ if (best_i == container.end())
+ return std::make_pair(item_type(), best_score);
+ else
+ return std::make_pair(*best_i, best_score);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ALGs_
+
diff --git a/ml/dlib/dlib/all/source.cpp b/ml/dlib/dlib/all/source.cpp
new file mode 100644
index 000000000..1fc646a86
--- /dev/null
+++ b/ml/dlib/dlib/all/source.cpp
@@ -0,0 +1,98 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ALL_SOURCe_
+#define DLIB_ALL_SOURCe_
+
+#if defined(DLIB_ALGs_) || defined(DLIB_PLATFORm_)
+#include "../dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+// ISO C++ code
+#include "../base64/base64_kernel_1.cpp"
+#include "../bigint/bigint_kernel_1.cpp"
+#include "../bigint/bigint_kernel_2.cpp"
+#include "../bit_stream/bit_stream_kernel_1.cpp"
+#include "../entropy_decoder/entropy_decoder_kernel_1.cpp"
+#include "../entropy_decoder/entropy_decoder_kernel_2.cpp"
+#include "../entropy_encoder/entropy_encoder_kernel_1.cpp"
+#include "../entropy_encoder/entropy_encoder_kernel_2.cpp"
+#include "../md5/md5_kernel_1.cpp"
+#include "../tokenizer/tokenizer_kernel_1.cpp"
+#include "../unicode/unicode.cpp"
+#include "../test_for_odr_violations.cpp"
+
+
+
+
+#ifndef DLIB_ISO_CPP_ONLY
+// Code that depends on OS specific APIs
+
+// include this first so that it can disable the older version
+// of the winsock API when compiled in windows.
+#include "../sockets/sockets_kernel_1.cpp"
+#include "../bsp/bsp.cpp"
+
+#include "../dir_nav/dir_nav_kernel_1.cpp"
+#include "../dir_nav/dir_nav_kernel_2.cpp"
+#include "../dir_nav/dir_nav_extensions.cpp"
+#include "../linker/linker_kernel_1.cpp"
+#include "../logger/extra_logger_headers.cpp"
+#include "../logger/logger_kernel_1.cpp"
+#include "../logger/logger_config_file.cpp"
+#include "../misc_api/misc_api_kernel_1.cpp"
+#include "../misc_api/misc_api_kernel_2.cpp"
+#include "../sockets/sockets_extensions.cpp"
+#include "../sockets/sockets_kernel_2.cpp"
+#include "../sockstreambuf/sockstreambuf.cpp"
+#include "../sockstreambuf/sockstreambuf_unbuffered.cpp"
+#include "../server/server_kernel.cpp"
+#include "../server/server_iostream.cpp"
+#include "../server/server_http.cpp"
+#include "../threads/multithreaded_object_extension.cpp"
+#include "../threads/threaded_object_extension.cpp"
+#include "../threads/threads_kernel_1.cpp"
+#include "../threads/threads_kernel_2.cpp"
+#include "../threads/threads_kernel_shared.cpp"
+#include "../threads/thread_pool_extension.cpp"
+#include "../threads/async.cpp"
+#include "../timer/timer.cpp"
+#include "../stack_trace.cpp"
+
+#ifdef DLIB_PNG_SUPPORT
+#include "../image_loader/png_loader.cpp"
+#include "../image_saver/save_png.cpp"
+#endif
+
+#ifdef DLIB_JPEG_SUPPORT
+#include "../image_loader/jpeg_loader.cpp"
+#include "../image_saver/save_jpeg.cpp"
+#endif
+
+#ifndef DLIB_NO_GUI_SUPPORT
+#include "../gui_widgets/fonts.cpp"
+#include "../gui_widgets/widgets.cpp"
+#include "../gui_widgets/drawable.cpp"
+#include "../gui_widgets/canvas_drawing.cpp"
+#include "../gui_widgets/style.cpp"
+#include "../gui_widgets/base_widgets.cpp"
+#include "../gui_core/gui_core_kernel_1.cpp"
+#include "../gui_core/gui_core_kernel_2.cpp"
+#endif // DLIB_NO_GUI_SUPPORT
+
+#include "../dnn/cpu_dlib.cpp"
+#include "../dnn/tensor_tools.cpp"
+
+#endif // DLIB_ISO_CPP_ONLY
+
+
+
+#include "../data_io/image_dataset_metadata.cpp"
+#include "../data_io/mnist.cpp"
+#include "../global_optimization/global_function_search.cpp"
+#include "../filtering/kalman_filter.cpp"
+
+
+#define DLIB_ALL_SOURCE_END
+
+#endif // DLIB_ALL_SOURCe_
+
diff --git a/ml/dlib/dlib/any.h b/ml/dlib/dlib/any.h
new file mode 100644
index 000000000..01f047066
--- /dev/null
+++ b/ml/dlib/dlib/any.h
@@ -0,0 +1,13 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AnY_
+#define DLIB_AnY_
+
+#include "any/any.h"
+#include "any/any_trainer.h"
+#include "any/any_decision_function.h"
+#include "any/any_function.h"
+
+#endif // DLIB_AnY_
+
+
diff --git a/ml/dlib/dlib/any/any.h b/ml/dlib/dlib/any/any.h
new file mode 100644
index 000000000..b5ef1bc8b
--- /dev/null
+++ b/ml/dlib/dlib/any/any.h
@@ -0,0 +1,183 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AnY_H_
+#define DLIB_AnY_H_
+
+#include "any_abstract.h"
+#include "../algs.h"
+
+#include <memory>
+#include <typeinfo>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class bad_any_cast : public std::bad_cast
+ {
+ public:
+ virtual const char * what() const throw()
+ {
+ return "bad_any_cast";
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class any
+ {
+
+ public:
+
+ any()
+ {
+ }
+
+ any (
+ const any& item
+ )
+ {
+ if (item.data)
+ {
+ item.data->copy_to(data);
+ }
+ }
+
+ template <typename T>
+ any (
+ const T& item
+ )
+ {
+ typedef typename basic_type<T>::type U;
+ data.reset(new derived<U>(item));
+ }
+
+ void clear (
+ )
+ {
+ data.reset();
+ }
+
+ template <typename T>
+ bool contains (
+ ) const
+ {
+ typedef typename basic_type<T>::type U;
+ return dynamic_cast<derived<U>*>(data.get()) != 0;
+ }
+
+ bool is_empty(
+ ) const
+ {
+ return data.get() == 0;
+ }
+
+ template <typename T>
+ T& cast_to(
+ )
+ {
+ typedef typename basic_type<T>::type U;
+ derived<U>* d = dynamic_cast<derived<U>*>(data.get());
+ if (d == 0)
+ {
+ throw bad_any_cast();
+ }
+
+ return d->item;
+ }
+
+ template <typename T>
+ const T& cast_to(
+ ) const
+ {
+ typedef typename basic_type<T>::type U;
+ derived<U>* d = dynamic_cast<derived<U>*>(data.get());
+ if (d == 0)
+ {
+ throw bad_any_cast();
+ }
+
+ return d->item;
+ }
+
+ template <typename T>
+ T& get(
+ )
+ {
+ typedef typename basic_type<T>::type U;
+ derived<U>* d = dynamic_cast<derived<U>*>(data.get());
+ if (d == 0)
+ {
+ d = new derived<U>();
+ data.reset(d);
+ }
+
+ return d->item;
+ }
+
+ any& operator= (
+ const any& item
+ )
+ {
+ any(item).swap(*this);
+ return *this;
+ }
+
+ void swap (
+ any& item
+ )
+ {
+ data.swap(item.data);
+ }
+
+ private:
+
+ struct base
+ {
+ virtual ~base() {}
+
+ virtual void copy_to (
+ std::unique_ptr<base>& dest
+ ) const = 0;
+ };
+
+ template <typename T>
+ struct derived : public base
+ {
+ T item;
+ derived() {}
+ derived(const T& val) : item(val) {}
+
+ virtual void copy_to (
+ std::unique_ptr<base>& dest
+ ) const
+ {
+ dest.reset(new derived<T>(item));
+ }
+ };
+
+ std::unique_ptr<base> data;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ any& a,
+ any& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T> T& any_cast(any& a) { return a.cast_to<T>(); }
+ template <typename T> const T& any_cast(const any& a) { return a.cast_to<T>(); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_AnY_H_
+
+
+
diff --git a/ml/dlib/dlib/any/any_abstract.h b/ml/dlib/dlib/any/any_abstract.h
new file mode 100644
index 000000000..2fea96381
--- /dev/null
+++ b/ml/dlib/dlib/any/any_abstract.h
@@ -0,0 +1,210 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_AnY_ABSTRACT_H_
+#ifdef DLIB_AnY_ABSTRACT_H_
+
+#include <typeinfo>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class bad_any_cast : public std::bad_cast
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is the exception class used by the any object.
+ It is used to indicate when someone attempts to cast an any
+ object into a type which isn't contained in the any object.
+ !*/
+
+ public:
+ virtual const char* what() const throw() { return "bad_any_cast"; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class any
+ {
+ /*!
+ INITIAL VALUE
+ - is_empty() == true
+ - for all T: contains<T>() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is basically a type-safe version of a void*. In particular,
+ it is a container which can contain only one object but the object may
+ be of any type.
+
+ It is somewhat like the type_safe_union except you don't have to declare
+ the set of possible content types beforehand. So in some sense this is
+ like a less type-strict version of the type_safe_union.
+ !*/
+
+ public:
+
+ any(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ any (
+ const any& item
+ );
+ /*!
+ ensures
+ - copies the state of item into *this.
+ - Note that *this and item will contain independent copies of the
+ contents of item. That is, this function performs a deep
+ copy and therefore does not result in *this containing
+ any kind of reference to item.
+ !*/
+
+ template < typename T >
+ any (
+ const T& item
+ );
+ /*!
+ ensures
+ - #contains<T>() == true
+ - #cast_to<T>() == item
+ (i.e. a copy of item will be stored in *this)
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this will have its default value. I.e. #is_empty() == true
+ !*/
+
+ template <typename T>
+ bool contains (
+ ) const;
+ /*!
+ ensures
+ - if (this object currently contains an object of type T) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_empty(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains any kind of object) then
+ - returns false
+ - else
+ - returns true
+ !*/
+
+ template <typename T>
+ T& cast_to(
+ );
+ /*!
+ ensures
+ - if (contains<T>() == true) then
+ - returns a non-const reference to the object contained within *this
+ - else
+ - throws bad_any_cast
+ !*/
+
+ template <typename T>
+ const T& cast_to(
+ ) const;
+ /*!
+ ensures
+ - if (contains<T>() == true) then
+ - returns a const reference to the object contained within *this
+ - else
+ - throws bad_any_cast
+ !*/
+
+ template <typename T>
+ T& get(
+ );
+ /*!
+ ensures
+ - #is_empty() == false
+ - #contains<T>() == true
+ - if (contains<T>() == true)
+ - returns a non-const reference to the object contained in *this.
+ - else
+ - Constructs an object of type T inside *this
+ - Any previous object stored in this any object is destructed and its
+ state is lost.
+ - returns a non-const reference to the newly created T object.
+ !*/
+
+ any& operator= (
+ const any& item
+ );
+ /*!
+ ensures
+ - copies the state of item into *this.
+ - Note that *this and item will contain independent copies of the
+ contents of item. That is, this function performs a deep
+ copy and therefore does not result in *this containing
+ any kind of reference to item.
+ !*/
+
+ void swap (
+ any& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ - does not invalidate pointers or references to the object contained
+ inside *this or item. Moreover, a pointer or reference to the object in
+ *this will now refer to the contents of #item and vice versa.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ any& a,
+ any& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T& any_cast(
+ any& a
+ ) { return a.cast_to<T>(); }
+ /*!
+ ensures
+ - returns a.cast_to<T>()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const T& any_cast(
+ const any& a
+ ) { return a.cast_to<T>(); }
+ /*!
+ ensures
+ - returns a.cast_to<T>()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AnY_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/any/any_decision_function.h b/ml/dlib/dlib/any/any_decision_function.h
new file mode 100644
index 000000000..771e9302b
--- /dev/null
+++ b/ml/dlib/dlib/any/any_decision_function.h
@@ -0,0 +1,209 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AnY_DECISION_FUNCTION_Hh_
+#define DLIB_AnY_DECISION_FUNCTION_Hh_
+
+#include "any.h"
+
+#include "any_decision_function_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type_,
+ typename result_type_ = double
+ >
+ class any_decision_function
+ {
+
+ public:
+
+ typedef sample_type_ sample_type;
+ typedef result_type_ result_type;
+ typedef default_memory_manager mem_manager_type;
+
+ any_decision_function()
+ {
+ }
+
+ any_decision_function (
+ const any_decision_function& item
+ )
+ {
+ if (item.data)
+ {
+ item.data->copy_to(data);
+ }
+ }
+
+ template <typename T>
+ any_decision_function (
+ const T& item
+ )
+ {
+ typedef typename basic_type<T>::type U;
+ data.reset(new derived<U>(item));
+ }
+
+ void clear (
+ )
+ {
+ data.reset();
+ }
+
+ template <typename T>
+ bool contains (
+ ) const
+ {
+ typedef typename basic_type<T>::type U;
+ return dynamic_cast<derived<U>*>(data.get()) != 0;
+ }
+
+ bool is_empty(
+ ) const
+ {
+ return data.get() == 0;
+ }
+
+ result_type operator() (
+ const sample_type& item
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_empty() == false,
+ "\t result_type any_decision_function::operator()"
+ << "\n\t You can't call operator() on an empty any_decision_function"
+ << "\n\t this: " << this
+ );
+
+ return data->evaluate(item);
+ }
+
+ template <typename T>
+ T& cast_to(
+ )
+ {
+ typedef typename basic_type<T>::type U;
+ derived<U>* d = dynamic_cast<derived<U>*>(data.get());
+ if (d == 0)
+ {
+ throw bad_any_cast();
+ }
+
+ return d->item;
+ }
+
+ template <typename T>
+ const T& cast_to(
+ ) const
+ {
+ typedef typename basic_type<T>::type U;
+ derived<U>* d = dynamic_cast<derived<U>*>(data.get());
+ if (d == 0)
+ {
+ throw bad_any_cast();
+ }
+
+ return d->item;
+ }
+
+ template <typename T>
+ T& get(
+ )
+ {
+ typedef typename basic_type<T>::type U;
+ derived<U>* d = dynamic_cast<derived<U>*>(data.get());
+ if (d == 0)
+ {
+ d = new derived<U>();
+ data.reset(d);
+ }
+
+ return d->item;
+ }
+
+ any_decision_function& operator= (
+ const any_decision_function& item
+ )
+ {
+ any_decision_function(item).swap(*this);
+ return *this;
+ }
+
+ void swap (
+ any_decision_function& item
+ )
+ {
+ data.swap(item.data);
+ }
+
+ private:
+
+ struct base
+ {
+ virtual ~base() {}
+
+ virtual void copy_to (
+ std::unique_ptr<base>& dest
+ ) const = 0;
+
+ virtual result_type evaluate (
+ const sample_type& samp
+ ) const = 0;
+ };
+
+ template <typename T>
+ struct derived : public base
+ {
+ T item;
+ derived() {}
+ derived(const T& val) : item(val) {}
+
+ virtual void copy_to (
+ std::unique_ptr<base>& dest
+ ) const
+ {
+ dest.reset(new derived<T>(item));
+ }
+
+ virtual result_type evaluate (
+ const sample_type& samp
+ ) const
+ {
+ return item(samp);
+ }
+ };
+
+ std::unique_ptr<base> data;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type,
+ typename result_type
+ >
+ inline void swap (
+ any_decision_function<sample_type, result_type>& a,
+ any_decision_function<sample_type, result_type>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename V>
+ T& any_cast(any_decision_function<U,V>& a) { return a.template cast_to<T>(); }
+
+ template <typename T, typename U, typename V>
+ const T& any_cast(const any_decision_function<U,V>& a) { return a.template cast_to<T>(); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_AnY_DECISION_FUNCTION_Hh_
+
+
diff --git a/ml/dlib/dlib/any/any_decision_function_abstract.h b/ml/dlib/dlib/any/any_decision_function_abstract.h
new file mode 100644
index 000000000..8b6644210
--- /dev/null
+++ b/ml/dlib/dlib/any/any_decision_function_abstract.h
@@ -0,0 +1,224 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_
+#ifdef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_
+
+#include "any_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type_,
+ typename result_type_ = double
+ >
+ class any_decision_function
+ {
+ /*!
+ INITIAL VALUE
+ - is_empty() == true
+ - for all T: contains<T>() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a version of dlib::any that is restricted to containing
+ elements which are some kind of function object with an operator() with
+ the following signature:
+ result_type operator()(const sample_type&) const
+
+ It is intended to be used to contain dlib::decision_function objects and
+ other types which represent learned decision functions. It allows you
+ to write code which contains and processes these decision functions
+ without needing to know the specific types of decision functions used.
+ !*/
+
+ public:
+
+ typedef sample_type_ sample_type;
+ typedef result_type_ result_type;
+ typedef default_memory_manager mem_manager_type;
+
+ any_decision_function(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ any_decision_function (
+ const any_decision_function& item
+ );
+ /*!
+ ensures
+ - copies the state of item into *this.
+ - Note that *this and item will contain independent copies of the
+ contents of item. That is, this function performs a deep
+ copy and therefore does not result in *this containing
+ any kind of reference to item.
+ !*/
+
+ template < typename T >
+ any_decision_function (
+ const T& item
+ );
+ /*!
+ ensures
+ - #contains<T>() == true
+ - #cast_to<T>() == item
+ (i.e. a copy of item will be stored in *this)
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this will have its default value. I.e. #is_empty() == true
+ !*/
+
+ template <typename T>
+ bool contains (
+ ) const;
+ /*!
+ ensures
+ - if (this object currently contains an object of type T) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_empty(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains any kind of object) then
+ - returns false
+ - else
+ - returns true
+ !*/
+
+ result_type operator() (
+ const sample_type& item
+ ) const;
+ /*!
+ requires
+ - is_empty() == false
+ ensures
+ - Let F denote the function object contained within *this. Then
+ this function performs:
+ return F(item)
+ !*/
+
+ template <typename T>
+ T& cast_to(
+ );
+ /*!
+ ensures
+ - if (contains<T>() == true) then
+ - returns a non-const reference to the object contained within *this
+ - else
+ - throws bad_any_cast
+ !*/
+
+ template <typename T>
+ const T& cast_to(
+ ) const;
+ /*!
+ ensures
+ - if (contains<T>() == true) then
+ - returns a const reference to the object contained within *this
+ - else
+ - throws bad_any_cast
+ !*/
+
+ template <typename T>
+ T& get(
+ );
+ /*!
+ ensures
+ - #is_empty() == false
+ - #contains<T>() == true
+ - if (contains<T>() == true)
+ - returns a non-const reference to the object contained in *this.
+ - else
+ - Constructs an object of type T inside *this
+ - Any previous object stored in this any_decision_function object is destructed and its
+ state is lost.
+ - returns a non-const reference to the newly created T object.
+ !*/
+
+ any_decision_function& operator= (
+ const any_decision_function& item
+ );
+ /*!
+ ensures
+ - copies the state of item into *this.
+ - Note that *this and item will contain independent copies of the
+ contents of item. That is, this function performs a deep
+ copy and therefore does not result in *this containing
+ any kind of reference to item.
+ !*/
+
+ void swap (
+ any_decision_function& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type,
+ typename result_type
+ >
+ inline void swap (
+ any_decision_function<sample_type,result_type>& a,
+ any_decision_function<sample_type,result_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename sample_type,
+ typename result_type
+ >
+ T& any_cast(
+ any_decision_function<sample_type,result_type>& a
+ ) { return a.cast_to<T>(); }
+ /*!
+ ensures
+ - returns a.cast_to<T>()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename sample_type,
+ typename result_type
+ >
+ const T& any_cast(
+ const any_decision_function<sample_type,result_type>& a
+ ) { return a.cast_to<T>(); }
+ /*!
+ ensures
+ - returns a.cast_to<T>()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_
+
+
+
diff --git a/ml/dlib/dlib/any/any_function.h b/ml/dlib/dlib/any/any_function.h
new file mode 100644
index 000000000..f186b4d3f
--- /dev/null
+++ b/ml/dlib/dlib/any/any_function.h
@@ -0,0 +1,885 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AnY_FUNCTION_Hh_
+#define DLIB_AnY_FUNCTION_Hh_
+
+#include "any.h"
+
+#include "any_function_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct sig_traits {};
+
+ template <
+ typename T
+ >
+ struct sig_traits<T ()>
+ {
+ typedef T result_type;
+ typedef void arg1_type;
+ typedef void arg2_type;
+ typedef void arg3_type;
+ typedef void arg4_type;
+ typedef void arg5_type;
+ typedef void arg6_type;
+ typedef void arg7_type;
+ typedef void arg8_type;
+ typedef void arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 0;
+ };
+
+ template <
+ typename T,
+ typename A1
+ >
+ struct sig_traits<T (A1)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef void arg2_type;
+ typedef void arg3_type;
+ typedef void arg4_type;
+ typedef void arg5_type;
+ typedef void arg6_type;
+ typedef void arg7_type;
+ typedef void arg8_type;
+ typedef void arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 1;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2
+ >
+ struct sig_traits<T (A1,A2)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef void arg3_type;
+ typedef void arg4_type;
+ typedef void arg5_type;
+ typedef void arg6_type;
+ typedef void arg7_type;
+ typedef void arg8_type;
+ typedef void arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 2;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3
+ >
+ struct sig_traits<T (A1,A2,A3)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef void arg4_type;
+ typedef void arg5_type;
+ typedef void arg6_type;
+ typedef void arg7_type;
+ typedef void arg8_type;
+ typedef void arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 3;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4
+ >
+ struct sig_traits<T (A1,A2,A3,A4)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef void arg5_type;
+ typedef void arg6_type;
+ typedef void arg7_type;
+ typedef void arg8_type;
+ typedef void arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 4;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef void arg6_type;
+ typedef void arg7_type;
+ typedef void arg8_type;
+ typedef void arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 5;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef void arg7_type;
+ typedef void arg8_type;
+ typedef void arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 6;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef void arg8_type;
+ typedef void arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 7;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef void arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 8;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef void arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 9;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef void arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 10;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef void arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 11;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11,
+ typename A12
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef A12 arg12_type;
+ typedef void arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 12;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11,
+ typename A12,
+ typename A13
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef A12 arg12_type;
+ typedef A13 arg13_type;
+ typedef void arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 13;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11,
+ typename A12,
+ typename A13,
+ typename A14
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef A12 arg12_type;
+ typedef A13 arg13_type;
+ typedef A14 arg14_type;
+ typedef void arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 14;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11,
+ typename A12,
+ typename A13,
+ typename A14,
+ typename A15
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef A12 arg12_type;
+ typedef A13 arg13_type;
+ typedef A14 arg14_type;
+ typedef A15 arg15_type;
+ typedef void arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 15;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11,
+ typename A12,
+ typename A13,
+ typename A14,
+ typename A15,
+ typename A16
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef A12 arg12_type;
+ typedef A13 arg13_type;
+ typedef A14 arg14_type;
+ typedef A15 arg15_type;
+ typedef A16 arg16_type;
+ typedef void arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 16;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11,
+ typename A12,
+ typename A13,
+ typename A14,
+ typename A15,
+ typename A16,
+ typename A17
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef A12 arg12_type;
+ typedef A13 arg13_type;
+ typedef A14 arg14_type;
+ typedef A15 arg15_type;
+ typedef A16 arg16_type;
+ typedef A17 arg17_type;
+ typedef void arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 17;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11,
+ typename A12,
+ typename A13,
+ typename A14,
+ typename A15,
+ typename A16,
+ typename A17,
+ typename A18
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef A12 arg12_type;
+ typedef A13 arg13_type;
+ typedef A14 arg14_type;
+ typedef A15 arg15_type;
+ typedef A16 arg16_type;
+ typedef A17 arg17_type;
+ typedef A18 arg18_type;
+ typedef void arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 18;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11,
+ typename A12,
+ typename A13,
+ typename A14,
+ typename A15,
+ typename A16,
+ typename A17,
+ typename A18,
+ typename A19
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18,A19)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef A12 arg12_type;
+ typedef A13 arg13_type;
+ typedef A14 arg14_type;
+ typedef A15 arg15_type;
+ typedef A16 arg16_type;
+ typedef A17 arg17_type;
+ typedef A18 arg18_type;
+ typedef A19 arg19_type;
+ typedef void arg20_type;
+
+ const static unsigned long num_args = 19;
+ };
+
+ template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10,
+ typename A11,
+ typename A12,
+ typename A13,
+ typename A14,
+ typename A15,
+ typename A16,
+ typename A17,
+ typename A18,
+ typename A19,
+ typename A20
+ >
+ struct sig_traits<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18,A19,A20)>
+ {
+ typedef T result_type;
+ typedef A1 arg1_type;
+ typedef A2 arg2_type;
+ typedef A3 arg3_type;
+ typedef A4 arg4_type;
+ typedef A5 arg5_type;
+ typedef A6 arg6_type;
+ typedef A7 arg7_type;
+ typedef A8 arg8_type;
+ typedef A9 arg9_type;
+ typedef A10 arg10_type;
+ typedef A11 arg11_type;
+ typedef A12 arg12_type;
+ typedef A13 arg13_type;
+ typedef A14 arg14_type;
+ typedef A15 arg15_type;
+ typedef A16 arg16_type;
+ typedef A17 arg17_type;
+ typedef A18 arg18_type;
+ typedef A19 arg19_type;
+ typedef A20 arg20_type;
+
+ const static unsigned long num_args = 20;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename function_type,
+ // These arguments are used to control the overloading. A user should
+ // not mess with them.
+ typename Enabled = void,
+ unsigned long Num_args = sig_traits<function_type>::num_args
+ >
+ class any_function
+ {
+ private:
+ any_function() {}
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ An error on this line means you are trying to use a function signature
+ with more than the supported number of arguments. The current version
+ of dlib only supports up to 10 arguments.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+ };
+
+
+ // The following preprocessor commands build the various overloaded versions
+ // of any_function for different numbers of commands and void vs. non-void return
+ // types.
+
+// 0 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST
+#define DLIB_ANY_FUNCTION_ARGS
+#define DLIB_ANY_FUNCTION_NUM_ARGS 0
+#include "any_function_impl2.h"
+
+// 1 argument
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1
+#define DLIB_ANY_FUNCTION_ARGS a1
+#define DLIB_ANY_FUNCTION_NUM_ARGS 1
+#include "any_function_impl2.h"
+
+// 2 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2
+#define DLIB_ANY_FUNCTION_ARGS a1,a2
+#define DLIB_ANY_FUNCTION_NUM_ARGS 2
+#include "any_function_impl2.h"
+
+// 3 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3
+#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3
+#define DLIB_ANY_FUNCTION_NUM_ARGS 3
+#include "any_function_impl2.h"
+
+// 4 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4
+#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4
+#define DLIB_ANY_FUNCTION_NUM_ARGS 4
+#include "any_function_impl2.h"
+
+// 5 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \
+ arg5_type a5
+#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5
+#define DLIB_ANY_FUNCTION_NUM_ARGS 5
+#include "any_function_impl2.h"
+
+// 6 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \
+ arg5_type a5, arg6_type a6
+#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6
+#define DLIB_ANY_FUNCTION_NUM_ARGS 6
+#include "any_function_impl2.h"
+
+// 7 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \
+ arg5_type a5, arg6_type a6, arg7_type a7
+#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6,a7
+#define DLIB_ANY_FUNCTION_NUM_ARGS 7
+#include "any_function_impl2.h"
+
+// 8 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \
+ arg5_type a5, arg6_type a6, arg7_type a7, arg8_type a8
+#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6,a7,a8
+#define DLIB_ANY_FUNCTION_NUM_ARGS 8
+#include "any_function_impl2.h"
+
+// 9 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \
+ arg5_type a5, arg6_type a6, arg7_type a7, arg8_type a8, \
+ arg9_type a9
+#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6,a7,a8,a9
+#define DLIB_ANY_FUNCTION_NUM_ARGS 9
+#include "any_function_impl2.h"
+
+// 10 arguments
+#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \
+ arg5_type a5, arg6_type a6, arg7_type a7, arg8_type a8, \
+ arg9_type a9, arg10_type a10
+#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6,a7,a8,a9,a10
+#define DLIB_ANY_FUNCTION_NUM_ARGS 10
+#include "any_function_impl2.h"
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename function_type>
+ inline void swap (
+ any_function<function_type>& a,
+ any_function<function_type>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename function_type>
+ T& any_cast(any_function<function_type>& a) { return a.template cast_to<T>(); }
+
+ template <typename T, typename function_type>
+ const T& any_cast(const any_function<function_type>& a) { return a.template cast_to<T>(); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AnY_FUNCTION_Hh_
+
diff --git a/ml/dlib/dlib/any/any_function_abstract.h b/ml/dlib/dlib/any/any_function_abstract.h
new file mode 100644
index 000000000..1fc129edb
--- /dev/null
+++ b/ml/dlib/dlib/any/any_function_abstract.h
@@ -0,0 +1,292 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_AnY_FUNCTION_ABSTRACT_H_
+#ifdef DLIB_AnY_FUNCTION_ABSTRACT_H_
+
+#include "any_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename function_type
+ >
+ class any_function
+ {
+ /*!
+ REQUIREMENTS ON function_type
+ This type should be a function signature. Some examples are:
+ void (int,int) // a function returning nothing and taking two ints
+ void () // a function returning nothing and taking no arguments
+ char (double&) // a function returning a char and taking a reference to a double
+
+ The number of arguments in the function must be no greater than 10.
+
+ INITIAL VALUE
+ - is_empty() == true
+ - for all T: contains<T>() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a version of dlib::any that is restricted to containing
+ elements which are some kind of function object with an operator() which
+ matches the function signature defined by function_type.
+
+
+ Here is an example:
+ #include <iostream>
+ #include <string>
+ #include "dlib/any.h"
+ using namespace std;
+ void print_message(string str) { cout << str << endl; }
+
+ int main()
+ {
+ dlib::any_function<void(string)> f;
+ f = print_message;
+ f("hello world"); // calls print_message("hello world")
+ }
+
+ Note that any_function objects can be used to store general function
+ objects (i.e. defined by a class with an overloaded operator()) in
+ addition to regular global functions.
+ !*/
+
+ public:
+
+ // This is the type of object returned by function_type functions.
+ typedef result_type_for_function_type result_type;
+ // Typedefs defining the argument types. If an argument does not exist
+ // then it is set to void.
+ typedef type_of_first_argument_in_funct_type arg1_type;
+ typedef type_of_second_argument_in_funct_type arg2_type;
+ ...
+ typedef type_of_last_argument_in_funct_type arg10_type;
+ const static unsigned long num_args = total_number_of_non_void_arguments;
+
+ any_function(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ any_function (
+ const any_function& item
+ );
+ /*!
+ ensures
+ - copies the state of item into *this.
+ - Note that *this and item will contain independent copies of the
+ contents of item. That is, this function performs a deep
+ copy and therefore does not result in *this containing
+ any kind of reference to item.
+ !*/
+
+ template < typename T >
+ any_function (
+ const T& item
+ );
+ /*!
+ ensures
+ - #contains<T>() == true
+ - #cast_to<T>() == item
+ (i.e. a copy of item will be stored in *this)
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this will have its default value. I.e. #is_empty() == true
+ !*/
+
+ template <typename T>
+ bool contains (
+ ) const;
+ /*!
+ ensures
+ - if (this object currently contains an object of type T) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_empty(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains any kind of object) then
+ - returns false
+ - else
+ - returns true
+ !*/
+
+ bool is_set (
+ ) const;
+ /*!
+ ensures
+ - returns !is_empty()
+ !*/
+
+ result_type operator() (
+ ) const;
+ /*!
+ requires
+ - is_empty() == false
+ - the signature defined by function_type takes no arguments
+ ensures
+ - Let F denote the function object contained within *this. Then
+ this function performs:
+ return F()
+ or if result_type is void then this function performs:
+ F()
+ !*/
+
+ result_type operator() (
+ const arg1_type& a1
+ ) const;
+ /*!
+ requires
+ - is_empty() == false
+ - the signature defined by function_type takes one argument
+ ensures
+ - Let F denote the function object contained within *this. Then
+ this function performs:
+ return F(a1)
+ or if result_type is void then this function performs:
+ F(a1)
+ !*/
+
+ result_type operator() (
+ const arg1_type& a1,
+ const arg2_type& a2
+ ) const;
+ /*!
+ requires
+ - is_empty() == false
+ - the signature defined by function_type takes two arguments
+ ensures
+ - Let F denote the function object contained within *this. Then
+ this function performs:
+ return F(a1,a2)
+ or if result_type is void then this function performs:
+ F(a1,a2)
+ !*/
+
+ /* !!!!!!!!! NOTE !!!!!!!!!
+
+ In addition to the above, operator() is defined for up to 10 arguments.
+ They are not listed here because it would clutter the documentation.
+
+ !!!!!!!!! NOTE !!!!!!!!! */
+
+ template <typename T>
+ T& cast_to(
+ );
+ /*!
+ ensures
+ - if (contains<T>() == true) then
+ - returns a non-const reference to the object contained within *this
+ - else
+ - throws bad_any_cast
+ !*/
+
+ template <typename T>
+ const T& cast_to(
+ ) const;
+ /*!
+ ensures
+ - if (contains<T>() == true) then
+ - returns a const reference to the object contained within *this
+ - else
+ - throws bad_any_cast
+ !*/
+
+ template <typename T>
+ T& get(
+ );
+ /*!
+ ensures
+ - #is_empty() == false
+ - #contains<T>() == true
+ - if (contains<T>() == true)
+ - returns a non-const reference to the object contained in *this.
+ - else
+ - Constructs an object of type T inside *this
+ - Any previous object stored in this any_function object is destructed and its
+ state is lost.
+ - returns a non-const reference to the newly created T object.
+ !*/
+
+ any_function& operator= (
+ const any_function& item
+ );
+ /*!
+ ensures
+ - copies the state of item into *this.
+ - Note that *this and item will contain independent copies of the
+ contents of item. That is, this function performs a deep
+ copy and therefore does not result in *this containing
+ any kind of reference to item.
+ !*/
+
+ void swap (
+ any_function& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename function_type
+ >
+ inline void swap (
+ any_function<function_type>& a,
+ any_function<function_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename function_type
+ >
+ T& any_cast(
+ any_function<function_type>& a
+ ) { return a.cast_to<T>(); }
+ /*!
+ ensures
+ - returns a.cast_to<T>()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename function_type
+ >
+ const T& any_cast(
+ const any_function<function_type>& a
+ ) { return a.cast_to<T>(); }
+ /*!
+ ensures
+ - returns a.cast_to<T>()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AnY_FUNCTION_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/any/any_function_impl.h b/ml/dlib/dlib/any/any_function_impl.h
new file mode 100644
index 000000000..fec66cde7
--- /dev/null
+++ b/ml/dlib/dlib/any/any_function_impl.h
@@ -0,0 +1,516 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ANY_FUNCTION_RETURN
+#error "You aren't supposed to directly #include this file. #include <dlib/any.h> instead."
+#endif
+
+#ifdef _MSC_VER
+// When using visual studio 2012, disable the warning "warning C4180: qualifier applied to function type has no meaning; ignored"
+// that you get about some template expansions applying & to function types.
+#pragma warning(disable : 4180)
+#endif
+
+#ifdef DLIB_ANY_FUNCTION_RETURN
+
+// This file contains the body of the any_function class. We use the
+// preprocessor to generate many different versions. There are
+// versions which return a value and those which return void. For
+// each of these types there are versions with differing numbers
+// of arguments.
+
+public:
+typedef typename sig_traits<function_type>::result_type result_type;
+typedef typename sig_traits<function_type>::arg1_type arg1_type;
+typedef typename sig_traits<function_type>::arg2_type arg2_type;
+typedef typename sig_traits<function_type>::arg3_type arg3_type;
+typedef typename sig_traits<function_type>::arg4_type arg4_type;
+typedef typename sig_traits<function_type>::arg5_type arg5_type;
+typedef typename sig_traits<function_type>::arg6_type arg6_type;
+typedef typename sig_traits<function_type>::arg7_type arg7_type;
+typedef typename sig_traits<function_type>::arg8_type arg8_type;
+typedef typename sig_traits<function_type>::arg9_type arg9_type;
+typedef typename sig_traits<function_type>::arg10_type arg10_type;
+const static unsigned long num_args = sig_traits<function_type>::num_args;
+
+any_function()
+{
+}
+
+any_function (
+ const any_function& item
+)
+{
+ if (item.data)
+ {
+ item.data->copy_to(data);
+ }
+}
+
+template <typename T>
+any_function (
+ const T& item
+)
+{
+ typedef typename basic_type<T>::type U;
+ data.reset(new derived<U,function_type>(item));
+}
+
+void clear (
+)
+{
+ data.reset();
+}
+
+template <typename T>
+bool contains (
+) const
+{
+ typedef typename basic_type<T>::type U;
+ return dynamic_cast<derived<U,function_type>*>(data.get()) != 0;
+}
+
+bool is_empty(
+) const
+{
+ return data.get() == 0;
+}
+
+bool is_set(
+) const
+{
+ return !is_empty();
+}
+
+template <typename T>
+T& cast_to(
+)
+{
+ typedef typename basic_type<T>::type U;
+ derived<U,function_type>* d = dynamic_cast<derived<U,function_type>*>(data.get());
+ if (d == 0)
+ {
+ throw bad_any_cast();
+ }
+
+ return d->item;
+}
+
+template <typename T>
+const T& cast_to(
+) const
+{
+ typedef typename basic_type<T>::type U;
+ derived<U,function_type>* d = dynamic_cast<derived<U,function_type>*>(data.get());
+ if (d == 0)
+ {
+ throw bad_any_cast();
+ }
+
+ return d->item;
+}
+
+template <typename T>
+T& get(
+)
+{
+ typedef typename basic_type<T>::type U;
+ derived<U,function_type>* d = dynamic_cast<derived<U,function_type>*>(data.get());
+ if (d == 0)
+ {
+ d = new derived<U,function_type>();
+ data.reset(d);
+ }
+
+ return d->item;
+}
+
+any_function& operator= (
+ const any_function& item
+)
+{
+ any_function(item).swap(*this);
+ return *this;
+}
+
+void swap (
+ any_function& item
+)
+{
+ data.swap(item.data);
+}
+
+result_type operator()(DLIB_ANY_FUNCTION_ARG_LIST) const
+{ validate(); DLIB_ANY_FUNCTION_RETURN data->evaluate(DLIB_ANY_FUNCTION_ARGS); }
+/* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to call a dlib::any_function but you have supplied
+ arguments which don't match the function signature used by the
+ dlib::any_function.
+!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+
+private:
+
+void validate () const
+{
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_empty() == false,
+ "\t result_type any_function::operator()"
+ << "\n\t You can't call operator() on an empty any_function"
+ << "\n\t this: " << this
+ );
+}
+
+
+template <typename FT>
+struct Tbase
+{
+ virtual ~Tbase() {}
+ virtual result_type evaluate () const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1
+ >
+struct Tbase<T (A1)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate ( A1) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1, typename A2
+ >
+struct Tbase<T (A1,A2)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate (A1,A2) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1, typename A2, typename A3
+ >
+struct Tbase<T (A1,A2,A3)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate (A1,A2,A3) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4
+ >
+struct Tbase<T (A1,A2,A3,A4)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate (A1,A2,A3,A4) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5
+ >
+struct Tbase<T (A1,A2,A3,A4,A5)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate (A1,A2,A3,A4,A5) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6
+ >
+struct Tbase<T (A1,A2,A3,A4,A5,A6)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate (A1,A2,A3,A4,A5,A6) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7
+ >
+struct Tbase<T (A1,A2,A3,A4,A5,A6,A7)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate (A1,A2,A3,A4,A5,A6,A7) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8
+ >
+struct Tbase<T (A1,A2,A3,A4,A5,A6,A7,A8)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate (A1,A2,A3,A4,A5,A6,A7,A8) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9
+ >
+struct Tbase<T (A1,A2,A3,A4,A5,A6,A7,A8,A9)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate (A1,A2,A3,A4,A5,A6,A7,A8,A9) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+template <
+ typename T,
+ typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10
+ >
+struct Tbase<T (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10)>
+{
+ virtual ~Tbase() {}
+ virtual T evaluate (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10) const = 0;
+ virtual void copy_to ( std::unique_ptr<Tbase>& dest) const = 0;
+};
+
+typedef Tbase<function_type> base;
+
+// -----------------------------------------------
+
+// Some templates to help deal with the weirdness of storing C function types (rather than pointer to functions).
+// Basically, we make sure things always get turned into function pointers even if the user gives a function reference.
+template <typename T, typename enabled = void>
+struct funct_type { typedef T type; };
+template <typename T>
+struct funct_type<T, typename enable_if<is_function<T> >::type> { typedef T* type; };
+
+template <typename T>
+static typename enable_if<is_function<T>,const T*>::type copy (const T& item) { return &item; }
+template <typename T>
+static typename disable_if<is_function<T>,const T&>::type copy (const T& item) { return item; }
+
+template <typename T, typename U>
+static typename enable_if<is_function<T>,const T&>::type deref (const U& item) { return *item; }
+template <typename T, typename U>
+static typename disable_if<is_function<T>,const T&>::type deref (const U& item) { return item; }
+
+// -----------------------------------------------
+
+#define DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE \
+ typename funct_type<T>::type item; \
+ derived() {} \
+ derived(const T& val) : item(copy(val)) {} \
+ virtual void copy_to ( std::unique_ptr<base>& dest) const \
+ { dest.reset(new derived(deref<T>(item))); }
+
+template <typename T, typename FT>
+struct derived : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ ) const { DLIB_ANY_FUNCTION_RETURN item(); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1>
+struct derived<T,result_type (A1)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1, typename A2>
+struct derived<T,result_type (A1,A2)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1, A2 a2
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1, typename A2, typename A3>
+struct derived<T,result_type (A1,A2,A3)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1, A2 a2, A3 a3
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1, typename A2, typename A3,
+ typename A4>
+struct derived<T,result_type (A1,A2,A3,A4)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1, A2 a2, A3 a3, A4 a4
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1, typename A2, typename A3,
+ typename A4, typename A5>
+struct derived<T,result_type (A1,A2,A3,A4,A5)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1, A2 a2, A3 a3, A4 a4, A5 a5
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6>
+struct derived<T,result_type (A1,A2,A3,A4,A5,A6)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7>
+struct derived<T,result_type (A1,A2,A3,A4,A5,A6,A7)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6, A7 a7
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6,a7); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8>
+struct derived<T,result_type (A1,A2,A3,A4,A5,A6,A7,A8)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6, A7 a7, A8 a8
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6,a7,a8); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9>
+struct derived<T,result_type (A1,A2,A3,A4,A5,A6,A7,A8,A9)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6, A7 a7, A8 a8, A9 a9
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6,a7,a8,a9); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+template <typename T, typename A1, typename A2, typename A3,
+ typename A4, typename A5, typename A6,
+ typename A7, typename A8, typename A9,
+ typename A10>
+struct derived<T,result_type (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10)> : public base
+{
+ DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+ virtual result_type evaluate (
+ A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6, A7 a7, A8 a8, A9 a9, A10 a10
+ ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6,a7,a8,a9,a10); }
+ /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ If you are getting an error on the above line then it means you
+ have attempted to assign a function or function object to a
+ dlib::any_function but the signatures of the source and
+ destination functions don't match.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+};
+
+std::unique_ptr<base> data;
+
+#undef DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE
+
+#endif // DLIB_ANY_FUNCTION_RETURN
+
diff --git a/ml/dlib/dlib/any/any_function_impl2.h b/ml/dlib/dlib/any/any_function_impl2.h
new file mode 100644
index 000000000..e1801ddc1
--- /dev/null
+++ b/ml/dlib/dlib/any/any_function_impl2.h
@@ -0,0 +1,44 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ANY_FUNCTION_ARG_LIST
+#error "You aren't supposed to directly #include this file. #include <dlib/any.h> instead."
+#endif
+
+#ifdef DLIB_ANY_FUNCTION_ARG_LIST
+
+// The case where function_type has a non-void return type
+ template <typename function_type, typename Enabled>
+ class any_function<function_type, Enabled, DLIB_ANY_FUNCTION_NUM_ARGS>
+ {
+#define DLIB_ANY_FUNCTION_RETURN return
+#include "any_function_impl.h"
+#undef DLIB_ANY_FUNCTION_RETURN
+
+ private:
+ // You get a compiler error about this function being private if you try to assign
+ // or copy between any_functions with different types. You must only copy between
+ // any_functions that represent functions with the same signature.
+ template <typename T, typename U> any_function(const any_function<T,U>&);
+ };
+
+// The case where function_type has a void return type
+ template <typename function_type>
+ class any_function<function_type, typename sig_traits<function_type>::type, DLIB_ANY_FUNCTION_NUM_ARGS>
+ {
+#define DLIB_ANY_FUNCTION_RETURN
+#include "any_function_impl.h"
+#undef DLIB_ANY_FUNCTION_RETURN
+
+ private:
+ // You get a compiler error about this function being private if you try to assign
+ // or copy between any_functions with different types. You must only copy between
+ // any_functions that represent functions with the same signature.
+ template <typename T> any_function(const any_function<T>&);
+ };
+
+#undef DLIB_ANY_FUNCTION_ARG_LIST
+#undef DLIB_ANY_FUNCTION_ARGS
+#undef DLIB_ANY_FUNCTION_NUM_ARGS
+
+#endif // DLIB_ANY_FUNCTION_ARG_LIST
+
diff --git a/ml/dlib/dlib/any/any_trainer.h b/ml/dlib/dlib/any/any_trainer.h
new file mode 100644
index 000000000..4df10a140
--- /dev/null
+++ b/ml/dlib/dlib/any/any_trainer.h
@@ -0,0 +1,217 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AnY_TRAINER_H_
+#define DLIB_AnY_TRAINER_H_
+
+#include "any.h"
+
+#include "any_decision_function.h"
+
+#include "any_trainer_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type_,
+ typename scalar_type_ = double
+ >
+ class any_trainer
+ {
+ public:
+ typedef sample_type_ sample_type;
+ typedef scalar_type_ scalar_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef any_decision_function<sample_type, scalar_type> trained_function_type;
+
+
+ any_trainer()
+ {
+ }
+
+ any_trainer (
+ const any_trainer& item
+ )
+ {
+ if (item.data)
+ {
+ item.data->copy_to(data);
+ }
+ }
+
+ template <typename T>
+ any_trainer (
+ const T& item
+ )
+ {
+ typedef typename basic_type<T>::type U;
+ data.reset(new derived<U>(item));
+ }
+
+ void clear (
+ )
+ {
+ data.reset();
+ }
+
+ template <typename T>
+ bool contains (
+ ) const
+ {
+ typedef typename basic_type<T>::type U;
+ return dynamic_cast<derived<U>*>(data.get()) != 0;
+ }
+
+ bool is_empty(
+ ) const
+ {
+ return data.get() == 0;
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& samples,
+ const std::vector<scalar_type>& labels
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_empty() == false,
+ "\t trained_function_type any_trainer::train()"
+ << "\n\t You can't call train() on an empty any_trainer"
+ << "\n\t this: " << this
+ );
+
+ return data->train(samples, labels);
+ }
+
+ template <typename T>
+ T& cast_to(
+ )
+ {
+ typedef typename basic_type<T>::type U;
+ derived<U>* d = dynamic_cast<derived<U>*>(data.get());
+ if (d == 0)
+ {
+ throw bad_any_cast();
+ }
+
+ return d->item;
+ }
+
+ template <typename T>
+ const T& cast_to(
+ ) const
+ {
+ typedef typename basic_type<T>::type U;
+ derived<U>* d = dynamic_cast<derived<U>*>(data.get());
+ if (d == 0)
+ {
+ throw bad_any_cast();
+ }
+
+ return d->item;
+ }
+
+ template <typename T>
+ T& get(
+ )
+ {
+ typedef typename basic_type<T>::type U;
+ derived<U>* d = dynamic_cast<derived<U>*>(data.get());
+ if (d == 0)
+ {
+ d = new derived<U>();
+ data.reset(d);
+ }
+
+ return d->item;
+ }
+
+ any_trainer& operator= (
+ const any_trainer& item
+ )
+ {
+ any_trainer(item).swap(*this);
+ return *this;
+ }
+
+ void swap (
+ any_trainer& item
+ )
+ {
+ data.swap(item.data);
+ }
+
+ private:
+
+ struct base
+ {
+ virtual ~base() {}
+
+ virtual trained_function_type train (
+ const std::vector<sample_type>& samples,
+ const std::vector<scalar_type>& labels
+ ) const = 0;
+
+ virtual void copy_to (
+ std::unique_ptr<base>& dest
+ ) const = 0;
+ };
+
+ template <typename T>
+ struct derived : public base
+ {
+ T item;
+ derived() {}
+ derived(const T& val) : item(val) {}
+
+ virtual void copy_to (
+ std::unique_ptr<base>& dest
+ ) const
+ {
+ dest.reset(new derived<T>(item));
+ }
+
+ virtual trained_function_type train (
+ const std::vector<sample_type>& samples,
+ const std::vector<scalar_type>& labels
+ ) const
+ {
+ return item.train(samples, labels);
+ }
+ };
+
+ std::unique_ptr<base> data;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type,
+ typename scalar_type
+ >
+ inline void swap (
+ any_trainer<sample_type,scalar_type>& a,
+ any_trainer<sample_type,scalar_type>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename V>
+ T& any_cast(any_trainer<U,V>& a) { return a.template cast_to<T>(); }
+
+ template <typename T, typename U, typename V>
+ const T& any_cast(const any_trainer<U,V>& a) { return a.template cast_to<T>(); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_AnY_TRAINER_H_
+
+
+
+
diff --git a/ml/dlib/dlib/any/any_trainer_abstract.h b/ml/dlib/dlib/any/any_trainer_abstract.h
new file mode 100644
index 000000000..877792fc1
--- /dev/null
+++ b/ml/dlib/dlib/any/any_trainer_abstract.h
@@ -0,0 +1,234 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_AnY_TRAINER_ABSTRACT_H_
+#ifdef DLIB_AnY_TRAINER_ABSTRACT_H_
+
+#include "any_abstract.h"
+#include "../algs.h"
+#include "any_decision_function_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type_,
+ typename scalar_type_ = double
+ >
+ class any_trainer
+ {
+ /*!
+ INITIAL VALUE
+ - is_empty() == true
+ - for all T: contains<T>() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a version of dlib::any that is restricted to containing
+ elements which are some kind of object with a .train() method compatible
+ with the following signature:
+
+ decision_function train(
+ const std::vector<sample_type>& samples,
+ const std::vector<scalar_type>& labels
+ ) const
+
+ Where decision_function is a type capable of being stored in an
+ any_decision_function<sample_type,scalar_type> object.
+
+ any_trainer is intended to be used to contain objects such as the svm_nu_trainer
+ and other similar types which represent supervised machine learning algorithms.
+ It allows you to write code which contains and processes these trainer objects
+ without needing to know the specific types of trainer objects used.
+ !*/
+
+ public:
+
+ typedef sample_type_ sample_type;
+ typedef scalar_type_ scalar_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef any_decision_function<sample_type, scalar_type> trained_function_type;
+
+ any_trainer(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ any_trainer (
+ const any_trainer& item
+ );
+ /*!
+ ensures
+ - copies the state of item into *this.
+ - Note that *this and item will contain independent copies of the
+ contents of item. That is, this function performs a deep
+ copy and therefore does not result in *this containing
+ any kind of reference to item.
+ !*/
+
+ template < typename T >
+ any_trainer (
+ const T& item
+ );
+ /*!
+ ensures
+ - #contains<T>() == true
+ - #cast_to<T>() == item
+ (i.e. a copy of item will be stored in *this)
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this will have its default value. I.e. #is_empty() == true
+ !*/
+
+ template <typename T>
+ bool contains (
+ ) const;
+ /*!
+ ensures
+ - if (this object currently contains an object of type T) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_empty(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains any kind of object) then
+ - returns false
+ - else
+ - returns true
+ !*/
+
+ trained_function_type train (
+ const std::vector<sample_type>& samples,
+ const std::vector<scalar_type>& labels
+ ) const
+ /*!
+ requires
+ - is_empty() == false
+ ensures
+ - Let TRAINER denote the object contained within *this. Then
+ this function performs:
+ return TRAINER.train(samples, labels)
+ !*/
+
+ template <typename T>
+ T& cast_to(
+ );
+ /*!
+ ensures
+ - if (contains<T>() == true) then
+ - returns a non-const reference to the object contained within *this
+ - else
+ - throws bad_any_cast
+ !*/
+
+ template <typename T>
+ const T& cast_to(
+ ) const;
+ /*!
+ ensures
+ - if (contains<T>() == true) then
+ - returns a const reference to the object contained within *this
+ - else
+ - throws bad_any_cast
+ !*/
+
+ template <typename T>
+ T& get(
+ );
+ /*!
+ ensures
+ - #is_empty() == false
+ - #contains<T>() == true
+ - if (contains<T>() == true)
+ - returns a non-const reference to the object contained in *this.
+ - else
+ - Constructs an object of type T inside *this
+ - Any previous object stored in this any_trainer object is destructed and its
+ state is lost.
+ - returns a non-const reference to the newly created T object.
+ !*/
+
+ any_trainer& operator= (
+ const any_trainer& item
+ );
+ /*!
+ ensures
+ - copies the state of item into *this.
+ - Note that *this and item will contain independent copies of the
+ contents of item. That is, this function performs a deep
+ copy and therefore does not result in *this containing
+ any kind of reference to item.
+ !*/
+
+ void swap (
+ any_trainer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type,
+ typename scalar_type
+ >
+ inline void swap (
+ any_trainer<sample_type,scalar_type>& a,
+ any_trainer<sample_type,scalar_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename sample_type,
+ typename scalar_type
+ >
+ T& any_cast(
+ any_trainer<sample_type,scalar_type>& a
+ ) { return a.cast_to<T>(); }
+ /*!
+ ensures
+ - returns a.cast_to<T>()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename sample_type,
+ typename scalar_type
+ >
+ const T& any_cast(
+ const any_trainer<sample_type,scalar_type>& a
+ ) { return a.cast_to<T>(); }
+ /*!
+ ensures
+ - returns a.cast_to<T>()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AnY_TRAINER_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/appveyor/dtest.yml b/ml/dlib/dlib/appveyor/dtest.yml
new file mode 100644
index 000000000..212528eb1
--- /dev/null
+++ b/ml/dlib/dlib/appveyor/dtest.yml
@@ -0,0 +1,19 @@
+version: "{build}"
+
+configuration: Release
+
+build_script:
+ # build test
+ - mkdir %APPVEYOR_BUILD_FOLDER%\build_test
+ - cd %APPVEYOR_BUILD_FOLDER%\build_test
+ - cmake -G "Visual Studio 14 2015 Win64" -T host=x64 ../dlib/test
+ - cmake --build . --config %CONFIGURATION% --target dlib_all_source_cpp
+ - cmake --build . --config %CONFIGURATION% --target imglab
+ - cmake --build . --config %CONFIGURATION% --target htmlify
+ - cmake --build . --config %CONFIGURATION% --target gui
+ - cmake --build . --config %CONFIGURATION% --target dtest
+
+test_script:
+ # run test
+ - cd %APPVEYOR_BUILD_FOLDER%\build_test\%CONFIGURATION%
+ - dtest --runall
diff --git a/ml/dlib/dlib/appveyor/dtest_vc2017.yml b/ml/dlib/dlib/appveyor/dtest_vc2017.yml
new file mode 100644
index 000000000..9e5e72ee9
--- /dev/null
+++ b/ml/dlib/dlib/appveyor/dtest_vc2017.yml
@@ -0,0 +1,21 @@
+version: "{build}"
+
+configuration: Release
+
+image: Visual Studio 2017
+
+build_script:
+ # build test
+ - mkdir %APPVEYOR_BUILD_FOLDER%\build_test
+ - cd %APPVEYOR_BUILD_FOLDER%\build_test
+ - cmake -G "Visual Studio 15 2017 Win64" -T host=x64 ../dlib/test
+ - cmake --build . --config %CONFIGURATION% --target dlib_all_source_cpp
+ - cmake --build . --config %CONFIGURATION% --target imglab
+ - cmake --build . --config %CONFIGURATION% --target htmlify
+ - cmake --build . --config %CONFIGURATION% --target gui
+ - cmake --build . --config %CONFIGURATION% --target dtest
+
+test_script:
+ # run test
+ - cd %APPVEYOR_BUILD_FOLDER%\build_test\%CONFIGURATION%
+ - dtest --runall
diff --git a/ml/dlib/dlib/appveyor/examples.yml b/ml/dlib/dlib/appveyor/examples.yml
new file mode 100644
index 000000000..55791d224
--- /dev/null
+++ b/ml/dlib/dlib/appveyor/examples.yml
@@ -0,0 +1,16 @@
+version: "{build}"
+
+configuration: Release
+
+image: Visual Studio 2017
+
+build_script:
+ # build test
+ - mkdir %APPVEYOR_BUILD_FOLDER%\build_examples
+ - cd %APPVEYOR_BUILD_FOLDER%\build_examples
+ #- cmake -G "Visual Studio 14 2015 Win64" -T host=x64 ../examples
+ - cmake -G "Visual Studio 15 2017 Win64" -T host=x64 ../examples
+ - cmake --build . --config %CONFIGURATION%
+
+test_script:
+ # run test
diff --git a/ml/dlib/dlib/appveyor/python.yml b/ml/dlib/dlib/appveyor/python.yml
new file mode 100644
index 000000000..2e1466084
--- /dev/null
+++ b/ml/dlib/dlib/appveyor/python.yml
@@ -0,0 +1,33 @@
+
+environment:
+ matrix:
+ - PYTHON: "C:\\Python27"
+ PYTHON_VERSION: "2.7.x"
+ PYTHON_ARCH: "32"
+
+ - PYTHON: "C:\\Python27-x64"
+ PYTHON_VERSION: "2.7.x"
+ PYTHON_ARCH: "64"
+
+ - PYTHON: "C:\\Python33"
+ PYTHON_VERSION: "3.3.x"
+ PYTHON_ARCH: "32"
+
+ - PYTHON: "C:\\Python33-x64"
+ PYTHON_VERSION: "3.3.x"
+ PYTHON_ARCH: "64"
+
+ - PYTHON: "C:\\Python35"
+ PYTHON_VERSION: "3.5.x"
+ PYTHON_ARCH: "32"
+
+ - PYTHON: "C:\\Python35-x64"
+ PYTHON_VERSION: "3.5.x"
+ PYTHON_ARCH: "64"
+
+
+build_script:
+ - python setup.py build
+
+test_script:
+ - python setup.py test
diff --git a/ml/dlib/dlib/array.h b/ml/dlib/dlib/array.h
new file mode 100644
index 000000000..ecdafc497
--- /dev/null
+++ b/ml/dlib/dlib/array.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ARRAy_
+#define DLIB_ARRAy_
+
+#include "array/array_kernel.h"
+#include "array/array_tools.h"
+
+#endif // DLIB_ARRAy_
+
diff --git a/ml/dlib/dlib/array/array_kernel.h b/ml/dlib/dlib/array/array_kernel.h
new file mode 100644
index 000000000..48160941b
--- /dev/null
+++ b/ml/dlib/dlib/array/array_kernel.h
@@ -0,0 +1,810 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ARRAY_KERNEl_2_
+#define DLIB_ARRAY_KERNEl_2_
+
+#include "array_kernel_abstract.h"
+#include "../interfaces/enumerable.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "../sort.h"
+#include "../is_kind.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class array : public enumerable<T>
+ {
+
+ /*!
+ INITIAL VALUE
+ - array_size == 0
+ - max_array_size == 0
+ - array_elements == 0
+ - pos == 0
+ - last_pos == 0
+ - _at_start == true
+
+ CONVENTION
+ - array_size == size()
+ - max_array_size == max_size()
+ - if (max_array_size > 0)
+ - array_elements == pointer to max_array_size elements of type T
+ - else
+ - array_elements == 0
+
+ - if (array_size > 0)
+ - last_pos == array_elements + array_size - 1
+ - else
+ - last_pos == 0
+
+
+ - at_start() == _at_start
+ - current_element_valid() == pos != 0
+ - if (current_element_valid()) then
+ - *pos == element()
+ !*/
+
+ public:
+
+ // These typedefs are here for backwards compatibility with old versions of dlib.
+ typedef array kernel_1a;
+ typedef array kernel_1a_c;
+ typedef array kernel_2a;
+ typedef array kernel_2a_c;
+ typedef array sort_1a;
+ typedef array sort_1a_c;
+ typedef array sort_1b;
+ typedef array sort_1b_c;
+ typedef array sort_2a;
+ typedef array sort_2a_c;
+ typedef array sort_2b;
+ typedef array sort_2b_c;
+ typedef array expand_1a;
+ typedef array expand_1a_c;
+ typedef array expand_1b;
+ typedef array expand_1b_c;
+ typedef array expand_1c;
+ typedef array expand_1c_c;
+ typedef array expand_1d;
+ typedef array expand_1d_c;
+
+
+
+
+ typedef T type;
+ typedef T value_type;
+ typedef mem_manager mem_manager_type;
+
+ array (
+ ) :
+ array_size(0),
+ max_array_size(0),
+ array_elements(0),
+ pos(0),
+ last_pos(0),
+ _at_start(true)
+ {}
+
+ array(
+ array&& item
+ ) : array()
+ {
+ swap(item);
+ }
+
+ array& operator=(
+ array&& item
+ )
+ {
+ swap(item);
+ return *this;
+ }
+
+ explicit array (
+ size_t new_size
+ ) :
+ array_size(0),
+ max_array_size(0),
+ array_elements(0),
+ pos(0),
+ last_pos(0),
+ _at_start(true)
+ {
+ resize(new_size);
+ }
+
+ ~array (
+ );
+
+ void clear (
+ );
+
+ inline const T& operator[] (
+ size_t pos
+ ) const;
+
+ inline T& operator[] (
+ size_t pos
+ );
+
+ void set_size (
+ size_t size
+ );
+
+ inline size_t max_size(
+ ) const;
+
+ void set_max_size(
+ size_t max
+ );
+
+ void swap (
+ array& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ inline const T& element (
+ ) const;
+
+ inline T& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ void sort (
+ );
+
+ void resize (
+ size_t new_size
+ );
+
+ const T& back (
+ ) const;
+
+ T& back (
+ );
+
+ void pop_back (
+ );
+
+ void pop_back (
+ T& item
+ );
+
+ void push_back (
+ T& item
+ );
+
+ void push_back (
+ T&& item
+ );
+
+ typedef T* iterator;
+ typedef const T* const_iterator;
+ iterator begin() { return array_elements; }
+ const_iterator begin() const { return array_elements; }
+ iterator end() { return array_elements+array_size; }
+ const_iterator end() const { return array_elements+array_size; }
+
+ private:
+
+ typename mem_manager::template rebind<T>::other pool;
+
+ // data members
+ size_t array_size;
+ size_t max_array_size;
+ T* array_elements;
+
+ mutable T* pos;
+ T* last_pos;
+ mutable bool _at_start;
+
+ // restricted functions
+ array(array<T>&); // copy constructor
+ array<T>& operator=(array<T>&); // assignment operator
+
+ };
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ array<T,mem_manager>& a,
+ array<T,mem_manager>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void serialize (
+ const array<T,mem_manager>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.max_size(),out);
+ serialize(item.size(),out);
+
+ for (size_t i = 0; i < item.size(); ++i)
+ serialize(item[i],out);
+ }
+ catch (serialization_error e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type array");
+ }
+ }
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ array<T,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ size_t max_size, size;
+ deserialize(max_size,in);
+ deserialize(size,in);
+ item.set_max_size(max_size);
+ item.set_size(size);
+ for (size_t i = 0; i < size; ++i)
+ deserialize(item[i],in);
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type array");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ array<T,mem_manager>::
+ ~array (
+ )
+ {
+ if (array_elements)
+ {
+ pool.deallocate_array(array_elements);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ clear (
+ )
+ {
+ reset();
+ last_pos = 0;
+ array_size = 0;
+ if (array_elements)
+ {
+ pool.deallocate_array(array_elements);
+ }
+ array_elements = 0;
+ max_array_size = 0;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& array<T,mem_manager>::
+ operator[] (
+ size_t pos
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( pos < this->size() ,
+ "\tconst T& array::operator[]"
+ << "\n\tpos must < size()"
+ << "\n\tpos: " << pos
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ return array_elements[pos];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& array<T,mem_manager>::
+ operator[] (
+ size_t pos
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( pos < this->size() ,
+ "\tT& array::operator[]"
+ << "\n\tpos must be < size()"
+ << "\n\tpos: " << pos
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ return array_elements[pos];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ set_size (
+ size_t size
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( size <= this->max_size() ),
+ "\tvoid array::set_size"
+ << "\n\tsize must be <= max_size()"
+ << "\n\tsize: " << size
+ << "\n\tmax size: " << this->max_size()
+ << "\n\tthis: " << this
+ );
+
+ reset();
+ array_size = size;
+ if (size > 0)
+ last_pos = array_elements + size - 1;
+ else
+ last_pos = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ size_t array<T,mem_manager>::
+ size (
+ ) const
+ {
+ return array_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ set_max_size(
+ size_t max
+ )
+ {
+ reset();
+ array_size = 0;
+ last_pos = 0;
+ if (max != 0)
+ {
+ // if new max size is different
+ if (max != max_array_size)
+ {
+ if (array_elements)
+ {
+ pool.deallocate_array(array_elements);
+ }
+ // try to get more memroy
+ try { array_elements = pool.allocate_array(max); }
+ catch (...) { array_elements = 0; max_array_size = 0; throw; }
+ max_array_size = max;
+ }
+
+ }
+ // if the array is being made to be zero
+ else
+ {
+ if (array_elements)
+ pool.deallocate_array(array_elements);
+ max_array_size = 0;
+ array_elements = 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ size_t array<T,mem_manager>::
+ max_size (
+ ) const
+ {
+ return max_array_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ swap (
+ array<T,mem_manager>& item
+ )
+ {
+ auto array_size_temp = item.array_size;
+ auto max_array_size_temp = item.max_array_size;
+ T* array_elements_temp = item.array_elements;
+
+ item.array_size = array_size;
+ item.max_array_size = max_array_size;
+ item.array_elements = array_elements;
+
+ array_size = array_size_temp;
+ max_array_size = max_array_size_temp;
+ array_elements = array_elements_temp;
+
+ exchange(_at_start,item._at_start);
+ exchange(pos,item.pos);
+ exchange(last_pos,item.last_pos);
+ pool.swap(item.pool);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool array<T,mem_manager>::
+ at_start (
+ ) const
+ {
+ return _at_start;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ reset (
+ ) const
+ {
+ _at_start = true;
+ pos = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool array<T,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return pos != 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& array<T,mem_manager>::
+ element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(this->current_element_valid(),
+ "\tconst T& array::element()"
+ << "\n\tThe current element must be valid if you are to access it."
+ << "\n\tthis: " << this
+ );
+
+ return *pos;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& array<T,mem_manager>::
+ element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(this->current_element_valid(),
+ "\tT& array::element()"
+ << "\n\tThe current element must be valid if you are to access it."
+ << "\n\tthis: " << this
+ );
+
+ return *pos;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool array<T,mem_manager>::
+ move_next (
+ ) const
+ {
+ if (!_at_start)
+ {
+ if (pos < last_pos)
+ {
+ ++pos;
+ return true;
+ }
+ else
+ {
+ pos = 0;
+ return false;
+ }
+ }
+ else
+ {
+ _at_start = false;
+ if (array_size > 0)
+ {
+ pos = array_elements;
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Yet more functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ sort (
+ )
+ {
+ if (this->size() > 1)
+ {
+ // call the quick sort function for arrays that is in algs.h
+ dlib::qsort_array(*this,0,this->size()-1);
+ }
+ this->reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ resize (
+ size_t new_size
+ )
+ {
+ if (this->max_size() < new_size)
+ {
+ array temp;
+ temp.set_max_size(new_size);
+ temp.set_size(new_size);
+ for (size_t i = 0; i < this->size(); ++i)
+ {
+ exchange((*this)[i],temp[i]);
+ }
+ temp.swap(*this);
+ }
+ else
+ {
+ this->set_size(new_size);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& array<T,mem_manager>::
+ back (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( this->size() > 0 ,
+ "\tT& array::back()"
+ << "\n\tsize() must be bigger than 0"
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ return (*this)[this->size()-1];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& array<T,mem_manager>::
+ back (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( this->size() > 0 ,
+ "\tconst T& array::back()"
+ << "\n\tsize() must be bigger than 0"
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ return (*this)[this->size()-1];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ pop_back (
+ T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( this->size() > 0 ,
+ "\tvoid array::pop_back()"
+ << "\n\tsize() must be bigger than 0"
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ exchange(item,(*this)[this->size()-1]);
+ this->set_size(this->size()-1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ pop_back (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( this->size() > 0 ,
+ "\tvoid array::pop_back()"
+ << "\n\tsize() must be bigger than 0"
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ this->set_size(this->size()-1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ push_back (
+ T& item
+ )
+ {
+ if (this->max_size() == this->size())
+ {
+ // double the size of the array
+ array temp;
+ temp.set_max_size(this->size()*2 + 1);
+ temp.set_size(this->size()+1);
+ for (size_t i = 0; i < this->size(); ++i)
+ {
+ exchange((*this)[i],temp[i]);
+ }
+ exchange(item,temp[temp.size()-1]);
+ temp.swap(*this);
+ }
+ else
+ {
+ this->set_size(this->size()+1);
+ exchange(item,(*this)[this->size()-1]);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array<T,mem_manager>::
+ push_back (
+ T&& item
+ ) { push_back(item); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename MM>
+ struct is_array <array<T,MM> >
+ {
+ const static bool value = true;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ARRAY_KERNEl_2_
+
diff --git a/ml/dlib/dlib/array/array_kernel_abstract.h b/ml/dlib/dlib/array/array_kernel_abstract.h
new file mode 100644
index 000000000..5cfdd483a
--- /dev/null
+++ b/ml/dlib/dlib/array/array_kernel_abstract.h
@@ -0,0 +1,360 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ARRAY_KERNEl_ABSTRACT_
+#ifdef DLIB_ARRAY_KERNEl_ABSTRACT_
+
+#include "../interfaces/enumerable.h"
+#include "../serialize.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class array : public enumerable<T>
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must have a default constructor.
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ front(), back(), swap(), max_size(), set_size(), and operator[]
+ functions do not invalidate pointers or references to internal data.
+ All other functions have no such guarantee.
+
+ INITIAL VALUE
+ size() == 0
+ max_size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the elements of the array in the
+ order (*this)[0], (*this)[1], (*this)[2], ...
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an ordered 1-dimensional array of items,
+ each item is associated with an integer value. The items are
+ numbered from 0 though size() - 1 and the operator[] functions
+ run in constant time.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef T value_type;
+ typedef mem_manager mem_manager_type;
+
+ array (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ !*/
+
+ explicit array (
+ size_t new_size
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #size() == new_size
+ - #max_size() == new_size
+ - All elements of the array will have initial values for their type.
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ !*/
+
+ ~array (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ array(
+ array&& item
+ );
+ /*!
+ ensures
+ - move constructs *this from item. Therefore, the state of item is
+ moved into *this and #item has a valid but unspecified state.
+ !*/
+
+ array& operator=(
+ array&& item
+ );
+ /*!
+ ensures
+ - move assigns *this from item. Therefore, the state of item is
+ moved into *this and #item has a valid but unspecified state.
+ - returns a reference to #*this
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if this exception is thrown then the array object is unusable
+ until clear() is called and succeeds
+ !*/
+
+ const T& operator[] (
+ size_t pos
+ ) const;
+ /*!
+ requires
+ - pos < size()
+ ensures
+ - returns a const reference to the element at position pos
+ !*/
+
+ T& operator[] (
+ size_t pos
+ );
+ /*!
+ requires
+ - pos < size()
+ ensures
+ - returns a non-const reference to the element at position pos
+ !*/
+
+ void set_size (
+ size_t size
+ );
+ /*!
+ requires
+ - size <= max_size()
+ ensures
+ - #size() == size
+ - any element with index between 0 and size - 1 which was in the
+ array before the call to set_size() retains its value and index.
+ All other elements have undetermined (but valid for their type)
+ values. (e.g. this object might buffer old T objects and reuse
+ them without reinitializing them between calls to set_size())
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ may throw this exception if there is not enough memory and
+ if it does throw then the call to set_size() has no effect
+ !*/
+
+ size_t max_size(
+ ) const;
+ /*!
+ ensures
+ - returns the maximum size of *this
+ !*/
+
+ void set_max_size(
+ size_t max
+ );
+ /*!
+ ensures
+ - #max_size() == max
+ - #size() == 0
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ may throw this exception if there is not enough
+ memory and if it does throw then max_size() == 0
+ !*/
+
+ void swap (
+ array<T>& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ void sort (
+ );
+ /*!
+ requires
+ - T must be a type with that is comparable via operator<
+ ensures
+ - for all elements in #*this the ith element is <= the i+1 element
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ data may be lost if sort() throws
+ !*/
+
+ void resize (
+ size_t new_size
+ );
+ /*!
+ ensures
+ - #size() == new_size
+ - #max_size() == max(new_size,max_size())
+ - for all i < size() && i < new_size:
+ - #(*this)[i] == (*this)[i]
+ (i.e. All the original elements of *this which were at index
+ values less than new_size are unmodified.)
+ - for all valid i >= size():
+ - #(*this)[i] has an undefined value
+ (i.e. any new elements of the array have an undefined value)
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If an exception is thrown then it has no effect on *this.
+ !*/
+
+
+ const T& back (
+ ) const;
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - returns a const reference to (*this)[size()-1]
+ !*/
+
+ T& back (
+ );
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - returns a non-const reference to (*this)[size()-1]
+ !*/
+
+ void pop_back (
+ T& item
+ );
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - #size() == size() - 1
+ - swaps (*this)[size()-1] into item
+ - All elements with an index less than size()-1 are
+ unmodified by this operation.
+ !*/
+
+ void pop_back (
+ );
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - #size() == size() - 1
+ - All elements with an index less than size()-1 are
+ unmodified by this operation.
+ !*/
+
+ void push_back (
+ T& item
+ );
+ /*!
+ ensures
+ - #size() == size()+1
+ - swaps item into (*this)[#size()-1]
+ - #back() == item
+ - #item has some undefined value (whatever happens to
+ get swapped out of the array)
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If an exception is thrown then it has no effect on *this.
+ !*/
+
+ void push_back (T&& item) { push_back(item); }
+ /*!
+ enable push_back from rvalues
+ !*/
+
+ typedef T* iterator;
+ typedef const T* const_iterator;
+
+ iterator begin(
+ );
+ /*!
+ ensures
+ - returns an iterator that points to the first element in this array or
+ end() if the array is empty.
+ !*/
+
+ const_iterator begin(
+ ) const;
+ /*!
+ ensures
+ - returns a const iterator that points to the first element in this
+ array or end() if the array is empty.
+ !*/
+
+ iterator end(
+ );
+ /*!
+ ensures
+ - returns an iterator that points to one past the end of the array.
+ !*/
+
+ const_iterator end(
+ ) const;
+ /*!
+ ensures
+ - returns a const iterator that points to one past the end of the
+ array.
+ !*/
+
+ private:
+
+ // restricted functions
+ array(array<T>&); // copy constructor
+ array<T>& operator=(array<T>&); // assignment operator
+
+ };
+
+ template <
+ typename T
+ >
+ inline void swap (
+ array<T>& a,
+ array<T>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T
+ >
+ void serialize (
+ const array<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ array<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+}
+
+#endif // DLIB_ARRAY_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/array/array_tools.h b/ml/dlib/dlib/array/array_tools.h
new file mode 100644
index 000000000..fce634396
--- /dev/null
+++ b/ml/dlib/dlib/array/array_tools.h
@@ -0,0 +1,38 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ARRAY_tOOLS_H_
+#define DLIB_ARRAY_tOOLS_H_
+
+#include "../assert.h"
+#include "array_tools_abstract.h"
+
+namespace dlib
+{
+ template <typename T>
+ void split_array (
+ T& a,
+ T& b,
+ double frac
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= frac && frac <= 1,
+ "\t void split_array()"
+ << "\n\t frac must be between 0 and 1."
+ << "\n\t frac: " << frac
+ );
+
+ const unsigned long asize = static_cast<unsigned long>(a.size()*frac);
+ const unsigned long bsize = a.size()-asize;
+
+ b.resize(bsize);
+ for (unsigned long i = 0; i < b.size(); ++i)
+ {
+ swap(b[i], a[i+asize]);
+ }
+ a.resize(asize);
+ }
+}
+
+#endif // DLIB_ARRAY_tOOLS_H_
+
diff --git a/ml/dlib/dlib/array/array_tools_abstract.h b/ml/dlib/dlib/array/array_tools_abstract.h
new file mode 100644
index 000000000..e9b957518
--- /dev/null
+++ b/ml/dlib/dlib/array/array_tools_abstract.h
@@ -0,0 +1,33 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ARRAY_tOOLS_ABSTRACT_H_
+#ifdef DLIB_ARRAY_tOOLS_ABSTRACT_H_
+
+#include "array_kernel_abstract.h"
+
+namespace dlib
+{
+ template <typename T>
+ void split_array (
+ T& a,
+ T& b,
+ double frac
+ );
+ /*!
+ requires
+ - 0 <= frac <= 1
+ - T must be an array type such as dlib::array or std::vector
+ ensures
+ - This function takes the elements of a and splits them into two groups. The
+ first group remains in a and the second group is put into b. The ordering of
+ elements in a is preserved. In particular, concatenating #a with #b will
+ reproduce the original contents of a.
+ - The elements in a are moved around using global swap(). So they must be
+ swappable, but do not need to be copyable.
+ - #a.size() == floor(a.size()*frac)
+ - #b.size() == a.size()-#a.size()
+ !*/
+}
+
+#endif // DLIB_ARRAY_tOOLS_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/array2d.h b/ml/dlib/dlib/array2d.h
new file mode 100644
index 000000000..f5325e4a2
--- /dev/null
+++ b/ml/dlib/dlib/array2d.h
@@ -0,0 +1,12 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ARRAY2d_
+#define DLIB_ARRAY2d_
+
+
+#include "array2d/array2d_kernel.h"
+#include "array2d/serialize_pixel_overloads.h"
+#include "array2d/array2d_generic_image.h"
+
+#endif // DLIB_ARRAY2d_
+
diff --git a/ml/dlib/dlib/array2d/array2d_generic_image.h b/ml/dlib/dlib/array2d/array2d_generic_image.h
new file mode 100644
index 000000000..a96f5e3c2
--- /dev/null
+++ b/ml/dlib/dlib/array2d/array2d_generic_image.h
@@ -0,0 +1,67 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ARRAY2D_GENERIC_iMAGE_Hh_
+#define DLIB_ARRAY2D_GENERIC_iMAGE_Hh_
+
+#include "array2d_kernel.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+ template <typename T, typename mm>
+ struct image_traits<array2d<T,mm> >
+ {
+ typedef T pixel_type;
+ };
+ template <typename T, typename mm>
+ struct image_traits<const array2d<T,mm> >
+ {
+ typedef T pixel_type;
+ };
+
+ template <typename T, typename mm>
+ inline long num_rows( const array2d<T,mm>& img) { return img.nr(); }
+ template <typename T, typename mm>
+ inline long num_columns( const array2d<T,mm>& img) { return img.nc(); }
+
+ template <typename T, typename mm>
+ inline void set_image_size(
+ array2d<T,mm>& img,
+ long rows,
+ long cols
+ ) { img.set_size(rows,cols); }
+
+ template <typename T, typename mm>
+ inline void* image_data(
+ array2d<T,mm>& img
+ )
+ {
+ if (img.size() != 0)
+ return &img[0][0];
+ else
+ return 0;
+ }
+
+ template <typename T, typename mm>
+ inline const void* image_data(
+ const array2d<T,mm>& img
+ )
+ {
+ if (img.size() != 0)
+ return &img[0][0];
+ else
+ return 0;
+ }
+
+ template <typename T, typename mm>
+ inline long width_step(
+ const array2d<T,mm>& img
+ )
+ {
+ return img.width_step();
+ }
+
+}
+
+#endif // DLIB_ARRAY2D_GENERIC_iMAGE_Hh_
+
diff --git a/ml/dlib/dlib/array2d/array2d_kernel.h b/ml/dlib/dlib/array2d/array2d_kernel.h
new file mode 100644
index 000000000..597112341
--- /dev/null
+++ b/ml/dlib/dlib/array2d/array2d_kernel.h
@@ -0,0 +1,498 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ARRAY2D_KERNEl_1_
+#define DLIB_ARRAY2D_KERNEl_1_
+
+#include "array2d_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../serialize.h"
+#include "../geometry/rectangle.h"
+
+namespace dlib
+{
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class array2d : public enumerable<T>
+ {
+
+ /*!
+ INITIAL VALUE
+ - nc_ == 0
+ - nr_ == 0
+ - data == 0
+ - at_start_ == true
+ - cur == 0
+ - last == 0
+
+ CONVENTION
+ - nc_ == nc()
+ - nr_ == nc()
+ - if (data != 0) then
+ - last == a pointer to the last element in the data array
+ - data == pointer to an array of nc_*nr_ T objects
+ - else
+ - nc_ == 0
+ - nr_ == 0
+ - data == 0
+ - last == 0
+
+
+ - nr_ * nc_ == size()
+ - if (cur == 0) then
+ - current_element_valid() == false
+ - else
+ - current_element_valid() == true
+ - *cur == element()
+
+ - at_start_ == at_start()
+ !*/
+
+
+ class row_helper;
+ public:
+
+ // These typedefs are here for backwards compatibility with older versions of dlib.
+ typedef array2d kernel_1a;
+ typedef array2d kernel_1a_c;
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ // -----------------------------------
+
+ class row
+ {
+ /*!
+ CONVENTION
+ - nc_ == nc()
+ - for all x < nc_:
+ - (*this)[x] == data[x]
+ !*/
+
+ friend class array2d<T,mem_manager>;
+ friend class row_helper;
+
+ public:
+ long nc (
+ ) const { return nc_; }
+
+ const T& operator[] (
+ long column
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(column < nc() && column >= 0,
+ "\tconst T& array2d::operator[](long column) const"
+ << "\n\tThe column index given must be less than the number of columns."
+ << "\n\tthis: " << this
+ << "\n\tcolumn: " << column
+ << "\n\tnc(): " << nc()
+ );
+
+ return data[column];
+ }
+
+ T& operator[] (
+ long column
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(column < nc() && column >= 0,
+ "\tT& array2d::operator[](long column)"
+ << "\n\tThe column index given must be less than the number of columns."
+ << "\n\tthis: " << this
+ << "\n\tcolumn: " << column
+ << "\n\tnc(): " << nc()
+ );
+
+ return data[column];
+ }
+
+ private:
+
+ row(T* data_, long cols) : data(data_), nc_(cols) {}
+
+ T* data;
+ long nc_;
+
+
+ // restricted functions
+ row(){}
+ row& operator=(row&);
+ };
+
+ // -----------------------------------
+
+ array2d (
+ ) :
+ data(0),
+ nc_(0),
+ nr_(0),
+ cur(0),
+ last(0),
+ at_start_(true)
+ {
+ }
+
+ array2d(
+ long rows,
+ long cols
+ ) :
+ data(0),
+ nc_(0),
+ nr_(0),
+ cur(0),
+ last(0),
+ at_start_(true)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT((cols >= 0 && rows >= 0),
+ "\t array2d::array2d(long rows, long cols)"
+ << "\n\t The array2d can't have negative rows or columns."
+ << "\n\t this: " << this
+ << "\n\t cols: " << cols
+ << "\n\t rows: " << rows
+ );
+
+ set_size(rows,cols);
+ }
+
+ array2d(const array2d&) = delete; // copy constructor
+ array2d& operator=(const array2d&) = delete; // assignment operator
+
+#ifdef DLIB_HAS_RVALUE_REFERENCES
+ array2d(array2d&& item) : array2d()
+ {
+ swap(item);
+ }
+
+ array2d& operator= (
+ array2d&& rhs
+ )
+ {
+ swap(rhs);
+ return *this;
+ }
+#endif
+
+ virtual ~array2d (
+ ) { clear(); }
+
+ long nc (
+ ) const { return nc_; }
+
+ long nr (
+ ) const { return nr_; }
+
+ row operator[] (
+ long row_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(row_ < nr() && row_ >= 0,
+ "\trow array2d::operator[](long row_)"
+ << "\n\tThe row index given must be less than the number of rows."
+ << "\n\tthis: " << this
+ << "\n\trow_: " << row_
+ << "\n\tnr(): " << nr()
+ );
+
+ return row(data+row_*nc_, nc_);
+ }
+
+ const row operator[] (
+ long row_
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(row_ < nr() && row_ >= 0,
+ "\tconst row array2d::operator[](long row_) const"
+ << "\n\tThe row index given must be less than the number of rows."
+ << "\n\tthis: " << this
+ << "\n\trow_: " << row_
+ << "\n\tnr(): " << nr()
+ );
+
+ return row(data+row_*nc_, nc_);
+ }
+
+ void swap (
+ array2d& item
+ )
+ {
+ exchange(data,item.data);
+ exchange(nr_,item.nr_);
+ exchange(nc_,item.nc_);
+ exchange(at_start_,item.at_start_);
+ exchange(cur,item.cur);
+ exchange(last,item.last);
+ pool.swap(item.pool);
+ }
+
+ void clear (
+ )
+ {
+ if (data != 0)
+ {
+ pool.deallocate_array(data);
+ nc_ = 0;
+ nr_ = 0;
+ data = 0;
+ at_start_ = true;
+ cur = 0;
+ last = 0;
+ }
+ }
+
+ void set_size (
+ long rows,
+ long cols
+ );
+
+ bool at_start (
+ ) const { return at_start_; }
+
+ void reset (
+ ) const { at_start_ = true; cur = 0; }
+
+ bool current_element_valid (
+ ) const { return (cur != 0); }
+
+ const T& element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_element_valid() == true,
+ "\tconst T& array2d::element()()"
+ << "\n\tYou can only call element() when you are at a valid one."
+ << "\n\tthis: " << this
+ );
+
+ return *cur;
+ }
+
+ T& element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_element_valid() == true,
+ "\tT& array2d::element()()"
+ << "\n\tYou can only call element() when you are at a valid one."
+ << "\n\tthis: " << this
+ );
+
+ return *cur;
+ }
+
+ bool move_next (
+ ) const
+ {
+ if (cur != 0)
+ {
+ if (cur != last)
+ {
+ ++cur;
+ return true;
+ }
+ cur = 0;
+ return false;
+ }
+ else if (at_start_)
+ {
+ cur = data;
+ at_start_ = false;
+ return (data != 0);
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ size_t size (
+ ) const { return static_cast<size_t>(nc_) * static_cast<size_t>(nr_); }
+
+ long width_step (
+ ) const
+ {
+ return nc_*sizeof(T);
+ }
+
+ private:
+
+
+ T* data;
+ long nc_;
+ long nr_;
+
+ typename mem_manager::template rebind<T>::other pool;
+ mutable T* cur;
+ T* last;
+ mutable bool at_start_;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ array2d<T,mem_manager>& a,
+ array2d<T,mem_manager>& b
+ ) { a.swap(b); }
+
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void serialize (
+ const array2d<T,mem_manager>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ // The reason the serialization is a little funny is because we are trying to
+ // maintain backwards compatibility with an older serialization format used by
+ // dlib while also encoding things in a way that lets the array2d and matrix
+ // objects have compatible serialization formats.
+ serialize(-item.nr(),out);
+ serialize(-item.nc(),out);
+
+ item.reset();
+ while (item.move_next())
+ serialize(item.element(),out);
+ item.reset();
+ }
+ catch (serialization_error e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type array2d");
+ }
+ }
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ array2d<T,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ long nr, nc;
+ deserialize(nr,in);
+ deserialize(nc,in);
+
+ // this is the newer serialization format
+ if (nr < 0 || nc < 0)
+ {
+ nr *= -1;
+ nc *= -1;
+ }
+ else
+ {
+ std::swap(nr,nc);
+ }
+
+ item.set_size(nr,nc);
+
+ while (item.move_next())
+ deserialize(item.element(),in);
+ item.reset();
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type array2d");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void array2d<T,mem_manager>::
+ set_size (
+ long rows,
+ long cols
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT((cols >= 0 && rows >= 0) ,
+ "\tvoid array2d::set_size(long rows, long cols)"
+ << "\n\tThe array2d can't have negative rows or columns."
+ << "\n\tthis: " << this
+ << "\n\tcols: " << cols
+ << "\n\trows: " << rows
+ );
+
+ // set the enumerator back at the start
+ at_start_ = true;
+ cur = 0;
+
+ // don't do anything if we are already the right size.
+ if (nc_ == cols && nr_ == rows)
+ {
+ return;
+ }
+
+ nc_ = cols;
+ nr_ = rows;
+
+ // free any existing memory
+ if (data != 0)
+ {
+ pool.deallocate_array(data);
+ data = 0;
+ }
+
+ // now setup this object to have the new size
+ try
+ {
+ if (nr_ > 0)
+ {
+ data = pool.allocate_array(nr_*nc_);
+ last = data + nr_*nc_ - 1;
+ }
+ }
+ catch (...)
+ {
+ if (data)
+ pool.deallocate_array(data);
+
+ data = 0;
+ nc_ = 0;
+ nr_ = 0;
+ last = 0;
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename MM>
+ struct is_array2d <array2d<T,MM> >
+ {
+ const static bool value = true;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ARRAY2D_KERNEl_1_
+
diff --git a/ml/dlib/dlib/array2d/array2d_kernel_abstract.h b/ml/dlib/dlib/array2d/array2d_kernel_abstract.h
new file mode 100644
index 000000000..cbb0e9b2b
--- /dev/null
+++ b/ml/dlib/dlib/array2d/array2d_kernel_abstract.h
@@ -0,0 +1,301 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ARRAY2D_KERNEl_ABSTRACT_
+#ifdef DLIB_ARRAY2D_KERNEl_ABSTRACT_
+
+#include "../interfaces/enumerable.h"
+#include "../serialize.h"
+#include "../algs.h"
+#include "../geometry/rectangle_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class array2d : public enumerable<T>
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must have a default constructor.
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ No member functions in this object will invalidate pointers
+ or references to internal data except for the set_size()
+ and clear() member functions.
+
+ INITIAL VALUE
+ nr() == 0
+ nc() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the elements of the array starting
+ with row 0 and then proceeding to row 1 and so on. Each row will be
+ fully enumerated before proceeding on to the next row and the elements
+ in a row will be enumerated beginning with the 0th column, then the 1st
+ column and so on.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a 2-Dimensional array of objects of
+ type T.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+
+ Finally, note that this object stores its data contiguously and in
+ row major order. Moreover, there is no padding at the end of each row.
+ This means that its width_step() value is always equal to sizeof(type)*nc().
+ !*/
+
+
+ public:
+
+ // ----------------------------------------
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ // ----------------------------------------
+
+ class row
+ {
+ /*!
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ No member functions in this object will invalidate pointers
+ or references to internal data.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a row of Ts in an array2d object.
+ !*/
+ public:
+ long nc (
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in this row
+ !*/
+
+ const T& operator[] (
+ long column
+ ) const;
+ /*!
+ requires
+ - 0 <= column < nc()
+ ensures
+ - returns a const reference to the T in the given column
+ !*/
+
+ T& operator[] (
+ long column
+ );
+ /*!
+ requires
+ - 0 <= column < nc()
+ ensures
+ - returns a non-const reference to the T in the given column
+ !*/
+
+ private:
+ // restricted functions
+ row();
+ row& operator=(row&);
+ };
+
+ // ----------------------------------------
+
+ array2d (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ array2d(const array2d&) = delete; // copy constructor
+ array2d& operator=(const array2d&) = delete; // assignment operator
+
+ array2d(
+ array2d&& item
+ );
+ /*!
+ ensures
+ - Moves the state of item into *this.
+ - #item is in a valid but unspecified state.
+ !*/
+
+ array2d (
+ long rows,
+ long cols
+ );
+ /*!
+ requires
+ - rows >= 0 && cols >= 0
+ ensures
+ - #nc() == cols
+ - #nr() == rows
+ - #at_start() == true
+ - all elements in this array have initial values for their type
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~array2d (
+ );
+ /*!
+ ensures
+ - all resources associated with *this has been released
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this has an initial value for its type
+ !*/
+
+ long nc (
+ ) const;
+ /*!
+ ensures
+ - returns the number of elements there are in a row. i.e. returns
+ the number of columns in *this
+ !*/
+
+ long nr (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in *this
+ !*/
+
+ void set_size (
+ long rows,
+ long cols
+ );
+ /*!
+ requires
+ - rows >= 0 && cols >= 0
+ ensures
+ - #nc() == cols
+ - #nr() == rows
+ - #at_start() == true
+ - if (the call to set_size() doesn't change the dimensions of this array) then
+ - all elements in this array retain their values from before this function was called
+ - else
+ - all elements in this array have initial values for their type
+ throws
+ - std::bad_alloc
+ If this exception is thrown then #*this will have an initial
+ value for its type.
+ !*/
+
+ row operator[] (
+ long row_index
+ );
+ /*!
+ requires
+ - 0 <= row_index < nr()
+ ensures
+ - returns a non-const row of nc() elements that represents the
+ given row_index'th row in *this.
+ !*/
+
+ const row operator[] (
+ long row_index
+ ) const;
+ /*!
+ requires
+ - 0 <= row_index < nr()
+ ensures
+ - returns a const row of nc() elements that represents the
+ given row_index'th row in *this.
+ !*/
+
+ void swap (
+ array2d& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ array2d& operator= (
+ array2d&& rhs
+ );
+ /*!
+ ensures
+ - Moves the state of item into *this.
+ - #item is in a valid but unspecified state.
+ - returns #*this
+ !*/
+
+ long width_step (
+ ) const;
+ /*!
+ ensures
+ - returns the size of one row of the image, in bytes.
+ More precisely, return a number N such that:
+ (char*)&item[0][0] + N == (char*)&item[1][0].
+ - for dlib::array2d objects, the returned value
+ is always equal to sizeof(type)*nc(). However,
+ other objects which implement dlib::array2d style
+ interfaces might have padding at the ends of their
+ rows and therefore might return larger numbers.
+ An example of such an object is the dlib::cv_image.
+ !*/
+
+ };
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ array2d<T,mem_manager>& a,
+ array2d<T,mem_manager>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void serialize (
+ const array2d<T,mem_manager>& item,
+ std::ostream& out
+ );
+ /*!
+ Provides serialization support. Note that the serialization formats used by the
+ dlib::matrix and dlib::array2d objects are compatible. That means you can load the
+ serialized data from one into another and it will work properly.
+ !*/
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ array2d<T,mem_manager>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+}
+
+#endif // DLIB_ARRAY2D_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/array2d/serialize_pixel_overloads.h b/ml/dlib/dlib/array2d/serialize_pixel_overloads.h
new file mode 100644
index 000000000..9ce2c4a13
--- /dev/null
+++ b/ml/dlib/dlib/array2d/serialize_pixel_overloads.h
@@ -0,0 +1,371 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_
+#define DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_
+
+#include "array2d_kernel.h"
+#include "../pixel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*
+ This file contains overloads of the serialize functions for array2d object
+ for the case where they contain simple 8bit POD pixel types. In these
+ cases we can perform a much faster serialization by writing data in chunks
+ instead of one pixel at a time (this avoids a lot of function call overhead
+ inside the iostreams).
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename mem_manager
+ >
+ void serialize (
+ const array2d<rgb_pixel,mem_manager>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ // The reason the serialization is a little funny is because we are trying to
+ // maintain backwards compatibility with an older serialization format used by
+ // dlib while also encoding things in a way that lets the array2d and matrix
+ // objects have compatible serialization formats.
+ serialize(-item.nr(),out);
+ serialize(-item.nc(),out);
+
+ COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3);
+
+ if (item.size() != 0)
+ out.write((char*)&item[0][0], sizeof(rgb_pixel)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type array2d");
+ }
+ }
+
+ template <
+ typename mem_manager
+ >
+ void deserialize (
+ array2d<rgb_pixel,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3);
+
+ long nr, nc;
+ deserialize(nr,in);
+ deserialize(nc,in);
+
+ // this is the newer serialization format
+ if (nr < 0 || nc < 0)
+ {
+ nr *= -1;
+ nc *= -1;
+ }
+ else
+ {
+ std::swap(nr,nc);
+ }
+
+ item.set_size(nr,nc);
+
+ if (item.size() != 0)
+ in.read((char*)&item[0][0], sizeof(rgb_pixel)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type array2d");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename mem_manager
+ >
+ void serialize (
+ const array2d<bgr_pixel,mem_manager>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ // The reason the serialization is a little funny is because we are trying to
+ // maintain backwards compatibility with an older serialization format used by
+ // dlib while also encoding things in a way that lets the array2d and matrix
+ // objects have compatible serialization formats.
+ serialize(-item.nr(),out);
+ serialize(-item.nc(),out);
+
+ COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3);
+
+ if (item.size() != 0)
+ out.write((char*)&item[0][0], sizeof(bgr_pixel)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type array2d");
+ }
+ }
+
+ template <
+ typename mem_manager
+ >
+ void deserialize (
+ array2d<bgr_pixel,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3);
+
+ long nr, nc;
+ deserialize(nr,in);
+ deserialize(nc,in);
+
+ // this is the newer serialization format
+ if (nr < 0 || nc < 0)
+ {
+ nr *= -1;
+ nc *= -1;
+ }
+ else
+ {
+ std::swap(nr,nc);
+ }
+
+
+ item.set_size(nr,nc);
+
+ if (item.size() != 0)
+ in.read((char*)&item[0][0], sizeof(bgr_pixel)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type array2d");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename mem_manager
+ >
+ void serialize (
+ const array2d<hsi_pixel,mem_manager>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ // The reason the serialization is a little funny is because we are trying to
+ // maintain backwards compatibility with an older serialization format used by
+ // dlib while also encoding things in a way that lets the array2d and matrix
+ // objects have compatible serialization formats.
+ serialize(-item.nr(),out);
+ serialize(-item.nc(),out);
+
+ COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3);
+
+ if (item.size() != 0)
+ out.write((char*)&item[0][0], sizeof(hsi_pixel)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type array2d");
+ }
+ }
+
+ template <
+ typename mem_manager
+ >
+ void deserialize (
+ array2d<hsi_pixel,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3);
+
+ long nr, nc;
+ deserialize(nr,in);
+ deserialize(nc,in);
+
+ // this is the newer serialization format
+ if (nr < 0 || nc < 0)
+ {
+ nr *= -1;
+ nc *= -1;
+ }
+ else
+ {
+ std::swap(nr,nc);
+ }
+
+
+ item.set_size(nr,nc);
+
+ if (item.size() != 0)
+ in.read((char*)&item[0][0], sizeof(hsi_pixel)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type array2d");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename mem_manager
+ >
+ void serialize (
+ const array2d<rgb_alpha_pixel,mem_manager>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ // The reason the serialization is a little funny is because we are trying to
+ // maintain backwards compatibility with an older serialization format used by
+ // dlib while also encoding things in a way that lets the array2d and matrix
+ // objects have compatible serialization formats.
+ serialize(-item.nr(),out);
+ serialize(-item.nc(),out);
+
+ COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4);
+
+ if (item.size() != 0)
+ out.write((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type array2d");
+ }
+ }
+
+ template <
+ typename mem_manager
+ >
+ void deserialize (
+ array2d<rgb_alpha_pixel,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4);
+
+ long nr, nc;
+ deserialize(nr,in);
+ deserialize(nc,in);
+
+ // this is the newer serialization format
+ if (nr < 0 || nc < 0)
+ {
+ nr *= -1;
+ nc *= -1;
+ }
+ else
+ {
+ std::swap(nr,nc);
+ }
+
+
+ item.set_size(nr,nc);
+
+ if (item.size() != 0)
+ in.read((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type array2d");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename mem_manager
+ >
+ void serialize (
+ const array2d<unsigned char,mem_manager>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ // The reason the serialization is a little funny is because we are trying to
+ // maintain backwards compatibility with an older serialization format used by
+ // dlib while also encoding things in a way that lets the array2d and matrix
+ // objects have compatible serialization formats.
+ serialize(-item.nr(),out);
+ serialize(-item.nc(),out);
+
+ if (item.size() != 0)
+ out.write((char*)&item[0][0], sizeof(unsigned char)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type array2d");
+ }
+ }
+
+ template <
+ typename mem_manager
+ >
+ void deserialize (
+ array2d<unsigned char,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ long nr, nc;
+ deserialize(nr,in);
+ deserialize(nc,in);
+ // this is the newer serialization format
+ if (nr < 0 || nc < 0)
+ {
+ nr *= -1;
+ nc *= -1;
+ }
+ else
+ {
+ std::swap(nr,nc);
+ }
+
+
+ item.set_size(nr,nc);
+
+ if (item.size() != 0)
+ in.read((char*)&item[0][0], sizeof(unsigned char)*item.size());
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type array2d");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_
+
diff --git a/ml/dlib/dlib/assert.h b/ml/dlib/dlib/assert.h
new file mode 100644
index 000000000..2220dd73a
--- /dev/null
+++ b/ml/dlib/dlib/assert.h
@@ -0,0 +1,216 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ASSERt_
+#define DLIB_ASSERt_
+
+#include "config.h"
+#include <sstream>
+#include <iosfwd>
+#include "error.h"
+
+// -----------------------------
+
+// Use some stuff from boost here
+// (C) Copyright John Maddock 2001 - 2003.
+// (C) Copyright Darin Adler 2001.
+// (C) Copyright Peter Dimov 2001.
+// (C) Copyright Bill Kempf 2002.
+// (C) Copyright Jens Maurer 2002.
+// (C) Copyright David Abrahams 2002 - 2003.
+// (C) Copyright Gennaro Prota 2003.
+// (C) Copyright Eric Friedman 2003.
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef BOOST_JOIN
+#define BOOST_JOIN( X, Y ) BOOST_DO_JOIN( X, Y )
+#define BOOST_DO_JOIN( X, Y ) BOOST_DO_JOIN2(X,Y)
+#define BOOST_DO_JOIN2( X, Y ) X##Y
+#endif
+
+// figure out if the compiler has rvalue references.
+#if defined(__clang__)
+# if __has_feature(cxx_rvalue_references)
+# define DLIB_HAS_RVALUE_REFERENCES
+# endif
+# if __has_feature(cxx_generalized_initializers)
+# define DLIB_HAS_INITIALIZER_LISTS
+# endif
+#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__)
+# define DLIB_HAS_RVALUE_REFERENCES
+# define DLIB_HAS_INITIALIZER_LISTS
+#elif defined(_MSC_VER) && _MSC_VER >= 1800
+# define DLIB_HAS_INITIALIZER_LISTS
+# define DLIB_HAS_RVALUE_REFERENCES
+#elif defined(_MSC_VER) && _MSC_VER >= 1600
+# define DLIB_HAS_RVALUE_REFERENCES
+#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X)
+# define DLIB_HAS_RVALUE_REFERENCES
+# define DLIB_HAS_INITIALIZER_LISTS
+#endif
+
+#if defined(__APPLE__) && defined(__GNUC_LIBSTD__) && ((__GNUC_LIBSTD__-0) * 100 + __GNUC_LIBSTD_MINOR__-0 <= 402)
+ // Apple has not updated libstdc++ in some time and anything under 4.02 does not have <initializer_list> for sure.
+# undef DLIB_HAS_INITIALIZER_LISTS
+#endif
+
+// figure out if the compiler has static_assert.
+#if defined(__clang__)
+# if __has_feature(cxx_static_assert)
+# define DLIB_HAS_STATIC_ASSERT
+# endif
+#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__)
+# define DLIB_HAS_STATIC_ASSERT
+#elif defined(_MSC_VER) && _MSC_VER >= 1600
+# define DLIB_HAS_STATIC_ASSERT
+#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X)
+# define DLIB_HAS_STATIC_ASSERT
+#endif
+
+
+// -----------------------------
+
+namespace dlib
+{
+ template <bool value> struct compile_time_assert;
+ template <> struct compile_time_assert<true> { enum {value=1}; };
+
+ template <typename T, typename U> struct assert_are_same_type;
+ template <typename T> struct assert_are_same_type<T,T> {enum{value=1};};
+ template <typename T, typename U> struct assert_are_not_same_type {enum{value=1}; };
+ template <typename T> struct assert_are_not_same_type<T,T> {};
+
+ template <typename T, typename U> struct assert_types_match {enum{value=0};};
+ template <typename T> struct assert_types_match<T,T> {enum{value=1};};
+}
+
+
+// gcc 4.8 will warn about unused typedefs. But we use typedefs in some of the compile
+// time assert macros so we need to make it not complain about them "not being used".
+#ifdef __GNUC__
+#define DLIB_NO_WARN_UNUSED __attribute__ ((unused))
+#else
+#define DLIB_NO_WARN_UNUSED
+#endif
+
+// Use the newer static_assert if it's available since it produces much more readable error
+// messages.
+#ifdef DLIB_HAS_STATIC_ASSERT
+ #define COMPILE_TIME_ASSERT(expression) static_assert(expression, "Failed assertion")
+ #define ASSERT_ARE_SAME_TYPE(type1, type2) static_assert(::dlib::assert_types_match<type1,type2>::value, "These types should be the same but aren't.")
+ #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) static_assert(!::dlib::assert_types_match<type1,type2>::value, "These types should NOT be the same.")
+#else
+ #define COMPILE_TIME_ASSERT(expression) \
+ DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DLIB_CTA, __LINE__)[::dlib::compile_time_assert<(bool)(expression)>::value]
+
+ #define ASSERT_ARE_SAME_TYPE(type1, type2) \
+ DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DLIB_AAST, __LINE__)[::dlib::assert_are_same_type<type1,type2>::value]
+
+ #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) \
+ DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DLIB_AANST, __LINE__)[::dlib::assert_are_not_same_type<type1,type2>::value]
+#endif
+
+// -----------------------------
+
+#if defined DLIB_DISABLE_ASSERTS
+ // if DLIB_DISABLE_ASSERTS is on then never enable DLIB_ASSERT no matter what.
+ #undef ENABLE_ASSERTS
+#endif
+
+#if !defined(DLIB_DISABLE_ASSERTS) && ( defined DEBUG || defined _DEBUG)
+ // make sure ENABLE_ASSERTS is defined if we are indeed using them.
+ #ifndef ENABLE_ASSERTS
+ #define ENABLE_ASSERTS
+ #endif
+#endif
+
+// -----------------------------
+
+#ifdef __GNUC__
+// There is a bug in version 4.4.5 of GCC on Ubuntu which causes GCC to segfault
+// when __PRETTY_FUNCTION__ is used within certain templated functions. So just
+// don't use it with this version of GCC.
+# if !(__GNUC__ == 4 && __GNUC_MINOR__ == 4 && __GNUC_PATCHLEVEL__ == 5)
+# define DLIB_FUNCTION_NAME __PRETTY_FUNCTION__
+# else
+# define DLIB_FUNCTION_NAME "unknown function"
+# endif
+#elif defined(_MSC_VER)
+#define DLIB_FUNCTION_NAME __FUNCSIG__
+#else
+#define DLIB_FUNCTION_NAME "unknown function"
+#endif
+
+#define DLIBM_CASSERT(_exp,_message) \
+ {if ( !(_exp) ) \
+ { \
+ dlib_assert_breakpoint(); \
+ std::ostringstream dlib_o_out; \
+ dlib_o_out << "\n\nError detected at line " << __LINE__ << ".\n"; \
+ dlib_o_out << "Error detected in file " << __FILE__ << ".\n"; \
+ dlib_o_out << "Error detected in function " << DLIB_FUNCTION_NAME << ".\n\n"; \
+ dlib_o_out << "Failing expression was " << #_exp << ".\n"; \
+ dlib_o_out << std::boolalpha << _message << "\n"; \
+ throw dlib::fatal_error(dlib::EBROKEN_ASSERT,dlib_o_out.str()); \
+ }}
+
+// This macro is not needed if you have a real C++ compiler. It's here to work around bugs in Visual Studio's preprocessor.
+#define DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(x) x
+// Make it so the 2nd argument of DLIB_CASSERT is optional. That is, you can call it like
+// DLIB_CASSERT(exp) or DLIB_CASSERT(exp,message).
+#define DLIBM_CASSERT_1_ARGS(exp) DLIBM_CASSERT(exp,"")
+#define DLIBM_CASSERT_2_ARGS(exp,message) DLIBM_CASSERT(exp,message)
+#define DLIBM_GET_3TH_ARG(arg1, arg2, arg3, ...) arg3
+#define DLIBM_CASSERT_CHOOSER(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_GET_3TH_ARG(__VA_ARGS__, DLIBM_CASSERT_2_ARGS, DLIBM_CASSERT_1_ARGS))
+#define DLIB_CASSERT(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_CASSERT_CHOOSER(__VA_ARGS__)(__VA_ARGS__))
+
+
+#ifdef ENABLE_ASSERTS
+ #define DLIB_ASSERT(...) DLIB_CASSERT(__VA_ARGS__)
+ #define DLIB_IF_ASSERT(exp) exp
+#else
+ #define DLIB_ASSERT(...) {}
+ #define DLIB_IF_ASSERT(exp)
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A DLIB_ASSERT_HAS_STANDARD_LAYOUT
+
+ This macro is meant to cause a compiler error if a type doesn't have a simple
+ memory layout (like a C struct). In particular, types with simple layouts are
+ ones which can be copied via memcpy().
+
+
+ This was called a POD type in C++03 and in C++0x we are looking to check if
+ it is a "standard layout type". Once we can use C++0x we can change this macro
+ to something that uses the std::is_standard_layout type_traits class.
+ See: http://www2.research.att.com/~bs/C++0xFAQ.html#PODs
+ !*/
+ // Use the fact that in C++03 you can't put non-PODs into a union.
+#define DLIB_ASSERT_HAS_STANDARD_LAYOUT(type) \
+ union BOOST_JOIN(DAHSL_,__LINE__) { type TYPE_NOT_STANDARD_LAYOUT; }; \
+ DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DAHSL2_,__LINE__)[sizeof(BOOST_JOIN(DAHSL_,__LINE__))];
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+// breakpoints
+extern "C"
+{
+ inline void dlib_assert_breakpoint(
+ ) {}
+ /*!
+ ensures
+ - this function does nothing
+ It exists just so you can put breakpoints on it in a debugging tool.
+ It is called only when an DLIB_ASSERT or DLIB_CASSERT fails and is about to
+ throw an exception.
+ !*/
+}
+
+// -----------------------------
+
+#include "stack_trace.h"
+
+#endif // DLIB_ASSERt_
+
diff --git a/ml/dlib/dlib/base64.h b/ml/dlib/dlib/base64.h
new file mode 100644
index 000000000..8308920d6
--- /dev/null
+++ b/ml/dlib/dlib/base64.h
@@ -0,0 +1,9 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BASe64_
+#define DLIB_BASe64_
+
+#include "base64/base64_kernel_1.h"
+
+#endif // DLIB_BASe64_
+
diff --git a/ml/dlib/dlib/base64/base64_kernel_1.cpp b/ml/dlib/dlib/base64/base64_kernel_1.cpp
new file mode 100644
index 000000000..5b48c789e
--- /dev/null
+++ b/ml/dlib/dlib/base64/base64_kernel_1.cpp
@@ -0,0 +1,403 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BASE64_KERNEL_1_CPp_
+#define DLIB_BASE64_KERNEL_1_CPp_
+
+#include "base64_kernel_1.h"
+#include <iostream>
+#include <sstream>
+#include <climits>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ base64::line_ending_type base64::
+ line_ending (
+ ) const
+ {
+ return eol_style;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base64::
+ set_line_ending (
+ line_ending_type eol_style_
+ )
+ {
+ eol_style = eol_style_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ base64::
+ base64 (
+ ) :
+ encode_table(0),
+ decode_table(0),
+ bad_value(100),
+ eol_style(LF)
+ {
+ try
+ {
+ encode_table = new char[64];
+ decode_table = new unsigned char[UCHAR_MAX];
+ }
+ catch (...)
+ {
+ if (encode_table) delete [] encode_table;
+ if (decode_table) delete [] decode_table;
+ throw;
+ }
+
+ // now set up the tables with the right stuff
+ encode_table[0] = 'A';
+ encode_table[17] = 'R';
+ encode_table[34] = 'i';
+ encode_table[51] = 'z';
+
+ encode_table[1] = 'B';
+ encode_table[18] = 'S';
+ encode_table[35] = 'j';
+ encode_table[52] = '0';
+
+ encode_table[2] = 'C';
+ encode_table[19] = 'T';
+ encode_table[36] = 'k';
+ encode_table[53] = '1';
+
+ encode_table[3] = 'D';
+ encode_table[20] = 'U';
+ encode_table[37] = 'l';
+ encode_table[54] = '2';
+
+ encode_table[4] = 'E';
+ encode_table[21] = 'V';
+ encode_table[38] = 'm';
+ encode_table[55] = '3';
+
+ encode_table[5] = 'F';
+ encode_table[22] = 'W';
+ encode_table[39] = 'n';
+ encode_table[56] = '4';
+
+ encode_table[6] = 'G';
+ encode_table[23] = 'X';
+ encode_table[40] = 'o';
+ encode_table[57] = '5';
+
+ encode_table[7] = 'H';
+ encode_table[24] = 'Y';
+ encode_table[41] = 'p';
+ encode_table[58] = '6';
+
+ encode_table[8] = 'I';
+ encode_table[25] = 'Z';
+ encode_table[42] = 'q';
+ encode_table[59] = '7';
+
+ encode_table[9] = 'J';
+ encode_table[26] = 'a';
+ encode_table[43] = 'r';
+ encode_table[60] = '8';
+
+ encode_table[10] = 'K';
+ encode_table[27] = 'b';
+ encode_table[44] = 's';
+ encode_table[61] = '9';
+
+ encode_table[11] = 'L';
+ encode_table[28] = 'c';
+ encode_table[45] = 't';
+ encode_table[62] = '+';
+
+ encode_table[12] = 'M';
+ encode_table[29] = 'd';
+ encode_table[46] = 'u';
+ encode_table[63] = '/';
+
+ encode_table[13] = 'N';
+ encode_table[30] = 'e';
+ encode_table[47] = 'v';
+
+ encode_table[14] = 'O';
+ encode_table[31] = 'f';
+ encode_table[48] = 'w';
+
+ encode_table[15] = 'P';
+ encode_table[32] = 'g';
+ encode_table[49] = 'x';
+
+ encode_table[16] = 'Q';
+ encode_table[33] = 'h';
+ encode_table[50] = 'y';
+
+
+
+ // we can now fill out the decode_table by using the encode_table
+ for (int i = 0; i < UCHAR_MAX; ++i)
+ {
+ decode_table[i] = bad_value;
+ }
+ for (unsigned char i = 0; i < 64; ++i)
+ {
+ decode_table[(unsigned char)encode_table[i]] = i;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ base64::
+ ~base64 (
+ )
+ {
+ delete [] encode_table;
+ delete [] decode_table;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base64::
+ encode (
+ std::istream& in_,
+ std::ostream& out_
+ ) const
+ {
+ using namespace std;
+ streambuf& in = *in_.rdbuf();
+ streambuf& out = *out_.rdbuf();
+
+ unsigned char inbuf[3];
+ unsigned char outbuf[4];
+ streamsize status = in.sgetn(reinterpret_cast<char*>(&inbuf),3);
+
+ unsigned char c1, c2, c3, c4, c5, c6;
+
+ int counter = 19;
+
+ // while we haven't hit the end of the input stream
+ while (status != 0)
+ {
+ if (counter == 0)
+ {
+ counter = 19;
+ // write a newline
+ char ch;
+ switch (eol_style)
+ {
+ case CR:
+ ch = '\r';
+ if (out.sputn(&ch,1)!=1)
+ throw std::ios_base::failure("error occurred in the base64 object");
+ break;
+ case LF:
+ ch = '\n';
+ if (out.sputn(&ch,1)!=1)
+ throw std::ios_base::failure("error occurred in the base64 object");
+ break;
+ case CRLF:
+ ch = '\r';
+ if (out.sputn(&ch,1)!=1)
+ throw std::ios_base::failure("error occurred in the base64 object");
+ ch = '\n';
+ if (out.sputn(&ch,1)!=1)
+ throw std::ios_base::failure("error occurred in the base64 object");
+ break;
+ default:
+ DLIB_CASSERT(false,"this should never happen");
+ }
+ }
+ --counter;
+
+ if (status == 3)
+ {
+ // encode the bytes in inbuf to base64 and write them to the output stream
+ c1 = inbuf[0]&0xfc;
+ c2 = inbuf[0]&0x03;
+ c3 = inbuf[1]&0xf0;
+ c4 = inbuf[1]&0x0f;
+ c5 = inbuf[2]&0xc0;
+ c6 = inbuf[2]&0x3f;
+
+ outbuf[0] = c1>>2;
+ outbuf[1] = (c2<<4)|(c3>>4);
+ outbuf[2] = (c4<<2)|(c5>>6);
+ outbuf[3] = c6;
+
+
+ outbuf[0] = encode_table[outbuf[0]];
+ outbuf[1] = encode_table[outbuf[1]];
+ outbuf[2] = encode_table[outbuf[2]];
+ outbuf[3] = encode_table[outbuf[3]];
+
+ // write the encoded bytes to the output stream
+ if (out.sputn(reinterpret_cast<char*>(&outbuf),4)!=4)
+ {
+ throw std::ios_base::failure("error occurred in the base64 object");
+ }
+
+ // get 3 more input bytes
+ status = in.sgetn(reinterpret_cast<char*>(&inbuf),3);
+ continue;
+ }
+ else if (status == 2)
+ {
+ // we are at the end of the input stream and need to add some padding
+
+ // encode the bytes in inbuf to base64 and write them to the output stream
+ c1 = inbuf[0]&0xfc;
+ c2 = inbuf[0]&0x03;
+ c3 = inbuf[1]&0xf0;
+ c4 = inbuf[1]&0x0f;
+ c5 = 0;
+
+ outbuf[0] = c1>>2;
+ outbuf[1] = (c2<<4)|(c3>>4);
+ outbuf[2] = (c4<<2)|(c5>>6);
+ outbuf[3] = '=';
+
+ outbuf[0] = encode_table[outbuf[0]];
+ outbuf[1] = encode_table[outbuf[1]];
+ outbuf[2] = encode_table[outbuf[2]];
+
+ // write the encoded bytes to the output stream
+ if (out.sputn(reinterpret_cast<char*>(&outbuf),4)!=4)
+ {
+ throw std::ios_base::failure("error occurred in the base64 object");
+ }
+
+
+ break;
+ }
+ else // in this case status must be 1
+ {
+ // we are at the end of the input stream and need to add some padding
+
+ // encode the bytes in inbuf to base64 and write them to the output stream
+ c1 = inbuf[0]&0xfc;
+ c2 = inbuf[0]&0x03;
+ c3 = 0;
+
+ outbuf[0] = c1>>2;
+ outbuf[1] = (c2<<4)|(c3>>4);
+ outbuf[2] = '=';
+ outbuf[3] = '=';
+
+ outbuf[0] = encode_table[outbuf[0]];
+ outbuf[1] = encode_table[outbuf[1]];
+
+
+ // write the encoded bytes to the output stream
+ if (out.sputn(reinterpret_cast<char*>(&outbuf),4)!=4)
+ {
+ throw std::ios_base::failure("error occurred in the base64 object");
+ }
+
+ break;
+ }
+ } // while (status != 0)
+
+
+ // make sure the stream buffer flushes to its I/O channel
+ out.pubsync();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base64::
+ decode (
+ std::istream& in_,
+ std::ostream& out_
+ ) const
+ {
+ using namespace std;
+ streambuf& in = *in_.rdbuf();
+ streambuf& out = *out_.rdbuf();
+
+ unsigned char inbuf[4];
+ unsigned char outbuf[3];
+ int inbuf_pos = 0;
+ streamsize status = in.sgetn(reinterpret_cast<char*>(inbuf),1);
+
+ // only count this character if it isn't some kind of filler
+ if (status == 1 && decode_table[inbuf[0]] != bad_value )
+ ++inbuf_pos;
+
+ unsigned char c1, c2, c3, c4, c5, c6;
+ streamsize outsize;
+
+ // while we haven't hit the end of the input stream
+ while (status != 0)
+ {
+ // if we have 4 valid characters
+ if (inbuf_pos == 4)
+ {
+ inbuf_pos = 0;
+
+ // this might be the end of the encoded data so we need to figure out if
+ // there was any padding applied.
+ outsize = 3;
+ if (inbuf[3] == '=')
+ {
+ if (inbuf[2] == '=')
+ outsize = 1;
+ else
+ outsize = 2;
+ }
+
+ // decode the incoming characters
+ inbuf[0] = decode_table[inbuf[0]];
+ inbuf[1] = decode_table[inbuf[1]];
+ inbuf[2] = decode_table[inbuf[2]];
+ inbuf[3] = decode_table[inbuf[3]];
+
+
+ // now pack these guys into bytes rather than 6 bit chunks
+ c1 = inbuf[0]<<2;
+ c2 = inbuf[1]>>4;
+ c3 = inbuf[1]<<4;
+ c4 = inbuf[2]>>2;
+ c5 = inbuf[2]<<6;
+ c6 = inbuf[3];
+
+ outbuf[0] = c1|c2;
+ outbuf[1] = c3|c4;
+ outbuf[2] = c5|c6;
+
+
+ // write the encoded bytes to the output stream
+ if (out.sputn(reinterpret_cast<char*>(&outbuf),outsize)!=outsize)
+ {
+ throw std::ios_base::failure("error occurred in the base64 object");
+ }
+ }
+
+ // get more input characters
+ status = in.sgetn(reinterpret_cast<char*>(inbuf + inbuf_pos),1);
+ // only count this character if it isn't some kind of filler
+ if ((decode_table[inbuf[inbuf_pos]] != bad_value || inbuf[inbuf_pos] == '=') &&
+ status != 0)
+ ++inbuf_pos;
+ } // while (status != 0)
+
+ if (inbuf_pos != 0)
+ {
+ ostringstream sout;
+ sout << inbuf_pos << " extra characters were found at the end of the encoded data."
+ << " This may indicate that the data stream has been truncated.";
+ // this happens if we hit EOF in the middle of decoding a 24bit block.
+ throw decode_error(sout.str());
+ }
+
+ // make sure the stream buffer flushes to its I/O channel
+ out.pubsync();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BASE64_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/base64/base64_kernel_1.h b/ml/dlib/dlib/base64/base64_kernel_1.h
new file mode 100644
index 000000000..d8f49b1b8
--- /dev/null
+++ b/ml/dlib/dlib/base64/base64_kernel_1.h
@@ -0,0 +1,92 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BASE64_KERNEl_1_
+#define DLIB_BASE64_KERNEl_1_
+
+#include "../algs.h"
+#include "base64_kernel_abstract.h"
+#include <iosfwd>
+
+namespace dlib
+{
+
+ class base64
+ {
+ /*!
+ INITIAL VALUE
+ - bad_value == 100
+ - encode_table == a pointer to an array of 64 chars
+ - where x is a 6 bit value the following is true:
+ - encode_table[x] == the base64 encoding of x
+ - decode_table == a pointer to an array of UCHAR_MAX chars
+ - where x is any char value:
+ - if (x is a valid character in the base64 coding scheme) then
+ - decode_table[x] == the 6 bit value that x encodes
+ - else
+ - decode_table[x] == bad_value
+
+ CONVENTION
+ - The state of this object never changes so just refer to its
+ initial value.
+
+
+ !*/
+
+ public:
+ // this is here for backwards compatibility with older versions of dlib.
+ typedef base64 kernel_1a;
+
+ class decode_error : public dlib::error { public:
+ decode_error( const std::string& e) : error(e) {}};
+
+ base64 (
+ );
+
+ virtual ~base64 (
+ );
+
+ enum line_ending_type
+ {
+ CR, // i.e. "\r"
+ LF, // i.e. "\n"
+ CRLF // i.e. "\r\n"
+ };
+
+ line_ending_type line_ending (
+ ) const;
+
+ void set_line_ending (
+ line_ending_type eol_style_
+ );
+
+ void encode (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+
+ void decode (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+
+ private:
+
+ char* encode_table;
+ unsigned char* decode_table;
+ const unsigned char bad_value;
+ line_ending_type eol_style;
+
+ // restricted functions
+ base64(base64&); // copy constructor
+ base64& operator=(base64&); // assignment operator
+
+ };
+
+}
+
+#ifdef NO_MAKEFILE
+#include "base64_kernel_1.cpp"
+#endif
+
+#endif // DLIB_BASE64_KERNEl_1_
+
diff --git a/ml/dlib/dlib/base64/base64_kernel_abstract.h b/ml/dlib/dlib/base64/base64_kernel_abstract.h
new file mode 100644
index 000000000..0a63d3b87
--- /dev/null
+++ b/ml/dlib/dlib/base64/base64_kernel_abstract.h
@@ -0,0 +1,121 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BASE64_KERNEl_ABSTRACT_
+#ifdef DLIB_BASE64_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+#include <iosfwd>
+
+namespace dlib
+{
+
+ class base64
+ {
+ /*!
+ INITIAL VALUE
+ - line_ending() == LF
+
+ WHAT THIS OBJECT REPRESENTS
+ This object consists of the two functions encode and decode.
+ These functions allow you to encode and decode data to and from
+ the Base64 Content-Transfer-Encoding defined in section 6.8 of
+ rfc2045.
+ !*/
+
+ public:
+
+ class decode_error : public dlib::error {};
+
+ base64 (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~base64 (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ enum line_ending_type
+ {
+ CR, // i.e. "\r"
+ LF, // i.e. "\n"
+ CRLF // i.e. "\r\n"
+ };
+
+ line_ending_type line_ending (
+ ) const;
+ /*!
+ ensures
+ - returns the type of end of line bytes the encoder
+ will use when encoding data to base64 blocks. Note that
+ the ostream object you use might apply some sort of transform
+ to line endings as well. For example, C++ ofstream objects
+ usually convert '\n' into whatever a normal newline is for
+ your platform unless you open a file in binary mode. But
+ aside from file streams the ostream objects usually don't
+ modify the data you pass to them.
+ !*/
+
+ void set_line_ending (
+ line_ending_type eol_style
+ );
+ /*!
+ ensures
+ - #line_ending() == eol_style
+ !*/
+
+ void encode (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+ /*!
+ ensures
+ - reads all data from in (until EOF is reached) and encodes it
+ and writes it to out
+ throws
+ - std::ios_base::failure
+ if there was a problem writing to out then this exception will
+ be thrown.
+ - any other exception
+ this exception may be thrown if there is any other problem
+ !*/
+
+ void decode (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+ /*!
+ ensures
+ - reads data from in (until EOF is reached), decodes it,
+ and writes it to out.
+ throws
+ - std::ios_base::failure
+ if there was a problem writing to out then this exception will
+ be thrown.
+ - decode_error
+ if an error was detected in the encoded data that prevented
+ it from being correctly decoded then this exception is
+ thrown.
+ - any other exception
+ this exception may be thrown if there is any other problem
+ !*/
+
+ private:
+
+ // restricted functions
+ base64(base64&); // copy constructor
+ base64& operator=(base64&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_BASE64_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/bayes_utils.h b/ml/dlib/dlib/bayes_utils.h
new file mode 100644
index 000000000..51ef6d2ed
--- /dev/null
+++ b/ml/dlib/dlib/bayes_utils.h
@@ -0,0 +1,11 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BAYES_UTILs_H_
+#define DLIB_BAYES_UTILs_H_
+
+#include "bayes_utils/bayes_utils.h"
+
+#endif // DLIB_BAYES_UTILs_H_
+
+
+
diff --git a/ml/dlib/dlib/bayes_utils/bayes_utils.h b/ml/dlib/dlib/bayes_utils/bayes_utils.h
new file mode 100644
index 000000000..04b3d1187
--- /dev/null
+++ b/ml/dlib/dlib/bayes_utils/bayes_utils.h
@@ -0,0 +1,1678 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BAYES_UTILs_
+#define DLIB_BAYES_UTILs_
+
+#include "bayes_utils_abstract.h"
+
+#include <algorithm>
+#include <ctime>
+#include <memory>
+#include <vector>
+
+#include "../string.h"
+#include "../map.h"
+#include "../matrix.h"
+#include "../rand.h"
+#include "../array.h"
+#include "../set.h"
+#include "../algs.h"
+#include "../noncopyable.h"
+#include "../graph.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class assignment
+ {
+ public:
+
+ assignment()
+ {
+ }
+
+ assignment(
+ const assignment& a
+ )
+ {
+ a.reset();
+ while (a.move_next())
+ {
+ unsigned long idx = a.element().key();
+ unsigned long value = a.element().value();
+ vals.add(idx,value);
+ }
+ }
+
+ assignment& operator = (
+ const assignment& rhs
+ )
+ {
+ if (this == &rhs)
+ return *this;
+
+ assignment(rhs).swap(*this);
+ return *this;
+ }
+
+ void clear()
+ {
+ vals.clear();
+ }
+
+ bool operator < (
+ const assignment& item
+ ) const
+ {
+ if (size() < item.size())
+ return true;
+ else if (size() > item.size())
+ return false;
+
+ reset();
+ item.reset();
+ while (move_next())
+ {
+ item.move_next();
+ if (element().key() < item.element().key())
+ return true;
+ else if (element().key() > item.element().key())
+ return false;
+ else if (element().value() < item.element().value())
+ return true;
+ else if (element().value() > item.element().value())
+ return false;
+ }
+
+ return false;
+ }
+
+ bool has_index (
+ unsigned long idx
+ ) const
+ {
+ return vals.is_in_domain(idx);
+ }
+
+ void add (
+ unsigned long idx,
+ unsigned long value = 0
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( has_index(idx) == false ,
+ "\tvoid assignment::add(idx)"
+ << "\n\tYou can't add the same index to an assignment object more than once"
+ << "\n\tidx: " << idx
+ << "\n\tthis: " << this
+ );
+
+ vals.add(idx, value);
+ }
+
+ unsigned long& operator[] (
+ const long idx
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( has_index(idx) == true ,
+ "\tunsigned long assignment::operator[](idx)"
+ << "\n\tYou can't access an index value if it isn't already in the object"
+ << "\n\tidx: " << idx
+ << "\n\tthis: " << this
+ );
+
+ return vals[idx];
+ }
+
+ const unsigned long& operator[] (
+ const long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( has_index(idx) == true ,
+ "\tunsigned long assignment::operator[](idx)"
+ << "\n\tYou can't access an index value if it isn't already in the object"
+ << "\n\tidx: " << idx
+ << "\n\tthis: " << this
+ );
+
+ return vals[idx];
+ }
+
+ void swap (
+ assignment& item
+ )
+ {
+ vals.swap(item.vals);
+ }
+
+ void remove (
+ unsigned long idx
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( has_index(idx) == true ,
+ "\tunsigned long assignment::remove(idx)"
+ << "\n\tYou can't remove an index value if it isn't already in the object"
+ << "\n\tidx: " << idx
+ << "\n\tthis: " << this
+ );
+
+ vals.destroy(idx);
+ }
+
+ unsigned long size() const { return vals.size(); }
+
+ void reset() const { vals.reset(); }
+
+ bool move_next() const { return vals.move_next(); }
+
+ map_pair<unsigned long, unsigned long>& element()
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_element_valid() == true,
+ "\tmap_pair<unsigned long,unsigned long>& assignment::element()"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+ return vals.element();
+ }
+
+ const map_pair<unsigned long, unsigned long>& element() const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_element_valid() == true,
+ "\tconst map_pair<unsigned long,unsigned long>& assignment::element() const"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return vals.element();
+ }
+
+ bool at_start() const { return vals.at_start(); }
+
+ bool current_element_valid() const { return vals.current_element_valid(); }
+
+ friend inline void serialize (
+ const assignment& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.vals, out);
+ }
+
+ friend inline void deserialize (
+ assignment& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.vals, in);
+ }
+
+ private:
+ mutable dlib::map<unsigned long, unsigned long>::kernel_1b_c vals;
+ };
+
+ inline std::ostream& operator << (
+ std::ostream& out,
+ const assignment& a
+ )
+ {
+ a.reset();
+ out << "(";
+ if (a.move_next())
+ out << a.element().key() << ":" << a.element().value();
+
+ while (a.move_next())
+ {
+ out << ", " << a.element().key() << ":" << a.element().value();
+ }
+
+ out << ")";
+ return out;
+ }
+
+
+ inline void swap (
+ assignment& a,
+ assignment& b
+ )
+ {
+ a.swap(b);
+ }
+
+
+// ------------------------------------------------------------------------
+
+ class joint_probability_table
+ {
+ /*!
+ INITIAL VALUE
+ - table.size() == 0
+
+ CONVENTION
+ - size() == table.size()
+ - probability(a) == table[a]
+ !*/
+ public:
+
+ joint_probability_table (
+ const joint_probability_table& t
+ )
+ {
+ t.reset();
+ while (t.move_next())
+ {
+ assignment a = t.element().key();
+ double p = t.element().value();
+ set_probability(a,p);
+ }
+ }
+
+ joint_probability_table() {}
+
+ joint_probability_table& operator= (
+ const joint_probability_table& rhs
+ )
+ {
+ if (this == &rhs)
+ return *this;
+ joint_probability_table(rhs).swap(*this);
+ return *this;
+ }
+
+ void set_probability (
+ const assignment& a,
+ double p
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0.0 <= p && p <= 1.0,
+ "\tvoid& joint_probability_table::set_probability(a,p)"
+ << "\n\tyou have given an invalid probability value"
+ << "\n\tp: " << p
+ << "\n\ta: " << a
+ << "\n\tthis: " << this
+ );
+
+ if (table.is_in_domain(a))
+ {
+ table[a] = p;
+ }
+ else
+ {
+ assignment temp(a);
+ table.add(temp,p);
+ }
+ }
+
+ bool has_entry_for (
+ const assignment& a
+ ) const
+ {
+ return table.is_in_domain(a);
+ }
+
+ void add_probability (
+ const assignment& a,
+ double p
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0.0 <= p && p <= 1.0,
+ "\tvoid& joint_probability_table::add_probability(a,p)"
+ << "\n\tyou have given an invalid probability value"
+ << "\n\tp: " << p
+ << "\n\ta: " << a
+ << "\n\tthis: " << this
+ );
+
+ if (table.is_in_domain(a))
+ {
+ table[a] += p;
+ if (table[a] > 1.0)
+ table[a] = 1.0;
+ }
+ else
+ {
+ assignment temp(a);
+ table.add(temp,p);
+ }
+ }
+
+ double probability (
+ const assignment& a
+ ) const
+ {
+ return table[a];
+ }
+
+ void clear()
+ {
+ table.clear();
+ }
+
+ size_t size () const { return table.size(); }
+ bool move_next() const { return table.move_next(); }
+ void reset() const { table.reset(); }
+ map_pair<assignment,double>& element()
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_element_valid() == true,
+ "\tmap_pair<assignment,double>& joint_probability_table::element()"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return table.element();
+ }
+
+ const map_pair<assignment,double>& element() const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_element_valid() == true,
+ "\tconst map_pair<assignment,double>& joint_probability_table::element() const"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return table.element();
+ }
+
+ bool at_start() const { return table.at_start(); }
+
+ bool current_element_valid() const { return table.current_element_valid(); }
+
+
+ template <typename T>
+ void marginalize (
+ const T& vars,
+ joint_probability_table& out
+ ) const
+ {
+ out.clear();
+ double p;
+ reset();
+ while (move_next())
+ {
+ assignment a;
+ const assignment& asrc = element().key();
+ p = element().value();
+
+ asrc.reset();
+ while (asrc.move_next())
+ {
+ if (vars.is_member(asrc.element().key()))
+ a.add(asrc.element().key(), asrc.element().value());
+ }
+
+ out.add_probability(a,p);
+ }
+ }
+
+ void marginalize (
+ const unsigned long var,
+ joint_probability_table& out
+ ) const
+ {
+ out.clear();
+ double p;
+ reset();
+ while (move_next())
+ {
+ assignment a;
+ const assignment& asrc = element().key();
+ p = element().value();
+
+ asrc.reset();
+ while (asrc.move_next())
+ {
+ if (var == asrc.element().key())
+ a.add(asrc.element().key(), asrc.element().value());
+ }
+
+ out.add_probability(a,p);
+ }
+ }
+
+ void normalize (
+ )
+ {
+ double sum = 0;
+
+ reset();
+ while (move_next())
+ sum += element().value();
+
+ reset();
+ while (move_next())
+ element().value() /= sum;
+ }
+
+ void swap (
+ joint_probability_table& item
+ )
+ {
+ table.swap(item.table);
+ }
+
+ friend inline void serialize (
+ const joint_probability_table& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.table, out);
+ }
+
+ friend inline void deserialize (
+ joint_probability_table& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.table, in);
+ }
+
+ private:
+
+ dlib::map<assignment, double >::kernel_1b_c table;
+ };
+
+ inline void swap (
+ joint_probability_table& a,
+ joint_probability_table& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ class conditional_probability_table : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ - table.size() == 0
+
+ CONVENTION
+ - if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) then
+ - has_entry_for(value,ps) == true
+ - probability(value,ps) == table[ps](value)
+ - else
+ - has_entry_for(value,ps) == false
+
+ - num_values() == num_vals
+ !*/
+ public:
+
+ conditional_probability_table()
+ {
+ clear();
+ }
+
+ void set_num_values (
+ unsigned long num
+ )
+ {
+ num_vals = num;
+ table.clear();
+ }
+
+ bool has_entry_for (
+ unsigned long value,
+ const assignment& ps
+ ) const
+ {
+ if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0)
+ return true;
+ else
+ return false;
+ }
+
+ unsigned long num_values (
+ ) const { return num_vals; }
+
+ void set_probability (
+ unsigned long value,
+ const assignment& ps,
+ double p
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( value < num_values() && 0.0 <= p && p <= 1.0 ,
+ "\tvoid conditional_probability_table::set_probability()"
+ << "\n\tinvalid arguments to set_probability"
+ << "\n\tvalue: " << value
+ << "\n\tnum_values(): " << num_values()
+ << "\n\tp: " << p
+ << "\n\tps: " << ps
+ << "\n\tthis: " << this
+ );
+
+ if (table.is_in_domain(ps))
+ {
+ table[ps](value) = p;
+ }
+ else
+ {
+ matrix<double,1> dist(num_vals);
+ set_all_elements(dist,-1);
+ dist(value) = p;
+ assignment temp(ps);
+ table.add(temp,dist);
+ }
+ }
+
+ double probability(
+ unsigned long value,
+ const assignment& ps
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( value < num_values() && has_entry_for(value,ps) ,
+ "\tvoid conditional_probability_table::probability()"
+ << "\n\tinvalid arguments to probability"
+ << "\n\tvalue: " << value
+ << "\n\tnum_values(): " << num_values()
+ << "\n\tps: " << ps
+ << "\n\tthis: " << this
+ );
+
+ return table[ps](value);
+ }
+
+ void clear()
+ {
+ table.clear();
+ num_vals = 0;
+ }
+
+ void empty_table ()
+ {
+ table.clear();
+ }
+
+ void swap (
+ conditional_probability_table& item
+ )
+ {
+ exchange(num_vals, item.num_vals);
+ table.swap(item.table);
+ }
+
+ friend inline void serialize (
+ const conditional_probability_table& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.table, out);
+ serialize(item.num_vals, out);
+ }
+
+ friend inline void deserialize (
+ conditional_probability_table& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.table, in);
+ deserialize(item.num_vals, in);
+ }
+
+ private:
+ dlib::map<assignment, matrix<double,1> >::kernel_1b_c table;
+ unsigned long num_vals;
+ };
+
+ inline void swap (
+ conditional_probability_table& a,
+ conditional_probability_table& b
+ ) { a.swap(b); }
+
+// ------------------------------------------------------------------------
+
+ class bayes_node : noncopyable
+ {
+ public:
+ bayes_node ()
+ {
+ is_instantiated = false;
+ value_ = 0;
+ }
+
+ unsigned long value (
+ ) const { return value_;}
+
+ void set_value (
+ unsigned long new_value
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( new_value < table().num_values(),
+ "\tvoid bayes_node::set_value(new_value)"
+ << "\n\tnew_value must be less than the number of possible values for this node"
+ << "\n\tnew_value: " << new_value
+ << "\n\ttable().num_values(): " << table().num_values()
+ << "\n\tthis: " << this
+ );
+
+ value_ = new_value;
+ }
+
+ conditional_probability_table& table (
+ ) { return table_; }
+
+ const conditional_probability_table& table (
+ ) const { return table_; }
+
+ bool is_evidence (
+ ) const { return is_instantiated; }
+
+ void set_as_nonevidence (
+ ) { is_instantiated = false; }
+
+ void set_as_evidence (
+ ) { is_instantiated = true; }
+
+ void swap (
+ bayes_node& item
+ )
+ {
+ exchange(value_, item.value_);
+ exchange(is_instantiated, item.is_instantiated);
+ table_.swap(item.table_);
+ }
+
+ friend inline void serialize (
+ const bayes_node& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.value_, out);
+ serialize(item.is_instantiated, out);
+ serialize(item.table_, out);
+ }
+
+ friend inline void deserialize (
+ bayes_node& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.value_, in);
+ deserialize(item.is_instantiated, in);
+ deserialize(item.table_, in);
+ }
+
+ private:
+
+ unsigned long value_;
+ bool is_instantiated;
+ conditional_probability_table table_;
+ };
+
+ inline void swap (
+ bayes_node& a,
+ bayes_node& b
+ ) { a.swap(b); }
+
+// ------------------------------------------------------------------------
+
+ namespace bayes_node_utils
+ {
+
+ template <typename T>
+ unsigned long node_num_values (
+ const T& bn,
+ unsigned long n
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes(),
+ "\tvoid bayes_node_utils::node_num_values(bn, n)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ );
+
+ return bn.node(n).data.table().num_values();
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void set_node_value (
+ T& bn,
+ unsigned long n,
+ unsigned long val
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes() && val < node_num_values(bn,n),
+ "\tvoid bayes_node_utils::set_node_value(bn, n, val)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tval: " << val
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n)
+ );
+
+ bn.node(n).data.set_value(val);
+ }
+
+ // ----------------------------------------------------------------------------------------
+ template <typename T>
+ unsigned long node_value (
+ const T& bn,
+ unsigned long n
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes(),
+ "\tunsigned long bayes_node_utils::node_value(bn, n)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ );
+
+ return bn.node(n).data.value();
+ }
+ // ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ bool node_is_evidence (
+ const T& bn,
+ unsigned long n
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes(),
+ "\tbool bayes_node_utils::node_is_evidence(bn, n)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ );
+
+ return bn.node(n).data.is_evidence();
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void set_node_as_evidence (
+ T& bn,
+ unsigned long n
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes(),
+ "\tvoid bayes_node_utils::set_node_as_evidence(bn, n)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ );
+
+ bn.node(n).data.set_as_evidence();
+ }
+
+ // ----------------------------------------------------------------------------------------
+ template <typename T>
+ void set_node_as_nonevidence (
+ T& bn,
+ unsigned long n
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes(),
+ "\tvoid bayes_node_utils::set_node_as_nonevidence(bn, n)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ );
+
+ bn.node(n).data.set_as_nonevidence();
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void set_node_num_values (
+ T& bn,
+ unsigned long n,
+ unsigned long num
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes(),
+ "\tvoid bayes_node_utils::set_node_num_values(bn, n, num)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ );
+
+ bn.node(n).data.table().set_num_values(num);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ double node_probability (
+ const T& bn,
+ unsigned long n,
+ unsigned long value,
+ const assignment& parents
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n),
+ "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tvalue: " << value
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n)
+ );
+
+ DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(),
+ "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tparents.size(): " << parents.size()
+ << "\n\tb.node(n).number_of_parents(): " << bn.node(n).number_of_parents()
+ );
+
+#ifdef ENABLE_ASSERTS
+ parents.reset();
+ while (parents.move_next())
+ {
+ const unsigned long x = parents.element().key();
+ DLIB_ASSERT( bn.has_edge(x, n),
+ "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tx: " << x
+ );
+ DLIB_ASSERT( parents[x] < node_num_values(bn,x),
+ "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tx: " << x
+ << "\n\tparents[x]: " << parents[x]
+ << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x)
+ );
+ }
+#endif
+
+ return bn.node(n).data.table().probability(value, parents);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void set_node_probability (
+ T& bn,
+ unsigned long n,
+ unsigned long value,
+ const assignment& parents,
+ double p
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n),
+ "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tp: " << p
+ << "\n\tvalue: " << value
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n)
+ );
+
+ DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(),
+ "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tp: " << p
+ << "\n\tparents.size(): " << parents.size()
+ << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents()
+ );
+
+ DLIB_ASSERT( 0.0 <= p && p <= 1.0,
+ "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tp: " << p
+ );
+
+#ifdef ENABLE_ASSERTS
+ parents.reset();
+ while (parents.move_next())
+ {
+ const unsigned long x = parents.element().key();
+ DLIB_ASSERT( bn.has_edge(x, n),
+ "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tx: " << x
+ );
+ DLIB_ASSERT( parents[x] < node_num_values(bn,x),
+ "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tx: " << x
+ << "\n\tparents[x]: " << parents[x]
+ << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x)
+ );
+ }
+#endif
+
+ bn.node(n).data.table().set_probability(value,parents,p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ const assignment node_first_parent_assignment (
+ const T& bn,
+ unsigned long n
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes(),
+ "\tconst assignment bayes_node_utils::node_first_parent_assignment(bn, n)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ );
+
+ assignment a;
+ const unsigned long num_parents = bn.node(n).number_of_parents();
+ for (unsigned long i = 0; i < num_parents; ++i)
+ {
+ a.add(bn.node(n).parent(i).index(), 0);
+ }
+ return a;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ bool node_next_parent_assignment (
+ const T& bn,
+ unsigned long n,
+ assignment& a
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes(),
+ "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ );
+
+ DLIB_ASSERT( a.size() == bn.node(n).number_of_parents(),
+ "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\ta.size(): " << a.size()
+ << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents()
+ );
+
+#ifdef ENABLE_ASSERTS
+ a.reset();
+ while (a.move_next())
+ {
+ const unsigned long x = a.element().key();
+ DLIB_ASSERT( bn.has_edge(x, n),
+ "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tx: " << x
+ );
+ DLIB_ASSERT( a[x] < node_num_values(bn,x),
+ "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tx: " << x
+ << "\n\ta[x]: " << a[x]
+ << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x)
+ );
+ }
+#endif
+
+ // basically this loop just adds 1 to the assignment but performs
+ // carries if necessary
+ for (unsigned long p = 0; p < a.size(); ++p)
+ {
+ const unsigned long pindex = bn.node(n).parent(p).index();
+ a[pindex] += 1;
+
+ // if we need to perform a carry
+ if (a[pindex] >= node_num_values(bn,pindex))
+ {
+ a[pindex] = 0;
+ }
+ else
+ {
+ // no carry necessary so we are done
+ return true;
+ }
+ }
+
+ // we got through the entire loop which means a carry propagated all the way out
+ // so there must not be any more valid assignments left
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ bool node_cpt_filled_out (
+ const T& bn,
+ unsigned long n
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( n < bn.number_of_nodes(),
+ "\tbool bayes_node_utils::node_cpt_filled_out(bn, n)"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tn: " << n
+ << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
+ );
+
+ const unsigned long num_values = node_num_values(bn,n);
+
+
+ const conditional_probability_table& table = bn.node(n).data.table();
+
+ // now loop over all the possible parent assignments for this node
+ assignment a(node_first_parent_assignment(bn,n));
+ do
+ {
+ double sum = 0;
+ // make sure that this assignment has an entry for all the values this node can take one
+ for (unsigned long value = 0; value < num_values; ++value)
+ {
+ if (table.has_entry_for(value,a) == false)
+ return false;
+ else
+ sum += table.probability(value,a);
+ }
+
+ // check if the sum of probabilities equals 1 as it should
+ if (std::abs(sum-1.0) > 1e-5)
+ return false;
+ } while (node_next_parent_assignment(bn,n,a));
+
+ return true;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class bayesian_network_gibbs_sampler : noncopyable
+ {
+ public:
+
+ bayesian_network_gibbs_sampler ()
+ {
+ rnd.set_seed(cast_to_string(std::time(0)));
+ }
+
+
+ template <
+ typename T
+ >
+ void sample_graph (
+ T& bn
+ )
+ {
+ using namespace bayes_node_utils;
+ for (unsigned long n = 0; n < bn.number_of_nodes(); ++n)
+ {
+ if (node_is_evidence(bn, n))
+ continue;
+
+ samples.set_size(node_num_values(bn,n));
+ // obtain the probability distribution for this node
+ for (long i = 0; i < samples.nc(); ++i)
+ {
+ set_node_value(bn, n, i);
+ samples(i) = node_probability(bn, n);
+
+ for (unsigned long j = 0; j < bn.node(n).number_of_children(); ++j)
+ samples(i) *= node_probability(bn, bn.node(n).child(j).index());
+ }
+
+ //normalize samples
+ samples /= sum(samples);
+
+
+ // select a random point in the probability distribution
+ double prob = rnd.get_random_double();
+
+ // now find the point in the distribution this probability corresponds to
+ long j;
+ for (j = 0; j < samples.nc()-1; ++j)
+ {
+ if (prob <= samples(j))
+ break;
+ else
+ prob -= samples(j);
+ }
+
+ set_node_value(bn, n, j);
+ }
+ }
+
+
+ private:
+
+ template <
+ typename T
+ >
+ double node_probability (
+ const T& bn,
+ unsigned long n
+ )
+ /*!
+ requires
+ - n < bn.number_of_nodes()
+ ensures
+ - computes the probability of node n having its current value given
+ the current values of its parents in the network bn
+ !*/
+ {
+ v.clear();
+ for (unsigned long i = 0; i < bn.node(n).number_of_parents(); ++i)
+ {
+ v.add(bn.node(n).parent(i).index(), bn.node(n).parent(i).data.value());
+ }
+ return bn.node(n).data.table().probability(bn.node(n).data.value(), v);
+ }
+
+ assignment v;
+
+ dlib::rand rnd;
+ matrix<double,1> samples;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ namespace bayesian_network_join_tree_helpers
+ {
+ class bnjt
+ {
+ /*!
+ this object is the base class used in this pimpl idiom
+ !*/
+ public:
+ virtual ~bnjt() {}
+
+ virtual const matrix<double,1> probability(
+ unsigned long idx
+ ) const = 0;
+ };
+
+ template <typename T, typename U>
+ class bnjt_impl : public bnjt
+ {
+ /*!
+ This object is the implementation in the pimpl idiom
+ !*/
+
+ public:
+
+ bnjt_impl (
+ const T& bn,
+ const U& join_tree
+ )
+ {
+ create_bayesian_network_join_tree(bn, join_tree, join_tree_values);
+
+ cliques.resize(bn.number_of_nodes());
+
+ // figure out which cliques contain each node
+ for (unsigned long i = 0; i < cliques.size(); ++i)
+ {
+ // find the smallest clique that contains node with index i
+ unsigned long smallest_clique = 0;
+ unsigned long size = std::numeric_limits<unsigned long>::max();
+
+ for (unsigned long n = 0; n < join_tree.number_of_nodes(); ++n)
+ {
+ if (join_tree.node(n).data.is_member(i) && join_tree.node(n).data.size() < size)
+ {
+ size = join_tree.node(n).data.size();
+ smallest_clique = n;
+ }
+ }
+
+ cliques[i] = smallest_clique;
+ }
+ }
+
+ virtual const matrix<double,1> probability(
+ unsigned long idx
+ ) const
+ {
+ join_tree_values.node(cliques[idx]).data.marginalize(idx, table);
+ table.normalize();
+ var.clear();
+ var.add(idx);
+ dist.set_size(table.size());
+
+ // read the probabilities out of the table and into the row matrix
+ for (unsigned long i = 0; i < table.size(); ++i)
+ {
+ var[idx] = i;
+ dist(i) = table.probability(var);
+ }
+
+ return dist;
+ }
+
+ private:
+
+ graph< joint_probability_table, joint_probability_table >::kernel_1a_c join_tree_values;
+ array<unsigned long> cliques;
+ mutable joint_probability_table table;
+ mutable assignment var;
+ mutable matrix<double,1> dist;
+
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename set_type, typename node_type>
+ bool set_contains_all_parents_of_node (
+ const set_type& set,
+ const node_type& node
+ )
+ {
+ for (unsigned long i = 0; i < node.number_of_parents(); ++i)
+ {
+ if (set.is_member(node.parent(i).index()) == false)
+ return false;
+ }
+ return true;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <
+ typename V
+ >
+ void pass_join_tree_message (
+ const U& join_tree,
+ V& bn_join_tree ,
+ unsigned long from,
+ unsigned long to
+ )
+ {
+ using namespace bayes_node_utils;
+ const typename U::edge_type& e = edge(join_tree, from, to);
+ typename V::edge_type& old_s = edge(bn_join_tree, from, to);
+
+ typedef typename V::edge_type joint_prob_table;
+
+ joint_prob_table new_s;
+ bn_join_tree.node(from).data.marginalize(e, new_s);
+
+ joint_probability_table temp(new_s);
+ // divide new_s by old_s and store the result in temp.
+ // if old_s is empty then that is the same as if it was all 1s
+ // so we don't have to do this if that is the case.
+ if (old_s.size() > 0)
+ {
+ temp.reset();
+ old_s.reset();
+ while (temp.move_next())
+ {
+ old_s.move_next();
+ if (old_s.element().value() != 0)
+ temp.element().value() /= old_s.element().value();
+ }
+ }
+
+ // now multiply temp by d and store the results in d
+ joint_probability_table& d = bn_join_tree.node(to).data;
+ d.reset();
+ while (d.move_next())
+ {
+ assignment a;
+ const assignment& asrc = d.element().key();
+ asrc.reset();
+ while (asrc.move_next())
+ {
+ if (e.is_member(asrc.element().key()))
+ a.add(asrc.element().key(), asrc.element().value());
+ }
+
+ d.element().value() *= temp.probability(a);
+
+ }
+
+ // store new_s in old_s
+ new_s.swap(old_s);
+
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <
+ typename V
+ >
+ void create_bayesian_network_join_tree (
+ const T& bn,
+ const U& join_tree,
+ V& bn_join_tree
+ )
+ /*!
+ requires
+ - bn is a proper bayesian network
+ - join_tree is the join tree for that bayesian network
+ ensures
+ - bn_join_tree == the output of the join tree algorithm for bayesian network inference.
+ So each node in this graph contains a joint_probability_table for the clique
+ in the corresponding node in the join_tree graph.
+ !*/
+ {
+ using namespace bayes_node_utils;
+ bn_join_tree.clear();
+ copy_graph_structure(join_tree, bn_join_tree);
+
+ // we need to keep track of which node is "in" each clique for the purposes of
+ // initializing the tables in each clique. So this vector will be used to do that
+ // and a value of join_tree.number_of_nodes() means that the node with
+ // that index is unassigned.
+ std::vector<unsigned long> node_assigned_to(bn.number_of_nodes(),join_tree.number_of_nodes());
+
+ // populate evidence with all the evidence node indices and their values
+ dlib::map<unsigned long, unsigned long>::kernel_1b_c evidence;
+ for (unsigned long i = 0; i < bn.number_of_nodes(); ++i)
+ {
+ if (node_is_evidence(bn, i))
+ {
+ unsigned long idx = i;
+ unsigned long value = node_value(bn, i);
+ evidence.add(idx,value);
+ }
+ }
+
+
+ // initialize the bn join tree
+ for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i)
+ {
+ bool contains_evidence = false;
+ std::vector<unsigned long> indices;
+ assignment value;
+
+ // loop over all the nodes in this clique in the join tree. In this loop
+ // we are making an assignment with all the values of the nodes it represents set to 0
+ join_tree.node(i).data.reset();
+ while (join_tree.node(i).data.move_next())
+ {
+ const unsigned long idx = join_tree.node(i).data.element();
+ indices.push_back(idx);
+ value.add(idx);
+
+ if (evidence.is_in_domain(join_tree.node(i).data.element()))
+ contains_evidence = true;
+ }
+
+ // now loop over all possible combinations of values that the nodes this
+ // clique in the join tree can take on. We do this by counting by one through all
+ // legal values
+ bool more_assignments = true;
+ while (more_assignments)
+ {
+ bn_join_tree.node(i).data.set_probability(value,1);
+
+ // account for any evidence
+ if (contains_evidence)
+ {
+ // loop over all the nodes in this cluster
+ for (unsigned long j = 0; j < indices.size(); ++j)
+ {
+ // if the current node is an evidence node
+ if (evidence.is_in_domain(indices[j]))
+ {
+ const unsigned long idx = indices[j];
+ const unsigned long evidence_value = evidence[idx];
+ if (value[idx] != evidence_value)
+ bn_join_tree.node(i).data.set_probability(value , 0);
+ }
+ }
+ }
+
+
+ // now check if any of the nodes in this cluster also have their parents in this cluster
+ join_tree.node(i).data.reset();
+ while (join_tree.node(i).data.move_next())
+ {
+ const unsigned long idx = join_tree.node(i).data.element();
+ // if this clique contains all the parents of this node and also hasn't
+ // been assigned to another clique
+ if (set_contains_all_parents_of_node(join_tree.node(i).data, bn.node(idx)) &&
+ (i == node_assigned_to[idx] || node_assigned_to[idx] == join_tree.number_of_nodes()) )
+ {
+ // note that this node is now assigned to this clique
+ node_assigned_to[idx] = i;
+ // node idx has all its parents in the cluster
+ assignment parent_values;
+ for (unsigned long j = 0; j < bn.node(idx).number_of_parents(); ++j)
+ {
+ const unsigned long pidx = bn.node(idx).parent(j).index();
+ parent_values.add(pidx, value[pidx]);
+ }
+
+ double temp = bn_join_tree.node(i).data.probability(value);
+ bn_join_tree.node(i).data.set_probability(value, temp * node_probability(bn, idx, value[idx], parent_values));
+
+ }
+ }
+
+
+ // now advance the value variable to its next possible state if there is one
+ more_assignments = false;
+ value.reset();
+ while (value.move_next())
+ {
+ value.element().value() += 1;
+ // if overflow
+ if (value.element().value() == node_num_values(bn, value.element().key()))
+ {
+ value.element().value() = 0;
+ }
+ else
+ {
+ more_assignments = true;
+ break;
+ }
+ }
+
+ } // end while (more_assignments)
+ }
+
+
+
+
+ // the tree is now initialized. Now all we need to do is perform the propagation and
+ // we are done
+ dlib::array<dlib::set<unsigned long>::compare_1b_c> remaining_msg_to_send;
+ dlib::array<dlib::set<unsigned long>::compare_1b_c> remaining_msg_to_receive;
+ remaining_msg_to_receive.resize(join_tree.number_of_nodes());
+ remaining_msg_to_send.resize(join_tree.number_of_nodes());
+ for (unsigned long i = 0; i < remaining_msg_to_receive.size(); ++i)
+ {
+ for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j)
+ {
+ const unsigned long idx = join_tree.node(i).neighbor(j).index();
+ unsigned long temp;
+ temp = idx; remaining_msg_to_receive[i].add(temp);
+ temp = idx; remaining_msg_to_send[i].add(temp);
+ }
+ }
+
+ // now remaining_msg_to_receive[i] contains all the nodes that node i hasn't yet received
+ // a message from.
+ // we will consider node 0 to be the root node.
+
+
+ bool message_sent = true;
+ std::vector<unsigned long>::iterator iter;
+ while (message_sent)
+ {
+ message_sent = false;
+ for (unsigned long i = 1; i < remaining_msg_to_send.size(); ++i)
+ {
+ // if node i hasn't sent any messages but has received all but one then send a message to the one
+ // node who hasn't sent i a message
+ if (remaining_msg_to_send[i].size() == join_tree.node(i).number_of_neighbors() && remaining_msg_to_receive[i].size() == 1)
+ {
+ unsigned long to;
+ // get the last remaining thing from this set
+ remaining_msg_to_receive[i].remove_any(to);
+
+ // send the message
+ pass_join_tree_message(join_tree, bn_join_tree, i, to);
+
+ // record that we sent this message
+ remaining_msg_to_send[i].destroy(to);
+ remaining_msg_to_receive[to].destroy(i);
+
+ // put to back in since we still need to receive it
+ remaining_msg_to_receive[i].add(to);
+ message_sent = true;
+ }
+ else if (remaining_msg_to_receive[i].size() == 0 && remaining_msg_to_send[i].size() > 0)
+ {
+ unsigned long to;
+ remaining_msg_to_send[i].remove_any(to);
+ remaining_msg_to_receive[to].destroy(i);
+ pass_join_tree_message(join_tree, bn_join_tree, i, to);
+ message_sent = true;
+ }
+ }
+
+ if (remaining_msg_to_receive[0].size() == 0)
+ {
+ // send a message to all of the root nodes neighbors unless we have already sent out he messages
+ while (remaining_msg_to_send[0].size() > 0)
+ {
+ unsigned long to;
+ remaining_msg_to_send[0].remove_any(to);
+ remaining_msg_to_receive[to].destroy(0);
+ pass_join_tree_message(join_tree, bn_join_tree, 0, to);
+ message_sent = true;
+ }
+ }
+
+
+ }
+
+ }
+
+ };
+ }
+
+ class bayesian_network_join_tree : noncopyable
+ {
+ /*!
+ use the pimpl idiom to push the template arguments from the class level to the
+ constructor level
+ !*/
+
+ public:
+
+ template <
+ typename T,
+ typename U
+ >
+ bayesian_network_join_tree (
+ const T& bn,
+ const U& join_tree
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( bn.number_of_nodes() > 0 ,
+ "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
+ << "\n\tYou have given an invalid bayesian network"
+ << "\n\tthis: " << this
+ );
+
+ DLIB_ASSERT( is_join_tree(bn, join_tree) == true ,
+ "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
+ << "\n\tYou have given an invalid join tree for the supplied bayesian network"
+ << "\n\tthis: " << this
+ );
+ DLIB_ASSERT( graph_contains_length_one_cycle(bn) == false,
+ "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
+ << "\n\tYou have given an invalid bayesian network"
+ << "\n\tthis: " << this
+ );
+ DLIB_ASSERT( graph_is_connected(bn) == true,
+ "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
+ << "\n\tYou have given an invalid bayesian network"
+ << "\n\tthis: " << this
+ );
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < bn.number_of_nodes(); ++i)
+ {
+ DLIB_ASSERT(bayes_node_utils::node_cpt_filled_out(bn,i) == true,
+ "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
+ << "\n\tYou have given an invalid bayesian network. "
+ << "\n\tYou must finish filling out the conditional_probability_table of node " << i
+ << "\n\tthis: " << this
+ );
+ }
+#endif
+
+ impl.reset(new bayesian_network_join_tree_helpers::bnjt_impl<T,U>(bn, join_tree));
+ num_nodes = bn.number_of_nodes();
+ }
+
+ const matrix<double,1> probability(
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( idx < number_of_nodes() ,
+ "\tconst matrix<double,1> bayesian_network_join_tree::probability(idx)"
+ << "\n\tYou have specified an invalid node index"
+ << "\n\tidx: " << idx
+ << "\n\tnumber_of_nodes(): " << number_of_nodes()
+ << "\n\tthis: " << this
+ );
+
+ return impl->probability(idx);
+ }
+
+ unsigned long number_of_nodes (
+ ) const { return num_nodes; }
+
+ void swap (
+ bayesian_network_join_tree& item
+ )
+ {
+ exchange(num_nodes, item.num_nodes);
+ impl.swap(item.impl);
+ }
+
+ private:
+
+ std::unique_ptr<bayesian_network_join_tree_helpers::bnjt> impl;
+ unsigned long num_nodes;
+
+ };
+
+ inline void swap (
+ bayesian_network_join_tree& a,
+ bayesian_network_join_tree& b
+ ) { a.swap(b); }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_BAYES_UTILs_
+
diff --git a/ml/dlib/dlib/bayes_utils/bayes_utils_abstract.h b/ml/dlib/dlib/bayes_utils/bayes_utils_abstract.h
new file mode 100644
index 000000000..b19e6e1da
--- /dev/null
+++ b/ml/dlib/dlib/bayes_utils/bayes_utils_abstract.h
@@ -0,0 +1,1042 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BAYES_UTILs_ABSTRACT_
+#ifdef DLIB_BAYES_UTILs_ABSTRACT_
+
+#include "../algs.h"
+#include "../noncopyable.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/map_pair.h"
+#include "../serialize.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class assignment : public enumerable<map_pair<unsigned long, unsigned long> >
+ {
+ /*!
+ INITIAL VALUE
+ - size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the entries in the assignment in
+ ascending order according to index values. (i.e. the elements are
+ enumerated in sorted order according to the value of their keys)
+
+ WHAT THIS OBJECT REPRESENTS
+ This object models an assignment of random variables to particular values.
+ It is used with the joint_probability_table and conditional_probability_table
+ objects to represent assignments of various random variables to actual values.
+
+ So for example, if you had a joint_probability_table that represented the
+ following table:
+ P(A = 0, B = 0) = 0.2
+ P(A = 0, B = 1) = 0.3
+ P(A = 1, B = 0) = 0.1
+ P(A = 1, B = 1) = 0.4
+
+ Also lets define an enum so we have concrete index numbers for A and B
+ enum { A = 0, B = 1};
+
+ Then you could query the value of P(A=1, B=0) as follows:
+ assignment a;
+ a.set(A, 1);
+ a.set(B, 0);
+ // and now it is the case that:
+ table.probability(a) == 0.1
+ a[A] == 1
+ a[B] == 0
+
+
+ Also note that when enumerating the elements of an assignment object
+ the key() refers to the index and the value() refers to the value at that
+ index. For example:
+
+ // assume a is an assignment object
+ a.reset();
+ while (a.move_next())
+ {
+ // in this loop it is always the case that:
+ // a[a.element().key()] == a.element().value()
+ }
+ !*/
+
+ public:
+
+ assignment(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ assignment(
+ const assignment& a
+ );
+ /*!
+ ensures
+ - #*this is a copy of a
+ !*/
+
+ assignment& operator = (
+ const assignment& rhs
+ );
+ /*!
+ ensures
+ - #*this is a copy of rhs
+ - returns *this
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - this object has been returned to its initial value
+ !*/
+
+ bool operator < (
+ const assignment& item
+ ) const;
+ /*!
+ ensures
+ - The exact functioning of this operator is undefined. The only guarantee
+ is that it establishes a total ordering on all possible assignment objects.
+ In other words, this operator makes it so that you can use assignment
+ objects in the associative containers but otherwise isn't of any
+ particular use.
+ !*/
+
+ bool has_index (
+ unsigned long idx
+ ) const;
+ /*!
+ ensures
+ - if (this assignment object has an entry for index idx) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void add (
+ unsigned long idx,
+ unsigned long value = 0
+ );
+ /*!
+ requires
+ - has_index(idx) == false
+ ensures
+ - #has_index(idx) == true
+ - #(*this)[idx] == value
+ !*/
+
+ void remove (
+ unsigned long idx
+ );
+ /*!
+ requires
+ - has_index(idx) == true
+ ensures
+ - #has_index(idx) == false
+ !*/
+
+ unsigned long& operator[] (
+ const long idx
+ );
+ /*!
+ requires
+ - has_index(idx) == true
+ ensures
+ - returns a reference to the value associated with index idx
+ !*/
+
+ const unsigned long& operator[] (
+ const long idx
+ ) const;
+ /*!
+ requires
+ - has_index(idx) == true
+ ensures
+ - returns a const reference to the value associated with index idx
+ !*/
+
+ void swap (
+ assignment& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ inline void swap (
+ assignment& a,
+ assignment& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+ std::ostream& operator << (
+ std::ostream& out,
+ const assignment& a
+ );
+ /*!
+ ensures
+ - writes a to the given output stream in the following format:
+ (index1:value1, index2:value2, ..., indexN:valueN)
+ !*/
+
+ void serialize (
+ const assignment& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ assignment& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ------------------------------------------------------------------------
+
+ class joint_probability_table : public enumerable<map_pair<assignment, double> >
+ {
+ /*!
+ INITIAL VALUE
+ - size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the entries in the probability table
+ in no particular order but they will all be visited.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object models a joint probability table. That is, it models
+ the function p(X). So this object models the probability of a particular
+ set of variables (referred to as X).
+ !*/
+
+ public:
+
+ joint_probability_table(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ joint_probability_table (
+ const joint_probability_table& t
+ );
+ /*!
+ ensures
+ - this object is a copy of t
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - this object has its initial value
+ !*/
+
+ joint_probability_table& operator= (
+ const joint_probability_table& rhs
+ );
+ /*!
+ ensures
+ - this object is a copy of rhs
+ - returns a reference to *this
+ !*/
+
+ bool has_entry_for (
+ const assignment& a
+ ) const;
+ /*!
+ ensures
+ - if (this joint_probability_table has an entry for p(X = a)) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void set_probability (
+ const assignment& a,
+ double p
+ );
+ /*!
+ requires
+ - 0 <= p <= 1
+ ensures
+ - if (has_entry_for(a) == false) then
+ - #size() == size() + 1
+ - #probability(a) == p
+ - #has_entry_for(a) == true
+ !*/
+
+ void add_probability (
+ const assignment& a,
+ double p
+ );
+ /*!
+ requires
+ - 0 <= p <= 1
+ ensures
+ - if (has_entry_for(a) == false) then
+ - #size() == size() + 1
+ - #probability(a) == p
+ - else
+ - #probability(a) == min(probability(a) + p, 1.0)
+ (i.e. does a saturating add)
+ - #has_entry_for(a) == true
+ !*/
+
+ const double probability (
+ const assignment& a
+ ) const;
+ /*!
+ ensures
+ - returns the probability p(X == a)
+ !*/
+
+ template <
+ typename T
+ >
+ void marginalize (
+ const T& vars,
+ joint_probability_table& output_table
+ ) const;
+ /*!
+ requires
+ - T is an implementation of set/set_kernel_abstract.h
+ ensures
+ - marginalizes *this by summing over all variables not in vars. The
+ result is stored in output_table.
+ !*/
+
+ void marginalize (
+ const unsigned long var,
+ joint_probability_table& output_table
+ ) const;
+ /*!
+ ensures
+ - is identical to calling the above marginalize() function with a set
+ that contains only var. Or in other words, performs a marginalization
+ with just one variable var. So that output_table will contain a table giving
+ the marginal probability of var all by itself.
+ !*/
+
+ void normalize (
+ );
+ /*!
+ ensures
+ - let sum == the sum of all the probabilities in this table
+ - after normalize() has finished it will be the case that the sum of all
+ the entries in this table is 1.0. This is accomplished by dividing all
+ the entries by the sum described above.
+ !*/
+
+ void swap (
+ joint_probability_table& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ inline void swap (
+ joint_probability_table& a,
+ joint_probability_table& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+ void serialize (
+ const joint_probability_table& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ joint_probability_table& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class conditional_probability_table : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ - num_values() == 0
+ - has_value_for(x, y) == false for all values of x and y
+
+ WHAT THIS OBJECT REPRESENTS
+ This object models a conditional probability table. That is, it models
+ the function p( X | parents). So this object models the conditional
+ probability of a particular variable (referred to as X) given another set
+ of variables (referred to as parents).
+ !*/
+
+ public:
+
+ conditional_probability_table(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - this object has its initial value
+ !*/
+
+ void empty_table (
+ );
+ /*!
+ ensures
+ - for all possible v and p:
+ - #has_entry_for(v,p) == false
+ (i.e. this function clears out the table when you call it but doesn't
+ change the value of num_values())
+ !*/
+
+ void set_num_values (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #num_values() == num
+ - for all possible v and p:
+ - #has_entry_for(v,p) == false
+ (i.e. this function clears out the table when you call it)
+ !*/
+
+ unsigned long num_values (
+ ) const;
+ /*!
+ ensures
+ - This object models the probability table p(X | parents). This
+ function returns the number of values X can take on.
+ !*/
+
+ bool has_entry_for (
+ unsigned long value,
+ const assignment& ps
+ ) const;
+ /*!
+ ensures
+ - if (this conditional_probability_table has an entry for p(X = value, parents = ps)) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void set_probability (
+ unsigned long value,
+ const assignment& ps,
+ double p
+ );
+ /*!
+ requires
+ - value < num_values()
+ - 0 <= p <= 1
+ ensures
+ - #probability(ps, value) == p
+ - #has_entry_for(value, ps) == true
+ !*/
+
+ double probability(
+ unsigned long value,
+ const assignment& ps
+ ) const;
+ /*!
+ requires
+ - value < num_values()
+ - has_entry_for(value, ps) == true
+ ensures
+ - returns the probability p( X = value | parents = ps).
+ !*/
+
+ void swap (
+ conditional_probability_table& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+ };
+
+ inline void swap (
+ conditional_probability_table& a,
+ conditional_probability_table& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+ void serialize (
+ const conditional_probability_table& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ conditional_probability_table& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ------------------------------------------------------------------------
+// ------------------------------------------------------------------------
+// ------------------------------------------------------------------------
+
+ class bayes_node : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ - is_evidence() == false
+ - value() == 0
+ - table().num_values() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a node in a bayesian network. It is
+ intended to be used inside the dlib::directed_graph object to
+ represent bayesian networks.
+ !*/
+
+ public:
+ bayes_node (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ unsigned long value (
+ ) const;
+ /*!
+ ensures
+ - returns the current value of this node
+ !*/
+
+ void set_value (
+ unsigned long new_value
+ );
+ /*!
+ requires
+ - new_value < table().num_values()
+ ensures
+ - #value() == new_value
+ !*/
+
+ conditional_probability_table& table (
+ );
+ /*!
+ ensures
+ - returns a reference to the conditional_probability_table associated with this node
+ !*/
+
+ const conditional_probability_table& table (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the conditional_probability_table associated with this
+ node.
+ !*/
+
+ bool is_evidence (
+ ) const;
+ /*!
+ ensures
+ - if (this is an evidence node) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void set_as_nonevidence (
+ );
+ /*!
+ ensures
+ - #is_evidence() == false
+ !*/
+
+ void set_as_evidence (
+ );
+ /*!
+ ensures
+ - #is_evidence() == true
+ !*/
+
+ void swap (
+ bayes_node& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ inline void swap (
+ bayes_node& a,
+ bayes_node& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+ void serialize (
+ const bayes_node& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ bayes_node& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ /*
+ The following group of functions are convenience functions for manipulating
+ bayes_node objects while they are inside a directed_graph. These functions
+ also have additional requires clauses that, in debug mode, will protect you
+ from attempts to manipulate a bayesian network in an inappropriate way.
+ */
+
+ namespace bayes_node_utils
+ {
+
+ template <
+ typename T
+ >
+ void set_node_value (
+ T& bn,
+ unsigned long n,
+ unsigned long val
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ - val < node_num_values(bn, n)
+ ensures
+ - #bn.node(n).data.value() = val
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ unsigned long node_value (
+ const T& bn,
+ unsigned long n
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ ensures
+ - returns bn.node(n).data.value()
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool node_is_evidence (
+ const T& bn,
+ unsigned long n
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ ensures
+ - returns bn.node(n).data.is_evidence()
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void set_node_as_evidence (
+ T& bn,
+ unsigned long n
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ ensures
+ - executes: bn.node(n).data.set_as_evidence()
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void set_node_as_nonevidence (
+ T& bn,
+ unsigned long n
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ ensures
+ - executes: bn.node(n).data.set_as_nonevidence()
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void set_node_num_values (
+ T& bn,
+ unsigned long n,
+ unsigned long num
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ ensures
+ - #bn.node(n).data.table().num_values() == num
+ (i.e. sets the number of different values this node can take)
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ unsigned long node_num_values (
+ const T& bn,
+ unsigned long n
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ ensures
+ - returns bn.node(n).data.table().num_values()
+ (i.e. returns the number of different values this node can take)
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const double node_probability (
+ const T& bn,
+ unsigned long n,
+ unsigned long value,
+ const assignment& parents
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ - value < node_num_values(bn,n)
+ - parents.size() == bn.node(n).number_of_parents()
+ - if (parents.has_index(x)) then
+ - bn.has_edge(x, n)
+ - parents[x] < node_num_values(bn,x)
+ ensures
+ - returns bn.node(n).data.table().probability(value, parents)
+ (i.e. returns the probability of node n having the given value when
+ its parents have the given assignment)
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const double set_node_probability (
+ const T& bn,
+ unsigned long n,
+ unsigned long value,
+ const assignment& parents,
+ double p
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ - value < node_num_values(bn,n)
+ - 0 <= p <= 1
+ - parents.size() == bn.node(n).number_of_parents()
+ - if (parents.has_index(x)) then
+ - bn.has_edge(x, n)
+ - parents[x] < node_num_values(bn,x)
+ ensures
+ - #bn.node(n).data.table().probability(value, parents) == p
+ (i.e. sets the probability of node n having the given value when
+ its parents have the given assignment to the probability p)
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ const assignment node_first_parent_assignment (
+ const T& bn,
+ unsigned long n
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ ensures
+ - returns an assignment A such that:
+ - A.size() == bn.node(n).number_of_parents()
+ - if (P is a parent of bn.node(n)) then
+ - A.has_index(P)
+ - A[P] == 0
+ - I.e. this function returns an assignment that contains all
+ the parents of the given node. Also, all the values of each
+ parent in the assignment is set to zero.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ bool node_next_parent_assignment (
+ const T& bn,
+ unsigned long n,
+ assignment& A
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ - A.size() == bn.node(n).number_of_parents()
+ - if (A.has_index(x)) then
+ - bn.has_edge(x, n)
+ - A[x] < node_num_values(bn,x)
+ ensures
+ - The behavior of this function is defined by the following code:
+ assignment a(node_first_parent_assignment(bn,n);
+ do {
+ // this loop loops over all possible parent assignments
+ // of the node bn.node(n). Each time through the loop variable a
+ // will be the next assignment.
+ } while (node_next_parent_assignment(bn,n,a))
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ bool node_cpt_filled_out (
+ const T& bn,
+ unsigned long n
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ - n < bn.number_of_nodes()
+ ensures
+ - if (the conditional_probability_table bn.node(n).data.table() is
+ fully filled out for this node) then
+ - returns true
+ - This means that each parent assignment for the given node
+ along with all possible values of this node shows up in the
+ table.
+ - It also means that all the probabilities conditioned on the
+ same parent assignment sum to 1.0
+ - else
+ - returns false
+ !*/
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class bayesian_network_gibbs_sampler : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ This object has no state
+
+ WHAT THIS OBJECT REPRESENTS
+ This object performs Markov Chain Monte Carlo sampling of a bayesian
+ network using the Gibbs sampling technique.
+
+ Note that this object is limited to only bayesian networks that
+ don't contain deterministic nodes. That is, incorrect results may
+ be computed if this object is used when the bayesian network contains
+ any nodes that have a probability of 1 in their conditional probability
+ tables for any event. So don't use this object for networks with
+ deterministic nodes.
+ !*/
+ public:
+
+ bayesian_network_gibbs_sampler (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ template <
+ typename T
+ >
+ void sample_graph (
+ T& bn
+ )
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - T::type == bayes_node
+ ensures
+ - modifies randomly (via the Gibbs sampling technique) samples all the nodes
+ in the network and updates their values with the newly sampled values
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class bayesian_network_join_tree : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an implementation of the join tree algorithm
+ for inference in bayesian networks. It doesn't have any mutable state.
+ To you use you just give it a directed_graph that contains a bayesian
+ network and a graph object that contains that networks corresponding
+ join tree. Then you may query this object to determine the probabilities
+ of any variables in the original bayesian network.
+ !*/
+
+ public:
+
+ template <
+ typename bn_type,
+ typename join_tree_type
+ >
+ bayesian_network_join_tree (
+ const bn_type& bn,
+ const join_tree_type& join_tree
+ );
+ /*!
+ requires
+ - bn_type is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - bn_type::type == bayes_node
+ - join_tree_type is an implementation of graph/graph_kernel_abstract.h
+ - join_tree_type::type is an implementation of set/set_compare_abstract.h and
+ this set type contains unsigned long objects.
+ - join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and
+ this set type contains unsigned long objects.
+ - is_join_tree(bn, join_tree) == true
+ - bn == a valid bayesian network with all its conditional probability tables
+ filled out
+ - for all valid n:
+ - node_cpt_filled_out(bn,n) == true
+ - graph_contains_length_one_cycle(bn) == false
+ - graph_is_connected(bn) == true
+ - bn.number_of_nodes() > 0
+ ensures
+ - this object is properly initialized
+ !*/
+
+ unsigned long number_of_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in the bayesian network that this
+ object was instantiated from.
+ !*/
+
+ const matrix<double,1> probability(
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - returns the probability distribution for the node with index idx that was in the bayesian
+ network that *this was instantiated from. Let D represent this distribution, then:
+ - D.nc() == the number of values the node idx ranges over
+ - D.nr() == 1
+ - D(i) == the probability of node idx taking on the value i
+ !*/
+
+ void swap (
+ bayesian_network_join_tree& item
+ );
+ /*!
+ ensures
+ - swaps *this with item
+ !*/
+
+ };
+
+ inline void swap (
+ bayesian_network_join_tree& a,
+ bayesian_network_join_tree& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BAYES_UTILs_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/bigint.h b/ml/dlib/dlib/bigint.h
new file mode 100644
index 000000000..73496689a
--- /dev/null
+++ b/ml/dlib/dlib/bigint.h
@@ -0,0 +1,43 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIGINt_
+#define DLIB_BIGINt_
+
+#include "bigint/bigint_kernel_1.h"
+#include "bigint/bigint_kernel_2.h"
+#include "bigint/bigint_kernel_c.h"
+
+
+
+
+namespace dlib
+{
+
+
+ class bigint
+ {
+ bigint() {}
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef bigint_kernel_1
+ kernel_1a;
+ typedef bigint_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+ // kernel_2a
+ typedef bigint_kernel_2
+ kernel_2a;
+ typedef bigint_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+
+ };
+}
+
+#endif // DLIB_BIGINt_
+
diff --git a/ml/dlib/dlib/bigint/bigint_kernel_1.cpp b/ml/dlib/dlib/bigint/bigint_kernel_1.cpp
new file mode 100644
index 000000000..feef761c2
--- /dev/null
+++ b/ml/dlib/dlib/bigint/bigint_kernel_1.cpp
@@ -0,0 +1,1720 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIGINT_KERNEL_1_CPp_
+#define DLIB_BIGINT_KERNEL_1_CPp_
+#include "bigint_kernel_1.h"
+
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member/friend function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1::
+ bigint_kernel_1 (
+ ) :
+ slack(25),
+ data(new data_record(slack))
+ {}
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1::
+ bigint_kernel_1 (
+ uint32 value
+ ) :
+ slack(25),
+ data(new data_record(slack))
+ {
+ *(data->number) = static_cast<uint16>(value&0xFFFF);
+ *(data->number+1) = static_cast<uint16>((value>>16)&0xFFFF);
+ if (*(data->number+1) != 0)
+ data->digits_used = 2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1::
+ bigint_kernel_1 (
+ const bigint_kernel_1& item
+ ) :
+ slack(25),
+ data(item.data)
+ {
+ data->references += 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1::
+ ~bigint_kernel_1 (
+ )
+ {
+ if (data->references == 1)
+ {
+ delete data;
+ }
+ else
+ {
+ data->references -= 1;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 bigint_kernel_1::
+ operator+ (
+ const bigint_kernel_1& rhs
+ ) const
+ {
+ data_record* temp = new data_record (
+ std::max(rhs.data->digits_used,data->digits_used) + slack
+ );
+ long_add(data,rhs.data,temp);
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator+= (
+ const bigint_kernel_1& rhs
+ )
+ {
+ // if there are other references to our data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack);
+ data->references -= 1;
+ long_add(data,rhs.data,temp);
+ data = temp;
+ }
+ // if data is not big enough for the result
+ else if (data->size <= std::max(data->digits_used,rhs.data->digits_used))
+ {
+ data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack);
+ long_add(data,rhs.data,temp);
+ delete data;
+ data = temp;
+ }
+ // there is enough size and no references
+ else
+ {
+ long_add(data,rhs.data,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 bigint_kernel_1::
+ operator- (
+ const bigint_kernel_1& rhs
+ ) const
+ {
+ data_record* temp = new data_record (
+ data->digits_used + slack
+ );
+ long_sub(data,rhs.data,temp);
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator-= (
+ const bigint_kernel_1& rhs
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ long_sub(data,rhs.data,temp);
+ data = temp;
+ }
+ else
+ {
+ long_sub(data,rhs.data,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 bigint_kernel_1::
+ operator* (
+ const bigint_kernel_1& rhs
+ ) const
+ {
+ data_record* temp = new data_record (
+ data->digits_used + rhs.data->digits_used + slack
+ );
+ long_mul(data,rhs.data,temp);
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator*= (
+ const bigint_kernel_1& rhs
+ )
+ {
+ // create a data_record to store the result of the multiplication in
+ data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack);
+ long_mul(data,rhs.data,temp);
+
+ // if there are other references to data
+ if (data->references != 1)
+ {
+ data->references -= 1;
+ }
+ else
+ {
+ delete data;
+ }
+ data = temp;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 bigint_kernel_1::
+ operator/ (
+ const bigint_kernel_1& rhs
+ ) const
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data_record* remainder;
+ try {
+ remainder = new data_record(data->digits_used+slack);
+ } catch (...) { delete temp; throw; }
+
+ long_div(data,rhs.data,temp,remainder);
+ delete remainder;
+
+
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator/= (
+ const bigint_kernel_1& rhs
+ )
+ {
+
+ data_record* temp = new data_record(data->digits_used+slack);
+ data_record* remainder;
+ try {
+ remainder = new data_record(data->digits_used+slack);
+ } catch (...) { delete temp; throw; }
+
+ long_div(data,rhs.data,temp,remainder);
+
+ // check if there are other references to data
+ if (data->references != 1)
+ {
+ data->references -= 1;
+ }
+ // if there are no references to data then it must be deleted
+ else
+ {
+ delete data;
+ }
+ data = temp;
+ delete remainder;
+
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 bigint_kernel_1::
+ operator% (
+ const bigint_kernel_1& rhs
+ ) const
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data_record* remainder;
+ try {
+ remainder = new data_record(data->digits_used+slack);
+ } catch (...) { delete temp; throw; }
+
+ long_div(data,rhs.data,temp,remainder);
+ delete temp;
+ return bigint_kernel_1(remainder,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator%= (
+ const bigint_kernel_1& rhs
+ )
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data_record* remainder;
+ try {
+ remainder = new data_record(data->digits_used+slack);
+ } catch (...) { delete temp; throw; }
+
+ long_div(data,rhs.data,temp,remainder);
+
+ // check if there are other references to data
+ if (data->references != 1)
+ {
+ data->references -= 1;
+ }
+ // if there are no references to data then it must be deleted
+ else
+ {
+ delete data;
+ }
+ data = remainder;
+ delete temp;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bigint_kernel_1::
+ operator < (
+ const bigint_kernel_1& rhs
+ ) const
+ {
+ return is_less_than(data,rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bigint_kernel_1::
+ operator == (
+ const bigint_kernel_1& rhs
+ ) const
+ {
+ return is_equal_to(data,rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator= (
+ const bigint_kernel_1& rhs
+ )
+ {
+ if (this == &rhs)
+ return *this;
+
+ // if we have the only reference to our data then delete it
+ if (data->references == 1)
+ {
+ delete data;
+ data = rhs.data;
+ data->references += 1;
+ }
+ else
+ {
+ data->references -= 1;
+ data = rhs.data;
+ data->references += 1;
+ }
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::ostream& operator<< (
+ std::ostream& out_,
+ const bigint_kernel_1& rhs
+ )
+ {
+ std::ostream out(out_.rdbuf());
+
+ typedef bigint_kernel_1 bigint;
+
+ bigint::data_record* temp = new bigint::data_record(*rhs.data,0);
+
+
+
+ // get a char array big enough to hold the number in ascii format
+ char* str;
+ try {
+ str = new char[(rhs.data->digits_used)*5+10];
+ } catch (...) { delete temp; throw; }
+
+ char* str_start = str;
+ str += (rhs.data->digits_used)*5+9;
+ *str = 0; --str;
+
+
+ uint16 remainder;
+ rhs.short_div(temp,10000,temp,remainder);
+
+ // pull the digits out of remainder
+ char a = remainder % 10 + '0';
+ remainder /= 10;
+ char b = remainder % 10 + '0';
+ remainder /= 10;
+ char c = remainder % 10 + '0';
+ remainder /= 10;
+ char d = remainder % 10 + '0';
+ remainder /= 10;
+
+
+ *str = a; --str;
+ *str = b; --str;
+ *str = c; --str;
+ *str = d; --str;
+
+
+ // keep looping until temp represents zero
+ while (temp->digits_used != 1 || *(temp->number) != 0)
+ {
+ rhs.short_div(temp,10000,temp,remainder);
+
+ // pull the digits out of remainder
+ char a = remainder % 10 + '0';
+ remainder /= 10;
+ char b = remainder % 10 + '0';
+ remainder /= 10;
+ char c = remainder % 10 + '0';
+ remainder /= 10;
+ char d = remainder % 10 + '0';
+ remainder /= 10;
+
+ *str = a; --str;
+ *str = b; --str;
+ *str = c; --str;
+ *str = d; --str;
+ }
+
+ // throw away and extra leading zeros
+ ++str;
+ if (*str == '0')
+ ++str;
+ if (*str == '0')
+ ++str;
+ if (*str == '0')
+ ++str;
+
+
+
+
+ out << str;
+ delete [] str_start;
+ delete temp;
+ return out_;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::istream& operator>> (
+ std::istream& in_,
+ bigint_kernel_1& rhs
+ )
+ {
+ std::istream in(in_.rdbuf());
+
+ // ignore any leading whitespaces
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n')
+ {
+ in.get();
+ }
+
+ // if the first digit is not an integer then this is an error
+ if ( !(in.peek() >= '0' && in.peek() <= '9'))
+ {
+ in_.clear(std::ios::failbit);
+ return in_;
+ }
+
+ int num_read;
+ bigint_kernel_1 temp;
+ do
+ {
+
+ // try to get 4 chars from in
+ num_read = 1;
+ char a = 0;
+ char b = 0;
+ char c = 0;
+ char d = 0;
+
+ if (in.peek() >= '0' && in.peek() <= '9')
+ {
+ num_read *= 10;
+ a = in.get();
+ }
+ if (in.peek() >= '0' && in.peek() <= '9')
+ {
+ num_read *= 10;
+ b = in.get();
+ }
+ if (in.peek() >= '0' && in.peek() <= '9')
+ {
+ num_read *= 10;
+ c = in.get();
+ }
+ if (in.peek() >= '0' && in.peek() <= '9')
+ {
+ num_read *= 10;
+ d = in.get();
+ }
+
+ // merge the for digits into an uint16
+ uint16 num = 0;
+ if (a != 0)
+ {
+ num = a - '0';
+ }
+ if (b != 0)
+ {
+ num *= 10;
+ num += b - '0';
+ }
+ if (c != 0)
+ {
+ num *= 10;
+ num += c - '0';
+ }
+ if (d != 0)
+ {
+ num *= 10;
+ num += d - '0';
+ }
+
+
+ if (num_read != 1)
+ {
+ // shift the digits in temp left by the number of new digits we just read
+ temp *= num_read;
+ // add in new digits
+ temp += num;
+ }
+
+ } while (num_read == 10000);
+
+
+ rhs = temp;
+ return in_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator+ (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (rhs.data->digits_used+rhs.slack);
+
+ rhs.short_add(rhs.data,lhs,temp);
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator+ (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (lhs.data->digits_used+lhs.slack);
+
+ lhs.short_add(lhs.data,rhs,temp);
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator+= (
+ uint16 rhs
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_add(data,rhs,temp);
+ data = temp;
+ }
+ // or if we need to enlarge data then do so
+ else if (data->digits_used == data->size)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ short_add(data,rhs,temp);
+ delete data;
+ data = temp;
+ }
+ // or if there is plenty of space and no references
+ else
+ {
+ short_add(data,rhs,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator- (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ bigint::data_record* temp = new bigint::data_record(rhs.slack);
+
+ *(temp->number) = lhs - *(rhs.data->number);
+
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator- (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (lhs.data->digits_used+lhs.slack);
+
+ lhs.short_sub(lhs.data,rhs,temp);
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator-= (
+ uint16 rhs
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_sub(data,rhs,temp);
+ data = temp;
+ }
+ else
+ {
+ short_sub(data,rhs,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator* (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (rhs.data->digits_used+rhs.slack);
+
+ rhs.short_mul(rhs.data,lhs,temp);
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator* (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (lhs.data->digits_used+lhs.slack);
+
+ lhs.short_mul(lhs.data,rhs,temp);
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator*= (
+ uint16 rhs
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_mul(data,rhs,temp);
+ data = temp;
+ }
+ // or if we need to enlarge data
+ else if (data->digits_used == data->size)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ short_mul(data,rhs,temp);
+ delete data;
+ data = temp;
+ }
+ else
+ {
+ short_mul(data,rhs,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator/ (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ bigint::data_record* temp = new bigint::data_record(rhs.slack);
+
+ // if rhs might not be bigger than lhs
+ if (rhs.data->digits_used == 1)
+ {
+ *(temp->number) = lhs/ *(rhs.data->number);
+ }
+
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator/ (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (lhs.data->digits_used+lhs.slack);
+
+ uint16 remainder;
+ lhs.short_div(lhs.data,rhs,temp,remainder);
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator/= (
+ uint16 rhs
+ )
+ {
+ uint16 remainder;
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_div(data,rhs,temp,remainder);
+ data = temp;
+ }
+ else
+ {
+ short_div(data,rhs,data,remainder);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator% (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ // temp is zero by default
+ bigint::data_record* temp = new bigint::data_record(rhs.slack);
+
+ if (rhs.data->digits_used == 1)
+ {
+ // if rhs is just an uint16 inside then perform the modulus
+ *(temp->number) = lhs % *(rhs.data->number);
+ }
+ else
+ {
+ // if rhs is bigger than lhs then the answer is lhs
+ *(temp->number) = lhs;
+ }
+
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 operator% (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_1 bigint;
+ bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack);
+
+ uint16 remainder;
+
+ lhs.short_div(lhs.data,rhs,temp,remainder);
+ temp->digits_used = 1;
+ *(temp->number) = remainder;
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator%= (
+ uint16 rhs
+ )
+ {
+ uint16 remainder;
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_div(data,rhs,temp,remainder);
+ data = temp;
+ }
+ else
+ {
+ short_div(data,rhs,data,remainder);
+ }
+
+ data->digits_used = 1;
+ *(data->number) = remainder;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool operator < (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ )
+ {
+ return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool operator < (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ )
+ {
+ return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool operator == (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ )
+ {
+ return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool operator == (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ )
+ {
+ return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator= (
+ uint16 rhs
+ )
+ {
+ // check if there are other references to our data
+ if (data->references != 1)
+ {
+ data->references -= 1;
+ try {
+ data = new data_record(slack);
+ } catch (...) { data->references += 1; throw; }
+ }
+ else
+ {
+ data->digits_used = 1;
+ }
+
+ *(data->number) = rhs;
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator++ (
+ )
+ {
+ // if there are other references to this data then make a copy of it
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ increment(data,temp);
+ data = temp;
+ }
+ // or if we need to enlarge data then do so
+ else if (data->digits_used == data->size)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ increment(data,temp);
+ delete data;
+ data = temp;
+ }
+ else
+ {
+ increment(data,data);
+ }
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 bigint_kernel_1::
+ operator++ (
+ int
+ )
+ {
+ data_record* temp; // this is the copy of temp we will return in the end
+
+ data_record* temp2 = new data_record(data->digits_used+slack);
+ increment(data,temp2);
+
+ temp = data;
+ data = temp2;
+
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_1& bigint_kernel_1::
+ operator-- (
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ decrement(data,temp);
+ data = temp;
+ }
+ else
+ {
+ decrement(data,data);
+ }
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_1 bigint_kernel_1::
+ operator-- (
+ int
+ )
+ {
+ data_record* temp; // this is the copy of temp we will return in the end
+
+ data_record* temp2 = new data_record(data->digits_used+slack);
+ decrement(data,temp2);
+
+ temp = data;
+ data = temp2;
+
+ return bigint_kernel_1(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ short_add (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const
+ {
+ // put value into the carry part of temp
+ uint32 temp = value;
+ temp <<= 16;
+
+
+ const uint16* number = data->number;
+ const uint16* end = number + data->digits_used; // one past the end of number
+ uint16* r = result->number;
+
+ while (number != end)
+ {
+ // add *number and the current carry
+ temp = *number + (temp>>16);
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++number;
+ ++r;
+ }
+
+ // if there is a final carry
+ if ((temp>>16) != 0)
+ {
+ result->digits_used = data->digits_used + 1;
+ // store the carry in the most significant digit of the result
+ *r = static_cast<uint16>(temp>>16);
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ short_sub (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const
+ {
+
+
+ const uint16* number = data->number;
+ const uint16* end = number + data->digits_used - 1;
+ uint16* r = result->number;
+
+ uint32 temp = *number - value;
+
+ // put the low word of temp into *data
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+
+ while (number != end)
+ {
+ ++number;
+ ++r;
+
+ // subtract the carry from *number
+ temp = *number - (temp>>31);
+
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+ }
+
+ // if we lost a digit in the subtraction
+ if (*r == 0)
+ {
+ if (data->digits_used == 1)
+ result->digits_used = 1;
+ else
+ result->digits_used = data->digits_used - 1;
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ short_mul (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const
+ {
+
+ uint32 temp = 0;
+
+
+ const uint16* number = data->number;
+ uint16* r = result->number;
+ const uint16* end = r + data->digits_used;
+
+
+
+ while ( r != end)
+ {
+
+ // multiply *data and value and add in the carry
+ temp = *number*(uint32)value + (temp>>16);
+
+ // put the low word of temp into *data
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++number;
+ ++r;
+ }
+
+ // if there is a final carry
+ if ((temp>>16) != 0)
+ {
+ result->digits_used = data->digits_used + 1;
+ // put the final carry into the most significant digit of the result
+ *r = static_cast<uint16>(temp>>16);
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ short_div (
+ const data_record* data,
+ uint16 value,
+ data_record* result,
+ uint16& rem
+ ) const
+ {
+
+ uint16 remainder = 0;
+ uint32 temp;
+
+
+
+ const uint16* number = data->number + data->digits_used - 1;
+ const uint16* end = number - data->digits_used;
+ uint16* r = result->number + data->digits_used - 1;
+
+
+ // if we are losing a digit in this division
+ if (*number < value)
+ {
+ if (data->digits_used == 1)
+ result->digits_used = 1;
+ else
+ result->digits_used = data->digits_used - 1;
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+
+
+ // perform the actual division
+ while (number != end)
+ {
+
+ temp = *number + (((uint32)remainder)<<16);
+
+ *r = static_cast<uint16>(temp/value);
+ remainder = static_cast<uint16>(temp%value);
+
+ --number;
+ --r;
+ }
+
+ rem = remainder;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ long_add (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const
+ {
+ // put value into the carry part of temp
+ uint32 temp=0;
+
+ uint16* min_num; // the number with the least digits used
+ uint16* max_num; // the number with the most digits used
+ uint16* min_end; // one past the end of min_num
+ uint16* max_end; // one past the end of max_num
+ uint16* r = result->number;
+
+ uint32 max_digits_used;
+ if (lhs->digits_used < rhs->digits_used)
+ {
+ max_digits_used = rhs->digits_used;
+ min_num = lhs->number;
+ max_num = rhs->number;
+ min_end = min_num + lhs->digits_used;
+ max_end = max_num + rhs->digits_used;
+ }
+ else
+ {
+ max_digits_used = lhs->digits_used;
+ min_num = rhs->number;
+ max_num = lhs->number;
+ min_end = min_num + rhs->digits_used;
+ max_end = max_num + lhs->digits_used;
+ }
+
+
+
+
+ while (min_num != min_end)
+ {
+ // add *min_num, *max_num and the current carry
+ temp = *min_num + *max_num + (temp>>16);
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++min_num;
+ ++max_num;
+ ++r;
+ }
+
+
+ while (max_num != max_end)
+ {
+ // add *max_num and the current carry
+ temp = *max_num + (temp>>16);
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++max_num;
+ ++r;
+ }
+
+ // check if there was a final carry
+ if ((temp>>16) != 0)
+ {
+ result->digits_used = max_digits_used + 1;
+ // put the carry into the most significant digit in the result
+ *r = static_cast<uint16>(temp>>16);
+ }
+ else
+ {
+ result->digits_used = max_digits_used;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ long_sub (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const
+ {
+
+
+ const uint16* number1 = lhs->number;
+ const uint16* number2 = rhs->number;
+ const uint16* end = number2 + rhs->digits_used;
+ uint16* r = result->number;
+
+
+
+ uint32 temp =0;
+
+
+ while (number2 != end)
+ {
+
+ // subtract *number2 from *number1 and then subtract any carry
+ temp = *number1 - *number2 - (temp>>31);
+
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++number1;
+ ++number2;
+ ++r;
+ }
+
+ end = lhs->number + lhs->digits_used;
+ while (number1 != end)
+ {
+
+ // subtract the carry from *number1
+ temp = *number1 - (temp>>31);
+
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++number1;
+ ++r;
+ }
+
+ result->digits_used = lhs->digits_used;
+ // adjust the number of digits used appropriately
+ --r;
+ while (*r == 0 && result->digits_used > 1)
+ {
+ --r;
+ --result->digits_used;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ long_div (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result,
+ data_record* remainder
+ ) const
+ {
+ // zero result
+ result->digits_used = 1;
+ *(result->number) = 0;
+
+ uint16* a;
+ uint16* b;
+ uint16* end;
+
+ // copy lhs into remainder
+ remainder->digits_used = lhs->digits_used;
+ a = remainder->number;
+ end = a + remainder->digits_used;
+ b = lhs->number;
+ while (a != end)
+ {
+ *a = *b;
+ ++a;
+ ++b;
+ }
+
+
+ // if rhs is bigger than lhs then result == 0 and remainder == lhs
+ // so then we can quit right now
+ if (is_less_than(lhs,rhs))
+ {
+ return;
+ }
+
+
+ // make a temporary number
+ data_record temp(lhs->digits_used + slack);
+
+
+ // shift rhs left until it is one shift away from being larger than lhs and
+ // put the number of left shifts necessary into shifts
+ uint32 shifts;
+ shifts = (lhs->digits_used - rhs->digits_used) * 16;
+
+ shift_left(rhs,&temp,shifts);
+
+
+ // while (lhs > temp)
+ while (is_less_than(&temp,lhs))
+ {
+ shift_left(&temp,&temp,1);
+ ++shifts;
+ }
+ // make sure lhs isn't smaller than temp
+ while (is_less_than(lhs,&temp))
+ {
+ shift_right(&temp,&temp);
+ --shifts;
+ }
+
+
+
+ // we want to execute the loop shifts +1 times
+ ++shifts;
+ while (shifts != 0)
+ {
+ shift_left(result,result,1);
+ // if (temp <= remainder)
+ if (!is_less_than(remainder,&temp))
+ {
+ long_sub(remainder,&temp,remainder);
+
+ // increment result
+ uint16* r = result->number;
+ uint16* end = r + result->digits_used;
+ while (true)
+ {
+ ++(*r);
+ // if there was no carry then we are done
+ if (*r != 0)
+ break;
+
+ ++r;
+
+ // if we hit the end of r and there is still a carry then
+ // the next digit of r is 1 and there is one more digit used
+ if (r == end)
+ {
+ *r = 1;
+ ++(result->digits_used);
+ break;
+ }
+ }
+ }
+ shift_right(&temp,&temp);
+ --shifts;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ long_mul (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const
+ {
+ // make result be zero
+ result->digits_used = 1;
+ *(result->number) = 0;
+
+
+ const data_record* aa;
+ const data_record* bb;
+
+ if (lhs->digits_used < rhs->digits_used)
+ {
+ // make copies of lhs and rhs and give them an appropriate amount of
+ // extra memory so there won't be any overflows
+ aa = lhs;
+ bb = rhs;
+ }
+ else
+ {
+ // make copies of lhs and rhs and give them an appropriate amount of
+ // extra memory so there won't be any overflows
+ aa = rhs;
+ bb = lhs;
+ }
+ // this is where we actually copy lhs and rhs
+ data_record b(*bb,aa->digits_used+slack); // the larger(approximately) of lhs and rhs
+
+
+ uint32 shift_value = 0;
+ uint16* anum = aa->number;
+ uint16* end = anum + aa->digits_used;
+ while (anum != end )
+ {
+ uint16 bit = 0x0001;
+
+ for (int i = 0; i < 16; ++i)
+ {
+ // if the specified bit of a is 1
+ if ((*anum & bit) != 0)
+ {
+ shift_left(&b,&b,shift_value);
+ shift_value = 0;
+ long_add(&b,result,result);
+ }
+ ++shift_value;
+ bit <<= 1;
+ }
+
+ ++anum;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ shift_left (
+ const data_record* data,
+ data_record* result,
+ uint32 shift_amount
+ ) const
+ {
+ uint32 offset = shift_amount/16;
+ shift_amount &= 0xf; // same as shift_amount %= 16;
+
+ uint16* r = result->number + data->digits_used + offset; // result
+ uint16* end = data->number;
+ uint16* s = end + data->digits_used; // source
+ const uint32 temp = 16 - shift_amount;
+
+ *r = (*(--s) >> temp);
+ // set the number of digits used in the result
+ // if the upper bits from *s were zero then don't count this first word
+ if (*r == 0)
+ {
+ result->digits_used = data->digits_used + offset;
+ }
+ else
+ {
+ result->digits_used = data->digits_used + offset + 1;
+ }
+ --r;
+
+ while (s != end)
+ {
+ *r = ((*s << shift_amount) | ( *(s-1) >> temp));
+ --r;
+ --s;
+ }
+ *r = *s << shift_amount;
+
+ // now zero the rest of the result
+ end = result->number;
+ while (r != end)
+ *(--r) = 0;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ shift_right (
+ const data_record* data,
+ data_record* result
+ ) const
+ {
+
+ uint16* r = result->number; // result
+ uint16* s = data->number; // source
+ uint16* end = s + data->digits_used - 1;
+
+ while (s != end)
+ {
+ *r = (*s >> 1) | (*(s+1) << 15);
+ ++r;
+ ++s;
+ }
+ *r = *s >> 1;
+
+
+ // calculate the new number for digits_used
+ if (*r == 0)
+ {
+ if (data->digits_used != 1)
+ result->digits_used = data->digits_used - 1;
+ else
+ result->digits_used = 1;
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bigint_kernel_1::
+ is_less_than (
+ const data_record* lhs,
+ const data_record* rhs
+ ) const
+ {
+ uint32 lhs_digits_used = lhs->digits_used;
+ uint32 rhs_digits_used = rhs->digits_used;
+
+ // if lhs is definitely less than rhs
+ if (lhs_digits_used < rhs_digits_used )
+ return true;
+ // if lhs is definitely greater than rhs
+ else if (lhs_digits_used > rhs_digits_used)
+ return false;
+ else
+ {
+ uint16* end = lhs->number;
+ uint16* l = end + lhs_digits_used;
+ uint16* r = rhs->number + rhs_digits_used;
+
+ while (l != end)
+ {
+ --l;
+ --r;
+ if (*l < *r)
+ return true;
+ else if (*l > *r)
+ return false;
+ }
+
+ // at this point we know that they are equal
+ return false;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bigint_kernel_1::
+ is_equal_to (
+ const data_record* lhs,
+ const data_record* rhs
+ ) const
+ {
+ // if lhs and rhs are definitely not equal
+ if (lhs->digits_used != rhs->digits_used )
+ {
+ return false;
+ }
+ else
+ {
+ uint16* l = lhs->number;
+ uint16* r = rhs->number;
+ uint16* end = l + lhs->digits_used;
+
+ while (l != end)
+ {
+ if (*l != *r)
+ return false;
+ ++l;
+ ++r;
+ }
+
+ // at this point we know that they are equal
+ return true;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ increment (
+ const data_record* source,
+ data_record* dest
+ ) const
+ {
+ uint16* s = source->number;
+ uint16* d = dest->number;
+ uint16* end = s + source->digits_used;
+ while (true)
+ {
+ *d = *s + 1;
+
+ // if there was no carry then break out of the loop
+ if (*d != 0)
+ {
+ dest->digits_used = source->digits_used;
+
+ // copy the rest of the digits over to d
+ ++d; ++s;
+ while (s != end)
+ {
+ *d = *s;
+ ++d;
+ ++s;
+ }
+
+ break;
+ }
+
+
+ ++s;
+
+ // if we have hit the end of s and there was a carry up to this point
+ // then just make the next digit 1 and add one to the digits used
+ if (s == end)
+ {
+ ++d;
+ dest->digits_used = source->digits_used + 1;
+ *d = 1;
+ break;
+ }
+
+ ++d;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_1::
+ decrement (
+ const data_record* source,
+ data_record* dest
+ ) const
+ {
+ uint16* s = source->number;
+ uint16* d = dest->number;
+ uint16* end = s + source->digits_used;
+ while (true)
+ {
+ *d = *s - 1;
+
+ // if there was no carry then break out of the loop
+ if (*d != 0xFFFF)
+ {
+ // if we lost a digit in the subtraction
+ if (*d == 0 && s+1 == end)
+ {
+ if (source->digits_used == 1)
+ dest->digits_used = 1;
+ else
+ dest->digits_used = source->digits_used - 1;
+ }
+ else
+ {
+ dest->digits_used = source->digits_used;
+ }
+ break;
+ }
+ else
+ {
+ ++d;
+ ++s;
+ }
+
+ }
+
+ // copy the rest of the digits over to d
+ ++d;
+ ++s;
+ while (s != end)
+ {
+ *d = *s;
+ ++d;
+ ++s;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_BIGINT_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/bigint/bigint_kernel_1.h b/ml/dlib/dlib/bigint/bigint_kernel_1.h
new file mode 100644
index 000000000..3e7f3d851
--- /dev/null
+++ b/ml/dlib/dlib/bigint/bigint_kernel_1.h
@@ -0,0 +1,544 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIGINT_KERNEl_1_
+#define DLIB_BIGINT_KERNEl_1_
+
+#include "bigint_kernel_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "../uintn.h"
+#include <iosfwd>
+
+namespace dlib
+{
+
+
+ class bigint_kernel_1
+ {
+ /*!
+ INITIAL VALUE
+ slack == 25
+ data->number[0] == 0
+ data->size == slack
+ data->references == 1
+ data->digits_used == 1
+
+
+ CONVENTION
+ slack == the number of extra digits placed into the number when it is
+ created. the slack value should never be less than 1
+
+ data->number == pointer to an array of data->size uint16s.
+ data represents a string of base 65535 numbers with data[0] being
+ the least significant bit and data[data->digits_used-1] being the most
+ significant
+
+
+ NOTE: In the comments I will consider a word to be a 16 bit value
+
+
+ data->digits_used == the number of significant digits in the number.
+ data->digits_used tells us the number of used elements in the
+ data->number array so everything beyond data->number[data->digits_used-1]
+ is undefined
+
+ data->references == the number of bigint_kernel_1 objects which refer
+ to this data_record
+
+
+
+ !*/
+
+
+ struct data_record
+ {
+
+
+ explicit data_record(
+ uint32 size_
+ ) :
+ size(size_),
+ number(new uint16[size_]),
+ references(1),
+ digits_used(1)
+ {*number = 0;}
+ /*!
+ ensures
+ - initializes *this to represent zero
+ !*/
+
+ data_record(
+ const data_record& item,
+ uint32 additional_size
+ ) :
+ size(item.digits_used + additional_size),
+ number(new uint16[size]),
+ references(1),
+ digits_used(item.digits_used)
+ {
+ uint16* source = item.number;
+ uint16* dest = number;
+ uint16* end = source + digits_used;
+ while (source != end)
+ {
+ *dest = *source;
+ ++dest;
+ ++source;
+ }
+ }
+ /*!
+ ensures
+ - *this is a copy of item except with
+ size == item.digits_used + additional_size
+ !*/
+
+ ~data_record(
+ )
+ {
+ delete [] number;
+ }
+
+
+ const uint32 size;
+ uint16* number;
+ uint32 references;
+ uint32 digits_used;
+
+ private:
+ // no copy constructor
+ data_record ( data_record&);
+ };
+
+
+
+ // note that the second parameter is just there
+ // to resolve the ambiguity between this constructor and
+ // bigint_kernel_1(uint32)
+ explicit bigint_kernel_1 (
+ data_record* data_, int
+ ): slack(25),data(data_) {}
+ /*!
+ ensures
+ - *this is initialized with data_ as its data member
+ !*/
+
+
+ public:
+
+ bigint_kernel_1 (
+ );
+
+ bigint_kernel_1 (
+ uint32 value
+ );
+
+ bigint_kernel_1 (
+ const bigint_kernel_1& item
+ );
+
+ virtual ~bigint_kernel_1 (
+ );
+
+ const bigint_kernel_1 operator+ (
+ const bigint_kernel_1& rhs
+ ) const;
+
+ bigint_kernel_1& operator+= (
+ const bigint_kernel_1& rhs
+ );
+
+ const bigint_kernel_1 operator- (
+ const bigint_kernel_1& rhs
+ ) const;
+
+ bigint_kernel_1& operator-= (
+ const bigint_kernel_1& rhs
+ );
+
+ const bigint_kernel_1 operator* (
+ const bigint_kernel_1& rhs
+ ) const;
+
+ bigint_kernel_1& operator*= (
+ const bigint_kernel_1& rhs
+ );
+
+ const bigint_kernel_1 operator/ (
+ const bigint_kernel_1& rhs
+ ) const;
+
+ bigint_kernel_1& operator/= (
+ const bigint_kernel_1& rhs
+ );
+
+ const bigint_kernel_1 operator% (
+ const bigint_kernel_1& rhs
+ ) const;
+
+ bigint_kernel_1& operator%= (
+ const bigint_kernel_1& rhs
+ );
+
+ bool operator < (
+ const bigint_kernel_1& rhs
+ ) const;
+
+ bool operator == (
+ const bigint_kernel_1& rhs
+ ) const;
+
+ bigint_kernel_1& operator= (
+ const bigint_kernel_1& rhs
+ );
+
+ friend std::ostream& operator<< (
+ std::ostream& out,
+ const bigint_kernel_1& rhs
+ );
+
+ friend std::istream& operator>> (
+ std::istream& in,
+ bigint_kernel_1& rhs
+ );
+
+ bigint_kernel_1& operator++ (
+ );
+
+ const bigint_kernel_1 operator++ (
+ int
+ );
+
+ bigint_kernel_1& operator-- (
+ );
+
+ const bigint_kernel_1 operator-- (
+ int
+ );
+
+ friend const bigint_kernel_1 operator+ (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ );
+
+ friend const bigint_kernel_1 operator+ (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_1& operator+= (
+ uint16 rhs
+ );
+
+ friend const bigint_kernel_1 operator- (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ );
+
+ friend const bigint_kernel_1 operator- (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_1& operator-= (
+ uint16 rhs
+ );
+
+ friend const bigint_kernel_1 operator* (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ );
+
+ friend const bigint_kernel_1 operator* (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_1& operator*= (
+ uint16 rhs
+ );
+
+ friend const bigint_kernel_1 operator/ (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ );
+
+ friend const bigint_kernel_1 operator/ (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_1& operator/= (
+ uint16 rhs
+ );
+
+ friend const bigint_kernel_1 operator% (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ );
+
+ friend const bigint_kernel_1 operator% (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_1& operator%= (
+ uint16 rhs
+ );
+
+ friend bool operator < (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ );
+
+ friend bool operator < (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ );
+
+ friend bool operator == (
+ const bigint_kernel_1& lhs,
+ uint16 rhs
+ );
+
+ friend bool operator == (
+ uint16 lhs,
+ const bigint_kernel_1& rhs
+ );
+
+ bigint_kernel_1& operator= (
+ uint16 rhs
+ );
+
+
+ void swap (
+ bigint_kernel_1& item
+ ) { data_record* temp = data; data = item.data; item.data = temp; }
+
+
+ private:
+
+ void long_add (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= max(lhs->digits_used,rhs->digits_used) + 1
+ ensures
+ - result == lhs + rhs
+ !*/
+
+ void long_sub (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - lhs >= rhs
+ - result->size >= lhs->digits_used
+ ensures
+ - result == lhs - rhs
+ !*/
+
+ void long_div (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result,
+ data_record* remainder
+ ) const;
+ /*!
+ requires
+ - rhs != 0
+ - result->size >= lhs->digits_used
+ - remainder->size >= lhs->digits_used
+ - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.)
+ ensures
+ - result == lhs / rhs
+ - remainder == lhs % rhs
+ !*/
+
+ void long_mul (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= lhs->digits_used + rhs->digits_used
+ - result != lhs
+ - result != rhs
+ ensures
+ - result == lhs * rhs
+ !*/
+
+ void short_add (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= data->size + 1
+ ensures
+ - result == data + value
+ !*/
+
+ void short_sub (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - data >= value
+ - result->size >= data->digits_used
+ ensures
+ - result == data - value
+ !*/
+
+ void short_mul (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= data->digits_used + 1
+ ensures
+ - result == data * value
+ !*/
+
+ void short_div (
+ const data_record* data,
+ uint16 value,
+ data_record* result,
+ uint16& remainder
+ ) const;
+ /*!
+ requires
+ - value != 0
+ - result->size >= data->digits_used
+ ensures
+ - result = data*value
+ - remainder = data%value
+ !*/
+
+ void shift_left (
+ const data_record* data,
+ data_record* result,
+ uint32 shift_amount
+ ) const;
+ /*!
+ requires
+ - result->size >= data->digits_used + shift_amount/8 + 1
+ ensures
+ - result == data << shift_amount
+ !*/
+
+ void shift_right (
+ const data_record* data,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= data->digits_used
+ ensures
+ - result == data >> 1
+ !*/
+
+ bool is_less_than (
+ const data_record* lhs,
+ const data_record* rhs
+ ) const;
+ /*!
+ ensures
+ - returns true if lhs < rhs
+ - returns false otherwise
+ !*/
+
+ bool is_equal_to (
+ const data_record* lhs,
+ const data_record* rhs
+ ) const;
+ /*!
+ ensures
+ - returns true if lhs == rhs
+ - returns false otherwise
+ !*/
+
+ void increment (
+ const data_record* source,
+ data_record* dest
+ ) const;
+ /*!
+ requires
+ - dest->size >= source->digits_used + 1
+ ensures
+ - dest = source + 1
+ !*/
+
+ void decrement (
+ const data_record* source,
+ data_record* dest
+ ) const;
+ /*!
+ requires
+ source != 0
+ ensuers
+ dest = source - 1
+ !*/
+
+ // member data
+ const uint32 slack;
+ data_record* data;
+
+
+
+ };
+
+ inline void swap (
+ bigint_kernel_1& a,
+ bigint_kernel_1& b
+ ) { a.swap(b); }
+
+ inline void serialize (
+ const bigint_kernel_1& item,
+ std::ostream& out
+ )
+ {
+ std::ios::fmtflags oldflags = out.flags();
+ out.flags();
+ out << item << ' ';
+ out.flags(oldflags);
+ if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c");
+ }
+
+ inline void deserialize (
+ bigint_kernel_1& item,
+ std::istream& in
+ )
+ {
+ std::ios::fmtflags oldflags = in.flags();
+ in.flags();
+ in >> item; in.flags(oldflags);
+ if (in.get() != ' ')
+ {
+ item = 0;
+ throw serialization_error("Error deserializing object of type bigint_kernel_c");
+ }
+ }
+
+ inline bool operator> (const bigint_kernel_1& a, const bigint_kernel_1& b) { return b < a; }
+ inline bool operator!= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a == b); }
+ inline bool operator<= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(b < a); }
+ inline bool operator>= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a < b); }
+}
+
+#ifdef NO_MAKEFILE
+#include "bigint_kernel_1.cpp"
+#endif
+
+#endif // DLIB_BIGINT_KERNEl_1_
+
diff --git a/ml/dlib/dlib/bigint/bigint_kernel_2.cpp b/ml/dlib/dlib/bigint/bigint_kernel_2.cpp
new file mode 100644
index 000000000..005e080af
--- /dev/null
+++ b/ml/dlib/dlib/bigint/bigint_kernel_2.cpp
@@ -0,0 +1,1945 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIGINT_KERNEL_2_CPp_
+#define DLIB_BIGINT_KERNEL_2_CPp_
+#include "bigint_kernel_2.h"
+
+#include <iostream>
+#include <cmath>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member/friend function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2::
+ bigint_kernel_2 (
+ ) :
+ slack(25),
+ data(new data_record(slack))
+ {}
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2::
+ bigint_kernel_2 (
+ uint32 value
+ ) :
+ slack(25),
+ data(new data_record(slack))
+ {
+ *(data->number) = static_cast<uint16>(value&0xFFFF);
+ *(data->number+1) = static_cast<uint16>((value>>16)&0xFFFF);
+ if (*(data->number+1) != 0)
+ data->digits_used = 2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2::
+ bigint_kernel_2 (
+ const bigint_kernel_2& item
+ ) :
+ slack(25),
+ data(item.data)
+ {
+ data->references += 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2::
+ ~bigint_kernel_2 (
+ )
+ {
+ if (data->references == 1)
+ {
+ delete data;
+ }
+ else
+ {
+ data->references -= 1;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 bigint_kernel_2::
+ operator+ (
+ const bigint_kernel_2& rhs
+ ) const
+ {
+ data_record* temp = new data_record (
+ std::max(rhs.data->digits_used,data->digits_used) + slack
+ );
+ long_add(data,rhs.data,temp);
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator+= (
+ const bigint_kernel_2& rhs
+ )
+ {
+ // if there are other references to our data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack);
+ data->references -= 1;
+ long_add(data,rhs.data,temp);
+ data = temp;
+ }
+ // if data is not big enough for the result
+ else if (data->size <= std::max(data->digits_used,rhs.data->digits_used))
+ {
+ data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack);
+ long_add(data,rhs.data,temp);
+ delete data;
+ data = temp;
+ }
+ // there is enough size and no references
+ else
+ {
+ long_add(data,rhs.data,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 bigint_kernel_2::
+ operator- (
+ const bigint_kernel_2& rhs
+ ) const
+ {
+ data_record* temp = new data_record (
+ data->digits_used + slack
+ );
+ long_sub(data,rhs.data,temp);
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator-= (
+ const bigint_kernel_2& rhs
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ long_sub(data,rhs.data,temp);
+ data = temp;
+ }
+ else
+ {
+ long_sub(data,rhs.data,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 bigint_kernel_2::
+ operator* (
+ const bigint_kernel_2& rhs
+ ) const
+ {
+ data_record* temp = new data_record (
+ data->digits_used + rhs.data->digits_used + slack
+ );
+ long_mul(data,rhs.data,temp);
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator*= (
+ const bigint_kernel_2& rhs
+ )
+ {
+ // create a data_record to store the result of the multiplication in
+ data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack);
+ long_mul(data,rhs.data,temp);
+
+ // if there are other references to data
+ if (data->references != 1)
+ {
+ data->references -= 1;
+ }
+ else
+ {
+ delete data;
+ }
+ data = temp;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 bigint_kernel_2::
+ operator/ (
+ const bigint_kernel_2& rhs
+ ) const
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data_record* remainder;
+ try {
+ remainder = new data_record(data->digits_used+slack);
+ } catch (...) { delete temp; throw; }
+
+ long_div(data,rhs.data,temp,remainder);
+ delete remainder;
+
+
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator/= (
+ const bigint_kernel_2& rhs
+ )
+ {
+
+ data_record* temp = new data_record(data->digits_used+slack);
+ data_record* remainder;
+ try {
+ remainder = new data_record(data->digits_used+slack);
+ } catch (...) { delete temp; throw; }
+
+ long_div(data,rhs.data,temp,remainder);
+
+ // check if there are other references to data
+ if (data->references != 1)
+ {
+ data->references -= 1;
+ }
+ // if there are no references to data then it must be deleted
+ else
+ {
+ delete data;
+ }
+ data = temp;
+ delete remainder;
+
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 bigint_kernel_2::
+ operator% (
+ const bigint_kernel_2& rhs
+ ) const
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data_record* remainder;
+ try {
+ remainder = new data_record(data->digits_used+slack);
+ } catch (...) { delete temp; throw; }
+
+ long_div(data,rhs.data,temp,remainder);
+ delete temp;
+ return bigint_kernel_2(remainder,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator%= (
+ const bigint_kernel_2& rhs
+ )
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data_record* remainder;
+ try {
+ remainder = new data_record(data->digits_used+slack);
+ } catch (...) { delete temp; throw; }
+
+ long_div(data,rhs.data,temp,remainder);
+
+ // check if there are other references to data
+ if (data->references != 1)
+ {
+ data->references -= 1;
+ }
+ // if there are no references to data then it must be deleted
+ else
+ {
+ delete data;
+ }
+ data = remainder;
+ delete temp;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bigint_kernel_2::
+ operator < (
+ const bigint_kernel_2& rhs
+ ) const
+ {
+ return is_less_than(data,rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bigint_kernel_2::
+ operator == (
+ const bigint_kernel_2& rhs
+ ) const
+ {
+ return is_equal_to(data,rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator= (
+ const bigint_kernel_2& rhs
+ )
+ {
+ if (this == &rhs)
+ return *this;
+
+ // if we have the only reference to our data then delete it
+ if (data->references == 1)
+ {
+ delete data;
+ data = rhs.data;
+ data->references += 1;
+ }
+ else
+ {
+ data->references -= 1;
+ data = rhs.data;
+ data->references += 1;
+ }
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::ostream& operator<< (
+ std::ostream& out_,
+ const bigint_kernel_2& rhs
+ )
+ {
+ std::ostream out(out_.rdbuf());
+
+ typedef bigint_kernel_2 bigint;
+
+ bigint::data_record* temp = new bigint::data_record(*rhs.data,0);
+
+
+
+ // get a char array big enough to hold the number in ascii format
+ char* str;
+ try {
+ str = new char[(rhs.data->digits_used)*5+10];
+ } catch (...) { delete temp; throw; }
+
+ char* str_start = str;
+ str += (rhs.data->digits_used)*5+9;
+ *str = 0; --str;
+
+
+ uint16 remainder;
+ rhs.short_div(temp,10000,temp,remainder);
+
+ // pull the digits out of remainder
+ char a = remainder % 10 + '0';
+ remainder /= 10;
+ char b = remainder % 10 + '0';
+ remainder /= 10;
+ char c = remainder % 10 + '0';
+ remainder /= 10;
+ char d = remainder % 10 + '0';
+ remainder /= 10;
+
+
+ *str = a; --str;
+ *str = b; --str;
+ *str = c; --str;
+ *str = d; --str;
+
+
+ // keep looping until temp represents zero
+ while (temp->digits_used != 1 || *(temp->number) != 0)
+ {
+ rhs.short_div(temp,10000,temp,remainder);
+
+ // pull the digits out of remainder
+ char a = remainder % 10 + '0';
+ remainder /= 10;
+ char b = remainder % 10 + '0';
+ remainder /= 10;
+ char c = remainder % 10 + '0';
+ remainder /= 10;
+ char d = remainder % 10 + '0';
+ remainder /= 10;
+
+ *str = a; --str;
+ *str = b; --str;
+ *str = c; --str;
+ *str = d; --str;
+ }
+
+ // throw away and extra leading zeros
+ ++str;
+ if (*str == '0')
+ ++str;
+ if (*str == '0')
+ ++str;
+ if (*str == '0')
+ ++str;
+
+
+
+
+ out << str;
+ delete [] str_start;
+ delete temp;
+ return out_;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::istream& operator>> (
+ std::istream& in_,
+ bigint_kernel_2& rhs
+ )
+ {
+ std::istream in(in_.rdbuf());
+
+ // ignore any leading whitespaces
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n')
+ {
+ in.get();
+ }
+
+ // if the first digit is not an integer then this is an error
+ if ( !(in.peek() >= '0' && in.peek() <= '9'))
+ {
+ in_.clear(std::ios::failbit);
+ return in_;
+ }
+
+ int num_read;
+ bigint_kernel_2 temp;
+ do
+ {
+
+ // try to get 4 chars from in
+ num_read = 1;
+ char a = 0;
+ char b = 0;
+ char c = 0;
+ char d = 0;
+
+ if (in.peek() >= '0' && in.peek() <= '9')
+ {
+ num_read *= 10;
+ a = in.get();
+ }
+ if (in.peek() >= '0' && in.peek() <= '9')
+ {
+ num_read *= 10;
+ b = in.get();
+ }
+ if (in.peek() >= '0' && in.peek() <= '9')
+ {
+ num_read *= 10;
+ c = in.get();
+ }
+ if (in.peek() >= '0' && in.peek() <= '9')
+ {
+ num_read *= 10;
+ d = in.get();
+ }
+
+ // merge the for digits into an uint16
+ uint16 num = 0;
+ if (a != 0)
+ {
+ num = a - '0';
+ }
+ if (b != 0)
+ {
+ num *= 10;
+ num += b - '0';
+ }
+ if (c != 0)
+ {
+ num *= 10;
+ num += c - '0';
+ }
+ if (d != 0)
+ {
+ num *= 10;
+ num += d - '0';
+ }
+
+
+ if (num_read != 1)
+ {
+ // shift the digits in temp left by the number of new digits we just read
+ temp *= num_read;
+ // add in new digits
+ temp += num;
+ }
+
+ } while (num_read == 10000);
+
+
+ rhs = temp;
+ return in_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator+ (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (rhs.data->digits_used+rhs.slack);
+
+ rhs.short_add(rhs.data,lhs,temp);
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator+ (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (lhs.data->digits_used+lhs.slack);
+
+ lhs.short_add(lhs.data,rhs,temp);
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator+= (
+ uint16 rhs
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_add(data,rhs,temp);
+ data = temp;
+ }
+ // or if we need to enlarge data then do so
+ else if (data->digits_used == data->size)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ short_add(data,rhs,temp);
+ delete data;
+ data = temp;
+ }
+ // or if there is plenty of space and no references
+ else
+ {
+ short_add(data,rhs,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator- (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ bigint::data_record* temp = new bigint::data_record(rhs.slack);
+
+ *(temp->number) = lhs - *(rhs.data->number);
+
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator- (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (lhs.data->digits_used+lhs.slack);
+
+ lhs.short_sub(lhs.data,rhs,temp);
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator-= (
+ uint16 rhs
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_sub(data,rhs,temp);
+ data = temp;
+ }
+ else
+ {
+ short_sub(data,rhs,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator* (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (rhs.data->digits_used+rhs.slack);
+
+ rhs.short_mul(rhs.data,lhs,temp);
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator* (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (lhs.data->digits_used+lhs.slack);
+
+ lhs.short_mul(lhs.data,rhs,temp);
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator*= (
+ uint16 rhs
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_mul(data,rhs,temp);
+ data = temp;
+ }
+ // or if we need to enlarge data
+ else if (data->digits_used == data->size)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ short_mul(data,rhs,temp);
+ delete data;
+ data = temp;
+ }
+ else
+ {
+ short_mul(data,rhs,data);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator/ (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ bigint::data_record* temp = new bigint::data_record(rhs.slack);
+
+ // if rhs might not be bigger than lhs
+ if (rhs.data->digits_used == 1)
+ {
+ *(temp->number) = lhs/ *(rhs.data->number);
+ }
+
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator/ (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ bigint::data_record* temp = new bigint::data_record
+ (lhs.data->digits_used+lhs.slack);
+
+ uint16 remainder;
+ lhs.short_div(lhs.data,rhs,temp,remainder);
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator/= (
+ uint16 rhs
+ )
+ {
+ uint16 remainder;
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_div(data,rhs,temp,remainder);
+ data = temp;
+ }
+ else
+ {
+ short_div(data,rhs,data,remainder);
+ }
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator% (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ // temp is zero by default
+ bigint::data_record* temp = new bigint::data_record(rhs.slack);
+
+ if (rhs.data->digits_used == 1)
+ {
+ // if rhs is just an uint16 inside then perform the modulus
+ *(temp->number) = lhs % *(rhs.data->number);
+ }
+ else
+ {
+ // if rhs is bigger than lhs then the answer is lhs
+ *(temp->number) = lhs;
+ }
+
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 operator% (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ )
+ {
+ typedef bigint_kernel_2 bigint;
+ bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack);
+
+ uint16 remainder;
+
+ lhs.short_div(lhs.data,rhs,temp,remainder);
+ temp->digits_used = 1;
+ *(temp->number) = remainder;
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator%= (
+ uint16 rhs
+ )
+ {
+ uint16 remainder;
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ short_div(data,rhs,temp,remainder);
+ data = temp;
+ }
+ else
+ {
+ short_div(data,rhs,data,remainder);
+ }
+
+ data->digits_used = 1;
+ *(data->number) = remainder;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool operator < (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ )
+ {
+ return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool operator < (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ )
+ {
+ return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool operator == (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ )
+ {
+ return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool operator == (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ )
+ {
+ return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator= (
+ uint16 rhs
+ )
+ {
+ // check if there are other references to our data
+ if (data->references != 1)
+ {
+ data->references -= 1;
+ try {
+ data = new data_record(slack);
+ } catch (...) { data->references += 1; throw; }
+ }
+ else
+ {
+ data->digits_used = 1;
+ }
+
+ *(data->number) = rhs;
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator++ (
+ )
+ {
+ // if there are other references to this data then make a copy of it
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ increment(data,temp);
+ data = temp;
+ }
+ // or if we need to enlarge data then do so
+ else if (data->digits_used == data->size)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ increment(data,temp);
+ delete data;
+ data = temp;
+ }
+ else
+ {
+ increment(data,data);
+ }
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 bigint_kernel_2::
+ operator++ (
+ int
+ )
+ {
+ data_record* temp; // this is the copy of temp we will return in the end
+
+ data_record* temp2 = new data_record(data->digits_used+slack);
+ increment(data,temp2);
+
+ temp = data;
+ data = temp2;
+
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bigint_kernel_2& bigint_kernel_2::
+ operator-- (
+ )
+ {
+ // if there are other references to this data
+ if (data->references != 1)
+ {
+ data_record* temp = new data_record(data->digits_used+slack);
+ data->references -= 1;
+ decrement(data,temp);
+ data = temp;
+ }
+ else
+ {
+ decrement(data,data);
+ }
+
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const bigint_kernel_2 bigint_kernel_2::
+ operator-- (
+ int
+ )
+ {
+ data_record* temp; // this is the copy of temp we will return in the end
+
+ data_record* temp2 = new data_record(data->digits_used+slack);
+ decrement(data,temp2);
+
+ temp = data;
+ data = temp2;
+
+ return bigint_kernel_2(temp,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ short_add (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const
+ {
+ // put value into the carry part of temp
+ uint32 temp = value;
+ temp <<= 16;
+
+
+ const uint16* number = data->number;
+ const uint16* end = number + data->digits_used; // one past the end of number
+ uint16* r = result->number;
+
+ while (number != end)
+ {
+ // add *number and the current carry
+ temp = *number + (temp>>16);
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++number;
+ ++r;
+ }
+
+ // if there is a final carry
+ if ((temp>>16) != 0)
+ {
+ result->digits_used = data->digits_used + 1;
+ // store the carry in the most significant digit of the result
+ *r = static_cast<uint16>(temp>>16);
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ short_sub (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const
+ {
+
+
+ const uint16* number = data->number;
+ const uint16* end = number + data->digits_used - 1;
+ uint16* r = result->number;
+
+ uint32 temp = *number - value;
+
+ // put the low word of temp into *data
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+
+ while (number != end)
+ {
+ ++number;
+ ++r;
+
+ // subtract the carry from *number
+ temp = *number - (temp>>31);
+
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+ }
+
+ // if we lost a digit in the subtraction
+ if (*r == 0)
+ {
+ if (data->digits_used == 1)
+ result->digits_used = 1;
+ else
+ result->digits_used = data->digits_used - 1;
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ short_mul (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const
+ {
+
+ uint32 temp = 0;
+
+
+ const uint16* number = data->number;
+ uint16* r = result->number;
+ const uint16* end = r + data->digits_used;
+
+
+
+ while ( r != end)
+ {
+
+ // multiply *data and value and add in the carry
+ temp = *number*(uint32)value + (temp>>16);
+
+ // put the low word of temp into *data
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++number;
+ ++r;
+ }
+
+ // if there is a final carry
+ if ((temp>>16) != 0)
+ {
+ result->digits_used = data->digits_used + 1;
+ // put the final carry into the most significant digit of the result
+ *r = static_cast<uint16>(temp>>16);
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ short_div (
+ const data_record* data,
+ uint16 value,
+ data_record* result,
+ uint16& rem
+ ) const
+ {
+
+ uint16 remainder = 0;
+ uint32 temp;
+
+
+
+ const uint16* number = data->number + data->digits_used - 1;
+ const uint16* end = number - data->digits_used;
+ uint16* r = result->number + data->digits_used - 1;
+
+
+ // if we are losing a digit in this division
+ if (*number < value)
+ {
+ if (data->digits_used == 1)
+ result->digits_used = 1;
+ else
+ result->digits_used = data->digits_used - 1;
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+
+
+ // perform the actual division
+ while (number != end)
+ {
+
+ temp = *number + (((uint32)remainder)<<16);
+
+ *r = static_cast<uint16>(temp/value);
+ remainder = static_cast<uint16>(temp%value);
+
+ --number;
+ --r;
+ }
+
+ rem = remainder;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ long_add (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const
+ {
+ // put value into the carry part of temp
+ uint32 temp=0;
+
+ uint16* min_num; // the number with the least digits used
+ uint16* max_num; // the number with the most digits used
+ uint16* min_end; // one past the end of min_num
+ uint16* max_end; // one past the end of max_num
+ uint16* r = result->number;
+
+ uint32 max_digits_used;
+ if (lhs->digits_used < rhs->digits_used)
+ {
+ max_digits_used = rhs->digits_used;
+ min_num = lhs->number;
+ max_num = rhs->number;
+ min_end = min_num + lhs->digits_used;
+ max_end = max_num + rhs->digits_used;
+ }
+ else
+ {
+ max_digits_used = lhs->digits_used;
+ min_num = rhs->number;
+ max_num = lhs->number;
+ min_end = min_num + rhs->digits_used;
+ max_end = max_num + lhs->digits_used;
+ }
+
+
+
+
+ while (min_num != min_end)
+ {
+ // add *min_num, *max_num and the current carry
+ temp = *min_num + *max_num + (temp>>16);
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++min_num;
+ ++max_num;
+ ++r;
+ }
+
+
+ while (max_num != max_end)
+ {
+ // add *max_num and the current carry
+ temp = *max_num + (temp>>16);
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++max_num;
+ ++r;
+ }
+
+ // check if there was a final carry
+ if ((temp>>16) != 0)
+ {
+ result->digits_used = max_digits_used + 1;
+ // put the carry into the most significant digit in the result
+ *r = static_cast<uint16>(temp>>16);
+ }
+ else
+ {
+ result->digits_used = max_digits_used;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ long_sub (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const
+ {
+
+
+ const uint16* number1 = lhs->number;
+ const uint16* number2 = rhs->number;
+ const uint16* end = number2 + rhs->digits_used;
+ uint16* r = result->number;
+
+
+
+ uint32 temp =0;
+
+
+ while (number2 != end)
+ {
+
+ // subtract *number2 from *number1 and then subtract any carry
+ temp = *number1 - *number2 - (temp>>31);
+
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++number1;
+ ++number2;
+ ++r;
+ }
+
+ end = lhs->number + lhs->digits_used;
+ while (number1 != end)
+ {
+
+ // subtract the carry from *number1
+ temp = *number1 - (temp>>31);
+
+ // put the low word of temp into *r
+ *r = static_cast<uint16>(temp & 0xFFFF);
+
+ ++number1;
+ ++r;
+ }
+
+ result->digits_used = lhs->digits_used;
+ // adjust the number of digits used appropriately
+ --r;
+ while (*r == 0 && result->digits_used > 1)
+ {
+ --r;
+ --result->digits_used;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ long_div (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result,
+ data_record* remainder
+ ) const
+ {
+ // zero result
+ result->digits_used = 1;
+ *(result->number) = 0;
+
+ uint16* a;
+ uint16* b;
+ uint16* end;
+
+ // copy lhs into remainder
+ remainder->digits_used = lhs->digits_used;
+ a = remainder->number;
+ end = a + remainder->digits_used;
+ b = lhs->number;
+ while (a != end)
+ {
+ *a = *b;
+ ++a;
+ ++b;
+ }
+
+
+ // if rhs is bigger than lhs then result == 0 and remainder == lhs
+ // so then we can quit right now
+ if (is_less_than(lhs,rhs))
+ {
+ return;
+ }
+
+
+ // make a temporary number
+ data_record temp(lhs->digits_used + slack);
+
+
+ // shift rhs left until it is one shift away from being larger than lhs and
+ // put the number of left shifts necessary into shifts
+ uint32 shifts;
+ shifts = (lhs->digits_used - rhs->digits_used) * 16;
+
+ shift_left(rhs,&temp,shifts);
+
+
+ // while (lhs > temp)
+ while (is_less_than(&temp,lhs))
+ {
+ shift_left(&temp,&temp,1);
+ ++shifts;
+ }
+ // make sure lhs isn't smaller than temp
+ while (is_less_than(lhs,&temp))
+ {
+ shift_right(&temp,&temp);
+ --shifts;
+ }
+
+
+
+ // we want to execute the loop shifts +1 times
+ ++shifts;
+ while (shifts != 0)
+ {
+ shift_left(result,result,1);
+ // if (temp <= remainder)
+ if (!is_less_than(remainder,&temp))
+ {
+ long_sub(remainder,&temp,remainder);
+
+ // increment result
+ uint16* r = result->number;
+ uint16* end = r + result->digits_used;
+ while (true)
+ {
+ ++(*r);
+ // if there was no carry then we are done
+ if (*r != 0)
+ break;
+
+ ++r;
+
+ // if we hit the end of r and there is still a carry then
+ // the next digit of r is 1 and there is one more digit used
+ if (r == end)
+ {
+ *r = 1;
+ ++(result->digits_used);
+ break;
+ }
+ }
+ }
+ shift_right(&temp,&temp);
+ --shifts;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ long_mul (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const
+ {
+ // if one of the numbers is small then use this simple but O(n^2) algorithm
+ if (std::min(lhs->digits_used, rhs->digits_used) < 10)
+ {
+ // make result be zero
+ result->digits_used = 1;
+ *(result->number) = 0;
+
+
+ const data_record* aa;
+ const data_record* bb;
+
+ if (lhs->digits_used < rhs->digits_used)
+ {
+ // make copies of lhs and rhs and give them an appropriate amount of
+ // extra memory so there won't be any overflows
+ aa = lhs;
+ bb = rhs;
+ }
+ else
+ {
+ // make copies of lhs and rhs and give them an appropriate amount of
+ // extra memory so there won't be any overflows
+ aa = rhs;
+ bb = lhs;
+ }
+
+ // copy the larger(approximately) of lhs and rhs into b
+ data_record b(*bb,aa->digits_used+slack);
+
+
+ uint32 shift_value = 0;
+ uint16* anum = aa->number;
+ uint16* end = anum + aa->digits_used;
+ while (anum != end )
+ {
+ uint16 bit = 0x0001;
+
+ for (int i = 0; i < 16; ++i)
+ {
+ // if the specified bit of a is 1
+ if ((*anum & bit) != 0)
+ {
+ shift_left(&b,&b,shift_value);
+ shift_value = 0;
+ long_add(&b,result,result);
+ }
+ ++shift_value;
+ bit <<= 1;
+ }
+
+ ++anum;
+ }
+ }
+ else // else if both lhs and rhs are large then use the more complex
+ // O(n*logn) algorithm
+ {
+ uint32 size = 1;
+ // make size a power of 2
+ while (size < (lhs->digits_used + rhs->digits_used)*2)
+ {
+ size *= 2;
+ }
+
+ // allocate some temporary space so we can do the FFT
+ ct* a = new ct[size];
+ ct* b; try {b = new ct[size]; } catch (...) { delete [] a; throw; }
+
+ // load lhs into the a array. We are breaking the input number into
+ // 8bit chunks for the purpose of using this fft algorithm. The reason
+ // for this is so that we have smaller numbers coming out of the final
+ // ifft. This helps avoid overflow.
+ for (uint32 i = 0; i < lhs->digits_used; ++i)
+ {
+ a[i*2] = ct((t)(lhs->number[i]&0xFF),0);
+ a[i*2+1] = ct((t)(lhs->number[i]>>8),0);
+ }
+ for (uint32 i = lhs->digits_used*2; i < size; ++i)
+ {
+ a[i] = 0;
+ }
+
+ // load rhs into the b array
+ for (uint32 i = 0; i < rhs->digits_used; ++i)
+ {
+ b[i*2] = ct((t)(rhs->number[i]&0xFF),0);
+ b[i*2+1] = ct((t)(rhs->number[i]>>8),0);
+ }
+ for (uint32 i = rhs->digits_used*2; i < size; ++i)
+ {
+ b[i] = 0;
+ }
+
+ // perform the forward fft of a and b
+ fft(a,size);
+ fft(b,size);
+
+ const double l = 1.0/size;
+
+ // do the pointwise multiply of a and b and also apply the scale
+ // factor in this loop too.
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ a[i] = l*a[i]*b[i];
+ }
+
+ // Now compute the inverse fft of the pointwise multiplication of a and b.
+ // This is basically the result. We just have to take care of any carries
+ // that should happen.
+ ifft(a,size);
+
+ // loop over the result and propagate any carries that need to take place.
+ // We will also be moving the resulting numbers into result->number at
+ // the same time.
+ uint64 carry = 0;
+ result->digits_used = 0;
+ int zeros = 0;
+ const uint32 len = lhs->digits_used + rhs->digits_used;
+ for (unsigned long i = 0; i < len; ++i)
+ {
+ uint64 num1 = static_cast<uint64>(std::floor(a[i*2].real()+0.5));
+ num1 += carry;
+ carry = 0;
+ if (num1 > 255)
+ {
+ carry = num1 >> 8;
+ num1 = (num1&0xFF);
+ }
+
+ uint64 num2 = static_cast<uint64>(std::floor(a[i*2+1].real()+0.5));
+ num2 += carry;
+ carry = 0;
+ if (num2 > 255)
+ {
+ carry = num2 >> 8;
+ num2 = (num2&0xFF);
+ }
+
+ // put the new number into its final place
+ num1 = (num2<<8) | num1;
+ result->number[i] = static_cast<uint16>(num1);
+
+ // keep track of the number of leading zeros
+ if (num1 == 0)
+ ++zeros;
+ else
+ zeros = 0;
+ ++(result->digits_used);
+ }
+
+ // adjust digits_used so that it reflects the actual number
+ // of non-zero digits in our representation.
+ result->digits_used -= zeros;
+
+ // if the result was zero then adjust the result accordingly
+ if (result->digits_used == 0)
+ {
+ // make result be zero
+ result->digits_used = 1;
+ *(result->number) = 0;
+ }
+
+ // free all the temporary buffers
+ delete [] a;
+ delete [] b;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ shift_left (
+ const data_record* data,
+ data_record* result,
+ uint32 shift_amount
+ ) const
+ {
+ uint32 offset = shift_amount/16;
+ shift_amount &= 0xf; // same as shift_amount %= 16;
+
+ uint16* r = result->number + data->digits_used + offset; // result
+ uint16* end = data->number;
+ uint16* s = end + data->digits_used; // source
+ const uint32 temp = 16 - shift_amount;
+
+ *r = (*(--s) >> temp);
+ // set the number of digits used in the result
+ // if the upper bits from *s were zero then don't count this first word
+ if (*r == 0)
+ {
+ result->digits_used = data->digits_used + offset;
+ }
+ else
+ {
+ result->digits_used = data->digits_used + offset + 1;
+ }
+ --r;
+
+ while (s != end)
+ {
+ *r = ((*s << shift_amount) | ( *(s-1) >> temp));
+ --r;
+ --s;
+ }
+ *r = *s << shift_amount;
+
+ // now zero the rest of the result
+ end = result->number;
+ while (r != end)
+ *(--r) = 0;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ shift_right (
+ const data_record* data,
+ data_record* result
+ ) const
+ {
+
+ uint16* r = result->number; // result
+ uint16* s = data->number; // source
+ uint16* end = s + data->digits_used - 1;
+
+ while (s != end)
+ {
+ *r = (*s >> 1) | (*(s+1) << 15);
+ ++r;
+ ++s;
+ }
+ *r = *s >> 1;
+
+
+ // calculate the new number for digits_used
+ if (*r == 0)
+ {
+ if (data->digits_used != 1)
+ result->digits_used = data->digits_used - 1;
+ else
+ result->digits_used = 1;
+ }
+ else
+ {
+ result->digits_used = data->digits_used;
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bigint_kernel_2::
+ is_less_than (
+ const data_record* lhs,
+ const data_record* rhs
+ ) const
+ {
+ uint32 lhs_digits_used = lhs->digits_used;
+ uint32 rhs_digits_used = rhs->digits_used;
+
+ // if lhs is definitely less than rhs
+ if (lhs_digits_used < rhs_digits_used )
+ return true;
+ // if lhs is definitely greater than rhs
+ else if (lhs_digits_used > rhs_digits_used)
+ return false;
+ else
+ {
+ uint16* end = lhs->number;
+ uint16* l = end + lhs_digits_used;
+ uint16* r = rhs->number + rhs_digits_used;
+
+ while (l != end)
+ {
+ --l;
+ --r;
+ if (*l < *r)
+ return true;
+ else if (*l > *r)
+ return false;
+ }
+
+ // at this point we know that they are equal
+ return false;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bigint_kernel_2::
+ is_equal_to (
+ const data_record* lhs,
+ const data_record* rhs
+ ) const
+ {
+ // if lhs and rhs are definitely not equal
+ if (lhs->digits_used != rhs->digits_used )
+ {
+ return false;
+ }
+ else
+ {
+ uint16* l = lhs->number;
+ uint16* r = rhs->number;
+ uint16* end = l + lhs->digits_used;
+
+ while (l != end)
+ {
+ if (*l != *r)
+ return false;
+ ++l;
+ ++r;
+ }
+
+ // at this point we know that they are equal
+ return true;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ increment (
+ const data_record* source,
+ data_record* dest
+ ) const
+ {
+ uint16* s = source->number;
+ uint16* d = dest->number;
+ uint16* end = s + source->digits_used;
+ while (true)
+ {
+ *d = *s + 1;
+
+ // if there was no carry then break out of the loop
+ if (*d != 0)
+ {
+ dest->digits_used = source->digits_used;
+
+ // copy the rest of the digits over to d
+ ++d; ++s;
+ while (s != end)
+ {
+ *d = *s;
+ ++d;
+ ++s;
+ }
+
+ break;
+ }
+
+
+ ++s;
+
+ // if we have hit the end of s and there was a carry up to this point
+ // then just make the next digit 1 and add one to the digits used
+ if (s == end)
+ {
+ ++d;
+ dest->digits_used = source->digits_used + 1;
+ *d = 1;
+ break;
+ }
+
+ ++d;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ decrement (
+ const data_record* source,
+ data_record* dest
+ ) const
+ {
+ uint16* s = source->number;
+ uint16* d = dest->number;
+ uint16* end = s + source->digits_used;
+ while (true)
+ {
+ *d = *s - 1;
+
+ // if there was no carry then break out of the loop
+ if (*d != 0xFFFF)
+ {
+ // if we lost a digit in the subtraction
+ if (*d == 0 && s+1 == end)
+ {
+ if (source->digits_used == 1)
+ dest->digits_used = 1;
+ else
+ dest->digits_used = source->digits_used - 1;
+ }
+ else
+ {
+ dest->digits_used = source->digits_used;
+ }
+ break;
+ }
+ else
+ {
+ ++d;
+ ++s;
+ }
+
+ }
+
+ // copy the rest of the digits over to d
+ ++d;
+ ++s;
+ while (s != end)
+ {
+ *d = *s;
+ ++d;
+ ++s;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ fft (
+ ct* data,
+ unsigned long len
+ ) const
+ {
+ const t pi2 = -2.0*3.1415926535897932384626433832795028841971693993751;
+
+ const unsigned long half = len/2;
+
+ std::vector<ct> twiddle_factors;
+ twiddle_factors.resize(half);
+
+ // compute the complex root of unity w
+ const t temp = pi2/len;
+ ct w = ct(std::cos(temp),std::sin(temp));
+
+ ct w_pow = 1;
+
+ // compute the twiddle factors
+ for (std::vector<ct>::size_type j = 0; j < twiddle_factors.size(); ++j)
+ {
+ twiddle_factors[j] = w_pow;
+ w_pow *= w;
+ }
+
+ ct a, b;
+
+ // now compute the decimation in frequency. This first
+ // outer loop loops log2(len) number of times
+ unsigned long skip = 1;
+ for (unsigned long step = half; step != 0; step >>= 1)
+ {
+ // do blocks of butterflies in this loop
+ for (unsigned long j = 0; j < len; j += step*2)
+ {
+ // do step butterflies
+ for (unsigned long k = 0; k < step; ++k)
+ {
+ const unsigned long a_idx = j+k;
+ const unsigned long b_idx = j+k+step;
+ a = data[a_idx] + data[b_idx];
+ b = (data[a_idx] - data[b_idx])*twiddle_factors[k*skip];
+ data[a_idx] = a;
+ data[b_idx] = b;
+ }
+ }
+ skip *= 2;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bigint_kernel_2::
+ ifft(
+ ct* data,
+ unsigned long len
+ ) const
+ {
+ const t pi2 = 2.0*3.1415926535897932384626433832795028841971693993751;
+
+ const unsigned long half = len/2;
+
+ std::vector<ct> twiddle_factors;
+ twiddle_factors.resize(half);
+
+ // compute the complex root of unity w
+ const t temp = pi2/len;
+ ct w = ct(std::cos(temp),std::sin(temp));
+
+ ct w_pow = 1;
+
+ // compute the twiddle factors
+ for (std::vector<ct>::size_type j = 0; j < twiddle_factors.size(); ++j)
+ {
+ twiddle_factors[j] = w_pow;
+ w_pow *= w;
+ }
+
+ ct a, b;
+
+ // now compute the inverse decimation in frequency. This first
+ // outer loop loops log2(len) number of times
+ unsigned long skip = half;
+ for (unsigned long step = 1; step <= half; step <<= 1)
+ {
+ // do blocks of butterflies in this loop
+ for (unsigned long j = 0; j < len; j += step*2)
+ {
+ // do step butterflies
+ for (unsigned long k = 0; k < step; ++k)
+ {
+ const unsigned long a_idx = j+k;
+ const unsigned long b_idx = j+k+step;
+ data[b_idx] *= twiddle_factors[k*skip];
+ a = data[a_idx] + data[b_idx];
+ b = data[a_idx] - data[b_idx];
+ data[a_idx] = a;
+ data[b_idx] = b;
+ }
+ }
+ skip /= 2;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_BIGINT_KERNEL_2_CPp_
+
diff --git a/ml/dlib/dlib/bigint/bigint_kernel_2.h b/ml/dlib/dlib/bigint/bigint_kernel_2.h
new file mode 100644
index 000000000..cbd8f895d
--- /dev/null
+++ b/ml/dlib/dlib/bigint/bigint_kernel_2.h
@@ -0,0 +1,570 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIGINT_KERNEl_2_
+#define DLIB_BIGINT_KERNEl_2_
+
+#include "bigint_kernel_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "../uintn.h"
+#include <iosfwd>
+#include <cmath>
+#include <complex>
+#include <vector>
+
+namespace dlib
+{
+
+ class bigint_kernel_2
+ {
+ /*!
+ INITIAL VALUE
+ slack == 25
+ data->number[0] == 0
+ data->size == slack
+ data->references == 1
+ data->digits_used == 1
+
+
+ CONVENTION
+ slack == the number of extra digits placed into the number when it is
+ created. the slack value should never be less than 1
+
+ data->number == pointer to an array of data->size uint16s.
+ data represents a string of base 65535 numbers with data[0] being
+ the least significant bit and data[data->digits_used-1] being the most
+ significant
+
+
+ NOTE: In the comments I will consider a word to be a 16 bit value
+
+
+ data->digits_used == the number of significant digits in the number.
+ data->digits_used tells us the number of used elements in the
+ data->number array so everything beyond data->number[data->digits_used-1]
+ is undefined
+
+ data->references == the number of bigint_kernel_2 objects which refer
+ to this data_record
+ !*/
+
+
+ struct data_record
+ {
+
+
+ explicit data_record(
+ uint32 size_
+ ) :
+ size(size_),
+ number(new uint16[size_]),
+ references(1),
+ digits_used(1)
+ {*number = 0;}
+ /*!
+ ensures
+ - initializes *this to represent zero
+ !*/
+
+ data_record(
+ const data_record& item,
+ uint32 additional_size
+ ) :
+ size(item.digits_used + additional_size),
+ number(new uint16[size]),
+ references(1),
+ digits_used(item.digits_used)
+ {
+ uint16* source = item.number;
+ uint16* dest = number;
+ uint16* end = source + digits_used;
+ while (source != end)
+ {
+ *dest = *source;
+ ++dest;
+ ++source;
+ }
+ }
+ /*!
+ ensures
+ - *this is a copy of item except with
+ size == item.digits_used + additional_size
+ !*/
+
+ ~data_record(
+ )
+ {
+ delete [] number;
+ }
+
+
+ const uint32 size;
+ uint16* number;
+ uint32 references;
+ uint32 digits_used;
+
+ private:
+ // no copy constructor
+ data_record ( data_record&);
+ };
+
+
+ // note that the second parameter is just there
+ // to resolve the ambiguity between this constructor and
+ // bigint_kernel_2(uint32)
+ explicit bigint_kernel_2 (
+ data_record* data_, int
+ ): slack(25),data(data_) {}
+ /*!
+ ensures
+ - *this is initialized with data_ as its data member
+ !*/
+
+ public:
+
+ bigint_kernel_2 (
+ );
+
+ bigint_kernel_2 (
+ uint32 value
+ );
+
+ bigint_kernel_2 (
+ const bigint_kernel_2& item
+ );
+
+ virtual ~bigint_kernel_2 (
+ );
+
+ const bigint_kernel_2 operator+ (
+ const bigint_kernel_2& rhs
+ ) const;
+
+ bigint_kernel_2& operator+= (
+ const bigint_kernel_2& rhs
+ );
+
+ const bigint_kernel_2 operator- (
+ const bigint_kernel_2& rhs
+ ) const;
+
+ bigint_kernel_2& operator-= (
+ const bigint_kernel_2& rhs
+ );
+
+ const bigint_kernel_2 operator* (
+ const bigint_kernel_2& rhs
+ ) const;
+
+ bigint_kernel_2& operator*= (
+ const bigint_kernel_2& rhs
+ );
+
+ const bigint_kernel_2 operator/ (
+ const bigint_kernel_2& rhs
+ ) const;
+
+ bigint_kernel_2& operator/= (
+ const bigint_kernel_2& rhs
+ );
+
+ const bigint_kernel_2 operator% (
+ const bigint_kernel_2& rhs
+ ) const;
+
+ bigint_kernel_2& operator%= (
+ const bigint_kernel_2& rhs
+ );
+
+ bool operator < (
+ const bigint_kernel_2& rhs
+ ) const;
+
+ bool operator == (
+ const bigint_kernel_2& rhs
+ ) const;
+
+ bigint_kernel_2& operator= (
+ const bigint_kernel_2& rhs
+ );
+
+ friend std::ostream& operator<< (
+ std::ostream& out,
+ const bigint_kernel_2& rhs
+ );
+
+ friend std::istream& operator>> (
+ std::istream& in,
+ bigint_kernel_2& rhs
+ );
+
+ bigint_kernel_2& operator++ (
+ );
+
+ const bigint_kernel_2 operator++ (
+ int
+ );
+
+ bigint_kernel_2& operator-- (
+ );
+
+ const bigint_kernel_2 operator-- (
+ int
+ );
+
+ friend const bigint_kernel_2 operator+ (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ );
+
+ friend const bigint_kernel_2 operator+ (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_2& operator+= (
+ uint16 rhs
+ );
+
+ friend const bigint_kernel_2 operator- (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ );
+
+ friend const bigint_kernel_2 operator- (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_2& operator-= (
+ uint16 rhs
+ );
+
+ friend const bigint_kernel_2 operator* (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ );
+
+ friend const bigint_kernel_2 operator* (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_2& operator*= (
+ uint16 rhs
+ );
+
+ friend const bigint_kernel_2 operator/ (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ );
+
+ friend const bigint_kernel_2 operator/ (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_2& operator/= (
+ uint16 rhs
+ );
+
+ friend const bigint_kernel_2 operator% (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ );
+
+ friend const bigint_kernel_2 operator% (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_2& operator%= (
+ uint16 rhs
+ );
+
+ friend bool operator < (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ );
+
+ friend bool operator < (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ );
+
+ friend bool operator == (
+ const bigint_kernel_2& lhs,
+ uint16 rhs
+ );
+
+ friend bool operator == (
+ uint16 lhs,
+ const bigint_kernel_2& rhs
+ );
+
+ bigint_kernel_2& operator= (
+ uint16 rhs
+ );
+
+
+ void swap (
+ bigint_kernel_2& item
+ ) { data_record* temp = data; data = item.data; item.data = temp; }
+
+
+ private:
+
+ typedef double t;
+ typedef std::complex<t> ct;
+
+ void fft(
+ ct* data,
+ unsigned long len
+ ) const;
+ /*!
+ requires
+ - len == x^n for some integer n (i.e. len is a power of 2)
+ - len > 0
+ ensures
+ - #data == the FT decimation in frequency of data
+ !*/
+
+ void ifft(
+ ct* data,
+ unsigned long len
+ ) const;
+ /*!
+ requires
+ - len == x^n for some integer n (i.e. len is a power of 2)
+ - len > 0
+ ensures
+ - #data == the inverse decimation in frequency of data.
+ (i.e. the inverse of what fft(data,len,-1) does to data)
+ !*/
+
+ void long_add (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= max(lhs->digits_used,rhs->digits_used) + 1
+ ensures
+ - result == lhs + rhs
+ !*/
+
+ void long_sub (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - lhs >= rhs
+ - result->size >= lhs->digits_used
+ ensures
+ - result == lhs - rhs
+ !*/
+
+ void long_div (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result,
+ data_record* remainder
+ ) const;
+ /*!
+ requires
+ - rhs != 0
+ - result->size >= lhs->digits_used
+ - remainder->size >= lhs->digits_used
+ - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.)
+ ensures
+ - result == lhs / rhs
+ - remainder == lhs % rhs
+ !*/
+
+ void long_mul (
+ const data_record* lhs,
+ const data_record* rhs,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= lhs->digits_used + rhs->digits_used
+ - result != lhs
+ - result != rhs
+ ensures
+ - result == lhs * rhs
+ !*/
+
+ void short_add (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= data->size + 1
+ ensures
+ - result == data + value
+ !*/
+
+ void short_sub (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - data >= value
+ - result->size >= data->digits_used
+ ensures
+ - result == data - value
+ !*/
+
+ void short_mul (
+ const data_record* data,
+ uint16 value,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= data->digits_used + 1
+ ensures
+ - result == data * value
+ !*/
+
+ void short_div (
+ const data_record* data,
+ uint16 value,
+ data_record* result,
+ uint16& remainder
+ ) const;
+ /*!
+ requires
+ - value != 0
+ - result->size >= data->digits_used
+ ensures
+ - result = data*value
+ - remainder = data%value
+ !*/
+
+ void shift_left (
+ const data_record* data,
+ data_record* result,
+ uint32 shift_amount
+ ) const;
+ /*!
+ requires
+ - result->size >= data->digits_used + shift_amount/8 + 1
+ ensures
+ - result == data << shift_amount
+ !*/
+
+ void shift_right (
+ const data_record* data,
+ data_record* result
+ ) const;
+ /*!
+ requires
+ - result->size >= data->digits_used
+ ensures
+ - result == data >> 1
+ !*/
+
+ bool is_less_than (
+ const data_record* lhs,
+ const data_record* rhs
+ ) const;
+ /*!
+ ensures
+ - returns true if lhs < rhs
+ - returns false otherwise
+ !*/
+
+ bool is_equal_to (
+ const data_record* lhs,
+ const data_record* rhs
+ ) const;
+ /*!
+ ensures
+ - returns true if lhs == rhs
+ - returns false otherwise
+ !*/
+
+ void increment (
+ const data_record* source,
+ data_record* dest
+ ) const;
+ /*!
+ requires
+ - dest->size >= source->digits_used + 1
+ ensures
+ - dest = source + 1
+ !*/
+
+ void decrement (
+ const data_record* source,
+ data_record* dest
+ ) const;
+ /*!
+ requires
+ source != 0
+ ensuers
+ dest = source - 1
+ !*/
+
+ // member data
+ const uint32 slack;
+ data_record* data;
+
+
+
+ };
+
+ inline void swap (
+ bigint_kernel_2& a,
+ bigint_kernel_2& b
+ ) { a.swap(b); }
+
+ inline void serialize (
+ const bigint_kernel_2& item,
+ std::ostream& out
+ )
+ {
+ std::ios::fmtflags oldflags = out.flags();
+ out.flags();
+ out << item << ' ';
+ out.flags(oldflags);
+ if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c");
+ }
+
+ inline void deserialize (
+ bigint_kernel_2& item,
+ std::istream& in
+ )
+ {
+ std::ios::fmtflags oldflags = in.flags();
+ in.flags();
+ in >> item; in.flags(oldflags);
+ if (in.get() != ' ')
+ {
+ item = 0;
+ throw serialization_error("Error deserializing object of type bigint_kernel_c");
+ }
+ }
+
+ inline bool operator> (const bigint_kernel_2& a, const bigint_kernel_2& b) { return b < a; }
+ inline bool operator!= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a == b); }
+ inline bool operator<= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(b < a); }
+ inline bool operator>= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a < b); }
+
+}
+
+#ifdef NO_MAKEFILE
+#include "bigint_kernel_2.cpp"
+#endif
+
+#endif // DLIB_BIGINT_KERNEl_2_
+
diff --git a/ml/dlib/dlib/bigint/bigint_kernel_abstract.h b/ml/dlib/dlib/bigint/bigint_kernel_abstract.h
new file mode 100644
index 000000000..99a54520b
--- /dev/null
+++ b/ml/dlib/dlib/bigint/bigint_kernel_abstract.h
@@ -0,0 +1,670 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BIGINT_KERNEl_ABSTRACT_
+#ifdef DLIB_BIGINT_KERNEl_ABSTRACT_
+
+#include <iosfwd>
+#include "../algs.h"
+#include "../serialize.h"
+#include "../uintn.h"
+
+namespace dlib
+{
+
+ class bigint
+ {
+ /*!
+ INITIAL VALUE
+ *this == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an arbitrary precision unsigned integer
+
+ the following operators are supported:
+ operator +
+ operator +=
+ operator -
+ operator -=
+ operator *
+ operator *=
+ operator /
+ operator /=
+ operator %
+ operator %=
+ operator ==
+ operator <
+ operator =
+ operator << (for writing to ostreams)
+ operator >> (for reading from istreams)
+ operator++ // pre increment
+ operator++(int) // post increment
+ operator-- // pre decrement
+ operator--(int) // post decrement
+
+
+ the other comparason operators(>, !=, <=, and >=) are
+ available and come from the templates in dlib::relational_operators
+
+ THREAD SAFETY
+ bigint may be reference counted so it is very unthread safe.
+ use with care in a multithreaded program
+
+ !*/
+
+ public:
+
+ bigint (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ if this is thrown the bigint will be unusable but
+ will not leak memory
+ !*/
+
+ bigint (
+ uint32 value
+ );
+ /*!
+ requires
+ - value <= (2^32)-1
+ ensures
+ - #*this is properly initialized
+ - #*this == value
+ throws
+ - std::bad_alloc
+ if this is thrown the bigint will be unusable but
+ will not leak memory
+ !*/
+
+ bigint (
+ const bigint& item
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this == value
+ throws
+ - std::bad_alloc
+ if this is thrown the bigint will be unusable but
+ will not leak memory
+ !*/
+
+ virtual ~bigint (
+ );
+ /*!
+ ensures
+ - all resources associated with #*this have been released
+ !*/
+
+ const bigint operator+ (
+ const bigint& rhs
+ ) const;
+ /*!
+ ensures
+ - returns the result of adding rhs to *this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator+= (
+ const bigint& rhs
+ );
+ /*!
+ ensures
+ - #*this == *this + rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ const bigint operator- (
+ const bigint& rhs
+ ) const;
+ /*!
+ requires
+ - *this >= rhs
+ ensures
+ - returns the result of subtracting rhs from *this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator-= (
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - *this >= rhs
+ ensures
+ - #*this == *this - rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ const bigint operator* (
+ const bigint& rhs
+ ) const;
+ /*!
+ ensures
+ - returns the result of multiplying *this and rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator*= (
+ const bigint& rhs
+ );
+ /*!
+ ensures
+ - #*this == *this * rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ const bigint operator/ (
+ const bigint& rhs
+ ) const;
+ /*!
+ requires
+ - rhs != 0
+ ensures
+ - returns the result of dividing *this by rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator/= (
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - rhs != 0
+ ensures
+ - #*this == *this / rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ const bigint operator% (
+ const bigint& rhs
+ ) const;
+ /*!
+ requires
+ - rhs != 0
+ ensures
+ - returns the result of *this mod rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator%= (
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - rhs != 0
+ ensures
+ - #*this == *this % rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bool operator < (
+ const bigint& rhs
+ ) const;
+ /*!
+ ensures
+ - returns true if *this is less than rhs
+ - returns false otherwise
+ !*/
+
+ bool operator == (
+ const bigint& rhs
+ ) const;
+ /*!
+ ensures
+ - returns true if *this and rhs represent the same number
+ - returns false otherwise
+ !*/
+
+ bigint& operator= (
+ const bigint& rhs
+ );
+ /*!
+ ensures
+ - #*this == rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+
+ friend std::ostream& operator<< (
+ std::ostream& out,
+ const bigint& rhs
+ );
+ /*!
+ ensures
+ - the number in *this has been written to #out as a base ten number
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect (nothing
+ is written to out)
+ !*/
+
+ friend std::istream& operator>> (
+ std::istream& in,
+ bigint& rhs
+ );
+ /*!
+ ensures
+ - reads a number from in and puts it into #*this
+ - if (there is no positive base ten number on the input stream ) then
+ - #in.fail() == true
+ throws
+ - std::bad_alloc
+ if this function throws the value in rhs is undefined and some
+ characters may have been read from in. rhs is still usable though,
+ its value is just unknown.
+ !*/
+
+
+ bigint& operator++ (
+ );
+ /*!
+ ensures
+ - #*this == *this + 1
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ const bigint operator++ (
+ int
+ );
+ /*!
+ ensures
+ - #*this == *this + 1
+ - returns *this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator-- (
+ );
+ /*!
+ requires
+ - *this != 0
+ ensures
+ - #*this == *this - 1
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ const bigint operator-- (
+ int
+ );
+ /*!
+ requires
+ - *this != 0
+ ensures
+ - #*this == *this - 1
+ - returns *this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ void swap (
+ bigint& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+
+ // ------------------------------------------------------------------
+ // ---- The following functions are identical to the above -----
+ // ---- but take uint16 as one of their arguments. They ---
+ // ---- exist only to allow for a more efficient implementation ---
+ // ------------------------------------------------------------------
+
+
+ friend const bigint operator+ (
+ uint16 lhs,
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - lhs <= 65535
+ ensures
+ - returns the result of adding rhs to lhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ friend const bigint operator+ (
+ const bigint& lhs,
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs <= 65535
+ ensures
+ - returns the result of adding rhs to lhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator+= (
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs <= 65535
+ ensures
+ - #*this == *this + rhs
+ - returns #this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ friend const bigint operator- (
+ uint16 lhs,
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - lhs >= rhs
+ - lhs <= 65535
+ ensures
+ - returns the result of subtracting rhs from lhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ friend const bigint operator- (
+ const bigint& lhs,
+ uint16 rhs
+ );
+ /*!
+ requires
+ - lhs >= rhs
+ - rhs <= 65535
+ ensures
+ - returns the result of subtracting rhs from lhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator-= (
+ uint16 rhs
+ );
+ /*!
+ requires
+ - *this >= rhs
+ - rhs <= 65535
+ ensures
+ - #*this == *this - rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ friend const bigint operator* (
+ uint16 lhs,
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - lhs <= 65535
+ ensures
+ - returns the result of multiplying lhs and rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ friend const bigint operator* (
+ const bigint& lhs,
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs <= 65535
+ ensures
+ - returns the result of multiplying lhs and rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator*= (
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs <= 65535
+ ensures
+ - #*this == *this * rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ friend const bigint operator/ (
+ uint16 lhs,
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - rhs != 0
+ - lhs <= 65535
+ ensures
+ - returns the result of dividing lhs by rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ friend const bigint operator/ (
+ const bigint& lhs,
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs != 0
+ - rhs <= 65535
+ ensures
+ - returns the result of dividing lhs by rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator/= (
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs != 0
+ - rhs <= 65535
+ ensures
+ - #*this == *this / rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ friend const bigint operator% (
+ uint16 lhs,
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - rhs != 0
+ - lhs <= 65535
+ ensures
+ - returns the result of lhs mod rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ friend const bigint operator% (
+ const bigint& lhs,
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs != 0
+ - rhs <= 65535
+ ensures
+ - returns the result of lhs mod rhs
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ bigint& operator%= (
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs != 0
+ - rhs <= 65535
+ ensures
+ - #*this == *this % rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+
+ friend bool operator < (
+ uint16 lhs,
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - lhs <= 65535
+ ensures
+ - returns true if lhs is less than rhs
+ - returns false otherwise
+ !*/
+
+ friend bool operator < (
+ const bigint& lhs,
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs <= 65535
+ ensures
+ - returns true if lhs is less than rhs
+ - returns false otherwise
+ !*/
+
+ friend bool operator == (
+ const bigint& lhs,
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs <= 65535
+ ensures
+ - returns true if lhs and rhs represent the same number
+ - returns false otherwise
+ !*/
+
+ friend bool operator == (
+ uint16 lhs,
+ const bigint& rhs
+ );
+ /*!
+ requires
+ - lhs <= 65535
+ ensures
+ - returns true if lhs and rhs represent the same number
+ - returns false otherwise
+ !*/
+
+ bigint& operator= (
+ uint16 rhs
+ );
+ /*!
+ requires
+ - rhs <= 65535
+ ensures
+ - #*this == rhs
+ - returns #*this
+ throws
+ - std::bad_alloc
+ if this function throws then it has no effect
+ !*/
+
+ };
+
+ inline void swap (
+ bigint& a,
+ bigint& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ void serialize (
+ const bigint& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ bigint& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+ inline bool operator> (const bigint& a, const bigint& b) { return b < a; }
+ inline bool operator!= (const bigint& a, const bigint& b) { return !(a == b); }
+ inline bool operator<= (const bigint& a, const bigint& b) { return !(b < a); }
+ inline bool operator>= (const bigint& a, const bigint& b) { return !(a < b); }
+}
+
+#endif // DLIB_BIGINT_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/bigint/bigint_kernel_c.h b/ml/dlib/dlib/bigint/bigint_kernel_c.h
new file mode 100644
index 000000000..954869a38
--- /dev/null
+++ b/ml/dlib/dlib/bigint/bigint_kernel_c.h
@@ -0,0 +1,1141 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIGINT_KERNEl_C_
+#define DLIB_BIGINT_KERNEl_C_
+
+#include "bigint_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ class bigint_kernel_c
+ {
+ bigint_base data;
+
+ explicit bigint_kernel_c (
+ const bigint_base& item
+ ) : data(item) {}
+
+ public:
+
+
+ bigint_kernel_c (
+ );
+
+ bigint_kernel_c (
+ uint32 value
+ );
+
+ bigint_kernel_c (
+ const bigint_kernel_c<bigint_base>& item
+ );
+
+ ~bigint_kernel_c (
+ );
+
+ const bigint_kernel_c<bigint_base> operator+ (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const;
+
+ bigint_kernel_c<bigint_base>& operator+= (
+ const bigint_kernel_c<bigint_base>& rhs
+ );
+
+ const bigint_kernel_c<bigint_base> operator- (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const;
+ bigint_kernel_c<bigint_base>& operator-= (
+ const bigint_kernel_c<bigint_base>& rhs
+ );
+
+ const bigint_kernel_c<bigint_base> operator* (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const;
+
+ bigint_kernel_c<bigint_base>& operator*= (
+ const bigint_kernel_c<bigint_base>& rhs
+ );
+
+ const bigint_kernel_c<bigint_base> operator/ (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const;
+
+ bigint_kernel_c<bigint_base>& operator/= (
+ const bigint_kernel_c<bigint_base>& rhs
+ );
+
+ const bigint_kernel_c<bigint_base> operator% (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const;
+
+ bigint_kernel_c<bigint_base>& operator%= (
+ const bigint_kernel_c<bigint_base>& rhs
+ );
+
+ bool operator < (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const;
+
+ bool operator == (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const;
+
+ bigint_kernel_c<bigint_base>& operator= (
+ const bigint_kernel_c<bigint_base>& rhs
+ );
+
+ template <typename T>
+ friend std::ostream& operator<< (
+ std::ostream& out,
+ const bigint_kernel_c<T>& rhs
+ );
+
+ template <typename T>
+ friend std::istream& operator>> (
+ std::istream& in,
+ bigint_kernel_c<T>& rhs
+ );
+
+ bigint_kernel_c<bigint_base>& operator++ (
+ );
+
+ const bigint_kernel_c<bigint_base> operator++ (
+ int
+ );
+
+ bigint_kernel_c<bigint_base>& operator-- (
+ );
+
+ const bigint_kernel_c<bigint_base> operator-- (
+ int
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator+ (
+ uint16 lhs,
+ const bigint_kernel_c<T>& rhs
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator+ (
+ const bigint_kernel_c<T>& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_c<bigint_base>& operator+= (
+ uint16 rhs
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator- (
+ uint16 lhs,
+ const bigint_kernel_c<T>& rhs
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator- (
+ const bigint_kernel_c<T>& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_c<bigint_base>& operator-= (
+ uint16 rhs
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator* (
+ uint16 lhs,
+ const bigint_kernel_c<T>& rhs
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator* (
+ const bigint_kernel_c<T>& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_c<bigint_base>& operator*= (
+ uint16 rhs
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator/ (
+ uint16 lhs,
+ const bigint_kernel_c<T>& rhs
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator/ (
+ const bigint_kernel_c<T>& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_c<bigint_base>& operator/= (
+ uint16 rhs
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator% (
+ uint16 lhs,
+ const bigint_kernel_c<T>& rhs
+ );
+
+ template <typename T>
+ friend const bigint_kernel_c<T> operator% (
+ const bigint_kernel_c<T>& lhs,
+ uint16 rhs
+ );
+
+ bigint_kernel_c<bigint_base>& operator%= (
+ uint16 rhs
+ );
+
+ template <typename T>
+ friend bool operator < (
+ uint16 lhs,
+ const bigint_kernel_c<T>& rhs
+ );
+
+ template <typename T>
+ friend bool operator < (
+ const bigint_kernel_c<T>& lhs,
+ uint16 rhs
+ );
+
+ template <typename T>
+ friend bool operator == (
+ const bigint_kernel_c<T>& lhs,
+ uint16 rhs
+ );
+
+ template <typename T>
+ friend bool operator == (
+ uint16 lhs,
+ const bigint_kernel_c<T>& rhs
+ );
+
+ bigint_kernel_c<bigint_base>& operator= (
+ uint16 rhs
+ );
+
+
+ void swap (
+ bigint_kernel_c<bigint_base>& item
+ ) { data.swap(item.data); }
+
+ };
+
+ template <
+ typename bigint_base
+ >
+ void swap (
+ bigint_kernel_c<bigint_base>& a,
+ bigint_kernel_c<bigint_base>& b
+ ) { a.swap(b); }
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ inline void serialize (
+ const bigint_kernel_c<bigint_base>& item,
+ std::ostream& out
+ )
+ {
+ std::ios::fmtflags oldflags = out.flags();
+ out.flags();
+ out << item << ' ';
+ out.flags(oldflags);
+ if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c");
+ }
+
+ template <
+ typename bigint_base
+ >
+ inline void deserialize (
+ bigint_kernel_c<bigint_base>& item,
+ std::istream& in
+ )
+ {
+ std::ios::fmtflags oldflags = in.flags();
+ in.flags();
+ in >> item; in.flags(oldflags);
+ if (in.get() != ' ')
+ {
+ item = 0;
+ throw serialization_error("Error deserializing object of type bigint_kernel_c");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>::
+ bigint_kernel_c (
+ )
+ {}
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>::
+ bigint_kernel_c (
+ uint32 value
+ ) :
+ data(value)
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( value <= 0xFFFFFFFF ,
+ "\tbigint::bigint(uint16)"
+ << "\n\t value must be <= (2^32)-1"
+ << "\n\tthis: " << this
+ << "\n\tvalue: " << value
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>::
+ bigint_kernel_c (
+ const bigint_kernel_c<bigint_base>& item
+ ) :
+ data(item.data)
+ {}
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>::
+ ~bigint_kernel_c (
+ )
+ {}
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> bigint_kernel_c<bigint_base>::
+ operator+ (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const
+ {
+ return bigint_kernel_c<bigint_base>(data + rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator+= (
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ data += rhs.data;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> bigint_kernel_c<bigint_base>::
+ operator- (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(*this < rhs),
+ "\tconst bigint bigint::operator-(const bigint&)"
+ << "\n\t *this should not be less than rhs"
+ << "\n\tthis: " << this
+ << "\n\t*this: " << *this
+ << "\n\trhs: " << rhs
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(data-rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator-= (
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(*this < rhs),
+ "\tbigint& bigint::operator-=(const bigint&)"
+ << "\n\t *this should not be less than rhs"
+ << "\n\tthis: " << this
+ << "\n\t*this: " << *this
+ << "\n\trhs: " << rhs
+ );
+
+ // call the real function
+ data -= rhs.data;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> bigint_kernel_c<bigint_base>::
+ operator* (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const
+ {
+ return bigint_kernel_c<bigint_base>(data * rhs.data );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator*= (
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ data *= rhs.data;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> bigint_kernel_c<bigint_base>::
+ operator/ (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const
+ {
+ //make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0),
+ "\tconst bigint bigint::operator/(const bigint&)"
+ << "\n\t can't divide by zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(data/rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator/= (
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0),
+ "\tbigint& bigint::operator/=(const bigint&)"
+ << "\n\t can't divide by zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ data /= rhs.data;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> bigint_kernel_c<bigint_base>::
+ operator% (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0),
+ "\tconst bigint bigint::operator%(const bigint&)"
+ << "\n\t can't divide by zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(data%rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator%= (
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0),
+ "\tbigint& bigint::operator%=(const bigint&)"
+ << "\n\t can't divide by zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ data %= rhs.data;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bool bigint_kernel_c<bigint_base>::
+ operator < (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const
+ {
+ return data < rhs.data;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bool bigint_kernel_c<bigint_base>::
+ operator == (
+ const bigint_kernel_c<bigint_base>& rhs
+ ) const
+ {
+ return data == rhs.data;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator= (
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ data = rhs.data;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ std::ostream& operator<< (
+ std::ostream& out,
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ out << rhs.data;
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ std::istream& operator>> (
+ std::istream& in,
+ bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ in >> rhs.data;
+ return in;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator++ (
+ )
+ {
+ ++data;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> bigint_kernel_c<bigint_base>::
+ operator++ (
+ int
+ )
+ {
+ return bigint_kernel_c<bigint_base>(data++);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator-- (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(*this == 0),
+ "\tbigint& bigint::operator--()"
+ << "\n\t *this to subtract from *this it must not be zero to begin with"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ --data;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> bigint_kernel_c<bigint_base>::
+ operator-- (
+ int
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(*this == 0),
+ "\tconst bigint bigint::operator--(int)"
+ << "\n\t *this to subtract from *this it must not be zero to begin with"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(data--);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator+ (
+ uint16 l,
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ uint32 lhs = l;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( lhs <= 65535,
+ "\tconst bigint operator+(uint16, const bigint&)"
+ << "\n\t lhs must be <= 65535"
+ << "\n\trhs: " << rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ return bigint_kernel_c<bigint_base>(static_cast<uint16>(lhs)+rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator+ (
+ const bigint_kernel_c<bigint_base>& lhs,
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( rhs <= 65535,
+ "\tconst bigint operator+(const bigint&, uint16)"
+ << "\n\t rhs must be <= 65535"
+ << "\n\trhs: " << rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ return bigint_kernel_c<bigint_base>(lhs.data+static_cast<uint16>(rhs));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator+= (
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( rhs <= 65535,
+ "\tbigint& bigint::operator+=(uint16)"
+ << "\n\t rhs must be <= 65535"
+ << "\n\tthis: " << this
+ << "\n\t*this: " << *this
+ << "\n\trhs: " << rhs
+ );
+
+ data += rhs;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator- (
+ uint16 l,
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ uint32 lhs = l;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(static_cast<uint16>(lhs) < rhs) && lhs <= 65535,
+ "\tconst bigint operator-(uint16,const bigint&)"
+ << "\n\t lhs must be greater than or equal to rhs and lhs <= 65535"
+ << "\n\tlhs: " << lhs
+ << "\n\trhs: " << rhs
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(static_cast<uint16>(lhs)-rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator- (
+ const bigint_kernel_c<bigint_base>& lhs,
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(lhs < static_cast<uint16>(rhs)) && rhs <= 65535,
+ "\tconst bigint operator-(const bigint&,uint16)"
+ << "\n\t lhs must be greater than or equal to rhs and rhs <= 65535"
+ << "\n\tlhs: " << lhs
+ << "\n\trhs: " << rhs
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(lhs.data-static_cast<uint16>(rhs));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator-= (
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(*this < static_cast<uint16>(rhs)) && rhs <= 65535,
+ "\tbigint& bigint::operator-=(uint16)"
+ << "\n\t *this must not be less than rhs and rhs <= 65535"
+ << "\n\tthis: " << this
+ << "\n\t*this: " << *this
+ << "\n\trhs: " << rhs
+ );
+
+ // call the real function
+ data -= static_cast<uint16>(rhs);
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator* (
+ uint16 l,
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ uint32 lhs = l;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( lhs <= 65535,
+ "\tconst bigint operator*(uint16, const bigint&)"
+ << "\n\t lhs must be <= 65535"
+ << "\n\trhs: " << rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ return bigint_kernel_c<bigint_base>(lhs*rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator* (
+ const bigint_kernel_c<bigint_base>& lhs,
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( rhs <= 65535,
+ "\tconst bigint operator*(const bigint&, uint16)"
+ << "\n\t rhs must be <= 65535"
+ << "\n\trhs: " << rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ return bigint_kernel_c<bigint_base>(lhs.data*rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator*= (
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( rhs <= 65535,
+ "\t bigint bigint::operator*=(uint16)"
+ << "\n\t rhs must be <= 65535"
+ << "\n\tthis: " << this
+ << "\n\t*this: " << *this
+ << "\n\trhs: " << rhs
+ );
+
+ data *= static_cast<uint16>(rhs);
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator/ (
+ uint16 l,
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ uint32 lhs = l;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0) && lhs <= 65535,
+ "\tconst bigint operator/(uint16,const bigint&)"
+ << "\n\t you can't divide by zero and lhs <= 65535"
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(lhs/rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator/ (
+ const bigint_kernel_c<bigint_base>& lhs,
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0) && rhs <= 65535,
+ "\tconst bigint operator/(const bigint&,uint16)"
+ << "\n\t you can't divide by zero and rhs <= 65535"
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ << "\n\trhs: " << rhs
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(lhs.data/static_cast<uint16>(rhs));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator/= (
+ uint16 rhs
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0) && static_cast<uint32>(rhs) <= 65535,
+ "\tbigint& bigint::operator/=(uint16)"
+ << "\n\t you can't divide by zero and rhs must be <= 65535"
+ << "\n\tthis: " << this
+ << "\n\trhs: " << rhs
+ );
+
+ // call the real function
+ data /= rhs;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator% (
+ uint16 lhs,
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0) && static_cast<uint32>(lhs) <= 65535,
+ "\tconst bigint operator%(uint16,const bigint&)"
+ << "\n\t you can't divide by zero and lhs must be <= 65535"
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(lhs%rhs.data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ const bigint_kernel_c<bigint_base> operator% (
+ const bigint_kernel_c<bigint_base>& lhs,
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0) && rhs <= 65535,
+ "\tconst bigint operator%(const bigint&,uint16)"
+ << "\n\t you can't divide by zero and rhs must be <= 65535"
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ << "\n\trhs: " << rhs
+ );
+
+ // call the real function
+ return bigint_kernel_c<bigint_base>(lhs.data%static_cast<uint16>(rhs));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator%= (
+ uint16 r
+ )
+ {
+
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !(rhs == 0) && rhs <= 65535,
+ "\tbigint& bigint::operator%=(uint16)"
+ << "\n\t you can't divide by zero and rhs must be <= 65535"
+ << "\n\tthis: " << this
+ << "\n\trhs: " << rhs
+ );
+
+ // call the real function
+ data %= rhs;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bool operator < (
+ uint16 l,
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ uint32 lhs = l;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( lhs <= 65535,
+ "\tbool operator<(uint16, const bigint&)"
+ << "\n\t lhs must be <= 65535"
+ << "\n\trhs: " << rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ return static_cast<uint16>(lhs) < rhs.data;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bool operator < (
+ const bigint_kernel_c<bigint_base>& lhs,
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( rhs <= 65535,
+ "\tbool operator<(const bigint&, uint16)"
+ << "\n\t rhs must be <= 65535"
+ << "\n\trhs: " << rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ return lhs.data < static_cast<uint16>(rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bool operator == (
+ const bigint_kernel_c<bigint_base>& lhs,
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( rhs <= 65535,
+ "\tbool operator==(const bigint&, uint16)"
+ << "\n\t rhs must be <= 65535"
+ << "\n\trhs: " << rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ return lhs.data == static_cast<uint16>(rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bool operator == (
+ uint16 l,
+ const bigint_kernel_c<bigint_base>& rhs
+ )
+ {
+ uint32 lhs = l;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( lhs <= 65535,
+ "\tbool operator==(uint16, const bigint&)"
+ << "\n\t lhs must be <= 65535"
+ << "\n\trhs: " << rhs
+ << "\n\tlhs: " << lhs
+ );
+
+ return static_cast<uint16>(lhs) == rhs.data;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bigint_base
+ >
+ bigint_kernel_c<bigint_base>& bigint_kernel_c<bigint_base>::
+ operator= (
+ uint16 r
+ )
+ {
+ uint32 rhs = r;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( rhs <= 65535,
+ "\tbigint bigint::operator=(uint16)"
+ << "\n\t rhs must be <= 65535"
+ << "\n\t*this: " << *this
+ << "\n\tthis: " << this
+ << "\n\tlhs: " << rhs
+ );
+
+ data = static_cast<uint16>(rhs);
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template < typename bigint_base >
+ inline bool operator> (const bigint_kernel_c<bigint_base>& a, const bigint_kernel_c<bigint_base>& b) { return b < a; }
+ template < typename bigint_base >
+ inline bool operator!= (const bigint_kernel_c<bigint_base>& a, const bigint_kernel_c<bigint_base>& b) { return !(a == b); }
+ template < typename bigint_base >
+ inline bool operator<= (const bigint_kernel_c<bigint_base>& a, const bigint_kernel_c<bigint_base>& b) { return !(b < a); }
+ template < typename bigint_base >
+ inline bool operator>= (const bigint_kernel_c<bigint_base>& a, const bigint_kernel_c<bigint_base>& b) { return !(a < b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BIGINT_KERNEl_C_
+
diff --git a/ml/dlib/dlib/binary_search_tree.h b/ml/dlib/dlib/binary_search_tree.h
new file mode 100644
index 000000000..5273e8ce9
--- /dev/null
+++ b/ml/dlib/dlib/binary_search_tree.h
@@ -0,0 +1,50 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINARY_SEARCH_TREe_
+#define DLIB_BINARY_SEARCH_TREe_
+
+
+#include "binary_search_tree/binary_search_tree_kernel_1.h"
+#include "binary_search_tree/binary_search_tree_kernel_2.h"
+#include "binary_search_tree/binary_search_tree_kernel_c.h"
+
+
+#include "algs.h"
+#include <functional>
+
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<domain>
+ >
+ class binary_search_tree
+ {
+ binary_search_tree() {}
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef binary_search_tree_kernel_1<domain,range,mem_manager,compare>
+ kernel_1a;
+ typedef binary_search_tree_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ // kernel_2a
+ typedef binary_search_tree_kernel_2<domain,range,mem_manager,compare>
+ kernel_2a;
+ typedef binary_search_tree_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+ };
+}
+
+#endif // DLIB_BINARY_SEARCH_TREe_
+
diff --git a/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_1.h b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_1.h
new file mode 100644
index 000000000..418eb07d0
--- /dev/null
+++ b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_1.h
@@ -0,0 +1,2064 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_1_
+#define DLIB_BINARY_SEARCH_TREE_KERNEl_1_
+
+#include "binary_search_tree_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/map_pair.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include <cstdlib>
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare = std::less<domain>
+ >
+ class binary_search_tree_kernel_1 : public enumerable<map_pair<domain,range> >,
+ public asc_pair_remover<domain,range,compare>
+ {
+
+ /*!
+ INITIAL VALUE
+ tree_size == 0
+ tree_root == 0
+ tree_height == 0
+ at_start_ == true
+ current_element == 0
+ stack == array of 50 node pointers
+ stack_pos == 0
+
+
+ CONVENTION
+ tree_size == size()
+ tree_height == height()
+
+ stack[stack_pos-1] == pop()
+
+ current_element_valid() == (current_element != 0)
+ if (current_element_valid()) then
+ element() == current_element->d and current_element->r
+ at_start_ == at_start()
+ if (current_element != 0 && current_element != tree_root) then
+ stack[stack_pos-1] == the parent of the node pointed to by current_element
+
+ if (tree_size != 0)
+ tree_root == pointer to the root node of the binary search tree
+ else
+ tree_root == 0
+
+
+ for all nodes:
+ {
+ left points to the left subtree or 0 if there is no left subtree and
+ right points to the right subtree or 0 if there is no right subtree and
+ all elements in a left subtree are <= the root and
+ all elements in a right subtree are >= the root and
+ d is the item in the domain of *this contained in the node
+ r is the item in the range of *this contained in the node
+ balance:
+ balance == 0 if both subtrees have the same height
+ balance == -1 if the left subtree has a height that is greater
+ than the height of the right subtree by 1
+ balance == 1 if the right subtree has a height that is greater
+ than the height of the left subtree by 1
+ for all trees:
+ the height of the left and right subtrees differ by at most one
+ }
+
+ !*/
+
+ class node
+ {
+ public:
+ node* left;
+ node* right;
+ domain d;
+ range r;
+ signed char balance;
+ };
+
+ class mpair : public map_pair<domain,range>
+ {
+ public:
+ const domain* d;
+ range* r;
+
+ const domain& key(
+ ) const { return *d; }
+
+ const range& value(
+ ) const { return *r; }
+
+ range& value(
+ ) { return *r; }
+ };
+
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ binary_search_tree_kernel_1(
+ ) :
+ tree_size(0),
+ tree_root(0),
+ current_element(0),
+ tree_height(0),
+ at_start_(true),
+ stack_pos(0),
+ stack(ppool.allocate_array(50))
+ {
+ }
+
+ virtual ~binary_search_tree_kernel_1(
+ );
+
+ inline void clear(
+ );
+
+ inline short height (
+ ) const;
+
+ inline unsigned long count (
+ const domain& item
+ ) const;
+
+ inline void add (
+ domain& d,
+ range& r
+ );
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ void destroy (
+ const domain& item
+ );
+
+ inline const range* operator[] (
+ const domain& item
+ ) const;
+
+ inline range* operator[] (
+ const domain& item
+ );
+
+ inline void swap (
+ binary_search_tree_kernel_1& item
+ );
+
+ // function from the asc_pair_remover interface
+ void remove_any (
+ domain& d,
+ range& r
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ const map_pair<domain,range>& element (
+ ) const;
+
+ map_pair<domain,range>& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ void remove_last_in_order (
+ domain& d,
+ range& r
+ );
+
+ void remove_current_element (
+ domain& d,
+ range& r
+ );
+
+ void position_enumerator (
+ const domain& d
+ ) const;
+
+ private:
+
+
+ inline void rotate_left (
+ node*& t
+ );
+ /*!
+ requires
+ - t->balance == 2
+ - t->right->balance == 0 or 1
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - #t is still a binary search tree
+ - #t->balance is between 1 and -1
+ - #t now has a height smaller by 1 if #t->balance == 0
+ !*/
+
+ inline void rotate_right (
+ node*& t
+ );
+ /*!
+ requires
+ - t->balance == -2
+ - t->left->balance == 0 or -1
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - #t is still a binary search tree
+ - #t->balance is between 1 and -1
+ - #t now has a height smaller by 1 if #t->balance == 0
+
+ !*/
+
+ inline void double_rotate_right (
+ node*& t
+ );
+ /*!
+ requires
+ - t->balance == -2
+ - t->left->balance == 1
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - #t is still a binary search tree
+ - #t now has a balance of 0
+ - #t now has a height smaller by 1
+ !*/
+
+ inline void double_rotate_left (
+ node*& t
+ );
+ /*!
+ requires
+ - t->balance == 2
+ - t->right->balance == -1
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - #t is still a binary search tree
+ - #t now has a balance of 0
+ - #t now has a height smaller by 1
+ !*/
+
+ bool remove_biggest_element_in_tree (
+ node*& t,
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - t != 0 (i.e. there must be something in the tree to remove)
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - the biggest node in t has been removed
+ - the biggest node domain element in t has been put into #d
+ - the biggest node range element in t has been put into #r
+ - #t is still a binary search tree
+ - returns false if the height of the tree has not changed
+ - returns true if the height of the tree has shrunk by one
+ !*/
+
+ bool remove_least_element_in_tree (
+ node*& t,
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - t != 0 (i.e. there must be something in the tree to remove)
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - the least node in t has been removed
+ - the least node domain element in t has been put into #d
+ - the least node range element in t has been put into #r
+ - #t is still a binary search tree
+ - returns false if the height of the tree has not changed
+ - returns true if the height of the tree has shrunk by one
+ !*/
+
+ bool add_to_tree (
+ node*& t,
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - the mapping (d --> r) has been added to #t
+ - #d and #r have initial values for their types
+ - #t is still a binary search tree
+ - returns false if the height of the tree has not changed
+ - returns true if the height of the tree has grown by one
+ !*/
+
+ bool remove_from_tree (
+ node*& t,
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+ /*!
+ requires
+ - return_reference(t,d) != 0
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - #d_copy is equivalent to d
+ - an element in t equivalent to d has been removed and swapped
+ into #d_copy and its associated range object has been
+ swapped into #r
+ - #t is still a binary search tree
+ - returns false if the height of the tree has not changed
+ - returns true if the height of the tree has shrunk by one
+ !*/
+
+ bool remove_from_tree (
+ node*& t,
+ const domain& item
+ );
+ /*!
+ requires
+ - return_reference(t,item) != 0
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - an element in t equivalent to item has been removed
+ - #t is still a binary search tree
+ - returns false if the height of the tree has not changed
+ - returns true if the height of the tree has shrunk by one
+ !*/
+
+ const range* return_reference (
+ const node* t,
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - if (there is a domain element equivalent to d in t) then
+ - returns a pointer to the element in the range equivalent to d
+ - else
+ - returns 0
+ !*/
+
+ range* return_reference (
+ node* t,
+ const domain& d
+ );
+ /*!
+ ensures
+ - if (there is a domain element equivalent to d in t) then
+ - returns a pointer to the element in the range equivalent to d
+ - else
+ - returns 0
+ !*/
+
+
+ inline bool keep_node_balanced (
+ node*& t
+ );
+ /*!
+ requires
+ - t != 0
+ - t == reference to the pointer in t's parent node that points to t
+ ensures
+ - if (t->balance is < 1 or > 1) then
+ - keep_node_balanced() will ensure that #t->balance == 0, -1, or 1
+ - #t is still a binary search tree
+ - returns true if it made the tree one height shorter
+ - returns false if it didn't change the height
+ !*/
+
+
+ unsigned long get_count (
+ const domain& item,
+ node* tree_root
+ ) const;
+ /*!
+ requires
+ - tree_root == the root of a binary search tree or 0
+ ensures
+ - if (tree_root == 0) then
+ - returns 0
+ - else
+ - returns the number of elements in tree_root that are
+ equivalent to item
+ !*/
+
+
+ void delete_tree (
+ node* t
+ );
+ /*!
+ requires
+ - t != 0
+ ensures
+ - deallocates the node pointed to by t and all of t's left and right children
+ !*/
+
+
+ void push (
+ node* n
+ ) const { stack[stack_pos] = n; ++stack_pos; }
+ /*!
+ ensures
+ - pushes n onto the stack
+ !*/
+
+
+ node* pop (
+ ) const { --stack_pos; return stack[stack_pos]; }
+ /*!
+ ensures
+ - pops the top of the stack and returns it
+ !*/
+
+
+
+ bool fix_stack (
+ node* t,
+ unsigned char depth = 0
+ );
+ /*!
+ requires
+ - current_element != 0
+ - depth == 0
+ - t == tree_root
+ ensures
+ - makes the stack contain the correct set of parent pointers.
+ also adjusts stack_pos so it is correct.
+ - #t is still a binary search tree
+ !*/
+
+ bool remove_current_element_from_tree (
+ node*& t,
+ domain& d,
+ range& r,
+ unsigned long cur_stack_pos = 1
+ );
+ /*!
+ requires
+ - t == tree_root
+ - cur_stack_pos == 1
+ - current_element != 0
+ ensures
+ - removes the data in the node given by current_element and swaps it into
+ #d and #r.
+ - #t is still a binary search tree
+ - the enumerator is advances on to the next element but its stack is
+ potentially corrupted. so you must call fix_stack(tree_root) to fix
+ it.
+ - returns false if the height of the tree has not changed
+ - returns true if the height of the tree has shrunk by one
+ !*/
+
+
+ // data members
+
+ mutable mpair p;
+ unsigned long tree_size;
+ node* tree_root;
+ mutable node* current_element;
+ typename mem_manager::template rebind<node>::other pool;
+ typename mem_manager::template rebind<node*>::other ppool;
+ short tree_height;
+ mutable bool at_start_;
+ mutable unsigned char stack_pos;
+ mutable node** stack;
+ compare comp;
+
+ // restricted functions
+ binary_search_tree_kernel_1(binary_search_tree_kernel_1&);
+ binary_search_tree_kernel_1& operator=(binary_search_tree_kernel_1&);
+
+
+ };
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ inline void swap (
+ binary_search_tree_kernel_1<domain,range,mem_manager,compare>& a,
+ binary_search_tree_kernel_1<domain,range,mem_manager,compare>& b
+ ) { a.swap(b); }
+
+
+
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void deserialize (
+ binary_search_tree_kernel_1<domain,range,mem_manager,compare>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ domain d;
+ range r;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(d,in);
+ deserialize(r,in);
+ item.add(d,r);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_1");
+ }
+ }
+
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ ~binary_search_tree_kernel_1 (
+ )
+ {
+ ppool.deallocate_array(stack);
+ if (tree_size != 0)
+ {
+ delete_tree(tree_root);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ clear (
+ )
+ {
+ if (tree_size > 0)
+ {
+ delete_tree(tree_root);
+ tree_root = 0;
+ tree_size = 0;
+ tree_height = 0;
+ }
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ size_t binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ size (
+ ) const
+ {
+ return tree_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ short binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ height (
+ ) const
+ {
+ return tree_height;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ unsigned long binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ count (
+ const domain& item
+ ) const
+ {
+ return get_count(item,tree_root);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ add (
+ domain& d,
+ range& r
+ )
+ {
+ tree_height += add_to_tree(tree_root,d,r);
+ ++tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ tree_height -= remove_from_tree(tree_root,d,d_copy,r);
+ --tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ destroy (
+ const domain& item
+ )
+ {
+ tree_height -= remove_from_tree(tree_root,item);
+ --tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ remove_any (
+ domain& d,
+ range& r
+ )
+ {
+ tree_height -= remove_least_element_in_tree(tree_root,d,r);
+ --tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ range* binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ operator[] (
+ const domain& item
+ )
+ {
+ return return_reference(tree_root,item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ const range* binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ operator[] (
+ const domain& item
+ ) const
+ {
+ return return_reference(tree_root,item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ swap (
+ binary_search_tree_kernel_1<domain,range,mem_manager,compare>& item
+ )
+ {
+ pool.swap(item.pool);
+ ppool.swap(item.ppool);
+ exchange(p,item.p);
+ exchange(stack,item.stack);
+ exchange(stack_pos,item.stack_pos);
+ exchange(comp,item.comp);
+
+
+ node* tree_root_temp = item.tree_root;
+ unsigned long tree_size_temp = item.tree_size;
+ short tree_height_temp = item.tree_height;
+ node* current_element_temp = item.current_element;
+ bool at_start_temp = item.at_start_;
+
+ item.tree_root = tree_root;
+ item.tree_size = tree_size;
+ item.tree_height = tree_height;
+ item.current_element = current_element;
+ item.at_start_ = at_start_;
+
+ tree_root = tree_root_temp;
+ tree_size = tree_size_temp;
+ tree_height = tree_height_temp;
+ current_element = current_element_temp;
+ at_start_ = at_start_temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ remove_last_in_order (
+ domain& d,
+ range& r
+ )
+ {
+ tree_height -= remove_biggest_element_in_tree(tree_root,d,r);
+ --tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ remove_current_element (
+ domain& d,
+ range& r
+ )
+ {
+ tree_height -= remove_current_element_from_tree(tree_root,d,r);
+ --tree_size;
+
+ // fix the enumerator stack if we need to
+ if (current_element)
+ fix_stack(tree_root);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ position_enumerator (
+ const domain& d
+ ) const
+ {
+ // clear the enumerator state and make sure the stack is empty
+ reset();
+ at_start_ = false;
+ node* t = tree_root;
+ bool went_left = false;
+ while (t != 0)
+ {
+ if ( comp(d , t->d) )
+ {
+ push(t);
+ // if item is on the left then look in left
+ t = t->left;
+ went_left = true;
+ }
+ else if (comp(t->d , d))
+ {
+ push(t);
+ // if item is on the right then look in right
+ t = t->right;
+ went_left = false;
+ }
+ else
+ {
+ current_element = t;
+ return;
+ }
+ }
+
+ // if we didn't find any matches but there might be something after the
+ // d in this tree.
+ if (stack_pos > 0)
+ {
+ current_element = pop();
+ // if we went left from this node then this node is the next
+ // biggest.
+ if (went_left)
+ {
+ return;
+ }
+ else
+ {
+ move_next();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ at_start (
+ ) const
+ {
+ return at_start_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ reset (
+ ) const
+ {
+ at_start_ = true;
+ current_element = 0;
+ stack_pos = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ current_element_valid (
+ ) const
+ {
+ return (current_element != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ const map_pair<domain,range>& binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ element (
+ ) const
+ {
+ p.d = &(current_element->d);
+ p.r = &(current_element->r);
+ return p;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ map_pair<domain,range>& binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ element (
+ )
+ {
+ p.d = &(current_element->d);
+ p.r = &(current_element->r);
+ return p;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ move_next (
+ ) const
+ {
+ // if we haven't started iterating yet
+ if (at_start_)
+ {
+ at_start_ = false;
+ if (tree_size == 0)
+ {
+ return false;
+ }
+ else
+ {
+ // find the first element in the tree
+ current_element = tree_root;
+ node* temp = current_element->left;
+ while (temp != 0)
+ {
+ push(current_element);
+ current_element = temp;
+ temp = current_element->left;
+ }
+ return true;
+ }
+ }
+ else
+ {
+ if (current_element == 0)
+ {
+ return false;
+ }
+ else
+ {
+ node* temp;
+ bool went_up; // true if we went up the tree from a child node to parent
+ bool from_left = false; // true if we went up and were coming from a left child node
+ // find the next element in the tree
+ if (current_element->right != 0)
+ {
+ // go right and down
+ temp = current_element;
+ push(current_element);
+ current_element = temp->right;
+ went_up = false;
+ }
+ else
+ {
+ // go up to the parent if we can
+ if (current_element == tree_root)
+ {
+ // in this case we have iterated over all the element of the tree
+ current_element = 0;
+ return false;
+ }
+ went_up = true;
+ node* parent = pop();
+
+
+ from_left = (parent->left == current_element);
+ // go up to parent
+ current_element = parent;
+ }
+
+
+ while (true)
+ {
+ if (went_up)
+ {
+ if (from_left)
+ {
+ // in this case we have found the next node
+ break;
+ }
+ else
+ {
+ if (current_element == tree_root)
+ {
+ // in this case we have iterated over all the elements
+ // in the tree
+ current_element = 0;
+ return false;
+ }
+ // we should go up
+ node* parent = pop();
+ from_left = (parent->left == current_element);
+ current_element = parent;
+ }
+ }
+ else
+ {
+ // we just went down to a child node
+ if (current_element->left != 0)
+ {
+ // go left
+ went_up = false;
+ temp = current_element;
+ push(current_element);
+ current_element = temp->left;
+ }
+ else
+ {
+ // if there is no left child then we have found the next node
+ break;
+ }
+ }
+ }
+
+ return true;
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ delete_tree (
+ node* t
+ )
+ {
+ if (t->left != 0)
+ delete_tree(t->left);
+ if (t->right != 0)
+ delete_tree(t->right);
+ pool.deallocate(t);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ rotate_left (
+ node*& t
+ )
+ {
+
+ // set the new balance numbers
+ if (t->right->balance == 1)
+ {
+ t->balance = 0;
+ t->right->balance = 0;
+ }
+ else
+ {
+ t->balance = 1;
+ t->right->balance = -1;
+ }
+
+ // perform the rotation
+ node* temp = t->right;
+ t->right = temp->left;
+ temp->left = t;
+ t = temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ rotate_right (
+ node*& t
+ )
+ {
+ // set the new balance numbers
+ if (t->left->balance == -1)
+ {
+ t->balance = 0;
+ t->left->balance = 0;
+ }
+ else
+ {
+ t->balance = -1;
+ t->left->balance = 1;
+ }
+
+ // preform the rotation
+ node* temp = t->left;
+ t->left = temp->right;
+ temp->right = t;
+ t = temp;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ double_rotate_right (
+ node*& t
+ )
+ {
+
+ node* temp = t;
+ t = t->left->right;
+
+ temp->left->right = t->left;
+ t->left = temp->left;
+
+ temp->left = t->right;
+ t->right = temp;
+
+ if (t->balance < 0)
+ {
+ t->left->balance = 0;
+ t->right->balance = 1;
+ }
+ else if (t->balance > 0)
+ {
+ t->left->balance = -1;
+ t->right->balance = 0;
+ }
+ else
+ {
+ t->left->balance = 0;
+ t->right->balance = 0;
+ }
+ t->balance = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ double_rotate_left (
+ node*& t
+ )
+ {
+ node* temp = t;
+ t = t->right->left;
+
+ temp->right->left = t->right;
+ t->right = temp->right;
+
+ temp->right = t->left;
+ t->left = temp;
+
+ if (t->balance < 0)
+ {
+ t->left->balance = 0;
+ t->right->balance = 1;
+ }
+ else if (t->balance > 0)
+ {
+ t->left->balance = -1;
+ t->right->balance = 0;
+ }
+ else
+ {
+ t->left->balance = 0;
+ t->right->balance = 0;
+ }
+
+ t->balance = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ remove_biggest_element_in_tree (
+ node*& t,
+ domain& d,
+ range& r
+ )
+ {
+ // make a reference to the current node so we don't have to dereference a
+ // pointer a bunch of times
+ node& tree = *t;
+
+ // if the right tree is an empty tree
+ if ( tree.right == 0)
+ {
+ // swap nodes domain and range elements into d and r
+ exchange(d,tree.d);
+ exchange(r,tree.r);
+
+ // plug hole left by removing this node
+ t = tree.left;
+
+ // delete the node that was just removed
+ pool.deallocate(&tree);
+
+ // return that the height of this part of the tree has decreased
+ return true;
+ }
+ else
+ {
+
+ // keep going right
+
+ // if remove made the tree one height shorter
+ if ( remove_biggest_element_in_tree(tree.right,d,r) )
+ {
+ // if this caused the current tree to strink then report that
+ if ( tree.balance == 1)
+ {
+ --tree.balance;
+ return true;
+ }
+ else
+ {
+ --tree.balance;
+ return keep_node_balanced(t);
+ }
+ }
+
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ remove_least_element_in_tree (
+ node*& t,
+ domain& d,
+ range& r
+ )
+ {
+ // make a reference to the current node so we don't have to dereference a
+ // pointer a bunch of times
+ node& tree = *t;
+
+ // if the left tree is an empty tree
+ if ( tree.left == 0)
+ {
+ // swap nodes domain and range elements into d and r
+ exchange(d,tree.d);
+ exchange(r,tree.r);
+
+ // plug hole left by removing this node
+ t = tree.right;
+
+ // delete the node that was just removed
+ pool.deallocate(&tree);
+
+ // return that the height of this part of the tree has decreased
+ return true;
+ }
+ else
+ {
+
+ // keep going left
+
+ // if remove made the tree one height shorter
+ if ( remove_least_element_in_tree(tree.left,d,r) )
+ {
+ // if this caused the current tree to strink then report that
+ if ( tree.balance == -1)
+ {
+ ++tree.balance;
+ return true;
+ }
+ else
+ {
+ ++tree.balance;
+ return keep_node_balanced(t);
+ }
+ }
+
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ add_to_tree (
+ node*& t,
+ domain& d,
+ range& r
+ )
+ {
+
+ // if found place to add
+ if (t == 0)
+ {
+ // create a node to add new item into
+ t = pool.allocate();
+
+ // make a reference to the current node so we don't have to dereference a
+ // pointer a bunch of times
+ node& tree = *t;
+
+
+ // set left and right pointers to NULL to indicate that there are no
+ // left or right subtrees
+ tree.left = 0;
+ tree.right = 0;
+ tree.balance = 0;
+
+ // put d and r into t
+ exchange(tree.d,d);
+ exchange(tree.r,r);
+
+ // indicate that the height of this tree has increased
+ return true;
+ }
+ else // keep looking for a place to add the new item
+ {
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+ signed char old_balance = tree.balance;
+
+ // add the new item to whatever subtree it should go into
+ if (comp( d , tree.d) )
+ tree.balance -= add_to_tree(tree.left,d,r);
+ else
+ tree.balance += add_to_tree(tree.right,d,r);
+
+
+ // if the tree was balanced to start with
+ if (old_balance == 0)
+ {
+ // if its not balanced anymore then it grew in height
+ if (tree.balance != 0)
+ return true;
+ else
+ return false;
+ }
+ else
+ {
+ // if the tree is now balanced then it didn't grow
+ if (tree.balance == 0)
+ {
+ return false;
+ }
+ else
+ {
+ // if the tree needs to be balanced
+ if (tree.balance != old_balance)
+ {
+ return !keep_node_balanced(t);
+ }
+ // if there has been no change in the heights
+ else
+ {
+ return false;
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ fix_stack (
+ node* t,
+ unsigned char depth
+ )
+ {
+ // if we found the node we were looking for
+ if (t == current_element)
+ {
+ stack_pos = depth;
+ return true;
+ }
+ else if (t == 0)
+ {
+ return false;
+ }
+
+ if (!( comp(t->d , current_element->d)))
+ {
+ // go left
+ if (fix_stack(t->left,depth+1))
+ {
+ stack[depth] = t;
+ return true;
+ }
+ }
+ if (!(comp(current_element->d , t->d)))
+ {
+ // go right
+ if (fix_stack(t->right,depth+1))
+ {
+ stack[depth] = t;
+ return true;
+ }
+ }
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ remove_current_element_from_tree (
+ node*& t,
+ domain& d,
+ range& r,
+ unsigned long cur_stack_pos
+ )
+ {
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+
+ // if we found the node we were looking for
+ if (t == current_element)
+ {
+
+ // swap nodes domain and range elements into d_copy and r
+ exchange(d,tree.d);
+ exchange(r,tree.r);
+
+ // if there is no left node
+ if (tree.left == 0)
+ {
+ // move the enumerator on to the next element before we mess with the
+ // tree
+ move_next();
+
+ // plug hole left by removing this node and free memory
+ t = tree.right; // plug hole with right subtree
+
+ // delete old node
+ pool.deallocate(&tree);
+
+ // indicate that the height has changed
+ return true;
+ }
+ // if there is no right node
+ else if (tree.right == 0)
+ {
+ // move the enumerator on to the next element before we mess with the
+ // tree
+ move_next();
+
+ // plug hole left by removing this node and free memory
+ t = tree.left; // plug hole with left subtree
+
+ // delete old node
+ pool.deallocate(&tree);
+
+ // indicate that the height of this tree has changed
+ return true;
+ }
+ // if there are both a left and right sub node
+ else
+ {
+
+ // in this case the next current element is going to get swapped back
+ // into this t node.
+ current_element = t;
+
+ // get an element that can replace the one being removed and do this
+ // if it made the right subtree shrink by one
+ if (remove_least_element_in_tree(tree.right,tree.d,tree.r))
+ {
+ // adjust the tree height
+ --tree.balance;
+
+ // if the height of the current tree has dropped by one
+ if (tree.balance == 0)
+ {
+ return true;
+ }
+ else
+ {
+ return keep_node_balanced(t);
+ }
+ }
+ // else this remove did not effect the height of this tree
+ else
+ {
+ return false;
+ }
+
+ }
+
+ }
+ else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.left) ||
+ tree.left == current_element )
+ {
+ // go left
+ if (tree.balance == -1)
+ {
+ int balance = tree.balance;
+ balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1);
+ tree.balance = balance;
+ return !tree.balance;
+ }
+ else
+ {
+ int balance = tree.balance;
+ balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1);
+ tree.balance = balance;
+ return keep_node_balanced(t);
+ }
+ }
+ else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.right) ||
+ tree.right == current_element )
+ {
+ // go right
+ if (tree.balance == 1)
+ {
+ int balance = tree.balance;
+ balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1);
+ tree.balance = balance;
+ return !tree.balance;
+ }
+ else
+ {
+ int balance = tree.balance;
+ balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1);
+ tree.balance = balance;
+ return keep_node_balanced(t);
+ }
+ }
+
+ // this return should never happen but do it anyway to suppress compiler warnings
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ remove_from_tree (
+ node*& t,
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+
+ // if item is on the left
+ if (comp(d , tree.d))
+ {
+ // if the left side of the tree has the greatest height
+ if (tree.balance == -1)
+ {
+ int balance = tree.balance;
+ balance += remove_from_tree(tree.left,d,d_copy,r);
+ tree.balance = balance;
+ return !tree.balance;
+ }
+ else
+ {
+ int balance = tree.balance;
+ balance += remove_from_tree(tree.left,d,d_copy,r);
+ tree.balance = balance;
+ return keep_node_balanced(t);
+ }
+
+ }
+ // if item is on the right
+ else if (comp(tree.d , d))
+ {
+
+ // if the right side of the tree has the greatest height
+ if (tree.balance == 1)
+ {
+ int balance = tree.balance;
+ balance -= remove_from_tree(tree.right,d,d_copy,r);
+ tree.balance = balance;
+ return !tree.balance;
+ }
+ else
+ {
+ int balance = tree.balance;
+ balance -= remove_from_tree(tree.right,d,d_copy,r);
+ tree.balance = balance;
+ return keep_node_balanced(t);
+ }
+ }
+ // if item is found
+ else
+ {
+
+ // swap nodes domain and range elements into d_copy and r
+ exchange(d_copy,tree.d);
+ exchange(r,tree.r);
+
+ // if there is no left node
+ if (tree.left == 0)
+ {
+
+ // plug hole left by removing this node and free memory
+ t = tree.right; // plug hole with right subtree
+
+ // delete old node
+ pool.deallocate(&tree);
+
+ // indicate that the height has changed
+ return true;
+ }
+ // if there is no right node
+ else if (tree.right == 0)
+ {
+
+ // plug hole left by removing this node and free memory
+ t = tree.left; // plug hole with left subtree
+
+ // delete old node
+ pool.deallocate(&tree);
+
+ // indicate that the height of this tree has changed
+ return true;
+ }
+ // if there are both a left and right sub node
+ else
+ {
+
+ // get an element that can replace the one being removed and do this
+ // if it made the right subtree shrink by one
+ if (remove_least_element_in_tree(tree.right,tree.d,tree.r))
+ {
+ // adjust the tree height
+ --tree.balance;
+
+ // if the height of the current tree has dropped by one
+ if (tree.balance == 0)
+ {
+ return true;
+ }
+ else
+ {
+ return keep_node_balanced(t);
+ }
+ }
+ // else this remove did not effect the height of this tree
+ else
+ {
+ return false;
+ }
+
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ remove_from_tree (
+ node*& t,
+ const domain& d
+ )
+ {
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+
+ // if item is on the left
+ if (comp(d , tree.d))
+ {
+ // if the left side of the tree has the greatest height
+ if (tree.balance == -1)
+ {
+ int balance = tree.balance;
+ balance += remove_from_tree(tree.left,d);
+ tree.balance = balance;
+ return !tree.balance;
+ }
+ else
+ {
+ int balance = tree.balance;
+ balance += remove_from_tree(tree.left,d);
+ tree.balance = balance;
+ return keep_node_balanced(t);
+ }
+
+ }
+ // if item is on the right
+ else if (comp(tree.d , d))
+ {
+
+ // if the right side of the tree has the greatest height
+ if (tree.balance == 1)
+ {
+ int balance = tree.balance;
+ balance -= remove_from_tree(tree.right,d);
+ tree.balance = balance;
+ return !tree.balance;
+ }
+ else
+ {
+ int balance = tree.balance;
+ balance -= remove_from_tree(tree.right,d);
+ tree.balance = balance;
+ return keep_node_balanced(t);
+ }
+ }
+ // if item is found
+ else
+ {
+
+ // if there is no left node
+ if (tree.left == 0)
+ {
+
+ // plug hole left by removing this node and free memory
+ t = tree.right; // plug hole with right subtree
+
+ // delete old node
+ pool.deallocate(&tree);
+
+ // indicate that the height has changed
+ return true;
+ }
+ // if there is no right node
+ else if (tree.right == 0)
+ {
+
+ // plug hole left by removing this node and free memory
+ t = tree.left; // plug hole with left subtree
+
+ // delete old node
+ pool.deallocate(&tree);
+
+ // indicate that the height of this tree has changed
+ return true;
+ }
+ // if there are both a left and right sub node
+ else
+ {
+
+ // get an element that can replace the one being removed and do this
+ // if it made the right subtree shrink by one
+ if (remove_least_element_in_tree(tree.right,tree.d,tree.r))
+ {
+ // adjust the tree height
+ --tree.balance;
+
+ // if the height of the current tree has dropped by one
+ if (tree.balance == 0)
+ {
+ return true;
+ }
+ else
+ {
+ return keep_node_balanced(t);
+ }
+ }
+ // else this remove did not effect the height of this tree
+ else
+ {
+ return false;
+ }
+
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ range* binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ return_reference (
+ node* t,
+ const domain& d
+ )
+ {
+ while (t != 0)
+ {
+
+ if ( comp(d , t->d ))
+ {
+ // if item is on the left then look in left
+ t = t->left;
+ }
+ else if (comp(t->d , d))
+ {
+ // if item is on the right then look in right
+ t = t->right;
+ }
+ else
+ {
+ // if it's found then return a reference to it
+ return &(t->r);
+ }
+ }
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ const range* binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ return_reference (
+ const node* t,
+ const domain& d
+ ) const
+ {
+ while (t != 0)
+ {
+
+ if ( comp(d , t->d) )
+ {
+ // if item is on the left then look in left
+ t = t->left;
+ }
+ else if (comp(t->d , d))
+ {
+ // if item is on the right then look in right
+ t = t->right;
+ }
+ else
+ {
+ // if it's found then return a reference to it
+ return &(t->r);
+ }
+ }
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ keep_node_balanced (
+ node*& t
+ )
+ {
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+
+ // if tree does not need to be balanced then return false
+ if (tree.balance == 0)
+ return false;
+
+
+ // if tree needs to be rotated left
+ if (tree.balance == 2)
+ {
+ if (tree.right->balance >= 0)
+ rotate_left(t);
+ else
+ double_rotate_left(t);
+ }
+ // else if the tree needs to be rotated right
+ else if (tree.balance == -2)
+ {
+ if (tree.left->balance <= 0)
+ rotate_right(t);
+ else
+ double_rotate_right(t);
+ }
+
+
+ if (t->balance == 0)
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ unsigned long binary_search_tree_kernel_1<domain,range,mem_manager,compare>::
+ get_count (
+ const domain& d,
+ node* tree_root
+ ) const
+ {
+ if (tree_root != 0)
+ {
+ if (comp(d , tree_root->d))
+ {
+ // go left
+ return get_count(d,tree_root->left);
+ }
+ else if (comp(tree_root->d , d))
+ {
+ // go right
+ return get_count(d,tree_root->right);
+ }
+ else
+ {
+ // go left and right to look for more matches
+ return get_count(d,tree_root->left)
+ + get_count(d,tree_root->right)
+ + 1;
+ }
+ }
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_1_
+
diff --git a/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_2.h b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_2.h
new file mode 100644
index 000000000..098d38c2e
--- /dev/null
+++ b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_2.h
@@ -0,0 +1,1897 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_2_
+#define DLIB_BINARY_SEARCH_TREE_KERNEl_2_
+
+#include "binary_search_tree_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/map_pair.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare = std::less<domain>
+ >
+ class binary_search_tree_kernel_2 : public enumerable<map_pair<domain,range> >,
+ public asc_pair_remover<domain,range,compare>
+ {
+
+ /*!
+ INITIAL VALUE
+ NIL == pointer to a node that represents a leaf
+ tree_size == 0
+ tree_root == NIL
+ at_start == true
+ current_element == 0
+
+
+ CONVENTION
+ current_element_valid() == (current_element != 0)
+ if (current_element_valid()) then
+ element() == current_element->d and current_element->r
+ at_start_ == at_start()
+
+
+ tree_size == size()
+
+ NIL == pointer to a node that represents a leaf
+
+ if (tree_size != 0)
+ tree_root == pointer to the root node of the binary search tree
+ else
+ tree_root == NIL
+
+ tree_root->color == black
+ Every leaf is black and all leafs are the NIL node.
+ The number of black nodes in any path from the root to a leaf is the
+ same.
+
+ for all nodes:
+ {
+ - left points to the left subtree or NIL if there is no left subtree
+ - right points to the right subtree or NIL if there is no right
+ subtree
+ - parent points to the parent node or NIL if the node is the root
+ - ordering of nodes is determined by comparing each node's d member
+ - all elements in a left subtree are <= the node
+ - all elements in a right subtree are >= the node
+ - color == red or black
+ - if (color == red)
+ - the node's children are black
+ }
+
+ !*/
+
+ class node
+ {
+ public:
+ node* left;
+ node* right;
+ node* parent;
+ domain d;
+ range r;
+ char color;
+ };
+
+ class mpair : public map_pair<domain,range>
+ {
+ public:
+ const domain* d;
+ range* r;
+
+ const domain& key(
+ ) const { return *d; }
+
+ const range& value(
+ ) const { return *r; }
+
+ range& value(
+ ) { return *r; }
+ };
+
+
+ const static char red = 0;
+ const static char black = 1;
+
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ binary_search_tree_kernel_2(
+ ) :
+ NIL(pool.allocate()),
+ tree_size(0),
+ tree_root(NIL),
+ current_element(0),
+ at_start_(true)
+ {
+ NIL->color = black;
+ NIL->left = 0;
+ NIL->right = 0;
+ NIL->parent = 0;
+ }
+
+ virtual ~binary_search_tree_kernel_2(
+ );
+
+ inline void clear(
+ );
+
+ inline short height (
+ ) const;
+
+ inline unsigned long count (
+ const domain& d
+ ) const;
+
+ inline void add (
+ domain& d,
+ range& r
+ );
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ void destroy (
+ const domain& d
+ );
+
+ void remove_any (
+ domain& d,
+ range& r
+ );
+
+ inline const range* operator[] (
+ const domain& item
+ ) const;
+
+ inline range* operator[] (
+ const domain& item
+ );
+
+ inline void swap (
+ binary_search_tree_kernel_2& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ const map_pair<domain,range>& element (
+ ) const;
+
+ map_pair<domain,range>& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ void remove_last_in_order (
+ domain& d,
+ range& r
+ );
+
+ void remove_current_element (
+ domain& d,
+ range& r
+ );
+
+ void position_enumerator (
+ const domain& d
+ ) const;
+
+ private:
+
+ inline void rotate_left (
+ node* t
+ );
+ /*!
+ requires
+ - t != NIL
+ - t->right != NIL
+ ensures
+ - performs a left rotation around t and its right child
+ !*/
+
+ inline void rotate_right (
+ node* t
+ );
+ /*!
+ requires
+ - t != NIL
+ - t->left != NIL
+ ensures
+ - performs a right rotation around t and its left child
+ !*/
+
+ inline void double_rotate_right (
+ node* t
+ );
+ /*!
+ requires
+ - t != NIL
+ - t->left != NIL
+ - t->left->right != NIL
+ - double_rotate_right() is only called in fix_after_add()
+ ensures
+ - performs a left rotation around t->left
+ - then performs a right rotation around t
+ !*/
+
+ inline void double_rotate_left (
+ node* t
+ );
+ /*!
+ requires
+ - t != NIL
+ - t->right != NIL
+ - t->right->left != NIL
+ - double_rotate_left() is only called in fix_after_add()
+ ensures
+ - performs a right rotation around t->right
+ - then performs a left rotation around t
+ !*/
+
+ void remove_biggest_element_in_tree (
+ node* t,
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - t != NIL (i.e. there must be something in the tree to remove)
+ ensures
+ - the biggest node in t has been removed
+ - the biggest node element in t has been put into #d and #r
+ - #t is still a binary search tree
+ !*/
+
+ bool remove_least_element_in_tree (
+ node* t,
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - t != NIL (i.e. there must be something in the tree to remove)
+ ensures
+ - the least node in t has been removed
+ - the least node element in t has been put into #d and #r
+ - #t is still a binary search tree
+ - if (the node that was removed was the one pointed to by current_element) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void add_to_tree (
+ node* t,
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - t != NIL
+ ensures
+ - d and r are now in #t
+ - there is a mapping from d to r in #t
+ - #d and #r have initial values for their types
+ - #t is still a binary search tree
+ !*/
+
+ void remove_from_tree (
+ node* t,
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+ /*!
+ requires
+ - return_reference(t,d) != 0
+ ensures
+ - #d_copy is equivalent to d
+ - the first element in t equivalent to d that is encountered when searching down the tree
+ from t has been removed and swapped into #d_copy. Also, the associated range element
+ has been removed and swapped into #r.
+ - if (the node that got removed wasn't current_element) then
+ - adjusts the current_element pointer if the data in the node that it points to gets moved.
+ - else
+ - the value of current_element is now invalid
+ - #t is still a binary search tree
+ !*/
+
+ void remove_from_tree (
+ node* t,
+ const domain& d
+ );
+ /*!
+ requires
+ - return_reference(t,d) != 0
+ ensures
+ - an element in t equivalent to d has been removed
+ - #t is still a binary search tree
+ !*/
+
+ const range* return_reference (
+ const node* t,
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - if (there is a domain element equivalent to d in t) then
+ - returns a pointer to the element in the range equivalent to d
+ - else
+ - returns 0
+ !*/
+
+ range* return_reference (
+ node* t,
+ const domain& d
+ );
+ /*!
+ ensures
+ - if (there is a domain element equivalent to d in t) then
+ - returns a pointer to the element in the range equivalent to d
+ - else
+ - returns 0
+ !*/
+
+ void fix_after_add (
+ node* t
+ );
+ /*!
+ requires
+ - t == pointer to the node just added
+ - t->color == red
+ - t->parent != NIL (t must not be the root)
+ - fix_after_add() is only called after a new node has been added
+ to t
+ ensures
+ - fixes any deviations from the CONVENTION caused by adding a node
+ !*/
+
+ void fix_after_remove (
+ node* t
+ );
+ /*!
+ requires
+ - t == pointer to the only child of the node that was spliced out
+ - fix_after_remove() is only called after a node has been removed
+ from t
+ - the color of the spliced out node was black
+ ensures
+ - fixes any deviations from the CONVENTION causes by removing a node
+ !*/
+
+
+ short tree_height (
+ node* t
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in the longest path from the root of the
+ tree to a leaf
+ !*/
+
+ void delete_tree (
+ node* t
+ );
+ /*!
+ requires
+ - t == root of binary search tree
+ - t != NIL
+ ensures
+ - deletes all nodes in t except for NIL
+ !*/
+
+ unsigned long get_count (
+ const domain& item,
+ node* tree_root
+ ) const;
+ /*!
+ requires
+ - tree_root == the root of a binary search tree or NIL
+ ensures
+ - if (tree_root == NIL) then
+ - returns 0
+ - else
+ - returns the number of elements in tree_root that are
+ equivalent to item
+ !*/
+
+
+
+ // data members
+ typename mem_manager::template rebind<node>::other pool;
+ node* NIL;
+ unsigned long tree_size;
+ node* tree_root;
+ mutable node* current_element;
+ mutable bool at_start_;
+ mutable mpair p;
+ compare comp;
+
+
+
+ // restricted functions
+ binary_search_tree_kernel_2(binary_search_tree_kernel_2&);
+ binary_search_tree_kernel_2& operator=(binary_search_tree_kernel_2&);
+
+
+ };
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ inline void swap (
+ binary_search_tree_kernel_2<domain,range,mem_manager,compare>& a,
+ binary_search_tree_kernel_2<domain,range,mem_manager,compare>& b
+ ) { a.swap(b); }
+
+
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void deserialize (
+ binary_search_tree_kernel_2<domain,range,mem_manager,compare>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ domain d;
+ range r;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(d,in);
+ deserialize(r,in);
+ item.add(d,r);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_2");
+ }
+ }
+
+
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ ~binary_search_tree_kernel_2 (
+ )
+ {
+ if (tree_root != NIL)
+ delete_tree(tree_root);
+ pool.deallocate(NIL);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ clear (
+ )
+ {
+ if (tree_size > 0)
+ {
+ delete_tree(tree_root);
+ tree_root = NIL;
+ tree_size = 0;
+ }
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ size_t binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ size (
+ ) const
+ {
+ return tree_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ short binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ height (
+ ) const
+ {
+ return tree_height(tree_root);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ unsigned long binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ count (
+ const domain& item
+ ) const
+ {
+ return get_count(item,tree_root);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ add (
+ domain& d,
+ range& r
+ )
+ {
+ if (tree_size == 0)
+ {
+ tree_root = pool.allocate();
+ tree_root->color = black;
+ tree_root->left = NIL;
+ tree_root->right = NIL;
+ tree_root->parent = NIL;
+ exchange(tree_root->d,d);
+ exchange(tree_root->r,r);
+ }
+ else
+ {
+ add_to_tree(tree_root,d,r);
+ }
+ ++tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ remove_from_tree(tree_root,d,d_copy,r);
+ --tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ destroy (
+ const domain& item
+ )
+ {
+ remove_from_tree(tree_root,item);
+ --tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ remove_any (
+ domain& d,
+ range& r
+ )
+ {
+ remove_least_element_in_tree(tree_root,d,r);
+ --tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ range* binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ operator[] (
+ const domain& d
+ )
+ {
+ return return_reference(tree_root,d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ const range* binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ operator[] (
+ const domain& d
+ ) const
+ {
+ return return_reference(tree_root,d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ swap (
+ binary_search_tree_kernel_2<domain,range,mem_manager,compare>& item
+ )
+ {
+ pool.swap(item.pool);
+
+ exchange(p,item.p);
+ exchange(comp,item.comp);
+
+ node* tree_root_temp = item.tree_root;
+ unsigned long tree_size_temp = item.tree_size;
+ node* const NIL_temp = item.NIL;
+ node* current_element_temp = item.current_element;
+ bool at_start_temp = item.at_start_;
+
+ item.tree_root = tree_root;
+ item.tree_size = tree_size;
+ item.NIL = NIL;
+ item.current_element = current_element;
+ item.at_start_ = at_start_;
+
+ tree_root = tree_root_temp;
+ tree_size = tree_size_temp;
+ NIL = NIL_temp;
+ current_element = current_element_temp;
+ at_start_ = at_start_temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ remove_last_in_order (
+ domain& d,
+ range& r
+ )
+ {
+ remove_biggest_element_in_tree(tree_root,d,r);
+ --tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ remove_current_element (
+ domain& d,
+ range& r
+ )
+ {
+ node* t = current_element;
+ move_next();
+ remove_from_tree(t,t->d,d,r);
+ --tree_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ position_enumerator (
+ const domain& d
+ ) const
+ {
+ // clear the enumerator state and make sure the stack is empty
+ reset();
+ at_start_ = false;
+ node* t = tree_root;
+ node* parent = NIL;
+ bool went_left = false;
+ while (t != NIL)
+ {
+ if ( comp(d , t->d ))
+ {
+ // if item is on the left then look in left
+ parent = t;
+ t = t->left;
+ went_left = true;
+ }
+ else if (comp(t->d , d))
+ {
+ // if item is on the right then look in right
+ parent = t;
+ t = t->right;
+ went_left = false;
+ }
+ else
+ {
+ current_element = t;
+ return;
+ }
+ }
+
+ // if we didn't find any matches but there might be something after the
+ // d in this tree.
+ if (parent != NIL)
+ {
+ current_element = parent;
+ // if we went left from this node then this node is the next
+ // biggest.
+ if (went_left)
+ {
+ return;
+ }
+ else
+ {
+ move_next();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ at_start (
+ ) const
+ {
+ return at_start_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ reset (
+ ) const
+ {
+ at_start_ = true;
+ current_element = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ current_element_valid (
+ ) const
+ {
+ return (current_element != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ const map_pair<domain,range>& binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ element (
+ ) const
+ {
+ p.d = &(current_element->d);
+ p.r = &(current_element->r);
+ return p;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ map_pair<domain,range>& binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ element (
+ )
+ {
+ p.d = &(current_element->d);
+ p.r = &(current_element->r);
+ return p;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ move_next (
+ ) const
+ {
+ // if we haven't started iterating yet
+ if (at_start_)
+ {
+ at_start_ = false;
+ if (tree_size == 0)
+ {
+ return false;
+ }
+ else
+ {
+ // find the first element in the tree
+ current_element = tree_root;
+ node* temp = current_element->left;
+ while (temp != NIL)
+ {
+ current_element = temp;
+ temp = current_element->left;
+ }
+ return true;
+ }
+ }
+ else
+ {
+ if (current_element == 0)
+ {
+ return false;
+ }
+ else
+ {
+ bool went_up; // true if we went up the tree from a child node to parent
+ bool from_left = false; // true if we went up and were coming from a left child node
+ // find the next element in the tree
+ if (current_element->right != NIL)
+ {
+ // go right and down
+ current_element = current_element->right;
+ went_up = false;
+ }
+ else
+ {
+ went_up = true;
+ node* parent = current_element->parent;
+ if (parent == NIL)
+ {
+ // in this case we have iterated over all the element of the tree
+ current_element = 0;
+ return false;
+ }
+
+ from_left = (parent->left == current_element);
+ // go up to parent
+ current_element = parent;
+ }
+
+
+ while (true)
+ {
+ if (went_up)
+ {
+ if (from_left)
+ {
+ // in this case we have found the next node
+ break;
+ }
+ else
+ {
+ // we should go up
+ node* parent = current_element->parent;
+ from_left = (parent->left == current_element);
+ current_element = parent;
+ if (current_element == NIL)
+ {
+ // in this case we have iterated over all the elements
+ // in the tree
+ current_element = 0;
+ return false;
+ }
+ }
+ }
+ else
+ {
+ // we just went down to a child node
+ if (current_element->left != NIL)
+ {
+ // go left
+ went_up = false;
+ current_element = current_element->left;
+ }
+ else
+ {
+ // if there is no left child then we have found the next node
+ break;
+ }
+ }
+ }
+
+ return true;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ delete_tree (
+ node* t
+ )
+ {
+ if (t->left != NIL)
+ delete_tree(t->left);
+ if (t->right != NIL)
+ delete_tree(t->right);
+ pool.deallocate(t);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ rotate_left (
+ node* t
+ )
+ {
+
+ // perform the rotation
+ node* temp = t->right;
+ t->right = temp->left;
+ if (temp->left != NIL)
+ temp->left->parent = t;
+ temp->left = t;
+ temp->parent = t->parent;
+
+
+ if (t == tree_root)
+ tree_root = temp;
+ else
+ {
+ // if t was on the left
+ if (t->parent->left == t)
+ t->parent->left = temp;
+ else
+ t->parent->right = temp;
+ }
+
+ t->parent = temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ rotate_right (
+ node* t
+ )
+ {
+ // perform the rotation
+ node* temp = t->left;
+ t->left = temp->right;
+ if (temp->right != NIL)
+ temp->right->parent = t;
+ temp->right = t;
+ temp->parent = t->parent;
+
+ if (t == tree_root)
+ tree_root = temp;
+ else
+ {
+ // if t is a left child
+ if (t->parent->left == t)
+ t->parent->left = temp;
+ else
+ t->parent->right = temp;
+ }
+
+ t->parent = temp;
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ double_rotate_right (
+ node* t
+ )
+ {
+
+ // preform the rotation
+ node& temp = *(t->left->right);
+ t->left = temp.right;
+ temp.right->parent = t;
+ temp.left->parent = temp.parent;
+ temp.parent->right = temp.left;
+ temp.parent->parent = &temp;
+ temp.right = t;
+ temp.left = temp.parent;
+ temp.parent = t->parent;
+
+
+ if (tree_root == t)
+ tree_root = &temp;
+ else
+ {
+ // t is a left child
+ if (t->parent->left == t)
+ t->parent->left = &temp;
+ else
+ t->parent->right = &temp;
+ }
+ t->parent = &temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ double_rotate_left (
+ node* t
+ )
+ {
+
+
+ // preform the rotation
+ node& temp = *(t->right->left);
+ t->right = temp.left;
+ temp.left->parent = t;
+ temp.right->parent = temp.parent;
+ temp.parent->left = temp.right;
+ temp.parent->parent = &temp;
+ temp.left = t;
+ temp.right = temp.parent;
+ temp.parent = t->parent;
+
+
+ if (tree_root == t)
+ tree_root = &temp;
+ else
+ {
+ // t is a left child
+ if (t->parent->left == t)
+ t->parent->left = &temp;
+ else
+ t->parent->right = &temp;
+ }
+ t->parent = &temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ remove_biggest_element_in_tree (
+ node* t,
+ domain& d,
+ range& r
+ )
+ {
+
+ node* next = t->right;
+ node* child; // the child node of the one we will slice out
+
+ if (next == NIL)
+ {
+ // need to determine if t is a right or left child
+ if (t->parent->right == t)
+ child = t->parent->right = t->left;
+ else
+ child = t->parent->left = t->left;
+
+ // update tree_root if necessary
+ if (t == tree_root)
+ tree_root = child;
+ }
+ else
+ {
+ // find the least node
+ do
+ {
+ t = next;
+ next = next->right;
+ } while (next != NIL);
+ // t is a right child
+ child = t->parent->right = t->left;
+
+ }
+
+ // swap the item from this node into d and r
+ exchange(d,t->d);
+ exchange(r,t->r);
+
+ // plug hole right by removing this node
+ child->parent = t->parent;
+
+ // keep the red-black properties true
+ if (t->color == black)
+ fix_after_remove(child);
+
+ // free the memory for this removed node
+ pool.deallocate(t);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ remove_least_element_in_tree (
+ node* t,
+ domain& d,
+ range& r
+ )
+ {
+
+ node* next = t->left;
+ node* child; // the child node of the one we will slice out
+
+ if (next == NIL)
+ {
+ // need to determine if t is a left or right child
+ if (t->parent->left == t)
+ child = t->parent->left = t->right;
+ else
+ child = t->parent->right = t->right;
+
+ // update tree_root if necessary
+ if (t == tree_root)
+ tree_root = child;
+ }
+ else
+ {
+ // find the least node
+ do
+ {
+ t = next;
+ next = next->left;
+ } while (next != NIL);
+ // t is a left child
+ child = t->parent->left = t->right;
+
+ }
+
+ // swap the item from this node into d and r
+ exchange(d,t->d);
+ exchange(r,t->r);
+
+ // plug hole left by removing this node
+ child->parent = t->parent;
+
+ // keep the red-black properties true
+ if (t->color == black)
+ fix_after_remove(child);
+
+ bool rvalue = (t == current_element);
+ // free the memory for this removed node
+ pool.deallocate(t);
+ return rvalue;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ add_to_tree (
+ node* t,
+ domain& d,
+ range& r
+ )
+ {
+ // parent of the current node
+ node* parent;
+
+ // find a place to add node
+ while (true)
+ {
+ parent = t;
+ // if item should be put on the left then go left
+ if (comp(d , t->d))
+ {
+ t = t->left;
+ if (t == NIL)
+ {
+ t = parent->left = pool.allocate();
+ break;
+ }
+ }
+ // if item should be put on the right then go right
+ else
+ {
+ t = t->right;
+ if (t == NIL)
+ {
+ t = parent->right = pool.allocate();
+ break;
+ }
+ }
+ }
+
+ // t is now the node where we will add item and
+ // parent is the parent of t
+
+ t->parent = parent;
+ t->left = NIL;
+ t->right = NIL;
+ t->color = red;
+ exchange(t->d,d);
+ exchange(t->r,r);
+
+
+ // keep the red-black properties true
+ fix_after_add(t);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ remove_from_tree (
+ node* t,
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ while (true)
+ {
+ if ( comp(d , t->d) )
+ {
+ // if item is on the left then look in left
+ t = t->left;
+ }
+ else if (comp(t->d , d))
+ {
+ // if item is on the right then look in right
+ t = t->right;
+ }
+ else
+ {
+ // found the node we want to remove
+
+ // swap out the item into d_copy and r
+ exchange(d_copy,t->d);
+ exchange(r,t->r);
+
+ if (t->left == NIL)
+ {
+ // if there is no left subtree
+
+ node* parent = t->parent;
+
+ // plug hole with right subtree
+
+
+ // if t is on the left
+ if (parent->left == t)
+ parent->left = t->right;
+ else
+ parent->right = t->right;
+ t->right->parent = parent;
+
+ // update tree_root if necessary
+ if (t == tree_root)
+ tree_root = t->right;
+
+ if (t->color == black)
+ fix_after_remove(t->right);
+
+ // delete old node
+ pool.deallocate(t);
+ }
+ else if (t->right == NIL)
+ {
+ // if there is no right subtree
+
+ node* parent = t->parent;
+
+ // plug hole with left subtree
+ if (parent->left == t)
+ parent->left = t->left;
+ else
+ parent->right = t->left;
+ t->left->parent = parent;
+
+ // update tree_root if necessary
+ if (t == tree_root)
+ tree_root = t->left;
+
+ if (t->color == black)
+ fix_after_remove(t->left);
+
+ // delete old node
+ pool.deallocate(t);
+ }
+ else
+ {
+ // if there is both a left and right subtree
+ // get an element to fill this node now that its been swapped into
+ // item_copy
+ if (remove_least_element_in_tree(t->right,t->d,t->r))
+ {
+ // the node removed was the one pointed to by current_element so we
+ // need to update it so that it points to the right spot.
+ current_element = t;
+ }
+ }
+
+ // quit loop
+ break;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ remove_from_tree (
+ node* t,
+ const domain& d
+ )
+ {
+ while (true)
+ {
+ if ( comp(d , t->d) )
+ {
+ // if item is on the left then look in left
+ t = t->left;
+ }
+ else if (comp(t->d , d))
+ {
+ // if item is on the right then look in right
+ t = t->right;
+ }
+ else
+ {
+ // found the node we want to remove
+
+
+ if (t->left == NIL)
+ {
+ // if there is no left subtree
+
+ node* parent = t->parent;
+
+ // plug hole with right subtree
+
+
+ if (parent->left == t)
+ parent->left = t->right;
+ else
+ parent->right = t->right;
+ t->right->parent = parent;
+
+ // update tree_root if necessary
+ if (t == tree_root)
+ tree_root = t->right;
+
+ if (t->color == black)
+ fix_after_remove(t->right);
+
+ // delete old node
+ pool.deallocate(t);
+ }
+ else if (t->right == NIL)
+ {
+ // if there is no right subtree
+
+ node* parent = t->parent;
+
+ // plug hole with left subtree
+ if (parent->left == t)
+ parent->left = t->left;
+ else
+ parent->right = t->left;
+ t->left->parent = parent;
+
+ // update tree_root if necessary
+ if (t == tree_root)
+ tree_root = t->left;
+
+ if (t->color == black)
+ fix_after_remove(t->left);
+
+ // delete old node
+ pool.deallocate(t);
+ }
+ else
+ {
+ // if there is both a left and right subtree
+ // get an element to fill this node now that its been swapped into
+ // item_copy
+ remove_least_element_in_tree(t->right,t->d,t->r);
+
+ }
+
+ // quit loop
+ break;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ range* binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ return_reference (
+ node* t,
+ const domain& d
+ )
+ {
+ while (t != NIL)
+ {
+ if ( comp(d , t->d ))
+ {
+ // if item is on the left then look in left
+ t = t->left;
+ }
+ else if (comp(t->d , d))
+ {
+ // if item is on the right then look in right
+ t = t->right;
+ }
+ else
+ {
+ // if it's found then return a reference to it
+ return &(t->r);
+ }
+ }
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ const range* binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ return_reference (
+ const node* t,
+ const domain& d
+ ) const
+ {
+ while (t != NIL)
+ {
+ if ( comp(d , t->d) )
+ {
+ // if item is on the left then look in left
+ t = t->left;
+ }
+ else if (comp(t->d , d))
+ {
+ // if item is on the right then look in right
+ t = t->right;
+ }
+ else
+ {
+ // if it's found then return a reference to it
+ return &(t->r);
+ }
+ }
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ fix_after_add (
+ node* t
+ )
+ {
+
+ while (t->parent->color == red)
+ {
+ node& grandparent = *(t->parent->parent);
+
+ // if both t's parent and its sibling are red
+ if (grandparent.left->color == grandparent.right->color)
+ {
+ grandparent.color = red;
+ grandparent.left->color = black;
+ grandparent.right->color = black;
+ t = &grandparent;
+ }
+ else
+ {
+ // if t is a left child
+ if (t == t->parent->left)
+ {
+ // if t's parent is a left child
+ if (t->parent == grandparent.left)
+ {
+ grandparent.color = red;
+ grandparent.left->color = black;
+ rotate_right(&grandparent);
+ }
+ // if t's parent is a right child
+ else
+ {
+ t->color = black;
+ grandparent.color = red;
+ double_rotate_left(&grandparent);
+ }
+ }
+ // if t is a right child
+ else
+ {
+ // if t's parent is a left child
+ if (t->parent == grandparent.left)
+ {
+ t->color = black;
+ grandparent.color = red;
+ double_rotate_right(&grandparent);
+ }
+ // if t's parent is a right child
+ else
+ {
+ grandparent.color = red;
+ grandparent.right->color = black;
+ rotate_left(&grandparent);
+ }
+ }
+ break;
+ }
+ }
+ tree_root->color = black;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ fix_after_remove (
+ node* t
+ )
+ {
+
+ while (t != tree_root && t->color == black)
+ {
+ if (t->parent->left == t)
+ {
+ node* sibling = t->parent->right;
+ if (sibling->color == red)
+ {
+ sibling->color = black;
+ t->parent->color = red;
+ rotate_left(t->parent);
+ sibling = t->parent->right;
+ }
+
+ if (sibling->left->color == black && sibling->right->color == black)
+ {
+ sibling->color = red;
+ t = t->parent;
+ }
+ else
+ {
+ if (sibling->right->color == black)
+ {
+ sibling->left->color = black;
+ sibling->color = red;
+ rotate_right(sibling);
+ sibling = t->parent->right;
+ }
+
+ sibling->color = t->parent->color;
+ t->parent->color = black;
+ sibling->right->color = black;
+ rotate_left(t->parent);
+ t = tree_root;
+
+ }
+
+
+ }
+ else
+ {
+
+ node* sibling = t->parent->left;
+ if (sibling->color == red)
+ {
+ sibling->color = black;
+ t->parent->color = red;
+ rotate_right(t->parent);
+ sibling = t->parent->left;
+ }
+
+ if (sibling->left->color == black && sibling->right->color == black)
+ {
+ sibling->color = red;
+ t = t->parent;
+ }
+ else
+ {
+ if (sibling->left->color == black)
+ {
+ sibling->right->color = black;
+ sibling->color = red;
+ rotate_left(sibling);
+ sibling = t->parent->left;
+ }
+
+ sibling->color = t->parent->color;
+ t->parent->color = black;
+ sibling->left->color = black;
+ rotate_right(t->parent);
+ t = tree_root;
+
+ }
+
+
+ }
+
+ }
+ t->color = black;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ short binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ tree_height (
+ node* t
+ ) const
+ {
+ if (t == NIL)
+ return 0;
+
+ short height1 = tree_height(t->left);
+ short height2 = tree_height(t->right);
+ if (height1 > height2)
+ return height1 + 1;
+ else
+ return height2 + 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ unsigned long binary_search_tree_kernel_2<domain,range,mem_manager,compare>::
+ get_count (
+ const domain& d,
+ node* tree_root
+ ) const
+ {
+ if (tree_root != NIL)
+ {
+ if (comp(d , tree_root->d))
+ {
+ // go left
+ return get_count(d,tree_root->left);
+ }
+ else if (comp(tree_root->d , d))
+ {
+ // go right
+ return get_count(d,tree_root->right);
+ }
+ else
+ {
+ // go left and right to look for more matches
+ return get_count(d,tree_root->left)
+ + get_count(d,tree_root->right)
+ + 1;
+ }
+ }
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_2_
+
diff --git a/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h
new file mode 100644
index 000000000..2abfe7e39
--- /dev/null
+++ b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h
@@ -0,0 +1,311 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_
+#ifdef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_
+
+#include "../interfaces/map_pair.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include "../algs.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<domain>
+ >
+ class binary_search_tree : public enumerable<map_pair<domain,range> >,
+ public asc_pair_remover<domain,range,compare>
+ {
+
+ /*!
+ REQUIREMENTS ON domain
+ domain must be comparable by compare where compare is a functor compatible with std::less and
+ domain is swappable by a global swap() and
+ domain must have a default constructor
+
+ REQUIREMENTS ON range
+ range is swappable by a global swap() and
+ range must have a default constructor
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap(), count(), height(), and operator[] functions
+ do not invalidate pointers or references to internal data.
+
+ position_enumerator() invalidates pointers or references to
+ data returned by element() and only by element() (i.e. pointers and
+ references returned by operator[] are still valid).
+
+ All other functions have no such guarantees.
+
+ INITIAL VALUE
+ size() == 0
+ height() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the domain (and each associated
+ range element) elements in ascending order according to the compare functor.
+ (i.e. the elements are enumerated in sorted order)
+
+ WHAT THIS OBJECT REPRESENTS
+ this object represents a data dictionary that is built on top of some
+ kind of binary search tree. It maps objects of type domain to objects
+ of type range.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+ NOTE:
+ definition of equivalent:
+ a is equivalent to b if
+ a < b == false and
+ b < a == false
+ !*/
+
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ binary_search_tree(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ !*/
+
+ virtual ~binary_search_tree(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ short height (
+ ) const;
+ /*!
+ ensures
+ - returns the number of elements in the longest path from the root
+ of the tree to a leaf
+ !*/
+
+ unsigned long count (
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - returns the number of elements in the domain of *this that are
+ equivalent to d
+ !*/
+
+ void add (
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - &d != &r (i.e. d and r cannot be the same variable)
+ ensures
+ - adds a mapping between d and r to *this
+ - if (count(d) == 0) then
+ - #*(*this)[d] == r
+ - else
+ - #(*this)[d] != 0
+ - #d and #r have initial values for their types
+ - #count(d) == count(d) + 1
+ - #at_start() == true
+ - #size() == size() + 1
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ if add() throws then it has no effect
+ !*/
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+ /*!
+ requires
+ - (*this)[d] != 0
+ - &d != &r (i.e. d and r cannot be the same variable)
+ - &d != &d_copy (i.e. d and d_copy cannot be the same variable)
+ - &r != &d_copy (i.e. r and d_copy cannot be the same variable)
+ ensures
+ - some element in the domain of *this that is equivalent to d has
+ been removed and swapped into #d_copy. Additionally, its
+ associated range element has been removed and swapped into #r.
+ - #count(d) == count(d) - 1
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void destroy (
+ const domain& d
+ );
+ /*!
+ requires
+ - (*this)[d] != 0
+ ensures
+ - an element in the domain of *this equivalent to d has been removed.
+ The element in the range of *this associated with d has also been
+ removed.
+ - #count(d) == count(d) - 1
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void remove_last_in_order (
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - the last/biggest (according to the compare functor) element in the domain of *this has
+ been removed and swapped into #d. The element in the range of *this
+ associated with #d has also been removed and swapped into #r.
+ - #count(#d) == count(#d) - 1
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void remove_current_element (
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - current_element_valid() == true
+ ensures
+ - the current element given by element() has been removed and swapped into d and r.
+ - #d == element().key()
+ - #r == element().value()
+ - #count(#d) == count(#d) - 1
+ - #size() == size() - 1
+ - moves the enumerator to the next element. If element() was the last
+ element in enumeration order then #current_element_valid() == false
+ and #at_start() == false.
+ !*/
+
+ void position_enumerator (
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - #at_start() == false
+ - if (count(d) > 0) then
+ - #element().key() == d
+ - else if (there are any items in the domain of *this that are bigger than
+ d according to the compare functor) then
+ - #element().key() == the smallest item in the domain of *this that is
+ bigger than d according to the compare functor.
+ - else
+ - #current_element_valid() == false
+ !*/
+
+ const range* operator[] (
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - if (there is an element in the domain equivalent to d) then
+ - returns a pointer to an element in the range of *this that
+ is associated with an element in the domain of *this
+ equivalent to d.
+ - else
+ - returns 0
+ !*/
+
+ range* operator[] (
+ const domain& d
+ );
+ /*!
+ ensures
+ - if (there is an element in the domain equivalent to d) then
+ - returns a pointer to an element in the range of *this that
+ is associated with an element in the domain of *this
+ equivalent to d.
+ - else
+ - returns 0
+ !*/
+
+ void swap (
+ binary_search_tree& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ binary_search_tree(binary_search_tree&);
+ binary_search_tree& operator=(binary_search_tree&);
+
+ };
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ inline void swap (
+ binary_search_tree<domain,range,mem_manager,compare>& a,
+ binary_search_tree<domain,range,mem_manager,compare>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void deserialize (
+ binary_search_tree<domain,range,mem_manager,compare>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_c.h b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_c.h
new file mode 100644
index 000000000..0dc153961
--- /dev/null
+++ b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_c.h
@@ -0,0 +1,235 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_C_
+#define DLIB_BINARY_SEARCH_TREE_KERNEl_C_
+
+#include "../interfaces/map_pair.h"
+#include "binary_search_tree_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ typename bst_base
+ >
+ class binary_search_tree_kernel_c : public bst_base
+ {
+ typedef typename bst_base::domain_type domain;
+ typedef typename bst_base::range_type range;
+
+ public:
+
+ binary_search_tree_kernel_c () {}
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ void destroy (
+ const domain& d
+ );
+
+ void add (
+ domain& d,
+ range& r
+ );
+
+ void remove_any (
+ domain& d,
+ range& r
+ );
+
+ const map_pair<domain, range>& element(
+ ) const
+ {
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst map_pair<domain,range>& binary_search_tree::element() const"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return bst_base::element();
+ }
+
+ map_pair<domain, range>& element(
+ )
+ {
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tmap_pair<domain,range>& binary_search_tree::element()"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return bst_base::element();
+ }
+
+ void remove_last_in_order (
+ domain& d,
+ range& r
+ );
+
+ void remove_current_element (
+ domain& d,
+ range& r
+ );
+
+
+ };
+
+
+ template <
+ typename bst_base
+ >
+ inline void swap (
+ binary_search_tree_kernel_c<bst_base>& a,
+ binary_search_tree_kernel_c<bst_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bst_base
+ >
+ void binary_search_tree_kernel_c<bst_base>::
+ add (
+ domain& d,
+ range& r
+ )
+ {
+ DLIB_CASSERT( static_cast<const void*>(&d) != static_cast<void*>(&r),
+ "\tvoid binary_search_tree::add"
+ << "\n\tyou can't call add() and give the same object to both parameters."
+ << "\n\tthis: " << this
+ << "\n\t&d: " << &d
+ << "\n\t&r: " << &r
+ << "\n\tsize(): " << this->size()
+ );
+
+ bst_base::add(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bst_base
+ >
+ void binary_search_tree_kernel_c<bst_base>::
+ destroy (
+ const domain& d
+ )
+ {
+ DLIB_CASSERT(this->operator[](d) != 0,
+ "\tvoid binary_search_tree::destroy"
+ << "\n\tthe element must be in the tree for it to be removed"
+ << "\n\tthis: " << this
+ << "\n\t&d: " << &d
+ );
+
+ bst_base::destroy(d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bst_base
+ >
+ void binary_search_tree_kernel_c<bst_base>::
+ remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ DLIB_CASSERT(this->operator[](d) != 0 &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&d_copy)) &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&r)) &&
+ (static_cast<const void*>(&r) != static_cast<void*>(&d_copy)),
+ "\tvoid binary_search_tree::remove"
+ << "\n\tthe element must be in the tree for it to be removed"
+ << "\n\tthis: " << this
+ << "\n\t&d: " << &d
+ << "\n\t&d_copy: " << &d_copy
+ << "\n\t&r: " << &r
+ );
+
+ bst_base::remove(d,d_copy,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bst_base
+ >
+ void binary_search_tree_kernel_c<bst_base>::
+ remove_any(
+ domain& d,
+ range& r
+ )
+ {
+ DLIB_CASSERT(this->size() != 0 &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&r)),
+ "\tvoid binary_search_tree::remove_any"
+ << "\n\ttree must not be empty if something is going to be removed"
+ << "\n\tthis: " << this
+ << "\n\t&d: " << &d
+ << "\n\t&r: " << &r
+ );
+
+ bst_base::remove_any(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bst_base
+ >
+ void binary_search_tree_kernel_c<bst_base>::
+ remove_last_in_order (
+ domain& d,
+ range& r
+ )
+ {
+ DLIB_CASSERT(this->size() > 0,
+ "\tvoid binary_search_tree::remove_last_in_order()"
+ << "\n\tyou can't remove an element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ bst_base::remove_last_in_order(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bst_base
+ >
+ void binary_search_tree_kernel_c<bst_base>::
+ remove_current_element (
+ domain& d,
+ range& r
+ )
+ {
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tvoid binary_search_tree::remove_current_element()"
+ << "\n\tyou can't remove the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ bst_base::remove_current_element(d,r);
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_C_
+
diff --git a/ml/dlib/dlib/bit_stream.h b/ml/dlib/dlib/bit_stream.h
new file mode 100644
index 000000000..8885f3515
--- /dev/null
+++ b/ml/dlib/dlib/bit_stream.h
@@ -0,0 +1,42 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIT_STREAm_
+#define DLIB_BIT_STREAm_
+
+#include "bit_stream/bit_stream_kernel_1.h"
+#include "bit_stream/bit_stream_kernel_c.h"
+
+#include "bit_stream/bit_stream_multi_1.h"
+#include "bit_stream/bit_stream_multi_c.h"
+
+namespace dlib
+{
+
+
+ class bit_stream
+ {
+ bit_stream() {}
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef bit_stream_kernel_1
+ kernel_1a;
+ typedef bit_stream_kernel_c<kernel_1a >
+ kernel_1a_c;
+
+ //---------- extensions ------------
+
+
+ // multi_1 extend kernel_1a
+ typedef bit_stream_multi_1<kernel_1a>
+ multi_1a;
+ typedef bit_stream_multi_c<bit_stream_multi_1<kernel_1a_c> >
+ multi_1a_c;
+
+ };
+}
+
+#endif // DLIB_BIT_STREAm_
+
diff --git a/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.cpp b/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.cpp
new file mode 100644
index 000000000..f49db14d5
--- /dev/null
+++ b/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.cpp
@@ -0,0 +1,200 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIT_STREAM_KERNEL_1_CPp_
+#define DLIB_BIT_STREAM_KERNEL_1_CPp_
+
+
+#include "bit_stream_kernel_1.h"
+#include "../algs.h"
+
+#include <iostream>
+
+namespace dlib
+{
+
+ inline void swap (
+ bit_stream_kernel_1& a,
+ bit_stream_kernel_1& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void bit_stream_kernel_1::
+ clear (
+ )
+ {
+ if (write_mode)
+ {
+ write_mode = false;
+
+ // flush output buffer
+ if (buffer_size > 0)
+ {
+ buffer <<= 8 - buffer_size;
+ osp->write(reinterpret_cast<char*>(&buffer),1);
+ }
+ }
+ else
+ read_mode = false;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bit_stream_kernel_1::
+ set_input_stream (
+ std::istream& is
+ )
+ {
+ isp = &is;
+ read_mode = true;
+
+ buffer_size = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bit_stream_kernel_1::
+ set_output_stream (
+ std::ostream& os
+ )
+ {
+ osp = &os;
+ write_mode = true;
+
+ buffer_size = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bit_stream_kernel_1::
+ close (
+ )
+ {
+ if (write_mode)
+ {
+ write_mode = false;
+
+ // flush output buffer
+ if (buffer_size > 0)
+ {
+ buffer <<= 8 - buffer_size;
+ osp->write(reinterpret_cast<char*>(&buffer),1);
+ }
+ }
+ else
+ read_mode = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bit_stream_kernel_1::
+ is_in_write_mode (
+ ) const
+ {
+ return write_mode;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bit_stream_kernel_1::
+ is_in_read_mode (
+ ) const
+ {
+ return read_mode;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bit_stream_kernel_1::
+ write (
+ int bit
+ )
+ {
+ // flush buffer if necessary
+ if (buffer_size == 8)
+ {
+ buffer <<= 8 - buffer_size;
+ if (osp->rdbuf()->sputn(reinterpret_cast<char*>(&buffer),1) == 0)
+ {
+ throw std::ios_base::failure("error occurred in the bit_stream object");
+ }
+
+ buffer_size = 0;
+ }
+
+ ++buffer_size;
+ buffer <<= 1;
+ buffer += static_cast<unsigned char>(bit);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bit_stream_kernel_1::
+ read (
+ int& bit
+ )
+ {
+ // get new byte if necessary
+ if (buffer_size == 0)
+ {
+ if (isp->rdbuf()->sgetn(reinterpret_cast<char*>(&buffer), 1) == 0)
+ {
+ // if we didn't read anything then return false
+ return false;
+ }
+
+ buffer_size = 8;
+ }
+
+ // put the most significant bit from buffer into bit
+ bit = static_cast<int>(buffer >> 7);
+
+ // shift out the bit that was just read
+ buffer <<= 1;
+ --buffer_size;
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bit_stream_kernel_1::
+ swap (
+ bit_stream_kernel_1& item
+ )
+ {
+
+ std::istream* isp_temp = item.isp;
+ std::ostream* osp_temp = item.osp;
+ bool write_mode_temp = item.write_mode;
+ bool read_mode_temp = item.read_mode;
+ unsigned char buffer_temp = item.buffer;
+ unsigned short buffer_size_temp = item.buffer_size;
+
+ item.isp = isp;
+ item.osp = osp;
+ item.write_mode = write_mode;
+ item.read_mode = read_mode;
+ item.buffer = buffer;
+ item.buffer_size = buffer_size;
+
+
+ isp = isp_temp;
+ osp = osp_temp;
+ write_mode = write_mode_temp;
+ read_mode = read_mode_temp;
+ buffer = buffer_temp;
+ buffer_size = buffer_size_temp;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_BIT_STREAM_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.h b/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.h
new file mode 100644
index 000000000..801e93e0a
--- /dev/null
+++ b/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.h
@@ -0,0 +1,120 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIT_STREAM_KERNEl_1_
+#define DLIB_BIT_STREAM_KERNEl_1_
+
+#include "bit_stream_kernel_abstract.h"
+#include <iosfwd>
+
+namespace dlib
+{
+
+ class bit_stream_kernel_1
+ {
+
+ /*!
+ INITIAL VALUE
+ write_mode == false
+ read_mode == false
+
+ CONVENTION
+ write_mode == is_in_write_mode()
+ read_mode == is_in_read_mode()
+
+ if (write_mode)
+ {
+ osp == pointer to an ostream object
+ buffer == the low order bits of buffer are the bits to be
+ written
+ buffer_size == the number of low order bits in buffer that are
+ bits that should be written
+ the lowest order bit is the last bit entered by the user
+ }
+
+ if (read_mode)
+ {
+ isp == pointer to an istream object
+ buffer == the high order bits of buffer are the bits
+ waiting to be read by the user
+ buffer_size == the number of high order bits in buffer that
+ are bits that are waiting to be read
+ the highest order bit is the next bit to give to the user
+ }
+ !*/
+
+
+ public:
+
+ bit_stream_kernel_1 (
+ ) :
+ write_mode(false),
+ read_mode(false)
+ {}
+
+ virtual ~bit_stream_kernel_1 (
+ )
+ {}
+
+ void clear (
+ );
+
+ void set_input_stream (
+ std::istream& is
+ );
+
+ void set_output_stream (
+ std::ostream& os
+ );
+
+ void close (
+ );
+
+ inline bool is_in_write_mode (
+ ) const;
+
+ inline bool is_in_read_mode (
+ ) const;
+
+ inline void write (
+ int bit
+ );
+
+ bool read (
+ int& bit
+ );
+
+ void swap (
+ bit_stream_kernel_1& item
+ );
+
+ private:
+
+ // member data
+ std::istream* isp;
+ std::ostream* osp;
+ bool write_mode;
+ bool read_mode;
+ unsigned char buffer;
+ unsigned short buffer_size;
+
+ // restricted functions
+ bit_stream_kernel_1(bit_stream_kernel_1&); // copy constructor
+ bit_stream_kernel_1& operator=(bit_stream_kernel_1&); // assignment operator
+
+ };
+
+ inline void swap (
+ bit_stream_kernel_1& a,
+ bit_stream_kernel_1& b
+ );
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "bit_stream_kernel_1.cpp"
+#endif
+
+#endif // DLIB_BIT_STREAM_KERNEl_1_
+
diff --git a/ml/dlib/dlib/bit_stream/bit_stream_kernel_abstract.h b/ml/dlib/dlib/bit_stream/bit_stream_kernel_abstract.h
new file mode 100644
index 000000000..00c2ae3b9
--- /dev/null
+++ b/ml/dlib/dlib/bit_stream/bit_stream_kernel_abstract.h
@@ -0,0 +1,185 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BIT_STREAM_KERNEl_ABSTRACT_
+#ifdef DLIB_BIT_STREAM_KERNEl_ABSTRACT_
+
+#include <iosfwd>
+
+namespace dlib
+{
+
+ class bit_stream
+ {
+
+ /*!
+ INITIAL VALUE
+ is_in_write_mode() == false
+ is_in_read_mode() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ this object is a middle man between a user and the iostream classes.
+ it allows single bits to be read/written easily to/from
+ the iostream classes
+
+ BUFFERING:
+ This object will only read/write single bytes at a time from/to the
+ iostream objects. Any buffered bits still in the bit_stream object
+ when it is closed or destructed are lost if it is in read mode. If
+ it is in write mode then any remaining bits are guaranteed to be
+ written to the output stream by the time it is closed or destructed.
+ !*/
+
+
+ public:
+
+ bit_stream (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~bit_stream (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+
+ void set_input_stream (
+ std::istream& is
+ );
+ /*!
+ requires
+ - is_in_write_mode() == false
+ - is_in_read_mode() == false
+ - is is ready to give input
+ ensures
+ - #is_in_write_mode() == false
+ - #is_in_read_mode() == true
+ - #*this will now be reading from is
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_output_stream (
+ std::ostream& os
+ );
+ /*!
+ requires
+ - is_in_write_mode() == false
+ - is_in_read_mode() == false
+ - os is ready to take output
+ ensures
+ - #is_in_write_mode() == true
+ - #is_in_read_mode() == false
+ - #*this will now write to os
+ throws
+ - std::bad_alloc
+ !*/
+
+
+
+ void close (
+ );
+ /*!
+ requires
+ - is_in_write_mode() == true || is_in_read_mode() == true
+ ensures
+ - #is_in_write_mode() == false
+ - #is_in_read_mode() == false
+ !*/
+
+ bool is_in_write_mode (
+ ) const;
+ /*!
+ ensures
+ - returns true if *this is associated with an output stream object
+ - returns false otherwise
+ !*/
+
+ bool is_in_read_mode (
+ ) const;
+ /*!
+ ensures
+ - returns true if *this is associated with an input stream object
+ - returns false otherwise
+ !*/
+
+ void write (
+ int bit
+ );
+ /*!
+ requires
+ - is_in_write_mode() == true
+ - bit == 0 || bit == 1
+ ensures
+ - bit will be written to the ostream object associated with *this
+ throws
+ - std::ios_base::failure
+ if (there was a problem writing to the output stream) then
+ this exception will be thrown. #*this will be unusable until
+ clear() is called and succeeds
+ - any other exception
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ bool read (
+ int& bit
+ );
+ /*!
+ requires
+ - is_in_read_mode() == true
+ ensures
+ - the next bit has been read and placed into #bit
+ - returns true if the read was successful, else false
+ (ex. false if EOF has been reached)
+ throws
+ - any exception
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void swap (
+ bit_stream& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ bit_stream(bit_stream&); // copy constructor
+ bit_stream& operator=(bit_stream&); // assignment operator
+
+ };
+
+ inline void swap (
+ bit_stream& a,
+ bit_stream& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_BIT_STREAM_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/bit_stream/bit_stream_kernel_c.h b/ml/dlib/dlib/bit_stream/bit_stream_kernel_c.h
new file mode 100644
index 000000000..1d52bff20
--- /dev/null
+++ b/ml/dlib/dlib/bit_stream/bit_stream_kernel_c.h
@@ -0,0 +1,172 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIT_STREAM_KERNEl_C_
+#define DLIB_BIT_STREAM_KERNEl_C_
+
+#include "bit_stream_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <iosfwd>
+
+namespace dlib
+{
+
+ template <
+ typename bit_stream_base // implements bit_stream/bit_stream_kernel_abstract.h
+ >
+ class bit_stream_kernel_c : public bit_stream_base
+ {
+ public:
+
+
+ void set_input_stream (
+ std::istream& is
+ );
+
+ void set_output_stream (
+ std::ostream& os
+ );
+
+ void close (
+ );
+
+ void write (
+ int bit
+ );
+
+ bool read (
+ int& bit
+ );
+
+ };
+
+ template <
+ typename bit_stream_base
+ >
+ inline void swap (
+ bit_stream_kernel_c<bit_stream_base>& a,
+ bit_stream_kernel_c<bit_stream_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bit_stream_base
+ >
+ void bit_stream_kernel_c<bit_stream_base>::
+ set_input_stream (
+ std::istream& is
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ),
+ "\tvoid bit_stream::set_intput_stream"
+ << "\n\tbit_stream must not be in write or read mode"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ bit_stream_base::set_input_stream(is);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bit_stream_base
+ >
+ void bit_stream_kernel_c<bit_stream_base>::
+ set_output_stream (
+ std::ostream& os
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ),
+ "\tvoid bit_stream::set_output_stream"
+ << "\n\tbit_stream must not be in write or read mode"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ bit_stream_base::set_output_stream(os);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bit_stream_base
+ >
+ void bit_stream_kernel_c<bit_stream_base>::
+ close (
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( this->is_in_write_mode() == true ) || ( this->is_in_read_mode() == true ),
+ "\tvoid bit_stream::close"
+ << "\n\tyou can't close a bit_stream that isn't open"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ bit_stream_base::close();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bit_stream_base
+ >
+ void bit_stream_kernel_c<bit_stream_base>::
+ write (
+ int bit
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( this->is_in_write_mode() == true ) && ( bit == 0 || bit == 1 ),
+ "\tvoid bit_stream::write"
+ << "\n\tthe bit stream bust be in write mode and bit must be either 1 or 0"
+ << "\n\tis_in_write_mode() == " << this->is_in_write_mode()
+ << "\n\tbit == " << bit
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ bit_stream_base::write(bit);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bit_stream_base
+ >
+ bool bit_stream_kernel_c<bit_stream_base>::
+ read (
+ int& bit
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( this->is_in_read_mode() == true ),
+ "\tbool bit_stream::read"
+ << "\n\tyou can't read from a bit_stream that isn't in read mode"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return bit_stream_base::read(bit);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BIT_STREAM_KERNEl_C_
+
diff --git a/ml/dlib/dlib/bit_stream/bit_stream_multi_1.h b/ml/dlib/dlib/bit_stream/bit_stream_multi_1.h
new file mode 100644
index 000000000..bf1cc0357
--- /dev/null
+++ b/ml/dlib/dlib/bit_stream/bit_stream_multi_1.h
@@ -0,0 +1,103 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIT_STREAM_MULTi_1_
+#define DLIB_BIT_STREAM_MULTi_1_
+
+#include "bit_stream_multi_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename bit_stream_base
+ >
+ class bit_stream_multi_1 : public bit_stream_base
+ {
+
+ public:
+
+ void multi_write (
+ unsigned long data,
+ int num_to_write
+ );
+
+ int multi_read (
+ unsigned long& data,
+ int num_to_read
+ );
+
+ };
+
+ template <
+ typename bit_stream_base
+ >
+ inline void swap (
+ bit_stream_multi_1<bit_stream_base>& a,
+ bit_stream_multi_1<bit_stream_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bit_stream_base
+ >
+ void bit_stream_multi_1<bit_stream_base>::
+ multi_write (
+ unsigned long data,
+ int num_to_write
+ )
+ {
+ // move the first bit into the most significant position
+ data <<= 32 - num_to_write;
+
+ for (int i = 0; i < num_to_write; ++i)
+ {
+ // write the first bit from data
+ this->write(static_cast<char>(data >> 31));
+
+ // shift the next bit into position
+ data <<= 1;
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bit_stream_base
+ >
+ int bit_stream_multi_1<bit_stream_base>::
+ multi_read (
+ unsigned long& data,
+ int num_to_read
+ )
+ {
+ int bit, i;
+ data = 0;
+ for (i = 0; i < num_to_read; ++i)
+ {
+
+ // get a bit
+ if (this->read(bit) == false)
+ break;
+
+ // shift data to make room for this new bit
+ data <<= 1;
+
+ // put bit into the least significant position in data
+ data += static_cast<unsigned long>(bit);
+
+ }
+
+ return i;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BIT_STREAM_MULTi_1_
+
diff --git a/ml/dlib/dlib/bit_stream/bit_stream_multi_abstract.h b/ml/dlib/dlib/bit_stream/bit_stream_multi_abstract.h
new file mode 100644
index 000000000..061af94f4
--- /dev/null
+++ b/ml/dlib/dlib/bit_stream/bit_stream_multi_abstract.h
@@ -0,0 +1,77 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BIT_STREAM_MULTi_ABSTRACT_
+#ifdef DLIB_BIT_STREAM_MULTi_ABSTRACT_
+
+#include "bit_stream_kernel_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename bit_stream_base
+ >
+ class bit_stream_multi : public bit_stream_base
+ {
+
+ /*!
+ REQUIREMENTS ON BIT_STREAM_BASE
+ it is an implementation of bit_stream/bit_stream_kernel_abstract.h
+
+
+ WHAT THIS EXTENSION DOES FOR BIT_STREAM
+ this gives a bit_stream object the ability to read/write multible bits
+ at a time
+ !*/
+
+
+ public:
+
+ void multi_write (
+ unsigned long data,
+ int num_to_write
+ );
+ /*!
+ requires
+ - is_in_write_mode() == true
+ - 0 <= num_to_write <= 32
+ ensures
+ - num_to_write low order bits from data will be written to the ostream
+ - object associated with *this
+ example: if data is 10010 then the bits will be written in the
+ order 1,0,0,1,0
+ !*/
+
+
+ int multi_read (
+ unsigned long& data,
+ int num_to_read
+ );
+ /*!
+ requires
+ - is_in_read_mode() == true
+ - 0 <= num_to_read <= 32
+ ensures
+ - tries to read num_to_read bits into the low order end of #data
+ example: if the incoming bits were 10010 then data would end
+ up with 10010 as its low order bits
+ - all of the bits in #data not filled in by multi_read() are zero
+ - returns the number of bits actually read into #data
+ !*/
+
+ };
+
+ template <
+ typename bit_stream_base
+ >
+ inline void swap (
+ bit_stream_multi<bit_stream_base>& a,
+ bit_stream_multi<bit_stream_base>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_BIT_STREAM_MULTi_ABSTRACT_
+
diff --git a/ml/dlib/dlib/bit_stream/bit_stream_multi_c.h b/ml/dlib/dlib/bit_stream/bit_stream_multi_c.h
new file mode 100644
index 000000000..de80c6328
--- /dev/null
+++ b/ml/dlib/dlib/bit_stream/bit_stream_multi_c.h
@@ -0,0 +1,101 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BIT_STREAM_MULTi_C_
+#define DLIB_BIT_STREAM_MULTi_C_
+
+#include "bit_stream_multi_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+ template <
+ typename bit_stream_base // implements bit_stream/bit_stream_multi_abstract.h
+ >
+ class bit_stream_multi_c : public bit_stream_base
+ {
+ public:
+
+ void multi_write (
+ unsigned long data,
+ int num_to_write
+ );
+
+ int multi_read (
+ unsigned long& data,
+ int num_to_read
+ );
+
+ };
+
+ template <
+ typename bit_stream_base
+ >
+ inline void swap (
+ bit_stream_multi_c<bit_stream_base>& a,
+ bit_stream_multi_c<bit_stream_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bit_stream_base
+ >
+ void bit_stream_multi_c<bit_stream_base>::
+ multi_write (
+ unsigned long data,
+ int num_to_write
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (this->is_in_write_mode() == true) && (num_to_write >= 0 && num_to_write <=32),
+ "\tvoid bit_stream::write"
+ << "\n\tthe bit stream bust be in write mode and"
+ << "\n\tnum_to_write must be between 0 and 32 inclusive"
+ << "\n\tnum_to_write == " << num_to_write
+ << "\n\tis_in_write_mode() == " << this->is_in_write_mode()
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ bit_stream_base::multi_write(data,num_to_write);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename bit_stream_base
+ >
+ int bit_stream_multi_c<bit_stream_base>::
+ multi_read (
+ unsigned long& data,
+ int num_to_read
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( this->is_in_read_mode() == true && ( num_to_read >= 0 && num_to_read <=32 ) ),
+ "\tvoid bit_stream::read"
+ << "\n\tyou can't read from a bit_stream that isn't in read mode and"
+ << "\n\tnum_to_read must be between 0 and 32 inclusive"
+ << "\n\tnum_to_read == " << num_to_read
+ << "\n\tis_in_read_mode() == " << this->is_in_read_mode()
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return bit_stream_base::multi_read(data,num_to_read);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BIT_STREAM_MULTi_C_
+
diff --git a/ml/dlib/dlib/bits/c++config.h b/ml/dlib/dlib/bits/c++config.h
new file mode 100644
index 000000000..6139ba823
--- /dev/null
+++ b/ml/dlib/dlib/bits/c++config.h
@@ -0,0 +1 @@
+#include "../dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/bound_function_pointer.h b/ml/dlib/dlib/bound_function_pointer.h
new file mode 100644
index 000000000..a482919c6
--- /dev/null
+++ b/ml/dlib/dlib/bound_function_pointer.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BOUND_FUNCTION_POINTEr_
+#define DLIB_BOUND_FUNCTION_POINTEr_
+
+#include "bound_function_pointer/bound_function_pointer_kernel_1.h"
+
+#endif // DLIB_BOUND_FUNCTION_POINTEr_
+
+
diff --git a/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h b/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h
new file mode 100644
index 000000000..a39592742
--- /dev/null
+++ b/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h
@@ -0,0 +1,774 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_
+#define DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_
+
+#include "../algs.h"
+#include "../member_function_pointer.h"
+#include "bound_function_pointer_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace bfp1_helpers
+ {
+ template <typename T> struct strip { typedef T type; };
+ template <typename T> struct strip<T&> { typedef T type; };
+
+ // ------------------------------------------------------------------------------------
+
+ class bound_function_helper_base_base
+ {
+ public:
+ virtual ~bound_function_helper_base_base(){}
+ virtual void call() const = 0;
+ virtual bool is_set() const = 0;
+ virtual void clone(void* ptr) const = 0;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T1, typename T2, typename T3, typename T4>
+ class bound_function_helper_base : public bound_function_helper_base_base
+ {
+ public:
+ bound_function_helper_base():arg1(0), arg2(0), arg3(0), arg4(0) {}
+
+ typename strip<T1>::type* arg1;
+ typename strip<T2>::type* arg2;
+ typename strip<T3>::type* arg3;
+ typename strip<T4>::type* arg4;
+
+
+ member_function_pointer<T1,T2,T3,T4> mfp;
+ };
+
+ // ----------------
+
+ template <typename F, typename T1 = void, typename T2 = void, typename T3 = void, typename T4 = void>
+ class bound_function_helper : public bound_function_helper_base<T1,T2,T3,T4>
+ {
+ public:
+ void call() const
+ {
+ (*fp)(*this->arg1, *this->arg2, *this->arg3, *this->arg4);
+ }
+
+ typename strip<F>::type* fp;
+ };
+
+ template <typename T1, typename T2, typename T3, typename T4>
+ class bound_function_helper<void,T1,T2,T3,T4> : public bound_function_helper_base<T1,T2,T3,T4>
+ {
+ public:
+ void call() const
+ {
+ if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3, *this->arg4);
+ else if (fp) fp(*this->arg1, *this->arg2, *this->arg3, *this->arg4);
+ }
+
+ void (*fp)(T1, T2, T3, T4);
+ };
+
+ // ----------------
+
+ template <typename F>
+ class bound_function_helper<F,void,void,void,void> : public bound_function_helper_base<void,void,void,void>
+ {
+ public:
+ void call() const
+ {
+ (*fp)();
+ }
+
+ typename strip<F>::type* fp;
+ };
+
+ template <>
+ class bound_function_helper<void,void,void,void,void> : public bound_function_helper_base<void,void,void,void>
+ {
+ public:
+ void call() const
+ {
+ if (this->mfp) this->mfp();
+ else if (fp) fp();
+ }
+
+ void (*fp)();
+ };
+
+ // ----------------
+
+ template <typename F, typename T1>
+ class bound_function_helper<F,T1,void,void,void> : public bound_function_helper_base<T1,void,void,void>
+ {
+ public:
+ void call() const
+ {
+ (*fp)(*this->arg1);
+ }
+
+ typename strip<F>::type* fp;
+ };
+
+ template <typename T1>
+ class bound_function_helper<void,T1,void,void,void> : public bound_function_helper_base<T1,void,void,void>
+ {
+ public:
+ void call() const
+ {
+ if (this->mfp) this->mfp(*this->arg1);
+ else if (fp) fp(*this->arg1);
+ }
+
+ void (*fp)(T1);
+ };
+
+ // ----------------
+
+ template <typename F, typename T1, typename T2>
+ class bound_function_helper<F,T1,T2,void,void> : public bound_function_helper_base<T1,T2,void,void>
+ {
+ public:
+ void call() const
+ {
+ (*fp)(*this->arg1, *this->arg2);
+ }
+
+ typename strip<F>::type* fp;
+ };
+
+ template <typename T1, typename T2>
+ class bound_function_helper<void,T1,T2,void,void> : public bound_function_helper_base<T1,T2,void,void>
+ {
+ public:
+ void call() const
+ {
+ if (this->mfp) this->mfp(*this->arg1, *this->arg2);
+ else if (fp) fp(*this->arg1, *this->arg2);
+ }
+
+ void (*fp)(T1, T2);
+ };
+
+ // ----------------
+
+ template <typename F, typename T1, typename T2, typename T3>
+ class bound_function_helper<F,T1,T2,T3,void> : public bound_function_helper_base<T1,T2,T3,void>
+ {
+ public:
+ void call() const
+ {
+ (*fp)(*this->arg1, *this->arg2, *this->arg3);
+ }
+
+ typename strip<F>::type* fp;
+ };
+
+ template <typename T1, typename T2, typename T3>
+ class bound_function_helper<void,T1,T2,T3,void> : public bound_function_helper_base<T1,T2,T3,void>
+ {
+ public:
+
+ void call() const
+ {
+ if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3);
+ else if (fp) fp(*this->arg1, *this->arg2, *this->arg3);
+ }
+
+ void (*fp)(T1, T2, T3);
+ };
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ class bound_function_helper_T : public T
+ {
+ public:
+ bound_function_helper_T(){ this->fp = 0;}
+
+ bool is_set() const
+ {
+ return this->fp != 0 || this->mfp.is_set();
+ }
+
+ template <unsigned long mem_size>
+ void safe_clone(stack_based_memory_block<mem_size>& buf)
+ {
+ // This is here just to validate the assumption that our block of memory we have made
+ // in bf_memory is the right size to store the data for this object. If you
+ // get a compiler error on this line then email me :)
+ COMPILE_TIME_ASSERT(sizeof(bound_function_helper_T) <= mem_size);
+ clone(buf.get());
+ }
+
+ void clone (void* ptr) const
+ {
+ bound_function_helper_T* p = new(ptr) bound_function_helper_T();
+ p->arg1 = this->arg1;
+ p->arg2 = this->arg2;
+ p->arg3 = this->arg3;
+ p->arg4 = this->arg4;
+ p->fp = this->fp;
+ p->mfp = this->mfp;
+ }
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class bound_function_pointer
+ {
+ typedef bfp1_helpers::bound_function_helper_T<bfp1_helpers::bound_function_helper<void,int> > bf_null_type;
+
+ public:
+
+ // These typedefs are here for backwards compatibility with previous versions of
+ // dlib.
+ typedef bound_function_pointer kernel_1a;
+ typedef bound_function_pointer kernel_1a_c;
+
+
+ bound_function_pointer (
+ ) { bf_null_type().safe_clone(bf_memory); }
+
+ bound_function_pointer (
+ const bound_function_pointer& item
+ ) { item.bf()->clone(bf_memory.get()); }
+
+ ~bound_function_pointer()
+ { destroy_bf_memory(); }
+
+ bound_function_pointer& operator= (
+ const bound_function_pointer& item
+ ) { bound_function_pointer(item).swap(*this); return *this; }
+
+ void clear (
+ ) { bound_function_pointer().swap(*this); }
+
+ bool is_set (
+ ) const
+ {
+ return bf()->is_set();
+ }
+
+ void swap (
+ bound_function_pointer& item
+ )
+ {
+ // make a temp copy of item
+ bound_function_pointer temp(item);
+
+ // destory the stuff in item
+ item.destroy_bf_memory();
+ // copy *this into item
+ bf()->clone(item.bf_memory.get());
+
+ // destory the stuff in this
+ destroy_bf_memory();
+ // copy temp into *this
+ temp.bf()->clone(bf_memory.get());
+ }
+
+ void operator() (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_set() == true ,
+ "\tvoid bound_function_pointer::operator()"
+ << "\n\tYou must call set() before you can use this function"
+ << "\n\tthis: " << this
+ );
+
+ bf()->call();
+ }
+
+ private:
+ struct dummy{ void nonnull() {}};
+ typedef void (dummy::*safe_bool)();
+
+ public:
+ operator safe_bool () const { return is_set() ? &dummy::nonnull : 0; }
+ bool operator!() const { return !is_set(); }
+
+ // -------------------------------------------
+ // set function object overloads
+ // -------------------------------------------
+
+ template <typename F>
+ void set (
+ F& function_object
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<F> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.fp = &function_object;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename F, typename A1 >
+ void set (
+ F& function_object,
+ A1& arg1
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<F,A1> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.fp = &function_object;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename F, typename A1, typename A2 >
+ void set (
+ F& function_object,
+ A1& arg1,
+ A2& arg2
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<F,A1,A2> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.fp = &function_object;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename F, typename A1, typename A2, typename A3 >
+ void set (
+ F& function_object,
+ A1& arg1,
+ A2& arg2,
+ A3& arg3
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<F,A1,A2,A3> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.arg3 = &arg3;
+ temp.fp = &function_object;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename F, typename A1, typename A2, typename A3, typename A4>
+ void set (
+ F& function_object,
+ A1& arg1,
+ A2& arg2,
+ A3& arg3,
+ A4& arg4
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<F,A1,A2,A3,A4> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.arg3 = &arg3;
+ temp.arg4 = &arg4;
+ temp.fp = &function_object;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ // -------------------------------------------
+ // set mfp overloads
+ // -------------------------------------------
+
+ template <typename T>
+ void set (
+ T& object,
+ void (T::*funct)()
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename T >
+ void set (
+ const T& object,
+ void (T::*funct)()const
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ // -------------------------------------------
+
+ template <typename T, typename T1, typename A1 >
+ void set (
+ T& object,
+ void (T::*funct)(T1),
+ A1& arg1
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename T, typename T1, typename A1 >
+ void set (
+ const T& object,
+ void (T::*funct)(T1)const,
+ A1& arg1
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ // ----------------
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ void set (
+ T& object,
+ void (T::*funct)(T1, T2),
+ A1& arg1,
+ A2& arg2
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1,T2> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ void set (
+ const T& object,
+ void (T::*funct)(T1, T2)const,
+ A1& arg1,
+ A2& arg2
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1,T2> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ // ----------------
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ void set (
+ T& object,
+ void (T::*funct)(T1, T2, T3),
+ A1& arg1,
+ A2& arg2,
+ A3& arg3
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.arg3 = &arg3;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ void set (
+ const T& object,
+ void (T::*funct)(T1, T2, T3)const,
+ A1& arg1,
+ A2& arg2,
+ A3& arg3
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.arg3 = &arg3;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ // ----------------
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ void set (
+ T& object,
+ void (T::*funct)(T1, T2, T3, T4),
+ A1& arg1,
+ A2& arg2,
+ A3& arg3,
+ A4& arg4
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3,T4> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.arg3 = &arg3;
+ temp.arg4 = &arg4;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ void set (
+ const T& object,
+ void (T::*funct)(T1, T2, T3, T4)const,
+ A1& arg1,
+ A2& arg2,
+ A3& arg3,
+ A4& arg4
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3,T4> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.arg3 = &arg3;
+ temp.arg4 = &arg4;
+ temp.mfp.set(object,funct);
+
+ temp.safe_clone(bf_memory);
+ }
+
+ // -------------------------------------------
+ // set fp overloads
+ // -------------------------------------------
+
+ void set (
+ void (*funct)()
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.fp = funct;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename T1, typename A1>
+ void set (
+ void (*funct)(T1),
+ A1& arg1
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.fp = funct;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename T1, typename A1,
+ typename T2, typename A2>
+ void set (
+ void (*funct)(T1, T2),
+ A1& arg1,
+ A2& arg2
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1,T2> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.fp = funct;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ void set (
+ void (*funct)(T1, T2, T3),
+ A1& arg1,
+ A2& arg2,
+ A3& arg3
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.arg3 = &arg3;
+ temp.fp = funct;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ template <typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ void set (
+ void (*funct)(T1, T2, T3, T4),
+ A1& arg1,
+ A2& arg2,
+ A3& arg3,
+ A4& arg4
+ )
+ {
+ using namespace bfp1_helpers;
+ destroy_bf_memory();
+ typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3,T4> > bf_helper_type;
+
+ bf_helper_type temp;
+ temp.arg1 = &arg1;
+ temp.arg2 = &arg2;
+ temp.arg3 = &arg3;
+ temp.arg4 = &arg4;
+ temp.fp = funct;
+
+ temp.safe_clone(bf_memory);
+ }
+
+ // -------------------------------------------
+
+ private:
+
+ stack_based_memory_block<sizeof(bf_null_type)> bf_memory;
+
+ void destroy_bf_memory (
+ )
+ {
+ // Honestly, this probably doesn't even do anything but I'm putting
+ // it here just for good measure.
+ bf()->~bound_function_helper_base_base();
+ }
+
+ bfp1_helpers::bound_function_helper_base_base* bf ()
+ { return static_cast<bfp1_helpers::bound_function_helper_base_base*>(bf_memory.get()); }
+
+ const bfp1_helpers::bound_function_helper_base_base* bf () const
+ { return static_cast<const bfp1_helpers::bound_function_helper_base_base*>(bf_memory.get()); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ bound_function_pointer& a,
+ bound_function_pointer& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h b/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h
new file mode 100644
index 000000000..b5356d6e0
--- /dev/null
+++ b/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h
@@ -0,0 +1,456 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_
+#ifdef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class bound_function_pointer
+ {
+ /*!
+ INITIAL VALUE
+ is_set() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a function with all its arguments bound to
+ specific objects. For example:
+
+ void test(int& var) { var = var+1; }
+
+ bound_function_pointer funct;
+
+ int a = 4;
+ funct.set(test,a); // bind the variable a to the first argument of the test() function
+
+ // at this point a == 4
+ funct();
+ // after funct() is called a == 5
+ !*/
+
+ public:
+
+ bound_function_pointer (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ !*/
+
+ bound_function_pointer(
+ const bound_function_pointer& item
+ );
+ /*!
+ ensures
+ - *this == item
+ !*/
+
+ ~bound_function_pointer (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ bound_function_pointer& operator=(
+ const bound_function_pointer& item
+ );
+ /*!
+ ensures
+ - *this == item
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ !*/
+
+ bool is_set (
+ ) const;
+ /*!
+ ensures
+ - if (this->set() has been called) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ operator some_undefined_pointer_type (
+ ) const;
+ /*!
+ ensures
+ - if (is_set()) then
+ - returns a non 0 value
+ - else
+ - returns a 0 value
+ !*/
+
+ bool operator! (
+ ) const;
+ /*!
+ ensures
+ - returns !is_set()
+ !*/
+
+ void operator () (
+ ) const;
+ /*!
+ requires
+ - is_set() == true
+ ensures
+ - calls the bound function on the object(s) specified by the last
+ call to this->set()
+ throws
+ - any exception thrown by the function specified by
+ the previous call to this->set().
+ If any of these exceptions are thrown then the call to this
+ function will have no effect on *this.
+ !*/
+
+ void swap (
+ bound_function_pointer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ // ----------------------
+
+ template <typename F>
+ void set (
+ F& function_object
+ );
+ /*!
+ requires
+ - function_object() is a valid expression
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call function_object()
+ (This seems pointless but it is a useful base case)
+ !*/
+
+ template < typename T>
+ void set (
+ T& object,
+ void (T::*funct)()
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)()
+ !*/
+
+ template < typename T>
+ void set (
+ const T& object,
+ void (T::*funct)()const
+ );
+ /*!
+ requires
+ - funct == a valid bound function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)()
+ !*/
+
+ void set (
+ void (*funct)()
+ );
+ /*!
+ requires
+ - funct == a valid function pointer
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call funct()
+ !*/
+
+ // ----------------------
+
+ template <typename F, typename A1 >
+ void set (
+ F& function_object,
+ A1& arg1
+ );
+ /*!
+ requires
+ - function_object(arg1) is a valid expression
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call function_object(arg1)
+ !*/
+
+ template < typename T, typename T1, typename A1 >
+ void set (
+ T& object,
+ void (T::*funct)(T1),
+ A1& arg1
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)(arg1)
+ !*/
+
+ template < typename T, typename T1, typename A1 >
+ void set (
+ const T& object,
+ void (T::*funct)(T1)const,
+ A1& arg1
+ );
+ /*!
+ requires
+ - funct == a valid bound function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)(arg1)
+ !*/
+
+ template <typename T1, typename A1>
+ void set (
+ void (*funct)(T1),
+ A1& arg1
+ );
+ /*!
+ requires
+ - funct == a valid function pointer
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call funct(arg1)
+ !*/
+
+ // ----------------------
+ template <typename F, typename A1, typename A2 >
+ void set (
+ F& function_object,
+ A1& arg1,
+ A2& arg2
+ );
+ /*!
+ requires
+ - function_object(arg1,arg2) is a valid expression
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call function_object(arg1,arg2)
+ !*/
+
+ template < typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ void set (
+ T& object,
+ void (T::*funct)(T1,T2),
+ A1& arg1,
+ A2& arg2
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)(arg1,arg2)
+ !*/
+
+ template < typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ void set (
+ const T& object,
+ void (T::*funct)(T1,T2)const,
+ A1& arg1,
+ A2& arg2
+ );
+ /*!
+ requires
+ - funct == a valid bound function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)(arg1,arg2)
+ !*/
+
+ template <typename T1, typename A1,
+ typename T2, typename A2>
+ void set (
+ void (*funct)(T1,T2),
+ A1& arg1,
+ A2& arg2
+ );
+ /*!
+ requires
+ - funct == a valid function pointer
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call funct(arg1,arg2)
+ !*/
+
+ // ----------------------
+
+ template <typename F, typename A1, typename A2, typename A3 >
+ void set (
+ F& function_object,
+ A1& arg1,
+ A2& arg2,
+ A3& arg3
+ );
+ /*!
+ requires
+ - function_object(arg1,arg2,arg3) is a valid expression
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call function_object(arg1,arg2,arg3)
+ !*/
+
+ template < typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ void set (
+ T& object,
+ void (T::*funct)(T1,T2,T3),
+ A1& arg1,
+ A2& arg2,
+ A3& arg3
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3)
+ !*/
+
+ template < typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ void set (
+ const T& object,
+ void (T::*funct)(T1,T2,T3)const,
+ A1& arg1,
+ A2& arg2,
+ A3& arg3
+ );
+ /*!
+ requires
+ - funct == a valid bound function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3)
+ !*/
+
+ template <typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ void set (
+ void (*funct)(T1,T2,T3),
+ A1& arg1,
+ A2& arg2,
+ A3& arg3
+ );
+ /*!
+ requires
+ - funct == a valid function pointer
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call funct(arg1,arg2,arg3)
+ !*/
+
+ // ----------------------
+
+ template <typename F, typename A1, typename A2, typename A3, typename A4>
+ void set (
+ F& function_object,
+ A1& arg1,
+ A2& arg2,
+ A3& arg3,
+ A4& arg4
+ );
+ /*!
+ requires
+ - function_object(arg1,arg2,arg3,arg4) is a valid expression
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call function_object(arg1,arg2,arg3,arg4)
+ !*/
+
+ template < typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ void set (
+ T& object,
+ void (T::*funct)(T1,T2,T3,T4),
+ A1& arg1,
+ A2& arg2,
+ A3& arg3,
+ A4& arg4
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4)
+ !*/
+
+ template < typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ void set (
+ const T& object,
+ void (T::*funct)(T1,T2,T3,T4)const,
+ A1& arg1,
+ A2& arg2,
+ A3& arg3,
+ A4& arg4
+ );
+ /*!
+ requires
+ - funct == a valid bound function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4)
+ !*/
+
+ template <typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ void set (
+ void (*funct)(T1,T2,T3,T4),
+ A1& arg1,
+ A2& arg2,
+ A3& arg3,
+ A4& arg4
+ );
+ /*!
+ requires
+ - funct == a valid function pointer
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call funct(arg1,arg2,arg3,arg4)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ bound_function_pointer& a,
+ bound_function_pointer& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/bridge.h b/ml/dlib/dlib/bridge.h
new file mode 100644
index 000000000..4b633c405
--- /dev/null
+++ b/ml/dlib/dlib/bridge.h
@@ -0,0 +1,17 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+
+#ifndef DLIB_BRIdGE_
+#define DLIB_BRIdGE_
+
+
+#include "bridge/bridge.h"
+
+#endif // DLIB_BRIdGE_
+
+
diff --git a/ml/dlib/dlib/bridge/bridge.h b/ml/dlib/dlib/bridge/bridge.h
new file mode 100644
index 000000000..da4e0bd7e
--- /dev/null
+++ b/ml/dlib/dlib/bridge/bridge.h
@@ -0,0 +1,669 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BRIDGe_Hh_
+#define DLIB_BRIDGe_Hh_
+
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "bridge_abstract.h"
+#include "../pipe.h"
+#include "../threads.h"
+#include "../serialize.h"
+#include "../sockets.h"
+#include "../sockstreambuf.h"
+#include "../logger.h"
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct connect_to_ip_and_port
+ {
+ connect_to_ip_and_port (
+ const std::string& ip_,
+ unsigned short port_
+ ): ip(ip_), port(port_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_ip_address(ip) && port != 0,
+ "\t connect_to_ip_and_port()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t ip: " << ip
+ << "\n\t port: " << port
+ << "\n\t this: " << this
+ );
+ }
+
+ private:
+ friend class bridge;
+ const std::string ip;
+ const unsigned short port;
+ };
+
+ inline connect_to_ip_and_port connect_to (
+ const network_address& addr
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(addr.port != 0,
+ "\t connect_to_ip_and_port()"
+ << "\n\t The TCP port to connect to can't be 0."
+ << "\n\t addr.port: " << addr.port
+ );
+
+ if (is_ip_address(addr.host_address))
+ {
+ return connect_to_ip_and_port(addr.host_address, addr.port);
+ }
+ else
+ {
+ std::string ip;
+ if(hostname_to_ip(addr.host_address,ip))
+ throw socket_error(ERESOLVE,"unable to resolve '" + addr.host_address + "' in connect_to()");
+
+ return connect_to_ip_and_port(ip, addr.port);
+ }
+ }
+
+ struct listen_on_port
+ {
+ listen_on_port(
+ unsigned short port_
+ ) : port(port_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( port != 0,
+ "\t listen_on_port()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t port: " << port
+ << "\n\t this: " << this
+ );
+ }
+
+ private:
+ friend class bridge;
+ const unsigned short port;
+ };
+
+ template <typename pipe_type>
+ struct bridge_transmit_decoration
+ {
+ bridge_transmit_decoration (
+ pipe_type& p_
+ ) : p(p_) {}
+
+ private:
+ friend class bridge;
+ pipe_type& p;
+ };
+
+ template <typename pipe_type>
+ bridge_transmit_decoration<pipe_type> transmit ( pipe_type& p) { return bridge_transmit_decoration<pipe_type>(p); }
+
+ template <typename pipe_type>
+ struct bridge_receive_decoration
+ {
+ bridge_receive_decoration (
+ pipe_type& p_
+ ) : p(p_) {}
+
+ private:
+ friend class bridge;
+ pipe_type& p;
+ };
+
+ template <typename pipe_type>
+ bridge_receive_decoration<pipe_type> receive ( pipe_type& p) { return bridge_receive_decoration<pipe_type>(p); }
+
+// ----------------------------------------------------------------------------------------
+
+ struct bridge_status
+ {
+ bridge_status() : is_connected(false), foreign_port(0){}
+
+ bool is_connected;
+ unsigned short foreign_port;
+ std::string foreign_ip;
+ };
+
+ inline void serialize ( const bridge_status& , std::ostream& )
+ {
+ throw serialization_error("It is illegal to serialize bridge_status objects.");
+ }
+
+ inline void deserialize ( bridge_status& , std::istream& )
+ {
+ throw serialization_error("It is illegal to serialize bridge_status objects.");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl_brns
+ {
+ class impl_bridge_base
+ {
+ public:
+
+ virtual ~impl_bridge_base() {}
+
+ virtual bridge_status get_bridge_status (
+ ) const = 0;
+ };
+
+ template <
+ typename transmit_pipe_type,
+ typename receive_pipe_type
+ >
+ class impl_bridge : public impl_bridge_base, private noncopyable, private multithreaded_object
+ {
+ /*!
+ CONVENTION
+ - if (list) then
+ - this object is supposed to be listening on the list object for incoming
+ connections when not connected.
+ - else
+ - this object is supposed to be attempting to connect to ip:port when
+ not connected.
+
+ - get_bridge_status() == current_bs
+ !*/
+ public:
+
+ impl_bridge (
+ unsigned short listen_port,
+ transmit_pipe_type* transmit_pipe_,
+ receive_pipe_type* receive_pipe_
+ ) :
+ s(m),
+ receive_thread_active(false),
+ transmit_thread_active(false),
+ port(0),
+ transmit_pipe(transmit_pipe_),
+ receive_pipe(receive_pipe_),
+ dlog("dlib.bridge"),
+ keepalive_code(0),
+ message_code(1)
+ {
+ int status = create_listener(list, listen_port);
+ if (status == PORTINUSE)
+ {
+ std::ostringstream sout;
+ sout << "Error, the port " << listen_port << " is already in use.";
+ throw socket_error(EPORT_IN_USE, sout.str());
+ }
+ else if (status == OTHER_ERROR)
+ {
+ throw socket_error("Unable to create listening socket for an unknown reason.");
+ }
+
+ register_thread(*this, &impl_bridge::transmit_thread);
+ register_thread(*this, &impl_bridge::receive_thread);
+ register_thread(*this, &impl_bridge::connect_thread);
+
+ start();
+ }
+
+ impl_bridge (
+ const std::string ip_,
+ unsigned short port_,
+ transmit_pipe_type* transmit_pipe_,
+ receive_pipe_type* receive_pipe_
+ ) :
+ s(m),
+ receive_thread_active(false),
+ transmit_thread_active(false),
+ port(port_),
+ ip(ip_),
+ transmit_pipe(transmit_pipe_),
+ receive_pipe(receive_pipe_),
+ dlog("dlib.bridge"),
+ keepalive_code(0),
+ message_code(1)
+ {
+ register_thread(*this, &impl_bridge::transmit_thread);
+ register_thread(*this, &impl_bridge::receive_thread);
+ register_thread(*this, &impl_bridge::connect_thread);
+
+ start();
+ }
+
+ ~impl_bridge()
+ {
+ // tell the threads to terminate
+ stop();
+
+ // save current pipe enabled status so we can restore it to however
+ // it was before this destructor ran.
+ bool transmit_enabled = true;
+ bool receive_enabled = true;
+
+ // make any calls blocked on a pipe return immediately.
+ if (transmit_pipe)
+ {
+ transmit_enabled = transmit_pipe->is_dequeue_enabled();
+ transmit_pipe->disable_dequeue();
+ }
+ if (receive_pipe)
+ {
+ receive_enabled = receive_pipe->is_enqueue_enabled();
+ receive_pipe->disable_enqueue();
+ }
+
+ {
+ auto_mutex lock(m);
+ s.broadcast();
+ // Shutdown the connection if we have one. This will cause
+ // all blocked I/O calls to return an error.
+ if (con)
+ con->shutdown();
+ }
+
+ // wait for all the threads to terminate.
+ wait();
+
+ if (transmit_pipe && transmit_enabled)
+ transmit_pipe->enable_dequeue();
+ if (receive_pipe && receive_enabled)
+ receive_pipe->enable_enqueue();
+ }
+
+ bridge_status get_bridge_status (
+ ) const
+ {
+ auto_mutex lock(current_bs_mutex);
+ return current_bs;
+ }
+
+ private:
+
+
+ template <typename pipe_type>
+ typename enable_if<is_convertible<bridge_status, typename pipe_type::type> >::type enqueue_bridge_status (
+ pipe_type* p,
+ const bridge_status& status
+ )
+ {
+ if (p)
+ {
+ typename pipe_type::type temp(status);
+ p->enqueue(temp);
+ }
+ }
+
+ template <typename pipe_type>
+ typename disable_if<is_convertible<bridge_status, typename pipe_type::type> >::type enqueue_bridge_status (
+ pipe_type* ,
+ const bridge_status&
+ )
+ {
+ }
+
+ void connect_thread (
+ )
+ {
+ while (!should_stop())
+ {
+ auto_mutex lock(m);
+ int status = OTHER_ERROR;
+ if (list)
+ {
+ do
+ {
+ status = list->accept(con, 1000);
+ } while (status == TIMEOUT && !should_stop());
+ }
+ else
+ {
+ status = create_connection(con, port, ip);
+ }
+
+ if (should_stop())
+ break;
+
+ if (status != 0)
+ {
+ // The last connection attempt failed. So pause for a little bit before making another attempt.
+ s.wait_or_timeout(2000);
+ continue;
+ }
+
+ dlog << LINFO << "Established new connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << ".";
+
+ bridge_status temp_bs;
+ { auto_mutex lock(current_bs_mutex);
+ current_bs.is_connected = true;
+ current_bs.foreign_port = con->get_foreign_port();
+ current_bs.foreign_ip = con->get_foreign_ip();
+ temp_bs = current_bs;
+ }
+ enqueue_bridge_status(receive_pipe, temp_bs);
+
+
+ receive_thread_active = true;
+ transmit_thread_active = true;
+
+ s.broadcast();
+
+ // Wait for the transmit and receive threads to end before we continue.
+ // This way we don't invalidate the con pointer while it is in use.
+ while (receive_thread_active || transmit_thread_active)
+ s.wait();
+
+
+ dlog << LINFO << "Closed connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << ".";
+ { auto_mutex lock(current_bs_mutex);
+ current_bs.is_connected = false;
+ current_bs.foreign_port = con->get_foreign_port();
+ current_bs.foreign_ip = con->get_foreign_ip();
+ temp_bs = current_bs;
+ }
+ enqueue_bridge_status(receive_pipe, temp_bs);
+ }
+
+ }
+
+
+ void receive_thread (
+ )
+ {
+ while (true)
+ {
+ // wait until we have a connection
+ { auto_mutex lock(m);
+ while (!receive_thread_active && !should_stop())
+ {
+ s.wait();
+ }
+
+ if (should_stop())
+ break;
+ }
+
+
+
+ try
+ {
+ if (receive_pipe)
+ {
+ sockstreambuf buf(con);
+ std::istream in(&buf);
+ typename receive_pipe_type::type item;
+ // This isn't necessary but doing it avoids a warning about
+ // item being uninitialized sometimes.
+ assign_zero_if_built_in_scalar_type(item);
+
+ while (in.peek() != EOF)
+ {
+ unsigned char code;
+ in.read((char*)&code, sizeof(code));
+ if (code == message_code)
+ {
+ deserialize(item, in);
+ receive_pipe->enqueue(item);
+ }
+ }
+ }
+ else
+ {
+ // Since we don't have a receive pipe to put messages into we will
+ // just read the bytes from the connection and ignore them.
+ char buf[1000];
+ while (con->read(buf, sizeof(buf)) > 0) ;
+ }
+ }
+ catch (std::bad_alloc& )
+ {
+ dlog << LERROR << "std::bad_alloc thrown while deserializing message from "
+ << con->get_foreign_ip() << ":" << con->get_foreign_port();
+ }
+ catch (dlib::serialization_error& e)
+ {
+ dlog << LERROR << "dlib::serialization_error thrown while deserializing message from "
+ << con->get_foreign_ip() << ":" << con->get_foreign_port()
+ << ".\nThe exception error message is: \n" << e.what();
+ }
+ catch (std::exception& e)
+ {
+ dlog << LERROR << "std::exception thrown while deserializing message from "
+ << con->get_foreign_ip() << ":" << con->get_foreign_port()
+ << ".\nThe exception error message is: \n" << e.what();
+ }
+
+
+
+
+ con->shutdown();
+ auto_mutex lock(m);
+ receive_thread_active = false;
+ s.broadcast();
+ }
+
+ auto_mutex lock(m);
+ receive_thread_active = false;
+ s.broadcast();
+ }
+
+ void transmit_thread (
+ )
+ {
+ while (true)
+ {
+ // wait until we have a connection
+ { auto_mutex lock(m);
+ while (!transmit_thread_active && !should_stop())
+ {
+ s.wait();
+ }
+
+ if (should_stop())
+ break;
+ }
+
+
+
+ try
+ {
+ sockstreambuf buf(con);
+ std::ostream out(&buf);
+ typename transmit_pipe_type::type item;
+ // This isn't necessary but doing it avoids a warning about
+ // item being uninitialized sometimes.
+ assign_zero_if_built_in_scalar_type(item);
+
+
+ while (out)
+ {
+ bool dequeue_timed_out = false;
+ if (transmit_pipe )
+ {
+ if (transmit_pipe->dequeue_or_timeout(item,1000))
+ {
+ out.write((char*)&message_code, sizeof(message_code));
+ serialize(item, out);
+ if (transmit_pipe->size() == 0)
+ out.flush();
+
+ continue;
+ }
+
+ dequeue_timed_out = (transmit_pipe->is_enabled() && transmit_pipe->is_dequeue_enabled());
+ }
+
+ // Pause for about a second. Note that we use a wait_or_timeout() call rather
+ // than sleep() here because we want to wake up immediately if this object is
+ // being destructed rather than hang for a second.
+ if (!dequeue_timed_out)
+ {
+ auto_mutex lock(m);
+ if (should_stop())
+ break;
+
+ s.wait_or_timeout(1000);
+ }
+ // Just send the keepalive byte periodically so we can
+ // tell if the connection is alive.
+ out.write((char*)&keepalive_code, sizeof(keepalive_code));
+ out.flush();
+ }
+ }
+ catch (std::bad_alloc& )
+ {
+ dlog << LERROR << "std::bad_alloc thrown while serializing message to "
+ << con->get_foreign_ip() << ":" << con->get_foreign_port();
+ }
+ catch (dlib::serialization_error& e)
+ {
+ dlog << LERROR << "dlib::serialization_error thrown while serializing message to "
+ << con->get_foreign_ip() << ":" << con->get_foreign_port()
+ << ".\nThe exception error message is: \n" << e.what();
+ }
+ catch (std::exception& e)
+ {
+ dlog << LERROR << "std::exception thrown while serializing message to "
+ << con->get_foreign_ip() << ":" << con->get_foreign_port()
+ << ".\nThe exception error message is: \n" << e.what();
+ }
+
+
+
+
+ con->shutdown();
+ auto_mutex lock(m);
+ transmit_thread_active = false;
+ s.broadcast();
+ }
+
+ auto_mutex lock(m);
+ transmit_thread_active = false;
+ s.broadcast();
+ }
+
+ mutex m;
+ signaler s;
+ bool receive_thread_active;
+ bool transmit_thread_active;
+ std::unique_ptr<connection> con;
+ std::unique_ptr<listener> list;
+ const unsigned short port;
+ const std::string ip;
+ transmit_pipe_type* const transmit_pipe;
+ receive_pipe_type* const receive_pipe;
+ logger dlog;
+ const unsigned char keepalive_code;
+ const unsigned char message_code;
+
+ mutex current_bs_mutex;
+ bridge_status current_bs;
+ };
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ class bridge : noncopyable
+ {
+ public:
+
+ bridge () {}
+
+ template < typename T, typename U, typename V >
+ bridge (
+ T network_parameters,
+ U pipe1,
+ V pipe2
+ ) { reconfigure(network_parameters,pipe1,pipe2); }
+
+ template < typename T, typename U>
+ bridge (
+ T network_parameters,
+ U pipe
+ ) { reconfigure(network_parameters,pipe); }
+
+
+ void clear (
+ )
+ {
+ pimpl.reset();
+ }
+
+ template < typename T, typename R >
+ void reconfigure (
+ listen_on_port network_parameters,
+ bridge_transmit_decoration<T> transmit_pipe,
+ bridge_receive_decoration<R> receive_pipe
+ ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
+
+ template < typename T, typename R >
+ void reconfigure (
+ listen_on_port network_parameters,
+ bridge_receive_decoration<R> receive_pipe,
+ bridge_transmit_decoration<T> transmit_pipe
+ ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
+
+ template < typename T >
+ void reconfigure (
+ listen_on_port network_parameters,
+ bridge_transmit_decoration<T> transmit_pipe
+ ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,T>(network_parameters.port, &transmit_pipe.p, 0)); }
+
+ template < typename R >
+ void reconfigure (
+ listen_on_port network_parameters,
+ bridge_receive_decoration<R> receive_pipe
+ ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<R,R>(network_parameters.port, 0, &receive_pipe.p)); }
+
+
+
+
+ template < typename T, typename R >
+ void reconfigure (
+ connect_to_ip_and_port network_parameters,
+ bridge_transmit_decoration<T> transmit_pipe,
+ bridge_receive_decoration<R> receive_pipe
+ ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
+
+ template < typename T, typename R >
+ void reconfigure (
+ connect_to_ip_and_port network_parameters,
+ bridge_receive_decoration<R> receive_pipe,
+ bridge_transmit_decoration<T> transmit_pipe
+ ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
+
+ template < typename R >
+ void reconfigure (
+ connect_to_ip_and_port network_parameters,
+ bridge_receive_decoration<R> receive_pipe
+ ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<R,R>(network_parameters.ip, network_parameters.port, 0, &receive_pipe.p)); }
+
+ template < typename T >
+ void reconfigure (
+ connect_to_ip_and_port network_parameters,
+ bridge_transmit_decoration<T> transmit_pipe
+ ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,T>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, 0)); }
+
+
+ bridge_status get_bridge_status (
+ ) const
+ {
+ if (pimpl)
+ return pimpl->get_bridge_status();
+ else
+ return bridge_status();
+ }
+
+ private:
+
+ std::unique_ptr<impl_brns::impl_bridge_base> pimpl;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BRIDGe_Hh_
+
diff --git a/ml/dlib/dlib/bridge/bridge_abstract.h b/ml/dlib/dlib/bridge/bridge_abstract.h
new file mode 100644
index 000000000..76ed21153
--- /dev/null
+++ b/ml/dlib/dlib/bridge/bridge_abstract.h
@@ -0,0 +1,347 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BRIDGe_ABSTRACT_
+#ifdef DLIB_BRIDGe_ABSTRACT_
+
+#include <string>
+#include "../pipe/pipe_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct connect_to_ip_and_port
+ {
+ connect_to_ip_and_port (
+ const std::string& ip,
+ unsigned short port
+ );
+ /*!
+ requires
+ - is_ip_address(ip) == true
+ - port != 0
+ ensures
+ - this object will represent a request to make a TCP connection
+ to the given IP address and port number.
+ !*/
+ };
+
+ connect_to_ip_and_port connect_to (
+ const network_address& addr
+ );
+ /*!
+ requires
+ - addr.port != 0
+ ensures
+ - converts the given network_address object into a connect_to_ip_and_port
+ object.
+ !*/
+
+ struct listen_on_port
+ {
+ listen_on_port(
+ unsigned short port
+ );
+ /*!
+ requires
+ - port != 0
+ ensures
+ - this object will represent a request to listen on the given
+ port number for incoming TCP connections.
+ !*/
+ };
+
+ template <
+ typename pipe_type
+ >
+ bridge_transmit_decoration<pipe_type> transmit (
+ pipe_type& p
+ );
+ /*!
+ requires
+ - pipe_type is some kind of dlib::pipe object
+ - the objects in the pipe must be serializable
+ ensures
+ - Adds a type decoration to the given pipe, marking it as a transmit pipe, and
+ then returns it.
+ !*/
+
+ template <
+ typename pipe_type
+ >
+ bridge_receive_decoration<pipe_type> receive (
+ pipe_type& p
+ );
+ /*!
+ requires
+ - pipe_type is some kind of dlib::pipe object
+ - the objects in the pipe must be serializable
+ ensures
+ - Adds a type decoration to the given pipe, marking it as a receive pipe, and
+ then returns it.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ struct bridge_status
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This simple struct represents the state of a bridge object. A
+ bridge is either connected or not. If it is connected then it
+ is connected to a foreign host with an IP address and port number
+ as indicated by this object.
+ !*/
+
+ bridge_status(
+ );
+ /*!
+ ensures
+ - #is_connected == false
+ - #foreign_port == 0
+ - #foreign_ip == ""
+ !*/
+
+ bool is_connected;
+ unsigned short foreign_port;
+ std::string foreign_ip;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class bridge : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for bridging a dlib::pipe object between
+ two network connected applications.
+
+
+ Note also that this object contains a dlib::logger object
+ which will log various events taking place inside a bridge.
+ If you want to see these log messages then enable the logger
+ named "dlib.bridge".
+
+
+ BRIDGE PROTOCOL DETAILS
+ The bridge object creates a single TCP connection between
+ two applications. Whenever it sends an object from a pipe
+ over a TCP connection it sends a byte with the value 1 followed
+ immediately by the serialized copy of the object from the pipe.
+ The serialization is performed by calling the global serialize()
+ function.
+
+ Additionally, a bridge object will periodically send bytes with
+ a value of 0 to ensure the TCP connection remains alive. These
+ are just read and ignored.
+ !*/
+
+ public:
+
+ bridge (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ - #get_bridge_status().is_connected == false
+ !*/
+
+ template <typename T, typename U, typename V>
+ bridge (
+ T network_parameters,
+ U pipe1,
+ V pipe2
+ );
+ /*!
+ requires
+ - T is of type connect_to_ip_and_port or listen_on_port
+ - U and V are of type bridge_transmit_decoration or bridge_receive_decoration,
+ however, U and V must be of different types (i.e. one is a receive type and
+ another a transmit type).
+ ensures
+ - this object is properly initialized
+ - performs: reconfigure(network_parameters, pipe1, pipe2)
+ (i.e. using this constructor is identical to using the default constructor
+ and then calling reconfigure())
+ !*/
+
+ template <typename T, typename U>
+ bridge (
+ T network_parameters,
+ U pipe
+ );
+ /*!
+ requires
+ - T is of type connect_to_ip_and_port or listen_on_port
+ - U is of type bridge_transmit_decoration or bridge_receive_decoration.
+ ensures
+ - this object is properly initialized
+ - performs: reconfigure(network_parameters, pipe)
+ (i.e. using this constructor is identical to using the default constructor
+ and then calling reconfigure())
+ !*/
+
+ ~bridge (
+ );
+ /*!
+ ensures
+ - blocks until all resources associated with this object have been destroyed.
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - returns this object to its default constructed state. That is, it will
+ be inactive, neither maintaining a connection nor attempting to acquire one.
+ - Any active connections or listening sockets will be closed.
+ !*/
+
+ bridge_status get_bridge_status (
+ ) const;
+ /*!
+ ensures
+ - returns the current status of this bridge object. In particular, returns
+ an object BS such that:
+ - BS.is_connected == true if and only if the bridge has an active TCP
+ connection to another computer.
+ - if (BS.is_connected) then
+ - BS.foreign_ip == the IP address of the remote host we are connected to.
+ - BS.foreign_port == the port number on the remote host we are connected to.
+ - else if (the bridge has previously been connected to a remote host but hasn't been
+ reconfigured or cleared since) then
+ - BS.foreign_ip == the IP address of the remote host we were connected to.
+ - BS.foreign_port == the port number on the remote host we were connected to.
+ - else
+ - BS.foreign_ip == ""
+ - BS.foreign_port == 0
+ !*/
+
+
+
+ template < typename T, typename R >
+ void reconfigure (
+ listen_on_port network_parameters,
+ bridge_transmit_decoration<T> transmit_pipe,
+ bridge_receive_decoration<R> receive_pipe
+ );
+ /*!
+ ensures
+ - This object will begin listening on the port specified by network_parameters
+ for incoming TCP connections. Any previous bridge state is cleared out.
+ - Onces a connection is established we will:
+ - Stop accepting new connections.
+ - Begin dequeuing objects from the transmit pipe and serializing them over
+ the TCP connection.
+ - Begin deserializing objects from the TCP connection and enqueueing them
+ onto the receive pipe.
+ - if (the current TCP connection is lost) then
+ - This object goes back to listening for a new connection.
+ - if (the receive pipe can contain bridge_status objects) then
+ - Whenever the bridge's status changes the updated bridge_status will be
+ enqueued onto the receive pipe unless the change was a TCP disconnect
+ resulting from a user calling reconfigure(), clear(), or destructing this
+ bridge. The status contents are defined by get_bridge_status().
+ throws
+ - socket_error
+ This exception is thrown if we are unable to open the listening socket.
+ !*/
+ template < typename T, typename R >
+ void reconfigure (
+ listen_on_port network_parameters,
+ bridge_receive_decoration<R> receive_pipe,
+ bridge_transmit_decoration<T> transmit_pipe
+ );
+ /*!
+ ensures
+ - performs reconfigure(network_parameters, transmit_pipe, receive_pipe)
+ !*/
+ template < typename T >
+ void reconfigure (
+ listen_on_port network_parameters,
+ bridge_transmit_decoration<T> transmit_pipe
+ );
+ /*!
+ ensures
+ - This function is identical to the above two reconfigure() functions
+ except that there is no receive pipe.
+ !*/
+ template < typename R >
+ void reconfigure (
+ listen_on_port network_parameters,
+ bridge_receive_decoration<R> receive_pipe
+ );
+ /*!
+ ensures
+ - This function is identical to the above three reconfigure() functions
+ except that there is no transmit pipe.
+ !*/
+
+
+
+ template <typename T, typename R>
+ void reconfigure (
+ connect_to_ip_and_port network_parameters,
+ bridge_transmit_decoration<T> transmit_pipe,
+ bridge_receive_decoration<R> receive_pipe
+ );
+ /*!
+ ensures
+ - This object will begin making TCP connection attempts to the IP address and port
+ specified by network_parameters. Any previous bridge state is cleared out.
+ - Onces a connection is established we will:
+ - Stop attempting new connections.
+ - Begin dequeuing objects from the transmit pipe and serializing them over
+ the TCP connection.
+ - Begin deserializing objects from the TCP connection and enqueueing them
+ onto the receive pipe.
+ - if (the current TCP connection is lost) then
+ - This object goes back to attempting to make a TCP connection with the
+ IP address and port specified by network_parameters.
+ - if (the receive pipe can contain bridge_status objects) then
+ - Whenever the bridge's status changes the updated bridge_status will be
+ enqueued onto the receive pipe unless the change was a TCP disconnect
+ resulting from a user calling reconfigure(), clear(), or destructing this
+ bridge. The status contents are defined by get_bridge_status().
+ !*/
+ template <typename T, typename R>
+ void reconfigure (
+ connect_to_ip_and_port network_parameters,
+ bridge_receive_decoration<R> receive_pipe,
+ bridge_transmit_decoration<T> transmit_pipe
+ );
+ /*!
+ ensures
+ - performs reconfigure(network_parameters, transmit_pipe, receive_pipe)
+ !*/
+ template <typename T>
+ void reconfigure (
+ connect_to_ip_and_port network_parameters,
+ bridge_transmit_decoration<T> transmit_pipe
+ );
+ /*!
+ ensures
+ - This function is identical to the above two reconfigure() functions
+ except that there is no receive pipe.
+ !*/
+ template <typename R>
+ void reconfigure (
+ connect_to_ip_and_port network_parameters,
+ bridge_receive_decoration<R> receive_pipe
+ );
+ /*!
+ ensures
+ - This function is identical to the above three reconfigure() functions
+ except that there is no transmit pipe.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BRIDGe_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/bsp.h b/ml/dlib/dlib/bsp.h
new file mode 100644
index 000000000..899b6a405
--- /dev/null
+++ b/ml/dlib/dlib/bsp.h
@@ -0,0 +1,12 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BSPh_
+#define DLIB_BSPh_
+
+
+#include "bsp/bsp.h"
+
+#endif // DLIB_BSPh_
+
+
+
diff --git a/ml/dlib/dlib/bsp/bsp.cpp b/ml/dlib/dlib/bsp/bsp.cpp
new file mode 100644
index 000000000..32e23519e
--- /dev/null
+++ b/ml/dlib/dlib/bsp/bsp.cpp
@@ -0,0 +1,496 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BSP_CPph_
+#define DLIB_BSP_CPph_
+
+#include "bsp.h"
+#include <memory>
+#include <stack>
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ namespace impl1
+ {
+
+ void connect_all (
+ map_id_to_con& cons,
+ const std::vector<network_address>& hosts,
+ unsigned long node_id
+ )
+ {
+ cons.clear();
+ for (unsigned long i = 0; i < hosts.size(); ++i)
+ {
+ std::unique_ptr<bsp_con> con(new bsp_con(hosts[i]));
+ dlib::serialize(node_id, con->stream); // tell the other end our node_id
+ unsigned long id = i+1;
+ cons.add(id, con);
+ }
+ }
+
+ void connect_all_hostinfo (
+ map_id_to_con& cons,
+ const std::vector<hostinfo>& hosts,
+ unsigned long node_id,
+ std::string& error_string
+ )
+ {
+ cons.clear();
+ for (unsigned long i = 0; i < hosts.size(); ++i)
+ {
+ try
+ {
+ std::unique_ptr<bsp_con> con(new bsp_con(hosts[i].addr));
+ dlib::serialize(node_id, con->stream); // tell the other end our node_id
+ con->stream.flush();
+ unsigned long id = hosts[i].node_id;
+ cons.add(id, con);
+ }
+ catch (std::exception&)
+ {
+ std::ostringstream sout;
+ sout << "Could not connect to " << hosts[i].addr;
+ error_string = sout.str();
+ break;
+ }
+ }
+ }
+
+
+ void send_out_connection_orders (
+ map_id_to_con& cons,
+ const std::vector<network_address>& hosts
+ )
+ {
+ // tell everyone their node ids
+ cons.reset();
+ while (cons.move_next())
+ {
+ dlib::serialize(cons.element().key(), cons.element().value()->stream);
+ }
+
+ // now tell them who to connect to
+ std::vector<hostinfo> targets;
+ for (unsigned long i = 0; i < hosts.size(); ++i)
+ {
+ hostinfo info(hosts[i], i+1);
+
+ dlib::serialize(targets, cons[info.node_id]->stream);
+ targets.push_back(info);
+
+ // let the other host know how many incoming connections to expect
+ const unsigned long num = hosts.size()-targets.size();
+ dlib::serialize(num, cons[info.node_id]->stream);
+ cons[info.node_id]->stream.flush();
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl2
+ {
+ // These control bytes are sent before each message between nodes. Note that many
+ // of these are only sent between the control node (node 0) and the other nodes.
+ // This is because the controller node is responsible for handling the
+ // synchronization that needs to happen when all nodes block on calls to
+ // receive_data()
+ // at the same time.
+
+ // denotes a normal content message.
+ const static char MESSAGE_HEADER = 0;
+
+ // sent to the controller node when someone receives a message via receive_data().
+ const static char GOT_MESSAGE = 1;
+
+ // sent to the controller node when someone sends a message via send().
+ const static char SENT_MESSAGE = 2;
+
+ // sent to the controller node when someone enters a call to receive_data()
+ const static char IN_WAITING_STATE = 3;
+
+ // broadcast when a node terminates itself.
+ const static char NODE_TERMINATE = 5;
+
+ // broadcast by the controller node when it determines that all nodes are blocked
+ // on calls to receive_data() and there aren't any messages in flight. This is also
+ // what makes us go to the next epoch.
+ const static char SEE_ALL_IN_WAITING_STATE = 6;
+
+ // This isn't ever transmitted between nodes. It is used internally to indicate
+ // that an error occurred.
+ const static char READ_ERROR = 7;
+
+ // ------------------------------------------------------------------------------------
+
+ void read_thread (
+ impl1::bsp_con* con,
+ unsigned long node_id,
+ unsigned long sender_id,
+ impl1::thread_safe_message_queue& msg_buffer
+ )
+ {
+ try
+ {
+ while(true)
+ {
+ impl1::msg_data msg;
+ deserialize(msg.msg_type, con->stream);
+ msg.sender_id = sender_id;
+
+ if (msg.msg_type == MESSAGE_HEADER)
+ {
+ msg.data.reset(new std::vector<char>);
+ deserialize(msg.epoch, con->stream);
+ deserialize(*msg.data, con->stream);
+ }
+
+ msg_buffer.push_and_consume(msg);
+
+ if (msg.msg_type == NODE_TERMINATE)
+ break;
+ }
+ }
+ catch (std::exception& e)
+ {
+ impl1::msg_data msg;
+ msg.data.reset(new std::vector<char>);
+ vectorstream sout(*msg.data);
+ sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
+ sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
+ sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
+ sout << " Receiving processing node id: " << node_id << std::endl;
+ sout << " Error message in the exception: " << e.what() << std::endl;
+
+ msg.sender_id = sender_id;
+ msg.msg_type = READ_ERROR;
+
+ msg_buffer.push_and_consume(msg);
+ }
+ catch (...)
+ {
+ impl1::msg_data msg;
+ msg.data.reset(new std::vector<char>);
+ vectorstream sout(*msg.data);
+ sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
+ sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
+ sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
+ sout << " Receiving processing node id: " << node_id << std::endl;
+
+ msg.sender_id = sender_id;
+ msg.msg_type = READ_ERROR;
+
+ msg_buffer.push_and_consume(msg);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// IMPLEMENTATION OF bsp_context OBJECT MEMBERS
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void bsp_context::
+ close_all_connections_gracefully(
+ )
+ {
+ if (node_id() != 0)
+ {
+ _cons.reset();
+ while (_cons.move_next())
+ {
+ // tell the other end that we are intentionally dropping the connection
+ serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
+ _cons.element().value()->stream.flush();
+ }
+ }
+
+ impl1::msg_data msg;
+ // now wait for all the other nodes to terminate
+ while (num_terminated_nodes < _cons.size() )
+ {
+ if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
+ {
+ num_waiting_nodes = 0;
+ broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
+ ++current_epoch;
+ }
+
+ if (!msg_buffer.pop(msg))
+ throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
+
+ if (msg.msg_type == impl2::NODE_TERMINATE)
+ {
+ ++num_terminated_nodes;
+ _cons[msg.sender_id]->terminated = true;
+ }
+ else if (msg.msg_type == impl2::READ_ERROR)
+ {
+ throw dlib::socket_error(msg.data_to_string());
+ }
+ else if (msg.msg_type == impl2::MESSAGE_HEADER)
+ {
+ throw dlib::socket_error("A BSP node received a message after it has terminated.");
+ }
+ else if (msg.msg_type == impl2::GOT_MESSAGE)
+ {
+ --num_waiting_nodes;
+ --outstanding_messages;
+ }
+ else if (msg.msg_type == impl2::SENT_MESSAGE)
+ {
+ ++outstanding_messages;
+ }
+ else if (msg.msg_type == impl2::IN_WAITING_STATE)
+ {
+ ++num_waiting_nodes;
+ }
+ }
+
+ if (node_id() == 0)
+ {
+ _cons.reset();
+ while (_cons.move_next())
+ {
+ // tell the other end that we are intentionally dropping the connection
+ serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
+ _cons.element().value()->stream.flush();
+ }
+
+ if (outstanding_messages != 0)
+ {
+ std::ostringstream sout;
+ sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
+ sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n";
+ sout << "have a corresponding call to receive().";
+ throw dlib::socket_error(sout.str());
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bsp_context::
+ ~bsp_context()
+ {
+ _cons.reset();
+ while (_cons.move_next())
+ {
+ _cons.element().value()->con->shutdown();
+ }
+
+ msg_buffer.disable();
+
+ // this will wait for all the threads to terminate
+ threads.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bsp_context::
+ bsp_context(
+ unsigned long node_id_,
+ impl1::map_id_to_con& cons_
+ ) :
+ outstanding_messages(0),
+ num_waiting_nodes(0),
+ num_terminated_nodes(0),
+ current_epoch(1),
+ _cons(cons_),
+ _node_id(node_id_)
+ {
+ // spawn a bunch of read threads, one for each connection
+ _cons.reset();
+ while (_cons.move_next())
+ {
+ std::unique_ptr<thread_function> ptr(new thread_function(&impl2::read_thread,
+ _cons.element().value().get(),
+ _node_id,
+ _cons.element().key(),
+ ref(msg_buffer)));
+ threads.push_back(ptr);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bsp_context::
+ receive_data (
+ std::shared_ptr<std::vector<char> >& item,
+ unsigned long& sending_node_id
+ )
+ {
+ notify_control_node(impl2::IN_WAITING_STATE);
+
+ while (true)
+ {
+ // If there aren't any nodes left to give us messages then return right now.
+ // We need to check the msg_buffer size to make sure there aren't any
+ // unprocessed message there. Recall that this can happen because status
+ // messages always jump to the front of the message buffer. So we might have
+ // learned about the node terminations before processing their messages for us.
+ if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0)
+ {
+ return false;
+ }
+
+ // if all running nodes are currently blocking forever on receive_data()
+ if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
+ {
+ num_waiting_nodes = 0;
+ broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
+
+ // Note that the reason we have this epoch counter is so we can tell if a
+ // sent message is from before or after one of these "all nodes waiting"
+ // synchronization events. If we didn't have the epoch count we would have
+ // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE
+ // message before others and then sends out a message to another node
+ // before that node got the SEE_ALL_IN_WAITING_STATE message. Then that
+ // node would think the normal message came before SEE_ALL_IN_WAITING_STATE
+ // which would be bad.
+ ++current_epoch;
+ return false;
+ }
+
+ impl1::msg_data data;
+ if (!msg_buffer.pop(data, current_epoch))
+ throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
+
+
+ switch(data.msg_type)
+ {
+ case impl2::MESSAGE_HEADER: {
+ item = data.data;
+ sending_node_id = data.sender_id;
+ notify_control_node(impl2::GOT_MESSAGE);
+ return true;
+ } break;
+
+ case impl2::IN_WAITING_STATE: {
+ ++num_waiting_nodes;
+ } break;
+
+ case impl2::GOT_MESSAGE: {
+ --outstanding_messages;
+ --num_waiting_nodes;
+ } break;
+
+ case impl2::SENT_MESSAGE: {
+ ++outstanding_messages;
+ } break;
+
+ case impl2::NODE_TERMINATE: {
+ ++num_terminated_nodes;
+ _cons[data.sender_id]->terminated = true;
+ } break;
+
+ case impl2::SEE_ALL_IN_WAITING_STATE: {
+ ++current_epoch;
+ return false;
+ } break;
+
+ case impl2::READ_ERROR: {
+ throw dlib::socket_error(data.data_to_string());
+ } break;
+
+ default: {
+ throw dlib::socket_error("Unknown message received by dlib::bsp_context");
+ } break;
+ } // end switch()
+ } // end while (true)
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bsp_context::
+ notify_control_node (
+ char val
+ )
+ {
+ if (node_id() == 0)
+ {
+ using namespace impl2;
+ switch(val)
+ {
+ case SENT_MESSAGE: {
+ ++outstanding_messages;
+ } break;
+
+ case GOT_MESSAGE: {
+ --outstanding_messages;
+ } break;
+
+ case IN_WAITING_STATE: {
+ // nothing to do in this case
+ } break;
+
+ default:
+ DLIB_CASSERT(false,"This should never happen");
+ }
+ }
+ else
+ {
+ serialize(val, _cons[0]->stream);
+ _cons[0]->stream.flush();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bsp_context::
+ broadcast_byte (
+ char val
+ )
+ {
+ for (unsigned long i = 0; i < number_of_nodes(); ++i)
+ {
+ // don't send to yourself or to terminated nodes
+ if (i == node_id() || _cons[i]->terminated)
+ continue;
+
+ serialize(val, _cons[i]->stream);
+ _cons[i]->stream.flush();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bsp_context::
+ send_data(
+ const std::vector<char>& item,
+ unsigned long target_node_id
+ )
+ {
+ using namespace impl2;
+ if (_cons[target_node_id]->terminated)
+ throw socket_error("Attempt to send a message to a node that has terminated.");
+
+ serialize(MESSAGE_HEADER, _cons[target_node_id]->stream);
+ serialize(current_epoch, _cons[target_node_id]->stream);
+ serialize(item, _cons[target_node_id]->stream);
+ _cons[target_node_id]->stream.flush();
+
+ notify_control_node(SENT_MESSAGE);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BSP_CPph_
+
diff --git a/ml/dlib/dlib/bsp/bsp.h b/ml/dlib/dlib/bsp/bsp.h
new file mode 100644
index 000000000..f0732c153
--- /dev/null
+++ b/ml/dlib/dlib/bsp/bsp.h
@@ -0,0 +1,1043 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BsP_Hh_
+#define DLIB_BsP_Hh_
+
+#include "bsp_abstract.h"
+
+#include <memory>
+#include <queue>
+#include <vector>
+
+#include "../sockets.h"
+#include "../array.h"
+#include "../sockstreambuf.h"
+#include "../string.h"
+#include "../serialize.h"
+#include "../map.h"
+#include "../ref.h"
+#include "../vectorstream.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl1
+ {
+ inline void null_notify(
+ unsigned short
+ ) {}
+
+ struct bsp_con
+ {
+ bsp_con(
+ const network_address& dest
+ ) :
+ con(connect(dest)),
+ buf(con),
+ stream(&buf),
+ terminated(false)
+ {
+ con->disable_nagle();
+ }
+
+ bsp_con(
+ std::unique_ptr<connection>& conptr
+ ) :
+ buf(conptr),
+ stream(&buf),
+ terminated(false)
+ {
+ // make sure we own the connection
+ conptr.swap(con);
+
+ con->disable_nagle();
+ }
+
+ std::unique_ptr<connection> con;
+ sockstreambuf buf;
+ std::iostream stream;
+ bool terminated;
+ };
+
+ typedef dlib::map<unsigned long, std::unique_ptr<bsp_con> >::kernel_1a_c map_id_to_con;
+
+ void connect_all (
+ map_id_to_con& cons,
+ const std::vector<network_address>& hosts,
+ unsigned long node_id
+ );
+ /*!
+ ensures
+ - creates connections to all the given hosts and stores them into cons
+ !*/
+
+ void send_out_connection_orders (
+ map_id_to_con& cons,
+ const std::vector<network_address>& hosts
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ struct hostinfo
+ {
+ hostinfo() {}
+ hostinfo (
+ const network_address& addr_,
+ unsigned long node_id_
+ ) :
+ addr(addr_),
+ node_id(node_id_)
+ {
+ }
+
+ network_address addr;
+ unsigned long node_id;
+ };
+
+ inline void serialize (
+ const hostinfo& item,
+ std::ostream& out
+ )
+ {
+ dlib::serialize(item.addr, out);
+ dlib::serialize(item.node_id, out);
+ }
+
+ inline void deserialize (
+ hostinfo& item,
+ std::istream& in
+ )
+ {
+ dlib::deserialize(item.addr, in);
+ dlib::deserialize(item.node_id, in);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void connect_all_hostinfo (
+ map_id_to_con& cons,
+ const std::vector<hostinfo>& hosts,
+ unsigned long node_id,
+ std::string& error_string
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type
+ >
+ void listen_and_connect_all(
+ unsigned long& node_id,
+ map_id_to_con& cons,
+ unsigned short port,
+ port_notify_function_type port_notify_function
+ )
+ {
+ cons.clear();
+ std::unique_ptr<listener> list;
+ const int status = create_listener(list, port);
+ if (status == PORTINUSE)
+ {
+ throw socket_error("Unable to create listening port " + cast_to_string(port) +
+ ". The port is already in use");
+ }
+ else if (status != 0)
+ {
+ throw socket_error("Unable to create listening port " + cast_to_string(port) );
+ }
+
+ port_notify_function(list->get_listening_port());
+
+ std::unique_ptr<connection> con;
+ if (list->accept(con))
+ {
+ throw socket_error("Error occurred while accepting new connection");
+ }
+
+ std::unique_ptr<bsp_con> temp(new bsp_con(con));
+
+ unsigned long remote_node_id;
+ dlib::deserialize(remote_node_id, temp->stream);
+ dlib::deserialize(node_id, temp->stream);
+ std::vector<hostinfo> targets;
+ dlib::deserialize(targets, temp->stream);
+ unsigned long num_incoming_connections;
+ dlib::deserialize(num_incoming_connections, temp->stream);
+
+ cons.add(remote_node_id,temp);
+
+ // make a thread that will connect to all the targets
+ map_id_to_con cons2;
+ std::string error_string;
+ thread_function thread(connect_all_hostinfo, dlib::ref(cons2), dlib::ref(targets), node_id, dlib::ref(error_string));
+ if (error_string.size() != 0)
+ throw socket_error(error_string);
+
+ // accept any incoming connections
+ for (unsigned long i = 0; i < num_incoming_connections; ++i)
+ {
+ // If it takes more than 10 seconds for the other nodes to connect to us
+ // then something has gone horribly wrong and it almost certainly will
+ // never connect at all. So just give up if that happens.
+ const unsigned long timeout_milliseconds = 10000;
+ if (list->accept(con, timeout_milliseconds))
+ {
+ throw socket_error("Error occurred while accepting new connection");
+ }
+
+ temp.reset(new bsp_con(con));
+
+ dlib::deserialize(remote_node_id, temp->stream);
+ cons.add(remote_node_id,temp);
+ }
+
+
+ // put all the connections created by the thread into cons
+ thread.wait();
+ while (cons2.size() > 0)
+ {
+ unsigned long id;
+ std::unique_ptr<bsp_con> temp;
+ cons2.remove_any(id,temp);
+ cons.add(id,temp);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ struct msg_data
+ {
+ std::shared_ptr<std::vector<char> > data;
+ unsigned long sender_id;
+ char msg_type;
+ dlib::uint64 epoch;
+
+ msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {}
+
+ std::string data_to_string() const
+ {
+ if (data && data->size() != 0)
+ return std::string(&(*data)[0], data->size());
+ else
+ return "";
+ }
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ class thread_safe_message_queue : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple message queue for msg_data objects. Note that it
+ has the special property that, while messages will generally leave
+ the queue in the order they are inserted, any message with a smaller
+ epoch value will always be popped out first. But for all messages
+ with equal epoch values the queue functions as a normal FIFO queue.
+ !*/
+ private:
+ struct msg_wrap
+ {
+ msg_wrap(
+ const msg_data& data_,
+ const dlib::uint64& sequence_number_
+ ) : data(data_), sequence_number(sequence_number_) {}
+
+ msg_wrap() : sequence_number(0){}
+
+ msg_data data;
+ dlib::uint64 sequence_number;
+
+ // Make it so that when msg_wrap objects are in a std::priority_queue,
+ // messages with a smaller epoch number always come first. Then, within an
+ // epoch, messages are ordered by their sequence number (so smaller first
+ // there as well).
+ bool operator<(const msg_wrap& item) const
+ {
+ if (data.epoch < item.data.epoch)
+ {
+ return false;
+ }
+ else if (data.epoch > item.data.epoch)
+ {
+ return true;
+ }
+ else
+ {
+ if (sequence_number < item.sequence_number)
+ return false;
+ else
+ return true;
+ }
+ }
+ };
+
+ public:
+ thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {}
+
+ ~thread_safe_message_queue()
+ {
+ disable();
+ }
+
+ void disable()
+ {
+ auto_mutex lock(class_mutex);
+ disabled = true;
+ sig.broadcast();
+ }
+
+ unsigned long size() const
+ {
+ auto_mutex lock(class_mutex);
+ return data.size();
+ }
+
+ void push_and_consume( msg_data& item)
+ {
+ auto_mutex lock(class_mutex);
+ data.push(msg_wrap(item, next_seq_num++));
+ // do this here so that we don't have to worry about different threads touching the shared_ptr.
+ item.data.reset();
+ sig.signal();
+ }
+
+ bool pop (
+ msg_data& item
+ )
+ /*!
+ ensures
+ - if (this function returns true) then
+ - #item == the next thing from the queue
+ - else
+ - this object is disabled
+ !*/
+ {
+ auto_mutex lock(class_mutex);
+ while (data.size() == 0 && !disabled)
+ sig.wait();
+
+ if (disabled)
+ return false;
+
+ item = data.top().data;
+ data.pop();
+
+ return true;
+ }
+
+ bool pop (
+ msg_data& item,
+ const dlib::uint64& max_epoch
+ )
+ /*!
+ ensures
+ - if (this function returns true) then
+ - #item == the next thing from the queue that has an epoch <= max_epoch
+ - else
+ - this object is disabled
+ !*/
+ {
+ auto_mutex lock(class_mutex);
+ while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled)
+ sig.wait();
+
+ if (disabled)
+ return false;
+
+ item = data.top().data;
+ data.pop();
+
+ return true;
+ }
+
+ private:
+ std::priority_queue<msg_wrap> data;
+ dlib::mutex class_mutex;
+ dlib::signaler sig;
+ bool disabled;
+ dlib::uint64 next_seq_num;
+ };
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class bsp_context : noncopyable
+ {
+
+ public:
+
+ template <typename T>
+ void send(
+ const T& item,
+ unsigned long target_node_id
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(target_node_id < number_of_nodes() &&
+ target_node_id != node_id(),
+ "\t void bsp_context::send()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t target_node_id: " << target_node_id
+ << "\n\t node_id(): " << node_id()
+ << "\n\t number_of_nodes(): " << number_of_nodes()
+ << "\n\t this: " << this
+ );
+
+ std::vector<char> buf;
+ vectorstream sout(buf);
+ serialize(item, sout);
+ send_data(buf, target_node_id);
+ }
+
+ template <typename T>
+ void broadcast (
+ const T& item
+ )
+ {
+ std::vector<char> buf;
+ vectorstream sout(buf);
+ serialize(item, sout);
+ for (unsigned long i = 0; i < number_of_nodes(); ++i)
+ {
+ // Don't send to yourself.
+ if (i == node_id())
+ continue;
+
+ send_data(buf, i);
+ }
+ }
+
+ unsigned long node_id (
+ ) const { return _node_id; }
+
+ unsigned long number_of_nodes (
+ ) const { return _cons.size()+1; }
+
+ void receive (
+ )
+ {
+ unsigned long id;
+ std::shared_ptr<std::vector<char> > temp;
+ if (receive_data(temp,id))
+ throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message.");
+ }
+
+ template <typename T>
+ void receive (
+ T& item
+ )
+ {
+ if(!try_receive(item))
+ throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked.");
+ }
+
+ template <typename T>
+ bool try_receive (
+ T& item
+ )
+ {
+ unsigned long sending_node_id;
+ return try_receive(item, sending_node_id);
+ }
+
+ template <typename T>
+ void receive (
+ T& item,
+ unsigned long& sending_node_id
+ )
+ {
+ if(!try_receive(item, sending_node_id))
+ throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked.");
+ }
+
+ template <typename T>
+ bool try_receive (
+ T& item,
+ unsigned long& sending_node_id
+ )
+ {
+ std::shared_ptr<std::vector<char> > temp;
+ if (receive_data(temp, sending_node_id))
+ {
+ vectorstream sin(*temp);
+ deserialize(item, sin);
+ if (sin.peek() != EOF)
+ throw serialization_error("deserialize() did not consume all bytes produced by serialize(). "
+ "This probably means you are calling a receive method with a different type "
+ "of object than the one which was sent.");
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ ~bsp_context();
+
+ private:
+
+ bsp_context();
+
+ bsp_context(
+ unsigned long node_id_,
+ impl1::map_id_to_con& cons_
+ );
+
+ void close_all_connections_gracefully();
+ /*!
+ ensures
+ - closes all the connections to other nodes and lets them know that
+ we are terminating normally rather than as the result of some kind
+ of error.
+ !*/
+
+ bool receive_data (
+ std::shared_ptr<std::vector<char> >& item,
+ unsigned long& sending_node_id
+ );
+
+
+ void notify_control_node (
+ char val
+ );
+
+ void broadcast_byte (
+ char val
+ );
+
+ void send_data(
+ const std::vector<char>& item,
+ unsigned long target_node_id
+ );
+ /*!
+ requires
+ - target_node_id < number_of_nodes()
+ - target_node_id != node_id()
+ ensures
+ - sends a copy of item to the node with the given id.
+ !*/
+
+
+
+
+ unsigned long outstanding_messages;
+ unsigned long num_waiting_nodes;
+ unsigned long num_terminated_nodes;
+ dlib::uint64 current_epoch;
+
+ impl1::thread_safe_message_queue msg_buffer;
+
+ impl1::map_id_to_con& _cons;
+ const unsigned long _node_id;
+ array<std::unique_ptr<thread_function> > threads;
+
+ // -----------------------------------
+
+ template <
+ typename funct_type
+ >
+ friend void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct
+ );
+
+ template <
+ typename funct_type,
+ typename ARG1
+ >
+ friend void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1
+ );
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2
+ >
+ friend void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2
+ );
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3
+ >
+ friend void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3
+ );
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3,
+ typename ARG4
+ >
+ friend void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3,
+ ARG4 arg4
+ );
+
+ // -----------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type
+ >
+ friend void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct
+ );
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1
+ >
+ friend void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1
+ );
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1,
+ typename ARG2
+ >
+ friend void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2
+ );
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3
+ >
+ friend void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3
+ );
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3,
+ typename ARG4
+ >
+ friend void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3,
+ ARG4 arg4
+ );
+
+ // -----------------------------------
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct
+ )
+ {
+ impl1::map_id_to_con cons;
+ const unsigned long node_id = 0;
+ connect_all(cons, hosts, node_id);
+ send_out_connection_orders(cons, hosts);
+ bsp_context obj(node_id, cons);
+ funct(obj);
+ obj.close_all_connections_gracefully();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1
+ )
+ {
+ impl1::map_id_to_con cons;
+ const unsigned long node_id = 0;
+ connect_all(cons, hosts, node_id);
+ send_out_connection_orders(cons, hosts);
+ bsp_context obj(node_id, cons);
+ funct(obj,arg1);
+ obj.close_all_connections_gracefully();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2
+ )
+ {
+ impl1::map_id_to_con cons;
+ const unsigned long node_id = 0;
+ connect_all(cons, hosts, node_id);
+ send_out_connection_orders(cons, hosts);
+ bsp_context obj(node_id, cons);
+ funct(obj,arg1,arg2);
+ obj.close_all_connections_gracefully();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3
+ )
+ {
+ impl1::map_id_to_con cons;
+ const unsigned long node_id = 0;
+ connect_all(cons, hosts, node_id);
+ send_out_connection_orders(cons, hosts);
+ bsp_context obj(node_id, cons);
+ funct(obj,arg1,arg2,arg3);
+ obj.close_all_connections_gracefully();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3,
+ typename ARG4
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3,
+ ARG4 arg4
+ )
+ {
+ impl1::map_id_to_con cons;
+ const unsigned long node_id = 0;
+ connect_all(cons, hosts, node_id);
+ send_out_connection_orders(cons, hosts);
+ bsp_context obj(node_id, cons);
+ funct(obj,arg1,arg2,arg3,arg4);
+ obj.close_all_connections_gracefully();
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(listening_port != 0,
+ "\t void bsp_listen()"
+ << "\n\t Invalid arguments were given to this function."
+ );
+
+ bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct,
+ ARG1 arg1
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(listening_port != 0,
+ "\t void bsp_listen()"
+ << "\n\t Invalid arguments were given to this function."
+ );
+
+ bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(listening_port != 0,
+ "\t void bsp_listen()"
+ << "\n\t Invalid arguments were given to this function."
+ );
+
+ bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(listening_port != 0,
+ "\t void bsp_listen()"
+ << "\n\t Invalid arguments were given to this function."
+ );
+
+ bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3,
+ typename ARG4
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3,
+ ARG4 arg4
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(listening_port != 0,
+ "\t void bsp_listen()"
+ << "\n\t Invalid arguments were given to this function."
+ );
+
+ bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3, arg4);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct
+ )
+ {
+ impl1::map_id_to_con cons;
+ unsigned long node_id;
+ listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
+ bsp_context obj(node_id, cons);
+ funct(obj);
+ obj.close_all_connections_gracefully();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1
+ )
+ {
+ impl1::map_id_to_con cons;
+ unsigned long node_id;
+ listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
+ bsp_context obj(node_id, cons);
+ funct(obj,arg1);
+ obj.close_all_connections_gracefully();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1,
+ typename ARG2
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2
+ )
+ {
+ impl1::map_id_to_con cons;
+ unsigned long node_id;
+ listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
+ bsp_context obj(node_id, cons);
+ funct(obj,arg1,arg2);
+ obj.close_all_connections_gracefully();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3
+ )
+ {
+ impl1::map_id_to_con cons;
+ unsigned long node_id;
+ listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
+ bsp_context obj(node_id, cons);
+ funct(obj,arg1,arg2,arg3);
+ obj.close_all_connections_gracefully();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3,
+ typename ARG4
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3,
+ ARG4 arg4
+ )
+ {
+ impl1::map_id_to_con cons;
+ unsigned long node_id;
+ listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
+ bsp_context obj(node_id, cons);
+ funct(obj,arg1,arg2,arg3,arg4);
+ obj.close_all_connections_gracefully();
+ }
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "bsp.cpp"
+#endif
+
+#endif // DLIB_BsP_Hh_
+
diff --git a/ml/dlib/dlib/bsp/bsp_abstract.h b/ml/dlib/dlib/bsp/bsp_abstract.h
new file mode 100644
index 000000000..b87f3a0c3
--- /dev/null
+++ b/ml/dlib/dlib/bsp/bsp_abstract.h
@@ -0,0 +1,912 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BsP_ABSTRACT_Hh_
+#ifdef DLIB_BsP_ABSTRACT_Hh_
+
+#include "../noncopyable.h"
+#include "../sockets/sockets_extensions_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class bsp_context : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool used to implement algorithms using the Bulk Synchronous
+ Parallel (BSP) computing model. A BSP algorithm is composed of a number of
+ processing nodes, each executing in parallel. The general flow of
+ execution in each processing node is the following:
+ 1. Do work locally on some data.
+ 2. Send some messages to other nodes.
+ 3. Receive messages from other nodes.
+ 4. Go to step 1 or terminate if complete.
+
+ To do this, each processing node needs an API used to send and receive
+ messages. This API is implemented by the bsp_connect object which provides
+ these services to a BSP node.
+
+ Note that BSP processing nodes are spawned using the bsp_connect() and
+ bsp_listen() routines defined at the bottom of this file. For example, to
+ start a BSP algorithm consisting of N processing nodes, you would make N-1
+ calls to bsp_listen() and one call to bsp_connect(). The call to
+ bsp_connect() then initiates the computation on all nodes.
+
+ Finally, note that there is no explicit barrier synchronization function
+ you call at the end of step 3. Instead, you can simply call a method such
+ as try_receive() until it returns false. That is, the bsp_context's
+ receive methods incorporate a barrier synchronization that happens once all
+ the BSP nodes are blocked on receive calls and there are no more messages
+ in flight.
+
+
+ THREAD SAFETY
+ This object is not thread-safe. In particular, you should only ever have
+ one thread that works with an instance of this object. This means that,
+ for example, you should not spawn sub-threads from within a BSP processing
+ node and have them invoke methods on this object. Instead, you should only
+ invoke this object's methods from within the BSP processing node's main
+ thread (i.e. the thread that executes the user supplied function funct()).
+ !*/
+
+ public:
+
+ template <typename T>
+ void send(
+ const T& item,
+ unsigned long target_node_id
+ );
+ /*!
+ requires
+ - item is serializable
+ - target_node_id < number_of_nodes()
+ - target_node_id != node_id()
+ ensures
+ - sends a copy of item to the node with the given id.
+ throws
+ - dlib::socket_error:
+ This exception is thrown if there is an error which prevents us from
+ delivering the message to the given node. One way this might happen is
+ if the target node has already terminated its execution or has lost
+ network connectivity.
+ !*/
+
+ template <typename T>
+ void broadcast (
+ const T& item
+ );
+ /*!
+ ensures
+ - item is serializable
+ - sends a copy of item to all other processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents us from
+ delivering a message to one of the other nodes. This might happen, for
+ example, if one of the nodes has terminated its execution or has lost
+ network connectivity.
+ !*/
+
+ unsigned long node_id (
+ ) const;
+ /*!
+ ensures
+ - Returns the id of the current processing node. That is,
+ returns a number N such that:
+ - N < number_of_nodes()
+ - N == the node id of the processing node that called node_id(). This
+ is a number that uniquely identifies the processing node.
+ !*/
+
+ unsigned long number_of_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of processing nodes participating in the BSP
+ computation.
+ !*/
+
+ template <typename T>
+ bool try_receive (
+ T& item
+ );
+ /*!
+ requires
+ - item is serializable
+ ensures
+ - if (this function returns true) then
+ - #item == the next message which was sent to the calling processing
+ node.
+ - else
+ - The following must have been true for this function to return false:
+ - All other nodes were blocked on calls to receive(),
+ try_receive(), or have terminated.
+ - There were not any messages in flight between any nodes.
+ - That is, if all the nodes had continued to block on receive
+ methods then they all would have blocked forever. Therefore,
+ this function only returns false once there are no more messages
+ to process by any node and there is no possibility of more being
+ generated until control is returned to the callers of receive
+ methods.
+ - When one BSP node's receive method returns because of the above
+ conditions then all of them will also return. That is, it is NOT the
+ case that just a subset of BSP nodes unblock. Moreover, they all
+ unblock at the same time.
+ throws
+ - dlib::socket_error:
+ This exception is thrown if some error occurs which prevents us from
+ communicating with other processing nodes.
+ - dlib::serialization_error or any exception thrown by the global
+ deserialize(T) routine:
+ This is thrown if there is a problem in deserialize(). This might
+ happen if the message sent doesn't match the type T expected by
+ try_receive().
+ !*/
+
+ template <typename T>
+ void receive (
+ T& item
+ );
+ /*!
+ requires
+ - item is serializable
+ ensures
+ - #item == the next message which was sent to the calling processing
+ node.
+ - This function is just a wrapper around try_receive() that throws an
+ exception if a message is not received (i.e. if try_receive() returns
+ false).
+ throws
+ - dlib::socket_error:
+ This exception is thrown if some error occurs which prevents us from
+ communicating with other processing nodes or if there was not a message
+ to receive.
+ - dlib::serialization_error or any exception thrown by the global
+ deserialize(T) routine:
+ This is thrown if there is a problem in deserialize(). This might
+ happen if the message sent doesn't match the type T expected by
+ receive().
+ !*/
+
+ template <typename T>
+ bool try_receive (
+ T& item,
+ unsigned long& sending_node_id
+ );
+ /*!
+ requires
+ - item is serializable
+ ensures
+ - if (this function returns true) then
+ - #item == the next message which was sent to the calling processing
+ node.
+ - #sending_node_id == the node id of the node that sent this message.
+ - #sending_node_id < number_of_nodes()
+ - else
+ - The following must have been true for this function to return false:
+ - All other nodes were blocked on calls to receive(),
+ try_receive(), or have terminated.
+ - There were not any messages in flight between any nodes.
+ - That is, if all the nodes had continued to block on receive
+ methods then they all would have blocked forever. Therefore,
+ this function only returns false once there are no more messages
+ to process by any node and there is no possibility of more being
+ generated until control is returned to the callers of receive
+ methods.
+ - When one BSP node's receive method returns because of the above
+ conditions then all of them will also return. That is, it is NOT the
+ case that just a subset of BSP nodes unblock. Moreover, they all
+ unblock at the same time.
+ throws
+ - dlib::socket_error:
+ This exception is thrown if some error occurs which prevents us from
+ communicating with other processing nodes.
+ - dlib::serialization_error or any exception thrown by the global
+ deserialize(T) routine:
+ This is thrown if there is a problem in deserialize(). This might
+ happen if the message sent doesn't match the type T expected by
+ try_receive().
+ !*/
+
+ template <typename T>
+ void receive (
+ T& item,
+ unsigned long& sending_node_id
+ );
+ /*!
+ requires
+ - item is serializable
+ ensures
+ - #item == the next message which was sent to the calling processing node.
+ - #sending_node_id == the node id of the node that sent this message.
+ - #sending_node_id < number_of_nodes()
+ - This function is just a wrapper around try_receive() that throws an
+ exception if a message is not received (i.e. if try_receive() returns
+ false).
+ throws
+ - dlib::socket_error:
+ This exception is thrown if some error occurs which prevents us from
+ communicating with other processing nodes or if there was not a message
+ to receive.
+ - dlib::serialization_error or any exception thrown by the global
+ deserialize(T) routine:
+ This is thrown if there is a problem in deserialize(). This might
+ happen if the message sent doesn't match the type T expected by
+ receive().
+ !*/
+
+ void receive (
+ );
+ /*!
+ ensures
+ - Waits for the following to all be true:
+ - All other nodes were blocked on calls to receive(), try_receive(), or
+ have terminated.
+ - There are not any messages in flight between any nodes.
+ - That is, if all the nodes had continued to block on receive methods
+ then they all would have blocked forever. Therefore, this function
+ only returns once there are no more messages to process by any node
+ and there is no possibility of more being generated until control is
+ returned to the callers of receive methods.
+ - When one BSP node's receive method returns because of the above
+ conditions then all of them will also return. That is, it is NOT the
+ case that just a subset of BSP nodes unblock. Moreover, they all unblock
+ at the same time.
+ throws
+ - dlib::socket_error:
+ This exception is thrown if some error occurs which prevents us from
+ communicating with other processing nodes or if a message is received
+ before this function would otherwise return.
+
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
+ - The processing node with a node ID of 0 will run locally on the machine
+ calling bsp_connect(). In particular, this node will execute funct(CONTEXT),
+ which is expected to carry out this node's portion of the BSP computation.
+ - The other processing nodes are executed on the hosts indicated by the input
+ argument. In particular, this function interprets hosts as a list addresses
+ identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
+ routines.
+ - This call to bsp_connect() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
+ - The processing node with a node ID of 0 will run locally on the machine
+ calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1),
+ which is expected to carry out this node's portion of the BSP computation.
+ - The other processing nodes are executed on the hosts indicated by the input
+ argument. In particular, this function interprets hosts as a list addresses
+ identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
+ routines.
+ - This call to bsp_connect() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1,arg2) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
+ - The processing node with a node ID of 0 will run locally on the machine
+ calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2),
+ which is expected to carry out this node's portion of the BSP computation.
+ - The other processing nodes are executed on the hosts indicated by the input
+ argument. In particular, this function interprets hosts as a list addresses
+ identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
+ routines.
+ - This call to bsp_connect() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
+ - The processing node with a node ID of 0 will run locally on the machine
+ calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3),
+ which is expected to carry out this node's portion of the BSP computation.
+ - The other processing nodes are executed on the hosts indicated by the input
+ argument. In particular, this function interprets hosts as a list addresses
+ identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
+ routines.
+ - This call to bsp_connect() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3,
+ typename ARG4
+ >
+ void bsp_connect (
+ const std::vector<network_address>& hosts,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3,
+ ARG4 arg4
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
+ - The processing node with a node ID of 0 will run locally on the machine
+ calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3,arg4),
+ which is expected to carry out this node's portion of the BSP computation.
+ - The other processing nodes are executed on the hosts indicated by the input
+ argument. In particular, this function interprets hosts as a list addresses
+ identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
+ routines.
+ - This call to bsp_connect() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct
+ );
+ /*!
+ requires
+ - listening_port != 0
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT) will be executed and it will
+ then be able to participate in the BSP computation as one of the processing
+ nodes.
+ - This function will listen on TCP port listening_port for a connection from
+ bsp_connect(). Once the connection is established, it will close the
+ listening port so it is free for use by other applications. The connection
+ and BSP computation will continue uninterrupted.
+ - This call to bsp_listen() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct,
+ ARG1 arg1
+ );
+ /*!
+ requires
+ - listening_port != 0
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT,arg1) will be executed and it will
+ then be able to participate in the BSP computation as one of the processing
+ nodes.
+ - This function will listen on TCP port listening_port for a connection from
+ bsp_connect(). Once the connection is established, it will close the
+ listening port so it is free for use by other applications. The connection
+ and BSP computation will continue uninterrupted.
+ - This call to bsp_listen() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2
+ );
+ /*!
+ requires
+ - listening_port != 0
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1,arg2) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT,arg1,arg2) will be executed and
+ it will then be able to participate in the BSP computation as one of the
+ processing nodes.
+ - This function will listen on TCP port listening_port for a connection from
+ bsp_connect(). Once the connection is established, it will close the
+ listening port so it is free for use by other applications. The connection
+ and BSP computation will continue uninterrupted.
+ - This call to bsp_listen() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3
+ );
+ /*!
+ requires
+ - listening_port != 0
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be
+ executed and it will then be able to participate in the BSP computation as
+ one of the processing nodes.
+ - This function will listen on TCP port listening_port for a connection from
+ bsp_connect(). Once the connection is established, it will close the
+ listening port so it is free for use by other applications. The connection
+ and BSP computation will continue uninterrupted.
+ - This call to bsp_listen() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3,
+ typename ARG4
+ >
+ void bsp_listen (
+ unsigned short listening_port,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3,
+ ARG4 arg4
+ );
+ /*!
+ requires
+ - listening_port != 0
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression
+ (i.e. funct must be a function or function object)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be
+ executed and it will then be able to participate in the BSP computation as
+ one of the processing nodes.
+ - This function will listen on TCP port listening_port for a connection from
+ bsp_connect(). Once the connection is established, it will close the
+ listening port so it is free for use by other applications. The connection
+ and BSP computation will continue uninterrupted.
+ - This call to bsp_listen() blocks until the BSP computation has completed on
+ all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT) must be a valid expression
+ (i.e. funct must be a function or function object)
+ - port_notify_function((unsigned short) 1234) must be a valid expression
+ (i.e. port_notify_function() must be a function or function object taking an
+ unsigned short)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT) will be executed and it will
+ then be able to participate in the BSP computation as one of the processing
+ nodes.
+ - if (listening_port != 0) then
+ - This function will listen on TCP port listening_port for a connection
+ from bsp_connect().
+ - else
+ - An available TCP port number is automatically selected and this function
+ will listen on it for a connection from bsp_connect().
+ - Once a listening port is opened, port_notify_function() is called with the
+ port number used. This provides a mechanism to find out what listening port
+ has been used if it is automatically selected. It also allows you to find
+ out when the routine has begun listening for an incoming connection from
+ bsp_connect().
+ - Once a connection is established, we will close the listening port so it is
+ free for use by other applications. The connection and BSP computation will
+ continue uninterrupted.
+ - This call to bsp_listen_dynamic_port() blocks until the BSP computation has
+ completed on all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1) must be a valid expression
+ (i.e. funct must be a function or function object)
+ - port_notify_function((unsigned short) 1234) must be a valid expression
+ (i.e. port_notify_function() must be a function or function object taking an
+ unsigned short)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT,arg1) will be executed and it
+ will then be able to participate in the BSP computation as one of the
+ processing nodes.
+ - if (listening_port != 0) then
+ - This function will listen on TCP port listening_port for a connection
+ from bsp_connect().
+ - else
+ - An available TCP port number is automatically selected and this function
+ will listen on it for a connection from bsp_connect().
+ - Once a listening port is opened, port_notify_function() is called with the
+ port number used. This provides a mechanism to find out what listening port
+ has been used if it is automatically selected. It also allows you to find
+ out when the routine has begun listening for an incoming connection from
+ bsp_connect().
+ - Once a connection is established, we will close the listening port so it is
+ free for use by other applications. The connection and BSP computation will
+ continue uninterrupted.
+ - This call to bsp_listen_dynamic_port() blocks until the BSP computation has
+ completed on all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1,
+ typename ARG2
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1,arg2) must be a valid expression
+ (i.e. funct must be a function or function object)
+ - port_notify_function((unsigned short) 1234) must be a valid expression
+ (i.e. port_notify_function() must be a function or function object taking an
+ unsigned short)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT,arg1,arg2) will be executed and
+ it will then be able to participate in the BSP computation as one of the
+ processing nodes.
+ - if (listening_port != 0) then
+ - This function will listen on TCP port listening_port for a connection
+ from bsp_connect().
+ - else
+ - An available TCP port number is automatically selected and this function
+ will listen on it for a connection from bsp_connect().
+ - Once a listening port is opened, port_notify_function() is called with the
+ port number used. This provides a mechanism to find out what listening port
+ has been used if it is automatically selected. It also allows you to find
+ out when the routine has begun listening for an incoming connection from
+ bsp_connect().
+ - Once a connection is established, we will close the listening port so it is
+ free for use by other applications. The connection and BSP computation will
+ continue uninterrupted.
+ - This call to bsp_listen_dynamic_port() blocks until the BSP computation has
+ completed on all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression
+ (i.e. funct must be a function or function object)
+ - port_notify_function((unsigned short) 1234) must be a valid expression
+ (i.e. port_notify_function() must be a function or function object taking an
+ unsigned short)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be
+ executed and it will then be able to participate in the BSP computation as
+ one of the processing nodes.
+ - if (listening_port != 0) then
+ - This function will listen on TCP port listening_port for a connection
+ from bsp_connect().
+ - else
+ - An available TCP port number is automatically selected and this function
+ will listen on it for a connection from bsp_connect().
+ - Once a listening port is opened, port_notify_function() is called with the
+ port number used. This provides a mechanism to find out what listening port
+ has been used if it is automatically selected. It also allows you to find
+ out when the routine has begun listening for an incoming connection from
+ bsp_connect().
+ - Once a connection is established, we will close the listening port so it is
+ free for use by other applications. The connection and BSP computation will
+ continue uninterrupted.
+ - This call to bsp_listen_dynamic_port() blocks until the BSP computation has
+ completed on all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename port_notify_function_type,
+ typename funct_type,
+ typename ARG1,
+ typename ARG2,
+ typename ARG3,
+ typename ARG4
+ >
+ void bsp_listen_dynamic_port (
+ unsigned short listening_port,
+ port_notify_function_type port_notify_function,
+ funct_type funct,
+ ARG1 arg1,
+ ARG2 arg2,
+ ARG3 arg3,
+ ARG4 arg4
+ );
+ /*!
+ requires
+ - let CONTEXT be an instance of a bsp_context object. Then:
+ - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression
+ (i.e. funct must be a function or function object)
+ - port_notify_function((unsigned short) 1234) must be a valid expression
+ (i.e. port_notify_function() must be a function or function object taking an
+ unsigned short)
+ ensures
+ - This function listens for a connection from the bsp_connect() routine. Once
+ this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be
+ executed and it will then be able to participate in the BSP computation as
+ one of the processing nodes.
+ - if (listening_port != 0) then
+ - This function will listen on TCP port listening_port for a connection
+ from bsp_connect().
+ - else
+ - An available TCP port number is automatically selected and this function
+ will listen on it for a connection from bsp_connect().
+ - Once a listening port is opened, port_notify_function() is called with the
+ port number used. This provides a mechanism to find out what listening port
+ has been used if it is automatically selected. It also allows you to find
+ out when the routine has begun listening for an incoming connection from
+ bsp_connect().
+ - Once a connection is established, we will close the listening port so it is
+ free for use by other applications. The connection and BSP computation will
+ continue uninterrupted.
+ - This call to bsp_listen_dynamic_port() blocks until the BSP computation has
+ completed on all processing nodes.
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is an error which prevents the BSP
+ job from executing.
+ - Any exception thrown by funct() will be propagated out of this call to
+ bsp_connect().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BsP_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/byte_orderer.h b/ml/dlib/dlib/byte_orderer.h
new file mode 100644
index 000000000..bc8f6108d
--- /dev/null
+++ b/ml/dlib/dlib/byte_orderer.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BYTE_ORDEREr_
+#define DLIB_BYTE_ORDEREr_
+
+
+#include "byte_orderer/byte_orderer_kernel_1.h"
+
+#endif // DLIB_BYTE_ORDEREr_
+
diff --git a/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_1.h b/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_1.h
new file mode 100644
index 000000000..9f8e8342f
--- /dev/null
+++ b/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_1.h
@@ -0,0 +1,176 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BYTE_ORDEREr_KERNEL_1_
+#define DLIB_BYTE_ORDEREr_KERNEL_1_
+
+#include "byte_orderer_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ class byte_orderer
+ {
+ /*!
+ INITIAL VALUE
+ - if (this machine is little endian) then
+ - little_endian == true
+ - else
+ - little_endian == false
+
+ CONVENTION
+ - host_is_big_endian() == !little_endian
+ - host_is_little_endian() == little_endian
+
+ - if (this machine is little endian) then
+ - little_endian == true
+ - else
+ - little_endian == false
+
+
+ !*/
+
+
+ public:
+
+ // this is here for backwards compatibility with older versions of dlib.
+ typedef byte_orderer kernel_1a;
+
+ byte_orderer (
+ )
+ {
+ // This will probably never be false but if it is then it means chars are not 8bits
+ // on this system. Which is a problem for this object.
+ COMPILE_TIME_ASSERT(sizeof(short) >= 2);
+
+ unsigned long temp = 1;
+ unsigned char* ptr = reinterpret_cast<unsigned char*>(&temp);
+ if (*ptr == 1)
+ little_endian = true;
+ else
+ little_endian = false;
+ }
+
+ virtual ~byte_orderer (
+ ){}
+
+ bool host_is_big_endian (
+ ) const { return !little_endian; }
+
+ bool host_is_little_endian (
+ ) const { return little_endian; }
+
+ template <
+ typename T
+ >
+ inline void host_to_network (
+ T& item
+ ) const
+ { if (little_endian) flip(item); }
+
+ template <
+ typename T
+ >
+ inline void network_to_host (
+ T& item
+ ) const { if (little_endian) flip(item); }
+
+ template <
+ typename T
+ >
+ void host_to_big (
+ T& item
+ ) const { if (little_endian) flip(item); }
+
+ template <
+ typename T
+ >
+ void big_to_host (
+ T& item
+ ) const { if (little_endian) flip(item); }
+
+ template <
+ typename T
+ >
+ void host_to_little (
+ T& item
+ ) const { if (!little_endian) flip(item); }
+
+ template <
+ typename T
+ >
+ void little_to_host (
+ T& item
+ ) const { if (!little_endian) flip(item); }
+
+
+ private:
+
+ template <
+ typename T,
+ size_t size
+ >
+ inline void flip (
+ T (&array)[size]
+ ) const
+ /*!
+ ensures
+ - flips the bytes in every element of this array
+ !*/
+ {
+ for (size_t i = 0; i < size; ++i)
+ {
+ flip(array[i]);
+ }
+ }
+
+ template <
+ typename T
+ >
+ inline void flip (
+ T& item
+ ) const
+ /*!
+ ensures
+ - reverses the byte ordering in item
+ !*/
+ {
+ DLIB_ASSERT_HAS_STANDARD_LAYOUT(T);
+
+ T value;
+
+ // If you are getting this as an error then you are probably using
+ // this object wrong. If you think you aren't then send me (Davis) an
+ // email and I'll either set you straight or change/remove this check so
+ // your stuff works :)
+ COMPILE_TIME_ASSERT(sizeof(T) <= sizeof(long double));
+
+ // If you are getting a compile error on this line then it means T is
+ // a pointer type. It doesn't make any sense to byte swap pointers
+ // since they have no meaning outside the context of their own process.
+ // So you probably just forgot to dereference that pointer before passing
+ // it to this function :)
+ COMPILE_TIME_ASSERT(is_pointer_type<T>::value == false);
+
+
+ const size_t size = sizeof(T);
+ unsigned char* const ptr = reinterpret_cast<unsigned char*>(&item);
+ unsigned char* const ptr_temp = reinterpret_cast<unsigned char*>(&value);
+ for (size_t i = 0; i < size; ++i)
+ ptr_temp[size-i-1] = ptr[i];
+
+ item = value;
+ }
+
+ bool little_endian;
+ };
+
+ // make flip not do anything at all for chars
+ template <> inline void byte_orderer::flip<char> ( char& ) const {}
+ template <> inline void byte_orderer::flip<unsigned char> ( unsigned char& ) const {}
+ template <> inline void byte_orderer::flip<signed char> ( signed char& ) const {}
+}
+
+#endif // DLIB_BYTE_ORDEREr_KERNEL_1_
+
diff --git a/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_abstract.h b/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_abstract.h
new file mode 100644
index 000000000..f7ea15103
--- /dev/null
+++ b/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_abstract.h
@@ -0,0 +1,149 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BYTE_ORDEREr_ABSTRACT_
+#ifdef DLIB_BYTE_ORDEREr_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+ class byte_orderer
+ {
+ /*!
+ INITIAL VALUE
+ This object has no state.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object simply provides a mechanism to convert data from a
+ host machine's own byte ordering to big or little endian and to
+ also do the reverse.
+
+ It also provides a pair of functions to convert to/from network byte
+ order where network byte order is big endian byte order. This pair of
+ functions does the exact same thing as the host_to_big() and big_to_host()
+ functions and is provided simply so that client code can use the most
+ self documenting name appropriate.
+
+ Also note that this object is capable of correctly flipping the contents
+ of arrays when the arrays are declared on the stack. e.g. You can
+ say things like:
+ int array[10];
+ bo.host_to_network(array);
+ !*/
+
+ public:
+
+ byte_orderer (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~byte_orderer (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ bool host_is_big_endian (
+ ) const;
+ /*!
+ ensures
+ - if (the host computer is a big endian machine) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool host_is_little_endian (
+ ) const;
+ /*!
+ ensures
+ - if (the host computer is a little endian machine) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ template <
+ typename T
+ >
+ void host_to_network (
+ T& item
+ ) const;
+ /*!
+ ensures
+ - #item == the value of item converted from host byte order
+ to network byte order.
+ !*/
+
+ template <
+ typename T
+ >
+ void network_to_host (
+ T& item
+ ) const;
+ /*!
+ ensures
+ - #item == the value of item converted from network byte order
+ to host byte order.
+ !*/
+
+ template <
+ typename T
+ >
+ void host_to_big (
+ T& item
+ ) const;
+ /*!
+ ensures
+ - #item == the value of item converted from host byte order
+ to big endian byte order.
+ !*/
+
+ template <
+ typename T
+ >
+ void big_to_host (
+ T& item
+ ) const;
+ /*!
+ ensures
+ - #item == the value of item converted from big endian byte order
+ to host byte order.
+ !*/
+
+ template <
+ typename T
+ >
+ void host_to_little (
+ T& item
+ ) const;
+ /*!
+ ensures
+ - #item == the value of item converted from host byte order
+ to little endian byte order.
+ !*/
+
+ template <
+ typename T
+ >
+ void little_to_host (
+ T& item
+ ) const;
+ /*!
+ ensures
+ - #item == the value of item converted from little endian byte order
+ to host byte order.
+ !*/
+
+ };
+}
+
+#endif // DLIB_BYTE_ORDEREr_ABSTRACT_
+
diff --git a/ml/dlib/dlib/cassert b/ml/dlib/dlib/cassert
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/cassert
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/clustering.h b/ml/dlib/dlib/clustering.h
new file mode 100644
index 000000000..3cbd6cfd4
--- /dev/null
+++ b/ml/dlib/dlib/clustering.h
@@ -0,0 +1,13 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CLuSTERING_
+#define DLIB_CLuSTERING_
+
+#include "clustering/modularity_clustering.h"
+#include "clustering/chinese_whispers.h"
+#include "clustering/spectral_cluster.h"
+#include "clustering/bottom_up_cluster.h"
+#include "svm/kkmeans.h"
+
+#endif // DLIB_CLuSTERING_
+
diff --git a/ml/dlib/dlib/clustering/bottom_up_cluster.h b/ml/dlib/dlib/clustering/bottom_up_cluster.h
new file mode 100644
index 000000000..f80b65108
--- /dev/null
+++ b/ml/dlib/dlib/clustering/bottom_up_cluster.h
@@ -0,0 +1,253 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_
+#define DLIB_BOTTOM_uP_CLUSTER_Hh_
+
+#include <queue>
+#include <map>
+
+#include "bottom_up_cluster_abstract.h"
+#include "../algs.h"
+#include "../matrix.h"
+#include "../disjoint_subsets.h"
+#include "../graph_utils.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace buc_impl
+ {
+ inline void merge_sets (
+ matrix<double>& dists,
+ unsigned long dest,
+ unsigned long src
+ )
+ {
+ for (long r = 0; r < dists.nr(); ++r)
+ dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src));
+ }
+
+ struct compare_dist
+ {
+ bool operator() (
+ const sample_pair& a,
+ const sample_pair& b
+ ) const
+ {
+ return a.distance() > b.distance();
+ }
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ unsigned long bottom_up_cluster (
+ const matrix_exp<EXP>& dists_,
+ std::vector<unsigned long>& labels,
+ unsigned long min_num_clusters,
+ double max_dist = std::numeric_limits<double>::infinity()
+ )
+ {
+ matrix<double> dists = matrix_cast<double>(dists_);
+ // make sure requires clause is not broken
+ DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0,
+ "\t unsigned long bottom_up_cluster()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t dists.nr(): " << dists.nr()
+ << "\n\t dists.nc(): " << dists.nc()
+ << "\n\t min_num_clusters: " << min_num_clusters
+ );
+
+ using namespace buc_impl;
+
+ labels.resize(dists.nr());
+ disjoint_subsets sets;
+ sets.set_size(dists.nr());
+ if (labels.size() == 0)
+ return 0;
+
+ // push all the edges in the graph into a priority queue so the best edges to merge
+ // come first.
+ std::priority_queue<sample_pair, std::vector<sample_pair>, compare_dist> que;
+ for (long r = 0; r < dists.nr(); ++r)
+ for (long c = r+1; c < dists.nc(); ++c)
+ que.push(sample_pair(r,c,dists(r,c)));
+
+ // Now start merging nodes.
+ for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter)
+ {
+ // find the next best thing to merge.
+ double best_dist = que.top().distance();
+ unsigned long a = sets.find_set(que.top().index1());
+ unsigned long b = sets.find_set(que.top().index2());
+ que.pop();
+ // we have been merging and modifying the distances, so make sure this distance
+ // is still valid and these guys haven't been merged already.
+ while(a == b || best_dist < dists(a,b))
+ {
+ // Haven't merged it yet, so put it back in with updated distance for
+ // reconsideration later.
+ if (a != b)
+ que.push(sample_pair(a, b, dists(a, b)));
+
+ best_dist = que.top().distance();
+ a = sets.find_set(que.top().index1());
+ b = sets.find_set(que.top().index2());
+ que.pop();
+ }
+
+
+ // now merge these sets if the best distance is small enough
+ if (best_dist > max_dist)
+ break;
+ unsigned long news = sets.merge_sets(a,b);
+ unsigned long olds = (news==a)?b:a;
+ merge_sets(dists, news, olds);
+ }
+
+ // figure out which cluster each element is in. Also make sure the labels are
+ // contiguous.
+ std::map<unsigned long, unsigned long> relabel;
+ for (unsigned long r = 0; r < labels.size(); ++r)
+ {
+ unsigned long l = sets.find_set(r);
+ // relabel to make contiguous
+ if (relabel.count(l) == 0)
+ {
+ unsigned long next = relabel.size();
+ relabel[l] = next;
+ }
+ labels[r] = relabel[l];
+ }
+
+
+ return relabel.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ struct snl_range
+ {
+ snl_range() = default;
+ snl_range(double val) : lower(val), upper(val) {}
+ snl_range(double l, double u) : lower(l), upper(u) { DLIB_ASSERT(lower <= upper)}
+
+ double lower = 0;
+ double upper = 0;
+
+ double width() const { return upper-lower; }
+ bool operator<(const snl_range& item) const { return lower < item.lower; }
+ };
+
+ inline snl_range merge(const snl_range& a, const snl_range& b)
+ {
+ return snl_range(std::min(a.lower, b.lower), std::max(a.upper, b.upper));
+ }
+
+ inline double distance (const snl_range& a, const snl_range& b)
+ {
+ return std::max(a.lower,b.lower) - std::min(a.upper,b.upper);
+ }
+
+ inline std::ostream& operator<< (std::ostream& out, const snl_range& item )
+ {
+ out << "["<<item.lower<<","<<item.upper<<"]";
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<snl_range> segment_number_line (
+ const std::vector<double>& x,
+ const double max_range_width
+ )
+ {
+ DLIB_CASSERT(max_range_width >= 0);
+
+ // create initial ranges, one for each value in x. So initially, all the ranges have
+ // width of 0.
+ std::vector<snl_range> ranges;
+ for (auto v : x)
+ ranges.push_back(v);
+ std::sort(ranges.begin(), ranges.end());
+
+ std::vector<snl_range> greedy_final_ranges;
+ if (ranges.size() == 0)
+ return greedy_final_ranges;
+ // We will try two different clustering strategies. One that does a simple greedy left
+ // to right sweep and another that does a bottom up agglomerative clustering. This
+ // first loop runs the greedy left to right sweep. Then at the end of this routine we
+ // will return the results that produced the tightest clustering.
+ greedy_final_ranges.push_back(ranges[0]);
+ for (size_t i = 1; i < ranges.size(); ++i)
+ {
+ auto m = merge(greedy_final_ranges.back(), ranges[i]);
+ if (m.width() <= max_range_width)
+ greedy_final_ranges.back() = m;
+ else
+ greedy_final_ranges.push_back(ranges[i]);
+ }
+
+
+ // Here we do the bottom up clustering. So compute the edges connecting our ranges.
+ // We will simply say there are edges between ranges if and only if they are
+ // immediately adjacent on the number line.
+ std::vector<sample_pair> edges;
+ for (size_t i = 1; i < ranges.size(); ++i)
+ edges.push_back(sample_pair(i-1,i, distance(ranges[i-1],ranges[i])));
+ std::sort(edges.begin(), edges.end(), order_by_distance<sample_pair>);
+
+ disjoint_subsets sets;
+ sets.set_size(ranges.size());
+
+ // Now start merging nodes.
+ for (auto edge : edges)
+ {
+ // find the next best thing to merge.
+ unsigned long a = sets.find_set(edge.index1());
+ unsigned long b = sets.find_set(edge.index2());
+
+ // merge it if it doesn't result in an interval that's too big.
+ auto m = merge(ranges[a], ranges[b]);
+ if (m.width() <= max_range_width)
+ {
+ unsigned long news = sets.merge_sets(a,b);
+ ranges[news] = m;
+ }
+ }
+
+ // Now create a list of the final ranges. We will do this by keeping track of which
+ // range we already added to final_ranges.
+ std::vector<snl_range> final_ranges;
+ std::vector<bool> already_output(ranges.size(), false);
+ for (unsigned long i = 0; i < sets.size(); ++i)
+ {
+ auto s = sets.find_set(i);
+ if (!already_output[s])
+ {
+ final_ranges.push_back(ranges[s]);
+ already_output[s] = true;
+ }
+ }
+
+ // only use the greedy clusters if they found a clustering with fewer clusters.
+ // Otherwise, the bottom up clustering probably produced a more sensible clustering.
+ if (final_ranges.size() <= greedy_final_ranges.size())
+ return final_ranges;
+ else
+ return greedy_final_ranges;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BOTTOM_uP_CLUSTER_Hh_
+
diff --git a/ml/dlib/dlib/clustering/bottom_up_cluster_abstract.h b/ml/dlib/dlib/clustering/bottom_up_cluster_abstract.h
new file mode 100644
index 000000000..72d362c12
--- /dev/null
+++ b/ml/dlib/dlib/clustering/bottom_up_cluster_abstract.h
@@ -0,0 +1,136 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_
+#ifdef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_
+
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ unsigned long bottom_up_cluster (
+ const matrix_exp<EXP>& dists,
+ std::vector<unsigned long>& labels,
+ unsigned long min_num_clusters,
+ double max_dist = std::numeric_limits<double>::infinity()
+ );
+ /*!
+ requires
+ - dists.nr() == dists.nc()
+ - min_num_clusters > 0
+ - dists == trans(dists)
+ (l.e. dists should be symmetric)
+ ensures
+ - Runs a bottom up agglomerative clustering algorithm.
+ - Interprets dists as a matrix that gives the distances between dists.nr()
+ items. In particular, we take dists(i,j) to be the distance between the ith
+ and jth element of some set. This function clusters the elements of this set
+ into at least min_num_clusters (or dists.nr() if there aren't enough
+ elements). Additionally, within each cluster, the maximum pairwise distance
+ between any two cluster elements is <= max_dist.
+ - returns the number of clusters found.
+ - #labels.size() == dists.nr()
+ - for all valid i:
+ - #labels[i] == the cluster ID of the node with index i (i.e. the node
+ corresponding to the distances dists(i,*)).
+ - 0 <= #labels[i] < the number of clusters found
+ (i.e. cluster IDs are assigned contiguously and start at 0)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ struct snl_range
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an interval on the real number line. It is used
+ to store the outputs of the segment_number_line() routine defined below.
+ !*/
+
+ snl_range(
+ );
+ /*!
+ ensures
+ - #lower == 0
+ - #upper == 0
+ !*/
+
+ snl_range(
+ double val
+ );
+ /*!
+ ensures
+ - #lower == val
+ - #upper == val
+ !*/
+
+ snl_range(
+ double l,
+ double u
+ );
+ /*!
+ requires
+ - l <= u
+ ensures
+ - #lower == l
+ - #upper == u
+ !*/
+
+ double lower;
+ double upper;
+
+ double width(
+ ) const { return upper-lower; }
+ /*!
+ ensures
+ - returns the width of this interval on the number line.
+ !*/
+
+ bool operator<(const snl_range& item) const { return lower < item.lower; }
+ /*!
+ ensures
+ - provides a total ordering of snl_range objects assuming they are
+ non-overlapping.
+ !*/
+ };
+
+ std::ostream& operator<< (std::ostream& out, const snl_range& item );
+ /*!
+ ensures
+ - prints item to out in the form [lower,upper].
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::vector<snl_range> segment_number_line (
+ const std::vector<double>& x,
+ const double max_range_width
+ );
+ /*!
+ requires
+ - max_range_width >= 0
+ ensures
+ - Finds a clustering of the values in x and returns the ranges that define the
+ clustering. This routine uses a combination of bottom up clustering and a
+ simple greedy scan to try and find the most compact set of ranges that
+ contain all the values in x.
+ - This routine has approximately linear runtime.
+ - Every value in x will be contained inside one of the returned snl_range
+ objects;
+ - All returned snl_range object's will have a width() <= max_range_width and
+ will also be non-overlapping.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/clustering/chinese_whispers.h b/ml/dlib/dlib/clustering/chinese_whispers.h
new file mode 100644
index 000000000..332cce1a0
--- /dev/null
+++ b/ml/dlib/dlib/clustering/chinese_whispers.h
@@ -0,0 +1,135 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CHINESE_WHISPErS_Hh_
+#define DLIB_CHINESE_WHISPErS_Hh_
+
+#include "chinese_whispers_abstract.h"
+#include <vector>
+#include "../rand.h"
+#include "../graph_utils/edge_list_graphs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ inline unsigned long chinese_whispers (
+ const std::vector<ordered_sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const unsigned long num_iterations,
+ dlib::rand& rnd
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_ordered_by_index(edges),
+ "\t unsigned long chinese_whispers()"
+ << "\n\t Invalid inputs were given to this function"
+ );
+
+ labels.clear();
+ if (edges.size() == 0)
+ return 0;
+
+ std::vector<std::pair<unsigned long, unsigned long> > neighbors;
+ find_neighbor_ranges(edges, neighbors);
+
+ // Initialize the labels, each node gets a different label.
+ labels.resize(neighbors.size());
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ labels[i] = i;
+
+
+ for (unsigned long iter = 0; iter < neighbors.size()*num_iterations; ++iter)
+ {
+ // Pick a random node.
+ const unsigned long idx = rnd.get_random_64bit_number()%neighbors.size();
+
+ // Count how many times each label happens amongst our neighbors.
+ std::map<unsigned long, double> labels_to_counts;
+ const unsigned long end = neighbors[idx].second;
+ for (unsigned long i = neighbors[idx].first; i != end; ++i)
+ {
+ labels_to_counts[labels[edges[i].index2()]] += edges[i].distance();
+ }
+
+ // find the most common label
+ std::map<unsigned long, double>::iterator i;
+ double best_score = -std::numeric_limits<double>::infinity();
+ unsigned long best_label = labels[idx];
+ for (i = labels_to_counts.begin(); i != labels_to_counts.end(); ++i)
+ {
+ if (i->second > best_score)
+ {
+ best_score = i->second;
+ best_label = i->first;
+ }
+ }
+
+ labels[idx] = best_label;
+ }
+
+
+ // Remap the labels into a contiguous range. First we find the
+ // mapping.
+ std::map<unsigned long,unsigned long> label_remap;
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ const unsigned long next_id = label_remap.size();
+ if (label_remap.count(labels[i]) == 0)
+ label_remap[labels[i]] = next_id;
+ }
+ // now apply the mapping to all the labels.
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ labels[i] = label_remap[labels[i]];
+ }
+
+ return label_remap.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline unsigned long chinese_whispers (
+ const std::vector<sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const unsigned long num_iterations,
+ dlib::rand& rnd
+ )
+ {
+ std::vector<ordered_sample_pair> oedges;
+ convert_unordered_to_ordered(edges, oedges);
+ std::sort(oedges.begin(), oedges.end(), &order_by_index<ordered_sample_pair>);
+
+ return chinese_whispers(oedges, labels, num_iterations, rnd);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline unsigned long chinese_whispers (
+ const std::vector<sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const unsigned long num_iterations = 100
+ )
+ {
+ dlib::rand rnd;
+ return chinese_whispers(edges, labels, num_iterations, rnd);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline unsigned long chinese_whispers (
+ const std::vector<ordered_sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const unsigned long num_iterations = 100
+ )
+ {
+ dlib::rand rnd;
+ return chinese_whispers(edges, labels, num_iterations, rnd);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CHINESE_WHISPErS_Hh_
+
diff --git a/ml/dlib/dlib/clustering/chinese_whispers_abstract.h b/ml/dlib/dlib/clustering/chinese_whispers_abstract.h
new file mode 100644
index 000000000..7a184c6f9
--- /dev/null
+++ b/ml/dlib/dlib/clustering/chinese_whispers_abstract.h
@@ -0,0 +1,97 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_
+#ifdef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_
+
+#include <vector>
+#include "../rand.h"
+#include "../graph_utils/ordered_sample_pair_abstract.h"
+#include "../graph_utils/sample_pair_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long chinese_whispers (
+ const std::vector<ordered_sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const unsigned long num_iterations,
+ dlib::rand& rnd
+ );
+ /*!
+ requires
+ - is_ordered_by_index(edges) == true
+ ensures
+ - This function implements the graph clustering algorithm described in the
+ paper: Chinese Whispers - an Efficient Graph Clustering Algorithm and its
+ Application to Natural Language Processing Problems by Chris Biemann.
+ - Interprets edges as a directed graph. That is, it contains the edges on the
+ said graph and the ordered_sample_pair::distance() values define the edge
+ weights (larger values indicating a stronger edge connection between the
+ nodes). If an edge has a distance() value of infinity then it is considered
+ a "must link" edge.
+ - returns the number of clusters found.
+ - #labels.size() == max_index_plus_one(edges)
+ - for all valid i:
+ - #labels[i] == the cluster ID of the node with index i in the graph.
+ - 0 <= #labels[i] < the number of clusters found
+ (i.e. cluster IDs are assigned contiguously and start at 0)
+ - Duplicate edges are interpreted as if there had been just one edge with a
+ distance value equal to the sum of all the duplicate edge's distance values.
+ - The algorithm performs exactly num_iterations passes over the graph before
+ terminating.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long chinese_whispers (
+ const std::vector<sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const unsigned long num_iterations,
+ dlib::rand& rnd
+ );
+ /*!
+ ensures
+ - This function is identical to the above chinese_whispers() routine except
+ that it operates on a vector of sample_pair objects instead of
+ ordered_sample_pairs. Therefore, this is simply a convenience routine. In
+ particular, it is implemented by transforming the given edges into
+ ordered_sample_pairs and then calling the chinese_whispers() routine defined
+ above.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long chinese_whispers (
+ const std::vector<ordered_sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const unsigned long num_iterations = 100
+ );
+ /*!
+ requires
+ - is_ordered_by_index(edges) == true
+ ensures
+ - performs: return chinese_whispers(edges, labels, num_iterations, rnd)
+ where rnd is a default initialized dlib::rand object.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long chinese_whispers (
+ const std::vector<sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const unsigned long num_iterations = 100
+ );
+ /*!
+ ensures
+ - performs: return chinese_whispers(edges, labels, num_iterations, rnd)
+ where rnd is a default initialized dlib::rand object.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/clustering/modularity_clustering.h b/ml/dlib/dlib/clustering/modularity_clustering.h
new file mode 100644
index 000000000..8b8a0b0a5
--- /dev/null
+++ b/ml/dlib/dlib/clustering/modularity_clustering.h
@@ -0,0 +1,515 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MODULARITY_ClUSTERING__H__
+#define DLIB_MODULARITY_ClUSTERING__H__
+
+#include "modularity_clustering_abstract.h"
+#include "../sparse_vector.h"
+#include "../graph_utils/edge_list_graphs.h"
+#include "../matrix.h"
+#include "../rand.h"
+
+namespace dlib
+{
+
+// -----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline double newman_cluster_split (
+ dlib::rand& rnd,
+ const std::vector<ordered_sample_pair>& edges,
+ const matrix<double,0,1>& node_degrees, // k from the Newman paper
+ const matrix<double,0,1>& Bdiag, // diag(B) from the Newman paper
+ const double& edge_sum, // m from the Newman paper
+ matrix<double,0,1>& labels,
+ const double eps,
+ const unsigned long max_iterations
+ )
+ /*!
+ requires
+ - node_degrees.size() == max_index_plus_one(edges)
+ - Bdiag.size() == max_index_plus_one(edges)
+ - edges must be sorted according to order_by_index()
+ ensures
+ - This routine splits a graph into two subgraphs using the Newman
+ clustering method.
+ - returns the modularity obtained when the graph is split according
+ to the contents of #labels.
+ - #labels.size() == node_degrees.size()
+ - for all valid i: #labels(i) == -1 or +1
+ - if (this function returns 0) then
+ - all the labels are equal, i.e. the graph is not split.
+ !*/
+ {
+ // Scale epsilon so that it is relative to the expected value of an element of a
+ // unit vector of length node_degrees.size().
+ const double power_iter_eps = eps * std::sqrt(1.0/node_degrees.size());
+
+ // Make a random unit vector and put in labels.
+ labels.set_size(node_degrees.size());
+ for (long i = 0; i < labels.size(); ++i)
+ labels(i) = rnd.get_random_gaussian();
+ labels /= length(labels);
+
+ matrix<double,0,1> Bv, Bv_unit;
+
+ // Do the power iteration for a while.
+ double eig = -1;
+ double offset = 0;
+ while (eig < 0)
+ {
+
+ // any number larger than power_iter_eps
+ double iteration_change = power_iter_eps*2+1;
+ for (unsigned long i = 0; i < max_iterations && iteration_change > power_iter_eps; ++i)
+ {
+ sparse_matrix_vector_multiply(edges, labels, Bv);
+ Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees;
+
+ if (offset != 0)
+ {
+ Bv -= offset*labels;
+ }
+
+
+ const double len = length(Bv);
+ if (len != 0)
+ {
+ Bv_unit = Bv/len;
+ iteration_change = max(abs(labels-Bv_unit));
+ labels.swap(Bv_unit);
+ }
+ else
+ {
+ // Had a bad time, pick another random vector and try it with the
+ // power iteration.
+ for (long i = 0; i < labels.size(); ++i)
+ labels(i) = rnd.get_random_gaussian();
+ }
+ }
+
+ eig = dot(Bv,labels);
+ // we will repeat this loop if the largest eigenvalue is negative
+ offset = eig;
+ }
+
+
+ for (long i = 0; i < labels.size(); ++i)
+ {
+ if (labels(i) > 0)
+ labels(i) = 1;
+ else
+ labels(i) = -1;
+ }
+
+
+ // compute B*labels, store result in Bv.
+ sparse_matrix_vector_multiply(edges, labels, Bv);
+ Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees;
+
+ // Do some label refinement. In this step we swap labels if it
+ // improves the modularity score.
+ bool flipped_label = true;
+ while(flipped_label)
+ {
+ flipped_label = false;
+ unsigned long idx = 0;
+ for (long i = 0; i < labels.size(); ++i)
+ {
+ const double val = -2*labels(i);
+ const double increase = 4*Bdiag(i) + 2*val*Bv(i);
+
+ // if there is an increase in modularity for swapping this label
+ if (increase > 0)
+ {
+ labels(i) *= -1;
+ while (idx < edges.size() && edges[idx].index1() == (unsigned long)i)
+ {
+ const long j = edges[idx].index2();
+ Bv(j) += val*edges[idx].distance();
+ ++idx;
+ }
+
+ Bv -= (val*node_degrees(i)/(2*edge_sum))*node_degrees;
+
+ flipped_label = true;
+ }
+ else
+ {
+ while (idx < edges.size() && edges[idx].index1() == (unsigned long)i)
+ {
+ ++idx;
+ }
+ }
+ }
+ }
+
+
+ const double modularity = dot(Bv, labels)/(4*edge_sum);
+
+ return modularity;
+ }
+
+ // -------------------------------------------------------------------------------------
+
+ inline unsigned long newman_cluster_helper (
+ dlib::rand& rnd,
+ const std::vector<ordered_sample_pair>& edges,
+ const matrix<double,0,1>& node_degrees, // k from the Newman paper
+ const matrix<double,0,1>& Bdiag, // diag(B) from the Newman paper
+ const double& edge_sum, // m from the Newman paper
+ std::vector<unsigned long>& labels,
+ double modularity_threshold,
+ const double eps,
+ const unsigned long max_iterations
+ )
+ /*!
+ ensures
+ - returns the number of clusters the data was split into
+ !*/
+ {
+ matrix<double,0,1> l;
+ const double modularity = newman_cluster_split(rnd,edges,node_degrees,Bdiag,edge_sum,l,eps,max_iterations);
+
+
+ // We need to collapse the node index values down to contiguous values. So
+ // we use the following two vectors to contain the mappings from input index
+ // values to their corresponding index values in each split.
+ std::vector<unsigned long> left_idx_map(node_degrees.size());
+ std::vector<unsigned long> right_idx_map(node_degrees.size());
+
+ // figure out how many nodes went into each side of the split.
+ unsigned long num_left_split = 0;
+ unsigned long num_right_split = 0;
+ for (long i = 0; i < l.size(); ++i)
+ {
+ if (l(i) > 0)
+ {
+ left_idx_map[i] = num_left_split;
+ ++num_left_split;
+ }
+ else
+ {
+ right_idx_map[i] = num_right_split;
+ ++num_right_split;
+ }
+ }
+
+ // do a recursive split if it will improve the modularity.
+ if (modularity > modularity_threshold && num_left_split > 0 && num_right_split > 0)
+ {
+
+ // split the node_degrees and Bdiag matrices into left and right split parts
+ matrix<double,0,1> left_node_degrees(num_left_split);
+ matrix<double,0,1> right_node_degrees(num_right_split);
+ matrix<double,0,1> left_Bdiag(num_left_split);
+ matrix<double,0,1> right_Bdiag(num_right_split);
+ for (long i = 0; i < l.size(); ++i)
+ {
+ if (l(i) > 0)
+ {
+ left_node_degrees(left_idx_map[i]) = node_degrees(i);
+ left_Bdiag(left_idx_map[i]) = Bdiag(i);
+ }
+ else
+ {
+ right_node_degrees(right_idx_map[i]) = node_degrees(i);
+ right_Bdiag(right_idx_map[i]) = Bdiag(i);
+ }
+ }
+
+
+ // put the edges from one side of the split into split_edges
+ std::vector<ordered_sample_pair> split_edges;
+ modularity_threshold = 0;
+ for (unsigned long k = 0; k < edges.size(); ++k)
+ {
+ const unsigned long i = edges[k].index1();
+ const unsigned long j = edges[k].index2();
+ const double d = edges[k].distance();
+ if (l(i) > 0 && l(j) > 0)
+ {
+ split_edges.push_back(ordered_sample_pair(left_idx_map[i], left_idx_map[j], d));
+ modularity_threshold += d;
+ }
+ }
+ modularity_threshold -= sum(left_node_degrees*sum(left_node_degrees))/(2*edge_sum);
+ modularity_threshold /= 4*edge_sum;
+
+ unsigned long num_left_clusters;
+ std::vector<unsigned long> left_labels;
+ num_left_clusters = newman_cluster_helper(rnd,split_edges,left_node_degrees,left_Bdiag,
+ edge_sum,left_labels,modularity_threshold,
+ eps, max_iterations);
+
+ // now load the other side into split_edges and cluster it as well
+ split_edges.clear();
+ modularity_threshold = 0;
+ for (unsigned long k = 0; k < edges.size(); ++k)
+ {
+ const unsigned long i = edges[k].index1();
+ const unsigned long j = edges[k].index2();
+ const double d = edges[k].distance();
+ if (l(i) < 0 && l(j) < 0)
+ {
+ split_edges.push_back(ordered_sample_pair(right_idx_map[i], right_idx_map[j], d));
+ modularity_threshold += d;
+ }
+ }
+ modularity_threshold -= sum(right_node_degrees*sum(right_node_degrees))/(2*edge_sum);
+ modularity_threshold /= 4*edge_sum;
+
+ unsigned long num_right_clusters;
+ std::vector<unsigned long> right_labels;
+ num_right_clusters = newman_cluster_helper(rnd,split_edges,right_node_degrees,right_Bdiag,
+ edge_sum,right_labels,modularity_threshold,
+ eps, max_iterations);
+
+ // Now merge the labels from the two splits.
+ labels.resize(node_degrees.size());
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ // if this node was in the left split
+ if (l(i) > 0)
+ {
+ labels[i] = left_labels[left_idx_map[i]];
+ }
+ else // if this node was in the right split
+ {
+ labels[i] = right_labels[right_idx_map[i]] + num_left_clusters;
+ }
+ }
+
+
+ return num_left_clusters + num_right_clusters;
+ }
+ else
+ {
+ labels.assign(node_degrees.size(),0);
+ return 1;
+ }
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline unsigned long newman_cluster (
+ const std::vector<ordered_sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const double eps = 1e-4,
+ const unsigned long max_iterations = 2000
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_ordered_by_index(edges),
+ "\t unsigned long newman_cluster()"
+ << "\n\t Invalid inputs were given to this function"
+ );
+
+ labels.clear();
+ if (edges.size() == 0)
+ return 0;
+
+ const unsigned long num_nodes = max_index_plus_one(edges);
+
+ // compute the node_degrees vector, edge_sum value, and diag(B).
+ matrix<double,0,1> node_degrees(num_nodes);
+ matrix<double,0,1> Bdiag(num_nodes);
+ Bdiag = 0;
+ double edge_sum = 0;
+ node_degrees = 0;
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ node_degrees(edges[i].index1()) += edges[i].distance();
+ edge_sum += edges[i].distance();
+ if (edges[i].index1() == edges[i].index2())
+ Bdiag(edges[i].index1()) += edges[i].distance();
+ }
+ edge_sum /= 2;
+ Bdiag -= squared(node_degrees)/(2*edge_sum);
+
+
+ dlib::rand rnd;
+ return impl::newman_cluster_helper(rnd,edges,node_degrees,Bdiag,edge_sum,labels,0,eps,max_iterations);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline unsigned long newman_cluster (
+ const std::vector<sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const double eps = 1e-4,
+ const unsigned long max_iterations = 2000
+ )
+ {
+ std::vector<ordered_sample_pair> oedges;
+ convert_unordered_to_ordered(edges, oedges);
+ std::sort(oedges.begin(), oedges.end(), &order_by_index<ordered_sample_pair>);
+
+ return newman_cluster(oedges, labels, eps, max_iterations);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline std::vector<unsigned long> remap_labels (
+ const std::vector<unsigned long>& labels,
+ unsigned long& num_labels
+ )
+ /*!
+ ensures
+ - This function takes labels and produces a mapping which maps elements of
+ labels into the most compact range in [0, max] as possible. In particular,
+ there won't be any unused integers in the mapped range.
+ - #num_labels == the number of distinct values in labels.
+ - returns a vector V such that:
+ - V.size() == labels.size()
+ - max(mat(V))+1 == num_labels.
+ - for all valid i,j:
+ - if (labels[i] == labels[j]) then
+ - V[i] == V[j]
+ - else
+ - V[i] != V[j]
+ !*/
+ {
+ std::map<unsigned long, unsigned long> temp;
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ if (temp.count(labels[i]) == 0)
+ {
+ const unsigned long next = temp.size();
+ temp[labels[i]] = next;
+ }
+ }
+
+ num_labels = temp.size();
+
+ std::vector<unsigned long> result(labels.size());
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ result[i] = temp[labels[i]];
+ }
+ return result;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double modularity (
+ const std::vector<sample_pair>& edges,
+ const std::vector<unsigned long>& labels
+ )
+ {
+ const unsigned long num_nodes = max_index_plus_one(edges);
+ // make sure requires clause is not broken
+ DLIB_ASSERT(labels.size() == num_nodes,
+ "\t double modularity()"
+ << "\n\t Invalid inputs were given to this function"
+ );
+
+ unsigned long num_labels;
+ const std::vector<unsigned long>& labels_ = dlib::impl::remap_labels(labels,num_labels);
+
+ std::vector<double> cluster_sums(num_labels,0);
+ std::vector<double> k(num_nodes,0);
+
+ double Q = 0;
+ double m = 0;
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ const unsigned long n1 = edges[i].index1();
+ const unsigned long n2 = edges[i].index2();
+ k[n1] += edges[i].distance();
+ if (n1 != n2)
+ k[n2] += edges[i].distance();
+
+ if (n1 != n2)
+ m += edges[i].distance();
+ else
+ m += edges[i].distance()/2;
+
+ if (labels_[n1] == labels_[n2])
+ {
+ if (n1 != n2)
+ Q += 2*edges[i].distance();
+ else
+ Q += edges[i].distance();
+ }
+ }
+
+ if (m == 0)
+ return 0;
+
+ for (unsigned long i = 0; i < labels_.size(); ++i)
+ {
+ cluster_sums[labels_[i]] += k[i];
+ }
+
+ for (unsigned long i = 0; i < labels_.size(); ++i)
+ {
+ Q -= k[i]*cluster_sums[labels_[i]]/(2*m);
+ }
+
+ return 1.0/(2*m)*Q;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double modularity (
+ const std::vector<ordered_sample_pair>& edges,
+ const std::vector<unsigned long>& labels
+ )
+ {
+ const unsigned long num_nodes = max_index_plus_one(edges);
+ // make sure requires clause is not broken
+ DLIB_ASSERT(labels.size() == num_nodes,
+ "\t double modularity()"
+ << "\n\t Invalid inputs were given to this function"
+ );
+
+
+ unsigned long num_labels;
+ const std::vector<unsigned long>& labels_ = dlib::impl::remap_labels(labels,num_labels);
+
+ std::vector<double> cluster_sums(num_labels,0);
+ std::vector<double> k(num_nodes,0);
+
+ double Q = 0;
+ double m = 0;
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ const unsigned long n1 = edges[i].index1();
+ const unsigned long n2 = edges[i].index2();
+ k[n1] += edges[i].distance();
+ m += edges[i].distance();
+ if (labels_[n1] == labels_[n2])
+ {
+ Q += edges[i].distance();
+ }
+ }
+
+ if (m == 0)
+ return 0;
+
+ for (unsigned long i = 0; i < labels_.size(); ++i)
+ {
+ cluster_sums[labels_[i]] += k[i];
+ }
+
+ for (unsigned long i = 0; i < labels_.size(); ++i)
+ {
+ Q -= k[i]*cluster_sums[labels_[i]]/m;
+ }
+
+ return 1.0/m*Q;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MODULARITY_ClUSTERING__H__
+
diff --git a/ml/dlib/dlib/clustering/modularity_clustering_abstract.h b/ml/dlib/dlib/clustering/modularity_clustering_abstract.h
new file mode 100644
index 000000000..c1e7c20c4
--- /dev/null
+++ b/ml/dlib/dlib/clustering/modularity_clustering_abstract.h
@@ -0,0 +1,125 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_
+#ifdef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_
+
+#include <vector>
+#include "../graph_utils/ordered_sample_pair_abstract.h"
+#include "../graph_utils/sample_pair_abstract.h"
+
+namespace dlib
+{
+
+// -----------------------------------------------------------------------------------------
+
+ double modularity (
+ const std::vector<sample_pair>& edges,
+ const std::vector<unsigned long>& labels
+ );
+ /*!
+ requires
+ - labels.size() == max_index_plus_one(edges)
+ - for all valid i:
+ - 0 <= edges[i].distance() < std::numeric_limits<double>::infinity()
+ ensures
+ - Interprets edges as an undirected graph. That is, it contains the edges on
+ the said graph and the sample_pair::distance() values define the edge weights
+ (larger values indicating a stronger edge connection between the nodes).
+ - This function returns the modularity value obtained when the given input
+ graph is broken into subgraphs according to the contents of labels. In
+ particular, we say that two nodes with indices i and j are in the same
+ subgraph or community if and only if labels[i] == labels[j].
+ - Duplicate edges are interpreted as if there had been just one edge with a
+ distance value equal to the sum of all the duplicate edge's distance values.
+ - See the paper Modularity and community structure in networks by M. E. J. Newman
+ for a detailed definition.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ double modularity (
+ const std::vector<ordered_sample_pair>& edges,
+ const std::vector<unsigned long>& labels
+ );
+ /*!
+ requires
+ - labels.size() == max_index_plus_one(edges)
+ - for all valid i:
+ - 0 <= edges[i].distance() < std::numeric_limits<double>::infinity()
+ ensures
+ - Interprets edges as a directed graph. That is, it contains the edges on the
+ said graph and the ordered_sample_pair::distance() values define the edge
+ weights (larger values indicating a stronger edge connection between the
+ nodes). Note that, generally, modularity is only really defined for
+ undirected graphs. Therefore, the "directed graph" given to this function
+ should have symmetric edges between all nodes. The reason this function is
+ provided at all is because sometimes a vector of ordered_sample_pair objects
+ is a useful representation of an undirected graph.
+ - This function returns the modularity value obtained when the given input
+ graph is broken into subgraphs according to the contents of labels. In
+ particular, we say that two nodes with indices i and j are in the same
+ subgraph or community if and only if labels[i] == labels[j].
+ - Duplicate edges are interpreted as if there had been just one edge with a
+ distance value equal to the sum of all the duplicate edge's distance values.
+ - See the paper Modularity and community structure in networks by M. E. J. Newman
+ for a detailed definition.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long newman_cluster (
+ const std::vector<ordered_sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const double eps = 1e-4,
+ const unsigned long max_iterations = 2000
+ );
+ /*!
+ requires
+ - is_ordered_by_index(edges) == true
+ - for all valid i:
+ - 0 <= edges[i].distance() < std::numeric_limits<double>::infinity()
+ ensures
+ - This function performs the clustering algorithm described in the paper
+ Modularity and community structure in networks by M. E. J. Newman.
+ - This function interprets edges as a graph and attempts to find the labeling
+ that maximizes modularity(edges, #labels).
+ - returns the number of clusters found.
+ - #labels.size() == max_index_plus_one(edges)
+ - for all valid i:
+ - #labels[i] == the cluster ID of the node with index i in the graph.
+ - 0 <= #labels[i] < the number of clusters found
+ (i.e. cluster IDs are assigned contiguously and start at 0)
+ - The main computation of the algorithm is involved in finding an eigenvector
+ of a certain matrix. To do this, we use the power iteration. In particular,
+ each time we try to find an eigenvector we will let the power iteration loop
+ at most max_iterations times or until it reaches an accuracy of eps.
+ Whichever comes first.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long newman_cluster (
+ const std::vector<sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const double eps = 1e-4,
+ const unsigned long max_iterations = 2000
+ );
+ /*!
+ requires
+ - for all valid i:
+ - 0 <= edges[i].distance() < std::numeric_limits<double>::infinity()
+ ensures
+ - This function is identical to the above newman_cluster() routine except that
+ it operates on a vector of sample_pair objects instead of
+ ordered_sample_pairs. Therefore, this is simply a convenience routine. In
+ particular, it is implemented by transforming the given edges into
+ ordered_sample_pairs and then calling the newman_cluster() routine defined
+ above.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/clustering/spectral_cluster.h b/ml/dlib/dlib/clustering/spectral_cluster.h
new file mode 100644
index 000000000..2cac9870f
--- /dev/null
+++ b/ml/dlib/dlib/clustering/spectral_cluster.h
@@ -0,0 +1,80 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SPECTRAL_CLUSTEr_H_
+#define DLIB_SPECTRAL_CLUSTEr_H_
+
+#include "spectral_cluster_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include "../svm/kkmeans.h"
+
+namespace dlib
+{
+ template <
+ typename kernel_type,
+ typename vector_type
+ >
+ std::vector<unsigned long> spectral_cluster (
+ const kernel_type& k,
+ const vector_type& samples,
+ const unsigned long num_clusters
+ )
+ {
+ DLIB_CASSERT(num_clusters > 0,
+ "\t std::vector<unsigned long> spectral_cluster(k,samples,num_clusters)"
+ << "\n\t num_clusters can't be 0."
+ );
+
+ if (num_clusters == 1)
+ {
+ // nothing to do, just assign everything to the 0 cluster.
+ return std::vector<unsigned long>(samples.size(), 0);
+ }
+
+ // compute the similarity matrix.
+ matrix<double> K(samples.size(), samples.size());
+ for (long r = 0; r < K.nr(); ++r)
+ for (long c = r+1; c < K.nc(); ++c)
+ K(r,c) = K(c,r) = (double)k(samples[r], samples[c]);
+ for (long r = 0; r < K.nr(); ++r)
+ K(r,r) = 0;
+
+ matrix<double,0,1> D(K.nr());
+ for (long r = 0; r < K.nr(); ++r)
+ D(r) = sum(rowm(K,r));
+ D = sqrt(reciprocal(D));
+ K = diagm(D)*K*diagm(D);
+ matrix<double> u,w,v;
+ // Use the normal SVD routine unless the matrix is really big, then use the fast
+ // approximate version.
+ if (K.nr() < 1000)
+ svd3(K,u,w,v);
+ else
+ svd_fast(K,u,w,v, num_clusters+100, 5);
+ // Pick out the eigenvectors associated with the largest eigenvalues.
+ rsort_columns(v,w);
+ v = colm(v, range(0,num_clusters-1));
+ // Now build the normalized spectral vectors, one for each input vector.
+ std::vector<matrix<double,0,1> > spec_samps, centers;
+ for (long r = 0; r < v.nr(); ++r)
+ {
+ spec_samps.push_back(trans(rowm(v,r)));
+ const double len = length(spec_samps.back());
+ if (len != 0)
+ spec_samps.back() /= len;
+ }
+ // Finally do the K-means clustering
+ pick_initial_centers(num_clusters, centers, spec_samps);
+ find_clusters_using_kmeans(spec_samps, centers);
+ // And then compute the cluster assignments based on the output of K-means.
+ std::vector<unsigned long> assignments;
+ for (unsigned long i = 0; i < spec_samps.size(); ++i)
+ assignments.push_back(nearest_center(centers, spec_samps[i]));
+
+ return assignments;
+ }
+
+}
+
+#endif // DLIB_SPECTRAL_CLUSTEr_H_
+
diff --git a/ml/dlib/dlib/clustering/spectral_cluster_abstract.h b/ml/dlib/dlib/clustering/spectral_cluster_abstract.h
new file mode 100644
index 000000000..880ad80af
--- /dev/null
+++ b/ml/dlib/dlib/clustering/spectral_cluster_abstract.h
@@ -0,0 +1,43 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_
+#ifdef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_
+
+#include <vector>
+
+namespace dlib
+{
+ template <
+ typename kernel_type,
+ typename vector_type
+ >
+ std::vector<unsigned long> spectral_cluster (
+ const kernel_type& k,
+ const vector_type& samples,
+ const unsigned long num_clusters
+ );
+ /*!
+ requires
+ - samples must be something with an interface compatible with std::vector.
+ - The following expression must evaluate to a double or float:
+ k(samples[i], samples[j])
+ - num_clusters > 0
+ ensures
+ - Performs the spectral clustering algorithm described in the paper:
+ On spectral clustering: Analysis and an algorithm by Ng, Jordan, and Weiss.
+ and returns the results.
+ - This function clusters the input data samples into num_clusters clusters and
+ returns a vector that indicates which cluster each sample falls into. In
+ particular, we return an array A such that:
+ - A.size() == samples.size()
+ - A[i] == the cluster assignment of samples[i].
+ - for all valid i: 0 <= A[i] < num_clusters
+ - The "similarity" of samples[i] with samples[j] is given by
+ k(samples[i],samples[j]). This means that k() should output a number >= 0
+ and the number should be larger for samples that are more similar.
+ !*/
+}
+
+#endif // DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/cmake b/ml/dlib/dlib/cmake
new file mode 100644
index 000000000..d3695b30e
--- /dev/null
+++ b/ml/dlib/dlib/cmake
@@ -0,0 +1,5 @@
+
+cmake_minimum_required(VERSION 2.8.12)
+
+add_subdirectory(${CMAKE_CURRENT_LIST_DIR} dlib_build)
+
diff --git a/ml/dlib/dlib/cmake_utils/add_global_compiler_switch.cmake b/ml/dlib/dlib/cmake_utils/add_global_compiler_switch.cmake
new file mode 100644
index 000000000..5f3d83ce4
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/add_global_compiler_switch.cmake
@@ -0,0 +1,35 @@
+
+
+cmake_minimum_required(VERSION 2.8.12)
+
+message(WARNING "add_global_compiler_switch() is deprecated. Use target_compile_options() instead")
+
+# Make macros that can add compiler switches to the entire project. Not just
+# to the current cmake folder being built.
+macro ( add_global_compiler_switch switch_name )
+ # If removing the switch would change the flags then it's already present
+ # and we don't need to do anything.
+ string(REPLACE "${switch_name}" "" tempstr "${CMAKE_CXX_FLAGS}")
+ if ("${CMAKE_CXX_FLAGS}" STREQUAL "${tempstr}" )
+ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${switch_name}"
+ CACHE STRING "Flags used by the compiler during all C++ builds."
+ FORCE)
+ endif ()
+endmacro()
+
+macro ( remove_global_compiler_switch switch_name )
+ string(REPLACE "${switch_name}" "" tempstr "${CMAKE_CXX_FLAGS}")
+ if (NOT "${CMAKE_CXX_FLAGS}" STREQUAL "${tempstr}" )
+ set (CMAKE_CXX_FLAGS "${tempstr}"
+ CACHE STRING "Flags used by the compiler during all C++ builds."
+ FORCE)
+ endif ()
+endmacro()
+
+macro (add_global_define def_name)
+ add_global_compiler_switch(-D${def_name})
+endmacro()
+
+macro (remove_global_define def_name)
+ remove_global_compiler_switch(-D${def_name})
+endmacro()
diff --git a/ml/dlib/dlib/cmake_utils/check_if_neon_available.cmake b/ml/dlib/dlib/cmake_utils/check_if_neon_available.cmake
new file mode 100644
index 000000000..0510707df
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/check_if_neon_available.cmake
@@ -0,0 +1,20 @@
+# This script checks if __ARM_NEON__ is defined for your compiler
+
+cmake_minimum_required(VERSION 2.8.12)
+
+# Don't rerun this script if its already been executed.
+if (DEFINED ARM_NEON_IS_AVAILABLE)
+ return()
+endif()
+
+# Set to false unless we find out otherwise in the code below.
+set(ARM_NEON_IS_AVAILABLE 0)
+
+# test if __ARM_NEON__ is defined
+try_compile(test_for_neon_worked ${PROJECT_BINARY_DIR}/neon_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_neon
+ neon_test)
+
+if(test_for_neon_worked)
+ message (STATUS "__ARM_NEON__ defined.")
+ set(ARM_NEON_IS_AVAILABLE 1)
+endif()
diff --git a/ml/dlib/dlib/cmake_utils/dlib.pc.in b/ml/dlib/dlib/cmake_utils/dlib.pc.in
new file mode 100644
index 000000000..188a673c3
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/dlib.pc.in
@@ -0,0 +1,9 @@
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: @PROJECT_NAME@
+Description: Numerical and networking C++ library
+Version: @VERSION@
+Libs: -L${libdir} -ldlib
+Cflags: -I${includedir}
+Requires:@REQUIRES_LIBS@
diff --git a/ml/dlib/dlib/cmake_utils/dlibConfig.cmake.in b/ml/dlib/dlib/cmake_utils/dlibConfig.cmake.in
new file mode 100644
index 000000000..df427e40e
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/dlibConfig.cmake.in
@@ -0,0 +1,50 @@
+# ===================================================================================
+# The dlib CMake configuration file
+#
+# ** File generated automatically, do not modify **
+#
+# Usage from an external project:
+# In your CMakeLists.txt, add these lines:
+#
+# FIND_PACKAGE(dlib REQUIRED)
+# TARGET_LINK_LIBRARIES(MY_TARGET_NAME ${dlib_LIBRARIES})
+#
+# This file will define the following variables:
+# - dlib_LIBRARIES : The list of all imported targets for dlib modules.
+# - dlib_INCLUDE_DIRS : The dlib include directories.
+# - dlib_VERSION : The version of this dlib build.
+# - dlib_VERSION_MAJOR : Major version part of this dlib revision.
+# - dlib_VERSION_MINOR : Minor version part of this dlib revision.
+#
+# ===================================================================================
+
+
+
+
+# Our library dependencies (contains definitions for IMPORTED targets)
+if(NOT TARGET dlib-shared AND NOT dlib_BINARY_DIR)
+ # Compute paths
+ get_filename_component(dlib_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
+ include("${dlib_CMAKE_DIR}/dlib.cmake")
+endif()
+
+set(dlib_LIBRARIES dlib::dlib)
+set(dlib_LIBS dlib::dlib)
+set(dlib_INCLUDE_DIRS "@CMAKE_INSTALL_FULL_INCLUDEDIR@" "@dlib_needed_includes@")
+
+mark_as_advanced(dlib_LIBRARIES)
+mark_as_advanced(dlib_LIBS)
+mark_as_advanced(dlib_INCLUDE_DIRS)
+
+# Mark these variables above as deprecated.
+function(__deprecated_var var access)
+ if(access STREQUAL "READ_ACCESS")
+ message(WARNING "The variable '${var}' is deprecated! Instead, simply use target_link_libraries(your_app dlib::dlib). See http://dlib.net/examples/CMakeLists.txt.html for an example.")
+ endif()
+endfunction()
+variable_watch(dlib_LIBRARIES __deprecated_var)
+variable_watch(dlib_LIBS __deprecated_var)
+variable_watch(dlib_INCLUDE_DIRS __deprecated_var)
+
+
+
diff --git a/ml/dlib/dlib/cmake_utils/find_blas.cmake b/ml/dlib/dlib/cmake_utils/find_blas.cmake
new file mode 100644
index 000000000..24fca7123
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/find_blas.cmake
@@ -0,0 +1,385 @@
+#
+# This is a CMake makefile. You can find the cmake utility and
+# information about it at http://www.cmake.org
+#
+#
+# This cmake file tries to find installed BLAS and LAPACK libraries.
+# It looks for an installed copy of the Intel MKL library first and then
+# attempts to find some other BLAS and LAPACK libraries if you don't have
+# the Intel MKL.
+#
+# blas_found - True if BLAS is available
+# lapack_found - True if LAPACK is available
+# found_intel_mkl - True if the Intel MKL library is available
+# found_intel_mkl_headers - True if Intel MKL headers are available
+# blas_libraries - link against these to use BLAS library
+# lapack_libraries - link against these to use LAPACK library
+# mkl_libraries - link against these to use the MKL library
+# mkl_include_dir - add to the include path to use the MKL library
+# openmp_libraries - Set to Intel's OpenMP library if and only if we
+# find the MKL.
+
+# setting this makes CMake allow normal looking if else statements
+SET(CMAKE_ALLOW_LOOSE_LOOP_CONSTRUCTS true)
+
+SET(blas_found 0)
+SET(lapack_found 0)
+SET(found_intel_mkl 0)
+SET(found_intel_mkl_headers 0)
+SET(lapack_with_underscore 0)
+SET(lapack_without_underscore 0)
+
+message(STATUS "Searching for BLAS and LAPACK")
+
+if (UNIX OR MINGW)
+ message(STATUS "Searching for BLAS and LAPACK")
+
+ if (BUILDING_MATLAB_MEX_FILE)
+ # # This commented out stuff would link directly to MATLAB's built in
+ # BLAS and LAPACK. But it's better to not link to anything and do a
+ #find_library(MATLAB_BLAS_LIBRARY mwblas PATHS ${MATLAB_LIB_FOLDERS} )
+ #find_library(MATLAB_LAPACK_LIBRARY mwlapack PATHS ${MATLAB_LIB_FOLDERS} )
+ #if (MATLAB_BLAS_LIBRARY AND MATLAB_LAPACK_LIBRARY)
+ # add_subdirectory(external/cblas)
+ # set(blas_libraries ${MATLAB_BLAS_LIBRARY} cblas )
+ # set(lapack_libraries ${MATLAB_LAPACK_LIBRARY} )
+ # set(blas_found 1)
+ # set(lapack_found 1)
+ # message(STATUS "Found MATLAB's BLAS and LAPACK libraries")
+ #endif()
+
+ # We need cblas since MATLAB doesn't provide cblas symbols.
+ add_subdirectory(external/cblas)
+ set(blas_libraries cblas )
+ set(blas_found 1)
+ set(lapack_found 1)
+ message(STATUS "Will link with MATLAB's BLAS and LAPACK at runtime (hopefully!)")
+
+
+ ## Don't try to link to anything other than MATLAB's own internal blas
+ ## and lapack libraries because doing so generally upsets MATLAB. So
+ ## we just end here no matter what.
+ return()
+ endif()
+
+ # First, search for libraries via pkg-config, which is the cleanest path
+ find_package(PkgConfig)
+ pkg_check_modules(BLAS_REFERENCE cblas)
+ pkg_check_modules(LAPACK_REFERENCE lapack)
+ if (BLAS_REFERENCE_FOUND AND LAPACK_REFERENCE_FOUND)
+ set(blas_libraries "${BLAS_REFERENCE_LDFLAGS}")
+ set(lapack_libraries "${LAPACK_REFERENCE_LDFLAGS}")
+ set(blas_found 1)
+ set(lapack_found 1)
+ set(REQUIRES_LIBS "${REQUIRES_LIBS} cblas lapack")
+ message(STATUS "Found BLAS and LAPACK via pkg-config")
+ return()
+ endif()
+
+ include(CheckTypeSize)
+ check_type_size( "void*" SIZE_OF_VOID_PTR)
+
+ if (SIZE_OF_VOID_PTR EQUAL 8)
+ set( mkl_search_path
+ /opt/intel/mkl/*/lib/em64t
+ /opt/intel/mkl/lib/intel64
+ /opt/intel/lib/intel64
+ /opt/intel/mkl/lib
+ )
+
+ find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path})
+ mark_as_advanced(mkl_intel)
+ else()
+ set( mkl_search_path
+ /opt/intel/mkl/*/lib/32
+ /opt/intel/mkl/lib/ia32
+ /opt/intel/lib/ia32
+ )
+
+ find_library(mkl_intel mkl_intel ${mkl_search_path})
+ mark_as_advanced(mkl_intel)
+ endif()
+
+ include(CheckLibraryExists)
+
+ # Get mkl_include_dir
+ set(mkl_include_search_path
+ /opt/intel/mkl/include
+ /opt/intel/include
+ )
+ find_path(mkl_include_dir mkl_version.h ${mkl_include_search_path})
+ mark_as_advanced(mkl_include_dir)
+
+ # Search for the needed libraries from the MKL. We will try to link against the mkl_rt
+ # file first since this way avoids linking bugs in some cases.
+ find_library(mkl_rt mkl_rt ${mkl_search_path})
+ find_library(openmp_libraries iomp5 ${mkl_search_path})
+ mark_as_advanced( mkl_rt openmp_libraries )
+ # if we found the MKL
+ if ( mkl_rt)
+ set(mkl_libraries ${mkl_rt} )
+ set(blas_libraries ${mkl_rt} )
+ set(lapack_libraries ${mkl_rt} )
+ set(blas_found 1)
+ set(lapack_found 1)
+ set(found_intel_mkl 1)
+ message(STATUS "Found Intel MKL BLAS/LAPACK library")
+ endif()
+
+ if (NOT found_intel_mkl)
+ # Search for the needed libraries from the MKL. This time try looking for a different
+ # set of MKL files and try to link against those.
+ find_library(mkl_core mkl_core ${mkl_search_path})
+ find_library(mkl_thread mkl_intel_thread ${mkl_search_path})
+ find_library(mkl_iomp iomp5 ${mkl_search_path})
+ find_library(mkl_pthread pthread ${mkl_search_path})
+
+ mark_as_advanced( mkl_intel mkl_core mkl_thread mkl_iomp mkl_pthread)
+ # If we found the MKL
+ if (mkl_intel AND mkl_core AND mkl_thread AND mkl_iomp AND mkl_pthread)
+ set(mkl_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread})
+ set(blas_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread})
+ set(lapack_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread})
+ set(blas_found 1)
+ set(lapack_found 1)
+ set(found_intel_mkl 1)
+ message(STATUS "Found Intel MKL BLAS/LAPACK library")
+ endif()
+ endif()
+
+ if (found_intel_mkl AND mkl_include_dir)
+ set(found_intel_mkl_headers 1)
+ endif()
+
+ # try to find some other LAPACK libraries if we didn't find the MKL
+ set(extra_paths
+ /usr/lib64
+ /usr/lib64/atlas-sse3
+ /usr/lib64/atlas-sse2
+ /usr/lib64/atlas
+ /usr/lib
+ /usr/lib/atlas-sse3
+ /usr/lib/atlas-sse2
+ /usr/lib/atlas
+ /usr/lib/openblas-base
+ /opt/OpenBLAS/lib
+ $ENV{OPENBLAS_HOME}/lib
+ )
+
+ INCLUDE (CheckFunctionExists)
+
+ if (NOT blas_found)
+ find_library(cblas_lib openblas PATHS ${extra_paths})
+ if (cblas_lib)
+ set(blas_libraries ${cblas_lib})
+ set(blas_found 1)
+ message(STATUS "Found OpenBLAS library")
+ set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries})
+ # If you compiled OpenBLAS with LAPACK in it then it should have the
+ # sgetrf_single function in it. So if we find that function in
+ # OpenBLAS then just use OpenBLAS's LAPACK.
+ CHECK_FUNCTION_EXISTS(sgetrf_single OPENBLAS_HAS_LAPACK)
+ if (OPENBLAS_HAS_LAPACK)
+ message(STATUS "Using OpenBLAS's built in LAPACK")
+ # set(lapack_libraries gfortran)
+ set(lapack_found 1)
+ endif()
+ endif()
+ mark_as_advanced( cblas_lib)
+ endif()
+
+
+ if (NOT lapack_found)
+ find_library(lapack_lib NAMES lapack lapack-3 PATHS ${extra_paths})
+ if (lapack_lib)
+ set(lapack_libraries ${lapack_lib})
+ set(lapack_found 1)
+ message(STATUS "Found LAPACK library")
+ endif()
+ mark_as_advanced( lapack_lib)
+ endif()
+
+
+ # try to find some other BLAS libraries if we didn't find the MKL
+
+ if (NOT blas_found)
+ find_library(atlas_lib atlas PATHS ${extra_paths})
+ find_library(cblas_lib cblas PATHS ${extra_paths})
+ if (atlas_lib AND cblas_lib)
+ set(blas_libraries ${atlas_lib} ${cblas_lib})
+ set(blas_found 1)
+ message(STATUS "Found ATLAS BLAS library")
+ endif()
+ mark_as_advanced( atlas_lib cblas_lib)
+ endif()
+
+ # CentOS 7 atlas
+ if (NOT blas_found)
+ find_library(tatlas_lib tatlas PATHS ${extra_paths})
+ find_library(satlas_lib satlas PATHS ${extra_paths})
+ if (tatlas_lib AND satlas_lib )
+ set(blas_libraries ${tatlas_lib} ${satlas_lib})
+ set(blas_found 1)
+ message(STATUS "Found ATLAS BLAS library")
+ endif()
+ mark_as_advanced( tatlas_lib satlas_lib)
+ endif()
+
+
+ if (NOT blas_found)
+ find_library(cblas_lib cblas PATHS ${extra_paths})
+ if (cblas_lib)
+ set(blas_libraries ${cblas_lib})
+ set(blas_found 1)
+ message(STATUS "Found CBLAS library")
+ endif()
+ mark_as_advanced( cblas_lib)
+ endif()
+
+
+ if (NOT blas_found)
+ find_library(generic_blas blas PATHS ${extra_paths})
+ if (generic_blas)
+ set(blas_libraries ${generic_blas})
+ set(blas_found 1)
+ message(STATUS "Found BLAS library")
+ endif()
+ mark_as_advanced( generic_blas)
+ endif()
+
+
+
+
+ # Make sure we really found a CBLAS library. That is, it needs to expose
+ # the proper cblas link symbols. So here we test if one of them is present
+ # and assume everything is good if it is. Note that we don't do this check if
+ # we found the Intel MKL since for some reason CHECK_FUNCTION_EXISTS doesn't work
+ # with it. But it's fine since the MKL should always have cblas.
+ if (blas_found AND NOT found_intel_mkl)
+ set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries})
+ CHECK_FUNCTION_EXISTS(cblas_ddot HAVE_CBLAS)
+ if (NOT HAVE_CBLAS)
+ message(STATUS "BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK")
+ set(blas_found 0)
+ set(lapack_found 0)
+ endif()
+ endif()
+
+
+
+elseif(WIN32 AND NOT MINGW)
+ message(STATUS "Searching for BLAS and LAPACK")
+
+ include(CheckTypeSize)
+ check_type_size( "void*" SIZE_OF_VOID_PTR)
+ if (SIZE_OF_VOID_PTR EQUAL 8)
+ set( mkl_search_path
+ "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/intel64"
+ "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/intel64"
+ "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/intel64"
+ "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/intel64"
+ "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/intel64"
+ "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/intel64"
+ "C:/Program Files/Intel/Composer XE/mkl/lib/intel64"
+ "C:/Program Files/Intel/Composer XE/compiler/lib/intel64"
+ )
+ find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path})
+ else()
+ set( mkl_search_path
+ "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/ia32"
+ "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/ia32"
+ "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/ia32"
+ "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/ia32"
+ "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/ia32"
+ "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/ia32"
+ "C:/Program Files/Intel/Composer XE/mkl/lib/ia32"
+ "C:/Program Files/Intel/Composer XE/compiler/lib/ia32"
+ )
+ find_library(mkl_intel mkl_intel_c ${mkl_search_path})
+ endif()
+
+ INCLUDE (CheckFunctionExists)
+
+ # Search for the needed libraries from the MKL.
+ find_library(mkl_core mkl_core ${mkl_search_path})
+ find_library(mkl_thread mkl_intel_thread ${mkl_search_path})
+ find_library(mkl_iomp libiomp5md ${mkl_search_path})
+
+ mark_as_advanced( mkl_intel mkl_core mkl_thread mkl_iomp)
+ # If we found the MKL
+ if (mkl_intel AND mkl_core AND mkl_thread AND mkl_iomp )
+ set(blas_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} )
+ set(lapack_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} )
+ set(blas_found 1)
+ set(lapack_found 1)
+ message(STATUS "Found Intel MKL BLAS/LAPACK library")
+
+ # Make sure the version of the Intel MKL we found is compatible with
+ # the compiler we are using. One way to do this check is to see if we can
+ # link to it right now.
+ set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries})
+ CHECK_FUNCTION_EXISTS(cblas_ddot HAVE_CBLAS)
+ if (NOT HAVE_CBLAS)
+ message("BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK")
+ set(blas_found 0)
+ set(lapack_found 0)
+ endif()
+
+ endif()
+
+
+endif()
+
+
+# When all else fails use CMake's built in functions to find BLAS and LAPACK
+if (NOT blas_found)
+ find_package(BLAS QUIET)
+ if (${BLAS_FOUND})
+ set(blas_libraries ${BLAS_LIBRARIES})
+ set(blas_found 1)
+ if (NOT lapack_found)
+ find_package(LAPACK QUIET)
+ if (${LAPACK_FOUND})
+ set(lapack_libraries ${LAPACK_LIBRARIES})
+ set(lapack_found 1)
+ endif()
+ endif()
+ endif()
+endif()
+
+
+# If using lapack, determine whether to mangle functions
+if (lapack_found)
+ include(CheckFunctionExists)
+ include(CheckFortranFunctionExists)
+ set(CMAKE_REQUIRED_LIBRARIES ${lapack_libraries})
+
+ check_function_exists("sgesv" LAPACK_FOUND_C_UNMANGLED)
+ check_function_exists("sgesv_" LAPACK_FOUND_C_MANGLED)
+ if (CMAKE_Fortran_COMPILER_LOADED)
+ check_fortran_function_exists("sgesv" LAPACK_FOUND_FORTRAN_UNMANGLED)
+ check_fortran_function_exists("sgesv_" LAPACK_FOUND_FORTRAN_MANGLED)
+ endif ()
+ if (LAPACK_FOUND_C_MANGLED OR LAPACK_FOUND_FORTRAN_MANGLED)
+ set(lapack_with_underscore 1)
+ elseif (LAPACK_FOUND_C_UNMANGLED OR LAPACK_FOUND_FORTRAN_UNMANGLED)
+ set(lapack_without_underscore 1)
+ endif ()
+endif()
+
+
+if (UNIX OR MINGW)
+ if (NOT blas_found)
+ message(" *****************************************************************************")
+ message(" *** No BLAS library found so using dlib's built in BLAS. However, if you ***")
+ message(" *** install an optimized BLAS such as OpenBLAS or the Intel MKL your code ***")
+ message(" *** will run faster. On Ubuntu you can install OpenBLAS by executing: ***")
+ message(" *** sudo apt-get install libopenblas-dev liblapack-dev ***")
+ message(" *** Or you can easily install OpenBLAS from source by downloading the ***")
+ message(" *** source tar file from http://www.openblas.net, extracting it, and ***")
+ message(" *** running: ***")
+ message(" *** make; sudo make install ***")
+ message(" *****************************************************************************")
+ endif()
+endif()
+
diff --git a/ml/dlib/dlib/cmake_utils/release_build_by_default b/ml/dlib/dlib/cmake_utils/release_build_by_default
new file mode 100644
index 000000000..1b0e95831
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/release_build_by_default
@@ -0,0 +1,9 @@
+
+#set default build type to Release
+if(NOT CMAKE_BUILD_TYPE)
+ set(CMAKE_BUILD_TYPE "Release" CACHE STRING
+ "Choose the type of build, options are: Debug Release
+ RelWithDebInfo MinSizeRel." FORCE)
+endif()
+
+
diff --git a/ml/dlib/dlib/cmake_utils/set_compiler_specific_options.cmake b/ml/dlib/dlib/cmake_utils/set_compiler_specific_options.cmake
new file mode 100644
index 000000000..dd8e3a7a5
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/set_compiler_specific_options.cmake
@@ -0,0 +1,131 @@
+
+cmake_minimum_required(VERSION 2.8.12)
+
+if (POLICY CMP0054)
+ cmake_policy(SET CMP0054 NEW)
+endif()
+
+set(USING_OLD_VISUAL_STUDIO_COMPILER 0)
+if(MSVC AND MSVC_VERSION VERSION_LESS 1900)
+ message(FATAL_ERROR "C++11 is required to use dlib, but the version of Visual Studio you are using is too old and doesn't support C++11. You need Visual Studio 2015 or newer. ")
+elseif(MSVC AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 19.0.24210.0 )
+ message(STATUS "NOTE: Visual Studio didn't have good enough C++11 support until Visual Studio 2015 update 3 (v19.0.24210.0)")
+ message(STATUS "So we aren't enabling things that require full C++11 support (e.g. the deep learning tools).")
+ message(STATUS "Also, be aware that Visual Studio's version naming is confusing, in particular, there are multiple versions of 'update 3'")
+ message(STATUS "So if you are getting this message you need to update to the newer version of Visual Studio to use full C++11.")
+ set(USING_OLD_VISUAL_STUDIO_COMPILER 1)
+elseif(MSVC AND (MSVC_VERSION EQUAL 1911 OR MSVC_VERSION EQUAL 1910))
+ message(STATUS "******************************************************************************************")
+ message(STATUS "Your version of Visual Studio has incomplete C++11 support and is unable to compile the ")
+ message(STATUS "DNN examples. So we are disabling the deep learning tools. If you want to use the DNN ")
+ message(STATUS "tools in dlib then update your copy of Visual Studio.")
+ message(STATUS "******************************************************************************************")
+ set(USING_OLD_VISUAL_STUDIO_COMPILER 1)
+endif()
+
+if(CMAKE_COMPILER_IS_GNUCXX)
+ execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpversion OUTPUT_VARIABLE GCC_VERSION)
+ if (GCC_VERSION VERSION_LESS 4.8)
+ message(FATAL_ERROR "C++11 is required to use dlib, but the version of GCC you are using is too old and doesn't support C++11. You need GCC 4.8 or newer. ")
+ endif()
+endif()
+
+
+# push USING_OLD_VISUAL_STUDIO_COMPILER to the parent so we can use it in the
+# examples CMakeLists.txt file.
+get_directory_property(has_parent PARENT_DIRECTORY)
+if(has_parent)
+ set(USING_OLD_VISUAL_STUDIO_COMPILER ${USING_OLD_VISUAL_STUDIO_COMPILER} PARENT_SCOPE)
+endif()
+
+
+
+set(gcc_like_compilers GNU Clang Intel)
+set(intel_archs x86_64 i386 i686 AMD64 amd64 x86)
+
+
+# Setup some options to allow a user to enable SSE and AVX instruction use.
+if ((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND
+ (";${intel_archs};" MATCHES ";${CMAKE_SYSTEM_PROCESSOR};") AND NOT USE_AUTO_VECTOR)
+ option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" OFF)
+ option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF)
+ option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF)
+ if(USE_AVX_INSTRUCTIONS)
+ list(APPEND active_compile_opts -mavx)
+ message(STATUS "Enabling AVX instructions")
+ elseif (USE_SSE4_INSTRUCTIONS)
+ list(APPEND active_compile_opts -msse4)
+ message(STATUS "Enabling SSE4 instructions")
+ elseif(USE_SSE2_INSTRUCTIONS)
+ list(APPEND active_compile_opts -msse2)
+ message(STATUS "Enabling SSE2 instructions")
+ endif()
+elseif (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") # else if using Visual Studio
+ # Use SSE2 by default when using Visual Studio.
+ option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" ON)
+ option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF)
+ option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF)
+
+ include(CheckTypeSize)
+ check_type_size( "void*" SIZE_OF_VOID_PTR)
+ if(USE_AVX_INSTRUCTIONS)
+ list(APPEND active_compile_opts /arch:AVX)
+ message(STATUS "Enabling AVX instructions")
+ elseif (USE_SSE4_INSTRUCTIONS)
+ # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes.
+ # So only give it when we are doing a 32 bit build.
+ if (SIZE_OF_VOID_PTR EQUAL 4)
+ list(APPEND active_compile_opts /arch:SSE2)
+ endif()
+ message(STATUS "Enabling SSE4 instructions")
+ list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2")
+ list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE3")
+ list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE41")
+ elseif(USE_SSE2_INSTRUCTIONS)
+ # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes.
+ # So only give it when we are doing a 32 bit build.
+ if (SIZE_OF_VOID_PTR EQUAL 4)
+ list(APPEND active_compile_opts /arch:SSE2)
+ endif()
+ message(STATUS "Enabling SSE2 instructions")
+ list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2")
+ endif()
+
+elseif((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND
+ ("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "^arm"))
+ option(USE_NEON_INSTRUCTIONS "Compile your program with ARM-NEON instructions" OFF)
+ if(USE_NEON_INSTRUCTIONS)
+ list(APPEND active_compile_opts -mfpu=neon)
+ message(STATUS "Enabling ARM-NEON instructions")
+ endif()
+endif()
+
+
+
+
+if (CMAKE_COMPILER_IS_GNUCXX)
+ # By default, g++ won't warn or error if you forget to return a value in a
+ # function which requires you to do so. This option makes it give a warning
+ # for doing this.
+ list(APPEND active_compile_opts "-Wreturn-type")
+endif()
+
+if ("Clang" MATCHES ${CMAKE_CXX_COMPILER_ID})
+ # Increase clang's default tempalte recurision depth so the dnn examples don't error out.
+ list(APPEND active_compile_opts "-ftemplate-depth=500")
+endif()
+
+if (MSVC)
+ # By default Visual Studio does not support .obj files with more than 65k sections.
+ # However, code generated by file_to_code_ex and code using DNN module can have
+ # them. So this flag enables > 65k sections, but produces .obj files
+ # that will not be readable by VS 2005.
+ list(APPEND active_compile_opts "/bigobj")
+
+ if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 3.3)
+ # Clang can compile all Dlib's code at Windows platform. Tested with Clang 5
+ list(APPEND active_compile_opts "-Xclang -fcxx-exceptions")
+ endif()
+endif()
+
+
diff --git a/ml/dlib/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake b/ml/dlib/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake
new file mode 100644
index 000000000..e5fb09129
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake
@@ -0,0 +1,19 @@
+
+# Including this cmake script into your cmake project will cause visual studio
+# to build your project against the static C runtime.
+
+cmake_minimum_required(VERSION 2.8.12)
+if (POLICY CMP0054)
+ cmake_policy(SET CMP0054 NEW)
+endif()
+
+if (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
+ foreach(flag_var
+ CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
+ CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
+ if(${flag_var} MATCHES "/MD")
+ string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
+ endif()
+ endforeach(flag_var)
+endif()
+
diff --git a/ml/dlib/dlib/cmake_utils/test_for_cpp11/CMakeLists.txt b/ml/dlib/dlib/cmake_utils/test_for_cpp11/CMakeLists.txt
new file mode 100644
index 000000000..bc6f02563
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/test_for_cpp11/CMakeLists.txt
@@ -0,0 +1,17 @@
+
+cmake_minimum_required(VERSION 2.8.12)
+project(cpp11_test)
+
+# Try to enable C++11
+include(CheckCXXCompilerFlag)
+CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11)
+CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORTS_CXX0X)
+if(COMPILER_SUPPORTS_CXX11)
+ message(STATUS "C++11 activated.")
+ ADD_DEFINITIONS("-std=c++11")
+elseif(COMPILER_SUPPORTS_CXX0X)
+ message(STATUS "C++0x activated.")
+ ADD_DEFINITIONS("-std=c++0x")
+endif()
+
+add_library(cpp11_test STATIC cpp11_test.cpp )
diff --git a/ml/dlib/dlib/cmake_utils/test_for_cpp11/cpp11_test.cpp b/ml/dlib/dlib/cmake_utils/test_for_cpp11/cpp11_test.cpp
new file mode 100644
index 000000000..6cc4f479b
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/test_for_cpp11/cpp11_test.cpp
@@ -0,0 +1,51 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <memory>
+#include <iostream>
+
+using namespace std;
+
+class testme
+{
+public:
+
+ testme(testme&&) = default;
+ testme(const testme&) = delete;
+
+
+ template <typename T>
+ auto auto_return(T f) -> decltype(f(4)) { return f(4); }
+
+ template <typename T>
+ auto auto_return(T f) -> decltype(f()) { return f(); }
+
+ static int returnint() { return 0; }
+
+ void dostuff()
+ {
+ thread_local int stuff1 = 999;
+ auto x = 4;
+
+ decltype(x) asdf = 9;
+
+ auto f = []() { cout << "in a lambda!" << endl; };
+ f();
+
+ auto_return(returnint);
+ }
+
+ template <typename ...T>
+ void variadic_template(
+ T&& ...args
+ )
+ {
+ }
+
+
+
+ std::shared_ptr<int> asdf;
+};
+
+// ------------------------------------------------------------------------------------
+
diff --git a/ml/dlib/dlib/cmake_utils/test_for_cuda/CMakeLists.txt b/ml/dlib/dlib/cmake_utils/test_for_cuda/CMakeLists.txt
new file mode 100644
index 000000000..5f6af245e
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/test_for_cuda/CMakeLists.txt
@@ -0,0 +1,14 @@
+
+cmake_minimum_required(VERSION 2.8.12)
+project(cuda_test)
+
+include_directories(../../dnn)
+add_definitions(-DDLIB_USE_CUDA)
+
+# Override the FindCUDA.cmake setting to avoid duplication of host flags if using a toolchain:
+option(CUDA_PROPAGATE_HOST_FLAGS "Propage C/CXX_FLAGS and friends to the host compiler via -Xcompile" OFF)
+find_package(CUDA 7.5 REQUIRED)
+set(CUDA_HOST_COMPILATION_CPP ON)
+list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__;-D_MWAITXINTRIN_H_INCLUDED;-D_FORCE_INLINES")
+
+cuda_add_library(cuda_test STATIC cuda_test.cu )
diff --git a/ml/dlib/dlib/cmake_utils/test_for_cuda/cuda_test.cu b/ml/dlib/dlib/cmake_utils/test_for_cuda/cuda_test.cu
new file mode 100644
index 000000000..fb1ffe0da
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/test_for_cuda/cuda_test.cu
@@ -0,0 +1,21 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "cuda_utils.h"
+#include "cuda_dlib.h"
+
+
+// ------------------------------------------------------------------------------------
+
+__global__ void cuda_add_arrays(const float* a, const float* b, float* out, size_t n)
+{
+ out[0] += a[0]+b[0];
+}
+
+void add_arrays()
+{
+ cuda_add_arrays<<<512,512>>>(0,0,0,0);
+}
+
+// ------------------------------------------------------------------------------------
+
diff --git a/ml/dlib/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt b/ml/dlib/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt
new file mode 100644
index 000000000..556088259
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt
@@ -0,0 +1,19 @@
+
+cmake_minimum_required(VERSION 2.8.12)
+project(cudnn_test)
+include(../use_cpp_11.cmake)
+
+# Override the FindCUDA.cmake setting to avoid duplication of host flags if using a toolchain:
+option(CUDA_PROPAGATE_HOST_FLAGS "Propage C/CXX_FLAGS and friends to the host compiler via -Xcompile" OFF)
+find_package(CUDA 7.5 REQUIRED)
+set(CUDA_HOST_COMPILATION_CPP ON)
+list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__")
+add_definitions(-DDLIB_USE_CUDA)
+
+include(find_cudnn.txt)
+
+if (cudnn_include AND cudnn)
+ include_directories(${cudnn_include})
+ cuda_add_library(cudnn_test STATIC ../../dnn/cudnn_dlibapi.cpp ${cudnn} )
+ enable_cpp11_for_target(cudnn_test)
+endif()
diff --git a/ml/dlib/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt b/ml/dlib/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt
new file mode 100644
index 000000000..dd5f14e3f
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt
@@ -0,0 +1,24 @@
+
+message(STATUS "Looking for cuDNN install...")
+# Look for cudnn, we will look in the same place as other CUDA
+# libraries and also a few other places as well.
+find_path(cudnn_include cudnn.h
+ HINTS ${CUDA_INCLUDE_DIRS} ENV CUDNN_INCLUDE_DIR ENV CUDNN_HOME
+ PATHS /usr/local ENV CPATH
+ PATH_SUFFIXES include
+ )
+get_filename_component(cudnn_hint_path "${CUDA_CUBLAS_LIBRARIES}" PATH)
+find_library(cudnn cudnn
+ HINTS ${cudnn_hint_path} ENV CUDNN_LIBRARY_DIR ENV CUDNN_HOME
+ PATHS /usr/local /usr/local/cuda ENV LD_LIBRARY_PATH
+ PATH_SUFFIXES lib64 lib x64
+ )
+mark_as_advanced(cudnn cudnn_include)
+
+if (cudnn AND cudnn_include)
+ message(STATUS "Found cuDNN: " ${cudnn})
+else()
+ message(STATUS "*** cuDNN V5.0 OR GREATER NOT FOUND. ***")
+ message(STATUS "*** Dlib requires cuDNN V5.0 OR GREATER. Since cuDNN is not found DLIB WILL NOT USE CUDA. ***")
+ message(STATUS "*** If you have cuDNN then set CMAKE_PREFIX_PATH to include cuDNN's folder. ***")
+endif()
diff --git a/ml/dlib/dlib/cmake_utils/test_for_neon/CMakeLists.txt b/ml/dlib/dlib/cmake_utils/test_for_neon/CMakeLists.txt
new file mode 100644
index 000000000..0b6eb6f28
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/test_for_neon/CMakeLists.txt
@@ -0,0 +1,6 @@
+
+cmake_minimum_required(VERSION 2.8.12)
+project(neon_test)
+
+add_library(neon_test STATIC neon_test.cpp )
+
diff --git a/ml/dlib/dlib/cmake_utils/test_for_neon/neon_test.cpp b/ml/dlib/dlib/cmake_utils/test_for_neon/neon_test.cpp
new file mode 100644
index 000000000..a4abdbade
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/test_for_neon/neon_test.cpp
@@ -0,0 +1,9 @@
+#ifdef __ARM_NEON__
+#else
+#error "No NEON"
+#endif
+int main(){}
+
+
+// ------------------------------------------------------------------------------------
+
diff --git a/ml/dlib/dlib/cmake_utils/use_cpp_11.cmake b/ml/dlib/dlib/cmake_utils/use_cpp_11.cmake
new file mode 100644
index 000000000..e49e30f2a
--- /dev/null
+++ b/ml/dlib/dlib/cmake_utils/use_cpp_11.cmake
@@ -0,0 +1,113 @@
+# This script creates a function, enable_cpp11_for_target(), which checks if your
+# compiler has C++11 support and enables it if it does.
+
+
+cmake_minimum_required(VERSION 2.8.12)
+
+if (POLICY CMP0054)
+ cmake_policy(SET CMP0054 NEW)
+endif()
+
+
+set(_where_is_cmake_utils_dir ${CMAKE_CURRENT_LIST_DIR})
+
+function(enable_cpp11_for_target target_name)
+
+
+# Set to false unless we find out otherwise in the code below.
+set(COMPILER_CAN_DO_CPP_11 0)
+
+
+
+macro(test_compiler_for_cpp11)
+ message(STATUS "Building a C++11 test project to see if your compiler supports C++11")
+ try_compile(test_for_cpp11_worked ${PROJECT_BINARY_DIR}/cpp11_test_build
+ ${_where_is_cmake_utils_dir}/test_for_cpp11 cpp11_test)
+ if (test_for_cpp11_worked)
+ message(STATUS "C++11 activated.")
+ set(COMPILER_CAN_DO_CPP_11 1)
+ else()
+ set(COMPILER_CAN_DO_CPP_11 0)
+ message(STATUS "********** Your compiler failed to build a C++11 project. C++11 is required to use all parts of dlib! **********")
+ endif()
+endmacro()
+
+# Now turn on the appropriate compiler switch to enable C++11 if you have a
+# C++11 compiler. In CMake 3.1 there is a simple flag you can set, but earlier
+# verions of CMake are not so convenient.
+if (CMAKE_VERSION VERSION_LESS "3.1.2")
+ if(CMAKE_COMPILER_IS_GNUCXX)
+ execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpversion OUTPUT_VARIABLE GCC_VERSION)
+ if (GCC_VERSION VERSION_GREATER 4.8 OR GCC_VERSION VERSION_EQUAL 4.8)
+ message(STATUS "C++11 activated.")
+ target_compile_options(${target_name} PUBLIC "-std=gnu++11")
+ set(COMPILER_CAN_DO_CPP_11 1)
+ endif()
+ elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
+ execute_process( COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string )
+ string (REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION ${clang_full_version_string})
+ if (CLANG_VERSION VERSION_GREATER 3.3)
+ message(STATUS "C++11 activated.")
+ target_compile_options(${target_name} PUBLIC "-std=c++11")
+ set(COMPILER_CAN_DO_CPP_11 1)
+ endif()
+ else()
+ # Since we don't know what compiler this is just try to build a c++11 project and see if it compiles.
+ test_compiler_for_cpp11()
+ endif()
+else()
+
+ # Set a flag if the compiler you are using is capable of providing C++11 features.
+ get_property(cxx_features GLOBAL PROPERTY CMAKE_CXX_KNOWN_FEATURES)
+ if (";${cxx_features};" MATCHES ";cxx_rvalue_references;" AND
+ ";${cxx_features};" MATCHES ";cxx_variadic_templates;" AND
+ ";${cxx_features};" MATCHES ";cxx_lambdas;" AND
+ ";${cxx_features};" MATCHES ";cxx_defaulted_move_initializers;" AND
+ ";${cxx_features};" MATCHES ";cxx_delegating_constructors;" AND
+ ";${cxx_features};" MATCHES ";cxx_thread_local;" AND
+ ";${cxx_features};" MATCHES ";cxx_constexpr;" AND
+ ";${cxx_features};" MATCHES ";cxx_decltype_incomplete_return_types;" AND
+ ";${cxx_features};" MATCHES ";cxx_auto_type;")
+
+ set(COMPILER_CAN_DO_CPP_11 1)
+ # Tell cmake that we need C++11 for dlib
+ target_compile_features(${target_name}
+ PUBLIC
+ cxx_rvalue_references
+ cxx_variadic_templates
+ cxx_lambdas
+ cxx_defaulted_move_initializers
+ cxx_delegating_constructors
+ cxx_thread_local
+ cxx_constexpr
+ # cxx_decltype_incomplete_return_types # purposfully commented out because cmake errors out on this when using visual studio and cmake 3.8.0
+ cxx_auto_type
+ )
+
+ if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
+ # Sometimes clang will lie and report that it supports C++11 when
+ # really it doesn't support thread_local. So check for that.
+ test_compiler_for_cpp11()
+ else()
+ message(STATUS "C++11 activated.")
+ endif()
+ endif()
+endif()
+
+# Always enable whatever partial C++11 support we have, even if it isn't full
+# support, and just hope for the best.
+if (NOT COMPILER_CAN_DO_CPP_11)
+ include(CheckCXXCompilerFlag)
+ CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11)
+ CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORTS_CXX0X)
+ if(COMPILER_SUPPORTS_CXX11)
+ message(STATUS "C++11 activated (compiler doesn't have full C++11 support).")
+ target_compile_options(${target_name} PUBLIC "-std=c++11")
+ elseif(COMPILER_SUPPORTS_CXX0X)
+ message(STATUS "C++0x activated (compiler doesn't have full C++11 support).")
+ target_compile_options(${target_name} PUBLIC "-std=c++0x")
+ endif()
+endif()
+
+endfunction()
+
diff --git a/ml/dlib/dlib/cmd_line_parser.h b/ml/dlib/dlib/cmd_line_parser.h
new file mode 100644
index 000000000..fd1148038
--- /dev/null
+++ b/ml/dlib/dlib/cmd_line_parser.h
@@ -0,0 +1,84 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CMD_LINE_PARSEr_
+#define DLIB_CMD_LINE_PARSEr_
+
+#include "cmd_line_parser/cmd_line_parser_kernel_1.h"
+#include "cmd_line_parser/cmd_line_parser_kernel_c.h"
+#include "cmd_line_parser/cmd_line_parser_print_1.h"
+#include "cmd_line_parser/cmd_line_parser_check_1.h"
+#include "cmd_line_parser/cmd_line_parser_check_c.h"
+#include <string>
+#include "cmd_line_parser/get_option.h"
+
+#include "map.h"
+#include "sequence.h"
+
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT
+ >
+ class impl_cmd_line_parser
+ {
+ /*!
+ This class is basically just a big templated typedef for building
+ a complete command line parser type out of all the parts it needs.
+ !*/
+
+ impl_cmd_line_parser() {}
+
+ typedef typename sequence<std::basic_string<charT> >::kernel_2a sequence_2a;
+ typedef typename sequence<std::basic_string<charT>*>::kernel_2a psequence_2a;
+ typedef typename map<std::basic_string<charT>,void*>::kernel_1a map_1a_string;
+
+ public:
+
+ typedef cmd_line_parser_kernel_1<charT,map_1a_string,sequence_2a,psequence_2a> kernel_1a;
+ typedef cmd_line_parser_kernel_c<kernel_1a> kernel_1a_c;
+ typedef cmd_line_parser_print_1<kernel_1a_c> print_1a_c;
+ typedef cmd_line_parser_check_c<cmd_line_parser_check_1<print_1a_c> > check_1a_c;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT
+ >
+ class cmd_line_parser : public impl_cmd_line_parser<charT>::check_1a_c
+ {
+ public:
+
+ // These typedefs are here for backwards compatibility with previous versions of dlib.
+ typedef cmd_line_parser kernel_1a;
+ typedef cmd_line_parser kernel_1a_c;
+ typedef cmd_line_parser print_1a;
+ typedef cmd_line_parser print_1a_c;
+ typedef cmd_line_parser check_1a;
+ typedef cmd_line_parser check_1a_c;
+ };
+
+ template <
+ typename charT
+ >
+ inline void swap (
+ cmd_line_parser<charT>& a,
+ cmd_line_parser<charT>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ typedef cmd_line_parser<char> command_line_parser;
+ typedef cmd_line_parser<wchar_t> wcommand_line_parser;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CMD_LINE_PARSEr_
+
diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_1.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_1.h
new file mode 100644
index 000000000..1736b4b56
--- /dev/null
+++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_1.h
@@ -0,0 +1,580 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CMD_LINE_PARSER_CHECk_1_
+#define DLIB_CMD_LINE_PARSER_CHECk_1_
+
+#include "cmd_line_parser_kernel_abstract.h"
+#include <sstream>
+#include <string>
+#include "../string.h"
+#include <vector>
+
+namespace dlib
+{
+
+ template <
+ typename clp_base
+ >
+ class cmd_line_parser_check_1 : public clp_base
+ {
+
+ /*!
+ This extension doesn't add any state.
+
+ !*/
+
+
+ public:
+ typedef typename clp_base::char_type char_type;
+ typedef typename clp_base::string_type string_type;
+
+ // ------------------------------------------------------------------------------------
+
+ class cmd_line_check_error : public dlib::error
+ {
+ friend class cmd_line_parser_check_1;
+
+ cmd_line_check_error(
+ error_type t,
+ const string_type& opt_,
+ const string_type& arg_
+ ) :
+ dlib::error(t),
+ opt(opt_),
+ opt2(),
+ arg(arg_),
+ required_opts()
+ { set_info_string(); }
+
+ cmd_line_check_error(
+ error_type t,
+ const string_type& opt_,
+ const string_type& opt2_,
+ int // this is just to make this constructor different from the one above
+ ) :
+ dlib::error(t),
+ opt(opt_),
+ opt2(opt2_),
+ arg(),
+ required_opts()
+ { set_info_string(); }
+
+ cmd_line_check_error (
+ error_type t,
+ const string_type& opt_,
+ const std::vector<string_type>& vect
+ ) :
+ dlib::error(t),
+ opt(opt_),
+ opt2(),
+ arg(),
+ required_opts(vect)
+ { set_info_string(); }
+
+ cmd_line_check_error(
+ error_type t,
+ const string_type& opt_
+ ) :
+ dlib::error(t),
+ opt(opt_),
+ opt2(),
+ arg(),
+ required_opts()
+ { set_info_string(); }
+
+ ~cmd_line_check_error() throw() {}
+
+ void set_info_string (
+ )
+ {
+ std::ostringstream sout;
+ switch (type)
+ {
+ case EINVALID_OPTION_ARG:
+ sout << "Command line error: '" << narrow(arg) << "' is not a valid argument to "
+ << "the '" << narrow(opt) << "' option.";
+ break;
+ case EMISSING_REQUIRED_OPTION:
+ if (required_opts.size() == 1)
+ {
+ sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of "
+ << "the '" << required_opts[0] << "' option.";
+ }
+ else
+ {
+ sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of "
+ << "one of the following options: ";
+ for (unsigned long i = 0; i < required_opts.size(); ++i)
+ {
+ if (i == required_opts.size()-2)
+ sout << "'" << required_opts[i] << "' or ";
+ else if (i == required_opts.size()-1)
+ sout << "'" << required_opts[i] << "'.";
+ else
+ sout << "'" << required_opts[i] << "', ";
+ }
+ }
+ break;
+ case EINCOMPATIBLE_OPTIONS:
+ sout << "Command line error: The '" << narrow(opt) << "' and '" << narrow(opt2)
+ << "' options cannot be given together on the command line.";
+ break;
+ case EMULTIPLE_OCCURANCES:
+ sout << "Command line error: The '" << narrow(opt) << "' option can only "
+ << "be given on the command line once.";
+ break;
+ default:
+ sout << "Command line error.";
+ break;
+ }
+ const_cast<std::string&>(info) = wrap_string(sout.str(),0,0);
+ }
+
+ public:
+ const string_type opt;
+ const string_type opt2;
+ const string_type arg;
+ const std::vector<string_type> required_opts;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void check_option_arg_type (
+ const string_type& option_name
+ ) const;
+
+ template <
+ typename T
+ >
+ void check_option_arg_range (
+ const string_type& option_name,
+ const T& first,
+ const T& last
+ ) const;
+
+ template <
+ typename T,
+ size_t length
+ >
+ void check_option_arg_range (
+ const string_type& option_name,
+ const T (&arg_set)[length]
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_option_arg_range (
+ const string_type& option_name,
+ const char_type* (&arg_set)[length]
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_incompatible_options (
+ const char_type* (&option_set)[length]
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_one_time_options (
+ const char_type* (&option_set)[length]
+ ) const;
+
+ void check_incompatible_options (
+ const string_type& option_name1,
+ const string_type& option_name2
+ ) const;
+
+ void check_sub_option (
+ const string_type& parent_option,
+ const string_type& sub_option
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_sub_options (
+ const string_type& parent_option,
+ const char_type* (&sub_option_set)[length]
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_sub_options (
+ const char_type* (&parent_option_set)[length],
+ const string_type& sub_option
+ ) const;
+
+ template <
+ size_t parent_length,
+ size_t sub_length
+ >
+ void check_sub_options (
+ const char_type* (&parent_option_set)[parent_length],
+ const char_type* (&sub_option_set)[sub_length]
+ ) const;
+ };
+
+ template <
+ typename clp_base
+ >
+ inline void swap (
+ cmd_line_parser_check_1<clp_base>& a,
+ cmd_line_parser_check_1<clp_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ template <typename T>
+ void cmd_line_parser_check_1<clp_base>::
+ check_option_arg_type (
+ const string_type& option_name
+ ) const
+ {
+ try
+ {
+ const typename clp_base::option_type& opt = this->option(option_name);
+ const unsigned long number_of_arguments = opt.number_of_arguments();
+ const unsigned long count = opt.count();
+ for (unsigned long i = 0; i < number_of_arguments; ++i)
+ {
+ for (unsigned long j = 0; j < count; ++j)
+ {
+ string_cast<T>(opt.argument(i,j));
+ }
+ }
+ }
+ catch (string_cast_error& e)
+ {
+ throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ template <typename T>
+ void cmd_line_parser_check_1<clp_base>::
+ check_option_arg_range (
+ const string_type& option_name,
+ const T& first,
+ const T& last
+ ) const
+ {
+ try
+ {
+ const typename clp_base::option_type& opt = this->option(option_name);
+ const unsigned long number_of_arguments = opt.number_of_arguments();
+ const unsigned long count = opt.count();
+ for (unsigned long i = 0; i < number_of_arguments; ++i)
+ {
+ for (unsigned long j = 0; j < count; ++j)
+ {
+ T temp(string_cast<T>(opt.argument(i,j)));
+ if (temp < first || last < temp)
+ {
+ throw cmd_line_check_error(
+ EINVALID_OPTION_ARG,
+ option_name,
+ opt.argument(i,j)
+ );
+ }
+ }
+ }
+ }
+ catch (string_cast_error& e)
+ {
+ throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ template < typename T, size_t length >
+ void cmd_line_parser_check_1<clp_base>::
+ check_option_arg_range (
+ const string_type& option_name,
+ const T (&arg_set)[length]
+ ) const
+ {
+ try
+ {
+ const typename clp_base::option_type& opt = this->option(option_name);
+ const unsigned long number_of_arguments = opt.number_of_arguments();
+ const unsigned long count = opt.count();
+ for (unsigned long i = 0; i < number_of_arguments; ++i)
+ {
+ for (unsigned long j = 0; j < count; ++j)
+ {
+ T temp(string_cast<T>(opt.argument(i,j)));
+ size_t k = 0;
+ for (; k < length; ++k)
+ {
+ if (arg_set[k] == temp)
+ break;
+ }
+ if (k == length)
+ {
+ throw cmd_line_check_error(
+ EINVALID_OPTION_ARG,
+ option_name,
+ opt.argument(i,j)
+ );
+ }
+ }
+ }
+ }
+ catch (string_cast_error& e)
+ {
+ throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ template < size_t length >
+ void cmd_line_parser_check_1<clp_base>::
+ check_option_arg_range (
+ const string_type& option_name,
+ const char_type* (&arg_set)[length]
+ ) const
+ {
+ const typename clp_base::option_type& opt = this->option(option_name);
+ const unsigned long number_of_arguments = opt.number_of_arguments();
+ const unsigned long count = opt.count();
+ for (unsigned long i = 0; i < number_of_arguments; ++i)
+ {
+ for (unsigned long j = 0; j < count; ++j)
+ {
+ size_t k = 0;
+ for (; k < length; ++k)
+ {
+ if (arg_set[k] == opt.argument(i,j))
+ break;
+ }
+ if (k == length)
+ {
+ throw cmd_line_check_error(
+ EINVALID_OPTION_ARG,
+ option_name,
+ opt.argument(i,j)
+ );
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ template < size_t length >
+ void cmd_line_parser_check_1<clp_base>::
+ check_incompatible_options (
+ const char_type* (&option_set)[length]
+ ) const
+ {
+ for (size_t i = 0; i < length; ++i)
+ {
+ for (size_t j = i+1; j < length; ++j)
+ {
+ if (this->option(option_set[i]).count() > 0 &&
+ this->option(option_set[j]).count() > 0 )
+ {
+ throw cmd_line_check_error(
+ EINCOMPATIBLE_OPTIONS,
+ option_set[i],
+ option_set[j],
+ 0 // this argument has no meaning and is only here to make this
+ // call different from the other constructor
+ );
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ void cmd_line_parser_check_1<clp_base>::
+ check_incompatible_options (
+ const string_type& option_name1,
+ const string_type& option_name2
+ ) const
+ {
+ if (this->option(option_name1).count() > 0 &&
+ this->option(option_name2).count() > 0 )
+ {
+ throw cmd_line_check_error(
+ EINCOMPATIBLE_OPTIONS,
+ option_name1,
+ option_name2,
+ 0 // this argument has no meaning and is only here to make this
+ // call different from the other constructor
+ );
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ void cmd_line_parser_check_1<clp_base>::
+ check_sub_option (
+ const string_type& parent_option,
+ const string_type& sub_option
+ ) const
+ {
+ if (this->option(parent_option).count() == 0)
+ {
+ if (this->option(sub_option).count() != 0)
+ {
+ std::vector<string_type> vect;
+ vect.resize(1);
+ vect[0] = parent_option;
+ throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ template < size_t length >
+ void cmd_line_parser_check_1<clp_base>::
+ check_sub_options (
+ const string_type& parent_option,
+ const char_type* (&sub_option_set)[length]
+ ) const
+ {
+ if (this->option(parent_option).count() == 0)
+ {
+ size_t i = 0;
+ for (; i < length; ++i)
+ {
+ if (this->option(sub_option_set[i]).count() > 0)
+ break;
+ }
+ if (i != length)
+ {
+ std::vector<string_type> vect;
+ vect.resize(1);
+ vect[0] = parent_option;
+ throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ template < size_t length >
+ void cmd_line_parser_check_1<clp_base>::
+ check_sub_options (
+ const char_type* (&parent_option_set)[length],
+ const string_type& sub_option
+ ) const
+ {
+ // first check if the sub_option is present
+ if (this->option(sub_option).count() > 0)
+ {
+ // now check if any of the parents are present
+ bool parents_present = false;
+ for (size_t i = 0; i < length; ++i)
+ {
+ if (this->option(parent_option_set[i]).count() > 0)
+ {
+ parents_present = true;
+ break;
+ }
+ }
+
+ if (!parents_present)
+ {
+ std::vector<string_type> vect(parent_option_set, parent_option_set+length);
+ throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ template < size_t parent_length, size_t sub_length >
+ void cmd_line_parser_check_1<clp_base>::
+ check_sub_options (
+ const char_type* (&parent_option_set)[parent_length],
+ const char_type* (&sub_option_set)[sub_length]
+ ) const
+ {
+ // first check if any of the parent options are present
+ bool parents_present = false;
+ for (size_t i = 0; i < parent_length; ++i)
+ {
+ if (this->option(parent_option_set[i]).count() > 0)
+ {
+ parents_present = true;
+ break;
+ }
+ }
+
+ if (!parents_present)
+ {
+ // none of these sub options should be present
+ size_t i = 0;
+ for (; i < sub_length; ++i)
+ {
+ if (this->option(sub_option_set[i]).count() > 0)
+ break;
+ }
+ if (i != sub_length)
+ {
+ std::vector<string_type> vect(parent_option_set, parent_option_set+parent_length);
+ throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_base>
+ template < size_t length >
+ void cmd_line_parser_check_1<clp_base>::
+ check_one_time_options (
+ const char_type* (&option_set)[length]
+ ) const
+ {
+ size_t i = 0;
+ for (; i < length; ++i)
+ {
+ if (this->option(option_set[i]).count() > 1)
+ break;
+ }
+ if (i != length)
+ {
+ throw cmd_line_check_error(
+ EMULTIPLE_OCCURANCES,
+ option_set[i]
+ );
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CMD_LINE_PARSER_CHECk_1_
+
+
diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_c.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_c.h
new file mode 100644
index 000000000..7ff858e89
--- /dev/null
+++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_c.h
@@ -0,0 +1,453 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CMD_LINE_PARSER_CHECk_C_
+#define DLIB_CMD_LINE_PARSER_CHECk_C_
+
+#include "cmd_line_parser_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <string>
+#include "../interfaces/cmd_line_parser_option.h"
+#include "../string.h"
+
+namespace dlib
+{
+
+ template <
+ typename clp_check
+ >
+ class cmd_line_parser_check_c : public clp_check
+ {
+ public:
+
+ typedef typename clp_check::char_type char_type;
+ typedef typename clp_check::string_type string_type;
+
+ template <
+ typename T
+ >
+ void check_option_arg_type (
+ const string_type& option_name
+ ) const;
+
+ template <
+ typename T
+ >
+ void check_option_arg_range (
+ const string_type& option_name,
+ const T& first,
+ const T& last
+ ) const;
+
+ template <
+ typename T,
+ size_t length
+ >
+ void check_option_arg_range (
+ const string_type& option_name,
+ const T (&arg_set)[length]
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_option_arg_range (
+ const string_type& option_name,
+ const char_type* (&arg_set)[length]
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_incompatible_options (
+ const char_type* (&option_set)[length]
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_one_time_options (
+ const char_type* (&option_set)[length]
+ ) const;
+
+ void check_incompatible_options (
+ const string_type& option_name1,
+ const string_type& option_name2
+ ) const;
+
+ void check_sub_option (
+ const string_type& parent_option,
+ const string_type& sub_option
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_sub_options (
+ const string_type& parent_option,
+ const char_type* (&sub_option_set)[length]
+ ) const;
+
+ template <
+ size_t length
+ >
+ void check_sub_options (
+ const char_type* (&parent_option_set)[length],
+ const string_type& sub_option
+ ) const;
+
+ template <
+ size_t parent_length,
+ size_t sub_length
+ >
+ void check_sub_options (
+ const char_type* (&parent_option_set)[parent_length],
+ const char_type* (&sub_option_set)[sub_length]
+ ) const;
+ };
+
+
+ template <
+ typename clp_check
+ >
+ inline void swap (
+ cmd_line_parser_check_c<clp_check>& a,
+ cmd_line_parser_check_c<clp_check>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ template <typename T>
+ void cmd_line_parser_check_c<clp_check>::
+ check_option_arg_type (
+ const string_type& option_name
+ ) const
+ {
+ COMPILE_TIME_ASSERT(is_pointer_type<T>::value == false);
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name),
+ "\tvoid cmd_line_parser_check::check_option_arg_type()"
+ << "\n\tYou must have already parsed the command line and option_name must be valid."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false")
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ << "\n\toption_name: " << option_name
+ );
+
+ clp_check::template check_option_arg_type<T>(option_name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ template <typename T>
+ void cmd_line_parser_check_c<clp_check>::
+ check_option_arg_range (
+ const string_type& option_name,
+ const T& first,
+ const T& last
+ ) const
+ {
+ COMPILE_TIME_ASSERT(is_pointer_type<T>::value == false);
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name) &&
+ first <= last,
+ "\tvoid cmd_line_parser_check::check_option_arg_range()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false")
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ << "\n\toption_name: " << option_name
+ << "\n\tfirst: " << first
+ << "\n\tlast: " << last
+ );
+
+ clp_check::check_option_arg_range(option_name,first,last);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ template < typename T, size_t length >
+ void cmd_line_parser_check_c<clp_check>::
+ check_option_arg_range (
+ const string_type& option_name,
+ const T (&arg_set)[length]
+ ) const
+ {
+ COMPILE_TIME_ASSERT(is_pointer_type<T>::value == false);
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name),
+ "\tvoid cmd_line_parser_check::check_option_arg_range()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false")
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ << "\n\toption_name: " << option_name
+ );
+
+ clp_check::check_option_arg_range(option_name,arg_set);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ template < size_t length >
+ void cmd_line_parser_check_c<clp_check>::
+ check_option_arg_range (
+ const string_type& option_name,
+ const char_type* (&arg_set)[length]
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name),
+ "\tvoid cmd_line_parser_check::check_option_arg_range()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false")
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ << "\n\toption_name: " << option_name
+ );
+
+ clp_check::check_option_arg_range(option_name,arg_set);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ template < size_t length >
+ void cmd_line_parser_check_c<clp_check>::
+ check_incompatible_options (
+ const char_type* (&option_set)[length]
+ ) const
+ {
+ // make sure requires clause is not broken
+ for (size_t i = 0; i < length; ++i)
+ {
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]),
+ "\tvoid cmd_line_parser_check::check_incompatible_options()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false")
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ << "\n\toption_set[i]: " << option_set[i]
+ << "\n\ti: " << static_cast<unsigned long>(i)
+ );
+
+ }
+ clp_check::check_incompatible_options(option_set);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ void cmd_line_parser_check_c<clp_check>::
+ check_incompatible_options (
+ const string_type& option_name1,
+ const string_type& option_name2
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name1) &&
+ this->option_is_defined(option_name2),
+ "\tvoid cmd_line_parser_check::check_incompatible_options()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(option_name1): " << ((this->option_is_defined(option_name1))?"true":"false")
+ << "\n\toption_is_defined(option_name2): " << ((this->option_is_defined(option_name2))?"true":"false")
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ << "\n\toption_name1: " << option_name1
+ << "\n\toption_name2: " << option_name2
+ );
+
+ clp_check::check_incompatible_options(option_name1,option_name2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ void cmd_line_parser_check_c<clp_check>::
+ check_sub_option (
+ const string_type& parent_option,
+ const string_type& sub_option
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option) &&
+ this->option_is_defined(sub_option),
+ "\tvoid cmd_line_parser_check::check_sub_option()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\tparsed_line(): " << this->parsed_line()
+ << "\n\toption_is_defined(parent_option): " << this->option_is_defined(parent_option)
+ << "\n\toption_is_defined(sub_option): " << this->option_is_defined(sub_option)
+ << "\n\tparent_option: " << parent_option
+ << "\n\tsub_option: " << sub_option
+ );
+ clp_check::check_sub_option(parent_option,sub_option);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ template < size_t length >
+ void cmd_line_parser_check_c<clp_check>::
+ check_sub_options (
+ const string_type& parent_option,
+ const char_type* (&sub_option_set)[length]
+ ) const
+ {
+ // make sure requires clause is not broken
+ for (size_t i = 0; i < length; ++i)
+ {
+ DLIB_CASSERT( this->option_is_defined(sub_option_set[i]),
+ "\tvoid cmd_line_parser_check::check_sub_options()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(sub_option_set[i]): "
+ << ((this->option_is_defined(sub_option_set[i]))?"true":"false")
+ << "\n\tsub_option_set[i]: " << sub_option_set[i]
+ << "\n\ti: " << static_cast<unsigned long>(i)
+ );
+
+ }
+
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option),
+ "\tvoid cmd_line_parser_check::check_sub_options()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(parent_option): " << ((this->option_is_defined(parent_option))?"true":"false")
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ << "\n\tparent_option: " << parent_option
+ );
+ clp_check::check_sub_options(parent_option,sub_option_set);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ template < size_t length >
+ void cmd_line_parser_check_c<clp_check>::
+ check_sub_options (
+ const char_type* (&parent_option_set)[length],
+ const string_type& sub_option
+ ) const
+ {
+ // make sure requires clause is not broken
+ for (size_t i = 0; i < length; ++i)
+ {
+ DLIB_CASSERT( this->option_is_defined(parent_option_set[i]),
+ "\tvoid cmd_line_parser_check::check_sub_options()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(parent_option_set[i]): "
+ << ((this->option_is_defined(parent_option_set[i]))?"true":"false")
+ << "\n\tparent_option_set[i]: " << parent_option_set[i]
+ << "\n\ti: " << static_cast<unsigned long>(i)
+ );
+
+ }
+
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(sub_option),
+ "\tvoid cmd_line_parser_check::check_sub_options()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(sub_option): " << ((this->option_is_defined(sub_option))?"true":"false")
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ << "\n\tsub_option: " << sub_option
+ );
+ clp_check::check_sub_options(parent_option_set,sub_option);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ template < size_t parent_length, size_t sub_length >
+ void cmd_line_parser_check_c<clp_check>::
+ check_sub_options (
+ const char_type* (&parent_option_set)[parent_length],
+ const char_type* (&sub_option_set)[sub_length]
+ ) const
+ {
+ // make sure requires clause is not broken
+ for (size_t i = 0; i < sub_length; ++i)
+ {
+ DLIB_CASSERT( this->option_is_defined(sub_option_set[i]),
+ "\tvoid cmd_line_parser_check::check_sub_options()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(sub_option_set[i]): "
+ << ((this->option_is_defined(sub_option_set[i]))?"true":"false")
+ << "\n\tsub_option_set[i]: " << sub_option_set[i]
+ << "\n\ti: " << static_cast<unsigned long>(i)
+ );
+ }
+
+ for (size_t i = 0; i < parent_length; ++i)
+ {
+ DLIB_CASSERT( this->option_is_defined(parent_option_set[i]),
+ "\tvoid cmd_line_parser_check::check_parent_options()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(parent_option_set[i]): "
+ << ((this->option_is_defined(parent_option_set[i]))?"true":"false")
+ << "\n\tparent_option_set[i]: " << parent_option_set[i]
+ << "\n\ti: " << static_cast<unsigned long>(i)
+ );
+ }
+
+
+
+ DLIB_CASSERT( this->parsed_line() == true ,
+ "\tvoid cmd_line_parser_check::check_sub_options()"
+ << "\n\tYou must have parsed the command line before you call this function."
+ << "\n\tthis: " << this
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ );
+
+ clp_check::check_sub_options(parent_option_set,sub_option_set);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename clp_check>
+ template < size_t length >
+ void cmd_line_parser_check_c<clp_check>::
+ check_one_time_options (
+ const char_type* (&option_set)[length]
+ ) const
+ {
+ // make sure requires clause is not broken
+ for (size_t i = 0; i < length; ++i)
+ {
+ DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]),
+ "\tvoid cmd_line_parser_check::check_one_time_options()"
+ << "\n\tSee the requires clause for this function."
+ << "\n\tthis: " << this
+ << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false")
+ << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false")
+ << "\n\toption_set[i]: " << option_set[i]
+ << "\n\ti: " << static_cast<unsigned long>(i)
+ );
+
+ }
+ clp_check::check_one_time_options(option_set);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CMD_LINE_PARSER_CHECk_C_
+
diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h
new file mode 100644
index 000000000..68ea5a135
--- /dev/null
+++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h
@@ -0,0 +1,799 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CMD_LINE_PARSER_KERNEl_1_
+#define DLIB_CMD_LINE_PARSER_KERNEl_1_
+
+#include "cmd_line_parser_kernel_abstract.h"
+#include "../algs.h"
+#include <string>
+#include <sstream>
+#include "../interfaces/enumerable.h"
+#include "../interfaces/cmd_line_parser_option.h"
+#include "../assert.h"
+#include "../string.h"
+
+namespace dlib
+{
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ class cmd_line_parser_kernel_1 : public enumerable<cmd_line_parser_option<charT> >
+ {
+ /*!
+ REQUIREMENTS ON map
+ is an implementation of map/map_kernel_abstract.h
+ is instantiated to map items of type std::basic_string<charT> to void*
+
+ REQUIREMENTS ON sequence
+ is an implementation of sequence/sequence_kernel_abstract.h and
+ is instantiated with std::basic_string<charT>
+
+ REQUIREMENTS ON sequence2
+ is an implementation of sequence/sequence_kernel_abstract.h and
+ is instantiated with std::basic_string<charT>*
+
+ INITIAL VALUE
+ options.size() == 0
+ argv.size() == 0
+ have_parsed_line == false
+
+ CONVENTION
+ have_parsed_line == parsed_line()
+ argv[index] == operator[](index)
+ argv.size() == number_of_arguments()
+ *((option_t*)options[name]) == option(name)
+ options.is_in_domain(name) == option_is_defined(name)
+ !*/
+
+
+
+
+ public:
+
+ typedef charT char_type;
+ typedef std::basic_string<charT> string_type;
+ typedef cmd_line_parser_option<charT> option_type;
+
+ // exception class
+ class cmd_line_parse_error : public dlib::error
+ {
+ void set_info_string (
+ )
+ {
+ std::ostringstream sout;
+ switch (type)
+ {
+ case EINVALID_OPTION:
+ sout << "Command line error: '" << narrow(item) << "' is not a valid option.";
+ break;
+ case ETOO_FEW_ARGS:
+ if (num > 1)
+ {
+ sout << "Command line error: The '" << narrow(item) << "' option requires " << num
+ << " arguments.";
+ }
+ else
+ {
+ sout << "Command line error: The '" << narrow(item) << "' option requires " << num
+ << " argument.";
+ }
+ break;
+ case ETOO_MANY_ARGS:
+ sout << "Command line error: The '" << narrow(item) << "' option does not take any arguments.\n";
+ break;
+ default:
+ sout << "Command line error.";
+ break;
+ }
+ const_cast<std::string&>(info) = wrap_string(sout.str(),0,0);
+ }
+
+ public:
+ cmd_line_parse_error(
+ error_type t,
+ const std::basic_string<charT>& _item
+ ) :
+ dlib::error(t),
+ item(_item),
+ num(0)
+ { set_info_string();}
+
+ cmd_line_parse_error(
+ error_type t,
+ const std::basic_string<charT>& _item,
+ unsigned long _num
+ ) :
+ dlib::error(t),
+ item(_item),
+ num(_num)
+ { set_info_string();}
+
+ cmd_line_parse_error(
+ ) :
+ dlib::error(),
+ item(),
+ num(0)
+ { set_info_string();}
+
+ ~cmd_line_parse_error() throw() {}
+
+ const std::basic_string<charT> item;
+ const unsigned long num;
+ };
+
+
+ private:
+
+ class option_t : public cmd_line_parser_option<charT>
+ {
+ /*!
+ INITIAL VALUE
+ options.size() == 0
+
+ CONVENTION
+ name_ == name()
+ description_ == description()
+ number_of_arguments_ == number_of_arguments()
+ options[N][arg] == argument(arg,N)
+ num_present == count()
+ !*/
+
+ friend class cmd_line_parser_kernel_1<charT,map,sequence,sequence2>;
+
+ public:
+
+ const std::basic_string<charT>& name (
+ ) const { return name_; }
+
+ const std::basic_string<charT>& group_name (
+ ) const { return group_name_; }
+
+ const std::basic_string<charT>& description (
+ ) const { return description_; }
+
+ unsigned long number_of_arguments(
+ ) const { return number_of_arguments_; }
+
+ unsigned long count (
+ ) const { return num_present; }
+
+ const std::basic_string<charT>& argument (
+ unsigned long arg,
+ unsigned long N
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( N < count() && arg < number_of_arguments(),
+ "\tconst string_type& cmd_line_parser_option::argument(unsigned long,unsigned long)"
+ << "\n\tInvalid arguments were given to this function."
+ << "\n\tthis: " << this
+ << "\n\tN: " << N
+ << "\n\targ: " << arg
+ << "\n\tname(): " << narrow(name())
+ << "\n\tcount(): " << count()
+ << "\n\tnumber_of_arguments(): " << number_of_arguments()
+ );
+
+ return options[N][arg];
+ }
+
+ protected:
+
+ option_t (
+ ) :
+ num_present(0)
+ {}
+
+ ~option_t()
+ {
+ clear();
+ }
+
+ private:
+
+ void clear()
+ /*!
+ ensures
+ - #count() == 0
+ - clears everything out of options and frees memory
+ !*/
+ {
+ for (unsigned long i = 0; i < options.size(); ++i)
+ {
+ delete [] options[i];
+ }
+ options.clear();
+ num_present = 0;
+ }
+
+ // data members
+ std::basic_string<charT> name_;
+ std::basic_string<charT> group_name_;
+ std::basic_string<charT> description_;
+ sequence2 options;
+ unsigned long number_of_arguments_;
+ unsigned long num_present;
+
+
+
+ // restricted functions
+ option_t(option_t&); // copy constructor
+ option_t& operator=(option_t&); // assignment operator
+ };
+
+ // --------------------------
+
+ public:
+
+ cmd_line_parser_kernel_1 (
+ );
+
+ virtual ~cmd_line_parser_kernel_1 (
+ );
+
+ void clear(
+ );
+
+ void parse (
+ int argc,
+ const charT** argv
+ );
+
+ void parse (
+ int argc,
+ charT** argv
+ )
+ {
+ parse(argc, const_cast<const charT**>(argv));
+ }
+
+ bool parsed_line(
+ ) const;
+
+ bool option_is_defined (
+ const string_type& name
+ ) const;
+
+ void add_option (
+ const string_type& name,
+ const string_type& description,
+ unsigned long number_of_arguments = 0
+ );
+
+ void set_group_name (
+ const string_type& group_name
+ );
+
+ string_type get_group_name (
+ ) const { return group_name; }
+
+ const cmd_line_parser_option<charT>& option (
+ const string_type& name
+ ) const;
+
+ unsigned long number_of_arguments(
+ ) const;
+
+ const string_type& operator[] (
+ unsigned long index
+ ) const;
+
+ void swap (
+ cmd_line_parser_kernel_1& item
+ );
+
+ // functions from the enumerable interface
+ bool at_start (
+ ) const { return options.at_start(); }
+
+ void reset (
+ ) const { options.reset(); }
+
+ bool current_element_valid (
+ ) const { return options.current_element_valid(); }
+
+ const cmd_line_parser_option<charT>& element (
+ ) const { return *static_cast<cmd_line_parser_option<charT>*>(options.element().value()); }
+
+ cmd_line_parser_option<charT>& element (
+ ) { return *static_cast<cmd_line_parser_option<charT>*>(options.element().value()); }
+
+ bool move_next (
+ ) const { return options.move_next(); }
+
+ size_t size (
+ ) const { return options.size(); }
+
+ private:
+
+ // data members
+ map options;
+ sequence argv;
+ bool have_parsed_line;
+ string_type group_name;
+
+ // restricted functions
+ cmd_line_parser_kernel_1(cmd_line_parser_kernel_1&); // copy constructor
+ cmd_line_parser_kernel_1& operator=(cmd_line_parser_kernel_1&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ inline void swap (
+ cmd_line_parser_kernel_1<charT,map,sequence,sequence2>& a,
+ cmd_line_parser_kernel_1<charT,map,sequence,sequence2>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ cmd_line_parser_kernel_1 (
+ ) :
+ have_parsed_line(false)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ ~cmd_line_parser_kernel_1 (
+ )
+ {
+ // delete all option_t objects in options
+ options.reset();
+ while (options.move_next())
+ {
+ delete static_cast<option_t*>(options.element().value());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ void cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ clear(
+ )
+ {
+ have_parsed_line = false;
+ argv.clear();
+
+
+ // delete all option_t objects in options
+ options.reset();
+ while (options.move_next())
+ {
+ delete static_cast<option_t*>(options.element().value());
+ }
+ options.clear();
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ void cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ parse (
+ int argc_,
+ const charT** argv
+ )
+ {
+ using namespace std;
+
+ // make sure there aren't any arguments hanging around from the last time
+ // parse was called
+ this->argv.clear();
+
+ // make sure that the options have been cleared of any arguments since
+ // the last time parse() was called
+ if (have_parsed_line)
+ {
+ options.reset();
+ while (options.move_next())
+ {
+ static_cast<option_t*>(options.element().value())->clear();
+ }
+ options.reset();
+ }
+
+ // this tells us if we have seen -- on the command line all by itself
+ // or not.
+ bool escape = false;
+
+ const unsigned long argc = static_cast<unsigned long>(argc_);
+ try
+ {
+
+ for (unsigned long i = 1; i < argc; ++i)
+ {
+ if (argv[i][0] == _dT(charT,'-') && !escape)
+ {
+ // we are looking at the start of an option
+
+ // --------------------------------------------------------------------
+ if (argv[i][1] == _dT(charT,'-'))
+ {
+ // we are looking at the start of a "long named" option
+ string_type temp = &argv[i][2];
+ string_type first_argument;
+ typename string_type::size_type pos = temp.find_first_of(_dT(charT,'='));
+ // This variable will be 1 if there is an argument supplied via the = sign
+ // and 0 otherwise.
+ unsigned long extra_argument = 0;
+ if (pos != string_type::npos)
+ {
+ // there should be an extra argument
+ extra_argument = 1;
+ first_argument = temp.substr(pos+1);
+ temp = temp.substr(0,pos);
+ }
+
+ // make sure this name is defined
+ if (!options.is_in_domain(temp))
+ {
+ // the long name is not a valid option
+ if (argv[i][2] == _dT(charT,'\0'))
+ {
+ // there was nothing after the -- on the command line
+ escape = true;
+ continue;
+ }
+ else
+ {
+ // there was something after the command line but it
+ // wasn't a valid option
+ throw cmd_line_parse_error(EINVALID_OPTION,temp);
+ }
+ }
+
+
+ option_t* o = static_cast<option_t*>(options[temp]);
+
+ // check the number of arguments after this option and make sure
+ // it is correct
+ if (argc + extra_argument <= o->number_of_arguments() + i)
+ {
+ // there are too few arguments
+ throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments());
+ }
+ if (extra_argument && first_argument.size() == 0 )
+ {
+ // if there would be exactly the right number of arguments if
+ // the first_argument wasn't empty
+ if (argc == o->number_of_arguments() + i)
+ throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments());
+ else
+ {
+ // in this case we just ignore the trailing = and parse everything
+ // the same.
+ extra_argument = 0;
+ }
+ }
+ // you can't force an option that doesn't have any arguments to take
+ // one by using the --option=arg syntax
+ if (extra_argument == 1 && o->number_of_arguments() == 0)
+ {
+ throw cmd_line_parse_error(ETOO_MANY_ARGS,temp);
+ }
+
+
+
+
+
+
+ // at this point we know that the option is ok and we should
+ // populate its options object
+ if (o->number_of_arguments() > 0)
+ {
+
+ string_type* stemp = new string_type[o->number_of_arguments()];
+ unsigned long j = 0;
+
+ // add the argument after the = sign if one is present
+ if (extra_argument)
+ {
+ stemp[0] = first_argument;
+ ++j;
+ }
+
+ for (; j < o->number_of_arguments(); ++j)
+ {
+ stemp[j] = argv[i+j+1-extra_argument];
+ }
+ o->options.add(o->options.size(),stemp);
+ }
+ o->num_present += 1;
+
+
+ // adjust the value of i to account for the arguments to
+ // this option
+ i += o->number_of_arguments() - extra_argument;
+ }
+ // --------------------------------------------------------------------
+ else
+ {
+ // we are looking at the start of a list of a single char options
+
+ // make sure there is something in this string other than -
+ if (argv[i][1] == _dT(charT,'\0'))
+ {
+ throw cmd_line_parse_error();
+ }
+
+ string_type temp = &argv[i][1];
+ const typename string_type::size_type num = temp.size();
+ for (unsigned long k = 0; k < num; ++k)
+ {
+ string_type name;
+ // Doing this instead of name = temp[k] seems to avoid a bug in g++ (Ubuntu/Linaro 4.5.2-8ubuntu4) 4.5.2
+ // which results in name[0] having the wrong value.
+ name.resize(1);
+ name[0] = temp[k];
+
+
+ // make sure this name is defined
+ if (!options.is_in_domain(name))
+ {
+ // the name is not a valid option
+ throw cmd_line_parse_error(EINVALID_OPTION,name);
+ }
+
+ option_t* o = static_cast<option_t*>(options[name]);
+
+ // if there are chars immediately following this option
+ int delta = 0;
+ if (num != k+1)
+ {
+ delta = 1;
+ }
+
+ // check the number of arguments after this option and make sure
+ // it is correct
+ if (argc + delta <= o->number_of_arguments() + i)
+ {
+ // there are too few arguments
+ std::ostringstream sout;
+ throw cmd_line_parse_error(ETOO_FEW_ARGS,name,o->number_of_arguments());
+ }
+
+
+ o->num_present += 1;
+
+ // at this point we know that the option is ok and we should
+ // populate its options object
+ if (o->number_of_arguments() > 0)
+ {
+ string_type* stemp = new string_type[o->number_of_arguments()];
+ if (delta == 1)
+ {
+ temp = &argv[i][2+k];
+ k = (unsigned long)num; // this ensures that the argument to this
+ // option isn't going to be treated as a
+ // list of options
+
+ stemp[0] = temp;
+ }
+ for (unsigned long j = 0; j < o->number_of_arguments()-delta; ++j)
+ {
+ stemp[j+delta] = argv[i+j+1];
+ }
+ o->options.add(o->options.size(),stemp);
+
+ // adjust the value of i to account for the arguments to
+ // this option
+ i += o->number_of_arguments()-delta;
+ }
+ } // for (unsigned long k = 0; k < num; ++k)
+ }
+ // --------------------------------------------------------------------
+
+ }
+ else
+ {
+ // this is just a normal argument
+ string_type temp = argv[i];
+ this->argv.add(this->argv.size(),temp);
+ }
+
+ }
+ have_parsed_line = true;
+
+ }
+ catch (...)
+ {
+ have_parsed_line = false;
+
+ // clear all the option objects
+ options.reset();
+ while (options.move_next())
+ {
+ static_cast<option_t*>(options.element().value())->clear();
+ }
+ options.reset();
+
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ bool cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ parsed_line(
+ ) const
+ {
+ return have_parsed_line;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ bool cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ option_is_defined (
+ const string_type& name
+ ) const
+ {
+ return options.is_in_domain(name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ void cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ set_group_name (
+ const string_type& group_name_
+ )
+ {
+ group_name = group_name_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ void cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ add_option (
+ const string_type& name,
+ const string_type& description,
+ unsigned long number_of_arguments
+ )
+ {
+ option_t* temp = new option_t;
+ try
+ {
+ temp->name_ = name;
+ temp->group_name_ = group_name;
+ temp->description_ = description;
+ temp->number_of_arguments_ = number_of_arguments;
+ void* t = temp;
+ string_type n(name);
+ options.add(n,t);
+ }catch (...) { delete temp; throw;}
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ const cmd_line_parser_option<charT>& cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ option (
+ const string_type& name
+ ) const
+ {
+ return *static_cast<cmd_line_parser_option<charT>*>(options[name]);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ unsigned long cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ number_of_arguments(
+ ) const
+ {
+ return argv.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ const std::basic_string<charT>& cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ operator[] (
+ unsigned long index
+ ) const
+ {
+ return argv[index];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename map,
+ typename sequence,
+ typename sequence2
+ >
+ void cmd_line_parser_kernel_1<charT,map,sequence,sequence2>::
+ swap (
+ cmd_line_parser_kernel_1<charT,map,sequence,sequence2>& item
+ )
+ {
+ options.swap(item.options);
+ argv.swap(item.argv);
+ exchange(have_parsed_line,item.have_parsed_line);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CMD_LINE_PARSER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h
new file mode 100644
index 000000000..8461ffb26
--- /dev/null
+++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h
@@ -0,0 +1,673 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_
+#ifdef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+#include <string>
+#include "../interfaces/enumerable.h"
+#include "../interfaces/cmd_line_parser_option.h"
+#include <vector>
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename charT
+ >
+ class cmd_line_parser : public enumerable<cmd_line_parser_option<charT> >
+ {
+ /*!
+ REQUIREMENTS ON charT
+ Must be an integral type suitable for storing characters. (e.g. char
+ or wchar_t)
+
+ INITIAL VALUE
+ - parsed_line() == false
+ - option_is_defined(x) == false, for all values of x
+ - get_group_name() == ""
+
+ ENUMERATION ORDER
+ The enumerator will enumerate over all the options defined in *this
+ in alphebetical order according to the name of the option.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ parsed_line(), option_is_defined(), option(), number_of_arguments(),
+ operator[](), and swap() functions do not invalidate pointers or
+ references to internal data. All other functions have no such guarantee.
+
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a command line parser.
+ The command lines must match the following BNF.
+
+ command_line ::= <program_name> { <options> | <arg> } [ -- {<word>} ]
+ program_name ::= <word>
+ arg ::= any <word> that does not start with -
+ option_arg ::= <sword>
+ option_name ::= <char>
+ long_option_name ::= <char> {<char> | - }
+ options ::= <bword> - <option_name> {<option_name>} {<option_arg>} |
+ <bword> -- <long_option_name> [=<option_arg>] {<bword> <option_arg>}
+ char ::= any character other than - or =
+ word ::= any string from argv where argv is the second
+ parameter to main()
+ sword ::= any suffix of a string from argv where argv is the
+ second parameter to main()
+ bword ::= This is an empty string which denotes the begining of a
+ <word>.
+
+
+ Options with arguments:
+ An option with N arguments will consider the next N swords to be
+ its arguments.
+
+ so for example, if we have an option o that expects 2 arguments
+ then the following are a few legal examples:
+
+ program -o arg1 arg2 general_argument
+ program -oarg1 arg2 general_argument
+
+ arg1 and arg2 are associated with the option o and general_argument
+ is not.
+
+ Arguments not associated with an option:
+ An argument that is not associated with an option is considered a
+ general command line argument and is indexed by operator[] defined
+ by the cmd_line_parser object. Additionally, if the string
+ "--" appears in the command line all by itself then all words
+ following it are considered to be general command line arguments.
+
+
+ Consider the following two examples involving a command line and
+ a cmd_line_parser object called parser.
+
+ Example 1:
+ command line: program general_arg1 -o arg1 arg2 general_arg2
+ Then the following is true (assuming the o option is defined
+ and takes 2 arguments).
+
+ parser[0] == "general_arg1"
+ parser[1] == "general_arg2"
+ parser.number_of_arguments() == 2
+ parser.option("o").argument(0) == "arg1"
+ parser.option("o").argument(1) == "arg2"
+ parser.option("o").count() == 1
+
+ Example 2:
+ command line: program general_arg1 -- -o arg1 arg2 general_arg2
+ Then the following is true (the -- causes everything following
+ it to be treated as a general argument).
+
+ parser[0] == "general_arg1"
+ parser[1] == "-o"
+ parser[2] == "arg1"
+ parser[3] == "arg2"
+ parser[4] == "general_arg2"
+ parser.number_of_arguments() == 5
+ parser.option("o").count() == 0
+ !*/
+
+ public:
+
+ typedef charT char_type;
+ typedef std::basic_string<charT> string_type;
+ typedef cmd_line_parser_option<charT> option_type;
+
+ // exception class
+ class cmd_line_parse_error : public dlib::error
+ {
+ /*!
+ GENERAL
+ This exception is thrown if there is an error detected in a
+ command line while it is being parsed. You can consult this
+ object's type and item members to determine the nature of the
+ error. (note that the type member is inherited from dlib::error).
+
+ INTERPRETING THIS EXCEPTION
+ - if (type == EINVALID_OPTION) then
+ - There was an undefined option on the command line
+ - item == The invalid option that was on the command line
+ - if (type == ETOO_FEW_ARGS) then
+ - An option was given on the command line but it was not
+ supplied with the required number of arguments.
+ - item == The name of this option.
+ - num == The number of arguments expected by this option.
+ - if (type == ETOO_MANY_ARGS) then
+ - An option was given on the command line such as --option=arg
+ but this option doesn't take any arguments.
+ - item == The name of this option.
+ !*/
+ public:
+ const std::basic_string<charT> item;
+ const unsigned long num;
+ };
+
+ // --------------------------
+
+ cmd_line_parser (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~cmd_line_parser (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void parse (
+ int argc,
+ const charT** argv
+ );
+ /*!
+ requires
+ - argv == an array of strings that was obtained from the second argument
+ of the function main().
+ (i.e. argv[0] should be the <program> token, argv[1] should be
+ an <options> or <arg> token, etc.)
+ - argc == the number of strings in argv
+ ensures
+ - parses the command line given by argc and argv
+ - #parsed_line() == true
+ - #at_start() == true
+ throws
+ - std::bad_alloc
+ if this exception is thrown then #*this is unusable until clear()
+ is called successfully
+ - cmd_line_parse_error
+ This exception is thrown if there is an error parsing the command line.
+ If this exception is thrown then #parsed_line() == false and all
+ options will have their count() set to 0 but otherwise there will
+ be no effect (i.e. all registered options will remain registered).
+ !*/
+
+ void parse (
+ int argc,
+ charT** argv
+ );
+ /*!
+ This just calls this->parse(argc,argv) and performs the necessary const_cast
+ on argv.
+ !*/
+
+ bool parsed_line(
+ ) const;
+ /*!
+ ensures
+ - returns true if parse() has been called successfully
+ - returns false otherwise
+ !*/
+
+ bool option_is_defined (
+ const string_type& name
+ ) const;
+ /*!
+ ensures
+ - returns true if the option has been added to the parser object
+ by calling add_option(name).
+ - returns false otherwise
+ !*/
+
+ void add_option (
+ const string_type& name,
+ const string_type& description,
+ unsigned long number_of_arguments = 0
+ );
+ /*!
+ requires
+ - parsed_line() == false
+ - option_is_defined(name) == false
+ - name does not contain any ' ', '\t', '\n', or '=' characters
+ - name[0] != '-'
+ - name.size() > 0
+ ensures
+ - #option_is_defined(name) == true
+ - #at_start() == true
+ - #option(name).count() == 0
+ - #option(name).description() == description
+ - #option(name).number_of_arguments() == number_of_arguments
+ - #option(name).group_name() == get_group_name()
+ throws
+ - std::bad_alloc
+ if this exception is thrown then the add_option() function has no
+ effect
+ !*/
+
+ const option_type& option (
+ const string_type& name
+ ) const;
+ /*!
+ requires
+ - option_is_defined(name) == true
+ ensures
+ - returns the option specified by name
+ !*/
+
+ unsigned long number_of_arguments(
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ ensures
+ - returns the number of arguments present in the command line.
+ This count does not include options or their arguments. Only
+ arguments unrelated to any option are counted.
+ !*/
+
+ const string_type& operator[] (
+ unsigned long N
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - N < number_of_arguments()
+ ensures
+ - returns the Nth command line argument
+ !*/
+
+ void swap (
+ cmd_line_parser& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ void print_options (
+ std::basic_ostream<char_type>& out
+ ) const;
+ /*!
+ ensures
+ - prints all the command line options to out.
+ - #at_start() == true
+ throws
+ - any exception.
+ if an exception is thrown then #at_start() == true but otherwise
+ it will have no effect on the state of #*this.
+ !*/
+
+ void print_options (
+ ) const;
+ /*!
+ ensures
+ - prints all the command line options to cout.
+ - #at_start() == true
+ throws
+ - any exception.
+ if an exception is thrown then #at_start() == true but otherwise
+ it will have no effect on the state of #*this.
+ !*/
+
+ string_type get_group_name (
+ ) const;
+ /*!
+ ensures
+ - returns the current group name. This is the group new options will be
+ added into when added via add_option().
+ - The group name of an option is used by print_options(). In particular,
+ it groups all options with the same group name together and displays them
+ under a title containing the text of the group name. This allows you to
+ group similar options together in the output of print_options().
+ - A group name of "" (i.e. the empty string) means that no group name is
+ set.
+ !*/
+
+ void set_group_name (
+ const string_type& group_name
+ );
+ /*!
+ ensures
+ - #get_group_name() == group_name
+ !*/
+
+ // -------------------------------------------------------------
+ // Input Validation Tools
+ // -------------------------------------------------------------
+
+ class cmd_line_check_error : public dlib::error
+ {
+ /*!
+ This is the exception thrown by the check_*() routines if they find a
+ command line error. The interpretation of the member variables is defined
+ below in each check_*() routine.
+ !*/
+
+ public:
+ const string_type opt;
+ const string_type opt2;
+ const string_type arg;
+ const std::vector<string_type> required_opts;
+ };
+
+ template <
+ typename T
+ >
+ void check_option_arg_type (
+ const string_type& option_name
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - option_is_defined(option_name) == true
+ - T is not a pointer type
+ ensures
+ - all the arguments for the given option are convertible
+ by string_cast<T>() to an object of type T.
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EINVALID_OPTION_ARG
+ - opt == option_name
+ - arg == the text of the offending argument
+ !*/
+
+ template <
+ typename T
+ >
+ void check_option_arg_range (
+ const string_type& option_name,
+ const T& first,
+ const T& last
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - option_is_defined(option_name) == true
+ - first <= last
+ - T is not a pointer type
+ ensures
+ - all the arguments for the given option are convertible
+ by string_cast<T>() to an object of type T and the resulting value is
+ in the range first to last inclusive.
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EINVALID_OPTION_ARG
+ - opt == option_name
+ - arg == the text of the offending argument
+ !*/
+
+ template <
+ typename T,
+ size_t length
+ >
+ void check_option_arg_range (
+ const string_type& option_name,
+ const T (&arg_set)[length]
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - option_is_defined(option_name) == true
+ - T is not a pointer type
+ ensures
+ - for each argument to the given option:
+ - this argument is convertible by string_cast<T>() to an object of
+ type T and the resulting value is equal to some element in the
+ arg_set array.
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EINVALID_OPTION_ARG
+ - opt == option_name
+ - arg == the text of the offending argument
+ !*/
+
+ template <
+ size_t length
+ >
+ void check_option_arg_range (
+ const string_type& option_name,
+ const char_type* (&arg_set)[length]
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - option_is_defined(option_name) == true
+ ensures
+ - for each argument to the given option:
+ - there is a string in the arg_set array that is equal to this argument.
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EINVALID_OPTION_ARG
+ - opt == option_name
+ - arg == the text of the offending argument
+ !*/
+
+ template <
+ size_t length
+ >
+ void check_one_time_options (
+ const char_type* (&option_set)[length]
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - for all valid i:
+ - option_is_defined(option_set[i]) == true
+ ensures
+ - all the options in the option_set array occur at most once on the
+ command line.
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EMULTIPLE_OCCURANCES
+ - opt == the option that occurred more than once on the command line.
+ !*/
+
+ void check_incompatible_options (
+ const string_type& option_name1,
+ const string_type& option_name2
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - option_is_defined(option_name1) == true
+ - option_is_defined(option_name2) == true
+ ensures
+ - option(option_name1).count() == 0 || option(option_name2).count() == 0
+ (i.e. at most, only one of the options is currently present)
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EINCOMPATIBLE_OPTIONS
+ - opt == option_name1
+ - opt2 == option_name2
+ !*/
+
+ template <
+ size_t length
+ >
+ void check_incompatible_options (
+ const char_type* (&option_set)[length]
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - for all valid i:
+ - option_is_defined(option_set[i]) == true
+ ensures
+ - At most only one of the options in the array option_set has a count()
+ greater than 0. (i.e. at most, only one of the options is currently present)
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EINCOMPATIBLE_OPTIONS
+ - opt == One of the incompatible options found.
+ - opt2 == The next incompatible option found.
+ !*/
+
+ void check_sub_option (
+ const string_type& parent_option,
+ const string_type& sub_option
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - option_is_defined(parent_option) == true
+ - option_is_defined(sub_option) == true
+ ensures
+ - if (option(parent_option).count() == 0) then
+ - option(sub_option).count() == 0
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EMISSING_REQUIRED_OPTION
+ - opt == sub_option.
+ - required_opts == a vector that contains only parent_option.
+ !*/
+
+ template <
+ size_t length
+ >
+ void check_sub_options (
+ const char_type* (&parent_option_set)[length],
+ const string_type& sub_option
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - option_is_defined(sub_option) == true
+ - for all valid i:
+ - option_is_defined(parent_option_set[i] == true
+ ensures
+ - if (option(sub_option).count() > 0) then
+ - At least one of the options in the array parent_option_set has a count()
+ greater than 0. (i.e. at least one of the options in parent_option_set
+ is currently present)
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EMISSING_REQUIRED_OPTION
+ - opt == the first option from the sub_option that is present.
+ - required_opts == a vector containing everything from parent_option_set.
+ !*/
+
+ template <
+ size_t length
+ >
+ void check_sub_options (
+ const string_type& parent_option,
+ const char_type* (&sub_option_set)[length]
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - option_is_defined(parent_option) == true
+ - for all valid i:
+ - option_is_defined(sub_option_set[i]) == true
+ ensures
+ - if (option(parent_option).count() == 0) then
+ - for all valid i:
+ - option(sub_option_set[i]).count() == 0
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EMISSING_REQUIRED_OPTION
+ - opt == the first option from the sub_option_set that is present.
+ - required_opts == a vector that contains only parent_option.
+ !*/
+
+ template <
+ size_t parent_length,
+ size_t sub_length
+ >
+ void check_sub_options (
+ const char_type* (&parent_option_set)[parent_length],
+ const char_type* (&sub_option_set)[sub_length]
+ ) const;
+ /*!
+ requires
+ - parsed_line() == true
+ - for all valid i:
+ - option_is_defined(parent_option_set[i] == true
+ - for all valid j:
+ - option_is_defined(sub_option_set[j]) == true
+ ensures
+ - for all valid j:
+ - if (option(sub_option_set[j]).count() > 0) then
+ - At least one of the options in the array parent_option_set has a count()
+ greater than 0. (i.e. at least one of the options in parent_option_set
+ is currently present)
+ throws
+ - std::bad_alloc
+ - cmd_line_check_error
+ This exception is thrown if the ensures clause could not be satisfied.
+ The exception's members will be set as follows:
+ - type == EMISSING_REQUIRED_OPTION
+ - opt == the first option from the sub_option_set that is present.
+ - required_opts == a vector containing everything from parent_option_set.
+ !*/
+
+
+ private:
+
+ // restricted functions
+ cmd_line_parser(cmd_line_parser&); // copy constructor
+ cmd_line_parser& operator=(cmd_line_parser&); // assignment operator
+
+ };
+
+// -----------------------------------------------------------------------------------------
+
+ typedef cmd_line_parser<char> command_line_parser;
+ typedef cmd_line_parser<wchar_t> wcommand_line_parser;
+
+// -----------------------------------------------------------------------------------------
+
+ template <
+ typename charT
+ >
+ inline void swap (
+ cmd_line_parser<charT>& a,
+ cmd_line_parser<charT>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// -----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h
new file mode 100644
index 000000000..e80543018
--- /dev/null
+++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h
@@ -0,0 +1,203 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CMD_LINE_PARSER_KERNEl_C_
+#define DLIB_CMD_LINE_PARSER_KERNEl_C_
+
+#include "cmd_line_parser_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <string>
+#include "../interfaces/cmd_line_parser_option.h"
+#include "../string.h"
+
+namespace dlib
+{
+
+ template <
+ typename clp_base
+ >
+ class cmd_line_parser_kernel_c : public clp_base
+ {
+ public:
+
+ typedef typename clp_base::char_type char_type;
+ typedef typename clp_base::string_type string_type;
+ typedef typename clp_base::option_type option_type;
+
+ void add_option (
+ const string_type& name,
+ const string_type& description,
+ unsigned long number_of_arguments = 0
+ );
+
+ const option_type& option (
+ const string_type& name
+ ) const;
+
+ unsigned long number_of_arguments(
+ ) const;
+
+ const option_type& element (
+ ) const;
+
+ option_type& element (
+ );
+
+ const string_type& operator[] (
+ unsigned long N
+ ) const;
+
+ };
+
+
+ template <
+ typename clp_base
+ >
+ inline void swap (
+ cmd_line_parser_kernel_c<clp_base>& a,
+ cmd_line_parser_kernel_c<clp_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename clp_base
+ >
+ const typename clp_base::string_type& cmd_line_parser_kernel_c<clp_base>::
+ operator[] (
+ unsigned long N
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->parsed_line() == true && N < number_of_arguments(),
+ "\tvoid cmd_line_parser::operator[](unsigned long N)"
+ << "\n\tYou must specify a valid index N and the parser must have run already."
+ << "\n\tthis: " << this
+ << "\n\tN: " << N
+ << "\n\tparsed_line(): " << this->parsed_line()
+ << "\n\tnumber_of_arguments(): " << number_of_arguments()
+ );
+
+ return clp_base::operator[](N);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename clp_base
+ >
+ void cmd_line_parser_kernel_c<clp_base>::
+ add_option (
+ const string_type& name,
+ const string_type& description,
+ unsigned long number_of_arguments
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->parsed_line() == false &&
+ name.size() > 0 &&
+ this->option_is_defined(name) == false &&
+ name.find_first_of(_dT(char_type," \t\n=")) == string_type::npos &&
+ name[0] != '-',
+ "\tvoid cmd_line_parser::add_option(const string_type&,const string_type&,unsigned long)"
+ << "\n\tsee the requires clause of add_option()"
+ << "\n\tthis: " << this
+ << "\n\tname.size(): " << static_cast<unsigned long>(name.size())
+ << "\n\tname: \"" << narrow(name) << "\""
+ << "\n\tparsed_line(): " << (this->parsed_line()? "true" : "false")
+ << "\n\tis_option_defined(\"" << narrow(name) << "\"): " << (this->option_is_defined(name)? "true" : "false")
+ );
+
+ clp_base::add_option(name,description,number_of_arguments);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename clp_base
+ >
+ const typename clp_base::option_type& cmd_line_parser_kernel_c<clp_base>::
+ option (
+ const string_type& name
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->option_is_defined(name) == true,
+ "\toption cmd_line_parser::option(const string_type&)"
+ << "\n\tto get an option it must be defined by a call to add_option()"
+ << "\n\tthis: " << this
+ << "\n\tname: \"" << narrow(name) << "\""
+ );
+
+ return clp_base::option(name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename clp_base
+ >
+ unsigned long cmd_line_parser_kernel_c<clp_base>::
+ number_of_arguments(
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->parsed_line() == true ,
+ "\tunsigned long cmd_line_parser::number_of_arguments()"
+ << "\n\tyou must parse the command line before you can find out how many arguments it has"
+ << "\n\tthis: " << this
+ );
+
+ return clp_base::number_of_arguments();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename clp_base
+ >
+ const typename clp_base::option_type& cmd_line_parser_kernel_c<clp_base>::
+ element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst cmd_line_parser_option& cmd_line_parser::element()"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return clp_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename clp_base
+ >
+ typename clp_base::option_type& cmd_line_parser_kernel_c<clp_base>::
+ element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tcmd_line_parser_option& cmd_line_parser::element()"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return clp_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CMD_LINE_PARSER_KERNEl_C_
+
diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_print_1.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_print_1.h
new file mode 100644
index 000000000..3f52c842f
--- /dev/null
+++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_print_1.h
@@ -0,0 +1,205 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CMD_LINE_PARSER_PRINt_1_
+#define DLIB_CMD_LINE_PARSER_PRINt_1_
+
+#include "cmd_line_parser_kernel_abstract.h"
+#include "../algs.h"
+#include "../string.h"
+#include <iostream>
+#include <string>
+#include <sstream>
+#include <map>
+#include <memory>
+
+namespace dlib
+{
+
+ template <
+ typename clp_base
+ >
+ class cmd_line_parser_print_1 : public clp_base
+ {
+
+ public:
+
+ void print_options (
+ std::basic_ostream<typename clp_base::char_type>& out
+ ) const;
+
+ void print_options (
+ ) const
+ {
+ print_options(std::cout);
+ }
+
+ };
+
+ template <
+ typename clp_base
+ >
+ inline void swap (
+ cmd_line_parser_print_1<clp_base>& a,
+ cmd_line_parser_print_1<clp_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename clp_base
+ >
+ void cmd_line_parser_print_1<clp_base>::
+ print_options (
+ std::basic_ostream<typename clp_base::char_type>& out
+ ) const
+ {
+ typedef typename clp_base::char_type ct;
+ typedef std::basic_string<ct> string;
+ typedef typename string::size_type size_type;
+
+ typedef std::basic_ostringstream<ct> ostringstream;
+
+ try
+ {
+
+
+ size_type max_len = 0;
+ this->reset();
+
+ // this loop here is just the bottom loop but without the print statements.
+ // I'm doing this to figure out what len should be.
+ while (this->move_next())
+ {
+ size_type len = 0;
+ len += 3;
+ if (this->element().name().size() > 1)
+ {
+ ++len;
+ }
+ len += this->element().name().size();
+
+ if (this->element().number_of_arguments() == 1)
+ {
+ len += 6;
+ }
+ else
+ {
+ for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i)
+ {
+ len += 7;
+ if (i+1 > 9)
+ ++len;
+ }
+ }
+
+ len += 3;
+ if (len < 33)
+ max_len = std::max(max_len,len);
+ }
+
+
+ // Make a separate ostringstream for each option group. We are going to write
+ // the output for each group to a separate ostringstream so that we can keep
+ // them grouped together in the final output.
+ std::map<string,std::shared_ptr<ostringstream> > groups;
+ this->reset();
+ while(this->move_next())
+ {
+ if (!groups[this->element().group_name()])
+ groups[this->element().group_name()].reset(new ostringstream);
+ }
+
+
+
+
+ this->reset();
+
+ while (this->move_next())
+ {
+ ostringstream& sout = *groups[this->element().group_name()];
+
+ size_type len = 0;
+ sout << _dT(ct,"\n -");
+ len += 3;
+ if (this->element().name().size() > 1)
+ {
+ sout << _dT(ct,"-");
+ ++len;
+ }
+ sout << this->element().name();
+ len += this->element().name().size();
+
+ if (this->element().number_of_arguments() == 1)
+ {
+ sout << _dT(ct," <arg>");
+ len += 6;
+ }
+ else
+ {
+ for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i)
+ {
+ sout << _dT(ct," <arg") << i+1 << _dT(ct,">");
+ len += 7;
+ if (i+1 > 9)
+ ++len;
+ }
+ }
+
+ sout << _dT(ct," ");
+ len += 3;
+
+ while (len < max_len)
+ {
+ ++len;
+ sout << _dT(ct," ");
+ }
+
+ const unsigned long ml = static_cast<unsigned long>(max_len);
+ // now print the description but make it wrap around nicely if it
+ // is to long to fit on one line.
+ if (len <= max_len)
+ sout << wrap_string(this->element().description(),0,ml);
+ else
+ sout << _dT(ct,"\n") << wrap_string(this->element().description(),ml,ml);
+ }
+
+ // Only print out a generic Options: group name if there is an unnamed option
+ // present.
+ if (groups.count(string()) == 1)
+ out << _dT(ct,"Options:");
+
+ // Now print everything out
+ typename std::map<string,std::shared_ptr<ostringstream> >::iterator i;
+ for (i = groups.begin(); i != groups.end(); ++i)
+ {
+ // print the group name if we have one
+ if (i->first.size() != 0)
+ {
+ if (i != groups.begin())
+ out << _dT(ct,"\n\n");
+ out << i->first << _dT(ct,":");
+ }
+
+ // print the options in the group
+ out << i->second->str();
+ }
+ out << _dT(ct,"\n\n");
+ this->reset();
+ }
+ catch (...)
+ {
+ this->reset();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CMD_LINE_PARSER_PRINt_1_
+
diff --git a/ml/dlib/dlib/cmd_line_parser/get_option.h b/ml/dlib/dlib/cmd_line_parser/get_option.h
new file mode 100644
index 000000000..2c8d1644f
--- /dev/null
+++ b/ml/dlib/dlib/cmd_line_parser/get_option.h
@@ -0,0 +1,181 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GET_OPTiON_Hh_
+#define DLIB_GET_OPTiON_Hh_
+
+#include "get_option_abstract.h"
+#include "../string.h"
+#include "../is_kind.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class option_parse_error : public error
+ {
+ public:
+ option_parse_error(const std::string& option_string, const std::string& str):
+ error(EOPTION_PARSE,"Error parsing argument for option '" + option_string + "', offending string is '" + str + "'.") {}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename config_reader_type, typename T>
+ T impl_config_reader_get_option (
+ const config_reader_type& cr,
+ const std::string& option_name,
+ const std::string& full_option_name,
+ T default_value
+ )
+ {
+ std::string::size_type pos = option_name.find_first_of(".");
+ if (pos == std::string::npos)
+ {
+ if (cr.is_key_defined(option_name))
+ {
+ try{ return string_cast<T>(cr[option_name]); }
+ catch (string_cast_error&) { throw option_parse_error(full_option_name, cr[option_name]); }
+ }
+ }
+ else
+ {
+ std::string block_name = option_name.substr(0,pos);
+ if (cr.is_block_defined(block_name))
+ {
+ return impl_config_reader_get_option(cr.block(block_name),
+ option_name.substr(pos+1),
+ full_option_name,
+ default_value);
+ }
+ }
+
+ return default_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename cr_type, typename T>
+ typename enable_if<is_config_reader<cr_type>,T>::type get_option (
+ const cr_type& cr,
+ const std::string& option_name,
+ T default_value
+ )
+ {
+ return impl_config_reader_get_option(cr, option_name, option_name, default_value);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename parser_type, typename T>
+ typename disable_if<is_config_reader<parser_type>,T>::type get_option (
+ const parser_type& parser,
+ const std::string& option_name,
+ T default_value
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( parser.option_is_defined(option_name) == true &&
+ parser.option(option_name).number_of_arguments() == 1,
+ "\t T get_option()"
+ << "\n\t option_name: " << option_name
+ << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name)
+ << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments()
+ );
+
+ if (parser.option(option_name))
+ {
+ try
+ {
+ default_value = string_cast<T>(parser.option(option_name).argument());
+ }
+ catch (string_cast_error&)
+ {
+ throw option_parse_error(option_name, parser.option(option_name).argument());
+ }
+ }
+ return default_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename parser_type, typename cr_type, typename T>
+ typename disable_if<is_config_reader<parser_type>,T>::type get_option (
+ const parser_type& parser,
+ const cr_type& cr,
+ const std::string& option_name,
+ T default_value
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( parser.option_is_defined(option_name) == true &&
+ parser.option(option_name).number_of_arguments() == 1,
+ "\t T get_option()"
+ << "\n\t option_name: " << option_name
+ << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name)
+ << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments()
+ );
+
+ if (parser.option(option_name))
+ return get_option(parser, option_name, default_value);
+ else
+ return get_option(cr, option_name, default_value);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename parser_type, typename cr_type, typename T>
+ typename disable_if<is_config_reader<parser_type>,T>::type get_option (
+ const cr_type& cr,
+ const parser_type& parser,
+ const std::string& option_name,
+ T default_value
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( parser.option_is_defined(option_name) == true &&
+ parser.option(option_name).number_of_arguments() == 1,
+ "\t T get_option()"
+ << "\n\t option_name: " << option_name
+ << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name)
+ << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments()
+ );
+
+ if (parser.option(option_name))
+ return get_option(parser, option_name, default_value);
+ else
+ return get_option(cr, option_name, default_value);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ inline std::string get_option (
+ const T& cr,
+ const std::string& option_name,
+ const char* default_value
+ )
+ {
+ return get_option(cr, option_name, std::string(default_value));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline std::string get_option (
+ const T& parser,
+ const U& cr,
+ const std::string& option_name,
+ const char* default_value
+ )
+ {
+ return get_option(parser, cr, option_name, std::string(default_value));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GET_OPTiON_Hh_
+
diff --git a/ml/dlib/dlib/cmd_line_parser/get_option_abstract.h b/ml/dlib/dlib/cmd_line_parser/get_option_abstract.h
new file mode 100644
index 000000000..90dc16721
--- /dev/null
+++ b/ml/dlib/dlib/cmd_line_parser/get_option_abstract.h
@@ -0,0 +1,146 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_GET_OPTiON_ABSTRACT_Hh_
+#ifdef DLIB_GET_OPTiON_ABSTRACT_Hh_
+
+#inclue <string>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class option_parse_error : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown by the get_option() functions. It is
+ thrown when the option string given by a command line parser or
+ config reader can't be converted into the type T.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_type,
+ typename T
+ >
+ T get_option (
+ const config_reader_type& cr,
+ const std::string& option_name,
+ T default_value
+ );
+ /*!
+ requires
+ - T is a type which can be read from an input stream
+ - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h
+ ensures
+ - option_name is used to index into the given config_reader.
+ - if (cr contains an entry corresponding to option_name) then
+ - converts the string value in cr corresponding to option_name into
+ an object of type T and returns it.
+ - else
+ - returns default_value
+ - The scheme for indexing into cr based on option_name is best
+ understood by looking at a few examples:
+ - an option name of "name" corresponds to cr["name"]
+ - an option name of "block1.name" corresponds to cr.block("block1")["name"]
+ - an option name of "block1.block2.name" corresponds to cr.block("block1").block("block2")["name"]
+ throws
+ - option_parse_error
+ This exception is thrown if we attempt but fail to convert the string value
+ in cr into an object of type T.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename command_line_parser_type,
+ typename T
+ >
+ T get_option (
+ const command_line_parser_type& parser,
+ const std::string& option_name,
+ T default_value
+ );
+ /*!
+ requires
+ - parser.option_is_defined(option_name) == true
+ - parser.option(option_name).number_of_arguments() == 1
+ - T is a type which can be read from an input stream
+ - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h
+ ensures
+ - if (parser.option(option_name)) then
+ - converts parser.option(option_name).argument() into an object
+ of type T and returns it. That is, the string argument to this
+ command line option is converted into a T and returned.
+ - else
+ - returns default_value
+ throws
+ - option_parse_error
+ This exception is thrown if we attempt but fail to convert the string
+ argument into an object of type T.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename command_line_parser_type,
+ typename config_reader_type,
+ typename T
+ >
+ T get_option (
+ const command_line_parser_type& parser,
+ const config_reader_type& cr,
+ const std::string& option_name,
+ T default_value
+ );
+ /*!
+ requires
+ - parser.option_is_defined(option_name) == true
+ - parser.option(option_name).number_of_arguments() == 1
+ - T is a type which can be read from an input stream
+ - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h
+ - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h
+ ensures
+ - if (parser.option(option_name)) then
+ - returns get_option(parser, option_name, default_value)
+ - else
+ - returns get_option(cr, option_name, default_value)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename command_line_parser_type,
+ typename config_reader_type,
+ typename T
+ >
+ T get_option (
+ const config_reader_type& cr,
+ const command_line_parser_type& parser,
+ const std::string& option_name,
+ T default_value
+ );
+ /*!
+ requires
+ - parser.option_is_defined(option_name) == true
+ - parser.option(option_name).number_of_arguments() == 1
+ - T is a type which can be read from an input stream
+ - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h
+ - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h
+ ensures
+ - if (parser.option(option_name)) then
+ - returns get_option(parser, option_name, default_value)
+ - else
+ - returns get_option(cr, option_name, default_value)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GET_OPTiON_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/compress_stream.h b/ml/dlib/dlib/compress_stream.h
new file mode 100644
index 000000000..8ccc1d52f
--- /dev/null
+++ b/ml/dlib/dlib/compress_stream.h
@@ -0,0 +1,133 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_COMPRESS_STREAm_
+#define DLIB_COMPRESS_STREAm_
+
+#include "compress_stream/compress_stream_kernel_1.h"
+#include "compress_stream/compress_stream_kernel_2.h"
+#include "compress_stream/compress_stream_kernel_3.h"
+
+#include "conditioning_class.h"
+#include "entropy_encoder.h"
+#include "entropy_decoder.h"
+
+#include "entropy_encoder_model.h"
+#include "entropy_decoder_model.h"
+#include "lz77_buffer.h"
+#include "sliding_buffer.h"
+#include "lzp_buffer.h"
+#include "crc32.h"
+
+
+namespace dlib
+{
+
+ class compress_stream
+ {
+ compress_stream() {}
+
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_1b fce1;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_1b fcd1;
+
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2b fce2;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2b fcd2;
+
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_3b fce3;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_3b fcd3;
+
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4a fce4a;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4a fcd4a;
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4b fce4b;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4b fcd4b;
+
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5a fce5a;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5a fcd5a;
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5b fce5b;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5b fcd5b;
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5c fce5c;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5c fcd5c;
+
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_6a fce6;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_6a fcd6;
+
+
+ typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2d fce2d;
+ typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2d fcd2d;
+
+ typedef sliding_buffer<unsigned char>::kernel_1a sliding_buffer1;
+ typedef lz77_buffer::kernel_2a lz77_buffer2a;
+
+
+ typedef lzp_buffer::kernel_1a lzp_buf_1;
+ typedef lzp_buffer::kernel_2a lzp_buf_2;
+
+
+ typedef entropy_encoder_model<513,entropy_encoder::kernel_2a>::kernel_1b fce_length;
+ typedef entropy_decoder_model<513,entropy_decoder::kernel_2a>::kernel_1b fcd_length;
+
+ typedef entropy_encoder_model<65534,entropy_encoder::kernel_2a>::kernel_1b fce_length_2;
+ typedef entropy_decoder_model<65534,entropy_decoder::kernel_2a>::kernel_1b fcd_length_2;
+
+
+ typedef entropy_encoder_model<32257,entropy_encoder::kernel_2a>::kernel_1b fce_index;
+ typedef entropy_decoder_model<32257,entropy_decoder::kernel_2a>::kernel_1b fcd_index;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef compress_stream_kernel_1 <fce1,fcd1,crc32::kernel_1a>
+ kernel_1a;
+
+ // kernel_1b
+ typedef compress_stream_kernel_1 <fce2,fcd2,crc32::kernel_1a>
+ kernel_1b;
+
+ // kernel_1c
+ typedef compress_stream_kernel_1 <fce3,fcd3,crc32::kernel_1a>
+ kernel_1c;
+
+ // kernel_1da
+ typedef compress_stream_kernel_1 <fce4a,fcd4a,crc32::kernel_1a>
+ kernel_1da;
+
+ // kernel_1ea
+ typedef compress_stream_kernel_1 <fce5a,fcd5a,crc32::kernel_1a>
+ kernel_1ea;
+
+ // kernel_1db
+ typedef compress_stream_kernel_1 <fce4b,fcd4b,crc32::kernel_1a>
+ kernel_1db;
+
+ // kernel_1eb
+ typedef compress_stream_kernel_1 <fce5b,fcd5b,crc32::kernel_1a>
+ kernel_1eb;
+
+ // kernel_1ec
+ typedef compress_stream_kernel_1 <fce5c,fcd5c,crc32::kernel_1a>
+ kernel_1ec;
+
+
+
+
+ // kernel_2a
+ typedef compress_stream_kernel_2 <fce2,fcd2,lz77_buffer2a,sliding_buffer1,fce_length,fcd_length,fce_index,fcd_index,crc32::kernel_1a>
+ kernel_2a;
+
+
+
+
+ // kernel_3a
+ typedef compress_stream_kernel_3 <lzp_buf_1,crc32::kernel_1a,16>
+ kernel_3a;
+ // kernel_3b
+ typedef compress_stream_kernel_3 <lzp_buf_2,crc32::kernel_1a,16>
+ kernel_3b;
+
+
+ };
+}
+
+#endif // DLIB_COMPRESS_STREAm_
+
diff --git a/ml/dlib/dlib/compress_stream/compress_stream_kernel_1.h b/ml/dlib/dlib/compress_stream/compress_stream_kernel_1.h
new file mode 100644
index 000000000..1a75ec6ce
--- /dev/null
+++ b/ml/dlib/dlib/compress_stream/compress_stream_kernel_1.h
@@ -0,0 +1,252 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_COMPRESS_STREAM_KERNEl_1_
+#define DLIB_COMPRESS_STREAM_KERNEl_1_
+
+#include "../algs.h"
+#include <iostream>
+#include <streambuf>
+#include <cstdio>
+#include "compress_stream_kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename fce,
+ typename fcd,
+ typename crc32
+ >
+ class compress_stream_kernel_1
+ {
+ /*!
+ REQUIREMENTS ON fce
+ is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h
+ the alphabet_size of fce must be 257.
+ fce and fcd share the same kernel number.
+
+ REQUIREMENTS ON fcd
+ is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h
+ the alphabet_size of fcd must be 257.
+ fce and fcd share the same kernel number.
+
+ REQUIREMENTS ON crc32
+ is an implementation of crc32/crc32_kernel_abstract.h
+
+
+
+ INITIAL VALUE
+ this object has no state
+
+ CONVENTION
+ this object has no state
+ !*/
+
+ const static unsigned long eof_symbol = 256;
+
+ public:
+
+ class decompression_error : public dlib::error
+ {
+ public:
+ decompression_error(
+ const char* i
+ ) :
+ dlib::error(std::string(i))
+ {}
+
+ decompression_error(
+ const std::string& i
+ ) :
+ dlib::error(i)
+ {}
+ };
+
+
+ compress_stream_kernel_1 (
+ )
+ {}
+
+ ~compress_stream_kernel_1 (
+ )
+ {}
+
+ void compress (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+
+ void decompress (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+
+ private:
+
+ // restricted functions
+ compress_stream_kernel_1(compress_stream_kernel_1&); // copy constructor
+ compress_stream_kernel_1& operator=(compress_stream_kernel_1&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename fce,
+ typename fcd,
+ typename crc32
+ >
+ void compress_stream_kernel_1<fce,fcd,crc32>::
+ compress (
+ std::istream& in_,
+ std::ostream& out_
+ ) const
+ {
+ std::streambuf::int_type temp;
+
+ std::streambuf& in = *in_.rdbuf();
+
+ typename fce::entropy_encoder_type coder;
+ coder.set_stream(out_);
+
+ fce model(coder);
+
+ crc32 crc;
+
+ unsigned long count = 0;
+
+ while (true)
+ {
+ // write out a known value every 20000 symbols
+ if (count == 20000)
+ {
+ count = 0;
+ coder.encode(1500,1501,8000);
+ }
+ ++count;
+
+ // get the next character
+ temp = in.sbumpc();
+
+ // if we have hit EOF then encode the marker symbol
+ if (temp != EOF)
+ {
+ // encode the symbol
+ model.encode(static_cast<unsigned long>(temp));
+ crc.add(static_cast<unsigned char>(temp));
+ continue;
+ }
+ else
+ {
+ model.encode(eof_symbol);
+
+ // now write the checksum
+ unsigned long checksum = crc.get_checksum();
+ unsigned char byte1 = static_cast<unsigned char>((checksum>>24)&0xFF);
+ unsigned char byte2 = static_cast<unsigned char>((checksum>>16)&0xFF);
+ unsigned char byte3 = static_cast<unsigned char>((checksum>>8)&0xFF);
+ unsigned char byte4 = static_cast<unsigned char>((checksum)&0xFF);
+
+ model.encode(byte1);
+ model.encode(byte2);
+ model.encode(byte3);
+ model.encode(byte4);
+
+ break;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename fce,
+ typename fcd,
+ typename crc32
+ >
+ void compress_stream_kernel_1<fce,fcd,crc32>::
+ decompress (
+ std::istream& in_,
+ std::ostream& out_
+ ) const
+ {
+
+ std::streambuf& out = *out_.rdbuf();
+
+ typename fcd::entropy_decoder_type coder;
+ coder.set_stream(in_);
+
+ fcd model(coder);
+
+ unsigned long symbol;
+ unsigned long count = 0;
+
+ crc32 crc;
+
+ // decode until we hit the marker symbol
+ while (true)
+ {
+ // make sure this is the value we expect
+ if (count == 20000)
+ {
+ if (coder.get_target(8000) != 1500)
+ {
+ throw decompression_error("Error detected in compressed data stream.");
+ }
+ count = 0;
+ coder.decode(1500,1501);
+ }
+ ++count;
+
+ // decode the next symbol
+ model.decode(symbol);
+ if (symbol != eof_symbol)
+ {
+ crc.add(static_cast<unsigned char>(symbol));
+ // write this symbol to out
+ if (out.sputc(static_cast<char>(symbol)) != static_cast<int>(symbol))
+ {
+ throw std::ios::failure("error occurred in compress_stream_kernel_1::decompress");
+ }
+ continue;
+ }
+ else
+ {
+ // we read eof from the encoded data. now we just have to check the checksum and we are done.
+ unsigned char byte1;
+ unsigned char byte2;
+ unsigned char byte3;
+ unsigned char byte4;
+
+ model.decode(symbol); byte1 = static_cast<unsigned char>(symbol);
+ model.decode(symbol); byte2 = static_cast<unsigned char>(symbol);
+ model.decode(symbol); byte3 = static_cast<unsigned char>(symbol);
+ model.decode(symbol); byte4 = static_cast<unsigned char>(symbol);
+
+ unsigned long checksum = byte1;
+ checksum <<= 8;
+ checksum |= byte2;
+ checksum <<= 8;
+ checksum |= byte3;
+ checksum <<= 8;
+ checksum |= byte4;
+
+ if (checksum != crc.get_checksum())
+ throw decompression_error("Error detected in compressed data stream.");
+
+ break;
+ }
+ } // while (true)
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_COMPRESS_STREAM_KERNEl_1_
+
diff --git a/ml/dlib/dlib/compress_stream/compress_stream_kernel_2.h b/ml/dlib/dlib/compress_stream/compress_stream_kernel_2.h
new file mode 100644
index 000000000..e46b23fad
--- /dev/null
+++ b/ml/dlib/dlib/compress_stream/compress_stream_kernel_2.h
@@ -0,0 +1,431 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_COMPRESS_STREAM_KERNEl_2_
+#define DLIB_COMPRESS_STREAM_KERNEl_2_
+
+#include "../algs.h"
+#include <iostream>
+#include <streambuf>
+#include "compress_stream_kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename fce,
+ typename fcd,
+ typename lz77_buffer,
+ typename sliding_buffer,
+ typename fce_length,
+ typename fcd_length,
+ typename fce_index,
+ typename fcd_index,
+ typename crc32
+ >
+ class compress_stream_kernel_2
+ {
+ /*!
+ REQUIREMENTS ON fce
+ is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h
+ the alphabet_size of fce must be 257.
+ fce and fcd share the same kernel number.
+
+ REQUIREMENTS ON fcd
+ is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h
+ the alphabet_size of fcd must be 257.
+ fce and fcd share the same kernel number.
+
+ REQUIREMENTS ON lz77_buffer
+ is an implementation of lz77_buffer/lz77_buffer_kernel_abstract.h
+
+ REQUIREMENTS ON sliding_buffer
+ is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h
+ is instantiated with T = unsigned char
+
+ REQUIREMENTS ON fce_length
+ is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h
+ the alphabet_size of fce must be 513. This will be used to encode the length of lz77 matches.
+ fce_length and fcd share the same kernel number.
+
+ REQUIREMENTS ON fcd_length
+ is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h
+ the alphabet_size of fcd must be 513. This will be used to decode the length of lz77 matches.
+ fce_length and fcd share the same kernel number.
+
+ REQUIREMENTS ON fce_index
+ is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h
+ the alphabet_size of fce must be 32257. This will be used to encode the index of lz77 matches.
+ fce_index and fcd share the same kernel number.
+
+ REQUIREMENTS ON fcd_index
+ is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h
+ the alphabet_size of fcd must be 32257. This will be used to decode the index of lz77 matches.
+ fce_index and fcd share the same kernel number.
+
+ REQUIREMENTS ON crc32
+ is an implementation of crc32/crc32_kernel_abstract.h
+
+ INITIAL VALUE
+ this object has no state
+
+ CONVENTION
+ this object has no state
+ !*/
+
+ const static unsigned long eof_symbol = 256;
+
+ public:
+
+ class decompression_error : public dlib::error
+ {
+ public:
+ decompression_error(
+ const char* i
+ ) :
+ dlib::error(std::string(i))
+ {}
+
+ decompression_error(
+ const std::string& i
+ ) :
+ dlib::error(i)
+ {}
+ };
+
+
+ compress_stream_kernel_2 (
+ )
+ {}
+
+ ~compress_stream_kernel_2 (
+ )
+ {}
+
+ void compress (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+
+ void decompress (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+
+ private:
+
+ // restricted functions
+ compress_stream_kernel_2(compress_stream_kernel_2&); // copy constructor
+ compress_stream_kernel_2& operator=(compress_stream_kernel_2&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename fce,
+ typename fcd,
+ typename lz77_buffer,
+ typename sliding_buffer,
+ typename fce_length,
+ typename fcd_length,
+ typename fce_index,
+ typename fcd_index,
+ typename crc32
+ >
+ void compress_stream_kernel_2<fce,fcd,lz77_buffer,sliding_buffer,fce_length,fcd_length,fce_index,fcd_index,crc32>::
+ compress (
+ std::istream& in_,
+ std::ostream& out_
+ ) const
+ {
+ std::streambuf::int_type temp;
+
+ std::streambuf& in = *in_.rdbuf();
+
+ typename fce::entropy_encoder_type coder;
+ coder.set_stream(out_);
+
+ fce model(coder);
+ fce_length model_length(coder);
+ fce_index model_index(coder);
+
+ const unsigned long LOOKAHEAD_LIMIT = 512;
+ lz77_buffer buffer(15,LOOKAHEAD_LIMIT);
+
+ crc32 crc;
+
+
+ unsigned long count = 0;
+
+ unsigned long lz77_count = 1; // number of times we used lz77 to encode
+ unsigned long ppm_count = 1; // number of times we used ppm to encode
+
+
+ while (true)
+ {
+ // write out a known value every 20000 symbols
+ if (count == 20000)
+ {
+ count = 0;
+ coder.encode(150,151,400);
+ }
+ ++count;
+
+ // try to fill the lookahead buffer
+ if (buffer.get_lookahead_buffer_size() < buffer.get_lookahead_buffer_limit())
+ {
+ temp = in.sbumpc();
+ while (temp != EOF)
+ {
+ crc.add(static_cast<unsigned char>(temp));
+ buffer.add(static_cast<unsigned char>(temp));
+ if (buffer.get_lookahead_buffer_size() == buffer.get_lookahead_buffer_limit())
+ break;
+ temp = in.sbumpc();
+ }
+ }
+
+ // compute the sum of ppm_count and lz77_count but make sure
+ // it is less than 65536
+ unsigned long sum = ppm_count + lz77_count;
+ if (sum >= 65536)
+ {
+ ppm_count >>= 1;
+ lz77_count >>= 1;
+ ppm_count |= 1;
+ lz77_count |= 1;
+ sum = ppm_count+lz77_count;
+ }
+
+ // if there are still more symbols in the lookahead buffer to encode
+ if (buffer.get_lookahead_buffer_size() > 0)
+ {
+ unsigned long match_index, match_length;
+ buffer.find_match(match_index,match_length,6);
+ if (match_length != 0)
+ {
+
+ // signal the decoder that we are using lz77
+ coder.encode(0,lz77_count,sum);
+ ++lz77_count;
+
+ // encode the index and length pair
+ model_index.encode(match_index);
+ model_length.encode(match_length);
+
+ }
+ else
+ {
+
+ // signal the decoder that we are using ppm
+ coder.encode(lz77_count,sum,sum);
+ ++ppm_count;
+
+ // encode the symbol using the ppm model
+ model.encode(buffer.lookahead_buffer(0));
+ buffer.shift_buffers(1);
+ }
+ }
+ else
+ {
+ // signal the decoder that we are using ppm
+ coder.encode(lz77_count,sum,sum);
+
+
+ model.encode(eof_symbol);
+ // now write the checksum
+ unsigned long checksum = crc.get_checksum();
+ unsigned char byte1 = static_cast<unsigned char>((checksum>>24)&0xFF);
+ unsigned char byte2 = static_cast<unsigned char>((checksum>>16)&0xFF);
+ unsigned char byte3 = static_cast<unsigned char>((checksum>>8)&0xFF);
+ unsigned char byte4 = static_cast<unsigned char>((checksum)&0xFF);
+
+ model.encode(byte1);
+ model.encode(byte2);
+ model.encode(byte3);
+ model.encode(byte4);
+
+ break;
+ }
+ } // while (true)
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename fce,
+ typename fcd,
+ typename lz77_buffer,
+ typename sliding_buffer,
+ typename fce_length,
+ typename fcd_length,
+ typename fce_index,
+ typename fcd_index,
+ typename crc32
+ >
+ void compress_stream_kernel_2<fce,fcd,lz77_buffer,sliding_buffer,fce_length,fcd_length,fce_index,fcd_index,crc32>::
+ decompress (
+ std::istream& in_,
+ std::ostream& out_
+ ) const
+ {
+
+ std::streambuf& out = *out_.rdbuf();
+
+ typename fcd::entropy_decoder_type coder;
+ coder.set_stream(in_);
+
+ fcd model(coder);
+ fcd_length model_length(coder);
+ fcd_index model_index(coder);
+
+ unsigned long symbol;
+ unsigned long count = 0;
+
+ sliding_buffer buffer;
+ buffer.set_size(15);
+
+ // Initialize the buffer to all zeros. There is no algorithmic reason to
+ // do this. But doing so avoids a warning from valgrind so that is why
+ // I'm doing this.
+ for (unsigned long i = 0; i < buffer.size(); ++i)
+ buffer[i] = 0;
+
+ crc32 crc;
+
+ unsigned long lz77_count = 1; // number of times we used lz77 to encode
+ unsigned long ppm_count = 1; // number of times we used ppm to encode
+ bool next_block_lz77;
+
+
+ // decode until we hit the marker symbol
+ while (true)
+ {
+ // make sure this is the value we expect
+ if (count == 20000)
+ {
+ if (coder.get_target(400) != 150)
+ {
+ throw decompression_error("Error detected in compressed data stream.");
+ }
+ count = 0;
+ coder.decode(150,151);
+ }
+ ++count;
+
+
+ // compute the sum of ppm_count and lz77_count but make sure
+ // it is less than 65536
+ unsigned long sum = ppm_count + lz77_count;
+ if (sum >= 65536)
+ {
+ ppm_count >>= 1;
+ lz77_count >>= 1;
+ ppm_count |= 1;
+ lz77_count |= 1;
+ sum = ppm_count+lz77_count;
+ }
+
+ // check if we are decoding a lz77 or ppm block
+ if (coder.get_target(sum) < lz77_count)
+ {
+ coder.decode(0,lz77_count);
+ next_block_lz77 = true;
+ ++lz77_count;
+ }
+ else
+ {
+ coder.decode(lz77_count,sum);
+ next_block_lz77 = false;
+ ++ppm_count;
+ }
+
+
+ if (next_block_lz77)
+ {
+
+ unsigned long match_length, match_index;
+ // decode the match index
+ model_index.decode(match_index);
+
+ // decode the match length
+ model_length.decode(match_length);
+
+
+ match_index += match_length;
+ buffer.rotate_left(match_length);
+ for (unsigned long i = 0; i < match_length; ++i)
+ {
+ unsigned char ch = buffer[match_index-i];
+ buffer[match_length-i-1] = ch;
+
+ crc.add(ch);
+ // write this ch to out
+ if (out.sputc(static_cast<char>(ch)) != static_cast<int>(ch))
+ {
+ throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress");
+ }
+ }
+
+ }
+ else
+ {
+
+ // decode the next symbol
+ model.decode(symbol);
+ if (symbol != eof_symbol)
+ {
+ buffer.rotate_left(1);
+ buffer[0] = static_cast<unsigned char>(symbol);
+
+
+ crc.add(static_cast<unsigned char>(symbol));
+ // write this symbol to out
+ if (out.sputc(static_cast<char>(symbol)) != static_cast<int>(symbol))
+ {
+ throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress");
+ }
+ }
+ else
+ {
+ // this was the eof marker symbol so we are done. now check the checksum
+
+ // now get the checksum and make sure it matches
+ unsigned char byte1;
+ unsigned char byte2;
+ unsigned char byte3;
+ unsigned char byte4;
+
+ model.decode(symbol); byte1 = static_cast<unsigned char>(symbol);
+ model.decode(symbol); byte2 = static_cast<unsigned char>(symbol);
+ model.decode(symbol); byte3 = static_cast<unsigned char>(symbol);
+ model.decode(symbol); byte4 = static_cast<unsigned char>(symbol);
+
+ unsigned long checksum = byte1;
+ checksum <<= 8;
+ checksum |= byte2;
+ checksum <<= 8;
+ checksum |= byte3;
+ checksum <<= 8;
+ checksum |= byte4;
+
+ if (checksum != crc.get_checksum())
+ throw decompression_error("Error detected in compressed data stream.");
+
+ break;
+ }
+ }
+
+ } // while (true)
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_COMPRESS_STREAM_KERNEl_2_
+
diff --git a/ml/dlib/dlib/compress_stream/compress_stream_kernel_3.h b/ml/dlib/dlib/compress_stream/compress_stream_kernel_3.h
new file mode 100644
index 000000000..ed4eee290
--- /dev/null
+++ b/ml/dlib/dlib/compress_stream/compress_stream_kernel_3.h
@@ -0,0 +1,381 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_COMPRESS_STREAM_KERNEl_3_
+#define DLIB_COMPRESS_STREAM_KERNEl_3_
+
+#include "../algs.h"
+#include "compress_stream_kernel_abstract.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ typename lzp_buf,
+ typename crc32,
+ unsigned long buffer_size
+ >
+ class compress_stream_kernel_3
+ {
+ /*!
+ REQUIREMENTS ON lzp_buf
+ is an implementation of lzp_buffer/lzp_buffer_kernel_abstract.h
+
+ REQUIREMENTS ON buffer_size
+ 10 < buffer_size < 32
+
+ REQUIREMENTS ON crc32
+ is an implementation of crc32/crc32_kernel_abstract.h
+
+
+ INITIAL VALUE
+ this object has no state
+
+ CONVENTION
+ this object has no state
+
+
+ This implementation uses the lzp_buffer and writes out matches
+ in a byte aligned format.
+
+ !*/
+
+
+ public:
+
+ class decompression_error : public dlib::error
+ {
+ public:
+ decompression_error(
+ const char* i
+ ) :
+ dlib::error(std::string(i))
+ {}
+
+ decompression_error(
+ const std::string& i
+ ) :
+ dlib::error(i)
+ {}
+ };
+
+
+ compress_stream_kernel_3 (
+ )
+ {
+ COMPILE_TIME_ASSERT(10 < buffer_size && buffer_size < 32);
+ }
+
+ ~compress_stream_kernel_3 (
+ )
+ {}
+
+ void compress (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+
+ void decompress (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+
+
+
+ private:
+
+ inline void write (
+ unsigned char symbol
+ ) const
+ {
+ if (out->sputn(reinterpret_cast<char*>(&symbol),1)==0)
+ throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3");
+ }
+
+ inline void decode (
+ unsigned char& symbol,
+ unsigned char& flag
+ ) const
+ {
+ if (count == 0)
+ {
+ if (((size_t)in->sgetn(reinterpret_cast<char*>(buffer),sizeof(buffer)))!=sizeof(buffer))
+ throw decompression_error("Error detected in compressed data stream.");
+ count = 8;
+ }
+ --count;
+ symbol = buffer[8-count];
+ flag = buffer[0] >> 7;
+ buffer[0] <<= 1;
+ }
+
+ inline void encode (
+ unsigned char symbol,
+ unsigned char flag
+ ) const
+ /*!
+ requires
+ - 0 <= flag <= 1
+ ensures
+ - writes symbol with the given one bit flag
+ !*/
+ {
+ // add this symbol and flag to the buffer
+ ++count;
+ buffer[0] <<= 1;
+ buffer[count] = symbol;
+ buffer[0] |= flag;
+
+ if (count == 8)
+ {
+ if (((size_t)out->sputn(reinterpret_cast<char*>(buffer),sizeof(buffer)))!=sizeof(buffer))
+ throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3");
+ count = 0;
+ buffer[0] = 0;
+ }
+ }
+
+ void clear (
+ ) const
+ /*!
+ ensures
+ - resets the buffers
+ !*/
+ {
+ count = 0;
+ }
+
+ void flush (
+ ) const
+ /*!
+ ensures
+ - flushes any data in the buffers to out
+ !*/
+ {
+ if (count != 0)
+ {
+ buffer[0] <<= (8-count);
+ if (((size_t)out->sputn(reinterpret_cast<char*>(buffer),sizeof(buffer)))!=sizeof(buffer))
+ throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3");
+ }
+ }
+
+ mutable unsigned int count;
+ // count tells us how many bytes are buffered in buffer and how many flag
+ // bit are currently in buffer[0]
+ mutable unsigned char buffer[9];
+ // buffer[0] holds the flag bits to be writen.
+ // the rest of the buffer holds the bytes to be writen.
+
+ mutable std::streambuf* in;
+ mutable std::streambuf* out;
+
+ // restricted functions
+ compress_stream_kernel_3(compress_stream_kernel_3<lzp_buf,crc32,buffer_size>&); // copy constructor
+ compress_stream_kernel_3<lzp_buf,crc32,buffer_size>& operator=(compress_stream_kernel_3<lzp_buf,crc32,buffer_size>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lzp_buf,
+ typename crc32,
+ unsigned long buffer_size
+ >
+ void compress_stream_kernel_3<lzp_buf,crc32,buffer_size>::
+ compress (
+ std::istream& in_,
+ std::ostream& out_
+ ) const
+ {
+ in = in_.rdbuf();
+ out = out_.rdbuf();
+ clear();
+
+ crc32 crc;
+
+ lzp_buf buffer(buffer_size);
+
+ std::streambuf::int_type temp = in->sbumpc();
+ unsigned long index;
+ unsigned char symbol;
+ unsigned char length;
+
+ while (temp != EOF)
+ {
+ symbol = static_cast<unsigned char>(temp);
+ if (buffer.predict_match(index))
+ {
+ if (buffer[index] == symbol)
+ {
+ // this is a match so we must find out how long it is
+ length = 1;
+
+ buffer.add(symbol);
+ crc.add(symbol);
+
+ temp = in->sbumpc();
+ while (length < 255)
+ {
+ if (temp == EOF)
+ {
+ break;
+ }
+ else if (static_cast<unsigned long>(length) >= index)
+ {
+ break;
+ }
+ else if (static_cast<unsigned char>(temp) == buffer[index])
+ {
+ ++length;
+ buffer.add(static_cast<unsigned char>(temp));
+ crc.add(static_cast<unsigned char>(temp));
+ temp = in->sbumpc();
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ encode(length,1);
+ }
+ else
+ {
+ // this is also not a match
+ encode(symbol,0);
+ buffer.add(symbol);
+ crc.add(symbol);
+
+ // get the next symbol
+ temp = in->sbumpc();
+ }
+ }
+ else
+ {
+ // there wasn't a match so just write this symbol
+ encode(symbol,0);
+ buffer.add(symbol);
+ crc.add(symbol);
+
+ // get the next symbol
+ temp = in->sbumpc();
+ }
+ }
+
+ // use a match of zero length to indicate EOF
+ encode(0,1);
+
+ // now write the checksum
+ unsigned long checksum = crc.get_checksum();
+ unsigned char byte1 = static_cast<unsigned char>((checksum>>24)&0xFF);
+ unsigned char byte2 = static_cast<unsigned char>((checksum>>16)&0xFF);
+ unsigned char byte3 = static_cast<unsigned char>((checksum>>8)&0xFF);
+ unsigned char byte4 = static_cast<unsigned char>((checksum)&0xFF);
+
+ encode(byte1,0);
+ encode(byte2,0);
+ encode(byte3,0);
+ encode(byte4,0);
+
+ flush();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lzp_buf,
+ typename crc32,
+ unsigned long buffer_size
+ >
+ void compress_stream_kernel_3<lzp_buf,crc32,buffer_size>::
+ decompress (
+ std::istream& in_,
+ std::ostream& out_
+ ) const
+ {
+ in = in_.rdbuf();
+ out = out_.rdbuf();
+ clear();
+
+ crc32 crc;
+
+ lzp_buf buffer(buffer_size);
+
+
+ unsigned long index = 0;
+ unsigned char symbol;
+ unsigned char length;
+ unsigned char flag;
+
+ decode(symbol,flag);
+ while (flag == 0 || symbol != 0)
+ {
+ buffer.predict_match(index);
+
+ if (flag == 1)
+ {
+ length = symbol;
+ do
+ {
+ --length;
+ symbol = buffer[index];
+ write(symbol);
+ buffer.add(symbol);
+ crc.add(symbol);
+ } while (length != 0);
+ }
+ else
+ {
+ // this is just a literal
+ write(symbol);
+ buffer.add(symbol);
+ crc.add(symbol);
+ }
+ decode(symbol,flag);
+ }
+
+
+ // now get the checksum and make sure it matches
+ unsigned char byte1;
+ unsigned char byte2;
+ unsigned char byte3;
+ unsigned char byte4;
+
+ decode(byte1,flag);
+ if (flag != 0)
+ throw decompression_error("Error detected in compressed data stream.");
+ decode(byte2,flag);
+ if (flag != 0)
+ throw decompression_error("Error detected in compressed data stream.");
+ decode(byte3,flag);
+ if (flag != 0)
+ throw decompression_error("Error detected in compressed data stream.");
+ decode(byte4,flag);
+ if (flag != 0)
+ throw decompression_error("Error detected in compressed data stream.");
+
+ unsigned long checksum = byte1;
+ checksum <<= 8;
+ checksum |= byte2;
+ checksum <<= 8;
+ checksum |= byte3;
+ checksum <<= 8;
+ checksum |= byte4;
+
+ if (checksum != crc.get_checksum())
+ throw decompression_error("Error detected in compressed data stream.");
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_COMPRESS_STREAM_KERNEl_3_
+
diff --git a/ml/dlib/dlib/compress_stream/compress_stream_kernel_abstract.h b/ml/dlib/dlib/compress_stream/compress_stream_kernel_abstract.h
new file mode 100644
index 000000000..48f46d9e1
--- /dev/null
+++ b/ml/dlib/dlib/compress_stream/compress_stream_kernel_abstract.h
@@ -0,0 +1,94 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_
+#ifdef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+#include <iosfwd>
+
+namespace dlib
+{
+
+ class compress_stream
+ {
+ /*!
+ INITIAL VALUE
+ This object does not have any state associated with it.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object consists of the two functions compress and decompress.
+ These functions allow you to compress and decompress data.
+ !*/
+
+ public:
+
+ class decompression_error : public dlib::error {};
+
+ compress_stream (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~compress_stream (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+
+ void compress (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+ /*!
+ ensures
+ - reads all data from in (until EOF is reached) and compresses it
+ and writes it to out
+ throws
+ - std::ios_base::failure
+ if there was a problem writing to out then this exception will
+ be thrown.
+ - any other exception
+ this exception may be thrown if there is any other problem
+ !*/
+
+
+ void decompress (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+ /*!
+ ensures
+ - reads data from in, decompresses it and writes it to out. note that
+ it stops reading data from in when it encounters the end of the
+ compressed data, not when it encounters EOF.
+ throws
+ - std::ios_base::failure
+ if there was a problem writing to out then this exception will
+ be thrown.
+ - decompression_error
+ if an error was detected in the compressed data that prevented
+ it from being correctly decompressed then this exception is
+ thrown.
+ - any other exception
+ this exception may be thrown if there is any other problem
+ !*/
+
+
+ private:
+
+ // restricted functions
+ compress_stream(compress_stream&); // copy constructor
+ compress_stream& operator=(compress_stream&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/conditioning_class.h b/ml/dlib/dlib/conditioning_class.h
new file mode 100644
index 000000000..409b98716
--- /dev/null
+++ b/ml/dlib/dlib/conditioning_class.h
@@ -0,0 +1,80 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONDITIONING_CLASs_
+#define DLIB_CONDITIONING_CLASs_
+
+#include "conditioning_class/conditioning_class_kernel_1.h"
+#include "conditioning_class/conditioning_class_kernel_2.h"
+#include "conditioning_class/conditioning_class_kernel_3.h"
+#include "conditioning_class/conditioning_class_kernel_4.h"
+#include "conditioning_class/conditioning_class_kernel_c.h"
+
+
+#include "memory_manager.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size
+ >
+ class conditioning_class
+ {
+ conditioning_class() {}
+
+ typedef memory_manager<char>::kernel_2b mm;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef conditioning_class_kernel_1<alphabet_size>
+ kernel_1a;
+ typedef conditioning_class_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+ // kernel_2a
+ typedef conditioning_class_kernel_2<alphabet_size>
+ kernel_2a;
+ typedef conditioning_class_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+ // kernel_3a
+ typedef conditioning_class_kernel_3<alphabet_size>
+ kernel_3a;
+ typedef conditioning_class_kernel_c<kernel_3a>
+ kernel_3a_c;
+
+
+ // -------- kernel_4 ---------
+
+ // kernel_4a
+ typedef conditioning_class_kernel_4<alphabet_size,10000,mm>
+ kernel_4a;
+ typedef conditioning_class_kernel_c<kernel_4a>
+ kernel_4a_c;
+
+ // kernel_4b
+ typedef conditioning_class_kernel_4<alphabet_size,100000,mm>
+ kernel_4b;
+ typedef conditioning_class_kernel_c<kernel_4b>
+ kernel_4b_c;
+
+ // kernel_4c
+ typedef conditioning_class_kernel_4<alphabet_size,1000000,mm>
+ kernel_4c;
+ typedef conditioning_class_kernel_c<kernel_4c>
+ kernel_4c_c;
+
+ // kernel_4d
+ typedef conditioning_class_kernel_4<alphabet_size,10000000,mm>
+ kernel_4d;
+ typedef conditioning_class_kernel_c<kernel_4d>
+ kernel_4d_c;
+
+ };
+}
+
+#endif // DLIB_CONDITIONING_CLASS_
+
diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_1.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_1.h
new file mode 100644
index 000000000..d26d80244
--- /dev/null
+++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_1.h
@@ -0,0 +1,333 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONDITIONING_CLASS_KERNEl_1_
+#define DLIB_CONDITIONING_CLASS_KERNEl_1_
+
+#include "conditioning_class_kernel_abstract.h"
+#include "../assert.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size
+ >
+ class conditioning_class_kernel_1
+ {
+ /*!
+ INITIAL VALUE
+ total == 1
+ counts == pointer to an array of alphabet_size unsigned shorts
+ for all i except i == alphabet_size-1: counts[i] == 0
+ counts[alphabet_size-1] == 1
+
+ CONVENTION
+ counts == pointer to an array of alphabet_size unsigned shorts
+ get_total() == total
+ get_count(symbol) == counts[symbol]
+
+ LOW_COUNT(symbol) == sum of counts[0] though counts[symbol-1]
+ or 0 if symbol == 0
+
+ get_memory_usage() == global_state.memory_usage
+ !*/
+
+ public:
+
+ class global_state_type
+ {
+ public:
+ global_state_type () : memory_usage(0) {}
+ private:
+ unsigned long memory_usage;
+
+ friend class conditioning_class_kernel_1<alphabet_size>;
+ };
+
+ conditioning_class_kernel_1 (
+ global_state_type& global_state_
+ );
+
+ ~conditioning_class_kernel_1 (
+ );
+
+ void clear(
+ );
+
+ bool increment_count (
+ unsigned long symbol,
+ unsigned short amount = 1
+ );
+
+ unsigned long get_count (
+ unsigned long symbol
+ ) const;
+
+ unsigned long get_total (
+ ) const;
+
+ unsigned long get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const;
+
+ void get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const;
+
+ unsigned long get_memory_usage (
+ ) const;
+
+ global_state_type& get_global_state (
+ );
+
+ static unsigned long get_alphabet_size (
+ );
+
+
+ private:
+
+ // restricted functions
+ conditioning_class_kernel_1(conditioning_class_kernel_1<alphabet_size>&); // copy constructor
+ conditioning_class_kernel_1& operator=(conditioning_class_kernel_1<alphabet_size>&); // assignment operator
+
+ // data members
+ unsigned short total;
+ unsigned short* counts;
+ global_state_type& global_state;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ conditioning_class_kernel_1<alphabet_size>::
+ conditioning_class_kernel_1 (
+ global_state_type& global_state_
+ ) :
+ total(1),
+ counts(new unsigned short[alphabet_size]),
+ global_state(global_state_)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 );
+
+ unsigned short* start = counts;
+ unsigned short* end = counts+alphabet_size-1;
+ while (start != end)
+ {
+ *start = 0;
+ ++start;
+ }
+ *start = 1;
+
+ // update memory usage
+ global_state.memory_usage += sizeof(unsigned short)*alphabet_size +
+ sizeof(conditioning_class_kernel_1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ conditioning_class_kernel_1<alphabet_size>::
+ ~conditioning_class_kernel_1 (
+ )
+ {
+ delete [] counts;
+ // update memory usage
+ global_state.memory_usage -= sizeof(unsigned short)*alphabet_size +
+ sizeof(conditioning_class_kernel_1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ void conditioning_class_kernel_1<alphabet_size>::
+ clear(
+ )
+ {
+ total = 1;
+ unsigned short* start = counts;
+ unsigned short* end = counts+alphabet_size-1;
+ while (start != end)
+ {
+ *start = 0;
+ ++start;
+ }
+ *start = 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_1<alphabet_size>::
+ get_memory_usage(
+ ) const
+ {
+ return global_state.memory_usage;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ typename conditioning_class_kernel_1<alphabet_size>::global_state_type& conditioning_class_kernel_1<alphabet_size>::
+ get_global_state(
+ )
+ {
+ return global_state;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ bool conditioning_class_kernel_1<alphabet_size>::
+ increment_count (
+ unsigned long symbol,
+ unsigned short amount
+ )
+ {
+ // if we are going over a total of 65535 then scale down all counts by 2
+ if (static_cast<unsigned long>(total)+static_cast<unsigned long>(amount) >= 65536)
+ {
+ total = 0;
+ unsigned short* start = counts;
+ unsigned short* end = counts+alphabet_size;
+ while (start != end)
+ {
+ *start >>= 1;
+ total += *start;
+ ++start;
+ }
+ // make sure it is at least one
+ if (counts[alphabet_size-1]==0)
+ {
+ ++total;
+ counts[alphabet_size-1] = 1;
+ }
+ }
+ counts[symbol] += amount;
+ total += amount;
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_1<alphabet_size>::
+ get_count (
+ unsigned long symbol
+ ) const
+ {
+ return counts[symbol];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_1<alphabet_size>::
+ get_alphabet_size (
+ )
+ {
+ return alphabet_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_1<alphabet_size>::
+ get_total (
+ ) const
+ {
+ return total;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_1<alphabet_size>::
+ get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const
+ {
+ if (counts[symbol] == 0)
+ return 0;
+
+ total_count = total;
+
+ const unsigned short* start = counts;
+ const unsigned short* end = counts+symbol;
+ unsigned short high_count_temp = *start;
+ while (start != end)
+ {
+ ++start;
+ high_count_temp += *start;
+ }
+ low_count = high_count_temp - *start;
+ high_count = high_count_temp;
+ return *start;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ void conditioning_class_kernel_1<alphabet_size>::
+ get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const
+ {
+ unsigned long high_count_temp = *counts;
+ const unsigned short* start = counts;
+ while (target >= high_count_temp)
+ {
+ ++start;
+ high_count_temp += *start;
+ }
+
+ low_count = high_count_temp - *start;
+ high_count = high_count_temp;
+ symbol = static_cast<unsigned long>(start-counts);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CONDITIONING_CLASS_KERNEl_1_
+
diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_2.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_2.h
new file mode 100644
index 000000000..c9b38c8e3
--- /dev/null
+++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_2.h
@@ -0,0 +1,500 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONDITIONING_CLASS_KERNEl_2_
+#define DLIB_CONDITIONING_CLASS_KERNEl_2_
+
+#include "conditioning_class_kernel_abstract.h"
+#include "../assert.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size
+ >
+ class conditioning_class_kernel_2
+ {
+ /*!
+ INITIAL VALUE
+ total == 1
+ symbols == pointer to array of alphabet_size data structs
+ for all i except i == alphabet_size-1: symbols[i].count == 0
+ symbols[i].left_count == 0
+
+ symbols[alphabet_size-1].count == 1
+ symbols[alpahbet_size-1].left_count == 0
+
+ CONVENTION
+ symbols == pointer to array of alphabet_size data structs
+ get_total() == total
+ get_count(symbol) == symbols[symbol].count
+
+ symbols is organized as a tree with symbols[0] as the root.
+
+ the left subchild of symbols[i] is symbols[i*2+1] and
+ the right subchild is symbols[i*2+2].
+ the partent of symbols[i] == symbols[(i-1)/2]
+
+ symbols[i].left_count == the sum of the counts of all the
+ symbols to the left of symbols[i]
+
+ get_memory_usage() == global_state.memory_usage
+ !*/
+
+ public:
+
+ class global_state_type
+ {
+ public:
+ global_state_type () : memory_usage(0) {}
+ private:
+ unsigned long memory_usage;
+
+ friend class conditioning_class_kernel_2<alphabet_size>;
+ };
+
+ conditioning_class_kernel_2 (
+ global_state_type& global_state_
+ );
+
+ ~conditioning_class_kernel_2 (
+ );
+
+ void clear(
+ );
+
+ bool increment_count (
+ unsigned long symbol,
+ unsigned short amount = 1
+ );
+
+ unsigned long get_count (
+ unsigned long symbol
+ ) const;
+
+ inline unsigned long get_total (
+ ) const;
+
+ unsigned long get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const;
+
+ void get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const;
+
+ unsigned long get_memory_usage (
+ ) const;
+
+ global_state_type& get_global_state (
+ );
+
+ static unsigned long get_alphabet_size (
+ );
+
+ private:
+
+ // restricted functions
+ conditioning_class_kernel_2(conditioning_class_kernel_2<alphabet_size>&); // copy constructor
+ conditioning_class_kernel_2& operator=(conditioning_class_kernel_2<alphabet_size>&); // assignment operator
+
+ // data members
+ unsigned short total;
+ struct data
+ {
+ unsigned short count;
+ unsigned short left_count;
+ };
+
+ data* symbols;
+ global_state_type& global_state;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ conditioning_class_kernel_2<alphabet_size>::
+ conditioning_class_kernel_2 (
+ global_state_type& global_state_
+ ) :
+ total(1),
+ symbols(new data[alphabet_size]),
+ global_state(global_state_)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 );
+
+ data* start = symbols;
+ data* end = symbols + alphabet_size-1;
+
+ while (start != end)
+ {
+ start->count = 0;
+ start->left_count = 0;
+ ++start;
+ }
+
+ start->count = 1;
+ start->left_count = 0;
+
+
+ // update the left_counts for the symbol alphabet_size-1
+ unsigned short temp;
+ unsigned long symbol = alphabet_size-1;
+ while (symbol != 0)
+ {
+ // temp will be 1 if symbol is odd, 0 if it is even
+ temp = static_cast<unsigned short>(symbol&0x1);
+
+ // set symbol to its parent
+ symbol = (symbol-1)>>1;
+
+ // note that all left subchidren are odd and also that
+ // if symbol was a left subchild then we want to increment
+ // its parents left_count
+ if (temp)
+ ++symbols[symbol].left_count;
+ }
+
+ global_state.memory_usage += sizeof(data)*alphabet_size +
+ sizeof(conditioning_class_kernel_2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ conditioning_class_kernel_2<alphabet_size>::
+ ~conditioning_class_kernel_2 (
+ )
+ {
+ delete [] symbols;
+ global_state.memory_usage -= sizeof(data)*alphabet_size +
+ sizeof(conditioning_class_kernel_2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ void conditioning_class_kernel_2<alphabet_size>::
+ clear(
+ )
+ {
+ data* start = symbols;
+ data* end = symbols + alphabet_size-1;
+
+ total = 1;
+
+ while (start != end)
+ {
+ start->count = 0;
+ start->left_count = 0;
+ ++start;
+ }
+
+ start->count = 1;
+ start->left_count = 0;
+
+ // update the left_counts
+ unsigned short temp;
+ unsigned long symbol = alphabet_size-1;
+ while (symbol != 0)
+ {
+ // temp will be 1 if symbol is odd, 0 if it is even
+ temp = static_cast<unsigned short>(symbol&0x1);
+
+ // set symbol to its parent
+ symbol = (symbol-1)>>1;
+
+ // note that all left subchidren are odd and also that
+ // if symbol was a left subchild then we want to increment
+ // its parents left_count
+ symbols[symbol].left_count += temp;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_2<alphabet_size>::
+ get_memory_usage(
+ ) const
+ {
+ return global_state.memory_usage;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ typename conditioning_class_kernel_2<alphabet_size>::global_state_type& conditioning_class_kernel_2<alphabet_size>::
+ get_global_state(
+ )
+ {
+ return global_state;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ bool conditioning_class_kernel_2<alphabet_size>::
+ increment_count (
+ unsigned long symbol,
+ unsigned short amount
+ )
+ {
+ // if we need to renormalize then do so
+ if (static_cast<unsigned long>(total)+static_cast<unsigned long>(amount) >= 65536)
+ {
+ unsigned long s;
+ unsigned short temp;
+ for (unsigned short i = 0; i < alphabet_size-1; ++i)
+ {
+ s = i;
+
+ // divide the count for this symbol by 2
+ symbols[i].count >>= 1;
+
+ symbols[i].left_count = 0;
+
+ // bubble this change up though the tree
+ while (s != 0)
+ {
+ // temp will be 1 if symbol is odd, 0 if it is even
+ temp = static_cast<unsigned short>(s&0x1);
+
+ // set s to its parent
+ s = (s-1)>>1;
+
+ // note that all left subchidren are odd and also that
+ // if s was a left subchild then we want to increment
+ // its parents left_count
+ if (temp)
+ symbols[s].left_count += symbols[i].count;
+ }
+ }
+
+ // update symbols alphabet_size-1
+ {
+ s = alphabet_size-1;
+
+ // divide alphabet_size-1 symbol by 2 if it's > 1
+ if (symbols[alphabet_size-1].count > 1)
+ symbols[alphabet_size-1].count >>= 1;
+
+ // bubble this change up though the tree
+ while (s != 0)
+ {
+ // temp will be 1 if symbol is odd, 0 if it is even
+ temp = static_cast<unsigned short>(s&0x1);
+
+ // set s to its parent
+ s = (s-1)>>1;
+
+ // note that all left subchidren are odd and also that
+ // if s was a left subchild then we want to increment
+ // its parents left_count
+ if (temp)
+ symbols[s].left_count += symbols[alphabet_size-1].count;
+ }
+ }
+
+
+
+
+
+
+ // calculate the new total
+ total = 0;
+ unsigned long m = 0;
+ while (m < alphabet_size)
+ {
+ total += symbols[m].count + symbols[m].left_count;
+ m = (m<<1) + 2;
+ }
+
+ }
+
+
+
+
+ // increment the count for the specified symbol
+ symbols[symbol].count += amount;;
+ total += amount;
+
+
+ unsigned short temp;
+ while (symbol != 0)
+ {
+ // temp will be 1 if symbol is odd, 0 if it is even
+ temp = static_cast<unsigned short>(symbol&0x1);
+
+ // set symbol to its parent
+ symbol = (symbol-1)>>1;
+
+ // note that all left subchidren are odd and also that
+ // if symbol was a left subchild then we want to increment
+ // its parents left_count
+ if (temp)
+ symbols[symbol].left_count += amount;
+ }
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_2<alphabet_size>::
+ get_count (
+ unsigned long symbol
+ ) const
+ {
+ return symbols[symbol].count;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_2<alphabet_size>::
+ get_alphabet_size (
+ )
+ {
+ return alphabet_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_2<alphabet_size>::
+ get_total (
+ ) const
+ {
+ return total;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_2<alphabet_size>::
+ get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const
+ {
+ if (symbols[symbol].count == 0)
+ return 0;
+
+ unsigned long current = symbol;
+ total_count = total;
+ unsigned long high_count_temp = 0;
+ bool came_from_right = true;
+ while (true)
+ {
+ if (came_from_right)
+ {
+ high_count_temp += symbols[current].count + symbols[current].left_count;
+ }
+
+ // note that if current is even then it is a right child
+ came_from_right = !(current&0x1);
+
+ if (current == 0)
+ break;
+
+ // set current to its parent
+ current = (current-1)>>1 ;
+ }
+
+
+ low_count = high_count_temp - symbols[symbol].count;
+ high_count = high_count_temp;
+
+ return symbols[symbol].count;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ void conditioning_class_kernel_2<alphabet_size>::
+ get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const
+ {
+ unsigned long current = 0;
+ unsigned long low_count_temp = 0;
+
+ while (true)
+ {
+ if (static_cast<unsigned short>(target) < symbols[current].left_count)
+ {
+ // we should go left
+ current = (current<<1) + 1;
+ }
+ else
+ {
+ target -= symbols[current].left_count;
+ low_count_temp += symbols[current].left_count;
+ if (static_cast<unsigned short>(target) < symbols[current].count)
+ {
+ // we have found our target
+ symbol = current;
+ high_count = low_count_temp + symbols[current].count;
+ low_count = low_count_temp;
+ break;
+ }
+ else
+ {
+ // go right
+ target -= symbols[current].count;
+ low_count_temp += symbols[current].count;
+ current = (current<<1) + 2;
+ }
+ }
+
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CONDITIONING_CLASS_KERNEl_1_
+
diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_3.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_3.h
new file mode 100644
index 000000000..b6de48555
--- /dev/null
+++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_3.h
@@ -0,0 +1,438 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONDITIONING_CLASS_KERNEl_3_
+#define DLIB_CONDITIONING_CLASS_KERNEl_3_
+
+#include "conditioning_class_kernel_abstract.h"
+#include "../assert.h"
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size
+ >
+ class conditioning_class_kernel_3
+ {
+ /*!
+ INITIAL VALUE
+ total == 1
+ counts == pointer to an array of alphabet_size data structs
+ for all i except i == 0: counts[i].count == 0
+ counts[0].count == 1
+ counts[0].symbol == alphabet_size-1
+ for all i except i == alphabet_size-1: counts[i].present == false
+ counts[alphabet_size-1].present == true
+
+ CONVENTION
+ counts == pointer to an array of alphabet_size data structs
+ get_total() == total
+ get_count(symbol) == counts[x].count where
+ counts[x].symbol == symbol
+
+
+ LOW_COUNT(symbol) == sum of counts[0].count though counts[x-1].count
+ where counts[x].symbol == symbol
+ if (counts[0].symbol == symbol) LOW_COUNT(symbol)==0
+
+
+ if (counts[i].count == 0) then
+ counts[i].symbol == undefined value
+
+ if (symbol has a nonzero count) then
+ counts[symbol].present == true
+
+ get_memory_usage() == global_state.memory_usage
+ !*/
+
+ public:
+
+ class global_state_type
+ {
+ public:
+ global_state_type () : memory_usage(0) {}
+ private:
+ unsigned long memory_usage;
+
+ friend class conditioning_class_kernel_3<alphabet_size>;
+ };
+
+ conditioning_class_kernel_3 (
+ global_state_type& global_state_
+ );
+
+ ~conditioning_class_kernel_3 (
+ );
+
+ void clear(
+ );
+
+ bool increment_count (
+ unsigned long symbol,
+ unsigned short amount = 1
+ );
+
+ unsigned long get_count (
+ unsigned long symbol
+ ) const;
+
+ unsigned long get_total (
+ ) const;
+
+ unsigned long get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const;
+
+ void get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const;
+
+ unsigned long get_memory_usage (
+ ) const;
+
+ global_state_type& get_global_state (
+ );
+
+ static unsigned long get_alphabet_size (
+ );
+
+ private:
+
+ // restricted functions
+ conditioning_class_kernel_3(conditioning_class_kernel_3<alphabet_size>&); // copy constructor
+ conditioning_class_kernel_3& operator=(conditioning_class_kernel_3<alphabet_size>&); // assignment operator
+
+ struct data
+ {
+ unsigned short count;
+ unsigned short symbol;
+ bool present;
+ };
+
+ // data members
+ unsigned short total;
+ data* counts;
+ global_state_type& global_state;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ conditioning_class_kernel_3<alphabet_size>::
+ conditioning_class_kernel_3 (
+ global_state_type& global_state_
+ ) :
+ total(1),
+ counts(new data[alphabet_size]),
+ global_state(global_state_)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 );
+
+ data* start = counts;
+ data* end = counts+alphabet_size;
+ start->count = 1;
+ start->symbol = alphabet_size-1;
+ start->present = false;
+ ++start;
+ while (start != end)
+ {
+ start->count = 0;
+ start->present = false;
+ ++start;
+ }
+ counts[alphabet_size-1].present = true;
+
+ // update memory usage
+ global_state.memory_usage += sizeof(data)*alphabet_size +
+ sizeof(conditioning_class_kernel_3);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ conditioning_class_kernel_3<alphabet_size>::
+ ~conditioning_class_kernel_3 (
+ )
+ {
+ delete [] counts;
+ // update memory usage
+ global_state.memory_usage -= sizeof(data)*alphabet_size +
+ sizeof(conditioning_class_kernel_3);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ void conditioning_class_kernel_3<alphabet_size>::
+ clear(
+ )
+ {
+ total = 1;
+ data* start = counts;
+ data* end = counts+alphabet_size;
+ start->count = 1;
+ start->symbol = alphabet_size-1;
+ start->present = false;
+ ++start;
+ while (start != end)
+ {
+ start->count = 0;
+ start->present = false;
+ ++start;
+ }
+ counts[alphabet_size-1].present = true;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ typename conditioning_class_kernel_3<alphabet_size>::global_state_type& conditioning_class_kernel_3<alphabet_size>::
+ get_global_state(
+ )
+ {
+ return global_state;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_3<alphabet_size>::
+ get_memory_usage(
+ ) const
+ {
+ return global_state.memory_usage;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ bool conditioning_class_kernel_3<alphabet_size>::
+ increment_count (
+ unsigned long symbol,
+ unsigned short amount
+ )
+ {
+ // if we are going over a total of 65535 then scale down all counts by 2
+ if (static_cast<unsigned long>(total)+static_cast<unsigned long>(amount) >= 65536)
+ {
+ total = 0;
+ data* start = counts;
+ data* end = counts+alphabet_size;
+
+ while (start != end)
+ {
+ if (start->count == 1)
+ {
+ if (start->symbol == alphabet_size-1)
+ {
+ // this symbol must never be zero so we will leave its count at 1
+ ++total;
+ }
+ else
+ {
+ start->count = 0;
+ counts[start->symbol].present = false;
+ }
+ }
+ else
+ {
+ start->count >>= 1;
+ total += start->count;
+ }
+
+ ++start;
+ }
+ }
+
+
+ data* start = counts;
+ data* swap_spot = counts;
+
+ if (counts[symbol].present)
+ {
+ while (true)
+ {
+ if (start->symbol == symbol && start->count!=0)
+ {
+ unsigned short temp = start->count + amount;
+
+ start->symbol = swap_spot->symbol;
+ start->count = swap_spot->count;
+
+ swap_spot->symbol = static_cast<unsigned short>(symbol);
+ swap_spot->count = temp;
+ break;
+ }
+
+ if ( (start->count) < (swap_spot->count))
+ {
+ swap_spot = start;
+ }
+
+
+ ++start;
+ }
+ }
+ else
+ {
+ counts[symbol].present = true;
+ while (true)
+ {
+ if (start->count == 0)
+ {
+ start->symbol = swap_spot->symbol;
+ start->count = swap_spot->count;
+
+ swap_spot->symbol = static_cast<unsigned short>(symbol);
+ swap_spot->count = amount;
+ break;
+ }
+
+ if ((start->count) < (swap_spot->count))
+ {
+ swap_spot = start;
+ }
+
+ ++start;
+ }
+ }
+
+ total += amount;
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_3<alphabet_size>::
+ get_count (
+ unsigned long symbol
+ ) const
+ {
+ if (counts[symbol].present == false)
+ return 0;
+
+ data* start = counts;
+ while (start->symbol != symbol)
+ {
+ ++start;
+ }
+ return start->count;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_3<alphabet_size>::
+ get_alphabet_size (
+ )
+ {
+ return alphabet_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_3<alphabet_size>::
+ get_total (
+ ) const
+ {
+ return total;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ unsigned long conditioning_class_kernel_3<alphabet_size>::
+ get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const
+ {
+ if (counts[symbol].present == false)
+ return 0;
+
+ total_count = total;
+ unsigned long low_count_temp = 0;
+ data* start = counts;
+ while (start->symbol != symbol)
+ {
+ low_count_temp += start->count;
+ ++start;
+ }
+
+ low_count = low_count_temp;
+ high_count = low_count_temp + start->count;
+ return start->count;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size
+ >
+ void conditioning_class_kernel_3<alphabet_size>::
+ get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const
+ {
+ unsigned long high_count_temp = counts->count;
+ const data* start = counts;
+ while (target >= high_count_temp)
+ {
+ ++start;
+ high_count_temp += start->count;
+ }
+
+ low_count = high_count_temp - start->count;
+ high_count = high_count_temp;
+ symbol = static_cast<unsigned long>(start->symbol);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CONDITIONING_CLASS_KERNEl_3_
+
diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_4.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_4.h
new file mode 100644
index 000000000..cb48ac196
--- /dev/null
+++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_4.h
@@ -0,0 +1,533 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONDITIONING_CLASS_KERNEl_4_
+#define DLIB_CONDITIONING_CLASS_KERNEl_4_
+
+#include "conditioning_class_kernel_abstract.h"
+#include "../assert.h"
+#include "../algs.h"
+
+namespace dlib
+{
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ class conditioning_class_kernel_4
+ {
+ /*!
+ REQUIREMENTS ON pool_size
+ pool_size > 0
+ this will be the number of nodes contained in our memory pool
+
+ REQUIREMENTS ON mem_manager
+ mem_manager is an implementation of memory_manager/memory_manager_kernel_abstract.h
+
+ INITIAL VALUE
+ total == 1
+ escapes == 1
+ next == 0
+
+ CONVENTION
+ get_total() == total
+ get_count(alphabet_size-1) == escapes
+
+ if (next != 0) then
+ next == pointer to the start of a linked list and the linked list
+ is terminated by a node with a next pointer of 0.
+
+ get_count(symbol) == node::count for the node where node::symbol==symbol
+ or 0 if no such node currently exists.
+
+ if (there is a node for the symbol) then
+ LOW_COUNT(symbol) == the sum of all node's counts in the linked list
+ up to but not including the node for the symbol.
+
+ get_memory_usage() == global_state.memory_usage
+ !*/
+
+
+ struct node
+ {
+ unsigned short symbol;
+ unsigned short count;
+ node* next;
+ };
+
+ public:
+
+ class global_state_type
+ {
+ public:
+ global_state_type (
+ ) :
+ memory_usage(pool_size*sizeof(node)+sizeof(global_state_type))
+ {}
+ private:
+ unsigned long memory_usage;
+
+ typename mem_manager::template rebind<node>::other pool;
+
+ friend class conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>;
+ };
+
+ conditioning_class_kernel_4 (
+ global_state_type& global_state_
+ );
+
+ ~conditioning_class_kernel_4 (
+ );
+
+ void clear(
+ );
+
+ bool increment_count (
+ unsigned long symbol,
+ unsigned short amount = 1
+ );
+
+ unsigned long get_count (
+ unsigned long symbol
+ ) const;
+
+ inline unsigned long get_total (
+ ) const;
+
+ unsigned long get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const;
+
+ void get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const;
+
+ unsigned long get_memory_usage (
+ ) const;
+
+ global_state_type& get_global_state (
+ );
+
+ static unsigned long get_alphabet_size (
+ );
+
+
+ private:
+
+ void half_counts (
+ );
+ /*!
+ ensures
+ - divides all counts by 2 but ensures that escapes is always at least 1
+ !*/
+
+ // restricted functions
+ conditioning_class_kernel_4(conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>&); // copy constructor
+ conditioning_class_kernel_4& operator=(conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>&); // assignment operator
+
+ // data members
+ unsigned short total;
+ unsigned short escapes;
+ node* next;
+ global_state_type& global_state;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ conditioning_class_kernel_4 (
+ global_state_type& global_state_
+ ) :
+ total(1),
+ escapes(1),
+ next(0),
+ global_state(global_state_)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 );
+
+ // update memory usage
+ global_state.memory_usage += sizeof(conditioning_class_kernel_4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ ~conditioning_class_kernel_4 (
+ )
+ {
+ clear();
+ // update memory usage
+ global_state.memory_usage -= sizeof(conditioning_class_kernel_4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ void conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ clear(
+ )
+ {
+ total = 1;
+ escapes = 1;
+ while (next)
+ {
+ node* temp = next;
+ next = next->next;
+ global_state.pool.deallocate(temp);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ get_memory_usage(
+ ) const
+ {
+ return global_state.memory_usage;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ typename conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::global_state_type& conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ get_global_state(
+ )
+ {
+ return global_state;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ bool conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ increment_count (
+ unsigned long symbol,
+ unsigned short amount
+ )
+ {
+ if (symbol == alphabet_size-1)
+ {
+ // make sure we won't cause any overflow
+ if (total >= 65536 - amount )
+ half_counts();
+
+ escapes += amount;
+ total += amount;
+ return true;
+ }
+
+
+ // find the symbol and increment it or add a new node to the list
+ if (next)
+ {
+ node* temp = next;
+ node* previous = 0;
+ while (true)
+ {
+ if (temp->symbol == static_cast<unsigned short>(symbol))
+ {
+ // make sure we won't cause any overflow
+ if (total >= 65536 - amount )
+ half_counts();
+
+ // we have found the symbol
+ total += amount;
+ temp->count += amount;
+
+ // if this node now has a count greater than its parent node
+ if (previous && temp->count > previous->count)
+ {
+ // swap the nodes so that the nodes will be in semi-sorted order
+ swap(temp->count,previous->count);
+ swap(temp->symbol,previous->symbol);
+ }
+ return true;
+ }
+ else if (temp->next == 0)
+ {
+ // we did not find the symbol so try to add it to the list
+ if (global_state.pool.get_number_of_allocations() < pool_size)
+ {
+ // make sure we won't cause any overflow
+ if (total >= 65536 - amount )
+ half_counts();
+
+ node* t = global_state.pool.allocate();
+ t->next = 0;
+ t->symbol = static_cast<unsigned short>(symbol);
+ t->count = amount;
+ temp->next = t;
+ total += amount;
+ return true;
+ }
+ else
+ {
+ // no memory left
+ return false;
+ }
+ }
+ else if (temp->count == 0)
+ {
+ // remove nodes that have a zero count
+ if (previous)
+ {
+ previous->next = temp->next;
+ node* t = temp;
+ temp = temp->next;
+ global_state.pool.deallocate(t);
+ }
+ else
+ {
+ next = temp->next;
+ node* t = temp;
+ temp = temp->next;
+ global_state.pool.deallocate(t);
+ }
+ }
+ else
+ {
+ previous = temp;
+ temp = temp->next;
+ }
+ } // while (true)
+ }
+ // if there aren't any nodes in the list yet then do this instead
+ else
+ {
+ if (global_state.pool.get_number_of_allocations() < pool_size)
+ {
+ // make sure we won't cause any overflow
+ if (total >= 65536 - amount )
+ half_counts();
+
+ next = global_state.pool.allocate();
+ next->next = 0;
+ next->symbol = static_cast<unsigned short>(symbol);
+ next->count = amount;
+ total += amount;
+ return true;
+ }
+ else
+ {
+ // no memory left
+ return false;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ get_count (
+ unsigned long symbol
+ ) const
+ {
+ if (symbol == alphabet_size-1)
+ {
+ return escapes;
+ }
+ else
+ {
+ node* temp = next;
+ while (temp)
+ {
+ if (temp->symbol == symbol)
+ return temp->count;
+ temp = temp->next;
+ }
+ return 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ get_alphabet_size (
+ )
+ {
+ return alphabet_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ get_total (
+ ) const
+ {
+ return total;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const
+ {
+ if (symbol != alphabet_size-1)
+ {
+ node* temp = next;
+ unsigned long low = 0;
+ while (temp)
+ {
+ if (temp->symbol == static_cast<unsigned short>(symbol))
+ {
+ high_count = temp->count + low;
+ low_count = low;
+ total_count = total;
+ return temp->count;
+ }
+ low += temp->count;
+ temp = temp->next;
+ }
+ return 0;
+ }
+ else
+ {
+ total_count = total;
+ high_count = total;
+ low_count = total-escapes;
+ return escapes;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ void conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const
+ {
+ node* temp = next;
+ unsigned long high = 0;
+ while (true)
+ {
+ if (temp != 0)
+ {
+ high += temp->count;
+ if (target < high)
+ {
+ symbol = temp->symbol;
+ high_count = high;
+ low_count = high - temp->count;
+ return;
+ }
+ temp = temp->next;
+ }
+ else
+ {
+ // this must be the escape symbol
+ symbol = alphabet_size-1;
+ low_count = total-escapes;
+ high_count = total;
+ return;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+
+ template <
+ unsigned long alphabet_size,
+ unsigned long pool_size,
+ typename mem_manager
+ >
+ void conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
+ half_counts (
+ )
+ {
+ total = 0;
+ if (escapes > 1)
+ escapes >>= 1;
+
+ //divide all counts by 2
+ node* temp = next;
+ while (temp)
+ {
+ temp->count >>= 1;
+ total += temp->count;
+ temp = temp->next;
+ }
+ total += escapes;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CONDITIONING_CLASS_KERNEl_4_
+
diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_abstract.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_abstract.h
new file mode 100644
index 000000000..411aea566
--- /dev/null
+++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_abstract.h
@@ -0,0 +1,228 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_
+#ifdef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size
+ >
+ class conditioning_class
+ {
+ /*!
+ REQUIREMENTS ON alphabet_size
+ 1 < alphabet_size < 65536
+
+ INITIAL VALUE
+ get_total() == 1
+ get_count(X) == 0 : for all valid values of X except alphabet_size-1
+ get_count(alphabet_size-1) == 1
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a conditioning class used for arithmetic style
+ compression. It maintains the cumulative counts which are needed
+ by the entropy_coder and entropy_decoder objects.
+
+ At any moment a conditioning_class object represents a set of
+ alphabet_size symbols. Each symbol is associated with an integer
+ called its count.
+
+ All symbols start out with a count of zero except for alphabet_size-1.
+ This last symbol will always have a count of at least one. It is
+ intended to be used as an escape into a lower context when coding
+ and so it must never have a zero probability or the decoder won't
+ be able to identify the escape symbol.
+
+ NOTATION:
+ Let MAP(i) be a function which maps integers to symbols. MAP(i) is
+ one to one and onto. Its domain is 1 to alphabet_size inclusive.
+
+ Let RMAP(s) be the inverse of MAP(i).
+ ( i.e. RMAP(MAP(i)) == i and MAP(RMAP(s)) == s )
+
+ Let COUNT(i) give the count for the symbol MAP(i).
+ ( i.e. COUNT(i) == get_count(MAP(i)) )
+
+
+ Let LOW_COUNT(s) == the sum of COUNT(x) for x == 1 to x == RMAP(s)-1
+ (note that the sum of COUNT(x) for x == 1 to x == 0 is 0)
+ Let HIGH_COUNT(s) == LOW_COUNT(s) + get_count(s)
+
+
+
+ Basically what this is saying is just that you shoudln't assume you know
+ what order the symbols are placed in when calculating the cumulative
+ sums. The specific mapping provided by the MAP() function is unspecified.
+
+ THREAD SAFETY
+ This object can be used safely in a multithreaded program as long as the
+ global state is not shared between conditioning classes which run on
+ different threads.
+
+ GLOBAL_STATE_TYPE
+ The global_state_type obejct allows instances of the conditioning_class
+ object to share any kind of global state the implementer desires.
+ However, the global_state_type object exists primarily to facilitate the
+ sharing of a memory pool between many instances of a conditioning_class
+ object. But note that it is not required that there be any kind of
+ memory pool at all, it is just a possibility.
+ !*/
+
+ public:
+
+ class global_state_type
+ {
+ global_state_type (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ // my contents are implementation specific.
+ };
+
+ conditioning_class (
+ global_state_type& global_state
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - &#get_global_state() == &global_state
+ throws
+ - std::bad_alloc
+ !*/
+
+ ~conditioning_class (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ !*/
+
+ bool increment_count (
+ unsigned long symbol,
+ unsigned short amount = 1
+ );
+ /*!
+ requires
+ - 0 <= symbol < alphabet_size
+ - 0 < amount < 32768
+ ensures
+ - if (sufficient memory is available to complete this operation) then
+ - returns true
+ - if (get_total()+amount < 65536) then
+ - #get_count(symbol) == get_count(symbol) + amount
+ - else
+ - #get_count(symbol) == get_count(symbol)/2 + amount
+ - if (get_count(alphabet_size-1) == 1) then
+ - #get_count(alphabet_size-1) == 1
+ - else
+ - #get_count(alphabet_size-1) == get_count(alphabet_size-1)/2
+ - for all X where (X != symbol)&&(X != alpahbet_size-1):
+ #get_count(X) == get_count(X)/2
+ - else
+ - returns false
+ !*/
+
+ unsigned long get_count (
+ unsigned long symbol
+ ) const;
+ /*!
+ requires
+ - 0 <= symbol < alphabet_size
+ ensures
+ - returns the count for the specified symbol
+ !*/
+
+ unsigned long get_total (
+ ) const;
+ /*!
+ ensures
+ - returns the sum of get_count(X) for all valid values of X
+ (i.e. returns the sum of the counts for all the symbols)
+ !*/
+
+ unsigned long get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const;
+ /*!
+ requires
+ - 0 <= symbol < alphabet_size
+ ensures
+ - returns get_count(symbol)
+ - if (get_count(symbol) != 0) then
+ - #total_count == get_total()
+ - #low_count == LOW_COUNT(symbol)
+ - #high_count == HIGH_COUNT(symbol)
+ - #low_count < #high_count <= #total_count
+ !*/
+
+ void get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const;
+ /*!
+ requires
+ - 0 <= target < get_total()
+ ensures
+ - LOW_COUNT(#symbol) <= target < HIGH_COUNT(#symbol)
+ - #low_count == LOW_COUNT(#symbol)
+ - #high_count == HIGH_COUNT(#symbol)
+ - #low_count < #high_count <= get_total()
+ !*/
+
+ global_state_type& get_global_state (
+ );
+ /*!
+ ensures
+ - returns a reference to the global state used by *this
+ !*/
+
+ unsigned long get_memory_usage (
+ ) const;
+ /*!
+ ensures
+ - returns the number of bytes of memory allocated by all conditioning_class
+ objects that share the global state given by get_global_state()
+ !*/
+
+ static unsigned long get_alphabet_size (
+ );
+ /*!
+ ensures
+ - returns alphabet_size
+ !*/
+
+ private:
+
+ // restricted functions
+ conditioning_class(conditioning_class<alphabet_size>&); // copy constructor
+ conditioning_class<alphabet_size>& operator=(conditioning_class<alphabet_size>&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_c.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_c.h
new file mode 100644
index 000000000..964240be8
--- /dev/null
+++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_c.h
@@ -0,0 +1,162 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONDITIONING_CLASS_KERNEl_C_
+#define DLIB_CONDITIONING_CLASS_KERNEl_C_
+
+#include "conditioning_class_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename cc_base
+ >
+ class conditioning_class_kernel_c : public cc_base
+ {
+ const unsigned long alphabet_size;
+
+ public:
+
+ conditioning_class_kernel_c (
+ typename cc_base::global_state_type& global_state
+ ) : cc_base(global_state),alphabet_size(cc_base::get_alphabet_size()) {}
+
+ bool increment_count (
+ unsigned long symbol,
+ unsigned short amount = 1
+ );
+
+ unsigned long get_count (
+ unsigned long symbol
+ ) const;
+
+ unsigned long get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const;
+
+ void get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const;
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename cc_base
+ >
+ bool conditioning_class_kernel_c<cc_base>::
+ increment_count (
+ unsigned long symbol,
+ unsigned short amount
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(symbol < alphabet_size &&
+ 0 < amount && amount < 32768,
+ "\tvoid conditioning_class::increment_count()"
+ << "\n\tthe symbol must be in the range 0 to alphabet_size-1. and"
+ << "\n\tamount must be in the range 1 to 32767"
+ << "\n\talphabet_size: " << alphabet_size
+ << "\n\tsymbol: " << symbol
+ << "\n\tamount: " << amount
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return cc_base::increment_count(symbol,amount);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename cc_base
+ >
+ unsigned long conditioning_class_kernel_c<cc_base>::
+ get_count (
+ unsigned long symbol
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(symbol < alphabet_size,
+ "\tvoid conditioning_class::get_count()"
+ << "\n\tthe symbol must be in the range 0 to alphabet_size-1"
+ << "\n\talphabet_size: " << alphabet_size
+ << "\n\tsymbol: " << symbol
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return cc_base::get_count(symbol);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename cc_base
+ >
+ unsigned long conditioning_class_kernel_c<cc_base>::
+ get_range (
+ unsigned long symbol,
+ unsigned long& low_count,
+ unsigned long& high_count,
+ unsigned long& total_count
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(symbol < alphabet_size,
+ "\tvoid conditioning_class::get_range()"
+ << "\n\tthe symbol must be in the range 0 to alphabet_size-1"
+ << "\n\talphabet_size: " << alphabet_size
+ << "\n\tsymbol: " << symbol
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return cc_base::get_range(symbol,low_count,high_count,total_count);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename cc_base
+ >
+ void conditioning_class_kernel_c<cc_base>::
+ get_symbol (
+ unsigned long target,
+ unsigned long& symbol,
+ unsigned long& low_count,
+ unsigned long& high_count
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( target < this->get_total(),
+ "\tvoid conditioning_class::get_symbol()"
+ << "\n\tthe target must be in the range 0 to get_total()-1"
+ << "\n\tget_total(): " << this->get_total()
+ << "\n\ttarget: " << target
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ cc_base::get_symbol(target,symbol,low_count,high_count);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CONDITIONING_CLASS_KERNEl_C_
+
diff --git a/ml/dlib/dlib/config.h b/ml/dlib/dlib/config.h
new file mode 100644
index 000000000..a48ac415f
--- /dev/null
+++ b/ml/dlib/dlib/config.h
@@ -0,0 +1,31 @@
+
+
+// If you are compiling dlib as a shared library and installing it somewhere on your system
+// then it is important that any programs that use dlib agree on the state of the
+// DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore,
+// uncomment one of the following lines to force all DLIB_ASSERTs to either always on or
+// always off. If you don't define one of these two macros then DLIB_ASSERT will toggle
+// automatically depending on the state of certain other macros, which is not what you want
+// when creating a shared library.
+//#define ENABLE_ASSERTS // asserts always enabled
+//#define DLIB_DISABLE_ASSERTS // asserts always disabled
+
+//#define DLIB_ISO_CPP_ONLY
+//#define DLIB_NO_GUI_SUPPORT
+//#define DLIB_ENABLE_STACK_TRACE
+
+// You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA,
+// and a BLAS and LAPACK library. To do this you need to uncomment the following #defines.
+// #define DLIB_JPEG_SUPPORT
+// #define DLIB_PNG_SUPPORT
+// #define DLIB_GIF_SUPPORT
+// #define DLIB_USE_FFTW
+// #define DLIB_USE_BLAS
+// #define DLIB_USE_LAPACK
+// #define DLIB_USE_CUDA
+
+
+// Define this so the code in dlib/test_for_odr_violations.h can detect ODR violations
+// related to users doing bad things with config.h
+#define DLIB_NOT_CONFIGURED
+
diff --git a/ml/dlib/dlib/config.h.in b/ml/dlib/dlib/config.h.in
new file mode 100644
index 000000000..27dff2b27
--- /dev/null
+++ b/ml/dlib/dlib/config.h.in
@@ -0,0 +1,34 @@
+
+
+// If you are compiling dlib as a shared library and installing it somewhere on your system
+// then it is important that any programs that use dlib agree on the state of the
+// DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore,
+// uncomment one of the following lines to force all DLIB_ASSERTs to either always on or
+// always off. If you don't define one of these two macros then DLIB_ASSERT will toggle
+// automatically depending on the state of certain other macros, which is not what you want
+// when creating a shared library.
+#cmakedefine ENABLE_ASSERTS // asserts always enabled
+#cmakedefine DLIB_DISABLE_ASSERTS // asserts always disabled
+
+#cmakedefine DLIB_ISO_CPP_ONLY
+#cmakedefine DLIB_NO_GUI_SUPPORT
+#cmakedefine DLIB_ENABLE_STACK_TRACE
+
+#cmakedefine LAPACK_FORCE_UNDERSCORE
+#cmakedefine LAPACK_FORCE_NOUNDERSCORE
+
+// You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA,
+// and a BLAS and LAPACK library. To do this you need to uncomment the following #defines.
+#cmakedefine DLIB_JPEG_SUPPORT
+#cmakedefine DLIB_PNG_SUPPORT
+#cmakedefine DLIB_GIF_SUPPORT
+#cmakedefine DLIB_USE_FFTW
+#cmakedefine DLIB_USE_BLAS
+#cmakedefine DLIB_USE_LAPACK
+#cmakedefine DLIB_USE_CUDA
+#cmakedefine DLIB_USE_MKL_FFT
+
+// This variable allows dlib/test_for_odr_violations.h to catch people who mistakenly use
+// headers from one version of dlib with a compiled dlib binary from a different dlib version.
+#cmakedefine DLIB_CHECK_FOR_VERSION_MISMATCH @DLIB_CHECK_FOR_VERSION_MISMATCH@
+
diff --git a/ml/dlib/dlib/config_reader.h b/ml/dlib/dlib/config_reader.h
new file mode 100644
index 000000000..d140a310c
--- /dev/null
+++ b/ml/dlib/dlib/config_reader.h
@@ -0,0 +1,39 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONFIG_READEr_
+#define DLIB_CONFIG_READEr_
+
+#include "config_reader/config_reader_kernel_1.h"
+#include "map.h"
+#include "tokenizer.h"
+#include "cmd_line_parser/get_option.h"
+
+#include "algs.h"
+#include "is_kind.h"
+
+
+namespace dlib
+{
+
+ typedef config_reader_kernel_1<
+ map<std::string,std::string>::kernel_1b,
+ map<std::string,void*>::kernel_1b,
+ tokenizer::kernel_1a
+ > config_reader;
+
+ template <> struct is_config_reader<config_reader> { const static bool value = true; };
+
+#ifndef DLIB_ISO_CPP_ONLY
+ typedef config_reader_thread_safe_1<
+ config_reader,
+ map<std::string,void*>::kernel_1b
+ > config_reader_thread_safe;
+
+ template <> struct is_config_reader<config_reader_thread_safe> { const static bool value = true; };
+#endif // DLIB_ISO_CPP_ONLY
+
+
+}
+
+#endif // DLIB_CONFIG_READEr_
+
diff --git a/ml/dlib/dlib/config_reader/config_reader_kernel_1.h b/ml/dlib/dlib/config_reader/config_reader_kernel_1.h
new file mode 100644
index 000000000..c0f9e5a71
--- /dev/null
+++ b/ml/dlib/dlib/config_reader/config_reader_kernel_1.h
@@ -0,0 +1,738 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONFIG_READER_KERNEl_1_
+#define DLIB_CONFIG_READER_KERNEl_1_
+
+#include "config_reader_kernel_abstract.h"
+#include <string>
+#include <iostream>
+#include <sstream>
+#include <fstream>
+#include "../algs.h"
+#include "../stl_checked/std_vector_c.h"
+
+#ifndef DLIB_ISO_CPP_ONLY
+#include "config_reader_thread_safe_1.h"
+#endif
+
+namespace dlib
+{
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ class config_reader_kernel_1
+ {
+
+ /*!
+ REQUIREMENTS ON map_string_string
+ is an implementation of map/map_kernel_abstract.h that maps std::string to std::string
+
+ REQUIREMENTS ON map_string_void
+ is an implementation of map/map_kernel_abstract.h that maps std::string to void*
+
+ REQUIREMENTS ON tokenizer
+ is an implementation of tokenizer/tokenizer_kernel_abstract.h
+
+ CONVENTION
+ key_table.is_in_domain(x) == is_key_defined(x)
+ block_table.is_in_domain(x) == is_block_defined(x)
+
+ key_table[x] == operator[](x)
+ block_table[x] == (void*)&block(x)
+ !*/
+
+ public:
+
+ // These two typedefs are defined for backwards compatibility with older versions of dlib.
+ typedef config_reader_kernel_1 kernel_1a;
+#ifndef DLIB_ISO_CPP_ONLY
+ typedef config_reader_thread_safe_1<
+ config_reader_kernel_1,
+ map_string_void
+ > thread_safe_1a;
+#endif // DLIB_ISO_CPP_ONLY
+
+
+ config_reader_kernel_1();
+
+ class config_reader_error : public dlib::error
+ {
+ friend class config_reader_kernel_1;
+ config_reader_error(
+ unsigned long ln,
+ bool r = false
+ ) :
+ dlib::error(ECONFIG_READER),
+ line_number(ln),
+ redefinition(r)
+ {
+ std::ostringstream sout;
+ sout << "Error in config_reader while parsing at line number " << line_number << ".";
+ if (redefinition)
+ sout << "\nThe identifier on this line has already been defined in this scope.";
+ const_cast<std::string&>(info) = sout.str();
+ }
+ public:
+ const unsigned long line_number;
+ const bool redefinition;
+ };
+
+ class file_not_found : public dlib::error
+ {
+ friend class config_reader_kernel_1;
+ file_not_found(
+ const std::string& file_name_
+ ) :
+ dlib::error(ECONFIG_READER, "Error in config_reader, unable to open file " + file_name_),
+ file_name(file_name_)
+ {}
+
+ ~file_not_found() throw() {}
+
+ public:
+ const std::string file_name;
+ };
+
+ class config_reader_access_error : public dlib::error
+ {
+ public:
+ config_reader_access_error(
+ const std::string& block_name_,
+ const std::string& key_name_
+ ) :
+ dlib::error(ECONFIG_READER),
+ block_name(block_name_),
+ key_name(key_name_)
+ {
+ std::ostringstream sout;
+ sout << "Error in config_reader.\n";
+ if (block_name.size() > 0)
+ sout << " A block with the name '" << block_name << "' was expected but not found.";
+ else if (key_name.size() > 0)
+ sout << " A key with the name '" << key_name << "' was expected but not found.";
+
+ const_cast<std::string&>(info) = sout.str();
+ }
+
+ ~config_reader_access_error() throw() {}
+ const std::string block_name;
+ const std::string key_name;
+ };
+
+ config_reader_kernel_1(
+ const std::string& config_file
+ );
+
+ config_reader_kernel_1(
+ std::istream& in
+ );
+
+ virtual ~config_reader_kernel_1(
+ );
+
+ void clear (
+ );
+
+ void load_from (
+ std::istream& in
+ );
+
+ void load_from (
+ const std::string& config_file
+ );
+
+ bool is_key_defined (
+ const std::string& key
+ ) const;
+
+ bool is_block_defined (
+ const std::string& name
+ ) const;
+
+ typedef config_reader_kernel_1 this_type;
+ const this_type& block (
+ const std::string& name
+ ) const;
+
+ const std::string& operator[] (
+ const std::string& key
+ ) const;
+
+ template <
+ typename queue_of_strings
+ >
+ void get_keys (
+ queue_of_strings& keys
+ ) const;
+
+ template <
+ typename alloc
+ >
+ void get_keys (
+ std::vector<std::string,alloc>& keys
+ ) const;
+
+ template <
+ typename alloc
+ >
+ void get_keys (
+ std_vector_c<std::string,alloc>& keys
+ ) const;
+
+ template <
+ typename queue_of_strings
+ >
+ void get_blocks (
+ queue_of_strings& blocks
+ ) const;
+
+ template <
+ typename alloc
+ >
+ void get_blocks (
+ std::vector<std::string,alloc>& blocks
+ ) const;
+
+ template <
+ typename alloc
+ >
+ void get_blocks (
+ std_vector_c<std::string,alloc>& blocks
+ ) const;
+
+ private:
+
+ static void parse_config_file (
+ config_reader_kernel_1& cr,
+ tokenizer& tok,
+ unsigned long& line_number,
+ const bool top_of_recursion = true
+ );
+ /*!
+ requires
+ - line_number == 1
+ - cr == *this
+ - top_of_recursion == true
+ ensures
+ - parses the data coming from tok and puts it into cr.
+ throws
+ - config_reader_error
+ !*/
+
+ map_string_string key_table;
+ map_string_void block_table;
+
+ // restricted functions
+ config_reader_kernel_1(config_reader_kernel_1&);
+ config_reader_kernel_1& operator=(config_reader_kernel_1&);
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ config_reader_kernel_1(
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ clear(
+ )
+ {
+ // free all our blocks
+ block_table.reset();
+ while (block_table.move_next())
+ {
+ delete static_cast<config_reader_kernel_1*>(block_table.element().value());
+ }
+ block_table.clear();
+ key_table.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ load_from(
+ std::istream& in
+ )
+ {
+ clear();
+
+ tokenizer tok;
+ tok.set_stream(in);
+ tok.set_identifier_token(
+ tok.lowercase_letters() + tok.uppercase_letters(),
+ tok.lowercase_letters() + tok.uppercase_letters() + tok.numbers() + "_-."
+ );
+
+ unsigned long line_number = 1;
+ try
+ {
+ parse_config_file(*this,tok,line_number);
+ }
+ catch (...)
+ {
+ clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ load_from(
+ const std::string& config_file
+ )
+ {
+ clear();
+ std::ifstream fin(config_file.c_str());
+ if (!fin)
+ throw file_not_found(config_file);
+
+ load_from(fin);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ config_reader_kernel_1(
+ std::istream& in
+ )
+ {
+ load_from(in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ config_reader_kernel_1(
+ const std::string& config_file
+ )
+ {
+ load_from(config_file);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ parse_config_file(
+ config_reader_kernel_1<map_string_string,map_string_void,tokenizer>& cr,
+ tokenizer& tok,
+ unsigned long& line_number,
+ const bool top_of_recursion
+ )
+ {
+ int type;
+ std::string token;
+ bool in_comment = false;
+ bool seen_identifier = false;
+ std::string identifier;
+ while (true)
+ {
+ tok.get_token(type,token);
+ // ignore white space
+ if (type == tokenizer::WHITE_SPACE)
+ continue;
+
+ // basically ignore end of lines
+ if (type == tokenizer::END_OF_LINE)
+ {
+ ++line_number;
+ in_comment = false;
+ continue;
+ }
+
+ // we are in a comment still so ignore this
+ if (in_comment)
+ continue;
+
+ // if this is the start of a comment
+ if (type == tokenizer::CHAR && token[0] == '#')
+ {
+ in_comment = true;
+ continue;
+ }
+
+ // if this is the case then we have just finished parsing a block so we should
+ // quit this function
+ if ( (type == tokenizer::CHAR && token[0] == '}' && !top_of_recursion) ||
+ (type == tokenizer::END_OF_FILE && top_of_recursion) )
+ {
+ break;
+ }
+
+ if (seen_identifier)
+ {
+ seen_identifier = false;
+ // the next character should be either a '=' or a '{'
+ if (type != tokenizer::CHAR || (token[0] != '=' && token[0] != '{'))
+ throw config_reader_error(line_number);
+
+ if (token[0] == '=')
+ {
+ // we should parse the value out now
+ // first discard any white space
+ if (tok.peek_type() == tokenizer::WHITE_SPACE)
+ tok.get_token(type,token);
+
+ std::string value;
+ type = tok.peek_type();
+ token = tok.peek_token();
+ while (true)
+ {
+ if (type == tokenizer::END_OF_FILE || type == tokenizer::END_OF_LINE)
+ break;
+
+ if (type == tokenizer::CHAR && token[0] == '\\')
+ {
+ tok.get_token(type,token);
+ if (tok.peek_type() == tokenizer::CHAR &&
+ tok.peek_token()[0] == '#')
+ {
+ tok.get_token(type,token);
+ value += '#';
+ }
+ else if (tok.peek_type() == tokenizer::CHAR &&
+ tok.peek_token()[0] == '}')
+ {
+ tok.get_token(type,token);
+ value += '}';
+ }
+ else
+ {
+ value += '\\';
+ }
+ }
+ else if (type == tokenizer::CHAR &&
+ (token[0] == '#' || token[0] == '}'))
+ {
+ break;
+ }
+ else
+ {
+ value += token;
+ tok.get_token(type,token);
+ }
+ type = tok.peek_type();
+ token = tok.peek_token();
+ } // while(true)
+
+ // strip of any tailing white space from value
+ std::string::size_type pos = value.find_last_not_of(" \t\r\n");
+ if (pos == std::string::npos)
+ value.clear();
+ else
+ value.erase(pos+1);
+
+ // make sure this key isn't already in the key_table
+ if (cr.key_table.is_in_domain(identifier))
+ throw config_reader_error(line_number,true);
+
+ // add this key/value pair to the key_table
+ cr.key_table.add(identifier,value);
+
+ }
+ else // when token[0] == '{'
+ {
+ // make sure this identifier isn't already in the block_table
+ if (cr.block_table.is_in_domain(identifier))
+ throw config_reader_error(line_number,true);
+
+ config_reader_kernel_1* new_cr = new config_reader_kernel_1;
+ void* vtemp = new_cr;
+ try { cr.block_table.add(identifier,vtemp); }
+ catch (...) { delete new_cr; throw; }
+
+ // now parse this block
+ parse_config_file(*new_cr,tok,line_number,false);
+ }
+ }
+ else
+ {
+ // the next thing should be an identifier but if it isn't this is an error
+ if (type != tokenizer::IDENTIFIER)
+ throw config_reader_error(line_number);
+
+ seen_identifier = true;
+ identifier = token;
+ }
+ } // while (true)
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ ~config_reader_kernel_1(
+ )
+ {
+ clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ bool config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ is_key_defined (
+ const std::string& key
+ ) const
+ {
+ return key_table.is_in_domain(key);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ bool config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ is_block_defined (
+ const std::string& name
+ ) const
+ {
+ return block_table.is_in_domain(name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename mss,
+ typename msv,
+ typename tokenizer
+ >
+ const config_reader_kernel_1<mss,msv,tokenizer>& config_reader_kernel_1<mss,msv,tokenizer>::
+ block (
+ const std::string& name
+ ) const
+ {
+ if (is_block_defined(name) == false)
+ {
+ throw config_reader_access_error(name,"");
+ }
+
+ return *static_cast<config_reader_kernel_1*>(block_table[name]);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ const std::string& config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ operator[] (
+ const std::string& key
+ ) const
+ {
+ if (is_key_defined(key) == false)
+ {
+ throw config_reader_access_error("",key);
+ }
+
+ return key_table[key];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ template <
+ typename queue_of_strings
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ get_keys (
+ queue_of_strings& keys
+ ) const
+ {
+ keys.clear();
+ key_table.reset();
+ std::string temp;
+ while (key_table.move_next())
+ {
+ temp = key_table.element().key();
+ keys.enqueue(temp);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ template <
+ typename alloc
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ get_keys (
+ std::vector<std::string,alloc>& keys
+ ) const
+ {
+ keys.clear();
+ key_table.reset();
+ while (key_table.move_next())
+ {
+ keys.push_back(key_table.element().key());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ template <
+ typename alloc
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ get_keys (
+ std_vector_c<std::string,alloc>& keys
+ ) const
+ {
+ keys.clear();
+ key_table.reset();
+ while (key_table.move_next())
+ {
+ keys.push_back(key_table.element().key());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ template <
+ typename queue_of_strings
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ get_blocks (
+ queue_of_strings& blocks
+ ) const
+ {
+ blocks.clear();
+ block_table.reset();
+ std::string temp;
+ while (block_table.move_next())
+ {
+ temp = block_table.element().key();
+ blocks.enqueue(temp);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ template <
+ typename alloc
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ get_blocks (
+ std::vector<std::string,alloc>& blocks
+ ) const
+ {
+ blocks.clear();
+ block_table.reset();
+ while (block_table.move_next())
+ {
+ blocks.push_back(block_table.element().key());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_string_string,
+ typename map_string_void,
+ typename tokenizer
+ >
+ template <
+ typename alloc
+ >
+ void config_reader_kernel_1<map_string_string,map_string_void,tokenizer>::
+ get_blocks (
+ std_vector_c<std::string,alloc>& blocks
+ ) const
+ {
+ blocks.clear();
+ block_table.reset();
+ while (block_table.move_next())
+ {
+ blocks.push_back(block_table.element().key());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CONFIG_READER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/config_reader/config_reader_kernel_abstract.h b/ml/dlib/dlib/config_reader/config_reader_kernel_abstract.h
new file mode 100644
index 000000000..e8c44c2b2
--- /dev/null
+++ b/ml/dlib/dlib/config_reader/config_reader_kernel_abstract.h
@@ -0,0 +1,363 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CONFIG_READER_KERNEl_ABSTRACT_
+#ifdef DLIB_CONFIG_READER_KERNEl_ABSTRACT_
+
+#include <string>
+#include <iosfwd>
+
+namespace dlib
+{
+
+ class config_reader
+ {
+
+ /*!
+ INITIAL VALUE
+ - there aren't any keys defined for this object
+ - there aren't any blocks defined for this object
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ The destructor, clear(), and load_from() invalidate pointers
+ and references to internal data. All other functions are guaranteed
+ to NOT invalidate pointers or references to internal data.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something which is intended to be used to read
+ text configuration files that are defined by the following EBNF (with
+ config_file as the starting symbol):
+
+ config_file = block;
+ block = { key_value_pair | sub_block };
+ key_value_pair = key_name, "=", value;
+ sub_block = block_name, "{", block, "}";
+
+ key_name = identifier;
+ block_name = identifier;
+ value = matches any string of text that ends with a newline character, # or }.
+ note that the trailing newline, # or } is not part of the value though.
+ identifier = Any string that matches the following regular expression:
+ [a-zA-Z][a-zA-Z0-9_-\.]*
+ i.e. Any string that starts with a letter and then is continued
+ with any number of letters, numbers, _ . or - characters.
+
+ Whitespace and comments are ignored. A comment is text that starts with # (but not \#
+ since the \ escapes the # so that you can have a # symbol in a value if you want) and
+ ends in a new line. You can also escape a } (e.g. "\}") if you want to have one in a
+ value.
+
+ Note that in a value the leading and trailing white spaces are stripped off but any
+ white space inside the value is preserved.
+
+ Also note that all key_names and block_names within a block syntax group must be unique
+ but don't have to be globally unique. I.e. different blocks can reuse names.
+
+ EXAMPLE CONFIG FILES:
+
+ Example 1:
+ #comment. This line is ignored because it starts with #
+
+ #here we have key1 which will have the value of "my value"
+ key1 = my value
+
+ another_key= another value # this is another key called "another_key" with
+ # a value of "another value"
+
+ # this key's value is the empty string. I.e. ""
+ key2=
+
+ Example 2:
+ #this example illustrates the use of blocks
+ some_key = blah blah
+
+ # now here is a block
+ our_block
+ {
+ # here we can define some keys and values that are local to this block.
+ a_key = something
+ foo = bar
+ some_key = more stuff # note that it is ok to name our key this even though
+ # there is a key called some_key above. This is because
+ # we are doing so inside a different block
+ }
+
+ another_block { foo = bar2 } # this block has only one key and is all on a single line
+ !*/
+
+ public:
+
+ // exception classes
+ class config_reader_error : public dlib::error
+ {
+ /*!
+ GENERAL
+ This exception is thrown if there is an error while parsing the
+ config file. The type member of this exception will be set
+ to ECONFIG_READER.
+
+ INTERPRETING THIS EXCEPTION
+ - line_number == the line number the parser was at when the
+ error occurred.
+ - if (redefinition) then
+ - The key or block name on line line_number has already
+ been defined in this scope which is an error.
+ - else
+ - Some other general syntax error was detected
+ !*/
+ public:
+ const unsigned long line_number;
+ const bool redefinition;
+ };
+
+ class file_not_found : public dlib::error
+ {
+ /*!
+ GENERAL
+ This exception is thrown if the config file can't be opened for
+ some reason. The type member of this exception will be set
+ to ECONFIG_READER.
+
+ INTERPRETING THIS EXCEPTION
+ - file_name == the name of the config file which we failed to open
+ !*/
+ public:
+ const std::string file_name;
+ };
+
+
+ class config_reader_access_error : public dlib::error
+ {
+ /*!
+ GENERAL
+ This exception is thrown if you try to access a key or
+ block that doesn't exist inside a config reader. The type
+ member of this exception will be set to ECONFIG_READER.
+ !*/
+ public:
+ config_reader_access_error(
+ const std::string& block_name_,
+ const std::string& key_name_
+ );
+ /*!
+ ensures
+ - #block_name == block_name_
+ - #key_name == key_name_
+ !*/
+
+ const std::string block_name;
+ const std::string key_name;
+ };
+
+ // --------------------------
+
+ config_reader(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - This object will not have any keys or blocks defined in it.
+ throws
+ - std::bad_alloc
+ - config_reader_error
+ !*/
+
+ config_reader(
+ std::istream& in
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - reads the config file to parse from the given input stream,
+ parses it and loads this object up with all the sub blocks and
+ key/value pairs it finds.
+ - before the load is performed, the previous state of the config file
+ reader is erased. So after the load the config file reader will contain
+ only information from the given config file.
+ - This object will represent the top most block of the config file.
+ throws
+ - std::bad_alloc
+ - config_reader_error
+ !*/
+
+ config_reader(
+ const std::string& config_file
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - parses the config file named by the config_file string. Specifically,
+ parses it and loads this object up with all the sub blocks and
+ key/value pairs it finds in the file.
+ - before the load is performed, the previous state of the config file
+ reader is erased. So after the load the config file reader will contain
+ only information from the given config file.
+ - This object will represent the top most block of the config file.
+ throws
+ - std::bad_alloc
+ - config_reader_error
+ - file_not_found
+ !*/
+
+ virtual ~config_reader(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ If this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void load_from (
+ std::istream& in
+ );
+ /*!
+ ensures
+ - reads the config file to parse from the given input stream,
+ parses it and loads this object up with all the sub blocks and
+ key/value pairs it finds.
+ - before the load is performed, the previous state of the config file
+ reader is erased. So after the load the config file reader will contain
+ only information from the given config file.
+ - *this will represent the top most block of the config file contained
+ in the input stream in.
+ throws
+ - std::bad_alloc
+ If this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ - config_reader_error
+ If this exception is thrown then this object will
+ revert to its initial value.
+ !*/
+
+ void load_from (
+ const std::string& config_file
+ );
+ /*!
+ ensures
+ - parses the config file named by the config_file string. Specifically,
+ parses it and loads this object up with all the sub blocks and
+ key/value pairs it finds in the file.
+ - before the load is performed, the previous state of the config file
+ reader is erased. So after the load the config file reader will contain
+ only information from the given config file.
+ - This object will represent the top most block of the config file.
+ throws
+ - std::bad_alloc
+ If this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ - config_reader_error
+ If this exception is thrown then this object will
+ revert to its initial value.
+ - file_not_found
+ If this exception is thrown then this object will
+ revert to its initial value.
+ !*/
+
+ bool is_key_defined (
+ const std::string& key_name
+ ) const;
+ /*!
+ ensures
+ - if (there is a key with the given name defined within this config_reader's block) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_block_defined (
+ const std::string& block_name
+ ) const;
+ /*!
+ ensures
+ - if (there is a sub block with the given name defined within this config_reader's block) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ typedef config_reader this_type;
+ const this_type& block (
+ const std::string& block_name
+ ) const;
+ /*!
+ ensures
+ - if (is_block_defined(block_name) == true) then
+ - returns a const reference to the config_reader that represents the given named sub block
+ - else
+ - throws config_reader_access_error
+ throws
+ - config_reader_access_error
+ if this exception is thrown then its block_name field will be set to the
+ given block_name string.
+ !*/
+
+ const std::string& operator[] (
+ const std::string& key_name
+ ) const;
+ /*!
+ ensures
+ - if (is_key_defined(key_name) == true) then
+ - returns a const reference to the value string associated with the given key in
+ this config_reader's block.
+ - else
+ - throws config_reader_access_error
+ throws
+ - config_reader_access_error
+ if this exception is thrown then its key_name field will be set to the
+ given key_name string.
+ !*/
+
+ template <
+ typename queue_of_strings
+ >
+ void get_keys (
+ queue_of_strings& keys
+ ) const;
+ /*!
+ requires
+ - queue_of_strings is an implementation of queue/queue_kernel_abstract.h
+ with T set to std::string, or std::vector<std::string>, or
+ dlib::std_vector_c<std::string>
+ ensures
+ - #keys == a collection containing all the keys defined in this config_reader's block.
+ (i.e. for all strings str in keys it is the case that is_key_defined(str) == true)
+ !*/
+
+ template <
+ typename queue_of_strings
+ >
+ void get_blocks (
+ queue_of_strings& blocks
+ ) const;
+ /*!
+ requires
+ - queue_of_strings is an implementation of queue/queue_kernel_abstract.h
+ with T set to std::string, or std::vector<std::string>, or
+ dlib::std_vector_c<std::string>
+ ensures
+ - #blocks == a collection containing the names of all the blocks defined in this
+ config_reader's block.
+ (i.e. for all strings str in blocks it is the case that is_block_defined(str) == true)
+ !*/
+
+ private:
+
+ // restricted functions
+ config_reader(config_reader&); // copy constructor
+ config_reader& operator=(config_reader&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_CONFIG_READER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/config_reader/config_reader_thread_safe_1.h b/ml/dlib/dlib/config_reader/config_reader_thread_safe_1.h
new file mode 100644
index 000000000..1ad250c99
--- /dev/null
+++ b/ml/dlib/dlib/config_reader/config_reader_thread_safe_1.h
@@ -0,0 +1,456 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONFIG_READER_THREAD_SAFe_
+#define DLIB_CONFIG_READER_THREAD_SAFe_
+
+#include "config_reader_kernel_abstract.h"
+#include <string>
+#include <iostream>
+#include <sstream>
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../threads.h"
+#include "config_reader_thread_safe_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ class config_reader_thread_safe_1
+ {
+
+ /*!
+ CONVENTION
+ - get_mutex() == *m
+ - *cr == the config reader being extended
+ - block_table[x] == (void*)&block(x)
+ - block_table.size() == the number of blocks in *cr
+ - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key)
+ - if (own_pointers) then
+ - this object owns the m and cr pointers and should delete them when destructed
+ !*/
+
+ public:
+
+ config_reader_thread_safe_1 (
+ const config_reader_base* base,
+ rmutex* m_
+ );
+
+ config_reader_thread_safe_1();
+
+ typedef typename config_reader_base::config_reader_error config_reader_error;
+ typedef typename config_reader_base::config_reader_access_error config_reader_access_error;
+
+ config_reader_thread_safe_1(
+ std::istream& in
+ );
+
+ config_reader_thread_safe_1(
+ const std::string& config_file
+ );
+
+ virtual ~config_reader_thread_safe_1(
+ );
+
+ void clear (
+ );
+
+ void load_from (
+ std::istream& in
+ );
+
+ void load_from (
+ const std::string& config_file
+ );
+
+ bool is_key_defined (
+ const std::string& key
+ ) const;
+
+ bool is_block_defined (
+ const std::string& name
+ ) const;
+
+ typedef config_reader_thread_safe_1 this_type;
+ const this_type& block (
+ const std::string& name
+ ) const;
+
+ const std::string& operator[] (
+ const std::string& key
+ ) const;
+
+ template <
+ typename queue_of_strings
+ >
+ void get_keys (
+ queue_of_strings& keys
+ ) const;
+
+ template <
+ typename queue_of_strings
+ >
+ void get_blocks (
+ queue_of_strings& blocks
+ ) const;
+
+ inline const rmutex& get_mutex (
+ ) const;
+
+ private:
+
+ void fill_block_table (
+ );
+ /*!
+ ensures
+ - block_table.size() == the number of blocks in cr
+ - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key)
+ !*/
+
+ rmutex* m;
+ config_reader_base* cr;
+ map_string_void block_table;
+ const bool own_pointers;
+
+ // restricted functions
+ config_reader_thread_safe_1(config_reader_thread_safe_1&);
+ config_reader_thread_safe_1& operator=(config_reader_thread_safe_1&);
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ config_reader_thread_safe_1(
+ const config_reader_base* base,
+ rmutex* m_
+ ) :
+ m(m_),
+ cr(const_cast<config_reader_base*>(base)),
+ own_pointers(false)
+ {
+ fill_block_table();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ config_reader_thread_safe_1(
+ ) :
+ m(0),
+ cr(0),
+ own_pointers(true)
+ {
+ try
+ {
+ m = new rmutex;
+ cr = new config_reader_base;
+ }
+ catch (...)
+ {
+ if (m) delete m;
+ if (cr) delete cr;
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ void config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ clear(
+ )
+ {
+ auto_mutex M(*m);
+ cr->clear();
+ fill_block_table();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ void config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ load_from(
+ std::istream& in
+ )
+ {
+ auto_mutex M(*m);
+ cr->load_from(in);
+ fill_block_table();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ void config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ load_from(
+ const std::string& config_file
+ )
+ {
+ auto_mutex M(*m);
+ cr->load_from(config_file);
+ fill_block_table();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ config_reader_thread_safe_1(
+ std::istream& in
+ ) :
+ m(0),
+ cr(0),
+ own_pointers(true)
+ {
+ try
+ {
+ m = new rmutex;
+ cr = new config_reader_base(in);
+ fill_block_table();
+ }
+ catch (...)
+ {
+ if (m) delete m;
+ if (cr) delete cr;
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ config_reader_thread_safe_1(
+ const std::string& config_file
+ ) :
+ m(0),
+ cr(0),
+ own_pointers(true)
+ {
+ try
+ {
+ m = new rmutex;
+ cr = new config_reader_base(config_file);
+ fill_block_table();
+ }
+ catch (...)
+ {
+ if (m) delete m;
+ if (cr) delete cr;
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ ~config_reader_thread_safe_1(
+ )
+ {
+ if (own_pointers)
+ {
+ delete m;
+ delete cr;
+ }
+
+ // clear out the block table
+ block_table.reset();
+ while (block_table.move_next())
+ {
+ delete static_cast<config_reader_thread_safe_1*>(block_table.element().value());
+ }
+ block_table.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ bool config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ is_key_defined (
+ const std::string& key
+ ) const
+ {
+ auto_mutex M(*m);
+ return cr->is_key_defined(key);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ bool config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ is_block_defined (
+ const std::string& name
+ ) const
+ {
+ auto_mutex M(*m);
+ return cr->is_block_defined(name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ const config_reader_thread_safe_1<config_reader_base,map_string_void>& config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ block (
+ const std::string& name
+ ) const
+ {
+ auto_mutex M(*m);
+ if (block_table.is_in_domain(name) == false)
+ {
+ throw config_reader_access_error(name,"");
+ }
+
+ return *static_cast<config_reader_thread_safe_1*>(block_table[name]);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ const std::string& config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ operator[] (
+ const std::string& key
+ ) const
+ {
+ auto_mutex M(*m);
+ return (*cr)[key];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ template <
+ typename queue_of_strings
+ >
+ void config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ get_keys (
+ queue_of_strings& keys
+ ) const
+ {
+ auto_mutex M(*m);
+ cr->get_keys(keys);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ template <
+ typename queue_of_strings
+ >
+ void config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ get_blocks (
+ queue_of_strings& blocks
+ ) const
+ {
+ auto_mutex M(*m);
+ cr->get_blocks(blocks);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ const rmutex& config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ get_mutex (
+ ) const
+ {
+ return *m;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// private member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename config_reader_base,
+ typename map_string_void
+ >
+ void config_reader_thread_safe_1<config_reader_base,map_string_void>::
+ fill_block_table (
+ )
+ {
+ using namespace std;
+ // first empty out the block table
+ block_table.reset();
+ while (block_table.move_next())
+ {
+ delete static_cast<config_reader_thread_safe_1*>(block_table.element().value());
+ }
+ block_table.clear();
+
+ std::vector<std::string> blocks;
+ cr->get_blocks(blocks);
+
+ // now fill the block table up to match what is in cr
+ for (unsigned long i = 0; i < blocks.size(); ++i)
+ {
+ config_reader_thread_safe_1* block = new config_reader_thread_safe_1(&cr->block(blocks[i]),m);
+ void* temp = block;
+ block_table.add(blocks[i],temp);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CONFIG_READER_THREAD_SAFe_
+
+
diff --git a/ml/dlib/dlib/config_reader/config_reader_thread_safe_abstract.h b/ml/dlib/dlib/config_reader/config_reader_thread_safe_abstract.h
new file mode 100644
index 000000000..25bcbae4a
--- /dev/null
+++ b/ml/dlib/dlib/config_reader/config_reader_thread_safe_abstract.h
@@ -0,0 +1,45 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_
+#ifdef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_
+
+#include <string>
+#include <iosfwd>
+#include "config_reader_kernel_abstract.h"
+#include "../threads/threads_kernel_abstract.h"
+
+namespace dlib
+{
+
+ class config_reader_thread_safe
+ {
+
+ /*!
+ WHAT THIS EXTENSION DOES FOR config_reader
+ This object extends a normal config_reader by simply wrapping all
+ its member functions inside mutex locks to make it safe to use
+ in a threaded program.
+
+ So this object provides an interface identical to the one defined
+ in the config_reader/config_reader_kernel_abstract.h file except that
+ the rmutex returned by get_mutex() is always locked when this
+ object's member functions are called.
+ !*/
+
+ public:
+
+ const rmutex& get_mutex (
+ ) const;
+ /*!
+ ensures
+ - returns the rmutex used to make this object thread safe. i.e. returns
+ the rmutex that is locked when this object's functions are called.
+ !*/
+
+ };
+
+}
+
+#endif // DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/console_progress_indicator.h b/ml/dlib/dlib/console_progress_indicator.h
new file mode 100644
index 000000000..8f04aa533
--- /dev/null
+++ b/ml/dlib/dlib/console_progress_indicator.h
@@ -0,0 +1,207 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_
+#define DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_
+
+#include <ctime>
+#include <cmath>
+#include <limits>
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class console_progress_indicator
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for reporting how long a task will take
+ to complete.
+
+ For example, consider the following bit of code:
+
+ console_progress_indicator pbar(100)
+ for (int i = 1; i <= 100; ++i)
+ {
+ pbar.print_status(i);
+ long_running_operation();
+ }
+
+ The above code will print a message to the console each iteration
+ which shows how much time is remaining until the loop terminates.
+ !*/
+
+ public:
+
+ inline explicit console_progress_indicator (
+ double target_value
+ );
+ /*!
+ ensures
+ - #target() == target_value
+ !*/
+
+ inline void reset (
+ double target_value
+ );
+ /*!
+ ensures
+ - #target() == target_value
+ - performs the equivalent of:
+ *this = console_progress_indicator(target_value)
+ (i.e. resets this object with a new target value)
+
+ !*/
+
+ inline double target (
+ ) const;
+ /*!
+ ensures
+ - This object attempts to measure how much time is
+ left until we reach a certain targeted value. This
+ function returns that targeted value.
+ !*/
+
+ inline bool print_status (
+ double cur,
+ bool always_print = false
+ );
+ /*!
+ ensures
+ - print_status() assumes it is called with values which are linearly
+ approaching target(). It will attempt to predict how much time is
+ remaining until cur becomes equal to target().
+ - prints a status message to the screen which indicates how much
+ more time is left until cur is equal to target()
+ - if (always_print) then
+ - This function prints to the screen each time it is called.
+ - else
+ - This function throttles the printing so that at most 1 message is
+ printed each second. Note that it won't print anything to the screen
+ until about one second has elapsed. This means that the first call
+ to print_status() never prints to the screen.
+ - This function returns true if it prints to the screen and false
+ otherwise.
+ !*/
+
+ private:
+
+ double target_val;
+
+ time_t start_time;
+ double first_val;
+ double seen_first_val;
+ time_t last_time;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// IMPLEMENTATION DETAILS
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ console_progress_indicator::
+ console_progress_indicator (
+ double target_value
+ ) :
+ target_val(target_value),
+ start_time(0),
+ first_val(0),
+ seen_first_val(false),
+ last_time(0)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool console_progress_indicator::
+ print_status (
+ double cur,
+ bool always_print
+ )
+ {
+ const time_t cur_time = std::time(0);
+
+ // if this is the first time print_status has been called
+ // then collect some information and exit. We will print status
+ // on the next call.
+ if (!seen_first_val)
+ {
+ start_time = cur_time;
+ last_time = cur_time;
+ first_val = cur;
+ seen_first_val = true;
+ return false;
+ }
+
+ if (cur_time != last_time || always_print)
+ {
+ last_time = cur_time;
+ double delta_t = static_cast<double>(cur_time - start_time);
+ double delta_val = std::abs(cur - first_val);
+
+ // don't do anything if cur is equal to first_val
+ if (delta_val < std::numeric_limits<double>::epsilon())
+ return false;
+
+ double seconds = delta_t/delta_val * std::abs(target_val - cur);
+
+ std::ios::fmtflags oldflags = std::cout.flags();
+ std::cout.flags();
+ std::cout.setf(std::ios::fixed,std::ios::floatfield);
+ std::streamsize ss;
+
+ if (seconds < 60)
+ {
+ ss = std::cout.precision(0);
+ std::cout << "Time remaining: " << seconds << " seconds. \r" << std::flush;
+ }
+ else if (seconds < 60*60)
+ {
+ ss = std::cout.precision(2);
+ std::cout << "Time remaining: " << seconds/60 << " minutes. \r" << std::flush;
+ }
+ else
+ {
+ ss = std::cout.precision(2);
+ std::cout << "Time remaining: " << seconds/60/60 << " hours. \r" << std::flush;
+ }
+
+ // restore previous output flags and precision settings
+ std::cout.flags(oldflags);
+ std::cout.precision(ss);
+
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double console_progress_indicator::
+ target (
+ ) const
+ {
+ return target_val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void console_progress_indicator::
+ reset (
+ double target_value
+ )
+ {
+ *this = console_progress_indicator(target_value);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_
+
diff --git a/ml/dlib/dlib/control.h b/ml/dlib/dlib/control.h
new file mode 100644
index 000000000..85d00817d
--- /dev/null
+++ b/ml/dlib/dlib/control.h
@@ -0,0 +1,11 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CONTRoL_
+#define DLIB_CONTRoL_
+
+#include "control/lspi.h"
+#include "control/mpc.h"
+
+#endif // DLIB_CONTRoL_
+
+
diff --git a/ml/dlib/dlib/control/approximate_linear_models.h b/ml/dlib/dlib/control/approximate_linear_models.h
new file mode 100644
index 000000000..9732d71e9
--- /dev/null
+++ b/ml/dlib/dlib/control/approximate_linear_models.h
@@ -0,0 +1,128 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
+#define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
+
+#include "approximate_linear_models_abstract.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ struct process_sample
+ {
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::state_type state_type;
+ typedef typename feature_extractor::action_type action_type;
+
+ process_sample(){}
+
+ process_sample(
+ const state_type& s,
+ const action_type& a,
+ const state_type& n,
+ const double& r
+ ) : state(s), action(a), next_state(n), reward(r) {}
+
+ state_type state;
+ action_type action;
+ state_type next_state;
+ double reward;
+ };
+
+ template < typename feature_extractor >
+ void serialize (const process_sample<feature_extractor>& item, std::ostream& out)
+ {
+ serialize(item.state, out);
+ serialize(item.action, out);
+ serialize(item.next_state, out);
+ serialize(item.reward, out);
+ }
+
+ template < typename feature_extractor >
+ void deserialize (process_sample<feature_extractor>& item, std::istream& in)
+ {
+ deserialize(item.state, in);
+ deserialize(item.action, in);
+ deserialize(item.next_state, in);
+ deserialize(item.reward, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class policy
+ {
+ public:
+
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::state_type state_type;
+ typedef typename feature_extractor::action_type action_type;
+
+
+ policy (
+ )
+ {
+ w.set_size(fe.num_features());
+ w = 0;
+ }
+
+ policy (
+ const matrix<double,0,1>& weights_,
+ const feature_extractor& fe_
+ ) : w(weights_), fe(fe_) {}
+
+ action_type operator() (
+ const state_type& state
+ ) const
+ {
+ return fe.find_best_action(state,w);
+ }
+
+ const feature_extractor& get_feature_extractor (
+ ) const { return fe; }
+
+ const matrix<double,0,1>& get_weights (
+ ) const { return w; }
+
+
+ private:
+ matrix<double,0,1> w;
+ feature_extractor fe;
+ };
+
+ template < typename feature_extractor >
+ inline void serialize(const policy<feature_extractor>& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.get_feature_extractor(), out);
+ serialize(item.get_weights(), out);
+ }
+ template < typename feature_extractor >
+ inline void deserialize(policy<feature_extractor>& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::policy object.");
+ feature_extractor fe;
+ matrix<double,0,1> w;
+ deserialize(fe, in);
+ deserialize(w, in);
+ item = policy<feature_extractor>(w,fe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
+
diff --git a/ml/dlib/dlib/control/approximate_linear_models_abstract.h b/ml/dlib/dlib/control/approximate_linear_models_abstract.h
new file mode 100644
index 000000000..59dac4276
--- /dev/null
+++ b/ml/dlib/dlib/control/approximate_linear_models_abstract.h
@@ -0,0 +1,213 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_
+#ifdef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_
+
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct example_feature_extractor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines the interface a feature extractor must implement if it
+ is to be used with the process_sample and policy objects defined at the
+ bottom of this file. Moreover, it is meant to represent the core part
+ of a model used in a reinforcement learning algorithm.
+
+ In particular, this object models a Q(state,action) function where
+ Q(state,action) == dot(w, PSI(state,action))
+ where PSI(state,action) is a feature vector and w is a parameter
+ vector.
+
+ Therefore, a feature extractor defines how the PSI(x,y) feature vector is
+ calculated. It also defines the types used to represent the state and
+ action objects.
+
+
+ THREAD SAFETY
+ Instances of this object are required to be threadsafe, that is, it should
+ be safe for multiple threads to make concurrent calls to the member
+ functions of this object.
+ !*/
+
+ // The state and actions can be any types so long as you provide typedefs for them.
+ typedef T state_type;
+ typedef U action_type;
+ // We can also say that the last element in the weight vector w must be 1. This
+ // can be useful for including a prior into your model.
+ const static bool force_last_weight_to_1 = false;
+
+ example_feature_extractor(
+ );
+ /*!
+ ensures
+ - this object is properly initialized.
+ !*/
+
+ unsigned long num_features(
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the PSI() feature vector.
+ !*/
+
+ action_type find_best_action (
+ const state_type& state,
+ const matrix<double,0,1>& w
+ ) const;
+ /*!
+ ensures
+ - returns the action A that maximizes Q(state,A) = dot(w,PSI(state,A)).
+ That is, this function finds the best action to take in the given state
+ when our model is parameterized by the given weight vector w.
+ !*/
+
+ void get_features (
+ const state_type& state,
+ const action_type& action,
+ matrix<double,0,1>& feats
+ ) const;
+ /*!
+ ensures
+ - #feats.size() == num_features()
+ - #feats == PSI(state,action)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ struct process_sample
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ feature_extractor should implement the example_feature_extractor interface
+ defined at the top of this file.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object holds a training sample for a reinforcement learning algorithm.
+ In particular, it should be a sample from some process where the process
+ was in state this->state, then took this->action action which resulted in
+ receiving this->reward and ending up in the state this->next_state.
+ !*/
+
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::state_type state_type;
+ typedef typename feature_extractor::action_type action_type;
+
+ process_sample(){}
+
+ process_sample(
+ const state_type& s,
+ const action_type& a,
+ const state_type& n,
+ const double& r
+ ) : state(s), action(a), next_state(n), reward(r) {}
+
+ state_type state;
+ action_type action;
+ state_type next_state;
+ double reward;
+ };
+
+ template < typename feature_extractor >
+ void serialize (const process_sample<feature_extractor>& item, std::ostream& out);
+ template < typename feature_extractor >
+ void deserialize (process_sample<feature_extractor>& item, std::istream& in);
+ /*!
+ provides serialization support.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class policy
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ feature_extractor should implement the example_feature_extractor interface
+ defined at the top of this file.
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a policy based on the supplied feature_extractor model. In
+ particular, it maps from feature_extractor::state_type to the best action
+ to take in that state.
+ !*/
+
+ public:
+
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::state_type state_type;
+ typedef typename feature_extractor::action_type action_type;
+
+
+ policy (
+ );
+ /*!
+ ensures
+ - #get_feature_extractor() == feature_extractor()
+ (i.e. it will have its default value)
+ - #get_weights().size() == #get_feature_extractor().num_features()
+ - #get_weights() == 0
+ !*/
+
+ policy (
+ const matrix<double,0,1>& weights,
+ const feature_extractor& fe
+ );
+ /*!
+ requires
+ - fe.num_features() == weights.size()
+ ensures
+ - #get_feature_extractor() == fe
+ - #get_weights() == weights
+ !*/
+
+ action_type operator() (
+ const state_type& state
+ ) const;
+ /*!
+ ensures
+ - returns get_feature_extractor().find_best_action(state,w);
+ !*/
+
+ const feature_extractor& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used by this object
+ !*/
+
+ const matrix<double,0,1>& get_weights (
+ ) const;
+ /*!
+ ensures
+ - returns the parameter vector (w) associated with this object. The length
+ of the vector is get_feature_extractor().num_features().
+ !*/
+
+ };
+
+ template < typename feature_extractor >
+ void serialize(const policy<feature_extractor>& item, std::ostream& out);
+ template < typename feature_extractor >
+ void deserialize(policy<feature_extractor>& item, std::istream& in);
+ /*!
+ provides serialization support.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+
+#endif // DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/control/lspi.h b/ml/dlib/dlib/control/lspi.h
new file mode 100644
index 000000000..b21a501d2
--- /dev/null
+++ b/ml/dlib/dlib/control/lspi.h
@@ -0,0 +1,188 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LSPI_Hh_
+#define DLIB_LSPI_Hh_
+
+#include "lspi_abstract.h"
+#include "approximate_linear_models.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class lspi
+ {
+ public:
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::state_type state_type;
+ typedef typename feature_extractor::action_type action_type;
+
+ explicit lspi(
+ const feature_extractor& fe_
+ ) : fe(fe_)
+ {
+ init();
+ }
+
+ lspi(
+ )
+ {
+ init();
+ }
+
+ double get_discount (
+ ) const { return discount; }
+
+ void set_discount (
+ double value
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < value && value <= 1,
+ "\t void lspi::set_discount(value)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t value: " << value
+ );
+ discount = value;
+ }
+
+ const feature_extractor& get_feature_extractor (
+ ) const { return fe; }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void lspi::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps_: " << eps_
+ );
+ eps = eps_;
+ }
+
+ double get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ void set_lambda (
+ double lambda_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(lambda_ >= 0,
+ "\t void lspi::set_lambda(lambda_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t lambda_: " << lambda_
+ );
+ lambda = lambda_;
+ }
+
+ double get_lambda (
+ ) const
+ {
+ return lambda;
+ }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ ) { max_iterations = max_iter; }
+
+ unsigned long get_max_iterations (
+ ) { return max_iterations; }
+
+ template <typename vector_type>
+ policy<feature_extractor> train (
+ const vector_type& samples
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(samples.size() > 0,
+ "\t policy lspi::train(samples)"
+ << "\n\t invalid inputs were given to this function"
+ );
+
+ matrix<double,0,1> w(fe.num_features());
+ w = 0;
+ matrix<double,0,1> prev_w, b, f1, f2;
+
+ matrix<double> A;
+
+ double change;
+ unsigned long iter = 0;
+ do
+ {
+ A = identity_matrix<double>(fe.num_features())*lambda;
+ b = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ fe.get_features(samples[i].state, samples[i].action, f1);
+ fe.get_features(samples[i].next_state,
+ fe.find_best_action(samples[i].next_state,w),
+ f2);
+ A += f1*trans(f1 - discount*f2);
+ b += f1*samples[i].reward;
+ }
+
+ prev_w = w;
+ if (feature_extractor::force_last_weight_to_1)
+ w = join_cols(pinv(colm(A,range(0,A.nc()-2)))*(b-colm(A,A.nc()-1)),mat(1.0));
+ else
+ w = pinv(A)*b;
+
+ change = length(w-prev_w);
+ ++iter;
+
+ if (verbose)
+ std::cout << "iteration: " << iter << "\tchange: " << change << std::endl;
+
+ } while(change > eps && iter < max_iterations);
+
+ return policy<feature_extractor>(w,fe);
+ }
+
+
+ private:
+
+ void init()
+ {
+ lambda = 0.01;
+ discount = 0.8;
+ eps = 0.01;
+ verbose = false;
+ max_iterations = 100;
+ }
+
+ double lambda;
+ double discount;
+ double eps;
+ bool verbose;
+ unsigned long max_iterations;
+ feature_extractor fe;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LSPI_Hh_
+
diff --git a/ml/dlib/dlib/control/lspi_abstract.h b/ml/dlib/dlib/control/lspi_abstract.h
new file mode 100644
index 000000000..f262d16f4
--- /dev/null
+++ b/ml/dlib/dlib/control/lspi_abstract.h
@@ -0,0 +1,193 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LSPI_ABSTRACT_Hh_
+#ifdef DLIB_LSPI_ABSTRACT_Hh_
+
+#include "approximate_linear_models_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class lspi
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ feature_extractor should implement the example_feature_extractor interface
+ defined at the top of dlib/control/approximate_linear_models_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is an implementation of the reinforcement learning algorithm
+ described in the following paper:
+ Lagoudakis, Michail G., and Ronald Parr. "Least-squares policy
+ iteration." The Journal of Machine Learning Research 4 (2003):
+ 1107-1149.
+
+ This means that it takes a bunch of training data in the form of
+ process_samples and outputs a policy that hopefully performs well when run
+ on the process that generated those samples.
+ !*/
+
+ public:
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::state_type state_type;
+ typedef typename feature_extractor::action_type action_type;
+
+ explicit lspi(
+ const feature_extractor& fe_
+ );
+ /*!
+ ensures
+ - #get_feature_extractor() == fe_
+ - #get_lambda() == 0.01
+ - #get_discount == 0.8
+ - #get_epsilon() == 0.01
+ - is not verbose
+ - #get_max_iterations() == 100
+ !*/
+
+ lspi(
+ );
+ /*!
+ ensures
+ - #get_feature_extractor() == feature_extractor()
+ (i.e. it will have its default value)
+ - #get_lambda() == 0.01
+ - #get_discount == 0.8
+ - #get_epsilon() == 0.01
+ - is not verbose
+ - #get_max_iterations() == 100
+ !*/
+
+ double get_discount (
+ ) const;
+ /*!
+ ensures
+ - returns the discount applied to the sum of rewards in the Bellman
+ equation.
+ !*/
+
+ void set_discount (
+ double value
+ );
+ /*!
+ requires
+ - 0 < value <= 1
+ ensures
+ - #get_discount() == value
+ !*/
+
+ const feature_extractor& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used by this object
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer to
+ train.
+ !*/
+
+ void set_lambda (
+ double lambda_
+ );
+ /*!
+ requires
+ - lambda >= 0
+ ensures
+ - #get_lambda() == lambda
+ !*/
+
+ double get_lambda (
+ ) const;
+ /*!
+ ensures
+ - returns the regularization parameter. It is the parameter that
+ determines the trade off between trying to fit the training data
+ exactly or allowing more errors but hopefully improving the
+ generalization ability of the resulting function. Smaller values
+ encourage exact fitting while larger values of lambda may encourage
+ better generalization.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ unsigned long get_max_iterations (
+ );
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ template <
+ typename vector_type
+ >
+ policy<feature_extractor> train (
+ const vector_type& samples
+ ) const;
+ /*!
+ requires
+ - samples.size() > 0
+ - samples is something with an interface that looks like
+ std::vector<process_sample<feature_extractor>>. That is, it should
+ be some kind of array of process_sample objects.
+ ensures
+ - Trains a policy based on the given data and returns the results. The
+ idea is to find a policy that will obtain the largest possible reward
+ when run on the process that generated the samples. In particular,
+ if the returned policy is P then:
+ - P(S) == the best action to take when in state S.
+ - if (feature_extractor::force_last_weight_to_1) then
+ - The last element of P.get_weights() is 1.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LSPI_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/control/mpc.h b/ml/dlib/dlib/control/mpc.h
new file mode 100644
index 000000000..48ef2b72d
--- /dev/null
+++ b/ml/dlib/dlib/control/mpc.h
@@ -0,0 +1,370 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MPC_Hh_
+#define DLIB_MPC_Hh_
+
+#include "mpc_abstract.h"
+#include "../matrix.h"
+#include "../algs.h"
+
+
+namespace dlib
+{
+ template <
+ long S_,
+ long I_,
+ unsigned long horizon_
+ >
+ class mpc
+ {
+
+ public:
+
+ const static long S = S_;
+ const static long I = I_;
+ const static unsigned long horizon = horizon_;
+
+ mpc(
+ )
+ {
+ A = 0;
+ B = 0;
+ C = 0;
+ Q = 0;
+ R = 0;
+ lower = 0;
+ upper = 0;
+
+ max_iterations = 0;
+ eps = 0.01;
+ for (unsigned long i = 0; i < horizon; ++i)
+ {
+ target[i].set_size(A.nr());
+ target[i] = 0;
+
+ controls[i].set_size(B.nc());
+ controls[i] = 0;
+ }
+ lambda = 0;
+ }
+
+ mpc (
+ const matrix<double,S,S>& A_,
+ const matrix<double,S,I>& B_,
+ const matrix<double,S,1>& C_,
+ const matrix<double,S,1>& Q_,
+ const matrix<double,I,1>& R_,
+ const matrix<double,I,1>& lower_,
+ const matrix<double,I,1>& upper_
+ ) : A(A_), B(B_), C(C_), Q(Q_), R(R_), lower(lower_), upper(upper_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(A.nr() > 0 && B.nc() > 0,
+ "\t mpc::mpc()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t A.nr(): " << A.nr()
+ << "\n\t B.nc(): " << B.nc()
+ );
+
+ DLIB_ASSERT(A.nr() == A.nc() &&
+ A.nr() == B.nr() &&
+ A.nr() == C.nr() &&
+ A.nr() == Q.nr(),
+ "\t mpc::mpc()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t A.nr(): " << A.nr()
+ << "\n\t A.nc(): " << A.nc()
+ << "\n\t B.nr(): " << B.nr()
+ << "\n\t C.nr(): " << C.nr()
+ << "\n\t Q.nr(): " << Q.nr()
+ );
+ DLIB_ASSERT(
+ B.nc() == R.nr() &&
+ B.nc() == lower.nr() &&
+ B.nc() == upper.nr() ,
+ "\t mpc::mpc()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t B.nr(): " << B.nr()
+ << "\n\t B.nc(): " << B.nc()
+ << "\n\t lower.nr(): " << lower.nr()
+ << "\n\t upper.nr(): " << upper.nr()
+ );
+ DLIB_ASSERT(min(Q) >= 0 &&
+ min(R) > 0 &&
+ min(upper-lower) >= 0,
+ "\t mpc::mpc()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t min(Q): " << min(Q)
+ << "\n\t min(R): " << min(R)
+ << "\n\t min(upper-lower): " << min(upper-lower)
+ );
+
+
+ max_iterations = 10000;
+ eps = 0.01;
+ for (unsigned long i = 0; i < horizon; ++i)
+ {
+ target[i].set_size(A.nr());
+ target[i] = 0;
+
+ controls[i].set_size(B.nc());
+ controls[i] = 0;
+ }
+
+ // Bound the maximum eigenvalue of the hessian by computing the trace of the
+ // hessian matrix.
+ lambda = sum(R)*horizon;
+ matrix<double,S,S> temp = diagm(Q);
+ for (unsigned long c = 0; c < horizon; ++c)
+ {
+ lambda += trace(trans(B)*temp*B);
+ Q_diag[horizon-c-1] = diag(trans(B)*temp*B);
+ temp = trans(A)*temp*A + diagm(Q);
+ }
+
+ }
+
+ const matrix<double,S,S>& get_A (
+ ) const { return A; }
+ const matrix<double,S,I>& get_B (
+ ) const { return B; }
+ const matrix<double,S,1>& get_C (
+ ) const { return C; }
+ const matrix<double,S,1>& get_Q (
+ ) const { return Q; }
+ const matrix<double,I,1>& get_R (
+ ) const { return R; }
+ const matrix<double,I,1>& get_lower_constraints (
+ ) const { return lower; }
+ const matrix<double,I,1>& get_upper_constraints (
+ ) const { return upper; }
+
+ void set_target (
+ const matrix<double,S,1>& val,
+ const unsigned long time
+ )
+ {
+ DLIB_ASSERT(time < horizon,
+ "\t void mpc::set_target(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t time: " << time
+ << "\n\t horizon: " << horizon
+ );
+
+ target[time] = val;
+ }
+
+ void set_target (
+ const matrix<double,S,1>& val
+ )
+ {
+ for (unsigned long i = 0; i < horizon; ++i)
+ target[i] = val;
+ }
+
+ void set_last_target (
+ const matrix<double,S,1>& val
+ )
+ {
+ set_target(val, horizon-1);
+ }
+
+ const matrix<double,S,1>& get_target (
+ const unsigned long time
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(time < horizon,
+ "\t matrix mpc::get_target(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t time: " << time
+ << "\n\t horizon: " << horizon
+ );
+
+ return target[time];
+ }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void mpc::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps_: " << eps_
+ );
+ eps = eps_;
+ }
+
+ double get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ matrix<double,I,1> operator() (
+ const matrix<double,S,1>& current_state
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(min(R) > 0 && A.nr() == current_state.size(),
+ "\t matrix mpc::operator(current_state)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t min(R): " << min(R)
+ << "\n\t A.nr(): " << A.nr()
+ << "\n\t current_state.size(): " << current_state.size()
+ );
+
+ // Shift the inputs over by one time step so we can use them to warm start the
+ // optimizer.
+ for (unsigned long i = 1; i < horizon; ++i)
+ controls[i-1] = controls[i];
+
+ solve_linear_mpc(current_state);
+
+ for (unsigned long i = 1; i < horizon; ++i)
+ target[i-1] = target[i];
+
+ return controls[0];
+ }
+
+ private:
+
+
+ // These temporary variables here just to avoid reallocating them on each call to
+ // operator().
+ matrix<double,S,1> M[horizon];
+ matrix<double,I,1> MM[horizon];
+ matrix<double,I,1> df[horizon];
+ matrix<double,I,1> v[horizon];
+ matrix<double,I,1> v_old[horizon];
+
+ void solve_linear_mpc (
+ const matrix<double,S,1>& initial_state
+ )
+ {
+ // make it so MM == trans(K)*Q*(M-target)
+ M[0] = A*initial_state + C;
+ for (unsigned long i = 1; i < horizon; ++i)
+ M[i] = A*M[i-1] + C;
+ for (unsigned long i = 0; i < horizon; ++i)
+ M[i] = diagm(Q)*(M[i]-target[i]);
+ for (long i = (long)horizon-2; i >= 0; --i)
+ M[i] += trans(A)*M[i+1];
+ for (unsigned long i = 0; i < horizon; ++i)
+ MM[i] = trans(B)*M[i];
+
+
+
+ unsigned long iter = 0;
+ for (; iter < max_iterations; ++iter)
+ {
+ // compute current gradient and put it into df.
+ // df == H*controls + MM;
+ M[0] = B*controls[0];
+ for (unsigned long i = 1; i < horizon; ++i)
+ M[i] = A*M[i-1] + B*controls[i];
+ for (unsigned long i = 0; i < horizon; ++i)
+ M[i] = diagm(Q)*M[i];
+ for (long i = (long)horizon-2; i >= 0; --i)
+ M[i] += trans(A)*M[i+1];
+ for (unsigned long i = 0; i < horizon; ++i)
+ df[i] = MM[i] + trans(B)*M[i] + diagm(R)*controls[i];
+
+
+
+ // Check the stopping condition, which is the magnitude of the largest element
+ // of the gradient.
+ double max_df = 0;
+ unsigned long max_t = 0;
+ long max_v = 0;
+ for (unsigned long i = 0; i < horizon; ++i)
+ {
+ for (long j = 0; j < controls[i].size(); ++j)
+ {
+ // if this variable isn't an active constraint then we care about it's
+ // derivative.
+ if (!((controls[i](j) <= lower(j) && df[i](j) > 0) ||
+ (controls[i](j) >= upper(j) && df[i](j) < 0)))
+ {
+ if (std::abs(df[i](j)) > max_df)
+ {
+ max_df = std::abs(df[i](j));
+ max_t = i;
+ max_v = j;
+ }
+ }
+ }
+ }
+ if (max_df < eps)
+ break;
+
+
+
+ // We will start out by doing a little bit of coordinate descent because it
+ // allows us to optimize individual variables exactly. Since we are warm
+ // starting each iteration with a really good solution this helps speed
+ // things up a lot.
+ const unsigned long smo_iters = 50;
+ if (iter < smo_iters)
+ {
+ if (Q_diag[max_t](max_v) == 0) continue;
+
+ // Take the optimal step but just for one variable.
+ controls[max_t](max_v) = -(df[max_t](max_v)-Q_diag[max_t](max_v)*controls[max_t](max_v))/Q_diag[max_t](max_v);
+ controls[max_t](max_v) = put_in_range(lower(max_v), upper(max_v), controls[max_t](max_v));
+
+ // If this is the last SMO iteration then don't forget to initialize v
+ // for the gradient steps.
+ if (iter+1 == smo_iters)
+ {
+ for (unsigned long i = 0; i < horizon; ++i)
+ v[i] = controls[i];
+ }
+ }
+ else
+ {
+ // Take a projected gradient step.
+ for (unsigned long i = 0; i < horizon; ++i)
+ {
+ v_old[i] = v[i];
+ v[i] = dlib::clamp(controls[i] - 1.0/lambda * df[i], lower, upper);
+ controls[i] = dlib::clamp(v[i] + (std::sqrt(lambda)-1)/(std::sqrt(lambda)+1)*(v[i]-v_old[i]), lower, upper);
+ }
+ }
+ }
+ }
+
+ unsigned long max_iterations;
+ double eps;
+
+ matrix<double,S,S> A;
+ matrix<double,S,I> B;
+ matrix<double,S,1> C;
+ matrix<double,S,1> Q;
+ matrix<double,I,1> R;
+ matrix<double,I,1> lower;
+ matrix<double,I,1> upper;
+ matrix<double,S,1> target[horizon];
+
+ double lambda; // abound on the largest eigenvalue of the hessian matrix.
+ matrix<double,I,1> Q_diag[horizon];
+ matrix<double,I,1> controls[horizon];
+
+ };
+
+}
+
+#endif // DLIB_MPC_Hh_
+
diff --git a/ml/dlib/dlib/control/mpc_abstract.h b/ml/dlib/dlib/control/mpc_abstract.h
new file mode 100644
index 000000000..b4421c076
--- /dev/null
+++ b/ml/dlib/dlib/control/mpc_abstract.h
@@ -0,0 +1,276 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MPC_ABSTRACT_Hh_
+#ifdef DLIB_MPC_ABSTRACT_Hh_
+
+#include "../matrix.h"
+
+namespace dlib
+{
+ template <
+ long S_,
+ long I_,
+ unsigned long horizon_
+ >
+ class mpc
+ {
+ /*!
+ REQUIREMENTS ON horizon_
+ horizon_ > 0
+
+ REQUIREMENTS ON S_
+ S_ >= 0
+
+ REQUIREMENTS ON I_
+ I_ >= 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a linear model predictive controller. To explain
+ what that means, suppose you have some process you want to control and the
+ process dynamics are described by the linear equation:
+ x_{i+1} = A*x_i + B*u_i + C
+ That is, the next state the system goes into is a linear function of its
+ current state (x_i) and the current control (u_i) plus some constant bias
+ or disturbance.
+
+ A model predictive controller can find the control (u) you should apply to
+ drive the state (x) to some reference value, or alternatively to make the
+ state track some reference time-varying sequence. It does this by
+ simulating the process for horizon_ time steps and selecting the control
+ that leads to the best performance over the next horizon_ steps.
+
+ To be precise, each time you ask this object for a control, it solves the
+ following quadratic program:
+
+ min sum_i trans(x_i-target_i)*Q*(x_i-target_i) + trans(u_i)*R*u_i
+ x_i,u_i
+
+ such that: x_0 == current_state
+ x_{i+1} == A*x_i + B*u_i + C
+ lower <= u_i <= upper
+ 0 <= i < horizon_
+
+ and reports u_0 as the control you should take given that you are currently
+ in current_state. Q and R are user supplied matrices that define how we
+ penalize variations away from the target state as well as how much we want
+ to avoid generating large control signals.
+
+ Finally, the algorithm we use to solve this quadratic program is based
+ largely on the method described in:
+ A Fast Gradient method for embedded linear predictive control (2011)
+ by Markus Kogel and Rolf Findeisen
+ !*/
+
+ public:
+
+ const static long S = S_;
+ const static long I = I_;
+ const static unsigned long horizon = horizon_;
+
+ mpc(
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == 0
+ - The A,B,C,Q,R,lower, and upper parameter matrices are filled with zeros.
+ Therefore, to use this object you must initialize it via the constructor
+ that supplies these parameters.
+ !*/
+
+ mpc (
+ const matrix<double,S,S>& A,
+ const matrix<double,S,I>& B,
+ const matrix<double,S,1>& C,
+ const matrix<double,S,1>& Q,
+ const matrix<double,I,1>& R,
+ const matrix<double,I,1>& lower,
+ const matrix<double,I,1>& upper
+ );
+ /*!
+ requires
+ - A.nr() > 0
+ - B.nc() > 0
+ - A.nr() == A.nc() == B.nr() == C.nr() == Q.nr()
+ - B.nc() == R.nr() == lower.nr() == upper.nr()
+ - min(Q) >= 0
+ - min(R) > 0
+ - min(upper-lower) >= 0
+ ensures
+ - #get_A() == A
+ - #get_B() == B
+ - #get_C() == C
+ - #get_Q() == Q
+ - #get_R() == R
+ - #get_lower_constraints() == lower
+ - #get_upper_constraints() == upper
+ - for all valid i:
+ - get_target(i) == a vector of all zeros
+ - get_target(i).size() == A.nr()
+ - #get_max_iterations() == 10000
+ - #get_epsilon() == 0.01
+ !*/
+
+ const matrix<double,S,S>& get_A (
+ ) const;
+ /*!
+ ensures
+ - returns the A matrix from the quadratic program defined above.
+ !*/
+
+ const matrix<double,S,I>& get_B (
+ ) const;
+ /*!
+ ensures
+ - returns the B matrix from the quadratic program defined above.
+ !*/
+
+ const matrix<double,S,1>& get_C (
+ ) const;
+ /*!
+ ensures
+ - returns the C matrix from the quadratic program defined above.
+ !*/
+
+ const matrix<double,S,1>& get_Q (
+ ) const;
+ /*!
+ ensures
+ - returns the diagonal of the Q matrix from the quadratic program defined
+ above.
+ !*/
+
+ const matrix<double,I,1>& get_R (
+ ) const;
+ /*!
+ ensures
+ - returns the diagonal of the R matrix from the quadratic program defined
+ above.
+ !*/
+
+ const matrix<double,I,1>& get_lower_constraints (
+ ) const;
+ /*!
+ ensures
+ - returns the lower matrix from the quadratic program defined above. All
+ controls generated by this object will have values no less than this
+ lower bound. That is, any control u will satisfy min(u-lower) >= 0.
+ !*/
+
+ const matrix<double,I,1>& get_upper_constraints (
+ ) const;
+ /*!
+ ensures
+ - returns the upper matrix from the quadratic program defined above. All
+ controls generated by this object will have values no larger than this
+ upper bound. That is, any control u will satisfy min(upper-u) >= 0.
+ !*/
+
+ const matrix<double,S,1>& get_target (
+ const unsigned long time
+ ) const;
+ /*!
+ requires
+ - time < horizon
+ ensures
+ - This object will try to find the control sequence that results in the
+ process obtaining get_target(time) state at the indicated time. Note
+ that the next time instant after "right now" is time 0.
+ !*/
+
+ void set_target (
+ const matrix<double,S,1>& val,
+ const unsigned long time
+ );
+ /*!
+ requires
+ - time < horizon
+ ensures
+ - #get_target(time) == val
+ !*/
+
+ void set_target (
+ const matrix<double,S,1>& val
+ );
+ /*!
+ ensures
+ - for all valid t:
+ - #get_target(t) == val
+ !*/
+
+ void set_last_target (
+ const matrix<double,S,1>& val
+ );
+ /*!
+ ensures
+ - performs: set_target(val, horizon-1)
+ !*/
+
+ unsigned long get_max_iterations (
+ ) const;
+ /*!
+ ensures
+ - When operator() is called it solves an optimization problem to
+ get_epsilon() precision to determine the next control action. In
+ particular, we run the optimizer until the magnitude of each element of
+ the gradient vector is less than get_epsilon() or until
+ get_max_iterations() solver iterations have been executed.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ void set_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - When operator() is called it solves an optimization problem to
+ get_epsilon() precision to determine the next control action. In
+ particular, we run the optimizer until the magnitude of each element of
+ the gradient vector is less than get_epsilon() or until
+ get_max_iterations() solver iterations have been executed. This means
+ that smaller epsilon values will give more accurate outputs but may take
+ longer to compute.
+ !*/
+
+ matrix<double,I,1> operator() (
+ const matrix<double,S,1>& current_state
+ );
+ /*!
+ requires
+ - min(R) > 0
+ - A.nr() == current_state.size()
+ ensures
+ - Solves the model predictive control problem defined by the arguments to
+ this objects constructor, assuming that the starting state is given by
+ current_state. Then we return the control that should be taken in the
+ current state that best optimizes the quadratic objective function
+ defined above.
+ - We also shift over the target states so that you only need to update the
+ last one (if you are using non-zero target states) via a call to
+ set_last_target()). In particular, for all valid t, it will be the case
+ that:
+ - #get_target(t) == get_target(t+1)
+ - #get_target(horizon-1) == get_target(horizon-1)
+ !*/
+
+ };
+
+}
+
+#endif // DLIB_MPC_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/cpp_pretty_printer.h b/ml/dlib/dlib/cpp_pretty_printer.h
new file mode 100644
index 000000000..5315559ba
--- /dev/null
+++ b/ml/dlib/dlib/cpp_pretty_printer.h
@@ -0,0 +1,39 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CPP_PRETTY_PRINTEr_
+#define DLIB_CPP_PRETTY_PRINTEr_
+
+
+#include "cpp_pretty_printer/cpp_pretty_printer_kernel_1.h"
+#include "cpp_pretty_printer/cpp_pretty_printer_kernel_2.h"
+#include "cpp_tokenizer.h"
+#include "stack.h"
+
+namespace dlib
+{
+
+ class cpp_pretty_printer
+ {
+ cpp_pretty_printer() {}
+
+
+ typedef stack<unsigned long>::kernel_1a stack;
+ typedef cpp_tokenizer::kernel_1a tok;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef cpp_pretty_printer_kernel_1<stack,tok>
+ kernel_1a;
+
+ // kernel_2a
+ typedef cpp_pretty_printer_kernel_2<stack,tok>
+ kernel_2a;
+
+ };
+}
+
+#endif // DLIB_CPP_PRETTY_PRINTEr_
+
diff --git a/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h
new file mode 100644
index 000000000..668d5049d
--- /dev/null
+++ b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h
@@ -0,0 +1,583 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CPP_PRETTY_PRINTER_KERNEl_1_
+#define DLIB_CPP_PRETTY_PRINTER_KERNEl_1_
+
+#include <string>
+#include <iostream>
+#include <sstream>
+#include "cpp_pretty_printer_kernel_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename stack,
+ typename tok
+ >
+ class cpp_pretty_printer_kernel_1
+ {
+ /*!
+ REQUIREMENTS ON stack
+ must be an implementation of stack/stack_kernel_abstract.h and
+ stack::type == unsigned long
+
+ REQUIREMENTS ON tok
+ must be an implementation of tokenizer/tokenizer_kernel_abstract.h
+
+ INFO
+ This implementation applies a color scheme, turns include directives
+ such as #include "file.h" into links to file.h.html, and it also puts
+ HTML anchor points on function and class declarations.
+ !*/
+
+ public:
+
+ cpp_pretty_printer_kernel_1 (
+ );
+
+ virtual ~cpp_pretty_printer_kernel_1 (
+ );
+
+ void print (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const;
+
+ void print_and_number (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const;
+
+ private:
+
+ const std::string htmlify (
+ const std::string& str
+ ) const;
+ /*!
+ ensures
+ - str == str but with any '<' replaced with '&lt;', any '>' replaced
+ with '&gt;', and any '&' replaced with '&amp;'
+ !*/
+
+ // data members
+ mutable tok t;
+
+ void number (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+ /*!
+ ensures
+ - prints in to out and adds line numbers
+ !*/
+
+ // restricted functions
+ cpp_pretty_printer_kernel_1(const cpp_pretty_printer_kernel_1&); // copy constructor
+ cpp_pretty_printer_kernel_1& operator=(const cpp_pretty_printer_kernel_1&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ cpp_pretty_printer_kernel_1<stack,tok>::
+ cpp_pretty_printer_kernel_1 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ cpp_pretty_printer_kernel_1<stack,tok>::
+ ~cpp_pretty_printer_kernel_1 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ void cpp_pretty_printer_kernel_1<stack,tok>::
+ print (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const
+ {
+ using namespace std;
+
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::print");
+
+ t.set_stream(in);
+
+ out << "<html><!-- "
+ << "Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates."
+ << " --><head><title>" << title << "</title></head><body bgcolor='white'><pre>\n";
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::print");
+
+ unsigned long scope = 0; // counts the number of new scopes we have entered
+ // since we were at a scope where functions can be declared
+
+ bool recently_seen_class_keyword = false;
+ // true if we have seen the keywords class, struct, or enum and
+ // we have not seen any identifiers or { characters
+
+ bool recently_seen_include = false;
+ // true if we have seen the #include keyword and have not seen double
+ // quoted text or >
+
+ bool recently_seen_new_scope = false;
+ // true if we have seen the keywords class, namespace, or struct and
+ // we have not seen the characters {, ), or ; since then
+
+ bool recently_seen_paren = false;
+ // true if we have seen a ) and we have only seen white_space or comments since
+
+ bool in_initialization_list = false;
+ // true if we have seen a ) followed by any white space or comments and then
+ // followed by a : (in scope==0 with recently_seen_preprocessor==false) and we
+ // have not yet seen the character { or ;
+
+ bool recently_seen_preprocessor = false;
+ // true if we have seen the #pragma or #if or #define or #elif keywords and have
+ // not seen an end of line.
+
+ bool recently_seen_extern = false;
+ // true if we have seen the extern keyword and haven't seen a ; or { yet.
+
+ unsigned long paren_count = 0;
+ // this is the number of ( we have seen minus the number of ) we have
+ // seen.
+
+
+ int type;
+ stack scopes; // a stack to hold old scopes
+ string token, temp;
+ t.get_token(type,token);
+ while (type != tok::END_OF_FILE)
+ {
+ switch (type)
+ {
+ case tok::IDENTIFIER: // ------------------------------------------
+ if ( recently_seen_class_keyword)
+ {
+ // this might be a class name so check if there is a
+ // ; or identifier or * or & coming up.
+ type = t.peek_type();
+ temp.clear();
+ if (type == tok::WHITE_SPACE)
+ {
+ t.get_token(type,temp);
+ if (temp.find_first_of("\n\r") != string::npos)
+ recently_seen_preprocessor = false;
+ }
+ if (t.peek_token() != ";" && t.peek_type() != tok::IDENTIFIER &&
+ t.peek_token() != "*" && t.peek_token() != "&")
+ {
+ // this is the name of a class or struct in a class or
+ // struct declaration.
+ out << "<b><a name='" << token << "'></a>" << token << "</b>" << temp;
+ }
+ else
+ {
+ out << token << temp;
+ }
+ }
+ else if ( !in_initialization_list &&
+ !recently_seen_preprocessor )
+ {
+ // this might be a function name so check if there is a
+ // ( coming up.
+ type = t.peek_type();
+ temp.clear();
+ if (type == tok::WHITE_SPACE)
+ {
+ t.get_token(type,temp);
+ type = t.peek_type();
+ }
+ if (type == tok::OTHER && t.peek_token() == "(")
+ {
+ if (scope == 0 && paren_count == 0)
+ {
+ // this is a function definition or prototype
+ out << "<b><a name='" << token << "'></a>" << token << "</b>" << temp;
+ }
+ else
+ {
+ // this is a function call (probably)
+ out << "<font color='#BB00BB'>" << token << "</font>" << temp;
+ }
+ }
+ else
+ {
+ out << token << temp;
+ }
+ }
+ else
+ {
+ out << token;
+ }
+
+
+
+ recently_seen_class_keyword = false;
+ recently_seen_paren = false;
+ break;
+
+ case tok::KEYWORD: // ---------------------------------------------
+ if (scope == 0 && token == "operator")
+ {
+ // Doing this is sort of weird since operator is really a keyword
+ // but I just like how this looks.
+ out << "<b><a name='" << token << "'></a>" << token << "</b>";
+ }
+ // this isn't a keyword if it is something like #include <new>
+ else if ( token == "true" || token == "false")
+ {
+ // color 'true' and 'false' the same way we color numbers
+ out << "<font color='#979000'>" << token << "</font>";
+ }
+ else if (!recently_seen_include)
+ {
+ // This is a normal keyword
+ if (token == "char" || token == "unsigned" || token == "signed" ||
+ token == "short" || token == "int" || token == "long" ||
+ token == "float" || token == "double" || token == "bool" ||
+ token == "void" || token == "size_t" || token == "wchar_t")
+ {
+ out << "<font color='#0000FF'><u>" << token << "</u></font>";
+ }
+ else
+ {
+ out << "<font color='#0000FF'>" << token << "</font>";
+ }
+ }
+ else
+ {
+ out << token;
+ }
+
+ if (token == "#include")
+ {
+ recently_seen_include = true;
+ }
+ else if (token == "class")
+ {
+ recently_seen_new_scope = true;
+ recently_seen_class_keyword = true;
+ }
+ else if (token == "namespace")
+ {
+ recently_seen_new_scope = true;
+ }
+ else if (token == "enum")
+ {
+ recently_seen_class_keyword = true;
+ }
+ else if (token == "struct")
+ {
+ recently_seen_new_scope = true;
+ recently_seen_class_keyword = true;
+ }
+ else if (token == "#pragma" || token == "#if" || token == "#define" || token == "#elif")
+ {
+ recently_seen_preprocessor = true;
+ }
+ else if (token == "extern")
+ {
+ recently_seen_extern = true;
+ }
+ recently_seen_paren = false;
+ break;
+
+ case tok::COMMENT: // ---------------------------------------------
+ {
+ // if this is a special anchor comment
+ if (token.size() > 4 &&
+ token[0] == '/' &&
+ token[1] == '*' &&
+ token[2] == '!' &&
+ token[3] == 'A' &&
+ token[4] == ' '
+ )
+ {
+ temp = token;
+ istringstream sin(token);
+ sin >> temp;
+ sin >> temp;
+ sin.get();
+ // if there was still more stuff in the token then we are ok.
+ if (sin)
+ out << "<a name='" << temp << "'/>";
+ }
+ out << "<font color='#009900'>" << htmlify(token) << "</font>";
+ }
+ break;
+
+ case tok::SINGLE_QUOTED_TEXT: // ----------------------------------
+ {
+ out << "<font color='#FF0000'>" << htmlify(token) << "</font>";
+ recently_seen_paren = false;
+ }
+ break;
+
+ case tok::NUMBER: // -----------------------------------------
+ {
+ out << "<font color='#979000'>" << token << "</font>";
+ recently_seen_include = false;
+ }
+ break;
+
+ case tok::WHITE_SPACE: // -----------------------------------------
+ {
+ out << token;
+ if (token.find_first_of("\n\r") != string::npos)
+ recently_seen_preprocessor = false;
+ }
+ break;
+
+ case tok::DOUBLE_QUOTED_TEXT: // ----------------------------------
+ {
+ if (recently_seen_include)
+ {
+ // this is the name of an included file
+ recently_seen_include = false;
+ out << "<a style='text-decoration:none' href='" << htmlify(token) << ".html'>" << htmlify(token) << "</a>";
+ }
+ else
+ {
+ // this is just a normal quoted string
+ out << "<font color='#CC0000'>" << htmlify(token) << "</font>";
+ }
+ recently_seen_paren = false;
+ }
+ break;
+
+ case tok::OTHER: // -----------------------------------------------
+ switch (token[0])
+ {
+ case '{':
+ out << "<b>{</b>";
+ // if we are entering a new scope
+ if (recently_seen_new_scope || recently_seen_extern)
+ {
+ recently_seen_new_scope = false;
+ scopes.push(scope);
+ scope = 0;
+ }
+ else
+ {
+ ++scope;
+ }
+ in_initialization_list = false;
+ recently_seen_paren = false;
+ recently_seen_class_keyword = false;
+ recently_seen_extern = false;
+ break;
+ case '}':
+ out << "<b>}</b>";
+ if (scope > 0)
+ {
+ --scope;
+ }
+ else if (scopes.size())
+ {
+ scopes.pop(scope);
+ }
+ recently_seen_paren = false;
+ break;
+
+ case ':':
+ out << ':';
+ if (recently_seen_paren && scope == 0 &&
+ recently_seen_preprocessor == false)
+ {
+ in_initialization_list = true;
+ }
+ recently_seen_paren = false;
+ break;
+
+ case ';':
+ out << ';';
+ recently_seen_new_scope = false;
+ recently_seen_paren = false;
+ recently_seen_extern = false;
+ break;
+
+ case ')':
+ out << "<font face='Lucida Console'>)</font>";
+ recently_seen_paren = true;
+ recently_seen_new_scope = false;
+ --paren_count;
+ break;
+
+ case '(':
+ out << "<font face='Lucida Console'>(</font>";
+ recently_seen_paren = false;
+ ++paren_count;
+ break;
+
+ case '>':
+ recently_seen_include = false;
+ out << "<font color='#5555FF'>&gt;</font>";
+ recently_seen_paren = false;
+ break;
+
+ case '<':
+ out << "<font color='#5555FF'>&lt;</font>";
+ recently_seen_paren = false;
+ break;
+
+ case '&':
+ out << "<font color='#5555FF'>&amp;</font>";
+ recently_seen_paren = false;
+ break;
+
+ case '=':
+ case '+':
+ case '-':
+ case '/':
+ case '*':
+ case '!':
+ case '|':
+ case '%':
+ out << "<font color='#5555FF'>" << token << "</font>";
+ recently_seen_paren = false;
+ break;
+
+ default:
+ out << token;
+ recently_seen_paren = false;
+ break;
+
+ } // switch (token[0])
+ break;
+
+ } // switch (type)
+
+ t.get_token(type,token);
+ } // while (type != tok::END_OF_FILE)
+
+
+ out << "\n</pre></body></html>";
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::print");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ void cpp_pretty_printer_kernel_1<stack,tok>::
+ print_and_number (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const
+ {
+ using namespace std;
+ ostringstream sout;
+ print(in,sout,title);
+ istringstream sin(sout.str());
+ number(sin,out);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ void cpp_pretty_printer_kernel_1<stack,tok>::
+ number (
+ std::istream& in,
+ std::ostream& out
+ ) const
+ {
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::number");
+
+ std::string space = "&nbsp;&nbsp;&nbsp;";
+ std::ios::int_type ch;
+ unsigned long count = 1;
+ while ((ch=in.get()) != EOF)
+ {
+ if (ch != '\n')
+ {
+ out << (char)ch;
+ }
+ else
+ {
+ out << "\n<font color='555555'>" << count << " </font> " + space;
+ ++count;
+ if (count == 10)
+ space = "&nbsp;&nbsp;";
+ if (count == 100)
+ space = "&nbsp;";
+ if (count == 1000)
+ space = "";
+ }
+ }
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::number");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ const std::string cpp_pretty_printer_kernel_1<stack,tok>::
+ htmlify (
+ const std::string& str
+ ) const
+ {
+ std::string::size_type i;
+ std::string temp;
+ for (i = 0; i < str.size(); ++i)
+ {
+ if (str[i] == '<')
+ temp += "&lt;";
+ else if (str[i] == '>')
+ temp += "&gt;";
+ else if (str[i] == '&')
+ temp += "&amp;";
+ else
+ temp += str[i];
+ }
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CPP_PRETTY_PRINTER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h
new file mode 100644
index 000000000..5ac894b33
--- /dev/null
+++ b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h
@@ -0,0 +1,520 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CPP_PRETTY_PRINTER_KERNEl_2_
+#define DLIB_CPP_PRETTY_PRINTER_KERNEl_2_
+
+#include <string>
+#include <iostream>
+#include <sstream>
+#include "cpp_pretty_printer_kernel_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename stack,
+ typename tok
+ >
+ class cpp_pretty_printer_kernel_2
+ {
+ /*!
+ REQUIREMENTS ON stack
+ must be an implementation of stack/stack_kernel_abstract.h and
+ stack::type == unsigned long
+
+ REQUIREMENTS ON tok
+ must be an implementation of tokenizer/tokenizer_kernel_abstract.h
+
+ INFO
+ This implementation applies a black and white color scheme suitable
+ for printing on a black and white printer. It also places the document
+ title prominently at the top of the pretty printed source file.
+ !*/
+
+ public:
+
+ cpp_pretty_printer_kernel_2 (
+ );
+
+ virtual ~cpp_pretty_printer_kernel_2 (
+ );
+
+ void print (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const;
+
+ void print_and_number (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const;
+
+ private:
+
+ // data members
+ mutable tok t;
+
+ const std::string htmlify (
+ const std::string& str
+ ) const;
+ /*!
+ ensures
+ - str == str but with any '<' replaced with '&lt;', any '>' replaced
+ with '&gt;', and any '&' replaced with '&amp;'
+ !*/
+
+ void number (
+ std::istream& in,
+ std::ostream& out
+ ) const;
+ /*!
+ ensures
+ - prints in to out and adds line numbers
+ !*/
+
+ // restricted functions
+ cpp_pretty_printer_kernel_2(const cpp_pretty_printer_kernel_2&); // copy constructor
+ cpp_pretty_printer_kernel_2& operator=(const cpp_pretty_printer_kernel_2&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ cpp_pretty_printer_kernel_2<stack,tok>::
+ cpp_pretty_printer_kernel_2 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ cpp_pretty_printer_kernel_2<stack,tok>::
+ ~cpp_pretty_printer_kernel_2 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ void cpp_pretty_printer_kernel_2<stack,tok>::
+ print (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const
+ {
+ using namespace std;
+
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::print");
+
+ t.set_stream(in);
+
+ out << "<html><!-- "
+ << "Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates."
+ << " --><head>"
+ << "<title>" << title << "</title></head><body bgcolor='white'>"
+ << "<h1><center>" << title << "</center></h1><pre>\n"
+ << "<font style='font-size:9pt' face='Lucida Console'>\n";
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::print");
+
+ unsigned long scope = 0; // counts the number of new scopes we have entered
+ // since we were at a scope where functions can be declared
+
+ bool recently_seen_class_keyword = false;
+ // true if we have seen the keywords class or struct and
+ // we have not seen any identifiers or { characters
+
+ bool recently_seen_include = false;
+ // true if we have seen the #include keyword and have not seen double
+ // quoted text or >
+
+ bool recently_seen_new_scope = false;
+ // true if we have seen the keywords class, namespace, or struct and
+ // we have not seen the characters {, ), or ; since then
+
+ bool recently_seen_paren = false;
+ // true if we have seen a ) and we have only seen white_space or comments since
+
+ bool in_initialization_list = false;
+ // true if we have seen a ) followed by any white space or comments and then
+ // followed by a : (in scope==0 with recently_seen_preprocessor==false) and we
+ // have not yet seen the character { or ;
+
+ bool recently_seen_preprocessor = false;
+ // true if we have seen the #pragma or #if or #define or #elif keyword and
+ // have not seen an identifier.
+
+
+ bool recently_seen_extern = false;
+ // true if we have seen the extern keyword and haven't yet seen a
+ // { or ; character.
+
+ unsigned long paren_count = 0;
+ // this is the number of ( we have seen minus the number of ) we have
+ // seen.
+
+
+ int type;
+ stack scopes; // a stack to hold old scopes
+ string token, temp;
+ t.get_token(type,token);
+ while (type != tok::END_OF_FILE)
+ {
+ switch (type)
+ {
+ case tok::IDENTIFIER: // ------------------------------------------
+ if ( recently_seen_class_keyword)
+ {
+ // this might be a class name so check if there is a
+ // ; or identifier or * or &amp; coming up.
+ type = t.peek_type();
+ temp.clear();
+ if (type == tok::WHITE_SPACE)
+ {
+ t.get_token(type,temp);
+ if (temp.find_first_of("\n\r") != string::npos)
+ recently_seen_preprocessor = false;
+ }
+ if (t.peek_token() != ";" && t.peek_type() != tok::IDENTIFIER &&
+ t.peek_token() != "*" && t.peek_token() != "&amp;")
+ {
+ // this is the name of a class or struct in a class or
+ // struct declaration.
+ out << "<b><i>" << token << "</i></b>" << temp;
+ }
+ else
+ {
+ out << token << temp;
+ }
+ }
+ else if ( !in_initialization_list &&
+ !recently_seen_preprocessor &&
+ scope == 0 &&
+ paren_count == 0)
+ {
+ // this might be a function name so check if there is a
+ // ( coming up.
+ type = t.peek_type();
+ temp.clear();
+ if (type == tok::WHITE_SPACE)
+ {
+ t.get_token(type,temp);
+ type = t.peek_type();
+ }
+ if (type == tok::OTHER && t.peek_token() == "(")
+ {
+ // this is a function definition or prototype
+ out << "<b><i>" << token << "</i></b>" << temp;
+ }
+ else
+ {
+ out << token << temp;
+ }
+ }
+ else
+ {
+ out << token;
+ }
+
+
+
+ recently_seen_class_keyword = false;
+ recently_seen_paren = false;
+ break;
+
+ case tok::KEYWORD: // ---------------------------------------------
+ if (scope == 0 && token == "operator")
+ {
+ // Doing this is sort of weird since operator is really a keyword
+ // but I just like how this looks.
+ out << "<b><i>" << token << "</i></b>";
+ }
+ // this isn't a keyword if it is something like #include <new>
+ else if (!recently_seen_include)
+ {
+ // This is a normal keyword
+ out << "<u><font face='Fixedsys'>" << token << "</font></u>";
+ }
+ else
+ {
+ out << token;
+ }
+
+ if (token == "#include")
+ {
+ recently_seen_include = true;
+ }
+ else if (token == "class")
+ {
+ recently_seen_new_scope = true;
+ recently_seen_class_keyword = true;
+ }
+ else if (token == "namespace")
+ {
+ recently_seen_new_scope = true;
+ }
+ else if (token == "struct")
+ {
+ recently_seen_new_scope = true;
+ recently_seen_class_keyword = true;
+ }
+ else if (token == "#pragma" || token == "#define" || token == "#elif" || token == "#if")
+ {
+ recently_seen_preprocessor = true;
+ }
+ else if (token == "extern")
+ {
+ recently_seen_extern = true;
+ }
+ recently_seen_paren = false;
+ break;
+
+ case tok::COMMENT: // ---------------------------------------------
+ {
+ out << "<font face='Courier New'>" << htmlify(token) << "</font>";
+ }
+ break;
+
+ case tok::SINGLE_QUOTED_TEXT: // ----------------------------------
+ {
+ out << htmlify(token);
+ recently_seen_paren = false;
+ }
+ break;
+
+ case tok::WHITE_SPACE: // -----------------------------------------
+ {
+ out << token;
+ if (token.find_first_of("\n\r") != string::npos)
+ recently_seen_preprocessor = false;
+ }
+ break;
+
+ case tok::DOUBLE_QUOTED_TEXT: // ----------------------------------
+ {
+ out << htmlify(token);
+ recently_seen_paren = false;
+ recently_seen_include = false;
+ }
+ break;
+
+ case tok::NUMBER:
+ case tok::OTHER: // -----------------------------------------------
+ switch (token[0])
+ {
+ case '{':
+ out << "<b>{</b>";
+ // if we are entering a new scope
+ if (recently_seen_new_scope || recently_seen_extern)
+ {
+ recently_seen_new_scope = false;
+ scopes.push(scope);
+ scope = 0;
+ }
+ else
+ {
+ ++scope;
+ }
+ in_initialization_list = false;
+ recently_seen_paren = false;
+ recently_seen_class_keyword = false;
+ recently_seen_extern = false;
+ break;
+ case '}':
+ out << "<b>}</b>";
+ if (scope > 0)
+ {
+ --scope;
+ }
+ else if (scopes.size())
+ {
+ scopes.pop(scope);
+ }
+ recently_seen_paren = false;
+ break;
+
+ case ':':
+ out << ':';
+ if (recently_seen_paren && scope == 0 &&
+ recently_seen_preprocessor == false)
+ {
+ in_initialization_list = true;
+ }
+ recently_seen_paren = false;
+ break;
+
+ case ';':
+ out << ';';
+ recently_seen_new_scope = false;
+ recently_seen_paren = false;
+ recently_seen_extern = false;
+ break;
+
+ case ')':
+ out << ')';
+ recently_seen_paren = true;
+ recently_seen_new_scope = false;
+ --paren_count;
+ break;
+
+ case '(':
+ out << '(';
+ recently_seen_paren = false;
+ ++paren_count;
+ break;
+
+ case '>':
+ recently_seen_include = false;
+ out << "&gt;";
+ recently_seen_paren = true;
+ break;
+
+ case '<':
+ out << "&lt;";
+ recently_seen_paren = true;
+ break;
+
+ case '&':
+ out << "&amp;";
+ recently_seen_paren = true;
+ break;
+
+ default:
+ out << token;
+ recently_seen_paren = false;
+ if (token == "&gt;")
+ recently_seen_include = false;
+ break;
+
+ } // switch (token[0])
+ break;
+
+ } // switch (type)
+
+ t.get_token(type,token);
+ } // while (type != tok::END_OF_FILE)
+
+
+ out << "</font></pre></body></html>";
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::print");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ void cpp_pretty_printer_kernel_2<stack,tok>::
+ print_and_number (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const
+ {
+ using namespace std;
+ ostringstream sout;
+ print(in,sout,title);
+ istringstream sin(sout.str());
+ number(sin,out);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ void cpp_pretty_printer_kernel_2<stack,tok>::
+ number (
+ std::istream& in,
+ std::ostream& out
+ ) const
+ {
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::number");
+
+ std::string space = "&nbsp;&nbsp;&nbsp;";
+ std::ios::int_type ch;
+ unsigned long count = 1;
+ while ((ch=in.get()) != EOF)
+ {
+ if (ch != '\n')
+ {
+ out << (char)ch;
+ }
+ else
+ {
+ out << "\n<i><font face='Courier New'>" << count << " </font></i> " + space;
+ ++count;
+ if (count == 10)
+ space = "&nbsp;&nbsp;";
+ if (count == 100)
+ space = "&nbsp;";
+ if (count == 1000)
+ space = "";
+ }
+ }
+ if (!out)
+ throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::number");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack,
+ typename tok
+ >
+ const std::string cpp_pretty_printer_kernel_2<stack,tok>::
+ htmlify (
+ const std::string& str
+ ) const
+ {
+ std::string::size_type i;
+ std::string temp;
+ for (i = 0; i < str.size(); ++i)
+ {
+ if (str[i] == '<')
+ temp += "&lt;";
+ else if (str[i] == '>')
+ temp += "&gt;";
+ else if (str[i] == '&')
+ temp += "&amp;";
+ else
+ temp += str[i];
+ }
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CPP_PRETTY_PRINTER_KERNEl_2_
+
diff --git a/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h
new file mode 100644
index 000000000..7d572d4fd
--- /dev/null
+++ b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h
@@ -0,0 +1,88 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CPP_PRETTY_PRINTER_KERNEl_ABSTRACT_
+#ifdef DLIB_CPP_PRETTY_PRINTER_KERNEl_ABSTRACT_
+
+#include <string>
+#include <ioswfd>
+
+namespace dlib
+{
+
+ class cpp_pretty_printer
+ {
+ /*!
+ INITIAL VALUE
+ This object does not have any state associated with it.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an HTML pretty printer for C++ source code.
+
+ !*/
+
+ public:
+
+ cpp_pretty_printer (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~cpp_pretty_printer (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ void print (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const;
+ /*!
+ ensures
+ - treats data from in as C++ source code and pretty prints it in
+ HTML and writes it to out.
+ - The title of the HTML document writen to out will be title
+ throws
+ - std::ios_base::failure
+ If there was a problem writing to out then this exception will
+ be thrown.
+ - any other exception
+ This exception may be thrown if there is any other problem.
+ !*/
+
+ void print_and_number (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& title
+ ) const;
+ /*!
+ ensures
+ - treats data from in as C++ source code and pretty prints it in
+ HTML with line numbers and writes it to out.
+ - The title of the HTML document writen to out will be title
+ throws
+ - std::ios_base::failure
+ If there was a problem writing to out then this exception will
+ be thrown.
+ - any other exception
+ This exception may be thrown if there is any other problem.
+ !*/
+
+ private:
+
+ // restricted functions
+ cpp_pretty_printer(const cpp_pretty_printer&); // copy constructor
+ cpp_pretty_printer& operator=(const cpp_pretty_printer&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_CPP_PRETTY_PRINTER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/cpp_tokenizer.h b/ml/dlib/dlib/cpp_tokenizer.h
new file mode 100644
index 000000000..676ad7a52
--- /dev/null
+++ b/ml/dlib/dlib/cpp_tokenizer.h
@@ -0,0 +1,40 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CPP_TOKENIZEr_
+#define DLIB_CPP_TOKENIZEr_
+
+#include <string>
+#include "cpp_tokenizer/cpp_tokenizer_kernel_1.h"
+#include "cpp_tokenizer/cpp_tokenizer_kernel_c.h"
+#include "tokenizer.h"
+#include "queue.h"
+#include "set.h"
+
+namespace dlib
+{
+
+ class cpp_tokenizer
+ {
+ cpp_tokenizer() {}
+
+
+ typedef set<std::string>::kernel_1a set;
+ typedef queue<cpp_tok_kernel_1_helper::token_text_pair>::kernel_2a queue;
+ typedef tokenizer::kernel_1a tok;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef cpp_tokenizer_kernel_1<tok,queue,set>
+ kernel_1a;
+ typedef cpp_tokenizer_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ };
+}
+
+#endif // DLIB_CPP_TOKENIZEr_
+
diff --git a/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h
new file mode 100644
index 000000000..8a244faa7
--- /dev/null
+++ b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h
@@ -0,0 +1,675 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CPP_TOKENIZER_KERNEl_1_
+#define DLIB_CPP_TOKENIZER_KERNEl_1_
+
+#include <string>
+#include <iostream>
+#include "cpp_tokenizer_kernel_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ namespace cpp_tok_kernel_1_helper
+ {
+ struct token_text_pair
+ {
+ std::string token;
+ int type;
+ };
+
+ }
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ class cpp_tokenizer_kernel_1
+ {
+ /*!
+ REQUIREMENTS ON tok
+ tok must be an implementation of tokenizer/tokenizer_kernel_abstract.h
+
+ REQUIREMENTS ON queue
+ queue must be an implementation of queue/queue_kernel_abstract.h
+ and must have T==cpp_tok_kernel_1_helper::token_text_pair
+
+ REQUIREMENTS ON set
+ set must be an implemention of set/set_kernel_abstract.h or
+ hash_set/hash_set_kernel_abstract.h and must have T==std::string.
+
+ INITIAL VALUE
+ - keywords == a set of all the C++ keywords
+ - tokenizer.stream_is_set() == false
+ - buffer.size() == 0
+ - tokenizer.get_identifier_head() == "$_" + tokenizer.lowercase_letters() +
+ tokenizer.uppercase_letters()
+ - tokenizer.get_identifier_body() == "$_" + tokenizer.lowercase_letters() +
+ tokenizer.uppercase_letters() + tokenizer.numbers()
+ - have_peeked == false
+
+
+ CONVENTION
+ - tokenizer.stream_is_set() == stream_is_set()
+ - tokenizer.get_stream() == get_stream()
+ - keywords == a set of all the C++ keywords
+
+ - tokenizer.get_identifier_head() == "$_" + tokenizer.lowercase_letters() +
+ tokenizer.uppercase_letters()
+ - tokenizer.get_identifier_body() == "$_" + tokenizer.lowercase_letters() +
+ tokenizer.uppercase_letters() + tokenizer.numbers()
+
+ - buffer == a queue of tokens. This is where we put tokens
+ we gathered early due to looking ahead.
+
+
+ - if (have_peeked) then
+ - next_token == the next token to be returned from get_token()
+ - next_type == the type of token in peek_token
+ !*/
+
+ typedef cpp_tok_kernel_1_helper::token_text_pair token_text_pair;
+
+ public:
+
+ enum
+ {
+ END_OF_FILE,
+ KEYWORD,
+ COMMENT,
+ SINGLE_QUOTED_TEXT,
+ DOUBLE_QUOTED_TEXT,
+ IDENTIFIER,
+ OTHER,
+ NUMBER,
+ WHITE_SPACE
+ };
+
+ cpp_tokenizer_kernel_1 (
+ );
+
+ virtual ~cpp_tokenizer_kernel_1 (
+ );
+
+ void clear(
+ );
+
+ void set_stream (
+ std::istream& in
+ );
+
+ bool stream_is_set (
+ ) const;
+
+ std::istream& get_stream (
+ ) const;
+
+ void get_token (
+ int& type,
+ std::string& token
+ );
+
+ int peek_type (
+ ) const;
+
+ const std::string& peek_token (
+ ) const;
+
+ void swap (
+ cpp_tokenizer_kernel_1<tok,queue,set>& item
+ );
+
+ private:
+
+ void buffer_token(
+ int type,
+ const std::string& token
+ )
+ /*!
+ ensures
+ - stores the token and its type into buffer
+ !*/
+ {
+ token_text_pair temp;
+ temp.token = token;
+ temp.type = type;
+ buffer.enqueue(temp);
+ }
+
+ void buffer_token(
+ int type,
+ char token
+ )
+ /*!
+ ensures
+ - stores the token and its type into buffer
+ !*/
+ {
+ token_text_pair temp;
+ temp.token = token;
+ temp.type = type;
+ buffer.enqueue(temp);
+ }
+
+ // restricted functions
+ cpp_tokenizer_kernel_1(const cpp_tokenizer_kernel_1<tok,queue,set>&); // copy constructor
+ cpp_tokenizer_kernel_1<tok,queue,set>& operator=(const cpp_tokenizer_kernel_1<tok,queue,set>&); // assignment operator
+
+ // data members
+ set keywords;
+ queue buffer;
+ tok tokenizer;
+
+ mutable std::string next_token;
+ mutable int next_type;
+ mutable bool have_peeked;
+
+
+ };
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ inline void swap (
+ cpp_tokenizer_kernel_1<tok,queue,set>& a,
+ cpp_tokenizer_kernel_1<tok,queue,set>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ cpp_tokenizer_kernel_1<tok,queue,set>::
+ cpp_tokenizer_kernel_1(
+ ) :
+ have_peeked(false)
+ {
+ // add C++ keywords to keywords
+ std::string temp;
+ temp = "#include"; keywords.add(temp);
+ temp = "__asm"; keywords.add(temp);
+ temp = "_asm"; keywords.add(temp);
+ temp = "if"; keywords.add(temp);
+ temp = "int"; keywords.add(temp);
+ temp = "else"; keywords.add(temp);
+ temp = "template"; keywords.add(temp);
+ temp = "void"; keywords.add(temp);
+ temp = "false"; keywords.add(temp);
+ temp = "class"; keywords.add(temp);
+ temp = "public"; keywords.add(temp);
+ temp = "while"; keywords.add(temp);
+ temp = "bool"; keywords.add(temp);
+ temp = "new"; keywords.add(temp);
+ temp = "delete"; keywords.add(temp);
+ temp = "true"; keywords.add(temp);
+ temp = "typedef"; keywords.add(temp);
+ temp = "const"; keywords.add(temp);
+ temp = "virtual"; keywords.add(temp);
+ temp = "inline"; keywords.add(temp);
+ temp = "for"; keywords.add(temp);
+ temp = "break"; keywords.add(temp);
+ temp = "struct"; keywords.add(temp);
+ temp = "float"; keywords.add(temp);
+ temp = "case"; keywords.add(temp);
+ temp = "enum"; keywords.add(temp);
+ temp = "this"; keywords.add(temp);
+ temp = "typeid"; keywords.add(temp);
+ temp = "double"; keywords.add(temp);
+ temp = "char"; keywords.add(temp);
+ temp = "typename"; keywords.add(temp);
+ temp = "signed"; keywords.add(temp);
+ temp = "friend"; keywords.add(temp);
+ temp = "wint_t"; keywords.add(temp);
+ temp = "default"; keywords.add(temp);
+ temp = "asm"; keywords.add(temp);
+ temp = "reinterpret_cast"; keywords.add(temp);
+ temp = "#define"; keywords.add(temp);
+ temp = "do"; keywords.add(temp);
+ temp = "continue"; keywords.add(temp);
+ temp = "auto"; keywords.add(temp);
+ temp = "unsigned"; keywords.add(temp);
+ temp = "size_t"; keywords.add(temp);
+ temp = "#undef"; keywords.add(temp);
+ temp = "#pragma"; keywords.add(temp);
+ temp = "namespace"; keywords.add(temp);
+ temp = "private"; keywords.add(temp);
+ temp = "#endif"; keywords.add(temp);
+ temp = "catch"; keywords.add(temp);
+ temp = "#else"; keywords.add(temp);
+ temp = "register"; keywords.add(temp);
+ temp = "volatile"; keywords.add(temp);
+ temp = "const_cast"; keywords.add(temp);
+ temp = "#end"; keywords.add(temp);
+ temp = "mutable"; keywords.add(temp);
+ temp = "static_cast"; keywords.add(temp);
+ temp = "wchar_t"; keywords.add(temp);
+ temp = "#if"; keywords.add(temp);
+ temp = "protected"; keywords.add(temp);
+ temp = "throw"; keywords.add(temp);
+ temp = "using"; keywords.add(temp);
+ temp = "dynamic_cast"; keywords.add(temp);
+ temp = "#ifdef"; keywords.add(temp);
+ temp = "return"; keywords.add(temp);
+ temp = "short"; keywords.add(temp);
+ temp = "#error"; keywords.add(temp);
+ temp = "#line"; keywords.add(temp);
+ temp = "explicit"; keywords.add(temp);
+ temp = "union"; keywords.add(temp);
+ temp = "#ifndef"; keywords.add(temp);
+ temp = "try"; keywords.add(temp);
+ temp = "sizeof"; keywords.add(temp);
+ temp = "goto"; keywords.add(temp);
+ temp = "long"; keywords.add(temp);
+ temp = "#elif"; keywords.add(temp);
+ temp = "static"; keywords.add(temp);
+ temp = "operator"; keywords.add(temp);
+ temp = "switch"; keywords.add(temp);
+ temp = "extern"; keywords.add(temp);
+
+
+ // set the tokenizer's IDENTIFIER token for C++ identifiers
+ tokenizer.set_identifier_token(
+ "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters(),
+ "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() +
+ tokenizer.numbers()
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ cpp_tokenizer_kernel_1<tok,queue,set>::
+ ~cpp_tokenizer_kernel_1 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ void cpp_tokenizer_kernel_1<tok,queue,set>::
+ clear(
+ )
+ {
+ tokenizer.clear();
+ buffer.clear();
+ have_peeked = false;
+
+ // set the tokenizer's IDENTIFIER token for C++ identifiers
+ tokenizer.set_identifier_token(
+ "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters(),
+ "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() +
+ tokenizer.numbers()
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ void cpp_tokenizer_kernel_1<tok,queue,set>::
+ set_stream (
+ std::istream& in
+ )
+ {
+ tokenizer.set_stream(in);
+ buffer.clear();
+ have_peeked = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ bool cpp_tokenizer_kernel_1<tok,queue,set>::
+ stream_is_set (
+ ) const
+ {
+ return tokenizer.stream_is_set();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ std::istream& cpp_tokenizer_kernel_1<tok,queue,set>::
+ get_stream (
+ ) const
+ {
+ return tokenizer.get_stream();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ void cpp_tokenizer_kernel_1<tok,queue,set>::
+ get_token (
+ int& type,
+ std::string& token
+ )
+ {
+ using namespace std;
+
+ if (!have_peeked)
+ {
+
+ if (buffer.size() > 0)
+ {
+ // just return what is in the buffer
+ token_text_pair temp;
+ buffer.dequeue(temp);
+ type = temp.type;
+ token = temp.token;
+ return;
+ }
+
+ tokenizer.get_token(type,token);
+
+ switch (type)
+ {
+ case tok::END_OF_FILE:
+ {
+ type = END_OF_FILE;
+ } break;
+
+ case tok::END_OF_LINE:
+ case tok::WHITE_SPACE:
+ {
+ type = tokenizer.peek_type();
+ if (type == tok::END_OF_LINE || type == tok::WHITE_SPACE)
+ {
+ std::string temp;
+ do
+ {
+ tokenizer.get_token(type,temp);
+ token += temp;
+ type = tokenizer.peek_type();
+ }while (type == tok::END_OF_LINE || type == tok::WHITE_SPACE);
+ }
+ type = WHITE_SPACE;
+
+ } break;
+
+ case tok::NUMBER:
+ {
+ // this could be a hex number such as 0xa33. we should check for this.
+ if (tokenizer.peek_type() == tok::IDENTIFIER && token == "0" &&
+ (tokenizer.peek_token()[0] == 'x' || tokenizer.peek_token()[0] == 'X'))
+ {
+ // this is a hex number so accumulate all the numbers and identifiers that follow
+ // because they have to be part of the number
+ std::string temp;
+ tokenizer.get_token(type,temp);
+ token = "0" + temp;
+
+ // get the rest of the hex number
+ while (tokenizer.peek_type() == tok::IDENTIFIER ||
+ tokenizer.peek_type() == tok::NUMBER
+ )
+ {
+ tokenizer.get_token(type,temp);
+ token += temp;
+ }
+
+ }
+ // or this could be a floating point value or something with an 'e' or 'E' in it.
+ else if ((tokenizer.peek_type() == tok::CHAR && tokenizer.peek_token()[0] == '.') ||
+ (tokenizer.peek_type() == tok::IDENTIFIER && std::tolower(tokenizer.peek_token()[0]) == 'e'))
+ {
+ std::string temp;
+ tokenizer.get_token(type,temp);
+ token += temp;
+ // now get the rest of the floating point value
+ while (tokenizer.peek_type() == tok::IDENTIFIER ||
+ tokenizer.peek_type() == tok::NUMBER
+ )
+ {
+ tokenizer.get_token(type,temp);
+ token += temp;
+ }
+ }
+ type = NUMBER;
+
+ } break;
+
+ case tok::IDENTIFIER:
+ {
+ if (keywords.is_member(token))
+ {
+ type = KEYWORD;
+ }
+ else
+ {
+ type = IDENTIFIER;
+ }
+ } break;
+
+ case tok::CHAR:
+ type = OTHER;
+ switch (token[0])
+ {
+ case '#':
+ {
+ // this might be a preprocessor keyword so we should check the
+ // next token
+ if (tokenizer.peek_type() == tok::IDENTIFIER &&
+ keywords.is_member('#'+tokenizer.peek_token()))
+ {
+ tokenizer.get_token(type,token);
+ token = '#' + token;
+ type = KEYWORD;
+ }
+ else
+ {
+ token = '#';
+ type = OTHER;
+ }
+ }
+ break;
+
+ case '"':
+ {
+ string temp;
+ tokenizer.get_token(type,token);
+ while (type != tok::END_OF_FILE)
+ {
+ // if this is the end of the quoted string
+ if (type == tok::CHAR && token[0] == '"' &&
+ (temp.size() == 0 || temp[temp.size()-1] != '\\' ||
+ (temp.size() > 1 && temp[temp.size()-2] == '\\') ))
+ {
+ buffer_token(DOUBLE_QUOTED_TEXT,temp);
+ buffer_token(OTHER,"\"");
+ break;
+ }
+ else
+ {
+ temp += token;
+ }
+ tokenizer.get_token(type,token);
+ }
+
+
+ type = OTHER;
+ token = '"';
+ } break;
+
+ case '\'':
+ {
+ string temp;
+ tokenizer.get_token(type,token);
+ if (type == tok::CHAR && token[0] == '\\')
+ {
+ temp += '\\';
+ tokenizer.get_token(type,token);
+ }
+ temp += token;
+ buffer_token(SINGLE_QUOTED_TEXT,temp);
+
+ // The next character should be a ' so take it out and put it in
+ // the buffer.
+ tokenizer.get_token(type,token);
+ buffer_token(OTHER,token);
+
+ type = OTHER;
+ token = '\'';
+ } break;
+
+ case '/':
+ {
+ // look ahead to see if this is the start of a comment
+ if (tokenizer.peek_type() == tok::CHAR)
+ {
+ if (tokenizer.peek_token()[0] == '/')
+ {
+ tokenizer.get_token(type,token);
+ // this is the start of a line comment
+ token = "//";
+ string temp;
+ tokenizer.get_token(type,temp);
+ while (type != tok::END_OF_FILE)
+ {
+ // if this is the end of the comment
+ if (type == tok::END_OF_LINE &&
+ token[token.size()-1] != '\\' )
+ {
+ token += '\n';
+ break;
+ }
+ else
+ {
+ token += temp;
+ }
+ tokenizer.get_token(type,temp);
+ }
+ type = COMMENT;
+
+ }
+ else if (tokenizer.peek_token()[0] == '*')
+ {
+ tokenizer.get_token(type,token);
+ // this is the start of a block comment
+ token = "/*";
+ string temp;
+ tokenizer.get_token(type,temp);
+ while (type != tok::END_OF_FILE)
+ {
+ // if this is the end of the comment
+ if (type == tok::CHAR && temp[0] == '/' &&
+ token[token.size()-1] == '*')
+ {
+ token += '/';
+ break;
+ }
+ else
+ {
+ token += temp;
+ }
+ tokenizer.get_token(type,temp);
+ }
+ type = COMMENT;
+ }
+ }
+ } break;
+
+ default:
+ break;
+ } // switch (token[0])
+ } // switch (type)
+ }
+ else
+ {
+ // if we get this far it means we have peeked so we should
+ // return the peek data.
+ type = next_type;
+ token = next_token;
+ have_peeked = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ int cpp_tokenizer_kernel_1<tok,queue,set>::
+ peek_type (
+ ) const
+ {
+ const_cast<cpp_tokenizer_kernel_1<tok,queue,set>*>(this)->get_token(next_type,next_token);
+ have_peeked = true;
+ return next_type;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ const std::string& cpp_tokenizer_kernel_1<tok,queue,set>::
+ peek_token (
+ ) const
+ {
+ const_cast<cpp_tokenizer_kernel_1<tok,queue,set>*>(this)->get_token(next_type,next_token);
+ have_peeked = true;
+ return next_token;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tok,
+ typename queue,
+ typename set
+ >
+ void cpp_tokenizer_kernel_1<tok,queue,set>::
+ swap (
+ cpp_tokenizer_kernel_1& item
+ )
+ {
+ tokenizer.swap(item.tokenizer);
+ buffer.swap(item.buffer);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CPP_TOKENIZER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h
new file mode 100644
index 000000000..e7ac23284
--- /dev/null
+++ b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h
@@ -0,0 +1,224 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CPP_TOKENIZER_KERNEl_ABSTRACT_
+#ifdef DLIB_CPP_TOKENIZER_KERNEl_ABSTRACT_
+
+#include <string>
+#include <ioswfd>
+
+namespace dlib
+{
+
+ class cpp_tokenizer
+ {
+ /*!
+ INITIAL VALUE
+ stream_is_set() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple tokenizer for C++ source code.
+
+ BUFFERING
+ This object is allowed to buffer data from the input stream.
+ Thus if you clear it or switch streams (via calling set_stream())
+ any buffered data will be lost.
+
+ TOKENS
+ When picking out tokens the cpp_tokenizer will always extract the
+ longest token it can. For example, if faced with the string
+ "AAA" it will consider the three As to be a single IDENTIFIER
+ token not three smaller IDENTIFIER tokens.
+
+ Also note that no characters in the input stream are discarded.
+ They will all be returned in the text of some token.
+ Additionally, each character will never be returned more than once.
+ This means that if you concatenated all returned tokens it would exactly
+ reproduce the contents of the input stream.
+
+ The tokens are defined as follows:
+
+ END_OF_FILE
+ This token represents the end of file. It doesn't have any
+ actual characters associated with it.
+
+ KEYWORD
+ This token matches a C++ keyword. (This includes the preprocessor
+ directives).
+
+ COMMENT
+ This token matches a C++ comment.
+
+ SINGLE_QUOTED_TEXT
+ This token matches the text of any single quoted literal.
+ For example, 'a' would be a match and the text of this token
+ would be the single character a.
+
+ DOUBLE_QUOTED_TEXT
+ This token matches the text of any double quoted string.
+ For example, "C++" would be a match and the text of this token
+ would be the three character string C++.
+
+ WHITE_SPACE
+ This is a multi character token. It is defined as a sequence of
+ one or more spaces, carrage returns, newlines, and tabs. I.e. It
+ is composed of characters from the following string " \r\n\t".
+
+ IDENTIFIER
+ This token matches any C++ identifier that isn't matched by any
+ of the above tokens. (A C++ identifier being a string matching
+ the regular expression [_$a-zA-Z][_$a-zA-Z0-9]*).
+
+ NUMBER
+ This token matches any C++ numerical constant.
+
+ OTHER
+ This matches anything that isn't part of one of the above tokens.
+ It is always a single character.
+ !*/
+
+ public:
+
+ enum
+ {
+ END_OF_FILE,
+ KEYWORD,
+ COMMENT,
+ SINGLE_QUOTED_TEXT,
+ DOUBLE_QUOTED_TEXT,
+ IDENTIFIER,
+ OTHER,
+ NUMBER,
+ WHITE_SPACE
+ };
+
+ cpp_tokenizer (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~cpp_tokenizer (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ If this exception is thrown then #*this is unusable
+ until clear() is called and succeeds.
+ !*/
+
+ void set_stream (
+ std::istream& in
+ );
+ /*!
+ ensures
+ - #*this will read data from in and tokenize it
+ - #stream_is_set() == true
+ - #get_stream() == in
+ !*/
+
+ bool stream_is_set (
+ ) const;
+ /*!
+ ensures
+ - returns true if a stream has been associated with *this by calling
+ set_stream()
+ !*/
+
+ std::istream& get_stream (
+ ) const;
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - returns a reference to the istream object that *this is reading
+ from.
+ !*/
+
+ void get_token (
+ int& type,
+ std::string& token
+ );
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - #token == the next token from the input stream get_stream()
+ - #type == the type of the token in #token
+ throws
+ - bad_alloc
+ If this exception is thrown then the call to this function will
+ have no effect on *this but the values of #type and #token will be
+ undefined. Additionally, some characters may have been read
+ from the stream get_stream() and lost.
+ !*/
+
+ int peek_type (
+ ) const;
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - returns the type of the token that will be returned from
+ the next call to get_token()
+ throws
+ - bad_alloc
+ If this exception is thrown then the call to this function will
+ have no effect on *this. However, some characters may have been
+ read from the stream get_stream() and lost.
+ !*/
+
+ const std::string& peek_token (
+ ) const;
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - returns the text of the token that will be returned from
+ the next call to get_token()
+ throws
+ - bad_alloc
+ If this exception is thrown then the call to this function will
+ have no effect on *this. However, some characters may have been
+ read from the stream get_stream() and lost.
+ !*/
+
+ void swap (
+ cpp_tokenizer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ cpp_tokenizer(const cpp_tokenizer&); // copy constructor
+ cpp_tokenizer& operator=(const cpp_tokenizer&); // assignment operator
+
+ };
+
+ inline void swap (
+ cpp_tokenizer& a,
+ cpp_tokenizer& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_CPP_TOKENIZER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h
new file mode 100644
index 000000000..0073a5680
--- /dev/null
+++ b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h
@@ -0,0 +1,137 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CPP_TOKENIZER_KERNEl_C_
+#define DLIB_CPP_TOKENIZER_KERNEl_C_
+
+#include "cpp_tokenizer_kernel_abstract.h"
+#include "../assert.h"
+#include <string>
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename tokenizer
+ >
+ class cpp_tokenizer_kernel_c : public tokenizer
+ {
+
+ public:
+ std::istream& get_stream (
+ ) const;
+
+ void get_token (
+ int& type,
+ std::string& token
+ );
+
+ int peek_type (
+ ) const;
+
+ const std::string& peek_token (
+ ) const;
+
+ };
+
+ template <
+ typename tokenizer
+ >
+ inline void swap (
+ cpp_tokenizer_kernel_c<tokenizer>& a,
+ cpp_tokenizer_kernel_c<tokenizer>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tokenizer
+ >
+ std::istream& cpp_tokenizer_kernel_c<tokenizer>::
+ get_stream (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tstd::istream& cpp_tokenizer::get_stream()"
+ << "\n\tyou must set a stream for this object before you can get it"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return tokenizer::get_stream();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tokenizer
+ >
+ const std::string& cpp_tokenizer_kernel_c<tokenizer>::
+ peek_token (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tconst std::string& cpp_tokenizer::peek_token()"
+ << "\n\tyou must set a stream for this object before you can peek at what it contains"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return tokenizer::peek_token();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tokenizer
+ >
+ int cpp_tokenizer_kernel_c<tokenizer>::
+ peek_type (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tint cpp_tokenizer::peek_type()"
+ << "\n\tyou must set a stream for this object before you can peek at what it contains"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return tokenizer::peek_type();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tokenizer
+ >
+ void cpp_tokenizer_kernel_c<tokenizer>::
+ get_token (
+ int& type,
+ std::string& token
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tvoid cpp_tokenizer::get_token()"
+ << "\n\tyou must set a stream for this object before you can get tokens from it."
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ tokenizer::get_token(type,token);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TOKENIZER_KERNEl_C_
+
+
diff --git a/ml/dlib/dlib/crc32.h b/ml/dlib/dlib/crc32.h
new file mode 100644
index 000000000..004aaeba4
--- /dev/null
+++ b/ml/dlib/dlib/crc32.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CRc32_
+#define DLIB_CRc32_
+
+
+#include "crc32/crc32_kernel_1.h"
+
+#endif // DLIB_CRc32_
+
diff --git a/ml/dlib/dlib/crc32/crc32_kernel_1.h b/ml/dlib/dlib/crc32/crc32_kernel_1.h
new file mode 100644
index 000000000..4c679d2f2
--- /dev/null
+++ b/ml/dlib/dlib/crc32/crc32_kernel_1.h
@@ -0,0 +1,262 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CRC32_KERNEl_1_
+#define DLIB_CRC32_KERNEl_1_
+
+#include "../algs.h"
+#include <string>
+#include <vector>
+#include "crc32_kernel_abstract.h"
+
+namespace dlib
+{
+
+ class crc32
+ {
+ /*!
+ INITIAL VALUE
+ checksum == 0xFFFFFFFF
+
+ CONVENTION
+ get_checksum() == checksum ^ 0xFFFFFFFF
+ !*/
+
+ public:
+
+ // this is here for backwards compatibility with older versions of dlib.
+ typedef crc32 kernel_1a;
+
+ inline crc32 (
+ );
+
+ inline crc32 (
+ const std::string& item
+ );
+
+ inline crc32 (
+ const std::vector<char>& item
+ );
+
+ inline virtual ~crc32 (
+ );
+
+ inline void clear(
+ );
+
+ inline void add (
+ unsigned char item
+ );
+
+ inline void add (
+ const std::string& item
+ );
+
+ inline void add (
+ const std::vector<char>& item
+ );
+
+ inline operator unsigned long (
+ ) const { return get_checksum(); }
+
+ inline unsigned long get_checksum (
+ ) const;
+
+ inline void swap (
+ crc32& item
+ );
+
+ inline crc32& operator=(
+ const crc32&
+ );
+
+ private:
+
+ unsigned long checksum;
+
+ inline unsigned long table (
+ unsigned int idx
+ ) const
+ {
+ /*
+ // This code generates the crc_table used below.
+ unsigned long crc_table[256];
+ for (unsigned long i = 0; i < 256; ++i)
+ {
+ unsigned long temp = i;
+ for (unsigned long j = 0; j < 8; ++j)
+ {
+ if (temp&1)
+ temp = (temp>>1)^0xedb88320;
+ else
+ temp >>= 1;
+ }
+ crc_table[i] = temp;
+ std::cout << std::hex << crc_table[i] << std::endl;
+ }
+ */
+
+ const static unsigned long crc_table[256] = {
+ 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x76dc419, 0x706af48f, 0xe963a535, 0x9e6495a3,
+ 0xedb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x9b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91,
+ 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7,
+ 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5,
+ 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b,
+ 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59,
+ 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f,
+ 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d,
+ 0x76dc4190, 0x1db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x6b6b51f, 0x9fbfe4a5, 0xe8b8d433,
+ 0x7807c9a2, 0xf00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x86d3d2d, 0x91646c97, 0xe6635c01,
+ 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457,
+ 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65,
+ 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb,
+ 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9,
+ 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f,
+ 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad,
+ 0xedb88320, 0x9abfb3b6, 0x3b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x4db2615, 0x73dc1683,
+ 0xe3630b12, 0x94643b84, 0xd6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0xa00ae27, 0x7d079eb1,
+ 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7,
+ 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5,
+ 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b,
+ 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79,
+ 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f,
+ 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d,
+ 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x26d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x5005713,
+ 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0xcb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0xbdbdf21,
+ 0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777,
+ 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45,
+ 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db,
+ 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9,
+ 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf,
+ 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d
+ };
+
+ return crc_table[idx];
+ }
+
+ };
+
+ inline void swap (
+ crc32& a,
+ crc32& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ crc32::
+ crc32 (
+ )
+ {
+ checksum = 0xFFFFFFFF;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ crc32::
+ crc32 (
+ const std::string& item
+ )
+ {
+ checksum = 0xFFFFFFFF;
+ add(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ crc32::
+ crc32 (
+ const std::vector<char>& item
+ )
+ {
+ checksum = 0xFFFFFFFF;
+ add(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ crc32::
+ ~crc32 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void crc32::
+ clear(
+ )
+ {
+ checksum = 0xFFFFFFFF;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void crc32::
+ add (
+ unsigned char item
+ )
+ {
+ checksum = (checksum>>8) ^ table((checksum^item) & 0xFF);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void crc32::
+ add (
+ const std::string& item
+ )
+ {
+ for (std::string::size_type i = 0; i < item.size(); ++i)
+ checksum = (checksum>>8) ^ table((checksum^item[i]) & 0xFF);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void crc32::
+ add (
+ const std::vector<char>& item
+ )
+ {
+ for (unsigned long i = 0; i < item.size(); ++i)
+ checksum = (checksum>>8) ^ table((checksum^item[i]) & 0xFF);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long crc32::
+ get_checksum (
+ ) const
+ {
+ return checksum ^ 0xFFFFFFFF;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void crc32::
+ swap (
+ crc32& item
+ )
+ {
+ exchange(checksum,item.checksum);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ crc32& crc32::
+ operator=(
+ const crc32& item
+ )
+ {
+ checksum = item.checksum;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CRC32_KERNEl_1_
+
diff --git a/ml/dlib/dlib/crc32/crc32_kernel_abstract.h b/ml/dlib/dlib/crc32/crc32_kernel_abstract.h
new file mode 100644
index 000000000..76da49fbc
--- /dev/null
+++ b/ml/dlib/dlib/crc32/crc32_kernel_abstract.h
@@ -0,0 +1,132 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CRC32_KERNEl_ABSTRACT_
+#ifdef DLIB_CRC32_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+#include <string>
+#include <vector>
+
+namespace dlib
+{
+
+ class crc32
+ {
+ /*!
+ INITIAL VALUE
+ The current checksum covers zero bytes.
+ get_checksum() == 0x00000000
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the CRC32 algorithm for calculating
+ checksums.
+ !*/
+
+ public:
+
+ crc32 (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ !*/
+
+ crc32 (
+ const std::string& item
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - calls this->add(item).
+ (i.e. Using this constructor is the same as using the default
+ constructor and then calling add() on item)
+ !*/
+
+ crc32 (
+ const std::vector<char>& item
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - calls this->add(item).
+ (i.e. Using this constructor is the same as using the default
+ constructor and then calling add() on item)
+ !*/
+
+ virtual ~crc32 (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ !*/
+
+ void add (
+ unsigned char item
+ );
+ /*!
+ ensures
+ - #get_checksum() == The checksum of all items added to *this previously
+ concatenated with item.
+ !*/
+
+ void add (
+ const std::string& item
+ );
+ /*!
+ ensures
+ - #get_checksum() == The checksum of all items added to *this previously
+ concatenated with item.
+ !*/
+
+ void add (
+ const std::vector<char>& item
+ );
+ /*!
+ ensures
+ - #get_checksum() == The checksum of all items added to *this previously
+ concatenated with item.
+ !*/
+
+ unsigned long get_checksum (
+ ) const;
+ /*!
+ ensures
+ - returns the current checksum
+ !*/
+
+ operator unsigned long (
+ ) const;
+ /*!
+ ensures
+ - returns get_checksum()
+ !*/
+
+ void swap (
+ crc32& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ void swap (
+ crc32& a,
+ crc32& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_CRC32_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/cstring b/ml/dlib/dlib/cstring
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/cstring
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/data_io.h b/ml/dlib/dlib/data_io.h
new file mode 100644
index 000000000..845e95f40
--- /dev/null
+++ b/ml/dlib/dlib/data_io.h
@@ -0,0 +1,18 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DATA_Io_HEADER
+#define DLIB_DATA_Io_HEADER
+
+#include "data_io/libsvm_io.h"
+#include "data_io/image_dataset_metadata.h"
+#include "data_io/mnist.h"
+
+#ifndef DLIB_ISO_CPP_ONLY
+#include "data_io/load_image_dataset.h"
+#endif
+
+#endif // DLIB_DATA_Io_HEADER
+
+
+
+
diff --git a/ml/dlib/dlib/data_io/image_dataset_metadata.cpp b/ml/dlib/dlib/data_io/image_dataset_metadata.cpp
new file mode 100644
index 000000000..390ef6a0a
--- /dev/null
+++ b/ml/dlib/dlib/data_io/image_dataset_metadata.cpp
@@ -0,0 +1,411 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IMAGE_DAtASET_METADATA_CPPh_
+#define DLIB_IMAGE_DAtASET_METADATA_CPPh_
+
+#include "image_dataset_metadata.h"
+
+#include <fstream>
+#include <sstream>
+#include "../compress_stream.h"
+#include "../base64.h"
+#include "../xml_parser.h"
+#include "../string.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ namespace image_dataset_metadata
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ const std::string get_decoded_string();
+ void create_image_metadata_stylesheet_file(const std::string& main_filename)
+ {
+ std::string path;
+ std::string::size_type pos = main_filename.find_last_of("/\\");
+ if (pos != std::string::npos)
+ path = main_filename.substr(0,pos+1);
+
+ std::ofstream fout((path + "image_metadata_stylesheet.xsl").c_str());
+ if (!fout)
+ throw dlib::error("ERROR: Unable to open image_metadata_stylesheet.xsl for writing.");
+
+ fout << get_decoded_string();
+
+ if (!fout)
+ throw dlib::error("ERROR: Unable to write to image_metadata_stylesheet.xsl.");
+ }
+
+ void save_image_dataset_metadata (
+ const dataset& meta,
+ const std::string& filename
+ )
+ {
+ create_image_metadata_stylesheet_file(filename);
+
+ const std::vector<image>& images = meta.images;
+
+ std::ofstream fout(filename.c_str());
+ if (!fout)
+ throw dlib::error("ERROR: Unable to open " + filename + " for writing.");
+
+ fout << "<?xml version='1.0' encoding='ISO-8859-1'?>\n";
+ fout << "<?xml-stylesheet type='text/xsl' href='image_metadata_stylesheet.xsl'?>\n";
+ fout << "<dataset>\n";
+ fout << "<name>" << meta.name << "</name>\n";
+ fout << "<comment>" << meta.comment << "</comment>\n";
+ fout << "<images>\n";
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ fout << " <image file='" << images[i].filename << "'>\n";
+
+ // save all the boxes
+ for (unsigned long j = 0; j < images[i].boxes.size(); ++j)
+ {
+ const box& b = images[i].boxes[j];
+ fout << " <box top='" << b.rect.top() << "' "
+ << "left='" << b.rect.left() << "' "
+ << "width='" << b.rect.width() << "' "
+ << "height='" << b.rect.height() << "'";
+ if (b.difficult)
+ fout << " difficult='" << b.difficult << "'";
+ if (b.truncated)
+ fout << " truncated='" << b.truncated << "'";
+ if (b.occluded)
+ fout << " occluded='" << b.occluded << "'";
+ if (b.ignore)
+ fout << " ignore='" << b.ignore << "'";
+ if (b.angle != 0)
+ fout << " angle='" << b.angle << "'";
+ if (b.age != 0)
+ fout << " age='" << b.age << "'";
+ if (b.gender == FEMALE)
+ fout << " gender='female'";
+ else if (b.gender == MALE)
+ fout << " gender='male'";
+ if (b.pose != 0)
+ fout << " pose='" << b.pose << "'";
+ if (b.detection_score != 0)
+ fout << " detection_score='" << b.detection_score << "'";
+
+ if (b.has_label() || b.parts.size() != 0)
+ {
+ fout << ">\n";
+
+ if (b.has_label())
+ fout << " <label>" << b.label << "</label>\n";
+
+ // save all the parts
+ std::map<std::string,point>::const_iterator itr;
+ for (itr = b.parts.begin(); itr != b.parts.end(); ++itr)
+ {
+ fout << " <part name='"<< itr->first << "' x='"<< itr->second.x() <<"' y='"<< itr->second.y() <<"'/>\n";
+ }
+
+ fout << " </box>\n";
+ }
+ else
+ {
+ fout << "/>\n";
+ }
+ }
+
+
+
+ fout << " </image>\n";
+
+ if (!fout)
+ throw dlib::error("ERROR: Unable to write to " + filename + ".");
+ }
+ fout << "</images>\n";
+ fout << "</dataset>";
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ class doc_handler : public document_handler
+ {
+ std::vector<std::string> ts;
+ image temp_image;
+ box temp_box;
+
+ dataset& meta;
+
+ public:
+
+ doc_handler(
+ dataset& metadata_
+ ):
+ meta(metadata_)
+ {}
+
+
+ virtual void start_document (
+ )
+ {
+ meta = dataset();
+ ts.clear();
+ temp_image = image();
+ temp_box = box();
+ }
+
+ virtual void end_document (
+ )
+ {
+ }
+
+ virtual void start_element (
+ const unsigned long line_number,
+ const std::string& name,
+ const dlib::attribute_list& atts
+ )
+ {
+ try
+ {
+ if (ts.size() == 0)
+ {
+ if (name != "dataset")
+ {
+ std::ostringstream sout;
+ sout << "Invalid XML document. Root tag must be <dataset>. Found <" << name << "> instead.";
+ throw dlib::error(sout.str());
+ }
+ else
+ {
+ ts.push_back(name);
+ return;
+ }
+ }
+
+
+ if (name == "box")
+ {
+ if (atts.is_in_list("top")) temp_box.rect.top() = sa = atts["top"];
+ else throw dlib::error("<box> missing required attribute 'top'");
+
+ if (atts.is_in_list("left")) temp_box.rect.left() = sa = atts["left"];
+ else throw dlib::error("<box> missing required attribute 'left'");
+
+ if (atts.is_in_list("width")) temp_box.rect.right() = sa = atts["width"];
+ else throw dlib::error("<box> missing required attribute 'width'");
+
+ if (atts.is_in_list("height")) temp_box.rect.bottom() = sa = atts["height"];
+ else throw dlib::error("<box> missing required attribute 'height'");
+
+ if (atts.is_in_list("difficult")) temp_box.difficult = sa = atts["difficult"];
+ if (atts.is_in_list("truncated")) temp_box.truncated = sa = atts["truncated"];
+ if (atts.is_in_list("occluded")) temp_box.occluded = sa = atts["occluded"];
+ if (atts.is_in_list("ignore")) temp_box.ignore = sa = atts["ignore"];
+ if (atts.is_in_list("angle")) temp_box.angle = sa = atts["angle"];
+ if (atts.is_in_list("age")) temp_box.age = sa = atts["age"];
+ if (atts.is_in_list("gender"))
+ {
+ if (atts["gender"] == "male")
+ temp_box.gender = MALE;
+ else if (atts["gender"] == "female")
+ temp_box.gender = FEMALE;
+ else if (atts["gender"] == "unknown")
+ temp_box.gender = UNKNOWN;
+ else
+ throw dlib::error("Invalid gender string in box attribute.");
+ }
+ if (atts.is_in_list("pose")) temp_box.pose = sa = atts["pose"];
+ if (atts.is_in_list("detection_score")) temp_box.detection_score = sa = atts["detection_score"];
+
+ temp_box.rect.bottom() += temp_box.rect.top()-1;
+ temp_box.rect.right() += temp_box.rect.left()-1;
+ }
+ else if (name == "part" && ts.back() == "box")
+ {
+ point temp;
+ if (atts.is_in_list("x")) temp.x() = sa = atts["x"];
+ else throw dlib::error("<part> missing required attribute 'x'");
+
+ if (atts.is_in_list("y")) temp.y() = sa = atts["y"];
+ else throw dlib::error("<part> missing required attribute 'y'");
+
+ if (atts.is_in_list("name"))
+ {
+ if (temp_box.parts.count(atts["name"])==0)
+ {
+ temp_box.parts[atts["name"]] = temp;
+ }
+ else
+ {
+ throw dlib::error("<part> with name '" + atts["name"] + "' is defined more than one time in a single box.");
+ }
+ }
+ else
+ {
+ throw dlib::error("<part> missing required attribute 'name'");
+ }
+ }
+ else if (name == "image")
+ {
+ temp_image.boxes.clear();
+
+ if (atts.is_in_list("file")) temp_image.filename = atts["file"];
+ else throw dlib::error("<image> missing required attribute 'file'");
+ }
+
+ ts.push_back(name);
+ }
+ catch (error& e)
+ {
+ throw dlib::error("Error on line " + cast_to_string(line_number) + ": " + e.what());
+ }
+ }
+
+ virtual void end_element (
+ const unsigned long ,
+ const std::string& name
+ )
+ {
+ ts.pop_back();
+ if (ts.size() == 0)
+ return;
+
+ if (name == "box" && ts.back() == "image")
+ {
+ temp_image.boxes.push_back(temp_box);
+ temp_box = box();
+ }
+ else if (name == "image" && ts.back() == "images")
+ {
+ meta.images.push_back(temp_image);
+ temp_image = image();
+ }
+ }
+
+ virtual void characters (
+ const std::string& data
+ )
+ {
+ if (ts.size() == 2 && ts[1] == "name")
+ {
+ meta.name = trim(data);
+ }
+ else if (ts.size() == 2 && ts[1] == "comment")
+ {
+ meta.comment = trim(data);
+ }
+ else if (ts.size() >= 2 && ts[ts.size()-1] == "label" &&
+ ts[ts.size()-2] == "box")
+ {
+ temp_box.label = trim(data);
+ }
+ }
+
+ virtual void processing_instruction (
+ const unsigned long ,
+ const std::string& ,
+ const std::string&
+ )
+ {
+ }
+ };
+
+ // ----------------------------------------------------------------------------------------
+
+ class xml_error_handler : public error_handler
+ {
+ public:
+ virtual void error (
+ const unsigned long
+ ) { }
+
+ virtual void fatal_error (
+ const unsigned long line_number
+ )
+ {
+ std::ostringstream sout;
+ sout << "There is a fatal error on line " << line_number << " so parsing will now halt.";
+ throw dlib::error(sout.str());
+ }
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ void load_image_dataset_metadata (
+ dataset& meta,
+ const std::string& filename
+ )
+ {
+ xml_error_handler eh;
+ doc_handler dh(meta);
+
+ std::ifstream fin(filename.c_str());
+ if (!fin)
+ throw dlib::error("ERROR: unable to open " + filename + " for reading.");
+
+ xml_parser parser;
+ parser.add_document_handler(dh);
+ parser.add_error_handler(eh);
+ parser.parse(fin);
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'images.xsl'
+ const std::string get_decoded_string()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'image_metadata_stylesheet.xsl' we want to decode and return.
+ sout << "PFWfgmWfCHr1DkV63lbjjeY2dCc2FbHDOVh0Kd7dkvaOfRYrOG24f0x77/5iMVq8FtE3UBxtGwSd";
+ sout << "1ZHOHRSHgieNoeBv8ssJQ75RRxYtFKRY3OTPX5eKQoCN9jUaUnHnR4QZtEHgmKqXSs50Yrdd+2Ah";
+ sout << "gNyarPZCiR6nvqNvCjtP2MP5FxleqNf8Fylatm2KdsXmrv5K87LYVN7i7JMkmZ++cTXYSOxDmxZi";
+ sout << "OiCH8funXUdF9apDW547gCjz9HOQUI6dkz5dYUeFjfp6dFugpnaJyyprFLKq048Qk7+QiL4CNF/G";
+ sout << "7e0VpBw8dMpiyRNi2fSQGSZGfIAUQKKT6+rPwQoRH2spdjsdXVWj4XQAqBX87nmqMnqjMhn/Vd1s";
+ sout << "W5aoC0drwRGu3Xe3gn9vBL8hBkRXcJvEy6q/lb9bYnsLemhE5Zp/+nTmTBjfT9UFYLcsmgsjC+4n";
+ sout << "Bq6h9QlpuyMYqJ8RvW8pp3mFlvXc3Yg+18t5F0hSMQfaIFYAuDPU2lVzPpY+ba0B39iu9IrPCLsS";
+ sout << "+tUtSNSmQ74CtzZgKKjkTMA3nwYP2SDmZE3firq42pihT7hdU5vYkes69K8AQl8WZyLPpMww+r0z";
+ sout << "+veEHPlAuxF7kL3ZvVjdB+xABwwqDe0kSRHRZINYdUfJwJdfYLyDnYoMjj6afqIJZ7QOBPZ42tV5";
+ sout << "3hYOQTFwTNovOastzJJXQe1kxPg1AQ8ynmfjjJZqD0xKedlyeJybP919mVAA23UryHsq9TVlabou";
+ sout << "qNl3xZW/mKKktvVsd/nuH62HIv/kgomyhaEUY5HgupupBUbQFZfyljZ5bl3g3V3Y1400Z1xTM/LL";
+ sout << "LJpeLdlqoGzIe/19vAN1zUUVId9F/OLNUl3Zoar63yZERSJHcsuq/Pasisp0HIGi7rfI9EIQF7C/";
+ sout << "IhLKLZsJ+LOycreQGOJALZIEZHOqxYLSXG0qaPM5bQL/MQJ2OZfwEhQgYOrjaM7oPOHHEfTq5kcO";
+ sout << "daMwzefKfxrF2GXbUs0bYsEXsIGwENIUKMliFaAI4qKLxxb94oc+O3BRjWueZjZty2zKawQyTHNd";
+ sout << "ltFJBUzfffdZN9Wq4zbPzntkM3U6Ys4LRztx5M15dtbhFeKx5rAf2tPXT6wU01hx7EJxBJzpvoDE";
+ sout << "YwEoYVDSYulRKpgk82cHFzzUDgWXbl4paFSe1L1w8r9KHr67SYJDTUG86Lrm6LJ0rw73Xp0NAFcU";
+ sout << "MKpiG9g1cHW74HYbUb/yAbtVWt40eB7M637umdo2jWz/r/vP5WnfSMXEbkyWebsa1fFceg/TLWy6";
+ sout << "E8OTc4XKB48h1oFIlGagOiprxho3+F3TIcxDSwA=";
+
+
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_IMAGE_DAtASET_METADATA_CPPh_
+
+
diff --git a/ml/dlib/dlib/data_io/image_dataset_metadata.h b/ml/dlib/dlib/data_io/image_dataset_metadata.h
new file mode 100644
index 000000000..3dac29ba6
--- /dev/null
+++ b/ml/dlib/dlib/data_io/image_dataset_metadata.h
@@ -0,0 +1,174 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IMAGE_DAtASET_METADATA_Hh_
+#define DLIB_IMAGE_DAtASET_METADATA_Hh_
+
+#include <string>
+#include <vector>
+#include "../geometry.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ namespace image_dataset_metadata
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ enum gender_t
+ {
+ UNKNOWN,
+ MALE,
+ FEMALE
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ struct box
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an annotated rectangular area of an image.
+ It is typically used to mark the location of an object such as a
+ person, car, etc.
+
+ The main variable of interest is rect. It gives the location of
+ the box. All the other variables are optional.
+ !*/
+
+ box(
+ ) :
+ difficult(false),
+ truncated(false),
+ occluded(false),
+ ignore(false),
+ pose(0),
+ detection_score(0),
+ angle(0),
+ gender(UNKNOWN),
+ age(0)
+ {}
+
+ box (
+ const rectangle& rect_
+ ) :
+ rect(rect_),
+ difficult(false),
+ truncated(false),
+ occluded(false),
+ ignore(false),
+ pose(0),
+ detection_score(0),
+ angle(0),
+ gender(UNKNOWN),
+ age(0)
+ {}
+
+ rectangle rect;
+
+ std::map<std::string,point> parts;
+
+ // optional fields
+ std::string label;
+ bool difficult;
+ bool truncated;
+ bool occluded;
+ bool ignore;
+ double pose;
+ double detection_score;
+
+ // The angle of the object in radians. Positive values indicate that the
+ // object at the center of the box is rotated clockwise by angle radians. A
+ // value of 0 would indicate that the object is in its "standard" upright pose.
+ // Therefore, to make the object appear upright we would have to rotate the
+ // image counter-clockwise by angle radians.
+ double angle;
+
+ gender_t gender;
+ double age;
+
+ bool has_label() const { return label.size() != 0; }
+ /*!
+ ensures
+ - returns true if label metadata is present and false otherwise.
+ !*/
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ struct image
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an annotated image.
+ !*/
+
+ image() {}
+ image(const std::string& f) : filename(f) {}
+
+ std::string filename;
+ std::vector<box> boxes;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ struct dataset
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a labeled set of images. In particular, it
+ contains the filename for each image as well as annotated boxes.
+ !*/
+
+ std::vector<image> images;
+ std::string comment;
+ std::string name;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ void save_image_dataset_metadata (
+ const dataset& meta,
+ const std::string& filename
+ );
+ /*!
+ ensures
+ - Writes the contents of the meta object to a file with the given
+ filename. The file will be in an XML format.
+ throws
+ - dlib::error
+ This exception is thrown if there is an error which prevents
+ this function from succeeding.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ void load_image_dataset_metadata (
+ dataset& meta,
+ const std::string& filename
+ );
+ /*!
+ ensures
+ - Attempts to interpret filename as a file containing XML formatted data
+ as produced by the save_image_dataset_metadata() function. Then
+ meta is loaded with the contents of the file.
+ throws
+ - dlib::error
+ This exception is thrown if there is an error which prevents
+ this function from succeeding.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef NO_MAKEFILE
+#include "image_dataset_metadata.cpp"
+#endif
+
+#endif // DLIB_IMAGE_DAtASET_METADATA_Hh_
+
diff --git a/ml/dlib/dlib/data_io/libsvm_io.h b/ml/dlib/dlib/data_io/libsvm_io.h
new file mode 100644
index 000000000..f365e82d7
--- /dev/null
+++ b/ml/dlib/dlib/data_io/libsvm_io.h
@@ -0,0 +1,276 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LIBSVM_iO_Hh_
+#define DLIB_LIBSVM_iO_Hh_
+
+#include "libsvm_io_abstract.h"
+
+#include <fstream>
+#include <string>
+#include <utility>
+#include "../algs.h"
+#include "../matrix.h"
+#include "../string.h"
+#include "../svm/sparse_vector.h"
+#include <vector>
+
+namespace dlib
+{
+ struct sample_data_io_error : public error
+ {
+ sample_data_io_error(const std::string& message): error(message) {}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename sample_type, typename label_type, typename alloc1, typename alloc2>
+ void load_libsvm_formatted_data (
+ const std::string& file_name,
+ std::vector<sample_type, alloc1>& samples,
+ std::vector<label_type, alloc2>& labels
+ )
+ {
+ using namespace std;
+ typedef typename sample_type::value_type pair_type;
+ typedef typename basic_type<typename pair_type::first_type>::type key_type;
+ typedef typename pair_type::second_type value_type;
+
+ // You must use unsigned integral key types in your sparse vectors
+ COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value);
+
+ samples.clear();
+ labels.clear();
+
+ ifstream fin(file_name.c_str());
+
+ if (!fin)
+ throw sample_data_io_error("Unable to open file " + file_name);
+
+ string line;
+ istringstream sin;
+ key_type key;
+ value_type value;
+ label_type label;
+ sample_type sample;
+ long line_num = 0;
+ while (fin.peek() != EOF)
+ {
+ ++line_num;
+ getline(fin, line);
+
+ string::size_type pos = line.find_first_not_of(" \t\r\n");
+
+ // ignore empty lines or comment lines
+ if (pos == string::npos || line[pos] == '#')
+ continue;
+
+ sin.clear();
+ sin.str(line);
+ sample.clear();
+
+ sin >> label;
+
+ if (!sin)
+ throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name );
+
+ // eat whitespace
+ sin >> ws;
+
+ while (sin.peek() != EOF && sin.peek() != '#')
+ {
+
+ sin >> key >> ws;
+
+ // ignore what should be a : character
+ if (sin.get() != ':')
+ throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name);
+
+ sin >> value;
+
+ if (sin && value != 0)
+ {
+ sample.insert(sample.end(), make_pair(key, value));
+ }
+
+ sin >> ws;
+ }
+
+ samples.push_back(sample);
+ labels.push_back(label);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename sample_type, typename alloc>
+ typename enable_if<is_const_type<typename sample_type::value_type::first_type> >::type
+ fix_nonzero_indexing (
+ std::vector<sample_type,alloc>& samples
+ )
+ {
+ typedef typename sample_type::value_type pair_type;
+ typedef typename basic_type<typename pair_type::first_type>::type key_type;
+
+ if (samples.size() == 0)
+ return;
+
+ // figure out the min index value
+ key_type min_idx = samples[0].begin()->first;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ min_idx = std::min(min_idx, samples[i].begin()->first);
+
+ // Now adjust all the samples so that their min index value is zero.
+ if (min_idx != 0)
+ {
+ sample_type temp;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ // copy samples[i] into temp but make sure it has a min index of zero.
+ temp.clear();
+ typename sample_type::iterator j;
+ for (j = samples[i].begin(); j != samples[i].end(); ++j)
+ {
+ temp.insert(temp.end(), std::make_pair(j->first-min_idx, j->second));
+ }
+
+ // replace the current sample with temp.
+ samples[i].swap(temp);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// If the "first" values in the std::pair objects are not const then we can modify them
+// directly and that is what this version of fix_nonzero_indexing() does.
+ template <typename sample_type, typename alloc>
+ typename disable_if<is_const_type<typename sample_type::value_type::first_type> >::type
+ fix_nonzero_indexing (
+ std::vector<sample_type,alloc>& samples
+ )
+ {
+ typedef typename sample_type::value_type pair_type;
+ typedef typename basic_type<typename pair_type::first_type>::type key_type;
+
+ if (samples.size() == 0)
+ return;
+
+ // figure out the min index value
+ key_type min_idx = samples[0].begin()->first;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ min_idx = std::min(min_idx, samples[i].begin()->first);
+
+ // Now adjust all the samples so that their min index value is zero.
+ if (min_idx != 0)
+ {
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ typename sample_type::iterator j;
+ for (j = samples[i].begin(); j != samples[i].end(); ++j)
+ {
+ j->first -= min_idx;
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+// This is an overload for sparse vectors
+ template <typename sample_type, typename label_type, typename alloc1, typename alloc2>
+ typename disable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data (
+ const std::string& file_name,
+ const std::vector<sample_type, alloc1>& samples,
+ const std::vector<label_type, alloc2>& labels
+ )
+ {
+ typedef typename sample_type::value_type pair_type;
+ typedef typename basic_type<typename pair_type::first_type>::type key_type;
+
+ // You must use unsigned integral key types in your sparse vectors
+ COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(samples.size() == labels.size(),
+ "\t void save_libsvm_formatted_data()"
+ << "\n\t You have to have labels for each sample and vice versa"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t labels.size(): " << labels.size()
+ );
+
+
+ using namespace std;
+ ofstream fout(file_name.c_str());
+ fout.precision(14);
+
+ if (!fout)
+ throw sample_data_io_error("Unable to open file " + file_name);
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ fout << labels[i];
+
+ for (typename sample_type::const_iterator j = samples[i].begin(); j != samples[i].end(); ++j)
+ {
+ if (j->second != 0)
+ fout << " " << j->first << ":" << j->second;
+ }
+ fout << "\n";
+
+ if (!fout)
+ throw sample_data_io_error("Error while writing to file " + file_name);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// This is an overload for dense vectors
+ template <typename sample_type, typename label_type, typename alloc1, typename alloc2>
+ typename enable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data (
+ const std::string& file_name,
+ const std::vector<sample_type, alloc1>& samples,
+ const std::vector<label_type, alloc2>& labels
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(samples.size() == labels.size(),
+ "\t void save_libsvm_formatted_data()"
+ << "\n\t You have to have labels for each sample and vice versa"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t labels.size(): " << labels.size()
+ );
+
+ using namespace std;
+ ofstream fout(file_name.c_str());
+ fout.precision(14);
+
+ if (!fout)
+ throw sample_data_io_error("Unable to open file " + file_name);
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ fout << labels[i];
+
+ for (long j = 0; j < samples[i].size(); ++j)
+ {
+ if (samples[i](j) != 0)
+ fout << " " << j << ":" << samples[i](j);
+ }
+ fout << "\n";
+
+ if (!fout)
+ throw sample_data_io_error("Error while writing to file " + file_name);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LIBSVM_iO_Hh_
+
diff --git a/ml/dlib/dlib/data_io/libsvm_io_abstract.h b/ml/dlib/dlib/data_io/libsvm_io_abstract.h
new file mode 100644
index 000000000..88d934fdb
--- /dev/null
+++ b/ml/dlib/dlib/data_io/libsvm_io_abstract.h
@@ -0,0 +1,125 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LIBSVM_iO_ABSTRACT_Hh_
+#ifdef DLIB_LIBSVM_iO_ABSTRACT_Hh_
+
+#include <fstream>
+#include <string>
+#include <utility>
+#include "../algs.h"
+#include "../matrix.h"
+#include <vector>
+
+namespace dlib
+{
+ struct sample_data_io_error : public error
+ {
+ /*!
+ This is the exception class used by the file IO functions defined below.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type,
+ typename label_type,
+ typename alloc1,
+ typename alloc2
+ >
+ void load_libsvm_formatted_data (
+ const std::string& file_name,
+ std::vector<sample_type, alloc1>& samples,
+ std::vector<label_type, alloc2>& labels
+ );
+ /*!
+ requires
+ - sample_type must be an STL container
+ - sample_type::value_type == std::pair<T,U> where T is some kind of
+ unsigned integral type
+ ensures
+ - attempts to read a file of the given name that should contain libsvm
+ formatted data. We turn the data into sparse vectors and store it
+ in samples
+ - #labels.size() == #samples.size()
+ - for all valid i: #labels[i] is the label for #samples[i]
+ throws
+ - sample_data_io_error
+ This exception is thrown if there is any problem loading data from file
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type,
+ typename label_type,
+ typename alloc1,
+ typename alloc2
+ >
+ void save_libsvm_formatted_data (
+ const std::string& file_name,
+ const std::vector<sample_type, alloc1>& samples,
+ const std::vector<label_type, alloc2>& labels
+ );
+ /*!
+ requires
+ - sample_type must be an STL container
+ - sample_type::value_type == std::pair<T,U> where T is some kind of
+ unsigned integral type
+ - samples.size() == labels.size()
+ ensures
+ - saves the data to the given file in libsvm format
+ throws
+ - sample_data_io_error
+ This exception is thrown if there is any problem saving data to file
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type,
+ typename label_type,
+ typename alloc1,
+ typename alloc2
+ >
+ void save_libsvm_formatted_data (
+ const std::string& file_name,
+ const std::vector<sample_type, alloc1>& samples,
+ const std::vector<label_type, alloc2>& labels
+ );
+ /*!
+ requires
+ - sample_type == a dense matrix (i.e. dlib::matrix)
+ - for all valid i: is_vector(samples[i]) == true
+ - samples.size() == labels.size()
+ ensures
+ - saves the data to the given file in libsvm format
+ throws
+ - sample_data_io_error
+ This exception is thrown if there is any problem saving data to file
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename sample_type, typename alloc>
+ void fix_nonzero_indexing (
+ std::vector<sample_type,alloc>& samples
+ );
+ /*!
+ requires
+ - samples must only contain valid sparse vectors. The definition of
+ a sparse vector can be found at the top of dlib/svm/sparse_vector_abstract.h
+ ensures
+ - Adjusts the sparse vectors in samples so that they are zero-indexed.
+ Or in other words, assume the smallest used index value in any of the sparse
+ vectors is N. Then this function subtracts N from all the index values in
+ samples. This is useful, for example, if you load a libsvm formatted datafile
+ with features indexed from 1 rather than 0 and you would like to fix this.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LIBSVM_iO_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/data_io/load_image_dataset.h b/ml/dlib/dlib/data_io/load_image_dataset.h
new file mode 100644
index 000000000..5664d96b2
--- /dev/null
+++ b/ml/dlib/dlib/data_io/load_image_dataset.h
@@ -0,0 +1,510 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LOAD_IMAGE_DaTASET_Hh_
+#define DLIB_LOAD_IMAGE_DaTASET_Hh_
+
+#include "load_image_dataset_abstract.h"
+#include "../misc_api.h"
+#include "../dir_nav.h"
+#include "../image_io.h"
+#include "../array.h"
+#include <vector>
+#include "../geometry.h"
+#include "image_dataset_metadata.h"
+#include <string>
+#include <set>
+#include "../image_processing/full_object_detection.h"
+#include <utility>
+#include <limits>
+#include "../image_transforms/image_pyramid.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class image_dataset_file
+ {
+ public:
+ image_dataset_file(const std::string& filename)
+ {
+ _skip_empty_images = false;
+ _have_parts = false;
+ _filename = filename;
+ _box_area_thresh = std::numeric_limits<double>::infinity();
+ }
+
+ image_dataset_file boxes_match_label(
+ const std::string& label
+ ) const
+ {
+ image_dataset_file temp(*this);
+ temp._labels.insert(label);
+ return temp;
+ }
+
+ image_dataset_file skip_empty_images(
+ ) const
+ {
+ image_dataset_file temp(*this);
+ temp._skip_empty_images = true;
+ return temp;
+ }
+
+ image_dataset_file boxes_have_parts(
+ ) const
+ {
+ image_dataset_file temp(*this);
+ temp._have_parts = true;
+ return temp;
+ }
+
+ image_dataset_file shrink_big_images(
+ double new_box_area_thresh = 150*150
+ ) const
+ {
+ image_dataset_file temp(*this);
+ temp._box_area_thresh = new_box_area_thresh;
+ return temp;
+ }
+
+ bool should_load_box (
+ const image_dataset_metadata::box& box
+ ) const
+ {
+ if (_have_parts && box.parts.size() == 0)
+ return false;
+ if (_labels.size() == 0)
+ return true;
+ if (_labels.count(box.label) != 0)
+ return true;
+ return false;
+ }
+
+ const std::string& get_filename() const { return _filename; }
+ bool should_skip_empty_images() const { return _skip_empty_images; }
+ bool should_boxes_have_parts() const { return _have_parts; }
+ double box_area_thresh() const { return _box_area_thresh; }
+ const std::set<std::string>& get_selected_box_labels() const { return _labels; }
+
+ private:
+ std::string _filename;
+ std::set<std::string> _labels;
+ bool _skip_empty_images;
+ bool _have_parts;
+ double _box_area_thresh;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations,
+ const image_dataset_file& source
+ )
+ {
+ images.clear();
+ object_locations.clear();
+
+ std::vector<std::vector<rectangle> > ignored_rects;
+
+ using namespace dlib::image_dataset_metadata;
+ dataset data;
+ load_image_dataset_metadata(data, source.get_filename());
+
+ // Set the current directory to be the one that contains the
+ // metadata file. We do this because the file might contain
+ // file paths which are relative to this folder.
+ locally_change_current_dir chdir(get_parent_directory(file(source.get_filename())));
+
+
+ typedef typename array_type::value_type image_type;
+
+
+ image_type img;
+ std::vector<rectangle> rects, ignored;
+ for (unsigned long i = 0; i < data.images.size(); ++i)
+ {
+ double min_rect_size = std::numeric_limits<double>::infinity();
+ rects.clear();
+ ignored.clear();
+ for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j)
+ {
+ if (source.should_load_box(data.images[i].boxes[j]))
+ {
+ if (data.images[i].boxes[j].ignore)
+ {
+ ignored.push_back(data.images[i].boxes[j].rect);
+ }
+ else
+ {
+ rects.push_back(data.images[i].boxes[j].rect);
+ min_rect_size = std::min<double>(min_rect_size, rects.back().area());
+ }
+ }
+ }
+
+ if (!source.should_skip_empty_images() || rects.size() != 0)
+ {
+ load_image(img, data.images[i].filename);
+ if (rects.size() != 0)
+ {
+ // if shrinking the image would still result in the smallest box being
+ // bigger than the box area threshold then shrink the image.
+ while(min_rect_size/2/2 > source.box_area_thresh())
+ {
+ pyramid_down<2> pyr;
+ pyr(img);
+ min_rect_size *= (1.0/2.0)*(1.0/2.0);
+ for (auto&& r : rects)
+ r = pyr.rect_down(r);
+ for (auto&& r : ignored)
+ r = pyr.rect_down(r);
+ }
+ while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh())
+ {
+ pyramid_down<3> pyr;
+ pyr(img);
+ min_rect_size *= (2.0/3.0)*(2.0/3.0);
+ for (auto&& r : rects)
+ r = pyr.rect_down(r);
+ for (auto&& r : ignored)
+ r = pyr.rect_down(r);
+ }
+ }
+ images.push_back(img);
+ object_locations.push_back(rects);
+ ignored_rects.push_back(ignored);
+ }
+ }
+
+ return ignored_rects;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline size_t num_non_ignored_boxes (const std::vector<mmod_rect>& rects)
+ {
+ size_t cnt = 0;
+ for (auto& b : rects)
+ {
+ if (!b.ignore)
+ cnt++;
+ }
+ return cnt;
+ }
+ }
+
+ template <
+ typename array_type
+ >
+ void load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<mmod_rect> >& object_locations,
+ const image_dataset_file& source
+ )
+ {
+ images.clear();
+ object_locations.clear();
+
+ using namespace dlib::image_dataset_metadata;
+ dataset data;
+ load_image_dataset_metadata(data, source.get_filename());
+
+ // Set the current directory to be the one that contains the
+ // metadata file. We do this because the file might contain
+ // file paths which are relative to this folder.
+ locally_change_current_dir chdir(get_parent_directory(file(source.get_filename())));
+
+ typedef typename array_type::value_type image_type;
+
+ image_type img;
+ std::vector<mmod_rect> rects;
+ for (unsigned long i = 0; i < data.images.size(); ++i)
+ {
+ double min_rect_size = std::numeric_limits<double>::infinity();
+ rects.clear();
+ for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j)
+ {
+ if (source.should_load_box(data.images[i].boxes[j]))
+ {
+ if (data.images[i].boxes[j].ignore)
+ {
+ rects.push_back(ignored_mmod_rect(data.images[i].boxes[j].rect));
+ }
+ else
+ {
+ rects.push_back(mmod_rect(data.images[i].boxes[j].rect));
+ min_rect_size = std::min<double>(min_rect_size, rects.back().rect.area());
+ }
+ rects.back().label = data.images[i].boxes[j].label;
+
+ }
+ }
+
+ if (!source.should_skip_empty_images() || impl::num_non_ignored_boxes(rects) != 0)
+ {
+ load_image(img, data.images[i].filename);
+ if (rects.size() != 0)
+ {
+ // if shrinking the image would still result in the smallest box being
+ // bigger than the box area threshold then shrink the image.
+ while(min_rect_size/2/2 > source.box_area_thresh())
+ {
+ pyramid_down<2> pyr;
+ pyr(img);
+ min_rect_size *= (1.0/2.0)*(1.0/2.0);
+ for (auto&& r : rects)
+ r.rect = pyr.rect_down(r.rect);
+ }
+ while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh())
+ {
+ pyramid_down<3> pyr;
+ pyr(img);
+ min_rect_size *= (2.0/3.0)*(2.0/3.0);
+ for (auto&& r : rects)
+ r.rect = pyr.rect_down(r.rect);
+ }
+ }
+ images.push_back(std::move(img));
+ object_locations.push_back(std::move(rects));
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// ******* THIS FUNCTION IS DEPRECATED, you should use another version of load_image_dataset() *******
+ template <
+ typename image_type,
+ typename MM
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array<image_type,MM>& images,
+ std::vector<std::vector<rectangle> >& object_locations,
+ const std::string& filename,
+ const std::string& label,
+ bool skip_empty_images = false
+ )
+ {
+ image_dataset_file f(filename);
+ if (label.size() != 0)
+ f = f.boxes_match_label(label);
+ if (skip_empty_images)
+ f = f.skip_empty_images();
+ return load_image_dataset(images, object_locations, f);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations,
+ const std::string& filename
+ )
+ {
+ return load_image_dataset(images, object_locations, image_dataset_file(filename));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ void load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<mmod_rect>>& object_locations,
+ const std::string& filename
+ )
+ {
+ load_image_dataset(images, object_locations, image_dataset_file(filename));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<full_object_detection> >& object_locations,
+ const image_dataset_file& source,
+ std::vector<std::string>& parts_list
+ )
+ {
+ typedef typename array_type::value_type image_type;
+ parts_list.clear();
+ images.clear();
+ object_locations.clear();
+
+ using namespace dlib::image_dataset_metadata;
+ dataset data;
+ load_image_dataset_metadata(data, source.get_filename());
+
+ // Set the current directory to be the one that contains the
+ // metadata file. We do this because the file might contain
+ // file paths which are relative to this folder.
+ locally_change_current_dir chdir(get_parent_directory(file(source.get_filename())));
+
+
+ std::set<std::string> all_parts;
+
+ // find out what parts are being used in the dataset. Store results in all_parts.
+ for (unsigned long i = 0; i < data.images.size(); ++i)
+ {
+ for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j)
+ {
+ if (source.should_load_box(data.images[i].boxes[j]))
+ {
+ const std::map<std::string,point>& parts = data.images[i].boxes[j].parts;
+ std::map<std::string,point>::const_iterator itr;
+
+ for (itr = parts.begin(); itr != parts.end(); ++itr)
+ {
+ all_parts.insert(itr->first);
+ }
+ }
+ }
+ }
+
+ // make a mapping between part names and the integers [0, all_parts.size())
+ std::map<std::string,int> parts_idx;
+ for (std::set<std::string>::iterator i = all_parts.begin(); i != all_parts.end(); ++i)
+ {
+ parts_idx[*i] = parts_list.size();
+ parts_list.push_back(*i);
+ }
+
+ std::vector<std::vector<rectangle> > ignored_rects;
+ std::vector<rectangle> ignored;
+ image_type img;
+ std::vector<full_object_detection> object_dets;
+ for (unsigned long i = 0; i < data.images.size(); ++i)
+ {
+ double min_rect_size = std::numeric_limits<double>::infinity();
+ object_dets.clear();
+ ignored.clear();
+ for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j)
+ {
+ if (source.should_load_box(data.images[i].boxes[j]))
+ {
+ if (data.images[i].boxes[j].ignore)
+ {
+ ignored.push_back(data.images[i].boxes[j].rect);
+ }
+ else
+ {
+ std::vector<point> partlist(parts_idx.size(), OBJECT_PART_NOT_PRESENT);
+
+ // populate partlist with all the parts present in this box.
+ const std::map<std::string,point>& parts = data.images[i].boxes[j].parts;
+ std::map<std::string,point>::const_iterator itr;
+ for (itr = parts.begin(); itr != parts.end(); ++itr)
+ {
+ partlist[parts_idx[itr->first]] = itr->second;
+ }
+
+ object_dets.push_back(full_object_detection(data.images[i].boxes[j].rect, partlist));
+ min_rect_size = std::min<double>(min_rect_size, object_dets.back().get_rect().area());
+ }
+ }
+ }
+
+ if (!source.should_skip_empty_images() || object_dets.size() != 0)
+ {
+ load_image(img, data.images[i].filename);
+ if (object_dets.size() != 0)
+ {
+ // if shrinking the image would still result in the smallest box being
+ // bigger than the box area threshold then shrink the image.
+ while(min_rect_size/2/2 > source.box_area_thresh())
+ {
+ pyramid_down<2> pyr;
+ pyr(img);
+ min_rect_size *= (1.0/2.0)*(1.0/2.0);
+ for (auto&& r : object_dets)
+ {
+ r.get_rect() = pyr.rect_down(r.get_rect());
+ for (unsigned long k = 0; k < r.num_parts(); ++k)
+ r.part(k) = pyr.point_down(r.part(k));
+ }
+ for (auto&& r : ignored)
+ {
+ r = pyr.rect_down(r);
+ }
+ }
+ while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh())
+ {
+ pyramid_down<3> pyr;
+ pyr(img);
+ min_rect_size *= (2.0/3.0)*(2.0/3.0);
+ for (auto&& r : object_dets)
+ {
+ r.get_rect() = pyr.rect_down(r.get_rect());
+ for (unsigned long k = 0; k < r.num_parts(); ++k)
+ r.part(k) = pyr.point_down(r.part(k));
+ }
+ for (auto&& r : ignored)
+ {
+ r = pyr.rect_down(r);
+ }
+ }
+ }
+ images.push_back(img);
+ object_locations.push_back(object_dets);
+ ignored_rects.push_back(ignored);
+ }
+ }
+
+
+ return ignored_rects;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<full_object_detection> >& object_locations,
+ const image_dataset_file& source
+ )
+ {
+ std::vector<std::string> parts_list;
+ return load_image_dataset(images, object_locations, source, parts_list);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<full_object_detection> >& object_locations,
+ const std::string& filename
+ )
+ {
+ std::vector<std::string> parts_list;
+ return load_image_dataset(images, object_locations, image_dataset_file(filename), parts_list);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LOAD_IMAGE_DaTASET_Hh_
+
diff --git a/ml/dlib/dlib/data_io/load_image_dataset_abstract.h b/ml/dlib/dlib/data_io/load_image_dataset_abstract.h
new file mode 100644
index 000000000..b06252098
--- /dev/null
+++ b/ml/dlib/dlib/data_io/load_image_dataset_abstract.h
@@ -0,0 +1,358 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_
+#ifdef DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_
+
+#include "image_dataset_metadata.h"
+#include "../array/array_kernel_abstract.h"
+#include <string>
+#include <vector>
+#include "../image_processing/full_object_detection_abstract.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class image_dataset_file
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool used to tell the load_image_dataset() functions which
+ boxes and images to load from an XML based image dataset file. By default,
+ this object tells load_image_dataset() to load all images and object boxes.
+ !*/
+
+ public:
+ image_dataset_file(
+ const std::string& filename
+ );
+ /*!
+ ensures
+ - #get_filename() == filename
+ - #should_skip_empty_images() == false
+ - #get_selected_box_labels().size() == 0
+ This means that, initially, all boxes will be loaded. Therefore, for all
+ possible boxes B we have:
+ - #should_load_box(B) == true
+ - #box_area_thresh() == infinity
+ !*/
+
+ const std::string& get_filename(
+ ) const;
+ /*!
+ ensures
+ - returns the name of the XML image dataset metadata file given to this
+ object's constructor.
+ !*/
+
+ bool should_skip_empty_images(
+ ) const;
+ /*!
+ ensures
+ - returns true if we are supposed to skip images that don't have any
+ non-ignored boxes to load when loading an image dataset using
+ load_image_dataset().
+ !*/
+
+ image_dataset_file boxes_match_label(
+ const std::string& label
+ ) const;
+ /*!
+ ensures
+ - returns a copy of *this that is identical in all respects to *this except
+ that label will be included in the labels set (i.e. the set returned by
+ get_selected_box_labels()).
+ !*/
+
+ const std::set<std::string>& get_selected_box_labels(
+ ) const;
+ /*!
+ ensures
+ - returns the set of box labels currently selected by the should_load_box()
+ method. Note that if the set is empty then we select all boxes.
+ !*/
+
+ image_dataset_file skip_empty_images(
+ ) const;
+ /*!
+ ensures
+ - returns a copy of *this that is identical in all respects to *this except
+ that #should_skip_empty_images() == true.
+ !*/
+
+ bool should_boxes_have_parts(
+ ) const;
+ /*!
+ ensures
+ - returns true if boxes must have some parts defined for them to be loaded.
+ !*/
+
+ image_dataset_file boxes_have_parts(
+ ) const;
+ /*!
+ ensures
+ - returns a copy of *this that is identical in all respects to *this except
+ that #should_boxes_have_parts() == true.
+ !*/
+
+ bool should_load_box (
+ const image_dataset_metadata::box& box
+ ) const;
+ /*!
+ ensures
+ - returns true if we are supposed to load the given box from an image
+ dataset XML file. In particular, if should_load_box() returns false then
+ the load_image_dataset() routines will not return the box at all, neither
+ in the ignore rectangles list or in the primary object_locations vector.
+ The behavior of this function is defined as follows:
+ - if (should_boxes_have_parts() && boxes.parts.size() == 0) then
+ - returns false
+ - else if (get_selected_box_labels().size() == 0) then
+ - returns true
+ - else if (get_selected_box_labels().count(box.label) != 0) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ image_dataset_file shrink_big_images(
+ double new_box_area_thresh = 150*150
+ ) const;
+ /*!
+ ensures
+ - returns a copy of *this that is identical in all respects to *this except
+ that #box_area_thresh() == new_box_area_thresh
+ !*/
+
+ double box_area_thresh(
+ ) const;
+ /*!
+ ensures
+ - If the smallest non-ignored rectangle in an image has an area greater
+ than box_area_thresh() then we will shrink the image until the area of
+ the box is about equal to box_area_thresh(). This is useful if you have
+ a dataset containing very high resolution images and you don't want to
+ load it in its native high resolution. Setting the box_area_thresh()
+ allows you to control the resolution of the loaded images.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations,
+ const image_dataset_file& source
+ );
+ /*!
+ requires
+ - array_type == An array of images. This is anything with an interface that
+ looks like std::vector<some generic image type> where a "generic image" is
+ anything that implements the generic image interface defined in
+ dlib/image_processing/generic_image.h.
+ ensures
+ - This routine loads the images and their associated object boxes from the
+ image metadata file indicated by source.get_filename(). This metadata file
+ should be in the XML format used by the save_image_dataset_metadata() routine.
+ - #images.size() == The number of images loaded from the metadata file. This
+ is all the images listed in the file unless source.should_skip_empty_images()
+ is set to true.
+ - #images.size() == #object_locations.size()
+ - This routine is capable of loading any image format which can be read by the
+ load_image() routine.
+ - let IGNORED_RECTS denote the vector returned from this function.
+ - IGNORED_RECTS.size() == #object_locations.size()
+ - IGNORED_RECTS == a list of the rectangles which have the "ignore" flag set to
+ true in the input XML file.
+ - for all valid i:
+ - #images[i] == a copy of the i-th image from the dataset.
+ - #object_locations[i] == a vector of all the rectangles associated with
+ #images[i]. These are the rectangles for which source.should_load_box()
+ returns true and are also not marked as "ignore" in the XML file.
+ - IGNORED_RECTS[i] == A vector of all the rectangles associated with #images[i]
+ that are marked as "ignore" but not discarded by source.should_load_box().
+ - if (source.should_skip_empty_images() == true) then
+ - #object_locations[i].size() != 0
+ (i.e. we won't load images that don't end up having any object locations)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations,
+ const std::string& filename
+ );
+ /*!
+ requires
+ - array_type == An array of images. This is anything with an interface that
+ looks like std::vector<some generic image type> where a "generic image" is
+ anything that implements the generic image interface defined in
+ dlib/image_processing/generic_image.h.
+ ensures
+ - performs: return load_image_dataset(images, object_locations, image_dataset_file(filename));
+ (i.e. it ignores box labels and therefore loads all the boxes in the dataset)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ void load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<mmod_rect> >& object_locations,
+ const image_dataset_file& source
+ );
+ /*!
+ requires
+ - array_type == An array of images. This is anything with an interface that
+ looks like std::vector<some generic image type> where a "generic image" is
+ anything that implements the generic image interface defined in
+ dlib/image_processing/generic_image.h.
+ ensures
+ - This function has essentially the same behavior as the above
+ load_image_dataset() routines, except here we output to a vector of
+ mmod_rects instead of rectangles. In this case, both ignore and non-ignore
+ rectangles go into object_locations since mmod_rect has an ignore boolean
+ field that records the ignored/non-ignored state of each rectangle. We also store
+ a each box's string label into the mmod_rect::label field as well.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ void load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<mmod_rect> >& object_locations,
+ const std::string& filename
+ );
+ /*!
+ requires
+ - array_type == An array of images. This is anything with an interface that
+ looks like std::vector<some generic image type> where a "generic image" is
+ anything that implements the generic image interface defined in
+ dlib/image_processing/generic_image.h.
+ ensures
+ - performs: load_image_dataset(images, object_locations, image_dataset_file(filename));
+ (i.e. it ignores box labels and therefore loads all the boxes in the dataset)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<full_object_detection> >& object_locations,
+ const image_dataset_file& source,
+ std::vector<std::string>& parts_list
+ );
+ /*!
+ requires
+ - array_type == An array of images. This is anything with an interface that
+ looks like std::vector<some generic image type> where a "generic image" is
+ anything that implements the generic image interface defined in
+ dlib/image_processing/generic_image.h.
+ ensures
+ - This routine loads the images and their associated object locations from the
+ image metadata file indicated by source.get_filename(). This metadata file
+ should be in the XML format used by the save_image_dataset_metadata() routine.
+ - The difference between this function and the version of load_image_dataset()
+ defined above is that this version will also load object part information and
+ thus fully populates the full_object_detection objects.
+ - #images.size() == The number of images loaded from the metadata file. This
+ is all the images listed in the file unless source.should_skip_empty_images()
+ is set to true.
+ - #images.size() == #object_locations.size()
+ - This routine is capable of loading any image format which can be read
+ by the load_image() routine.
+ - #parts_list == a vector that contains the list of object parts found in the
+ input file and loaded into object_locations.
+ - #parts_list is in lexicographic sorted order.
+ - let IGNORED_RECTS denote the vector returned from this function.
+ - IGNORED_RECTS.size() == #object_locations.size()
+ - IGNORED_RECTS == a list of the rectangles which have the "ignore" flag set to
+ true in the input XML file.
+ - for all valid i:
+ - #images[i] == a copy of the i-th image from the dataset.
+ - #object_locations[i] == a vector of all the rectangles associated with
+ #images[i]. These are the rectangles for which source.should_load_box()
+ returns true and are also not marked as "ignore" in the XML file.
+ - IGNORED_RECTS[i] == A vector of all the rectangles associated with #images[i]
+ that are marked as "ignore" but not discarded by source.should_load_box().
+ - if (source.should_skip_empty_images() == true) then
+ - #object_locations[i].size() != 0
+ (i.e. we won't load images that don't end up having any object locations)
+ - for all valid j:
+ - #object_locations[i][j].num_parts() == #parts_list.size()
+ - for all valid k:
+ - #object_locations[i][j].part(k) == the location of the part
+ with name #parts_list[k] or OBJECT_PART_NOT_PRESENT if the
+ part was not indicated for object #object_locations[i][j].
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<full_object_detection> >& object_locations,
+ const image_dataset_file& source
+ );
+ /*!
+ requires
+ - array_type == An array of images. This is anything with an interface that
+ looks like std::vector<some generic image type> where a "generic image" is
+ anything that implements the generic image interface defined in
+ dlib/image_processing/generic_image.h.
+ ensures
+ - performs: return load_image_dataset(images, object_locations, source, parts_list);
+ (i.e. this function simply calls the above function and discards the output
+ parts_list. So it is just a convenience function you can call if you don't
+ care about getting the parts list.)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ std::vector<std::vector<rectangle> > load_image_dataset (
+ array_type& images,
+ std::vector<std::vector<full_object_detection> >& object_locations,
+ const std::string& filename
+ );
+ /*!
+ requires
+ - array_type == An array of images. This is anything with an interface that
+ looks like std::vector<some generic image type> where a "generic image" is
+ anything that implements the generic image interface defined in
+ dlib/image_processing/generic_image.h.
+ ensures
+ - performs: return load_image_dataset(images, object_locations, image_dataset_file(filename));
+ (i.e. it ignores box labels and therefore loads all the boxes in the dataset)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/data_io/mnist.cpp b/ml/dlib/dlib/data_io/mnist.cpp
new file mode 100644
index 000000000..d6a62fb67
--- /dev/null
+++ b/ml/dlib/dlib/data_io/mnist.cpp
@@ -0,0 +1,133 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MNIST_CPp_
+#define DLIB_MNIST_CPp_
+
+#include "mnist.h"
+#include <fstream>
+#include "../byte_orderer.h"
+#include "../uintn.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ void load_mnist_dataset (
+ const std::string& folder_name,
+ std::vector<matrix<unsigned char> >& training_images,
+ std::vector<unsigned long>& training_labels,
+ std::vector<matrix<unsigned char> >& testing_images,
+ std::vector<unsigned long>& testing_labels
+ )
+ {
+ using namespace std;
+ ifstream fin1((folder_name+"/train-images-idx3-ubyte").c_str(), ios::binary);
+ if (!fin1)
+ {
+ fin1.open((folder_name + "/train-images.idx3-ubyte").c_str(), ios::binary);
+ }
+
+ ifstream fin2((folder_name+"/train-labels-idx1-ubyte").c_str(), ios::binary);
+ if (!fin2)
+ {
+ fin2.open((folder_name + "/train-labels.idx1-ubyte").c_str(), ios::binary);
+ }
+
+ ifstream fin3((folder_name+"/t10k-images-idx3-ubyte").c_str(), ios::binary);
+ if (!fin3)
+ {
+ fin3.open((folder_name + "/t10k-images.idx3-ubyte").c_str(), ios::binary);
+ }
+
+ ifstream fin4((folder_name+"/t10k-labels-idx1-ubyte").c_str(), ios::binary);
+ if (!fin4)
+ {
+ fin4.open((folder_name + "/t10k-labels.idx1-ubyte").c_str(), ios::binary);
+ }
+
+ if (!fin1) throw error("Unable to open file train-images-idx3-ubyte or train-images.idx3-ubyte");
+ if (!fin2) throw error("Unable to open file train-labels-idx1-ubyte or train-labels.idx1-ubyte");
+ if (!fin3) throw error("Unable to open file t10k-images-idx3-ubyte or t10k-images.idx3-ubyte");
+ if (!fin4) throw error("Unable to open file t10k-labels-idx1-ubyte or t10k-labels.idx1-ubyte");
+
+ byte_orderer bo;
+
+ // make sure the files have the contents we expect.
+ uint32 magic, num, nr, nc, num2, num3, num4;
+ fin1.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
+ fin1.read((char*)&num, sizeof(num)); bo.big_to_host(num);
+ fin1.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr);
+ fin1.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc);
+ if (magic != 2051 || num != 60000 || nr != 28 || nc != 28)
+ throw error("mndist dat files are corrupted.");
+
+ fin2.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
+ fin2.read((char*)&num2, sizeof(num2)); bo.big_to_host(num2);
+ if (magic != 2049 || num2 != 60000)
+ throw error("mndist dat files are corrupted.");
+
+ fin3.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
+ fin3.read((char*)&num3, sizeof(num3)); bo.big_to_host(num3);
+ fin3.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr);
+ fin3.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc);
+ if (magic != 2051 || num3 != 10000 || nr != 28 || nc != 28)
+ throw error("mndist dat files are corrupted.");
+
+ fin4.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
+ fin4.read((char*)&num4, sizeof(num4)); bo.big_to_host(num4);
+ if (magic != 2049 || num4 != 10000)
+ throw error("mndist dat files are corrupted.");
+
+ if (!fin1) throw error("Unable to read train-images-idx3-ubyte");
+ if (!fin2) throw error("Unable to read train-labels-idx1-ubyte");
+ if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte");
+ if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte");
+
+
+ training_images.resize(60000);
+ training_labels.resize(60000);
+ testing_images.resize(10000);
+ testing_labels.resize(10000);
+
+ for (size_t i = 0; i < training_images.size(); ++i)
+ {
+ training_images[i].set_size(nr,nc);
+ fin1.read((char*)&training_images[i](0,0), nr*nc);
+ }
+ for (size_t i = 0; i < training_labels.size(); ++i)
+ {
+ char l;
+ fin2.read(&l, 1);
+ training_labels[i] = l;
+ }
+
+ for (size_t i = 0; i < testing_images.size(); ++i)
+ {
+ testing_images[i].set_size(nr,nc);
+ fin3.read((char*)&testing_images[i](0,0), nr*nc);
+ }
+ for (size_t i = 0; i < testing_labels.size(); ++i)
+ {
+ char l;
+ fin4.read(&l, 1);
+ testing_labels[i] = l;
+ }
+
+ if (!fin1) throw error("Unable to read train-images-idx3-ubyte");
+ if (!fin2) throw error("Unable to read train-labels-idx1-ubyte");
+ if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte");
+ if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte");
+
+ if (fin1.get() != EOF) throw error("Unexpected bytes at end of train-images-idx3-ubyte");
+ if (fin2.get() != EOF) throw error("Unexpected bytes at end of train-labels-idx1-ubyte");
+ if (fin3.get() != EOF) throw error("Unexpected bytes at end of t10k-images-idx3-ubyte");
+ if (fin4.get() != EOF) throw error("Unexpected bytes at end of t10k-labels-idx1-ubyte");
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_MNIST_CPp_
+
+
+
diff --git a/ml/dlib/dlib/data_io/mnist.h b/ml/dlib/dlib/data_io/mnist.h
new file mode 100644
index 000000000..e71be6f2b
--- /dev/null
+++ b/ml/dlib/dlib/data_io/mnist.h
@@ -0,0 +1,32 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MNIST_Hh_
+#define DLIB_MNIST_Hh_
+
+#include "mnist_abstract.h"
+#include <string>
+#include <vector>
+#include "../matrix.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ void load_mnist_dataset (
+ const std::string& folder_name,
+ std::vector<matrix<unsigned char> >& training_images,
+ std::vector<unsigned long>& training_labels,
+ std::vector<matrix<unsigned char> >& testing_images,
+ std::vector<unsigned long>& testing_labels
+ );
+}
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef NO_MAKEFILE
+#include "mnist.cpp"
+#endif
+
+#endif // DLIB_MNIST_Hh_
+
+
diff --git a/ml/dlib/dlib/data_io/mnist_abstract.h b/ml/dlib/dlib/data_io/mnist_abstract.h
new file mode 100644
index 000000000..09121633e
--- /dev/null
+++ b/ml/dlib/dlib/data_io/mnist_abstract.h
@@ -0,0 +1,46 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MNIST_ABSTRACT_Hh_
+#ifdef DLIB_MNIST_ABSTRACT_Hh_
+
+#include <string>
+#include <vector>
+#include "../matrix.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ void load_mnist_dataset (
+ const std::string& folder_name,
+ std::vector<matrix<unsigned char> >& training_images,
+ std::vector<unsigned long>& training_labels,
+ std::vector<matrix<unsigned char> >& testing_images,
+ std::vector<unsigned long>& testing_labels
+ );
+ /*!
+ ensures
+ - Attempts to load the MNIST dataset from the hard drive. This is the dataset
+ of handwritten digits available from http://yann.lecun.com/exdb/mnist/. In
+ particular, the 4 files comprising the MNIST dataset should be present in the
+ folder indicated by folder_name. These four files are:
+ - train-images-idx3-ubyte
+ - train-labels-idx1-ubyte
+ - t10k-images-idx3-ubyte
+ - t10k-labels-idx1-ubyte
+ - #training_images == The 60,000 training images from the dataset.
+ - #training_labels == The labels for the contents of #training_images.
+ I.e. #training_labels[i] is the label of #training_images[i].
+ - #testing_images == The 10,000 testing images from the dataset.
+ - #testing_labels == The labels for the contents of #testing_images.
+ I.e. #testing_labels[i] is the label of #testing_images[i].
+ throws
+ - dlib::error if some problem prevents us from loading the data or the files
+ can't be found.
+ !*/
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_MNIST_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/dir_nav.h b/ml/dlib/dlib/dir_nav.h
new file mode 100644
index 000000000..c5956615d
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav.h
@@ -0,0 +1,21 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIR_NAv_
+#define DLIB_DIR_NAv_
+
+
+#include "platform.h"
+
+
+#ifdef WIN32
+#include "dir_nav/windows.h"
+#endif
+
+#ifndef WIN32
+#include "dir_nav/posix.h"
+#endif
+
+#include "dir_nav/dir_nav_extensions.h"
+
+#endif // DLIB_DIR_NAv_
+
diff --git a/ml/dlib/dlib/dir_nav/dir_nav_extensions.cpp b/ml/dlib/dlib/dir_nav/dir_nav_extensions.cpp
new file mode 100644
index 000000000..db05e4cc4
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/dir_nav_extensions.cpp
@@ -0,0 +1,121 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIR_NAV_EXTENSIONs_CPP_
+#define DLIB_DIR_NAV_EXTENSIONs_CPP_
+
+#include "dir_nav_extensions.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace implementation_details
+ {
+ void get_all_sub_dirs (
+ const directory& top_of_tree,
+ unsigned long max_depth,
+ std::vector<directory>& result,
+ std::vector<directory>& temp
+ )
+ {
+ if (max_depth > 0)
+ {
+ top_of_tree.get_dirs(temp);
+ const unsigned long start = result.size();
+ result.insert(result.end(), temp.begin(), temp.end());
+ const unsigned long end = start + temp.size();
+
+ for (unsigned long i = start; i < end; ++i)
+ {
+ get_all_sub_dirs(result[i], max_depth-1, result, temp);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool file_exists (
+ const std::string& filename
+ )
+ {
+ try
+ {
+ dlib::file temp(filename);
+ return true;
+ }
+ catch (file::file_not_found&)
+ {
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ directory get_parent_directory (
+ const directory& dir
+ )
+ {
+ return dir.get_parent();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ directory get_parent_directory (
+ const file& f
+ )
+ {
+ if (f.full_name().size() == 0)
+ return directory();
+
+ std::string::size_type pos = f.full_name().find_last_of("\\/");
+
+ if (pos == std::string::npos)
+ return directory();
+
+ return directory(f.full_name().substr(0,pos));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string select_oldest_file (
+ const std::string& filename1,
+ const std::string& filename2
+ )
+ {
+ file f1, f2;
+ try{f1 = file(filename1);} catch(file::file_not_found&) { return filename1; }
+ try{f2 = file(filename2);} catch(file::file_not_found&) { return filename2; }
+
+ if (f1.last_modified() < f2.last_modified())
+ return filename1;
+ else
+ return filename2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string select_newest_file (
+ const std::string& filename1,
+ const std::string& filename2
+ )
+ {
+ file f1, f2;
+ try{f1 = file(filename1);} catch(file::file_not_found&) { return filename2; }
+ try{f2 = file(filename2);} catch(file::file_not_found&) { return filename1; }
+
+ if (f1.last_modified() > f2.last_modified())
+ return filename1;
+ else
+ return filename2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DIR_NAV_EXTENSIONs_CPP_
+
+
+
diff --git a/ml/dlib/dlib/dir_nav/dir_nav_extensions.h b/ml/dlib/dlib/dir_nav/dir_nav_extensions.h
new file mode 100644
index 000000000..93dde1159
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/dir_nav_extensions.h
@@ -0,0 +1,172 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIR_NAV_EXTENSIONs_H_
+#define DLIB_DIR_NAV_EXTENSIONs_H_
+
+#include <string>
+#include <vector>
+#include <algorithm>
+#include "dir_nav_extensions_abstract.h"
+#include "../dir_nav.h"
+#include "../string.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ bool file_exists (
+ const std::string& filename
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ namespace implementation_details
+ {
+ void get_all_sub_dirs (
+ const directory& top_of_tree,
+ unsigned long max_depth,
+ std::vector<directory>& result,
+ std::vector<directory>& temp
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ const std::vector<file> get_files_in_directory_tree (
+ const directory& top_of_tree,
+ const T& add_file,
+ unsigned long max_depth = 30
+ )
+ {
+ std::vector<file> result, temp;
+ std::vector<directory> dirs, dirs_temp;
+ dirs.push_back(top_of_tree);
+
+ // get all the directories in the tree first
+ implementation_details::get_all_sub_dirs(top_of_tree, max_depth, dirs, dirs_temp);
+
+ // now just loop over all the directories and pick out the files we want to keep
+ for (unsigned long d = 0; d < dirs.size(); ++d)
+ {
+ dirs[d].get_files(temp);
+
+ // pick out the members of temp that we should keep
+ for (unsigned long i = 0; i < temp.size(); ++i)
+ {
+ if (add_file(temp[i]))
+ result.push_back(temp[i]);
+ }
+ }
+
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class match_ending
+ {
+
+ public:
+ match_ending (
+ const std::string& ending_
+ ) : ending(ending_) {}
+
+ bool operator() (
+ const file& f
+ ) const
+ {
+ // if the ending is bigger than f's name then it obviously doesn't match
+ if (ending.size() > f.name().size())
+ return false;
+
+ // now check if the actual characters that make up the end of the file name
+ // matches what is in ending.
+ return std::equal(ending.begin(), ending.end(), f.name().end()-ending.size());
+ }
+
+ private:
+ std::string ending;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class match_endings
+ {
+
+ public:
+ match_endings (
+ const std::string& endings_
+ )
+ {
+ const std::vector<std::string>& s = split(endings_);
+ for (unsigned long i = 0; i < s.size(); ++i)
+ {
+ endings.push_back(match_ending(s[i]));
+ }
+ }
+
+ bool operator() (
+ const file& f
+ ) const
+ {
+ for (unsigned long i = 0; i < endings.size(); ++i)
+ {
+ if (endings[i](f))
+ return true;
+ }
+
+ return false;
+ }
+
+ private:
+ std::vector<match_ending> endings;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class match_all
+ {
+ public:
+ bool operator() (
+ const file&
+ ) const { return true; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ directory get_parent_directory (
+ const directory& dir
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ directory get_parent_directory (
+ const file& f
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ std::string select_oldest_file (
+ const std::string& filename1,
+ const std::string& filename2
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ std::string select_newest_file (
+ const std::string& filename1,
+ const std::string& filename2
+ );
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "dir_nav_extensions.cpp"
+#endif
+
+#endif // DLIB_DIR_NAV_EXTENSIONs_H_
+
diff --git a/ml/dlib/dlib/dir_nav/dir_nav_extensions_abstract.h b/ml/dlib/dlib/dir_nav/dir_nav_extensions_abstract.h
new file mode 100644
index 000000000..4aa6cc4f2
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/dir_nav_extensions_abstract.h
@@ -0,0 +1,203 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DIR_NAV_EXTENSIONs_ABSTRACT_
+#ifdef DLIB_DIR_NAV_EXTENSIONs_ABSTRACT_
+
+#include <string>
+#include <vector>
+#include "dir_nav_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ bool file_exists (
+ const std::string& filename
+ );
+ /*!
+ ensures
+ - if (a file with the given filename exists) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ const std::vector<file> get_files_in_directory_tree (
+ const directory& top_of_tree,
+ const T& add_file,
+ unsigned long max_depth = 30
+ );
+ /*!
+ requires
+ - add_file must be a function object with the following prototype:
+ bool add_file (file f);
+ ensures
+ - performs a recursive search through the directory top_of_tree and all
+ its sub-directories (up to the given max depth). All files in these
+ directories are examined by passing them to add_file() and if it
+ returns true then they will be included in the returned std::vector<file>
+ object.
+ - Note that a max_depth of 0 means that only the files in the directory
+ top_of_tree will be considered. A depth of 1 means that only files in
+ top_of_tree and its immediate sub-directories will be considered. And
+ so on...
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class match_ending
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple function object that can be used with the
+ above get_files_in_directory_tree() function. This object
+ just looks for files with a certain ending.
+ !*/
+
+ public:
+ match_ending (
+ const std::string& ending
+ );
+ /*!
+ ensures
+ - this object will be a function that checks if a file has a
+ name that ends with the given ending string.
+ !*/
+
+ bool operator() (
+ const file& f
+ ) const;
+ /*!
+ ensures
+ - if (the file f has a name that ends with the ending string given
+ to this object's constructor) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class match_endings
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple function object that can be used with the
+ above get_files_in_directory_tree() function. This object
+ allows you to look for files with a number of different
+ endings.
+ !*/
+
+ public:
+ match_endings (
+ const std::string& ending_list
+ );
+ /*!
+ ensures
+ - ending_list is interpreted as a whitespace separated list
+ of file endings.
+ - this object will be a function that checks if a file has a
+ name that ends with one of the strings in ending_list.
+ !*/
+
+ bool operator() (
+ const file& f
+ ) const;
+ /*!
+ ensures
+ - if (the file f has a name that ends with one of the ending strings
+ given to this object's constructor) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class match_all
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple function object that can be used with the
+ above get_files_in_directory_tree() function. This object
+ matches all files.
+ !*/
+
+ public:
+ bool operator() (
+ const file& f
+ ) const;
+ /*!
+ ensures
+ - returns true
+ (i.e. this function doesn't do anything. It just says it
+ matches all files no matter what)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ directory get_parent_directory (
+ const directory& dir
+ );
+ /*!
+ ensures
+ - returns the parent directory of dir. In particular, this
+ function returns the value of dir.get_parent()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ directory get_parent_directory (
+ const file& f
+ );
+ /*!
+ ensures
+ - if (f.full_name() != "") then
+ - returns the directory which contains the given file
+ - else
+ - returns a default initialized directory (i.e. directory())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::string select_oldest_file (
+ const std::string& filename1,
+ const std::string& filename2
+ );
+ /*!
+ ensures
+ - Checks the last modification times of the two given files and returns the
+ filename of the oldest file, i.e., the file that has gone longest since being
+ modified. Ties are broken arbitrarily.
+ - For the purpose of comparison, a file that doesn't exist is presumed to have
+ a last modification time of -infinity (i.e. very far in the past).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::string select_newest_file (
+ const std::string& filename1,
+ const std::string& filename2
+ );
+ /*!
+ ensures
+ - Checks the last modification times of the two given files and returns the
+ filename that was most recently modified. Ties are broken arbitrarily.
+ - For the purpose of comparison, a file that doesn't exist is presumed to have
+ a last modification time of -infinity (i.e. very far in the past).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DIR_NAV_EXTENSIONs_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.cpp b/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.cpp
new file mode 100644
index 000000000..9891d5dff
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.cpp
@@ -0,0 +1,258 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIR_NAV_KERNEL_1_CPp_
+#define DLIB_DIR_NAV_KERNEL_1_CPp_
+#include "../platform.h"
+
+#ifdef WIN32
+
+#include "dir_nav_kernel_1.h"
+#include "../string.h"
+
+
+#ifdef __BORLANDC__
+// Apparently the borland compiler doesn't define this.
+#define INVALID_FILE_ATTRIBUTES ((DWORD)-1)
+#endif
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // file object implementation
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void file::
+ init (
+ const std::string& name
+ )
+ {
+ using namespace std;
+
+
+ char buf[3000];
+ char* str;
+ if (GetFullPathNameA(name.c_str(),sizeof(buf),buf,&str) == 0)
+ {
+ // the file was not found
+ throw file_not_found("Unable to find file " + name);
+ }
+ state.full_name = buf;
+
+
+ string::size_type pos = state.full_name.find_last_of(directory::get_separator());
+ if (pos == string::npos)
+ {
+ // no valid full path has no separator characters.
+ throw file_not_found("Unable to find file " + name);
+ }
+ state.name = state.full_name.substr(pos+1);
+
+
+ // now find the size of this file
+ WIN32_FIND_DATAA data;
+ HANDLE ffind = FindFirstFileA(state.full_name.c_str(), &data);
+ if (ffind == INVALID_HANDLE_VALUE ||
+ (data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) != 0)
+ {
+ throw file_not_found("Unable to find file " + name);
+ }
+ else
+ {
+ uint64 temp = data.nFileSizeHigh;
+ temp <<= 32;
+ temp |= data.nFileSizeLow;
+ state.file_size = temp;
+ FindClose(ffind);
+
+ ULARGE_INTEGER ull;
+ ull.LowPart = data.ftLastWriteTime.dwLowDateTime;
+ ull.HighPart = data.ftLastWriteTime.dwHighDateTime;
+ std::chrono::nanoseconds epoch(100 * (ull.QuadPart - 116444736000000000));
+ state.last_modified = std::chrono::time_point<std::chrono::system_clock>(std::chrono::duration_cast<std::chrono::system_clock::duration>(epoch));
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool file::
+ operator == (
+ const file& rhs
+ ) const
+ {
+ using namespace std;
+
+ if (state.full_name.size() != rhs.state.full_name.size())
+ return false;
+
+ // compare the strings but ignore the case because file names
+ // are not case sensitive on windows
+ return tolower(state.full_name) == tolower(rhs.state.full_name);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // directory object implementation
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void directory::
+ init (
+ const std::string& name
+ )
+ {
+ using namespace std;
+
+
+ char buf[3000];
+ char* str;
+ if (GetFullPathNameA(name.c_str(),sizeof(buf),buf,&str) == 0)
+ {
+ // the directory was not found
+ throw dir_not_found("Unable to find directory " + name);
+ }
+ state.full_name = buf;
+
+
+ const char sep = get_separator();
+ if (is_root_path(state.full_name) == false)
+ {
+ // ensure that thre is not a trialing separator
+ if (state.full_name[state.full_name.size()-1] == sep)
+ state.full_name.erase(state.full_name.size()-1);
+
+ // pick out the directory name
+ string::size_type pos = state.full_name.find_last_of(sep);
+ state.name = state.full_name.substr(pos+1);
+ }
+ else
+ {
+ // ensure that there is a trailing separator
+ if (state.full_name[state.full_name.size()-1] != sep)
+ state.full_name += sep;
+ }
+
+
+ // now check that this is actually a valid directory
+ DWORD attribs = GetFileAttributesA(state.full_name.c_str());
+ if (attribs == INVALID_FILE_ATTRIBUTES ||
+ (attribs&FILE_ATTRIBUTE_DIRECTORY) == 0)
+ {
+ // the directory was not found
+ throw dir_not_found("Unable to find directory " + name);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ char directory::
+ get_separator (
+ )
+ {
+ return '\\';
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool directory::
+ operator == (
+ const directory& rhs
+ ) const
+ {
+ using namespace std;
+
+ if (state.full_name.size() != rhs.state.full_name.size())
+ return false;
+
+ // compare the strings but ignore the case because file names
+ // are not case sensitive on windows
+ return tolower(state.full_name) == tolower(rhs.state.full_name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const directory directory::
+ get_parent (
+ ) const
+ {
+ using namespace std;
+ // if *this is the root then just return *this
+ if (is_root())
+ {
+ return *this;
+ }
+ else
+ {
+ directory temp;
+
+ const char sep = get_separator();
+
+ string::size_type pos = state.full_name.find_last_of(sep);
+ temp.state.full_name = state.full_name.substr(0,pos);
+
+ if ( is_root_path(temp.state.full_name))
+ {
+ temp.state.full_name += sep;
+ }
+ else
+ {
+ pos = temp.state.full_name.find_last_of(sep);
+ if (pos != string::npos)
+ {
+ temp.state.name = temp.state.full_name.substr(pos+1);
+ }
+ else
+ {
+ temp.state.full_name += sep;
+ }
+ }
+ return temp;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool directory::
+ is_root_path (
+ const std::string& path
+ ) const
+ {
+ using namespace std;
+ const char sep = get_separator();
+ bool root_path = false;
+ if (path.size() > 2 && path[0] == sep && path[1] == sep)
+ {
+ // in this case this is a windows share path
+ string::size_type pos = path.find_first_of(sep,2);
+ if (pos != string::npos)
+ {
+ pos = path.find_first_of(sep,pos+1);
+
+ if (pos == string::npos && path[path.size()-1] != sep)
+ root_path = true;
+ else if (pos == path.size()-1)
+ root_path = true;
+ }
+
+ }
+ else if ( (path.size() == 2 || path.size() == 3) && path[1] == ':')
+ {
+ // if this is a valid windows path then it must be a root path
+ root_path = true;
+ }
+
+ return root_path;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // WIN32
+
+#endif // DLIB_DIR_NAV_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.h b/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.h
new file mode 100644
index 000000000..a31f689de
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.h
@@ -0,0 +1,634 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIR_NAV_KERNEl_1_
+#define DLIB_DIR_NAV_KERNEl_1_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+#include "../platform.h"
+
+
+#include "dir_nav_kernel_abstract.h"
+#include <string>
+#include "../uintn.h"
+#include "../algs.h"
+
+#include "../windows_magic.h"
+#include <windows.h>
+#include <vector>
+#include "../stl_checked.h"
+#include "../enable_if.h"
+#include "../queue.h"
+#include <chrono>
+
+namespace dlib
+{
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // file object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class file
+ {
+ /*!
+ INITIAL VALUES
+ state.name == name()
+ state.full_name == full_name()
+ state.file_size == size()
+ state.last_modified == last_modified()
+
+ CONVENTION
+ state.name == name()
+ state.full_name == full_name()
+ state.file_size == size()
+ state.last_modified == last_modified()
+
+ !*/
+
+ friend class directory;
+
+ struct data
+ {
+ uint64 file_size;
+ std::string name;
+ std::string full_name;
+ std::chrono::time_point<std::chrono::system_clock> last_modified;
+ };
+
+
+ void init ( const std::string& name);
+
+ public:
+
+ struct private_constructor{};
+ inline file (
+ const std::string& name,
+ const std::string& full_name,
+ const uint64 file_size,
+ const std::chrono::time_point<std::chrono::system_clock>& last_modified,
+ private_constructor
+ )
+ {
+ state.file_size = file_size;
+ state.name = name;
+ state.full_name = full_name;
+ state.last_modified = last_modified;
+ }
+
+
+
+
+ class file_not_found : public error {
+ public: file_not_found(const std::string& s): error(s){}
+ };
+
+ inline file (
+ )
+ {
+ state.file_size = 0;
+ }
+
+ file (
+ const std::string& name
+ ) { init(name); }
+
+ file (
+ const char* name
+ ) { init(name); }
+
+ inline const std::string& name (
+ ) const { return state.name; }
+
+ inline const std::string& full_name (
+ ) const { return state.full_name; }
+
+ operator std::string (
+ ) const { return full_name(); }
+
+ inline uint64 size (
+ ) const { return state.file_size; }
+
+ inline std::chrono::time_point<std::chrono::system_clock> last_modified (
+ ) const { return state.last_modified; }
+
+ bool operator == (
+ const file& rhs
+ ) const;
+
+ bool operator != (
+ const file& rhs
+ ) const { return !(*this == rhs); }
+
+ inline bool operator < (
+ const file& item
+ ) const { return full_name() < item.full_name(); }
+
+ inline void swap (
+ file& item
+ )
+ {
+ exchange(state,item.state);
+ }
+
+ private:
+
+ data state;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // directory object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class directory
+ {
+ /*!
+ INITIAL VALUES
+ state.name == name()
+ state.full_name == full_name()
+
+ CONVENTION
+ state.name == name()
+ state.full_name == full_name()
+ is_root() == state.name.size() == 0
+
+ !*/
+
+ void init (const std::string& name);
+
+ public:
+
+ struct data
+ {
+ std::string name;
+ std::string full_name;
+ };
+
+
+ /*
+ The reason we don't just make this constructor actually
+ private is because doing it this way avoids a bug that
+ sometimes occurs in visual studio 7.1. The bug has
+ something to do with templated friend functions
+ such as the get_filesystem_roots() function below if
+ it was declared as a friend template of this class.
+ */
+ struct private_constructor{};
+ inline directory (
+ const std::string& name,
+ const std::string& full_name,
+ private_constructor
+ )
+ {
+ state.name = name;
+ state.full_name = full_name;
+ }
+
+
+ class dir_not_found : public error {
+ public: dir_not_found(const std::string& s):error(s){}
+ };
+ class listing_error : public error {
+ public: listing_error(const std::string& s):error(s){}
+ };
+
+ inline directory (
+ )
+ {
+ }
+
+ directory (
+ const std::string& name
+ ) { init(name); }
+
+ directory (
+ const char* name
+ ) { init(name); }
+
+
+ static char get_separator (
+ );
+
+
+ template <
+ typename queue_of_files
+ >
+ void get_files (
+ queue_of_files& files
+ ) const;
+
+ template <
+ typename queue_of_dirs
+ >
+ void get_dirs (
+ queue_of_dirs& dirs
+ ) const;
+
+ std::vector<file> get_files (
+ ) const
+ {
+ std::vector<file> temp_vector;
+ get_files(temp_vector);
+ return temp_vector;
+ }
+
+ std::vector<directory> get_dirs (
+ ) const
+ {
+ std::vector<directory> temp_vector;
+ get_dirs(temp_vector);
+ return temp_vector;
+ }
+
+ const directory get_parent (
+ ) const;
+
+ inline bool is_root (
+ ) const { return state.name.size() == 0; }
+
+ inline const std::string& name (
+ ) const { return state.name; }
+
+ inline const std::string& full_name (
+ ) const { return state.full_name; }
+
+ operator std::string (
+ ) const { return full_name(); }
+
+ bool operator == (
+ const directory& rhs
+ ) const;
+
+ bool operator != (
+ const directory& rhs
+ ) const { return !(*this == rhs); }
+
+ inline bool operator < (
+ const directory& item
+ ) const { return full_name() < item.full_name(); }
+
+ inline void swap (
+ directory& item
+ )
+ {
+ exchange(state,item.state);
+ }
+
+ private:
+
+ // member data
+ data state;
+
+ bool is_root_path (
+ const std::string& path
+ ) const;
+ /*!
+ ensures
+ - returns true if path is a root path.
+ Note that this function considers root paths that don't
+ have a trailing separator to also be valid.
+ !*/
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::ostream& operator<< (
+ std::ostream& out,
+ const directory& item
+ ) { out << (std::string)item; return out; }
+
+ inline std::ostream& operator<< (
+ std::ostream& out,
+ const file& item
+ ) { out << (std::string)item; return out; }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_dir
+ >
+ typename disable_if<is_std_vector<queue_of_dir>,void>::type get_filesystem_roots (
+ queue_of_dir& roots
+ )
+ {
+ roots.clear();
+ const DWORD mask = GetLogicalDrives();
+ DWORD bit = 1;
+ char buf[] = "A:\\";
+
+ do
+ {
+ if (mask & bit)
+ {
+ directory dir("",buf,directory::private_constructor());
+ roots.enqueue(dir);
+ }
+ bit <<= 1;
+ ++buf[0];
+ } while (buf[0] != 'Z');
+ }
+
+ template <
+ typename queue_of_dir
+ >
+ typename enable_if<is_std_vector<queue_of_dir>,void>::type get_filesystem_roots (
+ queue_of_dir& roots
+ )
+ {
+ roots.clear();
+ const DWORD mask = GetLogicalDrives();
+ DWORD bit = 1;
+ char buf[] = "A:\\";
+
+ do
+ {
+ if (mask & bit)
+ {
+ directory dir("",buf,directory::private_constructor());
+ roots.push_back(dir);
+ }
+ bit <<= 1;
+ ++buf[0];
+ } while (buf[0] != 'Z');
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ file& a,
+ file& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ directory& a,
+ directory& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // templated member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_files
+ >
+ typename disable_if<is_std_vector<queue_of_files>,void>::type
+ directory_helper_get_files (
+ const directory::data& state,
+ queue_of_files& files
+ )
+ {
+ using namespace std;
+ typedef directory::listing_error listing_error;
+ typedef file::private_constructor private_constructor;
+
+ files.clear();
+ if (state.full_name.size() == 0)
+ throw listing_error("This directory object currently doesn't represent any directory.");
+
+ HANDLE ffind = INVALID_HANDLE_VALUE;
+ try
+ {
+ WIN32_FIND_DATAA data;
+ string path = state.full_name;
+ // ensure that the path ends with a separator
+ if (path[path.size()-1] != directory::get_separator())
+ path += directory::get_separator();
+
+ ffind = FindFirstFileA((path+"*").c_str(), &data);
+ if (ffind == INVALID_HANDLE_VALUE)
+ {
+ throw listing_error("Unable to list the contents of " + state.full_name);
+ }
+
+
+ bool no_more_files = false;
+ do
+ {
+ if ((data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) == 0)
+ {
+ uint64 file_size = data.nFileSizeHigh;
+ file_size <<= 32;
+ file_size |= data.nFileSizeLow;
+
+ ULARGE_INTEGER ull;
+ ull.LowPart = data.ftLastWriteTime.dwLowDateTime;
+ ull.HighPart = data.ftLastWriteTime.dwHighDateTime;
+ std::chrono::nanoseconds epoch(100 * (ull.QuadPart - 116444736000000000));
+ auto last_modified = std::chrono::time_point<std::chrono::system_clock>(std::chrono::duration_cast<std::chrono::system_clock::duration>(epoch));
+
+ // this is a file so add it to the queue
+ file temp(data.cFileName,path+data.cFileName,file_size, last_modified, private_constructor());
+ files.enqueue(temp);
+ }
+
+ if (FindNextFileA(ffind,&data) == 0)
+ {
+ // an error occurred
+ if ( GetLastError() == ERROR_NO_MORE_FILES)
+ {
+ // there are no more files
+ no_more_files = true;
+ }
+ else
+ {
+ // there was an error
+ throw listing_error("Unable to list the contents of " + state.full_name);
+ }
+ }
+ } while (no_more_files == false);
+
+ FindClose(ffind);
+ ffind = INVALID_HANDLE_VALUE;
+ }
+ catch (...)
+ {
+ if (ffind != INVALID_HANDLE_VALUE)
+ FindClose(ffind);
+ files.clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_files
+ >
+ typename enable_if<is_std_vector<queue_of_files>,void>::type
+ directory_helper_get_files (
+ const directory::data& state,
+ queue_of_files& files
+ )
+ {
+ queue<file>::kernel_2a temp_files;
+ directory_helper_get_files(state,temp_files);
+
+ files.clear();
+
+ // copy the queue of files into the vector
+ temp_files.reset();
+ while (temp_files.move_next())
+ {
+ files.push_back(temp_files.element());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_files
+ >
+ void directory::
+ get_files (
+ queue_of_files& files
+ ) const
+ {
+ // the reason for this indirection here is because it avoids a bug in
+ // the mingw version of gcc
+ directory_helper_get_files(state,files);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_dirs
+ >
+ typename disable_if<is_std_vector<queue_of_dirs>,void>::type
+ directory_helper_get_dirs (
+ const directory::data& state,
+ queue_of_dirs& dirs
+ )
+ {
+ using namespace std;
+ typedef directory::listing_error listing_error;
+ typedef directory::private_constructor private_constructor;
+
+ dirs.clear();
+ if (state.full_name.size() == 0)
+ throw listing_error("This directory object currently doesn't represent any directory.");
+
+ HANDLE dfind = INVALID_HANDLE_VALUE;
+ try
+ {
+ WIN32_FIND_DATAA data;
+ string path = state.full_name;
+ // ensure that the path ends with a separator
+ if (path[path.size()-1] != directory::get_separator())
+ path += directory::get_separator();
+
+ dfind = FindFirstFileA((path+"*").c_str(), &data);
+ if (dfind == INVALID_HANDLE_VALUE)
+ {
+ throw listing_error("Unable to list the contents of " + state.full_name);
+ }
+
+
+ bool no_more_files = false;
+ do
+ {
+ string tname(data.cFileName);
+ if ((data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) != 0 &&
+ tname != "." &&
+ tname != "..")
+ {
+ // this is a directory so add it to the queue
+ directory temp(tname,path+tname,private_constructor());
+ dirs.enqueue(temp);
+ }
+
+ if (FindNextFileA(dfind,&data) == 0)
+ {
+ // an error occurred
+ if ( GetLastError() == ERROR_NO_MORE_FILES)
+ {
+ // there are no more files
+ no_more_files = true;
+ }
+ else
+ {
+ // there was an error
+ throw listing_error("Unable to list the contents of " + state.full_name);
+ }
+ }
+ } while (no_more_files == false);
+
+ FindClose(dfind);
+ dfind = INVALID_HANDLE_VALUE;
+ }
+ catch (...)
+ {
+ if (dfind != INVALID_HANDLE_VALUE)
+ FindClose(dfind);
+ dirs.clear();
+ throw;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_dirs
+ >
+ typename enable_if<is_std_vector<queue_of_dirs>,void>::type
+ directory_helper_get_dirs (
+ const directory::data& state,
+ queue_of_dirs& dirs
+ )
+ {
+ queue<directory>::kernel_2a temp_dirs;
+ directory_helper_get_dirs(state,temp_dirs);
+
+ dirs.clear();
+
+ // copy the queue of dirs into the vector
+ temp_dirs.reset();
+ while (temp_dirs.move_next())
+ {
+ dirs.push_back(temp_dirs.element());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_dirs
+ >
+ void directory::
+ get_dirs (
+ queue_of_dirs& dirs
+ ) const
+ {
+ // the reason for this indirection here is because it avoids a bug in
+ // the mingw version of gcc
+ directory_helper_get_dirs(state,dirs);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#ifdef NO_MAKEFILE
+#include "dir_nav_kernel_1.cpp"
+#endif
+
+#endif // DLIB_DIR_NAV_KERNEl_1_
+
diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.cpp b/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.cpp
new file mode 100644
index 000000000..be97b984c
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.cpp
@@ -0,0 +1,254 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIR_NAV_KERNEL_2_CPp_
+#define DLIB_DIR_NAV_KERNEL_2_CPp_
+
+#include "../platform.h"
+
+#ifdef POSIX
+
+
+#include "dir_nav_kernel_2.h"
+
+
+
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // file object implementation
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void file::
+ init (
+ const std::string& name
+ )
+ {
+ using namespace std;
+
+
+
+ char buf[PATH_MAX];
+ if (realpath(name.c_str(),buf) == 0)
+ {
+ // the file was not found
+ throw file_not_found("Unable to find file " + name);
+ }
+ state.full_name = buf;
+
+
+ string::size_type pos = state.full_name.find_last_of(directory::get_separator());
+ if (pos == string::npos)
+ {
+ // no valid full path has no separtor characters.
+ throw file_not_found("Unable to find file " + name);
+ }
+ state.name = state.full_name.substr(pos+1);
+
+
+ // now find the size of this file
+ struct stat64 buffer;
+ if (::stat64(state.full_name.c_str(), &buffer) ||
+ S_ISDIR(buffer.st_mode))
+ {
+ // there was an error during the call to stat64 or
+ // name is actually a directory
+ throw file_not_found("Unable to find file " + name);
+ }
+ else
+ {
+ state.file_size = static_cast<uint64>(buffer.st_size);
+
+
+ state.last_modified = std::chrono::system_clock::from_time_t(buffer.st_mtime);
+#ifdef _BSD_SOURCE
+ state.last_modified += std::chrono::duration_cast<std::chrono::system_clock::duration>(std::chrono::nanoseconds(buffer.st_atim.tv_nsec));
+#endif
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool file::
+ operator == (
+ const file& rhs
+ ) const
+ {
+ using namespace std;
+ if (state.full_name.size() == 0 && rhs.state.full_name.size() == 0)
+ return true;
+
+ // These files might have different names but actually represent the same
+ // file due to the presence of symbolic links.
+ char buf[PATH_MAX];
+ string left, right;
+ if (realpath(state.full_name.c_str(),buf) == 0)
+ return false;
+ left = buf;
+
+ if (realpath(rhs.state.full_name.c_str(),buf) == 0)
+ return false;
+ right = buf;
+
+ return (left == right);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // directory object implementation
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void directory::
+ init (
+ const std::string& name
+ )
+ {
+ using namespace std;
+
+
+ char buf[PATH_MAX];
+ if (realpath(name.c_str(),buf) == 0)
+ {
+ // the directory was not found
+ throw dir_not_found("Unable to find directory " + name);
+ }
+ state.full_name = buf;
+
+
+ const char sep = get_separator();
+ if (is_root_path(state.full_name) == false)
+ {
+ // ensure that thre is not a trialing separator
+ if (state.full_name[state.full_name.size()-1] == sep)
+ state.full_name.erase(state.full_name.size()-1);
+
+ // pick out the directory name
+ string::size_type pos = state.full_name.find_last_of(sep);
+ state.name = state.full_name.substr(pos+1);
+ }
+ else
+ {
+ // ensure that there is a trailing separator
+ if (state.full_name[state.full_name.size()-1] != sep)
+ state.full_name += sep;
+ }
+
+
+ struct stat64 buffer;
+ // now check that this is actually a valid directory
+ if (::stat64(state.full_name.c_str(),&buffer))
+ {
+ // the directory was not found
+ throw dir_not_found("Unable to find directory " + name);
+ }
+ else if (S_ISDIR(buffer.st_mode) == 0)
+ {
+ // It is not a directory
+ throw dir_not_found("Unable to find directory " + name);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ char directory::
+ get_separator (
+ )
+ {
+ return '/';
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool directory::
+ operator == (
+ const directory& rhs
+ ) const
+ {
+ using namespace std;
+ if (state.full_name.size() == 0 && rhs.state.full_name.size() == 0)
+ return true;
+
+ // These directories might have different names but actually represent the same
+ // directory due to the presence of symbolic links.
+ char buf[PATH_MAX];
+ string left, right;
+ if (realpath(state.full_name.c_str(),buf) == 0)
+ return false;
+ left = buf;
+
+ if (realpath(rhs.state.full_name.c_str(),buf) == 0)
+ return false;
+ right = buf;
+
+ return (left == right);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const directory directory::
+ get_parent (
+ ) const
+ {
+ using namespace std;
+ // if *this is the root then just return *this
+ if (is_root())
+ {
+ return *this;
+ }
+ else
+ {
+ directory temp;
+
+ const char sep = get_separator();
+
+ string::size_type pos = state.full_name.find_last_of(sep);
+ temp.state.full_name = state.full_name.substr(0,pos);
+
+ if ( is_root_path(temp.state.full_name))
+ {
+ temp.state.full_name += sep;
+ }
+ else
+ {
+ pos = temp.state.full_name.find_last_of(sep);
+ if (pos != string::npos)
+ {
+ temp.state.name = temp.state.full_name.substr(pos+1);
+ }
+ else
+ {
+ temp.state.full_name += sep;
+ }
+ }
+ return temp;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool directory::
+ is_root_path (
+ const std::string& path
+ ) const
+ {
+ const char sep = get_separator();
+ if (path.size() == 1 && path[0] == sep)
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // POSIX
+
+#endif // DLIB_DIR_NAV_KERNEL_2_CPp_
+
diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.h b/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.h
new file mode 100644
index 000000000..af2f3d5dc
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.h
@@ -0,0 +1,659 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIR_NAV_KERNEl_2_
+#define DLIB_DIR_NAV_KERNEl_2_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+
+#include "dir_nav_kernel_abstract.h"
+
+#include <string>
+#include "../uintn.h"
+#include "../algs.h"
+
+#include <sys/types.h>
+#include <dirent.h>
+#include <libgen.h>
+#include <limits.h>
+#include <unistd.h>
+#include <sys/stat.h>
+#include <errno.h>
+#include <stdlib.h>
+#include <chrono>
+
+#if !defined(__USE_LARGEFILE64 ) && !defined(_LARGEFILE64_SOURCE)
+#define stat64 stat
+#endif
+
+#include <vector>
+#include "../stl_checked.h"
+#include "../enable_if.h"
+#include "../queue.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // file object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class file
+ {
+ /*!
+ INITIAL VALUES
+ state.name == name()
+ state.full_name == full_name()
+ state.file_size == size()
+ state.last_modified == last_modified()
+
+ CONVENTION
+ state.name == name()
+ state.full_name == full_name()
+ state.file_size == size()
+ state.last_modified == last_modified()
+
+ !*/
+
+ friend class directory;
+
+ struct data
+ {
+ uint64 file_size;
+ std::string name;
+ std::string full_name;
+ std::chrono::time_point<std::chrono::system_clock> last_modified;
+ };
+
+ void init(const std::string& name);
+
+ public:
+
+ struct private_constructor{};
+ inline file (
+ const std::string& name,
+ const std::string& full_name,
+ const uint64 file_size,
+ const std::chrono::time_point<std::chrono::system_clock>& last_modified,
+ private_constructor
+ )
+ {
+ state.file_size = file_size;
+ state.name = name;
+ state.full_name = full_name;
+ state.last_modified = last_modified;
+ }
+
+
+ class file_not_found : public error {
+ public: file_not_found(const std::string& s): error(s){}
+ };
+
+ inline file (
+ )
+ {
+ state.file_size = 0;
+ }
+
+ file (
+ const std::string& name
+ ) { init(name); }
+
+ file (
+ const char* name
+ ) { init(name); }
+
+ inline const std::string& name (
+ ) const { return state.name; }
+
+ inline const std::string& full_name (
+ ) const { return state.full_name; }
+
+ inline uint64 size (
+ ) const { return state.file_size; }
+
+ inline std::chrono::time_point<std::chrono::system_clock> last_modified (
+ ) const { return state.last_modified; }
+
+ operator std::string (
+ ) const { return full_name(); }
+
+ bool operator == (
+ const file& rhs
+ ) const;
+
+ bool operator != (
+ const file& rhs
+ ) const { return !(*this == rhs); }
+
+ inline bool operator < (
+ const file& item
+ ) const { return full_name() < item.full_name(); }
+
+ inline void swap (
+ file& item
+ )
+ {
+ exchange(state,item.state);
+ }
+
+ private:
+
+ // member data
+ data state;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // directory object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class directory
+ {
+ /*!
+ INITIAL VALUES
+ state.name == name()
+ state.full_name == full_name()
+
+ CONVENTION
+ state.name == name()
+ state.full_name == full_name()
+ is_root() == state.name.size() == 0
+
+ !*/
+
+ void init(const std::string& name);
+
+ public:
+ struct private_constructor{};
+ inline directory (
+ const std::string& name,
+ const std::string& full_name,
+ private_constructor
+ )
+ {
+ state.name = name;
+ state.full_name = full_name;
+ }
+
+ struct data
+ {
+ std::string name;
+ std::string full_name;
+ };
+
+ class dir_not_found : public error {
+ public: dir_not_found(const std::string& s):error(s){}
+ };
+ class listing_error : public error {
+ public: listing_error(const std::string& s):error(s){}
+ };
+
+ inline directory (
+ )
+ {
+ }
+
+ directory (
+ const std::string& name
+ ) { init(name); }
+
+ directory (
+ const char* name
+ ) { init(name); }
+
+ static char get_separator (
+ );
+
+ template <
+ typename queue_of_files
+ >
+ void get_files (
+ queue_of_files& files
+ ) const;
+
+ template <
+ typename queue_of_dirs
+ >
+ void get_dirs (
+ queue_of_dirs& dirs
+ ) const;
+
+ std::vector<file> get_files (
+ ) const
+ {
+ std::vector<file> temp_vector;
+ get_files(temp_vector);
+ return temp_vector;
+ }
+
+ std::vector<directory> get_dirs (
+ ) const
+ {
+ std::vector<directory> temp_vector;
+ get_dirs(temp_vector);
+ return temp_vector;
+ }
+
+ const directory get_parent (
+ ) const;
+
+ inline bool is_root (
+ ) const { return state.name.size() == 0; }
+
+ inline const std::string& name (
+ ) const { return state.name; }
+
+ inline const std::string& full_name (
+ ) const { return state.full_name; }
+
+ operator std::string (
+ ) const { return full_name(); }
+
+ bool operator == (
+ const directory& rhs
+ ) const;
+
+ bool operator != (
+ const directory& rhs
+ ) const { return !(*this == rhs); }
+
+ inline bool operator < (
+ const directory& item
+ ) const { return full_name() < item.full_name(); }
+
+ inline void swap (
+ directory& item
+ )
+ {
+ exchange(state,item.state);
+ }
+
+ private:
+
+ // member data
+ data state;
+
+ bool is_root_path (
+ const std::string& path
+ ) const;
+ /*!
+ ensures
+ - returns true if path is a root path.
+ Note that this function considers root paths that don't
+ have a trailing separator to also be valid.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::ostream& operator<< (
+ std::ostream& out,
+ const directory& item
+ ) { out << (std::string)item; return out; }
+
+ inline std::ostream& operator<< (
+ std::ostream& out,
+ const file& item
+ ) { out << (std::string)item; return out; }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ file& a,
+ file& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ directory& a,
+ directory& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // templated member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_files
+ >
+ typename disable_if<is_std_vector<queue_of_files>,void>::type
+ directory_helper_get_files (
+ const directory::data& state,
+ queue_of_files& files
+ )
+ {
+ using namespace std;
+
+ files.clear();
+ if (state.full_name.size() == 0)
+ throw directory::listing_error("This directory object currently doesn't represent any directory.");
+
+ DIR* ffind = 0;
+ struct dirent* data;
+ struct stat64 buffer;
+
+ try
+ {
+ string path = state.full_name;
+ // ensure that the path ends with a separator
+ if (path[path.size()-1] != directory::get_separator())
+ path += directory::get_separator();
+
+ // get a handle to something we can search with
+ ffind = opendir(state.full_name.c_str());
+ if (ffind == 0)
+ {
+ throw directory::listing_error("Unable to list the contents of " + state.full_name);
+ }
+
+ while(true)
+ {
+ errno = 0;
+ if ( (data = readdir(ffind)) == 0)
+ {
+ // there was an error or no more files
+ if ( errno == 0)
+ {
+ // there are no more files
+ break;
+ }
+ else
+ {
+ // there was an error
+ throw directory::listing_error("Unable to list the contents of " + state.full_name);
+ }
+ }
+
+ uint64 file_size;
+ // get a stat64 structure so we can see if this is a file
+ if (::stat64((path+data->d_name).c_str(), &buffer) != 0)
+ {
+ // this might be a broken symbolic link. We can check by calling
+ // readlink and seeing if it finds anything.
+ char buf[PATH_MAX];
+ ssize_t temp = readlink((path+data->d_name).c_str(),buf,sizeof(buf));
+ if (temp == -1)
+ throw directory::listing_error("Unable to list the contents of " + state.full_name);
+ else
+ file_size = static_cast<uint64>(temp);
+ }
+ else
+ {
+ file_size = static_cast<uint64>(buffer.st_size);
+ }
+ auto last_modified = std::chrono::system_clock::from_time_t(buffer.st_mtime);
+#ifdef _BSD_SOURCE
+ last_modified += std::chrono::duration_cast<std::chrono::system_clock::duration>(std::chrono::nanoseconds(buffer.st_atim.tv_nsec));
+#endif
+
+ if (S_ISDIR(buffer.st_mode) == 0)
+ {
+ // this is actually a file
+ file temp(
+ data->d_name,
+ path+data->d_name,
+ file_size,
+ last_modified,
+ file::private_constructor()
+ );
+ files.enqueue(temp);
+ }
+ } // while (true)
+
+ if (ffind != 0)
+ {
+ while (closedir(ffind))
+ {
+ if (errno != EINTR)
+ break;
+ }
+ ffind = 0;
+ }
+
+ }
+ catch (...)
+ {
+ if (ffind != 0)
+ {
+ while (closedir(ffind))
+ {
+ if (errno != EINTR)
+ break;
+ }
+ ffind = 0;
+ }
+ files.clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_files
+ >
+ typename enable_if<is_std_vector<queue_of_files>,void>::type
+ directory_helper_get_files (
+ const directory::data& state,
+ queue_of_files& files
+ )
+ {
+ queue<file>::kernel_2a temp_files;
+ directory_helper_get_files(state,temp_files);
+
+ files.clear();
+
+ // copy the queue of files into the vector
+ temp_files.reset();
+ while (temp_files.move_next())
+ {
+ files.push_back(temp_files.element());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_files
+ >
+ void directory::
+ get_files (
+ queue_of_files& files
+ ) const
+ {
+ // the reason for this indirection here is because it avoids a bug in
+ // the cygwin version of gcc
+ directory_helper_get_files(state,files);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_dirs
+ >
+ typename disable_if<is_std_vector<queue_of_dirs>,void>::type
+ directory_helper_get_dirs (
+ const directory::data& state,
+ queue_of_dirs& dirs
+ )
+ {
+ using namespace std;
+
+ dirs.clear();
+ if (state.full_name.size() == 0)
+ throw directory::listing_error("This directory object currently doesn't represent any directory.");
+
+ DIR* ffind = 0;
+ struct dirent* data;
+ struct stat64 buffer;
+
+ try
+ {
+ string path = state.full_name;
+ // ensure that the path ends with a separator
+ if (path[path.size()-1] != directory::get_separator())
+ path += directory::get_separator();
+
+ // get a handle to something we can search with
+ ffind = opendir(state.full_name.c_str());
+ if (ffind == 0)
+ {
+ throw directory::listing_error("Unable to list the contents of " + state.full_name);
+ }
+
+ while(true)
+ {
+ errno = 0;
+ if ( (data = readdir(ffind)) == 0)
+ {
+ // there was an error or no more files
+ if ( errno == 0)
+ {
+ // there are no more files
+ break;
+ }
+ else
+ {
+ // there was an error
+ throw directory::listing_error("Unable to list the contents of " + state.full_name);
+ }
+ }
+
+ // get a stat64 structure so we can see if this is a file
+ if (::stat64((path+data->d_name).c_str(), &buffer) != 0)
+ {
+ // just assume this isn't a directory. It is probably a broken
+ // symbolic link.
+ continue;
+ }
+
+ string dtemp(data->d_name);
+ if (S_ISDIR(buffer.st_mode) &&
+ dtemp != "." &&
+ dtemp != ".." )
+ {
+ // this is a directory so add it to dirs
+ directory temp(dtemp,path+dtemp, directory::private_constructor());
+ dirs.enqueue(temp);
+ }
+ } // while (true)
+
+ if (ffind != 0)
+ {
+ while (closedir(ffind))
+ {
+ if (errno != EINTR)
+ break;
+ }
+ ffind = 0;
+ }
+
+ }
+ catch (...)
+ {
+ if (ffind != 0)
+ {
+ while (closedir(ffind))
+ {
+ if (errno != EINTR)
+ break;
+ }
+ ffind = 0;
+ }
+ dirs.clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_dirs
+ >
+ typename enable_if<is_std_vector<queue_of_dirs>,void>::type
+ directory_helper_get_dirs (
+ const directory::data& state,
+ queue_of_dirs& dirs
+ )
+ {
+ queue<directory>::kernel_2a temp_dirs;
+ directory_helper_get_dirs(state,temp_dirs);
+
+ dirs.clear();
+
+ // copy the queue of dirs into the vector
+ temp_dirs.reset();
+ while (temp_dirs.move_next())
+ {
+ dirs.push_back(temp_dirs.element());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <
+ typename queue_of_dirs
+ >
+ void directory::
+ get_dirs (
+ queue_of_dirs& dirs
+ ) const
+ {
+ // the reason for this indirection here is because it avoids a bug in
+ // the cygwin version of gcc
+ directory_helper_get_dirs(state,dirs);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_dir
+ >
+ typename disable_if<is_std_vector<queue_of_dir>,void>::type get_filesystem_roots (
+ queue_of_dir& roots
+ )
+ {
+ roots.clear();
+ directory dir("/");
+ roots.enqueue(dir);
+ }
+
+ template <
+ typename queue_of_dir
+ >
+ typename enable_if<is_std_vector<queue_of_dir>,void>::type get_filesystem_roots (
+ std::vector<directory>& roots
+ )
+ {
+ roots.clear();
+ directory dir("/");
+ roots.push_back(dir);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#ifdef NO_MAKEFILE
+#include "dir_nav_kernel_2.cpp"
+#endif
+
+#endif // DLIB_DIR_NAV_KERNEl_2_
+
diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_abstract.h b/ml/dlib/dlib/dir_nav/dir_nav_kernel_abstract.h
new file mode 100644
index 000000000..53254ee03
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_abstract.h
@@ -0,0 +1,515 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DIR_NAV_KERNEl_ABSTRACT_
+#ifdef DLIB_DIR_NAV_KERNEl_ABSTRACT_
+
+#include <string>
+#include <vector>
+#include "../uintn.h"
+#include "../algs.h"
+#include <chrono>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ GENERAL WARNING
+ Don't call any of these functions or make any of these objects
+ before main() has been entered. That means no instances
+ of file or directory at the global scope.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_of_dir
+ >
+ void get_filesystem_roots (
+ queue_of_dir& roots
+ );
+ /*!
+ requires
+ - queue_of_dirs == an implementation of queue/queue_kernel_abstract.h with T
+ set to directory or a std::vector<directory> or dlib::std_vector_c<directory>.
+ ensures
+ - #roots == a queue containing directories that represent all the roots
+ of the filesystem on this machine. (e.g. in windows you have c:\, d:\
+ etc.)
+ throws
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // file object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class file
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a file.
+
+ Note that the size of a file is determined at the time the file
+ object is constructed. Thus if a file changes sizes after its
+ file object has been created its file object's size() method
+ will not reflect the new file size.
+ !*/
+
+ public:
+
+ class file_not_found : public error {};
+
+ file (
+ );
+ /*!
+ ensures
+ - #*this has been properly initialized
+ - #name() == ""
+ - #full_name() == ""
+ - #size() == 0
+ - #*this does not represent any file
+ throws
+ - std::bad_alloc
+ !*/
+
+ file (
+ const std::string& name
+ );
+ /*!
+ ensures
+ - #*this has been properly initialized
+ - #*this represents the file given by name
+ Note that name can be a fully qualified path or just a path
+ relative to the current working directory. Also, any symbolic
+ links in name will be resolved.
+ throws
+ - std::bad_alloc
+ - file_not_found
+ This exception is thrown if the file can not be found or
+ accessed.
+ !*/
+
+ file (
+ const char* name
+ );
+ /*!
+ ensures
+ - this function is identical to file(const std::string& name)
+ !*/
+
+ file (
+ const file& item
+ );
+ /*!
+ ensures
+ - #*this == item
+ throws
+ - std::bad_alloc
+ !*/
+
+ ~file (
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ const std::string& name (
+ ) const;
+ /*!
+ ensures
+ - returns the name of the file. This is full_name() minus
+ the path to the file.
+ !*/
+
+ const std::string& full_name (
+ ) const;
+ /*!
+ ensures
+ - returns the fully qualified name for the file represented by *this
+ !*/
+
+ uint64 size (
+ ) const;
+ /*!
+ ensures
+ - returns the size of this file in bytes.
+ !*/
+
+ std::chrono::time_point<std::chrono::system_clock> last_modified (
+ ) const;
+ /*!
+ ensures
+ - returns the time the file was last modified.
+ !*/
+
+ operator std::string (
+ ) const;
+ /*!
+ ensures
+ - returns full_name()
+ (i.e. provides an implicit conversion to string from dlib::file)
+ !*/
+
+ file& operator= (
+ const file& rhs
+ );
+ /*!
+ ensures
+ - #*this == rhs
+ !*/
+
+ bool operator == (
+ const file& rhs
+ ) const;
+ /*!
+ ensures
+ - if (*this and rhs represent the same file) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool operator != (
+ const file& rhs
+ ) const;
+ /*!
+ ensures
+ - if (*this and rhs represent the same file) then
+ - returns false
+ - else
+ - returns true
+ !*/
+
+ bool operator < (
+ const file& item
+ ) const;
+ /*!
+ ensures
+ - if (full_name() < item.full_name()) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void swap (
+ file& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // directory object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class directory
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a directory in a file system. It gives
+ the ability to traverse a directory tree.
+
+ Note that the directories . and .. are not returned by get_dirs()
+ !*/
+
+ public:
+
+ class dir_not_found : public error {};
+ class listing_error : public error {};
+
+ directory (
+ );
+ /*!
+ ensures
+ - #*this has been properly initialized
+ - #full_name() == ""
+ - #name() == ""
+ - #is_root() == true
+ - #*this does not represent any directory
+ throws
+ - std::bad_alloc
+ !*/
+
+ directory (
+ const std::string& name
+ );
+ /*!
+ ensures
+ - #*this has been properly initialized
+ - #*this represents the directory given by name.
+ Note that name can be a fully qualified path or just a path
+ relative to the current working directory. Also, any symbolic
+ links in name will be resolved.
+ throws
+ - std::bad_alloc
+ - dir_not_found
+ This exception is thrown if the directory can not be found or
+ accessed.
+ !*/
+
+ directory (
+ const char* name
+ );
+ /*!
+ ensures
+ - this function is identical to directory(const std::string& name)
+ !*/
+
+ directory (
+ const directory& item
+ );
+ /*!
+ ensures
+ - #*this == item
+ throws
+ - std::bad_alloc
+ !*/
+
+ ~directory (
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ static char get_separator (
+ );
+ /*!
+ ensures
+ - returns the character used to separate directories and file names in a
+ path. (i.e. \ on windows and / in unix)
+ !*/
+
+ template <
+ typename queue_of_files
+ >
+ void get_files (
+ queue_of_files& files
+ ) const;
+ /*!
+ requires
+ - queue_of_files == an implementation of queue/queue_kernel_abstract.h with T
+ set to file or a std::vector<file> or dlib::std_vector_c<file>.
+ ensures
+ - #files == A queue containing all the files present in this directory.
+ (Note that symbolic links will not have been resolved in the names
+ of the returned files.)
+ - #files.size() == the number of files in this directory
+ throws
+ - bad_alloc
+ If this exception is thrown then the call to get_files() has
+ no effect on *this and #files is unusable until files.clear()
+ is called and succeeds.
+ - listing_error
+ This exception is thrown if listing access has been denied to this
+ directory or if some error occurred that prevented us from successfully
+ getting the contents of this directory.
+ If this exception is thrown then the call to get_files() has
+ no effect on *this and #files.size()==0.
+ !*/
+
+ std::vector<file> get_files (
+ ) const;
+ /*!
+ ensures
+ - This function simply calls get_files(temp_vector) and then returns temp_vector.
+ !*/
+
+ template <
+ typename queue_of_dirs
+ >
+ void get_dirs (
+ queue_of_dirs& dirs
+ ) const;
+ /*!
+ requires
+ - queue_of_dirs == an implementation of queue/queue_kernel_abstract.h with T
+ set to directory or a std::vector<directory> or dlib::std_vector_c<directory>.
+ ensures
+ - #dirs == a queue containing all the directories present in this directory.
+ (note that symbolic links will not have been resolved in the names
+ of the returned directories.)
+ - #dirs.size() == the number of subdirectories in this directory
+ throws
+ - bad_alloc
+ If this exception is thrown then the call to get_files() has
+ no effect on *this and #files is unusable until files.clear()
+ is called and succeeds.
+ - listing_error
+ This exception is thrown if listing access has been denied to this
+ directory or if some error occurred that prevented us from successfully
+ getting the contents of this directory.
+ If this exception is thrown then the call to get_dirs() has
+ no effect on *this and #dirs.size()==0.
+ !*/
+
+ std::vector<directory> get_dirs (
+ ) const;
+ /*!
+ ensures
+ - This function simply calls get_dirs(temp_vector) and then returns temp_vector.
+ !*/
+
+ bool is_root (
+ ) const;
+ /*!
+ ensures
+ - if (*this represents the root of this directory tree) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ const directory get_parent (
+ ) const;
+ /*!
+ ensures
+ - if (is_root()) then
+ - returns a copy of *this
+ - else
+ - returns the parent directory of *this
+ throws
+ - bad_alloc
+ If this exception is thrown then the call to get_parent() will
+ have no effect.
+ !*/
+
+ const std::string& name (
+ ) const;
+ /*!
+ ensures
+ - if (is_root()) then
+ - returns ""
+ - else
+ - returns the name of the directory. This is full_name() minus
+ the path to the directory.
+ !*/
+
+ const std::string& full_name (
+ ) const;
+ /*!
+ ensures
+ - returns the fully qualified directory name for *this
+ - if (is_root()) then
+ - the last character of #full_name() is get_separator()
+ - else
+ - the last character of #full_name() is NOT get_separator()
+ !*/
+
+ operator std::string (
+ ) const;
+ /*!
+ ensures
+ - returns full_name()
+ (i.e. provides an implicit conversion to string from dlib::directory)
+ !*/
+
+ directory& operator= (
+ const directory& rhs
+ );
+ /*!
+ ensures
+ - #*this == rhs
+ !*/
+
+ bool operator == (
+ const directory& rhs
+ ) const;
+ /*!
+ ensures
+ - if (*this and rhs represent the same directory) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool operator != (
+ const directory& rhs
+ ) const;
+ /*!
+ ensures
+ - if (*this and rhs represent the same directory) then
+ - returns false
+ - else
+ - returns true
+ !*/
+
+ bool operator < (
+ const directory& item
+ ) const;
+ /*!
+ ensures
+ - if (full_name() < item.full_name()) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void swap (
+ directory& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::ostream& operator<< (
+ std::ostream& out,
+ const directory& item
+ );
+ /*!
+ ensures
+ - performs: out << item.full_name()
+ - returns out
+ !*/
+
+ inline std::ostream& operator<< (
+ std::ostream& out,
+ const file& item
+ );
+ /*!
+ ensures
+ - performs: out << item.full_name()
+ - returns out
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ file& a,
+ file& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function for file objects
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void swap (
+ directory& a,
+ directory& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function for directory objects
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DIR_NAV_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/dir_nav/posix.h b/ml/dlib/dlib/dir_nav/posix.h
new file mode 100644
index 000000000..8d499064f
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/posix.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIR_NAV_KERNEl_1_
+#include "dir_nav_kernel_2.h"
+#endif
+
diff --git a/ml/dlib/dlib/dir_nav/windows.h b/ml/dlib/dlib/dir_nav/windows.h
new file mode 100644
index 000000000..b0f1e1bf5
--- /dev/null
+++ b/ml/dlib/dlib/dir_nav/windows.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIR_NAV_KERNEl_2_
+#include "dir_nav_kernel_1.h"
+#endif
+
diff --git a/ml/dlib/dlib/directed_graph.h b/ml/dlib/dlib/directed_graph.h
new file mode 100644
index 000000000..a452521dc
--- /dev/null
+++ b/ml/dlib/dlib/directed_graph.h
@@ -0,0 +1,37 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIRECTED_GRAPh_
+#define DLIB_DIRECTED_GRAPh_
+
+#include "directed_graph/directed_graph_kernel_1.h"
+
+#include "algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename E = char,
+ typename mem_manager = default_memory_manager
+ >
+ class directed_graph
+ {
+ directed_graph() {}
+ public:
+
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef directed_graph_kernel_1<T,E,mem_manager,false>
+ kernel_1a;
+ typedef directed_graph_kernel_1<T,E,mem_manager,true>
+ kernel_1a_c;
+
+ };
+}
+
+#endif // DLIB_DIRECTED_GRAPh_
+
+
diff --git a/ml/dlib/dlib/directed_graph/directed_graph_kernel_1.h b/ml/dlib/dlib/directed_graph/directed_graph_kernel_1.h
new file mode 100644
index 000000000..b0cc6a2c2
--- /dev/null
+++ b/ml/dlib/dlib/directed_graph/directed_graph_kernel_1.h
@@ -0,0 +1,704 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DIRECTED_GRAPH_KERNEl_1_
+#define DLIB_DIRECTED_GRAPH_KERNEl_1_
+
+#include <memory>
+#include <vector>
+
+#include "../serialize.h"
+#include "../noncopyable.h"
+#include "../std_allocator.h"
+#include "../algs.h"
+#include "directed_graph_kernel_abstract.h"
+#include "../is_kind.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename node_type, typename directed_graph, bool is_checked>
+ struct directed_graph_checker_helper
+ {
+ /*!
+ This object is used to check preconditions based on the value of is_checked
+ !*/
+
+ static void check_parent_edge (
+ unsigned long edge_index,
+ const node_type& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(edge_index < self.number_of_parents(),
+ "\tnode_type& directed_graph::node_type::parent_edge(edge_index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tedge_index: " << edge_index
+ << "\n\tnumber_of_parents(): " << self.number_of_parents()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_child_edge (
+ unsigned long edge_index,
+ const node_type& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(edge_index < self.number_of_children(),
+ "\tnode_type& directed_graph::node_type::child_edge(edge_index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tedge_index: " << edge_index
+ << "\n\tnumber_of_children(): " << self.number_of_children()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_parent (
+ unsigned long edge_index,
+ const node_type& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(edge_index < self.number_of_parents(),
+ "\tnode_type& directed_graph::node_type::parent(edge_index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tedge_index: " << edge_index
+ << "\n\tnumber_of_parents(): " << self.number_of_parents()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_child (
+ unsigned long edge_index,
+ const node_type& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(edge_index < self.number_of_children(),
+ "\tnode_type& directed_graph::node_type::child(edge_index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tedge_index: " << edge_index
+ << "\n\tnumber_of_children(): " << self.number_of_children()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_node (
+ unsigned long index,
+ const directed_graph& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(index < self.number_of_nodes(),
+ "\tnode_type& directed_graph::node(index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tindex: " << index
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_has_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index,
+ const directed_graph& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(parent_node_index < self.number_of_nodes() &&
+ child_node_index < self.number_of_nodes(),
+ "\tvoid directed_graph::has_edge(parent_node_index, child_node_index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tparent_node_index: " << parent_node_index
+ << "\n\tchild_node_index: " << child_node_index
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_add_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index,
+ const directed_graph& self
+ )
+ {
+ DLIB_CASSERT(parent_node_index < self.number_of_nodes() &&
+ child_node_index < self.number_of_nodes(),
+ "\tvoid directed_graph::add_edge(parent_node_index, child_node_index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tparent_node_index: " << parent_node_index
+ << "\n\tchild_node_index: " << child_node_index
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+
+ DLIB_CASSERT( self.has_edge(parent_node_index, child_node_index) == false,
+ "\tvoid directed_graph::add_edge(parent_node_index, child_node_index)"
+ << "\n\tYou can't add an edge if it already exists in the graph"
+ << "\n\tparent_node_index: " << parent_node_index
+ << "\n\tchild_node_index: " << child_node_index
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+
+ }
+
+ static void check_remove_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index,
+ const directed_graph& self
+ )
+ {
+ DLIB_CASSERT(parent_node_index < self.number_of_nodes() &&
+ child_node_index < self.number_of_nodes(),
+ "\tvoid directed_graph::remove_edge(parent_node_index, child_node_index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tparent_node_index: " << parent_node_index
+ << "\n\tchild_node_index: " << child_node_index
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+
+ DLIB_CASSERT( self.has_edge(parent_node_index, child_node_index) == true,
+ "\tvoid directed_graph::remove_edge(parent_node_index, child_node_index)"
+ << "\n\tYou can't remove an edge if it isn't in the graph"
+ << "\n\tparent_node_index: " << parent_node_index
+ << "\n\tchild_node_index: " << child_node_index
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+
+ }
+
+ static void check_remove_node (
+ unsigned long index,
+ const directed_graph& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(index < self.number_of_nodes(),
+ "\tvoid directed_graph::remove_node(index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tindex: " << index
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+ }
+ };
+
+ template <typename node_type, typename directed_graph>
+ struct directed_graph_checker_helper <node_type, directed_graph, false>
+ {
+ static inline void check_parent ( unsigned long , const node_type&) { }
+ static inline void check_child ( unsigned long , const node_type& ) { }
+ static inline void check_parent_edge ( unsigned long , const node_type&) { }
+ static inline void check_child_edge ( unsigned long , const node_type& ) { }
+ static inline void check_node ( unsigned long , const directed_graph& ) { }
+ static inline void check_has_edge ( unsigned long , unsigned long , const directed_graph& ) { }
+ static inline void check_add_edge ( unsigned long , unsigned long , const directed_graph& ) { }
+ static inline void check_remove_edge ( unsigned long , unsigned long , const directed_graph& ) { }
+ static inline void check_remove_node ( unsigned long , const directed_graph& ) { }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E = char,
+ typename mem_manager = default_memory_manager,
+ bool is_checked = true
+ >
+ class directed_graph_kernel_1 : noncopyable
+ {
+
+ /*!
+ INITIAL VALUE
+ - nodes.size() == 0
+
+ CONVENTION
+ - nodes.size() == number_of_nodes()
+ - for all valid i:
+ - *nodes[i] == node(i)
+ - nodes[i]->parents.size() == nodes[i]->number_of_parents(i)
+ - nodes[i]->children.size() == nodes[i]->number_of_children(i)
+ - nodes[i]->edge_parents.size() == nodes[i]->number_of_parents(i)
+ - nodes[i]->edge_children.size() == nodes[i]->number_of_children(i)
+ - nodes[i]->idx == i == nodes[i]->index()
+ - for all valid p:
+ - nodes[i]->parents[p] == pointer to the p'th parent node of i
+ - *nodes[i]->parents[p] == nodes[i]->parent(p)
+ - *nodes[i]->edge_parents[p] == nodes[i]->parent_edge(p)
+ - for all valid c:
+ - nodes[i]->children[c] == pointer to the c'th child node of i
+ - *nodes[i]->children[c] == nodes[i]->child(c)
+ - *nodes[i]->edge_children[c] == nodes[i]->child_edge(c)
+ !*/
+
+ public:
+ struct node_type;
+
+ private:
+ typedef directed_graph_checker_helper<node_type, directed_graph_kernel_1, is_checked> checker;
+
+
+ public:
+
+ typedef T type;
+ typedef E edge_type;
+ typedef mem_manager mem_manager_type;
+
+ template <typename Tr, typename Er, typename MMr>
+ struct rebind {
+ typedef directed_graph_kernel_1<Tr,Er,MMr> other;
+ };
+
+ directed_graph_kernel_1(
+ ) {}
+
+ virtual ~directed_graph_kernel_1(
+ ) {}
+
+ void clear(
+ ) { nodes.clear(); }
+
+ void set_number_of_nodes (
+ unsigned long new_size
+ );
+
+ unsigned long number_of_nodes (
+ ) const { return nodes.size(); }
+
+ node_type& node (
+ unsigned long index
+ ) { checker::check_node(index,*this); return *nodes[index]; }
+
+ const node_type& node (
+ unsigned long index
+ ) const { checker::check_node(index,*this); return *nodes[index]; }
+
+ bool has_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index
+ ) const;
+
+ void add_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index
+ );
+
+ void remove_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index
+ );
+
+ unsigned long add_node (
+ );
+
+ void remove_node (
+ unsigned long index
+ );
+
+ void swap (
+ directed_graph_kernel_1& item
+ ) { nodes.swap(item.nodes); }
+
+ private:
+
+
+ public:
+
+ struct node_type
+ {
+ T data;
+ typedef directed_graph_kernel_1 graph_type;
+
+ unsigned long index(
+ ) const { return idx; }
+
+ unsigned long number_of_parents (
+ ) const { return parents.size(); }
+
+ unsigned long number_of_children (
+ ) const { return children.size(); }
+
+ const node_type& parent (
+ unsigned long edge_index
+ ) const { checker::check_parent(edge_index,*this); return *parents[edge_index]; }
+
+ node_type& parent (
+ unsigned long edge_index
+ ) { checker::check_parent(edge_index,*this); return *parents[edge_index]; }
+
+ const node_type& child (
+ unsigned long edge_index
+ ) const { checker::check_child(edge_index,*this); return *children[edge_index]; }
+
+ node_type& child (
+ unsigned long edge_index
+ ) { checker::check_child(edge_index,*this); return *children[edge_index]; }
+
+ const E& parent_edge (
+ unsigned long edge_index
+ ) const { checker::check_parent_edge(edge_index,*this); return *edge_parents[edge_index]; }
+
+ E& parent_edge (
+ unsigned long edge_index
+ ) { checker::check_parent_edge(edge_index,*this); return *edge_parents[edge_index]; }
+
+ const E& child_edge (
+ unsigned long edge_index
+ ) const { checker::check_child_edge(edge_index,*this); return *edge_children[edge_index]; }
+
+ E& child_edge (
+ unsigned long edge_index
+ ) { checker::check_child_edge(edge_index,*this); return *edge_children[edge_index]; }
+
+ private:
+ friend class directed_graph_kernel_1;
+ typedef std_allocator<node_type*,mem_manager> alloc_type;
+ typedef std_allocator<std::shared_ptr<E>,mem_manager> alloc_edge_type;
+ std::vector<node_type*,alloc_type> parents;
+ std::vector<node_type*,alloc_type> children;
+ std::vector<std::shared_ptr<E>,alloc_edge_type> edge_parents;
+ std::vector<std::shared_ptr<E>,alloc_edge_type> edge_children;
+ unsigned long idx;
+ };
+
+ private:
+
+ typedef std_allocator<std::shared_ptr<node_type>,mem_manager> alloc_type;
+ typedef std::vector<std::shared_ptr<node_type>, alloc_type> vector_type;
+ vector_type nodes;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ struct is_directed_graph<directed_graph_kernel_1<T,E,mem_manager, is_checked> >
+ {
+ static const bool value = true;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ inline void swap (
+ directed_graph_kernel_1<T,E,mem_manager,is_checked>& a,
+ directed_graph_kernel_1<T,E,mem_manager,is_checked>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void serialize (
+ const directed_graph_kernel_1<T,E,mem_manager,is_checked>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.number_of_nodes(), out);
+
+ // serialize each node
+ for (unsigned long i = 0; i < item.number_of_nodes(); ++i)
+ {
+ serialize(item.node(i).data, out);
+
+ // serialize all the child edges
+ serialize(item.node(i).number_of_children(), out);
+ for (unsigned long c = 0; c < item.node(i).number_of_children(); ++c)
+ {
+ serialize(item.node(i).child(c).index(), out);
+ serialize(item.node(i).child_edge(c), out);
+ }
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type directed_graph_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void deserialize (
+ directed_graph_kernel_1<T,E,mem_manager,is_checked>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long size;
+ deserialize(size, in);
+
+ item.clear();
+ item.set_number_of_nodes(size);
+
+ // deserialize each node
+ for (unsigned long i = 0; i < item.number_of_nodes(); ++i)
+ {
+ deserialize(item.node(i).data, in);
+
+ unsigned long num_children;
+ deserialize(num_children, in);
+
+ // Add all the edges going to this nodes children nodes
+ for (unsigned long c = 0; c < num_children; ++c)
+ {
+ unsigned long child_index;
+ deserialize(child_index, in);
+
+ item.add_edge(i, child_index);
+
+ // find the edge we just added
+ for (unsigned long j = 0; j < item.node(i).number_of_children(); ++j)
+ {
+ if (item.node(i).child(j).index() == child_index)
+ {
+ deserialize(item.node(i).child_edge(j), in);
+ break;
+ }
+ }
+ }
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type directed_graph_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void directed_graph_kernel_1<T,E,mem_manager,is_checked>::
+ set_number_of_nodes (
+ unsigned long new_size
+ )
+ {
+ try
+ {
+ nodes.resize(new_size);
+ for (unsigned long i = 0; i < nodes.size(); ++i)
+ {
+ nodes[i].reset(new node_type);
+ nodes[i]->idx = i;
+ }
+ }
+ catch (...)
+ {
+ clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ bool directed_graph_kernel_1<T,E,mem_manager,is_checked>::
+ has_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index
+ ) const
+ {
+ checker::check_has_edge(parent_node_index, child_node_index, *this);
+
+ node_type& n = *nodes[parent_node_index];
+
+ // search all the child nodes to see if there is a link to the right node
+ for (unsigned long i = 0; i < n.children.size(); ++i)
+ {
+ if (n.children[i]->idx == child_node_index)
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void directed_graph_kernel_1<T,E,mem_manager,is_checked>::
+ add_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index
+ )
+ {
+ checker::check_add_edge(parent_node_index, child_node_index, *this);
+ try
+ {
+ node_type& p = *nodes[parent_node_index];
+ node_type& c = *nodes[child_node_index];
+
+ p.children.push_back(&c);
+ c.parents.push_back(&p);
+
+ p.edge_children.push_back(std::shared_ptr<E>(new E));
+ c.edge_parents.push_back(p.edge_children.back());
+ }
+ catch (...)
+ {
+ clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void directed_graph_kernel_1<T,E,mem_manager,is_checked>::
+ remove_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index
+ )
+ {
+ checker::check_remove_edge(parent_node_index, child_node_index, *this);
+
+ node_type& p = *nodes[parent_node_index];
+ node_type& c = *nodes[child_node_index];
+
+ // remove the record of the link from the parent node
+ unsigned long pos = static_cast<unsigned long>(find( p.children.begin(),
+ p.children.end(),
+ &c) - p.children.begin());
+ p.children.erase(p.children.begin()+pos);
+ p.edge_children.erase(p.edge_children.begin()+pos);
+
+ // remove the record of the link from the child node
+ pos = static_cast<unsigned long>(find( c.parents.begin(),
+ c.parents.end(),
+ &p) - c.parents.begin());
+ c.parents.erase(c.parents.begin() + pos);
+ c.edge_parents.erase(c.edge_parents.begin() + pos);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ unsigned long directed_graph_kernel_1<T,E,mem_manager,is_checked>::
+ add_node (
+ )
+ {
+ try
+ {
+ std::shared_ptr<node_type> n(new node_type);
+ n->idx = nodes.size();
+ nodes.push_back(n);
+ return n->idx;
+ }
+ catch (...)
+ {
+ clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void directed_graph_kernel_1<T,E,mem_manager,is_checked>::
+ remove_node (
+ unsigned long index
+ )
+ {
+ checker::check_remove_node(index,*this);
+
+ node_type& n = *nodes[index];
+
+ // remove all edges pointing to this node from its parents
+ for (unsigned long i = 0; i < n.parents.size(); ++i)
+ {
+ // remove the edge from this specific parent
+ unsigned long pos = static_cast<unsigned long>(find(n.parents[i]->children.begin(),
+ n.parents[i]->children.end(),
+ &n) - n.parents[i]->children.begin());
+
+ n.parents[i]->children.erase(n.parents[i]->children.begin() + pos);
+ n.parents[i]->edge_children.erase(n.parents[i]->edge_children.begin() + pos);
+ }
+
+ // remove all edges pointing to this node from its children
+ for (unsigned long i = 0; i < n.children.size(); ++i)
+ {
+ // remove the edge from this specific child
+ unsigned long pos = static_cast<unsigned long>(find(n.children[i]->parents.begin(),
+ n.children[i]->parents.end(),
+ &n) - n.children[i]->parents.begin());
+
+ n.children[i]->parents.erase(n.children[i]->parents.begin() + pos);
+ n.children[i]->edge_parents.erase(n.children[i]->edge_parents.begin() + pos);
+ }
+
+ // now remove this node by replacing it with the last node in the nodes vector
+ nodes[index] = nodes[nodes.size()-1];
+
+ // update the index for the node we just moved
+ nodes[index]->idx = index;
+
+ // now remove the duplicated node at the end of the vector
+ nodes.pop_back();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DIRECTED_GRAPH_KERNEl_1_
+
diff --git a/ml/dlib/dlib/directed_graph/directed_graph_kernel_abstract.h b/ml/dlib/dlib/directed_graph/directed_graph_kernel_abstract.h
new file mode 100644
index 000000000..70dd66efd
--- /dev/null
+++ b/ml/dlib/dlib/directed_graph/directed_graph_kernel_abstract.h
@@ -0,0 +1,383 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DIRECTED_GRAPH_KERNEl_ABSTRACT_
+#ifdef DLIB_DIRECTED_GRAPH_KERNEl_ABSTRACT_
+
+#include "../serialize.h"
+#include "../algs.h"
+#include "../noncopyable.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename E = char,
+ typename mem_manager = default_memory_manager
+ >
+ class directed_graph : noncopyable
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must be swappable by a global swap() and
+ T must have a default constructor
+
+ REQUIREMENTS ON E
+ E must be swappable by a global swap() and
+ E must have a default constructor
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ The only time pointers or references to nodes or edges become invalid is when
+ they reference nodes or edges that have been removed from a graph.
+
+ INITIAL VALUE
+ number_of_nodes() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a directed graph which is a set of nodes with directed
+ edges connecting various nodes.
+
+ In this object if there is a directed edge from a node A to a node B then I say
+ that A is the parent of B and B is the child of A.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef E edge_type;
+ typedef mem_manager mem_manager_type;
+
+ template <typename Tr, typename Er, typename MMr>
+ struct rebind {
+ typedef directed_graph<Tr,Er,MMr> other;
+ };
+
+ directed_graph(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ !*/
+
+ virtual ~directed_graph(
+ );
+ /*!
+ ensures
+ - all resources associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void set_number_of_nodes (
+ unsigned long new_size
+ );
+ /*!
+ ensures
+ - #number_of_nodes() == new_size
+ - for all i < new_size:
+ - number_of_parents(i) == 0
+ - number_of_children(i) == 0
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ unsigned long number_of_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in this graph
+ !*/
+
+ struct node_type
+ {
+ T data;
+ typedef directed_graph graph_type;
+
+ unsigned long index(
+ ) const;
+ /*!
+ ensures
+ - let G be the graph that contains the node *this
+ - returns a number N such that G.node(N) == *this
+ (i.e. returns the index of this node in the graph)
+ !*/
+
+ unsigned long number_of_parents (
+ ) const;
+ /*!
+ ensures
+ - returns the number of parents of this node
+ !*/
+
+ unsigned long number_of_children (
+ ) const;
+ /*!
+ ensures
+ - returns the number of children of this node
+ !*/
+
+ const node_type& parent (
+ unsigned long edge_index
+ ) const;
+ /*!
+ requires
+ - edge_index < number_of_parents()
+ ensures
+ - returns a const reference to the edge_index'th parent of *this
+ !*/
+
+ node_type& parent (
+ unsigned long edge_index
+ );
+ /*!
+ requires
+ - edge_index < number_of_parents()
+ ensures
+ - returns a non-const reference to the edge_index'th parent of *this
+ !*/
+
+ const node_type& child (
+ unsigned long edge_index
+ ) const;
+ /*!
+ requires
+ - edge_index < number_of_children()
+ ensures
+ - returns a const reference to the edge_index'th child of *this
+ !*/
+
+ node_type& child (
+ unsigned long edge_index
+ );
+ /*!
+ requires
+ - edge_index < number_of_children()
+ ensures
+ - returns a non-const reference to the edge_index'th child of *this
+ !*/
+
+ const E& parent_edge (
+ unsigned long edge_index
+ ) const;
+ /*!
+ requires
+ - edge_index < number_of_parents()
+ ensures
+ - returns a const reference to the edge_index'th edge data for the
+ edge connecting to node this->parent(edge_index)
+ !*/
+
+ E& parent_edge (
+ unsigned long edge_index
+ );
+ /*!
+ requires
+ - edge_index < number_of_parents()
+ ensures
+ - returns a non-const reference to the edge_index'th edge data for the
+ edge connecting to node this->parent(edge_index)
+ !*/
+
+ const E& child_edge (
+ unsigned long edge_index
+ ) const;
+ /*!
+ requires
+ - edge_index < number_of_children()
+ ensures
+ - returns a const reference to the edge_index'th edge data for the
+ edge connecting to node this->child(edge_index)
+ !*/
+
+ E& child_edge (
+ unsigned long edge_index
+ );
+ /*!
+ requires
+ - edge_index < number_of_children()
+ ensures
+ - returns a non-const reference to the edge_index'th edge data for the
+ edge connecting to node this->child(edge_index)
+ !*/
+ };
+
+ node_type& node (
+ unsigned long index
+ );
+ /*!
+ requires
+ - index < number_of_nodes()
+ ensures
+ - returns a non-const reference to the node with the given index
+ !*/
+
+ const node_type& node (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < number_of_nodes()
+ ensures
+ - returns a const reference to the node with the given index
+ !*/
+
+ bool has_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index
+ ) const;
+ /*!
+ requires
+ - parent_node_index < number_of_nodes()
+ - child_node_index < number_of_nodes()
+ ensures
+ - if (there is an edge leading from node(parent_node_index) to
+ node(child_node_index)) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void add_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index
+ );
+ /*!
+ requires
+ - parent_node_index < number_of_nodes()
+ - child_node_index < number_of_nodes()
+ - has_edge(parent_node_index, child_node_index) == false
+ ensures
+ - #has_edge(parent_node_index, child_node_index) == true
+ throws
+ - std::bad_alloc
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ void remove_edge (
+ unsigned long parent_node_index,
+ unsigned long child_node_index
+ );
+ /*!
+ requires
+ - parent_node_index < number_of_nodes()
+ - child_node_index < number_of_nodes()
+ - has_edge(parent_node_index, child_node_index) == true
+ ensures
+ - #has_edge(parent_node_index, child_node_index) == false
+ throws
+ - std::bad_alloc
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ unsigned long add_node (
+ );
+ /*!
+ ensures
+ - does not change the index number of existing nodes
+ - adds a node with index N == number_of_nodes() such that:
+ - #node(N).number_of_parents() == 0
+ - #node(N).number_of_children() == 0
+ - #number_of_nodes() == number_of_nodes() + 1
+ - returns N
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ void remove_node (
+ unsigned long index
+ );
+ /*!
+ requires
+ - index < number_of_nodes()
+ ensures
+ - removes the node with the given index from the graph.
+ - removes all edges linking the removed node to the rest
+ of the graph.
+ - the remaining node indexes are remapped so that they remain
+ contiguous. (This means that for all valid N, node(N) doesn't
+ necessarily reference the same node as #node(N))
+ - #number_of_nodes() == number_of_nodes() - 1
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ void swap (
+ directed_graph& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ directed_graph<T,mem_manager>& a,
+ directed_graph<T,mem_manager>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void serialize (
+ const directed_graph<T,mem_manager>& item,
+ std::ostream& out
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ directed_graph<T,mem_manager>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+}
+
+#endif // DLIB_DIRECTED_GRAPH_KERNEl_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/disjoint_subsets.h b/ml/dlib/dlib/disjoint_subsets.h
new file mode 100644
index 000000000..d33ef63f4
--- /dev/null
+++ b/ml/dlib/dlib/disjoint_subsets.h
@@ -0,0 +1,12 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DISJOINt_SUBSETS_
+#define DLIB_DISJOINt_SUBSETS_
+
+
+#include "disjoint_subsets/disjoint_subsets.h"
+#include "disjoint_subsets/disjoint_subsets_sized.h"
+
+#endif // DLIB_DISJOINt_SUBSETS_
+
+
diff --git a/ml/dlib/dlib/disjoint_subsets/disjoint_subsets.h b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets.h
new file mode 100644
index 000000000..7fab9eba3
--- /dev/null
+++ b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets.h
@@ -0,0 +1,141 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DISJOINT_SUBsETS_Hh_
+#define DLIB_DISJOINT_SUBsETS_Hh_
+
+#include "disjoint_subsets_abstract.h"
+#include <vector>
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class disjoint_subsets
+ {
+ public:
+
+ void clear (
+ ) noexcept
+ {
+ items.clear();
+ }
+
+ void set_size (
+ unsigned long new_size
+ )
+ {
+ items.resize(new_size);
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ items[i].parent = i;
+ items[i].rank = 0;
+ }
+ }
+
+ size_t size (
+ ) const noexcept
+ {
+ return items.size();
+ }
+
+ unsigned long find_set (
+ unsigned long item
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(item < size(),
+ "\t unsigned long disjoint_subsets::find_set()"
+ << "\n\t item must be less than size()"
+ << "\n\t item: " << item
+ << "\n\t size(): " << size()
+ << "\n\t this: " << this
+ );
+
+ if (items[item].parent == item)
+ {
+ return item;
+ }
+ else
+ {
+ // find root of item
+ unsigned long x = item;
+ do
+ {
+ x = items[x].parent;
+ } while (items[x].parent != x);
+
+ // do path compression
+ const unsigned long root = x;
+ x = item;
+ while (items[x].parent != x)
+ {
+ const unsigned long prev = x;
+ x = items[x].parent;
+ items[prev].parent = root;
+ }
+
+ return root;
+ }
+ }
+
+ unsigned long merge_sets (
+ unsigned long a,
+ unsigned long b
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(a != b &&
+ a < size() &&
+ b < size() &&
+ find_set(a) == a &&
+ find_set(b) == b,
+ "\t unsigned long disjoint_subsets::merge_sets(a,b)"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t a: " << a
+ << "\n\t b: " << b
+ << "\n\t size(): " << size()
+ << "\n\t find_set(a): " << find_set(a)
+ << "\n\t find_set(b): " << find_set(b)
+ << "\n\t this: " << this
+ );
+
+ if (items[a].rank > items[b].rank)
+ {
+ items[b].parent = a;
+ return a;
+ }
+ else
+ {
+ items[a].parent = b;
+ if (items[a].rank == items[b].rank)
+ {
+ items[b].rank = items[b].rank + 1;
+ }
+ return b;
+ }
+ }
+
+ private:
+
+ /*
+ See the book Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein
+ for a discussion of how this algorithm works.
+ */
+
+ struct data
+ {
+ unsigned long rank;
+ unsigned long parent;
+ };
+
+ mutable std::vector<data> items;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DISJOINT_SUBsETS_Hh_
diff --git a/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_abstract.h b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_abstract.h
new file mode 100644
index 000000000..bd67d0d5c
--- /dev/null
+++ b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_abstract.h
@@ -0,0 +1,96 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_
+#ifdef DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_
+
+#include <vector>
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class disjoint_subsets
+ {
+ /*!
+ INITIAL VALUE
+ - size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a set of integers which is partitioned into
+ a number of disjoint subsets. It supports the two fundamental operations
+ of finding which subset a particular integer belongs to as well as
+ merging subsets.
+ !*/
+ public:
+
+ void clear (
+ ) noexcept;
+ /*!
+ ensures
+ - #size() == 0
+ - returns this object to its initial value
+ !*/
+
+ void set_size (
+ unsigned long new_size
+ );
+ /*!
+ ensures
+ - #size() == new_size
+ - for all valid i:
+ - #find_set(i) == i
+ (i.e. this object contains new_size subsets, each containing exactly one element)
+ !*/
+
+ size_t size (
+ ) const noexcept;
+ /*!
+ ensures
+ - returns the total number of integer elements represented
+ by this object.
+ !*/
+
+ unsigned long find_set (
+ unsigned long item
+ ) const;
+ /*!
+ requires
+ - item < size()
+ ensures
+ - Each disjoint subset can be represented by any of its elements (since
+ the sets are all disjoint). In particular, for each subset we define
+ a special "representative element" which is used to represent it.
+ Therefore, this function returns the representative element for the
+ set which contains item.
+ - find_set(find_set(item)) == find_set(item)
+ - Note that if A and B are both elements of the same subset then we always
+ have find_set(A) == find_set(B).
+ !*/
+
+ unsigned long merge_sets (
+ unsigned long a,
+ unsigned long b
+ );
+ /*!
+ requires
+ - a != b
+ - a < size()
+ - b < size()
+ - find_set(a) == a
+ (i.e. a is the representative element of some set)
+ - find_set(b) == b
+ (i.e. b is the representative element of some set)
+ ensures
+ - #find_set(a) == #find_set(b)
+ (i.e. merges the set's containing a and b)
+ - returns #find_set(a)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_
diff --git a/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized.h b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized.h
new file mode 100644
index 000000000..9aa657f43
--- /dev/null
+++ b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized.h
@@ -0,0 +1,130 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DISJOINT_SUBsETS_SIZED_Hh_
+#define DLIB_DISJOINT_SUBsETS_SIZED_Hh_
+
+#include "disjoint_subsets_sized_abstract.h"
+#include "disjoint_subsets.h"
+#include <vector>
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class disjoint_subsets_sized
+ {
+ public:
+
+ void clear (
+ ) noexcept
+ {
+ disjoint_subsets_.clear();
+ sets_size.clear();
+ number_of_sets = 0;
+ }
+
+ void set_size (
+ unsigned long new_size
+ )
+ {
+ disjoint_subsets_.set_size(new_size);
+ sets_size.assign(new_size, 1);
+ number_of_sets = new_size;
+ }
+
+ size_t size (
+ ) const noexcept
+ {
+ return disjoint_subsets_.size();
+ }
+
+ unsigned long find_set (
+ unsigned long item
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(item < size(),
+ "\t unsigned long disjoint_subsets::find_set()"
+ << "\n\t item must be less than size()"
+ << "\n\t item: " << item
+ << "\n\t size(): " << size()
+ << "\n\t this: " << this
+ );
+
+ return disjoint_subsets_.find_set(item);
+ }
+
+ unsigned long merge_sets (
+ unsigned long a,
+ unsigned long b
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(a != b &&
+ a < size() &&
+ b < size() &&
+ find_set(a) == a &&
+ find_set(b) == b,
+ "\t unsigned long disjoint_subsets::merge_sets(a,b)"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t a: " << a
+ << "\n\t b: " << b
+ << "\n\t size(): " << size()
+ << "\n\t find_set(a): " << find_set(a)
+ << "\n\t find_set(b): " << find_set(b)
+ << "\n\t this: " << this
+ );
+
+ disjoint_subsets_.merge_sets(a, b);
+
+ if (find_set(a) == a) sets_size[a] += sets_size[b];
+ else sets_size[b] += sets_size[a];
+ --number_of_sets;
+
+ return find_set(a);
+ }
+
+ unsigned long get_number_of_sets (
+ ) const noexcept
+ {
+ return number_of_sets;
+ }
+
+ unsigned long get_size_of_set(
+ unsigned long item
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(item < size() &&
+ find_set(item) == item,
+ "\t unsigned long disjoint_subsets::get_size_of_set()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t item: " << item
+ << "\n\t size(): " << size()
+ << "\n\t find_set(item): " << find_set(item)
+ << "\n\t this: " << this
+ );
+
+ return sets_size[item];
+ }
+
+ private:
+
+ /*
+ See the book Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein
+ for a discussion of how this algorithm works.
+ */
+
+ mutable std::vector<unsigned long> sets_size;
+ unsigned long number_of_sets{0};
+ disjoint_subsets disjoint_subsets_;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DISJOINT_SUBsETS_SIZED_Hh_
diff --git a/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h
new file mode 100644
index 000000000..ecc5ef005
--- /dev/null
+++ b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h
@@ -0,0 +1,123 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DISJOINT_SUBsETS_SIZED_ABSTRACT_Hh_
+#ifdef DLIB_DISJOINT_SUBsETS_SIZED_ABSTRACT_Hh_
+
+#include <vector>
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class disjoint_subsets_sized
+ {
+ /*!
+ INITIAL VALUE
+ - size() == 0
+ - get_number_of_sets() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a set of integers which is partitioned into
+ a number of disjoint subsets. It supports the two fundamental operations
+ of finding which subset a particular integer belongs to as well as
+ merging subsets. It also allows you to find out how big each subset is. It
+ is therefore essentially the same thing as dlib::disjoint_subsets, except
+ it also keeps track of of the size of each subset.
+ !*/
+ public:
+
+ void clear (
+ ) noexcept;
+ /*!
+ ensures
+ - #size() == 0
+ - #get_number_of_sets() == 0
+ - returns this object to its initial value
+ !*/
+
+ void set_size (
+ unsigned long new_size
+ );
+ /*!
+ ensures
+ - #size() == new_size
+ - #get_number_of_sets() == new_size
+ - for all valid i:
+ - #find_set(i) == i
+ (i.e. this object contains new_size subsets, each containing exactly one element)
+ - #get_size_of_set(i) == 1
+ !*/
+
+ size_t size (
+ ) const noexcept;
+ /*!
+ ensures
+ - returns the total number of integer elements represented
+ by this object.
+ !*/
+
+ unsigned long find_set (
+ unsigned long item
+ ) const;
+ /*!
+ requires
+ - item < size()
+ ensures
+ - Each disjoint subset can be represented by any of its elements (since
+ the sets are all disjoint). In particular, for each subset we define
+ a special "representative element" which is used to represent it.
+ Therefore, this function returns the representative element for the
+ set which contains item.
+ - find_set(find_set(item)) == find_set(item)
+ - Note that if A and B are both elements of the same subset then we always
+ have find_set(A) == find_set(B).
+ !*/
+
+ unsigned long merge_sets (
+ unsigned long a,
+ unsigned long b
+ );
+ /*!
+ requires
+ - a != b
+ - a < size()
+ - b < size()
+ - find_set(a) == a
+ (i.e. a is the representative element of some set)
+ - find_set(b) == b
+ (i.e. b is the representative element of some set)
+ ensures
+ - #find_set(a) == #find_set(b)
+ (i.e. merges the set's containing a and b)
+ - #get_size_of_set(#find_set(a)) == get_size_of_set(a) + get_size_of_set(b)
+ - #get_number_of_sets() == get_number_of_sets() - 1
+ - returns #find_set(a)
+ !*/
+
+ unsigned long get_number_of_sets (
+ ) const noexcept;
+ /*!
+ ensures
+ - returns the current number of different subsets.
+ !*/
+
+ unsigned long get_size_of_set(
+ unsigned long item
+ ) const;
+ /*!
+ requires
+ - item < size()
+ - find_set(item) == item
+ (i.e. item is the representative element of some set)
+ ensures
+ - returns the number of elements which belongs to the set where item is the representative element.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_
diff --git a/ml/dlib/dlib/dlib_basic_cpp_build_tutorial.txt b/ml/dlib/dlib/dlib_basic_cpp_build_tutorial.txt
new file mode 100644
index 000000000..fa1a970ab
--- /dev/null
+++ b/ml/dlib/dlib/dlib_basic_cpp_build_tutorial.txt
@@ -0,0 +1,13 @@
+#error "Don't write #include <dlib/all/source.cpp> in your code."
+/*
+ In C++, it is generally an error to #include .cpp files. This is because it
+ can lead to what are called multiply defined symbol errors. Therefore, you
+ should compile dlib/all/source.cpp into your application just like you would
+ compile any other .cpp file.
+
+ If you are using Visual Studio you add .cpp files to your application using
+ the solution explorer window. Specifically, right click on Source Files,
+ then select Add -> Existing Item and select the .cpp files you want to add.
+
+ For general information on compiling dlib see http://dlib.net/compile.html
+*/
diff --git a/ml/dlib/dlib/dlib_include_path_tutorial.txt b/ml/dlib/dlib/dlib_include_path_tutorial.txt
new file mode 100644
index 000000000..f279ce103
--- /dev/null
+++ b/ml/dlib/dlib/dlib_include_path_tutorial.txt
@@ -0,0 +1,20 @@
+#error "Don't put the dlib folder in your include path"
+/*
+ You are getting this error because you have added the dlib folder to your
+ compiler's include search path.
+
+ You should *NOT* add the dlib folder itself to your compiler's include path.
+ Doing so will cause the build to fail because of name collisions (such as
+ dlib/string.h and string.h from the standard library). Instead you should
+ add the folder that contains the dlib folder to your include search path
+ and then use include statements of the form #include <dlib/queue.h> or
+ #include "dlib/queue.h". This will ensure that everything builds correctly.
+
+ XCode:
+ The XCode IDE often puts all folders that it knows about into
+ the compiler search path. So if you are using XCode then either
+ don't drag the whole dlib folder into the project or alternatively
+ modify your XCode project settings to not auto-add all folders to
+ the include path. Instead just make sure that the dlib folder is
+ itself inside a folder in your include path.
+*/
diff --git a/ml/dlib/dlib/dnn.h b/ml/dlib/dlib/dnn.h
new file mode 100644
index 000000000..db59948fb
--- /dev/null
+++ b/ml/dlib/dlib/dnn.h
@@ -0,0 +1,37 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_
+#define DLIB_DNn_
+
+// DNN module uses template-based network declaration that leads to very long
+// type names. Visual Studio will produce Warning C4503 in such cases
+#ifdef _MSC_VER
+# pragma warning( disable: 4503 )
+#endif
+
+#include "dnn/tensor.h"
+#include "dnn/input.h"
+
+// Problem: Visual Studio's vcpkgsrv.exe constantly uses a single CPU core,
+// apparently never finishing whatever it's trying to do. Moreover,
+// this issue prevents some operations like switching from Debug to
+// Release (and vice versa) in the IDE. (Your mileage may vary.)
+// Workaround: Keep manually killing the vcpkgsrv.exe process.
+// Solution: Disable IntelliSense for some files. Which files? Unfortunately
+// this seems to be a trial-and-error process.
+#ifndef __INTELLISENSE__
+#include "dnn/layers.h"
+#endif // __INTELLISENSE__
+
+#include "dnn/loss.h"
+#include "dnn/core.h"
+#include "dnn/solvers.h"
+#include "dnn/trainer.h"
+#include "dnn/cpu_dlib.h"
+#include "dnn/tensor_tools.h"
+#include "dnn/utilities.h"
+#include "dnn/validation.h"
+
+#endif // DLIB_DNn_
+
+
diff --git a/ml/dlib/dlib/dnn/core.h b/ml/dlib/dlib/dnn/core.h
new file mode 100644
index 000000000..5f1d05498
--- /dev/null
+++ b/ml/dlib/dlib/dnn/core.h
@@ -0,0 +1,3599 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_CORE_H_
+#define DLIB_DNn_CORE_H_
+
+#include "core_abstract.h"
+#include "tensor.h"
+#include <iterator>
+#include <memory>
+#include <sstream>
+#include <type_traits>
+#include "../statistics.h"
+#include "../rand.h"
+#include "../algs.h"
+#include <utility>
+#include <tuple>
+#include <cmath>
+#include <vector>
+#include "tensor_tools.h"
+#include <type_traits>
+#include "../metaprogramming.h"
+
+#ifdef _MSC_VER
+// Tell Visual Studio not to recursively inline functions very much because otherwise it
+// takes hours to compile the DNN code sometimes. It's crazy. Hopefully we can remove
+// this some day when the visual studio compiler is more efficient.
+#pragma inline_depth(2)
+#endif
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T, typename int_<decltype(&T::get_learning_rate_multiplier)>::type = 0>
+ double get_learning_rate_multiplier (
+ const T& obj,
+ special_
+ ) { return obj.get_learning_rate_multiplier(); }
+
+ template <typename T>
+ double get_learning_rate_multiplier ( const T& , general_) { return 1; }
+ }
+ template <typename T>
+ double get_learning_rate_multiplier(const T& obj) { return impl::get_learning_rate_multiplier(obj, special_()); }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T, typename int_<decltype(&T::get_weight_decay_multiplier)>::type = 0>
+ double get_weight_decay_multiplier (
+ const T& obj,
+ special_
+ ) { return obj.get_weight_decay_multiplier(); }
+
+ template <typename T>
+ double get_weight_decay_multiplier ( const T& , general_) { return 1; }
+ }
+ template <typename T>
+ double get_weight_decay_multiplier(const T& obj) { return impl::get_weight_decay_multiplier(obj, special_()); }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ // The reason we return an int for this version rather than doing the more straight forward thing (like we do above) is to avoid a bug in visual studio 2015.
+ template <typename T>
+ auto call_clean_method_if_exists (
+ T& obj,
+ special_
+ ) -> typename int_<decltype(&T::clean)>::type { obj.clean(); return 0; }
+
+ template <typename T>
+ void call_clean_method_if_exists (T& , general_) {}
+ }
+ template <typename T>
+ void call_clean_method_if_exists(T& obj) { impl::call_clean_method_if_exists(obj, special_()); }
+ /*!
+ ensures
+ - calls obj.clean() if obj has a .clean() method.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ class repeat_input_layer
+ {
+ /*!
+ None of the declarations in this object are really used. The only reason it
+ exists is to allow the repeat object to use a special input layer in its
+ internal networks which will cause add_tag_layer objects that happen to be
+ right at the input to not create copies of their input tensors. So
+ introducing the repeat_input_layer object allows us to optimize the
+ implementation of add_tag_layer for a special case that arises when it's
+ used in the context of the repeat layer.
+ !*/
+ public:
+ typedef int input_type;
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ,
+ forward_iterator ,
+ resizable_tensor&
+ ) const
+ {
+ }
+
+ friend void serialize(const repeat_input_layer&, std::ostream&){}
+ friend void deserialize(repeat_input_layer&, std::istream&){}
+ friend std::ostream& operator<<(std::ostream& out, const repeat_input_layer&) { return out; }
+ };
+
+ inline std::string tensor_to_str (
+ const tensor& t,
+ int& min_length
+ )
+ {
+ if (t.size() == 0)
+ return "";
+
+ std::ostringstream sout;
+ sout << "output size=(num:"<< t.num_samples() << ", ";
+ sout << "k:" << t.k() << ",";
+ while (sout.tellp() < 28) sout << " ";
+ sout << "nr:" << t.nr() << ",";
+ while (sout.tellp() < 28+8) sout << " ";
+ sout << "nc:" << t.nc() << ")";
+ while (sout.tellp() < min_length) sout << " ";
+ min_length = sout.tellp();
+ sout << "\t";
+ return sout.str();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // Tell us if T is one of the special layer types (i.e. add_layer, repeat, add_tag_layer, or
+ // add_skip_layer).
+ template <typename T> struct is_nonloss_layer_type : std::false_type {};
+ // Tell us if T is an instance of add_loss_layer.
+ template <typename T> struct is_loss_layer_type : std::false_type {};
+ // Tell us if T is an instance of add_layer
+ template <typename T> struct is_add_layer : std::false_type {};
+
+ namespace impl
+ {
+ template <size_t... indices, typename Tuple>
+ auto tuple_subset(
+ const Tuple& item,
+ compile_time_integer_list<indices...>
+ ) -> decltype(std::make_tuple(std::get<indices>(item)...))
+ {
+ return std::make_tuple(std::get<indices>(item)...);
+ }
+
+ template <typename Head, typename... Tail>
+ std::tuple<Tail...> basic_tuple_tail(
+ const std::tuple<Head, Tail...>& item
+ )
+ {
+ return tuple_subset(item, typename make_compile_time_integer_range<sizeof...(Tail)>::type());
+ }
+
+ template <typename T>
+ std::tuple<T> tuple_flatten(const T& t)
+ {
+ return std::make_tuple(t);
+ }
+
+ template <typename... T>
+ auto tuple_flatten(
+ const std::tuple<T...>& item
+ ) -> decltype(tuple_flatten(item, typename make_compile_time_integer_range<sizeof...(T)>::type()))
+ {
+ return tuple_flatten(item, typename make_compile_time_integer_range<sizeof...(T)>::type());
+ }
+
+ template <size_t... indices, typename... T>
+ auto tuple_flatten(
+ const std::tuple<T...>& item,
+ compile_time_integer_list<indices...>
+ ) -> decltype(std::tuple_cat(tuple_flatten(std::get<indices-1>(item))...))
+ {
+ return std::tuple_cat(tuple_flatten(std::get<indices-1>(item))...);
+ }
+
+ template <typename T>
+ struct tuple_head_helper
+ {
+ typedef T type;
+ static const type& get(const T& item)
+ {
+ return item;
+ }
+ };
+
+ template <typename T, typename... U>
+ struct tuple_head_helper<std::tuple<T, U...>>
+ {
+ typedef typename tuple_head_helper<T>::type type;
+ static const type& get(const std::tuple<T,U...>& item)
+ {
+ return tuple_head_helper<T>::get(std::get<0>(item));
+ }
+ };
+
+ template <typename T> struct alwaysbool { typedef bool type; };
+ // one more structure for VS 2015 UP3 support workaround
+ template <typename T> struct alwaysbool2 { typedef bool type; };
+
+ resizable_tensor& rt();
+
+ // The significance of a layer's backward method requiring forward's outputs is
+ // that such as layer can't have an in-place layer stacked on top of it because
+ // in-place layers overwrite the output of the layer they sit on top of.
+ template <typename layer_type, typename SUBNET>
+ constexpr auto backward_requires_forward_output(
+ layer_type& layer,
+ SUBNET& sub
+ ) -> typename alwaysbool<decltype(layer.backward(rt(),rt(),sub,rt()))>::type
+ {
+ return true;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ constexpr auto backward_requires_forward_output(
+ layer_type& layer,
+ SUBNET& sub
+ ) -> typename alwaysbool<decltype(layer.backward(rt(),sub,rt()))>::type
+ {
+ return false;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ constexpr auto backward_requires_forward_output(
+ layer_type& layer,
+ SUBNET& sub
+ ) -> typename alwaysbool<decltype(layer.backward_inplace(rt(),rt(),sub.get_gradient_input(),rt()))>::type
+ {
+ return true;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ constexpr auto backward_requires_forward_output(
+ layer_type& layer,
+ SUBNET& sub
+ ) -> typename alwaysbool<decltype(layer.backward_inplace(rt(),sub.get_gradient_input(),rt()))>::type
+ {
+ return false;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ constexpr auto has_inplace_backward(
+ layer_type& layer,
+ SUBNET& sub
+ ) -> typename alwaysbool2<decltype(layer.backward(rt(),rt(),sub,rt()))>::type
+ {
+ return false;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ constexpr auto has_inplace_backward(
+ layer_type& layer,
+ SUBNET& sub
+ ) -> typename alwaysbool2<decltype(layer.backward(rt(),sub,rt()))>::type
+ {
+ return false;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ constexpr auto has_inplace_backward(
+ layer_type& layer,
+ SUBNET& sub
+ ) -> typename alwaysbool2<decltype(layer.backward_inplace(rt(),rt(),sub.get_gradient_input(),rt()))>::type
+ {
+ return true;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ constexpr auto has_inplace_backward(
+ layer_type& layer,
+ SUBNET& sub
+ ) -> typename alwaysbool2<decltype(layer.backward_inplace(rt(),sub.get_gradient_input(),rt()))>::type
+ {
+ return true;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ constexpr auto is_inplace_layer(
+ layer_type& layer,
+ const SUBNET& sub
+ ) -> typename alwaysbool2<decltype(layer.forward(sub,rt()))>::type
+ {
+ return false;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ constexpr auto is_inplace_layer(
+ layer_type& layer,
+ const SUBNET& sub
+ ) -> typename alwaysbool<decltype(layer.forward_inplace(sub.get_output(),rt()))>::type
+ {
+ return true;
+ }
+
+ template <typename layer_type, typename SUBNET>
+ auto call_layer_backward(
+ layer_type& layer,
+ const tensor& computed_output,
+ const tensor& gradient_input,
+ SUBNET& sub,
+ tensor& params_grad
+ ) -> decltype(layer.backward(computed_output,gradient_input,sub,params_grad))
+ {
+ layer.backward(computed_output,gradient_input,sub,params_grad);
+ }
+
+ template <typename layer_type, typename SUBNET>
+ auto call_layer_backward(
+ layer_type& layer,
+ const tensor& ,
+ const tensor& gradient_input,
+ SUBNET& sub,
+ tensor& params_grad
+ ) -> decltype(layer.backward(gradient_input,sub,params_grad))
+ {
+ layer.backward(gradient_input,sub,params_grad);
+ }
+
+ template <typename layer_type, typename SUBNET>
+ auto call_layer_backward(
+ layer_type& layer,
+ const tensor& computed_output,
+ const tensor& gradient_input,
+ SUBNET& sub,
+ tensor& params_grad
+ ) -> decltype(layer.backward_inplace(computed_output,gradient_input,sub.get_gradient_input(),params_grad))
+ {
+ layer.backward_inplace(computed_output,gradient_input,sub.get_gradient_input(),params_grad);
+ }
+
+ template <typename layer_type, typename SUBNET>
+ auto call_layer_backward(
+ layer_type& layer,
+ const tensor& ,
+ const tensor& gradient_input,
+ SUBNET& sub,
+ tensor& params_grad
+ ) -> decltype(layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad))
+ {
+ layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad);
+ }
+
+
+ template <typename layer_type, typename SUBNET>
+ auto call_layer_forward(
+ layer_type& layer,
+ const SUBNET& sub,
+ tensor& /*data_output*/
+ ) -> decltype(layer.forward(sub,rt()))
+ {
+ // This overload of call_layer_forward() is here because this template
+ // naturally gets instantiated but only on code paths that never get executed.
+ // So rather than writing a bunch of hard to read template magic around call
+ // sites we just have this overload that doesn't do anything (and an assert to
+ // make sure that's the case).
+ DLIB_CASSERT(false, "This should never happen");
+ }
+
+ template <typename layer_type, typename SUBNET>
+ auto call_layer_forward(
+ layer_type& layer,
+ const SUBNET& sub,
+ resizable_tensor& data_output
+ ) -> decltype(layer.forward(sub,data_output))
+ {
+ layer.forward(sub,data_output);
+ }
+
+ template <typename layer_type, typename SUBNET>
+ auto call_layer_forward(
+ layer_type& layer,
+ const SUBNET& sub,
+ tensor& data_output
+ ) -> decltype(layer.forward_inplace(sub.get_output(),data_output))
+ {
+ layer.forward_inplace(sub.get_output(),data_output);
+ }
+
+ template <typename layer_type, typename SUBNET>
+ auto call_layer_forward(
+ layer_type& layer,
+ const SUBNET& sub,
+ resizable_tensor& data_output
+ ) -> decltype(layer.forward_inplace(sub.get_output(),data_output))
+ {
+ if (!have_same_dimensions(data_output, sub.get_output()))
+ data_output.copy_size(sub.get_output());
+ layer.forward_inplace(sub.get_output(),static_cast<tensor&>(data_output));
+ }
+
+
+ } // end namespace impl
+
+ template <typename... T>
+ typename impl::tuple_head_helper<std::tuple<T...>>::type tuple_head (
+ const std::tuple<T...>& item
+ )
+ {
+ return impl::tuple_head_helper<std::tuple<T...>>::get(item);
+ }
+
+ template <typename... T>
+ auto tuple_tail(
+ const std::tuple<T...>& item
+ ) -> decltype(impl::basic_tuple_tail(impl::tuple_flatten(item)))
+ {
+ return impl::basic_tuple_tail(impl::tuple_flatten(item));
+ }
+
+ inline std::tuple<> tuple_tail(
+ const std::tuple<>& item
+ )
+ {
+ return item;
+ }
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class sstack
+ {
+ public:
+ typedef T value_type;
+
+ sstack() = delete;
+
+ sstack (
+ T* data_,
+ size_t s
+ ) : data(data_), mysize(s) {}
+
+ const T& top() const
+ {
+ DLIB_CASSERT(size() != 0, "You can't call top() on an empty stack");
+ return *data;
+ }
+ T& top()
+ {
+ DLIB_CASSERT(size() != 0, "You can't call top() on an empty stack");
+ return *data;
+ }
+
+ size_t size() const { return mysize; }
+
+ sstack pop(size_t num=1)
+ {
+ DLIB_CASSERT(num <= size(), "You can't pop more things from the stack than it has in it.");
+ return sstack(data+num, mysize-num);
+ }
+
+ private:
+
+ T* data;
+ size_t mysize;
+ };
+
+ template <typename T>
+ sstack<T> make_sstack(std::vector<T>& item)
+ {
+ return sstack<T>(item.data(), item.size());
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace dimpl
+ {
+ template <typename T, bool is_first = true, typename enabled=void>
+ class subnet_wrapper
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool that makes an add_layer or add_loss_layer object
+ expose only the part of its interface defined by the SUBNET
+ type in layers_abstract.h. This way, when we pass subnetwork
+ objects to the layer callbacks those callbacks won't be able to
+ interact with the subnetworks in a way other than specified
+ by the SUBNET interface spec.
+
+ We also allow the top layer of a subnet_wrapper stack to call the
+ private_get_output() and private_get_gradient_input() functions. This
+ way, layers that have had their output/gradient overwritten by in-place
+ layers can only be accessed from the in-place layers that sit directly
+ on top of them since those in-place layers are the only layers that
+ know how to interact with them properly.
+ !*/
+
+ public:
+ subnet_wrapper(const subnet_wrapper&) = delete;
+ subnet_wrapper& operator=(const subnet_wrapper&) = delete;
+
+ subnet_wrapper(T& l_, unsigned int sef) : l(l_),_sample_expansion_factor(sef) {}
+ // Not much here because in this case T is one of the input layer types
+ // that doesn't have anything in it.
+ typedef T layer_details_type;
+ const layer_details_type& layer_details() const { return l; }
+ unsigned int sample_expansion_factor() const { return _sample_expansion_factor; }
+ private:
+ T& l;
+ unsigned int _sample_expansion_factor;
+ };
+
+ template <typename T>
+ class subnet_wrapper<T,true, typename std::enable_if<is_nonloss_layer_type<T>::value>::type>
+ {
+
+ public:
+ subnet_wrapper(const subnet_wrapper&) = delete;
+ subnet_wrapper& operator=(const subnet_wrapper&) = delete;
+
+ typedef T wrapped_type;
+ const static size_t num_computational_layers = T::num_computational_layers;
+ const static size_t num_layers = T::num_layers;
+ typedef typename T::layer_details_type layer_details_type;
+
+ subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {}
+
+ const tensor& get_output() const { return l.private_get_output(); }
+ tensor& get_gradient_input() { return l.private_get_gradient_input(); }
+
+ const layer_details_type& layer_details() const { return l.layer_details(); }
+
+ const subnet_wrapper<typename T::subnet_type,false>& subnet() const { return subnetwork; }
+ subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
+ unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); }
+
+ private:
+ T& l;
+ subnet_wrapper<typename T::subnet_type,false> subnetwork;
+ };
+
+ template <typename T>
+ class subnet_wrapper<T,false, typename std::enable_if<is_nonloss_layer_type<T>::value>::type>
+ {
+
+ public:
+ subnet_wrapper(const subnet_wrapper&) = delete;
+ subnet_wrapper& operator=(const subnet_wrapper&) = delete;
+
+ typedef T wrapped_type;
+ const static size_t num_computational_layers = T::num_computational_layers;
+ const static size_t num_layers = T::num_layers;
+ typedef typename T::layer_details_type layer_details_type;
+
+ subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {}
+
+ const tensor& get_output() const { return l.get_output(); }
+ tensor& get_gradient_input() { return l.get_gradient_input(); }
+
+ const layer_details_type& layer_details() const { return l.layer_details(); }
+
+ const subnet_wrapper<typename T::subnet_type,false>& subnet() const { return subnetwork; }
+ subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
+ unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); }
+
+ private:
+ T& l;
+ subnet_wrapper<typename T::subnet_type,false> subnetwork;
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename LAYER_DETAILS, typename SUBNET, typename enabled = void>
+ class add_layer;
+
+ template <typename LAYER_DETAILS, typename SUBNET, typename enabled>
+ void serialize(const add_layer<LAYER_DETAILS,SUBNET,enabled>& item, std::ostream& out);
+ template <typename LAYER_DETAILS, typename SUBNET, typename enabled>
+ void deserialize(add_layer<LAYER_DETAILS,SUBNET,enabled>& item, std::istream& in);
+
+ template <typename T, typename U>
+ struct is_nonloss_layer_type<add_layer<T,U>> : std::true_type {};
+
+ template <typename LAYER_DETAILS, typename SUBNET>
+ class add_layer<LAYER_DETAILS,SUBNET,
+ typename std::enable_if<is_nonloss_layer_type<SUBNET>::value>::type>
+ {
+ public:
+ typedef LAYER_DETAILS layer_details_type;
+ typedef SUBNET subnet_type;
+ typedef typename subnet_type::input_type input_type;
+ const static size_t num_layers = subnet_type::num_layers + 1;
+ const static size_t num_computational_layers = subnet_type::num_computational_layers + 1;
+
+ add_layer(
+ ):
+ subnetwork(new subnet_type()),
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false)
+ {
+ if (this_layer_operates_inplace())
+ subnetwork->disable_output_and_gradient_getters();
+ }
+
+ add_layer(const add_layer& item)
+ {
+ details = item.details;
+ subnetwork.reset(new subnet_type(*item.subnetwork));
+ this_layer_setup_called = item.this_layer_setup_called;
+ gradient_input_is_stale = item.gradient_input_is_stale;
+ get_output_and_gradient_input_disabled = item.get_output_and_gradient_input_disabled;
+ x_grad = item.x_grad;
+ cached_output = item.cached_output;
+ params_grad = item.params_grad;
+ temp_tensor = item.temp_tensor;
+ }
+ add_layer& operator=(const add_layer& item) { add_layer(item).swap(*this); return *this;}
+ add_layer(add_layer&& item) : add_layer() { swap(item); }
+ add_layer& operator=(add_layer&& item) { swap(item); return *this; }
+
+ template <typename T, typename U, typename E>
+ friend class add_layer;
+ template <typename T, bool is_first, typename E>
+ friend class dimpl::subnet_wrapper;
+ template <unsigned long T, typename U, typename E>
+ friend class add_tag_layer;
+ template <template<typename> class T, typename U>
+ friend class add_skip_layer;
+ template <size_t N, template<typename> class L, typename S>
+ friend class repeat;
+
+ // Allow copying networks from one to another as long as their corresponding
+ // layers can be constructed from each other.
+ template <typename T, typename U, typename E>
+ add_layer(
+ const add_layer<T,U,E>& item
+ ) :
+ details(item.layer_details()),
+ subnetwork(new subnet_type(item.subnet())),
+ this_layer_setup_called(item.this_layer_setup_called),
+ gradient_input_is_stale(item.gradient_input_is_stale),
+ get_output_and_gradient_input_disabled(item.get_output_and_gradient_input_disabled),
+ x_grad(item.x_grad),
+ cached_output(item.cached_output)
+ {
+ if (this_layer_operates_inplace())
+ subnetwork->disable_output_and_gradient_getters();
+ }
+
+ template <typename ...T>
+ add_layer(
+ const LAYER_DETAILS& layer_det,
+ T&& ...args
+ ) :
+ details(layer_det),
+ subnetwork(new subnet_type(std::forward<T>(args)...)),
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false)
+ {
+ if (this_layer_operates_inplace())
+ subnetwork->disable_output_and_gradient_getters();
+ }
+
+ template <typename T, typename ...U>
+ struct disable_forwarding_constr
+ {
+ const static bool value = std::is_constructible<LAYER_DETAILS,T>::value;
+ };
+ template <typename ...T, typename ...U>
+ struct disable_forwarding_constr<std::tuple<T...>,U...>
+ {
+ const static bool value = disable_forwarding_constr<typename std::remove_reference<T>::type...>::value;
+ };
+ template <typename T, typename ...U>
+ struct disable_forwarding_constr<std::tuple<T>,U...>
+ {
+ const static bool value = disable_forwarding_constr<typename std::remove_reference<T>::type>::value;
+ };
+ template <typename ...U>
+ struct disable_forwarding_constr<std::tuple<>,U...>
+ {
+ const static bool value = true;
+ };
+ template <typename ...T>
+ struct disable_forwarding_constr<add_layer<T...>>
+ {
+ const static bool value = true;
+ };
+
+ template <
+ typename ...T,
+ typename = typename std::enable_if<!disable_forwarding_constr<typename std::remove_reference<T>::type...>::value>::type
+ >
+ add_layer(
+ T&& ...args
+ ) :
+ subnetwork(new subnet_type(std::forward<T>(args)...)),
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false)
+ {
+ if (this_layer_operates_inplace())
+ subnetwork->disable_output_and_gradient_getters();
+ }
+
+ template <typename ...T>
+ add_layer(
+ LAYER_DETAILS&& layer_det,
+ T&& ...args
+ ) :
+ details(std::move(layer_det)),
+ subnetwork(new subnet_type(std::forward<T>(args)...)),
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false)
+ {
+ if (this_layer_operates_inplace())
+ subnetwork->disable_output_and_gradient_getters();
+ }
+
+ template <typename ...T, typename LD, typename ...U>
+ add_layer(
+ const std::tuple<LD,U...>& layer_det,
+ T&& ...args
+ ) :
+ details(tuple_head(layer_det)),
+ subnetwork(new subnet_type(tuple_tail(layer_det),std::forward<T>(args)...)),
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false)
+ {
+ if (this_layer_operates_inplace())
+ subnetwork->disable_output_and_gradient_getters();
+ }
+
+ template <typename ...T, typename LD, typename ...U>
+ add_layer(
+ std::tuple<>,
+ const std::tuple<LD,U...>& layer_det,
+ T&& ...args
+ ) : add_layer(layer_det,args...) { }
+
+ add_layer (
+ std::tuple<>
+ ) : add_layer() {}
+
+ template <typename ...T>
+ add_layer(
+ std::tuple<>,
+ LAYER_DETAILS&& layer_det,
+ T&& ...args
+ ) : add_layer(layer_det, args...) { }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ subnetwork->to_tensor(ibegin,iend,data);
+ }
+
+ template <typename forward_iterator>
+ const tensor& operator() (
+ forward_iterator ibegin,
+ forward_iterator iend
+ )
+ {
+ to_tensor(ibegin,iend,temp_tensor);
+ return forward(temp_tensor);
+ }
+
+
+ const tensor& operator() (const input_type& x)
+ {
+ return (*this)(&x, &x+1);
+ }
+
+ const tensor& forward(const tensor& x)
+ {
+ subnetwork->forward(x);
+ const dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
+ if (!this_layer_setup_called)
+ {
+ details.setup(wsub);
+ this_layer_setup_called = true;
+ }
+ if (this_layer_operates_inplace())
+ impl::call_layer_forward(details, wsub, private_get_output());
+ else
+ impl::call_layer_forward(details, wsub, cached_output);
+
+ gradient_input_is_stale = true;
+ return private_get_output();
+ }
+
+ private:
+ tensor& private_get_output() const
+ {
+ if (const_cast<add_layer&>(*this).this_layer_operates_inplace())
+ return subnetwork->private_get_output();
+ else
+ return const_cast<resizable_tensor&>(cached_output);
+ }
+ tensor& private_get_gradient_input()
+ {
+ if (this_layer_operates_inplace())
+ {
+ return subnetwork->private_get_gradient_input();
+ }
+ else
+ {
+ if (gradient_input_is_stale)
+ {
+ gradient_input_is_stale = false;
+ x_grad.copy_size(private_get_output());
+ x_grad = 0;
+ }
+ return x_grad;
+ }
+ }
+ void disable_output_and_gradient_getters (
+ ) { get_output_and_gradient_input_disabled = true; }
+ public:
+ const tensor& get_output() const
+ {
+ if (get_output_and_gradient_input_disabled)
+ throw dlib::error("Accessing this layer's get_output() is disabled because an in-place layer has been stacked on top of it.");
+ return private_get_output();
+ }
+ tensor& get_gradient_input()
+ {
+ if (get_output_and_gradient_input_disabled)
+ throw dlib::error("Accessing this layer's get_gradient_input() is disabled because an in-place layer has been stacked on top of it.");
+ return private_get_gradient_input();
+ }
+
+ const tensor& get_final_data_gradient(
+ ) const { return subnetwork->get_final_data_gradient(); }
+
+ void back_propagate_error(const tensor& x)
+ {
+ back_propagate_error(x, private_get_gradient_input());
+ }
+ void back_propagate_error(const tensor& x, const tensor& gradient_input)
+ {
+ dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
+ params_grad.copy_size(details.get_layer_params());
+ impl::call_layer_backward(details, private_get_output(),
+ gradient_input, wsub, static_cast<tensor&>(params_grad));
+
+ subnetwork->back_propagate_error(x);
+
+ // zero out get_gradient_input()
+ gradient_input_is_stale = true;
+ }
+
+ template <typename solver_type>
+ void update_parameters(sstack<solver_type> solvers, double learning_rate)
+ {
+ DLIB_CASSERT(solvers.size()>=num_computational_layers);
+ // Don't try to adjust the parameters if this layer doesn't have any or the
+ // learning rate is disabled for this layer.
+ if (params_grad.size() != 0 && get_learning_rate_multiplier(details) != 0)
+ {
+ const tensor& step = solvers.top()(learning_rate, details, static_cast<const tensor&>(params_grad));
+ tt::add(details.get_layer_params(), details.get_layer_params(), step);
+ }
+ subnetwork->update_parameters(solvers.pop(), learning_rate);
+ }
+
+ const tensor& get_parameter_gradient(
+ ) const { return params_grad; }
+
+ tensor& get_parameter_gradient (
+ ) { return params_grad; }
+
+ const subnet_type& subnet() const { return *subnetwork; }
+ subnet_type& subnet() { return *subnetwork; }
+
+ const layer_details_type& layer_details() const { return details; }
+ layer_details_type& layer_details() { return details; }
+
+ unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
+
+ void clean()
+ {
+ x_grad.clear();
+ cached_output.clear();
+ params_grad.clear();
+ temp_tensor.clear();
+ gradient_input_is_stale = true;
+ subnetwork->clean();
+ call_clean_method_if_exists(details);
+ }
+
+ friend void serialize(const add_layer& item, std::ostream& out)
+ {
+ int version = 2;
+ serialize(version, out);
+ serialize(*item.subnetwork, out);
+ serialize(item.details, out);
+ serialize(item.this_layer_setup_called, out);
+ serialize(item.gradient_input_is_stale, out);
+ serialize(item.get_output_and_gradient_input_disabled, out);
+ serialize(item.x_grad, out);
+ serialize(item.cached_output, out);
+ serialize(item.params_grad, out);
+ }
+
+ friend void deserialize(add_layer& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (!(1 <= version && version <= 2))
+ throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
+ deserialize(*item.subnetwork, in);
+ deserialize(item.details, in);
+ deserialize(item.this_layer_setup_called, in);
+ deserialize(item.gradient_input_is_stale, in);
+ deserialize(item.get_output_and_gradient_input_disabled, in);
+ deserialize(item.x_grad, in);
+ deserialize(item.cached_output, in);
+ if (version == 2)
+ deserialize(item.params_grad, in);
+ }
+
+ friend std::ostream& operator<< (std::ostream& out, const add_layer& item)
+ {
+ int min_length = 0;
+ item.print(out, 0, min_length);
+ return out;
+ }
+
+ void print (std::ostream& out, unsigned long idx, int& min_length) const
+ {
+ out << "layer<" << idx << ">\t" << impl::tensor_to_str(private_get_output(), min_length) << layer_details() << "\n";
+ subnet().print(out, idx+1, min_length);
+ }
+
+ private:
+
+ bool this_layer_operates_inplace(
+ )
+ {
+ // This layer can run in-place if it's an in-place capable layer and also if
+ // the layer it's on top of doesn't need its own output tensor (since in-place
+ // layers overwrite that tensor)
+ return impl::is_inplace_layer(details, *subnetwork) && !subnetwork->this_layer_requires_forward_output();
+ }
+ bool this_layer_requires_forward_output(
+ )
+ {
+ return impl::backward_requires_forward_output(details, *subnetwork);
+ }
+
+ void swap(add_layer& item)
+ {
+ std::swap(subnetwork,item.subnetwork);
+ std::swap(details, item.details);
+ std::swap(this_layer_setup_called, item.this_layer_setup_called);
+ std::swap(gradient_input_is_stale, item.gradient_input_is_stale);
+ std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled);
+ std::swap(x_grad, item.x_grad);
+ std::swap(cached_output, item.cached_output);
+ std::swap(params_grad, item.params_grad);
+ }
+
+
+ LAYER_DETAILS details;
+ std::unique_ptr<subnet_type> subnetwork;
+ bool this_layer_setup_called;
+ bool gradient_input_is_stale;
+ bool get_output_and_gradient_input_disabled;
+ // Note that if this_layer_operates_inplace()==true then x_grad and cached_output
+ // are not used at all. Instead, this layer uses these variables from the lower
+ // layer.
+ resizable_tensor x_grad;
+ resizable_tensor cached_output;
+
+ resizable_tensor params_grad;
+
+ // temp_tensor doesn't logically contribute to the state of this object.
+ // It is here only to prevent it from being reallocated over and over.
+ resizable_tensor temp_tensor;
+
+ };
+
+ template <typename T, typename U, typename E>
+ struct is_add_layer<add_layer<T,U,E>> : std::true_type {};
+ template <typename T, typename U, typename E>
+ struct is_add_layer<const add_layer<T,U,E>> : std::true_type {};
+ template <typename T, typename U, typename E>
+ struct is_add_layer<add_layer<T,U,E>&> : std::true_type {};
+ template <typename T, typename U, typename E>
+ struct is_add_layer<const add_layer<T,U,E>&> : std::true_type {};
+
+// ----------------------------------------------------------------------------------------
+
+// This version of add_layer handles the special case where the subnetwork being given is
+// just an input layer object.
+ template <typename LAYER_DETAILS, typename INPUT_LAYER, typename enabled>
+ class add_layer
+ {
+ public:
+ typedef LAYER_DETAILS layer_details_type;
+ typedef INPUT_LAYER subnet_type;
+ typedef typename INPUT_LAYER::input_type input_type;
+ const static size_t num_layers = 2;
+ const static size_t num_computational_layers = 1;
+
+ add_layer(
+ ):
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false),
+ _sample_expansion_factor(0)
+ {}
+
+ add_layer(const add_layer&) = default;
+ add_layer(add_layer&& item) : add_layer() { swap(item); }
+ add_layer& operator=(const add_layer&) = default;
+ add_layer& operator=(add_layer&& item) { swap(item); return *this; }
+
+ template <typename T, typename U, typename E>
+ friend class add_layer;
+ template <typename T, bool is_first, typename E>
+ friend class dimpl::subnet_wrapper;
+ template <unsigned long T, typename U, typename E>
+ friend class add_tag_layer;
+ template <template<typename> class T, typename U>
+ friend class add_skip_layer;
+ template <size_t N, template<typename> class L, typename S>
+ friend class repeat;
+
+ // Allow copying networks from one to another as long as their corresponding
+ // layers can be constructed from each other.
+ template <typename T, typename U, typename E>
+ add_layer(
+ const add_layer<T,U,E>& item
+ ):
+ input_layer(item.subnet()),
+ details(item.layer_details()),
+ this_layer_setup_called(item.this_layer_setup_called),
+ gradient_input_is_stale(item.gradient_input_is_stale),
+ get_output_and_gradient_input_disabled(false),
+ _sample_expansion_factor(item._sample_expansion_factor),
+ x_grad(item.x_grad),
+ cached_output(item.cached_output),
+ grad_final(item.grad_final)
+ {
+ }
+
+ add_layer(
+ const LAYER_DETAILS& layer_det
+ ) :
+ details(layer_det),
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false),
+ _sample_expansion_factor(0)
+ {}
+
+ add_layer(
+ const INPUT_LAYER& il
+ ) :
+ input_layer(il),
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false),
+ _sample_expansion_factor(0)
+ {}
+
+ add_layer(
+ LAYER_DETAILS&& layer_det
+ ) :
+ details(std::move(layer_det)),
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false),
+ _sample_expansion_factor(0)
+ {}
+
+ add_layer(
+ LAYER_DETAILS layer_det,
+ INPUT_LAYER il
+ ) :
+ details(std::move(layer_det)),
+ input_layer(std::move(il)),
+ this_layer_setup_called(false),
+ gradient_input_is_stale(true),
+ get_output_and_gradient_input_disabled(false),
+ _sample_expansion_factor(0)
+ {}
+
+ add_layer(
+ std::tuple<>,
+ const LAYER_DETAILS& layer_det
+ ) : add_layer(layer_det) {}
+
+ add_layer(
+ std::tuple<>,
+ LAYER_DETAILS&& layer_det
+ ) : add_layer(layer_det) {}
+
+ add_layer(
+ std::tuple<>,
+ LAYER_DETAILS layer_det,
+ INPUT_LAYER il
+ ) : add_layer(layer_det,il) {}
+
+ add_layer(
+ const std::tuple<LAYER_DETAILS>& layer_det
+ ) : add_layer(tuple_head(layer_det)) {}
+
+ add_layer(
+ const std::tuple<LAYER_DETAILS>& layer_det,
+ INPUT_LAYER il
+ ) : add_layer(tuple_head(layer_det),il) {}
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ input_layer.to_tensor(ibegin, iend, data);
+ // make sure the input layer's to_tensor() function is implemented properly.
+ DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend),
+ "The input layer can't produce fewer output tensors than there are inputs.");
+ DLIB_CASSERT(data.num_samples()%std::distance(ibegin,iend) == 0,
+ "The number of tensors produced by the input layer must be an integer multiple of the number of input objects.");
+
+ _sample_expansion_factor = data.num_samples()/std::distance(ibegin,iend);
+ data.async_copy_to_device();
+ }
+
+
+ template <typename forward_iterator>
+ const tensor& operator() (
+ forward_iterator ibegin,
+ forward_iterator iend
+ )
+ {
+ to_tensor(ibegin,iend,temp_tensor);
+ return forward(temp_tensor);
+ }
+
+
+ const tensor& operator() (const input_type& x)
+ {
+ return (*this)(&x, &x+1);
+ }
+
+ const tensor& forward (const tensor& x)
+ {
+ DLIB_CASSERT(sample_expansion_factor() != 0, "You must call to_tensor() before this function can be used.");
+ DLIB_CASSERT(x.num_samples()%sample_expansion_factor() == 0);
+ subnet_wrapper wsub(x, grad_final, _sample_expansion_factor);
+ if (!this_layer_setup_called)
+ {
+ details.setup(wsub);
+ this_layer_setup_called = true;
+ }
+ impl::call_layer_forward(details, wsub, cached_output);
+ gradient_input_is_stale = true;
+ return private_get_output();
+ }
+
+ private:
+ tensor& private_get_output() const { return const_cast<resizable_tensor&>(cached_output); }
+ tensor& private_get_gradient_input()
+ {
+ if (gradient_input_is_stale)
+ {
+ gradient_input_is_stale = false;
+ x_grad.copy_size(private_get_output());
+ x_grad = 0;
+ }
+ return x_grad;
+ }
+ void disable_output_and_gradient_getters (
+ ) { get_output_and_gradient_input_disabled = true; }
+ public:
+ const tensor& get_output() const
+ {
+ if (get_output_and_gradient_input_disabled)
+ throw dlib::error("Accessing this layer's get_output() is disabled because an in-place layer has been stacked on top of it.");
+ return private_get_output();
+ }
+ tensor& get_gradient_input()
+ {
+ if (get_output_and_gradient_input_disabled)
+ throw dlib::error("Accessing this layer's get_gradient_input() is disabled because an in-place layer has been stacked on top of it.");
+ return private_get_gradient_input();
+ }
+
+ const tensor& get_final_data_gradient(
+ ) const { return grad_final; }
+
+ void back_propagate_error(const tensor& x)
+ {
+ back_propagate_error(x, private_get_gradient_input());
+ }
+ void back_propagate_error(const tensor& x, const tensor& gradient_input)
+ {
+ // make sure grad_final is initialized to 0
+ if (!have_same_dimensions(x, grad_final))
+ grad_final.copy_size(x);
+ grad_final = 0;
+
+ subnet_wrapper wsub(x, grad_final, _sample_expansion_factor);
+ params_grad.copy_size(details.get_layer_params());
+ impl::call_layer_backward(details, private_get_output(),
+ gradient_input, wsub, static_cast<tensor&>(params_grad));
+
+ // zero out get_gradient_input()
+ gradient_input_is_stale = true;
+ }
+
+ template <typename solver_type>
+ void update_parameters(sstack<solver_type> solvers, double learning_rate)
+ {
+ DLIB_CASSERT(solvers.size()>=num_computational_layers);
+ // Don't try to adjust the parameters if this layer doesn't have any or the
+ // learning rate is disabled for this layer.
+ if (params_grad.size() != 0 && get_learning_rate_multiplier(details) != 0)
+ {
+ const tensor& step = solvers.top()(learning_rate, details, static_cast<const tensor&>(params_grad));
+ tt::add(details.get_layer_params(), details.get_layer_params(), step);
+ }
+ }
+
+ const tensor& get_parameter_gradient(
+ ) const { return params_grad; }
+
+ tensor& get_parameter_gradient (
+ ) { return params_grad; }
+
+ const subnet_type& subnet() const { return input_layer; }
+ subnet_type& subnet() { return input_layer; }
+
+ const layer_details_type& layer_details() const { return details; }
+ layer_details_type& layer_details() { return details; }
+
+ unsigned int sample_expansion_factor() const { return _sample_expansion_factor; }
+
+ void clean()
+ {
+ x_grad.clear();
+ grad_final.clear();
+ cached_output.clear();
+ params_grad.clear();
+ temp_tensor.clear();
+ gradient_input_is_stale = true;
+ call_clean_method_if_exists(details);
+ }
+
+ friend void serialize(const add_layer& item, std::ostream& out)
+ {
+ int version = 3;
+ serialize(version, out);
+ serialize(item.input_layer, out);
+ serialize(item.details, out);
+ serialize(item.this_layer_setup_called, out);
+ serialize(item.gradient_input_is_stale, out);
+ serialize(item.get_output_and_gradient_input_disabled, out);
+ serialize(item.x_grad, out);
+ serialize(item.cached_output, out);
+ serialize(item.grad_final, out);
+ serialize(item._sample_expansion_factor, out);
+ }
+
+ friend void deserialize(add_layer& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (!(2 <= version && version <= 3))
+ throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
+ deserialize(item.input_layer, in);
+ deserialize(item.details, in);
+ deserialize(item.this_layer_setup_called, in);
+ deserialize(item.gradient_input_is_stale, in);
+ deserialize(item.get_output_and_gradient_input_disabled, in);
+ deserialize(item.x_grad, in);
+ deserialize(item.cached_output, in);
+ deserialize(item.grad_final, in);
+ if (version >= 3)
+ deserialize(item._sample_expansion_factor, in);
+ else
+ item._sample_expansion_factor = 1; // all layer types set this to 1 in older dlib versions, so that's what we put here.
+ }
+
+ friend std::ostream& operator<< (std::ostream& out, const add_layer& item)
+ {
+ int min_length = 0;
+ item.print(out, 0, min_length);
+ return out;
+ }
+
+ void print (std::ostream& out, unsigned long idx, int& min_length) const
+ {
+ out << "layer<" << idx << ">\t" << impl::tensor_to_str(private_get_output(), min_length) << layer_details() << "\n";
+
+ // Don't print the repeat_input_layer since it doesn't exist from the user's
+ // point of view. It's just an artifact of how repeat<> works.
+ if (!std::is_same<subnet_type, impl::repeat_input_layer>::value)
+ out << "layer<" << idx+1 << ">\t" << subnet() << "\n";
+ }
+
+ private:
+
+ bool this_layer_requires_forward_output(
+ )
+ {
+ subnet_wrapper wsub(grad_final, grad_final, _sample_expansion_factor);
+ return impl::backward_requires_forward_output(details, wsub);
+ }
+
+ class subnet_wrapper
+ {
+ public:
+ subnet_wrapper(const tensor& x_, resizable_tensor& grad_final_, unsigned int sef) :
+ x(x_), grad_final(grad_final_), _sample_expansion_factor(sef) {}
+
+ subnet_wrapper(const subnet_wrapper&) = delete;
+ subnet_wrapper& operator=(const subnet_wrapper&) = delete;
+
+ unsigned int sample_expansion_factor() const { return _sample_expansion_factor;}
+ const tensor& get_output() const { return x; }
+ tensor& get_gradient_input()
+ {
+ if (!have_same_dimensions(x, grad_final))
+ {
+ grad_final.copy_size(x);
+ grad_final = 0;
+ }
+ return grad_final;
+ }
+
+ private:
+ const tensor& x;
+ resizable_tensor& grad_final;
+ unsigned int _sample_expansion_factor;
+ };
+
+ void swap(add_layer& item)
+ {
+ std::swap(input_layer, item.input_layer);
+ std::swap(details, item.details);
+ std::swap(this_layer_setup_called, item.this_layer_setup_called);
+ std::swap(gradient_input_is_stale, item.gradient_input_is_stale);
+ std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled);
+ std::swap(x_grad, item.x_grad);
+ std::swap(cached_output, item.cached_output);
+ std::swap(grad_final, item.grad_final);
+ std::swap(_sample_expansion_factor, item._sample_expansion_factor);
+ }
+
+ subnet_type input_layer;
+ LAYER_DETAILS details;
+ bool this_layer_setup_called;
+ bool gradient_input_is_stale;
+ bool get_output_and_gradient_input_disabled;
+ mutable unsigned int _sample_expansion_factor;
+ resizable_tensor x_grad;
+ resizable_tensor cached_output;
+ resizable_tensor grad_final;
+
+ // The following 2 objects don't logically contribute to the state of this class.
+ // They are only here to prevent them from being reallocated over and over in
+ // member functions.
+ resizable_tensor params_grad;
+ resizable_tensor temp_tensor;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <unsigned long ID, typename SUBNET, typename enabled=void>
+ class add_tag_layer;
+
+ template <template<typename SUBNET> class tag>
+ struct tag_id
+ {
+ const static unsigned long id = tag<impl::repeat_input_layer>::id;
+ };
+
+ template <unsigned long ID, typename SUBNET>
+ class add_tag_layer<ID,SUBNET,
+ typename std::enable_if<is_nonloss_layer_type<SUBNET>::value>::type>
+ {
+ public:
+ typedef SUBNET subnet_type;
+ typedef typename subnet_type::input_type input_type;
+ typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
+ const static size_t num_layers = subnet_type::num_layers + 1;
+ const static size_t num_computational_layers = subnet_type::num_computational_layers;
+ const static unsigned long id = ID;
+
+ add_tag_layer() {};
+ add_tag_layer(const add_tag_layer&) = default;
+ add_tag_layer(add_tag_layer&&) = default;
+ add_tag_layer& operator=(add_tag_layer&&) = default;
+ add_tag_layer& operator=(const add_tag_layer&) = default;
+
+ template <typename T>
+ add_tag_layer(
+ const add_tag_layer<ID,T>& item
+ ) : subnetwork(item.subnet())
+ {}
+
+ template <typename ...T>
+ add_tag_layer(
+ T ...args
+ ) :
+ subnetwork(std::move(args)...)
+ {
+ }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ subnetwork.to_tensor(ibegin,iend,data);
+ }
+
+ template <typename forward_iterator>
+ const tensor& operator() (
+ forward_iterator ibegin,
+ forward_iterator iend
+ )
+ {
+ return subnetwork(ibegin,iend);
+ }
+
+ const tensor& operator() (const input_type& x)
+ {
+ return subnetwork(x);
+ }
+
+ const tensor& forward(const tensor& x)
+ {
+ return subnetwork.forward(x);
+ }
+
+ const tensor& get_output() const { return subnetwork.get_output(); }
+
+ tensor& get_gradient_input()
+ {
+ return subnetwork.get_gradient_input();
+ }
+
+ const tensor& get_final_data_gradient(
+ ) const { return subnetwork.get_final_data_gradient(); }
+
+ void back_propagate_error(const tensor& x)
+ {
+ subnetwork.back_propagate_error(x);
+ }
+ void back_propagate_error(const tensor& x, const tensor& gradient_input)
+ {
+ subnetwork.back_propagate_error(x,gradient_input);
+ }
+
+ template <typename solver_type>
+ void update_parameters(sstack<solver_type> solvers, double learning_rate)
+ {
+ subnetwork.update_parameters(solvers, learning_rate);
+ }
+
+ const tensor& get_parameter_gradient(
+ ) const { return params_grad; }
+
+ tensor& get_parameter_gradient (
+ ) { return params_grad; }
+
+ const subnet_type& subnet() const { return subnetwork; }
+ subnet_type& subnet() { return subnetwork; }
+
+ unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
+
+ void clean()
+ {
+ subnetwork.clean();
+ }
+
+ friend void serialize(const add_tag_layer& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.subnetwork, out);
+ }
+
+ friend void deserialize(add_tag_layer& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer.");
+ deserialize(item.subnetwork, in);
+ }
+
+ friend std::ostream& operator<< (std::ostream& out, const add_tag_layer& item)
+ {
+ int min_length = 0;
+ item.print(out, 0, min_length);
+ return out;
+ }
+
+ void print (std::ostream& out, unsigned long idx, int& min_length) const
+ {
+ out << "layer<" << idx << ">\t" << impl::tensor_to_str(private_get_output(), min_length) << "tag" << ID << "\n";
+ subnet().print(out, idx+1, min_length);
+ }
+
+ private:
+
+ template <typename T, typename U, typename E>
+ friend class add_layer;
+ template <typename T, bool is_first, typename E>
+ friend class dimpl::subnet_wrapper;
+ template <unsigned long T, typename U, typename E>
+ friend class add_tag_layer;
+ template <template<typename> class T, typename U>
+ friend class add_skip_layer;
+ template <size_t N, template<typename> class L, typename S>
+ friend class repeat;
+
+ // You wouldn't put a tag on a layer if you didn't want to access its forward
+ // outputs. So this is always true.
+ bool this_layer_requires_forward_output(
+ ) { return true; }
+
+ void disable_output_and_gradient_getters (
+ )
+ {
+ // This should never happen because only inplace layers call
+ // disable_output_and_gradient_getters(), however, putting a tag layer right
+ // before an inplace layer basically means you don't want the following layer
+ // to operate in place. So the inplace layer should turn itself into an
+ // out-of-place layer and not call disable_output_and_gradient_getters().
+ DLIB_CASSERT(false,"This should never happen");
+ }
+
+ tensor& private_get_output() const
+ { return subnetwork.private_get_output(); }
+ tensor& private_get_gradient_input()
+ { return subnetwork.private_get_gradient_input(); }
+
+ subnet_type subnetwork;
+
+ // This member doesn't logically contribute to the state of the object since it is
+ // always empty. It's just here so we can have the get_parameter_gradient() methods
+ // which have to return something. So they return this empty tensor.
+ resizable_tensor params_grad;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename ...T>
+ struct decorator_repeat_group
+ {
+ decorator_repeat_group(
+ T&& ...args
+ ) : data(std::forward<T>(args)...) {}
+
+ std::tuple<T...> data;
+ };
+ template <typename ...T>
+ decorator_repeat_group<T...> repeat_group (
+ T&& ...args
+ )
+ {
+ return decorator_repeat_group<T...>(std::forward<T>(args)...);
+ }
+
+ template <
+ size_t num,
+ template<typename> class REPEATED_LAYER,
+ typename SUBNET
+ >
+ class repeat
+ {
+ static_assert(num > 0, "You can't have a layer repeated 0 times.");
+ public:
+ typedef SUBNET subnet_type;
+ typedef typename SUBNET::input_type input_type;
+ typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
+ const static size_t comp_layers_in_each_group = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers);
+ const static size_t comp_layers_in_repeated_group = comp_layers_in_each_group*num;
+ const static size_t num_computational_layers = comp_layers_in_repeated_group + SUBNET::num_computational_layers;
+
+ const static size_t layers_in_each_group = (REPEATED_LAYER<SUBNET>::num_layers-SUBNET::num_layers);
+ const static size_t layers_in_repeated_group = layers_in_each_group*num;
+ const static size_t num_layers = subnet_type::num_layers + layers_in_repeated_group;
+
+
+ typedef REPEATED_LAYER<impl::repeat_input_layer> repeated_layer_type;
+
+ repeat(
+ ) :
+ details(num)
+ {
+ }
+
+ size_t num_repetitions (
+ ) const { return num; }
+
+ const repeated_layer_type& get_repeated_layer (
+ size_t i
+ ) const
+ {
+ DLIB_CASSERT(i < num_repetitions());
+ return details[i];
+ }
+
+ repeated_layer_type& get_repeated_layer (
+ size_t i
+ )
+ {
+ DLIB_CASSERT(i < num_repetitions());
+ return details[i];
+ }
+
+ repeat(const repeat&) = default;
+ repeat(repeat&&) = default;
+ repeat& operator=(repeat&&) = default;
+ repeat& operator=(const repeat&) = default;
+
+ template <template<typename> class T, typename U>
+ repeat(
+ const repeat<num,T,U>& item
+ ) :
+ subnetwork(item.subnetwork)
+ {
+ for (auto&& d : item.details)
+ details.emplace_back(d);
+ }
+
+ template <typename T, typename ...U>
+ repeat(
+ T arg1,
+ U ...args2
+ ):
+ details(num, std::move(arg1)),
+ subnetwork(std::move(args2)...)
+ {
+ }
+
+ template <typename ...T, typename ...U>
+ repeat(
+ decorator_repeat_group<T...>&& arg1,
+ U ...args2
+ ):
+ details(num, arg1.data),
+ subnetwork(std::move(args2)...)
+ {
+ }
+
+ template <typename T, typename ...U>
+ repeat(
+ std::tuple<>,
+ T arg1,
+ U ...args2
+ ):
+ details(num, std::move(arg1)),
+ subnetwork(std::move(args2)...)
+ {
+ }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ subnetwork.to_tensor(ibegin,iend,data);
+ // call to_tensor on the networks in details just to populate the
+ // _sample_expansion_factor values in those networks. Other than that this
+ // call is a noop.
+ for (auto& d : details)
+ d.to_tensor(ibegin, iend, data);
+ }
+
+ template <typename forward_iterator>
+ const tensor& operator() (
+ forward_iterator ibegin,
+ forward_iterator iend
+ )
+ {
+ to_tensor(ibegin,iend,temp_tensor);
+ return forward(temp_tensor);
+ }
+
+ const tensor& operator() (const input_type& x)
+ {
+ return (*this)(&x, &x+1);
+ }
+
+ const tensor& forward(const tensor& x)
+ {
+ subnetwork.forward(x);
+ details[details.size()-1].forward(subnetwork.get_output());
+ for (long i = details.size()-2; i >= 0; --i)
+ details[i].forward(details[i+1].get_output());
+ return private_get_output();
+ }
+
+ private:
+ tensor& private_get_output() const
+ {
+ return details[0].private_get_output();
+ }
+ tensor& private_get_gradient_input()
+ {
+ return details[0].private_get_gradient_input();
+ }
+ public:
+ const tensor& get_output() const
+ {
+ return details[0].get_output();
+ }
+ tensor& get_gradient_input()
+ {
+ return details[0].get_gradient_input();
+ }
+
+ const tensor& get_parameter_gradient(
+ ) const { return details[0].get_parameter_gradient(); }
+
+ tensor& get_parameter_gradient (
+ ) { return details[0].get_parameter_gradient(); }
+
+ void back_propagate_error(const tensor& x)
+ {
+ back_propagate_error(x, private_get_gradient_input());
+ }
+ void back_propagate_error(const tensor& x, const tensor& gradient_input)
+ {
+ if (details.size() > 1)
+ {
+ details[0].back_propagate_error(details[1].get_output(), gradient_input);
+ for (size_t i = 1; i < details.size(); ++i)
+ {
+ if (i+1 < details.size())
+ details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient());
+ else
+ details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient());
+ }
+ }
+ else
+ {
+ details[0].back_propagate_error(subnetwork.get_output(), gradient_input);
+ }
+ subnetwork.back_propagate_error(x, details.back().get_final_data_gradient());
+ }
+
+ template <typename solver_type>
+ void update_parameters(sstack<solver_type> solvers, double learning_rate)
+ {
+ for (size_t i = 0; i < details.size(); ++i)
+ details[i].update_parameters(solvers.pop(comp_layers_in_each_group*i),learning_rate);
+ subnetwork.update_parameters(solvers.pop(comp_layers_in_each_group*details.size()),learning_rate);
+ }
+
+ const subnet_type& subnet() const { return subnetwork; }
+ subnet_type& subnet() { return subnetwork; }
+
+ unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
+
+ void clean()
+ {
+ temp_tensor.clear();
+ subnetwork.clean();
+ for (auto&& d : details)
+ d.clean();
+ }
+
+ friend void serialize(const repeat& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.details, out);
+ serialize(item.subnetwork, out);
+ }
+
+ friend void deserialize(repeat& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::repeat.");
+ deserialize(item.details, in);
+ deserialize(item.subnetwork, in);
+ }
+
+ friend std::ostream& operator<< (std::ostream& out, const repeat& item)
+ {
+ int min_length = 0;
+ item.print(out, 0, min_length);
+ return out;
+ }
+
+ void print (std::ostream& out, unsigned long idx, int& min_length) const
+ {
+ for (size_t i = 0; i < num_repetitions(); ++i)
+ {
+ get_repeated_layer(i).print(out, idx, min_length);
+ idx += layers_in_each_group;
+ }
+ subnet().print(out, idx, min_length);
+ }
+ private:
+
+
+ template <typename T, typename U, typename E>
+ friend class add_layer;
+ template <typename T, bool is_first, typename E>
+ friend class dimpl::subnet_wrapper;
+ template <unsigned long T, typename U, typename E>
+ friend class add_tag_layer;
+ template <template<typename> class T, typename U>
+ friend class add_skip_layer;
+ template <size_t N, template<typename> class L, typename S>
+ friend class repeat;
+
+ bool this_layer_requires_forward_output(
+ )
+ {
+ return details[0].this_layer_requires_forward_output();
+ }
+
+ void disable_output_and_gradient_getters (
+ )
+ {
+ details[0].disable_output_and_gradient_getters();
+ }
+
+
+ std::vector<repeated_layer_type> details;
+ subnet_type subnetwork;
+
+ // temp_tensor doesn't logically contribute to the state of this class.
+ // It is here only to void needing to reallocate it over and over.
+ resizable_tensor temp_tensor;
+ };
+
+ template <
+ size_t num,
+ template<typename> class REPEATED_LAYER,
+ typename SUBNET
+ >
+ struct is_nonloss_layer_type<repeat<num,REPEATED_LAYER,SUBNET>> : std::true_type {};
+
+// ----------------------------------------------------------------------------------------
+
+// This version of add_tag_layer handles the special case where the subnetwork being given
+// is just an input layer object.
+ template <unsigned long ID, typename INPUT_LAYER, typename enabled>
+ class add_tag_layer
+ {
+ public:
+ typedef INPUT_LAYER subnet_type;
+ typedef typename subnet_type::input_type input_type;
+ typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
+ const static size_t num_computational_layers = 0;
+ const static size_t num_layers = 2;
+ const static unsigned long id = ID;
+
+ add_tag_layer():cached_output_ptr(nullptr),gradient_input_is_stale(true),_sample_expansion_factor(0) {}
+
+ add_tag_layer(const add_tag_layer&) = default;
+ add_tag_layer& operator=(const add_tag_layer&) = default;
+ add_tag_layer(add_tag_layer&& item) : add_tag_layer() { swap(item); }
+ add_tag_layer& operator=(add_tag_layer&& item) { swap(item); return *this; }
+
+ template <typename T, typename E>
+ add_tag_layer(
+ const add_tag_layer<ID,T,E>& item
+ ) : input_layer(item.subnet()),
+ cached_output(item.cached_output),
+ cached_output_ptr(nullptr),
+ grad_final(item.grad_final),
+ gradient_input_is_stale(item.gradient_input_is_stale),
+ _sample_expansion_factor(0)
+ {}
+
+ template <typename ...T>
+ add_tag_layer(
+ T ...args
+ ) :
+ input_layer(std::move(args)...),
+ cached_output_ptr(nullptr),
+ gradient_input_is_stale(true),
+ _sample_expansion_factor(0)
+ {
+ }
+
+ add_tag_layer (
+ std::tuple<>
+ ) :
+ cached_output_ptr(nullptr),
+ gradient_input_is_stale(true),
+ _sample_expansion_factor(0)
+ {}
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ input_layer.to_tensor(ibegin,iend,data);
+
+ // make sure the input layer's to_tensor() function is implemented properly.
+ DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend),
+ "The input layer can't produce fewer output tensors than there are inputs.");
+ DLIB_CASSERT(data.num_samples()%std::distance(ibegin,iend) == 0,
+ "The number of tensors produced by the input layer must be an integer multiple of the number of input objects.");
+
+ _sample_expansion_factor = data.num_samples()/std::distance(ibegin,iend);
+ data.async_copy_to_device();
+ }
+
+ unsigned int sample_expansion_factor() const { return _sample_expansion_factor; }
+
+ template <typename forward_iterator>
+ const tensor& operator() (
+ forward_iterator ibegin,
+ forward_iterator iend
+ )
+ {
+ input_layer.to_tensor(ibegin,iend,cached_output);
+ cached_output_ptr = nullptr;
+ return get_output();
+ }
+
+ const tensor& operator() (const input_type& x)
+ {
+ return (*this)(&x, &x+1);
+ }
+
+ const tensor& forward(const tensor& x)
+ {
+ // If this tag is the first layer in one of the sub networks inside a repeat
+ // layer then we don't want it to be creating copies of x. This is because, we
+ // can just hold a pointer to x since the way repeat is constructed guarantees
+ // that x will have a lifetime larger than this pointer.
+ if (is_same_type<INPUT_LAYER, impl::repeat_input_layer>::value)
+ cached_output_ptr = const_cast<tensor*>(&x);
+ else
+ cached_output = x;
+ gradient_input_is_stale = true;
+ return get_output();
+ }
+
+ const tensor& get_output() const
+ {
+ if (cached_output_ptr)
+ return *cached_output_ptr;
+ else
+ return cached_output;
+ }
+
+ const tensor& get_final_data_gradient(
+ ) const { return grad_final; }
+
+ tensor& get_gradient_input()
+ {
+ if (!have_same_dimensions(get_output(), grad_final) ||
+ gradient_input_is_stale)
+ {
+ grad_final.copy_size(get_output());
+ grad_final = 0;
+ gradient_input_is_stale = false;
+ }
+ return grad_final;
+ }
+
+ void back_propagate_error(const tensor& /*x*/)
+ {
+ // nothing to do
+ }
+ void back_propagate_error(const tensor& /*x*/, const tensor& /*gradient_input*/)
+ {
+ // nothing to do
+ }
+
+ template <typename solver_type>
+ void update_parameters(sstack<solver_type> /*solvers*/, double /*learning_rate*/)
+ {
+ // nothing to do
+ }
+
+ const subnet_type& subnet() const { return input_layer; }
+ subnet_type& subnet() { return input_layer; }
+
+ void clean()
+ {
+ grad_final.clear();
+ cached_output.clear();
+ cached_output_ptr = 0;
+ }
+
+ friend void serialize(const add_tag_layer& item, std::ostream& out)
+ {
+ int version = 2;
+ serialize(version, out);
+ serialize(item.input_layer, out);
+ serialize(item.cached_output, out);
+ serialize(item.grad_final, out);
+ serialize(item.gradient_input_is_stale, out);
+ serialize(item._sample_expansion_factor, out);
+ }
+
+ friend void deserialize(add_tag_layer& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (!(1 <= version && version <= 2))
+ throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer.");
+ deserialize(item.input_layer, in);
+ deserialize(item.cached_output, in);
+ deserialize(item.grad_final, in);
+ deserialize(item.gradient_input_is_stale, in);
+ item.cached_output_ptr = nullptr;
+ if (version >= 2)
+ deserialize(item._sample_expansion_factor, in);
+ else
+ item._sample_expansion_factor = 1; // all layer types set this to 1 in older dlib versions, so that's what we put here.
+
+ }
+
+ friend std::ostream& operator<< (std::ostream& out, const add_tag_layer& item)
+ {
+ int min_length = 0;
+ item.print(out, 0, min_length);
+ return out;
+ }
+
+ void print (std::ostream& out, unsigned long idx, int& min_length) const
+ {
+ out << "layer<"<<idx << ">\t"<<impl::tensor_to_str(private_get_output(), min_length)<< "tag" << ID << "\n";
+ // Don't print the repeat_input_layer since it doesn't exist from the user's
+ // point of view. It's just an artifact of how repeat<> works.
+ if (!std::is_same<subnet_type, impl::repeat_input_layer>::value)
+ out << "layer<"<< idx+1 << ">\t" << subnet() << "\n";
+ }
+
+ private:
+
+ template <typename T, typename U, typename E>
+ friend class add_layer;
+ template <typename T, bool is_first, typename E>
+ friend class dimpl::subnet_wrapper;
+ template <unsigned long T, typename U, typename E>
+ friend class add_tag_layer;
+ template <template<typename> class T, typename U>
+ friend class add_skip_layer;
+ template <size_t N, template<typename> class L, typename S>
+ friend class repeat;
+
+ // You woudln't put a tag on a layer if you didn't want to access its forward
+ // outputs. So this is always true.
+ bool this_layer_requires_forward_output(
+ ) { return true; }
+
+ void disable_output_and_gradient_getters (
+ )
+ {
+ // This should never happen because only inplace layers call
+ // disable_output_and_gradient_getters(), however, putting a tag layer right
+ // before an inplace layer basically means you don't want the following layer
+ // to operate in place. So the inplace layer should turn itself into an
+ // out-of-place layer and not call disable_output_and_gradient_getters().
+ DLIB_CASSERT(false,"This should never happen");
+ }
+
+ tensor& private_get_output() const
+ { return const_cast<tensor&>(get_output()); }
+ tensor& private_get_gradient_input()
+ { return get_gradient_input(); }
+
+ void swap(add_tag_layer& item)
+ {
+ std::swap(input_layer, item.input_layer);
+ std::swap(cached_output, item.cached_output);
+ std::swap(cached_output_ptr, item.cached_output_ptr);
+ std::swap(grad_final, item.grad_final);
+ std::swap(gradient_input_is_stale, item.gradient_input_is_stale);
+ std::swap(_sample_expansion_factor, item._sample_expansion_factor);
+ }
+
+ subnet_type input_layer;
+ resizable_tensor cached_output;
+ tensor* cached_output_ptr;
+ resizable_tensor grad_final;
+ bool gradient_input_is_stale;
+ mutable unsigned int _sample_expansion_factor;
+ };
+
+ template <unsigned long ID, typename U, typename E>
+ struct is_nonloss_layer_type<add_tag_layer<ID,U,E>> : std::true_type {};
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename LOSS_DETAILS, typename SUBNET>
+ class add_loss_layer;
+
+ class no_label_type
+ {
+ private:
+ // We don't want anyone making these no_label_type objects. They are here only to
+ // allow add_loss_layer::training_label_type and dnn_trainer::training_label_type
+ // to exist which avoids needing to overload add_loss_layer and dnn_trainer for
+ // supervised an unsupervised losses. It also can be a type to use in template
+ // metaprogramming to indicate "no label". So here we make the constructor private
+ // with the exception that add_loss_layer objects can make it (again, just to
+ // simplify add_loss_layer's implementation).
+ no_label_type(){};
+ template <typename LOSS_DETAILS, typename SUBNET> friend class add_loss_layer;
+ template < typename net_type, typename solver_type > friend class dnn_trainer;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename LOSS_DETAILS, typename SUBNET>
+ class add_loss_layer
+ {
+ template <typename T, typename enabled=void>
+ struct get_loss_layer_training_label_type
+ {
+ typedef no_label_type type;
+ };
+ template <typename T>
+ struct get_loss_layer_training_label_type<T,typename std::enable_if<sizeof(typename T::training_label_type)!=0>::type>
+ {
+ typedef typename T::training_label_type type;
+ };
+
+ template <typename T, typename enabled=void>
+ struct get_loss_layer_output_label_type
+ {
+ typedef no_label_type type;
+ };
+ template <typename T>
+ struct get_loss_layer_output_label_type<T,typename std::enable_if<sizeof(typename T::output_label_type)!=0>::type>
+ {
+ typedef typename T::output_label_type type;
+ };
+
+ public:
+ typedef LOSS_DETAILS loss_details_type;
+ typedef SUBNET subnet_type;
+ typedef typename subnet_type::input_type input_type;
+ const static size_t num_layers = subnet_type::num_layers + 1;
+ // Note that the loss layer doesn't count as an additional computational layer.
+ const static size_t num_computational_layers = subnet_type::num_computational_layers;
+ typedef typename get_loss_layer_training_label_type<LOSS_DETAILS>::type training_label_type;
+ typedef typename get_loss_layer_output_label_type<LOSS_DETAILS>::type output_label_type;
+
+ static_assert(is_nonloss_layer_type<SUBNET>::value,
+ "SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer.");
+
+
+ add_loss_layer() {};
+ add_loss_layer(const add_loss_layer&) = default;
+ add_loss_layer& operator=(const add_loss_layer&) = default;
+ add_loss_layer(add_loss_layer&& item) : add_loss_layer() { swap(item); }
+ add_loss_layer& operator=(add_loss_layer&& item) { swap(item); return *this; }
+
+ template <typename T, typename U>
+ add_loss_layer(
+ const add_loss_layer<T,U>& item
+ ) :
+ loss(item.loss_details()),
+ subnetwork(item.subnet())
+ {}
+
+ template <typename ...T>
+ add_loss_layer(
+ const LOSS_DETAILS& layer_det,
+ T&& ...args
+ ) :
+ loss(layer_det),
+ subnetwork(std::forward<T>(args)...)
+ {
+ }
+
+ template <typename ...T>
+ add_loss_layer(
+ LOSS_DETAILS&& layer_det,
+ T&& ...args
+ ) :
+ loss(std::move(layer_det)),
+ subnetwork(std::forward<T>(args)...)
+ {
+ }
+
+ template <typename T, typename ...U>
+ struct disable_forwarding_constr
+ {
+ const static bool value = std::is_constructible<LOSS_DETAILS,T>::value;
+ };
+ template <typename ...T>
+ struct disable_forwarding_constr<add_loss_layer<T...>>
+ {
+ const static bool value = true;
+ };
+
+ template <
+ typename ...T,
+ typename = typename std::enable_if<!disable_forwarding_constr<typename std::remove_reference<T>::type...>::value>::type
+ >
+ add_loss_layer(
+ T&& ...args
+ ) :
+ subnetwork(std::forward<T>(args)...)
+ {
+ }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ subnetwork.to_tensor(ibegin,iend,data);
+ }
+
+ unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
+
+ template <typename output_iterator>
+ void operator() (
+ const tensor& x,
+ output_iterator obegin
+ )
+ {
+ subnetwork.forward(x);
+ const dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
+ loss.to_label(x, wsub, obegin);
+ }
+
+ template <typename forward_iterator, typename output_iterator>
+ void operator() (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ output_iterator obegin
+ )
+ {
+ to_tensor(ibegin,iend,temp_tensor);
+ (*this)(temp_tensor, obegin);
+ }
+
+ const output_label_type& operator() (const input_type& x)
+ {
+ (*this)(&x, &x+1, &temp_label);
+ return temp_label;
+ }
+
+ template <typename ...T>
+ const output_label_type& process (const input_type& x, T&& ...args)
+ {
+ to_tensor(&x,&x+1,temp_tensor);
+ subnetwork.forward(temp_tensor);
+ const dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
+ loss.to_label(temp_tensor, wsub, &temp_label, std::forward<T>(args)...);
+ return temp_label;
+ }
+
+ template <typename iterable_type, typename ...T>
+ std::vector<output_label_type> process_batch (const iterable_type& data, size_t batch_size, T&& ...args)
+ {
+ std::vector<output_label_type> results(std::distance(data.begin(), data.end()));
+ auto o = results.begin();
+ auto i = data.begin();
+ auto num_remaining = results.size();
+ while(num_remaining != 0)
+ {
+ auto inc = std::min(batch_size, num_remaining);
+ to_tensor(i,i+inc,temp_tensor);
+ subnetwork.forward(temp_tensor);
+ const dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
+ loss.to_label(temp_tensor, wsub, o, std::forward<T>(args)...);
+
+ i += inc;
+ o += inc;
+ num_remaining -= inc;
+ }
+ return results;
+ }
+
+ template <typename iterable_type>
+ std::vector<output_label_type> operator() (
+ const iterable_type& data,
+ size_t batch_size = 128
+ )
+ {
+ std::vector<output_label_type> results(std::distance(data.begin(), data.end()));
+ auto o = results.begin();
+ auto i = data.begin();
+ auto num_remaining = results.size();
+ while(num_remaining != 0)
+ {
+ auto inc = std::min(batch_size, num_remaining);
+ (*this)(i, i+inc, o);
+ i += inc;
+ o += inc;
+ num_remaining -= inc;
+ }
+ return results;
+ }
+
+ template <typename label_iterator>
+ double compute_loss (
+ const tensor& x,
+ label_iterator lbegin
+ )
+ {
+ subnetwork.forward(x);
+ dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
+ return loss.compute_loss_value_and_gradient(x, lbegin, wsub);
+ }
+
+ template <typename forward_iterator, typename label_iterator>
+ double compute_loss (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ label_iterator lbegin
+ )
+ {
+ to_tensor(ibegin,iend,temp_tensor);
+ return compute_loss(temp_tensor, lbegin);
+ }
+
+ double compute_loss (
+ const tensor& x
+ )
+ {
+ subnetwork.forward(x);
+ dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
+ return loss.compute_loss_value_and_gradient(x, wsub);
+ }
+
+ template <typename forward_iterator>
+ double compute_loss (
+ forward_iterator ibegin,
+ forward_iterator iend
+ )
+ {
+ to_tensor(ibegin,iend,temp_tensor);
+ return compute_loss(temp_tensor);
+ }
+
+ template <typename label_iterator>
+ double compute_parameter_gradients (
+ const tensor& x,
+ label_iterator lbegin
+ )
+ {
+ subnetwork.forward(x);
+ dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
+ double l = loss.compute_loss_value_and_gradient(x, lbegin, wsub);
+ subnetwork.back_propagate_error(x);
+ return l;
+ }
+ template <typename forward_iterator, typename label_iterator>
+ double compute_parameter_gradients (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ label_iterator lbegin
+ )
+ {
+ to_tensor(ibegin,iend,temp_tensor);
+ return compute_parameter_gradients(temp_tensor, lbegin);
+ }
+ double compute_parameter_gradients (
+ const tensor& x
+ )
+ {
+ subnetwork.forward(x);
+ dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
+ double l = loss.compute_loss_value_and_gradient(x, wsub);
+ subnetwork.back_propagate_error(x);
+ return l;
+ }
+ template <typename forward_iterator>
+ double compute_parameter_gradients (
+ forward_iterator ibegin,
+ forward_iterator iend
+ )
+ {
+ to_tensor(ibegin,iend,temp_tensor);
+ return compute_parameter_gradients(temp_tensor);
+ }
+
+ template <typename solver_type>
+ void update_parameters (
+ sstack<solver_type> solvers,
+ double learning_rate
+ )
+ {
+ subnetwork.update_parameters(solvers, learning_rate);
+ }
+
+ const subnet_type& subnet() const { return subnetwork; }
+ subnet_type& subnet() { return subnetwork; }
+ const loss_details_type& loss_details() const { return loss; }
+ loss_details_type& loss_details() { return loss; }
+
+ void clean (
+ )
+ {
+ temp_tensor.clear();
+ subnetwork.clean();
+ }
+
+ template <typename T, typename U>
+ friend void serialize(const add_loss_layer<T,U>& item, std::ostream& out);
+ template <typename T, typename U>
+ friend void deserialize(add_loss_layer<T,U>& item, std::istream& in);
+
+ friend std::ostream& operator<< (std::ostream& out, const add_loss_layer& item)
+ {
+ int min_length = 0;
+ item.print(out, 0, min_length);
+ return out;
+ }
+
+ void print (std::ostream& out, unsigned long idx, int& min_length) const
+ {
+ out << "layer<" << idx << ">\t" << loss_details() << "\n";
+ subnet().print(out, idx+1, min_length);
+ }
+
+ private:
+
+
+ void swap(add_loss_layer& item)
+ {
+ std::swap(loss, item.loss);
+ std::swap(subnetwork, item.subnetwork);
+ }
+
+ loss_details_type loss;
+ subnet_type subnetwork;
+
+ // These two objects don't logically contribute to the state of this object. They
+ // are here to prevent them from being reallocated over and over.
+ output_label_type temp_label;
+ resizable_tensor temp_tensor;
+ };
+
+ template <typename LOSS_DETAILS, typename SUBNET>
+ void serialize(const add_loss_layer<LOSS_DETAILS,SUBNET>& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.loss, out);
+ serialize(item.subnetwork, out);
+ }
+
+ template <typename LOSS_DETAILS, typename SUBNET>
+ void deserialize(add_loss_layer<LOSS_DETAILS,SUBNET>& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::add_loss_layer.");
+ deserialize(item.loss, in);
+ deserialize(item.subnetwork, in);
+ }
+
+
+ template <typename T, typename U>
+ struct is_loss_layer_type<add_loss_layer<T,U>> : std::true_type {};
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <unsigned int i, typename T, typename enabled = void>
+ struct layer_helper
+ {
+ static_assert(i < T::num_layers, "Call to layer() attempted to access non-existing layer in neural network.");
+ static T& makeT();
+ using next_type = typename std::remove_reference<decltype(makeT().subnet())>::type;
+ using type = typename layer_helper<i-1,next_type>::type;
+ static type& layer(T& n)
+ {
+ return layer_helper<i-1,next_type>::layer(n.subnet());
+ }
+ };
+ template <
+ unsigned int i,
+ size_t N, template<typename> class L, typename S
+ >
+ struct layer_helper<i,repeat<N,L,S>, typename std::enable_if<(i!=0&&i>=repeat<N,L,S>::layers_in_repeated_group)>::type>
+ {
+ const static size_t layers_in_repeated_group = repeat<N,L,S>::layers_in_repeated_group;
+
+ static repeat<N,L,S>& makeT();
+ using next_type = typename std::remove_reference<decltype(makeT().subnet())>::type;
+ using type = typename layer_helper<i-layers_in_repeated_group,next_type>::type;
+ static type& layer(repeat<N,L,S>& n)
+ {
+ return layer_helper<i-layers_in_repeated_group,next_type>::layer(n.subnet());
+ }
+ };
+ template <
+ unsigned int i,
+ size_t N, template<typename> class L, typename S
+ >
+ struct layer_helper<i,repeat<N,L,S>, typename std::enable_if<(i!=0&&i<repeat<N,L,S>::layers_in_repeated_group)>::type>
+ {
+ const static size_t layers_in_each_group = repeat<N,L,S>::layers_in_each_group;
+ typedef typename repeat<N,L,S>::repeated_layer_type repeated_layer_type;
+ using next_type = repeated_layer_type;
+ using type = typename layer_helper<i%layers_in_each_group,next_type>::type;
+ static type& layer(repeat<N,L,S>& n)
+ {
+ return layer_helper<i%layers_in_each_group,next_type>::layer(n.get_repeated_layer(i/layers_in_each_group));
+ }
+ };
+ template <
+ size_t N, template<typename> class L, typename S
+ >
+ struct layer_helper<0,repeat<N,L,S>, void>
+ {
+ typedef typename repeat<N,L,S>::repeated_layer_type repeated_layer_type;
+ using type = repeated_layer_type;
+ static type& layer(repeat<N,L,S>& n)
+ {
+ return n.get_repeated_layer(0);
+ }
+ };
+
+
+
+ template <
+ unsigned int i,
+ size_t N, template<typename> class L, typename S
+ >
+ struct layer_helper<i,const repeat<N,L,S>, typename std::enable_if<(i!=0&&i>=repeat<N,L,S>::layers_in_repeated_group)>::type>
+ {
+ const static size_t layers_in_repeated_group = repeat<N,L,S>::layers_in_repeated_group;
+
+ static const repeat<N,L,S>& makeT();
+ using next_type = const typename std::remove_reference<decltype(makeT().subnet())>::type;
+ using type = const typename layer_helper<i-layers_in_repeated_group,next_type>::type;
+ static type& layer(const repeat<N,L,S>& n)
+ {
+ return layer_helper<i-layers_in_repeated_group,next_type>::layer(n.subnet());
+ }
+ };
+ template <
+ unsigned int i,
+ size_t N, template<typename> class L, typename S
+ >
+ struct layer_helper<i,const repeat<N,L,S>, typename std::enable_if<(i!=0&&i<repeat<N,L,S>::layers_in_repeated_group)>::type>
+ {
+ const static size_t layers_in_each_group = repeat<N,L,S>::layers_in_each_group;
+ typedef typename repeat<N,L,S>::repeated_layer_type repeated_layer_type;
+ using next_type = const repeated_layer_type;
+ using type = const typename layer_helper<i%layers_in_each_group,next_type>::type;
+ static type& layer(const repeat<N,L,S>& n)
+ {
+ return layer_helper<i%layers_in_each_group,next_type>::layer(n.get_repeated_layer(i/layers_in_each_group));
+ }
+ };
+ template <
+ size_t N, template<typename> class L, typename S
+ >
+ struct layer_helper<0,const repeat<N,L,S>, void>
+ {
+ typedef typename repeat<N,L,S>::repeated_layer_type repeated_layer_type;
+ using type = const repeated_layer_type;
+ static type& layer(const repeat<N,L,S>& n)
+ {
+ return n.get_repeated_layer(0);
+ }
+ };
+
+
+
+ template <typename T>
+ struct layer_helper<0,T,void>
+ {
+ using type = T;
+ static type& layer(T& n)
+ {
+ return n;
+ }
+ };
+
+ template <template<typename> class Match, typename T, unsigned int i, typename enabled = void>
+ struct layer_helper_match
+ {
+ static T& makeT();
+ using next_type = typename std::remove_reference<decltype(makeT().subnet())>::type;
+ using type = typename layer_helper_match<Match,next_type,i>::type;
+ static type& layer(T& n)
+ {
+ return layer_helper_match<Match,next_type,i>::layer(n.subnet());
+ }
+ };
+ // This overload catches add_layer and add_loss_layer templates.
+ template <template<typename> class Match, typename T, unsigned int i>
+ struct layer_helper_match<Match,T,i,
+ typename std::enable_if<std::is_same<const T,const Match<typename T::subnet_type>>::value>::type>
+ {
+ using type = typename layer_helper<i,T>::type;
+ static type& layer(T& n)
+ {
+ return layer_helper<i,T>::layer(n);
+ }
+ };
+ // This overload catches input templates.
+ template <template<typename> class Match, typename T, unsigned int i>
+ struct layer_helper_match<Match,T,i,
+ typename std::enable_if<std::is_same<const T,const Match<typename T::input_type>>::value>::type>
+ {
+ using type = typename layer_helper<i,T>::type;
+ static type& layer(T& n)
+ {
+ return layer_helper<i,T>::layer(n);
+ }
+ };
+ // This overload catches subnet_wrapper templates.
+ template <template<typename> class Match, typename T, unsigned int i>
+ struct layer_helper_match<Match,T,i,
+ typename std::enable_if<std::is_same<const typename T::wrapped_type,
+ const Match<typename T::wrapped_type::subnet_type>>::value>::type>
+ {
+ using type = typename layer_helper<i,T>::type;
+ static type& layer(T& n)
+ {
+ return layer_helper<i,T>::layer(n);
+ }
+ };
+ }
+
+ template <unsigned int i, typename T>
+ typename impl::layer_helper<i,T>::type& layer (T& n)
+ {
+ return impl::layer_helper<i,T>::layer(n);
+ }
+
+ template <template<typename> class Match, typename T>
+ typename impl::layer_helper_match<Match,T,0>::type& layer (T& n)
+ {
+ return impl::layer_helper_match<Match,T,0>::layer(n);
+ }
+
+ template <template<typename> class Match, unsigned int i, typename T>
+ typename impl::layer_helper_match<Match,T,i>::type& layer (T& n)
+ {
+ return impl::layer_helper_match<Match,T,i>::layer(n);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ namespace dimpl
+ {
+ template <typename T>
+ T& get_input_details (
+ T& net
+ )
+ {
+ return net;
+ }
+
+ template <typename T, bool is_first, typename enabled>
+ auto get_input_details (
+ dimpl::subnet_wrapper<T,is_first,enabled>& net
+ ) -> decltype(net.layer_details())&
+ {
+ return net.layer_details();
+ }
+
+ template <typename T, bool is_first, typename enabled>
+ auto get_input_details (
+ const dimpl::subnet_wrapper<T,is_first,enabled>& net
+ ) -> decltype(net.layer_details())&
+ {
+ return net.layer_details();
+ }
+ }
+
+ template <typename net_type>
+ auto input_layer (
+ net_type& net
+ ) -> decltype(dimpl::get_input_details(layer<net_type::num_layers-1>(net)))&
+ {
+ // Calling input_layer() on a subnet_wrapper is a little funny since the behavior of
+ // .subnet() returns another subnet_wrapper rather than an input details object as it
+ // does in add_layer.
+ return dimpl::get_input_details(layer<net_type::num_layers-1>(net));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <template<typename> class TAG_TYPE, typename SUBNET>
+ class add_skip_layer
+ {
+ public:
+ typedef SUBNET subnet_type;
+ typedef typename subnet_type::input_type input_type;
+ typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
+ const static size_t num_layers = subnet_type::num_layers + 1;
+ const static size_t num_computational_layers = subnet_type::num_computational_layers;
+ const static unsigned long id = tag_id<TAG_TYPE>::id;
+
+ add_skip_layer() {};
+ add_skip_layer(const add_skip_layer&) = default;
+ add_skip_layer(add_skip_layer&&) = default;
+ add_skip_layer& operator=(add_skip_layer&&) = default;
+ add_skip_layer& operator=(const add_skip_layer&) = default;
+
+ template <typename T>
+ add_skip_layer(
+ const add_skip_layer<TAG_TYPE,T>& item
+ ) : subnetwork(item.subnet())
+ {}
+
+ template <typename ...T>
+ add_skip_layer(
+ T ...args
+ ) :
+ subnetwork(std::move(args)...)
+ {
+ }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ subnetwork.to_tensor(ibegin,iend,data);
+ }
+
+ template <typename forward_iterator>
+ const tensor& operator() (
+ forward_iterator ibegin,
+ forward_iterator iend
+ )
+ {
+ subnetwork(ibegin,iend);
+ return layer<TAG_TYPE>(subnetwork).get_output();
+ }
+
+ const tensor& operator() (const input_type& x)
+ {
+ subnetwork(x);
+ return layer<TAG_TYPE>(subnetwork).get_output();
+ }
+
+ const tensor& forward(const tensor& x)
+ {
+ subnetwork.forward(x);
+ return layer<TAG_TYPE>(subnetwork).get_output();
+ }
+
+ const tensor& get_output() const
+ {
+ return layer<TAG_TYPE>(subnetwork).get_output();
+ }
+
+ tensor& get_gradient_input()
+ {
+ return layer<TAG_TYPE>(subnetwork).get_gradient_input();
+ }
+
+ const tensor& get_final_data_gradient(
+ ) const
+ {
+ return subnetwork.get_final_data_gradient();
+ }
+
+ void back_propagate_error(const tensor& x)
+ {
+ subnetwork.back_propagate_error(x);
+ }
+
+ template <typename solver_type>
+ void update_parameters(sstack<solver_type> solvers, double learning_rate)
+ {
+ subnetwork.update_parameters(solvers, learning_rate);
+ }
+
+ const tensor& get_parameter_gradient(
+ ) const { return params_grad; }
+
+ tensor& get_parameter_gradient (
+ ) { return params_grad; }
+
+
+ const subnet_type& subnet() const
+ {
+ return subnetwork;
+ }
+
+ subnet_type& subnet()
+ {
+ return subnetwork;
+ }
+
+ unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
+
+ void clean()
+ {
+ subnetwork.clean();
+ }
+
+ friend void serialize(const add_skip_layer& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.subnetwork, out);
+ }
+
+ friend void deserialize(add_skip_layer& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::add_skip_layer.");
+ deserialize(item.subnetwork, in);
+ }
+
+ friend std::ostream& operator<< (std::ostream& out, const add_skip_layer& item)
+ {
+ int min_length = 0;
+ item.print(out, 0, min_length);
+ return out;
+ }
+
+ void print (std::ostream& out, unsigned long idx, int& min_length) const
+ {
+ out << "layer<" << idx << ">\t"<<impl::tensor_to_str(private_get_output(), min_length) <<"skip"<<id<<"\n";
+ subnet().print(out, idx+1, min_length);
+ }
+
+ private:
+
+
+ template <typename T, typename U, typename E>
+ friend class add_layer;
+ template <typename T, bool is_first, typename E>
+ friend class dimpl::subnet_wrapper;
+ template <unsigned long T, typename U, typename E>
+ friend class add_tag_layer;
+ template <template<typename> class T, typename U>
+ friend class add_skip_layer;
+ template <size_t N, template<typename> class L, typename S>
+ friend class repeat;
+
+ bool this_layer_requires_forward_output(
+ ) { return layer<TAG_TYPE>(subnetwork).this_layer_requires_forward_output(); }
+
+ void disable_output_and_gradient_getters (
+ ) { layer<TAG_TYPE>(subnetwork).disable_output_and_gradient_getters(); }
+
+ tensor& private_get_output() const
+ { return layer<TAG_TYPE>(subnetwork).private_get_output(); }
+ tensor& private_get_gradient_input()
+ { return layer<TAG_TYPE>(subnetwork).private_get_gradient_input(); }
+
+ subnet_type subnetwork;
+
+ // This member doesn't logically contribute to the state of the object since it is
+ // always empty. It's just here so we can have the get_parameter_gradient() methods
+ // which have to return something. So they return this empty tensor.
+ resizable_tensor params_grad;
+ };
+ template <template<typename> class T, typename U>
+ struct is_nonloss_layer_type<add_skip_layer<T,U>> : std::true_type {};
+
+ template <typename SUBNET> using tag1 = add_tag_layer< 1, SUBNET>;
+ template <typename SUBNET> using tag2 = add_tag_layer< 2, SUBNET>;
+ template <typename SUBNET> using tag3 = add_tag_layer< 3, SUBNET>;
+ template <typename SUBNET> using tag4 = add_tag_layer< 4, SUBNET>;
+ template <typename SUBNET> using tag5 = add_tag_layer< 5, SUBNET>;
+ template <typename SUBNET> using tag6 = add_tag_layer< 6, SUBNET>;
+ template <typename SUBNET> using tag7 = add_tag_layer< 7, SUBNET>;
+ template <typename SUBNET> using tag8 = add_tag_layer< 8, SUBNET>;
+ template <typename SUBNET> using tag9 = add_tag_layer< 9, SUBNET>;
+ template <typename SUBNET> using tag10 = add_tag_layer<10, SUBNET>;
+
+ template <typename SUBNET> using skip1 = add_skip_layer< tag1, SUBNET>;
+ template <typename SUBNET> using skip2 = add_skip_layer< tag2, SUBNET>;
+ template <typename SUBNET> using skip3 = add_skip_layer< tag3, SUBNET>;
+ template <typename SUBNET> using skip4 = add_skip_layer< tag4, SUBNET>;
+ template <typename SUBNET> using skip5 = add_skip_layer< tag5, SUBNET>;
+ template <typename SUBNET> using skip6 = add_skip_layer< tag6, SUBNET>;
+ template <typename SUBNET> using skip7 = add_skip_layer< tag7, SUBNET>;
+ template <typename SUBNET> using skip8 = add_skip_layer< tag8, SUBNET>;
+ template <typename SUBNET> using skip9 = add_skip_layer< tag9, SUBNET>;
+ template <typename SUBNET> using skip10 = add_skip_layer<tag10, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ namespace timpl
+ {
+ inline void fill_with_gassuan_random_numbers (
+ tensor& t,
+ dlib::rand& rnd,
+ double sigma = 1
+ )
+ {
+ float* data = t.host();
+ for (size_t i = 0; i < t.size(); ++i)
+ data[i] = rnd.get_random_gaussian()*sigma;
+ }
+
+ class test_layer_subnet
+ {
+ public:
+ test_layer_subnet (
+ dlib::rand& rnd_
+ ) : rnd(rnd_)
+ {
+ // Output and gradient_input have to have the same dimensions in each
+ // layer.
+ const long num_samples = rnd.get_random_32bit_number()%4+3;
+ const long k = rnd.get_random_32bit_number()%4+2;
+ const long nr = rnd.get_random_32bit_number()%4+2;
+ const long nc = rnd.get_random_32bit_number()%4+2;
+
+ output.set_size(num_samples, k, nr, nc);
+ gradient_input.set_size(num_samples, k, nr, nc);
+
+ // Use a non-zero initial gradient to make sure the layers add to it
+ // rather than assign and blow away the initial value.
+ fill_with_gassuan_random_numbers(gradient_input, rnd, 0.01);
+
+ fill_with_gassuan_random_numbers(output, rnd);
+ }
+
+
+ tensor& get_mutable_output() { return output; }
+ const tensor& get_output() const { return output; }
+ const tensor& private_get_output() const { return get_output(); }
+ const test_layer_subnet& subnet() const { init_sub(); return *subnetwork; }
+
+ tensor& get_gradient_input() { return gradient_input; }
+ tensor& private_get_gradient_input() { return get_gradient_input(); }
+ test_layer_subnet& subnet() { init_sub(); return *subnetwork; }
+
+
+
+ unsigned long count_outputs() const
+ {
+ if (subnetwork)
+ return subnetwork->count_outputs() + output.size();
+ else
+ return output.size();
+ }
+
+ float& get_output_element(unsigned long i)
+ {
+ if (i < output.size())
+ return output.host()[i];
+ else
+ return subnet().get_output_element(i-output.size());
+ }
+
+ float get_gradient_input_element(unsigned long i) const
+ {
+ if (i < gradient_input.size())
+ return gradient_input.host()[i];
+ else
+ return subnet().get_gradient_input_element(i-gradient_input.size());
+ }
+
+
+ private:
+ // We lazily initialize sub-layers as needed when someone tries to call
+ // subnet()
+ void init_sub() const
+ {
+ if (!subnetwork)
+ subnetwork.reset(new test_layer_subnet(rnd));
+ }
+
+ dlib::rand& rnd;
+ mutable std::unique_ptr<test_layer_subnet> subnetwork;
+ resizable_tensor output;
+ resizable_tensor gradient_input;
+ };
+
+ }
+
+ struct layer_test_results
+ {
+ layer_test_results() : was_good(true) {}
+ explicit layer_test_results(const std::string& l) : log(l),was_good(false) {}
+
+ std::string log;
+ bool was_good;
+
+ operator bool() const { return was_good; }
+ };
+
+ inline std::ostream& operator<< (std::ostream& out, const layer_test_results& item)
+ {
+ out << item.log;
+ return out;
+ }
+
+ template <
+ typename layer_details_type
+ >
+ layer_test_results impl_test_layer (
+ layer_details_type l,
+ const float base_eps
+ )
+ {
+ using namespace timpl;
+ // Do some setup
+ running_stats<double> rs_data, rs_params;
+ dlib::rand rnd;
+ std::ostringstream sout;
+ for (int iter = 0; iter < 10; ++iter)
+ {
+ test_layer_subnet subnetwork(rnd);
+ resizable_tensor output, out2, out3;
+ // Run setup() and forward() as well to make sure any calls to subnet() have
+ // happened before we start assuming we know how many data elements there are
+ // (since we do a lazy layer creation thing based on calls to subnet() inside
+ // test_layer_subnet).
+ l.setup(subnetwork);
+ impl::call_layer_forward(l, subnetwork, output);
+
+ resizable_tensor input_grad;
+ input_grad.copy_size(output);
+ fill_with_gassuan_random_numbers(input_grad, rnd);
+
+
+ // The f() we are computing gradients of is this thing. It's value at the current
+ // parameter and data values is:
+ //sout << "f(data,params): " << dot(output, input_grad) << std::endl;
+
+ // We are going to save a copy of the subnetwork.get_gradient_input() data before we do
+ // backpropagation since the backward() function is supposed to *add* to the
+ // gradients rather than overwrite them. We will use this saved data to check if
+ // that is the case.
+ const unsigned long num_data_inputs = subnetwork.count_outputs();
+ std::vector<float> initial_gradient_input(num_data_inputs);
+ for (unsigned long i = 0; i < num_data_inputs; ++i)
+ initial_gradient_input[i] = subnetwork.get_gradient_input_element(i);
+
+
+ // Now tell the layer to compute all the gradients. In the rest of this function
+ // we will just be checking that these gradients were computed correctly by
+ // comparing them to a central differences approximation.
+ resizable_tensor params_grad;
+ params_grad.copy_size(l.get_layer_params());
+ // But first, set the params grad to something crazy so that it's very obvious if
+ // it doesn't get fully assigned.
+ params_grad = std::numeric_limits<float>::infinity();
+ impl::call_layer_backward(l, output, input_grad, subnetwork, params_grad);
+
+ static_assert(impl::is_inplace_layer(l, subnetwork) == impl::has_inplace_backward(l, subnetwork),
+ "Layer not defined correctly. forward and backward methods must either both be in-place or both out-of-place. ");
+
+ // Make sure the outputs of forward() and backward() are the same when they are run
+ // in in-place mode.
+ if (impl::is_inplace_layer(l, subnetwork))
+ {
+ test_layer_subnet subnetwork2(rnd);
+ layer_details_type ll(l);
+ ll.setup(subnetwork2);
+ resizable_tensor ip_out;
+ impl::call_layer_forward(ll, subnetwork2, ip_out);
+ impl::call_layer_forward(ll, subnetwork2, subnetwork2.get_mutable_output());
+ const auto forward_error = max(abs(mat(ip_out) - mat(subnetwork2.get_output())));
+ if (forward_error > 0.00001)
+ {
+ using namespace std;
+ sout << "This layer is supposed to support in-place computations but the output of forward_inplace()\n";
+ sout << "changes when invoked in-place vs. out-of-place. The error was: " << forward_error << endl;
+ return layer_test_results(sout.str());
+ }
+
+ resizable_tensor params_grad;
+ params_grad.copy_size(ll.get_layer_params());
+ params_grad = std::numeric_limits<float>::infinity();
+
+ resizable_tensor input_grad;
+ input_grad.copy_size(ip_out);
+ fill_with_gassuan_random_numbers(input_grad, rnd);
+ resizable_tensor params_grad1, params_grad2, data_grad1, data_grad2;
+ params_grad1 = params_grad;
+ params_grad2 = params_grad;
+ // Now call backward() and make sure it works as well. Recall that when an
+ // in-place layer works in-place it assigns to it's outputs but when it's
+ // not running in-place it adds. So we initialize to a non-zero value to
+ // check that this is the behavior that really executes.
+ subnetwork2.get_gradient_input() = 9;
+ impl::call_layer_backward(ll, ip_out, input_grad, subnetwork2, params_grad1);
+ data_grad1 = subnetwork2.get_gradient_input();
+
+ subnetwork2.get_gradient_input() = mat(input_grad);
+ impl::call_layer_backward(ll, ip_out, subnetwork2.get_gradient_input(), subnetwork2, params_grad2);
+ data_grad2 = subnetwork2.get_gradient_input();
+ if (params_grad.size() != 0)
+ {
+ const auto backward_param_error = max(abs(mat(params_grad1) - mat(params_grad2)));
+ if (backward_param_error > 0.00001)
+ {
+ using namespace std;
+ sout << "This layer is supposed to support in-place computations but the output of backward_inplace()\n";
+ sout << "changes when invoked in-place vs. out-of-place. The error was: " << backward_param_error << endl;
+ return layer_test_results(sout.str());
+ }
+ }
+ const auto backward_data_error = max(abs(mat(data_grad1)-9 - mat(data_grad2)));
+ if (backward_data_error > 0.00001)
+ {
+ using namespace std;
+ sout << "This layer is supposed to support in-place computations but the output of backward_inplace()\n";
+ sout << "changes when invoked in-place vs. out-of-place. The error was: " << backward_data_error << endl;
+ return layer_test_results(sout.str());
+ }
+ }
+
+ // ==================================================================
+ // first validate the way the parameter gradients are computed
+ for (unsigned long i = 0; i < params_grad.size(); ++i)
+ {
+ layer_details_type l1(l);
+
+ float eps = l1.get_layer_params().host()[i]*base_eps;
+ if (eps == 0)
+ eps = base_eps;
+ const float oldval = l1.get_layer_params().host()[i];
+ l1.get_layer_params().host()[i] = oldval+eps;
+ impl::call_layer_forward(l1, subnetwork, out2);
+ l1.get_layer_params().host()[i] = oldval-eps;
+ impl::call_layer_forward(l1, subnetwork, out3);
+ l1.get_layer_params().host()[i] = oldval;
+
+ // Compute a reference derivative via a central differences approximation and
+ // compare it to the one output by the layer and make sure they match.
+ double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps);
+ double output_derivative = params_grad.host()[i];
+ double relative_error;
+ if (reference_derivative*output_derivative != 0)
+ relative_error = (reference_derivative - output_derivative)/(reference_derivative);
+ else
+ relative_error = (reference_derivative - output_derivative);
+ double absolute_error = (reference_derivative - output_derivative);
+ rs_params.add(std::abs(relative_error));
+ if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006)
+ {
+ using namespace std;
+ sout << "Gradient error in parameter #" << i <<". Relative error: "<< relative_error << endl;
+ sout << "expected derivative: " << reference_derivative << endl;
+ sout << "output derivative: " << output_derivative << endl;
+ sout << "iteration: " << iter << endl;
+ return layer_test_results(sout.str());
+ }
+ }
+
+ // ==================================================================
+ // now validate the data gradients
+ for (unsigned long i = 0; i < num_data_inputs; ++i)
+ {
+ const float oldval = subnetwork.get_output_element(i);
+ float eps = oldval*base_eps;
+ if (eps == 0)
+ eps = base_eps;
+ subnetwork.get_output_element(i) = oldval+eps;
+ impl::call_layer_forward(l, subnetwork, out2);
+ subnetwork.get_output_element(i) = oldval-eps;
+ impl::call_layer_forward(l, subnetwork, out3);
+ subnetwork.get_output_element(i) = oldval;
+
+ // Compute a reference derivative via a central differences approximation and
+ // compare it to the one output by the layer and make sure they match.
+ double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps);
+ double output_derivative = subnetwork.get_gradient_input_element(i);
+ output_derivative -= initial_gradient_input[i];
+ double relative_error;
+ if (reference_derivative*output_derivative != 0)
+ relative_error = (reference_derivative - output_derivative)/(reference_derivative);
+ else
+ relative_error = (reference_derivative - output_derivative);
+ double absolute_error = (reference_derivative - output_derivative);
+ rs_data.add(std::abs(relative_error));
+ if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006)
+ {
+ using namespace std;
+ sout << "Gradient error in data variable #" << i <<". Relative error: "<< relative_error << endl;
+ sout << "expected derivative: " << reference_derivative << endl;
+ sout << "output derivative: " << output_derivative << endl;
+ sout << "iteration: " << iter << endl;
+ return layer_test_results(sout.str());
+ }
+ }
+
+ } // end for (int iter = 0; iter < 10; ++iter)
+
+ if (rs_params.mean() > 0.003)
+ {
+ using namespace std;
+ sout << "Average parameter gradient error is somewhat large at: "<< rs_params.mean() << endl;
+ return layer_test_results(sout.str());
+ }
+ if (rs_data.mean() > 0.003)
+ {
+ using namespace std;
+ sout << "Average data gradient error is somewhat large at: "<< rs_data.mean() << endl;
+ return layer_test_results(sout.str());
+ }
+
+ return layer_test_results();
+ }
+
+ template <
+ typename layer_details_type
+ >
+ layer_test_results test_layer (
+ layer_details_type l
+ )
+ {
+ // Try a few different derivative step sizes to see if any work.
+ for (float base_eps = 0.0001; base_eps < 0.1; base_eps *= 2)
+ {
+ auto result = impl_test_layer(l, base_eps);
+ if (result)
+ return result;
+ }
+ // However, if none of the step sizes worked then try this one and probably result
+ // in returning an error.
+ return impl_test_layer(l, 0.01);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <size_t i, size_t num>
+ struct vlp_loop
+ {
+ template <typename T, typename U>
+ static typename std::enable_if<!is_add_layer<U>::value>::type invoke_functor(T&& , size_t& , U&& )
+ {
+ // intentionally left empty
+ }
+
+ template <typename T, typename U>
+ static typename std::enable_if<is_add_layer<U>::value>::type invoke_functor(T&& v , size_t& comp_i, U&& l )
+ {
+ v(comp_i, l.layer_details().get_layer_params());
+ ++comp_i;
+ }
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ static void visit(
+ size_t comp_i,
+ net_type& net,
+ visitor&& v
+ )
+ {
+ invoke_functor(v, comp_i, layer<i>(net));
+ vlp_loop<i+1, num>::visit(comp_i, net,v);
+ }
+ };
+
+ template <size_t num>
+ struct vlp_loop<num,num>
+ {
+ template <
+ typename net_type,
+ typename visitor
+ >
+ static void visit(
+ size_t,
+ net_type&,
+ visitor&&
+ )
+ {
+ // Base case of recursion. Don't do anything.
+ }
+ };
+
+ }
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ void visit_layer_parameters(
+ net_type& net,
+ visitor v
+ )
+ {
+ size_t comp_i = 0;
+ impl::vlp_loop<0, net_type::num_layers>::visit(comp_i, net, v);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <size_t i, size_t num>
+ struct vlpg_loop
+ {
+ template <typename T, typename U>
+ static typename std::enable_if<!is_add_layer<U>::value>::type invoke_functor(T&& , size_t& , U&& )
+ {
+ // intentionally left empty
+ }
+
+ template <typename T, typename U>
+ static typename std::enable_if<is_add_layer<U>::value>::type invoke_functor(T&& v , size_t& comp_i, U&& l )
+ {
+ v(comp_i, l.get_parameter_gradient());
+ ++comp_i;
+ }
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ static void visit(
+ size_t comp_i,
+ net_type& net,
+ visitor&& v
+ )
+ {
+ invoke_functor(v, comp_i, layer<i>(net));
+ vlpg_loop<i+1, num>::visit(comp_i, net,v);
+ }
+ };
+
+ template <size_t num>
+ struct vlpg_loop<num,num>
+ {
+ template <
+ typename net_type,
+ typename visitor
+ >
+ static void visit(
+ size_t,
+ net_type&,
+ visitor&&
+ )
+ {
+ // Base case of recursion. Don't do anything.
+ }
+ };
+
+ }
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ void visit_layer_parameter_gradients(
+ net_type& net,
+ visitor v
+ )
+ {
+ size_t comp_i = 0;
+ impl::vlpg_loop<0, net_type::num_layers>::visit(comp_i, net, v);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <size_t i, size_t num>
+ struct vl_loop
+ {
+ template <
+ typename net_type,
+ typename visitor
+ >
+ static void visit(
+ net_type& net,
+ visitor&& v
+ )
+ {
+ v(i, layer<i>(net));
+ vl_loop<i+1, num>::visit(net,v);
+ }
+ };
+
+ template <size_t num>
+ struct vl_loop<num,num>
+ {
+ template <
+ typename net_type,
+ typename visitor
+ >
+ static void visit(
+ net_type&,
+ visitor&&
+ )
+ {
+ // Base case of recursion. Don't do anything.
+ }
+ };
+
+ template <size_t i, size_t num>
+ struct vl_loop_backwards
+ {
+ template <
+ typename net_type,
+ typename visitor
+ >
+ static void visit(
+ net_type& net,
+ visitor&& v
+ )
+ {
+ vl_loop_backwards<i+1, num>::visit(net,v);
+ v(i, layer<i>(net));
+ }
+ };
+
+ template <size_t num>
+ struct vl_loop_backwards<num,num>
+ {
+ template <
+ typename net_type,
+ typename visitor
+ >
+ static void visit(
+ net_type&,
+ visitor&&
+ )
+ {
+ // Base case of recursion. Don't do anything.
+ }
+ };
+
+ }
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers(
+ net_type& net,
+ visitor v
+ )
+ {
+ impl::vl_loop<0, net_type::num_layers>::visit(net, v);
+ }
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers_backwards(
+ net_type& net,
+ visitor v
+ )
+ {
+ impl::vl_loop_backwards<0, net_type::num_layers>::visit(net, v);
+ }
+
+ template <
+ size_t begin,
+ size_t end,
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers_range(
+ net_type& net,
+ visitor v
+ )
+ {
+ static_assert(begin <= end, "Invalid range");
+ static_assert(end <= net_type::num_layers, "Invalid range");
+ impl::vl_loop<begin,end>::visit(net, v);
+ }
+
+ template <
+ size_t begin,
+ size_t end,
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers_backwards_range(
+ net_type& net,
+ visitor v
+ )
+ {
+ static_assert(begin <= end, "Invalid range");
+ static_assert(end <= net_type::num_layers, "Invalid range");
+ impl::vl_loop_backwards<begin,end>::visit(net, v);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <size_t i, unsigned long tag_id>
+ struct vl_until_tag
+ {
+ template <
+ typename net_type,
+ typename next_net_type,
+ typename visitor
+ >
+ static void visit(
+ net_type& net,
+ next_net_type& next_net,
+ visitor&& v
+ )
+ {
+ v(next_net);
+ vl_until_tag<i+1,tag_id>::visit(net,layer<i+1>(net),v);
+ }
+
+ template <
+ typename net_type,
+ typename SUBNET,
+ typename visitor
+ >
+ static void visit(
+ net_type& net,
+ const add_tag_layer<tag_id,SUBNET>& next_net,
+ visitor&& v
+ )
+ {
+ v(next_net);
+ }
+
+ template <
+ typename net_type,
+ typename SUBNET,
+ typename visitor
+ >
+ static void visit(
+ net_type& net,
+ add_tag_layer<tag_id,SUBNET>& next_net,
+ visitor&& v
+ )
+ {
+ v(next_net);
+ }
+ };
+ }
+
+ template <
+ unsigned long tag_id,
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers_until_tag(
+ net_type& net,
+ visitor v
+ )
+ {
+ impl::vl_until_tag<0,tag_id>::visit(net, net, v);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_CORE_H_
+
+
diff --git a/ml/dlib/dlib/dnn/core_abstract.h b/ml/dlib/dlib/dnn/core_abstract.h
new file mode 100644
index 000000000..db168a88b
--- /dev/null
+++ b/ml/dlib/dlib/dnn/core_abstract.h
@@ -0,0 +1,1700 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DNn_CORE_ABSTRACT_H_
+#ifdef DLIB_DNn_CORE_ABSTRACT_H_
+
+#include "tensor_abstract.h"
+#include <memory>
+#include <type_traits>
+#include <tuple>
+#include <vector>
+#include "../rand.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename... T
+ >
+ auto tuple_tail(
+ const std::tuple<T...>& item
+ );
+ /*!
+ ensures
+ - returns a tuple that contains everything in item except for tuple_head(item).
+ The items will be in the same order as they are in item, just without
+ tuple_head(item).
+ - This function will correctly handle nested tuples.
+ !*/
+
+ template <typename... T>
+ auto tuple_head (
+ const std::tuple<T...>& item
+ );
+ /*!
+ ensures
+ - returns a copy of the first thing in the tuple that isn't a std::tuple.
+ Essentially, this function calls std::get<0>() recursively on item until
+ a non-std::tuple object is found.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ double get_learning_rate_multiplier(
+ const T& obj
+ );
+ /*!
+ ensures
+ - if (obj has a get_learning_rate_multiplier() member function) then
+ - returns obj.get_learning_rate_multiplier()
+ - else
+ - returns 1
+ !*/
+
+ template <typename T>
+ double get_weight_decay_multiplier(
+ const T& obj
+ );
+ /*!
+ ensures
+ - if (obj has a get_weight_decay_multiplier() member function) then
+ - returns obj.get_weight_decay_multiplier()
+ - else
+ - returns 1
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool dnn_prefer_fastest_algorithms(
+ );
+ /*!
+ ensures
+ - If dlib should prefer to use fast algorithms rather than ones that use less
+ RAM then this function returns true and false otherwise.
+ - On program startup this function will default to true.
+ !*/
+
+ void set_dnn_prefer_fastest_algorithms(
+ );
+ /*!
+ ensures
+ - #dnn_prefer_fastest_algorithms() == true
+ !*/
+
+ void set_dnn_prefer_smallest_algorithms(
+ );
+ /*!
+ ensures
+ - #dnn_prefer_fastest_algorithms() == false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class sstack
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a basic stack of T objects. It contains no data itself but simply
+ points to a memory range of T object and allows you to access that block of
+ T objects as a stack.
+ !*/
+
+ public:
+ typedef T value_type;
+
+ sstack() = delete;
+
+ sstack (
+ T* data,
+ size_t s
+ );
+ /*!
+ ensures
+ - #size() == s
+ - #top() == *data
+ - #pop(i).top() == data[i]
+ !*/
+
+ const T& top(
+ ) const;
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - returns the top element of the stack.
+ !*/
+
+ T& top(
+ );
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - returns the top element of the stack.
+ !*/
+
+ size_t size(
+ ) const;
+ /*!
+ ensures
+ - returns the number of elements in this stack.
+ !*/
+
+ sstack pop(
+ size_t num = 1
+ );
+ /*!
+ requires
+ - num <= size()
+ ensures
+ - returns a reference to the sub-stack S such that:
+ - S.size() == size()-num.
+ - S.top() is num elements down the stack.
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ sstack<T> make_sstack(
+ std::vector<T>& item
+ ) { return sstack<T>(item.data(), item.size()); }
+ /*!
+ ensures
+ - returns a sstack that sits on top of the given std::vector.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename LAYER_DETAILS,
+ typename SUBNET
+ >
+ class add_layer
+ {
+ /*!
+ REQUIREMENTS ON LAYER_DETAILS
+ - Must be a type that implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined in layers_abstract.h
+
+ REQUIREMENTS ON SUBNET
+ - One of the following must be true:
+ - SUBNET implements the EXAMPLE_INPUT_LAYER interface defined in
+ input_abstract.h.
+ - SUBNET is an add_layer object.
+ - SUBNET is an add_tag_layer object.
+ - SUBNET is an add_skip_layer object.
+ - SUBNET is a repeat object.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a deep neural network. In particular, it is a tool
+ for adding another layer on top of the neural network of type SUBNET, which
+ is specified as a template argument. The specific layer added is defined
+ by the LAYER_DETAILS details template argument.
+ !*/
+
+ public:
+ typedef LAYER_DETAILS layer_details_type;
+ typedef SUBNET subnet_type;
+ typedef typename subnet_type::input_type input_type;
+ // num_computational_layers will always give the number of layers in the network
+ // that transform tensors (i.e. layers defined by something that implements the
+ // EXAMPLE_COMPUTATIONAL_LAYER_ interface). This is all the layers except for
+ // loss, tag, and skip layers.
+ const static size_t num_computational_layers = subnet_type::num_computational_layers + 1;
+ // num_layers counts all the layers in the network regardless of their type.
+ const static size_t num_layers = subnet_type::num_layers + 1;
+
+ add_layer(
+ );
+ /*!
+ ensures
+ - default constructs all the layers in this network.
+ - #sample_expansion_factor() == 0
+ !*/
+
+ add_layer(const add_layer&) = default;
+ add_layer(add_layer&&) = default;
+ add_layer& operator=(add_layer&&) = default;
+ add_layer& operator=(const add_layer&) = default;
+ /*!
+ ensures
+ - this object is copyable and movable.
+ !*/
+
+ template <typename T, typename U>
+ add_layer(
+ const add_layer<T,U>& item
+ );
+ /*!
+ ensures
+ - This constructor allows you to copy neural network objects from one to
+ another as long as their corresponding layers can be constructed from
+ each other.
+ - #layer_details() == layer_details_type(item.layer_details())
+ - #subnet() == subnet_type(item.subnet())
+ - #sample_expansion_factor() == item.sample_expansion_factor()
+ !*/
+
+ template <typename ...T, typename LD, typename ...U>
+ add_layer(
+ const std::tuple<LD,U...>& layer_det,
+ T&& ...args
+ );
+ /*!
+ ensures
+ - #layer_details() == layer_details_type(tuple_head(layer_det))
+ - #subnet() == subnet_type(tuple_tail(layer_det),args)
+ - #sample_expansion_factor() == 0
+ !*/
+
+ template <typename ...T>
+ add_layer(
+ const layer_details_type& layer_det,
+ T&& ...args
+ );
+ /*!
+ ensures
+ - #layer_details() == layer_details_type(layer_det)
+ - #subnet() == subnet_type(args)
+ - #sample_expansion_factor() == 0
+ !*/
+
+ template <typename ...T>
+ add_layer(
+ T&& ...args
+ );
+ /*!
+ ensures
+ - This version of the constructor is only called if layer_details_type
+ can't be constructed from the first thing in args. In this case, the
+ args are simply passed on to the sub layers in their entirety.
+ - #layer_details() == layer_details_type()
+ - #subnet() == subnet_type(args)
+ - #sample_expansion_factor() == 0
+ !*/
+
+ template <typename ...T>
+ add_layer(
+ layer_details_type&& layer_det,
+ T&& ...args
+ );
+ /*!
+ ensures
+ - #layer_details() == layer_det
+ - #subnet() == subnet_type(args)
+ - #sample_expansion_factor() == 0
+ !*/
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const;
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ ensures
+ - Converts the iterator range into a tensor and stores it into #data.
+ - #data.num_samples()%distance(ibegin,iend) == 0.
+ - #sample_expansion_factor() == #data.num_samples()/distance(ibegin,iend).
+ - #sample_expansion_factor() > 0
+ - The data in the ith sample of #data corresponds to the input_type object
+ *(ibegin+i/#sample_expansion_factor()).
+ - Invokes data.async_copy_to_device() so that the data begins transferring
+ to the GPU device, if present.
+ - This function is implemented by calling the to_tensor() routine defined
+ at the input layer of this network.
+ !*/
+
+ unsigned int sample_expansion_factor (
+ ) const;
+ /*!
+ ensures
+ - When to_tensor() is invoked on this network's input layer it converts N
+ input objects into M samples, all stored inside a resizable_tensor. It
+ is always the case that M is some integer multiple of N.
+ sample_expansion_factor() returns the value of this multiplier. To be
+ very specific, it is always true that M==I*N where I is some integer.
+ This integer I is what is returned by sample_expansion_factor().
+ !*/
+
+ const subnet_type& subnet(
+ ) const;
+ /*!
+ ensures
+ - returns the immediate subnetwork of *this network.
+ !*/
+
+ subnet_type& subnet(
+ );
+ /*!
+ ensures
+ - returns the immediate subnetwork of *this network.
+ !*/
+
+ const layer_details_type& layer_details(
+ ) const;
+ /*!
+ ensures
+ - returns the layer_details_type instance that defines the behavior of the
+ layer at the top of this network. I.e. returns the layer details that
+ defines the behavior of the layer nearest to the network output rather
+ than the input layer.
+ !*/
+
+ layer_details_type& layer_details(
+ );
+ /*!
+ ensures
+ - returns the layer_details_type instance that defines the behavior of the
+ layer at the top of this network. I.e. returns the layer details that
+ defines the behavior of the layer nearest to the network output rather
+ than the input layer.
+ !*/
+
+ template <typename forward_iterator>
+ const tensor& operator() (
+ forward_iterator ibegin,
+ forward_iterator iend
+ );
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ ensures
+ - runs [ibegin,iend) through the network and returns the results.
+ In particular, this function performs:
+ to_tensor(ibegin,iend,temp_tensor);
+ return forward(temp_tensor);
+ - The return value from this function is also available in #get_output().
+ i.e. this function returns #get_output().
+ - have_same_dimensions(#get_gradient_input(), #get_output()) == true.
+ - All elements of #get_gradient_input() are set to 0.
+ i.e. calling this function clears out #get_gradient_input() and ensures
+ it has the same dimensions as the most recent output.
+ !*/
+
+ const tensor& operator() (
+ const input_type& x
+ );
+ /*!
+ ensures
+ - runs a single x through the network and returns the output.
+ I.e. returns (*this)(&x, &x+1);
+ !*/
+
+ const tensor& forward(
+ const tensor& x
+ );
+ /*!
+ requires
+ - sample_expansion_factor() != 0
+ (i.e. to_tensor() must have been called to set sample_expansion_factor()
+ to something non-zero.)
+ - x.num_samples()%sample_expansion_factor() == 0
+ - x.num_samples() > 0
+ ensures
+ - Runs x through the network and returns the results. In particular, this
+ function performs the equivalent of:
+ subnet().forward(x);
+ if (this is the first time forward() has been called) then
+ layer_details().setup(subnet());
+ layer_details().forward(subnet(), get_output());
+ - The return value from this function is also available in #get_output().
+ i.e. this function returns #get_output().
+ - have_same_dimensions(#get_gradient_input(), #get_output()) == true
+ - All elements of #get_gradient_input() are set to 0.
+ i.e. calling this function clears out #get_gradient_input() and ensures
+ it has the same dimensions as the most recent output.
+ !*/
+
+ const tensor& get_output(
+ ) const;
+ /*!
+ ensures
+ - returns the output for the last tensor that was run through the network.
+ If nothing has been run through the network yet then returns an empty
+ tensor.
+ !*/
+
+ tensor& get_gradient_input(
+ );
+ /*!
+ ensures
+ - returns the error gradient for this network. That is, this is the error
+ gradient that this network will use to compute parameter gradients when
+ back_propagate_error() is called. Therefore, when performing back
+ propagation, layers that sit on top of this network layer write their
+ back-propagated error gradients into get_gradient_input(). Or to put it
+ another way, during back-propagation, layers take the contents of their
+ get_gradient_input() and back-propagate it through themselves and store
+ the result into their subnetwork's get_gradient_input().
+
+ This means you should consider get_gradient_input() as an input to the
+ back_propagate_error() method.
+ !*/
+
+ const tensor& get_final_data_gradient(
+ ) const;
+ /*!
+ ensures
+ - if back_propagate_error() has been called to back-propagate a gradient
+ through this network then you can call get_final_data_gradient() to
+ obtain the last data gradient computed. That is, this function returns
+ the gradient of the network with respect to its inputs.
+ - Note that there is only one "final data gradient" for an entire network,
+ not one per layer, since there is only one input to the entire network.
+ !*/
+
+ const tensor& get_parameter_gradient(
+ ) const;
+ /*!
+ ensures
+ - if back_propagate_error() has been called then you can call
+ get_parameter_gradient() to find the gradient of this layer's parameters.
+ When we update the parameters by calling update_parameters(), it will use
+ the gradient in get_parameter_gradient() to perform the update.
+ Therefore, you should consider get_parameter_gradient() as an input to
+ update_parameters().
+ !*/
+
+ tensor& get_parameter_gradient (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the tensor returned by the above
+ get_parameter_gradient() method. You could use this method to modify the
+ parameter gradient in some way before invoking update_parameters().
+ !*/
+
+ void back_propagate_error(
+ const tensor& x
+ );
+ /*!
+ requires
+ - forward(x) was called to forward propagate x though the network.
+ Moreover, this was the most recent call to forward() and x has not been
+ subsequently modified in any way.
+ - get_gradient_input() has been set equal to the gradient of this network's
+ output with respect to some loss function.
+ ensures
+ - Back propagates the error gradient, get_gradient_input(), through this
+ network and computes parameter and data gradients, via backpropagation.
+ Specifically, this function populates get_final_data_gradient() and also,
+ for each layer, the tensor returned by get_parameter_gradient().
+ - All elements of #get_gradient_input() are set to 0.
+ - have_same_dimensions(#get_final_data_gradient(), x) == true.
+ - have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
+ - #get_final_data_gradient() contains the gradient of the network with
+ respect to x.
+ !*/
+
+ void back_propagate_error(
+ const tensor& x,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - forward(x) was called to forward propagate x though the network.
+ Moreover, this was the most recent call to forward() and x has not been
+ subsequently modified in any way.
+ - have_same_dimensions(gradient_input, get_output()) == true
+ ensures
+ - This function is identical to the version of back_propagate_error()
+ defined immediately above except that it back-propagates gradient_input
+ through the network instead of get_gradient_input(). Therefore, this
+ version of back_propagate_error() is equivalent to performing:
+ get_gradient_input() = gradient_input;
+ back_propagate_error(x);
+ Except that calling back_propagate_error(x,gradient_input) avoids the
+ copy and is therefore slightly more efficient.
+ - All elements of #get_gradient_input() are set to 0.
+ - have_same_dimensions(#get_final_data_gradient(), x) == true.
+ - have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
+ - #get_final_data_gradient() contains the gradient of the network with
+ respect to x.
+ !*/
+
+ template <typename solver_type>
+ void update_parameters(
+ sstack<solver_type> solvers,
+ double learning_rate
+ );
+ /*!
+ requires
+ - solver_type is an implementation of the EXAMPLE_SOLVER interface defined
+ in solvers_abstract.h
+ - back_propagate_error() has been called.
+ - The given solvers have only ever been used with this network. That is,
+ if you want to call update_parameters() on some other neural network
+ object then you must NOT reuse the same solvers object.
+ - solvers.size() >= num_computational_layers
+ - 0 < learning_rate <= 1
+ ensures
+ - Updates all the parameters in the network. In particular, we pass each
+ layer's parameter gradient (i.e. the tensor returned by the layer's
+ get_parameter_gradient() member) through that layer's corresponding
+ solver object. This produces a parameter delta vector which we add to
+ the layer's parameters.
+ - The solvers use the given learning rate.
+ !*/
+
+ void clean(
+ );
+ /*!
+ ensures
+ - Causes the network to forget about everything but its parameters.
+ That is, for each layer we will have:
+ - get_output().num_samples() == 0
+ - get_gradient_input().num_samples() == 0
+ However, running new input data though this network will still produce
+ the same output it would have produced regardless of any calls to
+ clean(). The purpose of clean() is to compact the network object prior
+ to saving it to disk so that it takes up less space and the IO is
+ quicker.
+ - This also calls the .clean() method on any layer details objects that
+ define a .clean() method.
+ !*/
+
+ };
+
+ template <typename T, typename U>
+ std::ostream& operator<<(std::ostream& out, const add_layer<T,U>& item);
+ /*!
+ prints the network architecture to the given output stream.
+ !*/
+
+ template <typename T, typename U>
+ void serialize(const add_layer<T,U>& item, std::ostream& out);
+ template <typename T, typename U>
+ void deserialize(add_layer<T,U>& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class no_label_type;
+
+ template <
+ typename LOSS_DETAILS,
+ typename SUBNET
+ >
+ class add_loss_layer
+ {
+ /*!
+ REQUIREMENTS ON LOSS_DETAILS
+ - Must be a type that implements the EXAMPLE_LOSS_LAYER_ interface defined
+ in loss_abstract.h
+
+ REQUIREMENTS ON SUBNET
+ - One of the following must be true:
+ - SUBNET is an add_layer object.
+ - SUBNET is an add_tag_layer object.
+ - SUBNET is an add_skip_layer object.
+ - SUBNET is a repeat object.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a deep neural network. In particular, it is a tool
+ for adding a loss layer on top of the neural network of type SUBNET, which
+ is specified as a template argument. The specific layer added is defined
+ by the LOSS_DETAILS details template argument. Importantly, a loss layer
+ is the last layer in a deep neural network. So once it is added you can't
+ add any other layers of any type.
+ !*/
+
+ public:
+ typedef LOSS_DETAILS loss_details_type;
+ typedef SUBNET subnet_type;
+ typedef typename subnet_type::input_type input_type;
+ const static size_t num_computational_layers = subnet_type::num_computational_layers;
+ const static size_t num_layers = subnet_type::num_layers + 1;
+ // If LOSS_DETAILS is an unsupervised loss then training_label_type==no_label_type.
+ // Otherwise it is defined as follows:
+ typedef typename LOSS_DETAILS::training_label_type training_label_type;
+ // Similarly, if LOSS_DETAILS doesn't provide any output conversion then
+ // output_label_type==no_label_type.
+ typedef typename LOSS_DETAILS::output_label_type output_label_type;
+
+
+
+ add_loss_layer() = default;
+ /*!
+ ensures
+ - default constructs all the layers in this network.
+ !*/
+
+ add_loss_layer(const add_loss_layer&) = default;
+ add_loss_layer(add_loss_layer&&) = default;
+ add_loss_layer& operator=(add_loss_layer&&) = default;
+ add_loss_layer& operator=(const add_loss_layer&) = default;
+ /*!
+ ensures
+ - this object is copyable and movable.
+ !*/
+
+ template <typename T, typename U>
+ add_loss_layer(
+ const add_loss_layer<T,U>& item
+ );
+ /*!
+ ensures
+ - This constructor allows you to copy neural network objects from one to
+ another as long as their corresponding layers can be constructed from
+ each other.
+ - #loss_details() == loss_details_type(item.loss_details())
+ - #subnet() == subnet_type(item.subnet())
+ !*/
+
+ template <typename ...T>
+ add_loss_layer(
+ const LOSS_DETAILS& layer_det,
+ T&& ...args
+ );
+ /*!
+ ensures
+ - #loss_details() == loss_details_type(layer_det)
+ - #subnet() == subnet_type(args)
+ !*/
+
+ template <typename ...T>
+ add_loss_layer(
+ LOSS_DETAILS&& layer_det,
+ T&& ...args
+ );
+ /*!
+ ensures
+ - #loss_details() == loss_details_type(layer_det)
+ - #subnet() == subnet_type(args)
+ !*/
+
+ template <typename ...T>
+ add_loss_layer(
+ T&& ...args
+ );
+ /*!
+ ensures
+ - This version of the constructor is only called if loss_details_type can't
+ be constructed from the first thing in args. In this case, the args are
+ simply passed on to the sub layers in their entirety.
+ - #loss_details() == loss_details_type()
+ - #subnet() == subnet_type(args)
+ !*/
+
+ const subnet_type& subnet(
+ ) const;
+ /*!
+ ensures
+ - returns the immediate subnetwork of *this network.
+ !*/
+
+ subnet_type& subnet(
+ );
+ /*!
+ ensures
+ - returns the immediate subnetwork of *this network.
+ !*/
+
+ const loss_details_type& loss_details(
+ ) const;
+ /*!
+ ensures
+ - returns the loss_details_type instance that defines the behavior of the
+ loss layer used by this network.
+ !*/
+
+ loss_details_type& loss_details(
+ );
+ /*!
+ ensures
+ - returns the loss_details_type instance that defines the behavior of the
+ loss layer used by this network.
+ !*/
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const;
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ ensures
+ - Converts the iterator range into a tensor and stores it into #data.
+ - #data.num_samples()%distance(ibegin,iend) == 0.
+ - #sample_expansion_factor() == #data.num_samples()/distance(ibegin,iend).
+ - #sample_expansion_factor() > 0
+ - The data in the ith sample of #data corresponds to the input_type object
+ *(ibegin+i/sample_expansion_factor()).
+ - Invokes data.async_copy_to_device() so that the data begins transferring
+ to the GPU device, if present.
+ - This function is implemented by calling the to_tensor() routine defined
+ at the input layer of this network.
+ !*/
+
+ unsigned int sample_expansion_factor (
+ ) const;
+ /*!
+ ensures
+ - When to_tensor() is invoked on this network's input layer it converts N
+ input objects into M samples, all stored inside a resizable_tensor. It
+ is always the case that M is some integer multiple of N.
+ sample_expansion_factor() returns the value of this multiplier. To be
+ very specific, it is always true that M==I*N where I is some integer.
+ This integer I is what is returned by sample_expansion_factor().
+ !*/
+
+ // -------------
+
+ template <typename output_iterator>
+ void operator() (
+ const tensor& x,
+ output_iterator obegin
+ );
+ /*!
+ requires
+ - sample_expansion_factor() != 0
+ (i.e. to_tensor() must have been called to set sample_expansion_factor()
+ to something non-zero.)
+ - x.num_samples()%sample_expansion_factor() == 0
+ - x.num_samples() > 0
+ - obegin == iterator pointing to the start of a range of
+ x.num_samples()/sample_expansion_factor() output_label_type elements.
+ ensures
+ - runs x through the network and writes the output to the range at obegin.
+ - loss_details().to_label() is used to write the network output into
+ obegin.
+ !*/
+
+ template <typename forward_iterator, typename label_iterator>
+ void operator() (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ label_iterator obegin
+ );
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ - obegin == iterator pointing to the start of a range of
+ std::distance(ibegin,iend) output_label_type elements.
+ ensures
+ - runs [ibegin,iend) through the network and writes the output to the range
+ at obegin.
+ - loss_details().to_label() is used to write the network output into
+ obegin.
+ !*/
+
+ // -------------
+
+ const output_label_type& operator() (
+ const input_type& x
+ );
+ /*!
+ ensures
+ - runs a single object, x, through the network and returns the output.
+ - loss_details().to_label() is used to convert the network output into a
+ output_label_type.
+ !*/
+
+ template <typename iterable_type>
+ std::vector<output_label_type> operator() (
+ const iterable_type& data,
+ size_t batch_size = 128
+ );
+ /*!
+ requires
+ - batch_size > 0
+ - data must have a .begin() and .end() that supply iterators over a
+ sequence of input_type elements. E.g. data could have a type of
+ std::vector<input_type>
+ ensures
+ - runs all the objects in data through the network and returns their
+ predicted labels. This means this function returns a vector V such that:
+ - V.size() == data.size()
+ - for all valid i: V[i] == the predicted label of data[i].
+ - Elements of data are run through the network in batches of batch_size
+ items. Using a batch_size > 1 can be faster because it better exploits
+ the available hardware parallelism.
+ - loss_details().to_label() is used to convert the network output into a
+ output_label_type.
+ !*/
+
+ template <typename ...T>
+ const output_label_type& process (
+ const input_type& x,
+ T&& ...args
+ );
+ /*!
+ ensures
+ - This function is just like (*this)(x), i.e. it runs a single object, x,
+ through the network and returns the output. But we additionally pass the
+ given args to loss_details().to_label() as the 4th argument (or more,
+ depending on how many things are in args) when converting the network
+ output to an output_label_type. This is useful, for instance, with loss
+ layers like loss_mmod_ which has an optional adjust_threshold argument to
+ to_label() that adjusts the detection threshold. Therefore, for such
+ networks you could call them like: net.process(some_image, -0.5), and -0.5
+ would be passed so the adjust_threshold argument of to_tensor().
+ !*/
+
+ template <typename iterable_type, typename ...T>
+ std::vector<output_label_type> process_batch (
+ const iterable_type& data,
+ size_t batch_size,
+ T&& ...args
+ );
+ /*!
+ requires
+ - batch_size > 0
+ - data must have a .begin() and .end() that supply iterators over a
+ sequence of input_type elements. E.g. data could have a type of
+ std::vector<input_type>
+ ensures
+ - This function is just like (*this)(data,batch_size), i.e. it runs a
+ bunch of objects through the network and returns the outputs. But we
+ additionally pass the given args to loss_details().to_label() as the 4th
+ argument (or more, depending on how many things are in args) when
+ converting the network output to output_label_types. This is useful,
+ for instance, with loss layers like loss_mmod_ which has an optional
+ adjust_threshold argument to to_label() that adjusts the detection
+ threshold. Therefore, for such networks you could call them like:
+ net.process_batch(std::vector<image_type>({some_image, another_image}), 128, -0.5),
+ and -0.5 would be passed so the adjust_threshold argument of to_tensor().
+ !*/
+
+ // -------------
+
+ template <typename label_iterator>
+ double compute_loss (
+ const tensor& x,
+ label_iterator lbegin
+ );
+ /*!
+ requires
+ - sample_expansion_factor() != 0
+ (i.e. to_tensor() must have been called to set sample_expansion_factor()
+ to something non-zero.)
+ - x.num_samples()%sample_expansion_factor() == 0
+ - x.num_samples() > 0
+ - lbegin == iterator pointing to the start of a range of
+ x.num_samples()/sample_expansion_factor() training_label_type elements.
+ ensures
+ - runs x through the network, compares the output to the expected output
+ pointed to by lbegin, and returns the resulting loss.
+ - for all valid k:
+ - the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()).
+ - This function does not update the network parameters.
+ !*/
+
+ template <typename forward_iterator, typename label_iterator>
+ double compute_loss (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ label_iterator lbegin
+ );
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ - lbegin == iterator pointing to the start of a range of
+ std::distance(ibegin,iend) training_label_type elements.
+ ensures
+ - runs [ibegin,iend) through the network, compares the output to the
+ expected output pointed to by lbegin, and returns the resulting loss.
+ - for all valid k:
+ - the expected label of *(ibegin+k) is *(lbegin+k).
+ - This function does not update the network parameters.
+ !*/
+
+ // -------------
+
+ double compute_loss (
+ const tensor& x
+ );
+ /*!
+ requires
+ - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type.
+ - sample_expansion_factor() != 0
+ (i.e. to_tensor() must have been called to set sample_expansion_factor()
+ to something non-zero.)
+ - x.num_samples()%sample_expansion_factor() == 0
+ - x.num_samples() > 0
+ ensures
+ - runs x through the network and returns the resulting loss.
+ - This function does not update the network parameters.
+ !*/
+
+ template <typename forward_iterator>
+ double compute_loss (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ );
+ /*!
+ requires
+ - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type.
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ ensures
+ - runs [ibegin,iend) through the network and returns the resulting loss.
+ - This function does not update the network parameters.
+ !*/
+
+ // -------------
+
+ template <typename label_iterator>
+ double compute_parameter_gradients (
+ const tensor& x,
+ label_iterator lbegin
+ );
+ /*!
+ requires
+ - sample_expansion_factor() != 0
+ (i.e. to_tensor() must have been called to set sample_expansion_factor()
+ to something non-zero.)
+ - x.num_samples()%sample_expansion_factor() == 0
+ - x.num_samples() > 0
+ - lbegin == iterator pointing to the start of a range of
+ x.num_samples()/sample_expansion_factor() training_label_type elements.
+ ensures
+ - runs x through the network, compares the output to the expected output
+ pointed to by lbegin, and computes parameter and data gradients with
+ respect to the loss, via backpropagation. Specifically, this function
+ updates get_final_data_gradient() and also, for each layer, the tensor
+ returned by get_parameter_gradient().
+ - for all valid k:
+ - the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()).
+ - returns compute_loss(x,lbegin)
+ !*/
+
+ template <typename forward_iterator, typename label_iterator>
+ double compute_parameter_gradients (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ label_iterator lbegin
+ );
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ - lbegin == iterator pointing to the start of a range of
+ std::distance(ibegin,iend) training_label_type elements.
+ ensures
+ - runs [ibegin,iend) through the network, compares the output to the
+ expected output pointed to by lbegin, and computes parameter and data
+ gradients with respect to the loss, via backpropagation. Specifically,
+ this function updates get_final_data_gradient() and also, for each layer,
+ the tensor returned by get_parameter_gradient().
+ - for all valid k:
+ - the expected label of *(ibegin+k) is *(lbegin+k).
+ - returns compute_loss(ibegin,iend,lbegin)
+ !*/
+
+ double compute_parameter_gradients (
+ const tensor& x
+ );
+ /*!
+ requires
+ - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type.
+ - sample_expansion_factor() != 0
+ (i.e. to_tensor() must have been called to set sample_expansion_factor()
+ to something non-zero.)
+ - x.num_samples()%sample_expansion_factor() == 0
+ - x.num_samples() > 0
+ ensures
+ - runs x through the network and computes parameter and data gradients with
+ respect to the loss, via backpropagation. Specifically, this function
+ updates get_final_data_gradient() and also, for each layer, the tensor
+ returned by get_parameter_gradient().
+ - returns compute_loss(x)
+ !*/
+
+ template <typename forward_iterator>
+ double compute_parameter_gradients (
+ forward_iterator ibegin,
+ forward_iterator iend
+ );
+ /*!
+ requires
+ - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type.
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ ensures
+ - runs [ibegin,iend) through the network and computes parameter and data
+ gradients with respect to the loss, via backpropagation. Specifically,
+ this function updates get_final_data_gradient() and also, for each layer,
+ the tensor returned by get_parameter_gradient().
+ - returns compute_loss(ibegin,iend)
+ !*/
+
+ template <typename solver_type>
+ void update_parameters (
+ sstack<solver_type> solvers,
+ double learning_rate
+ );
+ /*!
+ requires
+ - solver_type is an implementation of the EXAMPLE_SOLVER interface defined
+ in solvers_abstract.h
+ - compute_parameter_gradients() has been called.
+ - The given solvers have only ever been used with this network. That
+ is, if you want to call update_parameters() on some other neural network
+ object then you must NOT reuse the same solvers object.
+ - solvers.size() >= num_computational_layers
+ - 0 < learning_rate <= 1
+ ensures
+ - Updates all the parameters in the network. In particular, we pass each
+ layer's parameter gradient (i.e. the tensor returned by the layer's
+ get_parameter_gradient() member) through that layer's corresponding
+ solver object. This produces a parameter delta vector which we add to
+ the layer's parameters.
+ - The solvers use the given learning rate.
+ !*/
+
+ // -------------
+
+ void clean (
+ );
+ /*!
+ ensures
+ - Causes the network to forget about everything but its parameters.
+ - invokes subnet().clean()
+ !*/
+ };
+
+ template <typename T, typename U>
+ std::ostream& operator<<(std::ostream& out, const add_loss_layer<T,U>& item);
+ /*!
+ prints the network architecture to the given output stream.
+ !*/
+
+ template <typename T, typename U>
+ void serialize(const add_loss_layer<T,U>& item, std::ostream& out);
+ template <typename T, typename U>
+ void deserialize(add_loss_layer<T,U>& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename ...T>
+ decorator_repeat_group<T...> repeat_group (
+ T&& ...args
+ );
+ /*!
+ ensures
+ - Decorates a group of variables. This is essentially like std::make_tuple()
+ except it's only purpose is to group variables together so they can be passed
+ to the repeat object's constructor.
+ !*/
+
+ template <
+ size_t num,
+ template<typename> class REPEATED_LAYER,
+ typename SUBNET
+ >
+ class repeat
+ {
+ /*!
+ REQUIREMENTS ON num
+ - num > 0
+
+ REQUIREMENTS ON REPEATED_LAYER
+ - REPEATED_LAYER must be a template that stacks more layers onto a deep neural
+ network. For example, if net_type were a network without a loss layer,
+ then it should be legal to create a deeper network with a type of
+ REPEATED_LAYER<net_type>.
+
+ REQUIREMENTS ON SUBNET
+ - One of the following must be true:
+ - SUBNET is an add_layer object.
+ - SUBNET is an add_tag_layer object.
+ - SUBNET is an add_skip_layer object.
+ - SUBNET is a repeat object.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object adds more layers to a deep neural network. In particular, it
+ adds REPEATED_LAYER on top of SUBNET num times. So for example, if num were 2 then
+ repeat<2,REPEATED_LAYER,SUBNET> would create a network equivalent to REPEATED_LAYER<REPEATED_LAYER<SUBNET>>.
+
+ Also, this object provides an interface identical to the one defined by the
+ add_layer object except that we add the num_repetitions() and
+ get_repeated_layer() methods. These additions are shown below along with
+ some additional explanatory comments.
+ !*/
+
+ public:
+
+ typedef SUBNET subnet_type;
+ typedef typename SUBNET::input_type input_type;
+ const static size_t num_computational_layers = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers)*num + SUBNET::num_computational_layers;
+ const static size_t num_layers = (REPEATED_LAYER<SUBNET>::num_layers-SUBNET::num_layers)*num + SUBNET::num_layers;
+ typedef REPEATED_LAYER<an_unspecified_input_type> repeated_layer_type;
+
+ template <typename T, typename ...U>
+ repeat(
+ T arg1,
+ U ...args2
+ );
+ /*!
+ ensures
+ - arg1 is used to initialize the num_repetitions() copies of REPEATED_LAYER inside
+ this object. That is, all the REPEATED_LAYER elements are initialized identically
+ by being given copies of arg1.
+ - The rest of the arguments to the constructor, i.e. args2, are passed to
+ SUBNET's constructor.
+ !*/
+
+ template <typename ...T, typename ...U>
+ repeat(
+ decorator_repeat_group<T...>&& arg1,
+ U ...args2
+ );
+ /*!
+ ensures
+ - arg1 is used to initialize the num_repetitions() copies of REPEATED_LAYER inside
+ this object. That is, all the REPEATED_LAYER elements are initialized identically
+ by being given copies of an undecorated arg1.
+ - The rest of the arguments to the constructor, i.e. args2, are passed to
+ SUBNET's constructor.
+ !*/
+
+ size_t num_repetitions (
+ ) const;
+ /*!
+ ensures
+ - returns num (i.e. the number of times REPEATED_LAYER was stacked on top of SUBNET)
+ !*/
+
+ const repeated_layer_type& get_repeated_layer (
+ size_t i
+ ) const;
+ /*!
+ requires
+ - i < num_repetitions()
+ ensures
+ - returns a reference to the i-th instance of REPEATED_LAYER. For example,
+ get_repeated_layer(0) returns the instance of REPEATED_LAYER that is on the top of
+ the network while get_repeated_layer(num_repetitions()-1) returns the
+ instance of REPEATED_LAYER that is stacked immediately on top of SUBNET.
+ !*/
+
+ repeated_layer_type& get_repeated_layer (
+ size_t i
+ );
+ /*!
+ requires
+ - i < num_repetitions()
+ ensures
+ - returns a reference to the i-th instance of REPEATED_LAYER. For example,
+ get_repeated_layer(0) returns the instance of REPEATED_LAYER that is on the top of
+ the network while get_repeated_layer(num_repetitions()-1) returns the
+ instance of REPEATED_LAYER that is stacked immediately on top of SUBNET.
+ !*/
+
+ const subnet_type& subnet(
+ ) const;
+ /*!
+ ensures
+ - returns the SUBNET base network that repeat sits on top of. If you want
+ to access the REPEATED_LAYER components then you must use get_repeated_layer().
+ !*/
+
+ subnet_type& subnet(
+ );
+ /*!
+ ensures
+ - returns the SUBNET base network that repeat sits on top of. If you want
+ to access the REPEATED_LAYER components then you must use get_repeated_layer().
+ !*/
+ };
+
+ template < size_t num, template<typename> class T, typename U >
+ std::ostream& operator<<(std::ostream& out, const repeat<num,T,U>& item);
+ /*!
+ prints the network architecture to the given output stream.
+ !*/
+
+ template < size_t num, template<typename> class T, typename U >
+ void serialize(const repeat<num,T,U>& item, std::ostream& out);
+ template < size_t num, template<typename> class T, typename U >
+ void deserialize(repeat<num,T,U>& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long ID,
+ typename SUBNET
+ >
+ class add_tag_layer
+ {
+ /*!
+ REQUIREMENTS ON SUBNET
+ - One of the following must be true:
+ - SUBNET implements the EXAMPLE_INPUT_LAYER interface defined in
+ input_abstract.h.
+ - SUBNET is an add_layer object.
+ - SUBNET is an add_tag_layer object.
+ - SUBNET is an add_skip_layer object.
+ - SUBNET is a repeat object.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object adds a new layer to a deep neural network. However, this layer
+ simply performs the identity transform. This means it is a no-op and its
+ presence does not change the behavior of the network. It exists solely to
+ be used by add_skip_layer to reference a particular part of a network.
+
+ Also, this object provides an interface identical to the one defined by the
+ add_layer object.
+ !*/
+ };
+
+ template <unsigned long ID, typename U>
+ std::ostream& operator<<(std::ostream& out, const add_tag_layer<ID,U>& item);
+ /*!
+ prints the network architecture to the given output stream.
+ !*/
+
+ template <unsigned long ID, typename U>
+ void serialize(const add_tag_layer<ID,U>& item, std::ostream& out);
+ template <unsigned long ID, typename U>
+ void deserialize(add_tag_layer<ID,U>& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+ template <typename SUBNET> using tag1 = add_tag_layer< 1, SUBNET>;
+ template <typename SUBNET> using tag2 = add_tag_layer< 2, SUBNET>;
+ template <typename SUBNET> using tag3 = add_tag_layer< 3, SUBNET>;
+ template <typename SUBNET> using tag4 = add_tag_layer< 4, SUBNET>;
+ template <typename SUBNET> using tag5 = add_tag_layer< 5, SUBNET>;
+ template <typename SUBNET> using tag6 = add_tag_layer< 6, SUBNET>;
+ template <typename SUBNET> using tag7 = add_tag_layer< 7, SUBNET>;
+ template <typename SUBNET> using tag8 = add_tag_layer< 8, SUBNET>;
+ template <typename SUBNET> using tag9 = add_tag_layer< 9, SUBNET>;
+ template <typename SUBNET> using tag10 = add_tag_layer<10, SUBNET>;
+
+ template <template<typename SUBNET> class tag>
+ struct tag_id
+ {
+ /*!
+ REQUIREMENTS ON tag
+ Tag should be an add_tag_layer template such as tag1, tag2, etc.
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool for finding the numeric ID of a tag layer. For example,
+ tag_id<tag3>::id == 3.
+ !*/
+
+ const static unsigned long id;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ template<typename> class TAG_TYPE,
+ typename SUBNET
+ >
+ class add_skip_layer
+ {
+ /*!
+ REQUIREMENTS ON SUBNET
+ - One of the following must be true:
+ - SUBNET is an add_layer object.
+ - SUBNET is an add_tag_layer object.
+ - SUBNET is an add_skip_layer object.
+ - SUBNET is a repeat object.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object adds a new layer to a deep neural network which draws its
+ inputs from layer<TAG_TYPE>(subnet()) and performs the identity transform.
+
+ Also, this object provides an interface identical to the one defined by the
+ add_layer object.
+ !*/
+ };
+
+ template <template<typename> class T, typename U>
+ std::ostream& operator<<(std::ostream& out, const add_skip_layer<T,U>& item);
+ /*!
+ prints the network architecture to the given output stream.
+ !*/
+
+ template <template<typename> class T, typename U>
+ void serialize(const add_skip_layer<T,U>& item, std::ostream& out);
+ template <template<typename> class T, typename U>
+ void deserialize(add_skip_layer<T,U>& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+ template <typename SUBNET> using skip1 = add_skip_layer< tag1, SUBNET>;
+ template <typename SUBNET> using skip2 = add_skip_layer< tag2, SUBNET>;
+ template <typename SUBNET> using skip3 = add_skip_layer< tag3, SUBNET>;
+ template <typename SUBNET> using skip4 = add_skip_layer< tag4, SUBNET>;
+ template <typename SUBNET> using skip5 = add_skip_layer< tag5, SUBNET>;
+ template <typename SUBNET> using skip6 = add_skip_layer< tag6, SUBNET>;
+ template <typename SUBNET> using skip7 = add_skip_layer< tag7, SUBNET>;
+ template <typename SUBNET> using skip8 = add_skip_layer< tag8, SUBNET>;
+ template <typename SUBNET> using skip9 = add_skip_layer< tag9, SUBNET>;
+ template <typename SUBNET> using skip10 = add_skip_layer<tag10, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned int i,
+ typename net_type
+ >
+ auto& layer (
+ net_type& n
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - i < net_type::num_layers
+ ensures
+ - This function allows you to access any layer in a network by its layer index
+ i. Therefore, it will walk i steps down the network and return the layer
+ object there. Since networks can be big, the best way to find layer index
+ numbers is to print a network to the screen since the print out will include
+ indexes for each layer.
+ - In general, this function chains together i calls to n.subnet() and returns
+ the result. So for example:
+ - if (i == 0)
+ - returns n
+ - else if (i == 1)
+ - returns n.subnet()
+ - else if (i == 2)
+ - returns n.subnet().subnet()
+ - else if (i == 3)
+ - returns n.subnet().subnet().subnet()
+ - else
+ - etc.
+ Except that when it hits a repeat layer it recurses into the repeated layers
+ contained inside. That is, if the layer index indicates a layer in a repeat
+ object this function will make the appropriate call to get_repeated_layer()
+ and do the right thing.
+ !*/
+
+ template <
+ template<typename> class Match,
+ typename net_type
+ >
+ auto& layer (
+ net_type& n
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ ensures
+ - returns the first layer in n that is of type Match. E.g. if net_type is
+ fc<relu<fc<input<sample_type>>>> then calling layer<relu>(n) would return
+ layer<1>(n), that is, a reference to the relu layer.
+ !*/
+
+ template <
+ template<typename> class Match,
+ unsigned int i,
+ typename net_type
+ >
+ auto& layer (
+ net_type& n
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ ensures
+ - returns layer<i>(layer<Match>(n))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename net_type>
+ auto& input_layer (
+ net_type& net
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ ensures
+ - returns the input later of the given network object. Specifically, this
+ function is equivalent to calling:
+ layer<net_type::num_layers-1>(net);
+ That is, you get the input layer details object for the network.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ void visit_layer_parameters(
+ net_type& net,
+ visitor v
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - v is a function object with a signature equivalent to:
+ v(size_t idx, tensor& t)
+ ensures
+ - Loops over all the computational layers (i.e. layers with parameters, as
+ opposed to loss, tag, or input layers) in net and passes their parameters to
+ v(). To be specific, this function essentially performs the following:
+
+ size_t computational_layer_idx = 0;
+ for (size_t i = 0; i < net_type::num_layers; ++i)
+ {
+ if (layer<i>(net) is a computational layer)
+ {
+ v(computational_layer_idx, layer<i>(net).layer_details().get_layer_params());
+ ++computational_layer_idx;
+ }
+ }
+ - When v() is called, the first argument is always < net_type::num_computational_layers.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ void visit_layer_parameter_gradients(
+ net_type& net,
+ visitor v
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - v is a function object with a signature equivalent to:
+ v(size_t idx, tensor& t)
+ ensures
+ - Loops over all the computational layers (i.e. layers with parameters, as
+ opposed to loss, tag, or input layers) in net and passes their parameter
+ gradients to v(). To be specific, this function essentially performs the
+ following:
+
+ size_t computational_layer_idx = 0;
+ for (size_t i = 0; i < net_type::num_layers; ++i)
+ {
+ if (layer<i>(net) is a computational layer)
+ {
+ v(computational_layer_idx, layer<i>(net).get_parameter_gradient());
+ ++computational_layer_idx;
+ }
+ }
+ - When v() is called, the first argument is always < net_type::num_computational_layers.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers(
+ net_type& net,
+ visitor v
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - v is a function object with a signature equivalent to:
+ v(size_t idx, any_net_type& t)
+ That is, it must take a size_t and then any of the network types such as
+ add_layer, add_loss_layer, etc.
+ ensures
+ - Loops over all the layers in net and calls v() on them. To be specific, this
+ function essentially performs the following:
+
+ for (size_t i = 0; i < net_type::num_layers; ++i)
+ v(i, layer<i>(net));
+ !*/
+
+ template <
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers_backwards(
+ net_type& net,
+ visitor v
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - v is a function object with a signature equivalent to:
+ v(size_t idx, any_net_type& t)
+ That is, it must take a size_t and then any of the network types such as
+ add_layer, add_loss_layer, etc.
+ ensures
+ - Loops over all the layers in net and calls v() on them. The loop happens in
+ the reverse order of visit_layers(). To be specific, this function
+ essentially performs the following:
+
+ for (size_t i = net_type::num_layers; i != 0; --i)
+ v(i-1, layer<i-1>(net));
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ size_t begin,
+ size_t end,
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers_range(
+ net_type& net,
+ visitor v
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - v is a function object with a signature equivalent to:
+ v(size_t idx, any_net_type& t)
+ That is, it must take a size_t and then any of the network types such as
+ add_layer, add_loss_layer, etc.
+ - begin <= end <= net_type::num_layers
+ ensures
+ - Loops over the layers in the range [begin,end) in net and calls v() on them.
+ The loop happens in the reverse order of visit_layers(). To be specific,
+ this function essentially performs the following:
+
+ for (size_t i = begin; i < end; ++i)
+ v(i, layer<i>(net));
+ !*/
+
+ template <
+ size_t begin,
+ size_t end,
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers_backwards_range(
+ net_type& net,
+ visitor v
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - v is a function object with a signature equivalent to:
+ v(size_t idx, any_net_type& t)
+ That is, it must take a size_t and then any of the network types such as
+ add_layer, add_loss_layer, etc.
+ - begin <= end <= net_type::num_layers
+ ensures
+ - Loops over the layers in the range [begin,end) in net and calls v() on them.
+ The loop happens in the reverse order of visit_layers_range(). To be specific,
+ this function essentially performs the following:
+
+ for (size_t i = end; i != begin; --i)
+ v(i-1, layer<i-1>(net));
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long tag_id,
+ typename net_type,
+ typename visitor
+ >
+ void visit_layers_until_tag(
+ net_type& net,
+ visitor v
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - v is a function object with a signature equivalent to:
+ v(any_net_type& t)
+ That is, it must take any of the network types such as add_layer,
+ add_loss_layer, etc.
+ ensures
+ - Loops over all the layers in net beginning with layer<0>(net) and going until
+ a tag layer with an ID of tag_id is encountered. To be specific, this
+ function essentially performs the following:
+
+ size_t i = 0;
+ while(layer<i>(net) isn't an add_tag_layer with ID == tag_id) {
+ v(layer<i>(net));
+ ++i;
+ }
+ v(layer<i>(net)); // also visits the tag layer itself at the very end.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ struct layer_test_results
+ {
+ std::string log;
+ bool was_good;
+
+ operator bool() const { return was_good; }
+ };
+
+ inline std::ostream& operator<< (std::ostream& out, const layer_test_results& item)
+ {
+ out << item.log;
+ return out;
+ }
+
+ template <
+ typename layer_details_type
+ >
+ layer_test_results test_layer (
+ layer_details_type l
+ );
+ /*!
+ ensures
+ - Checks if l correctly implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined in layers_abstract.h. Importantly, it computes numerical approximations
+ to the gradients and compares them to the outputs of the layer.
+ - The results of the testing are returned. In particular, if the returned object
+ is RESULT then we will have:
+ - RESULT.was_good == false if and only if the layer failed the testing.
+ - RESULT.log == a string describing why the testing failed if was_good==false.
+ - Note that this function is only capable of checking layers that take
+ arbitrary subnetworks as input. So if you have designed a layer that expects
+ only a certain restricted type of subnetwork then you might get a compile or
+ runtime error when you call this function.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_CORE_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/dnn/cpu_dlib.cpp b/ml/dlib/dlib/dnn/cpu_dlib.cpp
new file mode 100644
index 000000000..ed5661102
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cpu_dlib.cpp
@@ -0,0 +1,2170 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CPU_cPP_
+#define DLIB_DNN_CPU_cPP_
+
+// This file contains CPU implementations of the GPU based functions in cuda_dlib.h
+
+#include "cpu_dlib.h"
+#include "tensor_tools.h"
+#include "../image_transforms/interpolation.h"
+#include "../threads.h"
+
+namespace dlib
+{
+ namespace cpu
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ void multiply (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+ DLIB_CASSERT(dest.k() == src1.k() && src1.k() == src2.k() &&
+ dest.nr() == src1.nr() && src1.nr() == src2.nr() &&
+ dest.nc() == src1.nc() && src1.nc() == src2.nc() );
+ const long MD = std::max(std::max(dest.num_samples(),src1.num_samples()),src2.num_samples());
+ DLIB_CASSERT((dest.num_samples()==1 || dest.num_samples()==MD) &&
+ (src1.num_samples()==1 || src1.num_samples()==MD) &&
+ (src2.num_samples()==1 || src2.num_samples()==MD) );
+
+ if (dest.size() == 0)
+ return;
+
+ const size_t max_size = std::max(std::max(dest.size(),src1.size()),src2.size());
+ const auto d = dest.host();
+ const auto s1 = src1.host();
+ const auto s2 = src2.host();
+ if (dest.size() == src1.size() && src1.size() == src2.size())
+ {
+ if (add_to)
+ {
+ for (size_t i = 0; i < src1.size(); ++i)
+ d[i] += s1[i]*s2[i];
+ }
+ else
+ {
+ for (size_t i = 0; i < src1.size(); ++i)
+ d[i] = s1[i]*s2[i];
+ }
+ }
+ else if (dest.num_samples() == 1)
+ {
+ if (!add_to)
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ d[i] = 0;
+ }
+ for (size_t i = 0; i < max_size; ++i)
+ d[i%dest.size()] += s1[i%src1.size()]*s2[i%src2.size()];
+ }
+ else
+ {
+ if (add_to)
+ {
+ for (size_t i = 0; i < max_size; ++i)
+ d[i] += s1[i%src1.size()]*s2[i%src2.size()];
+ }
+ else
+ {
+ for (size_t i = 0; i < max_size; ++i)
+ d[i] = s1[i%src1.size()]*s2[i%src2.size()];
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void multiply_conv (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+ auto d = dest.host();
+ auto s1 = src1.host();
+ auto s2 = src2.host();
+ if (have_same_dimensions(dest,src1))
+ {
+ DLIB_CASSERT(src2.num_samples() == 1 && src2.nr() == 1 && src2.nc() == 1 && src2.k() == src1.k());
+
+ if (add_to)
+ {
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ *d++ += (*s1++)*s2[k];
+ }
+ }
+ }
+ }
+ }
+ else
+ {
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ *d++ = (*s1++)*s2[k];
+ }
+ }
+ }
+ }
+ }
+ }
+ else
+ {
+ DLIB_CASSERT(have_same_dimensions(src1,src2));
+ DLIB_CASSERT(dest.num_samples() == 1 && dest.nr() == 1 && dest.nc() == 1 && dest.k() == src1.k());
+
+ if (!add_to)
+ {
+ for (long k = 0; k < src1.k(); ++k)
+ d[k] = 0;
+ }
+
+ for (long n = 0; n < src1.num_samples(); ++n)
+ {
+ for (long k = 0; k < src1.k(); ++k)
+ {
+ for (long r = 0; r < src1.nr(); ++r)
+ {
+ for (long c = 0; c < src1.nc(); ++c)
+ {
+ d[k] += (*s1++)*(*s2++);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void scale_channels (
+ bool add_to,
+ tensor& dest,
+ const tensor& src,
+ const tensor& scales
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src) &&
+ scales.num_samples() == src.num_samples() &&
+ scales.k() == src.k() &&
+ scales.nr() == 1 &&
+ scales.nc() == 1 );
+
+ if (dest.size() == 0)
+ return;
+
+ if (add_to)
+ {
+ auto d = dest.host();
+ auto s = src.host();
+ auto scal = scales.host();
+
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long k = 0; k < src.k(); ++k)
+ {
+ const auto scale = scal[n*scales.k() + k];
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ *d++ += (*s++) * scale;
+ }
+ }
+ }
+ }
+
+
+ }
+ else
+ {
+ auto d = dest.host_write_only();
+ auto s = src.host();
+ auto scal = scales.host();
+
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long k = 0; k < src.k(); ++k)
+ {
+ const auto scale = scal[n*scales.k() + k];
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ *d++ = (*s++) * scale;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void add(
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(
+ (have_same_dimensions(src, dest) ||
+ (src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) ||
+ (src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) ||
+ (src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()) ||
+ (src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1)) &&
+ is_same_object(src,dest) == false ,
+ "\n\t dest.num_samples(): " << dest.num_samples()
+ <<"\n\t dest.k(): " << dest.k()
+ <<"\n\t dest.nr(): " << dest.nr()
+ <<"\n\t dest.nc(): " << dest.nc()
+ <<"\n\t src.num_samples(): " << src.num_samples()
+ <<"\n\t src.k(): " << src.k()
+ <<"\n\t src.nr(): " << src.nr()
+ <<"\n\t src.nc(): " << src.nc()
+ );
+
+
+ if (beta == 0 && alpha == 0)
+ {
+ dest = 0;
+ return;
+ }
+
+ auto d = dest.host();
+ auto s = src.host();
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ const auto sn = src.num_samples()==1 ? 0:n;
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ const auto sk = src.k()==1 ? 0:k;
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ const auto sr = src.nr()==1 ? 0:r;
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ const auto sc = src.nc()==1 ? 0:c;
+
+ const auto s_idx = ((sn*src.k() + sk)*src.nr() + sr)*src.nc() + sc;
+ *d = beta*(*d) + alpha*s[s_idx];
+ ++d;
+ }
+ }
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void add (
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+ auto d = dest.host();
+ auto s1 = src1.host();
+ auto s2 = src2.host();
+
+ // Do the simple and fast version if everything has the same dimensions
+ if (have_same_dimensions(dest, src1) &&
+ have_same_dimensions(dest, src2))
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ d[i] = s1[i] + s2[i];
+ return;
+ }
+
+ // Otherwise, do the more complex version with bounds checking.
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ float v1 = 0;
+ float v2 = 0;
+
+ // if this index is inside src1
+ if (n < src1.num_samples() &&
+ k < src1.k() &&
+ r < src1.nr() &&
+ c < src1.nc() )
+ {
+ const auto s_idx = ((n*src1.k() + k)*src1.nr() + r)*src1.nc() + c;
+ v1 = s1[s_idx];
+ }
+
+ // if this index is inside src2
+ if (n < src2.num_samples() &&
+ k < src2.k() &&
+ r < src2.nr() &&
+ c < src2.nc() )
+ {
+ const auto s_idx = ((n*src2.k() + k)*src2.nr() + r)*src2.nc() + c;
+ v2 = s2[s_idx];
+ }
+
+ *d = v1 + v2;
+ ++d;
+ }
+ }
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void multiply_zero_padded (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+ auto d = dest.host();
+ auto s1 = src1.host();
+ auto s2 = src2.host();
+
+ // Do the simple and fast version if everything has the same dimensions
+ if (have_same_dimensions(dest, src1) &&
+ have_same_dimensions(dest, src2))
+ {
+ if (add_to)
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ d[i] += s1[i] * s2[i];
+ }
+ else
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ d[i] = s1[i] * s2[i];
+ }
+ return;
+ }
+
+ // Otherwise, do the more complex version with bounds checking.
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ float v1 = 0;
+ float v2 = 0;
+
+ // if this index is inside src1
+ if (n < src1.num_samples() &&
+ k < src1.k() &&
+ r < src1.nr() &&
+ c < src1.nc() )
+ {
+ const auto s_idx = ((n*src1.k() + k)*src1.nr() + r)*src1.nc() + c;
+ v1 = s1[s_idx];
+ }
+
+ // if this index is inside src2
+ if (n < src2.num_samples() &&
+ k < src2.k() &&
+ r < src2.nr() &&
+ c < src2.nc() )
+ {
+ const auto s_idx = ((n*src2.k() + k)*src2.nr() + r)*src2.nc() + c;
+ v2 = s2[s_idx];
+ }
+
+ if (add_to)
+ *d += v1 * v2;
+ else
+ *d = v1 * v2;
+ ++d;
+ }
+ }
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void assign_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(
+ grad.num_samples() == 1 &&
+ gradient_input.k() == grad.k() &&
+ gradient_input.nr() == grad.nr() &&
+ gradient_input.nc() == grad.nc() &&
+ gradient_input.size() > 0);
+
+ auto out = grad.host();
+ auto in = gradient_input.host();
+
+ for (size_t i = 0; i < grad.size(); ++i)
+ out[i] = *in++;
+
+ for (long j = 1; j < gradient_input.num_samples(); ++j)
+ {
+ for (size_t i = 0; i < grad.size(); ++i)
+ out[i] += *in++;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void assign_conv_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(
+ grad.num_samples() == 1 &&
+ grad.k() >= 1 &&
+ grad.nr() == 1 &&
+ grad.nc() == 1 &&
+ gradient_input.k() == grad.k() &&
+ gradient_input.size() > 0 &&
+ is_same_object(grad,gradient_input) == false
+ );
+
+ auto g = grad.host();
+ auto gi = gradient_input.host();
+
+ for (long k = 0; k < gradient_input.k(); ++k)
+ g[k] = 0;
+
+ for (long n = 0; n < gradient_input.num_samples(); ++n)
+ {
+ for (long k = 0; k < gradient_input.k(); ++k)
+ {
+ for (long r = 0; r < gradient_input.nr(); ++r)
+ {
+ for (long c = 0; c < gradient_input.nc(); ++c)
+ {
+ g[k] += (*gi++);
+ }
+ }
+ }
+ }
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A,
+ const float B
+ )
+ {
+ DLIB_CASSERT(dest.size()==src.size());
+ const auto d = dest.host();
+ const auto s = src.host();
+ for (size_t i = 0; i < src.size(); ++i)
+ d[i] = A*s[i] + B;
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B,
+ const float C
+ )
+ {
+ DLIB_CASSERT(dest.size()==src1.size());
+ DLIB_CASSERT(dest.size()==src2.size());
+ const auto d = dest.host();
+ const auto s1 = src1.host();
+ const auto s2 = src2.host();
+ for (size_t i = 0; i < src1.size(); ++i)
+ d[i] = A*s1[i] + B*s2[i] + C;
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C,
+ const float D
+ )
+ {
+ DLIB_CASSERT(dest.size()==src1.size());
+ DLIB_CASSERT(dest.size()==src2.size());
+ DLIB_CASSERT(dest.size()==src3.size());
+ const auto d = dest.host();
+ const auto s1 = src1.host();
+ const auto s2 = src2.host();
+ const auto s3 = src3.host();
+ for (size_t i = 0; i < src1.size(); ++i)
+ d[i] = A*s1[i] + B*s2[i] + C*s3[i] + D;
+ }
+
+ void affine_transform_range(
+ size_t begin,
+ size_t end,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C
+ )
+ {
+ DLIB_CASSERT(dest.size()==src1.size());
+ DLIB_CASSERT(dest.size()==src2.size());
+ DLIB_CASSERT(dest.size()==src3.size());
+ DLIB_CASSERT(begin <= end && end <= dest.size());
+ const auto d = dest.host();
+ const auto s1 = src1.host();
+ const auto s2 = src2.host();
+ const auto s3 = src3.host();
+ for (size_t i = begin; i < end; ++i)
+ d[i] = A*s1[i] + B*s2[i] + C*s3[i];
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ DLIB_CASSERT(
+ ((A.num_samples()==1 && B.num_samples()==1) ||
+ (A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples())) &&
+ A.nr()==B.nr() && B.nr()==src.nr() &&
+ A.nc()==B.nc() && B.nc()==src.nc() &&
+ A.k() ==B.k() && B.k()==src.k());
+
+ auto d = dest.host();
+ auto s = src.host();
+ const auto a = A.host();
+ const auto b = B.host();
+ if (A.num_samples() == 1)
+ {
+ const long num = src.size()/src.num_samples();
+ for (long i = 0; i < src.num_samples(); ++i)
+ {
+ for (long j = 0; j < num; ++j)
+ {
+ *d = a[j]*(*s) + b[j];
+ d++;
+ s++;
+ }
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < src.size(); ++i)
+ d[i] = a[i]*s[i] + b[i];
+ }
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform_conv(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ DLIB_CASSERT(have_same_dimensions(A,B));
+ DLIB_CASSERT(A.num_samples() == 1 &&
+ A.nr() == 1 &&
+ A.nc() == 1 &&
+ A.k() == src.k());
+
+ auto d = dest.host();
+ auto s = src.host();
+ const auto a = A.host();
+ const auto b = B.host();
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ *d++ = a[k]*(*s++) + b[k];
+ }
+ }
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void affine_transform(
+ const rectangle& rect,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ float A,
+ float B,
+ float C
+ )
+ {
+ DLIB_CASSERT(dest.size() == src1.size());
+ DLIB_CASSERT(dest.size() == src2.size());
+ DLIB_CASSERT(dest.size() == src3.size());
+ DLIB_CASSERT(dest.num_samples() == src1.num_samples());
+ DLIB_CASSERT(dest.num_samples() == src2.num_samples());
+ DLIB_CASSERT(dest.num_samples() == src3.num_samples());
+ DLIB_CASSERT(rectangle(0,0, dest.size()/dest.num_samples()-1, dest.num_samples()-1).contains(rect));
+
+
+ auto d = dest.host();
+ auto s1 = src1.host();
+ auto s2 = src2.host();
+ auto s3 = src3.host();
+
+ const auto nc = dest.size()/dest.num_samples();
+
+ for (long r = rect.top(); r <= rect.bottom(); ++r)
+ {
+ for (long c = rect.left(); c <= rect.right(); ++c)
+ {
+ auto idx = r*nc + c;
+ d[idx] = s1[idx]*A + s2[idx]*B + s3[idx]*C;
+ }
+ }
+
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ void compute_adam_update (
+ size_t begin,
+ size_t end,
+ tensor& s,
+ tensor& m,
+ tensor& v,
+ const float t,
+ const float learning_rate,
+ const float weight_decay,
+ const float momentum1,
+ const float momentum2,
+ const tensor& params,
+ const tensor& params_grad
+ )
+ {
+ DLIB_CASSERT(s.size() == m.size() &&
+ s.size() == v.size() &&
+ s.size() == params.size() &&
+ s.size() == params_grad.size());
+ DLIB_CASSERT(begin <= end && end <= params.size());
+ const float eps = 1e-8;
+ const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t));
+
+ // The loop is equivalent to doing this:
+ // m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad);
+ // v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad);
+ // s = -alpha*m/(sqrt(v) + eps);
+ auto pm = m.host();
+ auto pv = v.host();
+ auto ps = s.host_write_only();
+ auto pparams = params.host();
+ auto ppgrad = params_grad.host();
+ for (size_t i = begin; i < end; ++i)
+ {
+ float g = weight_decay*pparams[i] + ppgrad[i];
+ pm[i] = momentum1*pm[i] + (1-momentum1)*g;
+ pv[i] = momentum2*pv[i] + (1-momentum2)*g*g;
+ ps[i] = -alpha*pm[i]/(std::sqrt(pv[i]) + eps);
+ }
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ void batch_normalize_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ )
+ {
+ DLIB_CASSERT(
+ gamma.num_samples() == 1 &&
+ gamma.nr() == src.nr() &&
+ gamma.nc() == src.nc() &&
+ gamma.k() == src.k() &&
+ have_same_dimensions(gamma, beta) &&
+ have_same_dimensions(gamma, running_means) &&
+ have_same_dimensions(gamma, running_variances) &&
+ eps > 0,
+ "\ngamma.num_samples(): " << gamma.num_samples() <<
+ "\ngamma.k(): " << gamma.k() <<
+ "\ngamma.nr(): " << gamma.nr() <<
+ "\ngamma.nc(): " << gamma.nc() <<
+ "\nbeta.num_samples(): " << beta.num_samples() <<
+ "\nbeta.k(): " << beta.k() <<
+ "\nbeta.nr(): " << beta.nr() <<
+ "\nbeta.nc(): " << beta.nc() <<
+ "\nrunning_means.num_samples(): " << running_means.num_samples() <<
+ "\nrunning_means.k(): " << running_means.k() <<
+ "\nrunning_means.nr(): " << running_means.nr() <<
+ "\nrunning_means.nc(): " << running_means.nc() <<
+ "\nrunning_variances.num_samples(): " << running_variances.num_samples() <<
+ "\nrunning_variances.k(): " << running_variances.k() <<
+ "\nrunning_variances.nr(): " << running_variances.nr() <<
+ "\nrunning_variances.nc(): " << running_variances.nc() <<
+ "\nsrc.k(): " << src.k() <<
+ "\nsrc.nr(): " << src.nr() <<
+ "\nsrc.nc(): " << src.nc() <<
+ "\neps: " << eps
+ );
+ dest.copy_size(src);
+
+ auto d = dest.host();
+ auto s = src.host();
+ auto g = gamma.host();
+ auto b = beta.host();
+ auto m = running_means.host();
+ auto v = running_variances.host();
+
+ const long num = src.k()*src.nr()*src.nc();
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long k = 0; k < num; ++k)
+ {
+ *d = g[k]*(*s - m[k])/std::sqrt(v[k]+eps) + b[k];
+ ++d;
+ ++s;
+ }
+ }
+ }
+
+ void batch_normalize (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ )
+ {
+ DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor);
+ DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means));
+ DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds));
+ DLIB_CASSERT(
+ src.num_samples() > 1 &&
+ gamma.num_samples() == 1 &&
+ beta.num_samples() == 1 &&
+ gamma.nr() == beta.nr() && beta.nr() == src.nr() &&
+ gamma.nc() == beta.nc() && beta.nc() == src.nc() &&
+ gamma.k() == beta.k() && beta.k() == src.k() &&
+ eps > 0,
+ "\ngamma.num_samples(): " << gamma.num_samples() <<
+ "\ngamma.k(): " << gamma.k() <<
+ "\ngamma.nr(): " << gamma.nr() <<
+ "\ngamma.nc(): " << gamma.nc() <<
+ "\nbeta.num_samples(): " << beta.num_samples() <<
+ "\nbeta.k(): " << beta.k() <<
+ "\nbeta.nr(): " << beta.nr() <<
+ "\nbeta.nc(): " << beta.nc() <<
+ "\nsrc.k(): " << src.k() <<
+ "\nsrc.nr(): " << src.nr() <<
+ "\nsrc.nc(): " << src.nc() <<
+ "\neps: " << eps
+ );
+
+ dest.copy_size(src);
+ means.set_size(1, src.k(), src.nr(), src.nc());
+ invstds.set_size(1, src.k(), src.nr(), src.nc());
+
+ // first compute means and invstds
+ means = 0;
+ invstds = 0;
+ const auto p_invstds = invstds.host();
+ const auto p_means = means.host();
+ auto p_src = src.host();
+ const long num = src.k()*src.nr()*src.nc();
+ // compute means, and sum of squares
+ for (long i = 0; i < num; ++i)
+ {
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ float val = p_src[n*num+i];
+ p_means[i] += val;
+ p_invstds[i] += val*val;
+ }
+ }
+ means /= src.num_samples();
+ invstds /= src.num_samples();
+ // copy data back to host
+ invstds.host(); means.host();
+
+ // compute variances
+ running_variances.copy_size(invstds);
+ auto rvar = running_variances.host();
+ // This scale makes the running variances unbiased.
+ const double scale = (src.num_samples())/(src.num_samples()-1.0);
+ for (long i = 0; i < num; ++i)
+ {
+ auto actual_var = p_invstds[i] - p_means[i]*p_means[i];
+ if (averaging_factor == 1)
+ rvar[i] = scale*actual_var;
+ else
+ rvar[i] = (1-averaging_factor)*rvar[i] + scale*averaging_factor*actual_var;
+
+ p_invstds[i] = 1.0f/std::sqrt(actual_var + eps);
+ }
+
+ p_src = src.host();
+ auto p_dest = dest.host();
+ const auto p_gamma = gamma.host();
+ const auto p_beta = beta.host();
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long i = 0; i < num; ++i)
+ {
+ *p_dest = (*p_src - p_means[i])*p_invstds[i];
+ *p_dest = (*p_dest)*p_gamma[i] + p_beta[i];
+ ++p_src;
+ ++p_dest;
+ }
+ }
+
+ // now keep track of the running means
+ running_means.copy_size(means);
+ if (averaging_factor != 1)
+ running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means);
+ else
+ running_means = means;
+ }
+
+ void batch_normalize_gradient (
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ )
+ {
+
+ const long num = src.k()*src.nr()*src.nc();
+ DLIB_CASSERT(src.num_samples() > 1);
+ DLIB_CASSERT(num == (long)means.size());
+ DLIB_CASSERT(num == (long)invstds.size());
+ DLIB_CASSERT(num == (long)gamma.size());
+ DLIB_CASSERT(num == (long)gamma_grad.size());
+ DLIB_CASSERT(num == (long)beta_grad.size());
+ DLIB_CASSERT(have_same_dimensions(gradient_input, src));
+ DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
+ DLIB_CASSERT(eps > 0);
+
+ beta_grad = 0;
+ gamma_grad = 0;
+ auto p_grad = gradient_input.host();
+ auto p_src = src.host();
+ const auto p_gamma = gamma.host();
+ const auto p_gamma_grad = gamma_grad.host();
+ const auto p_beta_grad = beta_grad.host();
+ const auto p_invstds = invstds.host();
+ const auto p_means = means.host();
+
+ resizable_tensor dvars, dmeans;
+ dvars.copy_size(invstds);
+ dmeans.copy_size(means);
+ dvars = 0;
+ dmeans = 0;
+ const auto p_dvars = dvars.host();
+ const auto p_dmeans = dmeans.host();
+
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long i = 0; i < num; ++i)
+ {
+ const float x_hat = (*p_src - p_means[i])*p_invstds[i];
+ p_beta_grad[i] += *p_grad;
+ p_gamma_grad[i] += (*p_grad)*x_hat;
+
+ const float dx = *p_grad * p_gamma[i];
+
+ p_dvars[i] += dx*(*p_src - p_means[i])*-0.5*std::pow(p_invstds[i], 3.0f);
+
+ ++p_grad;
+ ++p_src;
+ }
+ }
+
+ const float invnum = 1.0f/src.num_samples();
+ p_grad = gradient_input.host();
+ p_src = src.host();
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long i = 0; i < num; ++i)
+ {
+ const float dx = *p_grad * p_gamma[i];
+
+ p_dmeans[i] += dx*-p_invstds[i] + p_dvars[i] * -2*(*p_src - p_means[i])*invnum;
+
+ ++p_grad;
+ ++p_src;
+ }
+ }
+ p_grad = gradient_input.host();
+ p_src = src.host();
+ auto p_src_grad = src_grad.host();
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long i = 0; i < num; ++i)
+ {
+ const float dx = *p_grad * p_gamma[i];
+
+ *p_src_grad += dx*p_invstds[i] +
+ p_dvars[i] *2*(*p_src - p_means[i])*invnum +
+ p_dmeans[i]*invnum;
+
+
+ ++p_grad;
+ ++p_src;
+ ++p_src_grad;
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void batch_normalize_conv_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ )
+ {
+ DLIB_CASSERT(
+ gamma.num_samples() == 1 &&
+ gamma.nr() == 1 &&
+ gamma.nc() == 1 &&
+ gamma.k() == src.k() &&
+ have_same_dimensions(gamma, beta) &&
+ have_same_dimensions(gamma, running_means) &&
+ have_same_dimensions(gamma, running_variances) &&
+ eps > 0,
+ "\ngamma.num_samples(): " << gamma.num_samples() <<
+ "\ngamma.k(): " << gamma.k() <<
+ "\ngamma.nr(): " << gamma.nr() <<
+ "\ngamma.nc(): " << gamma.nc() <<
+ "\nbeta.num_samples(): " << beta.num_samples() <<
+ "\nbeta.k(): " << beta.k() <<
+ "\nbeta.nr(): " << beta.nr() <<
+ "\nbeta.nc(): " << beta.nc() <<
+ "\nrunning_means.num_samples(): " << running_means.num_samples() <<
+ "\nrunning_means.k(): " << running_means.k() <<
+ "\nrunning_means.nr(): " << running_means.nr() <<
+ "\nrunning_means.nc(): " << running_means.nc() <<
+ "\nrunning_variances.num_samples(): " << running_variances.num_samples() <<
+ "\nrunning_variances.k(): " << running_variances.k() <<
+ "\nrunning_variances.nr(): " << running_variances.nr() <<
+ "\nrunning_variances.nc(): " << running_variances.nc() <<
+ "\nsrc.k(): " << src.k() <<
+ "\nsrc.nr(): " << src.nr() <<
+ "\nsrc.nc(): " << src.nc() <<
+ "\neps: " << eps
+ );
+ dest.copy_size(src);
+
+ auto d = dest.host();
+ auto s = src.host();
+ auto g = gamma.host();
+ auto b = beta.host();
+ auto m = running_means.host();
+ auto v = running_variances.host();
+
+ const long num = src.nr()*src.nc();
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long k = 0; k < src.k(); ++k)
+ {
+ const float invstd = 1.0f/std::sqrt(v[k] + eps);
+ for (long j = 0; j < num; ++j)
+ {
+ *d = g[k]*(*s - m[k])*invstd + b[k];
+ ++d;
+ ++s;
+ }
+ }
+ }
+ }
+
+ void batch_normalize_conv (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ )
+ {
+ DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor);
+ DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means));
+ DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds));
+ DLIB_CASSERT(
+ src.num_samples() > 1 &&
+ gamma.num_samples() == 1 &&
+ beta.num_samples() == 1 &&
+ gamma.nr() == 1 &&
+ beta.nr() == 1 &&
+ gamma.nc() == 1 &&
+ beta.nc() == 1 &&
+ gamma.k() == beta.k() && beta.k() == src.k() &&
+ eps > 0,
+ "\ngamma.num_samples(): " << gamma.num_samples() <<
+ "\ngamma.k(): " << gamma.k() <<
+ "\ngamma.nr(): " << gamma.nr() <<
+ "\ngamma.nc(): " << gamma.nc() <<
+ "\nbeta.num_samples(): " << beta.num_samples() <<
+ "\nbeta.k(): " << beta.k() <<
+ "\nbeta.nr(): " << beta.nr() <<
+ "\nbeta.nc(): " << beta.nc() <<
+ "\nsrc.k(): " << src.k() <<
+ "\nsrc.nr(): " << src.nr() <<
+ "\nsrc.nc(): " << src.nc() <<
+ "\neps: " << eps
+ );
+
+ dest.copy_size(src);
+ means.set_size(1, src.k());
+ invstds.set_size(1, src.k());
+
+ // first compute means and invstds
+ means = 0;
+ invstds = 0;
+ const auto p_invstds = invstds.host();
+ const auto p_means = means.host();
+ const auto p_gamma = gamma.host();
+ const auto p_beta = beta.host();
+ auto p_src = src.host();
+ const long num = src.nr()*src.nc();
+ // compute means, and sum of squares
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long k = 0; k < src.k(); ++k)
+ {
+ for (long i = 0; i < num; ++i)
+ {
+ p_means[k] += *p_src;
+ p_invstds[k] += (*p_src)*(*p_src);
+ ++p_src;
+ }
+ }
+ }
+ means /= src.num_samples()*num;
+ invstds /= src.num_samples()*num;
+ // copy data back to host
+ invstds.host(); means.host();
+
+ p_src = src.host();
+ // compute variances
+ running_variances.copy_size(invstds);
+ auto rvar = running_variances.host();
+ // This scale makes the running variances unbiased.
+ const double scale = (src.num_samples()*num)/(src.num_samples()*num-1.0);
+ for (long k = 0; k < src.k(); ++k)
+ {
+ float actual_var = p_invstds[k] - p_means[k]*p_means[k];
+ if (averaging_factor == 1)
+ rvar[k] = scale*actual_var;
+ else
+ rvar[k] = (1-averaging_factor)*rvar[k] + scale*averaging_factor*actual_var;
+
+ p_invstds[k] = 1.0f/std::sqrt(actual_var + eps);
+ }
+
+ p_src = src.host();
+ auto p_dest = dest.host();
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long k = 0; k < src.k(); ++k)
+ {
+ for (long i = 0; i < num; ++i)
+ {
+ *p_dest = (*p_src - p_means[k])*p_invstds[k];
+ *p_dest = (*p_dest)*p_gamma[k] + p_beta[k];
+ ++p_src;
+ ++p_dest;
+ }
+ }
+ }
+
+ // now keep track of the running means
+ running_means.copy_size(means);
+ if (averaging_factor != 1)
+ running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means);
+ else
+ running_means = means;
+ }
+
+ void batch_normalize_conv_gradient(
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ )
+ {
+
+ const long num = src.nr()*src.nc();
+ DLIB_CASSERT(src.num_samples() > 1);
+ DLIB_CASSERT(src.k() == (long)means.size());
+ DLIB_CASSERT(src.k() == (long)invstds.size());
+ DLIB_CASSERT(src.k() == (long)gamma.size());
+ DLIB_CASSERT(src.k() == (long)gamma_grad.size());
+ DLIB_CASSERT(src.k() == (long)beta_grad.size());
+ DLIB_CASSERT(have_same_dimensions(gradient_input, src));
+ DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
+ DLIB_CASSERT(eps > 0);
+
+ beta_grad = 0;
+ gamma_grad = 0;
+
+ auto p_grad = gradient_input.host();
+ auto p_src = src.host();
+ const auto p_gamma = gamma.host();
+ const auto p_gamma_grad = gamma_grad.host();
+ const auto p_beta_grad = beta_grad.host();
+ const auto p_invstds = invstds.host();
+ const auto p_means = means.host();
+
+ resizable_tensor dvars, dmeans;
+ dvars.copy_size(invstds);
+ dmeans.copy_size(means);
+ dvars = 0;
+ dmeans = 0;
+ const auto p_dvars = dvars.host();
+ const auto p_dmeans = dmeans.host();
+
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long k = 0; k < src.k(); ++k)
+ {
+ const float invstd_pow = -0.5*std::pow(p_invstds[k], 3.0f);
+ for (long i = 0; i < num; ++i)
+ {
+ const float x_hat = (*p_src - p_means[k])*p_invstds[k];
+ p_beta_grad[k] += *p_grad;
+ p_gamma_grad[k] += (*p_grad)*x_hat;
+
+ const float dx = *p_grad * p_gamma[k];
+
+ p_dvars[k] += dx*(*p_src - p_means[k])*invstd_pow;
+
+ ++p_grad;
+ ++p_src;
+ }
+ }
+ }
+
+ p_grad = gradient_input.host();
+ p_src = src.host();
+ const float invnum = 1.0f/(src.num_samples()*num);
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long k = 0; k < src.k(); ++k)
+ {
+ for (long i = 0; i < num; ++i)
+ {
+ const float dx = *p_grad * p_gamma[k];
+
+ p_dmeans[k] += -dx*p_invstds[k] + p_dvars[k] * -2*(*p_src - p_means[k])*invnum;
+
+ ++p_grad;
+ ++p_src;
+ }
+ }
+ }
+ p_grad = gradient_input.host();
+ p_src = src.host();
+ auto p_src_grad = src_grad.host();
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ for (long k = 0; k < src.k(); ++k)
+ {
+ for (long i = 0; i < num; ++i)
+ {
+ const float dx = *p_grad * p_gamma[k];
+
+ *p_src_grad += dx*p_invstds[k] +
+ p_dvars[k]*2*(*p_src - p_means[k])*invnum +
+ p_dmeans[k]*invnum;
+
+
+ ++p_grad;
+ ++p_src;
+ ++p_src_grad;
+ }
+ }
+ }
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ void threshold (
+ tensor& data,
+ float thresh
+ )
+ {
+ const auto d = data.host();
+ for (size_t i = 0; i < data.size(); ++i)
+ d[i] = d[i]>thresh ? 1:0;
+ }
+
+ void dot (
+ const tensor& a,
+ const tensor& b,
+ tensor& result,
+ size_t idx
+ )
+ {
+ DLIB_CASSERT(a.size() == b.size());
+ DLIB_CASSERT(idx < result.size());
+
+ const auto aa = a.host();
+ const auto bb = b.host();
+ auto r = result.host();
+ for (size_t i = 0; i < a.size(); ++i)
+ r[idx] += aa[i]*bb[i];
+ }
+
+ // -----------------------------------------------------------------------------------
+ // -----------------------------------------------------------------------------------
+ // -----------------------------------------------------------------------------------
+
+ namespace ttimpl
+ {
+ void softmax (
+ const long num_locations,
+ const long num_channels,
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_ASSERT(num_channels*num_locations == src.nr()*src.nc()*src.k());
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ const auto d = dest.host();
+ const auto s = src.host();
+
+ // Note that we subtract out the max values in each channel before applying
+ // exp() to avoid numeric overflow in the subsequent computations. Doing this
+ // doesn't change the resulting output, it just makes it more numerically
+ // stable.
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ auto ss = s + num_locations*num_channels*n;
+ auto dd = d + num_locations*num_channels*n;
+ for (long i = 0; i < num_locations; ++i)
+ {
+ float max_val = -std::numeric_limits<float>::infinity();
+ for (long k = 0; k < num_channels; ++k)
+ max_val = std::max(max_val, ss[k*num_locations]);
+
+ for (long k = 0; k < num_channels; ++k)
+ dd[k*num_locations] = std::exp(ss[k*num_locations]-max_val);
+
+ ++ss;
+ ++dd;
+ }
+ }
+
+ // Now normalize each channel so they sum to 1.
+ for (long n = 0; n < src.num_samples(); ++n)
+ {
+ const auto dd = d + num_locations*num_channels*n;
+ for (long i = 0; i < num_locations; ++i)
+ {
+ const auto ddd = dd+i;
+
+ float temp = 0;
+ for (long k = 0; k < num_channels; ++k)
+ temp += ddd[k*num_locations];
+ for (long k = 0; k < num_channels; ++k)
+ ddd[k*num_locations] /= temp;
+ }
+ }
+ }
+
+ void softmax_gradient (
+ const long num_locations,
+ const long num_channels,
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_ASSERT(num_channels*num_locations == grad.nr()*grad.nc()*grad.k());
+ DLIB_CASSERT(have_same_dimensions(grad,dest));
+ DLIB_CASSERT(have_same_dimensions(grad,gradient_input));
+ const auto d = dest.host();
+ const auto g = grad.host();
+ const auto in = gradient_input.host();
+
+
+ for (long n = 0; n < grad.num_samples(); ++n)
+ {
+ const auto d2 = d + num_locations*num_channels*n;
+ const auto g2 = g + num_locations*num_channels*n;
+ const auto in2 = in + num_locations*num_channels*n;
+ for (long i = 0; i < num_locations; ++i)
+ {
+ const auto d3 = d2+i;
+ const auto g3 = g2+i;
+ const auto in3 = in2+i;
+
+ float temp = 0;
+ for (long k = 0; k < num_channels; ++k)
+ temp += -d3[k*num_locations]*in3[k*num_locations];
+ if (is_same_object(gradient_input, grad))
+ {
+ for (long k = 0; k < num_channels; ++k)
+ g3[k*num_locations] = d3[k*num_locations]*(temp+in3[k*num_locations]);
+ }
+ else
+ {
+ for (long k = 0; k < num_channels; ++k)
+ g3[k*num_locations] += d3[k*num_locations]*(temp+in3[k*num_locations]);
+ }
+ }
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void softmax (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ ttimpl::softmax(src.nr()*src.nc(), src.k(), dest, src);
+ }
+
+ void softmax_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(grad,dest));
+ DLIB_CASSERT(have_same_dimensions(grad,gradient_input));
+ ttimpl::softmax_gradient(grad.nr()*grad.nc(), grad.k(), grad, dest, gradient_input);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void softmax_all (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ ttimpl::softmax(1, src.nr()*src.nc()*src.k(), dest, src);
+ }
+
+ void softmax_all_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(grad,dest));
+ DLIB_CASSERT(have_same_dimensions(grad,gradient_input));
+ ttimpl::softmax_gradient(1, grad.nr()*grad.nc()*grad.k(), grad, dest, gradient_input);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void sigmoid (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ const auto d = dest.host();
+ const auto s = src.host();
+ for (size_t i = 0; i < src.size(); ++i)
+ d[i] = 1/(1+std::exp(-s[i]));
+ }
+
+ void sigmoid_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ const auto g = grad.host();
+ const auto d = dest.host();
+ const auto in = gradient_input.host();
+ if (is_same_object(gradient_input, grad))
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ g[i] = in[i]*d[i]*(1-d[i]);
+ }
+ else
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ g[i] += in[i]*d[i]*(1-d[i]);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void relu (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ dest = lowerbound(mat(src), 0);
+ }
+
+ void relu_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ const float* gi = gradient_input.host();
+ const float* in = dest.host();
+ float* out = grad.host();
+ if (is_same_object(grad, gradient_input))
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ {
+ if (in[i] > 0)
+ out[i] = gi[i];
+ else
+ out[i] = 0;
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ {
+ if (in[i] > 0)
+ out[i] += gi[i];
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void prelu (
+ tensor& dest,
+ const tensor& src,
+ const tensor& param
+ )
+ {
+ const float p = param.host()[0];
+ const float* s = src.host();
+ float* d = dest.host();
+ for (size_t i = 0; i < dest.size(); ++i)
+ {
+ if (s[i] > 0)
+ d[i] = s[i];
+ else
+ d[i] = p*s[i];
+ }
+ }
+
+ void prelu_gradient (
+ tensor& grad,
+ const tensor& src,
+ const tensor& gradient_input,
+ const tensor& param,
+ tensor& params_grad
+ )
+ {
+ DLIB_CASSERT(is_same_object(grad, gradient_input) == false);
+ const float p = param.host()[0];
+ const float* gi = gradient_input.host();
+ const float* s = src.host();
+ float* out = grad.host();
+ float pgrad = 0;
+ for (size_t i = 0; i < src.size(); ++i)
+ {
+ if (s[i] > 0)
+ {
+ out[i] += gi[i];
+ }
+ else
+ {
+ out[i] += p*gi[i];
+ pgrad += gi[i]*s[i];
+ }
+ }
+ params_grad.host()[0] = pgrad;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void tanh (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ const auto d = dest.host();
+ const auto s = src.host();
+ for (size_t i = 0; i < src.size(); ++i)
+ d[i] = std::tanh(s[i]);
+ }
+
+ void tanh_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ const auto g = grad.host();
+ const auto d = dest.host();
+ const auto in = gradient_input.host();
+ if (is_same_object(grad, gradient_input))
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ g[i] = in[i]*(1-d[i]*d[i]);
+ }
+ else
+ {
+ for (size_t i = 0; i < dest.size(); ++i)
+ g[i] += in[i]*(1-d[i]*d[i]);
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void resize_bilinear (
+ tensor& dest,
+ long dest_row_stride,
+ long dest_channel_stride,
+ const tensor& src,
+ long src_row_stride,
+ long src_channel_stride
+ )
+ {
+ DLIB_CASSERT(is_same_object(dest, src)==false);
+ DLIB_CASSERT(dest.num_samples() == src.num_samples());
+ DLIB_CASSERT(dest.k() == src.k());
+
+ if (dest.size() == 0 || src.size() == 0)
+ return;
+
+ const float* s = src.host();
+ float* d = dest.host();
+
+ parallel_for(0, dest.k()*dest.num_samples(), [&](long i)
+ {
+ auto simg = sub_image(s+i*src_channel_stride, src.nr(), src.nc(), src_row_stride);
+ auto dimg = sub_image(d+i*dest_channel_stride, dest.nr(), dest.nc(), dest_row_stride);
+
+ resize_image(simg, dimg);
+ });
+ }
+
+ void resize_bilinear_gradient (
+ tensor& grad,
+ long grad_row_stride,
+ long grad_channel_stride,
+ const tensor& gradient_input,
+ long gradient_input_row_stride,
+ long gradient_input_channel_stride
+ )
+ {
+ DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
+ DLIB_CASSERT(gradient_input.num_samples() == grad.num_samples());
+ DLIB_CASSERT(gradient_input.k() == grad.k());
+
+ if (gradient_input.size() == 0 || grad.size() == 0)
+ return;
+
+ const float* gi = gradient_input.host();
+ float* g = grad.host();
+ const float x_scale = (grad.nc()-1)/(float)std::max<long>((gradient_input.nc()-1),1);
+ const float y_scale = (grad.nr()-1)/(float)std::max<long>((gradient_input.nr()-1),1);
+ for (long long samp = 0; samp < gradient_input.num_samples(); ++samp)
+ {
+ for (long long k = 0; k < gradient_input.k(); ++k)
+ {
+ for (long long r = 0; r < gradient_input.nr(); ++r)
+ {
+ const float y = r*y_scale;
+ const long long top = static_cast<long long>(std::floor(y));
+ const long long bottom = std::min(top+1, grad.nr()-1);
+ const float tb_frac = y - top;
+ for (long long c = 0; c < gradient_input.nc(); ++c)
+ {
+ const float x = c*x_scale;
+ const long long left = static_cast<long long>(std::floor(x));
+ const long long right = std::min(left+1, grad.nc()-1);
+ const float lr_frac = x - left;
+
+ const float tmp = gi[r*gradient_input_row_stride+c];
+
+ g[top*grad_row_stride+left] += tmp*(1-tb_frac)*(1-lr_frac);
+ g[top*grad_row_stride+right] += tmp*(1-tb_frac)*(lr_frac);
+ g[bottom*grad_row_stride+left] += tmp*(tb_frac)*(1-lr_frac);
+ g[bottom*grad_row_stride+right] += tmp*(tb_frac)*(lr_frac);
+ }
+ }
+
+ g += grad_channel_stride;
+ gi += gradient_input_channel_stride;
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ pooling::pooling (
+ ) : window_height(0),window_width(0),stride_y(0),stride_x(0),padding_y(0),padding_x(0),do_max_pooling(true)
+ {
+ }
+
+ void pooling::
+ clear(
+ )
+ {
+ window_height = 0;
+ window_width = 0;
+ stride_y = 0;
+ stride_x = 0;
+ padding_y = 0;
+ padding_x = 0;
+ }
+
+ void pooling::
+ setup_max_pooling(
+ int window_height_,
+ int window_width_,
+ int stride_y_,
+ int stride_x_,
+ int padding_y_,
+ int padding_x_
+ )
+ {
+ DLIB_CASSERT(window_width_ > 0);
+ DLIB_CASSERT(window_height_ > 0);
+ DLIB_CASSERT(stride_y_ > 0);
+ DLIB_CASSERT(stride_x_ > 0);
+ DLIB_CASSERT(0 <= padding_y_ && padding_y_ < window_height_);
+ DLIB_CASSERT(0 <= padding_x_ && padding_x_ < window_width_);
+
+ window_height = window_height_;
+ window_width = window_width_;
+ stride_y = stride_y_;
+ stride_x = stride_x_;
+ padding_y = padding_y_;
+ padding_x = padding_x_;
+ do_max_pooling = true;
+ }
+
+ void pooling::
+ setup_avg_pooling(
+ int window_height_,
+ int window_width_,
+ int stride_y_,
+ int stride_x_,
+ int padding_y_,
+ int padding_x_
+ )
+ {
+ DLIB_CASSERT(window_width_ > 0);
+ DLIB_CASSERT(window_height_ > 0);
+ DLIB_CASSERT(stride_y_ > 0);
+ DLIB_CASSERT(stride_x_ > 0);
+ DLIB_CASSERT(0 <= padding_y_ && padding_y_ < window_height_);
+ DLIB_CASSERT(0 <= padding_x_ && padding_x_ < window_width_);
+
+ window_height = window_height_;
+ window_width = window_width_;
+ stride_y = stride_y_;
+ stride_x = stride_x_;
+ padding_y = padding_y_;
+ padding_x = padding_x_;
+ do_max_pooling = false;
+ }
+
+ void pooling::
+ operator() (
+ resizable_tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(window_width > 0);
+ DLIB_CASSERT(window_height > 0);
+ DLIB_CASSERT(stride_y > 0);
+ DLIB_CASSERT(stride_x > 0);
+ DLIB_CASSERT(0 <= padding_y && padding_y < window_height);
+ DLIB_CASSERT(0 <= padding_x && padding_x < window_width);
+ DLIB_CASSERT(window_width <= src.nc() + 2*padding_x,
+ "Pooling windows must be small enough to fit into the padded image.");
+ DLIB_CASSERT(window_height <= src.nr() + 2*padding_y,
+ "Pooling windows must be small enough to fit into the padded image.");
+
+ dest.set_size(
+ src.num_samples(),
+ src.k(),
+ 1+(src.nr()+2*padding_y-window_height)/stride_y,
+ 1+(src.nc()+2*padding_x-window_width)/stride_x
+ );
+
+ if (src.size() == 0)
+ {
+ dest = 0;
+ return;
+ }
+
+
+ auto d = dest.host();
+ const long x_offset = window_width/2 - padding_x;
+ const long y_offset = window_height/2 - padding_y;
+ if (does_max_pooling())
+ {
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ auto simg = image_plane(src,n,k);
+ auto dimg = d + (n*dest.k() + k)*dest.nr()*dest.nc();
+
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ auto win = centered_rect(c*stride_x+x_offset,
+ r*stride_y+y_offset,
+ window_width,
+ window_height);
+ dimg[r*dest.nc() + c] = max(subm_clipped(simg,win));
+ }
+ }
+ }
+ }
+ }
+ else
+ {
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ auto simg = image_plane(src,n,k);
+ auto dimg = d + (n*dest.k() + k)*dest.nr()*dest.nc();
+
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ auto win = centered_rect(c*stride_x+x_offset,
+ r*stride_y+y_offset,
+ window_width,
+ window_height);
+ dimg[r*dest.nc() + c] = mean(subm_clipped(simg,win));
+ }
+ }
+ }
+ }
+ }
+
+ }
+
+ void pooling::get_gradient(
+ const tensor& gradient_input,
+ const tensor& dest,
+ const tensor& src,
+ tensor& grad
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(gradient_input,dest));
+ DLIB_CASSERT(have_same_dimensions(src,grad));
+
+
+ if (src.size() == 0)
+ {
+ return;
+ }
+
+
+ auto gi = gradient_input.host();
+ auto g = grad.host();
+ const long x_offset = window_width/2 - padding_x;
+ const long y_offset = window_height/2 - padding_y;
+ if (does_max_pooling())
+ {
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ auto simg = image_plane(src,n,k);
+ auto gimg = g + (n*grad.k() + k)*grad.nr()*grad.nc();
+ auto giimg = gi + (n*dest.k() + k)*dest.nr()*dest.nc();
+ auto imgbox = get_rect(simg);
+
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ auto win = centered_rect(c*stride_x+x_offset,
+ r*stride_y+y_offset,
+ window_width,
+ window_height).intersect(imgbox);
+ auto p = max_point(subm(simg,win))+win.tl_corner();
+ gimg[p.y()*grad.nc()+p.x()] += giimg[r*dest.nc()+c];
+ }
+ }
+ }
+ }
+ }
+ else
+ {
+ for (long n = 0; n < dest.num_samples(); ++n)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ auto simg = image_plane(src,n,k);
+ auto gimg = g + (n*grad.k() + k)*grad.nr()*grad.nc();
+ auto giimg = gi + (n*dest.k() + k)*dest.nr()*dest.nc();
+ auto imgbox = get_rect(simg);
+
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ auto win = centered_rect(c*stride_x+x_offset,
+ r*stride_y+y_offset,
+ window_width,
+ window_height).intersect(imgbox);
+ const float delta = giimg[r*dest.nc()+c]/win.area();
+ for (long y = win.top(); y <= win.bottom(); ++y)
+ {
+ for (long x = win.left(); x <= win.right(); ++x)
+ {
+ gimg[y*grad.nc()+x] += delta;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ void img2col(
+ matrix<float>& output,
+ const tensor& data,
+ long n,
+ long filter_nr,
+ long filter_nc,
+ long stride_y,
+ long stride_x,
+ long padding_y,
+ long padding_x
+ )
+ {
+ const auto d = data.host() + data.k()*data.nr()*data.nc()*n;
+ const rectangle boundary = get_rect(data);
+
+ const long out_nr = 1+(data.nr()+2*padding_y-filter_nr)/stride_y;
+ const long out_nc = 1+(data.nc()+2*padding_x-filter_nc)/stride_x;
+
+ output.set_size(out_nr*out_nc,
+ data.k()*filter_nr*filter_nc);
+ DLIB_CASSERT(output.size() != 0);
+ float* t = &output(0,0);
+
+ // now fill in the Toeplitz output matrix for the n-th sample in data.
+ size_t cnt = 0;
+ const long max_r = data.nr() + padding_y-(filter_nr-1);
+ const long max_c = data.nc() + padding_x-(filter_nc-1);
+ for (long r = -padding_y; r < max_r; r+=stride_y)
+ {
+ for (long c = -padding_x; c < max_c; c+=stride_x)
+ {
+ for (long k = 0; k < data.k(); ++k)
+ {
+ for (long y = 0; y < filter_nr; ++y)
+ {
+ for (long x = 0; x < filter_nc; ++x)
+ {
+ DLIB_ASSERT(cnt < output.size());
+ long xx = c+x;
+ long yy = r+y;
+ if (boundary.contains(xx,yy))
+ *t = d[(k*data.nr() + yy)*data.nc() + xx];
+ else
+ *t = 0;
+ ++t;
+ ++cnt;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ void col2img(
+ const matrix<float>& output,
+ tensor& data,
+ long n,
+ long filter_nr,
+ long filter_nc,
+ long stride_y,
+ long stride_x,
+ long padding_y,
+ long padding_x
+ )
+ {
+ const auto d = data.host() + data.k()*data.nr()*data.nc()*n;
+ const rectangle boundary = get_rect(data);
+
+ DLIB_CASSERT(output.size() != 0);
+ const float* t = &output(0,0);
+
+ // now fill in the Toeplitz output matrix for the n-th sample in data.
+ const long max_r = data.nr() + padding_y-(filter_nr-1);
+ const long max_c = data.nc() + padding_x-(filter_nc-1);
+ for (long r = -padding_y; r < max_r; r+=stride_y)
+ {
+ for (long c = -padding_x; c < max_c; c+=stride_x)
+ {
+ for (long k = 0; k < data.k(); ++k)
+ {
+ for (long y = 0; y < filter_nr; ++y)
+ {
+ for (long x = 0; x < filter_nc; ++x)
+ {
+ long xx = c+x;
+ long yy = r+y;
+ if (boundary.contains(xx,yy))
+ d[(k*data.nr() + yy)*data.nc() + xx] += *t;
+ ++t;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ void tensor_conv::operator() (
+ const bool add_to_output,
+ resizable_tensor& output,
+ const tensor& data,
+ const tensor& filters
+ )
+ {
+ DLIB_CASSERT(last_stride_y > 0 && last_stride_x > 0, "You must call setup() before calling this function.");
+ output.set_size(data.num_samples(),
+ filters.num_samples(),
+ 1+(data.nr()+2*last_padding_y-filters.nr())/last_stride_y,
+ 1+(data.nc()+2*last_padding_x-filters.nc())/last_stride_x);
+ (*this)(add_to_output, static_cast<tensor&>(output),data,filters);
+ }
+
+ void tensor_conv::operator() (
+ const bool add_to_output,
+ tensor& output,
+ const tensor& data,
+ const tensor& filters
+ )
+ {
+ DLIB_CASSERT(is_same_object(output,data) == false);
+ DLIB_CASSERT(is_same_object(output,filters) == false);
+ DLIB_CASSERT(filters.k() == data.k());
+ DLIB_CASSERT(last_stride_y > 0 && last_stride_x > 0, "You must call setup() before calling this function.");
+ DLIB_CASSERT(filters.nr() <= data.nr() + 2*last_padding_y,
+ "Filter windows must be small enough to fit into the padded image.");
+ DLIB_CASSERT(filters.nc() <= data.nc() + 2*last_padding_x,
+ "Filter windows must be small enough to fit into the padded image.");
+
+ DLIB_CASSERT(output.num_samples() == data.num_samples());
+ DLIB_CASSERT(output.k() == filters.num_samples());
+ DLIB_CASSERT(output.nr() == 1+(data.nr()+2*last_padding_y-filters.nr())/last_stride_y);
+ DLIB_CASSERT(output.nc() == 1+(data.nc()+2*last_padding_x-filters.nc())/last_stride_x);
+
+
+ matrix<float> temp;
+ for (long n = 0; n < data.num_samples(); ++n)
+ {
+ img2col(temp, data, n, filters.nr(), filters.nc(), last_stride_y, last_stride_x, last_padding_y, last_padding_x);
+
+ if (add_to_output)
+ output.add_to_sample(n, mat(filters)*trans(temp));
+ else
+ output.set_sample(n, mat(filters)*trans(temp));
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void tensor_conv::
+ get_gradient_for_data (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& filters,
+ tensor& data_gradient
+ )
+ {
+ matrix<float> temp;
+ if (!add_to_output)
+ data_gradient = 0;
+ for (long n = 0; n < gradient_input.num_samples(); ++n)
+ {
+ auto gi = mat(gradient_input.host()+gradient_input.k()*gradient_input.nr()*gradient_input.nc()*n,
+ gradient_input.k(),
+ gradient_input.nr()*gradient_input.nc());
+
+
+ temp = trans(gi)*mat(filters);
+ col2img(temp, data_gradient, n, filters.nr(), filters.nc(), last_stride_y, last_stride_x, last_padding_y, last_padding_x);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void tensor_conv::
+ get_gradient_for_filters (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& data,
+ tensor& filters_gradient
+ )
+ {
+ matrix<float> temp;
+ for (long n = 0; n < gradient_input.num_samples(); ++n)
+ {
+ auto gi = mat(gradient_input.host()+gradient_input.k()*gradient_input.nr()*gradient_input.nc()*n,
+ gradient_input.k(),
+ gradient_input.nr()*gradient_input.nc());
+
+
+ img2col(temp, data, n, filters_gradient.nr(), filters_gradient.nc(), last_stride_y, last_stride_x, last_padding_y, last_padding_x);
+ if (n == 0)
+ {
+ if (add_to_output)
+ filters_gradient += gi*temp;
+ else
+ filters_gradient = gi*temp;
+ }
+ else
+ {
+ filters_gradient += gi*temp;
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void copy_tensor(
+ bool add_to,
+ tensor& dest,
+ size_t dest_k_offset,
+ const tensor& src,
+ size_t src_k_offset,
+ size_t count_k
+ )
+ {
+ const size_t dest_sample_size = static_cast<size_t>(dest.nc() * dest.nr() * dest.k());
+ const size_t src_sample_size = static_cast<size_t>(src.nc() * src.nr() * src.k());
+
+ const size_t block_size = count_k * dest.nc() * dest.nr();
+
+ DLIB_CASSERT(dest.num_samples() == src.num_samples() &&
+ dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size");
+ DLIB_CASSERT(dest.k() - dest_k_offset >= count_k, "Not enough space in dest tensor");
+ DLIB_CASSERT(src.k() - src_k_offset >= count_k, "Not enough space in src tensor");
+
+ float* dest_p = dest.host() + dest_k_offset * dest.nc() * dest.nr();
+ const float* src_p = src.host() + src_k_offset * src.nc() * src.nr();
+
+ for (long i = 0; i < src.num_samples(); ++i)
+ {
+ if (add_to)
+ {
+ for (size_t j = 0; j < block_size; ++j)
+ dest_p[j] += src_p[j];
+ }
+ else
+ {
+ ::memcpy(dest_p, src_p, block_size * sizeof(float));
+ }
+
+ dest_p += dest_sample_size;
+ src_p += src_sample_size;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+
+#endif // DLIB_DNN_CPU_cPP_
+
+
diff --git a/ml/dlib/dlib/dnn/cpu_dlib.h b/ml/dlib/dlib/dnn/cpu_dlib.h
new file mode 100644
index 000000000..330df01a2
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cpu_dlib.h
@@ -0,0 +1,505 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CPU_H_
+#define DLIB_DNN_CPU_H_
+
+// This file contains CPU implementations of the GPU based functions in cuda_dlib.h
+// and cudnn_dlibapi.h
+
+#include "tensor.h"
+#include "../geometry/rectangle.h"
+
+namespace dlib
+{
+ namespace cpu
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ void multiply (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+
+ void multiply_conv (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+
+ void multiply_zero_padded (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+
+ void scale_channels (
+ bool add_to,
+ tensor& dest,
+ const tensor& src,
+ const tensor& scales
+ );
+
+ void add(
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& src
+ );
+
+ void assign_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ );
+
+ void add (
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+
+ void assign_conv_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A,
+ const float B
+ );
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B,
+ const float C
+ );
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C,
+ const float D
+ );
+
+ void affine_transform_range(
+ size_t begin,
+ size_t end,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform_conv(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform(
+ const rectangle& rect,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ float A,
+ float B,
+ float C
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void compute_adam_update (
+ size_t begin,
+ size_t end,
+ tensor& s,
+ tensor& m,
+ tensor& v,
+ const float t,
+ const float learning_rate,
+ const float weight_decay,
+ const float momentum1,
+ const float momentum2,
+ const tensor& params,
+ const tensor& params_grad
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void batch_normalize_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ );
+
+ void batch_normalize (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ );
+
+ void batch_normalize_gradient (
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ );
+
+ void batch_normalize_conv_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ );
+
+ void batch_normalize_conv (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ );
+
+ void batch_normalize_conv_gradient (
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void threshold (
+ tensor& data,
+ float thresh
+ );
+
+ void dot (
+ const tensor& a,
+ const tensor& b,
+ tensor& result,
+ size_t idx
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void softmax (
+ tensor& dest,
+ const tensor& src
+ );
+
+ void softmax_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ void softmax_all (
+ tensor& dest,
+ const tensor& src
+ );
+
+ void softmax_all_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ void sigmoid (
+ tensor& dest,
+ const tensor& src
+ );
+
+ void sigmoid_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ void relu (
+ tensor& dest,
+ const tensor& src
+ );
+
+ void relu_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+
+ // ----------------------------------------------------------------------------------------
+
+ void prelu (
+ tensor& dest,
+ const tensor& src,
+ const tensor& param
+ );
+
+ void prelu_gradient (
+ tensor& grad,
+ const tensor& src,
+ const tensor& gradient_input,
+ const tensor& param,
+ tensor& params_grad
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ void tanh (
+ tensor& dest,
+ const tensor& src
+ );
+
+ void tanh_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+
+ // ----------------------------------------------------------------------------------------
+
+ void resize_bilinear (
+ tensor& dest,
+ long dest_row_stride,
+ long dest_channel_stride,
+ const tensor& src,
+ long src_row_stride,
+ long src_channel_stride
+ );
+
+ void resize_bilinear_gradient (
+ tensor& grad,
+ long grad_row_stride,
+ long grad_channel_stride,
+ const tensor& gradient_input,
+ long gradient_input_row_stride,
+ long gradient_input_channel_stride
+ );
+
+ inline void resize_bilinear (
+ tensor& dest,
+ const tensor& src
+ ) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); }
+
+ inline void resize_bilinear_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ ) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
+
+ // -----------------------------------------------------------------------------------
+
+ class pooling
+ {
+ public:
+
+ pooling(const pooling&) = delete;
+ pooling& operator=(const pooling&) = delete;
+
+ pooling (
+ );
+
+ void clear(
+ );
+
+ void setup_max_pooling(
+ int window_height,
+ int window_width,
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x
+ );
+
+ void setup_avg_pooling(
+ int window_height,
+ int window_width,
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x
+ );
+
+ bool does_max_pooling(
+ ) const { return do_max_pooling; }
+
+ void operator() (
+ resizable_tensor& dest,
+ const tensor& src
+ );
+
+ void get_gradient(
+ const tensor& gradient_input,
+ const tensor& dest,
+ const tensor& src,
+ tensor& grad
+ );
+
+ private:
+ int window_height;
+ int window_width;
+ int stride_y;
+ int stride_x;
+ int padding_y;
+ int padding_x;
+ bool do_max_pooling;
+
+ };
+
+ // -----------------------------------------------------------------------------------
+
+ class tensor_conv
+ {
+ public:
+ tensor_conv(const tensor_conv&) = delete;
+ tensor_conv& operator=(const tensor_conv&) = delete;
+
+ tensor_conv() {}
+
+ void clear(
+ ) {}
+
+ void setup(
+ const tensor& data, /* not used but required for interface */
+ const tensor& filters, /* not used but required for interface */
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x
+ )
+ {
+ (void)data; /* silence compiler */
+ DLIB_CASSERT(stride_y > 0 && stride_x > 0);
+ DLIB_CASSERT(0 <= padding_y && padding_y < filters.nr());
+ DLIB_CASSERT(0 <= padding_x && padding_x < filters.nc());
+ last_stride_y = stride_y;
+ last_stride_x = stride_x;
+ last_padding_y = padding_y;
+ last_padding_x = padding_x;
+ }
+
+ void operator() (
+ const bool add_to_output,
+ resizable_tensor& output,
+ const tensor& data,
+ const tensor& filters
+ );
+
+ void operator() (
+ const bool add_to_output,
+ tensor& output,
+ const tensor& data,
+ const tensor& filters
+ );
+
+ void get_gradient_for_data (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& filters,
+ tensor& data_gradient
+ );
+
+ void get_gradient_for_filters (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& data,
+ tensor& filters_gradient
+ );
+
+ private:
+
+ long last_stride_y = 0;
+ long last_stride_x = 0;
+ long last_padding_y = 0;
+ long last_padding_x = 0;
+ };
+
+ // -----------------------------------------------------------------------------------
+
+ void copy_tensor(
+ bool add_to,
+ tensor& dest,
+ size_t dest_k_offset,
+ const tensor& src,
+ size_t src_k_offset,
+ size_t count_k
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ }
+}
+
+#ifdef NO_MAKEFILE
+#include "cpu_dlib.cpp"
+#endif
+
+#endif // DLIB_DNN_CPU_H_
+
+
diff --git a/ml/dlib/dlib/dnn/cublas_dlibapi.cpp b/ml/dlib/dlib/dnn/cublas_dlibapi.cpp
new file mode 100644
index 000000000..376cc9f00
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cublas_dlibapi.cpp
@@ -0,0 +1,165 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuBLAS_CPP_
+#define DLIB_DNN_CuBLAS_CPP_
+
+#ifdef DLIB_USE_CUDA
+
+#include "cublas_dlibapi.h"
+#include "cuda_utils.h"
+
+#include <cublas_v2.h>
+#include <vector>
+
+static const char* cublas_get_error_string(cublasStatus_t s)
+{
+ switch(s)
+ {
+ case CUBLAS_STATUS_NOT_INITIALIZED:
+ return "CUDA Runtime API initialization failed.";
+ case CUBLAS_STATUS_ALLOC_FAILED:
+ return "CUDA Resources could not be allocated.";
+ default:
+ return "A call to cuBLAS failed";
+ }
+}
+
+// Check the return value of a call to the cuBLAS runtime for an error condition.
+#define CHECK_CUBLAS(call) \
+do{ \
+ const cublasStatus_t error = call; \
+ if (error != CUBLAS_STATUS_SUCCESS) \
+ { \
+ std::ostringstream sout; \
+ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
+ sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\
+ throw dlib::cublas_error(sout.str()); \
+ } \
+}while(false)
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ class cublas_context
+ {
+ public:
+ // not copyable
+ cublas_context(const cublas_context&) = delete;
+ cublas_context& operator=(const cublas_context&) = delete;
+
+ cublas_context()
+ {
+ handles.resize(16);
+ }
+ ~cublas_context()
+ {
+ for (auto h : handles)
+ {
+ if (h)
+ cublasDestroy(h);
+ }
+ }
+
+ cublasHandle_t get_handle (
+ )
+ {
+ int new_device_id;
+ CHECK_CUDA(cudaGetDevice(&new_device_id));
+ // make room for more devices if needed
+ if (new_device_id >= (long)handles.size())
+ handles.resize(new_device_id+16);
+
+ // If we don't have a handle already for this device then make one
+ if (!handles[new_device_id])
+ CHECK_CUBLAS(cublasCreate(&handles[new_device_id]));
+
+ // Finally, return the handle for the current device
+ return handles[new_device_id];
+ }
+
+ private:
+
+ std::vector<cublasHandle_t> handles;
+ };
+
+ static cublasHandle_t context()
+ {
+ thread_local cublas_context c;
+ return c.get_handle();
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ void gemm (
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& lhs,
+ bool trans_lhs,
+ const tensor& rhs,
+ bool trans_rhs
+ )
+ {
+ // Recall that BLAS uses column major order so to deal with that we flip the
+ // order of the lhs and rhs arguments.
+ const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N;
+ const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+ const int dest_nr = dest.num_samples();
+ const int dest_nc = dest.size()/dest_nr;
+ const int lhs_nr = lhs.num_samples();
+ const int lhs_nc = lhs.size()/lhs_nr;
+ const int rhs_nr = rhs.num_samples();
+ const int rhs_nc = rhs.size()/rhs_nr;
+ if (trans_lhs && trans_rhs)
+ {
+ DLIB_ASSERT( dest_nr == lhs_nc &&
+ dest_nc == rhs_nr &&
+ lhs_nr == rhs_nc)
+ }
+ else if (!trans_lhs && trans_rhs)
+ {
+ DLIB_ASSERT( dest_nr == lhs_nr &&
+ dest_nc == rhs_nr &&
+ lhs_nc == rhs_nc)
+ }
+ else if (trans_lhs && !trans_rhs)
+ {
+ DLIB_ASSERT( dest_nr == lhs_nc &&
+ dest_nc == rhs_nc &&
+ lhs_nr == rhs_nr)
+ }
+ else
+ {
+ DLIB_ASSERT( dest_nr == lhs_nr &&
+ dest_nc == rhs_nc &&
+ lhs_nc == rhs_nr)
+ }
+
+ const int k = trans_rhs ? rhs_nc : rhs_nr;
+ CHECK_CUBLAS(cublasSgemm(context(),
+ transb,
+ transa,
+ dest_nc, dest_nr, k,
+ &alpha,
+ rhs.device(), rhs_nc,
+ lhs.device(), lhs_nc,
+ &beta,
+ dest.device(),dest_nc));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuBLAS_CPP_
+
+
+
diff --git a/ml/dlib/dlib/dnn/cublas_dlibapi.h b/ml/dlib/dlib/dnn/cublas_dlibapi.h
new file mode 100644
index 000000000..b46fd25ca
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cublas_dlibapi.h
@@ -0,0 +1,50 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuBLAS_H_
+#define DLIB_DNN_CuBLAS_H_
+
+#ifdef DLIB_USE_CUDA
+
+#include "tensor.h"
+#include "cuda_errors.h"
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ void gemm (
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& lhs,
+ bool trans_lhs,
+ const tensor& rhs,
+ bool trans_rhs
+ );
+ /*!
+ requires
+ - The dimensions of lhs and rhs must be compatible for matrix
+ multiplication. In particular:
+ - Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
+ - Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
+ - Let D == mat(dest)
+ - D.nr() == L.nr() && D.nc() == R.nc()
+ (i.e. dest must be preallocated and have the correct output dimensions)
+ - L.nc() == R.nr()
+ ensures
+ - performs: dest = alpha*L*R + beta*mat(dest)
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuBLAS_H_
+
+
diff --git a/ml/dlib/dlib/dnn/cuda_data_ptr.cpp b/ml/dlib/dlib/dnn/cuda_data_ptr.cpp
new file mode 100644
index 000000000..8abce0695
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cuda_data_ptr.cpp
@@ -0,0 +1,71 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuDA_DATA_PTR_CPP_
+#define DLIB_DNN_CuDA_DATA_PTR_CPP_
+
+#ifdef DLIB_USE_CUDA
+
+#include "cuda_data_ptr.h"
+#include "cuda_utils.h"
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ cuda_data_void_ptr::
+ cuda_data_void_ptr(
+ size_t n
+ ) : num(n)
+ {
+ if (n == 0)
+ return;
+
+ void* data = nullptr;
+
+ CHECK_CUDA(cudaMalloc(&data, n));
+ pdata.reset(data, [](void* ptr){
+ auto err = cudaFree(ptr);
+ if(err!=cudaSuccess)
+ std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl;
+ });
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void memcpy(
+ void* dest,
+ const cuda_data_void_ptr& src
+ )
+ {
+ if (src.size() != 0)
+ {
+ CHECK_CUDA(cudaMemcpy(dest, src.data(), src.size(), cudaMemcpyDefault));
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void memcpy(
+ cuda_data_void_ptr& dest,
+ const void* src
+ )
+ {
+ if (dest.size() != 0)
+ {
+ CHECK_CUDA(cudaMemcpy(dest.data(), src, dest.size(), cudaMemcpyDefault));
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuDA_DATA_PTR_CPP_
+
+
diff --git a/ml/dlib/dlib/dnn/cuda_data_ptr.h b/ml/dlib/dlib/dnn/cuda_data_ptr.h
new file mode 100644
index 000000000..7eca608a0
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cuda_data_ptr.h
@@ -0,0 +1,184 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuDA_DATA_PTR_H_
+#define DLIB_DNN_CuDA_DATA_PTR_H_
+
+#ifdef DLIB_USE_CUDA
+
+#include <memory>
+#include <vector>
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ class cuda_data_void_ptr
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a block of memory on a CUDA device.
+ !*/
+ public:
+
+ cuda_data_void_ptr() = default;
+
+ cuda_data_void_ptr(size_t n);
+ /*!
+ ensures
+ - This object will allocate a device memory buffer of n bytes.
+ - #size() == n
+ !*/
+
+ void* data() { return pdata.get(); }
+ const void* data() const { return pdata.get(); }
+ operator void*() { return pdata.get(); }
+ operator const void*() const { return pdata.get(); }
+
+ void reset() { pdata.reset(); }
+
+ size_t size() const { return num; }
+ /*!
+ ensures
+ - returns the length of this buffer, in bytes.
+ !*/
+
+ private:
+
+ size_t num = 0;
+ std::shared_ptr<void> pdata;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ void memcpy(
+ void* dest,
+ const cuda_data_void_ptr& src
+ );
+ /*!
+ requires
+ - dest == a pointer to at least src.size() bytes on the host machine.
+ ensures
+ - copies the GPU data from src into dest.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ void memcpy(
+ cuda_data_void_ptr& dest,
+ const void* src
+ );
+ /*!
+ requires
+ - dest == a pointer to at least src.size() bytes on the host machine.
+ ensures
+ - copies the host data from src to the GPU memory buffer dest.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ class cuda_data_ptr
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a block of memory on a CUDA device. It is just a type safe
+ version of cuda_data_void_ptr.
+ !*/
+
+ public:
+
+ static_assert(std::is_standard_layout<T>::value, "You can only create basic standard layout types on the GPU");
+
+ cuda_data_ptr() = default;
+ cuda_data_ptr(size_t n) : num(n)
+ /*!
+ ensures
+ - This object will allocate a device memory buffer of n T objects.
+ - #size() == n
+ !*/
+ {
+ if (n == 0)
+ return;
+
+ pdata = cuda_data_void_ptr(n*sizeof(T));
+ }
+
+ T* data() { return (T*)pdata.data(); }
+ const T* data() const { return (T*)pdata.data(); }
+
+ operator T*() { return (T*)pdata.data(); }
+ operator const T*() const { return (T*)pdata.data(); }
+
+ void reset() { pdata.reset(); }
+
+ size_t size() const { return num; }
+
+
+ friend void memcpy(
+ std::vector<T>& dest,
+ const cuda_data_ptr& src
+ )
+ {
+ dest.resize(src.size());
+ if (src.size() != 0)
+ memcpy(dest.data(), src.pdata);
+ }
+
+ friend void memcpy(
+ cuda_data_ptr& src,
+ const std::vector<T>& dest
+ )
+ {
+ if (dest.size() != src.size())
+ dest = cuda_data_ptr<T>(src.size());
+
+ if (src.size() != 0)
+ memcpy(src.pdata, dest.data());
+ }
+
+ private:
+
+ size_t num = 0;
+ cuda_data_void_ptr pdata;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ class resizable_cuda_buffer
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a block of memory on a CUDA device that will be automatically
+ resized if requested size is larger than allocated.
+ !*/
+ public:
+ cuda_data_void_ptr get(size_t size)
+ /*!
+ ensures
+ - This object will return the buffer of requested size of larger
+ - buffer.size() >= size
+ !*/
+ {
+ if (buffer.size() < size)
+ {
+ buffer.reset();
+ buffer = cuda_data_void_ptr(size);
+ }
+ return buffer;
+ }
+ private:
+ cuda_data_void_ptr buffer;
+ };
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuDA_DATA_PTR_H_
+
diff --git a/ml/dlib/dlib/dnn/cuda_dlib.cu b/ml/dlib/dlib/dnn/cuda_dlib.cu
new file mode 100644
index 000000000..6c37593f1
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cuda_dlib.cu
@@ -0,0 +1,1630 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "cuda_utils.h"
+#include "cuda_dlib.h"
+
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ void set_device (
+ int dev
+ )
+ {
+ CHECK_CUDA(cudaSetDevice(dev));
+ }
+
+ int get_device (
+ )
+ {
+ int dev = 0;
+ CHECK_CUDA(cudaGetDevice(&dev));
+ return dev;
+ }
+
+ std::string get_device_name (
+ int device
+ )
+ {
+ cudaDeviceProp props;
+ CHECK_CUDA(cudaGetDeviceProperties(&props, device));
+ return props.name;
+ }
+
+ void set_current_device_blocking_sync(
+ )
+ {
+ CHECK_CUDA(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
+ }
+
+ int get_num_devices (
+ )
+ {
+ int num_devices;
+ CHECK_CUDA(cudaGetDeviceCount(&num_devices));
+ return num_devices;
+ }
+
+ bool can_access_peer (int device_id, int peer_device_id)
+ {
+ int can_access;
+ CHECK_CUDA(cudaDeviceCanAccessPeer(&can_access, device_id, peer_device_id));
+ return can_access != 0;
+ }
+ bool can_access_peer (const tensor& device, const tensor& peer_device)
+ {
+ return can_access_peer(device.device_id(), peer_device.device_id());
+ }
+
+ void device_synchronize (int dev)
+ {
+ raii_set_device set_dev(dev);
+ CHECK_CUDA(cudaDeviceSynchronize());
+ }
+ void device_synchronize (const tensor& dev) { device_synchronize(dev.device_id()); }
+
+ enable_peer_access::
+ enable_peer_access(
+ int device_id,
+ int peer_device_id
+ ) : call_disable(false), device_id(device_id), peer_device_id(peer_device_id)
+ {
+ raii_set_device set_dev(device_id);
+
+ auto err = cudaDeviceEnablePeerAccess(peer_device_id, 0);
+ if (err == cudaSuccess)
+ {
+ call_disable = true;
+ }
+ else if (err == cudaErrorPeerAccessAlreadyEnabled)
+ {
+ // call cudaGetLastError() to dispose of this error since we don't
+ // care.
+ auto err2 = cudaGetLastError();
+ if (err2 != cudaErrorPeerAccessAlreadyEnabled)
+ CHECK_CUDA(err2);
+ }
+ else
+ {
+ CHECK_CUDA(err);
+ }
+ }
+
+
+ enable_peer_access::
+ ~enable_peer_access() noexcept(false)
+ {
+ if (call_disable)
+ {
+ raii_set_device set_dev(device_id);
+ CHECK_CUDA(cudaDeviceDisablePeerAccess(peer_device_id));
+ }
+ }
+
+ // -----------------------------------------------------------------------------------
+ // -----------------------------------------------------------------------------------
+ // -----------------------------------------------------------------------------------
+
+ __global__ void _cuda_inverse_norms(float* invnorms, const float* data, size_t nr, size_t nc, const float eps)
+ {
+ // initialize invnorms before we begin.
+ for (auto i : grid_stride_range_y(0, nr))
+ for (auto j : grid_stride_range(0, 1))
+ invnorms[i] = eps;
+ __syncthreads();
+
+ for (auto i : grid_stride_range_y(0, nr))
+ {
+ auto p = data + i*nc;
+ float temp = 0;
+ for (auto j : grid_stride_range(0, nc))
+ temp += p[j]*p[j];
+
+ // and store the sum into invnorms[i]
+ warp_reduce_atomic_add(invnorms[i], temp);
+ }
+ __syncthreads();
+
+ for (auto i : grid_stride_range_y(0, nr))
+ for (auto j : grid_stride_range(0, 1))
+ invnorms[i] = 1.0/std::sqrt(invnorms[i]);
+ }
+
+ void inverse_norms (
+ resizable_tensor& invnorms,
+ const tensor& data,
+ const double eps
+ )
+ {
+ invnorms.set_size(data.num_samples());
+ launch_kernel(_cuda_inverse_norms, max_jobs(data.size()/data.num_samples(), data.num_samples()),
+ invnorms.device(), data.device(), data.num_samples(), data.size()/data.num_samples(), eps);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_dot_prods(float* out, const float* lhs, const float* rhs, size_t nr, size_t nc)
+ {
+ // initialize out before we begin.
+ for (auto i : grid_stride_range_y(0, nr))
+ for (auto j : grid_stride_range(0, 1))
+ out[i] = 0;
+ __syncthreads();
+
+ for (auto i : grid_stride_range_y(0, nr))
+ {
+ auto l = lhs + i*nc;
+ auto r = rhs + i*nc;
+ float temp = 0;
+ for (auto j : grid_stride_range(0, nc))
+ temp += l[j]*r[j];
+
+ // and store the sum into out[i]
+ warp_reduce_atomic_add(out[i], temp);
+ }
+ }
+
+ __global__ void _cuda_dot_prods_add_to(float* out, const float* lhs, const float* rhs, size_t nr, size_t nc)
+ {
+ for (auto i : grid_stride_range_y(0, nr))
+ {
+ auto l = lhs + i*nc;
+ auto r = rhs + i*nc;
+ float temp = 0;
+ for (auto j : grid_stride_range(0, nc))
+ temp += l[j]*r[j];
+
+ // and store the sum into out[i]
+ warp_reduce_atomic_add(out[i], temp);
+ }
+ }
+
+ void dot_prods (
+ resizable_tensor& out,
+ const tensor& lhs,
+ const tensor& rhs
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(lhs,rhs));
+
+ out.set_size(lhs.num_samples());
+ if (out.size() == 0)
+ return;
+
+ const auto nr = lhs.num_samples();
+ const auto nc = lhs.size()/lhs.num_samples();
+
+ launch_kernel(_cuda_dot_prods, max_jobs(nc,nr), out.device_write_only(), lhs.device(), rhs.device(), nr, nc);
+ }
+
+ void dot_prods (
+ bool add_to,
+ tensor& out,
+ const tensor& lhs,
+ const tensor& rhs
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(lhs,rhs));
+ DLIB_CASSERT(out.k() == 1 && out.nr() == 1 && out.nc() == 1);
+ DLIB_CASSERT(out.size() == lhs.num_samples());
+
+ const auto nr = lhs.num_samples();
+ const auto nc = lhs.size()/lhs.num_samples();
+
+ if (add_to)
+ launch_kernel(_cuda_dot_prods_add_to, max_jobs(nc,nr), out.device(), lhs.device(), rhs.device(), nr, nc);
+ else
+ launch_kernel(_cuda_dot_prods, max_jobs(nc,nr), out.device_write_only(), lhs.device(), rhs.device(), nr, nc);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_scale_columns(float* out, const float* m, const float* v, size_t nr, size_t nc)
+ {
+ for (auto j : grid_stride_range(0, nr*nc))
+ {
+ out[j] = m[j]*v[j%nc];
+ }
+ }
+
+ void scale_columns (
+ tensor& out,
+ const tensor& m,
+ const tensor& v
+ )
+ {
+ launch_kernel(_cuda_scale_columns, max_jobs(m.size()), out.device(), m.device(), v.device(), m.num_samples(), m.size()/m.num_samples());
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_scale_rows(float* out, const float* m, const float* v, size_t nr, size_t nc)
+ {
+ for (auto j : grid_stride_range(0, nr*nc))
+ {
+ out[j] = m[j]*v[j/nc];
+ }
+ }
+
+ void scale_rows (
+ tensor& out,
+ const tensor& m,
+ const tensor& v
+ )
+ {
+ launch_kernel(_cuda_scale_rows, max_jobs(m.size()), out.device(), m.device(), v.device(), m.num_samples(), m.size()/m.num_samples());
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_scale_rows2(float* out, const float* m1, const float* m2, const float* v1, const float* v2, size_t nr, size_t nc)
+ {
+ for (auto j : grid_stride_range(0, nr*nc))
+ {
+ out[j] = (m1[j] - m2[j]*v1[j/nc]) * v2[j/nc];
+ }
+ }
+
+ __global__ void _cuda_scale_rows2_beta(const float beta, float* out, const float* m1, const float* m2, const float* v1, const float* v2, size_t nr, size_t nc)
+ {
+ for (auto j : grid_stride_range(0, nr*nc))
+ {
+ out[j] = beta*out[j] + (m1[j] - m2[j]*v1[j/nc]) * v2[j/nc];
+ }
+ }
+
+ void scale_rows2 (
+ float beta,
+ tensor& out,
+ const tensor& m1,
+ const tensor& m2,
+ const tensor& v1,
+ const tensor& v2
+ )
+ {
+ if (beta == 0)
+ {
+ launch_kernel(_cuda_scale_rows2, max_jobs(m1.size()), out.device(),
+ m1.device(), m2.device(), v1.device(), v2.device(), m1.num_samples(),
+ m1.size()/m1.num_samples());
+ }
+ else
+ {
+ launch_kernel(_cuda_scale_rows2_beta, max_jobs(m1.size()), beta,
+ out.device(), m1.device(), m2.device(), v1.device(), v2.device(),
+ m1.num_samples(), m1.size()/m1.num_samples());
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_exp(float* dest, const float* src, size_t n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ dest[i] = ::exp(src[i]);
+ }
+
+ void exp (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_ASSERT(dest.size() == src.size());
+ launch_kernel(_cuda_exp, max_jobs(src.size()), dest.device(), src.device(), src.size());
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_log(float* dest, const float* src, size_t n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ dest[i] = ::log(src[i]);
+ }
+
+ void log (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_ASSERT(dest.size() == src.size());
+ launch_kernel(_cuda_log, max_jobs(src.size()), dest.device(), src.device(), src.size());
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_log10(float* dest, const float* src, size_t n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ dest[i] = ::log10(src[i]);
+ }
+
+ void log10 (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_ASSERT(dest.size() == src.size());
+ launch_kernel(_cuda_log10, max_jobs(src.size()), dest.device(), src.device(), src.size());
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ __global__ void _cuda_multiply1(float* d, const float* s1, const float* s2, size_t n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = s1[i]*s2[i];
+ }
+ }
+ __global__ void _cuda_multiply2(float* d, const float* s1, const float* s2,
+ size_t n, size_t s1_n, size_t s2_n, size_t max_size)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = 0;
+ for (size_t j = i; j < max_size; j += n)
+ d[i] += s1[j%s1_n]*s2[j%s2_n];
+ }
+ }
+
+ __global__ void _cuda_multiply3(float* d, const float* s1, const float* s2,
+ size_t n, size_t s1_n, size_t s2_n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = s1[i%s1_n]*s2[i%s2_n];
+ }
+ }
+
+ __global__ void _cuda_multiply1_add_to(float* d, const float* s1, const float* s2, size_t n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] += s1[i]*s2[i];
+ }
+ }
+ __global__ void _cuda_multiply2_add_to(float* d, const float* s1, const float* s2,
+ size_t n, size_t s1_n, size_t s2_n, size_t max_size)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ for (size_t j = i; j < max_size; j += n)
+ d[i] += s1[j%s1_n]*s2[j%s2_n];
+ }
+ }
+
+ __global__ void _cuda_multiply3_add_to(float* d, const float* s1, const float* s2,
+ size_t n, size_t s1_n, size_t s2_n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] += s1[i%s1_n]*s2[i%s2_n];
+ }
+ }
+
+ void multiply (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+
+ DLIB_CASSERT(dest.k() == src1.k() && src1.k() == src2.k() &&
+ dest.nr() == src1.nr() && src1.nr() == src2.nr() &&
+ dest.nc() == src1.nc() && src1.nc() == src2.nc() );
+ const long MD = std::max(std::max(dest.num_samples(),src1.num_samples()),src2.num_samples());
+ DLIB_CASSERT((dest.num_samples()==1 || dest.num_samples()==MD) &&
+ (src1.num_samples()==1 || src1.num_samples()==MD) &&
+ (src2.num_samples()==1 || src2.num_samples()==MD) );
+
+ if (dest.size() == 0)
+ return;
+
+ const size_t max_size = std::max(std::max(dest.size(),src1.size()),src2.size());
+ const auto d = dest.host();
+ const auto s1 = src1.host();
+ const auto s2 = src2.host();
+ if (dest.size() == src1.size() && src1.size() == src2.size())
+ {
+ if (add_to)
+ launch_kernel(_cuda_multiply1_add_to,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), src1.size());
+ else
+ launch_kernel(_cuda_multiply1,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), src1.size());
+ }
+ else if (dest.num_samples() == 1)
+ {
+ if (add_to)
+ launch_kernel(_cuda_multiply2_add_to,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(),
+ dest.size(), src1.size(), src2.size(), max_size);
+ else
+ launch_kernel(_cuda_multiply2,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(),
+ dest.size(), src1.size(), src2.size(), max_size);
+ }
+ else
+ {
+ if (add_to)
+ launch_kernel(_cuda_multiply3_add_to,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(),
+ dest.size(), src1.size(), src2.size());
+ else
+ launch_kernel(_cuda_multiply3,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(),
+ dest.size(), src1.size(), src2.size());
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ __global__ void _cuda_multiply_conv(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ auto k = (i/bs)%ks;
+ d[i] = s1[i]*s2[k];
+ }
+ }
+
+ __global__ void _cuda_multiply_conv2(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks)
+ {
+ // zero initialize d before we begin.
+ for (auto i : grid_stride_range_y(0, ks))
+ for (auto j : grid_stride_range(0, 1))
+ d[i] = 0;
+ __syncthreads();
+
+ // loop over all the image planes
+ for (auto i : grid_stride_range_y(0, n))
+ {
+ // sum all the elements in the i-th image plane
+ float temp = 0;
+ for (auto j : grid_stride_range(i*bs, (i+1)*bs))
+ temp += s1[j]*s2[j];
+ auto k = i%ks;
+ // and store the sum into d[k]
+ warp_reduce_atomic_add(d[k], temp);
+ }
+ }
+
+ __global__ void _cuda_multiply_conv_add_to(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ auto k = (i/bs)%ks;
+ d[i] += s1[i]*s2[k];
+ }
+ }
+
+ __global__ void _cuda_multiply_conv2_add_to(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks)
+ {
+ // loop over all the image planes
+ for (auto i : grid_stride_range_y(0, n))
+ {
+ // sum all the elements in the i-th image plane
+ float temp = 0;
+ for (auto j : grid_stride_range(i*bs, (i+1)*bs))
+ temp += s1[j]*s2[j];
+ auto k = i%ks;
+ // and store the sum into d[k]
+ warp_reduce_atomic_add(d[k], temp);
+ }
+ }
+
+
+ void multiply_conv (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+ if (have_same_dimensions(dest,src1))
+ {
+ DLIB_CASSERT(src2.num_samples() == 1 && src2.nr() == 1 && src2.nc() == 1 && src2.k() == src1.k());
+ if (dest.size() == 0)
+ return;
+
+ if (add_to)
+ launch_kernel(_cuda_multiply_conv_add_to,max_jobs(dest.size()),
+ dest.device(), src1.device(), src1.size(), src2.device(), src1.nr()*src1.nc(), src1.k());
+ else
+ launch_kernel(_cuda_multiply_conv,max_jobs(dest.size()),
+ dest.device(), src1.device(), src1.size(), src2.device(), src1.nr()*src1.nc(), src1.k());
+ }
+ else
+ {
+ DLIB_CASSERT(have_same_dimensions(src1,src2));
+ DLIB_CASSERT(dest.num_samples() == 1 && dest.nr() == 1 && dest.nc() == 1 && dest.k() == src1.k());
+ if (dest.size() == 0)
+ return;
+
+
+ const auto bs = src1.nr()*src1.nc();
+ const auto n = src1.num_samples()*src1.k();
+ if (add_to)
+ launch_kernel(_cuda_multiply_conv2_add_to, max_jobs(bs,n),
+ dest.device(), src1.device(), n, src2.device(), bs, src1.k());
+ else
+ launch_kernel(_cuda_multiply_conv2, max_jobs(bs,n),
+ dest.device(), src1.device(), n, src2.device(), bs, src1.k());
+ }
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ __global__ void _cuda_scale_channels_add_to(float* d, const float* src, size_t n, const float* scales, size_t bs)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ auto k = i/bs;
+ d[i] += src[i]*scales[k];
+ }
+ }
+
+ __global__ void _cuda_scale_channels(float* d, const float* src, size_t n, const float* scales, size_t bs)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ auto k = i/bs;
+ d[i] = src[i]*scales[k];
+ }
+ }
+
+ void scale_channels (
+ bool add_to,
+ tensor& dest,
+ const tensor& src,
+ const tensor& scales
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src) &&
+ scales.num_samples() == src.num_samples() &&
+ scales.k() == src.k() &&
+ scales.nr() == 1 &&
+ scales.nc() == 1 );
+
+ if (dest.size() == 0)
+ return;
+
+ if (add_to)
+ launch_kernel(_cuda_scale_channels_add_to,max_jobs(dest.size()),
+ dest.device(), src.device(), src.size(), scales.device(), src.nr()*src.nc());
+ else
+ launch_kernel(_cuda_scale_channels,max_jobs(dest.size()),
+ dest.device_write_only(), src.device(), src.size(), scales.device(), src.nr()*src.nc());
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ __global__ void _cuda_mult1(float* d, const float* s1, const float* s2, size_t n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = s1[i]*s2[i];
+ }
+ }
+
+ __global__ void _cuda_mult1_add_to(float* d, const float* s1, const float* s2, size_t n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] += s1[i]*s2[i];
+ }
+ }
+
+ __global__ void _cuda_mult2(float* d, const float* s1, const float* s2,
+ size_t dn, size_t dk, size_t dr, size_t dc,
+ size_t s1n, size_t s1k, size_t s1r, size_t s1c,
+ size_t s2n, size_t s2k, size_t s2r, size_t s2c)
+ {
+ for (auto i : grid_stride_range(0, dn*dk*dr*dc))
+ {
+ size_t n,k,r,c;
+ unpack_idx(i, dk,dr,dc, n,k,r,c);
+
+ float v1 = 0;
+ float v2 = 0;
+
+ if (n < s1n &&
+ k < s1k &&
+ r < s1r &&
+ c < s1c )
+ {
+ v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)];
+ }
+
+ if (n < s2n &&
+ k < s2k &&
+ r < s2r &&
+ c < s2c )
+ {
+ v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)];
+ }
+
+ d[i] = v1*v2;
+ }
+ }
+
+ __global__ void _cuda_mult2_add_to(float* d, const float* s1, const float* s2,
+ size_t dn, size_t dk, size_t dr, size_t dc,
+ size_t s1n, size_t s1k, size_t s1r, size_t s1c,
+ size_t s2n, size_t s2k, size_t s2r, size_t s2c)
+ {
+ for (auto i : grid_stride_range(0, dn*dk*dr*dc))
+ {
+ size_t n,k,r,c;
+ unpack_idx(i, dk,dr,dc, n,k,r,c);
+
+ float v1 = 0;
+ float v2 = 0;
+
+ if (n < s1n &&
+ k < s1k &&
+ r < s1r &&
+ c < s1c )
+ {
+ v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)];
+ }
+
+ if (n < s2n &&
+ k < s2k &&
+ r < s2r &&
+ c < s2c )
+ {
+ v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)];
+ }
+
+ d[i] += v1*v2;
+ }
+ }
+
+ void multiply_zero_padded (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+ if (dest.size() == 0)
+ return;
+
+ // Do the simple and fast version if everything has the same dimensions
+ if (have_same_dimensions(dest, src1) &&
+ have_same_dimensions(dest, src2))
+ {
+ if (add_to)
+ launch_kernel(_cuda_mult1_add_to,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size());
+ else
+ launch_kernel(_cuda_mult1,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size());
+ }
+ else
+ {
+ if (add_to)
+ {
+ // Otherwise, do the more complex version with bounds checking.
+ launch_kernel(_cuda_mult2_add_to,max_jobs(dest.size()),
+ dest.device(), src1.device(), src2.device(),
+ dest.num_samples(), dest.k(), dest.nr(), dest.nc(),
+ src1.num_samples(), src1.k(), src1.nr(), src1.nc(),
+ src2.num_samples(), src2.k(), src2.nr(), src2.nc()
+ );
+ }
+ else
+ {
+ // Otherwise, do the more complex version with bounds checking.
+ launch_kernel(_cuda_mult2,max_jobs(dest.size()),
+ dest.device(), src1.device(), src2.device(),
+ dest.num_samples(), dest.k(), dest.nr(), dest.nc(),
+ src1.num_samples(), src1.k(), src1.nr(), src1.nc(),
+ src2.num_samples(), src2.k(), src2.nr(), src2.nc()
+ );
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ __global__ void _cuda_add1(float* d, const float* s1, const float* s2, size_t n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = s1[i]+s2[i];
+ }
+ }
+
+ __global__ void _cuda_add2(float* d, const float* s1, const float* s2,
+ size_t dn, size_t dk, size_t dr, size_t dc,
+ size_t s1n, size_t s1k, size_t s1r, size_t s1c,
+ size_t s2n, size_t s2k, size_t s2r, size_t s2c)
+ {
+ for (auto i : grid_stride_range(0, dn*dk*dr*dc))
+ {
+ size_t n,k,r,c;
+ unpack_idx(i, dk,dr,dc, n,k,r,c);
+
+ float v1 = 0;
+ float v2 = 0;
+
+ if (n < s1n &&
+ k < s1k &&
+ r < s1r &&
+ c < s1c )
+ {
+ v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)];
+ }
+
+ if (n < s2n &&
+ k < s2k &&
+ r < s2r &&
+ c < s2c )
+ {
+ v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)];
+ }
+
+ d[i] = v1+v2;
+ }
+ }
+
+ void add (
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+ if (dest.size() == 0)
+ return;
+
+ // Do the simple and fast version if everything has the same dimensions
+ if (have_same_dimensions(dest, src1) &&
+ have_same_dimensions(dest, src2))
+ {
+ launch_kernel(_cuda_add1,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size());
+ }
+ else
+ {
+ // Otherwise, do the more complex version with bounds checking.
+ launch_kernel(_cuda_add2,max_jobs(dest.size()),
+ dest.device(), src1.device(), src2.device(),
+ dest.num_samples(), dest.k(), dest.nr(), dest.nc(),
+ src1.num_samples(), src1.k(), src1.nr(), src1.nc(),
+ src2.num_samples(), src2.k(), src2.nr(), src2.nc()
+ );
+ }
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ __global__ void _cuda_affine_transform1(float* d, const float* s, size_t n, float A, float B)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = A*s[i] + B;
+ }
+ }
+
+ __global__ void _cuda_affine_transform1_0(float* d, const float* s, size_t n, float A)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = A*s[i];
+ }
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A,
+ const float B
+ )
+ {
+ DLIB_CASSERT(dest.size()==src.size());
+ if (B != 0)
+ launch_kernel(_cuda_affine_transform1,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A, B);
+ else
+ launch_kernel(_cuda_affine_transform1_0,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A);
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A
+ )
+ {
+ DLIB_CASSERT(dest.size()==src.size());
+ launch_kernel(_cuda_affine_transform1_0,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_affine_transform_rect(
+ float* d,
+ const float* s1,
+ const float* s2,
+ const float* s3,
+ float A,
+ float B,
+ float C,
+ size_t start_idx,
+ size_t n,
+ size_t rect_nc,
+ size_t total_nc
+ )
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ size_t r = i/rect_nc;
+ size_t c = i%rect_nc;
+ size_t idx = r*total_nc + c + start_idx;
+ d[idx] = A*s1[idx] + B*s2[idx] + C*s3[idx];
+ }
+ }
+
+ void affine_transform(
+ const rectangle& rect,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ float A,
+ float B,
+ float C
+ )
+ {
+ DLIB_CASSERT(dest.size() == src1.size());
+ DLIB_CASSERT(dest.size() == src2.size());
+ DLIB_CASSERT(dest.size() == src3.size());
+ DLIB_CASSERT(dest.num_samples() == src1.num_samples());
+ DLIB_CASSERT(dest.num_samples() == src2.num_samples());
+ DLIB_CASSERT(dest.num_samples() == src3.num_samples());
+ DLIB_CASSERT(rectangle(0,0, dest.size()/dest.num_samples()-1, dest.num_samples()-1).contains(rect));
+ launch_kernel(_cuda_affine_transform_rect,max_jobs(rect.area()),
+ dest.device(), src1.device(), src2.device(), src3.device(), A, B, C,
+ rect.left() + rect.top()*(dest.size()/dest.num_samples()),
+ rect.area(),
+ rect.width(),
+ dest.size()/dest.num_samples());
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_affine_transform4(float* d, const float* s1, const float* s2, size_t n, float A, float B, float C)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = A*s1[i] + B*s2[i] + C;
+ }
+ }
+
+ __global__ void _cuda_affine_transform4_0(float* d, const float* s1, const float* s2, size_t n, float A, float B)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = A*s1[i] + B*s2[i];
+ }
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B,
+ const float C
+ )
+ {
+ DLIB_CASSERT(dest.size()==src1.size());
+ DLIB_CASSERT(dest.size()==src2.size());
+ if (C != 0)
+ launch_kernel(_cuda_affine_transform4,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B, C);
+ else
+ launch_kernel(_cuda_affine_transform4_0,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B);
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B
+ )
+ {
+ DLIB_CASSERT(dest.size()==src1.size());
+ DLIB_CASSERT(dest.size()==src2.size());
+ launch_kernel(_cuda_affine_transform4_0,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_add_scaled(float* d, const float* s, size_t n, float scale)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] += scale*s[i];
+ }
+ }
+
+ void add_scaled(
+ tensor& dest,
+ const float scale,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(dest.size()==src.size());
+ launch_kernel(_cuda_add_scaled,max_jobs(dest.size()),dest.device(), src.device(), dest.size(), scale);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_add_cv_to_all_columns(float beta, float* dest, float alpha, const float* src, size_t size, size_t stride)
+ {
+ for (auto i : grid_stride_range(0, size))
+ {
+ dest[i] = beta*dest[i] + alpha*src[i/stride];
+ }
+ }
+
+ __global__ void _cuda_add_cv_to_all_columns_no_beta(float* dest, float alpha, const float* src, size_t size, size_t stride)
+ {
+ for (auto i : grid_stride_range(0, size))
+ {
+ dest[i] = alpha*src[i/stride];
+ }
+ }
+
+ void add_cv_to_all_columns(
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(dest.num_samples() == src.num_samples() && src.num_samples() == src.size());
+ if (beta == 0)
+ launch_kernel(_cuda_add_cv_to_all_columns_no_beta, max_jobs(dest.size()), dest.device(), alpha, src.device(), dest.size(), dest.size()/dest.num_samples());
+ else
+ launch_kernel(_cuda_add_cv_to_all_columns, max_jobs(dest.size()), beta, dest.device(), alpha, src.device(), dest.size(), dest.size()/dest.num_samples());
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_affine_transform5(
+ float* d, const float* s1, const float* s2, const float* s3, size_t n, float A, float B, float C, float D
+ )
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = A*s1[i] + B*s2[i] + C*s3[i] + D;
+ }
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C,
+ const float D
+ )
+ {
+ DLIB_CASSERT(dest.size()==src1.size());
+ DLIB_CASSERT(dest.size()==src2.size());
+ DLIB_CASSERT(dest.size()==src3.size());
+ launch_kernel(_cuda_affine_transform5,max_jobs(dest.size()),dest.device(), src1.device(),
+ src2.device(), src3.device(), dest.size(), A, B, C, D);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_affine_transform_range(
+ float* d, const float* s1, const float* s2, const float* s3, size_t begin, size_t end, float A, float B, float C
+ )
+ {
+ for (auto i : grid_stride_range(begin, end))
+ {
+ d[i] = A*s1[i] + B*s2[i] + C*s3[i];
+ }
+ }
+
+
+ void affine_transform_range(
+ size_t begin,
+ size_t end,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C
+ )
+ {
+ DLIB_CASSERT(dest.size()==src1.size());
+ DLIB_CASSERT(dest.size()==src2.size());
+ DLIB_CASSERT(dest.size()==src3.size());
+ DLIB_CASSERT(begin <= end && end <= dest.size());
+ launch_kernel(_cuda_affine_transform_range,max_jobs(end-begin),
+ dest.device(), src1.device(),
+ src2.device(), src3.device(), begin, end, A, B, C);
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ __global__ void _cuda_affine_transform2(float* d, const float* s, size_t n, const float* A, const float* B)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = A[i]*s[i] + B[i];
+ }
+ }
+ __global__ void _cuda_affine_transform3(float* d, const float* s, size_t n, const float* A, const float* B, size_t bs)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = A[i%bs]*s[i] + B[i%bs];
+ }
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest, src));
+ DLIB_CASSERT(
+ ((A.num_samples()==1 && B.num_samples()==1) ||
+ (A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples())));
+ DLIB_CASSERT(
+ A.nr()==B.nr() && B.nr()==src.nr() &&
+ A.nc()==B.nc() && B.nc()==src.nc() &&
+ A.k() ==B.k() && B.k()==src.k(),
+ "\nA.nr(): " << A.nr() << "\nB.nr(): " << B.nr() << "\nsrc.nr(): " << src.nr()
+ <<"\nA.nc(): " << A.nc() << "\nB.nc(): " << B.nc() << "\nsrc.nc(): " << src.nc()
+ <<"\nA.k(): " << A.k() << "\nB.k(): " << B.k() << "\nsrc.k(): " << src.k()
+ );
+
+ if (A.num_samples() == 1)
+ {
+ launch_kernel(_cuda_affine_transform3,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A.device(), B.device(), A.size());
+ }
+ else
+ {
+ launch_kernel(_cuda_affine_transform2,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A.device(), B.device());
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_compute_adam_update(
+ size_t begin,
+ size_t end,
+ float* s,
+ float* m,
+ float* v,
+ const float alpha,
+ const float weight_decay,
+ const float momentum1,
+ const float momentum2,
+ const float* params,
+ const float* params_grad
+ )
+ {
+ const float eps = 1e-8;
+ // The loop is equivalent to doing this:
+ // m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad);
+ // v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad);
+ // s = -alpha*m/(sqrt(v) + eps);
+ for (auto i : grid_stride_range(begin, end))
+ {
+ float g = (weight_decay*params[i] + params_grad[i]);
+ m[i] = momentum1*m[i] + (1-momentum1)*g;
+ v[i] = momentum2*v[i] + (1-momentum2)*g*g;
+ s[i] = -alpha*m[i]/(std::sqrt(v[i]) + eps);
+ }
+ }
+
+ void compute_adam_update (
+ size_t begin,
+ size_t end,
+ tensor& s,
+ tensor& m,
+ tensor& v,
+ const float t,
+ const float learning_rate,
+ const float weight_decay,
+ const float momentum1,
+ const float momentum2,
+ const tensor& params,
+ const tensor& params_grad
+ )
+ {
+ DLIB_CASSERT(s.size() == m.size() &&
+ s.size() == v.size() &&
+ s.size() == params.size() &&
+ s.size() == params_grad.size());
+ DLIB_CASSERT(begin <= end && end <= params.size());
+ const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t));
+
+ launch_kernel(_cuda_compute_adam_update,max_jobs(end-begin),
+ begin, end, s.device(), m.device(), v.device(), alpha, weight_decay,
+ momentum1, momentum2, params.device(), params_grad.device());
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ __global__ void _cuda_affine_transform_conv(float* d, const float* s, size_t n, const float* A, const float* B, size_t bs, size_t ks)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ auto k = (i/bs)%ks;
+ d[i] = A[k]*s[i] + B[k];
+ }
+ }
+
+ void affine_transform_conv(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest, src));
+ DLIB_CASSERT(have_same_dimensions(A, B));
+ DLIB_CASSERT(A.num_samples() == 1 && A.nr() == 1 && A.nc() == 1 && A.k() == src.k());
+
+ launch_kernel(_cuda_affine_transform_conv,max_jobs(dest.size()),
+ dest.device(), src.device(), src.size(), A.device(), B.device(), src.nr()*src.nc(), src.k());
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ __global__ void _add_bias_gradient(float* out, const float* in, size_t n, size_t total_n)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ out[i] = in[i];
+ for (size_t j = i+n; j < total_n; j+=n)
+ out[i] += in[j];
+ }
+ }
+
+ void assign_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(
+ grad.num_samples() == 1 &&
+ gradient_input.k() == grad.k() &&
+ gradient_input.nr() == grad.nr() &&
+ gradient_input.nc() == grad.nc() &&
+ gradient_input.size() > 0);
+
+ launch_kernel(_add_bias_gradient,max_jobs(grad.size()),grad.device(), gradient_input.device(), grad.size(), gradient_input.size());
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _set_tensor(float* out, size_t n, const float val)
+ {
+ for (auto i : grid_stride_range(0, n))
+ out[i] = val;
+ }
+
+ void set_tensor (
+ tensor& t,
+ float value
+ )
+ {
+ launch_kernel(_set_tensor, max_jobs(t.size()), t.device(), t.size(), value);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _scale_tensor(float* out, size_t n, const float val)
+ {
+ for (auto i : grid_stride_range(0, n))
+ out[i] *= val;
+ }
+
+ void scale_tensor (
+ tensor& t,
+ float value
+ )
+ {
+ launch_kernel(_scale_tensor, max_jobs(t.size()), t.device(), t.size(), value);
+ }
+
+ // -----------------------------------------------------------------------------------
+ // -----------------------------------------------------------------------------------
+
+ __global__ void _cuda_threshold(float* d, size_t n, float thresh)
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ d[i] = d[i]>thresh ? 1:0;
+ }
+ }
+
+ void threshold (
+ tensor& data,
+ float thresh
+ )
+ {
+ launch_kernel(_cuda_threshold,max_jobs(data.size()),data.device(), data.size(), thresh);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ __global__ void _cuda_dot(const float* a, const float* b, size_t n, float* result)
+ {
+ // Parallel sum everything into local temp variables.
+ float temp = 0;
+ for(auto i : grid_stride_range(0, n))
+ temp += a[i]*b[i];
+
+ // Then do the warp reduce add thing to merge into one output value.
+ warp_reduce_atomic_add(*result, temp);
+ }
+
+
+ void dot (
+ const tensor& a,
+ const tensor& b,
+ tensor& result,
+ size_t idx
+ )
+ {
+ DLIB_CASSERT(a.size() == b.size());
+ DLIB_CASSERT(idx < result.size());
+
+ launch_kernel(_cuda_dot, max_jobs(a.size()), a.device(), b.device(), a.size(), result.device()+idx);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_prelu(const float* s, float* d, size_t n, const float* pp)
+ {
+ const float p = *pp;
+ for (auto i : grid_stride_range(0, n))
+ {
+ if (s[i] > 0)
+ d[i] = s[i];
+ else
+ d[i] = p*s[i];
+ }
+ }
+
+ void prelu (
+ tensor& dest,
+ const tensor& src,
+ const tensor& param
+ )
+ {
+ launch_kernel(_cuda_prelu, max_jobs(dest.size()),
+ src.device(), dest.device(), src.size(), param.device());
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_prelu_gradient(float* out, const float* s, const float* gi, size_t n, const float* pp, float* ppgrad)
+ {
+ const float p = *pp;
+ float pgrad = 0;
+ for(auto i : grid_stride_range(0, n))
+ {
+ if (s[i] > 0)
+ {
+ out[i] += gi[i];
+ }
+ else
+ {
+ out[i] += p*gi[i];
+ pgrad += gi[i]*s[i];
+ }
+ }
+
+ // Then do the warp reduce add thing to merge into one output value.
+ warp_reduce_atomic_add(*ppgrad, pgrad);
+ }
+
+ void prelu_gradient (
+ tensor& grad,
+ const tensor& src,
+ const tensor& gradient_input,
+ const tensor& param,
+ tensor& params_grad
+ )
+ {
+ params_grad = 0;
+ launch_kernel(_cuda_prelu_gradient, max_jobs(grad.size()),
+ grad.device(), src.device(), gradient_input.device(), grad.size(),
+ param.device(), params_grad.device());
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_resize_bilinear(size_t dsize, size_t dchan_size, size_t dnc, float* d,
+ size_t schan_size, int snr, int snc, const float* s,
+ const float x_scale, const float y_scale)
+ {
+ for(auto i : grid_stride_range(0, dsize))
+ {
+ const int idx = i%dchan_size;
+ const int channel = i/dchan_size;
+ const int sidx = channel*schan_size;
+ const int r = idx/dnc;
+ const int c = idx%dnc;
+
+ const float y = r*y_scale;
+ const int top = static_cast<int>(::floor(y));
+ const int bottom = ::min(top+1, snr-1);
+ const float tb_frac = y - top;
+
+ const float x = c*x_scale;
+ const int left = static_cast<int>(::floor(x));
+ const int right = ::min(left+1, snc-1);
+ const float lr_frac = x - left;
+
+ float tl = s[sidx+top*snc+left];
+ float tr = s[sidx+top*snc+right];
+ float bl = s[sidx+bottom*snc+left];
+ float br = s[sidx+bottom*snc+right];
+
+ float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
+ tb_frac*((1-lr_frac)*bl + lr_frac*br);
+
+ d[i] = temp;
+ }
+ }
+
+ __global__ void _cuda_resize_bilinear_strided(size_t dsize, size_t dchan_size, size_t dnc, float* d,
+ size_t schan_size, int snr, int snc, const float* s,
+ const float x_scale, const float y_scale,
+ size_t dest_row_stride, size_t src_row_stride, size_t dest_chan_size_strided
+ )
+ {
+ for(auto i : grid_stride_range(0, dsize))
+ {
+ const int idx = i%dchan_size;
+ const int channel = i/dchan_size;
+ const int sidx = channel*schan_size;
+ const int r = idx/dnc;
+ const int c = idx%dnc;
+ const int didx = channel*dest_chan_size_strided + r*dest_row_stride+c;
+
+ const float y = r*y_scale;
+ const int top = static_cast<int>(::floor(y));
+ const int bottom = ::min(top+1, snr-1);
+ const float tb_frac = y - top;
+
+ const float x = c*x_scale;
+ const int left = static_cast<int>(::floor(x));
+ const int right = ::min(left+1, snc-1);
+ const float lr_frac = x - left;
+
+ float tl = s[sidx+top*src_row_stride+left];
+ float tr = s[sidx+top*src_row_stride+right];
+ float bl = s[sidx+bottom*src_row_stride+left];
+ float br = s[sidx+bottom*src_row_stride+right];
+
+ float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
+ tb_frac*((1-lr_frac)*bl + lr_frac*br);
+
+ d[didx] = temp;
+ }
+ }
+
+ void resize_bilinear (
+ tensor& dest,
+ long dest_row_stride,
+ long dest_channel_stride,
+ const tensor& src,
+ long src_row_stride,
+ long src_channel_stride
+ )
+ {
+ DLIB_CASSERT(is_same_object(dest, src)==false);
+ DLIB_CASSERT(dest.num_samples() == src.num_samples());
+ DLIB_CASSERT(dest.k() == src.k());
+
+ if (dest.size() == 0 || src.size() == 0)
+ return;
+
+ const float x_scale = (src.nc()-1)/(float)std::max<long>((dest.nc()-1),1);
+ const float y_scale = (src.nr()-1)/(float)std::max<long>((dest.nr()-1),1);
+
+ if (dest.nc() == dest_row_stride && dest.nr()*dest.nc()==dest_channel_stride &&
+ src.nc() == src_row_stride && src.nr()*src.nc()==src_channel_stride)
+ {
+ launch_kernel(_cuda_resize_bilinear,
+ dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(),
+ src.nr()*src.nc(), src.nr(), src.nc(), src.device(),
+ x_scale, y_scale);
+ }
+ else
+ {
+ launch_kernel(_cuda_resize_bilinear_strided,
+ dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(),
+ src_channel_stride, src.nr(), src.nc(), src.device(),
+ x_scale, y_scale, dest_row_stride, src_row_stride, dest_channel_stride);
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_resize_bilinear_gradient(size_t dsize, size_t dchan_size, size_t dnc, const float* d,
+ size_t schan_size, int snr, int snc, float* s,
+ const float x_scale, const float y_scale)
+ {
+ for(auto i : grid_stride_range(0, dsize))
+ {
+ const float tmp = d[i];
+
+ const int idx = i%dchan_size;
+ const int channel = i/dchan_size;
+ const int sidx = channel*schan_size;
+ const int r = idx/dnc;
+ const int c = idx%dnc;
+
+ const float y = r*y_scale;
+ const int top = static_cast<int>(::floor(y));
+ const int bottom = ::min(top+1, snr-1);
+ const float tb_frac = y - top;
+
+ const float x = c*x_scale;
+ const int left = static_cast<int>(::floor(x));
+ const int right = ::min(left+1, snc-1);
+ const float lr_frac = x - left;
+
+
+ atomicAdd(s+sidx+top*snc+left, tmp*(1-tb_frac)*(1-lr_frac));
+ atomicAdd(s+sidx+top*snc+right, tmp*(1-tb_frac)*(lr_frac));
+ atomicAdd(s+sidx+bottom*snc+left, tmp*(tb_frac)*(1-lr_frac));
+ atomicAdd(s+sidx+bottom*snc+right, tmp*(tb_frac)*(lr_frac));
+ }
+ }
+
+ __global__ void _cuda_resize_bilinear_gradient_strided(size_t dsize, size_t dchan_size, size_t dnc, const float* d,
+ size_t schan_size, int snr, int snc, float* s,
+ const float x_scale, const float y_scale,
+ size_t dest_row_stride, size_t src_row_stride, size_t dest_chan_size_strided
+ )
+ {
+ for(auto i : grid_stride_range(0, dsize))
+ {
+
+ const int idx = i%dchan_size;
+ const int channel = i/dchan_size;
+ const int didx = channel*dest_chan_size_strided;
+ const int sidx = channel*schan_size;
+ const int r = idx/dnc;
+ const int c = idx%dnc;
+
+ const float tmp = d[didx + r*dest_row_stride+c];
+
+ const float y = r*y_scale;
+ const int top = static_cast<int>(::floor(y));
+ const int bottom = ::min(top+1, snr-1);
+ const float tb_frac = y - top;
+
+ const float x = c*x_scale;
+ const int left = static_cast<int>(::floor(x));
+ const int right = ::min(left+1, snc-1);
+ const float lr_frac = x - left;
+
+
+ atomicAdd(s+sidx+top*src_row_stride+left, tmp*(1-tb_frac)*(1-lr_frac));
+ atomicAdd(s+sidx+top*src_row_stride+right, tmp*(1-tb_frac)*(lr_frac));
+ atomicAdd(s+sidx+bottom*src_row_stride+left, tmp*(tb_frac)*(1-lr_frac));
+ atomicAdd(s+sidx+bottom*src_row_stride+right, tmp*(tb_frac)*(lr_frac));
+ }
+ }
+
+ void resize_bilinear_gradient (
+ tensor& grad,
+ long grad_row_stride,
+ long grad_channel_stride,
+ const tensor& gradient_input,
+ long gradient_input_row_stride,
+ long gradient_input_channel_stride
+ )
+ {
+ DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
+ DLIB_CASSERT(gradient_input.num_samples() == grad.num_samples());
+ DLIB_CASSERT(gradient_input.k() == grad.k());
+
+ if (grad.size() == 0 || gradient_input.size() == 0)
+ return;
+
+ const float x_scale = (grad.nc()-1)/(float)std::max<long>((gradient_input.nc()-1),1);
+ const float y_scale = (grad.nr()-1)/(float)std::max<long>((gradient_input.nr()-1),1);
+
+ if (grad.nc() == grad_row_stride && grad.nr()*grad.nc()==grad_channel_stride &&
+ gradient_input.nc() == gradient_input_row_stride && gradient_input.nr()*gradient_input.nc()==gradient_input_channel_stride)
+ {
+ launch_kernel(_cuda_resize_bilinear_gradient,
+ gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(),
+ grad.nr()*grad.nc(), grad.nr(), grad.nc(), grad.device(),
+ x_scale, y_scale);
+ }
+ else
+ {
+ launch_kernel(_cuda_resize_bilinear_gradient_strided,
+ gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(),
+ grad_channel_stride, grad.nr(), grad.nc(), grad.device(),
+ x_scale, y_scale, gradient_input_row_stride, grad_row_stride, gradient_input_channel_stride);
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ __global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size)
+ {
+ for(auto i : grid_stride_range(0, size))
+ {
+ size_t blk = i/block_size;
+ size_t j = i%block_size;
+ dest[blk*dest_stride + j] += src[blk*src_stride + j];
+ }
+ }
+
+ __global__ void _cuda_copy_tensor (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size)
+ {
+ for(auto i : grid_stride_range(0, size))
+ {
+ size_t blk = i/block_size;
+ size_t j = i%block_size;
+ dest[blk*dest_stride + j] = src[blk*src_stride + j];
+ }
+ }
+
+ void copy_tensor(
+ bool add_to,
+ tensor& dest,
+ size_t dest_k_offset,
+ const tensor& src,
+ size_t src_k_offset,
+ size_t count_k
+ )
+ {
+ const size_t dest_sample_size = static_cast<size_t>(dest.nc() * dest.nr() * dest.k());
+ const size_t src_sample_size = static_cast<size_t>(src.nc() * src.nr() * src.k());
+
+ const size_t block_size = count_k * dest.nc() * dest.nr();
+
+ DLIB_CASSERT(dest.num_samples() == src.num_samples() &&
+ dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size");
+ DLIB_CASSERT(dest.k() - dest_k_offset >= count_k, "Not enough space in dest tensor");
+ DLIB_CASSERT(src.k() - src_k_offset >= count_k, "Not enough space in src tensor");
+
+ float* dest_p = dest.device() + dest_k_offset * dest.nc() * dest.nr();
+ const float* src_p = src.device() + src_k_offset * src.nc() * src.nr();;
+
+ if (add_to)
+ {
+ launch_kernel(_cuda_copy_tensor_add_to, max_jobs(dest.size()),
+ dest_p, block_size*dest.num_samples(),
+ src_p, dest_sample_size, src_sample_size, block_size);
+ }
+ else
+ {
+ launch_kernel(_cuda_copy_tensor, max_jobs(dest.size()),
+ dest_p, block_size*dest.num_samples(),
+ src_p, dest_sample_size, src_sample_size, block_size);
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ }
+}
+
diff --git a/ml/dlib/dlib/dnn/cuda_dlib.h b/ml/dlib/dlib/dnn/cuda_dlib.h
new file mode 100644
index 000000000..3a057ffc4
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cuda_dlib.h
@@ -0,0 +1,469 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuDA_H_
+#define DLIB_DNN_CuDA_H_
+
+
+#include "tensor.h"
+#include "../geometry/rectangle.h"
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // ----------------------------------------------------------------------------------------
+
+ void set_device (
+ int dev
+ );
+
+ int get_device (
+ );
+
+ int get_num_devices (
+ );
+
+ std::string get_device_name (
+ int device
+ );
+
+ void set_current_device_blocking_sync(
+ );
+
+ bool can_access_peer (int device_id, int peer_device_id);
+ bool can_access_peer (const tensor& device, const tensor& peer_device);
+
+ void device_synchronize (int dev);
+ void device_synchronize (const tensor& dev);
+
+
+ class raii_set_device
+ {
+ public:
+ raii_set_device() = delete;
+ raii_set_device(const raii_set_device&) = delete;
+ raii_set_device& operator=(const raii_set_device&) = delete;
+
+ raii_set_device(int dev)
+ {
+ prev_dev = get_device();
+ set_device(dev);
+ }
+
+ raii_set_device(const tensor& dev)
+ {
+ prev_dev = get_device();
+ set_device(dev.device_id());
+ }
+
+ void operator() (int dev)
+ {
+ set_device(dev);
+ }
+
+ void operator() (const tensor& dev)
+ {
+ set_device(dev.device_id());
+ }
+
+ ~raii_set_device() noexcept(false)
+ {
+ set_device(prev_dev);
+ }
+
+ private:
+ int prev_dev;
+ };
+
+
+#ifdef DLIB_USE_CUDA
+
+ class enable_peer_access
+ {
+ public:
+
+ enable_peer_access() = delete;
+ enable_peer_access(const enable_peer_access&) = delete;
+ enable_peer_access& operator=(const enable_peer_access&) = delete;
+
+ enable_peer_access(
+ int device_id,
+ int peer_device_id
+ );
+
+ enable_peer_access(
+ const tensor& device,
+ const tensor& peer_device
+ ) : enable_peer_access(device.device_id(), peer_device.device_id())
+ {}
+
+ ~enable_peer_access() noexcept(false);
+
+ private:
+
+ bool call_disable;
+ int device_id;
+ int peer_device_id;
+ };
+
+ // -----------------------------------------------------------------------------------
+
+ void inverse_norms (
+ resizable_tensor& invnorms,
+ const tensor& data,
+ const double eps
+ );
+
+ void dot_prods (
+ resizable_tensor& out,
+ const tensor& lhs,
+ const tensor& rhs
+ );
+
+ void dot_prods (
+ bool add_to,
+ tensor& out,
+ const tensor& lhs,
+ const tensor& rhs
+ );
+
+ void scale_columns (
+ tensor& out,
+ const tensor& m,
+ const tensor& v
+ );
+
+ void scale_rows (
+ tensor& out,
+ const tensor& m,
+ const tensor& v
+ );
+
+ void scale_rows2 (
+ float beta,
+ tensor& out,
+ const tensor& m1,
+ const tensor& m2,
+ const tensor& v1,
+ const tensor& v2
+ );
+
+ void exp (
+ tensor& dest,
+ const tensor& src
+ );
+
+ void log (
+ tensor& dest,
+ const tensor& src
+ );
+
+ void log10 (
+ tensor& dest,
+ const tensor& src
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ void set_tensor (
+ tensor& t,
+ float value
+ );
+
+ void scale_tensor (
+ tensor& t,
+ float value
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ void multiply (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+
+ void multiply_conv (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+
+ void multiply_zero_padded (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+
+ void scale_channels (
+ bool add_to,
+ tensor& dest,
+ const tensor& src,
+ const tensor& scales
+ );
+
+ void add (
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A,
+ const float B
+ );
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A
+ );
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B,
+ const float C
+ );
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B
+ );
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C,
+ const float D
+ );
+
+ void affine_transform_range(
+ size_t begin,
+ size_t end,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C
+ );
+
+ void affine_transform(
+ const rectangle& rect,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ float A,
+ float B,
+ float C
+ );
+
+ // Note that this function isn't in the tt:: namespace because add_scaled() is
+ // called by cuda::add() so we don't need a tt:: version of add_scaled().
+ void add_scaled(
+ tensor& dest,
+ const float scale,
+ const tensor& src
+ );
+
+ void add_cv_to_all_columns(
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& src
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void affine_transform_conv(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ );
+
+ // ----------------------------------------------------------------------------------------
+
+ void compute_adam_update (
+ size_t begin,
+ size_t end,
+ tensor& s,
+ tensor& m,
+ tensor& v,
+ const float t,
+ const float learning_rate,
+ const float weight_decay,
+ const float momentum1,
+ const float momentum2,
+ const tensor& params,
+ const tensor& params_grad
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void assign_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ );
+
+ // -----------------------------------------------------------------------------------
+
+ void threshold (
+ tensor& data,
+ float thresh
+ );
+
+ // ----------------------------------------------------------------------------------------
+
+ void dot (
+ const tensor& a,
+ const tensor& b,
+ tensor& result,
+ size_t idx
+ );
+
+ // ----------------------------------------------------------------------------------------
+
+ void prelu (
+ tensor& dest,
+ const tensor& src,
+ const tensor& param
+ );
+
+ void prelu_gradient (
+ tensor& grad,
+ const tensor& src,
+ const tensor& gradient_input,
+ const tensor& param,
+ tensor& params_grad
+ );
+
+
+ // ----------------------------------------------------------------------------------------
+
+ void resize_bilinear (
+ tensor& dest,
+ long dest_row_stride,
+ long dest_channel_stride,
+ const tensor& src,
+ long src_row_stride,
+ long src_channel_stride
+ );
+
+ void resize_bilinear_gradient (
+ tensor& grad,
+ long grad_row_stride,
+ long grad_channel_stride,
+ const tensor& gradient_input,
+ long gradient_input_row_stride,
+ long gradient_input_channel_stride
+ );
+
+ inline void resize_bilinear (
+ tensor& dest,
+ const tensor& src
+ ) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); }
+
+ inline void resize_bilinear_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ ) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
+
+ // ----------------------------------------------------------------------------------------
+
+ void copy_tensor(
+ bool add_to,
+ tensor& dest,
+ size_t dest_k_offset,
+ const tensor& src,
+ size_t src_k_offset,
+ size_t count_k
+ );
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+#else // if DLIB_USE_CUDA NOT DEFINED
+
+ inline void set_device (
+ int id
+ )
+ {
+ DLIB_CASSERT(id == 0, "dlib::cuda::set_device(id) called with an invalid device id.");
+ }
+
+ inline int get_device (
+ ){ return 0; }
+
+ inline int get_num_devices (
+ ) { return 1; }
+
+ inline std::string get_device_name (
+ int device
+ )
+ {
+ DLIB_CASSERT(device == 0, "dlib::cuda::set_device(id) called with an invalid device id.");
+ return "CUDA_DISABLED";
+ }
+
+ inline void set_current_device_blocking_sync(
+ ) {}
+
+
+ inline bool can_access_peer (int , int )
+ { return false; }
+ inline bool can_access_peer (const tensor& , const tensor& )
+ { return false; }
+
+ inline void device_synchronize (int ){}
+ inline void device_synchronize (const tensor& ){}
+
+ class enable_peer_access
+ {
+ public:
+ enable_peer_access() = delete;
+ enable_peer_access(const enable_peer_access&) = delete;
+ enable_peer_access& operator=(const enable_peer_access&) = delete;
+ enable_peer_access( int, int ){}
+ enable_peer_access( const tensor&, const tensor& ) {}
+ };
+
+#endif // DLIB_USE_CUDA
+
+ }
+}
+
+
+#endif // DLIB_DNN_CuDA_H_
+
diff --git a/ml/dlib/dlib/dnn/cuda_errors.h b/ml/dlib/dlib/dnn/cuda_errors.h
new file mode 100644
index 000000000..fd28693c2
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cuda_errors.h
@@ -0,0 +1,70 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CUDA_ERRORs_H_
+#define DLIB_CUDA_ERRORs_H_
+
+
+#include "../error.h"
+
+namespace dlib
+{
+ struct cuda_error : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown if any calls to the NVIDIA CUDA runtime
+ returns an error.
+ !*/
+
+ cuda_error(const std::string& message): error(message) {}
+ };
+
+
+ struct cudnn_error : public cuda_error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown if any calls to the NVIDIA cuDNN library
+ returns an error.
+ !*/
+
+ cudnn_error(const std::string& message): cuda_error(message) {}
+ };
+
+ struct curand_error : public cuda_error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown if any calls to the NVIDIA cuRAND library
+ returns an error.
+ !*/
+
+ curand_error(const std::string& message): cuda_error(message) {}
+ };
+
+ struct cublas_error : public cuda_error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown if any calls to the NVIDIA cuBLAS library
+ returns an error.
+ !*/
+
+ cublas_error(const std::string& message): cuda_error(message) {}
+ };
+
+ struct cusolver_error : public cuda_error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown if any calls to the NVIDIA cuSolver library
+ returns an error.
+ !*/
+
+ cusolver_error(const std::string& message): cuda_error(message) {}
+ };
+}
+
+
+#endif // DLIB_CUDA_ERRORs_H_
+
diff --git a/ml/dlib/dlib/dnn/cuda_utils.h b/ml/dlib/dlib/dnn/cuda_utils.h
new file mode 100644
index 000000000..673a4e8ad
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cuda_utils.h
@@ -0,0 +1,413 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CUDA_UtILS_H_
+#define DLIB_CUDA_UtILS_H_
+
+#ifndef DLIB_USE_CUDA
+#error "This file shouldn't be #included unless DLIB_USE_CUDA is #defined"
+#endif
+
+#include "cuda_errors.h"
+#include "../algs.h"
+#include <cmath>
+
+#include <cuda_runtime.h>
+#include <sstream>
+#include <iostream>
+#include <memory>
+#include <vector>
+#include <type_traits>
+
+
+// Check the return value of a call to the CUDA runtime for an error condition.
+#define CHECK_CUDA(call) \
+do{ \
+ const cudaError_t error = call; \
+ if (error != cudaSuccess) \
+ { \
+ std::ostringstream sout; \
+ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
+ sout << "code: " << error << ", reason: " << cudaGetErrorString(error);\
+ throw dlib::cuda_error(sout.str()); \
+ } \
+}while(false)
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef __CUDACC__
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ __inline__ __device__ size_t pack_idx (
+ size_t dim_size3,
+ size_t dim_size2,
+ size_t dim_size1,
+ size_t idx4,
+ size_t idx3,
+ size_t idx2,
+ size_t idx1
+ )
+ /*!
+ ensures
+ - Converts a 4D array index into a 1D index assuming row major layout. To
+ understand precisely what this function does, imagine we had an array
+ declared like this:
+ int ARRAY[anything][dim_size3][dim_size2][dim_size1];
+ Then we could index it like this:
+ ARRAY[idx4][idx3][idx2][idx1]
+ or equivalently like this:
+ ((int*)ARRAY)[pack_idx(dim_size3,dim_size2,dim_size1, idx4,idx3,idx2,idx1)]
+ !*/
+ {
+ return ((idx4*dim_size3 + idx3)*dim_size2 + idx2)*dim_size1 + idx1;
+ }
+
+ __inline__ __device__ void unpack_idx (
+ size_t idx,
+ size_t dim_size3,
+ size_t dim_size2,
+ size_t dim_size1,
+ size_t& idx4,
+ size_t& idx3,
+ size_t& idx2,
+ size_t& idx1
+ )
+ /*!
+ ensures
+ - This function computes the inverse of pack_idx(). Therefore,
+ if PACKED == pack_idx(dim_size3,dim_size2,dim_size1, idx4,idx3,idx2,idx1)
+ then unpack_idx(PACKED,dim_size3,dim_size2,dim_size1, IDX4,IDX3,IDX2,IDX1)
+ results in:
+ - IDX1 == idx1
+ - IDX2 == idx2
+ - IDX3 == idx3
+ - IDX4 == idx4
+ !*/
+ {
+ idx1 = idx%dim_size1;
+
+ idx /= dim_size1;
+ idx2 = idx%dim_size2;
+
+ idx /= dim_size2;
+ idx3 = idx%dim_size3;
+
+ idx /= dim_size3;
+ idx4 = idx;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // This function is from the article:
+ // http://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/
+ __inline__ __device__ float warp_reduce_sum(float val)
+ {
+ for (int offset = warpSize/2; offset > 0; offset /= 2)
+#if CUDART_VERSION >= 9000
+ val += __shfl_down_sync(0xFFFFFFFF,val, offset);
+#else
+ val += __shfl_down(val, offset);
+#endif
+ return val;
+ }
+
+ __inline__ __device__ bool is_first_thread_in_warp()
+ {
+ return (threadIdx.x & (warpSize - 1)) == 0;
+ }
+
+ __inline__ __device__ void warp_reduce_atomic_add(
+ float& out,
+ float val
+ )
+ /*!
+ ensures
+ - Atomically adds all the val variables in the current warp to out.
+ See this page for an extended discussion:
+ http://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/
+ !*/
+ {
+ val = warp_reduce_sum(val);
+ if (is_first_thread_in_warp())
+ atomicAdd(&out, val);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ struct max_jobs
+ {
+ max_jobs(int x) : num_x(x) {}
+ max_jobs(int x, int y) : num_x(x), num_y(y) {}
+ int num_x;
+ int num_y = 1;
+ };
+
+ template <typename Kernel, typename... T>
+ void launch_kernel (
+ Kernel K,
+ T ...args
+ )
+ /*!
+ ensures
+ - launches the given kernel K(args...). The point of this function is to
+ automatically set the kernel launch parameters to something reasonable
+ based on the properties of the kernel and the current GPU card.
+ !*/
+ {
+ int num_blocks, num_threads;
+ CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K));
+ K<<<num_blocks,num_threads>>>(args...);
+ }
+
+ template <typename Kernel, typename... T>
+ void launch_kernel (
+ Kernel K,
+ max_jobs m,
+ T ...args
+ )
+ /*!
+ ensures
+ - This function is just like launch_kernel(K,args...) except that you can
+ additionally supply a max_jobs number that tells it how many possible
+ total threads could be used. This is useful when launching potentially
+ small jobs that might not need the number of threads suggested by
+ launch_kernel().
+ !*/
+ {
+ if (m.num_x == 0 || m.num_y == 0)
+ return;
+ int num_blocks, num_threads;
+ CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K));
+ // Check if the job is really small and we don't really need to launch a kernel
+ // with this many blocks and threads.
+ if (num_blocks*num_threads > m.num_x*m.num_y)
+ num_blocks = (m.num_x*m.num_y+num_threads-1)/num_threads;
+
+ if (m.num_y == 1)
+ {
+ K<<<num_blocks,num_threads>>>(args...);
+ }
+ else
+ {
+ /*
+ In general, the reason m.num_y!=1 (i.e. the reason you are in this
+ code path) is because we are using nested grid-stride loops. There are
+ two important things to note about what we are doing here. To
+ illustrate them we will talk about this little CUDA code snippet:
+
+ // initialize out before we begin.
+ for (auto i : grid_stride_range_y(0, nr))
+ for (auto j : grid_stride_range(0, 1))
+ out[i] = 0;
+
+ __syncthreads(); // synchronize threads in block
+
+ // loop over some 2D thing and sum and store things into out.
+ for (auto i : grid_stride_range_y(0, nr))
+ {
+ float temp = 0;
+ for (auto j : grid_stride_range(0, nc))
+ temp += whatever[i*nc+j];
+
+ // store the sum into out[i]
+ warp_reduce_atomic_add(out[i], temp);
+ }
+
+ First, we make sure the number of x threads is a multiple of 32 so that
+ you can use warp_reduce_atomic_add() inside the y loop.
+
+ Second, we put the x block size to 1 so inter-block synchronization is
+ easier. For example, if the number of x blocks wasn't 1 the above code
+ would have a race condition in it. This is because the execution of
+ out[i]=0 would be done by blocks with blockIdx.x==0, but then in the
+ second set of loops, *all* the x blocks use out[i]. Since
+ __syncthreads() doesn't do any synchronization between blocks some of
+ the blocks might begin before the out[i]=0 statements finished and that
+ would be super bad.
+ */
+
+ // Try and make sure that the ratio of x to y threads is reasonable based
+ // on the respective size of our loops.
+ int x_threads = 32;
+ int y_threads = num_threads/32;
+ const int ratio = static_cast<int>(std::round(put_in_range(1, y_threads, m.num_x/(double)m.num_y)));
+ x_threads *= ratio;
+ y_threads /= ratio;
+
+ dim3 blocks(1,num_blocks);
+ dim3 threads(x_threads,y_threads);
+ K<<<blocks,threads>>>(args...);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ class grid_stride_range
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool for making a for loop that loops over an entire block of
+ memory inside a kernel, but doing so in a way that parallelizes
+ appropriately across all the threads in a kernel launch. For example,
+ the following kernel would add the vector a to the vector b and store
+ the output in out (assuming all vectors are of dimension n):
+ __global__ void add_arrays(
+ const float* a,
+ const float* b,
+ float* out,
+ size_t n
+ )
+ {
+ for (auto i : grid_stride_range(0, n))
+ {
+ out[i] = a[i]+b[i];
+ }
+ }
+ !*/
+
+ public:
+ __device__ grid_stride_range(
+ size_t ibegin_,
+ size_t iend_
+ ) :
+ ibegin(ibegin_),
+ iend(iend_)
+ {}
+
+ class iterator
+ {
+ public:
+ __device__ iterator() {}
+ __device__ iterator(size_t pos_) : pos(pos_) {}
+
+ __device__ size_t operator*() const
+ {
+ return pos;
+ }
+
+ __device__ iterator& operator++()
+ {
+ pos += gridDim.x * blockDim.x;
+ return *this;
+ }
+
+ __device__ bool operator!=(const iterator& item) const
+ { return pos < item.pos; }
+
+ private:
+ size_t pos;
+ };
+
+ __device__ iterator begin() const
+ {
+ return iterator(ibegin+blockDim.x * blockIdx.x + threadIdx.x);
+ }
+ __device__ iterator end() const
+ {
+ return iterator(iend);
+ }
+ private:
+
+ size_t ibegin;
+ size_t iend;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ class grid_stride_range_y
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is just like grid_stride_range except that it looks at
+ CUDA's y thread index (e.g. threadIdx.y) instead of the x index.
+ Therefore, if you launch a cuda kernel with a statement like:
+ dim3 blocks(1,10);
+ dim3 threads(32,32); // You need to have x and y not equal to 1 to get parallelism over both loops.
+ add_arrays<<<blocks,threads>>>(a,b,out,nr,nc);
+ You can perform a nested 2D parallel for loop rather than doing just a
+ 1D for loop.
+
+ So the code in the kernel would look like this if you wanted to add two
+ 2D matrices:
+ __global__ void add_arrays(
+ const float* a,
+ const float* b,
+ float* out,
+ size_t nr,
+ size_t nc
+ )
+ {
+ for (auto r : grid_stride_range_y(0, nr))
+ {
+ for (auto c : grid_stride_range(0, nc))
+ {
+ auto i = r*nc+c;
+ out[i] = a[i]+b[i];
+ }
+ }
+ }
+ !*/
+
+ public:
+ __device__ grid_stride_range_y(
+ size_t ibegin_,
+ size_t iend_
+ ) :
+ ibegin(ibegin_),
+ iend(iend_)
+ {}
+
+ class iterator
+ {
+ public:
+ __device__ iterator() {}
+ __device__ iterator(size_t pos_) : pos(pos_) {}
+
+ __device__ size_t operator*() const
+ {
+ return pos;
+ }
+
+ __device__ iterator& operator++()
+ {
+ pos += gridDim.y * blockDim.y;
+ return *this;
+ }
+
+ __device__ bool operator!=(const iterator& item) const
+ { return pos < item.pos; }
+
+ private:
+ size_t pos;
+ };
+
+ __device__ iterator begin() const
+ {
+ return iterator(ibegin+blockDim.y * blockIdx.y + threadIdx.y);
+ }
+ __device__ iterator end() const
+ {
+ return iterator(iend);
+ }
+ private:
+
+ size_t ibegin;
+ size_t iend;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // __CUDACC__
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_CUDA_UtILS_H_
+
diff --git a/ml/dlib/dlib/dnn/cudnn_dlibapi.cpp b/ml/dlib/dlib/dnn/cudnn_dlibapi.cpp
new file mode 100644
index 000000000..6926561f1
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cudnn_dlibapi.cpp
@@ -0,0 +1,1604 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuDNN_CPP_
+#define DLIB_DNN_CuDNN_CPP_
+
+#ifdef DLIB_USE_CUDA
+
+#include "cudnn_dlibapi.h"
+#include "tensor.h"
+#include <cudnn.h>
+#include <iostream>
+#include <string>
+#include <vector>
+#include "cuda_utils.h"
+#include "cpu_dlib.h"
+#include "cuda_dlib.h"
+#include "tensor_tools.h"
+
+static const char* cudnn_get_error_string(cudnnStatus_t s)
+{
+ switch(s)
+ {
+ case CUDNN_STATUS_NOT_INITIALIZED:
+ return "CUDA Runtime API initialization failed.";
+ case CUDNN_STATUS_ALLOC_FAILED:
+ return "CUDA Resources could not be allocated.";
+ case CUDNN_STATUS_BAD_PARAM:
+ return "CUDNN_STATUS_BAD_PARAM";
+ case CUDNN_STATUS_EXECUTION_FAILED:
+ return "CUDNN_STATUS_EXECUTION_FAILED";
+ case CUDNN_STATUS_NOT_SUPPORTED:
+ return "CUDNN_STATUS_NOT_SUPPORTED";
+ case CUDNN_STATUS_ARCH_MISMATCH:
+ return "CUDNN_STATUS_ARCH_MISMATCH: Your GPU is too old and not supported by cuDNN";
+ default:
+ return "A call to cuDNN failed";
+ }
+}
+
+// Check the return value of a call to the cuDNN runtime for an error condition.
+#define CHECK_CUDNN(call) \
+do{ \
+ const cudnnStatus_t error = call; \
+ if (error != CUDNN_STATUS_SUCCESS) \
+ { \
+ std::ostringstream sout; \
+ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
+ sout << "code: " << error << ", reason: " << cudnn_get_error_string(error);\
+ throw dlib::cudnn_error(sout.str()); \
+ } \
+}while(false)
+
+
+namespace dlib
+{
+
+ namespace cuda
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ static cudnnTensorDescriptor_t descriptor(const tensor& t)
+ {
+ return (const cudnnTensorDescriptor_t)t.get_cudnn_tensor_descriptor().get_handle();
+ }
+ static cudnnTensorDescriptor_t descriptor(const tensor_descriptor& t)
+ {
+ return (const cudnnTensorDescriptor_t)t.get_handle();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ class cudnn_context
+ {
+ public:
+ // not copyable
+ cudnn_context(const cudnn_context&) = delete;
+ cudnn_context& operator=(const cudnn_context&) = delete;
+
+ cudnn_context()
+ {
+ handles.resize(16);
+ }
+ ~cudnn_context()
+ {
+ for (auto h : handles)
+ {
+ if (h)
+ cudnnDestroy(h);
+ }
+ }
+
+ cudnnHandle_t get_handle (
+ )
+ {
+ int new_device_id;
+ CHECK_CUDA(cudaGetDevice(&new_device_id));
+ // make room for more devices if needed
+ if (new_device_id >= (long)handles.size())
+ handles.resize(new_device_id+16);
+
+ // If we don't have a handle already for this device then make one
+ if (!handles[new_device_id])
+ CHECK_CUDNN(cudnnCreate(&handles[new_device_id]));
+
+ // Finally, return the handle for the current device
+ return handles[new_device_id];
+ }
+
+ private:
+
+ std::vector<cudnnHandle_t> handles;
+ };
+
+ static cudnnHandle_t context()
+ {
+ thread_local cudnn_context c;
+ return c.get_handle();
+ }
+ // ------------------------------------------------------------------------------------
+
+ class cudnn_device_buffer
+ {
+ public:
+ // not copyable
+ cudnn_device_buffer(const cudnn_device_buffer&) = delete;
+ cudnn_device_buffer& operator=(const cudnn_device_buffer&) = delete;
+
+ cudnn_device_buffer()
+ {
+ buffers.resize(16);
+ }
+ ~cudnn_device_buffer()
+ {
+ }
+
+ std::shared_ptr<resizable_cuda_buffer> get_buffer (
+ )
+ {
+ int new_device_id;
+ CHECK_CUDA(cudaGetDevice(&new_device_id));
+ // make room for more devices if needed
+ if (new_device_id >= (long)buffers.size())
+ buffers.resize(new_device_id+16);
+
+ // If we don't have a buffer already for this device then make one
+ std::shared_ptr<resizable_cuda_buffer> buff = buffers[new_device_id].lock();
+ if (!buff)
+ {
+ buff = std::make_shared<resizable_cuda_buffer>();
+ buffers[new_device_id] = buff;
+ }
+
+ // Finally, return the buffer for the current device
+ return buff;
+ }
+
+ private:
+
+ std::vector<std::weak_ptr<resizable_cuda_buffer>> buffers;
+ };
+
+
+ static std::shared_ptr<resizable_cuda_buffer> device_global_buffer()
+ {
+ thread_local cudnn_device_buffer buffer;
+ return buffer.get_buffer();
+ }
+ // ------------------------------------------------------------------------------------
+
+ class cudnn_activation_descriptor
+ {
+ public:
+ // not copyable
+ cudnn_activation_descriptor(const cudnn_activation_descriptor&) = delete;
+ cudnn_activation_descriptor& operator=(const cudnn_activation_descriptor&) = delete;
+
+ cudnn_activation_descriptor(
+ cudnnActivationMode_t mode,
+ cudnnNanPropagation_t reluNanOpt,
+ double reluCeiling
+ )
+ {
+ CHECK_CUDNN(cudnnCreateActivationDescriptor(&handle));
+ CHECK_CUDNN(cudnnSetActivationDescriptor(handle, mode, reluNanOpt, reluCeiling));
+ }
+
+ ~cudnn_activation_descriptor()
+ {
+ cudnnDestroyActivationDescriptor(handle);
+ }
+
+ cudnnActivationDescriptor_t get_handle (
+ )
+ {
+ return handle;
+ }
+ private:
+ cudnnActivationDescriptor_t handle;
+ };
+
+ static cudnnActivationDescriptor_t relu_activation_descriptor()
+ {
+ thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN,0);
+ return des.get_handle();
+ }
+
+ static cudnnActivationDescriptor_t sigmoid_activation_descriptor()
+ {
+ thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN,0);
+ return des.get_handle();
+ }
+
+ static cudnnActivationDescriptor_t tanh_activation_descriptor()
+ {
+ thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN,0);
+ return des.get_handle();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ tensor_descriptor::
+ tensor_descriptor(
+ ) : handle(nullptr)
+ {
+ }
+
+ tensor_descriptor::
+ ~tensor_descriptor()
+ {
+ set_size(0,0,0,0);
+ }
+
+ void tensor_descriptor::
+ set_size(
+ int n,
+ int k,
+ int nr,
+ int nc
+ )
+ {
+ if (handle)
+ {
+ cudnnDestroyTensorDescriptor((cudnnTensorDescriptor_t)handle);
+ handle = nullptr;
+ }
+
+ if (n != 0 && nr != 0 && nc != 0 && k != 0)
+ {
+ cudnnTensorDescriptor_t h;
+ CHECK_CUDNN(cudnnCreateTensorDescriptor(&h));
+ handle = h;
+
+ CHECK_CUDNN(cudnnSetTensor4dDescriptor((cudnnTensorDescriptor_t)handle,
+ CUDNN_TENSOR_NCHW,
+ CUDNN_DATA_FLOAT,
+ n,
+ k,
+ nr,
+ nc));
+ }
+ }
+
+ void tensor_descriptor::
+ get_size (
+ int& n,
+ int& k,
+ int& nr,
+ int& nc
+ ) const
+ {
+ if (handle)
+ {
+ int nStride, cStride, hStride, wStride;
+ cudnnDataType_t datatype;
+ CHECK_CUDNN(cudnnGetTensor4dDescriptor((cudnnTensorDescriptor_t)handle,
+ &datatype,
+ &n,
+ &k,
+ &nr,
+ &nc,
+ &nStride,
+ &cStride,
+ &hStride,
+ &wStride));
+ }
+ else
+ {
+ n = 0;
+ k = 0;
+ nr = 0;
+ nc = 0;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void add(
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(
+ (have_same_dimensions(src, dest) ||
+ (src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) ||
+ (src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) ||
+ (src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()) ||
+ (src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1)) &&
+ is_same_object(src,dest) == false ,
+ "\n\t dest.num_samples(): " << dest.num_samples()
+ <<"\n\t dest.k(): " << dest.k()
+ <<"\n\t dest.nr(): " << dest.nr()
+ <<"\n\t dest.nc(): " << dest.nc()
+ <<"\n\t src.num_samples(): " << src.num_samples()
+ <<"\n\t src.k(): " << src.k()
+ <<"\n\t src.nr(): " << src.nr()
+ <<"\n\t src.nc(): " << src.nc()
+ );
+
+ if (dest.size() == src.size() && beta == 1)
+ {
+ // Call the dlib function in this case since it's faster than the one that
+ // comes with cuDNN (at least as of cuDNN v4).
+ add_scaled(dest, alpha, src);
+ return;
+ }
+ else if (src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1)
+ {
+ add_cv_to_all_columns(beta, dest, alpha, src);
+ return;
+ }
+
+ CHECK_CUDNN(cudnnAddTensor(context(),
+ &alpha,
+ descriptor(src),
+ src.device(),
+ &beta,
+ descriptor(dest),
+ dest.device()));
+ }
+
+ void assign_conv_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(
+ grad.num_samples() == 1 &&
+ grad.k() >= 1 &&
+ grad.nr() == 1 &&
+ grad.nc() == 1 &&
+ gradient_input.k() == grad.k() &&
+ gradient_input.size() > 0 &&
+ is_same_object(grad,gradient_input) == false
+ );
+
+ const float alpha = 1;
+ const float beta = 0;
+ CHECK_CUDNN(cudnnConvolutionBackwardBias(context(),
+ &alpha,
+ descriptor(gradient_input),
+ gradient_input.device(),
+ &beta,
+ descriptor(grad),
+ grad.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void batch_normalize_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ )
+ {
+ DLIB_CASSERT(
+ gamma.num_samples() == 1 &&
+ gamma.nr() == src.nr() &&
+ gamma.nc() == src.nc() &&
+ gamma.k() == src.k() &&
+ have_same_dimensions(gamma, beta) &&
+ have_same_dimensions(gamma, running_means) &&
+ have_same_dimensions(gamma, running_variances) &&
+ eps > 0,
+ "\ngamma.num_samples(): " << gamma.num_samples() <<
+ "\ngamma.k(): " << gamma.k() <<
+ "\ngamma.nr(): " << gamma.nr() <<
+ "\ngamma.nc(): " << gamma.nc() <<
+ "\nbeta.num_samples(): " << beta.num_samples() <<
+ "\nbeta.k(): " << beta.k() <<
+ "\nbeta.nr(): " << beta.nr() <<
+ "\nbeta.nc(): " << beta.nc() <<
+ "\nrunning_means.num_samples(): " << running_means.num_samples() <<
+ "\nrunning_means.k(): " << running_means.k() <<
+ "\nrunning_means.nr(): " << running_means.nr() <<
+ "\nrunning_means.nc(): " << running_means.nc() <<
+ "\nrunning_variances.num_samples(): " << running_variances.num_samples() <<
+ "\nrunning_variances.k(): " << running_variances.k() <<
+ "\nrunning_variances.nr(): " << running_variances.nr() <<
+ "\nrunning_variances.nc(): " << running_variances.nc() <<
+ "\nsrc.k(): " << src.k() <<
+ "\nsrc.nr(): " << src.nr() <<
+ "\nsrc.nc(): " << src.nc() <<
+ "\neps: " << eps
+ );
+ const float in_scale = 1;
+ const float out_scale = 0;
+
+ dest.copy_size(src);
+
+ CHECK_CUDNN(cudnnBatchNormalizationForwardInference(
+ context(),
+ CUDNN_BATCHNORM_PER_ACTIVATION,
+ &in_scale,
+ &out_scale,
+ descriptor(src),
+ src.device(),
+ descriptor(dest),
+ dest.device(),
+ descriptor(gamma),
+ gamma.device(),
+ beta.device(),
+ running_means.device(),
+ running_variances.device(),
+ eps));
+ }
+
+ void batch_normalize (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ )
+ {
+ DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor);
+ DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means));
+ DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds));
+ DLIB_CASSERT(
+ src.num_samples() > 1 &&
+ gamma.num_samples() == 1 &&
+ beta.num_samples() == 1 &&
+ gamma.nr() == beta.nr() && beta.nr() == src.nr() &&
+ gamma.nc() == beta.nc() && beta.nc() == src.nc() &&
+ gamma.k() == beta.k() && beta.k() == src.k() &&
+ eps > 0,
+ "\ngamma.num_samples(): " << gamma.num_samples() <<
+ "\ngamma.k(): " << gamma.k() <<
+ "\ngamma.nr(): " << gamma.nr() <<
+ "\ngamma.nc(): " << gamma.nc() <<
+ "\nbeta.num_samples(): " << beta.num_samples() <<
+ "\nbeta.k(): " << beta.k() <<
+ "\nbeta.nr(): " << beta.nr() <<
+ "\nbeta.nc(): " << beta.nc() <<
+ "\nsrc.k(): " << src.k() <<
+ "\nsrc.nr(): " << src.nr() <<
+ "\nsrc.nc(): " << src.nc() <<
+ "\neps: " << eps
+ );
+
+ const float in_scale = 1;
+ const float out_scale = 0;
+
+ dest.copy_size(src);
+ means.set_size(1, src.k(), src.nr(), src.nc());
+ invstds.copy_size(means);
+ running_means.copy_size(means);
+ running_variances.copy_size(means);
+ // cuDNN requires that running_means and running_variances be initialized to
+ // some valid float values even if the averaging factor would have ignored
+ // them.
+ if (averaging_factor == 1)
+ {
+ running_means = 0;
+ running_variances = 1;
+ }
+
+ CHECK_CUDNN(cudnnBatchNormalizationForwardTraining(
+ context(),
+ CUDNN_BATCHNORM_PER_ACTIVATION,
+ &in_scale,
+ &out_scale,
+ descriptor(src),
+ src.device(),
+ descriptor(dest),
+ dest.device(),
+ descriptor(gamma),
+ gamma.device(),
+ beta.device(),
+ averaging_factor,
+ running_means.device(),
+ running_variances.device(),
+ eps,
+ means.device(),
+ invstds.device()));
+ }
+
+ void batch_normalize_gradient(
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ )
+ {
+ const long num = src.k()*src.nr()*src.nc();
+ DLIB_CASSERT(src.num_samples() > 1);
+ DLIB_CASSERT(num == (long)means.size());
+ DLIB_CASSERT(num == (long)invstds.size());
+ DLIB_CASSERT(num == (long)gamma.size());
+ DLIB_CASSERT(num == (long)gamma_grad.size());
+ DLIB_CASSERT(num == (long)beta_grad.size());
+ DLIB_CASSERT(have_same_dimensions(gradient_input, src));
+ DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
+ DLIB_CASSERT(eps > 0);
+
+ const float in_scale = 1;
+ const float out_scale = 1;
+ const float in_scale_params = 1;
+ const float out_scale_params = 0;
+
+ CHECK_CUDNN(cudnnBatchNormalizationBackward(
+ context(),
+ CUDNN_BATCHNORM_PER_ACTIVATION,
+ &in_scale,
+ &out_scale,
+ &in_scale_params,
+ &out_scale_params,
+ descriptor(src),
+ src.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ descriptor(src_grad),
+ src_grad.device(),
+ descriptor(gamma),
+ gamma.device(),
+ gamma_grad.device(),
+ beta_grad.device(),
+ eps,
+ means.device(),
+ invstds.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void batch_normalize_conv_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ )
+ {
+ DLIB_CASSERT(
+ gamma.num_samples() == 1 &&
+ gamma.nr() == 1 &&
+ gamma.nc() == 1 &&
+ gamma.k() == src.k() &&
+ have_same_dimensions(gamma, beta) &&
+ have_same_dimensions(gamma, running_means) &&
+ have_same_dimensions(gamma, running_variances) &&
+ eps > 0,
+ "\ngamma.num_samples(): " << gamma.num_samples() <<
+ "\ngamma.k(): " << gamma.k() <<
+ "\ngamma.nr(): " << gamma.nr() <<
+ "\ngamma.nc(): " << gamma.nc() <<
+ "\nbeta.num_samples(): " << beta.num_samples() <<
+ "\nbeta.k(): " << beta.k() <<
+ "\nbeta.nr(): " << beta.nr() <<
+ "\nbeta.nc(): " << beta.nc() <<
+ "\nrunning_means.num_samples(): " << running_means.num_samples() <<
+ "\nrunning_means.k(): " << running_means.k() <<
+ "\nrunning_means.nr(): " << running_means.nr() <<
+ "\nrunning_means.nc(): " << running_means.nc() <<
+ "\nrunning_variances.num_samples(): " << running_variances.num_samples() <<
+ "\nrunning_variances.k(): " << running_variances.k() <<
+ "\nrunning_variances.nr(): " << running_variances.nr() <<
+ "\nrunning_variances.nc(): " << running_variances.nc() <<
+ "\nsrc.k(): " << src.k() <<
+ "\nsrc.nr(): " << src.nr() <<
+ "\nsrc.nc(): " << src.nc() <<
+ "\neps: " << eps
+ );
+ const float in_scale = 1;
+ const float out_scale = 0;
+
+ dest.copy_size(src);
+
+ CHECK_CUDNN(cudnnBatchNormalizationForwardInference(
+ context(),
+ CUDNN_BATCHNORM_SPATIAL,
+ &in_scale,
+ &out_scale,
+ descriptor(src),
+ src.device(),
+ descriptor(dest),
+ dest.device(),
+ descriptor(gamma),
+ gamma.device(),
+ beta.device(),
+ running_means.device(),
+ running_variances.device(),
+ eps));
+ }
+
+ void batch_normalize_conv (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ )
+ {
+ DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor);
+ DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means));
+ DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds));
+ DLIB_CASSERT(
+ src.num_samples() > 1 &&
+ gamma.num_samples() == 1 &&
+ beta.num_samples() == 1 &&
+ gamma.nr() == 1 &&
+ beta.nr() == 1 &&
+ gamma.nc() == 1 &&
+ beta.nc() == 1 &&
+ gamma.k() == beta.k() && beta.k() == src.k() &&
+ eps > 0,
+ "\ngamma.num_samples(): " << gamma.num_samples() <<
+ "\ngamma.k(): " << gamma.k() <<
+ "\ngamma.nr(): " << gamma.nr() <<
+ "\ngamma.nc(): " << gamma.nc() <<
+ "\nbeta.num_samples(): " << beta.num_samples() <<
+ "\nbeta.k(): " << beta.k() <<
+ "\nbeta.nr(): " << beta.nr() <<
+ "\nbeta.nc(): " << beta.nc() <<
+ "\nsrc.k(): " << src.k() <<
+ "\nsrc.nr(): " << src.nr() <<
+ "\nsrc.nc(): " << src.nc() <<
+ "\neps: " << eps
+ );
+ const float in_scale = 1;
+ const float out_scale = 0;
+
+ dest.copy_size(src);
+ means.set_size(1, src.k());
+ invstds.copy_size(means);
+ running_means.copy_size(means);
+ running_variances.copy_size(means);
+ // cuDNN requires that running_means and running_variances be initialized to
+ // some valid float values even if the averaging factor would have ignored
+ // them.
+ if (averaging_factor == 1)
+ {
+ running_means = 0;
+ running_variances = 1;
+ }
+
+ CHECK_CUDNN(cudnnBatchNormalizationForwardTraining(
+ context(),
+ CUDNN_BATCHNORM_SPATIAL,
+ &in_scale,
+ &out_scale,
+ descriptor(src),
+ src.device(),
+ descriptor(dest),
+ dest.device(),
+ descriptor(gamma),
+ gamma.device(),
+ beta.device(),
+ averaging_factor,
+ running_means.device(),
+ running_variances.device(),
+ eps,
+ means.device(),
+ invstds.device()));
+ }
+
+ void batch_normalize_conv_gradient(
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ )
+ {
+ DLIB_CASSERT(src.k() == (long)means.size());
+ DLIB_CASSERT(src.k() == (long)invstds.size());
+ DLIB_CASSERT(src.k() == (long)gamma.size());
+ DLIB_CASSERT(src.k() == (long)gamma_grad.size());
+ DLIB_CASSERT(src.k() == (long)beta_grad.size());
+ DLIB_CASSERT(have_same_dimensions(gradient_input, src));
+ DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
+ DLIB_CASSERT(eps > 0);
+
+ const float in_scale = 1;
+ const float out_scale = 1;
+ const float in_scale_params = 1;
+ const float out_scale_params = 0;
+
+ CHECK_CUDNN(cudnnBatchNormalizationBackward(
+ context(),
+ CUDNN_BATCHNORM_SPATIAL,
+ &in_scale,
+ &out_scale,
+ &in_scale_params,
+ &out_scale_params,
+ descriptor(src),
+ src.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ descriptor(src_grad),
+ src_grad.device(),
+ descriptor(gamma),
+ gamma.device(),
+ gamma_grad.device(),
+ beta_grad.device(),
+ eps,
+ means.device(),
+ invstds.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ tensor_conv::
+ tensor_conv(
+ ) :
+ filter_handle(nullptr),
+ conv_handle(nullptr),
+ forward_algo(0),
+ backward_data_algo(0),
+ backward_filters_algo(0)
+ {
+ clear();
+ }
+
+ void tensor_conv::
+ clear (
+ )
+ {
+ if (filter_handle)
+ cudnnDestroyFilterDescriptor((cudnnFilterDescriptor_t)filter_handle);
+ if (conv_handle)
+ cudnnDestroyConvolutionDescriptor((cudnnConvolutionDescriptor_t)conv_handle);
+ filter_handle = nullptr;
+ conv_handle = nullptr;
+ out_num_samples = 0;
+ out_k = 0;
+ out_nr = 0;
+ out_nc = 0;
+
+ stride_y = 0;
+ stride_x = 0;
+ padding_y = 0;
+ padding_x = 0;
+ data_num_samples = 0;
+ data_k = 0;
+ data_nr = 0;
+ data_nc = 0;
+ filters_num_samples = 0;
+ filters_k = 0;
+ filters_nr = 0;
+ filters_nc = 0;
+
+ forward_algo = 0;
+ backward_data_algo = 0;
+ backward_filters_algo = 0;
+
+ forward_workspace_size_in_bytes = 0;
+ backward_data_workspace_size_in_bytes = 0;
+ backward_filters_workspace_size_in_bytes = 0;
+
+ forward_workspace.reset();
+ backward_data_workspace.reset();
+ backward_filters_workspace.reset();
+ workspace.reset();
+ }
+
+ void tensor_conv::
+ setup(
+ const tensor& data,
+ const tensor& filters,
+ int stride_y_,
+ int stride_x_,
+ int padding_y_,
+ int padding_x_
+ )
+ {
+ DLIB_CASSERT(data.k() == filters.k());
+
+ // if the last call to setup gave the same exact settings then don't do
+ // anything.
+ if (stride_y_ == stride_y &&
+ stride_x_ == stride_x &&
+ padding_y_ == padding_y &&
+ padding_x_ == padding_x &&
+ data_num_samples == data.num_samples() &&
+ data_k == data.k() &&
+ data_nr == data.nr() &&
+ data_nc == data.nc() &&
+ filters_num_samples == filters.num_samples() &&
+ filters_k == filters.k() &&
+ filters_nr == filters.nr() &&
+ filters_nc == filters.nc())
+ {
+ return;
+ }
+
+ clear();
+ try
+ {
+ stride_y = stride_y_;
+ stride_x = stride_x_;
+ padding_y = padding_y_;
+ padding_x = padding_x_;
+ data_num_samples = data.num_samples();
+ data_k = data.k();
+ data_nr = data.nr();
+ data_nc = data.nc();
+ filters_num_samples = filters.num_samples();
+ filters_k = filters.k();
+ filters_nr = filters.nr();
+ filters_nc = filters.nc();
+
+ CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
+ CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
+ CUDNN_DATA_FLOAT,
+ CUDNN_TENSOR_NCHW,
+ filters.num_samples(),
+ filters.k(),
+ filters.nr(),
+ filters.nc()));
+
+ CHECK_CUDNN(cudnnCreateConvolutionDescriptor((cudnnConvolutionDescriptor_t*)&conv_handle));
+#if CUDNN_MAJOR >= 6
+ CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle,
+ padding_y, // vertical padding
+ padding_x, // horizontal padding
+ stride_y,
+ stride_x,
+ 1, 1, // must be 1,1
+ CUDNN_CROSS_CORRELATION,
+ CUDNN_DATA_FLOAT)); // could also be CUDNN_CONVOLUTION
+#else
+ CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle,
+ padding_y, // vertical padding
+ padding_x, // horizontal padding
+ stride_y,
+ stride_x,
+ 1, 1, // must be 1,1
+ CUDNN_CROSS_CORRELATION)); // could also be CUDNN_CONVOLUTION
+#endif
+
+ CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ descriptor(data),
+ (const cudnnFilterDescriptor_t)filter_handle,
+ &out_num_samples,
+ &out_k,
+ &out_nr,
+ &out_nc));
+
+ tensor_descriptor dest_desc;
+ dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
+
+ // Pick which forward algorithm we will use and allocate the necessary
+ // workspace buffer.
+ cudnnConvolutionFwdAlgo_t forward_best_algo;
+ CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(
+ context(),
+ descriptor(data),
+ (const cudnnFilterDescriptor_t)filter_handle,
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ descriptor(dest_desc),
+ dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
+ std::numeric_limits<size_t>::max(),
+ &forward_best_algo));
+ forward_algo = forward_best_algo;
+ CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
+ context(),
+ descriptor(data),
+ (const cudnnFilterDescriptor_t)filter_handle,
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ descriptor(dest_desc),
+ forward_best_algo,
+ &forward_workspace_size_in_bytes));
+
+ // Pick which backward data algorithm we will use and allocate the
+ // necessary workspace buffer.
+ cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
+ CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(
+ context(),
+ (const cudnnFilterDescriptor_t)filter_handle,
+ descriptor(dest_desc),
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ descriptor(data),
+ dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE,
+ std::numeric_limits<size_t>::max(),
+ &backward_data_best_algo));
+ backward_data_algo = backward_data_best_algo;
+
+ CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(
+ context(),
+ (const cudnnFilterDescriptor_t)filter_handle,
+ descriptor(dest_desc),
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ descriptor(data),
+ backward_data_best_algo,
+ &backward_data_workspace_size_in_bytes));
+
+ // Pick which backward filters algorithm we will use and allocate the
+ // necessary workspace buffer.
+ cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
+ CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(
+ context(),
+ descriptor(data),
+ descriptor(dest_desc),
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ (const cudnnFilterDescriptor_t)filter_handle,
+ dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE,
+ std::numeric_limits<size_t>::max(),
+ &backward_filters_best_algo));
+ // cuDNN 5.1 has a bug that causes
+ // cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
+ // algorithm even for cases where cuDNN doesn't support it, leading to
+ // incorrect outputs. So here we check if we are in a case where winograd
+ // isn't supported and manually overrule
+ // cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe
+ // algorithm.
+ if (dnn_prefer_fastest_algorithms() &&
+ !(stride_x == 1 && stride_y == 1 && ((filters_nr==3&&filters_nc==3) || (filters_nr==5&&filters_nc==5)))
+ )
+ {
+ backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
+ }
+ backward_filters_algo = backward_filters_best_algo;
+
+ CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(
+ context(),
+ descriptor(data),
+ descriptor(dest_desc),
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ (const cudnnFilterDescriptor_t)filter_handle,
+ backward_filters_best_algo,
+ &backward_filters_workspace_size_in_bytes));
+
+ workspace = device_global_buffer();
+ }
+ catch(...)
+ {
+ clear();
+ throw;
+ }
+ }
+
+ tensor_conv::
+ ~tensor_conv (
+ )
+ {
+ clear();
+ }
+
+ void tensor_conv::operator() (
+ const bool add_to_output,
+ resizable_tensor& output,
+ const tensor& data,
+ const tensor& filters
+ )
+ {
+ DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function");
+
+ output.set_size(out_num_samples, out_k, out_nr, out_nc);
+ (*this)(add_to_output, static_cast<tensor&>(output), data, filters);
+ }
+
+ void tensor_conv::operator() (
+ const bool add_to_output,
+ tensor& output,
+ const tensor& data,
+ const tensor& filters
+ )
+ {
+ DLIB_CASSERT(is_same_object(output,data) == false);
+ DLIB_CASSERT(is_same_object(output,filters) == false);
+ DLIB_CASSERT(filters.k() == data.k());
+ DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function");
+ DLIB_CASSERT(filters.nc() <= data.nc() + 2*padding_x,
+ "Filter windows must be small enough to fit into the padded image."
+ << "\n\t filters.nc(): " << filters.nc()
+ << "\n\t data.nc(): " << data.nc()
+ << "\n\t padding_x: " << padding_x
+ );
+ DLIB_CASSERT(filters.nr() <= data.nr() + 2*padding_y,
+ "Filter windows must be small enough to fit into the padded image."
+ << "\n\t filters.nr(): " << filters.nr()
+ << "\n\t data.nr(): " << data.nr()
+ << "\n\t padding_y: " << padding_y
+ );
+
+
+ DLIB_CASSERT(output.num_samples() == data.num_samples(),out_num_samples << " " << data.num_samples());
+ DLIB_CASSERT(output.k() == filters.num_samples());
+ DLIB_CASSERT(output.nr() == 1+(data.nr()+2*padding_y-filters.nr())/stride_y);
+ DLIB_CASSERT(output.nc() == 1+(data.nc()+2*padding_x-filters.nc())/stride_x);
+
+
+
+ const float alpha = 1;
+ const float beta = add_to_output ? 1 : 0;
+
+ // Since cudnnConvolutionForward() is an asynchronous call, we need to hold a
+ // reference to the workspace buffer so we can be sure it isn't reallocated
+ // while the function is still executing on the device. But each time we come
+ // here, we make sure to grab the latest workspace buffer so that, globally, we
+ // minimize the number of such buffers.
+ forward_workspace = workspace->get(forward_workspace_size_in_bytes);
+
+ CHECK_CUDNN(cudnnConvolutionForward(
+ context(),
+ &alpha,
+ descriptor(data),
+ data.device(),
+ (const cudnnFilterDescriptor_t)filter_handle,
+ filters.device(),
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ (cudnnConvolutionFwdAlgo_t)forward_algo,
+ forward_workspace,
+ forward_workspace_size_in_bytes,
+ &beta,
+ descriptor(output),
+ output.device()));
+ }
+
+ void tensor_conv::get_gradient_for_data (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& filters,
+ tensor& data_gradient
+ )
+ {
+ const float alpha = 1;
+ const float beta = add_to_output ? 1 : 0;
+
+ // Since cudnnConvolutionBackwardData() is an asynchronous call, we need to hold a
+ // reference to the workspace buffer so we can be sure it isn't reallocated
+ // while the function is still executing on the device. But each time we come
+ // here, we make sure to grab the latest workspace buffer so that, globally, we
+ // minimize the number of such buffers.
+ backward_data_workspace = workspace->get(backward_data_workspace_size_in_bytes);
+
+
+ CHECK_CUDNN(cudnnConvolutionBackwardData(context(),
+ &alpha,
+ (const cudnnFilterDescriptor_t)filter_handle,
+ filters.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ (cudnnConvolutionBwdDataAlgo_t)backward_data_algo,
+ backward_data_workspace,
+ backward_data_workspace_size_in_bytes,
+ &beta,
+ descriptor(data_gradient),
+ data_gradient.device()));
+ }
+
+ void tensor_conv::
+ get_gradient_for_filters (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& data,
+ tensor& filters_gradient
+ )
+ {
+ const float alpha = 1;
+ const float beta = add_to_output ? 1 : 0;
+
+ // Since cudnnConvolutionBackwardFilter() is an asynchronous call, we need to hold a
+ // reference to the workspace buffer so we can be sure it isn't reallocated
+ // while the function is still executing on the device. But each time we come
+ // here, we make sure to grab the latest workspace buffer so that, globally, we
+ // minimize the number of such buffers.
+ backward_filters_workspace = workspace->get(backward_filters_workspace_size_in_bytes);
+
+ CHECK_CUDNN(cudnnConvolutionBackwardFilter(context(),
+ &alpha,
+ descriptor(data),
+ data.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ (const cudnnConvolutionDescriptor_t)conv_handle,
+ (cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo,
+ backward_filters_workspace,
+ backward_filters_workspace_size_in_bytes,
+ &beta,
+ (const cudnnFilterDescriptor_t)filter_handle,
+ filters_gradient.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ pooling::pooling (
+ ) : handle(nullptr),window_height(0),window_width(0),stride_y(0),stride_x(0),padding_y(0), padding_x(0)
+ {
+ }
+
+ pooling::~pooling(
+ )
+ {
+ clear();
+ }
+
+ void pooling::
+ clear(
+ )
+ {
+ if (handle)
+ cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle);
+ handle = nullptr;
+ window_height = 0;
+ window_width = 0;
+ stride_y = 0;
+ stride_x = 0;
+ padding_y = 0;
+ padding_x = 0;
+ }
+
+ void pooling::
+ setup_max_pooling(
+ int window_height_,
+ int window_width_,
+ int stride_y_,
+ int stride_x_,
+ int padding_y_,
+ int padding_x_
+ )
+ {
+ setup(window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_MAX);
+ do_max_pooling = true;
+ }
+
+ void pooling::
+ setup_avg_pooling(
+ int window_height_,
+ int window_width_,
+ int stride_y_,
+ int stride_x_,
+ int padding_y_,
+ int padding_x_
+ )
+ {
+ setup(window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING);
+ do_max_pooling = false;
+ }
+
+ void pooling::
+ setup(
+ int window_height_,
+ int window_width_,
+ int stride_y_,
+ int stride_x_,
+ int padding_y_,
+ int padding_x_,
+ int pooling_mode
+ )
+ {
+ DLIB_CASSERT (window_height_ > 0 && window_width_ > 0 &&
+ stride_y_ > 0 && stride_x_ > 0 ,
+ "window_height_: " << window_height_
+ << "\t\n window_width_: " << window_width_
+ << "\t\n stride_y_: " << stride_y_
+ << "\t\n stride_x_: " << stride_x_ );
+ DLIB_CASSERT( 0 <= padding_y_ && padding_y_ < window_height_ &&
+ 0 <= padding_x_ && padding_x_ < window_width_,
+ "window_height_: " << window_height_
+ << "\t\n window_width_: " << window_width_
+ << "\t\n padding_y_: " << padding_y_
+ << "\t\n padding_x_: " << padding_x_ );
+
+ if (window_height == window_height_ &&
+ window_width == window_width_ &&
+ stride_y == stride_y_ &&
+ stride_x == stride_x_ &&
+ padding_y == padding_y_ &&
+ padding_x == padding_x_
+ )
+ {
+ return;
+ }
+
+ clear();
+ try
+ {
+ window_height = window_height_;
+ window_width = window_width_;
+ stride_x = stride_x_;
+ stride_y = stride_y_;
+ padding_y = padding_y_;
+ padding_x = padding_x_;
+ cudnnPoolingDescriptor_t poolingDesc;
+ CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc));
+ handle = poolingDesc;
+
+ CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc,
+ (cudnnPoolingMode_t)pooling_mode,
+ CUDNN_PROPAGATE_NAN,
+ window_height,
+ window_width,
+ padding_y,
+ padding_x,
+ stride_y,
+ stride_x));
+ }
+ catch(...)
+ {
+ clear();
+ throw;
+ }
+ }
+
+ void pooling::
+ operator() (
+ resizable_tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(window_width <= src.nc() + 2*padding_x,
+ "Pooling windows must be small enough to fit into the padded image."
+ << "\n\t window_width: " << window_width
+ << "\n\t src.nc(): " << src.nc()
+ << "\n\t padding_x: " << padding_x
+ );
+ DLIB_CASSERT(window_height <= src.nr() + 2*padding_y,
+ "Pooling windows must be small enough to fit into the padded image."
+ << "\n\t window_height: " << window_height
+ << "\n\t src.nr(): " << src.nr()
+ << "\n\t padding_y: " << padding_y
+ );
+ const float alpha = 1;
+ const float beta = 0;
+ int outN;
+ int outC;
+ int outH;
+ int outW;
+ CHECK_CUDNN(cudnnGetPooling2dForwardOutputDim((const cudnnPoolingDescriptor_t)handle,
+ descriptor(src),
+ &outN,
+ &outC,
+ &outH,
+ &outW));
+
+
+ dest.set_size(outN,outC,outH,outW);
+
+ DLIB_CASSERT(dest.num_samples() == src.num_samples());
+ DLIB_CASSERT(dest.k() == src.k());
+ DLIB_CASSERT(dest.nr() == 1 + (src.nr() + 2*padding_y - window_height)/stride_y,
+ "\n stride_y: " << stride_y <<
+ "\n padding_y: " << padding_y <<
+ "\n window_height: " << window_height <<
+ "\n src.nr(): " << src.nr() <<
+ "\n dest.nr(): " << dest.nr() <<
+ "\n src.nr()/stride_y: " << src.nr()/stride_y);
+ DLIB_CASSERT(dest.nc() == 1 + (src.nc() + 2*padding_x - window_width)/stride_x,
+ "\n stride_x: " << stride_x <<
+ "\n padding_x: " << padding_x <<
+ "\n window_width: " << window_width <<
+ "\n src.nc(): " << src.nc() <<
+ "\n dest.nc(): " << dest.nc() <<
+ "\n src.nc()/stride_x: " << src.nc()/stride_x);
+
+ CHECK_CUDNN(cudnnPoolingForward(context(),
+ (const cudnnPoolingDescriptor_t)handle,
+ &alpha,
+ descriptor(src),
+ src.device(),
+ &beta,
+ descriptor(dest),
+ dest.device()));
+ }
+
+ void pooling::get_gradient(
+ const tensor& gradient_input,
+ const tensor& dest,
+ const tensor& src,
+ tensor& grad
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(gradient_input,dest));
+ DLIB_CASSERT(have_same_dimensions(src,grad));
+
+ const float alpha = 1;
+ const float beta = 1;
+ CHECK_CUDNN(cudnnPoolingBackward(context(),
+ (const cudnnPoolingDescriptor_t)handle,
+ &alpha,
+ descriptor(dest),
+ dest.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ descriptor(src),
+ src.device(),
+ &beta,
+ descriptor(grad),
+ grad.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ void softmax (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ if (src.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = 0;
+
+ CHECK_CUDNN(cudnnSoftmaxForward(context(),
+ CUDNN_SOFTMAX_ACCURATE,
+ CUDNN_SOFTMAX_MODE_CHANNEL,
+ &alpha,
+ descriptor(src),
+ src.device(),
+ &beta,
+ descriptor(dest),
+ dest.device()));
+ }
+
+
+ void softmax_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(
+ have_same_dimensions(dest,gradient_input) == true &&
+ have_same_dimensions(dest,grad) == true );
+ if (dest.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = is_same_object(grad,gradient_input) ? 0 : 1;
+ CHECK_CUDNN(cudnnSoftmaxBackward(context(),
+ CUDNN_SOFTMAX_ACCURATE,
+ CUDNN_SOFTMAX_MODE_CHANNEL,
+ &alpha,
+ descriptor(dest),
+ dest.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ &beta,
+ descriptor(grad),
+ grad.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ void softmax_all (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ if (src.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = 0;
+
+ CHECK_CUDNN(cudnnSoftmaxForward(context(),
+ CUDNN_SOFTMAX_ACCURATE,
+ CUDNN_SOFTMAX_MODE_INSTANCE,
+ &alpha,
+ descriptor(src),
+ src.device(),
+ &beta,
+ descriptor(dest),
+ dest.device()));
+ }
+
+
+ void softmax_all_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(
+ have_same_dimensions(dest,gradient_input) == true &&
+ have_same_dimensions(dest,grad) == true );
+ if (dest.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = is_same_object(grad,gradient_input) ? 0 : 1;
+ CHECK_CUDNN(cudnnSoftmaxBackward(context(),
+ CUDNN_SOFTMAX_ACCURATE,
+ CUDNN_SOFTMAX_MODE_INSTANCE,
+ &alpha,
+ descriptor(dest),
+ dest.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ &beta,
+ descriptor(grad),
+ grad.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ void sigmoid (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ if (src.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = 0;
+ CHECK_CUDNN(cudnnActivationForward(context(),
+ sigmoid_activation_descriptor(),
+ &alpha,
+ descriptor(src),
+ src.device(),
+ &beta,
+ descriptor(dest),
+ dest.device()));
+ }
+
+ void sigmoid_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(
+ have_same_dimensions(dest,gradient_input) == true &&
+ have_same_dimensions(dest,grad) == true );
+ if (dest.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = is_same_object(grad,gradient_input) ? 0 : 1;
+ CHECK_CUDNN(cudnnActivationBackward(context(),
+ sigmoid_activation_descriptor(),
+ &alpha,
+ descriptor(dest),
+ dest.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ descriptor(dest),
+ dest.device(),
+ &beta,
+ descriptor(grad),
+ grad.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void relu (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ if (src.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = 0;
+ CHECK_CUDNN(cudnnActivationForward(context(),
+ relu_activation_descriptor(),
+ &alpha,
+ descriptor(src),
+ src.device(),
+ &beta,
+ descriptor(dest),
+ dest.device()));
+ }
+
+ void relu_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(
+ have_same_dimensions(dest,gradient_input) == true &&
+ have_same_dimensions(dest,grad) == true );
+ if (dest.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = is_same_object(grad,gradient_input) ? 0 : 1;
+ CHECK_CUDNN(cudnnActivationBackward(context(),
+ relu_activation_descriptor(),
+ &alpha,
+ descriptor(dest),
+ dest.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ descriptor(dest),
+ dest.device(),
+ &beta,
+ descriptor(grad),
+ grad.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void tanh (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(dest,src));
+ if (src.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = 0;
+ CHECK_CUDNN(cudnnActivationForward(context(),
+ tanh_activation_descriptor(),
+ &alpha,
+ descriptor(src),
+ src.device(),
+ &beta,
+ descriptor(dest),
+ dest.device()));
+ }
+
+ void tanh_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+ DLIB_CASSERT(
+ have_same_dimensions(dest,gradient_input) == true &&
+ have_same_dimensions(dest,grad) == true);
+ if (dest.size() == 0)
+ return;
+
+ const float alpha = 1;
+ const float beta = is_same_object(grad,gradient_input) ? 0 : 1;
+ CHECK_CUDNN(cudnnActivationBackward(context(),
+ tanh_activation_descriptor(),
+ &alpha,
+ descriptor(dest),
+ dest.device(),
+ descriptor(gradient_input),
+ gradient_input.device(),
+ descriptor(dest),
+ dest.device(),
+ &beta,
+ descriptor(grad),
+ grad.device()));
+ }
+
+ // ------------------------------------------------------------------------------------
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuDNN_CPP_
+
+
diff --git a/ml/dlib/dlib/dnn/cudnn_dlibapi.h b/ml/dlib/dlib/dnn/cudnn_dlibapi.h
new file mode 100644
index 000000000..e9ffe5f6d
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cudnn_dlibapi.h
@@ -0,0 +1,518 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuDNN_H_
+#define DLIB_DNN_CuDNN_H_
+
+#ifdef DLIB_USE_CUDA
+
+#include "cuda_errors.h"
+#include <memory>
+#include "cuda_data_ptr.h"
+
+namespace dlib
+{
+ class tensor;
+ class resizable_tensor;
+
+ namespace cuda
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ class tensor_descriptor
+ {
+ /*!
+ Each tensor object will carry a tensor_descriptor in it when compiled with
+ CUDA.
+ !*/
+
+ public:
+ // not copyable
+ tensor_descriptor(const tensor_descriptor&) = delete;
+ tensor_descriptor& operator=(const tensor_descriptor&) = delete;
+ // but is movable
+ tensor_descriptor(tensor_descriptor&& item) : tensor_descriptor() { swap(item); }
+ tensor_descriptor& operator=(tensor_descriptor&& item) { swap(item); return *this; }
+
+ tensor_descriptor();
+ ~tensor_descriptor();
+
+ void set_size(
+ int n,
+ int k,
+ int nr,
+ int nc
+ );
+ /*!
+ ensures
+ - if any of the arguments are 0 then they are all set to 0 in the tensor.
+ !*/
+
+ void get_size (
+ int& n,
+ int& k,
+ int& nr,
+ int& nc
+ ) const;
+
+ const void* get_handle (
+ ) const { return handle; }
+
+ private:
+
+ void swap(tensor_descriptor& item) { std::swap(handle, item.handle); }
+
+ void* handle;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ void add(
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& src
+ );
+ /*!
+ requires
+ - One of the following is true:
+ - have_same_dimensions(src, dest)
+ - src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1
+ - src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()
+ - src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()
+ - is_same_object(src,dest) == false
+ ensures
+ - performs: dest = beta*dest + alpha*src
+ However, how the addition happens depends on the dimensions of src. In
+ particular, this function adds the scaled values of one src tensor to
+ dest. Each dimension of the src tensor must match the corresponding
+ dimension of the dest tensor or must be equal to 1. In the latter case,
+ the same value from the src tensor, for those dimensions, will be used to
+ add into the dest tensor.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ void assign_conv_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - grad.num_samples() == 1
+ - grad.k() >= 1
+ - grad.nr() == 1
+ - grad.nc() == 1
+ - gradient_input.k() == grad.k()
+ - gradient_input.size() > 0
+ - is_same_object(grad,gradient_input) == false
+ ensures
+ - let BIAS be a tensor with all dimensions equal to 1 except for k which is >= 1.
+ - let OUT be the output of add(1,OUT,1,BIAS)
+ - let f(gradient_input,BIAS) == dot(gradient_input,OUT)
+ - Then this function computes the gradient of f() with respect to BIAS and
+ assigns it to grad.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ void batch_normalize_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ );
+
+ void batch_normalize (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ );
+
+ void batch_normalize_gradient(
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ void batch_normalize_conv_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ );
+
+ void batch_normalize_conv (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ );
+
+ void batch_normalize_conv_gradient(
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ class tensor_conv
+ {
+ public:
+ tensor_conv(const tensor_conv&) = delete;
+ tensor_conv& operator=(const tensor_conv&) = delete;
+
+ tensor_conv();
+
+ void clear(
+ );
+
+ ~tensor_conv (
+ );
+
+ void operator() (
+ const bool add_to_output,
+ tensor& output,
+ const tensor& data,
+ const tensor& filters
+ );
+
+ void operator() (
+ const bool add_to_output,
+ resizable_tensor& output,
+ const tensor& data,
+ const tensor& filters
+ );
+
+ void get_gradient_for_data (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& filters,
+ tensor& data_gradient
+ );
+
+ void get_gradient_for_filters (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& data,
+ tensor& filters_gradient
+ );
+
+ void setup(
+ const tensor& data,
+ const tensor& filters,
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x
+ );
+
+ private:
+
+ // These variables record the type of data given to the last call to setup().
+ int stride_y;
+ int stride_x;
+ int padding_y;
+ int padding_x;
+ long data_num_samples, data_k, data_nr, data_nc;
+ long filters_num_samples, filters_k, filters_nr, filters_nc;
+
+
+ void* filter_handle;
+ void* conv_handle;
+
+ // dimensions of the output tensor from operator()
+ int out_num_samples;
+ int out_k;
+ int out_nr;
+ int out_nc;
+
+ int forward_algo;
+ int backward_data_algo;
+ int backward_filters_algo;
+
+ size_t forward_workspace_size_in_bytes;
+ size_t backward_data_workspace_size_in_bytes;
+ size_t backward_filters_workspace_size_in_bytes;
+ std::shared_ptr<resizable_cuda_buffer> workspace;
+ cuda_data_void_ptr forward_workspace;
+ cuda_data_void_ptr backward_data_workspace;
+ cuda_data_void_ptr backward_filters_workspace;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ class pooling
+ {
+ public:
+
+ pooling(const pooling&) = delete;
+ pooling& operator=(const pooling&) = delete;
+
+ pooling (
+ );
+
+ ~pooling(
+ );
+
+ void clear(
+ );
+
+ void setup_max_pooling(
+ int window_height,
+ int window_width,
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x
+ );
+
+ void setup_avg_pooling(
+ int window_height,
+ int window_width,
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x
+ );
+
+ bool does_max_pooling(
+ ) const { return do_max_pooling; }
+
+ void operator() (
+ resizable_tensor& dest,
+ const tensor& src
+ );
+
+ void get_gradient(
+ const tensor& gradient_input,
+ const tensor& dest,
+ const tensor& src,
+ tensor& grad
+ );
+
+ private:
+
+ void setup(
+ int window_height,
+ int window_width,
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x,
+ int pooling_mode
+ );
+
+ void* handle;
+ int window_height;
+ int window_width;
+ int stride_y;
+ int stride_x;
+ int padding_y;
+ int padding_x;
+ bool do_max_pooling;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ void softmax (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ ensures
+ - Note that the softmax function is a vector valued function:
+ s(x) == exp(x)/sum(exp(x))
+ - Computes the softmax function on src and writes the results to dest. The
+ softmax is computed per spatial location across the different channels at
+ each location. That is, softmax() outputs a new tensor, #dest, where
+ each of the spatial locations in dest (i.e. image idx, row idx, and
+ column idx) contains the output of s() evaluated over the channel values
+ at each location.
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void softmax_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,gradient_input) == true
+ - have_same_dimensions(dest,grad) == true
+ - is_same_object(grad, dest)==false
+ ensures
+ - We interpret dest as the output of softmax(dest,SRC) for some SRC tensor.
+ Then let f(SRC) == dot(gradient_input,dest) Then this function computes
+ the gradient of f() with respect to SRC and assigns it to grad.
+ - This function supports in-place operation, i.e. having
+ is_same_object(grad, gradient_input)==true
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ void softmax_all (
+ tensor& dest,
+ const tensor& src
+ );
+
+ void softmax_all_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+
+ // ------------------------------------------------------------------------------------
+
+ void sigmoid (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ ensures
+ - for all valid i:
+ - #dest.host()[i] == 1/(1+std::exp(-src.host()[i]))
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void sigmoid_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,gradient_input) == true
+ - have_same_dimensions(dest,grad) == true
+ - is_same_object(grad,dest) == false
+ ensures
+ - Recalling that dest is the output of sigmoid(dest,SRC) for some SRC tensor,
+ let f(SRC) == dot(gradient_input,dest)
+ - Then this function computes the gradient of f() with respect to SRC and
+ assigns it to grad.
+ - This function supports in-place operation, i.e. having
+ is_same_object(grad, gradient_input)==true
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ void relu (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ ensures
+ - for all valid i:
+ - #dest.host()[i] == std::max(0,src.host()[i])
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void relu_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,gradient_input) == true
+ - have_same_dimensions(dest,grad) == true
+ - is_same_object(grad,dest) == false
+ ensures
+ - Recalling that dest is the output of relu(dest,SRC) for some SRC tensor,
+ let f(SRC) == dot(gradient_input,dest)
+ - Then this function computes the gradient of f() with respect to SRC and
+ assigns it to grad.
+ - This function supports in-place operation, i.e. having
+ is_same_object(grad, gradient_input)==true
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ void tanh (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ ensures
+ - for all valid i:
+ - #dest.host()[i] == std::tanh(src.host()[i])
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void tanh_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,gradient_input) == true
+ - have_same_dimensions(dest,grad) == true
+ - is_same_object(grad,dest) == false
+ ensures
+ - Recalling that dest is the output of tanh(dest,SRC) for some SRC tensor,
+ let f(SRC) == dot(gradient_input,dest)
+ - Then this function computes the gradient of f() with respect to SRC and
+ assigns it to grad.
+ - This function supports in-place operation, i.e. having
+ is_same_object(grad, gradient_input)==true
+ !*/
+
+
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuDNN_H_
+
diff --git a/ml/dlib/dlib/dnn/curand_dlibapi.cpp b/ml/dlib/dlib/dnn/curand_dlibapi.cpp
new file mode 100644
index 000000000..67828e664
--- /dev/null
+++ b/ml/dlib/dlib/dnn/curand_dlibapi.cpp
@@ -0,0 +1,113 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuRAND_CPP_
+#define DLIB_DNN_CuRAND_CPP_
+
+#ifdef DLIB_USE_CUDA
+
+#include "curand_dlibapi.h"
+#include <curand.h>
+#include "../string.h"
+
+static const char* curand_get_error_string(curandStatus_t s)
+{
+ switch(s)
+ {
+ case CURAND_STATUS_NOT_INITIALIZED:
+ return "CUDA Runtime API initialization failed.";
+ case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
+ return "The requested length must be a multiple of two.";
+ default:
+ return "A call to cuRAND failed";
+ }
+}
+
+// Check the return value of a call to the cuDNN runtime for an error condition.
+#define CHECK_CURAND(call) \
+do{ \
+ const curandStatus_t error = call; \
+ if (error != CURAND_STATUS_SUCCESS) \
+ { \
+ std::ostringstream sout; \
+ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
+ sout << "code: " << error << ", reason: " << curand_get_error_string(error);\
+ throw dlib::curand_error(sout.str()); \
+ } \
+}while(false)
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // ----------------------------------------------------------------------------------------
+
+ curand_generator::
+ curand_generator(
+ unsigned long long seed
+ ) : handle(nullptr)
+ {
+ curandGenerator_t gen;
+ CHECK_CURAND(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
+ handle = gen;
+
+ CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(gen, seed));
+ }
+
+ curand_generator::
+ ~curand_generator()
+ {
+ if (handle)
+ {
+ curandDestroyGenerator((curandGenerator_t)handle);
+ }
+ }
+
+ void curand_generator::
+ fill_gaussian (
+ tensor& data,
+ float mean,
+ float stddev
+ )
+ {
+ if (data.size() == 0)
+ return;
+
+ CHECK_CURAND(curandGenerateNormal((curandGenerator_t)handle,
+ data.device(),
+ data.size(),
+ mean,
+ stddev));
+ }
+
+ void curand_generator::
+ fill_uniform (
+ tensor& data
+ )
+ {
+ if (data.size() == 0)
+ return;
+
+ CHECK_CURAND(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size()));
+ }
+
+ void curand_generator::
+ fill (
+ cuda_data_ptr<unsigned int>& data
+ )
+ {
+ if (data.size() == 0)
+ return;
+
+ CHECK_CURAND(curandGenerate((curandGenerator_t)handle, data, data.size()));
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuRAND_CPP_
+
diff --git a/ml/dlib/dlib/dnn/curand_dlibapi.h b/ml/dlib/dlib/dnn/curand_dlibapi.h
new file mode 100644
index 000000000..cd51fecee
--- /dev/null
+++ b/ml/dlib/dlib/dnn/curand_dlibapi.h
@@ -0,0 +1,75 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuRAND_H_
+#define DLIB_DNN_CuRAND_H_
+
+#ifdef DLIB_USE_CUDA
+
+#include "tensor.h"
+#include "cuda_errors.h"
+#include "cuda_data_ptr.h"
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ class curand_generator
+ {
+ public:
+ // not copyable
+ curand_generator(const curand_generator&) = delete;
+ curand_generator& operator=(const curand_generator&) = delete;
+
+ curand_generator() : curand_generator(0) {}
+ curand_generator(unsigned long long seed);
+ ~curand_generator();
+
+ void fill (
+ cuda_data_ptr<unsigned int>& data
+ );
+ /*!
+ ensures
+ - Fills data with random 32-bit unsigned integers.
+ !*/
+
+ void fill_gaussian (
+ tensor& data,
+ float mean = 0,
+ float stddev = 1
+ );
+ /*!
+ requires
+ - data.size()%2 == 0
+ - stddev >= 0
+ ensures
+ - Fills data with random numbers drawn from a Gaussian distribution
+ with the given mean and standard deviation.
+ !*/
+
+ void fill_uniform (
+ tensor& data
+ );
+ /*!
+ ensures
+ - Fills data with uniform random numbers in the range (0.0, 1.0].
+ !*/
+
+ private:
+
+ void* handle;
+ };
+
+ // -----------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuRAND_H_
+
+
+
diff --git a/ml/dlib/dlib/dnn/cusolver_dlibapi.cu b/ml/dlib/dlib/dnn/cusolver_dlibapi.cu
new file mode 100644
index 000000000..942613134
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cusolver_dlibapi.cu
@@ -0,0 +1,204 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuSOLVER_CU_
+#define DLIB_DNN_CuSOLVER_CU_
+
+#ifdef DLIB_USE_CUDA
+
+#include "cusolver_dlibapi.h"
+#include <cublas_v2.h>
+#include <cusolverDn.h>
+#include "cuda_utils.h"
+
+// ----------------------------------------------------------------------------------------
+
+static const char* cusolver_get_error_string(cusolverStatus_t s)
+{
+ switch(s)
+ {
+ case CUSOLVER_STATUS_NOT_INITIALIZED:
+ return "CUDA Runtime API initialization failed.";
+ case CUSOLVER_STATUS_ALLOC_FAILED:
+ return "CUDA Resources could not be allocated.";
+ default:
+ return "A call to cuSolver failed";
+ }
+}
+
+// Check the return value of a call to the cuSolver runtime for an error condition.
+#define CHECK_CUSOLVER(call) \
+do{ \
+ const cusolverStatus_t error = call; \
+ if (error != CUSOLVER_STATUS_SUCCESS) \
+ { \
+ std::ostringstream sout; \
+ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
+ sout << "code: " << error << ", reason: " << cusolver_get_error_string(error);\
+ throw dlib::cusolver_error(sout.str()); \
+ } \
+}while(false)
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ class cusolver_context
+ {
+ public:
+ // not copyable
+ cusolver_context(const cusolver_context&) = delete;
+ cusolver_context& operator=(const cusolver_context&) = delete;
+
+ cusolver_context()
+ {
+ handles.resize(16);
+ }
+ ~cusolver_context()
+ {
+ for (auto h : handles)
+ {
+ if (h)
+ cusolverDnDestroy(h);
+ }
+ }
+
+ cusolverDnHandle_t get_handle (
+ )
+ {
+ int new_device_id;
+ CHECK_CUDA(cudaGetDevice(&new_device_id));
+ // make room for more devices if needed
+ if (new_device_id >= (long)handles.size())
+ handles.resize(new_device_id+16);
+
+ // If we don't have a handle already for this device then make one
+ if (!handles[new_device_id])
+ CHECK_CUSOLVER(cusolverDnCreate(&handles[new_device_id]));
+
+ // Finally, return the handle for the current device
+ return handles[new_device_id];
+ }
+
+ private:
+
+ std::vector<cusolverDnHandle_t> handles;
+ };
+
+ static cusolverDnHandle_t context()
+ {
+ thread_local cusolver_context c;
+ return c.get_handle();
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ __global__ void _cuda_set_to_identity_matrix(float* m, size_t nr)
+ {
+ for (auto j : grid_stride_range(0, nr*nr))
+ {
+ if (j%(nr+1) == 0)
+ m[j] = 1;
+ else
+ m[j] = 0;
+ }
+ }
+
+ void set_to_identity_matrix (
+ tensor& m
+ )
+ {
+ DLIB_CASSERT(m.size() == m.num_samples()*m.num_samples());
+ launch_kernel(_cuda_set_to_identity_matrix, max_jobs(m.size()), m.device(), m.num_samples());
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inv::~inv()
+ {
+ sync_if_needed();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void inv::
+ operator() (
+ const tensor& m_,
+ resizable_tensor& out
+ )
+ {
+ DLIB_CASSERT(m_.size() == m_.num_samples()*m_.num_samples(), "Input matrix must be square if you want to invert it.");
+ m = m_;
+
+ out.copy_size(m);
+ set_to_identity_matrix(out);
+
+ const int nc = m.num_samples();
+ int Lwork;
+ CHECK_CUSOLVER(cusolverDnSgetrf_bufferSize(context(), nc , nc, m.device(), nc, &Lwork));
+
+ if (Lwork > (int)workspace.size())
+ {
+ sync_if_needed();
+ workspace = cuda_data_ptr<float>(Lwork);
+ }
+ if (nc > (int)Ipiv.size())
+ {
+ sync_if_needed();
+ Ipiv = cuda_data_ptr<int>(nc);
+ }
+ if (info.size() != 1)
+ {
+ info = cuda_data_ptr<int>(1);
+ }
+
+ CHECK_CUSOLVER(cusolverDnSgetrf(context(), nc, nc, m.device(), nc, workspace, Ipiv, info));
+ CHECK_CUSOLVER(cusolverDnSgetrs(context(), CUBLAS_OP_N, nc, nc, m.device(), nc, Ipiv, out.device(), nc, info));
+ did_work_lately = true;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ int inv::
+ get_last_status(
+ )
+ {
+ std::vector<int> linfo;
+ memcpy(linfo, info);
+ if (linfo.size() != 0)
+ return linfo[0];
+ else
+ return 0;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void inv::
+ sync_if_needed()
+ {
+ if (did_work_lately)
+ {
+ did_work_lately = false;
+ // make sure we wait until any previous kernel launches have finished
+ // before we do something like deallocate the GPU memory.
+ cudaDeviceSynchronize();
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuSOLVER_CU_
+
+
diff --git a/ml/dlib/dlib/dnn/cusolver_dlibapi.h b/ml/dlib/dlib/dnn/cusolver_dlibapi.h
new file mode 100644
index 000000000..e5c77c151
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cusolver_dlibapi.h
@@ -0,0 +1,75 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuSOLVER_H_
+#define DLIB_DNN_CuSOLVER_H_
+
+#ifdef DLIB_USE_CUDA
+
+#include "tensor.h"
+#include "cuda_errors.h"
+#include "cuda_data_ptr.h"
+#include "../noncopyable.h"
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ class inv : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a functor for doing matrix inversion on the GPU. The only
+ reason it's an object is to avoid the reallocation of some GPU memory
+ blocks if you want to do a bunch of matrix inversions in a row.
+ !*/
+
+ public:
+
+ inv() = default;
+ ~inv();
+
+ void operator() (
+ const tensor& m,
+ resizable_tensor& out
+ );
+ /*!
+ requires
+ - m.size() == m.num_samples()*m.num_samples()
+ (i.e. mat(m) must be a square matrix)
+ ensures
+ - out == inv(mat(m));
+ !*/
+
+ int get_last_status(
+ );
+ /*!
+ ensures
+ - returns 0 if the last matrix inversion was successful and != 0
+ otherwise.
+ !*/
+
+ private:
+
+ void sync_if_needed();
+
+ bool did_work_lately = false;
+ resizable_tensor m;
+ cuda_data_ptr<float> workspace;
+ cuda_data_ptr<int> Ipiv;
+ cuda_data_ptr<int> info;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuSOLVER_H_
+
+
+
diff --git a/ml/dlib/dlib/dnn/gpu_data.cpp b/ml/dlib/dlib/dnn/gpu_data.cpp
new file mode 100644
index 000000000..6e7cec6be
--- /dev/null
+++ b/ml/dlib/dlib/dnn/gpu_data.cpp
@@ -0,0 +1,228 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GPU_DaTA_CPP_
+#define DLIB_GPU_DaTA_CPP_
+
+// Only things that require CUDA are declared in this cpp file. Everything else is in the
+// gpu_data.h header so that it can operate as "header-only" code when using just the CPU.
+#ifdef DLIB_USE_CUDA
+
+#include "gpu_data.h"
+#include <iostream>
+#include "cuda_utils.h"
+#include <cstring>
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void memcpy (
+ gpu_data& dest,
+ const gpu_data& src
+ )
+ {
+ DLIB_CASSERT(dest.size() == src.size());
+ if (src.size() == 0 || &dest == &src)
+ return;
+
+ memcpy(dest,0, src, 0, src.size());
+ }
+
+ void memcpy (
+ gpu_data& dest,
+ size_t dest_offset,
+ const gpu_data& src,
+ size_t src_offset,
+ size_t num
+ )
+ {
+ DLIB_CASSERT(dest_offset + num <= dest.size());
+ DLIB_CASSERT(src_offset + num <= src.size());
+ if (num == 0)
+ return;
+
+ // if there is aliasing
+ if (&dest == &src && std::max(dest_offset, src_offset) < std::min(dest_offset,src_offset)+num)
+ {
+ // if they perfectly alias each other then there is nothing to do
+ if (dest_offset == src_offset)
+ return;
+ else
+ std::memmove(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num);
+ }
+ else
+ {
+ // if we write to the entire thing then we can use device_write_only()
+ if (dest_offset == 0 && num == dest.size())
+ {
+ // copy the memory efficiently based on which copy is current in each object.
+ if (src.device_ready())
+ CHECK_CUDA(cudaMemcpy(dest.device_write_only(), src.device()+src_offset, num*sizeof(float), cudaMemcpyDeviceToDevice));
+ else
+ CHECK_CUDA(cudaMemcpy(dest.device_write_only(), src.host()+src_offset, num*sizeof(float), cudaMemcpyHostToDevice));
+ }
+ else
+ {
+ // copy the memory efficiently based on which copy is current in each object.
+ if (dest.device_ready() && src.device_ready())
+ CHECK_CUDA(cudaMemcpy(dest.device()+dest_offset, src.device()+src_offset, num*sizeof(float), cudaMemcpyDeviceToDevice));
+ else if (!dest.device_ready() && src.device_ready())
+ CHECK_CUDA(cudaMemcpy(dest.host()+dest_offset, src.device()+src_offset, num*sizeof(float), cudaMemcpyDeviceToHost));
+ else if (dest.device_ready() && !src.device_ready())
+ CHECK_CUDA(cudaMemcpy(dest.device()+dest_offset, src.host()+src_offset, num*sizeof(float), cudaMemcpyHostToDevice));
+ else
+ CHECK_CUDA(cudaMemcpy(dest.host()+dest_offset, src.host()+src_offset, num*sizeof(float), cudaMemcpyHostToHost));
+ }
+ }
+ }
+// ----------------------------------------------------------------------------------------
+
+ void gpu_data::
+ wait_for_transfer_to_finish() const
+ {
+ if (have_active_transfer)
+ {
+ CHECK_CUDA(cudaStreamSynchronize((cudaStream_t)cuda_stream.get()));
+ have_active_transfer = false;
+ // Check for errors. These calls to cudaGetLastError() are what help us find
+ // out if our kernel launches have been failing.
+ CHECK_CUDA(cudaGetLastError());
+ }
+ }
+
+ void gpu_data::
+ copy_to_device() const
+ {
+ // We want transfers to the device to always be concurrent with any device
+ // computation. So we use our non-default stream to do the transfer.
+ async_copy_to_device();
+ wait_for_transfer_to_finish();
+ }
+
+ void gpu_data::
+ copy_to_host() const
+ {
+ if (!host_current)
+ {
+ wait_for_transfer_to_finish();
+ CHECK_CUDA(cudaMemcpy(data_host.get(), data_device.get(), data_size*sizeof(float), cudaMemcpyDeviceToHost));
+ host_current = true;
+ // At this point we know our RAM block isn't in use because cudaMemcpy()
+ // implicitly syncs with the device.
+ device_in_use = false;
+ // Check for errors. These calls to cudaGetLastError() are what help us find
+ // out if our kernel launches have been failing.
+ CHECK_CUDA(cudaGetLastError());
+ }
+ }
+
+ void gpu_data::
+ async_copy_to_device() const
+ {
+ if (!device_current)
+ {
+ if (device_in_use)
+ {
+ // Wait for any possible CUDA kernels that might be using our memory block to
+ // complete before we overwrite the memory.
+ CHECK_CUDA(cudaStreamSynchronize(0));
+ device_in_use = false;
+ }
+ CHECK_CUDA(cudaMemcpyAsync(data_device.get(), data_host.get(), data_size*sizeof(float), cudaMemcpyHostToDevice, (cudaStream_t)cuda_stream.get()));
+ have_active_transfer = true;
+ device_current = true;
+ }
+ }
+
+ void gpu_data::
+ set_size(
+ size_t new_size
+ )
+ {
+ if (new_size == 0)
+ {
+ if (device_in_use)
+ {
+ // Wait for any possible CUDA kernels that might be using our memory block to
+ // complete before we free the memory.
+ CHECK_CUDA(cudaStreamSynchronize(0));
+ device_in_use = false;
+ }
+ wait_for_transfer_to_finish();
+ data_size = 0;
+ host_current = true;
+ device_current = true;
+ device_in_use = false;
+ data_host.reset();
+ data_device.reset();
+ }
+ else if (new_size != data_size)
+ {
+ if (device_in_use)
+ {
+ // Wait for any possible CUDA kernels that might be using our memory block to
+ // complete before we free the memory.
+ CHECK_CUDA(cudaStreamSynchronize(0));
+ device_in_use = false;
+ }
+ wait_for_transfer_to_finish();
+ data_size = new_size;
+ host_current = true;
+ device_current = true;
+ device_in_use = false;
+
+ try
+ {
+ CHECK_CUDA(cudaGetDevice(&the_device_id));
+
+ // free memory blocks before we allocate new ones.
+ data_host.reset();
+ data_device.reset();
+
+ void* data;
+ CHECK_CUDA(cudaMallocHost(&data, new_size*sizeof(float)));
+ // Note that we don't throw exceptions since the free calls are invariably
+ // called in destructors. They also shouldn't fail anyway unless someone
+ // is resetting the GPU card in the middle of their program.
+ data_host.reset((float*)data, [](float* ptr){
+ auto err = cudaFreeHost(ptr);
+ if(err!=cudaSuccess)
+ std::cerr << "cudaFreeHost() failed. Reason: " << cudaGetErrorString(err) << std::endl;
+ });
+
+ CHECK_CUDA(cudaMalloc(&data, new_size*sizeof(float)));
+ data_device.reset((float*)data, [](float* ptr){
+ auto err = cudaFree(ptr);
+ if(err!=cudaSuccess)
+ std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl;
+ });
+
+ if (!cuda_stream)
+ {
+ cudaStream_t cstream;
+ CHECK_CUDA(cudaStreamCreateWithFlags(&cstream, cudaStreamNonBlocking));
+ cuda_stream.reset(cstream, [](void* ptr){
+ auto err = cudaStreamDestroy((cudaStream_t)ptr);
+ if(err!=cudaSuccess)
+ std::cerr << "cudaStreamDestroy() failed. Reason: " << cudaGetErrorString(err) << std::endl;
+ });
+ }
+
+ }
+ catch(...)
+ {
+ set_size(0);
+ throw;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_GPU_DaTA_CPP_
+
diff --git a/ml/dlib/dlib/dnn/gpu_data.h b/ml/dlib/dlib/dnn/gpu_data.h
new file mode 100644
index 000000000..022a05f71
--- /dev/null
+++ b/ml/dlib/dlib/dnn/gpu_data.h
@@ -0,0 +1,266 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GPU_DaTA_H_
+#define DLIB_GPU_DaTA_H_
+
+#include "gpu_data_abstract.h"
+#include <memory>
+#include <cstring>
+#include "cuda_errors.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class gpu_data
+ {
+ /*!
+ CONVENTION
+ - if (size() != 0) then
+ - data_host == a pointer to size() floats in CPU memory.
+ - if (data_device) then
+ - data_device == a pointer to size() floats in device memory.
+
+ - if (there might be an active async transfer from host to device) then
+ - have_active_transfer == true
+
+ - We use the host_current and device_current bools to keep track of which
+ copy of the data (or both) are most current. e.g. if the CPU has
+ modified the data and it hasn't been copied to the device yet then
+ host_current==true and device_current==false.
+
+ Similarly, we use device_in_use==true to indicate that device() has been
+ called and no operation to wait for all CUDA kernel completion has been
+ executed. So if device_in_use==true then there might be a CUDA kernel
+ executing that is using the device memory block contained in this object.
+
+ !*/
+ public:
+
+ gpu_data(
+ ) : data_size(0), host_current(true), device_current(true),have_active_transfer(false),device_in_use(false), the_device_id(0)
+ {
+ }
+
+ // Not copyable
+ gpu_data(const gpu_data&) = delete;
+ gpu_data& operator=(const gpu_data&) = delete;
+
+ // but is movable
+ gpu_data(gpu_data&& item) : gpu_data() { swap(item); }
+ gpu_data& operator=(gpu_data&& item) { swap(item); return *this; }
+
+ int device_id() const { return the_device_id; }
+
+#ifdef DLIB_USE_CUDA
+ void async_copy_to_device() const;
+ void set_size(size_t new_size);
+#else
+ // Note that calls to host() or device() will block until any async transfers are complete.
+ void async_copy_to_device() const{}
+
+ void set_size(size_t new_size)
+ {
+ if (new_size == 0)
+ {
+ data_size = 0;
+ host_current = true;
+ device_current = true;
+ device_in_use = false;
+ data_host.reset();
+ data_device.reset();
+ }
+ else if (new_size != data_size)
+ {
+ data_size = new_size;
+ host_current = true;
+ device_current = true;
+ device_in_use = false;
+ data_host.reset(new float[new_size], std::default_delete<float[]>());
+ data_device.reset();
+ }
+ }
+#endif
+
+ const float* host() const
+ {
+ copy_to_host();
+ return data_host.get();
+ }
+
+ float* host()
+ {
+ copy_to_host();
+ device_current = false;
+ return data_host.get();
+ }
+
+ float* host_write_only()
+ {
+ host_current = true;
+ device_current = false;
+ return data_host.get();
+ }
+
+ const float* device() const
+ {
+#ifndef DLIB_USE_CUDA
+ DLIB_CASSERT(false, "CUDA NOT ENABLED");
+#endif
+ copy_to_device();
+ device_in_use = true;
+ return data_device.get();
+ }
+
+ float* device()
+ {
+#ifndef DLIB_USE_CUDA
+ DLIB_CASSERT(false, "CUDA NOT ENABLED");
+#endif
+ copy_to_device();
+ host_current = false;
+ device_in_use = true;
+ return data_device.get();
+ }
+
+ float* device_write_only()
+ {
+#ifndef DLIB_USE_CUDA
+ DLIB_CASSERT(false, "CUDA NOT ENABLED");
+#endif
+ wait_for_transfer_to_finish();
+ host_current = false;
+ device_current = true;
+ device_in_use = true;
+ return data_device.get();
+ }
+
+ bool host_ready (
+ ) const { return host_current; }
+
+ bool device_ready (
+ ) const { return device_current && !have_active_transfer; }
+
+ size_t size() const { return data_size; }
+
+ void swap (gpu_data& item)
+ {
+ std::swap(data_size, item.data_size);
+ std::swap(host_current, item.host_current);
+ std::swap(device_current, item.device_current);
+ std::swap(have_active_transfer, item.have_active_transfer);
+ std::swap(data_host, item.data_host);
+ std::swap(data_device, item.data_device);
+ std::swap(cuda_stream, item.cuda_stream);
+ std::swap(the_device_id, item.the_device_id);
+ }
+
+ private:
+
+#ifdef DLIB_USE_CUDA
+ void copy_to_device() const;
+ void copy_to_host() const;
+ void wait_for_transfer_to_finish() const;
+#else
+ void copy_to_device() const{}
+ void copy_to_host() const{}
+ void wait_for_transfer_to_finish() const{}
+#endif
+
+
+ size_t data_size;
+ mutable bool host_current;
+ mutable bool device_current;
+ mutable bool have_active_transfer;
+ mutable bool device_in_use;
+
+ std::shared_ptr<float> data_host;
+ std::shared_ptr<float> data_device;
+ std::shared_ptr<void> cuda_stream;
+ int the_device_id;
+ };
+
+ inline void serialize(const gpu_data& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.size(), out);
+ auto data = item.host();
+ for (size_t i = 0; i < item.size(); ++i)
+ serialize(data[i], out);
+ }
+
+ inline void deserialize(gpu_data& item, std::istream& in)
+ {
+ int version;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::gpu_data.");
+ size_t s;
+ deserialize(s, in);
+ item.set_size(s);
+ auto data = item.host();
+ for (size_t i = 0; i < item.size(); ++i)
+ deserialize(data[i], in);
+ }
+
+#ifdef DLIB_USE_CUDA
+ void memcpy (gpu_data& dest, const gpu_data& src);
+
+ void memcpy (
+ gpu_data& dest,
+ size_t dest_offset,
+ const gpu_data& src,
+ size_t src_offset,
+ size_t num
+ );
+
+#else
+
+ inline void memcpy (gpu_data& dest, const gpu_data& src)
+ {
+ DLIB_CASSERT(dest.size() == src.size());
+ if (src.size() == 0 || &dest == &src)
+ return;
+ std::memcpy(dest.host_write_only(), src.host(), sizeof(float)*src.size());
+ }
+
+ inline void memcpy (
+ gpu_data& dest,
+ size_t dest_offset,
+ const gpu_data& src,
+ size_t src_offset,
+ size_t num
+ )
+ {
+ DLIB_CASSERT(dest_offset + num <= dest.size());
+ DLIB_CASSERT(src_offset + num <= src.size());
+ if (num == 0)
+ return;
+ if (&dest == &src && std::max(dest_offset, src_offset) < std::min(dest_offset,src_offset)+num)
+ {
+ // if they perfectly alias each other then there is nothing to do
+ if (dest_offset == src_offset)
+ return;
+ else
+ std::memmove(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num);
+ }
+ else
+ {
+ // if we write to the entire thing then we can use host_write_only()
+ if (dest_offset == 0 && num == dest.size())
+ std::memcpy(dest.host_write_only(), src.host()+src_offset, sizeof(float)*num);
+ else
+ std::memcpy(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num);
+ }
+ }
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GPU_DaTA_H_
+
diff --git a/ml/dlib/dlib/dnn/gpu_data_abstract.h b/ml/dlib/dlib/dnn/gpu_data_abstract.h
new file mode 100644
index 000000000..f2423dee1
--- /dev/null
+++ b/ml/dlib/dlib/dnn/gpu_data_abstract.h
@@ -0,0 +1,266 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_GPU_DaTA_ABSTRACT_H_
+#ifdef DLIB_GPU_DaTA_ABSTRACT_H_
+
+#include "cuda_errors.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class gpu_data
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a block of size() floats, all stored contiguously in memory.
+ Importantly, it keeps two copies of the floats, one on the host CPU side
+ and another on the GPU device side. It automatically performs the necessary
+ host/device transfers to keep these two copies of the data in sync.
+
+ All transfers to the device happen asynchronously with respect to the
+ default CUDA stream so that CUDA kernel computations can overlap with data
+ transfers. However, any transfers from the device to the host happen
+ synchronously in the default CUDA stream. Therefore, you should perform
+ all your CUDA kernel launches on the default stream so that transfers back
+ to the host do not happen before the relevant computations have completed.
+
+ If DLIB_USE_CUDA is not #defined then this object will not use CUDA at all.
+ Instead, it will simply store one host side memory block of floats.
+
+ THREAD SAFETY
+ Instances of this object are not thread-safe. So don't touch one from
+ multiple threads at the same time.
+ !*/
+ public:
+
+ gpu_data(
+ );
+ /*!
+ ensures
+ - #size() == 0
+ - #host() == nullptr
+ - #device() == nullptr
+ - #host_ready() == true
+ - #device_ready() == true
+ - #device_id() == 0
+ !*/
+
+ // This object is not copyable, however, it is movable.
+ gpu_data(const gpu_data&) = delete;
+ gpu_data& operator=(const gpu_data&) = delete;
+ gpu_data(gpu_data&& item);
+ gpu_data& operator=(gpu_data&& item);
+
+ int device_id(
+ ) const;
+ /*!
+ ensures
+ - returns the ID of the CUDA device that allocated this memory. I.e. the
+ number returned by cudaGetDevice() when the memory was allocated.
+ - If CUDA is not being used then this function always returns 0.
+ !*/
+
+ void async_copy_to_device(
+ );
+ /*!
+ ensures
+ - if (!device_ready()) then
+ - Begins asynchronously copying host data to the device once it is safe
+ to do so. I.e. This function will wait until any previously
+ scheduled CUDA kernels, which are using the device() memory block,
+ have completed before transferring the new data to the device.
+ - A call to device() that happens before the transfer completes will
+ block until the transfer is complete. That is, it is safe to call
+ async_copy_to_device() and then immediately call device().
+ !*/
+
+ void set_size(
+ size_t new_size
+ );
+ /*!
+ ensures
+ - #size() == new_size
+ !*/
+
+ bool host_ready (
+ ) const;
+ /*!
+ ensures
+ - returns true if and only if the host's copy of the data is current. The
+ host's data is current if there aren't any modifications to the data
+ which were made on the device side that have yet to be copied to the
+ host.
+ !*/
+
+ bool device_ready (
+ ) const;
+ /*!
+ ensures
+ - returns true if and only if the device's copy of the data is current.
+ The device's data is current if there aren't any modifications to the
+ data which were made on the host side that have yet to be copied to the
+ device.
+ !*/
+
+ const float* host(
+ ) const;
+ /*!
+ ensures
+ - returns a pointer to the host memory block of size() contiguous float
+ values or nullptr if size()==0.
+ - if (!host_ready()) then
+ - copies the data from the device to the host, while this is happening
+ the call to host() blocks.
+ - #host_ready() == true
+ !*/
+
+ float* host(
+ );
+ /*!
+ ensures
+ - returns a pointer to the host memory block of size() contiguous float
+ values or nullptr if size()==0.
+ - if (!host_ready()) then
+ - copies the data from the device to the host, while this is happening
+ the call to host() blocks.
+ - #host_ready() == true
+ - #device_ready() == false
+ I.e. Marks the device side data as out of date so that the next call to
+ device() will perform a host to device transfer. If you want to begin
+ the transfer immediately then you can call async_copy_to_device() after
+ calling host().
+ !*/
+
+ float* host_write_only(
+ );
+ /*!
+ ensures
+ - This function returns the same pointer as host(), except that it never
+ performs a device to host memory copy. Instead, it immediately marks the
+ device side data as out of date, effectively discarding it. Therefore,
+ the values in the data pointed to by host_write_only() are undefined and
+ you should only call host_write_only() if you are going to assign to
+ every memory location in the returned memory block.
+ - #host_ready() == true
+ - #device_ready() == false
+ !*/
+
+ const float* device(
+ ) const;
+ /*!
+ requires
+ - DLIB_USE_CUDA is #defined
+ ensures
+ - returns a pointer to the device memory block of size() contiguous float
+ values or nullptr if size()==0.
+ - if (!device_ready()) then
+ - copies the data from the host to the device, while this is happening
+ the call to device() blocks.
+ - #device_ready() == true
+ !*/
+
+ float* device(
+ );
+ /*!
+ requires
+ - DLIB_USE_CUDA is #defined
+ ensures
+ - returns a pointer to the device memory block of size() contiguous float
+ values or nullptr if size()==0.
+ - if (!device_ready()) then
+ - copies the data from the host to the device, while this is happening
+ the call to device() blocks.
+ - #host_ready() == false
+ - #device_ready() == true
+ !*/
+
+ float* device_write_only(
+ );
+ /*!
+ requires
+ - DLIB_USE_CUDA is #defined
+ ensures
+ - This function returns the same pointer as device(), except that it never
+ performs a host to device memory copy. Instead, it immediately marks the
+ host side data as out of date, effectively discarding it. Therefore, the
+ values in the data pointed to by device_write_only() are undefined and
+ you should only call device_write_only() if you are going to assign to
+ every memory location in the returned memory block.
+ - #host_ready() == false
+ - #device_ready() == true
+ !*/
+
+
+ size_t size(
+ ) const;
+ /*!
+ ensures
+ - returns the number of floats contained in this object.
+ !*/
+
+ void swap (
+ gpu_data& item
+ );
+ /*!
+ ensures
+ - swaps the state of *this and item
+ !*/
+
+ };
+
+ void serialize(const gpu_data& item, std::ostream& out);
+ void deserialize(gpu_data& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+ void memcpy (
+ gpu_data& dest,
+ const gpu_data& src
+ );
+ /*!
+ requires
+ - dest.size() == src.size()
+ ensures
+ - Copies the data in src to dest. If the device data is current (i.e.
+ device_ready()==true) on both src and dest then the copy will happen entirely
+ on the device side.
+ - It doesn't matter what GPU device is selected by cudaSetDevice(). You can
+ always copy gpu_data objects to and from each other regardless.
+ - This function blocks until the copy has completed.
+ !*/
+
+ void memcpy (
+ gpu_data& dest,
+ size_t dest_offset,
+ const gpu_data& src,
+ size_t src_offset,
+ size_t num
+ );
+ /*!
+ requires
+ - dest_offset + num <= dest.size()
+ - src_offset + num <= src.size()
+ ensures
+ - Copies the data in src to dest, but only copies data in the range
+ [src.host()+src_offset, src.host()+src_offset+num) to
+ [dest.host()+dest_offset, dest.host()+dest_offset+num). Therefore, it is
+ just like the above memcpy() except that you can specify some subset of data
+ in a gpu_data object to be copied.
+ - Like the above version of memcpy(), the copy will happen in the most
+ efficient way, automatically using the appropriate type of host/device
+ transfers based on where data is currently resident.
+ - It doesn't matter what GPU device is selected by cudaSetDevice(). You can
+ always copy gpu_data objects to and from each other regardless.
+ - This function blocks until the copy has completed.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GPU_DaTA_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/dnn/input.h b/ml/dlib/dlib/dnn/input.h
new file mode 100644
index 000000000..3b5c954e6
--- /dev/null
+++ b/ml/dlib/dlib/dnn/input.h
@@ -0,0 +1,808 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_INPUT_H_
+#define DLIB_DNn_INPUT_H_
+
+#include "input_abstract.h"
+#include "../matrix.h"
+#include "../array2d.h"
+#include "../pixel.h"
+#include "../image_processing.h"
+#include <sstream>
+#include <array>
+#include "tensor_tools.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class input
+ {
+ const static bool always_false = sizeof(T)!=sizeof(T);
+ static_assert(always_false, "Unsupported type given to input<>. input<> only supports "
+ "dlib::matrix and dlib::array2d objects.");
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <size_t NR, size_t NC=NR>
+ class input_rgb_image_sized;
+
+ class input_rgb_image
+ {
+ public:
+ typedef matrix<rgb_pixel> input_type;
+
+ input_rgb_image (
+ ) :
+ avg_red(122.782),
+ avg_green(117.001),
+ avg_blue(104.298)
+ {
+ }
+
+ input_rgb_image (
+ float avg_red_,
+ float avg_green_,
+ float avg_blue_
+ ) : avg_red(avg_red_), avg_green(avg_green_), avg_blue(avg_blue_)
+ {}
+
+ template <size_t NR, size_t NC>
+ inline input_rgb_image (
+ const input_rgb_image_sized<NR,NC>& item
+ );
+
+ float get_avg_red() const { return avg_red; }
+ float get_avg_green() const { return avg_green; }
+ float get_avg_blue() const { return avg_blue; }
+
+ bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); }
+ drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; }
+ drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ DLIB_CASSERT(std::distance(ibegin,iend) > 0);
+ const auto nr = ibegin->nr();
+ const auto nc = ibegin->nc();
+ // make sure all the input matrices have the same dimensions
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ DLIB_CASSERT(i->nr()==nr && i->nc()==nc,
+ "\t input_rgb_image::to_tensor()"
+ << "\n\t All matrices given to to_tensor() must have the same dimensions."
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ << "\n\t i->nr(): " << i->nr()
+ << "\n\t i->nc(): " << i->nc()
+ );
+ }
+
+
+ // initialize data to the right size to contain the stuff in the iterator range.
+ data.set_size(std::distance(ibegin,iend), 3, nr, nc);
+
+
+ const size_t offset = nr*nc;
+ auto ptr = data.host();
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ for (long r = 0; r < nr; ++r)
+ {
+ for (long c = 0; c < nc; ++c)
+ {
+ rgb_pixel temp = (*i)(r,c);
+ auto p = ptr++;
+ *p = (temp.red-avg_red)/256.0;
+ p += offset;
+ *p = (temp.green-avg_green)/256.0;
+ p += offset;
+ *p = (temp.blue-avg_blue)/256.0;
+ p += offset;
+ }
+ }
+ ptr += offset*(data.k()-1);
+ }
+
+ }
+
+ friend void serialize(const input_rgb_image& item, std::ostream& out)
+ {
+ serialize("input_rgb_image", out);
+ serialize(item.avg_red, out);
+ serialize(item.avg_green, out);
+ serialize(item.avg_blue, out);
+ }
+
+ friend void deserialize(input_rgb_image& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "input_rgb_image" && version != "input_rgb_image_sized")
+ throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image.");
+ deserialize(item.avg_red, in);
+ deserialize(item.avg_green, in);
+ deserialize(item.avg_blue, in);
+
+ // read and discard the sizes if this was really a sized input layer.
+ if (version == "input_rgb_image_sized")
+ {
+ size_t nr, nc;
+ deserialize(nr, in);
+ deserialize(nc, in);
+ }
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const input_rgb_image& item)
+ {
+ out << "input_rgb_image("<<item.avg_red<<","<<item.avg_green<<","<<item.avg_blue<<")";
+ return out;
+ }
+
+ friend void to_xml(const input_rgb_image& item, std::ostream& out)
+ {
+ out << "<input_rgb_image r='"<<item.avg_red<<"' g='"<<item.avg_green<<"' b='"<<item.avg_blue<<"'/>";
+ }
+
+ private:
+ float avg_red;
+ float avg_green;
+ float avg_blue;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <size_t NR, size_t NC>
+ class input_rgb_image_sized
+ {
+ public:
+ static_assert(NR != 0 && NC != 0, "The input image can't be empty.");
+
+ typedef matrix<rgb_pixel> input_type;
+
+ input_rgb_image_sized (
+ ) :
+ avg_red(122.782),
+ avg_green(117.001),
+ avg_blue(104.298)
+ {
+ }
+
+ input_rgb_image_sized (
+ const input_rgb_image& item
+ ) : avg_red(item.get_avg_red()),
+ avg_green(item.get_avg_green()),
+ avg_blue(item.get_avg_blue())
+ {}
+
+ input_rgb_image_sized (
+ float avg_red_,
+ float avg_green_,
+ float avg_blue_
+ ) : avg_red(avg_red_), avg_green(avg_green_), avg_blue(avg_blue_)
+ {}
+
+ float get_avg_red() const { return avg_red; }
+ float get_avg_green() const { return avg_green; }
+ float get_avg_blue() const { return avg_blue; }
+
+ bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); }
+ drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; }
+ drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ DLIB_CASSERT(std::distance(ibegin,iend) > 0);
+ // make sure all input images have the correct size
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ DLIB_CASSERT(i->nr()==NR && i->nc()==NC,
+ "\t input_rgb_image_sized::to_tensor()"
+ << "\n\t All input images must have "<<NR<<" rows and "<<NC<< " columns, but we got one with "<<i->nr()<<" rows and "<<i->nc()<<" columns."
+ );
+ }
+
+
+ // initialize data to the right size to contain the stuff in the iterator range.
+ data.set_size(std::distance(ibegin,iend), 3, NR, NC);
+
+
+ const size_t offset = NR*NC;
+ auto ptr = data.host();
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ for (size_t r = 0; r < NR; ++r)
+ {
+ for (size_t c = 0; c < NC; ++c)
+ {
+ rgb_pixel temp = (*i)(r,c);
+ auto p = ptr++;
+ *p = (temp.red-avg_red)/256.0;
+ p += offset;
+ *p = (temp.green-avg_green)/256.0;
+ p += offset;
+ *p = (temp.blue-avg_blue)/256.0;
+ p += offset;
+ }
+ }
+ ptr += offset*(data.k()-1);
+ }
+
+ }
+
+ friend void serialize(const input_rgb_image_sized& item, std::ostream& out)
+ {
+ serialize("input_rgb_image_sized", out);
+ serialize(item.avg_red, out);
+ serialize(item.avg_green, out);
+ serialize(item.avg_blue, out);
+ serialize(NR, out);
+ serialize(NC, out);
+ }
+
+ friend void deserialize(input_rgb_image_sized& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "input_rgb_image_sized")
+ throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image_sized.");
+ deserialize(item.avg_red, in);
+ deserialize(item.avg_green, in);
+ deserialize(item.avg_blue, in);
+ size_t nr, nc;
+ deserialize(nr, in);
+ deserialize(nc, in);
+ if (nr != NR || nc != NC)
+ {
+ std::ostringstream sout;
+ sout << "Wrong image dimensions found while deserializing dlib::input_rgb_image_sized.\n";
+ sout << "Expected "<<NR<<" rows and "<<NC<< " columns, but found "<<nr<<" rows and "<<nc<<" columns.";
+ throw serialization_error(sout.str());
+ }
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const input_rgb_image_sized& item)
+ {
+ out << "input_rgb_image_sized("<<item.avg_red<<","<<item.avg_green<<","<<item.avg_blue<<") nr="<<NR<<" nc="<<NC;
+ return out;
+ }
+
+ friend void to_xml(const input_rgb_image_sized& item, std::ostream& out)
+ {
+ out << "<input_rgb_image_sized r='"<<item.avg_red<<"' g='"<<item.avg_green<<"' b='"<<item.avg_blue<<"' nr='"<<NR<<"' nc='"<<NC<<"'/>";
+ }
+
+ private:
+ float avg_red;
+ float avg_green;
+ float avg_blue;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <size_t NR, size_t NC>
+ input_rgb_image::
+ input_rgb_image (
+ const input_rgb_image_sized<NR,NC>& item
+ ) : avg_red(item.get_avg_red()),
+ avg_green(item.get_avg_green()),
+ avg_blue(item.get_avg_blue())
+ {}
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ class input<matrix<T,NR,NC,MM,L>>
+ {
+ public:
+ typedef matrix<T,NR,NC,MM,L> input_type;
+
+ input() {}
+ input(const input&) {}
+
+ template <typename mm>
+ input(const input<array2d<T,mm>>&) {}
+
+ bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); }
+ drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; }
+ drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ DLIB_CASSERT(std::distance(ibegin,iend) > 0);
+ const auto nr = ibegin->nr();
+ const auto nc = ibegin->nc();
+ // make sure all the input matrices have the same dimensions
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ DLIB_CASSERT(i->nr()==nr && i->nc()==nc,
+ "\t input::to_tensor()"
+ << "\n\t All matrices given to to_tensor() must have the same dimensions."
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ << "\n\t i->nr(): " << i->nr()
+ << "\n\t i->nc(): " << i->nc()
+ );
+ }
+
+
+ // initialize data to the right size to contain the stuff in the iterator range.
+ data.set_size(std::distance(ibegin,iend), pixel_traits<T>::num, nr, nc);
+
+ typedef typename pixel_traits<T>::basic_pixel_type bptype;
+
+ const size_t offset = nr*nc;
+ auto ptr = data.host();
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ for (long r = 0; r < nr; ++r)
+ {
+ for (long c = 0; c < nc; ++c)
+ {
+ auto temp = pixel_to_vector<float>((*i)(r,c));
+ auto p = ptr++;
+ for (long j = 0; j < temp.size(); ++j)
+ {
+ if (is_same_type<bptype,unsigned char>::value)
+ *p = temp(j)/256.0;
+ else
+ *p = temp(j);
+ p += offset;
+ }
+ }
+ }
+ ptr += offset*(data.k()-1);
+ }
+
+ }
+
+ friend void serialize(const input& /*item*/, std::ostream& out)
+ {
+ serialize("input<matrix>", out);
+ }
+
+ friend void deserialize(input& /*item*/, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "input<matrix>")
+ throw serialization_error("Unexpected version found while deserializing dlib::input.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const input& /*item*/)
+ {
+ out << "input<matrix>";
+ return out;
+ }
+
+ friend void to_xml(const input& /*item*/, std::ostream& out)
+ {
+ out << "<input/>";
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L, size_t K>
+ class input<std::array<matrix<T,NR,NC,MM,L>,K>>
+ {
+ public:
+ typedef std::array<matrix<T,NR,NC,MM,L>,K> input_type;
+
+ input() {}
+ input(const input&) {}
+
+ bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); }
+ drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; }
+ drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ DLIB_CASSERT(std::distance(ibegin,iend) > 0);
+ DLIB_CASSERT(ibegin->size() != 0, "When using std::array<matrix> inputs you can't give 0 sized arrays.");
+ const auto nr = (*ibegin)[0].nr();
+ const auto nc = (*ibegin)[0].nc();
+ // make sure all the input matrices have the same dimensions
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ for (size_t k = 0; k < K; ++k)
+ {
+ const auto& arr = *i;
+ DLIB_CASSERT(arr[k].nr()==nr && arr[k].nc()==nc,
+ "\t input::to_tensor()"
+ << "\n\t When using std::array<matrix> as input, all matrices in a batch must have the same dimensions."
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ << "\n\t k: " << k
+ << "\n\t arr[k].nr(): " << arr[k].nr()
+ << "\n\t arr[k].nc(): " << arr[k].nc()
+ );
+ }
+ }
+
+
+ // initialize data to the right size to contain the stuff in the iterator range.
+ data.set_size(std::distance(ibegin,iend), K, nr, nc);
+
+ auto ptr = data.host();
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ for (size_t k = 0; k < K; ++k)
+ {
+ for (long r = 0; r < nr; ++r)
+ {
+ for (long c = 0; c < nc; ++c)
+ {
+ if (is_same_type<T,unsigned char>::value)
+ *ptr++ = (*i)[k](r,c)/256.0;
+ else
+ *ptr++ = (*i)[k](r,c);
+ }
+ }
+ }
+ }
+
+ }
+
+ friend void serialize(const input& /*item*/, std::ostream& out)
+ {
+ serialize("input<array<matrix>>", out);
+ }
+
+ friend void deserialize(input& /*item*/, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "input<array<matrix>>")
+ throw serialization_error("Unexpected version found while deserializing dlib::input<array<matrix>>.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const input& /*item*/)
+ {
+ out << "input<array<matrix>>";
+ return out;
+ }
+
+ friend void to_xml(const input& /*item*/, std::ostream& out)
+ {
+ out << "<input/>";
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename MM>
+ class input<array2d<T,MM>>
+ {
+ public:
+ typedef array2d<T,MM> input_type;
+
+ input() {}
+ input(const input&) {}
+
+ template <long NR, long NC, typename mm, typename L>
+ input(const input<matrix<T,NR,NC,mm,L>>&) {}
+
+ bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); }
+ drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; }
+ drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ DLIB_CASSERT(std::distance(ibegin,iend) > 0);
+ const auto nr = ibegin->nr();
+ const auto nc = ibegin->nc();
+ // make sure all the input matrices have the same dimensions
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ DLIB_CASSERT(i->nr()==nr && i->nc()==nc,
+ "\t input::to_tensor()"
+ << "\n\t All array2d objects given to to_tensor() must have the same dimensions."
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ << "\n\t i->nr(): " << i->nr()
+ << "\n\t i->nc(): " << i->nc()
+ );
+ }
+
+
+ // initialize data to the right size to contain the stuff in the iterator range.
+ data.set_size(std::distance(ibegin,iend), pixel_traits<T>::num, nr, nc);
+ typedef typename pixel_traits<T>::basic_pixel_type bptype;
+
+ const size_t offset = nr*nc;
+ auto ptr = data.host();
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ for (long r = 0; r < nr; ++r)
+ {
+ for (long c = 0; c < nc; ++c)
+ {
+ auto temp = pixel_to_vector<float>((*i)[r][c]);
+ auto p = ptr++;
+ for (long j = 0; j < temp.size(); ++j)
+ {
+ if (is_same_type<bptype,unsigned char>::value)
+ *p = temp(j)/256.0;
+ else
+ *p = temp(j);
+ p += offset;
+ }
+ }
+ }
+ ptr += offset*(data.k()-1);
+ }
+
+ }
+
+ friend void serialize(const input& item, std::ostream& out)
+ {
+ serialize("input<array2d>", out);
+ }
+
+ friend void deserialize(input& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "input<array2d>")
+ throw serialization_error("Unexpected version found while deserializing dlib::input.");
+ }
+ friend std::ostream& operator<<(std::ostream& out, const input& item)
+ {
+ out << "input<array2d>";
+ return out;
+ }
+
+ friend void to_xml(const input& item, std::ostream& out)
+ {
+ out << "<input/>";
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename PYRAMID_TYPE>
+ class input_rgb_image_pyramid
+ {
+ public:
+ typedef matrix<rgb_pixel> input_type;
+ typedef PYRAMID_TYPE pyramid_type;
+
+ input_rgb_image_pyramid (
+ ) :
+ avg_red(122.782),
+ avg_green(117.001),
+ avg_blue(104.298)
+ {
+ }
+
+ input_rgb_image_pyramid (
+ float avg_red_,
+ float avg_green_,
+ float avg_blue_
+ ) : avg_red(avg_red_), avg_green(avg_green_), avg_blue(avg_blue_)
+ {}
+
+ float get_avg_red() const { return avg_red; }
+ float get_avg_green() const { return avg_green; }
+ float get_avg_blue() const { return avg_blue; }
+
+ unsigned long get_pyramid_padding () const { return pyramid_padding; }
+ void set_pyramid_padding (unsigned long value) { pyramid_padding = value; }
+
+ unsigned long get_pyramid_outer_padding () const { return pyramid_outer_padding; }
+ void set_pyramid_outer_padding (unsigned long value) { pyramid_outer_padding = value; }
+
+ bool image_contained_point (
+ const tensor& data,
+ const point& p
+ ) const
+ {
+ auto&& rects = any_cast<std::vector<rectangle>>(data.annotation());
+ DLIB_CASSERT(rects.size() > 0);
+ return rects[0].contains(p+rects[0].tl_corner());
+ }
+
+ drectangle tensor_space_to_image_space (
+ const tensor& data,
+ drectangle r
+ ) const
+ {
+ auto&& rects = any_cast<std::vector<rectangle>>(data.annotation());
+ return tiled_pyramid_to_image<pyramid_type>(rects, r);
+ }
+
+ drectangle image_space_to_tensor_space (
+ const tensor& data,
+ double scale,
+ drectangle r
+ ) const
+ {
+ DLIB_CASSERT(0 < scale && scale <= 1 , "scale: "<< scale);
+ auto&& rects = any_cast<std::vector<rectangle>>(data.annotation());
+ return image_to_tiled_pyramid<pyramid_type>(rects, scale, r);
+ }
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const
+ {
+ DLIB_CASSERT(std::distance(ibegin,iend) > 0);
+ auto nr = ibegin->nr();
+ auto nc = ibegin->nc();
+ // make sure all the input matrices have the same dimensions
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ DLIB_CASSERT(i->nr()==nr && i->nc()==nc,
+ "\t input_rgb_image_pyramid::to_tensor()"
+ << "\n\t All matrices given to to_tensor() must have the same dimensions."
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ << "\n\t i->nr(): " << i->nr()
+ << "\n\t i->nc(): " << i->nc()
+ );
+ }
+
+ long NR, NC;
+ pyramid_type pyr;
+ auto& rects = data.annotation().get<std::vector<rectangle>>();
+ impl::compute_tiled_image_pyramid_details(pyr, nr, nc, pyramid_padding, pyramid_outer_padding, rects, NR, NC);
+
+ // initialize data to the right size to contain the stuff in the iterator range.
+ data.set_size(std::distance(ibegin,iend), 3, NR, NC);
+
+ // We need to zero the image before doing the pyramid, since the pyramid
+ // creation code doesn't write to all parts of the image. We also take
+ // care to avoid triggering any device to hosts copies.
+ auto ptr = data.host_write_only();
+ for (size_t i = 0; i < data.size(); ++i)
+ ptr[i] = 0;
+
+ if (rects.size() == 0)
+ return;
+
+ // copy the first raw image into the top part of the tiled pyramid. We need to
+ // do this for each of the input images/samples in the tensor.
+ for (auto i = ibegin; i != iend; ++i)
+ {
+ auto& img = *i;
+ ptr += rects[0].top()*data.nc();
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ auto p = ptr+rects[0].left();
+ for (long c = 0; c < img.nc(); ++c)
+ p[c] = (img(r,c).red-avg_red)/256.0;
+ ptr += data.nc();
+ }
+ ptr += data.nc()*(data.nr()-rects[0].bottom()-1);
+
+ ptr += rects[0].top()*data.nc();
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ auto p = ptr+rects[0].left();
+ for (long c = 0; c < img.nc(); ++c)
+ p[c] = (img(r,c).green-avg_green)/256.0;
+ ptr += data.nc();
+ }
+ ptr += data.nc()*(data.nr()-rects[0].bottom()-1);
+
+ ptr += rects[0].top()*data.nc();
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ auto p = ptr+rects[0].left();
+ for (long c = 0; c < img.nc(); ++c)
+ p[c] = (img(r,c).blue-avg_blue)/256.0;
+ ptr += data.nc();
+ }
+ ptr += data.nc()*(data.nr()-rects[0].bottom()-1);
+ }
+
+ // now build the image pyramid into data. This does the same thing as
+ // create_tiled_pyramid(), except we use the GPU if one is available.
+ for (size_t i = 1; i < rects.size(); ++i)
+ {
+ alias_tensor src(data.num_samples(),data.k(),rects[i-1].height(),rects[i-1].width());
+ alias_tensor dest(data.num_samples(),data.k(),rects[i].height(),rects[i].width());
+
+ auto asrc = src(data, data.nc()*rects[i-1].top() + rects[i-1].left());
+ auto adest = dest(data, data.nc()*rects[i].top() + rects[i].left());
+
+ tt::resize_bilinear(adest, data.nc(), data.nr()*data.nc(),
+ asrc, data.nc(), data.nr()*data.nc());
+ }
+ }
+
+ friend void serialize(const input_rgb_image_pyramid& item, std::ostream& out)
+ {
+ serialize("input_rgb_image_pyramid2", out);
+ serialize(item.avg_red, out);
+ serialize(item.avg_green, out);
+ serialize(item.avg_blue, out);
+ serialize(item.pyramid_padding, out);
+ serialize(item.pyramid_outer_padding, out);
+ }
+
+ friend void deserialize(input_rgb_image_pyramid& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "input_rgb_image_pyramid" && version != "input_rgb_image_pyramid2")
+ throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image_pyramid.");
+ deserialize(item.avg_red, in);
+ deserialize(item.avg_green, in);
+ deserialize(item.avg_blue, in);
+ if (version == "input_rgb_image_pyramid2")
+ {
+ deserialize(item.pyramid_padding, in);
+ deserialize(item.pyramid_outer_padding, in);
+ }
+ else
+ {
+ item.pyramid_padding = 10;
+ item.pyramid_outer_padding = 11;
+ }
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const input_rgb_image_pyramid& item)
+ {
+ out << "input_rgb_image_pyramid("<<item.avg_red<<","<<item.avg_green<<","<<item.avg_blue<<")";
+ out << " pyramid_padding="<<item.pyramid_padding;
+ out << " pyramid_outer_padding="<<item.pyramid_outer_padding;
+ return out;
+ }
+
+ friend void to_xml(const input_rgb_image_pyramid& item, std::ostream& out)
+ {
+ out << "<input_rgb_image_pyramid r='"<<item.avg_red<<"' g='"<<item.avg_green
+ <<"' b='"<<item.avg_blue
+ <<"' pyramid_padding='"<<item.pyramid_padding
+ <<"' pyramid_outer_padding='"<<item.pyramid_outer_padding
+ <<"'/>";
+ }
+
+ private:
+ float avg_red;
+ float avg_green;
+ float avg_blue;
+ unsigned long pyramid_padding = 10;
+ unsigned long pyramid_outer_padding = 11;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_INPUT_H_
+
diff --git a/ml/dlib/dlib/dnn/input_abstract.h b/ml/dlib/dlib/dnn/input_abstract.h
new file mode 100644
index 000000000..7130efb17
--- /dev/null
+++ b/ml/dlib/dlib/dnn/input_abstract.h
@@ -0,0 +1,467 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DNn_INPUT_ABSTRACT_H_
+#ifdef DLIB_DNn_INPUT_ABSTRACT_H_
+
+#include "../matrix.h"
+#include "../pixel.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class EXAMPLE_INPUT_LAYER
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ Each deep neural network model in dlib begins with an input layer. The job
+ of the input layer is to convert an input_type into a tensor. Nothing more
+ and nothing less.
+
+ Note that there is no dlib::EXAMPLE_INPUT_LAYER type. It is shown here
+ purely to document the interface that an input layer object must implement.
+ If you are using some kind of image or matrix object as your input_type
+ then you can use the provided dlib::input layer defined below. Otherwise,
+ you need to define your own custom input layer.
+
+ THREAD SAFETY
+ to_tensor() must be thread safe. That is, multiple threads must be able to
+ make calls to to_tensor() on a single instance of this object at the same
+ time.
+ !*/
+ public:
+
+ EXAMPLE_INPUT_LAYER(
+ );
+ /*!
+ ensures
+ - Default constructs this object. This function is not required to do
+ anything in particular but it must exist, that is, it is required that
+ layer objects be default constructable.
+ !*/
+
+ EXAMPLE_INPUT_LAYER (
+ const EXAMPLE_INPUT_LAYER& item
+ );
+ /*!
+ ensures
+ - EXAMPLE_INPUT_LAYER objects are copy constructable
+ !*/
+
+ EXAMPLE_INPUT_LAYER(
+ const some_other_input_layer_type& item
+ );
+ /*!
+ ensures
+ - Constructs this object from item. This form of constructor is optional
+ but it allows you to provide a conversion from one input layer type to
+ another. For example, the following code is valid only if my_input_layer2 can
+ be constructed from my_input_layer1:
+ relu<fc<relu<fc<my_input_layer1>>>> my_dnn1;
+ relu<fc<relu<fc<my_input_layer2>>>> my_dnn2(my_dnn1);
+ This kind of pattern is useful if you want to use one type of input layer
+ during training but a different type of layer during testing since it
+ allows you to easily convert between related deep neural network types.
+ !*/
+
+ typedef whatever_type_to_tensor_expects input_type;
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const;
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ ensures
+ - Converts the iterator range into a tensor and stores it into #data.
+ - #data.num_samples()%distance(ibegin,iend) == 0.
+ Normally you would have #data.num_samples() == distance(ibegin,iend) but
+ you can also expand the output by some integer factor so long as the loss
+ you use can deal with it correctly.
+ - The data in the ith sample of #data corresponds to the input_type object
+ *(ibegin+i/sample_expansion_factor).
+ where sample_expansion_factor==#data.num_samples()/distance(ibegin,iend).
+ !*/
+ };
+
+ std::ostream& operator<<(std::ostream& out, const EXAMPLE_INPUT_LAYER& item);
+ /*!
+ print a string describing this layer.
+ !*/
+
+ void to_xml(const EXAMPLE_INPUT_LAYER& item, std::ostream& out);
+ /*!
+ This function is optional, but required if you want to print your networks with
+ net_to_xml(). Therefore, to_xml() prints a layer as XML.
+ !*/
+
+ void serialize(const EXAMPLE_INPUT_LAYER& item, std::ostream& out);
+ void deserialize(EXAMPLE_INPUT_LAYER& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class input
+ {
+ /*!
+ REQUIREMENTS ON T
+ One of the following must be true:
+ - T is a matrix or array2d object and it must contain some kind of
+ pixel type. I.e. pixel_traits<T::type> must be defined.
+ - T is a std::array<matrix<U>> where U is any built in scalar type like
+ float, double, or unsigned char.
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a basic input layer that simply copies images into a tensor.
+ !*/
+
+ public:
+ typedef T input_type;
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const;
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ - The input range should contain image objects that all have the same
+ dimensions.
+ ensures
+ - Converts the iterator range into a tensor and stores it into #data. In
+ particular, if the input images have R rows, C columns, and K channels
+ (where K is given by pixel_traits::num or std::array::size() if
+ std::array inputs are used) then we will have:
+ - #data.num_samples() == std::distance(ibegin,iend)
+ - #data.nr() == R
+ - #data.nc() == C
+ - #data.k() == K
+ For example, a matrix<float,3,3> would turn into a tensor with 3 rows, 3
+ columns, and k()==1. Or a matrix<rgb_pixel,4,5> would turn into a tensor
+ with 4 rows, 5 columns, and k()==3 (since rgb_pixels have 3 channels).
+ Or a std::array<matrix<float,3,3>,5> would turn into a tensor with 3 rows
+ and columns, and k()==5 channels.
+ - If the input data contains pixels of type unsigned char, rgb_pixel, or
+ other pixel types with a basic_pixel_type of unsigned char then each
+ value written to the output tensor is first divided by 256.0 so that the
+ resulting outputs are all in the range [0,1].
+ !*/
+
+ // Provided for compatibility with input_rgb_image_pyramid's interface
+ bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); }
+ drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; }
+ drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class input_rgb_image
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This input layer works with RGB images of type matrix<rgb_pixel>. It is
+ very similar to the dlib::input layer except that it allows you to subtract
+ the average color value from each color channel when converting an image to
+ a tensor.
+ !*/
+ public:
+ typedef matrix<rgb_pixel> input_type;
+
+ input_rgb_image (
+ );
+ /*!
+ ensures
+ - #get_avg_red() == 122.782
+ - #get_avg_green() == 117.001
+ - #get_avg_blue() == 104.298
+ !*/
+
+ input_rgb_image (
+ float avg_red,
+ float avg_green,
+ float avg_blue
+ );
+ /*!
+ ensures
+ - #get_avg_red() == avg_red
+ - #get_avg_green() == avg_green
+ - #get_avg_blue() == avg_blue
+ !*/
+
+ float get_avg_red(
+ ) const;
+ /*!
+ ensures
+ - returns the value subtracted from the red color channel.
+ !*/
+
+ float get_avg_green(
+ ) const;
+ /*!
+ ensures
+ - returns the value subtracted from the green color channel.
+ !*/
+
+ float get_avg_blue(
+ ) const;
+ /*!
+ ensures
+ - returns the value subtracted from the blue color channel.
+ !*/
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const;
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ - The input range should contain images that all have the same
+ dimensions.
+ ensures
+ - Converts the iterator range into a tensor and stores it into #data. In
+ particular, if the input images have R rows, C columns then we will have:
+ - #data.num_samples() == std::distance(ibegin,iend)
+ - #data.nr() == R
+ - #data.nc() == C
+ - #data.k() == 3
+ Moreover, each color channel is normalized by having its average value
+ subtracted (according to get_avg_red(), get_avg_green(), or
+ get_avg_blue()) and then is divided by 256.0.
+ !*/
+
+
+ // Provided for compatibility with input_rgb_image_pyramid's interface
+ bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); }
+ drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; }
+ drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <size_t NR, size_t NC=NR>
+ class input_rgb_image_sized
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This layer has an interface and behavior identical to input_rgb_image
+ except that it requires input images to have NR rows and NC columns. This
+ is checked by a DLIB_CASSERT inside to_tensor().
+
+ You can also convert between input_rgb_image and input_rgb_image_sized by
+ copy construction or assignment.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PYRAMID_TYPE
+ >
+ class input_rgb_image_pyramid
+ {
+ /*!
+ REQUIREMENTS ON PYRAMID_TYPE
+ PYRAMID_TYPE must be an instance of the dlib::pyramid_down template.
+
+ WHAT THIS OBJECT REPRESENTS
+ This input layer works with RGB images of type matrix<rgb_pixel>. It is
+ identical to input_rgb_image except that it outputs a tensor containing a
+ tiled image pyramid of each input image rather than a simple copy of each
+ image. The tiled image pyramid is created using create_tiled_pyramid().
+ !*/
+
+ public:
+
+ typedef matrix<rgb_pixel> input_type;
+ typedef PYRAMID_TYPE pyramid_type;
+
+ input_rgb_image_pyramid (
+ );
+ /*!
+ ensures
+ - #get_avg_red() == 122.782
+ - #get_avg_green() == 117.001
+ - #get_avg_blue() == 104.298
+ - #get_pyramid_padding() == 10
+ - #get_pyramid_outer_padding() == 11
+ !*/
+
+ input_rgb_image_pyramid (
+ float avg_red,
+ float avg_green,
+ float avg_blue
+ );
+ /*!
+ ensures
+ - #get_avg_red() == avg_red
+ - #get_avg_green() == avg_green
+ - #get_avg_blue() == avg_blue
+ - #get_pyramid_padding() == 10
+ - #get_pyramid_outer_padding() == 11
+ !*/
+
+ float get_avg_red(
+ ) const;
+ /*!
+ ensures
+ - returns the value subtracted from the red color channel.
+ !*/
+
+ float get_avg_green(
+ ) const;
+ /*!
+ ensures
+ - returns the value subtracted from the green color channel.
+ !*/
+
+ float get_avg_blue(
+ ) const;
+ /*!
+ ensures
+ - returns the value subtracted from the blue color channel.
+ !*/
+
+ unsigned long get_pyramid_padding (
+ ) const;
+ /*!
+ ensures
+ - When this object creates a pyramid it will call create_tiled_pyramid() and
+ set create_tiled_pyramid's pyramid_padding parameter to get_pyramid_padding().
+ !*/
+ void set_pyramid_padding (
+ unsigned long value
+ );
+ /*!
+ ensures
+ - #get_pyramid_padding() == value
+ !*/
+
+ unsigned long get_pyramid_outer_padding (
+ ) const;
+ /*!
+ ensures
+ - When this object creates a pyramid it will call create_tiled_pyramid()
+ and set create_tiled_pyramid's pyramid_outer_padding parameter to
+ get_pyramid_outer_padding().
+ !*/
+ void set_pyramid_outer_padding (
+ unsigned long value
+ );
+ /*!
+ ensures
+ - #get_pyramid_outer_padding() == value
+ !*/
+
+ template <typename forward_iterator>
+ void to_tensor (
+ forward_iterator ibegin,
+ forward_iterator iend,
+ resizable_tensor& data
+ ) const;
+ /*!
+ requires
+ - [ibegin, iend) is an iterator range over input_type objects.
+ - std::distance(ibegin,iend) > 0
+ - The input range should contain images that all have the same
+ dimensions.
+ ensures
+ - Converts the iterator range into a tensor and stores it into #data. In
+ particular, we will have:
+ - #data.num_samples() == std::distance(ibegin,iend)
+ - #data.k() == 3
+ - Each sample in #data contains a tiled image pyramid of the
+ corresponding input image. The tiled pyramid is created by
+ create_tiled_pyramid().
+ Moreover, each color channel is normalized by having its average value
+ subtracted (according to get_avg_red(), get_avg_green(), or
+ get_avg_blue()) and then is divided by 256.0.
+ !*/
+
+ bool image_contained_point (
+ const tensor& data,
+ const point& p
+ ) const;
+ /*!
+ requires
+ - data is a tensor that was produced by this->to_tensor()
+ ensures
+ - Since data is a tensor that is built from a bunch of identically sized
+ images, we can ask if those images were big enough to contain the point
+ p. This function returns the answer to that question.
+ !*/
+
+ drectangle image_space_to_tensor_space (
+ const tensor& data,
+ double scale,
+ drectangle r
+ ) const;
+ /*!
+ requires
+ - data is a tensor that was produced by this->to_tensor()
+ - 0 < scale <= 1
+ ensures
+ - This function maps from to_tensor()'s input image space to its output
+ tensor space. Therefore, given that data is a tensor produced by
+ to_tensor(), image_space_to_tensor_space() allows you to ask for the
+ rectangle in data that corresponds to a rectangle in the original image
+ space.
+
+ Note that since the output tensor contains an image pyramid, there are
+ multiple points in the output tensor that correspond to any input
+ location. So you must also specify a scale so we know what level of the
+ pyramid is needed. So given a rectangle r in an input image, you can
+ ask, what rectangle in data corresponds to r when things are scale times
+ smaller? That rectangle is returned by this function.
+ - A scale of 1 means we don't move anywhere in the pyramid scale space relative
+ to the input image while smaller values of scale mean we move down the
+ pyramid.
+ !*/
+
+ drectangle tensor_space_to_image_space (
+ const tensor& data,
+ drectangle r
+ ) const;
+ /*!
+ requires
+ - data is a tensor that was produced by this->to_tensor()
+ ensures
+ - This function maps from to_tensor()'s output tensor space to its input
+ image space. Therefore, given that data is a tensor produced by
+ to_tensor(), tensor_space_to_image_space() allows you to ask for the
+ rectangle in the input image that corresponds to a rectangle in data.
+ - It should be noted that this function isn't always an inverse of
+ image_space_to_tensor_space(). This is because you can ask
+ image_space_to_tensor_space() for the coordinates of points outside the input
+ image and they will be mapped to somewhere that doesn't have an inverse.
+ But for points actually inside the input image this function performs an
+ approximate inverse mapping. I.e. when image_contained_point(data,center(r))==true
+ there is an approximate inverse.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_INPUT_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/dnn/layers.h b/ml/dlib/dlib/dnn/layers.h
new file mode 100644
index 000000000..91436f635
--- /dev/null
+++ b/ml/dlib/dlib/dnn/layers.h
@@ -0,0 +1,3244 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_LAYERS_H_
+#define DLIB_DNn_LAYERS_H_
+
+#include "layers_abstract.h"
+#include "tensor.h"
+#include "core.h"
+#include <iostream>
+#include <string>
+#include "../rand.h"
+#include "../string.h"
+#include "tensor_tools.h"
+#include "../vectorstream.h"
+#include "utilities.h"
+#include <sstream>
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct num_con_outputs
+ {
+ num_con_outputs(unsigned long n) : num_outputs(n) {}
+ unsigned long num_outputs;
+ };
+
+ template <
+ long _num_filters,
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y = _stride_y!=1? 0 : _nr/2,
+ int _padding_x = _stride_x!=1? 0 : _nc/2
+ >
+ class con_
+ {
+ public:
+
+ static_assert(_num_filters > 0, "The number of filters must be > 0");
+ static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
+ static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
+ static_assert(_stride_y > 0, "The filter stride must be > 0");
+ static_assert(_stride_x > 0, "The filter stride must be > 0");
+ static_assert(_nr==0 || (0 <= _padding_y && _padding_y < _nr), "The padding must be smaller than the filter size.");
+ static_assert(_nc==0 || (0 <= _padding_x && _padding_x < _nc), "The padding must be smaller than the filter size.");
+ static_assert(_nr!=0 || 0 == _padding_y, "If _nr==0 then the padding must be set to 0 as well.");
+ static_assert(_nc!=0 || 0 == _padding_x, "If _nr==0 then the padding must be set to 0 as well.");
+
+ con_(
+ num_con_outputs o
+ ) :
+ learning_rate_multiplier(1),
+ weight_decay_multiplier(1),
+ bias_learning_rate_multiplier(1),
+ bias_weight_decay_multiplier(0),
+ num_filters_(o.num_outputs),
+ padding_y_(_padding_y),
+ padding_x_(_padding_x)
+ {
+ DLIB_CASSERT(num_filters_ > 0);
+ }
+
+ con_() : con_(num_con_outputs(_num_filters)) {}
+
+ long num_filters() const { return num_filters_; }
+ long nr() const
+ {
+ if (_nr==0)
+ return filters.nr();
+ else
+ return _nr;
+ }
+ long nc() const
+ {
+ if (_nc==0)
+ return filters.nc();
+ else
+ return _nc;
+ }
+ long stride_y() const { return _stride_y; }
+ long stride_x() const { return _stride_x; }
+ long padding_y() const { return padding_y_; }
+ long padding_x() const { return padding_x_; }
+
+ void set_num_filters(long num)
+ {
+ DLIB_CASSERT(num > 0);
+ if (num != num_filters_)
+ {
+ DLIB_CASSERT(get_layer_params().size() == 0,
+ "You can't change the number of filters in con_ if the parameter tensor has already been allocated.");
+ num_filters_ = num;
+ }
+ }
+
+ double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
+ double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
+ void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
+ void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
+
+ double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; }
+ double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
+ void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
+ void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
+
+ inline dpoint map_input_to_output (
+ dpoint p
+ ) const
+ {
+ p.x() = (p.x()+padding_x()-nc()/2)/stride_x();
+ p.y() = (p.y()+padding_y()-nr()/2)/stride_y();
+ return p;
+ }
+
+ inline dpoint map_output_to_input (
+ dpoint p
+ ) const
+ {
+ p.x() = p.x()*stride_x() - padding_x() + nc()/2;
+ p.y() = p.y()*stride_y() - padding_y() + nr()/2;
+ return p;
+ }
+
+ con_ (
+ const con_& item
+ ) :
+ params(item.params),
+ filters(item.filters),
+ biases(item.biases),
+ learning_rate_multiplier(item.learning_rate_multiplier),
+ weight_decay_multiplier(item.weight_decay_multiplier),
+ bias_learning_rate_multiplier(item.bias_learning_rate_multiplier),
+ bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
+ num_filters_(item.num_filters_),
+ padding_y_(item.padding_y_),
+ padding_x_(item.padding_x_)
+ {
+ // this->conv is non-copyable and basically stateless, so we have to write our
+ // own copy to avoid trying to copy it and getting an error.
+ }
+
+ con_& operator= (
+ const con_& item
+ )
+ {
+ if (this == &item)
+ return *this;
+
+ // this->conv is non-copyable and basically stateless, so we have to write our
+ // own copy to avoid trying to copy it and getting an error.
+ params = item.params;
+ filters = item.filters;
+ biases = item.biases;
+ padding_y_ = item.padding_y_;
+ padding_x_ = item.padding_x_;
+ learning_rate_multiplier = item.learning_rate_multiplier;
+ weight_decay_multiplier = item.weight_decay_multiplier;
+ bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
+ bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
+ num_filters_ = item.num_filters_;
+ return *this;
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& sub)
+ {
+ const long filt_nr = _nr!=0 ? _nr : sub.get_output().nr();
+ const long filt_nc = _nc!=0 ? _nc : sub.get_output().nc();
+
+ long num_inputs = filt_nr*filt_nc*sub.get_output().k();
+ long num_outputs = num_filters_;
+ // allocate params for the filters and also for the filter bias values.
+ params.set_size(num_inputs*num_filters_ + num_filters_);
+
+ dlib::rand rnd(std::rand());
+ randomize_parameters(params, num_inputs+num_outputs, rnd);
+
+ filters = alias_tensor(num_filters_, sub.get_output().k(), filt_nr, filt_nc);
+ biases = alias_tensor(1,num_filters_);
+
+ // set the initial bias values to zero
+ biases(params,filters.size()) = 0;
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ conv.setup(sub.get_output(),
+ filters(params,0),
+ _stride_y,
+ _stride_x,
+ padding_y_,
+ padding_x_);
+ conv(false, output,
+ sub.get_output(),
+ filters(params,0));
+
+ tt::add(1,output,1,biases(params,filters.size()));
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
+ {
+ conv.get_gradient_for_data (true, gradient_input, filters(params,0), sub.get_gradient_input());
+ // no dpoint computing the parameter gradients if they won't be used.
+ if (learning_rate_multiplier != 0)
+ {
+ auto filt = filters(params_grad,0);
+ conv.get_gradient_for_filters (false, gradient_input, sub.get_output(), filt);
+ auto b = biases(params_grad, filters.size());
+ tt::assign_conv_bias_gradient(b, gradient_input);
+ }
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const con_& item, std::ostream& out)
+ {
+ serialize("con_4", out);
+ serialize(item.params, out);
+ serialize(item.num_filters_, out);
+ serialize(_nr, out);
+ serialize(_nc, out);
+ serialize(_stride_y, out);
+ serialize(_stride_x, out);
+ serialize(item.padding_y_, out);
+ serialize(item.padding_x_, out);
+ serialize(item.filters, out);
+ serialize(item.biases, out);
+ serialize(item.learning_rate_multiplier, out);
+ serialize(item.weight_decay_multiplier, out);
+ serialize(item.bias_learning_rate_multiplier, out);
+ serialize(item.bias_weight_decay_multiplier, out);
+ }
+
+ friend void deserialize(con_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ long nr;
+ long nc;
+ int stride_y;
+ int stride_x;
+ if (version == "con_4")
+ {
+ deserialize(item.params, in);
+ deserialize(item.num_filters_, in);
+ deserialize(nr, in);
+ deserialize(nc, in);
+ deserialize(stride_y, in);
+ deserialize(stride_x, in);
+ deserialize(item.padding_y_, in);
+ deserialize(item.padding_x_, in);
+ deserialize(item.filters, in);
+ deserialize(item.biases, in);
+ deserialize(item.learning_rate_multiplier, in);
+ deserialize(item.weight_decay_multiplier, in);
+ deserialize(item.bias_learning_rate_multiplier, in);
+ deserialize(item.bias_weight_decay_multiplier, in);
+ if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
+ if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
+ if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
+ if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
+ if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
+ if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
+ }
+ else
+ {
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
+ }
+ }
+
+
+ friend std::ostream& operator<<(std::ostream& out, const con_& item)
+ {
+ out << "con\t ("
+ << "num_filters="<<item.num_filters_
+ << ", nr="<<item.nr()
+ << ", nc="<<item.nc()
+ << ", stride_y="<<_stride_y
+ << ", stride_x="<<_stride_x
+ << ", padding_y="<<item.padding_y_
+ << ", padding_x="<<item.padding_x_
+ << ")";
+ out << " learning_rate_mult="<<item.learning_rate_multiplier;
+ out << " weight_decay_mult="<<item.weight_decay_multiplier;
+ out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
+ out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
+ return out;
+ }
+
+ friend void to_xml(const con_& item, std::ostream& out)
+ {
+ out << "<con"
+ << " num_filters='"<<item.num_filters_<<"'"
+ << " nr='"<<item.nr()<<"'"
+ << " nc='"<<item.nc()<<"'"
+ << " stride_y='"<<_stride_y<<"'"
+ << " stride_x='"<<_stride_x<<"'"
+ << " padding_y='"<<item.padding_y_<<"'"
+ << " padding_x='"<<item.padding_x_<<"'"
+ << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
+ << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
+ << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
+ << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'>\n";
+ out << mat(item.params);
+ out << "</con>";
+ }
+
+ private:
+
+ resizable_tensor params;
+ alias_tensor filters, biases;
+
+ tt::tensor_conv conv;
+ double learning_rate_multiplier;
+ double weight_decay_multiplier;
+ double bias_learning_rate_multiplier;
+ double bias_weight_decay_multiplier;
+ long num_filters_;
+
+ // These are here only because older versions of con (which you might encounter
+ // serialized to disk) used different padding settings.
+ int padding_y_;
+ int padding_x_;
+
+ };
+
+ template <
+ long num_filters,
+ long nr,
+ long nc,
+ int stride_y,
+ int stride_x,
+ typename SUBNET
+ >
+ using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long _num_filters,
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y = _stride_y!=1? 0 : _nr/2,
+ int _padding_x = _stride_x!=1? 0 : _nc/2
+ >
+ class cont_
+ {
+ public:
+
+ static_assert(_num_filters > 0, "The number of filters must be > 0");
+ static_assert(_nr > 0, "The number of rows in a filter must be > 0");
+ static_assert(_nc > 0, "The number of columns in a filter must be > 0");
+ static_assert(_stride_y > 0, "The filter stride must be > 0");
+ static_assert(_stride_x > 0, "The filter stride must be > 0");
+ static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size.");
+ static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size.");
+
+ cont_(
+ num_con_outputs o
+ ) :
+ learning_rate_multiplier(1),
+ weight_decay_multiplier(1),
+ bias_learning_rate_multiplier(1),
+ bias_weight_decay_multiplier(0),
+ num_filters_(o.num_outputs),
+ padding_y_(_padding_y),
+ padding_x_(_padding_x)
+ {
+ DLIB_CASSERT(num_filters_ > 0);
+ }
+
+ cont_() : cont_(num_con_outputs(_num_filters)) {}
+
+ long num_filters() const { return num_filters_; }
+ long nr() const { return _nr; }
+ long nc() const { return _nc; }
+ long stride_y() const { return _stride_y; }
+ long stride_x() const { return _stride_x; }
+ long padding_y() const { return padding_y_; }
+ long padding_x() const { return padding_x_; }
+
+ void set_num_filters(long num)
+ {
+ DLIB_CASSERT(num > 0);
+ if (num != num_filters_)
+ {
+ DLIB_CASSERT(get_layer_params().size() == 0,
+ "You can't change the number of filters in cont_ if the parameter tensor has already been allocated.");
+ num_filters_ = num;
+ }
+ }
+
+ double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
+ double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
+ void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
+ void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
+
+ double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; }
+ double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
+ void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
+ void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
+
+ inline dpoint map_output_to_input (
+ dpoint p
+ ) const
+ {
+ p.x() = (p.x()+padding_x()-nc()/2)/stride_x();
+ p.y() = (p.y()+padding_y()-nr()/2)/stride_y();
+ return p;
+ }
+
+ inline dpoint map_input_to_output (
+ dpoint p
+ ) const
+ {
+ p.x() = p.x()*stride_x() - padding_x() + nc()/2;
+ p.y() = p.y()*stride_y() - padding_y() + nr()/2;
+ return p;
+ }
+
+ cont_ (
+ const cont_& item
+ ) :
+ params(item.params),
+ filters(item.filters),
+ biases(item.biases),
+ learning_rate_multiplier(item.learning_rate_multiplier),
+ weight_decay_multiplier(item.weight_decay_multiplier),
+ bias_learning_rate_multiplier(item.bias_learning_rate_multiplier),
+ bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
+ num_filters_(item.num_filters_),
+ padding_y_(item.padding_y_),
+ padding_x_(item.padding_x_)
+ {
+ // this->conv is non-copyable and basically stateless, so we have to write our
+ // own copy to avoid trying to copy it and getting an error.
+ }
+
+ cont_& operator= (
+ const cont_& item
+ )
+ {
+ if (this == &item)
+ return *this;
+
+ // this->conv is non-copyable and basically stateless, so we have to write our
+ // own copy to avoid trying to copy it and getting an error.
+ params = item.params;
+ filters = item.filters;
+ biases = item.biases;
+ padding_y_ = item.padding_y_;
+ padding_x_ = item.padding_x_;
+ learning_rate_multiplier = item.learning_rate_multiplier;
+ weight_decay_multiplier = item.weight_decay_multiplier;
+ bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
+ bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
+ num_filters_ = item.num_filters_;
+ return *this;
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& sub)
+ {
+ long num_inputs = _nr*_nc*sub.get_output().k();
+ long num_outputs = num_filters_;
+ // allocate params for the filters and also for the filter bias values.
+ params.set_size(num_inputs*num_filters_ + num_filters_);
+
+ dlib::rand rnd(std::rand());
+ randomize_parameters(params, num_inputs+num_outputs, rnd);
+
+ filters = alias_tensor(sub.get_output().k(), num_filters_, _nr, _nc);
+ biases = alias_tensor(1,num_filters_);
+
+ // set the initial bias values to zero
+ biases(params,filters.size()) = 0;
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ auto filt = filters(params,0);
+ unsigned int gnr = _stride_y * (sub.get_output().nr() - 1) + filt.nr() - 2 * padding_y_;
+ unsigned int gnc = _stride_x * (sub.get_output().nc() - 1) + filt.nc() - 2 * padding_x_;
+ unsigned int gnsamps = sub.get_output().num_samples();
+ unsigned int gk = filt.k();
+ output.set_size(gnsamps,gk,gnr,gnc);
+ conv.setup(output,filt,_stride_y,_stride_x,padding_y_,padding_x_);
+ conv.get_gradient_for_data(false, sub.get_output(),filt,output);
+ tt::add(1,output,1,biases(params,filters.size()));
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
+ {
+ auto filt = filters(params,0);
+ conv(true, sub.get_gradient_input(),gradient_input, filt);
+ // no point computing the parameter gradients if they won't be used.
+ if (learning_rate_multiplier != 0)
+ {
+ auto filt = filters(params_grad,0);
+ conv.get_gradient_for_filters (false, sub.get_output(),gradient_input, filt);
+ auto b = biases(params_grad, filters.size());
+ tt::assign_conv_bias_gradient(b, gradient_input);
+ }
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const cont_& item, std::ostream& out)
+ {
+ serialize("cont_1", out);
+ serialize(item.params, out);
+ serialize(item.num_filters_, out);
+ serialize(_nr, out);
+ serialize(_nc, out);
+ serialize(_stride_y, out);
+ serialize(_stride_x, out);
+ serialize(item.padding_y_, out);
+ serialize(item.padding_x_, out);
+ serialize(item.filters, out);
+ serialize(item.biases, out);
+ serialize(item.learning_rate_multiplier, out);
+ serialize(item.weight_decay_multiplier, out);
+ serialize(item.bias_learning_rate_multiplier, out);
+ serialize(item.bias_weight_decay_multiplier, out);
+ }
+
+ friend void deserialize(cont_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ long nr;
+ long nc;
+ int stride_y;
+ int stride_x;
+ if (version == "cont_1")
+ {
+ deserialize(item.params, in);
+ deserialize(item.num_filters_, in);
+ deserialize(nr, in);
+ deserialize(nc, in);
+ deserialize(stride_y, in);
+ deserialize(stride_x, in);
+ deserialize(item.padding_y_, in);
+ deserialize(item.padding_x_, in);
+ deserialize(item.filters, in);
+ deserialize(item.biases, in);
+ deserialize(item.learning_rate_multiplier, in);
+ deserialize(item.weight_decay_multiplier, in);
+ deserialize(item.bias_learning_rate_multiplier, in);
+ deserialize(item.bias_weight_decay_multiplier, in);
+ if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
+ if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
+ if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
+ if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
+ if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
+ if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
+ }
+ else
+ {
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
+ }
+ }
+
+
+ friend std::ostream& operator<<(std::ostream& out, const cont_& item)
+ {
+ out << "cont\t ("
+ << "num_filters="<<item.num_filters_
+ << ", nr="<<_nr
+ << ", nc="<<_nc
+ << ", stride_y="<<_stride_y
+ << ", stride_x="<<_stride_x
+ << ", padding_y="<<item.padding_y_
+ << ", padding_x="<<item.padding_x_
+ << ")";
+ out << " learning_rate_mult="<<item.learning_rate_multiplier;
+ out << " weight_decay_mult="<<item.weight_decay_multiplier;
+ out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
+ out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
+ return out;
+ }
+
+ friend void to_xml(const cont_& item, std::ostream& out)
+ {
+ out << "<cont"
+ << " num_filters='"<<item.num_filters_<<"'"
+ << " nr='"<<_nr<<"'"
+ << " nc='"<<_nc<<"'"
+ << " stride_y='"<<_stride_y<<"'"
+ << " stride_x='"<<_stride_x<<"'"
+ << " padding_y='"<<item.padding_y_<<"'"
+ << " padding_x='"<<item.padding_x_<<"'"
+ << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
+ << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
+ << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
+ << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'>\n";
+ out << mat(item.params);
+ out << "</cont>";
+ }
+
+ private:
+
+ resizable_tensor params;
+ alias_tensor filters, biases;
+
+ tt::tensor_conv conv;
+ double learning_rate_multiplier;
+ double weight_decay_multiplier;
+ double bias_learning_rate_multiplier;
+ double bias_weight_decay_multiplier;
+ long num_filters_;
+
+ int padding_y_;
+ int padding_x_;
+
+ };
+
+ template <
+ long num_filters,
+ long nr,
+ long nc,
+ int stride_y,
+ int stride_x,
+ typename SUBNET
+ >
+ using cont = add_layer<cont_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ int scale_y,
+ int scale_x
+ >
+ class upsample_
+ {
+ public:
+ static_assert(scale_y >= 1, "upsampling scale factor can't be less than 1.");
+ static_assert(scale_x >= 1, "upsampling scale factor can't be less than 1.");
+
+ upsample_()
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ output.set_size(
+ sub.get_output().num_samples(),
+ sub.get_output().k(),
+ scale_y*sub.get_output().nr(),
+ scale_x*sub.get_output().nc());
+ tt::resize_bilinear(output, sub.get_output());
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
+ {
+ tt::resize_bilinear_gradient(sub.get_gradient_input(), gradient_input);
+ }
+
+ inline dpoint map_input_to_output (dpoint p) const
+ {
+ p.x() = p.x()*scale_x;
+ p.y() = p.y()*scale_y;
+ return p;
+ }
+ inline dpoint map_output_to_input (dpoint p) const
+ {
+ p.x() = p.x()/scale_x;
+ p.y() = p.y()/scale_y;
+ return p;
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const upsample_& , std::ostream& out)
+ {
+ serialize("upsample_", out);
+ serialize(scale_y, out);
+ serialize(scale_x, out);
+ }
+
+ friend void deserialize(upsample_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "upsample_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::upsample_.");
+
+ int _scale_y;
+ int _scale_x;
+ deserialize(_scale_y, in);
+ deserialize(_scale_x, in);
+ if (_scale_y != scale_y || _scale_x != scale_x)
+ throw serialization_error("Wrong scale found while deserializing dlib::upsample_");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const upsample_& )
+ {
+ out << "upsample\t ("
+ << "scale_y="<<scale_y
+ << ", scale_x="<<scale_x
+ << ")";
+ return out;
+ }
+
+ friend void to_xml(const upsample_& /*item*/, std::ostream& out)
+ {
+ out << "<upsample"
+ << " scale_y='"<<scale_y<<"'"
+ << " scale_x='"<<scale_x<<"'/>\n";
+ }
+
+ private:
+ resizable_tensor params;
+ };
+
+ template <
+ int scale,
+ typename SUBNET
+ >
+ using upsample = add_layer<upsample_<scale,scale>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y = _stride_y!=1? 0 : _nr/2,
+ int _padding_x = _stride_x!=1? 0 : _nc/2
+ >
+ class max_pool_
+ {
+ static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
+ static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
+ static_assert(_stride_y > 0, "The filter stride must be > 0");
+ static_assert(_stride_x > 0, "The filter stride must be > 0");
+ static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)),
+ "The padding must be smaller than the filter size, unless the filters size is 0.");
+ static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)),
+ "The padding must be smaller than the filter size, unless the filters size is 0.");
+ public:
+
+
+ max_pool_(
+ ) :
+ padding_y_(_padding_y),
+ padding_x_(_padding_x)
+ {}
+
+ long nr() const { return _nr; }
+ long nc() const { return _nc; }
+ long stride_y() const { return _stride_y; }
+ long stride_x() const { return _stride_x; }
+ long padding_y() const { return padding_y_; }
+ long padding_x() const { return padding_x_; }
+
+ inline dpoint map_input_to_output (
+ dpoint p
+ ) const
+ {
+ p.x() = (p.x()+padding_x()-nc()/2)/stride_x();
+ p.y() = (p.y()+padding_y()-nr()/2)/stride_y();
+ return p;
+ }
+
+ inline dpoint map_output_to_input (
+ dpoint p
+ ) const
+ {
+ p.x() = p.x()*stride_x() - padding_x() + nc()/2;
+ p.y() = p.y()*stride_y() - padding_y() + nr()/2;
+ return p;
+ }
+
+ max_pool_ (
+ const max_pool_& item
+ ) :
+ padding_y_(item.padding_y_),
+ padding_x_(item.padding_x_)
+ {
+ // this->mp is non-copyable so we have to write our own copy to avoid trying to
+ // copy it and getting an error.
+ }
+
+ max_pool_& operator= (
+ const max_pool_& item
+ )
+ {
+ if (this == &item)
+ return *this;
+
+ padding_y_ = item.padding_y_;
+ padding_x_ = item.padding_x_;
+
+ // this->mp is non-copyable so we have to write our own copy to avoid trying to
+ // copy it and getting an error.
+ return *this;
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(),
+ _nc!=0?_nc:sub.get_output().nc(),
+ _stride_y, _stride_x, padding_y_, padding_x_);
+
+ mp(output, sub.get_output());
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
+ {
+ mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(),
+ _nc!=0?_nc:sub.get_output().nc(),
+ _stride_y, _stride_x, padding_y_, padding_x_);
+
+ mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const max_pool_& item, std::ostream& out)
+ {
+ serialize("max_pool_2", out);
+ serialize(_nr, out);
+ serialize(_nc, out);
+ serialize(_stride_y, out);
+ serialize(_stride_x, out);
+ serialize(item.padding_y_, out);
+ serialize(item.padding_x_, out);
+ }
+
+ friend void deserialize(max_pool_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ long nr;
+ long nc;
+ int stride_y;
+ int stride_x;
+ if (version == "max_pool_2")
+ {
+ deserialize(nr, in);
+ deserialize(nc, in);
+ deserialize(stride_y, in);
+ deserialize(stride_x, in);
+ deserialize(item.padding_y_, in);
+ deserialize(item.padding_x_, in);
+ }
+ else
+ {
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
+ }
+
+ if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::max_pool_");
+ if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::max_pool_");
+ if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_");
+ if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
+ if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
+ if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const max_pool_& item)
+ {
+ out << "max_pool ("
+ << "nr="<<_nr
+ << ", nc="<<_nc
+ << ", stride_y="<<_stride_y
+ << ", stride_x="<<_stride_x
+ << ", padding_y="<<item.padding_y_
+ << ", padding_x="<<item.padding_x_
+ << ")";
+ return out;
+ }
+
+ friend void to_xml(const max_pool_& item, std::ostream& out)
+ {
+ out << "<max_pool"
+ << " nr='"<<_nr<<"'"
+ << " nc='"<<_nc<<"'"
+ << " stride_y='"<<_stride_y<<"'"
+ << " stride_x='"<<_stride_x<<"'"
+ << " padding_y='"<<item.padding_y_<<"'"
+ << " padding_x='"<<item.padding_x_<<"'"
+ << "/>\n";
+ }
+
+
+ private:
+
+
+ tt::pooling mp;
+ resizable_tensor params;
+
+ int padding_y_;
+ int padding_x_;
+ };
+
+ template <
+ long nr,
+ long nc,
+ int stride_y,
+ int stride_x,
+ typename SUBNET
+ >
+ using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
+
+ template <
+ typename SUBNET
+ >
+ using max_pool_everything = add_layer<max_pool_<0,0,1,1>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y = _stride_y!=1? 0 : _nr/2,
+ int _padding_x = _stride_x!=1? 0 : _nc/2
+ >
+ class avg_pool_
+ {
+ public:
+ static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
+ static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
+ static_assert(_stride_y > 0, "The filter stride must be > 0");
+ static_assert(_stride_x > 0, "The filter stride must be > 0");
+ static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)),
+ "The padding must be smaller than the filter size, unless the filters size is 0.");
+ static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)),
+ "The padding must be smaller than the filter size, unless the filters size is 0.");
+
+ avg_pool_(
+ ) :
+ padding_y_(_padding_y),
+ padding_x_(_padding_x)
+ {}
+
+ long nr() const { return _nr; }
+ long nc() const { return _nc; }
+ long stride_y() const { return _stride_y; }
+ long stride_x() const { return _stride_x; }
+ long padding_y() const { return padding_y_; }
+ long padding_x() const { return padding_x_; }
+
+ inline dpoint map_input_to_output (
+ dpoint p
+ ) const
+ {
+ p.x() = (p.x()+padding_x()-nc()/2)/stride_x();
+ p.y() = (p.y()+padding_y()-nr()/2)/stride_y();
+ return p;
+ }
+
+ inline dpoint map_output_to_input (
+ dpoint p
+ ) const
+ {
+ p.x() = p.x()*stride_x() - padding_x() + nc()/2;
+ p.y() = p.y()*stride_y() - padding_y() + nr()/2;
+ return p;
+ }
+
+ avg_pool_ (
+ const avg_pool_& item
+ ) :
+ padding_y_(item.padding_y_),
+ padding_x_(item.padding_x_)
+ {
+ // this->ap is non-copyable so we have to write our own copy to avoid trying to
+ // copy it and getting an error.
+ }
+
+ avg_pool_& operator= (
+ const avg_pool_& item
+ )
+ {
+ if (this == &item)
+ return *this;
+
+ padding_y_ = item.padding_y_;
+ padding_x_ = item.padding_x_;
+
+ // this->ap is non-copyable so we have to write our own copy to avoid trying to
+ // copy it and getting an error.
+ return *this;
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(),
+ _nc!=0?_nc:sub.get_output().nc(),
+ _stride_y, _stride_x, padding_y_, padding_x_);
+
+ ap(output, sub.get_output());
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
+ {
+ ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(),
+ _nc!=0?_nc:sub.get_output().nc(),
+ _stride_y, _stride_x, padding_y_, padding_x_);
+
+ ap.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const avg_pool_& item, std::ostream& out)
+ {
+ serialize("avg_pool_2", out);
+ serialize(_nr, out);
+ serialize(_nc, out);
+ serialize(_stride_y, out);
+ serialize(_stride_x, out);
+ serialize(item.padding_y_, out);
+ serialize(item.padding_x_, out);
+ }
+
+ friend void deserialize(avg_pool_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+
+ long nr;
+ long nc;
+ int stride_y;
+ int stride_x;
+ if (version == "avg_pool_2")
+ {
+ deserialize(nr, in);
+ deserialize(nc, in);
+ deserialize(stride_y, in);
+ deserialize(stride_x, in);
+ deserialize(item.padding_y_, in);
+ deserialize(item.padding_x_, in);
+ }
+ else
+ {
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
+ }
+
+ if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::avg_pool_");
+ if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::avg_pool_");
+ if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_");
+ if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
+ if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
+ if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item)
+ {
+ out << "avg_pool ("
+ << "nr="<<_nr
+ << ", nc="<<_nc
+ << ", stride_y="<<_stride_y
+ << ", stride_x="<<_stride_x
+ << ", padding_y="<<item.padding_y_
+ << ", padding_x="<<item.padding_x_
+ << ")";
+ return out;
+ }
+
+ friend void to_xml(const avg_pool_& item, std::ostream& out)
+ {
+ out << "<avg_pool"
+ << " nr='"<<_nr<<"'"
+ << " nc='"<<_nc<<"'"
+ << " stride_y='"<<_stride_y<<"'"
+ << " stride_x='"<<_stride_x<<"'"
+ << " padding_y='"<<item.padding_y_<<"'"
+ << " padding_x='"<<item.padding_x_<<"'"
+ << "/>\n";
+ }
+ private:
+
+ tt::pooling ap;
+ resizable_tensor params;
+
+ int padding_y_;
+ int padding_x_;
+ };
+
+ template <
+ long nr,
+ long nc,
+ int stride_y,
+ int stride_x,
+ typename SUBNET
+ >
+ using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
+
+ template <
+ typename SUBNET
+ >
+ using avg_pool_everything = add_layer<avg_pool_<0,0,1,1>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ enum layer_mode
+ {
+ CONV_MODE = 0,
+ FC_MODE = 1
+ };
+
+ const double DEFAULT_BATCH_NORM_EPS = 0.0001;
+
+ template <
+ layer_mode mode
+ >
+ class bn_
+ {
+ public:
+ explicit bn_(
+ unsigned long window_size,
+ double eps_ = DEFAULT_BATCH_NORM_EPS
+ ) :
+ num_updates(0),
+ running_stats_window_size(window_size),
+ learning_rate_multiplier(1),
+ weight_decay_multiplier(0),
+ bias_learning_rate_multiplier(1),
+ bias_weight_decay_multiplier(1),
+ eps(eps_)
+ {
+ DLIB_CASSERT(window_size > 0, "The batch normalization running stats window size can't be 0.");
+ }
+
+ bn_() : bn_(100) {}
+
+ layer_mode get_mode() const { return mode; }
+ unsigned long get_running_stats_window_size () const { return running_stats_window_size; }
+ void set_running_stats_window_size (unsigned long new_window_size )
+ {
+ DLIB_CASSERT(new_window_size > 0, "The batch normalization running stats window size can't be 0.");
+ running_stats_window_size = new_window_size;
+ }
+ double get_eps() const { return eps; }
+
+ double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
+ double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
+ void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
+ void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
+
+ double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; }
+ double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
+ void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
+ void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
+
+ inline dpoint map_input_to_output (const dpoint& p) const { return p; }
+ inline dpoint map_output_to_input (const dpoint& p) const { return p; }
+
+
+ template <typename SUBNET>
+ void setup (const SUBNET& sub)
+ {
+ if (mode == FC_MODE)
+ {
+ gamma = alias_tensor(1,
+ sub.get_output().k(),
+ sub.get_output().nr(),
+ sub.get_output().nc());
+ }
+ else
+ {
+ gamma = alias_tensor(1, sub.get_output().k());
+ }
+ beta = gamma;
+
+ params.set_size(gamma.size()+beta.size());
+
+ gamma(params,0) = 1;
+ beta(params,gamma.size()) = 0;
+
+ running_means.copy_size(gamma(params,0));
+ running_variances.copy_size(gamma(params,0));
+ running_means = 0;
+ running_variances = 1;
+ num_updates = 0;
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ auto g = gamma(params,0);
+ auto b = beta(params,gamma.size());
+ if (sub.get_output().num_samples() > 1)
+ {
+ const double decay = 1.0 - num_updates/(num_updates+1.0);
+ ++num_updates;
+ if (num_updates > running_stats_window_size)
+ num_updates = running_stats_window_size;
+
+ if (mode == FC_MODE)
+ tt::batch_normalize(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
+ else
+ tt::batch_normalize_conv(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
+ }
+ else // we are running in testing mode so we just linearly scale the input tensor.
+ {
+ if (mode == FC_MODE)
+ tt::batch_normalize_inference(eps, output, sub.get_output(), g, b, running_means, running_variances);
+ else
+ tt::batch_normalize_conv_inference(eps, output, sub.get_output(), g, b, running_means, running_variances);
+ }
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
+ {
+ auto g = gamma(params,0);
+ auto g_grad = gamma(params_grad, 0);
+ auto b_grad = beta(params_grad, gamma.size());
+ if (mode == FC_MODE)
+ tt::batch_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
+ else
+ tt::batch_normalize_conv_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const bn_& item, std::ostream& out)
+ {
+ if (mode == CONV_MODE)
+ serialize("bn_con2", out);
+ else // if FC_MODE
+ serialize("bn_fc2", out);
+ serialize(item.params, out);
+ serialize(item.gamma, out);
+ serialize(item.beta, out);
+ serialize(item.means, out);
+ serialize(item.invstds, out);
+ serialize(item.running_means, out);
+ serialize(item.running_variances, out);
+ serialize(item.num_updates, out);
+ serialize(item.running_stats_window_size, out);
+ serialize(item.learning_rate_multiplier, out);
+ serialize(item.weight_decay_multiplier, out);
+ serialize(item.bias_learning_rate_multiplier, out);
+ serialize(item.bias_weight_decay_multiplier, out);
+ serialize(item.eps, out);
+ }
+
+ friend void deserialize(bn_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (mode == CONV_MODE)
+ {
+ if (version != "bn_con2")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
+ }
+ else // must be in FC_MODE
+ {
+ if (version != "bn_fc2")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
+ }
+
+ deserialize(item.params, in);
+ deserialize(item.gamma, in);
+ deserialize(item.beta, in);
+ deserialize(item.means, in);
+ deserialize(item.invstds, in);
+ deserialize(item.running_means, in);
+ deserialize(item.running_variances, in);
+ deserialize(item.num_updates, in);
+ deserialize(item.running_stats_window_size, in);
+ deserialize(item.learning_rate_multiplier, in);
+ deserialize(item.weight_decay_multiplier, in);
+ deserialize(item.bias_learning_rate_multiplier, in);
+ deserialize(item.bias_weight_decay_multiplier, in);
+ deserialize(item.eps, in);
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const bn_& item)
+ {
+ if (mode == CONV_MODE)
+ out << "bn_con ";
+ else
+ out << "bn_fc ";
+ out << " eps="<<item.eps;
+ out << " running_stats_window_size="<<item.running_stats_window_size;
+ out << " learning_rate_mult="<<item.learning_rate_multiplier;
+ out << " weight_decay_mult="<<item.weight_decay_multiplier;
+ out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
+ out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
+ return out;
+ }
+
+ friend void to_xml(const bn_& item, std::ostream& out)
+ {
+ if (mode==CONV_MODE)
+ out << "<bn_con";
+ else
+ out << "<bn_fc";
+
+ out << " eps='"<<item.eps<<"'";
+ out << " running_stats_window_size='"<<item.running_stats_window_size<<"'";
+ out << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'";
+ out << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
+ out << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'";
+ out << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'";
+ out << ">\n";
+
+ out << mat(item.params);
+
+ if (mode==CONV_MODE)
+ out << "</bn_con>\n";
+ else
+ out << "</bn_fc>\n";
+ }
+
+ private:
+
+ friend class affine_;
+
+ resizable_tensor params;
+ alias_tensor gamma, beta;
+ resizable_tensor means, running_means;
+ resizable_tensor invstds, running_variances;
+ unsigned long num_updates;
+ unsigned long running_stats_window_size;
+ double learning_rate_multiplier;
+ double weight_decay_multiplier;
+ double bias_learning_rate_multiplier;
+ double bias_weight_decay_multiplier;
+ double eps;
+ };
+
+ template <typename SUBNET>
+ using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
+ template <typename SUBNET>
+ using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ class visitor_bn_running_stats_window_size
+ {
+ public:
+
+ visitor_bn_running_stats_window_size(unsigned long new_window_size_) : new_window_size(new_window_size_) {}
+
+ template <typename T>
+ void set_window_size(T&) const
+ {
+ // ignore other layer detail types
+ }
+
+ template < layer_mode mode >
+ void set_window_size(bn_<mode>& l) const
+ {
+ l.set_running_stats_window_size(new_window_size);
+ }
+
+ template<typename input_layer_type>
+ void operator()(size_t , input_layer_type& ) const
+ {
+ // ignore other layers
+ }
+
+ template <typename T, typename U, typename E>
+ void operator()(size_t , add_layer<T,U,E>& l) const
+ {
+ set_window_size(l.layer_details());
+ }
+
+ private:
+
+ unsigned long new_window_size;
+ };
+ }
+
+ template <typename net_type>
+ void set_all_bn_running_stats_window_sizes (
+ net_type& net,
+ unsigned long new_window_size
+ )
+ {
+ visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ enum fc_bias_mode
+ {
+ FC_HAS_BIAS = 0,
+ FC_NO_BIAS = 1
+ };
+
+ struct num_fc_outputs
+ {
+ num_fc_outputs(unsigned long n) : num_outputs(n) {}
+ unsigned long num_outputs;
+ };
+
+ template <
+ unsigned long num_outputs_,
+ fc_bias_mode bias_mode
+ >
+ class fc_
+ {
+ static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0");
+
+ public:
+ fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0),
+ learning_rate_multiplier(1),
+ weight_decay_multiplier(1),
+ bias_learning_rate_multiplier(1),
+ bias_weight_decay_multiplier(0)
+ {}
+
+ fc_() : fc_(num_fc_outputs(num_outputs_)) {}
+
+ double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
+ double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
+ void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
+ void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
+
+ double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; }
+ double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
+ void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
+ void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
+
+ unsigned long get_num_outputs (
+ ) const { return num_outputs; }
+
+ void set_num_outputs(long num)
+ {
+ DLIB_CASSERT(num > 0);
+ if (num != (long)num_outputs)
+ {
+ DLIB_CASSERT(get_layer_params().size() == 0,
+ "You can't change the number of filters in fc_ if the parameter tensor has already been allocated.");
+ num_outputs = num;
+ }
+ }
+
+ fc_bias_mode get_bias_mode (
+ ) const { return bias_mode; }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& sub)
+ {
+ num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
+ if (bias_mode == FC_HAS_BIAS)
+ params.set_size(num_inputs+1, num_outputs);
+ else
+ params.set_size(num_inputs, num_outputs);
+
+ dlib::rand rnd(std::rand());
+ randomize_parameters(params, num_inputs+num_outputs, rnd);
+
+ weights = alias_tensor(num_inputs, num_outputs);
+
+ if (bias_mode == FC_HAS_BIAS)
+ {
+ biases = alias_tensor(1,num_outputs);
+ // set the initial bias values to zero
+ biases(params,weights.size()) = 0;
+ }
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ DLIB_CASSERT((long)num_inputs == sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k(),
+ "The size of the input tensor to this fc layer doesn't match the size the fc layer was trained with.");
+ output.set_size(sub.get_output().num_samples(), num_outputs);
+
+ auto w = weights(params, 0);
+ tt::gemm(0,output, 1,sub.get_output(),false, w,false);
+ if (bias_mode == FC_HAS_BIAS)
+ {
+ auto b = biases(params, weights.size());
+ tt::add(1,output,1,b);
+ }
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
+ {
+ // no point computing the parameter gradients if they won't be used.
+ if (learning_rate_multiplier != 0)
+ {
+ // compute the gradient of the weight parameters.
+ auto pw = weights(params_grad, 0);
+ tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false);
+
+ if (bias_mode == FC_HAS_BIAS)
+ {
+ // compute the gradient of the bias parameters.
+ auto pb = biases(params_grad, weights.size());
+ tt::assign_bias_gradient(pb, gradient_input);
+ }
+ }
+
+ // compute the gradient for the data
+ auto w = weights(params, 0);
+ tt::gemm(1,sub.get_gradient_input(), 1,gradient_input,false, w,true);
+ }
+
+ alias_tensor_instance get_weights()
+ {
+ return weights(params, 0);
+ }
+
+ alias_tensor_const_instance get_weights() const
+ {
+ return weights(params, 0);
+ }
+
+ alias_tensor_instance get_biases()
+ {
+ static_assert(bias_mode == FC_HAS_BIAS, "This fc_ layer doesn't have a bias vector "
+ "to be retrieved, as per template parameter 'bias_mode'.");
+ return biases(params, weights.size());
+ }
+
+ alias_tensor_const_instance get_biases() const
+ {
+ static_assert(bias_mode == FC_HAS_BIAS, "This fc_ layer doesn't have a bias vector "
+ "to be retrieved, as per template parameter 'bias_mode'.");
+ return biases(params, weights.size());
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const fc_& item, std::ostream& out)
+ {
+ serialize("fc_2", out);
+ serialize(item.num_outputs, out);
+ serialize(item.num_inputs, out);
+ serialize(item.params, out);
+ serialize(item.weights, out);
+ serialize(item.biases, out);
+ serialize((int)bias_mode, out);
+ serialize(item.learning_rate_multiplier, out);
+ serialize(item.weight_decay_multiplier, out);
+ serialize(item.bias_learning_rate_multiplier, out);
+ serialize(item.bias_weight_decay_multiplier, out);
+ }
+
+ friend void deserialize(fc_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "fc_2")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
+
+ deserialize(item.num_outputs, in);
+ deserialize(item.num_inputs, in);
+ deserialize(item.params, in);
+ deserialize(item.weights, in);
+ deserialize(item.biases, in);
+ int bmode = 0;
+ deserialize(bmode, in);
+ if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
+ deserialize(item.learning_rate_multiplier, in);
+ deserialize(item.weight_decay_multiplier, in);
+ deserialize(item.bias_learning_rate_multiplier, in);
+ deserialize(item.bias_weight_decay_multiplier, in);
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const fc_& item)
+ {
+ if (bias_mode == FC_HAS_BIAS)
+ {
+ out << "fc\t ("
+ << "num_outputs="<<item.num_outputs
+ << ")";
+ out << " learning_rate_mult="<<item.learning_rate_multiplier;
+ out << " weight_decay_mult="<<item.weight_decay_multiplier;
+ out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
+ out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
+ }
+ else
+ {
+ out << "fc_no_bias ("
+ << "num_outputs="<<item.num_outputs
+ << ")";
+ out << " learning_rate_mult="<<item.learning_rate_multiplier;
+ out << " weight_decay_mult="<<item.weight_decay_multiplier;
+ }
+ return out;
+ }
+
+ friend void to_xml(const fc_& item, std::ostream& out)
+ {
+ if (bias_mode==FC_HAS_BIAS)
+ {
+ out << "<fc"
+ << " num_outputs='"<<item.num_outputs<<"'"
+ << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
+ << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
+ << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
+ << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'";
+ out << ">\n";
+ out << mat(item.params);
+ out << "</fc>\n";
+ }
+ else
+ {
+ out << "<fc_no_bias"
+ << " num_outputs='"<<item.num_outputs<<"'"
+ << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
+ << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
+ out << ">\n";
+ out << mat(item.params);
+ out << "</fc_no_bias>\n";
+ }
+ }
+
+ private:
+
+ unsigned long num_outputs;
+ unsigned long num_inputs;
+ resizable_tensor params;
+ alias_tensor weights, biases;
+ double learning_rate_multiplier;
+ double weight_decay_multiplier;
+ double bias_learning_rate_multiplier;
+ double bias_weight_decay_multiplier;
+ };
+
+ template <
+ unsigned long num_outputs,
+ typename SUBNET
+ >
+ using fc = add_layer<fc_<num_outputs,FC_HAS_BIAS>, SUBNET>;
+
+ template <
+ unsigned long num_outputs,
+ typename SUBNET
+ >
+ using fc_no_bias = add_layer<fc_<num_outputs,FC_NO_BIAS>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class dropout_
+ {
+ public:
+ explicit dropout_(
+ float drop_rate_ = 0.5
+ ) :
+ drop_rate(drop_rate_),
+ rnd(std::rand())
+ {
+ DLIB_CASSERT(0 <= drop_rate && drop_rate <= 1);
+ }
+
+ // We have to add a copy constructor and assignment operator because the rnd object
+ // is non-copyable.
+ dropout_(
+ const dropout_& item
+ ) : drop_rate(item.drop_rate), mask(item.mask), rnd(std::rand())
+ {}
+
+ dropout_& operator= (
+ const dropout_& item
+ )
+ {
+ if (this == &item)
+ return *this;
+
+ drop_rate = item.drop_rate;
+ mask = item.mask;
+ return *this;
+ }
+
+ float get_drop_rate (
+ ) const { return drop_rate; }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ void forward_inplace(const tensor& input, tensor& output)
+ {
+ // create a random mask and use it to filter the data
+ mask.copy_size(input);
+ rnd.fill_uniform(mask);
+ tt::threshold(mask, drop_rate);
+ tt::multiply(false, output, input, mask);
+ }
+
+ void backward_inplace(
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor& /*params_grad*/
+ )
+ {
+ if (is_same_object(gradient_input, data_grad))
+ tt::multiply(false, data_grad, mask, gradient_input);
+ else
+ tt::multiply(true, data_grad, mask, gradient_input);
+ }
+
+ inline dpoint map_input_to_output (const dpoint& p) const { return p; }
+ inline dpoint map_output_to_input (const dpoint& p) const { return p; }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const dropout_& item, std::ostream& out)
+ {
+ serialize("dropout_", out);
+ serialize(item.drop_rate, out);
+ serialize(item.mask, out);
+ }
+
+ friend void deserialize(dropout_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "dropout_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::dropout_.");
+ deserialize(item.drop_rate, in);
+ deserialize(item.mask, in);
+ }
+
+ void clean(
+ )
+ {
+ mask.clear();
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const dropout_& item)
+ {
+ out << "dropout\t ("
+ << "drop_rate="<<item.drop_rate
+ << ")";
+ return out;
+ }
+
+ friend void to_xml(const dropout_& item, std::ostream& out)
+ {
+ out << "<dropout"
+ << " drop_rate='"<<item.drop_rate<<"'";
+ out << "/>\n";
+ }
+
+ private:
+ float drop_rate;
+ resizable_tensor mask;
+
+ tt::tensor_rand rnd;
+ resizable_tensor params; // unused
+ };
+
+
+ template <typename SUBNET>
+ using dropout = add_layer<dropout_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class multiply_
+ {
+ public:
+ explicit multiply_(
+ float val_ = 0.5
+ ) :
+ val(val_)
+ {
+ }
+
+ multiply_ (
+ const dropout_& item
+ ) : val(1-item.get_drop_rate()) {}
+
+ float get_multiply_value (
+ ) const { return val; }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ void forward_inplace(const tensor& input, tensor& output)
+ {
+ tt::affine_transform(output, input, val);
+ }
+
+ inline dpoint map_input_to_output (const dpoint& p) const { return p; }
+ inline dpoint map_output_to_input (const dpoint& p) const { return p; }
+
+ void backward_inplace(
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor& /*params_grad*/
+ )
+ {
+ if (is_same_object(gradient_input, data_grad))
+ tt::affine_transform(data_grad, gradient_input, val);
+ else
+ tt::affine_transform(data_grad, data_grad, gradient_input, 1, val);
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const multiply_& item, std::ostream& out)
+ {
+ serialize("multiply_", out);
+ serialize(item.val, out);
+ }
+
+ friend void deserialize(multiply_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version == "dropout_")
+ {
+ // Since we can build a multiply_ from a dropout_ we check if that's what
+ // is in the stream and if so then just convert it right here.
+ unserialize sin(version, in);
+ dropout_ temp;
+ deserialize(temp, sin);
+ item = temp;
+ return;
+ }
+
+ if (version != "multiply_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::multiply_.");
+ deserialize(item.val, in);
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const multiply_& item)
+ {
+ out << "multiply ("
+ << "val="<<item.val
+ << ")";
+ return out;
+ }
+
+ friend void to_xml(const multiply_& item, std::ostream& out)
+ {
+ out << "<multiply"
+ << " val='"<<item.val<<"'";
+ out << "/>\n";
+ }
+ private:
+ float val;
+ resizable_tensor params; // unused
+ };
+
+ template <typename SUBNET>
+ using multiply = add_layer<multiply_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class affine_
+ {
+ public:
+ affine_(
+ ) : mode(FC_MODE)
+ {
+ }
+
+ affine_(
+ layer_mode mode_
+ ) : mode(mode_)
+ {
+ }
+
+ template <
+ layer_mode bnmode
+ >
+ affine_(
+ const bn_<bnmode>& item
+ )
+ {
+ gamma = item.gamma;
+ beta = item.beta;
+ mode = bnmode;
+
+ params.copy_size(item.params);
+
+ auto g = gamma(params,0);
+ auto b = beta(params,gamma.size());
+
+ resizable_tensor temp(item.params);
+ auto sg = gamma(temp,0);
+ auto sb = beta(temp,gamma.size());
+
+ g = pointwise_multiply(mat(sg), 1.0f/sqrt(mat(item.running_variances)+item.get_eps()));
+ b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means));
+ }
+
+ layer_mode get_mode() const { return mode; }
+
+ inline dpoint map_input_to_output (const dpoint& p) const { return p; }
+ inline dpoint map_output_to_input (const dpoint& p) const { return p; }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& sub)
+ {
+ if (mode == FC_MODE)
+ {
+ gamma = alias_tensor(1,
+ sub.get_output().k(),
+ sub.get_output().nr(),
+ sub.get_output().nc());
+ }
+ else
+ {
+ gamma = alias_tensor(1, sub.get_output().k());
+ }
+ beta = gamma;
+
+ params.set_size(gamma.size()+beta.size());
+
+ gamma(params,0) = 1;
+ beta(params,gamma.size()) = 0;
+ }
+
+ void forward_inplace(const tensor& input, tensor& output)
+ {
+ auto g = gamma(params,0);
+ auto b = beta(params,gamma.size());
+ if (mode == FC_MODE)
+ tt::affine_transform(output, input, g, b);
+ else
+ tt::affine_transform_conv(output, input, g, b);
+ }
+
+ void backward_inplace(
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor& /*params_grad*/
+ )
+ {
+ auto g = gamma(params,0);
+ auto b = beta(params,gamma.size());
+
+ // We are computing the gradient of dot(gradient_input, computed_output*g + b)
+ if (mode == FC_MODE)
+ {
+ if (is_same_object(gradient_input, data_grad))
+ tt::multiply(false, data_grad, gradient_input, g);
+ else
+ tt::multiply(true, data_grad, gradient_input, g);
+ }
+ else
+ {
+ if (is_same_object(gradient_input, data_grad))
+ tt::multiply_conv(false, data_grad, gradient_input, g);
+ else
+ tt::multiply_conv(true, data_grad, gradient_input, g);
+ }
+ }
+
+ const tensor& get_layer_params() const { return empty_params; }
+ tensor& get_layer_params() { return empty_params; }
+
+ friend void serialize(const affine_& item, std::ostream& out)
+ {
+ serialize("affine_", out);
+ serialize(item.params, out);
+ serialize(item.gamma, out);
+ serialize(item.beta, out);
+ serialize((int)item.mode, out);
+ }
+
+ friend void deserialize(affine_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version == "bn_con2")
+ {
+ // Since we can build an affine_ from a bn_ we check if that's what is in
+ // the stream and if so then just convert it right here.
+ unserialize sin(version, in);
+ bn_<CONV_MODE> temp;
+ deserialize(temp, sin);
+ item = temp;
+ return;
+ }
+ else if (version == "bn_fc2")
+ {
+ // Since we can build an affine_ from a bn_ we check if that's what is in
+ // the stream and if so then just convert it right here.
+ unserialize sin(version, in);
+ bn_<FC_MODE> temp;
+ deserialize(temp, sin);
+ item = temp;
+ return;
+ }
+
+ if (version != "affine_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_.");
+ deserialize(item.params, in);
+ deserialize(item.gamma, in);
+ deserialize(item.beta, in);
+ int mode;
+ deserialize(mode, in);
+ item.mode = (layer_mode)mode;
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const affine_& )
+ {
+ out << "affine";
+ return out;
+ }
+
+ friend void to_xml(const affine_& item, std::ostream& out)
+ {
+ if (item.mode==CONV_MODE)
+ out << "<affine_con>\n";
+ else
+ out << "<affine_fc>\n";
+
+ out << mat(item.params);
+
+ if (item.mode==CONV_MODE)
+ out << "</affine_con>\n";
+ else
+ out << "</affine_fc>\n";
+ }
+
+ private:
+ resizable_tensor params, empty_params;
+ alias_tensor gamma, beta;
+ layer_mode mode;
+ };
+
+ template <typename SUBNET>
+ using affine = add_layer<affine_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ template<typename> class tag
+ >
+ class add_prev_
+ {
+ public:
+ const static unsigned long id = tag_id<tag>::id;
+
+ add_prev_()
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ auto&& t1 = sub.get_output();
+ auto&& t2 = layer<tag>(sub).get_output();
+ output.set_size(std::max(t1.num_samples(),t2.num_samples()),
+ std::max(t1.k(),t2.k()),
+ std::max(t1.nr(),t2.nr()),
+ std::max(t1.nc(),t2.nc()));
+ tt::add(output, t1, t2);
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
+ {
+ // The gradient just flows backwards to the two layers that forward() added
+ // together.
+ tt::add(sub.get_gradient_input(), sub.get_gradient_input(), gradient_input);
+ tt::add(layer<tag>(sub).get_gradient_input(), layer<tag>(sub).get_gradient_input(), gradient_input);
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ inline dpoint map_input_to_output (const dpoint& p) const { return p; }
+ inline dpoint map_output_to_input (const dpoint& p) const { return p; }
+
+ friend void serialize(const add_prev_& , std::ostream& out)
+ {
+ serialize("add_prev_", out);
+ }
+
+ friend void deserialize(add_prev_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "add_prev_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::add_prev_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const add_prev_& item)
+ {
+ out << "add_prev"<<id;
+ return out;
+ }
+
+ friend void to_xml(const add_prev_& item, std::ostream& out)
+ {
+ out << "<add_prev tag='"<<id<<"'/>\n";
+ }
+
+ private:
+ resizable_tensor params;
+ };
+
+ template <
+ template<typename> class tag,
+ typename SUBNET
+ >
+ using add_prev = add_layer<add_prev_<tag>, SUBNET>;
+
+ template <typename SUBNET> using add_prev1 = add_prev<tag1, SUBNET>;
+ template <typename SUBNET> using add_prev2 = add_prev<tag2, SUBNET>;
+ template <typename SUBNET> using add_prev3 = add_prev<tag3, SUBNET>;
+ template <typename SUBNET> using add_prev4 = add_prev<tag4, SUBNET>;
+ template <typename SUBNET> using add_prev5 = add_prev<tag5, SUBNET>;
+ template <typename SUBNET> using add_prev6 = add_prev<tag6, SUBNET>;
+ template <typename SUBNET> using add_prev7 = add_prev<tag7, SUBNET>;
+ template <typename SUBNET> using add_prev8 = add_prev<tag8, SUBNET>;
+ template <typename SUBNET> using add_prev9 = add_prev<tag9, SUBNET>;
+ template <typename SUBNET> using add_prev10 = add_prev<tag10, SUBNET>;
+
+ using add_prev1_ = add_prev_<tag1>;
+ using add_prev2_ = add_prev_<tag2>;
+ using add_prev3_ = add_prev_<tag3>;
+ using add_prev4_ = add_prev_<tag4>;
+ using add_prev5_ = add_prev_<tag5>;
+ using add_prev6_ = add_prev_<tag6>;
+ using add_prev7_ = add_prev_<tag7>;
+ using add_prev8_ = add_prev_<tag8>;
+ using add_prev9_ = add_prev_<tag9>;
+ using add_prev10_ = add_prev_<tag10>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ template<typename> class tag
+ >
+ class mult_prev_
+ {
+ public:
+ const static unsigned long id = tag_id<tag>::id;
+
+ mult_prev_()
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ auto&& t1 = sub.get_output();
+ auto&& t2 = layer<tag>(sub).get_output();
+ output.set_size(std::max(t1.num_samples(),t2.num_samples()),
+ std::max(t1.k(),t2.k()),
+ std::max(t1.nr(),t2.nr()),
+ std::max(t1.nc(),t2.nc()));
+ tt::multiply_zero_padded(false, output, t1, t2);
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
+ {
+ auto&& t1 = sub.get_output();
+ auto&& t2 = layer<tag>(sub).get_output();
+ // The gradient just flows backwards to the two layers that forward()
+ // multiplied together.
+ tt::multiply_zero_padded(true, sub.get_gradient_input(), t2, gradient_input);
+ tt::multiply_zero_padded(true, layer<tag>(sub).get_gradient_input(), t1, gradient_input);
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const mult_prev_& , std::ostream& out)
+ {
+ serialize("mult_prev_", out);
+ }
+
+ friend void deserialize(mult_prev_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "mult_prev_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::mult_prev_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const mult_prev_& item)
+ {
+ out << "mult_prev"<<id;
+ return out;
+ }
+
+ friend void to_xml(const mult_prev_& item, std::ostream& out)
+ {
+ out << "<mult_prev tag='"<<id<<"'/>\n";
+ }
+
+ private:
+ resizable_tensor params;
+ };
+
+ template <
+ template<typename> class tag,
+ typename SUBNET
+ >
+ using mult_prev = add_layer<mult_prev_<tag>, SUBNET>;
+
+ template <typename SUBNET> using mult_prev1 = mult_prev<tag1, SUBNET>;
+ template <typename SUBNET> using mult_prev2 = mult_prev<tag2, SUBNET>;
+ template <typename SUBNET> using mult_prev3 = mult_prev<tag3, SUBNET>;
+ template <typename SUBNET> using mult_prev4 = mult_prev<tag4, SUBNET>;
+ template <typename SUBNET> using mult_prev5 = mult_prev<tag5, SUBNET>;
+ template <typename SUBNET> using mult_prev6 = mult_prev<tag6, SUBNET>;
+ template <typename SUBNET> using mult_prev7 = mult_prev<tag7, SUBNET>;
+ template <typename SUBNET> using mult_prev8 = mult_prev<tag8, SUBNET>;
+ template <typename SUBNET> using mult_prev9 = mult_prev<tag9, SUBNET>;
+ template <typename SUBNET> using mult_prev10 = mult_prev<tag10, SUBNET>;
+
+ using mult_prev1_ = mult_prev_<tag1>;
+ using mult_prev2_ = mult_prev_<tag2>;
+ using mult_prev3_ = mult_prev_<tag3>;
+ using mult_prev4_ = mult_prev_<tag4>;
+ using mult_prev5_ = mult_prev_<tag5>;
+ using mult_prev6_ = mult_prev_<tag6>;
+ using mult_prev7_ = mult_prev_<tag7>;
+ using mult_prev8_ = mult_prev_<tag8>;
+ using mult_prev9_ = mult_prev_<tag9>;
+ using mult_prev10_ = mult_prev_<tag10>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ template<typename> class tag
+ >
+ class scale_
+ {
+ public:
+ const static unsigned long id = tag_id<tag>::id;
+
+ scale_()
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ auto&& scales = sub.get_output();
+ auto&& src = layer<tag>(sub).get_output();
+ DLIB_CASSERT(scales.num_samples() == src.num_samples() &&
+ scales.k() == src.k() &&
+ scales.nr() == 1 &&
+ scales.nc() == 1,
+ "scales.k(): " << scales.k() <<
+ "\nsrc.k(): " << src.k()
+ );
+
+ output.copy_size(src);
+ tt::scale_channels(false, output, src, scales);
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
+ {
+ auto&& scales = sub.get_output();
+ auto&& src = layer<tag>(sub).get_output();
+ // The gradient just flows backwards to the two layers that forward()
+ // read from.
+ tt::scale_channels(true, layer<tag>(sub).get_gradient_input(), gradient_input, scales);
+
+ if (reshape_src.num_samples() != src.num_samples())
+ {
+ reshape_scales = alias_tensor(src.num_samples()*src.k());
+ reshape_src = alias_tensor(src.num_samples()*src.k(),src.nr()*src.nc());
+ }
+
+ auto&& scales_grad = sub.get_gradient_input();
+ auto sgrad = reshape_scales(scales_grad);
+ tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input));
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const scale_& item, std::ostream& out)
+ {
+ serialize("scale_", out);
+ serialize(item.reshape_scales, out);
+ serialize(item.reshape_src, out);
+ }
+
+ friend void deserialize(scale_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "scale_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::scale_.");
+ deserialize(item.reshape_scales, in);
+ deserialize(item.reshape_src, in);
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const scale_& item)
+ {
+ out << "scale"<<id;
+ return out;
+ }
+
+ friend void to_xml(const scale_& item, std::ostream& out)
+ {
+ out << "<scale tag='"<<id<<"'/>\n";
+ }
+
+ private:
+ alias_tensor reshape_scales;
+ alias_tensor reshape_src;
+ resizable_tensor params;
+ };
+
+ template <
+ template<typename> class tag,
+ typename SUBNET
+ >
+ using scale = add_layer<scale_<tag>, SUBNET>;
+
+ template <typename SUBNET> using scale1 = scale<tag1, SUBNET>;
+ template <typename SUBNET> using scale2 = scale<tag2, SUBNET>;
+ template <typename SUBNET> using scale3 = scale<tag3, SUBNET>;
+ template <typename SUBNET> using scale4 = scale<tag4, SUBNET>;
+ template <typename SUBNET> using scale5 = scale<tag5, SUBNET>;
+ template <typename SUBNET> using scale6 = scale<tag6, SUBNET>;
+ template <typename SUBNET> using scale7 = scale<tag7, SUBNET>;
+ template <typename SUBNET> using scale8 = scale<tag8, SUBNET>;
+ template <typename SUBNET> using scale9 = scale<tag9, SUBNET>;
+ template <typename SUBNET> using scale10 = scale<tag10, SUBNET>;
+
+ using scale1_ = scale_<tag1>;
+ using scale2_ = scale_<tag2>;
+ using scale3_ = scale_<tag3>;
+ using scale4_ = scale_<tag4>;
+ using scale5_ = scale_<tag5>;
+ using scale6_ = scale_<tag6>;
+ using scale7_ = scale_<tag7>;
+ using scale8_ = scale_<tag8>;
+ using scale9_ = scale_<tag9>;
+ using scale10_ = scale_<tag10>;
+
+// ----------------------------------------------------------------------------------------
+
+ class relu_
+ {
+ public:
+ relu_()
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ void forward_inplace(const tensor& input, tensor& output)
+ {
+ tt::relu(output, input);
+ }
+
+ void backward_inplace(
+ const tensor& computed_output,
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor&
+ )
+ {
+ tt::relu_gradient(data_grad, computed_output, gradient_input);
+ }
+
+ inline dpoint map_input_to_output (const dpoint& p) const { return p; }
+ inline dpoint map_output_to_input (const dpoint& p) const { return p; }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const relu_& , std::ostream& out)
+ {
+ serialize("relu_", out);
+ }
+
+ friend void deserialize(relu_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "relu_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const relu_& )
+ {
+ out << "relu";
+ return out;
+ }
+
+ friend void to_xml(const relu_& /*item*/, std::ostream& out)
+ {
+ out << "<relu/>\n";
+ }
+
+ private:
+ resizable_tensor params;
+ };
+
+
+ template <typename SUBNET>
+ using relu = add_layer<relu_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class prelu_
+ {
+ public:
+ explicit prelu_(
+ float initial_param_value_ = 0.25
+ ) : initial_param_value(initial_param_value_)
+ {
+ }
+
+ float get_initial_param_value (
+ ) const { return initial_param_value; }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ params.set_size(1);
+ params = initial_param_value;
+ }
+
+ template <typename SUBNET>
+ void forward(
+ const SUBNET& sub,
+ resizable_tensor& data_output
+ )
+ {
+ data_output.copy_size(sub.get_output());
+ tt::prelu(data_output, sub.get_output(), params);
+ }
+
+ template <typename SUBNET>
+ void backward(
+ const tensor& gradient_input,
+ SUBNET& sub,
+ tensor& params_grad
+ )
+ {
+ tt::prelu_gradient(sub.get_gradient_input(), sub.get_output(),
+ gradient_input, params, params_grad);
+ }
+
+ inline dpoint map_input_to_output (const dpoint& p) const { return p; }
+ inline dpoint map_output_to_input (const dpoint& p) const { return p; }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const prelu_& item, std::ostream& out)
+ {
+ serialize("prelu_", out);
+ serialize(item.params, out);
+ serialize(item.initial_param_value, out);
+ }
+
+ friend void deserialize(prelu_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "prelu_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::prelu_.");
+ deserialize(item.params, in);
+ deserialize(item.initial_param_value, in);
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const prelu_& item)
+ {
+ out << "prelu\t ("
+ << "initial_param_value="<<item.initial_param_value
+ << ")";
+ return out;
+ }
+
+ friend void to_xml(const prelu_& item, std::ostream& out)
+ {
+ out << "<prelu initial_param_value='"<<item.initial_param_value<<"'>\n";
+ out << mat(item.params);
+ out << "</prelu>\n";
+ }
+
+ private:
+ resizable_tensor params;
+ float initial_param_value;
+ };
+
+ template <typename SUBNET>
+ using prelu = add_layer<prelu_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class sig_
+ {
+ public:
+ sig_()
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ void forward_inplace(const tensor& input, tensor& output)
+ {
+ tt::sigmoid(output, input);
+ }
+
+ void backward_inplace(
+ const tensor& computed_output,
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor&
+ )
+ {
+ tt::sigmoid_gradient(data_grad, computed_output, gradient_input);
+ }
+
+ inline dpoint map_input_to_output (const dpoint& p) const { return p; }
+ inline dpoint map_output_to_input (const dpoint& p) const { return p; }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const sig_& , std::ostream& out)
+ {
+ serialize("sig_", out);
+ }
+
+ friend void deserialize(sig_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "sig_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::sig_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const sig_& )
+ {
+ out << "sig";
+ return out;
+ }
+
+ friend void to_xml(const sig_& /*item*/, std::ostream& out)
+ {
+ out << "<sig/>\n";
+ }
+
+
+ private:
+ resizable_tensor params;
+ };
+
+
+ template <typename SUBNET>
+ using sig = add_layer<sig_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class htan_
+ {
+ public:
+ htan_()
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ inline dpoint map_input_to_output (const dpoint& p) const { return p; }
+ inline dpoint map_output_to_input (const dpoint& p) const { return p; }
+
+ void forward_inplace(const tensor& input, tensor& output)
+ {
+ tt::tanh(output, input);
+ }
+
+ void backward_inplace(
+ const tensor& computed_output,
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor&
+ )
+ {
+ tt::tanh_gradient(data_grad, computed_output, gradient_input);
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const htan_& , std::ostream& out)
+ {
+ serialize("htan_", out);
+ }
+
+ friend void deserialize(htan_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "htan_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::htan_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const htan_& )
+ {
+ out << "htan";
+ return out;
+ }
+
+ friend void to_xml(const htan_& /*item*/, std::ostream& out)
+ {
+ out << "<htan/>\n";
+ }
+
+
+ private:
+ resizable_tensor params;
+ };
+
+
+ template <typename SUBNET>
+ using htan = add_layer<htan_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class softmax_
+ {
+ public:
+ softmax_()
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ void forward_inplace(const tensor& input, tensor& output)
+ {
+ tt::softmax(output, input);
+ }
+
+ void backward_inplace(
+ const tensor& computed_output,
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor&
+ )
+ {
+ tt::softmax_gradient(data_grad, computed_output, gradient_input);
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const softmax_& , std::ostream& out)
+ {
+ serialize("softmax_", out);
+ }
+
+ friend void deserialize(softmax_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "softmax_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const softmax_& )
+ {
+ out << "softmax";
+ return out;
+ }
+
+ friend void to_xml(const softmax_& /*item*/, std::ostream& out)
+ {
+ out << "<softmax/>\n";
+ }
+
+ private:
+ resizable_tensor params;
+ };
+
+ template <typename SUBNET>
+ using softmax = add_layer<softmax_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class softmax_all_
+ {
+ public:
+ softmax_all_()
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ void forward_inplace(const tensor& input, tensor& output)
+ {
+ tt::softmax_all(output, input);
+ }
+
+ void backward_inplace(
+ const tensor& computed_output,
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor&
+ )
+ {
+ tt::softmax_all_gradient(data_grad, computed_output, gradient_input);
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const softmax_all_& , std::ostream& out)
+ {
+ serialize("softmax_all_", out);
+ }
+
+ friend void deserialize(softmax_all_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "softmax_all_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_all_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const softmax_all_& )
+ {
+ out << "softmax_all";
+ return out;
+ }
+
+ friend void to_xml(const softmax_all_& /*item*/, std::ostream& out)
+ {
+ out << "<softmax_all/>\n";
+ }
+
+ private:
+ resizable_tensor params;
+ };
+
+ template <typename SUBNET>
+ using softmax_all = add_layer<softmax_all_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <template<typename> class TAG_TYPE, template<typename> class... TAG_TYPES>
+ struct concat_helper_impl{
+
+ constexpr static size_t tag_count() {return 1 + concat_helper_impl<TAG_TYPES...>::tag_count();}
+ static void list_tags(std::ostream& out)
+ {
+ out << tag_id<TAG_TYPE>::id << (tag_count() > 1 ? "," : "");
+ concat_helper_impl<TAG_TYPES...>::list_tags(out);
+ }
+
+ template<typename SUBNET>
+ static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
+ {
+ auto& t = layer<TAG_TYPE>(sub).get_output();
+ concat_helper_impl<TAG_TYPES...>::resize_out(out, sub, sum_k + t.k());
+ }
+ template<typename SUBNET>
+ static void concat(tensor& out, const SUBNET& sub, size_t k_offset)
+ {
+ auto& t = layer<TAG_TYPE>(sub).get_output();
+ tt::copy_tensor(false, out, k_offset, t, 0, t.k());
+ k_offset += t.k();
+ concat_helper_impl<TAG_TYPES...>::concat(out, sub, k_offset);
+ }
+ template<typename SUBNET>
+ static void split(const tensor& input, SUBNET& sub, size_t k_offset)
+ {
+ auto& t = layer<TAG_TYPE>(sub).get_gradient_input();
+ tt::copy_tensor(true, t, 0, input, k_offset, t.k());
+ k_offset += t.k();
+ concat_helper_impl<TAG_TYPES...>::split(input, sub, k_offset);
+ }
+ };
+ template <template<typename> class TAG_TYPE>
+ struct concat_helper_impl<TAG_TYPE>{
+ constexpr static size_t tag_count() {return 1;}
+ static void list_tags(std::ostream& out)
+ {
+ out << tag_id<TAG_TYPE>::id;
+ }
+
+ template<typename SUBNET>
+ static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
+ {
+ auto& t = layer<TAG_TYPE>(sub).get_output();
+ out.set_size(t.num_samples(), t.k() + sum_k, t.nr(), t.nc());
+ }
+ template<typename SUBNET>
+ static void concat(tensor& out, const SUBNET& sub, size_t k_offset)
+ {
+ auto& t = layer<TAG_TYPE>(sub).get_output();
+ tt::copy_tensor(false, out, k_offset, t, 0, t.k());
+ }
+ template<typename SUBNET>
+ static void split(const tensor& input, SUBNET& sub, size_t k_offset)
+ {
+ auto& t = layer<TAG_TYPE>(sub).get_gradient_input();
+ tt::copy_tensor(true, t, 0, input, k_offset, t.k());
+ }
+ };
+ }
+ // concat layer
+ template<
+ template<typename> class... TAG_TYPES
+ >
+ class concat_
+ {
+ static void list_tags(std::ostream& out) { impl::concat_helper_impl<TAG_TYPES...>::list_tags(out);};
+
+ public:
+ constexpr static size_t tag_count() {return impl::concat_helper_impl<TAG_TYPES...>::tag_count();};
+
+ template <typename SUBNET>
+ void setup (const SUBNET&)
+ {
+ // do nothing
+ }
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ // the total depth of result is the sum of depths from all tags
+ impl::concat_helper_impl<TAG_TYPES...>::resize_out(output, sub, 0);
+
+ // copy output from each tag into different part result
+ impl::concat_helper_impl<TAG_TYPES...>::concat(output, sub, 0);
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor&)
+ {
+ // Gradient is split into parts for each tag layer
+ impl::concat_helper_impl<TAG_TYPES...>::split(gradient_input, sub, 0);
+ }
+
+ dpoint map_input_to_output(dpoint p) const { return p; }
+ dpoint map_output_to_input(dpoint p) const { return p; }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const concat_& item, std::ostream& out)
+ {
+ serialize("concat_", out);
+ size_t count = tag_count();
+ serialize(count, out);
+ }
+
+ friend void deserialize(concat_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "concat_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::concat_.");
+ size_t count_tags;
+ deserialize(count_tags, in);
+ if (count_tags != tag_count())
+ throw serialization_error("Invalid count of tags "+ std::to_string(count_tags) +", expecting " +
+ std::to_string(tag_count()) +
+ " found while deserializing dlib::concat_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const concat_& item)
+ {
+ out << "concat\t (";
+ list_tags(out);
+ out << ")";
+ return out;
+ }
+
+ friend void to_xml(const concat_& item, std::ostream& out)
+ {
+ out << "<concat tags='";
+ list_tags(out);
+ out << "'/>\n";
+ }
+
+ private:
+ resizable_tensor params; // unused
+ };
+
+
+ // concat layer definitions
+ template <template<typename> class TAG1,
+ template<typename> class TAG2,
+ typename SUBNET>
+ using concat2 = add_layer<concat_<TAG1, TAG2>, SUBNET>;
+
+ template <template<typename> class TAG1,
+ template<typename> class TAG2,
+ template<typename> class TAG3,
+ typename SUBNET>
+ using concat3 = add_layer<concat_<TAG1, TAG2, TAG3>, SUBNET>;
+
+ template <template<typename> class TAG1,
+ template<typename> class TAG2,
+ template<typename> class TAG3,
+ template<typename> class TAG4,
+ typename SUBNET>
+ using concat4 = add_layer<concat_<TAG1, TAG2, TAG3, TAG4>, SUBNET>;
+
+ template <template<typename> class TAG1,
+ template<typename> class TAG2,
+ template<typename> class TAG3,
+ template<typename> class TAG4,
+ template<typename> class TAG5,
+ typename SUBNET>
+ using concat5 = add_layer<concat_<TAG1, TAG2, TAG3, TAG4, TAG5>, SUBNET>;
+
+ // inception layer will use tags internally. If user will use tags too, some conflicts
+ // possible to exclude them, here are new tags specially for inceptions
+ template <typename SUBNET> using itag0 = add_tag_layer< 1000 + 0, SUBNET>;
+ template <typename SUBNET> using itag1 = add_tag_layer< 1000 + 1, SUBNET>;
+ template <typename SUBNET> using itag2 = add_tag_layer< 1000 + 2, SUBNET>;
+ template <typename SUBNET> using itag3 = add_tag_layer< 1000 + 3, SUBNET>;
+ template <typename SUBNET> using itag4 = add_tag_layer< 1000 + 4, SUBNET>;
+ template <typename SUBNET> using itag5 = add_tag_layer< 1000 + 5, SUBNET>;
+ // skip to inception input
+ template <typename SUBNET> using iskip = add_skip_layer< itag0, SUBNET>;
+
+ // here are some templates to be used for creating inception layer groups
+ template <template<typename>class B1,
+ template<typename>class B2,
+ typename SUBNET>
+ using inception2 = concat2<itag1, itag2, itag1<B1<iskip< itag2<B2< itag0<SUBNET>>>>>>>;
+
+ template <template<typename>class B1,
+ template<typename>class B2,
+ template<typename>class B3,
+ typename SUBNET>
+ using inception3 = concat3<itag1, itag2, itag3, itag1<B1<iskip< itag2<B2<iskip< itag3<B3< itag0<SUBNET>>>>>>>>>>;
+
+ template <template<typename>class B1,
+ template<typename>class B2,
+ template<typename>class B3,
+ template<typename>class B4,
+ typename SUBNET>
+ using inception4 = concat4<itag1, itag2, itag3, itag4,
+ itag1<B1<iskip< itag2<B2<iskip< itag3<B3<iskip< itag4<B4< itag0<SUBNET>>>>>>>>>>>>>;
+
+ template <template<typename>class B1,
+ template<typename>class B2,
+ template<typename>class B3,
+ template<typename>class B4,
+ template<typename>class B5,
+ typename SUBNET>
+ using inception5 = concat5<itag1, itag2, itag3, itag4, itag5,
+ itag1<B1<iskip< itag2<B2<iskip< itag3<B3<iskip< itag4<B4<iskip< itag5<B5< itag0<SUBNET>>>>>>>>>>>>>>>>;
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ const double DEFAULT_L2_NORM_EPS = 1e-5;
+
+ class l2normalize_
+ {
+ public:
+ explicit l2normalize_(
+ double eps_ = DEFAULT_L2_NORM_EPS
+ ) :
+ eps(eps_)
+ {
+ }
+
+ double get_eps() const { return eps; }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& /*sub*/)
+ {
+ }
+
+ void forward_inplace(const tensor& input, tensor& output)
+ {
+ tt::inverse_norms(norm, input, eps);
+ tt::scale_rows(output, input, norm);
+ }
+
+ void backward_inplace(
+ const tensor& computed_output,
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor& /*params_grad*/
+ )
+ {
+ if (is_same_object(gradient_input, data_grad))
+ {
+ tt::dot_prods(temp, gradient_input, computed_output);
+ tt::scale_rows2(0, data_grad, gradient_input, computed_output, temp, norm);
+ }
+ else
+ {
+ tt::dot_prods(temp, gradient_input, computed_output);
+ tt::scale_rows2(1, data_grad, gradient_input, computed_output, temp, norm);
+ }
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const l2normalize_& item, std::ostream& out)
+ {
+ serialize("l2normalize_", out);
+ serialize(item.eps, out);
+ }
+
+ friend void deserialize(l2normalize_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "l2normalize_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::l2normalize_.");
+ deserialize(item.eps, in);
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const l2normalize_& item)
+ {
+ out << "l2normalize";
+ out << " eps="<<item.eps;
+ return out;
+ }
+
+ friend void to_xml(const l2normalize_& item, std::ostream& out)
+ {
+ out << "<l2normalize";
+ out << " eps='"<<item.eps<<"'";
+ out << "/>\n";
+ }
+ private:
+ double eps;
+
+ resizable_tensor params; // unused
+ // Here only to avoid reallocation and as a cache between forward/backward
+ // functions.
+ resizable_tensor norm;
+ resizable_tensor temp;
+ };
+
+ template <typename SUBNET>
+ using l2normalize = add_layer<l2normalize_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long _offset,
+ long _k,
+ long _nr,
+ long _nc
+ >
+ class extract_
+ {
+ static_assert(_offset >= 0, "The offset must be >= 0.");
+ static_assert(_k > 0, "The number of channels must be > 0.");
+ static_assert(_nr > 0, "The number of rows must be > 0.");
+ static_assert(_nc > 0, "The number of columns must be > 0.");
+ public:
+ extract_(
+ )
+ {
+ }
+
+ template <typename SUBNET>
+ void setup (const SUBNET& sub)
+ {
+ DLIB_CASSERT((long)sub.get_output().size() >= sub.get_output().num_samples()*(_offset+_k*_nr*_nc),
+ "The tensor we are trying to extract from the input tensor is too big to fit into the input tensor.");
+
+ aout = alias_tensor(sub.get_output().num_samples(), _k*_nr*_nc);
+ ain = alias_tensor(sub.get_output().num_samples(), sub.get_output().size()/sub.get_output().num_samples());
+ }
+
+ template <typename SUBNET>
+ void forward(const SUBNET& sub, resizable_tensor& output)
+ {
+ if (aout.num_samples() != sub.get_output().num_samples())
+ {
+ aout = alias_tensor(sub.get_output().num_samples(), _k*_nr*_nc);
+ ain = alias_tensor(sub.get_output().num_samples(), sub.get_output().size()/sub.get_output().num_samples());
+ }
+
+ output.set_size(sub.get_output().num_samples(), _k, _nr, _nc);
+ auto out = aout(output,0);
+ auto in = ain(sub.get_output(),0);
+ tt::copy_tensor(false, out, 0, in, _offset, _k*_nr*_nc);
+ }
+
+ template <typename SUBNET>
+ void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
+ {
+ auto out = ain(sub.get_gradient_input(),0);
+ auto in = aout(gradient_input,0);
+ tt::copy_tensor(true, out, _offset, in, 0, _k*_nr*_nc);
+ }
+
+ const tensor& get_layer_params() const { return params; }
+ tensor& get_layer_params() { return params; }
+
+ friend void serialize(const extract_& item, std::ostream& out)
+ {
+ serialize("extract_", out);
+ serialize(_offset, out);
+ serialize(_k, out);
+ serialize(_nr, out);
+ serialize(_nc, out);
+ }
+
+ friend void deserialize(extract_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "extract_")
+ throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::extract_.");
+
+ long offset;
+ long k;
+ long nr;
+ long nc;
+ deserialize(offset, in);
+ deserialize(k, in);
+ deserialize(nr, in);
+ deserialize(nc, in);
+
+ if (offset != _offset) throw serialization_error("Wrong offset found while deserializing dlib::extract_");
+ if (k != _k) throw serialization_error("Wrong k found while deserializing dlib::extract_");
+ if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::extract_");
+ if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::extract_");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const extract_& item)
+ {
+ out << "extract\t ("
+ << "offset="<<_offset
+ << ", k="<<_k
+ << ", nr="<<_nr
+ << ", nc="<<_nc
+ << ")";
+ return out;
+ }
+
+ friend void to_xml(const extract_& item, std::ostream& out)
+ {
+ out << "<extract";
+ out << " offset='"<<_offset<<"'";
+ out << " k='"<<_k<<"'";
+ out << " nr='"<<_nr<<"'";
+ out << " nc='"<<_nc<<"'";
+ out << "/>\n";
+ }
+ private:
+ alias_tensor aout, ain;
+
+ resizable_tensor params; // unused
+ };
+
+ template <
+ long offset,
+ long k,
+ long nr,
+ long nc,
+ typename SUBNET
+ >
+ using extract = add_layer<extract_<offset,k,nr,nc>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_LAYERS_H_
+
+
diff --git a/ml/dlib/dlib/dnn/layers_abstract.h b/ml/dlib/dlib/dnn/layers_abstract.h
new file mode 100644
index 000000000..f07025ff8
--- /dev/null
+++ b/ml/dlib/dlib/dnn/layers_abstract.h
@@ -0,0 +1,2631 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DNn_LAYERS_ABSTRACT_H_
+#ifdef DLIB_DNn_LAYERS_ABSTRACT_H_
+
+#include "tensor_abstract.h"
+#include "core_abstract.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class SUBNET
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a deep neural network. In particular, it is
+ the simplified interface through which layer objects interact with their
+ subnetworks. A layer's two important tasks are to (1) take outputs from its
+ subnetwork and forward propagate them through itself and (2) to backwards
+ propagate an error gradient through itself and onto its subnetwork.
+ The idea of a subnetwork is illustrated in the following diagram:
+
+ +---------------------------------------------------------+
+ | loss <-- layer1 <-- layer2 <-- ... <-- layern <-- input |
+ +---------------------------------------------------------+
+ ^ ^
+ \__ subnetwork for layer1 __/
+
+ Therefore, by "subnetwork" we mean the part of the network closer to the
+ input.
+
+ Note that there is no dlib::SUBNET type. It is shown here purely to
+ document the interface layer objects expect to see when they interact
+ with a network.
+ !*/
+
+ public:
+ // You aren't allowed to copy subnetworks from inside a layer.
+ SUBNET(const SUBNET&) = delete;
+ SUBNET& operator=(const SUBNET&) = delete;
+
+ const tensor& get_output(
+ ) const;
+ /*!
+ ensures
+ - returns the output of this subnetwork. This is the data that the next
+ layer in the network will take as input.
+ - have_same_dimensions(#get_gradient_input(), get_output()) == true
+ !*/
+
+ tensor& get_gradient_input(
+ );
+ /*!
+ ensures
+ - returns the error gradient for this subnetwork. That is, this is the
+ error gradient that this network will use to update itself. Therefore,
+ when performing back propagation, layers that sit on top of this
+ subnetwork write their back propagated error gradients into
+ get_gradient_input(). Or to put it another way, during back propagation,
+ layers take the contents of their get_gradient_input() and back propagate
+ it through themselves and store the results into their subnetwork's
+ get_gradient_input().
+ !*/
+
+ const NEXT_SUBNET& subnet(
+ ) const;
+ /*!
+ ensures
+ - returns the subnetwork of *this network. With respect to the diagram
+ above, if *this was layer1 then subnet() would return the network that
+ begins with layer2.
+ !*/
+
+ NEXT_SUBNET& subnet(
+ );
+ /*!
+ ensures
+ - returns the subnetwork of *this network. With respect to the diagram
+ above, if *this was layer1 then subnet() would return the network that
+ begins with layer2.
+ !*/
+
+ const layer_details_type& layer_details(
+ ) const;
+ /*!
+ ensures
+ - returns the layer_details_type instance that defines the behavior of the
+ layer at the top of this network. I.e. returns the layer details that
+ defines the behavior of the layer nearest to the network output rather
+ than the input layer. For computational layers, this is the object
+ implementing the EXAMPLE_COMPUTATIONAL_LAYER_ interface that defines the
+ layer's behavior.
+ !*/
+
+ unsigned int sample_expansion_factor (
+ ) const;
+ /*!
+ ensures
+ - When to_tensor() is invoked on this network's input layer it converts N
+ input objects into M samples, all stored inside a resizable_tensor. It
+ is always the case that M is some integer multiple of N.
+ sample_expansion_factor() returns the value of this multiplier. To be
+ very specific, it is always true that M==I*N where I is some integer.
+ This integer I is what is returned by sample_expansion_factor().
+
+ It should be noted that computational layers likely do not care about the
+ sample expansion factor. It is only really of concern inside a loss
+ layer where you need to know its value so that tensor samples can be
+ matched against truth objects. Moreover, in most cases the sample
+ expansion factor is 1.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class EXAMPLE_COMPUTATIONAL_LAYER_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ Each computational layer in a deep neural network can be thought of as a
+ function, f(data,parameters), that takes in a data tensor, some parameters,
+ and produces an output tensor. You create an entire deep network by
+ composing these functions. Importantly, you are able to use a wide range
+ of different functions to accommodate the task you are trying to
+ accomplish. Therefore, dlib includes a number of common layer types but if
+ you want to define your own then you simply implement a class with the same
+ interface as EXAMPLE_COMPUTATIONAL_LAYER_.
+
+ Note that there is no dlib::EXAMPLE_COMPUTATIONAL_LAYER_ type. It is shown
+ here purely to document the interface that a layer object must implement.
+
+ The central work of defining a layer is implementing the forward and backward
+ methods. When you do this you have four options:
+ - Implement the forward() and backward() methods according to the
+ specification shown below. Do not implement forward_inplace() and
+ backward_inplace().
+ - Implement the forward() and backward() methods according to the
+ specification shown below, except exclude the computed_output
+ parameter from backward(). Doing this will allow dlib to make some
+ layers execute in-place and therefore run a little faster and use
+ less memory. Do not implement forward_inplace() and
+ backward_inplace().
+ - Implement the forward_inplace() and backward_inplace() methods
+ according to the specification shown below. Do not implement
+ forward() and backward(). These in-place methods allow some types of
+ layers to be implemented more efficiently.
+ - Implement the forward_inplace() and backward_inplace() methods
+ according to the specification shown below, except exclude the
+ computed_output parameter from backward_inplace(). Doing this will
+ allow dlib to make some layers execute in-place and therefore run a
+ little faster and use less memory. Do not implement forward() and
+ backward().
+
+
+ It should also be noted that layers may define additional layer specific
+ fields and the solvers can use these fields as they see fit. For example,
+ some layers define get_learning_rate_multiplier() and
+ get_weight_decay_multiplier() methods. The solvers that come with dlib
+ look at these methods, if they exist, and adjust the learning rate or
+ weight decay for that layer according to the multiplier. Therefore, you
+ can add these methods to your layer types if you want, or even define new
+ fields and new solvers that use those fields in some way.
+ !*/
+
+ public:
+
+ EXAMPLE_COMPUTATIONAL_LAYER_(
+ );
+ /*!
+ ensures
+ - Default constructs this object. This function is not required to do
+ anything in particular but it must exist, that is, it is required that
+ layer objects be default constructable.
+ !*/
+
+ EXAMPLE_COMPUTATIONAL_LAYER_ (
+ const EXAMPLE_COMPUTATIONAL_LAYER_& item
+ );
+ /*!
+ ensures
+ - EXAMPLE_COMPUTATIONAL_LAYER_ objects are copy constructable
+ !*/
+
+ EXAMPLE_COMPUTATIONAL_LAYER_(
+ const some_other_layer_type& item
+ );
+ /*!
+ ensures
+ - Constructs this object from item. This form of constructor is optional
+ but it allows you to provide a conversion from one layer type to another.
+ For example, the following code is valid only if my_layer2 can be
+ constructed from my_layer1:
+ relu<fc<my_layer1<fc<input<matrix<float>>>>>> my_dnn1;
+ relu<fc<my_layer2<fc<input<matrix<float>>>>>> my_dnn2(my_dnn1);
+ This kind of pattern is useful if you want to use one type of layer
+ during training but a different type of layer during testing since it
+ allows you to easily convert between related deep neural network types.
+
+ Additionally, if you provide a constructor to build a layer from another
+ layer type you should also write your layer's deserialize() routine such
+ that it can read that other layer's serialized data in addition to your
+ own serialized data.
+ !*/
+
+ template <typename SUBNET>
+ void setup (
+ const SUBNET& sub
+ );
+ /*!
+ requires
+ - SUBNET implements the SUBNET interface defined at the top of this file.
+ ensures
+ - performs any necessary initial memory allocations and/or sets parameters
+ to their initial values prior to learning. Therefore, calling setup
+ destroys any previously learned parameters. Also, typically setup()
+ would look at the dimensions of the outputs of sub and configure the
+ number of parameters in *this accordingly.
+ !*/
+
+ template <typename SUBNET>
+ void forward(
+ const SUBNET& sub,
+ resizable_tensor& data_output
+ );
+ /*!
+ requires
+ - SUBNET implements the SUBNET interface defined at the top of this file.
+ - setup() has been called.
+ ensures
+ - Runs the output of the subnetwork through this layer and stores the
+ results into #data_output. In particular, forward() can use any of the
+ outputs in sub (e.g. sub.get_output(), sub.subnet().get_output(), etc.)
+ to compute whatever it wants.
+ !*/
+
+ template <typename SUBNET>
+ void backward(
+ const tensor& computed_output, // this parameter is optional
+ const tensor& gradient_input,
+ SUBNET& sub,
+ tensor& params_grad
+ );
+ /*!
+ requires
+ - SUBNET implements the SUBNET interface defined at the top of this file.
+ - setup() has been called.
+ - computed_output is the tensor resulting from calling forward(sub,computed_output).
+ Moreover, this was the most recent call to forward(). This means that
+ forward() is allowed to cache intermediate results so they can be used
+ during the backward computation.
+ - have_same_dimensions(gradient_input, computed_output) == true
+ - have_same_dimensions(sub.get_gradient_input(), sub.get_output()) == true
+ - have_same_dimensions(params_grad, get_layer_params()) == true
+ ensures
+ - This function outputs the gradients of this layer with respect to the
+ input data from sub and also with respect to this layer's parameters.
+ These gradients are stored into #sub and #params_grad, respectively. To be
+ precise, the gradients are taken of a function f(sub,get_layer_params())
+ which is defined thusly:
+ - Recalling that computed_output is a function of both sub and get_layer_params(),
+ since it is the result of calling forward(sub,computed_output):
+ let f(sub,get_layer_params()) == dot(computed_output, gradient_input)
+ Then we define the following gradient vectors:
+ - PARAMETER_GRADIENT == gradient of f(sub,get_layer_params()) with
+ respect to get_layer_params().
+ - for all valid I:
+ - DATA_GRADIENT_I == gradient of f(sub,get_layer_params()) with
+ respect to layer<I>(sub).get_output() (recall that forward() can
+ draw inputs from the immediate sub layer, sub.subnet(), or
+ any earlier layer. So you must consider the gradients with
+ respect to all inputs drawn from sub)
+ Finally, backward() outputs these gradients by performing:
+ - params_grad = PARAMETER_GRADIENT
+ - for all valid I:
+ - layer<I>(sub).get_gradient_input() += DATA_GRADIENT_I
+ !*/
+
+ void forward_inplace(
+ const tensor& data_input,
+ tensor& data_output
+ );
+ /*!
+ requires
+ - have_same_dimensions(data_input,data_output) == true
+ - setup() has been called.
+ ensures
+ - Runs the data_input tensor through this layer and stores the output into
+ #data_output.
+ - This function supports in-place operation, i.e. having
+ is_same_object(data_input, data_output)==true
+ !*/
+
+ void backward_inplace(
+ const tensor& computed_output, // this parameter is optional
+ const tensor& gradient_input,
+ tensor& data_grad,
+ tensor& params_grad
+ );
+ /*!
+ requires
+ - setup() has been called.
+ - computed_output is the tensor resulting from the most recent call to
+ forward_inplace(). This means that forward_inplace() is allowed to cache
+ intermediate results so they can be used during the backward computation.
+ - have_same_dimensions(gradient_input, data_grad) == true
+ - have_same_dimensions(gradient_input, computed_output) == true
+ - have_same_dimensions(params_grad, get_layer_params()) == true
+ ensures
+ - This function supports in-place operation, i.e. having
+ is_same_object(gradient_input, data_grad)==true
+ - This function outputs the gradients of this layer with respect to the
+ input data from a sublayer and also with respect to this layer's parameters.
+ These gradients are stored into #data_grad and #params_grad, respectively. To be
+ precise, the gradients are taken of a function f(data_input,get_layer_params())
+ which is defined thusly:
+ - Recalling that computed_output is a function of both the input to
+ forward_inplace() and get_layer_params(), since it is the result of
+ calling forward_inplace(data_input,computed_output):
+ let f(data_input,get_layer_params()) == dot(computed_output, gradient_input)
+ Then we define the following gradient vectors:
+ - PARAMETER_GRADIENT == gradient of f(data_input,get_layer_params()) with
+ respect to get_layer_params().
+ - DATA_GRADIENT == gradient of f(data_input,get_layer_params()) with respect
+ to data_input.
+ Finally, backward_inplace() outputs these gradients by performing:
+ - params_grad = PARAMETER_GRADIENT
+ - if (is_same_object(gradient_input, data_grad)) then
+ - data_grad = DATA_GRADIENT
+ - else
+ - data_grad += DATA_GRADIENT
+ !*/
+
+ const tensor& get_layer_params(
+ ) const;
+ /*!
+ ensures
+ - returns the parameters that define the behavior of forward().
+ !*/
+
+ tensor& get_layer_params(
+ );
+ /*!
+ ensures
+ - returns the parameters that define the behavior of forward().
+ !*/
+
+
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ /*!
+ These two functions are optional. If provided, they should map between
+ (column,row) coordinates in input and output tensors of forward(). Providing
+ these functions allows you to use global utility functions like
+ input_tensor_to_output_tensor().
+ !*/
+
+ void clean (
+ );
+ /*!
+ Implementing this function is optional. If you don't need it then you don't
+ have to provide a clean(). But if you do provide it then it must behave as
+ follows:
+
+ ensures
+ - calling clean() Causes this object to forget about everything except its
+ parameters. This is useful if your layer caches information between
+ forward and backward passes and you want to clean out that cache
+ information before saving the network to disk.
+ !*/
+
+ };
+
+ std::ostream& operator<<(std::ostream& out, const EXAMPLE_COMPUTATIONAL_LAYER_& item);
+ /*!
+ print a string describing this layer.
+ !*/
+
+ void to_xml(const EXAMPLE_COMPUTATIONAL_LAYER_& item, std::ostream& out);
+ /*!
+ This function is optional, but required if you want to print your networks with
+ net_to_xml(). Therefore, to_xml() prints a layer as XML.
+ !*/
+
+ void serialize(const EXAMPLE_COMPUTATIONAL_LAYER_& item, std::ostream& out);
+ void deserialize(EXAMPLE_COMPUTATIONAL_LAYER_& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+ // For each layer you define, always define an add_layer template so that layers can be
+ // easily composed. Moreover, the convention is that the layer class ends with an _
+ // while the add_layer template has the same name but without the trailing _.
+ template <typename SUBNET>
+ using EXAMPLE_COMPUTATIONAL_LAYER = add_layer<EXAMPLE_COMPUTATIONAL_LAYER_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ enum fc_bias_mode
+ {
+ FC_HAS_BIAS = 0,
+ FC_NO_BIAS = 1
+ };
+
+ struct num_fc_outputs
+ {
+ num_fc_outputs(unsigned long n) : num_outputs(n) {}
+ unsigned long num_outputs;
+ };
+
+ template <
+ unsigned long num_outputs,
+ fc_bias_mode bias_mode
+ >
+ class fc_
+ {
+ /*!
+ REQUIREMENTS ON num_outputs
+ num_outputs > 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a fully connected layer that
+ takes an input tensor and multiplies it by a weight matrix and outputs the
+ results.
+
+ The dimensions of the tensors output by this layer are as follows (letting
+ IN be the input tensor and OUT the output tensor):
+ - OUT.num_samples() == IN.num_samples()
+ - OUT.k() == get_num_outputs()
+ - OUT.nr() == 1
+ - OUT.nc() == 1
+ !*/
+
+ public:
+
+ fc_(
+ );
+ /*!
+ ensures
+ - #get_num_outputs() == num_outputs
+ - #get_bias_mode() == bias_mode
+ - #get_learning_rate_multiplier() == 1
+ - #get_weight_decay_multiplier() == 1
+ - #get_bias_learning_rate_multiplier() == 1
+ - #get_bias_weight_decay_multiplier() == 0
+ !*/
+
+ fc_(
+ num_fc_outputs o
+ );
+ /*!
+ ensures
+ - #get_num_outputs() == o.num_outputs
+ - #get_bias_mode() == bias_mode
+ - #get_learning_rate_multiplier() == 1
+ - #get_weight_decay_multiplier() == 1
+ - #get_bias_learning_rate_multiplier() == 1
+ - #get_bias_weight_decay_multiplier() == 0
+ !*/
+
+ unsigned long get_num_outputs (
+ ) const;
+ /*!
+ ensures
+ - This layer outputs column vectors that contain get_num_outputs()
+ elements. That is, the output tensor T from forward() will be such that:
+ - T.num_samples() == however many samples were given to forward().
+ - T.k() == get_num_outputs()
+ - The rest of the dimensions of T will be 1.
+ !*/
+
+ void set_num_outputs(
+ long num
+ );
+ /*!
+ requires
+ - num > 0
+ - get_layer_params().size() == 0 || get_num_outputs() == num
+ (i.e. You can't change the number of outputs in fc_ if the parameter
+ tensor has already been allocated.)
+ ensures
+ - #get_num_outputs() == num
+ !*/
+
+ fc_bias_mode get_bias_mode (
+ ) const;
+ /*!
+ ensures
+ - returns the bias mode which determines if this layer includes bias terms.
+ That is, if the bias mode is FC_HAS_BIAS then a different constant scalar
+ is added to each of the outputs of this layer.
+ !*/
+
+ double get_learning_rate_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the learning rate used to optimize its parameters be
+ multiplied by get_learning_rate_multiplier().
+ !*/
+
+ double get_weight_decay_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the weight decay used to optimize its parameters be
+ multiplied by get_weight_decay_multiplier().
+ !*/
+
+ void set_learning_rate_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_learning_rate_multiplier() == val
+ !*/
+
+ void set_weight_decay_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_weight_decay_multiplier() == val
+ !*/
+
+ double get_bias_learning_rate_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the learning rate used to optimize its bias parameters be
+ multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
+ !*/
+
+ double get_bias_weight_decay_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the weight decay used to optimize its bias parameters be
+ multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
+ !*/
+
+ void set_bias_learning_rate_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_bias_learning_rate_multiplier() == val
+ !*/
+
+ void set_bias_weight_decay_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_bias_weight_decay_multiplier() == val
+ !*/
+
+ alias_tensor_const_instance get_weights(
+ ) const;
+ /*!
+ ensures
+ - returns an alias of get_layer_params(), containing the weights matrix of
+ the fully connected layer.
+ - #get_weights().num_samples() is the number of elements in input sample,
+ i.e. sublayer's output's k * nc * nr.
+ - #get_bias().k() == #get_num_outputs()
+ - if get_bias_mode() == FC_HAS_BIAS:
+ - #get_layer_params().size() == (#get_weights().size() + #get_biases().size())
+ - else:
+ - #get_layer_params().size() == #get_weights().size()
+ !*/
+
+ alias_tensor_instance get_weights(
+ );
+ /*!
+ ensures
+ - returns an alias of get_layer_params(), containing the weights matrix of
+ the fully connected layer.
+ - #get_weights().num_samples() is the number of elements in input sample,
+ i.e. sublayer's output's k * nc * nr.
+ - #get_bias().k() == #get_num_outputs()
+ - if get_bias_mode() == FC_HAS_BIAS:
+ - #get_layer_params().size() == (#get_weights().size() + #get_biases().size())
+ - else:
+ - #get_layer_params().size() == #get_weights().size()
+ !*/
+
+ alias_tensor_const_instance get_biases(
+ ) const;
+ /*!
+ requires
+ - #get_bias_mode() == FC_HAS_BIAS
+ ensures
+ - returns an alias of get_layer_params(), containing the bias vector of
+ the fully connected layer.
+ - #get_bias().num_samples() == 1
+ - #get_bias().k() == #get_num_outputs()
+ - #get_layer_params().size() == (#get_weights().size() + #get_biases().size())
+ !*/
+
+ alias_tensor_instance get_biases(
+ );
+ /*!
+ requires
+ - #get_bias_mode() == FC_HAS_BIAS
+ ensures
+ - returns an alias of get_layer_params(), containing the bias vector of
+ the fully connected layer.
+ - #get_bias().num_samples() == 1
+ - #get_bias().k() == #get_num_outputs()
+ - #get_layer_params().size() == (#get_weights().size() + #get_biases().size())
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+
+ };
+
+ template <
+ unsigned long num_outputs,
+ typename SUBNET
+ >
+ using fc = add_layer<fc_<num_outputs,FC_HAS_BIAS>, SUBNET>;
+
+ template <
+ unsigned long num_outputs,
+ typename SUBNET
+ >
+ using fc_no_bias = add_layer<fc_<num_outputs,FC_NO_BIAS>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ struct num_con_outputs
+ {
+ num_con_outputs(unsigned long n) : num_outputs(n) {}
+ unsigned long num_outputs;
+ };
+
+ template <
+ long _num_filters,
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y = _stride_y!=1? 0 : _nr/2,
+ int _padding_x = _stride_x!=1? 0 : _nc/2
+ >
+ class con_
+ {
+ /*!
+ REQUIREMENTS ON TEMPLATE ARGUMENTS
+ - _num_filters > 0
+ - _nr >= 0
+ - _nc >= 0
+ - _stride_y > 0
+ - _stride_x > 0
+ - _padding_y >= 0
+ - _padding_x >= 0
+ - Also, we require that:
+ - if (_nr == 0) then
+ - _padding_y == 0
+ - else
+ - _padding_y < _nr
+ - if (_nc == 0) then
+ - _padding_x == 0
+ - else
+ - _padding_x < _nc
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a convolution layer that takes an
+ input tensor (nominally representing an image) and convolves it with a set
+ of filters and then outputs the results.
+
+ The dimensions of the tensors output by this layer are as follows (letting
+ IN be the input tensor and OUT the output tensor):
+ - OUT.num_samples() == IN.num_samples()
+ - OUT.k() == num_filters()
+ - OUT.nr() == 1+(IN.nr() + 2*padding_y() - nr())/stride_y()
+ - OUT.nc() == 1+(IN.nc() + 2*padding_x() - nc())/stride_x()
+
+ Note also that setting _nr or _nc to 0 has a special meaning of "set the
+ filter size equal to the input image size". Specifically, it means:
+ - if (_nr == 0) then
+ - nr() == IN.nr()
+ - OUT.nr() == 1
+ - if (_nc == 0) then
+ - nc() == IN.nc()
+ - OUT.nc() == 1
+ !*/
+
+ public:
+ con_(
+ );
+ /*!
+ ensures
+ - #num_filters() == _num_filters
+ - #nr() == _nr
+ - #nc() == _nc
+ - #stride_y() == _stride_y
+ - #stride_x() == _stride_x
+ - #padding_y() == _padding_y
+ - #padding_x() == _padding_x
+ - #get_learning_rate_multiplier() == 1
+ - #get_weight_decay_multiplier() == 1
+ - #get_bias_learning_rate_multiplier() == 1
+ - #get_bias_weight_decay_multiplier() == 0
+ !*/
+
+ con_(
+ num_con_outputs o
+ );
+ /*!
+ ensures
+ - #num_filters() == o.num_outputs
+ - #nr() == _nr
+ - #nc() == _nc
+ - #stride_y() == _stride_y
+ - #stride_x() == _stride_x
+ - #padding_y() == _padding_y
+ - #padding_x() == _padding_x
+ - #get_learning_rate_multiplier() == 1
+ - #get_weight_decay_multiplier() == 1
+ - #get_bias_learning_rate_multiplier() == 1
+ - #get_bias_weight_decay_multiplier() == 0
+ !*/
+
+ long num_filters(
+ ) const;
+ /*!
+ ensures
+ - returns the number of filters contained in this layer. The k dimension
+ of the output tensors produced by this layer will be equal to the number
+ of filters.
+ !*/
+
+ void set_num_filters(
+ long num
+ );
+ /*!
+ requires
+ - num > 0
+ - get_layer_params().size() == 0 || num_filters() == num
+ (i.e. You can't change the number of filters in con_ if the parameter
+ tensor has already been allocated.)
+ ensures
+ - #num_filters() == num
+ !*/
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the filters in this layer. Note that if
+ nr()==0 then it means the size of the filter is not yet assigned, but
+ once setup() is called nr() will be set to the input tensor's nr().
+ Therefore, nr()==0 has the special interpretation of "be the same size as
+ the input tensor".
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in the filters in this layer. Note that if
+ nc()==0 then it means the size of the filter is not yet assigned, but
+ once setup() is called nc() will be set to the input tensor's nc().
+ Therefore, nc()==0 has the special interpretation of "be the same size as
+ the input tensor".
+ !*/
+
+ long stride_y(
+ ) const;
+ /*!
+ ensures
+ - returns the vertical stride used when convolving the filters over an
+ image. That is, each filter will be moved stride_y() pixels down at a
+ time when it moves over the image.
+ !*/
+
+ long stride_x(
+ ) const;
+ /*!
+ ensures
+ - returns the horizontal stride used when convolving the filters over an
+ image. That is, each filter will be moved stride_x() pixels right at a
+ time when it moves over the image.
+ !*/
+
+ long padding_y(
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels of zero padding added to the top and bottom
+ sides of the image.
+ !*/
+
+ long padding_x(
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels of zero padding added to the left and right
+ sides of the image.
+ !*/
+
+ double get_learning_rate_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the learning rate used to optimize its parameters be
+ multiplied by get_learning_rate_multiplier().
+ !*/
+
+ double get_weight_decay_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the weight decay used to optimize its parameters be
+ multiplied by get_weight_decay_multiplier().
+ !*/
+
+ void set_learning_rate_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_learning_rate_multiplier() == val
+ !*/
+
+ void set_weight_decay_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_weight_decay_multiplier() == val
+ !*/
+
+ double get_bias_learning_rate_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the learning rate used to optimize its bias parameters be
+ multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
+ !*/
+
+ double get_bias_weight_decay_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the weight decay used to optimize its bias parameters be
+ multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
+ !*/
+
+ void set_bias_learning_rate_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_bias_learning_rate_multiplier() == val
+ !*/
+
+ void set_bias_weight_decay_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_bias_weight_decay_multiplier() == val
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+
+ };
+
+ template <
+ long num_filters,
+ long nr,
+ long nc,
+ int stride_y,
+ int stride_x,
+ typename SUBNET
+ >
+ using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long _num_filters,
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y = _stride_y!=1? 0 : _nr/2,
+ int _padding_x = _stride_x!=1? 0 : _nc/2
+ >
+ class cont_
+ {
+ /*!
+ REQUIREMENTS ON TEMPLATE ARGUMENTS
+ All of them must be > 0.
+ Also, we require that:
+ - 0 <= _padding_y && _padding_y < _nr
+ - 0 <= _padding_x && _padding_x < _nc
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a transposed convolution layer
+ that takes an input tensor and transpose convolves (sometimes called
+ "deconvolution") it with a set of filters and then outputs the results.
+
+ This is essentially a convolutional layer that allows fractional strides.
+ Therefore, you can make output tensors that are larger than the input
+ tensors using this layer type.
+
+
+ The dimensions of the tensors output by this layer are as follows (letting
+ IN be the input tensor and OUT the output tensor):
+ - OUT.num_samples() == IN.num_samples()
+ - OUT.k() == num_filters()
+ - OUT.nr() == stride_y()*(IN.nr()-1) + nr() - 2*padding_y()
+ - OUT.nc() == stride_x()*(IN.nc()-1) + nc() - 2*padding_x()
+ !*/
+
+ public:
+ cont_(
+ );
+ /*!
+ ensures
+ - #num_filters() == _num_filters
+ - #nr() == _nr
+ - #nc() == _nc
+ - #stride_y() == _stride_y
+ - #stride_x() == _stride_x
+ - #padding_y() == _padding_y
+ - #padding_x() == _padding_x
+ - #get_learning_rate_multiplier() == 1
+ - #get_weight_decay_multiplier() == 1
+ - #get_bias_learning_rate_multiplier() == 1
+ - #get_bias_weight_decay_multiplier() == 0
+ !*/
+
+ cont_(
+ num_con_outputs o
+ );
+ /*!
+ ensures
+ - #num_filters() == o.num_outputs
+ - #nr() == _nr
+ - #nc() == _nc
+ - #stride_y() == _stride_y
+ - #stride_x() == _stride_x
+ - #padding_y() == _padding_y
+ - #padding_x() == _padding_x
+ - #get_learning_rate_multiplier() == 1
+ - #get_weight_decay_multiplier() == 1
+ - #get_bias_learning_rate_multiplier() == 1
+ - #get_bias_weight_decay_multiplier() == 0
+ !*/
+
+ long num_filters(
+ ) const;
+ /*!
+ ensures
+ - returns the number of filters contained in this layer. The k dimension
+ of the output tensors produced by this layer will be equal to the number
+ of filters.
+ !*/
+
+ void set_num_filters(
+ long num
+ );
+ /*!
+ requires
+ - num > 0
+ - get_layer_params().size() == 0 || num_filters() == num
+ (i.e. You can't change the number of filters in cont_ if the parameter
+ tensor has already been allocated.)
+ ensures
+ - #num_filters() == num
+ !*/
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the filters in this layer.
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in the filters in this layer.
+ !*/
+
+ long stride_y(
+ ) const;
+ /*!
+ ensures
+ - returns the vertical stride used when convolving the filters over an
+ image. That is, each filter will be moved 1.0/stride_y() pixels down at
+ a time when it moves over the image.
+ !*/
+
+ long stride_x(
+ ) const;
+ /*!
+ ensures
+ - returns the horizontal stride used when convolving the filters over an
+ image. That is, each filter will be moved 1.0/stride_x() pixels right at
+ a time when it moves over the image.
+ !*/
+
+ long padding_y(
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels of zero padding added to the top and bottom
+ sides of the image.
+ !*/
+
+ long padding_x(
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels of zero padding added to the left and right
+ sides of the image.
+ !*/
+
+ double get_learning_rate_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the learning rate used to optimize its parameters be
+ multiplied by get_learning_rate_multiplier().
+ !*/
+
+ double get_weight_decay_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the weight decay used to optimize its parameters be
+ multiplied by get_weight_decay_multiplier().
+ !*/
+
+ void set_learning_rate_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_learning_rate_multiplier() == val
+ !*/
+
+ void set_weight_decay_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_weight_decay_multiplier() == val
+ !*/
+
+ double get_bias_learning_rate_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the learning rate used to optimize its bias parameters be
+ multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
+ !*/
+
+ double get_bias_weight_decay_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the weight decay used to optimize its bias parameters be
+ multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
+ !*/
+
+ void set_bias_learning_rate_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_bias_learning_rate_multiplier() == val
+ !*/
+
+ void set_bias_weight_decay_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_bias_weight_decay_multiplier() == val
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+
+ };
+
+ template <
+ long num_filters,
+ long nr,
+ long nc,
+ int stride_y,
+ int stride_x,
+ typename SUBNET
+ >
+ using cont = add_layer<cont_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ int scale_y,
+ int scale_x
+ >
+ class upsample_
+ {
+ /*!
+ REQUIREMENTS ON TEMPLATE ARGUMENTS
+ All of them must be >= 1.
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it allows you to upsample a layer using
+ bilinear interpolation. To be very specific, it upsamples each of the
+ channels in an input tensor. Therefore, if IN is the input tensor to this
+ layer and OUT the output tensor, then we will have:
+ - OUT.num_samples() == IN.num_samples()
+ - OUT.k() == IN.k()
+ - OUT.nr() == IN.nr()*scale_y
+ - OUT.nc() == IN.nr()*scale_x
+ - for all valid i,k: image_plane(OUT,i,k) is a copy of
+ image_plane(IN,i,k) that has been bilinearly interpolated to fit into
+ the shape of image_plane(OUT,i,k).
+ !*/
+ public:
+
+ upsample_(
+ );
+ /*!
+ ensures
+ - This object has no state, so the constructor does nothing, aside from
+ providing default constructability.
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+ template <
+ int scale,
+ typename SUBNET
+ >
+ using upsample = add_layer<upsample_<scale,scale>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class dropout_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a dropout layer. Therefore, it
+ passes its inputs through the stochastic function f(x) which outputs either
+ 0 or x. The probability of 0 being output is given by the drop_rate
+ argument to this object's constructor.
+
+ Note that, after you finish training a network with dropout, it is a good
+ idea to replace each dropout_ layer with a multiply_ layer because the
+ multiply_ layer is faster and deterministic.
+ !*/
+
+ public:
+
+ explicit dropout_(
+ float drop_rate = 0.5
+ );
+ /*!
+ requires
+ - 0 <= drop_rate <= 1
+ ensures
+ - #get_drop_rate() == drop_rate
+ !*/
+
+ float get_drop_rate (
+ ) const;
+ /*!
+ ensures
+ - returns the probability that an individual input value to this layer will
+ be replaced with 0.
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using dropout = add_layer<dropout_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class multiply_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a basic layer that just
+ multiplies its input tensor with a constant value and returns the result.
+ It therefore has no learnable parameters.
+ !*/
+
+ public:
+ explicit multiply_(
+ float val = 0.5
+ );
+ /*!
+ ensures
+ - #get_multiply_value() == val
+ !*/
+
+ multiply_ (
+ const dropout_& item
+ );
+ /*!
+ ensures
+ - #get_multiply_value() == 1-item.get_drop_rate()
+ (i.e. We construct the multiply_ layer so that it is essentially a
+ deterministic version of the given dropout_ layer)
+ !*/
+
+ float get_multiply_value (
+ ) const;
+ /*!
+ ensures
+ - this layer simply multiplies its input tensor by get_multiply_value() and
+ produces the result as output.
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using multiply = add_layer<multiply_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ enum layer_mode
+ {
+ CONV_MODE = 0, // convolutional mode
+ FC_MODE = 1 // fully connected mode
+ };
+
+ const double DEFAULT_BATCH_NORM_EPS = 0.0001;
+
+ template <
+ layer_mode mode
+ >
+ class bn_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a batch normalization layer that
+ implements the method described in the paper:
+ Batch Normalization: Accelerating Deep Network Training by Reducing
+ Internal Covariate Shift by Sergey Ioffe and Christian Szegedy
+
+ In particular, this layer produces output tensors with the same
+ dimensionality as the input tensors, except that the mean and variances of
+ the elements have been standardized to 0 and 1 respectively.
+
+ It should also be noted that when tensors with a num_samples() dimension of
+ 1 are passed to this layer it doesn't perform batch normalization.
+ Instead, it runs in "inference mode" where the learned linear normalizing
+ transformation is used to transform the tensor.
+
+ Finally, after you finish training a batch normalized network, it is a good
+ idea to replace each bn_ layer with an affine_ layer because the affine_
+ layer is faster and will never surprise you by performing batch
+ normalization on tensors that have a num_samples() dimension > 1. This allows
+ you to run large mini-batches of samples through your final network without
+ batch normalization executing at all.
+ !*/
+
+ public:
+ bn_(
+ );
+ /*!
+ ensures
+ - #get_mode() == mode
+ - #get_running_stats_window_size() == 100
+ - #get_learning_rate_multiplier() == 1
+ - #get_weight_decay_multiplier() == 0
+ - #get_bias_learning_rate_multiplier() == 1
+ - #get_bias_weight_decay_multiplier() == 1
+ - #get_eps() == tt::DEFAULT_BATCH_NORM_EPS
+ !*/
+
+ explicit bn_(
+ unsigned long window_size,
+ double eps = tt::DEFAULT_BATCH_NORM_EPS
+ );
+ /*!
+ requires
+ - eps > 0
+ - window_size > 0
+ ensures
+ - #get_mode() == mode
+ - #get_running_stats_window_size() == window_size
+ - #get_learning_rate_multiplier() == 1
+ - #get_weight_decay_multiplier() == 0
+ - #get_bias_learning_rate_multiplier() == 1
+ - #get_bias_weight_decay_multiplier() == 1
+ - #get_eps() == eps
+ !*/
+
+ layer_mode get_mode(
+ ) const;
+ /*!
+ ensures
+ - returns the mode of this layer, either CONV_MODE or FC_MODE.
+ If the mode is FC_MODE then the normalization is applied across the
+ samples in a tensor (i.e. k()*nr()*nc() different things will be
+ normalized). Otherwise, normalization is applied across everything
+ except for the k() dimension, resulting in there being only k()
+ normalization equations that are applied spatially over the tensor.
+
+ Therefore, if you are putting batch normalization after a fully connected
+ layer you should use FC_MODE. Otherwise, if you are putting batch
+ normalization after a convolutional layer you should use CONV_MODE.
+ !*/
+
+ double get_eps(
+ ) const;
+ /*!
+ ensures
+ - When doing batch normalization, we are dividing by the standard
+ deviation. This epsilon value returned by this function is added to the
+ variance to prevent the division from dividing by zero.
+ !*/
+
+ unsigned long get_running_stats_window_size (
+ ) const;
+ /*!
+ ensures
+ - Just as recommended in the batch normalization paper, this object keeps a
+ running average of the mean and standard deviations of the features.
+ These averages are used during "inference mode" so you can run a single
+ object through a batch normalized network. They are also what is used to
+ initialize an affine_ layer that is constructed from a bn_ layer. This
+ function returns the effective number of recent samples used to compute
+ the running average.
+ !*/
+
+ void set_running_stats_window_size (
+ unsigned long new_window_size
+ );
+ /*!
+ requires
+ - new_window_size > 0
+ ensures
+ - #get_running_stats_window_size() == new_window_size
+ !*/
+
+ double get_learning_rate_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the learning rate used to optimize its parameters be
+ multiplied by get_learning_rate_multiplier().
+ !*/
+
+ double get_weight_decay_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the weight decay used to optimize its parameters be
+ multiplied by get_weight_decay_multiplier().
+ !*/
+
+ void set_learning_rate_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_learning_rate_multiplier() == val
+ !*/
+
+ void set_weight_decay_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_weight_decay_multiplier() == val
+ !*/
+
+ double get_bias_learning_rate_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the learning rate used to optimize its bias parameters be
+ multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
+ !*/
+
+ double get_bias_weight_decay_multiplier(
+ ) const;
+ /*!
+ ensures
+ - returns a multiplier number. The interpretation is that this object is
+ requesting that the weight decay used to optimize its bias parameters be
+ multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
+ !*/
+
+ void set_bias_learning_rate_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_bias_learning_rate_multiplier() == val
+ !*/
+
+ void set_bias_weight_decay_multiplier(
+ double val
+ );
+ /*!
+ requires
+ - val >= 0
+ ensures
+ - #get_bias_weight_decay_multiplier() == val
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
+ template <typename SUBNET>
+ using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename net_type>
+ void set_all_bn_running_stats_window_sizes (
+ const net_type& net,
+ unsigned long new_window_size
+ );
+ /*!
+ requires
+ - new_window_size > 0
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ ensures
+ - Sets the get_running_stats_window_size() field of all bn_ layers in net to
+ new_window_size.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class affine_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it applies a simple pointwise linear
+ transformation to an input tensor. You can think of it as having two
+ parameter tensors, A and B. If the input tensor is called INPUT then the
+ output of this layer is:
+ A*INPUT+B
+ where all operations are performed element wise and each sample in the
+ INPUT tensor is processed separately.
+
+ Moreover, this object has two modes that effect the dimensionalities of A
+ and B and how they are applied to compute A*INPUT+B. If
+ get_mode()==FC_MODE then A and B each have the same dimensionality as the
+ input tensor, except their num_samples() dimensions are 1. If
+ get_mode()==CONV_MODE then A and B have all their dimensions set to 1
+ except for k(), which is equal to INPUT.k().
+
+ In either case, the computation of A*INPUT+B is performed pointwise over all
+ the elements of INPUT using either:
+ OUTPUT(n,k,r,c) == A(1,k,r,c)*INPUT(n,k,r,c)+B(1,k,r,c)
+ or
+ OUTPUT(n,k,r,c) == A(1,k,1,1)*INPUT(n,k,r,c)+B(1,k,1,1)
+ as appropriate.
+
+
+ Finally, note that the parameters of this layer are not learnable and
+ therefore not modified during network updates. Instead, the layer will
+ perform the identity transformation unless it is initialized with a bn_
+ layer, in which case it will perform whatever transformation the bn_ layer
+ has learned.
+ !*/
+
+ public:
+
+ affine_(
+ );
+ /*!
+ ensures
+ - #get_mode() == FC_MODE
+ !*/
+
+ affine_(
+ layer_mode mode
+ );
+ /*!
+ ensures
+ - #get_mode() == mode
+ !*/
+
+ template <
+ layer_mode mode
+ >
+ affine_(
+ const bn_<mode>& layer
+ );
+ /*!
+ ensures
+ - Constructs affine_ so that it performs the same transformation as the
+ supplied batch normalization layer. You would want to do this after you
+ finish training a network with bn_ layers because the affine_ layer will
+ execute faster.
+ - #get_mode() == layer.get_mode()
+ !*/
+
+ layer_mode get_mode(
+ ) const;
+ /*!
+ ensures
+ - returns the mode of this layer, either CONV_MODE or FC_MODE.
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the
+ EXAMPLE_COMPUTATIONAL_LAYER_ interface. Also note that get_layer_params()
+ always returns an empty tensor since there are no learnable parameters in this
+ object.
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using affine = add_layer<affine_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y = _stride_y!=1? 0 : _nr/2,
+ int _padding_x = _stride_x!=1? 0 : _nc/2
+ >
+ class max_pool_
+ {
+ /*!
+ REQUIREMENTS ON TEMPLATE ARGUMENTS
+ - _nr >= 0
+ - _nc >= 0
+ - _stride_y > 0
+ - _stride_x > 0
+ - _padding_y >= 0
+ - _padding_x >= 0
+ - if (_nr != 0) then
+ - _padding_y < _nr
+ - else
+ - _padding_y == 0
+ - if (_nc != 0) then
+ - _padding_x < _nr
+ - else
+ - _padding_x == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a max pooling layer that takes an
+ input tensor and downsamples it. It does this by sliding a window over the
+ images in an input tensor and outputting, for each channel, the maximum
+ element within the window.
+
+ If _nr == 0 then it means the filter size covers all the rows in the input
+ tensor, similarly for the _nc parameter. To be precise, if we call the
+ input tensor IN and the output tensor OUT, then OUT is defined as follows:
+ - let FILT_NR == (nr()==0) ? IN.nr() : nr()
+ - let FILT_NC == (nc()==0) ? IN.nc() : nc()
+ - OUT.num_samples() == IN.num_samples()
+ - OUT.k() == IN.k()
+ - OUT.nr() == 1+(IN.nr() + 2*padding_y() - FILT_NR)/stride_y()
+ - OUT.nc() == 1+(IN.nc() + 2*padding_x() - FILT_NC)/stride_x()
+ - for all valid s, k, r, and c:
+ - image_plane(OUT,s,k)(r,c) == max(subm_clipped(image_plane(IN,s,k),
+ centered_rect(x*stride_x() + FILT_NC/2 - padding_x(),
+ y*stride_y() + FILT_NR/2 - padding_y(),
+ FILT_NC,
+ FILT_NR)))
+ !*/
+
+ public:
+
+ max_pool_ (
+ );
+ /*!
+ ensures
+ - #nr() == _nr
+ - #nc() == _nc
+ - #stride_y() == _stride_y
+ - #stride_x() == _stride_x
+ - #padding_y() == _padding_y
+ - #padding_x() == _padding_x
+ !*/
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the pooling window or 0 if the window size
+ is "the entire input tensor".
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the pooling window or 0 if the window size
+ is "the entire input tensor".
+ !*/
+
+ long stride_y(
+ ) const;
+ /*!
+ ensures
+ - returns the vertical stride used when scanning the max pooling window
+ over an image. That is, each window will be moved stride_y() pixels down
+ at a time when it moves over the image.
+ !*/
+
+ long stride_x(
+ ) const;
+ /*!
+ ensures
+ - returns the horizontal stride used when scanning the max pooling window
+ over an image. That is, each window will be moved stride_x() pixels down
+ at a time when it moves over the image.
+ !*/
+
+ long padding_y(
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels of zero padding added to the top and bottom
+ sides of the image.
+ !*/
+
+ long padding_x(
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels of zero padding added to the left and right
+ sides of the image.
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
+ interface. Note that this layer doesn't have any parameters, so the tensor
+ returned by get_layer_params() is always empty.
+ !*/
+ };
+
+ template <
+ long nr,
+ long nc,
+ int stride_y,
+ int stride_x,
+ typename SUBNET
+ >
+ using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
+
+ template <
+ typename SUBNET
+ >
+ using max_pool_everything = add_layer<max_pool_<0,0,1,1>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y = _stride_y!=1? 0 : _nr/2,
+ int _padding_x = _stride_x!=1? 0 : _nc/2
+ >
+ class avg_pool_
+ {
+ /*!
+ REQUIREMENTS ON TEMPLATE ARGUMENTS
+ - _nr >= 0
+ - _nc >= 0
+ - _stride_y > 0
+ - _stride_x > 0
+ - _padding_y >= 0
+ - _padding_x >= 0
+ - if (_nr != 0) then
+ - _padding_y < _nr
+ - else
+ - _padding_y == 0
+ - if (_nc != 0) then
+ - _padding_x < _nr
+ - else
+ - _padding_x == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines an average pooling layer that
+ takes an input tensor and downsamples it. It does this by sliding a window
+ over the images in an input tensor and outputting, for each channel, the
+ average element within the window.
+
+ If _nr == 0 then it means the filter size covers all the rows in the input
+ tensor, similarly for the _nc parameter. To be precise, if we call the
+ input tensor IN and the output tensor OUT, then OUT is defined as follows:
+ - let FILT_NR == (nr()==0) ? IN.nr() : nr()
+ - let FILT_NC == (nc()==0) ? IN.nc() : nc()
+ - OUT.num_samples() == IN.num_samples()
+ - OUT.k() == IN.k()
+ - OUT.nr() == 1+(IN.nr() + 2*padding_y() - FILT_NR)/stride_y()
+ - OUT.nc() == 1+(IN.nc() + 2*padding_x() - FILT_NC)/stride_x()
+ - for all valid s, k, r, and c:
+ - image_plane(OUT,s,k)(r,c) == mean(subm_clipped(image_plane(IN,s,k),
+ centered_rect(x*stride_x() + FILT_NC/2 - padding_x(),
+ y*stride_y() + FILT_NR/2 - padding_y(),
+ FILT_NC,
+ FILT_NR)))
+ !*/
+
+ public:
+
+ avg_pool_ (
+ );
+ /*!
+ ensures
+ - #nr() == _nr
+ - #nc() == _nc
+ - #stride_y() == _stride_y
+ - #stride_x() == _stride_x
+ - #padding_y() == _padding_y
+ - #padding_x() == _padding_x
+ !*/
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the pooling window or 0 if the window size
+ is "the entire input tensor".
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the pooling window or 0 if the window size
+ is "the entire input tensor".
+ !*/
+
+ long stride_y(
+ ) const;
+ /*!
+ ensures
+ - returns the vertical stride used when scanning the pooling window
+ over an image. That is, each window will be moved stride_y() pixels down
+ at a time when it moves over the image.
+ !*/
+
+ long stride_x(
+ ) const;
+ /*!
+ ensures
+ - returns the horizontal stride used when scanning the pooling window
+ over an image. That is, each window will be moved stride_x() pixels down
+ at a time when it moves over the image.
+ !*/
+
+ long padding_y(
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels of zero padding added to the top and bottom
+ sides of the image.
+ !*/
+
+ long padding_x(
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels of zero padding added to the left and right
+ sides of the image.
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
+ interface. Note that this layer doesn't have any parameters, so the tensor
+ returned by get_layer_params() is always empty.
+ !*/
+
+ };
+
+ template <
+ long nr,
+ long nc,
+ int stride_y,
+ int stride_x,
+ typename SUBNET
+ >
+ using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
+
+ template <
+ typename SUBNET
+ >
+ using avg_pool_everything = add_layer<avg_pool_<0,0,1,1>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class relu_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a rectified linear layer.
+ Therefore, it passes its inputs through the function
+ f(x)=max(x,0)
+ where f() is applied pointwise across the input tensor.
+ !*/
+
+ public:
+
+ relu_(
+ );
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
+ interface. Note that this layer doesn't have any parameters, so the tensor
+ returned by get_layer_params() is always empty.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using relu = add_layer<relu_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class prelu_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a parametric rectified linear
+ layer. Therefore, it passes its inputs through the function
+ f(x) = x>0 ? x : p*x
+ where f() is applied pointwise across the input tensor and p is a scalar
+ parameter learned by this layer.
+
+
+ This is the layer type introduced in the paper:
+ He, Kaiming, et al. "Delving deep into rectifiers: Surpassing
+ human-level performance on imagenet classification." Proceedings of the
+ IEEE International Conference on Computer Vision. 2015.
+ !*/
+
+ public:
+
+ explicit prelu_(
+ float initial_param_value = 0.25
+ );
+ /*!
+ ensures
+ - The p parameter will be initialized with initial_param_value.
+ - #get_initial_param_value() == initial_param_value.
+ !*/
+
+ float get_initial_param_value (
+ ) const;
+ /*!
+ ensures
+ - returns the initial value of the prelu parameter.
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using prelu = add_layer<prelu_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class sig_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a sigmoid layer. Therefore, it
+ passes its inputs through the function
+ f(x)=1/(1+exp(-x))
+ where f() is applied pointwise across the input tensor.
+ !*/
+
+ public:
+
+ sig_(
+ );
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
+ interface. Note that this layer doesn't have any parameters, so the tensor
+ returned by get_layer_params() is always empty.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using sig = add_layer<sig_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class htan_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a hyperbolic tangent layer.
+ Therefore, it passes its inputs through the function
+ f(x)=std::tanh(x)
+ where f() is applied pointwise across the input tensor.
+ !*/
+
+ public:
+
+ htan_(
+ );
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
+ interface. Note that this layer doesn't have any parameters, so the tensor
+ returned by get_layer_params() is always empty.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using htan = add_layer<htan_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class softmax_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a softmax layer. To be precise,
+ we define the softmax function s(x) as:
+ s(x) == exp(x)/sum(exp(x))
+ where x is a vector. Then this layer treats its input tensor as a
+ collection of multi-channel images and applies s() to each spatial location
+ in each image. In each application, the tensor::k() channel elements at
+ each position are input to s() and then replaced by the outputs of s().
+
+ This means that, for example, if you collapsed each output image to a 1
+ channel image by adding the channels then you would end up with images
+ where each pixel value was 1. This is because the sum of the outputs of
+ s() will always be equal to 1.
+ !*/
+
+ public:
+
+ softmax_(
+ );
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
+ interface. Note that this layer doesn't have any parameters, so the tensor
+ returned by get_layer_params() is always empty.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using softmax = add_layer<softmax_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class softmax_all_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, it defines a softmax layer. To be precise,
+ we define the softmax function s(x) as:
+ s(x) == exp(x)/sum(exp(x))
+ where x is a vector. Then this layer treats its input tensor as a
+ collection of tensor::num_samples() vectors and applies s() to each vector
+ in the tensor. Therefore, there are logically tensor::num_samples()
+ invocations of s().
+ !*/
+
+ public:
+
+ softmax_all_(
+ );
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
+ interface. Note that this layer doesn't have any parameters, so the tensor
+ returned by get_layer_params() is always empty.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using softmax_all = add_layer<softmax_all_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ template<typename> class tag
+ >
+ class add_prev_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. This layer simply adds the output of two previous layers.
+ In particular, it adds the tensor from its immediate predecessor layer,
+ sub.get_output(), with the tensor from a deeper layer,
+ layer<tag>(sub).get_output().
+
+ Therefore, you supply a tag via add_prev_'s template argument that tells it
+ what layer to add to the output of the previous layer. The result of this
+ addition is output by add_prev_. Finally, the addition happens pointwise
+ according to 4D tensor arithmetic. If the dimensions don't match then
+ missing elements are presumed to be equal to 0. Moreover, each dimension
+ of the output tensor is equal to the maximum dimension of either of the
+ inputs. That is, if the tensors A and B are being added to produce C then:
+ - C.num_samples() == max(A.num_samples(), B.num_samples())
+ - C.k() == max(A.k(), B.k())
+ - C.nr() == max(A.nr(), B.nr())
+ - C.nc() == max(A.nc(), B.nc())
+ !*/
+
+ public:
+ add_prev_(
+ );
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+
+ template <
+ template<typename> class tag,
+ typename SUBNET
+ >
+ using add_prev = add_layer<add_prev_<tag>, SUBNET>;
+
+ // Here we add some convenient aliases for using add_prev_ with the tag layers.
+ template <typename SUBNET> using add_prev1 = add_prev<tag1, SUBNET>;
+ template <typename SUBNET> using add_prev2 = add_prev<tag2, SUBNET>;
+ template <typename SUBNET> using add_prev3 = add_prev<tag3, SUBNET>;
+ template <typename SUBNET> using add_prev4 = add_prev<tag4, SUBNET>;
+ template <typename SUBNET> using add_prev5 = add_prev<tag5, SUBNET>;
+ template <typename SUBNET> using add_prev6 = add_prev<tag6, SUBNET>;
+ template <typename SUBNET> using add_prev7 = add_prev<tag7, SUBNET>;
+ template <typename SUBNET> using add_prev8 = add_prev<tag8, SUBNET>;
+ template <typename SUBNET> using add_prev9 = add_prev<tag9, SUBNET>;
+ template <typename SUBNET> using add_prev10 = add_prev<tag10, SUBNET>;
+ using add_prev1_ = add_prev_<tag1>;
+ using add_prev2_ = add_prev_<tag2>;
+ using add_prev3_ = add_prev_<tag3>;
+ using add_prev4_ = add_prev_<tag4>;
+ using add_prev5_ = add_prev_<tag5>;
+ using add_prev6_ = add_prev_<tag6>;
+ using add_prev7_ = add_prev_<tag7>;
+ using add_prev8_ = add_prev_<tag8>;
+ using add_prev9_ = add_prev_<tag9>;
+ using add_prev10_ = add_prev_<tag10>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ template<typename> class tag
+ >
+ class mult_prev_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. This layer simply multiplies the output of two previous
+ layers. In particular, it multiplies the tensor from its immediate
+ predecessor layer, sub.get_output(), with the tensor from a deeper layer,
+ layer<tag>(sub).get_output().
+
+ Therefore, you supply a tag via mult_prev_'s template argument that tells
+ it what layer to multiply with the output of the previous layer. The
+ result of this multiplication is output by mult_prev_. Finally, the
+ multiplication happens pointwise according to 4D tensor arithmetic. If the
+ dimensions don't match then missing elements are presumed to be equal to 0.
+ Moreover, each dimension of the output tensor is equal to the maximum
+ dimension of either of the inputs. That is, if the tensors A and B are
+ being multiplied to produce C then:
+ - C.num_samples() == max(A.num_samples(), B.num_samples())
+ - C.k() == max(A.k(), B.k())
+ - C.nr() == max(A.nr(), B.nr())
+ - C.nc() == max(A.nc(), B.nc())
+ !*/
+
+ public:
+ mult_prev_(
+ );
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+
+ template <
+ template<typename> class tag,
+ typename SUBNET
+ >
+ using mult_prev = add_layer<mult_prev_<tag>, SUBNET>;
+
+ // Here we add some convenient aliases for using mult_prev_ with the tag layers.
+ template <typename SUBNET> using mult_prev1 = mult_prev<tag1, SUBNET>;
+ template <typename SUBNET> using mult_prev2 = mult_prev<tag2, SUBNET>;
+ template <typename SUBNET> using mult_prev3 = mult_prev<tag3, SUBNET>;
+ template <typename SUBNET> using mult_prev4 = mult_prev<tag4, SUBNET>;
+ template <typename SUBNET> using mult_prev5 = mult_prev<tag5, SUBNET>;
+ template <typename SUBNET> using mult_prev6 = mult_prev<tag6, SUBNET>;
+ template <typename SUBNET> using mult_prev7 = mult_prev<tag7, SUBNET>;
+ template <typename SUBNET> using mult_prev8 = mult_prev<tag8, SUBNET>;
+ template <typename SUBNET> using mult_prev9 = mult_prev<tag9, SUBNET>;
+ template <typename SUBNET> using mult_prev10 = mult_prev<tag10, SUBNET>;
+ using mult_prev1_ = mult_prev_<tag1>;
+ using mult_prev2_ = mult_prev_<tag2>;
+ using mult_prev3_ = mult_prev_<tag3>;
+ using mult_prev4_ = mult_prev_<tag4>;
+ using mult_prev5_ = mult_prev_<tag5>;
+ using mult_prev6_ = mult_prev_<tag6>;
+ using mult_prev7_ = mult_prev_<tag7>;
+ using mult_prev8_ = mult_prev_<tag8>;
+ using mult_prev9_ = mult_prev_<tag9>;
+ using mult_prev10_ = mult_prev_<tag10>;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ template<typename> class tag
+ >
+ class scale_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. This layer scales the output channels of the tagged layer
+ by multiplying it with the output of the previous layer. To be specific:
+ - Let INPUT == layer<tag>(sub).get_output()
+ - Let SCALES == sub.get_output()
+ - This layer takes INPUT and SCALES as input.
+ - The output of this layer has the same dimensions as INPUT.
+ - This layer requires:
+ - SCALES.num_samples() == INPUT.num_samples()
+ - SCALES.k() == INPUT.k()
+ - SCALES.nr() == 1
+ - SCALES.nc() == 1
+ - The output tensor is produced by pointwise multiplying SCALES with
+ INPUT at each spatial location. Therefore, if OUT is the output of
+ this layer then we would have:
+ OUT(n,k,r,c) == INPUT(n,k,r,c)*SCALES(n,k)
+ !*/
+
+ public:
+ scale_(
+ );
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+
+ template <
+ template<typename> class tag,
+ typename SUBNET
+ >
+ using scale = add_layer<scale_<tag>, SUBNET>;
+
+ // Here we add some convenient aliases for using scale_ with the tag layers.
+ template <typename SUBNET> using scale1 = scale<tag1, SUBNET>;
+ template <typename SUBNET> using scale2 = scale<tag2, SUBNET>;
+ template <typename SUBNET> using scale3 = scale<tag3, SUBNET>;
+ template <typename SUBNET> using scale4 = scale<tag4, SUBNET>;
+ template <typename SUBNET> using scale5 = scale<tag5, SUBNET>;
+ template <typename SUBNET> using scale6 = scale<tag6, SUBNET>;
+ template <typename SUBNET> using scale7 = scale<tag7, SUBNET>;
+ template <typename SUBNET> using scale8 = scale<tag8, SUBNET>;
+ template <typename SUBNET> using scale9 = scale<tag9, SUBNET>;
+ template <typename SUBNET> using scale10 = scale<tag10, SUBNET>;
+ using scale1_ = scale_<tag1>;
+ using scale2_ = scale_<tag2>;
+ using scale3_ = scale_<tag3>;
+ using scale4_ = scale_<tag4>;
+ using scale5_ = scale_<tag5>;
+ using scale6_ = scale_<tag6>;
+ using scale7_ = scale_<tag7>;
+ using scale8_ = scale_<tag8>;
+ using scale9_ = scale_<tag9>;
+ using scale10_ = scale_<tag10>;
+
+// ----------------------------------------------------------------------------------------
+
+ template<
+ template<typename> class... TAG_TYPES
+ >
+ class concat_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. This layer simply concatenates the output of tagged layers.
+ Importantly, each input layer must have the same dimensions (i.e.
+ num_samples, nr, and nc) except for the k channel, which may vary. This is
+ because the concatenation happens along the k dimension. That is, the
+ output of this network is a tensor, OUT, that is the concatenation of the
+ tensors:
+ for each (tag in TAG_TYPES)
+ layer<tag>(subnet).get_output()
+ Therefore, out.num_samples(), out.nr(), and out.nc() match the dimensions
+ of the input tensors while OUT.k() is the sum of the input layer's k()
+ dimensions.
+ !*/
+
+ public:
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ dpoint map_input_to_output(dpoint p) const;
+ dpoint map_output_to_input(dpoint p) const;
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+
+ // concat layer definitions
+ template <template<typename> class TAG1,
+ template<typename> class TAG2,
+ typename SUBNET>
+ using concat2 = add_layer<concat_<TAG1, TAG2>, SUBNET>;
+
+ template <template<typename> class TAG1,
+ template<typename> class TAG2,
+ template<typename> class TAG3,
+ typename SUBNET>
+ using concat3 = add_layer<concat_<TAG1, TAG2, TAG3>, SUBNET>;
+
+ template <template<typename> class TAG1,
+ template<typename> class TAG2,
+ template<typename> class TAG3,
+ template<typename> class TAG4,
+ typename SUBNET>
+ using concat4 = add_layer<concat_<TAG1, TAG2, TAG3, TAG4>, SUBNET>;
+
+ template <template<typename> class TAG1,
+ template<typename> class TAG2,
+ template<typename> class TAG3,
+ template<typename> class TAG4,
+ template<typename> class TAG5,
+ typename SUBNET>
+ using concat5 = add_layer<concat_<TAG1, TAG2, TAG3, TAG4, TAG5>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A inception layer definitions !*/
+
+ // Now define inception layer tag types. These layer aliases allow creating
+ // the networks described in the paper:
+ // Szegedy, Christian, et al. "Going deeper with convolutions." Proceedings of
+ // the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
+ // See the dnn_inception_ex.cpp example for a complete example of their use. Note also
+ // that we use tag ID numbers >= 1000 to avoid conflict with user's tag layers.
+ template <typename SUBNET> using itag0 = add_tag_layer< 1000 + 0, SUBNET>;
+ template <typename SUBNET> using itag1 = add_tag_layer< 1000 + 1, SUBNET>;
+ template <typename SUBNET> using itag2 = add_tag_layer< 1000 + 2, SUBNET>;
+ template <typename SUBNET> using itag3 = add_tag_layer< 1000 + 3, SUBNET>;
+ template <typename SUBNET> using itag4 = add_tag_layer< 1000 + 4, SUBNET>;
+ template <typename SUBNET> using itag5 = add_tag_layer< 1000 + 5, SUBNET>;
+ // skip to inception input
+ template <typename SUBNET> using iskip = add_skip_layer< itag0, SUBNET>;
+
+ // here are some templates to be used for creating inception layer groups
+ template <template<typename>class B1,
+ template<typename>class B2,
+ typename SUBNET>
+ using inception2 = concat2<itag1, itag2, itag1<B1<iskip< itag2<B2< itag0<SUBNET>>>>>>>;
+
+ template <template<typename>class B1,
+ template<typename>class B2,
+ template<typename>class B3,
+ typename SUBNET>
+ using inception3 = concat3<itag1, itag2, itag3, itag1<B1<iskip< itag2<B2<iskip< itag3<B3< itag0<SUBNET>>>>>>>>>>;
+
+ template <template<typename>class B1,
+ template<typename>class B2,
+ template<typename>class B3,
+ template<typename>class B4,
+ typename SUBNET>
+ using inception4 = concat4<itag1, itag2, itag3, itag4,
+ itag1<B1<iskip< itag2<B2<iskip< itag3<B3<iskip< itag4<B4< itag0<SUBNET>>>>>>>>>>>>>;
+
+ template <template<typename>class B1,
+ template<typename>class B2,
+ template<typename>class B3,
+ template<typename>class B4,
+ template<typename>class B5,
+ typename SUBNET>
+ using inception5 = concat5<itag1, itag2, itag3, itag4, itag5,
+ itag1<B1<iskip< itag2<B2<iskip< itag3<B3<iskip< itag4<B4<iskip< itag5<B5< itag0<SUBNET>>>>>>>>>>>>>>>>;
+
+// ----------------------------------------------------------------------------------------
+
+ const double DEFAULT_L2_NORM_EPS = 1e-5;
+
+ class l2normalize_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. It takes tensors as input and L2 normalizes them. In particular,
+ it has the following properties:
+ - The output tensors from this layer have the same dimensions as the
+ input tensors.
+ - If you think of each input tensor as a set of tensor::num_samples()
+ vectors, then the output tensor contains the same vectors except they
+ have been length normalized so that their L2 norms are all 1. I.e.
+ for each vector v we will have ||v||==1.
+ !*/
+
+ public:
+
+ explicit l2normalize_(
+ double eps = tt::DEFAULT_L2_NORM_EPS
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_eps() == eps
+ !*/
+
+ double get_eps(
+ ) const;
+ /*!
+ ensures
+ - When we normalize a vector we divide it by its L2 norm. However, the
+ get_eps() value is added to the squared norm prior to division to avoid
+ ever dividing by zero.
+ !*/
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ void forward_inplace(const tensor& input, tensor& output);
+ void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long _offset,
+ long _k,
+ long _nr,
+ long _nc
+ >
+ class extract_
+ {
+ /*!
+ REQUIREMENTS ON TEMPLATE ARGUMENTS
+ - 0 <= _offset
+ - 0 < _k
+ - 0 < _nr
+ - 0 < _nc
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
+ defined above. In particular, the output of this layer is simply a copy of
+ the input tensor. However, you can configure the extract layer to output
+ only some subset of the input tensor and also to reshape it. Therefore,
+ the dimensions of the tensor output by this layer are as follows (letting
+ IN be the input tensor and OUT the output tensor):
+ - OUT.num_samples() == IN.num_samples()
+ - OUT.k() == _k
+ - OUT.nr() == _nr
+ - OUT.nc() == _nc
+
+ So the output will always have the same number of samples as the input, but
+ within each sample (the k,nr,nc part) we will copy only a subset of the
+ values. Moreover, the _offset parameter controls which part of each sample
+ we take. To be very precise, we will have:
+ - let IN_SIZE = IN.k()*IN.nr()*IN.nc()
+ - let OUT_SIZE = _k*_nr*_nc
+ - for i in range[0,IN.num_samples()) and j in range[0,OUT_SIZE):
+ - OUT.host()[i*OUT_SIZE+j] == IN.host()[i*IN_SIZE+_offset+j]
+
+
+ Finally, all this means that the input tensor to this layer must have a big
+ enough size to accommodate taking a _k*_nr*_nc slice from each of its
+ samples.
+ !*/
+
+ public:
+
+ template <typename SUBNET> void setup (const SUBNET& sub);
+ template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
+ template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
+ const tensor& get_layer_params() const;
+ tensor& get_layer_params();
+ /*!
+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
+ !*/
+ };
+
+ template <
+ long offset,
+ long k,
+ long nr,
+ long nc,
+ typename SUBNET
+ >
+ using extract = add_layer<extract_<offset,k,nr,nc>, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_LAYERS_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/dnn/loss.h b/ml/dlib/dlib/dnn/loss.h
new file mode 100644
index 000000000..1b09b85c3
--- /dev/null
+++ b/ml/dlib/dlib/dnn/loss.h
@@ -0,0 +1,2870 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_LOSS_H_
+#define DLIB_DNn_LOSS_H_
+
+#include "loss_abstract.h"
+#include "core.h"
+#include "../matrix.h"
+#include "tensor_tools.h"
+#include "../geometry.h"
+#include "../image_processing/box_overlap_testing.h"
+#include "../image_processing/full_object_detection.h"
+#include "../svm/ranking_tools.h"
+#include <sstream>
+#include <map>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_binary_hinge_
+ {
+ public:
+
+ typedef float training_label_type;
+ typedef float output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+
+ const tensor& output_tensor = sub.get_output();
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ *iter++ = out_data[i];
+ }
+ }
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+
+ // The loss we output is the average loss over the mini-batch.
+ const double scale = 1.0/output_tensor.num_samples();
+ double loss = 0;
+ const float* out_data = output_tensor.host();
+ float* g = grad.host_write_only();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ const float y = *truth++;
+ DLIB_CASSERT(y == +1 || y == -1, "y: " << y);
+ const float temp = 1-y*out_data[i];
+ if (temp > 0)
+ {
+ loss += scale*temp;
+ g[i] = -scale*y;
+ }
+ else
+ {
+ g[i] = 0;
+ }
+ }
+ return loss;
+ }
+
+ friend void serialize(const loss_binary_hinge_& , std::ostream& out)
+ {
+ serialize("loss_binary_hinge_", out);
+ }
+
+ friend void deserialize(loss_binary_hinge_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_binary_hinge_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_hinge_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_binary_hinge_& )
+ {
+ out << "loss_binary_hinge";
+ return out;
+ }
+
+ friend void to_xml(const loss_binary_hinge_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_binary_hinge/>";
+ }
+
+ };
+
+ template <typename SUBNET>
+ using loss_binary_hinge = add_loss_layer<loss_binary_hinge_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_binary_log_
+ {
+ public:
+
+ typedef float training_label_type;
+ typedef float output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+
+ const tensor& output_tensor = sub.get_output();
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ *iter++ = out_data[i];
+ }
+ }
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+ DLIB_CASSERT(grad.nr() == 1 &&
+ grad.nc() == 1 &&
+ grad.k() == 1);
+
+ tt::sigmoid(grad, output_tensor);
+
+ // The loss we output is the average loss over the mini-batch.
+ const double scale = 1.0/output_tensor.num_samples();
+ double loss = 0;
+ float* g = grad.host();
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ const float y = *truth++;
+ DLIB_CASSERT(y == +1 || y == -1, "y: " << y);
+ float temp;
+ if (y > 0)
+ {
+ temp = log1pexp(-out_data[i]);
+ loss += scale*temp;
+ g[i] = scale*(g[i]-1);
+ }
+ else
+ {
+ temp = -(-out_data[i]-log1pexp(-out_data[i]));
+ loss += scale*temp;
+ g[i] = scale*g[i];
+ }
+ }
+ return loss;
+ }
+
+ friend void serialize(const loss_binary_log_& , std::ostream& out)
+ {
+ serialize("loss_binary_log_", out);
+ }
+
+ friend void deserialize(loss_binary_log_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_binary_log_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_log_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_binary_log_& )
+ {
+ out << "loss_binary_log";
+ return out;
+ }
+
+ friend void to_xml(const loss_binary_log_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_binary_log/>";
+ }
+
+ };
+
+ template <typename T>
+ T safe_log(T input, T epsilon = 1e-10)
+ {
+ // Prevent trying to calculate the logarithm of a very small number (let alone zero)
+ return std::log(std::max(input, epsilon));
+ }
+
+ template <typename SUBNET>
+ using loss_binary_log = add_loss_layer<loss_binary_log_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_multiclass_log_
+ {
+ public:
+
+ typedef unsigned long training_label_type;
+ typedef unsigned long output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 );
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+
+ // Note that output_tensor.k() should match the number of labels.
+
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ // The index of the largest output for this sample is the label.
+ *iter++ = index_of_max(rowm(mat(output_tensor),i));
+ }
+ }
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1);
+ DLIB_CASSERT(grad.nr() == 1 &&
+ grad.nc() == 1);
+
+ tt::softmax(grad, output_tensor);
+
+ // The loss we output is the average loss over the mini-batch.
+ const double scale = 1.0/output_tensor.num_samples();
+ double loss = 0;
+ float* g = grad.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ const long y = (long)*truth++;
+ // The network must produce a number of outputs that is equal to the number
+ // of labels when using this type of loss.
+ DLIB_CASSERT(y < output_tensor.k(), "y: " << y << ", output_tensor.k(): " << output_tensor.k());
+ for (long k = 0; k < output_tensor.k(); ++k)
+ {
+ const unsigned long idx = i*output_tensor.k()+k;
+ if (k == y)
+ {
+ loss += scale*-safe_log(g[idx]);
+ g[idx] = scale*(g[idx]-1);
+ }
+ else
+ {
+ g[idx] = scale*g[idx];
+ }
+ }
+ }
+ return loss;
+ }
+
+ friend void serialize(const loss_multiclass_log_& , std::ostream& out)
+ {
+ serialize("loss_multiclass_log_", out);
+ }
+
+ friend void deserialize(loss_multiclass_log_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_multiclass_log_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_& )
+ {
+ out << "loss_multiclass_log";
+ return out;
+ }
+
+ friend void to_xml(const loss_multiclass_log_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_multiclass_log/>";
+ }
+
+ };
+
+ template <typename SUBNET>
+ using loss_multiclass_log = add_loss_layer<loss_multiclass_log_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_multimulticlass_log_
+ {
+
+ public:
+
+ loss_multimulticlass_log_ () = default;
+
+ loss_multimulticlass_log_ (
+ const std::map<std::string,std::vector<std::string>>& labels
+ )
+ {
+ for (auto& l : labels)
+ {
+ possible_labels[l.first] = std::make_shared<decltype(l.second)>(l.second);
+ DLIB_CASSERT(l.second.size() >= 2, "Each classifier must have at least two possible labels.");
+
+ for (size_t i = 0; i < l.second.size(); ++i)
+ {
+ label_idx_lookup[l.first][l.second[i]] = i;
+ ++total_num_labels;
+ }
+ }
+ }
+
+ unsigned long number_of_labels() const { return total_num_labels; }
+
+ unsigned long number_of_classifiers() const { return possible_labels.size(); }
+
+ std::map<std::string,std::vector<std::string>> get_labels (
+ ) const
+ {
+ std::map<std::string,std::vector<std::string>> info;
+ for (auto& i : possible_labels)
+ {
+ for (auto& label : *i.second)
+ info[i.first].emplace_back(label);
+ }
+ return info;
+ }
+
+ class classifier_output
+ {
+
+ public:
+ classifier_output() = default;
+
+ size_t num_classes() const { return class_probs.size(); }
+
+ double probability_of_class (
+ size_t i
+ ) const
+ {
+ DLIB_CASSERT(i < num_classes());
+ return class_probs(i);
+ }
+
+ const std::string& label(
+ size_t i
+ ) const
+ {
+ DLIB_CASSERT(i < num_classes());
+ return (*_labels)[i];
+ }
+
+ operator std::string(
+ ) const
+ {
+ DLIB_CASSERT(num_classes() != 0);
+ return (*_labels)[index_of_max(class_probs)];
+ }
+
+ friend std::ostream& operator<< (std::ostream& out, const classifier_output& item)
+ {
+ DLIB_ASSERT(item.num_classes() != 0);
+ out << static_cast<std::string>(item);
+ return out;
+ }
+
+ private:
+
+ friend class loss_multimulticlass_log_;
+
+ template <typename EXP>
+ classifier_output(
+ const matrix_exp<EXP>& class_probs,
+ const std::shared_ptr<std::vector<std::string>>& _labels
+ ) :
+ class_probs(class_probs),
+ _labels(_labels)
+ {
+ }
+
+ matrix<float,1,0> class_probs;
+ std::shared_ptr<std::vector<std::string>> _labels;
+ };
+
+ typedef std::map<std::string,std::string> training_label_type;
+ typedef std::map<std::string,classifier_output> output_label_type;
+
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter_begin
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 );
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ DLIB_CASSERT(number_of_labels() != 0, "You must give the loss_multimulticlass_log_'s constructor label data before you can use it!");
+ DLIB_CASSERT(output_tensor.k() == (long)number_of_labels(), "The output tensor must have " << number_of_labels() << " channels.");
+
+
+ long k_offset = 0;
+ for (auto& l : possible_labels)
+ {
+ auto iter = iter_begin;
+ const std::string& classifier_name = l.first;
+ const auto& labels = (*l.second);
+ scratch.set_size(output_tensor.num_samples(), labels.size());
+ tt::copy_tensor(false, scratch, 0, output_tensor, k_offset, labels.size());
+
+ tt::softmax(scratch, scratch);
+
+ for (long i = 0; i < scratch.num_samples(); ++i)
+ (*iter++)[classifier_name] = classifier_output(rowm(mat(scratch),i), l.second);
+
+ k_offset += labels.size();
+ }
+ }
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth_begin,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1);
+ DLIB_CASSERT(grad.nr() == 1 &&
+ grad.nc() == 1);
+ DLIB_CASSERT(number_of_labels() != 0, "You must give the loss_multimulticlass_log_'s constructor label data before you can use it!");
+ DLIB_CASSERT(output_tensor.k() == (long)number_of_labels(), "The output tensor must have " << number_of_labels() << " channels.");
+
+ // The loss we output is the average loss over the mini-batch.
+ const double scale = 1.0/output_tensor.num_samples();
+ double loss = 0;
+ long k_offset = 0;
+ for (auto& l : label_idx_lookup)
+ {
+ const std::string& classifier_name = l.first;
+ const auto& int_labels = l.second;
+ scratch.set_size(output_tensor.num_samples(), int_labels.size());
+ tt::copy_tensor(false, scratch, 0, output_tensor, k_offset, int_labels.size());
+
+ tt::softmax(scratch, scratch);
+
+
+ auto truth = truth_begin;
+ float* g = scratch.host();
+ for (long i = 0; i < scratch.num_samples(); ++i)
+ {
+ const long y = int_labels.at(truth->at(classifier_name));
+ ++truth;
+
+ for (long k = 0; k < scratch.k(); ++k)
+ {
+ const unsigned long idx = i*scratch.k()+k;
+ if (k == y)
+ {
+ loss += scale*-std::log(g[idx]);
+ g[idx] = scale*(g[idx]-1);
+ }
+ else
+ {
+ g[idx] = scale*g[idx];
+ }
+ }
+ }
+
+ tt::copy_tensor(false, grad, k_offset, scratch, 0, int_labels.size());
+
+ k_offset += int_labels.size();
+ }
+ return loss;
+ }
+
+
+ friend void serialize(const loss_multimulticlass_log_& item, std::ostream& out)
+ {
+ serialize("loss_multimulticlass_log_", out);
+ serialize(item.get_labels(), out);
+ }
+
+ friend void deserialize(loss_multimulticlass_log_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_multimulticlass_log_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_multimulticlass_log_.");
+
+ std::map<std::string,std::vector<std::string>> info;
+ deserialize(info, in);
+ item = loss_multimulticlass_log_(info);
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_multimulticlass_log_& item)
+ {
+ out << "loss_multimulticlass_log, labels={";
+ for (auto i = item.possible_labels.begin(); i != item.possible_labels.end(); )
+ {
+ auto& category = i->first;
+ auto& labels = *(i->second);
+ out << category << ":(";
+ for (size_t j = 0; j < labels.size(); ++j)
+ {
+ out << labels[j];
+ if (j+1 < labels.size())
+ out << ",";
+ }
+
+ out << ")";
+ if (++i != item.possible_labels.end())
+ out << ", ";
+ }
+ out << "}";
+ return out;
+ }
+
+ friend void to_xml(const loss_multimulticlass_log_& item, std::ostream& out)
+ {
+ out << "<loss_multimulticlass_log>\n";
+ out << item;
+ out << "\n</loss_multimulticlass_log>";
+ }
+
+ private:
+
+ std::map<std::string,std::shared_ptr<std::vector<std::string>>> possible_labels;
+ unsigned long total_num_labels = 0;
+
+ // We make it true that: possible_labels[classifier][label_idx_lookup[classifier][label]] == label
+ std::map<std::string, std::map<std::string,long>> label_idx_lookup;
+
+
+ // Scratch doesn't logically contribute to the state of this object. It's just
+ // temporary scratch space used by this class.
+ mutable resizable_tensor scratch;
+
+
+ };
+
+ template <typename SUBNET>
+ using loss_multimulticlass_log = add_loss_layer<loss_multimulticlass_log_, SUBNET>;
+
+ inline bool operator== (const std::string& lhs, const loss_multimulticlass_log_::classifier_output& rhs)
+ { return lhs == static_cast<const std::string&>(rhs); }
+ inline bool operator== (const loss_multimulticlass_log_::classifier_output& lhs, const std::string& rhs)
+ { return rhs == static_cast<const std::string&>(lhs); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ enum class use_image_pyramid : uint8_t
+ {
+ no,
+ yes
+ };
+
+ struct mmod_options
+ {
+ public:
+
+ struct detector_window_details
+ {
+ detector_window_details() = default;
+ detector_window_details(unsigned long w, unsigned long h) : width(w), height(h) {}
+ detector_window_details(unsigned long w, unsigned long h, const std::string& l) : width(w), height(h), label(l) {}
+
+ unsigned long width = 0;
+ unsigned long height = 0;
+ std::string label;
+
+ friend inline void serialize(const detector_window_details& item, std::ostream& out)
+ {
+ int version = 2;
+ serialize(version, out);
+ serialize(item.width, out);
+ serialize(item.height, out);
+ serialize(item.label, out);
+ }
+
+ friend inline void deserialize(detector_window_details& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1 && version != 2)
+ throw serialization_error("Unexpected version found while deserializing dlib::mmod_options::detector_window_details");
+ deserialize(item.width, in);
+ deserialize(item.height, in);
+ if (version == 2)
+ deserialize(item.label, in);
+ }
+
+ };
+
+ mmod_options() = default;
+
+ std::vector<detector_window_details> detector_windows;
+ double loss_per_false_alarm = 1;
+ double loss_per_missed_target = 1;
+ double truth_match_iou_threshold = 0.5;
+ test_box_overlap overlaps_nms = test_box_overlap(0.4);
+ test_box_overlap overlaps_ignore;
+
+ use_image_pyramid assume_image_pyramid = use_image_pyramid::yes;
+
+ mmod_options (
+ const std::vector<std::vector<mmod_rect>>& boxes,
+ const unsigned long target_size, // We want the length of the longest dimension of the detector window to be this.
+ const unsigned long min_target_size, // But we require that the smallest dimension of the detector window be at least this big.
+ const double min_detector_window_overlap_iou = 0.75
+ )
+ {
+ DLIB_CASSERT(0 < min_target_size && min_target_size <= target_size);
+ DLIB_CASSERT(0.5 < min_detector_window_overlap_iou && min_detector_window_overlap_iou < 1);
+
+ // Figure out what detector windows we will need.
+ for (auto& label : get_labels(boxes))
+ {
+ for (auto ratio : find_covering_aspect_ratios(boxes, test_box_overlap(min_detector_window_overlap_iou), label))
+ {
+ double detector_width;
+ double detector_height;
+ if (ratio < 1)
+ {
+ detector_height = target_size;
+ detector_width = ratio*target_size;
+ if (detector_width < min_target_size)
+ {
+ detector_height = min_target_size/ratio;
+ detector_width = min_target_size;
+ }
+ }
+ else
+ {
+ detector_width = target_size;
+ detector_height = target_size/ratio;
+ if (detector_height < min_target_size)
+ {
+ detector_width = min_target_size*ratio;
+ detector_height = min_target_size;
+ }
+ }
+
+ detector_window_details p((unsigned long)std::round(detector_width), (unsigned long)std::round(detector_height), label);
+ detector_windows.push_back(p);
+ }
+ }
+
+ DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes).");
+
+ set_overlap_nms(boxes);
+ }
+
+ mmod_options(
+ use_image_pyramid assume_image_pyramid,
+ const std::vector<std::vector<mmod_rect>>& boxes,
+ const double min_detector_window_overlap_iou = 0.75
+ )
+ : assume_image_pyramid(assume_image_pyramid)
+ {
+ DLIB_CASSERT(assume_image_pyramid == use_image_pyramid::no);
+ DLIB_CASSERT(0.5 < min_detector_window_overlap_iou && min_detector_window_overlap_iou < 1);
+
+ // Figure out what detector windows we will need.
+ for (auto& label : get_labels(boxes))
+ {
+ for (auto rectangle : find_covering_rectangles(boxes, test_box_overlap(min_detector_window_overlap_iou), label))
+ {
+ detector_windows.push_back(detector_window_details(rectangle.width(), rectangle.height(), label));
+ }
+ }
+
+ DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes).");
+
+ set_overlap_nms(boxes);
+ }
+
+ private:
+
+ void set_overlap_nms(const std::vector<std::vector<mmod_rect>>& boxes)
+ {
+ // Convert from mmod_rect to rectangle so we can call
+ // find_tight_overlap_tester().
+ std::vector<std::vector<rectangle>> temp;
+ for (auto&& bi : boxes)
+ {
+ std::vector<rectangle> rtemp;
+ for (auto&& b : bi)
+ {
+ if (b.ignore)
+ continue;
+ rtemp.push_back(b.rect);
+ }
+ temp.push_back(std::move(rtemp));
+ }
+ overlaps_nms = find_tight_overlap_tester(temp);
+ // Relax the non-max-suppression a little so that it doesn't accidentally make
+ // it impossible for the detector to output boxes matching the training data.
+ // This could be a problem with the tightest possible nms test since there is
+ // some small variability in how boxes get positioned between the training data
+ // and the coordinate system used by the detector when it runs. So relaxing it
+ // here takes care of that.
+ auto iou_thresh = advance_toward_1(overlaps_nms.get_iou_thresh());
+ auto percent_covered_thresh = advance_toward_1(overlaps_nms.get_percent_covered_thresh());
+ overlaps_nms = test_box_overlap(iou_thresh, percent_covered_thresh);
+ }
+
+ static double advance_toward_1 (
+ double val
+ )
+ {
+ if (val < 1)
+ val += (1-val)*0.1;
+ return val;
+ }
+
+ static size_t count_overlaps (
+ const std::vector<rectangle>& rects,
+ const test_box_overlap& overlaps,
+ const rectangle& ref_box
+ )
+ {
+ size_t cnt = 0;
+ for (auto& b : rects)
+ {
+ if (overlaps(b, ref_box))
+ ++cnt;
+ }
+ return cnt;
+ }
+
+ static std::vector<rectangle> find_rectangles_overlapping_all_others (
+ std::vector<rectangle> rects,
+ const test_box_overlap& overlaps
+ )
+ {
+ std::vector<rectangle> exemplars;
+ dlib::rand rnd;
+
+ while(rects.size() > 0)
+ {
+ // Pick boxes at random and see if they overlap a lot of other boxes. We will try
+ // 500 different boxes each iteration and select whichever hits the most others to
+ // add to our exemplar set.
+ rectangle best_ref_box;
+ size_t best_cnt = 0;
+ for (int iter = 0; iter < 500; ++iter)
+ {
+ rectangle ref_box = rects[rnd.get_random_64bit_number()%rects.size()];
+ size_t cnt = count_overlaps(rects, overlaps, ref_box);
+ if (cnt >= best_cnt)
+ {
+ best_cnt = cnt;
+ best_ref_box = ref_box;
+ }
+ }
+
+ // Now mark all the boxes the new ref box hit as hit.
+ for (size_t i = 0; i < rects.size(); ++i)
+ {
+ if (overlaps(rects[i], best_ref_box))
+ {
+ // remove box from rects so we don't hit it again later
+ swap(rects[i], rects.back());
+ rects.pop_back();
+ --i;
+ }
+ }
+
+ exemplars.push_back(best_ref_box);
+ }
+
+ return exemplars;
+ }
+
+ static std::set<std::string> get_labels (
+ const std::vector<std::vector<mmod_rect>>& rects
+ )
+ {
+ std::set<std::string> labels;
+ for (auto& rr : rects)
+ {
+ for (auto& r : rr)
+ labels.insert(r.label);
+ }
+ return labels;
+ }
+
+ static std::vector<double> find_covering_aspect_ratios (
+ const std::vector<std::vector<mmod_rect>>& rects,
+ const test_box_overlap& overlaps,
+ const std::string& label
+ )
+ {
+ std::vector<rectangle> boxes;
+ // Make sure all the boxes have the same size and position, so that the only thing our
+ // checks for overlap will care about is aspect ratio (i.e. scale and x,y position are
+ // ignored).
+ for (auto& bb : rects)
+ {
+ for (auto&& b : bb)
+ {
+ if (!b.ignore && b.label == label)
+ boxes.push_back(move_rect(set_rect_area(b.rect,400*400), point(0,0)));
+ }
+ }
+
+ std::vector<double> ratios;
+ for (auto r : find_rectangles_overlapping_all_others(boxes, overlaps))
+ ratios.push_back(r.width()/(double)r.height());
+ return ratios;
+ }
+
+ static std::vector<dlib::rectangle> find_covering_rectangles (
+ const std::vector<std::vector<mmod_rect>>& rects,
+ const test_box_overlap& overlaps,
+ const std::string& label
+ )
+ {
+ std::vector<rectangle> boxes;
+ // Make sure all the boxes have the same position, so that the we only check for
+ // width and height.
+ for (auto& bb : rects)
+ {
+ for (auto&& b : bb)
+ {
+ if (!b.ignore && b.label == label)
+ boxes.push_back(rectangle(b.rect.width(), b.rect.height()));
+ }
+ }
+
+ return find_rectangles_overlapping_all_others(boxes, overlaps);
+ }
+ };
+
+ inline void serialize(const mmod_options& item, std::ostream& out)
+ {
+ int version = 3;
+
+ serialize(version, out);
+ serialize(item.detector_windows, out);
+ serialize(item.loss_per_false_alarm, out);
+ serialize(item.loss_per_missed_target, out);
+ serialize(item.truth_match_iou_threshold, out);
+ serialize(item.overlaps_nms, out);
+ serialize(item.overlaps_ignore, out);
+ serialize(static_cast<uint8_t>(item.assume_image_pyramid), out);
+ }
+
+ inline void deserialize(mmod_options& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 3 && version != 2 && version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::mmod_options");
+ if (version == 1)
+ {
+ unsigned long width;
+ unsigned long height;
+ deserialize(width, in);
+ deserialize(height, in);
+ item.detector_windows = {mmod_options::detector_window_details(width, height)};
+ }
+ else
+ {
+ deserialize(item.detector_windows, in);
+ }
+ deserialize(item.loss_per_false_alarm, in);
+ deserialize(item.loss_per_missed_target, in);
+ deserialize(item.truth_match_iou_threshold, in);
+ deserialize(item.overlaps_nms, in);
+ deserialize(item.overlaps_ignore, in);
+ item.assume_image_pyramid = use_image_pyramid::yes;
+ if (version >= 3)
+ {
+ uint8_t assume_image_pyramid = 0;
+ deserialize(assume_image_pyramid, in);
+ item.assume_image_pyramid = static_cast<use_image_pyramid>(assume_image_pyramid);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_mmod_
+ {
+ struct intermediate_detection
+ {
+ intermediate_detection() = default;
+
+ intermediate_detection(
+ rectangle rect_
+ ) : rect(rect_) {}
+
+ intermediate_detection(
+ rectangle rect_,
+ double detection_confidence_,
+ size_t tensor_offset_,
+ long channel
+ ) : rect(rect_), detection_confidence(detection_confidence_), tensor_offset(tensor_offset_), tensor_channel(channel) {}
+
+ rectangle rect;
+ double detection_confidence = 0;
+ size_t tensor_offset = 0;
+ long tensor_channel = 0;
+
+ bool operator<(const intermediate_detection& item) const { return detection_confidence < item.detection_confidence; }
+ };
+
+ public:
+
+ typedef std::vector<mmod_rect> training_label_type;
+ typedef std::vector<mmod_rect> output_label_type;
+
+ loss_mmod_() {}
+
+ loss_mmod_(mmod_options options_) : options(options_) {}
+
+ const mmod_options& get_options (
+ ) const { return options; }
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter,
+ double adjust_threshold = 0
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1, sub.sample_expansion_factor());
+
+ std::vector<intermediate_detection> dets_accum;
+ output_label_type final_dets;
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ tensor_to_dets(input_tensor, output_tensor, i, dets_accum, adjust_threshold, sub);
+
+ // Do non-max suppression
+ final_dets.clear();
+ for (unsigned long i = 0; i < dets_accum.size(); ++i)
+ {
+ if (overlaps_any_box_nms(final_dets, dets_accum[i].rect))
+ continue;
+
+ final_dets.push_back(mmod_rect(dets_accum[i].rect,
+ dets_accum[i].detection_confidence,
+ options.detector_windows[dets_accum[i].tensor_channel].label));
+ }
+
+ *iter++ = std::move(final_dets);
+ }
+ }
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size());
+
+ double det_thresh_speed_adjust = 0;
+
+
+ // we will scale the loss so that it doesn't get really huge
+ const double scale = 1.0/output_tensor.size();
+ double loss = 0;
+
+ float* g = grad.host_write_only();
+ for (size_t i = 0; i < grad.size(); ++i)
+ g[i] = 0;
+
+ const float* out_data = output_tensor.host();
+
+ std::vector<size_t> truth_idxs; truth_idxs.reserve(truth->size());
+ std::vector<intermediate_detection> dets;
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ tensor_to_dets(input_tensor, output_tensor, i, dets, -options.loss_per_false_alarm + det_thresh_speed_adjust, sub);
+
+ const unsigned long max_num_dets = 50 + truth->size()*5;
+ // Prevent calls to tensor_to_dets() from running for a really long time
+ // due to the production of an obscene number of detections.
+ const unsigned long max_num_initial_dets = max_num_dets*100;
+ if (dets.size() >= max_num_initial_dets)
+ {
+ det_thresh_speed_adjust = std::max(det_thresh_speed_adjust,dets[max_num_initial_dets].detection_confidence + options.loss_per_false_alarm);
+ }
+
+
+ // The loss will measure the number of incorrect detections. A detection is
+ // incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection
+ // on a truth rectangle.
+ loss += truth->size()*options.loss_per_missed_target;
+ for (auto&& x : *truth)
+ {
+ if (!x.ignore)
+ {
+ size_t k;
+ point p;
+ if(image_rect_to_feat_coord(p, input_tensor, x, x.label, sub, k, options.assume_image_pyramid))
+ {
+ // Ignore boxes that can't be detected by the CNN.
+ loss -= options.loss_per_missed_target;
+ continue;
+ }
+ const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x();
+ loss -= out_data[idx];
+ // compute gradient
+ g[idx] = -scale;
+ truth_idxs.push_back(idx);
+ }
+ else
+ {
+ // This box was ignored so shouldn't have been counted in the loss.
+ loss -= options.loss_per_missed_target;
+ truth_idxs.push_back(0);
+ }
+ }
+
+ // Measure the loss augmented score for the detections which hit a truth rect.
+ std::vector<double> truth_score_hits(truth->size(), 0);
+
+ // keep track of which truth boxes we have hit so far.
+ std::vector<bool> hit_truth_table(truth->size(), false);
+
+ std::vector<intermediate_detection> final_dets;
+ // The point of this loop is to fill out the truth_score_hits array.
+ for (unsigned long i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i)
+ {
+ if (overlaps_any_box_nms(final_dets, dets[i].rect))
+ continue;
+
+ const auto& det_label = options.detector_windows[dets[i].tensor_channel].label;
+
+ const std::pair<double,unsigned int> hittruth = find_best_match(*truth, dets[i].rect, det_label);
+
+ final_dets.push_back(dets[i].rect);
+
+ const double truth_match = hittruth.first;
+ // if hit truth rect
+ if (truth_match > options.truth_match_iou_threshold)
+ {
+ // if this is the first time we have seen a detect which hit (*truth)[hittruth.second]
+ const double score = dets[i].detection_confidence;
+ if (hit_truth_table[hittruth.second] == false)
+ {
+ hit_truth_table[hittruth.second] = true;
+ truth_score_hits[hittruth.second] += score;
+ }
+ else
+ {
+ truth_score_hits[hittruth.second] += score + options.loss_per_false_alarm;
+ }
+ }
+ }
+
+ // Check if any of the truth boxes are unobtainable because the NMS is
+ // killing them. If so, automatically set those unobtainable boxes to
+ // ignore and print a warning message to the user.
+ for (size_t i = 0; i < hit_truth_table.size(); ++i)
+ {
+ if (!hit_truth_table[i] && !(*truth)[i].ignore)
+ {
+ // So we didn't hit this truth box. Is that because there is
+ // another, different truth box, that overlaps it according to NMS?
+ const std::pair<double,unsigned int> hittruth = find_best_match(*truth, (*truth)[i], i);
+ if (hittruth.second == i || (*truth)[hittruth.second].ignore)
+ continue;
+ rectangle best_matching_truth_box = (*truth)[hittruth.second];
+ if (options.overlaps_nms(best_matching_truth_box, (*truth)[i]))
+ {
+ const size_t idx = truth_idxs[i];
+ // We are ignoring this box so we shouldn't have counted it in the
+ // loss in the first place. So we subtract out the loss values we
+ // added for it in the code above.
+ loss -= options.loss_per_missed_target-out_data[idx];
+ g[idx] = 0;
+ std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect;
+ std::cout << " that is suppressed by non-max-suppression ";
+ std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box
+ << " (IoU:"<< box_intersection_over_union(best_matching_truth_box,(*truth)[i]) <<", Percent covered:"
+ << box_percent_covered(best_matching_truth_box,(*truth)[i]) << ")." << std::endl;
+ }
+ }
+ }
+
+ hit_truth_table.assign(hit_truth_table.size(), false);
+ final_dets.clear();
+
+
+ // Now figure out which detections jointly maximize the loss and detection score sum. We
+ // need to take into account the fact that allowing a true detection in the output, while
+ // initially reducing the loss, may allow us to increase the loss later with many duplicate
+ // detections.
+ for (unsigned long i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i)
+ {
+ if (overlaps_any_box_nms(final_dets, dets[i].rect))
+ continue;
+
+ const auto& det_label = options.detector_windows[dets[i].tensor_channel].label;
+
+ const std::pair<double,unsigned int> hittruth = find_best_match(*truth, dets[i].rect, det_label);
+
+ const double truth_match = hittruth.first;
+ if (truth_match > options.truth_match_iou_threshold)
+ {
+ if (truth_score_hits[hittruth.second] > options.loss_per_missed_target)
+ {
+ if (!hit_truth_table[hittruth.second])
+ {
+ hit_truth_table[hittruth.second] = true;
+ final_dets.push_back(dets[i]);
+ loss -= options.loss_per_missed_target;
+ }
+ else
+ {
+ final_dets.push_back(dets[i]);
+ loss += options.loss_per_false_alarm;
+ }
+ }
+ }
+ else if (!overlaps_ignore_box(*truth, dets[i].rect))
+ {
+ // didn't hit anything
+ final_dets.push_back(dets[i]);
+ loss += options.loss_per_false_alarm;
+ }
+ }
+
+ for (auto&& x : final_dets)
+ {
+ loss += out_data[x.tensor_offset];
+ g[x.tensor_offset] += scale;
+ }
+
+ ++truth;
+ g += output_tensor.k()*output_tensor.nr()*output_tensor.nc();
+ out_data += output_tensor.k()*output_tensor.nr()*output_tensor.nc();
+ } // END for (long i = 0; i < output_tensor.num_samples(); ++i)
+
+
+ // Here we scale the loss so that it's roughly equal to the number of mistakes
+ // in an image. Note that this scaling is different than the scaling we
+ // applied to the gradient but it doesn't matter since the loss value isn't
+ // used to update parameters. It's used only for display and to check if we
+ // have converged. So it doesn't matter that they are scaled differently and
+ // this way the loss that is displayed is readily interpretable to the user.
+ return loss/output_tensor.num_samples();
+ }
+
+
+ friend void serialize(const loss_mmod_& item, std::ostream& out)
+ {
+ serialize("loss_mmod_", out);
+ serialize(item.options, out);
+ }
+
+ friend void deserialize(loss_mmod_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_mmod_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_mmod_.");
+ deserialize(item.options, in);
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_mmod_& item)
+ {
+ out << "loss_mmod\t (";
+
+ out << "detector_windows:(";
+ auto& opts = item.options;
+ for (size_t i = 0; i < opts.detector_windows.size(); ++i)
+ {
+ out << opts.detector_windows[i].width << "x" << opts.detector_windows[i].height;
+ if (i+1 < opts.detector_windows.size())
+ out << ",";
+ }
+ out << ")";
+ out << ", loss per FA:" << opts.loss_per_false_alarm;
+ out << ", loss per miss:" << opts.loss_per_missed_target;
+ out << ", truth match IOU thresh:" << opts.truth_match_iou_threshold;
+ out << ", overlaps_nms:("<<opts.overlaps_nms.get_iou_thresh()<<","<<opts.overlaps_nms.get_percent_covered_thresh()<<")";
+ out << ", overlaps_ignore:("<<opts.overlaps_ignore.get_iou_thresh()<<","<<opts.overlaps_ignore.get_percent_covered_thresh()<<")";
+
+ out << ")";
+ return out;
+ }
+
+ friend void to_xml(const loss_mmod_& /*item*/, std::ostream& out)
+ {
+ // TODO, add options fields
+ out << "<loss_mmod/>";
+ }
+
+ private:
+
+ template <typename net_type>
+ void tensor_to_dets (
+ const tensor& input_tensor,
+ const tensor& output_tensor,
+ long i,
+ std::vector<intermediate_detection>& dets_accum,
+ double adjust_threshold,
+ const net_type& net
+ ) const
+ {
+ DLIB_CASSERT(net.sample_expansion_factor() == 1,net.sample_expansion_factor());
+ DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size());
+ const float* out_data = output_tensor.host() + output_tensor.k()*output_tensor.nr()*output_tensor.nc()*i;
+ // scan the final layer and output the positive scoring locations
+ dets_accum.clear();
+ for (long k = 0; k < output_tensor.k(); ++k)
+ {
+ for (long r = 0; r < output_tensor.nr(); ++r)
+ {
+ for (long c = 0; c < output_tensor.nc(); ++c)
+ {
+ double score = out_data[(k*output_tensor.nr() + r)*output_tensor.nc() + c];
+ if (score > adjust_threshold)
+ {
+ dpoint p = output_tensor_to_input_tensor(net, point(c,r));
+ drectangle rect = centered_drect(p, options.detector_windows[k].width, options.detector_windows[k].height);
+ rect = input_layer(net).tensor_space_to_image_space(input_tensor,rect);
+
+ dets_accum.push_back(intermediate_detection(rect, score, (k*output_tensor.nr() + r)*output_tensor.nc() + c, k));
+ }
+ }
+ }
+ }
+ std::sort(dets_accum.rbegin(), dets_accum.rend());
+ }
+
+ size_t find_best_detection_window (
+ rectangle rect,
+ const std::string& label,
+ use_image_pyramid assume_image_pyramid
+ ) const
+ {
+ if (assume_image_pyramid == use_image_pyramid::yes)
+ {
+ rect = move_rect(set_rect_area(rect, 400*400), point(0,0));
+ }
+ else
+ {
+ rect = rectangle(rect.width(), rect.height());
+ }
+
+ // Figure out which detection window in options.detector_windows is most similar to rect
+ // (in terms of aspect ratio, if assume_image_pyramid == use_image_pyramid::yes).
+ size_t best_i = 0;
+ double best_ratio_diff = -std::numeric_limits<double>::infinity();
+ for (size_t i = 0; i < options.detector_windows.size(); ++i)
+ {
+ if (options.detector_windows[i].label != label)
+ continue;
+
+ rectangle det_window;
+
+ if (options.assume_image_pyramid == use_image_pyramid::yes)
+ {
+ det_window = centered_rect(point(0,0), options.detector_windows[i].width, options.detector_windows[i].height);
+ det_window = move_rect(set_rect_area(det_window, 400*400), point(0,0));
+ }
+ else
+ {
+ det_window = rectangle(options.detector_windows[i].width, options.detector_windows[i].height);
+ }
+
+ double iou = box_intersection_over_union(rect, det_window);
+ if (iou > best_ratio_diff)
+ {
+ best_ratio_diff = iou;
+ best_i = i;
+ }
+ }
+ return best_i;
+ }
+
+ template <typename net_type>
+ bool image_rect_to_feat_coord (
+ point& tensor_p,
+ const tensor& input_tensor,
+ const rectangle& rect,
+ const std::string& label,
+ const net_type& net,
+ size_t& det_idx,
+ use_image_pyramid assume_image_pyramid
+ ) const
+ {
+ using namespace std;
+ if (!input_layer(net).image_contained_point(input_tensor,center(rect)))
+ {
+ std::ostringstream sout;
+ sout << "Encountered a truth rectangle located at " << rect << " that is outside the image." << endl;
+ sout << "The center of each truth rectangle must be within the image." << endl;
+ throw impossible_labeling_error(sout.str());
+ }
+
+ det_idx = find_best_detection_window(rect,label,assume_image_pyramid);
+
+ double scale = 1.0;
+ if (options.assume_image_pyramid == use_image_pyramid::yes)
+ {
+ // Compute the scale we need to be at to get from rect to our detection window.
+ // Note that we compute the scale as the max of two numbers. It doesn't
+ // actually matter which one we pick, because if they are very different then
+ // it means the box can't be matched by the sliding window. But picking the
+ // max causes the right error message to be selected in the logic below.
+ scale = std::max(options.detector_windows[det_idx].width/(double)rect.width(), options.detector_windows[det_idx].height/(double)rect.height());
+ }
+ else
+ {
+ // We don't want invariance to scale.
+ scale = 1.0;
+ }
+
+ const rectangle mapped_rect = input_layer(net).image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect);
+
+ // compute the detection window that we would use at this position.
+ tensor_p = center(mapped_rect);
+ rectangle det_window = centered_rect(tensor_p, options.detector_windows[det_idx].width,options.detector_windows[det_idx].height);
+ det_window = input_layer(net).tensor_space_to_image_space(input_tensor, det_window);
+
+ // make sure the rect can actually be represented by the image pyramid we are
+ // using.
+ if (box_intersection_over_union(rect, det_window) <= options.truth_match_iou_threshold)
+ {
+ std::cout << "Warning, ignoring object. We encountered a truth rectangle with a width and height of " << rect.width() << " and " << rect.height() << ". ";
+ std::cout << "The image pyramid and sliding windows can't output a rectangle of this shape. ";
+ const double detector_area = options.detector_windows[det_idx].width*options.detector_windows[det_idx].height;
+ if (mapped_rect.area()/detector_area <= options.truth_match_iou_threshold)
+ {
+ std::cout << "This is because the rectangle is smaller than the best matching detection window, which has a width ";
+ std::cout << "and height of " << options.detector_windows[det_idx].width << " and " << options.detector_windows[det_idx].height << "." << std::endl;
+ }
+ else
+ {
+ std::cout << "This is either because (1) the final layer's features have too large of a stride across the image, limiting the possible locations the sliding window can search ";
+ std::cout << "or (2) because the rectangle's aspect ratio is too different from the best matching detection window, ";
+ std::cout << "which has a width and height of " << options.detector_windows[det_idx].width << " and " << options.detector_windows[det_idx].height << "." << std::endl;
+ }
+ return true;
+ }
+
+ // now map through the CNN to the output layer.
+ tensor_p = input_tensor_to_output_tensor(net,tensor_p);
+
+ const tensor& output_tensor = net.get_output();
+ if (!get_rect(output_tensor).contains(tensor_p))
+ {
+ std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << rect << " that is too close to the edge ";
+ std::cout << "of the image to be captured by the CNN features." << std::endl;
+ return true;
+ }
+
+ return false;
+ }
+
+
+ bool overlaps_ignore_box (
+ const std::vector<mmod_rect>& boxes,
+ const rectangle& rect
+ ) const
+ {
+ for (auto&& b : boxes)
+ {
+ if (b.ignore && options.overlaps_ignore(b, rect))
+ return true;
+ }
+ return false;
+ }
+
+ std::pair<double,unsigned int> find_best_match(
+ const std::vector<mmod_rect>& boxes,
+ const rectangle& rect,
+ const std::string& label
+ ) const
+ {
+ double match = 0;
+ unsigned int best_idx = 0;
+ for (unsigned long i = 0; i < boxes.size(); ++i)
+ {
+ if (boxes[i].ignore || boxes[i].label != label)
+ continue;
+
+ const double new_match = box_intersection_over_union(rect, boxes[i]);
+ if (new_match > match)
+ {
+ match = new_match;
+ best_idx = i;
+ }
+ }
+
+ return std::make_pair(match,best_idx);
+ }
+
+ std::pair<double,unsigned int> find_best_match(
+ const std::vector<mmod_rect>& boxes,
+ const rectangle& rect,
+ const size_t excluded_idx
+ ) const
+ {
+ double match = 0;
+ unsigned int best_idx = 0;
+ for (unsigned long i = 0; i < boxes.size(); ++i)
+ {
+ if (boxes[i].ignore || excluded_idx == i)
+ continue;
+
+ const double new_match = box_intersection_over_union(rect, boxes[i]);
+ if (new_match > match)
+ {
+ match = new_match;
+ best_idx = i;
+ }
+ }
+
+ return std::make_pair(match,best_idx);
+ }
+
+ template <typename T>
+ inline bool overlaps_any_box_nms (
+ const std::vector<T>& rects,
+ const rectangle& rect
+ ) const
+ {
+ for (auto&& r : rects)
+ {
+ if (options.overlaps_nms(r.rect, rect))
+ return true;
+ }
+ return false;
+ }
+
+
+ mmod_options options;
+
+ };
+
+ template <typename SUBNET>
+ using loss_mmod = add_loss_layer<loss_mmod_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_metric_
+ {
+ public:
+
+ typedef unsigned long training_label_type;
+ typedef matrix<float,0,1> output_label_type;
+
+ loss_metric_() = default;
+
+ loss_metric_(
+ float margin_,
+ float dist_thresh_
+ ) : margin(margin_), dist_thresh(dist_thresh_)
+ {
+ DLIB_CASSERT(margin_ > 0);
+ DLIB_CASSERT(dist_thresh_ > 0);
+ }
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1);
+
+ const float* p = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ *iter = mat(p,output_tensor.k(),1);
+
+ ++iter;
+ p += output_tensor.k();
+ }
+ }
+
+
+ float get_margin() const { return margin; }
+ float get_distance_threshold() const { return dist_thresh; }
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1);
+ DLIB_CASSERT(grad.nr() == 1 &&
+ grad.nc() == 1);
+
+
+
+ temp.set_size(output_tensor.num_samples(), output_tensor.num_samples());
+ grad_mul.copy_size(temp);
+
+ tt::gemm(0, temp, 1, output_tensor, false, output_tensor, true);
+
+
+ std::vector<double> temp_threshs;
+ const float* d = temp.host();
+ double loss = 0;
+ double num_pos_samps = 0.0001;
+ double num_neg_samps = 0.0001;
+ for (long r = 0; r < temp.num_samples(); ++r)
+ {
+ auto xx = d[r*temp.num_samples() + r];
+ const auto x_label = *(truth + r);
+ for (long c = r+1; c < temp.num_samples(); ++c)
+ {
+ const auto y_label = *(truth + c);
+ if (x_label == y_label)
+ {
+ ++num_pos_samps;
+ }
+ else
+ {
+ ++num_neg_samps;
+
+ // Figure out what distance threshold, when applied to the negative pairs,
+ // causes there to be an equal number of positive and negative pairs.
+ auto yy = d[c*temp.num_samples() + c];
+ auto xy = d[r*temp.num_samples() + c];
+ // compute the distance between x and y samples.
+ auto d2 = xx + yy - 2*xy;
+ if (d2 < 0)
+ d2 = 0;
+ temp_threshs.push_back(d2);
+ }
+ }
+ }
+ // The whole objective function is multiplied by this to scale the loss
+ // relative to the number of things in the mini-batch.
+ const double scale = 0.5/num_pos_samps;
+ DLIB_CASSERT(num_pos_samps>=1, "Make sure each mini-batch contains both positive pairs and negative pairs");
+ DLIB_CASSERT(num_neg_samps>=1, "Make sure each mini-batch contains both positive pairs and negative pairs");
+
+ std::sort(temp_threshs.begin(), temp_threshs.end());
+ const float neg_thresh = std::sqrt(temp_threshs[std::min(num_pos_samps,num_neg_samps)-1]);
+
+ // loop over all the pairs of training samples and compute the loss and
+ // gradients. Note that we only use the hardest negative pairs and that in
+ // particular we pick the number of negative pairs equal to the number of
+ // positive pairs so everything is balanced.
+ float* gm = grad_mul.host();
+ for (long r = 0; r < temp.num_samples(); ++r)
+ {
+ gm[r*temp.num_samples() + r] = 0;
+ const auto x_label = *(truth + r);
+ auto xx = d[r*temp.num_samples() + r];
+ for (long c = 0; c < temp.num_samples(); ++c)
+ {
+ if (r==c)
+ continue;
+ const auto y_label = *(truth + c);
+ auto yy = d[c*temp.num_samples() + c];
+ auto xy = d[r*temp.num_samples() + c];
+
+ // compute the distance between x and y samples.
+ auto d2 = xx + yy - 2*xy;
+ if (d2 <= 0)
+ d2 = 0;
+ else
+ d2 = std::sqrt(d2);
+
+ // It should be noted that the derivative of length(x-y) with respect
+ // to the x vector is the unit vector (x-y)/length(x-y). If you stare
+ // at the code below long enough you will see that it's just an
+ // application of this formula.
+
+ if (x_label == y_label)
+ {
+ // Things with the same label should have distances < dist_thresh between
+ // them. If not then we experience non-zero loss.
+ if (d2 < dist_thresh-margin)
+ {
+ gm[r*temp.num_samples() + c] = 0;
+ }
+ else
+ {
+ loss += scale*(d2 - (dist_thresh-margin));
+ gm[r*temp.num_samples() + r] += scale/d2;
+ gm[r*temp.num_samples() + c] = -scale/d2;
+ }
+ }
+ else
+ {
+ // Things with different labels should have distances > dist_thresh between
+ // them. If not then we experience non-zero loss.
+ if (d2 > dist_thresh+margin || d2 > neg_thresh)
+ {
+ gm[r*temp.num_samples() + c] = 0;
+ }
+ else
+ {
+ loss += scale*((dist_thresh+margin) - d2);
+ // don't divide by zero (or a really small number)
+ d2 = std::max(d2, 0.001f);
+ gm[r*temp.num_samples() + r] -= scale/d2;
+ gm[r*temp.num_samples() + c] = scale/d2;
+ }
+ }
+ }
+ }
+
+
+ tt::gemm(0, grad, 1, grad_mul, false, output_tensor, false);
+
+ return loss;
+ }
+
+ friend void serialize(const loss_metric_& item, std::ostream& out)
+ {
+ serialize("loss_metric_2", out);
+ serialize(item.margin, out);
+ serialize(item.dist_thresh, out);
+ }
+
+ friend void deserialize(loss_metric_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version == "loss_metric_")
+ {
+ // These values used to be hard coded, so for this version of the metric
+ // learning loss we just use these values.
+ item.margin = 0.1;
+ item.dist_thresh = 0.75;
+ return;
+ }
+ else if (version == "loss_metric_2")
+ {
+ deserialize(item.margin, in);
+ deserialize(item.dist_thresh, in);
+ }
+ else
+ {
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_metric_. Instead found " + version);
+ }
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_metric_& item )
+ {
+ out << "loss_metric (margin="<<item.margin<<", distance_threshold="<<item.dist_thresh<<")";
+ return out;
+ }
+
+ friend void to_xml(const loss_metric_& item, std::ostream& out)
+ {
+ out << "<loss_metric margin='"<<item.margin<<"' distance_threshold='"<<item.dist_thresh<<"'/>";
+ }
+
+ private:
+ float margin = 0.04;
+ float dist_thresh = 0.6;
+
+
+ // These variables are only here to avoid being reallocated over and over in
+ // compute_loss_value_and_gradient()
+ mutable resizable_tensor temp, grad_mul;
+
+ };
+
+ template <typename SUBNET>
+ using loss_metric = add_loss_layer<loss_metric_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_ranking_
+ {
+ public:
+
+ typedef float training_label_type; // nominally +1/-1
+ typedef float output_label_type; // ranking score
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+
+ const tensor& output_tensor = sub.get_output();
+
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ *iter++ = out_data[i];
+ }
+ }
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+ DLIB_CASSERT(grad.nr() == 1 &&
+ grad.nc() == 1 &&
+ grad.k() == 1);
+
+
+ std::vector<double> rel_scores;
+ std::vector<double> nonrel_scores;
+ std::vector<long> rel_idx, nonrel_idx;
+
+ const float* out_data = output_tensor.host();
+ float* g = grad.host_write_only();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ const float y = *truth++;
+ if (y > 0)
+ {
+ rel_scores.push_back(out_data[i]-y);
+ rel_idx.push_back(i);
+ }
+ else if (y < 0)
+ {
+ nonrel_scores.push_back(out_data[i]-y);
+ nonrel_idx.push_back(i);
+ }
+ else
+ {
+ g[i] = 0;
+ }
+ }
+
+
+ std::vector<unsigned long> rel_counts;
+ std::vector<unsigned long> nonrel_counts;
+ count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
+ const unsigned long total_pairs = rel_scores.size()*nonrel_scores.size();
+ DLIB_CASSERT(total_pairs > 0, "You can't give a ranking mini-batch that contains only one class. Both classes must be represented.");
+ const double scale = 1.0/total_pairs;
+
+
+ double loss = 0;
+ for (unsigned long k = 0; k < rel_counts.size(); ++k)
+ {
+ loss -= rel_counts[k]*rel_scores[k];
+ g[rel_idx[k]] = -1.0*rel_counts[k]*scale;
+ }
+
+ for (unsigned long k = 0; k < nonrel_counts.size(); ++k)
+ {
+ loss += nonrel_counts[k]*nonrel_scores[k];
+ g[nonrel_idx[k]] = nonrel_counts[k]*scale;
+ }
+
+ return loss*scale;
+ }
+
+ friend void serialize(const loss_ranking_& , std::ostream& out)
+ {
+ serialize("loss_ranking_", out);
+ }
+
+ friend void deserialize(loss_ranking_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_ranking_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_ranking_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_ranking_& )
+ {
+ out << "loss_ranking";
+ return out;
+ }
+
+ friend void to_xml(const loss_ranking_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_ranking/>";
+ }
+
+ };
+
+ template <typename SUBNET>
+ using loss_ranking = add_loss_layer<loss_ranking_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_mean_squared_
+ {
+ public:
+
+ typedef float training_label_type;
+ typedef float output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+
+ const tensor& output_tensor = sub.get_output();
+
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ *iter++ = out_data[i];
+ }
+ }
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+ DLIB_CASSERT(grad.nr() == 1 &&
+ grad.nc() == 1 &&
+ grad.k() == 1);
+
+ // The loss we output is the average loss over the mini-batch.
+ const double scale = 1.0/output_tensor.num_samples();
+ double loss = 0;
+ float* g = grad.host_write_only();
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ const float y = *truth++;
+ const float temp1 = y - out_data[i];
+ const float temp2 = scale*temp1;
+ loss += temp2*temp1;
+ g[i] = -temp2;
+
+ }
+ return loss;
+ }
+
+ friend void serialize(const loss_mean_squared_& , std::ostream& out)
+ {
+ serialize("loss_mean_squared_", out);
+ }
+
+ friend void deserialize(loss_mean_squared_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_mean_squared_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_& )
+ {
+ out << "loss_mean_squared";
+ return out;
+ }
+
+ friend void to_xml(const loss_mean_squared_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_mean_squared/>";
+ }
+
+ };
+
+ template <typename SUBNET>
+ using loss_mean_squared = add_loss_layer<loss_mean_squared_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_epsilon_insensitive_
+ {
+ public:
+
+ typedef float training_label_type;
+ typedef float output_label_type;
+
+ loss_epsilon_insensitive_() = default;
+ loss_epsilon_insensitive_(double eps) : eps(eps)
+ {
+ DLIB_CASSERT(eps >= 0, "You can't set a negative error epsilon.");
+ }
+
+ double get_epsilon () const { return eps; }
+ void set_epsilon(double e)
+ {
+ DLIB_CASSERT(e >= 0, "You can't set a negative error epsilon.");
+ eps = e;
+ }
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+
+ const tensor& output_tensor = sub.get_output();
+
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ *iter++ = out_data[i];
+ }
+ }
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1 &&
+ output_tensor.k() == 1);
+ DLIB_CASSERT(grad.nr() == 1 &&
+ grad.nc() == 1 &&
+ grad.k() == 1);
+
+ // The loss we output is the average loss over the mini-batch.
+ const double scale = 1.0/output_tensor.num_samples();
+ double loss = 0;
+ float* g = grad.host_write_only();
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ const float y = *truth++;
+ const float err = out_data[i]-y;
+ if (err > eps)
+ {
+ loss += scale*(err-eps);
+ g[i] = scale;
+ }
+ else if (err < -eps)
+ {
+ loss += scale*(eps-err);
+ g[i] = -scale;
+ }
+ }
+ return loss;
+ }
+
+ friend void serialize(const loss_epsilon_insensitive_& item, std::ostream& out)
+ {
+ serialize("loss_epsilon_insensitive_", out);
+ serialize(item.eps, out);
+ }
+
+ friend void deserialize(loss_epsilon_insensitive_& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_epsilon_insensitive_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_epsilon_insensitive_.");
+ deserialize(item.eps, in);
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_epsilon_insensitive_& item)
+ {
+ out << "loss_epsilon_insensitive epsilon: " << item.eps;
+ return out;
+ }
+
+ friend void to_xml(const loss_epsilon_insensitive_& item, std::ostream& out)
+ {
+ out << "<loss_epsilon_insensitive_ epsilon='" << item.eps << "'/>";
+ }
+
+ private:
+ double eps = 1;
+
+ };
+
+ template <typename SUBNET>
+ using loss_epsilon_insensitive = add_loss_layer<loss_epsilon_insensitive_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_mean_squared_multioutput_
+ {
+ public:
+
+ typedef matrix<float> training_label_type;
+ typedef matrix<float> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+
+ const tensor& output_tensor = sub.get_output();
+
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1)
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ *iter++ = mat(out_data, output_tensor.k(), 1);
+ out_data += output_tensor.k();
+ }
+ }
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.nr() == 1 &&
+ output_tensor.nc() == 1);
+ DLIB_CASSERT(grad.nr() == 1 &&
+ grad.nc() == 1);
+ DLIB_CASSERT(grad.k() == output_tensor.k());
+ const long k = output_tensor.k();
+ for (long idx = 0; idx < output_tensor.num_samples(); ++idx)
+ {
+ const_label_iterator truth_matrix_ptr = (truth + idx);
+ DLIB_CASSERT((*truth_matrix_ptr).nr() == k &&
+ (*truth_matrix_ptr).nc() == 1);
+ }
+
+ // The loss we output is the average loss over the mini-batch.
+ const double scale = 1.0/output_tensor.num_samples();
+ double loss = 0;
+ float* g = grad.host_write_only();
+ const float* out_data = output_tensor.host();
+ matrix<float> ytrue;
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ ytrue = *truth++;
+ for (long j = 0; j < output_tensor.k(); ++j)
+ {
+ const float y = ytrue(j, 0);
+ const float temp1 = y - *out_data++;
+ const float temp2 = scale*temp1;
+ loss += temp2*temp1;
+ *g = -temp2;
+ ++g;
+ }
+
+ }
+ return loss;
+ }
+
+ friend void serialize(const loss_mean_squared_multioutput_& , std::ostream& out)
+ {
+ serialize("loss_mean_squared_multioutput_", out);
+ }
+
+ friend void deserialize(loss_mean_squared_multioutput_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_mean_squared_multioutput_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_multioutput_& )
+ {
+ out << "loss_mean_squared_multioutput";
+ return out;
+ }
+
+ friend void to_xml(const loss_mean_squared_multioutput_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_mean_squared_multioutput/>";
+ }
+
+ };
+
+ template <typename SUBNET>
+ using loss_mean_squared_multioutput = add_loss_layer<loss_mean_squared_multioutput_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_multiclass_log_per_pixel_
+ {
+ public:
+
+ // In semantic segmentation, if you don't know the ground-truth of some pixel,
+ // set the label of that pixel to this value. When you do so, the pixel will be
+ // ignored when computing gradients.
+ static const uint16_t label_to_ignore = std::numeric_limits<uint16_t>::max();
+
+
+ // In semantic segmentation, 65535 classes ought to be enough for anybody.
+ typedef matrix<uint16_t> training_label_type;
+ typedef matrix<uint16_t> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ static void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ )
+ {
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+
+ const tensor& output_tensor = sub.get_output();
+
+ DLIB_CASSERT(output_tensor.k() >= 1); // Note that output_tensor.k() should match the number of labels.
+ DLIB_CASSERT(output_tensor.k() < std::numeric_limits<uint16_t>::max());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ const float* const out_data = output_tensor.host();
+
+ // The index of the largest output for each element is the label.
+ const auto find_label = [&](long sample, long r, long c)
+ {
+ uint16_t label = 0;
+ float max_value = out_data[tensor_index(output_tensor, sample, 0, r, c)];
+ for (long k = 1; k < output_tensor.k(); ++k)
+ {
+ const float value = out_data[tensor_index(output_tensor, sample, k, r, c)];
+ if (value > max_value)
+ {
+ label = static_cast<uint16_t>(k);
+ max_value = value;
+ }
+ }
+ return label;
+ };
+
+ for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter)
+ {
+ iter->set_size(output_tensor.nr(), output_tensor.nc());
+ for (long r = 0; r < output_tensor.nr(); ++r)
+ {
+ for (long c = 0; c < output_tensor.nc(); ++c)
+ {
+ // The index of the largest output for this element is the label.
+ iter->operator()(r, c) = find_label(i, r, c);
+ }
+ }
+ }
+ }
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.k() >= 1);
+ DLIB_CASSERT(output_tensor.k() < std::numeric_limits<uint16_t>::max());
+ DLIB_CASSERT(output_tensor.nr() == grad.nr() &&
+ output_tensor.nc() == grad.nc() &&
+ output_tensor.k() == grad.k());
+ for (long idx = 0; idx < output_tensor.num_samples(); ++idx)
+ {
+ const_label_iterator truth_matrix_ptr = (truth + idx);
+ DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() &&
+ truth_matrix_ptr->nc() == output_tensor.nc(),
+ "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", "
+ "output size = " << output_tensor.nr() << " x " << output_tensor.nc());
+ }
+
+ tt::softmax(grad, output_tensor);
+
+ // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output.
+ const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc());
+ double loss = 0;
+ float* const g = grad.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth)
+ {
+ for (long r = 0; r < output_tensor.nr(); ++r)
+ {
+ for (long c = 0; c < output_tensor.nc(); ++c)
+ {
+ const uint16_t y = truth->operator()(r, c);
+ // The network must produce a number of outputs that is equal to the number
+ // of labels when using this type of loss.
+ DLIB_CASSERT(static_cast<long>(y) < output_tensor.k() || y == label_to_ignore,
+ "y: " << y << ", output_tensor.k(): " << output_tensor.k());
+ for (long k = 0; k < output_tensor.k(); ++k)
+ {
+ const size_t idx = tensor_index(output_tensor, i, k, r, c);
+ if (k == y)
+ {
+ loss += scale*-safe_log(g[idx]);
+ g[idx] = scale*(g[idx] - 1);
+ }
+ else if (y == label_to_ignore)
+ {
+ g[idx] = 0.f;
+ }
+ else
+ {
+ g[idx] = scale*g[idx];
+ }
+ }
+ }
+ }
+ }
+ return loss;
+ }
+
+ friend void serialize(const loss_multiclass_log_per_pixel_& , std::ostream& out)
+ {
+ serialize("loss_multiclass_log_per_pixel_", out);
+ }
+
+ friend void deserialize(loss_multiclass_log_per_pixel_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_multiclass_log_per_pixel_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_per_pixel_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_per_pixel_& )
+ {
+ out << "loss_multiclass_log_per_pixel";
+ return out;
+ }
+
+ friend void to_xml(const loss_multiclass_log_per_pixel_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_multiclass_log_per_pixel/>";
+ }
+
+ private:
+ static size_t tensor_index(const tensor& t, long sample, long k, long row, long column)
+ {
+ // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38
+ return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column;
+ }
+
+ };
+
+ template <typename SUBNET>
+ using loss_multiclass_log_per_pixel = add_loss_layer<loss_multiclass_log_per_pixel_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_multiclass_log_per_pixel_weighted_
+ {
+ public:
+
+ struct weighted_label
+ {
+ weighted_label()
+ {}
+
+ weighted_label(uint16_t label, float weight = 1.f)
+ : label(label), weight(weight)
+ {}
+
+ // In semantic segmentation, 65536 classes ought to be enough for anybody.
+ uint16_t label = 0;
+ float weight = 1.f;
+ };
+
+ typedef matrix<weighted_label> training_label_type;
+ typedef matrix<uint16_t> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ static void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ )
+ {
+ loss_multiclass_log_per_pixel_::to_label(input_tensor, sub, iter);
+ }
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.k() >= 1);
+ DLIB_CASSERT(output_tensor.k() < std::numeric_limits<uint16_t>::max());
+ DLIB_CASSERT(output_tensor.nr() == grad.nr() &&
+ output_tensor.nc() == grad.nc() &&
+ output_tensor.k() == grad.k());
+ for (long idx = 0; idx < output_tensor.num_samples(); ++idx)
+ {
+ const_label_iterator truth_matrix_ptr = (truth + idx);
+ DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() &&
+ truth_matrix_ptr->nc() == output_tensor.nc(),
+ "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", "
+ "output size = " << output_tensor.nr() << " x " << output_tensor.nc());
+ }
+
+ tt::softmax(grad, output_tensor);
+
+ // The loss we output is the weighted average loss over the mini-batch, and also over each element of the matrix output.
+ const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc());
+ double loss = 0;
+ float* const g = grad.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth)
+ {
+ for (long r = 0; r < output_tensor.nr(); ++r)
+ {
+ for (long c = 0; c < output_tensor.nc(); ++c)
+ {
+ const weighted_label& weighted_label = truth->operator()(r, c);
+ const uint16_t y = weighted_label.label;
+ const float weight = weighted_label.weight;
+ // The network must produce a number of outputs that is equal to the number
+ // of labels when using this type of loss.
+ DLIB_CASSERT(static_cast<long>(y) < output_tensor.k() || weight == 0.f,
+ "y: " << y << ", output_tensor.k(): " << output_tensor.k());
+ for (long k = 0; k < output_tensor.k(); ++k)
+ {
+ const size_t idx = tensor_index(output_tensor, i, k, r, c);
+ if (k == y)
+ {
+ loss += weight*scale*-safe_log(g[idx]);
+ g[idx] = weight*scale*(g[idx] - 1);
+ }
+ else
+ {
+ g[idx] = weight*scale*g[idx];
+ }
+ }
+ }
+ }
+ }
+ return loss;
+ }
+
+ friend void serialize(const loss_multiclass_log_per_pixel_weighted_& , std::ostream& out)
+ {
+ serialize("loss_multiclass_log_per_pixel_weighted_", out);
+ }
+
+ friend void deserialize(loss_multiclass_log_per_pixel_weighted_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_multiclass_log_per_pixel_weighted_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_per_pixel_weighted_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_per_pixel_weighted_& )
+ {
+ out << "loss_multiclass_log_per_pixel_weighted";
+ return out;
+ }
+
+ friend void to_xml(const loss_multiclass_log_per_pixel_weighted_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_multiclass_log_per_pixel_weighted/>";
+ }
+
+ private:
+ static size_t tensor_index(const tensor& t, long sample, long k, long row, long column)
+ {
+ // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38
+ return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column;
+ }
+
+ };
+
+ template <typename SUBNET>
+ using loss_multiclass_log_per_pixel_weighted = add_loss_layer<loss_multiclass_log_per_pixel_weighted_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_mean_squared_per_pixel_
+ {
+ public:
+
+ typedef matrix<float> training_label_type;
+ typedef matrix<float> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+
+ const tensor& output_tensor = sub.get_output();
+
+ DLIB_CASSERT(output_tensor.k() == 1, "output k = " << output_tensor.k());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter)
+ {
+ iter->set_size(output_tensor.nr(), output_tensor.nc());
+ for (long r = 0; r < output_tensor.nr(); ++r)
+ {
+ for (long c = 0; c < output_tensor.nc(); ++c)
+ {
+ iter->operator()(r, c) = out_data[tensor_index(output_tensor, i, 0, r, c)];
+ }
+ }
+ }
+ }
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples() % sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+ DLIB_CASSERT(output_tensor.k() >= 1);
+ DLIB_CASSERT(output_tensor.k() < std::numeric_limits<uint16_t>::max());
+ DLIB_CASSERT(output_tensor.nr() == grad.nr() &&
+ output_tensor.nc() == grad.nc() &&
+ output_tensor.k() == grad.k());
+ for (long idx = 0; idx < output_tensor.num_samples(); ++idx)
+ {
+ const_label_iterator truth_matrix_ptr = (truth + idx);
+ DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() &&
+ truth_matrix_ptr->nc() == output_tensor.nc(),
+ "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", "
+ "output size = " << output_tensor.nr() << " x " << output_tensor.nc());
+ }
+
+ // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output.
+ const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc());
+ double loss = 0;
+ float* const g = grad.host();
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth)
+ {
+ for (long r = 0; r < output_tensor.nr(); ++r)
+ {
+ for (long c = 0; c < output_tensor.nc(); ++c)
+ {
+ const float y = truth->operator()(r, c);
+ const size_t idx = tensor_index(output_tensor, i, 0, r, c);
+ const float temp1 = y - out_data[idx];
+ const float temp2 = scale*temp1;
+ loss += temp2*temp1;
+ g[idx] = -temp2;
+ }
+ }
+ }
+ return loss;
+ }
+
+ friend void serialize(const loss_mean_squared_per_pixel_& , std::ostream& out)
+ {
+ serialize("loss_mean_squared_per_pixel_", out);
+ }
+
+ friend void deserialize(loss_mean_squared_per_pixel_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_mean_squared_per_pixel_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_per_pixel_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_per_pixel_& )
+ {
+ out << "loss_mean_squared_per_pixel";
+ return out;
+ }
+
+ friend void to_xml(const loss_mean_squared_per_pixel_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_mean_squared_per_pixel/>";
+ }
+
+ private:
+ static size_t tensor_index(const tensor& t, long sample, long k, long row, long column)
+ {
+ // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38
+ return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column;
+ }
+ };
+
+ template <typename SUBNET>
+ using loss_mean_squared_per_pixel = add_loss_layer<loss_mean_squared_per_pixel_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_dot_
+ {
+ public:
+
+ typedef matrix<float,0,1> training_label_type;
+ typedef matrix<float,0,1> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ *iter++ = trans(rowm(mat(output_tensor),i));
+ }
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const
+ {
+ const tensor& output_tensor = sub.get_output();
+ tensor& grad = sub.get_gradient_input();
+
+ DLIB_CASSERT(sub.sample_expansion_factor() == 1);
+ DLIB_CASSERT(input_tensor.num_samples() != 0);
+ DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
+ DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
+ DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
+
+ const long network_output_dims = output_tensor.size()/output_tensor.num_samples();
+
+
+ // The loss we output is the average loss over the mini-batch.
+ const double scale = 1.0/output_tensor.num_samples();
+ double loss = 0;
+ float* g = grad.host();
+ const float* out_data = output_tensor.host();
+ for (long i = 0; i < output_tensor.num_samples(); ++i)
+ {
+ DLIB_CASSERT(truth->size() == network_output_dims, "The network must output a vector with the same dimensionality as the training labels. "
+ << "\ntruth->size(): " << truth->size()
+ << "\nnetwork_output_dims: " << network_output_dims);
+
+ const float* t = &(*truth++)(0);
+
+ for (long j = 0; j < network_output_dims; ++j)
+ {
+ g[j] = -t[j]*scale;
+ loss -= out_data[j]*t[j];
+ }
+
+ g += network_output_dims;
+ out_data += network_output_dims;
+ }
+ return loss*scale;
+ }
+
+ friend void serialize(const loss_dot_& , std::ostream& out)
+ {
+ serialize("loss_dot_", out);
+ }
+
+ friend void deserialize(loss_dot_& , std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "loss_dot_")
+ throw serialization_error("Unexpected version found while deserializing dlib::loss_dot_.");
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, const loss_dot_& )
+ {
+ out << "loss_dot";
+ return out;
+ }
+
+ friend void to_xml(const loss_dot_& /*item*/, std::ostream& out)
+ {
+ out << "<loss_dot/>";
+ }
+
+ };
+
+ template <typename SUBNET>
+ using loss_dot = add_loss_layer<loss_dot_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_LOSS_H_
+
diff --git a/ml/dlib/dlib/dnn/loss_abstract.h b/ml/dlib/dlib/dnn/loss_abstract.h
new file mode 100644
index 000000000..0dd043677
--- /dev/null
+++ b/ml/dlib/dlib/dnn/loss_abstract.h
@@ -0,0 +1,1542 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DNn_LOSS_ABSTRACT_H_
+#ifdef DLIB_DNn_LOSS_ABSTRACT_H_
+
+#include "core_abstract.h"
+#include "../image_processing/full_object_detection_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class EXAMPLE_LOSS_LAYER_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ A loss layer is the final layer in a deep neural network. It computes the
+ task loss. That is, it computes a number that tells us how well the
+ network is performing on some task, such as predicting a binary label.
+
+ You can use one of the loss layers that comes with dlib (defined below).
+ But importantly, you are able to define your own loss layers to suit your
+ needs. You do this by creating a class that defines an interface matching
+ the one described by this EXAMPLE_LOSS_LAYER_ class. Note that there is no
+ dlib::EXAMPLE_LOSS_LAYER_ type. It is shown here purely to document the
+ interface that a loss layer must implement.
+
+ A loss layer can optionally provide a to_label() method that converts the
+ output of a network into a user defined type. If to_label() is not
+ provided then the operator() methods of add_loss_layer will not be
+ available, but otherwise everything will function as normal.
+
+ Finally, note that there are two broad flavors of loss layer, supervised
+ and unsupervised. The EXAMPLE_LOSS_LAYER_ as shown here is a supervised
+ layer. To make an unsupervised loss you simply leave out the
+ training_label_type typedef and the truth iterator argument to
+ compute_loss_value_and_gradient().
+ !*/
+
+ public:
+
+ // In most cases training_label_type and output_label_type will be the same type.
+ typedef whatever_type_you_use_for_training_labels training_label_type;
+ typedef whatever_type_you_use_for_outout_labels output_label_type;
+
+ EXAMPLE_LOSS_LAYER_ (
+ );
+ /*!
+ ensures
+ - EXAMPLE_LOSS_LAYER_ objects are default constructable.
+ !*/
+
+ EXAMPLE_LOSS_LAYER_ (
+ const EXAMPLE_LOSS_LAYER_& item
+ );
+ /*!
+ ensures
+ - EXAMPLE_LOSS_LAYER_ objects are copy constructable.
+ !*/
+
+ // Implementing to_label() is optional.
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ requires
+ - SUBNET implements the SUBNET interface defined at the top of
+ layers_abstract.h.
+ - input_tensor was given as input to the network sub and the outputs are
+ now visible in layer<i>(sub).get_output(), for all valid i.
+ - input_tensor.num_samples() > 0
+ - input_tensor.num_samples()%sub.sample_expansion_factor() == 0.
+ - iter == an iterator pointing to the beginning of a range of
+ input_tensor.num_samples()/sub.sample_expansion_factor() elements. Moreover,
+ they must be output_label_type elements.
+ ensures
+ - Converts the output of the provided network to output_label_type objects and
+ stores the results into the range indicated by iter. In particular, for
+ all valid i, it will be the case that:
+ *(iter+i/sub.sample_expansion_factor()) is populated based on the output of
+ sub and corresponds to the ith sample in input_tensor.
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ requires
+ - SUBNET implements the SUBNET interface defined at the top of
+ layers_abstract.h.
+ - input_tensor was given as input to the network sub and the outputs are
+ now visible in layer<i>(sub).get_output(), for all valid i.
+ - input_tensor.num_samples() > 0
+ - input_tensor.num_samples()%sub.sample_expansion_factor() == 0.
+ - for all valid i:
+ - layer<i>(sub).get_gradient_input() has the same dimensions as
+ layer<i>(sub).get_output().
+ - layer<i>(sub).get_gradient_input() contains all zeros (i.e.
+ initially, all input gradients are 0).
+ - truth == an iterator pointing to the beginning of a range of
+ input_tensor.num_samples()/sub.sample_expansion_factor() elements. Moreover,
+ they must be training_label_type elements.
+ - for all valid i:
+ - *(truth+i/sub.sample_expansion_factor()) is the label of the ith sample in
+ input_tensor.
+ ensures
+ - This function computes a loss function that describes how well the output
+ of sub matches the expected labels given by truth. Let's write the loss
+ function as L(input_tensor, truth, sub).
+ - Then compute_loss_value_and_gradient() computes the gradient of L() with
+ respect to the outputs in sub. Specifically, compute_loss_value_and_gradient()
+ assigns the gradients into sub by performing the following tensor
+ assignments, for all valid i:
+ - layer<i>(sub).get_gradient_input() = the gradient of
+ L(input_tensor,truth,sub) with respect to layer<i>(sub).get_output().
+ Note that, since get_gradient_input() is zero initialized, you don't
+ have to write gradient information to layers that have a zero
+ loss gradient.
+ - returns L(input_tensor,truth,sub)
+ !*/
+ };
+
+ std::ostream& operator<<(std::ostream& out, const EXAMPLE_LOSS_LAYER_& item);
+ /*!
+ print a string describing this layer.
+ !*/
+
+ void to_xml(const EXAMPLE_LOSS_LAYER_& item, std::ostream& out);
+ /*!
+ This function is optional, but required if you want to print your networks with
+ net_to_xml(). Therefore, to_xml() prints a layer as XML.
+ !*/
+
+ void serialize(const EXAMPLE_LOSS_LAYER_& item, std::ostream& out);
+ void deserialize(EXAMPLE_LOSS_LAYER_& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+ // For each loss layer you define, always define an add_loss_layer template so that
+ // layers can be easily composed. Moreover, the convention is that the layer class
+ // ends with an _ while the add_loss_layer template has the same name but without the
+ // trailing _.
+ template <typename SUBNET>
+ using EXAMPLE_LOSS_LAYER = add_loss_layer<EXAMPLE_LOSS_LAYER_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class loss_binary_hinge_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the hinge loss, which is
+ appropriate for binary classification problems. Therefore, the possible
+ labels when using this loss are +1 and -1. Moreover, it will cause the
+ network to produce outputs > 0 when predicting a member of the +1 class and
+ values < 0 otherwise.
+ !*/
+ public:
+
+ typedef float training_label_type;
+ typedef float output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output label is the raw score for each classified object. If the score
+ is > 0 then the classifier is predicting the +1 class, otherwise it is
+ predicting the -1 class.
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ - all values pointed to by truth are +1 or -1.
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_binary_hinge = add_loss_layer<loss_binary_hinge_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_binary_log_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the log loss, which is
+ appropriate for binary classification problems. Therefore, the possible
+ labels when using this loss are +1 and -1. Moreover, it will cause the
+ network to produce outputs > 0 when predicting a member of the +1 class and
+ values < 0 otherwise.
+
+ To be more specific, this object contains a sigmoid layer followed by a
+ cross-entropy layer.
+ !*/
+ public:
+
+ typedef float training_label_type;
+ typedef float output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output label is the raw score for each classified object. If the score
+ is > 0 then the classifier is predicting the +1 class, otherwise it is
+ predicting the -1 class.
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ - all values pointed to by truth are +1 or -1.
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_binary_log = add_loss_layer<loss_binary_log_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_multiclass_log_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic
+ regression loss (e.g. negative log-likelihood loss), which is appropriate
+ for multiclass classification problems. This means that the possible
+ labels when using this loss are integers >= 0.
+
+ Moreover, if after training you were to replace the loss layer of the
+ network with a softmax layer, the network outputs would give the
+ probabilities of each class assignment. That is, if you have K classes
+ then the network should output tensors with the tensor::k()'th dimension
+ equal to K. Applying softmax to these K values gives the probabilities of
+ each class. The index into that K dimensional vector with the highest
+ probability is the predicted class label.
+ !*/
+
+ public:
+
+ typedef unsigned long training_label_type;
+ typedef unsigned long output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output label is the predicted class for each classified object. The number
+ of possible output classes is sub.get_output().k().
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ - all values pointed to by truth are < sub.get_output().k()
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_multiclass_log = add_loss_layer<loss_multiclass_log_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_multimulticlass_log_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements a collection of
+ multiclass classifiers. An example will make its use clear. So suppose,
+ for example, that you want to make something that takes a picture of a
+ vehicle and answers the following questions:
+ - What type of vehicle is it? A sedan or a truck?
+ - What color is it? red, green, blue, gray, or black?
+ You need two separate multi-class classifiers to do this. One to decide
+ the type of vehicle, and another to decide the color. The
+ loss_multimulticlass_log_ allows you to pack these two classifiers into one
+ neural network. This means that when you use the network to process an
+ image it will output 2 labels for each image, the type label and the color
+ label.
+
+ To create a loss_multimulticlass_log_ for the above case you would
+ construct it as follows:
+ std::map<std::string,std::vector<std::string>> labels;
+ labels["type"] = {"sedan", "truck"};
+ labels["color"] = {"red", "green", "blue", "gray", "black"};
+ loss_multimulticlass_log_ myloss(labels);
+ Then you could use myloss with a network object and train it to do this
+ task. More generally, you can use any number of classifiers and labels
+ when using this object. Finally, each of the classifiers uses a standard
+ multi-class logistic regression loss.
+ !*/
+
+ public:
+
+ loss_multimulticlass_log_(
+ );
+ /*!
+ ensures
+ - #number_of_labels() == 0
+ - #get_labels().size() == 0
+ !*/
+
+ loss_multimulticlass_log_ (
+ const std::map<std::string,std::vector<std::string>>& labels
+ );
+ /*!
+ requires
+ - Each vector in labels must contain at least 2 strings. I.e. each
+ classifier must have at least two possible labels.
+ ensures
+ - #number_of_labels() == the total number of strings in all the
+ std::vectors in labels.
+ - #number_of_classifiers() == labels.size()
+ - #get_labels() == labels
+ !*/
+
+ unsigned long number_of_labels(
+ ) const;
+ /*!
+ ensures
+ - returns the total number of labels known to this loss. This is the count of
+ all the labels in each classifier.
+ !*/
+
+ unsigned long number_of_classifiers(
+ ) const;
+ /*!
+ ensures
+ - returns the number of classifiers defined by this loss.
+ !*/
+
+ std::map<std::string,std::vector<std::string>> get_labels (
+ ) const;
+ /*!
+ ensures
+ - returns the names of the classifiers and labels used by this loss. In
+ particular, if the returned object is L then:
+ - L[CLASS] == the set of labels used by the classifier CLASS.
+ - L.size() == number_of_classifiers()
+ - The count of strings in the vectors in L == number_of_labels()
+ !*/
+
+ class classifier_output
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object stores the predictions from one of the classifiers in
+ loss_multimulticlass_log_. It allows you to find out the most likely
+ string label predicted by that classifier, as well as get the class
+ conditional probability of any of the classes in the classifier.
+ !*/
+
+ public:
+
+ classifier_output(
+ );
+ /*!
+ ensures
+ - #num_classes() == 0
+ !*/
+
+ size_t num_classes(
+ ) const;
+ /*!
+ ensures
+ - returns the number of possible classes output by this classifier.
+ !*/
+
+ double probability_of_class (
+ size_t i
+ ) const;
+ /*!
+ requires
+ - i < num_classes()
+ ensures
+ - returns the probability that the true class has a label of label(i).
+ - The sum of probability_of_class(j) for j in the range [0, num_classes()) is always 1.
+ !*/
+
+ const std::string& label(
+ size_t i
+ ) const;
+ /*!
+ requires
+ - i < num_classes()
+ ensures
+ - returns the string label for the ith class.
+ !*/
+
+ operator std::string(
+ ) const;
+ /*!
+ requires
+ - num_classes() != 0
+ ensures
+ - returns the string label for the most probable class.
+ !*/
+
+ friend std::ostream& operator<< (std::ostream& out, const classifier_output& item);
+ /*!
+ requires
+ - num_classes() != 0
+ ensures
+ - prints the most probable class label to out.
+ !*/
+
+ };
+
+ // Both training_label_type and output_label_type should always have sizes equal to
+ // number_of_classifiers(). That is, the std::map should have an entry for every
+ // classifier known to this loss.
+ typedef std::map<std::string,std::string> training_label_type;
+ typedef std::map<std::string,classifier_output> output_label_type;
+
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - number_of_labels() != 0
+ - sub.get_output().k() == number_of_labels()
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - number_of_labels() != 0
+ - sub.get_output().k() == number_of_labels()
+ It should be noted that the last layer in your network should usually
+ be an fc layer. If so, you can satisfy this requirement of k() being
+ number_of_labels() by calling set_num_outputs() prior to training your
+ network like so:
+ your_network.subnet().layer_details().set_num_outputs(your_network.loss_details().number_of_labels());
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ - All the std::maps pointed to by truth contain entries for all the
+ classifiers known to this loss. That is, it must be valid to call
+ truth[i][classifier] for any of the classifiers known to this loss. To
+ say this another way, all the training samples must contain labels for
+ each of the classifiers defined by this loss.
+
+ To really belabor this, this also means that truth[i].size() ==
+ get_labels().size() and that both truth[i] and get_labels() have the same
+ set of key strings. It also means that the value strings in truth[i]
+ must be strings known to the loss, i.e. they are valid labels according
+ to get_labels().
+ !*/
+ };
+
+ template <typename SUBNET>
+ using loss_multimulticlass_log = add_loss_layer<loss_multimulticlass_log_, SUBNET>;
+
+ // Allow comparison between classifier_outputs and std::string to check if the
+ // predicted class is a particular string.
+ inline bool operator== (const std::string& lhs, const loss_multimulticlass_log_::classifier_output& rhs)
+ { return lhs == static_cast<const std::string&>(rhs); }
+ inline bool operator== (const loss_multimulticlass_log_::classifier_output& lhs, const std::string& rhs)
+ { return rhs == static_cast<const std::string&>(lhs); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ enum class use_image_pyramid : uint8_t
+ {
+ no,
+ yes
+ };
+
+ struct mmod_options
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object contains all the parameters that control the behavior of loss_mmod_.
+ !*/
+
+ public:
+
+ struct detector_window_details
+ {
+ detector_window_details() = default;
+ detector_window_details(unsigned long w, unsigned long h) : width(w), height(h) {}
+ detector_window_details(unsigned long w, unsigned long h, const std::string& l) : width(w), height(h), label(l) {}
+
+ unsigned long width = 0;
+ unsigned long height = 0;
+ std::string label;
+
+ friend inline void serialize(const detector_window_details& item, std::ostream& out);
+ friend inline void deserialize(detector_window_details& item, std::istream& in);
+ };
+
+ mmod_options() = default;
+
+ // This kind of object detector is a sliding window detector. The detector_windows
+ // field determines how many sliding windows we will use and what the shape of each
+ // window is. It also determines the output label applied to each detection
+ // identified by each window. Since you will usually use the MMOD loss with an
+ // image pyramid, the detector sizes also determine the size of the smallest object
+ // you can detect.
+ std::vector<detector_window_details> detector_windows;
+
+ // These parameters control how we penalize different kinds of mistakes. See
+ // Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046)
+ // for further details.
+ double loss_per_false_alarm = 1;
+ double loss_per_missed_target = 1;
+
+ // A detection must have an intersection-over-union value greater than this for us
+ // to consider it a match against a ground truth box.
+ double truth_match_iou_threshold = 0.5;
+
+ // When doing non-max suppression, we use overlaps_nms to decide if a box overlaps
+ // an already output detection and should therefore be thrown out.
+ test_box_overlap overlaps_nms = test_box_overlap(0.4);
+
+ // Any mmod_rect in the training data that has its ignore field set to true defines
+ // an "ignore zone" in an image. Any detection from that area is totally ignored
+ // by the optimizer. Therefore, this overlaps_ignore field defines how we decide
+ // if a box falls into an ignore zone. You use these ignore zones if there are
+ // objects in your dataset that you are unsure if you want to detect or otherwise
+ // don't care if the detector gets them or not.
+ test_box_overlap overlaps_ignore;
+
+ // Usually the detector would be scale-invariant, and used with an image pyramid.
+ // However, sometimes scale-invariance may not be desired.
+ use_image_pyramid assume_image_pyramid = use_image_pyramid::yes;
+
+ mmod_options (
+ const std::vector<std::vector<mmod_rect>>& boxes,
+ const unsigned long target_size,
+ const unsigned long min_target_size,
+ const double min_detector_window_overlap_iou = 0.75
+ );
+ /*!
+ requires
+ - 0 < min_target_size <= target_size
+ - 0.5 < min_detector_window_overlap_iou < 1
+ ensures
+ - use_image_pyramid_ == use_image_pyramid::yes
+ - This function should be used when scale-invariance is desired, and
+ input_rgb_image_pyramid is therefore used as the input layer.
+ - This function tries to automatically set the MMOD options to reasonable
+ values, assuming you have a training dataset of boxes.size() images, where
+ the ith image contains objects boxes[i] you want to detect.
+ - The most important thing this function does is decide what detector
+ windows should be used. This is done by finding a set of detector
+ windows that are sized such that:
+ - When slid over an image pyramid, each box in boxes will have an
+ intersection-over-union with one of the detector windows of at least
+ min_detector_window_overlap_iou. That is, we will make sure that
+ each box in boxes could potentially be detected by one of the
+ detector windows. This essentially comes down to picking detector
+ windows with aspect ratios similar to the aspect ratios in boxes.
+ Note that we also make sure that each box can be detected by a window
+ with the same label. For example, if all the boxes had the same
+ aspect ratio but there were 4 different labels used in boxes then
+ there would be 4 resulting detector windows, one for each label.
+ - The longest edge of each detector window is target_size pixels in
+ length, unless the window's shortest side would be less than
+ min_target_size pixels in length. In this case the shortest side
+ will be set to min_target_size length, and the other side sized to
+ preserve the aspect ratio of the window.
+ This means that target_size and min_target_size control the size of the
+ detector windows, while the aspect ratios of the detector windows are
+ automatically determined by the contents of boxes. It should also be
+ emphasized that the detector isn't going to be able to detect objects
+ smaller than any of the detector windows. So consider that when setting
+ these sizes.
+ - This function will also set the overlaps_nms tester to the most
+ restrictive tester that doesn't reject anything in boxes.
+ !*/
+
+ mmod_options (
+ use_image_pyramid use_image_pyramid,
+ const std::vector<std::vector<mmod_rect>>& boxes,
+ const double min_detector_window_overlap_iou = 0.75
+ );
+ /*!
+ requires
+ - use_image_pyramid == use_image_pyramid::no
+ - 0.5 < min_detector_window_overlap_iou < 1
+ ensures
+ - This function should be used when scale-invariance is not desired, and
+ there is no intention to apply an image pyramid.
+ - This function tries to automatically set the MMOD options to reasonable
+ values, assuming you have a training dataset of boxes.size() images, where
+ the ith image contains objects boxes[i] you want to detect.
+ - The most important thing this function does is decide what detector
+ windows should be used. This is done by finding a set of detector
+ windows that are sized such that:
+ - When slid over an image, each box in boxes will have an
+ intersection-over-union with one of the detector windows of at least
+ min_detector_window_overlap_iou. That is, we will make sure that
+ each box in boxes could potentially be detected by one of the
+ detector windows.
+ - This function will also set the overlaps_nms tester to the most
+ restrictive tester that doesn't reject anything in boxes.
+ !*/
+ };
+
+ void serialize(const mmod_options& item, std::ostream& out);
+ void deserialize(mmod_options& item, std::istream& in);
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_mmod_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the Max Margin Object
+ Detection loss defined in the paper:
+ Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046).
+
+ This means you use this loss if you want to detect the locations of objects
+ in images.
+
+ It should also be noted that this loss layer requires an input layer that
+ defines the following functions:
+ - image_contained_point()
+ - tensor_space_to_image_space()
+ - image_space_to_tensor_space()
+ A reference implementation of them and their definitions can be found in
+ the input_rgb_image_pyramid object, which is the recommended input layer to
+ be used with loss_mmod_.
+ !*/
+
+ public:
+
+ typedef std::vector<mmod_rect> training_label_type;
+ typedef std::vector<mmod_rect> output_label_type;
+
+ loss_mmod_(
+ );
+ /*!
+ ensures
+ - #get_options() == mmod_options()
+ !*/
+
+ loss_mmod_(
+ mmod_options options_
+ );
+ /*!
+ ensures
+ - #get_options() == options_
+ !*/
+
+ const mmod_options& get_options (
+ ) const;
+ /*!
+ ensures
+ - returns the options object that defines the general behavior of this loss layer.
+ !*/
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter,
+ double adjust_threshold = 0
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ Also, the output labels are std::vectors of mmod_rects where, for each mmod_rect R,
+ we have the following interpretations:
+ - R.rect == the location of an object in the image.
+ - R.detection_confidence the score for the object, the bigger the score the
+ more confident the detector is that an object is really there. Only
+ objects with a detection_confidence > adjust_threshold are output. So if
+ you want to output more objects (that are also of less confidence) you
+ can call to_label() with a smaller value of adjust_threshold.
+ - R.ignore == false (this value is unused by to_label()).
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ Also, the loss value returned is roughly equal to the average number of
+ mistakes made per image. This is the sum of false alarms and missed
+ detections, weighted by the loss weights for these types of mistakes specified
+ in the mmod_options.
+ !*/
+ };
+
+ template <typename SUBNET>
+ using loss_mmod = add_loss_layer<loss_mmod_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_metric_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it allows you to learn to map objects
+ into a vector space where objects sharing the same class label are close to
+ each other, while objects with different labels are far apart.
+
+ To be specific, it optimizes the following loss function which considers
+ all pairs of objects in a mini-batch and computes a different loss depending
+ on their respective class labels. So if objects A1 and A2 in a mini-batch
+ share the same class label then their contribution to the loss is:
+ max(0, length(A1-A2)-get_distance_threshold() + get_margin())
+
+ While if A1 and B1 have different class labels then their contribution to
+ the loss function is:
+ max(0, get_distance_threshold()-length(A1-B1) + get_margin())
+
+ Therefore, this loss layer optimizes a version of the hinge loss.
+ Moreover, the loss is trying to make sure that all objects with the same
+ label are within get_distance_threshold() distance of each other.
+ Conversely, if two objects have different labels then they should be more
+ than get_distance_threshold() distance from each other in the learned
+ embedding. So this loss function gives you a natural decision boundary for
+ deciding if two objects are from the same class.
+
+ Finally, the loss balances the number of negative pairs relative to the
+ number of positive pairs. Therefore, if there are N pairs that share the
+ same identity in a mini-batch then the algorithm will only include the N
+ worst non-matching pairs in the loss. That is, the algorithm performs hard
+ negative mining on the non-matching pairs. This is important since there
+ are in general way more non-matching pairs than matching pairs. So to
+ avoid imbalance in the loss this kind of hard negative mining is useful.
+ !*/
+ public:
+
+ typedef unsigned long training_label_type;
+ typedef matrix<float,0,1> output_label_type;
+
+ loss_metric_(
+ );
+ /*!
+ ensures
+ - #get_margin() == 0.04
+ - #get_distance_threshold() == 0.6
+ !*/
+
+ loss_metric_(
+ float margin,
+ float dist_thresh
+ );
+ /*!
+ requires
+ - margin > 0
+ - dist_thresh > 0
+ ensures
+ - #get_margin() == margin
+ - #get_distance_threshold() == dist_thresh
+ !*/
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ This loss expects the network to produce a single vector (per sample) as
+ output. This vector is the learned embedding. Therefore, to_label() just
+ copies these output vectors from the network into the output label_iterators
+ given to this function, one for each sample in the input_tensor.
+ !*/
+
+ float get_margin() const;
+ /*!
+ ensures
+ - returns the margin value used by the loss function. See the discussion
+ in WHAT THIS OBJECT REPRESENTS for details.
+ !*/
+
+ float get_distance_threshold() const;
+ /*!
+ ensures
+ - returns the distance threshold value used by the loss function. See the discussion
+ in WHAT THIS OBJECT REPRESENTS for details.
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_metric = add_loss_layer<loss_metric_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_ranking_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the pairwise ranking
+ loss described in the paper:
+ Optimizing Search Engines using Clickthrough Data by Thorsten Joachims
+
+ This is the same loss function used by the dlib::svm_rank_trainer object.
+ Therefore, it is generally appropriate when you have a two class problem
+ and you want to learn a function that ranks one class before the other.
+
+ So for example, suppose you have two classes of data. Objects of type A
+ and objects of type B. Moreover, suppose that you want to sort the objects
+ so that A objects always come before B objects. This loss will help you
+ learn a function that assigns a real number to each object such that A
+ objects get a larger number assigned to them than B objects. This lets you
+ then sort the objects according to the output of the neural network and
+ obtain the desired result of having A objects come before B objects.
+
+ The training labels should be positive values for objects you want to get
+ high scores and negative for objects that should get small scores. So
+ relative to our A/B example, you would give A objects labels of +1 and B
+ objects labels of -1. This should cause the learned network to give A
+ objects large positive values and B objects negative values.
+
+
+ Finally, the specific loss function is:
+ For all pairs of positive vs negative training examples A_i and B_j respectively:
+ sum_ij: max(0, B_i - A_j + margin_ij)
+ where margin_ij = the label for A_j minus the label for B_i. If you
+ always use +1 and -1 labels then the margin is always 2. However, this
+ formulation allows you to give certain training samples different weight by
+ adjusting the training labels appropriately.
+ !*/
+
+ public:
+
+ typedef float training_label_type;
+ typedef float output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output label is the predicted ranking score.
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_ranking = add_loss_layer<loss_ranking_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_epsilon_insensitive_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the epsilon insensitive
+ loss, which is appropriate for regression problems. In particular, this
+ loss function is;
+ loss(y1,y2) = abs(y1-y2)<epsilon ? 0 : abs(y1-y2)-epsilon
+
+ Therefore, the loss is basically just the abs() loss except there is a dead
+ zone around zero, causing the loss to not care about mistakes of magnitude
+ smaller than epsilon.
+ !*/
+ public:
+
+ typedef float training_label_type;
+ typedef float output_label_type;
+
+ loss_epsilon_insensitive_(
+ ) = default;
+ /*!
+ ensures
+ - #get_epsilon() == 1
+ !*/
+
+ loss_epsilon_insensitive_(
+ double eps
+ );
+ /*!
+ requires
+ - eps >= 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the epsilon value used in the loss function. Mistakes in the
+ regressor smaller than get_epsilon() are ignored by the loss function.
+ !*/
+
+ void set_epsilon(
+ double eps
+ );
+ /*!
+ requires
+ - eps >= 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output label is the predicted continuous variable.
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_epsilon_insensitive = add_loss_layer<loss_epsilon_insensitive_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_mean_squared_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the mean squared loss, which is
+ appropriate for regression problems.
+ !*/
+ public:
+
+ typedef float training_label_type;
+ typedef float output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output label is the predicted continuous variable.
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_mean_squared = add_loss_layer<loss_mean_squared_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_mean_squared_multioutput_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the mean squared loss,
+ which is appropriate for regression problems. It is basically just like
+ loss_mean_squared_ except that it lets you define multiple outputs instead
+ of just 1.
+ !*/
+ public:
+
+ typedef matrix<float> training_label_type;
+ typedef matrix<float> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output label is the predicted continuous variable.
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().nr() == 1
+ - sub.get_output().nc() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ - (*(truth + idx)).nc() == 1 for all idx such that 0 <= idx < sub.get_output().num_samples()
+ - (*(truth + idx)).nr() == sub.get_output().k() for all idx such that 0 <= idx < sub.get_output().num_samples()
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_mean_squared_multioutput = add_loss_layer<loss_mean_squared_multioutput_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_multiclass_log_per_pixel_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic
+ regression loss (e.g. negative log-likelihood loss), which is appropriate
+ for multiclass classification problems. It is basically just like
+ loss_multiclass_log_ except that it lets you define matrix outputs instead
+ of scalar outputs. It should be useful, for example, in semantic
+ segmentation where we want to classify each pixel of an image.
+ !*/
+ public:
+
+ // In semantic segmentation, if you don't know the ground-truth of some pixel,
+ // set the label of that pixel to this value. When you do so, the pixel will be
+ // ignored when computing gradients.
+ static const uint16_t label_to_ignore = std::numeric_limits<uint16_t>::max();
+
+ // In semantic segmentation, 65535 classes ought to be enough for anybody.
+ typedef matrix<uint16_t> training_label_type;
+ typedef matrix<uint16_t> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output label is the predicted class for each classified element. The number
+ of possible output classes is sub.get_output().k().
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ - all values pointed to by truth are < sub.get_output().k() or are equal to label_to_ignore.
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_multiclass_log_per_pixel = add_loss_layer<loss_multiclass_log_per_pixel_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_multiclass_log_per_pixel_weighted_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic
+ regression loss (e.g. negative log-likelihood loss), which is appropriate
+ for multiclass classification problems. It is basically just like
+ loss_multiclass_log_per_pixel_ except that it lets you define per-pixel
+ weights, which may be useful e.g. if you want to emphasize rare classes
+ while training. (If the classification problem is difficult, a flat weight
+ structure may lead the network to always predict the most common label, in
+ particular if the degree of imbalance is high. To emphasize a certain
+ class or classes, simply increase the weights of the corresponding pixels,
+ relative to the weights of the other pixels.)
+
+ Note that if you set the weight to 0 whenever a pixel's label is equal to
+ loss_multiclass_log_per_pixel_::label_to_ignore, and to 1 otherwise, then
+ you essentially get loss_multiclass_log_per_pixel_ as a special case.
+ !*/
+ public:
+
+ struct weighted_label
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the truth label of a single pixel, together with
+ an associated weight (the higher the weight, the more emphasis the
+ corresponding pixel is given during the training).
+ !*/
+
+ weighted_label();
+ weighted_label(uint16_t label, float weight = 1.f);
+
+ // The ground-truth label. In semantic segmentation, 65536 classes ought to be
+ // enough for anybody.
+ uint16_t label = 0;
+
+ // The weight of the corresponding pixel.
+ float weight = 1.f;
+ };
+
+ typedef matrix<weighted_label> training_label_type;
+ typedef matrix<uint16_t> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output label is the predicted class for each classified element. The number
+ of possible output classes is sub.get_output().k().
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ - all labels pointed to by truth are < sub.get_output().k(), or the corresponding weight
+ is zero.
+ !*/
+
+ };
+
+ template <typename SUBNET>
+ using loss_multiclass_log_per_pixel_weighted = add_loss_layer<loss_multiclass_log_per_pixel_weighted_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_mean_squared_per_pixel_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, it implements the mean squared loss,
+ which is appropriate for regression problems. It is basically just like
+ loss_mean_squared_multioutput_ except that it lets you define matrix or
+ image outputs, instead of vector.
+ !*/
+ public:
+
+ typedef matrix<float> training_label_type;
+ typedef matrix<float> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output labels are the predicted continuous variables.
+ !*/
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().k() == 1
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ - for all idx such that 0 <= idx < sub.get_output().num_samples():
+ - sub.get_output().nr() == (*(truth + idx)).nr()
+ - sub.get_output().nc() == (*(truth + idx)).nc()
+ !*/
+ };
+
+ template <typename SUBNET>
+ using loss_mean_squared_per_pixel = add_loss_layer<loss_mean_squared_per_pixel_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+ class loss_dot_
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the loss layer interface defined above by
+ EXAMPLE_LOSS_LAYER_. In particular, selecting this loss means you want
+ maximize the dot product between the output of a network and a set of
+ training vectors. The loss is therefore the negative dot product. To be
+ very specific, if X is the output vector of a network and Y is a training
+ label (also a vector), then the loss for this training sample is: -dot(X,Y)
+ !*/
+
+ public:
+
+ typedef matrix<float,0,1> training_label_type;
+ typedef matrix<float,0,1> output_label_type;
+
+ template <
+ typename SUB_TYPE,
+ typename label_iterator
+ >
+ void to_label (
+ const tensor& input_tensor,
+ const SUB_TYPE& sub,
+ label_iterator iter
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
+ it has the additional calling requirements that:
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ and the output labels are simply the final network outputs stuffed into a
+ vector. To be very specific, the output is the following for all valid i:
+ *(iter+i) == trans(rowm(mat(sub.get_output()),i))
+ !*/
+
+
+ template <
+ typename const_label_iterator,
+ typename SUBNET
+ >
+ double compute_loss_value_and_gradient (
+ const tensor& input_tensor,
+ const_label_iterator truth,
+ SUBNET& sub
+ ) const;
+ /*!
+ This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
+ except it has the additional calling requirements that:
+ - sub.get_output().num_samples() == input_tensor.num_samples()
+ - sub.sample_expansion_factor() == 1
+ - Let NETWORK_OUTPUT_DIMS == sub.get_output().size()/sub.get_output().num_samples()
+ - for all idx such that 0 <= idx < sub.get_output().num_samples():
+ - NETWORK_OUTPUT_DIMS == (*(truth + idx)).size()
+ !*/
+ };
+
+ template <typename SUBNET>
+ using loss_dot = add_loss_layer<loss_dot_, SUBNET>;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_LOSS_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/dnn/solvers.h b/ml/dlib/dlib/dnn/solvers.h
new file mode 100644
index 000000000..204541a7e
--- /dev/null
+++ b/ml/dlib/dlib/dnn/solvers.h
@@ -0,0 +1,405 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_SOLVERS_H_
+#define DLIB_DNn_SOLVERS_H_
+
+#include "solvers_abstract.h"
+#include "tensor.h"
+#include <iostream>
+#include "layers.h"
+
+namespace dlib
+{
+ class sgd
+ {
+ public:
+
+ explicit sgd(
+ float weight_decay_,
+ float momentum_ = 0.9
+ )
+ {
+ weight_decay = weight_decay_;
+ momentum = momentum_;
+ }
+
+ sgd(
+ ) : sgd(0.0005, 0.9)
+ {
+ }
+
+ float get_momentum (
+ ) const { return momentum; }
+
+ float get_weight_decay (
+ ) const { return weight_decay; }
+
+ template <typename layer_type>
+ const tensor& operator() (
+ const float learning_rate,
+ const layer_type& l,
+ const tensor& params_grad
+ )
+ {
+ const tensor& params = l.get_layer_params();
+
+ DLIB_CASSERT(params.size() != 0);
+ if (v.size() == 0)
+ {
+ v.copy_size(params_grad);
+ v = 0;
+ }
+
+ const double lr = learning_rate*get_learning_rate_multiplier(l);
+ const double wd = weight_decay*get_weight_decay_multiplier(l);
+
+ //perform: v = momentum*mat(v) - wd*lr*mat(params) - lr*mat(params_grad);
+ tt::affine_transform(v, v, params, params_grad, momentum, -wd*lr, -lr);
+
+ return v;
+ }
+
+ template <unsigned long N>
+ const tensor& operator() (
+ const float learning_rate,
+ const fc_<N,FC_HAS_BIAS>& l,
+ const tensor& params_grad
+ )
+ {
+ update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.get_num_outputs());
+ return v;
+ }
+
+ template <
+ long _num_filters,
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y,
+ int _padding_x
+ >
+ const tensor& operator() (
+ const float learning_rate,
+ const con_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l,
+ const tensor& params_grad
+ )
+ {
+ update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.num_filters());
+ return v;
+ }
+
+ template <
+ long _num_filters,
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y,
+ int _padding_x
+ >
+ const tensor& operator() (
+ const float learning_rate,
+ const cont_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l,
+ const tensor& params_grad
+ )
+ {
+ update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.num_filters());
+ return v;
+ }
+
+ template < layer_mode mode >
+ const tensor& operator() (
+ const float learning_rate,
+ const bn_<mode>& l,
+ const tensor& params_grad
+ )
+ {
+ update_considering_bias(learning_rate, l, params_grad, params_grad.size()/2);
+ return v;
+ }
+
+ friend void serialize(const sgd& item, std::ostream& out)
+ {
+ serialize("sgd2", out);
+ serialize(item.v, out);
+ serialize(item.weight_decay, out);
+ serialize(item.momentum, out);
+ }
+
+ friend void deserialize(sgd& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "sgd2")
+ throw serialization_error("Unexpected version found while deserializing dlib::sgd.");
+ deserialize(item.v, in);
+ deserialize(item.weight_decay, in);
+ deserialize(item.momentum, in);
+ }
+
+ friend std::ostream& operator<< (std::ostream& out, const sgd& item)
+ {
+ out << "sgd: weight_decay="<<item.get_weight_decay() << ", momentum="<<item.get_momentum();
+ return out;
+ }
+
+ private:
+
+ template <typename layer_type>
+ void update_considering_bias(
+ const float learning_rate,
+ const layer_type& l,
+ const tensor& params_grad,
+ unsigned long bias_offset
+ )
+ {
+ const tensor& params = l.get_layer_params();
+
+ DLIB_CASSERT(params.size() != 0);
+ if (v.size() == 0)
+ {
+ v.copy_size(params_grad);
+ v = 0;
+ }
+
+ double lr = learning_rate*get_learning_rate_multiplier(l);
+ double wd = weight_decay*get_weight_decay_multiplier(l);
+
+ //perform: v = momentum*mat(v) - wd*lr*mat(params) - lr*mat(params_grad);
+
+ if (l.get_bias_learning_rate_multiplier() == 1 && l.get_bias_weight_decay_multiplier() == 1)
+ {
+ tt::affine_transform(v, v, params, params_grad, momentum, -wd*lr, -lr);
+ }
+ else
+ {
+
+ tt::affine_transform_range(0, bias_offset, v, v, params, params_grad, momentum, -wd*lr, -lr);
+
+ // now update the biases but apply their multipliers
+ lr *= l.get_bias_learning_rate_multiplier();
+ wd *= l.get_bias_weight_decay_multiplier();
+ tt::affine_transform_range(bias_offset, v.size(), v, v, params, params_grad, momentum, -wd*lr, -lr);
+ }
+ }
+
+ resizable_tensor v;
+ float weight_decay;
+ float momentum;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class adam
+ {
+ public:
+
+ adam(
+ float weight_decay_,
+ float momentum1_,
+ float momentum2_
+ )
+ {
+ weight_decay = weight_decay_;
+ momentum1 = momentum1_;
+ momentum2 = momentum2_;
+ t = 0;
+ }
+
+ adam(
+ ) : adam(0.0005, 0.9, 0.999)
+ {}
+
+ float get_momentum1 (
+ ) const { return momentum1; }
+
+ float get_momentum2 (
+ ) const { return momentum2; }
+
+ float get_weight_decay (
+ ) const { return weight_decay; }
+
+ template <typename layer_type>
+ const tensor& operator() (
+ const float learning_rate,
+ const layer_type& l,
+ const tensor& params_grad
+ )
+ {
+ const tensor& params = l.get_layer_params();
+ DLIB_CASSERT(params.size() != 0);
+ if (v.size() == 0)
+ {
+ m.copy_size(params_grad);
+ m = 0;
+ v.copy_size(params_grad);
+ v = 0;
+ s.copy_size(params_grad);
+ }
+
+ ++t;
+
+
+ tt::compute_adam_update(0, params.size(), s, m, v, t,
+ learning_rate*get_learning_rate_multiplier(l),
+ weight_decay*get_weight_decay_multiplier(l),
+ momentum1, momentum2, params, params_grad);
+
+ return s;
+ }
+
+ template <unsigned long N>
+ const tensor& operator() (
+ const float learning_rate,
+ const fc_<N,FC_HAS_BIAS>& l,
+ const tensor& params_grad
+ )
+ {
+ update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.get_num_outputs());
+ return s;
+ }
+
+ template <
+ long _num_filters,
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y,
+ int _padding_x
+ >
+ const tensor& operator() (
+ const float learning_rate,
+ const con_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l,
+ const tensor& params_grad
+ )
+ {
+ update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.num_filters());
+ return s;
+ }
+
+ template <
+ long _num_filters,
+ long _nr,
+ long _nc,
+ int _stride_y,
+ int _stride_x,
+ int _padding_y,
+ int _padding_x
+ >
+ const tensor& operator() (
+ const float learning_rate,
+ const cont_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l,
+ const tensor& params_grad
+ )
+ {
+ update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.num_filters());
+ return s;
+ }
+
+ template < layer_mode mode >
+ const tensor& operator() (
+ const float learning_rate,
+ const bn_<mode>& l,
+ const tensor& params_grad
+ )
+ {
+ update_considering_bias(learning_rate, l, params_grad, params_grad.size()/2);
+ return s;
+ }
+
+
+ friend void serialize(const adam& item, std::ostream& out)
+ {
+ serialize("adam2", out);
+ serialize(item.m, out);
+ serialize(item.v, out);
+ serialize(item.s, out);
+ serialize(item.weight_decay, out);
+ serialize(item.momentum1, out);
+ serialize(item.momentum2, out);
+ serialize(item.t, out);
+ }
+
+ friend void deserialize(adam& item, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != "adam2")
+ throw serialization_error("Unexpected version found while deserializing dlib::adam.");
+ deserialize(item.m, in);
+ deserialize(item.v, in);
+ deserialize(item.s, in);
+ deserialize(item.weight_decay, in);
+ deserialize(item.momentum1, in);
+ deserialize(item.momentum2, in);
+ deserialize(item.t, in);
+ }
+
+ friend std::ostream& operator<< (std::ostream& out, const adam& item)
+ {
+ out << "adam: weight_decay="<<item.get_weight_decay() << ", momentum1="<<item.get_momentum1() << ", momentum2="<<item.get_momentum2();
+ return out;
+ }
+
+ private:
+
+ template <typename layer_type>
+ void update_considering_bias(
+ const float learning_rate,
+ const layer_type& l,
+ const tensor& params_grad,
+ unsigned long bias_offset
+ )
+ {
+ const tensor& params = l.get_layer_params();
+ DLIB_CASSERT(params.size() != 0);
+ if (v.size() == 0)
+ {
+ m.copy_size(params_grad);
+ m = 0;
+ v.copy_size(params_grad);
+ v = 0;
+ s.copy_size(params_grad);
+ }
+
+
+ ++t;
+
+ if (l.get_bias_learning_rate_multiplier() == 1 && l.get_bias_weight_decay_multiplier() == 1)
+ {
+ tt::compute_adam_update(0, params.size(), s, m, v, t,
+ learning_rate*get_learning_rate_multiplier(l),
+ weight_decay*get_weight_decay_multiplier(l),
+ momentum1, momentum2, params, params_grad);
+ }
+ else
+ {
+ tt::compute_adam_update(0, bias_offset, s, m, v, t,
+ learning_rate*get_learning_rate_multiplier(l),
+ weight_decay*get_weight_decay_multiplier(l),
+ momentum1, momentum2, params, params_grad);
+
+ tt::compute_adam_update(bias_offset, params.size(), s, m, v, t,
+ learning_rate*get_learning_rate_multiplier(l)*l.get_bias_learning_rate_multiplier(),
+ weight_decay*get_weight_decay_multiplier(l)*l.get_bias_weight_decay_multiplier(),
+ momentum1, momentum2, params, params_grad);
+ }
+ }
+ resizable_tensor m;
+ resizable_tensor v;
+ resizable_tensor s;
+ float weight_decay;
+ float momentum1;
+ float momentum2;
+ float t;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_SOLVERS_H_
+
diff --git a/ml/dlib/dlib/dnn/solvers_abstract.h b/ml/dlib/dlib/dnn/solvers_abstract.h
new file mode 100644
index 000000000..d10ef163a
--- /dev/null
+++ b/ml/dlib/dlib/dnn/solvers_abstract.h
@@ -0,0 +1,204 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DNn_SOLVERS_ABSTRACT_H_
+#ifdef DLIB_DNn_SOLVERS_ABSTRACT_H_
+
+#include "tensor_abstract.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class EXAMPLE_SOLVER
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ A solver defines the parameter update rule for a single layer in a deep
+ neural network. It takes a parameter gradient vector and the layer's
+ parameters and tells you how the parameters should be updated.
+ Importantly, each solver instance is used with only one layer in a network.
+ This allows us to define solvers that have per layer state, for example, a
+ solver may keep a momentum term and apply it to its update rule.
+
+ Note that there is no dlib::EXAMPLE_SOLVER type. It is shown here purely
+ to document the interface a solver object must implement.
+ !*/
+
+ public:
+
+ EXAMPLE_SOLVER(
+ );
+
+ template <typename layer_type>
+ const tensor& operator() (
+ const float learning_rate,
+ const layer_type& l,
+ const tensor& params_grad
+ )
+ /*!
+ requires
+ - l.get_layer_params().size() != 0
+ - have_same_dimensions(l.get_layer_params(), params_grad) == true.
+ - When this function is invoked on a particular solver instance, it is
+ always supplied with the same layer instance, l. That is, the solver is
+ allowed to remember things from one invocation to another and to assume
+ that it is being serially applied to optimize the same layer's
+ parameters.
+ ensures
+ - Returns a step vector V that is intended to be used to update the
+ parameters by adding V to l.get_layer_params().
+ - This function will use the given "learning rate" to compute V. How the
+ learning rate is used is solver dependent. But in general the learning
+ rate should be used to select the step size, i.e. to somehow determine
+ the magnitude of V.
+ !*/
+ };
+
+ void serialize(const EXAMPLE_SOLVER& item, std::ostream& out);
+ void deserialize(EXAMPLE_SOLVER& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+ std::ostream& operator<< (std::ostream& out, const EXAMPLE_SOLVER& item);
+ /*!
+ Prints the solver's name and parameters to out.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class sgd
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the EXAMPLE_SOLVER interface defined above. It is a
+ basic stochastic gradient descent solver which uses momentum and weight
+ decay. In particular, it computes the update vector V according to:
+ V = momentum*V - weight_decay*learning_rate*l.get_layer_params() - learning_rate*params_grad;
+ Here V is a momentum term that is remembered by the solver from one
+ invocation of operator() to the next.
+
+
+ Note that the actual learning rate and weight decay used by the solver are
+ multiplied by the per layer multipliers. That is, the solver will call
+ get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and
+ multiply these values with the nominal learning rate and weight decay,
+ respectively, to determine the values it will use during each step. It is
+ also overloaded to allow additional learning rate multipliers to be applied
+ to fc_ and con_ bias parameters.
+ !*/
+ public:
+
+ sgd(
+ );
+ /*!
+ ensures
+ - #get_weight_decay() == 0.0005
+ - #get_momentum() == 0.9
+ !*/
+
+ explicit sgd(
+ float weight_decay,
+ float momentum = 0.9
+ );
+ /*!
+ requires
+ - weight_decay >= 0
+ - momentum >= 0
+ ensures
+ - #get_weight_decay() == weight_decay
+ - #get_momentum() == momentum
+ !*/
+
+ float get_weight_decay () const;
+ float get_momentum () const;
+ };
+
+ void serialize(const sgd& item, std::ostream& out);
+ void deserialize(sgd& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+ std::ostream& operator<< (std::ostream& out, const sgd& item);
+ /*!
+ Prints the solver's name and parameters to out.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class adam
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the EXAMPLE_SOLVER interface defined above. In
+ particular, it implements the ADAM parameter update method described in the
+ paper:
+ Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
+ optimization." International Conference on Learning Representation. 2015.
+
+
+ Note that the actual learning rate and weight decay used by the solver are
+ multiplied by the per layer multipliers. That is, the solver will call
+ get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and
+ multiply these values with the nominal learning rate and weight decay,
+ respectively, to determine the values it will use during each step. It is
+ also overloaded to allow additional learning rate multipliers to be applied
+ to fc_ and con_ bias parameters.
+ !*/
+
+ public:
+
+ adam(
+ );
+ /*!
+ ensures
+ - #get_weight_decay() == 0.0005
+ - #get_momentum1() == 0.9
+ - #get_momentum2() == 0.999
+ !*/
+
+ adam(
+ float weight_decay,
+ float momentum1,
+ float momentum2
+ );
+ /*!
+ requires
+ - weight_decay >= 0
+ - 0 <= momentum1 < 1
+ - 0 <= momentum2 < 1
+ ensures
+ - #get_weight_decay() == weight_decay
+ - #get_momentum1() == momentum1
+ - #get_momentum2() == momentum2
+ !*/
+
+ float get_weight_decay () const;
+ float get_momentum1 () const;
+ float get_momentum2 () const;
+ };
+
+ void serialize(const adam& item, std::ostream& out);
+ void deserialize(adam& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+ std::ostream& operator<< (std::ostream& out, const adam& item);
+ /*!
+ Prints the solver's name and parameters to out.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_SOLVERS_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/dnn/tensor.h b/ml/dlib/dlib/dnn/tensor.h
new file mode 100644
index 000000000..8039fe666
--- /dev/null
+++ b/ml/dlib/dlib/dnn/tensor.h
@@ -0,0 +1,686 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_TENSOR_H_
+#define DLIB_DNn_TENSOR_H_
+
+#include "tensor_abstract.h"
+#include <cstring>
+#include "../matrix.h"
+#include "cudnn_dlibapi.h"
+#include "gpu_data.h"
+#include "../byte_orderer.h"
+#include <memory>
+#include "../any.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class tensor;
+ namespace cuda
+ {
+ void set_tensor (
+ tensor& t,
+ float value
+ );
+
+ void scale_tensor (
+ tensor& t,
+ float value
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class tensor
+ {
+ public:
+
+ tensor (
+ ) :
+ m_n(0), m_k(0), m_nr(0), m_nc(0), m_size(0)
+ {
+ }
+
+ virtual ~tensor() {}
+
+ long long num_samples() const { return m_n; }
+ long long k() const { return m_k; }
+ long long nr() const { return m_nr; }
+ long long nc() const { return m_nc; }
+ size_t size() const { return m_size; }
+
+ typedef float* iterator;
+ typedef const float* const_iterator;
+ iterator begin() { return host(); }
+ const_iterator begin() const { return host(); }
+ iterator end() { return host()+size(); }
+ const_iterator end() const { return host()+size(); }
+
+ void async_copy_to_device() const
+ {
+ data().async_copy_to_device();
+ }
+
+ virtual const float* host() const = 0;
+ virtual float* host() = 0;
+ virtual float* host_write_only() = 0;
+ virtual const float* device() const = 0;
+ virtual float* device() = 0;
+ virtual float* device_write_only() = 0;
+
+ virtual const any& annotation() const = 0;
+ virtual any& annotation() = 0;
+
+ int device_id() const { return data().device_id(); }
+
+ tensor& operator= (float val)
+ {
+#ifdef DLIB_USE_CUDA
+ // If you are using CUDA then presumably you will be mostly using tensors on
+ // the GPU. So unless you seem to be actively working with the host side's
+ // data then we do this initialization on the device side since this avoids a
+ // host to device transfer that would likely immediately follow.
+ if (data().device_ready())
+ {
+ cuda::set_tensor(*this, val);
+ return *this;
+ }
+#endif
+ auto d = host_write_only();
+ for (size_t i = 0; i < size(); ++i)
+ d[i] = val;
+
+ return *this;
+ }
+
+ tensor& operator*= (float val)
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::scale_tensor(*this, val);
+ return *this;
+#else
+ for (auto& d : *this)
+ d *= val;
+
+ return *this;
+#endif
+ }
+
+ tensor& operator/= (float val)
+ {
+ *this *= 1.0/val;
+ return *this;
+ }
+
+ template <typename EXP>
+ tensor& operator= (const matrix_exp<EXP>& item)
+ {
+ DLIB_CASSERT(num_samples() == item.nr() &&
+ nr()*nc()*k() == item.nc());
+ static_assert((is_same_type<float, typename EXP::type>::value == true),
+ "To assign a matrix to a tensor the matrix must contain float values");
+
+ set_ptrm(host_write_only(), m_n, m_nr*m_nc*m_k) = item;
+ return *this;
+ }
+
+ template <typename EXP>
+ tensor& operator+= (const matrix_exp<EXP>& item)
+ {
+ DLIB_CASSERT(num_samples() == item.nr() &&
+ nr()*nc()*k() == item.nc());
+ static_assert((is_same_type<float, typename EXP::type>::value == true),
+ "To assign a matrix to a tensor the matrix must contain float values");
+ set_ptrm(host(), m_n, m_nr*m_nc*m_k) += item;
+ return *this;
+ }
+
+ template <typename EXP>
+ tensor& operator-= (const matrix_exp<EXP>& item)
+ {
+ DLIB_CASSERT(num_samples() == item.nr() &&
+ nr()*nc()*k() == item.nc());
+ static_assert((is_same_type<float, typename EXP::type>::value == true),
+ "To assign a matrix to a tensor the matrix must contain float values");
+ set_ptrm(host(), m_n, m_nr*m_nc*m_k) -= item;
+ return *this;
+ }
+
+ template <typename EXP>
+ void set_sample (
+ unsigned long long idx,
+ const matrix_exp<EXP>& item
+ )
+ {
+ DLIB_CASSERT(idx < (unsigned long long)num_samples());
+ DLIB_CASSERT(item.size() == nr()*nc()*k());
+ static_assert((is_same_type<float, typename EXP::type>::value == true),
+ "To assign a matrix to a tensor the matrix must contain float values");
+ set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) = item;
+ }
+
+
+ template <typename EXP>
+ void add_to_sample (
+ unsigned long long idx,
+ const matrix_exp<EXP>& item
+ )
+ {
+ DLIB_CASSERT(idx < (unsigned long long)num_samples());
+ DLIB_CASSERT(item.size() == nr()*nc()*k());
+ static_assert((is_same_type<float, typename EXP::type>::value == true),
+ "To assign a matrix to a tensor the matrix must contain float values");
+ set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) += item;
+ }
+
+
+#ifdef DLIB_USE_CUDA
+ virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor (
+ ) const = 0;
+#endif
+
+ friend void memcpy (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(dest.size() == src.size());
+ memcpy(dest.data(), dest.get_alias_offset(),
+ src.data(), src.get_alias_offset(),
+ src.size());
+ }
+
+
+ protected:
+
+ friend class alias_tensor;
+
+ virtual gpu_data& data() = 0;
+ virtual const gpu_data& data() const = 0;
+ virtual size_t get_alias_offset() const { return 0; } // needed by alias_tensor.
+
+ long long m_n;
+ long long m_k;
+ long long m_nr;
+ long long m_nc;
+ long long m_size; // always equal to m_n*m_k*m_nr*m_nc
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool is_vector (
+ const tensor& t
+ )
+ {
+ return t.size() == (size_t)t.num_samples() ||
+ t.size() == (size_t)t.k() ||
+ t.size() == (size_t)t.nr() ||
+ t.size() == (size_t)t.nc();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const matrix_op<op_pointer_to_mat<float> > mat (
+ const tensor& t,
+ long long nr,
+ long long nc
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0 ,
+ "\tconst matrix_exp mat(tensor, nr, nc)"
+ << "\n\t nr and nc must be >= 0"
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ );
+ DLIB_ASSERT(nr*nc == (long long)t.size() ,
+ "\tconst matrix_exp mat(tensor, nr, nc)"
+ << "\n\t The sizes don't match up."
+ << "\n\t nr*nc: " << nr*nc
+ << "\n\t t.size(): " << t.size()
+ );
+ typedef op_pointer_to_mat<float> op;
+ return matrix_op<op>(op(t.host(),nr,nc));
+ }
+
+ inline const matrix_op<op_pointer_to_mat<float> > mat (
+ const tensor& t
+ )
+ {
+ if (t.size() != 0)
+ return mat(t, t.num_samples(), t.size()/t.num_samples());
+ else
+ return mat((float*)0,0,0);
+ }
+
+ inline const matrix_op<op_pointer_to_mat<float> > image_plane (
+ const tensor& t,
+ long long sample = 0,
+ long long k = 0
+ )
+ {
+ DLIB_ASSERT(0 <= sample && sample < t.num_samples() &&
+ 0 <= k && k < t.k() &&
+ t.size() != 0,
+ "\tconst matrix_exp image_plane(tensor,sample,k)"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t sample: " << sample
+ << "\n\t k: " << k
+ << "\n\t t.num_samples(): " << t.num_samples()
+ << "\n\t t.k(): " << t.k()
+ << "\n\t t.size(): " << t.size()
+ );
+
+
+ typedef op_pointer_to_mat<float> op;
+ return matrix_op<op>(op(t.host() + ((sample*t.k() + k)*t.nr())*t.nc(),
+ t.nr(),
+ t.nc()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool have_same_dimensions (
+ const tensor& a,
+ const tensor& b
+ )
+ {
+ return a.num_samples() == b.num_samples() &&
+ a.k() == b.k() &&
+ a.nr() == b.nr() &&
+ a.nc() == b.nc();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class resizable_tensor : public tensor
+ {
+ public:
+ resizable_tensor(
+ )
+ {}
+
+ template <typename EXP>
+ resizable_tensor(
+ const matrix_exp<EXP>& item
+ )
+ {
+ set_size(item.nr(), item.nc());
+ *this = item;
+ }
+
+ explicit resizable_tensor(
+ long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
+ )
+ {
+ DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0);
+
+ set_size(n_,k_,nr_,nc_);
+ }
+
+ resizable_tensor(const resizable_tensor& item) : _annotation(item.annotation())
+ {
+ copy_size(item);
+ memcpy(*this, item);
+ }
+ resizable_tensor(const tensor& item) : _annotation(item.annotation())
+ {
+ copy_size(item);
+ memcpy(*this, item);
+ }
+
+ resizable_tensor(resizable_tensor&& item) { swap(item); }
+ resizable_tensor& operator=(resizable_tensor&& item) { swap(item); return *this; }
+
+ virtual const float* host() const { return data_instance.host(); }
+ virtual float* host() { return data_instance.host(); }
+ virtual float* host_write_only() { return data_instance.host_write_only(); }
+ virtual const float* device() const { return data_instance.device(); }
+ virtual float* device() { return data_instance.device(); }
+ virtual float* device_write_only() { return data_instance.device_write_only(); }
+
+ virtual const any& annotation() const { return _annotation; }
+ virtual any& annotation() { return _annotation; }
+
+ void clear(
+ )
+ {
+ set_size(0,0,0,0);
+ _annotation.clear();
+ // free underlying memory
+ data_instance.set_size(0);
+ }
+
+ void copy_size (
+ const tensor& item
+ )
+ {
+ set_size(item.num_samples(), item.k(), item.nr(), item.nc());
+ }
+
+ resizable_tensor& operator= (float val)
+ {
+ tensor::operator=(val);
+ return *this;
+ }
+
+ template <typename EXP>
+ resizable_tensor& operator= (
+ const matrix_exp<EXP>& item
+ )
+ {
+ if (!(num_samples() == item.nr() && k()*nr()*nc() == item.nc()))
+ set_size(item.nr(), item.nc());
+ tensor::operator=(item);
+ return *this;
+ }
+
+ void set_size(
+ long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
+ )
+ {
+ DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0);
+
+ m_n = n_;
+ m_k = k_;
+ m_nr = nr_;
+ m_nc = nc_;
+ m_size = n_*k_*nr_*nc_;
+ if ((long long)data_instance.size() < m_size)
+ data_instance.set_size(m_size);
+#ifdef DLIB_USE_CUDA
+ cudnn_descriptor.set_size(m_n,m_k,m_nr,m_nc);
+#endif
+ }
+
+
+ resizable_tensor& operator= (const resizable_tensor& item)
+ {
+ resizable_tensor temp(item);
+ temp.swap(*this);
+ return *this;
+ }
+
+ resizable_tensor& operator= (const tensor& item)
+ {
+ resizable_tensor temp(item);
+ temp.swap(*this);
+ return *this;
+ }
+
+
+ void swap(resizable_tensor& item)
+ {
+ std::swap(m_n, item.m_n);
+ std::swap(m_k, item.m_k);
+ std::swap(m_nr, item.m_nr);
+ std::swap(m_nc, item.m_nc);
+ std::swap(m_size, item.m_size);
+ std::swap(data_instance, item.data_instance);
+ std::swap(_annotation, item._annotation);
+#ifdef DLIB_USE_CUDA
+ std::swap(cudnn_descriptor, item.cudnn_descriptor);
+#endif
+ }
+
+#ifdef DLIB_USE_CUDA
+ virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor (
+ ) const { return cudnn_descriptor; }
+#endif
+
+ private:
+
+#ifdef DLIB_USE_CUDA
+ cuda::tensor_descriptor cudnn_descriptor;
+#endif
+
+ gpu_data data_instance;
+ any _annotation;
+ virtual gpu_data& data() { return data_instance; }
+ virtual const gpu_data& data() const { return data_instance; }
+ };
+
+ inline void serialize(const tensor& item, std::ostream& out)
+ {
+ int version = 2;
+ serialize(version, out);
+ serialize(item.num_samples(), out);
+ serialize(item.k(), out);
+ serialize(item.nr(), out);
+ serialize(item.nc(), out);
+ byte_orderer bo;
+ auto sbuf = out.rdbuf();
+ for (auto d : item)
+ {
+ // Write out our data as 4byte little endian IEEE floats rather than using
+ // dlib's default float serialization. We do this because it will result in
+ // more compact outputs. It's slightly less portable but it seems doubtful
+ // that any CUDA enabled platform isn't going to use IEEE floats. But if one
+ // does we can just update the serialization code here to handle it if such a
+ // platform is encountered.
+ bo.host_to_little(d);
+ static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats");
+ sbuf->sputn((char*)&d, sizeof(d));
+ }
+ }
+
+ inline void deserialize(resizable_tensor& item, std::istream& in)
+ {
+ int version;
+ deserialize(version, in);
+ if (version != 2)
+ throw serialization_error("Unexpected version found while deserializing dlib::resizable_tensor.");
+
+ long long num_samples=0, k=0, nr=0, nc=0;
+ deserialize(num_samples, in);
+ deserialize(k, in);
+ deserialize(nr, in);
+ deserialize(nc, in);
+ item.set_size(num_samples, k, nr, nc);
+ byte_orderer bo;
+ auto sbuf = in.rdbuf();
+ for (auto& d : item)
+ {
+ static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats");
+ if (sbuf->sgetn((char*)&d,sizeof(d)) != sizeof(d))
+ {
+ in.setstate(std::ios::badbit);
+ throw serialization_error("Error reading data while deserializing dlib::resizable_tensor.");
+ }
+ bo.little_to_host(d);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double dot(
+ const tensor& a,
+ const tensor& b
+ )
+ {
+ DLIB_CASSERT(a.size() == b.size());
+ const float* da = a.host();
+ const float* db = b.host();
+ double sum = 0;
+ for (size_t i = 0; i < a.size(); ++i)
+ sum += da[i]*db[i];
+ return sum;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class alias_tensor_instance : public tensor
+ {
+ alias_tensor_instance(
+ ) : data_instance(0), _annotation(0), data_offset(0) {}
+
+ public:
+ friend class alias_tensor;
+ friend class alias_tensor_const_instance;
+
+ alias_tensor_instance& operator= (float val)
+ {
+ tensor::operator=(val);
+ return *this;
+ }
+
+ template <typename EXP>
+ alias_tensor_instance& operator= (const matrix_exp<EXP>& item)
+ {
+ tensor::operator=(item);
+ return *this;
+ }
+
+ virtual const float* host() const { return data_instance->host()+data_offset; }
+ virtual float* host() { return data_instance->host()+data_offset; }
+ virtual float* host_write_only() { return data_instance->host()+data_offset; }
+ virtual const float* device() const { return data_instance->device()+data_offset; }
+ virtual float* device() { return data_instance->device()+data_offset; }
+ virtual float* device_write_only() { return data_instance->device()+data_offset; }
+
+ virtual const any& annotation() const { return *_annotation; }
+ virtual any& annotation() { return *_annotation; }
+
+#ifdef DLIB_USE_CUDA
+ virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor (
+ ) const { return *cudnn_descriptor; }
+#endif
+ private:
+
+ virtual size_t get_alias_offset() const { return data_offset; }
+
+#ifdef DLIB_USE_CUDA
+ std::shared_ptr<cuda::tensor_descriptor> cudnn_descriptor;
+#endif
+ gpu_data* data_instance;
+ any* _annotation;
+ size_t data_offset;
+ virtual gpu_data& data() { return *data_instance; }
+ virtual const gpu_data& data() const { return *data_instance; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class alias_tensor_const_instance
+ {
+ public:
+ const tensor& get() const { return inst; }
+ operator const tensor& () { return inst; }
+
+ alias_tensor_const_instance(const alias_tensor_instance& item) : inst(item) {}
+
+ private:
+ alias_tensor_instance inst;
+
+ friend class alias_tensor;
+ alias_tensor_const_instance() {}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class alias_tensor
+ {
+ public:
+
+ alias_tensor (
+ ) {}
+
+ alias_tensor (
+ long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
+ )
+ {
+ DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0);
+
+ inst.m_n = n_;
+ inst.m_k = k_;
+ inst.m_nr = nr_;
+ inst.m_nc = nc_;
+ inst.m_size = n_*k_*nr_*nc_;
+ }
+
+ long long num_samples(
+ ) const { return inst.m_n; }
+
+ long long k(
+ ) const { return inst.m_k; }
+
+ long long nr(
+ ) const { return inst.m_nr; }
+
+ long long nc(
+ ) const { return inst.m_nc; }
+
+ size_t size(
+ ) const { return inst.m_size; }
+
+ alias_tensor_instance operator() (
+ tensor& t,
+ size_t offset = 0
+ ) const
+ {
+ DLIB_CASSERT(offset+size() <= t.size(),
+ "offset: "<<offset <<"\n"<<
+ "size(): "<<size() <<"\n"<<
+ "t.size(): "<<t.size() <<"\n");
+
+#ifdef DLIB_USE_CUDA
+ if (!inst.cudnn_descriptor)
+ {
+ inst.cudnn_descriptor = std::make_shared<cuda::tensor_descriptor>();
+ inst.cudnn_descriptor->set_size(inst.m_n, inst.m_k, inst.m_nr, inst.m_nc);
+ }
+#endif
+ inst.data_instance = &t.data();
+ inst._annotation = &t.annotation();
+ // Note that t might already be an aliasing tensor so we need to take that into
+ // account.
+ inst.data_offset = t.get_alias_offset()+offset;
+ return inst;
+ }
+
+ alias_tensor_const_instance operator() (
+ const tensor& t,
+ size_t offset = 0
+ ) const
+ {
+ alias_tensor_const_instance temp;
+ temp.inst = (*this)(const_cast<tensor&>(t),offset);
+ return temp;
+ }
+
+ private:
+ mutable alias_tensor_instance inst;
+ };
+
+ inline void serialize(const alias_tensor& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.num_samples(), out);
+ serialize(item.k(), out);
+ serialize(item.nr(), out);
+ serialize(item.nc(), out);
+ }
+
+ inline void deserialize(alias_tensor& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::alias_tensor.");
+ long long num_samples, k, nr, nc;
+ deserialize(num_samples, in);
+ deserialize(k, in);
+ deserialize(nr, in);
+ deserialize(nc, in);
+ item = alias_tensor(num_samples, k, nr, nc);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_TENSOR_H_
+
diff --git a/ml/dlib/dlib/dnn/tensor_abstract.h b/ml/dlib/dlib/dnn/tensor_abstract.h
new file mode 100644
index 000000000..73a9fff77
--- /dev/null
+++ b/ml/dlib/dlib/dnn/tensor_abstract.h
@@ -0,0 +1,727 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DNn_TENSOR_ABSTRACT_H_
+#ifdef DLIB_DNn_TENSOR_ABSTRACT_H_
+
+#include "../matrix.h"
+#include "../any/any_abstract.h"
+
+namespace dlib
+{
+// ----------------------------------------------------------------------------------------
+
+ class tensor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a 4D array of float values, all stored contiguously
+ in memory. Importantly, it keeps two copies of the floats, one on the host
+ CPU side and another on the GPU device side. It automatically performs the
+ necessary host/device transfers to keep these two copies of the data in
+ sync.
+
+ All transfers to the device happen asynchronously with respect to the
+ default CUDA stream so that CUDA kernel computations can overlap with data
+ transfers. However, any transfers from the device to the host happen
+ synchronously in the default CUDA stream. Therefore, you should perform
+ all your CUDA kernel launches on the default stream so that transfers back
+ to the host do not happen before the relevant computations have completed.
+
+ If DLIB_USE_CUDA is not #defined then this object will not use CUDA at all.
+ Instead, it will simply store one host side memory block of floats.
+
+ Finally, the convention in dlib code is to interpret the tensor as a set of
+ num_samples() 3D arrays, each of dimension k() by nr() by nc(). Also,
+ while this class does not specify a memory layout, the convention is to
+ assume that indexing into an element at coordinates (sample,k,r,c) can be
+ accomplished via:
+ host()[((sample*t.k() + k)*t.nr() + r)*t.nc() + c]
+
+ THREAD SAFETY
+ Instances of this object are not thread-safe. So don't touch one from
+ multiple threads at the same time.
+ !*/
+
+ public:
+
+ virtual ~tensor();
+
+ long long num_samples(
+ ) const;
+ /*!
+ ensures
+ - returns the number of 3D arrays of dimension k() by nr() by nc() there
+ are in this object.
+ !*/
+
+ long long k(
+ ) const;
+ /*!
+ ensures
+ - returns the k dimension of this tensor. Generally, we think of a tensor
+ as containing num_samples() images of nr() by nc() rows and columns, each
+ with k() channels.
+ !*/
+
+ long long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in this tensor.
+ !*/
+
+ long long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in this tensor.
+ !*/
+
+ size_t size(
+ ) const;
+ /*!
+ ensures
+ - returns num_samples()*k()*nr()*nc()
+ (i.e. the total number of floats in this tensor)
+ !*/
+
+ void async_copy_to_device(
+ ) const;
+ /*!
+ ensures
+ - This function does not block.
+ - if (the host version of the data is newer than the device's copy) then
+ - Begins asynchronously copying host data to the device.
+ - A call to device() that happens before the transfer completes will
+ block until the transfer is complete. That is, it is safe to call
+ async_copy_to_device() and then immediately call device().
+ !*/
+
+ typedef float* iterator;
+ typedef const float* const_iterator;
+ iterator begin() { return host(); }
+ const_iterator begin() const { return host(); }
+ iterator end() { return host()+size(); }
+ const_iterator end() const { return host()+size(); }
+ /*!
+ ensures
+ - makes a tensor iterable just like the STL containers.
+ !*/
+
+ virtual const float* host(
+ ) const = 0;
+ /*!
+ ensures
+ - returns a pointer to the host memory block of size() contiguous float
+ values or nullptr if size()==0.
+ - if (the host's copy of the data is out of date) then
+ - copies the data from the device to the host, while this is happening
+ the call to host() blocks.
+ !*/
+
+ virtual float* host(
+ ) = 0;
+ /*!
+ ensures
+ - returns a pointer to the host memory block of size() contiguous float
+ values or nullptr if size()==0.
+ - if (the host's copy of the data is out of date) then
+ - copies the data from the device to the host, while this is happening
+ the call to host() blocks.
+ - Marks the device side data as out of date so that the next call to
+ device() will perform a host to device transfer. If you want to begin
+ the transfer immediately then you can call async_copy_to_device() after
+ calling host().
+ !*/
+
+ virtual float* host_write_only(
+ ) = 0;
+ /*!
+ ensures
+ - This function returns the same pointer as host(), except that it never
+ performs a device to host memory copy. Instead, it immediately marks the
+ device side data as out of date, effectively discarding it. Therefore,
+ the values in the data pointed to by host_write_only() are undefined and
+ you should only call host_write_only() if you are going to assign to
+ every memory location in the returned memory block.
+ !*/
+
+ virtual const float* device(
+ ) const = 0;
+ /*!
+ requires
+ - DLIB_USE_CUDA is #defined
+ ensures
+ - returns a pointer to the device memory block of size() contiguous float
+ values or nullptr if size()==0.
+ - if (the device's copy of the data is out of date) then
+ - copies the data from the host to the device, while this is happening
+ the call to device() blocks.
+ !*/
+
+ virtual float* device(
+ ) = 0;
+ /*!
+ requires
+ - DLIB_USE_CUDA is #defined
+ ensures
+ - returns a pointer to the device memory block of size() contiguous float
+ values or nullptr if size()==0.
+ - if (the device's copy of the data is out of date) then
+ - copies the data from the host to the device, while this is happening
+ the call to device() blocks.
+ - Marks the host side data as out of date so that the next call to
+ host() will perform a device to host transfer.
+ !*/
+
+ virtual float* device_write_only(
+ ) = 0;
+ /*!
+ requires
+ - DLIB_USE_CUDA is #defined
+ ensures
+ - This function returns the same pointer as device(), except that it never
+ performs a host to device memory copy. Instead, it immediately marks the
+ host side data as out of date, effectively discarding it. Therefore, the
+ values in the data pointed to by device_write_only() are undefined and
+ you should only call device_write_only() if you are going to assign to
+ every memory location in the returned memory block.
+ !*/
+
+ virtual const any& annotation(
+ ) const = 0;
+ /*!
+ ensures
+ - returns a const reference to the any object in this tensor. The any
+ object can be used to store any additional annotation you like in a
+ tensor. However, it should be noted that the annotation() is ignored by
+ serialize() and therefore not saved when a tensor is serialized.
+ !*/
+
+ virtual any& annotation(
+ ) = 0;
+ /*!
+ ensures
+ - returns a non-const reference to the any object in this tensor. The any
+ object can be used to store any additional annotation you like in a
+ tensor. However, it should be noted that the annotation() is ignored by
+ serialize() and therefore not saved when a tensor is serialized.
+ !*/
+
+ int device_id(
+ ) const;
+ /*!
+ ensures
+ - returns the ID of the CUDA device that allocated this memory. I.e. the
+ number returned by cudaGetDevice() when the memory was allocated.
+ - If CUDA is not being used then this function always returns 0.
+ !*/
+
+ tensor& operator= (
+ float val
+ );
+ /*!
+ ensures
+ - sets all elements of this tensor equal to val.
+ - returns *this
+ !*/
+
+ tensor& operator*= (
+ float val
+ );
+ /*!
+ ensures
+ - pointwise multiplies all elements of *this tensor with val.
+ - returns *this
+ !*/
+
+ tensor& operator/= (
+ float val
+ );
+ /*!
+ ensures
+ - pointwise divides all elements of *this tensor with val.
+ - returns *this
+ !*/
+
+ template <typename EXP>
+ tensor& operator= (
+ const matrix_exp<EXP>& item
+ );
+ /*!
+ requires
+ - num_samples() == item.nr()
+ - k()*nr()*nc() == item.nc()
+ - item contains float values
+ ensures
+ - Assigns item to *this tensor by performing:
+ set_ptrm(host(), num_samples(), k()*nr()*nc()) = item;
+ !*/
+
+ template <typename EXP>
+ tensor& operator+= (
+ const matrix_exp<EXP>& item
+ );
+ /*!
+ requires
+ - num_samples() == item.nr()
+ - k()*nr()*nc() == item.nc()
+ - item contains float values
+ ensures
+ - Adds item to *this tensor by performing:
+ set_ptrm(host(), num_samples(), k()*nr()*nc()) += item;
+ !*/
+
+ template <typename EXP>
+ tensor& operator-= (
+ const matrix_exp<EXP>& item
+ );
+ /*!
+ requires
+ - num_samples() == item.nr()
+ - k()*nr()*nc() == item.nc()
+ - item contains float values
+ ensures
+ - Subtracts item from *this tensor by performing:
+ set_ptrm(host(), num_samples(), k()*nr()*nc()) -= item;
+ !*/
+
+ template <typename EXP>
+ void set_sample (
+ unsigned long long idx,
+ const matrix_exp<EXP>& item
+ );
+ /*!
+ requires
+ - idx < num_samples()
+ - k()*nr()*nc() == item.size()
+ - item contains float values
+ ensures
+ - Assigns item to the idx'th sample in *this by performing:
+ set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) = item;
+ !*/
+
+
+ template <typename EXP>
+ void add_to_sample (
+ unsigned long long idx,
+ const matrix_exp<EXP>& item
+ );
+ /*!
+ requires
+ - idx < num_samples()
+ - k()*nr()*nc() == item.size()
+ - item contains float values
+ ensures
+ - Adds item to the idx'th sample in *this by performing:
+ set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) += item;
+ !*/
+
+ protected:
+
+ // You can't move or copy another tensor into *this since that might modify the
+ // tensor's dimensions. If you want to do that sort of thing then use a
+ // resizable_tensor.
+ tensor(const tensor& item);
+ tensor& operator= (const tensor& item);
+ tensor(tensor&& item);
+ tensor& operator=(tensor&& item);
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void memcpy (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - dest.size() == src.size()
+ ensures
+ - Copies the data in src to dest. If the device data is current on both src
+ and dest then the copy will happen entirely on the device side.
+ - It doesn't matter what GPU device is selected by cudaSetDevice(). You can
+ always copy tensor objects to and from each other regardless.
+ - This function blocks until the copy has completed.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_vector (
+ const tensor& t
+ );
+ /*!
+ ensures
+ - returns true if and only if one of the following is true:
+ - t.size() == t.num_samples()
+ - t.size() == t.k()
+ - t.size() == t.nr()
+ - t.size() == t.nc()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp mat (
+ const tensor& t,
+ long long nr,
+ long long nc
+ );
+ /*!
+ requires
+ - nr >= 0
+ - nc >= 0
+ - nr*nc == t.size()
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == nr
+ - m.nc() == nc
+ - for all valid r and c:
+ M(r,c) == t.host()[r*nc + c]
+ (i.e. the tensor is interpreted as a matrix laid out in memory
+ in row major order)
+ !*/
+
+ const matrix_exp mat (
+ const tensor& t
+ );
+ /*!
+ ensures
+ - if (t.size() != 0) then
+ - returns mat(t, t.num_samples(), t.size()/t.num_samples())
+ - else
+ - returns an empty matrix.
+ !*/
+
+ const matrix_exp image_plane (
+ const tensor& t,
+ long long sample = 0,
+ long long k = 0
+ );
+ /*!
+ requires
+ - t.size() != 0
+ - 0 <= sample < t.num_samples()
+ - 0 <= k < t.k()
+ ensures
+ - returns the k-th image plane from the sample-th image in t. That is,
+ returns a matrix M such that:
+ - M contains float valued elements.
+ - M.nr() == t.nr()
+ - M.nc() == t.nc()
+ - for all valid r and c:
+ - M(r,c) == t.host()[((sample*t.k() + k)*t.nr() + r)*t.nc() + c]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool have_same_dimensions (
+ const tensor& a,
+ const tensor& b
+ );
+ /*!
+ ensures
+ - returns true if and only if all of the fallowing are satisfied:
+ - a.num_samples() == b.num_samples()
+ - a.k() == b.k()
+ - a.nr() == b.nr()
+ - a.nc() == b.nc()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class resizable_tensor : public tensor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is just a tensor with the additional ability to be resized.
+ !*/
+
+ public:
+ resizable_tensor(
+ );
+ /*!
+ ensures
+ - #size() == 0
+ - #num_samples() == 0
+ - #k() == 0
+ - #nr() == 0
+ - #nc() == 0
+ - #capacity() == 0
+ !*/
+
+ template <typename EXP>
+ resizable_tensor(
+ const matrix_exp<EXP>& item
+ );
+ /*!
+ requires
+ - item contains float values
+ ensures
+ - #num_samples() == item.nr()
+ - #k() == item.nc()
+ - #nr() == 1
+ - #nc() == 1
+ - Assigns item to *this tensor by performing:
+ set_ptrm(host(), num_samples(), k()*nr()*nc()) = item;
+ - #capacity() == size()
+ !*/
+
+ explicit resizable_tensor(
+ long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
+ );
+ /*!
+ requires
+ - n_ >= 0
+ - k_ >= 0
+ - nr_ >= 0
+ - nc_ >= 0
+ ensures
+ - #size() == n_*k_*nr_*nc_
+ - #num_samples() == n_
+ - #k() == k_
+ - #nr() == nr_
+ - #nc() == nc_
+ - #capacity() == size()
+ !*/
+
+ // This object is copyable and movable
+ resizable_tensor(const resizable_tensor&) = default;
+ resizable_tensor(resizable_tensor&&) = default;
+ resizable_tensor& operator= (const resizable_tensor&) = default;
+ resizable_tensor& operator= (resizable_tensor&&) = default;
+
+ size_t capacity (
+ ) const;
+ /*!
+ ensures
+ - returns the total number of floats allocated. This might be different
+ from the size() since calls to set_size() that make a tensor smaller
+ don't trigger reallocations. They simply adjust the nominal dimensions
+ while keeping the same allocated memory block. This makes calls to
+ set_size() very fast. If you need to deallocate a tensor then use
+ clear().
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #size() == 0
+ - #num_samples() == 0
+ - #k() == 0
+ - #nr() == 0
+ - #nc() == 0
+ - #annotation().is_empty() == true
+ - #capacity() == 0
+ !*/
+
+ void copy_size (
+ const tensor& item
+ );
+ /*!
+ ensures
+ - resizes *this so that: have_same_dimensions(#*this, item)==true
+ !*/
+
+ void set_size(
+ long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
+ );
+ /*!
+ requires
+ - n_ >= 0
+ - k_ >= 0
+ - nr_ >= 0
+ - nc_ >= 0
+ ensures
+ - #size() == n_*k_*nr_*nc_
+ - #num_samples() == n_
+ - #k() == k_
+ - #nr() == nr_
+ - #nc() == nc_
+ - #capacity() == max(#size(), capacity())
+ (i.e. capacity() never goes down when calling set_size().)
+ !*/
+
+ template <typename EXP>
+ resizable_tensor& operator= (
+ const matrix_exp<EXP>& item
+ );
+ /*!
+ requires
+ - item contains float values
+ ensures
+ - if (num_samples() == item.nr() && k()*nr()*nc() == item.nc()) then
+ - the dimensions of this tensor are not changed
+ - else
+ - #num_samples() == item.nr()
+ - #k() == item.nc()
+ - #nr() == 1
+ - #nc() == 1
+ - Assigns item to *this tensor by performing:
+ set_ptrm(host(), num_samples(), k()*nr()*nc()) = item;
+ !*/
+ };
+
+ void serialize(const tensor& item, std::ostream& out);
+ void deserialize(resizable_tensor& item, std::istream& in);
+ /*!
+ provides serialization support for tensor and resizable_tensor. Note that you can
+ serialize to/from any combination of tenor and resizable_tensor objects.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ double dot(
+ const tensor& a,
+ const tensor& b
+ );
+ /*!
+ requires
+ - a.size() == b.size()
+ ensures
+ - returns the dot product between a and b when they are both treated as
+ a.size() dimensional vectors. That is, this function pointwise multiplies
+ the vectors together, then sums the result and returns it.
+
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class alias_tensor_instance : public tensor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tensor that aliases another tensor. That is, it doesn't
+ have its own block of memory but instead simply holds pointers to the
+ memory of another tensor object. It therefore allows you to efficiently
+ break a tensor into pieces and pass those pieces into functions.
+
+ An alias_tensor_instance doesn't own the resources it points to in any sense.
+ So it is important to make sure that the underlying owning tensor doesn't get
+ destructed before any alias tensors which point to it are destructed.
+ !*/
+
+ // You can't default initialize this object. You can only get instances of it from
+ // alias_tensor::operator().
+ alias_tensor_instance(
+ );
+ };
+
+ class alias_tensor_const_instance
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is essentially a const version of alias_tensor_instance and therefore
+ represents a tensor. However, due to the mechanics of C++, this object
+ can't inherit from tensor. So instead it provides a get() and an implicit
+ conversion to const tensor.
+ !*/
+
+ public:
+
+ // non-const alias tensors are convertible to const ones.
+ alias_tensor_const_instance(const alias_tensor_instance& item);
+
+ // Methods that cast the alias to a tensor.
+ const tensor& get() const;
+ operator const tensor& ();
+
+ private:
+ // You can't default initialize this object. You can only get instances of it from
+ // alias_tensor::operator().
+ alias_tensor_const_instance();
+ };
+
+ class alias_tensor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool for creating tensor objects that alias other tensor objects.
+ That is, it allows you to make a tensor that references the memory space of
+ another tensor object rather than owning its own memory. This allows you
+ to do things like interpret a single tensor in different ways or even as a
+ group of multiple tensors.
+ !*/
+ public:
+
+ alias_tensor (
+ );
+ /*!
+ ensures
+ - #size() == 0
+ - #num_samples() == 0
+ - #k() == 0
+ - #nr() == 0
+ - #nc() == 0
+ !*/
+
+ alias_tensor (
+ long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1
+ );
+ /*!
+ requires
+ - n_ >= 0
+ - k_ >= 0
+ - nr_ >= 0
+ - nc_ >= 0
+ ensures
+ - #size() == n_*k_*nr_*nc_
+ - #num_samples() == n_
+ - #k() == k_
+ - #nr() == nr_
+ - #nc() == nc_
+ !*/
+
+ long long num_samples() const;
+ long long k() const;
+ long long nr() const;
+ long long nc() const;
+ size_t size() const;
+
+ alias_tensor_instance operator() (
+ tensor& t,
+ size_t offset = 0
+ ) const;
+ /*!
+ requires
+ - offset+size() <= t.size()
+ ensures
+ - Returns a tensor that simply aliases the elements of t beginning with t's
+ offset'th element. Specifically, this function returns an aliasing
+ tensor T such that:
+ - T.size() == size()
+ - T.num_samples() == num_samples()
+ - T.k() == k()
+ - T.nr() == nr()
+ - T.nc() == nc()
+ - T.host() == t.host()+offset
+ - T.device() == t.device()+offset
+ - &T.annotation() == &t.annotation()
+ !*/
+
+ alias_tensor_const_instance operator() (
+ const tensor& t,
+ size_t offset = 0
+ ) const;
+ /*!
+ requires
+ - offset+size() <= t.size()
+ ensures
+ - This function is identical to the above version of operator() except that
+ it takes and returns const tensors instead of non-const tensors.
+ !*/
+ };
+
+ void serialize(const alias_tensor& item, std::ostream& out);
+ void deserialize(alias_tensor& item, std::istream& in);
+ /*!
+ provides serialization support for alias_tensor.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_TENSOR_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/dnn/tensor_tools.cpp b/ml/dlib/dlib/dnn/tensor_tools.cpp
new file mode 100644
index 000000000..c0f7fd69d
--- /dev/null
+++ b/ml/dlib/dlib/dnn/tensor_tools.cpp
@@ -0,0 +1,985 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TeNSOR_TOOLS_CPP_
+#define DLIB_TeNSOR_TOOLS_CPP_
+
+#include "tensor_tools.h"
+#include "../string.h"
+#include <atomic>
+
+namespace dlib
+{
+ namespace
+ {
+ std::atomic<bool>& dnn_prefer_fastest_algo (
+ )
+ {
+ static std::atomic<bool> var(true);
+ return var;
+ }
+ }
+
+ bool dnn_prefer_fastest_algorithms (
+ )
+ {
+ return dnn_prefer_fastest_algo();
+ }
+
+ void set_dnn_prefer_fastest_algorithms(
+ )
+ {
+ dnn_prefer_fastest_algo() = true;
+ }
+
+ void set_dnn_prefer_smallest_algorithms(
+ )
+ {
+ dnn_prefer_fastest_algo() = false;
+ }
+}
+
+namespace dlib { namespace tt
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void inverse_norms (
+ resizable_tensor& invnorms,
+ const tensor& data,
+ const double eps
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::inverse_norms(invnorms, data, eps);
+#else
+ invnorms = reciprocal(sqrt(sum_cols(squared(mat(data))) + eps));
+#endif
+ }
+
+ void dot_prods (
+ resizable_tensor& out,
+ const tensor& lhs,
+ const tensor& rhs
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::dot_prods(out, lhs, rhs);
+#else
+ out = sum_cols(pointwise_multiply(mat(lhs), mat(rhs)));
+#endif
+ }
+
+ void dot_prods (
+ bool add_to,
+ tensor& out,
+ const tensor& lhs,
+ const tensor& rhs
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::dot_prods(add_to, out, lhs, rhs);
+#else
+ if (add_to)
+ out += sum_cols(pointwise_multiply(mat(lhs), mat(rhs)));
+ else
+ out = sum_cols(pointwise_multiply(mat(lhs), mat(rhs)));
+#endif
+ }
+
+ void scale_columns (
+ tensor& out,
+ const tensor& m,
+ const tensor& v
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(out,m));
+ DLIB_CASSERT(is_vector(v));
+ if (m.size() == 0 && v.size() == 0)
+ return;
+ DLIB_CASSERT(m.size() != 0);
+ DLIB_CASSERT(m.size()/m.num_samples() == v.size());
+
+#ifdef DLIB_USE_CUDA
+ cuda::scale_columns(out, m, v);
+#else
+ DLIB_CASSERT(false, "shouldn't be called right now");
+ out = scale_columns(mat(m), mat(v));
+#endif
+ }
+
+ void scale_rows (
+ tensor& out,
+ const tensor& m,
+ const tensor& v
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(out,m));
+ DLIB_CASSERT(is_vector(v));
+ if (m.size() == 0 && v.size() == 0)
+ return;
+ DLIB_CASSERT(m.size() != 0);
+ DLIB_CASSERT(m.num_samples() == v.size());
+
+#ifdef DLIB_USE_CUDA
+ cuda::scale_rows(out, m, v);
+#else
+ out = scale_rows(mat(m), mat(v));
+#endif
+ }
+
+ void scale_rows2 (
+ float beta,
+ tensor& out,
+ const tensor& m1,
+ const tensor& m2,
+ const tensor& v1,
+ const tensor& v2
+ )
+ {
+ DLIB_CASSERT(have_same_dimensions(out,m1));
+ DLIB_CASSERT(have_same_dimensions(out,m2));
+ DLIB_CASSERT(have_same_dimensions(v1,v2));
+ DLIB_CASSERT(is_vector(mat(v1)));
+ DLIB_CASSERT(v1.size() == m1.num_samples());
+
+#ifdef DLIB_USE_CUDA
+ cuda::scale_rows2(beta, out, m1, m2, v1, v2);
+#else
+ if (beta == 0)
+ out = scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2));
+ else
+ out = beta*mat(out) + scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void exp (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(dest.size() == src.size());
+
+#ifdef DLIB_USE_CUDA
+ cuda::exp(dest,src);
+#else
+ dest = exp(mat(src));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void log (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(dest.size() == src.size());
+
+#ifdef DLIB_USE_CUDA
+ cuda::log(dest,src);
+#else
+ dest = log(mat(src));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void log10 (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+ DLIB_CASSERT(dest.size() == src.size());
+
+#ifdef DLIB_USE_CUDA
+ cuda::log10(dest,src);
+#else
+ dest = log10(mat(src));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void gemm (
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& lhs,
+ bool trans_lhs,
+ const tensor& rhs,
+ bool trans_rhs
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::gemm(beta, dest, alpha, lhs, trans_lhs, rhs, trans_rhs);
+#else
+ if (beta != 0)
+ {
+ if (trans_lhs && trans_rhs)
+ dest = alpha*trans(mat(lhs))*trans(mat(rhs)) + beta*mat(dest);
+ else if (!trans_lhs && trans_rhs)
+ dest = alpha*mat(lhs)*trans(mat(rhs)) + beta*mat(dest);
+ else if (trans_lhs && !trans_rhs)
+ dest = alpha*trans(mat(lhs))*mat(rhs) + beta*mat(dest);
+ else
+ dest = alpha*mat(lhs)*mat(rhs) + beta*mat(dest);
+ }
+ else
+ {
+ if (trans_lhs && trans_rhs)
+ dest = alpha*trans(mat(lhs))*trans(mat(rhs));
+ else if (!trans_lhs && trans_rhs)
+ dest = alpha*mat(lhs)*trans(mat(rhs));
+ else if (trans_lhs && !trans_rhs)
+ dest = alpha*trans(mat(lhs))*mat(rhs);
+ else
+ dest = alpha*mat(lhs)*mat(rhs);
+ }
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ tensor_rand::
+ tensor_rand(
+ unsigned long long seed
+ )
+#ifdef DLIB_USE_CUDA
+ :rnd(seed){}
+#else
+ {rnd.set_seed(cast_to_string(seed)); }
+#endif
+
+ void tensor_rand::
+ fill_gaussian (
+ tensor& data,
+ float mean,
+ float stddev
+ )
+ {
+ DLIB_CASSERT(data.size()%2 == 0);
+#ifdef DLIB_USE_CUDA
+ rnd.fill_gaussian(data, mean, stddev);
+#else
+ for (auto& x : data)
+ x = rnd.get_random_gaussian()*stddev + mean;
+#endif
+ }
+
+ void tensor_rand::
+ fill_uniform (
+ tensor& data
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ rnd.fill_uniform(data);
+#else
+ for (auto& x : data)
+ x = rnd.get_random_float();
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void multiply (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+ DLIB_CASSERT(dest.k() == src1.k() && src1.k() == src2.k() &&
+ dest.nr() == src1.nr() && src1.nr() == src2.nr() &&
+ dest.nc() == src1.nc() && src1.nc() == src2.nc() );
+ const long MD = std::max(std::max(dest.num_samples(),src1.num_samples()),src2.num_samples());
+ DLIB_CASSERT((dest.num_samples()==1 || dest.num_samples()==MD) &&
+ (src1.num_samples()==1 || src1.num_samples()==MD) &&
+ (src2.num_samples()==1 || src2.num_samples()==MD) );
+#ifdef DLIB_USE_CUDA
+ cuda::multiply(add_to, dest, src1, src2);
+#else
+ cpu::multiply(add_to, dest, src1, src2);
+#endif
+
+ }
+
+ void scale_channels (
+ bool add_to,
+ tensor& dest,
+ const tensor& src,
+ const tensor& scales
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::scale_channels(add_to, dest, src, scales);
+#else
+ cpu::scale_channels(add_to, dest, src, scales);
+#endif
+ }
+
+ void multiply_conv (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::multiply_conv(add_to, dest, src1, src2);
+#else
+ cpu::multiply_conv(add_to, dest, src1, src2);
+#endif
+ }
+
+ void multiply_zero_padded (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::multiply_zero_padded(add_to, dest, src1, src2);
+#else
+ cpu::multiply_zero_padded(add_to, dest, src1, src2);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A,
+ const float B
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform(dest,src,A,B);
+#else
+ cpu::affine_transform(dest,src,A,B);
+#endif
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform(dest,src,A);
+#else
+ cpu::affine_transform(dest,src,A,0);
+#endif
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B,
+ const float C
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform(dest,src1,src2,A,B,C);
+#else
+ cpu::affine_transform(dest,src1,src2,A,B,C);
+#endif
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform(dest,src1,src2,A,B);
+#else
+ cpu::affine_transform(dest,src1,src2,A,B,0);
+#endif
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C,
+ const float D
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform(dest,src1,src2,src3,A,B,C,D);
+#else
+ cpu::affine_transform(dest,src1,src2,src3,A,B,C,D);
+#endif
+ }
+
+ void affine_transform_range(
+ size_t begin,
+ size_t end,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform_range(begin, end, dest,src1,src2,src3,A,B,C);
+#else
+ cpu::affine_transform_range(begin, end, dest,src1,src2,src3,A,B,C);
+#endif
+ }
+
+ void affine_transform(
+ const rectangle& rect,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ float A,
+ float B,
+ float C
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform(rect, dest,src1,src2,src3,A,B,C);
+#else
+ cpu::affine_transform(rect, dest,src1,src2,src3,A,B,C);
+#endif
+ }
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform_range(0,dest.size(),dest,src1,src2,src3,A,B,C);
+#else
+ cpu::affine_transform_range(0,dest.size(),dest,src1,src2,src3,A,B,C);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform(dest,src,A,B);
+#else
+ cpu::affine_transform(dest,src,A,B);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void affine_transform_conv(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::affine_transform_conv(dest,src,A,B);
+#else
+ cpu::affine_transform_conv(dest,src,A,B);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void compute_adam_update (
+ size_t begin,
+ size_t end,
+ tensor& s,
+ tensor& m,
+ tensor& v,
+ const float t,
+ const float learning_rate,
+ const float weight_decay,
+ const float momentum1,
+ const float momentum2,
+ const tensor& params,
+ const tensor& params_grad
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::compute_adam_update(begin, end, s, m, v, t, learning_rate, weight_decay, momentum1,
+ momentum2, params, params_grad);
+#else
+ cpu::compute_adam_update(begin, end, s, m, v, t, learning_rate, weight_decay, momentum1,
+ momentum2, params, params_grad);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void batch_normalize_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::batch_normalize_inference(eps,dest,src,gamma,beta,running_means,running_variances);
+#else
+ cpu::batch_normalize_inference(eps,dest,src,gamma,beta,running_means,running_variances);
+#endif
+ }
+
+ void batch_normalize (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& vars,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::batch_normalize(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta);
+#else
+ cpu::batch_normalize(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta);
+#endif
+ }
+
+ void batch_normalize_gradient (
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ )
+ {
+
+#ifdef DLIB_USE_CUDA
+ cuda::batch_normalize_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
+#else
+ cpu::batch_normalize_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void batch_normalize_conv_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::batch_normalize_conv_inference(eps,dest,src,gamma,beta,running_means,running_variances);
+#else
+ cpu::batch_normalize_conv_inference(eps,dest,src,gamma,beta,running_means,running_variances);
+#endif
+ }
+
+ void batch_normalize_conv (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& vars,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::batch_normalize_conv(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta);
+#else
+ cpu::batch_normalize_conv(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta);
+#endif
+ }
+
+ void batch_normalize_conv_gradient (
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ )
+ {
+
+#ifdef DLIB_USE_CUDA
+ cuda::batch_normalize_conv_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
+#else
+ cpu::batch_normalize_conv_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threshold (
+ tensor& data,
+ float thresh
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::threshold(data,thresh);
+#else
+ cpu::threshold(data,thresh);
+#endif
+ }
+
+ void dot (
+ const tensor& a,
+ const tensor& b,
+ tensor& result,
+ size_t idx
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::dot(a,b,result,idx);
+#else
+ cpu::dot(a,b,result,idx);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void add(
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& src
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::add(beta,dest,alpha,src);
+#else
+ cpu::add(beta,dest,alpha,src);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void add (
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::add(dest, src1, src2);
+#else
+ cpu::add(dest, src1, src2);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void assign_conv_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::assign_conv_bias_gradient(grad,gradient_input);
+#else
+ cpu::assign_conv_bias_gradient(grad,gradient_input);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void assign_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::assign_bias_gradient(grad,gradient_input);
+#else
+ cpu::assign_bias_gradient(grad,gradient_input);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void softmax (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::softmax(dest,src);
+#else
+ cpu::softmax(dest,src);
+#endif
+ }
+
+ void softmax_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::softmax_gradient(grad, dest, gradient_input);
+#else
+ cpu::softmax_gradient(grad, dest, gradient_input);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void softmax_all (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::softmax_all(dest,src);
+#else
+ cpu::softmax_all(dest,src);
+#endif
+ }
+
+ void softmax_all_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::softmax_all_gradient(grad, dest, gradient_input);
+#else
+ cpu::softmax_all_gradient(grad, dest, gradient_input);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void sigmoid (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::sigmoid(dest,src);
+#else
+ cpu::sigmoid(dest,src);
+#endif
+ }
+
+ void sigmoid_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::sigmoid_gradient(grad, dest, gradient_input);
+#else
+ cpu::sigmoid_gradient(grad, dest, gradient_input);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void relu (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::relu(dest,src);
+#else
+ cpu::relu(dest,src);
+#endif
+ }
+
+ void relu_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::relu_gradient(grad, dest, gradient_input);
+#else
+ cpu::relu_gradient(grad, dest, gradient_input);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void prelu (
+ tensor& dest,
+ const tensor& src,
+ const tensor& param
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::prelu(dest, src, param);
+#else
+ cpu::prelu(dest, src, param);
+#endif
+ }
+
+ void prelu_gradient (
+ tensor& grad,
+ const tensor& src,
+ const tensor& gradient_input,
+ const tensor& param,
+ tensor& params_grad
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::prelu_gradient(grad, src, gradient_input, param, params_grad);
+#else
+ cpu::prelu_gradient(grad, src, gradient_input, param, params_grad);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tanh (
+ tensor& dest,
+ const tensor& src
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::tanh(dest,src);
+#else
+ cpu::tanh(dest,src);
+#endif
+ }
+
+ void tanh_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::tanh_gradient(grad, dest, gradient_input);
+#else
+ cpu::tanh_gradient(grad, dest, gradient_input);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void resize_bilinear (
+ tensor& dest,
+ long dest_row_stride,
+ long dest_channel_stride,
+ const tensor& src,
+ long src_row_stride,
+ long src_channel_stride
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::resize_bilinear(dest,dest_row_stride,dest_channel_stride, src,src_row_stride,src_channel_stride);
+#else
+ cpu::resize_bilinear(dest,dest_row_stride,dest_channel_stride, src,src_row_stride,src_channel_stride);
+#endif
+ }
+
+ void resize_bilinear_gradient (
+ tensor& grad,
+ long grad_row_stride,
+ long grad_channel_stride,
+ const tensor& gradient_input,
+ long gradient_input_row_stride,
+ long gradient_input_channel_stride
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::resize_bilinear_gradient(grad,grad_row_stride,grad_channel_stride, gradient_input,gradient_input_row_stride,gradient_input_channel_stride);
+#else
+ cpu::resize_bilinear_gradient(grad,grad_row_stride,grad_channel_stride, gradient_input,gradient_input_row_stride,gradient_input_channel_stride);
+#endif
+ }
+
+// ------------------------------------------------------------------------------------
+
+ void copy_tensor(
+ bool add_to,
+ tensor& dest,
+ size_t dest_k_offset,
+ const tensor& src,
+ size_t src_k_offset,
+ size_t count_k
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ cuda::copy_tensor(add_to, dest, dest_k_offset, src, src_k_offset, count_k);
+#else
+ cpu::copy_tensor(add_to, dest, dest_k_offset, src, src_k_offset, count_k);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void inv::
+ operator() (
+ const tensor& m,
+ resizable_tensor& out
+ )
+ {
+#ifdef DLIB_USE_CUDA
+ finv(m,out);
+#else
+ out = dlib::inv(mat(m));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}}
+
+#endif // DLIB_TeNSOR_TOOLS_CPP_
+
diff --git a/ml/dlib/dlib/dnn/tensor_tools.h b/ml/dlib/dlib/dnn/tensor_tools.h
new file mode 100644
index 000000000..9ba3154e5
--- /dev/null
+++ b/ml/dlib/dlib/dnn/tensor_tools.h
@@ -0,0 +1,1711 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TeNSOR_TOOLS_H_
+#define DLIB_TeNSOR_TOOLS_H_
+
+#include "tensor.h"
+#include "cudnn_dlibapi.h"
+#include "cublas_dlibapi.h"
+#include "cusolver_dlibapi.h"
+#include "curand_dlibapi.h"
+#include "cpu_dlib.h"
+#include "cuda_dlib.h"
+#include "../rand.h"
+#include <memory>
+#include "../geometry/rectangle.h"
+#include "../test_for_odr_violations.h"
+
+namespace dlib
+{
+ bool dnn_prefer_fastest_algorithms();
+ void set_dnn_prefer_fastest_algorithms();
+ void set_dnn_prefer_smallest_algorithms();
+}
+
+namespace dlib { namespace tt
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void inverse_norms (
+ resizable_tensor& invnorms,
+ const tensor& data,
+ const double eps
+ );
+ /*!
+ ensures
+ - #invnorms == reciprocal(sqrt(sum_cols(squared(mat(data))) + eps))
+ !*/
+
+ void dot_prods (
+ resizable_tensor& out,
+ const tensor& lhs,
+ const tensor& rhs
+ );
+ /*!
+ requires
+ - have_same_dimensions(lhs,rhs) == true
+ ensures
+ - #out.num_samples() == lhs.num_samples()
+ - #out.k() == #out.nr() == #out.nc() == 1
+ - #out == sum_cols(pointwise_multiply(mat(lhs), mat(rhs)));
+ !*/
+
+ void dot_prods (
+ bool add_to,
+ tensor& out,
+ const tensor& lhs,
+ const tensor& rhs
+ );
+ /*!
+ requires
+ - have_same_dimensions(lhs,rhs) == true
+ - out.size() == lhs.num_samples()
+ - out.k() == out.nr() == out.nc() == 1
+ ensures
+ - if (add_to) then
+ - #out == mat(out) + sum_cols(pointwise_multiply(mat(lhs), mat(rhs)));
+ - else
+ - #out == sum_cols(pointwise_multiply(mat(lhs), mat(rhs)));
+ !*/
+
+ void scale_columns (
+ tensor& out,
+ const tensor& m,
+ const tensor& v
+ );
+ /*!
+ requires
+ - have_same_dimensions(out,m) == true
+ - is_vector(v) == true
+ - v.size() == mat(m).nc()
+ ensures
+ - performs: out = scale_columns(mat(m),mat(v));
+ !*/
+
+ void scale_rows (
+ tensor& out,
+ const tensor& m,
+ const tensor& v
+ );
+ /*!
+ requires
+ - have_same_dimensions(out,m) == true
+ - is_vector(v) == true
+ - v.size() == m.num_samples()
+ ensures
+ - performs: out = scale_rows(mat(m),mat(v));
+ !*/
+
+ void scale_rows2 (
+ float beta,
+ tensor& out,
+ const tensor& m1,
+ const tensor& m2,
+ const tensor& v1,
+ const tensor& v2
+ );
+ /*!
+ requires
+ - have_same_dimensions(out,m1) == true
+ - have_same_dimensions(out,m2) == true
+ - have_same_dimensions(v1,v2) == true
+ - is_vector(v1) == true
+ - v1.size() == m1.num_samples()
+ ensures
+ - performs:
+ out = beta*out + scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2));
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void exp (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - dest.size() == src.size()
+ ensures
+ - performs: dest = exp(mat(src))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void log (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - dest.size() == src.size()
+ ensures
+ - performs: dest = log(mat(src))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void log10 (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - dest.size() == src.size()
+ ensures
+ - performs: dest = log10(mat(src))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void gemm (
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& lhs,
+ bool trans_lhs,
+ const tensor& rhs,
+ bool trans_rhs
+ );
+ /*!
+ requires
+ - dest does not alias the memory of lhs or rhs
+ - The dimensions of lhs and rhs must be compatible for matrix multiplication.
+ In particular:
+ - Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
+ - Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
+ - Let D == mat(dest)
+ - D.nr() == L.nr() && D.nc() == R.nc()
+ (i.e. dest must be preallocated and have the correct output dimensions)
+ - L.nc() == R.nr()
+ ensures
+ - performs: dest = alpha*L*R + beta*mat(dest)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class inv
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a functor for doing matrix inversion on the GPU. The only
+ reason it's an object is to avoid the reallocation of some GPU memory
+ blocks if you want to do a bunch of matrix inversions in a row.
+ !*/
+ public:
+
+ void operator() (
+ const tensor& m,
+ resizable_tensor& out
+ );
+ /*!
+ requires
+ - m.size() == m.num_samples()*m.num_samples()
+ (i.e. mat(m) must be a square matrix)
+ ensures
+ - out == inv(mat(m));
+ !*/
+
+ private:
+#ifdef DLIB_USE_CUDA
+ cuda::inv finv;
+#endif
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class tensor_rand
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool for filling a tensor with random numbers.
+
+ Note that the sequence of random numbers output by this object is different
+ when dlib is compiled with DLIB_USE_CUDA. So you should not write code
+ that depends on any specific sequence of numbers coming out of a
+ tensor_rand.
+
+ !*/
+
+ public:
+ // not copyable
+ tensor_rand(const tensor_rand&) = delete;
+ tensor_rand& operator=(const tensor_rand&) = delete;
+
+ tensor_rand() : tensor_rand(0) {}
+ tensor_rand(unsigned long long seed);
+
+ void fill_gaussian (
+ tensor& data,
+ float mean = 0,
+ float stddev = 1
+ );
+ /*!
+ requires
+ - data.size()%2 == 0
+ ensures
+ - Fills data with random numbers drawn from a Gaussian distribution
+ with the given mean and standard deviation.
+ !*/
+
+ void fill_uniform (
+ tensor& data
+ );
+ /*!
+ ensures
+ - Fills data with uniform random numbers in the range (0.0, 1.0].
+ !*/
+
+#ifdef DLIB_USE_CUDA
+ cuda::curand_generator rnd;
+#else
+ dlib::rand rnd;
+#endif
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void multiply (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+ /*!
+ requires
+ - dest.k() == src1.k() == src2.k()
+ - dest.nr() == src1.nr() == src2.nr()
+ - dest.nc() == src1.nc() == src2.nc()
+ - dest.num_samples(), src1.num_samples(), and src2.num_samples() must each
+ either be 1 or whichever ones aren't equal to 1 must have the same values.
+ ensures
+ - let MD = max(dest.num_samples(), src1.num_samples(), src2.num_samples)
+ - This function pointwise multiplies src1 with src2 and stores the result into
+ #dest. However, how the multiplication happens depends on the dimensions of
+ the tensors. First, when src1 and src2 are multiplied together, if either
+ has a num_samples() dimension that is != MD, then it is first replicated to
+ produce a tensor with num_samples()==MD dimensions and then they are
+ pointwise multiplied together.
+
+ Second, if dest.num_samples()==1, then after the pointwise multiplication of
+ src1 with src2, the result has its samples summed to produce an output tensor
+ with num_samples()==1 which is then assigned to #dest.
+ - if (add_to) then
+ - Instead of assigning the result to dest, this function adds the result to dest.
+ !*/
+
+ void scale_channels (
+ bool add_to,
+ tensor& dest,
+ const tensor& src,
+ const tensor& scales
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ - scales.num_samples() == src.num_samples()
+ - scales.k() == src.k()
+ - scales.nr() == 1
+ - scales.nc() == 1
+ ensures
+ - Scales each channel of src by the corresponding value in scales. To be
+ precise, we will have:
+ - #dest(n,k,r,c) == src(n,k,r,c)*scales(n,k,1,1)
+ - if (add_to) then
+ - Instead of assigning the result to dest, this function adds the result to dest.
+ !*/
+
+ void multiply_conv (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+ /*!
+ requires
+ - if (have_same_dimensions(dest, src1) == true) then
+ - src2.num_samples() == 1
+ - src2.nr() == 1
+ - src2.nc() == 1
+ - src2.k() == src1.k()
+ - else
+ - have_same_dimensions(src1, src2) == true)
+ - dest.num_samples() == 1
+ - dest.nr() == 1
+ - dest.nc() == 1
+ - dest.k() == src1.k()
+ ensures
+ - Performs #dest == src1*src2
+ In particular, if the elements of dest, src1, and src2 were indexed by (n,k,r,c) then
+ we would have:
+ - if (have_same_dimensions(dest,src1)) then
+ #dest(n,k,r,c) == src1(n,k,r,c)*src2(k)
+ - else
+ #dest(k) == sum over {n,r,c} of src1(n,k,r,c)*src2(n,k,r,c)
+ - if (add_to) then
+ - Instead of assigning the result to dest, this function adds the result to dest.
+ !*/
+
+ void multiply_zero_padded (
+ bool add_to,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+ /*!
+ ensures
+ - if (add_to) then
+ - performs: dest += src1 * src2
+ - else
+ - performs: dest = src1 * src2
+ - In either case, the multiplication happens pointwise according to 4D tensor
+ arithmetic. If the dimensions don't match then missing elements are presumed
+ to be equal to 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A,
+ const float B
+ );
+ /*!
+ requires
+ - dest.size()==src.size()
+ ensures
+ - #dest == A*src + B
+ !*/
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const float A
+ );
+ /*!
+ requires
+ - dest.size()==src.size()
+ ensures
+ - #dest == A*src
+ !*/
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B,
+ const float C
+ );
+ /*!
+ requires
+ - dest.size()==src1.size()
+ - dest.size()==src2.size()
+ ensures
+ - #dest == A*src1 + B*src2 + C
+ !*/
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const float A,
+ const float B
+ );
+ /*!
+ requires
+ - dest.size()==src1.size()
+ - dest.size()==src2.size()
+ ensures
+ - #dest == A*src1 + B*src2
+ !*/
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C,
+ const float D
+ );
+ /*!
+ requires
+ - dest.size()==src1.size()
+ - dest.size()==src2.size()
+ - dest.size()==src3.size()
+ ensures
+ - #dest == A*src1 + B*src2 + C*src3 + D
+ !*/
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C
+ );
+ /*!
+ requires
+ - dest.size()==src1.size()
+ - dest.size()==src2.size()
+ - dest.size()==src3.size()
+ ensures
+ - #dest == A*src1 + B*src2 + C*src3
+ !*/
+
+ void affine_transform_range(
+ size_t begin,
+ size_t end,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ const float A,
+ const float B,
+ const float C
+ );
+ /*!
+ requires
+ - dest.size()==src1.size()
+ - dest.size()==src2.size()
+ - dest.size()==src3.size()
+ - begin <= end <= dest.size()
+ ensures
+ - This function operates much like
+ affine_transform(dest,src1,src2,src3,A,B,C,0), except that it runs over only
+ the half open range [begin,end) rather than processing the entire tensor.
+ Specifically, it does this:
+ - for i in the range [begin, end):
+ - #dest.host()[i] == A*src1.host()[i] + B*src2.host()[i] + C*src3.host()[i]
+ !*/
+
+ void affine_transform(
+ const rectangle& rect,
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2,
+ const tensor& src3,
+ float A,
+ float B,
+ float C
+ );
+ /*!
+ requires
+ - dest.size()==src1.size()
+ - dest.size()==src2.size()
+ - dest.size()==src3.size()
+ - dest.num_samples()==src1.num_samples()
+ - dest.num_samples()==src2.num_samples()
+ - dest.num_samples()==src3.num_samples()
+ - get_rect(mat(dest)).contains(rect) == true
+ (i.e. rect must be entirely contained within dest)
+ ensures
+ - This function operates much like
+ affine_transform(dest,src1,src2,src3,A,B,C,0), except that it runs over only
+ the sub-rectangle indicated by rect. In particular, this function is equivalent
+ to:
+ set_subm(dest,rect) = A*subm(mat(src1),rect) + B*subm(mat(src2),rect) + C*subm(mat(src3),rect)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void affine_transform(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,src) == true
+ - if (A.num_samples() == 1) then
+ - B.num_samples() == 1
+ - else
+ - A.num_samples() == src.num_samples()
+ - B.num_samples() == src.num_samples()
+ - A.nr() == B.nr() == src.nr()
+ - A.nc() == B.nc() == src.nc()
+ - A.k() == B.k() == src.k()
+ ensures
+ - if (A.num_samples() == 1) then
+ - #dest == A*src + B
+ (done for each sample in src)
+ - else
+ - for all valid i:
+ - #dest.host()[i] == A.host()[i]*src.host()[i] + B.host()[i]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void affine_transform_conv(
+ tensor& dest,
+ const tensor& src,
+ const tensor& A,
+ const tensor& B
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,src) == true
+ - have_same_dimensions(A, B) == true
+ - A.num_samples() == 1
+ - A.nr() == 1
+ - A.nc() == 1
+ - A.k() == src.k()
+ ensures
+ - Performs #dest == A*src + B
+ In particular, if the elements of dest and src were indexed by (n,k,r,c) then
+ we would have:
+ #dest(n,k,r,c) == A(k)*src(n,k,r,c) + B(k).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void compute_adam_update (
+ size_t begin,
+ size_t end,
+ tensor& s,
+ tensor& m,
+ tensor& v,
+ const float t,
+ const float learning_rate,
+ const float weight_decay,
+ const float momentum1,
+ const float momentum2,
+ const tensor& params,
+ const tensor& params_grad
+ );
+ /*!
+ requires
+ - s.size() == m.size() = v.size() == params.size() == params_grad.size()
+ - t > 0
+ - learning_rate > 0
+ - weight_decay >= 0
+ - 0 <= momentum1 < 1
+ - 0 <= momentum2 < 1
+ - begin <= end <= params.size()
+ ensures
+ - This function implements the ADAM parameter update method described in the paper:
+ Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
+ optimization." International Conference on Learning Representation. 2015.
+ Specifically, it implements the method shown as Algorithm 1.
+ - #s is the update vector that should be added to the parameters.
+ - The function only operates in the half open range [begin,end) of the memory
+ blocks of each tensor. E.g. to make this function run on the entire tensor
+ set begin to 0 and end to params.size().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void batch_normalize_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ );
+ /*!
+ requires
+ - eps > 0
+ - gamma.num_samples() == 1
+ - gamma.nr() == src.nr()
+ - gamma.nc() == src.nc()
+ - gamma.k() == src.k()
+ - have_same_dimensions(gamma, beta)
+ - have_same_dimensions(gamma, running_means)
+ - have_same_dimensions(gamma, running_variances)
+ ensures
+ - Linearly transforms src as a call to batch_normalize() would if src had means
+ and variances as given by running_means and running_variances. That is, this
+ function performs:
+ dest = gamma*(src-running_means)/sqrt(running_variances+eps) + beta
+ Note that it does it in a pointwise fashion over the samples in src.
+ !*/
+
+ void batch_normalize (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ );
+ /*!
+ requires
+ - eps > 0
+ - src.num_samples() > 1
+ - gamma.num_samples() == 1
+ - beta.num_samples() == 1
+ - gamma.nr() == beta.nr() == src.nr()
+ - gamma.nc() == beta.nc() == src.nc()
+ - gamma.k() == beta.k() == src.k()
+ - 0 <= averaging_factor <= 1
+ - if (averaging_factor != 1)
+ - have_same_dimensions(running_means, means) == true
+ - have_same_dimensions(running_variances, invstds) == true
+ ensures
+ - have_same_dimensions(#dest, src) == true
+ - #means.num_samples() == 1
+ - #invstds.num_samples() == 1
+ - means.nr() == invstds.nr() == src.nr()
+ - means.nc() == invstds.nc() == src.nc()
+ - means.k() == invstds.k() == src.k()
+ - #src == the batch normalized version of src.
+ - #means == the mean values of the contents of src.
+ - #invstds == 1/(the standard deviation values of the contents of src).
+ - #running_means = (1-averaging_factor)*mat(#running_means) + averaging_factor*mat(#means);
+ - #running_variances = (1-averaging_factor)*mat(#running_variances) + averaging_factor*(variance of contents of src);
+ !*/
+
+ void batch_normalize_gradient (
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ );
+ /*!
+ requires
+ - eps > 0
+ - invstds and means should be the output of a call to
+ batch_normalize(eps,dest,means,invstds,src,gamma,beta)
+ - have_same_dimensions(gradient_input, src) == true
+ - have_same_dimensions(src, src_grad) == true
+ - src.num_samples() > 1
+ - gamma.num_samples() == 1
+ - have_same_dimensions(gamma, gamma_grad) == true
+ - have_same_dimensions(gamma, beta_grad) == true
+ - gamma.nr() == src.nr()
+ - gamma.nc() == src.nc()
+ - gamma.k() == src.k()
+ - have_same_dimensions(means, gamma) == true
+ - have_same_dimensions(invstds, gamma) == true
+ ensures
+ - Let f(src,gamma,beta) == dot(gradient_input, dest output of
+ batch_normalize(eps,dest,means,invstds,src,gamma,beta))
+ - Adds the gradient of f() with respect to src to #src_grad.
+ - Assigns the gradient of f() with respect to gamma to #gamma_grad.
+ - Assigns the gradient of f() with respect to beta to #beta_grad.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void batch_normalize_conv_inference (
+ const double eps,
+ resizable_tensor& dest,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta,
+ const tensor& running_means,
+ const tensor& running_variances
+ );
+ /*!
+ requires
+ - eps > 0
+ - gamma.num_samples() == 1
+ - gamma.nr() == 1
+ - gamma.nc() == 1
+ - gamma.k() == src.k()
+ - have_same_dimensions(gamma, beta)
+ - have_same_dimensions(gamma, running_means)
+ - have_same_dimensions(gamma, running_variances)
+ ensures
+ - Linearly transforms src as a call to batch_normalize_conv() would if src had
+ means and variances as given by running_means and running_variances. That
+ is, this function performs:
+ dest = gamma*(src-running_means)/sqrt(running_variances+eps) + beta
+ Note that it does this in a pointwise fashion over the samples, rows, and
+ columns in src.
+ !*/
+
+ void batch_normalize_conv (
+ const double eps,
+ resizable_tensor& dest,
+ resizable_tensor& means,
+ resizable_tensor& invstds,
+ const double averaging_factor,
+ resizable_tensor& running_means,
+ resizable_tensor& running_variances,
+ const tensor& src,
+ const tensor& gamma,
+ const tensor& beta
+ );
+ /*!
+ requires
+ - eps > 0
+ - src.num_samples() > 1
+ - gamma.num_samples()==gamma.nr()==gamma.nc() == 1
+ - beta.num_samples() ==beta.nr() ==gamma.nc() == 1
+ - gamma.k() == beta.k() == src.k()
+ - 0 <= averaging_factor <= 1
+ - if (averaging_factor != 1)
+ - have_same_dimensions(running_means, means) == true
+ - have_same_dimensions(running_variances, invstds) == true
+ ensures
+ - have_same_dimensions(#dest, src) == true
+ - #means.num_samples()==means.nr()==means.nc() == 1
+ - #invstds.num_samples() ==invstds.nr() ==invstds.nc() == 1
+ - means.k() == invstds.k() == src.k()
+ - #src == the batch normalized version of src.
+ - #means == the mean values of the contents of src.
+ - #invstds == 1/(the standard deviation values of the contents of src).
+ - #running_means = (1-averaging_factor)*mat(#running_means) + averaging_factor*mat(#means);
+ - #running_variances = (1-averaging_factor)*mat(#running_variances) + averaging_factor*(variance of contents of src);
+ !*/
+
+ void batch_normalize_conv_gradient (
+ const double eps,
+ const tensor& gradient_input,
+ const tensor& means,
+ const tensor& invstds,
+ const tensor& src,
+ const tensor& gamma,
+ tensor& src_grad,
+ tensor& gamma_grad,
+ tensor& beta_grad
+ );
+ /*!
+ requires
+ - eps > 0
+ - invstds and means should be the output of a call to
+ batch_normalize_conv(eps,dest,means,invstds,src,gamma,beta)
+ - have_same_dimensions(gradient_input, src) == true
+ - have_same_dimensions(src, src_grad) == true
+ - src.num_samples() > 1
+ - gamma.num_samples()==gamma.nr()==gamma.nc() == 1
+ - have_same_dimensions(gamma, gamma_grad) == true
+ - have_same_dimensions(gamma, beta_grad) == true
+ - gamma.k() == src.k()
+ - have_same_dimensions(means, gamma) == true
+ - have_same_dimensions(invstds, gamma) == true
+ ensures
+ - Let f(src,gamma,beta) == dot(gradient_input, dest output of
+ batch_normalize_conv(eps,dest,means,invstds,src,gamma,beta))
+ - Adds the gradient of f() with respect to src to #src_grad.
+ - Assigns the gradient of f() with respect to gamma to #gamma_grad.
+ - Assigns the gradient of f() with respect to beta to #beta_grad.
+ !*/
+
+// -----------------------------------------------------------------------------------
+
+ void threshold (
+ tensor& data,
+ float thresh
+ );
+ /*!
+ ensures
+ - Sets all elements of data to 1 or 0 depending on if they are above or below
+ the given threshold. Specifically, for all valid i:
+ - #data.host()[i] == data.host()[i]>thresh ? 1 : 0
+ !*/
+
+ void dot (
+ const tensor& a,
+ const tensor& b,
+ tensor& result,
+ size_t idx
+ );
+ /*!
+ requires
+ - a.size() == b.size()
+ - idx < result.size()
+ ensures
+ - #result.host()[idx] == result.host()[idx] + dot(a,b);
+ I.e. Adds the dot product between a and b into the idx-th element of result.
+ The reason you might want to use this more complex version of dot() is
+ because, when using CUDA, it runs by generating asynchronous kernel launches
+ whereas the version of dot() that returns the result immediately as a scalar
+ must block the host while we wait for the result to be computed and then
+ transfered from the GPU do the host for return by dot(). So this version of
+ dot() might be much faster in some cases.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void add(
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& src
+ );
+ /*!
+ requires
+ - One of the following is true:
+ - have_same_dimensions(src, dest)
+ - src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1
+ - src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()
+ - src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()
+ - src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1
+ - is_same_object(src,dest) == false
+ ensures
+ - performs: dest = beta*dest + alpha*src
+ However, how the addition happens depends on the dimensions of src. In
+ particular, this function adds the scaled values of one src tensor to dest.
+ Each dimension of the src tensor must match the corresponding dimension of
+ the dest tensor or must be equal to 1. In the latter case, the same value
+ from the src tensor, for those dimensions, will be used to add into the dest
+ tensor.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void add (
+ tensor& dest,
+ const tensor& src1,
+ const tensor& src2
+ );
+ /*!
+ ensures
+ - performs: dest = src1 + src2
+ The addition happens pointwise according to 4D tensor arithmetic. If the
+ dimensions don't match then missing elements are presumed to be equal to 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void assign_conv_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - grad.num_samples() == 1
+ - grad.k() >= 1
+ - grad.nr() == 1
+ - grad.nc() == 1
+ - gradient_input.k() == grad.k()
+ - gradient_input.size() > 0
+ - is_same_object(grad,gradient_input) == false
+ ensures
+ - let BIAS be a tensor with the same dimensions as grad.
+ - let OUT be the output of add(1,OUT,1,BIAS)
+ - let f(gradient_input,BIAS) == dot(gradient_input,OUT)
+ - Then this function computes the gradient of f() with respect to BIAS and
+ assigns it to grad.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void assign_bias_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - grad.num_samples() == 1
+ - gradient_input.k() == grad.k()
+ - gradient_input.nr() == grad.nr()
+ - gradient_input.nc() == grad.nc()
+ - gradient_input.size() > 0
+ - is_same_object(grad,gradient_input) == false
+ ensures
+ - let BIAS be a tensor with the same dimensions as grad.
+ - let OUT be the output of add(1,OUT,1,BIAS)
+ - let f(gradient_input,BIAS) == dot(gradient_input,OUT)
+ - Then this function computes the gradient of f() with respect to BIAS and
+ assigns it to grad.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class tensor_conv
+ {
+ public:
+ tensor_conv(const tensor_conv&) = delete;
+ tensor_conv& operator=(const tensor_conv&) = delete;
+
+ tensor_conv() {}
+
+ void clear(
+ ) { impl.clear(); }
+
+ void operator() (
+ const bool add_to_output,
+ tensor& output,
+ const tensor& data,
+ const tensor& filters
+ ) { impl(add_to_output,output,data,filters); }
+ /*!
+ requires
+ - setup() has been called. Specifically, setup() has been called like this:
+ this->setup(data, filters, stride_y, stride_x, padding_y, padding_x);
+ - is_same_object(output,data) == false
+ - is_same_object(output,filters) == false
+ - filters.k() == data.k()
+ - filters.nr() <= src.nr() + 2*padding_y
+ - filters.nc() <= src.nc() + 2*padding_x
+ - #output.num_samples() == data.num_samples()
+ - #output.k() == filters.num_samples()
+ - #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y
+ - #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x
+ ensures
+ - Convolves filters over data. If add_to_output==true then we add the
+ results to output, otherwise we assign to output, overwriting the
+ previous values in output.
+ - filters contains filters.num_samples() filters.
+ !*/
+
+ void operator() (
+ const bool add_to_output,
+ resizable_tensor& output,
+ const tensor& data,
+ const tensor& filters
+ ) { impl(add_to_output,output,data,filters); }
+ /*!
+ requires
+ - setup() has been called. Specifically, setup() has been called like this:
+ this->setup(data, filters, stride_y, stride_x, padding_y, padding_x);
+ - is_same_object(output,data) == false
+ - is_same_object(output,filters) == false
+ - filters.k() == data.k()
+ - filters.nr() <= src.nr() + 2*padding_y
+ - filters.nc() <= src.nc() + 2*padding_x
+ ensures
+ - Convolves filters over data. If add_to_output==true then we add the
+ results to output, otherwise we assign to output, overwriting the
+ previous values in output.
+ - filters contains filters.num_samples() filters.
+ - #output.num_samples() == data.num_samples()
+ - #output.k() == filters.num_samples()
+ - #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y
+ - #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x
+ !*/
+
+ void get_gradient_for_data (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& filters,
+ tensor& data_gradient
+ ) { impl.get_gradient_for_data(add_to_output,gradient_input,filters,data_gradient); }
+ /*!
+ requires
+ - One of the following must be true:
+ - filters has the same dimensions as the filters object given to the
+ last call to operator(). Also, data_gradient has the same dimensions
+ as the data object given to the last call to operator().
+ - setup() has been called. Specifically, setup() has been called like this:
+ this->setup(data_gradient, filters, stride_y, stride_x, padding_y, padding_x);
+ - gradient_input has the following dimensions:
+ - gradient_input.num_samples() == data_gradient.num_samples()
+ - gradient_input.k() == filters.num_samples()
+ - gradient_input.nr() == 1+(data_gradient.nr() + 2*padding_y - filters.nr())/stride_y
+ - gradient_input.nc() == 1+(data_gradient.nc() + 2*padding_x - filters.nc())/stride_x
+ - NOTE, these dimensions are what you would obtain if gradient_input
+ has the same dimensions as the last output of operator().
+ - is_same_object(data_gradient,filters) == false
+ - is_same_object(data_gradient,gradient_input) == false
+ ensures
+ - let OUT be the output of (*this)(OUT,data,filters,sx,sy).
+ - let f(data,filters) == dot(OUT, gradient_input)
+ - if (add_to_output) then
+ - This function finds the gradient of f() with respect to data and adds
+ this gradient to data_gradient.
+ - else
+ - This function finds the gradient of f() with respect to data and
+ assigns this gradient to data_gradient, overwriting the previous
+ values in data_gradient.
+ !*/
+
+ void get_gradient_for_filters (
+ const bool add_to_output,
+ const tensor& gradient_input,
+ const tensor& data,
+ tensor& filters_gradient
+ ) { impl.get_gradient_for_filters(add_to_output,gradient_input,data,filters_gradient); }
+ /*!
+ requires
+ - One of the following must be true:
+ - filters_gradient has the same dimensions as the filters object given
+ to the last call to operator(). Also, data has the same dimensions
+ as the data object given to the last call to operator().
+ - setup() has been called. Specifically, setup() has been called like this:
+ this->setup(data, filters_gradient, stride_y, stride_x, padding_y, padding_x);
+ - gradient_input has the following dimensions:
+ - gradient_input.num_samples() == data.num_samples()
+ - gradient_input.k() == filters.num_samples()
+ - gradient_input.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y
+ - gradient_input.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x
+ - NOTE, these dimensions are what you would obtain if gradient_input
+ has the same dimensions as the last output of operator().
+ - is_same_object(filters_gradient,data) == false
+ - is_same_object(filters_gradient,gradient_input) == false
+ ensures
+ - let OUT be the output of (*this)(OUT,data,filters,sx,sy).
+ - let f(data,filters) == dot(OUT, gradient_input)
+ - if (add_to_output) then
+ - This function finds the gradient of f() with respect to filters and
+ adds this gradient to filters_gradient.
+ - else
+ - This function finds the gradient of f() with respect to filters and
+ assigns this gradient to filters_gradient, overwriting the previous
+ values in filters_gradient.
+ !*/
+
+
+ void setup(
+ const tensor& data,
+ const tensor& filters,
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x
+ ) {impl.setup(data,filters,stride_y,stride_x,padding_y,padding_x); }
+ /*!
+ requires
+ - filters.k() == data.k()
+ - stride_y > 0
+ - stride_x > 0
+ - 0 <= padding_y < filters.nr()
+ - 0 <= padding_x < filters.nc()
+ ensures
+ - When operator() is called, the output tensor will have these dimensions:
+ - output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y
+ - output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x
+ - output.num_samples() == data.num_samples()
+ - output.k() == filters.num_samples()
+ - The point of setup() is to allow this object to gather information about
+ all the tensor sizes and filter layouts involved in the computation. In
+ particular, the reason the tensors are input into setup() is just to
+ observe their sizes. setup() doesn't do anything with the contents of
+ the tensors, or store any kind of references to the data or filter
+ tensors.
+ !*/
+
+ private:
+#ifdef DLIB_USE_CUDA
+ cuda::tensor_conv impl;
+#else
+ cpu::tensor_conv impl;
+#endif
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class pooling
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ The pooling object is a tool for performing spatial pooling over a tensor.
+ It can be configured to do either max or average pooling.
+ !*/
+ public:
+
+ pooling(const pooling&) = delete;
+ pooling& operator=(const pooling&) = delete;
+
+ pooling (
+ ) = default;
+
+ void clear(
+ ) { impl.clear(); }
+
+ void setup_max_pooling(
+ int window_height,
+ int window_width,
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x
+ ) { impl.setup_max_pooling(window_height, window_width, stride_y, stride_x, padding_y, padding_x); }
+ /*!
+ requires
+ - window_height > 0
+ - window_width > 0
+ - stride_y > 0
+ - stride_x > 0
+ - 0 <= padding_y < window_height
+ - 0 <= padding_x < window_width
+ ensures
+ - When you call operator() it will do max pooling with the given
+ parameters.
+ !*/
+
+ void setup_avg_pooling(
+ int window_height,
+ int window_width,
+ int stride_y,
+ int stride_x,
+ int padding_y,
+ int padding_x
+ ) { impl.setup_avg_pooling(window_height, window_width, stride_y, stride_x, padding_y, padding_x); }
+ /*!
+ requires
+ - window_height > 0
+ - window_width > 0
+ - stride_y > 0
+ - stride_x > 0
+ - 0 <= padding_y < window_height
+ - 0 <= padding_x < window_width
+ ensures
+ - When you call operator() it will do average pooling with the given
+ parameters.
+ !*/
+
+ bool does_max_pooling(
+ ) const { return impl.does_max_pooling(); }
+
+ void operator() (
+ resizable_tensor& dest,
+ const tensor& src
+ ) { impl(dest, src); }
+ /*!
+ requires
+ - is_same_object(dest,src) == false
+ - either setup_max_pooling() or setup_avg_pooling() has been called.
+ - window_width <= src.nc() + 2*padding_x
+ - window_height <= src.nr() + 2*padding_y
+ ensures
+ - #dest.num_samples() == src.num_samples()
+ - #dest.k() == src.k()
+ - #dest.nr() == 1 + (src.nr() + 2*padding_y - window_height)/stride_y
+ - #dest.nc() == 1 + (src.nc() + 2*padding_x - window_width)/stride_x
+ - WINDOW == centered_rect(x*stride_x + window_width/2 - padding_x,
+ y*stride_y + window_height/2 - padding_y,
+ window_width,
+ window_height)
+ - for all valid s, k, r, and c:
+ - if (does_max_pooling()) then
+ - image_plane(#dest,s,k)(r,c) == max(subm_clipped(image_plane(src,s,k),WINDOW(c,r)))
+ - else
+ - image_plane(#dest,s,k)(r,c) == mean(subm_clipped(image_plane(src,s,k),WINDOW(c,r)))
+ !*/
+
+ void get_gradient(
+ const tensor& gradient_input,
+ const tensor& dest,
+ const tensor& src,
+ tensor& grad
+ ) { impl.get_gradient(gradient_input, dest, src, grad); }
+ /*!
+ requires
+ - have_same_dimensions(gradient_input,dest) == true
+ - have_same_dimensions(src,grad) == true
+ - dest contains the result of calling (*this)(dest,src)
+ - is_same_object(grad,gradient_input) == false
+ - is_same_object(grad,dest) == false
+ - is_same_object(grad,src) == false
+ ensures
+ - Recalling that dest is the output of (*this)(dest,src),
+ let f(src) == dot(gradient_input,dest)
+ - Then this function computes the gradient of f() with respect to src and
+ adds it to grad.
+ !*/
+
+ private:
+#ifdef DLIB_USE_CUDA
+ cuda::pooling impl;
+#else
+ cpu::pooling impl;
+#endif
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void softmax (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ ensures
+ - Note that the softmax function is a vector valued function:
+ s(x) == exp(x)/sum(exp(x))
+ - Computes the softmax function on src and writes the results to dest. The
+ softmax is computed per spatial location across the different channels at
+ each location. That is, softmax() outputs a new tensor, #dest, where each of
+ the spatial locations in dest (i.e. image idx, row idx, and column idx)
+ contains the output of s() evaluated over the channel values at each
+ location.
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void softmax_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,gradient_input) == true
+ - have_same_dimensions(dest,grad) == true
+ ensures
+ - We interpret dest as the output of softmax(dest,SRC) for some SRC tensor.
+ Then let f(SRC) == dot(gradient_input,dest). Then this function computes the
+ gradient of f() with respect to SRC and stores it to grad. Moreover, if
+ is_same_object(grad,gradient_input)==true then the output is assigned to
+ grad, replacing its previous contents. Otherwise the output is added to
+ grad.
+ - This function supports in-place operation, i.e. having
+ is_same_object(grad, gradient_input)==true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void softmax_all (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ ensures
+ - Note that the softmax function is a vector valued function:
+ s(x) == exp(x)/sum(exp(x))
+ - Computes the softmax function on src and writes the results to dest. The
+ softmax is computed over the entire tensor with one invocation of s(). So
+ unlike softmax() which computes many s() evaluations, one for each spatial
+ location, softmax_all() calls s() once for the entire tensor.
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void softmax_all_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,gradient_input) == true
+ - have_same_dimensions(dest,grad) == true
+ - is_same_object(grad, dest)==false
+ ensures
+ - We interpret dest as the output of softmax_all(dest,SRC) for some SRC tensor.
+ Then let f(SRC) == dot(gradient_input,dest) Then this function computes the
+ gradient of f() with respect to SRC and assigns it to grad.
+ - This function supports in-place operation, i.e. having
+ is_same_object(grad, gradient_input)==true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void sigmoid (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ ensures
+ - for all valid i:
+ - #dest.host()[i] == 1/(1+std::exp(-src.host()[i]))
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void sigmoid_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,gradient_input) == true
+ - have_same_dimensions(dest,grad) == true
+ ensures
+ - Recalling that dest is the output of sigmoid(dest,SRC) for some SRC tensor,
+ let f(SRC) == dot(gradient_input,dest). Then this function computes the
+ gradient of f() with respect to SRC and stores it to grad. Moreover, if
+ is_same_object(grad,gradient_input)==true then the output is assigned to
+ grad, replacing its previous contents. Otherwise the output is added to
+ grad.
+ - This function supports in-place operation, i.e. having
+ is_same_object(grad, gradient_input)==true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void relu (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ ensures
+ - for all valid i:
+ - #dest.host()[i] == std::max(0,src.host()[i])
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void relu_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,gradient_input) == true
+ - have_same_dimensions(dest,grad) == true
+ ensures
+ - Recalling that dest is the output of relu(dest,SRC) for some SRC tensor,
+ let f(SRC) == dot(gradient_input,dest). Then this function computes the
+ gradient of f() with respect to SRC and stores it to grad. Moreover, if
+ is_same_object(grad,gradient_input)==true then the output is assigned to
+ grad, replacing its previous contents. Otherwise the output is added to
+ grad.
+ - This function supports in-place operation, i.e. having
+ is_same_object(grad, gradient_input)==true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void prelu (
+ tensor& dest,
+ const tensor& src,
+ const tensor& param
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ - param.size() == 1
+ ensures
+ - for all valid i:
+ - if (src.host()[i] > 0) then
+ - #dest.host()[i] == src.host()[i]
+ - else
+ - #dest.host()[i] == src.host()[i] * param.host()[0]
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void prelu_gradient (
+ tensor& grad,
+ const tensor& src,
+ const tensor& gradient_input,
+ const tensor& param,
+ tensor& params_grad
+ );
+ /*!
+ requires
+ - have_same_dimensions(grad,src) == true
+ - have_same_dimensions(grad,gradient_input) == true
+ - param.size() == 1
+ - params_grad.size() == 1
+ - is_same_object(grad, gradient_input) == false
+ ensures
+ - Recalling that dest is the output of prelu(dest,src,param) let
+ f(src,param) == dot(gradient_input,dest)
+ - Then this function computes the gradient of f() with respect to src and
+ param. It assigns the gradient with respect to param to #params_grad and
+ adds the gradient with respect to src to #grad.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void tanh (
+ tensor& dest,
+ const tensor& src
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest, src) == true
+ ensures
+ - for all valid i:
+ - #dest.host()[i] == std::tanh(src.host()[i])
+ - This function supports in-place operation, i.e. having
+ is_same_object(dest, src)==true
+ !*/
+
+ void tanh_gradient (
+ tensor& grad,
+ const tensor& dest,
+ const tensor& gradient_input
+ );
+ /*!
+ requires
+ - have_same_dimensions(dest,gradient_input) == true
+ - have_same_dimensions(dest,grad) == true
+ ensures
+ - Recalling that dest is the output of tanh(dest,SRC) for some SRC tensor,
+ let f(SRC) == dot(gradient_input,dest). Then this function computes the
+ gradient of f() with respect to SRC and stores it to grad. Moreover, if
+ is_same_object(grad,gradient_input)==true then the output is assigned to
+ grad, replacing its previous contents. Otherwise the output is added to
+ grad.
+ - This function supports in-place operation, i.e. having
+ is_same_object(grad, gradient_input)==true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void resize_bilinear (
+ tensor& dest,
+ long dest_row_stride,
+ long dest_channel_stride,
+ const tensor& src,
+ long src_row_stride,
+ long src_channel_stride
+ );
+ /*!
+ requires
+ - is_same_object(dest, src)==false
+ - dest.num_samples() == src.num_samples()
+ - dest.k() == src.k()
+ ensures
+ - for all valid i,k: image_plane(dest,i,k) is a copy of image_plane(src,i,k)
+ that has been bilinearly interpolated to fit into the shape of
+ image_plane(dest,i,k).
+ - Instead of supposing the row stride and channel stride in the tensors is
+ given by tensor::nc() and tensor::nr()*tensor::nc() respectively, we use the
+ provided stride values to transition from one row and channel to the next.
+ This is useful in combination with alias_tensor objects since it allows you
+ to operate on subwindows in an image.
+ !*/
+
+ void resize_bilinear_gradient (
+ tensor& grad,
+ long grad_row_stride,
+ long grad_channel_stride,
+ const tensor& gradient_input,
+ long gradient_input_row_stride,
+ long gradient_input_channel_stride
+ );
+ /*!
+ requires
+ - is_same_object(grad, gradient_input)==false
+ - gradient_input.num_samples() == grad.num_samples()
+ - gradient_input.k() == grad.k()
+ ensures
+ - Suppose that DEST is the output of resize_bilinear(DEST,SRC) for some SRC
+ tensor, let f(SRC) == dot(gradient_input,DEST). Then this function computes
+ the gradient of f() with respect to SRC and adds it to grad. It should be
+ noted that we don't need to know the contents of DEST to compute this
+ gradient. All that matters is that gradient_input have the same dimensions
+ as DEST.
+ - Instead of supposing the row stride and channel stride in the tensors is
+ given by tensor::nc() and tensor::nr()*tensor::nc() respectively, we use the
+ provided stride values to transition from one row and channel to the next.
+ This is useful in combination with alias_tensor objects since it allows you
+ to operate on subwindows in an image.
+ !*/
+
+ inline void resize_bilinear (
+ tensor& dest,
+ const tensor& src
+ ) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); }
+ /*!
+ requires
+ - is_same_object(dest, src)==false
+ - dest.num_samples() == src.num_samples()
+ - dest.k() == src.k()
+ ensures
+ - for all valid i,k: image_plane(dest,i,k) is a copy of image_plane(src,i,k)
+ that has been bilinearly interpolated to fit into the shape of
+ image_plane(dest,i,k).
+ !*/
+
+ inline void resize_bilinear_gradient (
+ tensor& grad,
+ const tensor& gradient_input
+ ) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); }
+ /*!
+ requires
+ - is_same_object(grad, gradient_input)==false
+ - gradient_input.num_samples() == grad.num_samples()
+ - gradient_input.k() == grad.k()
+ ensures
+ - Suppose that DEST is the output of resize_bilinear(DEST,SRC) for some SRC
+ tensor, let f(SRC) == dot(gradient_input,DEST). Then this function computes
+ the gradient of f() with respect to SRC and adds it to grad. It should be
+ noted that we don't need to know the contents of DEST to compute this
+ gradient. All that matters is that gradient_input have the same dimensions
+ as DEST.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class multi_device_tensor_averager
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for very quickly averaging a bunch of tensors
+ together.
+ !*/
+ public:
+
+ multi_device_tensor_averager(const multi_device_tensor_averager&) = delete;
+ multi_device_tensor_averager& operator=(const multi_device_tensor_averager&) = delete;
+
+ multi_device_tensor_averager() = default;
+
+ void set(
+ std::vector<tensor*> items
+ )
+ /*!
+ requires
+ - All the tensors in items are the same size
+ ensures
+ - When you call average() we will average the tensors in items.
+ - It's important that the tensors already be allocated to their devices
+ before you call set(). This is because set() will setup the types of
+ between device transfers now and use them when you call average().
+ !*/
+ {
+ using namespace ::dlib::cuda;
+ accessible_groups.clear();
+ epa.clear();
+ if (items.size() < 1)
+ return;
+
+ scale = 1.0/items.size();
+
+ // split item into groups of accessible devices
+ std::vector<tensor*> group, unused;
+ while(items.size() > 0)
+ {
+ group.push_back(items[0]);
+ for(size_t i = 1; i < items.size(); ++i)
+ {
+ if (can_access_peer(*items[0], *items[i]))
+ group.push_back(items[i]);
+ else
+ unused.push_back(items[i]);
+ }
+ accessible_groups.push_back(group);
+ unused.swap(items);
+ unused.clear();
+ group.clear();
+ }
+ for (auto&& g : accessible_groups)
+ {
+ for (size_t i = 1; i < g.size(); ++i)
+ {
+ epa.emplace_back(new enable_peer_access(*g[0], *g[i]));
+ }
+ }
+ }
+
+ size_t num_device_groups(
+ ) const { return accessible_groups.size(); }
+ /*!
+ ensures
+ - The devices given to set() are grouped together when they can directly
+ access each other using GPUDirect. This function returns the number of
+ such groups. For example, if all devices can directly access each other
+ then the number of groups is 1.
+ !*/
+
+ void average()
+ /*!
+ requires
+ - All the devices have stopped writing to the tensors given to set(). So
+ you should probably call cudaDeviceSynchronize() on each of the relevant
+ devices before calling average().
+ ensures
+ - Computes the average of all the tensors given to set() and then sets them
+ all equal to the average.
+ !*/
+ {
+ using namespace ::dlib::cuda;
+
+
+ // First we average things within each group
+ for (auto&& g : accessible_groups)
+ {
+ raii_set_device set_dev(*g[0]);
+ if (g.size() == 1)
+ tt::affine_transform(*g[0], *g[0], scale);
+ else
+ tt::affine_transform(*g[0], *g[0], *g[1], scale, scale);
+
+ for (size_t i = 2; i < g.size(); ++i)
+ tt::affine_transform(*g[0], *g[0], *g[i], 1, scale);
+ }
+
+ if (accessible_groups.size() > 1)
+ {
+ tensor& total_avg = *accessible_groups[0][0];
+ raii_set_device set_dev(total_avg);
+ accum_buffer.copy_size(total_avg);
+ // now we need to average things across groups
+ for (size_t i = 1; i < accessible_groups.size(); ++i)
+ {
+ memcpy(accum_buffer, *accessible_groups[i][0]);
+ tt::add(total_avg, total_avg, accum_buffer);
+ }
+
+ // Now total_avg has the final average in it. So we need to send
+ // copies of it back to each of the groups.
+ for (size_t i = 1; i < accessible_groups.size(); ++i)
+ {
+ memcpy(*accessible_groups[i][0], total_avg);
+ }
+ }
+
+
+ // Now propagate averages back out to each element using point to point
+ // communication inside a group.
+ for (auto&& g : accessible_groups)
+ {
+ raii_set_device set_dev(*g[0]);
+ for (size_t i = 1; i < g.size(); ++i)
+ memcpy(*g[i], *g[0]);
+ }
+ }
+
+ private:
+ std::vector<std::unique_ptr<::dlib::cuda::enable_peer_access>> epa;
+ std::vector<std::vector<tensor*>> accessible_groups;
+ float scale;
+
+ resizable_tensor accum_buffer;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void copy_tensor(
+ bool add_to,
+ tensor& dest,
+ size_t dest_k_offset,
+ const tensor& src,
+ size_t src_k_offset,
+ size_t count_k
+ );
+ /*!
+ requires
+ - dest.nc() == src.nc()
+ - dest.nr() == src.nr()
+ - dest.num_samples() == src.num_samples()
+ - dest.k() - dest_k_offset >= count_k
+ - src.k() - src_k_offset >= count_k
+ - is_same_object(dest,src) == false
+ - The memory areas of src and dest do not overlap.
+ ensures
+ - if (add_to) then
+ - performs: dest[i, k + dest_k_offset, r, c] += src[i, k + src_k_offset, r, c], where k in [0..count_k]
+ i.e., adds content of each sample from src in to corresponding place of sample at dest.
+ - else
+ - performs: dest[i, k + dest_k_offset, r, c] = src[i, k + src_k_offset, r, c], where k in [0..count_k]
+ i.e., copies content of each sample from src in to corresponding place of sample at dest.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}}
+
+#ifdef NO_MAKEFILE
+#include "tensor_tools.cpp"
+#endif
+
+#endif // DLIB_TeNSOR_TOOLS_H_
+
+
diff --git a/ml/dlib/dlib/dnn/trainer.h b/ml/dlib/dlib/dnn/trainer.h
new file mode 100644
index 000000000..7cb2bf5e5
--- /dev/null
+++ b/ml/dlib/dlib/dnn/trainer.h
@@ -0,0 +1,1333 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_TRAINER_H_
+#define DLIB_DNn_TRAINER_H_
+
+#include "trainer_abstract.h"
+#include "core.h"
+#include "solvers.h"
+#include "../statistics.h"
+#include <chrono>
+#include <fstream>
+#include <sstream>
+#include "../serialize.h"
+
+#include "../pipe.h"
+#include "../threads.h"
+#include "cuda_dlib.h"
+#include "../statistics/running_gradient.h"
+#include <atomic>
+#include <cstdio>
+#include <set>
+#include <future>
+#include <exception>
+#include <mutex>
+#include "../dir_nav.h"
+#include "../md5.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename training_label_type>
+ struct dnn_job_t
+ {
+ dnn_job_t() = default;
+ dnn_job_t(const dnn_job_t&) = delete;
+ dnn_job_t& operator=(const dnn_job_t&) = delete;
+
+ std::vector<std::vector<training_label_type>> labels;
+ std::vector<resizable_tensor> t;
+ std::vector<int> have_data; // have_data[i] is true if there is data in labels[i] and t[i].
+ bool test_only = false;
+ };
+
+ template <typename training_label_type>
+ void swap(dnn_job_t<training_label_type>& a, dnn_job_t<training_label_type>& b)
+ {
+ a.labels.swap(b.labels);
+ a.t.swap(b.t);
+ a.have_data.swap(b.have_data);
+ std::swap(a.test_only,b.test_only);
+ }
+ }
+
+ enum class force_flush_to_disk {
+ no = 0,
+ yes = 1
+ };
+
+ template <
+ typename net_type,
+ typename solver_type = sgd
+ >
+ class dnn_trainer : private threaded_object
+ {
+ public:
+
+ static_assert(is_loss_layer_type<net_type>::value,
+ "The last layer in a network must be a loss layer.");
+
+ typedef typename net_type::training_label_type training_label_type;
+ typedef typename net_type::input_type input_type;
+ const static size_t num_computational_layers = net_type::num_computational_layers;
+ const static size_t num_layers = net_type::num_layers;
+ private:
+ typedef impl::dnn_job_t<training_label_type> job_t;
+ public:
+
+ dnn_trainer() = delete;
+ dnn_trainer(const dnn_trainer&) = delete;
+ dnn_trainer& operator=(const dnn_trainer&) = delete;
+
+ explicit dnn_trainer(net_type& net_) : job_pipe(0), net(net_)
+ {
+ solver_type default_solver;
+ devices.push_back(std::make_shared<device_data>(dlib::cuda::get_device(), net, default_solver));
+
+ init();
+ }
+
+ dnn_trainer(
+ net_type& net_,
+ const solver_type& solver_
+ ) : job_pipe(0), net(net_)
+ {
+ devices.push_back(std::make_shared<device_data>(dlib::cuda::get_device(), net, solver_));
+
+ init();
+ }
+
+ dnn_trainer(
+ net_type& net_,
+ const solver_type& solver_,
+ const std::vector<int>& cuda_extra_devices
+ ) : job_pipe(0), net(net_)
+ {
+ devices.push_back(std::make_shared<device_data>(dlib::cuda::get_device(), net, solver_));
+
+ const int total_devices = dlib::cuda::get_num_devices();
+
+ // Make device contexts for the extra device ids but be careful to avoid any
+ // duplicate ids.
+ std::set<int> temp(cuda_extra_devices.begin(), cuda_extra_devices.end());
+ temp.erase(devices[0]->device_id);
+ for (auto id : temp)
+ {
+ DLIB_CASSERT(0 <= id && id < total_devices, "Invalid CUDA device id given to dnn_trainer.");
+ // Switch to this device so that any tensor objects that get allocated when
+ // we create the device context happen on this device.
+ dlib::cuda::set_device(id);
+ devices.push_back(std::make_shared<device_data>(id, net, solver_, clone_net()));
+ }
+ // Set the current device back to what it was before this constructor was
+ // called.
+ dlib::cuda::set_device(devices[0]->device_id);
+
+ init();
+ }
+
+ ~dnn_trainer(
+ )
+ {
+ job_pipe.disable();
+ stop();
+ wait();
+ }
+
+ net_type& get_net (
+ force_flush_to_disk force_flush = force_flush_to_disk::yes
+ )
+ {
+ wait_for_thread_to_pause();
+ sync_to_disk(force_flush == force_flush_to_disk::yes);
+ propagate_exception();
+ return net;
+ }
+
+
+ unsigned long get_mini_batch_size (
+ ) const { return mini_batch_size; }
+
+ void set_mini_batch_size (
+ unsigned long batch_size
+ )
+ {
+ DLIB_CASSERT(batch_size > 0);
+ mini_batch_size = batch_size;
+ }
+
+ unsigned long get_max_num_epochs (
+ ) const { return max_num_epochs; }
+
+ void set_max_num_epochs (
+ unsigned long num
+ )
+ {
+ DLIB_CASSERT(num > 0);
+ max_num_epochs = num;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+
+ const std::vector<solver_type>& get_solvers (
+ ) const
+ {
+ wait_for_thread_to_pause();
+ propagate_exception();
+ return devices[0]->solvers;
+ }
+
+ void train_one_step (
+ const std::vector<input_type>& data,
+ const std::vector<training_label_type>& labels
+ )
+ {
+ DLIB_CASSERT(data.size() == labels.size());
+
+ train_one_step(data.begin(), data.end(), labels.begin());
+ }
+
+ template <
+ typename data_iterator,
+ typename label_iterator
+ >
+ void train_one_step (
+ data_iterator dbegin,
+ data_iterator dend,
+ label_iterator lbegin
+ )
+ {
+ DLIB_CASSERT(std::distance(dbegin, dend) > 0);
+
+ print_periodic_verbose_status();
+ sync_to_disk();
+ send_job(false, dbegin, dend, lbegin);
+
+ ++train_one_step_calls;
+ }
+
+ void train_one_step (
+ const std::vector<input_type>& data
+ )
+ {
+ train_one_step(data.begin(), data.end());
+ }
+
+ template <
+ typename data_iterator
+ >
+ void train_one_step (
+ data_iterator dbegin,
+ data_iterator dend
+ )
+ {
+ DLIB_CASSERT(std::distance(dbegin, dend) > 0);
+ print_periodic_verbose_status();
+ sync_to_disk();
+ send_job(false, dbegin, dend);
+ ++train_one_step_calls;
+ }
+
+ void test_one_step (
+ const std::vector<input_type>& data,
+ const std::vector<training_label_type>& labels
+ )
+ {
+ DLIB_CASSERT(data.size() == labels.size());
+
+ test_one_step(data.begin(), data.end(), labels.begin());
+ }
+
+ template <
+ typename data_iterator,
+ typename label_iterator
+ >
+ void test_one_step (
+ data_iterator dbegin,
+ data_iterator dend,
+ label_iterator lbegin
+ )
+ {
+ DLIB_CASSERT(std::distance(dbegin, dend) > 0);
+
+ print_periodic_verbose_status();
+ sync_to_disk();
+ send_job(true, dbegin, dend, lbegin);
+
+ ++test_one_step_calls;
+ }
+
+ void test_one_step (
+ const std::vector<input_type>& data
+ )
+ {
+ test_one_step(data.begin(), data.end());
+ }
+
+ template <
+ typename data_iterator
+ >
+ void test_one_step (
+ data_iterator dbegin,
+ data_iterator dend
+ )
+ {
+ DLIB_CASSERT(std::distance(dbegin, dend) > 0);
+ print_periodic_verbose_status();
+ sync_to_disk();
+ send_job(true, dbegin, dend);
+ ++test_one_step_calls;
+ }
+
+ void train (
+ const std::vector<input_type>& data,
+ const std::vector<training_label_type>& labels
+ )
+ {
+ DLIB_CASSERT(data.size() == labels.size() && data.size() > 0);
+
+ // The reason these two loops don't initialize their counter variables but
+ // instead use class members is so we can include the state of the loops in the
+ // stuff written by sync_to_disk()
+ for (;
+ epoch_iteration < max_num_epochs && learning_rate >= min_learning_rate;
+ ++epoch_iteration)
+ {
+ using namespace std::chrono;
+ last_time = system_clock::now();
+ clear_average_loss();
+ for (; epoch_pos < data.size() && learning_rate >= min_learning_rate; epoch_pos += mini_batch_size)
+ {
+ if (verbose)
+ {
+ auto now_time = system_clock::now();
+ if (now_time-last_time > seconds(20))
+ {
+ last_time = now_time;
+ auto iter = epoch_iteration + epoch_pos/(double)data.size();
+ std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " "
+ << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
+ << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
+ print_progress();
+ }
+ }
+
+ sync_to_disk();
+ send_job(false, data.begin()+epoch_pos,
+ data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
+ labels.begin()+epoch_pos);
+ }
+ epoch_pos = 0;
+
+ if (verbose)
+ {
+ // Capitalize the E in Epoch so it's easy to grep out the lines that
+ // are for full epoch status statements.
+ std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " "
+ << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
+ << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
+ print_progress();
+ }
+ }
+ wait_for_thread_to_pause();
+ // if we modified the network at all then be sure to sync the final result.
+ sync_to_disk(true);
+ }
+
+ void train (
+ const std::vector<input_type>& data
+ )
+ {
+ DLIB_CASSERT(data.size() > 0);
+
+ const bool has_unsupervised_loss = std::is_same<no_label_type, training_label_type>::value;
+ static_assert(has_unsupervised_loss,
+ "You can only call this version of train() when using an unsupervised loss.");
+
+ // The reason these two loops don't initialize their counter variables but
+ // instead use class members is so we can include the state of the loops in the
+ // stuff written by sync_to_disk()
+ for (;
+ epoch_iteration < max_num_epochs && learning_rate >= min_learning_rate;
+ ++epoch_iteration)
+ {
+ using namespace std::chrono;
+ last_time = system_clock::now();
+ clear_average_loss();
+ for (; epoch_pos < data.size() && learning_rate >= min_learning_rate; epoch_pos += mini_batch_size)
+ {
+ if (verbose)
+ {
+ auto now_time = system_clock::now();
+ if (now_time-last_time > seconds(20))
+ {
+ last_time = now_time;
+ auto iter = epoch_iteration + epoch_pos/(double)data.size();
+ std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " "
+ << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
+ << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
+ print_progress();
+ }
+ }
+
+ sync_to_disk();
+ send_job(false, data.begin()+epoch_pos,
+ data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
+ }
+ epoch_pos = 0;
+
+ if (verbose)
+ {
+ // Capitalize the E in Epoch so it's easy to grep out the lines that
+ // are for full epoch status statements.
+ std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " "
+ << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
+ << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
+ print_progress();
+ }
+ }
+ wait_for_thread_to_pause();
+ // if we modified the network at all then be sure to sync the final result.
+ sync_to_disk(true);
+ }
+
+ void set_synchronization_file (
+ const std::string& filename,
+ std::chrono::seconds time_between_syncs_ = std::chrono::minutes(15)
+ )
+ {
+ last_sync_time = std::chrono::system_clock::now();
+ sync_filename = filename;
+ time_between_syncs = time_between_syncs_;
+
+ // check if the sync file already exists, if it does we should load it.
+ std::ifstream fin(newest_syncfile(), std::ios::binary);
+ if (fin)
+ deserialize(*this, fin);
+ }
+
+ const std::string& get_synchronization_file (
+ )
+ {
+ return sync_filename;
+ }
+
+ double get_average_loss (
+ ) const
+ {
+ wait_for_thread_to_pause();
+ return rs.mean();
+ }
+
+ double get_average_test_loss (
+ ) const
+ {
+ wait_for_thread_to_pause();
+ return rs_test.mean();
+ }
+
+ void clear_average_loss (
+ )
+ {
+ wait_for_thread_to_pause();
+ rs.clear();
+ }
+
+ void set_learning_rate (
+ double lr
+ )
+ {
+ DLIB_CASSERT(lr > 0);
+ wait_for_thread_to_pause();
+ if (learning_rate != lr)
+ {
+ steps_without_progress = 0;
+ test_steps_without_progress = 0;
+ previous_loss_values.clear();
+ test_previous_loss_values.clear();
+ }
+ learning_rate = lr;
+ lr_schedule.set_size(0);
+ }
+
+ double get_learning_rate(
+ ) const
+ {
+ return learning_rate;
+ }
+
+ void set_min_learning_rate (
+ double lr
+ )
+ {
+ DLIB_CASSERT(lr > 0);
+ wait_for_thread_to_pause();
+ lr_schedule.set_size(0);
+ min_learning_rate = lr;
+ }
+
+ double get_min_learning_rate (
+ ) const
+ {
+ return min_learning_rate;
+ }
+
+ template <typename EXP>
+ void set_learning_rate_schedule (
+ const matrix_exp<EXP>& schedule
+ )
+ {
+ DLIB_CASSERT(schedule.size() > 0);
+ DLIB_CASSERT(min(schedule) > 0);
+ set_learning_rate(schedule(0,0));
+ set_min_learning_rate(min(schedule));
+ set_learning_rate_shrink_factor(1);
+ lr_schedule = matrix_cast<double>(reshape_to_column_vector(schedule));
+ lr_schedule_pos = 0;
+ }
+
+ const matrix<double,0,1>& get_learning_rate_schedule (
+ ) const
+ {
+ return lr_schedule;
+ }
+
+ void set_iterations_without_progress_threshold (
+ unsigned long thresh
+ )
+ {
+ wait_for_thread_to_pause();
+ lr_schedule.set_size(0);
+ iter_without_progress_thresh = thresh;
+ }
+
+ unsigned long get_iterations_without_progress_threshold (
+ ) const
+ {
+ return iter_without_progress_thresh;
+ }
+
+ unsigned long get_steps_without_progress (
+ ) const
+ {
+ return steps_without_progress;
+ }
+
+ void set_test_iterations_without_progress_threshold (
+ unsigned long thresh
+ )
+ {
+ wait_for_thread_to_pause();
+ lr_schedule.set_size(0);
+ test_iter_without_progress_thresh = thresh;
+ }
+
+ unsigned long get_test_iterations_without_progress_threshold (
+ ) const
+ {
+ return test_iter_without_progress_thresh;
+ }
+
+ unsigned long get_test_steps_without_progress (
+ ) const
+ {
+ return test_steps_without_progress;
+ }
+
+ void set_learning_rate_shrink_factor (
+ double shrink
+ )
+ {
+ DLIB_CASSERT(0 < shrink && shrink <= 1);
+ wait_for_thread_to_pause();
+ lr_schedule.set_size(0);
+ learning_rate_shrink = shrink;
+ steps_without_progress = 0;
+ test_steps_without_progress = 0;
+ }
+
+ double get_learning_rate_shrink_factor (
+ ) const
+ {
+ return learning_rate_shrink;
+ }
+
+ unsigned long long get_train_one_step_calls (
+ ) const
+ {
+ return train_one_step_calls;
+ }
+
+ unsigned long long get_test_one_step_calls (
+ ) const
+ {
+ return test_one_step_calls;
+ }
+
+ private:
+
+ void record_test_loss(double loss)
+ {
+ test_previous_loss_values.push_back(loss);
+ if (is_finite(loss))
+ rs_test.add(loss);
+ // discard really old loss values.
+ while (test_previous_loss_values.size() > test_iter_without_progress_thresh)
+ test_previous_loss_values.pop_front();
+ }
+
+ void record_loss(double loss)
+ {
+ // This kind of budgeting causes our gradient checking to use a fixed amount of
+ // computational resources, regardless of the size of iter_without_progress_thresh.
+ gradient_check_budget += 200;
+
+ rs.add(loss);
+ previous_loss_values.push_back(loss);
+ // discard really old loss values.
+ while (previous_loss_values.size() > iter_without_progress_thresh)
+ previous_loss_values.pop_front();
+ }
+
+ template <typename T>
+ double compute_parameter_gradients(size_t device, job_t& next_job, const T&)
+ {
+ if (next_job.have_data[device])
+ {
+ auto&& dev = *devices[device];
+ dlib::cuda::set_device(dev.device_id);
+ if (next_job.test_only)
+ return dev.net.compute_loss(next_job.t[device], next_job.labels[device].begin());
+ else
+ return dev.net.compute_parameter_gradients(next_job.t[device], next_job.labels[device].begin());
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ double compute_parameter_gradients(size_t device, job_t& next_job, const no_label_type&)
+ {
+ if (next_job.have_data[device])
+ {
+ auto&& dev = *devices[device];
+ dlib::cuda::set_device(dev.device_id);
+ no_label_type pick_which_run_update;
+ if (next_job.test_only)
+ return dev.net.compute_loss(next_job.t[device]);
+ else
+ return dev.net.compute_parameter_gradients(next_job.t[device]);
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ void update_parameters(size_t device)
+ {
+ auto&& dev = *devices[device];
+ dlib::cuda::set_device(dev.device_id);
+ dev.net.update_parameters(make_sstack(dev.solvers), learning_rate);
+ }
+
+ void thread() try
+ {
+ training_label_type pick_which_run_update;
+ job_t next_job;
+
+ std::vector<dlib::future<double>> losses(devices.size());
+
+ std::vector<tt::multi_device_tensor_averager> averagers;
+ // An array of all the parameter tensors in the first network. We will
+ // periodically copy these tensors to all the other devices to make sure the
+ // different GPUs don't go out of sync.
+ std::vector<tensor*> reference_params;
+ visit_layer_parameters(devices[0]->net, [&](size_t, tensor& t) { reference_params.push_back(&t); });
+
+ // We make separate thread pools with just one thread in them because we want
+ // to make sure each device is always executed on the same thread. We care
+ // about this because there are thread_local context variables for some cuda
+ // components and they get allocated for each combination of thread and device.
+ // So if we make sure the same device always uses the same thread this will
+ // reduce the number of contexts we allocate from num_devices*num_devices to
+ // just num_devices.
+ std::vector<std::shared_ptr<thread_pool>> tp;
+ for (size_t i = 0; i < devices.size(); ++i)
+ tp.push_back(std::make_shared<thread_pool>(1));
+
+
+ main_iteration_counter = 0;
+ while(job_pipe.dequeue(next_job))
+ {
+ if (next_job.test_only)
+ {
+ // compute the testing loss
+ for (size_t i = 0; i < devices.size(); ++i)
+ tp[i]->add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]);
+ // aggregate loss values from all the network computations.
+ double theloss = 0;
+ for (auto&& loss : losses)
+ theloss += loss.get();
+ record_test_loss(theloss/losses.size());
+
+ // Check if we should shrink the learning rate based on how the test
+ // error has been doing lately.
+ if (learning_rate_shrink != 1)
+ {
+ test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values);
+ if (test_steps_without_progress >= test_iter_without_progress_thresh)
+ {
+ test_steps_without_progress = count_steps_without_decrease_robust(test_previous_loss_values);
+ if (test_steps_without_progress >= test_iter_without_progress_thresh)
+ {
+ // optimization has flattened out, so drop the learning rate.
+ learning_rate = learning_rate_shrink*learning_rate;
+ test_steps_without_progress = 0;
+ // Empty out some of the previous loss values so that test_steps_without_progress
+ // will decrease below test_iter_without_progress_thresh.
+ for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt)
+ test_previous_loss_values.pop_front();
+ }
+ }
+ }
+ continue;
+ }
+
+ updated_net_since_last_sync = true;
+ ++main_iteration_counter;
+ // Call compute_parameter_gradients() and update_parameters() but pick the
+ // right version for unsupervised or supervised training based on the type
+ // of training_label_type.
+ for (size_t i = 0; i < devices.size(); ++i)
+ tp[i]->add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]);
+ // aggregate loss values from all the network computations.
+ double theloss = 0;
+ for (auto&& loss : losses)
+ theloss += loss.get();
+ record_loss(theloss/losses.size());
+
+ // Now, if there is more than one active device we need to synchronize the
+ // gradient updates between devices. So we do that now.
+ if (devices.size() > 1)
+ {
+ // if this is the first iteration then we need to setup the averagers.
+ // We can't do this outside the loop because the tensors that get
+ // averaged need to be allocated to their devices before we call set()
+ // so that the averagers can determine how best to average them.
+ if (averagers.size() == 0 || sync_file_reloaded)
+ {
+ averagers = std::vector<tt::multi_device_tensor_averager>(net_type::num_computational_layers);
+ // setup the averagers to point to the tensors in the networks.
+ std::vector<std::vector<tensor*>> all_tensors(devices.size());
+ for (size_t i = 0; i < all_tensors.size(); ++i)
+ {
+ all_tensors[i].resize(net_type::num_computational_layers);
+ visit_layer_parameter_gradients(devices[i]->net, [&](size_t j, tensor& t){
+ all_tensors[i][j] = &t;
+ });
+ }
+ // Now set each averager to average the tensors at the same layer in each
+ // network.
+ for (size_t i = 0; i < net_type::num_computational_layers; ++i)
+ {
+ std::vector<tensor*> temp(all_tensors.size());
+ for (size_t j = 0; j < all_tensors.size(); ++j)
+ temp[j] = all_tensors[j][i];
+ // ignore layers that don't have parameters
+ if (temp[0]->size() != 0)
+ averagers[i].set(temp);
+ }
+
+ sync_file_reloaded = false;
+ }
+
+
+ for (auto&& d : devices)
+ cuda::device_synchronize(d->device_id);
+
+ for (auto&& avg : averagers)
+ avg.average();
+ }
+
+
+ // Now apply all the updates to each device.
+ for (size_t i = 0; i < devices.size(); ++i)
+ tp[i]->add_task_by_value([&,i](){ if (next_job.have_data[i]) update_parameters(i); });
+ // and wait for the updates to all happen.
+ for (size_t i = 0; i < devices.size(); ++i)
+ tp[i]->wait_for_all_tasks();
+
+
+ // Every now and then force all the parameters to be the same just to make
+ // sure they aren't drifting apart due to any non-deterministic behavior on
+ // the GPU. It's also important to do this on the first iteration because
+ // the different networks may be initialized differently when tensor data
+ // is first passed through them. So this code block deals with these
+ // issues.
+ if (devices.size() > 1 && main_iteration_counter%2000 == 1)
+ {
+ for (size_t i = 1; i < devices.size(); ++i)
+ {
+ visit_layer_parameters(devices[i]->net, [&](size_t j, tensor& t)
+ {
+ memcpy(t, *reference_params[j]);
+ });
+ }
+ }
+
+ // If we have been running for a while then check if the loss is still
+ // dropping. If it isn't then we will reduce the learning rate. Note that we
+ // have a "budget" that prevents us from calling
+ // count_steps_without_decrease() every iteration. We do this because
+ // it can be expensive to compute when previous_loss_values is large.
+ if (gradient_check_budget > iter_without_progress_thresh && learning_rate_shrink != 1)
+ {
+ gradient_check_budget = 0;
+ steps_without_progress = count_steps_without_decrease(previous_loss_values);
+ if (steps_without_progress >= iter_without_progress_thresh)
+ {
+ // Double check that we aren't seeing decrease. This second check
+ // discards the top 10% largest values and checks again. We do
+ // this because sometimes a mini-batch might be bad and cause the
+ // loss to suddenly jump up, making count_steps_without_decrease()
+ // return a large number. But if we discard the top 10% of the
+ // values in previous_loss_values then we are robust to that kind
+ // of noise. Another way of looking at it, if the reason
+ // count_steps_without_decrease() returns a large value is only
+ // because the most recent loss values have suddenly been large,
+ // then we shouldn't stop or lower the learning rate. We should
+ // keep going until whatever disturbance we hit is damped down.
+ steps_without_progress = count_steps_without_decrease_robust(previous_loss_values);
+ if (steps_without_progress >= iter_without_progress_thresh)
+ {
+ // optimization has flattened out, so drop the learning rate.
+ learning_rate = learning_rate_shrink*learning_rate;
+ steps_without_progress = 0;
+ // Empty out some of the previous loss values so that steps_without_progress
+ // will decrease below iter_without_progress_thresh.
+ for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt)
+ previous_loss_values.pop_front();
+ }
+ }
+ }
+ else if (lr_schedule.size() != 0) // or use the learning rate schedule if we have one.
+ {
+ if (lr_schedule_pos < lr_schedule.size())
+ learning_rate = lr_schedule(lr_schedule_pos++);
+ else
+ learning_rate = lr_schedule(lr_schedule.size()-1)*0.99;
+ }
+ }
+ }
+ catch(...)
+ {
+ // If an exception happens then permanently disable the trainer object.
+ job_pipe.disable();
+ std::lock_guard<std::mutex> lock(eptr_mutex);
+ eptr = std::current_exception();
+ }
+
+ void wait_for_thread_to_pause() const
+ {
+ job_pipe.wait_for_num_blocked_dequeues(1);
+ }
+
+ const static long string_pad = 11;
+ const static long epoch_string_pad = 4;
+ const static long lr_string_pad = 4;
+
+ void init()
+ {
+ max_num_epochs = 10000;
+ mini_batch_size = 128;
+ verbose = false;
+ learning_rate = 1e-2;
+ min_learning_rate = 1e-5;
+ iter_without_progress_thresh = 2000;
+ steps_without_progress = 0;
+ test_iter_without_progress_thresh = 500;
+ test_steps_without_progress = 0;
+
+ learning_rate_shrink = 0.1;
+ epoch_iteration = 0;
+ epoch_pos = 0;
+ train_one_step_calls = 0;
+ test_one_step_calls = 0;
+ gradient_check_budget = 0;
+ lr_schedule_pos = 0;
+
+ main_iteration_counter = 0;
+ main_iteration_counter_at_last_disk_sync = 0;
+ prob_loss_increasing_thresh_default_value = 0.99;
+ prob_loss_increasing_thresh_max_value = 0.99999;
+ prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
+ updated_net_since_last_sync = false;
+ sync_file_reloaded = false;
+ previous_loss_values_dump_amount = 400;
+ test_previous_loss_values_dump_amount = 100;
+
+ rs_test = running_stats_decayed<double>(200);
+
+ start();
+ }
+
+ // serialize and deserialize are private because we hold net by reference so
+ // allowing someone to serialize this training object is weird and will likely
+ // result in user errors. However, we use these functions as part of the automatic
+ // sync code in this object.
+ friend void serialize(const dnn_trainer& item, std::ostream& out)
+ {
+ item.wait_for_thread_to_pause();
+ int version = 12;
+ serialize(version, out);
+
+ size_t nl = dnn_trainer::num_layers;
+ serialize(nl, out);
+ serialize(item.rs, out);
+ serialize(item.rs_test, out);
+ serialize(item.previous_loss_values, out);
+ serialize(item.max_num_epochs, out);
+ serialize(item.mini_batch_size, out);
+ serialize(item.verbose, out);
+ serialize(item.net, out);
+ serialize(item.devices[0]->solvers, out);
+ serialize(item.learning_rate.load(), out);
+ serialize(item.min_learning_rate, out);
+ serialize(item.iter_without_progress_thresh.load(), out);
+ serialize(item.steps_without_progress.load(), out);
+ serialize(item.learning_rate_shrink.load(), out);
+ serialize(item.epoch_iteration, out);
+ serialize(item.epoch_pos, out);
+ serialize(item.train_one_step_calls, out);
+ serialize(item.test_one_step_calls, out);
+ serialize(item.lr_schedule, out);
+ serialize(item.lr_schedule_pos, out);
+ serialize(item.test_iter_without_progress_thresh.load(), out);
+ serialize(item.test_steps_without_progress.load(), out);
+ serialize(item.test_previous_loss_values, out);
+ serialize(item.previous_loss_values_dump_amount, out);
+ serialize(item.test_previous_loss_values_dump_amount, out);
+
+ }
+ friend void deserialize(dnn_trainer& item, std::istream& in)
+ {
+ item.wait_for_thread_to_pause();
+ int version = 0;
+ deserialize(version, in);
+ if (version != 12)
+ throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
+
+ size_t num_layers = 0;
+ deserialize(num_layers, in);
+ if (num_layers != dnn_trainer::num_layers)
+ {
+ std::ostringstream sout;
+ sout << "Error deserializing dlib::dnn_trainer. The saved sync file is for a network with " << std::endl;
+ sout << "a different number of layers. We expected the number of layers to be " << dnn_trainer::num_layers << " but" << std::endl;
+ sout << "instead the file contains " << num_layers << " layers." << std::endl;
+ throw serialization_error(sout.str());
+ }
+
+ double dtemp; long ltemp;
+ deserialize(item.rs, in);
+ deserialize(item.rs_test, in);
+ deserialize(item.previous_loss_values, in);
+ deserialize(item.max_num_epochs, in);
+ deserialize(item.mini_batch_size, in);
+ deserialize(item.verbose, in);
+ deserialize(item.net, in);
+ deserialize(item.devices[0]->solvers, in);
+ deserialize(dtemp, in); item.learning_rate = dtemp;
+ deserialize(item.min_learning_rate, in);
+ deserialize(ltemp, in); item.iter_without_progress_thresh = ltemp;
+ deserialize(ltemp, in); item.steps_without_progress = ltemp;
+ deserialize(dtemp, in); item.learning_rate_shrink = dtemp;
+ deserialize(item.epoch_iteration, in);
+ deserialize(item.epoch_pos, in);
+ deserialize(item.train_one_step_calls, in);
+ deserialize(item.test_one_step_calls, in);
+ deserialize(item.lr_schedule, in);
+ deserialize(item.lr_schedule_pos, in);
+ deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp;
+ deserialize(ltemp, in); item.test_steps_without_progress = ltemp;
+ deserialize(item.test_previous_loss_values, in);
+ deserialize(item.previous_loss_values_dump_amount, in);
+ deserialize(item.test_previous_loss_values_dump_amount, in);
+
+ if (item.devices.size() > 1)
+ {
+ const auto prev_dev = dlib::cuda::get_device();
+ // initialize all the other device networks and solver objects
+ for (size_t i = 1; i < item.devices.size(); ++i)
+ {
+ // Switch to this device so that any tensor objects that get allocated when
+ // we copy this stuff happen on this device.
+ dlib::cuda::set_device(item.devices[i]->device_id);
+ item.devices[i]->solvers = item.devices[0]->solvers;
+ item.devices[i]->net = item.devices[0]->net;
+ }
+ dlib::cuda::set_device(prev_dev);
+ }
+ }
+
+ void sync_to_disk (
+ bool do_it_now = false
+ )
+ {
+ // don't sync anything if we haven't updated the network since the last sync
+ if (!updated_net_since_last_sync)
+ return;
+
+ // If the sync file isn't set then don't do anything.
+ if (sync_filename.size() == 0)
+ return;
+
+ // Only sync if it has been long enough since the last sync or we are being
+ // explicitly forced to do it.
+ if (std::chrono::system_clock::now() - last_sync_time > time_between_syncs ||
+ do_it_now)
+ {
+ wait_for_thread_to_pause();
+
+ // compact network before saving to disk.
+ this->net.clean();
+
+ // if the loss has actually been going up since the last time we saved our
+ // state to disk then something has probably gone wrong in the
+ // optimization. So in this case we do the opposite and recall the
+ // previously saved state in the hopes that the problem won't reoccur.
+ if (loss_increased_since_last_disk_sync())
+ {
+ std::ifstream fin(newest_syncfile(), std::ios::binary);
+ deserialize(*this, fin);
+ sync_file_reloaded = true;
+ if (verbose)
+ std::cout << "Loss has been increasing, reloading saved state from " << newest_syncfile() << std::endl;
+ }
+ else
+ {
+
+ const std::string filename = oldest_syncfile();
+ serialize(filename) << *this;
+
+ if (verbose)
+ std::cout << "Saved state to " << filename << std::endl;
+ }
+
+ last_sync_time = std::chrono::system_clock::now();
+ main_iteration_counter_at_last_disk_sync = main_iteration_counter;
+ updated_net_since_last_sync = false;
+ }
+ }
+
+ std::string newest_syncfile (
+ )
+ {
+ return select_newest_file(sync_filename, sync_filename + "_");
+ }
+
+ std::string oldest_syncfile (
+ )
+ {
+ return select_oldest_file(sync_filename, sync_filename + "_");
+ }
+
+ bool loss_increased_since_last_disk_sync()
+ {
+ size_t gradient_updates_since_last_sync = main_iteration_counter - main_iteration_counter_at_last_disk_sync;
+
+ // if we haven't synced anything to disk yet then return false.
+ if (!std::ifstream(newest_syncfile(), std::ios::binary))
+ return false;
+
+ for (auto x : previous_loss_values)
+ {
+ // If we get a NaN value of loss assume things have gone horribly wrong and
+ // we should reload the state of the trainer.
+ if (std::isnan(x))
+ return true;
+ }
+
+ // if we haven't seen much data yet then just say false. Or, alternatively, if
+ // it's been too long since the last sync then don't reload either.
+ if (gradient_updates_since_last_sync < 30 || previous_loss_values.size() < 2*gradient_updates_since_last_sync)
+ return false;
+
+ // Now look at the data since a little before the last disk sync. We will
+ // check if the loss is getting bettor or worse.
+ running_gradient g;
+ for (size_t i = previous_loss_values.size() - 2*gradient_updates_since_last_sync; i < previous_loss_values.size(); ++i)
+ g.add(previous_loss_values[i]);
+
+ // if the loss is very likely to be increasing then return true
+ const double prob = g.probability_gradient_greater_than(0);
+ if (prob > prob_loss_increasing_thresh && prob_loss_increasing_thresh <= prob_loss_increasing_thresh_max_value)
+ {
+ // Exponentially decay the threshold towards 1 so that if we keep finding
+ // the loss to be increasing over and over we will make the test
+ // progressively harder and harder until it fails, therefore ensuring we
+ // can't get stuck reloading from a previous state over and over.
+ prob_loss_increasing_thresh = 0.1*prob_loss_increasing_thresh + 0.9*1;
+ return true;
+ }
+ else
+ {
+ // decay back to the default threshold
+ prob_loss_increasing_thresh = std::pow(prob_loss_increasing_thresh, 10.0);
+ // but don't decay below the default value
+ prob_loss_increasing_thresh = std::max(prob_loss_increasing_thresh, prob_loss_increasing_thresh_default_value);
+
+ return false;
+ }
+ }
+
+
+ struct clone_net{};
+
+ // per device state. All the containers have the same number of objects in them.
+ struct device_data
+ {
+ device_data(
+ int device_id_,
+ net_type& net_,
+ const solver_type& solver_
+ ) : device_id(device_id_), net(net_), solvers(num_computational_layers, solver_) {}
+
+ device_data(
+ int device_id_,
+ net_type& net_,
+ const solver_type& solver_,
+ clone_net
+ ) : device_id(device_id_), net_copy(std::make_shared<net_type>(net_)), net(*net_copy), solvers(num_computational_layers, solver_) {}
+
+ int device_id;
+ std::shared_ptr<net_type> net_copy;
+ net_type& net;
+ std::vector<solver_type> solvers;
+ };
+
+ template <
+ typename data_iterator,
+ typename label_iterator
+ >
+ void send_job (
+ bool test_only,
+ data_iterator dbegin,
+ data_iterator dend,
+ label_iterator lbegin
+ )
+ {
+ propagate_exception();
+ size_t num = std::distance(dbegin, dend);
+ size_t devs = devices.size();
+ job.t.resize(devs);
+ job.labels.resize(devs);
+ job.have_data.resize(devs);
+ job.test_only = test_only;
+
+ // chop the data into devs blocks, each of about block_size elements.
+ size_t block_size = (num+devs-1)/devs;
+
+ const auto prev_dev = dlib::cuda::get_device();
+ for (size_t i = 0; i < devs; ++i)
+ {
+ dlib::cuda::set_device(devices[i]->device_id);
+
+ size_t start = i*block_size;
+ size_t stop = std::min(num, start+block_size);
+
+ if (start < stop)
+ {
+ devices[i]->net.to_tensor(dbegin+start, dbegin+stop, job.t[i]);
+ job.labels[i].assign(lbegin+start, lbegin+stop);
+ job.have_data[i] = true;
+ }
+ else
+ {
+ job.have_data[i] = false;
+ }
+ }
+
+ dlib::cuda::set_device(prev_dev);
+ job_pipe.enqueue(job);
+ }
+
+ template <
+ typename data_iterator
+ >
+ void send_job (
+ bool test_only,
+ data_iterator dbegin,
+ data_iterator dend
+ )
+ {
+ typename std::vector<training_label_type>::iterator nothing;
+ send_job(test_only, dbegin, dend, nothing);
+ }
+
+ void print_progress()
+ {
+ if (lr_schedule.size() == 0)
+ {
+ if (test_previous_loss_values.size() == 0)
+ std::cout << "steps without apparent progress: " << steps_without_progress;
+ else
+ std::cout << "steps without apparent progress: train=" << steps_without_progress << ", test=" << test_steps_without_progress;
+ }
+ else
+ {
+ std::ostringstream sout;
+ sout << "percent complete: " << std::fixed << std::setprecision(2) << 100.0*lr_schedule_pos/(double)lr_schedule.size() << "%";
+ std::cout << sout.str();
+ }
+ std::cout << std::endl;
+ }
+
+ void print_periodic_verbose_status()
+ {
+ if (verbose)
+ {
+ using namespace std::chrono;
+ auto now_time = system_clock::now();
+ if (now_time-last_time > seconds(40))
+ {
+ last_time = now_time;
+ std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
+ << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " ";
+ if (test_previous_loss_values.size() == 0)
+ {
+ std::cout << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
+ }
+ else
+ {
+ std::cout << "train loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
+ std::cout << "test loss: " << rpad(cast_to_string(get_average_test_loss()),string_pad) << " ";
+ }
+ print_progress();
+ clear_average_loss();
+ }
+ }
+ }
+
+ std::vector<std::shared_ptr<device_data>> devices;
+ dlib::pipe<job_t> job_pipe;
+ job_t job;
+
+
+ running_stats<double> rs;
+ running_stats_decayed<double> rs_test;
+ std::deque<double> previous_loss_values;
+ unsigned long max_num_epochs;
+ size_t mini_batch_size;
+ bool verbose;
+ net_type& net;
+ std::atomic<double> learning_rate;
+ double min_learning_rate;
+ std::atomic<unsigned long> iter_without_progress_thresh;
+ std::atomic<unsigned long> steps_without_progress;
+
+ std::atomic<unsigned long> test_iter_without_progress_thresh;
+ std::atomic<unsigned long> test_steps_without_progress;
+ std::deque<double> test_previous_loss_values;
+
+ std::atomic<double> learning_rate_shrink;
+ std::chrono::time_point<std::chrono::system_clock> last_sync_time;
+ std::string sync_filename;
+ std::chrono::seconds time_between_syncs;
+ unsigned long epoch_iteration;
+ size_t epoch_pos;
+ std::chrono::time_point<std::chrono::system_clock> last_time;
+ unsigned long long train_one_step_calls;
+ unsigned long long test_one_step_calls;
+ matrix<double,0,1> lr_schedule;
+ long lr_schedule_pos;
+ unsigned long gradient_check_budget;
+
+ std::exception_ptr eptr = nullptr;
+ mutable std::mutex eptr_mutex;
+ void propagate_exception() const
+ {
+ std::lock_guard<std::mutex> lock(eptr_mutex);
+ if (eptr)
+ std::rethrow_exception(eptr);
+ }
+
+ // These 5 variables are not serialized
+ size_t main_iteration_counter;
+ size_t main_iteration_counter_at_last_disk_sync;
+ double prob_loss_increasing_thresh_default_value;
+ double prob_loss_increasing_thresh_max_value;
+ double prob_loss_increasing_thresh;
+ std::atomic<bool> updated_net_since_last_sync;
+
+ bool sync_file_reloaded;
+ unsigned long previous_loss_values_dump_amount;
+ unsigned long test_previous_loss_values_dump_amount;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename net_type,
+ typename solver_type
+ >
+ std::ostream& operator<< (
+ std::ostream& out,
+ dnn_trainer<net_type,solver_type>& trainer
+ )
+ {
+ using std::endl;
+ out << "dnn_trainer details: \n";
+ out << " net_type::num_layers: " << net_type::num_layers << endl;
+ // figure out how big the net is in MB.
+ std::ostringstream sout;
+ net_type temp = trainer.get_net(); // make a copy so that we can clean it without mutating the trainer's net.
+ temp.clean();
+ serialize(temp, sout);
+ out << " net size: " << sout.str().size()/1024.0/1024.0 << "MB" << endl;
+ // Don't include the loss params in the hash since we print them on the next line.
+ // They also aren't really part of the "architecture" of the network.
+ out << " net architecture hash: " << md5(cast_to_string(trainer.get_net().subnet())) << endl;
+ out << " loss: " << trainer.get_net().loss_details() << endl;
+
+ out << " synchronization file: " << trainer.get_synchronization_file() << endl;
+ out << " trainer.get_solvers()[0]: " << trainer.get_solvers()[0] << endl;
+ auto sched = trainer.get_learning_rate_schedule();
+ if (sched.size() != 0)
+ {
+ out << " using explicit user-supplied learning rate schedule" << endl;
+ }
+ else
+ {
+ out << " learning rate: "<< trainer.get_learning_rate() << endl;
+ out << " learning rate shrink factor: "<< trainer.get_learning_rate_shrink_factor() << endl;
+ out << " min learning rate: "<< trainer.get_min_learning_rate() << endl;
+ out << " iterations without progress threshold: "<< trainer.get_iterations_without_progress_threshold() << endl;
+ out << " test iterations without progress threshold: "<< trainer.get_test_iterations_without_progress_threshold() << endl;
+ }
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_TRAINER_H_
+
diff --git a/ml/dlib/dlib/dnn/trainer_abstract.h b/ml/dlib/dlib/dnn/trainer_abstract.h
new file mode 100644
index 000000000..3bfb6dc99
--- /dev/null
+++ b/ml/dlib/dlib/dnn/trainer_abstract.h
@@ -0,0 +1,765 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DNn_TRAINER_ABSTRACT_H_
+#ifdef DLIB_DNn_TRAINER_ABSTRACT_H_
+
+#include "core_abstract.h"
+#include "solvers_abstract.h"
+#include <vector>
+#include <chrono>
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ enum class force_flush_to_disk {
+ no = 0,
+ yes = 1
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename net_type,
+ typename solver_type = sgd
+ >
+ class dnn_trainer
+ {
+ /*!
+ REQUIREMENTS ON net_type
+ - net_type is an add_loss_layer object.
+
+ REQUIREMENTS ON solver_type
+ - solver_type is an implementation of the EXAMPLE_SOLVER interface defined
+ in solvers_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool training a deep neural network. To use it you supply
+ a neural network type and a solver, then you call train() with your
+ training data and it will output a new network instance that has hopefully
+ learned something useful from your training data.
+
+ If you are compiling with CUDA then this object will use the GPU that is
+ currently selected (i.e. the one indicated by cudaGetDevice()) when
+ dnn_trainer is constructed. It will continue to use that device even if
+ you later change it by a call to cudaSetDevice().
+
+ EXCEPTIONS
+ If an exception is thrown by any part of the neural network during training
+ then the exception will be propagated out of the trainer to the user.
+ Moreover, the trainer instance will be unusable and should be destroyed.
+ !*/
+
+ public:
+
+ typedef typename net_type::training_label_type training_label_type;
+ typedef typename net_type::input_type input_type;
+ const static size_t num_computational_layers = net_type::num_computational_layers;
+
+ dnn_trainer() = delete;
+ dnn_trainer(const dnn_trainer&) = delete;
+ dnn_trainer& operator=(const dnn_trainer&) = delete;
+
+ dnn_trainer(
+ net_type& net,
+ const solver_type& solver = solver_type(),
+ const std::vector<int>& cuda_extra_devices = {}
+ );
+ /*!
+ requires
+ - for all valid i:
+ - 0 <= cuda_extra_devices[i] < dlib::cuda::get_num_devices()
+ ensures
+ - &#get_net() == &net
+ (i.e. The dnn_trainer holds a reference to net, it does not copy it.
+ Therefore, you must ensure net has a lifetime at least as long as the
+ dnn_trainer).
+ - #get_solvers() == a set of solvers that are all initialized with the
+ provided solver instance.
+ - #get_max_num_epochs() == 10000
+ - #get_mini_batch_size() == 128
+ - #get_learning_rate() == 1e-2
+ - #get_min_learning_rate() == 1e-5
+ - #get_iterations_without_progress_threshold() == 2000
+ - #get_test_iterations_without_progress_threshold() == 500
+ - #get_learning_rate_shrink_factor() == 0.1
+ - #get_learning_rate_schedule().size() == 0
+ - #get_train_one_step_calls() == 0
+ - #get_test_one_step_calls() == 0
+ - #get_synchronization_file() == ""
+ - if (cuda_extra_devices.size() > 0) then
+ - This object will use multiple graphics cards to run the learning
+ algorithms. In particular, it will always use whatever device is
+ currently selected on the calling thread (the device indicated by
+ cudaGetDevice()). In addition, you can ask to use additional
+ devices, which you do by putting their device numbers into
+ cuda_extra_devices.
+ !*/
+
+ net_type& get_net (
+ force_flush_to_disk force_flush = force_flush_to_disk::yes
+ );
+ /*!
+ ensures
+ - returns the neural network object used by this trainer. This is the
+ network that is optimized when you call train() or train_one_step().
+ Recall that the dnn_trainer doesn't contain the net_type object but
+ simply holds a reference to an external network which was provided to the
+ dnn_trainer's constructor.
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ - If force_flush is yes, then this function will sync the trainer state to
+ disk if the current state hasn't already been synced to disk since the
+ last network modification.
+ !*/
+
+ const std::vector<solver_type>& get_solvers (
+ ) const;
+ /*!
+ ensures
+ - returns the solvers used to optimize each layer of the neural network
+ get_net(). In particular, the first layer's solver is
+ get_solvers()[0], the second layer's solver is
+ get_solvers()[1], and so on.
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ unsigned long get_mini_batch_size (
+ ) const;
+ /*!
+ ensures
+ - During training, we call the network's update() routine over and over
+ with training data. The number of training samples we give to each call
+ to update is the "mini-batch size", which is defined by
+ get_mini_batch_size().
+ !*/
+
+ void set_mini_batch_size (
+ unsigned long batch_size
+ );
+ /*!
+ requires
+ - batch_size > 0
+ ensures
+ - #get_mini_batch_size() == batch_size
+ !*/
+
+ unsigned long get_max_num_epochs (
+ ) const;
+ /*!
+ ensures
+ - train() will execute at most get_max_num_epochs() iterations over the
+ training data before returning.
+ !*/
+
+ void set_max_num_epochs (
+ unsigned long num
+ );
+ /*!
+ requires
+ - num > 0
+ ensures
+ - #get_max_num_epochs() == num
+ !*/
+
+ void set_learning_rate (
+ double lr
+ );
+ /*!
+ requires
+ - lr > 0
+ ensures
+ - #get_learning_rate() == lr
+ - #get_learning_rate_schedule().size() == 0
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ double get_learning_rate(
+ ) const;
+ /*!
+ ensures
+ - During each training step, a solver tells us how to modify the parameters
+ of each layer in the network. It does this by outputting a step vector
+ that, when added to the parameters, will hopefully result in improved
+ network performance. The learning rate is one of the inputs to the
+ solver and influences the size of this step vector. This function
+ returns the current learning rate, that is, the learning rate that will
+ be used during the next training step.
+ !*/
+
+ void set_min_learning_rate (
+ double lr
+ );
+ /*!
+ requires
+ - lr > 0
+ ensures
+ - #get_min_learning_rate() == lr
+ - #get_learning_rate_schedule().size() == 0
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ double get_min_learning_rate (
+ ) const;
+ /*!
+ ensures
+ - During training via this->train(), this object will test if progress is
+ still being made and if it isn't then it will reduce get_learning_rate()
+ by setting it to get_learning_rate()*get_learning_rate_shrink_factor().
+ However, it will not reduce it below get_min_learning_rate(). Once this
+ minimum learning rate is crossed the training will terminate.
+ - get_min_learning_rate() doesn't apply if you are using train_one_step().
+ You can keep calling train_one_step() as many times as you want and the
+ learning rate will drop infinitely close to 0 if you run long enough.
+ !*/
+
+ template <typename EXP>
+ void set_learning_rate_schedule (
+ const matrix_exp<EXP>& schedule
+ );
+ /*!
+ requires
+ - schedule.size() > 0
+ - min(schedule) > 0
+ ensures
+ - #get_learning_rate_schedule() == reshape_to_column_vector(schedule)
+ - #get_learning_rate() == schedule(0,0)
+ - #get_min_learning_rate() == min(schedule)
+ - #set_learning_rate_shrink_factor() == 1
+ !*/
+
+ const matrix<double,0,1>& get_learning_rate_schedule (
+ ) const;
+ /*!
+ ensures
+ - if (this function returns a non-empty matrix) then
+ - This trainer will use an explicit learning rate schedule defined by
+ the learning rate values in get_learning_rate_schedule(). For
+ example, if get_learning_rate_schedule() returned {0.1, 0.09, 0.08,
+ 0.07, 0.06} then the first training mini-batch would use a learning
+ rate of 0.1, then the next training mini-batch uses 0.09, and then
+ 0.8, and so on until the end of the schedule is reached.
+
+ If you continue to run training after the end of the schedule has
+ been reached then the learning rate will be fixed to 0.99 times the
+ final value. So in our example, eventually the learning rate would
+ be fixed to 0.99*0.06. This allows you to test if we have reached the
+ end of the schedule by checking if get_learning_rate() >= 0.06.
+ !*/
+
+ unsigned long get_steps_without_progress (
+ ) const;
+ /*!
+ ensures
+ - if (get_learning_rate_shrink_factor() != 1) then
+ - returns an estimate of how many mini-batches have executed without us
+ observing a statistically significant decrease in the training error.
+ - else
+ - returns 0
+ !*/
+
+ void set_iterations_without_progress_threshold (
+ unsigned long thresh
+ );
+ /*!
+ ensures
+ - #get_iterations_without_progress_threshold() == thresh
+ - #get_learning_rate_schedule().size() == 0
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ unsigned long get_iterations_without_progress_threshold (
+ ) const;
+ /*!
+ ensures
+ - This object monitors the progress of training and estimates if the
+ training error is being reduced. It does this by looking at the previous
+ get_iterations_without_progress_threshold() mini-batch results and
+ applying the statistical test defined by the running_gradient object to
+ see if the training error is getting smaller. If it isn't being reduced
+ then get_learning_rate() is made smaller by a factor of get_learning_rate_shrink_factor().
+
+ Therefore, get_iterations_without_progress_threshold() should always be
+ set to something sensibly large so that this test can be done with
+ reasonably high confidence. Think of this test as saying "if the loss
+ hasn't decreased for the previous get_iterations_without_progress_threshold()
+ then shrink the learning rate".
+ !*/
+
+ void set_learning_rate_shrink_factor (
+ double shrink
+ );
+ /*!
+ requires
+ - 0 < shrink && shrink <= 1
+ ensures
+ - #get_learning_rate_shrink_factor() == shrink
+ - #get_learning_rate_schedule().size() == 0
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ double get_learning_rate_shrink_factor (
+ ) const;
+ /*!
+ ensures
+ - Whenever the training routine thinks it isn't making progress anymore it
+ will reduce get_learning_rate() by multiplying it by get_learning_rate_shrink_factor().
+ - You can disable the automatic learning rate reduction by setting
+ get_learning_rate_shrink_factor() to 1.
+ !*/
+
+ unsigned long long get_train_one_step_calls (
+ ) const;
+ /*!
+ ensures
+ - returns the number of times train_one_step() has been called.
+ !*/
+
+ unsigned long long get_test_one_step_calls (
+ ) const;
+ /*!
+ ensures
+ - returns the number of times test_one_step() has been called.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - This object will not print anything to standard out
+ !*/
+
+ void set_synchronization_file (
+ const std::string& filename,
+ std::chrono::seconds time_between_syncs = std::chrono::minutes(15)
+ );
+ /*!
+ ensures
+ - #get_synchronization_file() == filename
+ - While training is running, either via train() or repeated calls to
+ train_one_step(), this object will save its entire state, including the
+ state of get_net(), to disk in the file named filename every
+ time_between_syncs seconds.
+ - If the filename file already exists then the state of this trainer will
+ be loaded from that file by this call to set_synchronization_file().
+ This allows you to resume a training session which was previously
+ interrupted.
+ - It should be noted that when saving, the trainer will alternate between
+ saving to a file called filename and another file called filename+"_".
+ We do this because it's possible that your computer might crash (not
+ because of dlib, just in general) before the data is safely saved to
+ disk. This way, you will always have a backup file if the write to disk
+ gets corrupted or is incomplete. Moreover, when loading, we will always
+ load from the newest of the two possible files.
+ !*/
+
+ const std::string& get_synchronization_file (
+ );
+ /*!
+ ensures
+ - Returns the name of the file the dnn_trainer will periodically save it's
+ state to. If the return value is "" then synchronization is disabled.
+ !*/
+
+ void train (
+ const std::vector<input_type>& data,
+ const std::vector<training_label_type>& labels
+ );
+ /*!
+ requires
+ - data.size() == labels.size()
+ - data.size() > 0
+ - net_type uses a supervised loss.
+ i.e. net_type::training_label_type != no_label_type.
+ ensures
+ - Trains a supervised neural network based on the given training data.
+ The goal of training is to find the network parameters that minimize
+ get_net().compute_loss(data.begin(), data.end(), labels.begin()).
+ - The optimizer will run until get_learning_rate() < get_min_learning_rate()
+ or get_max_num_epochs() training epochs have been executed.
+ - Each layer in the network will be optimized by its corresponding solver
+ in get_solvers().
+ - Each call to train DOES NOT reinitialize the state of get_net() or
+ get_solvers(). That is, the existing state of the solvers and network is
+ the starting point for the optimization each time train() is called. In
+ particular, if you use the set_synchronization_file() method you can
+ resume an interrupted train() call by simply calling train() again and it
+ will pick up from the last synchronization point.
+ - You can obtain the average loss value during the final training epoch by
+ calling get_average_loss().
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ void train (
+ const std::vector<input_type>& data
+ );
+ /*!
+ requires
+ - data.size() > 0
+ - net_type uses an unsupervised loss.
+ i.e. net_type::training_label_type == no_label_type.
+ ensures
+ - Trains an unsupervised neural network based on the given training data.
+ The goal of training is to find the network parameters that minimize
+ get_net().compute_loss(data.begin(), data.end()).
+ - The optimizer will run until get_learning_rate() < get_min_learning_rate()
+ or get_max_num_epochs() training epochs have been executed.
+ - Each layer in the network will be optimized by its corresponding solver
+ in get_solvers().
+ - Each call to train DOES NOT reinitialize the state of get_net() or
+ get_solvers(). That is, the existing state of the solvers and network is
+ the starting point for the optimization each time train() is called. In
+ particular, if you use the set_synchronization_file() method you can
+ resume an interrupted train() call by simply calling train() again and it
+ will pick up from the last synchronization point.
+ - You can obtain the average loss value during the final training epoch by
+ calling get_average_loss().
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ void train_one_step (
+ const std::vector<input_type>& data,
+ const std::vector<training_label_type>& labels
+ );
+ /*!
+ requires
+ - data.size() == labels.size()
+ - data.size() > 0
+ - net_type uses a supervised loss.
+ i.e. net_type::training_label_type != no_label_type.
+ ensures
+ - Performs one stochastic gradient update step based on the mini-batch of
+ data and labels supplied to this function. In particular, calling
+ train_one_step() in a loop is equivalent to calling the train() method
+ defined above. However, train_one_step() allows you to stream data from
+ disk into the training process while train() requires you to first load
+ all the training data into RAM. Otherwise, these training methods are
+ equivalent.
+ - You can observe the current average loss value by calling get_average_loss().
+ - The network training will happen in another thread. Therefore, after
+ calling this function you should call get_net() before you touch the net
+ object from the calling thread to ensure no other threads are still
+ accessing the network.
+ - #get_train_one_step_calls() == get_train_one_step_calls() + 1.
+ !*/
+
+ template <
+ typename data_iterator,
+ typename label_iterator
+ >
+ void train_one_step (
+ data_iterator dbegin,
+ data_iterator dend,
+ label_iterator lbegin
+ );
+ /*!
+ requires
+ - std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable
+ - std::distance(dbegin, dend) > 0
+ - net_type uses a supervised loss.
+ i.e. net_type::training_label_type != no_label_type.
+ ensures
+ - Performs one stochastic gradient update step based on the mini-batch of
+ data and labels supplied to this function. In particular, calling
+ train_one_step() in a loop is equivalent to calling the train() method
+ defined above. However, train_one_step() allows you to stream data from
+ disk into the training process while train() requires you to first load
+ all the training data into RAM. Otherwise, these training methods are
+ equivalent.
+ - You can observe the current average loss value by calling get_average_loss().
+ - The network training will happen in another thread. Therefore, after
+ calling this function you should call get_net() before you touch the net
+ object from the calling thread to ensure no other threads are still
+ accessing the network.
+ - #get_train_one_step_calls() == get_train_one_step_calls() + 1.
+ !*/
+
+ void train_one_step (
+ const std::vector<input_type>& data
+ );
+ /*!
+ requires
+ - data.size() > 0
+ - net_type uses an unsupervised loss.
+ i.e. net_type::training_label_type == no_label_type.
+ ensures
+ - Performs one stochastic gradient update step based on the mini-batch of
+ data supplied to this function. In particular, calling train_one_step()
+ in a loop is equivalent to calling the train() method defined above.
+ However, train_one_step() allows you to stream data from disk into the
+ training process while train() requires you to first load all the
+ training data into RAM. Otherwise, these training methods are
+ equivalent.
+ - You can observe the current average loss value by calling get_average_loss().
+ - The network training will happen in another thread. Therefore, after
+ calling this function you should call get_net() before you touch the net
+ object from the calling thread to ensure no other threads are still
+ accessing the network.
+ - #get_train_one_step_calls() == get_train_one_step_calls() + 1.
+ !*/
+
+ template <
+ typename data_iterator
+ >
+ void train_one_step (
+ data_iterator dbegin,
+ data_iterator dend
+ );
+ /*!
+ requires
+ - std::distance(dbegin, dend) > 0
+ - net_type uses an unsupervised loss.
+ i.e. net_type::training_label_type == no_label_type.
+ ensures
+ - Performs one stochastic gradient update step based on the mini-batch of
+ data supplied to this function. In particular, calling train_one_step()
+ in a loop is equivalent to calling the train() method defined above.
+ However, train_one_step() allows you to stream data from disk into the
+ training process while train() requires you to first load all the
+ training data into RAM. Otherwise, these training methods are
+ equivalent.
+ - You can observe the current average loss value by calling get_average_loss().
+ - The network training will happen in another thread. Therefore, after
+ calling this function you should call get_net() before you touch the net
+ object from the calling thread to ensure no other threads are still
+ accessing the network.
+ - #get_train_one_step_calls() == get_train_one_step_calls() + 1.
+ !*/
+
+ double get_average_loss (
+ ) const;
+ /*!
+ ensures
+ - returns the average loss value observed during previous calls to
+ train_one_step() or train(). That is, the average output of
+ net_type::update() during the previous mini-batch updates.
+ - Note that, if be_verbose() has been called, then this object will
+ automatically call clear_average_loss() periodically when it logs the
+ loss to the console.
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ void clear_average_loss (
+ );
+ /*!
+ ensures
+ - #get_average_loss() == 0
+ - get_average_loss() uses a dlib::running_stats object to keep a running
+ average of the loss values seen during the previous mini-batch updates
+ applied during training. Calling clear_average_loss() resets the
+ running_stats object so it forgets about all previous loss values
+ observed.
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ // ----------------------
+
+ double get_average_test_loss (
+ ) const;
+ /*!
+ ensures
+ - returns the average loss value observed during previous calls to
+ test_one_step().
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ void test_one_step (
+ const std::vector<input_type>& data,
+ const std::vector<training_label_type>& labels
+ );
+ /*!
+ requires
+ - data.size() == labels.size()
+ - data.size() > 0
+ - net_type uses a supervised loss.
+ i.e. net_type::training_label_type != no_label_type.
+ ensures
+ - Runs the given data through the network and computes and records the loss.
+ - This call does not modify network parameters. The point of
+ test_one_step() is two fold, to allow you to observe the accuracy of the
+ network on hold out data during training, and to allow the trainer to
+ automatically adjust the learning rate when the test loss stops
+ improving. It should be noted that you are not required to use
+ test_one_step() at all, but if you want to do this kind of thing it is
+ available.
+ - You can observe the current average loss value by calling get_average_test_loss().
+ - The computation will happen in another thread. Therefore, after calling
+ this function you should call get_net() before you touch the net object
+ from the calling thread to ensure no other threads are still accessing
+ the network.
+ - #get_test_one_step_calls() == get_test_one_step_calls() + 1.
+ !*/
+
+ template <
+ typename data_iterator,
+ typename label_iterator
+ >
+ void test_one_step (
+ data_iterator dbegin,
+ data_iterator dend,
+ label_iterator lbegin
+ );
+ /*!
+ requires
+ - std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable
+ - std::distance(dbegin, dend) > 0
+ - net_type uses a supervised loss.
+ i.e. net_type::training_label_type != no_label_type.
+ ensures
+ - Runs the given data through the network and computes and records the loss.
+ - This call does not modify network parameters. The point of
+ test_one_step() is two fold, to allow you to observe the accuracy of the
+ network on hold out data during training, and to allow the trainer to
+ automatically adjust the learning rate when the test loss stops
+ improving. It should be noted that you are not required to use
+ test_one_step() at all, but if you want to do this kind of thing it is
+ available.
+ - You can observe the current average loss value by calling get_average_test_loss().
+ - The computation will happen in another thread. Therefore, after calling
+ this function you should call get_net() before you touch the net object
+ from the calling thread to ensure no other threads are still accessing
+ the network.
+ - #get_test_one_step_calls() == get_test_one_step_calls() + 1.
+ !*/
+
+ void test_one_step (
+ const std::vector<input_type>& data
+ );
+ /*!
+ requires
+ - data.size() > 0
+ - net_type uses an unsupervised loss.
+ i.e. net_type::training_label_type == no_label_type.
+ ensures
+ - Runs the given data through the network and computes and records the loss.
+ - This call does not modify network parameters. The point of
+ test_one_step() is two fold, to allow you to observe the accuracy of the
+ network on hold out data during training, and to allow the trainer to
+ automatically adjust the learning rate when the test loss stops
+ improving. It should be noted that you are not required to use
+ test_one_step() at all, but if you want to do this kind of thing it is
+ available.
+ - You can observe the current average loss value by calling get_average_test_loss().
+ - The computation will happen in another thread. Therefore, after calling
+ this function you should call get_net() before you touch the net object
+ from the calling thread to ensure no other threads are still accessing
+ the network.
+ - #get_test_one_step_calls() == get_test_one_step_calls() + 1.
+ !*/
+
+ template <
+ typename data_iterator
+ >
+ void test_one_step (
+ data_iterator dbegin,
+ data_iterator dend
+ );
+ /*!
+ requires
+ - std::distance(dbegin, dend) > 0
+ - net_type uses an unsupervised loss.
+ i.e. net_type::training_label_type == no_label_type.
+ ensures
+ - Runs the given data through the network and computes and records the loss.
+ - This call does not modify network parameters. The point of
+ test_one_step() is two fold, to allow you to observe the accuracy of the
+ network on hold out data during training, and to allow the trainer to
+ automatically adjust the learning rate when the test loss stops
+ improving. It should be noted that you are not required to use
+ test_one_step() at all, but if you want to do this kind of thing it is
+ available.
+ - You can observe the current average loss value by calling get_average_test_loss().
+ - The computation will happen in another thread. Therefore, after calling
+ this function you should call get_net() before you touch the net object
+ from the calling thread to ensure no other threads are still accessing
+ the network.
+ - #get_test_one_step_calls() == get_test_one_step_calls() + 1.
+ !*/
+
+ void set_test_iterations_without_progress_threshold (
+ unsigned long thresh
+ );
+ /*!
+ ensures
+ - #get_test_iterations_without_progress_threshold() == thresh
+ - #get_learning_rate_schedule().size() == 0
+ - This function blocks until all threads inside the dnn_trainer have
+ stopped touching the net.
+ !*/
+
+ unsigned long get_test_iterations_without_progress_threshold (
+ ) const;
+ /*!
+ ensures
+ - This object monitors the progress of training and estimates if the
+ testing error is being reduced. It does this by looking at the previous
+ get_test_iterations_without_progress_threshold() mini-batch results from
+ test_one_step() and applying the statistical test defined by the
+ running_gradient object to see if the testing error is getting smaller.
+ If it isn't being reduced then get_learning_rate() is made smaller by a
+ factor of get_learning_rate_shrink_factor().
+
+ Therefore, get_test_iterations_without_progress_threshold() should always be
+ set to something sensibly large so that this test can be done with
+ reasonably high confidence. Think of this test as saying "if the testing loss
+ hasn't decreased for the previous get_test_iterations_without_progress_threshold()
+ calls to test_one_step() then shrink the learning rate".
+ !*/
+
+ unsigned long get_test_steps_without_progress (
+ ) const;
+ /*!
+ ensures
+ - if (get_learning_rate_shrink_factor() != 1) then
+ - returns an estimate of how many mini-batches have executed without us
+ observing a statistically significant decrease in the testing error
+ (i.e. the error on the data given to the trainer via test_one_step()
+ calls).
+ - else
+ - returns 0
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename net_type,
+ typename solver_type
+ >
+ std::ostream& operator<< (
+ std::ostream& out,
+ dnn_trainer<net_type,solver_type>& trainer
+ );
+ /*!
+ ensures
+ - Prints a log of the current parameters of trainer to out.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_TRAINER_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/dnn/utilities.h b/ml/dlib/dlib/dnn/utilities.h
new file mode 100644
index 000000000..976128c81
--- /dev/null
+++ b/ml/dlib/dlib/dnn/utilities.h
@@ -0,0 +1,281 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_UTILITIES_H_
+#define DLIB_DNn_UTILITIES_H_
+
+#include "core.h"
+#include "utilities_abstract.h"
+#include "../geometry.h"
+#include <fstream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ inline double log1pexp(double x)
+ {
+ using std::exp;
+ using namespace std; // Do this instead of using std::log1p because some compilers
+ // error out otherwise (E.g. gcc 4.9 in cygwin)
+ if (x <= -37)
+ return exp(x);
+ else if (-37 < x && x <= 18)
+ return log1p(exp(x));
+ else if (18 < x && x <= 33.3)
+ return x + exp(-x);
+ else
+ return x;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void randomize_parameters (
+ tensor& params,
+ unsigned long num_inputs_and_outputs,
+ dlib::rand& rnd
+ )
+ {
+ for (auto& val : params)
+ {
+ // Draw a random number to initialize the layer according to formula (16)
+ // from Understanding the difficulty of training deep feedforward neural
+ // networks by Xavier Glorot and Yoshua Bengio.
+ val = 2*rnd.get_random_float()-1;
+ val *= std::sqrt(6.0/(num_inputs_and_outputs));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ class visitor_net_to_xml
+ {
+ public:
+
+ visitor_net_to_xml(std::ostream& out_) : out(out_) {}
+
+ template<typename input_layer_type>
+ void operator()(size_t idx, const input_layer_type& l)
+ {
+ out << "<layer idx='"<<idx<<"' type='input'>\n";
+ to_xml(l,out);
+ out << "</layer>\n";
+ }
+
+ template <typename T, typename U>
+ void operator()(size_t idx, const add_loss_layer<T,U>& l)
+ {
+ out << "<layer idx='"<<idx<<"' type='loss'>\n";
+ to_xml(l.loss_details(),out);
+ out << "</layer>\n";
+ }
+
+ template <typename T, typename U, typename E>
+ void operator()(size_t idx, const add_layer<T,U,E>& l)
+ {
+ out << "<layer idx='"<<idx<<"' type='comp'>\n";
+ to_xml(l.layer_details(),out);
+ out << "</layer>\n";
+ }
+
+ template <unsigned long ID, typename U, typename E>
+ void operator()(size_t idx, const add_tag_layer<ID,U,E>& l)
+ {
+ out << "<layer idx='"<<idx<<"' type='tag' id='"<<ID<<"'/>\n";
+ }
+
+ template <template<typename> class T, typename U>
+ void operator()(size_t idx, const add_skip_layer<T,U>& l)
+ {
+ out << "<layer idx='"<<idx<<"' type='skip' id='"<<(tag_id<T>::id)<<"'/>\n";
+ }
+
+ private:
+
+ std::ostream& out;
+ };
+ }
+
+ template <typename net_type>
+ void net_to_xml (
+ const net_type& net,
+ std::ostream& out
+ )
+ {
+ auto old_precision = out.precision(9);
+ out << "<net>\n";
+ visit_layers(net, impl::visitor_net_to_xml(out));
+ out << "</net>\n";
+ // restore the original stream precision.
+ out.precision(old_precision);
+ }
+
+ template <typename net_type>
+ void net_to_xml (
+ const net_type& net,
+ const std::string& filename
+ )
+ {
+ std::ofstream fout(filename);
+ net_to_xml(net, fout);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ class visitor_net_map_input_to_output
+ {
+ public:
+
+ visitor_net_map_input_to_output(dpoint& p_) : p(p_) {}
+
+ dpoint& p;
+
+ template<typename input_layer_type>
+ void operator()(const input_layer_type& net)
+ {
+ }
+
+ template <typename T, typename U>
+ void operator()(const add_loss_layer<T,U>& net)
+ {
+ (*this)(net.subnet());
+ }
+
+ template <typename T, typename U, typename E>
+ void operator()(const add_layer<T,U,E>& net)
+ {
+ (*this)(net.subnet());
+ p = net.layer_details().map_input_to_output(p);
+ }
+ template <bool B, typename T, typename U, typename E>
+ void operator()(const dimpl::subnet_wrapper<add_layer<T,U,E>,B>& net)
+ {
+ (*this)(net.subnet());
+ p = net.layer_details().map_input_to_output(p);
+ }
+
+
+ template <unsigned long ID, typename U, typename E>
+ void operator()(const add_tag_layer<ID,U,E>& net)
+ {
+ // tag layers are an identity transform, so do nothing
+ (*this)(net.subnet());
+ }
+ template <bool is_first, unsigned long ID, typename U, typename E>
+ void operator()(const dimpl::subnet_wrapper<add_tag_layer<ID,U,E>,is_first>& net)
+ {
+ // tag layers are an identity transform, so do nothing
+ (*this)(net.subnet());
+ }
+
+
+ template <template<typename> class TAG_TYPE, typename U>
+ void operator()(const add_skip_layer<TAG_TYPE,U>& net)
+ {
+ (*this)(layer<TAG_TYPE>(net));
+ }
+ template <bool is_first, template<typename> class TAG_TYPE, typename SUBNET>
+ void operator()(const dimpl::subnet_wrapper<add_skip_layer<TAG_TYPE,SUBNET>,is_first>& net)
+ {
+ // skip layers are an identity transform, so do nothing
+ (*this)(layer<TAG_TYPE>(net));
+ }
+
+ };
+
+ class visitor_net_map_output_to_input
+ {
+ public:
+ visitor_net_map_output_to_input(dpoint& p_) : p(p_) {}
+
+ dpoint& p;
+
+ template<typename input_layer_type>
+ void operator()(const input_layer_type& net)
+ {
+ }
+
+ template <typename T, typename U>
+ void operator()(const add_loss_layer<T,U>& net)
+ {
+ (*this)(net.subnet());
+ }
+
+ template <typename T, typename U, typename E>
+ void operator()(const add_layer<T,U,E>& net)
+ {
+ p = net.layer_details().map_output_to_input(p);
+ (*this)(net.subnet());
+ }
+ template <bool B, typename T, typename U, typename E>
+ void operator()(const dimpl::subnet_wrapper<add_layer<T,U,E>,B>& net)
+ {
+ p = net.layer_details().map_output_to_input(p);
+ (*this)(net.subnet());
+ }
+
+
+ template <unsigned long ID, typename U, typename E>
+ void operator()(const add_tag_layer<ID,U,E>& net)
+ {
+ // tag layers are an identity transform, so do nothing
+ (*this)(net.subnet());
+ }
+ template <bool is_first, unsigned long ID, typename U, typename E>
+ void operator()(const dimpl::subnet_wrapper<add_tag_layer<ID,U,E>,is_first>& net)
+ {
+ // tag layers are an identity transform, so do nothing
+ (*this)(net.subnet());
+ }
+
+
+ template <template<typename> class TAG_TYPE, typename U>
+ void operator()(const add_skip_layer<TAG_TYPE,U>& net)
+ {
+ (*this)(layer<TAG_TYPE>(net));
+ }
+ template <bool is_first, template<typename> class TAG_TYPE, typename SUBNET>
+ void operator()(const dimpl::subnet_wrapper<add_skip_layer<TAG_TYPE,SUBNET>,is_first>& net)
+ {
+ // skip layers are an identity transform, so do nothing
+ (*this)(layer<TAG_TYPE>(net));
+ }
+
+ };
+ }
+
+ template <typename net_type>
+ inline dpoint input_tensor_to_output_tensor(
+ const net_type& net,
+ dpoint p
+ )
+ {
+ impl::visitor_net_map_input_to_output temp(p);
+ temp(net);
+ return p;
+ }
+
+ template <typename net_type>
+ inline dpoint output_tensor_to_input_tensor(
+ const net_type& net,
+ dpoint p
+ )
+ {
+ impl::visitor_net_map_output_to_input temp(p);
+ temp(net);
+ return p;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_UTILITIES_H_
+
+
+
diff --git a/ml/dlib/dlib/dnn/utilities_abstract.h b/ml/dlib/dlib/dnn/utilities_abstract.h
new file mode 100644
index 000000000..2a9a3d3fc
--- /dev/null
+++ b/ml/dlib/dlib/dnn/utilities_abstract.h
@@ -0,0 +1,127 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DNn_UTILITIES_ABSTRACT_H_
+#ifdef DLIB_DNn_UTILITIES_ABSTRACT_H_
+
+#include "core_abstract.h"
+#include "../geometry/vector_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ double log1pexp(
+ double x
+ );
+ /*!
+ ensures
+ - returns log(1+exp(x))
+ (except computes it using a numerically accurate method)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void randomize_parameters (
+ tensor& params,
+ unsigned long num_inputs_and_outputs,
+ dlib::rand& rnd
+ );
+ /*!
+ ensures
+ - This function assigns random values into params based on the given random
+ number generator. In particular, it uses the parameter initialization method
+ of formula 16 from the paper "Understanding the difficulty of training deep
+ feedforward neural networks" by Xavier Glorot and Yoshua Bengio.
+ - It is assumed that the total number of inputs and outputs from the layer is
+ num_inputs_and_outputs. That is, you should set num_inputs_and_outputs to
+ the sum of the dimensionalities of the vectors going into and out of the
+ layer that uses params as its parameters.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename net_type>
+ void net_to_xml (
+ const net_type& net,
+ std::ostream& out
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - All layers in the net must provide to_xml() functions.
+ ensures
+ - Prints the given neural network object as an XML document to the given output
+ stream.
+ !*/
+
+ template <typename net_type>
+ void net_to_xml (
+ const net_type& net,
+ const std::string& filename
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
+ add_tag_layer.
+ - All layers in the net must provide to_xml() functions.
+ ensures
+ - This function is just like the above net_to_xml(), except it writes to a file
+ rather than an ostream.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename net_type>
+ dpoint input_tensor_to_output_tensor(
+ const net_type& net,
+ dpoint p
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_skip_layer, or add_tag_layer.
+ - All layers in the net must provide map_input_to_output() functions.
+ ensures
+ - Given a dpoint (i.e. a row,column coordinate) in the input tensor given to
+ net, this function returns the corresponding dpoint in the output tensor
+ net.get_output(). This kind of mapping is useful when working with fully
+ convolutional networks as you will often want to know what parts of the
+ output feature maps correspond to what parts of the input.
+ - If the network contains skip layers then any layers skipped over by the skip
+ layer are ignored for the purpose of computing this coordinate mapping. That
+ is, if you walk the network from the output layer to the input layer, where
+ each time you encounter a skip layer you jump to the layer indicated by the
+ skip layer, you will visit exactly the layers in the network involved in the
+ input_tensor_to_output_tensor() calculation. This behavior is useful since it
+ allows you to compute some auxiliary DNN as a separate branch of computation
+ that is separate from the main network's job of running some kind of fully
+ convolutional network over an image. For instance, you might want to have a
+ branch in your network that computes some global image level
+ summarization/feature.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename net_type>
+ dpoint output_tensor_to_input_tensor(
+ const net_type& net,
+ dpoint p
+ );
+ /*!
+ requires
+ - net_type is an object of type add_layer, add_skip_layer, or add_tag_layer.
+ - All layers in the net must provide map_output_to_input() functions.
+ ensures
+ - This function provides the reverse mapping of input_tensor_to_output_tensor().
+ That is, given a dpoint in net.get_output(), what is the corresponding dpoint
+ in the input tensor?
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_UTILITIES_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/dnn/validation.h b/ml/dlib/dlib/dnn/validation.h
new file mode 100644
index 000000000..c65cb4526
--- /dev/null
+++ b/ml/dlib/dlib/dnn/validation.h
@@ -0,0 +1,122 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNn_VALIDATION_H_
+#define DLIB_DNn_VALIDATION_H_
+
+#include "../svm/cross_validate_object_detection_trainer_abstract.h"
+#include "../svm/cross_validate_object_detection_trainer.h"
+#include "layers.h"
+#include <set>
+
+namespace dlib
+{
+ namespace impl
+ {
+ inline std::set<std::string> get_labels (
+ const std::vector<mmod_rect>& rects1,
+ const std::vector<mmod_rect>& rects2
+ )
+ {
+ std::set<std::string> labels;
+ for (auto& rr : rects1)
+ labels.insert(rr.label);
+ for (auto& rr : rects2)
+ labels.insert(rr.label);
+ return labels;
+ }
+ }
+
+ template <
+ typename SUBNET,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ loss_mmod<SUBNET>& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<mmod_rect>>& truth_dets,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0,
+ const test_box_overlap& overlaps_ignore_tester = test_box_overlap()
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( is_learning_problem(images,truth_dets) == true ,
+ "\t matrix test_object_detection_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
+ << "\n\t images.size(): " << images.size()
+ );
+
+
+
+ double correct_hits = 0;
+ double total_true_targets = 0;
+
+ std::vector<std::pair<double,bool> > all_dets;
+ unsigned long missing_detections = 0;
+
+ resizable_tensor temp;
+
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ std::vector<mmod_rect> hits;
+ detector.to_tensor(&images[i], &images[i]+1, temp);
+ detector.subnet().forward(temp);
+ detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold);
+
+
+ for (auto& label : impl::get_labels(truth_dets[i], hits))
+ {
+ std::vector<full_object_detection> truth_boxes;
+ std::vector<rectangle> ignore;
+ std::vector<std::pair<double,rectangle>> boxes;
+ // copy hits and truth_dets into the above three objects
+ for (auto&& b : truth_dets[i])
+ {
+ if (b.ignore)
+ {
+ ignore.push_back(b);
+ }
+ else if (b.label == label)
+ {
+ truth_boxes.push_back(full_object_detection(b.rect));
+ ++total_true_targets;
+ }
+ }
+ for (auto&& b : hits)
+ {
+ if (b.label == label)
+ boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
+ }
+
+ correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester);
+ }
+ }
+
+ std::sort(all_dets.rbegin(), all_dets.rend());
+
+ double precision, recall;
+
+ double total_hits = all_dets.size();
+
+ if (total_hits == 0)
+ precision = 1;
+ else
+ precision = correct_hits / total_hits;
+
+ if (total_true_targets == 0)
+ recall = 1;
+ else
+ recall = correct_hits / total_true_targets;
+
+ matrix<double, 1, 3> res;
+ res = precision, recall, average_precision(all_dets, missing_detections);
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DNn_VALIDATION_H_
+
diff --git a/ml/dlib/dlib/enable_if.h b/ml/dlib/dlib/enable_if.h
new file mode 100644
index 000000000..f081dea6d
--- /dev/null
+++ b/ml/dlib/dlib/enable_if.h
@@ -0,0 +1,62 @@
+// Copyright 2003 (C) The Trustees of Indiana University.
+// Use, modification, and distribution is subject to the Boost Software
+// License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
+// http://www.boost.org/LICENSE_1_0.txt)
+// Authors: Jaakko Jarvi (jajarvi at osl.iu.edu)
+// Jeremiah Willcock (jewillco at osl.iu.edu)
+// Andrew Lumsdaine (lums at osl.iu.edu)
+#ifndef DLIB_BOOST_UTILITY_ENABLE_IF_HPP
+#define DLIB_BOOST_UTILITY_ENABLE_IF_HPP
+
+namespace dlib
+{
+
+ template <bool B, class T = void>
+ struct enable_if_c {
+ typedef T type;
+ };
+
+ template <class T>
+ struct enable_if_c<false, T> {};
+
+ template <class Cond, class T = void>
+ struct enable_if : public enable_if_c<Cond::value, T> {};
+
+ template <bool B, class T>
+ struct lazy_enable_if_c {
+ typedef typename T::type type;
+ };
+
+ template <class T>
+ struct lazy_enable_if_c<false, T> {};
+
+ template <class Cond, class T>
+ struct lazy_enable_if : public lazy_enable_if_c<Cond::value, T> {};
+
+
+ template <bool B, class T = void>
+ struct disable_if_c {
+ typedef T type;
+ };
+
+ template <class T>
+ struct disable_if_c<true, T> {};
+
+ template <class Cond, class T = void>
+ struct disable_if : public disable_if_c<Cond::value, T> {};
+
+ template <bool B, class T>
+ struct lazy_disable_if_c {
+ typedef typename T::type type;
+ };
+
+ template <class T>
+ struct lazy_disable_if_c<true, T> {};
+
+ template <class Cond, class T>
+ struct lazy_disable_if : public lazy_disable_if_c<Cond::value, T> {};
+
+} // namespace dlib
+
+#endif // DLIB_BOOST_UTILITY_ENABLE_IF_HPP
+
diff --git a/ml/dlib/dlib/entropy_decoder.h b/ml/dlib/dlib/entropy_decoder.h
new file mode 100644
index 000000000..345dffa8c
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder.h
@@ -0,0 +1,44 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODEr_
+#define DLIB_ENTROPY_DECODEr_
+
+#include "entropy_decoder/entropy_decoder_kernel_1.h"
+#include "entropy_decoder/entropy_decoder_kernel_2.h"
+#include "entropy_decoder/entropy_decoder_kernel_c.h"
+
+
+
+
+namespace dlib
+{
+
+
+ class entropy_decoder
+ {
+ entropy_decoder() {}
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef entropy_decoder_kernel_1
+ kernel_1a;
+ typedef entropy_decoder_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ // kernel_2a
+ typedef entropy_decoder_kernel_2
+ kernel_2a;
+ typedef entropy_decoder_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+
+ };
+}
+
+#endif // DLIB_ENTROPY_DECODEr_
+
diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.cpp b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.cpp
new file mode 100644
index 000000000..82c583634
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.cpp
@@ -0,0 +1,220 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_KERNEL_1_CPp_
+#define DLIB_ENTROPY_DECODER_KERNEL_1_CPp_
+#include "entropy_decoder_kernel_1.h"
+#include <iostream>
+#include <streambuf>
+#include <sstream>
+
+#include "../assert.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ entropy_decoder_kernel_1::
+ entropy_decoder_kernel_1(
+ ) :
+ initial_low(0x00000001),
+ initial_high(0xffffffff),
+ in(0),
+ low(initial_low),
+ high(initial_high),
+ buf(0),
+ buf_used(0),
+ target(0x00000000),
+ r(0)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ entropy_decoder_kernel_1::
+ ~entropy_decoder_kernel_1 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_decoder_kernel_1::
+ clear(
+ )
+ {
+ in = 0;
+ buf_used = 0;
+ buf = 0;
+ r = 0;
+ low = initial_low;
+ high = initial_high;
+ target = 0x00000000;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_decoder_kernel_1::
+ set_stream (
+ std::istream& in_
+ )
+ {
+ buf_used = 0;
+ buf = 0;
+ r = 0;
+ low = initial_low;
+ high = initial_high;
+ target = 0x00000000;
+
+ in = &in_;
+ streambuf = in_.rdbuf();
+
+
+
+ unsigned char ch;
+
+
+ streambuf->sgetn((char*)&ch,1);
+ target = ch;
+
+ target <<= 8;
+ if (streambuf->sgetn((char*)&ch,1))
+ target += ch;
+
+
+ target <<= 8;
+ if (streambuf->sgetn((char*)&ch,1))
+ target += ch;
+
+
+ target <<= 8;
+ if (streambuf->sgetn((char*)&ch,1))
+ target += ch;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool entropy_decoder_kernel_1::
+ stream_is_set (
+ ) const
+ {
+ if (in != 0)
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::istream& entropy_decoder_kernel_1::
+ get_stream (
+ ) const
+ {
+ return *in;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_decoder_kernel_1::
+ decode (
+ uint32 low_count,
+ uint32 high_count
+ )
+ {
+ // note that we must subtract 1 to preserve the convention that
+ // high == the real upper range - 1
+ high = low + r*high_count - 1;
+ low = low + r*low_count;
+ r = 0;
+
+
+
+ while (true)
+ {
+
+ // if the highest order bit in high and low is the same
+ if ( low >= 0x80000000 || high < 0x80000000)
+ {
+ // make sure buf isn't empty
+ if (buf_used == 0)
+ {
+ buf_used = 8;
+ if (streambuf->sgetn(reinterpret_cast<char*>(&buf),1)==0)
+ {
+ // if there isn't anything else in the streambuffer then just
+ // make buf zero.
+ buf = 0;
+ }
+ }
+
+ // we will be taking one bit from buf to replace the one we threw away
+ --buf_used;
+
+ // roll off the bit in target
+ target <<= 1;
+
+ // roll off the bit
+ high <<= 1;
+ low <<= 1;
+ high |= 1; // note that it is ok to add one to high here because
+ // of the convention that high == real upper range - 1.
+ // so that means that if we want to shift the upper range
+ // left by one then we must shift a one into high also
+ // since real upper range == high + 0.999999999...
+
+ // make sure low is never zero
+ if (low == 0)
+ low = 1;
+
+ // take a bit from buf to fill in the one we threw away
+ target += (buf>>buf_used)&0x01;
+ }
+ // if the distance between high and low is small and there aren't
+ // any bits we can roll off then round low up or high down.
+ else if (high-low < 0x10000)
+ {
+ if (high == 0x80000000)
+ high = 0x7fffffff;
+ else
+ low = 0x80000000;
+ }
+ else
+ {
+ break;
+ }
+ } // while (true)
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool entropy_decoder_kernel_1::
+ get_target_called (
+ ) const
+ {
+ return (r != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint32 entropy_decoder_kernel_1::
+ get_target (
+ uint32 total
+ )
+ {
+ // note that we must add one because of the convention that
+ // high == the real upper range minus 1
+ r = (high-low+1)/total;
+ uint32 temp = (target-low)/r;
+ if (temp < total)
+ return temp;
+ else
+ return total-1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_ENTROPY_DECODER_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.h b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.h
new file mode 100644
index 000000000..daf6d11ec
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.h
@@ -0,0 +1,132 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_KERNEl_1_
+#define DLIB_ENTROPY_DECODER_KERNEl_1_
+
+#include "../algs.h"
+#include "entropy_decoder_kernel_abstract.h"
+#include <iosfwd>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+ class entropy_decoder_kernel_1
+ {
+ /*!
+ GENERAL NOTES
+ this decoder is implemented using arithmetic coding
+
+ INITIAL VALUE
+ in == 0
+ buf_used == 0
+ buf == 0
+ initial_low == 0x00000001 (slightly more than zero)
+ initial_high == 0xffffffff (slightly less than one, 0.99999999976717)
+ target == 0x00000000 (zero)
+ low == initial_low
+ high == initial_high
+ r == 0
+
+ CONVENTION
+ if (in != 0)
+ *in == get_stream()
+ true == stream_is_set()
+ streambuf == in->rdbuf()
+ else
+ false == stream_is_set()
+
+ buf == used to hold fractional byte values which are fed to target.
+ buf_used == the number of low order bits in buf that are currently
+ in use
+ low == the low end of the range used for arithmetic encoding.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so it is
+ always in the range [0,1)
+
+ low is also never allowed to be zero to avoid overflow
+ in the calculation (high-low+1)/total.
+
+ high == the high end of the range - 1 used for arithmetic encoding.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so when we
+ interpret high as a real number then it is always in the
+ range [0,1)
+
+ the range for arithmetic encoding is always
+ [low,high + 0.9999999...) the 0.9999999... is why
+ high == real upper range - 1
+
+ target == 32 bits of the fraction produced from an arithmetic encoder.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so it is
+ always in the range [0,1)
+
+ r == the value (high-low+1)/total from the last call to
+ get_target() or 0 if get_target_called() should be false
+
+ get_target_called() == (r != 0)
+
+ !*/
+
+ public:
+
+ entropy_decoder_kernel_1 (
+ );
+
+ virtual ~entropy_decoder_kernel_1 (
+ );
+
+ void clear(
+ );
+
+ void set_stream (
+ std::istream& in
+ );
+
+ bool stream_is_set (
+ ) const;
+
+ std::istream& get_stream (
+ ) const;
+
+ void decode (
+ uint32 low_count,
+ uint32 high_count
+ );
+
+ bool get_target_called (
+ ) const;
+
+ uint32 get_target (
+ uint32 total
+ );
+
+ private:
+
+ // restricted functions
+ entropy_decoder_kernel_1(entropy_decoder_kernel_1&); // copy constructor
+ entropy_decoder_kernel_1& operator=(entropy_decoder_kernel_1&); // assignment operator
+
+ // data members
+ const uint32 initial_low;
+ const uint32 initial_high;
+ std::istream* in;
+ uint32 low;
+ uint32 high;
+ unsigned char buf;
+ uint32 buf_used;
+ uint32 target;
+ uint32 r;
+ std::streambuf* streambuf;
+
+ };
+
+}
+
+#ifdef NO_MAKEFILE
+#include "entropy_decoder_kernel_1.cpp"
+#endif
+
+#endif // DLIB_ENTROPY_DECODER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.cpp b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.cpp
new file mode 100644
index 000000000..5b986273e
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.cpp
@@ -0,0 +1,224 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_KERNEL_2_CPp_
+#define DLIB_ENTROPY_DECODER_KERNEL_2_CPp_
+#include "entropy_decoder_kernel_2.h"
+#include <iostream>
+#include <streambuf>
+#include <sstream>
+
+#include "../assert.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ entropy_decoder_kernel_2::
+ entropy_decoder_kernel_2(
+ ) :
+ initial_low(0x00000001),
+ initial_high(0xffffffff),
+ in(0),
+ low(initial_low),
+ high(initial_high),
+ target(0x00000000),
+ r(0)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ entropy_decoder_kernel_2::
+ ~entropy_decoder_kernel_2 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_decoder_kernel_2::
+ clear(
+ )
+ {
+ in = 0;
+ r = 0;
+ low = initial_low;
+ high = initial_high;
+ target = 0x00000000;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_decoder_kernel_2::
+ set_stream (
+ std::istream& in_
+ )
+ {
+ r = 0;
+ low = initial_low;
+ high = initial_high;
+ target = 0x00000000;
+
+ in = &in_;
+ streambuf = in_.rdbuf();
+
+
+
+ unsigned char ch;
+
+
+ streambuf->sgetn((char*)&ch,1);
+ target = ch;
+
+ target <<= 8;
+ if (streambuf->sgetn((char*)&ch,1))
+ target += ch;
+
+
+ target <<= 8;
+ if (streambuf->sgetn((char*)&ch,1))
+ target += ch;
+
+
+ target <<= 8;
+ if (streambuf->sgetn((char*)&ch,1))
+ target += ch;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool entropy_decoder_kernel_2::
+ stream_is_set (
+ ) const
+ {
+ if (in != 0)
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::istream& entropy_decoder_kernel_2::
+ get_stream (
+ ) const
+ {
+ return *in;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_decoder_kernel_2::
+ decode (
+ uint32 low_count,
+ uint32 high_count
+ )
+ {
+ // note that we must subtract 1 to preserve the convention that
+ // high == the real upper range - 1
+ high = low + r*high_count - 1;
+ low = low + r*low_count;
+ r = 0;
+
+
+ while (true )
+ {
+
+ // if high and low don't have the same 8 high order bits
+ if ((high&0xFF000000) != (low&0xFF000000))
+ {
+ // if the distance between high and low is small and there aren't
+ // any bits we can roll off then force high and low to have common high
+ // order bits.
+ if ((high-low < 0x10000))
+ {
+ if (high-low > 0x1000)
+ {
+ high>>=1;
+ low>>=1;
+ high = low = high+low;
+ high += 0xFF;
+ low -= 0xFF;
+ }
+ else /**/
+ {
+ high>>=1;
+ low>>=1;
+ high = low = high+low;
+ }
+ }
+ else
+ {
+ // there are no bits to roll off and high and low are not
+ // too close so just quit the loop
+ break;
+ }
+
+ }
+ // else if there are 8 bits we can roll off
+ else
+ {
+ unsigned char buf;
+ if (streambuf->sgetn(reinterpret_cast<char*>(&buf),1)==0)
+ {
+ // if there isn't anything else in the streambuffer then just
+ // make buf zero.
+ buf = 0;
+ }
+
+ // also roll off the bits in target
+ target <<= 8;
+
+ // roll off the bits
+ high <<= 8;
+ low <<= 8;
+ high |= 0xFF; // note that it is ok to add 0xFF to high here because
+ // of the convention that high == real upper range - 1.
+ // so that means that if we want to shift the upper range
+ // left by one then we must shift a one into high also
+ // since real upper range == high + 0.999999999...
+
+ // make sure low is never zero
+ if (low == 0)
+ low = 1;
+
+
+ // put the new bits into target
+ target |= static_cast<uint32>(buf);
+ }
+
+ } // while (true)
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool entropy_decoder_kernel_2::
+ get_target_called (
+ ) const
+ {
+ return (r != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint32 entropy_decoder_kernel_2::
+ get_target (
+ uint32 total
+ )
+ {
+ // note that we must add one because of the convention that
+ // high == the real upper range minus 1
+ r = (high-low+1)/total;
+ uint32 temp = (target-low)/r;
+ if (temp < total)
+ return temp;
+ else
+ return total-1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_ENTROPY_DECODER_KERNEL_2_CPp_
+
diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.h b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.h
new file mode 100644
index 000000000..7284cec3c
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.h
@@ -0,0 +1,127 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_KERNEl_2_
+#define DLIB_ENTROPY_DECODER_KERNEl_2_
+
+#include "../algs.h"
+#include "entropy_decoder_kernel_abstract.h"
+#include <iosfwd>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+ class entropy_decoder_kernel_2
+ {
+ /*!
+ GENERAL NOTES
+ this decoder is implemented using "range" coding
+
+ INITIAL VALUE
+ in == 0
+ initial_low == 0x00000001 (slightly more than zero)
+ initial_high == 0xffffffff (slightly less than one, 0.99999999976717)
+ target == 0x00000000 (zero)
+ low == initial_low
+ high == initial_high
+ r == 0
+
+ CONVENTION
+ if (in != 0)
+ *in == get_stream()
+ true == stream_is_set()
+ streambuf == in->rdbuf()
+ else
+ false == stream_is_set()
+
+
+ low == the low end of the range used for arithmetic encoding.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so it is
+ always in the range [0,1)
+
+ low is also never allowed to be zero to avoid overflow
+ in the calculation (high-low+1)/total.
+
+ high == the high end of the range - 1 used for arithmetic encoding.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so when we
+ interpret high as a real number then it is always in the
+ range [0,1)
+
+ the range for arithmetic encoding is always
+ [low,high + 0.9999999...) the 0.9999999... is why
+ high == real upper range - 1
+
+ target == 32 bits of the fraction produced from an arithmetic encoder.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so it is
+ always in the range [0,1)
+
+ r == the value (high-low+1)/total from the last call to
+ get_target() or 0 if get_target_called() should be false
+
+ get_target_called() == (r != 0)
+
+ !*/
+
+ public:
+
+ entropy_decoder_kernel_2 (
+ );
+
+ virtual ~entropy_decoder_kernel_2 (
+ );
+
+ void clear(
+ );
+
+ void set_stream (
+ std::istream& in
+ );
+
+ bool stream_is_set (
+ ) const;
+
+ std::istream& get_stream (
+ ) const;
+
+ void decode (
+ uint32 low_count,
+ uint32 high_count
+ );
+
+ bool get_target_called (
+ ) const;
+
+ uint32 get_target (
+ uint32 total
+ );
+
+ private:
+
+ // restricted functions
+ entropy_decoder_kernel_2(entropy_decoder_kernel_2&); // copy constructor
+ entropy_decoder_kernel_2& operator=(entropy_decoder_kernel_2&); // assignment operator
+
+ // data members
+ const uint32 initial_low;
+ const uint32 initial_high;
+ std::istream* in;
+ uint32 low;
+ uint32 high;
+ uint32 target;
+ uint32 r;
+ std::streambuf* streambuf;
+
+ };
+
+
+}
+
+#ifdef NO_MAKEFILE
+#include "entropy_decoder_kernel_2.cpp"
+#endif
+
+#endif // DLIB_ENTROPY_DECODER_KERNEl_2_
+
diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_abstract.h b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_abstract.h
new file mode 100644
index 000000000..89906b9ae
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_abstract.h
@@ -0,0 +1,207 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ENTROPY_DECODER_KERNEl_ABSTRACT_
+#ifdef DLIB_ENTROPY_DECODER_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+#include <iosfwd>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+ class entropy_decoder
+ {
+ /*!
+ INITIAL VALUE
+ stream_is_set() == false
+ get_target_called() == false
+
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an entropy decoder (could be implemented as an
+ arithmetic decoder for example).
+
+ Note that all implementations of entropy_encoder and entropy_decoder
+ are paired. This means that if you use entropy_encoder_kernel_n to
+ encode something then you must use the corresponding
+ entropy_decoder_kernel_n to decode it.
+
+
+ WHERE IS EOF?
+ It is important to note that this object will not give any indication
+ that is has hit the end of the input stream when it occurs. It is
+ up to you to use some kind of coding scheme to detect this in the
+ compressed data stream.
+
+ Another important thing to know is that decode() must be called
+ exactly the same number of times as encode() and with the same values
+ supplied for TOTAL, high_count, and low_count. Doing this ensures
+ that the decoder consumes exactly all the bytes from the input
+ stream that were written by the entropy_encoder.
+
+ NOTATION:
+ At any moment each symbol has a certain probability of appearing in
+ the input stream. These probabilities may change as each symbol is
+ decoded and the probability model is updated accordingly.
+
+
+ - Before considering current symbol:
+
+ let P(i) be a function which gives the probability of seeing the ith
+ symbol of an N symbol alphabet. Note that P(i) refers to the probability
+ of seeing the ith symbol WITHOUT considering the symbol currently given
+ by get_target(TOTAL). ( The domain of P(i) is from 0 to N-1. )
+
+ for each i: P(i) == COUNT/TOTAL where COUNT and TOTAL are integers
+ and TOTAL is the same number for all P(i) but COUNT may vary.
+
+ let LOW_COUNT(i) be the sum of all P(x)*TOTAL from x == 0 to x == i-1
+ (note that LOW_COUNT(0) == 0)
+ let HIGH_COUNT(i) be the sum of all P(x)*TOTAL from x == 0 to x == i
+
+
+ - After considering current symbol:
+
+ let #P(i) be a function which gives the probability of seeing the ith
+ symbol after we have updated our probability model to take the symbol
+ given by get_target(TOTAL) into account.
+
+ for each i: #P(i) == #COUNT/#TOTAL where #COUNT and #TOTAL are integers
+ and #TOTAL is the same number for all #P(i) but #COUNT may vary.
+ !*/
+
+ public:
+
+ entropy_decoder (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~entropy_decoder (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ - if (stream_is_set())
+ - clears any state accumulated in *this from decoding data from
+ the stream get_stream()
+ throws
+ - any exception
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void set_stream (
+ std::istream& in
+ );
+ /*!
+ ensures
+ - #*this will read data from in and decode it
+ - #stream_is_set() == true
+ - #get_target() == a number representing the first symbol from in
+ - #get_target_called() == false
+ - if (stream_is_set())
+ - clears any state accumulated in *this from decoding data from
+ the stream get_stream()
+ throws
+ - any exception
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ bool stream_is_set (
+ ) const;
+ /*!
+ ensures
+ - returns true if a stream has been associated with *this by calling
+ set_stream()
+ !*/
+
+ std::istream& get_stream (
+ ) const;
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - returns a reference to the istream object that *this is reading
+ encoded data from
+ !*/
+
+
+ void decode (
+ uint32 low_count,
+ uint32 high_count
+ );
+ /*!
+ requires
+ - get_target_called() == true
+ - stream_is_set() == true
+ - low_count == LOW_COUNT(S) where S is the symbol represented
+ by get_target(TOTAL)
+ - high_count == HIGH_COUNT(S) where S is the symbol represented
+ by get_target(TOTAL)
+ - low_count <= get_target(TOTAL) < high_count <= TOTAL
+ ensures
+ - #get_target(#TOTAL) == a number which represents the next symbol
+ - #get_target_called() == false
+ throws
+ - any exception
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ bool get_target_called (
+ ) const;
+ /*!
+ ensures
+ - returns true if get_target() has been called and since then decode()
+ and set_stream() have not been called
+ - returns false otherwise
+ !*/
+
+ uint32 get_target (
+ uint32 total
+ );
+ /*!
+ requires
+ - 0 < total < 65536 (2^16)
+ - total == TOTAL
+ - stream_is_set() == true
+ ensures
+ - in the next call to decode() the value of TOTAL will be
+ considered to be total
+ - #get_target_called() == true
+ - returns a number N such that:
+ - N is in the range 0 to total - 1
+ - N represents a symbol S where
+ LOW_COUNT(S) <= N < HIGH_COUNT(S)
+ throws
+ - any exception
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ private:
+
+ // restricted functions
+ entropy_decoder(entropy_decoder&); // copy constructor
+ entropy_decoder& operator=(entropy_decoder&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_ENTROPY_DECODER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_c.h b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_c.h
new file mode 100644
index 000000000..4b94791f3
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_c.h
@@ -0,0 +1,123 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_KERNEl_C_
+#define DLIB_ENTROPY_DECODER_KERNEl_C_
+
+#include "entropy_decoder_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename decoder
+ >
+ class entropy_decoder_kernel_c : public decoder
+ {
+
+ public:
+ std::istream& get_stream (
+ ) const;
+
+ void decode (
+ uint32 low_count,
+ uint32 high_count
+ );
+
+ uint32 get_target (
+ uint32 total
+ );
+
+ private:
+ uint32 _get_target;
+ uint32 TOTAL;
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename decoder
+ >
+ std::istream& entropy_decoder_kernel_c<decoder>::
+ get_stream (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tstd::istream& entropy_decoder::get_stream()"
+ << "\n\tyou must set a stream for this object before you can get it"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return decoder::get_stream();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename decoder
+ >
+ void entropy_decoder_kernel_c<decoder>::
+ decode (
+ uint32 low_count,
+ uint32 high_count
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (low_count <= _get_target) && (_get_target < high_count) &&
+ (high_count <= TOTAL) &&
+ (this->stream_is_set() == true) && (this->get_target_called() == true),
+ "\tvoid entropy_decoder::decode()"
+ << "\n\tRefer to the ensures clause for this function for further information."
+ << "\n\tNote that _get_target refers to get_target(TOTAL)"
+ << "\n\tthis: " << this
+ << "\n\tlow_count: " << low_count
+ << "\n\thigh_count: " << high_count
+ << "\n\tTOTAL: " << TOTAL
+ << "\n\tget_target(TOTAL): " << _get_target
+ << "\n\tis_stream_set(): " << (this->stream_is_set() ? "true" : "false" )
+ << "\n\tget_target_called(): " << (this->get_target_called() ? "true" : "false" )
+ );
+
+ // call the real function
+ decoder::decode(low_count,high_count);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename decoder
+ >
+ uint32 entropy_decoder_kernel_c<decoder>::
+ get_target (
+ uint32 total
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (total > 0) && (total < 65536) && (this->stream_is_set() == true),
+ "\tvoid entropy_decoder::get_target()"
+ << "\n\tyou must set a stream for this object before you can get the "
+ << "\n\rnext target."
+ << "\n\tthis: " << this
+ << "\n\ttotal: " << total
+ );
+
+ // call the real function
+ _get_target = decoder::get_target(total);
+ TOTAL = total;
+ return _get_target;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_KERNEl_C_
+
diff --git a/ml/dlib/dlib/entropy_decoder_model.h b/ml/dlib/dlib/entropy_decoder_model.h
new file mode 100644
index 000000000..d898161aa
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder_model.h
@@ -0,0 +1,108 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_MODEl_
+#define DLIB_ENTROPY_DECODER_MODEl_
+
+#include "entropy_decoder_model/entropy_decoder_model_kernel_1.h"
+#include "entropy_decoder_model/entropy_decoder_model_kernel_2.h"
+#include "entropy_decoder_model/entropy_decoder_model_kernel_3.h"
+#include "entropy_decoder_model/entropy_decoder_model_kernel_4.h"
+#include "entropy_decoder_model/entropy_decoder_model_kernel_5.h"
+#include "entropy_decoder_model/entropy_decoder_model_kernel_6.h"
+
+#include "conditioning_class.h"
+#include "memory_manager.h"
+
+namespace dlib
+{
+
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder
+ >
+ class entropy_decoder_model
+ {
+ entropy_decoder_model() {}
+
+ typedef typename conditioning_class<alphabet_size+1>::kernel_1a cc1;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_2a cc2;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_3a cc3;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_4a cc4a;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_4b cc4b;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_4c cc4c;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_4d cc4d;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1
+ typedef entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc1>
+ kernel_1a;
+
+ typedef entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc2>
+ kernel_1b;
+
+ typedef entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc3>
+ kernel_1c;
+
+ // --------------------
+
+ // kernel_2
+ typedef entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc1,cc1>
+ kernel_2a;
+
+ typedef entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc2,cc2>
+ kernel_2b;
+
+ typedef entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc3,cc3>
+ kernel_2c;
+
+ typedef entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc2,cc4b>
+ kernel_2d;
+
+ // --------------------
+
+ // kernel_3
+ typedef entropy_decoder_model_kernel_3<alphabet_size,entropy_decoder,cc1,cc4b>
+ kernel_3a;
+
+ typedef entropy_decoder_model_kernel_3<alphabet_size,entropy_decoder,cc2,cc4b>
+ kernel_3b;
+
+ typedef entropy_decoder_model_kernel_3<alphabet_size,entropy_decoder,cc3,cc4b>
+ kernel_3c;
+
+ // --------------------
+
+ // kernel_4
+ typedef entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,200000,4>
+ kernel_4a;
+ typedef entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,1000000,5>
+ kernel_4b;
+
+
+ // --------------------
+
+ // kernel_5
+ typedef entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,200000,4>
+ kernel_5a;
+ typedef entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,1000000,5>
+ kernel_5b;
+ typedef entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,2500000,7>
+ kernel_5c;
+
+
+ // --------------------
+
+ // kernel_6
+ typedef entropy_decoder_model_kernel_6<alphabet_size,entropy_decoder>
+ kernel_6a;
+
+
+ };
+}
+
+#endif // DLIB_ENTROPY_DECODER_MODEl_
+
diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_1.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_1.h
new file mode 100644
index 000000000..a0a94c948
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_1.h
@@ -0,0 +1,173 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_1_
+#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_1_
+
+#include "../algs.h"
+#include "entropy_decoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc
+ >
+ class entropy_decoder_model_kernel_1
+ {
+ /*!
+ REQUIREMENTS ON cc
+ cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ cc::get_alphabet_size() == alphabet_size+1
+
+ INITIAL VALUE
+ Initially this object's finite context model is empty
+
+ CONVENTION
+ &get_entropy_decoder() == coder
+ &order_0.get_global_state() == &gs
+
+ This is an order-0 model. The last symbol in the order-0 context is
+ an escape into the order minus 1 context.
+ !*/
+
+ public:
+
+ typedef entropy_decoder entropy_decoder_type;
+
+ entropy_decoder_model_kernel_1 (
+ entropy_decoder& coder
+ );
+
+ virtual ~entropy_decoder_model_kernel_1 (
+ );
+
+ inline void clear(
+ );
+
+ inline void decode (
+ unsigned long& symbol
+ );
+
+ entropy_decoder& get_entropy_decoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ entropy_decoder& coder;
+ typename cc::global_state_type gs;
+ cc order_0;
+
+ // restricted functions
+ entropy_decoder_model_kernel_1(entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc>&); // copy constructor
+ entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc>& operator=(entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc
+ >
+ entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc>::
+ entropy_decoder_model_kernel_1 (
+ entropy_decoder& coder_
+ ) :
+ coder(coder_),
+ order_0(gs)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc
+ >
+ entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc>::
+ ~entropy_decoder_model_kernel_1 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc
+ >
+ void entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc>::
+ clear(
+ )
+ {
+ order_0.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc
+ >
+ void entropy_decoder_model_kernel_1<alphabet_size,entropy_decoder,cc>::
+ decode (
+ unsigned long& symbol
+ )
+ {
+ unsigned long current_symbol, low_count, high_count, target;
+
+ // look in the order-0 context
+ target = coder.get_target(order_0.get_total());
+ order_0.get_symbol(target,current_symbol,low_count,high_count);
+
+
+ // have coder decode the next symbol
+ coder.decode(low_count,high_count);
+
+ // if current_symbol is not an escape from the order-0 context
+ if (current_symbol != alphabet_size)
+ {
+ // update the count for this symbol
+ order_0.increment_count(current_symbol,2);
+
+ symbol = current_symbol;
+ return;
+ }
+
+ // update the count for the escape symbol
+ order_0.increment_count(alphabet_size);
+
+
+ // go into the order minus one context
+ target = coder.get_target(alphabet_size);
+ coder.decode(target,target+1);
+
+
+ // update the count for this symbol in the order-0 context
+ order_0.increment_count(target,2);
+
+ symbol = target;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_1_
+
diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_2.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_2.h
new file mode 100644
index 000000000..6841db391
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_2.h
@@ -0,0 +1,245 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_2_
+#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_2_
+
+#include "../algs.h"
+#include "entropy_decoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename ccbig
+ >
+ class entropy_decoder_model_kernel_2
+ {
+ /*!
+ REQUIREMENTS ON cc
+ cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ cc::get_alphabet_size() == alphabet_size+1
+ this will be used for the order-0 context
+
+ REQUIREMENTS ON ccbig
+ ccbig is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ ccbig::get_alphabet_size() == alphabet_size+1
+ this will be used for the order-1 context
+
+ INITIAL VALUE
+ Initially this object's finite context model is empty
+ previous_symbol == 0
+
+ CONVENTION
+ &get_entropy_decoder() == coder
+ &order_0.get_global_state() == &gs
+ &order_1[i]->get_global_state() == &gsbig
+
+
+ This is an order-1-0 model. The last symbol in the order-0 and order-1
+ context is an escape into the lower context.
+
+ previous_symbol == the last symbol seen
+ !*/
+
+ public:
+
+ typedef entropy_decoder entropy_decoder_type;
+
+ entropy_decoder_model_kernel_2 (
+ entropy_decoder& coder
+ );
+
+ virtual ~entropy_decoder_model_kernel_2 (
+ );
+
+ inline void clear(
+ );
+
+ inline void decode (
+ unsigned long& symbol
+ );
+
+ entropy_decoder& get_entropy_decoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ entropy_decoder& coder;
+ typename cc::global_state_type gs;
+ typename ccbig::global_state_type gsbig;
+ cc order_0;
+ ccbig* order_1[alphabet_size];
+ unsigned long previous_symbol;
+
+
+ // restricted functions
+ entropy_decoder_model_kernel_2(entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc,ccbig>&); // copy constructor
+ entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc,ccbig>& operator=(entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc,ccbig>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename ccbig
+ >
+ entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc,ccbig>::
+ entropy_decoder_model_kernel_2 (
+ entropy_decoder& coder_
+ ) :
+ coder(coder_),
+ order_0(gs),
+ previous_symbol(0)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535);
+
+ unsigned long i;
+ try
+ {
+ for (i = 0; i < alphabet_size; ++i)
+ {
+ order_1[i] = new ccbig(gsbig);
+ }
+ }
+ catch (...)
+ {
+ for (unsigned long j = 0; j < i; ++j)
+ {
+ delete order_1[j];
+ }
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename ccbig
+ >
+ entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc,ccbig>::
+ ~entropy_decoder_model_kernel_2 (
+ )
+ {
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ delete order_1[i];
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename ccbig
+ >
+ void entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc,ccbig>::
+ clear(
+ )
+ {
+ previous_symbol = 0;
+ order_0.clear();
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ order_1[i]->clear();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename ccbig
+ >
+ void entropy_decoder_model_kernel_2<alphabet_size,entropy_decoder,cc,ccbig>::
+ decode (
+ unsigned long& symbol
+ )
+ {
+ unsigned long current_symbol, low_count, high_count, target;
+
+ // look in the order-1 context
+ target = coder.get_target(order_1[previous_symbol]->get_total());
+ order_1[previous_symbol]->get_symbol(target,current_symbol,low_count,high_count);
+
+ // have the coder decode the next symbol
+ coder.decode(low_count,high_count);
+
+ // if the current_symbol is not an escape from the order-1 context
+ if (current_symbol != alphabet_size)
+ {
+ symbol = current_symbol;
+ order_1[previous_symbol]->increment_count(current_symbol,2);
+ previous_symbol = current_symbol;
+ return;
+ }
+
+ // since this is an escape to order-0 we should increment
+ // the escape symbol
+ order_1[previous_symbol]->increment_count(alphabet_size);
+
+
+
+ // look in the order-0 context
+ target = coder.get_target(order_0.get_total());
+ order_0.get_symbol(target,current_symbol,low_count,high_count);
+
+ // have coder decode the next symbol
+ coder.decode(low_count,high_count);
+
+ // if current_symbol is not an escape from the order-0 context
+ if (current_symbol != alphabet_size)
+ {
+ // update the count for this symbol
+ order_1[previous_symbol]->increment_count(current_symbol,2);
+ order_0.increment_count(current_symbol,2);
+
+ symbol = current_symbol;
+ previous_symbol = current_symbol;
+ return;
+ }
+
+ // update the count for the escape symbol
+ order_0.increment_count(current_symbol);
+
+
+ // go into the order minus one context
+ target = coder.get_target(alphabet_size);
+ coder.decode(target,target+1);
+
+
+ // update the count for this symbol
+ order_1[previous_symbol]->increment_count(target,2);
+ order_0.increment_count(target,2);
+
+ symbol = target;
+ previous_symbol = target;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_2_
+
diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_3.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_3.h
new file mode 100644
index 000000000..c55c09e85
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_3.h
@@ -0,0 +1,335 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_3_
+#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_3_
+
+#include "../algs.h"
+#include "entropy_decoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename cc_high
+ >
+ class entropy_decoder_model_kernel_3
+ {
+ /*!
+ REQUIREMENTS ON cc
+ cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ cc::get_alphabet_size() == alphabet_size+1
+
+ REQUIREMENTS ON cc_high
+ cc_high is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ cc_high::get_alphabet_size() == alphabet_size+1
+
+ INITIAL VALUE
+ - Initially this object's finite context model is empty
+ - previous_symbol == 0
+ - previous_symbol2 == 0
+ - order_1 == pointer to an array of alphabet_size elements
+ - order_2 == pointer to an array of alphabet_size*alphabet_size elements
+ - for all values of i: order_2[i] == 0
+
+ CONVENTION
+ &get_entropy_encoder() == coder
+ &order_0.get_global_state() == &gs
+ &order_1[i]->get_global_state() == &gs
+
+ if (order_2[i] != 0) then
+ &order_2[i]->get_global_state() == &gs_high
+
+ This is an order-2-1-0 model. The last symbol in the order-2, order-1 and
+ order-0 contexts is an escape into the lower context.
+
+ previous_symbol == the last symbol seen
+ previous_symbol2 == the symbol we saw before previous_symbol
+ !*/
+
+ public:
+
+ typedef entropy_decoder entropy_decoder_type;
+
+ entropy_decoder_model_kernel_3 (
+ entropy_decoder& coder
+ );
+
+ virtual ~entropy_decoder_model_kernel_3 (
+ );
+
+ inline void clear(
+ );
+
+ inline void decode (
+ unsigned long& symbol
+ );
+
+ entropy_decoder& get_entropy_decoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ entropy_decoder& coder;
+ typename cc::global_state_type gs;
+ typename cc_high::global_state_type gs_high;
+ cc order_0;
+ cc** order_1;
+ unsigned long previous_symbol;
+ cc_high** order_2;
+ unsigned long previous_symbol2;
+
+ // restricted functions
+ entropy_decoder_model_kernel_3(entropy_decoder_model_kernel_3&); // copy constructor
+ entropy_decoder_model_kernel_3& operator=(entropy_decoder_model_kernel_3&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename cc_high
+ >
+ entropy_decoder_model_kernel_3<alphabet_size,entropy_decoder,cc,cc_high>::
+ entropy_decoder_model_kernel_3 (
+ entropy_decoder& coder_
+ ) :
+ coder(coder_),
+ order_0(gs),
+ order_1(0),
+ previous_symbol(0),
+ order_2(0),
+ previous_symbol2(0)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535);
+
+ try
+ {
+ order_1 = new cc*[alphabet_size];
+ order_2 = new cc_high*[alphabet_size*alphabet_size];
+ }
+ catch (...)
+ {
+ if (order_1) delete [] order_1;
+ if (order_2) delete [] order_2;
+ throw;
+ }
+
+
+ unsigned long i;
+
+ for (i = 0; i < alphabet_size*alphabet_size; ++i)
+ {
+ order_2[i] = 0;
+ }
+
+ try
+ {
+ for (i = 0; i < alphabet_size; ++i)
+ {
+ order_1[i] = new cc(gs);
+ }
+ }
+ catch (...)
+ {
+ for (unsigned long j = 0; j < i; ++j)
+ {
+ delete order_1[j];
+ }
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename cc_high
+ >
+ entropy_decoder_model_kernel_3<alphabet_size,entropy_decoder,cc,cc_high>::
+ ~entropy_decoder_model_kernel_3 (
+ )
+ {
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ delete order_1[i];
+ }
+
+ for (unsigned long i = 0; i < alphabet_size*alphabet_size; ++i)
+ {
+ if (order_2[i] != 0)
+ delete order_2[i];
+ }
+ delete [] order_1;
+ delete [] order_2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename cc_high
+ >
+ void entropy_decoder_model_kernel_3<alphabet_size,entropy_decoder,cc,cc_high>::
+ clear(
+ )
+ {
+ previous_symbol = 0;
+ previous_symbol2 = 0;
+ order_0.clear();
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ order_1[i]->clear();
+ }
+
+ for (unsigned long i = 0; i < alphabet_size*alphabet_size; ++i)
+ {
+ if (order_2[i] != 0)
+ {
+ delete order_2[i];
+ order_2[i] = 0;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ typename cc,
+ typename cc_high
+ >
+ void entropy_decoder_model_kernel_3<alphabet_size,entropy_decoder,cc,cc_high>::
+ decode (
+ unsigned long& symbol
+ )
+ {
+ unsigned long current_symbol, low_count, high_count, target;
+
+
+ // look in the order-2 context
+ unsigned long temp = previous_symbol + (previous_symbol2 * alphabet_size);
+ if (order_2[temp] != 0)
+ {
+ target = coder.get_target(order_2[temp]->get_total());
+ order_2[temp]->get_symbol(target,current_symbol,low_count,high_count);
+
+ // have the coder decode the next symbol
+ coder.decode(low_count,high_count);
+
+ // if the current_symbol is not an escape from the order-2 context
+ if (current_symbol != alphabet_size)
+ {
+ symbol = current_symbol;
+ order_2[temp]->increment_count(current_symbol,2);
+ previous_symbol2 = previous_symbol;
+ previous_symbol = current_symbol;
+ return;
+ }
+
+ // since this is an escape to order-1 we should increment
+ // the escape symbol
+ order_2[temp]->increment_count(alphabet_size);
+ }
+ else
+ {
+ order_2[temp] = new cc_high(gs_high);
+ }
+
+
+
+
+
+
+ // look in the order-1 context
+ target = coder.get_target(order_1[previous_symbol]->get_total());
+ order_1[previous_symbol]->get_symbol(target,current_symbol,low_count,high_count);
+
+ // have the coder decode the next symbol
+ coder.decode(low_count,high_count);
+
+ // if the current_symbol is not an escape from the order-1 context
+ if (current_symbol != alphabet_size)
+ {
+ symbol = current_symbol;
+ order_2[temp]->increment_count(current_symbol,2);
+ order_1[previous_symbol]->increment_count(current_symbol,2);
+ previous_symbol2 = previous_symbol;
+ previous_symbol = current_symbol;
+ return;
+ }
+
+ // since this is an escape to order-0 we should increment
+ // the escape symbol
+ order_1[previous_symbol]->increment_count(alphabet_size);
+
+
+
+ // look in the order-0 context
+ target = coder.get_target(order_0.get_total());
+ order_0.get_symbol(target,current_symbol,low_count,high_count);
+
+ // have coder decode the next symbol
+ coder.decode(low_count,high_count);
+
+ // if current_symbol is not an escape from the order-0 context
+ if (current_symbol != alphabet_size)
+ {
+ // update the count for this symbol
+ order_2[temp]->increment_count(current_symbol,2);
+ order_1[previous_symbol]->increment_count(current_symbol,2);
+ order_0.increment_count(current_symbol,2);
+
+
+ symbol = current_symbol;
+ previous_symbol2 = previous_symbol;
+ previous_symbol = current_symbol;
+ return;
+ }
+
+ // update the count for the escape symbol
+ order_0.increment_count(current_symbol);
+
+
+ // go into the order minus one context
+ target = coder.get_target(alphabet_size);
+ coder.decode(target,target+1);
+
+
+ // update the count for this symbol
+ order_2[temp]->increment_count(target,2);
+ order_1[previous_symbol]->increment_count(target,2);
+ order_0.increment_count(target,2);
+
+
+ symbol = target;
+ previous_symbol2 = previous_symbol;
+ previous_symbol = target;
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_3_
+
diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_4.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_4.h
new file mode 100644
index 000000000..dcbfef6a0
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_4.h
@@ -0,0 +1,622 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_4_
+#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_4_
+
+#include "../algs.h"
+#include "entropy_decoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+
+namespace dlib
+{
+
+ namespace edmk4
+ {
+ struct node
+ {
+ node* next;
+ node* child_context;
+ node* parent_context;
+
+ unsigned short symbol;
+ unsigned short count;
+ unsigned short total;
+ unsigned short escapes;
+ };
+ }
+
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ class entropy_decoder_model_kernel_4
+ {
+ /*!
+ REQUIREMENTS ON total_nodes
+ - 4096 < total_nodes
+ - this is the total number of nodes that we will use in the tree
+
+ REQUIREMENTS ON order
+ - 0 <= order
+ - this is the maximum depth-1 the tree will be allowed to go (note
+ that the root level is depth 0).
+
+
+ GENERAL NOTES
+ This implementation follows more or less the implementation
+ strategy laid out by Alistair Moffat in his paper
+ Implementing the PPM data compression scheme. Published in IEEE
+ Transactions on Communications, 38(11):1917-1921, 1990.
+
+ The escape method used will be method D.
+
+
+ INITIAL VALUE
+ - root == pointer to an array of total_nodes nodes
+ - next_node == 1
+ - cur == root
+ - cur_order = 0
+ - root->next == 0
+ - root->parent_context == 0
+ - root->child_context == 0
+ - root->escapes == 0
+ - root->total == 0
+ - stack_size == 0
+
+ CONVENTION
+ - pop() == stack[stack_size-1]
+ - &get_entropy_decoder() == coder
+ - root == pointer to an array of total_nodes nodes.
+ this is also the root of the tree.
+
+ - if (next_node < total_nodes) then
+ - next_node == the next node in root that has not yet been allocated
+
+
+ - root->next == 0
+ - root->parent_context == 0
+
+
+ - for every node in the tree:
+ {
+ - NOTATION:
+ - The "context" of a node is the string of symbols seen
+ when you go from the root of the tree down (down though
+ child context pointers) to the node, including the symbol at
+ the node itself. (note that the context of the root node
+ is "" or the empty string)
+ - A set of nodes is in the same "context set" if all the node's
+ contexts are of length n and all the node's contexts share
+ the same prefix of length n-1.
+ - The "child context set" of a node is a set of nodes with
+ contexts that are one symbol longer and prefixed by the node's
+ context. For example, if a node has a context "abc" then the
+ nodes for contexts "abca", "abcb", "abcc", etc. are all in
+ the child context set of the node.
+ - The "parent context" of a node is the context that is one
+ symbol shorter than the node's context and includes the
+ symbol in the node. So the parent context of a node with
+ context "abcd" would be the context "bcd".
+
+
+ - if (next != 0) then
+ - next == pointer to the next node in the same context set
+ - if (child_context != 0) then
+ - child_context == pointer to the first node of the child
+ context set for this node.
+ - if (parent_context != 0) then
+ - parent_context == pointer to the parent context of this node.
+ - else
+ - this node is the root node of the tree
+
+
+ - if (this is not the root node) then
+ - symbol == the symbol represented with this node
+ - count == the number of times this symbol has been seen in its
+ parent context.
+ - else
+ - the root doesn't have a symbol. i.e. the context for the
+ root node is "" or the empty string.
+
+ - total == The sum of the counts of all the nodes
+ in the child context set + escapes.
+ - escapes == the escape count for the context represented
+ by the node.
+ }
+
+
+ - cur_order < order
+ - cur_order == the depth of the node cur in the tree.
+ (note that the root node has depth 0)
+ - cur == pointer to the node in the tree who's context matches
+ the most recent symbols we have seen.
+
+
+ !*/
+
+ typedef edmk4::node node;
+
+ public:
+
+ typedef entropy_decoder entropy_decoder_type;
+
+ entropy_decoder_model_kernel_4 (
+ entropy_decoder& coder
+ );
+
+ virtual ~entropy_decoder_model_kernel_4 (
+ );
+
+ inline void clear(
+ );
+
+ inline void decode (
+ unsigned long& symbol
+ );
+
+ entropy_decoder& get_entropy_decoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+
+ inline void push (
+ edmk4::node* n
+ );
+ /*!
+ requires
+ - stack_size <= order
+ ensures
+ - #pop() == n
+ !*/
+
+ inline edmk4::node* pop (
+ );
+ /*!
+ requires
+ - stack_size > 0
+ ensures
+ - returns the node at the top of the stack
+ !*/
+
+ inline edmk4::node* allocate_node (
+ );
+ /*!
+ requires
+ - space_left() == true
+ ensures
+ - returns a pointer to a new node
+ !*/
+
+ inline void destroy_tree (
+ );
+ /*!
+ ensures
+ - deallocates all nodes except the root
+ - #root->child_context == 0
+ - #root->escapes == 0
+ - #root->total == 0
+ - #cur == root
+ - #cur_order == 0
+ - #stack_size == 0
+ !*/
+
+
+ inline bool space_left (
+ ) const;
+ /*!
+ ensures
+ - returns true if there is at least 1 free node left.
+ - returns false otherwise
+ !*/
+
+
+ inline void scale_counts (
+ node* n
+ );
+ /*!
+ ensures
+ - divides all the counts in the child context set of n by 2.
+ - none of the nodes in the child context set will have a count of 0
+ !*/
+
+
+ entropy_decoder& coder;
+ unsigned long next_node;
+ node* root;
+ node* cur;
+ unsigned long cur_order;
+ node* stack[order+1];
+ unsigned long stack_size;
+
+ // restricted functions
+ entropy_decoder_model_kernel_4(entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>&); // copy constructor
+ entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>& operator=(entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ entropy_decoder_model_kernel_4 (
+ entropy_decoder& coder_
+ ) :
+ coder(coder_),
+ next_node(1),
+ cur_order(0),
+ stack_size(0)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535);
+ COMPILE_TIME_ASSERT( 4096 < total_nodes );
+
+ root = new node[total_nodes];
+ cur = root;
+
+ root->child_context = 0;
+ root->escapes = 0;
+ root->next = 0;
+ root->parent_context = 0;
+ root->total = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ ~entropy_decoder_model_kernel_4 (
+ )
+ {
+ delete [] root;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ clear(
+ )
+ {
+ destroy_tree();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ decode (
+ unsigned long& symbol
+ )
+ {
+ node* temp = cur;
+ cur = 0;
+ unsigned long low_count, high_count, total_count;
+ unsigned long target;
+ node* new_node = 0;
+
+ // local_order will track the level of temp in the tree
+ unsigned long local_order = cur_order;
+
+
+ while (true)
+ {
+ high_count = 0;
+ if (space_left())
+ {
+ total_count = temp->total;
+
+ if (total_count > 0)
+ {
+ // check if we need to scale the counts
+ if (total_count > 10000)
+ {
+ scale_counts(temp);
+ total_count = temp->total;
+ }
+
+ target = coder.get_target(total_count);
+
+ // find either the symbol we are looking for or the
+ // end of the context set
+ node* n = temp->child_context;
+ node* last = 0;
+ while (true)
+ {
+ high_count += n->count;
+
+ if (high_count > target || n->next == 0)
+ break;
+ last = n;
+ n = n->next;
+ }
+
+ low_count = high_count - n->count;
+
+ // if we found the symbol
+ if (high_count > target)
+ {
+ if (new_node != 0)
+ {
+ new_node->parent_context = n;
+ }
+
+ symbol = n->symbol;
+
+ coder.decode(low_count,high_count);
+ n->count += 8;
+ temp->total += 8;
+
+ // move this node to the front
+ if (last)
+ {
+ last->next = n->next;
+ n->next = temp->child_context;
+ temp->child_context = n;
+ }
+
+
+ if (cur == 0)
+ {
+ if (local_order < order)
+ {
+ cur_order = local_order+1;
+ cur = n;
+ }
+ else
+ {
+ cur = n->parent_context;
+ cur_order = local_order;
+ }
+ }
+
+ break;
+
+ }
+ // if we hit the end of the context set without finding the symbol
+ else
+ {
+ if (new_node != 0)
+ {
+ new_node->parent_context = allocate_node();
+ new_node = new_node->parent_context;
+ }
+ else
+ {
+ new_node = allocate_node();
+ }
+
+ n->next = new_node;
+
+ // get the escape code
+ coder.decode(high_count,total_count);
+ }
+
+ }
+ else // if (total_count == 0)
+ {
+ // this means that temp->child_context == 0 so we should make
+ // a new node here.
+ if (new_node != 0)
+ {
+ new_node->parent_context = allocate_node();
+ new_node = new_node->parent_context;
+ }
+ else
+ {
+ new_node = allocate_node();
+ }
+
+ temp->child_context = new_node;
+ }
+
+ if (cur == 0 && local_order < order)
+ {
+ cur = new_node;
+ cur_order = local_order+1;
+ }
+
+ // fill out the new node
+ new_node->child_context = 0;
+ new_node->count = 4;
+ new_node->escapes = 0;
+ new_node->next = 0;
+ push(new_node);
+ new_node->total = 0;
+ temp->escapes += 4;
+ temp->total += 8;
+
+
+ if (temp != root)
+ {
+ temp = temp->parent_context;
+ --local_order;
+ continue;
+ }
+
+ // since this is the root we are going to the order-(-1) context
+ // so we can just take care of that here.
+ target = coder.get_target(alphabet_size);
+ new_node->parent_context = root;
+ coder.decode(target,target+1);
+ symbol = target;
+
+ if (cur == 0)
+ {
+ cur = root;
+ cur_order = 0;
+ }
+ break;
+ }
+ else
+ {
+ // there isn't enough space so we should rebuild the tree
+ destroy_tree();
+ temp = cur;
+ local_order = cur_order;
+ cur = 0;
+ new_node = 0;
+ }
+ } // while (true)
+
+ while (stack_size > 0)
+ {
+ pop()->symbol = static_cast<unsigned short>(symbol);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ edmk4::node* entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ allocate_node (
+ )
+ {
+ node* temp;
+ temp = root + next_node;
+ ++next_node;
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ destroy_tree (
+ )
+ {
+ next_node = 1;
+ root->child_context = 0;
+ root->escapes = 0;
+ root->total = 0;
+ cur = root;
+ cur_order = 0;
+ stack_size = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ bool entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ space_left (
+ ) const
+ {
+ return (next_node < total_nodes);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ push (
+ edmk4::node* n
+ )
+ {
+ stack[stack_size] = n;
+ ++stack_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ edmk4::node* entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ pop (
+ )
+ {
+ --stack_size;
+ return stack[stack_size];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_4<alphabet_size,entropy_decoder,total_nodes,order>::
+ scale_counts (
+ node* temp
+ )
+ {
+ if (temp->escapes > 1)
+ temp->escapes >>= 1;
+ temp->total = temp->escapes;
+
+ node* n = temp->child_context;
+ while (n != 0)
+ {
+ if (n->count > 1)
+ n->count >>= 1;
+
+ temp->total += n->count;
+ n = n->next;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_4_
+
diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_5.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_5.h
new file mode 100644
index 000000000..9253e950b
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_5.h
@@ -0,0 +1,793 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_5_
+#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_5_
+
+#include "../algs.h"
+#include "entropy_decoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+
+namespace dlib
+{
+
+ namespace edmk5
+ {
+ struct node
+ {
+ node* next;
+ node* child_context;
+ node* parent_context;
+
+ unsigned short symbol;
+ unsigned short count;
+ unsigned short total;
+ unsigned short escapes;
+ };
+ }
+
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ class entropy_decoder_model_kernel_5
+ {
+ /*!
+ REQUIREMENTS ON total_nodes
+ - 4096 < total_nodes
+ - this is the total number of nodes that we will use in the tree
+
+ REQUIREMENTS ON order
+ - 0 <= order
+ - this is the maximum depth-1 the tree will be allowed to go (note
+ that the root level is depth 0).
+
+
+ GENERAL NOTES
+ This implementation follows more or less the implementation
+ strategy laid out by Alistair Moffat in his paper
+ Implementing the PPM data compression scheme. Published in IEEE
+ Transactions on Communications, 38(11):1917-1921, 1990.
+
+ The escape method used will be method D.
+
+ This also uses Dmitry Shkarin's Information Inheritance scheme.
+ (described in "PPM: one step to practicality" and "Improving the
+ Efficiency of the PPM Algorithm")
+
+
+ INITIAL VALUE
+ - root == pointer to an array of total_nodes nodes
+ - next_node == 1
+ - cur == root
+ - cur_order = 0
+ - root->next == 0
+ - root->parent_context == 0
+ - root->child_context == 0
+ - root->escapes == 0
+ - root->total == 0
+ - stack_size == 0
+ - exc_used == false
+ - for all i: exc[i] == 0
+
+ CONVENTION
+ - exc_used == something_is_excluded()
+ - pop() == stack[stack_size-1].n and stack[stack_size-1].nc
+ - is_excluded(symbol) == bit symbol&0x1F from exc[symbol>>5]
+ - &get_entropy_decoder() == coder
+ - root == pointer to an array of total_nodes nodes.
+ this is also the root of the tree.
+ - if (next_node < total_nodes) then
+ - next_node == the next node in root that has not yet been allocated
+
+ - root->next == 0
+ - root->parent_context == 0
+
+
+ - for every node in the tree:
+ {
+ - NOTATION:
+ - The "context" of a node is the string of symbols seen
+ when you go from the root of the tree down (down though
+ child context pointers) to the node, including the symbol at
+ the node itself. (note that the context of the root node
+ is "" or the empty string)
+ - A set of nodes is in the same "context set" if all the node's
+ contexts are of length n and all the node's contexts share
+ the same prefix of length n-1.
+ - The "child context set" of a node is a set of nodes with
+ contexts that are one symbol longer and prefixed by the node's
+ context. For example, if a node has a context "abc" then the
+ nodes for contexts "abca", "abcb", "abcc", etc. are all in
+ the child context set of the node.
+ - The "parent context" of a node is the context that is one
+ symbol shorter than the node's context and includes the
+ symbol in the node. So the parent context of a node with
+ context "abcd" would be the context "bcd".
+
+
+ - if (next != 0) then
+ - next == pointer to the next node in the same context set
+ - if (child_context != 0) then
+ - child_context == pointer to the first node of the child
+ context set for this node.
+ - escapes > 0
+ - if (parent_context != 0) then
+ - parent_context == pointer to the parent context of this node.
+ - else
+ - this node is the root node of the tree
+
+
+ - if (this is not the root node) then
+ - symbol == the symbol represented with this node
+ - count == the number of times this symbol has been seen in its
+ parent context.
+ - else
+ - the root doesn't have a symbol. i.e. the context for the
+ root node is "" or the empty string.
+
+ - total == The sum of the counts of all the nodes
+ in the child context set + escapes.
+ - escapes == the escape count for the context represented
+ by the node.
+ - count > 0
+ }
+
+
+ - cur_order < order
+ - cur_order == the depth of the node cur in the tree.
+ (note that the root node has depth 0)
+ - cur == pointer to the node in the tree who's context matches
+ the most recent symbols we have seen.
+
+
+ !*/
+
+ typedef edmk5::node node;
+
+ public:
+
+ typedef entropy_decoder entropy_decoder_type;
+
+ entropy_decoder_model_kernel_5 (
+ entropy_decoder& coder
+ );
+
+ virtual ~entropy_decoder_model_kernel_5 (
+ );
+
+ inline void clear(
+ );
+
+ inline void decode (
+ unsigned long& symbol
+ );
+
+ entropy_decoder& get_entropy_decoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+
+ inline void push (
+ node* n,
+ node* nc
+ );
+ /*!
+ requires
+ - stack_size < order
+ ensures
+ - #pop(a,b): a == n && b == nc
+ !*/
+
+ inline void pop (
+ node*& n,
+ node*& nc
+ );
+ /*!
+ requires
+ - stack_size > 0
+ ensures
+ - returns the two nodes at the top of the stack
+ !*/
+
+ inline edmk5::node* allocate_node (
+ );
+ /*!
+ requires
+ - space_left() == true
+ ensures
+ - returns a pointer to a new node
+ !*/
+
+ inline bool space_left (
+ ) const;
+ /*!
+ ensures
+ - returns true if there is at least 1 free node left.
+ - returns false otherwise
+ !*/
+
+ inline void exclude (
+ unsigned short symbol
+ );
+ /*!
+ ensures
+ - #is_excluded(symbol) == true
+ - #something_is_excluded() == true
+ !*/
+
+ inline bool is_excluded (
+ unsigned short symbol
+ );
+ /*!
+ ensures
+ - if (symbol has been excluded) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ inline bool something_is_excluded (
+ );
+ /*!
+ ensures
+ - returns true if some symbol has been excluded.
+ returns false otherwise
+ !*/
+
+ inline void clear_exclusions (
+ );
+ /*!
+ ensures
+ - for all symbols #is_excluded(symbol) == false
+ - #something_is_excluded() == false
+ !*/
+
+ inline void scale_counts (
+ node* n
+ );
+ /*!
+ ensures
+ - divides all the counts in the child context set of n by 2.
+ - none of the nodes in the child context set will have a count of 0
+ !*/
+
+ struct nodes
+ {
+ node* n;
+ node* nc;
+ };
+
+ entropy_decoder& coder;
+ unsigned long next_node;
+ node* root;
+ node* cur;
+ unsigned long cur_order;
+ unsigned long exc[alphabet_size/32+1];
+ nodes stack[order+1];
+ unsigned long stack_size;
+ bool exc_used;
+
+ // restricted functions
+ entropy_decoder_model_kernel_5(entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>&); // copy constructor
+ entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>& operator=(entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ entropy_decoder_model_kernel_5 (
+ entropy_decoder& coder_
+ ) :
+ coder(coder_),
+ next_node(1),
+ cur_order(0),
+ stack_size(0)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535);
+ COMPILE_TIME_ASSERT( 4096 < total_nodes );
+
+ root = new node[total_nodes];
+ cur = root;
+
+ root->child_context = 0;
+ root->escapes = 0;
+ root->next = 0;
+ root->parent_context = 0;
+ root->total = 0;
+
+ clear_exclusions();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ ~entropy_decoder_model_kernel_5 (
+ )
+ {
+ delete [] root;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ clear(
+ )
+ {
+ next_node = 1;
+ root->child_context = 0;
+ root->escapes = 0;
+ root->total = 0;
+ cur = root;
+ cur_order = 0;
+ stack_size = 0;
+
+ clear_exclusions();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ decode (
+ unsigned long& symbol
+ )
+ {
+ node* temp = cur;
+ cur = 0;
+ unsigned long low_count, high_count, total_count;
+ unsigned long target;
+ node* new_node = 0;
+
+ // local_order will track the level of temp in the tree
+ unsigned long local_order = cur_order;
+
+
+ unsigned short c; // c == t(a|sk)
+ unsigned short t; // t == T(sk)
+
+
+ if (something_is_excluded())
+ clear_exclusions();
+
+ while (true)
+ {
+ high_count = 0;
+ if (space_left())
+ {
+ total_count = temp->total;
+
+ if (total_count > 0)
+ {
+ // check if we need to scale the counts
+ if (total_count > 10000)
+ {
+ scale_counts(temp);
+ total_count = temp->total;
+ }
+
+ if (something_is_excluded())
+ {
+ node* n = temp->child_context;
+ total_count = temp->escapes;
+ while (true)
+ {
+ if (is_excluded(n->symbol) == false)
+ {
+ total_count += n->count;
+ }
+ if (n->next == 0)
+ break;
+ n = n->next;
+ }
+ }
+
+
+
+ target = coder.get_target(total_count);
+
+ // find either the symbol we are looking for or the
+ // end of the context set
+ node* n = temp->child_context;
+ node* last = 0;
+ while (true)
+ {
+ if (is_excluded(n->symbol) == false)
+ {
+ high_count += n->count;
+ exclude(n->symbol);
+ }
+
+
+ if (high_count > target || n->next == 0)
+ break;
+ last = n;
+ n = n->next;
+ }
+
+
+ // if we found the symbol
+ if (high_count > target)
+ {
+ low_count = high_count - n->count;
+
+ if (new_node != 0)
+ {
+ new_node->parent_context = n;
+ }
+
+ symbol = n->symbol;
+
+ coder.decode(low_count,high_count);
+ c = n->count += 8;
+ t = temp->total += 8;
+
+
+ // move this node to the front
+ if (last)
+ {
+ last->next = n->next;
+ n->next = temp->child_context;
+ temp->child_context = n;
+ }
+
+ if (cur == 0)
+ {
+ if (local_order < order)
+ {
+ cur_order = local_order+1;
+ cur = n;
+ }
+ else
+ {
+ cur = n->parent_context;
+ cur_order = local_order;
+ }
+ }
+
+ break;
+
+
+ }
+ // if we hit the end of the context set without finding the symbol
+ else
+ {
+ if (new_node != 0)
+ {
+ new_node->parent_context = allocate_node();
+ new_node = new_node->parent_context;
+ }
+ else
+ {
+ new_node = allocate_node();
+ }
+
+ n->next = new_node;
+
+ // get the escape code
+ coder.decode(high_count,total_count);
+ }
+
+ }
+ else // if (total_count == 0)
+ {
+ // this means that temp->child_context == 0 so we should make
+ // a new node here.
+ if (new_node != 0)
+ {
+ new_node->parent_context = allocate_node();
+ new_node = new_node->parent_context;
+ }
+ else
+ {
+ new_node = allocate_node();
+ }
+
+ temp->child_context = new_node;
+ }
+
+ if (cur == 0 && local_order < order)
+ {
+ cur = new_node;
+ cur_order = local_order+1;
+ }
+
+ // fill out the new node
+ new_node->child_context = 0;
+ new_node->escapes = 0;
+ new_node->next = 0;
+ push(new_node,temp);
+ new_node->total = 0;
+
+
+
+ if (temp != root)
+ {
+ temp = temp->parent_context;
+ --local_order;
+ continue;
+ }
+
+ t = 2056;
+ c = 8;
+
+ // since this is the root we are going to the order-(-1) context
+ // so we can just take care of that here.
+ target = coder.get_target(alphabet_size);
+ new_node->parent_context = root;
+ coder.decode(target,target+1);
+ symbol = target;
+
+ if (cur == 0)
+ {
+ cur = root;
+ cur_order = 0;
+ }
+ break;
+ }
+ else
+ {
+ // there isn't enough space so we should rebuild the tree
+ clear();
+ temp = cur;
+ local_order = cur_order;
+ cur = 0;
+ new_node = 0;
+ }
+ } // while (true)
+
+ // initialize the counts and symbol for any new nodes we have added
+ // to the tree.
+ node* n, *nc;
+ while (stack_size > 0)
+ {
+ pop(n,nc);
+
+ n->symbol = static_cast<unsigned short>(symbol);
+
+ // if nc is not a determnistic context
+ if (nc->total)
+ {
+ unsigned long temp2 = t-c+nc->total - nc->escapes - nc->escapes;
+ unsigned long temp = nc->total;
+ temp *= c;
+ temp /= (temp2|1); // this oring by 1 is just to make sure that temp2 is never zero
+ temp += 2;
+ if (temp > 50000) temp = 50000;
+ n->count = static_cast<unsigned short>(temp);
+
+
+ nc->escapes += 4;
+ nc->total += static_cast<unsigned short>(temp) + 4;
+ }
+ else
+ {
+ n->count = 3 + 5*(c)/(t-c);
+
+ nc->escapes = 4;
+ nc->total = n->count + 4;
+ }
+
+ while (nc->total > 10000)
+ {
+ scale_counts(nc);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ edmk5::node* entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ allocate_node (
+ )
+ {
+ node* temp;
+ temp = root + next_node;
+ ++next_node;
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ bool entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ space_left (
+ ) const
+ {
+ return (next_node < total_nodes);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ exclude (
+ unsigned short symbol
+ )
+ {
+ exc_used = true;
+ unsigned long temp = 1;
+ temp <<= symbol&0x1F;
+ exc[symbol>>5] |= temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ bool entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ is_excluded (
+ unsigned short symbol
+ )
+ {
+ unsigned long temp = 1;
+ temp <<= symbol&0x1F;
+ return ((exc[symbol>>5]&temp) != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ clear_exclusions (
+ )
+ {
+ exc_used = false;
+ for (unsigned long i = 0; i < alphabet_size/32+1; ++i)
+ {
+ exc[i] = 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ push (
+ node* n,
+ node* nc
+ )
+ {
+ stack[stack_size].n = n;
+ stack[stack_size].nc = nc;
+ ++stack_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ pop (
+ node*& n,
+ node*& nc
+ )
+ {
+ --stack_size;
+ n = stack[stack_size].n;
+ nc = stack[stack_size].nc;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ bool entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ something_is_excluded (
+ )
+ {
+ return exc_used;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_decoder_model_kernel_5<alphabet_size,entropy_decoder,total_nodes,order>::
+ scale_counts (
+ node* temp
+ )
+ {
+ if (temp->escapes > 1)
+ temp->escapes >>= 1;
+ temp->total = temp->escapes;
+
+ node* n = temp->child_context;
+ while (n != 0)
+ {
+ if (n->count > 1)
+ n->count >>= 1;
+
+ temp->total += n->count;
+ n = n->next;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_5_
+
+
diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_6.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_6.h
new file mode 100644
index 000000000..dc23f10eb
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_6.h
@@ -0,0 +1,131 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_6_
+#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_6_
+
+#include "../algs.h"
+#include "entropy_decoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder
+ >
+ class entropy_decoder_model_kernel_6
+ {
+ /*!
+ INITIAL VALUE
+ This object has no state
+
+ CONVENTION
+ &get_entropy_decoder() == coder
+
+ This is an order-(-1) model. So it doesn't really do anything.
+ Every symbol has the same probability.
+ !*/
+
+ public:
+
+ typedef entropy_decoder entropy_decoder_type;
+
+ entropy_decoder_model_kernel_6 (
+ entropy_decoder& coder
+ );
+
+ virtual ~entropy_decoder_model_kernel_6 (
+ );
+
+ inline void clear(
+ );
+
+ inline void decode (
+ unsigned long& symbol
+ );
+
+ entropy_decoder& get_entropy_decoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ entropy_decoder& coder;
+
+ // restricted functions
+ entropy_decoder_model_kernel_6(entropy_decoder_model_kernel_6<alphabet_size,entropy_decoder>&); // copy constructor
+ entropy_decoder_model_kernel_6<alphabet_size,entropy_decoder>& operator=(entropy_decoder_model_kernel_6<alphabet_size,entropy_decoder>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder
+ >
+ entropy_decoder_model_kernel_6<alphabet_size,entropy_decoder>::
+ entropy_decoder_model_kernel_6 (
+ entropy_decoder& coder_
+ ) :
+ coder(coder_)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder
+ >
+ entropy_decoder_model_kernel_6<alphabet_size,entropy_decoder>::
+ ~entropy_decoder_model_kernel_6 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder
+ >
+ void entropy_decoder_model_kernel_6<alphabet_size,entropy_decoder>::
+ clear(
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder
+ >
+ void entropy_decoder_model_kernel_6<alphabet_size,entropy_decoder>::
+ decode (
+ unsigned long& symbol
+ )
+ {
+ unsigned long target;
+
+ target = coder.get_target(alphabet_size);
+ coder.decode(target,target+1);
+
+ symbol = target;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_6_
+
diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_abstract.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_abstract.h
new file mode 100644
index 000000000..5b2deabd7
--- /dev/null
+++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_abstract.h
@@ -0,0 +1,116 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ENTROPY_DECODER_MODEL_KERNEl_ABSTRACT_
+#ifdef DLIB_ENTROPY_DECODER_MODEL_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_decoder
+ >
+ class entropy_decoder_model
+ {
+ /*!
+ REQUIREMENTS ON alphabet_size
+ 1 < alphabet_size < 65535
+
+ REQUIREMENTS ON entropy_decoder
+ is an implementation of entropy_decoder/entropy_decoder_kernel_abstract.h
+
+ INITIAL VALUE
+ Initially this object is at some predefined empty or ground state.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents some kind of statistical model. You
+ can use it to read symbols from an entropy_decoder and it will calculate
+ the cumulative counts/probabilities and manage contexts for you.
+
+ Note that all implementations of entropy_encoder_model and
+ entropy_decoder_model are paired. This means that if you use
+ entropy_encoder_model_kernel_n to encode something then you must
+ use the corresponding entropy_decoder_model_kernel_n to decode it.
+
+ Also note that this object does not perform any buffering of symbols. It
+ reads them from its associated entropy_decoder simply as it needs them.
+ This makes it safe to use multiple entropy_decoder_model objects with
+ a single entropy_decoder without them trampling each other.
+ !*/
+
+ public:
+
+ typedef entropy_decoder entropy_decoder_type;
+
+ entropy_decoder_model (
+ entropy_decoder& coder
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - &#get_entropy_decoder() == &coder
+ throws
+ - any exception
+ !*/
+
+ virtual ~entropy_decoder_model (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ - does not modify get_entropy_decoder()
+ throws
+ - any exception
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void decode (
+ unsigned long& symbol
+ );
+ /*!
+ ensures
+ - decodes the next symbol
+ - #symbol == the next symbol
+ - #symbol < alphabet_size
+ throws
+ - any exception
+ If this exception is thrown then #*this is unusable until
+ clear() is called and succeeds.
+ !*/
+
+ entropy_decoder& get_entropy_decoder (
+ );
+ /*!
+ ensures
+ - returns a reference to the entropy_decoder used by *this
+ !*/
+
+ static unsigned long get_alphabet_size (
+ );
+ /*!
+ ensures
+ - returns alphabet_size
+ !*/
+
+ private:
+
+ // restricted functions
+ entropy_decoder_model(entropy_decoder_model<alphabet_size>&); // copy constructor
+ entropy_decoder_model<alphabet_size>& operator=(entropy_decoder_model<alphabet_size>&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/entropy_encoder.h b/ml/dlib/dlib/entropy_encoder.h
new file mode 100644
index 000000000..5afda1976
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder.h
@@ -0,0 +1,43 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODEr_
+#define DLIB_ENTROPY_ENCODEr_
+
+#include "entropy_encoder/entropy_encoder_kernel_1.h"
+#include "entropy_encoder/entropy_encoder_kernel_2.h"
+#include "entropy_encoder/entropy_encoder_kernel_c.h"
+
+
+
+
+namespace dlib
+{
+
+
+ class entropy_encoder
+ {
+ entropy_encoder() {}
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef entropy_encoder_kernel_1
+ kernel_1a;
+ typedef entropy_encoder_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ // kernel_2a
+ typedef entropy_encoder_kernel_2
+ kernel_2a;
+ typedef entropy_encoder_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+ };
+}
+
+#endif // DLIB_ENTROPY_ENCODEr_
+
diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.cpp b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.cpp
new file mode 100644
index 000000000..effcf3123
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.cpp
@@ -0,0 +1,239 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_KERNEL_1_CPp_
+#define DLIB_ENTROPY_ENCODER_KERNEL_1_CPp_
+#include "entropy_encoder_kernel_1.h"
+#include <iostream>
+#include <streambuf>
+
+namespace dlib
+{
+
+
+// ----------------------------------------------------------------------------------------
+
+ entropy_encoder_kernel_1::
+ entropy_encoder_kernel_1(
+ ) :
+ initial_low(0x00000001),
+ initial_high(0xffffffff),
+ out(0),
+ low(initial_low),
+ high(initial_high),
+ buf(0),
+ buf_used(0)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ entropy_encoder_kernel_1::
+ ~entropy_encoder_kernel_1 (
+ )
+ {
+ try {
+ if (out != 0)
+ {
+ flush();
+ }
+ } catch (...) {}
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_encoder_kernel_1::
+ clear(
+ )
+ {
+ if (out != 0)
+ {
+ flush();
+ }
+ out = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_encoder_kernel_1::
+ set_stream (
+ std::ostream& out_
+ )
+ {
+ if (out != 0)
+ {
+ // if a stream is currently set then flush the buffers to it before
+ // we switch to the new stream
+ flush();
+ }
+
+ out = &out_;
+ streambuf = out_.rdbuf();
+
+ // reset the encoder state
+ buf_used = 0;
+ buf = 0;
+ low = initial_low;
+ high = initial_high;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool entropy_encoder_kernel_1::
+ stream_is_set (
+ ) const
+ {
+ if (out != 0)
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::ostream& entropy_encoder_kernel_1::
+ get_stream (
+ ) const
+ {
+ return *out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_encoder_kernel_1::
+ encode (
+ uint32 low_count,
+ uint32 high_count,
+ uint32 total
+ )
+ {
+ // note that we must add one because of the convention that
+ // high == the real upper range minus 1
+ uint32 r = (high-low+1)/total;
+
+ // note that we must subtract 1 to preserve the convention that
+ // high == the real upper range - 1
+ high = low + r*high_count-1;
+ low = low + r*low_count;
+
+
+ while (true)
+ {
+
+ // if the highest order bit in high and low is the same
+ if ( low >= 0x80000000 || high < 0x80000000)
+ {
+ // if buf is full then write it out
+ if (buf_used == 8)
+ {
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1)==0)
+ {
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+ }
+ buf = 0;
+ buf_used = 0;
+ }
+
+
+ // write the high order bit from low into buf
+ buf <<= 1;
+ ++buf_used;
+ if (low&0x80000000)
+ buf |= 0x1;
+
+ // roll off the bit we just wrote to buf
+ low <<= 1;
+ high <<= 1;
+ high |= 1; // note that it is ok to add one to high here because
+ // of the convention that high == real upper range - 1.
+ // so that means that if we want to shift the upper range
+ // left by one then we must shift a one into high also
+ // since real upper range == high + 0.999999999...
+
+ // make sure low is never zero
+ if (low == 0)
+ low = 1;
+ }
+ // if the distance between high and low is small and there aren't
+ // any bits we can roll off then round low up or high down.
+ else if (high-low < 0x10000)
+ {
+ if (high == 0x80000000)
+ high = 0x7fffffff;
+ else
+ low = 0x80000000;
+ }
+ else
+ {
+ break;
+ }
+ } // while (true)
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_encoder_kernel_1::
+ flush (
+ )
+ {
+ // flush the next 4 or 5 bytes that are buffered
+ // thats whatever is contained in buf and then all of low plus any extra
+ // bits needed to pad that to be an even 4 or 5 bytes
+
+
+ if (buf_used != 8)
+ {
+ buf <<= (8-buf_used);
+ buf |= static_cast<unsigned char>(low>>(24+buf_used));
+ low <<= (8-buf_used);
+ }
+
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1) == 0)
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+
+
+
+ buf = static_cast<unsigned char>((low >> 24)&0xFF);
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1) == 0)
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+
+
+
+
+ buf = static_cast<unsigned char>((low >> 16)&0xFF);
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1)==0)
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+
+
+
+ buf = static_cast<unsigned char>((low >> 8)&0xFF);
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1)==0)
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+
+
+
+ if (buf_used != 0)
+ {
+ buf = static_cast<unsigned char>((low)&0xFF);
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1)==0)
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+ }
+
+
+
+ // make sure the stream buffer flushes to its I/O channel
+ streambuf->pubsync();
+
+
+ // reset the encoder state
+ buf_used = 0;
+ buf = 0;
+ low = initial_low;
+ high = initial_high;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_ENTROPY_ENCODER_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.h b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.h
new file mode 100644
index 000000000..ccaaf9824
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.h
@@ -0,0 +1,119 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_KERNEl_1_
+#define DLIB_ENTROPY_ENCODER_KERNEl_1_
+
+#include "../algs.h"
+#include "entropy_encoder_kernel_abstract.h"
+#include <iosfwd>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+ class entropy_encoder_kernel_1
+ {
+ /*!
+ GENERAL NOTES
+ this encoder is implemented using arithmetic coding
+
+ INITIAL VALUE
+ out == 0
+ buf_used == 0
+ buf == 0
+ initial_low == 0x00000001 (slightly more than zero)
+ initial_high == 0xffffffff (slightly less than one, 0.99999999976717)
+ low == initial_low
+ high == initial_high
+
+ CONVENTION
+ if (out != 0)
+ *out == get_stream()
+ true == stream_is_set()
+ streambuf == out->rdbuf()
+ else
+ false == stream_is_set()
+
+ buf == used to accumulate bits before writing them to out.
+ buf_used == the number of low order bits in buf that are currently
+ in use
+ low == the low end of the range used for arithmetic encoding.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so it is
+ always in the range [0,1)
+
+ low is also never allowed to be zero to avoid overflow
+ in the calculation (high-low+1)/total.
+
+ high == the high end of the range - 1 used for arithmetic encoding.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so when we
+ interpret high as a real number then it is always in the
+ range [0,1)
+
+ the range for arithmetic encoding is always
+ [low,high + 0.9999999...) the 0.9999999... is why
+ high == real upper range - 1
+
+ !*/
+
+ public:
+
+ entropy_encoder_kernel_1 (
+ );
+
+ virtual ~entropy_encoder_kernel_1 (
+ );
+
+ void clear(
+ );
+
+ void set_stream (
+ std::ostream& out
+ );
+
+ bool stream_is_set (
+ ) const;
+
+ std::ostream& get_stream (
+ ) const;
+
+ void encode (
+ uint32 low_count,
+ uint32 high_count,
+ uint32 total
+ );
+
+ private:
+
+ void flush (
+ );
+ /*!
+ requires
+ out != 0 (i.e. there is a stream object to flush the data to
+ !*/
+
+ // restricted functions
+ entropy_encoder_kernel_1(entropy_encoder_kernel_1&); // copy constructor
+ entropy_encoder_kernel_1& operator=(entropy_encoder_kernel_1&); // assignment operator
+
+ // data members
+ const uint32 initial_low;
+ const uint32 initial_high;
+ std::ostream* out;
+ uint32 low;
+ uint32 high;
+ unsigned char buf;
+ uint32 buf_used;
+ std::streambuf* streambuf;
+
+ };
+
+}
+
+#ifdef NO_MAKEFILE
+#include "entropy_encoder_kernel_1.cpp"
+#endif
+
+#endif // DLIB_ENTROPY_ENCODER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.cpp b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.cpp
new file mode 100644
index 000000000..4f64a6155
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.cpp
@@ -0,0 +1,233 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_KERNEL_2_CPp_
+#define DLIB_ENTROPY_ENCODER_KERNEL_2_CPp_
+#include "entropy_encoder_kernel_2.h"
+#include <iostream>
+#include <streambuf>
+
+namespace dlib
+{
+
+
+// ----------------------------------------------------------------------------------------
+
+ entropy_encoder_kernel_2::
+ entropy_encoder_kernel_2(
+ ) :
+ initial_low(0x00000001),
+ initial_high(0xffffffff),
+ out(0),
+ low(initial_low),
+ high(initial_high)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ entropy_encoder_kernel_2::
+ ~entropy_encoder_kernel_2 (
+ )
+ {
+ try {
+ if (out != 0)
+ {
+ flush();
+ }
+ } catch (...) {}
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_encoder_kernel_2::
+ clear(
+ )
+ {
+ if (out != 0)
+ {
+ flush();
+ }
+ out = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_encoder_kernel_2::
+ set_stream (
+ std::ostream& out_
+ )
+ {
+ if (out != 0)
+ {
+ // if a stream is currently set then flush the buffers to it before
+ // we switch to the new stream
+ flush();
+ }
+
+ out = &out_;
+ streambuf = out_.rdbuf();
+
+ // reset the encoder state
+ low = initial_low;
+ high = initial_high;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool entropy_encoder_kernel_2::
+ stream_is_set (
+ ) const
+ {
+ if (out != 0)
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::ostream& entropy_encoder_kernel_2::
+ get_stream (
+ ) const
+ {
+ return *out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_encoder_kernel_2::
+ encode (
+ uint32 low_count,
+ uint32 high_count,
+ uint32 total
+ )
+ {
+ // note that we must add one because of the convention that
+ // high == the real upper range minus 1
+ uint32 r = (high-low+1)/total;
+
+ // note that we must subtract 1 to preserve the convention that
+ // high == the real upper range - 1
+ high = low + r*high_count-1;
+ low = low + r*low_count;
+
+
+
+ while (true )
+ {
+
+ // if high and low don't have the same 8 high order bits
+ if ((high&0xFF000000) != (low&0xFF000000))
+ {
+ // if the distance between high and low is small and there aren't
+ // any bits we can roll off then force high and low to have common high
+ // order bits.
+ if ((high-low < 0x10000))
+ {
+ if (high-low > 0x1000)
+ {
+ high>>=1;
+ low>>=1;
+ high = low = high+low;
+ high += 0xFF;
+ low -= 0xFF;
+ }
+ else /**/
+ {
+ high>>=1;
+ low>>=1;
+ high = low = high+low;
+ }
+ }
+ else
+ {
+ // there are no bits to roll off and high and low are not
+ // too close so just quit the loop
+ break;
+ }
+
+ }
+ // else if there are 8 bits we can roll off
+ else
+ {
+ // write the 8 high order bits from low into buf
+ unsigned char buf = static_cast<unsigned char>(low>>24);
+
+
+ // roll off the bits we just wrote to buf
+ high <<= 8;
+ low <<= 8;
+ high |= 0xFF; // note that it is ok to add 0xFF to high here because
+ // of the convention that high == real upper range - 1.
+ // so that means that if we want to shift the upper range
+ // left by one then we must shift a one into high also
+ // since real upper range == high + 0.999999999...
+
+ // make sure low is never zero
+ if (low == 0)
+ low = 1;
+
+ // write buf to the output stream
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1)==0)
+ {
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+ }
+
+ }
+
+ } // while (true)
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void entropy_encoder_kernel_2::
+ flush (
+ )
+ {
+
+ // flush low to the output stream
+
+
+ unsigned char buf;
+
+
+ buf = static_cast<unsigned char>((low >> 24)&0xFF);
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1) == 0)
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+
+
+
+
+ buf = static_cast<unsigned char>((low >> 16)&0xFF);
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1)==0)
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+
+
+
+ buf = static_cast<unsigned char>((low >> 8)&0xFF);
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1)==0)
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+
+
+ buf = static_cast<unsigned char>((low)&0xFF);
+ if (streambuf->sputn(reinterpret_cast<char*>(&buf),1)==0)
+ throw std::ios_base::failure("error occurred in the entropy_encoder object");
+
+
+
+
+ // make sure the stream buffer flushes to its I/O channel
+ streambuf->pubsync();
+
+
+ // reset the encoder state
+ low = initial_low;
+ high = initial_high;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_ENTROPY_ENCODER_KERNEL_2_CPp_
+
diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.h b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.h
new file mode 100644
index 000000000..71a7503c6
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.h
@@ -0,0 +1,112 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_KERNEl_2_
+#define DLIB_ENTROPY_ENCODER_KERNEl_2_
+
+#include "../algs.h"
+#include "entropy_encoder_kernel_abstract.h"
+#include <iosfwd>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+ class entropy_encoder_kernel_2
+ {
+ /*!
+ GENERAL NOTES
+ this encoder is implemented using "range" coding
+
+ INITIAL VALUE
+ out == 0
+ initial_low == 0x00000001 (slightly more than zero)
+ initial_high == 0xffffffff (slightly less than one, 0.99999999976717)
+ low == initial_low
+ high == initial_high
+
+ CONVENTION
+ if (out != 0)
+ *out == get_stream()
+ true == stream_is_set()
+ streambuf == out->rdbuf()
+ else
+ false == stream_is_set()
+
+
+ low == the low end of the range used for arithmetic encoding.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so it is
+ always in the range [0,1)
+
+ low is also never allowed to be zero to avoid overflow
+ in the calculation (high-low+1)/total.
+
+ high == the high end of the range - 1 used for arithmetic encoding.
+ this number is used as a 32bit fixed point real number.
+ the point is fixed just before the first bit, so when we
+ interpret high as a real number then it is always in the
+ range [0,1)
+
+ the range for arithmetic encoding is always
+ [low,high + 0.9999999...) the 0.9999999... is why
+ high == real upper range - 1
+ !*/
+
+ public:
+
+ entropy_encoder_kernel_2 (
+ );
+
+ virtual ~entropy_encoder_kernel_2 (
+ );
+
+ void clear(
+ );
+
+ void set_stream (
+ std::ostream& out
+ );
+
+ bool stream_is_set (
+ ) const;
+
+ std::ostream& get_stream (
+ ) const;
+
+ void encode (
+ uint32 low_count,
+ uint32 high_count,
+ uint32 total
+ );
+
+ private:
+
+ void flush (
+ );
+ /*!
+ requires
+ out != 0 (i.e. there is a stream object to flush the data to
+ !*/
+
+ // restricted functions
+ entropy_encoder_kernel_2(entropy_encoder_kernel_2&); // copy constructor
+ entropy_encoder_kernel_2& operator=(entropy_encoder_kernel_2&); // assignment operator
+
+ // data members
+ const uint32 initial_low;
+ const uint32 initial_high;
+ std::ostream* out;
+ uint32 low;
+ uint32 high;
+ std::streambuf* streambuf;
+
+ };
+
+}
+
+#ifdef NO_MAKEFILE
+#include "entropy_encoder_kernel_2.cpp"
+#endif
+
+#endif // DLIB_ENTROPY_ENCODER_KERNEl_2_
+
diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_abstract.h b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_abstract.h
new file mode 100644
index 000000000..48af93307
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_abstract.h
@@ -0,0 +1,161 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ENTROPY_ENCODER_KERNEl_ABSTRACT_
+#ifdef DLIB_ENTROPY_ENCODER_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+#include <iosfwd>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+ class entropy_encoder
+ {
+ /*!
+ INITIAL VALUE
+ stream_is_set() == false
+
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an entropy encoder (could be implemented as an
+ arithmetic encoder for example).
+
+ Note that all implementations of entropy_encoder and entropy_decoder
+ are paired. This means that if you use entropy_encoder_kernel_n to
+ encode something then you must use the corresponding
+ entropy_decoder_kernel_n to decode it.
+
+ NOTATION:
+ At any moment each symbol has a certain probability of appearing in
+ the input stream. These probabilities may change as each symbol is
+ encountered and the probability model is updated accordingly.
+
+
+ let P(i) be a function which gives the probability of seeing the ith
+ symbol of an N symbol alphabet BEFORE the probability model is updated
+ to account for the current symbol. ( The domain of P(i) is from 0 to N-1. )
+
+ for each i: P(i) == COUNT/TOTAL where COUNT and TOTAL are integers.
+ and TOTAL is the same number for all P(i) but COUNT may vary.
+
+ let LOW_COUNT(i) be the sum of all P(x)*TOTAL from x == 0 to x == i-1
+ (note that LOW_COUNT(0) == 0)
+ let HIGH_COUNT(i) be the sum of all P(x)*TOTAL from x == 0 to x == i
+ !*/
+
+ public:
+
+ entropy_encoder (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~entropy_encoder (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ - if (stream_is_set()) then
+ - any buffered data in *this will be written to get_stream()
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ - if (stream_is_set()) then
+ - any buffered data in *this will be written to get_stream()
+ - clears any memory of all previous calls to encode() from #*this
+ throws
+ - std::ios_base::failure
+ if (stream_is_set() && there was a problem writing to get_stream())
+ then this exception will be thrown. #*this will be unusable until
+ clear() is called and succeeds
+ - any other exception
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+
+ void set_stream (
+ std::ostream& out
+ );
+ /*!
+ ensures
+ - #get_stream() == out
+ - #stream_is_set() == true
+ - if (stream_is_set()) then
+ - any buffered data in *this will be written to get_stream()
+ - clears any memory of all previous calls to encode() from #*this
+ throws
+ - std::ios_base::failure
+ if (stream_is_set() && there was a problem writing to get_stream())
+ then this exception will be thrown. #*this will be unusable until
+ clear() is called and succeeds
+ - any other exception
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ bool stream_is_set (
+ ) const;
+ /*!
+ ensures
+ - returns true if a stream has been associated with *this by calling
+ set_stream()
+ !*/
+
+ std::ostream& get_stream (
+ ) const;
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - returns a reference to the ostream object that *this writes its
+ encoded data to
+ !*/
+
+ void encode (
+ uint32 low_count,
+ uint32 high_count,
+ uint32 total
+ );
+ /*!
+ requires
+ - 0 < total < 65536 (2^16)
+ - total == TOTAL
+ - low_count < high_count <= total
+ - stream_is_set() == true
+ ensures
+ - encodes the symbol S where:
+ - LOW_COUNT(S) == low_count
+ - HIGH_COUNT(S) == high_count
+ throws
+ - std::ios_base::failure
+ if (there was a problem writing to get_stream()) then
+ this exception will be thrown. #*this will be unusable until
+ clear() is called and succeeds
+ - any other exception
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+
+ private:
+
+ // restricted functions
+ entropy_encoder(entropy_encoder&); // copy constructor
+ entropy_encoder& operator=(entropy_encoder&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_c.h b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_c.h
new file mode 100644
index 000000000..f11241ecc
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_c.h
@@ -0,0 +1,112 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_KERNEl_C_
+#define DLIB_ENTROPY_ENCODER_KERNEl_C_
+
+#include "entropy_encoder_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename encoder
+ >
+ class entropy_encoder_kernel_c : public encoder
+ {
+
+ public:
+ std::ostream& get_stream (
+ ) const;
+
+ void encode (
+ uint32 low_count,
+ uint32 high_count,
+ uint32 total
+ );
+
+ void flush (
+ );
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename encoder
+ >
+ std::ostream& entropy_encoder_kernel_c<encoder>::
+ get_stream (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tstd::ostream& entropy_encoder::get_stream()"
+ << "\n\tyou must set a stream for this object before you can get it"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return encoder::get_stream();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename encoder
+ >
+ void entropy_encoder_kernel_c<encoder>::
+ encode (
+ uint32 low_count,
+ uint32 high_count,
+ uint32 total
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (0 < total) && (total < 65536) && (low_count < high_count) && (high_count <= total) &&
+ (this->stream_is_set() == true),
+ "\tvoid entropy_encoder::encode()"
+ << "\n\trefer to the ensures clause for this function for further information"
+ << "\n\tthis: " << this
+ << "\n\ttotal: " << total
+ << "\n\tlow_count: " << low_count
+ << "\n\thigh_count: " << high_count
+ << "\n\tis_stream_set(): " << (this->stream_is_set() ? "true" : "false" )
+ );
+
+ // call the real function
+ encoder::encode(low_count,high_count,total);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename encoder
+ >
+ void entropy_encoder_kernel_c<encoder>::
+ flush (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tvoid entropy_encoder::flush()"
+ << "\n\tyou must set a stream for this object before you can flush to it"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ encoder::flush();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_KERNEl_C_
+
diff --git a/ml/dlib/dlib/entropy_encoder_model.h b/ml/dlib/dlib/entropy_encoder_model.h
new file mode 100644
index 000000000..465377e97
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder_model.h
@@ -0,0 +1,146 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_MODEl_
+#define DLIB_ENTROPY_ENCODER_MODEl_
+
+#include "entropy_encoder_model/entropy_encoder_model_kernel_1.h"
+#include "entropy_encoder_model/entropy_encoder_model_kernel_2.h"
+#include "entropy_encoder_model/entropy_encoder_model_kernel_3.h"
+#include "entropy_encoder_model/entropy_encoder_model_kernel_4.h"
+#include "entropy_encoder_model/entropy_encoder_model_kernel_5.h"
+#include "entropy_encoder_model/entropy_encoder_model_kernel_6.h"
+#include "entropy_encoder_model/entropy_encoder_model_kernel_c.h"
+
+#include "conditioning_class.h"
+#include "memory_manager.h"
+#include "sliding_buffer.h"
+
+
+namespace dlib
+{
+
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder
+ >
+ class entropy_encoder_model
+ {
+ entropy_encoder_model() {}
+
+ typedef typename conditioning_class<alphabet_size+1>::kernel_1a cc1;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_2a cc2;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_3a cc3;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_4a cc4a;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_4b cc4b;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_4c cc4c;
+ typedef typename conditioning_class<alphabet_size+1>::kernel_4d cc4d;
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1
+ typedef entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc1>
+ kernel_1a;
+ typedef entropy_encoder_model_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+ typedef entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc2>
+ kernel_1b;
+ typedef entropy_encoder_model_kernel_c<kernel_1b>
+ kernel_1b_c;
+
+ typedef entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc3>
+ kernel_1c;
+ typedef entropy_encoder_model_kernel_c<kernel_1c>
+ kernel_1c_c;
+
+ // --------------------
+
+ // kernel_2
+ typedef entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc1,cc1>
+ kernel_2a;
+ typedef entropy_encoder_model_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+ typedef entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc2,cc2>
+ kernel_2b;
+ typedef entropy_encoder_model_kernel_c<kernel_2b>
+ kernel_2b_c;
+
+ typedef entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc3,cc3>
+ kernel_2c;
+ typedef entropy_encoder_model_kernel_c<kernel_2c>
+ kernel_2c_c;
+
+ typedef entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc2,cc4b>
+ kernel_2d;
+ typedef entropy_encoder_model_kernel_c<kernel_2d>
+ kernel_2d_c;
+
+ // --------------------
+
+ // kernel_3
+ typedef entropy_encoder_model_kernel_3<alphabet_size,entropy_encoder,cc1,cc4b>
+ kernel_3a;
+ typedef entropy_encoder_model_kernel_c<kernel_3a>
+ kernel_3a_c;
+
+ typedef entropy_encoder_model_kernel_3<alphabet_size,entropy_encoder,cc2,cc4b>
+ kernel_3b;
+ typedef entropy_encoder_model_kernel_c<kernel_3b>
+ kernel_3b_c;
+
+ typedef entropy_encoder_model_kernel_3<alphabet_size,entropy_encoder,cc3,cc4b>
+ kernel_3c;
+ typedef entropy_encoder_model_kernel_c<kernel_3c>
+ kernel_3c_c;
+
+ // --------------------
+
+ // kernel_4
+ typedef entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,200000,4>
+ kernel_4a;
+ typedef entropy_encoder_model_kernel_c<kernel_4a>
+ kernel_4a_c;
+
+ typedef entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,1000000,5>
+ kernel_4b;
+ typedef entropy_encoder_model_kernel_c<kernel_4b>
+ kernel_4b_c;
+
+ // --------------------
+
+ // kernel_5
+ typedef entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,200000,4>
+ kernel_5a;
+ typedef entropy_encoder_model_kernel_c<kernel_5a>
+ kernel_5a_c;
+
+ typedef entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,1000000,5>
+ kernel_5b;
+ typedef entropy_encoder_model_kernel_c<kernel_5b>
+ kernel_5b_c;
+
+ typedef entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,2500000,7>
+ kernel_5c;
+ typedef entropy_encoder_model_kernel_c<kernel_5c>
+ kernel_5c_c;
+
+ // --------------------
+
+ // kernel_6
+ typedef entropy_encoder_model_kernel_6<alphabet_size,entropy_encoder>
+ kernel_6a;
+ typedef entropy_encoder_model_kernel_c<kernel_6a>
+ kernel_6a_c;
+
+
+
+ };
+}
+
+#endif // DLIB_ENTROPY_ENCODER_MODEl_
+
diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_1.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_1.h
new file mode 100644
index 000000000..29c82e5be
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_1.h
@@ -0,0 +1,167 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_1_
+#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_1_
+
+#include "../algs.h"
+#include "entropy_encoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc
+ >
+ class entropy_encoder_model_kernel_1
+ {
+ /*!
+ REQUIREMENTS ON cc
+ cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ cc::get_alphabet_size() == alphabet_size+1
+
+ INITIAL VALUE
+ Initially this object's finite context model is empty
+
+ CONVENTION
+ &get_entropy_encoder() == coder
+ &order_0.get_global_state() == &gs
+
+ This is an order-0 model. The last symbol in the order-0 context is
+ an escape into the order minus 1 context.
+ !*/
+
+ public:
+
+ typedef entropy_encoder entropy_encoder_type;
+
+ entropy_encoder_model_kernel_1 (
+ entropy_encoder& coder
+ );
+
+ virtual ~entropy_encoder_model_kernel_1 (
+ );
+
+ inline void clear(
+ );
+
+ inline void encode (
+ unsigned long symbol
+ );
+
+ entropy_encoder& get_entropy_encoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ entropy_encoder& coder;
+ typename cc::global_state_type gs;
+ cc order_0;
+
+ // restricted functions
+ entropy_encoder_model_kernel_1(entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc>&); // copy constructor
+ entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc>& operator=(entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc
+ >
+ entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc>::
+ entropy_encoder_model_kernel_1 (
+ entropy_encoder& coder_
+ ) :
+ coder(coder_),
+ order_0(gs)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc
+ >
+ entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc>::
+ ~entropy_encoder_model_kernel_1 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc
+ >
+ void entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc>::
+ clear(
+ )
+ {
+ order_0.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc
+ >
+ void entropy_encoder_model_kernel_1<alphabet_size,entropy_encoder,cc>::
+ encode (
+ unsigned long symbol
+ )
+ {
+ unsigned long low_count = 0, high_count = 0, total_count = 0;
+
+ // if we have seen this symbol in the order-0 context
+ if (order_0.get_range(symbol,low_count,high_count,total_count))
+ {
+ // update the count for this symbol
+ order_0.increment_count(symbol,2);
+ // encode this symbol
+ coder.encode(low_count,high_count,total_count);
+ return;
+ }
+
+ // if we are here then the symbol does not appear in the order-0 context
+
+
+ // since we have never seen the current symbol in this context
+ // escape from order-0 context
+ order_0.get_range(alphabet_size,low_count,high_count,total_count);
+ coder.encode(low_count,high_count,total_count);
+ // increment the count for the escape symbol
+ order_0.increment_count(alphabet_size);
+
+ // update the count for this symbol
+ order_0.increment_count(symbol,2);
+
+ // use order minus one context
+ coder.encode(symbol,symbol+1,alphabet_size);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_1_
+
diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_2.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_2.h
new file mode 100644
index 000000000..08a16cae0
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_2.h
@@ -0,0 +1,246 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_2_
+#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_2_
+
+#include "../algs.h"
+#include "entropy_encoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename ccbig
+ >
+ class entropy_encoder_model_kernel_2
+ {
+ /*!
+ REQUIREMENTS ON cc
+ cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ cc::get_alphabet_size() == alphabet_size+1
+ this will be used for the order-0 context
+
+ REQUIREMENTS ON ccbig
+ ccbig is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ ccbig::get_alphabet_size() == alphabet_size+1
+ this will be used for the order-1 context
+
+ INITIAL VALUE
+ Initially this object's finite context model is empty
+ previous_symbol == 0
+
+ CONVENTION
+ &get_entropy_encoder() == coder
+ &order_0.get_global_state() == &gs
+ &order_1[i]->get_global_state() == &gsbig
+
+
+ This is an order-1-0 model. The last symbol in the order-0 and order-1
+ context is an escape into the lower context.
+
+ previous_symbol == the last symbol seen
+ !*/
+
+ public:
+
+ typedef entropy_encoder entropy_encoder_type;
+
+ entropy_encoder_model_kernel_2 (
+ entropy_encoder& coder
+ );
+
+ virtual ~entropy_encoder_model_kernel_2 (
+ );
+
+ inline void clear(
+ );
+
+ inline void encode (
+ unsigned long symbol
+ );
+
+ entropy_encoder& get_entropy_encoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ entropy_encoder& coder;
+ typename cc::global_state_type gs;
+ typename ccbig::global_state_type gsbig;
+ cc order_0;
+ ccbig* order_1[alphabet_size];
+ unsigned long previous_symbol;
+
+
+ // restricted functions
+ entropy_encoder_model_kernel_2(entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc,ccbig>&); // copy constructor
+ entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc,ccbig>& operator=(entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc,ccbig>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename ccbig
+ >
+ entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc,ccbig>::
+ entropy_encoder_model_kernel_2 (
+ entropy_encoder& coder_
+ ) :
+ coder(coder_),
+ order_0(gs),
+ previous_symbol(0)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 );
+
+ unsigned long i;
+ try
+ {
+ for (i = 0; i < alphabet_size; ++i)
+ {
+ order_1[i] = new ccbig(gsbig);
+ }
+ }
+ catch (...)
+ {
+ for (unsigned long j = 0; j < i; ++j)
+ {
+ delete order_1[j];
+ }
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename ccbig
+ >
+ entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc,ccbig>::
+ ~entropy_encoder_model_kernel_2 (
+ )
+ {
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ delete order_1[i];
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename ccbig
+ >
+ void entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc,ccbig>::
+ clear(
+ )
+ {
+ previous_symbol = 0;
+ order_0.clear();
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ order_1[i]->clear();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename ccbig
+ >
+ void entropy_encoder_model_kernel_2<alphabet_size,entropy_encoder,cc,ccbig>::
+ encode (
+ unsigned long symbol
+ )
+ {
+ unsigned long low_count = 0, high_count = 0, total_count = 0;
+
+
+ ccbig& context = *order_1[previous_symbol];
+
+ // if we have seen this symbol in the order-1 context
+ if (context.get_range(symbol,low_count,high_count,total_count))
+ {
+ // update the count for this symbol
+ context.increment_count(symbol,2);
+ // encode this symbol
+ coder.encode(low_count,high_count,total_count);
+ previous_symbol = symbol;
+ return;
+ }
+
+ // we didn't find the symbol in the order-1 context so we must escape to a
+ // lower context.
+
+ // escape to the order-0 context
+ context.get_range(alphabet_size,low_count,high_count,total_count);
+ coder.encode(low_count,high_count,total_count);
+
+
+ // increment counts for the escape symbol and the current symbol
+ context.increment_count(alphabet_size);
+ context.increment_count(symbol,2);
+
+ previous_symbol = symbol;
+
+
+
+
+
+ // if we have seen this symbol in the order-0 context
+ if (order_0.get_range(symbol,low_count,high_count,total_count))
+ {
+ // update the count for this symbol
+ order_0.increment_count(symbol,2);
+ // encode this symbol
+ coder.encode(low_count,high_count,total_count);
+ return;
+ }
+
+ // if we are here then the symbol does not appear in the order-0 context
+
+
+ // since we have never seen the current symbol in this context
+ // escape from order-0 context
+ order_0.get_range(alphabet_size,low_count,high_count,total_count);
+ coder.encode(low_count,high_count,total_count);
+ // increment the count for the escape symbol
+ order_0.increment_count(alphabet_size);
+
+ // update the count for this symbol
+ order_0.increment_count(symbol,2);
+
+ // use order minus one context
+ coder.encode(symbol,symbol+1,alphabet_size);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_2_
+
diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_3.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_3.h
new file mode 100644
index 000000000..0df28f201
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_3.h
@@ -0,0 +1,341 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_3_
+#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_3_
+
+#include "../algs.h"
+#include "entropy_encoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename cc_high
+ >
+ class entropy_encoder_model_kernel_3
+ {
+ /*!
+ REQUIREMENTS ON cc
+ cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ cc::get_alphabet_size() == alphabet_size+1
+
+ REQUIREMENTS ON cc_high
+ cc_high is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ cc_high::get_alphabet_size() == alphabet_size+1
+
+ INITIAL VALUE
+ - Initially this object's finite context model is empty
+ - previous_symbol == 0
+ - previous_symbol2 == 0
+ - order_1 == pointer to an array of alphabet_size elements
+ - order_2 == pointer to an array of alphabet_size*alphabet_size elements
+ - order_2[i] == 0
+
+ CONVENTION
+ &get_entropy_encoder() == coder
+ &order_0.get_global_state() == &gs
+ &order_1[i]->get_global_state() == &gs
+
+ if (order_2[i] != 0) then
+ &order_2[i]->get_global_state() == &gs_high
+
+ This is an order-2-1-0 model. The last symbol in the order-2, order-1 and
+ order-0 contexts is an escape into the lower context.
+
+ previous_symbol == the last symbol seen
+ previous_symbol2 == the symbol we saw before previous_symbol
+ !*/
+
+ public:
+
+ typedef entropy_encoder entropy_encoder_type;
+
+ entropy_encoder_model_kernel_3 (
+ entropy_encoder& coder
+ );
+
+ virtual ~entropy_encoder_model_kernel_3 (
+ );
+
+ inline void clear(
+ );
+
+ inline void encode (
+ unsigned long symbol
+ );
+
+ entropy_encoder& get_entropy_encoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ entropy_encoder& coder;
+ typename cc::global_state_type gs;
+ typename cc_high::global_state_type gs_high;
+ cc order_0;
+ cc** order_1;
+ unsigned long previous_symbol;
+ cc_high** order_2;
+ unsigned long previous_symbol2;
+
+
+ // restricted functions
+ entropy_encoder_model_kernel_3(entropy_encoder_model_kernel_3&); // copy constructor
+ entropy_encoder_model_kernel_3& operator=(entropy_encoder_model_kernel_3&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename cc_high
+ >
+ entropy_encoder_model_kernel_3<alphabet_size,entropy_encoder,cc,cc_high>::
+ entropy_encoder_model_kernel_3 (
+ entropy_encoder& coder_
+ ) :
+ coder(coder_),
+ order_0(gs),
+ order_1(0),
+ previous_symbol(0),
+ order_2(0),
+ previous_symbol2(0)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 );
+
+ try
+ {
+ order_1 = new cc*[alphabet_size];
+ order_2 = new cc_high*[alphabet_size*alphabet_size];
+ }
+ catch (...)
+ {
+ if (order_1) delete [] order_1;
+ if (order_2) delete [] order_2;
+ throw;
+ }
+
+ unsigned long i;
+
+ for (i = 0; i < (alphabet_size*alphabet_size); ++i)
+ {
+ order_2[i] = 0;
+ }
+
+ try
+ {
+ for (i = 0; i < alphabet_size; ++i)
+ {
+ order_1[i] = new cc(gs);
+ }
+ }
+ catch (...)
+ {
+ for (unsigned long j = 0; j < i; ++j)
+ {
+ delete order_1[j];
+ }
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename cc_high
+ >
+ entropy_encoder_model_kernel_3<alphabet_size,entropy_encoder,cc,cc_high>::
+ ~entropy_encoder_model_kernel_3 (
+ )
+ {
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ delete order_1[i];
+ }
+
+ for (unsigned long i = 0; i < alphabet_size*alphabet_size; ++i)
+ {
+ if (order_2[i] != 0)
+ delete order_2[i];
+ }
+ delete [] order_1;
+ delete [] order_2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename cc_high
+ >
+ void entropy_encoder_model_kernel_3<alphabet_size,entropy_encoder,cc,cc_high>::
+ clear(
+ )
+ {
+ previous_symbol = 0;
+ previous_symbol2 = 0;
+ order_0.clear();
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ order_1[i]->clear();
+ }
+
+ for (unsigned long i = 0; i < alphabet_size*alphabet_size; ++i)
+ {
+ if (order_2[i] != 0)
+ {
+ delete order_2[i];
+ order_2[i] = 0;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ typename cc,
+ typename cc_high
+ >
+ void entropy_encoder_model_kernel_3<alphabet_size,entropy_encoder,cc,cc_high>::
+ encode (
+ unsigned long symbol
+ )
+ {
+ unsigned long low_count = 0, high_count = 0, total_count = 0;
+
+
+ // order-2 context stuff
+ {
+ unsigned long temp = previous_symbol + (previous_symbol2 * alphabet_size);
+ previous_symbol2 = previous_symbol;
+
+ if (order_2[temp] != 0)
+ {
+ if (order_2[temp]->get_range(symbol,low_count,high_count,total_count))
+ {
+ // there was an entry for this symbol in this context
+
+ // update the count for this symbol
+ order_2[temp]->increment_count(symbol,2);
+ // encode this symbol
+ coder.encode(low_count,high_count,total_count);
+ previous_symbol = symbol;
+ return;
+ }
+
+ // there was no entry for this symbol in this context so we must
+ // escape to order-1
+
+ // escape to the order-1 context
+ order_2[temp]->get_range(alphabet_size,low_count,high_count,total_count);
+ coder.encode(low_count,high_count,total_count);
+
+ // increment the count for the escape symbol
+ order_2[temp]->increment_count(alphabet_size);
+
+ }
+ else
+ {
+ order_2[temp] = new cc_high(gs_high);
+
+ // in this case the decoder knows to escape to order-1 because
+ // there was no conditioning_class object in this context yet.
+ // so we don't need to actually write the escape symbol
+ }
+
+ // update the count for this symbol in this context
+ order_2[temp]->increment_count(symbol,2);
+ }
+
+
+
+
+ // order-1 context stuff
+ {
+ cc& context = *order_1[previous_symbol];
+
+ // if we have seen this symbol in the order-1 context
+ if (context.get_range(symbol,low_count,high_count,total_count))
+ {
+ // update the count for this symbol
+ context.increment_count(symbol,2);
+ // encode this symbol
+ coder.encode(low_count,high_count,total_count);
+ previous_symbol = symbol;
+ return;
+ }
+
+ // we didn't find the symbol in the order-1 context so we must escape to a
+ // lower context.
+
+ // escape to the order-0 context
+ context.get_range(alphabet_size,low_count,high_count,total_count);
+ coder.encode(low_count,high_count,total_count);
+
+
+ // increment counts for the escape symbol and the current symbol
+ context.increment_count(alphabet_size);
+ context.increment_count(symbol,2);
+ }
+
+ previous_symbol = symbol;
+
+
+
+
+
+ // if we have seen this symbol in the order-0 context
+ if (order_0.get_range(symbol,low_count,high_count,total_count))
+ {
+ // update the count for this symbol
+ order_0.increment_count(symbol,2);
+ // encode this symbol
+ coder.encode(low_count,high_count,total_count);
+ return;
+ }
+
+ // if we are here then the symbol does not appear in the order-0 context
+
+
+ // since we have never seen the current symbol in this context
+ // escape from order-0 context
+ order_0.get_range(alphabet_size,low_count,high_count,total_count);
+ coder.encode(low_count,high_count,total_count);
+ // increment the count for the escape symbol
+ order_0.increment_count(alphabet_size);
+
+ // update the count for this symbol
+ order_0.increment_count(symbol,2);
+
+ // use order minus one context
+ coder.encode(symbol,symbol+1,alphabet_size);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_3_
+
diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_4.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_4.h
new file mode 100644
index 000000000..0e5ae46d3
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_4.h
@@ -0,0 +1,553 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_4_
+#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_4_
+
+#include "../algs.h"
+#include "entropy_encoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+
+
+namespace dlib
+{
+
+ namespace eemk4
+ {
+ struct node
+ {
+ node* next;
+ node* child_context;
+ node* parent_context;
+
+ unsigned short symbol;
+ unsigned short count;
+ unsigned short total;
+ unsigned short escapes;
+ };
+ }
+
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ class entropy_encoder_model_kernel_4
+ {
+ /*!
+ REQUIREMENTS ON total_nodes
+ - 4096 < total_nodes
+ - this is the total number of nodes that we will use in the tree
+
+ REQUIREMENTS ON order
+ - 0 <= order
+ - this is the maximum depth-1 the tree will be allowed to go (note
+ that the root level is depth 0).
+
+ GENERAL NOTES
+ This implementation follows more or less the implementation
+ strategy laid out by Alistair Moffat in his paper
+ Implementing the PPM data compression scheme. Published in IEEE
+ Transactions on Communications, 38(11):1917-1921, 1990.
+
+ The escape method used will be method D.
+
+
+ INITIAL VALUE
+ - root == pointer to an array of total_nodes nodes
+ - next_node == 1
+ - cur == root
+ - cur_order = 0
+ - root->next == 0
+ - root->parent_context == 0
+ - root->child_context == 0
+ - root->escapes == 0
+ - root->total == 0
+
+ CONVENTION
+ - &get_entropy_encoder() == coder
+ - root == pointer to an array of total_nodes nodes.
+ this is also the root of the tree.
+
+ - if (next_node < total_nodes) then
+ - next_node == the next node in root that has not yet been allocated
+
+ - root->next == 0
+ - root->parent_context == 0
+
+
+ - for every node in the tree:
+ {
+ - NOTATION:
+ - The "context" of a node is the string of symbols seen
+ when you go from the root of the tree down (down though
+ child context pointers) to the node, including the symbol at
+ the node itself. (note that the context of the root node
+ is "" or the empty string)
+ - A set of nodes is in the same "context set" if all the node's
+ contexts are of length n and all the node's contexts share
+ the same prefix of length n-1.
+ - The "child context set" of a node is a set of nodes with
+ contexts that are one symbol longer and prefixed by the node's
+ context. For example, if a node has a context "abc" then the
+ nodes for contexts "abca", "abcb", "abcc", etc. are all in
+ the child context set of the node.
+ - The "parent context" of a node is the context that is one
+ symbol shorter than the node's context and includes the
+ symbol in the node. So the parent context of a node with
+ context "abcd" would be the context "bcd".
+
+
+ - if (next != 0) then
+ - next == pointer to the next node in the same context set
+ - if (child_context != 0) then
+ - child_context == pointer to the first node of the child
+ context set for this node.
+ - if (parent_context != 0) then
+ - parent_context == pointer to the parent context of this node.
+ - else
+ - this node is the root node of the tree
+
+
+ - if (this is not the root node) then
+ - symbol == the symbol represented with this node
+ - count == the number of times this symbol has been seen in its
+ parent context.
+ - else
+ - the root doesn't have a symbol. i.e. the context for the
+ root node is "" or the empty string.
+
+ - total == The sum of the counts of all the nodes
+ in the child context set + escapes.
+ - escapes == the escape count for the context represented
+ by the node.
+ }
+
+
+ - cur_order < order
+ - cur_order == the depth of the node cur in the tree.
+ (note that the root node has depth 0)
+ - cur == pointer to the node in the tree who's context matches
+ the most recent symbols we have seen.
+
+
+ !*/
+
+ typedef eemk4::node node;
+
+
+ public:
+
+ typedef entropy_encoder entropy_encoder_type;
+
+ entropy_encoder_model_kernel_4 (
+ entropy_encoder& coder
+ );
+
+ virtual ~entropy_encoder_model_kernel_4 (
+ );
+
+ inline void clear(
+ );
+
+ inline void encode (
+ unsigned long symbol
+ );
+
+ entropy_encoder& get_entropy_encoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ inline eemk4::node* allocate_node (
+ );
+ /*!
+ requires
+ - space_left() == true
+ ensures
+ - returns a pointer to a new node
+ !*/
+
+ inline void destroy_tree (
+ );
+ /*!
+ ensures
+ - deallocates all nodes except the root
+ - #root->child_context == 0
+ - #root->escapes == 0
+ - #root->total == 0
+ - #cur == root
+ - #cur_order == 0
+ !*/
+
+
+ inline bool space_left (
+ ) const;
+ /*!
+ ensures
+ - returns true if there is at least 1 free node left.
+ - returns false otherwise
+ !*/
+
+
+ inline void scale_counts (
+ node* n
+ );
+ /*!
+ ensures
+ - divides all the counts in the child context set of n by 2.
+ - none of the nodes in the child context set will have a count of 0
+ !*/
+
+
+ unsigned long next_node;
+ entropy_encoder& coder;
+ node* root;
+ node* cur;
+ unsigned long cur_order;
+
+
+ // restricted functions
+ entropy_encoder_model_kernel_4(entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>&); // copy constructor
+ entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>& operator=(entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>::
+ entropy_encoder_model_kernel_4 (
+ entropy_encoder& coder_
+ ) :
+ next_node(1),
+ coder(coder_),
+ cur_order(0)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 );
+ COMPILE_TIME_ASSERT( 4096 < total_nodes );
+
+ root = new node[total_nodes];
+ cur = root;
+
+ root->child_context = 0;
+ root->escapes = 0;
+ root->next = 0;
+ root->parent_context = 0;
+ root->total = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>::
+ ~entropy_encoder_model_kernel_4 (
+ )
+ {
+ delete [] root;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>::
+ clear(
+ )
+ {
+ destroy_tree();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>::
+ encode (
+ unsigned long sym
+ )
+ {
+ unsigned short symbol = static_cast<unsigned short>(sym);
+ node* temp = cur;
+ cur = 0;
+ unsigned short low_count, high_count, total_count;
+ node* new_node = 0;
+
+ // local_order will track the level of temp in the tree
+ unsigned long local_order = cur_order;
+
+ while (true)
+ {
+ high_count = 0;
+ if (space_left())
+ {
+ total_count = temp->total;
+
+ if (total_count > 0)
+ {
+ // check if we need to scale the counts
+ if (total_count > 10000)
+ {
+ scale_counts(temp);
+ total_count = temp->total;
+ }
+
+ // find either the symbol we are looking for or the
+ // end of the context set
+ node* n = temp->child_context;
+ node* last = 0;
+ while (true)
+ {
+ high_count += n->count;
+
+ if (n->symbol == symbol || n->next == 0)
+ break;
+ last = n;
+ n = n->next;
+ }
+
+ low_count = high_count - n->count;
+
+ // if we found the symbol
+ if (n->symbol == symbol)
+ {
+ if (new_node != 0)
+ {
+ new_node->parent_context = n;
+ }
+
+ coder.encode(low_count,high_count,total_count);
+ n->count += 8;
+ temp->total += 8;
+
+
+ // move this node to the front
+ if (last)
+ {
+ last->next = n->next;
+ n->next = temp->child_context;
+ temp->child_context = n;
+ }
+
+
+ if (cur == 0)
+ {
+ if (local_order < order)
+ {
+ cur_order = local_order+1;
+ cur = n;
+ }
+ else
+ {
+ cur = n->parent_context;
+ cur_order = local_order;
+ }
+ }
+
+ break;
+
+ }
+ // if we hit the end of the context set without finding the symbol
+ else
+ {
+ if (new_node != 0)
+ {
+ new_node->parent_context = allocate_node();
+ new_node = new_node->parent_context;
+ }
+ else
+ {
+ new_node = allocate_node();
+ }
+
+ n->next = new_node;
+
+ // write an escape to a lower context
+ coder.encode(high_count,total_count,total_count);
+ }
+
+ }
+ else // if (total_count == 0)
+ {
+ // this means that temp->child_context == 0 so we should make
+ // a new node here.
+ if (new_node != 0)
+ {
+ new_node->parent_context = allocate_node();
+ new_node = new_node->parent_context;
+ }
+ else
+ {
+ new_node = allocate_node();
+ }
+
+ temp->child_context = new_node;
+ }
+
+ if (cur == 0 && local_order < order)
+ {
+ cur = new_node;
+ cur_order = local_order+1;
+ }
+
+ // fill out the new node
+ new_node->child_context = 0;
+ new_node->count = 4;
+ new_node->escapes = 0;
+ new_node->next = 0;
+ new_node->symbol = static_cast<unsigned short>(symbol);
+ new_node->total = 0;
+ temp->escapes += 4;
+ temp->total += 8;
+
+
+ if (temp != root)
+ {
+ temp = temp->parent_context;
+ --local_order;
+ continue;
+ }
+
+ // since this is the root we are going to the order-(-1) context
+ // so we can just take care of that here.
+ new_node->parent_context = root;
+ coder.encode(symbol,symbol+1,alphabet_size);
+
+ if (cur == 0)
+ {
+ cur = root;
+ cur_order = 0;
+ }
+ break;
+ }
+ else
+ {
+ // there isn't enough space so we should rebuild the tree
+ destroy_tree();
+ temp = cur;
+ local_order = cur_order;
+ cur = 0;
+ new_node = 0;
+ }
+ } // while (true)
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ eemk4::node* entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>::
+ allocate_node (
+ )
+ {
+ node* temp;
+ temp = root + next_node;
+ ++next_node;
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>::
+ destroy_tree (
+ )
+ {
+ next_node = 1;
+ root->child_context = 0;
+ root->escapes = 0;
+ root->total = 0;
+ cur = root;
+ cur_order = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ bool entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>::
+ space_left (
+ ) const
+ {
+ return (next_node < total_nodes);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_4<alphabet_size,entropy_encoder,total_nodes,order>::
+ scale_counts (
+ node* temp
+ )
+ {
+ if (temp->escapes > 1)
+ temp->escapes >>= 1;
+ temp->total = temp->escapes;
+
+ node* n = temp->child_context;
+ while (n != 0)
+ {
+ if (n->count > 1)
+ n->count >>= 1;
+
+ temp->total += n->count;
+ n = n->next;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_4_
+
diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_5.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_5.h
new file mode 100644
index 000000000..6c0c30426
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_5.h
@@ -0,0 +1,817 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_5_
+#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_5_
+
+#include "../algs.h"
+#include "entropy_encoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+
+
+namespace dlib
+{
+
+ namespace eemk5
+ {
+ struct node
+ {
+ node* next;
+ node* child_context;
+ node* parent_context;
+
+ unsigned short symbol;
+ unsigned short count;
+ unsigned short total;
+ unsigned short escapes;
+ };
+ }
+
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ class entropy_encoder_model_kernel_5
+ {
+ /*!
+ REQUIREMENTS ON total_nodes
+ - 4096 < total_nodes
+ - this is the total number of nodes that we will use in the tree
+
+ REQUIREMENTS ON order
+ - 0 <= order
+ - this is the maximum depth-1 the tree will be allowed to go (note
+ that the root level is depth 0).
+
+ GENERAL NOTES
+ This implementation follows more or less the implementation
+ strategy laid out by Alistair Moffat in his paper
+ Implementing the PPM data compression scheme. Published in IEEE
+ Transactions on Communications, 38(11):1917-1921, 1990.
+
+ The escape method used will be method D.
+
+ This also uses Dmitry Shkarin's Information Inheritance scheme.
+ (described in "PPM: one step to practicality" and "Improving the
+ Efficiency of the PPM Algorithm")
+
+
+ INITIAL VALUE
+ - root == pointer to an array of total_nodes nodes
+ - next_node == 1
+ - cur == root
+ - cur_order = 0
+ - root->next == 0
+ - root->parent_context == 0
+ - root->child_context == 0
+ - root->escapes == 0
+ - root->total == 0
+ - stack_size == 0
+ - exc_used == false
+ - for all i: exc[i] == 0
+
+ CONVENTION
+ - pop() == stack[stack_size-1].n and stack[stack_size-1].nc
+ - exc_used == something_is_excluded()
+ - is_excluded(symbol) == bit symbol&0x1F from exc[symbol>>5]
+ - &get_entropy_encoder() == coder
+ - root == pointer to an array of total_nodes nodes.
+ this is also the root of the tree.
+ - if (next_node < total_nodes) then
+ - next_node == the next node in root that has not yet been allocated
+
+ - root->next == 0
+ - root->parent_context == 0
+
+
+ - for every node in the tree:
+ {
+ - NOTATION:
+ - The "context" of a node is the string of symbols seen
+ when you go from the root of the tree down (down though
+ child context pointers) to the node, including the symbol at
+ the node itself. (note that the context of the root node
+ is "" or the empty string)
+ - A set of nodes is in the same "context set" if all the node's
+ contexts are of length n and all the node's contexts share
+ the same prefix of length n-1.
+ - The "child context set" of a node is a set of nodes with
+ contexts that are one symbol longer and prefixed by the node's
+ context. For example, if a node has a context "abc" then the
+ nodes for contexts "abca", "abcb", "abcc", etc. are all in
+ the child context set of the node.
+ - The "parent context" of a node is the context that is one
+ symbol shorter than the node's context and includes the
+ symbol in the node. So the parent context of a node with
+ context "abcd" would be the context "bcd".
+
+
+ - if (next != 0) then
+ - next == pointer to the next node in the same context set
+ - if (child_context != 0) then
+ - child_context == pointer to the first node of the child
+ context set for this node.
+ - escapes > 0
+ - if (parent_context != 0) then
+ - parent_context == pointer to the parent context of this node.
+ - else
+ - this node is the root node of the tree
+
+
+ - if (this is not the root node) then
+ - symbol == the symbol represented with this node
+ - count == the number of times this symbol has been seen in its
+ parent context.
+ - else
+ - the root doesn't have a symbol. i.e. the context for the
+ root node is "" or the empty string.
+
+ - total == The sum of the counts of all the nodes
+ in the child context set + escapes.
+ - escapes == the escape count for the context represented
+ by the node.
+ - count > 0
+ }
+
+
+ - cur_order < order
+ - cur_order == the depth of the node cur in the tree.
+ (note that the root node has depth 0)
+ - cur == pointer to the node in the tree who's context matches
+ the most recent symbols we have seen.
+
+
+ !*/
+
+ typedef eemk5::node node;
+
+
+ public:
+
+ typedef entropy_encoder entropy_encoder_type;
+
+ entropy_encoder_model_kernel_5 (
+ entropy_encoder& coder
+ );
+
+ virtual ~entropy_encoder_model_kernel_5 (
+ );
+
+ inline void clear(
+ );
+
+ inline void encode (
+ unsigned long symbol
+ );
+
+ entropy_encoder& get_entropy_encoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ inline eemk5::node* allocate_node (
+ );
+ /*!
+ requires
+ - space_left() == true
+ ensures
+ - returns a pointer to a new node
+ !*/
+
+ inline bool space_left (
+ ) const;
+ /*!
+ ensures
+ - returns true if there is at least 1 free node left.
+ - returns false otherwise
+ !*/
+
+ inline void exclude (
+ unsigned short symbol
+ );
+ /*!
+ ensures
+ - #is_excluded(symbol) == true
+ - #something_is_excluded() == true
+ !*/
+
+ inline bool something_is_excluded (
+ );
+ /*!
+ ensures
+ - returns true if some symbol has been excluded.
+ returns false otherwise
+ !*/
+
+ inline bool is_excluded (
+ unsigned short symbol
+ );
+ /*!
+ ensures
+ - if (symbol has been excluded) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ inline void clear_exclusions (
+ );
+ /*!
+ ensures
+ - for all symbols #is_excluded(symbol) == false
+ - #something_is_excluded() == true
+ !*/
+
+ inline void scale_counts (
+ node* n
+ );
+ /*!
+ ensures
+ - divides all the counts in the child context set of n by 2.
+ - none of the nodes in the child context set will have a count of 0
+ !*/
+
+ inline void push (
+ node* n,
+ node* nc
+ );
+ /*!
+ requires
+ - stack_size < order
+ ensures
+ - #pop(a,b): a == n && b == nc
+ !*/
+
+ inline void pop (
+ node*& n,
+ node*& nc
+ );
+ /*!
+ requires
+ - stack_size > 0
+ ensures
+ - returns the two nodes at the top of the stack
+ !*/
+
+ struct nodes
+ {
+ node* n;
+ node* nc;
+ };
+
+ unsigned long next_node;
+ entropy_encoder& coder;
+ node* root;
+ node* cur;
+ unsigned long cur_order;
+ unsigned long exc[alphabet_size/32+1];
+ bool exc_used;
+ nodes stack[order+1];
+ unsigned long stack_size;
+
+ // restricted functions
+ entropy_encoder_model_kernel_5(entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>&); // copy constructor
+ entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>& operator=(entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ entropy_encoder_model_kernel_5 (
+ entropy_encoder& coder_
+ ) :
+ next_node(1),
+ coder(coder_),
+ cur_order(0),
+ stack_size(0)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 );
+ COMPILE_TIME_ASSERT( 4096 < total_nodes );
+
+ root = new node[total_nodes];
+ cur = root;
+
+ root->child_context = 0;
+ root->escapes = 0;
+ root->next = 0;
+ root->parent_context = 0;
+ root->total = 0;
+
+ clear_exclusions();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ ~entropy_encoder_model_kernel_5 (
+ )
+ {
+ delete [] root;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ clear(
+ )
+ {
+ next_node = 1;
+ root->child_context = 0;
+ root->escapes = 0;
+ root->total = 0;
+ cur = root;
+ cur_order = 0;
+ stack_size = 0;
+
+ clear_exclusions();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ encode (
+ unsigned long sym
+ )
+ {
+ unsigned short symbol = static_cast<unsigned short>(sym);
+ node* temp = cur;
+ cur = 0;
+ unsigned short low_count, high_count, total_count;
+ node* new_node = 0;
+
+ // local_order will track the level of temp in the tree
+ unsigned long local_order = cur_order;
+
+
+ unsigned short c; // c == t(a|sk)
+ unsigned short t; // t == T(sk)
+
+
+ if (something_is_excluded())
+ clear_exclusions();
+
+ while (true)
+ {
+ low_count = 0;
+ high_count = 0;
+ if (space_left())
+ {
+ total_count = temp->total;
+
+ if (total_count > 0)
+ {
+ // check if we need to scale the counts
+ if (total_count > 10000)
+ {
+ scale_counts(temp);
+ total_count = temp->total;
+ }
+
+
+ // find the symbol we are looking for and put a pointer to it
+ // into found_symbol. If it isn't found then found_symbol == 0.
+ // also, low_count and high_count will be correctly set.
+ node* n = temp->child_context;
+ node* found_symbol = 0;
+ node* last = 0;
+ if (something_is_excluded())
+ {
+ node* templast = 0;
+ while (true)
+ {
+ if (is_excluded(n->symbol) == false)
+ {
+ exclude(n->symbol);
+ if (found_symbol == 0)
+ {
+ high_count += n->count;
+ if (n->symbol == symbol)
+ {
+ found_symbol = n;
+ last = templast;
+ low_count = high_count - n->count;
+ }
+ }
+ }
+ else
+ {
+ total_count -= n->count;
+ }
+
+ if (n->next == 0)
+ break;
+ templast = n;
+ n = n->next;
+ }
+ }
+ else
+ {
+ while (true)
+ {
+ high_count += n->count;
+ exclude(n->symbol);
+
+ if (n->symbol == symbol)
+ {
+ found_symbol = n;
+ low_count = high_count - n->count;
+ break;
+ }
+
+ if (n->next == 0)
+ break;
+ last = n;
+ n = n->next;
+ }
+ }
+
+
+
+
+
+ // if we found the symbol
+ if (found_symbol)
+ {
+ n = found_symbol;
+ if (new_node != 0)
+ {
+ new_node->parent_context = found_symbol;
+ }
+
+
+ coder.encode(low_count,high_count,total_count);
+ c = n->count += 8;
+ t = temp->total += 8;
+
+ // move this node to the front
+ if (last)
+ {
+ last->next = n->next;
+ n->next = temp->child_context;
+ temp->child_context = n;
+ }
+
+
+ if (cur == 0)
+ {
+ if (local_order >= order)
+ {
+ cur = n->parent_context;
+ cur_order = local_order;
+ }
+ else
+ {
+ cur_order = local_order+1;
+ cur = n;
+ }
+ }
+
+ break;
+
+ }
+ // if we hit the end of the context set without finding the symbol
+ else
+ {
+ // finish excluding all the symbols
+ while (n->next)
+ {
+ exclude(n->symbol);
+ n = n->next;
+ }
+
+ if (new_node != 0)
+ {
+ new_node->parent_context = allocate_node();
+ new_node = new_node->parent_context;
+ }
+ else
+ {
+ new_node = allocate_node();
+ }
+
+ n->next = new_node;
+
+ // write an escape to a lower context
+ coder.encode(high_count,total_count,total_count);
+ }
+
+ }
+ else // if (total_count == 0)
+ {
+ // this means that temp->child_context == 0 so we should make
+ // a new node here.
+ if (new_node != 0)
+ {
+ new_node->parent_context = allocate_node();
+ new_node = new_node->parent_context;
+ }
+ else
+ {
+ new_node = allocate_node();
+ }
+
+ temp->child_context = new_node;
+ }
+
+ if (cur == 0 && local_order < order)
+ {
+ cur = new_node;
+ cur_order = local_order+1;
+ }
+
+ // fill out the new node
+ new_node->child_context = 0;
+ new_node->escapes = 0;
+ new_node->next = 0;
+ new_node->total = 0;
+ push(new_node,temp);
+
+ if (temp != root)
+ {
+ temp = temp->parent_context;
+ --local_order;
+ continue;
+ }
+
+ t = 2056;
+ c = 8;
+
+ // since this is the root we are going to the order-(-1) context
+ // so we can just take care of that here.
+ new_node->parent_context = root;
+ coder.encode(symbol,symbol+1,alphabet_size);
+
+ if (cur == 0)
+ {
+ cur = root;
+ cur_order = 0;
+ }
+ break;
+ }
+ else
+ {
+ // there isn't enough space so we should throw away the tree
+ clear();
+ temp = cur;
+ local_order = cur_order;
+ cur = 0;
+ new_node = 0;
+ }
+ } // while (true)
+
+
+ // initialize the counts and symbol for any new nodes we have added
+ // to the tree.
+ node* n, *nc;
+ while (stack_size > 0)
+ {
+ pop(n,nc);
+
+ n->symbol = static_cast<unsigned short>(symbol);
+
+ // if nc is not a determnistic context
+ if (nc->total)
+ {
+ unsigned long temp2 = t-c+nc->total - nc->escapes - nc->escapes;
+ unsigned long temp = nc->total;
+ temp *= c;
+ temp /= (temp2|1); // this oring by 1 is just to make sure that temp2 is never zero
+ temp += 2;
+ if (temp > 50000) temp = 50000;
+ n->count = static_cast<unsigned short>(temp);
+
+
+ nc->escapes += 4;
+ nc->total += static_cast<unsigned short>(temp) + 4;
+ }
+ else
+ {
+ n->count = 3 + 5*(c)/(t-c);
+
+ nc->escapes = 4;
+ nc->total = n->count + 4;
+ }
+
+ while (nc->total > 10000)
+ {
+ scale_counts(nc);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ eemk5::node* entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ allocate_node (
+ )
+ {
+ node* temp;
+ temp = root + next_node;
+ ++next_node;
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ bool entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ space_left (
+ ) const
+ {
+ return (next_node < total_nodes);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ exclude (
+ unsigned short symbol
+ )
+ {
+ exc_used = true;
+ unsigned long temp = 1;
+ temp <<= symbol&0x1F;
+ exc[symbol>>5] |= temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ bool entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ is_excluded (
+ unsigned short symbol
+ )
+ {
+ unsigned long temp = 1;
+ temp <<= symbol&0x1F;
+ return ((exc[symbol>>5]&temp) != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ clear_exclusions (
+ )
+ {
+ exc_used = false;
+ for (unsigned long i = 0; i < alphabet_size/32+1; ++i)
+ {
+ exc[i] = 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ bool entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ something_is_excluded (
+ )
+ {
+ return exc_used;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ push (
+ node* n,
+ node* nc
+ )
+ {
+ stack[stack_size].n = n;
+ stack[stack_size].nc = nc;
+ ++stack_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ pop (
+ node*& n,
+ node*& nc
+ )
+ {
+ --stack_size;
+ n = stack[stack_size].n;
+ nc = stack[stack_size].nc;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder,
+ unsigned long total_nodes,
+ unsigned long order
+ >
+ void entropy_encoder_model_kernel_5<alphabet_size,entropy_encoder,total_nodes,order>::
+ scale_counts (
+ node* temp
+ )
+ {
+ if (temp->escapes > 1)
+ temp->escapes >>= 1;
+ temp->total = temp->escapes;
+
+ node* n = temp->child_context;
+ while (n != 0)
+ {
+ if (n->count > 1)
+ n->count >>= 1;
+
+ temp->total += n->count;
+ n = n->next;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_5_
+
+
diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_6.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_6.h
new file mode 100644
index 000000000..2199bfbe4
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_6.h
@@ -0,0 +1,127 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_6_
+#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_6_
+
+#include "../algs.h"
+#include "entropy_encoder_model_kernel_abstract.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder
+ >
+ class entropy_encoder_model_kernel_6
+ {
+ /*!
+ INITIAL VALUE
+ Initially this object's finite context model is empty
+
+ CONVENTION
+ &get_entropy_encoder() == coder
+
+ This is an order-(-1) model. So it doesn't really do anything.
+ Every symbol has the same probability.
+ !*/
+
+ public:
+
+ typedef entropy_encoder entropy_encoder_type;
+
+ entropy_encoder_model_kernel_6 (
+ entropy_encoder& coder
+ );
+
+ virtual ~entropy_encoder_model_kernel_6 (
+ );
+
+ inline void clear(
+ );
+
+ inline void encode (
+ unsigned long symbol
+ );
+
+ entropy_encoder& get_entropy_encoder (
+ ) { return coder; }
+
+ static unsigned long get_alphabet_size (
+ ) { return alphabet_size; }
+
+ private:
+
+ entropy_encoder& coder;
+
+ // restricted functions
+ entropy_encoder_model_kernel_6(entropy_encoder_model_kernel_6<alphabet_size,entropy_encoder>&); // copy constructor
+ entropy_encoder_model_kernel_6<alphabet_size,entropy_encoder>& operator=(entropy_encoder_model_kernel_6<alphabet_size,entropy_encoder>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder
+ >
+ entropy_encoder_model_kernel_6<alphabet_size,entropy_encoder>::
+ entropy_encoder_model_kernel_6 (
+ entropy_encoder& coder_
+ ) :
+ coder(coder_)
+ {
+ COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder
+ >
+ entropy_encoder_model_kernel_6<alphabet_size,entropy_encoder>::
+ ~entropy_encoder_model_kernel_6 (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder
+ >
+ void entropy_encoder_model_kernel_6<alphabet_size,entropy_encoder>::
+ clear(
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder
+ >
+ void entropy_encoder_model_kernel_6<alphabet_size,entropy_encoder>::
+ encode (
+ unsigned long symbol
+ )
+ {
+ // use order minus one context
+ coder.encode(symbol,symbol+1,alphabet_size);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_6_
+
diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_abstract.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_abstract.h
new file mode 100644
index 000000000..fb5f01bc7
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_abstract.h
@@ -0,0 +1,118 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_ABSTRACT_
+#ifdef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned long alphabet_size,
+ typename entropy_encoder
+ >
+ class entropy_encoder_model
+ {
+ /*!
+ REQUIREMENTS ON alphabet_size
+ 1 < alphabet_size < 65535
+
+ REQUIREMENTS ON entropy_encoder
+ is an implementation of entropy_encoder/entropy_encoder_kernel_abstract.h
+
+ INITIAL VALUE
+ Initially this object is at some predefined empty or ground state.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents some kind of statistical model. You
+ can use it to write symbols to an entropy_encoder and it will calculate
+ the cumulative counts/probabilities and manage contexts for you.
+
+ Note that all implementations of entropy_encoder_model and
+ entropy_decoder_model are paired. This means that if you use
+ entropy_encoder_model_kernel_n to encode something then you must
+ use the corresponding entropy_decoder_model_kernel_n to decode it.
+
+ Also note that this object does not perform any buffering of symbols. It
+ writes them to its associated entropy_encoder immediately.
+ This makes it safe to use multiple entropy_encoder_model objects with
+ a single entropy_encoder without them trampling each other.
+ !*/
+
+ public:
+
+ typedef entropy_encoder entropy_encoder_type;
+
+ entropy_encoder_model (
+ entropy_encoder& coder
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - &#get_entropy_encoder() == &coder
+ throws
+ - any exception
+ !*/
+
+ virtual ~entropy_encoder_model (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ - does not modify get_entropy_encoder()
+ throws
+ - any exception
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void encode (
+ unsigned long symbol
+ );
+ /*!
+ requires
+ - symbol < alphabet_size
+ ensures
+ - encodes and writes the symbol to get_entropy_encoder().
+ This also means that there is no internal buffering. symbol is
+ written immediately to the entropy_encoder.
+ throws
+ - any exception
+ If this exception is thrown then #*this is unusable until
+ clear() is called and succeeds.
+ !*/
+
+ entropy_encoder& get_entropy_encoder (
+ );
+ /*!
+ ensures
+ - returns a reference to the entropy_encoder used by *this
+ !*/
+
+ static unsigned long get_alphabet_size (
+ );
+ /*!
+ ensures
+ - returns alphabet_size
+ !*/
+
+ private:
+
+ // restricted functions
+ entropy_encoder_model(entropy_encoder_model<alphabet_size,entropy_encoder>&); // copy constructor
+ entropy_encoder_model<alphabet_size,entropy_encoder>& operator=(entropy_encoder_model<alphabet_size,entropy_encoder>&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_c.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_c.h
new file mode 100644
index 000000000..4637ddd1e
--- /dev/null
+++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_c.h
@@ -0,0 +1,65 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_C_
+#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_C_
+
+#include "entropy_encoder_model_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename eem_base
+ >
+ class entropy_encoder_model_kernel_c : public eem_base
+ {
+ const unsigned long alphabet_size;
+ typedef typename eem_base::entropy_encoder_type entropy_encoder;
+
+ public:
+
+ entropy_encoder_model_kernel_c (
+ entropy_encoder& coder
+ ) : eem_base(coder), alphabet_size(eem_base::get_alphabet_size()) {}
+
+ void encode (
+ unsigned long symbol
+ );
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename eem_base
+ >
+ void entropy_encoder_model_kernel_c<eem_base>::
+ encode (
+ unsigned long symbol
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(symbol < alphabet_size,
+ "\tvoid entropy_encoder_model::encode()"
+ << "\n\tthe symbol must be in the range 0 to alphabet_size-1"
+ << "\n\talphabet_size: " << alphabet_size
+ << "\n\tsymbol: " << symbol
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ eem_base::encode(symbol);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_C_
+
diff --git a/ml/dlib/dlib/error.h b/ml/dlib/dlib/error.h
new file mode 100644
index 000000000..ce9b95b1a
--- /dev/null
+++ b/ml/dlib/dlib/error.h
@@ -0,0 +1,449 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ERROr_
+#define DLIB_ERROr_
+
+#include <string>
+#include <new> // for std::bad_alloc
+#include <iostream>
+#include <cassert>
+#include <cstdlib>
+#include <exception>
+
+// -------------------------------
+// ------ exception classes ------
+// -------------------------------
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ enum error_type
+ {
+ EPORT_IN_USE,
+ ETIMEOUT,
+ ECONNECTION,
+ ELISTENER,
+ ERESOLVE,
+ EMONITOR,
+ ECREATE_THREAD,
+ ECREATE_MUTEX,
+ ECREATE_SIGNALER,
+ EUNSPECIFIED,
+ EGENERAL_TYPE1,
+ EGENERAL_TYPE2,
+ EGENERAL_TYPE3,
+ EINVALID_OPTION,
+ ETOO_FEW_ARGS,
+ ETOO_MANY_ARGS,
+ ESOCKET,
+ ETHREAD,
+ EGUI,
+ EFATAL,
+ EBROKEN_ASSERT,
+ EIMAGE_LOAD,
+ EDIR_CREATE,
+ EINCOMPATIBLE_OPTIONS,
+ EMISSING_REQUIRED_OPTION,
+ EINVALID_OPTION_ARG,
+ EMULTIPLE_OCCURANCES,
+ ECONFIG_READER,
+ EIMAGE_SAVE,
+ ECAST_TO_STRING,
+ ESTRING_CAST,
+ EUTF8_TO_UTF32,
+ EOPTION_PARSE
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ // the base exception class
+ class error : public std::exception
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the base exception class for the dlib library. i.e. all
+ exceptions in this library inherit from this class.
+ !*/
+
+ public:
+ error(
+ error_type t,
+ const std::string& a
+ ): info(a), type(t) {}
+ /*!
+ ensures
+ - #type == t
+ - #info == a
+ !*/
+
+ error(
+ error_type t
+ ): type(t) {}
+ /*!
+ ensures
+ - #type == t
+ - #info == ""
+ !*/
+
+ error(
+ const std::string& a
+ ): info(a), type(EUNSPECIFIED) {}
+ /*!
+ ensures
+ - #type == EUNSPECIFIED
+ - #info == a
+ !*/
+
+ error(
+ ): type(EUNSPECIFIED) {}
+ /*!
+ ensures
+ - #type == EUNSPECIFIED
+ - #info == ""
+ !*/
+
+ virtual ~error(
+ ) throw() {}
+ /*!
+ ensures
+ - does nothing
+ !*/
+
+ const char* what(
+ ) const throw()
+ /*!
+ ensures
+ - if (info.size() != 0) then
+ - returns info.c_str()
+ - else
+ - returns type_to_string(type)
+ !*/
+ {
+ if (info.size() > 0)
+ return info.c_str();
+ else
+ return type_to_string();
+ }
+
+ const char* type_to_string (
+ ) const throw()
+ /*!
+ ensures
+ - returns a string that names the contents of the type member.
+ !*/
+ {
+ if ( type == EPORT_IN_USE) return "EPORT_IN_USE";
+ else if ( type == ETIMEOUT) return "ETIMEOUT";
+ else if ( type == ECONNECTION) return "ECONNECTION";
+ else if ( type == ELISTENER) return "ELISTENER";
+ else if ( type == ERESOLVE) return "ERESOLVE";
+ else if ( type == EMONITOR) return "EMONITOR";
+ else if ( type == ECREATE_THREAD) return "ECREATE_THREAD";
+ else if ( type == ECREATE_MUTEX) return "ECREATE_MUTEX";
+ else if ( type == ECREATE_SIGNALER) return "ECREATE_SIGNALER";
+ else if ( type == EUNSPECIFIED) return "EUNSPECIFIED";
+ else if ( type == EGENERAL_TYPE1) return "EGENERAL_TYPE1";
+ else if ( type == EGENERAL_TYPE2) return "EGENERAL_TYPE2";
+ else if ( type == EGENERAL_TYPE3) return "EGENERAL_TYPE3";
+ else if ( type == EINVALID_OPTION) return "EINVALID_OPTION";
+ else if ( type == ETOO_FEW_ARGS) return "ETOO_FEW_ARGS";
+ else if ( type == ETOO_MANY_ARGS) return "ETOO_MANY_ARGS";
+ else if ( type == ESOCKET) return "ESOCKET";
+ else if ( type == ETHREAD) return "ETHREAD";
+ else if ( type == EGUI) return "EGUI";
+ else if ( type == EFATAL) return "EFATAL";
+ else if ( type == EBROKEN_ASSERT) return "EBROKEN_ASSERT";
+ else if ( type == EIMAGE_LOAD) return "EIMAGE_LOAD";
+ else if ( type == EDIR_CREATE) return "EDIR_CREATE";
+ else if ( type == EINCOMPATIBLE_OPTIONS) return "EINCOMPATIBLE_OPTIONS";
+ else if ( type == EMISSING_REQUIRED_OPTION) return "EMISSING_REQUIRED_OPTION";
+ else if ( type == EINVALID_OPTION_ARG) return "EINVALID_OPTION_ARG";
+ else if ( type == EMULTIPLE_OCCURANCES) return "EMULTIPLE_OCCURANCES";
+ else if ( type == ECONFIG_READER) return "ECONFIG_READER";
+ else if ( type == EIMAGE_SAVE) return "EIMAGE_SAVE";
+ else if ( type == ECAST_TO_STRING) return "ECAST_TO_STRING";
+ else if ( type == ESTRING_CAST) return "ESTRING_CAST";
+ else if ( type == EUTF8_TO_UTF32) return "EUTF8_TO_UTF32";
+ else if ( type == EOPTION_PARSE) return "EOPTION_PARSE";
+ else return "undefined error type";
+ }
+
+ const std::string info; // info about the error
+ const error_type type; // the type of the error
+
+ private:
+ const error& operator=(const error&);
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class fatal_error : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ As the name says, this object represents some kind of fatal error.
+ That is, it represents an unrecoverable error and any program that
+ throws this exception is, by definition, buggy and needs to be fixed.
+
+ Note that a fatal_error exception can only be thrown once. The second
+ time an application attempts to construct a fatal_error it will be
+ immediately aborted and an error message will be printed to std::cerr.
+ The reason for this is because the first fatal_error was apparently ignored
+ so the second fatal_error is going to make itself impossible to ignore
+ by calling abort. The lesson here is that you should not try to ignore
+ fatal errors.
+
+ This is also the exception thrown by the DLIB_ASSERT and DLIB_CASSERT macros.
+ !*/
+
+ public:
+ fatal_error(
+ error_type t,
+ const std::string& a
+ ): error(t,a) {check_for_previous_fatal_errors();}
+ /*!
+ ensures
+ - #type == t
+ - #info == a
+ !*/
+
+ fatal_error(
+ error_type t
+ ): error(t) {check_for_previous_fatal_errors();}
+ /*!
+ ensures
+ - #type == t
+ - #info == ""
+ !*/
+
+ fatal_error(
+ const std::string& a
+ ): error(EFATAL,a) {check_for_previous_fatal_errors();}
+ /*!
+ ensures
+ - #type == EFATAL
+ - #info == a
+ !*/
+
+ fatal_error(
+ ): error(EFATAL) {check_for_previous_fatal_errors();}
+ /*!
+ ensures
+ - #type == EFATAL
+ - #info == ""
+ !*/
+
+ private:
+
+ static inline char* message ()
+ {
+ static char buf[2000];
+ buf[1999] = '\0'; // just to be extra safe
+ return buf;
+ }
+
+ static inline void dlib_fatal_error_terminate (
+ )
+ {
+ std::cerr << "\n**************************** FATAL ERROR DETECTED ****************************";
+ std::cerr << message() << std::endl;
+ std::cerr << "******************************************************************************\n" << std::endl;
+ }
+
+ void check_for_previous_fatal_errors()
+ {
+ // If dlib is being use to create plugins for some other application, like
+ // MATLAB, then don't do these checks since it terminates the over arching
+ // system. Just let the errors go to the plugin handler and it will deal with
+ // them.
+#if defined(MATLAB_MEX_FILE) || defined(DLIB_NO_ABORT_ON_2ND_FATAL_ERROR)
+ return;
+#else
+ static bool is_first_fatal_error = true;
+ if (is_first_fatal_error == false)
+ {
+ std::cerr << "\n\n ************************** FATAL ERROR DETECTED ************************** " << std::endl;
+ std::cerr << " ************************** FATAL ERROR DETECTED ************************** " << std::endl;
+ std::cerr << " ************************** FATAL ERROR DETECTED ************************** \n" << std::endl;
+ std::cerr << "Two fatal errors have been detected, the first was inappropriately ignored. \n"
+ << "To prevent further fatal errors from being ignored this application will be \n"
+ << "terminated immediately and you should go fix this buggy program.\n\n"
+ << "The error message from this fatal error was:\n" << this->what() << "\n\n" << std::endl;
+ using namespace std;
+ assert(false);
+ abort();
+ }
+ else
+ {
+ // copy the message into the fixed message buffer so that it can be recalled by dlib_fatal_error_terminate
+ // if needed.
+ char* msg = message();
+ unsigned long i;
+ for (i = 0; i < 2000-1 && i < this->info.size(); ++i)
+ msg[i] = info[i];
+ msg[i] = '\0';
+
+ // set this termination handler so that if the user doesn't catch this dlib::fatal_error that is being
+ // thrown then it will eventually be printed to standard error
+ std::set_terminate(&dlib_fatal_error_terminate);
+ }
+ is_first_fatal_error = false;
+#endif
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class gui_error : public error
+ {
+ public:
+ gui_error(
+ error_type t,
+ const std::string& a
+ ): error(t,a) {}
+ /*!
+ ensures
+ - #type == t
+ - #info == a
+ !*/
+
+ gui_error(
+ error_type t
+ ): error(t) {}
+ /*!
+ ensures
+ - #type == t
+ - #info == ""
+ !*/
+
+ gui_error(
+ const std::string& a
+ ): error(EGUI,a) {}
+ /*!
+ ensures
+ - #type == EGUI
+ - #info == a
+ !*/
+
+ gui_error(
+ ): error(EGUI) {}
+ /*!
+ ensures
+ - #type == EGUI
+ - #info == ""
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class socket_error : public error
+ {
+ public:
+ socket_error(
+ error_type t,
+ const std::string& a
+ ): error(t,a) {}
+ /*!
+ ensures
+ - #type == t
+ - #info == a
+ !*/
+
+ socket_error(
+ error_type t
+ ): error(t) {}
+ /*!
+ ensures
+ - #type == t
+ - #info == ""
+ !*/
+
+ socket_error(
+ const std::string& a
+ ): error(ESOCKET,a) {}
+ /*!
+ ensures
+ - #type == ESOCKET
+ - #info == a
+ !*/
+
+ socket_error(
+ ): error(ESOCKET) {}
+ /*!
+ ensures
+ - #type == ESOCKET
+ - #info == ""
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class thread_error : public error
+ {
+ public:
+ thread_error(
+ error_type t,
+ const std::string& a
+ ): error(t,a) {}
+ /*!
+ ensures
+ - #type == t
+ - #info == a
+ !*/
+
+ thread_error(
+ error_type t
+ ): error(t) {}
+ /*!
+ ensures
+ - #type == t
+ - #info == ""
+ !*/
+
+ thread_error(
+ const std::string& a
+ ): error(ETHREAD,a) {}
+ /*!
+ ensures
+ - #type == ETHREAD
+ - #info == a
+ !*/
+
+ thread_error(
+ ): error(ETHREAD) {}
+ /*!
+ ensures
+ - #type == ETHREAD
+ - #info == ""
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class impossible_labeling_error : public dlib::error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown by code that trains object detectors (e.g.
+ structural_svm_object_detection_problem) when they detect that the set of
+ truth boxes given to the training algorithm contains some impossible to
+ obtain outputs.
+
+ This kind of problem can happen when the set of image positions scanned by
+ the underlying object detection method doesn't include the truth rectangle
+ as a possible output. Another possibility is when two truth boxes are very
+ close together and hard coded non-max suppression logic would prevent two
+ boxes in such close proximity from being output.
+ !*/
+ public:
+ impossible_labeling_error(const std::string& msg) : dlib::error(msg) {};
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ERROr_
+
diff --git a/ml/dlib/dlib/external/cblas/CMakeLists.txt b/ml/dlib/dlib/external/cblas/CMakeLists.txt
new file mode 100644
index 000000000..0d800ae13
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/CMakeLists.txt
@@ -0,0 +1,182 @@
+#
+# This is a CMake makefile. You can find the cmake utility and
+# information about it at http://www.cmake.org
+#
+
+
+cmake_minimum_required(VERSION 2.8.12)
+project(cblas)
+
+
+enable_language (Fortran)
+
+set(CMAKE_POSITION_INDEPENDENT_CODE True)
+add_definitions(-DADD_ -DF77_INT=ptrdiff_t)
+
+add_library(cblas STATIC
+ cblas_caxpy.c
+ #cblas_ccopy.c
+ cblas_cdotc_sub.c
+ cblas_cdotu_sub.c
+ #cblas_cgbmv.c
+ cblas_cgemm.c
+ cblas_cgemv.c
+ cblas_cgerc.c
+ cblas_cgeru.c
+ #cblas_chbmv.c
+ #cblas_chemm.c
+ #cblas_chemv.c
+ #cblas_cher2.c
+ #cblas_cher2k.c
+ #cblas_cher.c
+ #cblas_cherk.c
+ #cblas_chpmv.c
+ #cblas_chpr2.c
+ #cblas_chpr.c
+ cblas_cscal.c
+ #cblas_csscal.c
+ #cblas_cswap.c
+ #cblas_csymm.c
+ #cblas_csyr2k.c
+ #cblas_csyrk.c
+ #cblas_ctbmv.c
+ #cblas_ctbsv.c
+ #cblas_ctpmv.c
+ #cblas_ctpsv.c
+ #cblas_ctrmm.c
+ #cblas_ctrmv.c
+ cblas_ctrsm.c
+ #cblas_ctrsv.c
+ #cblas_dasum.c
+ cblas_daxpy.c
+ #cblas_dcopy.c
+ cblas_ddot.c
+ #cblas_dgbmv.c
+ cblas_dgemm.c
+ cblas_dgemv.c
+ cblas_dger.c
+ #cblas_dnrm2.c
+ #cblas_drot.c
+ #cblas_drotg.c
+ #cblas_drotm.c
+ #cblas_drotmg.c
+ #cblas_dsbmv.c
+ cblas_dscal.c
+ #cblas_dsdot.c
+ #cblas_dspmv.c
+ #cblas_dspr2.c
+ #cblas_dspr.c
+ #cblas_dswap.c
+ #cblas_dsymm.c
+ #cblas_dsymv.c
+ #cblas_dsyr2.c
+ #cblas_dsyr2k.c
+ #cblas_dsyr.c
+ #cblas_dsyrk.c
+ #cblas_dtbmv.c
+ #cblas_dtbsv.c
+ #cblas_dtpmv.c
+ #cblas_dtpsv.c
+ #cblas_dtrmm.c
+ #cblas_dtrmv.c
+ cblas_dtrsm.c
+ #cblas_dtrsv.c
+ #cblas_dzasum.c
+ #cblas_dznrm2.c
+ #cblas_icamax.c
+ #cblas_idamax.c
+ #cblas_isamax.c
+ #cblas_izamax.c
+ #cblas_sasum.c
+ cblas_saxpy.c
+ #cblas_scasum.c
+ #cblas_scnrm2.c
+ #cblas_scopy.c
+ cblas_sdot.c
+ #cblas_sdsdot.c
+ #cblas_sgbmv.c
+ cblas_sgemm.c
+ cblas_sgemv.c
+ cblas_sger.c
+ #cblas_snrm2.c
+ #cblas_srot.c
+ #cblas_srotg.c
+ #cblas_srotm.c
+ #cblas_srotmg.c
+ #cblas_ssbmv.c
+ cblas_sscal.c
+ #cblas_sspmv.c
+ #cblas_sspr2.c
+ #cblas_sspr.c
+ #cblas_sswap.c
+ #cblas_ssymm.c
+ #cblas_ssymv.c
+ #cblas_ssyr2.c
+ #cblas_ssyr2k.c
+ #cblas_ssyr.c
+ #cblas_ssyrk.c
+ #cblas_stbmv.c
+ #cblas_stbsv.c
+ #cblas_stpmv.c
+ #cblas_stpsv.c
+ #cblas_strmm.c
+ #cblas_strmv.c
+ cblas_strsm.c
+ #cblas_strsv.c
+ cblas_xerbla.c
+ cblas_zaxpy.c
+ #cblas_zcopy.c
+ cblas_zdotc_sub.c
+ cblas_zdotu_sub.c
+ #cblas_zdscal.c
+ #cblas_zgbmv.c
+ cblas_zgemm.c
+ cblas_zgemv.c
+ cblas_zgerc.c
+ cblas_zgeru.c
+ #cblas_zhbmv.c
+ #cblas_zhemm.c
+ #cblas_zhemv.c
+ #cblas_zher2.c
+ #cblas_zher2k.c
+ #cblas_zher.c
+ #cblas_zherk.c
+ #cblas_zhpmv.c
+ #cblas_zhpr2.c
+ #cblas_zhpr.c
+ cblas_zscal.c
+ #cblas_zswap.c
+ #cblas_zsymm.c
+ #cblas_zsyr2k.c
+ #cblas_zsyrk.c
+ #cblas_ztbmv.c
+ #cblas_ztbsv.c
+ #cblas_ztpmv.c
+ #cblas_ztpsv.c
+ #cblas_ztrmm.c
+ #cblas_ztrmv.c
+ cblas_ztrsm.c
+ #cblas_ztrsv.c
+
+ cdotcsub.f
+ cdotusub.f
+ dasumsub.f
+ ddotsub.f
+ dnrm2sub.f
+ dsdotsub.f
+ dzasumsub.f
+ dznrm2sub.f
+ icamaxsub.f
+ idamaxsub.f
+ isamaxsub.f
+ izamaxsub.f
+ sasumsub.f
+ scasumsub.f
+ scnrm2sub.f
+ sdotsub.f
+ sdsdotsub.f
+ snrm2sub.f
+ zdotcsub.f
+ zdotusub.f
+ )
+
diff --git a/ml/dlib/dlib/external/cblas/README b/ml/dlib/dlib/external/cblas/README
new file mode 100644
index 000000000..a89feaffe
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/README
@@ -0,0 +1,7 @@
+This folder contains a copy of CBLAS (from http://www.netlib.org/blas/) which
+has been setup so you can compile it with CMake. It also only compiles the
+part of CBLAS needed by dlib.
+
+Most BLAS libraries come with CBLAS, however, some don't. In particular, if
+you are using the BLAS that comes with MATLAB then you will need this CBLAS
+code linked into your own to get dlib working with MATLAB's built in BLAS.
diff --git a/ml/dlib/dlib/external/cblas/cblas.h b/ml/dlib/dlib/external/cblas/cblas.h
new file mode 100644
index 000000000..f91557e74
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas.h
@@ -0,0 +1,575 @@
+#ifndef CBLAS_H
+#define CBLAS_H
+#include <stddef.h>
+
+/*
+ * Enumerated and derived types
+ */
+#define CBLAS_INDEX size_t /* this may vary between platforms */
+
+enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
+enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
+enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
+enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
+enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*
+ * ===========================================================================
+ * Prototypes for level 1 BLAS functions (complex are recast as routines)
+ * ===========================================================================
+ */
+float cblas_sdsdot(const int N, const float alpha, const float *X,
+ const int incX, const float *Y, const int incY);
+double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
+ const int incY);
+float cblas_sdot(const int N, const float *X, const int incX,
+ const float *Y, const int incY);
+double cblas_ddot(const int N, const double *X, const int incX,
+ const double *Y, const int incY);
+
+/*
+ * Functions having prefixes Z and C only
+ */
+void cblas_cdotu_sub(const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotu);
+void cblas_cdotc_sub(const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotc);
+
+void cblas_zdotu_sub(const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotu);
+void cblas_zdotc_sub(const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotc);
+
+
+/*
+ * Functions having prefixes S D SC DZ
+ */
+float cblas_snrm2(const int N, const float *X, const int incX);
+float cblas_sasum(const int N, const float *X, const int incX);
+
+double cblas_dnrm2(const int N, const double *X, const int incX);
+double cblas_dasum(const int N, const double *X, const int incX);
+
+float cblas_scnrm2(const int N, const void *X, const int incX);
+float cblas_scasum(const int N, const void *X, const int incX);
+
+double cblas_dznrm2(const int N, const void *X, const int incX);
+double cblas_dzasum(const int N, const void *X, const int incX);
+
+
+/*
+ * Functions having standard 4 prefixes (S D C Z)
+ */
+CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
+CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
+CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
+CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
+
+/*
+ * ===========================================================================
+ * Prototypes for level 1 BLAS routines
+ * ===========================================================================
+ */
+
+/*
+ * Routines with standard 4 prefixes (s, d, c, z)
+ */
+void cblas_sswap(const int N, float *X, const int incX,
+ float *Y, const int incY);
+void cblas_scopy(const int N, const float *X, const int incX,
+ float *Y, const int incY);
+void cblas_saxpy(const int N, const float alpha, const float *X,
+ const int incX, float *Y, const int incY);
+
+void cblas_dswap(const int N, double *X, const int incX,
+ double *Y, const int incY);
+void cblas_dcopy(const int N, const double *X, const int incX,
+ double *Y, const int incY);
+void cblas_daxpy(const int N, const double alpha, const double *X,
+ const int incX, double *Y, const int incY);
+
+void cblas_cswap(const int N, void *X, const int incX,
+ void *Y, const int incY);
+void cblas_ccopy(const int N, const void *X, const int incX,
+ void *Y, const int incY);
+void cblas_caxpy(const int N, const void *alpha, const void *X,
+ const int incX, void *Y, const int incY);
+
+void cblas_zswap(const int N, void *X, const int incX,
+ void *Y, const int incY);
+void cblas_zcopy(const int N, const void *X, const int incX,
+ void *Y, const int incY);
+void cblas_zaxpy(const int N, const void *alpha, const void *X,
+ const int incX, void *Y, const int incY);
+
+
+/*
+ * Routines with S and D prefix only
+ */
+void cblas_srotg(float *a, float *b, float *c, float *s);
+void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
+void cblas_srot(const int N, float *X, const int incX,
+ float *Y, const int incY, const float c, const float s);
+void cblas_srotm(const int N, float *X, const int incX,
+ float *Y, const int incY, const float *P);
+
+void cblas_drotg(double *a, double *b, double *c, double *s);
+void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
+void cblas_drot(const int N, double *X, const int incX,
+ double *Y, const int incY, const double c, const double s);
+void cblas_drotm(const int N, double *X, const int incX,
+ double *Y, const int incY, const double *P);
+
+
+/*
+ * Routines with S D C Z CS and ZD prefixes
+ */
+void cblas_sscal(const int N, const float alpha, float *X, const int incX);
+void cblas_dscal(const int N, const double alpha, double *X, const int incX);
+void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
+void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
+void cblas_csscal(const int N, const float alpha, void *X, const int incX);
+void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
+
+/*
+ * ===========================================================================
+ * Prototypes for level 2 BLAS
+ * ===========================================================================
+ */
+
+/*
+ * Routines with standard 4 prefixes (S, D, C, Z)
+ */
+void cblas_sgemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ const float *X, const int incX, const float beta,
+ float *Y, const int incY);
+void cblas_sgbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const int KL, const int KU, const float alpha,
+ const float *A, const int lda, const float *X,
+ const int incX, const float beta, float *Y, const int incY);
+void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const float *A, const int lda,
+ float *X, const int incX);
+void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const float *A, const int lda,
+ float *X, const int incX);
+void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const float *Ap, float *X, const int incX);
+void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const float *A, const int lda, float *X,
+ const int incX);
+void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const float *A, const int lda,
+ float *X, const int incX);
+void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const float *Ap, float *X, const int incX);
+
+void cblas_dgemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ const double *X, const int incX, const double beta,
+ double *Y, const int incY);
+void cblas_dgbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const int KL, const int KU, const double alpha,
+ const double *A, const int lda, const double *X,
+ const int incX, const double beta, double *Y, const int incY);
+void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const double *A, const int lda,
+ double *X, const int incX);
+void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const double *A, const int lda,
+ double *X, const int incX);
+void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const double *Ap, double *X, const int incX);
+void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const double *A, const int lda, double *X,
+ const int incX);
+void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const double *A, const int lda,
+ double *X, const int incX);
+void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const double *Ap, double *X, const int incX);
+
+void cblas_cgemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY);
+void cblas_cgbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const int KL, const int KU, const void *alpha,
+ const void *A, const int lda, const void *X,
+ const int incX, const void *beta, void *Y, const int incY);
+void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *A, const int lda,
+ void *X, const int incX);
+void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const void *A, const int lda,
+ void *X, const int incX);
+void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *Ap, void *X, const int incX);
+void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *A, const int lda, void *X,
+ const int incX);
+void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const void *A, const int lda,
+ void *X, const int incX);
+void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *Ap, void *X, const int incX);
+
+void cblas_zgemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY);
+void cblas_zgbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const int KL, const int KU, const void *alpha,
+ const void *A, const int lda, const void *X,
+ const int incX, const void *beta, void *Y, const int incY);
+void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *A, const int lda,
+ void *X, const int incX);
+void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const void *A, const int lda,
+ void *X, const int incX);
+void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *Ap, void *X, const int incX);
+void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *A, const int lda, void *X,
+ const int incX);
+void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const void *A, const int lda,
+ void *X, const int incX);
+void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *Ap, void *X, const int incX);
+
+
+/*
+ * Routines with S and D prefixes only
+ */
+void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *A,
+ const int lda, const float *X, const int incX,
+ const float beta, float *Y, const int incY);
+void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const int K, const float alpha, const float *A,
+ const int lda, const float *X, const int incX,
+ const float beta, float *Y, const int incY);
+void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *Ap,
+ const float *X, const int incX,
+ const float beta, float *Y, const int incY);
+void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
+ const float alpha, const float *X, const int incX,
+ const float *Y, const int incY, float *A, const int lda);
+void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *X,
+ const int incX, float *A, const int lda);
+void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *X,
+ const int incX, float *Ap);
+void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *X,
+ const int incX, const float *Y, const int incY, float *A,
+ const int lda);
+void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *X,
+ const int incX, const float *Y, const int incY, float *A);
+
+void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *A,
+ const int lda, const double *X, const int incX,
+ const double beta, double *Y, const int incY);
+void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const int K, const double alpha, const double *A,
+ const int lda, const double *X, const int incX,
+ const double beta, double *Y, const int incY);
+void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *Ap,
+ const double *X, const int incX,
+ const double beta, double *Y, const int incY);
+void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
+ const double alpha, const double *X, const int incX,
+ const double *Y, const int incY, double *A, const int lda);
+void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *X,
+ const int incX, double *A, const int lda);
+void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *X,
+ const int incX, double *Ap);
+void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *X,
+ const int incX, const double *Y, const int incY, double *A,
+ const int lda);
+void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *X,
+ const int incX, const double *Y, const int incY, double *A);
+
+
+/*
+ * Routines with C and Z prefixes only
+ */
+void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const void *alpha, const void *A,
+ const int lda, const void *X, const int incX,
+ const void *beta, void *Y, const int incY);
+void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const int K, const void *alpha, const void *A,
+ const int lda, const void *X, const int incX,
+ const void *beta, void *Y, const int incY);
+void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const void *alpha, const void *Ap,
+ const void *X, const int incX,
+ const void *beta, void *Y, const int incY);
+void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const void *X, const int incX,
+ void *A, const int lda);
+void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const void *X,
+ const int incX, void *A);
+void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *Ap);
+
+void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const void *alpha, const void *A,
+ const int lda, const void *X, const int incX,
+ const void *beta, void *Y, const int incY);
+void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const int K, const void *alpha, const void *A,
+ const int lda, const void *X, const int incX,
+ const void *beta, void *Y, const int incY);
+void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const void *alpha, const void *Ap,
+ const void *X, const int incX,
+ const void *beta, void *Y, const int incY);
+void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const void *X, const int incX,
+ void *A, const int lda);
+void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const void *X,
+ const int incX, void *A);
+void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *Ap);
+
+/*
+ * ===========================================================================
+ * Prototypes for level 3 BLAS
+ * ===========================================================================
+ */
+
+/*
+ * Routines with standard 4 prefixes (S, D, C, Z)
+ */
+void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const float alpha, const float *A,
+ const int lda, const float *B, const int ldb,
+ const float beta, float *C, const int ldc);
+void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ const float *B, const int ldb, const float beta,
+ float *C, const int ldc);
+void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const float alpha, const float *A, const int lda,
+ const float beta, float *C, const int ldc);
+void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const float alpha, const float *A, const int lda,
+ const float *B, const int ldb, const float beta,
+ float *C, const int ldc);
+void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ float *B, const int ldb);
+void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ float *B, const int ldb);
+
+void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const double alpha, const double *A,
+ const int lda, const double *B, const int ldb,
+ const double beta, double *C, const int ldc);
+void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ const double *B, const int ldb, const double beta,
+ double *C, const int ldc);
+void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const double alpha, const double *A, const int lda,
+ const double beta, double *C, const int ldc);
+void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const double alpha, const double *A, const int lda,
+ const double *B, const int ldb, const double beta,
+ double *C, const int ldc);
+void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ double *B, const int ldb);
+void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ double *B, const int ldb);
+
+void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const void *alpha, const void *A,
+ const int lda, const void *B, const int ldb,
+ const void *beta, void *C, const int ldc);
+void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc);
+void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *beta, void *C, const int ldc);
+void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc);
+void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ void *B, const int ldb);
+void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ void *B, const int ldb);
+
+void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const void *alpha, const void *A,
+ const int lda, const void *B, const int ldb,
+ const void *beta, void *C, const int ldc);
+void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc);
+void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *beta, void *C, const int ldc);
+void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc);
+void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ void *B, const int ldb);
+void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ void *B, const int ldb);
+
+
+/*
+ * Routines with prefixes C and Z only
+ */
+void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc);
+void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const float alpha, const void *A, const int lda,
+ const float beta, void *C, const int ldc);
+void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const float beta,
+ void *C, const int ldc);
+
+void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc);
+void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const double alpha, const void *A, const int lda,
+ const double beta, void *C, const int ldc);
+void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const double beta,
+ void *C, const int ldc);
+
+void cblas_xerbla(int p, const char *rout, const char *form, ...);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/ml/dlib/dlib/external/cblas/cblas_caxpy.c b/ml/dlib/dlib/external/cblas/cblas_caxpy.c
new file mode 100644
index 000000000..7579aa707
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_caxpy.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_caxpy.c
+ *
+ * The program is a C interface to caxpy.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_caxpy( const int N, const void *alpha, const void *X,
+ const int incX, void *Y, const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_caxpy( &F77_N, alpha, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ccopy.c b/ml/dlib/dlib/external/cblas/cblas_ccopy.c
new file mode 100644
index 000000000..b7bc42847
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ccopy.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_ccopy.c
+ *
+ * The program is a C interface to ccopy.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ccopy( const int N, const void *X,
+ const int incX, void *Y, const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_ccopy( &F77_N, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cdotc_sub.c b/ml/dlib/dlib/external/cblas/cblas_cdotc_sub.c
new file mode 100644
index 000000000..d6086814e
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cdotc_sub.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_cdotc_sub.c
+ *
+ * The program is a C interface to cdotc.
+ * It calls the fortran wrapper before calling cdotc.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cdotc_sub( const int N, const void *X, const int incX,
+ const void *Y, const int incY,void *dotc)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_cdotc_sub( &F77_N, X, &F77_incX, Y, &F77_incY, dotc);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cdotu_sub.c b/ml/dlib/dlib/external/cblas/cblas_cdotu_sub.c
new file mode 100644
index 000000000..d06e4e5fa
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cdotu_sub.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_cdotu_sub.f
+ *
+ * The program is a C interface to cdotu.
+ * It calls the forteran wrapper before calling cdotu.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cdotu_sub( const int N, const void *X,
+ const int incX, const void *Y, const int incY,void *dotu)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_cdotu_sub( &F77_N, X, &F77_incX, Y, &F77_incY, dotu);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cgbmv.c b/ml/dlib/dlib/external/cblas/cblas_cgbmv.c
new file mode 100644
index 000000000..94cc175f3
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cgbmv.c
@@ -0,0 +1,154 @@
+/*
+ * cblas_cgbmv.c
+ * The program is a C interface of cgbmv
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cgbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const int KL, const int KU,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char TA;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA;
+#else
+ #define F77_TA &TA
+#endif
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+ F77_INT F77_KL=KL,F77_KU=KU;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_KL KL
+ #define F77_KU KU
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+ int n=0, i=0, incx=incX;
+ const float *xx= (float *)X, *alp= (float *)alpha, *bet = (float *)beta;
+ float ALPHA[2],BETA[2];
+ int tincY, tincx;
+ float *x=(float *)X, *y=(float *)Y, *st=0, *tx=0;
+
+ if (order == CblasColMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(2, "cblas_cgbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_cgbmv(F77_TA, &F77_M, &F77_N, &F77_KL, &F77_KU, alpha,
+ A, &F77_lda, X, &F77_incX, beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ BETA[0]= *bet;
+ BETA[1]= -bet[1];
+ TA = 'N';
+ if (M > 0)
+ {
+ n = M << 1;
+ x = malloc(n*sizeof(float));
+ tx = x;
+
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+
+ if( incY > 0 )
+ tincY = incY;
+ else
+ tincY = -incY;
+
+ y++;
+
+ if (N > 0)
+ {
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ }
+ }
+ else x = (float *) X;
+
+
+ }
+ else
+ {
+ cblas_xerbla(2, "cblas_cgbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ if (TransA == CblasConjTrans)
+ F77_cgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, ALPHA,
+ A ,&F77_lda, x,&F77_incX, BETA, Y, &F77_incY);
+ else
+ F77_cgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, alpha,
+ A ,&F77_lda, x,&F77_incX, beta, Y, &F77_incY);
+ if (TransA == CblasConjTrans)
+ {
+ if (x != X) free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_cgbmv", "Illegal Order setting, %d\n", order);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cgemm.c b/ml/dlib/dlib/external/cblas/cblas_cgemm.c
new file mode 100644
index 000000000..c11641023
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cgemm.c
@@ -0,0 +1,94 @@
+/*
+ *
+ * cblas_cgemm.c
+ * This program is a C interface to cgemm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const void *alpha, const void *A,
+ const int lda, const void *B, const int ldb,
+ const void *beta, void *C, const int ldc)
+{
+ char TA, TB;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_TB;
+#else
+ #define F77_TA &TA
+ #define F77_TB &TB
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if(TransA == CblasTrans) TA='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_cgemm", "Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if(TransB == CblasTrans) TB='T';
+ else if ( TransB == CblasConjTrans ) TB='C';
+ else if ( TransB == CblasNoTrans ) TB='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_cgemm", "Illegal TransB setting, %d\n", TransB);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ F77_TB = C2F_CHAR(&TB);
+ #endif
+
+ F77_cgemm(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, alpha, A,
+ &F77_lda, B, &F77_ldb, beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if(TransA == CblasTrans) TB='T';
+ else if ( TransA == CblasConjTrans ) TB='C';
+ else if ( TransA == CblasNoTrans ) TB='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_cgemm", "Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if(TransB == CblasTrans) TA='T';
+ else if ( TransB == CblasConjTrans ) TA='C';
+ else if ( TransB == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_cgemm", "Illegal TransB setting, %d\n", TransB);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ F77_TB = C2F_CHAR(&TB);
+ #endif
+
+ F77_cgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, alpha, B,
+ &F77_ldb, A, &F77_lda, beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_cgemm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cgemv.c b/ml/dlib/dlib/external/cblas/cblas_cgemv.c
new file mode 100644
index 000000000..a1cbb94ee
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cgemv.c
@@ -0,0 +1,151 @@
+/*
+ * cblas_cgemv.c
+ * The program is a C interface of cgemv
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cgemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char TA;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA;
+#else
+ #define F77_TA &TA
+#endif
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+
+ int n=0, i=0, incx=incX;
+ const float *xx= (const float *)X;
+ float ALPHA[2],BETA[2];
+ int tincY, tincx;
+ float *x=(float *)X, *y=(float *)Y, *st=0, *tx=0;
+ const float *stx = x;
+
+
+ if (order == CblasColMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(2, "cblas_cgemv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_cgemv(F77_TA, &F77_M, &F77_N, alpha, A, &F77_lda, X, &F77_incX,
+ beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ ALPHA[0]= *( (const float *) alpha );
+ ALPHA[1]= -( *( (const float *) alpha+1) );
+ BETA[0]= *( (const float *) beta );
+ BETA[1]= -( *( (const float *) beta+1 ) );
+ TA = 'N';
+ if (M > 0)
+ {
+ n = M << 1;
+ x = malloc(n*sizeof(float));
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+ F77_incX = 1;
+
+ if(incY > 0)
+ tincY = incY;
+ else
+ tincY = -incY;
+
+ y++;
+
+ if (N > 0)
+ {
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ }
+ stx = x;
+ }
+ else stx = (const float *)X;
+ }
+ else
+ {
+ cblas_xerbla(2, "cblas_cgemv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ if (TransA == CblasConjTrans)
+ F77_cgemv(F77_TA, &F77_N, &F77_M, ALPHA, A, &F77_lda, stx,
+ &F77_incX, BETA, Y, &F77_incY);
+ else
+ F77_cgemv(F77_TA, &F77_N, &F77_M, alpha, A, &F77_lda, x,
+ &F77_incX, beta, Y, &F77_incY);
+
+ if (TransA == CblasConjTrans)
+ {
+ if (x != (const float *)X) free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_cgemv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cgerc.c b/ml/dlib/dlib/external/cblas/cblas_cgerc.c
new file mode 100644
index 000000000..e843f099b
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cgerc.c
@@ -0,0 +1,77 @@
+/*
+ * cblas_cgerc.c
+ * The program is a C interface to cgerc.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda)
+{
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incy
+ #define F77_lda lda
+#endif
+
+ int n, i, tincy, incy=incY;
+ float *y=(float *)Y, *yy=(float *)Y, *ty, *st;
+
+
+ if (order == CblasColMajor)
+ {
+ F77_cgerc( &F77_M, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+ } else if (order == CblasRowMajor)
+ {
+ if (N > 0)
+ {
+ n = N << 1;
+ y = malloc(n*sizeof(float));
+
+ ty = y;
+ if( incY > 0 ) {
+ i = incY << 1;
+ tincy = 2;
+ st= y+n;
+ } else {
+ i = incY *(-2);
+ tincy = -2;
+ st = y-2;
+ y +=(n-2);
+ }
+ do
+ {
+ *y = *yy;
+ y[1] = -yy[1];
+ y += tincy ;
+ yy += i;
+ }
+ while (y != st);
+ y = ty;
+
+ #ifdef F77_INT
+ F77_incY = 1;
+ #else
+ incy = 1;
+ #endif
+ }
+ else y = (float *) Y;
+
+ F77_cgeru( &F77_N, &F77_M, alpha, y, &F77_incY, X, &F77_incX, A,
+ &F77_lda);
+ if(Y!=y)
+ free(y);
+
+ } else cblas_xerbla(1, "cblas_cgerc", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cgeru.c b/ml/dlib/dlib/external/cblas/cblas_cgeru.c
new file mode 100644
index 000000000..4471d1f80
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cgeru.c
@@ -0,0 +1,38 @@
+/*
+ * cblas_cgeru.c
+ * The program is a C interface to cgeru.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda)
+{
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+ #define F77_lda lda
+#endif
+
+
+
+ if (order == CblasColMajor)
+ {
+ F77_cgeru( &F77_M, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+ }
+ else if (order == CblasRowMajor)
+ {
+ F77_cgeru( &F77_N, &F77_M, alpha, Y, &F77_incY, X, &F77_incX, A,
+ &F77_lda);
+ }
+ else cblas_xerbla(1, "cblas_cgeru","Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_chbmv.c b/ml/dlib/dlib/external/cblas/cblas_chbmv.c
new file mode 100644
index 000000000..8681fa345
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_chbmv.c
@@ -0,0 +1,145 @@
+/*
+ * cblas_chbmv.c
+ * The program is a C interface to chbmv
+ *
+ * Keita Teranishi 5/18/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+#include <stdio.h>
+#include <stdlib.h>
+void cblas_chbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo,const int N,const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+ int n, i=0, incx=incX;
+ const float *xx= (float *)X, *alp= (float *)alpha, *bet = (float *)beta;
+ float ALPHA[2],BETA[2];
+ int tincY, tincx;
+ float *x=(float *)X, *y=(float *)Y, *st=0, *tx;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_chbmv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_chbmv(F77_UL, &F77_N, &F77_K, alpha, A, &F77_lda, X,
+ &F77_incX, beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ BETA[0]= *bet;
+ BETA[1]= -bet[1];
+
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(float));
+
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+
+ if(incY > 0)
+ tincY = incY;
+ else
+ tincY = -incY;
+ y++;
+
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ } else
+ x = (float *) X;
+
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_chbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_chbmv(F77_UL, &F77_N, &F77_K, ALPHA,
+ A ,&F77_lda, x,&F77_incX, BETA, Y, &F77_incY);
+ }
+ else
+ {
+ cblas_xerbla(1, "cblas_chbmv","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if ( order == CblasRowMajor )
+ {
+ if(X!=x)
+ free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_chemm.c b/ml/dlib/dlib/external/cblas/cblas_chemm.c
new file mode 100644
index 000000000..e64f0d0f3
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_chemm.c
@@ -0,0 +1,91 @@
+/*
+ *
+ * cblas_chemm.c
+ * This program is a C interface to chemm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc)
+{
+ char SD, UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_SD, F77_UL;
+#else
+ #define F77_SD &SD
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_chemm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_chemm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_chemm(F77_SD, F77_UL, &F77_M, &F77_N, alpha, A, &F77_lda,
+ B, &F77_ldb, beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_chemm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_chemm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_chemm(F77_SD, F77_UL, &F77_N, &F77_M, alpha, A,
+ &F77_lda, B, &F77_ldb, beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_chemm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_chemv.c b/ml/dlib/dlib/external/cblas/cblas_chemv.c
new file mode 100644
index 000000000..bbaaefdc2
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_chemv.c
@@ -0,0 +1,146 @@
+/*
+ * cblas_chemv.c
+ * The program is a C interface to chemv
+ *
+ * Keita Teranishi 5/18/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_chemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+ int n=0, i=0, incx=incX;
+ const float *xx= (float *)X, *alp= (float *)alpha, *bet = (float *)beta;
+ float ALPHA[2],BETA[2];
+ int tincY, tincx;
+ float *x=(float *)X, *y=(float *)Y, *st=0, *tx;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_chemv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_chemv(F77_UL, &F77_N, alpha, A, &F77_lda, X, &F77_incX,
+ beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ BETA[0]= *bet;
+ BETA[1]= -bet[1];
+
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(float));
+
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+
+ if(incY > 0)
+ tincY = incY;
+ else
+ tincY = -incY;
+ y++;
+
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ } else
+ x = (float *) X;
+
+
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_chemv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_chemv(F77_UL, &F77_N, ALPHA, A, &F77_lda, x, &F77_incX,
+ BETA, Y, &F77_incY);
+ }
+ else
+ {
+ cblas_xerbla(1, "cblas_chemv","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if ( order == CblasRowMajor )
+ {
+ if ( X != x )
+ free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cher.c b/ml/dlib/dlib/external/cblas/cblas_cher.c
new file mode 100644
index 000000000..580413b02
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cher.c
@@ -0,0 +1,103 @@
+/*
+ * cblas_cher.c
+ * The program is a C interface to cher.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const void *X, const int incX
+ ,void *A, const int lda)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incx
+#endif
+ int n, i, tincx, incx=incX;
+ float *x=(float *)X, *xx=(float *)X, *tx, *st;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_cher","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_cher(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_cher","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(float));
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+ }
+ else x = (float *) X;
+ F77_cher(F77_UL, &F77_N, &alpha, x, &F77_incX, A, &F77_lda);
+ } else
+ {
+ cblas_xerbla(1, "cblas_cher","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if(X!=x)
+ free(x);
+
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cher2.c b/ml/dlib/dlib/external/cblas/cblas_cher2.c
new file mode 100644
index 000000000..89d36a0ee
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cher2.c
@@ -0,0 +1,139 @@
+/*
+ * cblas_cher2.c
+ * The program is a C interface to cher2.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incx
+ #define F77_incY incy
+#endif
+ int n, i, j, tincx, tincy, incx=incX, incy=incY;
+ float *x=(float *)X, *xx=(float *)X, *y=(float *)Y,
+ *yy=(float *)Y, *tx, *ty, *stx, *sty;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_cher2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_cher2(F77_UL, &F77_N, alpha, X, &F77_incX,
+ Y, &F77_incY, A, &F77_lda);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_cher2","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(float));
+ y = malloc(n*sizeof(float));
+ tx = x;
+ ty = y;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ stx= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ stx = x-2;
+ x +=(n-2);
+ }
+
+ if( incY > 0 ) {
+ j = incY << 1;
+ tincy = 2;
+ sty= y+n;
+ } else {
+ j = incY *(-2);
+ tincy = -2;
+ sty = y-2;
+ y +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != stx);
+
+ do
+ {
+ *y = *yy;
+ y[1] = -yy[1];
+ y += tincy ;
+ yy += j;
+ }
+ while (y != sty);
+
+ x=tx;
+ y=ty;
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ F77_incY = 1;
+ #else
+ incx = 1;
+ incy = 1;
+ #endif
+ } else
+ {
+ x = (float *) X;
+ y = (float *) Y;
+ }
+ F77_cher2(F77_UL, &F77_N, alpha, y, &F77_incY, x,
+ &F77_incX, A, &F77_lda);
+ } else
+ {
+ cblas_xerbla(1, "cblas_cher2","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if(X!=x)
+ free(x);
+ if(Y!=y)
+ free(y);
+
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cher2k.c b/ml/dlib/dlib/external/cblas/cblas_cher2k.c
new file mode 100644
index 000000000..cad1c432e
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cher2k.c
@@ -0,0 +1,96 @@
+/*
+ *
+ * cblas_cher2k.c
+ * This program is a C interface to cher2k.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const float beta,
+ void *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+ float ALPHA[2];
+ const float *alp=(float *)alpha;
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_cher2k", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_cher2k", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_cher2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(2, "cblas_cher2k", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='C';
+ else
+ {
+ cblas_xerbla(3, "cblas_cher2k", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ F77_cher2k(F77_UL,F77_TR, &F77_N, &F77_K, ALPHA, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_cher2k", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cherk.c b/ml/dlib/dlib/external/cblas/cblas_cherk.c
new file mode 100644
index 000000000..0b6362db7
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cherk.c
@@ -0,0 +1,90 @@
+/*
+ *
+ * cblas_cherk.c
+ * This program is a C interface to cherk.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const float alpha, const void *A, const int lda,
+ const float beta, void *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_cherk", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_cherk", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_cherk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda,
+ &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_cherk", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='C';
+ else
+ {
+ cblas_xerbla(3, "cblas_cherk", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_cherk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda,
+ &beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_cherk", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_chpmv.c b/ml/dlib/dlib/external/cblas/cblas_chpmv.c
new file mode 100644
index 000000000..048734760
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_chpmv.c
@@ -0,0 +1,146 @@
+/*
+ * cblas_chpmv.c
+ * The program is a C interface of chpmv
+ *
+ * Keita Teranishi 5/18/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_chpmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo,const int N,
+ const void *alpha, const void *AP,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+ int n, i=0, incx=incX;
+ const float *xx= (float *)X, *alp= (float *)alpha, *bet = (float *)beta;
+ float ALPHA[2],BETA[2];
+ int tincY, tincx;
+ float *x=(float *)X, *y=(float *)Y, *st=0, *tx;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_chpmv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_chpmv(F77_UL, &F77_N, alpha, AP, X,
+ &F77_incX, beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ BETA[0]= *bet;
+ BETA[1]= -bet[1];
+
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(float));
+
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+
+ if(incY > 0)
+ tincY = incY;
+ else
+ tincY = -incY;
+ y++;
+
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ } else
+ x = (float *) X;
+
+
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_chpmv","Illegal Uplo setting, %d\n", Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_chpmv(F77_UL, &F77_N, ALPHA,
+ AP, x, &F77_incX, BETA, Y, &F77_incY);
+ }
+ else
+ {
+ cblas_xerbla(1, "cblas_chpmv","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if ( order == CblasRowMajor )
+ {
+ if(X!=x)
+ free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_chpr.c b/ml/dlib/dlib/external/cblas/cblas_chpr.c
new file mode 100644
index 000000000..72796b158
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_chpr.c
@@ -0,0 +1,102 @@
+/*
+ * cblas_chpr.c
+ * The program is a C interface to chpr.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const void *X,
+ const int incX, void *A)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incx
+#endif
+ int n, i, tincx, incx=incX;
+ float *x=(float *)X, *xx=(float *)X, *tx, *st;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_chpr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_chpr(F77_UL, &F77_N, &alpha, X, &F77_incX, A);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_chpr","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(float));
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+ }
+ else x = (float *) X;
+
+ F77_chpr(F77_UL, &F77_N, &alpha, x, &F77_incX, A);
+
+ } else
+ {
+ cblas_xerbla(1, "cblas_chpr","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if(X!=x)
+ free(x);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_chpr2.c b/ml/dlib/dlib/external/cblas/cblas_chpr2.c
new file mode 100644
index 000000000..f80d087aa
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_chpr2.c
@@ -0,0 +1,136 @@
+/*
+ * cblas_chpr2.c
+ * The program is a C interface to chpr2.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N,const void *alpha, const void *X,
+ const int incX,const void *Y, const int incY, void *Ap)
+
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incx
+ #define F77_incY incy
+#endif
+ int n, i, j, tincx, tincy, incx=incX, incy=incY;
+ float *x=(float *)X, *xx=(float *)X, *y=(float *)Y,
+ *yy=(float *)Y, *tx, *ty, *stx, *sty;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_chpr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_chpr2(F77_UL, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, Ap);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_chpr2","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(float));
+ y = malloc(n*sizeof(float));
+ tx = x;
+ ty = y;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ stx= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ stx = x-2;
+ x +=(n-2);
+ }
+
+ if( incY > 0 ) {
+ j = incY << 1;
+ tincy = 2;
+ sty= y+n;
+ } else {
+ j = incY *(-2);
+ tincy = -2;
+ sty = y-2;
+ y +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != stx);
+ do
+ {
+ *y = *yy;
+ y[1] = -yy[1];
+ y += tincy ;
+ yy += j;
+ }
+ while (y != sty);
+
+ x=tx;
+ y=ty;
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ F77_incY = 1;
+ #else
+ incx = 1;
+ incy = 1;
+ #endif
+
+ } else
+ {
+ x = (float *) X;
+ y = (void *) Y;
+ }
+ F77_chpr2(F77_UL, &F77_N, alpha, y, &F77_incY, x, &F77_incX, Ap);
+ } else
+ {
+ cblas_xerbla(1, "cblas_chpr2","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if(X!=x)
+ free(x);
+ if(Y!=y)
+ free(y);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cscal.c b/ml/dlib/dlib/external/cblas/cblas_cscal.c
new file mode 100644
index 000000000..a23e6ee57
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cscal.c
@@ -0,0 +1,21 @@
+/*
+ * cblas_cscal.c
+ *
+ * The program is a C interface to cscal.f.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cscal( const int N, const void *alpha, void *X,
+ const int incX)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_cscal( &F77_N, alpha, X, &F77_incX);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_csscal.c b/ml/dlib/dlib/external/cblas/cblas_csscal.c
new file mode 100644
index 000000000..39983fe07
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_csscal.c
@@ -0,0 +1,21 @@
+/*
+ * cblas_csscal.c
+ *
+ * The program is a C interface to csscal.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_csscal( const int N, const float alpha, void *X,
+ const int incX)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_csscal( &F77_N, &alpha, X, &F77_incX);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_cswap.c b/ml/dlib/dlib/external/cblas/cblas_cswap.c
new file mode 100644
index 000000000..127282072
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_cswap.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_cswap.c
+ *
+ * The program is a C interface to cswap.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_cswap( const int N, void *X, const int incX, void *Y,
+ const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_cswap( &F77_N, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_csymm.c b/ml/dlib/dlib/external/cblas/cblas_csymm.c
new file mode 100644
index 000000000..a462b5ebd
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_csymm.c
@@ -0,0 +1,91 @@
+/*
+ *
+ * cblas_csymm.c
+ * This program is a C interface to csymm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc)
+{
+ char SD, UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_SD, F77_UL;
+#else
+ #define F77_SD &SD
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_csymm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_csymm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_csymm(F77_SD, F77_UL, &F77_M, &F77_N, alpha, A, &F77_lda,
+ B, &F77_ldb, beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_csymm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_csymm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_csymm(F77_SD, F77_UL, &F77_N, &F77_M, alpha, A, &F77_lda,
+ B, &F77_ldb, beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_csymm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_csyr2k.c b/ml/dlib/dlib/external/cblas/cblas_csyr2k.c
new file mode 100644
index 000000000..c93facb2b
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_csyr2k.c
@@ -0,0 +1,93 @@
+/*
+ *
+ * cblas_csyr2k.c
+ * This program is a C interface to csyr2k.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_csyr2k", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_csyr2k", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_csyr2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda,
+ B, &F77_ldb, beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_csyr2k", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='T';
+ else
+ {
+ cblas_xerbla(3, "cblas_csyr2k", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_csyr2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, B, &F77_ldb, beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_csyr2k", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_csyrk.c b/ml/dlib/dlib/external/cblas/cblas_csyrk.c
new file mode 100644
index 000000000..4ff0bd535
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_csyrk.c
@@ -0,0 +1,93 @@
+/*
+ *
+ * cblas_csyrk.c
+ * This program is a C interface to csyrk.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *beta, void *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_csyrk", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_csyrk", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_csyrk(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda,
+ beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_csyrk", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='T';
+ else
+ {
+ cblas_xerbla(3, "cblas_csyrk", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_csyrk(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda,
+ beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_csyrk", "Illegal Order setting, %d\n", Order);
+ return;
+}
+
diff --git a/ml/dlib/dlib/external/cblas/cblas_ctbmv.c b/ml/dlib/dlib/external/cblas/cblas_ctbmv.c
new file mode 100644
index 000000000..0b313d858
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ctbmv.c
@@ -0,0 +1,139 @@
+/*
+ * cblas_ctbmv.c
+ * The program is a C interface to ctbmv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const void *A, const int lda,
+ void *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ float *st=0, *x=(float *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctbmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ctbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if(incX > 0)
+ tincX = incX;
+ else
+ tincX = -incX;
+ i = tincX << 1;
+ n = i * N;
+ x++;
+ st = x + n;
+ do
+ {
+ *x = -(*x);
+ x+= i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ctbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ctbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ctbmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ctbsv.c b/ml/dlib/dlib/external/cblas/cblas_ctbsv.c
new file mode 100644
index 000000000..31f3f5bb0
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ctbsv.c
@@ -0,0 +1,143 @@
+/*
+ * cblas_ctbsv.c
+ * The program is a C interface to ctbsv.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const void *A, const int lda,
+ void *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ float *st=0,*x=(float *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctbsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctbsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctbsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ctbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctbsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if ( incX > 0 )
+ tincX = incX;
+ else
+ tincX = -incX;
+
+ n = N*2*(tincX);
+
+ x++;
+
+ st=x+n;
+
+ i = tincX << 1;
+ do
+ {
+ *x = -(*x);
+ x+=i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ctbsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctbsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ctbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x+= i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ctbsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ctpmv.c b/ml/dlib/dlib/external/cblas/cblas_ctpmv.c
new file mode 100644
index 000000000..03dad131e
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ctpmv.c
@@ -0,0 +1,133 @@
+/*
+ * cblas_ctpmv.c
+ * The program is a C interface to ctpmv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *Ap, void *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ float *st=0,*x=(float *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctpmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctpmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctpmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ctpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctpmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if(incX > 0)
+ tincX = incX;
+ else
+ tincX = -incX;
+ i = tincX << 1;
+ n = i * N;
+ x++;
+ st = x + n;
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ctpmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctpmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ctpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX);
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ctpmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ctpsv.c b/ml/dlib/dlib/external/cblas/cblas_ctpsv.c
new file mode 100644
index 000000000..3023306bd
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ctpsv.c
@@ -0,0 +1,138 @@
+/*
+ * cblas_ctpsv.c
+ * The program is a C interface to ctpsv.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *Ap, void *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ float *st=0, *x=(float*)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctpsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctpsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctpsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ctpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctpsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if ( incX > 0 )
+ tincX = incX;
+ else
+ tincX = -incX;
+
+ n = N*2*(tincX);
+
+ x++;
+
+ st=x+n;
+
+ i = tincX << 1;
+ do
+ {
+ *x = -(*x);
+ x+=i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ctpsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctpsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ctpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX);
+
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ctpsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ctrmm.c b/ml/dlib/dlib/external/cblas/cblas_ctrmm.c
new file mode 100644
index 000000000..ceeb2a5ff
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ctrmm.c
@@ -0,0 +1,123 @@
+/*
+ *
+ * cblas_ctrmm.c
+ * This program is a C interface to ctrmm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ void *B, const int ldb)
+{
+ char UL, TA, SD, DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_SD &SD
+ #define F77_DI &DI
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight ) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctrmm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+ if( Uplo == CblasUpper ) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctrmm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans ) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctrmm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else cblas_xerbla(5, "cblas_ctrmm",
+ "Illegal Diag setting, %d\n", Diag);
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ctrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, alpha, A, &F77_lda, B, &F77_ldb);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight ) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctrmm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper ) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctrmm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans ) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctrmm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_ctrmm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ctrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, alpha, A, &F77_lda, B, &F77_ldb);
+ }
+ else cblas_xerbla(1, "cblas_ctrmm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ctrmv.c b/ml/dlib/dlib/external/cblas/cblas_ctrmv.c
new file mode 100644
index 000000000..dc590fe76
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ctrmv.c
@@ -0,0 +1,136 @@
+/*
+ * cblas_ctrmv.c
+ * The program is a C interface to ctrmv.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *A, const int lda,
+ void *X, const int incX)
+
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ float *st=0,*x=(float *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctrmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctrmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctrmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ctrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctrmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if(incX > 0)
+ tincX = incX;
+ else
+ tincX = -incX;
+ i = tincX << 1;
+ n = i * N;
+ st = x + n;
+ do
+ {
+ x[1] = -x[1];
+ x+= i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ctrmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctrmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ctrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ x[1] = -x[1];
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ctrmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ctrsm.c b/ml/dlib/dlib/external/cblas/cblas_ctrsm.c
new file mode 100644
index 000000000..ae2d8796e
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ctrsm.c
@@ -0,0 +1,132 @@
+/*
+ *
+ * cblas_ctrsm.c
+ * This program is a C interface to ctrsm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ void *B, const int ldb)
+{
+ char UL, TA, SD, DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_SD &SD
+ #define F77_DI &DI
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctrsm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctrsm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctrsm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_ctrsm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ctrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, alpha, A,
+ &F77_lda, B, &F77_ldb);
+ } else if (Order == CblasRowMajor)
+ {
+
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctrsm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctrsm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctrsm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_ctrsm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+
+ F77_ctrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, alpha, A,
+ &F77_lda, B, &F77_ldb);
+ }
+ else cblas_xerbla(1, "cblas_ctrsm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ctrsv.c b/ml/dlib/dlib/external/cblas/cblas_ctrsv.c
new file mode 100644
index 000000000..8bfacd913
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ctrsv.c
@@ -0,0 +1,137 @@
+/*
+ * cblas_ctrsv.c
+ * The program is a C interface to ctrsv.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *A, const int lda, void *X,
+ const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ float *st=0,*x=(float *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctrsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ctrsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctrsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ctrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ctrsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if ( incX > 0 )
+ tincX = incX;
+ else
+ tincX = -incX;
+
+ n = N*2*(tincX);
+ x++;
+ st=x+n;
+ i = tincX << 1;
+ do
+ {
+ *x = -(*x);
+ x+=i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ctrsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ctrsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ctrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ctrsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dasum.c b/ml/dlib/dlib/external/cblas/cblas_dasum.c
new file mode 100644
index 000000000..1a3667f2d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dasum.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_dasum.c
+ *
+ * The program is a C interface to dasum.
+ * It calls the fortran wrapper before calling dasum.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+double cblas_dasum( const int N, const double *X, const int incX)
+{
+ double asum;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_dasum_sub( &F77_N, X, &F77_incX, &asum);
+ return asum;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_daxpy.c b/ml/dlib/dlib/external/cblas/cblas_daxpy.c
new file mode 100644
index 000000000..3678137fb
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_daxpy.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_daxpy.c
+ *
+ * The program is a C interface to daxpy.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_daxpy( const int N, const double alpha, const double *X,
+ const int incX, double *Y, const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_daxpy( &F77_N, &alpha, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dcopy.c b/ml/dlib/dlib/external/cblas/cblas_dcopy.c
new file mode 100644
index 000000000..422a55e51
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dcopy.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_dcopy.c
+ *
+ * The program is a C interface to dcopy.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dcopy( const int N, const double *X,
+ const int incX, double *Y, const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_dcopy( &F77_N, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ddot.c b/ml/dlib/dlib/external/cblas/cblas_ddot.c
new file mode 100644
index 000000000..d77343403
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ddot.c
@@ -0,0 +1,25 @@
+/*
+ * cblas_ddot.c
+ *
+ * The program is a C interface to ddot.
+ * It calls the fortran wrapper before calling ddot.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+double cblas_ddot( const int N, const double *X,
+ const int incX, const double *Y, const int incY)
+{
+ double dot;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_ddot_sub( &F77_N, X, &F77_incX, Y, &F77_incY, &dot);
+ return dot;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dgbmv.c b/ml/dlib/dlib/external/cblas/cblas_dgbmv.c
new file mode 100644
index 000000000..886dab740
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dgbmv.c
@@ -0,0 +1,70 @@
+/*
+ *
+ * cblas_dgbmv.c
+ * This program is a C interface to dgbmv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dgbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const int KL, const int KU,
+ const double alpha, const double *A, const int lda,
+ const double *X, const int incX, const double beta,
+ double *Y, const int incY)
+{
+ char TA;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA;
+#else
+ #define F77_TA &TA
+#endif
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+ F77_INT F77_KL=KL,F77_KU=KU;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_KL KL
+ #define F77_KU KU
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(2, "cblas_dgbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_dgbmv(F77_TA, &F77_M, &F77_N, &F77_KL, &F77_KU, &alpha,
+ A, &F77_lda, X, &F77_incX, &beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(2, "cblas_dgbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_dgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, &alpha,
+ A ,&F77_lda, X,&F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_dgbmv", "Illegal Order setting, %d\n", order);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dgemm.c b/ml/dlib/dlib/external/cblas/cblas_dgemm.c
new file mode 100644
index 000000000..4fa9d8603
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dgemm.c
@@ -0,0 +1,94 @@
+/*
+ *
+ * cblas_dgemm.c
+ * This program is a C interface to dgemm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const double alpha, const double *A,
+ const int lda, const double *B, const int ldb,
+ const double beta, double *C, const int ldc)
+{
+ char TA, TB;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_TB;
+#else
+ #define F77_TA &TA
+ #define F77_TB &TB
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if(TransA == CblasTrans) TA='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_dgemm","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if(TransB == CblasTrans) TB='T';
+ else if ( TransB == CblasConjTrans ) TB='C';
+ else if ( TransB == CblasNoTrans ) TB='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_dgemm","Illegal TransB setting, %d\n", TransB);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ F77_TB = C2F_CHAR(&TB);
+ #endif
+
+ F77_dgemm(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, &alpha, A,
+ &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if(TransA == CblasTrans) TB='T';
+ else if ( TransA == CblasConjTrans ) TB='C';
+ else if ( TransA == CblasNoTrans ) TB='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_dgemm","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if(TransB == CblasTrans) TA='T';
+ else if ( TransB == CblasConjTrans ) TA='C';
+ else if ( TransB == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_dgemm","Illegal TransB setting, %d\n", TransB);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ F77_TB = C2F_CHAR(&TB);
+ #endif
+
+ F77_dgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, &alpha, B,
+ &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_dgemm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dgemv.c b/ml/dlib/dlib/external/cblas/cblas_dgemv.c
new file mode 100644
index 000000000..23a0f51e7
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dgemv.c
@@ -0,0 +1,67 @@
+/*
+ *
+ * cblas_dgemv.c
+ * This program is a C interface to dgemv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dgemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ const double *X, const int incX, const double beta,
+ double *Y, const int incY)
+{
+ char TA;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA;
+#else
+ #define F77_TA &TA
+#endif
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(2, "cblas_dgemv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_dgemv(F77_TA, &F77_M, &F77_N, &alpha, A, &F77_lda, X, &F77_incX,
+ &beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(2, "cblas_dgemv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_dgemv(F77_TA, &F77_N, &F77_M, &alpha, A, &F77_lda, X,
+ &F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_dgemv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dger.c b/ml/dlib/dlib/external/cblas/cblas_dger.c
new file mode 100644
index 000000000..d021cc401
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dger.c
@@ -0,0 +1,40 @@
+/*
+ *
+ * cblas_dger.c
+ * This program is a C interface to dger.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
+ const double alpha, const double *X, const int incX,
+ const double *Y, const int incY, double *A, const int lda)
+{
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+ #define F77_lda lda
+#endif
+
+
+ if (order == CblasColMajor)
+ {
+ F77_dger( &F77_M, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+ }
+ else if (order == CblasRowMajor)
+ {
+ F77_dger( &F77_N, &F77_M ,&alpha, Y, &F77_incY, X, &F77_incX, A,
+ &F77_lda);
+
+ }
+ else cblas_xerbla(1, "cblas_dger", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dnrm2.c b/ml/dlib/dlib/external/cblas/cblas_dnrm2.c
new file mode 100644
index 000000000..fe46ad484
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dnrm2.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_dnrm2.c
+ *
+ * The program is a C interface to dnrm2.
+ * It calls the fortranwrapper before calling dnrm2.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+double cblas_dnrm2( const int N, const double *X, const int incX)
+{
+ double nrm2;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_dnrm2_sub( &F77_N, X, &F77_incX, &nrm2);
+ return nrm2;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_drot.c b/ml/dlib/dlib/external/cblas/cblas_drot.c
new file mode 100644
index 000000000..51dc4ad5e
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_drot.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_drot.c
+ *
+ * The program is a C interface to drot.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_drot(const int N, double *X, const int incX,
+ double *Y, const int incY, const double c, const double s)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_drot(&F77_N, X, &F77_incX, Y, &F77_incY, &c, &s);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_drotg.c b/ml/dlib/dlib/external/cblas/cblas_drotg.c
new file mode 100644
index 000000000..0cbbd8bc0
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_drotg.c
@@ -0,0 +1,14 @@
+/*
+ * cblas_drotg.c
+ *
+ * The program is a C interface to drotg.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_drotg( double *a, double *b, double *c, double *s)
+{
+ F77_drotg(a,b,c,s);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_drotm.c b/ml/dlib/dlib/external/cblas/cblas_drotm.c
new file mode 100644
index 000000000..ebe20ad62
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_drotm.c
@@ -0,0 +1,14 @@
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_drotm( const int N, double *X, const int incX, double *Y,
+ const int incY, const double *P)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_drotm( &F77_N, X, &F77_incX, Y, &F77_incY, P);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_drotmg.c b/ml/dlib/dlib/external/cblas/cblas_drotmg.c
new file mode 100644
index 000000000..13a2208e5
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_drotmg.c
@@ -0,0 +1,15 @@
+/*
+ * cblas_drotmg.c
+ *
+ * The program is a C interface to drotmg.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_drotmg( double *d1, double *d2, double *b1,
+ const double b2, double *p)
+{
+ F77_drotmg(d1,d2,b1,&b2,p);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dsbmv.c b/ml/dlib/dlib/external/cblas/cblas_dsbmv.c
new file mode 100644
index 000000000..c2f1a71c3
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dsbmv.c
@@ -0,0 +1,66 @@
+/*
+ *
+ * cblas_dsbmv.c
+ * This program is a C interface to dsbmv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dsbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo, const int N, const int K,
+ const double alpha, const double *A, const int lda,
+ const double *X, const int incX, const double beta,
+ double *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsbmv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dsbmv(F77_UL, &F77_N, &F77_K, &alpha, A, &F77_lda, X,
+ &F77_incX, &beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dsbmv(F77_UL, &F77_N, &F77_K, &alpha,
+ A ,&F77_lda, X,&F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_dsbmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dscal.c b/ml/dlib/dlib/external/cblas/cblas_dscal.c
new file mode 100644
index 000000000..bd04de77d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dscal.c
@@ -0,0 +1,21 @@
+/*
+ * cblas_dscal.c
+ *
+ * The program is a C interface to dscal.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dscal( const int N, const double alpha, double *X,
+ const int incX)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_dscal( &F77_N, &alpha, X, &F77_incX);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dsdot.c b/ml/dlib/dlib/external/cblas/cblas_dsdot.c
new file mode 100644
index 000000000..52cd877a2
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dsdot.c
@@ -0,0 +1,25 @@
+/*
+ * cblas_dsdot.c
+ *
+ * The program is a C interface to dsdot.
+ * It calls fthe fortran wrapper before calling dsdot.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+double cblas_dsdot( const int N, const float *X,
+ const int incX, const float *Y, const int incY)
+{
+ double dot;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_dsdot_sub( &F77_N, X, &F77_incX, Y, &F77_incY, &dot);
+ return dot;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dspmv.c b/ml/dlib/dlib/external/cblas/cblas_dspmv.c
new file mode 100644
index 000000000..ecdf081cb
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dspmv.c
@@ -0,0 +1,65 @@
+/*
+ *
+ * cblas_dspmv.c
+ * This program is a C interface to dspmv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dspmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo, const int N,
+ const double alpha, const double *AP,
+ const double *X, const int incX, const double beta,
+ double *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dspmv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dspmv(F77_UL, &F77_N, &alpha, AP, X,
+ &F77_incX, &beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dspmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dspmv(F77_UL, &F77_N, &alpha,
+ AP, X,&F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_dspmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dspr.c b/ml/dlib/dlib/external/cblas/cblas_dspr.c
new file mode 100644
index 000000000..9e40cc11a
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dspr.c
@@ -0,0 +1,59 @@
+/*
+ *
+ * cblas_dspr.c
+ * This program is a C interface to dspr.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *X,
+ const int incX, double *Ap)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dspr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_dspr(F77_UL, &F77_N, &alpha, X, &F77_incX, Ap);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasLower) UL = 'U';
+ else if (Uplo == CblasUpper) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dspr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dspr(F77_UL, &F77_N, &alpha, X, &F77_incX, Ap);
+ } else cblas_xerbla(1, "cblas_dspr", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dspr2.c b/ml/dlib/dlib/external/cblas/cblas_dspr2.c
new file mode 100644
index 000000000..4ebbbbd52
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dspr2.c
@@ -0,0 +1,59 @@
+/*
+ * cblas_dspr2.c
+ * The program is a C interface to dspr2.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *X,
+ const int incX, const double *Y, const int incY, double *A)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dspr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_dspr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasLower) UL = 'U';
+ else if (Uplo == CblasUpper) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dspr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dspr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A);
+ } else cblas_xerbla(1, "cblas_dspr2", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dswap.c b/ml/dlib/dlib/external/cblas/cblas_dswap.c
new file mode 100644
index 000000000..9ae5bb93c
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dswap.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_dswap.c
+ *
+ * The program is a C interface to dswap.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dswap( const int N, double *X, const int incX, double *Y,
+ const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_dswap( &F77_N, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dsymm.c b/ml/dlib/dlib/external/cblas/cblas_dsymm.c
new file mode 100644
index 000000000..99b3858a1
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dsymm.c
@@ -0,0 +1,91 @@
+/*
+ *
+ * cblas_dsymm.c
+ * This program is a C interface to dsymm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ const double *B, const int ldb, const double beta,
+ double *C, const int ldc)
+{
+ char SD, UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_SD, F77_UL;
+#else
+ #define F77_SD &SD
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsymm","Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_dsymm","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_dsymm(F77_SD, F77_UL, &F77_M, &F77_N, &alpha, A, &F77_lda,
+ B, &F77_ldb, &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsymm","Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_dsymm","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_dsymm(F77_SD, F77_UL, &F77_N, &F77_M, &alpha, A, &F77_lda, B,
+ &F77_ldb, &beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_dsymm","Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dsymv.c b/ml/dlib/dlib/external/cblas/cblas_dsymv.c
new file mode 100644
index 000000000..f0d32398a
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dsymv.c
@@ -0,0 +1,65 @@
+/*
+ *
+ * cblas_dsymv.c
+ * This program is a C interface to dsymv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dsymv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo, const int N,
+ const double alpha, const double *A, const int lda,
+ const double *X, const int incX, const double beta,
+ double *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsymv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dsymv(F77_UL, &F77_N, &alpha, A, &F77_lda, X,
+ &F77_incX, &beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsymv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dsymv(F77_UL, &F77_N, &alpha,
+ A ,&F77_lda, X,&F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_dsymv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dsyr.c b/ml/dlib/dlib/external/cblas/cblas_dsyr.c
new file mode 100644
index 000000000..d21b846e3
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dsyr.c
@@ -0,0 +1,60 @@
+/*
+ *
+ * cblas_dsyr.c
+ * This program is a C interface to dsyr.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *X,
+ const int incX, double *A, const int lda)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_lda=lda;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_lda lda
+#endif
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsyr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_dsyr(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasLower) UL = 'U';
+ else if (Uplo == CblasUpper) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsyr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dsyr(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda);
+ } else cblas_xerbla(1, "cblas_dsyr", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dsyr2.c b/ml/dlib/dlib/external/cblas/cblas_dsyr2.c
new file mode 100644
index 000000000..7ce59657d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dsyr2.c
@@ -0,0 +1,65 @@
+/*
+ *
+ * cblas_dsyr2.c
+ * This program is a C interface to dsyr2.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const double *X,
+ const int incX, const double *Y, const int incY, double *A,
+ const int lda)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY, F77_lda=lda;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+ #define F77_lda lda
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsyr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_dsyr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasLower) UL = 'U';
+ else if (Uplo == CblasUpper) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsyr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_dsyr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+ } else cblas_xerbla(1, "cblas_dsyr2", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dsyr2k.c b/ml/dlib/dlib/external/cblas/cblas_dsyr2k.c
new file mode 100644
index 000000000..dc11e9549
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dsyr2k.c
@@ -0,0 +1,94 @@
+/*
+ *
+ * cblas_dsyr2k.c
+ * This program is a C interface to dsyr2k.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const double alpha, const double *A, const int lda,
+ const double *B, const int ldb, const double beta,
+ double *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsyr2k","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_dsyr2k","Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_dsyr2k(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda,
+ B, &F77_ldb, &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_dsyr2k","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='T';
+ else
+ {
+ cblas_xerbla(3, "cblas_dsyr2k","Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_dsyr2k(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, B,
+ &F77_ldb, &beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_dsyr2k","Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dsyrk.c b/ml/dlib/dlib/external/cblas/cblas_dsyrk.c
new file mode 100644
index 000000000..7ee834ea6
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dsyrk.c
@@ -0,0 +1,93 @@
+/*
+ *
+ * cblas_dsyrk.c
+ * This program is a C interface to dsyrk.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const double alpha, const double *A, const int lda,
+ const double beta, double *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dsyrk","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_dsyrk","Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_dsyrk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda,
+ &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_dsyrk","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='T';
+ else
+ {
+ cblas_xerbla(3, "cblas_dsyrk","Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_dsyrk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda,
+ &beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_dsyrk","Illegal Order setting, %d\n", Order);
+ return;
+}
+
diff --git a/ml/dlib/dlib/external/cblas/cblas_dtbmv.c b/ml/dlib/dlib/external/cblas/cblas_dtbmv.c
new file mode 100644
index 000000000..1a06d1886
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dtbmv.c
@@ -0,0 +1,103 @@
+/*
+ * cblas_dtbmv.c
+ * The program is a C interface to dtbmv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const double *A, const int lda,
+ double *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtbmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_dtbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_dtbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+
+ }
+ else cblas_xerbla(1, "cblas_dtbmv", "Illegal Order setting, %d\n", order);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dtbsv.c b/ml/dlib/dlib/external/cblas/cblas_dtbsv.c
new file mode 100644
index 000000000..aaf4a8d4b
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dtbsv.c
@@ -0,0 +1,103 @@
+/*
+ * cblas_dtbsv.c
+ * The program is a C interface to dtbsv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const double *A, const int lda,
+ double *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtbsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtbsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtbsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_dtbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtbsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtbsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtbsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_dtbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else cblas_xerbla(1, "cblas_dtbsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dtpmv.c b/ml/dlib/dlib/external/cblas/cblas_dtpmv.c
new file mode 100644
index 000000000..565f97a68
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dtpmv.c
@@ -0,0 +1,98 @@
+/*
+ * cblas_dtpmv.c
+ * The program is a C interface to dtpmv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const double *Ap, double *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtpmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtpmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtpmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_dtpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtpmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtpmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtpmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_dtpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX);
+ }
+ else cblas_xerbla(1, "cblas_dtpmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dtpsv.c b/ml/dlib/dlib/external/cblas/cblas_dtpsv.c
new file mode 100644
index 000000000..4f51ccc56
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dtpsv.c
@@ -0,0 +1,99 @@
+/*
+ * cblas_dtpsv.c
+ * The program is a C interface to dtpsv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const double *Ap, double *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtpsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtpsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtpsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_dtpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtpsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtpsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtpsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_dtpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX);
+
+ }
+ else cblas_xerbla(1, "cblas_dtpsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dtrmm.c b/ml/dlib/dlib/external/cblas/cblas_dtrmm.c
new file mode 100644
index 000000000..0f4c0d161
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dtrmm.c
@@ -0,0 +1,125 @@
+/*
+ *
+ * cblas_dtrmm.c
+ * This program is a C interface to dtrmm.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ double *B, const int ldb)
+{
+ char UL, TA, SD, DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_SD &SD
+ #define F77_DI &DI
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtrmm","Illegal Side setting, %d\n", Side);
+ return;
+ }
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtrmm","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtrmm","Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_dtrmm","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_dtrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, &alpha, A, &F77_lda, B, &F77_ldb);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtrmm","Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtrmm","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtrmm","Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_dtrmm","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_dtrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, &alpha, A, &F77_lda, B, &F77_ldb);
+ }
+ else cblas_xerbla(1, "cblas_dtrmm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dtrmv.c b/ml/dlib/dlib/external/cblas/cblas_dtrmv.c
new file mode 100644
index 000000000..c20ea0626
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dtrmv.c
@@ -0,0 +1,103 @@
+/*
+ *
+ * cblas_dtrmv.c
+ * This program is a C interface to sgemv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const double *A, const int lda,
+ double *X, const int incX)
+
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtrmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtrmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtrmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_dtrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtrmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtrmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtrmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_dtrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ } else cblas_xerbla(1, "cblas_dtrmv", "Illegal order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dtrsm.c b/ml/dlib/dlib/external/cblas/cblas_dtrsm.c
new file mode 100644
index 000000000..986425f60
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dtrsm.c
@@ -0,0 +1,130 @@
+/*
+ *
+ * cblas_dtrsm.c
+ * This program is a C interface to dtrsm.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ double *B, const int ldb)
+
+{
+ char UL, TA, SD, DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_SD &SD
+ #define F77_DI &DI
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if ( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtrsm","Illegal Side setting, %d\n", Side);
+ return;
+ }
+ if ( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtrsm","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if ( TransA == CblasTrans ) TA='T';
+ else if ( TransA == CblasConjTrans) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtrsm","Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if ( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_dtrsm","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_dtrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, &alpha,
+ A, &F77_lda, B, &F77_ldb);
+ }
+ else if (Order == CblasRowMajor)
+ {
+ if ( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtrsm","Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if ( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtrsm","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if ( TransA == CblasTrans ) TA='T';
+ else if ( TransA == CblasConjTrans) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtrsm","Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if ( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_dtrsm","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_dtrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, &alpha, A,
+ &F77_lda, B, &F77_ldb);
+ }
+ else cblas_xerbla(1, "cblas_dtrsm","Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dtrsv.c b/ml/dlib/dlib/external/cblas/cblas_dtrsv.c
new file mode 100644
index 000000000..5c4ed5637
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dtrsv.c
@@ -0,0 +1,102 @@
+/*
+ * cblas_dtrsv.c
+ * The program is a C interface to dtrsv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const double *A, const int lda, double *X,
+ const int incX)
+
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtrsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtrsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtrsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_dtrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_dtrsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_dtrsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_dtrsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_dtrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else cblas_xerbla(1, "cblas_dtrsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dzasum.c b/ml/dlib/dlib/external/cblas/cblas_dzasum.c
new file mode 100644
index 000000000..b32f573e5
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dzasum.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_dzasum.c
+ *
+ * The program is a C interface to dzasum.
+ * It calls the fortran wrapper before calling dzasum.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+double cblas_dzasum( const int N, const void *X, const int incX)
+{
+ double asum;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_dzasum_sub( &F77_N, X, &F77_incX, &asum);
+ return asum;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_dznrm2.c b/ml/dlib/dlib/external/cblas/cblas_dznrm2.c
new file mode 100644
index 000000000..dfa2bfc83
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_dznrm2.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_dznrm2.c
+ *
+ * The program is a C interface to dznrm2.
+ * It calls the fortran wrapper before calling dznrm2.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+double cblas_dznrm2( const int N, const void *X, const int incX)
+{
+ double nrm2;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_dznrm2_sub( &F77_N, X, &F77_incX, &nrm2);
+ return nrm2;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_f77.h b/ml/dlib/dlib/external/cblas/cblas_f77.h
new file mode 100644
index 000000000..18435cd30
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_f77.h
@@ -0,0 +1,701 @@
+/*
+ * cblas_f77.h
+ * Written by Keita Teranishi
+ *
+ * Updated by Jeff Horner
+ * Merged cblas_f77.h and cblas_fortran_header.h
+ */
+
+#ifndef CBLAS_F77_H
+#define CBLAS_f77_H
+
+#ifdef CRAY
+ #include <fortran.h>
+ #define F77_CHAR _fcd
+ #define C2F_CHAR(a) ( _cptofcd( (a), 1 ) )
+ #define C2F_STR(a, i) ( _cptofcd( (a), (i) ) )
+ #define F77_STRLEN(a) (_fcdlen)
+#endif
+
+#ifdef WeirdNEC
+ #define F77_INT long
+#endif
+
+#ifdef F77_CHAR
+ #define FCHAR F77_CHAR
+#else
+ #define FCHAR char *
+#endif
+
+#ifdef F77_INT
+ #define FINT const F77_INT *
+ #define FINT2 F77_INT *
+#else
+ #define FINT const int *
+ #define FINT2 int *
+#endif
+
+#if defined(ADD_)
+/*
+ * Level 1 BLAS
+ */
+#define F77_xerbla xerbla_
+ #define F77_srotg srotg_
+ #define F77_srotmg srotmg_
+ #define F77_srot srot_
+ #define F77_srotm srotm_
+ #define F77_drotg drotg_
+ #define F77_drotmg drotmg_
+ #define F77_drot drot_
+ #define F77_drotm drotm_
+ #define F77_sswap sswap_
+ #define F77_scopy scopy_
+ #define F77_saxpy saxpy_
+ #define F77_isamax_sub isamaxsub_
+ #define F77_dswap dswap_
+ #define F77_dcopy dcopy_
+ #define F77_daxpy daxpy_
+ #define F77_idamax_sub idamaxsub_
+ #define F77_cswap cswap_
+ #define F77_ccopy ccopy_
+ #define F77_caxpy caxpy_
+ #define F77_icamax_sub icamaxsub_
+ #define F77_zswap zswap_
+ #define F77_zcopy zcopy_
+ #define F77_zaxpy zaxpy_
+ #define F77_izamax_sub izamaxsub_
+ #define F77_sdot_sub sdotsub_
+ #define F77_ddot_sub ddotsub_
+ #define F77_dsdot_sub dsdotsub_
+ #define F77_sscal sscal_
+ #define F77_dscal dscal_
+ #define F77_cscal cscal_
+ #define F77_zscal zscal_
+ #define F77_csscal csscal_
+ #define F77_zdscal zdscal_
+ #define F77_cdotu_sub cdotusub_
+ #define F77_cdotc_sub cdotcsub_
+ #define F77_zdotu_sub zdotusub_
+ #define F77_zdotc_sub zdotcsub_
+ #define F77_snrm2_sub snrm2sub_
+ #define F77_sasum_sub sasumsub_
+ #define F77_dnrm2_sub dnrm2sub_
+ #define F77_dasum_sub dasumsub_
+ #define F77_scnrm2_sub scnrm2sub_
+ #define F77_scasum_sub scasumsub_
+ #define F77_dznrm2_sub dznrm2sub_
+ #define F77_dzasum_sub dzasumsub_
+ #define F77_sdsdot_sub sdsdotsub_
+/*
+ * Level 2 BLAS
+ */
+ #define F77_ssymv ssymv_
+ #define F77_ssbmv ssbmv_
+ #define F77_sspmv sspmv_
+ #define F77_sger sger_
+ #define F77_ssyr ssyr_
+ #define F77_sspr sspr_
+ #define F77_ssyr2 ssyr2_
+ #define F77_sspr2 sspr2_
+ #define F77_dsymv dsymv_
+ #define F77_dsbmv dsbmv_
+ #define F77_dspmv dspmv_
+ #define F77_dger dger_
+ #define F77_dsyr dsyr_
+ #define F77_dspr dspr_
+ #define F77_dsyr2 dsyr2_
+ #define F77_dspr2 dspr2_
+ #define F77_chemv chemv_
+ #define F77_chbmv chbmv_
+ #define F77_chpmv chpmv_
+ #define F77_cgeru cgeru_
+ #define F77_cgerc cgerc_
+ #define F77_cher cher_
+ #define F77_chpr chpr_
+ #define F77_cher2 cher2_
+ #define F77_chpr2 chpr2_
+ #define F77_zhemv zhemv_
+ #define F77_zhbmv zhbmv_
+ #define F77_zhpmv zhpmv_
+ #define F77_zgeru zgeru_
+ #define F77_zgerc zgerc_
+ #define F77_zher zher_
+ #define F77_zhpr zhpr_
+ #define F77_zher2 zher2_
+ #define F77_zhpr2 zhpr2_
+ #define F77_sgemv sgemv_
+ #define F77_sgbmv sgbmv_
+ #define F77_strmv strmv_
+ #define F77_stbmv stbmv_
+ #define F77_stpmv stpmv_
+ #define F77_strsv strsv_
+ #define F77_stbsv stbsv_
+ #define F77_stpsv stpsv_
+ #define F77_dgemv dgemv_
+ #define F77_dgbmv dgbmv_
+ #define F77_dtrmv dtrmv_
+ #define F77_dtbmv dtbmv_
+ #define F77_dtpmv dtpmv_
+ #define F77_dtrsv dtrsv_
+ #define F77_dtbsv dtbsv_
+ #define F77_dtpsv dtpsv_
+ #define F77_cgemv cgemv_
+ #define F77_cgbmv cgbmv_
+ #define F77_ctrmv ctrmv_
+ #define F77_ctbmv ctbmv_
+ #define F77_ctpmv ctpmv_
+ #define F77_ctrsv ctrsv_
+ #define F77_ctbsv ctbsv_
+ #define F77_ctpsv ctpsv_
+ #define F77_zgemv zgemv_
+ #define F77_zgbmv zgbmv_
+ #define F77_ztrmv ztrmv_
+ #define F77_ztbmv ztbmv_
+ #define F77_ztpmv ztpmv_
+ #define F77_ztrsv ztrsv_
+ #define F77_ztbsv ztbsv_
+ #define F77_ztpsv ztpsv_
+/*
+ * Level 3 BLAS
+ */
+ #define F77_chemm chemm_
+ #define F77_cherk cherk_
+ #define F77_cher2k cher2k_
+ #define F77_zhemm zhemm_
+ #define F77_zherk zherk_
+ #define F77_zher2k zher2k_
+ #define F77_sgemm sgemm_
+ #define F77_ssymm ssymm_
+ #define F77_ssyrk ssyrk_
+ #define F77_ssyr2k ssyr2k_
+ #define F77_strmm strmm_
+ #define F77_strsm strsm_
+ #define F77_dgemm dgemm_
+ #define F77_dsymm dsymm_
+ #define F77_dsyrk dsyrk_
+ #define F77_dsyr2k dsyr2k_
+ #define F77_dtrmm dtrmm_
+ #define F77_dtrsm dtrsm_
+ #define F77_cgemm cgemm_
+ #define F77_csymm csymm_
+ #define F77_csyrk csyrk_
+ #define F77_csyr2k csyr2k_
+ #define F77_ctrmm ctrmm_
+ #define F77_ctrsm ctrsm_
+ #define F77_zgemm zgemm_
+ #define F77_zsymm zsymm_
+ #define F77_zsyrk zsyrk_
+ #define F77_zsyr2k zsyr2k_
+ #define F77_ztrmm ztrmm_
+ #define F77_ztrsm ztrsm_
+#elif defined(UPCASE)
+/*
+ * Level 1 BLAS
+ */
+#define F77_xerbla XERBLA
+ #define F77_srotg SROTG
+ #define F77_srotmg SROTMG
+ #define F77_srot SROT
+ #define F77_srotm SROTM
+ #define F77_drotg DROTG
+ #define F77_drotmg DROTMG
+ #define F77_drot DROT
+ #define F77_drotm DROTM
+ #define F77_sswap SSWAP
+ #define F77_scopy SCOPY
+ #define F77_saxpy SAXPY
+ #define F77_isamax_sub ISAMAXSUB
+ #define F77_dswap DSWAP
+ #define F77_dcopy DCOPY
+ #define F77_daxpy DAXPY
+ #define F77_idamax_sub IDAMAXSUB
+ #define F77_cswap CSWAP
+ #define F77_ccopy CCOPY
+ #define F77_caxpy CAXPY
+ #define F77_icamax_sub ICAMAXSUB
+ #define F77_zswap ZSWAP
+ #define F77_zcopy ZCOPY
+ #define F77_zaxpy ZAXPY
+ #define F77_izamax_sub IZAMAXSUB
+ #define F77_sdot_sub SDOTSUB
+ #define F77_ddot_sub DDOTSUB
+ #define F77_dsdot_sub DSDOTSUB
+ #define F77_sscal SSCAL
+ #define F77_dscal DSCAL
+ #define F77_cscal CSCAL
+ #define F77_zscal ZSCAL
+ #define F77_csscal CSSCAL
+ #define F77_zdscal ZDSCAL
+ #define F77_cdotu_sub CDOTUSUB
+ #define F77_cdotc_sub CDOTCSUB
+ #define F77_zdotu_sub ZDOTUSUB
+ #define F77_zdotc_sub ZDOTCSUB
+ #define F77_snrm2_sub SNRM2SUB
+ #define F77_sasum_sub SASUMSUB
+ #define F77_dnrm2_sub DNRM2SUB
+ #define F77_dasum_sub DASUMSUB
+ #define F77_scnrm2_sub SCNRM2SUB
+ #define F77_scasum_sub SCASUMSUB
+ #define F77_dznrm2_sub DZNRM2SUB
+ #define F77_dzasum_sub DZASUMSUB
+ #define F77_sdsdot_sub SDSDOTSUB
+/*
+ * Level 2 BLAS
+ */
+ #define F77_ssymv SSYMV
+ #define F77_ssbmv SSBMV
+ #define F77_sspmv SSPMV
+ #define F77_sger SGER
+ #define F77_ssyr SSYR
+ #define F77_sspr SSPR
+ #define F77_ssyr2 SSYR2
+ #define F77_sspr2 SSPR2
+ #define F77_dsymv DSYMV
+ #define F77_dsbmv DSBMV
+ #define F77_dspmv DSPMV
+ #define F77_dger DGER
+ #define F77_dsyr DSYR
+ #define F77_dspr DSPR
+ #define F77_dsyr2 DSYR2
+ #define F77_dspr2 DSPR2
+ #define F77_chemv CHEMV
+ #define F77_chbmv CHBMV
+ #define F77_chpmv CHPMV
+ #define F77_cgeru CGERU
+ #define F77_cgerc CGERC
+ #define F77_cher CHER
+ #define F77_chpr CHPR
+ #define F77_cher2 CHER2
+ #define F77_chpr2 CHPR2
+ #define F77_zhemv ZHEMV
+ #define F77_zhbmv ZHBMV
+ #define F77_zhpmv ZHPMV
+ #define F77_zgeru ZGERU
+ #define F77_zgerc ZGERC
+ #define F77_zher ZHER
+ #define F77_zhpr ZHPR
+ #define F77_zher2 ZHER2
+ #define F77_zhpr2 ZHPR2
+ #define F77_sgemv SGEMV
+ #define F77_sgbmv SGBMV
+ #define F77_strmv STRMV
+ #define F77_stbmv STBMV
+ #define F77_stpmv STPMV
+ #define F77_strsv STRSV
+ #define F77_stbsv STBSV
+ #define F77_stpsv STPSV
+ #define F77_dgemv DGEMV
+ #define F77_dgbmv DGBMV
+ #define F77_dtrmv DTRMV
+ #define F77_dtbmv DTBMV
+ #define F77_dtpmv DTPMV
+ #define F77_dtrsv DTRSV
+ #define F77_dtbsv DTBSV
+ #define F77_dtpsv DTPSV
+ #define F77_cgemv CGEMV
+ #define F77_cgbmv CGBMV
+ #define F77_ctrmv CTRMV
+ #define F77_ctbmv CTBMV
+ #define F77_ctpmv CTPMV
+ #define F77_ctrsv CTRSV
+ #define F77_ctbsv CTBSV
+ #define F77_ctpsv CTPSV
+ #define F77_zgemv ZGEMV
+ #define F77_zgbmv ZGBMV
+ #define F77_ztrmv ZTRMV
+ #define F77_ztbmv ZTBMV
+ #define F77_ztpmv ZTPMV
+ #define F77_ztrsv ZTRSV
+ #define F77_ztbsv ZTBSV
+ #define F77_ztpsv ZTPSV
+/*
+ * Level 3 BLAS
+ */
+ #define F77_chemm CHEMM
+ #define F77_cherk CHERK
+ #define F77_cher2k CHER2K
+ #define F77_zhemm ZHEMM
+ #define F77_zherk ZHERK
+ #define F77_zher2k ZHER2K
+ #define F77_sgemm SGEMM
+ #define F77_ssymm SSYMM
+ #define F77_ssyrk SSYRK
+ #define F77_ssyr2k SSYR2K
+ #define F77_strmm STRMM
+ #define F77_strsm STRSM
+ #define F77_dgemm DGEMM
+ #define F77_dsymm DSYMM
+ #define F77_dsyrk DSYRK
+ #define F77_dsyr2k DSYR2K
+ #define F77_dtrmm DTRMM
+ #define F77_dtrsm DTRSM
+ #define F77_cgemm CGEMM
+ #define F77_csymm CSYMM
+ #define F77_csyrk CSYRK
+ #define F77_csyr2k CSYR2K
+ #define F77_ctrmm CTRMM
+ #define F77_ctrsm CTRSM
+ #define F77_zgemm ZGEMM
+ #define F77_zsymm ZSYMM
+ #define F77_zsyrk ZSYRK
+ #define F77_zsyr2k ZSYR2K
+ #define F77_ztrmm ZTRMM
+ #define F77_ztrsm ZTRSM
+#elif defined(NOCHANGE)
+/*
+ * Level 1 BLAS
+ */
+#define F77_xerbla xerbla
+ #define F77_srotg srotg
+ #define F77_srotmg srotmg
+ #define F77_srot srot
+ #define F77_srotm srotm
+ #define F77_drotg drotg
+ #define F77_drotmg drotmg
+ #define F77_drot drot
+ #define F77_drotm drotm
+ #define F77_sswap sswap
+ #define F77_scopy scopy
+ #define F77_saxpy saxpy
+ #define F77_isamax_sub isamaxsub
+ #define F77_dswap dswap
+ #define F77_dcopy dcopy
+ #define F77_daxpy daxpy
+ #define F77_idamax_sub idamaxsub
+ #define F77_cswap cswap
+ #define F77_ccopy ccopy
+ #define F77_caxpy caxpy
+ #define F77_icamax_sub icamaxsub
+ #define F77_zswap zswap
+ #define F77_zcopy zcopy
+ #define F77_zaxpy zaxpy
+ #define F77_izamax_sub izamaxsub
+ #define F77_sdot_sub sdotsub
+ #define F77_ddot_sub ddotsub
+ #define F77_dsdot_sub dsdotsub
+ #define F77_sscal sscal
+ #define F77_dscal dscal
+ #define F77_cscal cscal
+ #define F77_zscal zscal
+ #define F77_csscal csscal
+ #define F77_zdscal zdscal
+ #define F77_cdotu_sub cdotusub
+ #define F77_cdotc_sub cdotcsub
+ #define F77_zdotu_sub zdotusub
+ #define F77_zdotc_sub zdotcsub
+ #define F77_snrm2_sub snrm2sub
+ #define F77_sasum_sub sasumsub
+ #define F77_dnrm2_sub dnrm2sub
+ #define F77_dasum_sub dasumsub
+ #define F77_scnrm2_sub scnrm2sub
+ #define F77_scasum_sub scasumsub
+ #define F77_dznrm2_sub dznrm2sub
+ #define F77_dzasum_sub dzasumsub
+ #define F77_sdsdot_sub sdsdotsub
+/*
+ * Level 2 BLAS
+ */
+ #define F77_ssymv ssymv
+ #define F77_ssbmv ssbmv
+ #define F77_sspmv sspmv
+ #define F77_sger sger
+ #define F77_ssyr ssyr
+ #define F77_sspr sspr
+ #define F77_ssyr2 ssyr2
+ #define F77_sspr2 sspr2
+ #define F77_dsymv dsymv
+ #define F77_dsbmv dsbmv
+ #define F77_dspmv dspmv
+ #define F77_dger dger
+ #define F77_dsyr dsyr
+ #define F77_dspr dspr
+ #define F77_dsyr2 dsyr2
+ #define F77_dspr2 dspr2
+ #define F77_chemv chemv
+ #define F77_chbmv chbmv
+ #define F77_chpmv chpmv
+ #define F77_cgeru cgeru
+ #define F77_cgerc cgerc
+ #define F77_cher cher
+ #define F77_chpr chpr
+ #define F77_cher2 cher2
+ #define F77_chpr2 chpr2
+ #define F77_zhemv zhemv
+ #define F77_zhbmv zhbmv
+ #define F77_zhpmv zhpmv
+ #define F77_zgeru zgeru
+ #define F77_zgerc zgerc
+ #define F77_zher zher
+ #define F77_zhpr zhpr
+ #define F77_zher2 zher2
+ #define F77_zhpr2 zhpr2
+ #define F77_sgemv sgemv
+ #define F77_sgbmv sgbmv
+ #define F77_strmv strmv
+ #define F77_stbmv stbmv
+ #define F77_stpmv stpmv
+ #define F77_strsv strsv
+ #define F77_stbsv stbsv
+ #define F77_stpsv stpsv
+ #define F77_dgemv dgemv
+ #define F77_dgbmv dgbmv
+ #define F77_dtrmv dtrmv
+ #define F77_dtbmv dtbmv
+ #define F77_dtpmv dtpmv
+ #define F77_dtrsv dtrsv
+ #define F77_dtbsv dtbsv
+ #define F77_dtpsv dtpsv
+ #define F77_cgemv cgemv
+ #define F77_cgbmv cgbmv
+ #define F77_ctrmv ctrmv
+ #define F77_ctbmv ctbmv
+ #define F77_ctpmv ctpmv
+ #define F77_ctrsv ctrsv
+ #define F77_ctbsv ctbsv
+ #define F77_ctpsv ctpsv
+ #define F77_zgemv zgemv
+ #define F77_zgbmv zgbmv
+ #define F77_ztrmv ztrmv
+ #define F77_ztbmv ztbmv
+ #define F77_ztpmv ztpmv
+ #define F77_ztrsv ztrsv
+ #define F77_ztbsv ztbsv
+ #define F77_ztpsv ztpsv
+/*
+ * Level 3 BLAS
+ */
+ #define F77_chemm chemm
+ #define F77_cherk cherk
+ #define F77_cher2k cher2k
+ #define F77_zhemm zhemm
+ #define F77_zherk zherk
+ #define F77_zher2k zher2k
+ #define F77_sgemm sgemm
+ #define F77_ssymm ssymm
+ #define F77_ssyrk ssyrk
+ #define F77_ssyr2k ssyr2k
+ #define F77_strmm strmm
+ #define F77_strsm strsm
+ #define F77_dgemm dgemm
+ #define F77_dsymm dsymm
+ #define F77_dsyrk dsyrk
+ #define F77_dsyr2k dsyr2k
+ #define F77_dtrmm dtrmm
+ #define F77_dtrsm dtrsm
+ #define F77_cgemm cgemm
+ #define F77_csymm csymm
+ #define F77_csyrk csyrk
+ #define F77_csyr2k csyr2k
+ #define F77_ctrmm ctrmm
+ #define F77_ctrsm ctrsm
+ #define F77_zgemm zgemm
+ #define F77_zsymm zsymm
+ #define F77_zsyrk zsyrk
+ #define F77_zsyr2k zsyr2k
+ #define F77_ztrmm ztrmm
+ #define F77_ztrsm ztrsm
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ void F77_xerbla(FCHAR, void *);
+/*
+ * Level 1 Fortran Prototypes
+ */
+
+/* Single Precision */
+
+ void F77_srot(FINT, float *, FINT, float *, FINT, const float *, const float *);
+ void F77_srotg(float *,float *,float *,float *);
+ void F77_srotm( FINT, float *, FINT, float *, FINT, const float *);
+ void F77_srotmg(float *,float *,float *,const float *, float *);
+ void F77_sswap( FINT, float *, FINT, float *, FINT);
+ void F77_scopy( FINT, const float *, FINT, float *, FINT);
+ void F77_saxpy( FINT, const float *, const float *, FINT, float *, FINT);
+ void F77_sdot_sub(FINT, const float *, FINT, const float *, FINT, float *);
+ void F77_sdsdot_sub( FINT, const float *, const float *, FINT, const float *, FINT, float *);
+ void F77_sscal( FINT, const float *, float *, FINT);
+ void F77_snrm2_sub( FINT, const float *, FINT, float *);
+ void F77_sasum_sub( FINT, const float *, FINT, float *);
+ void F77_isamax_sub( FINT, const float * , FINT, FINT2);
+
+/* Double Precision */
+
+ void F77_drot(FINT, double *, FINT, double *, FINT, const double *, const double *);
+ void F77_drotg(double *,double *,double *,double *);
+ void F77_drotm( FINT, double *, FINT, double *, FINT, const double *);
+ void F77_drotmg(double *,double *,double *,const double *, double *);
+ void F77_dswap( FINT, double *, FINT, double *, FINT);
+ void F77_dcopy( FINT, const double *, FINT, double *, FINT);
+ void F77_daxpy( FINT, const double *, const double *, FINT, double *, FINT);
+ void F77_dswap( FINT, double *, FINT, double *, FINT);
+ void F77_dsdot_sub(FINT, const float *, FINT, const float *, FINT, double *);
+ void F77_ddot_sub( FINT, const double *, FINT, const double *, FINT, double *);
+ void F77_dscal( FINT, const double *, double *, FINT);
+ void F77_dnrm2_sub( FINT, const double *, FINT, double *);
+ void F77_dasum_sub( FINT, const double *, FINT, double *);
+ void F77_idamax_sub( FINT, const double * , FINT, FINT2);
+
+/* Single Complex Precision */
+
+ void F77_cswap( FINT, void *, FINT, void *, FINT);
+ void F77_ccopy( FINT, const void *, FINT, void *, FINT);
+ void F77_caxpy( FINT, const void *, const void *, FINT, void *, FINT);
+ void F77_cswap( FINT, void *, FINT, void *, FINT);
+ void F77_cdotc_sub( FINT, const void *, FINT, const void *, FINT, void *);
+ void F77_cdotu_sub( FINT, const void *, FINT, const void *, FINT, void *);
+ void F77_cscal( FINT, const void *, void *, FINT);
+ void F77_icamax_sub( FINT, const void *, FINT, FINT2);
+ void F77_csscal( FINT, const float *, void *, FINT);
+ void F77_scnrm2_sub( FINT, const void *, FINT, float *);
+ void F77_scasum_sub( FINT, const void *, FINT, float *);
+
+/* Double Complex Precision */
+
+ void F77_zswap( FINT, void *, FINT, void *, FINT);
+ void F77_zcopy( FINT, const void *, FINT, void *, FINT);
+ void F77_zaxpy( FINT, const void *, const void *, FINT, void *, FINT);
+ void F77_zswap( FINT, void *, FINT, void *, FINT);
+ void F77_zdotc_sub( FINT, const void *, FINT, const void *, FINT, void *);
+ void F77_zdotu_sub( FINT, const void *, FINT, const void *, FINT, void *);
+ void F77_zdscal( FINT, const double *, void *, FINT);
+ void F77_zscal( FINT, const void *, void *, FINT);
+ void F77_dznrm2_sub( FINT, const void *, FINT, double *);
+ void F77_dzasum_sub( FINT, const void *, FINT, double *);
+ void F77_izamax_sub( FINT, const void *, FINT, FINT2);
+
+/*
+ * Level 2 Fortran Prototypes
+ */
+
+/* Single Precision */
+
+ void F77_sgemv(FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_sgbmv(FCHAR, FINT, FINT, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_ssymv(FCHAR, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_ssbmv(FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_sspmv(FCHAR, FINT, const float *, const float *, const float *, FINT, const float *, float *, FINT);
+ void F77_strmv( FCHAR, FCHAR, FCHAR, FINT, const float *, FINT, float *, FINT);
+ void F77_stbmv( FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, FINT, float *, FINT);
+ void F77_strsv( FCHAR, FCHAR, FCHAR, FINT, const float *, FINT, float *, FINT);
+ void F77_stbsv( FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, FINT, float *, FINT);
+ void F77_stpmv( FCHAR, FCHAR, FCHAR, FINT, const float *, float *, FINT);
+ void F77_stpsv( FCHAR, FCHAR, FCHAR, FINT, const float *, float *, FINT);
+ void F77_sger( FINT, FINT, const float *, const float *, FINT, const float *, FINT, float *, FINT);
+ void F77_ssyr(FCHAR, FINT, const float *, const float *, FINT, float *, FINT);
+ void F77_sspr(FCHAR, FINT, const float *, const float *, FINT, float *);
+ void F77_sspr2(FCHAR, FINT, const float *, const float *, FINT, const float *, FINT, float *);
+ void F77_ssyr2(FCHAR, FINT, const float *, const float *, FINT, const float *, FINT, float *, FINT);
+
+/* Double Precision */
+
+ void F77_dgemv(FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_dgbmv(FCHAR, FINT, FINT, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_dsymv(FCHAR, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_dsbmv(FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_dspmv(FCHAR, FINT, const double *, const double *, const double *, FINT, const double *, double *, FINT);
+ void F77_dtrmv( FCHAR, FCHAR, FCHAR, FINT, const double *, FINT, double *, FINT);
+ void F77_dtbmv( FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, FINT, double *, FINT);
+ void F77_dtrsv( FCHAR, FCHAR, FCHAR, FINT, const double *, FINT, double *, FINT);
+ void F77_dtbsv( FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, FINT, double *, FINT);
+ void F77_dtpmv( FCHAR, FCHAR, FCHAR, FINT, const double *, double *, FINT);
+ void F77_dtpsv( FCHAR, FCHAR, FCHAR, FINT, const double *, double *, FINT);
+ void F77_dger( FINT, FINT, const double *, const double *, FINT, const double *, FINT, double *, FINT);
+ void F77_dsyr(FCHAR, FINT, const double *, const double *, FINT, double *, FINT);
+ void F77_dspr(FCHAR, FINT, const double *, const double *, FINT, double *);
+ void F77_dspr2(FCHAR, FINT, const double *, const double *, FINT, const double *, FINT, double *);
+ void F77_dsyr2(FCHAR, FINT, const double *, const double *, FINT, const double *, FINT, double *, FINT);
+
+/* Single Complex Precision */
+
+ void F77_cgemv(FCHAR, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT);
+ void F77_cgbmv(FCHAR, FINT, FINT, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT);
+ void F77_chemv(FCHAR, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT);
+ void F77_chbmv(FCHAR, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT);
+ void F77_chpmv(FCHAR, FINT, const void *, const void *, const void *, FINT, const void *, void *, FINT);
+ void F77_ctrmv( FCHAR, FCHAR, FCHAR, FINT, const void *, FINT, void *, FINT);
+ void F77_ctbmv( FCHAR, FCHAR, FCHAR, FINT, FINT, const void *, FINT, void *, FINT);
+ void F77_ctpmv( FCHAR, FCHAR, FCHAR, FINT, const void *, void *, FINT);
+ void F77_ctrsv( FCHAR, FCHAR, FCHAR, FINT, const void *, FINT, void *, FINT);
+ void F77_ctbsv( FCHAR, FCHAR, FCHAR, FINT, FINT, const void *, FINT, void *, FINT);
+ void F77_ctpsv( FCHAR, FCHAR, FCHAR, FINT, const void *, void *,FINT);
+ void F77_cgerc( FINT, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT);
+ void F77_cgeru( FINT, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT);
+ void F77_cher(FCHAR, FINT, const float *, const void *, FINT, void *, FINT);
+ void F77_cher2(FCHAR, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT);
+ void F77_chpr(FCHAR, FINT, const float *, const void *, FINT, void *);
+ void F77_chpr2(FCHAR, FINT, const float *, const void *, FINT, const void *, FINT, void *);
+
+/* Double Complex Precision */
+
+ void F77_zgemv(FCHAR, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT);
+ void F77_zgbmv(FCHAR, FINT, FINT, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT);
+ void F77_zhemv(FCHAR, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT);
+ void F77_zhbmv(FCHAR, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT);
+ void F77_zhpmv(FCHAR, FINT, const void *, const void *, const void *, FINT, const void *, void *, FINT);
+ void F77_ztrmv( FCHAR, FCHAR, FCHAR, FINT, const void *, FINT, void *, FINT);
+ void F77_ztbmv( FCHAR, FCHAR, FCHAR, FINT, FINT, const void *, FINT, void *, FINT);
+ void F77_ztpmv( FCHAR, FCHAR, FCHAR, FINT, const void *, void *, FINT);
+ void F77_ztrsv( FCHAR, FCHAR, FCHAR, FINT, const void *, FINT, void *, FINT);
+ void F77_ztbsv( FCHAR, FCHAR, FCHAR, FINT, FINT, const void *, FINT, void *, FINT);
+ void F77_ztpsv( FCHAR, FCHAR, FCHAR, FINT, const void *, void *,FINT);
+ void F77_zgerc( FINT, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT);
+ void F77_zgeru( FINT, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT);
+ void F77_zher(FCHAR, FINT, const double *, const void *, FINT, void *, FINT);
+ void F77_zher2(FCHAR, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT);
+ void F77_zhpr(FCHAR, FINT, const double *, const void *, FINT, void *);
+ void F77_zhpr2(FCHAR, FINT, const double *, const void *, FINT, const void *, FINT, void *);
+
+/*
+ * Level 3 Fortran Prototypes
+ */
+
+/* Single Precision */
+
+ void F77_sgemm(FCHAR, FCHAR, FINT, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_ssymm(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_ssyrk(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, float *, FINT);
+ void F77_ssyr2k(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_strmm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, float *, FINT);
+ void F77_strsm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, float *, FINT);
+
+/* Double Precision */
+
+ void F77_dgemm(FCHAR, FCHAR, FINT, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_dsymm(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_dsyrk(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, double *, FINT);
+ void F77_dsyr2k(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_dtrmm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, double *, FINT);
+ void F77_dtrsm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, double *, FINT);
+
+/* Single Complex Precision */
+
+ void F77_cgemm(FCHAR, FCHAR, FINT, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_csymm(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_chemm(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_csyrk(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, float *, FINT);
+ void F77_cherk(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, float *, FINT);
+ void F77_csyr2k(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_cher2k(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT);
+ void F77_ctrmm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, float *, FINT);
+ void F77_ctrsm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, float *, FINT);
+
+/* Double Complex Precision */
+
+ void F77_zgemm(FCHAR, FCHAR, FINT, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_zsymm(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_zhemm(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_zsyrk(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, double *, FINT);
+ void F77_zherk(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, double *, FINT);
+ void F77_zsyr2k(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_zher2k(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT);
+ void F77_ztrmm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, double *, FINT);
+ void F77_ztrsm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, double *, FINT);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* CBLAS_F77_H */
diff --git a/ml/dlib/dlib/external/cblas/cblas_icamax.c b/ml/dlib/dlib/external/cblas/cblas_icamax.c
new file mode 100644
index 000000000..b3ffe6eec
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_icamax.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_icamax.c
+ *
+ * The program is a C interface to icamax.
+ * It calls the fortran wrapper before calling icamax.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+CBLAS_INDEX cblas_icamax( const int N, const void *X, const int incX)
+{
+ F77_INT iamax;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_icamax_sub( &F77_N, X, &F77_incX, &iamax);
+ return iamax ? iamax-1 : 0;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_idamax.c b/ml/dlib/dlib/external/cblas/cblas_idamax.c
new file mode 100644
index 000000000..e42e459ea
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_idamax.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_idamax.c
+ *
+ * The program is a C interface to idamax.
+ * It calls the fortran wrapper before calling idamax.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+CBLAS_INDEX cblas_idamax( const int N, const double *X, const int incX)
+{
+ F77_INT iamax;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_idamax_sub( &F77_N, X, &F77_incX, &iamax);
+ return iamax ? iamax-1 : 0;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_isamax.c b/ml/dlib/dlib/external/cblas/cblas_isamax.c
new file mode 100644
index 000000000..63d639c7f
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_isamax.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_isamax.c
+ *
+ * The program is a C interface to isamax.
+ * It calls the fortran wrapper before calling isamax.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+CBLAS_INDEX cblas_isamax( const int N, const float *X, const int incX)
+{
+ F77_INT iamax;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_isamax_sub( &F77_N, X, &F77_incX, &iamax);
+ return iamax ? iamax-1 : 0;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_izamax.c b/ml/dlib/dlib/external/cblas/cblas_izamax.c
new file mode 100644
index 000000000..78eda4042
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_izamax.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_izamax.c
+ *
+ * The program is a C interface to izamax.
+ * It calls the fortran wrapper before calling izamax.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+CBLAS_INDEX cblas_izamax( const int N, const void *X, const int incX)
+{
+ F77_INT iamax;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_izamax_sub( &F77_N, X, &F77_incX, &iamax);
+ return (iamax ? iamax-1 : 0);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sasum.c b/ml/dlib/dlib/external/cblas/cblas_sasum.c
new file mode 100644
index 000000000..7d4c32cf9
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sasum.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_sasum.c
+ *
+ * The program is a C interface to sasum.
+ * It calls the fortran wrapper before calling sasum.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+float cblas_sasum( const int N, const float *X, const int incX)
+{
+ float asum;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_sasum_sub( &F77_N, X, &F77_incX, &asum);
+ return asum;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_saxpy.c b/ml/dlib/dlib/external/cblas/cblas_saxpy.c
new file mode 100644
index 000000000..2eee8e06e
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_saxpy.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_saxpy.c
+ *
+ * The program is a C interface to saxpy.
+ * It calls the fortran wrapper before calling saxpy.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_saxpy( const int N, const float alpha, const float *X,
+ const int incX, float *Y, const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_saxpy( &F77_N, &alpha, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_scasum.c b/ml/dlib/dlib/external/cblas/cblas_scasum.c
new file mode 100644
index 000000000..e1fa53090
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_scasum.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_scasum.c
+ *
+ * The program is a C interface to scasum.
+ * It calls the fortran wrapper before calling scasum.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+float cblas_scasum( const int N, const void *X, const int incX)
+{
+ float asum;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_scasum_sub( &F77_N, X, &F77_incX, &asum);
+ return asum;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_scnrm2.c b/ml/dlib/dlib/external/cblas/cblas_scnrm2.c
new file mode 100644
index 000000000..fa48454ed
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_scnrm2.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_scnrm2.c
+ *
+ * The program is a C interface to scnrm2.
+ * It calls the fortran wrapper before calling scnrm2.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+float cblas_scnrm2( const int N, const void *X, const int incX)
+{
+ float nrm2;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_scnrm2_sub( &F77_N, X, &F77_incX, &nrm2);
+ return nrm2;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_scopy.c b/ml/dlib/dlib/external/cblas/cblas_scopy.c
new file mode 100644
index 000000000..7796959f3
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_scopy.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_scopy.c
+ *
+ * The program is a C interface to scopy.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_scopy( const int N, const float *X,
+ const int incX, float *Y, const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_scopy( &F77_N, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sdot.c b/ml/dlib/dlib/external/cblas/cblas_sdot.c
new file mode 100644
index 000000000..baf859272
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sdot.c
@@ -0,0 +1,25 @@
+/*
+ * cblas_sdot.c
+ *
+ * The program is a C interface to sdot.
+ * It calls the fortran wrapper before calling sdot.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+float cblas_sdot( const int N, const float *X,
+ const int incX, const float *Y, const int incY)
+{
+ float dot;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_sdot_sub( &F77_N, X, &F77_incX, Y, &F77_incY, &dot);
+ return dot;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sdsdot.c b/ml/dlib/dlib/external/cblas/cblas_sdsdot.c
new file mode 100644
index 000000000..b824849b9
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sdsdot.c
@@ -0,0 +1,25 @@
+/*
+ * cblas_sdsdot.c
+ *
+ * The program is a C interface to sdsdot.
+ * It calls the fortran wrapper before calling sdsdot.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+float cblas_sdsdot( const int N, const float alpha, const float *X,
+ const int incX, const float *Y, const int incY)
+{
+ float dot;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_sdsdot_sub( &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, &dot);
+ return dot;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sgbmv.c b/ml/dlib/dlib/external/cblas/cblas_sgbmv.c
new file mode 100644
index 000000000..b6de24977
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sgbmv.c
@@ -0,0 +1,72 @@
+/*
+ *
+ * cblas_sgbmv.c
+ * This program is a C interface to sgbmv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_sgbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const int KL, const int KU,
+ const float alpha, const float *A, const int lda,
+ const float *X, const int incX, const float beta,
+ float *Y, const int incY)
+{
+ char TA;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA;
+#else
+ #define F77_TA &TA
+#endif
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+ F77_INT F77_KL=KL,F77_KU=KU;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_KL KL
+ #define F77_KU KU
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(2, "cblas_sgbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_sgbmv(F77_TA, &F77_M, &F77_N, &F77_KL, &F77_KU, &alpha,
+ A, &F77_lda, X, &F77_incX, &beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(2, "cblas_sgbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_sgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, &alpha,
+ A ,&F77_lda, X, &F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_sgbmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sgemm.c b/ml/dlib/dlib/external/cblas/cblas_sgemm.c
new file mode 100644
index 000000000..f8adeda49
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sgemm.c
@@ -0,0 +1,95 @@
+/*
+ *
+ * cblas_sgemm.c
+ * This program is a C interface to sgemm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const float alpha, const float *A,
+ const int lda, const float *B, const int ldb,
+ const float beta, float *C, const int ldc)
+{
+ char TA, TB;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_TB;
+#else
+ #define F77_TA &TA
+ #define F77_TB &TB
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+ if( Order == CblasColMajor )
+ {
+ if(TransA == CblasTrans) TA='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_sgemm",
+ "Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if(TransB == CblasTrans) TB='T';
+ else if ( TransB == CblasConjTrans ) TB='C';
+ else if ( TransB == CblasNoTrans ) TB='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_sgemm",
+ "Illegal TransB setting, %d\n", TransB);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ F77_TB = C2F_CHAR(&TB);
+ #endif
+
+ F77_sgemm(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if(TransA == CblasTrans) TB='T';
+ else if ( TransA == CblasConjTrans ) TB='C';
+ else if ( TransA == CblasNoTrans ) TB='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_sgemm",
+ "Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if(TransB == CblasTrans) TA='T';
+ else if ( TransB == CblasConjTrans ) TA='C';
+ else if ( TransB == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_sgemm",
+ "Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ F77_TB = C2F_CHAR(&TB);
+ #endif
+
+ F77_sgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, &alpha, B, &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc);
+ } else
+ cblas_xerbla(1, "cblas_sgemm",
+ "Illegal Order setting, %d\n", Order);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sgemv.c b/ml/dlib/dlib/external/cblas/cblas_sgemv.c
new file mode 100644
index 000000000..d47f3be55
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sgemv.c
@@ -0,0 +1,67 @@
+/*
+ *
+ * cblas_sgemv.c
+ * This program is a C interface to sgemv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_sgemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ const float *X, const int incX, const float beta,
+ float *Y, const int incY)
+{
+ char TA;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA;
+#else
+ #define F77_TA &TA
+#endif
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+
+ if (order == CblasColMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(2, "cblas_sgemv","Illegal TransA setting, %d\n", TransA);
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_sgemv(F77_TA, &F77_M, &F77_N, &alpha, A, &F77_lda, X, &F77_incX,
+ &beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(2, "cblas_sgemv", "Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_sgemv(F77_TA, &F77_N, &F77_M, &alpha, A, &F77_lda, X,
+ &F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_sgemv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sger.c b/ml/dlib/dlib/external/cblas/cblas_sger.c
new file mode 100644
index 000000000..0313590c7
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sger.c
@@ -0,0 +1,39 @@
+/*
+ *
+ * cblas_sger.c
+ * This program is a C interface to sger.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
+ const float alpha, const float *X, const int incX,
+ const float *Y, const int incY, float *A, const int lda)
+{
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+ #define F77_lda lda
+#endif
+
+
+ if (order == CblasColMajor)
+ {
+ F77_sger( &F77_M, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+ }
+ else if (order == CblasRowMajor)
+ {
+ F77_sger( &F77_N, &F77_M, &alpha, Y, &F77_incY, X, &F77_incX, A,
+ &F77_lda);
+ }
+ else cblas_xerbla(1, "cblas_sger", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_snrm2.c b/ml/dlib/dlib/external/cblas/cblas_snrm2.c
new file mode 100644
index 000000000..18161b4fa
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_snrm2.c
@@ -0,0 +1,23 @@
+/*
+ * cblas_snrm2.c
+ *
+ * The program is a C interface to snrm2.
+ * It calls the fortran wrapper before calling snrm2.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+float cblas_snrm2( const int N, const float *X, const int incX)
+{
+ float nrm2;
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_snrm2_sub( &F77_N, X, &F77_incX, &nrm2);
+ return nrm2;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_srot.c b/ml/dlib/dlib/external/cblas/cblas_srot.c
new file mode 100644
index 000000000..cbd1c8c90
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_srot.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_srot.c
+ *
+ * The program is a C interface to srot.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_srot( const int N, float *X, const int incX, float *Y,
+ const int incY, const float c, const float s)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_srot(&F77_N, X, &F77_incX, Y, &F77_incY, &c, &s);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_srotg.c b/ml/dlib/dlib/external/cblas/cblas_srotg.c
new file mode 100644
index 000000000..f6460048d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_srotg.c
@@ -0,0 +1,14 @@
+/*
+ * cblas_srotg.c
+ *
+ * The program is a C interface to srotg.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_srotg( float *a, float *b, float *c, float *s)
+{
+ F77_srotg(a,b,c,s);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_srotm.c b/ml/dlib/dlib/external/cblas/cblas_srotm.c
new file mode 100644
index 000000000..496746454
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_srotm.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_srotm.c
+ *
+ * The program is a C interface to srotm.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_srotm( const int N, float *X, const int incX, float *Y,
+ const int incY, const float *P)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_srotm( &F77_N, X, &F77_incX, Y, &F77_incY, P);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_srotmg.c b/ml/dlib/dlib/external/cblas/cblas_srotmg.c
new file mode 100644
index 000000000..04f978b40
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_srotmg.c
@@ -0,0 +1,15 @@
+/*
+ * cblas_srotmg.c
+ *
+ * The program is a C interface to srotmg.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_srotmg( float *d1, float *d2, float *b1,
+ const float b2, float *p)
+{
+ F77_srotmg(d1,d2,b1,&b2,p);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ssbmv.c b/ml/dlib/dlib/external/cblas/cblas_ssbmv.c
new file mode 100644
index 000000000..69663b619
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ssbmv.c
@@ -0,0 +1,65 @@
+/*
+ *
+ * cblas_ssbmv.c
+ * This program is a C interface to ssbmv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const int K, const float alpha, const float *A,
+ const int lda, const float *X, const int incX,
+ const float beta, float *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssbmv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_ssbmv(F77_UL, &F77_N, &F77_K, &alpha, A, &F77_lda, X,
+ &F77_incX, &beta, Y, &F77_incY);
+ }else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_ssbmv(F77_UL, &F77_N, &F77_K, &alpha, A, &F77_lda, X,
+ &F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_ssbmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sscal.c b/ml/dlib/dlib/external/cblas/cblas_sscal.c
new file mode 100644
index 000000000..1f09abe7a
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sscal.c
@@ -0,0 +1,21 @@
+/*
+ * cblas_sscal.c
+ *
+ * The program is a C interface to sscal.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_sscal( const int N, const float alpha, float *X,
+ const int incX)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_sscal( &F77_N, &alpha, X, &F77_incX);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sspmv.c b/ml/dlib/dlib/external/cblas/cblas_sspmv.c
new file mode 100644
index 000000000..e3485e742
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sspmv.c
@@ -0,0 +1,62 @@
+/*
+ *
+ * cblas_sspmv.c
+ * This program is a C interface to sspmv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_sspmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo, const int N,
+ const float alpha, const float *AP,
+ const float *X, const int incX, const float beta,
+ float *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_sspmv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_sspmv(F77_UL, &F77_N, &alpha, AP, X,
+ &F77_incX, &beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_sspmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_sspmv(F77_UL, &F77_N, &alpha,
+ AP, X,&F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_sspmv", "Illegal Order setting, %d\n", order);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sspr.c b/ml/dlib/dlib/external/cblas/cblas_sspr.c
new file mode 100644
index 000000000..75d669b30
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sspr.c
@@ -0,0 +1,61 @@
+/*
+ *
+ * cblas_sspr.c
+ * This program is a C interface to sspr.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *X,
+ const int incX, float *Ap)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_sspr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_sspr(F77_UL, &F77_N, &alpha, X, &F77_incX, Ap);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasLower) UL = 'U';
+ else if (Uplo == CblasUpper) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_sspr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_sspr(F77_UL, &F77_N, &alpha, X, &F77_incX, Ap);
+ } else cblas_xerbla(1, "cblas_sspr", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sspr2.c b/ml/dlib/dlib/external/cblas/cblas_sspr2.c
new file mode 100644
index 000000000..b6ff9bfa2
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sspr2.c
@@ -0,0 +1,60 @@
+/*
+ *
+ * cblas_sspr2.c
+ * This program is a C interface to sspr2.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *X,
+ const int incX, const float *Y, const int incY, float *A)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_sspr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_sspr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasLower) UL = 'U';
+ else if (Uplo == CblasUpper) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_sspr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_sspr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A);
+ } else cblas_xerbla(1, "cblas_sspr2", "Illegal Order setting, %d\n", order);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_sswap.c b/ml/dlib/dlib/external/cblas/cblas_sswap.c
new file mode 100644
index 000000000..b74d8469c
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_sswap.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_sswap.c
+ *
+ * The program is a C interface to sswap.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_sswap( const int N, float *X, const int incX, float *Y,
+ const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_sswap( &F77_N, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ssymm.c b/ml/dlib/dlib/external/cblas/cblas_ssymm.c
new file mode 100644
index 000000000..55c413b48
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ssymm.c
@@ -0,0 +1,93 @@
+/*
+ *
+ * cblas_ssymm.c
+ * This program is a C interface to ssymm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ const float *B, const int ldb, const float beta,
+ float *C, const int ldc)
+{
+ char SD, UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_SD, F77_UL;
+#else
+ #define F77_SD &SD
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssymm",
+ "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_ssymm",
+ "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_ssymm(F77_SD, F77_UL, &F77_M, &F77_N, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssymm",
+ "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_ssymm",
+ "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_ssymm(F77_SD, F77_UL, &F77_N, &F77_M, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ } else cblas_xerbla(1, "cblas_ssymm",
+ "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ssymv.c b/ml/dlib/dlib/external/cblas/cblas_ssymv.c
new file mode 100644
index 000000000..c56d4d457
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ssymv.c
@@ -0,0 +1,65 @@
+/*
+ *
+ * cblas_ssymv.c
+ * This program is a C interface to ssymv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ssymv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo, const int N,
+ const float alpha, const float *A, const int lda,
+ const float *X, const int incX, const float beta,
+ float *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssymv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_ssymv(F77_UL, &F77_N, &alpha, A, &F77_lda, X,
+ &F77_incX, &beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssymv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_ssymv(F77_UL, &F77_N, &alpha,
+ A ,&F77_lda, X,&F77_incX, &beta, Y, &F77_incY);
+ }
+ else cblas_xerbla(1, "cblas_ssymv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ssyr.c b/ml/dlib/dlib/external/cblas/cblas_ssyr.c
new file mode 100644
index 000000000..4215b9671
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ssyr.c
@@ -0,0 +1,59 @@
+/*
+ *
+ * cblas_ssyr.c
+ * This program is a C interface to ssyr.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *X,
+ const int incX, float *A, const int lda)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_lda=lda;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_lda lda
+#endif
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssyr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_ssyr(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasLower) UL = 'U';
+ else if (Uplo == CblasUpper) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssyr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_ssyr(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda);
+ } else cblas_xerbla(1, "cblas_ssyr", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ssyr2.c b/ml/dlib/dlib/external/cblas/cblas_ssyr2.c
new file mode 100644
index 000000000..9cdaa412d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ssyr2.c
@@ -0,0 +1,65 @@
+/*
+ *
+ * cblas_ssyr2.c
+ * This program is a C interface to ssyr2.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const float alpha, const float *X,
+ const int incX, const float *Y, const int incY, float *A,
+ const int lda)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY, F77_lda=lda;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+ #define F77_lda lda
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssyr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_ssyr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasLower) UL = 'U';
+ else if (Uplo == CblasUpper) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssyr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_ssyr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+ } else cblas_xerbla(1, "cblas_ssyr2", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ssyr2k.c b/ml/dlib/dlib/external/cblas/cblas_ssyr2k.c
new file mode 100644
index 000000000..9e9f538df
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ssyr2k.c
@@ -0,0 +1,96 @@
+/*
+ *
+ * cblas_ssyr2k.c
+ * This program is a C interface to ssyr2k.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const float alpha, const float *A, const int lda,
+ const float *B, const int ldb, const float beta,
+ float *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssyr2k",
+ "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_ssyr2k",
+ "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_ssyr2k(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_ssyr2k",
+ "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='T';
+ else
+ {
+ cblas_xerbla(3, "cblas_ssyr2k",
+ "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_ssyr2k(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ } else cblas_xerbla(1, "cblas_ssyr2k",
+ "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ssyrk.c b/ml/dlib/dlib/external/cblas/cblas_ssyrk.c
new file mode 100644
index 000000000..55ceb7e3a
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ssyrk.c
@@ -0,0 +1,95 @@
+/*
+ *
+ * cblas_ssyrk.c
+ * This program is a C interface to ssyrk.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const float alpha, const float *A, const int lda,
+ const float beta, float *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ssyrk",
+ "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_ssyrk",
+ "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_ssyrk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_ssyrk",
+ "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='T';
+ else
+ {
+ cblas_xerbla(3, "cblas_ssyrk",
+ "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_ssyrk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, &beta, C, &F77_ldc);
+ } else cblas_xerbla(1, "cblas_ssyrk",
+ "Illegal Order setting, %d\n", Order);
+ return;
+}
+
diff --git a/ml/dlib/dlib/external/cblas/cblas_stbmv.c b/ml/dlib/dlib/external/cblas/cblas_stbmv.c
new file mode 100644
index 000000000..71ef469a7
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_stbmv.c
@@ -0,0 +1,103 @@
+/*
+ * cblas_stbmv.c
+ * This program is a C interface to stbmv.
+ * Written by Keita Teranishi
+ * 3/3/1998
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+
+void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const float *A, const int lda,
+ float *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_stbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_stbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_stbmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_stbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_stbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_stbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_stbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_stbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else cblas_xerbla(1, "cblas_stbmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_stbsv.c b/ml/dlib/dlib/external/cblas/cblas_stbsv.c
new file mode 100644
index 000000000..96df7c0c1
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_stbsv.c
@@ -0,0 +1,103 @@
+/*
+ * cblas_stbsv.c
+ * The program is a C interface to stbsv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const float *A, const int lda,
+ float *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_stbsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_stbsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_stbsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_stbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_stbsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_stbsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_stbsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_stbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else cblas_xerbla(1, "cblas_stbsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_stpmv.c b/ml/dlib/dlib/external/cblas/cblas_stpmv.c
new file mode 100644
index 000000000..5cb5cd29d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_stpmv.c
@@ -0,0 +1,99 @@
+/*
+ *
+ * cblas_stpmv.c
+ * This program is a C interface to stpmv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const float *Ap, float *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_stpmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_stpmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_stpmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_stpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_stpmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_stpmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_stpmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_stpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX);
+ }
+ else cblas_xerbla(1, "cblas_stpmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_stpsv.c b/ml/dlib/dlib/external/cblas/cblas_stpsv.c
new file mode 100644
index 000000000..2f0d29c0d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_stpsv.c
@@ -0,0 +1,99 @@
+/*
+ * cblas_stpsv.c
+ * The program is a C interface to stpsv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const float *Ap, float *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_stpsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_stpsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_stpsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_stpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_stpsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_stpsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_stpsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_stpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX);
+
+ }
+ else cblas_xerbla(1, "cblas_stpsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_strmm.c b/ml/dlib/dlib/external/cblas/cblas_strmm.c
new file mode 100644
index 000000000..40d6e23dd
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_strmm.c
@@ -0,0 +1,125 @@
+/*
+ *
+ * cblas_strmm.c
+ * This program is a C interface to strmm.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ float *B, const int ldb)
+{
+ char UL, TA, SD, DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_SD &SD
+ #define F77_DI &DI
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_strmm","Illegal Side setting, %d\n", Side);
+ return;
+ }
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_strmm","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_strmm","Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_strmm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_strmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, &alpha, A, &F77_lda, B, &F77_ldb);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_strmm","Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_strmm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_strmm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_strmm","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+#ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+#endif
+ F77_strmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, &alpha, A,
+ &F77_lda, B, &F77_ldb);
+ }
+ else cblas_xerbla(1, "cblas_strmm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_strmv.c b/ml/dlib/dlib/external/cblas/cblas_strmv.c
new file mode 100644
index 000000000..4c2f7b6a6
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_strmv.c
@@ -0,0 +1,103 @@
+/*
+ *
+ * cblas_strmv.c
+ * This program is a C interface to strmv.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const float *A, const int lda,
+ float *X, const int incX)
+
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_strmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_strmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_strmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_strmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_strmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_strmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_strmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_strmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else cblas_xerbla(1, "cblas_strmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_strsm.c b/ml/dlib/dlib/external/cblas/cblas_strsm.c
new file mode 100644
index 000000000..178cf7d06
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_strsm.c
@@ -0,0 +1,120 @@
+/*
+ *
+ * cblas_strsm.c
+ * This program is a C interface to strsm.
+ * Written by Keita Teranishi
+ * 4/6/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ float *B, const int ldb)
+
+{
+ char UL, TA, SD, DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_SD &SD
+ #define F77_DI &DI
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_strsm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_strsm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_strsm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_strsm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_strsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, &alpha, A, &F77_lda, B, &F77_ldb);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_strsm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_strsm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_strsm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_strsm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_strsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, &alpha, A, &F77_lda, B, &F77_ldb);
+ }
+ else cblas_xerbla(1, "cblas_strsm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_strsv.c b/ml/dlib/dlib/external/cblas/cblas_strsv.c
new file mode 100644
index 000000000..7c3811974
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_strsv.c
@@ -0,0 +1,102 @@
+/*
+ * cblas_strsv.c
+ * The program is a C interface to strsv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const float *A, const int lda, float *X,
+ const int incX)
+
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_strsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_strsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_strsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_strsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_strsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans) TA = 'N';
+ else
+ {
+ cblas_xerbla(3, "cblas_strsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_strsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_strsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else cblas_xerbla(1, "cblas_strsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_xerbla.c b/ml/dlib/dlib/external/cblas/cblas_xerbla.c
new file mode 100644
index 000000000..0b5b39f53
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_xerbla.c
@@ -0,0 +1,66 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <stdarg.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+
+void cblas_xerbla(int info, const char *rout, const char *form, ...)
+{
+ char empty[1] = "";
+ va_list argptr;
+
+ va_start(argptr, form);
+
+ {
+ if (strstr(rout,"gemm") != 0)
+ {
+ if (info == 5 ) info = 4;
+ else if (info == 4 ) info = 5;
+ else if (info == 11) info = 9;
+ else if (info == 9 ) info = 11;
+ }
+ else if (strstr(rout,"symm") != 0 || strstr(rout,"hemm") != 0)
+ {
+ if (info == 5 ) info = 4;
+ else if (info == 4 ) info = 5;
+ }
+ else if (strstr(rout,"trmm") != 0 || strstr(rout,"trsm") != 0)
+ {
+ if (info == 7 ) info = 6;
+ else if (info == 6 ) info = 7;
+ }
+ else if (strstr(rout,"gemv") != 0)
+ {
+ if (info == 4) info = 3;
+ else if (info == 3) info = 4;
+ }
+ else if (strstr(rout,"gbmv") != 0)
+ {
+ if (info == 4) info = 3;
+ else if (info == 3) info = 4;
+ else if (info == 6) info = 5;
+ else if (info == 5) info = 6;
+ }
+ else if (strstr(rout,"ger") != 0)
+ {
+ if (info == 3) info = 2;
+ else if (info == 2) info = 3;
+ else if (info == 8) info = 6;
+ else if (info == 6) info = 8;
+ }
+ else if ( (strstr(rout,"her2") != 0 || strstr(rout,"hpr2") != 0)
+ && strstr(rout,"her2k") == 0 )
+ {
+ if (info == 8) info = 6;
+ else if (info == 6) info = 8;
+ }
+ }
+ if (info)
+ fprintf(stderr, "Parameter %d to routine %s was incorrect\n", info, rout);
+ vfprintf(stderr, form, argptr);
+ va_end(argptr);
+ if (info && !info)
+ F77_xerbla(empty, &info); /* Force link of our F77 error handler */
+ exit(-1);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zaxpy.c b/ml/dlib/dlib/external/cblas/cblas_zaxpy.c
new file mode 100644
index 000000000..f63c4c39b
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zaxpy.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_zaxpy.c
+ *
+ * The program is a C interface to zaxpy.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zaxpy( const int N, const void *alpha, const void *X,
+ const int incX, void *Y, const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_zaxpy( &F77_N, alpha, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zcopy.c b/ml/dlib/dlib/external/cblas/cblas_zcopy.c
new file mode 100644
index 000000000..a16be28e7
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zcopy.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_zcopy.c
+ *
+ * The program is a C interface to zcopy.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zcopy( const int N, const void *X,
+ const int incX, void *Y, const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_zcopy( &F77_N, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zdotc_sub.c b/ml/dlib/dlib/external/cblas/cblas_zdotc_sub.c
new file mode 100644
index 000000000..29dec6c57
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zdotc_sub.c
@@ -0,0 +1,24 @@
+/*
+ * cblas_zdotc_sub.c
+ *
+ * The program is a C interface to zdotc.
+ * It calls the fortran wrapper before calling zdotc.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zdotc_sub( const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotc)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_zdotc_sub( &F77_N, X, &F77_incX, Y, &F77_incY, dotc);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zdotu_sub.c b/ml/dlib/dlib/external/cblas/cblas_zdotu_sub.c
new file mode 100644
index 000000000..48a14bf3d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zdotu_sub.c
@@ -0,0 +1,24 @@
+/*
+ * cblas_zdotu_sub.c
+ *
+ * The program is a C interface to zdotu.
+ * It calls the fortran wrapper before calling zdotu.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zdotu_sub( const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotu)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_zdotu_sub( &F77_N, X, &F77_incX, Y, &F77_incY, dotu);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zdscal.c b/ml/dlib/dlib/external/cblas/cblas_zdscal.c
new file mode 100644
index 000000000..788365bef
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zdscal.c
@@ -0,0 +1,21 @@
+/*
+ * cblas_zdscal.c
+ *
+ * The program is a C interface to zdscal.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zdscal( const int N, const double alpha, void *X,
+ const int incX)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_zdscal( &F77_N, &alpha, X, &F77_incX);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zgbmv.c b/ml/dlib/dlib/external/cblas/cblas_zgbmv.c
new file mode 100644
index 000000000..e5b922429
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zgbmv.c
@@ -0,0 +1,155 @@
+/*
+ * cblas_zgbmv.c
+ * The program is a C interface of zgbmv
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zgbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const int KL, const int KU,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char TA;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA;
+#else
+ #define F77_TA &TA
+#endif
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+ F77_INT F77_KL=KL,F77_KU=KU;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_KL KL
+ #define F77_KU KU
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+ int n, i=0, incx=incX;
+ const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta;
+ double ALPHA[2],BETA[2];
+ int tincY, tincx;
+ double *x=(double *)X, *y=(double *)Y, *st=0, *tx;
+
+ if (order == CblasColMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(2, "cblas_zgbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_zgbmv(F77_TA, &F77_M, &F77_N, &F77_KL, &F77_KU, alpha,
+ A, &F77_lda, X, &F77_incX, beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ BETA[0]= *bet;
+ BETA[1]= -bet[1];
+ TA = 'N';
+ if (M > 0)
+ {
+ n = M << 1;
+ x = malloc(n*sizeof(double));
+ tx = x;
+
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+
+ if( incY > 0 )
+ tincY = incY;
+ else
+ tincY = -incY;
+
+ y++;
+
+ if (N > 0)
+ {
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ }
+ }
+ else x = (double *) X;
+
+
+ }
+ else
+ {
+ cblas_xerbla(2, "cblas_zgbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ if (TransA == CblasConjTrans)
+ F77_zgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, ALPHA,
+ A ,&F77_lda, x,&F77_incX, BETA, Y, &F77_incY);
+ else
+ F77_zgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, alpha,
+ A ,&F77_lda, x,&F77_incX, beta, Y, &F77_incY);
+ if (TransA == CblasConjTrans)
+ {
+ if (x != X) free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_zgbmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zgemm.c b/ml/dlib/dlib/external/cblas/cblas_zgemm.c
new file mode 100644
index 000000000..c348afa2d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zgemm.c
@@ -0,0 +1,94 @@
+/*
+ *
+ * cblas_zgemm.c
+ * This program is a C interface to zgemm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const void *alpha, const void *A,
+ const int lda, const void *B, const int ldb,
+ const void *beta, void *C, const int ldc)
+{
+ char TA, TB;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_TB;
+#else
+ #define F77_TA &TA
+ #define F77_TB &TB
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if(TransA == CblasTrans) TA='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_zgemm","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if(TransB == CblasTrans) TB='T';
+ else if ( TransB == CblasConjTrans ) TB='C';
+ else if ( TransB == CblasNoTrans ) TB='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_zgemm","Illegal TransB setting, %d\n", TransB);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ F77_TB = C2F_CHAR(&TB);
+ #endif
+
+ F77_zgemm(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, alpha, A,
+ &F77_lda, B, &F77_ldb, beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if(TransA == CblasTrans) TB='T';
+ else if ( TransA == CblasConjTrans ) TB='C';
+ else if ( TransA == CblasNoTrans ) TB='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_zgemm","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if(TransB == CblasTrans) TA='T';
+ else if ( TransB == CblasConjTrans ) TA='C';
+ else if ( TransB == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(2, "cblas_zgemm","Illegal TransB setting, %d\n", TransB);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ F77_TB = C2F_CHAR(&TB);
+ #endif
+
+ F77_zgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, alpha, B,
+ &F77_ldb, A, &F77_lda, beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_zgemm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zgemv.c b/ml/dlib/dlib/external/cblas/cblas_zgemv.c
new file mode 100644
index 000000000..6d5cd0cb2
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zgemv.c
@@ -0,0 +1,153 @@
+/*
+ * cblas_zgemv.c
+ * The program is a C interface of zgemv
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zgemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char TA;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA;
+#else
+ #define F77_TA &TA
+#endif
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+
+ int n, i=0, incx=incX;
+ const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta;
+ double ALPHA[2],BETA[2];
+ int tincY, tincx;
+ double *x=(double *)X, *y=(double *)Y, *st=0, *tx;
+
+
+ if (order == CblasColMajor)
+ {
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(2, "cblas_zgemv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ F77_zgemv(F77_TA, &F77_M, &F77_N, alpha, A, &F77_lda, X, &F77_incX,
+ beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ BETA[0]= *bet;
+ BETA[1]= -bet[1];
+ TA = 'N';
+ if (M > 0)
+ {
+ n = M << 1;
+ x = malloc(n*sizeof(double));
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+
+ if(incY > 0)
+ tincY = incY;
+ else
+ tincY = -incY;
+
+ y++;
+
+ if (N > 0)
+ {
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ }
+ }
+ else x = (double *) X;
+ }
+ else
+ {
+ cblas_xerbla(2, "cblas_zgemv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_TA = C2F_CHAR(&TA);
+ #endif
+ if (TransA == CblasConjTrans)
+ F77_zgemv(F77_TA, &F77_N, &F77_M, ALPHA, A, &F77_lda, x,
+ &F77_incX, BETA, Y, &F77_incY);
+ else
+ F77_zgemv(F77_TA, &F77_N, &F77_M, alpha, A, &F77_lda, x,
+ &F77_incX, beta, Y, &F77_incY);
+
+ if (TransA == CblasConjTrans)
+ {
+ if (x != (double *)X) free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_zgemv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zgerc.c b/ml/dlib/dlib/external/cblas/cblas_zgerc.c
new file mode 100644
index 000000000..2fbbcb028
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zgerc.c
@@ -0,0 +1,77 @@
+/*
+ * cblas_zgerc.c
+ * The program is a C interface to zgerc.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda)
+{
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incy
+ #define F77_lda lda
+#endif
+
+ int n, i, tincy, incy=incY;
+ double *y=(double *)Y, *yy=(double *)Y, *ty, *st;
+
+
+ if (order == CblasColMajor)
+ {
+ F77_zgerc( &F77_M, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+ } else if (order == CblasRowMajor)
+ {
+ if (N > 0)
+ {
+ n = N << 1;
+ y = malloc(n*sizeof(double));
+
+ ty = y;
+ if( incY > 0 ) {
+ i = incY << 1;
+ tincy = 2;
+ st= y+n;
+ } else {
+ i = incY *(-2);
+ tincy = -2;
+ st = y-2;
+ y +=(n-2);
+ }
+ do
+ {
+ *y = *yy;
+ y[1] = -yy[1];
+ y += tincy ;
+ yy += i;
+ }
+ while (y != st);
+ y = ty;
+
+ #ifdef F77_INT
+ F77_incY = 1;
+ #else
+ incy = 1;
+ #endif
+ }
+ else y = (double *) Y;
+
+ F77_zgeru( &F77_N, &F77_M, alpha, y, &F77_incY, X, &F77_incX, A,
+ &F77_lda);
+ if(Y!=y)
+ free(y);
+
+ } else cblas_xerbla(1, "cblas_zgerc", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zgeru.c b/ml/dlib/dlib/external/cblas/cblas_zgeru.c
new file mode 100644
index 000000000..56c3ded68
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zgeru.c
@@ -0,0 +1,37 @@
+/*
+ * cblas_zgeru.c
+ * The program is a C interface to zgeru.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda)
+{
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+ #define F77_lda lda
+#endif
+
+
+ if (order == CblasColMajor)
+ {
+ F77_zgeru( &F77_M, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, A,
+ &F77_lda);
+ }
+ else if (order == CblasRowMajor)
+ {
+ F77_zgeru( &F77_N, &F77_M, alpha, Y, &F77_incY, X, &F77_incX, A,
+ &F77_lda);
+ }
+ else cblas_xerbla(1, "cblas_zgeru", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zhbmv.c b/ml/dlib/dlib/external/cblas/cblas_zhbmv.c
new file mode 100644
index 000000000..491207d3f
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zhbmv.c
@@ -0,0 +1,145 @@
+/*
+ * cblas_zhbmv.c
+ * The program is a C interface to zhbmv
+ *
+ * Keita Teranishi 5/18/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+#include <stdio.h>
+#include <stdlib.h>
+void cblas_zhbmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo,const int N,const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+ int n, i=0, incx=incX;
+ const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta;
+ double ALPHA[2],BETA[2];
+ int tincY, tincx;
+ double *x=(double *)X, *y=(double *)Y, *st=0, *tx;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhbmv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_zhbmv(F77_UL, &F77_N, &F77_K, alpha, A, &F77_lda, X,
+ &F77_incX, beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ BETA[0]= *bet;
+ BETA[1]= -bet[1];
+
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(double));
+
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+
+ if(incY > 0)
+ tincY = incY;
+ else
+ tincY = -incY;
+ y++;
+
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ } else
+ x = (double *) X;
+
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_zhbmv(F77_UL, &F77_N, &F77_K, ALPHA,
+ A ,&F77_lda, x,&F77_incX, BETA, Y, &F77_incY);
+ }
+ else
+ {
+ cblas_xerbla(1, "cblas_zhbmv","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if ( order == CblasRowMajor )
+ {
+ if(X!=x)
+ free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zhemm.c b/ml/dlib/dlib/external/cblas/cblas_zhemm.c
new file mode 100644
index 000000000..a31e9ae87
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zhemm.c
@@ -0,0 +1,91 @@
+/*
+ *
+ * cblas_zhemm.c
+ * This program is a C interface to zhemm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc)
+{
+ char SD, UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_SD, F77_UL;
+#else
+ #define F77_SD &SD
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhemm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_zhemm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_zhemm(F77_SD, F77_UL, &F77_M, &F77_N, alpha, A, &F77_lda,
+ B, &F77_ldb, beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhemm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_zhemm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_zhemm(F77_SD, F77_UL, &F77_N, &F77_M, alpha, A,
+ &F77_lda, B, &F77_ldb, beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_zhemm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zhemv.c b/ml/dlib/dlib/external/cblas/cblas_zhemv.c
new file mode 100644
index 000000000..c3cdb5958
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zhemv.c
@@ -0,0 +1,146 @@
+/*
+ * cblas_zhemv.c
+ * The program is a C interface to zhemv
+ *
+ * Keita Teranishi 5/18/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zhemv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+ int n, i=0, incx=incX;
+ const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta;
+ double ALPHA[2],BETA[2];
+ int tincY, tincx;
+ double *x=(double *)X, *y=(double *)Y, *st=0, *tx;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhemv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_zhemv(F77_UL, &F77_N, alpha, A, &F77_lda, X, &F77_incX,
+ beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ BETA[0]= *bet;
+ BETA[1]= -bet[1];
+
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(double));
+
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+
+ if(incY > 0)
+ tincY = incY;
+ else
+ tincY = -incY;
+ y++;
+
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ } else
+ x = (double *) X;
+
+
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhemv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_zhemv(F77_UL, &F77_N, ALPHA, A, &F77_lda, x, &F77_incX,
+ BETA, Y, &F77_incY);
+ }
+ else
+ {
+ cblas_xerbla(1, "cblas_zhemv","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if ( order == CblasRowMajor )
+ {
+ if ( X != x )
+ free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zher.c b/ml/dlib/dlib/external/cblas/cblas_zher.c
new file mode 100644
index 000000000..30453737c
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zher.c
@@ -0,0 +1,99 @@
+/*
+ * cblas_zher.c
+ * The program is a C interface to zher.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const void *X, const int incX
+ ,void *A, const int lda)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incx
+#endif
+ int n, i, tincx, incx=incX;
+ double *x=(double *)X, *xx=(double *)X, *tx, *st;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zher","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_zher(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zher","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(double));
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+ }
+ else x = (double *) X;
+ F77_zher(F77_UL, &F77_N, &alpha, x, &F77_incX, A, &F77_lda);
+ } else cblas_xerbla(1, "cblas_zher", "Illegal Order setting, %d\n", order);
+ if(X!=x)
+ free(x);
+
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zher2.c b/ml/dlib/dlib/external/cblas/cblas_zher2.c
new file mode 100644
index 000000000..8bf0bd733
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zher2.c
@@ -0,0 +1,140 @@
+/*
+ * cblas_zher2.c
+ * The program is a C interface to zher2.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incx
+ #define F77_incY incy
+#endif
+ int n, i, j, tincx, tincy, incx=incX, incy=incY;
+ double *x=(double *)X, *xx=(double *)X, *y=(double *)Y,
+ *yy=(double *)Y, *tx, *ty, *stx, *sty;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zher2", "Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_zher2(F77_UL, &F77_N, alpha, X, &F77_incX,
+ Y, &F77_incY, A, &F77_lda);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zher2", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(double));
+ y = malloc(n*sizeof(double));
+ tx = x;
+ ty = y;
+ if( incX > 0 ) {
+ i = incX << 1 ;
+ tincx = 2;
+ stx= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ stx = x-2;
+ x +=(n-2);
+ }
+
+ if( incY > 0 ) {
+ j = incY << 1;
+ tincy = 2;
+ sty= y+n;
+ } else {
+ j = incY *(-2);
+ tincy = -2;
+ sty = y-2;
+ y +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != stx);
+
+ do
+ {
+ *y = *yy;
+ y[1] = -yy[1];
+ y += tincy ;
+ yy += j;
+ }
+ while (y != sty);
+
+ x=tx;
+ y=ty;
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ F77_incY = 1;
+ #else
+ incx = 1;
+ incy = 1;
+ #endif
+ } else
+ {
+ x = (double *) X;
+ y = (double *) Y;
+ }
+ F77_zher2(F77_UL, &F77_N, alpha, y, &F77_incY, x,
+ &F77_incX, A, &F77_lda);
+ }
+ else
+ {
+ cblas_xerbla(1, "cblas_zher2", "Illegal Order setting, %d\n", order);
+ return;
+ }
+ if(X!=x)
+ free(x);
+ if(Y!=y)
+ free(y);
+
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zher2k.c b/ml/dlib/dlib/external/cblas/cblas_zher2k.c
new file mode 100644
index 000000000..96bcfe2a5
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zher2k.c
@@ -0,0 +1,95 @@
+/*
+ *
+ * cblas_zher2k.c
+ * This program is a C interface to zher2k.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const double beta,
+ void *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+ double ALPHA[2];
+ const double *alp=(double *)alpha;
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_zher2k", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_zher2k", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_zher2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zher2k", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='C';
+ else
+ {
+ cblas_xerbla(3, "cblas_zher2k", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ F77_zher2k(F77_UL,F77_TR, &F77_N, &F77_K, ALPHA, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc);
+ } else cblas_xerbla(1, "cblas_zher2k", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zherk.c b/ml/dlib/dlib/external/cblas/cblas_zherk.c
new file mode 100644
index 000000000..bddef491b
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zherk.c
@@ -0,0 +1,90 @@
+/*
+ *
+ * cblas_zherk.c
+ * This program is a C interface to zherk.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const double alpha, const void *A, const int lda,
+ const double beta, void *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_zherk", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_zherk", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_zherk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda,
+ &beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_zherk", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='C';
+ else
+ {
+ cblas_xerbla(3, "cblas_zherk", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_zherk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda,
+ &beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_zherk", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zhpmv.c b/ml/dlib/dlib/external/cblas/cblas_zhpmv.c
new file mode 100644
index 000000000..1812884fc
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zhpmv.c
@@ -0,0 +1,146 @@
+/*
+ * cblas_zhpmv.c
+ * The program is a C interface of zhpmv
+ *
+ * Keita Teranishi 5/18/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zhpmv(const enum CBLAS_ORDER order,
+ const enum CBLAS_UPLO Uplo,const int N,
+ const void *alpha, const void *AP,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incx
+ #define F77_incY incY
+#endif
+ int n, i=0, incx=incX;
+ const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta;
+ double ALPHA[2],BETA[2];
+ int tincY, tincx;
+ double *x=(double *)X, *y=(double *)Y, *st=0, *tx;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhpmv","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ F77_zhpmv(F77_UL, &F77_N, alpha, AP, X,
+ &F77_incX, beta, Y, &F77_incY);
+ }
+ else if (order == CblasRowMajor)
+ {
+ ALPHA[0]= *alp;
+ ALPHA[1]= -alp[1];
+ BETA[0]= *bet;
+ BETA[1]= -bet[1];
+
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(double));
+
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+
+
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+
+ if(incY > 0)
+ tincY = incY;
+ else
+ tincY = -incY;
+ y++;
+
+ i = tincY << 1;
+ n = i * N ;
+ st = y + n;
+ do {
+ *y = -(*y);
+ y += i;
+ } while(y != st);
+ y -= n;
+ } else
+ x = (double *) X;
+
+
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhpmv","Illegal Uplo setting, %d\n", Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_zhpmv(F77_UL, &F77_N, ALPHA,
+ AP, x, &F77_incX, BETA, Y, &F77_incY);
+ }
+ else
+ {
+ cblas_xerbla(1, "cblas_zhpmv","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if ( order == CblasRowMajor )
+ {
+ if(X!=x)
+ free(x);
+ if (N > 0)
+ {
+ do
+ {
+ *y = -(*y);
+ y += i;
+ }
+ while (y != st);
+ }
+ }
+
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zhpr.c b/ml/dlib/dlib/external/cblas/cblas_zhpr.c
new file mode 100644
index 000000000..3ed2a8f61
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zhpr.c
@@ -0,0 +1,102 @@
+/*
+ * cblas_zhpr.c
+ * The program is a C interface to zhpr.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N, const double alpha, const void *X,
+ const int incX, void *A)
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incx
+#endif
+ int n, i, tincx, incx=incX;
+ double *x=(double *)X, *xx=(double *)X, *tx, *st;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhpr","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_zhpr(F77_UL, &F77_N, &alpha, X, &F77_incX, A);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhpr","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(double));
+ tx = x;
+ if( incX > 0 ) {
+ i = incX << 1;
+ tincx = 2;
+ st= x+n;
+ } else {
+ i = incX *(-2);
+ tincx = -2;
+ st = x-2;
+ x +=(n-2);
+ }
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += tincx ;
+ xx += i;
+ }
+ while (x != st);
+ x=tx;
+ #ifdef F77_INT
+ F77_incX = 1;
+ #else
+ incx = 1;
+ #endif
+ }
+ else x = (double *) X;
+
+ F77_zhpr(F77_UL, &F77_N, &alpha, x, &F77_incX, A);
+
+ } else
+ {
+ cblas_xerbla(1, "cblas_zhpr","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if(X!=x)
+ free(x);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zhpr2.c b/ml/dlib/dlib/external/cblas/cblas_zhpr2.c
new file mode 100644
index 000000000..0793a298a
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zhpr2.c
@@ -0,0 +1,137 @@
+/*
+ * cblas_zhpr2.c
+ * The program is a C interface to zhpr2.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const int N,const void *alpha, const void *X,
+ const int incX,const void *Y, const int incY, void *Ap)
+
+{
+ char UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_UL;
+#else
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incx
+ #define F77_incY incy
+#endif
+ int n, i, j, incx=incX, incy=incY;
+ double *x=(double *)X, *xx=(double *)X, *y=(double *)Y,
+ *yy=(double *)Y, *stx, *sty;
+
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasLower) UL = 'L';
+ else if (Uplo == CblasUpper) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhpr2","Illegal Uplo setting, %d\n",Uplo );
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+
+ F77_zhpr2(F77_UL, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, Ap);
+
+ } else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_zhpr2","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ #endif
+ if (N > 0)
+ {
+ n = N << 1;
+ x = malloc(n*sizeof(double));
+ y = malloc(n*sizeof(double));
+ stx = x + n;
+ sty = y + n;
+ if( incX > 0 )
+ i = incX << 1;
+ else
+ i = incX *(-2);
+
+ if( incY > 0 )
+ j = incY << 1;
+ else
+ j = incY *(-2);
+ do
+ {
+ *x = *xx;
+ x[1] = -xx[1];
+ x += 2;
+ xx += i;
+ } while (x != stx);
+ do
+ {
+ *y = *yy;
+ y[1] = -yy[1];
+ y += 2;
+ yy += j;
+ }
+ while (y != sty);
+ x -= n;
+ y -= n;
+
+ #ifdef F77_INT
+ if(incX > 0 )
+ F77_incX = 1;
+ else
+ F77_incX = -1;
+
+ if(incY > 0 )
+ F77_incY = 1;
+ else
+ F77_incY = -1;
+
+ #else
+ if(incX > 0 )
+ incx = 1;
+ else
+ incx = -1;
+
+ if(incY > 0 )
+ incy = 1;
+ else
+ incy = -1;
+ #endif
+
+ } else
+ {
+ x = (double *) X;
+ y = (void *) Y;
+ }
+ F77_zhpr2(F77_UL, &F77_N, alpha, y, &F77_incY, x, &F77_incX, Ap);
+ }
+ else
+ {
+ cblas_xerbla(1, "cblas_zhpr2","Illegal Order setting, %d\n", order);
+ return;
+ }
+ if(X!=x)
+ free(x);
+ if(Y!=y)
+ free(y);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zscal.c b/ml/dlib/dlib/external/cblas/cblas_zscal.c
new file mode 100644
index 000000000..37b319f38
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zscal.c
@@ -0,0 +1,21 @@
+/*
+ * cblas_zscal.c
+ *
+ * The program is a C interface to zscal.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zscal( const int N, const void *alpha, void *X,
+ const int incX)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ F77_zscal( &F77_N, alpha, X, &F77_incX);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zswap.c b/ml/dlib/dlib/external/cblas/cblas_zswap.c
new file mode 100644
index 000000000..dfde2cbd0
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zswap.c
@@ -0,0 +1,22 @@
+/*
+ * cblas_zswap.c
+ *
+ * The program is a C interface to zswap.
+ *
+ * Written by Keita Teranishi. 2/11/1998
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zswap( const int N, void *X, const int incX, void *Y,
+ const int incY)
+{
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
+#else
+ #define F77_N N
+ #define F77_incX incX
+ #define F77_incY incY
+#endif
+ F77_zswap( &F77_N, X, &F77_incX, Y, &F77_incY);
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zsymm.c b/ml/dlib/dlib/external/cblas/cblas_zsymm.c
new file mode 100644
index 000000000..85d5e3f49
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zsymm.c
@@ -0,0 +1,91 @@
+/*
+ *
+ * cblas_zsymm.c
+ * This program is a C interface to zsymm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc)
+{
+ char SD, UL;
+#ifdef F77_CHAR
+ F77_CHAR F77_SD, F77_UL;
+#else
+ #define F77_SD &SD
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_zsymm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_zsymm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_zsymm(F77_SD, F77_UL, &F77_M, &F77_N, alpha, A, &F77_lda,
+ B, &F77_ldb, beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_zsymm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_zsymm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_SD = C2F_CHAR(&SD);
+ #endif
+
+ F77_zsymm(F77_SD, F77_UL, &F77_N, &F77_M, alpha, A, &F77_lda,
+ B, &F77_ldb, beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_zsymm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zsyr2k.c b/ml/dlib/dlib/external/cblas/cblas_zsyr2k.c
new file mode 100644
index 000000000..ffac33462
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zsyr2k.c
@@ -0,0 +1,93 @@
+/*
+ *
+ * cblas_zsyr2k.c
+ * This program is a C interface to zsyr2k.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *B, const int ldb, const void *beta,
+ void *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldb ldb
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_zsyr2k", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_zsyr2k", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_zsyr2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda,
+ B, &F77_ldb, beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_zsyr2k", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='T';
+ else
+ {
+ cblas_xerbla(3, "cblas_zsyr2k", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_zsyr2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, B, &F77_ldb, beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_zsyr2k", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_zsyrk.c b/ml/dlib/dlib/external/cblas/cblas_zsyrk.c
new file mode 100644
index 000000000..45796074f
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_zsyrk.c
@@ -0,0 +1,92 @@
+/*
+ *
+ * cblas_zsyrk.c
+ * This program is a C interface to zsyrk.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
+ const void *alpha, const void *A, const int lda,
+ const void *beta, void *C, const int ldc)
+{
+ char UL, TR;
+#ifdef F77_CHAR
+ F77_CHAR F77_TR, F77_UL;
+#else
+ #define F77_TR &TR
+ #define F77_UL &UL
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_K=K, F77_lda=lda;
+ F77_INT F77_ldc=ldc;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_ldc ldc
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_zsyrk", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( Trans == CblasTrans) TR ='T';
+ else if ( Trans == CblasConjTrans ) TR='C';
+ else if ( Trans == CblasNoTrans ) TR='N';
+ else
+ {
+ cblas_xerbla(3, "cblas_zsyrk", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_zsyrk(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda,
+ beta, C, &F77_ldc);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_zsyrk", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if( Trans == CblasTrans) TR ='N';
+ else if ( Trans == CblasConjTrans ) TR='N';
+ else if ( Trans == CblasNoTrans ) TR='T';
+ else
+ {
+ cblas_xerbla(3, "cblas_zsyrk", "Illegal Trans setting, %d\n", Trans);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TR = C2F_CHAR(&TR);
+ #endif
+
+ F77_zsyrk(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda,
+ beta, C, &F77_ldc);
+ }
+ else cblas_xerbla(1, "cblas_zsyrk", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ztbmv.c b/ml/dlib/dlib/external/cblas/cblas_ztbmv.c
new file mode 100644
index 000000000..916774d27
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ztbmv.c
@@ -0,0 +1,139 @@
+/*
+ * cblas_ztbmv.c
+ * The program is a C interface to ztbmv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const void *A, const int lda,
+ void *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ double *st=0, *x=(double *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztbmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ztbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if(incX > 0)
+ tincX = incX;
+ else
+ tincX = -incX;
+ i = tincX << 1;
+ n = i * N;
+ x++;
+ st = x + n;
+ do
+ {
+ *x = -(*x);
+ x+= i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ztbmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztbmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ztbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ztbmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ztbsv.c b/ml/dlib/dlib/external/cblas/cblas_ztbsv.c
new file mode 100644
index 000000000..cc5d3f73f
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ztbsv.c
@@ -0,0 +1,143 @@
+/*
+ * cblas_ztbsv.c
+ * The program is a C interface to ztbsv.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const int K, const void *A, const int lda,
+ void *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_K K
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ double *st=0,*x=(double *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztbsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztbsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztbsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ztbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztbsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if ( incX > 0 )
+ tincX = incX;
+ else
+ tincX = -incX;
+
+ n = N*2*(tincX);
+
+ x++;
+
+ st=x+n;
+
+ i = tincX << 1;
+ do
+ {
+ *x = -(*x);
+ x+=i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ztbsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztbsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ztbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X,
+ &F77_incX);
+
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x+= i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ztbsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ztpmv.c b/ml/dlib/dlib/external/cblas/cblas_ztpmv.c
new file mode 100644
index 000000000..2e7949a25
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ztpmv.c
@@ -0,0 +1,133 @@
+/*
+ * cblas_ztpmv.c
+ * The program is a C interface to ztpmv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *Ap, void *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ double *st=0,*x=(double *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztpmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztpmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztpmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ztpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztpmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if(incX > 0)
+ tincX = incX;
+ else
+ tincX = -incX;
+ i = tincX << 1;
+ n = i * N;
+ x++;
+ st = x + n;
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ztpmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztpmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ztpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX);
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ztpmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ztpsv.c b/ml/dlib/dlib/external/cblas/cblas_ztpsv.c
new file mode 100644
index 000000000..c41d02016
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ztpsv.c
@@ -0,0 +1,138 @@
+/*
+ * cblas_ztpsv.c
+ * The program is a C interface to ztpsv.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *Ap, void *X, const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ double *st=0, *x=(double*)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztpsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztpsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztpsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ztpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztpsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if ( incX > 0 )
+ tincX = incX;
+ else
+ tincX = -incX;
+
+ n = N*2*(tincX);
+
+ x++;
+
+ st=x+n;
+
+ i = tincX << 1;
+ do
+ {
+ *x = -(*x);
+ x+=i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ztpsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztpsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ztpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX);
+
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ztpsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ztrmm.c b/ml/dlib/dlib/external/cblas/cblas_ztrmm.c
new file mode 100644
index 000000000..4e76377d9
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ztrmm.c
@@ -0,0 +1,126 @@
+/*
+ *
+ * cblas_ztrmm.c
+ * This program is a C interface to ztrmm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ void *B, const int ldb)
+{
+ char UL, TA, SD, DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_SD &SD
+ #define F77_DI &DI
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+ if( Side == CblasRight ) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztrmm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+ if( Uplo == CblasUpper ) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztrmm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans ) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztrmm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_ztrmm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ztrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, alpha, A, &F77_lda, B, &F77_ldb);
+ } else if (Order == CblasRowMajor)
+ {
+ if( Side == CblasRight ) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztrmm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper ) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztrmm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans ) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztrmm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_ztrmm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ztrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, alpha, A, &F77_lda, B, &F77_ldb);
+ }
+ else cblas_xerbla(1, "cblas_ztrmm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ztrmv.c b/ml/dlib/dlib/external/cblas/cblas_ztrmv.c
new file mode 100644
index 000000000..65e760529
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ztrmv.c
@@ -0,0 +1,137 @@
+/*
+ * cblas_ztrmv.c
+ * The program is a C interface to ztrmv.
+ *
+ * Keita Teranishi 5/20/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *A, const int lda,
+ void *X, const int incX)
+
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ double *st=0,*x=(double *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztrmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztrmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztrmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ztrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztrmv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if(incX > 0)
+ tincX = incX;
+ else
+ tincX = -incX;
+ i = tincX << 1;
+ n = i * N;
+ x++;
+ st = x + n;
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ztrmv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztrmv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ztrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ztrmv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ztrsm.c b/ml/dlib/dlib/external/cblas/cblas_ztrsm.c
new file mode 100644
index 000000000..7540147c8
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ztrsm.c
@@ -0,0 +1,132 @@
+/*
+ *
+ * cblas_ztrsm.c
+ * This program is a C interface to ztrsm.
+ * Written by Keita Teranishi
+ * 4/8/1998
+ *
+ */
+
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
+ const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
+ const enum CBLAS_DIAG Diag, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ void *B, const int ldb)
+{
+ char UL, TA, SD, DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_SD &SD
+ #define F77_DI &DI
+#endif
+
+#ifdef F77_INT
+ F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb;
+#else
+ #define F77_M M
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_ldb ldb
+#endif
+
+
+ if( Order == CblasColMajor )
+ {
+
+ if( Side == CblasRight) SD='R';
+ else if ( Side == CblasLeft ) SD='L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztrsm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='U';
+ else if ( Uplo == CblasLower ) UL='L';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztrsm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztrsm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_ztrsm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+ F77_ztrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, alpha, A,
+ &F77_lda, B, &F77_ldb);
+ } else if (Order == CblasRowMajor)
+ {
+
+ if( Side == CblasRight) SD='L';
+ else if ( Side == CblasLeft ) SD='R';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztrsm", "Illegal Side setting, %d\n", Side);
+ return;
+ }
+
+ if( Uplo == CblasUpper) UL='L';
+ else if ( Uplo == CblasLower ) UL='U';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztrsm", "Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if( TransA == CblasTrans) TA ='T';
+ else if ( TransA == CblasConjTrans ) TA='C';
+ else if ( TransA == CblasNoTrans ) TA='N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztrsm", "Illegal Trans setting, %d\n", TransA);
+ return;
+ }
+
+ if( Diag == CblasUnit ) DI='U';
+ else if ( Diag == CblasNonUnit ) DI='N';
+ else
+ {
+ cblas_xerbla(5, "cblas_ztrsm", "Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_SD = C2F_CHAR(&SD);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+
+
+ F77_ztrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, alpha, A,
+ &F77_lda, B, &F77_ldb);
+ }
+ else cblas_xerbla(1, "cblas_ztrsm", "Illegal Order setting, %d\n", Order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cblas_ztrsv.c b/ml/dlib/dlib/external/cblas/cblas_ztrsv.c
new file mode 100644
index 000000000..07e2653b4
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cblas_ztrsv.c
@@ -0,0 +1,137 @@
+/*
+ * cblas_ztrsv.c
+ * The program is a C interface to ztrsv.
+ *
+ * Keita Teranishi 3/23/98
+ *
+ */
+#include "cblas.h"
+#include "cblas_f77.h"
+void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
+ const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
+ const int N, const void *A, const int lda, void *X,
+ const int incX)
+{
+ char TA;
+ char UL;
+ char DI;
+#ifdef F77_CHAR
+ F77_CHAR F77_TA, F77_UL, F77_DI;
+#else
+ #define F77_TA &TA
+ #define F77_UL &UL
+ #define F77_DI &DI
+#endif
+#ifdef F77_INT
+ F77_INT F77_N=N, F77_lda=lda, F77_incX=incX;
+#else
+ #define F77_N N
+ #define F77_lda lda
+ #define F77_incX incX
+#endif
+ int n, i=0, tincX;
+ double *st=0,*x=(double *)X;
+
+ if (order == CblasColMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'U';
+ else if (Uplo == CblasLower) UL = 'L';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztrsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+ if (TransA == CblasNoTrans) TA = 'N';
+ else if (TransA == CblasTrans) TA = 'T';
+ else if (TransA == CblasConjTrans) TA = 'C';
+ else
+ {
+ cblas_xerbla(3, "cblas_ztrsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztrsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ztrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ }
+ else if (order == CblasRowMajor)
+ {
+ if (Uplo == CblasUpper) UL = 'L';
+ else if (Uplo == CblasLower) UL = 'U';
+ else
+ {
+ cblas_xerbla(2, "cblas_ztrsv","Illegal Uplo setting, %d\n", Uplo);
+ return;
+ }
+
+ if (TransA == CblasNoTrans) TA = 'T';
+ else if (TransA == CblasTrans) TA = 'N';
+ else if (TransA == CblasConjTrans)
+ {
+ TA = 'N';
+ if ( N > 0)
+ {
+ if ( incX > 0 )
+ tincX = incX;
+ else
+ tincX = -incX;
+
+ n = N*2*(tincX);
+ x++;
+ st=x+n;
+ i = tincX << 1;
+ do
+ {
+ *x = -(*x);
+ x+=i;
+ }
+ while (x != st);
+ x -= n;
+ }
+ }
+ else
+ {
+ cblas_xerbla(3, "cblas_ztrsv","Illegal TransA setting, %d\n", TransA);
+ return;
+ }
+
+ if (Diag == CblasUnit) DI = 'U';
+ else if (Diag == CblasNonUnit) DI = 'N';
+ else
+ {
+ cblas_xerbla(4, "cblas_ztrsv","Illegal Diag setting, %d\n", Diag);
+ return;
+ }
+ #ifdef F77_CHAR
+ F77_UL = C2F_CHAR(&UL);
+ F77_TA = C2F_CHAR(&TA);
+ F77_DI = C2F_CHAR(&DI);
+ #endif
+ F77_ztrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X,
+ &F77_incX);
+ if (TransA == CblasConjTrans)
+ {
+ if (N > 0)
+ {
+ do
+ {
+ *x = -(*x);
+ x += i;
+ }
+ while (x != st);
+ }
+ }
+ }
+ else cblas_xerbla(1, "cblas_ztrsv", "Illegal Order setting, %d\n", order);
+ return;
+}
diff --git a/ml/dlib/dlib/external/cblas/cdotcsub.f b/ml/dlib/dlib/external/cblas/cdotcsub.f
new file mode 100644
index 000000000..f97d7159e
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cdotcsub.f
@@ -0,0 +1,15 @@
+c cdotcsub.f
+c
+c The program is a fortran wrapper for cdotc.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine cdotcsub(n,x,incx,y,incy,dotc)
+c
+ external cdotc
+ complex cdotc,dotc
+ integer n,incx,incy
+ complex x(*),y(*)
+c
+ dotc=cdotc(n,x,incx,y,incy)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/cdotusub.f b/ml/dlib/dlib/external/cblas/cdotusub.f
new file mode 100644
index 000000000..5107c0402
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/cdotusub.f
@@ -0,0 +1,15 @@
+c cdotusub.f
+c
+c The program is a fortran wrapper for cdotu.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine cdotusub(n,x,incx,y,incy,dotu)
+c
+ external cdotu
+ complex cdotu,dotu
+ integer n,incx,incy
+ complex x(*),y(*)
+c
+ dotu=cdotu(n,x,incx,y,incy)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/dasumsub.f b/ml/dlib/dlib/external/cblas/dasumsub.f
new file mode 100644
index 000000000..3d64d17e6
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/dasumsub.f
@@ -0,0 +1,15 @@
+c dasumsun.f
+c
+c The program is a fortran wrapper for dasum..
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine dasumsub(n,x,incx,asum)
+c
+ external dasum
+ double precision dasum,asum
+ integer n,incx
+ double precision x(*)
+c
+ asum=dasum(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/ddotsub.f b/ml/dlib/dlib/external/cblas/ddotsub.f
new file mode 100644
index 000000000..205f3b46f
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/ddotsub.f
@@ -0,0 +1,15 @@
+c ddotsub.f
+c
+c The program is a fortran wrapper for ddot.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine ddotsub(n,x,incx,y,incy,dot)
+c
+ external ddot
+ double precision ddot
+ integer n,incx,incy
+ double precision x(*),y(*),dot
+c
+ dot=ddot(n,x,incx,y,incy)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/dnrm2sub.f b/ml/dlib/dlib/external/cblas/dnrm2sub.f
new file mode 100644
index 000000000..88f17db8b
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/dnrm2sub.f
@@ -0,0 +1,15 @@
+c dnrm2sub.f
+c
+c The program is a fortran wrapper for dnrm2.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine dnrm2sub(n,x,incx,nrm2)
+c
+ external dnrm2
+ double precision dnrm2,nrm2
+ integer n,incx
+ double precision x(*)
+c
+ nrm2=dnrm2(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/dsdotsub.f b/ml/dlib/dlib/external/cblas/dsdotsub.f
new file mode 100644
index 000000000..e7e872c9e
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/dsdotsub.f
@@ -0,0 +1,15 @@
+c dsdotsub.f
+c
+c The program is a fortran wrapper for dsdot.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine dsdotsub(n,x,incx,y,incy,dot)
+c
+ external dsdot
+ double precision dsdot,dot
+ integer n,incx,incy
+ real x(*),y(*)
+c
+ dot=dsdot(n,x,incx,y,incy)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/dzasumsub.f b/ml/dlib/dlib/external/cblas/dzasumsub.f
new file mode 100644
index 000000000..9aaf16387
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/dzasumsub.f
@@ -0,0 +1,15 @@
+c dzasumsub.f
+c
+c The program is a fortran wrapper for dzasum.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine dzasumsub(n,x,incx,asum)
+c
+ external dzasum
+ double precision dzasum,asum
+ integer n,incx
+ double complex x(*)
+c
+ asum=dzasum(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/dznrm2sub.f b/ml/dlib/dlib/external/cblas/dznrm2sub.f
new file mode 100644
index 000000000..45dc599f8
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/dznrm2sub.f
@@ -0,0 +1,15 @@
+c dznrm2sub.f
+c
+c The program is a fortran wrapper for dznrm2.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine dznrm2sub(n,x,incx,nrm2)
+c
+ external dznrm2
+ double precision dznrm2,nrm2
+ integer n,incx
+ double complex x(*)
+c
+ nrm2=dznrm2(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/icamaxsub.f b/ml/dlib/dlib/external/cblas/icamaxsub.f
new file mode 100644
index 000000000..3f47071eb
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/icamaxsub.f
@@ -0,0 +1,15 @@
+c icamaxsub.f
+c
+c The program is a fortran wrapper for icamax.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine icamaxsub(n,x,incx,iamax)
+c
+ external icamax
+ integer icamax,iamax
+ integer n,incx
+ complex x(*)
+c
+ iamax=icamax(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/idamaxsub.f b/ml/dlib/dlib/external/cblas/idamaxsub.f
new file mode 100644
index 000000000..3c1ee5c32
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/idamaxsub.f
@@ -0,0 +1,15 @@
+c icamaxsub.f
+c
+c The program is a fortran wrapper for idamax.
+c Witten by Keita Teranishi. 2/22/1998
+c
+ subroutine idamaxsub(n,x,incx,iamax)
+c
+ external idamax
+ integer idamax,iamax
+ integer n,incx
+ double precision x(*)
+c
+ iamax=idamax(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/isamaxsub.f b/ml/dlib/dlib/external/cblas/isamaxsub.f
new file mode 100644
index 000000000..0faf42fde
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/isamaxsub.f
@@ -0,0 +1,15 @@
+c isamaxsub.f
+c
+c The program is a fortran wrapper for isamax.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine isamaxsub(n,x,incx,iamax)
+c
+ external isamax
+ integer isamax,iamax
+ integer n,incx
+ real x(*)
+c
+ iamax=isamax(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/izamaxsub.f b/ml/dlib/dlib/external/cblas/izamaxsub.f
new file mode 100644
index 000000000..5b15855a7
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/izamaxsub.f
@@ -0,0 +1,15 @@
+c izamaxsub.f
+c
+c The program is a fortran wrapper for izamax.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine izamaxsub(n,x,incx,iamax)
+c
+ external izamax
+ integer izamax,iamax
+ integer n,incx
+ double complex x(*)
+c
+ iamax=izamax(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/sasumsub.f b/ml/dlib/dlib/external/cblas/sasumsub.f
new file mode 100644
index 000000000..955f11e8d
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/sasumsub.f
@@ -0,0 +1,15 @@
+c sasumsub.f
+c
+c The program is a fortran wrapper for sasum.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine sasumsub(n,x,incx,asum)
+c
+ external sasum
+ real sasum,asum
+ integer n,incx
+ real x(*)
+c
+ asum=sasum(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/scasumsub.f b/ml/dlib/dlib/external/cblas/scasumsub.f
new file mode 100644
index 000000000..077ace670
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/scasumsub.f
@@ -0,0 +1,15 @@
+c scasumsub.f
+c
+c The program is a fortran wrapper for scasum.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine scasumsub(n,x,incx,asum)
+c
+ external scasum
+ real scasum,asum
+ integer n,incx
+ complex x(*)
+c
+ asum=scasum(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/scnrm2sub.f b/ml/dlib/dlib/external/cblas/scnrm2sub.f
new file mode 100644
index 000000000..7242c9742
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/scnrm2sub.f
@@ -0,0 +1,15 @@
+c scnrm2sub.f
+c
+c The program is a fortran wrapper for scnrm2.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine scnrm2sub(n,x,incx,nrm2)
+c
+ external scnrm2
+ real scnrm2,nrm2
+ integer n,incx
+ complex x(*)
+c
+ nrm2=scnrm2(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/sdotsub.f b/ml/dlib/dlib/external/cblas/sdotsub.f
new file mode 100644
index 000000000..e1af3c97b
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/sdotsub.f
@@ -0,0 +1,15 @@
+c sdotsub.f
+c
+c The program is a fortran wrapper for sdot.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine sdotsub(n,x,incx,y,incy,dot)
+c
+ external sdot
+ real sdot
+ integer n,incx,incy
+ real x(*),y(*),dot
+c
+ dot=sdot(n,x,incx,y,incy)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/sdsdotsub.f b/ml/dlib/dlib/external/cblas/sdsdotsub.f
new file mode 100644
index 000000000..80008e9ce
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/sdsdotsub.f
@@ -0,0 +1,15 @@
+c sdsdotsub.f
+c
+c The program is a fortran wrapper for sdsdot.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine sdsdotsub(n,x,incx,y,incy,dot)
+c
+ external sdsdot
+ real sdsdot,dot
+ integer n,incx,incy
+ real x(*),y(*)
+c
+ dot=sdsdot(n,x,incx,y,incy)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/snrm2sub.f b/ml/dlib/dlib/external/cblas/snrm2sub.f
new file mode 100644
index 000000000..871a6e49f
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/snrm2sub.f
@@ -0,0 +1,15 @@
+c snrm2sub.f
+c
+c The program is a fortran wrapper for snrm2.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine snrm2sub(n,x,incx,nrm2)
+c
+ external snrm2
+ real snrm2,nrm2
+ integer n,incx
+ real x(*)
+c
+ nrm2=snrm2(n,x,incx)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/zdotcsub.f b/ml/dlib/dlib/external/cblas/zdotcsub.f
new file mode 100644
index 000000000..8d483c895
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/zdotcsub.f
@@ -0,0 +1,15 @@
+c zdotcsub.f
+c
+c The program is a fortran wrapper for zdotc.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine zdotcsub(n,x,incx,y,incy,dotc)
+c
+ external zdotc
+ double complex zdotc,dotc
+ integer n,incx,incy
+ double complex x(*),y(*)
+c
+ dotc=zdotc(n,x,incx,y,incy)
+ return
+ end
diff --git a/ml/dlib/dlib/external/cblas/zdotusub.f b/ml/dlib/dlib/external/cblas/zdotusub.f
new file mode 100644
index 000000000..23f32dec3
--- /dev/null
+++ b/ml/dlib/dlib/external/cblas/zdotusub.f
@@ -0,0 +1,15 @@
+c zdotusub.f
+c
+c The program is a fortran wrapper for zdotu.
+c Witten by Keita Teranishi. 2/11/1998
+c
+ subroutine zdotusub(n,x,incx,y,incy,dotu)
+c
+ external zdotu
+ double complex zdotu,dotu
+ integer n,incx,incy
+ double complex x(*),y(*)
+c
+ dotu=zdotu(n,x,incx,y,incy)
+ return
+ end
diff --git a/ml/dlib/dlib/external/libjpeg/README b/ml/dlib/dlib/external/libjpeg/README
new file mode 100644
index 000000000..86cc20669
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/README
@@ -0,0 +1,385 @@
+The Independent JPEG Group's JPEG software
+==========================================
+
+README for release 6b of 27-Mar-1998
+====================================
+
+This distribution contains the sixth public release of the Independent JPEG
+Group's free JPEG software. You are welcome to redistribute this software and
+to use it for any purpose, subject to the conditions under LEGAL ISSUES, below.
+
+Serious users of this software (particularly those incorporating it into
+larger programs) should contact IJG at jpeg-info@uunet.uu.net to be added to
+our electronic mailing list. Mailing list members are notified of updates
+and have a chance to participate in technical discussions, etc.
+
+This software is the work of Tom Lane, Philip Gladstone, Jim Boucher,
+Lee Crocker, Julian Minguillon, Luis Ortiz, George Phillips, Davide Rossi,
+Guido Vollbeding, Ge' Weijers, and other members of the Independent JPEG
+Group.
+
+IJG is not affiliated with the official ISO JPEG standards committee.
+
+
+DOCUMENTATION ROADMAP
+=====================
+
+This file contains the following sections:
+
+OVERVIEW General description of JPEG and the IJG software.
+LEGAL ISSUES Copyright, lack of warranty, terms of distribution.
+REFERENCES Where to learn more about JPEG.
+ARCHIVE LOCATIONS Where to find newer versions of this software.
+RELATED SOFTWARE Other stuff you should get.
+FILE FORMAT WARS Software *not* to get.
+TO DO Plans for future IJG releases.
+
+Other documentation files in the distribution are:
+
+User documentation:
+ install.doc How to configure and install the IJG software.
+ usage.doc Usage instructions for cjpeg, djpeg, jpegtran,
+ rdjpgcom, and wrjpgcom.
+ *.1 Unix-style man pages for programs (same info as usage.doc).
+ wizard.doc Advanced usage instructions for JPEG wizards only.
+ change.log Version-to-version change highlights.
+Programmer and internal documentation:
+ libjpeg.doc How to use the JPEG library in your own programs.
+ example.c Sample code for calling the JPEG library.
+ structure.doc Overview of the JPEG library's internal structure.
+ filelist.doc Road map of IJG files.
+ coderules.doc Coding style rules --- please read if you contribute code.
+
+Please read at least the files install.doc and usage.doc. Useful information
+can also be found in the JPEG FAQ (Frequently Asked Questions) article. See
+ARCHIVE LOCATIONS below to find out where to obtain the FAQ article.
+
+If you want to understand how the JPEG code works, we suggest reading one or
+more of the REFERENCES, then looking at the documentation files (in roughly
+the order listed) before diving into the code.
+
+
+OVERVIEW
+========
+
+This package contains C software to implement JPEG image compression and
+decompression. JPEG (pronounced "jay-peg") is a standardized compression
+method for full-color and gray-scale images. JPEG is intended for compressing
+"real-world" scenes; line drawings, cartoons and other non-realistic images
+are not its strong suit. JPEG is lossy, meaning that the output image is not
+exactly identical to the input image. Hence you must not use JPEG if you
+have to have identical output bits. However, on typical photographic images,
+very good compression levels can be obtained with no visible change, and
+remarkably high compression levels are possible if you can tolerate a
+low-quality image. For more details, see the references, or just experiment
+with various compression settings.
+
+This software implements JPEG baseline, extended-sequential, and progressive
+compression processes. Provision is made for supporting all variants of these
+processes, although some uncommon parameter settings aren't implemented yet.
+For legal reasons, we are not distributing code for the arithmetic-coding
+variants of JPEG; see LEGAL ISSUES. We have made no provision for supporting
+the hierarchical or lossless processes defined in the standard.
+
+We provide a set of library routines for reading and writing JPEG image files,
+plus two sample applications "cjpeg" and "djpeg", which use the library to
+perform conversion between JPEG and some other popular image file formats.
+The library is intended to be reused in other applications.
+
+In order to support file conversion and viewing software, we have included
+considerable functionality beyond the bare JPEG coding/decoding capability;
+for example, the color quantization modules are not strictly part of JPEG
+decoding, but they are essential for output to colormapped file formats or
+colormapped displays. These extra functions can be compiled out of the
+library if not required for a particular application. We have also included
+"jpegtran", a utility for lossless transcoding between different JPEG
+processes, and "rdjpgcom" and "wrjpgcom", two simple applications for
+inserting and extracting textual comments in JFIF files.
+
+The emphasis in designing this software has been on achieving portability and
+flexibility, while also making it fast enough to be useful. In particular,
+the software is not intended to be read as a tutorial on JPEG. (See the
+REFERENCES section for introductory material.) Rather, it is intended to
+be reliable, portable, industrial-strength code. We do not claim to have
+achieved that goal in every aspect of the software, but we strive for it.
+
+We welcome the use of this software as a component of commercial products.
+No royalty is required, but we do ask for an acknowledgement in product
+documentation, as described under LEGAL ISSUES.
+
+
+LEGAL ISSUES
+============
+
+In plain English:
+
+1. We don't promise that this software works. (But if you find any bugs,
+ please let us know!)
+2. You can use this software for whatever you want. You don't have to pay us.
+3. You may not pretend that you wrote this software. If you use it in a
+ program, you must acknowledge somewhere in your documentation that
+ you've used the IJG code.
+
+In legalese:
+
+The authors make NO WARRANTY or representation, either express or implied,
+with respect to this software, its quality, accuracy, merchantability, or
+fitness for a particular purpose. This software is provided "AS IS", and you,
+its user, assume the entire risk as to its quality and accuracy.
+
+This software is copyright (C) 1991-1998, Thomas G. Lane.
+All Rights Reserved except as specified below.
+
+Permission is hereby granted to use, copy, modify, and distribute this
+software (or portions thereof) for any purpose, without fee, subject to these
+conditions:
+(1) If any part of the source code for this software is distributed, then this
+README file must be included, with this copyright and no-warranty notice
+unaltered; and any additions, deletions, or changes to the original files
+must be clearly indicated in accompanying documentation.
+(2) If only executable code is distributed, then the accompanying
+documentation must state that "this software is based in part on the work of
+the Independent JPEG Group".
+(3) Permission for use of this software is granted only if the user accepts
+full responsibility for any undesirable consequences; the authors accept
+NO LIABILITY for damages of any kind.
+
+These conditions apply to any software derived from or based on the IJG code,
+not just to the unmodified library. If you use our work, you ought to
+acknowledge us.
+
+Permission is NOT granted for the use of any IJG author's name or company name
+in advertising or publicity relating to this software or products derived from
+it. This software may be referred to only as "the Independent JPEG Group's
+software".
+
+We specifically permit and encourage the use of this software as the basis of
+commercial products, provided that all warranty or liability claims are
+assumed by the product vendor.
+
+
+ansi2knr.c is included in this distribution by permission of L. Peter Deutsch,
+sole proprietor of its copyright holder, Aladdin Enterprises of Menlo Park, CA.
+ansi2knr.c is NOT covered by the above copyright and conditions, but instead
+by the usual distribution terms of the Free Software Foundation; principally,
+that you must include source code if you redistribute it. (See the file
+ansi2knr.c for full details.) However, since ansi2knr.c is not needed as part
+of any program generated from the IJG code, this does not limit you more than
+the foregoing paragraphs do.
+
+The Unix configuration script "configure" was produced with GNU Autoconf.
+It is copyright by the Free Software Foundation but is freely distributable.
+The same holds for its supporting scripts (config.guess, config.sub,
+ltconfig, ltmain.sh). Another support script, install-sh, is copyright
+by M.I.T. but is also freely distributable.
+
+It appears that the arithmetic coding option of the JPEG spec is covered by
+patents owned by IBM, AT&T, and Mitsubishi. Hence arithmetic coding cannot
+legally be used without obtaining one or more licenses. For this reason,
+support for arithmetic coding has been removed from the free JPEG software.
+(Since arithmetic coding provides only a marginal gain over the unpatented
+Huffman mode, it is unlikely that very many implementations will support it.)
+So far as we are aware, there are no patent restrictions on the remaining
+code.
+
+The IJG distribution formerly included code to read and write GIF files.
+To avoid entanglement with the Unisys LZW patent, GIF reading support has
+been removed altogether, and the GIF writer has been simplified to produce
+"uncompressed GIFs". This technique does not use the LZW algorithm; the
+resulting GIF files are larger than usual, but are readable by all standard
+GIF decoders.
+
+We are required to state that
+ "The Graphics Interchange Format(c) is the Copyright property of
+ CompuServe Incorporated. GIF(sm) is a Service Mark property of
+ CompuServe Incorporated."
+
+
+REFERENCES
+==========
+
+We highly recommend reading one or more of these references before trying to
+understand the innards of the JPEG software.
+
+The best short technical introduction to the JPEG compression algorithm is
+ Wallace, Gregory K. "The JPEG Still Picture Compression Standard",
+ Communications of the ACM, April 1991 (vol. 34 no. 4), pp. 30-44.
+(Adjacent articles in that issue discuss MPEG motion picture compression,
+applications of JPEG, and related topics.) If you don't have the CACM issue
+handy, a PostScript file containing a revised version of Wallace's article is
+available at ftp://ftp.uu.net/graphics/jpeg/wallace.ps.gz. The file (actually
+a preprint for an article that appeared in IEEE Trans. Consumer Electronics)
+omits the sample images that appeared in CACM, but it includes corrections
+and some added material. Note: the Wallace article is copyright ACM and IEEE,
+and it may not be used for commercial purposes.
+
+A somewhat less technical, more leisurely introduction to JPEG can be found in
+"The Data Compression Book" by Mark Nelson and Jean-loup Gailly, published by
+M&T Books (New York), 2nd ed. 1996, ISBN 1-55851-434-1. This book provides
+good explanations and example C code for a multitude of compression methods
+including JPEG. It is an excellent source if you are comfortable reading C
+code but don't know much about data compression in general. The book's JPEG
+sample code is far from industrial-strength, but when you are ready to look
+at a full implementation, you've got one here...
+
+The best full description of JPEG is the textbook "JPEG Still Image Data
+Compression Standard" by William B. Pennebaker and Joan L. Mitchell, published
+by Van Nostrand Reinhold, 1993, ISBN 0-442-01272-1. Price US$59.95, 638 pp.
+The book includes the complete text of the ISO JPEG standards (DIS 10918-1
+and draft DIS 10918-2). This is by far the most complete exposition of JPEG
+in existence, and we highly recommend it.
+
+The JPEG standard itself is not available electronically; you must order a
+paper copy through ISO or ITU. (Unless you feel a need to own a certified
+official copy, we recommend buying the Pennebaker and Mitchell book instead;
+it's much cheaper and includes a great deal of useful explanatory material.)
+In the USA, copies of the standard may be ordered from ANSI Sales at (212)
+642-4900, or from Global Engineering Documents at (800) 854-7179. (ANSI
+doesn't take credit card orders, but Global does.) It's not cheap: as of
+1992, ANSI was charging $95 for Part 1 and $47 for Part 2, plus 7%
+shipping/handling. The standard is divided into two parts, Part 1 being the
+actual specification, while Part 2 covers compliance testing methods. Part 1
+is titled "Digital Compression and Coding of Continuous-tone Still Images,
+Part 1: Requirements and guidelines" and has document numbers ISO/IEC IS
+10918-1, ITU-T T.81. Part 2 is titled "Digital Compression and Coding of
+Continuous-tone Still Images, Part 2: Compliance testing" and has document
+numbers ISO/IEC IS 10918-2, ITU-T T.83.
+
+Some extensions to the original JPEG standard are defined in JPEG Part 3,
+a newer ISO standard numbered ISO/IEC IS 10918-3 and ITU-T T.84. IJG
+currently does not support any Part 3 extensions.
+
+The JPEG standard does not specify all details of an interchangeable file
+format. For the omitted details we follow the "JFIF" conventions, revision
+1.02. A copy of the JFIF spec is available from:
+ Literature Department
+ C-Cube Microsystems, Inc.
+ 1778 McCarthy Blvd.
+ Milpitas, CA 95035
+ phone (408) 944-6300, fax (408) 944-6314
+A PostScript version of this document is available by FTP at
+ftp://ftp.uu.net/graphics/jpeg/jfif.ps.gz. There is also a plain text
+version at ftp://ftp.uu.net/graphics/jpeg/jfif.txt.gz, but it is missing
+the figures.
+
+The TIFF 6.0 file format specification can be obtained by FTP from
+ftp://ftp.sgi.com/graphics/tiff/TIFF6.ps.gz. The JPEG incorporation scheme
+found in the TIFF 6.0 spec of 3-June-92 has a number of serious problems.
+IJG does not recommend use of the TIFF 6.0 design (TIFF Compression tag 6).
+Instead, we recommend the JPEG design proposed by TIFF Technical Note #2
+(Compression tag 7). Copies of this Note can be obtained from ftp.sgi.com or
+from ftp://ftp.uu.net/graphics/jpeg/. It is expected that the next revision
+of the TIFF spec will replace the 6.0 JPEG design with the Note's design.
+Although IJG's own code does not support TIFF/JPEG, the free libtiff library
+uses our library to implement TIFF/JPEG per the Note. libtiff is available
+from ftp://ftp.sgi.com/graphics/tiff/.
+
+
+ARCHIVE LOCATIONS
+=================
+
+The "official" archive site for this software is ftp.uu.net (Internet
+address 192.48.96.9). The most recent released version can always be found
+there in directory graphics/jpeg. This particular version will be archived
+as ftp://ftp.uu.net/graphics/jpeg/jpegsrc.v6b.tar.gz. If you don't have
+direct Internet access, UUNET's archives are also available via UUCP; contact
+help@uunet.uu.net for information on retrieving files that way.
+
+Numerous Internet sites maintain copies of the UUNET files. However, only
+ftp.uu.net is guaranteed to have the latest official version.
+
+You can also obtain this software in DOS-compatible "zip" archive format from
+the SimTel archives (ftp://ftp.simtel.net/pub/simtelnet/msdos/graphics/), or
+on CompuServe in the Graphics Support forum (GO CIS:GRAPHSUP), library 12
+"JPEG Tools". Again, these versions may sometimes lag behind the ftp.uu.net
+release.
+
+The JPEG FAQ (Frequently Asked Questions) article is a useful source of
+general information about JPEG. It is updated constantly and therefore is
+not included in this distribution. The FAQ is posted every two weeks to
+Usenet newsgroups comp.graphics.misc, news.answers, and other groups.
+It is available on the World Wide Web at http://www.faqs.org/faqs/jpeg-faq/
+and other news.answers archive sites, including the official news.answers
+archive at rtfm.mit.edu: ftp://rtfm.mit.edu/pub/usenet/news.answers/jpeg-faq/.
+If you don't have Web or FTP access, send e-mail to mail-server@rtfm.mit.edu
+with body
+ send usenet/news.answers/jpeg-faq/part1
+ send usenet/news.answers/jpeg-faq/part2
+
+
+RELATED SOFTWARE
+================
+
+Numerous viewing and image manipulation programs now support JPEG. (Quite a
+few of them use this library to do so.) The JPEG FAQ described above lists
+some of the more popular free and shareware viewers, and tells where to
+obtain them on Internet.
+
+If you are on a Unix machine, we highly recommend Jef Poskanzer's free
+PBMPLUS software, which provides many useful operations on PPM-format image
+files. In particular, it can convert PPM images to and from a wide range of
+other formats, thus making cjpeg/djpeg considerably more useful. The latest
+version is distributed by the NetPBM group, and is available from numerous
+sites, notably ftp://wuarchive.wustl.edu/graphics/graphics/packages/NetPBM/.
+Unfortunately PBMPLUS/NETPBM is not nearly as portable as the IJG software is;
+you are likely to have difficulty making it work on any non-Unix machine.
+
+A different free JPEG implementation, written by the PVRG group at Stanford,
+is available from ftp://havefun.stanford.edu/pub/jpeg/. This program
+is designed for research and experimentation rather than production use;
+it is slower, harder to use, and less portable than the IJG code, but it
+is easier to read and modify. Also, the PVRG code supports lossless JPEG,
+which we do not. (On the other hand, it doesn't do progressive JPEG.)
+
+
+FILE FORMAT WARS
+================
+
+Some JPEG programs produce files that are not compatible with our library.
+The root of the problem is that the ISO JPEG committee failed to specify a
+concrete file format. Some vendors "filled in the blanks" on their own,
+creating proprietary formats that no one else could read. (For example, none
+of the early commercial JPEG implementations for the Macintosh were able to
+exchange compressed files.)
+
+The file format we have adopted is called JFIF (see REFERENCES). This format
+has been agreed to by a number of major commercial JPEG vendors, and it has
+become the de facto standard. JFIF is a minimal or "low end" representation.
+We recommend the use of TIFF/JPEG (TIFF revision 6.0 as modified by TIFF
+Technical Note #2) for "high end" applications that need to record a lot of
+additional data about an image. TIFF/JPEG is fairly new and not yet widely
+supported, unfortunately.
+
+The upcoming JPEG Part 3 standard defines a file format called SPIFF.
+SPIFF is interoperable with JFIF, in the sense that most JFIF decoders should
+be able to read the most common variant of SPIFF. SPIFF has some technical
+advantages over JFIF, but its major claim to fame is simply that it is an
+official standard rather than an informal one. At this point it is unclear
+whether SPIFF will supersede JFIF or whether JFIF will remain the de-facto
+standard. IJG intends to support SPIFF once the standard is frozen, but we
+have not decided whether it should become our default output format or not.
+(In any case, our decoder will remain capable of reading JFIF indefinitely.)
+
+Various proprietary file formats incorporating JPEG compression also exist.
+We have little or no sympathy for the existence of these formats. Indeed,
+one of the original reasons for developing this free software was to help
+force convergence on common, open format standards for JPEG files. Don't
+use a proprietary file format!
+
+
+TO DO
+=====
+
+The major thrust for v7 will probably be improvement of visual quality.
+The current method for scaling the quantization tables is known not to be
+very good at low Q values. We also intend to investigate block boundary
+smoothing, "poor man's variable quantization", and other means of improving
+quality-vs-file-size performance without sacrificing compatibility.
+
+In future versions, we are considering supporting some of the upcoming JPEG
+Part 3 extensions --- principally, variable quantization and the SPIFF file
+format.
+
+As always, speeding things up is of great interest.
+
+Please send bug reports, offers of help, etc. to jpeg-info@uunet.uu.net.
diff --git a/ml/dlib/dlib/external/libjpeg/jcapimin.cpp b/ml/dlib/dlib/external/libjpeg/jcapimin.cpp
new file mode 100644
index 000000000..bc0ceac5b
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcapimin.cpp
@@ -0,0 +1,280 @@
+/*
+ * jcapimin.c
+ *
+ * Copyright (C) 1994-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains application interface code for the compression half
+ * of the JPEG library. These are the "minimum" API routines that may be
+ * needed in either the normal full-compression case or the transcoding-only
+ * case.
+ *
+ * Most of the routines intended to be called directly by an application
+ * are in this file or in jcapistd.c. But also see jcparam.c for
+ * parameter-setup helper routines, jcomapi.c for routines shared by
+ * compression and decompression, and jctrans.c for the transcoding case.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/*
+ * Initialization of a JPEG compression object.
+ * The error manager must already be set up (in case memory manager fails).
+ */
+
+GLOBAL(void)
+jpeg_CreateCompress (j_compress_ptr cinfo, int version, size_t structsize)
+{
+ int i;
+
+ /* Guard against version mismatches between library and caller. */
+ cinfo->mem = NULL; /* so jpeg_destroy knows mem mgr not called */
+ if (version != JPEG_LIB_VERSION)
+ ERREXIT2(cinfo, JERR_BAD_LIB_VERSION, JPEG_LIB_VERSION, version);
+ if (structsize != SIZEOF(struct jpeg_compress_struct))
+ ERREXIT2(cinfo, JERR_BAD_STRUCT_SIZE,
+ (int) SIZEOF(struct jpeg_compress_struct), (int) structsize);
+
+ /* For debugging purposes, we zero the whole master structure.
+ * But the application has already set the err pointer, and may have set
+ * client_data, so we have to save and restore those fields.
+ * Note: if application hasn't set client_data, tools like Purify may
+ * complain here.
+ */
+ {
+ struct jpeg_error_mgr * err = cinfo->err;
+ void * client_data = cinfo->client_data; /* ignore Purify complaint here */
+ MEMZERO(cinfo, SIZEOF(struct jpeg_compress_struct));
+ cinfo->err = err;
+ cinfo->client_data = client_data;
+ }
+ cinfo->is_decompressor = FALSE;
+
+ /* Initialize a memory manager instance for this object */
+ jinit_memory_mgr((j_common_ptr) cinfo);
+
+ /* Zero out pointers to permanent structures. */
+ cinfo->progress = NULL;
+ cinfo->dest = NULL;
+
+ cinfo->comp_info = NULL;
+
+ for (i = 0; i < NUM_QUANT_TBLS; i++)
+ cinfo->quant_tbl_ptrs[i] = NULL;
+
+ for (i = 0; i < NUM_HUFF_TBLS; i++) {
+ cinfo->dc_huff_tbl_ptrs[i] = NULL;
+ cinfo->ac_huff_tbl_ptrs[i] = NULL;
+ }
+
+ cinfo->script_space = NULL;
+
+ cinfo->input_gamma = 1.0; /* in case application forgets */
+
+ /* OK, I'm ready */
+ cinfo->global_state = CSTATE_START;
+}
+
+
+/*
+ * Destruction of a JPEG compression object
+ */
+
+GLOBAL(void)
+jpeg_destroy_compress (j_compress_ptr cinfo)
+{
+ jpeg_destroy((j_common_ptr) cinfo); /* use common routine */
+}
+
+
+/*
+ * Abort processing of a JPEG compression operation,
+ * but don't destroy the object itself.
+ */
+
+GLOBAL(void)
+jpeg_abort_compress (j_compress_ptr cinfo)
+{
+ jpeg_abort((j_common_ptr) cinfo); /* use common routine */
+}
+
+
+/*
+ * Forcibly suppress or un-suppress all quantization and Huffman tables.
+ * Marks all currently defined tables as already written (if suppress)
+ * or not written (if !suppress). This will control whether they get emitted
+ * by a subsequent jpeg_start_compress call.
+ *
+ * This routine is exported for use by applications that want to produce
+ * abbreviated JPEG datastreams. It logically belongs in jcparam.c, but
+ * since it is called by jpeg_start_compress, we put it here --- otherwise
+ * jcparam.o would be linked whether the application used it or not.
+ */
+
+GLOBAL(void)
+jpeg_suppress_tables (j_compress_ptr cinfo, int suppress)
+{
+ int i;
+ JQUANT_TBL * qtbl;
+ JHUFF_TBL * htbl;
+
+ for (i = 0; i < NUM_QUANT_TBLS; i++) {
+ if ((qtbl = cinfo->quant_tbl_ptrs[i]) != NULL)
+ qtbl->sent_table = suppress;
+ }
+
+ for (i = 0; i < NUM_HUFF_TBLS; i++) {
+ if ((htbl = cinfo->dc_huff_tbl_ptrs[i]) != NULL)
+ htbl->sent_table = suppress;
+ if ((htbl = cinfo->ac_huff_tbl_ptrs[i]) != NULL)
+ htbl->sent_table = suppress;
+ }
+}
+
+
+/*
+ * Finish JPEG compression.
+ *
+ * If a multipass operating mode was selected, this may do a great deal of
+ * work including most of the actual output.
+ */
+
+GLOBAL(void)
+jpeg_finish_compress (j_compress_ptr cinfo)
+{
+ JDIMENSION iMCU_row;
+
+ if (cinfo->global_state == CSTATE_SCANNING ||
+ cinfo->global_state == CSTATE_RAW_OK) {
+ /* Terminate first pass */
+ if (cinfo->next_scanline < cinfo->image_height)
+ ERREXIT(cinfo, JERR_TOO_LITTLE_DATA);
+ (*cinfo->master->finish_pass) (cinfo);
+ } else if (cinfo->global_state != CSTATE_WRCOEFS)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ /* Perform any remaining passes */
+ while (! cinfo->master->is_last_pass) {
+ (*cinfo->master->prepare_for_pass) (cinfo);
+ for (iMCU_row = 0; iMCU_row < cinfo->total_iMCU_rows; iMCU_row++) {
+ if (cinfo->progress != NULL) {
+ cinfo->progress->pass_counter = (long) iMCU_row;
+ cinfo->progress->pass_limit = (long) cinfo->total_iMCU_rows;
+ (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo);
+ }
+ /* We bypass the main controller and invoke coef controller directly;
+ * all work is being done from the coefficient buffer.
+ */
+ if (! (*cinfo->coef->compress_data) (cinfo, (JSAMPIMAGE) NULL))
+ ERREXIT(cinfo, JERR_CANT_SUSPEND);
+ }
+ (*cinfo->master->finish_pass) (cinfo);
+ }
+ /* Write EOI, do final cleanup */
+ (*cinfo->marker->write_file_trailer) (cinfo);
+ (*cinfo->dest->term_destination) (cinfo);
+ /* We can use jpeg_abort to release memory and reset global_state */
+ jpeg_abort((j_common_ptr) cinfo);
+}
+
+
+/*
+ * Write a special marker.
+ * This is only recommended for writing COM or APPn markers.
+ * Must be called after jpeg_start_compress() and before
+ * first call to jpeg_write_scanlines() or jpeg_write_raw_data().
+ */
+
+GLOBAL(void)
+jpeg_write_marker (j_compress_ptr cinfo, int marker,
+ const JOCTET *dataptr, unsigned int datalen)
+{
+ JMETHOD(void, write_marker_byte, (j_compress_ptr info, int val));
+
+ if (cinfo->next_scanline != 0 ||
+ (cinfo->global_state != CSTATE_SCANNING &&
+ cinfo->global_state != CSTATE_RAW_OK &&
+ cinfo->global_state != CSTATE_WRCOEFS))
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ (*cinfo->marker->write_marker_header) (cinfo, marker, datalen);
+ write_marker_byte = cinfo->marker->write_marker_byte; /* copy for speed */
+ while (datalen--) {
+ (*write_marker_byte) (cinfo, *dataptr);
+ dataptr++;
+ }
+}
+
+/* Same, but piecemeal. */
+
+GLOBAL(void)
+jpeg_write_m_header (j_compress_ptr cinfo, int marker, unsigned int datalen)
+{
+ if (cinfo->next_scanline != 0 ||
+ (cinfo->global_state != CSTATE_SCANNING &&
+ cinfo->global_state != CSTATE_RAW_OK &&
+ cinfo->global_state != CSTATE_WRCOEFS))
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ (*cinfo->marker->write_marker_header) (cinfo, marker, datalen);
+}
+
+GLOBAL(void)
+jpeg_write_m_byte (j_compress_ptr cinfo, int val)
+{
+ (*cinfo->marker->write_marker_byte) (cinfo, val);
+}
+
+
+/*
+ * Alternate compression function: just write an abbreviated table file.
+ * Before calling this, all parameters and a data destination must be set up.
+ *
+ * To produce a pair of files containing abbreviated tables and abbreviated
+ * image data, one would proceed as follows:
+ *
+ * initialize JPEG object
+ * set JPEG parameters
+ * set destination to table file
+ * jpeg_write_tables(cinfo);
+ * set destination to image file
+ * jpeg_start_compress(cinfo, FALSE);
+ * write data...
+ * jpeg_finish_compress(cinfo);
+ *
+ * jpeg_write_tables has the side effect of marking all tables written
+ * (same as jpeg_suppress_tables(..., TRUE)). Thus a subsequent start_compress
+ * will not re-emit the tables unless it is passed write_all_tables=TRUE.
+ */
+
+GLOBAL(void)
+jpeg_write_tables (j_compress_ptr cinfo)
+{
+ if (cinfo->global_state != CSTATE_START)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ /* (Re)initialize error mgr and destination modules */
+ (*cinfo->err->reset_error_mgr) ((j_common_ptr) cinfo);
+ (*cinfo->dest->init_destination) (cinfo);
+ /* Initialize the marker writer ... bit of a crock to do it here. */
+ jinit_marker_writer(cinfo);
+ /* Write them tables! */
+ (*cinfo->marker->write_tables_only) (cinfo);
+ /* And clean up. */
+ (*cinfo->dest->term_destination) (cinfo);
+ /*
+ * In library releases up through v6a, we called jpeg_abort() here to free
+ * any working memory allocated by the destination manager and marker
+ * writer. Some applications had a problem with that: they allocated space
+ * of their own from the library memory manager, and didn't want it to go
+ * away during write_tables. So now we do nothing. This will cause a
+ * memory leak if an app calls write_tables repeatedly without doing a full
+ * compression cycle or otherwise resetting the JPEG object. However, that
+ * seems less bad than unexpectedly freeing memory in the normal case.
+ * An app that prefers the old behavior can call jpeg_abort for itself after
+ * each call to jpeg_write_tables().
+ */
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jcapistd.cpp b/ml/dlib/dlib/external/libjpeg/jcapistd.cpp
new file mode 100644
index 000000000..3f4e08063
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcapistd.cpp
@@ -0,0 +1,161 @@
+/*
+ * jcapistd.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains application interface code for the compression half
+ * of the JPEG library. These are the "standard" API routines that are
+ * used in the normal full-compression case. They are not used by a
+ * transcoding-only application. Note that if an application links in
+ * jpeg_start_compress, it will end up linking in the entire compressor.
+ * We thus must separate this file from jcapimin.c to avoid linking the
+ * whole compression library into a transcoder.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/*
+ * Compression initialization.
+ * Before calling this, all parameters and a data destination must be set up.
+ *
+ * We require a write_all_tables parameter as a failsafe check when writing
+ * multiple datastreams from the same compression object. Since prior runs
+ * will have left all the tables marked sent_table=TRUE, a subsequent run
+ * would emit an abbreviated stream (no tables) by default. This may be what
+ * is wanted, but for safety's sake it should not be the default behavior:
+ * programmers should have to make a deliberate choice to emit abbreviated
+ * images. Therefore the documentation and examples should encourage people
+ * to pass write_all_tables=TRUE; then it will take active thought to do the
+ * wrong thing.
+ */
+
+GLOBAL(void)
+jpeg_start_compress (j_compress_ptr cinfo, int write_all_tables)
+{
+ if (cinfo->global_state != CSTATE_START)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ if (write_all_tables)
+ jpeg_suppress_tables(cinfo, FALSE); /* mark all tables to be written */
+
+ /* (Re)initialize error mgr and destination modules */
+ (*cinfo->err->reset_error_mgr) ((j_common_ptr) cinfo);
+ (*cinfo->dest->init_destination) (cinfo);
+ /* Perform master selection of active modules */
+ jinit_compress_master(cinfo);
+ /* Set up for the first pass */
+ (*cinfo->master->prepare_for_pass) (cinfo);
+ /* Ready for application to drive first pass through jpeg_write_scanlines
+ * or jpeg_write_raw_data.
+ */
+ cinfo->next_scanline = 0;
+ cinfo->global_state = (cinfo->raw_data_in ? CSTATE_RAW_OK : CSTATE_SCANNING);
+}
+
+
+/*
+ * Write some scanlines of data to the JPEG compressor.
+ *
+ * The return value will be the number of lines actually written.
+ * This should be less than the supplied num_lines only in case that
+ * the data destination module has requested suspension of the compressor,
+ * or if more than image_height scanlines are passed in.
+ *
+ * Note: we warn about excess calls to jpeg_write_scanlines() since
+ * this likely signals an application programmer error. However,
+ * excess scanlines passed in the last valid call are *silently* ignored,
+ * so that the application need not adjust num_lines for end-of-image
+ * when using a multiple-scanline buffer.
+ */
+
+GLOBAL(JDIMENSION)
+jpeg_write_scanlines (j_compress_ptr cinfo, JSAMPARRAY scanlines,
+ JDIMENSION num_lines)
+{
+ JDIMENSION row_ctr, rows_left;
+
+ if (cinfo->global_state != CSTATE_SCANNING)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ if (cinfo->next_scanline >= cinfo->image_height)
+ WARNMS(cinfo, JWRN_TOO_MUCH_DATA);
+
+ /* Call progress monitor hook if present */
+ if (cinfo->progress != NULL) {
+ cinfo->progress->pass_counter = (long) cinfo->next_scanline;
+ cinfo->progress->pass_limit = (long) cinfo->image_height;
+ (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo);
+ }
+
+ /* Give master control module another chance if this is first call to
+ * jpeg_write_scanlines. This lets output of the frame/scan headers be
+ * delayed so that application can write COM, etc, markers between
+ * jpeg_start_compress and jpeg_write_scanlines.
+ */
+ if (cinfo->master->call_pass_startup)
+ (*cinfo->master->pass_startup) (cinfo);
+
+ /* Ignore any extra scanlines at bottom of image. */
+ rows_left = cinfo->image_height - cinfo->next_scanline;
+ if (num_lines > rows_left)
+ num_lines = rows_left;
+
+ row_ctr = 0;
+ (*cinfo->main->process_data) (cinfo, scanlines, &row_ctr, num_lines);
+ cinfo->next_scanline += row_ctr;
+ return row_ctr;
+}
+
+
+/*
+ * Alternate entry point to write raw data.
+ * Processes exactly one iMCU row per call, unless suspended.
+ */
+
+GLOBAL(JDIMENSION)
+jpeg_write_raw_data (j_compress_ptr cinfo, JSAMPIMAGE data,
+ JDIMENSION num_lines)
+{
+ JDIMENSION lines_per_iMCU_row;
+
+ if (cinfo->global_state != CSTATE_RAW_OK)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ if (cinfo->next_scanline >= cinfo->image_height) {
+ WARNMS(cinfo, JWRN_TOO_MUCH_DATA);
+ return 0;
+ }
+
+ /* Call progress monitor hook if present */
+ if (cinfo->progress != NULL) {
+ cinfo->progress->pass_counter = (long) cinfo->next_scanline;
+ cinfo->progress->pass_limit = (long) cinfo->image_height;
+ (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo);
+ }
+
+ /* Give master control module another chance if this is first call to
+ * jpeg_write_raw_data. This lets output of the frame/scan headers be
+ * delayed so that application can write COM, etc, markers between
+ * jpeg_start_compress and jpeg_write_raw_data.
+ */
+ if (cinfo->master->call_pass_startup)
+ (*cinfo->master->pass_startup) (cinfo);
+
+ /* Verify that at least one iMCU row has been passed. */
+ lines_per_iMCU_row = cinfo->max_v_samp_factor * DCTSIZE;
+ if (num_lines < lines_per_iMCU_row)
+ ERREXIT(cinfo, JERR_BUFFER_SIZE);
+
+ /* Directly compress the row. */
+ if (! (*cinfo->coef->compress_data) (cinfo, data)) {
+ /* If compressor did not consume the whole row, suspend processing. */
+ return 0;
+ }
+
+ /* OK, we processed one iMCU row. */
+ cinfo->next_scanline += lines_per_iMCU_row;
+ return lines_per_iMCU_row;
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jccoefct.cpp b/ml/dlib/dlib/external/libjpeg/jccoefct.cpp
new file mode 100644
index 000000000..175d7ecd9
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jccoefct.cpp
@@ -0,0 +1,449 @@
+/*
+ * jccoefct.c
+ *
+ * Copyright (C) 1994-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains the coefficient buffer controller for compression.
+ * This controller is the top level of the JPEG compressor proper.
+ * The coefficient buffer lies between forward-DCT and entropy encoding steps.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* We use a full-image coefficient buffer when doing Huffman optimization,
+ * and also for writing multiple-scan JPEG files. In all cases, the DCT
+ * step is run during the first pass, and subsequent passes need only read
+ * the buffered coefficients.
+ */
+#ifdef ENTROPY_OPT_SUPPORTED
+#define FULL_COEF_BUFFER_SUPPORTED
+#else
+#ifdef C_MULTISCAN_FILES_SUPPORTED
+#define FULL_COEF_BUFFER_SUPPORTED
+#endif
+#endif
+
+
+/* Private buffer controller object */
+
+typedef struct {
+ struct jpeg_c_coef_controller pub; /* public fields */
+
+ JDIMENSION iMCU_row_num; /* iMCU row # within image */
+ JDIMENSION mcu_ctr; /* counts MCUs processed in current row */
+ int MCU_vert_offset; /* counts MCU rows within iMCU row */
+ int MCU_rows_per_iMCU_row; /* number of such rows needed */
+
+ /* For single-pass compression, it's sufficient to buffer just one MCU
+ * (although this may prove a bit slow in practice). We allocate a
+ * workspace of C_MAX_BLOCKS_IN_MCU coefficient blocks, and reuse it for each
+ * MCU constructed and sent. (On 80x86, the workspace is FAR even though
+ * it's not really very big; this is to keep the module interfaces unchanged
+ * when a large coefficient buffer is necessary.)
+ * In multi-pass modes, this array points to the current MCU's blocks
+ * within the virtual arrays.
+ */
+ JBLOCKROW MCU_buffer[C_MAX_BLOCKS_IN_MCU];
+
+ /* In multi-pass modes, we need a virtual block array for each component. */
+ jvirt_barray_ptr whole_image[MAX_COMPONENTS];
+} my_coef_controller;
+
+typedef my_coef_controller * my_coef_ptr;
+
+
+/* Forward declarations */
+METHODDEF(int) compress_data
+ JPP((j_compress_ptr cinfo, JSAMPIMAGE input_buf));
+#ifdef FULL_COEF_BUFFER_SUPPORTED
+METHODDEF(int) compress_first_pass
+ JPP((j_compress_ptr cinfo, JSAMPIMAGE input_buf));
+METHODDEF(int) compress_output
+ JPP((j_compress_ptr cinfo, JSAMPIMAGE input_buf));
+#endif
+
+
+LOCAL(void)
+start_iMCU_row (j_compress_ptr cinfo)
+/* Reset within-iMCU-row counters for a new row */
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+
+ /* In an interleaved scan, an MCU row is the same as an iMCU row.
+ * In a noninterleaved scan, an iMCU row has v_samp_factor MCU rows.
+ * But at the bottom of the image, process only what's left.
+ */
+ if (cinfo->comps_in_scan > 1) {
+ coef->MCU_rows_per_iMCU_row = 1;
+ } else {
+ if (coef->iMCU_row_num < (cinfo->total_iMCU_rows-1))
+ coef->MCU_rows_per_iMCU_row = cinfo->cur_comp_info[0]->v_samp_factor;
+ else
+ coef->MCU_rows_per_iMCU_row = cinfo->cur_comp_info[0]->last_row_height;
+ }
+
+ coef->mcu_ctr = 0;
+ coef->MCU_vert_offset = 0;
+}
+
+
+/*
+ * Initialize for a processing pass.
+ */
+
+METHODDEF(void)
+start_pass_coef (j_compress_ptr cinfo, J_BUF_MODE pass_mode)
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+
+ coef->iMCU_row_num = 0;
+ start_iMCU_row(cinfo);
+
+ switch (pass_mode) {
+ case JBUF_PASS_THRU:
+ if (coef->whole_image[0] != NULL)
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ coef->pub.compress_data = compress_data;
+ break;
+#ifdef FULL_COEF_BUFFER_SUPPORTED
+ case JBUF_SAVE_AND_PASS:
+ if (coef->whole_image[0] == NULL)
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ coef->pub.compress_data = compress_first_pass;
+ break;
+ case JBUF_CRANK_DEST:
+ if (coef->whole_image[0] == NULL)
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ coef->pub.compress_data = compress_output;
+ break;
+#endif
+ default:
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ break;
+ }
+}
+
+
+/*
+ * Process some data in the single-pass case.
+ * We process the equivalent of one fully interleaved MCU row ("iMCU" row)
+ * per call, ie, v_samp_factor block rows for each component in the image.
+ * Returns TRUE if the iMCU row is completed, FALSE if suspended.
+ *
+ * NB: input_buf contains a plane for each component in image,
+ * which we index according to the component's SOF position.
+ */
+
+METHODDEF(int)
+compress_data (j_compress_ptr cinfo, JSAMPIMAGE input_buf)
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+ JDIMENSION MCU_col_num; /* index of current MCU within row */
+ JDIMENSION last_MCU_col = cinfo->MCUs_per_row - 1;
+ JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1;
+ int blkn, bi, ci, yindex, yoffset, blockcnt;
+ JDIMENSION ypos, xpos;
+ jpeg_component_info *compptr;
+
+ /* Loop to write as much as one whole iMCU row */
+ for (yoffset = coef->MCU_vert_offset; yoffset < coef->MCU_rows_per_iMCU_row;
+ yoffset++) {
+ for (MCU_col_num = coef->mcu_ctr; MCU_col_num <= last_MCU_col;
+ MCU_col_num++) {
+ /* Determine where data comes from in input_buf and do the DCT thing.
+ * Each call on forward_DCT processes a horizontal row of DCT blocks
+ * as wide as an MCU; we rely on having allocated the MCU_buffer[] blocks
+ * sequentially. Dummy blocks at the right or bottom edge are filled in
+ * specially. The data in them does not matter for image reconstruction,
+ * so we fill them with values that will encode to the smallest amount of
+ * data, viz: all zeroes in the AC entries, DC entries equal to previous
+ * block's DC value. (Thanks to Thomas Kinsman for this idea.)
+ */
+ blkn = 0;
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ blockcnt = (MCU_col_num < last_MCU_col) ? compptr->MCU_width
+ : compptr->last_col_width;
+ xpos = MCU_col_num * compptr->MCU_sample_width;
+ ypos = yoffset * DCTSIZE; /* ypos == (yoffset+yindex) * DCTSIZE */
+ for (yindex = 0; yindex < compptr->MCU_height; yindex++) {
+ if (coef->iMCU_row_num < last_iMCU_row ||
+ yoffset+yindex < compptr->last_row_height) {
+ (*cinfo->fdct->forward_DCT) (cinfo, compptr,
+ input_buf[compptr->component_index],
+ coef->MCU_buffer[blkn],
+ ypos, xpos, (JDIMENSION) blockcnt);
+ if (blockcnt < compptr->MCU_width) {
+ /* Create some dummy blocks at the right edge of the image. */
+ jzero_far((void FAR *) coef->MCU_buffer[blkn + blockcnt],
+ (compptr->MCU_width - blockcnt) * SIZEOF(JBLOCK));
+ for (bi = blockcnt; bi < compptr->MCU_width; bi++) {
+ coef->MCU_buffer[blkn+bi][0][0] = coef->MCU_buffer[blkn+bi-1][0][0];
+ }
+ }
+ } else {
+ /* Create a row of dummy blocks at the bottom of the image. */
+ jzero_far((void FAR *) coef->MCU_buffer[blkn],
+ compptr->MCU_width * SIZEOF(JBLOCK));
+ for (bi = 0; bi < compptr->MCU_width; bi++) {
+ coef->MCU_buffer[blkn+bi][0][0] = coef->MCU_buffer[blkn-1][0][0];
+ }
+ }
+ blkn += compptr->MCU_width;
+ ypos += DCTSIZE;
+ }
+ }
+ /* Try to write the MCU. In event of a suspension failure, we will
+ * re-DCT the MCU on restart (a bit inefficient, could be fixed...)
+ */
+ if (! (*cinfo->entropy->encode_mcu) (cinfo, coef->MCU_buffer)) {
+ /* Suspension forced; update state counters and exit */
+ coef->MCU_vert_offset = yoffset;
+ coef->mcu_ctr = MCU_col_num;
+ return FALSE;
+ }
+ }
+ /* Completed an MCU row, but perhaps not an iMCU row */
+ coef->mcu_ctr = 0;
+ }
+ /* Completed the iMCU row, advance counters for next one */
+ coef->iMCU_row_num++;
+ start_iMCU_row(cinfo);
+ return TRUE;
+}
+
+
+#ifdef FULL_COEF_BUFFER_SUPPORTED
+
+/*
+ * Process some data in the first pass of a multi-pass case.
+ * We process the equivalent of one fully interleaved MCU row ("iMCU" row)
+ * per call, ie, v_samp_factor block rows for each component in the image.
+ * This amount of data is read from the source buffer, DCT'd and quantized,
+ * and saved into the virtual arrays. We also generate suitable dummy blocks
+ * as needed at the right and lower edges. (The dummy blocks are constructed
+ * in the virtual arrays, which have been padded appropriately.) This makes
+ * it possible for subsequent passes not to worry about real vs. dummy blocks.
+ *
+ * We must also emit the data to the entropy encoder. This is conveniently
+ * done by calling compress_output() after we've loaded the current strip
+ * of the virtual arrays.
+ *
+ * NB: input_buf contains a plane for each component in image. All
+ * components are DCT'd and loaded into the virtual arrays in this pass.
+ * However, it may be that only a subset of the components are emitted to
+ * the entropy encoder during this first pass; be careful about looking
+ * at the scan-dependent variables (MCU dimensions, etc).
+ */
+
+METHODDEF(int)
+compress_first_pass (j_compress_ptr cinfo, JSAMPIMAGE input_buf)
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+ JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1;
+ JDIMENSION blocks_across, MCUs_across, MCUindex;
+ int bi, ci, h_samp_factor, block_row, block_rows, ndummy;
+ JCOEF lastDC;
+ jpeg_component_info *compptr;
+ JBLOCKARRAY buffer;
+ JBLOCKROW thisblockrow, lastblockrow;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Align the virtual buffer for this component. */
+ buffer = (*cinfo->mem->access_virt_barray)
+ ((j_common_ptr) cinfo, coef->whole_image[ci],
+ coef->iMCU_row_num * compptr->v_samp_factor,
+ (JDIMENSION) compptr->v_samp_factor, TRUE);
+ /* Count non-dummy DCT block rows in this iMCU row. */
+ if (coef->iMCU_row_num < last_iMCU_row)
+ block_rows = compptr->v_samp_factor;
+ else {
+ /* NB: can't use last_row_height here, since may not be set! */
+ block_rows = (int) (compptr->height_in_blocks % compptr->v_samp_factor);
+ if (block_rows == 0) block_rows = compptr->v_samp_factor;
+ }
+ blocks_across = compptr->width_in_blocks;
+ h_samp_factor = compptr->h_samp_factor;
+ /* Count number of dummy blocks to be added at the right margin. */
+ ndummy = (int) (blocks_across % h_samp_factor);
+ if (ndummy > 0)
+ ndummy = h_samp_factor - ndummy;
+ /* Perform DCT for all non-dummy blocks in this iMCU row. Each call
+ * on forward_DCT processes a complete horizontal row of DCT blocks.
+ */
+ for (block_row = 0; block_row < block_rows; block_row++) {
+ thisblockrow = buffer[block_row];
+ (*cinfo->fdct->forward_DCT) (cinfo, compptr,
+ input_buf[ci], thisblockrow,
+ (JDIMENSION) (block_row * DCTSIZE),
+ (JDIMENSION) 0, blocks_across);
+ if (ndummy > 0) {
+ /* Create dummy blocks at the right edge of the image. */
+ thisblockrow += blocks_across; /* => first dummy block */
+ jzero_far((void FAR *) thisblockrow, ndummy * SIZEOF(JBLOCK));
+ lastDC = thisblockrow[-1][0];
+ for (bi = 0; bi < ndummy; bi++) {
+ thisblockrow[bi][0] = lastDC;
+ }
+ }
+ }
+ /* If at end of image, create dummy block rows as needed.
+ * The tricky part here is that within each MCU, we want the DC values
+ * of the dummy blocks to match the last real block's DC value.
+ * This squeezes a few more bytes out of the resulting file...
+ */
+ if (coef->iMCU_row_num == last_iMCU_row) {
+ blocks_across += ndummy; /* include lower right corner */
+ MCUs_across = blocks_across / h_samp_factor;
+ for (block_row = block_rows; block_row < compptr->v_samp_factor;
+ block_row++) {
+ thisblockrow = buffer[block_row];
+ lastblockrow = buffer[block_row-1];
+ jzero_far((void FAR *) thisblockrow,
+ (size_t) (blocks_across * SIZEOF(JBLOCK)));
+ for (MCUindex = 0; MCUindex < MCUs_across; MCUindex++) {
+ lastDC = lastblockrow[h_samp_factor-1][0];
+ for (bi = 0; bi < h_samp_factor; bi++) {
+ thisblockrow[bi][0] = lastDC;
+ }
+ thisblockrow += h_samp_factor; /* advance to next MCU in row */
+ lastblockrow += h_samp_factor;
+ }
+ }
+ }
+ }
+ /* NB: compress_output will increment iMCU_row_num if successful.
+ * A suspension return will result in redoing all the work above next time.
+ */
+
+ /* Emit data to the entropy encoder, sharing code with subsequent passes */
+ return compress_output(cinfo, input_buf);
+}
+
+
+/*
+ * Process some data in subsequent passes of a multi-pass case.
+ * We process the equivalent of one fully interleaved MCU row ("iMCU" row)
+ * per call, ie, v_samp_factor block rows for each component in the scan.
+ * The data is obtained from the virtual arrays and fed to the entropy coder.
+ * Returns TRUE if the iMCU row is completed, FALSE if suspended.
+ *
+ * NB: input_buf is ignored; it is likely to be a NULL pointer.
+ */
+
+METHODDEF(int)
+compress_output (j_compress_ptr cinfo, JSAMPIMAGE )//input_buf)
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+ JDIMENSION MCU_col_num; /* index of current MCU within row */
+ int blkn, ci, xindex, yindex, yoffset;
+ JDIMENSION start_col;
+ JBLOCKARRAY buffer[MAX_COMPS_IN_SCAN];
+ JBLOCKROW buffer_ptr;
+ jpeg_component_info *compptr;
+
+ /* Align the virtual buffers for the components used in this scan.
+ * NB: during first pass, this is safe only because the buffers will
+ * already be aligned properly, so jmemmgr.c won't need to do any I/O.
+ */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ buffer[ci] = (*cinfo->mem->access_virt_barray)
+ ((j_common_ptr) cinfo, coef->whole_image[compptr->component_index],
+ coef->iMCU_row_num * compptr->v_samp_factor,
+ (JDIMENSION) compptr->v_samp_factor, FALSE);
+ }
+
+ /* Loop to process one whole iMCU row */
+ for (yoffset = coef->MCU_vert_offset; yoffset < coef->MCU_rows_per_iMCU_row;
+ yoffset++) {
+ for (MCU_col_num = coef->mcu_ctr; MCU_col_num < cinfo->MCUs_per_row;
+ MCU_col_num++) {
+ /* Construct list of pointers to DCT blocks belonging to this MCU */
+ blkn = 0; /* index of current DCT block within MCU */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ start_col = MCU_col_num * compptr->MCU_width;
+ for (yindex = 0; yindex < compptr->MCU_height; yindex++) {
+ buffer_ptr = buffer[ci][yindex+yoffset] + start_col;
+ for (xindex = 0; xindex < compptr->MCU_width; xindex++) {
+ coef->MCU_buffer[blkn++] = buffer_ptr++;
+ }
+ }
+ }
+ /* Try to write the MCU. */
+ if (! (*cinfo->entropy->encode_mcu) (cinfo, coef->MCU_buffer)) {
+ /* Suspension forced; update state counters and exit */
+ coef->MCU_vert_offset = yoffset;
+ coef->mcu_ctr = MCU_col_num;
+ return FALSE;
+ }
+ }
+ /* Completed an MCU row, but perhaps not an iMCU row */
+ coef->mcu_ctr = 0;
+ }
+ /* Completed the iMCU row, advance counters for next one */
+ coef->iMCU_row_num++;
+ start_iMCU_row(cinfo);
+ return TRUE;
+}
+
+#endif /* FULL_COEF_BUFFER_SUPPORTED */
+
+
+/*
+ * Initialize coefficient buffer controller.
+ */
+
+GLOBAL(void)
+jinit_c_coef_controller (j_compress_ptr cinfo, int need_full_buffer)
+{
+ my_coef_ptr coef;
+
+ coef = (my_coef_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_coef_controller));
+ cinfo->coef = (struct jpeg_c_coef_controller *) coef;
+ coef->pub.start_pass = start_pass_coef;
+
+ /* Create the coefficient buffer. */
+ if (need_full_buffer) {
+#ifdef FULL_COEF_BUFFER_SUPPORTED
+ /* Allocate a full-image virtual array for each component, */
+ /* padded to a multiple of samp_factor DCT blocks in each direction. */
+ int ci;
+ jpeg_component_info *compptr;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ coef->whole_image[ci] = (*cinfo->mem->request_virt_barray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE, FALSE,
+ (JDIMENSION) jround_up((long) compptr->width_in_blocks,
+ (long) compptr->h_samp_factor),
+ (JDIMENSION) jround_up((long) compptr->height_in_blocks,
+ (long) compptr->v_samp_factor),
+ (JDIMENSION) compptr->v_samp_factor);
+ }
+#else
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+#endif
+ } else {
+ /* We only need a single-MCU buffer. */
+ JBLOCKROW buffer;
+ int i;
+
+ buffer = (JBLOCKROW)
+ (*cinfo->mem->alloc_large) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ C_MAX_BLOCKS_IN_MCU * SIZEOF(JBLOCK));
+ for (i = 0; i < C_MAX_BLOCKS_IN_MCU; i++) {
+ coef->MCU_buffer[i] = buffer + i;
+ }
+ coef->whole_image[0] = NULL; /* flag for no virtual arrays */
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jccolor.cpp b/ml/dlib/dlib/external/libjpeg/jccolor.cpp
new file mode 100644
index 000000000..c5cfeded5
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jccolor.cpp
@@ -0,0 +1,459 @@
+/*
+ * jccolor.c
+ *
+ * Copyright (C) 1991-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains input colorspace conversion routines.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Private subobject */
+
+typedef struct {
+ struct jpeg_color_converter pub; /* public fields */
+
+ /* Private state for RGB->YCC conversion */
+ long * rgb_ycc_tab; /* => table for RGB to YCbCr conversion */
+} my_color_converter;
+
+typedef my_color_converter * my_cconvert_ptr;
+
+
+/**************** RGB -> YCbCr conversion: most common case **************/
+
+/*
+ * YCbCr is defined per CCIR 601-1, except that Cb and Cr are
+ * normalized to the range 0..MAXJSAMPLE rather than -0.5 .. 0.5.
+ * The conversion equations to be implemented are therefore
+ * Y = 0.29900 * R + 0.58700 * G + 0.11400 * B
+ * Cb = -0.16874 * R - 0.33126 * G + 0.50000 * B + CENTERJSAMPLE
+ * Cr = 0.50000 * R - 0.41869 * G - 0.08131 * B + CENTERJSAMPLE
+ * (These numbers are derived from TIFF 6.0 section 21, dated 3-June-92.)
+ * Note: older versions of the IJG code used a zero offset of MAXJSAMPLE/2,
+ * rather than CENTERJSAMPLE, for Cb and Cr. This gave equal positive and
+ * negative swings for Cb/Cr, but meant that grayscale values (Cb=Cr=0)
+ * were not represented exactly. Now we sacrifice exact representation of
+ * maximum red and maximum blue in order to get exact grayscales.
+ *
+ * To avoid floating-point arithmetic, we represent the fractional constants
+ * as integers scaled up by 2^16 (about 4 digits precision); we have to divide
+ * the products by 2^16, with appropriate rounding, to get the correct answer.
+ *
+ * For even more speed, we avoid doing any multiplications in the inner loop
+ * by precalculating the constants times R,G,B for all possible values.
+ * For 8-bit JSAMPLEs this is very reasonable (only 256 entries per table);
+ * for 12-bit samples it is still acceptable. It's not very reasonable for
+ * 16-bit samples, but if you want lossless storage you shouldn't be changing
+ * colorspace anyway.
+ * The CENTERJSAMPLE offsets and the rounding fudge-factor of 0.5 are included
+ * in the tables to save adding them separately in the inner loop.
+ */
+
+#define SCALEBITS 16 /* speediest right-shift on some machines */
+#define CBCR_OFFSET ((long) CENTERJSAMPLE << SCALEBITS)
+#define ONE_HALF ((long) 1 << (SCALEBITS-1))
+#define FIX(x) ((long) ((x) * (1L<<SCALEBITS) + 0.5))
+
+/* We allocate one big table and divide it up into eight parts, instead of
+ * doing eight alloc_small requests. This lets us use a single table base
+ * address, which can be held in a register in the inner loops on many
+ * machines (more than can hold all eight addresses, anyway).
+ */
+
+#define R_Y_OFF 0 /* offset to R => Y section */
+#define G_Y_OFF (1*(MAXJSAMPLE+1)) /* offset to G => Y section */
+#define B_Y_OFF (2*(MAXJSAMPLE+1)) /* etc. */
+#define R_CB_OFF (3*(MAXJSAMPLE+1))
+#define G_CB_OFF (4*(MAXJSAMPLE+1))
+#define B_CB_OFF (5*(MAXJSAMPLE+1))
+#define R_CR_OFF B_CB_OFF /* B=>Cb, R=>Cr are the same */
+#define G_CR_OFF (6*(MAXJSAMPLE+1))
+#define B_CR_OFF (7*(MAXJSAMPLE+1))
+#define TABLE_SIZE (8*(MAXJSAMPLE+1))
+
+
+/*
+ * Initialize for RGB->YCC colorspace conversion.
+ */
+
+METHODDEF(void)
+rgb_ycc_start (j_compress_ptr cinfo)
+{
+ my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert;
+ long * rgb_ycc_tab;
+ long i;
+
+ /* Allocate and fill in the conversion tables. */
+ cconvert->rgb_ycc_tab = rgb_ycc_tab = (long *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (TABLE_SIZE * SIZEOF(long)));
+
+ for (i = 0; i <= MAXJSAMPLE; i++) {
+ rgb_ycc_tab[i+R_Y_OFF] = FIX(0.29900) * i;
+ rgb_ycc_tab[i+G_Y_OFF] = FIX(0.58700) * i;
+ rgb_ycc_tab[i+B_Y_OFF] = FIX(0.11400) * i + ONE_HALF;
+ rgb_ycc_tab[i+R_CB_OFF] = (-FIX(0.16874)) * i;
+ rgb_ycc_tab[i+G_CB_OFF] = (-FIX(0.33126)) * i;
+ /* We use a rounding fudge-factor of 0.5-epsilon for Cb and Cr.
+ * This ensures that the maximum output will round to MAXJSAMPLE
+ * not MAXJSAMPLE+1, and thus that we don't have to range-limit.
+ */
+ rgb_ycc_tab[i+B_CB_OFF] = FIX(0.50000) * i + CBCR_OFFSET + ONE_HALF-1;
+/* B=>Cb and R=>Cr tables are the same
+ rgb_ycc_tab[i+R_CR_OFF] = FIX(0.50000) * i + CBCR_OFFSET + ONE_HALF-1;
+*/
+ rgb_ycc_tab[i+G_CR_OFF] = (-FIX(0.41869)) * i;
+ rgb_ycc_tab[i+B_CR_OFF] = (-FIX(0.08131)) * i;
+ }
+}
+
+
+/*
+ * Convert some rows of samples to the JPEG colorspace.
+ *
+ * Note that we change from the application's interleaved-pixel format
+ * to our internal noninterleaved, one-plane-per-component format.
+ * The input buffer is therefore three times as wide as the output buffer.
+ *
+ * A starting row offset is provided only for the output buffer. The caller
+ * can easily adjust the passed input_buf value to accommodate any row
+ * offset required on that side.
+ */
+
+METHODDEF(void)
+rgb_ycc_convert (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JSAMPIMAGE output_buf,
+ JDIMENSION output_row, int num_rows)
+{
+ my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert;
+ int r, g, b;
+ long * ctab = cconvert->rgb_ycc_tab;
+ JSAMPROW inptr;
+ JSAMPROW outptr0, outptr1, outptr2;
+ JDIMENSION col;
+ JDIMENSION num_cols = cinfo->image_width;
+
+ while (--num_rows >= 0) {
+ inptr = *input_buf++;
+ outptr0 = output_buf[0][output_row];
+ outptr1 = output_buf[1][output_row];
+ outptr2 = output_buf[2][output_row];
+ output_row++;
+ for (col = 0; col < num_cols; col++) {
+ r = GETJSAMPLE(inptr[RGB_RED]);
+ g = GETJSAMPLE(inptr[RGB_GREEN]);
+ b = GETJSAMPLE(inptr[RGB_BLUE]);
+ inptr += RGB_PIXELSIZE;
+ /* If the inputs are 0..MAXJSAMPLE, the outputs of these equations
+ * must be too; we do not need an explicit range-limiting operation.
+ * Hence the value being shifted is never negative, and we don't
+ * need the general RIGHT_SHIFT macro.
+ */
+ /* Y */
+ outptr0[col] = (JSAMPLE)
+ ((ctab[r+R_Y_OFF] + ctab[g+G_Y_OFF] + ctab[b+B_Y_OFF])
+ >> SCALEBITS);
+ /* Cb */
+ outptr1[col] = (JSAMPLE)
+ ((ctab[r+R_CB_OFF] + ctab[g+G_CB_OFF] + ctab[b+B_CB_OFF])
+ >> SCALEBITS);
+ /* Cr */
+ outptr2[col] = (JSAMPLE)
+ ((ctab[r+R_CR_OFF] + ctab[g+G_CR_OFF] + ctab[b+B_CR_OFF])
+ >> SCALEBITS);
+ }
+ }
+}
+
+
+/**************** Cases other than RGB -> YCbCr **************/
+
+
+/*
+ * Convert some rows of samples to the JPEG colorspace.
+ * This version handles RGB->grayscale conversion, which is the same
+ * as the RGB->Y portion of RGB->YCbCr.
+ * We assume rgb_ycc_start has been called (we only use the Y tables).
+ */
+
+METHODDEF(void)
+rgb_gray_convert (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JSAMPIMAGE output_buf,
+ JDIMENSION output_row, int num_rows)
+{
+ my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert;
+ int r, g, b;
+ long * ctab = cconvert->rgb_ycc_tab;
+ JSAMPROW inptr;
+ JSAMPROW outptr;
+ JDIMENSION col;
+ JDIMENSION num_cols = cinfo->image_width;
+
+ while (--num_rows >= 0) {
+ inptr = *input_buf++;
+ outptr = output_buf[0][output_row];
+ output_row++;
+ for (col = 0; col < num_cols; col++) {
+ r = GETJSAMPLE(inptr[RGB_RED]);
+ g = GETJSAMPLE(inptr[RGB_GREEN]);
+ b = GETJSAMPLE(inptr[RGB_BLUE]);
+ inptr += RGB_PIXELSIZE;
+ /* Y */
+ outptr[col] = (JSAMPLE)
+ ((ctab[r+R_Y_OFF] + ctab[g+G_Y_OFF] + ctab[b+B_Y_OFF])
+ >> SCALEBITS);
+ }
+ }
+}
+
+
+/*
+ * Convert some rows of samples to the JPEG colorspace.
+ * This version handles Adobe-style CMYK->YCCK conversion,
+ * where we convert R=1-C, G=1-M, and B=1-Y to YCbCr using the same
+ * conversion as above, while passing K (black) unchanged.
+ * We assume rgb_ycc_start has been called.
+ */
+
+METHODDEF(void)
+cmyk_ycck_convert (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JSAMPIMAGE output_buf,
+ JDIMENSION output_row, int num_rows)
+{
+ my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert;
+ int r, g, b;
+ long * ctab = cconvert->rgb_ycc_tab;
+ JSAMPROW inptr;
+ JSAMPROW outptr0, outptr1, outptr2, outptr3;
+ JDIMENSION col;
+ JDIMENSION num_cols = cinfo->image_width;
+
+ while (--num_rows >= 0) {
+ inptr = *input_buf++;
+ outptr0 = output_buf[0][output_row];
+ outptr1 = output_buf[1][output_row];
+ outptr2 = output_buf[2][output_row];
+ outptr3 = output_buf[3][output_row];
+ output_row++;
+ for (col = 0; col < num_cols; col++) {
+ r = MAXJSAMPLE - GETJSAMPLE(inptr[0]);
+ g = MAXJSAMPLE - GETJSAMPLE(inptr[1]);
+ b = MAXJSAMPLE - GETJSAMPLE(inptr[2]);
+ /* K passes through as-is */
+ outptr3[col] = inptr[3]; /* don't need GETJSAMPLE here */
+ inptr += 4;
+ /* If the inputs are 0..MAXJSAMPLE, the outputs of these equations
+ * must be too; we do not need an explicit range-limiting operation.
+ * Hence the value being shifted is never negative, and we don't
+ * need the general RIGHT_SHIFT macro.
+ */
+ /* Y */
+ outptr0[col] = (JSAMPLE)
+ ((ctab[r+R_Y_OFF] + ctab[g+G_Y_OFF] + ctab[b+B_Y_OFF])
+ >> SCALEBITS);
+ /* Cb */
+ outptr1[col] = (JSAMPLE)
+ ((ctab[r+R_CB_OFF] + ctab[g+G_CB_OFF] + ctab[b+B_CB_OFF])
+ >> SCALEBITS);
+ /* Cr */
+ outptr2[col] = (JSAMPLE)
+ ((ctab[r+R_CR_OFF] + ctab[g+G_CR_OFF] + ctab[b+B_CR_OFF])
+ >> SCALEBITS);
+ }
+ }
+}
+
+
+/*
+ * Convert some rows of samples to the JPEG colorspace.
+ * This version handles grayscale output with no conversion.
+ * The source can be either plain grayscale or YCbCr (since Y == gray).
+ */
+
+METHODDEF(void)
+grayscale_convert (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JSAMPIMAGE output_buf,
+ JDIMENSION output_row, int num_rows)
+{
+ JSAMPROW inptr;
+ JSAMPROW outptr;
+ JDIMENSION col;
+ JDIMENSION num_cols = cinfo->image_width;
+ int instride = cinfo->input_components;
+
+ while (--num_rows >= 0) {
+ inptr = *input_buf++;
+ outptr = output_buf[0][output_row];
+ output_row++;
+ for (col = 0; col < num_cols; col++) {
+ outptr[col] = inptr[0]; /* don't need GETJSAMPLE() here */
+ inptr += instride;
+ }
+ }
+}
+
+
+/*
+ * Convert some rows of samples to the JPEG colorspace.
+ * This version handles multi-component colorspaces without conversion.
+ * We assume input_components == num_components.
+ */
+
+METHODDEF(void)
+null_convert (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JSAMPIMAGE output_buf,
+ JDIMENSION output_row, int num_rows)
+{
+ JSAMPROW inptr;
+ JSAMPROW outptr;
+ JDIMENSION col;
+ int ci;
+ int nc = cinfo->num_components;
+ JDIMENSION num_cols = cinfo->image_width;
+
+ while (--num_rows >= 0) {
+ /* It seems fastest to make a separate pass for each component. */
+ for (ci = 0; ci < nc; ci++) {
+ inptr = *input_buf;
+ outptr = output_buf[ci][output_row];
+ for (col = 0; col < num_cols; col++) {
+ outptr[col] = inptr[ci]; /* don't need GETJSAMPLE() here */
+ inptr += nc;
+ }
+ }
+ input_buf++;
+ output_row++;
+ }
+}
+
+
+/*
+ * Empty method for start_pass.
+ */
+
+METHODDEF(void)
+null_method (j_compress_ptr )//cinfo)
+{
+ /* no work needed */
+}
+
+
+/*
+ * Module initialization routine for input colorspace conversion.
+ */
+
+GLOBAL(void)
+jinit_color_converter (j_compress_ptr cinfo)
+{
+ my_cconvert_ptr cconvert;
+
+ cconvert = (my_cconvert_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_color_converter));
+ cinfo->cconvert = (struct jpeg_color_converter *) cconvert;
+ /* set start_pass to null method until we find out differently */
+ cconvert->pub.start_pass = null_method;
+
+ /* Make sure input_components agrees with in_color_space */
+ switch (cinfo->in_color_space) {
+ case JCS_GRAYSCALE:
+ if (cinfo->input_components != 1)
+ ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE);
+ break;
+
+ case JCS_RGB:
+#if RGB_PIXELSIZE != 3
+ if (cinfo->input_components != RGB_PIXELSIZE)
+ ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE);
+ break;
+#endif /* else share code with YCbCr */
+
+ case JCS_YCbCr:
+ if (cinfo->input_components != 3)
+ ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE);
+ break;
+
+ case JCS_CMYK:
+ case JCS_YCCK:
+ if (cinfo->input_components != 4)
+ ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE);
+ break;
+
+ default: /* JCS_UNKNOWN can be anything */
+ if (cinfo->input_components < 1)
+ ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE);
+ break;
+ }
+
+ /* Check num_components, set conversion method based on requested space */
+ switch (cinfo->jpeg_color_space) {
+ case JCS_GRAYSCALE:
+ if (cinfo->num_components != 1)
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ if (cinfo->in_color_space == JCS_GRAYSCALE)
+ cconvert->pub.color_convert = grayscale_convert;
+ else if (cinfo->in_color_space == JCS_RGB) {
+ cconvert->pub.start_pass = rgb_ycc_start;
+ cconvert->pub.color_convert = rgb_gray_convert;
+ } else if (cinfo->in_color_space == JCS_YCbCr)
+ cconvert->pub.color_convert = grayscale_convert;
+ else
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ break;
+
+ case JCS_RGB:
+ if (cinfo->num_components != 3)
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ if (cinfo->in_color_space == JCS_RGB && RGB_PIXELSIZE == 3)
+ cconvert->pub.color_convert = null_convert;
+ else
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ break;
+
+ case JCS_YCbCr:
+ if (cinfo->num_components != 3)
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ if (cinfo->in_color_space == JCS_RGB) {
+ cconvert->pub.start_pass = rgb_ycc_start;
+ cconvert->pub.color_convert = rgb_ycc_convert;
+ } else if (cinfo->in_color_space == JCS_YCbCr)
+ cconvert->pub.color_convert = null_convert;
+ else
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ break;
+
+ case JCS_CMYK:
+ if (cinfo->num_components != 4)
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ if (cinfo->in_color_space == JCS_CMYK)
+ cconvert->pub.color_convert = null_convert;
+ else
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ break;
+
+ case JCS_YCCK:
+ if (cinfo->num_components != 4)
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ if (cinfo->in_color_space == JCS_CMYK) {
+ cconvert->pub.start_pass = rgb_ycc_start;
+ cconvert->pub.color_convert = cmyk_ycck_convert;
+ } else if (cinfo->in_color_space == JCS_YCCK)
+ cconvert->pub.color_convert = null_convert;
+ else
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ break;
+
+ default: /* allow null conversion of JCS_UNKNOWN */
+ if (cinfo->jpeg_color_space != cinfo->in_color_space ||
+ cinfo->num_components != cinfo->input_components)
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ cconvert->pub.color_convert = null_convert;
+ break;
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jcdctmgr.cpp b/ml/dlib/dlib/external/libjpeg/jcdctmgr.cpp
new file mode 100644
index 000000000..cbfc1a857
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcdctmgr.cpp
@@ -0,0 +1,387 @@
+/*
+ * jcdctmgr.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains the forward-DCT management logic.
+ * This code selects a particular DCT implementation to be used,
+ * and it performs related housekeeping chores including coefficient
+ * quantization.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdct.h" /* Private declarations for DCT subsystem */
+
+
+/* Private subobject for this module */
+
+typedef struct {
+ struct jpeg_forward_dct pub; /* public fields */
+
+ /* Pointer to the DCT routine actually in use */
+ forward_DCT_method_ptr do_dct;
+
+ /* The actual post-DCT divisors --- not identical to the quant table
+ * entries, because of scaling (especially for an unnormalized DCT).
+ * Each table is given in normal array order.
+ */
+ DCTELEM * divisors[NUM_QUANT_TBLS];
+
+#ifdef DCT_FLOAT_SUPPORTED
+ /* Same as above for the floating-point case. */
+ float_DCT_method_ptr do_float_dct;
+ FAST_FLOAT * float_divisors[NUM_QUANT_TBLS];
+#endif
+} my_fdct_controller;
+
+typedef my_fdct_controller * my_fdct_ptr;
+
+
+/*
+ * Initialize for a processing pass.
+ * Verify that all referenced Q-tables are present, and set up
+ * the divisor table for each one.
+ * In the current implementation, DCT of all components is done during
+ * the first pass, even if only some components will be output in the
+ * first scan. Hence all components should be examined here.
+ */
+
+METHODDEF(void)
+start_pass_fdctmgr (j_compress_ptr cinfo)
+{
+ my_fdct_ptr fdct = (my_fdct_ptr) cinfo->fdct;
+ int ci, qtblno, i;
+ jpeg_component_info *compptr;
+ JQUANT_TBL * qtbl;
+ DCTELEM * dtbl;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ qtblno = compptr->quant_tbl_no;
+ /* Make sure specified quantization table is present */
+ if (qtblno < 0 || qtblno >= NUM_QUANT_TBLS ||
+ cinfo->quant_tbl_ptrs[qtblno] == NULL)
+ ERREXIT1(cinfo, JERR_NO_QUANT_TABLE, qtblno);
+ qtbl = cinfo->quant_tbl_ptrs[qtblno];
+ /* Compute divisors for this quant table */
+ /* We may do this more than once for same table, but it's not a big deal */
+ switch (cinfo->dct_method) {
+#ifdef DCT_ISLOW_SUPPORTED
+ case JDCT_ISLOW:
+ /* For LL&M IDCT method, divisors are equal to raw quantization
+ * coefficients multiplied by 8 (to counteract scaling).
+ */
+ if (fdct->divisors[qtblno] == NULL) {
+ fdct->divisors[qtblno] = (DCTELEM *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ DCTSIZE2 * SIZEOF(DCTELEM));
+ }
+ dtbl = fdct->divisors[qtblno];
+ for (i = 0; i < DCTSIZE2; i++) {
+ dtbl[i] = ((DCTELEM) qtbl->quantval[i]) << 3;
+ }
+ break;
+#endif
+#ifdef DCT_IFAST_SUPPORTED
+ case JDCT_IFAST:
+ {
+ /* For AA&N IDCT method, divisors are equal to quantization
+ * coefficients scaled by scalefactor[row]*scalefactor[col], where
+ * scalefactor[0] = 1
+ * scalefactor[k] = cos(k*PI/16) * sqrt(2) for k=1..7
+ * We apply a further scale factor of 8.
+ */
+#define CONST_BITS 14
+ static const short aanscales[DCTSIZE2] = {
+ /* precomputed values scaled up by 14 bits */
+ 16384, 22725, 21407, 19266, 16384, 12873, 8867, 4520,
+ 22725, 31521, 29692, 26722, 22725, 17855, 12299, 6270,
+ 21407, 29692, 27969, 25172, 21407, 16819, 11585, 5906,
+ 19266, 26722, 25172, 22654, 19266, 15137, 10426, 5315,
+ 16384, 22725, 21407, 19266, 16384, 12873, 8867, 4520,
+ 12873, 17855, 16819, 15137, 12873, 10114, 6967, 3552,
+ 8867, 12299, 11585, 10426, 8867, 6967, 4799, 2446,
+ 4520, 6270, 5906, 5315, 4520, 3552, 2446, 1247
+ };
+ SHIFT_TEMPS
+
+ if (fdct->divisors[qtblno] == NULL) {
+ fdct->divisors[qtblno] = (DCTELEM *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ DCTSIZE2 * SIZEOF(DCTELEM));
+ }
+ dtbl = fdct->divisors[qtblno];
+ for (i = 0; i < DCTSIZE2; i++) {
+ dtbl[i] = (DCTELEM)
+ DESCALE(MULTIPLY16V16((long) qtbl->quantval[i],
+ (long) aanscales[i]),
+ CONST_BITS-3);
+ }
+ }
+ break;
+#endif
+#ifdef DCT_FLOAT_SUPPORTED
+ case JDCT_FLOAT:
+ {
+ /* For float AA&N IDCT method, divisors are equal to quantization
+ * coefficients scaled by scalefactor[row]*scalefactor[col], where
+ * scalefactor[0] = 1
+ * scalefactor[k] = cos(k*PI/16) * sqrt(2) for k=1..7
+ * We apply a further scale factor of 8.
+ * What's actually stored is 1/divisor so that the inner loop can
+ * use a multiplication rather than a division.
+ */
+ FAST_FLOAT * fdtbl;
+ int row, col;
+ static const double aanscalefactor[DCTSIZE] = {
+ 1.0, 1.387039845, 1.306562965, 1.175875602,
+ 1.0, 0.785694958, 0.541196100, 0.275899379
+ };
+
+ if (fdct->float_divisors[qtblno] == NULL) {
+ fdct->float_divisors[qtblno] = (FAST_FLOAT *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ DCTSIZE2 * SIZEOF(FAST_FLOAT));
+ }
+ fdtbl = fdct->float_divisors[qtblno];
+ i = 0;
+ for (row = 0; row < DCTSIZE; row++) {
+ for (col = 0; col < DCTSIZE; col++) {
+ fdtbl[i] = (FAST_FLOAT)
+ (1.0 / (((double) qtbl->quantval[i] *
+ aanscalefactor[row] * aanscalefactor[col] * 8.0)));
+ i++;
+ }
+ }
+ }
+ break;
+#endif
+ default:
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+ break;
+ }
+ }
+}
+
+
+/*
+ * Perform forward DCT on one or more blocks of a component.
+ *
+ * The input samples are taken from the sample_data[] array starting at
+ * position start_row/start_col, and moving to the right for any additional
+ * blocks. The quantized coefficients are returned in coef_blocks[].
+ */
+
+METHODDEF(void)
+forward_DCT (j_compress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY sample_data, JBLOCKROW coef_blocks,
+ JDIMENSION start_row, JDIMENSION start_col,
+ JDIMENSION num_blocks)
+/* This version is used for integer DCT implementations. */
+{
+ /* This routine is heavily used, so it's worth coding it tightly. */
+ my_fdct_ptr fdct = (my_fdct_ptr) cinfo->fdct;
+ forward_DCT_method_ptr do_dct = fdct->do_dct;
+ DCTELEM * divisors = fdct->divisors[compptr->quant_tbl_no];
+ DCTELEM workspace[DCTSIZE2]; /* work area for FDCT subroutine */
+ JDIMENSION bi;
+
+ sample_data += start_row; /* fold in the vertical offset once */
+
+ for (bi = 0; bi < num_blocks; bi++, start_col += DCTSIZE) {
+ /* Load data into workspace, applying unsigned->signed conversion */
+ { DCTELEM *workspaceptr;
+ JSAMPROW elemptr;
+ int elemr;
+
+ workspaceptr = workspace;
+ for (elemr = 0; elemr < DCTSIZE; elemr++) {
+ elemptr = sample_data[elemr] + start_col;
+#if DCTSIZE == 8 /* unroll the inner loop */
+ *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE;
+ *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE;
+ *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE;
+ *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE;
+ *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE;
+ *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE;
+ *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE;
+ *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE;
+#else
+ { int elemc;
+ for (elemc = DCTSIZE; elemc > 0; elemc--) {
+ *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE;
+ }
+ }
+#endif
+ }
+ }
+
+ /* Perform the DCT */
+ (*do_dct) (workspace);
+
+ /* Quantize/descale the coefficients, and store into coef_blocks[] */
+ { DCTELEM temp, qval;
+ int i;
+ JCOEFPTR output_ptr = coef_blocks[bi];
+
+ for (i = 0; i < DCTSIZE2; i++) {
+ qval = divisors[i];
+ temp = workspace[i];
+ /* Divide the coefficient value by qval, ensuring proper rounding.
+ * Since C does not specify the direction of rounding for negative
+ * quotients, we have to force the dividend positive for portability.
+ *
+ * In most files, at least half of the output values will be zero
+ * (at default quantization settings, more like three-quarters...)
+ * so we should ensure that this case is fast. On many machines,
+ * a comparison is enough cheaper than a divide to make a special test
+ * a win. Since both inputs will be nonnegative, we need only test
+ * for a < b to discover whether a/b is 0.
+ * If your machine's division is fast enough, define FAST_DIVIDE.
+ */
+#ifdef FAST_DIVIDE
+#define DIVIDE_BY(a,b) a /= b
+#else
+#define DIVIDE_BY(a,b) if (a >= b) a /= b; else a = 0
+#endif
+ if (temp < 0) {
+ temp = -temp;
+ temp += qval>>1; /* for rounding */
+ DIVIDE_BY(temp, qval);
+ temp = -temp;
+ } else {
+ temp += qval>>1; /* for rounding */
+ DIVIDE_BY(temp, qval);
+ }
+ output_ptr[i] = (JCOEF) temp;
+ }
+ }
+ }
+}
+
+
+#ifdef DCT_FLOAT_SUPPORTED
+
+METHODDEF(void)
+forward_DCT_float (j_compress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY sample_data, JBLOCKROW coef_blocks,
+ JDIMENSION start_row, JDIMENSION start_col,
+ JDIMENSION num_blocks)
+/* This version is used for floating-point DCT implementations. */
+{
+ /* This routine is heavily used, so it's worth coding it tightly. */
+ my_fdct_ptr fdct = (my_fdct_ptr) cinfo->fdct;
+ float_DCT_method_ptr do_dct = fdct->do_float_dct;
+ FAST_FLOAT * divisors = fdct->float_divisors[compptr->quant_tbl_no];
+ FAST_FLOAT workspace[DCTSIZE2]; /* work area for FDCT subroutine */
+ JDIMENSION bi;
+
+ sample_data += start_row; /* fold in the vertical offset once */
+
+ for (bi = 0; bi < num_blocks; bi++, start_col += DCTSIZE) {
+ /* Load data into workspace, applying unsigned->signed conversion */
+ { FAST_FLOAT *workspaceptr;
+ JSAMPROW elemptr;
+ int elemr;
+
+ workspaceptr = workspace;
+ for (elemr = 0; elemr < DCTSIZE; elemr++) {
+ elemptr = sample_data[elemr] + start_col;
+#if DCTSIZE == 8 /* unroll the inner loop */
+ *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE);
+ *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE);
+ *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE);
+ *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE);
+ *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE);
+ *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE);
+ *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE);
+ *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE);
+#else
+ { int elemc;
+ for (elemc = DCTSIZE; elemc > 0; elemc--) {
+ *workspaceptr++ = (FAST_FLOAT)
+ (GETJSAMPLE(*elemptr++) - CENTERJSAMPLE);
+ }
+ }
+#endif
+ }
+ }
+
+ /* Perform the DCT */
+ (*do_dct) (workspace);
+
+ /* Quantize/descale the coefficients, and store into coef_blocks[] */
+ { FAST_FLOAT temp;
+ int i;
+ JCOEFPTR output_ptr = coef_blocks[bi];
+
+ for (i = 0; i < DCTSIZE2; i++) {
+ /* Apply the quantization and scaling factor */
+ temp = workspace[i] * divisors[i];
+ /* Round to nearest integer.
+ * Since C does not specify the direction of rounding for negative
+ * quotients, we have to force the dividend positive for portability.
+ * The maximum coefficient size is +-16K (for 12-bit data), so this
+ * code should work for either 16-bit or 32-bit ints.
+ */
+ output_ptr[i] = (JCOEF) ((int) (temp + (FAST_FLOAT) 16384.5) - 16384);
+ }
+ }
+ }
+}
+
+#endif /* DCT_FLOAT_SUPPORTED */
+
+
+/*
+ * Initialize FDCT manager.
+ */
+
+GLOBAL(void)
+jinit_forward_dct (j_compress_ptr cinfo)
+{
+ my_fdct_ptr fdct;
+ int i;
+
+ fdct = (my_fdct_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_fdct_controller));
+ cinfo->fdct = (struct jpeg_forward_dct *) fdct;
+ fdct->pub.start_pass = start_pass_fdctmgr;
+
+ switch (cinfo->dct_method) {
+#ifdef DCT_ISLOW_SUPPORTED
+ case JDCT_ISLOW:
+ fdct->pub.forward_DCT = forward_DCT;
+ fdct->do_dct = jpeg_fdct_islow;
+ break;
+#endif
+#ifdef DCT_IFAST_SUPPORTED
+ case JDCT_IFAST:
+ fdct->pub.forward_DCT = forward_DCT;
+ fdct->do_dct = jpeg_fdct_ifast;
+ break;
+#endif
+#ifdef DCT_FLOAT_SUPPORTED
+ case JDCT_FLOAT:
+ fdct->pub.forward_DCT = forward_DCT_float;
+ fdct->do_float_dct = jpeg_fdct_float;
+ break;
+#endif
+ default:
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+ break;
+ }
+
+ /* Mark divisor tables unallocated */
+ for (i = 0; i < NUM_QUANT_TBLS; i++) {
+ fdct->divisors[i] = NULL;
+#ifdef DCT_FLOAT_SUPPORTED
+ fdct->float_divisors[i] = NULL;
+#endif
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jchuff.cpp b/ml/dlib/dlib/external/libjpeg/jchuff.cpp
new file mode 100644
index 000000000..d543319a6
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jchuff.cpp
@@ -0,0 +1,909 @@
+/*
+ * jchuff.c
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains Huffman entropy encoding routines.
+ *
+ * Much of the complexity here has to do with supporting output suspension.
+ * If the data destination module demands suspension, we want to be able to
+ * back up to the start of the current MCU. To do this, we copy state
+ * variables into local working storage, and update them back to the
+ * permanent JPEG objects only upon successful completion of an MCU.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jchuff.h" /* Declarations shared with jcphuff.c */
+
+
+/* Expanded entropy encoder object for Huffman encoding.
+ *
+ * The savable_state subrecord contains fields that change within an MCU,
+ * but must not be updated permanently until we complete the MCU.
+ */
+
+typedef struct {
+ long put_buffer; /* current bit-accumulation buffer */
+ int put_bits; /* # of bits now in it */
+ int last_dc_val[MAX_COMPS_IN_SCAN]; /* last DC coef for each component */
+} savable_state;
+
+/* This macro is to work around compilers with missing or broken
+ * structure assignment. You'll need to fix this code if you have
+ * such a compiler and you change MAX_COMPS_IN_SCAN.
+ */
+
+#ifndef NO_STRUCT_ASSIGN
+#define ASSIGN_STATE(dest,src) ((dest) = (src))
+#else
+#if MAX_COMPS_IN_SCAN == 4
+#define ASSIGN_STATE(dest,src) \
+ ((dest).put_buffer = (src).put_buffer, \
+ (dest).put_bits = (src).put_bits, \
+ (dest).last_dc_val[0] = (src).last_dc_val[0], \
+ (dest).last_dc_val[1] = (src).last_dc_val[1], \
+ (dest).last_dc_val[2] = (src).last_dc_val[2], \
+ (dest).last_dc_val[3] = (src).last_dc_val[3])
+#endif
+#endif
+
+
+typedef struct {
+ struct jpeg_entropy_encoder pub; /* public fields */
+
+ savable_state saved; /* Bit buffer & DC state at start of MCU */
+
+ /* These fields are NOT loaded into local working state. */
+ unsigned int restarts_to_go; /* MCUs left in this restart interval */
+ int next_restart_num; /* next restart number to write (0-7) */
+
+ /* Pointers to derived tables (these workspaces have image lifespan) */
+ c_derived_tbl * dc_derived_tbls[NUM_HUFF_TBLS];
+ c_derived_tbl * ac_derived_tbls[NUM_HUFF_TBLS];
+
+#ifdef ENTROPY_OPT_SUPPORTED /* Statistics tables for optimization */
+ long * dc_count_ptrs[NUM_HUFF_TBLS];
+ long * ac_count_ptrs[NUM_HUFF_TBLS];
+#endif
+} huff_entropy_encoder;
+
+typedef huff_entropy_encoder * huff_entropy_ptr;
+
+/* Working state while writing an MCU.
+ * This struct contains all the fields that are needed by subroutines.
+ */
+
+typedef struct {
+ JOCTET * next_output_byte; /* => next byte to write in buffer */
+ size_t free_in_buffer; /* # of byte spaces remaining in buffer */
+ savable_state cur; /* Current bit buffer & DC state */
+ j_compress_ptr cinfo; /* dump_buffer needs access to this */
+} working_state;
+
+
+/* Forward declarations */
+METHODDEF(int) encode_mcu_huff JPP((j_compress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+METHODDEF(void) finish_pass_huff JPP((j_compress_ptr cinfo));
+#ifdef ENTROPY_OPT_SUPPORTED
+METHODDEF(int) encode_mcu_gather JPP((j_compress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+METHODDEF(void) finish_pass_gather JPP((j_compress_ptr cinfo));
+#endif
+
+
+/*
+ * Initialize for a Huffman-compressed scan.
+ * If gather_statistics is TRUE, we do not output anything during the scan,
+ * just count the Huffman symbols used and generate Huffman code tables.
+ */
+
+METHODDEF(void)
+start_pass_huff (j_compress_ptr cinfo, int gather_statistics)
+{
+ huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy;
+ int ci, dctbl, actbl;
+ jpeg_component_info * compptr;
+
+ if (gather_statistics) {
+#ifdef ENTROPY_OPT_SUPPORTED
+ entropy->pub.encode_mcu = encode_mcu_gather;
+ entropy->pub.finish_pass = finish_pass_gather;
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ } else {
+ entropy->pub.encode_mcu = encode_mcu_huff;
+ entropy->pub.finish_pass = finish_pass_huff;
+ }
+
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ dctbl = compptr->dc_tbl_no;
+ actbl = compptr->ac_tbl_no;
+ if (gather_statistics) {
+#ifdef ENTROPY_OPT_SUPPORTED
+ /* Check for invalid table indexes */
+ /* (make_c_derived_tbl does this in the other path) */
+ if (dctbl < 0 || dctbl >= NUM_HUFF_TBLS)
+ ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, dctbl);
+ if (actbl < 0 || actbl >= NUM_HUFF_TBLS)
+ ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, actbl);
+ /* Allocate and zero the statistics tables */
+ /* Note that jpeg_gen_optimal_table expects 257 entries in each table! */
+ if (entropy->dc_count_ptrs[dctbl] == NULL)
+ entropy->dc_count_ptrs[dctbl] = (long *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ 257 * SIZEOF(long));
+ MEMZERO(entropy->dc_count_ptrs[dctbl], 257 * SIZEOF(long));
+ if (entropy->ac_count_ptrs[actbl] == NULL)
+ entropy->ac_count_ptrs[actbl] = (long *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ 257 * SIZEOF(long));
+ MEMZERO(entropy->ac_count_ptrs[actbl], 257 * SIZEOF(long));
+#endif
+ } else {
+ /* Compute derived values for Huffman tables */
+ /* We may do this more than once for a table, but it's not expensive */
+ jpeg_make_c_derived_tbl(cinfo, TRUE, dctbl,
+ & entropy->dc_derived_tbls[dctbl]);
+ jpeg_make_c_derived_tbl(cinfo, FALSE, actbl,
+ & entropy->ac_derived_tbls[actbl]);
+ }
+ /* Initialize DC predictions to 0 */
+ entropy->saved.last_dc_val[ci] = 0;
+ }
+
+ /* Initialize bit buffer to empty */
+ entropy->saved.put_buffer = 0;
+ entropy->saved.put_bits = 0;
+
+ /* Initialize restart stuff */
+ entropy->restarts_to_go = cinfo->restart_interval;
+ entropy->next_restart_num = 0;
+}
+
+
+/*
+ * Compute the derived values for a Huffman table.
+ * This routine also performs some validation checks on the table.
+ *
+ * Note this is also used by jcphuff.c.
+ */
+
+GLOBAL(void)
+jpeg_make_c_derived_tbl (j_compress_ptr cinfo, int isDC, int tblno,
+ c_derived_tbl ** pdtbl)
+{
+ JHUFF_TBL *htbl;
+ c_derived_tbl *dtbl;
+ int p, i, l, lastp, si, maxsymbol;
+ char huffsize[257];
+ unsigned int huffcode[257];
+ unsigned int code;
+
+ /* Note that huffsize[] and huffcode[] are filled in code-length order,
+ * paralleling the order of the symbols themselves in htbl->huffval[].
+ */
+
+ /* Find the input Huffman table */
+ if (tblno < 0 || tblno >= NUM_HUFF_TBLS)
+ ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tblno);
+ htbl =
+ isDC ? cinfo->dc_huff_tbl_ptrs[tblno] : cinfo->ac_huff_tbl_ptrs[tblno];
+ if (htbl == NULL)
+ ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tblno);
+
+ /* Allocate a workspace if we haven't already done so. */
+ if (*pdtbl == NULL)
+ *pdtbl = (c_derived_tbl *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(c_derived_tbl));
+ dtbl = *pdtbl;
+
+ /* Figure C.1: make table of Huffman code length for each symbol */
+
+ p = 0;
+ for (l = 1; l <= 16; l++) {
+ i = (int) htbl->bits[l];
+ if (i < 0 || p + i > 256) /* protect against table overrun */
+ ERREXIT(cinfo, JERR_BAD_HUFF_TABLE);
+ while (i--)
+ huffsize[p++] = (char) l;
+ }
+ huffsize[p] = 0;
+ lastp = p;
+
+ /* Figure C.2: generate the codes themselves */
+ /* We also validate that the counts represent a legal Huffman code tree. */
+
+ code = 0;
+ si = huffsize[0];
+ p = 0;
+ while (huffsize[p]) {
+ while (((int) huffsize[p]) == si) {
+ huffcode[p++] = code;
+ code++;
+ }
+ /* code is now 1 more than the last code used for codelength si; but
+ * it must still fit in si bits, since no code is allowed to be all ones.
+ */
+ if (((long) code) >= (((long) 1) << si))
+ ERREXIT(cinfo, JERR_BAD_HUFF_TABLE);
+ code <<= 1;
+ si++;
+ }
+
+ /* Figure C.3: generate encoding tables */
+ /* These are code and size indexed by symbol value */
+
+ /* Set all codeless symbols to have code length 0;
+ * this lets us detect duplicate VAL entries here, and later
+ * allows emit_bits to detect any attempt to emit such symbols.
+ */
+ MEMZERO(dtbl->ehufsi, SIZEOF(dtbl->ehufsi));
+
+ /* This is also a convenient place to check for out-of-range
+ * and duplicated VAL entries. We allow 0..255 for AC symbols
+ * but only 0..15 for DC. (We could constrain them further
+ * based on data depth and mode, but this seems enough.)
+ */
+ maxsymbol = isDC ? 15 : 255;
+
+ for (p = 0; p < lastp; p++) {
+ i = htbl->huffval[p];
+ if (i < 0 || i > maxsymbol || dtbl->ehufsi[i])
+ ERREXIT(cinfo, JERR_BAD_HUFF_TABLE);
+ dtbl->ehufco[i] = huffcode[p];
+ dtbl->ehufsi[i] = huffsize[p];
+ }
+}
+
+
+/* Outputting bytes to the file */
+
+/* Emit a byte, taking 'action' if must suspend. */
+#define emit_byte(state,val,action) \
+ { *(state)->next_output_byte++ = (JOCTET) (val); \
+ if (--(state)->free_in_buffer == 0) \
+ if (! dump_buffer(state)) \
+ { action; } }
+
+
+LOCAL(int)
+dump_buffer (working_state * state)
+/* Empty the output buffer; return TRUE if successful, FALSE if must suspend */
+{
+ struct jpeg_destination_mgr * dest = state->cinfo->dest;
+
+ if (! (*dest->empty_output_buffer) (state->cinfo))
+ return FALSE;
+ /* After a successful buffer dump, must reset buffer pointers */
+ state->next_output_byte = dest->next_output_byte;
+ state->free_in_buffer = dest->free_in_buffer;
+ return TRUE;
+}
+
+
+/* Outputting bits to the file */
+
+/* Only the right 24 bits of put_buffer are used; the valid bits are
+ * left-justified in this part. At most 16 bits can be passed to emit_bits
+ * in one call, and we never retain more than 7 bits in put_buffer
+ * between calls, so 24 bits are sufficient.
+ */
+
+inline
+LOCAL(int)
+emit_bits (working_state * state, unsigned int code, int size)
+/* Emit some bits; return TRUE if successful, FALSE if must suspend */
+{
+ /* This routine is heavily used, so it's worth coding tightly. */
+ long put_buffer = (long) code;
+ int put_bits = state->cur.put_bits;
+
+ /* if size is 0, caller used an invalid Huffman table entry */
+ if (size == 0)
+ ERREXIT(state->cinfo, JERR_HUFF_MISSING_CODE);
+
+ put_buffer &= (((long) 1)<<size) - 1; /* mask off any extra bits in code */
+
+ put_bits += size; /* new number of bits in buffer */
+
+ put_buffer <<= 24 - put_bits; /* align incoming bits */
+
+ put_buffer |= state->cur.put_buffer; /* and merge with old buffer contents */
+
+ while (put_bits >= 8) {
+ int c = (int) ((put_buffer >> 16) & 0xFF);
+
+ emit_byte(state, c, return FALSE);
+ if (c == 0xFF) { /* need to stuff a zero byte? */
+ emit_byte(state, 0, return FALSE);
+ }
+ put_buffer <<= 8;
+ put_bits -= 8;
+ }
+
+ state->cur.put_buffer = put_buffer; /* update state variables */
+ state->cur.put_bits = put_bits;
+
+ return TRUE;
+}
+
+
+LOCAL(int)
+flush_bits (working_state * state)
+{
+ if (! emit_bits(state, 0x7F, 7)) /* fill any partial byte with ones */
+ return FALSE;
+ state->cur.put_buffer = 0; /* and reset bit-buffer to empty */
+ state->cur.put_bits = 0;
+ return TRUE;
+}
+
+
+/* Encode a single block's worth of coefficients */
+
+LOCAL(int)
+encode_one_block (working_state * state, JCOEFPTR block, int last_dc_val,
+ c_derived_tbl *dctbl, c_derived_tbl *actbl)
+{
+ int temp, temp2;
+ int nbits;
+ int k, r, i;
+
+ /* Encode the DC coefficient difference per section F.1.2.1 */
+
+ temp = temp2 = block[0] - last_dc_val;
+
+ if (temp < 0) {
+ temp = -temp; /* temp is abs value of input */
+ /* For a negative input, want temp2 = bitwise complement of abs(input) */
+ /* This code assumes we are on a two's complement machine */
+ temp2--;
+ }
+
+ /* Find the number of bits needed for the magnitude of the coefficient */
+ nbits = 0;
+ while (temp) {
+ nbits++;
+ temp >>= 1;
+ }
+ /* Check for out-of-range coefficient values.
+ * Since we're encoding a difference, the range limit is twice as much.
+ */
+ if (nbits > MAX_COEF_BITS+1)
+ ERREXIT(state->cinfo, JERR_BAD_DCT_COEF);
+
+ /* Emit the Huffman-coded symbol for the number of bits */
+ if (! emit_bits(state, dctbl->ehufco[nbits], dctbl->ehufsi[nbits]))
+ return FALSE;
+
+ /* Emit that number of bits of the value, if positive, */
+ /* or the complement of its magnitude, if negative. */
+ if (nbits) /* emit_bits rejects calls with size 0 */
+ if (! emit_bits(state, (unsigned int) temp2, nbits))
+ return FALSE;
+
+ /* Encode the AC coefficients per section F.1.2.2 */
+
+ r = 0; /* r = run length of zeros */
+
+ for (k = 1; k < DCTSIZE2; k++) {
+ if ((temp = block[jpeg_natural_order[k]]) == 0) {
+ r++;
+ } else {
+ /* if run length > 15, must emit special run-length-16 codes (0xF0) */
+ while (r > 15) {
+ if (! emit_bits(state, actbl->ehufco[0xF0], actbl->ehufsi[0xF0]))
+ return FALSE;
+ r -= 16;
+ }
+
+ temp2 = temp;
+ if (temp < 0) {
+ temp = -temp; /* temp is abs value of input */
+ /* This code assumes we are on a two's complement machine */
+ temp2--;
+ }
+
+ /* Find the number of bits needed for the magnitude of the coefficient */
+ nbits = 1; /* there must be at least one 1 bit */
+ while ((temp >>= 1))
+ nbits++;
+ /* Check for out-of-range coefficient values */
+ if (nbits > MAX_COEF_BITS)
+ ERREXIT(state->cinfo, JERR_BAD_DCT_COEF);
+
+ /* Emit Huffman symbol for run length / number of bits */
+ i = (r << 4) + nbits;
+ if (! emit_bits(state, actbl->ehufco[i], actbl->ehufsi[i]))
+ return FALSE;
+
+ /* Emit that number of bits of the value, if positive, */
+ /* or the complement of its magnitude, if negative. */
+ if (! emit_bits(state, (unsigned int) temp2, nbits))
+ return FALSE;
+
+ r = 0;
+ }
+ }
+
+ /* If the last coef(s) were zero, emit an end-of-block code */
+ if (r > 0)
+ if (! emit_bits(state, actbl->ehufco[0], actbl->ehufsi[0]))
+ return FALSE;
+
+ return TRUE;
+}
+
+
+/*
+ * Emit a restart marker & resynchronize predictions.
+ */
+
+LOCAL(int)
+emit_restart (working_state * state, int restart_num)
+{
+ int ci;
+
+ if (! flush_bits(state))
+ return FALSE;
+
+ emit_byte(state, 0xFF, return FALSE);
+ emit_byte(state, JPEG_RST0 + restart_num, return FALSE);
+
+ /* Re-initialize DC predictions to 0 */
+ for (ci = 0; ci < state->cinfo->comps_in_scan; ci++)
+ state->cur.last_dc_val[ci] = 0;
+
+ /* The restart counter is not updated until we successfully write the MCU. */
+
+ return TRUE;
+}
+
+
+/*
+ * Encode and output one MCU's worth of Huffman-compressed coefficients.
+ */
+
+METHODDEF(int)
+encode_mcu_huff (j_compress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy;
+ working_state state;
+ int blkn, ci;
+ jpeg_component_info * compptr;
+
+ /* Load up working state */
+ state.next_output_byte = cinfo->dest->next_output_byte;
+ state.free_in_buffer = cinfo->dest->free_in_buffer;
+ ASSIGN_STATE(state.cur, entropy->saved);
+ state.cinfo = cinfo;
+
+ /* Emit restart marker if needed */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0)
+ if (! emit_restart(&state, entropy->next_restart_num))
+ return FALSE;
+ }
+
+ /* Encode the MCU data blocks */
+ for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) {
+ ci = cinfo->MCU_membership[blkn];
+ compptr = cinfo->cur_comp_info[ci];
+ if (! encode_one_block(&state,
+ MCU_data[blkn][0], state.cur.last_dc_val[ci],
+ entropy->dc_derived_tbls[compptr->dc_tbl_no],
+ entropy->ac_derived_tbls[compptr->ac_tbl_no]))
+ return FALSE;
+ /* Update last_dc_val */
+ state.cur.last_dc_val[ci] = MCU_data[blkn][0][0];
+ }
+
+ /* Completed MCU, so update state */
+ cinfo->dest->next_output_byte = state.next_output_byte;
+ cinfo->dest->free_in_buffer = state.free_in_buffer;
+ ASSIGN_STATE(entropy->saved, state.cur);
+
+ /* Update restart-interval state too */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0) {
+ entropy->restarts_to_go = cinfo->restart_interval;
+ entropy->next_restart_num++;
+ entropy->next_restart_num &= 7;
+ }
+ entropy->restarts_to_go--;
+ }
+
+ return TRUE;
+}
+
+
+/*
+ * Finish up at the end of a Huffman-compressed scan.
+ */
+
+METHODDEF(void)
+finish_pass_huff (j_compress_ptr cinfo)
+{
+ huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy;
+ working_state state;
+
+ /* Load up working state ... flush_bits needs it */
+ state.next_output_byte = cinfo->dest->next_output_byte;
+ state.free_in_buffer = cinfo->dest->free_in_buffer;
+ ASSIGN_STATE(state.cur, entropy->saved);
+ state.cinfo = cinfo;
+
+ /* Flush out the last data */
+ if (! flush_bits(&state))
+ ERREXIT(cinfo, JERR_CANT_SUSPEND);
+
+ /* Update state */
+ cinfo->dest->next_output_byte = state.next_output_byte;
+ cinfo->dest->free_in_buffer = state.free_in_buffer;
+ ASSIGN_STATE(entropy->saved, state.cur);
+}
+
+
+/*
+ * Huffman coding optimization.
+ *
+ * We first scan the supplied data and count the number of uses of each symbol
+ * that is to be Huffman-coded. (This process MUST agree with the code above.)
+ * Then we build a Huffman coding tree for the observed counts.
+ * Symbols which are not needed at all for the particular image are not
+ * assigned any code, which saves space in the DHT marker as well as in
+ * the compressed data.
+ */
+
+#ifdef ENTROPY_OPT_SUPPORTED
+
+
+/* Process a single block's worth of coefficients */
+
+LOCAL(void)
+htest_one_block (j_compress_ptr cinfo, JCOEFPTR block, int last_dc_val,
+ long dc_counts[], long ac_counts[])
+{
+ int temp;
+ int nbits;
+ int k, r;
+
+ /* Encode the DC coefficient difference per section F.1.2.1 */
+
+ temp = block[0] - last_dc_val;
+ if (temp < 0)
+ temp = -temp;
+
+ /* Find the number of bits needed for the magnitude of the coefficient */
+ nbits = 0;
+ while (temp) {
+ nbits++;
+ temp >>= 1;
+ }
+ /* Check for out-of-range coefficient values.
+ * Since we're encoding a difference, the range limit is twice as much.
+ */
+ if (nbits > MAX_COEF_BITS+1)
+ ERREXIT(cinfo, JERR_BAD_DCT_COEF);
+
+ /* Count the Huffman symbol for the number of bits */
+ dc_counts[nbits]++;
+
+ /* Encode the AC coefficients per section F.1.2.2 */
+
+ r = 0; /* r = run length of zeros */
+
+ for (k = 1; k < DCTSIZE2; k++) {
+ if ((temp = block[jpeg_natural_order[k]]) == 0) {
+ r++;
+ } else {
+ /* if run length > 15, must emit special run-length-16 codes (0xF0) */
+ while (r > 15) {
+ ac_counts[0xF0]++;
+ r -= 16;
+ }
+
+ /* Find the number of bits needed for the magnitude of the coefficient */
+ if (temp < 0)
+ temp = -temp;
+
+ /* Find the number of bits needed for the magnitude of the coefficient */
+ nbits = 1; /* there must be at least one 1 bit */
+ while ((temp >>= 1))
+ nbits++;
+ /* Check for out-of-range coefficient values */
+ if (nbits > MAX_COEF_BITS)
+ ERREXIT(cinfo, JERR_BAD_DCT_COEF);
+
+ /* Count Huffman symbol for run length / number of bits */
+ ac_counts[(r << 4) + nbits]++;
+
+ r = 0;
+ }
+ }
+
+ /* If the last coef(s) were zero, emit an end-of-block code */
+ if (r > 0)
+ ac_counts[0]++;
+}
+
+
+/*
+ * Trial-encode one MCU's worth of Huffman-compressed coefficients.
+ * No data is actually output, so no suspension return is possible.
+ */
+
+METHODDEF(int)
+encode_mcu_gather (j_compress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy;
+ int blkn, ci;
+ jpeg_component_info * compptr;
+
+ /* Take care of restart intervals if needed */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0) {
+ /* Re-initialize DC predictions to 0 */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++)
+ entropy->saved.last_dc_val[ci] = 0;
+ /* Update restart state */
+ entropy->restarts_to_go = cinfo->restart_interval;
+ }
+ entropy->restarts_to_go--;
+ }
+
+ for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) {
+ ci = cinfo->MCU_membership[blkn];
+ compptr = cinfo->cur_comp_info[ci];
+ htest_one_block(cinfo, MCU_data[blkn][0], entropy->saved.last_dc_val[ci],
+ entropy->dc_count_ptrs[compptr->dc_tbl_no],
+ entropy->ac_count_ptrs[compptr->ac_tbl_no]);
+ entropy->saved.last_dc_val[ci] = MCU_data[blkn][0][0];
+ }
+
+ return TRUE;
+}
+
+
+/*
+ * Generate the best Huffman code table for the given counts, fill htbl.
+ * Note this is also used by jcphuff.c.
+ *
+ * The JPEG standard requires that no symbol be assigned a codeword of all
+ * one bits (so that padding bits added at the end of a compressed segment
+ * can't look like a valid code). Because of the canonical ordering of
+ * codewords, this just means that there must be an unused slot in the
+ * longest codeword length category. Section K.2 of the JPEG spec suggests
+ * reserving such a slot by pretending that symbol 256 is a valid symbol
+ * with count 1. In theory that's not optimal; giving it count zero but
+ * including it in the symbol set anyway should give a better Huffman code.
+ * But the theoretically better code actually seems to come out worse in
+ * practice, because it produces more all-ones bytes (which incur stuffed
+ * zero bytes in the final file). In any case the difference is tiny.
+ *
+ * The JPEG standard requires Huffman codes to be no more than 16 bits long.
+ * If some symbols have a very small but nonzero probability, the Huffman tree
+ * must be adjusted to meet the code length restriction. We currently use
+ * the adjustment method suggested in JPEG section K.2. This method is *not*
+ * optimal; it may not choose the best possible limited-length code. But
+ * typically only very-low-frequency symbols will be given less-than-optimal
+ * lengths, so the code is almost optimal. Experimental comparisons against
+ * an optimal limited-length-code algorithm indicate that the difference is
+ * microscopic --- usually less than a hundredth of a percent of total size.
+ * So the extra complexity of an optimal algorithm doesn't seem worthwhile.
+ */
+
+GLOBAL(void)
+jpeg_gen_optimal_table (j_compress_ptr cinfo, JHUFF_TBL * htbl, long freq[])
+{
+#define MAX_CLEN 32 /* assumed maximum initial code length */
+ unsigned short bits[MAX_CLEN+1]; /* bits[k] = # of symbols with code length k */
+ int codesize[257]; /* codesize[k] = code length of symbol k */
+ int others[257]; /* next symbol in current branch of tree */
+ int c1, c2;
+ int p, i, j;
+ long v;
+
+ /* This algorithm is explained in section K.2 of the JPEG standard */
+
+ MEMZERO(bits, SIZEOF(bits));
+ MEMZERO(codesize, SIZEOF(codesize));
+ for (i = 0; i < 257; i++)
+ others[i] = -1; /* init links to empty */
+
+ freq[256] = 1; /* make sure 256 has a nonzero count */
+ /* Including the pseudo-symbol 256 in the Huffman procedure guarantees
+ * that no real symbol is given code-value of all ones, because 256
+ * will be placed last in the largest codeword category.
+ */
+
+ /* Huffman's basic algorithm to assign optimal code lengths to symbols */
+
+ for (;;) {
+ /* Find the smallest nonzero frequency, set c1 = its symbol */
+ /* In case of ties, take the larger symbol number */
+ c1 = -1;
+ v = 1000000000L;
+ for (i = 0; i <= 256; i++) {
+ if (freq[i] && freq[i] <= v) {
+ v = freq[i];
+ c1 = i;
+ }
+ }
+
+ /* Find the next smallest nonzero frequency, set c2 = its symbol */
+ /* In case of ties, take the larger symbol number */
+ c2 = -1;
+ v = 1000000000L;
+ for (i = 0; i <= 256; i++) {
+ if (freq[i] && freq[i] <= v && i != c1) {
+ v = freq[i];
+ c2 = i;
+ }
+ }
+
+ /* Done if we've merged everything into one frequency */
+ if (c2 < 0)
+ break;
+
+ /* Else merge the two counts/trees */
+ freq[c1] += freq[c2];
+ freq[c2] = 0;
+
+ /* Increment the codesize of everything in c1's tree branch */
+ codesize[c1]++;
+ while (others[c1] >= 0) {
+ c1 = others[c1];
+ codesize[c1]++;
+ }
+
+ others[c1] = c2; /* chain c2 onto c1's tree branch */
+
+ /* Increment the codesize of everything in c2's tree branch */
+ codesize[c2]++;
+ while (others[c2] >= 0) {
+ c2 = others[c2];
+ codesize[c2]++;
+ }
+ }
+
+ /* Now count the number of symbols of each code length */
+ for (i = 0; i <= 256; i++) {
+ if (codesize[i]) {
+ /* The JPEG standard seems to think that this can't happen, */
+ /* but I'm paranoid... */
+ if (codesize[i] > MAX_CLEN)
+ ERREXIT(cinfo, JERR_HUFF_CLEN_OVERFLOW);
+
+ bits[codesize[i]]++;
+ }
+ }
+
+ /* JPEG doesn't allow symbols with code lengths over 16 bits, so if the pure
+ * Huffman procedure assigned any such lengths, we must adjust the coding.
+ * Here is what the JPEG spec says about how this next bit works:
+ * Since symbols are paired for the longest Huffman code, the symbols are
+ * removed from this length category two at a time. The prefix for the pair
+ * (which is one bit shorter) is allocated to one of the pair; then,
+ * skipping the BITS entry for that prefix length, a code word from the next
+ * shortest nonzero BITS entry is converted into a prefix for two code words
+ * one bit longer.
+ */
+
+ for (i = MAX_CLEN; i > 16; i--) {
+ while (bits[i] > 0) {
+ j = i - 2; /* find length of new prefix to be used */
+ while (bits[j] == 0)
+ j--;
+
+ bits[i] -= 2; /* remove two symbols */
+ bits[i-1]++; /* one goes in this length */
+ bits[j+1] += 2; /* two new symbols in this length */
+ bits[j]--; /* symbol of this length is now a prefix */
+ }
+ }
+
+ /* Remove the count for the pseudo-symbol 256 from the largest codelength */
+ while (bits[i] == 0) /* find largest codelength still in use */
+ i--;
+ bits[i]--;
+
+ /* Return final symbol counts (only for lengths 0..16) */
+ MEMCOPY(htbl->bits, bits, SIZEOF(htbl->bits));
+
+ /* Return a list of the symbols sorted by code length */
+ /* It's not real clear to me why we don't need to consider the codelength
+ * changes made above, but the JPEG spec seems to think this works.
+ */
+ p = 0;
+ for (i = 1; i <= MAX_CLEN; i++) {
+ for (j = 0; j <= 255; j++) {
+ if (codesize[j] == i) {
+ htbl->huffval[p] = (unsigned char) j;
+ p++;
+ }
+ }
+ }
+
+ /* Set sent_table FALSE so updated table will be written to JPEG file. */
+ htbl->sent_table = FALSE;
+}
+
+
+/*
+ * Finish up a statistics-gathering pass and create the new Huffman tables.
+ */
+
+METHODDEF(void)
+finish_pass_gather (j_compress_ptr cinfo)
+{
+ huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy;
+ int ci, dctbl, actbl;
+ jpeg_component_info * compptr;
+ JHUFF_TBL **htblptr;
+ int did_dc[NUM_HUFF_TBLS];
+ int did_ac[NUM_HUFF_TBLS];
+
+ /* It's important not to apply jpeg_gen_optimal_table more than once
+ * per table, because it clobbers the input frequency counts!
+ */
+ MEMZERO(did_dc, SIZEOF(did_dc));
+ MEMZERO(did_ac, SIZEOF(did_ac));
+
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ dctbl = compptr->dc_tbl_no;
+ actbl = compptr->ac_tbl_no;
+ if (! did_dc[dctbl]) {
+ htblptr = & cinfo->dc_huff_tbl_ptrs[dctbl];
+ if (*htblptr == NULL)
+ *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo);
+ jpeg_gen_optimal_table(cinfo, *htblptr, entropy->dc_count_ptrs[dctbl]);
+ did_dc[dctbl] = TRUE;
+ }
+ if (! did_ac[actbl]) {
+ htblptr = & cinfo->ac_huff_tbl_ptrs[actbl];
+ if (*htblptr == NULL)
+ *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo);
+ jpeg_gen_optimal_table(cinfo, *htblptr, entropy->ac_count_ptrs[actbl]);
+ did_ac[actbl] = TRUE;
+ }
+ }
+}
+
+
+#endif /* ENTROPY_OPT_SUPPORTED */
+
+
+/*
+ * Module initialization routine for Huffman entropy encoding.
+ */
+
+GLOBAL(void)
+jinit_huff_encoder (j_compress_ptr cinfo)
+{
+ huff_entropy_ptr entropy;
+ int i;
+
+ entropy = (huff_entropy_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(huff_entropy_encoder));
+ cinfo->entropy = (struct jpeg_entropy_encoder *) entropy;
+ entropy->pub.start_pass = start_pass_huff;
+
+ /* Mark tables unallocated */
+ for (i = 0; i < NUM_HUFF_TBLS; i++) {
+ entropy->dc_derived_tbls[i] = entropy->ac_derived_tbls[i] = NULL;
+#ifdef ENTROPY_OPT_SUPPORTED
+ entropy->dc_count_ptrs[i] = entropy->ac_count_ptrs[i] = NULL;
+#endif
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jchuff.h b/ml/dlib/dlib/external/libjpeg/jchuff.h
new file mode 100644
index 000000000..2d184ec55
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jchuff.h
@@ -0,0 +1,47 @@
+/*
+ * jchuff.h
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains declarations for Huffman entropy encoding routines
+ * that are shared between the sequential encoder (jchuff.c) and the
+ * progressive encoder (jcphuff.c). No other modules need to see these.
+ */
+
+/* The legal range of a DCT coefficient is
+ * -1024 .. +1023 for 8-bit data;
+ * -16384 .. +16383 for 12-bit data.
+ * Hence the magnitude should always fit in 10 or 14 bits respectively.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+#define MAX_COEF_BITS 10
+#else
+#define MAX_COEF_BITS 14
+#endif
+
+/* Derived data constructed for each Huffman table */
+
+typedef struct {
+ unsigned int ehufco[256]; /* code for each symbol */
+ char ehufsi[256]; /* length of code for each symbol */
+ /* If no code has been allocated for a symbol S, ehufsi[S] contains 0 */
+} c_derived_tbl;
+
+/* Short forms of external names for systems with brain-damaged linkers. */
+
+#ifdef NEED_SHORT_EXTERNAL_NAMES
+#define jpeg_make_c_derived_tbl jMkCDerived
+#define jpeg_gen_optimal_table jGenOptTbl
+#endif /* NEED_SHORT_EXTERNAL_NAMES */
+
+/* Expand a Huffman table definition into the derived format */
+EXTERN(void) jpeg_make_c_derived_tbl
+ JPP((j_compress_ptr cinfo, int isDC, int tblno,
+ c_derived_tbl ** pdtbl));
+
+/* Generate an optimal table definition given the specified counts */
+EXTERN(void) jpeg_gen_optimal_table
+ JPP((j_compress_ptr cinfo, JHUFF_TBL * htbl, long freq[]));
diff --git a/ml/dlib/dlib/external/libjpeg/jcinit.cpp b/ml/dlib/dlib/external/libjpeg/jcinit.cpp
new file mode 100644
index 000000000..2ca809cfd
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcinit.cpp
@@ -0,0 +1,72 @@
+/*
+ * jcinit.c
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains initialization logic for the JPEG compressor.
+ * This routine is in charge of selecting the modules to be executed and
+ * making an initialization call to each one.
+ *
+ * Logically, this code belongs in jcmaster.c. It's split out because
+ * linking this routine implies linking the entire compression library.
+ * For a transcoding-only application, we want to be able to use jcmaster.c
+ * without linking in the whole library.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/*
+ * Master selection of compression modules.
+ * This is done once at the start of processing an image. We determine
+ * which modules will be used and give them appropriate initialization calls.
+ */
+
+GLOBAL(void)
+jinit_compress_master (j_compress_ptr cinfo)
+{
+ /* Initialize master control (includes parameter checking/processing) */
+ jinit_c_master_control(cinfo, FALSE /* full compression */);
+
+ /* Preprocessing */
+ if (! cinfo->raw_data_in) {
+ jinit_color_converter(cinfo);
+ jinit_downsampler(cinfo);
+ jinit_c_prep_controller(cinfo, FALSE /* never need full buffer here */);
+ }
+ /* Forward DCT */
+ jinit_forward_dct(cinfo);
+ /* Entropy encoding: either Huffman or arithmetic coding. */
+ if (cinfo->arith_code) {
+ ERREXIT(cinfo, JERR_ARITH_NOTIMPL);
+ } else {
+ if (cinfo->progressive_mode) {
+#ifdef C_PROGRESSIVE_SUPPORTED
+ jinit_phuff_encoder(cinfo);
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ } else
+ jinit_huff_encoder(cinfo);
+ }
+
+ /* Need a full-image coefficient buffer in any multi-pass mode. */
+ jinit_c_coef_controller(cinfo,
+ (int) (cinfo->num_scans > 1 || cinfo->optimize_coding));
+ jinit_c_main_controller(cinfo, FALSE /* never need full buffer here */);
+
+ jinit_marker_writer(cinfo);
+
+ /* We can now tell the memory manager to allocate virtual arrays. */
+ (*cinfo->mem->realize_virt_arrays) ((j_common_ptr) cinfo);
+
+ /* Write the datastream header (SOI) immediately.
+ * Frame and scan headers are postponed till later.
+ * This lets application insert special markers after the SOI.
+ */
+ (*cinfo->marker->write_file_header) (cinfo);
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jcmainct.cpp b/ml/dlib/dlib/external/libjpeg/jcmainct.cpp
new file mode 100644
index 000000000..1e5f97a20
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcmainct.cpp
@@ -0,0 +1,293 @@
+/*
+ * jcmainct.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains the main buffer controller for compression.
+ * The main buffer lies between the pre-processor and the JPEG
+ * compressor proper; it holds downsampled data in the JPEG colorspace.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Note: currently, there is no operating mode in which a full-image buffer
+ * is needed at this step. If there were, that mode could not be used with
+ * "raw data" input, since this module is bypassed in that case. However,
+ * we've left the code here for possible use in special applications.
+ */
+#undef FULL_MAIN_BUFFER_SUPPORTED
+
+
+/* Private buffer controller object */
+
+typedef struct {
+ struct jpeg_c_main_controller pub; /* public fields */
+
+ JDIMENSION cur_iMCU_row; /* number of current iMCU row */
+ JDIMENSION rowgroup_ctr; /* counts row groups received in iMCU row */
+ int suspended; /* remember if we suspended output */
+ J_BUF_MODE pass_mode; /* current operating mode */
+
+ /* If using just a strip buffer, this points to the entire set of buffers
+ * (we allocate one for each component). In the full-image case, this
+ * points to the currently accessible strips of the virtual arrays.
+ */
+ JSAMPARRAY buffer[MAX_COMPONENTS];
+
+#ifdef FULL_MAIN_BUFFER_SUPPORTED
+ /* If using full-image storage, this array holds pointers to virtual-array
+ * control blocks for each component. Unused if not full-image storage.
+ */
+ jvirt_sarray_ptr whole_image[MAX_COMPONENTS];
+#endif
+} my_main_controller;
+
+typedef my_main_controller * my_main_ptr;
+
+
+/* Forward declarations */
+METHODDEF(void) process_data_simple_main
+ JPP((j_compress_ptr cinfo, JSAMPARRAY input_buf,
+ JDIMENSION *in_row_ctr, JDIMENSION in_rows_avail));
+#ifdef FULL_MAIN_BUFFER_SUPPORTED
+METHODDEF(void) process_data_buffer_main
+ JPP((j_compress_ptr cinfo, JSAMPARRAY input_buf,
+ JDIMENSION *in_row_ctr, JDIMENSION in_rows_avail));
+#endif
+
+
+/*
+ * Initialize for a processing pass.
+ */
+
+METHODDEF(void)
+start_pass_main (j_compress_ptr cinfo, J_BUF_MODE pass_mode)
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+
+ /* Do nothing in raw-data mode. */
+ if (cinfo->raw_data_in)
+ return;
+
+ main->cur_iMCU_row = 0; /* initialize counters */
+ main->rowgroup_ctr = 0;
+ main->suspended = FALSE;
+ main->pass_mode = pass_mode; /* save mode for use by process_data */
+
+ switch (pass_mode) {
+ case JBUF_PASS_THRU:
+#ifdef FULL_MAIN_BUFFER_SUPPORTED
+ if (main->whole_image[0] != NULL)
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+#endif
+ main->pub.process_data = process_data_simple_main;
+ break;
+#ifdef FULL_MAIN_BUFFER_SUPPORTED
+ case JBUF_SAVE_SOURCE:
+ case JBUF_CRANK_DEST:
+ case JBUF_SAVE_AND_PASS:
+ if (main->whole_image[0] == NULL)
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ main->pub.process_data = process_data_buffer_main;
+ break;
+#endif
+ default:
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ break;
+ }
+}
+
+
+/*
+ * Process some data.
+ * This routine handles the simple pass-through mode,
+ * where we have only a strip buffer.
+ */
+
+METHODDEF(void)
+process_data_simple_main (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JDIMENSION *in_row_ctr,
+ JDIMENSION in_rows_avail)
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+
+ while (main->cur_iMCU_row < cinfo->total_iMCU_rows) {
+ /* Read input data if we haven't filled the main buffer yet */
+ if (main->rowgroup_ctr < DCTSIZE)
+ (*cinfo->prep->pre_process_data) (cinfo,
+ input_buf, in_row_ctr, in_rows_avail,
+ main->buffer, &main->rowgroup_ctr,
+ (JDIMENSION) DCTSIZE);
+
+ /* If we don't have a full iMCU row buffered, return to application for
+ * more data. Note that preprocessor will always pad to fill the iMCU row
+ * at the bottom of the image.
+ */
+ if (main->rowgroup_ctr != DCTSIZE)
+ return;
+
+ /* Send the completed row to the compressor */
+ if (! (*cinfo->coef->compress_data) (cinfo, main->buffer)) {
+ /* If compressor did not consume the whole row, then we must need to
+ * suspend processing and return to the application. In this situation
+ * we pretend we didn't yet consume the last input row; otherwise, if
+ * it happened to be the last row of the image, the application would
+ * think we were done.
+ */
+ if (! main->suspended) {
+ (*in_row_ctr)--;
+ main->suspended = TRUE;
+ }
+ return;
+ }
+ /* We did finish the row. Undo our little suspension hack if a previous
+ * call suspended; then mark the main buffer empty.
+ */
+ if (main->suspended) {
+ (*in_row_ctr)++;
+ main->suspended = FALSE;
+ }
+ main->rowgroup_ctr = 0;
+ main->cur_iMCU_row++;
+ }
+}
+
+
+#ifdef FULL_MAIN_BUFFER_SUPPORTED
+
+/*
+ * Process some data.
+ * This routine handles all of the modes that use a full-size buffer.
+ */
+
+METHODDEF(void)
+process_data_buffer_main (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JDIMENSION *in_row_ctr,
+ JDIMENSION in_rows_avail)
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+ int ci;
+ jpeg_component_info *compptr;
+ int writing = (main->pass_mode != JBUF_CRANK_DEST);
+
+ while (main->cur_iMCU_row < cinfo->total_iMCU_rows) {
+ /* Realign the virtual buffers if at the start of an iMCU row. */
+ if (main->rowgroup_ctr == 0) {
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ main->buffer[ci] = (*cinfo->mem->access_virt_sarray)
+ ((j_common_ptr) cinfo, main->whole_image[ci],
+ main->cur_iMCU_row * (compptr->v_samp_factor * DCTSIZE),
+ (JDIMENSION) (compptr->v_samp_factor * DCTSIZE), writing);
+ }
+ /* In a read pass, pretend we just read some source data. */
+ if (! writing) {
+ *in_row_ctr += cinfo->max_v_samp_factor * DCTSIZE;
+ main->rowgroup_ctr = DCTSIZE;
+ }
+ }
+
+ /* If a write pass, read input data until the current iMCU row is full. */
+ /* Note: preprocessor will pad if necessary to fill the last iMCU row. */
+ if (writing) {
+ (*cinfo->prep->pre_process_data) (cinfo,
+ input_buf, in_row_ctr, in_rows_avail,
+ main->buffer, &main->rowgroup_ctr,
+ (JDIMENSION) DCTSIZE);
+ /* Return to application if we need more data to fill the iMCU row. */
+ if (main->rowgroup_ctr < DCTSIZE)
+ return;
+ }
+
+ /* Emit data, unless this is a sink-only pass. */
+ if (main->pass_mode != JBUF_SAVE_SOURCE) {
+ if (! (*cinfo->coef->compress_data) (cinfo, main->buffer)) {
+ /* If compressor did not consume the whole row, then we must need to
+ * suspend processing and return to the application. In this situation
+ * we pretend we didn't yet consume the last input row; otherwise, if
+ * it happened to be the last row of the image, the application would
+ * think we were done.
+ */
+ if (! main->suspended) {
+ (*in_row_ctr)--;
+ main->suspended = TRUE;
+ }
+ return;
+ }
+ /* We did finish the row. Undo our little suspension hack if a previous
+ * call suspended; then mark the main buffer empty.
+ */
+ if (main->suspended) {
+ (*in_row_ctr)++;
+ main->suspended = FALSE;
+ }
+ }
+
+ /* If get here, we are done with this iMCU row. Mark buffer empty. */
+ main->rowgroup_ctr = 0;
+ main->cur_iMCU_row++;
+ }
+}
+
+#endif /* FULL_MAIN_BUFFER_SUPPORTED */
+
+
+/*
+ * Initialize main buffer controller.
+ */
+
+GLOBAL(void)
+jinit_c_main_controller (j_compress_ptr cinfo, int need_full_buffer)
+{
+ my_main_ptr main;
+ int ci;
+ jpeg_component_info *compptr;
+
+ main = (my_main_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_main_controller));
+ cinfo->main = (struct jpeg_c_main_controller *) main;
+ main->pub.start_pass = start_pass_main;
+
+ /* We don't need to create a buffer in raw-data mode. */
+ if (cinfo->raw_data_in)
+ return;
+
+ /* Create the buffer. It holds downsampled data, so each component
+ * may be of a different size.
+ */
+ if (need_full_buffer) {
+#ifdef FULL_MAIN_BUFFER_SUPPORTED
+ /* Allocate a full-image virtual array for each component */
+ /* Note we pad the bottom to a multiple of the iMCU height */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ main->whole_image[ci] = (*cinfo->mem->request_virt_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE, FALSE,
+ compptr->width_in_blocks * DCTSIZE,
+ (JDIMENSION) jround_up((long) compptr->height_in_blocks,
+ (long) compptr->v_samp_factor) * DCTSIZE,
+ (JDIMENSION) (compptr->v_samp_factor * DCTSIZE));
+ }
+#else
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+#endif
+ } else {
+#ifdef FULL_MAIN_BUFFER_SUPPORTED
+ main->whole_image[0] = NULL; /* flag for no virtual arrays */
+#endif
+ /* Allocate a strip buffer for each component */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ main->buffer[ci] = (*cinfo->mem->alloc_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ compptr->width_in_blocks * DCTSIZE,
+ (JDIMENSION) (compptr->v_samp_factor * DCTSIZE));
+ }
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jcmarker.cpp b/ml/dlib/dlib/external/libjpeg/jcmarker.cpp
new file mode 100644
index 000000000..1bfd15c55
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcmarker.cpp
@@ -0,0 +1,664 @@
+/*
+ * jcmarker.c
+ *
+ * Copyright (C) 1991-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains routines to write JPEG datastream markers.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+typedef enum { /* JPEG marker codes */
+ M_SOF0 = 0xc0,
+ M_SOF1 = 0xc1,
+ M_SOF2 = 0xc2,
+ M_SOF3 = 0xc3,
+
+ M_SOF5 = 0xc5,
+ M_SOF6 = 0xc6,
+ M_SOF7 = 0xc7,
+
+ M_JPG = 0xc8,
+ M_SOF9 = 0xc9,
+ M_SOF10 = 0xca,
+ M_SOF11 = 0xcb,
+
+ M_SOF13 = 0xcd,
+ M_SOF14 = 0xce,
+ M_SOF15 = 0xcf,
+
+ M_DHT = 0xc4,
+
+ M_DAC = 0xcc,
+
+ M_RST0 = 0xd0,
+ M_RST1 = 0xd1,
+ M_RST2 = 0xd2,
+ M_RST3 = 0xd3,
+ M_RST4 = 0xd4,
+ M_RST5 = 0xd5,
+ M_RST6 = 0xd6,
+ M_RST7 = 0xd7,
+
+ M_SOI = 0xd8,
+ M_EOI = 0xd9,
+ M_SOS = 0xda,
+ M_DQT = 0xdb,
+ M_DNL = 0xdc,
+ M_DRI = 0xdd,
+ M_DHP = 0xde,
+ M_EXP = 0xdf,
+
+ M_APP0 = 0xe0,
+ M_APP1 = 0xe1,
+ M_APP2 = 0xe2,
+ M_APP3 = 0xe3,
+ M_APP4 = 0xe4,
+ M_APP5 = 0xe5,
+ M_APP6 = 0xe6,
+ M_APP7 = 0xe7,
+ M_APP8 = 0xe8,
+ M_APP9 = 0xe9,
+ M_APP10 = 0xea,
+ M_APP11 = 0xeb,
+ M_APP12 = 0xec,
+ M_APP13 = 0xed,
+ M_APP14 = 0xee,
+ M_APP15 = 0xef,
+
+ M_JPG0 = 0xf0,
+ M_JPG13 = 0xfd,
+ M_COM = 0xfe,
+
+ M_TEM = 0x01,
+
+ M_ERROR = 0x100
+} JPEG_MARKER;
+
+
+/* Private state */
+
+typedef struct {
+ struct jpeg_marker_writer pub; /* public fields */
+
+ unsigned int last_restart_interval; /* last DRI value emitted; 0 after SOI */
+} my_marker_writer;
+
+typedef my_marker_writer * my_marker_ptr;
+
+
+/*
+ * Basic output routines.
+ *
+ * Note that we do not support suspension while writing a marker.
+ * Therefore, an application using suspension must ensure that there is
+ * enough buffer space for the initial markers (typ. 600-700 bytes) before
+ * calling jpeg_start_compress, and enough space to write the trailing EOI
+ * (a few bytes) before calling jpeg_finish_compress. Multipass compression
+ * modes are not supported at all with suspension, so those two are the only
+ * points where markers will be written.
+ */
+
+LOCAL(void)
+emit_byte (j_compress_ptr cinfo, int val)
+/* Emit a byte */
+{
+ struct jpeg_destination_mgr * dest = cinfo->dest;
+
+ *(dest->next_output_byte)++ = (JOCTET) val;
+ if (--dest->free_in_buffer == 0) {
+ if (! (*dest->empty_output_buffer) (cinfo))
+ ERREXIT(cinfo, JERR_CANT_SUSPEND);
+ }
+}
+
+
+LOCAL(void)
+emit_marker (j_compress_ptr cinfo, JPEG_MARKER mark)
+/* Emit a marker code */
+{
+ emit_byte(cinfo, 0xFF);
+ emit_byte(cinfo, (int) mark);
+}
+
+
+LOCAL(void)
+emit_2bytes (j_compress_ptr cinfo, int value)
+/* Emit a 2-byte integer; these are always MSB first in JPEG files */
+{
+ emit_byte(cinfo, (value >> 8) & 0xFF);
+ emit_byte(cinfo, value & 0xFF);
+}
+
+
+/*
+ * Routines to write specific marker types.
+ */
+
+LOCAL(int)
+emit_dqt (j_compress_ptr cinfo, int index)
+/* Emit a DQT marker */
+/* Returns the precision used (0 = 8bits, 1 = 16bits) for baseline checking */
+{
+ JQUANT_TBL * qtbl = cinfo->quant_tbl_ptrs[index];
+ int prec;
+ int i;
+
+ if (qtbl == NULL)
+ ERREXIT1(cinfo, JERR_NO_QUANT_TABLE, index);
+
+ prec = 0;
+ for (i = 0; i < DCTSIZE2; i++) {
+ if (qtbl->quantval[i] > 255)
+ prec = 1;
+ }
+
+ if (! qtbl->sent_table) {
+ emit_marker(cinfo, M_DQT);
+
+ emit_2bytes(cinfo, prec ? DCTSIZE2*2 + 1 + 2 : DCTSIZE2 + 1 + 2);
+
+ emit_byte(cinfo, index + (prec<<4));
+
+ for (i = 0; i < DCTSIZE2; i++) {
+ /* The table entries must be emitted in zigzag order. */
+ unsigned int qval = qtbl->quantval[jpeg_natural_order[i]];
+ if (prec)
+ emit_byte(cinfo, (int) (qval >> 8));
+ emit_byte(cinfo, (int) (qval & 0xFF));
+ }
+
+ qtbl->sent_table = TRUE;
+ }
+
+ return prec;
+}
+
+
+LOCAL(void)
+emit_dht (j_compress_ptr cinfo, int index, int is_ac)
+/* Emit a DHT marker */
+{
+ JHUFF_TBL * htbl;
+ int length, i;
+
+ if (is_ac) {
+ htbl = cinfo->ac_huff_tbl_ptrs[index];
+ index += 0x10; /* output index has AC bit set */
+ } else {
+ htbl = cinfo->dc_huff_tbl_ptrs[index];
+ }
+
+ if (htbl == NULL)
+ ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, index);
+
+ if (! htbl->sent_table) {
+ emit_marker(cinfo, M_DHT);
+
+ length = 0;
+ for (i = 1; i <= 16; i++)
+ length += htbl->bits[i];
+
+ emit_2bytes(cinfo, length + 2 + 1 + 16);
+ emit_byte(cinfo, index);
+
+ for (i = 1; i <= 16; i++)
+ emit_byte(cinfo, htbl->bits[i]);
+
+ for (i = 0; i < length; i++)
+ emit_byte(cinfo, htbl->huffval[i]);
+
+ htbl->sent_table = TRUE;
+ }
+}
+
+
+LOCAL(void)
+emit_dac (j_compress_ptr )//cinfo)
+/* Emit a DAC marker */
+/* Since the useful info is so small, we want to emit all the tables in */
+/* one DAC marker. Therefore this routine does its own scan of the table. */
+{
+#ifdef C_ARITH_CODING_SUPPORTED
+ char dc_in_use[NUM_ARITH_TBLS];
+ char ac_in_use[NUM_ARITH_TBLS];
+ int length, i;
+ jpeg_component_info *compptr;
+
+ for (i = 0; i < NUM_ARITH_TBLS; i++)
+ dc_in_use[i] = ac_in_use[i] = 0;
+
+ for (i = 0; i < cinfo->comps_in_scan; i++) {
+ compptr = cinfo->cur_comp_info[i];
+ dc_in_use[compptr->dc_tbl_no] = 1;
+ ac_in_use[compptr->ac_tbl_no] = 1;
+ }
+
+ length = 0;
+ for (i = 0; i < NUM_ARITH_TBLS; i++)
+ length += dc_in_use[i] + ac_in_use[i];
+
+ emit_marker(cinfo, M_DAC);
+
+ emit_2bytes(cinfo, length*2 + 2);
+
+ for (i = 0; i < NUM_ARITH_TBLS; i++) {
+ if (dc_in_use[i]) {
+ emit_byte(cinfo, i);
+ emit_byte(cinfo, cinfo->arith_dc_L[i] + (cinfo->arith_dc_U[i]<<4));
+ }
+ if (ac_in_use[i]) {
+ emit_byte(cinfo, i + 0x10);
+ emit_byte(cinfo, cinfo->arith_ac_K[i]);
+ }
+ }
+#endif /* C_ARITH_CODING_SUPPORTED */
+}
+
+
+LOCAL(void)
+emit_dri (j_compress_ptr cinfo)
+/* Emit a DRI marker */
+{
+ emit_marker(cinfo, M_DRI);
+
+ emit_2bytes(cinfo, 4); /* fixed length */
+
+ emit_2bytes(cinfo, (int) cinfo->restart_interval);
+}
+
+
+LOCAL(void)
+emit_sof (j_compress_ptr cinfo, JPEG_MARKER code)
+/* Emit a SOF marker */
+{
+ int ci;
+ jpeg_component_info *compptr;
+
+ emit_marker(cinfo, code);
+
+ emit_2bytes(cinfo, 3 * cinfo->num_components + 2 + 5 + 1); /* length */
+
+ /* Make sure image isn't bigger than SOF field can handle */
+ if ((long) cinfo->image_height > 65535L ||
+ (long) cinfo->image_width > 65535L)
+ ERREXIT1(cinfo, JERR_IMAGE_TOO_BIG, (unsigned int) 65535);
+
+ emit_byte(cinfo, cinfo->data_precision);
+ emit_2bytes(cinfo, (int) cinfo->image_height);
+ emit_2bytes(cinfo, (int) cinfo->image_width);
+
+ emit_byte(cinfo, cinfo->num_components);
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ emit_byte(cinfo, compptr->component_id);
+ emit_byte(cinfo, (compptr->h_samp_factor << 4) + compptr->v_samp_factor);
+ emit_byte(cinfo, compptr->quant_tbl_no);
+ }
+}
+
+
+LOCAL(void)
+emit_sos (j_compress_ptr cinfo)
+/* Emit a SOS marker */
+{
+ int i, td, ta;
+ jpeg_component_info *compptr;
+
+ emit_marker(cinfo, M_SOS);
+
+ emit_2bytes(cinfo, 2 * cinfo->comps_in_scan + 2 + 1 + 3); /* length */
+
+ emit_byte(cinfo, cinfo->comps_in_scan);
+
+ for (i = 0; i < cinfo->comps_in_scan; i++) {
+ compptr = cinfo->cur_comp_info[i];
+ emit_byte(cinfo, compptr->component_id);
+ td = compptr->dc_tbl_no;
+ ta = compptr->ac_tbl_no;
+ if (cinfo->progressive_mode) {
+ /* Progressive mode: only DC or only AC tables are used in one scan;
+ * furthermore, Huffman coding of DC refinement uses no table at all.
+ * We emit 0 for unused field(s); this is recommended by the P&M text
+ * but does not seem to be specified in the standard.
+ */
+ if (cinfo->Ss == 0) {
+ ta = 0; /* DC scan */
+ if (cinfo->Ah != 0 && !cinfo->arith_code)
+ td = 0; /* no DC table either */
+ } else {
+ td = 0; /* AC scan */
+ }
+ }
+ emit_byte(cinfo, (td << 4) + ta);
+ }
+
+ emit_byte(cinfo, cinfo->Ss);
+ emit_byte(cinfo, cinfo->Se);
+ emit_byte(cinfo, (cinfo->Ah << 4) + cinfo->Al);
+}
+
+
+LOCAL(void)
+emit_jfif_app0 (j_compress_ptr cinfo)
+/* Emit a JFIF-compliant APP0 marker */
+{
+ /*
+ * Length of APP0 block (2 bytes)
+ * Block ID (4 bytes - ASCII "JFIF")
+ * Zero byte (1 byte to terminate the ID string)
+ * Version Major, Minor (2 bytes - major first)
+ * Units (1 byte - 0x00 = none, 0x01 = inch, 0x02 = cm)
+ * Xdpu (2 bytes - dots per unit horizontal)
+ * Ydpu (2 bytes - dots per unit vertical)
+ * Thumbnail X size (1 byte)
+ * Thumbnail Y size (1 byte)
+ */
+
+ emit_marker(cinfo, M_APP0);
+
+ emit_2bytes(cinfo, 2 + 4 + 1 + 2 + 1 + 2 + 2 + 1 + 1); /* length */
+
+ emit_byte(cinfo, 0x4A); /* Identifier: ASCII "JFIF" */
+ emit_byte(cinfo, 0x46);
+ emit_byte(cinfo, 0x49);
+ emit_byte(cinfo, 0x46);
+ emit_byte(cinfo, 0);
+ emit_byte(cinfo, cinfo->JFIF_major_version); /* Version fields */
+ emit_byte(cinfo, cinfo->JFIF_minor_version);
+ emit_byte(cinfo, cinfo->density_unit); /* Pixel size information */
+ emit_2bytes(cinfo, (int) cinfo->X_density);
+ emit_2bytes(cinfo, (int) cinfo->Y_density);
+ emit_byte(cinfo, 0); /* No thumbnail image */
+ emit_byte(cinfo, 0);
+}
+
+
+LOCAL(void)
+emit_adobe_app14 (j_compress_ptr cinfo)
+/* Emit an Adobe APP14 marker */
+{
+ /*
+ * Length of APP14 block (2 bytes)
+ * Block ID (5 bytes - ASCII "Adobe")
+ * Version Number (2 bytes - currently 100)
+ * Flags0 (2 bytes - currently 0)
+ * Flags1 (2 bytes - currently 0)
+ * Color transform (1 byte)
+ *
+ * Although Adobe TN 5116 mentions Version = 101, all the Adobe files
+ * now in circulation seem to use Version = 100, so that's what we write.
+ *
+ * We write the color transform byte as 1 if the JPEG color space is
+ * YCbCr, 2 if it's YCCK, 0 otherwise. Adobe's definition has to do with
+ * whether the encoder performed a transformation, which is pretty useless.
+ */
+
+ emit_marker(cinfo, M_APP14);
+
+ emit_2bytes(cinfo, 2 + 5 + 2 + 2 + 2 + 1); /* length */
+
+ emit_byte(cinfo, 0x41); /* Identifier: ASCII "Adobe" */
+ emit_byte(cinfo, 0x64);
+ emit_byte(cinfo, 0x6F);
+ emit_byte(cinfo, 0x62);
+ emit_byte(cinfo, 0x65);
+ emit_2bytes(cinfo, 100); /* Version */
+ emit_2bytes(cinfo, 0); /* Flags0 */
+ emit_2bytes(cinfo, 0); /* Flags1 */
+ switch (cinfo->jpeg_color_space) {
+ case JCS_YCbCr:
+ emit_byte(cinfo, 1); /* Color transform = 1 */
+ break;
+ case JCS_YCCK:
+ emit_byte(cinfo, 2); /* Color transform = 2 */
+ break;
+ default:
+ emit_byte(cinfo, 0); /* Color transform = 0 */
+ break;
+ }
+}
+
+
+/*
+ * These routines allow writing an arbitrary marker with parameters.
+ * The only intended use is to emit COM or APPn markers after calling
+ * write_file_header and before calling write_frame_header.
+ * Other uses are not guaranteed to produce desirable results.
+ * Counting the parameter bytes properly is the caller's responsibility.
+ */
+
+METHODDEF(void)
+write_marker_header (j_compress_ptr cinfo, int marker, unsigned int datalen)
+/* Emit an arbitrary marker header */
+{
+ if (datalen > (unsigned int) 65533) /* safety check */
+ ERREXIT(cinfo, JERR_BAD_LENGTH);
+
+ emit_marker(cinfo, (JPEG_MARKER) marker);
+
+ emit_2bytes(cinfo, (int) (datalen + 2)); /* total length */
+}
+
+METHODDEF(void)
+write_marker_byte (j_compress_ptr cinfo, int val)
+/* Emit one byte of marker parameters following write_marker_header */
+{
+ emit_byte(cinfo, val);
+}
+
+
+/*
+ * Write datastream header.
+ * This consists of an SOI and optional APPn markers.
+ * We recommend use of the JFIF marker, but not the Adobe marker,
+ * when using YCbCr or grayscale data. The JFIF marker should NOT
+ * be used for any other JPEG colorspace. The Adobe marker is helpful
+ * to distinguish RGB, CMYK, and YCCK colorspaces.
+ * Note that an application can write additional header markers after
+ * jpeg_start_compress returns.
+ */
+
+METHODDEF(void)
+write_file_header (j_compress_ptr cinfo)
+{
+ my_marker_ptr marker = (my_marker_ptr) cinfo->marker;
+
+ emit_marker(cinfo, M_SOI); /* first the SOI */
+
+ /* SOI is defined to reset restart interval to 0 */
+ marker->last_restart_interval = 0;
+
+ if (cinfo->write_JFIF_header) /* next an optional JFIF APP0 */
+ emit_jfif_app0(cinfo);
+ if (cinfo->write_Adobe_marker) /* next an optional Adobe APP14 */
+ emit_adobe_app14(cinfo);
+}
+
+
+/*
+ * Write frame header.
+ * This consists of DQT and SOFn markers.
+ * Note that we do not emit the SOF until we have emitted the DQT(s).
+ * This avoids compatibility problems with incorrect implementations that
+ * try to error-check the quant table numbers as soon as they see the SOF.
+ */
+
+METHODDEF(void)
+write_frame_header (j_compress_ptr cinfo)
+{
+ int ci, prec;
+ int is_baseline;
+ jpeg_component_info *compptr;
+
+ /* Emit DQT for each quantization table.
+ * Note that emit_dqt() suppresses any duplicate tables.
+ */
+ prec = 0;
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ prec += emit_dqt(cinfo, compptr->quant_tbl_no);
+ }
+ /* now prec is nonzero iff there are any 16-bit quant tables. */
+
+ /* Check for a non-baseline specification.
+ * Note we assume that Huffman table numbers won't be changed later.
+ */
+ if (cinfo->arith_code || cinfo->progressive_mode ||
+ cinfo->data_precision != 8) {
+ is_baseline = FALSE;
+ } else {
+ is_baseline = TRUE;
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ if (compptr->dc_tbl_no > 1 || compptr->ac_tbl_no > 1)
+ is_baseline = FALSE;
+ }
+ if (prec && is_baseline) {
+ is_baseline = FALSE;
+ /* If it's baseline except for quantizer size, warn the user */
+ TRACEMS(cinfo, 0, JTRC_16BIT_TABLES);
+ }
+ }
+
+ /* Emit the proper SOF marker */
+ if (cinfo->arith_code) {
+ emit_sof(cinfo, M_SOF9); /* SOF code for arithmetic coding */
+ } else {
+ if (cinfo->progressive_mode)
+ emit_sof(cinfo, M_SOF2); /* SOF code for progressive Huffman */
+ else if (is_baseline)
+ emit_sof(cinfo, M_SOF0); /* SOF code for baseline implementation */
+ else
+ emit_sof(cinfo, M_SOF1); /* SOF code for non-baseline Huffman file */
+ }
+}
+
+
+/*
+ * Write scan header.
+ * This consists of DHT or DAC markers, optional DRI, and SOS.
+ * Compressed data will be written following the SOS.
+ */
+
+METHODDEF(void)
+write_scan_header (j_compress_ptr cinfo)
+{
+ my_marker_ptr marker = (my_marker_ptr) cinfo->marker;
+ int i;
+ jpeg_component_info *compptr;
+
+ if (cinfo->arith_code) {
+ /* Emit arith conditioning info. We may have some duplication
+ * if the file has multiple scans, but it's so small it's hardly
+ * worth worrying about.
+ */
+ emit_dac(cinfo);
+ } else {
+ /* Emit Huffman tables.
+ * Note that emit_dht() suppresses any duplicate tables.
+ */
+ for (i = 0; i < cinfo->comps_in_scan; i++) {
+ compptr = cinfo->cur_comp_info[i];
+ if (cinfo->progressive_mode) {
+ /* Progressive mode: only DC or only AC tables are used in one scan */
+ if (cinfo->Ss == 0) {
+ if (cinfo->Ah == 0) /* DC needs no table for refinement scan */
+ emit_dht(cinfo, compptr->dc_tbl_no, FALSE);
+ } else {
+ emit_dht(cinfo, compptr->ac_tbl_no, TRUE);
+ }
+ } else {
+ /* Sequential mode: need both DC and AC tables */
+ emit_dht(cinfo, compptr->dc_tbl_no, FALSE);
+ emit_dht(cinfo, compptr->ac_tbl_no, TRUE);
+ }
+ }
+ }
+
+ /* Emit DRI if required --- note that DRI value could change for each scan.
+ * We avoid wasting space with unnecessary DRIs, however.
+ */
+ if (cinfo->restart_interval != marker->last_restart_interval) {
+ emit_dri(cinfo);
+ marker->last_restart_interval = cinfo->restart_interval;
+ }
+
+ emit_sos(cinfo);
+}
+
+
+/*
+ * Write datastream trailer.
+ */
+
+METHODDEF(void)
+write_file_trailer (j_compress_ptr cinfo)
+{
+ emit_marker(cinfo, M_EOI);
+}
+
+
+/*
+ * Write an abbreviated table-specification datastream.
+ * This consists of SOI, DQT and DHT tables, and EOI.
+ * Any table that is defined and not marked sent_table = TRUE will be
+ * emitted. Note that all tables will be marked sent_table = TRUE at exit.
+ */
+
+METHODDEF(void)
+write_tables_only (j_compress_ptr cinfo)
+{
+ int i;
+
+ emit_marker(cinfo, M_SOI);
+
+ for (i = 0; i < NUM_QUANT_TBLS; i++) {
+ if (cinfo->quant_tbl_ptrs[i] != NULL)
+ (void) emit_dqt(cinfo, i);
+ }
+
+ if (! cinfo->arith_code) {
+ for (i = 0; i < NUM_HUFF_TBLS; i++) {
+ if (cinfo->dc_huff_tbl_ptrs[i] != NULL)
+ emit_dht(cinfo, i, FALSE);
+ if (cinfo->ac_huff_tbl_ptrs[i] != NULL)
+ emit_dht(cinfo, i, TRUE);
+ }
+ }
+
+ emit_marker(cinfo, M_EOI);
+}
+
+
+/*
+ * Initialize the marker writer module.
+ */
+
+GLOBAL(void)
+jinit_marker_writer (j_compress_ptr cinfo)
+{
+ my_marker_ptr marker;
+
+ /* Create the subobject */
+ marker = (my_marker_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_marker_writer));
+ cinfo->marker = (struct jpeg_marker_writer *) marker;
+ /* Initialize method pointers */
+ marker->pub.write_file_header = write_file_header;
+ marker->pub.write_frame_header = write_frame_header;
+ marker->pub.write_scan_header = write_scan_header;
+ marker->pub.write_file_trailer = write_file_trailer;
+ marker->pub.write_tables_only = write_tables_only;
+ marker->pub.write_marker_header = write_marker_header;
+ marker->pub.write_marker_byte = write_marker_byte;
+ /* Initialize private state */
+ marker->last_restart_interval = 0;
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jcmaster.cpp b/ml/dlib/dlib/external/libjpeg/jcmaster.cpp
new file mode 100644
index 000000000..3e1b1d711
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcmaster.cpp
@@ -0,0 +1,590 @@
+/*
+ * jcmaster.c
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains master control logic for the JPEG compressor.
+ * These routines are concerned with parameter validation, initial setup,
+ * and inter-pass control (determining the number of passes and the work
+ * to be done in each pass).
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Private state */
+
+typedef enum {
+ main_pass, /* input data, also do first output step */
+ huff_opt_pass, /* Huffman code optimization pass */
+ output_pass /* data output pass */
+} c_pass_type;
+
+typedef struct {
+ struct jpeg_comp_master pub; /* public fields */
+
+ c_pass_type pass_type; /* the type of the current pass */
+
+ int pass_number; /* # of passes completed */
+ int total_passes; /* total # of passes needed */
+
+ int scan_number; /* current index in scan_info[] */
+} my_comp_master;
+
+typedef my_comp_master * my_master_ptr;
+
+
+/*
+ * Support routines that do various essential calculations.
+ */
+
+LOCAL(void)
+initial_setup (j_compress_ptr cinfo)
+/* Do computations that are needed before master selection phase */
+{
+ int ci;
+ jpeg_component_info *compptr;
+ long samplesperrow;
+ JDIMENSION jd_samplesperrow;
+
+ /* Sanity check on image dimensions */
+ if (cinfo->image_height <= 0 || cinfo->image_width <= 0
+ || cinfo->num_components <= 0 || cinfo->input_components <= 0)
+ ERREXIT(cinfo, JERR_EMPTY_IMAGE);
+
+ /* Make sure image isn't bigger than I can handle */
+ if ((long) cinfo->image_height > (long) JPEG_MAX_DIMENSION ||
+ (long) cinfo->image_width > (long) JPEG_MAX_DIMENSION)
+ ERREXIT1(cinfo, JERR_IMAGE_TOO_BIG, (unsigned int) JPEG_MAX_DIMENSION);
+
+ /* Width of an input scanline must be representable as JDIMENSION. */
+ samplesperrow = (long) cinfo->image_width * (long) cinfo->input_components;
+ jd_samplesperrow = (JDIMENSION) samplesperrow;
+ if ((long) jd_samplesperrow != samplesperrow)
+ ERREXIT(cinfo, JERR_WIDTH_OVERFLOW);
+
+ /* For now, precision must match compiled-in value... */
+ if (cinfo->data_precision != BITS_IN_JSAMPLE)
+ ERREXIT1(cinfo, JERR_BAD_PRECISION, cinfo->data_precision);
+
+ /* Check that number of components won't exceed internal array sizes */
+ if (cinfo->num_components > MAX_COMPONENTS)
+ ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->num_components,
+ MAX_COMPONENTS);
+
+ /* Compute maximum sampling factors; check factor validity */
+ cinfo->max_h_samp_factor = 1;
+ cinfo->max_v_samp_factor = 1;
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ if (compptr->h_samp_factor<=0 || compptr->h_samp_factor>MAX_SAMP_FACTOR ||
+ compptr->v_samp_factor<=0 || compptr->v_samp_factor>MAX_SAMP_FACTOR)
+ ERREXIT(cinfo, JERR_BAD_SAMPLING);
+ cinfo->max_h_samp_factor = MAX(cinfo->max_h_samp_factor,
+ compptr->h_samp_factor);
+ cinfo->max_v_samp_factor = MAX(cinfo->max_v_samp_factor,
+ compptr->v_samp_factor);
+ }
+
+ /* Compute dimensions of components */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Fill in the correct component_index value; don't rely on application */
+ compptr->component_index = ci;
+ /* For compression, we never do DCT scaling. */
+ compptr->DCT_scaled_size = DCTSIZE;
+ /* Size in DCT blocks */
+ compptr->width_in_blocks = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width * (long) compptr->h_samp_factor,
+ (long) (cinfo->max_h_samp_factor * DCTSIZE));
+ compptr->height_in_blocks = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height * (long) compptr->v_samp_factor,
+ (long) (cinfo->max_v_samp_factor * DCTSIZE));
+ /* Size in samples */
+ compptr->downsampled_width = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width * (long) compptr->h_samp_factor,
+ (long) cinfo->max_h_samp_factor);
+ compptr->downsampled_height = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height * (long) compptr->v_samp_factor,
+ (long) cinfo->max_v_samp_factor);
+ /* Mark component needed (this flag isn't actually used for compression) */
+ compptr->component_needed = TRUE;
+ }
+
+ /* Compute number of fully interleaved MCU rows (number of times that
+ * main controller will call coefficient controller).
+ */
+ cinfo->total_iMCU_rows = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height,
+ (long) (cinfo->max_v_samp_factor*DCTSIZE));
+}
+
+
+#ifdef C_MULTISCAN_FILES_SUPPORTED
+
+LOCAL(void)
+validate_script (j_compress_ptr cinfo)
+/* Verify that the scan script in cinfo->scan_info[] is valid; also
+ * determine whether it uses progressive JPEG, and set cinfo->progressive_mode.
+ */
+{
+ const jpeg_scan_info * scanptr;
+ int scanno, ncomps, ci, coefi, thisi;
+ int Ss, Se, Ah, Al;
+ int component_sent[MAX_COMPONENTS];
+#ifdef C_PROGRESSIVE_SUPPORTED
+ int * last_bitpos_ptr;
+ int last_bitpos[MAX_COMPONENTS][DCTSIZE2];
+ /* -1 until that coefficient has been seen; then last Al for it */
+#endif
+
+ if (cinfo->num_scans <= 0)
+ ERREXIT1(cinfo, JERR_BAD_SCAN_SCRIPT, 0);
+
+ /* For sequential JPEG, all scans must have Ss=0, Se=DCTSIZE2-1;
+ * for progressive JPEG, no scan can have this.
+ */
+ scanptr = cinfo->scan_info;
+ if (scanptr->Ss != 0 || scanptr->Se != DCTSIZE2-1) {
+#ifdef C_PROGRESSIVE_SUPPORTED
+ cinfo->progressive_mode = TRUE;
+ last_bitpos_ptr = & last_bitpos[0][0];
+ for (ci = 0; ci < cinfo->num_components; ci++)
+ for (coefi = 0; coefi < DCTSIZE2; coefi++)
+ *last_bitpos_ptr++ = -1;
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ } else {
+ cinfo->progressive_mode = FALSE;
+ for (ci = 0; ci < cinfo->num_components; ci++)
+ component_sent[ci] = FALSE;
+ }
+
+ for (scanno = 1; scanno <= cinfo->num_scans; scanptr++, scanno++) {
+ /* Validate component indexes */
+ ncomps = scanptr->comps_in_scan;
+ if (ncomps <= 0 || ncomps > MAX_COMPS_IN_SCAN)
+ ERREXIT2(cinfo, JERR_COMPONENT_COUNT, ncomps, MAX_COMPS_IN_SCAN);
+ for (ci = 0; ci < ncomps; ci++) {
+ thisi = scanptr->component_index[ci];
+ if (thisi < 0 || thisi >= cinfo->num_components)
+ ERREXIT1(cinfo, JERR_BAD_SCAN_SCRIPT, scanno);
+ /* Components must appear in SOF order within each scan */
+ if (ci > 0 && thisi <= scanptr->component_index[ci-1])
+ ERREXIT1(cinfo, JERR_BAD_SCAN_SCRIPT, scanno);
+ }
+ /* Validate progression parameters */
+ Ss = scanptr->Ss;
+ Se = scanptr->Se;
+ Ah = scanptr->Ah;
+ Al = scanptr->Al;
+ if (cinfo->progressive_mode) {
+#ifdef C_PROGRESSIVE_SUPPORTED
+ /* The JPEG spec simply gives the ranges 0..13 for Ah and Al, but that
+ * seems wrong: the upper bound ought to depend on data precision.
+ * Perhaps they really meant 0..N+1 for N-bit precision.
+ * Here we allow 0..10 for 8-bit data; Al larger than 10 results in
+ * out-of-range reconstructed DC values during the first DC scan,
+ * which might cause problems for some decoders.
+ */
+#if BITS_IN_JSAMPLE == 8
+#define MAX_AH_AL 10
+#else
+#define MAX_AH_AL 13
+#endif
+ if (Ss < 0 || Ss >= DCTSIZE2 || Se < Ss || Se >= DCTSIZE2 ||
+ Ah < 0 || Ah > MAX_AH_AL || Al < 0 || Al > MAX_AH_AL)
+ ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno);
+ if (Ss == 0) {
+ if (Se != 0) /* DC and AC together not OK */
+ ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno);
+ } else {
+ if (ncomps != 1) /* AC scans must be for only one component */
+ ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno);
+ }
+ for (ci = 0; ci < ncomps; ci++) {
+ last_bitpos_ptr = & last_bitpos[scanptr->component_index[ci]][0];
+ if (Ss != 0 && last_bitpos_ptr[0] < 0) /* AC without prior DC scan */
+ ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno);
+ for (coefi = Ss; coefi <= Se; coefi++) {
+ if (last_bitpos_ptr[coefi] < 0) {
+ /* first scan of this coefficient */
+ if (Ah != 0)
+ ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno);
+ } else {
+ /* not first scan */
+ if (Ah != last_bitpos_ptr[coefi] || Al != Ah-1)
+ ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno);
+ }
+ last_bitpos_ptr[coefi] = Al;
+ }
+ }
+#endif
+ } else {
+ /* For sequential JPEG, all progression parameters must be these: */
+ if (Ss != 0 || Se != DCTSIZE2-1 || Ah != 0 || Al != 0)
+ ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno);
+ /* Make sure components are not sent twice */
+ for (ci = 0; ci < ncomps; ci++) {
+ thisi = scanptr->component_index[ci];
+ if (component_sent[thisi])
+ ERREXIT1(cinfo, JERR_BAD_SCAN_SCRIPT, scanno);
+ component_sent[thisi] = TRUE;
+ }
+ }
+ }
+
+ /* Now verify that everything got sent. */
+ if (cinfo->progressive_mode) {
+#ifdef C_PROGRESSIVE_SUPPORTED
+ /* For progressive mode, we only check that at least some DC data
+ * got sent for each component; the spec does not require that all bits
+ * of all coefficients be transmitted. Would it be wiser to enforce
+ * transmission of all coefficient bits??
+ */
+ for (ci = 0; ci < cinfo->num_components; ci++) {
+ if (last_bitpos[ci][0] < 0)
+ ERREXIT(cinfo, JERR_MISSING_DATA);
+ }
+#endif
+ } else {
+ for (ci = 0; ci < cinfo->num_components; ci++) {
+ if (! component_sent[ci])
+ ERREXIT(cinfo, JERR_MISSING_DATA);
+ }
+ }
+}
+
+#endif /* C_MULTISCAN_FILES_SUPPORTED */
+
+
+LOCAL(void)
+select_scan_parameters (j_compress_ptr cinfo)
+/* Set up the scan parameters for the current scan */
+{
+ int ci;
+
+#ifdef C_MULTISCAN_FILES_SUPPORTED
+ if (cinfo->scan_info != NULL) {
+ /* Prepare for current scan --- the script is already validated */
+ my_master_ptr master = (my_master_ptr) cinfo->master;
+ const jpeg_scan_info * scanptr = cinfo->scan_info + master->scan_number;
+
+ cinfo->comps_in_scan = scanptr->comps_in_scan;
+ for (ci = 0; ci < scanptr->comps_in_scan; ci++) {
+ cinfo->cur_comp_info[ci] =
+ &cinfo->comp_info[scanptr->component_index[ci]];
+ }
+ cinfo->Ss = scanptr->Ss;
+ cinfo->Se = scanptr->Se;
+ cinfo->Ah = scanptr->Ah;
+ cinfo->Al = scanptr->Al;
+ }
+ else
+#endif
+ {
+ /* Prepare for single sequential-JPEG scan containing all components */
+ if (cinfo->num_components > MAX_COMPS_IN_SCAN)
+ ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->num_components,
+ MAX_COMPS_IN_SCAN);
+ cinfo->comps_in_scan = cinfo->num_components;
+ for (ci = 0; ci < cinfo->num_components; ci++) {
+ cinfo->cur_comp_info[ci] = &cinfo->comp_info[ci];
+ }
+ cinfo->Ss = 0;
+ cinfo->Se = DCTSIZE2-1;
+ cinfo->Ah = 0;
+ cinfo->Al = 0;
+ }
+}
+
+
+LOCAL(void)
+per_scan_setup (j_compress_ptr cinfo)
+/* Do computations that are needed before processing a JPEG scan */
+/* cinfo->comps_in_scan and cinfo->cur_comp_info[] are already set */
+{
+ int ci, mcublks, tmp;
+ jpeg_component_info *compptr;
+
+ if (cinfo->comps_in_scan == 1) {
+
+ /* Noninterleaved (single-component) scan */
+ compptr = cinfo->cur_comp_info[0];
+
+ /* Overall image size in MCUs */
+ cinfo->MCUs_per_row = compptr->width_in_blocks;
+ cinfo->MCU_rows_in_scan = compptr->height_in_blocks;
+
+ /* For noninterleaved scan, always one block per MCU */
+ compptr->MCU_width = 1;
+ compptr->MCU_height = 1;
+ compptr->MCU_blocks = 1;
+ compptr->MCU_sample_width = DCTSIZE;
+ compptr->last_col_width = 1;
+ /* For noninterleaved scans, it is convenient to define last_row_height
+ * as the number of block rows present in the last iMCU row.
+ */
+ tmp = (int) (compptr->height_in_blocks % compptr->v_samp_factor);
+ if (tmp == 0) tmp = compptr->v_samp_factor;
+ compptr->last_row_height = tmp;
+
+ /* Prepare array describing MCU composition */
+ cinfo->blocks_in_MCU = 1;
+ cinfo->MCU_membership[0] = 0;
+
+ } else {
+
+ /* Interleaved (multi-component) scan */
+ if (cinfo->comps_in_scan <= 0 || cinfo->comps_in_scan > MAX_COMPS_IN_SCAN)
+ ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->comps_in_scan,
+ MAX_COMPS_IN_SCAN);
+
+ /* Overall image size in MCUs */
+ cinfo->MCUs_per_row = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width,
+ (long) (cinfo->max_h_samp_factor*DCTSIZE));
+ cinfo->MCU_rows_in_scan = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height,
+ (long) (cinfo->max_v_samp_factor*DCTSIZE));
+
+ cinfo->blocks_in_MCU = 0;
+
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ /* Sampling factors give # of blocks of component in each MCU */
+ compptr->MCU_width = compptr->h_samp_factor;
+ compptr->MCU_height = compptr->v_samp_factor;
+ compptr->MCU_blocks = compptr->MCU_width * compptr->MCU_height;
+ compptr->MCU_sample_width = compptr->MCU_width * DCTSIZE;
+ /* Figure number of non-dummy blocks in last MCU column & row */
+ tmp = (int) (compptr->width_in_blocks % compptr->MCU_width);
+ if (tmp == 0) tmp = compptr->MCU_width;
+ compptr->last_col_width = tmp;
+ tmp = (int) (compptr->height_in_blocks % compptr->MCU_height);
+ if (tmp == 0) tmp = compptr->MCU_height;
+ compptr->last_row_height = tmp;
+ /* Prepare array describing MCU composition */
+ mcublks = compptr->MCU_blocks;
+ if (cinfo->blocks_in_MCU + mcublks > C_MAX_BLOCKS_IN_MCU)
+ ERREXIT(cinfo, JERR_BAD_MCU_SIZE);
+ while (mcublks-- > 0) {
+ cinfo->MCU_membership[cinfo->blocks_in_MCU++] = ci;
+ }
+ }
+
+ }
+
+ /* Convert restart specified in rows to actual MCU count. */
+ /* Note that count must fit in 16 bits, so we provide limiting. */
+ if (cinfo->restart_in_rows > 0) {
+ long nominal = (long) cinfo->restart_in_rows * (long) cinfo->MCUs_per_row;
+ cinfo->restart_interval = (unsigned int) MIN(nominal, 65535L);
+ }
+}
+
+
+/*
+ * Per-pass setup.
+ * This is called at the beginning of each pass. We determine which modules
+ * will be active during this pass and give them appropriate start_pass calls.
+ * We also set is_last_pass to indicate whether any more passes will be
+ * required.
+ */
+
+METHODDEF(void)
+prepare_for_pass (j_compress_ptr cinfo)
+{
+ my_master_ptr master = (my_master_ptr) cinfo->master;
+
+ switch (master->pass_type) {
+ case main_pass:
+ /* Initial pass: will collect input data, and do either Huffman
+ * optimization or data output for the first scan.
+ */
+ select_scan_parameters(cinfo);
+ per_scan_setup(cinfo);
+ if (! cinfo->raw_data_in) {
+ (*cinfo->cconvert->start_pass) (cinfo);
+ (*cinfo->downsample->start_pass) (cinfo);
+ (*cinfo->prep->start_pass) (cinfo, JBUF_PASS_THRU);
+ }
+ (*cinfo->fdct->start_pass) (cinfo);
+ (*cinfo->entropy->start_pass) (cinfo, cinfo->optimize_coding);
+ (*cinfo->coef->start_pass) (cinfo,
+ (master->total_passes > 1 ?
+ JBUF_SAVE_AND_PASS : JBUF_PASS_THRU));
+ (*cinfo->main->start_pass) (cinfo, JBUF_PASS_THRU);
+ if (cinfo->optimize_coding) {
+ /* No immediate data output; postpone writing frame/scan headers */
+ master->pub.call_pass_startup = FALSE;
+ } else {
+ /* Will write frame/scan headers at first jpeg_write_scanlines call */
+ master->pub.call_pass_startup = TRUE;
+ }
+ break;
+#ifdef ENTROPY_OPT_SUPPORTED
+ case huff_opt_pass:
+ /* Do Huffman optimization for a scan after the first one. */
+ select_scan_parameters(cinfo);
+ per_scan_setup(cinfo);
+ if (cinfo->Ss != 0 || cinfo->Ah == 0 || cinfo->arith_code) {
+ (*cinfo->entropy->start_pass) (cinfo, TRUE);
+ (*cinfo->coef->start_pass) (cinfo, JBUF_CRANK_DEST);
+ master->pub.call_pass_startup = FALSE;
+ break;
+ }
+ /* Special case: Huffman DC refinement scans need no Huffman table
+ * and therefore we can skip the optimization pass for them.
+ */
+ master->pass_type = output_pass;
+ master->pass_number++;
+ /*FALLTHROUGH*/
+#endif
+ case output_pass:
+ /* Do a data-output pass. */
+ /* We need not repeat per-scan setup if prior optimization pass did it. */
+ if (! cinfo->optimize_coding) {
+ select_scan_parameters(cinfo);
+ per_scan_setup(cinfo);
+ }
+ (*cinfo->entropy->start_pass) (cinfo, FALSE);
+ (*cinfo->coef->start_pass) (cinfo, JBUF_CRANK_DEST);
+ /* We emit frame/scan headers now */
+ if (master->scan_number == 0)
+ (*cinfo->marker->write_frame_header) (cinfo);
+ (*cinfo->marker->write_scan_header) (cinfo);
+ master->pub.call_pass_startup = FALSE;
+ break;
+ default:
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+ }
+
+ master->pub.is_last_pass = (master->pass_number == master->total_passes-1);
+
+ /* Set up progress monitor's pass info if present */
+ if (cinfo->progress != NULL) {
+ cinfo->progress->completed_passes = master->pass_number;
+ cinfo->progress->total_passes = master->total_passes;
+ }
+}
+
+
+/*
+ * Special start-of-pass hook.
+ * This is called by jpeg_write_scanlines if call_pass_startup is TRUE.
+ * In single-pass processing, we need this hook because we don't want to
+ * write frame/scan headers during jpeg_start_compress; we want to let the
+ * application write COM markers etc. between jpeg_start_compress and the
+ * jpeg_write_scanlines loop.
+ * In multi-pass processing, this routine is not used.
+ */
+
+METHODDEF(void)
+pass_startup (j_compress_ptr cinfo)
+{
+ cinfo->master->call_pass_startup = FALSE; /* reset flag so call only once */
+
+ (*cinfo->marker->write_frame_header) (cinfo);
+ (*cinfo->marker->write_scan_header) (cinfo);
+}
+
+
+/*
+ * Finish up at end of pass.
+ */
+
+METHODDEF(void)
+finish_pass_master (j_compress_ptr cinfo)
+{
+ my_master_ptr master = (my_master_ptr) cinfo->master;
+
+ /* The entropy coder always needs an end-of-pass call,
+ * either to analyze statistics or to flush its output buffer.
+ */
+ (*cinfo->entropy->finish_pass) (cinfo);
+
+ /* Update state for next pass */
+ switch (master->pass_type) {
+ case main_pass:
+ /* next pass is either output of scan 0 (after optimization)
+ * or output of scan 1 (if no optimization).
+ */
+ master->pass_type = output_pass;
+ if (! cinfo->optimize_coding)
+ master->scan_number++;
+ break;
+ case huff_opt_pass:
+ /* next pass is always output of current scan */
+ master->pass_type = output_pass;
+ break;
+ case output_pass:
+ /* next pass is either optimization or output of next scan */
+ if (cinfo->optimize_coding)
+ master->pass_type = huff_opt_pass;
+ master->scan_number++;
+ break;
+ }
+
+ master->pass_number++;
+}
+
+
+/*
+ * Initialize master compression control.
+ */
+
+GLOBAL(void)
+jinit_c_master_control (j_compress_ptr cinfo, int transcode_only)
+{
+ my_master_ptr master;
+
+ master = (my_master_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_comp_master));
+ cinfo->master = (struct jpeg_comp_master *) master;
+ master->pub.prepare_for_pass = prepare_for_pass;
+ master->pub.pass_startup = pass_startup;
+ master->pub.finish_pass = finish_pass_master;
+ master->pub.is_last_pass = FALSE;
+
+ /* Validate parameters, determine derived values */
+ initial_setup(cinfo);
+
+ if (cinfo->scan_info != NULL) {
+#ifdef C_MULTISCAN_FILES_SUPPORTED
+ validate_script(cinfo);
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ } else {
+ cinfo->progressive_mode = FALSE;
+ cinfo->num_scans = 1;
+ }
+
+ if (cinfo->progressive_mode) /* TEMPORARY HACK ??? */
+ cinfo->optimize_coding = TRUE; /* assume default tables no good for progressive mode */
+
+ /* Initialize my private state */
+ if (transcode_only) {
+ /* no main pass in transcoding */
+ if (cinfo->optimize_coding)
+ master->pass_type = huff_opt_pass;
+ else
+ master->pass_type = output_pass;
+ } else {
+ /* for normal compression, first pass is always this type: */
+ master->pass_type = main_pass;
+ }
+ master->scan_number = 0;
+ master->pass_number = 0;
+ if (cinfo->optimize_coding)
+ master->total_passes = cinfo->num_scans * 2;
+ else
+ master->total_passes = cinfo->num_scans;
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jcomapi.cpp b/ml/dlib/dlib/external/libjpeg/jcomapi.cpp
new file mode 100644
index 000000000..9b1fa7568
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcomapi.cpp
@@ -0,0 +1,106 @@
+/*
+ * jcomapi.c
+ *
+ * Copyright (C) 1994-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains application interface routines that are used for both
+ * compression and decompression.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/*
+ * Abort processing of a JPEG compression or decompression operation,
+ * but don't destroy the object itself.
+ *
+ * For this, we merely clean up all the nonpermanent memory pools.
+ * Note that temp files (virtual arrays) are not allowed to belong to
+ * the permanent pool, so we will be able to close all temp files here.
+ * Closing a data source or destination, if necessary, is the application's
+ * responsibility.
+ */
+
+GLOBAL(void)
+jpeg_abort (j_common_ptr cinfo)
+{
+ int pool;
+
+ /* Do nothing if called on a not-initialized or destroyed JPEG object. */
+ if (cinfo->mem == NULL)
+ return;
+
+ /* Releasing pools in reverse order might help avoid fragmentation
+ * with some (brain-damaged) malloc libraries.
+ */
+ for (pool = JPOOL_NUMPOOLS-1; pool > JPOOL_PERMANENT; pool--) {
+ (*cinfo->mem->free_pool) (cinfo, pool);
+ }
+
+ /* Reset overall state for possible reuse of object */
+ if (cinfo->is_decompressor) {
+ cinfo->global_state = DSTATE_START;
+ /* Try to keep application from accessing now-deleted marker list.
+ * A bit kludgy to do it here, but this is the most central place.
+ */
+ ((j_decompress_ptr) cinfo)->marker_list = NULL;
+ } else {
+ cinfo->global_state = CSTATE_START;
+ }
+}
+
+
+/*
+ * Destruction of a JPEG object.
+ *
+ * Everything gets deallocated except the master jpeg_compress_struct itself
+ * and the error manager struct. Both of these are supplied by the application
+ * and must be freed, if necessary, by the application. (Often they are on
+ * the stack and so don't need to be freed anyway.)
+ * Closing a data source or destination, if necessary, is the application's
+ * responsibility.
+ */
+
+GLOBAL(void)
+jpeg_destroy (j_common_ptr cinfo)
+{
+ /* We need only tell the memory manager to release everything. */
+ /* NB: mem pointer is NULL if memory mgr failed to initialize. */
+ if (cinfo->mem != NULL)
+ (*cinfo->mem->self_destruct) (cinfo);
+ cinfo->mem = NULL; /* be safe if jpeg_destroy is called twice */
+ cinfo->global_state = 0; /* mark it destroyed */
+}
+
+
+/*
+ * Convenience routines for allocating quantization and Huffman tables.
+ * (Would jutils.c be a more reasonable place to put these?)
+ */
+
+GLOBAL(JQUANT_TBL *)
+jpeg_alloc_quant_table (j_common_ptr cinfo)
+{
+ JQUANT_TBL *tbl;
+
+ tbl = (JQUANT_TBL *)
+ (*cinfo->mem->alloc_small) (cinfo, JPOOL_PERMANENT, SIZEOF(JQUANT_TBL));
+ tbl->sent_table = FALSE; /* make sure this is false in any new table */
+ return tbl;
+}
+
+
+GLOBAL(JHUFF_TBL *)
+jpeg_alloc_huff_table (j_common_ptr cinfo)
+{
+ JHUFF_TBL *tbl;
+
+ tbl = (JHUFF_TBL *)
+ (*cinfo->mem->alloc_small) (cinfo, JPOOL_PERMANENT, SIZEOF(JHUFF_TBL));
+ tbl->sent_table = FALSE; /* make sure this is false in any new table */
+ return tbl;
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jconfig.h b/ml/dlib/dlib/external/libjpeg/jconfig.h
new file mode 100644
index 000000000..9594ec56b
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jconfig.h
@@ -0,0 +1,45 @@
+/* jconfig.h. Generated automatically by configure. */
+/* jconfig.cfg --- source file edited by configure script */
+/* see jconfig.doc for explanations */
+
+#define HAVE_PROTOTYPES
+#define HAVE_UNSIGNED_CHAR
+#define HAVE_UNSIGNED_SHORT
+#undef void
+#undef const
+#undef CHAR_IS_UNSIGNED
+#define HAVE_STDDEF_H
+#define HAVE_STDLIB_H
+#undef NEED_BSD_STRINGS
+#undef NEED_SYS_TYPES_H
+#undef NEED_FAR_POINTERS
+#undef NEED_SHORT_EXTERNAL_NAMES
+/* Define this if you get warnings about undefined structures. */
+#undef INCOMPLETE_TYPES_BROKEN
+
+#ifdef JPEG_INTERNALS
+
+#undef RIGHT_SHIFT_IS_UNSIGNED
+#define INLINE __inline__
+/* These are for configuring the JPEG memory manager. */
+#undef DEFAULT_MAX_MEM
+#undef NO_MKTEMP
+
+#endif /* JPEG_INTERNALS */
+
+#ifdef JPEG_CJPEG_DJPEG
+
+#define BMP_SUPPORTED /* BMP image file format */
+#define GIF_SUPPORTED /* GIF image file format */
+#define PPM_SUPPORTED /* PBMPLUS PPM/PGM image file format */
+#undef RLE_SUPPORTED /* Utah RLE image file format */
+#define TARGA_SUPPORTED /* Targa image file format */
+
+#undef TWO_FILE_COMMANDLINE
+#undef NEED_SIGNAL_CATCHER
+#undef DONT_USE_B_MODE
+
+/* Define this if you want percent-done progress reports from cjpeg/djpeg. */
+#undef PROGRESS_REPORT
+
+#endif /* JPEG_CJPEG_DJPEG */
diff --git a/ml/dlib/dlib/external/libjpeg/jcparam.cpp b/ml/dlib/dlib/external/libjpeg/jcparam.cpp
new file mode 100644
index 000000000..a87ef2079
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcparam.cpp
@@ -0,0 +1,610 @@
+/*
+ * jcparam.c
+ *
+ * Copyright (C) 1991-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains optional default-setting code for the JPEG compressor.
+ * Applications do not have to use this file, but those that don't use it
+ * must know a lot more about the innards of the JPEG code.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/*
+ * Quantization table setup routines
+ */
+
+GLOBAL(void)
+jpeg_add_quant_table (j_compress_ptr cinfo, int which_tbl,
+ const unsigned int *basic_table,
+ int scale_factor, int force_baseline)
+/* Define a quantization table equal to the basic_table times
+ * a scale factor (given as a percentage).
+ * If force_baseline is TRUE, the computed quantization table entries
+ * are limited to 1..255 for JPEG baseline compatibility.
+ */
+{
+ JQUANT_TBL ** qtblptr;
+ int i;
+ long temp;
+
+ /* Safety check to ensure start_compress not called yet. */
+ if (cinfo->global_state != CSTATE_START)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ if (which_tbl < 0 || which_tbl >= NUM_QUANT_TBLS)
+ ERREXIT1(cinfo, JERR_DQT_INDEX, which_tbl);
+
+ qtblptr = & cinfo->quant_tbl_ptrs[which_tbl];
+
+ if (*qtblptr == NULL)
+ *qtblptr = jpeg_alloc_quant_table((j_common_ptr) cinfo);
+
+ for (i = 0; i < DCTSIZE2; i++) {
+ temp = ((long) basic_table[i] * scale_factor + 50L) / 100L;
+ /* limit the values to the valid range */
+ if (temp <= 0L) temp = 1L;
+ if (temp > 32767L) temp = 32767L; /* max quantizer needed for 12 bits */
+ if (force_baseline && temp > 255L)
+ temp = 255L; /* limit to baseline range if requested */
+ (*qtblptr)->quantval[i] = (unsigned short) temp;
+ }
+
+ /* Initialize sent_table FALSE so table will be written to JPEG file. */
+ (*qtblptr)->sent_table = FALSE;
+}
+
+
+GLOBAL(void)
+jpeg_set_linear_quality (j_compress_ptr cinfo, int scale_factor,
+ int force_baseline)
+/* Set or change the 'quality' (quantization) setting, using default tables
+ * and a straight percentage-scaling quality scale. In most cases it's better
+ * to use jpeg_set_quality (below); this entry point is provided for
+ * applications that insist on a linear percentage scaling.
+ */
+{
+ /* These are the sample quantization tables given in JPEG spec section K.1.
+ * The spec says that the values given produce "good" quality, and
+ * when divided by 2, "very good" quality.
+ */
+ static const unsigned int std_luminance_quant_tbl[DCTSIZE2] = {
+ 16, 11, 10, 16, 24, 40, 51, 61,
+ 12, 12, 14, 19, 26, 58, 60, 55,
+ 14, 13, 16, 24, 40, 57, 69, 56,
+ 14, 17, 22, 29, 51, 87, 80, 62,
+ 18, 22, 37, 56, 68, 109, 103, 77,
+ 24, 35, 55, 64, 81, 104, 113, 92,
+ 49, 64, 78, 87, 103, 121, 120, 101,
+ 72, 92, 95, 98, 112, 100, 103, 99
+ };
+ static const unsigned int std_chrominance_quant_tbl[DCTSIZE2] = {
+ 17, 18, 24, 47, 99, 99, 99, 99,
+ 18, 21, 26, 66, 99, 99, 99, 99,
+ 24, 26, 56, 99, 99, 99, 99, 99,
+ 47, 66, 99, 99, 99, 99, 99, 99,
+ 99, 99, 99, 99, 99, 99, 99, 99,
+ 99, 99, 99, 99, 99, 99, 99, 99,
+ 99, 99, 99, 99, 99, 99, 99, 99,
+ 99, 99, 99, 99, 99, 99, 99, 99
+ };
+
+ /* Set up two quantization tables using the specified scaling */
+ jpeg_add_quant_table(cinfo, 0, std_luminance_quant_tbl,
+ scale_factor, force_baseline);
+ jpeg_add_quant_table(cinfo, 1, std_chrominance_quant_tbl,
+ scale_factor, force_baseline);
+}
+
+
+GLOBAL(int)
+jpeg_quality_scaling (int quality)
+/* Convert a user-specified quality rating to a percentage scaling factor
+ * for an underlying quantization table, using our recommended scaling curve.
+ * The input 'quality' factor should be 0 (terrible) to 100 (very good).
+ */
+{
+ /* Safety limit on quality factor. Convert 0 to 1 to avoid zero divide. */
+ if (quality <= 0) quality = 1;
+ if (quality > 100) quality = 100;
+
+ /* The basic table is used as-is (scaling 100) for a quality of 50.
+ * Qualities 50..100 are converted to scaling percentage 200 - 2*Q;
+ * note that at Q=100 the scaling is 0, which will cause jpeg_add_quant_table
+ * to make all the table entries 1 (hence, minimum quantization loss).
+ * Qualities 1..50 are converted to scaling percentage 5000/Q.
+ */
+ if (quality < 50)
+ quality = 5000 / quality;
+ else
+ quality = 200 - quality*2;
+
+ return quality;
+}
+
+
+GLOBAL(void)
+jpeg_set_quality (j_compress_ptr cinfo, int quality, int force_baseline)
+/* Set or change the 'quality' (quantization) setting, using default tables.
+ * This is the standard quality-adjusting entry point for typical user
+ * interfaces; only those who want detailed control over quantization tables
+ * would use the preceding three routines directly.
+ */
+{
+ /* Convert user 0-100 rating to percentage scaling */
+ quality = jpeg_quality_scaling(quality);
+
+ /* Set up standard quality tables */
+ jpeg_set_linear_quality(cinfo, quality, force_baseline);
+}
+
+
+/*
+ * Huffman table setup routines
+ */
+
+LOCAL(void)
+add_huff_table (j_compress_ptr cinfo,
+ JHUFF_TBL **htblptr, const unsigned char *bits, const unsigned char *val)
+/* Define a Huffman table */
+{
+ int nsymbols, len;
+
+ if (*htblptr == NULL)
+ *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo);
+
+ /* Copy the number-of-symbols-of-each-code-length counts */
+ MEMCOPY((*htblptr)->bits, bits, SIZEOF((*htblptr)->bits));
+
+ /* Validate the counts. We do this here mainly so we can copy the right
+ * number of symbols from the val[] array, without risking marching off
+ * the end of memory. jchuff.c will do a more thorough test later.
+ */
+ nsymbols = 0;
+ for (len = 1; len <= 16; len++)
+ nsymbols += bits[len];
+ if (nsymbols < 1 || nsymbols > 256)
+ ERREXIT(cinfo, JERR_BAD_HUFF_TABLE);
+
+ MEMCOPY((*htblptr)->huffval, val, nsymbols * SIZEOF(unsigned char));
+
+ /* Initialize sent_table FALSE so table will be written to JPEG file. */
+ (*htblptr)->sent_table = FALSE;
+}
+
+
+LOCAL(void)
+std_huff_tables (j_compress_ptr cinfo)
+/* Set up the standard Huffman tables (cf. JPEG standard section K.3) */
+/* IMPORTANT: these are only valid for 8-bit data precision! */
+{
+ static const unsigned char bits_dc_luminance[17] =
+ { /* 0-base */ 0, 0, 1, 5, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 };
+ static const unsigned char val_dc_luminance[] =
+ { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 };
+
+ static const unsigned char bits_dc_chrominance[17] =
+ { /* 0-base */ 0, 0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0 };
+ static const unsigned char val_dc_chrominance[] =
+ { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 };
+
+ static const unsigned char bits_ac_luminance[17] =
+ { /* 0-base */ 0, 0, 2, 1, 3, 3, 2, 4, 3, 5, 5, 4, 4, 0, 0, 1, 0x7d };
+ static const unsigned char val_ac_luminance[] =
+ { 0x01, 0x02, 0x03, 0x00, 0x04, 0x11, 0x05, 0x12,
+ 0x21, 0x31, 0x41, 0x06, 0x13, 0x51, 0x61, 0x07,
+ 0x22, 0x71, 0x14, 0x32, 0x81, 0x91, 0xa1, 0x08,
+ 0x23, 0x42, 0xb1, 0xc1, 0x15, 0x52, 0xd1, 0xf0,
+ 0x24, 0x33, 0x62, 0x72, 0x82, 0x09, 0x0a, 0x16,
+ 0x17, 0x18, 0x19, 0x1a, 0x25, 0x26, 0x27, 0x28,
+ 0x29, 0x2a, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39,
+ 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49,
+ 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59,
+ 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69,
+ 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
+ 0x7a, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89,
+ 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98,
+ 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
+ 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6,
+ 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, 0xc4, 0xc5,
+ 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, 0xd3, 0xd4,
+ 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xe1, 0xe2,
+ 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea,
+ 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8,
+ 0xf9, 0xfa };
+
+ static const unsigned char bits_ac_chrominance[17] =
+ { /* 0-base */ 0, 0, 2, 1, 2, 4, 4, 3, 4, 7, 5, 4, 4, 0, 1, 2, 0x77 };
+ static const unsigned char val_ac_chrominance[] =
+ { 0x00, 0x01, 0x02, 0x03, 0x11, 0x04, 0x05, 0x21,
+ 0x31, 0x06, 0x12, 0x41, 0x51, 0x07, 0x61, 0x71,
+ 0x13, 0x22, 0x32, 0x81, 0x08, 0x14, 0x42, 0x91,
+ 0xa1, 0xb1, 0xc1, 0x09, 0x23, 0x33, 0x52, 0xf0,
+ 0x15, 0x62, 0x72, 0xd1, 0x0a, 0x16, 0x24, 0x34,
+ 0xe1, 0x25, 0xf1, 0x17, 0x18, 0x19, 0x1a, 0x26,
+ 0x27, 0x28, 0x29, 0x2a, 0x35, 0x36, 0x37, 0x38,
+ 0x39, 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48,
+ 0x49, 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58,
+ 0x59, 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68,
+ 0x69, 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78,
+ 0x79, 0x7a, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
+ 0x88, 0x89, 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96,
+ 0x97, 0x98, 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5,
+ 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4,
+ 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3,
+ 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2,
+ 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda,
+ 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9,
+ 0xea, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8,
+ 0xf9, 0xfa };
+
+ add_huff_table(cinfo, &cinfo->dc_huff_tbl_ptrs[0],
+ bits_dc_luminance, val_dc_luminance);
+ add_huff_table(cinfo, &cinfo->ac_huff_tbl_ptrs[0],
+ bits_ac_luminance, val_ac_luminance);
+ add_huff_table(cinfo, &cinfo->dc_huff_tbl_ptrs[1],
+ bits_dc_chrominance, val_dc_chrominance);
+ add_huff_table(cinfo, &cinfo->ac_huff_tbl_ptrs[1],
+ bits_ac_chrominance, val_ac_chrominance);
+}
+
+
+/*
+ * Default parameter setup for compression.
+ *
+ * Applications that don't choose to use this routine must do their
+ * own setup of all these parameters. Alternately, you can call this
+ * to establish defaults and then alter parameters selectively. This
+ * is the recommended approach since, if we add any new parameters,
+ * your code will still work (they'll be set to reasonable defaults).
+ */
+
+GLOBAL(void)
+jpeg_set_defaults (j_compress_ptr cinfo)
+{
+ int i;
+
+ /* Safety check to ensure start_compress not called yet. */
+ if (cinfo->global_state != CSTATE_START)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ /* Allocate comp_info array large enough for maximum component count.
+ * Array is made permanent in case application wants to compress
+ * multiple images at same param settings.
+ */
+ if (cinfo->comp_info == NULL)
+ cinfo->comp_info = (jpeg_component_info *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT,
+ MAX_COMPONENTS * SIZEOF(jpeg_component_info));
+
+ /* Initialize everything not dependent on the color space */
+
+ cinfo->data_precision = BITS_IN_JSAMPLE;
+ /* Set up two quantization tables using default quality of 75 */
+ jpeg_set_quality(cinfo, 75, TRUE);
+ /* Set up two Huffman tables */
+ std_huff_tables(cinfo);
+
+ /* Initialize default arithmetic coding conditioning */
+ for (i = 0; i < NUM_ARITH_TBLS; i++) {
+ cinfo->arith_dc_L[i] = 0;
+ cinfo->arith_dc_U[i] = 1;
+ cinfo->arith_ac_K[i] = 5;
+ }
+
+ /* Default is no multiple-scan output */
+ cinfo->scan_info = NULL;
+ cinfo->num_scans = 0;
+
+ /* Expect normal source image, not raw downsampled data */
+ cinfo->raw_data_in = FALSE;
+
+ /* Use Huffman coding, not arithmetic coding, by default */
+ cinfo->arith_code = FALSE;
+
+ /* By default, don't do extra passes to optimize entropy coding */
+ cinfo->optimize_coding = FALSE;
+ /* The standard Huffman tables are only valid for 8-bit data precision.
+ * If the precision is higher, force optimization on so that usable
+ * tables will be computed. This test can be removed if default tables
+ * are supplied that are valid for the desired precision.
+ */
+ if (cinfo->data_precision > 8)
+ cinfo->optimize_coding = TRUE;
+
+ /* By default, use the simpler non-cosited sampling alignment */
+ cinfo->CCIR601_sampling = FALSE;
+
+ /* No input smoothing */
+ cinfo->smoothing_factor = 0;
+
+ /* DCT algorithm preference */
+ cinfo->dct_method = JDCT_DEFAULT;
+
+ /* No restart markers */
+ cinfo->restart_interval = 0;
+ cinfo->restart_in_rows = 0;
+
+ /* Fill in default JFIF marker parameters. Note that whether the marker
+ * will actually be written is determined by jpeg_set_colorspace.
+ *
+ * By default, the library emits JFIF version code 1.01.
+ * An application that wants to emit JFIF 1.02 extension markers should set
+ * JFIF_minor_version to 2. We could probably get away with just defaulting
+ * to 1.02, but there may still be some decoders in use that will complain
+ * about that; saying 1.01 should minimize compatibility problems.
+ */
+ cinfo->JFIF_major_version = 1; /* Default JFIF version = 1.01 */
+ cinfo->JFIF_minor_version = 1;
+ cinfo->density_unit = 0; /* Pixel size is unknown by default */
+ cinfo->X_density = 1; /* Pixel aspect ratio is square by default */
+ cinfo->Y_density = 1;
+
+ /* Choose JPEG colorspace based on input space, set defaults accordingly */
+
+ jpeg_default_colorspace(cinfo);
+}
+
+
+/*
+ * Select an appropriate JPEG colorspace for in_color_space.
+ */
+
+GLOBAL(void)
+jpeg_default_colorspace (j_compress_ptr cinfo)
+{
+ switch (cinfo->in_color_space) {
+ case JCS_GRAYSCALE:
+ jpeg_set_colorspace(cinfo, JCS_GRAYSCALE);
+ break;
+ case JCS_RGB:
+ jpeg_set_colorspace(cinfo, JCS_YCbCr);
+ break;
+ case JCS_YCbCr:
+ jpeg_set_colorspace(cinfo, JCS_YCbCr);
+ break;
+ case JCS_CMYK:
+ jpeg_set_colorspace(cinfo, JCS_CMYK); /* By default, no translation */
+ break;
+ case JCS_YCCK:
+ jpeg_set_colorspace(cinfo, JCS_YCCK);
+ break;
+ case JCS_UNKNOWN:
+ jpeg_set_colorspace(cinfo, JCS_UNKNOWN);
+ break;
+ default:
+ ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE);
+ }
+}
+
+
+/*
+ * Set the JPEG colorspace, and choose colorspace-dependent default values.
+ */
+
+GLOBAL(void)
+jpeg_set_colorspace (j_compress_ptr cinfo, J_COLOR_SPACE colorspace)
+{
+ jpeg_component_info * compptr;
+ int ci;
+
+#define SET_COMP(index,id,hsamp,vsamp,quant,dctbl,actbl) \
+ (compptr = &cinfo->comp_info[index], \
+ compptr->component_id = (id), \
+ compptr->h_samp_factor = (hsamp), \
+ compptr->v_samp_factor = (vsamp), \
+ compptr->quant_tbl_no = (quant), \
+ compptr->dc_tbl_no = (dctbl), \
+ compptr->ac_tbl_no = (actbl) )
+
+ /* Safety check to ensure start_compress not called yet. */
+ if (cinfo->global_state != CSTATE_START)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ /* For all colorspaces, we use Q and Huff tables 0 for luminance components,
+ * tables 1 for chrominance components.
+ */
+
+ cinfo->jpeg_color_space = colorspace;
+
+ cinfo->write_JFIF_header = FALSE; /* No marker for non-JFIF colorspaces */
+ cinfo->write_Adobe_marker = FALSE; /* write no Adobe marker by default */
+
+ switch (colorspace) {
+ case JCS_GRAYSCALE:
+ cinfo->write_JFIF_header = TRUE; /* Write a JFIF marker */
+ cinfo->num_components = 1;
+ /* JFIF specifies component ID 1 */
+ SET_COMP(0, 1, 1,1, 0, 0,0);
+ break;
+ case JCS_RGB:
+ cinfo->write_Adobe_marker = TRUE; /* write Adobe marker to flag RGB */
+ cinfo->num_components = 3;
+ SET_COMP(0, 0x52 /* 'R' */, 1,1, 0, 0,0);
+ SET_COMP(1, 0x47 /* 'G' */, 1,1, 0, 0,0);
+ SET_COMP(2, 0x42 /* 'B' */, 1,1, 0, 0,0);
+ break;
+ case JCS_YCbCr:
+ cinfo->write_JFIF_header = TRUE; /* Write a JFIF marker */
+ cinfo->num_components = 3;
+ /* JFIF specifies component IDs 1,2,3 */
+ /* We default to 2x2 subsamples of chrominance */
+ SET_COMP(0, 1, 2,2, 0, 0,0);
+ SET_COMP(1, 2, 1,1, 1, 1,1);
+ SET_COMP(2, 3, 1,1, 1, 1,1);
+ break;
+ case JCS_CMYK:
+ cinfo->write_Adobe_marker = TRUE; /* write Adobe marker to flag CMYK */
+ cinfo->num_components = 4;
+ SET_COMP(0, 0x43 /* 'C' */, 1,1, 0, 0,0);
+ SET_COMP(1, 0x4D /* 'M' */, 1,1, 0, 0,0);
+ SET_COMP(2, 0x59 /* 'Y' */, 1,1, 0, 0,0);
+ SET_COMP(3, 0x4B /* 'K' */, 1,1, 0, 0,0);
+ break;
+ case JCS_YCCK:
+ cinfo->write_Adobe_marker = TRUE; /* write Adobe marker to flag YCCK */
+ cinfo->num_components = 4;
+ SET_COMP(0, 1, 2,2, 0, 0,0);
+ SET_COMP(1, 2, 1,1, 1, 1,1);
+ SET_COMP(2, 3, 1,1, 1, 1,1);
+ SET_COMP(3, 4, 2,2, 0, 0,0);
+ break;
+ case JCS_UNKNOWN:
+ cinfo->num_components = cinfo->input_components;
+ if (cinfo->num_components < 1 || cinfo->num_components > MAX_COMPONENTS)
+ ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->num_components,
+ MAX_COMPONENTS);
+ for (ci = 0; ci < cinfo->num_components; ci++) {
+ SET_COMP(ci, ci, 1,1, 0, 0,0);
+ }
+ break;
+ default:
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ }
+}
+
+
+#ifdef C_PROGRESSIVE_SUPPORTED
+
+LOCAL(jpeg_scan_info *)
+fill_a_scan (jpeg_scan_info * scanptr, int ci,
+ int Ss, int Se, int Ah, int Al)
+/* Support routine: generate one scan for specified component */
+{
+ scanptr->comps_in_scan = 1;
+ scanptr->component_index[0] = ci;
+ scanptr->Ss = Ss;
+ scanptr->Se = Se;
+ scanptr->Ah = Ah;
+ scanptr->Al = Al;
+ scanptr++;
+ return scanptr;
+}
+
+LOCAL(jpeg_scan_info *)
+fill_scans (jpeg_scan_info * scanptr, int ncomps,
+ int Ss, int Se, int Ah, int Al)
+/* Support routine: generate one scan for each component */
+{
+ int ci;
+
+ for (ci = 0; ci < ncomps; ci++) {
+ scanptr->comps_in_scan = 1;
+ scanptr->component_index[0] = ci;
+ scanptr->Ss = Ss;
+ scanptr->Se = Se;
+ scanptr->Ah = Ah;
+ scanptr->Al = Al;
+ scanptr++;
+ }
+ return scanptr;
+}
+
+LOCAL(jpeg_scan_info *)
+fill_dc_scans (jpeg_scan_info * scanptr, int ncomps, int Ah, int Al)
+/* Support routine: generate interleaved DC scan if possible, else N scans */
+{
+ int ci;
+
+ if (ncomps <= MAX_COMPS_IN_SCAN) {
+ /* Single interleaved DC scan */
+ scanptr->comps_in_scan = ncomps;
+ for (ci = 0; ci < ncomps; ci++)
+ scanptr->component_index[ci] = ci;
+ scanptr->Ss = scanptr->Se = 0;
+ scanptr->Ah = Ah;
+ scanptr->Al = Al;
+ scanptr++;
+ } else {
+ /* Noninterleaved DC scan for each component */
+ scanptr = fill_scans(scanptr, ncomps, 0, 0, Ah, Al);
+ }
+ return scanptr;
+}
+
+
+/*
+ * Create a recommended progressive-JPEG script.
+ * cinfo->num_components and cinfo->jpeg_color_space must be correct.
+ */
+
+GLOBAL(void)
+jpeg_simple_progression (j_compress_ptr cinfo)
+{
+ int ncomps = cinfo->num_components;
+ int nscans;
+ jpeg_scan_info * scanptr;
+
+ /* Safety check to ensure start_compress not called yet. */
+ if (cinfo->global_state != CSTATE_START)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ /* Figure space needed for script. Calculation must match code below! */
+ if (ncomps == 3 && cinfo->jpeg_color_space == JCS_YCbCr) {
+ /* Custom script for YCbCr color images. */
+ nscans = 10;
+ } else {
+ /* All-purpose script for other color spaces. */
+ if (ncomps > MAX_COMPS_IN_SCAN)
+ nscans = 6 * ncomps; /* 2 DC + 4 AC scans per component */
+ else
+ nscans = 2 + 4 * ncomps; /* 2 DC scans; 4 AC scans per component */
+ }
+
+ /* Allocate space for script.
+ * We need to put it in the permanent pool in case the application performs
+ * multiple compressions without changing the settings. To avoid a memory
+ * leak if jpeg_simple_progression is called repeatedly for the same JPEG
+ * object, we try to re-use previously allocated space, and we allocate
+ * enough space to handle YCbCr even if initially asked for grayscale.
+ */
+ if (cinfo->script_space == NULL || cinfo->script_space_size < nscans) {
+ cinfo->script_space_size = MAX(nscans, 10);
+ cinfo->script_space = (jpeg_scan_info *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT,
+ cinfo->script_space_size * SIZEOF(jpeg_scan_info));
+ }
+ scanptr = cinfo->script_space;
+ cinfo->scan_info = scanptr;
+ cinfo->num_scans = nscans;
+
+ if (ncomps == 3 && cinfo->jpeg_color_space == JCS_YCbCr) {
+ /* Custom script for YCbCr color images. */
+ /* Initial DC scan */
+ scanptr = fill_dc_scans(scanptr, ncomps, 0, 1);
+ /* Initial AC scan: get some luma data out in a hurry */
+ scanptr = fill_a_scan(scanptr, 0, 1, 5, 0, 2);
+ /* Chroma data is too small to be worth expending many scans on */
+ scanptr = fill_a_scan(scanptr, 2, 1, 63, 0, 1);
+ scanptr = fill_a_scan(scanptr, 1, 1, 63, 0, 1);
+ /* Complete spectral selection for luma AC */
+ scanptr = fill_a_scan(scanptr, 0, 6, 63, 0, 2);
+ /* Refine next bit of luma AC */
+ scanptr = fill_a_scan(scanptr, 0, 1, 63, 2, 1);
+ /* Finish DC successive approximation */
+ scanptr = fill_dc_scans(scanptr, ncomps, 1, 0);
+ /* Finish AC successive approximation */
+ scanptr = fill_a_scan(scanptr, 2, 1, 63, 1, 0);
+ scanptr = fill_a_scan(scanptr, 1, 1, 63, 1, 0);
+ /* Luma bottom bit comes last since it's usually largest scan */
+ scanptr = fill_a_scan(scanptr, 0, 1, 63, 1, 0);
+ } else {
+ /* All-purpose script for other color spaces. */
+ /* Successive approximation first pass */
+ scanptr = fill_dc_scans(scanptr, ncomps, 0, 1);
+ scanptr = fill_scans(scanptr, ncomps, 1, 5, 0, 2);
+ scanptr = fill_scans(scanptr, ncomps, 6, 63, 0, 2);
+ /* Successive approximation second pass */
+ scanptr = fill_scans(scanptr, ncomps, 1, 63, 2, 1);
+ /* Successive approximation final pass */
+ scanptr = fill_dc_scans(scanptr, ncomps, 1, 0);
+ scanptr = fill_scans(scanptr, ncomps, 1, 63, 1, 0);
+ }
+}
+
+#endif /* C_PROGRESSIVE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jcphuff.cpp b/ml/dlib/dlib/external/libjpeg/jcphuff.cpp
new file mode 100644
index 000000000..66a85b8c7
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcphuff.cpp
@@ -0,0 +1,833 @@
+/*
+ * jcphuff.c
+ *
+ * Copyright (C) 1995-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains Huffman entropy encoding routines for progressive JPEG.
+ *
+ * We do not support output suspension in this module, since the library
+ * currently does not allow multiple-scan files to be written with output
+ * suspension.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jchuff.h" /* Declarations shared with jchuff.c */
+
+#ifdef C_PROGRESSIVE_SUPPORTED
+
+/* Expanded entropy encoder object for progressive Huffman encoding. */
+
+typedef struct {
+ struct jpeg_entropy_encoder pub; /* public fields */
+
+ /* Mode flag: TRUE for optimization, FALSE for actual data output */
+ int gather_statistics;
+
+ /* Bit-level coding status.
+ * next_output_byte/free_in_buffer are local copies of cinfo->dest fields.
+ */
+ JOCTET * next_output_byte; /* => next byte to write in buffer */
+ size_t free_in_buffer; /* # of byte spaces remaining in buffer */
+ long put_buffer; /* current bit-accumulation buffer */
+ int put_bits; /* # of bits now in it */
+ j_compress_ptr cinfo; /* link to cinfo (needed for dump_buffer) */
+
+ /* Coding status for DC components */
+ int last_dc_val[MAX_COMPS_IN_SCAN]; /* last DC coef for each component */
+
+ /* Coding status for AC components */
+ int ac_tbl_no; /* the table number of the single component */
+ unsigned int EOBRUN; /* run length of EOBs */
+ unsigned int BE; /* # of buffered correction bits before MCU */
+ char * bit_buffer; /* buffer for correction bits (1 per char) */
+ /* packing correction bits tightly would save some space but cost time... */
+
+ unsigned int restarts_to_go; /* MCUs left in this restart interval */
+ int next_restart_num; /* next restart number to write (0-7) */
+
+ /* Pointers to derived tables (these workspaces have image lifespan).
+ * Since any one scan codes only DC or only AC, we only need one set
+ * of tables, not one for DC and one for AC.
+ */
+ c_derived_tbl * derived_tbls[NUM_HUFF_TBLS];
+
+ /* Statistics tables for optimization; again, one set is enough */
+ long * count_ptrs[NUM_HUFF_TBLS];
+} phuff_entropy_encoder;
+
+typedef phuff_entropy_encoder * phuff_entropy_ptr;
+
+/* MAX_CORR_BITS is the number of bits the AC refinement correction-bit
+ * buffer can hold. Larger sizes may slightly improve compression, but
+ * 1000 is already well into the realm of overkill.
+ * The minimum safe size is 64 bits.
+ */
+
+#define MAX_CORR_BITS 1000 /* Max # of correction bits I can buffer */
+
+/* IRIGHT_SHIFT is like RIGHT_SHIFT, but works on int rather than long.
+ * We assume that int right shift is unsigned if long right shift is,
+ * which should be safe.
+ */
+
+#ifdef RIGHT_SHIFT_IS_UNSIGNED
+#define ISHIFT_TEMPS int ishift_temp;
+#define IRIGHT_SHIFT(x,shft) \
+ ((ishift_temp = (x)) < 0 ? \
+ (ishift_temp >> (shft)) | ((~0) << (16-(shft))) : \
+ (ishift_temp >> (shft)))
+#else
+#define ISHIFT_TEMPS
+#define IRIGHT_SHIFT(x,shft) ((x) >> (shft))
+#endif
+
+/* Forward declarations */
+METHODDEF(int) encode_mcu_DC_first JPP((j_compress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+METHODDEF(int) encode_mcu_AC_first JPP((j_compress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+METHODDEF(int) encode_mcu_DC_refine JPP((j_compress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+METHODDEF(int) encode_mcu_AC_refine JPP((j_compress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+METHODDEF(void) finish_pass_phuff JPP((j_compress_ptr cinfo));
+METHODDEF(void) finish_pass_gather_phuff JPP((j_compress_ptr cinfo));
+
+
+/*
+ * Initialize for a Huffman-compressed scan using progressive JPEG.
+ */
+
+METHODDEF(void)
+start_pass_phuff (j_compress_ptr cinfo, int gather_statistics)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int is_DC_band;
+ int ci, tbl;
+ jpeg_component_info * compptr;
+
+ entropy->cinfo = cinfo;
+ entropy->gather_statistics = gather_statistics;
+
+ is_DC_band = (cinfo->Ss == 0);
+
+ /* We assume jcmaster.c already validated the scan parameters. */
+
+ /* Select execution routines */
+ if (cinfo->Ah == 0) {
+ if (is_DC_band)
+ entropy->pub.encode_mcu = encode_mcu_DC_first;
+ else
+ entropy->pub.encode_mcu = encode_mcu_AC_first;
+ } else {
+ if (is_DC_band)
+ entropy->pub.encode_mcu = encode_mcu_DC_refine;
+ else {
+ entropy->pub.encode_mcu = encode_mcu_AC_refine;
+ /* AC refinement needs a correction bit buffer */
+ if (entropy->bit_buffer == NULL)
+ entropy->bit_buffer = (char *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ MAX_CORR_BITS * SIZEOF(char));
+ }
+ }
+ if (gather_statistics)
+ entropy->pub.finish_pass = finish_pass_gather_phuff;
+ else
+ entropy->pub.finish_pass = finish_pass_phuff;
+
+ /* Only DC coefficients may be interleaved, so cinfo->comps_in_scan = 1
+ * for AC coefficients.
+ */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ /* Initialize DC predictions to 0 */
+ entropy->last_dc_val[ci] = 0;
+ /* Get table index */
+ if (is_DC_band) {
+ if (cinfo->Ah != 0) /* DC refinement needs no table */
+ continue;
+ tbl = compptr->dc_tbl_no;
+ } else {
+ entropy->ac_tbl_no = tbl = compptr->ac_tbl_no;
+ }
+ if (gather_statistics) {
+ /* Check for invalid table index */
+ /* (make_c_derived_tbl does this in the other path) */
+ if (tbl < 0 || tbl >= NUM_HUFF_TBLS)
+ ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tbl);
+ /* Allocate and zero the statistics tables */
+ /* Note that jpeg_gen_optimal_table expects 257 entries in each table! */
+ if (entropy->count_ptrs[tbl] == NULL)
+ entropy->count_ptrs[tbl] = (long *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ 257 * SIZEOF(long));
+ MEMZERO(entropy->count_ptrs[tbl], 257 * SIZEOF(long));
+ } else {
+ /* Compute derived values for Huffman table */
+ /* We may do this more than once for a table, but it's not expensive */
+ jpeg_make_c_derived_tbl(cinfo, is_DC_band, tbl,
+ & entropy->derived_tbls[tbl]);
+ }
+ }
+
+ /* Initialize AC stuff */
+ entropy->EOBRUN = 0;
+ entropy->BE = 0;
+
+ /* Initialize bit buffer to empty */
+ entropy->put_buffer = 0;
+ entropy->put_bits = 0;
+
+ /* Initialize restart stuff */
+ entropy->restarts_to_go = cinfo->restart_interval;
+ entropy->next_restart_num = 0;
+}
+
+
+/* Outputting bytes to the file.
+ * NB: these must be called only when actually outputting,
+ * that is, entropy->gather_statistics == FALSE.
+ */
+
+/* Emit a byte */
+#define emit_byte(entropy,val) \
+ { *(entropy)->next_output_byte++ = (JOCTET) (val); \
+ if (--(entropy)->free_in_buffer == 0) \
+ dump_buffer(entropy); }
+
+
+LOCAL(void)
+dump_buffer (phuff_entropy_ptr entropy)
+/* Empty the output buffer; we do not support suspension in this module. */
+{
+ struct jpeg_destination_mgr * dest = entropy->cinfo->dest;
+
+ if (! (*dest->empty_output_buffer) (entropy->cinfo))
+ ERREXIT(entropy->cinfo, JERR_CANT_SUSPEND);
+ /* After a successful buffer dump, must reset buffer pointers */
+ entropy->next_output_byte = dest->next_output_byte;
+ entropy->free_in_buffer = dest->free_in_buffer;
+}
+
+
+/* Outputting bits to the file */
+
+/* Only the right 24 bits of put_buffer are used; the valid bits are
+ * left-justified in this part. At most 16 bits can be passed to emit_bits
+ * in one call, and we never retain more than 7 bits in put_buffer
+ * between calls, so 24 bits are sufficient.
+ */
+
+inline
+LOCAL(void)
+emit_bits (phuff_entropy_ptr entropy, unsigned int code, int size)
+/* Emit some bits, unless we are in gather mode */
+{
+ /* This routine is heavily used, so it's worth coding tightly. */
+ long put_buffer = (long) code;
+ int put_bits = entropy->put_bits;
+
+ /* if size is 0, caller used an invalid Huffman table entry */
+ if (size == 0)
+ ERREXIT(entropy->cinfo, JERR_HUFF_MISSING_CODE);
+
+ if (entropy->gather_statistics)
+ return; /* do nothing if we're only getting stats */
+
+ put_buffer &= (((long) 1)<<size) - 1; /* mask off any extra bits in code */
+
+ put_bits += size; /* new number of bits in buffer */
+
+ put_buffer <<= 24 - put_bits; /* align incoming bits */
+
+ put_buffer |= entropy->put_buffer; /* and merge with old buffer contents */
+
+ while (put_bits >= 8) {
+ int c = (int) ((put_buffer >> 16) & 0xFF);
+
+ emit_byte(entropy, c);
+ if (c == 0xFF) { /* need to stuff a zero byte? */
+ emit_byte(entropy, 0);
+ }
+ put_buffer <<= 8;
+ put_bits -= 8;
+ }
+
+ entropy->put_buffer = put_buffer; /* update variables */
+ entropy->put_bits = put_bits;
+}
+
+
+LOCAL(void)
+flush_bits (phuff_entropy_ptr entropy)
+{
+ emit_bits(entropy, 0x7F, 7); /* fill any partial byte with ones */
+ entropy->put_buffer = 0; /* and reset bit-buffer to empty */
+ entropy->put_bits = 0;
+}
+
+
+/*
+ * Emit (or just count) a Huffman symbol.
+ */
+
+inline
+LOCAL(void)
+emit_symbol (phuff_entropy_ptr entropy, int tbl_no, int symbol)
+{
+ if (entropy->gather_statistics)
+ entropy->count_ptrs[tbl_no][symbol]++;
+ else {
+ c_derived_tbl * tbl = entropy->derived_tbls[tbl_no];
+ emit_bits(entropy, tbl->ehufco[symbol], tbl->ehufsi[symbol]);
+ }
+}
+
+
+/*
+ * Emit bits from a correction bit buffer.
+ */
+
+LOCAL(void)
+emit_buffered_bits (phuff_entropy_ptr entropy, char * bufstart,
+ unsigned int nbits)
+{
+ if (entropy->gather_statistics)
+ return; /* no real work */
+
+ while (nbits > 0) {
+ emit_bits(entropy, (unsigned int) (*bufstart), 1);
+ bufstart++;
+ nbits--;
+ }
+}
+
+
+/*
+ * Emit any pending EOBRUN symbol.
+ */
+
+LOCAL(void)
+emit_eobrun (phuff_entropy_ptr entropy)
+{
+ int temp, nbits;
+
+ if (entropy->EOBRUN > 0) { /* if there is any pending EOBRUN */
+ temp = entropy->EOBRUN;
+ nbits = 0;
+ while ((temp >>= 1))
+ nbits++;
+ /* safety check: shouldn't happen given limited correction-bit buffer */
+ if (nbits > 14)
+ ERREXIT(entropy->cinfo, JERR_HUFF_MISSING_CODE);
+
+ emit_symbol(entropy, entropy->ac_tbl_no, nbits << 4);
+ if (nbits)
+ emit_bits(entropy, entropy->EOBRUN, nbits);
+
+ entropy->EOBRUN = 0;
+
+ /* Emit any buffered correction bits */
+ emit_buffered_bits(entropy, entropy->bit_buffer, entropy->BE);
+ entropy->BE = 0;
+ }
+}
+
+
+/*
+ * Emit a restart marker & resynchronize predictions.
+ */
+
+LOCAL(void)
+emit_restart (phuff_entropy_ptr entropy, int restart_num)
+{
+ int ci;
+
+ emit_eobrun(entropy);
+
+ if (! entropy->gather_statistics) {
+ flush_bits(entropy);
+ emit_byte(entropy, 0xFF);
+ emit_byte(entropy, JPEG_RST0 + restart_num);
+ }
+
+ if (entropy->cinfo->Ss == 0) {
+ /* Re-initialize DC predictions to 0 */
+ for (ci = 0; ci < entropy->cinfo->comps_in_scan; ci++)
+ entropy->last_dc_val[ci] = 0;
+ } else {
+ /* Re-initialize all AC-related fields to 0 */
+ entropy->EOBRUN = 0;
+ entropy->BE = 0;
+ }
+}
+
+
+/*
+ * MCU encoding for DC initial scan (either spectral selection,
+ * or first pass of successive approximation).
+ */
+
+METHODDEF(int)
+encode_mcu_DC_first (j_compress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int temp, temp2;
+ int nbits;
+ int blkn, ci;
+ int Al = cinfo->Al;
+ JBLOCKROW block;
+ jpeg_component_info * compptr;
+ ISHIFT_TEMPS
+
+ entropy->next_output_byte = cinfo->dest->next_output_byte;
+ entropy->free_in_buffer = cinfo->dest->free_in_buffer;
+
+ /* Emit restart marker if needed */
+ if (cinfo->restart_interval)
+ if (entropy->restarts_to_go == 0)
+ emit_restart(entropy, entropy->next_restart_num);
+
+ /* Encode the MCU data blocks */
+ for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) {
+ block = MCU_data[blkn];
+ ci = cinfo->MCU_membership[blkn];
+ compptr = cinfo->cur_comp_info[ci];
+
+ /* Compute the DC value after the required point transform by Al.
+ * This is simply an arithmetic right shift.
+ */
+ temp2 = IRIGHT_SHIFT((int) ((*block)[0]), Al);
+
+ /* DC differences are figured on the point-transformed values. */
+ temp = temp2 - entropy->last_dc_val[ci];
+ entropy->last_dc_val[ci] = temp2;
+
+ /* Encode the DC coefficient difference per section G.1.2.1 */
+ temp2 = temp;
+ if (temp < 0) {
+ temp = -temp; /* temp is abs value of input */
+ /* For a negative input, want temp2 = bitwise complement of abs(input) */
+ /* This code assumes we are on a two's complement machine */
+ temp2--;
+ }
+
+ /* Find the number of bits needed for the magnitude of the coefficient */
+ nbits = 0;
+ while (temp) {
+ nbits++;
+ temp >>= 1;
+ }
+ /* Check for out-of-range coefficient values.
+ * Since we're encoding a difference, the range limit is twice as much.
+ */
+ if (nbits > MAX_COEF_BITS+1)
+ ERREXIT(cinfo, JERR_BAD_DCT_COEF);
+
+ /* Count/emit the Huffman-coded symbol for the number of bits */
+ emit_symbol(entropy, compptr->dc_tbl_no, nbits);
+
+ /* Emit that number of bits of the value, if positive, */
+ /* or the complement of its magnitude, if negative. */
+ if (nbits) /* emit_bits rejects calls with size 0 */
+ emit_bits(entropy, (unsigned int) temp2, nbits);
+ }
+
+ cinfo->dest->next_output_byte = entropy->next_output_byte;
+ cinfo->dest->free_in_buffer = entropy->free_in_buffer;
+
+ /* Update restart-interval state too */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0) {
+ entropy->restarts_to_go = cinfo->restart_interval;
+ entropy->next_restart_num++;
+ entropy->next_restart_num &= 7;
+ }
+ entropy->restarts_to_go--;
+ }
+
+ return TRUE;
+}
+
+
+/*
+ * MCU encoding for AC initial scan (either spectral selection,
+ * or first pass of successive approximation).
+ */
+
+METHODDEF(int)
+encode_mcu_AC_first (j_compress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int temp, temp2;
+ int nbits;
+ int r, k;
+ int Se = cinfo->Se;
+ int Al = cinfo->Al;
+ JBLOCKROW block;
+
+ entropy->next_output_byte = cinfo->dest->next_output_byte;
+ entropy->free_in_buffer = cinfo->dest->free_in_buffer;
+
+ /* Emit restart marker if needed */
+ if (cinfo->restart_interval)
+ if (entropy->restarts_to_go == 0)
+ emit_restart(entropy, entropy->next_restart_num);
+
+ /* Encode the MCU data block */
+ block = MCU_data[0];
+
+ /* Encode the AC coefficients per section G.1.2.2, fig. G.3 */
+
+ r = 0; /* r = run length of zeros */
+
+ for (k = cinfo->Ss; k <= Se; k++) {
+ if ((temp = (*block)[jpeg_natural_order[k]]) == 0) {
+ r++;
+ continue;
+ }
+ /* We must apply the point transform by Al. For AC coefficients this
+ * is an integer division with rounding towards 0. To do this portably
+ * in C, we shift after obtaining the absolute value; so the code is
+ * interwoven with finding the abs value (temp) and output bits (temp2).
+ */
+ if (temp < 0) {
+ temp = -temp; /* temp is abs value of input */
+ temp >>= Al; /* apply the point transform */
+ /* For a negative coef, want temp2 = bitwise complement of abs(coef) */
+ temp2 = ~temp;
+ } else {
+ temp >>= Al; /* apply the point transform */
+ temp2 = temp;
+ }
+ /* Watch out for case that nonzero coef is zero after point transform */
+ if (temp == 0) {
+ r++;
+ continue;
+ }
+
+ /* Emit any pending EOBRUN */
+ if (entropy->EOBRUN > 0)
+ emit_eobrun(entropy);
+ /* if run length > 15, must emit special run-length-16 codes (0xF0) */
+ while (r > 15) {
+ emit_symbol(entropy, entropy->ac_tbl_no, 0xF0);
+ r -= 16;
+ }
+
+ /* Find the number of bits needed for the magnitude of the coefficient */
+ nbits = 1; /* there must be at least one 1 bit */
+ while ((temp >>= 1))
+ nbits++;
+ /* Check for out-of-range coefficient values */
+ if (nbits > MAX_COEF_BITS)
+ ERREXIT(cinfo, JERR_BAD_DCT_COEF);
+
+ /* Count/emit Huffman symbol for run length / number of bits */
+ emit_symbol(entropy, entropy->ac_tbl_no, (r << 4) + nbits);
+
+ /* Emit that number of bits of the value, if positive, */
+ /* or the complement of its magnitude, if negative. */
+ emit_bits(entropy, (unsigned int) temp2, nbits);
+
+ r = 0; /* reset zero run length */
+ }
+
+ if (r > 0) { /* If there are trailing zeroes, */
+ entropy->EOBRUN++; /* count an EOB */
+ if (entropy->EOBRUN == 0x7FFF)
+ emit_eobrun(entropy); /* force it out to avoid overflow */
+ }
+
+ cinfo->dest->next_output_byte = entropy->next_output_byte;
+ cinfo->dest->free_in_buffer = entropy->free_in_buffer;
+
+ /* Update restart-interval state too */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0) {
+ entropy->restarts_to_go = cinfo->restart_interval;
+ entropy->next_restart_num++;
+ entropy->next_restart_num &= 7;
+ }
+ entropy->restarts_to_go--;
+ }
+
+ return TRUE;
+}
+
+
+/*
+ * MCU encoding for DC successive approximation refinement scan.
+ * Note: we assume such scans can be multi-component, although the spec
+ * is not very clear on the point.
+ */
+
+METHODDEF(int)
+encode_mcu_DC_refine (j_compress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int temp;
+ int blkn;
+ int Al = cinfo->Al;
+ JBLOCKROW block;
+
+ entropy->next_output_byte = cinfo->dest->next_output_byte;
+ entropy->free_in_buffer = cinfo->dest->free_in_buffer;
+
+ /* Emit restart marker if needed */
+ if (cinfo->restart_interval)
+ if (entropy->restarts_to_go == 0)
+ emit_restart(entropy, entropy->next_restart_num);
+
+ /* Encode the MCU data blocks */
+ for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) {
+ block = MCU_data[blkn];
+
+ /* We simply emit the Al'th bit of the DC coefficient value. */
+ temp = (*block)[0];
+ emit_bits(entropy, (unsigned int) (temp >> Al), 1);
+ }
+
+ cinfo->dest->next_output_byte = entropy->next_output_byte;
+ cinfo->dest->free_in_buffer = entropy->free_in_buffer;
+
+ /* Update restart-interval state too */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0) {
+ entropy->restarts_to_go = cinfo->restart_interval;
+ entropy->next_restart_num++;
+ entropy->next_restart_num &= 7;
+ }
+ entropy->restarts_to_go--;
+ }
+
+ return TRUE;
+}
+
+
+/*
+ * MCU encoding for AC successive approximation refinement scan.
+ */
+
+METHODDEF(int)
+encode_mcu_AC_refine (j_compress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int temp;
+ int r, k;
+ int EOB;
+ char *BR_buffer;
+ unsigned int BR;
+ int Se = cinfo->Se;
+ int Al = cinfo->Al;
+ JBLOCKROW block;
+ int absvalues[DCTSIZE2];
+
+ entropy->next_output_byte = cinfo->dest->next_output_byte;
+ entropy->free_in_buffer = cinfo->dest->free_in_buffer;
+
+ /* Emit restart marker if needed */
+ if (cinfo->restart_interval)
+ if (entropy->restarts_to_go == 0)
+ emit_restart(entropy, entropy->next_restart_num);
+
+ /* Encode the MCU data block */
+ block = MCU_data[0];
+
+ /* It is convenient to make a pre-pass to determine the transformed
+ * coefficients' absolute values and the EOB position.
+ */
+ EOB = 0;
+ for (k = cinfo->Ss; k <= Se; k++) {
+ temp = (*block)[jpeg_natural_order[k]];
+ /* We must apply the point transform by Al. For AC coefficients this
+ * is an integer division with rounding towards 0. To do this portably
+ * in C, we shift after obtaining the absolute value.
+ */
+ if (temp < 0)
+ temp = -temp; /* temp is abs value of input */
+ temp >>= Al; /* apply the point transform */
+ absvalues[k] = temp; /* save abs value for main pass */
+ if (temp == 1)
+ EOB = k; /* EOB = index of last newly-nonzero coef */
+ }
+
+ /* Encode the AC coefficients per section G.1.2.3, fig. G.7 */
+
+ r = 0; /* r = run length of zeros */
+ BR = 0; /* BR = count of buffered bits added now */
+ BR_buffer = entropy->bit_buffer + entropy->BE; /* Append bits to buffer */
+
+ for (k = cinfo->Ss; k <= Se; k++) {
+ if ((temp = absvalues[k]) == 0) {
+ r++;
+ continue;
+ }
+
+ /* Emit any required ZRLs, but not if they can be folded into EOB */
+ while (r > 15 && k <= EOB) {
+ /* emit any pending EOBRUN and the BE correction bits */
+ emit_eobrun(entropy);
+ /* Emit ZRL */
+ emit_symbol(entropy, entropy->ac_tbl_no, 0xF0);
+ r -= 16;
+ /* Emit buffered correction bits that must be associated with ZRL */
+ emit_buffered_bits(entropy, BR_buffer, BR);
+ BR_buffer = entropy->bit_buffer; /* BE bits are gone now */
+ BR = 0;
+ }
+
+ /* If the coef was previously nonzero, it only needs a correction bit.
+ * NOTE: a straight translation of the spec's figure G.7 would suggest
+ * that we also need to test r > 15. But if r > 15, we can only get here
+ * if k > EOB, which implies that this coefficient is not 1.
+ */
+ if (temp > 1) {
+ /* The correction bit is the next bit of the absolute value. */
+ BR_buffer[BR++] = (char) (temp & 1);
+ continue;
+ }
+
+ /* Emit any pending EOBRUN and the BE correction bits */
+ emit_eobrun(entropy);
+
+ /* Count/emit Huffman symbol for run length / number of bits */
+ emit_symbol(entropy, entropy->ac_tbl_no, (r << 4) + 1);
+
+ /* Emit output bit for newly-nonzero coef */
+ temp = ((*block)[jpeg_natural_order[k]] < 0) ? 0 : 1;
+ emit_bits(entropy, (unsigned int) temp, 1);
+
+ /* Emit buffered correction bits that must be associated with this code */
+ emit_buffered_bits(entropy, BR_buffer, BR);
+ BR_buffer = entropy->bit_buffer; /* BE bits are gone now */
+ BR = 0;
+ r = 0; /* reset zero run length */
+ }
+
+ if (r > 0 || BR > 0) { /* If there are trailing zeroes, */
+ entropy->EOBRUN++; /* count an EOB */
+ entropy->BE += BR; /* concat my correction bits to older ones */
+ /* We force out the EOB if we risk either:
+ * 1. overflow of the EOB counter;
+ * 2. overflow of the correction bit buffer during the next MCU.
+ */
+ if (entropy->EOBRUN == 0x7FFF || entropy->BE > (MAX_CORR_BITS-DCTSIZE2+1))
+ emit_eobrun(entropy);
+ }
+
+ cinfo->dest->next_output_byte = entropy->next_output_byte;
+ cinfo->dest->free_in_buffer = entropy->free_in_buffer;
+
+ /* Update restart-interval state too */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0) {
+ entropy->restarts_to_go = cinfo->restart_interval;
+ entropy->next_restart_num++;
+ entropy->next_restart_num &= 7;
+ }
+ entropy->restarts_to_go--;
+ }
+
+ return TRUE;
+}
+
+
+/*
+ * Finish up at the end of a Huffman-compressed progressive scan.
+ */
+
+METHODDEF(void)
+finish_pass_phuff (j_compress_ptr cinfo)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+
+ entropy->next_output_byte = cinfo->dest->next_output_byte;
+ entropy->free_in_buffer = cinfo->dest->free_in_buffer;
+
+ /* Flush out any buffered data */
+ emit_eobrun(entropy);
+ flush_bits(entropy);
+
+ cinfo->dest->next_output_byte = entropy->next_output_byte;
+ cinfo->dest->free_in_buffer = entropy->free_in_buffer;
+}
+
+
+/*
+ * Finish up a statistics-gathering pass and create the new Huffman tables.
+ */
+
+METHODDEF(void)
+finish_pass_gather_phuff (j_compress_ptr cinfo)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int is_DC_band;
+ int ci, tbl;
+ jpeg_component_info * compptr;
+ JHUFF_TBL **htblptr;
+ int did[NUM_HUFF_TBLS];
+
+ /* Flush out buffered data (all we care about is counting the EOB symbol) */
+ emit_eobrun(entropy);
+
+ is_DC_band = (cinfo->Ss == 0);
+
+ /* It's important not to apply jpeg_gen_optimal_table more than once
+ * per table, because it clobbers the input frequency counts!
+ */
+ MEMZERO(did, SIZEOF(did));
+
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ if (is_DC_band) {
+ if (cinfo->Ah != 0) /* DC refinement needs no table */
+ continue;
+ tbl = compptr->dc_tbl_no;
+ } else {
+ tbl = compptr->ac_tbl_no;
+ }
+ if (! did[tbl]) {
+ if (is_DC_band)
+ htblptr = & cinfo->dc_huff_tbl_ptrs[tbl];
+ else
+ htblptr = & cinfo->ac_huff_tbl_ptrs[tbl];
+ if (*htblptr == NULL)
+ *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo);
+ jpeg_gen_optimal_table(cinfo, *htblptr, entropy->count_ptrs[tbl]);
+ did[tbl] = TRUE;
+ }
+ }
+}
+
+
+/*
+ * Module initialization routine for progressive Huffman entropy encoding.
+ */
+
+GLOBAL(void)
+jinit_phuff_encoder (j_compress_ptr cinfo)
+{
+ phuff_entropy_ptr entropy;
+ int i;
+
+ entropy = (phuff_entropy_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(phuff_entropy_encoder));
+ cinfo->entropy = (struct jpeg_entropy_encoder *) entropy;
+ entropy->pub.start_pass = start_pass_phuff;
+
+ /* Mark tables unallocated */
+ for (i = 0; i < NUM_HUFF_TBLS; i++) {
+ entropy->derived_tbls[i] = NULL;
+ entropy->count_ptrs[i] = NULL;
+ }
+ entropy->bit_buffer = NULL; /* needed only in AC refinement scan */
+}
+
+#endif /* C_PROGRESSIVE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jcprepct.cpp b/ml/dlib/dlib/external/libjpeg/jcprepct.cpp
new file mode 100644
index 000000000..d1532c273
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcprepct.cpp
@@ -0,0 +1,354 @@
+/*
+ * jcprepct.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains the compression preprocessing controller.
+ * This controller manages the color conversion, downsampling,
+ * and edge expansion steps.
+ *
+ * Most of the complexity here is associated with buffering input rows
+ * as required by the downsampler. See the comments at the head of
+ * jcsample.c for the downsampler's needs.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* At present, jcsample.c can request context rows only for smoothing.
+ * In the future, we might also need context rows for CCIR601 sampling
+ * or other more-complex downsampling procedures. The code to support
+ * context rows should be compiled only if needed.
+ */
+#ifdef INPUT_SMOOTHING_SUPPORTED
+#define CONTEXT_ROWS_SUPPORTED
+#endif
+
+
+/*
+ * For the simple (no-context-row) case, we just need to buffer one
+ * row group's worth of pixels for the downsampling step. At the bottom of
+ * the image, we pad to a full row group by replicating the last pixel row.
+ * The downsampler's last output row is then replicated if needed to pad
+ * out to a full iMCU row.
+ *
+ * When providing context rows, we must buffer three row groups' worth of
+ * pixels. Three row groups are physically allocated, but the row pointer
+ * arrays are made five row groups high, with the extra pointers above and
+ * below "wrapping around" to point to the last and first real row groups.
+ * This allows the downsampler to access the proper context rows.
+ * At the top and bottom of the image, we create dummy context rows by
+ * copying the first or last real pixel row. This copying could be avoided
+ * by pointer hacking as is done in jdmainct.c, but it doesn't seem worth the
+ * trouble on the compression side.
+ */
+
+
+/* Private buffer controller object */
+
+typedef struct {
+ struct jpeg_c_prep_controller pub; /* public fields */
+
+ /* Downsampling input buffer. This buffer holds color-converted data
+ * until we have enough to do a downsample step.
+ */
+ JSAMPARRAY color_buf[MAX_COMPONENTS];
+
+ JDIMENSION rows_to_go; /* counts rows remaining in source image */
+ int next_buf_row; /* index of next row to store in color_buf */
+
+#ifdef CONTEXT_ROWS_SUPPORTED /* only needed for context case */
+ int this_row_group; /* starting row index of group to process */
+ int next_buf_stop; /* downsample when we reach this index */
+#endif
+} my_prep_controller;
+
+typedef my_prep_controller * my_prep_ptr;
+
+
+/*
+ * Initialize for a processing pass.
+ */
+
+METHODDEF(void)
+start_pass_prep (j_compress_ptr cinfo, J_BUF_MODE pass_mode)
+{
+ my_prep_ptr prep = (my_prep_ptr) cinfo->prep;
+
+ if (pass_mode != JBUF_PASS_THRU)
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+
+ /* Initialize total-height counter for detecting bottom of image */
+ prep->rows_to_go = cinfo->image_height;
+ /* Mark the conversion buffer empty */
+ prep->next_buf_row = 0;
+#ifdef CONTEXT_ROWS_SUPPORTED
+ /* Preset additional state variables for context mode.
+ * These aren't used in non-context mode, so we needn't test which mode.
+ */
+ prep->this_row_group = 0;
+ /* Set next_buf_stop to stop after two row groups have been read in. */
+ prep->next_buf_stop = 2 * cinfo->max_v_samp_factor;
+#endif
+}
+
+
+/*
+ * Expand an image vertically from height input_rows to height output_rows,
+ * by duplicating the bottom row.
+ */
+
+LOCAL(void)
+expand_bottom_edge (JSAMPARRAY image_data, JDIMENSION num_cols,
+ int input_rows, int output_rows)
+{
+ int row;
+
+ for (row = input_rows; row < output_rows; row++) {
+ jcopy_sample_rows(image_data, input_rows-1, image_data, row,
+ 1, num_cols);
+ }
+}
+
+
+/*
+ * Process some data in the simple no-context case.
+ *
+ * Preprocessor output data is counted in "row groups". A row group
+ * is defined to be v_samp_factor sample rows of each component.
+ * Downsampling will produce this much data from each max_v_samp_factor
+ * input rows.
+ */
+
+METHODDEF(void)
+pre_process_data (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JDIMENSION *in_row_ctr,
+ JDIMENSION in_rows_avail,
+ JSAMPIMAGE output_buf, JDIMENSION *out_row_group_ctr,
+ JDIMENSION out_row_groups_avail)
+{
+ my_prep_ptr prep = (my_prep_ptr) cinfo->prep;
+ int numrows, ci;
+ JDIMENSION inrows;
+ jpeg_component_info * compptr;
+
+ while (*in_row_ctr < in_rows_avail &&
+ *out_row_group_ctr < out_row_groups_avail) {
+ /* Do color conversion to fill the conversion buffer. */
+ inrows = in_rows_avail - *in_row_ctr;
+ numrows = cinfo->max_v_samp_factor - prep->next_buf_row;
+ numrows = (int) MIN((JDIMENSION) numrows, inrows);
+ (*cinfo->cconvert->color_convert) (cinfo, input_buf + *in_row_ctr,
+ prep->color_buf,
+ (JDIMENSION) prep->next_buf_row,
+ numrows);
+ *in_row_ctr += numrows;
+ prep->next_buf_row += numrows;
+ prep->rows_to_go -= numrows;
+ /* If at bottom of image, pad to fill the conversion buffer. */
+ if (prep->rows_to_go == 0 &&
+ prep->next_buf_row < cinfo->max_v_samp_factor) {
+ for (ci = 0; ci < cinfo->num_components; ci++) {
+ expand_bottom_edge(prep->color_buf[ci], cinfo->image_width,
+ prep->next_buf_row, cinfo->max_v_samp_factor);
+ }
+ prep->next_buf_row = cinfo->max_v_samp_factor;
+ }
+ /* If we've filled the conversion buffer, empty it. */
+ if (prep->next_buf_row == cinfo->max_v_samp_factor) {
+ (*cinfo->downsample->downsample) (cinfo,
+ prep->color_buf, (JDIMENSION) 0,
+ output_buf, *out_row_group_ctr);
+ prep->next_buf_row = 0;
+ (*out_row_group_ctr)++;
+ }
+ /* If at bottom of image, pad the output to a full iMCU height.
+ * Note we assume the caller is providing a one-iMCU-height output buffer!
+ */
+ if (prep->rows_to_go == 0 &&
+ *out_row_group_ctr < out_row_groups_avail) {
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ expand_bottom_edge(output_buf[ci],
+ compptr->width_in_blocks * DCTSIZE,
+ (int) (*out_row_group_ctr * compptr->v_samp_factor),
+ (int) (out_row_groups_avail * compptr->v_samp_factor));
+ }
+ *out_row_group_ctr = out_row_groups_avail;
+ break; /* can exit outer loop without test */
+ }
+ }
+}
+
+
+#ifdef CONTEXT_ROWS_SUPPORTED
+
+/*
+ * Process some data in the context case.
+ */
+
+METHODDEF(void)
+pre_process_context (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JDIMENSION *in_row_ctr,
+ JDIMENSION in_rows_avail,
+ JSAMPIMAGE output_buf, JDIMENSION *out_row_group_ctr,
+ JDIMENSION out_row_groups_avail)
+{
+ my_prep_ptr prep = (my_prep_ptr) cinfo->prep;
+ int numrows, ci;
+ int buf_height = cinfo->max_v_samp_factor * 3;
+ JDIMENSION inrows;
+
+ while (*out_row_group_ctr < out_row_groups_avail) {
+ if (*in_row_ctr < in_rows_avail) {
+ /* Do color conversion to fill the conversion buffer. */
+ inrows = in_rows_avail - *in_row_ctr;
+ numrows = prep->next_buf_stop - prep->next_buf_row;
+ numrows = (int) MIN((JDIMENSION) numrows, inrows);
+ (*cinfo->cconvert->color_convert) (cinfo, input_buf + *in_row_ctr,
+ prep->color_buf,
+ (JDIMENSION) prep->next_buf_row,
+ numrows);
+ /* Pad at top of image, if first time through */
+ if (prep->rows_to_go == cinfo->image_height) {
+ for (ci = 0; ci < cinfo->num_components; ci++) {
+ int row;
+ for (row = 1; row <= cinfo->max_v_samp_factor; row++) {
+ jcopy_sample_rows(prep->color_buf[ci], 0,
+ prep->color_buf[ci], -row,
+ 1, cinfo->image_width);
+ }
+ }
+ }
+ *in_row_ctr += numrows;
+ prep->next_buf_row += numrows;
+ prep->rows_to_go -= numrows;
+ } else {
+ /* Return for more data, unless we are at the bottom of the image. */
+ if (prep->rows_to_go != 0)
+ break;
+ /* When at bottom of image, pad to fill the conversion buffer. */
+ if (prep->next_buf_row < prep->next_buf_stop) {
+ for (ci = 0; ci < cinfo->num_components; ci++) {
+ expand_bottom_edge(prep->color_buf[ci], cinfo->image_width,
+ prep->next_buf_row, prep->next_buf_stop);
+ }
+ prep->next_buf_row = prep->next_buf_stop;
+ }
+ }
+ /* If we've gotten enough data, downsample a row group. */
+ if (prep->next_buf_row == prep->next_buf_stop) {
+ (*cinfo->downsample->downsample) (cinfo,
+ prep->color_buf,
+ (JDIMENSION) prep->this_row_group,
+ output_buf, *out_row_group_ctr);
+ (*out_row_group_ctr)++;
+ /* Advance pointers with wraparound as necessary. */
+ prep->this_row_group += cinfo->max_v_samp_factor;
+ if (prep->this_row_group >= buf_height)
+ prep->this_row_group = 0;
+ if (prep->next_buf_row >= buf_height)
+ prep->next_buf_row = 0;
+ prep->next_buf_stop = prep->next_buf_row + cinfo->max_v_samp_factor;
+ }
+ }
+}
+
+
+/*
+ * Create the wrapped-around downsampling input buffer needed for context mode.
+ */
+
+LOCAL(void)
+create_context_buffer (j_compress_ptr cinfo)
+{
+ my_prep_ptr prep = (my_prep_ptr) cinfo->prep;
+ int rgroup_height = cinfo->max_v_samp_factor;
+ int ci, i;
+ jpeg_component_info * compptr;
+ JSAMPARRAY true_buffer, fake_buffer;
+
+ /* Grab enough space for fake row pointers for all the components;
+ * we need five row groups' worth of pointers for each component.
+ */
+ fake_buffer = (JSAMPARRAY)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (cinfo->num_components * 5 * rgroup_height) *
+ SIZEOF(JSAMPROW));
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Allocate the actual buffer space (3 row groups) for this component.
+ * We make the buffer wide enough to allow the downsampler to edge-expand
+ * horizontally within the buffer, if it so chooses.
+ */
+ true_buffer = (*cinfo->mem->alloc_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (JDIMENSION) (((long) compptr->width_in_blocks * DCTSIZE *
+ cinfo->max_h_samp_factor) / compptr->h_samp_factor),
+ (JDIMENSION) (3 * rgroup_height));
+ /* Copy true buffer row pointers into the middle of the fake row array */
+ MEMCOPY(fake_buffer + rgroup_height, true_buffer,
+ 3 * rgroup_height * SIZEOF(JSAMPROW));
+ /* Fill in the above and below wraparound pointers */
+ for (i = 0; i < rgroup_height; i++) {
+ fake_buffer[i] = true_buffer[2 * rgroup_height + i];
+ fake_buffer[4 * rgroup_height + i] = true_buffer[i];
+ }
+ prep->color_buf[ci] = fake_buffer + rgroup_height;
+ fake_buffer += 5 * rgroup_height; /* point to space for next component */
+ }
+}
+
+#endif /* CONTEXT_ROWS_SUPPORTED */
+
+
+/*
+ * Initialize preprocessing controller.
+ */
+
+GLOBAL(void)
+jinit_c_prep_controller (j_compress_ptr cinfo, int need_full_buffer)
+{
+ my_prep_ptr prep;
+ int ci;
+ jpeg_component_info * compptr;
+
+ if (need_full_buffer) /* safety check */
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+
+ prep = (my_prep_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_prep_controller));
+ cinfo->prep = (struct jpeg_c_prep_controller *) prep;
+ prep->pub.start_pass = start_pass_prep;
+
+ /* Allocate the color conversion buffer.
+ * We make the buffer wide enough to allow the downsampler to edge-expand
+ * horizontally within the buffer, if it so chooses.
+ */
+ if (cinfo->downsample->need_context_rows) {
+ /* Set up to provide context rows */
+#ifdef CONTEXT_ROWS_SUPPORTED
+ prep->pub.pre_process_data = pre_process_context;
+ create_context_buffer(cinfo);
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ } else {
+ /* No context, just make it tall enough for one row group */
+ prep->pub.pre_process_data = pre_process_data;
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ prep->color_buf[ci] = (*cinfo->mem->alloc_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (JDIMENSION) (((long) compptr->width_in_blocks * DCTSIZE *
+ cinfo->max_h_samp_factor) / compptr->h_samp_factor),
+ (JDIMENSION) cinfo->max_v_samp_factor);
+ }
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jcsample.cpp b/ml/dlib/dlib/external/libjpeg/jcsample.cpp
new file mode 100644
index 000000000..b73270120
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jcsample.cpp
@@ -0,0 +1,519 @@
+/*
+ * jcsample.c
+ *
+ * Copyright (C) 1991-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains downsampling routines.
+ *
+ * Downsampling input data is counted in "row groups". A row group
+ * is defined to be max_v_samp_factor pixel rows of each component,
+ * from which the downsampler produces v_samp_factor sample rows.
+ * A single row group is processed in each call to the downsampler module.
+ *
+ * The downsampler is responsible for edge-expansion of its output data
+ * to fill an integral number of DCT blocks horizontally. The source buffer
+ * may be modified if it is helpful for this purpose (the source buffer is
+ * allocated wide enough to correspond to the desired output width).
+ * The caller (the prep controller) is responsible for vertical padding.
+ *
+ * The downsampler may request "context rows" by setting need_context_rows
+ * during startup. In this case, the input arrays will contain at least
+ * one row group's worth of pixels above and below the passed-in data;
+ * the caller will create dummy rows at image top and bottom by replicating
+ * the first or last real pixel row.
+ *
+ * An excellent reference for image resampling is
+ * Digital Image Warping, George Wolberg, 1990.
+ * Pub. by IEEE Computer Society Press, Los Alamitos, CA. ISBN 0-8186-8944-7.
+ *
+ * The downsampling algorithm used here is a simple average of the source
+ * pixels covered by the output pixel. The hi-falutin sampling literature
+ * refers to this as a "box filter". In general the characteristics of a box
+ * filter are not very good, but for the specific cases we normally use (1:1
+ * and 2:1 ratios) the box is equivalent to a "triangle filter" which is not
+ * nearly so bad. If you intend to use other sampling ratios, you'd be well
+ * advised to improve this code.
+ *
+ * A simple input-smoothing capability is provided. This is mainly intended
+ * for cleaning up color-dithered GIF input files (if you find it inadequate,
+ * we suggest using an external filtering program such as pnmconvol). When
+ * enabled, each input pixel P is replaced by a weighted sum of itself and its
+ * eight neighbors. P's weight is 1-8*SF and each neighbor's weight is SF,
+ * where SF = (smoothing_factor / 1024).
+ * Currently, smoothing is only supported for 2h2v sampling factors.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Pointer to routine to downsample a single component */
+typedef JMETHOD(void, downsample1_ptr,
+ (j_compress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY output_data));
+
+/* Private subobject */
+
+typedef struct {
+ struct jpeg_downsampler pub; /* public fields */
+
+ /* Downsampling method pointers, one per component */
+ downsample1_ptr methods[MAX_COMPONENTS];
+} my_downsampler;
+
+typedef my_downsampler * my_downsample_ptr;
+
+
+/*
+ * Initialize for a downsampling pass.
+ */
+
+METHODDEF(void)
+start_pass_downsample (j_compress_ptr )//cinfo)
+{
+ /* no work for now */
+}
+
+
+/*
+ * Expand a component horizontally from width input_cols to width output_cols,
+ * by duplicating the rightmost samples.
+ */
+
+LOCAL(void)
+expand_right_edge (JSAMPARRAY image_data, int num_rows,
+ JDIMENSION input_cols, JDIMENSION output_cols)
+{
+ JSAMPROW ptr;
+ JSAMPLE pixval;
+ int count;
+ int row;
+ int numcols = (int) (output_cols - input_cols);
+
+ if (numcols > 0) {
+ for (row = 0; row < num_rows; row++) {
+ ptr = image_data[row] + input_cols;
+ pixval = ptr[-1]; /* don't need GETJSAMPLE() here */
+ for (count = numcols; count > 0; count--)
+ *ptr++ = pixval;
+ }
+ }
+}
+
+
+/*
+ * Do downsampling for a whole row group (all components).
+ *
+ * In this version we simply downsample each component independently.
+ */
+
+METHODDEF(void)
+sep_downsample (j_compress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION in_row_index,
+ JSAMPIMAGE output_buf, JDIMENSION out_row_group_index)
+{
+ my_downsample_ptr downsample = (my_downsample_ptr) cinfo->downsample;
+ int ci;
+ jpeg_component_info * compptr;
+ JSAMPARRAY in_ptr, out_ptr;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ in_ptr = input_buf[ci] + in_row_index;
+ out_ptr = output_buf[ci] + (out_row_group_index * compptr->v_samp_factor);
+ (*downsample->methods[ci]) (cinfo, compptr, in_ptr, out_ptr);
+ }
+}
+
+
+/*
+ * Downsample pixel values of a single component.
+ * One row group is processed per call.
+ * This version handles arbitrary integral sampling ratios, without smoothing.
+ * Note that this version is not actually used for customary sampling ratios.
+ */
+
+METHODDEF(void)
+int_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY output_data)
+{
+ int inrow, outrow, h_expand, v_expand, numpix, numpix2, h, v;
+ JDIMENSION outcol, outcol_h; /* outcol_h == outcol*h_expand */
+ JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE;
+ JSAMPROW inptr, outptr;
+ long outvalue;
+
+ h_expand = cinfo->max_h_samp_factor / compptr->h_samp_factor;
+ v_expand = cinfo->max_v_samp_factor / compptr->v_samp_factor;
+ numpix = h_expand * v_expand;
+ numpix2 = numpix/2;
+
+ /* Expand input data enough to let all the output samples be generated
+ * by the standard loop. Special-casing padded output would be more
+ * efficient.
+ */
+ expand_right_edge(input_data, cinfo->max_v_samp_factor,
+ cinfo->image_width, output_cols * h_expand);
+
+ inrow = 0;
+ for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) {
+ outptr = output_data[outrow];
+ for (outcol = 0, outcol_h = 0; outcol < output_cols;
+ outcol++, outcol_h += h_expand) {
+ outvalue = 0;
+ for (v = 0; v < v_expand; v++) {
+ inptr = input_data[inrow+v] + outcol_h;
+ for (h = 0; h < h_expand; h++) {
+ outvalue += (long) GETJSAMPLE(*inptr++);
+ }
+ }
+ *outptr++ = (JSAMPLE) ((outvalue + numpix2) / numpix);
+ }
+ inrow += v_expand;
+ }
+}
+
+
+/*
+ * Downsample pixel values of a single component.
+ * This version handles the special case of a full-size component,
+ * without smoothing.
+ */
+
+METHODDEF(void)
+fullsize_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY output_data)
+{
+ /* Copy the data */
+ jcopy_sample_rows(input_data, 0, output_data, 0,
+ cinfo->max_v_samp_factor, cinfo->image_width);
+ /* Edge-expand */
+ expand_right_edge(output_data, cinfo->max_v_samp_factor,
+ cinfo->image_width, compptr->width_in_blocks * DCTSIZE);
+}
+
+
+/*
+ * Downsample pixel values of a single component.
+ * This version handles the common case of 2:1 horizontal and 1:1 vertical,
+ * without smoothing.
+ *
+ * A note about the "bias" calculations: when rounding fractional values to
+ * integer, we do not want to always round 0.5 up to the next integer.
+ * If we did that, we'd introduce a noticeable bias towards larger values.
+ * Instead, this code is arranged so that 0.5 will be rounded up or down at
+ * alternate pixel locations (a simple ordered dither pattern).
+ */
+
+METHODDEF(void)
+h2v1_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY output_data)
+{
+ int outrow;
+ JDIMENSION outcol;
+ JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE;
+ JSAMPROW inptr, outptr;
+ int bias;
+
+ /* Expand input data enough to let all the output samples be generated
+ * by the standard loop. Special-casing padded output would be more
+ * efficient.
+ */
+ expand_right_edge(input_data, cinfo->max_v_samp_factor,
+ cinfo->image_width, output_cols * 2);
+
+ for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) {
+ outptr = output_data[outrow];
+ inptr = input_data[outrow];
+ bias = 0; /* bias = 0,1,0,1,... for successive samples */
+ for (outcol = 0; outcol < output_cols; outcol++) {
+ *outptr++ = (JSAMPLE) ((GETJSAMPLE(*inptr) + GETJSAMPLE(inptr[1])
+ + bias) >> 1);
+ bias ^= 1; /* 0=>1, 1=>0 */
+ inptr += 2;
+ }
+ }
+}
+
+
+/*
+ * Downsample pixel values of a single component.
+ * This version handles the standard case of 2:1 horizontal and 2:1 vertical,
+ * without smoothing.
+ */
+
+METHODDEF(void)
+h2v2_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY output_data)
+{
+ int inrow, outrow;
+ JDIMENSION outcol;
+ JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE;
+ JSAMPROW inptr0, inptr1, outptr;
+ int bias;
+
+ /* Expand input data enough to let all the output samples be generated
+ * by the standard loop. Special-casing padded output would be more
+ * efficient.
+ */
+ expand_right_edge(input_data, cinfo->max_v_samp_factor,
+ cinfo->image_width, output_cols * 2);
+
+ inrow = 0;
+ for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) {
+ outptr = output_data[outrow];
+ inptr0 = input_data[inrow];
+ inptr1 = input_data[inrow+1];
+ bias = 1; /* bias = 1,2,1,2,... for successive samples */
+ for (outcol = 0; outcol < output_cols; outcol++) {
+ *outptr++ = (JSAMPLE) ((GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[1]) +
+ GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[1])
+ + bias) >> 2);
+ bias ^= 3; /* 1=>2, 2=>1 */
+ inptr0 += 2; inptr1 += 2;
+ }
+ inrow += 2;
+ }
+}
+
+
+#ifdef INPUT_SMOOTHING_SUPPORTED
+
+/*
+ * Downsample pixel values of a single component.
+ * This version handles the standard case of 2:1 horizontal and 2:1 vertical,
+ * with smoothing. One row of context is required.
+ */
+
+METHODDEF(void)
+h2v2_smooth_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY output_data)
+{
+ int inrow, outrow;
+ JDIMENSION colctr;
+ JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE;
+ JSAMPROW inptr0, inptr1, above_ptr, below_ptr, outptr;
+ long membersum, neighsum, memberscale, neighscale;
+
+ /* Expand input data enough to let all the output samples be generated
+ * by the standard loop. Special-casing padded output would be more
+ * efficient.
+ */
+ expand_right_edge(input_data - 1, cinfo->max_v_samp_factor + 2,
+ cinfo->image_width, output_cols * 2);
+
+ /* We don't bother to form the individual "smoothed" input pixel values;
+ * we can directly compute the output which is the average of the four
+ * smoothed values. Each of the four member pixels contributes a fraction
+ * (1-8*SF) to its own smoothed image and a fraction SF to each of the three
+ * other smoothed pixels, therefore a total fraction (1-5*SF)/4 to the final
+ * output. The four corner-adjacent neighbor pixels contribute a fraction
+ * SF to just one smoothed pixel, or SF/4 to the final output; while the
+ * eight edge-adjacent neighbors contribute SF to each of two smoothed
+ * pixels, or SF/2 overall. In order to use integer arithmetic, these
+ * factors are scaled by 2^16 = 65536.
+ * Also recall that SF = smoothing_factor / 1024.
+ */
+
+ memberscale = 16384 - cinfo->smoothing_factor * 80; /* scaled (1-5*SF)/4 */
+ neighscale = cinfo->smoothing_factor * 16; /* scaled SF/4 */
+
+ inrow = 0;
+ for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) {
+ outptr = output_data[outrow];
+ inptr0 = input_data[inrow];
+ inptr1 = input_data[inrow+1];
+ above_ptr = input_data[inrow-1];
+ below_ptr = input_data[inrow+2];
+
+ /* Special case for first column: pretend column -1 is same as column 0 */
+ membersum = GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[1]) +
+ GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[1]);
+ neighsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(above_ptr[1]) +
+ GETJSAMPLE(*below_ptr) + GETJSAMPLE(below_ptr[1]) +
+ GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[2]) +
+ GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[2]);
+ neighsum += neighsum;
+ neighsum += GETJSAMPLE(*above_ptr) + GETJSAMPLE(above_ptr[2]) +
+ GETJSAMPLE(*below_ptr) + GETJSAMPLE(below_ptr[2]);
+ membersum = membersum * memberscale + neighsum * neighscale;
+ *outptr++ = (JSAMPLE) ((membersum + 32768) >> 16);
+ inptr0 += 2; inptr1 += 2; above_ptr += 2; below_ptr += 2;
+
+ for (colctr = output_cols - 2; colctr > 0; colctr--) {
+ /* sum of pixels directly mapped to this output element */
+ membersum = GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[1]) +
+ GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[1]);
+ /* sum of edge-neighbor pixels */
+ neighsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(above_ptr[1]) +
+ GETJSAMPLE(*below_ptr) + GETJSAMPLE(below_ptr[1]) +
+ GETJSAMPLE(inptr0[-1]) + GETJSAMPLE(inptr0[2]) +
+ GETJSAMPLE(inptr1[-1]) + GETJSAMPLE(inptr1[2]);
+ /* The edge-neighbors count twice as much as corner-neighbors */
+ neighsum += neighsum;
+ /* Add in the corner-neighbors */
+ neighsum += GETJSAMPLE(above_ptr[-1]) + GETJSAMPLE(above_ptr[2]) +
+ GETJSAMPLE(below_ptr[-1]) + GETJSAMPLE(below_ptr[2]);
+ /* form final output scaled up by 2^16 */
+ membersum = membersum * memberscale + neighsum * neighscale;
+ /* round, descale and output it */
+ *outptr++ = (JSAMPLE) ((membersum + 32768) >> 16);
+ inptr0 += 2; inptr1 += 2; above_ptr += 2; below_ptr += 2;
+ }
+
+ /* Special case for last column */
+ membersum = GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[1]) +
+ GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[1]);
+ neighsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(above_ptr[1]) +
+ GETJSAMPLE(*below_ptr) + GETJSAMPLE(below_ptr[1]) +
+ GETJSAMPLE(inptr0[-1]) + GETJSAMPLE(inptr0[1]) +
+ GETJSAMPLE(inptr1[-1]) + GETJSAMPLE(inptr1[1]);
+ neighsum += neighsum;
+ neighsum += GETJSAMPLE(above_ptr[-1]) + GETJSAMPLE(above_ptr[1]) +
+ GETJSAMPLE(below_ptr[-1]) + GETJSAMPLE(below_ptr[1]);
+ membersum = membersum * memberscale + neighsum * neighscale;
+ *outptr = (JSAMPLE) ((membersum + 32768) >> 16);
+
+ inrow += 2;
+ }
+}
+
+
+/*
+ * Downsample pixel values of a single component.
+ * This version handles the special case of a full-size component,
+ * with smoothing. One row of context is required.
+ */
+
+METHODDEF(void)
+fullsize_smooth_downsample (j_compress_ptr cinfo, jpeg_component_info *compptr,
+ JSAMPARRAY input_data, JSAMPARRAY output_data)
+{
+ int outrow;
+ JDIMENSION colctr;
+ JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE;
+ JSAMPROW inptr, above_ptr, below_ptr, outptr;
+ long membersum, neighsum, memberscale, neighscale;
+ int colsum, lastcolsum, nextcolsum;
+
+ /* Expand input data enough to let all the output samples be generated
+ * by the standard loop. Special-casing padded output would be more
+ * efficient.
+ */
+ expand_right_edge(input_data - 1, cinfo->max_v_samp_factor + 2,
+ cinfo->image_width, output_cols);
+
+ /* Each of the eight neighbor pixels contributes a fraction SF to the
+ * smoothed pixel, while the main pixel contributes (1-8*SF). In order
+ * to use integer arithmetic, these factors are multiplied by 2^16 = 65536.
+ * Also recall that SF = smoothing_factor / 1024.
+ */
+
+ memberscale = 65536L - cinfo->smoothing_factor * 512L; /* scaled 1-8*SF */
+ neighscale = cinfo->smoothing_factor * 64; /* scaled SF */
+
+ for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) {
+ outptr = output_data[outrow];
+ inptr = input_data[outrow];
+ above_ptr = input_data[outrow-1];
+ below_ptr = input_data[outrow+1];
+
+ /* Special case for first column */
+ colsum = GETJSAMPLE(*above_ptr++) + GETJSAMPLE(*below_ptr++) +
+ GETJSAMPLE(*inptr);
+ membersum = GETJSAMPLE(*inptr++);
+ nextcolsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(*below_ptr) +
+ GETJSAMPLE(*inptr);
+ neighsum = colsum + (colsum - membersum) + nextcolsum;
+ membersum = membersum * memberscale + neighsum * neighscale;
+ *outptr++ = (JSAMPLE) ((membersum + 32768) >> 16);
+ lastcolsum = colsum; colsum = nextcolsum;
+
+ for (colctr = output_cols - 2; colctr > 0; colctr--) {
+ membersum = GETJSAMPLE(*inptr++);
+ above_ptr++; below_ptr++;
+ nextcolsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(*below_ptr) +
+ GETJSAMPLE(*inptr);
+ neighsum = lastcolsum + (colsum - membersum) + nextcolsum;
+ membersum = membersum * memberscale + neighsum * neighscale;
+ *outptr++ = (JSAMPLE) ((membersum + 32768) >> 16);
+ lastcolsum = colsum; colsum = nextcolsum;
+ }
+
+ /* Special case for last column */
+ membersum = GETJSAMPLE(*inptr);
+ neighsum = lastcolsum + (colsum - membersum) + colsum;
+ membersum = membersum * memberscale + neighsum * neighscale;
+ *outptr = (JSAMPLE) ((membersum + 32768) >> 16);
+
+ }
+}
+
+#endif /* INPUT_SMOOTHING_SUPPORTED */
+
+
+/*
+ * Module initialization routine for downsampling.
+ * Note that we must select a routine for each component.
+ */
+
+GLOBAL(void)
+jinit_downsampler (j_compress_ptr cinfo)
+{
+ my_downsample_ptr downsample;
+ int ci;
+ jpeg_component_info * compptr;
+ int smoothok = TRUE;
+
+ downsample = (my_downsample_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_downsampler));
+ cinfo->downsample = (struct jpeg_downsampler *) downsample;
+ downsample->pub.start_pass = start_pass_downsample;
+ downsample->pub.downsample = sep_downsample;
+ downsample->pub.need_context_rows = FALSE;
+
+ if (cinfo->CCIR601_sampling)
+ ERREXIT(cinfo, JERR_CCIR601_NOTIMPL);
+
+ /* Verify we can handle the sampling factors, and set up method pointers */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ if (compptr->h_samp_factor == cinfo->max_h_samp_factor &&
+ compptr->v_samp_factor == cinfo->max_v_samp_factor) {
+#ifdef INPUT_SMOOTHING_SUPPORTED
+ if (cinfo->smoothing_factor) {
+ downsample->methods[ci] = fullsize_smooth_downsample;
+ downsample->pub.need_context_rows = TRUE;
+ } else
+#endif
+ downsample->methods[ci] = fullsize_downsample;
+ } else if (compptr->h_samp_factor * 2 == cinfo->max_h_samp_factor &&
+ compptr->v_samp_factor == cinfo->max_v_samp_factor) {
+ smoothok = FALSE;
+ downsample->methods[ci] = h2v1_downsample;
+ } else if (compptr->h_samp_factor * 2 == cinfo->max_h_samp_factor &&
+ compptr->v_samp_factor * 2 == cinfo->max_v_samp_factor) {
+#ifdef INPUT_SMOOTHING_SUPPORTED
+ if (cinfo->smoothing_factor) {
+ downsample->methods[ci] = h2v2_smooth_downsample;
+ downsample->pub.need_context_rows = TRUE;
+ } else
+#endif
+ downsample->methods[ci] = h2v2_downsample;
+ } else if ((cinfo->max_h_samp_factor % compptr->h_samp_factor) == 0 &&
+ (cinfo->max_v_samp_factor % compptr->v_samp_factor) == 0) {
+ smoothok = FALSE;
+ downsample->methods[ci] = int_downsample;
+ } else
+ ERREXIT(cinfo, JERR_FRACT_SAMPLE_NOTIMPL);
+ }
+
+#ifdef INPUT_SMOOTHING_SUPPORTED
+ if (cinfo->smoothing_factor && !smoothok)
+ TRACEMS(cinfo, 0, JTRC_SMOOTH_NOTIMPL);
+#endif
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdapimin.cpp b/ml/dlib/dlib/external/libjpeg/jdapimin.cpp
new file mode 100644
index 000000000..3ea5bf161
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdapimin.cpp
@@ -0,0 +1,395 @@
+/*
+ * jdapimin.c
+ *
+ * Copyright (C) 1994-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains application interface code for the decompression half
+ * of the JPEG library. These are the "minimum" API routines that may be
+ * needed in either the normal full-decompression case or the
+ * transcoding-only case.
+ *
+ * Most of the routines intended to be called directly by an application
+ * are in this file or in jdapistd.c. But also see jcomapi.c for routines
+ * shared by compression and decompression, and jdtrans.c for the transcoding
+ * case.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/*
+ * Initialization of a JPEG decompression object.
+ * The error manager must already be set up (in case memory manager fails).
+ */
+
+GLOBAL(void)
+jpeg_CreateDecompress (j_decompress_ptr cinfo, int version, size_t structsize)
+{
+ int i;
+
+ /* Guard against version mismatches between library and caller. */
+ cinfo->mem = NULL; /* so jpeg_destroy knows mem mgr not called */
+ if (version != JPEG_LIB_VERSION)
+ ERREXIT2(cinfo, JERR_BAD_LIB_VERSION, JPEG_LIB_VERSION, version);
+ if (structsize != SIZEOF(struct jpeg_decompress_struct))
+ ERREXIT2(cinfo, JERR_BAD_STRUCT_SIZE,
+ (int) SIZEOF(struct jpeg_decompress_struct), (int) structsize);
+
+ /* For debugging purposes, we zero the whole master structure.
+ * But the application has already set the err pointer, and may have set
+ * client_data, so we have to save and restore those fields.
+ * Note: if application hasn't set client_data, tools like Purify may
+ * complain here.
+ */
+ {
+ struct jpeg_error_mgr * err = cinfo->err;
+ void * client_data = cinfo->client_data; /* ignore Purify complaint here */
+ MEMZERO(cinfo, SIZEOF(struct jpeg_decompress_struct));
+ cinfo->err = err;
+ cinfo->client_data = client_data;
+ }
+ cinfo->is_decompressor = TRUE;
+
+ /* Initialize a memory manager instance for this object */
+ jinit_memory_mgr((j_common_ptr) cinfo);
+
+ /* Zero out pointers to permanent structures. */
+ cinfo->progress = NULL;
+ cinfo->src = NULL;
+
+ for (i = 0; i < NUM_QUANT_TBLS; i++)
+ cinfo->quant_tbl_ptrs[i] = NULL;
+
+ for (i = 0; i < NUM_HUFF_TBLS; i++) {
+ cinfo->dc_huff_tbl_ptrs[i] = NULL;
+ cinfo->ac_huff_tbl_ptrs[i] = NULL;
+ }
+
+ /* Initialize marker processor so application can override methods
+ * for COM, APPn markers before calling jpeg_read_header.
+ */
+ cinfo->marker_list = NULL;
+ jinit_marker_reader(cinfo);
+
+ /* And initialize the overall input controller. */
+ jinit_input_controller(cinfo);
+
+ /* OK, I'm ready */
+ cinfo->global_state = DSTATE_START;
+}
+
+
+/*
+ * Destruction of a JPEG decompression object
+ */
+
+GLOBAL(void)
+jpeg_destroy_decompress (j_decompress_ptr cinfo)
+{
+ jpeg_destroy((j_common_ptr) cinfo); /* use common routine */
+}
+
+
+/*
+ * Abort processing of a JPEG decompression operation,
+ * but don't destroy the object itself.
+ */
+
+GLOBAL(void)
+jpeg_abort_decompress (j_decompress_ptr cinfo)
+{
+ jpeg_abort((j_common_ptr) cinfo); /* use common routine */
+}
+
+
+/*
+ * Set default decompression parameters.
+ */
+
+LOCAL(void)
+default_decompress_parms (j_decompress_ptr cinfo)
+{
+ /* Guess the input colorspace, and set output colorspace accordingly. */
+ /* (Wish JPEG committee had provided a real way to specify this...) */
+ /* Note application may override our guesses. */
+ switch (cinfo->num_components) {
+ case 1:
+ cinfo->jpeg_color_space = JCS_GRAYSCALE;
+ cinfo->out_color_space = JCS_GRAYSCALE;
+ break;
+
+ case 3:
+ if (cinfo->saw_JFIF_marker) {
+ cinfo->jpeg_color_space = JCS_YCbCr; /* JFIF implies YCbCr */
+ } else if (cinfo->saw_Adobe_marker) {
+ switch (cinfo->Adobe_transform) {
+ case 0:
+ cinfo->jpeg_color_space = JCS_RGB;
+ break;
+ case 1:
+ cinfo->jpeg_color_space = JCS_YCbCr;
+ break;
+ default:
+ WARNMS1(cinfo, JWRN_ADOBE_XFORM, cinfo->Adobe_transform);
+ cinfo->jpeg_color_space = JCS_YCbCr; /* assume it's YCbCr */
+ break;
+ }
+ } else {
+ /* Saw no special markers, try to guess from the component IDs */
+ int cid0 = cinfo->comp_info[0].component_id;
+ int cid1 = cinfo->comp_info[1].component_id;
+ int cid2 = cinfo->comp_info[2].component_id;
+
+ if (cid0 == 1 && cid1 == 2 && cid2 == 3)
+ cinfo->jpeg_color_space = JCS_YCbCr; /* assume JFIF w/out marker */
+ else if (cid0 == 82 && cid1 == 71 && cid2 == 66)
+ cinfo->jpeg_color_space = JCS_RGB; /* ASCII 'R', 'G', 'B' */
+ else {
+ TRACEMS3(cinfo, 1, JTRC_UNKNOWN_IDS, cid0, cid1, cid2);
+ cinfo->jpeg_color_space = JCS_YCbCr; /* assume it's YCbCr */
+ }
+ }
+ /* Always guess RGB is proper output colorspace. */
+ cinfo->out_color_space = JCS_RGB;
+ break;
+
+ case 4:
+ if (cinfo->saw_Adobe_marker) {
+ switch (cinfo->Adobe_transform) {
+ case 0:
+ cinfo->jpeg_color_space = JCS_CMYK;
+ break;
+ case 2:
+ cinfo->jpeg_color_space = JCS_YCCK;
+ break;
+ default:
+ WARNMS1(cinfo, JWRN_ADOBE_XFORM, cinfo->Adobe_transform);
+ cinfo->jpeg_color_space = JCS_YCCK; /* assume it's YCCK */
+ break;
+ }
+ } else {
+ /* No special markers, assume straight CMYK. */
+ cinfo->jpeg_color_space = JCS_CMYK;
+ }
+ cinfo->out_color_space = JCS_CMYK;
+ break;
+
+ default:
+ cinfo->jpeg_color_space = JCS_UNKNOWN;
+ cinfo->out_color_space = JCS_UNKNOWN;
+ break;
+ }
+
+ /* Set defaults for other decompression parameters. */
+ cinfo->scale_num = 1; /* 1:1 scaling */
+ cinfo->scale_denom = 1;
+ cinfo->output_gamma = 1.0;
+ cinfo->buffered_image = FALSE;
+ cinfo->raw_data_out = FALSE;
+ cinfo->dct_method = JDCT_DEFAULT;
+ cinfo->do_fancy_upsampling = TRUE;
+ cinfo->do_block_smoothing = TRUE;
+ cinfo->quantize_colors = FALSE;
+ /* We set these in case application only sets quantize_colors. */
+ cinfo->dither_mode = JDITHER_FS;
+#ifdef QUANT_2PASS_SUPPORTED
+ cinfo->two_pass_quantize = TRUE;
+#else
+ cinfo->two_pass_quantize = FALSE;
+#endif
+ cinfo->desired_number_of_colors = 256;
+ cinfo->colormap = NULL;
+ /* Initialize for no mode change in buffered-image mode. */
+ cinfo->enable_1pass_quant = FALSE;
+ cinfo->enable_external_quant = FALSE;
+ cinfo->enable_2pass_quant = FALSE;
+}
+
+
+/*
+ * Decompression startup: read start of JPEG datastream to see what's there.
+ * Need only initialize JPEG object and supply a data source before calling.
+ *
+ * This routine will read as far as the first SOS marker (ie, actual start of
+ * compressed data), and will save all tables and parameters in the JPEG
+ * object. It will also initialize the decompression parameters to default
+ * values, and finally return JPEG_HEADER_OK. On return, the application may
+ * adjust the decompression parameters and then call jpeg_start_decompress.
+ * (Or, if the application only wanted to determine the image parameters,
+ * the data need not be decompressed. In that case, call jpeg_abort or
+ * jpeg_destroy to release any temporary space.)
+ * If an abbreviated (tables only) datastream is presented, the routine will
+ * return JPEG_HEADER_TABLES_ONLY upon reaching EOI. The application may then
+ * re-use the JPEG object to read the abbreviated image datastream(s).
+ * It is unnecessary (but OK) to call jpeg_abort in this case.
+ * The JPEG_SUSPENDED return code only occurs if the data source module
+ * requests suspension of the decompressor. In this case the application
+ * should load more source data and then re-call jpeg_read_header to resume
+ * processing.
+ * If a non-suspending data source is used and require_image is TRUE, then the
+ * return code need not be inspected since only JPEG_HEADER_OK is possible.
+ *
+ * This routine is now just a front end to jpeg_consume_input, with some
+ * extra error checking.
+ */
+
+GLOBAL(int)
+jpeg_read_header (j_decompress_ptr cinfo, int require_image)
+{
+ int retcode;
+
+ if (cinfo->global_state != DSTATE_START &&
+ cinfo->global_state != DSTATE_INHEADER)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ retcode = jpeg_consume_input(cinfo);
+
+ switch (retcode) {
+ case JPEG_REACHED_SOS:
+ retcode = JPEG_HEADER_OK;
+ break;
+ case JPEG_REACHED_EOI:
+ if (require_image) /* Complain if application wanted an image */
+ ERREXIT(cinfo, JERR_NO_IMAGE);
+ /* Reset to start state; it would be safer to require the application to
+ * call jpeg_abort, but we can't change it now for compatibility reasons.
+ * A side effect is to free any temporary memory (there shouldn't be any).
+ */
+ jpeg_abort((j_common_ptr) cinfo); /* sets state = DSTATE_START */
+ retcode = JPEG_HEADER_TABLES_ONLY;
+ break;
+ case JPEG_SUSPENDED:
+ /* no work */
+ break;
+ }
+
+ return retcode;
+}
+
+
+/*
+ * Consume data in advance of what the decompressor requires.
+ * This can be called at any time once the decompressor object has
+ * been created and a data source has been set up.
+ *
+ * This routine is essentially a state machine that handles a couple
+ * of critical state-transition actions, namely initial setup and
+ * transition from header scanning to ready-for-start_decompress.
+ * All the actual input is done via the input controller's consume_input
+ * method.
+ */
+
+GLOBAL(int)
+jpeg_consume_input (j_decompress_ptr cinfo)
+{
+ int retcode = JPEG_SUSPENDED;
+
+ /* NB: every possible DSTATE value should be listed in this switch */
+ switch (cinfo->global_state) {
+ case DSTATE_START:
+ /* Start-of-datastream actions: reset appropriate modules */
+ (*cinfo->inputctl->reset_input_controller) (cinfo);
+ /* Initialize application's data source module */
+ (*cinfo->src->init_source) (cinfo);
+ cinfo->global_state = DSTATE_INHEADER;
+ /*FALLTHROUGH*/
+ case DSTATE_INHEADER:
+ retcode = (*cinfo->inputctl->consume_input) (cinfo);
+ if (retcode == JPEG_REACHED_SOS) { /* Found SOS, prepare to decompress */
+ /* Set up default parameters based on header data */
+ default_decompress_parms(cinfo);
+ /* Set global state: ready for start_decompress */
+ cinfo->global_state = DSTATE_READY;
+ }
+ break;
+ case DSTATE_READY:
+ /* Can't advance past first SOS until start_decompress is called */
+ retcode = JPEG_REACHED_SOS;
+ break;
+ case DSTATE_PRELOAD:
+ case DSTATE_PRESCAN:
+ case DSTATE_SCANNING:
+ case DSTATE_RAW_OK:
+ case DSTATE_BUFIMAGE:
+ case DSTATE_BUFPOST:
+ case DSTATE_STOPPING:
+ retcode = (*cinfo->inputctl->consume_input) (cinfo);
+ break;
+ default:
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ }
+ return retcode;
+}
+
+
+/*
+ * Have we finished reading the input file?
+ */
+
+GLOBAL(int)
+jpeg_input_complete (j_decompress_ptr cinfo)
+{
+ /* Check for valid jpeg object */
+ if (cinfo->global_state < DSTATE_START ||
+ cinfo->global_state > DSTATE_STOPPING)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ return cinfo->inputctl->eoi_reached;
+}
+
+
+/*
+ * Is there more than one scan?
+ */
+
+GLOBAL(int)
+jpeg_has_multiple_scans (j_decompress_ptr cinfo)
+{
+ /* Only valid after jpeg_read_header completes */
+ if (cinfo->global_state < DSTATE_READY ||
+ cinfo->global_state > DSTATE_STOPPING)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ return cinfo->inputctl->has_multiple_scans;
+}
+
+
+/*
+ * Finish JPEG decompression.
+ *
+ * This will normally just verify the file trailer and release temp storage.
+ *
+ * Returns FALSE if suspended. The return value need be inspected only if
+ * a suspending data source is used.
+ */
+
+GLOBAL(int)
+jpeg_finish_decompress (j_decompress_ptr cinfo)
+{
+ if ((cinfo->global_state == DSTATE_SCANNING ||
+ cinfo->global_state == DSTATE_RAW_OK) && ! cinfo->buffered_image) {
+ /* Terminate final pass of non-buffered mode */
+ if (cinfo->output_scanline < cinfo->output_height)
+ ERREXIT(cinfo, JERR_TOO_LITTLE_DATA);
+ (*cinfo->master->finish_output_pass) (cinfo);
+ cinfo->global_state = DSTATE_STOPPING;
+ } else if (cinfo->global_state == DSTATE_BUFIMAGE) {
+ /* Finishing after a buffered-image operation */
+ cinfo->global_state = DSTATE_STOPPING;
+ } else if (cinfo->global_state != DSTATE_STOPPING) {
+ /* STOPPING = repeat call after a suspension, anything else is error */
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ }
+ /* Read until EOI */
+ while (! cinfo->inputctl->eoi_reached) {
+ if ((*cinfo->inputctl->consume_input) (cinfo) == JPEG_SUSPENDED)
+ return FALSE; /* Suspend, come back later */
+ }
+ /* Do final cleanup */
+ (*cinfo->src->term_source) (cinfo);
+ /* We can use jpeg_abort to release memory and reset global_state */
+ jpeg_abort((j_common_ptr) cinfo);
+ return TRUE;
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdapistd.cpp b/ml/dlib/dlib/external/libjpeg/jdapistd.cpp
new file mode 100644
index 000000000..03d909e00
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdapistd.cpp
@@ -0,0 +1,275 @@
+/*
+ * jdapistd.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains application interface code for the decompression half
+ * of the JPEG library. These are the "standard" API routines that are
+ * used in the normal full-decompression case. They are not used by a
+ * transcoding-only application. Note that if an application links in
+ * jpeg_start_decompress, it will end up linking in the entire decompressor.
+ * We thus must separate this file from jdapimin.c to avoid linking the
+ * whole decompression library into a transcoder.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Forward declarations */
+LOCAL(int) output_pass_setup JPP((j_decompress_ptr cinfo));
+
+
+/*
+ * Decompression initialization.
+ * jpeg_read_header must be completed before calling this.
+ *
+ * If a multipass operating mode was selected, this will do all but the
+ * last pass, and thus may take a great deal of time.
+ *
+ * Returns FALSE if suspended. The return value need be inspected only if
+ * a suspending data source is used.
+ */
+
+GLOBAL(int)
+jpeg_start_decompress (j_decompress_ptr cinfo)
+{
+ if (cinfo->global_state == DSTATE_READY) {
+ /* First call: initialize master control, select active modules */
+ jinit_master_decompress(cinfo);
+ if (cinfo->buffered_image) {
+ /* No more work here; expecting jpeg_start_output next */
+ cinfo->global_state = DSTATE_BUFIMAGE;
+ return TRUE;
+ }
+ cinfo->global_state = DSTATE_PRELOAD;
+ }
+ if (cinfo->global_state == DSTATE_PRELOAD) {
+ /* If file has multiple scans, absorb them all into the coef buffer */
+ if (cinfo->inputctl->has_multiple_scans) {
+#ifdef D_MULTISCAN_FILES_SUPPORTED
+ for (;;) {
+ int retcode;
+ /* Call progress monitor hook if present */
+ if (cinfo->progress != NULL)
+ (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo);
+ /* Absorb some more input */
+ retcode = (*cinfo->inputctl->consume_input) (cinfo);
+ if (retcode == JPEG_SUSPENDED)
+ return FALSE;
+ if (retcode == JPEG_REACHED_EOI)
+ break;
+ /* Advance progress counter if appropriate */
+ if (cinfo->progress != NULL &&
+ (retcode == JPEG_ROW_COMPLETED || retcode == JPEG_REACHED_SOS)) {
+ if (++cinfo->progress->pass_counter >= cinfo->progress->pass_limit) {
+ /* jdmaster underestimated number of scans; ratchet up one scan */
+ cinfo->progress->pass_limit += (long) cinfo->total_iMCU_rows;
+ }
+ }
+ }
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif /* D_MULTISCAN_FILES_SUPPORTED */
+ }
+ cinfo->output_scan_number = cinfo->input_scan_number;
+ } else if (cinfo->global_state != DSTATE_PRESCAN)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ /* Perform any dummy output passes, and set up for the final pass */
+ return output_pass_setup(cinfo);
+}
+
+
+/*
+ * Set up for an output pass, and perform any dummy pass(es) needed.
+ * Common subroutine for jpeg_start_decompress and jpeg_start_output.
+ * Entry: global_state = DSTATE_PRESCAN only if previously suspended.
+ * Exit: If done, returns TRUE and sets global_state for proper output mode.
+ * If suspended, returns FALSE and sets global_state = DSTATE_PRESCAN.
+ */
+
+LOCAL(int)
+output_pass_setup (j_decompress_ptr cinfo)
+{
+ if (cinfo->global_state != DSTATE_PRESCAN) {
+ /* First call: do pass setup */
+ (*cinfo->master->prepare_for_output_pass) (cinfo);
+ cinfo->output_scanline = 0;
+ cinfo->global_state = DSTATE_PRESCAN;
+ }
+ /* Loop over any required dummy passes */
+ while (cinfo->master->is_dummy_pass) {
+#ifdef QUANT_2PASS_SUPPORTED
+ /* Crank through the dummy pass */
+ while (cinfo->output_scanline < cinfo->output_height) {
+ JDIMENSION last_scanline;
+ /* Call progress monitor hook if present */
+ if (cinfo->progress != NULL) {
+ cinfo->progress->pass_counter = (long) cinfo->output_scanline;
+ cinfo->progress->pass_limit = (long) cinfo->output_height;
+ (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo);
+ }
+ /* Process some data */
+ last_scanline = cinfo->output_scanline;
+ (*cinfo->main->process_data) (cinfo, (JSAMPARRAY) NULL,
+ &cinfo->output_scanline, (JDIMENSION) 0);
+ if (cinfo->output_scanline == last_scanline)
+ return FALSE; /* No progress made, must suspend */
+ }
+ /* Finish up dummy pass, and set up for another one */
+ (*cinfo->master->finish_output_pass) (cinfo);
+ (*cinfo->master->prepare_for_output_pass) (cinfo);
+ cinfo->output_scanline = 0;
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif /* QUANT_2PASS_SUPPORTED */
+ }
+ /* Ready for application to drive output pass through
+ * jpeg_read_scanlines or jpeg_read_raw_data.
+ */
+ cinfo->global_state = cinfo->raw_data_out ? DSTATE_RAW_OK : DSTATE_SCANNING;
+ return TRUE;
+}
+
+
+/*
+ * Read some scanlines of data from the JPEG decompressor.
+ *
+ * The return value will be the number of lines actually read.
+ * This may be less than the number requested in several cases,
+ * including bottom of image, data source suspension, and operating
+ * modes that emit multiple scanlines at a time.
+ *
+ * Note: we warn about excess calls to jpeg_read_scanlines() since
+ * this likely signals an application programmer error. However,
+ * an oversize buffer (max_lines > scanlines remaining) is not an error.
+ */
+
+GLOBAL(JDIMENSION)
+jpeg_read_scanlines (j_decompress_ptr cinfo, JSAMPARRAY scanlines,
+ JDIMENSION max_lines)
+{
+ JDIMENSION row_ctr;
+
+ if (cinfo->global_state != DSTATE_SCANNING)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ if (cinfo->output_scanline >= cinfo->output_height) {
+ WARNMS(cinfo, JWRN_TOO_MUCH_DATA);
+ return 0;
+ }
+
+ /* Call progress monitor hook if present */
+ if (cinfo->progress != NULL) {
+ cinfo->progress->pass_counter = (long) cinfo->output_scanline;
+ cinfo->progress->pass_limit = (long) cinfo->output_height;
+ (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo);
+ }
+
+ /* Process some data */
+ row_ctr = 0;
+ (*cinfo->main->process_data) (cinfo, scanlines, &row_ctr, max_lines);
+ cinfo->output_scanline += row_ctr;
+ return row_ctr;
+}
+
+
+/*
+ * Alternate entry point to read raw data.
+ * Processes exactly one iMCU row per call, unless suspended.
+ */
+
+GLOBAL(JDIMENSION)
+jpeg_read_raw_data (j_decompress_ptr cinfo, JSAMPIMAGE data,
+ JDIMENSION max_lines)
+{
+ JDIMENSION lines_per_iMCU_row;
+
+ if (cinfo->global_state != DSTATE_RAW_OK)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ if (cinfo->output_scanline >= cinfo->output_height) {
+ WARNMS(cinfo, JWRN_TOO_MUCH_DATA);
+ return 0;
+ }
+
+ /* Call progress monitor hook if present */
+ if (cinfo->progress != NULL) {
+ cinfo->progress->pass_counter = (long) cinfo->output_scanline;
+ cinfo->progress->pass_limit = (long) cinfo->output_height;
+ (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo);
+ }
+
+ /* Verify that at least one iMCU row can be returned. */
+ lines_per_iMCU_row = cinfo->max_v_samp_factor * cinfo->min_DCT_scaled_size;
+ if (max_lines < lines_per_iMCU_row)
+ ERREXIT(cinfo, JERR_BUFFER_SIZE);
+
+ /* Decompress directly into user's buffer. */
+ if (! (*cinfo->coef->decompress_data) (cinfo, data))
+ return 0; /* suspension forced, can do nothing more */
+
+ /* OK, we processed one iMCU row. */
+ cinfo->output_scanline += lines_per_iMCU_row;
+ return lines_per_iMCU_row;
+}
+
+
+/* Additional entry points for buffered-image mode. */
+
+#ifdef D_MULTISCAN_FILES_SUPPORTED
+
+/*
+ * Initialize for an output pass in buffered-image mode.
+ */
+
+GLOBAL(int)
+jpeg_start_output (j_decompress_ptr cinfo, int scan_number)
+{
+ if (cinfo->global_state != DSTATE_BUFIMAGE &&
+ cinfo->global_state != DSTATE_PRESCAN)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ /* Limit scan number to valid range */
+ if (scan_number <= 0)
+ scan_number = 1;
+ if (cinfo->inputctl->eoi_reached &&
+ scan_number > cinfo->input_scan_number)
+ scan_number = cinfo->input_scan_number;
+ cinfo->output_scan_number = scan_number;
+ /* Perform any dummy output passes, and set up for the real pass */
+ return output_pass_setup(cinfo);
+}
+
+
+/*
+ * Finish up after an output pass in buffered-image mode.
+ *
+ * Returns FALSE if suspended. The return value need be inspected only if
+ * a suspending data source is used.
+ */
+
+GLOBAL(int)
+jpeg_finish_output (j_decompress_ptr cinfo)
+{
+ if ((cinfo->global_state == DSTATE_SCANNING ||
+ cinfo->global_state == DSTATE_RAW_OK) && cinfo->buffered_image) {
+ /* Terminate this pass. */
+ /* We do not require the whole pass to have been completed. */
+ (*cinfo->master->finish_output_pass) (cinfo);
+ cinfo->global_state = DSTATE_BUFPOST;
+ } else if (cinfo->global_state != DSTATE_BUFPOST) {
+ /* BUFPOST = repeat call after a suspension, anything else is error */
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+ }
+ /* Read markers looking for SOS or EOI */
+ while (cinfo->input_scan_number <= cinfo->output_scan_number &&
+ ! cinfo->inputctl->eoi_reached) {
+ if ((*cinfo->inputctl->consume_input) (cinfo) == JPEG_SUSPENDED)
+ return FALSE; /* Suspend, come back later */
+ }
+ cinfo->global_state = DSTATE_BUFIMAGE;
+ return TRUE;
+}
+
+#endif /* D_MULTISCAN_FILES_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jdatadst.cpp b/ml/dlib/dlib/external/libjpeg/jdatadst.cpp
new file mode 100644
index 000000000..afa3c83c6
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdatadst.cpp
@@ -0,0 +1,151 @@
+/*
+ * jdatadst.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains compression data destination routines for the case of
+ * emitting JPEG data to a file (or any stdio stream). While these routines
+ * are sufficient for most applications, some will want to use a different
+ * destination manager.
+ * IMPORTANT: we assume that fwrite() will correctly transcribe an array of
+ * JOCTETs into 8-bit-wide elements on external storage. If char is wider
+ * than 8 bits on your machine, you may need to do some tweaking.
+ */
+
+/* this is not a core library module, so it doesn't define JPEG_INTERNALS */
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jerror.h"
+
+
+/* Expanded data destination object for stdio output */
+
+typedef struct {
+ struct jpeg_destination_mgr pub; /* public fields */
+
+ FILE * outfile; /* target stream */
+ JOCTET * buffer; /* start of buffer */
+} my_destination_mgr;
+
+typedef my_destination_mgr * my_dest_ptr;
+
+#define OUTPUT_BUF_SIZE 4096 /* choose an efficiently fwrite'able size */
+
+
+/*
+ * Initialize destination --- called by jpeg_start_compress
+ * before any data is actually written.
+ */
+
+METHODDEF(void)
+init_destination (j_compress_ptr cinfo)
+{
+ my_dest_ptr dest = (my_dest_ptr) cinfo->dest;
+
+ /* Allocate the output buffer --- it will be released when done with image */
+ dest->buffer = (JOCTET *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ OUTPUT_BUF_SIZE * SIZEOF(JOCTET));
+
+ dest->pub.next_output_byte = dest->buffer;
+ dest->pub.free_in_buffer = OUTPUT_BUF_SIZE;
+}
+
+
+/*
+ * Empty the output buffer --- called whenever buffer fills up.
+ *
+ * In typical applications, this should write the entire output buffer
+ * (ignoring the current state of next_output_byte & free_in_buffer),
+ * reset the pointer & count to the start of the buffer, and return TRUE
+ * indicating that the buffer has been dumped.
+ *
+ * In applications that need to be able to suspend compression due to output
+ * overrun, a FALSE return indicates that the buffer cannot be emptied now.
+ * In this situation, the compressor will return to its caller (possibly with
+ * an indication that it has not accepted all the supplied scanlines). The
+ * application should resume compression after it has made more room in the
+ * output buffer. Note that there are substantial restrictions on the use of
+ * suspension --- see the documentation.
+ *
+ * When suspending, the compressor will back up to a convenient restart point
+ * (typically the start of the current MCU). next_output_byte & free_in_buffer
+ * indicate where the restart point will be if the current call returns FALSE.
+ * Data beyond this point will be regenerated after resumption, so do not
+ * write it out when emptying the buffer externally.
+ */
+
+METHODDEF(int)
+empty_output_buffer (j_compress_ptr cinfo)
+{
+ my_dest_ptr dest = (my_dest_ptr) cinfo->dest;
+
+ if (JFWRITE(dest->outfile, dest->buffer, OUTPUT_BUF_SIZE) !=
+ (size_t) OUTPUT_BUF_SIZE)
+ ERREXIT(cinfo, JERR_FILE_WRITE);
+
+ dest->pub.next_output_byte = dest->buffer;
+ dest->pub.free_in_buffer = OUTPUT_BUF_SIZE;
+
+ return TRUE;
+}
+
+
+/*
+ * Terminate destination --- called by jpeg_finish_compress
+ * after all data has been written. Usually needs to flush buffer.
+ *
+ * NB: *not* called by jpeg_abort or jpeg_destroy; surrounding
+ * application must deal with any cleanup that should happen even
+ * for error exit.
+ */
+
+METHODDEF(void)
+term_destination (j_compress_ptr cinfo)
+{
+ my_dest_ptr dest = (my_dest_ptr) cinfo->dest;
+ size_t datacount = OUTPUT_BUF_SIZE - dest->pub.free_in_buffer;
+
+ /* Write any data remaining in the buffer */
+ if (datacount > 0) {
+ if (JFWRITE(dest->outfile, dest->buffer, datacount) != datacount)
+ ERREXIT(cinfo, JERR_FILE_WRITE);
+ }
+ fflush(dest->outfile);
+ /* Make sure we wrote the output file OK */
+ if (ferror(dest->outfile))
+ ERREXIT(cinfo, JERR_FILE_WRITE);
+}
+
+
+/*
+ * Prepare for output to a stdio stream.
+ * The caller must have already opened the stream, and is responsible
+ * for closing it after finishing compression.
+ */
+
+GLOBAL(void)
+jpeg_stdio_dest (j_compress_ptr cinfo, FILE * outfile)
+{
+ my_dest_ptr dest;
+
+ /* The destination object is made permanent so that multiple JPEG images
+ * can be written to the same file without re-executing jpeg_stdio_dest.
+ * This makes it dangerous to use this manager and a different destination
+ * manager serially with the same JPEG object, because their private object
+ * sizes may be different. Caveat programmer.
+ */
+ if (cinfo->dest == NULL) { /* first time for this JPEG object? */
+ cinfo->dest = (struct jpeg_destination_mgr *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT,
+ SIZEOF(my_destination_mgr));
+ }
+
+ dest = (my_dest_ptr) cinfo->dest;
+ dest->pub.init_destination = init_destination;
+ dest->pub.empty_output_buffer = empty_output_buffer;
+ dest->pub.term_destination = term_destination;
+ dest->outfile = outfile;
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdatasrc.cpp b/ml/dlib/dlib/external/libjpeg/jdatasrc.cpp
new file mode 100644
index 000000000..7af097f02
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdatasrc.cpp
@@ -0,0 +1,212 @@
+/*
+ * jdatasrc.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains decompression data source routines for the case of
+ * reading JPEG data from a file (or any stdio stream). While these routines
+ * are sufficient for most applications, some will want to use a different
+ * source manager.
+ * IMPORTANT: we assume that fread() will correctly transcribe an array of
+ * JOCTETs from 8-bit-wide elements on external storage. If char is wider
+ * than 8 bits on your machine, you may need to do some tweaking.
+ */
+
+/* this is not a core library module, so it doesn't define JPEG_INTERNALS */
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jerror.h"
+
+
+/* Expanded data source object for stdio input */
+
+typedef struct {
+ struct jpeg_source_mgr pub; /* public fields */
+
+ FILE * infile; /* source stream */
+ JOCTET * buffer; /* start of buffer */
+ int start_of_file; /* have we gotten any data yet? */
+} my_source_mgr;
+
+typedef my_source_mgr * my_src_ptr;
+
+#define INPUT_BUF_SIZE 4096 /* choose an efficiently fread'able size */
+
+
+/*
+ * Initialize source --- called by jpeg_read_header
+ * before any data is actually read.
+ */
+
+METHODDEF(void)
+init_source (j_decompress_ptr cinfo)
+{
+ my_src_ptr src = (my_src_ptr) cinfo->src;
+
+ /* We reset the empty-input-file flag for each image,
+ * but we don't clear the input buffer.
+ * This is correct behavior for reading a series of images from one source.
+ */
+ src->start_of_file = TRUE;
+}
+
+
+/*
+ * Fill the input buffer --- called whenever buffer is emptied.
+ *
+ * In typical applications, this should read fresh data into the buffer
+ * (ignoring the current state of next_input_byte & bytes_in_buffer),
+ * reset the pointer & count to the start of the buffer, and return TRUE
+ * indicating that the buffer has been reloaded. It is not necessary to
+ * fill the buffer entirely, only to obtain at least one more byte.
+ *
+ * There is no such thing as an EOF return. If the end of the file has been
+ * reached, the routine has a choice of ERREXIT() or inserting fake data into
+ * the buffer. In most cases, generating a warning message and inserting a
+ * fake EOI marker is the best course of action --- this will allow the
+ * decompressor to output however much of the image is there. However,
+ * the resulting error message is misleading if the real problem is an empty
+ * input file, so we handle that case specially.
+ *
+ * In applications that need to be able to suspend compression due to input
+ * not being available yet, a FALSE return indicates that no more data can be
+ * obtained right now, but more may be forthcoming later. In this situation,
+ * the decompressor will return to its caller (with an indication of the
+ * number of scanlines it has read, if any). The application should resume
+ * decompression after it has loaded more data into the input buffer. Note
+ * that there are substantial restrictions on the use of suspension --- see
+ * the documentation.
+ *
+ * When suspending, the decompressor will back up to a convenient restart point
+ * (typically the start of the current MCU). next_input_byte & bytes_in_buffer
+ * indicate where the restart point will be if the current call returns FALSE.
+ * Data beyond this point must be rescanned after resumption, so move it to
+ * the front of the buffer rather than discarding it.
+ */
+
+METHODDEF(int)
+fill_input_buffer (j_decompress_ptr cinfo)
+{
+ my_src_ptr src = (my_src_ptr) cinfo->src;
+ size_t nbytes;
+
+ nbytes = JFREAD(src->infile, src->buffer, INPUT_BUF_SIZE);
+
+ if (nbytes <= 0) {
+ if (src->start_of_file) /* Treat empty input file as fatal error */
+ ERREXIT(cinfo, JERR_INPUT_EMPTY);
+ WARNMS(cinfo, JWRN_JPEG_EOF);
+ /* Insert a fake EOI marker */
+ src->buffer[0] = (JOCTET) 0xFF;
+ src->buffer[1] = (JOCTET) JPEG_EOI;
+ nbytes = 2;
+ }
+
+ src->pub.next_input_byte = src->buffer;
+ src->pub.bytes_in_buffer = nbytes;
+ src->start_of_file = FALSE;
+
+ return TRUE;
+}
+
+
+/*
+ * Skip data --- used to skip over a potentially large amount of
+ * uninteresting data (such as an APPn marker).
+ *
+ * Writers of suspendable-input applications must note that skip_input_data
+ * is not granted the right to give a suspension return. If the skip extends
+ * beyond the data currently in the buffer, the buffer can be marked empty so
+ * that the next read will cause a fill_input_buffer call that can suspend.
+ * Arranging for additional bytes to be discarded before reloading the input
+ * buffer is the application writer's problem.
+ */
+
+METHODDEF(void)
+skip_input_data (j_decompress_ptr cinfo, long num_bytes)
+{
+ my_src_ptr src = (my_src_ptr) cinfo->src;
+
+ /* Just a dumb implementation for now. Could use fseek() except
+ * it doesn't work on pipes. Not clear that being smart is worth
+ * any trouble anyway --- large skips are infrequent.
+ */
+ if (num_bytes > 0) {
+ while (num_bytes > (long) src->pub.bytes_in_buffer) {
+ num_bytes -= (long) src->pub.bytes_in_buffer;
+ (void) fill_input_buffer(cinfo);
+ /* note we assume that fill_input_buffer will never return FALSE,
+ * so suspension need not be handled.
+ */
+ }
+ src->pub.next_input_byte += (size_t) num_bytes;
+ src->pub.bytes_in_buffer -= (size_t) num_bytes;
+ }
+}
+
+
+/*
+ * An additional method that can be provided by data source modules is the
+ * resync_to_restart method for error recovery in the presence of RST markers.
+ * For the moment, this source module just uses the default resync method
+ * provided by the JPEG library. That method assumes that no backtracking
+ * is possible.
+ */
+
+
+/*
+ * Terminate source --- called by jpeg_finish_decompress
+ * after all data has been read. Often a no-op.
+ *
+ * NB: *not* called by jpeg_abort or jpeg_destroy; surrounding
+ * application must deal with any cleanup that should happen even
+ * for error exit.
+ */
+
+METHODDEF(void)
+term_source (j_decompress_ptr )
+{
+ /* no work necessary here */
+}
+
+
+/*
+ * Prepare for input from a stdio stream.
+ * The caller must have already opened the stream, and is responsible
+ * for closing it after finishing decompression.
+ */
+
+GLOBAL(void)
+jpeg_stdio_src (j_decompress_ptr cinfo, FILE * infile)
+{
+ my_src_ptr src;
+
+ /* The source object and input buffer are made permanent so that a series
+ * of JPEG images can be read from the same file by calling jpeg_stdio_src
+ * only before the first one. (If we discarded the buffer at the end of
+ * one image, we'd likely lose the start of the next one.)
+ * This makes it unsafe to use this manager and a different source
+ * manager serially with the same JPEG object. Caveat programmer.
+ */
+ if (cinfo->src == NULL) { /* first time for this JPEG object? */
+ cinfo->src = (struct jpeg_source_mgr *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT,
+ SIZEOF(my_source_mgr));
+ src = (my_src_ptr) cinfo->src;
+ src->buffer = (JOCTET *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT,
+ INPUT_BUF_SIZE * SIZEOF(JOCTET));
+ }
+
+ src = (my_src_ptr) cinfo->src;
+ src->pub.init_source = init_source;
+ src->pub.fill_input_buffer = fill_input_buffer;
+ src->pub.skip_input_data = skip_input_data;
+ src->pub.resync_to_restart = jpeg_resync_to_restart; /* use default method */
+ src->pub.term_source = term_source;
+ src->infile = infile;
+ src->pub.bytes_in_buffer = 0; /* forces fill_input_buffer on first read */
+ src->pub.next_input_byte = NULL; /* until buffer loaded */
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdcoefct.cpp b/ml/dlib/dlib/external/libjpeg/jdcoefct.cpp
new file mode 100644
index 000000000..11b618920
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdcoefct.cpp
@@ -0,0 +1,736 @@
+/*
+ * jdcoefct.c
+ *
+ * Copyright (C) 1994-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains the coefficient buffer controller for decompression.
+ * This controller is the top level of the JPEG decompressor proper.
+ * The coefficient buffer lies between entropy decoding and inverse-DCT steps.
+ *
+ * In buffered-image mode, this controller is the interface between
+ * input-oriented processing and output-oriented processing.
+ * Also, the input side (only) is used when reading a file for transcoding.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+/* Block smoothing is only applicable for progressive JPEG, so: */
+#ifndef D_PROGRESSIVE_SUPPORTED
+#undef BLOCK_SMOOTHING_SUPPORTED
+#endif
+
+/* Private buffer controller object */
+
+typedef struct {
+ struct jpeg_d_coef_controller pub; /* public fields */
+
+ /* These variables keep track of the current location of the input side. */
+ /* cinfo->input_iMCU_row is also used for this. */
+ JDIMENSION MCU_ctr; /* counts MCUs processed in current row */
+ int MCU_vert_offset; /* counts MCU rows within iMCU row */
+ int MCU_rows_per_iMCU_row; /* number of such rows needed */
+
+ /* The output side's location is represented by cinfo->output_iMCU_row. */
+
+ /* In single-pass modes, it's sufficient to buffer just one MCU.
+ * We allocate a workspace of D_MAX_BLOCKS_IN_MCU coefficient blocks,
+ * and let the entropy decoder write into that workspace each time.
+ * (On 80x86, the workspace is FAR even though it's not really very big;
+ * this is to keep the module interfaces unchanged when a large coefficient
+ * buffer is necessary.)
+ * In multi-pass modes, this array points to the current MCU's blocks
+ * within the virtual arrays; it is used only by the input side.
+ */
+ JBLOCKROW MCU_buffer[D_MAX_BLOCKS_IN_MCU];
+
+#ifdef D_MULTISCAN_FILES_SUPPORTED
+ /* In multi-pass modes, we need a virtual block array for each component. */
+ jvirt_barray_ptr whole_image[MAX_COMPONENTS];
+#endif
+
+#ifdef BLOCK_SMOOTHING_SUPPORTED
+ /* When doing block smoothing, we latch coefficient Al values here */
+ int * coef_bits_latch;
+#define SAVED_COEFS 6 /* we save coef_bits[0..5] */
+#endif
+} my_coef_controller;
+
+typedef my_coef_controller * my_coef_ptr;
+
+/* Forward declarations */
+METHODDEF(int) decompress_onepass
+ JPP((j_decompress_ptr cinfo, JSAMPIMAGE output_buf));
+#ifdef D_MULTISCAN_FILES_SUPPORTED
+METHODDEF(int) decompress_data
+ JPP((j_decompress_ptr cinfo, JSAMPIMAGE output_buf));
+#endif
+#ifdef BLOCK_SMOOTHING_SUPPORTED
+LOCAL(int) smoothing_ok JPP((j_decompress_ptr cinfo));
+METHODDEF(int) decompress_smooth_data
+ JPP((j_decompress_ptr cinfo, JSAMPIMAGE output_buf));
+#endif
+
+
+LOCAL(void)
+start_iMCU_row (j_decompress_ptr cinfo)
+/* Reset within-iMCU-row counters for a new row (input side) */
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+
+ /* In an interleaved scan, an MCU row is the same as an iMCU row.
+ * In a noninterleaved scan, an iMCU row has v_samp_factor MCU rows.
+ * But at the bottom of the image, process only what's left.
+ */
+ if (cinfo->comps_in_scan > 1) {
+ coef->MCU_rows_per_iMCU_row = 1;
+ } else {
+ if (cinfo->input_iMCU_row < (cinfo->total_iMCU_rows-1))
+ coef->MCU_rows_per_iMCU_row = cinfo->cur_comp_info[0]->v_samp_factor;
+ else
+ coef->MCU_rows_per_iMCU_row = cinfo->cur_comp_info[0]->last_row_height;
+ }
+
+ coef->MCU_ctr = 0;
+ coef->MCU_vert_offset = 0;
+}
+
+
+/*
+ * Initialize for an input processing pass.
+ */
+
+METHODDEF(void)
+start_input_pass (j_decompress_ptr cinfo)
+{
+ cinfo->input_iMCU_row = 0;
+ start_iMCU_row(cinfo);
+}
+
+
+/*
+ * Initialize for an output processing pass.
+ */
+
+METHODDEF(void)
+start_output_pass (j_decompress_ptr cinfo)
+{
+#ifdef BLOCK_SMOOTHING_SUPPORTED
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+
+ /* If multipass, check to see whether to use block smoothing on this pass */
+ if (coef->pub.coef_arrays != NULL) {
+ if (cinfo->do_block_smoothing && smoothing_ok(cinfo))
+ coef->pub.decompress_data = decompress_smooth_data;
+ else
+ coef->pub.decompress_data = decompress_data;
+ }
+#endif
+ cinfo->output_iMCU_row = 0;
+}
+
+
+/*
+ * Decompress and return some data in the single-pass case.
+ * Always attempts to emit one fully interleaved MCU row ("iMCU" row).
+ * Input and output must run in lockstep since we have only a one-MCU buffer.
+ * Return value is JPEG_ROW_COMPLETED, JPEG_SCAN_COMPLETED, or JPEG_SUSPENDED.
+ *
+ * NB: output_buf contains a plane for each component in image,
+ * which we index according to the component's SOF position.
+ */
+
+METHODDEF(int)
+decompress_onepass (j_decompress_ptr cinfo, JSAMPIMAGE output_buf)
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+ JDIMENSION MCU_col_num; /* index of current MCU within row */
+ JDIMENSION last_MCU_col = cinfo->MCUs_per_row - 1;
+ JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1;
+ int blkn, ci, xindex, yindex, yoffset, useful_width;
+ JSAMPARRAY output_ptr;
+ JDIMENSION start_col, output_col;
+ jpeg_component_info *compptr;
+ inverse_DCT_method_ptr inverse_DCT;
+
+ /* Loop to process as much as one whole iMCU row */
+ for (yoffset = coef->MCU_vert_offset; yoffset < coef->MCU_rows_per_iMCU_row;
+ yoffset++) {
+ for (MCU_col_num = coef->MCU_ctr; MCU_col_num <= last_MCU_col;
+ MCU_col_num++) {
+ /* Try to fetch an MCU. Entropy decoder expects buffer to be zeroed. */
+ jzero_far((void FAR *) coef->MCU_buffer[0],
+ (size_t) (cinfo->blocks_in_MCU * SIZEOF(JBLOCK)));
+ if (! (*cinfo->entropy->decode_mcu) (cinfo, coef->MCU_buffer)) {
+ /* Suspension forced; update state counters and exit */
+ coef->MCU_vert_offset = yoffset;
+ coef->MCU_ctr = MCU_col_num;
+ return JPEG_SUSPENDED;
+ }
+ /* Determine where data should go in output_buf and do the IDCT thing.
+ * We skip dummy blocks at the right and bottom edges (but blkn gets
+ * incremented past them!). Note the inner loop relies on having
+ * allocated the MCU_buffer[] blocks sequentially.
+ */
+ blkn = 0; /* index of current DCT block within MCU */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ /* Don't bother to IDCT an uninteresting component. */
+ if (! compptr->component_needed) {
+ blkn += compptr->MCU_blocks;
+ continue;
+ }
+ inverse_DCT = cinfo->idct->inverse_DCT[compptr->component_index];
+ useful_width = (MCU_col_num < last_MCU_col) ? compptr->MCU_width
+ : compptr->last_col_width;
+ output_ptr = output_buf[compptr->component_index] +
+ yoffset * compptr->DCT_scaled_size;
+ start_col = MCU_col_num * compptr->MCU_sample_width;
+ for (yindex = 0; yindex < compptr->MCU_height; yindex++) {
+ if (cinfo->input_iMCU_row < last_iMCU_row ||
+ yoffset+yindex < compptr->last_row_height) {
+ output_col = start_col;
+ for (xindex = 0; xindex < useful_width; xindex++) {
+ (*inverse_DCT) (cinfo, compptr,
+ (JCOEFPTR) coef->MCU_buffer[blkn+xindex],
+ output_ptr, output_col);
+ output_col += compptr->DCT_scaled_size;
+ }
+ }
+ blkn += compptr->MCU_width;
+ output_ptr += compptr->DCT_scaled_size;
+ }
+ }
+ }
+ /* Completed an MCU row, but perhaps not an iMCU row */
+ coef->MCU_ctr = 0;
+ }
+ /* Completed the iMCU row, advance counters for next one */
+ cinfo->output_iMCU_row++;
+ if (++(cinfo->input_iMCU_row) < cinfo->total_iMCU_rows) {
+ start_iMCU_row(cinfo);
+ return JPEG_ROW_COMPLETED;
+ }
+ /* Completed the scan */
+ (*cinfo->inputctl->finish_input_pass) (cinfo);
+ return JPEG_SCAN_COMPLETED;
+}
+
+
+/*
+ * Dummy consume-input routine for single-pass operation.
+ */
+
+METHODDEF(int)
+dummy_consume_data (j_decompress_ptr )
+{
+ return JPEG_SUSPENDED; /* Always indicate nothing was done */
+}
+
+
+#ifdef D_MULTISCAN_FILES_SUPPORTED
+
+/*
+ * Consume input data and store it in the full-image coefficient buffer.
+ * We read as much as one fully interleaved MCU row ("iMCU" row) per call,
+ * ie, v_samp_factor block rows for each component in the scan.
+ * Return value is JPEG_ROW_COMPLETED, JPEG_SCAN_COMPLETED, or JPEG_SUSPENDED.
+ */
+
+METHODDEF(int)
+consume_data (j_decompress_ptr cinfo)
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+ JDIMENSION MCU_col_num; /* index of current MCU within row */
+ int blkn, ci, xindex, yindex, yoffset;
+ JDIMENSION start_col;
+ JBLOCKARRAY buffer[MAX_COMPS_IN_SCAN];
+ JBLOCKROW buffer_ptr;
+ jpeg_component_info *compptr;
+
+ /* Align the virtual buffers for the components used in this scan. */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ buffer[ci] = (*cinfo->mem->access_virt_barray)
+ ((j_common_ptr) cinfo, coef->whole_image[compptr->component_index],
+ cinfo->input_iMCU_row * compptr->v_samp_factor,
+ (JDIMENSION) compptr->v_samp_factor, TRUE);
+ /* Note: entropy decoder expects buffer to be zeroed,
+ * but this is handled automatically by the memory manager
+ * because we requested a pre-zeroed array.
+ */
+ }
+
+ /* Loop to process one whole iMCU row */
+ for (yoffset = coef->MCU_vert_offset; yoffset < coef->MCU_rows_per_iMCU_row;
+ yoffset++) {
+ for (MCU_col_num = coef->MCU_ctr; MCU_col_num < cinfo->MCUs_per_row;
+ MCU_col_num++) {
+ /* Construct list of pointers to DCT blocks belonging to this MCU */
+ blkn = 0; /* index of current DCT block within MCU */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ start_col = MCU_col_num * compptr->MCU_width;
+ for (yindex = 0; yindex < compptr->MCU_height; yindex++) {
+ buffer_ptr = buffer[ci][yindex+yoffset] + start_col;
+ for (xindex = 0; xindex < compptr->MCU_width; xindex++) {
+ coef->MCU_buffer[blkn++] = buffer_ptr++;
+ }
+ }
+ }
+ /* Try to fetch the MCU. */
+ if (! (*cinfo->entropy->decode_mcu) (cinfo, coef->MCU_buffer)) {
+ /* Suspension forced; update state counters and exit */
+ coef->MCU_vert_offset = yoffset;
+ coef->MCU_ctr = MCU_col_num;
+ return JPEG_SUSPENDED;
+ }
+ }
+ /* Completed an MCU row, but perhaps not an iMCU row */
+ coef->MCU_ctr = 0;
+ }
+ /* Completed the iMCU row, advance counters for next one */
+ if (++(cinfo->input_iMCU_row) < cinfo->total_iMCU_rows) {
+ start_iMCU_row(cinfo);
+ return JPEG_ROW_COMPLETED;
+ }
+ /* Completed the scan */
+ (*cinfo->inputctl->finish_input_pass) (cinfo);
+ return JPEG_SCAN_COMPLETED;
+}
+
+
+/*
+ * Decompress and return some data in the multi-pass case.
+ * Always attempts to emit one fully interleaved MCU row ("iMCU" row).
+ * Return value is JPEG_ROW_COMPLETED, JPEG_SCAN_COMPLETED, or JPEG_SUSPENDED.
+ *
+ * NB: output_buf contains a plane for each component in image.
+ */
+
+METHODDEF(int)
+decompress_data (j_decompress_ptr cinfo, JSAMPIMAGE output_buf)
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+ JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1;
+ JDIMENSION block_num;
+ int ci, block_row, block_rows;
+ JBLOCKARRAY buffer;
+ JBLOCKROW buffer_ptr;
+ JSAMPARRAY output_ptr;
+ JDIMENSION output_col;
+ jpeg_component_info *compptr;
+ inverse_DCT_method_ptr inverse_DCT;
+
+ /* Force some input to be done if we are getting ahead of the input. */
+ while (cinfo->input_scan_number < cinfo->output_scan_number ||
+ (cinfo->input_scan_number == cinfo->output_scan_number &&
+ cinfo->input_iMCU_row <= cinfo->output_iMCU_row)) {
+ if ((*cinfo->inputctl->consume_input)(cinfo) == JPEG_SUSPENDED)
+ return JPEG_SUSPENDED;
+ }
+
+ /* OK, output from the virtual arrays. */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Don't bother to IDCT an uninteresting component. */
+ if (! compptr->component_needed)
+ continue;
+ /* Align the virtual buffer for this component. */
+ buffer = (*cinfo->mem->access_virt_barray)
+ ((j_common_ptr) cinfo, coef->whole_image[ci],
+ cinfo->output_iMCU_row * compptr->v_samp_factor,
+ (JDIMENSION) compptr->v_samp_factor, FALSE);
+ /* Count non-dummy DCT block rows in this iMCU row. */
+ if (cinfo->output_iMCU_row < last_iMCU_row)
+ block_rows = compptr->v_samp_factor;
+ else {
+ /* NB: can't use last_row_height here; it is input-side-dependent! */
+ block_rows = (int) (compptr->height_in_blocks % compptr->v_samp_factor);
+ if (block_rows == 0) block_rows = compptr->v_samp_factor;
+ }
+ inverse_DCT = cinfo->idct->inverse_DCT[ci];
+ output_ptr = output_buf[ci];
+ /* Loop over all DCT blocks to be processed. */
+ for (block_row = 0; block_row < block_rows; block_row++) {
+ buffer_ptr = buffer[block_row];
+ output_col = 0;
+ for (block_num = 0; block_num < compptr->width_in_blocks; block_num++) {
+ (*inverse_DCT) (cinfo, compptr, (JCOEFPTR) buffer_ptr,
+ output_ptr, output_col);
+ buffer_ptr++;
+ output_col += compptr->DCT_scaled_size;
+ }
+ output_ptr += compptr->DCT_scaled_size;
+ }
+ }
+
+ if (++(cinfo->output_iMCU_row) < cinfo->total_iMCU_rows)
+ return JPEG_ROW_COMPLETED;
+ return JPEG_SCAN_COMPLETED;
+}
+
+#endif /* D_MULTISCAN_FILES_SUPPORTED */
+
+
+#ifdef BLOCK_SMOOTHING_SUPPORTED
+
+/*
+ * This code applies interblock smoothing as described by section K.8
+ * of the JPEG standard: the first 5 AC coefficients are estimated from
+ * the DC values of a DCT block and its 8 neighboring blocks.
+ * We apply smoothing only for progressive JPEG decoding, and only if
+ * the coefficients it can estimate are not yet known to full precision.
+ */
+
+/* Natural-order array positions of the first 5 zigzag-order coefficients */
+#define Q01_POS 1
+#define Q10_POS 8
+#define Q20_POS 16
+#define Q11_POS 9
+#define Q02_POS 2
+
+/*
+ * Determine whether block smoothing is applicable and safe.
+ * We also latch the current states of the coef_bits[] entries for the
+ * AC coefficients; otherwise, if the input side of the decompressor
+ * advances into a new scan, we might think the coefficients are known
+ * more accurately than they really are.
+ */
+
+LOCAL(int)
+smoothing_ok (j_decompress_ptr cinfo)
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+ int smoothing_useful = FALSE;
+ int ci, coefi;
+ jpeg_component_info *compptr;
+ JQUANT_TBL * qtable;
+ int * coef_bits;
+ int * coef_bits_latch;
+
+ if (! cinfo->progressive_mode || cinfo->coef_bits == NULL)
+ return FALSE;
+
+ /* Allocate latch area if not already done */
+ if (coef->coef_bits_latch == NULL)
+ coef->coef_bits_latch = (int *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ cinfo->num_components *
+ (SAVED_COEFS * SIZEOF(int)));
+ coef_bits_latch = coef->coef_bits_latch;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* All components' quantization values must already be latched. */
+ if ((qtable = compptr->quant_table) == NULL)
+ return FALSE;
+ /* Verify DC & first 5 AC quantizers are nonzero to avoid zero-divide. */
+ if (qtable->quantval[0] == 0 ||
+ qtable->quantval[Q01_POS] == 0 ||
+ qtable->quantval[Q10_POS] == 0 ||
+ qtable->quantval[Q20_POS] == 0 ||
+ qtable->quantval[Q11_POS] == 0 ||
+ qtable->quantval[Q02_POS] == 0)
+ return FALSE;
+ /* DC values must be at least partly known for all components. */
+ coef_bits = cinfo->coef_bits[ci];
+ if (coef_bits[0] < 0)
+ return FALSE;
+ /* Block smoothing is helpful if some AC coefficients remain inaccurate. */
+ for (coefi = 1; coefi <= 5; coefi++) {
+ coef_bits_latch[coefi] = coef_bits[coefi];
+ if (coef_bits[coefi] != 0)
+ smoothing_useful = TRUE;
+ }
+ coef_bits_latch += SAVED_COEFS;
+ }
+
+ return smoothing_useful;
+}
+
+
+/*
+ * Variant of decompress_data for use when doing block smoothing.
+ */
+
+METHODDEF(int)
+decompress_smooth_data (j_decompress_ptr cinfo, JSAMPIMAGE output_buf)
+{
+ my_coef_ptr coef = (my_coef_ptr) cinfo->coef;
+ JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1;
+ JDIMENSION block_num, last_block_column;
+ int ci, block_row, block_rows, access_rows;
+ JBLOCKARRAY buffer;
+ JBLOCKROW buffer_ptr, prev_block_row, next_block_row;
+ JSAMPARRAY output_ptr;
+ JDIMENSION output_col;
+ jpeg_component_info *compptr;
+ inverse_DCT_method_ptr inverse_DCT;
+ int first_row, last_row;
+ JBLOCK workspace;
+ int *coef_bits;
+ JQUANT_TBL *quanttbl;
+ long Q00,Q01,Q02,Q10,Q11,Q20, num;
+ int DC1,DC2,DC3,DC4,DC5,DC6,DC7,DC8,DC9;
+ int Al, pred;
+
+ /* Force some input to be done if we are getting ahead of the input. */
+ while (cinfo->input_scan_number <= cinfo->output_scan_number &&
+ ! cinfo->inputctl->eoi_reached) {
+ if (cinfo->input_scan_number == cinfo->output_scan_number) {
+ /* If input is working on current scan, we ordinarily want it to
+ * have completed the current row. But if input scan is DC,
+ * we want it to keep one row ahead so that next block row's DC
+ * values are up to date.
+ */
+ JDIMENSION delta = (cinfo->Ss == 0) ? 1 : 0;
+ if (cinfo->input_iMCU_row > cinfo->output_iMCU_row+delta)
+ break;
+ }
+ if ((*cinfo->inputctl->consume_input)(cinfo) == JPEG_SUSPENDED)
+ return JPEG_SUSPENDED;
+ }
+
+ /* OK, output from the virtual arrays. */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Don't bother to IDCT an uninteresting component. */
+ if (! compptr->component_needed)
+ continue;
+ /* Count non-dummy DCT block rows in this iMCU row. */
+ if (cinfo->output_iMCU_row < last_iMCU_row) {
+ block_rows = compptr->v_samp_factor;
+ access_rows = block_rows * 2; /* this and next iMCU row */
+ last_row = FALSE;
+ } else {
+ /* NB: can't use last_row_height here; it is input-side-dependent! */
+ block_rows = (int) (compptr->height_in_blocks % compptr->v_samp_factor);
+ if (block_rows == 0) block_rows = compptr->v_samp_factor;
+ access_rows = block_rows; /* this iMCU row only */
+ last_row = TRUE;
+ }
+ /* Align the virtual buffer for this component. */
+ if (cinfo->output_iMCU_row > 0) {
+ access_rows += compptr->v_samp_factor; /* prior iMCU row too */
+ buffer = (*cinfo->mem->access_virt_barray)
+ ((j_common_ptr) cinfo, coef->whole_image[ci],
+ (cinfo->output_iMCU_row - 1) * compptr->v_samp_factor,
+ (JDIMENSION) access_rows, FALSE);
+ buffer += compptr->v_samp_factor; /* point to current iMCU row */
+ first_row = FALSE;
+ } else {
+ buffer = (*cinfo->mem->access_virt_barray)
+ ((j_common_ptr) cinfo, coef->whole_image[ci],
+ (JDIMENSION) 0, (JDIMENSION) access_rows, FALSE);
+ first_row = TRUE;
+ }
+ /* Fetch component-dependent info */
+ coef_bits = coef->coef_bits_latch + (ci * SAVED_COEFS);
+ quanttbl = compptr->quant_table;
+ Q00 = quanttbl->quantval[0];
+ Q01 = quanttbl->quantval[Q01_POS];
+ Q10 = quanttbl->quantval[Q10_POS];
+ Q20 = quanttbl->quantval[Q20_POS];
+ Q11 = quanttbl->quantval[Q11_POS];
+ Q02 = quanttbl->quantval[Q02_POS];
+ inverse_DCT = cinfo->idct->inverse_DCT[ci];
+ output_ptr = output_buf[ci];
+ /* Loop over all DCT blocks to be processed. */
+ for (block_row = 0; block_row < block_rows; block_row++) {
+ buffer_ptr = buffer[block_row];
+ if (first_row && block_row == 0)
+ prev_block_row = buffer_ptr;
+ else
+ prev_block_row = buffer[block_row-1];
+ if (last_row && block_row == block_rows-1)
+ next_block_row = buffer_ptr;
+ else
+ next_block_row = buffer[block_row+1];
+ /* We fetch the surrounding DC values using a sliding-register approach.
+ * Initialize all nine here so as to do the right thing on narrow pics.
+ */
+ DC1 = DC2 = DC3 = (int) prev_block_row[0][0];
+ DC4 = DC5 = DC6 = (int) buffer_ptr[0][0];
+ DC7 = DC8 = DC9 = (int) next_block_row[0][0];
+ output_col = 0;
+ last_block_column = compptr->width_in_blocks - 1;
+ for (block_num = 0; block_num <= last_block_column; block_num++) {
+ /* Fetch current DCT block into workspace so we can modify it. */
+ jcopy_block_row(buffer_ptr, (JBLOCKROW) workspace, (JDIMENSION) 1);
+ /* Update DC values */
+ if (block_num < last_block_column) {
+ DC3 = (int) prev_block_row[1][0];
+ DC6 = (int) buffer_ptr[1][0];
+ DC9 = (int) next_block_row[1][0];
+ }
+ /* Compute coefficient estimates per K.8.
+ * An estimate is applied only if coefficient is still zero,
+ * and is not known to be fully accurate.
+ */
+ /* AC01 */
+ if ((Al=coef_bits[1]) != 0 && workspace[1] == 0) {
+ num = 36 * Q00 * (DC4 - DC6);
+ if (num >= 0) {
+ pred = (int) (((Q01<<7) + num) / (Q01<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ } else {
+ pred = (int) (((Q01<<7) - num) / (Q01<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ pred = -pred;
+ }
+ workspace[1] = (JCOEF) pred;
+ }
+ /* AC10 */
+ if ((Al=coef_bits[2]) != 0 && workspace[8] == 0) {
+ num = 36 * Q00 * (DC2 - DC8);
+ if (num >= 0) {
+ pred = (int) (((Q10<<7) + num) / (Q10<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ } else {
+ pred = (int) (((Q10<<7) - num) / (Q10<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ pred = -pred;
+ }
+ workspace[8] = (JCOEF) pred;
+ }
+ /* AC20 */
+ if ((Al=coef_bits[3]) != 0 && workspace[16] == 0) {
+ num = 9 * Q00 * (DC2 + DC8 - 2*DC5);
+ if (num >= 0) {
+ pred = (int) (((Q20<<7) + num) / (Q20<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ } else {
+ pred = (int) (((Q20<<7) - num) / (Q20<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ pred = -pred;
+ }
+ workspace[16] = (JCOEF) pred;
+ }
+ /* AC11 */
+ if ((Al=coef_bits[4]) != 0 && workspace[9] == 0) {
+ num = 5 * Q00 * (DC1 - DC3 - DC7 + DC9);
+ if (num >= 0) {
+ pred = (int) (((Q11<<7) + num) / (Q11<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ } else {
+ pred = (int) (((Q11<<7) - num) / (Q11<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ pred = -pred;
+ }
+ workspace[9] = (JCOEF) pred;
+ }
+ /* AC02 */
+ if ((Al=coef_bits[5]) != 0 && workspace[2] == 0) {
+ num = 9 * Q00 * (DC4 + DC6 - 2*DC5);
+ if (num >= 0) {
+ pred = (int) (((Q02<<7) + num) / (Q02<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ } else {
+ pred = (int) (((Q02<<7) - num) / (Q02<<8));
+ if (Al > 0 && pred >= (1<<Al))
+ pred = (1<<Al)-1;
+ pred = -pred;
+ }
+ workspace[2] = (JCOEF) pred;
+ }
+ /* OK, do the IDCT */
+ (*inverse_DCT) (cinfo, compptr, (JCOEFPTR) workspace,
+ output_ptr, output_col);
+ /* Advance for next column */
+ DC1 = DC2; DC2 = DC3;
+ DC4 = DC5; DC5 = DC6;
+ DC7 = DC8; DC8 = DC9;
+ buffer_ptr++, prev_block_row++, next_block_row++;
+ output_col += compptr->DCT_scaled_size;
+ }
+ output_ptr += compptr->DCT_scaled_size;
+ }
+ }
+
+ if (++(cinfo->output_iMCU_row) < cinfo->total_iMCU_rows)
+ return JPEG_ROW_COMPLETED;
+ return JPEG_SCAN_COMPLETED;
+}
+
+#endif /* BLOCK_SMOOTHING_SUPPORTED */
+
+
+/*
+ * Initialize coefficient buffer controller.
+ */
+
+GLOBAL(void)
+jinit_d_coef_controller (j_decompress_ptr cinfo, int need_full_buffer)
+{
+ my_coef_ptr coef;
+
+ coef = (my_coef_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_coef_controller));
+ cinfo->coef = (struct jpeg_d_coef_controller *) coef;
+ coef->pub.start_input_pass = start_input_pass;
+ coef->pub.start_output_pass = start_output_pass;
+#ifdef BLOCK_SMOOTHING_SUPPORTED
+ coef->coef_bits_latch = NULL;
+#endif
+
+ /* Create the coefficient buffer. */
+ if (need_full_buffer) {
+#ifdef D_MULTISCAN_FILES_SUPPORTED
+ /* Allocate a full-image virtual array for each component, */
+ /* padded to a multiple of samp_factor DCT blocks in each direction. */
+ /* Note we ask for a pre-zeroed array. */
+ int ci, access_rows;
+ jpeg_component_info *compptr;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ access_rows = compptr->v_samp_factor;
+#ifdef BLOCK_SMOOTHING_SUPPORTED
+ /* If block smoothing could be used, need a bigger window */
+ if (cinfo->progressive_mode)
+ access_rows *= 3;
+#endif
+ coef->whole_image[ci] = (*cinfo->mem->request_virt_barray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE, TRUE,
+ (JDIMENSION) jround_up((long) compptr->width_in_blocks,
+ (long) compptr->h_samp_factor),
+ (JDIMENSION) jround_up((long) compptr->height_in_blocks,
+ (long) compptr->v_samp_factor),
+ (JDIMENSION) access_rows);
+ }
+ coef->pub.consume_data = consume_data;
+ coef->pub.decompress_data = decompress_data;
+ coef->pub.coef_arrays = coef->whole_image; /* link to virtual arrays */
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ } else {
+ /* We only need a single-MCU buffer. */
+ JBLOCKROW buffer;
+ int i;
+
+ buffer = (JBLOCKROW)
+ (*cinfo->mem->alloc_large) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ D_MAX_BLOCKS_IN_MCU * SIZEOF(JBLOCK));
+ for (i = 0; i < D_MAX_BLOCKS_IN_MCU; i++) {
+ coef->MCU_buffer[i] = buffer + i;
+ }
+ coef->pub.consume_data = dummy_consume_data;
+ coef->pub.decompress_data = decompress_onepass;
+ coef->pub.coef_arrays = NULL; /* flag for no virtual arrays */
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdcolor.cpp b/ml/dlib/dlib/external/libjpeg/jdcolor.cpp
new file mode 100644
index 000000000..8dd88bfd3
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdcolor.cpp
@@ -0,0 +1,396 @@
+/*
+ * jdcolor.c
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains output colorspace conversion routines.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Private subobject */
+
+typedef struct {
+ struct jpeg_color_deconverter pub; /* public fields */
+
+ /* Private state for YCC->RGB conversion */
+ int * Cr_r_tab; /* => table for Cr to R conversion */
+ int * Cb_b_tab; /* => table for Cb to B conversion */
+ long * Cr_g_tab; /* => table for Cr to G conversion */
+ long * Cb_g_tab; /* => table for Cb to G conversion */
+} my_color_deconverter;
+
+typedef my_color_deconverter * my_cconvert_ptr;
+
+
+/**************** YCbCr -> RGB conversion: most common case **************/
+
+/*
+ * YCbCr is defined per CCIR 601-1, except that Cb and Cr are
+ * normalized to the range 0..MAXJSAMPLE rather than -0.5 .. 0.5.
+ * The conversion equations to be implemented are therefore
+ * R = Y + 1.40200 * Cr
+ * G = Y - 0.34414 * Cb - 0.71414 * Cr
+ * B = Y + 1.77200 * Cb
+ * where Cb and Cr represent the incoming values less CENTERJSAMPLE.
+ * (These numbers are derived from TIFF 6.0 section 21, dated 3-June-92.)
+ *
+ * To avoid floating-point arithmetic, we represent the fractional constants
+ * as integers scaled up by 2^16 (about 4 digits precision); we have to divide
+ * the products by 2^16, with appropriate rounding, to get the correct answer.
+ * Notice that Y, being an integral input, does not contribute any fraction
+ * so it need not participate in the rounding.
+ *
+ * For even more speed, we avoid doing any multiplications in the inner loop
+ * by precalculating the constants times Cb and Cr for all possible values.
+ * For 8-bit JSAMPLEs this is very reasonable (only 256 entries per table);
+ * for 12-bit samples it is still acceptable. It's not very reasonable for
+ * 16-bit samples, but if you want lossless storage you shouldn't be changing
+ * colorspace anyway.
+ * The Cr=>R and Cb=>B values can be rounded to integers in advance; the
+ * values for the G calculation are left scaled up, since we must add them
+ * together before rounding.
+ */
+
+#define SCALEBITS 16 /* speediest right-shift on some machines */
+#define ONE_HALF ((long) 1 << (SCALEBITS-1))
+#define FIX(x) ((long) ((x) * (1L<<SCALEBITS) + 0.5))
+
+
+/*
+ * Initialize tables for YCC->RGB colorspace conversion.
+ */
+
+LOCAL(void)
+build_ycc_rgb_table (j_decompress_ptr cinfo)
+{
+ my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert;
+ int i;
+ long x;
+ SHIFT_TEMPS
+
+ cconvert->Cr_r_tab = (int *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (MAXJSAMPLE+1) * SIZEOF(int));
+ cconvert->Cb_b_tab = (int *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (MAXJSAMPLE+1) * SIZEOF(int));
+ cconvert->Cr_g_tab = (long *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (MAXJSAMPLE+1) * SIZEOF(long));
+ cconvert->Cb_g_tab = (long *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (MAXJSAMPLE+1) * SIZEOF(long));
+
+ for (i = 0, x = -CENTERJSAMPLE; i <= MAXJSAMPLE; i++, x++) {
+ /* i is the actual input pixel value, in the range 0..MAXJSAMPLE */
+ /* The Cb or Cr value we are thinking of is x = i - CENTERJSAMPLE */
+ /* Cr=>R value is nearest int to 1.40200 * x */
+ cconvert->Cr_r_tab[i] = (int)
+ RIGHT_SHIFT(FIX(1.40200) * x + ONE_HALF, SCALEBITS);
+ /* Cb=>B value is nearest int to 1.77200 * x */
+ cconvert->Cb_b_tab[i] = (int)
+ RIGHT_SHIFT(FIX(1.77200) * x + ONE_HALF, SCALEBITS);
+ /* Cr=>G value is scaled-up -0.71414 * x */
+ cconvert->Cr_g_tab[i] = (- FIX(0.71414)) * x;
+ /* Cb=>G value is scaled-up -0.34414 * x */
+ /* We also add in ONE_HALF so that need not do it in inner loop */
+ cconvert->Cb_g_tab[i] = (- FIX(0.34414)) * x + ONE_HALF;
+ }
+}
+
+
+/*
+ * Convert some rows of samples to the output colorspace.
+ *
+ * Note that we change from noninterleaved, one-plane-per-component format
+ * to interleaved-pixel format. The output buffer is therefore three times
+ * as wide as the input buffer.
+ * A starting row offset is provided only for the input buffer. The caller
+ * can easily adjust the passed output_buf value to accommodate any row
+ * offset required on that side.
+ */
+
+METHODDEF(void)
+ycc_rgb_convert (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION input_row,
+ JSAMPARRAY output_buf, int num_rows)
+{
+ my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert;
+ int y, cb, cr;
+ JSAMPROW outptr;
+ JSAMPROW inptr0, inptr1, inptr2;
+ JDIMENSION col;
+ JDIMENSION num_cols = cinfo->output_width;
+ /* copy these pointers into registers if possible */
+ JSAMPLE * range_limit = cinfo->sample_range_limit;
+ int * Crrtab = cconvert->Cr_r_tab;
+ int * Cbbtab = cconvert->Cb_b_tab;
+ long * Crgtab = cconvert->Cr_g_tab;
+ long * Cbgtab = cconvert->Cb_g_tab;
+ SHIFT_TEMPS
+
+ while (--num_rows >= 0) {
+ inptr0 = input_buf[0][input_row];
+ inptr1 = input_buf[1][input_row];
+ inptr2 = input_buf[2][input_row];
+ input_row++;
+ outptr = *output_buf++;
+ for (col = 0; col < num_cols; col++) {
+ y = GETJSAMPLE(inptr0[col]);
+ cb = GETJSAMPLE(inptr1[col]);
+ cr = GETJSAMPLE(inptr2[col]);
+ /* Range-limiting is essential due to noise introduced by DCT losses. */
+ outptr[RGB_RED] = range_limit[y + Crrtab[cr]];
+ outptr[RGB_GREEN] = range_limit[y +
+ ((int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr],
+ SCALEBITS))];
+ outptr[RGB_BLUE] = range_limit[y + Cbbtab[cb]];
+ outptr += RGB_PIXELSIZE;
+ }
+ }
+}
+
+
+/**************** Cases other than YCbCr -> RGB **************/
+
+
+/*
+ * Color conversion for no colorspace change: just copy the data,
+ * converting from separate-planes to interleaved representation.
+ */
+
+METHODDEF(void)
+null_convert (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION input_row,
+ JSAMPARRAY output_buf, int num_rows)
+{
+ JSAMPROW inptr, outptr;
+ JDIMENSION count;
+ int num_components = cinfo->num_components;
+ JDIMENSION num_cols = cinfo->output_width;
+ int ci;
+
+ while (--num_rows >= 0) {
+ for (ci = 0; ci < num_components; ci++) {
+ inptr = input_buf[ci][input_row];
+ outptr = output_buf[0] + ci;
+ for (count = num_cols; count > 0; count--) {
+ *outptr = *inptr++; /* needn't bother with GETJSAMPLE() here */
+ outptr += num_components;
+ }
+ }
+ input_row++;
+ output_buf++;
+ }
+}
+
+
+/*
+ * Color conversion for grayscale: just copy the data.
+ * This also works for YCbCr -> grayscale conversion, in which
+ * we just copy the Y (luminance) component and ignore chrominance.
+ */
+
+METHODDEF(void)
+grayscale_convert (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION input_row,
+ JSAMPARRAY output_buf, int num_rows)
+{
+ jcopy_sample_rows(input_buf[0], (int) input_row, output_buf, 0,
+ num_rows, cinfo->output_width);
+}
+
+
+/*
+ * Convert grayscale to RGB: just duplicate the graylevel three times.
+ * This is provided to support applications that don't want to cope
+ * with grayscale as a separate case.
+ */
+
+METHODDEF(void)
+gray_rgb_convert (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION input_row,
+ JSAMPARRAY output_buf, int num_rows)
+{
+ JSAMPROW inptr, outptr;
+ JDIMENSION col;
+ JDIMENSION num_cols = cinfo->output_width;
+
+ while (--num_rows >= 0) {
+ inptr = input_buf[0][input_row++];
+ outptr = *output_buf++;
+ for (col = 0; col < num_cols; col++) {
+ /* We can dispense with GETJSAMPLE() here */
+ outptr[RGB_RED] = outptr[RGB_GREEN] = outptr[RGB_BLUE] = inptr[col];
+ outptr += RGB_PIXELSIZE;
+ }
+ }
+}
+
+
+/*
+ * Adobe-style YCCK->CMYK conversion.
+ * We convert YCbCr to R=1-C, G=1-M, and B=1-Y using the same
+ * conversion as above, while passing K (black) unchanged.
+ * We assume build_ycc_rgb_table has been called.
+ */
+
+METHODDEF(void)
+ycck_cmyk_convert (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION input_row,
+ JSAMPARRAY output_buf, int num_rows)
+{
+ my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert;
+ int y, cb, cr;
+ JSAMPROW outptr;
+ JSAMPROW inptr0, inptr1, inptr2, inptr3;
+ JDIMENSION col;
+ JDIMENSION num_cols = cinfo->output_width;
+ /* copy these pointers into registers if possible */
+ JSAMPLE * range_limit = cinfo->sample_range_limit;
+ int * Crrtab = cconvert->Cr_r_tab;
+ int * Cbbtab = cconvert->Cb_b_tab;
+ long * Crgtab = cconvert->Cr_g_tab;
+ long * Cbgtab = cconvert->Cb_g_tab;
+ SHIFT_TEMPS
+
+ while (--num_rows >= 0) {
+ inptr0 = input_buf[0][input_row];
+ inptr1 = input_buf[1][input_row];
+ inptr2 = input_buf[2][input_row];
+ inptr3 = input_buf[3][input_row];
+ input_row++;
+ outptr = *output_buf++;
+ for (col = 0; col < num_cols; col++) {
+ y = GETJSAMPLE(inptr0[col]);
+ cb = GETJSAMPLE(inptr1[col]);
+ cr = GETJSAMPLE(inptr2[col]);
+ /* Range-limiting is essential due to noise introduced by DCT losses. */
+ outptr[0] = range_limit[MAXJSAMPLE - (y + Crrtab[cr])]; /* red */
+ outptr[1] = range_limit[MAXJSAMPLE - (y + /* green */
+ ((int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr],
+ SCALEBITS)))];
+ outptr[2] = range_limit[MAXJSAMPLE - (y + Cbbtab[cb])]; /* blue */
+ /* K passes through unchanged */
+ outptr[3] = inptr3[col]; /* don't need GETJSAMPLE here */
+ outptr += 4;
+ }
+ }
+}
+
+
+/*
+ * Empty method for start_pass.
+ */
+
+METHODDEF(void)
+start_pass_dcolor (j_decompress_ptr )
+{
+ /* no work needed */
+}
+
+
+/*
+ * Module initialization routine for output colorspace conversion.
+ */
+
+GLOBAL(void)
+jinit_color_deconverter (j_decompress_ptr cinfo)
+{
+ my_cconvert_ptr cconvert;
+ int ci;
+
+ cconvert = (my_cconvert_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_color_deconverter));
+ cinfo->cconvert = (struct jpeg_color_deconverter *) cconvert;
+ cconvert->pub.start_pass = start_pass_dcolor;
+
+ /* Make sure num_components agrees with jpeg_color_space */
+ switch (cinfo->jpeg_color_space) {
+ case JCS_GRAYSCALE:
+ if (cinfo->num_components != 1)
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ break;
+
+ case JCS_RGB:
+ case JCS_YCbCr:
+ if (cinfo->num_components != 3)
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ break;
+
+ case JCS_CMYK:
+ case JCS_YCCK:
+ if (cinfo->num_components != 4)
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ break;
+
+ default: /* JCS_UNKNOWN can be anything */
+ if (cinfo->num_components < 1)
+ ERREXIT(cinfo, JERR_BAD_J_COLORSPACE);
+ break;
+ }
+
+ /* Set out_color_components and conversion method based on requested space.
+ * Also clear the component_needed flags for any unused components,
+ * so that earlier pipeline stages can avoid useless computation.
+ */
+
+ switch (cinfo->out_color_space) {
+ case JCS_GRAYSCALE:
+ cinfo->out_color_components = 1;
+ if (cinfo->jpeg_color_space == JCS_GRAYSCALE ||
+ cinfo->jpeg_color_space == JCS_YCbCr) {
+ cconvert->pub.color_convert = grayscale_convert;
+ /* For color->grayscale conversion, only the Y (0) component is needed */
+ for (ci = 1; ci < cinfo->num_components; ci++)
+ cinfo->comp_info[ci].component_needed = FALSE;
+ } else
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ break;
+
+ case JCS_RGB:
+ cinfo->out_color_components = RGB_PIXELSIZE;
+ if (cinfo->jpeg_color_space == JCS_YCbCr) {
+ cconvert->pub.color_convert = ycc_rgb_convert;
+ build_ycc_rgb_table(cinfo);
+ } else if (cinfo->jpeg_color_space == JCS_GRAYSCALE) {
+ cconvert->pub.color_convert = gray_rgb_convert;
+ } else if (cinfo->jpeg_color_space == JCS_RGB && RGB_PIXELSIZE == 3) {
+ cconvert->pub.color_convert = null_convert;
+ } else
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ break;
+
+ case JCS_CMYK:
+ cinfo->out_color_components = 4;
+ if (cinfo->jpeg_color_space == JCS_YCCK) {
+ cconvert->pub.color_convert = ycck_cmyk_convert;
+ build_ycc_rgb_table(cinfo);
+ } else if (cinfo->jpeg_color_space == JCS_CMYK) {
+ cconvert->pub.color_convert = null_convert;
+ } else
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ break;
+
+ default:
+ /* Permit null conversion to same output space */
+ if (cinfo->out_color_space == cinfo->jpeg_color_space) {
+ cinfo->out_color_components = cinfo->num_components;
+ cconvert->pub.color_convert = null_convert;
+ } else /* unsupported non-null conversion */
+ ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL);
+ break;
+ }
+
+ if (cinfo->quantize_colors)
+ cinfo->output_components = 1; /* single colormapped output component */
+ else
+ cinfo->output_components = cinfo->out_color_components;
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdct.h b/ml/dlib/dlib/external/libjpeg/jdct.h
new file mode 100644
index 000000000..a89c9550b
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdct.h
@@ -0,0 +1,176 @@
+/*
+ * jdct.h
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This include file contains common declarations for the forward and
+ * inverse DCT modules. These declarations are private to the DCT managers
+ * (jcdctmgr.c, jddctmgr.c) and the individual DCT algorithms.
+ * The individual DCT algorithms are kept in separate files to ease
+ * machine-dependent tuning (e.g., assembly coding).
+ */
+
+
+/*
+ * A forward DCT routine is given a pointer to a work area of type DCTELEM[];
+ * the DCT is to be performed in-place in that buffer. Type DCTELEM is int
+ * for 8-bit samples, long for 12-bit samples. (NOTE: Floating-point DCT
+ * implementations use an array of type FAST_FLOAT, instead.)
+ * The DCT inputs are expected to be signed (range +-CENTERJSAMPLE).
+ * The DCT outputs are returned scaled up by a factor of 8; they therefore
+ * have a range of +-8K for 8-bit data, +-128K for 12-bit data. This
+ * convention improves accuracy in integer implementations and saves some
+ * work in floating-point ones.
+ * Quantization of the output coefficients is done by jcdctmgr.c.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+typedef int DCTELEM; /* 16 or 32 bits is fine */
+#else
+typedef long DCTELEM; /* must have 32 bits */
+#endif
+
+typedef JMETHOD(void, forward_DCT_method_ptr, (DCTELEM * data));
+typedef JMETHOD(void, float_DCT_method_ptr, (FAST_FLOAT * data));
+
+
+/*
+ * An inverse DCT routine is given a pointer to the input JBLOCK and a pointer
+ * to an output sample array. The routine must dequantize the input data as
+ * well as perform the IDCT; for dequantization, it uses the multiplier table
+ * pointed to by compptr->dct_table. The output data is to be placed into the
+ * sample array starting at a specified column. (Any row offset needed will
+ * be applied to the array pointer before it is passed to the IDCT code.)
+ * Note that the number of samples emitted by the IDCT routine is
+ * DCT_scaled_size * DCT_scaled_size.
+ */
+
+/* typedef inverse_DCT_method_ptr is declared in jpegint.h */
+
+/*
+ * Each IDCT routine has its own ideas about the best dct_table element type.
+ */
+
+typedef MULTIPLIER ISLOW_MULT_TYPE; /* short or int, whichever is faster */
+#if BITS_IN_JSAMPLE == 8
+typedef MULTIPLIER IFAST_MULT_TYPE; /* 16 bits is OK, use short if faster */
+#define IFAST_SCALE_BITS 2 /* fractional bits in scale factors */
+#else
+typedef long IFAST_MULT_TYPE; /* need 32 bits for scaled quantizers */
+#define IFAST_SCALE_BITS 13 /* fractional bits in scale factors */
+#endif
+typedef FAST_FLOAT FLOAT_MULT_TYPE; /* preferred floating type */
+
+
+/*
+ * Each IDCT routine is responsible for range-limiting its results and
+ * converting them to unsigned form (0..MAXJSAMPLE). The raw outputs could
+ * be quite far out of range if the input data is corrupt, so a bulletproof
+ * range-limiting step is required. We use a mask-and-table-lookup method
+ * to do the combined operations quickly. See the comments with
+ * prepare_range_limit_table (in jdmaster.c) for more info.
+ */
+
+#define IDCT_range_limit(cinfo) ((cinfo)->sample_range_limit + CENTERJSAMPLE)
+
+#define RANGE_MASK (MAXJSAMPLE * 4 + 3) /* 2 bits wider than legal samples */
+
+
+/* Short forms of external names for systems with brain-damaged linkers. */
+
+#ifdef NEED_SHORT_EXTERNAL_NAMES
+#define jpeg_fdct_islow jFDislow
+#define jpeg_fdct_ifast jFDifast
+#define jpeg_fdct_float jFDfloat
+#define jpeg_idct_islow jRDislow
+#define jpeg_idct_ifast jRDifast
+#define jpeg_idct_float jRDfloat
+#define jpeg_idct_4x4 jRD4x4
+#define jpeg_idct_2x2 jRD2x2
+#define jpeg_idct_1x1 jRD1x1
+#endif /* NEED_SHORT_EXTERNAL_NAMES */
+
+/* Extern declarations for the forward and inverse DCT routines. */
+
+EXTERN(void) jpeg_fdct_islow JPP((DCTELEM * data));
+EXTERN(void) jpeg_fdct_ifast JPP((DCTELEM * data));
+EXTERN(void) jpeg_fdct_float JPP((FAST_FLOAT * data));
+
+EXTERN(void) jpeg_idct_islow
+ JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col));
+EXTERN(void) jpeg_idct_ifast
+ JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col));
+EXTERN(void) jpeg_idct_float
+ JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col));
+EXTERN(void) jpeg_idct_4x4
+ JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col));
+EXTERN(void) jpeg_idct_2x2
+ JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col));
+EXTERN(void) jpeg_idct_1x1
+ JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col));
+
+
+/*
+ * Macros for handling fixed-point arithmetic; these are used by many
+ * but not all of the DCT/IDCT modules.
+ *
+ * All values are expected to be of type long.
+ * Fractional constants are scaled left by CONST_BITS bits.
+ * CONST_BITS is defined within each module using these macros,
+ * and may differ from one module to the next.
+ */
+
+#define ONE ((long) 1)
+#define CONST_SCALE (ONE << CONST_BITS)
+
+/* Convert a positive real constant to an integer scaled by CONST_SCALE.
+ * Caution: some C compilers fail to reduce "FIX(constant)" at compile time,
+ * thus causing a lot of useless floating-point operations at run time.
+ */
+
+#define FIX(x) ((long) ((x) * CONST_SCALE + 0.5))
+
+/* Descale and correctly round an long value that's scaled by N bits.
+ * We assume RIGHT_SHIFT rounds towards minus infinity, so adding
+ * the fudge factor is correct for either sign of X.
+ */
+
+#define DESCALE(x,n) RIGHT_SHIFT((x) + (ONE << ((n)-1)), n)
+
+/* Multiply an long variable by an long constant to yield an long result.
+ * This macro is used only when the two inputs will actually be no more than
+ * 16 bits wide, so that a 16x16->32 bit multiply can be used instead of a
+ * full 32x32 multiply. This provides a useful speedup on many machines.
+ * Unfortunately there is no way to specify a 16x16->32 multiply portably
+ * in C, but some C compilers will do the right thing if you provide the
+ * correct combination of casts.
+ */
+
+#ifdef SHORTxSHORT_32 /* may work if 'int' is 32 bits */
+#define MULTIPLY16C16(var,const) (((short) (var)) * ((short) (const)))
+#endif
+#ifdef SHORTxLCONST_32 /* known to work with Microsoft C 6.0 */
+#define MULTIPLY16C16(var,const) (((short) (var)) * ((long) (const)))
+#endif
+
+#ifndef MULTIPLY16C16 /* default definition */
+#define MULTIPLY16C16(var,const) ((var) * (const))
+#endif
+
+/* Same except both inputs are variables. */
+
+#ifdef SHORTxSHORT_32 /* may work if 'int' is 32 bits */
+#define MULTIPLY16V16(var1,var2) (((short) (var1)) * ((short) (var2)))
+#endif
+
+#ifndef MULTIPLY16V16 /* default definition */
+#define MULTIPLY16V16(var1,var2) ((var1) * (var2))
+#endif
diff --git a/ml/dlib/dlib/external/libjpeg/jddctmgr.cpp b/ml/dlib/dlib/external/libjpeg/jddctmgr.cpp
new file mode 100644
index 000000000..620da686d
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jddctmgr.cpp
@@ -0,0 +1,269 @@
+/*
+ * jddctmgr.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains the inverse-DCT management logic.
+ * This code selects a particular IDCT implementation to be used,
+ * and it performs related housekeeping chores. No code in this file
+ * is executed per IDCT step, only during output pass setup.
+ *
+ * Note that the IDCT routines are responsible for performing coefficient
+ * dequantization as well as the IDCT proper. This module sets up the
+ * dequantization multiplier table needed by the IDCT routine.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdct.h" /* Private declarations for DCT subsystem */
+
+
+/*
+ * The decompressor input side (jdinput.c) saves away the appropriate
+ * quantization table for each component at the start of the first scan
+ * involving that component. (This is necessary in order to correctly
+ * decode files that reuse Q-table slots.)
+ * When we are ready to make an output pass, the saved Q-table is converted
+ * to a multiplier table that will actually be used by the IDCT routine.
+ * The multiplier table contents are IDCT-method-dependent. To support
+ * application changes in IDCT method between scans, we can remake the
+ * multiplier tables if necessary.
+ * In buffered-image mode, the first output pass may occur before any data
+ * has been seen for some components, and thus before their Q-tables have
+ * been saved away. To handle this case, multiplier tables are preset
+ * to zeroes; the result of the IDCT will be a neutral gray level.
+ */
+
+
+/* Private subobject for this module */
+
+typedef struct {
+ struct jpeg_inverse_dct pub; /* public fields */
+
+ /* This array contains the IDCT method code that each multiplier table
+ * is currently set up for, or -1 if it's not yet set up.
+ * The actual multiplier tables are pointed to by dct_table in the
+ * per-component comp_info structures.
+ */
+ int cur_method[MAX_COMPONENTS];
+} my_idct_controller;
+
+typedef my_idct_controller * my_idct_ptr;
+
+
+/* Allocated multiplier tables: big enough for any supported variant */
+
+typedef union {
+ ISLOW_MULT_TYPE islow_array[DCTSIZE2];
+#ifdef DCT_IFAST_SUPPORTED
+ IFAST_MULT_TYPE ifast_array[DCTSIZE2];
+#endif
+#ifdef DCT_FLOAT_SUPPORTED
+ FLOAT_MULT_TYPE float_array[DCTSIZE2];
+#endif
+} multiplier_table;
+
+
+/* The current scaled-IDCT routines require ISLOW-style multiplier tables,
+ * so be sure to compile that code if either ISLOW or SCALING is requested.
+ */
+#ifdef DCT_ISLOW_SUPPORTED
+#define PROVIDE_ISLOW_TABLES
+#else
+#ifdef IDCT_SCALING_SUPPORTED
+#define PROVIDE_ISLOW_TABLES
+#endif
+#endif
+
+
+/*
+ * Prepare for an output pass.
+ * Here we select the proper IDCT routine for each component and build
+ * a matching multiplier table.
+ */
+
+METHODDEF(void)
+start_pass (j_decompress_ptr cinfo)
+{
+ my_idct_ptr idct = (my_idct_ptr) cinfo->idct;
+ int ci, i;
+ jpeg_component_info *compptr;
+ int method = 0;
+ inverse_DCT_method_ptr method_ptr = NULL;
+ JQUANT_TBL * qtbl;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Select the proper IDCT routine for this component's scaling */
+ switch (compptr->DCT_scaled_size) {
+#ifdef IDCT_SCALING_SUPPORTED
+ case 1:
+ method_ptr = jpeg_idct_1x1;
+ method = JDCT_ISLOW; /* jidctred uses islow-style table */
+ break;
+ case 2:
+ method_ptr = jpeg_idct_2x2;
+ method = JDCT_ISLOW; /* jidctred uses islow-style table */
+ break;
+ case 4:
+ method_ptr = jpeg_idct_4x4;
+ method = JDCT_ISLOW; /* jidctred uses islow-style table */
+ break;
+#endif
+ case DCTSIZE:
+ switch (cinfo->dct_method) {
+#ifdef DCT_ISLOW_SUPPORTED
+ case JDCT_ISLOW:
+ method_ptr = jpeg_idct_islow;
+ method = JDCT_ISLOW;
+ break;
+#endif
+#ifdef DCT_IFAST_SUPPORTED
+ case JDCT_IFAST:
+ method_ptr = jpeg_idct_ifast;
+ method = JDCT_IFAST;
+ break;
+#endif
+#ifdef DCT_FLOAT_SUPPORTED
+ case JDCT_FLOAT:
+ method_ptr = jpeg_idct_float;
+ method = JDCT_FLOAT;
+ break;
+#endif
+ default:
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+ break;
+ }
+ break;
+ default:
+ ERREXIT1(cinfo, JERR_BAD_DCTSIZE, compptr->DCT_scaled_size);
+ break;
+ }
+ idct->pub.inverse_DCT[ci] = method_ptr;
+ /* Create multiplier table from quant table.
+ * However, we can skip this if the component is uninteresting
+ * or if we already built the table. Also, if no quant table
+ * has yet been saved for the component, we leave the
+ * multiplier table all-zero; we'll be reading zeroes from the
+ * coefficient controller's buffer anyway.
+ */
+ if (! compptr->component_needed || idct->cur_method[ci] == method)
+ continue;
+ qtbl = compptr->quant_table;
+ if (qtbl == NULL) /* happens if no data yet for component */
+ continue;
+ idct->cur_method[ci] = method;
+ switch (method) {
+#ifdef PROVIDE_ISLOW_TABLES
+ case JDCT_ISLOW:
+ {
+ /* For LL&M IDCT method, multipliers are equal to raw quantization
+ * coefficients, but are stored as ints to ensure access efficiency.
+ */
+ ISLOW_MULT_TYPE * ismtbl = (ISLOW_MULT_TYPE *) compptr->dct_table;
+ for (i = 0; i < DCTSIZE2; i++) {
+ ismtbl[i] = (ISLOW_MULT_TYPE) qtbl->quantval[i];
+ }
+ }
+ break;
+#endif
+#ifdef DCT_IFAST_SUPPORTED
+ case JDCT_IFAST:
+ {
+ /* For AA&N IDCT method, multipliers are equal to quantization
+ * coefficients scaled by scalefactor[row]*scalefactor[col], where
+ * scalefactor[0] = 1
+ * scalefactor[k] = cos(k*PI/16) * sqrt(2) for k=1..7
+ * For integer operation, the multiplier table is to be scaled by
+ * IFAST_SCALE_BITS.
+ */
+ IFAST_MULT_TYPE * ifmtbl = (IFAST_MULT_TYPE *) compptr->dct_table;
+#define CONST_BITS 14
+ static const short aanscales[DCTSIZE2] = {
+ /* precomputed values scaled up by 14 bits */
+ 16384, 22725, 21407, 19266, 16384, 12873, 8867, 4520,
+ 22725, 31521, 29692, 26722, 22725, 17855, 12299, 6270,
+ 21407, 29692, 27969, 25172, 21407, 16819, 11585, 5906,
+ 19266, 26722, 25172, 22654, 19266, 15137, 10426, 5315,
+ 16384, 22725, 21407, 19266, 16384, 12873, 8867, 4520,
+ 12873, 17855, 16819, 15137, 12873, 10114, 6967, 3552,
+ 8867, 12299, 11585, 10426, 8867, 6967, 4799, 2446,
+ 4520, 6270, 5906, 5315, 4520, 3552, 2446, 1247
+ };
+ SHIFT_TEMPS
+
+ for (i = 0; i < DCTSIZE2; i++) {
+ ifmtbl[i] = (IFAST_MULT_TYPE)
+ DESCALE(MULTIPLY16V16((long) qtbl->quantval[i],
+ (long) aanscales[i]),
+ CONST_BITS-IFAST_SCALE_BITS);
+ }
+ }
+ break;
+#endif
+#ifdef DCT_FLOAT_SUPPORTED
+ case JDCT_FLOAT:
+ {
+ /* For float AA&N IDCT method, multipliers are equal to quantization
+ * coefficients scaled by scalefactor[row]*scalefactor[col], where
+ * scalefactor[0] = 1
+ * scalefactor[k] = cos(k*PI/16) * sqrt(2) for k=1..7
+ */
+ FLOAT_MULT_TYPE * fmtbl = (FLOAT_MULT_TYPE *) compptr->dct_table;
+ int row, col;
+ static const double aanscalefactor[DCTSIZE] = {
+ 1.0, 1.387039845, 1.306562965, 1.175875602,
+ 1.0, 0.785694958, 0.541196100, 0.275899379
+ };
+
+ i = 0;
+ for (row = 0; row < DCTSIZE; row++) {
+ for (col = 0; col < DCTSIZE; col++) {
+ fmtbl[i] = (FLOAT_MULT_TYPE)
+ ((double) qtbl->quantval[i] *
+ aanscalefactor[row] * aanscalefactor[col]);
+ i++;
+ }
+ }
+ }
+ break;
+#endif
+ default:
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+ break;
+ }
+ }
+}
+
+
+/*
+ * Initialize IDCT manager.
+ */
+
+GLOBAL(void)
+jinit_inverse_dct (j_decompress_ptr cinfo)
+{
+ my_idct_ptr idct;
+ int ci;
+ jpeg_component_info *compptr;
+
+ idct = (my_idct_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_idct_controller));
+ cinfo->idct = (struct jpeg_inverse_dct *) idct;
+ idct->pub.start_pass = start_pass;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Allocate and pre-zero a multiplier table for each component */
+ compptr->dct_table =
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(multiplier_table));
+ MEMZERO(compptr->dct_table, SIZEOF(multiplier_table));
+ /* Mark multiplier table not yet set up for any method */
+ idct->cur_method[ci] = -1;
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdhuff.cpp b/ml/dlib/dlib/external/libjpeg/jdhuff.cpp
new file mode 100644
index 000000000..26a2a36f2
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdhuff.cpp
@@ -0,0 +1,654 @@
+/*
+ * jdhuff.c
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains Huffman entropy decoding routines.
+ *
+ * Much of the complexity here has to do with supporting input suspension.
+ * If the data source module demands suspension, we want to be able to back
+ * up to the start of the current MCU. To do this, we copy state variables
+ * into local working storage, and update them back to the permanent
+ * storage only upon successful completion of an MCU.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdhuff.h" /* Declarations shared with jdphuff.c */
+
+#ifdef __GNUC__
+#pragma GCC diagnostic ignored "-Wshift-negative-value"
+#endif
+
+/*
+ * Expanded entropy decoder object for Huffman decoding.
+ *
+ * The savable_state subrecord contains fields that change within an MCU,
+ * but must not be updated permanently until we complete the MCU.
+ */
+
+typedef struct {
+ int last_dc_val[MAX_COMPS_IN_SCAN]; /* last DC coef for each component */
+} savable_state;
+
+/* This macro is to work around compilers with missing or broken
+ * structure assignment. You'll need to fix this code if you have
+ * such a compiler and you change MAX_COMPS_IN_SCAN.
+ */
+
+#ifndef NO_STRUCT_ASSIGN
+#define ASSIGN_STATE(dest,src) ((dest) = (src))
+#else
+#if MAX_COMPS_IN_SCAN == 4
+#define ASSIGN_STATE(dest,src) \
+ ((dest).last_dc_val[0] = (src).last_dc_val[0], \
+ (dest).last_dc_val[1] = (src).last_dc_val[1], \
+ (dest).last_dc_val[2] = (src).last_dc_val[2], \
+ (dest).last_dc_val[3] = (src).last_dc_val[3])
+#endif
+#endif
+
+
+typedef struct {
+ struct jpeg_entropy_decoder pub; /* public fields */
+
+ /* These fields are loaded into local variables at start of each MCU.
+ * In case of suspension, we exit WITHOUT updating them.
+ */
+ bitread_perm_state bitstate; /* Bit buffer at start of MCU */
+ savable_state saved; /* Other state at start of MCU */
+
+ /* These fields are NOT loaded into local working state. */
+ unsigned int restarts_to_go; /* MCUs left in this restart interval */
+
+ /* Pointers to derived tables (these workspaces have image lifespan) */
+ d_derived_tbl * dc_derived_tbls[NUM_HUFF_TBLS];
+ d_derived_tbl * ac_derived_tbls[NUM_HUFF_TBLS];
+
+ /* Precalculated info set up by start_pass for use in decode_mcu: */
+
+ /* Pointers to derived tables to be used for each block within an MCU */
+ d_derived_tbl * dc_cur_tbls[D_MAX_BLOCKS_IN_MCU];
+ d_derived_tbl * ac_cur_tbls[D_MAX_BLOCKS_IN_MCU];
+ /* Whether we care about the DC and AC coefficient values for each block */
+ int dc_needed[D_MAX_BLOCKS_IN_MCU];
+ int ac_needed[D_MAX_BLOCKS_IN_MCU];
+} huff_entropy_decoder;
+
+typedef huff_entropy_decoder * huff_entropy_ptr;
+
+
+/*
+ * Initialize for a Huffman-compressed scan.
+ */
+
+METHODDEF(void)
+start_pass_huff_decoder (j_decompress_ptr cinfo)
+{
+ huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy;
+ int ci, blkn, dctbl, actbl;
+ jpeg_component_info * compptr;
+
+ /* Check that the scan parameters Ss, Se, Ah/Al are OK for sequential JPEG.
+ * This ought to be an error condition, but we make it a warning because
+ * there are some baseline files out there with all zeroes in these bytes.
+ */
+ if (cinfo->Ss != 0 || cinfo->Se != DCTSIZE2-1 ||
+ cinfo->Ah != 0 || cinfo->Al != 0)
+ WARNMS(cinfo, JWRN_NOT_SEQUENTIAL);
+
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ dctbl = compptr->dc_tbl_no;
+ actbl = compptr->ac_tbl_no;
+ /* Compute derived values for Huffman tables */
+ /* We may do this more than once for a table, but it's not expensive */
+ jpeg_make_d_derived_tbl(cinfo, TRUE, dctbl,
+ & entropy->dc_derived_tbls[dctbl]);
+ jpeg_make_d_derived_tbl(cinfo, FALSE, actbl,
+ & entropy->ac_derived_tbls[actbl]);
+ /* Initialize DC predictions to 0 */
+ entropy->saved.last_dc_val[ci] = 0;
+ }
+
+ /* Precalculate decoding info for each block in an MCU of this scan */
+ for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) {
+ ci = cinfo->MCU_membership[blkn];
+ compptr = cinfo->cur_comp_info[ci];
+ /* Precalculate which table to use for each block */
+ entropy->dc_cur_tbls[blkn] = entropy->dc_derived_tbls[compptr->dc_tbl_no];
+ entropy->ac_cur_tbls[blkn] = entropy->ac_derived_tbls[compptr->ac_tbl_no];
+ /* Decide whether we really care about the coefficient values */
+ if (compptr->component_needed) {
+ entropy->dc_needed[blkn] = TRUE;
+ /* we don't need the ACs if producing a 1/8th-size image */
+ entropy->ac_needed[blkn] = (compptr->DCT_scaled_size > 1);
+ } else {
+ entropy->dc_needed[blkn] = entropy->ac_needed[blkn] = FALSE;
+ }
+ }
+
+ /* Initialize bitread state variables */
+ entropy->bitstate.bits_left = 0;
+ entropy->bitstate.get_buffer = 0; /* unnecessary, but keeps Purify quiet */
+ entropy->pub.insufficient_data = FALSE;
+
+ /* Initialize restart counter */
+ entropy->restarts_to_go = cinfo->restart_interval;
+}
+
+
+/*
+ * Compute the derived values for a Huffman table.
+ * This routine also performs some validation checks on the table.
+ *
+ * Note this is also used by jdphuff.c.
+ */
+
+GLOBAL(void)
+jpeg_make_d_derived_tbl (j_decompress_ptr cinfo, int isDC, int tblno,
+ d_derived_tbl ** pdtbl)
+{
+ JHUFF_TBL *htbl;
+ d_derived_tbl *dtbl;
+ int p, i, l, si, numsymbols;
+ int lookbits, ctr;
+ char huffsize[257];
+ unsigned int huffcode[257];
+ unsigned int code;
+
+ /* Note that huffsize[] and huffcode[] are filled in code-length order,
+ * paralleling the order of the symbols themselves in htbl->huffval[].
+ */
+
+ /* Find the input Huffman table */
+ if (tblno < 0 || tblno >= NUM_HUFF_TBLS)
+ ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tblno);
+ htbl =
+ isDC ? cinfo->dc_huff_tbl_ptrs[tblno] : cinfo->ac_huff_tbl_ptrs[tblno];
+ if (htbl == NULL)
+ ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tblno);
+
+ /* Allocate a workspace if we haven't already done so. */
+ if (*pdtbl == NULL)
+ *pdtbl = (d_derived_tbl *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(d_derived_tbl));
+ dtbl = *pdtbl;
+ dtbl->pub = htbl; /* fill in back link */
+
+ /* Figure C.1: make table of Huffman code length for each symbol */
+
+ p = 0;
+ for (l = 1; l <= 16; l++) {
+ i = (int) htbl->bits[l];
+ if (i < 0 || p + i > 256) /* protect against table overrun */
+ ERREXIT(cinfo, JERR_BAD_HUFF_TABLE);
+ while (i--)
+ huffsize[p++] = (char) l;
+ }
+ huffsize[p] = 0;
+ numsymbols = p;
+
+ /* Figure C.2: generate the codes themselves */
+ /* We also validate that the counts represent a legal Huffman code tree. */
+
+ code = 0;
+ si = huffsize[0];
+ p = 0;
+ while (huffsize[p]) {
+ while (((int) huffsize[p]) == si) {
+ huffcode[p++] = code;
+ code++;
+ }
+ /* code is now 1 more than the last code used for codelength si; but
+ * it must still fit in si bits, since no code is allowed to be all ones.
+ */
+ if (((long) code) >= (((long) 1) << si))
+ ERREXIT(cinfo, JERR_BAD_HUFF_TABLE);
+ code <<= 1;
+ si++;
+ }
+
+ /* Figure F.15: generate decoding tables for bit-sequential decoding */
+
+ p = 0;
+ for (l = 1; l <= 16; l++) {
+ if (htbl->bits[l]) {
+ /* valoffset[l] = huffval[] index of 1st symbol of code length l,
+ * minus the minimum code of length l
+ */
+ dtbl->valoffset[l] = (long) p - (long) huffcode[p];
+ p += htbl->bits[l];
+ dtbl->maxcode[l] = huffcode[p-1]; /* maximum code of length l */
+ } else {
+ dtbl->maxcode[l] = -1; /* -1 if no codes of this length */
+ }
+ }
+ dtbl->maxcode[17] = 0xFFFFFL; /* ensures jpeg_huff_decode terminates */
+
+ /* Compute lookahead tables to speed up decoding.
+ * First we set all the table entries to 0, indicating "too long";
+ * then we iterate through the Huffman codes that are short enough and
+ * fill in all the entries that correspond to bit sequences starting
+ * with that code.
+ */
+
+ MEMZERO(dtbl->look_nbits, SIZEOF(dtbl->look_nbits));
+
+ p = 0;
+ for (l = 1; l <= HUFF_LOOKAHEAD; l++) {
+ for (i = 1; i <= (int) htbl->bits[l]; i++, p++) {
+ /* l = current code's length, p = its index in huffcode[] & huffval[]. */
+ /* Generate left-justified code followed by all possible bit sequences */
+ lookbits = huffcode[p] << (HUFF_LOOKAHEAD-l);
+ for (ctr = 1 << (HUFF_LOOKAHEAD-l); ctr > 0; ctr--) {
+ dtbl->look_nbits[lookbits] = l;
+ dtbl->look_sym[lookbits] = htbl->huffval[p];
+ lookbits++;
+ }
+ }
+ }
+
+ /* Validate symbols as being reasonable.
+ * For AC tables, we make no check, but accept all byte values 0..255.
+ * For DC tables, we require the symbols to be in range 0..15.
+ * (Tighter bounds could be applied depending on the data depth and mode,
+ * but this is sufficient to ensure safe decoding.)
+ */
+ if (isDC) {
+ for (i = 0; i < numsymbols; i++) {
+ int sym = htbl->huffval[i];
+ if (sym < 0 || sym > 15)
+ ERREXIT(cinfo, JERR_BAD_HUFF_TABLE);
+ }
+ }
+}
+
+
+/*
+ * Out-of-line code for bit fetching (shared with jdphuff.c).
+ * See jdhuff.h for info about usage.
+ * Note: current values of get_buffer and bits_left are passed as parameters,
+ * but are returned in the corresponding fields of the state struct.
+ *
+ * On most machines MIN_GET_BITS should be 25 to allow the full 32-bit width
+ * of get_buffer to be used. (On machines with wider words, an even larger
+ * buffer could be used.) However, on some machines 32-bit shifts are
+ * quite slow and take time proportional to the number of places shifted.
+ * (This is true with most PC compilers, for instance.) In this case it may
+ * be a win to set MIN_GET_BITS to the minimum value of 15. This reduces the
+ * average shift distance at the cost of more calls to jpeg_fill_bit_buffer.
+ */
+
+#ifdef SLOW_SHIFT_32
+#define MIN_GET_BITS 15 /* minimum allowable value */
+#else
+#define MIN_GET_BITS (BIT_BUF_SIZE-7)
+#endif
+
+
+GLOBAL(int)
+jpeg_fill_bit_buffer (bitread_working_state * state,
+ bit_buf_type get_buffer, register int bits_left,
+ int nbits)
+/* Load up the bit buffer to a depth of at least nbits */
+{
+ /* Copy heavily used state fields into locals (hopefully registers) */
+ const JOCTET * next_input_byte = state->next_input_byte;
+ size_t bytes_in_buffer = state->bytes_in_buffer;
+ j_decompress_ptr cinfo = state->cinfo;
+
+ /* Attempt to load at least MIN_GET_BITS bits into get_buffer. */
+ /* (It is assumed that no request will be for more than that many bits.) */
+ /* We fail to do so only if we hit a marker or are forced to suspend. */
+
+ if (cinfo->unread_marker == 0) { /* cannot advance past a marker */
+ while (bits_left < MIN_GET_BITS) {
+ int c;
+
+ /* Attempt to read a byte */
+ if (bytes_in_buffer == 0) {
+ if (! (*cinfo->src->fill_input_buffer) (cinfo))
+ return FALSE;
+ next_input_byte = cinfo->src->next_input_byte;
+ bytes_in_buffer = cinfo->src->bytes_in_buffer;
+ }
+ bytes_in_buffer--;
+ c = GETJOCTET(*next_input_byte++);
+
+ /* If it's 0xFF, check and discard stuffed zero byte */
+ if (c == 0xFF) {
+ /* Loop here to discard any padding FF's on terminating marker,
+ * so that we can save a valid unread_marker value. NOTE: we will
+ * accept multiple FF's followed by a 0 as meaning a single FF data
+ * byte. This data pattern is not valid according to the standard.
+ */
+ do {
+ if (bytes_in_buffer == 0) {
+ if (! (*cinfo->src->fill_input_buffer) (cinfo))
+ return FALSE;
+ next_input_byte = cinfo->src->next_input_byte;
+ bytes_in_buffer = cinfo->src->bytes_in_buffer;
+ }
+ bytes_in_buffer--;
+ c = GETJOCTET(*next_input_byte++);
+ } while (c == 0xFF);
+
+ if (c == 0) {
+ /* Found FF/00, which represents an FF data byte */
+ c = 0xFF;
+ } else {
+ /* Oops, it's actually a marker indicating end of compressed data.
+ * Save the marker code for later use.
+ * Fine point: it might appear that we should save the marker into
+ * bitread working state, not straight into permanent state. But
+ * once we have hit a marker, we cannot need to suspend within the
+ * current MCU, because we will read no more bytes from the data
+ * source. So it is OK to update permanent state right away.
+ */
+ cinfo->unread_marker = c;
+ /* See if we need to insert some fake zero bits. */
+ goto no_more_bytes;
+ }
+ }
+
+ /* OK, load c into get_buffer */
+ get_buffer = (get_buffer << 8) | c;
+ bits_left += 8;
+ } /* end while */
+ } else {
+ no_more_bytes:
+ /* We get here if we've read the marker that terminates the compressed
+ * data segment. There should be enough bits in the buffer register
+ * to satisfy the request; if so, no problem.
+ */
+ if (nbits > bits_left) {
+ /* Uh-oh. Report corrupted data to user and stuff zeroes into
+ * the data stream, so that we can produce some kind of image.
+ * We use a nonvolatile flag to ensure that only one warning message
+ * appears per data segment.
+ */
+ if (! cinfo->entropy->insufficient_data) {
+ WARNMS(cinfo, JWRN_HIT_MARKER);
+ cinfo->entropy->insufficient_data = TRUE;
+ }
+ /* Fill the buffer with zero bits */
+ get_buffer <<= MIN_GET_BITS - bits_left;
+ bits_left = MIN_GET_BITS;
+ }
+ }
+
+ /* Unload the local registers */
+ state->next_input_byte = next_input_byte;
+ state->bytes_in_buffer = bytes_in_buffer;
+ state->get_buffer = get_buffer;
+ state->bits_left = bits_left;
+
+ return TRUE;
+}
+
+
+/*
+ * Out-of-line code for Huffman code decoding.
+ * See jdhuff.h for info about usage.
+ */
+
+GLOBAL(int)
+jpeg_huff_decode (bitread_working_state * state,
+ bit_buf_type get_buffer, register int bits_left,
+ d_derived_tbl * htbl, int min_bits)
+{
+ int l = min_bits;
+ long code;
+
+ /* HUFF_DECODE has determined that the code is at least min_bits */
+ /* bits long, so fetch that many bits in one swoop. */
+
+ CHECK_BIT_BUFFER(*state, l, return -1);
+ code = GET_BITS(l);
+
+ /* Collect the rest of the Huffman code one bit at a time. */
+ /* This is per Figure F.16 in the JPEG spec. */
+
+ while (code > htbl->maxcode[l]) {
+ code <<= 1;
+ CHECK_BIT_BUFFER(*state, 1, return -1);
+ code |= GET_BITS(1);
+ l++;
+ }
+
+ /* Unload the local registers */
+ state->get_buffer = get_buffer;
+ state->bits_left = bits_left;
+
+ /* With garbage input we may reach the sentinel value l = 17. */
+
+ if (l > 16) {
+ WARNMS(state->cinfo, JWRN_HUFF_BAD_CODE);
+ return 0; /* fake a zero as the safest result */
+ }
+
+ return htbl->pub->huffval[ (int) (code + htbl->valoffset[l]) ];
+}
+
+
+/*
+ * Figure F.12: extend sign bit.
+ * On some machines, a shift and add will be faster than a table lookup.
+ */
+
+#ifdef AVOID_TABLES
+
+#define HUFF_EXTEND(x,s) ((x) < (1<<((s)-1)) ? (x) + (((-1)<<(s)) + 1) : (x))
+
+#else
+
+#define HUFF_EXTEND(x,s) ((x) < extend_test[s] ? (x) + extend_offset[s] : (x))
+
+static const int extend_test[16] = /* entry n is 2**(n-1) */
+ { 0, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080,
+ 0x0100, 0x0200, 0x0400, 0x0800, 0x1000, 0x2000, 0x4000 };
+
+static const int extend_offset[16] = /* entry n is (-1 << n) + 1 */
+ { 0, ((-1)<<1) + 1, ((-1)<<2) + 1, ((-1)<<3) + 1, ((-1)<<4) + 1,
+ ((-1)<<5) + 1, ((-1)<<6) + 1, ((-1)<<7) + 1, ((-1)<<8) + 1,
+ ((-1)<<9) + 1, ((-1)<<10) + 1, ((-1)<<11) + 1, ((-1)<<12) + 1,
+ ((-1)<<13) + 1, ((-1)<<14) + 1, ((-1)<<15) + 1 };
+
+#endif /* AVOID_TABLES */
+
+
+/*
+ * Check for a restart marker & resynchronize decoder.
+ * Returns FALSE if must suspend.
+ */
+
+LOCAL(int)
+process_restart (j_decompress_ptr cinfo)
+{
+ huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy;
+ int ci;
+
+ /* Throw away any unused bits remaining in bit buffer; */
+ /* include any full bytes in next_marker's count of discarded bytes */
+ cinfo->marker->discarded_bytes += entropy->bitstate.bits_left / 8;
+ entropy->bitstate.bits_left = 0;
+
+ /* Advance past the RSTn marker */
+ if (! (*cinfo->marker->read_restart_marker) (cinfo))
+ return FALSE;
+
+ /* Re-initialize DC predictions to 0 */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++)
+ entropy->saved.last_dc_val[ci] = 0;
+
+ /* Reset restart counter */
+ entropy->restarts_to_go = cinfo->restart_interval;
+
+ /* Reset out-of-data flag, unless read_restart_marker left us smack up
+ * against a marker. In that case we will end up treating the next data
+ * segment as empty, and we can avoid producing bogus output pixels by
+ * leaving the flag set.
+ */
+ if (cinfo->unread_marker == 0)
+ entropy->pub.insufficient_data = FALSE;
+
+ return TRUE;
+}
+
+
+/*
+ * Decode and return one MCU's worth of Huffman-compressed coefficients.
+ * The coefficients are reordered from zigzag order into natural array order,
+ * but are not dequantized.
+ *
+ * The i'th block of the MCU is stored into the block pointed to by
+ * MCU_data[i]. WE ASSUME THIS AREA HAS BEEN ZEROED BY THE CALLER.
+ * (Wholesale zeroing is usually a little faster than retail...)
+ *
+ * Returns FALSE if data source requested suspension. In that case no
+ * changes have been made to permanent state. (Exception: some output
+ * coefficients may already have been assigned. This is harmless for
+ * this module, since we'll just re-assign them on the next call.)
+ */
+
+METHODDEF(int)
+decode_mcu (j_decompress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy;
+ int blkn;
+ BITREAD_STATE_VARS;
+ savable_state state;
+
+ /* Process restart marker if needed; may have to suspend */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0)
+ if (! process_restart(cinfo))
+ return FALSE;
+ }
+
+ /* If we've run out of data, just leave the MCU set to zeroes.
+ * This way, we return uniform gray for the remainder of the segment.
+ */
+ if (! entropy->pub.insufficient_data) {
+
+ /* Load up working state */
+ BITREAD_LOAD_STATE(cinfo,entropy->bitstate);
+ ASSIGN_STATE(state, entropy->saved);
+
+ /* Outer loop handles each block in the MCU */
+
+ for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) {
+ JBLOCKROW block = MCU_data[blkn];
+ d_derived_tbl * dctbl = entropy->dc_cur_tbls[blkn];
+ d_derived_tbl * actbl = entropy->ac_cur_tbls[blkn];
+ int s, k, r;
+
+ /* Decode a single block's worth of coefficients */
+
+ /* Section F.2.2.1: decode the DC coefficient difference */
+ HUFF_DECODE(s, br_state, dctbl, return FALSE, label1);
+ if (s) {
+ CHECK_BIT_BUFFER(br_state, s, return FALSE);
+ r = GET_BITS(s);
+ s = HUFF_EXTEND(r, s);
+ }
+
+ if (entropy->dc_needed[blkn]) {
+ /* Convert DC difference to actual value, update last_dc_val */
+ int ci = cinfo->MCU_membership[blkn];
+ s += state.last_dc_val[ci];
+ state.last_dc_val[ci] = s;
+ /* Output the DC coefficient (assumes jpeg_natural_order[0] = 0) */
+ (*block)[0] = (JCOEF) s;
+ }
+
+ if (entropy->ac_needed[blkn]) {
+
+ /* Section F.2.2.2: decode the AC coefficients */
+ /* Since zeroes are skipped, output area must be cleared beforehand */
+ for (k = 1; k < DCTSIZE2; k++) {
+ HUFF_DECODE(s, br_state, actbl, return FALSE, label2);
+
+ r = s >> 4;
+ s &= 15;
+
+ if (s) {
+ k += r;
+ CHECK_BIT_BUFFER(br_state, s, return FALSE);
+ r = GET_BITS(s);
+ s = HUFF_EXTEND(r, s);
+ /* Output coefficient in natural (dezigzagged) order.
+ * Note: the extra entries in jpeg_natural_order[] will save us
+ * if k >= DCTSIZE2, which could happen if the data is corrupted.
+ */
+ (*block)[jpeg_natural_order[k]] = (JCOEF) s;
+ } else {
+ if (r != 15)
+ break;
+ k += 15;
+ }
+ }
+
+ } else {
+
+ /* Section F.2.2.2: decode the AC coefficients */
+ /* In this path we just discard the values */
+ for (k = 1; k < DCTSIZE2; k++) {
+ HUFF_DECODE(s, br_state, actbl, return FALSE, label3);
+
+ r = s >> 4;
+ s &= 15;
+
+ if (s) {
+ k += r;
+ CHECK_BIT_BUFFER(br_state, s, return FALSE);
+ DROP_BITS(s);
+ } else {
+ if (r != 15)
+ break;
+ k += 15;
+ }
+ }
+
+ }
+ }
+
+ /* Completed MCU, so update state */
+ BITREAD_SAVE_STATE(cinfo,entropy->bitstate);
+ ASSIGN_STATE(entropy->saved, state);
+ }
+
+ /* Account for restart interval (no-op if not using restarts) */
+ entropy->restarts_to_go--;
+
+ return TRUE;
+}
+
+
+/*
+ * Module initialization routine for Huffman entropy decoding.
+ */
+
+GLOBAL(void)
+jinit_huff_decoder (j_decompress_ptr cinfo)
+{
+ huff_entropy_ptr entropy;
+ int i;
+
+ entropy = (huff_entropy_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(huff_entropy_decoder));
+ cinfo->entropy = (struct jpeg_entropy_decoder *) entropy;
+ entropy->pub.start_pass = start_pass_huff_decoder;
+ entropy->pub.decode_mcu = decode_mcu;
+
+ /* Mark tables unallocated */
+ for (i = 0; i < NUM_HUFF_TBLS; i++) {
+ entropy->dc_derived_tbls[i] = entropy->ac_derived_tbls[i] = NULL;
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdhuff.h b/ml/dlib/dlib/external/libjpeg/jdhuff.h
new file mode 100644
index 000000000..6a0e939af
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdhuff.h
@@ -0,0 +1,201 @@
+/*
+ * jdhuff.h
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains declarations for Huffman entropy decoding routines
+ * that are shared between the sequential decoder (jdhuff.c) and the
+ * progressive decoder (jdphuff.c). No other modules need to see these.
+ */
+
+/* Short forms of external names for systems with brain-damaged linkers. */
+
+#ifdef NEED_SHORT_EXTERNAL_NAMES
+#define jpeg_make_d_derived_tbl jMkDDerived
+#define jpeg_fill_bit_buffer jFilBitBuf
+#define jpeg_huff_decode jHufDecode
+#endif /* NEED_SHORT_EXTERNAL_NAMES */
+
+
+/* Derived data constructed for each Huffman table */
+
+#define HUFF_LOOKAHEAD 8 /* # of bits of lookahead */
+
+typedef struct {
+ /* Basic tables: (element [0] of each array is unused) */
+ long maxcode[18]; /* largest code of length k (-1 if none) */
+ /* (maxcode[17] is a sentinel to ensure jpeg_huff_decode terminates) */
+ long valoffset[17]; /* huffval[] offset for codes of length k */
+ /* valoffset[k] = huffval[] index of 1st symbol of code length k, less
+ * the smallest code of length k; so given a code of length k, the
+ * corresponding symbol is huffval[code + valoffset[k]]
+ */
+
+ /* Link to public Huffman table (needed only in jpeg_huff_decode) */
+ JHUFF_TBL *pub;
+
+ /* Lookahead tables: indexed by the next HUFF_LOOKAHEAD bits of
+ * the input data stream. If the next Huffman code is no more
+ * than HUFF_LOOKAHEAD bits long, we can obtain its length and
+ * the corresponding symbol directly from these tables.
+ */
+ int look_nbits[1<<HUFF_LOOKAHEAD]; /* # bits, or 0 if too long */
+ unsigned char look_sym[1<<HUFF_LOOKAHEAD]; /* symbol, or unused */
+} d_derived_tbl;
+
+/* Expand a Huffman table definition into the derived format */
+EXTERN(void) jpeg_make_d_derived_tbl
+ JPP((j_decompress_ptr cinfo, int isDC, int tblno,
+ d_derived_tbl ** pdtbl));
+
+
+/*
+ * Fetching the next N bits from the input stream is a time-critical operation
+ * for the Huffman decoders. We implement it with a combination of inline
+ * macros and out-of-line subroutines. Note that N (the number of bits
+ * demanded at one time) never exceeds 15 for JPEG use.
+ *
+ * We read source bytes into get_buffer and dole out bits as needed.
+ * If get_buffer already contains enough bits, they are fetched in-line
+ * by the macros CHECK_BIT_BUFFER and GET_BITS. When there aren't enough
+ * bits, jpeg_fill_bit_buffer is called; it will attempt to fill get_buffer
+ * as full as possible (not just to the number of bits needed; this
+ * prefetching reduces the overhead cost of calling jpeg_fill_bit_buffer).
+ * Note that jpeg_fill_bit_buffer may return FALSE to indicate suspension.
+ * On TRUE return, jpeg_fill_bit_buffer guarantees that get_buffer contains
+ * at least the requested number of bits --- dummy zeroes are inserted if
+ * necessary.
+ */
+
+typedef long bit_buf_type; /* type of bit-extraction buffer */
+#define BIT_BUF_SIZE 32 /* size of buffer in bits */
+
+/* If long is > 32 bits on your machine, and shifting/masking longs is
+ * reasonably fast, making bit_buf_type be long and setting BIT_BUF_SIZE
+ * appropriately should be a win. Unfortunately we can't define the size
+ * with something like #define BIT_BUF_SIZE (sizeof(bit_buf_type)*8)
+ * because not all machines measure sizeof in 8-bit bytes.
+ */
+
+typedef struct { /* Bitreading state saved across MCUs */
+ bit_buf_type get_buffer; /* current bit-extraction buffer */
+ int bits_left; /* # of unused bits in it */
+} bitread_perm_state;
+
+typedef struct { /* Bitreading working state within an MCU */
+ /* Current data source location */
+ /* We need a copy, rather than munging the original, in case of suspension */
+ const JOCTET * next_input_byte; /* => next byte to read from source */
+ size_t bytes_in_buffer; /* # of bytes remaining in source buffer */
+ /* Bit input buffer --- note these values are kept in variables,
+ * not in this struct, inside the inner loops.
+ */
+ bit_buf_type get_buffer; /* current bit-extraction buffer */
+ int bits_left; /* # of unused bits in it */
+ /* Pointer needed by jpeg_fill_bit_buffer. */
+ j_decompress_ptr cinfo; /* back link to decompress master record */
+} bitread_working_state;
+
+/* Macros to declare and load/save bitread local variables. */
+#define BITREAD_STATE_VARS \
+ bit_buf_type get_buffer; \
+ int bits_left; \
+ bitread_working_state br_state
+
+#define BITREAD_LOAD_STATE(cinfop,permstate) \
+ br_state.cinfo = cinfop; \
+ br_state.next_input_byte = cinfop->src->next_input_byte; \
+ br_state.bytes_in_buffer = cinfop->src->bytes_in_buffer; \
+ get_buffer = permstate.get_buffer; \
+ bits_left = permstate.bits_left;
+
+#define BITREAD_SAVE_STATE(cinfop,permstate) \
+ cinfop->src->next_input_byte = br_state.next_input_byte; \
+ cinfop->src->bytes_in_buffer = br_state.bytes_in_buffer; \
+ permstate.get_buffer = get_buffer; \
+ permstate.bits_left = bits_left
+
+/*
+ * These macros provide the in-line portion of bit fetching.
+ * Use CHECK_BIT_BUFFER to ensure there are N bits in get_buffer
+ * before using GET_BITS, PEEK_BITS, or DROP_BITS.
+ * The variables get_buffer and bits_left are assumed to be locals,
+ * but the state struct might not be (jpeg_huff_decode needs this).
+ * CHECK_BIT_BUFFER(state,n,action);
+ * Ensure there are N bits in get_buffer; if suspend, take action.
+ * val = GET_BITS(n);
+ * Fetch next N bits.
+ * val = PEEK_BITS(n);
+ * Fetch next N bits without removing them from the buffer.
+ * DROP_BITS(n);
+ * Discard next N bits.
+ * The value N should be a simple variable, not an expression, because it
+ * is evaluated multiple times.
+ */
+
+#define CHECK_BIT_BUFFER(state,nbits,action) \
+ { if (bits_left < (nbits)) { \
+ if (! jpeg_fill_bit_buffer(&(state),get_buffer,bits_left,nbits)) \
+ { action; } \
+ get_buffer = (state).get_buffer; bits_left = (state).bits_left; } }
+
+#define GET_BITS(nbits) \
+ (((int) (get_buffer >> (bits_left -= (nbits)))) & ((1<<(nbits))-1))
+
+#define PEEK_BITS(nbits) \
+ (((int) (get_buffer >> (bits_left - (nbits)))) & ((1<<(nbits))-1))
+
+#define DROP_BITS(nbits) \
+ (bits_left -= (nbits))
+
+/* Load up the bit buffer to a depth of at least nbits */
+EXTERN(int) jpeg_fill_bit_buffer
+ JPP((bitread_working_state * state, bit_buf_type get_buffer,
+ int bits_left, int nbits));
+
+
+/*
+ * Code for extracting next Huffman-coded symbol from input bit stream.
+ * Again, this is time-critical and we make the main paths be macros.
+ *
+ * We use a lookahead table to process codes of up to HUFF_LOOKAHEAD bits
+ * without looping. Usually, more than 95% of the Huffman codes will be 8
+ * or fewer bits long. The few overlength codes are handled with a loop,
+ * which need not be inline code.
+ *
+ * Notes about the HUFF_DECODE macro:
+ * 1. Near the end of the data segment, we may fail to get enough bits
+ * for a lookahead. In that case, we do it the hard way.
+ * 2. If the lookahead table contains no entry, the next code must be
+ * more than HUFF_LOOKAHEAD bits long.
+ * 3. jpeg_huff_decode returns -1 if forced to suspend.
+ */
+
+#define HUFF_DECODE(result,state,htbl,failaction,slowlabel) \
+{ int nb, look; \
+ if (bits_left < HUFF_LOOKAHEAD) { \
+ if (! jpeg_fill_bit_buffer(&state,get_buffer,bits_left, 0)) {failaction;} \
+ get_buffer = state.get_buffer; bits_left = state.bits_left; \
+ if (bits_left < HUFF_LOOKAHEAD) { \
+ nb = 1; goto slowlabel; \
+ } \
+ } \
+ look = PEEK_BITS(HUFF_LOOKAHEAD); \
+ if ((nb = htbl->look_nbits[look]) != 0) { \
+ DROP_BITS(nb); \
+ result = htbl->look_sym[look]; \
+ } else { \
+ nb = HUFF_LOOKAHEAD+1; \
+slowlabel: \
+ if ((result=jpeg_huff_decode(&state,get_buffer,bits_left,htbl,nb)) < 0) \
+ { failaction; } \
+ get_buffer = state.get_buffer; bits_left = state.bits_left; \
+ } \
+}
+
+/* Out-of-line case for Huffman code fetching */
+EXTERN(int) jpeg_huff_decode
+ JPP((bitread_working_state * state, bit_buf_type get_buffer,
+ int bits_left, d_derived_tbl * htbl, int min_bits));
diff --git a/ml/dlib/dlib/external/libjpeg/jdinput.cpp b/ml/dlib/dlib/external/libjpeg/jdinput.cpp
new file mode 100644
index 000000000..42f79977d
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdinput.cpp
@@ -0,0 +1,381 @@
+/*
+ * jdinput.c
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains input control logic for the JPEG decompressor.
+ * These routines are concerned with controlling the decompressor's input
+ * processing (marker reading and coefficient decoding). The actual input
+ * reading is done in jdmarker.c, jdhuff.c, and jdphuff.c.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Private state */
+
+typedef struct {
+ struct jpeg_input_controller pub; /* public fields */
+
+ int inheaders; /* TRUE until first SOS is reached */
+} my_input_controller;
+
+typedef my_input_controller * my_inputctl_ptr;
+
+
+/* Forward declarations */
+METHODDEF(int) consume_markers JPP((j_decompress_ptr cinfo));
+
+
+/*
+ * Routines to calculate various quantities related to the size of the image.
+ */
+
+LOCAL(void)
+initial_setup (j_decompress_ptr cinfo)
+/* Called once, when first SOS marker is reached */
+{
+ int ci;
+ jpeg_component_info *compptr;
+
+ /* Make sure image isn't bigger than I can handle */
+ if ((long) cinfo->image_height > (long) JPEG_MAX_DIMENSION ||
+ (long) cinfo->image_width > (long) JPEG_MAX_DIMENSION)
+ ERREXIT1(cinfo, JERR_IMAGE_TOO_BIG, (unsigned int) JPEG_MAX_DIMENSION);
+
+ /* For now, precision must match compiled-in value... */
+ if (cinfo->data_precision != BITS_IN_JSAMPLE)
+ ERREXIT1(cinfo, JERR_BAD_PRECISION, cinfo->data_precision);
+
+ /* Check that number of components won't exceed internal array sizes */
+ if (cinfo->num_components > MAX_COMPONENTS)
+ ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->num_components,
+ MAX_COMPONENTS);
+
+ /* Compute maximum sampling factors; check factor validity */
+ cinfo->max_h_samp_factor = 1;
+ cinfo->max_v_samp_factor = 1;
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ if (compptr->h_samp_factor<=0 || compptr->h_samp_factor>MAX_SAMP_FACTOR ||
+ compptr->v_samp_factor<=0 || compptr->v_samp_factor>MAX_SAMP_FACTOR)
+ ERREXIT(cinfo, JERR_BAD_SAMPLING);
+ cinfo->max_h_samp_factor = MAX(cinfo->max_h_samp_factor,
+ compptr->h_samp_factor);
+ cinfo->max_v_samp_factor = MAX(cinfo->max_v_samp_factor,
+ compptr->v_samp_factor);
+ }
+
+ /* We initialize DCT_scaled_size and min_DCT_scaled_size to DCTSIZE.
+ * In the full decompressor, this will be overridden by jdmaster.c;
+ * but in the transcoder, jdmaster.c is not used, so we must do it here.
+ */
+ cinfo->min_DCT_scaled_size = DCTSIZE;
+
+ /* Compute dimensions of components */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ compptr->DCT_scaled_size = DCTSIZE;
+ /* Size in DCT blocks */
+ compptr->width_in_blocks = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width * (long) compptr->h_samp_factor,
+ (long) (cinfo->max_h_samp_factor * DCTSIZE));
+ compptr->height_in_blocks = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height * (long) compptr->v_samp_factor,
+ (long) (cinfo->max_v_samp_factor * DCTSIZE));
+ /* downsampled_width and downsampled_height will also be overridden by
+ * jdmaster.c if we are doing full decompression. The transcoder library
+ * doesn't use these values, but the calling application might.
+ */
+ /* Size in samples */
+ compptr->downsampled_width = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width * (long) compptr->h_samp_factor,
+ (long) cinfo->max_h_samp_factor);
+ compptr->downsampled_height = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height * (long) compptr->v_samp_factor,
+ (long) cinfo->max_v_samp_factor);
+ /* Mark component needed, until color conversion says otherwise */
+ compptr->component_needed = TRUE;
+ /* Mark no quantization table yet saved for component */
+ compptr->quant_table = NULL;
+ }
+
+ /* Compute number of fully interleaved MCU rows. */
+ cinfo->total_iMCU_rows = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height,
+ (long) (cinfo->max_v_samp_factor*DCTSIZE));
+
+ /* Decide whether file contains multiple scans */
+ if (cinfo->comps_in_scan < cinfo->num_components || cinfo->progressive_mode)
+ cinfo->inputctl->has_multiple_scans = TRUE;
+ else
+ cinfo->inputctl->has_multiple_scans = FALSE;
+}
+
+
+LOCAL(void)
+per_scan_setup (j_decompress_ptr cinfo)
+/* Do computations that are needed before processing a JPEG scan */
+/* cinfo->comps_in_scan and cinfo->cur_comp_info[] were set from SOS marker */
+{
+ int ci, mcublks, tmp;
+ jpeg_component_info *compptr;
+
+ if (cinfo->comps_in_scan == 1) {
+
+ /* Noninterleaved (single-component) scan */
+ compptr = cinfo->cur_comp_info[0];
+
+ /* Overall image size in MCUs */
+ cinfo->MCUs_per_row = compptr->width_in_blocks;
+ cinfo->MCU_rows_in_scan = compptr->height_in_blocks;
+
+ /* For noninterleaved scan, always one block per MCU */
+ compptr->MCU_width = 1;
+ compptr->MCU_height = 1;
+ compptr->MCU_blocks = 1;
+ compptr->MCU_sample_width = compptr->DCT_scaled_size;
+ compptr->last_col_width = 1;
+ /* For noninterleaved scans, it is convenient to define last_row_height
+ * as the number of block rows present in the last iMCU row.
+ */
+ tmp = (int) (compptr->height_in_blocks % compptr->v_samp_factor);
+ if (tmp == 0) tmp = compptr->v_samp_factor;
+ compptr->last_row_height = tmp;
+
+ /* Prepare array describing MCU composition */
+ cinfo->blocks_in_MCU = 1;
+ cinfo->MCU_membership[0] = 0;
+
+ } else {
+
+ /* Interleaved (multi-component) scan */
+ if (cinfo->comps_in_scan <= 0 || cinfo->comps_in_scan > MAX_COMPS_IN_SCAN)
+ ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->comps_in_scan,
+ MAX_COMPS_IN_SCAN);
+
+ /* Overall image size in MCUs */
+ cinfo->MCUs_per_row = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width,
+ (long) (cinfo->max_h_samp_factor*DCTSIZE));
+ cinfo->MCU_rows_in_scan = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height,
+ (long) (cinfo->max_v_samp_factor*DCTSIZE));
+
+ cinfo->blocks_in_MCU = 0;
+
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ /* Sampling factors give # of blocks of component in each MCU */
+ compptr->MCU_width = compptr->h_samp_factor;
+ compptr->MCU_height = compptr->v_samp_factor;
+ compptr->MCU_blocks = compptr->MCU_width * compptr->MCU_height;
+ compptr->MCU_sample_width = compptr->MCU_width * compptr->DCT_scaled_size;
+ /* Figure number of non-dummy blocks in last MCU column & row */
+ tmp = (int) (compptr->width_in_blocks % compptr->MCU_width);
+ if (tmp == 0) tmp = compptr->MCU_width;
+ compptr->last_col_width = tmp;
+ tmp = (int) (compptr->height_in_blocks % compptr->MCU_height);
+ if (tmp == 0) tmp = compptr->MCU_height;
+ compptr->last_row_height = tmp;
+ /* Prepare array describing MCU composition */
+ mcublks = compptr->MCU_blocks;
+ if (cinfo->blocks_in_MCU + mcublks > D_MAX_BLOCKS_IN_MCU)
+ ERREXIT(cinfo, JERR_BAD_MCU_SIZE);
+ while (mcublks-- > 0) {
+ cinfo->MCU_membership[cinfo->blocks_in_MCU++] = ci;
+ }
+ }
+
+ }
+}
+
+
+/*
+ * Save away a copy of the Q-table referenced by each component present
+ * in the current scan, unless already saved during a prior scan.
+ *
+ * In a multiple-scan JPEG file, the encoder could assign different components
+ * the same Q-table slot number, but change table definitions between scans
+ * so that each component uses a different Q-table. (The IJG encoder is not
+ * currently capable of doing this, but other encoders might.) Since we want
+ * to be able to dequantize all the components at the end of the file, this
+ * means that we have to save away the table actually used for each component.
+ * We do this by copying the table at the start of the first scan containing
+ * the component.
+ * The JPEG spec prohibits the encoder from changing the contents of a Q-table
+ * slot between scans of a component using that slot. If the encoder does so
+ * anyway, this decoder will simply use the Q-table values that were current
+ * at the start of the first scan for the component.
+ *
+ * The decompressor output side looks only at the saved quant tables,
+ * not at the current Q-table slots.
+ */
+
+LOCAL(void)
+latch_quant_tables (j_decompress_ptr cinfo)
+{
+ int ci, qtblno;
+ jpeg_component_info *compptr;
+ JQUANT_TBL * qtbl;
+
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ /* No work if we already saved Q-table for this component */
+ if (compptr->quant_table != NULL)
+ continue;
+ /* Make sure specified quantization table is present */
+ qtblno = compptr->quant_tbl_no;
+ if (qtblno < 0 || qtblno >= NUM_QUANT_TBLS ||
+ cinfo->quant_tbl_ptrs[qtblno] == NULL)
+ ERREXIT1(cinfo, JERR_NO_QUANT_TABLE, qtblno);
+ /* OK, save away the quantization table */
+ qtbl = (JQUANT_TBL *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(JQUANT_TBL));
+ MEMCOPY(qtbl, cinfo->quant_tbl_ptrs[qtblno], SIZEOF(JQUANT_TBL));
+ compptr->quant_table = qtbl;
+ }
+}
+
+
+/*
+ * Initialize the input modules to read a scan of compressed data.
+ * The first call to this is done by jdmaster.c after initializing
+ * the entire decompressor (during jpeg_start_decompress).
+ * Subsequent calls come from consume_markers, below.
+ */
+
+METHODDEF(void)
+start_input_pass (j_decompress_ptr cinfo)
+{
+ per_scan_setup(cinfo);
+ latch_quant_tables(cinfo);
+ (*cinfo->entropy->start_pass) (cinfo);
+ (*cinfo->coef->start_input_pass) (cinfo);
+ cinfo->inputctl->consume_input = cinfo->coef->consume_data;
+}
+
+
+/*
+ * Finish up after inputting a compressed-data scan.
+ * This is called by the coefficient controller after it's read all
+ * the expected data of the scan.
+ */
+
+METHODDEF(void)
+finish_input_pass (j_decompress_ptr cinfo)
+{
+ cinfo->inputctl->consume_input = consume_markers;
+}
+
+
+/*
+ * Read JPEG markers before, between, or after compressed-data scans.
+ * Change state as necessary when a new scan is reached.
+ * Return value is JPEG_SUSPENDED, JPEG_REACHED_SOS, or JPEG_REACHED_EOI.
+ *
+ * The consume_input method pointer points either here or to the
+ * coefficient controller's consume_data routine, depending on whether
+ * we are reading a compressed data segment or inter-segment markers.
+ */
+
+METHODDEF(int)
+consume_markers (j_decompress_ptr cinfo)
+{
+ my_inputctl_ptr inputctl = (my_inputctl_ptr) cinfo->inputctl;
+ int val;
+
+ if (inputctl->pub.eoi_reached) /* After hitting EOI, read no further */
+ return JPEG_REACHED_EOI;
+
+ val = (*cinfo->marker->read_markers) (cinfo);
+
+ switch (val) {
+ case JPEG_REACHED_SOS: /* Found SOS */
+ if (inputctl->inheaders) { /* 1st SOS */
+ initial_setup(cinfo);
+ inputctl->inheaders = FALSE;
+ /* Note: start_input_pass must be called by jdmaster.c
+ * before any more input can be consumed. jdapimin.c is
+ * responsible for enforcing this sequencing.
+ */
+ } else { /* 2nd or later SOS marker */
+ if (! inputctl->pub.has_multiple_scans)
+ ERREXIT(cinfo, JERR_EOI_EXPECTED); /* Oops, I wasn't expecting this! */
+ start_input_pass(cinfo);
+ }
+ break;
+ case JPEG_REACHED_EOI: /* Found EOI */
+ inputctl->pub.eoi_reached = TRUE;
+ if (inputctl->inheaders) { /* Tables-only datastream, apparently */
+ if (cinfo->marker->saw_SOF)
+ ERREXIT(cinfo, JERR_SOF_NO_SOS);
+ } else {
+ /* Prevent infinite loop in coef ctlr's decompress_data routine
+ * if user set output_scan_number larger than number of scans.
+ */
+ if (cinfo->output_scan_number > cinfo->input_scan_number)
+ cinfo->output_scan_number = cinfo->input_scan_number;
+ }
+ break;
+ case JPEG_SUSPENDED:
+ break;
+ }
+
+ return val;
+}
+
+
+/*
+ * Reset state to begin a fresh datastream.
+ */
+
+METHODDEF(void)
+reset_input_controller (j_decompress_ptr cinfo)
+{
+ my_inputctl_ptr inputctl = (my_inputctl_ptr) cinfo->inputctl;
+
+ inputctl->pub.consume_input = consume_markers;
+ inputctl->pub.has_multiple_scans = FALSE; /* "unknown" would be better */
+ inputctl->pub.eoi_reached = FALSE;
+ inputctl->inheaders = TRUE;
+ /* Reset other modules */
+ (*cinfo->err->reset_error_mgr) ((j_common_ptr) cinfo);
+ (*cinfo->marker->reset_marker_reader) (cinfo);
+ /* Reset progression state -- would be cleaner if entropy decoder did this */
+ cinfo->coef_bits = NULL;
+}
+
+
+/*
+ * Initialize the input controller module.
+ * This is called only once, when the decompression object is created.
+ */
+
+GLOBAL(void)
+jinit_input_controller (j_decompress_ptr cinfo)
+{
+ my_inputctl_ptr inputctl;
+
+ /* Create subobject in permanent pool */
+ inputctl = (my_inputctl_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT,
+ SIZEOF(my_input_controller));
+ cinfo->inputctl = (struct jpeg_input_controller *) inputctl;
+ /* Initialize method pointers */
+ inputctl->pub.consume_input = consume_markers;
+ inputctl->pub.reset_input_controller = reset_input_controller;
+ inputctl->pub.start_input_pass = start_input_pass;
+ inputctl->pub.finish_input_pass = finish_input_pass;
+ /* Initialize state: can't use reset_input_controller since we don't
+ * want to try to reset other modules yet.
+ */
+ inputctl->pub.has_multiple_scans = FALSE; /* "unknown" would be better */
+ inputctl->pub.eoi_reached = FALSE;
+ inputctl->inheaders = TRUE;
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdmainct.cpp b/ml/dlib/dlib/external/libjpeg/jdmainct.cpp
new file mode 100644
index 000000000..bc2c378aa
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdmainct.cpp
@@ -0,0 +1,512 @@
+/*
+ * jdmainct.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains the main buffer controller for decompression.
+ * The main buffer lies between the JPEG decompressor proper and the
+ * post-processor; it holds downsampled data in the JPEG colorspace.
+ *
+ * Note that this code is bypassed in raw-data mode, since the application
+ * supplies the equivalent of the main buffer in that case.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/*
+ * In the current system design, the main buffer need never be a full-image
+ * buffer; any full-height buffers will be found inside the coefficient or
+ * postprocessing controllers. Nonetheless, the main controller is not
+ * trivial. Its responsibility is to provide context rows for upsampling/
+ * rescaling, and doing this in an efficient fashion is a bit tricky.
+ *
+ * Postprocessor input data is counted in "row groups". A row group
+ * is defined to be (v_samp_factor * DCT_scaled_size / min_DCT_scaled_size)
+ * sample rows of each component. (We require DCT_scaled_size values to be
+ * chosen such that these numbers are integers. In practice DCT_scaled_size
+ * values will likely be powers of two, so we actually have the stronger
+ * condition that DCT_scaled_size / min_DCT_scaled_size is an integer.)
+ * Upsampling will typically produce max_v_samp_factor pixel rows from each
+ * row group (times any additional scale factor that the upsampler is
+ * applying).
+ *
+ * The coefficient controller will deliver data to us one iMCU row at a time;
+ * each iMCU row contains v_samp_factor * DCT_scaled_size sample rows, or
+ * exactly min_DCT_scaled_size row groups. (This amount of data corresponds
+ * to one row of MCUs when the image is fully interleaved.) Note that the
+ * number of sample rows varies across components, but the number of row
+ * groups does not. Some garbage sample rows may be included in the last iMCU
+ * row at the bottom of the image.
+ *
+ * Depending on the vertical scaling algorithm used, the upsampler may need
+ * access to the sample row(s) above and below its current input row group.
+ * The upsampler is required to set need_context_rows TRUE at global selection
+ * time if so. When need_context_rows is FALSE, this controller can simply
+ * obtain one iMCU row at a time from the coefficient controller and dole it
+ * out as row groups to the postprocessor.
+ *
+ * When need_context_rows is TRUE, this controller guarantees that the buffer
+ * passed to postprocessing contains at least one row group's worth of samples
+ * above and below the row group(s) being processed. Note that the context
+ * rows "above" the first passed row group appear at negative row offsets in
+ * the passed buffer. At the top and bottom of the image, the required
+ * context rows are manufactured by duplicating the first or last real sample
+ * row; this avoids having special cases in the upsampling inner loops.
+ *
+ * The amount of context is fixed at one row group just because that's a
+ * convenient number for this controller to work with. The existing
+ * upsamplers really only need one sample row of context. An upsampler
+ * supporting arbitrary output rescaling might wish for more than one row
+ * group of context when shrinking the image; tough, we don't handle that.
+ * (This is justified by the assumption that downsizing will be handled mostly
+ * by adjusting the DCT_scaled_size values, so that the actual scale factor at
+ * the upsample step needn't be much less than one.)
+ *
+ * To provide the desired context, we have to retain the last two row groups
+ * of one iMCU row while reading in the next iMCU row. (The last row group
+ * can't be processed until we have another row group for its below-context,
+ * and so we have to save the next-to-last group too for its above-context.)
+ * We could do this most simply by copying data around in our buffer, but
+ * that'd be very slow. We can avoid copying any data by creating a rather
+ * strange pointer structure. Here's how it works. We allocate a workspace
+ * consisting of M+2 row groups (where M = min_DCT_scaled_size is the number
+ * of row groups per iMCU row). We create two sets of redundant pointers to
+ * the workspace. Labeling the physical row groups 0 to M+1, the synthesized
+ * pointer lists look like this:
+ * M+1 M-1
+ * master pointer --> 0 master pointer --> 0
+ * 1 1
+ * ... ...
+ * M-3 M-3
+ * M-2 M
+ * M-1 M+1
+ * M M-2
+ * M+1 M-1
+ * 0 0
+ * We read alternate iMCU rows using each master pointer; thus the last two
+ * row groups of the previous iMCU row remain un-overwritten in the workspace.
+ * The pointer lists are set up so that the required context rows appear to
+ * be adjacent to the proper places when we pass the pointer lists to the
+ * upsampler.
+ *
+ * The above pictures describe the normal state of the pointer lists.
+ * At top and bottom of the image, we diddle the pointer lists to duplicate
+ * the first or last sample row as necessary (this is cheaper than copying
+ * sample rows around).
+ *
+ * This scheme breaks down if M < 2, ie, min_DCT_scaled_size is 1. In that
+ * situation each iMCU row provides only one row group so the buffering logic
+ * must be different (eg, we must read two iMCU rows before we can emit the
+ * first row group). For now, we simply do not support providing context
+ * rows when min_DCT_scaled_size is 1. That combination seems unlikely to
+ * be worth providing --- if someone wants a 1/8th-size preview, they probably
+ * want it quick and dirty, so a context-free upsampler is sufficient.
+ */
+
+
+/* Private buffer controller object */
+
+typedef struct {
+ struct jpeg_d_main_controller pub; /* public fields */
+
+ /* Pointer to allocated workspace (M or M+2 row groups). */
+ JSAMPARRAY buffer[MAX_COMPONENTS];
+
+ int buffer_full; /* Have we gotten an iMCU row from decoder? */
+ JDIMENSION rowgroup_ctr; /* counts row groups output to postprocessor */
+
+ /* Remaining fields are only used in the context case. */
+
+ /* These are the master pointers to the funny-order pointer lists. */
+ JSAMPIMAGE xbuffer[2]; /* pointers to weird pointer lists */
+
+ int whichptr; /* indicates which pointer set is now in use */
+ int context_state; /* process_data state machine status */
+ JDIMENSION rowgroups_avail; /* row groups available to postprocessor */
+ JDIMENSION iMCU_row_ctr; /* counts iMCU rows to detect image top/bot */
+} my_main_controller;
+
+typedef my_main_controller * my_main_ptr;
+
+/* context_state values: */
+#define CTX_PREPARE_FOR_IMCU 0 /* need to prepare for MCU row */
+#define CTX_PROCESS_IMCU 1 /* feeding iMCU to postprocessor */
+#define CTX_POSTPONED_ROW 2 /* feeding postponed row group */
+
+
+/* Forward declarations */
+METHODDEF(void) process_data_simple_main
+ JPP((j_decompress_ptr cinfo, JSAMPARRAY output_buf,
+ JDIMENSION *out_row_ctr, JDIMENSION out_rows_avail));
+METHODDEF(void) process_data_context_main
+ JPP((j_decompress_ptr cinfo, JSAMPARRAY output_buf,
+ JDIMENSION *out_row_ctr, JDIMENSION out_rows_avail));
+#ifdef QUANT_2PASS_SUPPORTED
+METHODDEF(void) process_data_crank_post
+ JPP((j_decompress_ptr cinfo, JSAMPARRAY output_buf,
+ JDIMENSION *out_row_ctr, JDIMENSION out_rows_avail));
+#endif
+
+
+LOCAL(void)
+alloc_funny_pointers (j_decompress_ptr cinfo)
+/* Allocate space for the funny pointer lists.
+ * This is done only once, not once per pass.
+ */
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+ int ci, rgroup;
+ int M = cinfo->min_DCT_scaled_size;
+ jpeg_component_info *compptr;
+ JSAMPARRAY xbuf;
+
+ /* Get top-level space for component array pointers.
+ * We alloc both arrays with one call to save a few cycles.
+ */
+ main->xbuffer[0] = (JSAMPIMAGE)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ cinfo->num_components * 2 * SIZEOF(JSAMPARRAY));
+ main->xbuffer[1] = main->xbuffer[0] + cinfo->num_components;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ rgroup = (compptr->v_samp_factor * compptr->DCT_scaled_size) /
+ cinfo->min_DCT_scaled_size; /* height of a row group of component */
+ /* Get space for pointer lists --- M+4 row groups in each list.
+ * We alloc both pointer lists with one call to save a few cycles.
+ */
+ xbuf = (JSAMPARRAY)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ 2 * (rgroup * (M + 4)) * SIZEOF(JSAMPROW));
+ xbuf += rgroup; /* want one row group at negative offsets */
+ main->xbuffer[0][ci] = xbuf;
+ xbuf += rgroup * (M + 4);
+ main->xbuffer[1][ci] = xbuf;
+ }
+}
+
+
+LOCAL(void)
+make_funny_pointers (j_decompress_ptr cinfo)
+/* Create the funny pointer lists discussed in the comments above.
+ * The actual workspace is already allocated (in main->buffer),
+ * and the space for the pointer lists is allocated too.
+ * This routine just fills in the curiously ordered lists.
+ * This will be repeated at the beginning of each pass.
+ */
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+ int ci, i, rgroup;
+ int M = cinfo->min_DCT_scaled_size;
+ jpeg_component_info *compptr;
+ JSAMPARRAY buf, xbuf0, xbuf1;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ rgroup = (compptr->v_samp_factor * compptr->DCT_scaled_size) /
+ cinfo->min_DCT_scaled_size; /* height of a row group of component */
+ xbuf0 = main->xbuffer[0][ci];
+ xbuf1 = main->xbuffer[1][ci];
+ /* First copy the workspace pointers as-is */
+ buf = main->buffer[ci];
+ for (i = 0; i < rgroup * (M + 2); i++) {
+ xbuf0[i] = xbuf1[i] = buf[i];
+ }
+ /* In the second list, put the last four row groups in swapped order */
+ for (i = 0; i < rgroup * 2; i++) {
+ xbuf1[rgroup*(M-2) + i] = buf[rgroup*M + i];
+ xbuf1[rgroup*M + i] = buf[rgroup*(M-2) + i];
+ }
+ /* The wraparound pointers at top and bottom will be filled later
+ * (see set_wraparound_pointers, below). Initially we want the "above"
+ * pointers to duplicate the first actual data line. This only needs
+ * to happen in xbuffer[0].
+ */
+ for (i = 0; i < rgroup; i++) {
+ xbuf0[i - rgroup] = xbuf0[0];
+ }
+ }
+}
+
+
+LOCAL(void)
+set_wraparound_pointers (j_decompress_ptr cinfo)
+/* Set up the "wraparound" pointers at top and bottom of the pointer lists.
+ * This changes the pointer list state from top-of-image to the normal state.
+ */
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+ int ci, i, rgroup;
+ int M = cinfo->min_DCT_scaled_size;
+ jpeg_component_info *compptr;
+ JSAMPARRAY xbuf0, xbuf1;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ rgroup = (compptr->v_samp_factor * compptr->DCT_scaled_size) /
+ cinfo->min_DCT_scaled_size; /* height of a row group of component */
+ xbuf0 = main->xbuffer[0][ci];
+ xbuf1 = main->xbuffer[1][ci];
+ for (i = 0; i < rgroup; i++) {
+ xbuf0[i - rgroup] = xbuf0[rgroup*(M+1) + i];
+ xbuf1[i - rgroup] = xbuf1[rgroup*(M+1) + i];
+ xbuf0[rgroup*(M+2) + i] = xbuf0[i];
+ xbuf1[rgroup*(M+2) + i] = xbuf1[i];
+ }
+ }
+}
+
+
+LOCAL(void)
+set_bottom_pointers (j_decompress_ptr cinfo)
+/* Change the pointer lists to duplicate the last sample row at the bottom
+ * of the image. whichptr indicates which xbuffer holds the final iMCU row.
+ * Also sets rowgroups_avail to indicate number of nondummy row groups in row.
+ */
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+ int ci, i, rgroup, iMCUheight, rows_left;
+ jpeg_component_info *compptr;
+ JSAMPARRAY xbuf;
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Count sample rows in one iMCU row and in one row group */
+ iMCUheight = compptr->v_samp_factor * compptr->DCT_scaled_size;
+ rgroup = iMCUheight / cinfo->min_DCT_scaled_size;
+ /* Count nondummy sample rows remaining for this component */
+ rows_left = (int) (compptr->downsampled_height % (JDIMENSION) iMCUheight);
+ if (rows_left == 0) rows_left = iMCUheight;
+ /* Count nondummy row groups. Should get same answer for each component,
+ * so we need only do it once.
+ */
+ if (ci == 0) {
+ main->rowgroups_avail = (JDIMENSION) ((rows_left-1) / rgroup + 1);
+ }
+ /* Duplicate the last real sample row rgroup*2 times; this pads out the
+ * last partial rowgroup and ensures at least one full rowgroup of context.
+ */
+ xbuf = main->xbuffer[main->whichptr][ci];
+ for (i = 0; i < rgroup * 2; i++) {
+ xbuf[rows_left + i] = xbuf[rows_left-1];
+ }
+ }
+}
+
+
+/*
+ * Initialize for a processing pass.
+ */
+
+METHODDEF(void)
+start_pass_main (j_decompress_ptr cinfo, J_BUF_MODE pass_mode)
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+
+ switch (pass_mode) {
+ case JBUF_PASS_THRU:
+ if (cinfo->upsample->need_context_rows) {
+ main->pub.process_data = process_data_context_main;
+ make_funny_pointers(cinfo); /* Create the xbuffer[] lists */
+ main->whichptr = 0; /* Read first iMCU row into xbuffer[0] */
+ main->context_state = CTX_PREPARE_FOR_IMCU;
+ main->iMCU_row_ctr = 0;
+ } else {
+ /* Simple case with no context needed */
+ main->pub.process_data = process_data_simple_main;
+ }
+ main->buffer_full = FALSE; /* Mark buffer empty */
+ main->rowgroup_ctr = 0;
+ break;
+#ifdef QUANT_2PASS_SUPPORTED
+ case JBUF_CRANK_DEST:
+ /* For last pass of 2-pass quantization, just crank the postprocessor */
+ main->pub.process_data = process_data_crank_post;
+ break;
+#endif
+ default:
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ break;
+ }
+}
+
+
+/*
+ * Process some data.
+ * This handles the simple case where no context is required.
+ */
+
+METHODDEF(void)
+process_data_simple_main (j_decompress_ptr cinfo,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail)
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+ JDIMENSION rowgroups_avail;
+
+ /* Read input data if we haven't filled the main buffer yet */
+ if (! main->buffer_full) {
+ if (! (*cinfo->coef->decompress_data) (cinfo, main->buffer))
+ return; /* suspension forced, can do nothing more */
+ main->buffer_full = TRUE; /* OK, we have an iMCU row to work with */
+ }
+
+ /* There are always min_DCT_scaled_size row groups in an iMCU row. */
+ rowgroups_avail = (JDIMENSION) cinfo->min_DCT_scaled_size;
+ /* Note: at the bottom of the image, we may pass extra garbage row groups
+ * to the postprocessor. The postprocessor has to check for bottom
+ * of image anyway (at row resolution), so no point in us doing it too.
+ */
+
+ /* Feed the postprocessor */
+ (*cinfo->post->post_process_data) (cinfo, main->buffer,
+ &main->rowgroup_ctr, rowgroups_avail,
+ output_buf, out_row_ctr, out_rows_avail);
+
+ /* Has postprocessor consumed all the data yet? If so, mark buffer empty */
+ if (main->rowgroup_ctr >= rowgroups_avail) {
+ main->buffer_full = FALSE;
+ main->rowgroup_ctr = 0;
+ }
+}
+
+
+/*
+ * Process some data.
+ * This handles the case where context rows must be provided.
+ */
+
+METHODDEF(void)
+process_data_context_main (j_decompress_ptr cinfo,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail)
+{
+ my_main_ptr main = (my_main_ptr) cinfo->main;
+
+ /* Read input data if we haven't filled the main buffer yet */
+ if (! main->buffer_full) {
+ if (! (*cinfo->coef->decompress_data) (cinfo,
+ main->xbuffer[main->whichptr]))
+ return; /* suspension forced, can do nothing more */
+ main->buffer_full = TRUE; /* OK, we have an iMCU row to work with */
+ main->iMCU_row_ctr++; /* count rows received */
+ }
+
+ /* Postprocessor typically will not swallow all the input data it is handed
+ * in one call (due to filling the output buffer first). Must be prepared
+ * to exit and restart. This switch lets us keep track of how far we got.
+ * Note that each case falls through to the next on successful completion.
+ */
+ switch (main->context_state) {
+ case CTX_POSTPONED_ROW:
+ /* Call postprocessor using previously set pointers for postponed row */
+ (*cinfo->post->post_process_data) (cinfo, main->xbuffer[main->whichptr],
+ &main->rowgroup_ctr, main->rowgroups_avail,
+ output_buf, out_row_ctr, out_rows_avail);
+ if (main->rowgroup_ctr < main->rowgroups_avail)
+ return; /* Need to suspend */
+ main->context_state = CTX_PREPARE_FOR_IMCU;
+ if (*out_row_ctr >= out_rows_avail)
+ return; /* Postprocessor exactly filled output buf */
+ /*FALLTHROUGH*/
+ case CTX_PREPARE_FOR_IMCU:
+ /* Prepare to process first M-1 row groups of this iMCU row */
+ main->rowgroup_ctr = 0;
+ main->rowgroups_avail = (JDIMENSION) (cinfo->min_DCT_scaled_size - 1);
+ /* Check for bottom of image: if so, tweak pointers to "duplicate"
+ * the last sample row, and adjust rowgroups_avail to ignore padding rows.
+ */
+ if (main->iMCU_row_ctr == cinfo->total_iMCU_rows)
+ set_bottom_pointers(cinfo);
+ main->context_state = CTX_PROCESS_IMCU;
+ /*FALLTHROUGH*/
+ case CTX_PROCESS_IMCU:
+ /* Call postprocessor using previously set pointers */
+ (*cinfo->post->post_process_data) (cinfo, main->xbuffer[main->whichptr],
+ &main->rowgroup_ctr, main->rowgroups_avail,
+ output_buf, out_row_ctr, out_rows_avail);
+ if (main->rowgroup_ctr < main->rowgroups_avail)
+ return; /* Need to suspend */
+ /* After the first iMCU, change wraparound pointers to normal state */
+ if (main->iMCU_row_ctr == 1)
+ set_wraparound_pointers(cinfo);
+ /* Prepare to load new iMCU row using other xbuffer list */
+ main->whichptr ^= 1; /* 0=>1 or 1=>0 */
+ main->buffer_full = FALSE;
+ /* Still need to process last row group of this iMCU row, */
+ /* which is saved at index M+1 of the other xbuffer */
+ main->rowgroup_ctr = (JDIMENSION) (cinfo->min_DCT_scaled_size + 1);
+ main->rowgroups_avail = (JDIMENSION) (cinfo->min_DCT_scaled_size + 2);
+ main->context_state = CTX_POSTPONED_ROW;
+ }
+}
+
+
+/*
+ * Process some data.
+ * Final pass of two-pass quantization: just call the postprocessor.
+ * Source data will be the postprocessor controller's internal buffer.
+ */
+
+#ifdef QUANT_2PASS_SUPPORTED
+
+METHODDEF(void)
+process_data_crank_post (j_decompress_ptr cinfo,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail)
+{
+ (*cinfo->post->post_process_data) (cinfo, (JSAMPIMAGE) NULL,
+ (JDIMENSION *) NULL, (JDIMENSION) 0,
+ output_buf, out_row_ctr, out_rows_avail);
+}
+
+#endif /* QUANT_2PASS_SUPPORTED */
+
+
+/*
+ * Initialize main buffer controller.
+ */
+
+GLOBAL(void)
+jinit_d_main_controller (j_decompress_ptr cinfo, int need_full_buffer)
+{
+ my_main_ptr main;
+ int ci, rgroup, ngroups;
+ jpeg_component_info *compptr;
+
+ main = (my_main_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_main_controller));
+ cinfo->main = (struct jpeg_d_main_controller *) main;
+ main->pub.start_pass = start_pass_main;
+
+ if (need_full_buffer) /* shouldn't happen */
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+
+ /* Allocate the workspace.
+ * ngroups is the number of row groups we need.
+ */
+ if (cinfo->upsample->need_context_rows) {
+ if (cinfo->min_DCT_scaled_size < 2) /* unsupported, see comments above */
+ ERREXIT(cinfo, JERR_NOTIMPL);
+ alloc_funny_pointers(cinfo); /* Alloc space for xbuffer[] lists */
+ ngroups = cinfo->min_DCT_scaled_size + 2;
+ } else {
+ ngroups = cinfo->min_DCT_scaled_size;
+ }
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ rgroup = (compptr->v_samp_factor * compptr->DCT_scaled_size) /
+ cinfo->min_DCT_scaled_size; /* height of a row group of component */
+ main->buffer[ci] = (*cinfo->mem->alloc_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ compptr->width_in_blocks * compptr->DCT_scaled_size,
+ (JDIMENSION) (rgroup * ngroups));
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdmarker.cpp b/ml/dlib/dlib/external/libjpeg/jdmarker.cpp
new file mode 100644
index 000000000..c8c9f8e1c
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdmarker.cpp
@@ -0,0 +1,1360 @@
+/*
+ * jdmarker.c
+ *
+ * Copyright (C) 1991-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains routines to decode JPEG datastream markers.
+ * Most of the complexity arises from our desire to support input
+ * suspension: if not all of the data for a marker is available,
+ * we must exit back to the application. On resumption, we reprocess
+ * the marker.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+typedef enum { /* JPEG marker codes */
+ M_SOF0 = 0xc0,
+ M_SOF1 = 0xc1,
+ M_SOF2 = 0xc2,
+ M_SOF3 = 0xc3,
+
+ M_SOF5 = 0xc5,
+ M_SOF6 = 0xc6,
+ M_SOF7 = 0xc7,
+
+ M_JPG = 0xc8,
+ M_SOF9 = 0xc9,
+ M_SOF10 = 0xca,
+ M_SOF11 = 0xcb,
+
+ M_SOF13 = 0xcd,
+ M_SOF14 = 0xce,
+ M_SOF15 = 0xcf,
+
+ M_DHT = 0xc4,
+
+ M_DAC = 0xcc,
+
+ M_RST0 = 0xd0,
+ M_RST1 = 0xd1,
+ M_RST2 = 0xd2,
+ M_RST3 = 0xd3,
+ M_RST4 = 0xd4,
+ M_RST5 = 0xd5,
+ M_RST6 = 0xd6,
+ M_RST7 = 0xd7,
+
+ M_SOI = 0xd8,
+ M_EOI = 0xd9,
+ M_SOS = 0xda,
+ M_DQT = 0xdb,
+ M_DNL = 0xdc,
+ M_DRI = 0xdd,
+ M_DHP = 0xde,
+ M_EXP = 0xdf,
+
+ M_APP0 = 0xe0,
+ M_APP1 = 0xe1,
+ M_APP2 = 0xe2,
+ M_APP3 = 0xe3,
+ M_APP4 = 0xe4,
+ M_APP5 = 0xe5,
+ M_APP6 = 0xe6,
+ M_APP7 = 0xe7,
+ M_APP8 = 0xe8,
+ M_APP9 = 0xe9,
+ M_APP10 = 0xea,
+ M_APP11 = 0xeb,
+ M_APP12 = 0xec,
+ M_APP13 = 0xed,
+ M_APP14 = 0xee,
+ M_APP15 = 0xef,
+
+ M_JPG0 = 0xf0,
+ M_JPG13 = 0xfd,
+ M_COM = 0xfe,
+
+ M_TEM = 0x01,
+
+ M_ERROR = 0x100
+} JPEG_MARKER;
+
+
+/* Private state */
+
+typedef struct {
+ struct jpeg_marker_reader pub; /* public fields */
+
+ /* Application-overridable marker processing methods */
+ jpeg_marker_parser_method process_COM;
+ jpeg_marker_parser_method process_APPn[16];
+
+ /* Limit on marker data length to save for each marker type */
+ unsigned int length_limit_COM;
+ unsigned int length_limit_APPn[16];
+
+ /* Status of COM/APPn marker saving */
+ jpeg_saved_marker_ptr cur_marker; /* NULL if not processing a marker */
+ unsigned int bytes_read; /* data bytes read so far in marker */
+ /* Note: cur_marker is not linked into marker_list until it's all read. */
+} my_marker_reader;
+
+typedef my_marker_reader * my_marker_ptr;
+
+
+/*
+ * Macros for fetching data from the data source module.
+ *
+ * At all times, cinfo->src->next_input_byte and ->bytes_in_buffer reflect
+ * the current restart point; we update them only when we have reached a
+ * suitable place to restart if a suspension occurs.
+ */
+
+/* Declare and initialize local copies of input pointer/count */
+#define INPUT_VARS(cinfo) \
+ struct jpeg_source_mgr * datasrc = (cinfo)->src; \
+ const JOCTET * next_input_byte = datasrc->next_input_byte; \
+ size_t bytes_in_buffer = datasrc->bytes_in_buffer
+
+/* Unload the local copies --- do this only at a restart boundary */
+#define INPUT_SYNC(cinfo) \
+ ( datasrc->next_input_byte = next_input_byte, \
+ datasrc->bytes_in_buffer = bytes_in_buffer )
+
+/* Reload the local copies --- used only in MAKE_BYTE_AVAIL */
+#define INPUT_RELOAD(cinfo) \
+ ( next_input_byte = datasrc->next_input_byte, \
+ bytes_in_buffer = datasrc->bytes_in_buffer )
+
+/* Internal macro for INPUT_BYTE and INPUT_2BYTES: make a byte available.
+ * Note we do *not* do INPUT_SYNC before calling fill_input_buffer,
+ * but we must reload the local copies after a successful fill.
+ */
+#define MAKE_BYTE_AVAIL(cinfo,action) \
+ if (bytes_in_buffer == 0) { \
+ if (! (*datasrc->fill_input_buffer) (cinfo)) \
+ { action; } \
+ INPUT_RELOAD(cinfo); \
+ }
+
+/* Read a byte into variable V.
+ * If must suspend, take the specified action (typically "return FALSE").
+ */
+#define INPUT_BYTE(cinfo,V,action) \
+ MAKESTMT( MAKE_BYTE_AVAIL(cinfo,action); \
+ bytes_in_buffer--; \
+ V = GETJOCTET(*next_input_byte++); )
+
+/* As above, but read two bytes interpreted as an unsigned 16-bit integer.
+ * V should be declared unsigned int or perhaps long.
+ */
+#define INPUT_2BYTES(cinfo,V,action) \
+ MAKESTMT( MAKE_BYTE_AVAIL(cinfo,action); \
+ bytes_in_buffer--; \
+ V = ((unsigned int) GETJOCTET(*next_input_byte++)) << 8; \
+ MAKE_BYTE_AVAIL(cinfo,action); \
+ bytes_in_buffer--; \
+ V += GETJOCTET(*next_input_byte++); )
+
+
+/*
+ * Routines to process JPEG markers.
+ *
+ * Entry condition: JPEG marker itself has been read and its code saved
+ * in cinfo->unread_marker; input restart point is just after the marker.
+ *
+ * Exit: if return TRUE, have read and processed any parameters, and have
+ * updated the restart point to point after the parameters.
+ * If return FALSE, was forced to suspend before reaching end of
+ * marker parameters; restart point has not been moved. Same routine
+ * will be called again after application supplies more input data.
+ *
+ * This approach to suspension assumes that all of a marker's parameters
+ * can fit into a single input bufferload. This should hold for "normal"
+ * markers. Some COM/APPn markers might have large parameter segments
+ * that might not fit. If we are simply dropping such a marker, we use
+ * skip_input_data to get past it, and thereby put the problem on the
+ * source manager's shoulders. If we are saving the marker's contents
+ * into memory, we use a slightly different convention: when forced to
+ * suspend, the marker processor updates the restart point to the end of
+ * what it's consumed (ie, the end of the buffer) before returning FALSE.
+ * On resumption, cinfo->unread_marker still contains the marker code,
+ * but the data source will point to the next chunk of marker data.
+ * The marker processor must retain internal state to deal with this.
+ *
+ * Note that we don't bother to avoid duplicate trace messages if a
+ * suspension occurs within marker parameters. Other side effects
+ * require more care.
+ */
+
+
+LOCAL(int)
+get_soi (j_decompress_ptr cinfo)
+/* Process an SOI marker */
+{
+ int i;
+
+ TRACEMS(cinfo, 1, JTRC_SOI);
+
+ if (cinfo->marker->saw_SOI)
+ ERREXIT(cinfo, JERR_SOI_DUPLICATE);
+
+ /* Reset all parameters that are defined to be reset by SOI */
+
+ for (i = 0; i < NUM_ARITH_TBLS; i++) {
+ cinfo->arith_dc_L[i] = 0;
+ cinfo->arith_dc_U[i] = 1;
+ cinfo->arith_ac_K[i] = 5;
+ }
+ cinfo->restart_interval = 0;
+
+ /* Set initial assumptions for colorspace etc */
+
+ cinfo->jpeg_color_space = JCS_UNKNOWN;
+ cinfo->CCIR601_sampling = FALSE; /* Assume non-CCIR sampling??? */
+
+ cinfo->saw_JFIF_marker = FALSE;
+ cinfo->JFIF_major_version = 1; /* set default JFIF APP0 values */
+ cinfo->JFIF_minor_version = 1;
+ cinfo->density_unit = 0;
+ cinfo->X_density = 1;
+ cinfo->Y_density = 1;
+ cinfo->saw_Adobe_marker = FALSE;
+ cinfo->Adobe_transform = 0;
+
+ cinfo->marker->saw_SOI = TRUE;
+
+ return TRUE;
+}
+
+
+LOCAL(int)
+get_sof (j_decompress_ptr cinfo, int is_prog, int is_arith)
+/* Process a SOFn marker */
+{
+ long length;
+ int c, ci;
+ jpeg_component_info * compptr;
+ INPUT_VARS(cinfo);
+
+ cinfo->progressive_mode = is_prog;
+ cinfo->arith_code = is_arith;
+
+ INPUT_2BYTES(cinfo, length, return FALSE);
+
+ INPUT_BYTE(cinfo, cinfo->data_precision, return FALSE);
+ INPUT_2BYTES(cinfo, cinfo->image_height, return FALSE);
+ INPUT_2BYTES(cinfo, cinfo->image_width, return FALSE);
+ INPUT_BYTE(cinfo, cinfo->num_components, return FALSE);
+
+ length -= 8;
+
+ TRACEMS4(cinfo, 1, JTRC_SOF, cinfo->unread_marker,
+ (int) cinfo->image_width, (int) cinfo->image_height,
+ cinfo->num_components);
+
+ if (cinfo->marker->saw_SOF)
+ ERREXIT(cinfo, JERR_SOF_DUPLICATE);
+
+ /* We don't support files in which the image height is initially specified */
+ /* as 0 and is later redefined by DNL. As long as we have to check that, */
+ /* might as well have a general sanity check. */
+ if (cinfo->image_height <= 0 || cinfo->image_width <= 0
+ || cinfo->num_components <= 0)
+ ERREXIT(cinfo, JERR_EMPTY_IMAGE);
+
+ if (length != (cinfo->num_components * 3))
+ ERREXIT(cinfo, JERR_BAD_LENGTH);
+
+ if (cinfo->comp_info == NULL) /* do only once, even if suspend */
+ cinfo->comp_info = (jpeg_component_info *) (*cinfo->mem->alloc_small)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ cinfo->num_components * SIZEOF(jpeg_component_info));
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ compptr->component_index = ci;
+ INPUT_BYTE(cinfo, compptr->component_id, return FALSE);
+ INPUT_BYTE(cinfo, c, return FALSE);
+ compptr->h_samp_factor = (c >> 4) & 15;
+ compptr->v_samp_factor = (c ) & 15;
+ INPUT_BYTE(cinfo, compptr->quant_tbl_no, return FALSE);
+
+ TRACEMS4(cinfo, 1, JTRC_SOF_COMPONENT,
+ compptr->component_id, compptr->h_samp_factor,
+ compptr->v_samp_factor, compptr->quant_tbl_no);
+ }
+
+ cinfo->marker->saw_SOF = TRUE;
+
+ INPUT_SYNC(cinfo);
+ return TRUE;
+}
+
+
+LOCAL(int)
+get_sos (j_decompress_ptr cinfo)
+/* Process a SOS marker */
+{
+ long length;
+ int i, ci, n, c, cc;
+ jpeg_component_info * compptr;
+ INPUT_VARS(cinfo);
+
+ if (! cinfo->marker->saw_SOF)
+ ERREXIT(cinfo, JERR_SOS_NO_SOF);
+
+ INPUT_2BYTES(cinfo, length, return FALSE);
+
+ INPUT_BYTE(cinfo, n, return FALSE); /* Number of components */
+
+ TRACEMS1(cinfo, 1, JTRC_SOS, n);
+
+ if (length != (n * 2 + 6) || n < 1 || n > MAX_COMPS_IN_SCAN)
+ ERREXIT(cinfo, JERR_BAD_LENGTH);
+
+ cinfo->comps_in_scan = n;
+
+ /* Collect the component-spec parameters */
+
+ for (i = 0; i < n; i++) {
+ INPUT_BYTE(cinfo, cc, return FALSE);
+ INPUT_BYTE(cinfo, c, return FALSE);
+
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ if (cc == compptr->component_id)
+ goto id_found;
+ }
+
+ ERREXIT1(cinfo, JERR_BAD_COMPONENT_ID, cc);
+
+ id_found:
+
+ cinfo->cur_comp_info[i] = compptr;
+ compptr->dc_tbl_no = (c >> 4) & 15;
+ compptr->ac_tbl_no = (c ) & 15;
+
+ TRACEMS3(cinfo, 1, JTRC_SOS_COMPONENT, cc,
+ compptr->dc_tbl_no, compptr->ac_tbl_no);
+ }
+
+ /* Collect the additional scan parameters Ss, Se, Ah/Al. */
+ INPUT_BYTE(cinfo, c, return FALSE);
+ cinfo->Ss = c;
+ INPUT_BYTE(cinfo, c, return FALSE);
+ cinfo->Se = c;
+ INPUT_BYTE(cinfo, c, return FALSE);
+ cinfo->Ah = (c >> 4) & 15;
+ cinfo->Al = (c ) & 15;
+
+ TRACEMS4(cinfo, 1, JTRC_SOS_PARAMS, cinfo->Ss, cinfo->Se,
+ cinfo->Ah, cinfo->Al);
+
+ /* Prepare to scan data & restart markers */
+ cinfo->marker->next_restart_num = 0;
+
+ /* Count another SOS marker */
+ cinfo->input_scan_number++;
+
+ INPUT_SYNC(cinfo);
+ return TRUE;
+}
+
+
+#ifdef D_ARITH_CODING_SUPPORTED
+
+LOCAL(int)
+get_dac (j_decompress_ptr cinfo)
+/* Process a DAC marker */
+{
+ long length;
+ int index, val;
+ INPUT_VARS(cinfo);
+
+ INPUT_2BYTES(cinfo, length, return FALSE);
+ length -= 2;
+
+ while (length > 0) {
+ INPUT_BYTE(cinfo, index, return FALSE);
+ INPUT_BYTE(cinfo, val, return FALSE);
+
+ length -= 2;
+
+ TRACEMS2(cinfo, 1, JTRC_DAC, index, val);
+
+ if (index < 0 || index >= (2*NUM_ARITH_TBLS))
+ ERREXIT1(cinfo, JERR_DAC_INDEX, index);
+
+ if (index >= NUM_ARITH_TBLS) { /* define AC table */
+ cinfo->arith_ac_K[index-NUM_ARITH_TBLS] = (unsigned char) val;
+ } else { /* define DC table */
+ cinfo->arith_dc_L[index] = (unsigned char) (val & 0x0F);
+ cinfo->arith_dc_U[index] = (unsigned char) (val >> 4);
+ if (cinfo->arith_dc_L[index] > cinfo->arith_dc_U[index])
+ ERREXIT1(cinfo, JERR_DAC_VALUE, val);
+ }
+ }
+
+ if (length != 0)
+ ERREXIT(cinfo, JERR_BAD_LENGTH);
+
+ INPUT_SYNC(cinfo);
+ return TRUE;
+}
+
+#else /* ! D_ARITH_CODING_SUPPORTED */
+
+#define get_dac(cinfo) skip_variable(cinfo)
+
+#endif /* D_ARITH_CODING_SUPPORTED */
+
+
+LOCAL(int)
+get_dht (j_decompress_ptr cinfo)
+/* Process a DHT marker */
+{
+ long length;
+ unsigned char bits[17];
+ unsigned char huffval[256];
+ int i, index, count;
+ JHUFF_TBL **htblptr;
+ INPUT_VARS(cinfo);
+
+ INPUT_2BYTES(cinfo, length, return FALSE);
+ length -= 2;
+
+ while (length > 16) {
+ INPUT_BYTE(cinfo, index, return FALSE);
+
+ TRACEMS1(cinfo, 1, JTRC_DHT, index);
+
+ bits[0] = 0;
+ count = 0;
+ for (i = 1; i <= 16; i++) {
+ INPUT_BYTE(cinfo, bits[i], return FALSE);
+ count += bits[i];
+ }
+
+ length -= 1 + 16;
+
+ TRACEMS8(cinfo, 2, JTRC_HUFFBITS,
+ bits[1], bits[2], bits[3], bits[4],
+ bits[5], bits[6], bits[7], bits[8]);
+ TRACEMS8(cinfo, 2, JTRC_HUFFBITS,
+ bits[9], bits[10], bits[11], bits[12],
+ bits[13], bits[14], bits[15], bits[16]);
+
+ /* Here we just do minimal validation of the counts to avoid walking
+ * off the end of our table space. jdhuff.c will check more carefully.
+ */
+ if (count > 256 || ((long) count) > length)
+ ERREXIT(cinfo, JERR_BAD_HUFF_TABLE);
+
+ for (i = 0; i < count; i++)
+ INPUT_BYTE(cinfo, huffval[i], return FALSE);
+
+ length -= count;
+
+ if (index & 0x10) { /* AC table definition */
+ index -= 0x10;
+ htblptr = &cinfo->ac_huff_tbl_ptrs[index];
+ } else { /* DC table definition */
+ htblptr = &cinfo->dc_huff_tbl_ptrs[index];
+ }
+
+ if (index < 0 || index >= NUM_HUFF_TBLS)
+ ERREXIT1(cinfo, JERR_DHT_INDEX, index);
+
+ if (*htblptr == NULL)
+ *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo);
+
+ MEMCOPY((*htblptr)->bits, bits, SIZEOF((*htblptr)->bits));
+ MEMCOPY((*htblptr)->huffval, huffval, SIZEOF((*htblptr)->huffval));
+ }
+
+ if (length != 0)
+ ERREXIT(cinfo, JERR_BAD_LENGTH);
+
+ INPUT_SYNC(cinfo);
+ return TRUE;
+}
+
+
+LOCAL(int)
+get_dqt (j_decompress_ptr cinfo)
+/* Process a DQT marker */
+{
+ long length;
+ int n, i, prec;
+ unsigned int tmp;
+ JQUANT_TBL *quant_ptr;
+ INPUT_VARS(cinfo);
+
+ INPUT_2BYTES(cinfo, length, return FALSE);
+ length -= 2;
+
+ while (length > 0) {
+ INPUT_BYTE(cinfo, n, return FALSE);
+ prec = n >> 4;
+ n &= 0x0F;
+
+ TRACEMS2(cinfo, 1, JTRC_DQT, n, prec);
+
+ if (n >= NUM_QUANT_TBLS)
+ ERREXIT1(cinfo, JERR_DQT_INDEX, n);
+
+ if (cinfo->quant_tbl_ptrs[n] == NULL)
+ cinfo->quant_tbl_ptrs[n] = jpeg_alloc_quant_table((j_common_ptr) cinfo);
+ quant_ptr = cinfo->quant_tbl_ptrs[n];
+
+ for (i = 0; i < DCTSIZE2; i++) {
+ if (prec)
+ INPUT_2BYTES(cinfo, tmp, return FALSE);
+ else
+ INPUT_BYTE(cinfo, tmp, return FALSE);
+ /* We convert the zigzag-order table to natural array order. */
+ quant_ptr->quantval[jpeg_natural_order[i]] = (unsigned short) tmp;
+ }
+
+ if (cinfo->err->trace_level >= 2) {
+ for (i = 0; i < DCTSIZE2; i += 8) {
+ TRACEMS8(cinfo, 2, JTRC_QUANTVALS,
+ quant_ptr->quantval[i], quant_ptr->quantval[i+1],
+ quant_ptr->quantval[i+2], quant_ptr->quantval[i+3],
+ quant_ptr->quantval[i+4], quant_ptr->quantval[i+5],
+ quant_ptr->quantval[i+6], quant_ptr->quantval[i+7]);
+ }
+ }
+
+ length -= DCTSIZE2+1;
+ if (prec) length -= DCTSIZE2;
+ }
+
+ if (length != 0)
+ ERREXIT(cinfo, JERR_BAD_LENGTH);
+
+ INPUT_SYNC(cinfo);
+ return TRUE;
+}
+
+
+LOCAL(int)
+get_dri (j_decompress_ptr cinfo)
+/* Process a DRI marker */
+{
+ long length;
+ unsigned int tmp;
+ INPUT_VARS(cinfo);
+
+ INPUT_2BYTES(cinfo, length, return FALSE);
+
+ if (length != 4)
+ ERREXIT(cinfo, JERR_BAD_LENGTH);
+
+ INPUT_2BYTES(cinfo, tmp, return FALSE);
+
+ TRACEMS1(cinfo, 1, JTRC_DRI, tmp);
+
+ cinfo->restart_interval = tmp;
+
+ INPUT_SYNC(cinfo);
+ return TRUE;
+}
+
+
+/*
+ * Routines for processing APPn and COM markers.
+ * These are either saved in memory or discarded, per application request.
+ * APP0 and APP14 are specially checked to see if they are
+ * JFIF and Adobe markers, respectively.
+ */
+
+#define APP0_DATA_LEN 14 /* Length of interesting data in APP0 */
+#define APP14_DATA_LEN 12 /* Length of interesting data in APP14 */
+#define APPN_DATA_LEN 14 /* Must be the largest of the above!! */
+
+
+LOCAL(void)
+examine_app0 (j_decompress_ptr cinfo, JOCTET FAR * data,
+ unsigned int datalen, long remaining)
+/* Examine first few bytes from an APP0.
+ * Take appropriate action if it is a JFIF marker.
+ * datalen is # of bytes at data[], remaining is length of rest of marker data.
+ */
+{
+ long totallen = (long) datalen + remaining;
+
+ if (datalen >= APP0_DATA_LEN &&
+ GETJOCTET(data[0]) == 0x4A &&
+ GETJOCTET(data[1]) == 0x46 &&
+ GETJOCTET(data[2]) == 0x49 &&
+ GETJOCTET(data[3]) == 0x46 &&
+ GETJOCTET(data[4]) == 0) {
+ /* Found JFIF APP0 marker: save info */
+ cinfo->saw_JFIF_marker = TRUE;
+ cinfo->JFIF_major_version = GETJOCTET(data[5]);
+ cinfo->JFIF_minor_version = GETJOCTET(data[6]);
+ cinfo->density_unit = GETJOCTET(data[7]);
+ cinfo->X_density = (GETJOCTET(data[8]) << 8) + GETJOCTET(data[9]);
+ cinfo->Y_density = (GETJOCTET(data[10]) << 8) + GETJOCTET(data[11]);
+ /* Check version.
+ * Major version must be 1, anything else signals an incompatible change.
+ * (We used to treat this as an error, but now it's a nonfatal warning,
+ * because some bozo at Hijaak couldn't read the spec.)
+ * Minor version should be 0..2, but process anyway if newer.
+ */
+ if (cinfo->JFIF_major_version != 1)
+ WARNMS2(cinfo, JWRN_JFIF_MAJOR,
+ cinfo->JFIF_major_version, cinfo->JFIF_minor_version);
+ /* Generate trace messages */
+ TRACEMS5(cinfo, 1, JTRC_JFIF,
+ cinfo->JFIF_major_version, cinfo->JFIF_minor_version,
+ cinfo->X_density, cinfo->Y_density, cinfo->density_unit);
+ /* Validate thumbnail dimensions and issue appropriate messages */
+ if (GETJOCTET(data[12]) | GETJOCTET(data[13]))
+ TRACEMS2(cinfo, 1, JTRC_JFIF_THUMBNAIL,
+ GETJOCTET(data[12]), GETJOCTET(data[13]));
+ totallen -= APP0_DATA_LEN;
+ if (totallen !=
+ ((long)GETJOCTET(data[12]) * (long)GETJOCTET(data[13]) * (long) 3))
+ TRACEMS1(cinfo, 1, JTRC_JFIF_BADTHUMBNAILSIZE, (int) totallen);
+ } else if (datalen >= 6 &&
+ GETJOCTET(data[0]) == 0x4A &&
+ GETJOCTET(data[1]) == 0x46 &&
+ GETJOCTET(data[2]) == 0x58 &&
+ GETJOCTET(data[3]) == 0x58 &&
+ GETJOCTET(data[4]) == 0) {
+ /* Found JFIF "JFXX" extension APP0 marker */
+ /* The library doesn't actually do anything with these,
+ * but we try to produce a helpful trace message.
+ */
+ switch (GETJOCTET(data[5])) {
+ case 0x10:
+ TRACEMS1(cinfo, 1, JTRC_THUMB_JPEG, (int) totallen);
+ break;
+ case 0x11:
+ TRACEMS1(cinfo, 1, JTRC_THUMB_PALETTE, (int) totallen);
+ break;
+ case 0x13:
+ TRACEMS1(cinfo, 1, JTRC_THUMB_RGB, (int) totallen);
+ break;
+ default:
+ TRACEMS2(cinfo, 1, JTRC_JFIF_EXTENSION,
+ GETJOCTET(data[5]), (int) totallen);
+ break;
+ }
+ } else {
+ /* Start of APP0 does not match "JFIF" or "JFXX", or too short */
+ TRACEMS1(cinfo, 1, JTRC_APP0, (int) totallen);
+ }
+}
+
+
+LOCAL(void)
+examine_app14 (j_decompress_ptr cinfo, JOCTET FAR * data,
+ unsigned int datalen, long remaining)
+/* Examine first few bytes from an APP14.
+ * Take appropriate action if it is an Adobe marker.
+ * datalen is # of bytes at data[], remaining is length of rest of marker data.
+ */
+{
+ unsigned int version, flags0, flags1, transform;
+
+ if (datalen >= APP14_DATA_LEN &&
+ GETJOCTET(data[0]) == 0x41 &&
+ GETJOCTET(data[1]) == 0x64 &&
+ GETJOCTET(data[2]) == 0x6F &&
+ GETJOCTET(data[3]) == 0x62 &&
+ GETJOCTET(data[4]) == 0x65) {
+ /* Found Adobe APP14 marker */
+ version = (GETJOCTET(data[5]) << 8) + GETJOCTET(data[6]);
+ flags0 = (GETJOCTET(data[7]) << 8) + GETJOCTET(data[8]);
+ flags1 = (GETJOCTET(data[9]) << 8) + GETJOCTET(data[10]);
+ transform = GETJOCTET(data[11]);
+ TRACEMS4(cinfo, 1, JTRC_ADOBE, version, flags0, flags1, transform);
+ cinfo->saw_Adobe_marker = TRUE;
+ cinfo->Adobe_transform = (unsigned char) transform;
+ } else {
+ /* Start of APP14 does not match "Adobe", or too short */
+ TRACEMS1(cinfo, 1, JTRC_APP14, (int) (datalen + remaining));
+ }
+}
+
+
+METHODDEF(int)
+get_interesting_appn (j_decompress_ptr cinfo)
+/* Process an APP0 or APP14 marker without saving it */
+{
+ long length;
+ JOCTET b[APPN_DATA_LEN];
+ unsigned int i, numtoread;
+ INPUT_VARS(cinfo);
+
+ INPUT_2BYTES(cinfo, length, return FALSE);
+ length -= 2;
+
+ /* get the interesting part of the marker data */
+ if (length >= APPN_DATA_LEN)
+ numtoread = APPN_DATA_LEN;
+ else if (length > 0)
+ numtoread = (unsigned int) length;
+ else
+ numtoread = 0;
+ for (i = 0; i < numtoread; i++)
+ INPUT_BYTE(cinfo, b[i], return FALSE);
+ length -= numtoread;
+
+ /* process it */
+ switch (cinfo->unread_marker) {
+ case M_APP0:
+ examine_app0(cinfo, (JOCTET FAR *) b, numtoread, length);
+ break;
+ case M_APP14:
+ examine_app14(cinfo, (JOCTET FAR *) b, numtoread, length);
+ break;
+ default:
+ /* can't get here unless jpeg_save_markers chooses wrong processor */
+ ERREXIT1(cinfo, JERR_UNKNOWN_MARKER, cinfo->unread_marker);
+ break;
+ }
+
+ /* skip any remaining data -- could be lots */
+ INPUT_SYNC(cinfo);
+ if (length > 0)
+ (*cinfo->src->skip_input_data) (cinfo, (long) length);
+
+ return TRUE;
+}
+
+
+#ifdef SAVE_MARKERS_SUPPORTED
+
+METHODDEF(int)
+save_marker (j_decompress_ptr cinfo)
+/* Save an APPn or COM marker into the marker list */
+{
+ my_marker_ptr marker = (my_marker_ptr) cinfo->marker;
+ jpeg_saved_marker_ptr cur_marker = marker->cur_marker;
+ unsigned int bytes_read, data_length;
+ JOCTET FAR * data;
+ long length = 0;
+ INPUT_VARS(cinfo);
+
+ if (cur_marker == NULL) {
+ /* begin reading a marker */
+ INPUT_2BYTES(cinfo, length, return FALSE);
+ length -= 2;
+ if (length >= 0) { /* watch out for bogus length word */
+ /* figure out how much we want to save */
+ unsigned int limit;
+ if (cinfo->unread_marker == (int) M_COM)
+ limit = marker->length_limit_COM;
+ else
+ limit = marker->length_limit_APPn[cinfo->unread_marker - (int) M_APP0];
+ if ((unsigned int) length < limit)
+ limit = (unsigned int) length;
+ /* allocate and initialize the marker item */
+ cur_marker = (jpeg_saved_marker_ptr)
+ (*cinfo->mem->alloc_large) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(struct jpeg_marker_struct) + limit);
+ cur_marker->next = NULL;
+ cur_marker->marker = (unsigned char) cinfo->unread_marker;
+ cur_marker->original_length = (unsigned int) length;
+ cur_marker->data_length = limit;
+ /* data area is just beyond the jpeg_marker_struct */
+ data = cur_marker->data = (JOCTET FAR *) (cur_marker + 1);
+ marker->cur_marker = cur_marker;
+ marker->bytes_read = 0;
+ bytes_read = 0;
+ data_length = limit;
+ } else {
+ /* deal with bogus length word */
+ bytes_read = data_length = 0;
+ data = NULL;
+ }
+ } else {
+ /* resume reading a marker */
+ bytes_read = marker->bytes_read;
+ data_length = cur_marker->data_length;
+ data = cur_marker->data + bytes_read;
+ }
+
+ while (bytes_read < data_length) {
+ INPUT_SYNC(cinfo); /* move the restart point to here */
+ marker->bytes_read = bytes_read;
+ /* If there's not at least one byte in buffer, suspend */
+ MAKE_BYTE_AVAIL(cinfo, return FALSE);
+ /* Copy bytes with reasonable rapidity */
+ while (bytes_read < data_length && bytes_in_buffer > 0) {
+ *data++ = *next_input_byte++;
+ bytes_in_buffer--;
+ bytes_read++;
+ }
+ }
+
+ /* Done reading what we want to read */
+ if (cur_marker != NULL) { /* will be NULL if bogus length word */
+ /* Add new marker to end of list */
+ if (cinfo->marker_list == NULL) {
+ cinfo->marker_list = cur_marker;
+ } else {
+ jpeg_saved_marker_ptr prev = cinfo->marker_list;
+ while (prev->next != NULL)
+ prev = prev->next;
+ prev->next = cur_marker;
+ }
+ /* Reset pointer & calc remaining data length */
+ data = cur_marker->data;
+ length = cur_marker->original_length - data_length;
+ }
+ /* Reset to initial state for next marker */
+ marker->cur_marker = NULL;
+
+ /* Process the marker if interesting; else just make a generic trace msg */
+ switch (cinfo->unread_marker) {
+ case M_APP0:
+ examine_app0(cinfo, data, data_length, length);
+ break;
+ case M_APP14:
+ examine_app14(cinfo, data, data_length, length);
+ break;
+ default:
+ TRACEMS2(cinfo, 1, JTRC_MISC_MARKER, cinfo->unread_marker,
+ (int) (data_length + length));
+ break;
+ }
+
+ /* skip any remaining data -- could be lots */
+ INPUT_SYNC(cinfo); /* do before skip_input_data */
+ if (length > 0)
+ (*cinfo->src->skip_input_data) (cinfo, (long) length);
+
+ return TRUE;
+}
+
+#endif /* SAVE_MARKERS_SUPPORTED */
+
+
+METHODDEF(int)
+skip_variable (j_decompress_ptr cinfo)
+/* Skip over an unknown or uninteresting variable-length marker */
+{
+ long length;
+ INPUT_VARS(cinfo);
+
+ INPUT_2BYTES(cinfo, length, return FALSE);
+ length -= 2;
+
+ TRACEMS2(cinfo, 1, JTRC_MISC_MARKER, cinfo->unread_marker, (int) length);
+
+ INPUT_SYNC(cinfo); /* do before skip_input_data */
+ if (length > 0)
+ (*cinfo->src->skip_input_data) (cinfo, (long) length);
+
+ return TRUE;
+}
+
+
+/*
+ * Find the next JPEG marker, save it in cinfo->unread_marker.
+ * Returns FALSE if had to suspend before reaching a marker;
+ * in that case cinfo->unread_marker is unchanged.
+ *
+ * Note that the result might not be a valid marker code,
+ * but it will never be 0 or FF.
+ */
+
+LOCAL(int)
+next_marker (j_decompress_ptr cinfo)
+{
+ int c;
+ INPUT_VARS(cinfo);
+
+ for (;;) {
+ INPUT_BYTE(cinfo, c, return FALSE);
+ /* Skip any non-FF bytes.
+ * This may look a bit inefficient, but it will not occur in a valid file.
+ * We sync after each discarded byte so that a suspending data source
+ * can discard the byte from its buffer.
+ */
+ while (c != 0xFF) {
+ cinfo->marker->discarded_bytes++;
+ INPUT_SYNC(cinfo);
+ INPUT_BYTE(cinfo, c, return FALSE);
+ }
+ /* This loop swallows any duplicate FF bytes. Extra FFs are legal as
+ * pad bytes, so don't count them in discarded_bytes. We assume there
+ * will not be so many consecutive FF bytes as to overflow a suspending
+ * data source's input buffer.
+ */
+ do {
+ INPUT_BYTE(cinfo, c, return FALSE);
+ } while (c == 0xFF);
+ if (c != 0)
+ break; /* found a valid marker, exit loop */
+ /* Reach here if we found a stuffed-zero data sequence (FF/00).
+ * Discard it and loop back to try again.
+ */
+ cinfo->marker->discarded_bytes += 2;
+ INPUT_SYNC(cinfo);
+ }
+
+ if (cinfo->marker->discarded_bytes != 0) {
+ WARNMS2(cinfo, JWRN_EXTRANEOUS_DATA, cinfo->marker->discarded_bytes, c);
+ cinfo->marker->discarded_bytes = 0;
+ }
+
+ cinfo->unread_marker = c;
+
+ INPUT_SYNC(cinfo);
+ return TRUE;
+}
+
+
+LOCAL(int)
+first_marker (j_decompress_ptr cinfo)
+/* Like next_marker, but used to obtain the initial SOI marker. */
+/* For this marker, we do not allow preceding garbage or fill; otherwise,
+ * we might well scan an entire input file before realizing it ain't JPEG.
+ * If an application wants to process non-JFIF files, it must seek to the
+ * SOI before calling the JPEG library.
+ */
+{
+ int c, c2;
+ INPUT_VARS(cinfo);
+
+ INPUT_BYTE(cinfo, c, return FALSE);
+ INPUT_BYTE(cinfo, c2, return FALSE);
+ if (c != 0xFF || c2 != (int) M_SOI)
+ ERREXIT2(cinfo, JERR_NO_SOI, c, c2);
+
+ cinfo->unread_marker = c2;
+
+ INPUT_SYNC(cinfo);
+ return TRUE;
+}
+
+
+/*
+ * Read markers until SOS or EOI.
+ *
+ * Returns same codes as are defined for jpeg_consume_input:
+ * JPEG_SUSPENDED, JPEG_REACHED_SOS, or JPEG_REACHED_EOI.
+ */
+
+METHODDEF(int)
+read_markers (j_decompress_ptr cinfo)
+{
+ /* Outer loop repeats once for each marker. */
+ for (;;) {
+ /* Collect the marker proper, unless we already did. */
+ /* NB: first_marker() enforces the requirement that SOI appear first. */
+ if (cinfo->unread_marker == 0) {
+ if (! cinfo->marker->saw_SOI) {
+ if (! first_marker(cinfo))
+ return JPEG_SUSPENDED;
+ } else {
+ if (! next_marker(cinfo))
+ return JPEG_SUSPENDED;
+ }
+ }
+ /* At this point cinfo->unread_marker contains the marker code and the
+ * input point is just past the marker proper, but before any parameters.
+ * A suspension will cause us to return with this state still true.
+ */
+ switch (cinfo->unread_marker) {
+ case M_SOI:
+ if (! get_soi(cinfo))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_SOF0: /* Baseline */
+ case M_SOF1: /* Extended sequential, Huffman */
+ if (! get_sof(cinfo, FALSE, FALSE))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_SOF2: /* Progressive, Huffman */
+ if (! get_sof(cinfo, TRUE, FALSE))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_SOF9: /* Extended sequential, arithmetic */
+ if (! get_sof(cinfo, FALSE, TRUE))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_SOF10: /* Progressive, arithmetic */
+ if (! get_sof(cinfo, TRUE, TRUE))
+ return JPEG_SUSPENDED;
+ break;
+
+ /* Currently unsupported SOFn types */
+ case M_SOF3: /* Lossless, Huffman */
+ case M_SOF5: /* Differential sequential, Huffman */
+ case M_SOF6: /* Differential progressive, Huffman */
+ case M_SOF7: /* Differential lossless, Huffman */
+ case M_JPG: /* Reserved for JPEG extensions */
+ case M_SOF11: /* Lossless, arithmetic */
+ case M_SOF13: /* Differential sequential, arithmetic */
+ case M_SOF14: /* Differential progressive, arithmetic */
+ case M_SOF15: /* Differential lossless, arithmetic */
+ ERREXIT1(cinfo, JERR_SOF_UNSUPPORTED, cinfo->unread_marker);
+ break;
+
+ case M_SOS:
+ if (! get_sos(cinfo))
+ return JPEG_SUSPENDED;
+ cinfo->unread_marker = 0; /* processed the marker */
+ return JPEG_REACHED_SOS;
+
+ case M_EOI:
+ TRACEMS(cinfo, 1, JTRC_EOI);
+ cinfo->unread_marker = 0; /* processed the marker */
+ return JPEG_REACHED_EOI;
+
+ case M_DAC:
+ if (! get_dac(cinfo))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_DHT:
+ if (! get_dht(cinfo))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_DQT:
+ if (! get_dqt(cinfo))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_DRI:
+ if (! get_dri(cinfo))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_APP0:
+ case M_APP1:
+ case M_APP2:
+ case M_APP3:
+ case M_APP4:
+ case M_APP5:
+ case M_APP6:
+ case M_APP7:
+ case M_APP8:
+ case M_APP9:
+ case M_APP10:
+ case M_APP11:
+ case M_APP12:
+ case M_APP13:
+ case M_APP14:
+ case M_APP15:
+ if (! (*((my_marker_ptr) cinfo->marker)->process_APPn[
+ cinfo->unread_marker - (int) M_APP0]) (cinfo))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_COM:
+ if (! (*((my_marker_ptr) cinfo->marker)->process_COM) (cinfo))
+ return JPEG_SUSPENDED;
+ break;
+
+ case M_RST0: /* these are all parameterless */
+ case M_RST1:
+ case M_RST2:
+ case M_RST3:
+ case M_RST4:
+ case M_RST5:
+ case M_RST6:
+ case M_RST7:
+ case M_TEM:
+ TRACEMS1(cinfo, 1, JTRC_PARMLESS_MARKER, cinfo->unread_marker);
+ break;
+
+ case M_DNL: /* Ignore DNL ... perhaps the wrong thing */
+ if (! skip_variable(cinfo))
+ return JPEG_SUSPENDED;
+ break;
+
+ default: /* must be DHP, EXP, JPGn, or RESn */
+ /* For now, we treat the reserved markers as fatal errors since they are
+ * likely to be used to signal incompatible JPEG Part 3 extensions.
+ * Once the JPEG 3 version-number marker is well defined, this code
+ * ought to change!
+ */
+ ERREXIT1(cinfo, JERR_UNKNOWN_MARKER, cinfo->unread_marker);
+ break;
+ }
+ /* Successfully processed marker, so reset state variable */
+ cinfo->unread_marker = 0;
+ } /* end loop */
+}
+
+
+/*
+ * Read a restart marker, which is expected to appear next in the datastream;
+ * if the marker is not there, take appropriate recovery action.
+ * Returns FALSE if suspension is required.
+ *
+ * This is called by the entropy decoder after it has read an appropriate
+ * number of MCUs. cinfo->unread_marker may be nonzero if the entropy decoder
+ * has already read a marker from the data source. Under normal conditions
+ * cinfo->unread_marker will be reset to 0 before returning; if not reset,
+ * it holds a marker which the decoder will be unable to read past.
+ */
+
+METHODDEF(int)
+read_restart_marker (j_decompress_ptr cinfo)
+{
+ /* Obtain a marker unless we already did. */
+ /* Note that next_marker will complain if it skips any data. */
+ if (cinfo->unread_marker == 0) {
+ if (! next_marker(cinfo))
+ return FALSE;
+ }
+
+ if (cinfo->unread_marker ==
+ ((int) M_RST0 + cinfo->marker->next_restart_num)) {
+ /* Normal case --- swallow the marker and let entropy decoder continue */
+ TRACEMS1(cinfo, 3, JTRC_RST, cinfo->marker->next_restart_num);
+ cinfo->unread_marker = 0;
+ } else {
+ /* Uh-oh, the restart markers have been messed up. */
+ /* Let the data source manager determine how to resync. */
+ if (! (*cinfo->src->resync_to_restart) (cinfo,
+ cinfo->marker->next_restart_num))
+ return FALSE;
+ }
+
+ /* Update next-restart state */
+ cinfo->marker->next_restart_num = (cinfo->marker->next_restart_num + 1) & 7;
+
+ return TRUE;
+}
+
+
+/*
+ * This is the default resync_to_restart method for data source managers
+ * to use if they don't have any better approach. Some data source managers
+ * may be able to back up, or may have additional knowledge about the data
+ * which permits a more intelligent recovery strategy; such managers would
+ * presumably supply their own resync method.
+ *
+ * read_restart_marker calls resync_to_restart if it finds a marker other than
+ * the restart marker it was expecting. (This code is *not* used unless
+ * a nonzero restart interval has been declared.) cinfo->unread_marker is
+ * the marker code actually found (might be anything, except 0 or FF).
+ * The desired restart marker number (0..7) is passed as a parameter.
+ * This routine is supposed to apply whatever error recovery strategy seems
+ * appropriate in order to position the input stream to the next data segment.
+ * Note that cinfo->unread_marker is treated as a marker appearing before
+ * the current data-source input point; usually it should be reset to zero
+ * before returning.
+ * Returns FALSE if suspension is required.
+ *
+ * This implementation is substantially constrained by wanting to treat the
+ * input as a data stream; this means we can't back up. Therefore, we have
+ * only the following actions to work with:
+ * 1. Simply discard the marker and let the entropy decoder resume at next
+ * byte of file.
+ * 2. Read forward until we find another marker, discarding intervening
+ * data. (In theory we could look ahead within the current bufferload,
+ * without having to discard data if we don't find the desired marker.
+ * This idea is not implemented here, in part because it makes behavior
+ * dependent on buffer size and chance buffer-boundary positions.)
+ * 3. Leave the marker unread (by failing to zero cinfo->unread_marker).
+ * This will cause the entropy decoder to process an empty data segment,
+ * inserting dummy zeroes, and then we will reprocess the marker.
+ *
+ * #2 is appropriate if we think the desired marker lies ahead, while #3 is
+ * appropriate if the found marker is a future restart marker (indicating
+ * that we have missed the desired restart marker, probably because it got
+ * corrupted).
+ * We apply #2 or #3 if the found marker is a restart marker no more than
+ * two counts behind or ahead of the expected one. We also apply #2 if the
+ * found marker is not a legal JPEG marker code (it's certainly bogus data).
+ * If the found marker is a restart marker more than 2 counts away, we do #1
+ * (too much risk that the marker is erroneous; with luck we will be able to
+ * resync at some future point).
+ * For any valid non-restart JPEG marker, we apply #3. This keeps us from
+ * overrunning the end of a scan. An implementation limited to single-scan
+ * files might find it better to apply #2 for markers other than EOI, since
+ * any other marker would have to be bogus data in that case.
+ */
+
+GLOBAL(int)
+jpeg_resync_to_restart (j_decompress_ptr cinfo, int desired)
+{
+ int marker = cinfo->unread_marker;
+ int action = 1;
+
+ /* Always put up a warning. */
+ WARNMS2(cinfo, JWRN_MUST_RESYNC, marker, desired);
+
+ /* Outer loop handles repeated decision after scanning forward. */
+ for (;;) {
+ if (marker < (int) M_SOF0)
+ action = 2; /* invalid marker */
+ else if (marker < (int) M_RST0 || marker > (int) M_RST7)
+ action = 3; /* valid non-restart marker */
+ else {
+ if (marker == ((int) M_RST0 + ((desired+1) & 7)) ||
+ marker == ((int) M_RST0 + ((desired+2) & 7)))
+ action = 3; /* one of the next two expected restarts */
+ else if (marker == ((int) M_RST0 + ((desired-1) & 7)) ||
+ marker == ((int) M_RST0 + ((desired-2) & 7)))
+ action = 2; /* a prior restart, so advance */
+ else
+ action = 1; /* desired restart or too far away */
+ }
+ TRACEMS2(cinfo, 4, JTRC_RECOVERY_ACTION, marker, action);
+ switch (action) {
+ case 1:
+ /* Discard marker and let entropy decoder resume processing. */
+ cinfo->unread_marker = 0;
+ return TRUE;
+ case 2:
+ /* Scan to the next marker, and repeat the decision loop. */
+ if (! next_marker(cinfo))
+ return FALSE;
+ marker = cinfo->unread_marker;
+ break;
+ case 3:
+ /* Return without advancing past this marker. */
+ /* Entropy decoder will be forced to process an empty segment. */
+ return TRUE;
+ }
+ } /* end loop */
+}
+
+
+/*
+ * Reset marker processing state to begin a fresh datastream.
+ */
+
+METHODDEF(void)
+reset_marker_reader (j_decompress_ptr cinfo)
+{
+ my_marker_ptr marker = (my_marker_ptr) cinfo->marker;
+
+ cinfo->comp_info = NULL; /* until allocated by get_sof */
+ cinfo->input_scan_number = 0; /* no SOS seen yet */
+ cinfo->unread_marker = 0; /* no pending marker */
+ marker->pub.saw_SOI = FALSE; /* set internal state too */
+ marker->pub.saw_SOF = FALSE;
+ marker->pub.discarded_bytes = 0;
+ marker->cur_marker = NULL;
+}
+
+
+/*
+ * Initialize the marker reader module.
+ * This is called only once, when the decompression object is created.
+ */
+
+GLOBAL(void)
+jinit_marker_reader (j_decompress_ptr cinfo)
+{
+ my_marker_ptr marker;
+ int i;
+
+ /* Create subobject in permanent pool */
+ marker = (my_marker_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT,
+ SIZEOF(my_marker_reader));
+ cinfo->marker = (struct jpeg_marker_reader *) marker;
+ /* Initialize public method pointers */
+ marker->pub.reset_marker_reader = reset_marker_reader;
+ marker->pub.read_markers = read_markers;
+ marker->pub.read_restart_marker = read_restart_marker;
+ /* Initialize COM/APPn processing.
+ * By default, we examine and then discard APP0 and APP14,
+ * but simply discard COM and all other APPn.
+ */
+ marker->process_COM = skip_variable;
+ marker->length_limit_COM = 0;
+ for (i = 0; i < 16; i++) {
+ marker->process_APPn[i] = skip_variable;
+ marker->length_limit_APPn[i] = 0;
+ }
+ marker->process_APPn[0] = get_interesting_appn;
+ marker->process_APPn[14] = get_interesting_appn;
+ /* Reset marker processing state */
+ reset_marker_reader(cinfo);
+}
+
+
+/*
+ * Control saving of COM and APPn markers into marker_list.
+ */
+
+#ifdef SAVE_MARKERS_SUPPORTED
+
+GLOBAL(void)
+jpeg_save_markers (j_decompress_ptr cinfo, int marker_code,
+ unsigned int length_limit)
+{
+ my_marker_ptr marker = (my_marker_ptr) cinfo->marker;
+ long maxlength;
+ jpeg_marker_parser_method processor;
+
+ /* Length limit mustn't be larger than what we can allocate
+ * (should only be a concern in a 16-bit environment).
+ */
+ maxlength = cinfo->mem->max_alloc_chunk - SIZEOF(struct jpeg_marker_struct);
+ if (((long) length_limit) > maxlength)
+ length_limit = (unsigned int) maxlength;
+
+ /* Choose processor routine to use.
+ * APP0/APP14 have special requirements.
+ */
+ if (length_limit) {
+ processor = save_marker;
+ /* If saving APP0/APP14, save at least enough for our internal use. */
+ if (marker_code == (int) M_APP0 && length_limit < APP0_DATA_LEN)
+ length_limit = APP0_DATA_LEN;
+ else if (marker_code == (int) M_APP14 && length_limit < APP14_DATA_LEN)
+ length_limit = APP14_DATA_LEN;
+ } else {
+ processor = skip_variable;
+ /* If discarding APP0/APP14, use our regular on-the-fly processor. */
+ if (marker_code == (int) M_APP0 || marker_code == (int) M_APP14)
+ processor = get_interesting_appn;
+ }
+
+ if (marker_code == (int) M_COM) {
+ marker->process_COM = processor;
+ marker->length_limit_COM = length_limit;
+ } else if (marker_code >= (int) M_APP0 && marker_code <= (int) M_APP15) {
+ marker->process_APPn[marker_code - (int) M_APP0] = processor;
+ marker->length_limit_APPn[marker_code - (int) M_APP0] = length_limit;
+ } else
+ ERREXIT1(cinfo, JERR_UNKNOWN_MARKER, marker_code);
+}
+
+#endif /* SAVE_MARKERS_SUPPORTED */
+
+
+/*
+ * Install a special processing method for COM or APPn markers.
+ */
+
+GLOBAL(void)
+jpeg_set_marker_processor (j_decompress_ptr cinfo, int marker_code,
+ jpeg_marker_parser_method routine)
+{
+ my_marker_ptr marker = (my_marker_ptr) cinfo->marker;
+
+ if (marker_code == (int) M_COM)
+ marker->process_COM = routine;
+ else if (marker_code >= (int) M_APP0 && marker_code <= (int) M_APP15)
+ marker->process_APPn[marker_code - (int) M_APP0] = routine;
+ else
+ ERREXIT1(cinfo, JERR_UNKNOWN_MARKER, marker_code);
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdmaster.cpp b/ml/dlib/dlib/external/libjpeg/jdmaster.cpp
new file mode 100644
index 000000000..8aea1c8ea
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdmaster.cpp
@@ -0,0 +1,557 @@
+/*
+ * jdmaster.c
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains master control logic for the JPEG decompressor.
+ * These routines are concerned with selecting the modules to be executed
+ * and with determining the number of passes and the work to be done in each
+ * pass.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Private state */
+
+typedef struct {
+ struct jpeg_decomp_master pub; /* public fields */
+
+ int pass_number; /* # of passes completed */
+
+ int using_merged_upsample; /* TRUE if using merged upsample/cconvert */
+
+ /* Saved references to initialized quantizer modules,
+ * in case we need to switch modes.
+ */
+ struct jpeg_color_quantizer * quantizer_1pass;
+ struct jpeg_color_quantizer * quantizer_2pass;
+} my_decomp_master;
+
+typedef my_decomp_master * my_master_ptr;
+
+
+/*
+ * Determine whether merged upsample/color conversion should be used.
+ * CRUCIAL: this must match the actual capabilities of jdmerge.c!
+ */
+
+LOCAL(int)
+use_merged_upsample (j_decompress_ptr cinfo)
+{
+#ifdef UPSAMPLE_MERGING_SUPPORTED
+ /* Merging is the equivalent of plain box-filter upsampling */
+ if (cinfo->do_fancy_upsampling || cinfo->CCIR601_sampling)
+ return FALSE;
+ /* jdmerge.c only supports YCC=>RGB color conversion */
+ if (cinfo->jpeg_color_space != JCS_YCbCr || cinfo->num_components != 3 ||
+ cinfo->out_color_space != JCS_RGB ||
+ cinfo->out_color_components != RGB_PIXELSIZE)
+ return FALSE;
+ /* and it only handles 2h1v or 2h2v sampling ratios */
+ if (cinfo->comp_info[0].h_samp_factor != 2 ||
+ cinfo->comp_info[1].h_samp_factor != 1 ||
+ cinfo->comp_info[2].h_samp_factor != 1 ||
+ cinfo->comp_info[0].v_samp_factor > 2 ||
+ cinfo->comp_info[1].v_samp_factor != 1 ||
+ cinfo->comp_info[2].v_samp_factor != 1)
+ return FALSE;
+ /* furthermore, it doesn't work if we've scaled the IDCTs differently */
+ if (cinfo->comp_info[0].DCT_scaled_size != cinfo->min_DCT_scaled_size ||
+ cinfo->comp_info[1].DCT_scaled_size != cinfo->min_DCT_scaled_size ||
+ cinfo->comp_info[2].DCT_scaled_size != cinfo->min_DCT_scaled_size)
+ return FALSE;
+ /* ??? also need to test for upsample-time rescaling, when & if supported */
+ return TRUE; /* by golly, it'll work... */
+#else
+ return FALSE;
+#endif
+}
+
+
+/*
+ * Compute output image dimensions and related values.
+ * NOTE: this is exported for possible use by application.
+ * Hence it mustn't do anything that can't be done twice.
+ * Also note that it may be called before the master module is initialized!
+ */
+
+GLOBAL(void)
+jpeg_calc_output_dimensions (j_decompress_ptr cinfo)
+/* Do computations that are needed before master selection phase */
+{
+#ifdef IDCT_SCALING_SUPPORTED
+ int ci;
+ jpeg_component_info *compptr;
+#endif
+
+ /* Prevent application from calling me at wrong times */
+ if (cinfo->global_state != DSTATE_READY)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+#ifdef IDCT_SCALING_SUPPORTED
+
+ /* Compute actual output image dimensions and DCT scaling choices. */
+ if (cinfo->scale_num * 8 <= cinfo->scale_denom) {
+ /* Provide 1/8 scaling */
+ cinfo->output_width = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width, 8L);
+ cinfo->output_height = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height, 8L);
+ cinfo->min_DCT_scaled_size = 1;
+ } else if (cinfo->scale_num * 4 <= cinfo->scale_denom) {
+ /* Provide 1/4 scaling */
+ cinfo->output_width = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width, 4L);
+ cinfo->output_height = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height, 4L);
+ cinfo->min_DCT_scaled_size = 2;
+ } else if (cinfo->scale_num * 2 <= cinfo->scale_denom) {
+ /* Provide 1/2 scaling */
+ cinfo->output_width = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width, 2L);
+ cinfo->output_height = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height, 2L);
+ cinfo->min_DCT_scaled_size = 4;
+ } else {
+ /* Provide 1/1 scaling */
+ cinfo->output_width = cinfo->image_width;
+ cinfo->output_height = cinfo->image_height;
+ cinfo->min_DCT_scaled_size = DCTSIZE;
+ }
+ /* In selecting the actual DCT scaling for each component, we try to
+ * scale up the chroma components via IDCT scaling rather than upsampling.
+ * This saves time if the upsampler gets to use 1:1 scaling.
+ * Note this code assumes that the supported DCT scalings are powers of 2.
+ */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ int ssize = cinfo->min_DCT_scaled_size;
+ while (ssize < DCTSIZE &&
+ (compptr->h_samp_factor * ssize * 2 <=
+ cinfo->max_h_samp_factor * cinfo->min_DCT_scaled_size) &&
+ (compptr->v_samp_factor * ssize * 2 <=
+ cinfo->max_v_samp_factor * cinfo->min_DCT_scaled_size)) {
+ ssize = ssize * 2;
+ }
+ compptr->DCT_scaled_size = ssize;
+ }
+
+ /* Recompute downsampled dimensions of components;
+ * application needs to know these if using raw downsampled data.
+ */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Size in samples, after IDCT scaling */
+ compptr->downsampled_width = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_width *
+ (long) (compptr->h_samp_factor * compptr->DCT_scaled_size),
+ (long) (cinfo->max_h_samp_factor * DCTSIZE));
+ compptr->downsampled_height = (JDIMENSION)
+ jdiv_round_up((long) cinfo->image_height *
+ (long) (compptr->v_samp_factor * compptr->DCT_scaled_size),
+ (long) (cinfo->max_v_samp_factor * DCTSIZE));
+ }
+
+#else /* !IDCT_SCALING_SUPPORTED */
+
+ /* Hardwire it to "no scaling" */
+ cinfo->output_width = cinfo->image_width;
+ cinfo->output_height = cinfo->image_height;
+ /* jdinput.c has already initialized DCT_scaled_size to DCTSIZE,
+ * and has computed unscaled downsampled_width and downsampled_height.
+ */
+
+#endif /* IDCT_SCALING_SUPPORTED */
+
+ /* Report number of components in selected colorspace. */
+ /* Probably this should be in the color conversion module... */
+ switch (cinfo->out_color_space) {
+ case JCS_GRAYSCALE:
+ cinfo->out_color_components = 1;
+ break;
+ case JCS_RGB:
+#if RGB_PIXELSIZE != 3
+ cinfo->out_color_components = RGB_PIXELSIZE;
+ break;
+#endif /* else share code with YCbCr */
+ case JCS_YCbCr:
+ cinfo->out_color_components = 3;
+ break;
+ case JCS_CMYK:
+ case JCS_YCCK:
+ cinfo->out_color_components = 4;
+ break;
+ default: /* else must be same colorspace as in file */
+ cinfo->out_color_components = cinfo->num_components;
+ break;
+ }
+ cinfo->output_components = (cinfo->quantize_colors ? 1 :
+ cinfo->out_color_components);
+
+ /* See if upsampler will want to emit more than one row at a time */
+ if (use_merged_upsample(cinfo))
+ cinfo->rec_outbuf_height = cinfo->max_v_samp_factor;
+ else
+ cinfo->rec_outbuf_height = 1;
+}
+
+
+/*
+ * Several decompression processes need to range-limit values to the range
+ * 0..MAXJSAMPLE; the input value may fall somewhat outside this range
+ * due to noise introduced by quantization, roundoff error, etc. These
+ * processes are inner loops and need to be as fast as possible. On most
+ * machines, particularly CPUs with pipelines or instruction prefetch,
+ * a (subscript-check-less) C table lookup
+ * x = sample_range_limit[x];
+ * is faster than explicit tests
+ * if (x < 0) x = 0;
+ * else if (x > MAXJSAMPLE) x = MAXJSAMPLE;
+ * These processes all use a common table prepared by the routine below.
+ *
+ * For most steps we can mathematically guarantee that the initial value
+ * of x is within MAXJSAMPLE+1 of the legal range, so a table running from
+ * -(MAXJSAMPLE+1) to 2*MAXJSAMPLE+1 is sufficient. But for the initial
+ * limiting step (just after the IDCT), a wildly out-of-range value is
+ * possible if the input data is corrupt. To avoid any chance of indexing
+ * off the end of memory and getting a bad-pointer trap, we perform the
+ * post-IDCT limiting thus:
+ * x = range_limit[x & MASK];
+ * where MASK is 2 bits wider than legal sample data, ie 10 bits for 8-bit
+ * samples. Under normal circumstances this is more than enough range and
+ * a correct output will be generated; with bogus input data the mask will
+ * cause wraparound, and we will safely generate a bogus-but-in-range output.
+ * For the post-IDCT step, we want to convert the data from signed to unsigned
+ * representation by adding CENTERJSAMPLE at the same time that we limit it.
+ * So the post-IDCT limiting table ends up looking like this:
+ * CENTERJSAMPLE,CENTERJSAMPLE+1,...,MAXJSAMPLE,
+ * MAXJSAMPLE (repeat 2*(MAXJSAMPLE+1)-CENTERJSAMPLE times),
+ * 0 (repeat 2*(MAXJSAMPLE+1)-CENTERJSAMPLE times),
+ * 0,1,...,CENTERJSAMPLE-1
+ * Negative inputs select values from the upper half of the table after
+ * masking.
+ *
+ * We can save some space by overlapping the start of the post-IDCT table
+ * with the simpler range limiting table. The post-IDCT table begins at
+ * sample_range_limit + CENTERJSAMPLE.
+ *
+ * Note that the table is allocated in near data space on PCs; it's small
+ * enough and used often enough to justify this.
+ */
+
+LOCAL(void)
+prepare_range_limit_table (j_decompress_ptr cinfo)
+/* Allocate and fill in the sample_range_limit table */
+{
+ JSAMPLE * table;
+ int i;
+
+ table = (JSAMPLE *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (5 * (MAXJSAMPLE+1) + CENTERJSAMPLE) * SIZEOF(JSAMPLE));
+ table += (MAXJSAMPLE+1); /* allow negative subscripts of simple table */
+ cinfo->sample_range_limit = table;
+ /* First segment of "simple" table: limit[x] = 0 for x < 0 */
+ MEMZERO(table - (MAXJSAMPLE+1), (MAXJSAMPLE+1) * SIZEOF(JSAMPLE));
+ /* Main part of "simple" table: limit[x] = x */
+ for (i = 0; i <= MAXJSAMPLE; i++)
+ table[i] = (JSAMPLE) i;
+ table += CENTERJSAMPLE; /* Point to where post-IDCT table starts */
+ /* End of simple table, rest of first half of post-IDCT table */
+ for (i = CENTERJSAMPLE; i < 2*(MAXJSAMPLE+1); i++)
+ table[i] = MAXJSAMPLE;
+ /* Second half of post-IDCT table */
+ MEMZERO(table + (2 * (MAXJSAMPLE+1)),
+ (2 * (MAXJSAMPLE+1) - CENTERJSAMPLE) * SIZEOF(JSAMPLE));
+ MEMCOPY(table + (4 * (MAXJSAMPLE+1) - CENTERJSAMPLE),
+ cinfo->sample_range_limit, CENTERJSAMPLE * SIZEOF(JSAMPLE));
+}
+
+
+/*
+ * Master selection of decompression modules.
+ * This is done once at jpeg_start_decompress time. We determine
+ * which modules will be used and give them appropriate initialization calls.
+ * We also initialize the decompressor input side to begin consuming data.
+ *
+ * Since jpeg_read_header has finished, we know what is in the SOF
+ * and (first) SOS markers. We also have all the application parameter
+ * settings.
+ */
+
+LOCAL(void)
+master_selection (j_decompress_ptr cinfo)
+{
+ my_master_ptr master = (my_master_ptr) cinfo->master;
+ int use_c_buffer;
+ long samplesperrow;
+ JDIMENSION jd_samplesperrow;
+
+ /* Initialize dimensions and other stuff */
+ jpeg_calc_output_dimensions(cinfo);
+ prepare_range_limit_table(cinfo);
+
+ /* Width of an output scanline must be representable as JDIMENSION. */
+ samplesperrow = (long) cinfo->output_width * (long) cinfo->out_color_components;
+ jd_samplesperrow = (JDIMENSION) samplesperrow;
+ if ((long) jd_samplesperrow != samplesperrow)
+ ERREXIT(cinfo, JERR_WIDTH_OVERFLOW);
+
+ /* Initialize my private state */
+ master->pass_number = 0;
+ master->using_merged_upsample = use_merged_upsample(cinfo);
+
+ /* Color quantizer selection */
+ master->quantizer_1pass = NULL;
+ master->quantizer_2pass = NULL;
+ /* No mode changes if not using buffered-image mode. */
+ if (! cinfo->quantize_colors || ! cinfo->buffered_image) {
+ cinfo->enable_1pass_quant = FALSE;
+ cinfo->enable_external_quant = FALSE;
+ cinfo->enable_2pass_quant = FALSE;
+ }
+ if (cinfo->quantize_colors) {
+ if (cinfo->raw_data_out)
+ ERREXIT(cinfo, JERR_NOTIMPL);
+ /* 2-pass quantizer only works in 3-component color space. */
+ if (cinfo->out_color_components != 3) {
+ cinfo->enable_1pass_quant = TRUE;
+ cinfo->enable_external_quant = FALSE;
+ cinfo->enable_2pass_quant = FALSE;
+ cinfo->colormap = NULL;
+ } else if (cinfo->colormap != NULL) {
+ cinfo->enable_external_quant = TRUE;
+ } else if (cinfo->two_pass_quantize) {
+ cinfo->enable_2pass_quant = TRUE;
+ } else {
+ cinfo->enable_1pass_quant = TRUE;
+ }
+
+ if (cinfo->enable_1pass_quant) {
+#ifdef QUANT_1PASS_SUPPORTED
+ jinit_1pass_quantizer(cinfo);
+ master->quantizer_1pass = cinfo->cquantize;
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ }
+
+ /* We use the 2-pass code to map to external colormaps. */
+ if (cinfo->enable_2pass_quant || cinfo->enable_external_quant) {
+#ifdef QUANT_2PASS_SUPPORTED
+ jinit_2pass_quantizer(cinfo);
+ master->quantizer_2pass = cinfo->cquantize;
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ }
+ /* If both quantizers are initialized, the 2-pass one is left active;
+ * this is necessary for starting with quantization to an external map.
+ */
+ }
+
+ /* Post-processing: in particular, color conversion first */
+ if (! cinfo->raw_data_out) {
+ if (master->using_merged_upsample) {
+#ifdef UPSAMPLE_MERGING_SUPPORTED
+ jinit_merged_upsampler(cinfo); /* does color conversion too */
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ } else {
+ jinit_color_deconverter(cinfo);
+ jinit_upsampler(cinfo);
+ }
+ jinit_d_post_controller(cinfo, cinfo->enable_2pass_quant);
+ }
+ /* Inverse DCT */
+ jinit_inverse_dct(cinfo);
+ /* Entropy decoding: either Huffman or arithmetic coding. */
+ if (cinfo->arith_code) {
+ ERREXIT(cinfo, JERR_ARITH_NOTIMPL);
+ } else {
+ if (cinfo->progressive_mode) {
+#ifdef D_PROGRESSIVE_SUPPORTED
+ jinit_phuff_decoder(cinfo);
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif
+ } else
+ jinit_huff_decoder(cinfo);
+ }
+
+ /* Initialize principal buffer controllers. */
+ use_c_buffer = cinfo->inputctl->has_multiple_scans || cinfo->buffered_image;
+ jinit_d_coef_controller(cinfo, use_c_buffer);
+
+ if (! cinfo->raw_data_out)
+ jinit_d_main_controller(cinfo, FALSE /* never need full buffer here */);
+
+ /* We can now tell the memory manager to allocate virtual arrays. */
+ (*cinfo->mem->realize_virt_arrays) ((j_common_ptr) cinfo);
+
+ /* Initialize input side of decompressor to consume first scan. */
+ (*cinfo->inputctl->start_input_pass) (cinfo);
+
+#ifdef D_MULTISCAN_FILES_SUPPORTED
+ /* If jpeg_start_decompress will read the whole file, initialize
+ * progress monitoring appropriately. The input step is counted
+ * as one pass.
+ */
+ if (cinfo->progress != NULL && ! cinfo->buffered_image &&
+ cinfo->inputctl->has_multiple_scans) {
+ int nscans;
+ /* Estimate number of scans to set pass_limit. */
+ if (cinfo->progressive_mode) {
+ /* Arbitrarily estimate 2 interleaved DC scans + 3 AC scans/component. */
+ nscans = 2 + 3 * cinfo->num_components;
+ } else {
+ /* For a nonprogressive multiscan file, estimate 1 scan per component. */
+ nscans = cinfo->num_components;
+ }
+ cinfo->progress->pass_counter = 0L;
+ cinfo->progress->pass_limit = (long) cinfo->total_iMCU_rows * nscans;
+ cinfo->progress->completed_passes = 0;
+ cinfo->progress->total_passes = (cinfo->enable_2pass_quant ? 3 : 2);
+ /* Count the input pass as done */
+ master->pass_number++;
+ }
+#endif /* D_MULTISCAN_FILES_SUPPORTED */
+}
+
+
+/*
+ * Per-pass setup.
+ * This is called at the beginning of each output pass. We determine which
+ * modules will be active during this pass and give them appropriate
+ * start_pass calls. We also set is_dummy_pass to indicate whether this
+ * is a "real" output pass or a dummy pass for color quantization.
+ * (In the latter case, jdapistd.c will crank the pass to completion.)
+ */
+
+METHODDEF(void)
+prepare_for_output_pass (j_decompress_ptr cinfo)
+{
+ my_master_ptr master = (my_master_ptr) cinfo->master;
+
+ if (master->pub.is_dummy_pass) {
+#ifdef QUANT_2PASS_SUPPORTED
+ /* Final pass of 2-pass quantization */
+ master->pub.is_dummy_pass = FALSE;
+ (*cinfo->cquantize->start_pass) (cinfo, FALSE);
+ (*cinfo->post->start_pass) (cinfo, JBUF_CRANK_DEST);
+ (*cinfo->main->start_pass) (cinfo, JBUF_CRANK_DEST);
+#else
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+#endif /* QUANT_2PASS_SUPPORTED */
+ } else {
+ if (cinfo->quantize_colors && cinfo->colormap == NULL) {
+ /* Select new quantization method */
+ if (cinfo->two_pass_quantize && cinfo->enable_2pass_quant) {
+ cinfo->cquantize = master->quantizer_2pass;
+ master->pub.is_dummy_pass = TRUE;
+ } else if (cinfo->enable_1pass_quant) {
+ cinfo->cquantize = master->quantizer_1pass;
+ } else {
+ ERREXIT(cinfo, JERR_MODE_CHANGE);
+ }
+ }
+ (*cinfo->idct->start_pass) (cinfo);
+ (*cinfo->coef->start_output_pass) (cinfo);
+ if (! cinfo->raw_data_out) {
+ if (! master->using_merged_upsample)
+ (*cinfo->cconvert->start_pass) (cinfo);
+ (*cinfo->upsample->start_pass) (cinfo);
+ if (cinfo->quantize_colors)
+ (*cinfo->cquantize->start_pass) (cinfo, master->pub.is_dummy_pass);
+ (*cinfo->post->start_pass) (cinfo,
+ (master->pub.is_dummy_pass ? JBUF_SAVE_AND_PASS : JBUF_PASS_THRU));
+ (*cinfo->main->start_pass) (cinfo, JBUF_PASS_THRU);
+ }
+ }
+
+ /* Set up progress monitor's pass info if present */
+ if (cinfo->progress != NULL) {
+ cinfo->progress->completed_passes = master->pass_number;
+ cinfo->progress->total_passes = master->pass_number +
+ (master->pub.is_dummy_pass ? 2 : 1);
+ /* In buffered-image mode, we assume one more output pass if EOI not
+ * yet reached, but no more passes if EOI has been reached.
+ */
+ if (cinfo->buffered_image && ! cinfo->inputctl->eoi_reached) {
+ cinfo->progress->total_passes += (cinfo->enable_2pass_quant ? 2 : 1);
+ }
+ }
+}
+
+
+/*
+ * Finish up at end of an output pass.
+ */
+
+METHODDEF(void)
+finish_output_pass (j_decompress_ptr cinfo)
+{
+ my_master_ptr master = (my_master_ptr) cinfo->master;
+
+ if (cinfo->quantize_colors)
+ (*cinfo->cquantize->finish_pass) (cinfo);
+ master->pass_number++;
+}
+
+
+#ifdef D_MULTISCAN_FILES_SUPPORTED
+
+/*
+ * Switch to a new external colormap between output passes.
+ */
+
+GLOBAL(void)
+jpeg_new_colormap (j_decompress_ptr cinfo)
+{
+ my_master_ptr master = (my_master_ptr) cinfo->master;
+
+ /* Prevent application from calling me at wrong times */
+ if (cinfo->global_state != DSTATE_BUFIMAGE)
+ ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state);
+
+ if (cinfo->quantize_colors && cinfo->enable_external_quant &&
+ cinfo->colormap != NULL) {
+ /* Select 2-pass quantizer for external colormap use */
+ cinfo->cquantize = master->quantizer_2pass;
+ /* Notify quantizer of colormap change */
+ (*cinfo->cquantize->new_color_map) (cinfo);
+ master->pub.is_dummy_pass = FALSE; /* just in case */
+ } else
+ ERREXIT(cinfo, JERR_MODE_CHANGE);
+}
+
+#endif /* D_MULTISCAN_FILES_SUPPORTED */
+
+
+/*
+ * Initialize master decompression control and select active modules.
+ * This is performed at the start of jpeg_start_decompress.
+ */
+
+GLOBAL(void)
+jinit_master_decompress (j_decompress_ptr cinfo)
+{
+ my_master_ptr master;
+
+ master = (my_master_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_decomp_master));
+ cinfo->master = (struct jpeg_decomp_master *) master;
+ master->pub.prepare_for_output_pass = prepare_for_output_pass;
+ master->pub.finish_output_pass = finish_output_pass;
+
+ master->pub.is_dummy_pass = FALSE;
+
+ master_selection(cinfo);
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdmerge.cpp b/ml/dlib/dlib/external/libjpeg/jdmerge.cpp
new file mode 100644
index 000000000..38a692ff8
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdmerge.cpp
@@ -0,0 +1,400 @@
+/*
+ * jdmerge.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains code for merged upsampling/color conversion.
+ *
+ * This file combines functions from jdsample.c and jdcolor.c;
+ * read those files first to understand what's going on.
+ *
+ * When the chroma components are to be upsampled by simple replication
+ * (ie, box filtering), we can save some work in color conversion by
+ * calculating all the output pixels corresponding to a pair of chroma
+ * samples at one time. In the conversion equations
+ * R = Y + K1 * Cr
+ * G = Y + K2 * Cb + K3 * Cr
+ * B = Y + K4 * Cb
+ * only the Y term varies among the group of pixels corresponding to a pair
+ * of chroma samples, so the rest of the terms can be calculated just once.
+ * At typical sampling ratios, this eliminates half or three-quarters of the
+ * multiplications needed for color conversion.
+ *
+ * This file currently provides implementations for the following cases:
+ * YCbCr => RGB color conversion only.
+ * Sampling ratios of 2h1v or 2h2v.
+ * No scaling needed at upsample time.
+ * Corner-aligned (non-CCIR601) sampling alignment.
+ * Other special cases could be added, but in most applications these are
+ * the only common cases. (For uncommon cases we fall back on the more
+ * general code in jdsample.c and jdcolor.c.)
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+#ifdef UPSAMPLE_MERGING_SUPPORTED
+
+
+/* Private subobject */
+
+typedef struct {
+ struct jpeg_upsampler pub; /* public fields */
+
+ /* Pointer to routine to do actual upsampling/conversion of one row group */
+ JMETHOD(void, upmethod, (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION in_row_group_ctr,
+ JSAMPARRAY output_buf));
+
+ /* Private state for YCC->RGB conversion */
+ int * Cr_r_tab; /* => table for Cr to R conversion */
+ int * Cb_b_tab; /* => table for Cb to B conversion */
+ long * Cr_g_tab; /* => table for Cr to G conversion */
+ long * Cb_g_tab; /* => table for Cb to G conversion */
+
+ /* For 2:1 vertical sampling, we produce two output rows at a time.
+ * We need a "spare" row buffer to hold the second output row if the
+ * application provides just a one-row buffer; we also use the spare
+ * to discard the dummy last row if the image height is odd.
+ */
+ JSAMPROW spare_row;
+ int spare_full; /* T if spare buffer is occupied */
+
+ JDIMENSION out_row_width; /* samples per output row */
+ JDIMENSION rows_to_go; /* counts rows remaining in image */
+} my_upsampler;
+
+typedef my_upsampler * my_upsample_ptr;
+
+#define SCALEBITS 16 /* speediest right-shift on some machines */
+#define ONE_HALF ((long) 1 << (SCALEBITS-1))
+#define FIX(x) ((long) ((x) * (1L<<SCALEBITS) + 0.5))
+
+
+/*
+ * Initialize tables for YCC->RGB colorspace conversion.
+ * This is taken directly from jdcolor.c; see that file for more info.
+ */
+
+LOCAL(void)
+build_ycc_rgb_table (j_decompress_ptr cinfo)
+{
+ my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample;
+ int i;
+ long x;
+ SHIFT_TEMPS
+
+ upsample->Cr_r_tab = (int *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (MAXJSAMPLE+1) * SIZEOF(int));
+ upsample->Cb_b_tab = (int *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (MAXJSAMPLE+1) * SIZEOF(int));
+ upsample->Cr_g_tab = (long *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (MAXJSAMPLE+1) * SIZEOF(long));
+ upsample->Cb_g_tab = (long *)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (MAXJSAMPLE+1) * SIZEOF(long));
+
+ for (i = 0, x = -CENTERJSAMPLE; i <= MAXJSAMPLE; i++, x++) {
+ /* i is the actual input pixel value, in the range 0..MAXJSAMPLE */
+ /* The Cb or Cr value we are thinking of is x = i - CENTERJSAMPLE */
+ /* Cr=>R value is nearest int to 1.40200 * x */
+ upsample->Cr_r_tab[i] = (int)
+ RIGHT_SHIFT(FIX(1.40200) * x + ONE_HALF, SCALEBITS);
+ /* Cb=>B value is nearest int to 1.77200 * x */
+ upsample->Cb_b_tab[i] = (int)
+ RIGHT_SHIFT(FIX(1.77200) * x + ONE_HALF, SCALEBITS);
+ /* Cr=>G value is scaled-up -0.71414 * x */
+ upsample->Cr_g_tab[i] = (- FIX(0.71414)) * x;
+ /* Cb=>G value is scaled-up -0.34414 * x */
+ /* We also add in ONE_HALF so that need not do it in inner loop */
+ upsample->Cb_g_tab[i] = (- FIX(0.34414)) * x + ONE_HALF;
+ }
+}
+
+
+/*
+ * Initialize for an upsampling pass.
+ */
+
+METHODDEF(void)
+start_pass_merged_upsample (j_decompress_ptr cinfo)
+{
+ my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample;
+
+ /* Mark the spare buffer empty */
+ upsample->spare_full = FALSE;
+ /* Initialize total-height counter for detecting bottom of image */
+ upsample->rows_to_go = cinfo->output_height;
+}
+
+
+/*
+ * Control routine to do upsampling (and color conversion).
+ *
+ * The control routine just handles the row buffering considerations.
+ */
+
+METHODDEF(void)
+merged_2v_upsample (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr,
+ JDIMENSION ,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail)
+/* 2:1 vertical sampling case: may need a spare row. */
+{
+ my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample;
+ JSAMPROW work_ptrs[2];
+ JDIMENSION num_rows; /* number of rows returned to caller */
+
+ if (upsample->spare_full) {
+ /* If we have a spare row saved from a previous cycle, just return it. */
+ jcopy_sample_rows(& upsample->spare_row, 0, output_buf + *out_row_ctr, 0,
+ 1, upsample->out_row_width);
+ num_rows = 1;
+ upsample->spare_full = FALSE;
+ } else {
+ /* Figure number of rows to return to caller. */
+ num_rows = 2;
+ /* Not more than the distance to the end of the image. */
+ if (num_rows > upsample->rows_to_go)
+ num_rows = upsample->rows_to_go;
+ /* And not more than what the client can accept: */
+ out_rows_avail -= *out_row_ctr;
+ if (num_rows > out_rows_avail)
+ num_rows = out_rows_avail;
+ /* Create output pointer array for upsampler. */
+ work_ptrs[0] = output_buf[*out_row_ctr];
+ if (num_rows > 1) {
+ work_ptrs[1] = output_buf[*out_row_ctr + 1];
+ } else {
+ work_ptrs[1] = upsample->spare_row;
+ upsample->spare_full = TRUE;
+ }
+ /* Now do the upsampling. */
+ (*upsample->upmethod) (cinfo, input_buf, *in_row_group_ctr, work_ptrs);
+ }
+
+ /* Adjust counts */
+ *out_row_ctr += num_rows;
+ upsample->rows_to_go -= num_rows;
+ /* When the buffer is emptied, declare this input row group consumed */
+ if (! upsample->spare_full)
+ (*in_row_group_ctr)++;
+}
+
+
+METHODDEF(void)
+merged_1v_upsample (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr,
+ JDIMENSION ,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION )
+/* 1:1 vertical sampling case: much easier, never need a spare row. */
+{
+ my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample;
+
+ /* Just do the upsampling. */
+ (*upsample->upmethod) (cinfo, input_buf, *in_row_group_ctr,
+ output_buf + *out_row_ctr);
+ /* Adjust counts */
+ (*out_row_ctr)++;
+ (*in_row_group_ctr)++;
+}
+
+
+/*
+ * These are the routines invoked by the control routines to do
+ * the actual upsampling/conversion. One row group is processed per call.
+ *
+ * Note: since we may be writing directly into application-supplied buffers,
+ * we have to be honest about the output width; we can't assume the buffer
+ * has been rounded up to an even width.
+ */
+
+
+/*
+ * Upsample and color convert for the case of 2:1 horizontal and 1:1 vertical.
+ */
+
+METHODDEF(void)
+h2v1_merged_upsample (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION in_row_group_ctr,
+ JSAMPARRAY output_buf)
+{
+ my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample;
+ int y, cred, cgreen, cblue;
+ int cb, cr;
+ JSAMPROW outptr;
+ JSAMPROW inptr0, inptr1, inptr2;
+ JDIMENSION col;
+ /* copy these pointers into registers if possible */
+ JSAMPLE * range_limit = cinfo->sample_range_limit;
+ int * Crrtab = upsample->Cr_r_tab;
+ int * Cbbtab = upsample->Cb_b_tab;
+ long * Crgtab = upsample->Cr_g_tab;
+ long * Cbgtab = upsample->Cb_g_tab;
+ SHIFT_TEMPS
+
+ inptr0 = input_buf[0][in_row_group_ctr];
+ inptr1 = input_buf[1][in_row_group_ctr];
+ inptr2 = input_buf[2][in_row_group_ctr];
+ outptr = output_buf[0];
+ /* Loop for each pair of output pixels */
+ for (col = cinfo->output_width >> 1; col > 0; col--) {
+ /* Do the chroma part of the calculation */
+ cb = GETJSAMPLE(*inptr1++);
+ cr = GETJSAMPLE(*inptr2++);
+ cred = Crrtab[cr];
+ cgreen = (int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], SCALEBITS);
+ cblue = Cbbtab[cb];
+ /* Fetch 2 Y values and emit 2 pixels */
+ y = GETJSAMPLE(*inptr0++);
+ outptr[RGB_RED] = range_limit[y + cred];
+ outptr[RGB_GREEN] = range_limit[y + cgreen];
+ outptr[RGB_BLUE] = range_limit[y + cblue];
+ outptr += RGB_PIXELSIZE;
+ y = GETJSAMPLE(*inptr0++);
+ outptr[RGB_RED] = range_limit[y + cred];
+ outptr[RGB_GREEN] = range_limit[y + cgreen];
+ outptr[RGB_BLUE] = range_limit[y + cblue];
+ outptr += RGB_PIXELSIZE;
+ }
+ /* If image width is odd, do the last output column separately */
+ if (cinfo->output_width & 1) {
+ cb = GETJSAMPLE(*inptr1);
+ cr = GETJSAMPLE(*inptr2);
+ cred = Crrtab[cr];
+ cgreen = (int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], SCALEBITS);
+ cblue = Cbbtab[cb];
+ y = GETJSAMPLE(*inptr0);
+ outptr[RGB_RED] = range_limit[y + cred];
+ outptr[RGB_GREEN] = range_limit[y + cgreen];
+ outptr[RGB_BLUE] = range_limit[y + cblue];
+ }
+}
+
+
+/*
+ * Upsample and color convert for the case of 2:1 horizontal and 2:1 vertical.
+ */
+
+METHODDEF(void)
+h2v2_merged_upsample (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION in_row_group_ctr,
+ JSAMPARRAY output_buf)
+{
+ my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample;
+ int y, cred, cgreen, cblue;
+ int cb, cr;
+ JSAMPROW outptr0, outptr1;
+ JSAMPROW inptr00, inptr01, inptr1, inptr2;
+ JDIMENSION col;
+ /* copy these pointers into registers if possible */
+ JSAMPLE * range_limit = cinfo->sample_range_limit;
+ int * Crrtab = upsample->Cr_r_tab;
+ int * Cbbtab = upsample->Cb_b_tab;
+ long * Crgtab = upsample->Cr_g_tab;
+ long * Cbgtab = upsample->Cb_g_tab;
+ SHIFT_TEMPS
+
+ inptr00 = input_buf[0][in_row_group_ctr*2];
+ inptr01 = input_buf[0][in_row_group_ctr*2 + 1];
+ inptr1 = input_buf[1][in_row_group_ctr];
+ inptr2 = input_buf[2][in_row_group_ctr];
+ outptr0 = output_buf[0];
+ outptr1 = output_buf[1];
+ /* Loop for each group of output pixels */
+ for (col = cinfo->output_width >> 1; col > 0; col--) {
+ /* Do the chroma part of the calculation */
+ cb = GETJSAMPLE(*inptr1++);
+ cr = GETJSAMPLE(*inptr2++);
+ cred = Crrtab[cr];
+ cgreen = (int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], SCALEBITS);
+ cblue = Cbbtab[cb];
+ /* Fetch 4 Y values and emit 4 pixels */
+ y = GETJSAMPLE(*inptr00++);
+ outptr0[RGB_RED] = range_limit[y + cred];
+ outptr0[RGB_GREEN] = range_limit[y + cgreen];
+ outptr0[RGB_BLUE] = range_limit[y + cblue];
+ outptr0 += RGB_PIXELSIZE;
+ y = GETJSAMPLE(*inptr00++);
+ outptr0[RGB_RED] = range_limit[y + cred];
+ outptr0[RGB_GREEN] = range_limit[y + cgreen];
+ outptr0[RGB_BLUE] = range_limit[y + cblue];
+ outptr0 += RGB_PIXELSIZE;
+ y = GETJSAMPLE(*inptr01++);
+ outptr1[RGB_RED] = range_limit[y + cred];
+ outptr1[RGB_GREEN] = range_limit[y + cgreen];
+ outptr1[RGB_BLUE] = range_limit[y + cblue];
+ outptr1 += RGB_PIXELSIZE;
+ y = GETJSAMPLE(*inptr01++);
+ outptr1[RGB_RED] = range_limit[y + cred];
+ outptr1[RGB_GREEN] = range_limit[y + cgreen];
+ outptr1[RGB_BLUE] = range_limit[y + cblue];
+ outptr1 += RGB_PIXELSIZE;
+ }
+ /* If image width is odd, do the last output column separately */
+ if (cinfo->output_width & 1) {
+ cb = GETJSAMPLE(*inptr1);
+ cr = GETJSAMPLE(*inptr2);
+ cred = Crrtab[cr];
+ cgreen = (int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], SCALEBITS);
+ cblue = Cbbtab[cb];
+ y = GETJSAMPLE(*inptr00);
+ outptr0[RGB_RED] = range_limit[y + cred];
+ outptr0[RGB_GREEN] = range_limit[y + cgreen];
+ outptr0[RGB_BLUE] = range_limit[y + cblue];
+ y = GETJSAMPLE(*inptr01);
+ outptr1[RGB_RED] = range_limit[y + cred];
+ outptr1[RGB_GREEN] = range_limit[y + cgreen];
+ outptr1[RGB_BLUE] = range_limit[y + cblue];
+ }
+}
+
+
+/*
+ * Module initialization routine for merged upsampling/color conversion.
+ *
+ * NB: this is called under the conditions determined by use_merged_upsample()
+ * in jdmaster.c. That routine MUST correspond to the actual capabilities
+ * of this module; no safety checks are made here.
+ */
+
+GLOBAL(void)
+jinit_merged_upsampler (j_decompress_ptr cinfo)
+{
+ my_upsample_ptr upsample;
+
+ upsample = (my_upsample_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_upsampler));
+ cinfo->upsample = (struct jpeg_upsampler *) upsample;
+ upsample->pub.start_pass = start_pass_merged_upsample;
+ upsample->pub.need_context_rows = FALSE;
+
+ upsample->out_row_width = cinfo->output_width * cinfo->out_color_components;
+
+ if (cinfo->max_v_samp_factor == 2) {
+ upsample->pub.upsample = merged_2v_upsample;
+ upsample->upmethod = h2v2_merged_upsample;
+ /* Allocate a spare row buffer */
+ upsample->spare_row = (JSAMPROW)
+ (*cinfo->mem->alloc_large) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (size_t) (upsample->out_row_width * SIZEOF(JSAMPLE)));
+ } else {
+ upsample->pub.upsample = merged_1v_upsample;
+ upsample->upmethod = h2v1_merged_upsample;
+ /* No spare row needed */
+ upsample->spare_row = NULL;
+ }
+
+ build_ycc_rgb_table(cinfo);
+}
+
+#endif /* UPSAMPLE_MERGING_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jdphuff.cpp b/ml/dlib/dlib/external/libjpeg/jdphuff.cpp
new file mode 100644
index 000000000..d9c02fe0b
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdphuff.cpp
@@ -0,0 +1,671 @@
+/*
+ * jdphuff.c
+ *
+ * Copyright (C) 1995-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains Huffman entropy decoding routines for progressive JPEG.
+ *
+ * Much of the complexity here has to do with supporting input suspension.
+ * If the data source module demands suspension, we want to be able to back
+ * up to the start of the current MCU. To do this, we copy state variables
+ * into local working storage, and update them back to the permanent
+ * storage only upon successful completion of an MCU.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdhuff.h" /* Declarations shared with jdhuff.c */
+
+#ifdef __GNUC__
+#pragma GCC diagnostic ignored "-Wshift-negative-value"
+#endif
+
+#ifdef D_PROGRESSIVE_SUPPORTED
+
+/*
+ * Expanded entropy decoder object for progressive Huffman decoding.
+ *
+ * The savable_state subrecord contains fields that change within an MCU,
+ * but must not be updated permanently until we complete the MCU.
+ */
+
+typedef struct {
+ unsigned int EOBRUN; /* remaining EOBs in EOBRUN */
+ int last_dc_val[MAX_COMPS_IN_SCAN]; /* last DC coef for each component */
+} savable_state;
+
+/* This macro is to work around compilers with missing or broken
+ * structure assignment. You'll need to fix this code if you have
+ * such a compiler and you change MAX_COMPS_IN_SCAN.
+ */
+
+#ifndef NO_STRUCT_ASSIGN
+#define ASSIGN_STATE(dest,src) ((dest) = (src))
+#else
+#if MAX_COMPS_IN_SCAN == 4
+#define ASSIGN_STATE(dest,src) \
+ ((dest).EOBRUN = (src).EOBRUN, \
+ (dest).last_dc_val[0] = (src).last_dc_val[0], \
+ (dest).last_dc_val[1] = (src).last_dc_val[1], \
+ (dest).last_dc_val[2] = (src).last_dc_val[2], \
+ (dest).last_dc_val[3] = (src).last_dc_val[3])
+#endif
+#endif
+
+
+typedef struct {
+ struct jpeg_entropy_decoder pub; /* public fields */
+
+ /* These fields are loaded into local variables at start of each MCU.
+ * In case of suspension, we exit WITHOUT updating them.
+ */
+ bitread_perm_state bitstate; /* Bit buffer at start of MCU */
+ savable_state saved; /* Other state at start of MCU */
+
+ /* These fields are NOT loaded into local working state. */
+ unsigned int restarts_to_go; /* MCUs left in this restart interval */
+
+ /* Pointers to derived tables (these workspaces have image lifespan) */
+ d_derived_tbl * derived_tbls[NUM_HUFF_TBLS];
+
+ d_derived_tbl * ac_derived_tbl; /* active table during an AC scan */
+} phuff_entropy_decoder;
+
+typedef phuff_entropy_decoder * phuff_entropy_ptr;
+
+/* Forward declarations */
+METHODDEF(int) decode_mcu_DC_first JPP((j_decompress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+METHODDEF(int) decode_mcu_AC_first JPP((j_decompress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+METHODDEF(int) decode_mcu_DC_refine JPP((j_decompress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+METHODDEF(int) decode_mcu_AC_refine JPP((j_decompress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+
+
+/*
+ * Initialize for a Huffman-compressed scan.
+ */
+
+METHODDEF(void)
+start_pass_phuff_decoder (j_decompress_ptr cinfo)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int is_DC_band, bad;
+ int ci, coefi, tbl;
+ int *coef_bit_ptr;
+ jpeg_component_info * compptr;
+
+ is_DC_band = (cinfo->Ss == 0);
+
+ /* Validate scan parameters */
+ bad = FALSE;
+ if (is_DC_band) {
+ if (cinfo->Se != 0)
+ bad = TRUE;
+ } else {
+ /* need not check Ss/Se < 0 since they came from unsigned bytes */
+ if (cinfo->Ss > cinfo->Se || cinfo->Se >= DCTSIZE2)
+ bad = TRUE;
+ /* AC scans may have only one component */
+ if (cinfo->comps_in_scan != 1)
+ bad = TRUE;
+ }
+ if (cinfo->Ah != 0) {
+ /* Successive approximation refinement scan: must have Al = Ah-1. */
+ if (cinfo->Al != cinfo->Ah-1)
+ bad = TRUE;
+ }
+ if (cinfo->Al > 13) /* need not check for < 0 */
+ bad = TRUE;
+ /* Arguably the maximum Al value should be less than 13 for 8-bit precision,
+ * but the spec doesn't say so, and we try to be liberal about what we
+ * accept. Note: large Al values could result in out-of-range DC
+ * coefficients during early scans, leading to bizarre displays due to
+ * overflows in the IDCT math. But we won't crash.
+ */
+ if (bad)
+ ERREXIT4(cinfo, JERR_BAD_PROGRESSION,
+ cinfo->Ss, cinfo->Se, cinfo->Ah, cinfo->Al);
+ /* Update progression status, and verify that scan order is legal.
+ * Note that inter-scan inconsistencies are treated as warnings
+ * not fatal errors ... not clear if this is right way to behave.
+ */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ int cindex = cinfo->cur_comp_info[ci]->component_index;
+ coef_bit_ptr = & cinfo->coef_bits[cindex][0];
+ if (!is_DC_band && coef_bit_ptr[0] < 0) /* AC without prior DC scan */
+ WARNMS2(cinfo, JWRN_BOGUS_PROGRESSION, cindex, 0);
+ for (coefi = cinfo->Ss; coefi <= cinfo->Se; coefi++) {
+ int expected = (coef_bit_ptr[coefi] < 0) ? 0 : coef_bit_ptr[coefi];
+ if (cinfo->Ah != expected)
+ WARNMS2(cinfo, JWRN_BOGUS_PROGRESSION, cindex, coefi);
+ coef_bit_ptr[coefi] = cinfo->Al;
+ }
+ }
+
+ /* Select MCU decoding routine */
+ if (cinfo->Ah == 0) {
+ if (is_DC_band)
+ entropy->pub.decode_mcu = decode_mcu_DC_first;
+ else
+ entropy->pub.decode_mcu = decode_mcu_AC_first;
+ } else {
+ if (is_DC_band)
+ entropy->pub.decode_mcu = decode_mcu_DC_refine;
+ else
+ entropy->pub.decode_mcu = decode_mcu_AC_refine;
+ }
+
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++) {
+ compptr = cinfo->cur_comp_info[ci];
+ /* Make sure requested tables are present, and compute derived tables.
+ * We may build same derived table more than once, but it's not expensive.
+ */
+ if (is_DC_band) {
+ if (cinfo->Ah == 0) { /* DC refinement needs no table */
+ tbl = compptr->dc_tbl_no;
+ jpeg_make_d_derived_tbl(cinfo, TRUE, tbl,
+ & entropy->derived_tbls[tbl]);
+ }
+ } else {
+ tbl = compptr->ac_tbl_no;
+ jpeg_make_d_derived_tbl(cinfo, FALSE, tbl,
+ & entropy->derived_tbls[tbl]);
+ /* remember the single active table */
+ entropy->ac_derived_tbl = entropy->derived_tbls[tbl];
+ }
+ /* Initialize DC predictions to 0 */
+ entropy->saved.last_dc_val[ci] = 0;
+ }
+
+ /* Initialize bitread state variables */
+ entropy->bitstate.bits_left = 0;
+ entropy->bitstate.get_buffer = 0; /* unnecessary, but keeps Purify quiet */
+ entropy->pub.insufficient_data = FALSE;
+
+ /* Initialize private state variables */
+ entropy->saved.EOBRUN = 0;
+
+ /* Initialize restart counter */
+ entropy->restarts_to_go = cinfo->restart_interval;
+}
+
+
+/*
+ * Figure F.12: extend sign bit.
+ * On some machines, a shift and add will be faster than a table lookup.
+ */
+
+#ifdef AVOID_TABLES
+
+#define HUFF_EXTEND(x,s) ((x) < (1<<((s)-1)) ? (x) + (((-1)<<(s)) + 1) : (x))
+
+#else
+
+#define HUFF_EXTEND(x,s) ((x) < extend_test[s] ? (x) + extend_offset[s] : (x))
+
+static const int extend_test[16] = /* entry n is 2**(n-1) */
+ { 0, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080,
+ 0x0100, 0x0200, 0x0400, 0x0800, 0x1000, 0x2000, 0x4000 };
+
+static const int extend_offset[16] = /* entry n is (-1 << n) + 1 */
+ { 0, ((-1)<<1) + 1, ((-1)<<2) + 1, ((-1)<<3) + 1, ((-1)<<4) + 1,
+ ((-1)<<5) + 1, ((-1)<<6) + 1, ((-1)<<7) + 1, ((-1)<<8) + 1,
+ ((-1)<<9) + 1, ((-1)<<10) + 1, ((-1)<<11) + 1, ((-1)<<12) + 1,
+ ((-1)<<13) + 1, ((-1)<<14) + 1, ((-1)<<15) + 1 };
+
+#endif /* AVOID_TABLES */
+
+
+/*
+ * Check for a restart marker & resynchronize decoder.
+ * Returns FALSE if must suspend.
+ */
+
+LOCAL(int)
+process_restart (j_decompress_ptr cinfo)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int ci;
+
+ /* Throw away any unused bits remaining in bit buffer; */
+ /* include any full bytes in next_marker's count of discarded bytes */
+ cinfo->marker->discarded_bytes += entropy->bitstate.bits_left / 8;
+ entropy->bitstate.bits_left = 0;
+
+ /* Advance past the RSTn marker */
+ if (! (*cinfo->marker->read_restart_marker) (cinfo))
+ return FALSE;
+
+ /* Re-initialize DC predictions to 0 */
+ for (ci = 0; ci < cinfo->comps_in_scan; ci++)
+ entropy->saved.last_dc_val[ci] = 0;
+ /* Re-init EOB run count, too */
+ entropy->saved.EOBRUN = 0;
+
+ /* Reset restart counter */
+ entropy->restarts_to_go = cinfo->restart_interval;
+
+ /* Reset out-of-data flag, unless read_restart_marker left us smack up
+ * against a marker. In that case we will end up treating the next data
+ * segment as empty, and we can avoid producing bogus output pixels by
+ * leaving the flag set.
+ */
+ if (cinfo->unread_marker == 0)
+ entropy->pub.insufficient_data = FALSE;
+
+ return TRUE;
+}
+
+
+/*
+ * Huffman MCU decoding.
+ * Each of these routines decodes and returns one MCU's worth of
+ * Huffman-compressed coefficients.
+ * The coefficients are reordered from zigzag order into natural array order,
+ * but are not dequantized.
+ *
+ * The i'th block of the MCU is stored into the block pointed to by
+ * MCU_data[i]. WE ASSUME THIS AREA IS INITIALLY ZEROED BY THE CALLER.
+ *
+ * We return FALSE if data source requested suspension. In that case no
+ * changes have been made to permanent state. (Exception: some output
+ * coefficients may already have been assigned. This is harmless for
+ * spectral selection, since we'll just re-assign them on the next call.
+ * Successive approximation AC refinement has to be more careful, however.)
+ */
+
+/*
+ * MCU decoding for DC initial scan (either spectral selection,
+ * or first pass of successive approximation).
+ */
+
+METHODDEF(int)
+decode_mcu_DC_first (j_decompress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int Al = cinfo->Al;
+ int s, r;
+ int blkn, ci;
+ JBLOCKROW block;
+ BITREAD_STATE_VARS;
+ savable_state state;
+ d_derived_tbl * tbl;
+ jpeg_component_info * compptr;
+
+ /* Process restart marker if needed; may have to suspend */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0)
+ if (! process_restart(cinfo))
+ return FALSE;
+ }
+
+ /* If we've run out of data, just leave the MCU set to zeroes.
+ * This way, we return uniform gray for the remainder of the segment.
+ */
+ if (! entropy->pub.insufficient_data) {
+
+ /* Load up working state */
+ BITREAD_LOAD_STATE(cinfo,entropy->bitstate);
+ ASSIGN_STATE(state, entropy->saved);
+
+ /* Outer loop handles each block in the MCU */
+
+ for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) {
+ block = MCU_data[blkn];
+ ci = cinfo->MCU_membership[blkn];
+ compptr = cinfo->cur_comp_info[ci];
+ tbl = entropy->derived_tbls[compptr->dc_tbl_no];
+
+ /* Decode a single block's worth of coefficients */
+
+ /* Section F.2.2.1: decode the DC coefficient difference */
+ HUFF_DECODE(s, br_state, tbl, return FALSE, label1);
+ if (s) {
+ CHECK_BIT_BUFFER(br_state, s, return FALSE);
+ r = GET_BITS(s);
+ s = HUFF_EXTEND(r, s);
+ }
+
+ /* Convert DC difference to actual value, update last_dc_val */
+ s += state.last_dc_val[ci];
+ state.last_dc_val[ci] = s;
+ /* Scale and output the coefficient (assumes jpeg_natural_order[0]=0) */
+ (*block)[0] = (JCOEF) (s << Al);
+ }
+
+ /* Completed MCU, so update state */
+ BITREAD_SAVE_STATE(cinfo,entropy->bitstate);
+ ASSIGN_STATE(entropy->saved, state);
+ }
+
+ /* Account for restart interval (no-op if not using restarts) */
+ entropy->restarts_to_go--;
+
+ return TRUE;
+}
+
+
+/*
+ * MCU decoding for AC initial scan (either spectral selection,
+ * or first pass of successive approximation).
+ */
+
+METHODDEF(int)
+decode_mcu_AC_first (j_decompress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int Se = cinfo->Se;
+ int Al = cinfo->Al;
+ int s, k, r;
+ unsigned int EOBRUN;
+ JBLOCKROW block;
+ BITREAD_STATE_VARS;
+ d_derived_tbl * tbl;
+
+ /* Process restart marker if needed; may have to suspend */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0)
+ if (! process_restart(cinfo))
+ return FALSE;
+ }
+
+ /* If we've run out of data, just leave the MCU set to zeroes.
+ * This way, we return uniform gray for the remainder of the segment.
+ */
+ if (! entropy->pub.insufficient_data) {
+
+ /* Load up working state.
+ * We can avoid loading/saving bitread state if in an EOB run.
+ */
+ EOBRUN = entropy->saved.EOBRUN; /* only part of saved state we need */
+
+ /* There is always only one block per MCU */
+
+ if (EOBRUN > 0) /* if it's a band of zeroes... */
+ EOBRUN--; /* ...process it now (we do nothing) */
+ else {
+ BITREAD_LOAD_STATE(cinfo,entropy->bitstate);
+ block = MCU_data[0];
+ tbl = entropy->ac_derived_tbl;
+
+ for (k = cinfo->Ss; k <= Se; k++) {
+ HUFF_DECODE(s, br_state, tbl, return FALSE, label2);
+ r = s >> 4;
+ s &= 15;
+ if (s) {
+ k += r;
+ CHECK_BIT_BUFFER(br_state, s, return FALSE);
+ r = GET_BITS(s);
+ s = HUFF_EXTEND(r, s);
+ /* Scale and output coefficient in natural (dezigzagged) order */
+ (*block)[jpeg_natural_order[k]] = (JCOEF) (s << Al);
+ } else {
+ if (r == 15) { /* ZRL */
+ k += 15; /* skip 15 zeroes in band */
+ } else { /* EOBr, run length is 2^r + appended bits */
+ EOBRUN = 1 << r;
+ if (r) { /* EOBr, r > 0 */
+ CHECK_BIT_BUFFER(br_state, r, return FALSE);
+ r = GET_BITS(r);
+ EOBRUN += r;
+ }
+ EOBRUN--; /* this band is processed at this moment */
+ break; /* force end-of-band */
+ }
+ }
+ }
+
+ BITREAD_SAVE_STATE(cinfo,entropy->bitstate);
+ }
+
+ /* Completed MCU, so update state */
+ entropy->saved.EOBRUN = EOBRUN; /* only part of saved state we need */
+ }
+
+ /* Account for restart interval (no-op if not using restarts) */
+ entropy->restarts_to_go--;
+
+ return TRUE;
+}
+
+
+/*
+ * MCU decoding for DC successive approximation refinement scan.
+ * Note: we assume such scans can be multi-component, although the spec
+ * is not very clear on the point.
+ */
+
+METHODDEF(int)
+decode_mcu_DC_refine (j_decompress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int p1 = 1 << cinfo->Al; /* 1 in the bit position being coded */
+ int blkn;
+ JBLOCKROW block;
+ BITREAD_STATE_VARS;
+
+ /* Process restart marker if needed; may have to suspend */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0)
+ if (! process_restart(cinfo))
+ return FALSE;
+ }
+
+ /* Not worth the cycles to check insufficient_data here,
+ * since we will not change the data anyway if we read zeroes.
+ */
+
+ /* Load up working state */
+ BITREAD_LOAD_STATE(cinfo,entropy->bitstate);
+
+ /* Outer loop handles each block in the MCU */
+
+ for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) {
+ block = MCU_data[blkn];
+
+ /* Encoded data is simply the next bit of the two's-complement DC value */
+ CHECK_BIT_BUFFER(br_state, 1, return FALSE);
+ if (GET_BITS(1))
+ (*block)[0] |= p1;
+ /* Note: since we use |=, repeating the assignment later is safe */
+ }
+
+ /* Completed MCU, so update state */
+ BITREAD_SAVE_STATE(cinfo,entropy->bitstate);
+
+ /* Account for restart interval (no-op if not using restarts) */
+ entropy->restarts_to_go--;
+
+ return TRUE;
+}
+
+
+/*
+ * MCU decoding for AC successive approximation refinement scan.
+ */
+
+METHODDEF(int)
+decode_mcu_AC_refine (j_decompress_ptr cinfo, JBLOCKROW *MCU_data)
+{
+ phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy;
+ int Se = cinfo->Se;
+ int p1 = 1 << cinfo->Al; /* 1 in the bit position being coded */
+ int m1 = (-1) << cinfo->Al; /* -1 in the bit position being coded */
+ int s, k, r;
+ unsigned int EOBRUN;
+ JBLOCKROW block;
+ JCOEFPTR thiscoef;
+ BITREAD_STATE_VARS;
+ d_derived_tbl * tbl;
+ int num_newnz;
+ int newnz_pos[DCTSIZE2];
+
+ /* Process restart marker if needed; may have to suspend */
+ if (cinfo->restart_interval) {
+ if (entropy->restarts_to_go == 0)
+ if (! process_restart(cinfo))
+ return FALSE;
+ }
+
+ /* If we've run out of data, don't modify the MCU.
+ */
+ if (! entropy->pub.insufficient_data) {
+
+ /* Load up working state */
+ BITREAD_LOAD_STATE(cinfo,entropy->bitstate);
+ EOBRUN = entropy->saved.EOBRUN; /* only part of saved state we need */
+
+ /* There is always only one block per MCU */
+ block = MCU_data[0];
+ tbl = entropy->ac_derived_tbl;
+
+ /* If we are forced to suspend, we must undo the assignments to any newly
+ * nonzero coefficients in the block, because otherwise we'd get confused
+ * next time about which coefficients were already nonzero.
+ * But we need not undo addition of bits to already-nonzero coefficients;
+ * instead, we can test the current bit to see if we already did it.
+ */
+ num_newnz = 0;
+
+ /* initialize coefficient loop counter to start of band */
+ k = cinfo->Ss;
+
+ if (EOBRUN == 0) {
+ for (; k <= Se; k++) {
+ HUFF_DECODE(s, br_state, tbl, goto undoit, label3);
+ r = s >> 4;
+ s &= 15;
+ if (s) {
+ if (s != 1) /* size of new coef should always be 1 */
+ WARNMS(cinfo, JWRN_HUFF_BAD_CODE);
+ CHECK_BIT_BUFFER(br_state, 1, goto undoit);
+ if (GET_BITS(1))
+ s = p1; /* newly nonzero coef is positive */
+ else
+ s = m1; /* newly nonzero coef is negative */
+ } else {
+ if (r != 15) {
+ EOBRUN = 1 << r; /* EOBr, run length is 2^r + appended bits */
+ if (r) {
+ CHECK_BIT_BUFFER(br_state, r, goto undoit);
+ r = GET_BITS(r);
+ EOBRUN += r;
+ }
+ break; /* rest of block is handled by EOB logic */
+ }
+ /* note s = 0 for processing ZRL */
+ }
+ /* Advance over already-nonzero coefs and r still-zero coefs,
+ * appending correction bits to the nonzeroes. A correction bit is 1
+ * if the absolute value of the coefficient must be increased.
+ */
+ do {
+ thiscoef = *block + jpeg_natural_order[k];
+ if (*thiscoef != 0) {
+ CHECK_BIT_BUFFER(br_state, 1, goto undoit);
+ if (GET_BITS(1)) {
+ if ((*thiscoef & p1) == 0) { /* do nothing if already set it */
+ if (*thiscoef >= 0)
+ *thiscoef += p1;
+ else
+ *thiscoef += m1;
+ }
+ }
+ } else {
+ if (--r < 0)
+ break; /* reached target zero coefficient */
+ }
+ k++;
+ } while (k <= Se);
+ if (s) {
+ int pos = jpeg_natural_order[k];
+ /* Output newly nonzero coefficient */
+ (*block)[pos] = (JCOEF) s;
+ /* Remember its position in case we have to suspend */
+ newnz_pos[num_newnz++] = pos;
+ }
+ }
+ }
+
+ if (EOBRUN > 0) {
+ /* Scan any remaining coefficient positions after the end-of-band
+ * (the last newly nonzero coefficient, if any). Append a correction
+ * bit to each already-nonzero coefficient. A correction bit is 1
+ * if the absolute value of the coefficient must be increased.
+ */
+ for (; k <= Se; k++) {
+ thiscoef = *block + jpeg_natural_order[k];
+ if (*thiscoef != 0) {
+ CHECK_BIT_BUFFER(br_state, 1, goto undoit);
+ if (GET_BITS(1)) {
+ if ((*thiscoef & p1) == 0) { /* do nothing if already changed it */
+ if (*thiscoef >= 0)
+ *thiscoef += p1;
+ else
+ *thiscoef += m1;
+ }
+ }
+ }
+ }
+ /* Count one block completed in EOB run */
+ EOBRUN--;
+ }
+
+ /* Completed MCU, so update state */
+ BITREAD_SAVE_STATE(cinfo,entropy->bitstate);
+ entropy->saved.EOBRUN = EOBRUN; /* only part of saved state we need */
+ }
+
+ /* Account for restart interval (no-op if not using restarts) */
+ entropy->restarts_to_go--;
+
+ return TRUE;
+
+undoit:
+ /* Re-zero any output coefficients that we made newly nonzero */
+ while (num_newnz > 0)
+ (*block)[newnz_pos[--num_newnz]] = 0;
+
+ return FALSE;
+}
+
+
+/*
+ * Module initialization routine for progressive Huffman entropy decoding.
+ */
+
+GLOBAL(void)
+jinit_phuff_decoder (j_decompress_ptr cinfo)
+{
+ phuff_entropy_ptr entropy;
+ int *coef_bit_ptr;
+ int ci, i;
+
+ entropy = (phuff_entropy_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(phuff_entropy_decoder));
+ cinfo->entropy = (struct jpeg_entropy_decoder *) entropy;
+ entropy->pub.start_pass = start_pass_phuff_decoder;
+
+ /* Mark derived tables unallocated */
+ for (i = 0; i < NUM_HUFF_TBLS; i++) {
+ entropy->derived_tbls[i] = NULL;
+ }
+
+ /* Create progression status table */
+ cinfo->coef_bits = (int (*)[DCTSIZE2])
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ cinfo->num_components*DCTSIZE2*SIZEOF(int));
+ coef_bit_ptr = & cinfo->coef_bits[0][0];
+ for (ci = 0; ci < cinfo->num_components; ci++)
+ for (i = 0; i < DCTSIZE2; i++)
+ *coef_bit_ptr++ = -1;
+}
+
+#endif /* D_PROGRESSIVE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jdpostct.cpp b/ml/dlib/dlib/external/libjpeg/jdpostct.cpp
new file mode 100644
index 000000000..63e10ec17
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdpostct.cpp
@@ -0,0 +1,290 @@
+/*
+ * jdpostct.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains the decompression postprocessing controller.
+ * This controller manages the upsampling, color conversion, and color
+ * quantization/reduction steps; specifically, it controls the buffering
+ * between upsample/color conversion and color quantization/reduction.
+ *
+ * If no color quantization/reduction is required, then this module has no
+ * work to do, and it just hands off to the upsample/color conversion code.
+ * An integrated upsample/convert/quantize process would replace this module
+ * entirely.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Private buffer controller object */
+
+typedef struct {
+ struct jpeg_d_post_controller pub; /* public fields */
+
+ /* Color quantization source buffer: this holds output data from
+ * the upsample/color conversion step to be passed to the quantizer.
+ * For two-pass color quantization, we need a full-image buffer;
+ * for one-pass operation, a strip buffer is sufficient.
+ */
+ jvirt_sarray_ptr whole_image; /* virtual array, or NULL if one-pass */
+ JSAMPARRAY buffer; /* strip buffer, or current strip of virtual */
+ JDIMENSION strip_height; /* buffer size in rows */
+ /* for two-pass mode only: */
+ JDIMENSION starting_row; /* row # of first row in current strip */
+ JDIMENSION next_row; /* index of next row to fill/empty in strip */
+} my_post_controller;
+
+typedef my_post_controller * my_post_ptr;
+
+
+/* Forward declarations */
+METHODDEF(void) post_process_1pass
+ JPP((j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr,
+ JDIMENSION in_row_groups_avail,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail));
+#ifdef QUANT_2PASS_SUPPORTED
+METHODDEF(void) post_process_prepass
+ JPP((j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr,
+ JDIMENSION in_row_groups_avail,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail));
+METHODDEF(void) post_process_2pass
+ JPP((j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr,
+ JDIMENSION in_row_groups_avail,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail));
+#endif
+
+
+/*
+ * Initialize for a processing pass.
+ */
+
+METHODDEF(void)
+start_pass_dpost (j_decompress_ptr cinfo, J_BUF_MODE pass_mode)
+{
+ my_post_ptr post = (my_post_ptr) cinfo->post;
+
+ switch (pass_mode) {
+ case JBUF_PASS_THRU:
+ if (cinfo->quantize_colors) {
+ /* Single-pass processing with color quantization. */
+ post->pub.post_process_data = post_process_1pass;
+ /* We could be doing buffered-image output before starting a 2-pass
+ * color quantization; in that case, jinit_d_post_controller did not
+ * allocate a strip buffer. Use the virtual-array buffer as workspace.
+ */
+ if (post->buffer == NULL) {
+ post->buffer = (*cinfo->mem->access_virt_sarray)
+ ((j_common_ptr) cinfo, post->whole_image,
+ (JDIMENSION) 0, post->strip_height, TRUE);
+ }
+ } else {
+ /* For single-pass processing without color quantization,
+ * I have no work to do; just call the upsampler directly.
+ */
+ post->pub.post_process_data = cinfo->upsample->upsample;
+ }
+ break;
+#ifdef QUANT_2PASS_SUPPORTED
+ case JBUF_SAVE_AND_PASS:
+ /* First pass of 2-pass quantization */
+ if (post->whole_image == NULL)
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ post->pub.post_process_data = post_process_prepass;
+ break;
+ case JBUF_CRANK_DEST:
+ /* Second pass of 2-pass quantization */
+ if (post->whole_image == NULL)
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ post->pub.post_process_data = post_process_2pass;
+ break;
+#endif /* QUANT_2PASS_SUPPORTED */
+ default:
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+ break;
+ }
+ post->starting_row = post->next_row = 0;
+}
+
+
+/*
+ * Process some data in the one-pass (strip buffer) case.
+ * This is used for color precision reduction as well as one-pass quantization.
+ */
+
+METHODDEF(void)
+post_process_1pass (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr,
+ JDIMENSION in_row_groups_avail,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail)
+{
+ my_post_ptr post = (my_post_ptr) cinfo->post;
+ JDIMENSION num_rows, max_rows;
+
+ /* Fill the buffer, but not more than what we can dump out in one go. */
+ /* Note we rely on the upsampler to detect bottom of image. */
+ max_rows = out_rows_avail - *out_row_ctr;
+ if (max_rows > post->strip_height)
+ max_rows = post->strip_height;
+ num_rows = 0;
+ (*cinfo->upsample->upsample) (cinfo,
+ input_buf, in_row_group_ctr, in_row_groups_avail,
+ post->buffer, &num_rows, max_rows);
+ /* Quantize and emit data. */
+ (*cinfo->cquantize->color_quantize) (cinfo,
+ post->buffer, output_buf + *out_row_ctr, (int) num_rows);
+ *out_row_ctr += num_rows;
+}
+
+
+#ifdef QUANT_2PASS_SUPPORTED
+
+/*
+ * Process some data in the first pass of 2-pass quantization.
+ */
+
+METHODDEF(void)
+post_process_prepass (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr,
+ JDIMENSION in_row_groups_avail,
+ JSAMPARRAY , JDIMENSION *out_row_ctr,
+ JDIMENSION )
+{
+ my_post_ptr post = (my_post_ptr) cinfo->post;
+ JDIMENSION old_next_row, num_rows;
+
+ /* Reposition virtual buffer if at start of strip. */
+ if (post->next_row == 0) {
+ post->buffer = (*cinfo->mem->access_virt_sarray)
+ ((j_common_ptr) cinfo, post->whole_image,
+ post->starting_row, post->strip_height, TRUE);
+ }
+
+ /* Upsample some data (up to a strip height's worth). */
+ old_next_row = post->next_row;
+ (*cinfo->upsample->upsample) (cinfo,
+ input_buf, in_row_group_ctr, in_row_groups_avail,
+ post->buffer, &post->next_row, post->strip_height);
+
+ /* Allow quantizer to scan new data. No data is emitted, */
+ /* but we advance out_row_ctr so outer loop can tell when we're done. */
+ if (post->next_row > old_next_row) {
+ num_rows = post->next_row - old_next_row;
+ (*cinfo->cquantize->color_quantize) (cinfo, post->buffer + old_next_row,
+ (JSAMPARRAY) NULL, (int) num_rows);
+ *out_row_ctr += num_rows;
+ }
+
+ /* Advance if we filled the strip. */
+ if (post->next_row >= post->strip_height) {
+ post->starting_row += post->strip_height;
+ post->next_row = 0;
+ }
+}
+
+
+/*
+ * Process some data in the second pass of 2-pass quantization.
+ */
+
+METHODDEF(void)
+post_process_2pass (j_decompress_ptr cinfo,
+ JSAMPIMAGE , JDIMENSION *,
+ JDIMENSION ,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail)
+{
+ my_post_ptr post = (my_post_ptr) cinfo->post;
+ JDIMENSION num_rows, max_rows;
+
+ /* Reposition virtual buffer if at start of strip. */
+ if (post->next_row == 0) {
+ post->buffer = (*cinfo->mem->access_virt_sarray)
+ ((j_common_ptr) cinfo, post->whole_image,
+ post->starting_row, post->strip_height, FALSE);
+ }
+
+ /* Determine number of rows to emit. */
+ num_rows = post->strip_height - post->next_row; /* available in strip */
+ max_rows = out_rows_avail - *out_row_ctr; /* available in output area */
+ if (num_rows > max_rows)
+ num_rows = max_rows;
+ /* We have to check bottom of image here, can't depend on upsampler. */
+ max_rows = cinfo->output_height - post->starting_row;
+ if (num_rows > max_rows)
+ num_rows = max_rows;
+
+ /* Quantize and emit data. */
+ (*cinfo->cquantize->color_quantize) (cinfo,
+ post->buffer + post->next_row, output_buf + *out_row_ctr,
+ (int) num_rows);
+ *out_row_ctr += num_rows;
+
+ /* Advance if we filled the strip. */
+ post->next_row += num_rows;
+ if (post->next_row >= post->strip_height) {
+ post->starting_row += post->strip_height;
+ post->next_row = 0;
+ }
+}
+
+#endif /* QUANT_2PASS_SUPPORTED */
+
+
+/*
+ * Initialize postprocessing controller.
+ */
+
+GLOBAL(void)
+jinit_d_post_controller (j_decompress_ptr cinfo, int need_full_buffer)
+{
+ my_post_ptr post;
+
+ post = (my_post_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_post_controller));
+ cinfo->post = (struct jpeg_d_post_controller *) post;
+ post->pub.start_pass = start_pass_dpost;
+ post->whole_image = NULL; /* flag for no virtual arrays */
+ post->buffer = NULL; /* flag for no strip buffer */
+
+ /* Create the quantization buffer, if needed */
+ if (cinfo->quantize_colors) {
+ /* The buffer strip height is max_v_samp_factor, which is typically
+ * an efficient number of rows for upsampling to return.
+ * (In the presence of output rescaling, we might want to be smarter?)
+ */
+ post->strip_height = (JDIMENSION) cinfo->max_v_samp_factor;
+ if (need_full_buffer) {
+ /* Two-pass color quantization: need full-image storage. */
+ /* We round up the number of rows to a multiple of the strip height. */
+#ifdef QUANT_2PASS_SUPPORTED
+ post->whole_image = (*cinfo->mem->request_virt_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE, FALSE,
+ cinfo->output_width * cinfo->out_color_components,
+ (JDIMENSION) jround_up((long) cinfo->output_height,
+ (long) post->strip_height),
+ post->strip_height);
+#else
+ ERREXIT(cinfo, JERR_BAD_BUFFER_MODE);
+#endif /* QUANT_2PASS_SUPPORTED */
+ } else {
+ /* One-pass color quantization: just make a strip buffer. */
+ post->buffer = (*cinfo->mem->alloc_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ cinfo->output_width * cinfo->out_color_components,
+ post->strip_height);
+ }
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jdsample.cpp b/ml/dlib/dlib/external/libjpeg/jdsample.cpp
new file mode 100644
index 000000000..647c8bf45
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jdsample.cpp
@@ -0,0 +1,478 @@
+/*
+ * jdsample.c
+ *
+ * Copyright (C) 1991-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains upsampling routines.
+ *
+ * Upsampling input data is counted in "row groups". A row group
+ * is defined to be (v_samp_factor * DCT_scaled_size / min_DCT_scaled_size)
+ * sample rows of each component. Upsampling will normally produce
+ * max_v_samp_factor pixel rows from each row group (but this could vary
+ * if the upsampler is applying a scale factor of its own).
+ *
+ * An excellent reference for image resampling is
+ * Digital Image Warping, George Wolberg, 1990.
+ * Pub. by IEEE Computer Society Press, Los Alamitos, CA. ISBN 0-8186-8944-7.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/* Pointer to routine to upsample a single component */
+typedef JMETHOD(void, upsample1_ptr,
+ (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr));
+
+/* Private subobject */
+
+typedef struct {
+ struct jpeg_upsampler pub; /* public fields */
+
+ /* Color conversion buffer. When using separate upsampling and color
+ * conversion steps, this buffer holds one upsampled row group until it
+ * has been color converted and output.
+ * Note: we do not allocate any storage for component(s) which are full-size,
+ * ie do not need rescaling. The corresponding entry of color_buf[] is
+ * simply set to point to the input data array, thereby avoiding copying.
+ */
+ JSAMPARRAY color_buf[MAX_COMPONENTS];
+
+ /* Per-component upsampling method pointers */
+ upsample1_ptr methods[MAX_COMPONENTS];
+
+ int next_row_out; /* counts rows emitted from color_buf */
+ JDIMENSION rows_to_go; /* counts rows remaining in image */
+
+ /* Height of an input row group for each component. */
+ int rowgroup_height[MAX_COMPONENTS];
+
+ /* These arrays save pixel expansion factors so that int_expand need not
+ * recompute them each time. They are unused for other upsampling methods.
+ */
+ unsigned char h_expand[MAX_COMPONENTS];
+ unsigned char v_expand[MAX_COMPONENTS];
+} my_upsampler;
+
+typedef my_upsampler * my_upsample_ptr;
+
+
+/*
+ * Initialize for an upsampling pass.
+ */
+
+METHODDEF(void)
+start_pass_upsample (j_decompress_ptr cinfo)
+{
+ my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample;
+
+ /* Mark the conversion buffer empty */
+ upsample->next_row_out = cinfo->max_v_samp_factor;
+ /* Initialize total-height counter for detecting bottom of image */
+ upsample->rows_to_go = cinfo->output_height;
+}
+
+
+/*
+ * Control routine to do upsampling (and color conversion).
+ *
+ * In this version we upsample each component independently.
+ * We upsample one row group into the conversion buffer, then apply
+ * color conversion a row at a time.
+ */
+
+METHODDEF(void)
+sep_upsample (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr,
+ JDIMENSION ,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail)
+{
+ my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample;
+ int ci;
+ jpeg_component_info * compptr;
+ JDIMENSION num_rows;
+
+ /* Fill the conversion buffer, if it's empty */
+ if (upsample->next_row_out >= cinfo->max_v_samp_factor) {
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Invoke per-component upsample method. Notice we pass a POINTER
+ * to color_buf[ci], so that fullsize_upsample can change it.
+ */
+ (*upsample->methods[ci]) (cinfo, compptr,
+ input_buf[ci] + (*in_row_group_ctr * upsample->rowgroup_height[ci]),
+ upsample->color_buf + ci);
+ }
+ upsample->next_row_out = 0;
+ }
+
+ /* Color-convert and emit rows */
+
+ /* How many we have in the buffer: */
+ num_rows = (JDIMENSION) (cinfo->max_v_samp_factor - upsample->next_row_out);
+ /* Not more than the distance to the end of the image. Need this test
+ * in case the image height is not a multiple of max_v_samp_factor:
+ */
+ if (num_rows > upsample->rows_to_go)
+ num_rows = upsample->rows_to_go;
+ /* And not more than what the client can accept: */
+ out_rows_avail -= *out_row_ctr;
+ if (num_rows > out_rows_avail)
+ num_rows = out_rows_avail;
+
+ (*cinfo->cconvert->color_convert) (cinfo, upsample->color_buf,
+ (JDIMENSION) upsample->next_row_out,
+ output_buf + *out_row_ctr,
+ (int) num_rows);
+
+ /* Adjust counts */
+ *out_row_ctr += num_rows;
+ upsample->rows_to_go -= num_rows;
+ upsample->next_row_out += num_rows;
+ /* When the buffer is emptied, declare this input row group consumed */
+ if (upsample->next_row_out >= cinfo->max_v_samp_factor)
+ (*in_row_group_ctr)++;
+}
+
+
+/*
+ * These are the routines invoked by sep_upsample to upsample pixel values
+ * of a single component. One row group is processed per call.
+ */
+
+
+/*
+ * For full-size components, we just make color_buf[ci] point at the
+ * input buffer, and thus avoid copying any data. Note that this is
+ * safe only because sep_upsample doesn't declare the input row group
+ * "consumed" until we are done color converting and emitting it.
+ */
+
+METHODDEF(void)
+fullsize_upsample (j_decompress_ptr , jpeg_component_info * ,
+ JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr)
+{
+ *output_data_ptr = input_data;
+}
+
+
+/*
+ * This is a no-op version used for "uninteresting" components.
+ * These components will not be referenced by color conversion.
+ */
+
+METHODDEF(void)
+noop_upsample (j_decompress_ptr , jpeg_component_info * ,
+ JSAMPARRAY , JSAMPARRAY * output_data_ptr)
+{
+ *output_data_ptr = NULL; /* safety check */
+}
+
+
+/*
+ * This version handles any integral sampling ratios.
+ * This is not used for typical JPEG files, so it need not be fast.
+ * Nor, for that matter, is it particularly accurate: the algorithm is
+ * simple replication of the input pixel onto the corresponding output
+ * pixels. The hi-falutin sampling literature refers to this as a
+ * "box filter". A box filter tends to introduce visible artifacts,
+ * so if you are actually going to use 3:1 or 4:1 sampling ratios
+ * you would be well advised to improve this code.
+ */
+
+METHODDEF(void)
+int_upsample (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr)
+{
+ my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample;
+ JSAMPARRAY output_data = *output_data_ptr;
+ JSAMPROW inptr, outptr;
+ JSAMPLE invalue;
+ int h;
+ JSAMPROW outend;
+ int h_expand, v_expand;
+ int inrow, outrow;
+
+ h_expand = upsample->h_expand[compptr->component_index];
+ v_expand = upsample->v_expand[compptr->component_index];
+
+ inrow = outrow = 0;
+ while (outrow < cinfo->max_v_samp_factor) {
+ /* Generate one output row with proper horizontal expansion */
+ inptr = input_data[inrow];
+ outptr = output_data[outrow];
+ outend = outptr + cinfo->output_width;
+ while (outptr < outend) {
+ invalue = *inptr++; /* don't need GETJSAMPLE() here */
+ for (h = h_expand; h > 0; h--) {
+ *outptr++ = invalue;
+ }
+ }
+ /* Generate any additional output rows by duplicating the first one */
+ if (v_expand > 1) {
+ jcopy_sample_rows(output_data, outrow, output_data, outrow+1,
+ v_expand-1, cinfo->output_width);
+ }
+ inrow++;
+ outrow += v_expand;
+ }
+}
+
+
+/*
+ * Fast processing for the common case of 2:1 horizontal and 1:1 vertical.
+ * It's still a box filter.
+ */
+
+METHODDEF(void)
+h2v1_upsample (j_decompress_ptr cinfo, jpeg_component_info * ,
+ JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr)
+{
+ JSAMPARRAY output_data = *output_data_ptr;
+ JSAMPROW inptr, outptr;
+ JSAMPLE invalue;
+ JSAMPROW outend;
+ int inrow;
+
+ for (inrow = 0; inrow < cinfo->max_v_samp_factor; inrow++) {
+ inptr = input_data[inrow];
+ outptr = output_data[inrow];
+ outend = outptr + cinfo->output_width;
+ while (outptr < outend) {
+ invalue = *inptr++; /* don't need GETJSAMPLE() here */
+ *outptr++ = invalue;
+ *outptr++ = invalue;
+ }
+ }
+}
+
+
+/*
+ * Fast processing for the common case of 2:1 horizontal and 2:1 vertical.
+ * It's still a box filter.
+ */
+
+METHODDEF(void)
+h2v2_upsample (j_decompress_ptr cinfo, jpeg_component_info * ,
+ JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr)
+{
+ JSAMPARRAY output_data = *output_data_ptr;
+ JSAMPROW inptr, outptr;
+ JSAMPLE invalue;
+ JSAMPROW outend;
+ int inrow, outrow;
+
+ inrow = outrow = 0;
+ while (outrow < cinfo->max_v_samp_factor) {
+ inptr = input_data[inrow];
+ outptr = output_data[outrow];
+ outend = outptr + cinfo->output_width;
+ while (outptr < outend) {
+ invalue = *inptr++; /* don't need GETJSAMPLE() here */
+ *outptr++ = invalue;
+ *outptr++ = invalue;
+ }
+ jcopy_sample_rows(output_data, outrow, output_data, outrow+1,
+ 1, cinfo->output_width);
+ inrow++;
+ outrow += 2;
+ }
+}
+
+
+/*
+ * Fancy processing for the common case of 2:1 horizontal and 1:1 vertical.
+ *
+ * The upsampling algorithm is linear interpolation between pixel centers,
+ * also known as a "triangle filter". This is a good compromise between
+ * speed and visual quality. The centers of the output pixels are 1/4 and 3/4
+ * of the way between input pixel centers.
+ *
+ * A note about the "bias" calculations: when rounding fractional values to
+ * integer, we do not want to always round 0.5 up to the next integer.
+ * If we did that, we'd introduce a noticeable bias towards larger values.
+ * Instead, this code is arranged so that 0.5 will be rounded up or down at
+ * alternate pixel locations (a simple ordered dither pattern).
+ */
+
+METHODDEF(void)
+h2v1_fancy_upsample (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr)
+{
+ JSAMPARRAY output_data = *output_data_ptr;
+ JSAMPROW inptr, outptr;
+ int invalue;
+ JDIMENSION colctr;
+ int inrow;
+
+ for (inrow = 0; inrow < cinfo->max_v_samp_factor; inrow++) {
+ inptr = input_data[inrow];
+ outptr = output_data[inrow];
+ /* Special case for first column */
+ invalue = GETJSAMPLE(*inptr++);
+ *outptr++ = (JSAMPLE) invalue;
+ *outptr++ = (JSAMPLE) ((invalue * 3 + GETJSAMPLE(*inptr) + 2) >> 2);
+
+ for (colctr = compptr->downsampled_width - 2; colctr > 0; colctr--) {
+ /* General case: 3/4 * nearer pixel + 1/4 * further pixel */
+ invalue = GETJSAMPLE(*inptr++) * 3;
+ *outptr++ = (JSAMPLE) ((invalue + GETJSAMPLE(inptr[-2]) + 1) >> 2);
+ *outptr++ = (JSAMPLE) ((invalue + GETJSAMPLE(*inptr) + 2) >> 2);
+ }
+
+ /* Special case for last column */
+ invalue = GETJSAMPLE(*inptr);
+ *outptr++ = (JSAMPLE) ((invalue * 3 + GETJSAMPLE(inptr[-1]) + 1) >> 2);
+ *outptr++ = (JSAMPLE) invalue;
+ }
+}
+
+
+/*
+ * Fancy processing for the common case of 2:1 horizontal and 2:1 vertical.
+ * Again a triangle filter; see comments for h2v1 case, above.
+ *
+ * It is OK for us to reference the adjacent input rows because we demanded
+ * context from the main buffer controller (see initialization code).
+ */
+
+METHODDEF(void)
+h2v2_fancy_upsample (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr)
+{
+ JSAMPARRAY output_data = *output_data_ptr;
+ JSAMPROW inptr0, inptr1, outptr;
+#if BITS_IN_JSAMPLE == 8
+ int thiscolsum, lastcolsum, nextcolsum;
+#else
+ long thiscolsum, lastcolsum, nextcolsum;
+#endif
+ JDIMENSION colctr;
+ int inrow, outrow, v;
+
+ inrow = outrow = 0;
+ while (outrow < cinfo->max_v_samp_factor) {
+ for (v = 0; v < 2; v++) {
+ /* inptr0 points to nearest input row, inptr1 points to next nearest */
+ inptr0 = input_data[inrow];
+ if (v == 0) /* next nearest is row above */
+ inptr1 = input_data[inrow-1];
+ else /* next nearest is row below */
+ inptr1 = input_data[inrow+1];
+ outptr = output_data[outrow++];
+
+ /* Special case for first column */
+ thiscolsum = GETJSAMPLE(*inptr0++) * 3 + GETJSAMPLE(*inptr1++);
+ nextcolsum = GETJSAMPLE(*inptr0++) * 3 + GETJSAMPLE(*inptr1++);
+ *outptr++ = (JSAMPLE) ((thiscolsum * 4 + 8) >> 4);
+ *outptr++ = (JSAMPLE) ((thiscolsum * 3 + nextcolsum + 7) >> 4);
+ lastcolsum = thiscolsum; thiscolsum = nextcolsum;
+
+ for (colctr = compptr->downsampled_width - 2; colctr > 0; colctr--) {
+ /* General case: 3/4 * nearer pixel + 1/4 * further pixel in each */
+ /* dimension, thus 9/16, 3/16, 3/16, 1/16 overall */
+ nextcolsum = GETJSAMPLE(*inptr0++) * 3 + GETJSAMPLE(*inptr1++);
+ *outptr++ = (JSAMPLE) ((thiscolsum * 3 + lastcolsum + 8) >> 4);
+ *outptr++ = (JSAMPLE) ((thiscolsum * 3 + nextcolsum + 7) >> 4);
+ lastcolsum = thiscolsum; thiscolsum = nextcolsum;
+ }
+
+ /* Special case for last column */
+ *outptr++ = (JSAMPLE) ((thiscolsum * 3 + lastcolsum + 8) >> 4);
+ *outptr++ = (JSAMPLE) ((thiscolsum * 4 + 7) >> 4);
+ }
+ inrow++;
+ }
+}
+
+
+/*
+ * Module initialization routine for upsampling.
+ */
+
+GLOBAL(void)
+jinit_upsampler (j_decompress_ptr cinfo)
+{
+ my_upsample_ptr upsample;
+ int ci;
+ jpeg_component_info * compptr;
+ int need_buffer, do_fancy;
+ int h_in_group, v_in_group, h_out_group, v_out_group;
+
+ upsample = (my_upsample_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_upsampler));
+ cinfo->upsample = (struct jpeg_upsampler *) upsample;
+ upsample->pub.start_pass = start_pass_upsample;
+ upsample->pub.upsample = sep_upsample;
+ upsample->pub.need_context_rows = FALSE; /* until we find out differently */
+
+ if (cinfo->CCIR601_sampling) /* this isn't supported */
+ ERREXIT(cinfo, JERR_CCIR601_NOTIMPL);
+
+ /* jdmainct.c doesn't support context rows when min_DCT_scaled_size = 1,
+ * so don't ask for it.
+ */
+ do_fancy = cinfo->do_fancy_upsampling && cinfo->min_DCT_scaled_size > 1;
+
+ /* Verify we can handle the sampling factors, select per-component methods,
+ * and create storage as needed.
+ */
+ for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components;
+ ci++, compptr++) {
+ /* Compute size of an "input group" after IDCT scaling. This many samples
+ * are to be converted to max_h_samp_factor * max_v_samp_factor pixels.
+ */
+ h_in_group = (compptr->h_samp_factor * compptr->DCT_scaled_size) /
+ cinfo->min_DCT_scaled_size;
+ v_in_group = (compptr->v_samp_factor * compptr->DCT_scaled_size) /
+ cinfo->min_DCT_scaled_size;
+ h_out_group = cinfo->max_h_samp_factor;
+ v_out_group = cinfo->max_v_samp_factor;
+ upsample->rowgroup_height[ci] = v_in_group; /* save for use later */
+ need_buffer = TRUE;
+ if (! compptr->component_needed) {
+ /* Don't bother to upsample an uninteresting component. */
+ upsample->methods[ci] = noop_upsample;
+ need_buffer = FALSE;
+ } else if (h_in_group == h_out_group && v_in_group == v_out_group) {
+ /* Fullsize components can be processed without any work. */
+ upsample->methods[ci] = fullsize_upsample;
+ need_buffer = FALSE;
+ } else if (h_in_group * 2 == h_out_group &&
+ v_in_group == v_out_group) {
+ /* Special cases for 2h1v upsampling */
+ if (do_fancy && compptr->downsampled_width > 2)
+ upsample->methods[ci] = h2v1_fancy_upsample;
+ else
+ upsample->methods[ci] = h2v1_upsample;
+ } else if (h_in_group * 2 == h_out_group &&
+ v_in_group * 2 == v_out_group) {
+ /* Special cases for 2h2v upsampling */
+ if (do_fancy && compptr->downsampled_width > 2) {
+ upsample->methods[ci] = h2v2_fancy_upsample;
+ upsample->pub.need_context_rows = TRUE;
+ } else
+ upsample->methods[ci] = h2v2_upsample;
+ } else if ((h_out_group % h_in_group) == 0 &&
+ (v_out_group % v_in_group) == 0) {
+ /* Generic integral-factors upsampling method */
+ upsample->methods[ci] = int_upsample;
+ upsample->h_expand[ci] = (unsigned char) (h_out_group / h_in_group);
+ upsample->v_expand[ci] = (unsigned char) (v_out_group / v_in_group);
+ } else
+ ERREXIT(cinfo, JERR_FRACT_SAMPLE_NOTIMPL);
+ if (need_buffer) {
+ upsample->color_buf[ci] = (*cinfo->mem->alloc_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (JDIMENSION) jround_up((long) cinfo->output_width,
+ (long) cinfo->max_h_samp_factor),
+ (JDIMENSION) cinfo->max_v_samp_factor);
+ }
+ }
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jerror.cpp b/ml/dlib/dlib/external/libjpeg/jerror.cpp
new file mode 100644
index 000000000..117bc4829
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jerror.cpp
@@ -0,0 +1,252 @@
+/*
+ * jerror.c
+ *
+ * Copyright (C) 1991-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains simple error-reporting and trace-message routines.
+ * These are suitable for Unix-like systems and others where writing to
+ * stderr is the right thing to do. Many applications will want to replace
+ * some or all of these routines.
+ *
+ * If you define USE_WINDOWS_MESSAGEBOX in jconfig.h or in the makefile,
+ * you get a Windows-specific hack to display error messages in a dialog box.
+ * It ain't much, but it beats dropping error messages into the bit bucket,
+ * which is what happens to output to stderr under most Windows C compilers.
+ *
+ * These routines are used by both the compression and decompression code.
+ */
+
+/* this is not a core library module, so it doesn't define JPEG_INTERNALS */
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jversion.h"
+#include "jerror.h"
+
+#ifdef USE_WINDOWS_MESSAGEBOX
+#include <windows.h>
+#endif
+
+#ifndef EXIT_FAILURE /* define exit() codes if not provided */
+#define EXIT_FAILURE 1
+#endif
+
+
+/*
+ * Create the message string table.
+ * We do this from the master message list in jerror.h by re-reading
+ * jerror.h with a suitable definition for macro JMESSAGE.
+ * The message table is made an external symbol just in case any applications
+ * want to refer to it directly.
+ */
+
+#ifdef NEED_SHORT_EXTERNAL_NAMES
+#define jpeg_std_message_table jMsgTable
+#endif
+
+#define JMESSAGE(code,string) string ,
+
+const char * const jpeg_std_message_table[] = {
+#include "jerror.h"
+ NULL
+};
+
+
+/*
+ * Error exit handler: must not return to caller.
+ *
+ * Applications may override this if they want to get control back after
+ * an error. Typically one would longjmp somewhere instead of exiting.
+ * The setjmp buffer can be made a private field within an expanded error
+ * handler object. Note that the info needed to generate an error message
+ * is stored in the error object, so you can generate the message now or
+ * later, at your convenience.
+ * You should make sure that the JPEG object is cleaned up (with jpeg_abort
+ * or jpeg_destroy) at some point.
+ */
+
+METHODDEF(void)
+error_exit (j_common_ptr cinfo)
+{
+ /* Always display the message */
+ (*cinfo->err->output_message) (cinfo);
+
+ /* Let the memory manager delete any temp files before we die */
+ jpeg_destroy(cinfo);
+
+ exit(EXIT_FAILURE);
+}
+
+
+/*
+ * Actual output of an error or trace message.
+ * Applications may override this method to send JPEG messages somewhere
+ * other than stderr.
+ *
+ * On Windows, printing to stderr is generally completely useless,
+ * so we provide optional code to produce an error-dialog popup.
+ * Most Windows applications will still prefer to override this routine,
+ * but if they don't, it'll do something at least marginally useful.
+ *
+ * NOTE: to use the library in an environment that doesn't support the
+ * C stdio library, you may have to delete the call to fprintf() entirely,
+ * not just not use this routine.
+ */
+
+METHODDEF(void)
+output_message (j_common_ptr cinfo)
+{
+ char buffer[JMSG_LENGTH_MAX];
+
+ /* Create the message */
+ (*cinfo->err->format_message) (cinfo, buffer);
+
+#ifdef USE_WINDOWS_MESSAGEBOX
+ /* Display it in a message dialog box */
+ MessageBox(GetActiveWindow(), buffer, "JPEG Library Error",
+ MB_OK | MB_ICONERROR);
+#else
+ /* Send it to stderr, adding a newline */
+ fprintf(stderr, "%s\n", buffer);
+#endif
+}
+
+
+/*
+ * Decide whether to emit a trace or warning message.
+ * msg_level is one of:
+ * -1: recoverable corrupt-data warning, may want to abort.
+ * 0: important advisory messages (always display to user).
+ * 1: first level of tracing detail.
+ * 2,3,...: successively more detailed tracing messages.
+ * An application might override this method if it wanted to abort on warnings
+ * or change the policy about which messages to display.
+ */
+
+METHODDEF(void)
+emit_message (j_common_ptr cinfo, int msg_level)
+{
+ struct jpeg_error_mgr * err = cinfo->err;
+
+ if (msg_level < 0) {
+ /* It's a warning message. Since corrupt files may generate many warnings,
+ * the policy implemented here is to show only the first warning,
+ * unless trace_level >= 3.
+ */
+ if (err->num_warnings == 0 || err->trace_level >= 3)
+ (*err->output_message) (cinfo);
+ /* Always count warnings in num_warnings. */
+ err->num_warnings++;
+ } else {
+ /* It's a trace message. Show it if trace_level >= msg_level. */
+ if (err->trace_level >= msg_level)
+ (*err->output_message) (cinfo);
+ }
+}
+
+
+/*
+ * Format a message string for the most recent JPEG error or message.
+ * The message is stored into buffer, which should be at least JMSG_LENGTH_MAX
+ * characters. Note that no '\n' character is added to the string.
+ * Few applications should need to override this method.
+ */
+
+METHODDEF(void)
+format_message (j_common_ptr cinfo, char * buffer)
+{
+ struct jpeg_error_mgr * err = cinfo->err;
+ int msg_code = err->msg_code;
+ const char * msgtext = NULL;
+ const char * msgptr;
+ char ch;
+ int isstring;
+
+ /* Look up message string in proper table */
+ if (msg_code > 0 && msg_code <= err->last_jpeg_message) {
+ msgtext = err->jpeg_message_table[msg_code];
+ } else if (err->addon_message_table != NULL &&
+ msg_code >= err->first_addon_message &&
+ msg_code <= err->last_addon_message) {
+ msgtext = err->addon_message_table[msg_code - err->first_addon_message];
+ }
+
+ /* Defend against bogus message number */
+ if (msgtext == NULL) {
+ err->msg_parm.i[0] = msg_code;
+ msgtext = err->jpeg_message_table[0];
+ }
+
+ /* Check for string parameter, as indicated by %s in the message text */
+ isstring = FALSE;
+ msgptr = msgtext;
+ while ((ch = *msgptr++) != '\0') {
+ if (ch == '%') {
+ if (*msgptr == 's') isstring = TRUE;
+ break;
+ }
+ }
+
+ /* Format the message into the passed buffer */
+ if (isstring)
+ sprintf(buffer, msgtext, err->msg_parm.s);
+ else
+ sprintf(buffer, msgtext,
+ err->msg_parm.i[0], err->msg_parm.i[1],
+ err->msg_parm.i[2], err->msg_parm.i[3],
+ err->msg_parm.i[4], err->msg_parm.i[5],
+ err->msg_parm.i[6], err->msg_parm.i[7]);
+}
+
+
+/*
+ * Reset error state variables at start of a new image.
+ * This is called during compression startup to reset trace/error
+ * processing to default state, without losing any application-specific
+ * method pointers. An application might possibly want to override
+ * this method if it has additional error processing state.
+ */
+
+METHODDEF(void)
+reset_error_mgr (j_common_ptr cinfo)
+{
+ cinfo->err->num_warnings = 0;
+ /* trace_level is not reset since it is an application-supplied parameter */
+ cinfo->err->msg_code = 0; /* may be useful as a flag for "no error" */
+}
+
+
+/*
+ * Fill in the standard error-handling methods in a jpeg_error_mgr object.
+ * Typical call is:
+ * struct jpeg_compress_struct cinfo;
+ * struct jpeg_error_mgr err;
+ *
+ * cinfo.err = jpeg_std_error(&err);
+ * after which the application may override some of the methods.
+ */
+
+GLOBAL(struct jpeg_error_mgr *)
+jpeg_std_error (struct jpeg_error_mgr * err)
+{
+ err->error_exit = error_exit;
+ err->emit_message = emit_message;
+ err->output_message = output_message;
+ err->format_message = format_message;
+ err->reset_error_mgr = reset_error_mgr;
+
+ err->trace_level = 0; /* default = no tracing */
+ err->num_warnings = 0; /* no warnings emitted yet */
+ err->msg_code = 0; /* may be useful as a flag for "no error" */
+
+ /* Initialize message table pointers */
+ err->jpeg_message_table = jpeg_std_message_table;
+ err->last_jpeg_message = (int) JMSG_LASTMSGCODE - 1;
+
+ err->addon_message_table = NULL;
+ err->first_addon_message = 0; /* for safety */
+ err->last_addon_message = 0;
+
+ return err;
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jerror.h b/ml/dlib/dlib/external/libjpeg/jerror.h
new file mode 100644
index 000000000..fc2fffeac
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jerror.h
@@ -0,0 +1,291 @@
+/*
+ * jerror.h
+ *
+ * Copyright (C) 1994-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file defines the error and message codes for the JPEG library.
+ * Edit this file to add new codes, or to translate the message strings to
+ * some other language.
+ * A set of error-reporting macros are defined too. Some applications using
+ * the JPEG library may wish to include this file to get the error codes
+ * and/or the macros.
+ */
+
+/*
+ * To define the enum list of message codes, include this file without
+ * defining macro JMESSAGE. To create a message string table, include it
+ * again with a suitable JMESSAGE definition (see jerror.c for an example).
+ */
+#ifndef JMESSAGE
+#ifndef JERROR_H
+/* First time through, define the enum list */
+#define JMAKE_ENUM_LIST
+#else
+/* Repeated inclusions of this file are no-ops unless JMESSAGE is defined */
+#define JMESSAGE(code,string)
+#endif /* JERROR_H */
+#endif /* JMESSAGE */
+
+#ifdef JMAKE_ENUM_LIST
+
+typedef enum {
+
+#define JMESSAGE(code,string) code ,
+
+#endif /* JMAKE_ENUM_LIST */
+
+JMESSAGE(JMSG_NOMESSAGE, "Bogus message code %d") /* Must be first entry! */
+
+/* For maintenance convenience, list is alphabetical by message code name */
+JMESSAGE(JERR_ARITH_NOTIMPL,
+ "Sorry, there are legal restrictions on arithmetic coding")
+JMESSAGE(JERR_BAD_ALIGN_TYPE, "ALIGN_TYPE is wrong, please fix")
+JMESSAGE(JERR_BAD_ALLOC_CHUNK, "MAX_ALLOC_CHUNK is wrong, please fix")
+JMESSAGE(JERR_BAD_BUFFER_MODE, "Bogus buffer control mode")
+JMESSAGE(JERR_BAD_COMPONENT_ID, "Invalid component ID %d in SOS")
+JMESSAGE(JERR_BAD_DCT_COEF, "DCT coefficient out of range")
+JMESSAGE(JERR_BAD_DCTSIZE, "IDCT output block size %d not supported")
+JMESSAGE(JERR_BAD_HUFF_TABLE, "Bogus Huffman table definition")
+JMESSAGE(JERR_BAD_IN_COLORSPACE, "Bogus input colorspace")
+JMESSAGE(JERR_BAD_J_COLORSPACE, "Bogus JPEG colorspace")
+JMESSAGE(JERR_BAD_LENGTH, "Bogus marker length")
+JMESSAGE(JERR_BAD_LIB_VERSION,
+ "Wrong JPEG library version: library is %d, caller expects %d")
+JMESSAGE(JERR_BAD_MCU_SIZE, "Sampling factors too large for interleaved scan")
+JMESSAGE(JERR_BAD_POOL_ID, "Invalid memory pool code %d")
+JMESSAGE(JERR_BAD_PRECISION, "Unsupported JPEG data precision %d")
+JMESSAGE(JERR_BAD_PROGRESSION,
+ "Invalid progressive parameters Ss=%d Se=%d Ah=%d Al=%d")
+JMESSAGE(JERR_BAD_PROG_SCRIPT,
+ "Invalid progressive parameters at scan script entry %d")
+JMESSAGE(JERR_BAD_SAMPLING, "Bogus sampling factors")
+JMESSAGE(JERR_BAD_SCAN_SCRIPT, "Invalid scan script at entry %d")
+JMESSAGE(JERR_BAD_STATE, "Improper call to JPEG library in state %d")
+JMESSAGE(JERR_BAD_STRUCT_SIZE,
+ "JPEG parameter struct mismatch: library thinks size is %u, caller expects %u")
+JMESSAGE(JERR_BAD_VIRTUAL_ACCESS, "Bogus virtual array access")
+JMESSAGE(JERR_BUFFER_SIZE, "Buffer passed to JPEG library is too small")
+JMESSAGE(JERR_CANT_SUSPEND, "Suspension not allowed here")
+JMESSAGE(JERR_CCIR601_NOTIMPL, "CCIR601 sampling not implemented yet")
+JMESSAGE(JERR_COMPONENT_COUNT, "Too many color components: %d, max %d")
+JMESSAGE(JERR_CONVERSION_NOTIMPL, "Unsupported color conversion request")
+JMESSAGE(JERR_DAC_INDEX, "Bogus DAC index %d")
+JMESSAGE(JERR_DAC_VALUE, "Bogus DAC value 0x%x")
+JMESSAGE(JERR_DHT_INDEX, "Bogus DHT index %d")
+JMESSAGE(JERR_DQT_INDEX, "Bogus DQT index %d")
+JMESSAGE(JERR_EMPTY_IMAGE, "Empty JPEG image (DNL not supported)")
+JMESSAGE(JERR_EMS_READ, "Read from EMS failed")
+JMESSAGE(JERR_EMS_WRITE, "Write to EMS failed")
+JMESSAGE(JERR_EOI_EXPECTED, "Didn't expect more than one scan")
+JMESSAGE(JERR_FILE_READ, "Input file read error")
+JMESSAGE(JERR_FILE_WRITE, "Output file write error --- out of disk space?")
+JMESSAGE(JERR_FRACT_SAMPLE_NOTIMPL, "Fractional sampling not implemented yet")
+JMESSAGE(JERR_HUFF_CLEN_OVERFLOW, "Huffman code size table overflow")
+JMESSAGE(JERR_HUFF_MISSING_CODE, "Missing Huffman code table entry")
+JMESSAGE(JERR_IMAGE_TOO_BIG, "Maximum supported image dimension is %u pixels")
+JMESSAGE(JERR_INPUT_EMPTY, "Empty input file")
+JMESSAGE(JERR_INPUT_EOF, "Premature end of input file")
+JMESSAGE(JERR_MISMATCHED_QUANT_TABLE,
+ "Cannot transcode due to multiple use of quantization table %d")
+JMESSAGE(JERR_MISSING_DATA, "Scan script does not transmit all data")
+JMESSAGE(JERR_MODE_CHANGE, "Invalid color quantization mode change")
+JMESSAGE(JERR_NOTIMPL, "Not implemented yet")
+JMESSAGE(JERR_NOT_COMPILED, "Requested feature was omitted at compile time")
+JMESSAGE(JERR_NO_BACKING_STORE, "Backing store not supported")
+JMESSAGE(JERR_NO_HUFF_TABLE, "Huffman table 0x%02x was not defined")
+JMESSAGE(JERR_NO_IMAGE, "JPEG datastream contains no image")
+JMESSAGE(JERR_NO_QUANT_TABLE, "Quantization table 0x%02x was not defined")
+JMESSAGE(JERR_NO_SOI, "Not a JPEG file: starts with 0x%02x 0x%02x")
+JMESSAGE(JERR_OUT_OF_MEMORY, "Insufficient memory (case %d)")
+JMESSAGE(JERR_QUANT_COMPONENTS,
+ "Cannot quantize more than %d color components")
+JMESSAGE(JERR_QUANT_FEW_COLORS, "Cannot quantize to fewer than %d colors")
+JMESSAGE(JERR_QUANT_MANY_COLORS, "Cannot quantize to more than %d colors")
+JMESSAGE(JERR_SOF_DUPLICATE, "Invalid JPEG file structure: two SOF markers")
+JMESSAGE(JERR_SOF_NO_SOS, "Invalid JPEG file structure: missing SOS marker")
+JMESSAGE(JERR_SOF_UNSUPPORTED, "Unsupported JPEG process: SOF type 0x%02x")
+JMESSAGE(JERR_SOI_DUPLICATE, "Invalid JPEG file structure: two SOI markers")
+JMESSAGE(JERR_SOS_NO_SOF, "Invalid JPEG file structure: SOS before SOF")
+JMESSAGE(JERR_TFILE_CREATE, "Failed to create temporary file %s")
+JMESSAGE(JERR_TFILE_READ, "Read failed on temporary file")
+JMESSAGE(JERR_TFILE_SEEK, "Seek failed on temporary file")
+JMESSAGE(JERR_TFILE_WRITE,
+ "Write failed on temporary file --- out of disk space?")
+JMESSAGE(JERR_TOO_LITTLE_DATA, "Application transferred too few scanlines")
+JMESSAGE(JERR_UNKNOWN_MARKER, "Unsupported marker type 0x%02x")
+JMESSAGE(JERR_VIRTUAL_BUG, "Virtual array controller messed up")
+JMESSAGE(JERR_WIDTH_OVERFLOW, "Image too wide for this implementation")
+JMESSAGE(JERR_XMS_READ, "Read from XMS failed")
+JMESSAGE(JERR_XMS_WRITE, "Write to XMS failed")
+JMESSAGE(JMSG_COPYRIGHT, JCOPYRIGHT)
+JMESSAGE(JMSG_VERSION, JVERSION)
+JMESSAGE(JTRC_16BIT_TABLES,
+ "Caution: quantization tables are too coarse for baseline JPEG")
+JMESSAGE(JTRC_ADOBE,
+ "Adobe APP14 marker: version %d, flags 0x%04x 0x%04x, transform %d")
+JMESSAGE(JTRC_APP0, "Unknown APP0 marker (not JFIF), length %u")
+JMESSAGE(JTRC_APP14, "Unknown APP14 marker (not Adobe), length %u")
+JMESSAGE(JTRC_DAC, "Define Arithmetic Table 0x%02x: 0x%02x")
+JMESSAGE(JTRC_DHT, "Define Huffman Table 0x%02x")
+JMESSAGE(JTRC_DQT, "Define Quantization Table %d precision %d")
+JMESSAGE(JTRC_DRI, "Define Restart Interval %u")
+JMESSAGE(JTRC_EMS_CLOSE, "Freed EMS handle %u")
+JMESSAGE(JTRC_EMS_OPEN, "Obtained EMS handle %u")
+JMESSAGE(JTRC_EOI, "End Of Image")
+JMESSAGE(JTRC_HUFFBITS, " %3d %3d %3d %3d %3d %3d %3d %3d")
+JMESSAGE(JTRC_JFIF, "JFIF APP0 marker: version %d.%02d, density %dx%d %d")
+JMESSAGE(JTRC_JFIF_BADTHUMBNAILSIZE,
+ "Warning: thumbnail image size does not match data length %u")
+JMESSAGE(JTRC_JFIF_EXTENSION,
+ "JFIF extension marker: type 0x%02x, length %u")
+JMESSAGE(JTRC_JFIF_THUMBNAIL, " with %d x %d thumbnail image")
+JMESSAGE(JTRC_MISC_MARKER, "Miscellaneous marker 0x%02x, length %u")
+JMESSAGE(JTRC_PARMLESS_MARKER, "Unexpected marker 0x%02x")
+JMESSAGE(JTRC_QUANTVALS, " %4u %4u %4u %4u %4u %4u %4u %4u")
+JMESSAGE(JTRC_QUANT_3_NCOLORS, "Quantizing to %d = %d*%d*%d colors")
+JMESSAGE(JTRC_QUANT_NCOLORS, "Quantizing to %d colors")
+JMESSAGE(JTRC_QUANT_SELECTED, "Selected %d colors for quantization")
+JMESSAGE(JTRC_RECOVERY_ACTION, "At marker 0x%02x, recovery action %d")
+JMESSAGE(JTRC_RST, "RST%d")
+JMESSAGE(JTRC_SMOOTH_NOTIMPL,
+ "Smoothing not supported with nonstandard sampling ratios")
+JMESSAGE(JTRC_SOF, "Start Of Frame 0x%02x: width=%u, height=%u, components=%d")
+JMESSAGE(JTRC_SOF_COMPONENT, " Component %d: %dhx%dv q=%d")
+JMESSAGE(JTRC_SOI, "Start of Image")
+JMESSAGE(JTRC_SOS, "Start Of Scan: %d components")
+JMESSAGE(JTRC_SOS_COMPONENT, " Component %d: dc=%d ac=%d")
+JMESSAGE(JTRC_SOS_PARAMS, " Ss=%d, Se=%d, Ah=%d, Al=%d")
+JMESSAGE(JTRC_TFILE_CLOSE, "Closed temporary file %s")
+JMESSAGE(JTRC_TFILE_OPEN, "Opened temporary file %s")
+JMESSAGE(JTRC_THUMB_JPEG,
+ "JFIF extension marker: JPEG-compressed thumbnail image, length %u")
+JMESSAGE(JTRC_THUMB_PALETTE,
+ "JFIF extension marker: palette thumbnail image, length %u")
+JMESSAGE(JTRC_THUMB_RGB,
+ "JFIF extension marker: RGB thumbnail image, length %u")
+JMESSAGE(JTRC_UNKNOWN_IDS,
+ "Unrecognized component IDs %d %d %d, assuming YCbCr")
+JMESSAGE(JTRC_XMS_CLOSE, "Freed XMS handle %u")
+JMESSAGE(JTRC_XMS_OPEN, "Obtained XMS handle %u")
+JMESSAGE(JWRN_ADOBE_XFORM, "Unknown Adobe color transform code %d")
+JMESSAGE(JWRN_BOGUS_PROGRESSION,
+ "Inconsistent progression sequence for component %d coefficient %d")
+JMESSAGE(JWRN_EXTRANEOUS_DATA,
+ "Corrupt JPEG data: %u extraneous bytes before marker 0x%02x")
+JMESSAGE(JWRN_HIT_MARKER, "Corrupt JPEG data: premature end of data segment")
+JMESSAGE(JWRN_HUFF_BAD_CODE, "Corrupt JPEG data: bad Huffman code")
+JMESSAGE(JWRN_JFIF_MAJOR, "Warning: unknown JFIF revision number %d.%02d")
+JMESSAGE(JWRN_JPEG_EOF, "Premature end of JPEG file")
+JMESSAGE(JWRN_MUST_RESYNC,
+ "Corrupt JPEG data: found marker 0x%02x instead of RST%d")
+JMESSAGE(JWRN_NOT_SEQUENTIAL, "Invalid SOS parameters for sequential JPEG")
+JMESSAGE(JWRN_TOO_MUCH_DATA, "Application transferred too many scanlines")
+
+#ifdef JMAKE_ENUM_LIST
+
+ JMSG_LASTMSGCODE
+} J_MESSAGE_CODE;
+
+#undef JMAKE_ENUM_LIST
+#endif /* JMAKE_ENUM_LIST */
+
+/* Zap JMESSAGE macro so that future re-inclusions do nothing by default */
+#undef JMESSAGE
+
+
+#ifndef JERROR_H
+#define JERROR_H
+
+/* Macros to simplify using the error and trace message stuff */
+/* The first parameter is either type of cinfo pointer */
+
+/* Fatal errors (print message and exit) */
+#define ERREXIT(cinfo,code) \
+ ((cinfo)->err->msg_code = (code), \
+ (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo)))
+#define ERREXIT1(cinfo,code,p1) \
+ ((cinfo)->err->msg_code = (code), \
+ (cinfo)->err->msg_parm.i[0] = (p1), \
+ (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo)))
+#define ERREXIT2(cinfo,code,p1,p2) \
+ ((cinfo)->err->msg_code = (code), \
+ (cinfo)->err->msg_parm.i[0] = (p1), \
+ (cinfo)->err->msg_parm.i[1] = (p2), \
+ (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo)))
+#define ERREXIT3(cinfo,code,p1,p2,p3) \
+ ((cinfo)->err->msg_code = (code), \
+ (cinfo)->err->msg_parm.i[0] = (p1), \
+ (cinfo)->err->msg_parm.i[1] = (p2), \
+ (cinfo)->err->msg_parm.i[2] = (p3), \
+ (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo)))
+#define ERREXIT4(cinfo,code,p1,p2,p3,p4) \
+ ((cinfo)->err->msg_code = (code), \
+ (cinfo)->err->msg_parm.i[0] = (p1), \
+ (cinfo)->err->msg_parm.i[1] = (p2), \
+ (cinfo)->err->msg_parm.i[2] = (p3), \
+ (cinfo)->err->msg_parm.i[3] = (p4), \
+ (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo)))
+#define ERREXITS(cinfo,code,str) \
+ ((cinfo)->err->msg_code = (code), \
+ strncpy((cinfo)->err->msg_parm.s, (str), JMSG_STR_PARM_MAX), \
+ (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo)))
+
+#define MAKESTMT(stuff) do { stuff } while (0)
+
+/* Nonfatal errors (we can keep going, but the data is probably corrupt) */
+#define WARNMS(cinfo,code) \
+ ((cinfo)->err->msg_code = (code), \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), -1))
+#define WARNMS1(cinfo,code,p1) \
+ ((cinfo)->err->msg_code = (code), \
+ (cinfo)->err->msg_parm.i[0] = (p1), \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), -1))
+#define WARNMS2(cinfo,code,p1,p2) \
+ ((cinfo)->err->msg_code = (code), \
+ (cinfo)->err->msg_parm.i[0] = (p1), \
+ (cinfo)->err->msg_parm.i[1] = (p2), \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), -1))
+
+/* Informational/debugging messages */
+#define TRACEMS(cinfo,lvl,code) \
+ ((cinfo)->err->msg_code = (code), \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)))
+#define TRACEMS1(cinfo,lvl,code,p1) \
+ ((cinfo)->err->msg_code = (code), \
+ (cinfo)->err->msg_parm.i[0] = (p1), \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)))
+#define TRACEMS2(cinfo,lvl,code,p1,p2) \
+ ((cinfo)->err->msg_code = (code), \
+ (cinfo)->err->msg_parm.i[0] = (p1), \
+ (cinfo)->err->msg_parm.i[1] = (p2), \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)))
+#define TRACEMS3(cinfo,lvl,code,p1,p2,p3) \
+ MAKESTMT(int * _mp = (cinfo)->err->msg_parm.i; \
+ _mp[0] = (p1); _mp[1] = (p2); _mp[2] = (p3); \
+ (cinfo)->err->msg_code = (code); \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)); )
+#define TRACEMS4(cinfo,lvl,code,p1,p2,p3,p4) \
+ MAKESTMT(int * _mp = (cinfo)->err->msg_parm.i; \
+ _mp[0] = (p1); _mp[1] = (p2); _mp[2] = (p3); _mp[3] = (p4); \
+ (cinfo)->err->msg_code = (code); \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)); )
+#define TRACEMS5(cinfo,lvl,code,p1,p2,p3,p4,p5) \
+ MAKESTMT(int * _mp = (cinfo)->err->msg_parm.i; \
+ _mp[0] = (p1); _mp[1] = (p2); _mp[2] = (p3); _mp[3] = (p4); \
+ _mp[4] = (p5); \
+ (cinfo)->err->msg_code = (code); \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)); )
+#define TRACEMS8(cinfo,lvl,code,p1,p2,p3,p4,p5,p6,p7,p8) \
+ MAKESTMT(int * _mp = (cinfo)->err->msg_parm.i; \
+ _mp[0] = (p1); _mp[1] = (p2); _mp[2] = (p3); _mp[3] = (p4); \
+ _mp[4] = (p5); _mp[5] = (p6); _mp[6] = (p7); _mp[7] = (p8); \
+ (cinfo)->err->msg_code = (code); \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)); )
+#define TRACEMSS(cinfo,lvl,code,str) \
+ ((cinfo)->err->msg_code = (code), \
+ strncpy((cinfo)->err->msg_parm.s, (str), JMSG_STR_PARM_MAX), \
+ (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)))
+
+#endif /* JERROR_H */
diff --git a/ml/dlib/dlib/external/libjpeg/jfdctflt.cpp b/ml/dlib/dlib/external/libjpeg/jfdctflt.cpp
new file mode 100644
index 000000000..79d7a0078
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jfdctflt.cpp
@@ -0,0 +1,168 @@
+/*
+ * jfdctflt.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains a floating-point implementation of the
+ * forward DCT (Discrete Cosine Transform).
+ *
+ * This implementation should be more accurate than either of the integer
+ * DCT implementations. However, it may not give the same results on all
+ * machines because of differences in roundoff behavior. Speed will depend
+ * on the hardware's floating point capacity.
+ *
+ * A 2-D DCT can be done by 1-D DCT on each row followed by 1-D DCT
+ * on each column. Direct algorithms are also available, but they are
+ * much more complex and seem not to be any faster when reduced to code.
+ *
+ * This implementation is based on Arai, Agui, and Nakajima's algorithm for
+ * scaled DCT. Their original paper (Trans. IEICE E-71(11):1095) is in
+ * Japanese, but the algorithm is described in the Pennebaker & Mitchell
+ * JPEG textbook (see REFERENCES section in file README). The following code
+ * is based directly on figure 4-8 in P&M.
+ * While an 8-point DCT cannot be done in less than 11 multiplies, it is
+ * possible to arrange the computation so that many of the multiplies are
+ * simple scalings of the final outputs. These multiplies can then be
+ * folded into the multiplications or divisions by the JPEG quantization
+ * table entries. The AA&N method leaves only 5 multiplies and 29 adds
+ * to be done in the DCT itself.
+ * The primary disadvantage of this method is that with a fixed-point
+ * implementation, accuracy is lost due to imprecise representation of the
+ * scaled quantization values. However, that problem does not arise if
+ * we use floating point arithmetic.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdct.h" /* Private declarations for DCT subsystem */
+
+#ifdef DCT_FLOAT_SUPPORTED
+
+
+/*
+ * This module is specialized to the case DCTSIZE = 8.
+ */
+
+#if DCTSIZE != 8
+ Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */
+#endif
+
+
+/*
+ * Perform the forward DCT on one block of samples.
+ */
+
+GLOBAL(void)
+jpeg_fdct_float (FAST_FLOAT * data)
+{
+ FAST_FLOAT tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
+ FAST_FLOAT tmp10, tmp11, tmp12, tmp13;
+ FAST_FLOAT z1, z2, z3, z4, z5, z11, z13;
+ FAST_FLOAT *dataptr;
+ int ctr;
+
+ /* Pass 1: process rows. */
+
+ dataptr = data;
+ for (ctr = DCTSIZE-1; ctr >= 0; ctr--) {
+ tmp0 = dataptr[0] + dataptr[7];
+ tmp7 = dataptr[0] - dataptr[7];
+ tmp1 = dataptr[1] + dataptr[6];
+ tmp6 = dataptr[1] - dataptr[6];
+ tmp2 = dataptr[2] + dataptr[5];
+ tmp5 = dataptr[2] - dataptr[5];
+ tmp3 = dataptr[3] + dataptr[4];
+ tmp4 = dataptr[3] - dataptr[4];
+
+ /* Even part */
+
+ tmp10 = tmp0 + tmp3; /* phase 2 */
+ tmp13 = tmp0 - tmp3;
+ tmp11 = tmp1 + tmp2;
+ tmp12 = tmp1 - tmp2;
+
+ dataptr[0] = tmp10 + tmp11; /* phase 3 */
+ dataptr[4] = tmp10 - tmp11;
+
+ z1 = (tmp12 + tmp13) * ((FAST_FLOAT) 0.707106781); /* c4 */
+ dataptr[2] = tmp13 + z1; /* phase 5 */
+ dataptr[6] = tmp13 - z1;
+
+ /* Odd part */
+
+ tmp10 = tmp4 + tmp5; /* phase 2 */
+ tmp11 = tmp5 + tmp6;
+ tmp12 = tmp6 + tmp7;
+
+ /* The rotator is modified from fig 4-8 to avoid extra negations. */
+ z5 = (tmp10 - tmp12) * ((FAST_FLOAT) 0.382683433); /* c6 */
+ z2 = ((FAST_FLOAT) 0.541196100) * tmp10 + z5; /* c2-c6 */
+ z4 = ((FAST_FLOAT) 1.306562965) * tmp12 + z5; /* c2+c6 */
+ z3 = tmp11 * ((FAST_FLOAT) 0.707106781); /* c4 */
+
+ z11 = tmp7 + z3; /* phase 5 */
+ z13 = tmp7 - z3;
+
+ dataptr[5] = z13 + z2; /* phase 6 */
+ dataptr[3] = z13 - z2;
+ dataptr[1] = z11 + z4;
+ dataptr[7] = z11 - z4;
+
+ dataptr += DCTSIZE; /* advance pointer to next row */
+ }
+
+ /* Pass 2: process columns. */
+
+ dataptr = data;
+ for (ctr = DCTSIZE-1; ctr >= 0; ctr--) {
+ tmp0 = dataptr[DCTSIZE*0] + dataptr[DCTSIZE*7];
+ tmp7 = dataptr[DCTSIZE*0] - dataptr[DCTSIZE*7];
+ tmp1 = dataptr[DCTSIZE*1] + dataptr[DCTSIZE*6];
+ tmp6 = dataptr[DCTSIZE*1] - dataptr[DCTSIZE*6];
+ tmp2 = dataptr[DCTSIZE*2] + dataptr[DCTSIZE*5];
+ tmp5 = dataptr[DCTSIZE*2] - dataptr[DCTSIZE*5];
+ tmp3 = dataptr[DCTSIZE*3] + dataptr[DCTSIZE*4];
+ tmp4 = dataptr[DCTSIZE*3] - dataptr[DCTSIZE*4];
+
+ /* Even part */
+
+ tmp10 = tmp0 + tmp3; /* phase 2 */
+ tmp13 = tmp0 - tmp3;
+ tmp11 = tmp1 + tmp2;
+ tmp12 = tmp1 - tmp2;
+
+ dataptr[DCTSIZE*0] = tmp10 + tmp11; /* phase 3 */
+ dataptr[DCTSIZE*4] = tmp10 - tmp11;
+
+ z1 = (tmp12 + tmp13) * ((FAST_FLOAT) 0.707106781); /* c4 */
+ dataptr[DCTSIZE*2] = tmp13 + z1; /* phase 5 */
+ dataptr[DCTSIZE*6] = tmp13 - z1;
+
+ /* Odd part */
+
+ tmp10 = tmp4 + tmp5; /* phase 2 */
+ tmp11 = tmp5 + tmp6;
+ tmp12 = tmp6 + tmp7;
+
+ /* The rotator is modified from fig 4-8 to avoid extra negations. */
+ z5 = (tmp10 - tmp12) * ((FAST_FLOAT) 0.382683433); /* c6 */
+ z2 = ((FAST_FLOAT) 0.541196100) * tmp10 + z5; /* c2-c6 */
+ z4 = ((FAST_FLOAT) 1.306562965) * tmp12 + z5; /* c2+c6 */
+ z3 = tmp11 * ((FAST_FLOAT) 0.707106781); /* c4 */
+
+ z11 = tmp7 + z3; /* phase 5 */
+ z13 = tmp7 - z3;
+
+ dataptr[DCTSIZE*5] = z13 + z2; /* phase 6 */
+ dataptr[DCTSIZE*3] = z13 - z2;
+ dataptr[DCTSIZE*1] = z11 + z4;
+ dataptr[DCTSIZE*7] = z11 - z4;
+
+ dataptr++; /* advance pointer to next column */
+ }
+}
+
+#endif /* DCT_FLOAT_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jfdctfst.cpp b/ml/dlib/dlib/external/libjpeg/jfdctfst.cpp
new file mode 100644
index 000000000..4c96d4f3d
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jfdctfst.cpp
@@ -0,0 +1,224 @@
+/*
+ * jfdctfst.c
+ *
+ * Copyright (C) 1994-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains a fast, not so accurate integer implementation of the
+ * forward DCT (Discrete Cosine Transform).
+ *
+ * A 2-D DCT can be done by 1-D DCT on each row followed by 1-D DCT
+ * on each column. Direct algorithms are also available, but they are
+ * much more complex and seem not to be any faster when reduced to code.
+ *
+ * This implementation is based on Arai, Agui, and Nakajima's algorithm for
+ * scaled DCT. Their original paper (Trans. IEICE E-71(11):1095) is in
+ * Japanese, but the algorithm is described in the Pennebaker & Mitchell
+ * JPEG textbook (see REFERENCES section in file README). The following code
+ * is based directly on figure 4-8 in P&M.
+ * While an 8-point DCT cannot be done in less than 11 multiplies, it is
+ * possible to arrange the computation so that many of the multiplies are
+ * simple scalings of the final outputs. These multiplies can then be
+ * folded into the multiplications or divisions by the JPEG quantization
+ * table entries. The AA&N method leaves only 5 multiplies and 29 adds
+ * to be done in the DCT itself.
+ * The primary disadvantage of this method is that with fixed-point math,
+ * accuracy is lost due to imprecise representation of the scaled
+ * quantization values. The smaller the quantization table entry, the less
+ * precise the scaled value, so this implementation does worse with high-
+ * quality-setting files than with low-quality ones.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdct.h" /* Private declarations for DCT subsystem */
+
+#ifdef DCT_IFAST_SUPPORTED
+
+
+/*
+ * This module is specialized to the case DCTSIZE = 8.
+ */
+
+#if DCTSIZE != 8
+ Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */
+#endif
+
+
+/* Scaling decisions are generally the same as in the LL&M algorithm;
+ * see jfdctint.c for more details. However, we choose to descale
+ * (right shift) multiplication products as soon as they are formed,
+ * rather than carrying additional fractional bits into subsequent additions.
+ * This compromises accuracy slightly, but it lets us save a few shifts.
+ * More importantly, 16-bit arithmetic is then adequate (for 8-bit samples)
+ * everywhere except in the multiplications proper; this saves a good deal
+ * of work on 16-bit-int machines.
+ *
+ * Again to save a few shifts, the intermediate results between pass 1 and
+ * pass 2 are not upscaled, but are represented only to integral precision.
+ *
+ * A final compromise is to represent the multiplicative constants to only
+ * 8 fractional bits, rather than 13. This saves some shifting work on some
+ * machines, and may also reduce the cost of multiplication (since there
+ * are fewer one-bits in the constants).
+ */
+
+#define CONST_BITS 8
+
+
+/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus
+ * causing a lot of useless floating-point operations at run time.
+ * To get around this we use the following pre-calculated constants.
+ * If you change CONST_BITS you may want to add appropriate values.
+ * (With a reasonable C compiler, you can just rely on the FIX() macro...)
+ */
+
+#if CONST_BITS == 8
+#define FIX_0_382683433 ((long) 98) /* FIX(0.382683433) */
+#define FIX_0_541196100 ((long) 139) /* FIX(0.541196100) */
+#define FIX_0_707106781 ((long) 181) /* FIX(0.707106781) */
+#define FIX_1_306562965 ((long) 334) /* FIX(1.306562965) */
+#else
+#define FIX_0_382683433 FIX(0.382683433)
+#define FIX_0_541196100 FIX(0.541196100)
+#define FIX_0_707106781 FIX(0.707106781)
+#define FIX_1_306562965 FIX(1.306562965)
+#endif
+
+
+/* We can gain a little more speed, with a further compromise in accuracy,
+ * by omitting the addition in a descaling shift. This yields an incorrectly
+ * rounded result half the time...
+ */
+
+#ifndef USE_ACCURATE_ROUNDING
+#undef DESCALE
+#define DESCALE(x,n) RIGHT_SHIFT(x, n)
+#endif
+
+
+/* Multiply a DCTELEM variable by an long constant, and immediately
+ * descale to yield a DCTELEM result.
+ */
+
+#define MULTIPLY(var,const) ((DCTELEM) DESCALE((var) * (const), CONST_BITS))
+
+
+/*
+ * Perform the forward DCT on one block of samples.
+ */
+
+GLOBAL(void)
+jpeg_fdct_ifast (DCTELEM * data)
+{
+ DCTELEM tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
+ DCTELEM tmp10, tmp11, tmp12, tmp13;
+ DCTELEM z1, z2, z3, z4, z5, z11, z13;
+ DCTELEM *dataptr;
+ int ctr;
+ SHIFT_TEMPS
+
+ /* Pass 1: process rows. */
+
+ dataptr = data;
+ for (ctr = DCTSIZE-1; ctr >= 0; ctr--) {
+ tmp0 = dataptr[0] + dataptr[7];
+ tmp7 = dataptr[0] - dataptr[7];
+ tmp1 = dataptr[1] + dataptr[6];
+ tmp6 = dataptr[1] - dataptr[6];
+ tmp2 = dataptr[2] + dataptr[5];
+ tmp5 = dataptr[2] - dataptr[5];
+ tmp3 = dataptr[3] + dataptr[4];
+ tmp4 = dataptr[3] - dataptr[4];
+
+ /* Even part */
+
+ tmp10 = tmp0 + tmp3; /* phase 2 */
+ tmp13 = tmp0 - tmp3;
+ tmp11 = tmp1 + tmp2;
+ tmp12 = tmp1 - tmp2;
+
+ dataptr[0] = tmp10 + tmp11; /* phase 3 */
+ dataptr[4] = tmp10 - tmp11;
+
+ z1 = MULTIPLY(tmp12 + tmp13, FIX_0_707106781); /* c4 */
+ dataptr[2] = tmp13 + z1; /* phase 5 */
+ dataptr[6] = tmp13 - z1;
+
+ /* Odd part */
+
+ tmp10 = tmp4 + tmp5; /* phase 2 */
+ tmp11 = tmp5 + tmp6;
+ tmp12 = tmp6 + tmp7;
+
+ /* The rotator is modified from fig 4-8 to avoid extra negations. */
+ z5 = MULTIPLY(tmp10 - tmp12, FIX_0_382683433); /* c6 */
+ z2 = MULTIPLY(tmp10, FIX_0_541196100) + z5; /* c2-c6 */
+ z4 = MULTIPLY(tmp12, FIX_1_306562965) + z5; /* c2+c6 */
+ z3 = MULTIPLY(tmp11, FIX_0_707106781); /* c4 */
+
+ z11 = tmp7 + z3; /* phase 5 */
+ z13 = tmp7 - z3;
+
+ dataptr[5] = z13 + z2; /* phase 6 */
+ dataptr[3] = z13 - z2;
+ dataptr[1] = z11 + z4;
+ dataptr[7] = z11 - z4;
+
+ dataptr += DCTSIZE; /* advance pointer to next row */
+ }
+
+ /* Pass 2: process columns. */
+
+ dataptr = data;
+ for (ctr = DCTSIZE-1; ctr >= 0; ctr--) {
+ tmp0 = dataptr[DCTSIZE*0] + dataptr[DCTSIZE*7];
+ tmp7 = dataptr[DCTSIZE*0] - dataptr[DCTSIZE*7];
+ tmp1 = dataptr[DCTSIZE*1] + dataptr[DCTSIZE*6];
+ tmp6 = dataptr[DCTSIZE*1] - dataptr[DCTSIZE*6];
+ tmp2 = dataptr[DCTSIZE*2] + dataptr[DCTSIZE*5];
+ tmp5 = dataptr[DCTSIZE*2] - dataptr[DCTSIZE*5];
+ tmp3 = dataptr[DCTSIZE*3] + dataptr[DCTSIZE*4];
+ tmp4 = dataptr[DCTSIZE*3] - dataptr[DCTSIZE*4];
+
+ /* Even part */
+
+ tmp10 = tmp0 + tmp3; /* phase 2 */
+ tmp13 = tmp0 - tmp3;
+ tmp11 = tmp1 + tmp2;
+ tmp12 = tmp1 - tmp2;
+
+ dataptr[DCTSIZE*0] = tmp10 + tmp11; /* phase 3 */
+ dataptr[DCTSIZE*4] = tmp10 - tmp11;
+
+ z1 = MULTIPLY(tmp12 + tmp13, FIX_0_707106781); /* c4 */
+ dataptr[DCTSIZE*2] = tmp13 + z1; /* phase 5 */
+ dataptr[DCTSIZE*6] = tmp13 - z1;
+
+ /* Odd part */
+
+ tmp10 = tmp4 + tmp5; /* phase 2 */
+ tmp11 = tmp5 + tmp6;
+ tmp12 = tmp6 + tmp7;
+
+ /* The rotator is modified from fig 4-8 to avoid extra negations. */
+ z5 = MULTIPLY(tmp10 - tmp12, FIX_0_382683433); /* c6 */
+ z2 = MULTIPLY(tmp10, FIX_0_541196100) + z5; /* c2-c6 */
+ z4 = MULTIPLY(tmp12, FIX_1_306562965) + z5; /* c2+c6 */
+ z3 = MULTIPLY(tmp11, FIX_0_707106781); /* c4 */
+
+ z11 = tmp7 + z3; /* phase 5 */
+ z13 = tmp7 - z3;
+
+ dataptr[DCTSIZE*5] = z13 + z2; /* phase 6 */
+ dataptr[DCTSIZE*3] = z13 - z2;
+ dataptr[DCTSIZE*1] = z11 + z4;
+ dataptr[DCTSIZE*7] = z11 - z4;
+
+ dataptr++; /* advance pointer to next column */
+ }
+}
+
+#endif /* DCT_IFAST_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jfdctint.cpp b/ml/dlib/dlib/external/libjpeg/jfdctint.cpp
new file mode 100644
index 000000000..b6046a26b
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jfdctint.cpp
@@ -0,0 +1,283 @@
+/*
+ * jfdctint.c
+ *
+ * Copyright (C) 1991-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains a slow-but-accurate integer implementation of the
+ * forward DCT (Discrete Cosine Transform).
+ *
+ * A 2-D DCT can be done by 1-D DCT on each row followed by 1-D DCT
+ * on each column. Direct algorithms are also available, but they are
+ * much more complex and seem not to be any faster when reduced to code.
+ *
+ * This implementation is based on an algorithm described in
+ * C. Loeffler, A. Ligtenberg and G. Moschytz, "Practical Fast 1-D DCT
+ * Algorithms with 11 Multiplications", Proc. Int'l. Conf. on Acoustics,
+ * Speech, and Signal Processing 1989 (ICASSP '89), pp. 988-991.
+ * The primary algorithm described there uses 11 multiplies and 29 adds.
+ * We use their alternate method with 12 multiplies and 32 adds.
+ * The advantage of this method is that no data path contains more than one
+ * multiplication; this allows a very simple and accurate implementation in
+ * scaled fixed-point arithmetic, with a minimal number of shifts.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdct.h" /* Private declarations for DCT subsystem */
+
+#ifdef DCT_ISLOW_SUPPORTED
+
+
+/*
+ * This module is specialized to the case DCTSIZE = 8.
+ */
+
+#if DCTSIZE != 8
+ Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */
+#endif
+
+
+/*
+ * The poop on this scaling stuff is as follows:
+ *
+ * Each 1-D DCT step produces outputs which are a factor of sqrt(N)
+ * larger than the true DCT outputs. The final outputs are therefore
+ * a factor of N larger than desired; since N=8 this can be cured by
+ * a simple right shift at the end of the algorithm. The advantage of
+ * this arrangement is that we save two multiplications per 1-D DCT,
+ * because the y0 and y4 outputs need not be divided by sqrt(N).
+ * In the IJG code, this factor of 8 is removed by the quantization step
+ * (in jcdctmgr.c), NOT in this module.
+ *
+ * We have to do addition and subtraction of the integer inputs, which
+ * is no problem, and multiplication by fractional constants, which is
+ * a problem to do in integer arithmetic. We multiply all the constants
+ * by CONST_SCALE and convert them to integer constants (thus retaining
+ * CONST_BITS bits of precision in the constants). After doing a
+ * multiplication we have to divide the product by CONST_SCALE, with proper
+ * rounding, to produce the correct output. This division can be done
+ * cheaply as a right shift of CONST_BITS bits. We postpone shifting
+ * as long as possible so that partial sums can be added together with
+ * full fractional precision.
+ *
+ * The outputs of the first pass are scaled up by PASS1_BITS bits so that
+ * they are represented to better-than-integral precision. These outputs
+ * require BITS_IN_JSAMPLE + PASS1_BITS + 3 bits; this fits in a 16-bit word
+ * with the recommended scaling. (For 12-bit sample data, the intermediate
+ * array is long anyway.)
+ *
+ * To avoid overflow of the 32-bit intermediate results in pass 2, we must
+ * have BITS_IN_JSAMPLE + CONST_BITS + PASS1_BITS <= 26. Error analysis
+ * shows that the values given below are the most effective.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+#define CONST_BITS 13
+#define PASS1_BITS 2
+#else
+#define CONST_BITS 13
+#define PASS1_BITS 1 /* lose a little precision to avoid overflow */
+#endif
+
+/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus
+ * causing a lot of useless floating-point operations at run time.
+ * To get around this we use the following pre-calculated constants.
+ * If you change CONST_BITS you may want to add appropriate values.
+ * (With a reasonable C compiler, you can just rely on the FIX() macro...)
+ */
+
+#if CONST_BITS == 13
+#define FIX_0_298631336 ((long) 2446) /* FIX(0.298631336) */
+#define FIX_0_390180644 ((long) 3196) /* FIX(0.390180644) */
+#define FIX_0_541196100 ((long) 4433) /* FIX(0.541196100) */
+#define FIX_0_765366865 ((long) 6270) /* FIX(0.765366865) */
+#define FIX_0_899976223 ((long) 7373) /* FIX(0.899976223) */
+#define FIX_1_175875602 ((long) 9633) /* FIX(1.175875602) */
+#define FIX_1_501321110 ((long) 12299) /* FIX(1.501321110) */
+#define FIX_1_847759065 ((long) 15137) /* FIX(1.847759065) */
+#define FIX_1_961570560 ((long) 16069) /* FIX(1.961570560) */
+#define FIX_2_053119869 ((long) 16819) /* FIX(2.053119869) */
+#define FIX_2_562915447 ((long) 20995) /* FIX(2.562915447) */
+#define FIX_3_072711026 ((long) 25172) /* FIX(3.072711026) */
+#else
+#define FIX_0_298631336 FIX(0.298631336)
+#define FIX_0_390180644 FIX(0.390180644)
+#define FIX_0_541196100 FIX(0.541196100)
+#define FIX_0_765366865 FIX(0.765366865)
+#define FIX_0_899976223 FIX(0.899976223)
+#define FIX_1_175875602 FIX(1.175875602)
+#define FIX_1_501321110 FIX(1.501321110)
+#define FIX_1_847759065 FIX(1.847759065)
+#define FIX_1_961570560 FIX(1.961570560)
+#define FIX_2_053119869 FIX(2.053119869)
+#define FIX_2_562915447 FIX(2.562915447)
+#define FIX_3_072711026 FIX(3.072711026)
+#endif
+
+
+/* Multiply an long variable by an long constant to yield an long result.
+ * For 8-bit samples with the recommended scaling, all the variable
+ * and constant values involved are no more than 16 bits wide, so a
+ * 16x16->32 bit multiply can be used instead of a full 32x32 multiply.
+ * For 12-bit samples, a full 32-bit multiplication will be needed.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+#define MULTIPLY(var,const) MULTIPLY16C16(var,const)
+#else
+#define MULTIPLY(var,const) ((var) * (const))
+#endif
+
+
+/*
+ * Perform the forward DCT on one block of samples.
+ */
+
+GLOBAL(void)
+jpeg_fdct_islow (DCTELEM * data)
+{
+ long tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
+ long tmp10, tmp11, tmp12, tmp13;
+ long z1, z2, z3, z4, z5;
+ DCTELEM *dataptr;
+ int ctr;
+ SHIFT_TEMPS
+
+ /* Pass 1: process rows. */
+ /* Note results are scaled up by sqrt(8) compared to a true DCT; */
+ /* furthermore, we scale the results by 2**PASS1_BITS. */
+
+ dataptr = data;
+ for (ctr = DCTSIZE-1; ctr >= 0; ctr--) {
+ tmp0 = dataptr[0] + dataptr[7];
+ tmp7 = dataptr[0] - dataptr[7];
+ tmp1 = dataptr[1] + dataptr[6];
+ tmp6 = dataptr[1] - dataptr[6];
+ tmp2 = dataptr[2] + dataptr[5];
+ tmp5 = dataptr[2] - dataptr[5];
+ tmp3 = dataptr[3] + dataptr[4];
+ tmp4 = dataptr[3] - dataptr[4];
+
+ /* Even part per LL&M figure 1 --- note that published figure is faulty;
+ * rotator "sqrt(2)*c1" should be "sqrt(2)*c6".
+ */
+
+ tmp10 = tmp0 + tmp3;
+ tmp13 = tmp0 - tmp3;
+ tmp11 = tmp1 + tmp2;
+ tmp12 = tmp1 - tmp2;
+
+ dataptr[0] = (DCTELEM) ((tmp10 + tmp11) << PASS1_BITS);
+ dataptr[4] = (DCTELEM) ((tmp10 - tmp11) << PASS1_BITS);
+
+ z1 = MULTIPLY(tmp12 + tmp13, FIX_0_541196100);
+ dataptr[2] = (DCTELEM) DESCALE(z1 + MULTIPLY(tmp13, FIX_0_765366865),
+ CONST_BITS-PASS1_BITS);
+ dataptr[6] = (DCTELEM) DESCALE(z1 + MULTIPLY(tmp12, - FIX_1_847759065),
+ CONST_BITS-PASS1_BITS);
+
+ /* Odd part per figure 8 --- note paper omits factor of sqrt(2).
+ * cK represents cos(K*pi/16).
+ * i0..i3 in the paper are tmp4..tmp7 here.
+ */
+
+ z1 = tmp4 + tmp7;
+ z2 = tmp5 + tmp6;
+ z3 = tmp4 + tmp6;
+ z4 = tmp5 + tmp7;
+ z5 = MULTIPLY(z3 + z4, FIX_1_175875602); /* sqrt(2) * c3 */
+
+ tmp4 = MULTIPLY(tmp4, FIX_0_298631336); /* sqrt(2) * (-c1+c3+c5-c7) */
+ tmp5 = MULTIPLY(tmp5, FIX_2_053119869); /* sqrt(2) * ( c1+c3-c5+c7) */
+ tmp6 = MULTIPLY(tmp6, FIX_3_072711026); /* sqrt(2) * ( c1+c3+c5-c7) */
+ tmp7 = MULTIPLY(tmp7, FIX_1_501321110); /* sqrt(2) * ( c1+c3-c5-c7) */
+ z1 = MULTIPLY(z1, - FIX_0_899976223); /* sqrt(2) * (c7-c3) */
+ z2 = MULTIPLY(z2, - FIX_2_562915447); /* sqrt(2) * (-c1-c3) */
+ z3 = MULTIPLY(z3, - FIX_1_961570560); /* sqrt(2) * (-c3-c5) */
+ z4 = MULTIPLY(z4, - FIX_0_390180644); /* sqrt(2) * (c5-c3) */
+
+ z3 += z5;
+ z4 += z5;
+
+ dataptr[7] = (DCTELEM) DESCALE(tmp4 + z1 + z3, CONST_BITS-PASS1_BITS);
+ dataptr[5] = (DCTELEM) DESCALE(tmp5 + z2 + z4, CONST_BITS-PASS1_BITS);
+ dataptr[3] = (DCTELEM) DESCALE(tmp6 + z2 + z3, CONST_BITS-PASS1_BITS);
+ dataptr[1] = (DCTELEM) DESCALE(tmp7 + z1 + z4, CONST_BITS-PASS1_BITS);
+
+ dataptr += DCTSIZE; /* advance pointer to next row */
+ }
+
+ /* Pass 2: process columns.
+ * We remove the PASS1_BITS scaling, but leave the results scaled up
+ * by an overall factor of 8.
+ */
+
+ dataptr = data;
+ for (ctr = DCTSIZE-1; ctr >= 0; ctr--) {
+ tmp0 = dataptr[DCTSIZE*0] + dataptr[DCTSIZE*7];
+ tmp7 = dataptr[DCTSIZE*0] - dataptr[DCTSIZE*7];
+ tmp1 = dataptr[DCTSIZE*1] + dataptr[DCTSIZE*6];
+ tmp6 = dataptr[DCTSIZE*1] - dataptr[DCTSIZE*6];
+ tmp2 = dataptr[DCTSIZE*2] + dataptr[DCTSIZE*5];
+ tmp5 = dataptr[DCTSIZE*2] - dataptr[DCTSIZE*5];
+ tmp3 = dataptr[DCTSIZE*3] + dataptr[DCTSIZE*4];
+ tmp4 = dataptr[DCTSIZE*3] - dataptr[DCTSIZE*4];
+
+ /* Even part per LL&M figure 1 --- note that published figure is faulty;
+ * rotator "sqrt(2)*c1" should be "sqrt(2)*c6".
+ */
+
+ tmp10 = tmp0 + tmp3;
+ tmp13 = tmp0 - tmp3;
+ tmp11 = tmp1 + tmp2;
+ tmp12 = tmp1 - tmp2;
+
+ dataptr[DCTSIZE*0] = (DCTELEM) DESCALE(tmp10 + tmp11, PASS1_BITS);
+ dataptr[DCTSIZE*4] = (DCTELEM) DESCALE(tmp10 - tmp11, PASS1_BITS);
+
+ z1 = MULTIPLY(tmp12 + tmp13, FIX_0_541196100);
+ dataptr[DCTSIZE*2] = (DCTELEM) DESCALE(z1 + MULTIPLY(tmp13, FIX_0_765366865),
+ CONST_BITS+PASS1_BITS);
+ dataptr[DCTSIZE*6] = (DCTELEM) DESCALE(z1 + MULTIPLY(tmp12, - FIX_1_847759065),
+ CONST_BITS+PASS1_BITS);
+
+ /* Odd part per figure 8 --- note paper omits factor of sqrt(2).
+ * cK represents cos(K*pi/16).
+ * i0..i3 in the paper are tmp4..tmp7 here.
+ */
+
+ z1 = tmp4 + tmp7;
+ z2 = tmp5 + tmp6;
+ z3 = tmp4 + tmp6;
+ z4 = tmp5 + tmp7;
+ z5 = MULTIPLY(z3 + z4, FIX_1_175875602); /* sqrt(2) * c3 */
+
+ tmp4 = MULTIPLY(tmp4, FIX_0_298631336); /* sqrt(2) * (-c1+c3+c5-c7) */
+ tmp5 = MULTIPLY(tmp5, FIX_2_053119869); /* sqrt(2) * ( c1+c3-c5+c7) */
+ tmp6 = MULTIPLY(tmp6, FIX_3_072711026); /* sqrt(2) * ( c1+c3+c5-c7) */
+ tmp7 = MULTIPLY(tmp7, FIX_1_501321110); /* sqrt(2) * ( c1+c3-c5-c7) */
+ z1 = MULTIPLY(z1, - FIX_0_899976223); /* sqrt(2) * (c7-c3) */
+ z2 = MULTIPLY(z2, - FIX_2_562915447); /* sqrt(2) * (-c1-c3) */
+ z3 = MULTIPLY(z3, - FIX_1_961570560); /* sqrt(2) * (-c3-c5) */
+ z4 = MULTIPLY(z4, - FIX_0_390180644); /* sqrt(2) * (c5-c3) */
+
+ z3 += z5;
+ z4 += z5;
+
+ dataptr[DCTSIZE*7] = (DCTELEM) DESCALE(tmp4 + z1 + z3,
+ CONST_BITS+PASS1_BITS);
+ dataptr[DCTSIZE*5] = (DCTELEM) DESCALE(tmp5 + z2 + z4,
+ CONST_BITS+PASS1_BITS);
+ dataptr[DCTSIZE*3] = (DCTELEM) DESCALE(tmp6 + z2 + z3,
+ CONST_BITS+PASS1_BITS);
+ dataptr[DCTSIZE*1] = (DCTELEM) DESCALE(tmp7 + z1 + z4,
+ CONST_BITS+PASS1_BITS);
+
+ dataptr++; /* advance pointer to next column */
+ }
+}
+
+#endif /* DCT_ISLOW_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jidctflt.cpp b/ml/dlib/dlib/external/libjpeg/jidctflt.cpp
new file mode 100644
index 000000000..3e9b54579
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jidctflt.cpp
@@ -0,0 +1,242 @@
+/*
+ * jidctflt.c
+ *
+ * Copyright (C) 1994-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains a floating-point implementation of the
+ * inverse DCT (Discrete Cosine Transform). In the IJG code, this routine
+ * must also perform dequantization of the input coefficients.
+ *
+ * This implementation should be more accurate than either of the integer
+ * IDCT implementations. However, it may not give the same results on all
+ * machines because of differences in roundoff behavior. Speed will depend
+ * on the hardware's floating point capacity.
+ *
+ * A 2-D IDCT can be done by 1-D IDCT on each column followed by 1-D IDCT
+ * on each row (or vice versa, but it's more convenient to emit a row at
+ * a time). Direct algorithms are also available, but they are much more
+ * complex and seem not to be any faster when reduced to code.
+ *
+ * This implementation is based on Arai, Agui, and Nakajima's algorithm for
+ * scaled DCT. Their original paper (Trans. IEICE E-71(11):1095) is in
+ * Japanese, but the algorithm is described in the Pennebaker & Mitchell
+ * JPEG textbook (see REFERENCES section in file README). The following code
+ * is based directly on figure 4-8 in P&M.
+ * While an 8-point DCT cannot be done in less than 11 multiplies, it is
+ * possible to arrange the computation so that many of the multiplies are
+ * simple scalings of the final outputs. These multiplies can then be
+ * folded into the multiplications or divisions by the JPEG quantization
+ * table entries. The AA&N method leaves only 5 multiplies and 29 adds
+ * to be done in the DCT itself.
+ * The primary disadvantage of this method is that with a fixed-point
+ * implementation, accuracy is lost due to imprecise representation of the
+ * scaled quantization values. However, that problem does not arise if
+ * we use floating point arithmetic.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdct.h" /* Private declarations for DCT subsystem */
+
+#ifdef DCT_FLOAT_SUPPORTED
+
+
+/*
+ * This module is specialized to the case DCTSIZE = 8.
+ */
+
+#if DCTSIZE != 8
+ Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */
+#endif
+
+
+/* Dequantize a coefficient by multiplying it by the multiplier-table
+ * entry; produce a float result.
+ */
+
+#define DEQUANTIZE(coef,quantval) (((FAST_FLOAT) (coef)) * (quantval))
+
+
+/*
+ * Perform dequantization and inverse DCT on one block of coefficients.
+ */
+
+GLOBAL(void)
+jpeg_idct_float (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block,
+ JSAMPARRAY output_buf, JDIMENSION output_col)
+{
+ FAST_FLOAT tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
+ FAST_FLOAT tmp10, tmp11, tmp12, tmp13;
+ FAST_FLOAT z5, z10, z11, z12, z13;
+ JCOEFPTR inptr;
+ FLOAT_MULT_TYPE * quantptr;
+ FAST_FLOAT * wsptr;
+ JSAMPROW outptr;
+ JSAMPLE *range_limit = IDCT_range_limit(cinfo);
+ int ctr;
+ FAST_FLOAT workspace[DCTSIZE2]; /* buffers data between passes */
+ SHIFT_TEMPS
+
+ /* Pass 1: process columns from input, store into work array. */
+
+ inptr = coef_block;
+ quantptr = (FLOAT_MULT_TYPE *) compptr->dct_table;
+ wsptr = workspace;
+ for (ctr = DCTSIZE; ctr > 0; ctr--) {
+ /* Due to quantization, we will usually find that many of the input
+ * coefficients are zero, especially the AC terms. We can exploit this
+ * by short-circuiting the IDCT calculation for any column in which all
+ * the AC terms are zero. In that case each output is equal to the
+ * DC coefficient (with scale factor as needed).
+ * With typical images and quantization tables, half or more of the
+ * column DCT calculations can be simplified this way.
+ */
+
+ if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*2] == 0 &&
+ inptr[DCTSIZE*3] == 0 && inptr[DCTSIZE*4] == 0 &&
+ inptr[DCTSIZE*5] == 0 && inptr[DCTSIZE*6] == 0 &&
+ inptr[DCTSIZE*7] == 0) {
+ /* AC terms all zero */
+ FAST_FLOAT dcval = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]);
+
+ wsptr[DCTSIZE*0] = dcval;
+ wsptr[DCTSIZE*1] = dcval;
+ wsptr[DCTSIZE*2] = dcval;
+ wsptr[DCTSIZE*3] = dcval;
+ wsptr[DCTSIZE*4] = dcval;
+ wsptr[DCTSIZE*5] = dcval;
+ wsptr[DCTSIZE*6] = dcval;
+ wsptr[DCTSIZE*7] = dcval;
+
+ inptr++; /* advance pointers to next column */
+ quantptr++;
+ wsptr++;
+ continue;
+ }
+
+ /* Even part */
+
+ tmp0 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]);
+ tmp1 = DEQUANTIZE(inptr[DCTSIZE*2], quantptr[DCTSIZE*2]);
+ tmp2 = DEQUANTIZE(inptr[DCTSIZE*4], quantptr[DCTSIZE*4]);
+ tmp3 = DEQUANTIZE(inptr[DCTSIZE*6], quantptr[DCTSIZE*6]);
+
+ tmp10 = tmp0 + tmp2; /* phase 3 */
+ tmp11 = tmp0 - tmp2;
+
+ tmp13 = tmp1 + tmp3; /* phases 5-3 */
+ tmp12 = (tmp1 - tmp3) * ((FAST_FLOAT) 1.414213562) - tmp13; /* 2*c4 */
+
+ tmp0 = tmp10 + tmp13; /* phase 2 */
+ tmp3 = tmp10 - tmp13;
+ tmp1 = tmp11 + tmp12;
+ tmp2 = tmp11 - tmp12;
+
+ /* Odd part */
+
+ tmp4 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]);
+ tmp5 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]);
+ tmp6 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]);
+ tmp7 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]);
+
+ z13 = tmp6 + tmp5; /* phase 6 */
+ z10 = tmp6 - tmp5;
+ z11 = tmp4 + tmp7;
+ z12 = tmp4 - tmp7;
+
+ tmp7 = z11 + z13; /* phase 5 */
+ tmp11 = (z11 - z13) * ((FAST_FLOAT) 1.414213562); /* 2*c4 */
+
+ z5 = (z10 + z12) * ((FAST_FLOAT) 1.847759065); /* 2*c2 */
+ tmp10 = ((FAST_FLOAT) 1.082392200) * z12 - z5; /* 2*(c2-c6) */
+ tmp12 = ((FAST_FLOAT) -2.613125930) * z10 + z5; /* -2*(c2+c6) */
+
+ tmp6 = tmp12 - tmp7; /* phase 2 */
+ tmp5 = tmp11 - tmp6;
+ tmp4 = tmp10 + tmp5;
+
+ wsptr[DCTSIZE*0] = tmp0 + tmp7;
+ wsptr[DCTSIZE*7] = tmp0 - tmp7;
+ wsptr[DCTSIZE*1] = tmp1 + tmp6;
+ wsptr[DCTSIZE*6] = tmp1 - tmp6;
+ wsptr[DCTSIZE*2] = tmp2 + tmp5;
+ wsptr[DCTSIZE*5] = tmp2 - tmp5;
+ wsptr[DCTSIZE*4] = tmp3 + tmp4;
+ wsptr[DCTSIZE*3] = tmp3 - tmp4;
+
+ inptr++; /* advance pointers to next column */
+ quantptr++;
+ wsptr++;
+ }
+
+ /* Pass 2: process rows from work array, store into output array. */
+ /* Note that we must descale the results by a factor of 8 == 2**3. */
+
+ wsptr = workspace;
+ for (ctr = 0; ctr < DCTSIZE; ctr++) {
+ outptr = output_buf[ctr] + output_col;
+ /* Rows of zeroes can be exploited in the same way as we did with columns.
+ * However, the column calculation has created many nonzero AC terms, so
+ * the simplification applies less often (typically 5% to 10% of the time).
+ * And testing floats for zero is relatively expensive, so we don't bother.
+ */
+
+ /* Even part */
+
+ tmp10 = wsptr[0] + wsptr[4];
+ tmp11 = wsptr[0] - wsptr[4];
+
+ tmp13 = wsptr[2] + wsptr[6];
+ tmp12 = (wsptr[2] - wsptr[6]) * ((FAST_FLOAT) 1.414213562) - tmp13;
+
+ tmp0 = tmp10 + tmp13;
+ tmp3 = tmp10 - tmp13;
+ tmp1 = tmp11 + tmp12;
+ tmp2 = tmp11 - tmp12;
+
+ /* Odd part */
+
+ z13 = wsptr[5] + wsptr[3];
+ z10 = wsptr[5] - wsptr[3];
+ z11 = wsptr[1] + wsptr[7];
+ z12 = wsptr[1] - wsptr[7];
+
+ tmp7 = z11 + z13;
+ tmp11 = (z11 - z13) * ((FAST_FLOAT) 1.414213562);
+
+ z5 = (z10 + z12) * ((FAST_FLOAT) 1.847759065); /* 2*c2 */
+ tmp10 = ((FAST_FLOAT) 1.082392200) * z12 - z5; /* 2*(c2-c6) */
+ tmp12 = ((FAST_FLOAT) -2.613125930) * z10 + z5; /* -2*(c2+c6) */
+
+ tmp6 = tmp12 - tmp7;
+ tmp5 = tmp11 - tmp6;
+ tmp4 = tmp10 + tmp5;
+
+ /* Final output stage: scale down by a factor of 8 and range-limit */
+
+ outptr[0] = range_limit[(int) DESCALE((long) (tmp0 + tmp7), 3)
+ & RANGE_MASK];
+ outptr[7] = range_limit[(int) DESCALE((long) (tmp0 - tmp7), 3)
+ & RANGE_MASK];
+ outptr[1] = range_limit[(int) DESCALE((long) (tmp1 + tmp6), 3)
+ & RANGE_MASK];
+ outptr[6] = range_limit[(int) DESCALE((long) (tmp1 - tmp6), 3)
+ & RANGE_MASK];
+ outptr[2] = range_limit[(int) DESCALE((long) (tmp2 + tmp5), 3)
+ & RANGE_MASK];
+ outptr[5] = range_limit[(int) DESCALE((long) (tmp2 - tmp5), 3)
+ & RANGE_MASK];
+ outptr[4] = range_limit[(int) DESCALE((long) (tmp3 + tmp4), 3)
+ & RANGE_MASK];
+ outptr[3] = range_limit[(int) DESCALE((long) (tmp3 - tmp4), 3)
+ & RANGE_MASK];
+
+ wsptr += DCTSIZE; /* advance pointer to next row */
+ }
+}
+
+#endif /* DCT_FLOAT_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jidctfst.cpp b/ml/dlib/dlib/external/libjpeg/jidctfst.cpp
new file mode 100644
index 000000000..e08835bc5
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jidctfst.cpp
@@ -0,0 +1,368 @@
+/*
+ * jidctfst.c
+ *
+ * Copyright (C) 1994-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains a fast, not so accurate integer implementation of the
+ * inverse DCT (Discrete Cosine Transform). In the IJG code, this routine
+ * must also perform dequantization of the input coefficients.
+ *
+ * A 2-D IDCT can be done by 1-D IDCT on each column followed by 1-D IDCT
+ * on each row (or vice versa, but it's more convenient to emit a row at
+ * a time). Direct algorithms are also available, but they are much more
+ * complex and seem not to be any faster when reduced to code.
+ *
+ * This implementation is based on Arai, Agui, and Nakajima's algorithm for
+ * scaled DCT. Their original paper (Trans. IEICE E-71(11):1095) is in
+ * Japanese, but the algorithm is described in the Pennebaker & Mitchell
+ * JPEG textbook (see REFERENCES section in file README). The following code
+ * is based directly on figure 4-8 in P&M.
+ * While an 8-point DCT cannot be done in less than 11 multiplies, it is
+ * possible to arrange the computation so that many of the multiplies are
+ * simple scalings of the final outputs. These multiplies can then be
+ * folded into the multiplications or divisions by the JPEG quantization
+ * table entries. The AA&N method leaves only 5 multiplies and 29 adds
+ * to be done in the DCT itself.
+ * The primary disadvantage of this method is that with fixed-point math,
+ * accuracy is lost due to imprecise representation of the scaled
+ * quantization values. The smaller the quantization table entry, the less
+ * precise the scaled value, so this implementation does worse with high-
+ * quality-setting files than with low-quality ones.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdct.h" /* Private declarations for DCT subsystem */
+
+#ifdef DCT_IFAST_SUPPORTED
+
+
+/*
+ * This module is specialized to the case DCTSIZE = 8.
+ */
+
+#if DCTSIZE != 8
+ Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */
+#endif
+
+
+/* Scaling decisions are generally the same as in the LL&M algorithm;
+ * see jidctint.c for more details. However, we choose to descale
+ * (right shift) multiplication products as soon as they are formed,
+ * rather than carrying additional fractional bits into subsequent additions.
+ * This compromises accuracy slightly, but it lets us save a few shifts.
+ * More importantly, 16-bit arithmetic is then adequate (for 8-bit samples)
+ * everywhere except in the multiplications proper; this saves a good deal
+ * of work on 16-bit-int machines.
+ *
+ * The dequantized coefficients are not integers because the AA&N scaling
+ * factors have been incorporated. We represent them scaled up by PASS1_BITS,
+ * so that the first and second IDCT rounds have the same input scaling.
+ * For 8-bit JSAMPLEs, we choose IFAST_SCALE_BITS = PASS1_BITS so as to
+ * avoid a descaling shift; this compromises accuracy rather drastically
+ * for small quantization table entries, but it saves a lot of shifts.
+ * For 12-bit JSAMPLEs, there's no hope of using 16x16 multiplies anyway,
+ * so we use a much larger scaling factor to preserve accuracy.
+ *
+ * A final compromise is to represent the multiplicative constants to only
+ * 8 fractional bits, rather than 13. This saves some shifting work on some
+ * machines, and may also reduce the cost of multiplication (since there
+ * are fewer one-bits in the constants).
+ */
+
+#if BITS_IN_JSAMPLE == 8
+#define CONST_BITS 8
+#define PASS1_BITS 2
+#else
+#define CONST_BITS 8
+#define PASS1_BITS 1 /* lose a little precision to avoid overflow */
+#endif
+
+/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus
+ * causing a lot of useless floating-point operations at run time.
+ * To get around this we use the following pre-calculated constants.
+ * If you change CONST_BITS you may want to add appropriate values.
+ * (With a reasonable C compiler, you can just rely on the FIX() macro...)
+ */
+
+#if CONST_BITS == 8
+#define FIX_1_082392200 ((long) 277) /* FIX(1.082392200) */
+#define FIX_1_414213562 ((long) 362) /* FIX(1.414213562) */
+#define FIX_1_847759065 ((long) 473) /* FIX(1.847759065) */
+#define FIX_2_613125930 ((long) 669) /* FIX(2.613125930) */
+#else
+#define FIX_1_082392200 FIX(1.082392200)
+#define FIX_1_414213562 FIX(1.414213562)
+#define FIX_1_847759065 FIX(1.847759065)
+#define FIX_2_613125930 FIX(2.613125930)
+#endif
+
+
+/* We can gain a little more speed, with a further compromise in accuracy,
+ * by omitting the addition in a descaling shift. This yields an incorrectly
+ * rounded result half the time...
+ */
+
+#ifndef USE_ACCURATE_ROUNDING
+#undef DESCALE
+#define DESCALE(x,n) RIGHT_SHIFT(x, n)
+#endif
+
+
+/* Multiply a DCTELEM variable by an long constant, and immediately
+ * descale to yield a DCTELEM result.
+ */
+
+#define MULTIPLY(var,const) ((DCTELEM) DESCALE((var) * (const), CONST_BITS))
+
+
+/* Dequantize a coefficient by multiplying it by the multiplier-table
+ * entry; produce a DCTELEM result. For 8-bit data a 16x16->16
+ * multiplication will do. For 12-bit data, the multiplier table is
+ * declared long, so a 32-bit multiply will be used.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+#define DEQUANTIZE(coef,quantval) (((IFAST_MULT_TYPE) (coef)) * (quantval))
+#else
+#define DEQUANTIZE(coef,quantval) \
+ DESCALE((coef)*(quantval), IFAST_SCALE_BITS-PASS1_BITS)
+#endif
+
+
+/* Like DESCALE, but applies to a DCTELEM and produces an int.
+ * We assume that int right shift is unsigned if long right shift is.
+ */
+
+#ifdef RIGHT_SHIFT_IS_UNSIGNED
+#define ISHIFT_TEMPS DCTELEM ishift_temp;
+#if BITS_IN_JSAMPLE == 8
+#define DCTELEMBITS 16 /* DCTELEM may be 16 or 32 bits */
+#else
+#define DCTELEMBITS 32 /* DCTELEM must be 32 bits */
+#endif
+#define IRIGHT_SHIFT(x,shft) \
+ ((ishift_temp = (x)) < 0 ? \
+ (ishift_temp >> (shft)) | ((~((DCTELEM) 0)) << (DCTELEMBITS-(shft))) : \
+ (ishift_temp >> (shft)))
+#else
+#define ISHIFT_TEMPS
+#define IRIGHT_SHIFT(x,shft) ((x) >> (shft))
+#endif
+
+#ifdef USE_ACCURATE_ROUNDING
+#define IDESCALE(x,n) ((int) IRIGHT_SHIFT((x) + (1 << ((n)-1)), n))
+#else
+#define IDESCALE(x,n) ((int) IRIGHT_SHIFT(x, n))
+#endif
+
+
+/*
+ * Perform dequantization and inverse DCT on one block of coefficients.
+ */
+
+GLOBAL(void)
+jpeg_idct_ifast (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block,
+ JSAMPARRAY output_buf, JDIMENSION output_col)
+{
+ DCTELEM tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
+ DCTELEM tmp10, tmp11, tmp12, tmp13;
+ DCTELEM z5, z10, z11, z12, z13;
+ JCOEFPTR inptr;
+ IFAST_MULT_TYPE * quantptr;
+ int * wsptr;
+ JSAMPROW outptr;
+ JSAMPLE *range_limit = IDCT_range_limit(cinfo);
+ int ctr;
+ int workspace[DCTSIZE2]; /* buffers data between passes */
+ SHIFT_TEMPS /* for DESCALE */
+ ISHIFT_TEMPS /* for IDESCALE */
+
+ /* Pass 1: process columns from input, store into work array. */
+
+ inptr = coef_block;
+ quantptr = (IFAST_MULT_TYPE *) compptr->dct_table;
+ wsptr = workspace;
+ for (ctr = DCTSIZE; ctr > 0; ctr--) {
+ /* Due to quantization, we will usually find that many of the input
+ * coefficients are zero, especially the AC terms. We can exploit this
+ * by short-circuiting the IDCT calculation for any column in which all
+ * the AC terms are zero. In that case each output is equal to the
+ * DC coefficient (with scale factor as needed).
+ * With typical images and quantization tables, half or more of the
+ * column DCT calculations can be simplified this way.
+ */
+
+ if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*2] == 0 &&
+ inptr[DCTSIZE*3] == 0 && inptr[DCTSIZE*4] == 0 &&
+ inptr[DCTSIZE*5] == 0 && inptr[DCTSIZE*6] == 0 &&
+ inptr[DCTSIZE*7] == 0) {
+ /* AC terms all zero */
+ int dcval = (int) DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]);
+
+ wsptr[DCTSIZE*0] = dcval;
+ wsptr[DCTSIZE*1] = dcval;
+ wsptr[DCTSIZE*2] = dcval;
+ wsptr[DCTSIZE*3] = dcval;
+ wsptr[DCTSIZE*4] = dcval;
+ wsptr[DCTSIZE*5] = dcval;
+ wsptr[DCTSIZE*6] = dcval;
+ wsptr[DCTSIZE*7] = dcval;
+
+ inptr++; /* advance pointers to next column */
+ quantptr++;
+ wsptr++;
+ continue;
+ }
+
+ /* Even part */
+
+ tmp0 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]);
+ tmp1 = DEQUANTIZE(inptr[DCTSIZE*2], quantptr[DCTSIZE*2]);
+ tmp2 = DEQUANTIZE(inptr[DCTSIZE*4], quantptr[DCTSIZE*4]);
+ tmp3 = DEQUANTIZE(inptr[DCTSIZE*6], quantptr[DCTSIZE*6]);
+
+ tmp10 = tmp0 + tmp2; /* phase 3 */
+ tmp11 = tmp0 - tmp2;
+
+ tmp13 = tmp1 + tmp3; /* phases 5-3 */
+ tmp12 = MULTIPLY(tmp1 - tmp3, FIX_1_414213562) - tmp13; /* 2*c4 */
+
+ tmp0 = tmp10 + tmp13; /* phase 2 */
+ tmp3 = tmp10 - tmp13;
+ tmp1 = tmp11 + tmp12;
+ tmp2 = tmp11 - tmp12;
+
+ /* Odd part */
+
+ tmp4 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]);
+ tmp5 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]);
+ tmp6 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]);
+ tmp7 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]);
+
+ z13 = tmp6 + tmp5; /* phase 6 */
+ z10 = tmp6 - tmp5;
+ z11 = tmp4 + tmp7;
+ z12 = tmp4 - tmp7;
+
+ tmp7 = z11 + z13; /* phase 5 */
+ tmp11 = MULTIPLY(z11 - z13, FIX_1_414213562); /* 2*c4 */
+
+ z5 = MULTIPLY(z10 + z12, FIX_1_847759065); /* 2*c2 */
+ tmp10 = MULTIPLY(z12, FIX_1_082392200) - z5; /* 2*(c2-c6) */
+ tmp12 = MULTIPLY(z10, - FIX_2_613125930) + z5; /* -2*(c2+c6) */
+
+ tmp6 = tmp12 - tmp7; /* phase 2 */
+ tmp5 = tmp11 - tmp6;
+ tmp4 = tmp10 + tmp5;
+
+ wsptr[DCTSIZE*0] = (int) (tmp0 + tmp7);
+ wsptr[DCTSIZE*7] = (int) (tmp0 - tmp7);
+ wsptr[DCTSIZE*1] = (int) (tmp1 + tmp6);
+ wsptr[DCTSIZE*6] = (int) (tmp1 - tmp6);
+ wsptr[DCTSIZE*2] = (int) (tmp2 + tmp5);
+ wsptr[DCTSIZE*5] = (int) (tmp2 - tmp5);
+ wsptr[DCTSIZE*4] = (int) (tmp3 + tmp4);
+ wsptr[DCTSIZE*3] = (int) (tmp3 - tmp4);
+
+ inptr++; /* advance pointers to next column */
+ quantptr++;
+ wsptr++;
+ }
+
+ /* Pass 2: process rows from work array, store into output array. */
+ /* Note that we must descale the results by a factor of 8 == 2**3, */
+ /* and also undo the PASS1_BITS scaling. */
+
+ wsptr = workspace;
+ for (ctr = 0; ctr < DCTSIZE; ctr++) {
+ outptr = output_buf[ctr] + output_col;
+ /* Rows of zeroes can be exploited in the same way as we did with columns.
+ * However, the column calculation has created many nonzero AC terms, so
+ * the simplification applies less often (typically 5% to 10% of the time).
+ * On machines with very fast multiplication, it's possible that the
+ * test takes more time than it's worth. In that case this section
+ * may be commented out.
+ */
+
+#ifndef NO_ZERO_ROW_TEST
+ if (wsptr[1] == 0 && wsptr[2] == 0 && wsptr[3] == 0 && wsptr[4] == 0 &&
+ wsptr[5] == 0 && wsptr[6] == 0 && wsptr[7] == 0) {
+ /* AC terms all zero */
+ JSAMPLE dcval = range_limit[IDESCALE(wsptr[0], PASS1_BITS+3)
+ & RANGE_MASK];
+
+ outptr[0] = dcval;
+ outptr[1] = dcval;
+ outptr[2] = dcval;
+ outptr[3] = dcval;
+ outptr[4] = dcval;
+ outptr[5] = dcval;
+ outptr[6] = dcval;
+ outptr[7] = dcval;
+
+ wsptr += DCTSIZE; /* advance pointer to next row */
+ continue;
+ }
+#endif
+
+ /* Even part */
+
+ tmp10 = ((DCTELEM) wsptr[0] + (DCTELEM) wsptr[4]);
+ tmp11 = ((DCTELEM) wsptr[0] - (DCTELEM) wsptr[4]);
+
+ tmp13 = ((DCTELEM) wsptr[2] + (DCTELEM) wsptr[6]);
+ tmp12 = MULTIPLY((DCTELEM) wsptr[2] - (DCTELEM) wsptr[6], FIX_1_414213562)
+ - tmp13;
+
+ tmp0 = tmp10 + tmp13;
+ tmp3 = tmp10 - tmp13;
+ tmp1 = tmp11 + tmp12;
+ tmp2 = tmp11 - tmp12;
+
+ /* Odd part */
+
+ z13 = (DCTELEM) wsptr[5] + (DCTELEM) wsptr[3];
+ z10 = (DCTELEM) wsptr[5] - (DCTELEM) wsptr[3];
+ z11 = (DCTELEM) wsptr[1] + (DCTELEM) wsptr[7];
+ z12 = (DCTELEM) wsptr[1] - (DCTELEM) wsptr[7];
+
+ tmp7 = z11 + z13; /* phase 5 */
+ tmp11 = MULTIPLY(z11 - z13, FIX_1_414213562); /* 2*c4 */
+
+ z5 = MULTIPLY(z10 + z12, FIX_1_847759065); /* 2*c2 */
+ tmp10 = MULTIPLY(z12, FIX_1_082392200) - z5; /* 2*(c2-c6) */
+ tmp12 = MULTIPLY(z10, - FIX_2_613125930) + z5; /* -2*(c2+c6) */
+
+ tmp6 = tmp12 - tmp7; /* phase 2 */
+ tmp5 = tmp11 - tmp6;
+ tmp4 = tmp10 + tmp5;
+
+ /* Final output stage: scale down by a factor of 8 and range-limit */
+
+ outptr[0] = range_limit[IDESCALE(tmp0 + tmp7, PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[7] = range_limit[IDESCALE(tmp0 - tmp7, PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[1] = range_limit[IDESCALE(tmp1 + tmp6, PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[6] = range_limit[IDESCALE(tmp1 - tmp6, PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[2] = range_limit[IDESCALE(tmp2 + tmp5, PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[5] = range_limit[IDESCALE(tmp2 - tmp5, PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[4] = range_limit[IDESCALE(tmp3 + tmp4, PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[3] = range_limit[IDESCALE(tmp3 - tmp4, PASS1_BITS+3)
+ & RANGE_MASK];
+
+ wsptr += DCTSIZE; /* advance pointer to next row */
+ }
+}
+
+#endif /* DCT_IFAST_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jidctint.cpp b/ml/dlib/dlib/external/libjpeg/jidctint.cpp
new file mode 100644
index 000000000..630b4fb89
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jidctint.cpp
@@ -0,0 +1,389 @@
+/*
+ * jidctint.c
+ *
+ * Copyright (C) 1991-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains a slow-but-accurate integer implementation of the
+ * inverse DCT (Discrete Cosine Transform). In the IJG code, this routine
+ * must also perform dequantization of the input coefficients.
+ *
+ * A 2-D IDCT can be done by 1-D IDCT on each column followed by 1-D IDCT
+ * on each row (or vice versa, but it's more convenient to emit a row at
+ * a time). Direct algorithms are also available, but they are much more
+ * complex and seem not to be any faster when reduced to code.
+ *
+ * This implementation is based on an algorithm described in
+ * C. Loeffler, A. Ligtenberg and G. Moschytz, "Practical Fast 1-D DCT
+ * Algorithms with 11 Multiplications", Proc. Int'l. Conf. on Acoustics,
+ * Speech, and Signal Processing 1989 (ICASSP '89), pp. 988-991.
+ * The primary algorithm described there uses 11 multiplies and 29 adds.
+ * We use their alternate method with 12 multiplies and 32 adds.
+ * The advantage of this method is that no data path contains more than one
+ * multiplication; this allows a very simple and accurate implementation in
+ * scaled fixed-point arithmetic, with a minimal number of shifts.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdct.h" /* Private declarations for DCT subsystem */
+
+#ifdef DCT_ISLOW_SUPPORTED
+
+
+/*
+ * This module is specialized to the case DCTSIZE = 8.
+ */
+
+#if DCTSIZE != 8
+ Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */
+#endif
+
+
+/*
+ * The poop on this scaling stuff is as follows:
+ *
+ * Each 1-D IDCT step produces outputs which are a factor of sqrt(N)
+ * larger than the true IDCT outputs. The final outputs are therefore
+ * a factor of N larger than desired; since N=8 this can be cured by
+ * a simple right shift at the end of the algorithm. The advantage of
+ * this arrangement is that we save two multiplications per 1-D IDCT,
+ * because the y0 and y4 inputs need not be divided by sqrt(N).
+ *
+ * We have to do addition and subtraction of the integer inputs, which
+ * is no problem, and multiplication by fractional constants, which is
+ * a problem to do in integer arithmetic. We multiply all the constants
+ * by CONST_SCALE and convert them to integer constants (thus retaining
+ * CONST_BITS bits of precision in the constants). After doing a
+ * multiplication we have to divide the product by CONST_SCALE, with proper
+ * rounding, to produce the correct output. This division can be done
+ * cheaply as a right shift of CONST_BITS bits. We postpone shifting
+ * as long as possible so that partial sums can be added together with
+ * full fractional precision.
+ *
+ * The outputs of the first pass are scaled up by PASS1_BITS bits so that
+ * they are represented to better-than-integral precision. These outputs
+ * require BITS_IN_JSAMPLE + PASS1_BITS + 3 bits; this fits in a 16-bit word
+ * with the recommended scaling. (To scale up 12-bit sample data further, an
+ * intermediate long array would be needed.)
+ *
+ * To avoid overflow of the 32-bit intermediate results in pass 2, we must
+ * have BITS_IN_JSAMPLE + CONST_BITS + PASS1_BITS <= 26. Error analysis
+ * shows that the values given below are the most effective.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+#define CONST_BITS 13
+#define PASS1_BITS 2
+#else
+#define CONST_BITS 13
+#define PASS1_BITS 1 /* lose a little precision to avoid overflow */
+#endif
+
+/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus
+ * causing a lot of useless floating-point operations at run time.
+ * To get around this we use the following pre-calculated constants.
+ * If you change CONST_BITS you may want to add appropriate values.
+ * (With a reasonable C compiler, you can just rely on the FIX() macro...)
+ */
+
+#if CONST_BITS == 13
+#define FIX_0_298631336 ((long) 2446) /* FIX(0.298631336) */
+#define FIX_0_390180644 ((long) 3196) /* FIX(0.390180644) */
+#define FIX_0_541196100 ((long) 4433) /* FIX(0.541196100) */
+#define FIX_0_765366865 ((long) 6270) /* FIX(0.765366865) */
+#define FIX_0_899976223 ((long) 7373) /* FIX(0.899976223) */
+#define FIX_1_175875602 ((long) 9633) /* FIX(1.175875602) */
+#define FIX_1_501321110 ((long) 12299) /* FIX(1.501321110) */
+#define FIX_1_847759065 ((long) 15137) /* FIX(1.847759065) */
+#define FIX_1_961570560 ((long) 16069) /* FIX(1.961570560) */
+#define FIX_2_053119869 ((long) 16819) /* FIX(2.053119869) */
+#define FIX_2_562915447 ((long) 20995) /* FIX(2.562915447) */
+#define FIX_3_072711026 ((long) 25172) /* FIX(3.072711026) */
+#else
+#define FIX_0_298631336 FIX(0.298631336)
+#define FIX_0_390180644 FIX(0.390180644)
+#define FIX_0_541196100 FIX(0.541196100)
+#define FIX_0_765366865 FIX(0.765366865)
+#define FIX_0_899976223 FIX(0.899976223)
+#define FIX_1_175875602 FIX(1.175875602)
+#define FIX_1_501321110 FIX(1.501321110)
+#define FIX_1_847759065 FIX(1.847759065)
+#define FIX_1_961570560 FIX(1.961570560)
+#define FIX_2_053119869 FIX(2.053119869)
+#define FIX_2_562915447 FIX(2.562915447)
+#define FIX_3_072711026 FIX(3.072711026)
+#endif
+
+
+/* Multiply an long variable by an long constant to yield an long result.
+ * For 8-bit samples with the recommended scaling, all the variable
+ * and constant values involved are no more than 16 bits wide, so a
+ * 16x16->32 bit multiply can be used instead of a full 32x32 multiply.
+ * For 12-bit samples, a full 32-bit multiplication will be needed.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+#define MULTIPLY(var,const) MULTIPLY16C16(var,const)
+#else
+#define MULTIPLY(var,const) ((var) * (const))
+#endif
+
+
+/* Dequantize a coefficient by multiplying it by the multiplier-table
+ * entry; produce an int result. In this module, both inputs and result
+ * are 16 bits or less, so either int or short multiply will work.
+ */
+
+#define DEQUANTIZE(coef,quantval) (((ISLOW_MULT_TYPE) (coef)) * (quantval))
+
+
+/*
+ * Perform dequantization and inverse DCT on one block of coefficients.
+ */
+
+GLOBAL(void)
+jpeg_idct_islow (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block,
+ JSAMPARRAY output_buf, JDIMENSION output_col)
+{
+ long tmp0, tmp1, tmp2, tmp3;
+ long tmp10, tmp11, tmp12, tmp13;
+ long z1, z2, z3, z4, z5;
+ JCOEFPTR inptr;
+ ISLOW_MULT_TYPE * quantptr;
+ int * wsptr;
+ JSAMPROW outptr;
+ JSAMPLE *range_limit = IDCT_range_limit(cinfo);
+ int ctr;
+ int workspace[DCTSIZE2]; /* buffers data between passes */
+ SHIFT_TEMPS
+
+ /* Pass 1: process columns from input, store into work array. */
+ /* Note results are scaled up by sqrt(8) compared to a true IDCT; */
+ /* furthermore, we scale the results by 2**PASS1_BITS. */
+
+ inptr = coef_block;
+ quantptr = (ISLOW_MULT_TYPE *) compptr->dct_table;
+ wsptr = workspace;
+ for (ctr = DCTSIZE; ctr > 0; ctr--) {
+ /* Due to quantization, we will usually find that many of the input
+ * coefficients are zero, especially the AC terms. We can exploit this
+ * by short-circuiting the IDCT calculation for any column in which all
+ * the AC terms are zero. In that case each output is equal to the
+ * DC coefficient (with scale factor as needed).
+ * With typical images and quantization tables, half or more of the
+ * column DCT calculations can be simplified this way.
+ */
+
+ if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*2] == 0 &&
+ inptr[DCTSIZE*3] == 0 && inptr[DCTSIZE*4] == 0 &&
+ inptr[DCTSIZE*5] == 0 && inptr[DCTSIZE*6] == 0 &&
+ inptr[DCTSIZE*7] == 0) {
+ /* AC terms all zero */
+ int dcval = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]) << PASS1_BITS;
+
+ wsptr[DCTSIZE*0] = dcval;
+ wsptr[DCTSIZE*1] = dcval;
+ wsptr[DCTSIZE*2] = dcval;
+ wsptr[DCTSIZE*3] = dcval;
+ wsptr[DCTSIZE*4] = dcval;
+ wsptr[DCTSIZE*5] = dcval;
+ wsptr[DCTSIZE*6] = dcval;
+ wsptr[DCTSIZE*7] = dcval;
+
+ inptr++; /* advance pointers to next column */
+ quantptr++;
+ wsptr++;
+ continue;
+ }
+
+ /* Even part: reverse the even part of the forward DCT. */
+ /* The rotator is sqrt(2)*c(-6). */
+
+ z2 = DEQUANTIZE(inptr[DCTSIZE*2], quantptr[DCTSIZE*2]);
+ z3 = DEQUANTIZE(inptr[DCTSIZE*6], quantptr[DCTSIZE*6]);
+
+ z1 = MULTIPLY(z2 + z3, FIX_0_541196100);
+ tmp2 = z1 + MULTIPLY(z3, - FIX_1_847759065);
+ tmp3 = z1 + MULTIPLY(z2, FIX_0_765366865);
+
+ z2 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]);
+ z3 = DEQUANTIZE(inptr[DCTSIZE*4], quantptr[DCTSIZE*4]);
+
+ tmp0 = (z2 + z3) << CONST_BITS;
+ tmp1 = (z2 - z3) << CONST_BITS;
+
+ tmp10 = tmp0 + tmp3;
+ tmp13 = tmp0 - tmp3;
+ tmp11 = tmp1 + tmp2;
+ tmp12 = tmp1 - tmp2;
+
+ /* Odd part per figure 8; the matrix is unitary and hence its
+ * transpose is its inverse. i0..i3 are y7,y5,y3,y1 respectively.
+ */
+
+ tmp0 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]);
+ tmp1 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]);
+ tmp2 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]);
+ tmp3 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]);
+
+ z1 = tmp0 + tmp3;
+ z2 = tmp1 + tmp2;
+ z3 = tmp0 + tmp2;
+ z4 = tmp1 + tmp3;
+ z5 = MULTIPLY(z3 + z4, FIX_1_175875602); /* sqrt(2) * c3 */
+
+ tmp0 = MULTIPLY(tmp0, FIX_0_298631336); /* sqrt(2) * (-c1+c3+c5-c7) */
+ tmp1 = MULTIPLY(tmp1, FIX_2_053119869); /* sqrt(2) * ( c1+c3-c5+c7) */
+ tmp2 = MULTIPLY(tmp2, FIX_3_072711026); /* sqrt(2) * ( c1+c3+c5-c7) */
+ tmp3 = MULTIPLY(tmp3, FIX_1_501321110); /* sqrt(2) * ( c1+c3-c5-c7) */
+ z1 = MULTIPLY(z1, - FIX_0_899976223); /* sqrt(2) * (c7-c3) */
+ z2 = MULTIPLY(z2, - FIX_2_562915447); /* sqrt(2) * (-c1-c3) */
+ z3 = MULTIPLY(z3, - FIX_1_961570560); /* sqrt(2) * (-c3-c5) */
+ z4 = MULTIPLY(z4, - FIX_0_390180644); /* sqrt(2) * (c5-c3) */
+
+ z3 += z5;
+ z4 += z5;
+
+ tmp0 += z1 + z3;
+ tmp1 += z2 + z4;
+ tmp2 += z2 + z3;
+ tmp3 += z1 + z4;
+
+ /* Final output stage: inputs are tmp10..tmp13, tmp0..tmp3 */
+
+ wsptr[DCTSIZE*0] = (int) DESCALE(tmp10 + tmp3, CONST_BITS-PASS1_BITS);
+ wsptr[DCTSIZE*7] = (int) DESCALE(tmp10 - tmp3, CONST_BITS-PASS1_BITS);
+ wsptr[DCTSIZE*1] = (int) DESCALE(tmp11 + tmp2, CONST_BITS-PASS1_BITS);
+ wsptr[DCTSIZE*6] = (int) DESCALE(tmp11 - tmp2, CONST_BITS-PASS1_BITS);
+ wsptr[DCTSIZE*2] = (int) DESCALE(tmp12 + tmp1, CONST_BITS-PASS1_BITS);
+ wsptr[DCTSIZE*5] = (int) DESCALE(tmp12 - tmp1, CONST_BITS-PASS1_BITS);
+ wsptr[DCTSIZE*3] = (int) DESCALE(tmp13 + tmp0, CONST_BITS-PASS1_BITS);
+ wsptr[DCTSIZE*4] = (int) DESCALE(tmp13 - tmp0, CONST_BITS-PASS1_BITS);
+
+ inptr++; /* advance pointers to next column */
+ quantptr++;
+ wsptr++;
+ }
+
+ /* Pass 2: process rows from work array, store into output array. */
+ /* Note that we must descale the results by a factor of 8 == 2**3, */
+ /* and also undo the PASS1_BITS scaling. */
+
+ wsptr = workspace;
+ for (ctr = 0; ctr < DCTSIZE; ctr++) {
+ outptr = output_buf[ctr] + output_col;
+ /* Rows of zeroes can be exploited in the same way as we did with columns.
+ * However, the column calculation has created many nonzero AC terms, so
+ * the simplification applies less often (typically 5% to 10% of the time).
+ * On machines with very fast multiplication, it's possible that the
+ * test takes more time than it's worth. In that case this section
+ * may be commented out.
+ */
+
+#ifndef NO_ZERO_ROW_TEST
+ if (wsptr[1] == 0 && wsptr[2] == 0 && wsptr[3] == 0 && wsptr[4] == 0 &&
+ wsptr[5] == 0 && wsptr[6] == 0 && wsptr[7] == 0) {
+ /* AC terms all zero */
+ JSAMPLE dcval = range_limit[(int) DESCALE((long) wsptr[0], PASS1_BITS+3)
+ & RANGE_MASK];
+
+ outptr[0] = dcval;
+ outptr[1] = dcval;
+ outptr[2] = dcval;
+ outptr[3] = dcval;
+ outptr[4] = dcval;
+ outptr[5] = dcval;
+ outptr[6] = dcval;
+ outptr[7] = dcval;
+
+ wsptr += DCTSIZE; /* advance pointer to next row */
+ continue;
+ }
+#endif
+
+ /* Even part: reverse the even part of the forward DCT. */
+ /* The rotator is sqrt(2)*c(-6). */
+
+ z2 = (long) wsptr[2];
+ z3 = (long) wsptr[6];
+
+ z1 = MULTIPLY(z2 + z3, FIX_0_541196100);
+ tmp2 = z1 + MULTIPLY(z3, - FIX_1_847759065);
+ tmp3 = z1 + MULTIPLY(z2, FIX_0_765366865);
+
+ tmp0 = ((long) wsptr[0] + (long) wsptr[4]) << CONST_BITS;
+ tmp1 = ((long) wsptr[0] - (long) wsptr[4]) << CONST_BITS;
+
+ tmp10 = tmp0 + tmp3;
+ tmp13 = tmp0 - tmp3;
+ tmp11 = tmp1 + tmp2;
+ tmp12 = tmp1 - tmp2;
+
+ /* Odd part per figure 8; the matrix is unitary and hence its
+ * transpose is its inverse. i0..i3 are y7,y5,y3,y1 respectively.
+ */
+
+ tmp0 = (long) wsptr[7];
+ tmp1 = (long) wsptr[5];
+ tmp2 = (long) wsptr[3];
+ tmp3 = (long) wsptr[1];
+
+ z1 = tmp0 + tmp3;
+ z2 = tmp1 + tmp2;
+ z3 = tmp0 + tmp2;
+ z4 = tmp1 + tmp3;
+ z5 = MULTIPLY(z3 + z4, FIX_1_175875602); /* sqrt(2) * c3 */
+
+ tmp0 = MULTIPLY(tmp0, FIX_0_298631336); /* sqrt(2) * (-c1+c3+c5-c7) */
+ tmp1 = MULTIPLY(tmp1, FIX_2_053119869); /* sqrt(2) * ( c1+c3-c5+c7) */
+ tmp2 = MULTIPLY(tmp2, FIX_3_072711026); /* sqrt(2) * ( c1+c3+c5-c7) */
+ tmp3 = MULTIPLY(tmp3, FIX_1_501321110); /* sqrt(2) * ( c1+c3-c5-c7) */
+ z1 = MULTIPLY(z1, - FIX_0_899976223); /* sqrt(2) * (c7-c3) */
+ z2 = MULTIPLY(z2, - FIX_2_562915447); /* sqrt(2) * (-c1-c3) */
+ z3 = MULTIPLY(z3, - FIX_1_961570560); /* sqrt(2) * (-c3-c5) */
+ z4 = MULTIPLY(z4, - FIX_0_390180644); /* sqrt(2) * (c5-c3) */
+
+ z3 += z5;
+ z4 += z5;
+
+ tmp0 += z1 + z3;
+ tmp1 += z2 + z4;
+ tmp2 += z2 + z3;
+ tmp3 += z1 + z4;
+
+ /* Final output stage: inputs are tmp10..tmp13, tmp0..tmp3 */
+
+ outptr[0] = range_limit[(int) DESCALE(tmp10 + tmp3,
+ CONST_BITS+PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[7] = range_limit[(int) DESCALE(tmp10 - tmp3,
+ CONST_BITS+PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[1] = range_limit[(int) DESCALE(tmp11 + tmp2,
+ CONST_BITS+PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[6] = range_limit[(int) DESCALE(tmp11 - tmp2,
+ CONST_BITS+PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[2] = range_limit[(int) DESCALE(tmp12 + tmp1,
+ CONST_BITS+PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[5] = range_limit[(int) DESCALE(tmp12 - tmp1,
+ CONST_BITS+PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[3] = range_limit[(int) DESCALE(tmp13 + tmp0,
+ CONST_BITS+PASS1_BITS+3)
+ & RANGE_MASK];
+ outptr[4] = range_limit[(int) DESCALE(tmp13 - tmp0,
+ CONST_BITS+PASS1_BITS+3)
+ & RANGE_MASK];
+
+ wsptr += DCTSIZE; /* advance pointer to next row */
+ }
+}
+
+#endif /* DCT_ISLOW_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jidctred.cpp b/ml/dlib/dlib/external/libjpeg/jidctred.cpp
new file mode 100644
index 000000000..fa442ac9a
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jidctred.cpp
@@ -0,0 +1,398 @@
+/*
+ * jidctred.c
+ *
+ * Copyright (C) 1994-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains inverse-DCT routines that produce reduced-size output:
+ * either 4x4, 2x2, or 1x1 pixels from an 8x8 DCT block.
+ *
+ * The implementation is based on the Loeffler, Ligtenberg and Moschytz (LL&M)
+ * algorithm used in jidctint.c. We simply replace each 8-to-8 1-D IDCT step
+ * with an 8-to-4 step that produces the four averages of two adjacent outputs
+ * (or an 8-to-2 step producing two averages of four outputs, for 2x2 output).
+ * These steps were derived by computing the corresponding values at the end
+ * of the normal LL&M code, then simplifying as much as possible.
+ *
+ * 1x1 is trivial: just take the DC coefficient divided by 8.
+ *
+ * See jidctint.c for additional comments.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jdct.h" /* Private declarations for DCT subsystem */
+
+#ifdef IDCT_SCALING_SUPPORTED
+
+
+/*
+ * This module is specialized to the case DCTSIZE = 8.
+ */
+
+#if DCTSIZE != 8
+ Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */
+#endif
+
+
+/* Scaling is the same as in jidctint.c. */
+
+#if BITS_IN_JSAMPLE == 8
+#define CONST_BITS 13
+#define PASS1_BITS 2
+#else
+#define CONST_BITS 13
+#define PASS1_BITS 1 /* lose a little precision to avoid overflow */
+#endif
+
+/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus
+ * causing a lot of useless floating-point operations at run time.
+ * To get around this we use the following pre-calculated constants.
+ * If you change CONST_BITS you may want to add appropriate values.
+ * (With a reasonable C compiler, you can just rely on the FIX() macro...)
+ */
+
+#if CONST_BITS == 13
+#define FIX_0_211164243 ((long) 1730) /* FIX(0.211164243) */
+#define FIX_0_509795579 ((long) 4176) /* FIX(0.509795579) */
+#define FIX_0_601344887 ((long) 4926) /* FIX(0.601344887) */
+#define FIX_0_720959822 ((long) 5906) /* FIX(0.720959822) */
+#define FIX_0_765366865 ((long) 6270) /* FIX(0.765366865) */
+#define FIX_0_850430095 ((long) 6967) /* FIX(0.850430095) */
+#define FIX_0_899976223 ((long) 7373) /* FIX(0.899976223) */
+#define FIX_1_061594337 ((long) 8697) /* FIX(1.061594337) */
+#define FIX_1_272758580 ((long) 10426) /* FIX(1.272758580) */
+#define FIX_1_451774981 ((long) 11893) /* FIX(1.451774981) */
+#define FIX_1_847759065 ((long) 15137) /* FIX(1.847759065) */
+#define FIX_2_172734803 ((long) 17799) /* FIX(2.172734803) */
+#define FIX_2_562915447 ((long) 20995) /* FIX(2.562915447) */
+#define FIX_3_624509785 ((long) 29692) /* FIX(3.624509785) */
+#else
+#define FIX_0_211164243 FIX(0.211164243)
+#define FIX_0_509795579 FIX(0.509795579)
+#define FIX_0_601344887 FIX(0.601344887)
+#define FIX_0_720959822 FIX(0.720959822)
+#define FIX_0_765366865 FIX(0.765366865)
+#define FIX_0_850430095 FIX(0.850430095)
+#define FIX_0_899976223 FIX(0.899976223)
+#define FIX_1_061594337 FIX(1.061594337)
+#define FIX_1_272758580 FIX(1.272758580)
+#define FIX_1_451774981 FIX(1.451774981)
+#define FIX_1_847759065 FIX(1.847759065)
+#define FIX_2_172734803 FIX(2.172734803)
+#define FIX_2_562915447 FIX(2.562915447)
+#define FIX_3_624509785 FIX(3.624509785)
+#endif
+
+
+/* Multiply an long variable by an long constant to yield an long result.
+ * For 8-bit samples with the recommended scaling, all the variable
+ * and constant values involved are no more than 16 bits wide, so a
+ * 16x16->32 bit multiply can be used instead of a full 32x32 multiply.
+ * For 12-bit samples, a full 32-bit multiplication will be needed.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+#define MULTIPLY(var,const) MULTIPLY16C16(var,const)
+#else
+#define MULTIPLY(var,const) ((var) * (const))
+#endif
+
+
+/* Dequantize a coefficient by multiplying it by the multiplier-table
+ * entry; produce an int result. In this module, both inputs and result
+ * are 16 bits or less, so either int or short multiply will work.
+ */
+
+#define DEQUANTIZE(coef,quantval) (((ISLOW_MULT_TYPE) (coef)) * (quantval))
+
+
+/*
+ * Perform dequantization and inverse DCT on one block of coefficients,
+ * producing a reduced-size 4x4 output block.
+ */
+
+GLOBAL(void)
+jpeg_idct_4x4 (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block,
+ JSAMPARRAY output_buf, JDIMENSION output_col)
+{
+ long tmp0, tmp2, tmp10, tmp12;
+ long z1, z2, z3, z4;
+ JCOEFPTR inptr;
+ ISLOW_MULT_TYPE * quantptr;
+ int * wsptr;
+ JSAMPROW outptr;
+ JSAMPLE *range_limit = IDCT_range_limit(cinfo);
+ int ctr;
+ int workspace[DCTSIZE*4]; /* buffers data between passes */
+ SHIFT_TEMPS
+
+ /* Pass 1: process columns from input, store into work array. */
+
+ inptr = coef_block;
+ quantptr = (ISLOW_MULT_TYPE *) compptr->dct_table;
+ wsptr = workspace;
+ for (ctr = DCTSIZE; ctr > 0; inptr++, quantptr++, wsptr++, ctr--) {
+ /* Don't bother to process column 4, because second pass won't use it */
+ if (ctr == DCTSIZE-4)
+ continue;
+ if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*2] == 0 &&
+ inptr[DCTSIZE*3] == 0 && inptr[DCTSIZE*5] == 0 &&
+ inptr[DCTSIZE*6] == 0 && inptr[DCTSIZE*7] == 0) {
+ /* AC terms all zero; we need not examine term 4 for 4x4 output */
+ int dcval = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]) << PASS1_BITS;
+
+ wsptr[DCTSIZE*0] = dcval;
+ wsptr[DCTSIZE*1] = dcval;
+ wsptr[DCTSIZE*2] = dcval;
+ wsptr[DCTSIZE*3] = dcval;
+
+ continue;
+ }
+
+ /* Even part */
+
+ tmp0 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]);
+ tmp0 <<= (CONST_BITS+1);
+
+ z2 = DEQUANTIZE(inptr[DCTSIZE*2], quantptr[DCTSIZE*2]);
+ z3 = DEQUANTIZE(inptr[DCTSIZE*6], quantptr[DCTSIZE*6]);
+
+ tmp2 = MULTIPLY(z2, FIX_1_847759065) + MULTIPLY(z3, - FIX_0_765366865);
+
+ tmp10 = tmp0 + tmp2;
+ tmp12 = tmp0 - tmp2;
+
+ /* Odd part */
+
+ z1 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]);
+ z2 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]);
+ z3 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]);
+ z4 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]);
+
+ tmp0 = MULTIPLY(z1, - FIX_0_211164243) /* sqrt(2) * (c3-c1) */
+ + MULTIPLY(z2, FIX_1_451774981) /* sqrt(2) * (c3+c7) */
+ + MULTIPLY(z3, - FIX_2_172734803) /* sqrt(2) * (-c1-c5) */
+ + MULTIPLY(z4, FIX_1_061594337); /* sqrt(2) * (c5+c7) */
+
+ tmp2 = MULTIPLY(z1, - FIX_0_509795579) /* sqrt(2) * (c7-c5) */
+ + MULTIPLY(z2, - FIX_0_601344887) /* sqrt(2) * (c5-c1) */
+ + MULTIPLY(z3, FIX_0_899976223) /* sqrt(2) * (c3-c7) */
+ + MULTIPLY(z4, FIX_2_562915447); /* sqrt(2) * (c1+c3) */
+
+ /* Final output stage */
+
+ wsptr[DCTSIZE*0] = (int) DESCALE(tmp10 + tmp2, CONST_BITS-PASS1_BITS+1);
+ wsptr[DCTSIZE*3] = (int) DESCALE(tmp10 - tmp2, CONST_BITS-PASS1_BITS+1);
+ wsptr[DCTSIZE*1] = (int) DESCALE(tmp12 + tmp0, CONST_BITS-PASS1_BITS+1);
+ wsptr[DCTSIZE*2] = (int) DESCALE(tmp12 - tmp0, CONST_BITS-PASS1_BITS+1);
+ }
+
+ /* Pass 2: process 4 rows from work array, store into output array. */
+
+ wsptr = workspace;
+ for (ctr = 0; ctr < 4; ctr++) {
+ outptr = output_buf[ctr] + output_col;
+ /* It's not clear whether a zero row test is worthwhile here ... */
+
+#ifndef NO_ZERO_ROW_TEST
+ if (wsptr[1] == 0 && wsptr[2] == 0 && wsptr[3] == 0 &&
+ wsptr[5] == 0 && wsptr[6] == 0 && wsptr[7] == 0) {
+ /* AC terms all zero */
+ JSAMPLE dcval = range_limit[(int) DESCALE((long) wsptr[0], PASS1_BITS+3)
+ & RANGE_MASK];
+
+ outptr[0] = dcval;
+ outptr[1] = dcval;
+ outptr[2] = dcval;
+ outptr[3] = dcval;
+
+ wsptr += DCTSIZE; /* advance pointer to next row */
+ continue;
+ }
+#endif
+
+ /* Even part */
+
+ tmp0 = ((long) wsptr[0]) << (CONST_BITS+1);
+
+ tmp2 = MULTIPLY((long) wsptr[2], FIX_1_847759065)
+ + MULTIPLY((long) wsptr[6], - FIX_0_765366865);
+
+ tmp10 = tmp0 + tmp2;
+ tmp12 = tmp0 - tmp2;
+
+ /* Odd part */
+
+ z1 = (long) wsptr[7];
+ z2 = (long) wsptr[5];
+ z3 = (long) wsptr[3];
+ z4 = (long) wsptr[1];
+
+ tmp0 = MULTIPLY(z1, - FIX_0_211164243) /* sqrt(2) * (c3-c1) */
+ + MULTIPLY(z2, FIX_1_451774981) /* sqrt(2) * (c3+c7) */
+ + MULTIPLY(z3, - FIX_2_172734803) /* sqrt(2) * (-c1-c5) */
+ + MULTIPLY(z4, FIX_1_061594337); /* sqrt(2) * (c5+c7) */
+
+ tmp2 = MULTIPLY(z1, - FIX_0_509795579) /* sqrt(2) * (c7-c5) */
+ + MULTIPLY(z2, - FIX_0_601344887) /* sqrt(2) * (c5-c1) */
+ + MULTIPLY(z3, FIX_0_899976223) /* sqrt(2) * (c3-c7) */
+ + MULTIPLY(z4, FIX_2_562915447); /* sqrt(2) * (c1+c3) */
+
+ /* Final output stage */
+
+ outptr[0] = range_limit[(int) DESCALE(tmp10 + tmp2,
+ CONST_BITS+PASS1_BITS+3+1)
+ & RANGE_MASK];
+ outptr[3] = range_limit[(int) DESCALE(tmp10 - tmp2,
+ CONST_BITS+PASS1_BITS+3+1)
+ & RANGE_MASK];
+ outptr[1] = range_limit[(int) DESCALE(tmp12 + tmp0,
+ CONST_BITS+PASS1_BITS+3+1)
+ & RANGE_MASK];
+ outptr[2] = range_limit[(int) DESCALE(tmp12 - tmp0,
+ CONST_BITS+PASS1_BITS+3+1)
+ & RANGE_MASK];
+
+ wsptr += DCTSIZE; /* advance pointer to next row */
+ }
+}
+
+
+/*
+ * Perform dequantization and inverse DCT on one block of coefficients,
+ * producing a reduced-size 2x2 output block.
+ */
+
+GLOBAL(void)
+jpeg_idct_2x2 (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block,
+ JSAMPARRAY output_buf, JDIMENSION output_col)
+{
+ long tmp0, tmp10, z1;
+ JCOEFPTR inptr;
+ ISLOW_MULT_TYPE * quantptr;
+ int * wsptr;
+ JSAMPROW outptr;
+ JSAMPLE *range_limit = IDCT_range_limit(cinfo);
+ int ctr;
+ int workspace[DCTSIZE*2]; /* buffers data between passes */
+ SHIFT_TEMPS
+
+ /* Pass 1: process columns from input, store into work array. */
+
+ inptr = coef_block;
+ quantptr = (ISLOW_MULT_TYPE *) compptr->dct_table;
+ wsptr = workspace;
+ for (ctr = DCTSIZE; ctr > 0; inptr++, quantptr++, wsptr++, ctr--) {
+ /* Don't bother to process columns 2,4,6 */
+ if (ctr == DCTSIZE-2 || ctr == DCTSIZE-4 || ctr == DCTSIZE-6)
+ continue;
+ if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*3] == 0 &&
+ inptr[DCTSIZE*5] == 0 && inptr[DCTSIZE*7] == 0) {
+ /* AC terms all zero; we need not examine terms 2,4,6 for 2x2 output */
+ int dcval = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]) << PASS1_BITS;
+
+ wsptr[DCTSIZE*0] = dcval;
+ wsptr[DCTSIZE*1] = dcval;
+
+ continue;
+ }
+
+ /* Even part */
+
+ z1 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]);
+ tmp10 = z1 << (CONST_BITS+2);
+
+ /* Odd part */
+
+ z1 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]);
+ tmp0 = MULTIPLY(z1, - FIX_0_720959822); /* sqrt(2) * (c7-c5+c3-c1) */
+ z1 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]);
+ tmp0 += MULTIPLY(z1, FIX_0_850430095); /* sqrt(2) * (-c1+c3+c5+c7) */
+ z1 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]);
+ tmp0 += MULTIPLY(z1, - FIX_1_272758580); /* sqrt(2) * (-c1+c3-c5-c7) */
+ z1 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]);
+ tmp0 += MULTIPLY(z1, FIX_3_624509785); /* sqrt(2) * (c1+c3+c5+c7) */
+
+ /* Final output stage */
+
+ wsptr[DCTSIZE*0] = (int) DESCALE(tmp10 + tmp0, CONST_BITS-PASS1_BITS+2);
+ wsptr[DCTSIZE*1] = (int) DESCALE(tmp10 - tmp0, CONST_BITS-PASS1_BITS+2);
+ }
+
+ /* Pass 2: process 2 rows from work array, store into output array. */
+
+ wsptr = workspace;
+ for (ctr = 0; ctr < 2; ctr++) {
+ outptr = output_buf[ctr] + output_col;
+ /* It's not clear whether a zero row test is worthwhile here ... */
+
+#ifndef NO_ZERO_ROW_TEST
+ if (wsptr[1] == 0 && wsptr[3] == 0 && wsptr[5] == 0 && wsptr[7] == 0) {
+ /* AC terms all zero */
+ JSAMPLE dcval = range_limit[(int) DESCALE((long) wsptr[0], PASS1_BITS+3)
+ & RANGE_MASK];
+
+ outptr[0] = dcval;
+ outptr[1] = dcval;
+
+ wsptr += DCTSIZE; /* advance pointer to next row */
+ continue;
+ }
+#endif
+
+ /* Even part */
+
+ tmp10 = ((long) wsptr[0]) << (CONST_BITS+2);
+
+ /* Odd part */
+
+ tmp0 = MULTIPLY((long) wsptr[7], - FIX_0_720959822) /* sqrt(2) * (c7-c5+c3-c1) */
+ + MULTIPLY((long) wsptr[5], FIX_0_850430095) /* sqrt(2) * (-c1+c3+c5+c7) */
+ + MULTIPLY((long) wsptr[3], - FIX_1_272758580) /* sqrt(2) * (-c1+c3-c5-c7) */
+ + MULTIPLY((long) wsptr[1], FIX_3_624509785); /* sqrt(2) * (c1+c3+c5+c7) */
+
+ /* Final output stage */
+
+ outptr[0] = range_limit[(int) DESCALE(tmp10 + tmp0,
+ CONST_BITS+PASS1_BITS+3+2)
+ & RANGE_MASK];
+ outptr[1] = range_limit[(int) DESCALE(tmp10 - tmp0,
+ CONST_BITS+PASS1_BITS+3+2)
+ & RANGE_MASK];
+
+ wsptr += DCTSIZE; /* advance pointer to next row */
+ }
+}
+
+
+/*
+ * Perform dequantization and inverse DCT on one block of coefficients,
+ * producing a reduced-size 1x1 output block.
+ */
+
+GLOBAL(void)
+jpeg_idct_1x1 (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block,
+ JSAMPARRAY output_buf, JDIMENSION output_col)
+{
+ int dcval;
+ ISLOW_MULT_TYPE * quantptr;
+ JSAMPLE *range_limit = IDCT_range_limit(cinfo);
+ SHIFT_TEMPS
+
+ /* We hardly need an inverse DCT routine for this: just take the
+ * average pixel value, which is one-eighth of the DC coefficient.
+ */
+ quantptr = (ISLOW_MULT_TYPE *) compptr->dct_table;
+ dcval = DEQUANTIZE(coef_block[0], quantptr[0]);
+ dcval = (int) DESCALE((long) dcval, 3);
+
+ output_buf[0][output_col] = range_limit[dcval & RANGE_MASK];
+}
+
+#endif /* IDCT_SCALING_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jinclude.h b/ml/dlib/dlib/external/libjpeg/jinclude.h
new file mode 100644
index 000000000..0a4f15146
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jinclude.h
@@ -0,0 +1,91 @@
+/*
+ * jinclude.h
+ *
+ * Copyright (C) 1991-1994, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file exists to provide a single place to fix any problems with
+ * including the wrong system include files. (Common problems are taken
+ * care of by the standard jconfig symbols, but on really weird systems
+ * you may have to edit this file.)
+ *
+ * NOTE: this file is NOT intended to be included by applications using the
+ * JPEG library. Most applications need only include jpeglib.h.
+ */
+
+
+/* Include auto-config file to find out which system include files we need. */
+
+#include "jconfig.h" /* auto configuration options */
+#define JCONFIG_INCLUDED /* so that jpeglib.h doesn't do it again */
+
+/*
+ * We need the NULL macro and size_t typedef.
+ * On an ANSI-conforming system it is sufficient to include <stddef.h>.
+ * Otherwise, we get them from <stdlib.h> or <stdio.h>; we may have to
+ * pull in <sys/types.h> as well.
+ * Note that the core JPEG library does not require <stdio.h>;
+ * only the default error handler and data source/destination modules do.
+ * But we must pull it in because of the references to FILE in jpeglib.h.
+ * You can remove those references if you want to compile without <stdio.h>.
+ */
+
+#ifdef HAVE_STDDEF_H
+#include <stddef.h>
+#endif
+
+#ifdef HAVE_STDLIB_H
+#include <stdlib.h>
+#endif
+
+#ifdef NEED_SYS_TYPES_H
+#include <sys/types.h>
+#endif
+
+#include <stdio.h>
+
+/*
+ * We need memory copying and zeroing functions, plus strncpy().
+ * ANSI and System V implementations declare these in <string.h>.
+ * BSD doesn't have the mem() functions, but it does have bcopy()/bzero().
+ * Some systems may declare memset and memcpy in <memory.h>.
+ *
+ * NOTE: we assume the size parameters to these functions are of type size_t.
+ * Change the casts in these macros if not!
+ */
+
+#ifdef NEED_BSD_STRINGS
+
+#include <strings.h>
+#define MEMZERO(target,size) bzero((void *)(target), (size_t)(size))
+#define MEMCOPY(dest,src,size) bcopy((const void *)(src), (void *)(dest), (size_t)(size))
+
+#else /* not BSD, assume ANSI/SysV string lib */
+
+#include <string.h>
+#define MEMZERO(target,size) memset((void *)(target), 0, (size_t)(size))
+#define MEMCOPY(dest,src,size) memcpy((void *)(dest), (const void *)(src), (size_t)(size))
+
+#endif
+
+/*
+ * In ANSI C, and indeed any rational implementation, size_t is also the
+ * type returned by sizeof(). However, it seems there are some irrational
+ * implementations out there, in which sizeof() returns an int even though
+ * size_t is defined as long or unsigned long. To ensure consistent results
+ * we always use this SIZEOF() macro in place of using sizeof() directly.
+ */
+
+#define SIZEOF(object) ((size_t) sizeof(object))
+
+/*
+ * The modules that use fread() and fwrite() always invoke them through
+ * these macros. On some systems you may need to twiddle the argument casts.
+ * CAUTION: argument order is different from underlying functions!
+ */
+
+#define JFREAD(file,buf,sizeofbuf) \
+ ((size_t) fread((void *) (buf), (size_t) 1, (size_t) (sizeofbuf), (file)))
+#define JFWRITE(file,buf,sizeofbuf) \
+ ((size_t) fwrite((const void *) (buf), (size_t) 1, (size_t) (sizeofbuf), (file)))
diff --git a/ml/dlib/dlib/external/libjpeg/jmemmgr.cpp b/ml/dlib/dlib/external/libjpeg/jmemmgr.cpp
new file mode 100644
index 000000000..3a2e61955
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jmemmgr.cpp
@@ -0,0 +1,1118 @@
+/*
+ * jmemmgr.c
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains the JPEG system-independent memory management
+ * routines. This code is usable across a wide variety of machines; most
+ * of the system dependencies have been isolated in a separate file.
+ * The major functions provided here are:
+ * * pool-based allocation and freeing of memory;
+ * * policy decisions about how to divide available memory among the
+ * virtual arrays;
+ * * control logic for swapping virtual arrays between main memory and
+ * backing storage.
+ * The separate system-dependent file provides the actual backing-storage
+ * access code, and it contains the policy decision about how much total
+ * main memory to use.
+ * This file is system-dependent in the sense that some of its functions
+ * are unnecessary in some systems. For example, if there is enough virtual
+ * memory so that backing storage will never be used, much of the virtual
+ * array control logic could be removed. (Of course, if you have that much
+ * memory then you shouldn't care about a little bit of unused code...)
+ */
+
+#define JPEG_INTERNALS
+#define AM_MEMORY_MANAGER /* we define jvirt_Xarray_control structs */
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jmemsys.h" /* import the system-dependent declarations */
+
+#ifndef NO_GETENV
+#ifndef HAVE_STDLIB_H /* <stdlib.h> should declare getenv() */
+extern char * getenv JPP((const char * name));
+#endif
+#endif
+
+
+/*
+ * Some important notes:
+ * The allocation routines provided here must never return NULL.
+ * They should exit to error_exit if unsuccessful.
+ *
+ * It's not a good idea to try to merge the sarray and barray routines,
+ * even though they are textually almost the same, because samples are
+ * usually stored as bytes while coefficients are shorts or ints. Thus,
+ * in machines where byte pointers have a different representation from
+ * word pointers, the resulting machine code could not be the same.
+ */
+
+
+/*
+ * Many machines require storage alignment: longs must start on 4-byte
+ * boundaries, doubles on 8-byte boundaries, etc. On such machines, malloc()
+ * always returns pointers that are multiples of the worst-case alignment
+ * requirement, and we had better do so too.
+ * There isn't any really portable way to determine the worst-case alignment
+ * requirement. This module assumes that the alignment requirement is
+ * multiples of sizeof(ALIGN_TYPE).
+ * By default, we define ALIGN_TYPE as double. This is necessary on some
+ * workstations (where doubles really do need 8-byte alignment) and will work
+ * fine on nearly everything. If your machine has lesser alignment needs,
+ * you can save a few bytes by making ALIGN_TYPE smaller.
+ * The only place I know of where this will NOT work is certain Macintosh
+ * 680x0 compilers that define double as a 10-byte IEEE extended float.
+ * Doing 10-byte alignment is counterproductive because longwords won't be
+ * aligned well. Put "#define ALIGN_TYPE long" in jconfig.h if you have
+ * such a compiler.
+ */
+
+#ifndef ALIGN_TYPE /* so can override from jconfig.h */
+#define ALIGN_TYPE double
+#endif
+
+
+/*
+ * We allocate objects from "pools", where each pool is gotten with a single
+ * request to jpeg_get_small() or jpeg_get_large(). There is no per-object
+ * overhead within a pool, except for alignment padding. Each pool has a
+ * header with a link to the next pool of the same class.
+ * Small and large pool headers are identical except that the latter's
+ * link pointer must be FAR on 80x86 machines.
+ * Notice that the "real" header fields are union'ed with a dummy ALIGN_TYPE
+ * field. This forces the compiler to make SIZEOF(small_pool_hdr) a multiple
+ * of the alignment requirement of ALIGN_TYPE.
+ */
+
+typedef union small_pool_struct * small_pool_ptr;
+
+typedef union small_pool_struct {
+ struct {
+ small_pool_ptr next; /* next in list of pools */
+ size_t bytes_used; /* how many bytes already used within pool */
+ size_t bytes_left; /* bytes still available in this pool */
+ } hdr;
+ ALIGN_TYPE dummy; /* included in union to ensure alignment */
+} small_pool_hdr;
+
+typedef union large_pool_struct FAR * large_pool_ptr;
+
+typedef union large_pool_struct {
+ struct {
+ large_pool_ptr next; /* next in list of pools */
+ size_t bytes_used; /* how many bytes already used within pool */
+ size_t bytes_left; /* bytes still available in this pool */
+ } hdr;
+ ALIGN_TYPE dummy; /* included in union to ensure alignment */
+} large_pool_hdr;
+
+
+/*
+ * Here is the full definition of a memory manager object.
+ */
+
+typedef struct {
+ struct jpeg_memory_mgr pub; /* public fields */
+
+ /* Each pool identifier (lifetime class) names a linked list of pools. */
+ small_pool_ptr small_list[JPOOL_NUMPOOLS];
+ large_pool_ptr large_list[JPOOL_NUMPOOLS];
+
+ /* Since we only have one lifetime class of virtual arrays, only one
+ * linked list is necessary (for each datatype). Note that the virtual
+ * array control blocks being linked together are actually stored somewhere
+ * in the small-pool list.
+ */
+ jvirt_sarray_ptr virt_sarray_list;
+ jvirt_barray_ptr virt_barray_list;
+
+ /* This counts total space obtained from jpeg_get_small/large */
+ long total_space_allocated;
+
+ /* alloc_sarray and alloc_barray set this value for use by virtual
+ * array routines.
+ */
+ JDIMENSION last_rowsperchunk; /* from most recent alloc_sarray/barray */
+} my_memory_mgr;
+
+typedef my_memory_mgr * my_mem_ptr;
+
+
+/*
+ * The control blocks for virtual arrays.
+ * Note that these blocks are allocated in the "small" pool area.
+ * System-dependent info for the associated backing store (if any) is hidden
+ * inside the backing_store_info struct.
+ */
+
+struct jvirt_sarray_control {
+ JSAMPARRAY mem_buffer; /* => the in-memory buffer */
+ JDIMENSION rows_in_array; /* total virtual array height */
+ JDIMENSION samplesperrow; /* width of array (and of memory buffer) */
+ JDIMENSION maxaccess; /* max rows accessed by access_virt_sarray */
+ JDIMENSION rows_in_mem; /* height of memory buffer */
+ JDIMENSION rowsperchunk; /* allocation chunk size in mem_buffer */
+ JDIMENSION cur_start_row; /* first logical row # in the buffer */
+ JDIMENSION first_undef_row; /* row # of first uninitialized row */
+ int pre_zero; /* pre-zero mode requested? */
+ int dirty; /* do current buffer contents need written? */
+ int b_s_open; /* is backing-store data valid? */
+ jvirt_sarray_ptr next; /* link to next virtual sarray control block */
+ backing_store_info b_s_info; /* System-dependent control info */
+};
+
+struct jvirt_barray_control {
+ JBLOCKARRAY mem_buffer; /* => the in-memory buffer */
+ JDIMENSION rows_in_array; /* total virtual array height */
+ JDIMENSION blocksperrow; /* width of array (and of memory buffer) */
+ JDIMENSION maxaccess; /* max rows accessed by access_virt_barray */
+ JDIMENSION rows_in_mem; /* height of memory buffer */
+ JDIMENSION rowsperchunk; /* allocation chunk size in mem_buffer */
+ JDIMENSION cur_start_row; /* first logical row # in the buffer */
+ JDIMENSION first_undef_row; /* row # of first uninitialized row */
+ int pre_zero; /* pre-zero mode requested? */
+ int dirty; /* do current buffer contents need written? */
+ int b_s_open; /* is backing-store data valid? */
+ jvirt_barray_ptr next; /* link to next virtual barray control block */
+ backing_store_info b_s_info; /* System-dependent control info */
+};
+
+
+#ifdef MEM_STATS /* optional extra stuff for statistics */
+
+LOCAL(void)
+print_mem_stats (j_common_ptr cinfo, int pool_id)
+{
+ my_mem_ptr mem = (my_mem_ptr) cinfo->mem;
+ small_pool_ptr shdr_ptr;
+ large_pool_ptr lhdr_ptr;
+
+ /* Since this is only a debugging stub, we can cheat a little by using
+ * fprintf directly rather than going through the trace message code.
+ * This is helpful because message parm array can't handle longs.
+ */
+ fprintf(stderr, "Freeing pool %d, total space = %ld\n",
+ pool_id, mem->total_space_allocated);
+
+ for (lhdr_ptr = mem->large_list[pool_id]; lhdr_ptr != NULL;
+ lhdr_ptr = lhdr_ptr->hdr.next) {
+ fprintf(stderr, " Large chunk used %ld\n",
+ (long) lhdr_ptr->hdr.bytes_used);
+ }
+
+ for (shdr_ptr = mem->small_list[pool_id]; shdr_ptr != NULL;
+ shdr_ptr = shdr_ptr->hdr.next) {
+ fprintf(stderr, " Small chunk used %ld free %ld\n",
+ (long) shdr_ptr->hdr.bytes_used,
+ (long) shdr_ptr->hdr.bytes_left);
+ }
+}
+
+#endif /* MEM_STATS */
+
+
+LOCAL(void)
+out_of_memory (j_common_ptr cinfo, int which)
+/* Report an out-of-memory error and stop execution */
+/* If we compiled MEM_STATS support, report alloc requests before dying */
+{
+#ifdef MEM_STATS
+ cinfo->err->trace_level = 2; /* force self_destruct to report stats */
+#endif
+ ERREXIT1(cinfo, JERR_OUT_OF_MEMORY, which);
+}
+
+
+/*
+ * Allocation of "small" objects.
+ *
+ * For these, we use pooled storage. When a new pool must be created,
+ * we try to get enough space for the current request plus a "slop" factor,
+ * where the slop will be the amount of leftover space in the new pool.
+ * The speed vs. space tradeoff is largely determined by the slop values.
+ * A different slop value is provided for each pool class (lifetime),
+ * and we also distinguish the first pool of a class from later ones.
+ * NOTE: the values given work fairly well on both 16- and 32-bit-int
+ * machines, but may be too small if longs are 64 bits or more.
+ */
+
+static const size_t first_pool_slop[JPOOL_NUMPOOLS] =
+{
+ 1600, /* first PERMANENT pool */
+ 16000 /* first IMAGE pool */
+};
+
+static const size_t extra_pool_slop[JPOOL_NUMPOOLS] =
+{
+ 0, /* additional PERMANENT pools */
+ 5000 /* additional IMAGE pools */
+};
+
+#define MIN_SLOP 50 /* greater than 0 to avoid futile looping */
+
+
+METHODDEF(void *)
+alloc_small (j_common_ptr cinfo, int pool_id, size_t sizeofobject)
+/* Allocate a "small" object */
+{
+ my_mem_ptr mem = (my_mem_ptr) cinfo->mem;
+ small_pool_ptr hdr_ptr, prev_hdr_ptr;
+ char * data_ptr;
+ size_t odd_bytes, min_request, slop;
+
+ /* Check for unsatisfiable request (do now to ensure no overflow below) */
+ if (sizeofobject > (size_t) (MAX_ALLOC_CHUNK-SIZEOF(small_pool_hdr)))
+ out_of_memory(cinfo, 1); /* request exceeds malloc's ability */
+
+ /* Round up the requested size to a multiple of SIZEOF(ALIGN_TYPE) */
+ odd_bytes = sizeofobject % SIZEOF(ALIGN_TYPE);
+ if (odd_bytes > 0)
+ sizeofobject += SIZEOF(ALIGN_TYPE) - odd_bytes;
+
+ /* See if space is available in any existing pool */
+ if (pool_id < 0 || pool_id >= JPOOL_NUMPOOLS)
+ ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */
+ prev_hdr_ptr = NULL;
+ hdr_ptr = mem->small_list[pool_id];
+ while (hdr_ptr != NULL) {
+ if (hdr_ptr->hdr.bytes_left >= sizeofobject)
+ break; /* found pool with enough space */
+ prev_hdr_ptr = hdr_ptr;
+ hdr_ptr = hdr_ptr->hdr.next;
+ }
+
+ /* Time to make a new pool? */
+ if (hdr_ptr == NULL) {
+ /* min_request is what we need now, slop is what will be leftover */
+ min_request = sizeofobject + SIZEOF(small_pool_hdr);
+ if (prev_hdr_ptr == NULL) /* first pool in class? */
+ slop = first_pool_slop[pool_id];
+ else
+ slop = extra_pool_slop[pool_id];
+ /* Don't ask for more than MAX_ALLOC_CHUNK */
+ if (slop > (size_t) (MAX_ALLOC_CHUNK-min_request))
+ slop = (size_t) (MAX_ALLOC_CHUNK-min_request);
+ /* Try to get space, if fail reduce slop and try again */
+ for (;;) {
+ hdr_ptr = (small_pool_ptr) jpeg_get_small(cinfo, min_request + slop);
+ if (hdr_ptr != NULL)
+ break;
+ slop /= 2;
+ if (slop < MIN_SLOP) /* give up when it gets real small */
+ out_of_memory(cinfo, 2); /* jpeg_get_small failed */
+ }
+ mem->total_space_allocated += (long)(min_request + slop);
+ /* Success, initialize the new pool header and add to end of list */
+ hdr_ptr->hdr.next = NULL;
+ hdr_ptr->hdr.bytes_used = 0;
+ hdr_ptr->hdr.bytes_left = sizeofobject + slop;
+ if (prev_hdr_ptr == NULL) /* first pool in class? */
+ mem->small_list[pool_id] = hdr_ptr;
+ else
+ prev_hdr_ptr->hdr.next = hdr_ptr;
+ }
+
+ /* OK, allocate the object from the current pool */
+ data_ptr = (char *) (hdr_ptr + 1); /* point to first data byte in pool */
+ data_ptr += hdr_ptr->hdr.bytes_used; /* point to place for object */
+ hdr_ptr->hdr.bytes_used += sizeofobject;
+ hdr_ptr->hdr.bytes_left -= sizeofobject;
+
+ return (void *) data_ptr;
+}
+
+
+/*
+ * Allocation of "large" objects.
+ *
+ * The external semantics of these are the same as "small" objects,
+ * except that FAR pointers are used on 80x86. However the pool
+ * management heuristics are quite different. We assume that each
+ * request is large enough that it may as well be passed directly to
+ * jpeg_get_large; the pool management just links everything together
+ * so that we can free it all on demand.
+ * Note: the major use of "large" objects is in JSAMPARRAY and JBLOCKARRAY
+ * structures. The routines that create these structures (see below)
+ * deliberately bunch rows together to ensure a large request size.
+ */
+
+METHODDEF(void FAR *)
+alloc_large (j_common_ptr cinfo, int pool_id, size_t sizeofobject)
+/* Allocate a "large" object */
+{
+ my_mem_ptr mem = (my_mem_ptr) cinfo->mem;
+ large_pool_ptr hdr_ptr;
+ size_t odd_bytes;
+
+ /* Check for unsatisfiable request (do now to ensure no overflow below) */
+ if (sizeofobject > (size_t) (MAX_ALLOC_CHUNK-SIZEOF(large_pool_hdr)))
+ out_of_memory(cinfo, 3); /* request exceeds malloc's ability */
+
+ /* Round up the requested size to a multiple of SIZEOF(ALIGN_TYPE) */
+ odd_bytes = sizeofobject % SIZEOF(ALIGN_TYPE);
+ if (odd_bytes > 0)
+ sizeofobject += SIZEOF(ALIGN_TYPE) - odd_bytes;
+
+ /* Always make a new pool */
+ if (pool_id < 0 || pool_id >= JPOOL_NUMPOOLS)
+ ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */
+
+ hdr_ptr = (large_pool_ptr) jpeg_get_large(cinfo, sizeofobject +
+ SIZEOF(large_pool_hdr));
+ if (hdr_ptr == NULL)
+ out_of_memory(cinfo, 4); /* jpeg_get_large failed */
+ mem->total_space_allocated += (long)(sizeofobject + SIZEOF(large_pool_hdr));
+
+ /* Success, initialize the new pool header and add to list */
+ hdr_ptr->hdr.next = mem->large_list[pool_id];
+ /* We maintain space counts in each pool header for statistical purposes,
+ * even though they are not needed for allocation.
+ */
+ hdr_ptr->hdr.bytes_used = sizeofobject;
+ hdr_ptr->hdr.bytes_left = 0;
+ mem->large_list[pool_id] = hdr_ptr;
+
+ return (void FAR *) (hdr_ptr + 1); /* point to first data byte in pool */
+}
+
+
+/*
+ * Creation of 2-D sample arrays.
+ * The pointers are in near heap, the samples themselves in FAR heap.
+ *
+ * To minimize allocation overhead and to allow I/O of large contiguous
+ * blocks, we allocate the sample rows in groups of as many rows as possible
+ * without exceeding MAX_ALLOC_CHUNK total bytes per allocation request.
+ * NB: the virtual array control routines, later in this file, know about
+ * this chunking of rows. The rowsperchunk value is left in the mem manager
+ * object so that it can be saved away if this sarray is the workspace for
+ * a virtual array.
+ */
+
+METHODDEF(JSAMPARRAY)
+alloc_sarray (j_common_ptr cinfo, int pool_id,
+ JDIMENSION samplesperrow, JDIMENSION numrows)
+/* Allocate a 2-D sample array */
+{
+ my_mem_ptr mem = (my_mem_ptr) cinfo->mem;
+ JSAMPARRAY result;
+ JSAMPROW workspace;
+ JDIMENSION rowsperchunk, currow, i;
+ long ltemp;
+
+ /* Calculate max # of rows allowed in one allocation chunk */
+ ltemp = (MAX_ALLOC_CHUNK-SIZEOF(large_pool_hdr)) /
+ ((long) samplesperrow * SIZEOF(JSAMPLE));
+ if (ltemp <= 0)
+ ERREXIT(cinfo, JERR_WIDTH_OVERFLOW);
+ if (ltemp < (long) numrows)
+ rowsperchunk = (JDIMENSION) ltemp;
+ else
+ rowsperchunk = numrows;
+ mem->last_rowsperchunk = rowsperchunk;
+
+ /* Get space for row pointers (small object) */
+ result = (JSAMPARRAY) alloc_small(cinfo, pool_id,
+ (size_t) (numrows * SIZEOF(JSAMPROW)));
+
+ /* Get the rows themselves (large objects) */
+ currow = 0;
+ while (currow < numrows) {
+ rowsperchunk = MIN(rowsperchunk, numrows - currow);
+ workspace = (JSAMPROW) alloc_large(cinfo, pool_id,
+ (size_t) ((size_t) rowsperchunk * (size_t) samplesperrow
+ * SIZEOF(JSAMPLE)));
+ for (i = rowsperchunk; i > 0; i--) {
+ result[currow++] = workspace;
+ workspace += samplesperrow;
+ }
+ }
+
+ return result;
+}
+
+
+/*
+ * Creation of 2-D coefficient-block arrays.
+ * This is essentially the same as the code for sample arrays, above.
+ */
+
+METHODDEF(JBLOCKARRAY)
+alloc_barray (j_common_ptr cinfo, int pool_id,
+ JDIMENSION blocksperrow, JDIMENSION numrows)
+/* Allocate a 2-D coefficient-block array */
+{
+ my_mem_ptr mem = (my_mem_ptr) cinfo->mem;
+ JBLOCKARRAY result;
+ JBLOCKROW workspace;
+ JDIMENSION rowsperchunk, currow, i;
+ long ltemp;
+
+ /* Calculate max # of rows allowed in one allocation chunk */
+ ltemp = (MAX_ALLOC_CHUNK-SIZEOF(large_pool_hdr)) /
+ ((long) blocksperrow * SIZEOF(JBLOCK));
+ if (ltemp <= 0)
+ ERREXIT(cinfo, JERR_WIDTH_OVERFLOW);
+ if (ltemp < (long) numrows)
+ rowsperchunk = (JDIMENSION) ltemp;
+ else
+ rowsperchunk = numrows;
+ mem->last_rowsperchunk = rowsperchunk;
+
+ /* Get space for row pointers (small object) */
+ result = (JBLOCKARRAY) alloc_small(cinfo, pool_id,
+ (size_t) (numrows * SIZEOF(JBLOCKROW)));
+
+ /* Get the rows themselves (large objects) */
+ currow = 0;
+ while (currow < numrows) {
+ rowsperchunk = MIN(rowsperchunk, numrows - currow);
+ workspace = (JBLOCKROW) alloc_large(cinfo, pool_id,
+ (size_t) ((size_t) rowsperchunk * (size_t) blocksperrow
+ * SIZEOF(JBLOCK)));
+ for (i = rowsperchunk; i > 0; i--) {
+ result[currow++] = workspace;
+ workspace += blocksperrow;
+ }
+ }
+
+ return result;
+}
+
+
+/*
+ * About virtual array management:
+ *
+ * The above "normal" array routines are only used to allocate strip buffers
+ * (as wide as the image, but just a few rows high). Full-image-sized buffers
+ * are handled as "virtual" arrays. The array is still accessed a strip at a
+ * time, but the memory manager must save the whole array for repeated
+ * accesses. The intended implementation is that there is a strip buffer in
+ * memory (as high as is possible given the desired memory limit), plus a
+ * backing file that holds the rest of the array.
+ *
+ * The request_virt_array routines are told the total size of the image and
+ * the maximum number of rows that will be accessed at once. The in-memory
+ * buffer must be at least as large as the maxaccess value.
+ *
+ * The request routines create control blocks but not the in-memory buffers.
+ * That is postponed until realize_virt_arrays is called. At that time the
+ * total amount of space needed is known (approximately, anyway), so free
+ * memory can be divided up fairly.
+ *
+ * The access_virt_array routines are responsible for making a specific strip
+ * area accessible (after reading or writing the backing file, if necessary).
+ * Note that the access routines are told whether the caller intends to modify
+ * the accessed strip; during a read-only pass this saves having to rewrite
+ * data to disk. The access routines are also responsible for pre-zeroing
+ * any newly accessed rows, if pre-zeroing was requested.
+ *
+ * In current usage, the access requests are usually for nonoverlapping
+ * strips; that is, successive access start_row numbers differ by exactly
+ * num_rows = maxaccess. This means we can get good performance with simple
+ * buffer dump/reload logic, by making the in-memory buffer be a multiple
+ * of the access height; then there will never be accesses across bufferload
+ * boundaries. The code will still work with overlapping access requests,
+ * but it doesn't handle bufferload overlaps very efficiently.
+ */
+
+
+METHODDEF(jvirt_sarray_ptr)
+request_virt_sarray (j_common_ptr cinfo, int pool_id, int pre_zero,
+ JDIMENSION samplesperrow, JDIMENSION numrows,
+ JDIMENSION maxaccess)
+/* Request a virtual 2-D sample array */
+{
+ my_mem_ptr mem = (my_mem_ptr) cinfo->mem;
+ jvirt_sarray_ptr result;
+
+ /* Only IMAGE-lifetime virtual arrays are currently supported */
+ if (pool_id != JPOOL_IMAGE)
+ ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */
+
+ /* get control block */
+ result = (jvirt_sarray_ptr) alloc_small(cinfo, pool_id,
+ SIZEOF(struct jvirt_sarray_control));
+
+ result->mem_buffer = NULL; /* marks array not yet realized */
+ result->rows_in_array = numrows;
+ result->samplesperrow = samplesperrow;
+ result->maxaccess = maxaccess;
+ result->pre_zero = pre_zero;
+ result->b_s_open = FALSE; /* no associated backing-store object */
+ result->next = mem->virt_sarray_list; /* add to list of virtual arrays */
+ mem->virt_sarray_list = result;
+
+ return result;
+}
+
+
+METHODDEF(jvirt_barray_ptr)
+request_virt_barray (j_common_ptr cinfo, int pool_id, int pre_zero,
+ JDIMENSION blocksperrow, JDIMENSION numrows,
+ JDIMENSION maxaccess)
+/* Request a virtual 2-D coefficient-block array */
+{
+ my_mem_ptr mem = (my_mem_ptr) cinfo->mem;
+ jvirt_barray_ptr result;
+
+ /* Only IMAGE-lifetime virtual arrays are currently supported */
+ if (pool_id != JPOOL_IMAGE)
+ ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */
+
+ /* get control block */
+ result = (jvirt_barray_ptr) alloc_small(cinfo, pool_id,
+ SIZEOF(struct jvirt_barray_control));
+
+ result->mem_buffer = NULL; /* marks array not yet realized */
+ result->rows_in_array = numrows;
+ result->blocksperrow = blocksperrow;
+ result->maxaccess = maxaccess;
+ result->pre_zero = pre_zero;
+ result->b_s_open = FALSE; /* no associated backing-store object */
+ result->next = mem->virt_barray_list; /* add to list of virtual arrays */
+ mem->virt_barray_list = result;
+
+ return result;
+}
+
+
+METHODDEF(void)
+realize_virt_arrays (j_common_ptr cinfo)
+/* Allocate the in-memory buffers for any unrealized virtual arrays */
+{
+ my_mem_ptr mem = (my_mem_ptr) cinfo->mem;
+ long space_per_minheight, maximum_space, avail_mem;
+ long minheights, max_minheights;
+ jvirt_sarray_ptr sptr;
+ jvirt_barray_ptr bptr;
+
+ /* Compute the minimum space needed (maxaccess rows in each buffer)
+ * and the maximum space needed (full image height in each buffer).
+ * These may be of use to the system-dependent jpeg_mem_available routine.
+ */
+ space_per_minheight = 0;
+ maximum_space = 0;
+ for (sptr = mem->virt_sarray_list; sptr != NULL; sptr = sptr->next) {
+ if (sptr->mem_buffer == NULL) { /* if not realized yet */
+ space_per_minheight += (long) sptr->maxaccess *
+ (long) sptr->samplesperrow * SIZEOF(JSAMPLE);
+ maximum_space += (long) sptr->rows_in_array *
+ (long) sptr->samplesperrow * SIZEOF(JSAMPLE);
+ }
+ }
+ for (bptr = mem->virt_barray_list; bptr != NULL; bptr = bptr->next) {
+ if (bptr->mem_buffer == NULL) { /* if not realized yet */
+ space_per_minheight += (long) bptr->maxaccess *
+ (long) bptr->blocksperrow * SIZEOF(JBLOCK);
+ maximum_space += (long) bptr->rows_in_array *
+ (long) bptr->blocksperrow * SIZEOF(JBLOCK);
+ }
+ }
+
+ if (space_per_minheight <= 0)
+ return; /* no unrealized arrays, no work */
+
+ /* Determine amount of memory to actually use; this is system-dependent. */
+ avail_mem = jpeg_mem_available(cinfo, space_per_minheight, maximum_space,
+ mem->total_space_allocated);
+
+ /* If the maximum space needed is available, make all the buffers full
+ * height; otherwise parcel it out with the same number of minheights
+ * in each buffer.
+ */
+ if (avail_mem >= maximum_space)
+ max_minheights = 1000000000L;
+ else {
+ max_minheights = avail_mem / space_per_minheight;
+ /* If there doesn't seem to be enough space, try to get the minimum
+ * anyway. This allows a "stub" implementation of jpeg_mem_available().
+ */
+ if (max_minheights <= 0)
+ max_minheights = 1;
+ }
+
+ /* Allocate the in-memory buffers and initialize backing store as needed. */
+
+ for (sptr = mem->virt_sarray_list; sptr != NULL; sptr = sptr->next) {
+ if (sptr->mem_buffer == NULL) { /* if not realized yet */
+ minheights = ((long) sptr->rows_in_array - 1L) / sptr->maxaccess + 1L;
+ if (minheights <= max_minheights) {
+ /* This buffer fits in memory */
+ sptr->rows_in_mem = sptr->rows_in_array;
+ } else {
+ /* It doesn't fit in memory, create backing store. */
+ sptr->rows_in_mem = (JDIMENSION) (max_minheights * sptr->maxaccess);
+ jpeg_open_backing_store(cinfo, & sptr->b_s_info,
+ (long) sptr->rows_in_array *
+ (long) sptr->samplesperrow *
+ (long) SIZEOF(JSAMPLE));
+ sptr->b_s_open = TRUE;
+ }
+ sptr->mem_buffer = alloc_sarray(cinfo, JPOOL_IMAGE,
+ sptr->samplesperrow, sptr->rows_in_mem);
+ sptr->rowsperchunk = mem->last_rowsperchunk;
+ sptr->cur_start_row = 0;
+ sptr->first_undef_row = 0;
+ sptr->dirty = FALSE;
+ }
+ }
+
+ for (bptr = mem->virt_barray_list; bptr != NULL; bptr = bptr->next) {
+ if (bptr->mem_buffer == NULL) { /* if not realized yet */
+ minheights = ((long) bptr->rows_in_array - 1L) / bptr->maxaccess + 1L;
+ if (minheights <= max_minheights) {
+ /* This buffer fits in memory */
+ bptr->rows_in_mem = bptr->rows_in_array;
+ } else {
+ /* It doesn't fit in memory, create backing store. */
+ bptr->rows_in_mem = (JDIMENSION) (max_minheights * bptr->maxaccess);
+ jpeg_open_backing_store(cinfo, & bptr->b_s_info,
+ (long) bptr->rows_in_array *
+ (long) bptr->blocksperrow *
+ (long) SIZEOF(JBLOCK));
+ bptr->b_s_open = TRUE;
+ }
+ bptr->mem_buffer = alloc_barray(cinfo, JPOOL_IMAGE,
+ bptr->blocksperrow, bptr->rows_in_mem);
+ bptr->rowsperchunk = mem->last_rowsperchunk;
+ bptr->cur_start_row = 0;
+ bptr->first_undef_row = 0;
+ bptr->dirty = FALSE;
+ }
+ }
+}
+
+
+LOCAL(void)
+do_sarray_io (j_common_ptr cinfo, jvirt_sarray_ptr ptr, int writing)
+/* Do backing store read or write of a virtual sample array */
+{
+ long bytesperrow, file_offset, byte_count, rows, thisrow, i;
+
+ bytesperrow = (long) ptr->samplesperrow * SIZEOF(JSAMPLE);
+ file_offset = ptr->cur_start_row * bytesperrow;
+ /* Loop to read or write each allocation chunk in mem_buffer */
+ for (i = 0; i < (long) ptr->rows_in_mem; i += ptr->rowsperchunk) {
+ /* One chunk, but check for short chunk at end of buffer */
+ rows = MIN((long) ptr->rowsperchunk, (long) ptr->rows_in_mem - i);
+ /* Transfer no more than is currently defined */
+ thisrow = (long) ptr->cur_start_row + i;
+ rows = MIN(rows, (long) ptr->first_undef_row - thisrow);
+ /* Transfer no more than fits in file */
+ rows = MIN(rows, (long) ptr->rows_in_array - thisrow);
+ if (rows <= 0) /* this chunk might be past end of file! */
+ break;
+ byte_count = rows * bytesperrow;
+ if (writing)
+ (*ptr->b_s_info.write_backing_store) (cinfo, & ptr->b_s_info,
+ (void FAR *) ptr->mem_buffer[i],
+ file_offset, byte_count);
+ else
+ (*ptr->b_s_info.read_backing_store) (cinfo, & ptr->b_s_info,
+ (void FAR *) ptr->mem_buffer[i],
+ file_offset, byte_count);
+ file_offset += byte_count;
+ }
+}
+
+
+LOCAL(void)
+do_barray_io (j_common_ptr cinfo, jvirt_barray_ptr ptr, int writing)
+/* Do backing store read or write of a virtual coefficient-block array */
+{
+ long bytesperrow, file_offset, byte_count, rows, thisrow, i;
+
+ bytesperrow = (long) ptr->blocksperrow * SIZEOF(JBLOCK);
+ file_offset = ptr->cur_start_row * bytesperrow;
+ /* Loop to read or write each allocation chunk in mem_buffer */
+ for (i = 0; i < (long) ptr->rows_in_mem; i += ptr->rowsperchunk) {
+ /* One chunk, but check for short chunk at end of buffer */
+ rows = MIN((long) ptr->rowsperchunk, (long) ptr->rows_in_mem - i);
+ /* Transfer no more than is currently defined */
+ thisrow = (long) ptr->cur_start_row + i;
+ rows = MIN(rows, (long) ptr->first_undef_row - thisrow);
+ /* Transfer no more than fits in file */
+ rows = MIN(rows, (long) ptr->rows_in_array - thisrow);
+ if (rows <= 0) /* this chunk might be past end of file! */
+ break;
+ byte_count = rows * bytesperrow;
+ if (writing)
+ (*ptr->b_s_info.write_backing_store) (cinfo, & ptr->b_s_info,
+ (void FAR *) ptr->mem_buffer[i],
+ file_offset, byte_count);
+ else
+ (*ptr->b_s_info.read_backing_store) (cinfo, & ptr->b_s_info,
+ (void FAR *) ptr->mem_buffer[i],
+ file_offset, byte_count);
+ file_offset += byte_count;
+ }
+}
+
+
+METHODDEF(JSAMPARRAY)
+access_virt_sarray (j_common_ptr cinfo, jvirt_sarray_ptr ptr,
+ JDIMENSION start_row, JDIMENSION num_rows,
+ int writable)
+/* Access the part of a virtual sample array starting at start_row */
+/* and extending for num_rows rows. writable is true if */
+/* caller intends to modify the accessed area. */
+{
+ JDIMENSION end_row = start_row + num_rows;
+ JDIMENSION undef_row;
+
+ /* debugging check */
+ if (end_row > ptr->rows_in_array || num_rows > ptr->maxaccess ||
+ ptr->mem_buffer == NULL)
+ ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS);
+
+ /* Make the desired part of the virtual array accessible */
+ if (start_row < ptr->cur_start_row ||
+ end_row > ptr->cur_start_row+ptr->rows_in_mem) {
+ if (! ptr->b_s_open)
+ ERREXIT(cinfo, JERR_VIRTUAL_BUG);
+ /* Flush old buffer contents if necessary */
+ if (ptr->dirty) {
+ do_sarray_io(cinfo, ptr, TRUE);
+ ptr->dirty = FALSE;
+ }
+ /* Decide what part of virtual array to access.
+ * Algorithm: if target address > current window, assume forward scan,
+ * load starting at target address. If target address < current window,
+ * assume backward scan, load so that target area is top of window.
+ * Note that when switching from forward write to forward read, will have
+ * start_row = 0, so the limiting case applies and we load from 0 anyway.
+ */
+ if (start_row > ptr->cur_start_row) {
+ ptr->cur_start_row = start_row;
+ } else {
+ /* use long arithmetic here to avoid overflow & unsigned problems */
+ long ltemp;
+
+ ltemp = (long) end_row - (long) ptr->rows_in_mem;
+ if (ltemp < 0)
+ ltemp = 0; /* don't fall off front end of file */
+ ptr->cur_start_row = (JDIMENSION) ltemp;
+ }
+ /* Read in the selected part of the array.
+ * During the initial write pass, we will do no actual read
+ * because the selected part is all undefined.
+ */
+ do_sarray_io(cinfo, ptr, FALSE);
+ }
+ /* Ensure the accessed part of the array is defined; prezero if needed.
+ * To improve locality of access, we only prezero the part of the array
+ * that the caller is about to access, not the entire in-memory array.
+ */
+ if (ptr->first_undef_row < end_row) {
+ if (ptr->first_undef_row < start_row) {
+ if (writable) /* writer skipped over a section of array */
+ ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS);
+ undef_row = start_row; /* but reader is allowed to read ahead */
+ } else {
+ undef_row = ptr->first_undef_row;
+ }
+ if (writable)
+ ptr->first_undef_row = end_row;
+ if (ptr->pre_zero) {
+ size_t bytesperrow = (size_t) ptr->samplesperrow * SIZEOF(JSAMPLE);
+ undef_row -= ptr->cur_start_row; /* make indexes relative to buffer */
+ end_row -= ptr->cur_start_row;
+ while (undef_row < end_row) {
+ jzero_far((void FAR *) ptr->mem_buffer[undef_row], bytesperrow);
+ undef_row++;
+ }
+ } else {
+ if (! writable) /* reader looking at undefined data */
+ ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS);
+ }
+ }
+ /* Flag the buffer dirty if caller will write in it */
+ if (writable)
+ ptr->dirty = TRUE;
+ /* Return address of proper part of the buffer */
+ return ptr->mem_buffer + (start_row - ptr->cur_start_row);
+}
+
+
+METHODDEF(JBLOCKARRAY)
+access_virt_barray (j_common_ptr cinfo, jvirt_barray_ptr ptr,
+ JDIMENSION start_row, JDIMENSION num_rows,
+ int writable)
+/* Access the part of a virtual block array starting at start_row */
+/* and extending for num_rows rows. writable is true if */
+/* caller intends to modify the accessed area. */
+{
+ JDIMENSION end_row = start_row + num_rows;
+ JDIMENSION undef_row;
+
+ /* debugging check */
+ if (end_row > ptr->rows_in_array || num_rows > ptr->maxaccess ||
+ ptr->mem_buffer == NULL)
+ ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS);
+
+ /* Make the desired part of the virtual array accessible */
+ if (start_row < ptr->cur_start_row ||
+ end_row > ptr->cur_start_row+ptr->rows_in_mem) {
+ if (! ptr->b_s_open)
+ ERREXIT(cinfo, JERR_VIRTUAL_BUG);
+ /* Flush old buffer contents if necessary */
+ if (ptr->dirty) {
+ do_barray_io(cinfo, ptr, TRUE);
+ ptr->dirty = FALSE;
+ }
+ /* Decide what part of virtual array to access.
+ * Algorithm: if target address > current window, assume forward scan,
+ * load starting at target address. If target address < current window,
+ * assume backward scan, load so that target area is top of window.
+ * Note that when switching from forward write to forward read, will have
+ * start_row = 0, so the limiting case applies and we load from 0 anyway.
+ */
+ if (start_row > ptr->cur_start_row) {
+ ptr->cur_start_row = start_row;
+ } else {
+ /* use long arithmetic here to avoid overflow & unsigned problems */
+ long ltemp;
+
+ ltemp = (long) end_row - (long) ptr->rows_in_mem;
+ if (ltemp < 0)
+ ltemp = 0; /* don't fall off front end of file */
+ ptr->cur_start_row = (JDIMENSION) ltemp;
+ }
+ /* Read in the selected part of the array.
+ * During the initial write pass, we will do no actual read
+ * because the selected part is all undefined.
+ */
+ do_barray_io(cinfo, ptr, FALSE);
+ }
+ /* Ensure the accessed part of the array is defined; prezero if needed.
+ * To improve locality of access, we only prezero the part of the array
+ * that the caller is about to access, not the entire in-memory array.
+ */
+ if (ptr->first_undef_row < end_row) {
+ if (ptr->first_undef_row < start_row) {
+ if (writable) /* writer skipped over a section of array */
+ ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS);
+ undef_row = start_row; /* but reader is allowed to read ahead */
+ } else {
+ undef_row = ptr->first_undef_row;
+ }
+ if (writable)
+ ptr->first_undef_row = end_row;
+ if (ptr->pre_zero) {
+ size_t bytesperrow = (size_t) ptr->blocksperrow * SIZEOF(JBLOCK);
+ undef_row -= ptr->cur_start_row; /* make indexes relative to buffer */
+ end_row -= ptr->cur_start_row;
+ while (undef_row < end_row) {
+ jzero_far((void FAR *) ptr->mem_buffer[undef_row], bytesperrow);
+ undef_row++;
+ }
+ } else {
+ if (! writable) /* reader looking at undefined data */
+ ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS);
+ }
+ }
+ /* Flag the buffer dirty if caller will write in it */
+ if (writable)
+ ptr->dirty = TRUE;
+ /* Return address of proper part of the buffer */
+ return ptr->mem_buffer + (start_row - ptr->cur_start_row);
+}
+
+
+/*
+ * Release all objects belonging to a specified pool.
+ */
+
+METHODDEF(void)
+free_pool (j_common_ptr cinfo, int pool_id)
+{
+ my_mem_ptr mem = (my_mem_ptr) cinfo->mem;
+ small_pool_ptr shdr_ptr;
+ large_pool_ptr lhdr_ptr;
+ size_t space_freed;
+
+ if (pool_id < 0 || pool_id >= JPOOL_NUMPOOLS)
+ ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */
+
+#ifdef MEM_STATS
+ if (cinfo->err->trace_level > 1)
+ print_mem_stats(cinfo, pool_id); /* print pool's memory usage statistics */
+#endif
+
+ /* If freeing IMAGE pool, close any virtual arrays first */
+ if (pool_id == JPOOL_IMAGE) {
+ jvirt_sarray_ptr sptr;
+ jvirt_barray_ptr bptr;
+
+ for (sptr = mem->virt_sarray_list; sptr != NULL; sptr = sptr->next) {
+ if (sptr->b_s_open) { /* there may be no backing store */
+ sptr->b_s_open = FALSE; /* prevent recursive close if error */
+ (*sptr->b_s_info.close_backing_store) (cinfo, & sptr->b_s_info);
+ }
+ }
+ mem->virt_sarray_list = NULL;
+ for (bptr = mem->virt_barray_list; bptr != NULL; bptr = bptr->next) {
+ if (bptr->b_s_open) { /* there may be no backing store */
+ bptr->b_s_open = FALSE; /* prevent recursive close if error */
+ (*bptr->b_s_info.close_backing_store) (cinfo, & bptr->b_s_info);
+ }
+ }
+ mem->virt_barray_list = NULL;
+ }
+
+ /* Release large objects */
+ lhdr_ptr = mem->large_list[pool_id];
+ mem->large_list[pool_id] = NULL;
+
+ while (lhdr_ptr != NULL) {
+ large_pool_ptr next_lhdr_ptr = lhdr_ptr->hdr.next;
+ space_freed = lhdr_ptr->hdr.bytes_used +
+ lhdr_ptr->hdr.bytes_left +
+ SIZEOF(large_pool_hdr);
+ jpeg_free_large(cinfo, (void FAR *) lhdr_ptr, space_freed);
+ mem->total_space_allocated -= (long)space_freed;
+ lhdr_ptr = next_lhdr_ptr;
+ }
+
+ /* Release small objects */
+ shdr_ptr = mem->small_list[pool_id];
+ mem->small_list[pool_id] = NULL;
+
+ while (shdr_ptr != NULL) {
+ small_pool_ptr next_shdr_ptr = shdr_ptr->hdr.next;
+ space_freed = shdr_ptr->hdr.bytes_used +
+ shdr_ptr->hdr.bytes_left +
+ SIZEOF(small_pool_hdr);
+ jpeg_free_small(cinfo, (void *) shdr_ptr, space_freed);
+ mem->total_space_allocated -= (long)space_freed;
+ shdr_ptr = next_shdr_ptr;
+ }
+}
+
+
+/*
+ * Close up shop entirely.
+ * Note that this cannot be called unless cinfo->mem is non-NULL.
+ */
+
+METHODDEF(void)
+self_destruct (j_common_ptr cinfo)
+{
+ int pool;
+
+ /* Close all backing store, release all memory.
+ * Releasing pools in reverse order might help avoid fragmentation
+ * with some (brain-damaged) malloc libraries.
+ */
+ for (pool = JPOOL_NUMPOOLS-1; pool >= JPOOL_PERMANENT; pool--) {
+ free_pool(cinfo, pool);
+ }
+
+ /* Release the memory manager control block too. */
+ jpeg_free_small(cinfo, (void *) cinfo->mem, SIZEOF(my_memory_mgr));
+ cinfo->mem = NULL; /* ensures I will be called only once */
+
+ jpeg_mem_term(cinfo); /* system-dependent cleanup */
+}
+
+
+/*
+ * Memory manager initialization.
+ * When this is called, only the error manager pointer is valid in cinfo!
+ */
+
+GLOBAL(void)
+jinit_memory_mgr (j_common_ptr cinfo)
+{
+ my_mem_ptr mem;
+ long max_to_use;
+ int pool;
+ size_t test_mac;
+
+ cinfo->mem = NULL; /* for safety if init fails */
+
+ /* Check for configuration errors.
+ * SIZEOF(ALIGN_TYPE) should be a power of 2; otherwise, it probably
+ * doesn't reflect any real hardware alignment requirement.
+ * The test is a little tricky: for X>0, X and X-1 have no one-bits
+ * in common if and only if X is a power of 2, ie has only one one-bit.
+ * Some compilers may give an "unreachable code" warning here; ignore it.
+ */
+ if ((SIZEOF(ALIGN_TYPE) & (SIZEOF(ALIGN_TYPE)-1)) != 0)
+ ERREXIT(cinfo, JERR_BAD_ALIGN_TYPE);
+ /* MAX_ALLOC_CHUNK must be representable as type size_t, and must be
+ * a multiple of SIZEOF(ALIGN_TYPE).
+ * Again, an "unreachable code" warning may be ignored here.
+ * But a "constant too large" warning means you need to fix MAX_ALLOC_CHUNK.
+ */
+ test_mac = (size_t) MAX_ALLOC_CHUNK;
+ if ((long) test_mac != MAX_ALLOC_CHUNK ||
+ (MAX_ALLOC_CHUNK % SIZEOF(ALIGN_TYPE)) != 0)
+ ERREXIT(cinfo, JERR_BAD_ALLOC_CHUNK);
+
+ max_to_use = jpeg_mem_init(cinfo); /* system-dependent initialization */
+
+ /* Attempt to allocate memory manager's control block */
+ mem = (my_mem_ptr) jpeg_get_small(cinfo, SIZEOF(my_memory_mgr));
+
+ if (mem == NULL) {
+ jpeg_mem_term(cinfo); /* system-dependent cleanup */
+ ERREXIT1(cinfo, JERR_OUT_OF_MEMORY, 0);
+ }
+
+ /* OK, fill in the method pointers */
+ mem->pub.alloc_small = alloc_small;
+ mem->pub.alloc_large = alloc_large;
+ mem->pub.alloc_sarray = alloc_sarray;
+ mem->pub.alloc_barray = alloc_barray;
+ mem->pub.request_virt_sarray = request_virt_sarray;
+ mem->pub.request_virt_barray = request_virt_barray;
+ mem->pub.realize_virt_arrays = realize_virt_arrays;
+ mem->pub.access_virt_sarray = access_virt_sarray;
+ mem->pub.access_virt_barray = access_virt_barray;
+ mem->pub.free_pool = free_pool;
+ mem->pub.self_destruct = self_destruct;
+
+ /* Make MAX_ALLOC_CHUNK accessible to other modules */
+ mem->pub.max_alloc_chunk = MAX_ALLOC_CHUNK;
+
+ /* Initialize working state */
+ mem->pub.max_memory_to_use = max_to_use;
+
+ for (pool = JPOOL_NUMPOOLS-1; pool >= JPOOL_PERMANENT; pool--) {
+ mem->small_list[pool] = NULL;
+ mem->large_list[pool] = NULL;
+ }
+ mem->virt_sarray_list = NULL;
+ mem->virt_barray_list = NULL;
+
+ mem->total_space_allocated = SIZEOF(my_memory_mgr);
+
+ /* Declare ourselves open for business */
+ cinfo->mem = & mem->pub;
+
+ /* Check for an environment variable JPEGMEM; if found, override the
+ * default max_memory setting from jpeg_mem_init. Note that the
+ * surrounding application may again override this value.
+ * If your system doesn't support getenv(), define NO_GETENV to disable
+ * this feature.
+ */
+#ifndef NO_GETENV
+ { char * memenv;
+
+ if ((memenv = getenv("JPEGMEM")) != NULL) {
+ char ch = 'x';
+
+ if (sscanf(memenv, "%ld%c", &max_to_use, &ch) > 0) {
+ if (ch == 'm' || ch == 'M')
+ max_to_use *= 1000L;
+ mem->pub.max_memory_to_use = max_to_use * 1000L;
+ }
+ }
+ }
+#endif
+
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jmemnobs.cpp b/ml/dlib/dlib/external/libjpeg/jmemnobs.cpp
new file mode 100644
index 000000000..27fe6c457
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jmemnobs.cpp
@@ -0,0 +1,109 @@
+/*
+ * jmemnobs.c
+ *
+ * Copyright (C) 1992-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file provides a really simple implementation of the system-
+ * dependent portion of the JPEG memory manager. This implementation
+ * assumes that no backing-store files are needed: all required space
+ * can be obtained from malloc().
+ * This is very portable in the sense that it'll compile on almost anything,
+ * but you'd better have lots of main memory (or virtual memory) if you want
+ * to process big images.
+ * Note that the max_memory_to_use option is ignored by this implementation.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+#include "jmemsys.h" /* import the system-dependent declarations */
+
+#ifndef HAVE_STDLIB_H /* <stdlib.h> should declare malloc(),free() */
+extern void * malloc JPP((size_t size));
+extern void free JPP((void *ptr));
+#endif
+
+
+/*
+ * Memory allocation and freeing are controlled by the regular library
+ * routines malloc() and free().
+ */
+
+GLOBAL(void *)
+jpeg_get_small (j_common_ptr , size_t sizeofobject)
+{
+ return (void *) malloc(sizeofobject);
+}
+
+GLOBAL(void)
+jpeg_free_small (j_common_ptr , void * object, size_t )
+{
+ free(object);
+}
+
+
+/*
+ * "Large" objects are treated the same as "small" ones.
+ * NB: although we include FAR keywords in the routine declarations,
+ * this file won't actually work in 80x86 small/medium model; at least,
+ * you probably won't be able to process useful-size images in only 64KB.
+ */
+
+GLOBAL(void FAR *)
+jpeg_get_large (j_common_ptr , size_t sizeofobject)
+{
+ return (void FAR *) malloc(sizeofobject);
+}
+
+GLOBAL(void)
+jpeg_free_large (j_common_ptr , void FAR * object, size_t )
+{
+ free(object);
+}
+
+
+/*
+ * This routine computes the total memory space available for allocation.
+ * Here we always say, "we got all you want bud!"
+ */
+
+GLOBAL(long)
+jpeg_mem_available (j_common_ptr , long ,
+ long max_bytes_needed, long )
+{
+ return max_bytes_needed;
+}
+
+
+/*
+ * Backing store (temporary file) management.
+ * Since jpeg_mem_available always promised the moon,
+ * this should never be called and we can just error out.
+ */
+
+GLOBAL(void)
+jpeg_open_backing_store (j_common_ptr cinfo, backing_store_ptr ,
+ long )
+{
+ ERREXIT(cinfo, JERR_NO_BACKING_STORE);
+}
+
+
+/*
+ * These routines take care of any system-dependent initialization and
+ * cleanup required. Here, there isn't any.
+ */
+
+GLOBAL(long)
+jpeg_mem_init (j_common_ptr )
+{
+ return 0; /* just set max_memory_to_use to 0 */
+}
+
+GLOBAL(void)
+jpeg_mem_term (j_common_ptr )
+{
+ /* no work */
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jmemsys.h b/ml/dlib/dlib/external/libjpeg/jmemsys.h
new file mode 100644
index 000000000..6c3c6d348
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jmemsys.h
@@ -0,0 +1,198 @@
+/*
+ * jmemsys.h
+ *
+ * Copyright (C) 1992-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This include file defines the interface between the system-independent
+ * and system-dependent portions of the JPEG memory manager. No other
+ * modules need include it. (The system-independent portion is jmemmgr.c;
+ * there are several different versions of the system-dependent portion.)
+ *
+ * This file works as-is for the system-dependent memory managers supplied
+ * in the IJG distribution. You may need to modify it if you write a
+ * custom memory manager. If system-dependent changes are needed in
+ * this file, the best method is to #ifdef them based on a configuration
+ * symbol supplied in jconfig.h, as we have done with USE_MSDOS_MEMMGR
+ * and USE_MAC_MEMMGR.
+ */
+
+
+/* Short forms of external names for systems with brain-damaged linkers. */
+
+#ifdef NEED_SHORT_EXTERNAL_NAMES
+#define jpeg_get_small jGetSmall
+#define jpeg_free_small jFreeSmall
+#define jpeg_get_large jGetLarge
+#define jpeg_free_large jFreeLarge
+#define jpeg_mem_available jMemAvail
+#define jpeg_open_backing_store jOpenBackStore
+#define jpeg_mem_init jMemInit
+#define jpeg_mem_term jMemTerm
+#endif /* NEED_SHORT_EXTERNAL_NAMES */
+
+
+/*
+ * These two functions are used to allocate and release small chunks of
+ * memory. (Typically the total amount requested through jpeg_get_small is
+ * no more than 20K or so; this will be requested in chunks of a few K each.)
+ * Behavior should be the same as for the standard library functions malloc
+ * and free; in particular, jpeg_get_small must return NULL on failure.
+ * On most systems, these ARE malloc and free. jpeg_free_small is passed the
+ * size of the object being freed, just in case it's needed.
+ * On an 80x86 machine using small-data memory model, these manage near heap.
+ */
+
+EXTERN(void *) jpeg_get_small JPP((j_common_ptr cinfo, size_t sizeofobject));
+EXTERN(void) jpeg_free_small JPP((j_common_ptr cinfo, void * object,
+ size_t sizeofobject));
+
+/*
+ * These two functions are used to allocate and release large chunks of
+ * memory (up to the total free space designated by jpeg_mem_available).
+ * The interface is the same as above, except that on an 80x86 machine,
+ * far pointers are used. On most other machines these are identical to
+ * the jpeg_get/free_small routines; but we keep them separate anyway,
+ * in case a different allocation strategy is desirable for large chunks.
+ */
+
+EXTERN(void FAR *) jpeg_get_large JPP((j_common_ptr cinfo,
+ size_t sizeofobject));
+EXTERN(void) jpeg_free_large JPP((j_common_ptr cinfo, void FAR * object,
+ size_t sizeofobject));
+
+/*
+ * The macro MAX_ALLOC_CHUNK designates the maximum number of bytes that may
+ * be requested in a single call to jpeg_get_large (and jpeg_get_small for that
+ * matter, but that case should never come into play). This macro is needed
+ * to model the 64Kb-segment-size limit of far addressing on 80x86 machines.
+ * On those machines, we expect that jconfig.h will provide a proper value.
+ * On machines with 32-bit flat address spaces, any large constant may be used.
+ *
+ * NB: jmemmgr.c expects that MAX_ALLOC_CHUNK will be representable as type
+ * size_t and will be a multiple of sizeof(align_type).
+ */
+
+#ifndef MAX_ALLOC_CHUNK /* may be overridden in jconfig.h */
+#define MAX_ALLOC_CHUNK 1000000000L
+#endif
+
+/*
+ * This routine computes the total space still available for allocation by
+ * jpeg_get_large. If more space than this is needed, backing store will be
+ * used. NOTE: any memory already allocated must not be counted.
+ *
+ * There is a minimum space requirement, corresponding to the minimum
+ * feasible buffer sizes; jmemmgr.c will request that much space even if
+ * jpeg_mem_available returns zero. The maximum space needed, enough to hold
+ * all working storage in memory, is also passed in case it is useful.
+ * Finally, the total space already allocated is passed. If no better
+ * method is available, cinfo->mem->max_memory_to_use - already_allocated
+ * is often a suitable calculation.
+ *
+ * It is OK for jpeg_mem_available to underestimate the space available
+ * (that'll just lead to more backing-store access than is really necessary).
+ * However, an overestimate will lead to failure. Hence it's wise to subtract
+ * a slop factor from the true available space. 5% should be enough.
+ *
+ * On machines with lots of virtual memory, any large constant may be returned.
+ * Conversely, zero may be returned to always use the minimum amount of memory.
+ */
+
+EXTERN(long) jpeg_mem_available JPP((j_common_ptr cinfo,
+ long min_bytes_needed,
+ long max_bytes_needed,
+ long already_allocated));
+
+
+/*
+ * This structure holds whatever state is needed to access a single
+ * backing-store object. The read/write/close method pointers are called
+ * by jmemmgr.c to manipulate the backing-store object; all other fields
+ * are private to the system-dependent backing store routines.
+ */
+
+#define TEMP_NAME_LENGTH 64 /* max length of a temporary file's name */
+
+
+#ifdef USE_MSDOS_MEMMGR /* DOS-specific junk */
+
+typedef unsigned short XMSH; /* type of extended-memory handles */
+typedef unsigned short EMSH; /* type of expanded-memory handles */
+
+typedef union {
+ short file_handle; /* DOS file handle if it's a temp file */
+ XMSH xms_handle; /* handle if it's a chunk of XMS */
+ EMSH ems_handle; /* handle if it's a chunk of EMS */
+} handle_union;
+
+#endif /* USE_MSDOS_MEMMGR */
+
+#ifdef USE_MAC_MEMMGR /* Mac-specific junk */
+#include <Files.h>
+#endif /* USE_MAC_MEMMGR */
+
+
+typedef struct backing_store_struct * backing_store_ptr;
+
+typedef struct backing_store_struct {
+ /* Methods for reading/writing/closing this backing-store object */
+ JMETHOD(void, read_backing_store, (j_common_ptr cinfo,
+ backing_store_ptr info,
+ void FAR * buffer_address,
+ long file_offset, long byte_count));
+ JMETHOD(void, write_backing_store, (j_common_ptr cinfo,
+ backing_store_ptr info,
+ void FAR * buffer_address,
+ long file_offset, long byte_count));
+ JMETHOD(void, close_backing_store, (j_common_ptr cinfo,
+ backing_store_ptr info));
+
+ /* Private fields for system-dependent backing-store management */
+#ifdef USE_MSDOS_MEMMGR
+ /* For the MS-DOS manager (jmemdos.c), we need: */
+ handle_union handle; /* reference to backing-store storage object */
+ char temp_name[TEMP_NAME_LENGTH]; /* name if it's a file */
+#else
+#ifdef USE_MAC_MEMMGR
+ /* For the Mac manager (jmemmac.c), we need: */
+ short temp_file; /* file reference number to temp file */
+ FSSpec tempSpec; /* the FSSpec for the temp file */
+ char temp_name[TEMP_NAME_LENGTH]; /* name if it's a file */
+#else
+ /* For a typical implementation with temp files, we need: */
+ FILE * temp_file; /* stdio reference to temp file */
+ char temp_name[TEMP_NAME_LENGTH]; /* name of temp file */
+#endif
+#endif
+} backing_store_info;
+
+
+/*
+ * Initial opening of a backing-store object. This must fill in the
+ * read/write/close pointers in the object. The read/write routines
+ * may take an error exit if the specified maximum file size is exceeded.
+ * (If jpeg_mem_available always returns a large value, this routine can
+ * just take an error exit.)
+ */
+
+EXTERN(void) jpeg_open_backing_store JPP((j_common_ptr cinfo,
+ backing_store_ptr info,
+ long total_bytes_needed));
+
+
+/*
+ * These routines take care of any system-dependent initialization and
+ * cleanup required. jpeg_mem_init will be called before anything is
+ * allocated (and, therefore, nothing in cinfo is of use except the error
+ * manager pointer). It should return a suitable default value for
+ * max_memory_to_use; this may subsequently be overridden by the surrounding
+ * application. (Note that max_memory_to_use is only important if
+ * jpeg_mem_available chooses to consult it ... no one else will.)
+ * jpeg_mem_term may assume that all requested memory has been freed and that
+ * all opened backing-store objects have been closed.
+ */
+
+EXTERN(long) jpeg_mem_init JPP((j_common_ptr cinfo));
+EXTERN(void) jpeg_mem_term JPP((j_common_ptr cinfo));
diff --git a/ml/dlib/dlib/external/libjpeg/jmorecfg.h b/ml/dlib/dlib/external/libjpeg/jmorecfg.h
new file mode 100644
index 000000000..6082f069a
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jmorecfg.h
@@ -0,0 +1,356 @@
+/*
+ * jmorecfg.h
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains additional configuration options that customize the
+ * JPEG software for special applications or support machine-dependent
+ * optimizations. Most users will not need to touch this file.
+ */
+
+
+/*
+ * Define BITS_IN_JSAMPLE as either
+ * 8 for 8-bit sample values (the usual setting)
+ * 12 for 12-bit sample values
+ * Only 8 and 12 are legal data precisions for lossy JPEG according to the
+ * JPEG standard, and the IJG code does not support anything else!
+ * We do not support run-time selection of data precision, sorry.
+ */
+
+#define BITS_IN_JSAMPLE 8 /* use 8 or 12 */
+
+
+/*
+ * Maximum number of components (color channels) allowed in JPEG image.
+ * To meet the letter of the JPEG spec, set this to 255. However, darn
+ * few applications need more than 4 channels (maybe 5 for CMYK + alpha
+ * mask). We recommend 10 as a reasonable compromise; use 4 if you are
+ * really short on memory. (Each allowed component costs a hundred or so
+ * bytes of storage, whether actually used in an image or not.)
+ */
+
+#define MAX_COMPONENTS 10 /* maximum number of image components */
+
+
+/*
+ * Basic data types.
+ * You may need to change these if you have a machine with unusual data
+ * type sizes; for example, "char" not 8 bits, "short" not 16 bits,
+ * or "long" not 32 bits. We don't care whether "int" is 16 or 32 bits,
+ * but it had better be at least 16.
+ */
+
+/* Representation of a single sample (pixel element value).
+ * We frequently allocate large arrays of these, so it's important to keep
+ * them small. But if you have memory to burn and access to char or short
+ * arrays is very slow on your hardware, you might want to change these.
+ */
+
+
+#ifdef _MSC_VER
+// Disable the following warnings for Visual Studio
+// This is a warning you get from visual studio 2005 about things in the standard C++
+// library being "deprecated." I checked the C++ standard and it doesn't say jack
+// about any of them (I checked the searchable PDF). So this warning is total Bunk.
+#pragma warning(disable : 4996)
+#endif
+
+
+#if BITS_IN_JSAMPLE == 8
+/* JSAMPLE should be the smallest type that will hold the values 0..255.
+ * You can use a signed char by having GETJSAMPLE mask it with 0xFF.
+ */
+
+#ifdef HAVE_UNSIGNED_CHAR
+
+typedef unsigned char JSAMPLE;
+#define GETJSAMPLE(value) ((int) (value))
+
+#else /* not HAVE_UNSIGNED_CHAR */
+
+typedef char JSAMPLE;
+#ifdef CHAR_IS_UNSIGNED
+#define GETJSAMPLE(value) ((int) (value))
+#else
+#define GETJSAMPLE(value) ((int) (value) & 0xFF)
+#endif /* CHAR_IS_UNSIGNED */
+
+#endif /* HAVE_UNSIGNED_CHAR */
+
+#define MAXJSAMPLE 255
+#define CENTERJSAMPLE 128
+
+#endif /* BITS_IN_JSAMPLE == 8 */
+
+
+#if BITS_IN_JSAMPLE == 12
+/* JSAMPLE should be the smallest type that will hold the values 0..4095.
+ * On nearly all machines "short" will do nicely.
+ */
+
+typedef short JSAMPLE;
+#define GETJSAMPLE(value) ((int) (value))
+
+#define MAXJSAMPLE 4095
+#define CENTERJSAMPLE 2048
+
+#endif /* BITS_IN_JSAMPLE == 12 */
+
+
+/* Representation of a DCT frequency coefficient.
+ * This should be a signed value of at least 16 bits; "short" is usually OK.
+ * Again, we allocate large arrays of these, but you can change to int
+ * if you have memory to burn and "short" is really slow.
+ */
+
+typedef short JCOEF;
+
+
+/* Compressed datastreams are represented as arrays of JOCTET.
+ * These must be EXACTLY 8 bits wide, at least once they are written to
+ * external storage. Note that when using the stdio data source/destination
+ * managers, this is also the data type passed to fread/fwrite.
+ */
+
+#ifdef HAVE_UNSIGNED_CHAR
+
+typedef unsigned char JOCTET;
+#define GETJOCTET(value) (value)
+
+#else /* not HAVE_UNSIGNED_CHAR */
+
+typedef char JOCTET;
+#ifdef CHAR_IS_UNSIGNED
+#define GETJOCTET(value) (value)
+#else
+#define GETJOCTET(value) ((value) & 0xFF)
+#endif /* CHAR_IS_UNSIGNED */
+
+#endif /* HAVE_UNSIGNED_CHAR */
+
+
+/* These typedefs are used for various table entries and so forth.
+ * They must be at least as wide as specified; but making them too big
+ * won't cost a huge amount of memory, so we don't provide special
+ * extraction code like we did for JSAMPLE. (In other words, these
+ * typedefs live at a different point on the speed/space tradeoff curve.)
+ */
+
+/* unsigned char must hold at least the values 0..255. */
+
+
+/* unsigned short must hold at least the values 0..65535. */
+
+
+
+/* Datatype used for image dimensions. The JPEG standard only supports
+ * images up to 64K*64K due to 16-bit fields in SOF markers. Therefore
+ * "unsigned int" is sufficient on all machines. However, if you need to
+ * handle larger images and you don't mind deviating from the spec, you
+ * can change this datatype.
+ */
+
+typedef unsigned int JDIMENSION;
+
+#define JPEG_MAX_DIMENSION 65500L /* a tad under 64K to prevent overflows */
+
+
+/* These macros are used in all function definitions and extern declarations.
+ * You could modify them if you need to change function linkage conventions;
+ * in particular, you'll need to do that to make the library a Windows DLL.
+ * Another application is to make all functions global for use with debuggers
+ * or code profilers that require it.
+ */
+
+/* a function called through method pointers: */
+#define METHODDEF(type) static type
+/* a function used only in its module: */
+#define LOCAL(type) static type
+/* a function referenced thru EXTERNs: */
+#define GLOBAL(type) type
+/*
+ Use C linking unless we are supposed to be compiling our own copy of
+ libjpeg. Then let it use C++ linking so that we are less likely to get
+ linker name conflicts with other libraries that happen to statically include
+ libjpeg as well.
+*/
+#if defined(__cplusplus) && !defined(DLIB_JPEG_STATIC)
+#define EXTERN(type) extern "C" type
+#else
+#define EXTERN(type) extern type
+#endif
+
+
+/* This macro is used to declare a "method", that is, a function pointer.
+ * We want to supply prototype parameters if the compiler can cope.
+ * Note that the arglist parameter must be parenthesized!
+ * Again, you can customize this if you need special linkage keywords.
+ */
+
+#ifdef HAVE_PROTOTYPES
+#define JMETHOD(type,methodname,arglist) type (*methodname) arglist
+#else
+#define JMETHOD(type,methodname,arglist) type (*methodname) ()
+#endif
+
+
+/* Here is the pseudo-keyword for declaring pointers that must be "far"
+ * on 80x86 machines. Most of the specialized coding for 80x86 is handled
+ * by just saying "FAR *" where such a pointer is needed. In a few places
+ * explicit coding is needed; see uses of the NEED_FAR_POINTERS symbol.
+ */
+
+#ifdef NEED_FAR_POINTERS
+#define FAR far
+#else
+#ifndef FAR
+ #define FAR
+#endif
+#endif
+
+
+/*
+ * On a few systems, type boolean and/or its values FALSE, TRUE may appear
+ * in standard header files. Or you may have conflicts with application-
+ * specific header files that you want to include together with these files.
+ * Defining HAVE_BOOLEAN before including jpeglib.h should make it work.
+ */
+
+#ifndef FALSE /* in case these macros already exist */
+#define FALSE 0 /* values of boolean */
+#endif
+#ifndef TRUE
+#define TRUE 1
+#endif
+
+
+/*
+ * The remaining options affect code selection within the JPEG library,
+ * but they don't need to be visible to most applications using the library.
+ * To minimize application namespace pollution, the symbols won't be
+ * defined unless JPEG_INTERNALS or JPEG_INTERNAL_OPTIONS has been defined.
+ */
+
+#ifdef JPEG_INTERNALS
+#define JPEG_INTERNAL_OPTIONS
+#endif
+
+#ifdef JPEG_INTERNAL_OPTIONS
+
+
+/*
+ * These defines indicate whether to include various optional functions.
+ * Undefining some of these symbols will produce a smaller but less capable
+ * library. Note that you can leave certain source files out of the
+ * compilation/linking process if you've #undef'd the corresponding symbols.
+ * (You may HAVE to do that if your compiler doesn't like null source files.)
+ */
+
+/* Arithmetic coding is unsupported for legal reasons. Complaints to IBM. */
+
+/* Capability options common to encoder and decoder: */
+
+#define DCT_ISLOW_SUPPORTED /* slow but accurate integer algorithm */
+#define DCT_IFAST_SUPPORTED /* faster, less accurate integer method */
+#define DCT_FLOAT_SUPPORTED /* floating-point: accurate, fast on fast HW */
+
+/* Encoder capability options: */
+
+#undef C_ARITH_CODING_SUPPORTED /* Arithmetic coding back end? */
+#define C_MULTISCAN_FILES_SUPPORTED /* Multiple-scan JPEG files? */
+#define C_PROGRESSIVE_SUPPORTED /* Progressive JPEG? (Requires MULTISCAN)*/
+#define ENTROPY_OPT_SUPPORTED /* Optimization of entropy coding parms? */
+/* Note: if you selected 12-bit data precision, it is dangerous to turn off
+ * ENTROPY_OPT_SUPPORTED. The standard Huffman tables are only good for 8-bit
+ * precision, so jchuff.c normally uses entropy optimization to compute
+ * usable tables for higher precision. If you don't want to do optimization,
+ * you'll have to supply different default Huffman tables.
+ * The exact same statements apply for progressive JPEG: the default tables
+ * don't work for progressive mode. (This may get fixed, however.)
+ */
+#define INPUT_SMOOTHING_SUPPORTED /* Input image smoothing option? */
+
+/* Decoder capability options: */
+
+#undef D_ARITH_CODING_SUPPORTED /* Arithmetic coding back end? */
+#define D_MULTISCAN_FILES_SUPPORTED /* Multiple-scan JPEG files? */
+#define D_PROGRESSIVE_SUPPORTED /* Progressive JPEG? (Requires MULTISCAN)*/
+#define SAVE_MARKERS_SUPPORTED /* jpeg_save_markers() needed? */
+#define BLOCK_SMOOTHING_SUPPORTED /* Block smoothing? (Progressive only) */
+#define IDCT_SCALING_SUPPORTED /* Output rescaling via IDCT? */
+#undef UPSAMPLE_SCALING_SUPPORTED /* Output rescaling at upsample stage? */
+#define UPSAMPLE_MERGING_SUPPORTED /* Fast path for sloppy upsampling? */
+#define QUANT_1PASS_SUPPORTED /* 1-pass color quantization? */
+#define QUANT_2PASS_SUPPORTED /* 2-pass color quantization? */
+
+/* more capability options later, no doubt */
+
+
+/*
+ * Ordering of RGB data in scanlines passed to or from the application.
+ * If your application wants to deal with data in the order B,G,R, just
+ * change these macros. You can also deal with formats such as R,G,B,X
+ * (one extra byte per pixel) by changing RGB_PIXELSIZE. Note that changing
+ * the offsets will also change the order in which colormap data is organized.
+ * RESTRICTIONS:
+ * 1. The sample applications cjpeg,djpeg do NOT support modified RGB formats.
+ * 2. These macros only affect RGB<=>YCbCr color conversion, so they are not
+ * useful if you are using JPEG color spaces other than YCbCr or grayscale.
+ * 3. The color quantizer modules will not behave desirably if RGB_PIXELSIZE
+ * is not 3 (they don't understand about dummy color components!). So you
+ * can't use color quantization if you change that value.
+ */
+
+#define RGB_RED 0 /* Offset of Red in an RGB scanline element */
+#define RGB_GREEN 1 /* Offset of Green */
+#define RGB_BLUE 2 /* Offset of Blue */
+#define RGB_PIXELSIZE 3 /* JSAMPLEs per RGB scanline element */
+
+
+/* Definitions for speed-related optimizations. */
+
+
+/* If your compiler supports inline functions, define INLINE
+ * as the inline keyword; otherwise define it as empty.
+ */
+
+#ifndef INLINE
+#ifdef __GNUC__ /* for instance, GNU C knows about inline */
+#define INLINE __inline__
+#endif
+#ifndef INLINE
+#define INLINE /* default is to define it as empty */
+#endif
+#endif
+
+
+/* On some machines (notably 68000 series) "int" is 32 bits, but multiplying
+ * two 16-bit shorts is faster than multiplying two ints. Define MULTIPLIER
+ * as short on such a machine. MULTIPLIER must be at least 16 bits wide.
+ */
+
+#ifndef MULTIPLIER
+#define MULTIPLIER int /* type for fastest integer multiply */
+#endif
+
+
+/* FAST_FLOAT should be either float or double, whichever is done faster
+ * by your compiler. (Note that this type is only used in the floating point
+ * DCT routines, so it only matters if you've defined DCT_FLOAT_SUPPORTED.)
+ * Typically, float is faster in ANSI C compilers, while double is faster in
+ * pre-ANSI compilers (because they insist on converting to double anyway).
+ * The code below therefore chooses float if we have ANSI-style prototypes.
+ */
+
+#ifndef FAST_FLOAT
+#ifdef HAVE_PROTOTYPES
+#define FAST_FLOAT float
+#else
+#define FAST_FLOAT double
+#endif
+#endif
+
+#endif /* JPEG_INTERNAL_OPTIONS */
diff --git a/ml/dlib/dlib/external/libjpeg/jpegint.h b/ml/dlib/dlib/external/libjpeg/jpegint.h
new file mode 100644
index 000000000..654067965
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jpegint.h
@@ -0,0 +1,392 @@
+/*
+ * jpegint.h
+ *
+ * Copyright (C) 1991-1997, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file provides common declarations for the various JPEG modules.
+ * These declarations are considered internal to the JPEG library; most
+ * applications using the library shouldn't need to include this file.
+ */
+
+
+/* Declarations for both compression & decompression */
+
+typedef enum { /* Operating modes for buffer controllers */
+ JBUF_PASS_THRU, /* Plain stripwise operation */
+ /* Remaining modes require a full-image buffer to have been created */
+ JBUF_SAVE_SOURCE, /* Run source subobject only, save output */
+ JBUF_CRANK_DEST, /* Run dest subobject only, using saved data */
+ JBUF_SAVE_AND_PASS /* Run both subobjects, save output */
+} J_BUF_MODE;
+
+/* Values of global_state field (jdapi.c has some dependencies on ordering!) */
+#define CSTATE_START 100 /* after create_compress */
+#define CSTATE_SCANNING 101 /* start_compress done, write_scanlines OK */
+#define CSTATE_RAW_OK 102 /* start_compress done, write_raw_data OK */
+#define CSTATE_WRCOEFS 103 /* jpeg_write_coefficients done */
+#define DSTATE_START 200 /* after create_decompress */
+#define DSTATE_INHEADER 201 /* reading header markers, no SOS yet */
+#define DSTATE_READY 202 /* found SOS, ready for start_decompress */
+#define DSTATE_PRELOAD 203 /* reading multiscan file in start_decompress*/
+#define DSTATE_PRESCAN 204 /* performing dummy pass for 2-pass quant */
+#define DSTATE_SCANNING 205 /* start_decompress done, read_scanlines OK */
+#define DSTATE_RAW_OK 206 /* start_decompress done, read_raw_data OK */
+#define DSTATE_BUFIMAGE 207 /* expecting jpeg_start_output */
+#define DSTATE_BUFPOST 208 /* looking for SOS/EOI in jpeg_finish_output */
+#define DSTATE_RDCOEFS 209 /* reading file in jpeg_read_coefficients */
+#define DSTATE_STOPPING 210 /* looking for EOI in jpeg_finish_decompress */
+
+
+/* Declarations for compression modules */
+
+/* Master control module */
+struct jpeg_comp_master {
+ JMETHOD(void, prepare_for_pass, (j_compress_ptr cinfo));
+ JMETHOD(void, pass_startup, (j_compress_ptr cinfo));
+ JMETHOD(void, finish_pass, (j_compress_ptr cinfo));
+
+ /* State variables made visible to other modules */
+ int call_pass_startup; /* True if pass_startup must be called */
+ int is_last_pass; /* True during last pass */
+};
+
+/* Main buffer control (downsampled-data buffer) */
+struct jpeg_c_main_controller {
+ JMETHOD(void, start_pass, (j_compress_ptr cinfo, J_BUF_MODE pass_mode));
+ JMETHOD(void, process_data, (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JDIMENSION *in_row_ctr,
+ JDIMENSION in_rows_avail));
+};
+
+/* Compression preprocessing (downsampling input buffer control) */
+struct jpeg_c_prep_controller {
+ JMETHOD(void, start_pass, (j_compress_ptr cinfo, J_BUF_MODE pass_mode));
+ JMETHOD(void, pre_process_data, (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf,
+ JDIMENSION *in_row_ctr,
+ JDIMENSION in_rows_avail,
+ JSAMPIMAGE output_buf,
+ JDIMENSION *out_row_group_ctr,
+ JDIMENSION out_row_groups_avail));
+};
+
+/* Coefficient buffer control */
+struct jpeg_c_coef_controller {
+ JMETHOD(void, start_pass, (j_compress_ptr cinfo, J_BUF_MODE pass_mode));
+ JMETHOD(int, compress_data, (j_compress_ptr cinfo,
+ JSAMPIMAGE input_buf));
+};
+
+/* Colorspace conversion */
+struct jpeg_color_converter {
+ JMETHOD(void, start_pass, (j_compress_ptr cinfo));
+ JMETHOD(void, color_convert, (j_compress_ptr cinfo,
+ JSAMPARRAY input_buf, JSAMPIMAGE output_buf,
+ JDIMENSION output_row, int num_rows));
+};
+
+/* Downsampling */
+struct jpeg_downsampler {
+ JMETHOD(void, start_pass, (j_compress_ptr cinfo));
+ JMETHOD(void, downsample, (j_compress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION in_row_index,
+ JSAMPIMAGE output_buf,
+ JDIMENSION out_row_group_index));
+
+ int need_context_rows; /* TRUE if need rows above & below */
+};
+
+/* Forward DCT (also controls coefficient quantization) */
+struct jpeg_forward_dct {
+ JMETHOD(void, start_pass, (j_compress_ptr cinfo));
+ /* perhaps this should be an array??? */
+ JMETHOD(void, forward_DCT, (j_compress_ptr cinfo,
+ jpeg_component_info * compptr,
+ JSAMPARRAY sample_data, JBLOCKROW coef_blocks,
+ JDIMENSION start_row, JDIMENSION start_col,
+ JDIMENSION num_blocks));
+};
+
+/* Entropy encoding */
+struct jpeg_entropy_encoder {
+ JMETHOD(void, start_pass, (j_compress_ptr cinfo, int gather_statistics));
+ JMETHOD(int, encode_mcu, (j_compress_ptr cinfo, JBLOCKROW *MCU_data));
+ JMETHOD(void, finish_pass, (j_compress_ptr cinfo));
+};
+
+/* Marker writing */
+struct jpeg_marker_writer {
+ JMETHOD(void, write_file_header, (j_compress_ptr cinfo));
+ JMETHOD(void, write_frame_header, (j_compress_ptr cinfo));
+ JMETHOD(void, write_scan_header, (j_compress_ptr cinfo));
+ JMETHOD(void, write_file_trailer, (j_compress_ptr cinfo));
+ JMETHOD(void, write_tables_only, (j_compress_ptr cinfo));
+ /* These routines are exported to allow insertion of extra markers */
+ /* Probably only COM and APPn markers should be written this way */
+ JMETHOD(void, write_marker_header, (j_compress_ptr cinfo, int marker,
+ unsigned int datalen));
+ JMETHOD(void, write_marker_byte, (j_compress_ptr cinfo, int val));
+};
+
+
+/* Declarations for decompression modules */
+
+/* Master control module */
+struct jpeg_decomp_master {
+ JMETHOD(void, prepare_for_output_pass, (j_decompress_ptr cinfo));
+ JMETHOD(void, finish_output_pass, (j_decompress_ptr cinfo));
+
+ /* State variables made visible to other modules */
+ int is_dummy_pass; /* True during 1st pass for 2-pass quant */
+};
+
+/* Input control module */
+struct jpeg_input_controller {
+ JMETHOD(int, consume_input, (j_decompress_ptr cinfo));
+ JMETHOD(void, reset_input_controller, (j_decompress_ptr cinfo));
+ JMETHOD(void, start_input_pass, (j_decompress_ptr cinfo));
+ JMETHOD(void, finish_input_pass, (j_decompress_ptr cinfo));
+
+ /* State variables made visible to other modules */
+ int has_multiple_scans; /* True if file has multiple scans */
+ int eoi_reached; /* True when EOI has been consumed */
+};
+
+/* Main buffer control (downsampled-data buffer) */
+struct jpeg_d_main_controller {
+ JMETHOD(void, start_pass, (j_decompress_ptr cinfo, J_BUF_MODE pass_mode));
+ JMETHOD(void, process_data, (j_decompress_ptr cinfo,
+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail));
+};
+
+/* Coefficient buffer control */
+struct jpeg_d_coef_controller {
+ JMETHOD(void, start_input_pass, (j_decompress_ptr cinfo));
+ JMETHOD(int, consume_data, (j_decompress_ptr cinfo));
+ JMETHOD(void, start_output_pass, (j_decompress_ptr cinfo));
+ JMETHOD(int, decompress_data, (j_decompress_ptr cinfo,
+ JSAMPIMAGE output_buf));
+ /* Pointer to array of coefficient virtual arrays, or NULL if none */
+ jvirt_barray_ptr *coef_arrays;
+};
+
+/* Decompression postprocessing (color quantization buffer control) */
+struct jpeg_d_post_controller {
+ JMETHOD(void, start_pass, (j_decompress_ptr cinfo, J_BUF_MODE pass_mode));
+ JMETHOD(void, post_process_data, (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf,
+ JDIMENSION *in_row_group_ctr,
+ JDIMENSION in_row_groups_avail,
+ JSAMPARRAY output_buf,
+ JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail));
+};
+
+/* Marker reading & parsing */
+struct jpeg_marker_reader {
+ JMETHOD(void, reset_marker_reader, (j_decompress_ptr cinfo));
+ /* Read markers until SOS or EOI.
+ * Returns same codes as are defined for jpeg_consume_input:
+ * JPEG_SUSPENDED, JPEG_REACHED_SOS, or JPEG_REACHED_EOI.
+ */
+ JMETHOD(int, read_markers, (j_decompress_ptr cinfo));
+ /* Read a restart marker --- exported for use by entropy decoder only */
+ jpeg_marker_parser_method read_restart_marker;
+
+ /* State of marker reader --- nominally internal, but applications
+ * supplying COM or APPn handlers might like to know the state.
+ */
+ int saw_SOI; /* found SOI? */
+ int saw_SOF; /* found SOF? */
+ int next_restart_num; /* next restart number expected (0-7) */
+ unsigned int discarded_bytes; /* # of bytes skipped looking for a marker */
+};
+
+/* Entropy decoding */
+struct jpeg_entropy_decoder {
+ JMETHOD(void, start_pass, (j_decompress_ptr cinfo));
+ JMETHOD(int, decode_mcu, (j_decompress_ptr cinfo,
+ JBLOCKROW *MCU_data));
+
+ /* This is here to share code between baseline and progressive decoders; */
+ /* other modules probably should not use it */
+ int insufficient_data; /* set TRUE after emitting warning */
+};
+
+/* Inverse DCT (also performs dequantization) */
+typedef JMETHOD(void, inverse_DCT_method_ptr,
+ (j_decompress_ptr cinfo, jpeg_component_info * compptr,
+ JCOEFPTR coef_block,
+ JSAMPARRAY output_buf, JDIMENSION output_col));
+
+struct jpeg_inverse_dct {
+ JMETHOD(void, start_pass, (j_decompress_ptr cinfo));
+ /* It is useful to allow each component to have a separate IDCT method. */
+ inverse_DCT_method_ptr inverse_DCT[MAX_COMPONENTS];
+};
+
+/* Upsampling (note that upsampler must also call color converter) */
+struct jpeg_upsampler {
+ JMETHOD(void, start_pass, (j_decompress_ptr cinfo));
+ JMETHOD(void, upsample, (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf,
+ JDIMENSION *in_row_group_ctr,
+ JDIMENSION in_row_groups_avail,
+ JSAMPARRAY output_buf,
+ JDIMENSION *out_row_ctr,
+ JDIMENSION out_rows_avail));
+
+ int need_context_rows; /* TRUE if need rows above & below */
+};
+
+/* Colorspace conversion */
+struct jpeg_color_deconverter {
+ JMETHOD(void, start_pass, (j_decompress_ptr cinfo));
+ JMETHOD(void, color_convert, (j_decompress_ptr cinfo,
+ JSAMPIMAGE input_buf, JDIMENSION input_row,
+ JSAMPARRAY output_buf, int num_rows));
+};
+
+/* Color quantization or color precision reduction */
+struct jpeg_color_quantizer {
+ JMETHOD(void, start_pass, (j_decompress_ptr cinfo, int is_pre_scan));
+ JMETHOD(void, color_quantize, (j_decompress_ptr cinfo,
+ JSAMPARRAY input_buf, JSAMPARRAY output_buf,
+ int num_rows));
+ JMETHOD(void, finish_pass, (j_decompress_ptr cinfo));
+ JMETHOD(void, new_color_map, (j_decompress_ptr cinfo));
+};
+
+
+/* Miscellaneous useful macros */
+
+#undef MAX
+#define MAX(a,b) ((a) > (b) ? (a) : (b))
+#undef MIN
+#define MIN(a,b) ((a) < (b) ? (a) : (b))
+
+
+/* We assume that right shift corresponds to signed division by 2 with
+ * rounding towards minus infinity. This is correct for typical "arithmetic
+ * shift" instructions that shift in copies of the sign bit. But some
+ * C compilers implement >> with an unsigned shift. For these machines you
+ * must define RIGHT_SHIFT_IS_UNSIGNED.
+ * RIGHT_SHIFT provides a proper signed right shift of an long quantity.
+ * It is only applied with constant shift counts. SHIFT_TEMPS must be
+ * included in the variables of any routine using RIGHT_SHIFT.
+ */
+
+#ifdef RIGHT_SHIFT_IS_UNSIGNED
+#define SHIFT_TEMPS long shift_temp;
+#define RIGHT_SHIFT(x,shft) \
+ ((shift_temp = (x)) < 0 ? \
+ (shift_temp >> (shft)) | ((~((long) 0)) << (32-(shft))) : \
+ (shift_temp >> (shft)))
+#else
+#define SHIFT_TEMPS
+#define RIGHT_SHIFT(x,shft) ((x) >> (shft))
+#endif
+
+
+/* Short forms of external names for systems with brain-damaged linkers. */
+
+#ifdef NEED_SHORT_EXTERNAL_NAMES
+#define jinit_compress_master jICompress
+#define jinit_c_master_control jICMaster
+#define jinit_c_main_controller jICMainC
+#define jinit_c_prep_controller jICPrepC
+#define jinit_c_coef_controller jICCoefC
+#define jinit_color_converter jICColor
+#define jinit_downsampler jIDownsampler
+#define jinit_forward_dct jIFDCT
+#define jinit_huff_encoder jIHEncoder
+#define jinit_phuff_encoder jIPHEncoder
+#define jinit_marker_writer jIMWriter
+#define jinit_master_decompress jIDMaster
+#define jinit_d_main_controller jIDMainC
+#define jinit_d_coef_controller jIDCoefC
+#define jinit_d_post_controller jIDPostC
+#define jinit_input_controller jIInCtlr
+#define jinit_marker_reader jIMReader
+#define jinit_huff_decoder jIHDecoder
+#define jinit_phuff_decoder jIPHDecoder
+#define jinit_inverse_dct jIIDCT
+#define jinit_upsampler jIUpsampler
+#define jinit_color_deconverter jIDColor
+#define jinit_1pass_quantizer jI1Quant
+#define jinit_2pass_quantizer jI2Quant
+#define jinit_merged_upsampler jIMUpsampler
+#define jinit_memory_mgr jIMemMgr
+#define jdiv_round_up jDivRound
+#define jround_up jRound
+#define jcopy_sample_rows jCopySamples
+#define jcopy_block_row jCopyBlocks
+#define jzero_far jZeroFar
+#define jpeg_zigzag_order jZIGTable
+#define jpeg_natural_order jZAGTable
+#endif /* NEED_SHORT_EXTERNAL_NAMES */
+
+
+/* Compression module initialization routines */
+EXTERN(void) jinit_compress_master JPP((j_compress_ptr cinfo));
+EXTERN(void) jinit_c_master_control JPP((j_compress_ptr cinfo,
+ int transcode_only));
+EXTERN(void) jinit_c_main_controller JPP((j_compress_ptr cinfo,
+ int need_full_buffer));
+EXTERN(void) jinit_c_prep_controller JPP((j_compress_ptr cinfo,
+ int need_full_buffer));
+EXTERN(void) jinit_c_coef_controller JPP((j_compress_ptr cinfo,
+ int need_full_buffer));
+EXTERN(void) jinit_color_converter JPP((j_compress_ptr cinfo));
+EXTERN(void) jinit_downsampler JPP((j_compress_ptr cinfo));
+EXTERN(void) jinit_forward_dct JPP((j_compress_ptr cinfo));
+EXTERN(void) jinit_huff_encoder JPP((j_compress_ptr cinfo));
+EXTERN(void) jinit_phuff_encoder JPP((j_compress_ptr cinfo));
+EXTERN(void) jinit_marker_writer JPP((j_compress_ptr cinfo));
+/* Decompression module initialization routines */
+EXTERN(void) jinit_master_decompress JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_d_main_controller JPP((j_decompress_ptr cinfo,
+ int need_full_buffer));
+EXTERN(void) jinit_d_coef_controller JPP((j_decompress_ptr cinfo,
+ int need_full_buffer));
+EXTERN(void) jinit_d_post_controller JPP((j_decompress_ptr cinfo,
+ int need_full_buffer));
+EXTERN(void) jinit_input_controller JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_marker_reader JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_huff_decoder JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_phuff_decoder JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_inverse_dct JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_upsampler JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_color_deconverter JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_1pass_quantizer JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_2pass_quantizer JPP((j_decompress_ptr cinfo));
+EXTERN(void) jinit_merged_upsampler JPP((j_decompress_ptr cinfo));
+/* Memory manager initialization */
+EXTERN(void) jinit_memory_mgr JPP((j_common_ptr cinfo));
+
+/* Utility routines in jutils.c */
+EXTERN(long) jdiv_round_up JPP((long a, long b));
+EXTERN(long) jround_up JPP((long a, long b));
+EXTERN(void) jcopy_sample_rows JPP((JSAMPARRAY input_array, int source_row,
+ JSAMPARRAY output_array, int dest_row,
+ int num_rows, JDIMENSION num_cols));
+EXTERN(void) jcopy_block_row JPP((JBLOCKROW input_row, JBLOCKROW output_row,
+ JDIMENSION num_blocks));
+EXTERN(void) jzero_far JPP((void FAR * target, size_t bytestozero));
+/* Constant tables in jutils.c */
+#if 0 /* This table is not actually needed in v6a */
+extern const int jpeg_zigzag_order[]; /* natural coef order to zigzag order */
+#endif
+extern const int jpeg_natural_order[]; /* zigzag coef order to natural order */
+
+/* Suppress undefined-structure complaints if necessary. */
+
+#ifdef INCOMPLETE_TYPES_BROKEN
+#ifndef AM_MEMORY_MANAGER /* only jmemmgr.c defines these */
+struct jvirt_sarray_control { long dummy; };
+struct jvirt_barray_control { long dummy; };
+#endif
+#endif /* INCOMPLETE_TYPES_BROKEN */
diff --git a/ml/dlib/dlib/external/libjpeg/jpeglib.h b/ml/dlib/dlib/external/libjpeg/jpeglib.h
new file mode 100644
index 000000000..e611602d2
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jpeglib.h
@@ -0,0 +1,1096 @@
+/*
+ * jpeglib.h
+ *
+ * Copyright (C) 1991-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file defines the application interface for the JPEG library.
+ * Most applications using the library need only include this file,
+ * and perhaps jerror.h if they want to know the exact error codes.
+ */
+
+#ifndef JPEGLIB_H
+#define JPEGLIB_H
+
+/*
+ * First we include the configuration files that record how this
+ * installation of the JPEG library is set up. jconfig.h can be
+ * generated automatically for many systems. jmorecfg.h contains
+ * manual configuration options that most people need not worry about.
+ */
+
+#ifndef JCONFIG_INCLUDED /* in case jinclude.h already did */
+#include "jconfig.h" /* widely used configuration options */
+#endif
+#include "jmorecfg.h" /* seldom changed options */
+
+
+/* Version ID for the JPEG library.
+ * Might be useful for tests like "#if JPEG_LIB_VERSION >= 60".
+ */
+
+#define JPEG_LIB_VERSION 62 /* Version 6b */
+
+
+/* Various constants determining the sizes of things.
+ * All of these are specified by the JPEG standard, so don't change them
+ * if you want to be compatible.
+ */
+
+#define DCTSIZE 8 /* The basic DCT block is 8x8 samples */
+#define DCTSIZE2 64 /* DCTSIZE squared; # of elements in a block */
+#define NUM_QUANT_TBLS 4 /* Quantization tables are numbered 0..3 */
+#define NUM_HUFF_TBLS 4 /* Huffman tables are numbered 0..3 */
+#define NUM_ARITH_TBLS 16 /* Arith-coding tables are numbered 0..15 */
+#define MAX_COMPS_IN_SCAN 4 /* JPEG limit on # of components in one scan */
+#define MAX_SAMP_FACTOR 4 /* JPEG limit on sampling factors */
+/* Unfortunately, some bozo at Adobe saw no reason to be bound by the standard;
+ * the PostScript DCT filter can emit files with many more than 10 blocks/MCU.
+ * If you happen to run across such a file, you can up D_MAX_BLOCKS_IN_MCU
+ * to handle it. We even let you do this from the jconfig.h file. However,
+ * we strongly discourage changing C_MAX_BLOCKS_IN_MCU; just because Adobe
+ * sometimes emits noncompliant files doesn't mean you should too.
+ */
+#define C_MAX_BLOCKS_IN_MCU 10 /* compressor's limit on blocks per MCU */
+#ifndef D_MAX_BLOCKS_IN_MCU
+#define D_MAX_BLOCKS_IN_MCU 10 /* decompressor's limit on blocks per MCU */
+#endif
+
+
+/* Data structures for images (arrays of samples and of DCT coefficients).
+ * On 80x86 machines, the image arrays are too big for near pointers,
+ * but the pointer arrays can fit in near memory.
+ */
+
+typedef JSAMPLE FAR *JSAMPROW; /* ptr to one image row of pixel samples. */
+typedef JSAMPROW *JSAMPARRAY; /* ptr to some rows (a 2-D sample array) */
+typedef JSAMPARRAY *JSAMPIMAGE; /* a 3-D sample array: top index is color */
+
+typedef JCOEF JBLOCK[DCTSIZE2]; /* one block of coefficients */
+typedef JBLOCK FAR *JBLOCKROW; /* pointer to one row of coefficient blocks */
+typedef JBLOCKROW *JBLOCKARRAY; /* a 2-D array of coefficient blocks */
+typedef JBLOCKARRAY *JBLOCKIMAGE; /* a 3-D array of coefficient blocks */
+
+typedef JCOEF FAR *JCOEFPTR; /* useful in a couple of places */
+
+
+/* Types for JPEG compression parameters and working tables. */
+
+
+/* DCT coefficient quantization tables. */
+
+typedef struct {
+ /* This array gives the coefficient quantizers in natural array order
+ * (not the zigzag order in which they are stored in a JPEG DQT marker).
+ * CAUTION: IJG versions prior to v6a kept this array in zigzag order.
+ */
+ unsigned short quantval[DCTSIZE2]; /* quantization step for each coefficient */
+ /* This field is used only during compression. It's initialized FALSE when
+ * the table is created, and set TRUE when it's been output to the file.
+ * You could suppress output of a table by setting this to TRUE.
+ * (See jpeg_suppress_tables for an example.)
+ */
+ int sent_table; /* TRUE when table has been output */
+} JQUANT_TBL;
+
+
+/* Huffman coding tables. */
+
+typedef struct {
+ /* These two fields directly represent the contents of a JPEG DHT marker */
+ unsigned char bits[17]; /* bits[k] = # of symbols with codes of */
+ /* length k bits; bits[0] is unused */
+ unsigned char huffval[256]; /* The symbols, in order of incr code length */
+ /* This field is used only during compression. It's initialized FALSE when
+ * the table is created, and set TRUE when it's been output to the file.
+ * You could suppress output of a table by setting this to TRUE.
+ * (See jpeg_suppress_tables for an example.)
+ */
+ int sent_table; /* TRUE when table has been output */
+} JHUFF_TBL;
+
+
+/* Basic info about one component (color channel). */
+
+typedef struct {
+ /* These values are fixed over the whole image. */
+ /* For compression, they must be supplied by parameter setup; */
+ /* for decompression, they are read from the SOF marker. */
+ int component_id; /* identifier for this component (0..255) */
+ int component_index; /* its index in SOF or cinfo->comp_info[] */
+ int h_samp_factor; /* horizontal sampling factor (1..4) */
+ int v_samp_factor; /* vertical sampling factor (1..4) */
+ int quant_tbl_no; /* quantization table selector (0..3) */
+ /* These values may vary between scans. */
+ /* For compression, they must be supplied by parameter setup; */
+ /* for decompression, they are read from the SOS marker. */
+ /* The decompressor output side may not use these variables. */
+ int dc_tbl_no; /* DC entropy table selector (0..3) */
+ int ac_tbl_no; /* AC entropy table selector (0..3) */
+
+ /* Remaining fields should be treated as private by applications. */
+
+ /* These values are computed during compression or decompression startup: */
+ /* Component's size in DCT blocks.
+ * Any dummy blocks added to complete an MCU are not counted; therefore
+ * these values do not depend on whether a scan is interleaved or not.
+ */
+ JDIMENSION width_in_blocks;
+ JDIMENSION height_in_blocks;
+ /* Size of a DCT block in samples. Always DCTSIZE for compression.
+ * For decompression this is the size of the output from one DCT block,
+ * reflecting any scaling we choose to apply during the IDCT step.
+ * Values of 1,2,4,8 are likely to be supported. Note that different
+ * components may receive different IDCT scalings.
+ */
+ int DCT_scaled_size;
+ /* The downsampled dimensions are the component's actual, unpadded number
+ * of samples at the main buffer (preprocessing/compression interface), thus
+ * downsampled_width = ceil(image_width * Hi/Hmax)
+ * and similarly for height. For decompression, IDCT scaling is included, so
+ * downsampled_width = ceil(image_width * Hi/Hmax * DCT_scaled_size/DCTSIZE)
+ */
+ JDIMENSION downsampled_width; /* actual width in samples */
+ JDIMENSION downsampled_height; /* actual height in samples */
+ /* This flag is used only for decompression. In cases where some of the
+ * components will be ignored (eg grayscale output from YCbCr image),
+ * we can skip most computations for the unused components.
+ */
+ int component_needed; /* do we need the value of this component? */
+
+ /* These values are computed before starting a scan of the component. */
+ /* The decompressor output side may not use these variables. */
+ int MCU_width; /* number of blocks per MCU, horizontally */
+ int MCU_height; /* number of blocks per MCU, vertically */
+ int MCU_blocks; /* MCU_width * MCU_height */
+ int MCU_sample_width; /* MCU width in samples, MCU_width*DCT_scaled_size */
+ int last_col_width; /* # of non-dummy blocks across in last MCU */
+ int last_row_height; /* # of non-dummy blocks down in last MCU */
+
+ /* Saved quantization table for component; NULL if none yet saved.
+ * See jdinput.c comments about the need for this information.
+ * This field is currently used only for decompression.
+ */
+ JQUANT_TBL * quant_table;
+
+ /* Private per-component storage for DCT or IDCT subsystem. */
+ void * dct_table;
+} jpeg_component_info;
+
+
+/* The script for encoding a multiple-scan file is an array of these: */
+
+typedef struct {
+ int comps_in_scan; /* number of components encoded in this scan */
+ int component_index[MAX_COMPS_IN_SCAN]; /* their SOF/comp_info[] indexes */
+ int Ss, Se; /* progressive JPEG spectral selection parms */
+ int Ah, Al; /* progressive JPEG successive approx. parms */
+} jpeg_scan_info;
+
+/* The decompressor can save APPn and COM markers in a list of these: */
+
+typedef struct jpeg_marker_struct FAR * jpeg_saved_marker_ptr;
+
+struct jpeg_marker_struct {
+ jpeg_saved_marker_ptr next; /* next in list, or NULL */
+ unsigned char marker; /* marker code: JPEG_COM, or JPEG_APP0+n */
+ unsigned int original_length; /* # bytes of data in the file */
+ unsigned int data_length; /* # bytes of data saved at data[] */
+ JOCTET FAR * data; /* the data contained in the marker */
+ /* the marker length word is not counted in data_length or original_length */
+};
+
+/* Known color spaces. */
+
+typedef enum {
+ JCS_UNKNOWN, /* error/unspecified */
+ JCS_GRAYSCALE, /* monochrome */
+ JCS_RGB, /* red/green/blue */
+ JCS_YCbCr, /* Y/Cb/Cr (also known as YUV) */
+ JCS_CMYK, /* C/M/Y/K */
+ JCS_YCCK /* Y/Cb/Cr/K */
+} J_COLOR_SPACE;
+
+/* DCT/IDCT algorithm options. */
+
+typedef enum {
+ JDCT_ISLOW, /* slow but accurate integer algorithm */
+ JDCT_IFAST, /* faster, less accurate integer method */
+ JDCT_FLOAT /* floating-point: accurate, fast on fast HW */
+} J_DCT_METHOD;
+
+#ifndef JDCT_DEFAULT /* may be overridden in jconfig.h */
+#define JDCT_DEFAULT JDCT_ISLOW
+#endif
+#ifndef JDCT_FASTEST /* may be overridden in jconfig.h */
+#define JDCT_FASTEST JDCT_IFAST
+#endif
+
+/* Dithering options for decompression. */
+
+typedef enum {
+ JDITHER_NONE, /* no dithering */
+ JDITHER_ORDERED, /* simple ordered dither */
+ JDITHER_FS /* Floyd-Steinberg error diffusion dither */
+} J_DITHER_MODE;
+
+
+/* Common fields between JPEG compression and decompression master structs. */
+
+#define jpeg_common_fields \
+ struct jpeg_error_mgr * err; /* Error handler module */\
+ struct jpeg_memory_mgr * mem; /* Memory manager module */\
+ struct jpeg_progress_mgr * progress; /* Progress monitor, or NULL if none */\
+ void * client_data; /* Available for use by application */\
+ int is_decompressor; /* So common code can tell which is which */\
+ int global_state /* For checking call sequence validity */
+
+/* Routines that are to be used by both halves of the library are declared
+ * to receive a pointer to this structure. There are no actual instances of
+ * jpeg_common_struct, only of jpeg_compress_struct and jpeg_decompress_struct.
+ */
+struct jpeg_common_struct {
+ jpeg_common_fields; /* Fields common to both master struct types */
+ /* Additional fields follow in an actual jpeg_compress_struct or
+ * jpeg_decompress_struct. All three structs must agree on these
+ * initial fields! (This would be a lot cleaner in C++.)
+ */
+};
+
+typedef struct jpeg_common_struct * j_common_ptr;
+typedef struct jpeg_compress_struct * j_compress_ptr;
+typedef struct jpeg_decompress_struct * j_decompress_ptr;
+
+
+/* Master record for a compression instance */
+
+struct jpeg_compress_struct {
+ jpeg_common_fields; /* Fields shared with jpeg_decompress_struct */
+
+ /* Destination for compressed data */
+ struct jpeg_destination_mgr * dest;
+
+ /* Description of source image --- these fields must be filled in by
+ * outer application before starting compression. in_color_space must
+ * be correct before you can even call jpeg_set_defaults().
+ */
+
+ JDIMENSION image_width; /* input image width */
+ JDIMENSION image_height; /* input image height */
+ int input_components; /* # of color components in input image */
+ J_COLOR_SPACE in_color_space; /* colorspace of input image */
+
+ double input_gamma; /* image gamma of input image */
+
+ /* Compression parameters --- these fields must be set before calling
+ * jpeg_start_compress(). We recommend calling jpeg_set_defaults() to
+ * initialize everything to reasonable defaults, then changing anything
+ * the application specifically wants to change. That way you won't get
+ * burnt when new parameters are added. Also note that there are several
+ * helper routines to simplify changing parameters.
+ */
+
+ int data_precision; /* bits of precision in image data */
+
+ int num_components; /* # of color components in JPEG image */
+ J_COLOR_SPACE jpeg_color_space; /* colorspace of JPEG image */
+
+ jpeg_component_info * comp_info;
+ /* comp_info[i] describes component that appears i'th in SOF */
+
+ JQUANT_TBL * quant_tbl_ptrs[NUM_QUANT_TBLS];
+ /* ptrs to coefficient quantization tables, or NULL if not defined */
+
+ JHUFF_TBL * dc_huff_tbl_ptrs[NUM_HUFF_TBLS];
+ JHUFF_TBL * ac_huff_tbl_ptrs[NUM_HUFF_TBLS];
+ /* ptrs to Huffman coding tables, or NULL if not defined */
+
+ unsigned char arith_dc_L[NUM_ARITH_TBLS]; /* L values for DC arith-coding tables */
+ unsigned char arith_dc_U[NUM_ARITH_TBLS]; /* U values for DC arith-coding tables */
+ unsigned char arith_ac_K[NUM_ARITH_TBLS]; /* Kx values for AC arith-coding tables */
+
+ int num_scans; /* # of entries in scan_info array */
+ const jpeg_scan_info * scan_info; /* script for multi-scan file, or NULL */
+ /* The default value of scan_info is NULL, which causes a single-scan
+ * sequential JPEG file to be emitted. To create a multi-scan file,
+ * set num_scans and scan_info to point to an array of scan definitions.
+ */
+
+ int raw_data_in; /* TRUE=caller supplies downsampled data */
+ int arith_code; /* TRUE=arithmetic coding, FALSE=Huffman */
+ int optimize_coding; /* TRUE=optimize entropy encoding parms */
+ int CCIR601_sampling; /* TRUE=first samples are cosited */
+ int smoothing_factor; /* 1..100, or 0 for no input smoothing */
+ J_DCT_METHOD dct_method; /* DCT algorithm selector */
+
+ /* The restart interval can be specified in absolute MCUs by setting
+ * restart_interval, or in MCU rows by setting restart_in_rows
+ * (in which case the correct restart_interval will be figured
+ * for each scan).
+ */
+ unsigned int restart_interval; /* MCUs per restart, or 0 for no restart */
+ int restart_in_rows; /* if > 0, MCU rows per restart interval */
+
+ /* Parameters controlling emission of special markers. */
+
+ int write_JFIF_header; /* should a JFIF marker be written? */
+ unsigned char JFIF_major_version; /* What to write for the JFIF version number */
+ unsigned char JFIF_minor_version;
+ /* These three values are not used by the JPEG code, merely copied */
+ /* into the JFIF APP0 marker. density_unit can be 0 for unknown, */
+ /* 1 for dots/inch, or 2 for dots/cm. Note that the pixel aspect */
+ /* ratio is defined by X_density/Y_density even when density_unit=0. */
+ unsigned char density_unit; /* JFIF code for pixel size units */
+ unsigned short X_density; /* Horizontal pixel density */
+ unsigned short Y_density; /* Vertical pixel density */
+ int write_Adobe_marker; /* should an Adobe marker be written? */
+
+ /* State variable: index of next scanline to be written to
+ * jpeg_write_scanlines(). Application may use this to control its
+ * processing loop, e.g., "while (next_scanline < image_height)".
+ */
+
+ JDIMENSION next_scanline; /* 0 .. image_height-1 */
+
+ /* Remaining fields are known throughout compressor, but generally
+ * should not be touched by a surrounding application.
+ */
+
+ /*
+ * These fields are computed during compression startup
+ */
+ int progressive_mode; /* TRUE if scan script uses progressive mode */
+ int max_h_samp_factor; /* largest h_samp_factor */
+ int max_v_samp_factor; /* largest v_samp_factor */
+
+ JDIMENSION total_iMCU_rows; /* # of iMCU rows to be input to coef ctlr */
+ /* The coefficient controller receives data in units of MCU rows as defined
+ * for fully interleaved scans (whether the JPEG file is interleaved or not).
+ * There are v_samp_factor * DCTSIZE sample rows of each component in an
+ * "iMCU" (interleaved MCU) row.
+ */
+
+ /*
+ * These fields are valid during any one scan.
+ * They describe the components and MCUs actually appearing in the scan.
+ */
+ int comps_in_scan; /* # of JPEG components in this scan */
+ jpeg_component_info * cur_comp_info[MAX_COMPS_IN_SCAN];
+ /* *cur_comp_info[i] describes component that appears i'th in SOS */
+
+ JDIMENSION MCUs_per_row; /* # of MCUs across the image */
+ JDIMENSION MCU_rows_in_scan; /* # of MCU rows in the image */
+
+ int blocks_in_MCU; /* # of DCT blocks per MCU */
+ int MCU_membership[C_MAX_BLOCKS_IN_MCU];
+ /* MCU_membership[i] is index in cur_comp_info of component owning */
+ /* i'th block in an MCU */
+
+ int Ss, Se, Ah, Al; /* progressive JPEG parameters for scan */
+
+ /*
+ * Links to compression subobjects (methods and private variables of modules)
+ */
+ struct jpeg_comp_master * master;
+ struct jpeg_c_main_controller * main;
+ struct jpeg_c_prep_controller * prep;
+ struct jpeg_c_coef_controller * coef;
+ struct jpeg_marker_writer * marker;
+ struct jpeg_color_converter * cconvert;
+ struct jpeg_downsampler * downsample;
+ struct jpeg_forward_dct * fdct;
+ struct jpeg_entropy_encoder * entropy;
+ jpeg_scan_info * script_space; /* workspace for jpeg_simple_progression */
+ int script_space_size;
+};
+
+
+/* Master record for a decompression instance */
+
+struct jpeg_decompress_struct {
+ jpeg_common_fields; /* Fields shared with jpeg_compress_struct */
+
+ /* Source of compressed data */
+ struct jpeg_source_mgr * src;
+
+ /* Basic description of image --- filled in by jpeg_read_header(). */
+ /* Application may inspect these values to decide how to process image. */
+
+ JDIMENSION image_width; /* nominal image width (from SOF marker) */
+ JDIMENSION image_height; /* nominal image height */
+ int num_components; /* # of color components in JPEG image */
+ J_COLOR_SPACE jpeg_color_space; /* colorspace of JPEG image */
+
+ /* Decompression processing parameters --- these fields must be set before
+ * calling jpeg_start_decompress(). Note that jpeg_read_header() initializes
+ * them to default values.
+ */
+
+ J_COLOR_SPACE out_color_space; /* colorspace for output */
+
+ unsigned int scale_num, scale_denom; /* fraction by which to scale image */
+
+ double output_gamma; /* image gamma wanted in output */
+
+ int buffered_image; /* TRUE=multiple output passes */
+ int raw_data_out; /* TRUE=downsampled data wanted */
+
+ J_DCT_METHOD dct_method; /* IDCT algorithm selector */
+ int do_fancy_upsampling; /* TRUE=apply fancy upsampling */
+ int do_block_smoothing; /* TRUE=apply interblock smoothing */
+
+ int quantize_colors; /* TRUE=colormapped output wanted */
+ /* the following are ignored if not quantize_colors: */
+ J_DITHER_MODE dither_mode; /* type of color dithering to use */
+ int two_pass_quantize; /* TRUE=use two-pass color quantization */
+ int desired_number_of_colors; /* max # colors to use in created colormap */
+ /* these are significant only in buffered-image mode: */
+ int enable_1pass_quant; /* enable future use of 1-pass quantizer */
+ int enable_external_quant;/* enable future use of external colormap */
+ int enable_2pass_quant; /* enable future use of 2-pass quantizer */
+
+ /* Description of actual output image that will be returned to application.
+ * These fields are computed by jpeg_start_decompress().
+ * You can also use jpeg_calc_output_dimensions() to determine these values
+ * in advance of calling jpeg_start_decompress().
+ */
+
+ JDIMENSION output_width; /* scaled image width */
+ JDIMENSION output_height; /* scaled image height */
+ int out_color_components; /* # of color components in out_color_space */
+ int output_components; /* # of color components returned */
+ /* output_components is 1 (a colormap index) when quantizing colors;
+ * otherwise it equals out_color_components.
+ */
+ int rec_outbuf_height; /* min recommended height of scanline buffer */
+ /* If the buffer passed to jpeg_read_scanlines() is less than this many rows
+ * high, space and time will be wasted due to unnecessary data copying.
+ * Usually rec_outbuf_height will be 1 or 2, at most 4.
+ */
+
+ /* When quantizing colors, the output colormap is described by these fields.
+ * The application can supply a colormap by setting colormap non-NULL before
+ * calling jpeg_start_decompress; otherwise a colormap is created during
+ * jpeg_start_decompress or jpeg_start_output.
+ * The map has out_color_components rows and actual_number_of_colors columns.
+ */
+ int actual_number_of_colors; /* number of entries in use */
+ JSAMPARRAY colormap; /* The color map as a 2-D pixel array */
+
+ /* State variables: these variables indicate the progress of decompression.
+ * The application may examine these but must not modify them.
+ */
+
+ /* Row index of next scanline to be read from jpeg_read_scanlines().
+ * Application may use this to control its processing loop, e.g.,
+ * "while (output_scanline < output_height)".
+ */
+ JDIMENSION output_scanline; /* 0 .. output_height-1 */
+
+ /* Current input scan number and number of iMCU rows completed in scan.
+ * These indicate the progress of the decompressor input side.
+ */
+ int input_scan_number; /* Number of SOS markers seen so far */
+ JDIMENSION input_iMCU_row; /* Number of iMCU rows completed */
+
+ /* The "output scan number" is the notional scan being displayed by the
+ * output side. The decompressor will not allow output scan/row number
+ * to get ahead of input scan/row, but it can fall arbitrarily far behind.
+ */
+ int output_scan_number; /* Nominal scan number being displayed */
+ JDIMENSION output_iMCU_row; /* Number of iMCU rows read */
+
+ /* Current progression status. coef_bits[c][i] indicates the precision
+ * with which component c's DCT coefficient i (in zigzag order) is known.
+ * It is -1 when no data has yet been received, otherwise it is the point
+ * transform (shift) value for the most recent scan of the coefficient
+ * (thus, 0 at completion of the progression).
+ * This pointer is NULL when reading a non-progressive file.
+ */
+ int (*coef_bits)[DCTSIZE2]; /* -1 or current Al value for each coef */
+
+ /* Internal JPEG parameters --- the application usually need not look at
+ * these fields. Note that the decompressor output side may not use
+ * any parameters that can change between scans.
+ */
+
+ /* Quantization and Huffman tables are carried forward across input
+ * datastreams when processing abbreviated JPEG datastreams.
+ */
+
+ JQUANT_TBL * quant_tbl_ptrs[NUM_QUANT_TBLS];
+ /* ptrs to coefficient quantization tables, or NULL if not defined */
+
+ JHUFF_TBL * dc_huff_tbl_ptrs[NUM_HUFF_TBLS];
+ JHUFF_TBL * ac_huff_tbl_ptrs[NUM_HUFF_TBLS];
+ /* ptrs to Huffman coding tables, or NULL if not defined */
+
+ /* These parameters are never carried across datastreams, since they
+ * are given in SOF/SOS markers or defined to be reset by SOI.
+ */
+
+ int data_precision; /* bits of precision in image data */
+
+ jpeg_component_info * comp_info;
+ /* comp_info[i] describes component that appears i'th in SOF */
+
+ int progressive_mode; /* TRUE if SOFn specifies progressive mode */
+ int arith_code; /* TRUE=arithmetic coding, FALSE=Huffman */
+
+ unsigned char arith_dc_L[NUM_ARITH_TBLS]; /* L values for DC arith-coding tables */
+ unsigned char arith_dc_U[NUM_ARITH_TBLS]; /* U values for DC arith-coding tables */
+ unsigned char arith_ac_K[NUM_ARITH_TBLS]; /* Kx values for AC arith-coding tables */
+
+ unsigned int restart_interval; /* MCUs per restart interval, or 0 for no restart */
+
+ /* These fields record data obtained from optional markers recognized by
+ * the JPEG library.
+ */
+ int saw_JFIF_marker; /* TRUE iff a JFIF APP0 marker was found */
+ /* Data copied from JFIF marker; only valid if saw_JFIF_marker is TRUE: */
+ unsigned char JFIF_major_version; /* JFIF version number */
+ unsigned char JFIF_minor_version;
+ unsigned char density_unit; /* JFIF code for pixel size units */
+ unsigned short X_density; /* Horizontal pixel density */
+ unsigned short Y_density; /* Vertical pixel density */
+ int saw_Adobe_marker; /* TRUE iff an Adobe APP14 marker was found */
+ unsigned char Adobe_transform; /* Color transform code from Adobe marker */
+
+ int CCIR601_sampling; /* TRUE=first samples are cosited */
+
+ /* Aside from the specific data retained from APPn markers known to the
+ * library, the uninterpreted contents of any or all APPn and COM markers
+ * can be saved in a list for examination by the application.
+ */
+ jpeg_saved_marker_ptr marker_list; /* Head of list of saved markers */
+
+ /* Remaining fields are known throughout decompressor, but generally
+ * should not be touched by a surrounding application.
+ */
+
+ /*
+ * These fields are computed during decompression startup
+ */
+ int max_h_samp_factor; /* largest h_samp_factor */
+ int max_v_samp_factor; /* largest v_samp_factor */
+
+ int min_DCT_scaled_size; /* smallest DCT_scaled_size of any component */
+
+ JDIMENSION total_iMCU_rows; /* # of iMCU rows in image */
+ /* The coefficient controller's input and output progress is measured in
+ * units of "iMCU" (interleaved MCU) rows. These are the same as MCU rows
+ * in fully interleaved JPEG scans, but are used whether the scan is
+ * interleaved or not. We define an iMCU row as v_samp_factor DCT block
+ * rows of each component. Therefore, the IDCT output contains
+ * v_samp_factor*DCT_scaled_size sample rows of a component per iMCU row.
+ */
+
+ JSAMPLE * sample_range_limit; /* table for fast range-limiting */
+
+ /*
+ * These fields are valid during any one scan.
+ * They describe the components and MCUs actually appearing in the scan.
+ * Note that the decompressor output side must not use these fields.
+ */
+ int comps_in_scan; /* # of JPEG components in this scan */
+ jpeg_component_info * cur_comp_info[MAX_COMPS_IN_SCAN];
+ /* *cur_comp_info[i] describes component that appears i'th in SOS */
+
+ JDIMENSION MCUs_per_row; /* # of MCUs across the image */
+ JDIMENSION MCU_rows_in_scan; /* # of MCU rows in the image */
+
+ int blocks_in_MCU; /* # of DCT blocks per MCU */
+ int MCU_membership[D_MAX_BLOCKS_IN_MCU];
+ /* MCU_membership[i] is index in cur_comp_info of component owning */
+ /* i'th block in an MCU */
+
+ int Ss, Se, Ah, Al; /* progressive JPEG parameters for scan */
+
+ /* This field is shared between entropy decoder and marker parser.
+ * It is either zero or the code of a JPEG marker that has been
+ * read from the data source, but has not yet been processed.
+ */
+ int unread_marker;
+
+ /*
+ * Links to decompression subobjects (methods, private variables of modules)
+ */
+ struct jpeg_decomp_master * master;
+ struct jpeg_d_main_controller * main;
+ struct jpeg_d_coef_controller * coef;
+ struct jpeg_d_post_controller * post;
+ struct jpeg_input_controller * inputctl;
+ struct jpeg_marker_reader * marker;
+ struct jpeg_entropy_decoder * entropy;
+ struct jpeg_inverse_dct * idct;
+ struct jpeg_upsampler * upsample;
+ struct jpeg_color_deconverter * cconvert;
+ struct jpeg_color_quantizer * cquantize;
+};
+
+
+/* "Object" declarations for JPEG modules that may be supplied or called
+ * directly by the surrounding application.
+ * As with all objects in the JPEG library, these structs only define the
+ * publicly visible methods and state variables of a module. Additional
+ * private fields may exist after the public ones.
+ */
+
+
+/* Error handler object */
+
+struct jpeg_error_mgr {
+ /* Error exit handler: does not return to caller */
+ JMETHOD(void, error_exit, (j_common_ptr cinfo));
+ /* Conditionally emit a trace or warning message */
+ JMETHOD(void, emit_message, (j_common_ptr cinfo, int msg_level));
+ /* Routine that actually outputs a trace or error message */
+ JMETHOD(void, output_message, (j_common_ptr cinfo));
+ /* Format a message string for the most recent JPEG error or message */
+ JMETHOD(void, format_message, (j_common_ptr cinfo, char * buffer));
+#define JMSG_LENGTH_MAX 200 /* recommended size of format_message buffer */
+ /* Reset error state variables at start of a new image */
+ JMETHOD(void, reset_error_mgr, (j_common_ptr cinfo));
+
+ /* The message ID code and any parameters are saved here.
+ * A message can have one string parameter or up to 8 int parameters.
+ */
+ int msg_code;
+#define JMSG_STR_PARM_MAX 80
+ union {
+ int i[8];
+ char s[JMSG_STR_PARM_MAX];
+ } msg_parm;
+
+ /* Standard state variables for error facility */
+
+ int trace_level; /* max msg_level that will be displayed */
+
+ /* For recoverable corrupt-data errors, we emit a warning message,
+ * but keep going unless emit_message chooses to abort. emit_message
+ * should count warnings in num_warnings. The surrounding application
+ * can check for bad data by seeing if num_warnings is nonzero at the
+ * end of processing.
+ */
+ long num_warnings; /* number of corrupt-data warnings */
+
+ /* These fields point to the table(s) of error message strings.
+ * An application can change the table pointer to switch to a different
+ * message list (typically, to change the language in which errors are
+ * reported). Some applications may wish to add additional error codes
+ * that will be handled by the JPEG library error mechanism; the second
+ * table pointer is used for this purpose.
+ *
+ * First table includes all errors generated by JPEG library itself.
+ * Error code 0 is reserved for a "no such error string" message.
+ */
+ const char * const * jpeg_message_table; /* Library errors */
+ int last_jpeg_message; /* Table contains strings 0..last_jpeg_message */
+ /* Second table can be added by application (see cjpeg/djpeg for example).
+ * It contains strings numbered first_addon_message..last_addon_message.
+ */
+ const char * const * addon_message_table; /* Non-library errors */
+ int first_addon_message; /* code for first string in addon table */
+ int last_addon_message; /* code for last string in addon table */
+};
+
+
+/* Progress monitor object */
+
+struct jpeg_progress_mgr {
+ JMETHOD(void, progress_monitor, (j_common_ptr cinfo));
+
+ long pass_counter; /* work units completed in this pass */
+ long pass_limit; /* total number of work units in this pass */
+ int completed_passes; /* passes completed so far */
+ int total_passes; /* total number of passes expected */
+};
+
+
+/* Data destination object for compression */
+
+struct jpeg_destination_mgr {
+ JOCTET * next_output_byte; /* => next byte to write in buffer */
+ size_t free_in_buffer; /* # of byte spaces remaining in buffer */
+
+ JMETHOD(void, init_destination, (j_compress_ptr cinfo));
+ JMETHOD(int, empty_output_buffer, (j_compress_ptr cinfo));
+ JMETHOD(void, term_destination, (j_compress_ptr cinfo));
+};
+
+
+/* Data source object for decompression */
+
+struct jpeg_source_mgr {
+ const JOCTET * next_input_byte; /* => next byte to read from buffer */
+ size_t bytes_in_buffer; /* # of bytes remaining in buffer */
+
+ JMETHOD(void, init_source, (j_decompress_ptr cinfo));
+ JMETHOD(int, fill_input_buffer, (j_decompress_ptr cinfo));
+ JMETHOD(void, skip_input_data, (j_decompress_ptr cinfo, long num_bytes));
+ JMETHOD(int, resync_to_restart, (j_decompress_ptr cinfo, int desired));
+ JMETHOD(void, term_source, (j_decompress_ptr cinfo));
+};
+
+
+/* Memory manager object.
+ * Allocates "small" objects (a few K total), "large" objects (tens of K),
+ * and "really big" objects (virtual arrays with backing store if needed).
+ * The memory manager does not allow individual objects to be freed; rather,
+ * each created object is assigned to a pool, and whole pools can be freed
+ * at once. This is faster and more convenient than remembering exactly what
+ * to free, especially where malloc()/free() are not too speedy.
+ * NB: alloc routines never return NULL. They exit to error_exit if not
+ * successful.
+ */
+
+#define JPOOL_PERMANENT 0 /* lasts until master record is destroyed */
+#define JPOOL_IMAGE 1 /* lasts until done with image/datastream */
+#define JPOOL_NUMPOOLS 2
+
+typedef struct jvirt_sarray_control * jvirt_sarray_ptr;
+typedef struct jvirt_barray_control * jvirt_barray_ptr;
+
+
+struct jpeg_memory_mgr {
+ /* Method pointers */
+ JMETHOD(void *, alloc_small, (j_common_ptr cinfo, int pool_id,
+ size_t sizeofobject));
+ JMETHOD(void FAR *, alloc_large, (j_common_ptr cinfo, int pool_id,
+ size_t sizeofobject));
+ JMETHOD(JSAMPARRAY, alloc_sarray, (j_common_ptr cinfo, int pool_id,
+ JDIMENSION samplesperrow,
+ JDIMENSION numrows));
+ JMETHOD(JBLOCKARRAY, alloc_barray, (j_common_ptr cinfo, int pool_id,
+ JDIMENSION blocksperrow,
+ JDIMENSION numrows));
+ JMETHOD(jvirt_sarray_ptr, request_virt_sarray, (j_common_ptr cinfo,
+ int pool_id,
+ int pre_zero,
+ JDIMENSION samplesperrow,
+ JDIMENSION numrows,
+ JDIMENSION maxaccess));
+ JMETHOD(jvirt_barray_ptr, request_virt_barray, (j_common_ptr cinfo,
+ int pool_id,
+ int pre_zero,
+ JDIMENSION blocksperrow,
+ JDIMENSION numrows,
+ JDIMENSION maxaccess));
+ JMETHOD(void, realize_virt_arrays, (j_common_ptr cinfo));
+ JMETHOD(JSAMPARRAY, access_virt_sarray, (j_common_ptr cinfo,
+ jvirt_sarray_ptr ptr,
+ JDIMENSION start_row,
+ JDIMENSION num_rows,
+ int writable));
+ JMETHOD(JBLOCKARRAY, access_virt_barray, (j_common_ptr cinfo,
+ jvirt_barray_ptr ptr,
+ JDIMENSION start_row,
+ JDIMENSION num_rows,
+ int writable));
+ JMETHOD(void, free_pool, (j_common_ptr cinfo, int pool_id));
+ JMETHOD(void, self_destruct, (j_common_ptr cinfo));
+
+ /* Limit on memory allocation for this JPEG object. (Note that this is
+ * merely advisory, not a guaranteed maximum; it only affects the space
+ * used for virtual-array buffers.) May be changed by outer application
+ * after creating the JPEG object.
+ */
+ long max_memory_to_use;
+
+ /* Maximum allocation request accepted by alloc_large. */
+ long max_alloc_chunk;
+};
+
+
+/* Routine signature for application-supplied marker processing methods.
+ * Need not pass marker code since it is stored in cinfo->unread_marker.
+ */
+typedef JMETHOD(int, jpeg_marker_parser_method, (j_decompress_ptr cinfo));
+
+
+/* Declarations for routines called by application.
+ * The JPP macro hides prototype parameters from compilers that can't cope.
+ * Note JPP requires double parentheses.
+ */
+
+#ifdef HAVE_PROTOTYPES
+#define JPP(arglist) arglist
+#else
+#define JPP(arglist) ()
+#endif
+
+
+/* Short forms of external names for systems with brain-damaged linkers.
+ * We shorten external names to be unique in the first six letters, which
+ * is good enough for all known systems.
+ * (If your compiler itself needs names to be unique in less than 15
+ * characters, you are out of luck. Get a better compiler.)
+ */
+
+#ifdef NEED_SHORT_EXTERNAL_NAMES
+#define jpeg_std_error jStdError
+#define jpeg_CreateCompress jCreaCompress
+#define jpeg_CreateDecompress jCreaDecompress
+#define jpeg_destroy_compress jDestCompress
+#define jpeg_destroy_decompress jDestDecompress
+#define jpeg_stdio_dest jStdDest
+#define jpeg_stdio_src jStdSrc
+#define jpeg_set_defaults jSetDefaults
+#define jpeg_set_colorspace jSetColorspace
+#define jpeg_default_colorspace jDefColorspace
+#define jpeg_set_quality jSetQuality
+#define jpeg_set_linear_quality jSetLQuality
+#define jpeg_add_quant_table jAddQuantTable
+#define jpeg_quality_scaling jQualityScaling
+#define jpeg_simple_progression jSimProgress
+#define jpeg_suppress_tables jSuppressTables
+#define jpeg_alloc_quant_table jAlcQTable
+#define jpeg_alloc_huff_table jAlcHTable
+#define jpeg_start_compress jStrtCompress
+#define jpeg_write_scanlines jWrtScanlines
+#define jpeg_finish_compress jFinCompress
+#define jpeg_write_raw_data jWrtRawData
+#define jpeg_write_marker jWrtMarker
+#define jpeg_write_m_header jWrtMHeader
+#define jpeg_write_m_byte jWrtMByte
+#define jpeg_write_tables jWrtTables
+#define jpeg_read_header jReadHeader
+#define jpeg_start_decompress jStrtDecompress
+#define jpeg_read_scanlines jReadScanlines
+#define jpeg_finish_decompress jFinDecompress
+#define jpeg_read_raw_data jReadRawData
+#define jpeg_has_multiple_scans jHasMultScn
+#define jpeg_start_output jStrtOutput
+#define jpeg_finish_output jFinOutput
+#define jpeg_input_complete jInComplete
+#define jpeg_new_colormap jNewCMap
+#define jpeg_consume_input jConsumeInput
+#define jpeg_calc_output_dimensions jCalcDimensions
+#define jpeg_save_markers jSaveMarkers
+#define jpeg_set_marker_processor jSetMarker
+#define jpeg_read_coefficients jReadCoefs
+#define jpeg_write_coefficients jWrtCoefs
+#define jpeg_copy_critical_parameters jCopyCrit
+#define jpeg_abort_compress jAbrtCompress
+#define jpeg_abort_decompress jAbrtDecompress
+#define jpeg_abort jAbort
+#define jpeg_destroy jDestroy
+#define jpeg_resync_to_restart jResyncRestart
+#endif /* NEED_SHORT_EXTERNAL_NAMES */
+
+
+/* Default error-management setup */
+EXTERN(struct jpeg_error_mgr *) jpeg_std_error
+ JPP((struct jpeg_error_mgr * err));
+
+/* Initialization of JPEG compression objects.
+ * jpeg_create_compress() and jpeg_create_decompress() are the exported
+ * names that applications should call. These expand to calls on
+ * jpeg_CreateCompress and jpeg_CreateDecompress with additional information
+ * passed for version mismatch checking.
+ * NB: you must set up the error-manager BEFORE calling jpeg_create_xxx.
+ */
+#define jpeg_create_compress(cinfo) \
+ jpeg_CreateCompress((cinfo), JPEG_LIB_VERSION, \
+ (size_t) sizeof(struct jpeg_compress_struct))
+#define jpeg_create_decompress(cinfo) \
+ jpeg_CreateDecompress((cinfo), JPEG_LIB_VERSION, \
+ (size_t) sizeof(struct jpeg_decompress_struct))
+EXTERN(void) jpeg_CreateCompress JPP((j_compress_ptr cinfo,
+ int version, size_t structsize));
+EXTERN(void) jpeg_CreateDecompress JPP((j_decompress_ptr cinfo,
+ int version, size_t structsize));
+/* Destruction of JPEG compression objects */
+EXTERN(void) jpeg_destroy_compress JPP((j_compress_ptr cinfo));
+EXTERN(void) jpeg_destroy_decompress JPP((j_decompress_ptr cinfo));
+
+/* Standard data source and destination managers: stdio streams. */
+/* Caller is responsible for opening the file before and closing after. */
+EXTERN(void) jpeg_stdio_dest JPP((j_compress_ptr cinfo, FILE * outfile));
+EXTERN(void) jpeg_stdio_src JPP((j_decompress_ptr cinfo, FILE * infile));
+
+/* Default parameter setup for compression */
+EXTERN(void) jpeg_set_defaults JPP((j_compress_ptr cinfo));
+/* Compression parameter setup aids */
+EXTERN(void) jpeg_set_colorspace JPP((j_compress_ptr cinfo,
+ J_COLOR_SPACE colorspace));
+EXTERN(void) jpeg_default_colorspace JPP((j_compress_ptr cinfo));
+EXTERN(void) jpeg_set_quality JPP((j_compress_ptr cinfo, int quality,
+ int force_baseline));
+EXTERN(void) jpeg_set_linear_quality JPP((j_compress_ptr cinfo,
+ int scale_factor,
+ int force_baseline));
+EXTERN(void) jpeg_add_quant_table JPP((j_compress_ptr cinfo, int which_tbl,
+ const unsigned int *basic_table,
+ int scale_factor,
+ int force_baseline));
+EXTERN(int) jpeg_quality_scaling JPP((int quality));
+EXTERN(void) jpeg_simple_progression JPP((j_compress_ptr cinfo));
+EXTERN(void) jpeg_suppress_tables JPP((j_compress_ptr cinfo,
+ int suppress));
+EXTERN(JQUANT_TBL *) jpeg_alloc_quant_table JPP((j_common_ptr cinfo));
+EXTERN(JHUFF_TBL *) jpeg_alloc_huff_table JPP((j_common_ptr cinfo));
+
+/* Main entry points for compression */
+EXTERN(void) jpeg_start_compress JPP((j_compress_ptr cinfo,
+ int write_all_tables));
+EXTERN(JDIMENSION) jpeg_write_scanlines JPP((j_compress_ptr cinfo,
+ JSAMPARRAY scanlines,
+ JDIMENSION num_lines));
+EXTERN(void) jpeg_finish_compress JPP((j_compress_ptr cinfo));
+
+/* Replaces jpeg_write_scanlines when writing raw downsampled data. */
+EXTERN(JDIMENSION) jpeg_write_raw_data JPP((j_compress_ptr cinfo,
+ JSAMPIMAGE data,
+ JDIMENSION num_lines));
+
+/* Write a special marker. See libjpeg.doc concerning safe usage. */
+EXTERN(void) jpeg_write_marker
+ JPP((j_compress_ptr cinfo, int marker,
+ const JOCTET * dataptr, unsigned int datalen));
+/* Same, but piecemeal. */
+EXTERN(void) jpeg_write_m_header
+ JPP((j_compress_ptr cinfo, int marker, unsigned int datalen));
+EXTERN(void) jpeg_write_m_byte
+ JPP((j_compress_ptr cinfo, int val));
+
+/* Alternate compression function: just write an abbreviated table file */
+EXTERN(void) jpeg_write_tables JPP((j_compress_ptr cinfo));
+
+/* Decompression startup: read start of JPEG datastream to see what's there */
+EXTERN(int) jpeg_read_header JPP((j_decompress_ptr cinfo,
+ int require_image));
+/* Return value is one of: */
+#define JPEG_SUSPENDED 0 /* Suspended due to lack of input data */
+#define JPEG_HEADER_OK 1 /* Found valid image datastream */
+#define JPEG_HEADER_TABLES_ONLY 2 /* Found valid table-specs-only datastream */
+/* If you pass require_image = TRUE (normal case), you need not check for
+ * a TABLES_ONLY return code; an abbreviated file will cause an error exit.
+ * JPEG_SUSPENDED is only possible if you use a data source module that can
+ * give a suspension return (the stdio source module doesn't).
+ */
+
+/* Main entry points for decompression */
+EXTERN(int) jpeg_start_decompress JPP((j_decompress_ptr cinfo));
+EXTERN(JDIMENSION) jpeg_read_scanlines JPP((j_decompress_ptr cinfo,
+ JSAMPARRAY scanlines,
+ JDIMENSION max_lines));
+EXTERN(int) jpeg_finish_decompress JPP((j_decompress_ptr cinfo));
+
+/* Replaces jpeg_read_scanlines when reading raw downsampled data. */
+EXTERN(JDIMENSION) jpeg_read_raw_data JPP((j_decompress_ptr cinfo,
+ JSAMPIMAGE data,
+ JDIMENSION max_lines));
+
+/* Additional entry points for buffered-image mode. */
+EXTERN(int) jpeg_has_multiple_scans JPP((j_decompress_ptr cinfo));
+EXTERN(int) jpeg_start_output JPP((j_decompress_ptr cinfo,
+ int scan_number));
+EXTERN(int) jpeg_finish_output JPP((j_decompress_ptr cinfo));
+EXTERN(int) jpeg_input_complete JPP((j_decompress_ptr cinfo));
+EXTERN(void) jpeg_new_colormap JPP((j_decompress_ptr cinfo));
+EXTERN(int) jpeg_consume_input JPP((j_decompress_ptr cinfo));
+/* Return value is one of: */
+/* #define JPEG_SUSPENDED 0 Suspended due to lack of input data */
+#define JPEG_REACHED_SOS 1 /* Reached start of new scan */
+#define JPEG_REACHED_EOI 2 /* Reached end of image */
+#define JPEG_ROW_COMPLETED 3 /* Completed one iMCU row */
+#define JPEG_SCAN_COMPLETED 4 /* Completed last iMCU row of a scan */
+
+/* Precalculate output dimensions for current decompression parameters. */
+EXTERN(void) jpeg_calc_output_dimensions JPP((j_decompress_ptr cinfo));
+
+/* Control saving of COM and APPn markers into marker_list. */
+EXTERN(void) jpeg_save_markers
+ JPP((j_decompress_ptr cinfo, int marker_code,
+ unsigned int length_limit));
+
+/* Install a special processing method for COM or APPn markers. */
+EXTERN(void) jpeg_set_marker_processor
+ JPP((j_decompress_ptr cinfo, int marker_code,
+ jpeg_marker_parser_method routine));
+
+/* Read or write raw DCT coefficients --- useful for lossless transcoding. */
+EXTERN(jvirt_barray_ptr *) jpeg_read_coefficients JPP((j_decompress_ptr cinfo));
+EXTERN(void) jpeg_write_coefficients JPP((j_compress_ptr cinfo,
+ jvirt_barray_ptr * coef_arrays));
+EXTERN(void) jpeg_copy_critical_parameters JPP((j_decompress_ptr srcinfo,
+ j_compress_ptr dstinfo));
+
+/* If you choose to abort compression or decompression before completing
+ * jpeg_finish_(de)compress, then you need to clean up to release memory,
+ * temporary files, etc. You can just call jpeg_destroy_(de)compress
+ * if you're done with the JPEG object, but if you want to clean it up and
+ * reuse it, call this:
+ */
+EXTERN(void) jpeg_abort_compress JPP((j_compress_ptr cinfo));
+EXTERN(void) jpeg_abort_decompress JPP((j_decompress_ptr cinfo));
+
+/* Generic versions of jpeg_abort and jpeg_destroy that work on either
+ * flavor of JPEG object. These may be more convenient in some places.
+ */
+EXTERN(void) jpeg_abort JPP((j_common_ptr cinfo));
+EXTERN(void) jpeg_destroy JPP((j_common_ptr cinfo));
+
+/* Default restart-marker-resync procedure for use by data source modules */
+EXTERN(int) jpeg_resync_to_restart JPP((j_decompress_ptr cinfo,
+ int desired));
+
+
+/* These marker codes are exported since applications and data source modules
+ * are likely to want to use them.
+ */
+
+#define JPEG_RST0 0xD0 /* RST0 marker code */
+#define JPEG_EOI 0xD9 /* EOI marker code */
+#define JPEG_APP0 0xE0 /* APP0 marker code */
+#define JPEG_COM 0xFE /* COM marker code */
+
+
+/* If we have a brain-damaged compiler that emits warnings (or worse, errors)
+ * for structure definitions that are never filled in, keep it quiet by
+ * supplying dummy definitions for the various substructures.
+ */
+
+#ifdef INCOMPLETE_TYPES_BROKEN
+#ifndef JPEG_INTERNALS /* will be defined in jpegint.h */
+struct jvirt_sarray_control { long dummy; };
+struct jvirt_barray_control { long dummy; };
+struct jpeg_comp_master { long dummy; };
+struct jpeg_c_main_controller { long dummy; };
+struct jpeg_c_prep_controller { long dummy; };
+struct jpeg_c_coef_controller { long dummy; };
+struct jpeg_marker_writer { long dummy; };
+struct jpeg_color_converter { long dummy; };
+struct jpeg_downsampler { long dummy; };
+struct jpeg_forward_dct { long dummy; };
+struct jpeg_entropy_encoder { long dummy; };
+struct jpeg_decomp_master { long dummy; };
+struct jpeg_d_main_controller { long dummy; };
+struct jpeg_d_coef_controller { long dummy; };
+struct jpeg_d_post_controller { long dummy; };
+struct jpeg_input_controller { long dummy; };
+struct jpeg_marker_reader { long dummy; };
+struct jpeg_entropy_decoder { long dummy; };
+struct jpeg_inverse_dct { long dummy; };
+struct jpeg_upsampler { long dummy; };
+struct jpeg_color_deconverter { long dummy; };
+struct jpeg_color_quantizer { long dummy; };
+#endif /* JPEG_INTERNALS */
+#endif /* INCOMPLETE_TYPES_BROKEN */
+
+
+/*
+ * The JPEG library modules define JPEG_INTERNALS before including this file.
+ * The internal structure declarations are read only when that is true.
+ * Applications using the library should not include jpegint.h, but may wish
+ * to include jerror.h.
+ */
+
+#ifdef JPEG_INTERNALS
+#include "jpegint.h" /* fetch private declarations */
+#include "jerror.h" /* fetch error codes too */
+#endif
+
+#endif /* JPEGLIB_H */
diff --git a/ml/dlib/dlib/external/libjpeg/jquant1.cpp b/ml/dlib/dlib/external/libjpeg/jquant1.cpp
new file mode 100644
index 000000000..7582015ad
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jquant1.cpp
@@ -0,0 +1,856 @@
+/*
+ * jquant1.c
+ *
+ * Copyright (C) 1991-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains 1-pass color quantization (color mapping) routines.
+ * These routines provide mapping to a fixed color map using equally spaced
+ * color values. Optional Floyd-Steinberg or ordered dithering is available.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+#ifdef QUANT_1PASS_SUPPORTED
+
+
+/*
+ * The main purpose of 1-pass quantization is to provide a fast, if not very
+ * high quality, colormapped output capability. A 2-pass quantizer usually
+ * gives better visual quality; however, for quantized grayscale output this
+ * quantizer is perfectly adequate. Dithering is highly recommended with this
+ * quantizer, though you can turn it off if you really want to.
+ *
+ * In 1-pass quantization the colormap must be chosen in advance of seeing the
+ * image. We use a map consisting of all combinations of Ncolors[i] color
+ * values for the i'th component. The Ncolors[] values are chosen so that
+ * their product, the total number of colors, is no more than that requested.
+ * (In most cases, the product will be somewhat less.)
+ *
+ * Since the colormap is orthogonal, the representative value for each color
+ * component can be determined without considering the other components;
+ * then these indexes can be combined into a colormap index by a standard
+ * N-dimensional-array-subscript calculation. Most of the arithmetic involved
+ * can be precalculated and stored in the lookup table colorindex[].
+ * colorindex[i][j] maps pixel value j in component i to the nearest
+ * representative value (grid plane) for that component; this index is
+ * multiplied by the array stride for component i, so that the
+ * index of the colormap entry closest to a given pixel value is just
+ * sum( colorindex[component-number][pixel-component-value] )
+ * Aside from being fast, this scheme allows for variable spacing between
+ * representative values with no additional lookup cost.
+ *
+ * If gamma correction has been applied in color conversion, it might be wise
+ * to adjust the color grid spacing so that the representative colors are
+ * equidistant in linear space. At this writing, gamma correction is not
+ * implemented by jdcolor, so nothing is done here.
+ */
+
+
+/* Declarations for ordered dithering.
+ *
+ * We use a standard 16x16 ordered dither array. The basic concept of ordered
+ * dithering is described in many references, for instance Dale Schumacher's
+ * chapter II.2 of Graphics Gems II (James Arvo, ed. Academic Press, 1991).
+ * In place of Schumacher's comparisons against a "threshold" value, we add a
+ * "dither" value to the input pixel and then round the result to the nearest
+ * output value. The dither value is equivalent to (0.5 - threshold) times
+ * the distance between output values. For ordered dithering, we assume that
+ * the output colors are equally spaced; if not, results will probably be
+ * worse, since the dither may be too much or too little at a given point.
+ *
+ * The normal calculation would be to form pixel value + dither, range-limit
+ * this to 0..MAXJSAMPLE, and then index into the colorindex table as usual.
+ * We can skip the separate range-limiting step by extending the colorindex
+ * table in both directions.
+ */
+
+#define ODITHER_SIZE 16 /* dimension of dither matrix */
+/* NB: if ODITHER_SIZE is not a power of 2, ODITHER_MASK uses will break */
+#define ODITHER_CELLS (ODITHER_SIZE*ODITHER_SIZE) /* # cells in matrix */
+#define ODITHER_MASK (ODITHER_SIZE-1) /* mask for wrapping around counters */
+
+typedef int ODITHER_MATRIX[ODITHER_SIZE][ODITHER_SIZE];
+typedef int (*ODITHER_MATRIX_PTR)[ODITHER_SIZE];
+
+static const unsigned char base_dither_matrix[ODITHER_SIZE][ODITHER_SIZE] = {
+ /* Bayer's order-4 dither array. Generated by the code given in
+ * Stephen Hawley's article "Ordered Dithering" in Graphics Gems I.
+ * The values in this array must range from 0 to ODITHER_CELLS-1.
+ */
+ { 0,192, 48,240, 12,204, 60,252, 3,195, 51,243, 15,207, 63,255 },
+ { 128, 64,176,112,140, 76,188,124,131, 67,179,115,143, 79,191,127 },
+ { 32,224, 16,208, 44,236, 28,220, 35,227, 19,211, 47,239, 31,223 },
+ { 160, 96,144, 80,172,108,156, 92,163, 99,147, 83,175,111,159, 95 },
+ { 8,200, 56,248, 4,196, 52,244, 11,203, 59,251, 7,199, 55,247 },
+ { 136, 72,184,120,132, 68,180,116,139, 75,187,123,135, 71,183,119 },
+ { 40,232, 24,216, 36,228, 20,212, 43,235, 27,219, 39,231, 23,215 },
+ { 168,104,152, 88,164,100,148, 84,171,107,155, 91,167,103,151, 87 },
+ { 2,194, 50,242, 14,206, 62,254, 1,193, 49,241, 13,205, 61,253 },
+ { 130, 66,178,114,142, 78,190,126,129, 65,177,113,141, 77,189,125 },
+ { 34,226, 18,210, 46,238, 30,222, 33,225, 17,209, 45,237, 29,221 },
+ { 162, 98,146, 82,174,110,158, 94,161, 97,145, 81,173,109,157, 93 },
+ { 10,202, 58,250, 6,198, 54,246, 9,201, 57,249, 5,197, 53,245 },
+ { 138, 74,186,122,134, 70,182,118,137, 73,185,121,133, 69,181,117 },
+ { 42,234, 26,218, 38,230, 22,214, 41,233, 25,217, 37,229, 21,213 },
+ { 170,106,154, 90,166,102,150, 86,169,105,153, 89,165,101,149, 85 }
+};
+
+
+/* Declarations for Floyd-Steinberg dithering.
+ *
+ * Errors are accumulated into the array fserrors[], at a resolution of
+ * 1/16th of a pixel count. The error at a given pixel is propagated
+ * to its not-yet-processed neighbors using the standard F-S fractions,
+ * ... (here) 7/16
+ * 3/16 5/16 1/16
+ * We work left-to-right on even rows, right-to-left on odd rows.
+ *
+ * We can get away with a single array (holding one row's worth of errors)
+ * by using it to store the current row's errors at pixel columns not yet
+ * processed, but the next row's errors at columns already processed. We
+ * need only a few extra variables to hold the errors immediately around the
+ * current column. (If we are lucky, those variables are in registers, but
+ * even if not, they're probably cheaper to access than array elements are.)
+ *
+ * The fserrors[] array is indexed [component#][position].
+ * We provide (#columns + 2) entries per component; the extra entry at each
+ * end saves us from special-casing the first and last pixels.
+ *
+ * Note: on a wide image, we might not have enough room in a PC's near data
+ * segment to hold the error array; so it is allocated with alloc_large.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+typedef short FSERROR; /* 16 bits should be enough */
+typedef int LOCFSERROR; /* use 'int' for calculation temps */
+#else
+typedef long FSERROR; /* may need more than 16 bits */
+typedef long LOCFSERROR; /* be sure calculation temps are big enough */
+#endif
+
+typedef FSERROR FAR *FSERRPTR; /* pointer to error array (in FAR storage!) */
+
+
+/* Private subobject */
+
+#define MAX_Q_COMPS 4 /* max components I can handle */
+
+typedef struct {
+ struct jpeg_color_quantizer pub; /* public fields */
+
+ /* Initially allocated colormap is saved here */
+ JSAMPARRAY sv_colormap; /* The color map as a 2-D pixel array */
+ int sv_actual; /* number of entries in use */
+
+ JSAMPARRAY colorindex; /* Precomputed mapping for speed */
+ /* colorindex[i][j] = index of color closest to pixel value j in component i,
+ * premultiplied as described above. Since colormap indexes must fit into
+ * JSAMPLEs, the entries of this array will too.
+ */
+ int is_padded; /* is the colorindex padded for odither? */
+
+ int Ncolors[MAX_Q_COMPS]; /* # of values alloced to each component */
+
+ /* Variables for ordered dithering */
+ int row_index; /* cur row's vertical index in dither matrix */
+ ODITHER_MATRIX_PTR odither[MAX_Q_COMPS]; /* one dither array per component */
+
+ /* Variables for Floyd-Steinberg dithering */
+ FSERRPTR fserrors[MAX_Q_COMPS]; /* accumulated errors */
+ int on_odd_row; /* flag to remember which row we are on */
+} my_cquantizer;
+
+typedef my_cquantizer * my_cquantize_ptr;
+
+
+/*
+ * Policy-making subroutines for create_colormap and create_colorindex.
+ * These routines determine the colormap to be used. The rest of the module
+ * only assumes that the colormap is orthogonal.
+ *
+ * * select_ncolors decides how to divvy up the available colors
+ * among the components.
+ * * output_value defines the set of representative values for a component.
+ * * largest_input_value defines the mapping from input values to
+ * representative values for a component.
+ * Note that the latter two routines may impose different policies for
+ * different components, though this is not currently done.
+ */
+
+
+LOCAL(int)
+select_ncolors (j_decompress_ptr cinfo, int Ncolors[])
+/* Determine allocation of desired colors to components, */
+/* and fill in Ncolors[] array to indicate choice. */
+/* Return value is total number of colors (product of Ncolors[] values). */
+{
+ int nc = cinfo->out_color_components; /* number of color components */
+ int max_colors = cinfo->desired_number_of_colors;
+ int total_colors, iroot, i, j;
+ int changed;
+ long temp;
+ static const int RGB_order[3] = { RGB_GREEN, RGB_RED, RGB_BLUE };
+
+ /* We can allocate at least the nc'th root of max_colors per component. */
+ /* Compute floor(nc'th root of max_colors). */
+ iroot = 1;
+ do {
+ iroot++;
+ temp = iroot; /* set temp = iroot ** nc */
+ for (i = 1; i < nc; i++)
+ temp *= iroot;
+ } while (temp <= (long) max_colors); /* repeat till iroot exceeds root */
+ iroot--; /* now iroot = floor(root) */
+
+ /* Must have at least 2 color values per component */
+ if (iroot < 2)
+ ERREXIT1(cinfo, JERR_QUANT_FEW_COLORS, (int) temp);
+
+ /* Initialize to iroot color values for each component */
+ total_colors = 1;
+ for (i = 0; i < nc; i++) {
+ Ncolors[i] = iroot;
+ total_colors *= iroot;
+ }
+ /* We may be able to increment the count for one or more components without
+ * exceeding max_colors, though we know not all can be incremented.
+ * Sometimes, the first component can be incremented more than once!
+ * (Example: for 16 colors, we start at 2*2*2, go to 3*2*2, then 4*2*2.)
+ * In RGB colorspace, try to increment G first, then R, then B.
+ */
+ do {
+ changed = FALSE;
+ for (i = 0; i < nc; i++) {
+ j = (cinfo->out_color_space == JCS_RGB ? RGB_order[i] : i);
+ /* calculate new total_colors if Ncolors[j] is incremented */
+ temp = total_colors / Ncolors[j];
+ temp *= Ncolors[j]+1; /* done in long arith to avoid oflo */
+ if (temp > (long) max_colors)
+ break; /* won't fit, done with this pass */
+ Ncolors[j]++; /* OK, apply the increment */
+ total_colors = (int) temp;
+ changed = TRUE;
+ }
+ } while (changed);
+
+ return total_colors;
+}
+
+
+LOCAL(int)
+output_value (j_decompress_ptr , int , int j, int maxj)
+/* Return j'th output value, where j will range from 0 to maxj */
+/* The output values must fall in 0..MAXJSAMPLE in increasing order */
+{
+ /* We always provide values 0 and MAXJSAMPLE for each component;
+ * any additional values are equally spaced between these limits.
+ * (Forcing the upper and lower values to the limits ensures that
+ * dithering can't produce a color outside the selected gamut.)
+ */
+ return (int) (((long) j * MAXJSAMPLE + maxj/2) / maxj);
+}
+
+
+LOCAL(int)
+largest_input_value (j_decompress_ptr , int , int j, int maxj)
+/* Return largest input value that should map to j'th output value */
+/* Must have largest(j=0) >= 0, and largest(j=maxj) >= MAXJSAMPLE */
+{
+ /* Breakpoints are halfway between values returned by output_value */
+ return (int) (((long) (2*j + 1) * MAXJSAMPLE + maxj) / (2*maxj));
+}
+
+
+/*
+ * Create the colormap.
+ */
+
+LOCAL(void)
+create_colormap (j_decompress_ptr cinfo)
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ JSAMPARRAY colormap; /* Created colormap */
+ int total_colors; /* Number of distinct output colors */
+ int i,j,k, nci, blksize, blkdist, ptr, val;
+
+ /* Select number of colors for each component */
+ total_colors = select_ncolors(cinfo, cquantize->Ncolors);
+
+ /* Report selected color counts */
+ if (cinfo->out_color_components == 3)
+ TRACEMS4(cinfo, 1, JTRC_QUANT_3_NCOLORS,
+ total_colors, cquantize->Ncolors[0],
+ cquantize->Ncolors[1], cquantize->Ncolors[2]);
+ else
+ TRACEMS1(cinfo, 1, JTRC_QUANT_NCOLORS, total_colors);
+
+ /* Allocate and fill in the colormap. */
+ /* The colors are ordered in the map in standard row-major order, */
+ /* i.e. rightmost (highest-indexed) color changes most rapidly. */
+
+ colormap = (*cinfo->mem->alloc_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (JDIMENSION) total_colors, (JDIMENSION) cinfo->out_color_components);
+
+ /* blksize is number of adjacent repeated entries for a component */
+ /* blkdist is distance between groups of identical entries for a component */
+ blkdist = total_colors;
+
+ for (i = 0; i < cinfo->out_color_components; i++) {
+ /* fill in colormap entries for i'th color component */
+ nci = cquantize->Ncolors[i]; /* # of distinct values for this color */
+ blksize = blkdist / nci;
+ for (j = 0; j < nci; j++) {
+ /* Compute j'th output value (out of nci) for component */
+ val = output_value(cinfo, i, j, nci-1);
+ /* Fill in all colormap entries that have this value of this component */
+ for (ptr = j * blksize; ptr < total_colors; ptr += blkdist) {
+ /* fill in blksize entries beginning at ptr */
+ for (k = 0; k < blksize; k++)
+ colormap[i][ptr+k] = (JSAMPLE) val;
+ }
+ }
+ blkdist = blksize; /* blksize of this color is blkdist of next */
+ }
+
+ /* Save the colormap in private storage,
+ * where it will survive color quantization mode changes.
+ */
+ cquantize->sv_colormap = colormap;
+ cquantize->sv_actual = total_colors;
+}
+
+
+/*
+ * Create the color index table.
+ */
+
+LOCAL(void)
+create_colorindex (j_decompress_ptr cinfo)
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ JSAMPROW indexptr;
+ int i,j,k, nci, blksize, val, pad;
+
+ /* For ordered dither, we pad the color index tables by MAXJSAMPLE in
+ * each direction (input index values can be -MAXJSAMPLE .. 2*MAXJSAMPLE).
+ * This is not necessary in the other dithering modes. However, we
+ * flag whether it was done in case user changes dithering mode.
+ */
+ if (cinfo->dither_mode == JDITHER_ORDERED) {
+ pad = MAXJSAMPLE*2;
+ cquantize->is_padded = TRUE;
+ } else {
+ pad = 0;
+ cquantize->is_padded = FALSE;
+ }
+
+ cquantize->colorindex = (*cinfo->mem->alloc_sarray)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (JDIMENSION) (MAXJSAMPLE+1 + pad),
+ (JDIMENSION) cinfo->out_color_components);
+
+ /* blksize is number of adjacent repeated entries for a component */
+ blksize = cquantize->sv_actual;
+
+ for (i = 0; i < cinfo->out_color_components; i++) {
+ /* fill in colorindex entries for i'th color component */
+ nci = cquantize->Ncolors[i]; /* # of distinct values for this color */
+ blksize = blksize / nci;
+
+ /* adjust colorindex pointers to provide padding at negative indexes. */
+ if (pad)
+ cquantize->colorindex[i] += MAXJSAMPLE;
+
+ /* in loop, val = index of current output value, */
+ /* and k = largest j that maps to current val */
+ indexptr = cquantize->colorindex[i];
+ val = 0;
+ k = largest_input_value(cinfo, i, 0, nci-1);
+ for (j = 0; j <= MAXJSAMPLE; j++) {
+ while (j > k) /* advance val if past boundary */
+ k = largest_input_value(cinfo, i, ++val, nci-1);
+ /* premultiply so that no multiplication needed in main processing */
+ indexptr[j] = (JSAMPLE) (val * blksize);
+ }
+ /* Pad at both ends if necessary */
+ if (pad)
+ for (j = 1; j <= MAXJSAMPLE; j++) {
+ indexptr[-j] = indexptr[0];
+ indexptr[MAXJSAMPLE+j] = indexptr[MAXJSAMPLE];
+ }
+ }
+}
+
+
+/*
+ * Create an ordered-dither array for a component having ncolors
+ * distinct output values.
+ */
+
+LOCAL(ODITHER_MATRIX_PTR)
+make_odither_array (j_decompress_ptr cinfo, int ncolors)
+{
+ ODITHER_MATRIX_PTR odither;
+ int j,k;
+ long num,den;
+
+ odither = (ODITHER_MATRIX_PTR)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(ODITHER_MATRIX));
+ /* The inter-value distance for this color is MAXJSAMPLE/(ncolors-1).
+ * Hence the dither value for the matrix cell with fill order f
+ * (f=0..N-1) should be (N-1-2*f)/(2*N) * MAXJSAMPLE/(ncolors-1).
+ * On 16-bit-int machine, be careful to avoid overflow.
+ */
+ den = 2 * ODITHER_CELLS * ((long) (ncolors - 1));
+ for (j = 0; j < ODITHER_SIZE; j++) {
+ for (k = 0; k < ODITHER_SIZE; k++) {
+ num = ((long) (ODITHER_CELLS-1 - 2*((int)base_dither_matrix[j][k])))
+ * MAXJSAMPLE;
+ /* Ensure round towards zero despite C's lack of consistency
+ * about rounding negative values in integer division...
+ */
+ odither[j][k] = (int) (num<0 ? -((-num)/den) : num/den);
+ }
+ }
+ return odither;
+}
+
+
+/*
+ * Create the ordered-dither tables.
+ * Components having the same number of representative colors may
+ * share a dither table.
+ */
+
+LOCAL(void)
+create_odither_tables (j_decompress_ptr cinfo)
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ ODITHER_MATRIX_PTR odither;
+ int i, j, nci;
+
+ for (i = 0; i < cinfo->out_color_components; i++) {
+ nci = cquantize->Ncolors[i]; /* # of distinct values for this color */
+ odither = NULL; /* search for matching prior component */
+ for (j = 0; j < i; j++) {
+ if (nci == cquantize->Ncolors[j]) {
+ odither = cquantize->odither[j];
+ break;
+ }
+ }
+ if (odither == NULL) /* need a new table? */
+ odither = make_odither_array(cinfo, nci);
+ cquantize->odither[i] = odither;
+ }
+}
+
+
+/*
+ * Map some rows of pixels to the output colormapped representation.
+ */
+
+METHODDEF(void)
+color_quantize (j_decompress_ptr cinfo, JSAMPARRAY input_buf,
+ JSAMPARRAY output_buf, int num_rows)
+/* General case, no dithering */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ JSAMPARRAY colorindex = cquantize->colorindex;
+ int pixcode, ci;
+ JSAMPROW ptrin, ptrout;
+ int row;
+ JDIMENSION col;
+ JDIMENSION width = cinfo->output_width;
+ int nc = cinfo->out_color_components;
+
+ for (row = 0; row < num_rows; row++) {
+ ptrin = input_buf[row];
+ ptrout = output_buf[row];
+ for (col = width; col > 0; col--) {
+ pixcode = 0;
+ for (ci = 0; ci < nc; ci++) {
+ pixcode += GETJSAMPLE(colorindex[ci][GETJSAMPLE(*ptrin++)]);
+ }
+ *ptrout++ = (JSAMPLE) pixcode;
+ }
+ }
+}
+
+
+METHODDEF(void)
+color_quantize3 (j_decompress_ptr cinfo, JSAMPARRAY input_buf,
+ JSAMPARRAY output_buf, int num_rows)
+/* Fast path for out_color_components==3, no dithering */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ int pixcode;
+ JSAMPROW ptrin, ptrout;
+ JSAMPROW colorindex0 = cquantize->colorindex[0];
+ JSAMPROW colorindex1 = cquantize->colorindex[1];
+ JSAMPROW colorindex2 = cquantize->colorindex[2];
+ int row;
+ JDIMENSION col;
+ JDIMENSION width = cinfo->output_width;
+
+ for (row = 0; row < num_rows; row++) {
+ ptrin = input_buf[row];
+ ptrout = output_buf[row];
+ for (col = width; col > 0; col--) {
+ pixcode = GETJSAMPLE(colorindex0[GETJSAMPLE(*ptrin++)]);
+ pixcode += GETJSAMPLE(colorindex1[GETJSAMPLE(*ptrin++)]);
+ pixcode += GETJSAMPLE(colorindex2[GETJSAMPLE(*ptrin++)]);
+ *ptrout++ = (JSAMPLE) pixcode;
+ }
+ }
+}
+
+
+METHODDEF(void)
+quantize_ord_dither (j_decompress_ptr cinfo, JSAMPARRAY input_buf,
+ JSAMPARRAY output_buf, int num_rows)
+/* General case, with ordered dithering */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ JSAMPROW input_ptr;
+ JSAMPROW output_ptr;
+ JSAMPROW colorindex_ci;
+ int * dither; /* points to active row of dither matrix */
+ int row_index, col_index; /* current indexes into dither matrix */
+ int nc = cinfo->out_color_components;
+ int ci;
+ int row;
+ JDIMENSION col;
+ JDIMENSION width = cinfo->output_width;
+
+ for (row = 0; row < num_rows; row++) {
+ /* Initialize output values to 0 so can process components separately */
+ jzero_far((void FAR *) output_buf[row],
+ (size_t) (width * SIZEOF(JSAMPLE)));
+ row_index = cquantize->row_index;
+ for (ci = 0; ci < nc; ci++) {
+ input_ptr = input_buf[row] + ci;
+ output_ptr = output_buf[row];
+ colorindex_ci = cquantize->colorindex[ci];
+ dither = cquantize->odither[ci][row_index];
+ col_index = 0;
+
+ for (col = width; col > 0; col--) {
+ /* Form pixel value + dither, range-limit to 0..MAXJSAMPLE,
+ * select output value, accumulate into output code for this pixel.
+ * Range-limiting need not be done explicitly, as we have extended
+ * the colorindex table to produce the right answers for out-of-range
+ * inputs. The maximum dither is +- MAXJSAMPLE; this sets the
+ * required amount of padding.
+ */
+ *output_ptr += colorindex_ci[GETJSAMPLE(*input_ptr)+dither[col_index]];
+ input_ptr += nc;
+ output_ptr++;
+ col_index = (col_index + 1) & ODITHER_MASK;
+ }
+ }
+ /* Advance row index for next row */
+ row_index = (row_index + 1) & ODITHER_MASK;
+ cquantize->row_index = row_index;
+ }
+}
+
+
+METHODDEF(void)
+quantize3_ord_dither (j_decompress_ptr cinfo, JSAMPARRAY input_buf,
+ JSAMPARRAY output_buf, int num_rows)
+/* Fast path for out_color_components==3, with ordered dithering */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ int pixcode;
+ JSAMPROW input_ptr;
+ JSAMPROW output_ptr;
+ JSAMPROW colorindex0 = cquantize->colorindex[0];
+ JSAMPROW colorindex1 = cquantize->colorindex[1];
+ JSAMPROW colorindex2 = cquantize->colorindex[2];
+ int * dither0; /* points to active row of dither matrix */
+ int * dither1;
+ int * dither2;
+ int row_index, col_index; /* current indexes into dither matrix */
+ int row;
+ JDIMENSION col;
+ JDIMENSION width = cinfo->output_width;
+
+ for (row = 0; row < num_rows; row++) {
+ row_index = cquantize->row_index;
+ input_ptr = input_buf[row];
+ output_ptr = output_buf[row];
+ dither0 = cquantize->odither[0][row_index];
+ dither1 = cquantize->odither[1][row_index];
+ dither2 = cquantize->odither[2][row_index];
+ col_index = 0;
+
+ for (col = width; col > 0; col--) {
+ pixcode = GETJSAMPLE(colorindex0[GETJSAMPLE(*input_ptr++) +
+ dither0[col_index]]);
+ pixcode += GETJSAMPLE(colorindex1[GETJSAMPLE(*input_ptr++) +
+ dither1[col_index]]);
+ pixcode += GETJSAMPLE(colorindex2[GETJSAMPLE(*input_ptr++) +
+ dither2[col_index]]);
+ *output_ptr++ = (JSAMPLE) pixcode;
+ col_index = (col_index + 1) & ODITHER_MASK;
+ }
+ row_index = (row_index + 1) & ODITHER_MASK;
+ cquantize->row_index = row_index;
+ }
+}
+
+
+METHODDEF(void)
+quantize_fs_dither (j_decompress_ptr cinfo, JSAMPARRAY input_buf,
+ JSAMPARRAY output_buf, int num_rows)
+/* General case, with Floyd-Steinberg dithering */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ LOCFSERROR cur; /* current error or pixel value */
+ LOCFSERROR belowerr; /* error for pixel below cur */
+ LOCFSERROR bpreverr; /* error for below/prev col */
+ LOCFSERROR bnexterr; /* error for below/next col */
+ LOCFSERROR delta;
+ FSERRPTR errorptr; /* => fserrors[] at column before current */
+ JSAMPROW input_ptr;
+ JSAMPROW output_ptr;
+ JSAMPROW colorindex_ci;
+ JSAMPROW colormap_ci;
+ int pixcode;
+ int nc = cinfo->out_color_components;
+ int dir; /* 1 for left-to-right, -1 for right-to-left */
+ int dirnc; /* dir * nc */
+ int ci;
+ int row;
+ JDIMENSION col;
+ JDIMENSION width = cinfo->output_width;
+ JSAMPLE *range_limit = cinfo->sample_range_limit;
+ SHIFT_TEMPS
+
+ for (row = 0; row < num_rows; row++) {
+ /* Initialize output values to 0 so can process components separately */
+ jzero_far((void FAR *) output_buf[row],
+ (size_t) (width * SIZEOF(JSAMPLE)));
+ for (ci = 0; ci < nc; ci++) {
+ input_ptr = input_buf[row] + ci;
+ output_ptr = output_buf[row];
+ if (cquantize->on_odd_row) {
+ /* work right to left in this row */
+ input_ptr += (width-1) * nc; /* so point to rightmost pixel */
+ output_ptr += width-1;
+ dir = -1;
+ dirnc = -nc;
+ errorptr = cquantize->fserrors[ci] + (width+1); /* => entry after last column */
+ } else {
+ /* work left to right in this row */
+ dir = 1;
+ dirnc = nc;
+ errorptr = cquantize->fserrors[ci]; /* => entry before first column */
+ }
+ colorindex_ci = cquantize->colorindex[ci];
+ colormap_ci = cquantize->sv_colormap[ci];
+ /* Preset error values: no error propagated to first pixel from left */
+ cur = 0;
+ /* and no error propagated to row below yet */
+ belowerr = bpreverr = 0;
+
+ for (col = width; col > 0; col--) {
+ /* cur holds the error propagated from the previous pixel on the
+ * current line. Add the error propagated from the previous line
+ * to form the complete error correction term for this pixel, and
+ * round the error term (which is expressed * 16) to an integer.
+ * RIGHT_SHIFT rounds towards minus infinity, so adding 8 is correct
+ * for either sign of the error value.
+ * Note: errorptr points to *previous* column's array entry.
+ */
+ cur = RIGHT_SHIFT(cur + errorptr[dir] + 8, 4);
+ /* Form pixel value + error, and range-limit to 0..MAXJSAMPLE.
+ * The maximum error is +- MAXJSAMPLE; this sets the required size
+ * of the range_limit array.
+ */
+ cur += GETJSAMPLE(*input_ptr);
+ cur = GETJSAMPLE(range_limit[cur]);
+ /* Select output value, accumulate into output code for this pixel */
+ pixcode = GETJSAMPLE(colorindex_ci[cur]);
+ *output_ptr += (JSAMPLE) pixcode;
+ /* Compute actual representation error at this pixel */
+ /* Note: we can do this even though we don't have the final */
+ /* pixel code, because the colormap is orthogonal. */
+ cur -= GETJSAMPLE(colormap_ci[pixcode]);
+ /* Compute error fractions to be propagated to adjacent pixels.
+ * Add these into the running sums, and simultaneously shift the
+ * next-line error sums left by 1 column.
+ */
+ bnexterr = cur;
+ delta = cur * 2;
+ cur += delta; /* form error * 3 */
+ errorptr[0] = (FSERROR) (bpreverr + cur);
+ cur += delta; /* form error * 5 */
+ bpreverr = belowerr + cur;
+ belowerr = bnexterr;
+ cur += delta; /* form error * 7 */
+ /* At this point cur contains the 7/16 error value to be propagated
+ * to the next pixel on the current line, and all the errors for the
+ * next line have been shifted over. We are therefore ready to move on.
+ */
+ input_ptr += dirnc; /* advance input ptr to next column */
+ output_ptr += dir; /* advance output ptr to next column */
+ errorptr += dir; /* advance errorptr to current column */
+ }
+ /* Post-loop cleanup: we must unload the final error value into the
+ * final fserrors[] entry. Note we need not unload belowerr because
+ * it is for the dummy column before or after the actual array.
+ */
+ errorptr[0] = (FSERROR) bpreverr; /* unload prev err into array */
+ }
+ cquantize->on_odd_row = (cquantize->on_odd_row ? FALSE : TRUE);
+ }
+}
+
+
+/*
+ * Allocate workspace for Floyd-Steinberg errors.
+ */
+
+LOCAL(void)
+alloc_fs_workspace (j_decompress_ptr cinfo)
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ size_t arraysize;
+ int i;
+
+ arraysize = (size_t) ((cinfo->output_width + 2) * SIZEOF(FSERROR));
+ for (i = 0; i < cinfo->out_color_components; i++) {
+ cquantize->fserrors[i] = (FSERRPTR)
+ (*cinfo->mem->alloc_large)((j_common_ptr) cinfo, JPOOL_IMAGE, arraysize);
+ }
+}
+
+
+/*
+ * Initialize for one-pass color quantization.
+ */
+
+METHODDEF(void)
+start_pass_1_quant (j_decompress_ptr cinfo, int )
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ size_t arraysize;
+ int i;
+
+ /* Install my colormap. */
+ cinfo->colormap = cquantize->sv_colormap;
+ cinfo->actual_number_of_colors = cquantize->sv_actual;
+
+ /* Initialize for desired dithering mode. */
+ switch (cinfo->dither_mode) {
+ case JDITHER_NONE:
+ if (cinfo->out_color_components == 3)
+ cquantize->pub.color_quantize = color_quantize3;
+ else
+ cquantize->pub.color_quantize = color_quantize;
+ break;
+ case JDITHER_ORDERED:
+ if (cinfo->out_color_components == 3)
+ cquantize->pub.color_quantize = quantize3_ord_dither;
+ else
+ cquantize->pub.color_quantize = quantize_ord_dither;
+ cquantize->row_index = 0; /* initialize state for ordered dither */
+ /* If user changed to ordered dither from another mode,
+ * we must recreate the color index table with padding.
+ * This will cost extra space, but probably isn't very likely.
+ */
+ if (! cquantize->is_padded)
+ create_colorindex(cinfo);
+ /* Create ordered-dither tables if we didn't already. */
+ if (cquantize->odither[0] == NULL)
+ create_odither_tables(cinfo);
+ break;
+ case JDITHER_FS:
+ cquantize->pub.color_quantize = quantize_fs_dither;
+ cquantize->on_odd_row = FALSE; /* initialize state for F-S dither */
+ /* Allocate Floyd-Steinberg workspace if didn't already. */
+ if (cquantize->fserrors[0] == NULL)
+ alloc_fs_workspace(cinfo);
+ /* Initialize the propagated errors to zero. */
+ arraysize = (size_t) ((cinfo->output_width + 2) * SIZEOF(FSERROR));
+ for (i = 0; i < cinfo->out_color_components; i++)
+ jzero_far((void FAR *) cquantize->fserrors[i], arraysize);
+ break;
+ default:
+ ERREXIT(cinfo, JERR_NOT_COMPILED);
+ break;
+ }
+}
+
+
+/*
+ * Finish up at the end of the pass.
+ */
+
+METHODDEF(void)
+finish_pass_1_quant (j_decompress_ptr )
+{
+ /* no work in 1-pass case */
+}
+
+
+/*
+ * Switch to a new external colormap between output passes.
+ * Shouldn't get to this module!
+ */
+
+METHODDEF(void)
+new_color_map_1_quant (j_decompress_ptr cinfo)
+{
+ ERREXIT(cinfo, JERR_MODE_CHANGE);
+}
+
+
+/*
+ * Module initialization routine for 1-pass color quantization.
+ */
+
+GLOBAL(void)
+jinit_1pass_quantizer (j_decompress_ptr cinfo)
+{
+ my_cquantize_ptr cquantize;
+
+ cquantize = (my_cquantize_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_cquantizer));
+ cinfo->cquantize = (struct jpeg_color_quantizer *) cquantize;
+ cquantize->pub.start_pass = start_pass_1_quant;
+ cquantize->pub.finish_pass = finish_pass_1_quant;
+ cquantize->pub.new_color_map = new_color_map_1_quant;
+ cquantize->fserrors[0] = NULL; /* Flag FS workspace not allocated */
+ cquantize->odither[0] = NULL; /* Also flag odither arrays not allocated */
+
+ /* Make sure my internal arrays won't overflow */
+ if (cinfo->out_color_components > MAX_Q_COMPS)
+ ERREXIT1(cinfo, JERR_QUANT_COMPONENTS, MAX_Q_COMPS);
+ /* Make sure colormap indexes can be represented by JSAMPLEs */
+ if (cinfo->desired_number_of_colors > (MAXJSAMPLE+1))
+ ERREXIT1(cinfo, JERR_QUANT_MANY_COLORS, MAXJSAMPLE+1);
+
+ /* Create the colormap and color index table. */
+ create_colormap(cinfo);
+ create_colorindex(cinfo);
+
+ /* Allocate Floyd-Steinberg workspace now if requested.
+ * We do this now since it is FAR storage and may affect the memory
+ * manager's space calculations. If the user changes to FS dither
+ * mode in a later pass, we will allocate the space then, and will
+ * possibly overrun the max_memory_to_use setting.
+ */
+ if (cinfo->dither_mode == JDITHER_FS)
+ alloc_fs_workspace(cinfo);
+}
+
+#endif /* QUANT_1PASS_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jquant2.cpp b/ml/dlib/dlib/external/libjpeg/jquant2.cpp
new file mode 100644
index 000000000..0d7b5969a
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jquant2.cpp
@@ -0,0 +1,1310 @@
+/*
+ * jquant2.c
+ *
+ * Copyright (C) 1991-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains 2-pass color quantization (color mapping) routines.
+ * These routines provide selection of a custom color map for an image,
+ * followed by mapping of the image to that color map, with optional
+ * Floyd-Steinberg dithering.
+ * It is also possible to use just the second pass to map to an arbitrary
+ * externally-given color map.
+ *
+ * Note: ordered dithering is not supported, since there isn't any fast
+ * way to compute intercolor distances; it's unclear that ordered dither's
+ * fundamental assumptions even hold with an irregularly spaced color map.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+#ifdef QUANT_2PASS_SUPPORTED
+
+
+/*
+ * This module implements the well-known Heckbert paradigm for color
+ * quantization. Most of the ideas used here can be traced back to
+ * Heckbert's seminal paper
+ * Heckbert, Paul. "Color Image Quantization for Frame Buffer Display",
+ * Proc. SIGGRAPH '82, Computer Graphics v.16 #3 (July 1982), pp 297-304.
+ *
+ * In the first pass over the image, we accumulate a histogram showing the
+ * usage count of each possible color. To keep the histogram to a reasonable
+ * size, we reduce the precision of the input; typical practice is to retain
+ * 5 or 6 bits per color, so that 8 or 4 different input values are counted
+ * in the same histogram cell.
+ *
+ * Next, the color-selection step begins with a box representing the whole
+ * color space, and repeatedly splits the "largest" remaining box until we
+ * have as many boxes as desired colors. Then the mean color in each
+ * remaining box becomes one of the possible output colors.
+ *
+ * The second pass over the image maps each input pixel to the closest output
+ * color (optionally after applying a Floyd-Steinberg dithering correction).
+ * This mapping is logically trivial, but making it go fast enough requires
+ * considerable care.
+ *
+ * Heckbert-style quantizers vary a good deal in their policies for choosing
+ * the "largest" box and deciding where to cut it. The particular policies
+ * used here have proved out well in experimental comparisons, but better ones
+ * may yet be found.
+ *
+ * In earlier versions of the IJG code, this module quantized in YCbCr color
+ * space, processing the raw upsampled data without a color conversion step.
+ * This allowed the color conversion math to be done only once per colormap
+ * entry, not once per pixel. However, that optimization precluded other
+ * useful optimizations (such as merging color conversion with upsampling)
+ * and it also interfered with desired capabilities such as quantizing to an
+ * externally-supplied colormap. We have therefore abandoned that approach.
+ * The present code works in the post-conversion color space, typically RGB.
+ *
+ * To improve the visual quality of the results, we actually work in scaled
+ * RGB space, giving G distances more weight than R, and R in turn more than
+ * B. To do everything in integer math, we must use integer scale factors.
+ * The 2/3/1 scale factors used here correspond loosely to the relative
+ * weights of the colors in the NTSC grayscale equation.
+ * If you want to use this code to quantize a non-RGB color space, you'll
+ * probably need to change these scale factors.
+ */
+
+#define R_SCALE 2 /* scale R distances by this much */
+#define G_SCALE 3 /* scale G distances by this much */
+#define B_SCALE 1 /* and B by this much */
+
+/* Relabel R/G/B as components 0/1/2, respecting the RGB ordering defined
+ * in jmorecfg.h. As the code stands, it will do the right thing for R,G,B
+ * and B,G,R orders. If you define some other weird order in jmorecfg.h,
+ * you'll get compile errors until you extend this logic. In that case
+ * you'll probably want to tweak the histogram sizes too.
+ */
+
+#if RGB_RED == 0
+#define C0_SCALE R_SCALE
+#endif
+#if RGB_BLUE == 0
+#define C0_SCALE B_SCALE
+#endif
+#if RGB_GREEN == 1
+#define C1_SCALE G_SCALE
+#endif
+#if RGB_RED == 2
+#define C2_SCALE R_SCALE
+#endif
+#if RGB_BLUE == 2
+#define C2_SCALE B_SCALE
+#endif
+
+
+/*
+ * First we have the histogram data structure and routines for creating it.
+ *
+ * The number of bits of precision can be adjusted by changing these symbols.
+ * We recommend keeping 6 bits for G and 5 each for R and B.
+ * If you have plenty of memory and cycles, 6 bits all around gives marginally
+ * better results; if you are short of memory, 5 bits all around will save
+ * some space but degrade the results.
+ * To maintain a fully accurate histogram, we'd need to allocate a "long"
+ * (preferably unsigned long) for each cell. In practice this is overkill;
+ * we can get by with 16 bits per cell. Few of the cell counts will overflow,
+ * and clamping those that do overflow to the maximum value will give close-
+ * enough results. This reduces the recommended histogram size from 256Kb
+ * to 128Kb, which is a useful savings on PC-class machines.
+ * (In the second pass the histogram space is re-used for pixel mapping data;
+ * in that capacity, each cell must be able to store zero to the number of
+ * desired colors. 16 bits/cell is plenty for that too.)
+ * Since the JPEG code is intended to run in small memory model on 80x86
+ * machines, we can't just allocate the histogram in one chunk. Instead
+ * of a true 3-D array, we use a row of pointers to 2-D arrays. Each
+ * pointer corresponds to a C0 value (typically 2^5 = 32 pointers) and
+ * each 2-D array has 2^6*2^5 = 2048 or 2^6*2^6 = 4096 entries. Note that
+ * on 80x86 machines, the pointer row is in near memory but the actual
+ * arrays are in far memory (same arrangement as we use for image arrays).
+ */
+
+#define MAXNUMCOLORS (MAXJSAMPLE+1) /* maximum size of colormap */
+
+/* These will do the right thing for either R,G,B or B,G,R color order,
+ * but you may not like the results for other color orders.
+ */
+#define HIST_C0_BITS 5 /* bits of precision in R/B histogram */
+#define HIST_C1_BITS 6 /* bits of precision in G histogram */
+#define HIST_C2_BITS 5 /* bits of precision in B/R histogram */
+
+/* Number of elements along histogram axes. */
+#define HIST_C0_ELEMS (1<<HIST_C0_BITS)
+#define HIST_C1_ELEMS (1<<HIST_C1_BITS)
+#define HIST_C2_ELEMS (1<<HIST_C2_BITS)
+
+/* These are the amounts to shift an input value to get a histogram index. */
+#define C0_SHIFT (BITS_IN_JSAMPLE-HIST_C0_BITS)
+#define C1_SHIFT (BITS_IN_JSAMPLE-HIST_C1_BITS)
+#define C2_SHIFT (BITS_IN_JSAMPLE-HIST_C2_BITS)
+
+
+typedef unsigned short histcell; /* histogram cell; prefer an unsigned type */
+
+typedef histcell FAR * histptr; /* for pointers to histogram cells */
+
+typedef histcell hist1d[HIST_C2_ELEMS]; /* typedefs for the array */
+typedef hist1d FAR * hist2d; /* type for the 2nd-level pointers */
+typedef hist2d * hist3d; /* type for top-level pointer */
+
+
+/* Declarations for Floyd-Steinberg dithering.
+ *
+ * Errors are accumulated into the array fserrors[], at a resolution of
+ * 1/16th of a pixel count. The error at a given pixel is propagated
+ * to its not-yet-processed neighbors using the standard F-S fractions,
+ * ... (here) 7/16
+ * 3/16 5/16 1/16
+ * We work left-to-right on even rows, right-to-left on odd rows.
+ *
+ * We can get away with a single array (holding one row's worth of errors)
+ * by using it to store the current row's errors at pixel columns not yet
+ * processed, but the next row's errors at columns already processed. We
+ * need only a few extra variables to hold the errors immediately around the
+ * current column. (If we are lucky, those variables are in registers, but
+ * even if not, they're probably cheaper to access than array elements are.)
+ *
+ * The fserrors[] array has (#columns + 2) entries; the extra entry at
+ * each end saves us from special-casing the first and last pixels.
+ * Each entry is three values long, one value for each color component.
+ *
+ * Note: on a wide image, we might not have enough room in a PC's near data
+ * segment to hold the error array; so it is allocated with alloc_large.
+ */
+
+#if BITS_IN_JSAMPLE == 8
+typedef short FSERROR; /* 16 bits should be enough */
+typedef int LOCFSERROR; /* use 'int' for calculation temps */
+#else
+typedef long FSERROR; /* may need more than 16 bits */
+typedef long LOCFSERROR; /* be sure calculation temps are big enough */
+#endif
+
+typedef FSERROR FAR *FSERRPTR; /* pointer to error array (in FAR storage!) */
+
+
+/* Private subobject */
+
+typedef struct {
+ struct jpeg_color_quantizer pub; /* public fields */
+
+ /* Space for the eventually created colormap is stashed here */
+ JSAMPARRAY sv_colormap; /* colormap allocated at init time */
+ int desired; /* desired # of colors = size of colormap */
+
+ /* Variables for accumulating image statistics */
+ hist3d histogram; /* pointer to the histogram */
+
+ int needs_zeroed; /* TRUE if next pass must zero histogram */
+
+ /* Variables for Floyd-Steinberg dithering */
+ FSERRPTR fserrors; /* accumulated errors */
+ int on_odd_row; /* flag to remember which row we are on */
+ int * error_limiter; /* table for clamping the applied error */
+} my_cquantizer;
+
+typedef my_cquantizer * my_cquantize_ptr;
+
+
+/*
+ * Prescan some rows of pixels.
+ * In this module the prescan simply updates the histogram, which has been
+ * initialized to zeroes by start_pass.
+ * An output_buf parameter is required by the method signature, but no data
+ * is actually output (in fact the buffer controller is probably passing a
+ * NULL pointer).
+ */
+
+METHODDEF(void)
+prescan_quantize (j_decompress_ptr cinfo, JSAMPARRAY input_buf,
+ JSAMPARRAY , int num_rows)
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ JSAMPROW ptr;
+ histptr histp;
+ hist3d histogram = cquantize->histogram;
+ int row;
+ JDIMENSION col;
+ JDIMENSION width = cinfo->output_width;
+
+ for (row = 0; row < num_rows; row++) {
+ ptr = input_buf[row];
+ for (col = width; col > 0; col--) {
+ /* get pixel value and index into the histogram */
+ histp = & histogram[GETJSAMPLE(ptr[0]) >> C0_SHIFT]
+ [GETJSAMPLE(ptr[1]) >> C1_SHIFT]
+ [GETJSAMPLE(ptr[2]) >> C2_SHIFT];
+ /* increment, check for overflow and undo increment if so. */
+ if (++(*histp) <= 0)
+ (*histp)--;
+ ptr += 3;
+ }
+ }
+}
+
+
+/*
+ * Next we have the really interesting routines: selection of a colormap
+ * given the completed histogram.
+ * These routines work with a list of "boxes", each representing a rectangular
+ * subset of the input color space (to histogram precision).
+ */
+
+typedef struct {
+ /* The bounds of the box (inclusive); expressed as histogram indexes */
+ int c0min, c0max;
+ int c1min, c1max;
+ int c2min, c2max;
+ /* The volume (actually 2-norm) of the box */
+ long volume;
+ /* The number of nonzero histogram cells within this box */
+ long colorcount;
+} box;
+
+typedef box * boxptr;
+
+
+LOCAL(boxptr)
+find_biggest_color_pop (boxptr boxlist, int numboxes)
+/* Find the splittable box with the largest color population */
+/* Returns NULL if no splittable boxes remain */
+{
+ boxptr boxp;
+ int i;
+ long maxc = 0;
+ boxptr which = NULL;
+
+ for (i = 0, boxp = boxlist; i < numboxes; i++, boxp++) {
+ if (boxp->colorcount > maxc && boxp->volume > 0) {
+ which = boxp;
+ maxc = boxp->colorcount;
+ }
+ }
+ return which;
+}
+
+
+LOCAL(boxptr)
+find_biggest_volume (boxptr boxlist, int numboxes)
+/* Find the splittable box with the largest (scaled) volume */
+/* Returns NULL if no splittable boxes remain */
+{
+ boxptr boxp;
+ int i;
+ long maxv = 0;
+ boxptr which = NULL;
+
+ for (i = 0, boxp = boxlist; i < numboxes; i++, boxp++) {
+ if (boxp->volume > maxv) {
+ which = boxp;
+ maxv = boxp->volume;
+ }
+ }
+ return which;
+}
+
+
+LOCAL(void)
+update_box (j_decompress_ptr cinfo, boxptr boxp)
+/* Shrink the min/max bounds of a box to enclose only nonzero elements, */
+/* and recompute its volume and population */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ hist3d histogram = cquantize->histogram;
+ histptr histp;
+ int c0,c1,c2;
+ int c0min,c0max,c1min,c1max,c2min,c2max;
+ long dist0,dist1,dist2;
+ long ccount;
+
+ c0min = boxp->c0min; c0max = boxp->c0max;
+ c1min = boxp->c1min; c1max = boxp->c1max;
+ c2min = boxp->c2min; c2max = boxp->c2max;
+
+ if (c0max > c0min)
+ for (c0 = c0min; c0 <= c0max; c0++)
+ for (c1 = c1min; c1 <= c1max; c1++) {
+ histp = & histogram[c0][c1][c2min];
+ for (c2 = c2min; c2 <= c2max; c2++)
+ if (*histp++ != 0) {
+ boxp->c0min = c0min = c0;
+ goto have_c0min;
+ }
+ }
+ have_c0min:
+ if (c0max > c0min)
+ for (c0 = c0max; c0 >= c0min; c0--)
+ for (c1 = c1min; c1 <= c1max; c1++) {
+ histp = & histogram[c0][c1][c2min];
+ for (c2 = c2min; c2 <= c2max; c2++)
+ if (*histp++ != 0) {
+ boxp->c0max = c0max = c0;
+ goto have_c0max;
+ }
+ }
+ have_c0max:
+ if (c1max > c1min)
+ for (c1 = c1min; c1 <= c1max; c1++)
+ for (c0 = c0min; c0 <= c0max; c0++) {
+ histp = & histogram[c0][c1][c2min];
+ for (c2 = c2min; c2 <= c2max; c2++)
+ if (*histp++ != 0) {
+ boxp->c1min = c1min = c1;
+ goto have_c1min;
+ }
+ }
+ have_c1min:
+ if (c1max > c1min)
+ for (c1 = c1max; c1 >= c1min; c1--)
+ for (c0 = c0min; c0 <= c0max; c0++) {
+ histp = & histogram[c0][c1][c2min];
+ for (c2 = c2min; c2 <= c2max; c2++)
+ if (*histp++ != 0) {
+ boxp->c1max = c1max = c1;
+ goto have_c1max;
+ }
+ }
+ have_c1max:
+ if (c2max > c2min)
+ for (c2 = c2min; c2 <= c2max; c2++)
+ for (c0 = c0min; c0 <= c0max; c0++) {
+ histp = & histogram[c0][c1min][c2];
+ for (c1 = c1min; c1 <= c1max; c1++, histp += HIST_C2_ELEMS)
+ if (*histp != 0) {
+ boxp->c2min = c2min = c2;
+ goto have_c2min;
+ }
+ }
+ have_c2min:
+ if (c2max > c2min)
+ for (c2 = c2max; c2 >= c2min; c2--)
+ for (c0 = c0min; c0 <= c0max; c0++) {
+ histp = & histogram[c0][c1min][c2];
+ for (c1 = c1min; c1 <= c1max; c1++, histp += HIST_C2_ELEMS)
+ if (*histp != 0) {
+ boxp->c2max = c2max = c2;
+ goto have_c2max;
+ }
+ }
+ have_c2max:
+
+ /* Update box volume.
+ * We use 2-norm rather than real volume here; this biases the method
+ * against making long narrow boxes, and it has the side benefit that
+ * a box is splittable iff norm > 0.
+ * Since the differences are expressed in histogram-cell units,
+ * we have to shift back to JSAMPLE units to get consistent distances;
+ * after which, we scale according to the selected distance scale factors.
+ */
+ dist0 = ((c0max - c0min) << C0_SHIFT) * C0_SCALE;
+ dist1 = ((c1max - c1min) << C1_SHIFT) * C1_SCALE;
+ dist2 = ((c2max - c2min) << C2_SHIFT) * C2_SCALE;
+ boxp->volume = dist0*dist0 + dist1*dist1 + dist2*dist2;
+
+ /* Now scan remaining volume of box and compute population */
+ ccount = 0;
+ for (c0 = c0min; c0 <= c0max; c0++)
+ for (c1 = c1min; c1 <= c1max; c1++) {
+ histp = & histogram[c0][c1][c2min];
+ for (c2 = c2min; c2 <= c2max; c2++, histp++)
+ if (*histp != 0) {
+ ccount++;
+ }
+ }
+ boxp->colorcount = ccount;
+}
+
+
+LOCAL(int)
+median_cut (j_decompress_ptr cinfo, boxptr boxlist, int numboxes,
+ int desired_colors)
+/* Repeatedly select and split the largest box until we have enough boxes */
+{
+ int n,lb;
+ int c0,c1,c2,cmax;
+ boxptr b1,b2;
+
+ while (numboxes < desired_colors) {
+ /* Select box to split.
+ * Current algorithm: by population for first half, then by volume.
+ */
+ if (numboxes*2 <= desired_colors) {
+ b1 = find_biggest_color_pop(boxlist, numboxes);
+ } else {
+ b1 = find_biggest_volume(boxlist, numboxes);
+ }
+ if (b1 == NULL) /* no splittable boxes left! */
+ break;
+ b2 = &boxlist[numboxes]; /* where new box will go */
+ /* Copy the color bounds to the new box. */
+ b2->c0max = b1->c0max; b2->c1max = b1->c1max; b2->c2max = b1->c2max;
+ b2->c0min = b1->c0min; b2->c1min = b1->c1min; b2->c2min = b1->c2min;
+ /* Choose which axis to split the box on.
+ * Current algorithm: longest scaled axis.
+ * See notes in update_box about scaling distances.
+ */
+ c0 = ((b1->c0max - b1->c0min) << C0_SHIFT) * C0_SCALE;
+ c1 = ((b1->c1max - b1->c1min) << C1_SHIFT) * C1_SCALE;
+ c2 = ((b1->c2max - b1->c2min) << C2_SHIFT) * C2_SCALE;
+ /* We want to break any ties in favor of green, then red, blue last.
+ * This code does the right thing for R,G,B or B,G,R color orders only.
+ */
+#if RGB_RED == 0
+ cmax = c1; n = 1;
+ if (c0 > cmax) { cmax = c0; n = 0; }
+ if (c2 > cmax) { n = 2; }
+#else
+ cmax = c1; n = 1;
+ if (c2 > cmax) { cmax = c2; n = 2; }
+ if (c0 > cmax) { n = 0; }
+#endif
+ /* Choose split point along selected axis, and update box bounds.
+ * Current algorithm: split at halfway point.
+ * (Since the box has been shrunk to minimum volume,
+ * any split will produce two nonempty subboxes.)
+ * Note that lb value is max for lower box, so must be < old max.
+ */
+ switch (n) {
+ case 0:
+ lb = (b1->c0max + b1->c0min) / 2;
+ b1->c0max = lb;
+ b2->c0min = lb+1;
+ break;
+ case 1:
+ lb = (b1->c1max + b1->c1min) / 2;
+ b1->c1max = lb;
+ b2->c1min = lb+1;
+ break;
+ case 2:
+ lb = (b1->c2max + b1->c2min) / 2;
+ b1->c2max = lb;
+ b2->c2min = lb+1;
+ break;
+ }
+ /* Update stats for boxes */
+ update_box(cinfo, b1);
+ update_box(cinfo, b2);
+ numboxes++;
+ }
+ return numboxes;
+}
+
+
+LOCAL(void)
+compute_color (j_decompress_ptr cinfo, boxptr boxp, int icolor)
+/* Compute representative color for a box, put it in colormap[icolor] */
+{
+ /* Current algorithm: mean weighted by pixels (not colors) */
+ /* Note it is important to get the rounding correct! */
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ hist3d histogram = cquantize->histogram;
+ histptr histp;
+ int c0,c1,c2;
+ int c0min,c0max,c1min,c1max,c2min,c2max;
+ long count;
+ long total = 0;
+ long c0total = 0;
+ long c1total = 0;
+ long c2total = 0;
+
+ c0min = boxp->c0min; c0max = boxp->c0max;
+ c1min = boxp->c1min; c1max = boxp->c1max;
+ c2min = boxp->c2min; c2max = boxp->c2max;
+
+ for (c0 = c0min; c0 <= c0max; c0++)
+ for (c1 = c1min; c1 <= c1max; c1++) {
+ histp = & histogram[c0][c1][c2min];
+ for (c2 = c2min; c2 <= c2max; c2++) {
+ if ((count = *histp++) != 0) {
+ total += count;
+ c0total += ((c0 << C0_SHIFT) + ((1<<C0_SHIFT)>>1)) * count;
+ c1total += ((c1 << C1_SHIFT) + ((1<<C1_SHIFT)>>1)) * count;
+ c2total += ((c2 << C2_SHIFT) + ((1<<C2_SHIFT)>>1)) * count;
+ }
+ }
+ }
+
+ cinfo->colormap[0][icolor] = (JSAMPLE) ((c0total + (total>>1)) / total);
+ cinfo->colormap[1][icolor] = (JSAMPLE) ((c1total + (total>>1)) / total);
+ cinfo->colormap[2][icolor] = (JSAMPLE) ((c2total + (total>>1)) / total);
+}
+
+
+LOCAL(void)
+select_colors (j_decompress_ptr cinfo, int desired_colors)
+/* Master routine for color selection */
+{
+ boxptr boxlist;
+ int numboxes;
+ int i;
+
+ /* Allocate workspace for box list */
+ boxlist = (boxptr) (*cinfo->mem->alloc_small)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE, desired_colors * SIZEOF(box));
+ /* Initialize one box containing whole space */
+ numboxes = 1;
+ boxlist[0].c0min = 0;
+ boxlist[0].c0max = MAXJSAMPLE >> C0_SHIFT;
+ boxlist[0].c1min = 0;
+ boxlist[0].c1max = MAXJSAMPLE >> C1_SHIFT;
+ boxlist[0].c2min = 0;
+ boxlist[0].c2max = MAXJSAMPLE >> C2_SHIFT;
+ /* Shrink it to actually-used volume and set its statistics */
+ update_box(cinfo, & boxlist[0]);
+ /* Perform median-cut to produce final box list */
+ numboxes = median_cut(cinfo, boxlist, numboxes, desired_colors);
+ /* Compute the representative color for each box, fill colormap */
+ for (i = 0; i < numboxes; i++)
+ compute_color(cinfo, & boxlist[i], i);
+ cinfo->actual_number_of_colors = numboxes;
+ TRACEMS1(cinfo, 1, JTRC_QUANT_SELECTED, numboxes);
+}
+
+
+/*
+ * These routines are concerned with the time-critical task of mapping input
+ * colors to the nearest color in the selected colormap.
+ *
+ * We re-use the histogram space as an "inverse color map", essentially a
+ * cache for the results of nearest-color searches. All colors within a
+ * histogram cell will be mapped to the same colormap entry, namely the one
+ * closest to the cell's center. This may not be quite the closest entry to
+ * the actual input color, but it's almost as good. A zero in the cache
+ * indicates we haven't found the nearest color for that cell yet; the array
+ * is cleared to zeroes before starting the mapping pass. When we find the
+ * nearest color for a cell, its colormap index plus one is recorded in the
+ * cache for future use. The pass2 scanning routines call fill_inverse_cmap
+ * when they need to use an unfilled entry in the cache.
+ *
+ * Our method of efficiently finding nearest colors is based on the "locally
+ * sorted search" idea described by Heckbert and on the incremental distance
+ * calculation described by Spencer W. Thomas in chapter III.1 of Graphics
+ * Gems II (James Arvo, ed. Academic Press, 1991). Thomas points out that
+ * the distances from a given colormap entry to each cell of the histogram can
+ * be computed quickly using an incremental method: the differences between
+ * distances to adjacent cells themselves differ by a constant. This allows a
+ * fairly fast implementation of the "brute force" approach of computing the
+ * distance from every colormap entry to every histogram cell. Unfortunately,
+ * it needs a work array to hold the best-distance-so-far for each histogram
+ * cell (because the inner loop has to be over cells, not colormap entries).
+ * The work array elements have to be INT32s, so the work array would need
+ * 256Kb at our recommended precision. This is not feasible in DOS machines.
+ *
+ * To get around these problems, we apply Thomas' method to compute the
+ * nearest colors for only the cells within a small subbox of the histogram.
+ * The work array need be only as big as the subbox, so the memory usage
+ * problem is solved. Furthermore, we need not fill subboxes that are never
+ * referenced in pass2; many images use only part of the color gamut, so a
+ * fair amount of work is saved. An additional advantage of this
+ * approach is that we can apply Heckbert's locality criterion to quickly
+ * eliminate colormap entries that are far away from the subbox; typically
+ * three-fourths of the colormap entries are rejected by Heckbert's criterion,
+ * and we need not compute their distances to individual cells in the subbox.
+ * The speed of this approach is heavily influenced by the subbox size: too
+ * small means too much overhead, too big loses because Heckbert's criterion
+ * can't eliminate as many colormap entries. Empirically the best subbox
+ * size seems to be about 1/512th of the histogram (1/8th in each direction).
+ *
+ * Thomas' article also describes a refined method which is asymptotically
+ * faster than the brute-force method, but it is also far more complex and
+ * cannot efficiently be applied to small subboxes. It is therefore not
+ * useful for programs intended to be portable to DOS machines. On machines
+ * with plenty of memory, filling the whole histogram in one shot with Thomas'
+ * refined method might be faster than the present code --- but then again,
+ * it might not be any faster, and it's certainly more complicated.
+ */
+
+
+/* log2(histogram cells in update box) for each axis; this can be adjusted */
+#define BOX_C0_LOG (HIST_C0_BITS-3)
+#define BOX_C1_LOG (HIST_C1_BITS-3)
+#define BOX_C2_LOG (HIST_C2_BITS-3)
+
+#define BOX_C0_ELEMS (1<<BOX_C0_LOG) /* # of hist cells in update box */
+#define BOX_C1_ELEMS (1<<BOX_C1_LOG)
+#define BOX_C2_ELEMS (1<<BOX_C2_LOG)
+
+#define BOX_C0_SHIFT (C0_SHIFT + BOX_C0_LOG)
+#define BOX_C1_SHIFT (C1_SHIFT + BOX_C1_LOG)
+#define BOX_C2_SHIFT (C2_SHIFT + BOX_C2_LOG)
+
+
+/*
+ * The next three routines implement inverse colormap filling. They could
+ * all be folded into one big routine, but splitting them up this way saves
+ * some stack space (the mindist[] and bestdist[] arrays need not coexist)
+ * and may allow some compilers to produce better code by registerizing more
+ * inner-loop variables.
+ */
+
+LOCAL(int)
+find_nearby_colors (j_decompress_ptr cinfo, int minc0, int minc1, int minc2,
+ JSAMPLE colorlist[])
+/* Locate the colormap entries close enough to an update box to be candidates
+ * for the nearest entry to some cell(s) in the update box. The update box
+ * is specified by the center coordinates of its first cell. The number of
+ * candidate colormap entries is returned, and their colormap indexes are
+ * placed in colorlist[].
+ * This routine uses Heckbert's "locally sorted search" criterion to select
+ * the colors that need further consideration.
+ */
+{
+ int numcolors = cinfo->actual_number_of_colors;
+ int maxc0, maxc1, maxc2;
+ int centerc0, centerc1, centerc2;
+ int i, x, ncolors;
+ long minmaxdist, min_dist, max_dist, tdist;
+ long mindist[MAXNUMCOLORS]; /* min distance to colormap entry i */
+
+ /* Compute true coordinates of update box's upper corner and center.
+ * Actually we compute the coordinates of the center of the upper-corner
+ * histogram cell, which are the upper bounds of the volume we care about.
+ * Note that since ">>" rounds down, the "center" values may be closer to
+ * min than to max; hence comparisons to them must be "<=", not "<".
+ */
+ maxc0 = minc0 + ((1 << BOX_C0_SHIFT) - (1 << C0_SHIFT));
+ centerc0 = (minc0 + maxc0) >> 1;
+ maxc1 = minc1 + ((1 << BOX_C1_SHIFT) - (1 << C1_SHIFT));
+ centerc1 = (minc1 + maxc1) >> 1;
+ maxc2 = minc2 + ((1 << BOX_C2_SHIFT) - (1 << C2_SHIFT));
+ centerc2 = (minc2 + maxc2) >> 1;
+
+ /* For each color in colormap, find:
+ * 1. its minimum squared-distance to any point in the update box
+ * (zero if color is within update box);
+ * 2. its maximum squared-distance to any point in the update box.
+ * Both of these can be found by considering only the corners of the box.
+ * We save the minimum distance for each color in mindist[];
+ * only the smallest maximum distance is of interest.
+ */
+ minmaxdist = 0x7FFFFFFFL;
+
+ for (i = 0; i < numcolors; i++) {
+ /* We compute the squared-c0-distance term, then add in the other two. */
+ x = GETJSAMPLE(cinfo->colormap[0][i]);
+ if (x < minc0) {
+ tdist = (x - minc0) * C0_SCALE;
+ min_dist = tdist*tdist;
+ tdist = (x - maxc0) * C0_SCALE;
+ max_dist = tdist*tdist;
+ } else if (x > maxc0) {
+ tdist = (x - maxc0) * C0_SCALE;
+ min_dist = tdist*tdist;
+ tdist = (x - minc0) * C0_SCALE;
+ max_dist = tdist*tdist;
+ } else {
+ /* within cell range so no contribution to min_dist */
+ min_dist = 0;
+ if (x <= centerc0) {
+ tdist = (x - maxc0) * C0_SCALE;
+ max_dist = tdist*tdist;
+ } else {
+ tdist = (x - minc0) * C0_SCALE;
+ max_dist = tdist*tdist;
+ }
+ }
+
+ x = GETJSAMPLE(cinfo->colormap[1][i]);
+ if (x < minc1) {
+ tdist = (x - minc1) * C1_SCALE;
+ min_dist += tdist*tdist;
+ tdist = (x - maxc1) * C1_SCALE;
+ max_dist += tdist*tdist;
+ } else if (x > maxc1) {
+ tdist = (x - maxc1) * C1_SCALE;
+ min_dist += tdist*tdist;
+ tdist = (x - minc1) * C1_SCALE;
+ max_dist += tdist*tdist;
+ } else {
+ /* within cell range so no contribution to min_dist */
+ if (x <= centerc1) {
+ tdist = (x - maxc1) * C1_SCALE;
+ max_dist += tdist*tdist;
+ } else {
+ tdist = (x - minc1) * C1_SCALE;
+ max_dist += tdist*tdist;
+ }
+ }
+
+ x = GETJSAMPLE(cinfo->colormap[2][i]);
+ if (x < minc2) {
+ tdist = (x - minc2) * C2_SCALE;
+ min_dist += tdist*tdist;
+ tdist = (x - maxc2) * C2_SCALE;
+ max_dist += tdist*tdist;
+ } else if (x > maxc2) {
+ tdist = (x - maxc2) * C2_SCALE;
+ min_dist += tdist*tdist;
+ tdist = (x - minc2) * C2_SCALE;
+ max_dist += tdist*tdist;
+ } else {
+ /* within cell range so no contribution to min_dist */
+ if (x <= centerc2) {
+ tdist = (x - maxc2) * C2_SCALE;
+ max_dist += tdist*tdist;
+ } else {
+ tdist = (x - minc2) * C2_SCALE;
+ max_dist += tdist*tdist;
+ }
+ }
+
+ mindist[i] = min_dist; /* save away the results */
+ if (max_dist < minmaxdist)
+ minmaxdist = max_dist;
+ }
+
+ /* Now we know that no cell in the update box is more than minmaxdist
+ * away from some colormap entry. Therefore, only colors that are
+ * within minmaxdist of some part of the box need be considered.
+ */
+ ncolors = 0;
+ for (i = 0; i < numcolors; i++) {
+ if (mindist[i] <= minmaxdist)
+ colorlist[ncolors++] = (JSAMPLE) i;
+ }
+ return ncolors;
+}
+
+
+LOCAL(void)
+find_best_colors (j_decompress_ptr cinfo, int minc0, int minc1, int minc2,
+ int numcolors, JSAMPLE colorlist[], JSAMPLE bestcolor[])
+/* Find the closest colormap entry for each cell in the update box,
+ * given the list of candidate colors prepared by find_nearby_colors.
+ * Return the indexes of the closest entries in the bestcolor[] array.
+ * This routine uses Thomas' incremental distance calculation method to
+ * find the distance from a colormap entry to successive cells in the box.
+ */
+{
+ int ic0, ic1, ic2;
+ int i, icolor;
+ long * bptr; /* pointer into bestdist[] array */
+ JSAMPLE * cptr; /* pointer into bestcolor[] array */
+ long dist0, dist1; /* initial distance values */
+ long dist2; /* current distance in inner loop */
+ long xx0, xx1; /* distance increments */
+ long xx2;
+ long inc0, inc1, inc2; /* initial values for increments */
+ /* This array holds the distance to the nearest-so-far color for each cell */
+ long bestdist[BOX_C0_ELEMS * BOX_C1_ELEMS * BOX_C2_ELEMS];
+
+ /* Initialize best-distance for each cell of the update box */
+ bptr = bestdist;
+ for (i = BOX_C0_ELEMS*BOX_C1_ELEMS*BOX_C2_ELEMS-1; i >= 0; i--)
+ *bptr++ = 0x7FFFFFFFL;
+
+ /* For each color selected by find_nearby_colors,
+ * compute its distance to the center of each cell in the box.
+ * If that's less than best-so-far, update best distance and color number.
+ */
+
+ /* Nominal steps between cell centers ("x" in Thomas article) */
+#define STEP_C0 ((1 << C0_SHIFT) * C0_SCALE)
+#define STEP_C1 ((1 << C1_SHIFT) * C1_SCALE)
+#define STEP_C2 ((1 << C2_SHIFT) * C2_SCALE)
+
+ for (i = 0; i < numcolors; i++) {
+ icolor = GETJSAMPLE(colorlist[i]);
+ /* Compute (square of) distance from minc0/c1/c2 to this color */
+ inc0 = (minc0 - GETJSAMPLE(cinfo->colormap[0][icolor])) * C0_SCALE;
+ dist0 = inc0*inc0;
+ inc1 = (minc1 - GETJSAMPLE(cinfo->colormap[1][icolor])) * C1_SCALE;
+ dist0 += inc1*inc1;
+ inc2 = (minc2 - GETJSAMPLE(cinfo->colormap[2][icolor])) * C2_SCALE;
+ dist0 += inc2*inc2;
+ /* Form the initial difference increments */
+ inc0 = inc0 * (2 * STEP_C0) + STEP_C0 * STEP_C0;
+ inc1 = inc1 * (2 * STEP_C1) + STEP_C1 * STEP_C1;
+ inc2 = inc2 * (2 * STEP_C2) + STEP_C2 * STEP_C2;
+ /* Now loop over all cells in box, updating distance per Thomas method */
+ bptr = bestdist;
+ cptr = bestcolor;
+ xx0 = inc0;
+ for (ic0 = BOX_C0_ELEMS-1; ic0 >= 0; ic0--) {
+ dist1 = dist0;
+ xx1 = inc1;
+ for (ic1 = BOX_C1_ELEMS-1; ic1 >= 0; ic1--) {
+ dist2 = dist1;
+ xx2 = inc2;
+ for (ic2 = BOX_C2_ELEMS-1; ic2 >= 0; ic2--) {
+ if (dist2 < *bptr) {
+ *bptr = dist2;
+ *cptr = (JSAMPLE) icolor;
+ }
+ dist2 += xx2;
+ xx2 += 2 * STEP_C2 * STEP_C2;
+ bptr++;
+ cptr++;
+ }
+ dist1 += xx1;
+ xx1 += 2 * STEP_C1 * STEP_C1;
+ }
+ dist0 += xx0;
+ xx0 += 2 * STEP_C0 * STEP_C0;
+ }
+ }
+}
+
+
+LOCAL(void)
+fill_inverse_cmap (j_decompress_ptr cinfo, int c0, int c1, int c2)
+/* Fill the inverse-colormap entries in the update box that contains */
+/* histogram cell c0/c1/c2. (Only that one cell MUST be filled, but */
+/* we can fill as many others as we wish.) */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ hist3d histogram = cquantize->histogram;
+ int minc0, minc1, minc2; /* lower left corner of update box */
+ int ic0, ic1, ic2;
+ JSAMPLE * cptr; /* pointer into bestcolor[] array */
+ histptr cachep; /* pointer into main cache array */
+ /* This array lists the candidate colormap indexes. */
+ JSAMPLE colorlist[MAXNUMCOLORS];
+ int numcolors; /* number of candidate colors */
+ /* This array holds the actually closest colormap index for each cell. */
+ JSAMPLE bestcolor[BOX_C0_ELEMS * BOX_C1_ELEMS * BOX_C2_ELEMS];
+
+ /* Convert cell coordinates to update box ID */
+ c0 >>= BOX_C0_LOG;
+ c1 >>= BOX_C1_LOG;
+ c2 >>= BOX_C2_LOG;
+
+ /* Compute true coordinates of update box's origin corner.
+ * Actually we compute the coordinates of the center of the corner
+ * histogram cell, which are the lower bounds of the volume we care about.
+ */
+ minc0 = (c0 << BOX_C0_SHIFT) + ((1 << C0_SHIFT) >> 1);
+ minc1 = (c1 << BOX_C1_SHIFT) + ((1 << C1_SHIFT) >> 1);
+ minc2 = (c2 << BOX_C2_SHIFT) + ((1 << C2_SHIFT) >> 1);
+
+ /* Determine which colormap entries are close enough to be candidates
+ * for the nearest entry to some cell in the update box.
+ */
+ numcolors = find_nearby_colors(cinfo, minc0, minc1, minc2, colorlist);
+
+ /* Determine the actually nearest colors. */
+ find_best_colors(cinfo, minc0, minc1, minc2, numcolors, colorlist,
+ bestcolor);
+
+ /* Save the best color numbers (plus 1) in the main cache array */
+ c0 <<= BOX_C0_LOG; /* convert ID back to base cell indexes */
+ c1 <<= BOX_C1_LOG;
+ c2 <<= BOX_C2_LOG;
+ cptr = bestcolor;
+ for (ic0 = 0; ic0 < BOX_C0_ELEMS; ic0++) {
+ for (ic1 = 0; ic1 < BOX_C1_ELEMS; ic1++) {
+ cachep = & histogram[c0+ic0][c1+ic1][c2];
+ for (ic2 = 0; ic2 < BOX_C2_ELEMS; ic2++) {
+ *cachep++ = (histcell) (GETJSAMPLE(*cptr++) + 1);
+ }
+ }
+ }
+}
+
+
+/*
+ * Map some rows of pixels to the output colormapped representation.
+ */
+
+METHODDEF(void)
+pass2_no_dither (j_decompress_ptr cinfo,
+ JSAMPARRAY input_buf, JSAMPARRAY output_buf, int num_rows)
+/* This version performs no dithering */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ hist3d histogram = cquantize->histogram;
+ JSAMPROW inptr, outptr;
+ histptr cachep;
+ int c0, c1, c2;
+ int row;
+ JDIMENSION col;
+ JDIMENSION width = cinfo->output_width;
+
+ for (row = 0; row < num_rows; row++) {
+ inptr = input_buf[row];
+ outptr = output_buf[row];
+ for (col = width; col > 0; col--) {
+ /* get pixel value and index into the cache */
+ c0 = GETJSAMPLE(*inptr++) >> C0_SHIFT;
+ c1 = GETJSAMPLE(*inptr++) >> C1_SHIFT;
+ c2 = GETJSAMPLE(*inptr++) >> C2_SHIFT;
+ cachep = & histogram[c0][c1][c2];
+ /* If we have not seen this color before, find nearest colormap entry */
+ /* and update the cache */
+ if (*cachep == 0)
+ fill_inverse_cmap(cinfo, c0,c1,c2);
+ /* Now emit the colormap index for this cell */
+ *outptr++ = (JSAMPLE) (*cachep - 1);
+ }
+ }
+}
+
+
+METHODDEF(void)
+pass2_fs_dither (j_decompress_ptr cinfo,
+ JSAMPARRAY input_buf, JSAMPARRAY output_buf, int num_rows)
+/* This version performs Floyd-Steinberg dithering */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ hist3d histogram = cquantize->histogram;
+ LOCFSERROR cur0, cur1, cur2; /* current error or pixel value */
+ LOCFSERROR belowerr0, belowerr1, belowerr2; /* error for pixel below cur */
+ LOCFSERROR bpreverr0, bpreverr1, bpreverr2; /* error for below/prev col */
+ FSERRPTR errorptr; /* => fserrors[] at column before current */
+ JSAMPROW inptr; /* => current input pixel */
+ JSAMPROW outptr; /* => current output pixel */
+ histptr cachep;
+ int dir; /* +1 or -1 depending on direction */
+ int dir3; /* 3*dir, for advancing inptr & errorptr */
+ int row;
+ JDIMENSION col;
+ JDIMENSION width = cinfo->output_width;
+ JSAMPLE *range_limit = cinfo->sample_range_limit;
+ int *error_limit = cquantize->error_limiter;
+ JSAMPROW colormap0 = cinfo->colormap[0];
+ JSAMPROW colormap1 = cinfo->colormap[1];
+ JSAMPROW colormap2 = cinfo->colormap[2];
+ SHIFT_TEMPS
+
+ for (row = 0; row < num_rows; row++) {
+ inptr = input_buf[row];
+ outptr = output_buf[row];
+ if (cquantize->on_odd_row) {
+ /* work right to left in this row */
+ inptr += (width-1) * 3; /* so point to rightmost pixel */
+ outptr += width-1;
+ dir = -1;
+ dir3 = -3;
+ errorptr = cquantize->fserrors + (width+1)*3; /* => entry after last column */
+ cquantize->on_odd_row = FALSE; /* flip for next time */
+ } else {
+ /* work left to right in this row */
+ dir = 1;
+ dir3 = 3;
+ errorptr = cquantize->fserrors; /* => entry before first real column */
+ cquantize->on_odd_row = TRUE; /* flip for next time */
+ }
+ /* Preset error values: no error propagated to first pixel from left */
+ cur0 = cur1 = cur2 = 0;
+ /* and no error propagated to row below yet */
+ belowerr0 = belowerr1 = belowerr2 = 0;
+ bpreverr0 = bpreverr1 = bpreverr2 = 0;
+
+ for (col = width; col > 0; col--) {
+ /* curN holds the error propagated from the previous pixel on the
+ * current line. Add the error propagated from the previous line
+ * to form the complete error correction term for this pixel, and
+ * round the error term (which is expressed * 16) to an integer.
+ * RIGHT_SHIFT rounds towards minus infinity, so adding 8 is correct
+ * for either sign of the error value.
+ * Note: errorptr points to *previous* column's array entry.
+ */
+ cur0 = RIGHT_SHIFT(cur0 + errorptr[dir3+0] + 8, 4);
+ cur1 = RIGHT_SHIFT(cur1 + errorptr[dir3+1] + 8, 4);
+ cur2 = RIGHT_SHIFT(cur2 + errorptr[dir3+2] + 8, 4);
+ /* Limit the error using transfer function set by init_error_limit.
+ * See comments with init_error_limit for rationale.
+ */
+ cur0 = error_limit[cur0];
+ cur1 = error_limit[cur1];
+ cur2 = error_limit[cur2];
+ /* Form pixel value + error, and range-limit to 0..MAXJSAMPLE.
+ * The maximum error is +- MAXJSAMPLE (or less with error limiting);
+ * this sets the required size of the range_limit array.
+ */
+ cur0 += GETJSAMPLE(inptr[0]);
+ cur1 += GETJSAMPLE(inptr[1]);
+ cur2 += GETJSAMPLE(inptr[2]);
+ cur0 = GETJSAMPLE(range_limit[cur0]);
+ cur1 = GETJSAMPLE(range_limit[cur1]);
+ cur2 = GETJSAMPLE(range_limit[cur2]);
+ /* Index into the cache with adjusted pixel value */
+ cachep = & histogram[cur0>>C0_SHIFT][cur1>>C1_SHIFT][cur2>>C2_SHIFT];
+ /* If we have not seen this color before, find nearest colormap */
+ /* entry and update the cache */
+ if (*cachep == 0)
+ fill_inverse_cmap(cinfo, cur0>>C0_SHIFT,cur1>>C1_SHIFT,cur2>>C2_SHIFT);
+ /* Now emit the colormap index for this cell */
+ { int pixcode = *cachep - 1;
+ *outptr = (JSAMPLE) pixcode;
+ /* Compute representation error for this pixel */
+ cur0 -= GETJSAMPLE(colormap0[pixcode]);
+ cur1 -= GETJSAMPLE(colormap1[pixcode]);
+ cur2 -= GETJSAMPLE(colormap2[pixcode]);
+ }
+ /* Compute error fractions to be propagated to adjacent pixels.
+ * Add these into the running sums, and simultaneously shift the
+ * next-line error sums left by 1 column.
+ */
+ { LOCFSERROR bnexterr, delta;
+
+ bnexterr = cur0; /* Process component 0 */
+ delta = cur0 * 2;
+ cur0 += delta; /* form error * 3 */
+ errorptr[0] = (FSERROR) (bpreverr0 + cur0);
+ cur0 += delta; /* form error * 5 */
+ bpreverr0 = belowerr0 + cur0;
+ belowerr0 = bnexterr;
+ cur0 += delta; /* form error * 7 */
+ bnexterr = cur1; /* Process component 1 */
+ delta = cur1 * 2;
+ cur1 += delta; /* form error * 3 */
+ errorptr[1] = (FSERROR) (bpreverr1 + cur1);
+ cur1 += delta; /* form error * 5 */
+ bpreverr1 = belowerr1 + cur1;
+ belowerr1 = bnexterr;
+ cur1 += delta; /* form error * 7 */
+ bnexterr = cur2; /* Process component 2 */
+ delta = cur2 * 2;
+ cur2 += delta; /* form error * 3 */
+ errorptr[2] = (FSERROR) (bpreverr2 + cur2);
+ cur2 += delta; /* form error * 5 */
+ bpreverr2 = belowerr2 + cur2;
+ belowerr2 = bnexterr;
+ cur2 += delta; /* form error * 7 */
+ }
+ /* At this point curN contains the 7/16 error value to be propagated
+ * to the next pixel on the current line, and all the errors for the
+ * next line have been shifted over. We are therefore ready to move on.
+ */
+ inptr += dir3; /* Advance pixel pointers to next column */
+ outptr += dir;
+ errorptr += dir3; /* advance errorptr to current column */
+ }
+ /* Post-loop cleanup: we must unload the final error values into the
+ * final fserrors[] entry. Note we need not unload belowerrN because
+ * it is for the dummy column before or after the actual array.
+ */
+ errorptr[0] = (FSERROR) bpreverr0; /* unload prev errs into array */
+ errorptr[1] = (FSERROR) bpreverr1;
+ errorptr[2] = (FSERROR) bpreverr2;
+ }
+}
+
+
+/*
+ * Initialize the error-limiting transfer function (lookup table).
+ * The raw F-S error computation can potentially compute error values of up to
+ * +- MAXJSAMPLE. But we want the maximum correction applied to a pixel to be
+ * much less, otherwise obviously wrong pixels will be created. (Typical
+ * effects include weird fringes at color-area boundaries, isolated bright
+ * pixels in a dark area, etc.) The standard advice for avoiding this problem
+ * is to ensure that the "corners" of the color cube are allocated as output
+ * colors; then repeated errors in the same direction cannot cause cascading
+ * error buildup. However, that only prevents the error from getting
+ * completely out of hand; Aaron Giles reports that error limiting improves
+ * the results even with corner colors allocated.
+ * A simple clamping of the error values to about +- MAXJSAMPLE/8 works pretty
+ * well, but the smoother transfer function used below is even better. Thanks
+ * to Aaron Giles for this idea.
+ */
+
+LOCAL(void)
+init_error_limit (j_decompress_ptr cinfo)
+/* Allocate and fill in the error_limiter table */
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ int * table;
+ int in, out;
+
+ table = (int *) (*cinfo->mem->alloc_small)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE, (MAXJSAMPLE*2+1) * SIZEOF(int));
+ table += MAXJSAMPLE; /* so can index -MAXJSAMPLE .. +MAXJSAMPLE */
+ cquantize->error_limiter = table;
+
+#define STEPSIZE ((MAXJSAMPLE+1)/16)
+ /* Map errors 1:1 up to +- MAXJSAMPLE/16 */
+ out = 0;
+ for (in = 0; in < STEPSIZE; in++, out++) {
+ table[in] = out; table[-in] = -out;
+ }
+ /* Map errors 1:2 up to +- 3*MAXJSAMPLE/16 */
+ for (; in < STEPSIZE*3; in++, out += (in&1) ? 0 : 1) {
+ table[in] = out; table[-in] = -out;
+ }
+ /* Clamp the rest to final out value (which is (MAXJSAMPLE+1)/8) */
+ for (; in <= MAXJSAMPLE; in++) {
+ table[in] = out; table[-in] = -out;
+ }
+#undef STEPSIZE
+}
+
+
+/*
+ * Finish up at the end of each pass.
+ */
+
+METHODDEF(void)
+finish_pass1 (j_decompress_ptr cinfo)
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+
+ /* Select the representative colors and fill in cinfo->colormap */
+ cinfo->colormap = cquantize->sv_colormap;
+ select_colors(cinfo, cquantize->desired);
+ /* Force next pass to zero the color index table */
+ cquantize->needs_zeroed = TRUE;
+}
+
+
+METHODDEF(void)
+finish_pass2 (j_decompress_ptr )
+{
+ /* no work */
+}
+
+
+/*
+ * Initialize for each processing pass.
+ */
+
+METHODDEF(void)
+start_pass_2_quant (j_decompress_ptr cinfo, int is_pre_scan)
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+ hist3d histogram = cquantize->histogram;
+ int i;
+
+ /* Only F-S dithering or no dithering is supported. */
+ /* If user asks for ordered dither, give him F-S. */
+ if (cinfo->dither_mode != JDITHER_NONE)
+ cinfo->dither_mode = JDITHER_FS;
+
+ if (is_pre_scan) {
+ /* Set up method pointers */
+ cquantize->pub.color_quantize = prescan_quantize;
+ cquantize->pub.finish_pass = finish_pass1;
+ cquantize->needs_zeroed = TRUE; /* Always zero histogram */
+ } else {
+ /* Set up method pointers */
+ if (cinfo->dither_mode == JDITHER_FS)
+ cquantize->pub.color_quantize = pass2_fs_dither;
+ else
+ cquantize->pub.color_quantize = pass2_no_dither;
+ cquantize->pub.finish_pass = finish_pass2;
+
+ /* Make sure color count is acceptable */
+ i = cinfo->actual_number_of_colors;
+ if (i < 1)
+ ERREXIT1(cinfo, JERR_QUANT_FEW_COLORS, 1);
+ if (i > MAXNUMCOLORS)
+ ERREXIT1(cinfo, JERR_QUANT_MANY_COLORS, MAXNUMCOLORS);
+
+ if (cinfo->dither_mode == JDITHER_FS) {
+ size_t arraysize = (size_t) ((cinfo->output_width + 2) *
+ (3 * SIZEOF(FSERROR)));
+ /* Allocate Floyd-Steinberg workspace if we didn't already. */
+ if (cquantize->fserrors == NULL)
+ cquantize->fserrors = (FSERRPTR) (*cinfo->mem->alloc_large)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE, arraysize);
+ /* Initialize the propagated errors to zero. */
+ jzero_far((void FAR *) cquantize->fserrors, arraysize);
+ /* Make the error-limit table if we didn't already. */
+ if (cquantize->error_limiter == NULL)
+ init_error_limit(cinfo);
+ cquantize->on_odd_row = FALSE;
+ }
+
+ }
+ /* Zero the histogram or inverse color map, if necessary */
+ if (cquantize->needs_zeroed) {
+ for (i = 0; i < HIST_C0_ELEMS; i++) {
+ jzero_far((void FAR *) histogram[i],
+ HIST_C1_ELEMS*HIST_C2_ELEMS * SIZEOF(histcell));
+ }
+ cquantize->needs_zeroed = FALSE;
+ }
+}
+
+
+/*
+ * Switch to a new external colormap between output passes.
+ */
+
+METHODDEF(void)
+new_color_map_2_quant (j_decompress_ptr cinfo)
+{
+ my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize;
+
+ /* Reset the inverse color map */
+ cquantize->needs_zeroed = TRUE;
+}
+
+
+/*
+ * Module initialization routine for 2-pass color quantization.
+ */
+
+GLOBAL(void)
+jinit_2pass_quantizer (j_decompress_ptr cinfo)
+{
+ my_cquantize_ptr cquantize;
+ int i;
+
+ cquantize = (my_cquantize_ptr)
+ (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ SIZEOF(my_cquantizer));
+ cinfo->cquantize = (struct jpeg_color_quantizer *) cquantize;
+ cquantize->pub.start_pass = start_pass_2_quant;
+ cquantize->pub.new_color_map = new_color_map_2_quant;
+ cquantize->fserrors = NULL; /* flag optional arrays not allocated */
+ cquantize->error_limiter = NULL;
+
+ /* Make sure jdmaster didn't give me a case I can't handle */
+ if (cinfo->out_color_components != 3)
+ ERREXIT(cinfo, JERR_NOTIMPL);
+
+ /* Allocate the histogram/inverse colormap storage */
+ cquantize->histogram = (hist3d) (*cinfo->mem->alloc_small)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE, HIST_C0_ELEMS * SIZEOF(hist2d));
+ for (i = 0; i < HIST_C0_ELEMS; i++) {
+ cquantize->histogram[i] = (hist2d) (*cinfo->mem->alloc_large)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ HIST_C1_ELEMS*HIST_C2_ELEMS * SIZEOF(histcell));
+ }
+ cquantize->needs_zeroed = TRUE; /* histogram is garbage now */
+
+ /* Allocate storage for the completed colormap, if required.
+ * We do this now since it is FAR storage and may affect
+ * the memory manager's space calculations.
+ */
+ if (cinfo->enable_2pass_quant) {
+ /* Make sure color count is acceptable */
+ int desired = cinfo->desired_number_of_colors;
+ /* Lower bound on # of colors ... somewhat arbitrary as long as > 0 */
+ if (desired < 8)
+ ERREXIT1(cinfo, JERR_QUANT_FEW_COLORS, 8);
+ /* Make sure colormap indexes can be represented by JSAMPLEs */
+ if (desired > MAXNUMCOLORS)
+ ERREXIT1(cinfo, JERR_QUANT_MANY_COLORS, MAXNUMCOLORS);
+ cquantize->sv_colormap = (*cinfo->mem->alloc_sarray)
+ ((j_common_ptr) cinfo,JPOOL_IMAGE, (JDIMENSION) desired, (JDIMENSION) 3);
+ cquantize->desired = desired;
+ } else
+ cquantize->sv_colormap = NULL;
+
+ /* Only F-S dithering or no dithering is supported. */
+ /* If user asks for ordered dither, give him F-S. */
+ if (cinfo->dither_mode != JDITHER_NONE)
+ cinfo->dither_mode = JDITHER_FS;
+
+ /* Allocate Floyd-Steinberg workspace if necessary.
+ * This isn't really needed until pass 2, but again it is FAR storage.
+ * Although we will cope with a later change in dither_mode,
+ * we do not promise to honor max_memory_to_use if dither_mode changes.
+ */
+ if (cinfo->dither_mode == JDITHER_FS) {
+ cquantize->fserrors = (FSERRPTR) (*cinfo->mem->alloc_large)
+ ((j_common_ptr) cinfo, JPOOL_IMAGE,
+ (size_t) ((cinfo->output_width + 2) * (3 * SIZEOF(FSERROR))));
+ /* Might as well create the error-limiting table too. */
+ init_error_limit(cinfo);
+ }
+}
+
+#endif /* QUANT_2PASS_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libjpeg/jutils.cpp b/ml/dlib/dlib/external/libjpeg/jutils.cpp
new file mode 100644
index 000000000..fd8906c83
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jutils.cpp
@@ -0,0 +1,179 @@
+/*
+ * jutils.c
+ *
+ * Copyright (C) 1991-1996, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains tables and miscellaneous utility routines needed
+ * for both compression and decompression.
+ * Note we prefix all global names with "j" to minimize conflicts with
+ * a surrounding application.
+ */
+
+#define JPEG_INTERNALS
+#include "jinclude.h"
+#include "jpeglib.h"
+
+
+/*
+ * jpeg_zigzag_order[i] is the zigzag-order position of the i'th element
+ * of a DCT block read in natural order (left to right, top to bottom).
+ */
+
+#if 0 /* This table is not actually needed in v6a */
+
+const int jpeg_zigzag_order[DCTSIZE2] = {
+ 0, 1, 5, 6, 14, 15, 27, 28,
+ 2, 4, 7, 13, 16, 26, 29, 42,
+ 3, 8, 12, 17, 25, 30, 41, 43,
+ 9, 11, 18, 24, 31, 40, 44, 53,
+ 10, 19, 23, 32, 39, 45, 52, 54,
+ 20, 22, 33, 38, 46, 51, 55, 60,
+ 21, 34, 37, 47, 50, 56, 59, 61,
+ 35, 36, 48, 49, 57, 58, 62, 63
+};
+
+#endif
+
+/*
+ * jpeg_natural_order[i] is the natural-order position of the i'th element
+ * of zigzag order.
+ *
+ * When reading corrupted data, the Huffman decoders could attempt
+ * to reference an entry beyond the end of this array (if the decoded
+ * zero run length reaches past the end of the block). To prevent
+ * wild stores without adding an inner-loop test, we put some extra
+ * "63"s after the real entries. This will cause the extra coefficient
+ * to be stored in location 63 of the block, not somewhere random.
+ * The worst case would be a run-length of 15, which means we need 16
+ * fake entries.
+ */
+
+const int jpeg_natural_order[DCTSIZE2+16] = {
+ 0, 1, 8, 16, 9, 2, 3, 10,
+ 17, 24, 32, 25, 18, 11, 4, 5,
+ 12, 19, 26, 33, 40, 48, 41, 34,
+ 27, 20, 13, 6, 7, 14, 21, 28,
+ 35, 42, 49, 56, 57, 50, 43, 36,
+ 29, 22, 15, 23, 30, 37, 44, 51,
+ 58, 59, 52, 45, 38, 31, 39, 46,
+ 53, 60, 61, 54, 47, 55, 62, 63,
+ 63, 63, 63, 63, 63, 63, 63, 63, /* extra entries for safety in decoder */
+ 63, 63, 63, 63, 63, 63, 63, 63
+};
+
+
+/*
+ * Arithmetic utilities
+ */
+
+GLOBAL(long)
+jdiv_round_up (long a, long b)
+/* Compute a/b rounded up to next integer, ie, ceil(a/b) */
+/* Assumes a >= 0, b > 0 */
+{
+ return (a + b - 1L) / b;
+}
+
+
+GLOBAL(long)
+jround_up (long a, long b)
+/* Compute a rounded up to next multiple of b, ie, ceil(a/b)*b */
+/* Assumes a >= 0, b > 0 */
+{
+ a += b - 1L;
+ return a - (a % b);
+}
+
+
+/* On normal machines we can apply MEMCOPY() and MEMZERO() to sample arrays
+ * and coefficient-block arrays. This won't work on 80x86 because the arrays
+ * are FAR and we're assuming a small-pointer memory model. However, some
+ * DOS compilers provide far-pointer versions of memcpy() and memset() even
+ * in the small-model libraries. These will be used if USE_FMEM is defined.
+ * Otherwise, the routines below do it the hard way. (The performance cost
+ * is not all that great, because these routines aren't very heavily used.)
+ */
+
+#ifndef NEED_FAR_POINTERS /* normal case, same as regular macros */
+#define FMEMCOPY(dest,src,size) MEMCOPY(dest,src,size)
+#define FMEMZERO(target,size) MEMZERO(target,size)
+#else /* 80x86 case, define if we can */
+#ifdef USE_FMEM
+#define FMEMCOPY(dest,src,size) _fmemcpy((void FAR *)(dest), (const void FAR *)(src), (size_t)(size))
+#define FMEMZERO(target,size) _fmemset((void FAR *)(target), 0, (size_t)(size))
+#endif
+#endif
+
+
+GLOBAL(void)
+jcopy_sample_rows (JSAMPARRAY input_array, int source_row,
+ JSAMPARRAY output_array, int dest_row,
+ int num_rows, JDIMENSION num_cols)
+/* Copy some rows of samples from one place to another.
+ * num_rows rows are copied from input_array[source_row++]
+ * to output_array[dest_row++]; these areas may overlap for duplication.
+ * The source and destination arrays must be at least as wide as num_cols.
+ */
+{
+ JSAMPROW inptr, outptr;
+#ifdef FMEMCOPY
+ size_t count = (size_t) (num_cols * SIZEOF(JSAMPLE));
+#else
+ JDIMENSION count;
+#endif
+ int row;
+
+ input_array += source_row;
+ output_array += dest_row;
+
+ for (row = num_rows; row > 0; row--) {
+ inptr = *input_array++;
+ outptr = *output_array++;
+#ifdef FMEMCOPY
+ FMEMCOPY(outptr, inptr, count);
+#else
+ for (count = num_cols; count > 0; count--)
+ *outptr++ = *inptr++; /* needn't bother with GETJSAMPLE() here */
+#endif
+ }
+}
+
+
+GLOBAL(void)
+jcopy_block_row (JBLOCKROW input_row, JBLOCKROW output_row,
+ JDIMENSION num_blocks)
+/* Copy a row of coefficient blocks from one place to another. */
+{
+#ifdef FMEMCOPY
+ FMEMCOPY(output_row, input_row, num_blocks * (DCTSIZE2 * SIZEOF(JCOEF)));
+#else
+ JCOEFPTR inptr, outptr;
+ long count;
+
+ inptr = (JCOEFPTR) input_row;
+ outptr = (JCOEFPTR) output_row;
+ for (count = (long) num_blocks * DCTSIZE2; count > 0; count--) {
+ *outptr++ = *inptr++;
+ }
+#endif
+}
+
+
+GLOBAL(void)
+jzero_far (void FAR * target, size_t bytestozero)
+/* Zero out a chunk of FAR memory. */
+/* This might be sample-array data, block-array data, or alloc_large data. */
+{
+#ifdef FMEMZERO
+ FMEMZERO(target, bytestozero);
+#else
+ char FAR * ptr = (char FAR *) target;
+ size_t count;
+
+ for (count = bytestozero; count > 0; count--) {
+ *ptr++ = 0;
+ }
+#endif
+}
diff --git a/ml/dlib/dlib/external/libjpeg/jversion.h b/ml/dlib/dlib/external/libjpeg/jversion.h
new file mode 100644
index 000000000..6472c58d3
--- /dev/null
+++ b/ml/dlib/dlib/external/libjpeg/jversion.h
@@ -0,0 +1,14 @@
+/*
+ * jversion.h
+ *
+ * Copyright (C) 1991-1998, Thomas G. Lane.
+ * This file is part of the Independent JPEG Group's software.
+ * For conditions of distribution and use, see the accompanying README file.
+ *
+ * This file contains software version identification.
+ */
+
+
+#define JVERSION "6b 27-Mar-1998"
+
+#define JCOPYRIGHT "Copyright (C) 1998, Thomas G. Lane"
diff --git a/ml/dlib/dlib/external/libpng/LICENSE b/ml/dlib/dlib/external/libpng/LICENSE
new file mode 100644
index 000000000..b1b97ea57
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/LICENSE
@@ -0,0 +1,111 @@
+
+This copy of the libpng notices is provided for your convenience. In case of
+any discrepancy between this copy and the notices in the file png.h that is
+included in the libpng distribution, the latter shall prevail.
+
+COPYRIGHT NOTICE, DISCLAIMER, and LICENSE:
+
+If you modify libpng you may insert additional notices immediately following
+this sentence.
+
+This code is released under the libpng license.
+
+libpng versions 1.2.6, August 15, 2004, through 1.6.7, November 14, 2013, are
+Copyright (c) 2004, 2006-2013 Glenn Randers-Pehrson, and are
+distributed according to the same disclaimer and license as libpng-1.2.5
+with the following individual added to the list of Contributing Authors
+
+ Cosmin Truta
+
+libpng versions 1.0.7, July 1, 2000, through 1.2.5 - October 3, 2002, are
+Copyright (c) 2000-2002 Glenn Randers-Pehrson, and are
+distributed according to the same disclaimer and license as libpng-1.0.6
+with the following individuals added to the list of Contributing Authors
+
+ Simon-Pierre Cadieux
+ Eric S. Raymond
+ Gilles Vollant
+
+and with the following additions to the disclaimer:
+
+ There is no warranty against interference with your enjoyment of the
+ library or against infringement. There is no warranty that our
+ efforts or the library will fulfill any of your particular purposes
+ or needs. This library is provided with all faults, and the entire
+ risk of satisfactory quality, performance, accuracy, and effort is with
+ the user.
+
+libpng versions 0.97, January 1998, through 1.0.6, March 20, 2000, are
+Copyright (c) 1998, 1999 Glenn Randers-Pehrson, and are
+distributed according to the same disclaimer and license as libpng-0.96,
+with the following individuals added to the list of Contributing Authors:
+
+ Tom Lane
+ Glenn Randers-Pehrson
+ Willem van Schaik
+
+libpng versions 0.89, June 1996, through 0.96, May 1997, are
+Copyright (c) 1996, 1997 Andreas Dilger
+Distributed according to the same disclaimer and license as libpng-0.88,
+with the following individuals added to the list of Contributing Authors:
+
+ John Bowler
+ Kevin Bracey
+ Sam Bushell
+ Magnus Holmgren
+ Greg Roelofs
+ Tom Tanner
+
+libpng versions 0.5, May 1995, through 0.88, January 1996, are
+Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.
+
+For the purposes of this copyright and license, "Contributing Authors"
+is defined as the following set of individuals:
+
+ Andreas Dilger
+ Dave Martindale
+ Guy Eric Schalnat
+ Paul Schmidt
+ Tim Wegner
+
+The PNG Reference Library is supplied "AS IS". The Contributing Authors
+and Group 42, Inc. disclaim all warranties, expressed or implied,
+including, without limitation, the warranties of merchantability and of
+fitness for any purpose. The Contributing Authors and Group 42, Inc.
+assume no liability for direct, indirect, incidental, special, exemplary,
+or consequential damages, which may result from the use of the PNG
+Reference Library, even if advised of the possibility of such damage.
+
+Permission is hereby granted to use, copy, modify, and distribute this
+source code, or portions hereof, for any purpose, without fee, subject
+to the following restrictions:
+
+1. The origin of this source code must not be misrepresented.
+
+2. Altered versions must be plainly marked as such and must not
+ be misrepresented as being the original source.
+
+3. This Copyright notice may not be removed or altered from any
+ source or altered source distribution.
+
+The Contributing Authors and Group 42, Inc. specifically permit, without
+fee, and encourage the use of this source code as a component to
+supporting the PNG file format in commercial products. If you use this
+source code in a product, acknowledgment is not required but would be
+appreciated.
+
+
+A "png_get_copyright" function is available, for convenient use in "about"
+boxes and the like:
+
+ printf("%s",png_get_copyright(NULL));
+
+Also, the PNG logo (in PNG format, of course) is supplied in the
+files "pngbar.png" and "pngbar.jpg (88x31) and "pngnow.png" (98x31).
+
+Libpng is OSI Certified Open Source Software. OSI Certified Open Source is a
+certification mark of the Open Source Initiative.
+
+Glenn Randers-Pehrson
+glennrp at users.sourceforge.net
+November 14, 2013
diff --git a/ml/dlib/dlib/external/libpng/README b/ml/dlib/dlib/external/libpng/README
new file mode 100644
index 000000000..80fc574ad
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/README
@@ -0,0 +1,202 @@
+README for libpng version 1.6.7 - November 14, 2013 (shared library 16.0)
+See the note about version numbers near the top of png.h
+
+See INSTALL for instructions on how to install libpng.
+
+Libpng comes in several distribution formats. Get libpng-*.tar.gz or
+libpng-*.tar.xz or if you want UNIX-style line endings in the text files,
+or lpng*.7z or lpng*.zip if you want DOS-style line endings.
+
+Version 0.89 was the first official release of libpng. Don't let the
+fact that it's the first release fool you. The libpng library has been in
+extensive use and testing since mid-1995. By late 1997 it had
+finally gotten to the stage where there hadn't been significant
+changes to the API in some time, and people have a bad feeling about
+libraries with versions < 1.0. Version 1.0.0 was released in
+March 1998.
+
+****
+Note that some of the changes to the png_info structure render this
+version of the library binary incompatible with libpng-0.89 or
+earlier versions if you are using a shared library. The type of the
+"filler" parameter for png_set_filler() has changed from png_byte to
+png_uint_32, which will affect shared-library applications that use
+this function.
+
+To avoid problems with changes to the internals of png_info_struct,
+new APIs have been made available in 0.95 to avoid direct application
+access to info_ptr. These functions are the png_set_<chunk> and
+png_get_<chunk> functions. These functions should be used when
+accessing/storing the info_struct data, rather than manipulating it
+directly, to avoid such problems in the future.
+
+It is important to note that the APIs do not make current programs
+that access the info struct directly incompatible with the new
+library. However, it is strongly suggested that new programs use
+the new APIs (as shown in example.c and pngtest.c), and older programs
+be converted to the new format, to facilitate upgrades in the future.
+****
+
+Additions since 0.90 include the ability to compile libpng as a
+Windows DLL, and new APIs for accessing data in the info struct.
+Experimental functions include the ability to set weighting and cost
+factors for row filter selection, direct reads of integers from buffers
+on big-endian processors that support misaligned data access, faster
+methods of doing alpha composition, and more accurate 16->8 bit color
+conversion.
+
+The additions since 0.89 include the ability to read from a PNG stream
+which has had some (or all) of the signature bytes read by the calling
+application. This also allows the reading of embedded PNG streams that
+do not have the PNG file signature. As well, it is now possible to set
+the library action on the detection of chunk CRC errors. It is possible
+to set different actions based on whether the CRC error occurred in a
+critical or an ancillary chunk.
+
+The changes made to the library, and bugs fixed are based on discussions
+on the PNG-implement mailing list and not on material submitted
+privately to Guy, Andreas, or Glenn. They will forward any good
+suggestions to the list.
+
+For a detailed description on using libpng, read libpng-manual.txt. For
+examples of libpng in a program, see example.c and pngtest.c. For usage
+information and restrictions (what little they are) on libpng, see
+png.h. For a description on using zlib (the compression library used by
+libpng) and zlib's restrictions, see zlib.h
+
+I have included a general makefile, as well as several machine and
+compiler specific ones, but you may have to modify one for your own needs.
+
+You should use zlib 1.0.4 or later to run this, but it MAY work with
+versions as old as zlib 0.95. Even so, there are bugs in older zlib
+versions which can cause the output of invalid compression streams for
+some images. You will definitely need zlib 1.0.4 or later if you are
+taking advantage of the MS-DOS "far" structure allocation for the small
+and medium memory models. You should also note that zlib is a
+compression library that is useful for more things than just PNG files.
+You can use zlib as a drop-in replacement for fread() and fwrite() if
+you are so inclined.
+
+zlib should be available at the same place that libpng is, or at zlib.net.
+
+You may also want a copy of the PNG specification. It is available
+as an RFC, a W3C Recommendation, and an ISO/IEC Standard. You can find
+these at http://www.libpng.org/pub/png/documents/
+
+This code is currently being archived at libpng.sf.net in the
+[DOWNLOAD] area, and at ftp://ftp.simplesystems.org. If you can't find it
+in any of those places, e-mail me, and I'll help you find it.
+
+If you have any code changes, requests, problems, etc., please e-mail
+them to me. Also, I'd appreciate any make files or project files,
+and any modifications you needed to make to get libpng to compile,
+along with a #define variable to tell what compiler/system you are on.
+If you needed to add transformations to libpng, or wish libpng would
+provide the image in a different way, drop me a note (and code, if
+possible), so I can consider supporting the transformation.
+Finally, if you get any warning messages when compiling libpng
+(note: not zlib), and they are easy to fix, I'd appreciate the
+fix. Please mention "libpng" somewhere in the subject line. Thanks.
+
+This release was created and will be supported by myself (of course
+based in a large way on Guy's and Andreas' earlier work), and the PNG
+development group.
+
+Send comments/corrections/commendations to png-mng-implement at
+lists.sourceforge.net (subscription required; visit
+https://lists.sourceforge.net/lists/listinfo/png-mng-implement
+to subscribe) or to glennrp at users.sourceforge.net
+
+You can't reach Guy, the original libpng author, at the addresses
+given in previous versions of this document. He and Andreas will
+read mail addressed to the png-implement list, however.
+
+Please do not send general questions about PNG. Send them to
+png-mng-misc at lists.sf.net (subscription required; visit
+https://lists.sourceforge.net/lists/listinfo/png-mng-misc to
+subscribe). If you have a question about something
+in the PNG specification that is related to using libpng, send it
+to me. Send me any questions that start with "I was using libpng,
+and ...". If in doubt, send questions to me. I'll bounce them
+to others, if necessary.
+
+Please do not send suggestions on how to change PNG. We have
+been discussing PNG for eighteen years now, and it is official and
+finished. If you have suggestions for libpng, however, I'll
+gladly listen. Even if your suggestion is not used immediately,
+it may be used later.
+
+Files in this distribution:
+
+ ANNOUNCE => Announcement of this version, with recent changes
+ CHANGES => Description of changes between libpng versions
+ KNOWNBUG => List of known bugs and deficiencies
+ LICENSE => License to use and redistribute libpng
+ README => This file
+ TODO => Things not implemented in the current library
+ Y2KINFO => Statement of Y2K compliance
+ example.c => Example code for using libpng functions
+ libpng.3 => manual page for libpng (includes libpng-manual.txt)
+ libpng-manual.txt => Description of libpng and its functions
+ libpngpf.3 => manual page for libpng's private functions
+ png.5 => manual page for the PNG format
+ png.c => Basic interface functions common to library
+ png.h => Library function and interface declarations (public)
+ pngpriv.h => Library function and interface declarations (private)
+ pngconf.h => System specific library configuration (public)
+ pngstruct.h => png_struct declaration (private)
+ pnginfo.h => png_info struct declaration (private)
+ pngdebug.h => debugging macros (private)
+ pngerror.c => Error/warning message I/O functions
+ pngget.c => Functions for retrieving info from struct
+ pngmem.c => Memory handling functions
+ pngbar.png => PNG logo, 88x31
+ pngnow.png => PNG logo, 98x31
+ pngpread.c => Progressive reading functions
+ pngread.c => Read data/helper high-level functions
+ pngrio.c => Lowest-level data read I/O functions
+ pngrtran.c => Read data transformation functions
+ pngrutil.c => Read data utility functions
+ pngset.c => Functions for storing data into the info_struct
+ pngtest.c => Library test program
+ pngtest.png => Library test sample image
+ pngtrans.c => Common data transformation functions
+ pngwio.c => Lowest-level write I/O functions
+ pngwrite.c => High-level write functions
+ pngwtran.c => Write data transformations
+ pngwutil.c => Write utility functions
+ arm => Contains optimized code for the ARM platform
+ contrib => Contributions
+ examples => Example programs
+ gregbook => source code for PNG reading and writing, from
+ Greg Roelofs' "PNG: The Definitive Guide",
+ O'Reilly, 1999
+ libtests => Test programs
+ pngminim => Minimal decoder, encoder, and progressive decoder
+ programs demonstrating use of pngusr.dfa
+ pngminus => Simple pnm2png and png2pnm programs
+ pngsuite => Test images
+ tools => Various tools
+ visupng => Contains a MSVC workspace for VisualPng
+ projects => Contains project files and workspaces for
+ building a DLL
+ owatcom => Contains a WATCOM project for building libpng
+ visualc71 => Contains a Microsoft Visual C++ (MSVC)
+ workspace for building libpng and zlib
+ vstudio => Contains a Microsoft Visual C++ (MSVC)
+ workspace for building libpng and zlib
+ scripts => Directory containing scripts for building libpng:
+ (see scripts/README.txt for the list of scripts)
+
+Good luck, and happy coding.
+
+-Glenn Randers-Pehrson (current maintainer, since 1998)
+ Internet: glennrp at users.sourceforge.net
+
+-Andreas Eric Dilger (former maintainer, 1996-1997)
+ Internet: adilger at enel.ucalgary.ca
+ Web: http://www-mddsp.enel.ucalgary.ca/People/adilger/
+
+-Guy Eric Schalnat (original author and former maintainer, 1995-1996)
+ (formerly of Group 42, Inc)
+ Internet: gschal at infinet.com
diff --git a/ml/dlib/dlib/external/libpng/arm/arm_init.c b/ml/dlib/dlib/external/libpng/arm/arm_init.c
new file mode 100644
index 000000000..098771781
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/arm/arm_init.c
@@ -0,0 +1,232 @@
+
+/* arm_init.c - NEON optimised filter functions
+ *
+ * Copyright (c) 2013 Glenn Randers-Pehrson
+ * Written by Mans Rullgard, 2011.
+ * Last changed in libpng 1.6.6 [September 16, 2013]
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+/* Below, after checking __linux__, various non-C90 POSIX 1003.1 functions are
+ * called.
+ */
+#define _POSIX_SOURCE 1
+
+#include "../pngpriv.h"
+
+#ifdef PNG_READ_SUPPORTED
+#if PNG_ARM_NEON_OPT > 0
+#ifdef PNG_ARM_NEON_CHECK_SUPPORTED /* Do run-time checks */
+#include <signal.h> /* for sig_atomic_t */
+
+#ifdef __ANDROID__
+/* Linux provides access to information about CPU capabilites via
+ * /proc/self/auxv, however Android blocks this while still claiming to be
+ * Linux. The Andoid NDK, however, provides appropriate support.
+ *
+ * Documentation: http://www.kandroid.org/ndk/docs/CPU-ARM-NEON.html
+ */
+#include <cpu-features.h>
+
+static int
+png_have_neon(png_structp png_ptr)
+{
+ /* This is a whole lot easier than the mess below, however it is probably
+ * implemented as below, therefore it is better to cache the result (these
+ * function calls may be slow!)
+ */
+ PNG_UNUSED(png_ptr)
+ return android_getCpuFamily() == ANDROID_CPU_FAMILY_ARM &&
+ (android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_NEON) != 0;
+}
+#elif defined(__linux__)
+/* The generic __linux__ implementation requires reading /proc/self/auxv and
+ * looking at each element for one that records NEON capabilities.
+ */
+#include <unistd.h> /* for POSIX 1003.1 */
+#include <errno.h> /* for EINTR */
+
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <elf.h>
+#include <asm/hwcap.h>
+
+/* A read call may be interrupted, in which case it returns -1 and sets errno to
+ * EINTR if nothing was done, otherwise (if something was done) a partial read
+ * may result.
+ */
+static size_t
+safe_read(png_structp png_ptr, int fd, void *buffer_in, size_t nbytes)
+{
+ size_t ntotal = 0;
+ char *buffer = png_voidcast(char*, buffer_in);
+
+ while (nbytes > 0)
+ {
+ unsigned int nread;
+ int iread;
+
+ /* Passing nread > INT_MAX to read is implementation defined in POSIX
+ * 1003.1, therefore despite the unsigned argument portable code must
+ * limit the value to INT_MAX!
+ */
+ if (nbytes > INT_MAX)
+ nread = INT_MAX;
+
+ else
+ nread = (unsigned int)/*SAFE*/nbytes;
+
+ iread = read(fd, buffer, nread);
+
+ if (iread == -1)
+ {
+ /* This is the devil in the details, a read can terminate early with 0
+ * bytes read because of EINTR, yet it still returns -1 otherwise end
+ * of file cannot be distinguished.
+ */
+ if (errno != EINTR)
+ {
+ png_warning(png_ptr, "/proc read failed");
+ return 0; /* I.e. a permanent failure */
+ }
+ }
+
+ else if (iread < 0)
+ {
+ /* Not a valid 'read' result: */
+ png_warning(png_ptr, "OS /proc read bug");
+ return 0;
+ }
+
+ else if (iread > 0)
+ {
+ /* Continue reading until a permanent failure, or EOF */
+ buffer += iread;
+ nbytes -= (unsigned int)/*SAFE*/iread;
+ ntotal += (unsigned int)/*SAFE*/iread;
+ }
+
+ else
+ return ntotal;
+ }
+
+ return ntotal; /* nbytes == 0 */
+}
+
+static int
+png_have_neon(png_structp png_ptr)
+{
+ int fd = open("/proc/self/auxv", O_RDONLY);
+ Elf32_auxv_t aux;
+
+ /* Failsafe: failure to open means no NEON */
+ if (fd == -1)
+ {
+ png_warning(png_ptr, "/proc/self/auxv open failed");
+ return 0;
+ }
+
+ while (safe_read(png_ptr, fd, &aux, sizeof aux) == sizeof aux)
+ {
+ if (aux.a_type == AT_HWCAP && (aux.a_un.a_val & HWCAP_NEON) != 0)
+ {
+ close(fd);
+ return 1;
+ }
+ }
+
+ close(fd);
+ return 0;
+}
+#else
+ /* We don't know how to do a run-time check on this system */
+# error "no support for run-time ARM NEON checks"
+#endif /* OS checks */
+#endif /* PNG_ARM_NEON_CHECK_SUPPORTED */
+
+#ifndef PNG_ALIGNED_MEMORY_SUPPORTED
+# error "ALIGNED_MEMORY is required; set: -DPNG_ALIGNED_MEMORY_SUPPORTED"
+#endif
+
+void
+png_init_filter_functions_neon(png_structp pp, unsigned int bpp)
+{
+ /* The switch statement is compiled in for ARM_NEON_API, the call to
+ * png_have_neon is compiled in for ARM_NEON_CHECK. If both are defined
+ * the check is only performed if the API has not set the NEON option on
+ * or off explicitly. In this case the check controls what happens.
+ *
+ * If the CHECK is not compiled in and the option is UNSET the behavior prior
+ * to 1.6.7 was to use the NEON code - this was a bug caused by having the
+ * wrong order of the 'ON' and 'default' cases. UNSET now defaults to OFF,
+ * as documented in png.h
+ */
+#ifdef PNG_ARM_NEON_API_SUPPORTED
+ switch ((pp->options >> PNG_ARM_NEON) & 3)
+ {
+ case PNG_OPTION_UNSET:
+ /* Allow the run-time check to execute if it has been enabled -
+ * thus both API and CHECK can be turned on. If it isn't supported
+ * this case will fall through to the 'default' below, which just
+ * returns.
+ */
+#endif /* PNG_ARM_NEON_API_SUPPORTED */
+#ifdef PNG_ARM_NEON_CHECK_SUPPORTED
+ {
+ static volatile sig_atomic_t no_neon = -1; /* not checked */
+
+ if (no_neon < 0)
+ no_neon = !png_have_neon(pp);
+
+ if (no_neon)
+ return;
+ }
+#ifdef PNG_ARM_NEON_API_SUPPORTED
+ break;
+#endif
+#endif /* PNG_ARM_NEON_CHECK_SUPPORTED */
+
+#ifdef PNG_ARM_NEON_API_SUPPORTED
+ default: /* OFF or INVALID */
+ return;
+
+ case PNG_OPTION_ON:
+ /* Option turned on */
+ break;
+ }
+#endif
+
+ /* IMPORTANT: any new external functions used here must be declared using
+ * PNG_INTERNAL_FUNCTION in ../pngpriv.h. This is required so that the
+ * 'prefix' option to configure works:
+ *
+ * ./configure --with-libpng-prefix=foobar_
+ *
+ * Verify you have got this right by running the above command, doing a build
+ * and examining pngprefix.h; it must contain a #define for every external
+ * function you add. (Notice that this happens automatically for the
+ * initialization function.)
+ */
+ pp->read_filter[PNG_FILTER_VALUE_UP-1] = png_read_filter_row_up_neon;
+
+ if (bpp == 3)
+ {
+ pp->read_filter[PNG_FILTER_VALUE_SUB-1] = png_read_filter_row_sub3_neon;
+ pp->read_filter[PNG_FILTER_VALUE_AVG-1] = png_read_filter_row_avg3_neon;
+ pp->read_filter[PNG_FILTER_VALUE_PAETH-1] =
+ png_read_filter_row_paeth3_neon;
+ }
+
+ else if (bpp == 4)
+ {
+ pp->read_filter[PNG_FILTER_VALUE_SUB-1] = png_read_filter_row_sub4_neon;
+ pp->read_filter[PNG_FILTER_VALUE_AVG-1] = png_read_filter_row_avg4_neon;
+ pp->read_filter[PNG_FILTER_VALUE_PAETH-1] =
+ png_read_filter_row_paeth4_neon;
+ }
+}
+#endif /* PNG_ARM_NEON_OPT > 0 */
+#endif /* PNG_READ_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/arm/filter_neon.S b/ml/dlib/dlib/external/libpng/arm/filter_neon.S
new file mode 100644
index 000000000..3d1ccf505
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/arm/filter_neon.S
@@ -0,0 +1,245 @@
+
+/* filter_neon.S - NEON optimised filter functions
+ *
+ * Copyright (c) 2013 Glenn Randers-Pehrson
+ * Written by Mans Rullgard, 2011.
+ * Last changed in libpng 1.6.7 [November 14, 2013]
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+/* This is required to get the symbol renames, which are #defines, and also
+ * includes the definition (or not) of PNG_ARM_NEON_OPT.
+ */
+#define PNG_VERSION_INFO_ONLY
+#include "../pngpriv.h"
+
+#if defined(__linux__) && defined(__ELF__)
+.section .note.GNU-stack,"",%progbits /* mark stack as non-executable */
+#endif
+
+/* Assembler NEON support - only works for 32-bit ARM (i.e. it does not work for
+ * ARM64). The code in arm/filter_neon_intrinsics.c supports ARM64, however it
+ * only works if -mfpu=neon is specified on the GCC command line. See pngpriv.h
+ * for the logic which sets PNG_USE_ARM_NEON_ASM:
+ */
+#if PNG_ARM_NEON_IMPLEMENTATION == 2 /* hand-coded assembler */
+
+#ifdef PNG_READ_SUPPORTED
+#if PNG_ARM_NEON_OPT > 0
+
+#ifdef __ELF__
+# define ELF
+#else
+# define ELF @
+#endif
+
+ .arch armv7-a
+ .fpu neon
+
+.macro func name, export=0
+ .macro endfunc
+ELF .size \name, . - \name
+ .endfunc
+ .purgem endfunc
+ .endm
+ .text
+ .if \export
+ .global \name
+ .endif
+ELF .type \name, STT_FUNC
+ .func \name
+\name:
+.endm
+
+func png_read_filter_row_sub4_neon, export=1
+ ldr r3, [r0, #4] @ rowbytes
+ vmov.i8 d3, #0
+1:
+ vld4.32 {d4[],d5[],d6[],d7[]}, [r1,:128]
+ vadd.u8 d0, d3, d4
+ vadd.u8 d1, d0, d5
+ vadd.u8 d2, d1, d6
+ vadd.u8 d3, d2, d7
+ vst4.32 {d0[0],d1[0],d2[0],d3[0]},[r1,:128]!
+ subs r3, r3, #16
+ bgt 1b
+
+ bx lr
+endfunc
+
+func png_read_filter_row_sub3_neon, export=1
+ ldr r3, [r0, #4] @ rowbytes
+ vmov.i8 d3, #0
+ mov r0, r1
+ mov r2, #3
+ mov r12, #12
+ vld1.8 {q11}, [r0], r12
+1:
+ vext.8 d5, d22, d23, #3
+ vadd.u8 d0, d3, d22
+ vext.8 d6, d22, d23, #6
+ vadd.u8 d1, d0, d5
+ vext.8 d7, d23, d23, #1
+ vld1.8 {q11}, [r0], r12
+ vst1.32 {d0[0]}, [r1,:32], r2
+ vadd.u8 d2, d1, d6
+ vst1.32 {d1[0]}, [r1], r2
+ vadd.u8 d3, d2, d7
+ vst1.32 {d2[0]}, [r1], r2
+ vst1.32 {d3[0]}, [r1], r2
+ subs r3, r3, #12
+ bgt 1b
+
+ bx lr
+endfunc
+
+func png_read_filter_row_up_neon, export=1
+ ldr r3, [r0, #4] @ rowbytes
+1:
+ vld1.8 {q0}, [r1,:128]
+ vld1.8 {q1}, [r2,:128]!
+ vadd.u8 q0, q0, q1
+ vst1.8 {q0}, [r1,:128]!
+ subs r3, r3, #16
+ bgt 1b
+
+ bx lr
+endfunc
+
+func png_read_filter_row_avg4_neon, export=1
+ ldr r12, [r0, #4] @ rowbytes
+ vmov.i8 d3, #0
+1:
+ vld4.32 {d4[],d5[],d6[],d7[]}, [r1,:128]
+ vld4.32 {d16[],d17[],d18[],d19[]},[r2,:128]!
+ vhadd.u8 d0, d3, d16
+ vadd.u8 d0, d0, d4
+ vhadd.u8 d1, d0, d17
+ vadd.u8 d1, d1, d5
+ vhadd.u8 d2, d1, d18
+ vadd.u8 d2, d2, d6
+ vhadd.u8 d3, d2, d19
+ vadd.u8 d3, d3, d7
+ vst4.32 {d0[0],d1[0],d2[0],d3[0]},[r1,:128]!
+ subs r12, r12, #16
+ bgt 1b
+
+ bx lr
+endfunc
+
+func png_read_filter_row_avg3_neon, export=1
+ push {r4,lr}
+ ldr r12, [r0, #4] @ rowbytes
+ vmov.i8 d3, #0
+ mov r0, r1
+ mov r4, #3
+ mov lr, #12
+ vld1.8 {q11}, [r0], lr
+1:
+ vld1.8 {q10}, [r2], lr
+ vext.8 d5, d22, d23, #3
+ vhadd.u8 d0, d3, d20
+ vext.8 d17, d20, d21, #3
+ vadd.u8 d0, d0, d22
+ vext.8 d6, d22, d23, #6
+ vhadd.u8 d1, d0, d17
+ vext.8 d18, d20, d21, #6
+ vadd.u8 d1, d1, d5
+ vext.8 d7, d23, d23, #1
+ vld1.8 {q11}, [r0], lr
+ vst1.32 {d0[0]}, [r1,:32], r4
+ vhadd.u8 d2, d1, d18
+ vst1.32 {d1[0]}, [r1], r4
+ vext.8 d19, d21, d21, #1
+ vadd.u8 d2, d2, d6
+ vhadd.u8 d3, d2, d19
+ vst1.32 {d2[0]}, [r1], r4
+ vadd.u8 d3, d3, d7
+ vst1.32 {d3[0]}, [r1], r4
+ subs r12, r12, #12
+ bgt 1b
+
+ pop {r4,pc}
+endfunc
+
+.macro paeth rx, ra, rb, rc
+ vaddl.u8 q12, \ra, \rb @ a + b
+ vaddl.u8 q15, \rc, \rc @ 2*c
+ vabdl.u8 q13, \rb, \rc @ pa
+ vabdl.u8 q14, \ra, \rc @ pb
+ vabd.u16 q15, q12, q15 @ pc
+ vcle.u16 q12, q13, q14 @ pa <= pb
+ vcle.u16 q13, q13, q15 @ pa <= pc
+ vcle.u16 q14, q14, q15 @ pb <= pc
+ vand q12, q12, q13 @ pa <= pb && pa <= pc
+ vmovn.u16 d28, q14
+ vmovn.u16 \rx, q12
+ vbsl d28, \rb, \rc
+ vbsl \rx, \ra, d28
+.endm
+
+func png_read_filter_row_paeth4_neon, export=1
+ ldr r12, [r0, #4] @ rowbytes
+ vmov.i8 d3, #0
+ vmov.i8 d20, #0
+1:
+ vld4.32 {d4[],d5[],d6[],d7[]}, [r1,:128]
+ vld4.32 {d16[],d17[],d18[],d19[]},[r2,:128]!
+ paeth d0, d3, d16, d20
+ vadd.u8 d0, d0, d4
+ paeth d1, d0, d17, d16
+ vadd.u8 d1, d1, d5
+ paeth d2, d1, d18, d17
+ vadd.u8 d2, d2, d6
+ paeth d3, d2, d19, d18
+ vmov d20, d19
+ vadd.u8 d3, d3, d7
+ vst4.32 {d0[0],d1[0],d2[0],d3[0]},[r1,:128]!
+ subs r12, r12, #16
+ bgt 1b
+
+ bx lr
+endfunc
+
+func png_read_filter_row_paeth3_neon, export=1
+ push {r4,lr}
+ ldr r12, [r0, #4] @ rowbytes
+ vmov.i8 d3, #0
+ vmov.i8 d4, #0
+ mov r0, r1
+ mov r4, #3
+ mov lr, #12
+ vld1.8 {q11}, [r0], lr
+1:
+ vld1.8 {q10}, [r2], lr
+ paeth d0, d3, d20, d4
+ vext.8 d5, d22, d23, #3
+ vadd.u8 d0, d0, d22
+ vext.8 d17, d20, d21, #3
+ paeth d1, d0, d17, d20
+ vst1.32 {d0[0]}, [r1,:32], r4
+ vext.8 d6, d22, d23, #6
+ vadd.u8 d1, d1, d5
+ vext.8 d18, d20, d21, #6
+ paeth d2, d1, d18, d17
+ vext.8 d7, d23, d23, #1
+ vld1.8 {q11}, [r0], lr
+ vst1.32 {d1[0]}, [r1], r4
+ vadd.u8 d2, d2, d6
+ vext.8 d19, d21, d21, #1
+ paeth d3, d2, d19, d18
+ vst1.32 {d2[0]}, [r1], r4
+ vmov d4, d19
+ vadd.u8 d3, d3, d7
+ vst1.32 {d3[0]}, [r1], r4
+ subs r12, r12, #12
+ bgt 1b
+
+ pop {r4,pc}
+endfunc
+#endif /* PNG_ARM_NEON_OPT > 0 */
+#endif /* PNG_READ_SUPPORTED */
+#endif /* PNG_ARM_NEON_IMPLEMENTATION == 2 (assembler) */
diff --git a/ml/dlib/dlib/external/libpng/arm/filter_neon_intrinsics.c b/ml/dlib/dlib/external/libpng/arm/filter_neon_intrinsics.c
new file mode 100644
index 000000000..e6a0217ab
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/arm/filter_neon_intrinsics.c
@@ -0,0 +1,372 @@
+
+/* filter_neon_intrinsics.c - NEON optimised filter functions
+ *
+ * Copyright (c) 2013 Glenn Randers-Pehrson
+ * Written by James Yu <james.yu at linaro.org>, October 2013.
+ * Based on filter_neon.S, written by Mans Rullgard, 2011.
+ *
+ * Last changed in libpng 1.6.7 [November 14, 2013]
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+#include "../pngpriv.h"
+
+/* This code requires -mfpu=neon on the command line: */
+#if PNG_ARM_NEON_IMPLEMENTATION == 1 /* intrinsics code */
+
+#include <arm_neon.h>
+
+/* libpng row pointers are not necessarily aligned to any particular boundary,
+ * however this code will only work with appropriate alignment. arm/arm_init.c
+ * checks for this (and will not compile unless it is done), this code uses
+ * variants of png_aligncast to avoid compiler warnings.
+ */
+#define png_ptr(type,pointer) png_aligncast(type *,pointer)
+#define png_ptrc(type,pointer) png_aligncastconst(const type *,pointer)
+
+/* The following relies on a variable 'temp_pointer' being declared with type
+ * 'type'. This is written this way just to hide the GCC strict aliasing
+ * warning; note that the code is safe because there never is an alias between
+ * the input and output pointers.
+ */
+#define png_ldr(type,pointer)\
+ (temp_pointer = png_ptr(type,pointer), *temp_pointer)
+
+#ifdef PNG_READ_SUPPORTED
+#if PNG_ARM_NEON_OPT > 0
+
+void
+png_read_filter_row_up_neon(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_bytep rp = row;
+ png_bytep rp_stop = row + row_info->rowbytes;
+ png_const_bytep pp = prev_row;
+
+ for (; rp < rp_stop; rp += 16, pp += 16)
+ {
+ uint8x16_t qrp, qpp;
+
+ qrp = vld1q_u8(rp);
+ qpp = vld1q_u8(pp);
+ qrp = vaddq_u8(qrp, qpp);
+ vst1q_u8(rp, qrp);
+ }
+}
+
+void
+png_read_filter_row_sub3_neon(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_bytep rp = row;
+ png_bytep rp_stop = row + row_info->rowbytes;
+
+ uint8x16_t vtmp = vld1q_u8(rp);
+ uint8x8x2_t *vrpt = png_ptr(uint8x8x2_t, &vtmp);
+ uint8x8x2_t vrp = *vrpt;
+
+ uint8x8x4_t vdest;
+ vdest.val[3] = vdup_n_u8(0);
+
+ for (; rp < rp_stop;)
+ {
+ uint8x8_t vtmp1, vtmp2;
+ uint32x2_t *temp_pointer;
+
+ vtmp1 = vext_u8(vrp.val[0], vrp.val[1], 3);
+ vdest.val[0] = vadd_u8(vdest.val[3], vrp.val[0]);
+ vtmp2 = vext_u8(vrp.val[0], vrp.val[1], 6);
+ vdest.val[1] = vadd_u8(vdest.val[0], vtmp1);
+
+ vtmp1 = vext_u8(vrp.val[1], vrp.val[1], 1);
+ vdest.val[2] = vadd_u8(vdest.val[1], vtmp2);
+ vdest.val[3] = vadd_u8(vdest.val[2], vtmp1);
+
+ vtmp = vld1q_u8(rp + 12);
+ vrpt = png_ptr(uint8x8x2_t, &vtmp);
+ vrp = *vrpt;
+
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[0]), 0);
+ rp += 3;
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[1]), 0);
+ rp += 3;
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[2]), 0);
+ rp += 3;
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[3]), 0);
+ rp += 3;
+ }
+
+ PNG_UNUSED(prev_row)
+}
+
+void
+png_read_filter_row_sub4_neon(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_bytep rp = row;
+ png_bytep rp_stop = row + row_info->rowbytes;
+
+ uint8x8x4_t vdest;
+ vdest.val[3] = vdup_n_u8(0);
+
+ for (; rp < rp_stop; rp += 16)
+ {
+ uint32x2x4_t vtmp = vld4_u32(png_ptr(uint32_t,rp));
+ uint8x8x4_t *vrpt = png_ptr(uint8x8x4_t,&vtmp);
+ uint8x8x4_t vrp = *vrpt;
+ uint32x2x4_t *temp_pointer;
+
+ vdest.val[0] = vadd_u8(vdest.val[3], vrp.val[0]);
+ vdest.val[1] = vadd_u8(vdest.val[0], vrp.val[1]);
+ vdest.val[2] = vadd_u8(vdest.val[1], vrp.val[2]);
+ vdest.val[3] = vadd_u8(vdest.val[2], vrp.val[3]);
+ vst4_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2x4_t,&vdest), 0);
+ }
+
+ PNG_UNUSED(prev_row)
+}
+
+void
+png_read_filter_row_avg3_neon(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_bytep rp = row;
+ png_const_bytep pp = prev_row;
+ png_bytep rp_stop = row + row_info->rowbytes;
+
+ uint8x16_t vtmp;
+ uint8x8x2_t *vrpt;
+ uint8x8x2_t vrp;
+ uint8x8x4_t vdest;
+ vdest.val[3] = vdup_n_u8(0);
+
+ vtmp = vld1q_u8(rp);
+ vrpt = png_ptr(uint8x8x2_t,&vtmp);
+ vrp = *vrpt;
+
+ for (; rp < rp_stop; pp += 12)
+ {
+ uint8x8_t vtmp1, vtmp2, vtmp3;
+
+ uint8x8x2_t *vppt;
+ uint8x8x2_t vpp;
+
+ uint32x2_t *temp_pointer;
+
+ vtmp = vld1q_u8(pp);
+ vppt = png_ptr(uint8x8x2_t,&vtmp);
+ vpp = *vppt;
+
+ vtmp1 = vext_u8(vrp.val[0], vrp.val[1], 3);
+ vdest.val[0] = vhadd_u8(vdest.val[3], vpp.val[0]);
+ vdest.val[0] = vadd_u8(vdest.val[0], vrp.val[0]);
+
+ vtmp2 = vext_u8(vpp.val[0], vpp.val[1], 3);
+ vtmp3 = vext_u8(vrp.val[0], vrp.val[1], 6);
+ vdest.val[1] = vhadd_u8(vdest.val[0], vtmp2);
+ vdest.val[1] = vadd_u8(vdest.val[1], vtmp1);
+
+ vtmp2 = vext_u8(vpp.val[0], vpp.val[1], 6);
+ vtmp1 = vext_u8(vrp.val[1], vrp.val[1], 1);
+
+ vtmp = vld1q_u8(rp + 12);
+ vrpt = png_ptr(uint8x8x2_t,&vtmp);
+ vrp = *vrpt;
+
+ vdest.val[2] = vhadd_u8(vdest.val[1], vtmp2);
+ vdest.val[2] = vadd_u8(vdest.val[2], vtmp3);
+
+ vtmp2 = vext_u8(vpp.val[1], vpp.val[1], 1);
+
+ vdest.val[3] = vhadd_u8(vdest.val[2], vtmp2);
+ vdest.val[3] = vadd_u8(vdest.val[3], vtmp1);
+
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[0]), 0);
+ rp += 3;
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[1]), 0);
+ rp += 3;
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[2]), 0);
+ rp += 3;
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[3]), 0);
+ rp += 3;
+ }
+}
+
+void
+png_read_filter_row_avg4_neon(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_bytep rp = row;
+ png_bytep rp_stop = row + row_info->rowbytes;
+ png_const_bytep pp = prev_row;
+
+ uint8x8x4_t vdest;
+ vdest.val[3] = vdup_n_u8(0);
+
+ for (; rp < rp_stop; rp += 16, pp += 16)
+ {
+ uint32x2x4_t vtmp;
+ uint8x8x4_t *vrpt, *vppt;
+ uint8x8x4_t vrp, vpp;
+ uint32x2x4_t *temp_pointer;
+
+ vtmp = vld4_u32(png_ptr(uint32_t,rp));
+ vrpt = png_ptr(uint8x8x4_t,&vtmp);
+ vrp = *vrpt;
+ vtmp = vld4_u32(png_ptrc(uint32_t,pp));
+ vppt = png_ptr(uint8x8x4_t,&vtmp);
+ vpp = *vppt;
+
+ vdest.val[0] = vhadd_u8(vdest.val[3], vpp.val[0]);
+ vdest.val[0] = vadd_u8(vdest.val[0], vrp.val[0]);
+ vdest.val[1] = vhadd_u8(vdest.val[0], vpp.val[1]);
+ vdest.val[1] = vadd_u8(vdest.val[1], vrp.val[1]);
+ vdest.val[2] = vhadd_u8(vdest.val[1], vpp.val[2]);
+ vdest.val[2] = vadd_u8(vdest.val[2], vrp.val[2]);
+ vdest.val[3] = vhadd_u8(vdest.val[2], vpp.val[3]);
+ vdest.val[3] = vadd_u8(vdest.val[3], vrp.val[3]);
+
+ vst4_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2x4_t,&vdest), 0);
+ }
+}
+
+static uint8x8_t
+paeth(uint8x8_t a, uint8x8_t b, uint8x8_t c)
+{
+ uint8x8_t d, e;
+ uint16x8_t p1, pa, pb, pc;
+
+ p1 = vaddl_u8(a, b); /* a + b */
+ pc = vaddl_u8(c, c); /* c * 2 */
+ pa = vabdl_u8(b, c); /* pa */
+ pb = vabdl_u8(a, c); /* pb */
+ pc = vabdq_u16(p1, pc); /* pc */
+
+ p1 = vcleq_u16(pa, pb); /* pa <= pb */
+ pa = vcleq_u16(pa, pc); /* pa <= pc */
+ pb = vcleq_u16(pb, pc); /* pb <= pc */
+
+ p1 = vandq_u16(p1, pa); /* pa <= pb && pa <= pc */
+
+ d = vmovn_u16(pb);
+ e = vmovn_u16(p1);
+
+ d = vbsl_u8(d, b, c);
+ e = vbsl_u8(e, a, d);
+
+ return e;
+}
+
+void
+png_read_filter_row_paeth3_neon(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_bytep rp = row;
+ png_const_bytep pp = prev_row;
+ png_bytep rp_stop = row + row_info->rowbytes;
+
+ uint8x16_t vtmp;
+ uint8x8x2_t *vrpt;
+ uint8x8x2_t vrp;
+ uint8x8_t vlast = vdup_n_u8(0);
+ uint8x8x4_t vdest;
+ vdest.val[3] = vdup_n_u8(0);
+
+ vtmp = vld1q_u8(rp);
+ vrpt = png_ptr(uint8x8x2_t,&vtmp);
+ vrp = *vrpt;
+
+ for (; rp < rp_stop; pp += 12)
+ {
+ uint8x8x2_t *vppt;
+ uint8x8x2_t vpp;
+ uint8x8_t vtmp1, vtmp2, vtmp3;
+ uint32x2_t *temp_pointer;
+
+ vtmp = vld1q_u8(pp);
+ vppt = png_ptr(uint8x8x2_t,&vtmp);
+ vpp = *vppt;
+
+ vdest.val[0] = paeth(vdest.val[3], vpp.val[0], vlast);
+ vdest.val[0] = vadd_u8(vdest.val[0], vrp.val[0]);
+
+ vtmp1 = vext_u8(vrp.val[0], vrp.val[1], 3);
+ vtmp2 = vext_u8(vpp.val[0], vpp.val[1], 3);
+ vdest.val[1] = paeth(vdest.val[0], vtmp2, vpp.val[0]);
+ vdest.val[1] = vadd_u8(vdest.val[1], vtmp1);
+
+ vtmp1 = vext_u8(vrp.val[0], vrp.val[1], 6);
+ vtmp3 = vext_u8(vpp.val[0], vpp.val[1], 6);
+ vdest.val[2] = paeth(vdest.val[1], vtmp3, vtmp2);
+ vdest.val[2] = vadd_u8(vdest.val[2], vtmp1);
+
+ vtmp1 = vext_u8(vrp.val[1], vrp.val[1], 1);
+ vtmp2 = vext_u8(vpp.val[1], vpp.val[1], 1);
+
+ vtmp = vld1q_u8(rp + 12);
+ vrpt = png_ptr(uint8x8x2_t,&vtmp);
+ vrp = *vrpt;
+
+ vdest.val[3] = paeth(vdest.val[2], vtmp2, vtmp3);
+ vdest.val[3] = vadd_u8(vdest.val[3], vtmp1);
+
+ vlast = vtmp2;
+
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[0]), 0);
+ rp += 3;
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[1]), 0);
+ rp += 3;
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[2]), 0);
+ rp += 3;
+ vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[3]), 0);
+ rp += 3;
+ }
+}
+
+void
+png_read_filter_row_paeth4_neon(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_bytep rp = row;
+ png_bytep rp_stop = row + row_info->rowbytes;
+ png_const_bytep pp = prev_row;
+
+ uint8x8_t vlast = vdup_n_u8(0);
+ uint8x8x4_t vdest;
+ vdest.val[3] = vdup_n_u8(0);
+
+ for (; rp < rp_stop; rp += 16, pp += 16)
+ {
+ uint32x2x4_t vtmp;
+ uint8x8x4_t *vrpt, *vppt;
+ uint8x8x4_t vrp, vpp;
+ uint32x2x4_t *temp_pointer;
+
+ vtmp = vld4_u32(png_ptr(uint32_t,rp));
+ vrpt = png_ptr(uint8x8x4_t,&vtmp);
+ vrp = *vrpt;
+ vtmp = vld4_u32(png_ptrc(uint32_t,pp));
+ vppt = png_ptr(uint8x8x4_t,&vtmp);
+ vpp = *vppt;
+
+ vdest.val[0] = paeth(vdest.val[3], vpp.val[0], vlast);
+ vdest.val[0] = vadd_u8(vdest.val[0], vrp.val[0]);
+ vdest.val[1] = paeth(vdest.val[0], vpp.val[1], vpp.val[0]);
+ vdest.val[1] = vadd_u8(vdest.val[1], vrp.val[1]);
+ vdest.val[2] = paeth(vdest.val[1], vpp.val[2], vpp.val[1]);
+ vdest.val[2] = vadd_u8(vdest.val[2], vrp.val[2]);
+ vdest.val[3] = paeth(vdest.val[2], vpp.val[3], vpp.val[2]);
+ vdest.val[3] = vadd_u8(vdest.val[3], vrp.val[3]);
+
+ vlast = vpp.val[3];
+
+ vst4_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2x4_t,&vdest), 0);
+ }
+}
+
+#endif /* PNG_ARM_NEON_OPT > 0 */
+#endif /* PNG_READ_SUPPORTED */
+#endif /* PNG_ARM_NEON_IMPLEMENTATION == 1 (intrinsics) */
diff --git a/ml/dlib/dlib/external/libpng/png.c b/ml/dlib/dlib/external/libpng/png.c
new file mode 100644
index 000000000..efcc6eead
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/png.c
@@ -0,0 +1,4299 @@
+
+/* png.c - location for general purpose libpng functions
+ *
+ * Last changed in libpng 1.6.2 [April 25, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+#include "pngpriv.h"
+
+/* Generate a compiler error if there is an old png.h in the search path. */
+typedef png_libpng_version_1_6_7 Your_png_h_is_not_version_1_6_7;
+
+/* Tells libpng that we have already handled the first "num_bytes" bytes
+ * of the PNG file signature. If the PNG data is embedded into another
+ * stream we can set num_bytes = 8 so that libpng will not attempt to read
+ * or write any of the magic bytes before it starts on the IHDR.
+ */
+
+#ifdef PNG_READ_SUPPORTED
+void PNGAPI
+png_set_sig_bytes(png_structrp png_ptr, int num_bytes)
+{
+ png_debug(1, "in png_set_sig_bytes");
+
+ if (png_ptr == NULL)
+ return;
+
+ if (num_bytes > 8)
+ png_error(png_ptr, "Too many bytes for PNG signature");
+
+ png_ptr->sig_bytes = (png_byte)(num_bytes < 0 ? 0 : num_bytes);
+}
+
+/* Checks whether the supplied bytes match the PNG signature. We allow
+ * checking less than the full 8-byte signature so that those apps that
+ * already read the first few bytes of a file to determine the file type
+ * can simply check the remaining bytes for extra assurance. Returns
+ * an integer less than, equal to, or greater than zero if sig is found,
+ * respectively, to be less than, to match, or be greater than the correct
+ * PNG signature (this is the same behavior as strcmp, memcmp, etc).
+ */
+int PNGAPI
+png_sig_cmp(png_const_bytep sig, png_size_t start, png_size_t num_to_check)
+{
+ png_byte png_signature[8] = {137, 80, 78, 71, 13, 10, 26, 10};
+
+ if (num_to_check > 8)
+ num_to_check = 8;
+
+ else if (num_to_check < 1)
+ return (-1);
+
+ if (start > 7)
+ return (-1);
+
+ if (start + num_to_check > 8)
+ num_to_check = 8 - start;
+
+ return ((int)(memcmp(&sig[start], &png_signature[start], num_to_check)));
+}
+
+#endif /* PNG_READ_SUPPORTED */
+
+#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED)
+/* Function to allocate memory for zlib */
+PNG_FUNCTION(voidpf /* PRIVATE */,
+png_zalloc,(voidpf png_ptr, uInt items, uInt size),PNG_ALLOCATED)
+{
+ png_alloc_size_t num_bytes = size;
+
+ if (png_ptr == NULL)
+ return NULL;
+
+ if (items >= (~(png_alloc_size_t)0)/size)
+ {
+ png_warning (png_voidcast(png_structrp, png_ptr),
+ "Potential overflow in png_zalloc()");
+ return NULL;
+ }
+
+ num_bytes *= items;
+ return png_malloc_warn(png_voidcast(png_structrp, png_ptr), num_bytes);
+}
+
+/* Function to free memory for zlib */
+void /* PRIVATE */
+png_zfree(voidpf png_ptr, voidpf ptr)
+{
+ png_free(png_voidcast(png_const_structrp,png_ptr), ptr);
+}
+
+/* Reset the CRC variable to 32 bits of 1's. Care must be taken
+ * in case CRC is > 32 bits to leave the top bits 0.
+ */
+void /* PRIVATE */
+png_reset_crc(png_structrp png_ptr)
+{
+ /* The cast is safe because the crc is a 32 bit value. */
+ png_ptr->crc = (png_uint_32)crc32(0, Z_NULL, 0);
+}
+
+/* Calculate the CRC over a section of data. We can only pass as
+ * much data to this routine as the largest single buffer size. We
+ * also check that this data will actually be used before going to the
+ * trouble of calculating it.
+ */
+void /* PRIVATE */
+png_calculate_crc(png_structrp png_ptr, png_const_bytep ptr, png_size_t length)
+{
+ int need_crc = 1;
+
+ if (PNG_CHUNK_ANCILLARY(png_ptr->chunk_name))
+ {
+ if ((png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_MASK) ==
+ (PNG_FLAG_CRC_ANCILLARY_USE | PNG_FLAG_CRC_ANCILLARY_NOWARN))
+ need_crc = 0;
+ }
+
+ else /* critical */
+ {
+ if (png_ptr->flags & PNG_FLAG_CRC_CRITICAL_IGNORE)
+ need_crc = 0;
+ }
+
+ /* 'uLong' is defined in zlib.h as unsigned long; this means that on some
+ * systems it is a 64 bit value. crc32, however, returns 32 bits so the
+ * following cast is safe. 'uInt' may be no more than 16 bits, so it is
+ * necessary to perform a loop here.
+ */
+ if (need_crc && length > 0)
+ {
+ uLong crc = png_ptr->crc; /* Should never issue a warning */
+
+ do
+ {
+ uInt safe_length = (uInt)length;
+ if (safe_length == 0)
+ safe_length = (uInt)-1; /* evil, but safe */
+
+ crc = crc32(crc, ptr, safe_length);
+
+ /* The following should never issue compiler warnings; if they do the
+ * target system has characteristics that will probably violate other
+ * assumptions within the libpng code.
+ */
+ ptr += safe_length;
+ length -= safe_length;
+ }
+ while (length > 0);
+
+ /* And the following is always safe because the crc is only 32 bits. */
+ png_ptr->crc = (png_uint_32)crc;
+ }
+}
+
+/* Check a user supplied version number, called from both read and write
+ * functions that create a png_struct.
+ */
+int
+png_user_version_check(png_structrp png_ptr, png_const_charp user_png_ver)
+{
+ if (user_png_ver)
+ {
+ int i = 0;
+
+ do
+ {
+ if (user_png_ver[i] != png_libpng_ver[i])
+ png_ptr->flags |= PNG_FLAG_LIBRARY_MISMATCH;
+ } while (png_libpng_ver[i++]);
+ }
+
+ else
+ png_ptr->flags |= PNG_FLAG_LIBRARY_MISMATCH;
+
+ if (png_ptr->flags & PNG_FLAG_LIBRARY_MISMATCH)
+ {
+ /* Libpng 0.90 and later are binary incompatible with libpng 0.89, so
+ * we must recompile any applications that use any older library version.
+ * For versions after libpng 1.0, we will be compatible, so we need
+ * only check the first and third digits (note that when we reach version
+ * 1.10 we will need to check the fourth symbol, namely user_png_ver[3]).
+ */
+ if (user_png_ver == NULL || user_png_ver[0] != png_libpng_ver[0] ||
+ (user_png_ver[0] == '1' && (user_png_ver[2] != png_libpng_ver[2] ||
+ user_png_ver[3] != png_libpng_ver[3])) ||
+ (user_png_ver[0] == '0' && user_png_ver[2] < '9'))
+ {
+#ifdef PNG_WARNINGS_SUPPORTED
+ size_t pos = 0;
+ char m[128];
+
+ pos = png_safecat(m, (sizeof m), pos,
+ "Application built with libpng-");
+ pos = png_safecat(m, (sizeof m), pos, user_png_ver);
+ pos = png_safecat(m, (sizeof m), pos, " but running with ");
+ pos = png_safecat(m, (sizeof m), pos, png_libpng_ver);
+
+ png_warning(png_ptr, m);
+#endif
+
+#ifdef PNG_ERROR_NUMBERS_SUPPORTED
+ png_ptr->flags = 0;
+#endif
+
+ return 0;
+ }
+ }
+
+ /* Success return. */
+ return 1;
+}
+
+/* Generic function to create a png_struct for either read or write - this
+ * contains the common initialization.
+ */
+PNG_FUNCTION(png_structp /* PRIVATE */,
+png_create_png_struct,(png_const_charp user_png_ver, png_voidp error_ptr,
+ png_error_ptr error_fn, png_error_ptr warn_fn, png_voidp mem_ptr,
+ png_malloc_ptr malloc_fn, png_free_ptr free_fn),PNG_ALLOCATED)
+{
+ png_struct create_struct;
+# ifdef PNG_SETJMP_SUPPORTED
+ jmp_buf create_jmp_buf;
+# endif
+
+ /* This temporary stack-allocated structure is used to provide a place to
+ * build enough context to allow the user provided memory allocator (if any)
+ * to be called.
+ */
+ memset(&create_struct, 0, (sizeof create_struct));
+
+ /* Added at libpng-1.2.6 */
+# ifdef PNG_USER_LIMITS_SUPPORTED
+ create_struct.user_width_max = PNG_USER_WIDTH_MAX;
+ create_struct.user_height_max = PNG_USER_HEIGHT_MAX;
+
+# ifdef PNG_USER_CHUNK_CACHE_MAX
+ /* Added at libpng-1.2.43 and 1.4.0 */
+ create_struct.user_chunk_cache_max = PNG_USER_CHUNK_CACHE_MAX;
+# endif
+
+# ifdef PNG_USER_CHUNK_MALLOC_MAX
+ /* Added at libpng-1.2.43 and 1.4.1, required only for read but exists
+ * in png_struct regardless.
+ */
+ create_struct.user_chunk_malloc_max = PNG_USER_CHUNK_MALLOC_MAX;
+# endif
+# endif
+
+ /* The following two API calls simply set fields in png_struct, so it is safe
+ * to do them now even though error handling is not yet set up.
+ */
+# ifdef PNG_USER_MEM_SUPPORTED
+ png_set_mem_fn(&create_struct, mem_ptr, malloc_fn, free_fn);
+# endif
+
+ /* (*error_fn) can return control to the caller after the error_ptr is set,
+ * this will result in a memory leak unless the error_fn does something
+ * extremely sophisticated. The design lacks merit but is implicit in the
+ * API.
+ */
+ png_set_error_fn(&create_struct, error_ptr, error_fn, warn_fn);
+
+# ifdef PNG_SETJMP_SUPPORTED
+ if (!setjmp(create_jmp_buf))
+ {
+ /* Temporarily fake out the longjmp information until we have
+ * successfully completed this function. This only works if we have
+ * setjmp() support compiled in, but it is safe - this stuff should
+ * never happen.
+ */
+ create_struct.jmp_buf_ptr = &create_jmp_buf;
+ create_struct.jmp_buf_size = 0; /*stack allocation*/
+ create_struct.longjmp_fn = longjmp;
+# else
+ {
+# endif
+ /* Call the general version checker (shared with read and write code):
+ */
+ if (png_user_version_check(&create_struct, user_png_ver))
+ {
+ png_structrp png_ptr = png_voidcast(png_structrp,
+ png_malloc_warn(&create_struct, (sizeof *png_ptr)));
+
+ if (png_ptr != NULL)
+ {
+ /* png_ptr->zstream holds a back-pointer to the png_struct, so
+ * this can only be done now:
+ */
+ create_struct.zstream.zalloc = png_zalloc;
+ create_struct.zstream.zfree = png_zfree;
+ create_struct.zstream.opaque = png_ptr;
+
+# ifdef PNG_SETJMP_SUPPORTED
+ /* Eliminate the local error handling: */
+ create_struct.jmp_buf_ptr = NULL;
+ create_struct.jmp_buf_size = 0;
+ create_struct.longjmp_fn = 0;
+# endif
+
+ *png_ptr = create_struct;
+
+ /* This is the successful return point */
+ return png_ptr;
+ }
+ }
+ }
+
+ /* A longjmp because of a bug in the application storage allocator or a
+ * simple failure to allocate the png_struct.
+ */
+ return NULL;
+}
+
+/* Allocate the memory for an info_struct for the application. */
+PNG_FUNCTION(png_infop,PNGAPI
+png_create_info_struct,(png_const_structrp png_ptr),PNG_ALLOCATED)
+{
+ png_inforp info_ptr;
+
+ png_debug(1, "in png_create_info_struct");
+
+ if (png_ptr == NULL)
+ return NULL;
+
+ /* Use the internal API that does not (or at least should not) error out, so
+ * that this call always returns ok. The application typically sets up the
+ * error handling *after* creating the info_struct because this is the way it
+ * has always been done in 'example.c'.
+ */
+ info_ptr = png_voidcast(png_inforp, png_malloc_base(png_ptr,
+ (sizeof *info_ptr)));
+
+ if (info_ptr != NULL)
+ memset(info_ptr, 0, (sizeof *info_ptr));
+
+ return info_ptr;
+}
+
+/* This function frees the memory associated with a single info struct.
+ * Normally, one would use either png_destroy_read_struct() or
+ * png_destroy_write_struct() to free an info struct, but this may be
+ * useful for some applications. From libpng 1.6.0 this function is also used
+ * internally to implement the png_info release part of the 'struct' destroy
+ * APIs. This ensures that all possible approaches free the same data (all of
+ * it).
+ */
+void PNGAPI
+png_destroy_info_struct(png_const_structrp png_ptr, png_infopp info_ptr_ptr)
+{
+ png_inforp info_ptr = NULL;
+
+ png_debug(1, "in png_destroy_info_struct");
+
+ if (png_ptr == NULL)
+ return;
+
+ if (info_ptr_ptr != NULL)
+ info_ptr = *info_ptr_ptr;
+
+ if (info_ptr != NULL)
+ {
+ /* Do this first in case of an error below; if the app implements its own
+ * memory management this can lead to png_free calling png_error, which
+ * will abort this routine and return control to the app error handler.
+ * An infinite loop may result if it then tries to free the same info
+ * ptr.
+ */
+ *info_ptr_ptr = NULL;
+
+ png_free_data(png_ptr, info_ptr, PNG_FREE_ALL, -1);
+ memset(info_ptr, 0, (sizeof *info_ptr));
+ png_free(png_ptr, info_ptr);
+ }
+}
+
+/* Initialize the info structure. This is now an internal function (0.89)
+ * and applications using it are urged to use png_create_info_struct()
+ * instead. Use deprecated in 1.6.0, internal use removed (used internally it
+ * is just a memset).
+ *
+ * NOTE: it is almost inconceivable that this API is used because it bypasses
+ * the user-memory mechanism and the user error handling/warning mechanisms in
+ * those cases where it does anything other than a memset.
+ */
+PNG_FUNCTION(void,PNGAPI
+png_info_init_3,(png_infopp ptr_ptr, png_size_t png_info_struct_size),
+ PNG_DEPRECATED)
+{
+ png_inforp info_ptr = *ptr_ptr;
+
+ png_debug(1, "in png_info_init_3");
+
+ if (info_ptr == NULL)
+ return;
+
+ if ((sizeof (png_info)) > png_info_struct_size)
+ {
+ *ptr_ptr = NULL;
+ /* The following line is why this API should not be used: */
+ free(info_ptr);
+ info_ptr = png_voidcast(png_inforp, png_malloc_base(NULL,
+ (sizeof *info_ptr)));
+ *ptr_ptr = info_ptr;
+ }
+
+ /* Set everything to 0 */
+ memset(info_ptr, 0, (sizeof *info_ptr));
+}
+
+/* The following API is not called internally */
+void PNGAPI
+png_data_freer(png_const_structrp png_ptr, png_inforp info_ptr,
+ int freer, png_uint_32 mask)
+{
+ png_debug(1, "in png_data_freer");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ if (freer == PNG_DESTROY_WILL_FREE_DATA)
+ info_ptr->free_me |= mask;
+
+ else if (freer == PNG_USER_WILL_FREE_DATA)
+ info_ptr->free_me &= ~mask;
+
+ else
+ png_error(png_ptr, "Unknown freer parameter in png_data_freer");
+}
+
+void PNGAPI
+png_free_data(png_const_structrp png_ptr, png_inforp info_ptr, png_uint_32 mask,
+ int num)
+{
+ png_debug(1, "in png_free_data");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+#ifdef PNG_TEXT_SUPPORTED
+ /* Free text item num or (if num == -1) all text items */
+ if ((mask & PNG_FREE_TEXT) & info_ptr->free_me)
+ {
+ if (num != -1)
+ {
+ if (info_ptr->text && info_ptr->text[num].key)
+ {
+ png_free(png_ptr, info_ptr->text[num].key);
+ info_ptr->text[num].key = NULL;
+ }
+ }
+
+ else
+ {
+ int i;
+ for (i = 0; i < info_ptr->num_text; i++)
+ png_free_data(png_ptr, info_ptr, PNG_FREE_TEXT, i);
+ png_free(png_ptr, info_ptr->text);
+ info_ptr->text = NULL;
+ info_ptr->num_text=0;
+ }
+ }
+#endif
+
+#ifdef PNG_tRNS_SUPPORTED
+ /* Free any tRNS entry */
+ if ((mask & PNG_FREE_TRNS) & info_ptr->free_me)
+ {
+ png_free(png_ptr, info_ptr->trans_alpha);
+ info_ptr->trans_alpha = NULL;
+ info_ptr->valid &= ~PNG_INFO_tRNS;
+ }
+#endif
+
+#ifdef PNG_sCAL_SUPPORTED
+ /* Free any sCAL entry */
+ if ((mask & PNG_FREE_SCAL) & info_ptr->free_me)
+ {
+ png_free(png_ptr, info_ptr->scal_s_width);
+ png_free(png_ptr, info_ptr->scal_s_height);
+ info_ptr->scal_s_width = NULL;
+ info_ptr->scal_s_height = NULL;
+ info_ptr->valid &= ~PNG_INFO_sCAL;
+ }
+#endif
+
+#ifdef PNG_pCAL_SUPPORTED
+ /* Free any pCAL entry */
+ if ((mask & PNG_FREE_PCAL) & info_ptr->free_me)
+ {
+ png_free(png_ptr, info_ptr->pcal_purpose);
+ png_free(png_ptr, info_ptr->pcal_units);
+ info_ptr->pcal_purpose = NULL;
+ info_ptr->pcal_units = NULL;
+ if (info_ptr->pcal_params != NULL)
+ {
+ unsigned int i;
+ for (i = 0; i < info_ptr->pcal_nparams; i++)
+ {
+ png_free(png_ptr, info_ptr->pcal_params[i]);
+ info_ptr->pcal_params[i] = NULL;
+ }
+ png_free(png_ptr, info_ptr->pcal_params);
+ info_ptr->pcal_params = NULL;
+ }
+ info_ptr->valid &= ~PNG_INFO_pCAL;
+ }
+#endif
+
+#ifdef PNG_iCCP_SUPPORTED
+ /* Free any profile entry */
+ if ((mask & PNG_FREE_ICCP) & info_ptr->free_me)
+ {
+ png_free(png_ptr, info_ptr->iccp_name);
+ png_free(png_ptr, info_ptr->iccp_profile);
+ info_ptr->iccp_name = NULL;
+ info_ptr->iccp_profile = NULL;
+ info_ptr->valid &= ~PNG_INFO_iCCP;
+ }
+#endif
+
+#ifdef PNG_sPLT_SUPPORTED
+ /* Free a given sPLT entry, or (if num == -1) all sPLT entries */
+ if ((mask & PNG_FREE_SPLT) & info_ptr->free_me)
+ {
+ if (num != -1)
+ {
+ if (info_ptr->splt_palettes)
+ {
+ png_free(png_ptr, info_ptr->splt_palettes[num].name);
+ png_free(png_ptr, info_ptr->splt_palettes[num].entries);
+ info_ptr->splt_palettes[num].name = NULL;
+ info_ptr->splt_palettes[num].entries = NULL;
+ }
+ }
+
+ else
+ {
+ if (info_ptr->splt_palettes_num)
+ {
+ int i;
+ for (i = 0; i < info_ptr->splt_palettes_num; i++)
+ png_free_data(png_ptr, info_ptr, PNG_FREE_SPLT, (int)i);
+
+ png_free(png_ptr, info_ptr->splt_palettes);
+ info_ptr->splt_palettes = NULL;
+ info_ptr->splt_palettes_num = 0;
+ }
+ info_ptr->valid &= ~PNG_INFO_sPLT;
+ }
+ }
+#endif
+
+#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED
+ if ((mask & PNG_FREE_UNKN) & info_ptr->free_me)
+ {
+ if (num != -1)
+ {
+ if (info_ptr->unknown_chunks)
+ {
+ png_free(png_ptr, info_ptr->unknown_chunks[num].data);
+ info_ptr->unknown_chunks[num].data = NULL;
+ }
+ }
+
+ else
+ {
+ int i;
+
+ if (info_ptr->unknown_chunks_num)
+ {
+ for (i = 0; i < info_ptr->unknown_chunks_num; i++)
+ png_free_data(png_ptr, info_ptr, PNG_FREE_UNKN, (int)i);
+
+ png_free(png_ptr, info_ptr->unknown_chunks);
+ info_ptr->unknown_chunks = NULL;
+ info_ptr->unknown_chunks_num = 0;
+ }
+ }
+ }
+#endif
+
+#ifdef PNG_hIST_SUPPORTED
+ /* Free any hIST entry */
+ if ((mask & PNG_FREE_HIST) & info_ptr->free_me)
+ {
+ png_free(png_ptr, info_ptr->hist);
+ info_ptr->hist = NULL;
+ info_ptr->valid &= ~PNG_INFO_hIST;
+ }
+#endif
+
+ /* Free any PLTE entry that was internally allocated */
+ if ((mask & PNG_FREE_PLTE) & info_ptr->free_me)
+ {
+ png_free(png_ptr, info_ptr->palette);
+ info_ptr->palette = NULL;
+ info_ptr->valid &= ~PNG_INFO_PLTE;
+ info_ptr->num_palette = 0;
+ }
+
+#ifdef PNG_INFO_IMAGE_SUPPORTED
+ /* Free any image bits attached to the info structure */
+ if ((mask & PNG_FREE_ROWS) & info_ptr->free_me)
+ {
+ if (info_ptr->row_pointers)
+ {
+ png_uint_32 row;
+ for (row = 0; row < info_ptr->height; row++)
+ {
+ png_free(png_ptr, info_ptr->row_pointers[row]);
+ info_ptr->row_pointers[row] = NULL;
+ }
+ png_free(png_ptr, info_ptr->row_pointers);
+ info_ptr->row_pointers = NULL;
+ }
+ info_ptr->valid &= ~PNG_INFO_IDAT;
+ }
+#endif
+
+ if (num != -1)
+ mask &= ~PNG_FREE_MUL;
+
+ info_ptr->free_me &= ~mask;
+}
+#endif /* defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) */
+
+/* This function returns a pointer to the io_ptr associated with the user
+ * functions. The application should free any memory associated with this
+ * pointer before png_write_destroy() or png_read_destroy() are called.
+ */
+png_voidp PNGAPI
+png_get_io_ptr(png_const_structrp png_ptr)
+{
+ if (png_ptr == NULL)
+ return (NULL);
+
+ return (png_ptr->io_ptr);
+}
+
+#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED)
+# ifdef PNG_STDIO_SUPPORTED
+/* Initialize the default input/output functions for the PNG file. If you
+ * use your own read or write routines, you can call either png_set_read_fn()
+ * or png_set_write_fn() instead of png_init_io(). If you have defined
+ * PNG_NO_STDIO or otherwise disabled PNG_STDIO_SUPPORTED, you must use a
+ * function of your own because "FILE *" isn't necessarily available.
+ */
+void PNGAPI
+png_init_io(png_structrp png_ptr, png_FILE_p fp)
+{
+ png_debug(1, "in png_init_io");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->io_ptr = (png_voidp)fp;
+}
+# endif
+
+#ifdef PNG_SAVE_INT_32_SUPPORTED
+/* The png_save_int_32 function assumes integers are stored in two's
+ * complement format. If this isn't the case, then this routine needs to
+ * be modified to write data in two's complement format. Note that,
+ * the following works correctly even if png_int_32 has more than 32 bits
+ * (compare the more complex code required on read for sign extension.)
+ */
+void PNGAPI
+png_save_int_32(png_bytep buf, png_int_32 i)
+{
+ buf[0] = (png_byte)((i >> 24) & 0xff);
+ buf[1] = (png_byte)((i >> 16) & 0xff);
+ buf[2] = (png_byte)((i >> 8) & 0xff);
+ buf[3] = (png_byte)(i & 0xff);
+}
+#endif
+
+# ifdef PNG_TIME_RFC1123_SUPPORTED
+/* Convert the supplied time into an RFC 1123 string suitable for use in
+ * a "Creation Time" or other text-based time string.
+ */
+int PNGAPI
+png_convert_to_rfc1123_buffer(char out[29], png_const_timep ptime)
+{
+ static PNG_CONST char short_months[12][4] =
+ {"Jan", "Feb", "Mar", "Apr", "May", "Jun",
+ "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"};
+
+ if (out == NULL)
+ return 0;
+
+ if (ptime->year > 9999 /* RFC1123 limitation */ ||
+ ptime->month == 0 || ptime->month > 12 ||
+ ptime->day == 0 || ptime->day > 31 ||
+ ptime->hour > 23 || ptime->minute > 59 ||
+ ptime->second > 60)
+ return 0;
+
+ {
+ size_t pos = 0;
+ char number_buf[5]; /* enough for a four-digit year */
+
+# define APPEND_STRING(string) pos = png_safecat(out, 29, pos, (string))
+# define APPEND_NUMBER(format, value)\
+ APPEND_STRING(PNG_FORMAT_NUMBER(number_buf, format, (value)))
+# define APPEND(ch) if (pos < 28) out[pos++] = (ch)
+
+ APPEND_NUMBER(PNG_NUMBER_FORMAT_u, (unsigned)ptime->day);
+ APPEND(' ');
+ APPEND_STRING(short_months[(ptime->month - 1)]);
+ APPEND(' ');
+ APPEND_NUMBER(PNG_NUMBER_FORMAT_u, ptime->year);
+ APPEND(' ');
+ APPEND_NUMBER(PNG_NUMBER_FORMAT_02u, (unsigned)ptime->hour);
+ APPEND(':');
+ APPEND_NUMBER(PNG_NUMBER_FORMAT_02u, (unsigned)ptime->minute);
+ APPEND(':');
+ APPEND_NUMBER(PNG_NUMBER_FORMAT_02u, (unsigned)ptime->second);
+ APPEND_STRING(" +0000"); /* This reliably terminates the buffer */
+
+# undef APPEND
+# undef APPEND_NUMBER
+# undef APPEND_STRING
+ }
+
+ return 1;
+}
+
+# if PNG_LIBPNG_VER < 10700
+/* To do: remove the following from libpng-1.7 */
+/* Original API that uses a private buffer in png_struct.
+ * Deprecated because it causes png_struct to carry a spurious temporary
+ * buffer (png_struct::time_buffer), better to have the caller pass this in.
+ */
+png_const_charp PNGAPI
+png_convert_to_rfc1123(png_structrp png_ptr, png_const_timep ptime)
+{
+ if (png_ptr != NULL)
+ {
+ /* The only failure above if png_ptr != NULL is from an invalid ptime */
+ if (!png_convert_to_rfc1123_buffer(png_ptr->time_buffer, ptime))
+ png_warning(png_ptr, "Ignoring invalid time value");
+
+ else
+ return png_ptr->time_buffer;
+ }
+
+ return NULL;
+}
+# endif
+# endif /* PNG_TIME_RFC1123_SUPPORTED */
+
+#endif /* defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) */
+
+png_const_charp PNGAPI
+png_get_copyright(png_const_structrp png_ptr)
+{
+ PNG_UNUSED(png_ptr) /* Silence compiler warning about unused png_ptr */
+#ifdef PNG_STRING_COPYRIGHT
+ return PNG_STRING_COPYRIGHT
+#else
+# ifdef __STDC__
+ return PNG_STRING_NEWLINE \
+ "libpng version 1.6.7 - November 14, 2013" PNG_STRING_NEWLINE \
+ "Copyright (c) 1998-2013 Glenn Randers-Pehrson" PNG_STRING_NEWLINE \
+ "Copyright (c) 1996-1997 Andreas Dilger" PNG_STRING_NEWLINE \
+ "Copyright (c) 1995-1996 Guy Eric Schalnat, Group 42, Inc." \
+ PNG_STRING_NEWLINE;
+# else
+ return "libpng version 1.6.7 - November 14, 2013\
+ Copyright (c) 1998-2013 Glenn Randers-Pehrson\
+ Copyright (c) 1996-1997 Andreas Dilger\
+ Copyright (c) 1995-1996 Guy Eric Schalnat, Group 42, Inc.";
+# endif
+#endif
+}
+
+/* The following return the library version as a short string in the
+ * format 1.0.0 through 99.99.99zz. To get the version of *.h files
+ * used with your application, print out PNG_LIBPNG_VER_STRING, which
+ * is defined in png.h.
+ * Note: now there is no difference between png_get_libpng_ver() and
+ * png_get_header_ver(). Due to the version_nn_nn_nn typedef guard,
+ * it is guaranteed that png.c uses the correct version of png.h.
+ */
+png_const_charp PNGAPI
+png_get_libpng_ver(png_const_structrp png_ptr)
+{
+ /* Version of *.c files used when building libpng */
+ return png_get_header_ver(png_ptr);
+}
+
+png_const_charp PNGAPI
+png_get_header_ver(png_const_structrp png_ptr)
+{
+ /* Version of *.h files used when building libpng */
+ PNG_UNUSED(png_ptr) /* Silence compiler warning about unused png_ptr */
+ return PNG_LIBPNG_VER_STRING;
+}
+
+png_const_charp PNGAPI
+png_get_header_version(png_const_structrp png_ptr)
+{
+ /* Returns longer string containing both version and date */
+ PNG_UNUSED(png_ptr) /* Silence compiler warning about unused png_ptr */
+#ifdef __STDC__
+ return PNG_HEADER_VERSION_STRING
+# ifndef PNG_READ_SUPPORTED
+ " (NO READ SUPPORT)"
+# endif
+ PNG_STRING_NEWLINE;
+#else
+ return PNG_HEADER_VERSION_STRING;
+#endif
+}
+
+#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+int PNGAPI
+png_handle_as_unknown(png_const_structrp png_ptr, png_const_bytep chunk_name)
+{
+ /* Check chunk_name and return "keep" value if it's on the list, else 0 */
+ png_const_bytep p, p_end;
+
+ if (png_ptr == NULL || chunk_name == NULL || png_ptr->num_chunk_list == 0)
+ return PNG_HANDLE_CHUNK_AS_DEFAULT;
+
+ p_end = png_ptr->chunk_list;
+ p = p_end + png_ptr->num_chunk_list*5; /* beyond end */
+
+ /* The code is the fifth byte after each four byte string. Historically this
+ * code was always searched from the end of the list, this is no longer
+ * necessary because the 'set' routine handles duplicate entries correcty.
+ */
+ do /* num_chunk_list > 0, so at least one */
+ {
+ p -= 5;
+
+ if (!memcmp(chunk_name, p, 4))
+ return p[4];
+ }
+ while (p > p_end);
+
+ /* This means that known chunks should be processed and unknown chunks should
+ * be handled according to the value of png_ptr->unknown_default; this can be
+ * confusing because, as a result, there are two levels of defaulting for
+ * unknown chunks.
+ */
+ return PNG_HANDLE_CHUNK_AS_DEFAULT;
+}
+
+#if defined(PNG_READ_UNKNOWN_CHUNKS_SUPPORTED) ||\
+ defined(PNG_HANDLE_AS_UNKNOWN_SUPPORTED)
+int /* PRIVATE */
+png_chunk_unknown_handling(png_const_structrp png_ptr, png_uint_32 chunk_name)
+{
+ png_byte chunk_string[5];
+
+ PNG_CSTRING_FROM_CHUNK(chunk_string, chunk_name);
+ return png_handle_as_unknown(png_ptr, chunk_string);
+}
+#endif /* READ_UNKNOWN_CHUNKS || HANDLE_AS_UNKNOWN */
+#endif /* SET_UNKNOWN_CHUNKS */
+
+#ifdef PNG_READ_SUPPORTED
+/* This function, added to libpng-1.0.6g, is untested. */
+int PNGAPI
+png_reset_zstream(png_structrp png_ptr)
+{
+ if (png_ptr == NULL)
+ return Z_STREAM_ERROR;
+
+ /* WARNING: this resets the window bits to the maximum! */
+ return (inflateReset(&png_ptr->zstream));
+}
+#endif /* PNG_READ_SUPPORTED */
+
+/* This function was added to libpng-1.0.7 */
+png_uint_32 PNGAPI
+png_access_version_number(void)
+{
+ /* Version of *.c files used when building libpng */
+ return((png_uint_32)PNG_LIBPNG_VER);
+}
+
+
+
+#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED)
+/* Ensure that png_ptr->zstream.msg holds some appropriate error message string.
+ * If it doesn't 'ret' is used to set it to something appropriate, even in cases
+ * like Z_OK or Z_STREAM_END where the error code is apparently a success code.
+ */
+void /* PRIVATE */
+png_zstream_error(png_structrp png_ptr, int ret)
+{
+ /* Translate 'ret' into an appropriate error string, priority is given to the
+ * one in zstream if set. This always returns a string, even in cases like
+ * Z_OK or Z_STREAM_END where the error code is a success code.
+ */
+ if (png_ptr->zstream.msg == NULL) switch (ret)
+ {
+ default:
+ case Z_OK:
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("unexpected zlib return code");
+ break;
+
+ case Z_STREAM_END:
+ /* Normal exit */
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("unexpected end of LZ stream");
+ break;
+
+ case Z_NEED_DICT:
+ /* This means the deflate stream did not have a dictionary; this
+ * indicates a bogus PNG.
+ */
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("missing LZ dictionary");
+ break;
+
+ case Z_ERRNO:
+ /* gz APIs only: should not happen */
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("zlib IO error");
+ break;
+
+ case Z_STREAM_ERROR:
+ /* internal libpng error */
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("bad parameters to zlib");
+ break;
+
+ case Z_DATA_ERROR:
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("damaged LZ stream");
+ break;
+
+ case Z_MEM_ERROR:
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("insufficient memory");
+ break;
+
+ case Z_BUF_ERROR:
+ /* End of input or output; not a problem if the caller is doing
+ * incremental read or write.
+ */
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("truncated");
+ break;
+
+ case Z_VERSION_ERROR:
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("unsupported zlib version");
+ break;
+
+ case PNG_UNEXPECTED_ZLIB_RETURN:
+ /* Compile errors here mean that zlib now uses the value co-opted in
+ * pngpriv.h for PNG_UNEXPECTED_ZLIB_RETURN; update the switch above
+ * and change pngpriv.h. Note that this message is "... return",
+ * whereas the default/Z_OK one is "... return code".
+ */
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("unexpected zlib return");
+ break;
+ }
+}
+
+/* png_convert_size: a PNGAPI but no longer in png.h, so deleted
+ * at libpng 1.5.5!
+ */
+
+/* Added at libpng version 1.2.34 and 1.4.0 (moved from pngset.c) */
+#ifdef PNG_GAMMA_SUPPORTED /* always set if COLORSPACE */
+static int
+png_colorspace_check_gamma(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, png_fixed_point gAMA, int from)
+ /* This is called to check a new gamma value against an existing one. The
+ * routine returns false if the new gamma value should not be written.
+ *
+ * 'from' says where the new gamma value comes from:
+ *
+ * 0: the new gamma value is the libpng estimate for an ICC profile
+ * 1: the new gamma value comes from a gAMA chunk
+ * 2: the new gamma value comes from an sRGB chunk
+ */
+{
+ png_fixed_point gtest;
+
+ if ((colorspace->flags & PNG_COLORSPACE_HAVE_GAMMA) != 0 &&
+ (!png_muldiv(&gtest, colorspace->gamma, PNG_FP_1, gAMA) ||
+ png_gamma_significant(gtest)))
+ {
+ /* Either this is an sRGB image, in which case the calculated gamma
+ * approximation should match, or this is an image with a profile and the
+ * value libpng calculates for the gamma of the profile does not match the
+ * value recorded in the file. The former, sRGB, case is an error, the
+ * latter is just a warning.
+ */
+ if ((colorspace->flags & PNG_COLORSPACE_FROM_sRGB) != 0 || from == 2)
+ {
+ png_chunk_report(png_ptr, "gamma value does not match sRGB",
+ PNG_CHUNK_ERROR);
+ /* Do not overwrite an sRGB value */
+ return from == 2;
+ }
+
+ else /* sRGB tag not involved */
+ {
+ png_chunk_report(png_ptr, "gamma value does not match libpng estimate",
+ PNG_CHUNK_WARNING);
+ return from == 1;
+ }
+ }
+
+ return 1;
+}
+
+void /* PRIVATE */
+png_colorspace_set_gamma(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, png_fixed_point gAMA)
+{
+ /* Changed in libpng-1.5.4 to limit the values to ensure overflow can't
+ * occur. Since the fixed point representation is assymetrical it is
+ * possible for 1/gamma to overflow the limit of 21474 and this means the
+ * gamma value must be at least 5/100000 and hence at most 20000.0. For
+ * safety the limits here are a little narrower. The values are 0.00016 to
+ * 6250.0, which are truly ridiculous gamma values (and will produce
+ * displays that are all black or all white.)
+ *
+ * In 1.6.0 this test replaces the ones in pngrutil.c, in the gAMA chunk
+ * handling code, which only required the value to be >0.
+ */
+ png_const_charp errmsg;
+
+ if (gAMA < 16 || gAMA > 625000000)
+ errmsg = "gamma value out of range";
+
+# ifdef PNG_READ_gAMA_SUPPORTED
+ /* Allow the application to set the gamma value more than once */
+ else if ((png_ptr->mode & PNG_IS_READ_STRUCT) != 0 &&
+ (colorspace->flags & PNG_COLORSPACE_FROM_gAMA) != 0)
+ errmsg = "duplicate";
+# endif
+
+ /* Do nothing if the colorspace is already invalid */
+ else if (colorspace->flags & PNG_COLORSPACE_INVALID)
+ return;
+
+ else
+ {
+ if (png_colorspace_check_gamma(png_ptr, colorspace, gAMA, 1/*from gAMA*/))
+ {
+ /* Store this gamma value. */
+ colorspace->gamma = gAMA;
+ colorspace->flags |=
+ (PNG_COLORSPACE_HAVE_GAMMA | PNG_COLORSPACE_FROM_gAMA);
+ }
+
+ /* At present if the check_gamma test fails the gamma of the colorspace is
+ * not updated however the colorspace is not invalidated. This
+ * corresponds to the case where the existing gamma comes from an sRGB
+ * chunk or profile. An error message has already been output.
+ */
+ return;
+ }
+
+ /* Error exit - errmsg has been set. */
+ colorspace->flags |= PNG_COLORSPACE_INVALID;
+ png_chunk_report(png_ptr, errmsg, PNG_CHUNK_WRITE_ERROR);
+}
+
+void /* PRIVATE */
+png_colorspace_sync_info(png_const_structrp png_ptr, png_inforp info_ptr)
+{
+ if (info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID)
+ {
+ /* Everything is invalid */
+ info_ptr->valid &= ~(PNG_INFO_gAMA|PNG_INFO_cHRM|PNG_INFO_sRGB|
+ PNG_INFO_iCCP);
+
+# ifdef PNG_COLORSPACE_SUPPORTED
+ /* Clean up the iCCP profile now if it won't be used. */
+ png_free_data(png_ptr, info_ptr, PNG_FREE_ICCP, -1/*not used*/);
+# else
+ PNG_UNUSED(png_ptr)
+# endif
+ }
+
+ else
+ {
+# ifdef PNG_COLORSPACE_SUPPORTED
+ /* Leave the INFO_iCCP flag set if the pngset.c code has already set
+ * it; this allows a PNG to contain a profile which matches sRGB and
+ * yet still have that profile retrievable by the application.
+ */
+ if (info_ptr->colorspace.flags & PNG_COLORSPACE_MATCHES_sRGB)
+ info_ptr->valid |= PNG_INFO_sRGB;
+
+ else
+ info_ptr->valid &= ~PNG_INFO_sRGB;
+
+ if (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS)
+ info_ptr->valid |= PNG_INFO_cHRM;
+
+ else
+ info_ptr->valid &= ~PNG_INFO_cHRM;
+# endif
+
+ if (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_GAMMA)
+ info_ptr->valid |= PNG_INFO_gAMA;
+
+ else
+ info_ptr->valid &= ~PNG_INFO_gAMA;
+ }
+}
+
+#ifdef PNG_READ_SUPPORTED
+void /* PRIVATE */
+png_colorspace_sync(png_const_structrp png_ptr, png_inforp info_ptr)
+{
+ if (info_ptr == NULL) /* reduce code size; check here not in the caller */
+ return;
+
+ info_ptr->colorspace = png_ptr->colorspace;
+ png_colorspace_sync_info(png_ptr, info_ptr);
+}
+#endif
+#endif
+
+#ifdef PNG_COLORSPACE_SUPPORTED
+/* Added at libpng-1.5.5 to support read and write of true CIEXYZ values for
+ * cHRM, as opposed to using chromaticities. These internal APIs return
+ * non-zero on a parameter error. The X, Y and Z values are required to be
+ * positive and less than 1.0.
+ */
+static int
+png_xy_from_XYZ(png_xy *xy, const png_XYZ *XYZ)
+{
+ png_int_32 d, dwhite, whiteX, whiteY;
+
+ d = XYZ->red_X + XYZ->red_Y + XYZ->red_Z;
+ if (!png_muldiv(&xy->redx, XYZ->red_X, PNG_FP_1, d)) return 1;
+ if (!png_muldiv(&xy->redy, XYZ->red_Y, PNG_FP_1, d)) return 1;
+ dwhite = d;
+ whiteX = XYZ->red_X;
+ whiteY = XYZ->red_Y;
+
+ d = XYZ->green_X + XYZ->green_Y + XYZ->green_Z;
+ if (!png_muldiv(&xy->greenx, XYZ->green_X, PNG_FP_1, d)) return 1;
+ if (!png_muldiv(&xy->greeny, XYZ->green_Y, PNG_FP_1, d)) return 1;
+ dwhite += d;
+ whiteX += XYZ->green_X;
+ whiteY += XYZ->green_Y;
+
+ d = XYZ->blue_X + XYZ->blue_Y + XYZ->blue_Z;
+ if (!png_muldiv(&xy->bluex, XYZ->blue_X, PNG_FP_1, d)) return 1;
+ if (!png_muldiv(&xy->bluey, XYZ->blue_Y, PNG_FP_1, d)) return 1;
+ dwhite += d;
+ whiteX += XYZ->blue_X;
+ whiteY += XYZ->blue_Y;
+
+ /* The reference white is simply the sum of the end-point (X,Y,Z) vectors,
+ * thus:
+ */
+ if (!png_muldiv(&xy->whitex, whiteX, PNG_FP_1, dwhite)) return 1;
+ if (!png_muldiv(&xy->whitey, whiteY, PNG_FP_1, dwhite)) return 1;
+
+ return 0;
+}
+
+static int
+png_XYZ_from_xy(png_XYZ *XYZ, const png_xy *xy)
+{
+ png_fixed_point red_inverse, green_inverse, blue_scale;
+ png_fixed_point left, right, denominator;
+
+ /* Check xy and, implicitly, z. Note that wide gamut color spaces typically
+ * have end points with 0 tristimulus values (these are impossible end
+ * points, but they are used to cover the possible colors.)
+ */
+ if (xy->redx < 0 || xy->redx > PNG_FP_1) return 1;
+ if (xy->redy < 0 || xy->redy > PNG_FP_1-xy->redx) return 1;
+ if (xy->greenx < 0 || xy->greenx > PNG_FP_1) return 1;
+ if (xy->greeny < 0 || xy->greeny > PNG_FP_1-xy->greenx) return 1;
+ if (xy->bluex < 0 || xy->bluex > PNG_FP_1) return 1;
+ if (xy->bluey < 0 || xy->bluey > PNG_FP_1-xy->bluex) return 1;
+ if (xy->whitex < 0 || xy->whitex > PNG_FP_1) return 1;
+ if (xy->whitey < 0 || xy->whitey > PNG_FP_1-xy->whitex) return 1;
+
+ /* The reverse calculation is more difficult because the original tristimulus
+ * value had 9 independent values (red,green,blue)x(X,Y,Z) however only 8
+ * derived values were recorded in the cHRM chunk;
+ * (red,green,blue,white)x(x,y). This loses one degree of freedom and
+ * therefore an arbitrary ninth value has to be introduced to undo the
+ * original transformations.
+ *
+ * Think of the original end-points as points in (X,Y,Z) space. The
+ * chromaticity values (c) have the property:
+ *
+ * C
+ * c = ---------
+ * X + Y + Z
+ *
+ * For each c (x,y,z) from the corresponding original C (X,Y,Z). Thus the
+ * three chromaticity values (x,y,z) for each end-point obey the
+ * relationship:
+ *
+ * x + y + z = 1
+ *
+ * This describes the plane in (X,Y,Z) space that intersects each axis at the
+ * value 1.0; call this the chromaticity plane. Thus the chromaticity
+ * calculation has scaled each end-point so that it is on the x+y+z=1 plane
+ * and chromaticity is the intersection of the vector from the origin to the
+ * (X,Y,Z) value with the chromaticity plane.
+ *
+ * To fully invert the chromaticity calculation we would need the three
+ * end-point scale factors, (red-scale, green-scale, blue-scale), but these
+ * were not recorded. Instead we calculated the reference white (X,Y,Z) and
+ * recorded the chromaticity of this. The reference white (X,Y,Z) would have
+ * given all three of the scale factors since:
+ *
+ * color-C = color-c * color-scale
+ * white-C = red-C + green-C + blue-C
+ * = red-c*red-scale + green-c*green-scale + blue-c*blue-scale
+ *
+ * But cHRM records only white-x and white-y, so we have lost the white scale
+ * factor:
+ *
+ * white-C = white-c*white-scale
+ *
+ * To handle this the inverse transformation makes an arbitrary assumption
+ * about white-scale:
+ *
+ * Assume: white-Y = 1.0
+ * Hence: white-scale = 1/white-y
+ * Or: red-Y + green-Y + blue-Y = 1.0
+ *
+ * Notice the last statement of the assumption gives an equation in three of
+ * the nine values we want to calculate. 8 more equations come from the
+ * above routine as summarised at the top above (the chromaticity
+ * calculation):
+ *
+ * Given: color-x = color-X / (color-X + color-Y + color-Z)
+ * Hence: (color-x - 1)*color-X + color.x*color-Y + color.x*color-Z = 0
+ *
+ * This is 9 simultaneous equations in the 9 variables "color-C" and can be
+ * solved by Cramer's rule. Cramer's rule requires calculating 10 9x9 matrix
+ * determinants, however this is not as bad as it seems because only 28 of
+ * the total of 90 terms in the various matrices are non-zero. Nevertheless
+ * Cramer's rule is notoriously numerically unstable because the determinant
+ * calculation involves the difference of large, but similar, numbers. It is
+ * difficult to be sure that the calculation is stable for real world values
+ * and it is certain that it becomes unstable where the end points are close
+ * together.
+ *
+ * So this code uses the perhaps slightly less optimal but more
+ * understandable and totally obvious approach of calculating color-scale.
+ *
+ * This algorithm depends on the precision in white-scale and that is
+ * (1/white-y), so we can immediately see that as white-y approaches 0 the
+ * accuracy inherent in the cHRM chunk drops off substantially.
+ *
+ * libpng arithmetic: a simple invertion of the above equations
+ * ------------------------------------------------------------
+ *
+ * white_scale = 1/white-y
+ * white-X = white-x * white-scale
+ * white-Y = 1.0
+ * white-Z = (1 - white-x - white-y) * white_scale
+ *
+ * white-C = red-C + green-C + blue-C
+ * = red-c*red-scale + green-c*green-scale + blue-c*blue-scale
+ *
+ * This gives us three equations in (red-scale,green-scale,blue-scale) where
+ * all the coefficients are now known:
+ *
+ * red-x*red-scale + green-x*green-scale + blue-x*blue-scale
+ * = white-x/white-y
+ * red-y*red-scale + green-y*green-scale + blue-y*blue-scale = 1
+ * red-z*red-scale + green-z*green-scale + blue-z*blue-scale
+ * = (1 - white-x - white-y)/white-y
+ *
+ * In the last equation color-z is (1 - color-x - color-y) so we can add all
+ * three equations together to get an alternative third:
+ *
+ * red-scale + green-scale + blue-scale = 1/white-y = white-scale
+ *
+ * So now we have a Cramer's rule solution where the determinants are just
+ * 3x3 - far more tractible. Unfortunately 3x3 determinants still involve
+ * multiplication of three coefficients so we can't guarantee to avoid
+ * overflow in the libpng fixed point representation. Using Cramer's rule in
+ * floating point is probably a good choice here, but it's not an option for
+ * fixed point. Instead proceed to simplify the first two equations by
+ * eliminating what is likely to be the largest value, blue-scale:
+ *
+ * blue-scale = white-scale - red-scale - green-scale
+ *
+ * Hence:
+ *
+ * (red-x - blue-x)*red-scale + (green-x - blue-x)*green-scale =
+ * (white-x - blue-x)*white-scale
+ *
+ * (red-y - blue-y)*red-scale + (green-y - blue-y)*green-scale =
+ * 1 - blue-y*white-scale
+ *
+ * And now we can trivially solve for (red-scale,green-scale):
+ *
+ * green-scale =
+ * (white-x - blue-x)*white-scale - (red-x - blue-x)*red-scale
+ * -----------------------------------------------------------
+ * green-x - blue-x
+ *
+ * red-scale =
+ * 1 - blue-y*white-scale - (green-y - blue-y) * green-scale
+ * ---------------------------------------------------------
+ * red-y - blue-y
+ *
+ * Hence:
+ *
+ * red-scale =
+ * ( (green-x - blue-x) * (white-y - blue-y) -
+ * (green-y - blue-y) * (white-x - blue-x) ) / white-y
+ * -------------------------------------------------------------------------
+ * (green-x - blue-x)*(red-y - blue-y)-(green-y - blue-y)*(red-x - blue-x)
+ *
+ * green-scale =
+ * ( (red-y - blue-y) * (white-x - blue-x) -
+ * (red-x - blue-x) * (white-y - blue-y) ) / white-y
+ * -------------------------------------------------------------------------
+ * (green-x - blue-x)*(red-y - blue-y)-(green-y - blue-y)*(red-x - blue-x)
+ *
+ * Accuracy:
+ * The input values have 5 decimal digits of accuracy. The values are all in
+ * the range 0 < value < 1, so simple products are in the same range but may
+ * need up to 10 decimal digits to preserve the original precision and avoid
+ * underflow. Because we are using a 32-bit signed representation we cannot
+ * match this; the best is a little over 9 decimal digits, less than 10.
+ *
+ * The approach used here is to preserve the maximum precision within the
+ * signed representation. Because the red-scale calculation above uses the
+ * difference between two products of values that must be in the range -1..+1
+ * it is sufficient to divide the product by 7; ceil(100,000/32767*2). The
+ * factor is irrelevant in the calculation because it is applied to both
+ * numerator and denominator.
+ *
+ * Note that the values of the differences of the products of the
+ * chromaticities in the above equations tend to be small, for example for
+ * the sRGB chromaticities they are:
+ *
+ * red numerator: -0.04751
+ * green numerator: -0.08788
+ * denominator: -0.2241 (without white-y multiplication)
+ *
+ * The resultant Y coefficients from the chromaticities of some widely used
+ * color space definitions are (to 15 decimal places):
+ *
+ * sRGB
+ * 0.212639005871510 0.715168678767756 0.072192315360734
+ * Kodak ProPhoto
+ * 0.288071128229293 0.711843217810102 0.000085653960605
+ * Adobe RGB
+ * 0.297344975250536 0.627363566255466 0.075291458493998
+ * Adobe Wide Gamut RGB
+ * 0.258728243040113 0.724682314948566 0.016589442011321
+ */
+ /* By the argument, above overflow should be impossible here. The return
+ * value of 2 indicates an internal error to the caller.
+ */
+ if (!png_muldiv(&left, xy->greenx-xy->bluex, xy->redy - xy->bluey, 7))
+ return 2;
+ if (!png_muldiv(&right, xy->greeny-xy->bluey, xy->redx - xy->bluex, 7))
+ return 2;
+ denominator = left - right;
+
+ /* Now find the red numerator. */
+ if (!png_muldiv(&left, xy->greenx-xy->bluex, xy->whitey-xy->bluey, 7))
+ return 2;
+ if (!png_muldiv(&right, xy->greeny-xy->bluey, xy->whitex-xy->bluex, 7))
+ return 2;
+
+ /* Overflow is possible here and it indicates an extreme set of PNG cHRM
+ * chunk values. This calculation actually returns the reciprocal of the
+ * scale value because this allows us to delay the multiplication of white-y
+ * into the denominator, which tends to produce a small number.
+ */
+ if (!png_muldiv(&red_inverse, xy->whitey, denominator, left-right) ||
+ red_inverse <= xy->whitey /* r+g+b scales = white scale */)
+ return 1;
+
+ /* Similarly for green_inverse: */
+ if (!png_muldiv(&left, xy->redy-xy->bluey, xy->whitex-xy->bluex, 7))
+ return 2;
+ if (!png_muldiv(&right, xy->redx-xy->bluex, xy->whitey-xy->bluey, 7))
+ return 2;
+ if (!png_muldiv(&green_inverse, xy->whitey, denominator, left-right) ||
+ green_inverse <= xy->whitey)
+ return 1;
+
+ /* And the blue scale, the checks above guarantee this can't overflow but it
+ * can still produce 0 for extreme cHRM values.
+ */
+ blue_scale = png_reciprocal(xy->whitey) - png_reciprocal(red_inverse) -
+ png_reciprocal(green_inverse);
+ if (blue_scale <= 0) return 1;
+
+
+ /* And fill in the png_XYZ: */
+ if (!png_muldiv(&XYZ->red_X, xy->redx, PNG_FP_1, red_inverse)) return 1;
+ if (!png_muldiv(&XYZ->red_Y, xy->redy, PNG_FP_1, red_inverse)) return 1;
+ if (!png_muldiv(&XYZ->red_Z, PNG_FP_1 - xy->redx - xy->redy, PNG_FP_1,
+ red_inverse))
+ return 1;
+
+ if (!png_muldiv(&XYZ->green_X, xy->greenx, PNG_FP_1, green_inverse))
+ return 1;
+ if (!png_muldiv(&XYZ->green_Y, xy->greeny, PNG_FP_1, green_inverse))
+ return 1;
+ if (!png_muldiv(&XYZ->green_Z, PNG_FP_1 - xy->greenx - xy->greeny, PNG_FP_1,
+ green_inverse))
+ return 1;
+
+ if (!png_muldiv(&XYZ->blue_X, xy->bluex, blue_scale, PNG_FP_1)) return 1;
+ if (!png_muldiv(&XYZ->blue_Y, xy->bluey, blue_scale, PNG_FP_1)) return 1;
+ if (!png_muldiv(&XYZ->blue_Z, PNG_FP_1 - xy->bluex - xy->bluey, blue_scale,
+ PNG_FP_1))
+ return 1;
+
+ return 0; /*success*/
+}
+
+static int
+png_XYZ_normalize(png_XYZ *XYZ)
+{
+ png_int_32 Y;
+
+ if (XYZ->red_Y < 0 || XYZ->green_Y < 0 || XYZ->blue_Y < 0 ||
+ XYZ->red_X < 0 || XYZ->green_X < 0 || XYZ->blue_X < 0 ||
+ XYZ->red_Z < 0 || XYZ->green_Z < 0 || XYZ->blue_Z < 0)
+ return 1;
+
+ /* Normalize by scaling so the sum of the end-point Y values is PNG_FP_1.
+ * IMPLEMENTATION NOTE: ANSI requires signed overflow not to occur, therefore
+ * relying on addition of two positive values producing a negative one is not
+ * safe.
+ */
+ Y = XYZ->red_Y;
+ if (0x7fffffff - Y < XYZ->green_X) return 1;
+ Y += XYZ->green_Y;
+ if (0x7fffffff - Y < XYZ->blue_X) return 1;
+ Y += XYZ->blue_Y;
+
+ if (Y != PNG_FP_1)
+ {
+ if (!png_muldiv(&XYZ->red_X, XYZ->red_X, PNG_FP_1, Y)) return 1;
+ if (!png_muldiv(&XYZ->red_Y, XYZ->red_Y, PNG_FP_1, Y)) return 1;
+ if (!png_muldiv(&XYZ->red_Z, XYZ->red_Z, PNG_FP_1, Y)) return 1;
+
+ if (!png_muldiv(&XYZ->green_X, XYZ->green_X, PNG_FP_1, Y)) return 1;
+ if (!png_muldiv(&XYZ->green_Y, XYZ->green_Y, PNG_FP_1, Y)) return 1;
+ if (!png_muldiv(&XYZ->green_Z, XYZ->green_Z, PNG_FP_1, Y)) return 1;
+
+ if (!png_muldiv(&XYZ->blue_X, XYZ->blue_X, PNG_FP_1, Y)) return 1;
+ if (!png_muldiv(&XYZ->blue_Y, XYZ->blue_Y, PNG_FP_1, Y)) return 1;
+ if (!png_muldiv(&XYZ->blue_Z, XYZ->blue_Z, PNG_FP_1, Y)) return 1;
+ }
+
+ return 0;
+}
+
+static int
+png_colorspace_endpoints_match(const png_xy *xy1, const png_xy *xy2, int delta)
+{
+ /* Allow an error of +/-0.01 (absolute value) on each chromaticity */
+ return !(PNG_OUT_OF_RANGE(xy1->whitex, xy2->whitex,delta) ||
+ PNG_OUT_OF_RANGE(xy1->whitey, xy2->whitey,delta) ||
+ PNG_OUT_OF_RANGE(xy1->redx, xy2->redx, delta) ||
+ PNG_OUT_OF_RANGE(xy1->redy, xy2->redy, delta) ||
+ PNG_OUT_OF_RANGE(xy1->greenx, xy2->greenx,delta) ||
+ PNG_OUT_OF_RANGE(xy1->greeny, xy2->greeny,delta) ||
+ PNG_OUT_OF_RANGE(xy1->bluex, xy2->bluex, delta) ||
+ PNG_OUT_OF_RANGE(xy1->bluey, xy2->bluey, delta));
+}
+
+/* Added in libpng-1.6.0, a different check for the validity of a set of cHRM
+ * chunk chromaticities. Earlier checks used to simply look for the overflow
+ * condition (where the determinant of the matrix to solve for XYZ ends up zero
+ * because the chromaticity values are not all distinct.) Despite this it is
+ * theoretically possible to produce chromaticities that are apparently valid
+ * but that rapidly degrade to invalid, potentially crashing, sets because of
+ * arithmetic inaccuracies when calculations are performed on them. The new
+ * check is to round-trip xy -> XYZ -> xy and then check that the result is
+ * within a small percentage of the original.
+ */
+static int
+png_colorspace_check_xy(png_XYZ *XYZ, const png_xy *xy)
+{
+ int result;
+ png_xy xy_test;
+
+ /* As a side-effect this routine also returns the XYZ endpoints. */
+ result = png_XYZ_from_xy(XYZ, xy);
+ if (result) return result;
+
+ result = png_xy_from_XYZ(&xy_test, XYZ);
+ if (result) return result;
+
+ if (png_colorspace_endpoints_match(xy, &xy_test,
+ 5/*actually, the math is pretty accurate*/))
+ return 0;
+
+ /* Too much slip */
+ return 1;
+}
+
+/* This is the check going the other way. The XYZ is modified to normalize it
+ * (another side-effect) and the xy chromaticities are returned.
+ */
+static int
+png_colorspace_check_XYZ(png_xy *xy, png_XYZ *XYZ)
+{
+ int result;
+ png_XYZ XYZtemp;
+
+ result = png_XYZ_normalize(XYZ);
+ if (result) return result;
+
+ result = png_xy_from_XYZ(xy, XYZ);
+ if (result) return result;
+
+ XYZtemp = *XYZ;
+ return png_colorspace_check_xy(&XYZtemp, xy);
+}
+
+/* Used to check for an endpoint match against sRGB */
+static const png_xy sRGB_xy = /* From ITU-R BT.709-3 */
+{
+ /* color x y */
+ /* red */ 64000, 33000,
+ /* green */ 30000, 60000,
+ /* blue */ 15000, 6000,
+ /* white */ 31270, 32900
+};
+
+static int
+png_colorspace_set_xy_and_XYZ(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, const png_xy *xy, const png_XYZ *XYZ,
+ int preferred)
+{
+ if (colorspace->flags & PNG_COLORSPACE_INVALID)
+ return 0;
+
+ /* The consistency check is performed on the chromaticities; this factors out
+ * variations because of the normalization (or not) of the end point Y
+ * values.
+ */
+ if (preferred < 2 && (colorspace->flags & PNG_COLORSPACE_HAVE_ENDPOINTS))
+ {
+ /* The end points must be reasonably close to any we already have. The
+ * following allows an error of up to +/-.001
+ */
+ if (!png_colorspace_endpoints_match(xy, &colorspace->end_points_xy, 100))
+ {
+ colorspace->flags |= PNG_COLORSPACE_INVALID;
+ png_benign_error(png_ptr, "inconsistent chromaticities");
+ return 0; /* failed */
+ }
+
+ /* Only overwrite with preferred values */
+ if (!preferred)
+ return 1; /* ok, but no change */
+ }
+
+ colorspace->end_points_xy = *xy;
+ colorspace->end_points_XYZ = *XYZ;
+ colorspace->flags |= PNG_COLORSPACE_HAVE_ENDPOINTS;
+
+ /* The end points are normally quoted to two decimal digits, so allow +/-0.01
+ * on this test.
+ */
+ if (png_colorspace_endpoints_match(xy, &sRGB_xy, 1000))
+ colorspace->flags |= PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB;
+
+ else
+ colorspace->flags &= PNG_COLORSPACE_CANCEL(
+ PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB);
+
+ return 2; /* ok and changed */
+}
+
+int /* PRIVATE */
+png_colorspace_set_chromaticities(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, const png_xy *xy, int preferred)
+{
+ /* We must check the end points to ensure they are reasonable - in the past
+ * color management systems have crashed as a result of getting bogus
+ * colorant values, while this isn't the fault of libpng it is the
+ * responsibility of libpng because PNG carries the bomb and libpng is in a
+ * position to protect against it.
+ */
+ png_XYZ XYZ;
+
+ switch (png_colorspace_check_xy(&XYZ, xy))
+ {
+ case 0: /* success */
+ return png_colorspace_set_xy_and_XYZ(png_ptr, colorspace, xy, &XYZ,
+ preferred);
+
+ case 1:
+ /* We can't invert the chromaticities so we can't produce value XYZ
+ * values. Likely as not a color management system will fail too.
+ */
+ colorspace->flags |= PNG_COLORSPACE_INVALID;
+ png_benign_error(png_ptr, "invalid chromaticities");
+ break;
+
+ default:
+ /* libpng is broken; this should be a warning but if it happens we
+ * want error reports so for the moment it is an error.
+ */
+ colorspace->flags |= PNG_COLORSPACE_INVALID;
+ png_error(png_ptr, "internal error checking chromaticities");
+ break;
+ }
+
+ return 0; /* failed */
+}
+
+int /* PRIVATE */
+png_colorspace_set_endpoints(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, const png_XYZ *XYZ_in, int preferred)
+{
+ png_XYZ XYZ = *XYZ_in;
+ png_xy xy;
+
+ switch (png_colorspace_check_XYZ(&xy, &XYZ))
+ {
+ case 0:
+ return png_colorspace_set_xy_and_XYZ(png_ptr, colorspace, &xy, &XYZ,
+ preferred);
+
+ case 1:
+ /* End points are invalid. */
+ colorspace->flags |= PNG_COLORSPACE_INVALID;
+ png_benign_error(png_ptr, "invalid end points");
+ break;
+
+ default:
+ colorspace->flags |= PNG_COLORSPACE_INVALID;
+ png_error(png_ptr, "internal error checking chromaticities");
+ break;
+ }
+
+ return 0; /* failed */
+}
+
+#if defined(PNG_sRGB_SUPPORTED) || defined(PNG_iCCP_SUPPORTED)
+/* Error message generation */
+static char
+png_icc_tag_char(png_uint_32 byte)
+{
+ byte &= 0xff;
+ if (byte >= 32 && byte <= 126)
+ return (char)byte;
+ else
+ return '?';
+}
+
+static void
+png_icc_tag_name(char *name, png_uint_32 tag)
+{
+ name[0] = '\'';
+ name[1] = png_icc_tag_char(tag >> 24);
+ name[2] = png_icc_tag_char(tag >> 16);
+ name[3] = png_icc_tag_char(tag >> 8);
+ name[4] = png_icc_tag_char(tag );
+ name[5] = '\'';
+}
+
+static int
+is_ICC_signature_char(png_alloc_size_t it)
+{
+ return it == 32 || (it >= 48 && it <= 57) || (it >= 65 && it <= 90) ||
+ (it >= 97 && it <= 122);
+}
+
+static int is_ICC_signature(png_alloc_size_t it)
+{
+ return is_ICC_signature_char(it >> 24) /* checks all the top bits */ &&
+ is_ICC_signature_char((it >> 16) & 0xff) &&
+ is_ICC_signature_char((it >> 8) & 0xff) &&
+ is_ICC_signature_char(it & 0xff);
+}
+
+static int
+png_icc_profile_error(png_const_structrp png_ptr, png_colorspacerp colorspace,
+ png_const_charp name, png_alloc_size_t value, png_const_charp reason)
+{
+ size_t pos;
+ char message[196]; /* see below for calculation */
+
+ if (colorspace != NULL)
+ colorspace->flags |= PNG_COLORSPACE_INVALID;
+
+ pos = png_safecat(message, (sizeof message), 0, "profile '"); /* 9 chars */
+ pos = png_safecat(message, pos+79, pos, name); /* Truncate to 79 chars */
+ pos = png_safecat(message, (sizeof message), pos, "': "); /* +2 = 90 */
+ if (is_ICC_signature(value))
+ {
+ /* So 'value' is at most 4 bytes and the following cast is safe */
+ png_icc_tag_name(message+pos, (png_uint_32)value);
+ pos += 6; /* total +8; less than the else clause */
+ message[pos++] = ':';
+ message[pos++] = ' ';
+ }
+# ifdef PNG_WARNINGS_SUPPORTED
+ else
+ {
+ char number[PNG_NUMBER_BUFFER_SIZE]; /* +24 = 114*/
+
+ pos = png_safecat(message, (sizeof message), pos,
+ png_format_number(number, number+(sizeof number),
+ PNG_NUMBER_FORMAT_x, value));
+ pos = png_safecat(message, (sizeof message), pos, "h: "); /*+2 = 116*/
+ }
+# endif
+ /* The 'reason' is an arbitrary message, allow +79 maximum 195 */
+ pos = png_safecat(message, (sizeof message), pos, reason);
+
+ /* This is recoverable, but make it unconditionally an app_error on write to
+ * avoid writing invalid ICC profiles into PNG files. (I.e. we handle them
+ * on read, with a warning, but on write unless the app turns off
+ * application errors the PNG won't be written.)
+ */
+ png_chunk_report(png_ptr, message,
+ (colorspace != NULL) ? PNG_CHUNK_ERROR : PNG_CHUNK_WRITE_ERROR);
+
+ return 0;
+}
+#endif /* sRGB || iCCP */
+
+#ifdef PNG_sRGB_SUPPORTED
+int /* PRIVATE */
+png_colorspace_set_sRGB(png_const_structrp png_ptr, png_colorspacerp colorspace,
+ int intent)
+{
+ /* sRGB sets known gamma, end points and (from the chunk) intent. */
+ /* IMPORTANT: these are not necessarily the values found in an ICC profile
+ * because ICC profiles store values adapted to a D50 environment; it is
+ * expected that the ICC profile mediaWhitePointTag will be D50, see the
+ * checks and code elsewhere to understand this better.
+ *
+ * These XYZ values, which are accurate to 5dp, produce rgb to gray
+ * coefficients of (6968,23435,2366), which are reduced (because they add up
+ * to 32769 not 32768) to (6968,23434,2366). These are the values that
+ * libpng has traditionally used (and are the best values given the 15bit
+ * algorithm used by the rgb to gray code.)
+ */
+ static const png_XYZ sRGB_XYZ = /* D65 XYZ (*not* the D50 adapted values!) */
+ {
+ /* color X Y Z */
+ /* red */ 41239, 21264, 1933,
+ /* green */ 35758, 71517, 11919,
+ /* blue */ 18048, 7219, 95053
+ };
+
+ /* Do nothing if the colorspace is already invalidated. */
+ if (colorspace->flags & PNG_COLORSPACE_INVALID)
+ return 0;
+
+ /* Check the intent, then check for existing settings. It is valid for the
+ * PNG file to have cHRM or gAMA chunks along with sRGB, but the values must
+ * be consistent with the correct values. If, however, this function is
+ * called below because an iCCP chunk matches sRGB then it is quite
+ * conceivable that an older app recorded incorrect gAMA and cHRM because of
+ * an incorrect calculation based on the values in the profile - this does
+ * *not* invalidate the profile (though it still produces an error, which can
+ * be ignored.)
+ */
+ if (intent < 0 || intent >= PNG_sRGB_INTENT_LAST)
+ return png_icc_profile_error(png_ptr, colorspace, "sRGB",
+ (unsigned)intent, "invalid sRGB rendering intent");
+
+ if ((colorspace->flags & PNG_COLORSPACE_HAVE_INTENT) != 0 &&
+ colorspace->rendering_intent != intent)
+ return png_icc_profile_error(png_ptr, colorspace, "sRGB",
+ (unsigned)intent, "inconsistent rendering intents");
+
+ if ((colorspace->flags & PNG_COLORSPACE_FROM_sRGB) != 0)
+ {
+ png_benign_error(png_ptr, "duplicate sRGB information ignored");
+ return 0;
+ }
+
+ /* If the standard sRGB cHRM chunk does not match the one from the PNG file
+ * warn but overwrite the value with the correct one.
+ */
+ if ((colorspace->flags & PNG_COLORSPACE_HAVE_ENDPOINTS) != 0 &&
+ !png_colorspace_endpoints_match(&sRGB_xy, &colorspace->end_points_xy,
+ 100))
+ png_chunk_report(png_ptr, "cHRM chunk does not match sRGB",
+ PNG_CHUNK_ERROR);
+
+ /* This check is just done for the error reporting - the routine always
+ * returns true when the 'from' argument corresponds to sRGB (2).
+ */
+ (void)png_colorspace_check_gamma(png_ptr, colorspace, PNG_GAMMA_sRGB_INVERSE,
+ 2/*from sRGB*/);
+
+ /* intent: bugs in GCC force 'int' to be used as the parameter type. */
+ colorspace->rendering_intent = (png_uint_16)intent;
+ colorspace->flags |= PNG_COLORSPACE_HAVE_INTENT;
+
+ /* endpoints */
+ colorspace->end_points_xy = sRGB_xy;
+ colorspace->end_points_XYZ = sRGB_XYZ;
+ colorspace->flags |=
+ (PNG_COLORSPACE_HAVE_ENDPOINTS|PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB);
+
+ /* gamma */
+ colorspace->gamma = PNG_GAMMA_sRGB_INVERSE;
+ colorspace->flags |= PNG_COLORSPACE_HAVE_GAMMA;
+
+ /* Finally record that we have an sRGB profile */
+ colorspace->flags |=
+ (PNG_COLORSPACE_MATCHES_sRGB|PNG_COLORSPACE_FROM_sRGB);
+
+ return 1; /* set */
+}
+#endif /* sRGB */
+
+#ifdef PNG_iCCP_SUPPORTED
+/* Encoded value of D50 as an ICC XYZNumber. From the ICC 2010 spec the value
+ * is XYZ(0.9642,1.0,0.8249), which scales to:
+ *
+ * (63189.8112, 65536, 54060.6464)
+ */
+static const png_byte D50_nCIEXYZ[12] =
+ { 0x00, 0x00, 0xf6, 0xd6, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xd3, 0x2d };
+
+int /* PRIVATE */
+png_icc_check_length(png_const_structrp png_ptr, png_colorspacerp colorspace,
+ png_const_charp name, png_uint_32 profile_length)
+{
+ if (profile_length < 132)
+ return png_icc_profile_error(png_ptr, colorspace, name, profile_length,
+ "too short");
+
+ if (profile_length & 3)
+ return png_icc_profile_error(png_ptr, colorspace, name, profile_length,
+ "invalid length");
+
+ return 1;
+}
+
+int /* PRIVATE */
+png_icc_check_header(png_const_structrp png_ptr, png_colorspacerp colorspace,
+ png_const_charp name, png_uint_32 profile_length,
+ png_const_bytep profile/* first 132 bytes only */, int color_type)
+{
+ png_uint_32 temp;
+
+ /* Length check; this cannot be ignored in this code because profile_length
+ * is used later to check the tag table, so even if the profile seems over
+ * long profile_length from the caller must be correct. The caller can fix
+ * this up on read or write by just passing in the profile header length.
+ */
+ temp = png_get_uint_32(profile);
+ if (temp != profile_length)
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "length does not match profile");
+
+ temp = png_get_uint_32(profile+128); /* tag count: 12 bytes/tag */
+ if (temp > 357913930 || /* (2^32-4-132)/12: maximum possible tag count */
+ profile_length < 132+12*temp) /* truncated tag table */
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "tag count too large");
+
+ /* The 'intent' must be valid or we can't store it, ICC limits the intent to
+ * 16 bits.
+ */
+ temp = png_get_uint_32(profile+64);
+ if (temp >= 0xffff) /* The ICC limit */
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "invalid rendering intent");
+
+ /* This is just a warning because the profile may be valid in future
+ * versions.
+ */
+ if (temp >= PNG_sRGB_INTENT_LAST)
+ (void)png_icc_profile_error(png_ptr, NULL, name, temp,
+ "intent outside defined range");
+
+ /* At this point the tag table can't be checked because it hasn't necessarily
+ * been loaded; however, various header fields can be checked. These checks
+ * are for values permitted by the PNG spec in an ICC profile; the PNG spec
+ * restricts the profiles that can be passed in an iCCP chunk (they must be
+ * appropriate to processing PNG data!)
+ */
+
+ /* Data checks (could be skipped). These checks must be independent of the
+ * version number; however, the version number doesn't accomodate changes in
+ * the header fields (just the known tags and the interpretation of the
+ * data.)
+ */
+ temp = png_get_uint_32(profile+36); /* signature 'ascp' */
+ if (temp != 0x61637370)
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "invalid signature");
+
+ /* Currently the PCS illuminant/adopted white point (the computational
+ * white point) are required to be D50,
+ * however the profile contains a record of the illuminant so perhaps ICC
+ * expects to be able to change this in the future (despite the rationale in
+ * the introduction for using a fixed PCS adopted white.) Consequently the
+ * following is just a warning.
+ */
+ if (memcmp(profile+68, D50_nCIEXYZ, 12) != 0)
+ (void)png_icc_profile_error(png_ptr, NULL, name, 0/*no tag value*/,
+ "PCS illuminant is not D50");
+
+ /* The PNG spec requires this:
+ * "If the iCCP chunk is present, the image samples conform to the colour
+ * space represented by the embedded ICC profile as defined by the
+ * International Color Consortium [ICC]. The colour space of the ICC profile
+ * shall be an RGB colour space for colour images (PNG colour types 2, 3, and
+ * 6), or a greyscale colour space for greyscale images (PNG colour types 0
+ * and 4)."
+ *
+ * This checking code ensures the embedded profile (on either read or write)
+ * conforms to the specification requirements. Notice that an ICC 'gray'
+ * color-space profile contains the information to transform the monochrome
+ * data to XYZ or L*a*b (according to which PCS the profile uses) and this
+ * should be used in preference to the standard libpng K channel replication
+ * into R, G and B channels.
+ *
+ * Previously it was suggested that an RGB profile on grayscale data could be
+ * handled. However it it is clear that using an RGB profile in this context
+ * must be an error - there is no specification of what it means. Thus it is
+ * almost certainly more correct to ignore the profile.
+ */
+ temp = png_get_uint_32(profile+16); /* data colour space field */
+ switch (temp)
+ {
+ case 0x52474220: /* 'RGB ' */
+ if (!(color_type & PNG_COLOR_MASK_COLOR))
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "RGB color space not permitted on grayscale PNG");
+ break;
+
+ case 0x47524159: /* 'GRAY' */
+ if (color_type & PNG_COLOR_MASK_COLOR)
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "Gray color space not permitted on RGB PNG");
+ break;
+
+ default:
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "invalid ICC profile color space");
+ }
+
+ /* It is up to the application to check that the profile class matches the
+ * application requirements; the spec provides no guidance, but it's pretty
+ * weird if the profile is not scanner ('scnr'), monitor ('mntr'), printer
+ * ('prtr') or 'spac' (for generic color spaces). Issue a warning in these
+ * cases. Issue an error for device link or abstract profiles - these don't
+ * contain the records necessary to transform the color-space to anything
+ * other than the target device (and not even that for an abstract profile).
+ * Profiles of these classes may not be embedded in images.
+ */
+ temp = png_get_uint_32(profile+12); /* profile/device class */
+ switch (temp)
+ {
+ case 0x73636E72: /* 'scnr' */
+ case 0x6D6E7472: /* 'mntr' */
+ case 0x70727472: /* 'prtr' */
+ case 0x73706163: /* 'spac' */
+ /* All supported */
+ break;
+
+ case 0x61627374: /* 'abst' */
+ /* May not be embedded in an image */
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "invalid embedded Abstract ICC profile");
+
+ case 0x6C696E6B: /* 'link' */
+ /* DeviceLink profiles cannnot be interpreted in a non-device specific
+ * fashion, if an app uses the AToB0Tag in the profile the results are
+ * undefined unless the result is sent to the intended device,
+ * therefore a DeviceLink profile should not be found embedded in a
+ * PNG.
+ */
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "unexpected DeviceLink ICC profile class");
+
+ case 0x6E6D636C: /* 'nmcl' */
+ /* A NamedColor profile is also device specific, however it doesn't
+ * contain an AToB0 tag that is open to misintrepretation. Almost
+ * certainly it will fail the tests below.
+ */
+ (void)png_icc_profile_error(png_ptr, NULL, name, temp,
+ "unexpected NamedColor ICC profile class");
+ break;
+
+ default:
+ /* To allow for future enhancements to the profile accept unrecognized
+ * profile classes with a warning, these then hit the test below on the
+ * tag content to ensure they are backward compatible with one of the
+ * understood profiles.
+ */
+ (void)png_icc_profile_error(png_ptr, NULL, name, temp,
+ "unrecognized ICC profile class");
+ break;
+ }
+
+ /* For any profile other than a device link one the PCS must be encoded
+ * either in XYZ or Lab.
+ */
+ temp = png_get_uint_32(profile+20);
+ switch (temp)
+ {
+ case 0x58595A20: /* 'XYZ ' */
+ case 0x4C616220: /* 'Lab ' */
+ break;
+
+ default:
+ return png_icc_profile_error(png_ptr, colorspace, name, temp,
+ "unexpected ICC PCS encoding");
+ }
+
+ return 1;
+}
+
+int /* PRIVATE */
+png_icc_check_tag_table(png_const_structrp png_ptr, png_colorspacerp colorspace,
+ png_const_charp name, png_uint_32 profile_length,
+ png_const_bytep profile /* header plus whole tag table */)
+{
+ png_uint_32 tag_count = png_get_uint_32(profile+128);
+ png_uint_32 itag;
+ png_const_bytep tag = profile+132; /* The first tag */
+
+ /* First scan all the tags in the table and add bits to the icc_info value
+ * (temporarily in 'tags').
+ */
+ for (itag=0; itag < tag_count; ++itag, tag += 12)
+ {
+ png_uint_32 tag_id = png_get_uint_32(tag+0);
+ png_uint_32 tag_start = png_get_uint_32(tag+4); /* must be aligned */
+ png_uint_32 tag_length = png_get_uint_32(tag+8);/* not padded */
+
+ /* The ICC specification does not exclude zero length tags, therefore the
+ * start might actually be anywhere if there is no data, but this would be
+ * a clear abuse of the intent of the standard so the start is checked for
+ * being in range. All defined tag types have an 8 byte header - a 4 byte
+ * type signature then 0.
+ */
+ if ((tag_start & 3) != 0)
+ {
+ /* CNHP730S.icc shipped with Microsoft Windows 64 violates this, it is
+ * only a warning here because libpng does not care about the
+ * alignment.
+ */
+ (void)png_icc_profile_error(png_ptr, NULL, name, tag_id,
+ "ICC profile tag start not a multiple of 4");
+ }
+
+ /* This is a hard error; potentially it can cause read outside the
+ * profile.
+ */
+ if (tag_start > profile_length || tag_length > profile_length - tag_start)
+ return png_icc_profile_error(png_ptr, colorspace, name, tag_id,
+ "ICC profile tag outside profile");
+ }
+
+ return 1; /* success, maybe with warnings */
+}
+
+#ifdef PNG_sRGB_SUPPORTED
+/* Information about the known ICC sRGB profiles */
+static const struct
+{
+ png_uint_32 adler, crc, length;
+ png_uint_32 md5[4];
+ png_byte have_md5;
+ png_byte is_broken;
+ png_uint_16 intent;
+
+# define PNG_MD5(a,b,c,d) { a, b, c, d }, (a!=0)||(b!=0)||(c!=0)||(d!=0)
+# define PNG_ICC_CHECKSUM(adler, crc, md5, intent, broke, date, length, fname)\
+ { adler, crc, length, md5, broke, intent },
+
+} png_sRGB_checks[] =
+{
+ /* This data comes from contrib/tools/checksum-icc run on downloads of
+ * all four ICC sRGB profiles from www.color.org.
+ */
+ /* adler32, crc32, MD5[4], intent, date, length, file-name */
+ PNG_ICC_CHECKSUM(0x0a3fd9f6, 0x3b8772b9,
+ PNG_MD5(0x29f83dde, 0xaff255ae, 0x7842fae4, 0xca83390d), 0, 0,
+ "2009/03/27 21:36:31", 3048, "sRGB_IEC61966-2-1_black_scaled.icc")
+
+ /* ICC sRGB v2 perceptual no black-compensation: */
+ PNG_ICC_CHECKSUM(0x4909e5e1, 0x427ebb21,
+ PNG_MD5(0xc95bd637, 0xe95d8a3b, 0x0df38f99, 0xc1320389), 1, 0,
+ "2009/03/27 21:37:45", 3052, "sRGB_IEC61966-2-1_no_black_scaling.icc")
+
+ PNG_ICC_CHECKSUM(0xfd2144a1, 0x306fd8ae,
+ PNG_MD5(0xfc663378, 0x37e2886b, 0xfd72e983, 0x8228f1b8), 0, 0,
+ "2009/08/10 17:28:01", 60988, "sRGB_v4_ICC_preference_displayclass.icc")
+
+ /* ICC sRGB v4 perceptual */
+ PNG_ICC_CHECKSUM(0x209c35d2, 0xbbef7812,
+ PNG_MD5(0x34562abf, 0x994ccd06, 0x6d2c5721, 0xd0d68c5d), 0, 0,
+ "2007/07/25 00:05:37", 60960, "sRGB_v4_ICC_preference.icc")
+
+ /* The following profiles have no known MD5 checksum. If there is a match
+ * on the (empty) MD5 the other fields are used to attempt a match and
+ * a warning is produced. The first two of these profiles have a 'cprt' tag
+ * which suggests that they were also made by Hewlett Packard.
+ */
+ PNG_ICC_CHECKSUM(0xa054d762, 0x5d5129ce,
+ PNG_MD5(0x00000000, 0x00000000, 0x00000000, 0x00000000), 1, 0,
+ "2004/07/21 18:57:42", 3024, "sRGB_IEC61966-2-1_noBPC.icc")
+
+ /* This is a 'mntr' (display) profile with a mediaWhitePointTag that does not
+ * match the D50 PCS illuminant in the header (it is in fact the D65 values,
+ * so the white point is recorded as the un-adapted value.) The profiles
+ * below only differ in one byte - the intent - and are basically the same as
+ * the previous profile except for the mediaWhitePointTag error and a missing
+ * chromaticAdaptationTag.
+ */
+ PNG_ICC_CHECKSUM(0xf784f3fb, 0x182ea552,
+ PNG_MD5(0x00000000, 0x00000000, 0x00000000, 0x00000000), 0, 1/*broken*/,
+ "1998/02/09 06:49:00", 3144, "HP-Microsoft sRGB v2 perceptual")
+
+ PNG_ICC_CHECKSUM(0x0398f3fc, 0xf29e526d,
+ PNG_MD5(0x00000000, 0x00000000, 0x00000000, 0x00000000), 1, 1/*broken*/,
+ "1998/02/09 06:49:00", 3144, "HP-Microsoft sRGB v2 media-relative")
+};
+
+static int
+png_compare_ICC_profile_with_sRGB(png_const_structrp png_ptr,
+ png_const_bytep profile, uLong adler)
+{
+ /* The quick check is to verify just the MD5 signature and trust the
+ * rest of the data. Because the profile has already been verified for
+ * correctness this is safe. png_colorspace_set_sRGB will check the 'intent'
+ * field too, so if the profile has been edited with an intent not defined
+ * by sRGB (but maybe defined by a later ICC specification) the read of
+ * the profile will fail at that point.
+ */
+ png_uint_32 length = 0;
+ png_uint_32 intent = 0x10000; /* invalid */
+#if PNG_sRGB_PROFILE_CHECKS > 1
+ uLong crc = 0; /* the value for 0 length data */
+#endif
+ unsigned int i;
+
+ for (i=0; i < (sizeof png_sRGB_checks) / (sizeof png_sRGB_checks[0]); ++i)
+ {
+ if (png_get_uint_32(profile+84) == png_sRGB_checks[i].md5[0] &&
+ png_get_uint_32(profile+88) == png_sRGB_checks[i].md5[1] &&
+ png_get_uint_32(profile+92) == png_sRGB_checks[i].md5[2] &&
+ png_get_uint_32(profile+96) == png_sRGB_checks[i].md5[3])
+ {
+ /* This may be one of the old HP profiles without an MD5, in that
+ * case we can only use the length and Adler32 (note that these
+ * are not used by default if there is an MD5!)
+ */
+# if PNG_sRGB_PROFILE_CHECKS == 0
+ if (png_sRGB_checks[i].have_md5)
+ return 1+png_sRGB_checks[i].is_broken;
+# endif
+
+ /* Profile is unsigned or more checks have been configured in. */
+ if (length == 0)
+ {
+ length = png_get_uint_32(profile);
+ intent = png_get_uint_32(profile+64);
+ }
+
+ /* Length *and* intent must match */
+ if (length == png_sRGB_checks[i].length &&
+ intent == png_sRGB_checks[i].intent)
+ {
+ /* Now calculate the adler32 if not done already. */
+ if (adler == 0)
+ {
+ adler = adler32(0, NULL, 0);
+ adler = adler32(adler, profile, length);
+ }
+
+ if (adler == png_sRGB_checks[i].adler)
+ {
+ /* These basic checks suggest that the data has not been
+ * modified, but if the check level is more than 1 perform
+ * our own crc32 checksum on the data.
+ */
+# if PNG_sRGB_PROFILE_CHECKS > 1
+ if (crc == 0)
+ {
+ crc = crc32(0, NULL, 0);
+ crc = crc32(crc, profile, length);
+ }
+
+ /* So this check must pass for the 'return' below to happen.
+ */
+ if (crc == png_sRGB_checks[i].crc)
+# endif
+ {
+ if (png_sRGB_checks[i].is_broken)
+ {
+ /* These profiles are known to have bad data that may cause
+ * problems if they are used, therefore attempt to
+ * discourage their use, skip the 'have_md5' warning below,
+ * which is made irrelevant by this error.
+ */
+ png_chunk_report(png_ptr, "known incorrect sRGB profile",
+ PNG_CHUNK_ERROR);
+ }
+
+ /* Warn that this being done; this isn't even an error since
+ * the profile is perfectly valid, but it would be nice if
+ * people used the up-to-date ones.
+ */
+ else if (!png_sRGB_checks[i].have_md5)
+ {
+ png_chunk_report(png_ptr,
+ "out-of-date sRGB profile with no signature",
+ PNG_CHUNK_WARNING);
+ }
+
+ return 1+png_sRGB_checks[i].is_broken;
+ }
+ }
+ }
+
+# if PNG_sRGB_PROFILE_CHECKS > 0
+ /* The signature matched, but the profile had been changed in some
+ * way. This is an apparent violation of the ICC terms of use and,
+ * anyway, probably indicates a data error or uninformed hacking.
+ */
+ if (png_sRGB_checks[i].have_md5)
+ png_benign_error(png_ptr,
+ "copyright violation: edited ICC profile ignored");
+# endif
+ }
+ }
+
+ return 0; /* no match */
+}
+#endif
+
+#ifdef PNG_sRGB_SUPPORTED
+void /* PRIVATE */
+png_icc_set_sRGB(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, png_const_bytep profile, uLong adler)
+{
+ /* Is this profile one of the known ICC sRGB profiles? If it is, just set
+ * the sRGB information.
+ */
+ if (png_compare_ICC_profile_with_sRGB(png_ptr, profile, adler))
+ (void)png_colorspace_set_sRGB(png_ptr, colorspace,
+ (int)/*already checked*/png_get_uint_32(profile+64));
+}
+#endif /* PNG_READ_sRGB_SUPPORTED */
+
+int /* PRIVATE */
+png_colorspace_set_ICC(png_const_structrp png_ptr, png_colorspacerp colorspace,
+ png_const_charp name, png_uint_32 profile_length, png_const_bytep profile,
+ int color_type)
+{
+ if (colorspace->flags & PNG_COLORSPACE_INVALID)
+ return 0;
+
+ if (png_icc_check_length(png_ptr, colorspace, name, profile_length) &&
+ png_icc_check_header(png_ptr, colorspace, name, profile_length, profile,
+ color_type) &&
+ png_icc_check_tag_table(png_ptr, colorspace, name, profile_length,
+ profile))
+ {
+# ifdef PNG_sRGB_SUPPORTED
+ /* If no sRGB support, don't try storing sRGB information */
+ png_icc_set_sRGB(png_ptr, colorspace, profile, 0);
+# endif
+ return 1;
+ }
+
+ /* Failure case */
+ return 0;
+}
+#endif /* iCCP */
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+void /* PRIVATE */
+png_colorspace_set_rgb_coefficients(png_structrp png_ptr)
+{
+ /* Set the rgb_to_gray coefficients from the colorspace. */
+ if (!png_ptr->rgb_to_gray_coefficients_set &&
+ (png_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS) != 0)
+ {
+ /* png_set_background has not been called, get the coefficients from the Y
+ * values of the colorspace colorants.
+ */
+ png_fixed_point r = png_ptr->colorspace.end_points_XYZ.red_Y;
+ png_fixed_point g = png_ptr->colorspace.end_points_XYZ.green_Y;
+ png_fixed_point b = png_ptr->colorspace.end_points_XYZ.blue_Y;
+ png_fixed_point total = r+g+b;
+
+ if (total > 0 &&
+ r >= 0 && png_muldiv(&r, r, 32768, total) && r >= 0 && r <= 32768 &&
+ g >= 0 && png_muldiv(&g, g, 32768, total) && g >= 0 && g <= 32768 &&
+ b >= 0 && png_muldiv(&b, b, 32768, total) && b >= 0 && b <= 32768 &&
+ r+g+b <= 32769)
+ {
+ /* We allow 0 coefficients here. r+g+b may be 32769 if two or
+ * all of the coefficients were rounded up. Handle this by
+ * reducing the *largest* coefficient by 1; this matches the
+ * approach used for the default coefficients in pngrtran.c
+ */
+ int add = 0;
+
+ if (r+g+b > 32768)
+ add = -1;
+ else if (r+g+b < 32768)
+ add = 1;
+
+ if (add != 0)
+ {
+ if (g >= r && g >= b)
+ g += add;
+ else if (r >= g && r >= b)
+ r += add;
+ else
+ b += add;
+ }
+
+ /* Check for an internal error. */
+ if (r+g+b != 32768)
+ png_error(png_ptr,
+ "internal error handling cHRM coefficients");
+
+ else
+ {
+ png_ptr->rgb_to_gray_red_coeff = (png_uint_16)r;
+ png_ptr->rgb_to_gray_green_coeff = (png_uint_16)g;
+ }
+ }
+
+ /* This is a png_error at present even though it could be ignored -
+ * it should never happen, but it is important that if it does, the
+ * bug is fixed.
+ */
+ else
+ png_error(png_ptr, "internal error handling cHRM->XYZ");
+ }
+}
+#endif
+
+#endif /* COLORSPACE */
+
+void /* PRIVATE */
+png_check_IHDR(png_const_structrp png_ptr,
+ png_uint_32 width, png_uint_32 height, int bit_depth,
+ int color_type, int interlace_type, int compression_type,
+ int filter_type)
+{
+ int error = 0;
+
+ /* Check for width and height valid values */
+ if (width == 0)
+ {
+ png_warning(png_ptr, "Image width is zero in IHDR");
+ error = 1;
+ }
+
+ if (height == 0)
+ {
+ png_warning(png_ptr, "Image height is zero in IHDR");
+ error = 1;
+ }
+
+# ifdef PNG_SET_USER_LIMITS_SUPPORTED
+ if (width > png_ptr->user_width_max)
+
+# else
+ if (width > PNG_USER_WIDTH_MAX)
+# endif
+ {
+ png_warning(png_ptr, "Image width exceeds user limit in IHDR");
+ error = 1;
+ }
+
+# ifdef PNG_SET_USER_LIMITS_SUPPORTED
+ if (height > png_ptr->user_height_max)
+# else
+ if (height > PNG_USER_HEIGHT_MAX)
+# endif
+ {
+ png_warning(png_ptr, "Image height exceeds user limit in IHDR");
+ error = 1;
+ }
+
+ if (width > PNG_UINT_31_MAX)
+ {
+ png_warning(png_ptr, "Invalid image width in IHDR");
+ error = 1;
+ }
+
+ if (height > PNG_UINT_31_MAX)
+ {
+ png_warning(png_ptr, "Invalid image height in IHDR");
+ error = 1;
+ }
+
+ if (width > (PNG_UINT_32_MAX
+ >> 3) /* 8-byte RGBA pixels */
+ - 48 /* bigrowbuf hack */
+ - 1 /* filter byte */
+ - 7*8 /* rounding of width to multiple of 8 pixels */
+ - 8) /* extra max_pixel_depth pad */
+ png_warning(png_ptr, "Width is too large for libpng to process pixels");
+
+ /* Check other values */
+ if (bit_depth != 1 && bit_depth != 2 && bit_depth != 4 &&
+ bit_depth != 8 && bit_depth != 16)
+ {
+ png_warning(png_ptr, "Invalid bit depth in IHDR");
+ error = 1;
+ }
+
+ if (color_type < 0 || color_type == 1 ||
+ color_type == 5 || color_type > 6)
+ {
+ png_warning(png_ptr, "Invalid color type in IHDR");
+ error = 1;
+ }
+
+ if (((color_type == PNG_COLOR_TYPE_PALETTE) && bit_depth > 8) ||
+ ((color_type == PNG_COLOR_TYPE_RGB ||
+ color_type == PNG_COLOR_TYPE_GRAY_ALPHA ||
+ color_type == PNG_COLOR_TYPE_RGB_ALPHA) && bit_depth < 8))
+ {
+ png_warning(png_ptr, "Invalid color type/bit depth combination in IHDR");
+ error = 1;
+ }
+
+ if (interlace_type >= PNG_INTERLACE_LAST)
+ {
+ png_warning(png_ptr, "Unknown interlace method in IHDR");
+ error = 1;
+ }
+
+ if (compression_type != PNG_COMPRESSION_TYPE_BASE)
+ {
+ png_warning(png_ptr, "Unknown compression method in IHDR");
+ error = 1;
+ }
+
+# ifdef PNG_MNG_FEATURES_SUPPORTED
+ /* Accept filter_method 64 (intrapixel differencing) only if
+ * 1. Libpng was compiled with PNG_MNG_FEATURES_SUPPORTED and
+ * 2. Libpng did not read a PNG signature (this filter_method is only
+ * used in PNG datastreams that are embedded in MNG datastreams) and
+ * 3. The application called png_permit_mng_features with a mask that
+ * included PNG_FLAG_MNG_FILTER_64 and
+ * 4. The filter_method is 64 and
+ * 5. The color_type is RGB or RGBA
+ */
+ if ((png_ptr->mode & PNG_HAVE_PNG_SIGNATURE) &&
+ png_ptr->mng_features_permitted)
+ png_warning(png_ptr, "MNG features are not allowed in a PNG datastream");
+
+ if (filter_type != PNG_FILTER_TYPE_BASE)
+ {
+ if (!((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) &&
+ (filter_type == PNG_INTRAPIXEL_DIFFERENCING) &&
+ ((png_ptr->mode & PNG_HAVE_PNG_SIGNATURE) == 0) &&
+ (color_type == PNG_COLOR_TYPE_RGB ||
+ color_type == PNG_COLOR_TYPE_RGB_ALPHA)))
+ {
+ png_warning(png_ptr, "Unknown filter method in IHDR");
+ error = 1;
+ }
+
+ if (png_ptr->mode & PNG_HAVE_PNG_SIGNATURE)
+ {
+ png_warning(png_ptr, "Invalid filter method in IHDR");
+ error = 1;
+ }
+ }
+
+# else
+ if (filter_type != PNG_FILTER_TYPE_BASE)
+ {
+ png_warning(png_ptr, "Unknown filter method in IHDR");
+ error = 1;
+ }
+# endif
+
+ if (error == 1)
+ png_error(png_ptr, "Invalid IHDR data");
+}
+
+#if defined(PNG_sCAL_SUPPORTED) || defined(PNG_pCAL_SUPPORTED)
+/* ASCII to fp functions */
+/* Check an ASCII formated floating point value, see the more detailed
+ * comments in pngpriv.h
+ */
+/* The following is used internally to preserve the sticky flags */
+#define png_fp_add(state, flags) ((state) |= (flags))
+#define png_fp_set(state, value) ((state) = (value) | ((state) & PNG_FP_STICKY))
+
+int /* PRIVATE */
+png_check_fp_number(png_const_charp string, png_size_t size, int *statep,
+ png_size_tp whereami)
+{
+ int state = *statep;
+ png_size_t i = *whereami;
+
+ while (i < size)
+ {
+ int type;
+ /* First find the type of the next character */
+ switch (string[i])
+ {
+ case 43: type = PNG_FP_SAW_SIGN; break;
+ case 45: type = PNG_FP_SAW_SIGN + PNG_FP_NEGATIVE; break;
+ case 46: type = PNG_FP_SAW_DOT; break;
+ case 48: type = PNG_FP_SAW_DIGIT; break;
+ case 49: case 50: case 51: case 52:
+ case 53: case 54: case 55: case 56:
+ case 57: type = PNG_FP_SAW_DIGIT + PNG_FP_NONZERO; break;
+ case 69:
+ case 101: type = PNG_FP_SAW_E; break;
+ default: goto PNG_FP_End;
+ }
+
+ /* Now deal with this type according to the current
+ * state, the type is arranged to not overlap the
+ * bits of the PNG_FP_STATE.
+ */
+ switch ((state & PNG_FP_STATE) + (type & PNG_FP_SAW_ANY))
+ {
+ case PNG_FP_INTEGER + PNG_FP_SAW_SIGN:
+ if (state & PNG_FP_SAW_ANY)
+ goto PNG_FP_End; /* not a part of the number */
+
+ png_fp_add(state, type);
+ break;
+
+ case PNG_FP_INTEGER + PNG_FP_SAW_DOT:
+ /* Ok as trailer, ok as lead of fraction. */
+ if (state & PNG_FP_SAW_DOT) /* two dots */
+ goto PNG_FP_End;
+
+ else if (state & PNG_FP_SAW_DIGIT) /* trailing dot? */
+ png_fp_add(state, type);
+
+ else
+ png_fp_set(state, PNG_FP_FRACTION | type);
+
+ break;
+
+ case PNG_FP_INTEGER + PNG_FP_SAW_DIGIT:
+ if (state & PNG_FP_SAW_DOT) /* delayed fraction */
+ png_fp_set(state, PNG_FP_FRACTION | PNG_FP_SAW_DOT);
+
+ png_fp_add(state, type | PNG_FP_WAS_VALID);
+
+ break;
+
+ case PNG_FP_INTEGER + PNG_FP_SAW_E:
+ if ((state & PNG_FP_SAW_DIGIT) == 0)
+ goto PNG_FP_End;
+
+ png_fp_set(state, PNG_FP_EXPONENT);
+
+ break;
+
+ /* case PNG_FP_FRACTION + PNG_FP_SAW_SIGN:
+ goto PNG_FP_End; ** no sign in fraction */
+
+ /* case PNG_FP_FRACTION + PNG_FP_SAW_DOT:
+ goto PNG_FP_End; ** Because SAW_DOT is always set */
+
+ case PNG_FP_FRACTION + PNG_FP_SAW_DIGIT:
+ png_fp_add(state, type | PNG_FP_WAS_VALID);
+ break;
+
+ case PNG_FP_FRACTION + PNG_FP_SAW_E:
+ /* This is correct because the trailing '.' on an
+ * integer is handled above - so we can only get here
+ * with the sequence ".E" (with no preceding digits).
+ */
+ if ((state & PNG_FP_SAW_DIGIT) == 0)
+ goto PNG_FP_End;
+
+ png_fp_set(state, PNG_FP_EXPONENT);
+
+ break;
+
+ case PNG_FP_EXPONENT + PNG_FP_SAW_SIGN:
+ if (state & PNG_FP_SAW_ANY)
+ goto PNG_FP_End; /* not a part of the number */
+
+ png_fp_add(state, PNG_FP_SAW_SIGN);
+
+ break;
+
+ /* case PNG_FP_EXPONENT + PNG_FP_SAW_DOT:
+ goto PNG_FP_End; */
+
+ case PNG_FP_EXPONENT + PNG_FP_SAW_DIGIT:
+ png_fp_add(state, PNG_FP_SAW_DIGIT | PNG_FP_WAS_VALID);
+
+ break;
+
+ /* case PNG_FP_EXPONEXT + PNG_FP_SAW_E:
+ goto PNG_FP_End; */
+
+ default: goto PNG_FP_End; /* I.e. break 2 */
+ }
+
+ /* The character seems ok, continue. */
+ ++i;
+ }
+
+PNG_FP_End:
+ /* Here at the end, update the state and return the correct
+ * return code.
+ */
+ *statep = state;
+ *whereami = i;
+
+ return (state & PNG_FP_SAW_DIGIT) != 0;
+}
+
+
+/* The same but for a complete string. */
+int
+png_check_fp_string(png_const_charp string, png_size_t size)
+{
+ int state=0;
+ png_size_t char_index=0;
+
+ if (png_check_fp_number(string, size, &state, &char_index) &&
+ (char_index == size || string[char_index] == 0))
+ return state /* must be non-zero - see above */;
+
+ return 0; /* i.e. fail */
+}
+#endif /* pCAL or sCAL */
+
+#ifdef PNG_sCAL_SUPPORTED
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+/* Utility used below - a simple accurate power of ten from an integral
+ * exponent.
+ */
+static double
+png_pow10(int power)
+{
+ int recip = 0;
+ double d = 1;
+
+ /* Handle negative exponent with a reciprocal at the end because
+ * 10 is exact whereas .1 is inexact in base 2
+ */
+ if (power < 0)
+ {
+ if (power < DBL_MIN_10_EXP) return 0;
+ recip = 1, power = -power;
+ }
+
+ if (power > 0)
+ {
+ /* Decompose power bitwise. */
+ double mult = 10;
+ do
+ {
+ if (power & 1) d *= mult;
+ mult *= mult;
+ power >>= 1;
+ }
+ while (power > 0);
+
+ if (recip) d = 1/d;
+ }
+ /* else power is 0 and d is 1 */
+
+ return d;
+}
+
+/* Function to format a floating point value in ASCII with a given
+ * precision.
+ */
+void /* PRIVATE */
+png_ascii_from_fp(png_const_structrp png_ptr, png_charp ascii, png_size_t size,
+ double fp, unsigned int precision)
+{
+ /* We use standard functions from math.h, but not printf because
+ * that would require stdio. The caller must supply a buffer of
+ * sufficient size or we will png_error. The tests on size and
+ * the space in ascii[] consumed are indicated below.
+ */
+ if (precision < 1)
+ precision = DBL_DIG;
+
+ /* Enforce the limit of the implementation precision too. */
+ if (precision > DBL_DIG+1)
+ precision = DBL_DIG+1;
+
+ /* Basic sanity checks */
+ if (size >= precision+5) /* See the requirements below. */
+ {
+ if (fp < 0)
+ {
+ fp = -fp;
+ *ascii++ = 45; /* '-' PLUS 1 TOTAL 1 */
+ --size;
+ }
+
+ if (fp >= DBL_MIN && fp <= DBL_MAX)
+ {
+ int exp_b10; /* A base 10 exponent */
+ double base; /* 10^exp_b10 */
+
+ /* First extract a base 10 exponent of the number,
+ * the calculation below rounds down when converting
+ * from base 2 to base 10 (multiply by log10(2) -
+ * 0.3010, but 77/256 is 0.3008, so exp_b10 needs to
+ * be increased. Note that the arithmetic shift
+ * performs a floor() unlike C arithmetic - using a
+ * C multiply would break the following for negative
+ * exponents.
+ */
+ (void)frexp(fp, &exp_b10); /* exponent to base 2 */
+
+ exp_b10 = (exp_b10 * 77) >> 8; /* <= exponent to base 10 */
+
+ /* Avoid underflow here. */
+ base = png_pow10(exp_b10); /* May underflow */
+
+ while (base < DBL_MIN || base < fp)
+ {
+ /* And this may overflow. */
+ double test = png_pow10(exp_b10+1);
+
+ if (test <= DBL_MAX)
+ ++exp_b10, base = test;
+
+ else
+ break;
+ }
+
+ /* Normalize fp and correct exp_b10, after this fp is in the
+ * range [.1,1) and exp_b10 is both the exponent and the digit
+ * *before* which the decimal point should be inserted
+ * (starting with 0 for the first digit). Note that this
+ * works even if 10^exp_b10 is out of range because of the
+ * test on DBL_MAX above.
+ */
+ fp /= base;
+ while (fp >= 1) fp /= 10, ++exp_b10;
+
+ /* Because of the code above fp may, at this point, be
+ * less than .1, this is ok because the code below can
+ * handle the leading zeros this generates, so no attempt
+ * is made to correct that here.
+ */
+
+ {
+ int czero, clead, cdigits;
+ char exponent[10];
+
+ /* Allow up to two leading zeros - this will not lengthen
+ * the number compared to using E-n.
+ */
+ if (exp_b10 < 0 && exp_b10 > -3) /* PLUS 3 TOTAL 4 */
+ {
+ czero = -exp_b10; /* PLUS 2 digits: TOTAL 3 */
+ exp_b10 = 0; /* Dot added below before first output. */
+ }
+ else
+ czero = 0; /* No zeros to add */
+
+ /* Generate the digit list, stripping trailing zeros and
+ * inserting a '.' before a digit if the exponent is 0.
+ */
+ clead = czero; /* Count of leading zeros */
+ cdigits = 0; /* Count of digits in list. */
+
+ do
+ {
+ double d;
+
+ fp *= 10;
+ /* Use modf here, not floor and subtract, so that
+ * the separation is done in one step. At the end
+ * of the loop don't break the number into parts so
+ * that the final digit is rounded.
+ */
+ if (cdigits+czero-clead+1 < (int)precision)
+ fp = modf(fp, &d);
+
+ else
+ {
+ d = floor(fp + .5);
+
+ if (d > 9)
+ {
+ /* Rounding up to 10, handle that here. */
+ if (czero > 0)
+ {
+ --czero, d = 1;
+ if (cdigits == 0) --clead;
+ }
+ else
+ {
+ while (cdigits > 0 && d > 9)
+ {
+ int ch = *--ascii;
+
+ if (exp_b10 != (-1))
+ ++exp_b10;
+
+ else if (ch == 46)
+ {
+ ch = *--ascii, ++size;
+ /* Advance exp_b10 to '1', so that the
+ * decimal point happens after the
+ * previous digit.
+ */
+ exp_b10 = 1;
+ }
+
+ --cdigits;
+ d = ch - 47; /* I.e. 1+(ch-48) */
+ }
+
+ /* Did we reach the beginning? If so adjust the
+ * exponent but take into account the leading
+ * decimal point.
+ */
+ if (d > 9) /* cdigits == 0 */
+ {
+ if (exp_b10 == (-1))
+ {
+ /* Leading decimal point (plus zeros?), if
+ * we lose the decimal point here it must
+ * be reentered below.
+ */
+ int ch = *--ascii;
+
+ if (ch == 46)
+ ++size, exp_b10 = 1;
+
+ /* Else lost a leading zero, so 'exp_b10' is
+ * still ok at (-1)
+ */
+ }
+ else
+ ++exp_b10;
+
+ /* In all cases we output a '1' */
+ d = 1;
+ }
+ }
+ }
+ fp = 0; /* Guarantees termination below. */
+ }
+
+ if (d == 0)
+ {
+ ++czero;
+ if (cdigits == 0) ++clead;
+ }
+ else
+ {
+ /* Included embedded zeros in the digit count. */
+ cdigits += czero - clead;
+ clead = 0;
+
+ while (czero > 0)
+ {
+ /* exp_b10 == (-1) means we just output the decimal
+ * place - after the DP don't adjust 'exp_b10' any
+ * more!
+ */
+ if (exp_b10 != (-1))
+ {
+ if (exp_b10 == 0) *ascii++ = 46, --size;
+ /* PLUS 1: TOTAL 4 */
+ --exp_b10;
+ }
+ *ascii++ = 48, --czero;
+ }
+
+ if (exp_b10 != (-1))
+ {
+ if (exp_b10 == 0) *ascii++ = 46, --size; /* counted
+ above */
+ --exp_b10;
+ }
+ *ascii++ = (char)(48 + (int)d), ++cdigits;
+ }
+ }
+ while (cdigits+czero-clead < (int)precision && fp > DBL_MIN);
+
+ /* The total output count (max) is now 4+precision */
+
+ /* Check for an exponent, if we don't need one we are
+ * done and just need to terminate the string. At
+ * this point exp_b10==(-1) is effectively if flag - it got
+ * to '-1' because of the decrement after outputing
+ * the decimal point above (the exponent required is
+ * *not* -1!)
+ */
+ if (exp_b10 >= (-1) && exp_b10 <= 2)
+ {
+ /* The following only happens if we didn't output the
+ * leading zeros above for negative exponent, so this
+ * doest add to the digit requirement. Note that the
+ * two zeros here can only be output if the two leading
+ * zeros were *not* output, so this doesn't increase
+ * the output count.
+ */
+ while (--exp_b10 >= 0) *ascii++ = 48;
+
+ *ascii = 0;
+
+ /* Total buffer requirement (including the '\0') is
+ * 5+precision - see check at the start.
+ */
+ return;
+ }
+
+ /* Here if an exponent is required, adjust size for
+ * the digits we output but did not count. The total
+ * digit output here so far is at most 1+precision - no
+ * decimal point and no leading or trailing zeros have
+ * been output.
+ */
+ size -= cdigits;
+
+ *ascii++ = 69, --size; /* 'E': PLUS 1 TOTAL 2+precision */
+
+ /* The following use of an unsigned temporary avoids ambiguities in
+ * the signed arithmetic on exp_b10 and permits GCC at least to do
+ * better optimization.
+ */
+ {
+ unsigned int uexp_b10;
+
+ if (exp_b10 < 0)
+ {
+ *ascii++ = 45, --size; /* '-': PLUS 1 TOTAL 3+precision */
+ uexp_b10 = -exp_b10;
+ }
+
+ else
+ uexp_b10 = exp_b10;
+
+ cdigits = 0;
+
+ while (uexp_b10 > 0)
+ {
+ exponent[cdigits++] = (char)(48 + uexp_b10 % 10);
+ uexp_b10 /= 10;
+ }
+ }
+
+ /* Need another size check here for the exponent digits, so
+ * this need not be considered above.
+ */
+ if ((int)size > cdigits)
+ {
+ while (cdigits > 0) *ascii++ = exponent[--cdigits];
+
+ *ascii = 0;
+
+ return;
+ }
+ }
+ }
+ else if (!(fp >= DBL_MIN))
+ {
+ *ascii++ = 48; /* '0' */
+ *ascii = 0;
+ return;
+ }
+ else
+ {
+ *ascii++ = 105; /* 'i' */
+ *ascii++ = 110; /* 'n' */
+ *ascii++ = 102; /* 'f' */
+ *ascii = 0;
+ return;
+ }
+ }
+
+ /* Here on buffer too small. */
+ png_error(png_ptr, "ASCII conversion buffer too small");
+}
+
+# endif /* FLOATING_POINT */
+
+# ifdef PNG_FIXED_POINT_SUPPORTED
+/* Function to format a fixed point value in ASCII.
+ */
+void /* PRIVATE */
+png_ascii_from_fixed(png_const_structrp png_ptr, png_charp ascii,
+ png_size_t size, png_fixed_point fp)
+{
+ /* Require space for 10 decimal digits, a decimal point, a minus sign and a
+ * trailing \0, 13 characters:
+ */
+ if (size > 12)
+ {
+ png_uint_32 num;
+
+ /* Avoid overflow here on the minimum integer. */
+ if (fp < 0)
+ *ascii++ = 45, --size, num = -fp;
+ else
+ num = fp;
+
+ if (num <= 0x80000000) /* else overflowed */
+ {
+ unsigned int ndigits = 0, first = 16 /* flag value */;
+ char digits[10];
+
+ while (num)
+ {
+ /* Split the low digit off num: */
+ unsigned int tmp = num/10;
+ num -= tmp*10;
+ digits[ndigits++] = (char)(48 + num);
+ /* Record the first non-zero digit, note that this is a number
+ * starting at 1, it's not actually the array index.
+ */
+ if (first == 16 && num > 0)
+ first = ndigits;
+ num = tmp;
+ }
+
+ if (ndigits > 0)
+ {
+ while (ndigits > 5) *ascii++ = digits[--ndigits];
+ /* The remaining digits are fractional digits, ndigits is '5' or
+ * smaller at this point. It is certainly not zero. Check for a
+ * non-zero fractional digit:
+ */
+ if (first <= 5)
+ {
+ unsigned int i;
+ *ascii++ = 46; /* decimal point */
+ /* ndigits may be <5 for small numbers, output leading zeros
+ * then ndigits digits to first:
+ */
+ i = 5;
+ while (ndigits < i) *ascii++ = 48, --i;
+ while (ndigits >= first) *ascii++ = digits[--ndigits];
+ /* Don't output the trailing zeros! */
+ }
+ }
+ else
+ *ascii++ = 48;
+
+ /* And null terminate the string: */
+ *ascii = 0;
+ return;
+ }
+ }
+
+ /* Here on buffer too small. */
+ png_error(png_ptr, "ASCII conversion buffer too small");
+}
+# endif /* FIXED_POINT */
+#endif /* READ_SCAL */
+
+#if defined(PNG_FLOATING_POINT_SUPPORTED) && \
+ !defined(PNG_FIXED_POINT_MACRO_SUPPORTED) && \
+ (defined(PNG_gAMA_SUPPORTED) || defined(PNG_cHRM_SUPPORTED) || \
+ defined(PNG_sCAL_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) || \
+ defined(PNG_READ_RGB_TO_GRAY_SUPPORTED)) || \
+ (defined(PNG_sCAL_SUPPORTED) && \
+ defined(PNG_FLOATING_ARITHMETIC_SUPPORTED))
+png_fixed_point
+png_fixed(png_const_structrp png_ptr, double fp, png_const_charp text)
+{
+ double r = floor(100000 * fp + .5);
+
+ if (r > 2147483647. || r < -2147483648.)
+ png_fixed_error(png_ptr, text);
+
+ return (png_fixed_point)r;
+}
+#endif
+
+#if defined(PNG_READ_GAMMA_SUPPORTED) || \
+ defined(PNG_INCH_CONVERSIONS_SUPPORTED) || defined(PNG_READ_pHYs_SUPPORTED)
+/* muldiv functions */
+/* This API takes signed arguments and rounds the result to the nearest
+ * integer (or, for a fixed point number - the standard argument - to
+ * the nearest .00001). Overflow and divide by zero are signalled in
+ * the result, a boolean - true on success, false on overflow.
+ */
+int
+png_muldiv(png_fixed_point_p res, png_fixed_point a, png_int_32 times,
+ png_int_32 divisor)
+{
+ /* Return a * times / divisor, rounded. */
+ if (divisor != 0)
+ {
+ if (a == 0 || times == 0)
+ {
+ *res = 0;
+ return 1;
+ }
+ else
+ {
+#ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED
+ double r = a;
+ r *= times;
+ r /= divisor;
+ r = floor(r+.5);
+
+ /* A png_fixed_point is a 32-bit integer. */
+ if (r <= 2147483647. && r >= -2147483648.)
+ {
+ *res = (png_fixed_point)r;
+ return 1;
+ }
+#else
+ int negative = 0;
+ png_uint_32 A, T, D;
+ png_uint_32 s16, s32, s00;
+
+ if (a < 0)
+ negative = 1, A = -a;
+ else
+ A = a;
+
+ if (times < 0)
+ negative = !negative, T = -times;
+ else
+ T = times;
+
+ if (divisor < 0)
+ negative = !negative, D = -divisor;
+ else
+ D = divisor;
+
+ /* Following can't overflow because the arguments only
+ * have 31 bits each, however the result may be 32 bits.
+ */
+ s16 = (A >> 16) * (T & 0xffff) +
+ (A & 0xffff) * (T >> 16);
+ /* Can't overflow because the a*times bit is only 30
+ * bits at most.
+ */
+ s32 = (A >> 16) * (T >> 16) + (s16 >> 16);
+ s00 = (A & 0xffff) * (T & 0xffff);
+
+ s16 = (s16 & 0xffff) << 16;
+ s00 += s16;
+
+ if (s00 < s16)
+ ++s32; /* carry */
+
+ if (s32 < D) /* else overflow */
+ {
+ /* s32.s00 is now the 64-bit product, do a standard
+ * division, we know that s32 < D, so the maximum
+ * required shift is 31.
+ */
+ int bitshift = 32;
+ png_fixed_point result = 0; /* NOTE: signed */
+
+ while (--bitshift >= 0)
+ {
+ png_uint_32 d32, d00;
+
+ if (bitshift > 0)
+ d32 = D >> (32-bitshift), d00 = D << bitshift;
+
+ else
+ d32 = 0, d00 = D;
+
+ if (s32 > d32)
+ {
+ if (s00 < d00) --s32; /* carry */
+ s32 -= d32, s00 -= d00, result += 1<<bitshift;
+ }
+
+ else
+ if (s32 == d32 && s00 >= d00)
+ s32 = 0, s00 -= d00, result += 1<<bitshift;
+ }
+
+ /* Handle the rounding. */
+ if (s00 >= (D >> 1))
+ ++result;
+
+ if (negative)
+ result = -result;
+
+ /* Check for overflow. */
+ if ((negative && result <= 0) || (!negative && result >= 0))
+ {
+ *res = result;
+ return 1;
+ }
+ }
+#endif
+ }
+ }
+
+ return 0;
+}
+#endif /* READ_GAMMA || INCH_CONVERSIONS */
+
+#if defined(PNG_READ_GAMMA_SUPPORTED) || defined(PNG_INCH_CONVERSIONS_SUPPORTED)
+/* The following is for when the caller doesn't much care about the
+ * result.
+ */
+png_fixed_point
+png_muldiv_warn(png_const_structrp png_ptr, png_fixed_point a, png_int_32 times,
+ png_int_32 divisor)
+{
+ png_fixed_point result;
+
+ if (png_muldiv(&result, a, times, divisor))
+ return result;
+
+ png_warning(png_ptr, "fixed point overflow ignored");
+ return 0;
+}
+#endif
+
+#ifdef PNG_GAMMA_SUPPORTED /* more fixed point functions for gamma */
+/* Calculate a reciprocal, return 0 on div-by-zero or overflow. */
+png_fixed_point
+png_reciprocal(png_fixed_point a)
+{
+#ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED
+ double r = floor(1E10/a+.5);
+
+ if (r <= 2147483647. && r >= -2147483648.)
+ return (png_fixed_point)r;
+#else
+ png_fixed_point res;
+
+ if (png_muldiv(&res, 100000, 100000, a))
+ return res;
+#endif
+
+ return 0; /* error/overflow */
+}
+
+/* This is the shared test on whether a gamma value is 'significant' - whether
+ * it is worth doing gamma correction.
+ */
+int /* PRIVATE */
+png_gamma_significant(png_fixed_point gamma_val)
+{
+ return gamma_val < PNG_FP_1 - PNG_GAMMA_THRESHOLD_FIXED ||
+ gamma_val > PNG_FP_1 + PNG_GAMMA_THRESHOLD_FIXED;
+}
+#endif
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+/* A local convenience routine. */
+static png_fixed_point
+png_product2(png_fixed_point a, png_fixed_point b)
+{
+ /* The required result is 1/a * 1/b; the following preserves accuracy. */
+#ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED
+ double r = a * 1E-5;
+ r *= b;
+ r = floor(r+.5);
+
+ if (r <= 2147483647. && r >= -2147483648.)
+ return (png_fixed_point)r;
+#else
+ png_fixed_point res;
+
+ if (png_muldiv(&res, a, b, 100000))
+ return res;
+#endif
+
+ return 0; /* overflow */
+}
+
+/* The inverse of the above. */
+png_fixed_point
+png_reciprocal2(png_fixed_point a, png_fixed_point b)
+{
+ /* The required result is 1/a * 1/b; the following preserves accuracy. */
+#ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED
+ double r = 1E15/a;
+ r /= b;
+ r = floor(r+.5);
+
+ if (r <= 2147483647. && r >= -2147483648.)
+ return (png_fixed_point)r;
+#else
+ /* This may overflow because the range of png_fixed_point isn't symmetric,
+ * but this API is only used for the product of file and screen gamma so it
+ * doesn't matter that the smallest number it can produce is 1/21474, not
+ * 1/100000
+ */
+ png_fixed_point res = png_product2(a, b);
+
+ if (res != 0)
+ return png_reciprocal(res);
+#endif
+
+ return 0; /* overflow */
+}
+#endif /* READ_GAMMA */
+
+#ifdef PNG_READ_GAMMA_SUPPORTED /* gamma table code */
+#ifndef PNG_FLOATING_ARITHMETIC_SUPPORTED
+/* Fixed point gamma.
+ *
+ * The code to calculate the tables used below can be found in the shell script
+ * contrib/tools/intgamma.sh
+ *
+ * To calculate gamma this code implements fast log() and exp() calls using only
+ * fixed point arithmetic. This code has sufficient precision for either 8-bit
+ * or 16-bit sample values.
+ *
+ * The tables used here were calculated using simple 'bc' programs, but C double
+ * precision floating point arithmetic would work fine.
+ *
+ * 8-bit log table
+ * This is a table of -log(value/255)/log(2) for 'value' in the range 128 to
+ * 255, so it's the base 2 logarithm of a normalized 8-bit floating point
+ * mantissa. The numbers are 32-bit fractions.
+ */
+static const png_uint_32
+png_8bit_l2[128] =
+{
+ 4270715492U, 4222494797U, 4174646467U, 4127164793U, 4080044201U, 4033279239U,
+ 3986864580U, 3940795015U, 3895065449U, 3849670902U, 3804606499U, 3759867474U,
+ 3715449162U, 3671346997U, 3627556511U, 3584073329U, 3540893168U, 3498011834U,
+ 3455425220U, 3413129301U, 3371120137U, 3329393864U, 3287946700U, 3246774933U,
+ 3205874930U, 3165243125U, 3124876025U, 3084770202U, 3044922296U, 3005329011U,
+ 2965987113U, 2926893432U, 2888044853U, 2849438323U, 2811070844U, 2772939474U,
+ 2735041326U, 2697373562U, 2659933400U, 2622718104U, 2585724991U, 2548951424U,
+ 2512394810U, 2476052606U, 2439922311U, 2404001468U, 2368287663U, 2332778523U,
+ 2297471715U, 2262364947U, 2227455964U, 2192742551U, 2158222529U, 2123893754U,
+ 2089754119U, 2055801552U, 2022034013U, 1988449497U, 1955046031U, 1921821672U,
+ 1888774511U, 1855902668U, 1823204291U, 1790677560U, 1758320682U, 1726131893U,
+ 1694109454U, 1662251657U, 1630556815U, 1599023271U, 1567649391U, 1536433567U,
+ 1505374214U, 1474469770U, 1443718700U, 1413119487U, 1382670639U, 1352370686U,
+ 1322218179U, 1292211689U, 1262349810U, 1232631153U, 1203054352U, 1173618059U,
+ 1144320946U, 1115161701U, 1086139034U, 1057251672U, 1028498358U, 999877854U,
+ 971388940U, 943030410U, 914801076U, 886699767U, 858725327U, 830876614U,
+ 803152505U, 775551890U, 748073672U, 720716771U, 693480120U, 666362667U,
+ 639363374U, 612481215U, 585715177U, 559064263U, 532527486U, 506103872U,
+ 479792461U, 453592303U, 427502463U, 401522014U, 375650043U, 349885648U,
+ 324227938U, 298676034U, 273229066U, 247886176U, 222646516U, 197509248U,
+ 172473545U, 147538590U, 122703574U, 97967701U, 73330182U, 48790236U,
+ 24347096U, 0U
+
+#if 0
+ /* The following are the values for 16-bit tables - these work fine for the
+ * 8-bit conversions but produce very slightly larger errors in the 16-bit
+ * log (about 1.2 as opposed to 0.7 absolute error in the final value). To
+ * use these all the shifts below must be adjusted appropriately.
+ */
+ 65166, 64430, 63700, 62976, 62257, 61543, 60835, 60132, 59434, 58741, 58054,
+ 57371, 56693, 56020, 55352, 54689, 54030, 53375, 52726, 52080, 51439, 50803,
+ 50170, 49542, 48918, 48298, 47682, 47070, 46462, 45858, 45257, 44661, 44068,
+ 43479, 42894, 42312, 41733, 41159, 40587, 40020, 39455, 38894, 38336, 37782,
+ 37230, 36682, 36137, 35595, 35057, 34521, 33988, 33459, 32932, 32408, 31887,
+ 31369, 30854, 30341, 29832, 29325, 28820, 28319, 27820, 27324, 26830, 26339,
+ 25850, 25364, 24880, 24399, 23920, 23444, 22970, 22499, 22029, 21562, 21098,
+ 20636, 20175, 19718, 19262, 18808, 18357, 17908, 17461, 17016, 16573, 16132,
+ 15694, 15257, 14822, 14390, 13959, 13530, 13103, 12678, 12255, 11834, 11415,
+ 10997, 10582, 10168, 9756, 9346, 8937, 8531, 8126, 7723, 7321, 6921, 6523,
+ 6127, 5732, 5339, 4947, 4557, 4169, 3782, 3397, 3014, 2632, 2251, 1872, 1495,
+ 1119, 744, 372
+#endif
+};
+
+static png_int_32
+png_log8bit(unsigned int x)
+{
+ unsigned int lg2 = 0;
+ /* Each time 'x' is multiplied by 2, 1 must be subtracted off the final log,
+ * because the log is actually negate that means adding 1. The final
+ * returned value thus has the range 0 (for 255 input) to 7.994 (for 1
+ * input), return -1 for the overflow (log 0) case, - so the result is
+ * always at most 19 bits.
+ */
+ if ((x &= 0xff) == 0)
+ return -1;
+
+ if ((x & 0xf0) == 0)
+ lg2 = 4, x <<= 4;
+
+ if ((x & 0xc0) == 0)
+ lg2 += 2, x <<= 2;
+
+ if ((x & 0x80) == 0)
+ lg2 += 1, x <<= 1;
+
+ /* result is at most 19 bits, so this cast is safe: */
+ return (png_int_32)((lg2 << 16) + ((png_8bit_l2[x-128]+32768)>>16));
+}
+
+/* The above gives exact (to 16 binary places) log2 values for 8-bit images,
+ * for 16-bit images we use the most significant 8 bits of the 16-bit value to
+ * get an approximation then multiply the approximation by a correction factor
+ * determined by the remaining up to 8 bits. This requires an additional step
+ * in the 16-bit case.
+ *
+ * We want log2(value/65535), we have log2(v'/255), where:
+ *
+ * value = v' * 256 + v''
+ * = v' * f
+ *
+ * So f is value/v', which is equal to (256+v''/v') since v' is in the range 128
+ * to 255 and v'' is in the range 0 to 255 f will be in the range 256 to less
+ * than 258. The final factor also needs to correct for the fact that our 8-bit
+ * value is scaled by 255, whereas the 16-bit values must be scaled by 65535.
+ *
+ * This gives a final formula using a calculated value 'x' which is value/v' and
+ * scaling by 65536 to match the above table:
+ *
+ * log2(x/257) * 65536
+ *
+ * Since these numbers are so close to '1' we can use simple linear
+ * interpolation between the two end values 256/257 (result -368.61) and 258/257
+ * (result 367.179). The values used below are scaled by a further 64 to give
+ * 16-bit precision in the interpolation:
+ *
+ * Start (256): -23591
+ * Zero (257): 0
+ * End (258): 23499
+ */
+static png_int_32
+png_log16bit(png_uint_32 x)
+{
+ unsigned int lg2 = 0;
+
+ /* As above, but now the input has 16 bits. */
+ if ((x &= 0xffff) == 0)
+ return -1;
+
+ if ((x & 0xff00) == 0)
+ lg2 = 8, x <<= 8;
+
+ if ((x & 0xf000) == 0)
+ lg2 += 4, x <<= 4;
+
+ if ((x & 0xc000) == 0)
+ lg2 += 2, x <<= 2;
+
+ if ((x & 0x8000) == 0)
+ lg2 += 1, x <<= 1;
+
+ /* Calculate the base logarithm from the top 8 bits as a 28-bit fractional
+ * value.
+ */
+ lg2 <<= 28;
+ lg2 += (png_8bit_l2[(x>>8)-128]+8) >> 4;
+
+ /* Now we need to interpolate the factor, this requires a division by the top
+ * 8 bits. Do this with maximum precision.
+ */
+ x = ((x << 16) + (x >> 9)) / (x >> 8);
+
+ /* Since we divided by the top 8 bits of 'x' there will be a '1' at 1<<24,
+ * the value at 1<<16 (ignoring this) will be 0 or 1; this gives us exactly
+ * 16 bits to interpolate to get the low bits of the result. Round the
+ * answer. Note that the end point values are scaled by 64 to retain overall
+ * precision and that 'lg2' is current scaled by an extra 12 bits, so adjust
+ * the overall scaling by 6-12. Round at every step.
+ */
+ x -= 1U << 24;
+
+ if (x <= 65536U) /* <= '257' */
+ lg2 += ((23591U * (65536U-x)) + (1U << (16+6-12-1))) >> (16+6-12);
+
+ else
+ lg2 -= ((23499U * (x-65536U)) + (1U << (16+6-12-1))) >> (16+6-12);
+
+ /* Safe, because the result can't have more than 20 bits: */
+ return (png_int_32)((lg2 + 2048) >> 12);
+}
+
+/* The 'exp()' case must invert the above, taking a 20-bit fixed point
+ * logarithmic value and returning a 16 or 8-bit number as appropriate. In
+ * each case only the low 16 bits are relevant - the fraction - since the
+ * integer bits (the top 4) simply determine a shift.
+ *
+ * The worst case is the 16-bit distinction between 65535 and 65534, this
+ * requires perhaps spurious accuracty in the decoding of the logarithm to
+ * distinguish log2(65535/65534.5) - 10^-5 or 17 bits. There is little chance
+ * of getting this accuracy in practice.
+ *
+ * To deal with this the following exp() function works out the exponent of the
+ * frational part of the logarithm by using an accurate 32-bit value from the
+ * top four fractional bits then multiplying in the remaining bits.
+ */
+static const png_uint_32
+png_32bit_exp[16] =
+{
+ /* NOTE: the first entry is deliberately set to the maximum 32-bit value. */
+ 4294967295U, 4112874773U, 3938502376U, 3771522796U, 3611622603U, 3458501653U,
+ 3311872529U, 3171459999U, 3037000500U, 2908241642U, 2784941738U, 2666869345U,
+ 2553802834U, 2445529972U, 2341847524U, 2242560872U
+};
+
+/* Adjustment table; provided to explain the numbers in the code below. */
+#if 0
+for (i=11;i>=0;--i){ print i, " ", (1 - e(-(2^i)/65536*l(2))) * 2^(32-i), "\n"}
+ 11 44937.64284865548751208448
+ 10 45180.98734845585101160448
+ 9 45303.31936980687359311872
+ 8 45364.65110595323018870784
+ 7 45395.35850361789624614912
+ 6 45410.72259715102037508096
+ 5 45418.40724413220722311168
+ 4 45422.25021786898173001728
+ 3 45424.17186732298419044352
+ 2 45425.13273269940811464704
+ 1 45425.61317555035558641664
+ 0 45425.85339951654943850496
+#endif
+
+static png_uint_32
+png_exp(png_fixed_point x)
+{
+ if (x > 0 && x <= 0xfffff) /* Else overflow or zero (underflow) */
+ {
+ /* Obtain a 4-bit approximation */
+ png_uint_32 e = png_32bit_exp[(x >> 12) & 0xf];
+
+ /* Incorporate the low 12 bits - these decrease the returned value by
+ * multiplying by a number less than 1 if the bit is set. The multiplier
+ * is determined by the above table and the shift. Notice that the values
+ * converge on 45426 and this is used to allow linear interpolation of the
+ * low bits.
+ */
+ if (x & 0x800)
+ e -= (((e >> 16) * 44938U) + 16U) >> 5;
+
+ if (x & 0x400)
+ e -= (((e >> 16) * 45181U) + 32U) >> 6;
+
+ if (x & 0x200)
+ e -= (((e >> 16) * 45303U) + 64U) >> 7;
+
+ if (x & 0x100)
+ e -= (((e >> 16) * 45365U) + 128U) >> 8;
+
+ if (x & 0x080)
+ e -= (((e >> 16) * 45395U) + 256U) >> 9;
+
+ if (x & 0x040)
+ e -= (((e >> 16) * 45410U) + 512U) >> 10;
+
+ /* And handle the low 6 bits in a single block. */
+ e -= (((e >> 16) * 355U * (x & 0x3fU)) + 256U) >> 9;
+
+ /* Handle the upper bits of x. */
+ e >>= x >> 16;
+ return e;
+ }
+
+ /* Check for overflow */
+ if (x <= 0)
+ return png_32bit_exp[0];
+
+ /* Else underflow */
+ return 0;
+}
+
+static png_byte
+png_exp8bit(png_fixed_point lg2)
+{
+ /* Get a 32-bit value: */
+ png_uint_32 x = png_exp(lg2);
+
+ /* Convert the 32-bit value to 0..255 by multiplying by 256-1, note that the
+ * second, rounding, step can't overflow because of the first, subtraction,
+ * step.
+ */
+ x -= x >> 8;
+ return (png_byte)((x + 0x7fffffU) >> 24);
+}
+
+static png_uint_16
+png_exp16bit(png_fixed_point lg2)
+{
+ /* Get a 32-bit value: */
+ png_uint_32 x = png_exp(lg2);
+
+ /* Convert the 32-bit value to 0..65535 by multiplying by 65536-1: */
+ x -= x >> 16;
+ return (png_uint_16)((x + 32767U) >> 16);
+}
+#endif /* FLOATING_ARITHMETIC */
+
+png_byte
+png_gamma_8bit_correct(unsigned int value, png_fixed_point gamma_val)
+{
+ if (value > 0 && value < 255)
+ {
+# ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED
+ double r = floor(255*pow(value/255.,gamma_val*.00001)+.5);
+ return (png_byte)r;
+# else
+ png_int_32 lg2 = png_log8bit(value);
+ png_fixed_point res;
+
+ if (png_muldiv(&res, gamma_val, lg2, PNG_FP_1))
+ return png_exp8bit(res);
+
+ /* Overflow. */
+ value = 0;
+# endif
+ }
+
+ return (png_byte)value;
+}
+
+png_uint_16
+png_gamma_16bit_correct(unsigned int value, png_fixed_point gamma_val)
+{
+ if (value > 0 && value < 65535)
+ {
+# ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED
+ double r = floor(65535*pow(value/65535.,gamma_val*.00001)+.5);
+ return (png_uint_16)r;
+# else
+ png_int_32 lg2 = png_log16bit(value);
+ png_fixed_point res;
+
+ if (png_muldiv(&res, gamma_val, lg2, PNG_FP_1))
+ return png_exp16bit(res);
+
+ /* Overflow. */
+ value = 0;
+# endif
+ }
+
+ return (png_uint_16)value;
+}
+
+/* This does the right thing based on the bit_depth field of the
+ * png_struct, interpreting values as 8-bit or 16-bit. While the result
+ * is nominally a 16-bit value if bit depth is 8 then the result is
+ * 8-bit (as are the arguments.)
+ */
+png_uint_16 /* PRIVATE */
+png_gamma_correct(png_structrp png_ptr, unsigned int value,
+ png_fixed_point gamma_val)
+{
+ if (png_ptr->bit_depth == 8)
+ return png_gamma_8bit_correct(value, gamma_val);
+
+ else
+ return png_gamma_16bit_correct(value, gamma_val);
+}
+
+/* Internal function to build a single 16-bit table - the table consists of
+ * 'num' 256 entry subtables, where 'num' is determined by 'shift' - the amount
+ * to shift the input values right (or 16-number_of_signifiant_bits).
+ *
+ * The caller is responsible for ensuring that the table gets cleaned up on
+ * png_error (i.e. if one of the mallocs below fails) - i.e. the *table argument
+ * should be somewhere that will be cleaned.
+ */
+static void
+png_build_16bit_table(png_structrp png_ptr, png_uint_16pp *ptable,
+ PNG_CONST unsigned int shift, PNG_CONST png_fixed_point gamma_val)
+{
+ /* Various values derived from 'shift': */
+ PNG_CONST unsigned int num = 1U << (8U - shift);
+ PNG_CONST unsigned int max = (1U << (16U - shift))-1U;
+ PNG_CONST unsigned int max_by_2 = 1U << (15U-shift);
+ unsigned int i;
+
+ png_uint_16pp table = *ptable =
+ (png_uint_16pp)png_calloc(png_ptr, num * (sizeof (png_uint_16p)));
+
+ for (i = 0; i < num; i++)
+ {
+ png_uint_16p sub_table = table[i] =
+ (png_uint_16p)png_malloc(png_ptr, 256 * (sizeof (png_uint_16)));
+
+ /* The 'threshold' test is repeated here because it can arise for one of
+ * the 16-bit tables even if the others don't hit it.
+ */
+ if (png_gamma_significant(gamma_val))
+ {
+ /* The old code would overflow at the end and this would cause the
+ * 'pow' function to return a result >1, resulting in an
+ * arithmetic error. This code follows the spec exactly; ig is
+ * the recovered input sample, it always has 8-16 bits.
+ *
+ * We want input * 65535/max, rounded, the arithmetic fits in 32
+ * bits (unsigned) so long as max <= 32767.
+ */
+ unsigned int j;
+ for (j = 0; j < 256; j++)
+ {
+ png_uint_32 ig = (j << (8-shift)) + i;
+# ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED
+ /* Inline the 'max' scaling operation: */
+ double d = floor(65535*pow(ig/(double)max, gamma_val*.00001)+.5);
+ sub_table[j] = (png_uint_16)d;
+# else
+ if (shift)
+ ig = (ig * 65535U + max_by_2)/max;
+
+ sub_table[j] = png_gamma_16bit_correct(ig, gamma_val);
+# endif
+ }
+ }
+ else
+ {
+ /* We must still build a table, but do it the fast way. */
+ unsigned int j;
+
+ for (j = 0; j < 256; j++)
+ {
+ png_uint_32 ig = (j << (8-shift)) + i;
+
+ if (shift)
+ ig = (ig * 65535U + max_by_2)/max;
+
+ sub_table[j] = (png_uint_16)ig;
+ }
+ }
+ }
+}
+
+/* NOTE: this function expects the *inverse* of the overall gamma transformation
+ * required.
+ */
+static void
+png_build_16to8_table(png_structrp png_ptr, png_uint_16pp *ptable,
+ PNG_CONST unsigned int shift, PNG_CONST png_fixed_point gamma_val)
+{
+ PNG_CONST unsigned int num = 1U << (8U - shift);
+ PNG_CONST unsigned int max = (1U << (16U - shift))-1U;
+ unsigned int i;
+ png_uint_32 last;
+
+ png_uint_16pp table = *ptable =
+ (png_uint_16pp)png_calloc(png_ptr, num * (sizeof (png_uint_16p)));
+
+ /* 'num' is the number of tables and also the number of low bits of low
+ * bits of the input 16-bit value used to select a table. Each table is
+ * itself index by the high 8 bits of the value.
+ */
+ for (i = 0; i < num; i++)
+ table[i] = (png_uint_16p)png_malloc(png_ptr,
+ 256 * (sizeof (png_uint_16)));
+
+ /* 'gamma_val' is set to the reciprocal of the value calculated above, so
+ * pow(out,g) is an *input* value. 'last' is the last input value set.
+ *
+ * In the loop 'i' is used to find output values. Since the output is
+ * 8-bit there are only 256 possible values. The tables are set up to
+ * select the closest possible output value for each input by finding
+ * the input value at the boundary between each pair of output values
+ * and filling the table up to that boundary with the lower output
+ * value.
+ *
+ * The boundary values are 0.5,1.5..253.5,254.5. Since these are 9-bit
+ * values the code below uses a 16-bit value in i; the values start at
+ * 128.5 (for 0.5) and step by 257, for a total of 254 values (the last
+ * entries are filled with 255). Start i at 128 and fill all 'last'
+ * table entries <= 'max'
+ */
+ last = 0;
+ for (i = 0; i < 255; ++i) /* 8-bit output value */
+ {
+ /* Find the corresponding maximum input value */
+ png_uint_16 out = (png_uint_16)(i * 257U); /* 16-bit output value */
+
+ /* Find the boundary value in 16 bits: */
+ png_uint_32 bound = png_gamma_16bit_correct(out+128U, gamma_val);
+
+ /* Adjust (round) to (16-shift) bits: */
+ bound = (bound * max + 32768U)/65535U + 1U;
+
+ while (last < bound)
+ {
+ table[last & (0xffU >> shift)][last >> (8U - shift)] = out;
+ last++;
+ }
+ }
+
+ /* And fill in the final entries. */
+ while (last < (num << 8))
+ {
+ table[last & (0xff >> shift)][last >> (8U - shift)] = 65535U;
+ last++;
+ }
+}
+
+/* Build a single 8-bit table: same as the 16-bit case but much simpler (and
+ * typically much faster). Note that libpng currently does no sBIT processing
+ * (apparently contrary to the spec) so a 256 entry table is always generated.
+ */
+static void
+png_build_8bit_table(png_structrp png_ptr, png_bytepp ptable,
+ PNG_CONST png_fixed_point gamma_val)
+{
+ unsigned int i;
+ png_bytep table = *ptable = (png_bytep)png_malloc(png_ptr, 256);
+
+ if (png_gamma_significant(gamma_val)) for (i=0; i<256; i++)
+ table[i] = png_gamma_8bit_correct(i, gamma_val);
+
+ else for (i=0; i<256; ++i)
+ table[i] = (png_byte)i;
+}
+
+/* Used from png_read_destroy and below to release the memory used by the gamma
+ * tables.
+ */
+void /* PRIVATE */
+png_destroy_gamma_table(png_structrp png_ptr)
+{
+ png_free(png_ptr, png_ptr->gamma_table);
+ png_ptr->gamma_table = NULL;
+
+ if (png_ptr->gamma_16_table != NULL)
+ {
+ int i;
+ int istop = (1 << (8 - png_ptr->gamma_shift));
+ for (i = 0; i < istop; i++)
+ {
+ png_free(png_ptr, png_ptr->gamma_16_table[i]);
+ }
+ png_free(png_ptr, png_ptr->gamma_16_table);
+ png_ptr->gamma_16_table = NULL;
+ }
+
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) || \
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED) || \
+ defined(PNG_READ_RGB_TO_GRAY_SUPPORTED)
+ png_free(png_ptr, png_ptr->gamma_from_1);
+ png_ptr->gamma_from_1 = NULL;
+ png_free(png_ptr, png_ptr->gamma_to_1);
+ png_ptr->gamma_to_1 = NULL;
+
+ if (png_ptr->gamma_16_from_1 != NULL)
+ {
+ int i;
+ int istop = (1 << (8 - png_ptr->gamma_shift));
+ for (i = 0; i < istop; i++)
+ {
+ png_free(png_ptr, png_ptr->gamma_16_from_1[i]);
+ }
+ png_free(png_ptr, png_ptr->gamma_16_from_1);
+ png_ptr->gamma_16_from_1 = NULL;
+ }
+ if (png_ptr->gamma_16_to_1 != NULL)
+ {
+ int i;
+ int istop = (1 << (8 - png_ptr->gamma_shift));
+ for (i = 0; i < istop; i++)
+ {
+ png_free(png_ptr, png_ptr->gamma_16_to_1[i]);
+ }
+ png_free(png_ptr, png_ptr->gamma_16_to_1);
+ png_ptr->gamma_16_to_1 = NULL;
+ }
+#endif /* READ_BACKGROUND || READ_ALPHA_MODE || RGB_TO_GRAY */
+}
+
+/* We build the 8- or 16-bit gamma tables here. Note that for 16-bit
+ * tables, we don't make a full table if we are reducing to 8-bit in
+ * the future. Note also how the gamma_16 tables are segmented so that
+ * we don't need to allocate > 64K chunks for a full 16-bit table.
+ */
+void /* PRIVATE */
+png_build_gamma_table(png_structrp png_ptr, int bit_depth)
+{
+ png_debug(1, "in png_build_gamma_table");
+
+ /* Remove any existing table; this copes with multiple calls to
+ * png_read_update_info. The warning is because building the gamma tables
+ * multiple times is a performance hit - it's harmless but the ability to call
+ * png_read_update_info() multiple times is new in 1.5.6 so it seems sensible
+ * to warn if the app introduces such a hit.
+ */
+ if (png_ptr->gamma_table != NULL || png_ptr->gamma_16_table != NULL)
+ {
+ png_warning(png_ptr, "gamma table being rebuilt");
+ png_destroy_gamma_table(png_ptr);
+ }
+
+ if (bit_depth <= 8)
+ {
+ png_build_8bit_table(png_ptr, &png_ptr->gamma_table,
+ png_ptr->screen_gamma > 0 ? png_reciprocal2(png_ptr->colorspace.gamma,
+ png_ptr->screen_gamma) : PNG_FP_1);
+
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) || \
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED) || \
+ defined(PNG_READ_RGB_TO_GRAY_SUPPORTED)
+ if (png_ptr->transformations & (PNG_COMPOSE | PNG_RGB_TO_GRAY))
+ {
+ png_build_8bit_table(png_ptr, &png_ptr->gamma_to_1,
+ png_reciprocal(png_ptr->colorspace.gamma));
+
+ png_build_8bit_table(png_ptr, &png_ptr->gamma_from_1,
+ png_ptr->screen_gamma > 0 ? png_reciprocal(png_ptr->screen_gamma) :
+ png_ptr->colorspace.gamma/* Probably doing rgb_to_gray */);
+ }
+#endif /* READ_BACKGROUND || READ_ALPHA_MODE || RGB_TO_GRAY */
+ }
+ else
+ {
+ png_byte shift, sig_bit;
+
+ if (png_ptr->color_type & PNG_COLOR_MASK_COLOR)
+ {
+ sig_bit = png_ptr->sig_bit.red;
+
+ if (png_ptr->sig_bit.green > sig_bit)
+ sig_bit = png_ptr->sig_bit.green;
+
+ if (png_ptr->sig_bit.blue > sig_bit)
+ sig_bit = png_ptr->sig_bit.blue;
+ }
+ else
+ sig_bit = png_ptr->sig_bit.gray;
+
+ /* 16-bit gamma code uses this equation:
+ *
+ * ov = table[(iv & 0xff) >> gamma_shift][iv >> 8]
+ *
+ * Where 'iv' is the input color value and 'ov' is the output value -
+ * pow(iv, gamma).
+ *
+ * Thus the gamma table consists of up to 256 256 entry tables. The table
+ * is selected by the (8-gamma_shift) most significant of the low 8 bits of
+ * the color value then indexed by the upper 8 bits:
+ *
+ * table[low bits][high 8 bits]
+ *
+ * So the table 'n' corresponds to all those 'iv' of:
+ *
+ * <all high 8-bit values><n << gamma_shift>..<(n+1 << gamma_shift)-1>
+ *
+ */
+ if (sig_bit > 0 && sig_bit < 16U)
+ shift = (png_byte)(16U - sig_bit); /* shift == insignificant bits */
+
+ else
+ shift = 0; /* keep all 16 bits */
+
+ if (png_ptr->transformations & (PNG_16_TO_8 | PNG_SCALE_16_TO_8))
+ {
+ /* PNG_MAX_GAMMA_8 is the number of bits to keep - effectively
+ * the significant bits in the *input* when the output will
+ * eventually be 8 bits. By default it is 11.
+ */
+ if (shift < (16U - PNG_MAX_GAMMA_8))
+ shift = (16U - PNG_MAX_GAMMA_8);
+ }
+
+ if (shift > 8U)
+ shift = 8U; /* Guarantees at least one table! */
+
+ png_ptr->gamma_shift = shift;
+
+#ifdef PNG_16BIT_SUPPORTED
+ /* NOTE: prior to 1.5.4 this test used to include PNG_BACKGROUND (now
+ * PNG_COMPOSE). This effectively smashed the background calculation for
+ * 16-bit output because the 8-bit table assumes the result will be reduced
+ * to 8 bits.
+ */
+ if (png_ptr->transformations & (PNG_16_TO_8 | PNG_SCALE_16_TO_8))
+#endif
+ png_build_16to8_table(png_ptr, &png_ptr->gamma_16_table, shift,
+ png_ptr->screen_gamma > 0 ? png_product2(png_ptr->colorspace.gamma,
+ png_ptr->screen_gamma) : PNG_FP_1);
+
+#ifdef PNG_16BIT_SUPPORTED
+ else
+ png_build_16bit_table(png_ptr, &png_ptr->gamma_16_table, shift,
+ png_ptr->screen_gamma > 0 ? png_reciprocal2(png_ptr->colorspace.gamma,
+ png_ptr->screen_gamma) : PNG_FP_1);
+#endif
+
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) || \
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED) || \
+ defined(PNG_READ_RGB_TO_GRAY_SUPPORTED)
+ if (png_ptr->transformations & (PNG_COMPOSE | PNG_RGB_TO_GRAY))
+ {
+ png_build_16bit_table(png_ptr, &png_ptr->gamma_16_to_1, shift,
+ png_reciprocal(png_ptr->colorspace.gamma));
+
+ /* Notice that the '16 from 1' table should be full precision, however
+ * the lookup on this table still uses gamma_shift, so it can't be.
+ * TODO: fix this.
+ */
+ png_build_16bit_table(png_ptr, &png_ptr->gamma_16_from_1, shift,
+ png_ptr->screen_gamma > 0 ? png_reciprocal(png_ptr->screen_gamma) :
+ png_ptr->colorspace.gamma/* Probably doing rgb_to_gray */);
+ }
+#endif /* READ_BACKGROUND || READ_ALPHA_MODE || RGB_TO_GRAY */
+ }
+}
+#endif /* READ_GAMMA */
+
+/* HARDWARE OPTION SUPPORT */
+#ifdef PNG_SET_OPTION_SUPPORTED
+int PNGAPI
+png_set_option(png_structrp png_ptr, int option, int onoff)
+{
+ if (png_ptr != NULL && option >= 0 && option < PNG_OPTION_NEXT &&
+ (option & 1) == 0)
+ {
+ int mask = 3 << option;
+ int setting = (2 + (onoff != 0)) << option;
+ int current = png_ptr->options;
+
+ png_ptr->options = (png_byte)((current & ~mask) | setting);
+
+ return (current & mask) >> option;
+ }
+
+ return PNG_OPTION_INVALID;
+}
+#endif
+
+/* sRGB support */
+#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\
+ defined(PNG_SIMPLIFIED_WRITE_SUPPORTED)
+/* sRGB conversion tables; these are machine generated with the code in
+ * contrib/tools/makesRGB.c. The actual sRGB transfer curve defined in the
+ * specification (see the article at http://en.wikipedia.org/wiki/SRGB)
+ * is used, not the gamma=1/2.2 approximation use elsewhere in libpng.
+ * The sRGB to linear table is exact (to the nearest 16 bit linear fraction).
+ * The inverse (linear to sRGB) table has accuracies as follows:
+ *
+ * For all possible (255*65535+1) input values:
+ *
+ * error: -0.515566 - 0.625971, 79441 (0.475369%) of readings inexact
+ *
+ * For the input values corresponding to the 65536 16-bit values:
+ *
+ * error: -0.513727 - 0.607759, 308 (0.469978%) of readings inexact
+ *
+ * In all cases the inexact readings are off by one.
+ */
+
+#ifdef PNG_SIMPLIFIED_READ_SUPPORTED
+/* The convert-to-sRGB table is only currently required for read. */
+const png_uint_16 png_sRGB_table[256] =
+{
+ 0,20,40,60,80,99,119,139,
+ 159,179,199,219,241,264,288,313,
+ 340,367,396,427,458,491,526,562,
+ 599,637,677,718,761,805,851,898,
+ 947,997,1048,1101,1156,1212,1270,1330,
+ 1391,1453,1517,1583,1651,1720,1790,1863,
+ 1937,2013,2090,2170,2250,2333,2418,2504,
+ 2592,2681,2773,2866,2961,3058,3157,3258,
+ 3360,3464,3570,3678,3788,3900,4014,4129,
+ 4247,4366,4488,4611,4736,4864,4993,5124,
+ 5257,5392,5530,5669,5810,5953,6099,6246,
+ 6395,6547,6700,6856,7014,7174,7335,7500,
+ 7666,7834,8004,8177,8352,8528,8708,8889,
+ 9072,9258,9445,9635,9828,10022,10219,10417,
+ 10619,10822,11028,11235,11446,11658,11873,12090,
+ 12309,12530,12754,12980,13209,13440,13673,13909,
+ 14146,14387,14629,14874,15122,15371,15623,15878,
+ 16135,16394,16656,16920,17187,17456,17727,18001,
+ 18277,18556,18837,19121,19407,19696,19987,20281,
+ 20577,20876,21177,21481,21787,22096,22407,22721,
+ 23038,23357,23678,24002,24329,24658,24990,25325,
+ 25662,26001,26344,26688,27036,27386,27739,28094,
+ 28452,28813,29176,29542,29911,30282,30656,31033,
+ 31412,31794,32179,32567,32957,33350,33745,34143,
+ 34544,34948,35355,35764,36176,36591,37008,37429,
+ 37852,38278,38706,39138,39572,40009,40449,40891,
+ 41337,41785,42236,42690,43147,43606,44069,44534,
+ 45002,45473,45947,46423,46903,47385,47871,48359,
+ 48850,49344,49841,50341,50844,51349,51858,52369,
+ 52884,53401,53921,54445,54971,55500,56032,56567,
+ 57105,57646,58190,58737,59287,59840,60396,60955,
+ 61517,62082,62650,63221,63795,64372,64952,65535
+};
+
+#endif /* simplified read only */
+
+/* The base/delta tables are required for both read and write (but currently
+ * only the simplified versions.)
+ */
+const png_uint_16 png_sRGB_base[512] =
+{
+ 128,1782,3383,4644,5675,6564,7357,8074,
+ 8732,9346,9921,10463,10977,11466,11935,12384,
+ 12816,13233,13634,14024,14402,14769,15125,15473,
+ 15812,16142,16466,16781,17090,17393,17690,17981,
+ 18266,18546,18822,19093,19359,19621,19879,20133,
+ 20383,20630,20873,21113,21349,21583,21813,22041,
+ 22265,22487,22707,22923,23138,23350,23559,23767,
+ 23972,24175,24376,24575,24772,24967,25160,25352,
+ 25542,25730,25916,26101,26284,26465,26645,26823,
+ 27000,27176,27350,27523,27695,27865,28034,28201,
+ 28368,28533,28697,28860,29021,29182,29341,29500,
+ 29657,29813,29969,30123,30276,30429,30580,30730,
+ 30880,31028,31176,31323,31469,31614,31758,31902,
+ 32045,32186,32327,32468,32607,32746,32884,33021,
+ 33158,33294,33429,33564,33697,33831,33963,34095,
+ 34226,34357,34486,34616,34744,34873,35000,35127,
+ 35253,35379,35504,35629,35753,35876,35999,36122,
+ 36244,36365,36486,36606,36726,36845,36964,37083,
+ 37201,37318,37435,37551,37668,37783,37898,38013,
+ 38127,38241,38354,38467,38580,38692,38803,38915,
+ 39026,39136,39246,39356,39465,39574,39682,39790,
+ 39898,40005,40112,40219,40325,40431,40537,40642,
+ 40747,40851,40955,41059,41163,41266,41369,41471,
+ 41573,41675,41777,41878,41979,42079,42179,42279,
+ 42379,42478,42577,42676,42775,42873,42971,43068,
+ 43165,43262,43359,43456,43552,43648,43743,43839,
+ 43934,44028,44123,44217,44311,44405,44499,44592,
+ 44685,44778,44870,44962,45054,45146,45238,45329,
+ 45420,45511,45601,45692,45782,45872,45961,46051,
+ 46140,46229,46318,46406,46494,46583,46670,46758,
+ 46846,46933,47020,47107,47193,47280,47366,47452,
+ 47538,47623,47709,47794,47879,47964,48048,48133,
+ 48217,48301,48385,48468,48552,48635,48718,48801,
+ 48884,48966,49048,49131,49213,49294,49376,49458,
+ 49539,49620,49701,49782,49862,49943,50023,50103,
+ 50183,50263,50342,50422,50501,50580,50659,50738,
+ 50816,50895,50973,51051,51129,51207,51285,51362,
+ 51439,51517,51594,51671,51747,51824,51900,51977,
+ 52053,52129,52205,52280,52356,52432,52507,52582,
+ 52657,52732,52807,52881,52956,53030,53104,53178,
+ 53252,53326,53400,53473,53546,53620,53693,53766,
+ 53839,53911,53984,54056,54129,54201,54273,54345,
+ 54417,54489,54560,54632,54703,54774,54845,54916,
+ 54987,55058,55129,55199,55269,55340,55410,55480,
+ 55550,55620,55689,55759,55828,55898,55967,56036,
+ 56105,56174,56243,56311,56380,56448,56517,56585,
+ 56653,56721,56789,56857,56924,56992,57059,57127,
+ 57194,57261,57328,57395,57462,57529,57595,57662,
+ 57728,57795,57861,57927,57993,58059,58125,58191,
+ 58256,58322,58387,58453,58518,58583,58648,58713,
+ 58778,58843,58908,58972,59037,59101,59165,59230,
+ 59294,59358,59422,59486,59549,59613,59677,59740,
+ 59804,59867,59930,59993,60056,60119,60182,60245,
+ 60308,60370,60433,60495,60558,60620,60682,60744,
+ 60806,60868,60930,60992,61054,61115,61177,61238,
+ 61300,61361,61422,61483,61544,61605,61666,61727,
+ 61788,61848,61909,61969,62030,62090,62150,62211,
+ 62271,62331,62391,62450,62510,62570,62630,62689,
+ 62749,62808,62867,62927,62986,63045,63104,63163,
+ 63222,63281,63340,63398,63457,63515,63574,63632,
+ 63691,63749,63807,63865,63923,63981,64039,64097,
+ 64155,64212,64270,64328,64385,64443,64500,64557,
+ 64614,64672,64729,64786,64843,64900,64956,65013,
+ 65070,65126,65183,65239,65296,65352,65409,65465
+};
+
+const png_byte png_sRGB_delta[512] =
+{
+ 207,201,158,129,113,100,90,82,77,72,68,64,61,59,56,54,
+ 52,50,49,47,46,45,43,42,41,40,39,39,38,37,36,36,
+ 35,34,34,33,33,32,32,31,31,30,30,30,29,29,28,28,
+ 28,27,27,27,27,26,26,26,25,25,25,25,24,24,24,24,
+ 23,23,23,23,23,22,22,22,22,22,22,21,21,21,21,21,
+ 21,20,20,20,20,20,20,20,20,19,19,19,19,19,19,19,
+ 19,18,18,18,18,18,18,18,18,18,18,17,17,17,17,17,
+ 17,17,17,17,17,17,16,16,16,16,16,16,16,16,16,16,
+ 16,16,16,16,15,15,15,15,15,15,15,15,15,15,15,15,
+ 15,15,15,15,14,14,14,14,14,14,14,14,14,14,14,14,
+ 14,14,14,14,14,14,14,13,13,13,13,13,13,13,13,13,
+ 13,13,13,13,13,13,13,13,13,13,13,13,13,13,12,12,
+ 12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,
+ 12,12,12,12,12,12,12,12,12,12,12,12,11,11,11,11,
+ 11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,
+ 11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,
+ 11,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,
+ 10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,
+ 10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,
+ 10,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,
+ 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,
+ 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,
+ 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,
+ 9,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,
+ 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,
+ 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,
+ 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,
+ 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,
+ 8,8,8,8,8,8,8,8,8,7,7,7,7,7,7,7,
+ 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,
+ 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,
+ 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7
+};
+#endif /* SIMPLIFIED READ/WRITE sRGB support */
+
+/* SIMPLIFIED READ/WRITE SUPPORT */
+#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\
+ defined(PNG_SIMPLIFIED_WRITE_SUPPORTED)
+static int
+png_image_free_function(png_voidp argument)
+{
+ png_imagep image = png_voidcast(png_imagep, argument);
+ png_controlp cp = image->opaque;
+ png_control c;
+
+ /* Double check that we have a png_ptr - it should be impossible to get here
+ * without one.
+ */
+ if (cp->png_ptr == NULL)
+ return 0;
+
+ /* First free any data held in the control structure. */
+# ifdef PNG_STDIO_SUPPORTED
+ if (cp->owned_file)
+ {
+ FILE *fp = png_voidcast(FILE*, cp->png_ptr->io_ptr);
+ cp->owned_file = 0;
+
+ /* Ignore errors here. */
+ if (fp != NULL)
+ {
+ cp->png_ptr->io_ptr = NULL;
+ (void)fclose(fp);
+ }
+ }
+# endif
+
+ /* Copy the control structure so that the original, allocated, version can be
+ * safely freed. Notice that a png_error here stops the remainder of the
+ * cleanup, but this is probably fine because that would indicate bad memory
+ * problems anyway.
+ */
+ c = *cp;
+ image->opaque = &c;
+ png_free(c.png_ptr, cp);
+
+ /* Then the structures, calling the correct API. */
+ if (c.for_write)
+ {
+# ifdef PNG_SIMPLIFIED_WRITE_SUPPORTED
+ png_destroy_write_struct(&c.png_ptr, &c.info_ptr);
+# else
+ png_error(c.png_ptr, "simplified write not supported");
+# endif
+ }
+ else
+ {
+# ifdef PNG_SIMPLIFIED_READ_SUPPORTED
+ png_destroy_read_struct(&c.png_ptr, &c.info_ptr, NULL);
+# else
+ png_error(c.png_ptr, "simplified read not supported");
+# endif
+ }
+
+ /* Success. */
+ return 1;
+}
+
+void PNGAPI
+png_image_free(png_imagep image)
+{
+ /* Safely call the real function, but only if doing so is safe at this point
+ * (if not inside an error handling context). Otherwise assume
+ * png_safe_execute will call this API after the return.
+ */
+ if (image != NULL && image->opaque != NULL &&
+ image->opaque->error_buf == NULL)
+ {
+ /* Ignore errors here: */
+ (void)png_safe_execute(image, png_image_free_function, image);
+ image->opaque = NULL;
+ }
+}
+
+int /* PRIVATE */
+png_image_error(png_imagep image, png_const_charp error_message)
+{
+ /* Utility to log an error. */
+ png_safecat(image->message, (sizeof image->message), 0, error_message);
+ image->warning_or_error |= PNG_IMAGE_ERROR;
+ png_image_free(image);
+ return 0;
+}
+
+#endif /* SIMPLIFIED READ/WRITE */
+#endif /* defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) */
diff --git a/ml/dlib/dlib/external/libpng/png.h b/ml/dlib/dlib/external/libpng/png.h
new file mode 100644
index 000000000..527392738
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/png.h
@@ -0,0 +1,3319 @@
+
+/* png.h - header file for PNG reference library
+ *
+ * libpng version 1.6.7 - November 14, 2013
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license (See LICENSE, below)
+ *
+ * Authors and maintainers:
+ * libpng versions 0.71, May 1995, through 0.88, January 1996: Guy Schalnat
+ * libpng versions 0.89c, June 1996, through 0.96, May 1997: Andreas Dilger
+ * libpng versions 0.97, January 1998, through 1.6.7 - November 14, 2013: Glenn
+ * See also "Contributing Authors", below.
+ *
+ * Note about libpng version numbers:
+ *
+ * Due to various miscommunications, unforeseen code incompatibilities
+ * and occasional factors outside the authors' control, version numbering
+ * on the library has not always been consistent and straightforward.
+ * The following table summarizes matters since version 0.89c, which was
+ * the first widely used release:
+ *
+ * source png.h png.h shared-lib
+ * version string int version
+ * ------- ------ ----- ----------
+ * 0.89c "1.0 beta 3" 0.89 89 1.0.89
+ * 0.90 "1.0 beta 4" 0.90 90 0.90 [should have been 2.0.90]
+ * 0.95 "1.0 beta 5" 0.95 95 0.95 [should have been 2.0.95]
+ * 0.96 "1.0 beta 6" 0.96 96 0.96 [should have been 2.0.96]
+ * 0.97b "1.00.97 beta 7" 1.00.97 97 1.0.1 [should have been 2.0.97]
+ * 0.97c 0.97 97 2.0.97
+ * 0.98 0.98 98 2.0.98
+ * 0.99 0.99 98 2.0.99
+ * 0.99a-m 0.99 99 2.0.99
+ * 1.00 1.00 100 2.1.0 [100 should be 10000]
+ * 1.0.0 (from here on, the 100 2.1.0 [100 should be 10000]
+ * 1.0.1 png.h string is 10001 2.1.0
+ * 1.0.1a-e identical to the 10002 from here on, the shared library
+ * 1.0.2 source version) 10002 is 2.V where V is the source code
+ * 1.0.2a-b 10003 version, except as noted.
+ * 1.0.3 10003
+ * 1.0.3a-d 10004
+ * 1.0.4 10004
+ * 1.0.4a-f 10005
+ * 1.0.5 (+ 2 patches) 10005
+ * 1.0.5a-d 10006
+ * 1.0.5e-r 10100 (not source compatible)
+ * 1.0.5s-v 10006 (not binary compatible)
+ * 1.0.6 (+ 3 patches) 10006 (still binary incompatible)
+ * 1.0.6d-f 10007 (still binary incompatible)
+ * 1.0.6g 10007
+ * 1.0.6h 10007 10.6h (testing xy.z so-numbering)
+ * 1.0.6i 10007 10.6i
+ * 1.0.6j 10007 2.1.0.6j (incompatible with 1.0.0)
+ * 1.0.7beta11-14 DLLNUM 10007 2.1.0.7beta11-14 (binary compatible)
+ * 1.0.7beta15-18 1 10007 2.1.0.7beta15-18 (binary compatible)
+ * 1.0.7rc1-2 1 10007 2.1.0.7rc1-2 (binary compatible)
+ * 1.0.7 1 10007 (still compatible)
+ * 1.0.8beta1-4 1 10008 2.1.0.8beta1-4
+ * 1.0.8rc1 1 10008 2.1.0.8rc1
+ * 1.0.8 1 10008 2.1.0.8
+ * 1.0.9beta1-6 1 10009 2.1.0.9beta1-6
+ * 1.0.9rc1 1 10009 2.1.0.9rc1
+ * 1.0.9beta7-10 1 10009 2.1.0.9beta7-10
+ * 1.0.9rc2 1 10009 2.1.0.9rc2
+ * 1.0.9 1 10009 2.1.0.9
+ * 1.0.10beta1 1 10010 2.1.0.10beta1
+ * 1.0.10rc1 1 10010 2.1.0.10rc1
+ * 1.0.10 1 10010 2.1.0.10
+ * 1.0.11beta1-3 1 10011 2.1.0.11beta1-3
+ * 1.0.11rc1 1 10011 2.1.0.11rc1
+ * 1.0.11 1 10011 2.1.0.11
+ * 1.0.12beta1-2 2 10012 2.1.0.12beta1-2
+ * 1.0.12rc1 2 10012 2.1.0.12rc1
+ * 1.0.12 2 10012 2.1.0.12
+ * 1.1.0a-f - 10100 2.1.1.0a-f (branch abandoned)
+ * 1.2.0beta1-2 2 10200 2.1.2.0beta1-2
+ * 1.2.0beta3-5 3 10200 3.1.2.0beta3-5
+ * 1.2.0rc1 3 10200 3.1.2.0rc1
+ * 1.2.0 3 10200 3.1.2.0
+ * 1.2.1beta1-4 3 10201 3.1.2.1beta1-4
+ * 1.2.1rc1-2 3 10201 3.1.2.1rc1-2
+ * 1.2.1 3 10201 3.1.2.1
+ * 1.2.2beta1-6 12 10202 12.so.0.1.2.2beta1-6
+ * 1.0.13beta1 10 10013 10.so.0.1.0.13beta1
+ * 1.0.13rc1 10 10013 10.so.0.1.0.13rc1
+ * 1.2.2rc1 12 10202 12.so.0.1.2.2rc1
+ * 1.0.13 10 10013 10.so.0.1.0.13
+ * 1.2.2 12 10202 12.so.0.1.2.2
+ * 1.2.3rc1-6 12 10203 12.so.0.1.2.3rc1-6
+ * 1.2.3 12 10203 12.so.0.1.2.3
+ * 1.2.4beta1-3 13 10204 12.so.0.1.2.4beta1-3
+ * 1.0.14rc1 13 10014 10.so.0.1.0.14rc1
+ * 1.2.4rc1 13 10204 12.so.0.1.2.4rc1
+ * 1.0.14 10 10014 10.so.0.1.0.14
+ * 1.2.4 13 10204 12.so.0.1.2.4
+ * 1.2.5beta1-2 13 10205 12.so.0.1.2.5beta1-2
+ * 1.0.15rc1-3 10 10015 10.so.0.1.0.15rc1-3
+ * 1.2.5rc1-3 13 10205 12.so.0.1.2.5rc1-3
+ * 1.0.15 10 10015 10.so.0.1.0.15
+ * 1.2.5 13 10205 12.so.0.1.2.5
+ * 1.2.6beta1-4 13 10206 12.so.0.1.2.6beta1-4
+ * 1.0.16 10 10016 10.so.0.1.0.16
+ * 1.2.6 13 10206 12.so.0.1.2.6
+ * 1.2.7beta1-2 13 10207 12.so.0.1.2.7beta1-2
+ * 1.0.17rc1 10 10017 12.so.0.1.0.17rc1
+ * 1.2.7rc1 13 10207 12.so.0.1.2.7rc1
+ * 1.0.17 10 10017 12.so.0.1.0.17
+ * 1.2.7 13 10207 12.so.0.1.2.7
+ * 1.2.8beta1-5 13 10208 12.so.0.1.2.8beta1-5
+ * 1.0.18rc1-5 10 10018 12.so.0.1.0.18rc1-5
+ * 1.2.8rc1-5 13 10208 12.so.0.1.2.8rc1-5
+ * 1.0.18 10 10018 12.so.0.1.0.18
+ * 1.2.8 13 10208 12.so.0.1.2.8
+ * 1.2.9beta1-3 13 10209 12.so.0.1.2.9beta1-3
+ * 1.2.9beta4-11 13 10209 12.so.0.9[.0]
+ * 1.2.9rc1 13 10209 12.so.0.9[.0]
+ * 1.2.9 13 10209 12.so.0.9[.0]
+ * 1.2.10beta1-7 13 10210 12.so.0.10[.0]
+ * 1.2.10rc1-2 13 10210 12.so.0.10[.0]
+ * 1.2.10 13 10210 12.so.0.10[.0]
+ * 1.4.0beta1-5 14 10400 14.so.0.0[.0]
+ * 1.2.11beta1-4 13 10211 12.so.0.11[.0]
+ * 1.4.0beta7-8 14 10400 14.so.0.0[.0]
+ * 1.2.11 13 10211 12.so.0.11[.0]
+ * 1.2.12 13 10212 12.so.0.12[.0]
+ * 1.4.0beta9-14 14 10400 14.so.0.0[.0]
+ * 1.2.13 13 10213 12.so.0.13[.0]
+ * 1.4.0beta15-36 14 10400 14.so.0.0[.0]
+ * 1.4.0beta37-87 14 10400 14.so.14.0[.0]
+ * 1.4.0rc01 14 10400 14.so.14.0[.0]
+ * 1.4.0beta88-109 14 10400 14.so.14.0[.0]
+ * 1.4.0rc02-08 14 10400 14.so.14.0[.0]
+ * 1.4.0 14 10400 14.so.14.0[.0]
+ * 1.4.1beta01-03 14 10401 14.so.14.1[.0]
+ * 1.4.1rc01 14 10401 14.so.14.1[.0]
+ * 1.4.1beta04-12 14 10401 14.so.14.1[.0]
+ * 1.4.1 14 10401 14.so.14.1[.0]
+ * 1.4.2 14 10402 14.so.14.2[.0]
+ * 1.4.3 14 10403 14.so.14.3[.0]
+ * 1.4.4 14 10404 14.so.14.4[.0]
+ * 1.5.0beta01-58 15 10500 15.so.15.0[.0]
+ * 1.5.0rc01-07 15 10500 15.so.15.0[.0]
+ * 1.5.0 15 10500 15.so.15.0[.0]
+ * 1.5.1beta01-11 15 10501 15.so.15.1[.0]
+ * 1.5.1rc01-02 15 10501 15.so.15.1[.0]
+ * 1.5.1 15 10501 15.so.15.1[.0]
+ * 1.5.2beta01-03 15 10502 15.so.15.2[.0]
+ * 1.5.2rc01-03 15 10502 15.so.15.2[.0]
+ * 1.5.2 15 10502 15.so.15.2[.0]
+ * 1.5.3beta01-10 15 10503 15.so.15.3[.0]
+ * 1.5.3rc01-02 15 10503 15.so.15.3[.0]
+ * 1.5.3beta11 15 10503 15.so.15.3[.0]
+ * 1.5.3 [omitted]
+ * 1.5.4beta01-08 15 10504 15.so.15.4[.0]
+ * 1.5.4rc01 15 10504 15.so.15.4[.0]
+ * 1.5.4 15 10504 15.so.15.4[.0]
+ * 1.5.5beta01-08 15 10505 15.so.15.5[.0]
+ * 1.5.5rc01 15 10505 15.so.15.5[.0]
+ * 1.5.5 15 10505 15.so.15.5[.0]
+ * 1.5.6beta01-07 15 10506 15.so.15.6[.0]
+ * 1.5.6rc01-03 15 10506 15.so.15.6[.0]
+ * 1.5.6 15 10506 15.so.15.6[.0]
+ * 1.5.7beta01-05 15 10507 15.so.15.7[.0]
+ * 1.5.7rc01-03 15 10507 15.so.15.7[.0]
+ * 1.5.7 15 10507 15.so.15.7[.0]
+ * 1.6.0beta01-40 16 10600 16.so.16.0[.0]
+ * 1.6.0rc01-08 16 10600 16.so.16.0[.0]
+ * 1.6.0 16 10600 16.so.16.0[.0]
+ * 1.6.1beta01-09 16 10601 16.so.16.1[.0]
+ * 1.6.1rc01 16 10601 16.so.16.1[.0]
+ * 1.6.1 16 10601 16.so.16.1[.0]
+ * 1.6.2beta01 16 10602 16.so.16.2[.0]
+ * 1.6.2rc01-06 16 10602 16.so.16.2[.0]
+ * 1.6.2 16 10602 16.so.16.2[.0]
+ * 1.6.3beta01-11 16 10603 16.so.16.3[.0]
+ * 1.6.3rc01 16 10603 16.so.16.3[.0]
+ * 1.6.3 16 10603 16.so.16.3[.0]
+ * 1.6.4beta01-02 16 10604 16.so.16.4[.0]
+ * 1.6.4rc01 16 10604 16.so.16.4[.0]
+ * 1.6.4 16 10604 16.so.16.4[.0]
+ * 1.6.5 16 10605 16.so.16.5[.0]
+ * 1.6.6 16 10606 16.so.16.6[.0]
+ * 1.6.7beta01-04 16 10607 16.so.16.7[.0]
+ * 1.6.7rc01-02 16 10607 16.so.16.7[.0]
+ * 1.6.7 16 10607 16.so.16.7[.0]
+ *
+ * Henceforth the source version will match the shared-library major
+ * and minor numbers; the shared-library major version number will be
+ * used for changes in backward compatibility, as it is intended. The
+ * PNG_LIBPNG_VER macro, which is not used within libpng but is available
+ * for applications, is an unsigned integer of the form xyyzz corresponding
+ * to the source version x.y.z (leading zeros in y and z). Beta versions
+ * were given the previous public release number plus a letter, until
+ * version 1.0.6j; from then on they were given the upcoming public
+ * release number plus "betaNN" or "rcNN".
+ *
+ * Binary incompatibility exists only when applications make direct access
+ * to the info_ptr or png_ptr members through png.h, and the compiled
+ * application is loaded with a different version of the library.
+ *
+ * DLLNUM will change each time there are forward or backward changes
+ * in binary compatibility (e.g., when a new feature is added).
+ *
+ * See libpng-manual.txt or libpng.3 for more information. The PNG
+ * specification is available as a W3C Recommendation and as an ISO
+ * Specification, <http://www.w3.org/TR/2003/REC-PNG-20031110/
+ */
+
+/*
+ * COPYRIGHT NOTICE, DISCLAIMER, and LICENSE:
+ *
+ * If you modify libpng you may insert additional notices immediately following
+ * this sentence.
+ *
+ * This code is released under the libpng license.
+ *
+ * libpng versions 1.2.6, August 15, 2004, through 1.6.7, November 14, 2013, are
+ * Copyright (c) 2004, 2006-2013 Glenn Randers-Pehrson, and are
+ * distributed according to the same disclaimer and license as libpng-1.2.5
+ * with the following individual added to the list of Contributing Authors:
+ *
+ * Cosmin Truta
+ *
+ * libpng versions 1.0.7, July 1, 2000, through 1.2.5, October 3, 2002, are
+ * Copyright (c) 2000-2002 Glenn Randers-Pehrson, and are
+ * distributed according to the same disclaimer and license as libpng-1.0.6
+ * with the following individuals added to the list of Contributing Authors:
+ *
+ * Simon-Pierre Cadieux
+ * Eric S. Raymond
+ * Gilles Vollant
+ *
+ * and with the following additions to the disclaimer:
+ *
+ * There is no warranty against interference with your enjoyment of the
+ * library or against infringement. There is no warranty that our
+ * efforts or the library will fulfill any of your particular purposes
+ * or needs. This library is provided with all faults, and the entire
+ * risk of satisfactory quality, performance, accuracy, and effort is with
+ * the user.
+ *
+ * libpng versions 0.97, January 1998, through 1.0.6, March 20, 2000, are
+ * Copyright (c) 1998, 1999, 2000 Glenn Randers-Pehrson, and are
+ * distributed according to the same disclaimer and license as libpng-0.96,
+ * with the following individuals added to the list of Contributing Authors:
+ *
+ * Tom Lane
+ * Glenn Randers-Pehrson
+ * Willem van Schaik
+ *
+ * libpng versions 0.89, June 1996, through 0.96, May 1997, are
+ * Copyright (c) 1996, 1997 Andreas Dilger
+ * Distributed according to the same disclaimer and license as libpng-0.88,
+ * with the following individuals added to the list of Contributing Authors:
+ *
+ * John Bowler
+ * Kevin Bracey
+ * Sam Bushell
+ * Magnus Holmgren
+ * Greg Roelofs
+ * Tom Tanner
+ *
+ * libpng versions 0.5, May 1995, through 0.88, January 1996, are
+ * Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.
+ *
+ * For the purposes of this copyright and license, "Contributing Authors"
+ * is defined as the following set of individuals:
+ *
+ * Andreas Dilger
+ * Dave Martindale
+ * Guy Eric Schalnat
+ * Paul Schmidt
+ * Tim Wegner
+ *
+ * The PNG Reference Library is supplied "AS IS". The Contributing Authors
+ * and Group 42, Inc. disclaim all warranties, expressed or implied,
+ * including, without limitation, the warranties of merchantability and of
+ * fitness for any purpose. The Contributing Authors and Group 42, Inc.
+ * assume no liability for direct, indirect, incidental, special, exemplary,
+ * or consequential damages, which may result from the use of the PNG
+ * Reference Library, even if advised of the possibility of such damage.
+ *
+ * Permission is hereby granted to use, copy, modify, and distribute this
+ * source code, or portions hereof, for any purpose, without fee, subject
+ * to the following restrictions:
+ *
+ * 1. The origin of this source code must not be misrepresented.
+ *
+ * 2. Altered versions must be plainly marked as such and must not
+ * be misrepresented as being the original source.
+ *
+ * 3. This Copyright notice may not be removed or altered from
+ * any source or altered source distribution.
+ *
+ * The Contributing Authors and Group 42, Inc. specifically permit, without
+ * fee, and encourage the use of this source code as a component to
+ * supporting the PNG file format in commercial products. If you use this
+ * source code in a product, acknowledgment is not required but would be
+ * appreciated.
+ */
+
+/*
+ * A "png_get_copyright" function is available, for convenient use in "about"
+ * boxes and the like:
+ *
+ * printf("%s", png_get_copyright(NULL));
+ *
+ * Also, the PNG logo (in PNG format, of course) is supplied in the
+ * files "pngbar.png" and "pngbar.jpg (88x31) and "pngnow.png" (98x31).
+ */
+
+/*
+ * Libpng is OSI Certified Open Source Software. OSI Certified is a
+ * certification mark of the Open Source Initiative.
+ */
+
+/*
+ * The contributing authors would like to thank all those who helped
+ * with testing, bug fixes, and patience. This wouldn't have been
+ * possible without all of you.
+ *
+ * Thanks to Frank J. T. Wojcik for helping with the documentation.
+ */
+
+/*
+ * Y2K compliance in libpng:
+ * =========================
+ *
+ * November 14, 2013
+ *
+ * Since the PNG Development group is an ad-hoc body, we can't make
+ * an official declaration.
+ *
+ * This is your unofficial assurance that libpng from version 0.71 and
+ * upward through 1.6.7 are Y2K compliant. It is my belief that
+ * earlier versions were also Y2K compliant.
+ *
+ * Libpng only has two year fields. One is a 2-byte unsigned integer
+ * that will hold years up to 65535. The other, which is deprecated,
+ * holds the date in text format, and will hold years up to 9999.
+ *
+ * The integer is
+ * "png_uint_16 year" in png_time_struct.
+ *
+ * The string is
+ * "char time_buffer[29]" in png_struct. This is no longer used
+ * in libpng-1.6.x and will be removed from libpng-1.7.0.
+ *
+ * There are seven time-related functions:
+ * png.c: png_convert_to_rfc_1123_buffer() in png.c
+ * (formerly png_convert_to_rfc_1123() prior to libpng-1.5.x and
+ * png_convert_to_rfc_1152() in error prior to libpng-0.98)
+ * png_convert_from_struct_tm() in pngwrite.c, called in pngwrite.c
+ * png_convert_from_time_t() in pngwrite.c
+ * png_get_tIME() in pngget.c
+ * png_handle_tIME() in pngrutil.c, called in pngread.c
+ * png_set_tIME() in pngset.c
+ * png_write_tIME() in pngwutil.c, called in pngwrite.c
+ *
+ * All handle dates properly in a Y2K environment. The
+ * png_convert_from_time_t() function calls gmtime() to convert from system
+ * clock time, which returns (year - 1900), which we properly convert to
+ * the full 4-digit year. There is a possibility that libpng applications
+ * are not passing 4-digit years into the png_convert_to_rfc_1123_buffer()
+ * function, or that they are incorrectly passing only a 2-digit year
+ * instead of "year - 1900" into the png_convert_from_struct_tm() function,
+ * but this is not under our control. The libpng documentation has always
+ * stated that it works with 4-digit years, and the APIs have been
+ * documented as such.
+ *
+ * The tIME chunk itself is also Y2K compliant. It uses a 2-byte unsigned
+ * integer to hold the year, and can hold years as large as 65535.
+ *
+ * zlib, upon which libpng depends, is also Y2K compliant. It contains
+ * no date-related code.
+ *
+ * Glenn Randers-Pehrson
+ * libpng maintainer
+ * PNG Development Group
+ */
+
+#ifndef PNG_H
+#define PNG_H
+
+/* This is not the place to learn how to use libpng. The file libpng-manual.txt
+ * describes how to use libpng, and the file example.c summarizes it
+ * with some code on which to build. This file is useful for looking
+ * at the actual function definitions and structure components.
+ *
+ * If you just need to read a PNG file and don't want to read the documentation
+ * skip to the end of this file and read the section entitled 'simplified API'.
+ */
+
+/* Version information for png.h - this should match the version in png.c */
+#define PNG_LIBPNG_VER_STRING "1.6.7"
+#define PNG_HEADER_VERSION_STRING \
+ " libpng version 1.6.7 - November 14, 2013\n"
+
+#define PNG_LIBPNG_VER_SONUM 16
+#define PNG_LIBPNG_VER_DLLNUM 16
+
+/* These should match the first 3 components of PNG_LIBPNG_VER_STRING: */
+#define PNG_LIBPNG_VER_MAJOR 1
+#define PNG_LIBPNG_VER_MINOR 6
+#define PNG_LIBPNG_VER_RELEASE 7
+
+/* This should match the numeric part of the final component of
+ * PNG_LIBPNG_VER_STRING, omitting any leading zero:
+ */
+
+#define PNG_LIBPNG_VER_BUILD 0
+
+/* Release Status */
+#define PNG_LIBPNG_BUILD_ALPHA 1
+#define PNG_LIBPNG_BUILD_BETA 2
+#define PNG_LIBPNG_BUILD_RC 3
+#define PNG_LIBPNG_BUILD_STABLE 4
+#define PNG_LIBPNG_BUILD_RELEASE_STATUS_MASK 7
+
+/* Release-Specific Flags */
+#define PNG_LIBPNG_BUILD_PATCH 8 /* Can be OR'ed with
+ PNG_LIBPNG_BUILD_STABLE only */
+#define PNG_LIBPNG_BUILD_PRIVATE 16 /* Cannot be OR'ed with
+ PNG_LIBPNG_BUILD_SPECIAL */
+#define PNG_LIBPNG_BUILD_SPECIAL 32 /* Cannot be OR'ed with
+ PNG_LIBPNG_BUILD_PRIVATE */
+
+#define PNG_LIBPNG_BUILD_BASE_TYPE PNG_LIBPNG_BUILD_STABLE
+
+/* Careful here. At one time, Guy wanted to use 082, but that would be octal.
+ * We must not include leading zeros.
+ * Versions 0.7 through 1.0.0 were in the range 0 to 100 here (only
+ * version 1.0.0 was mis-numbered 100 instead of 10000). From
+ * version 1.0.1 it's xxyyzz, where x=major, y=minor, z=release
+ */
+#define PNG_LIBPNG_VER 10607 /* 1.6.7 */
+
+/* Library configuration: these options cannot be changed after
+ * the library has been built.
+ */
+#ifndef PNGLCONF_H
+ /* If pnglibconf.h is missing, you can
+ * copy scripts/pnglibconf.h.prebuilt to pnglibconf.h
+ */
+# include "pnglibconf.h"
+#endif
+
+#ifndef PNG_VERSION_INFO_ONLY
+ /* Machine specific configuration. */
+# include "pngconf.h"
+#endif
+
+/*
+ * Added at libpng-1.2.8
+ *
+ * Ref MSDN: Private as priority over Special
+ * VS_FF_PRIVATEBUILD File *was not* built using standard release
+ * procedures. If this value is given, the StringFileInfo block must
+ * contain a PrivateBuild string.
+ *
+ * VS_FF_SPECIALBUILD File *was* built by the original company using
+ * standard release procedures but is a variation of the standard
+ * file of the same version number. If this value is given, the
+ * StringFileInfo block must contain a SpecialBuild string.
+ */
+
+#ifdef PNG_USER_PRIVATEBUILD /* From pnglibconf.h */
+# define PNG_LIBPNG_BUILD_TYPE \
+ (PNG_LIBPNG_BUILD_BASE_TYPE | PNG_LIBPNG_BUILD_PRIVATE)
+#else
+# ifdef PNG_LIBPNG_SPECIALBUILD
+# define PNG_LIBPNG_BUILD_TYPE \
+ (PNG_LIBPNG_BUILD_BASE_TYPE | PNG_LIBPNG_BUILD_SPECIAL)
+# else
+# define PNG_LIBPNG_BUILD_TYPE (PNG_LIBPNG_BUILD_BASE_TYPE)
+# endif
+#endif
+
+#ifndef PNG_VERSION_INFO_ONLY
+
+/* Inhibit C++ name-mangling for libpng functions but not for system calls. */
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+/* Version information for C files, stored in png.c. This had better match
+ * the version above.
+ */
+#define png_libpng_ver png_get_header_ver(NULL)
+
+/* This file is arranged in several sections:
+ *
+ * 1. Any configuration options that can be specified by for the application
+ * code when it is built. (Build time configuration is in pnglibconf.h)
+ * 2. Type definitions (base types are defined in pngconf.h), structure
+ * definitions.
+ * 3. Exported library functions.
+ * 4. Simplified API.
+ *
+ * The library source code has additional files (principally pngpriv.h) that
+ * allow configuration of the library.
+ */
+/* Section 1: run time configuration
+ * See pnglibconf.h for build time configuration
+ *
+ * Run time configuration allows the application to choose between
+ * implementations of certain arithmetic APIs. The default is set
+ * at build time and recorded in pnglibconf.h, but it is safe to
+ * override these (and only these) settings. Note that this won't
+ * change what the library does, only application code, and the
+ * settings can (and probably should) be made on a per-file basis
+ * by setting the #defines before including png.h
+ *
+ * Use macros to read integers from PNG data or use the exported
+ * functions?
+ * PNG_USE_READ_MACROS: use the macros (see below) Note that
+ * the macros evaluate their argument multiple times.
+ * PNG_NO_USE_READ_MACROS: call the relevant library function.
+ *
+ * Use the alternative algorithm for compositing alpha samples that
+ * does not use division?
+ * PNG_READ_COMPOSITE_NODIV_SUPPORTED: use the 'no division'
+ * algorithm.
+ * PNG_NO_READ_COMPOSITE_NODIV: use the 'division' algorithm.
+ *
+ * How to handle benign errors if PNG_ALLOW_BENIGN_ERRORS is
+ * false?
+ * PNG_ALLOW_BENIGN_ERRORS: map calls to the benign error
+ * APIs to png_warning.
+ * Otherwise the calls are mapped to png_error.
+ */
+
+/* Section 2: type definitions, including structures and compile time
+ * constants.
+ * See pngconf.h for base types that vary by machine/system
+ */
+
+/* This triggers a compiler error in png.c, if png.c and png.h
+ * do not agree upon the version number.
+ */
+typedef char* png_libpng_version_1_6_7;
+
+/* Basic control structions. Read libpng-manual.txt or libpng.3 for more info.
+ *
+ * png_struct is the cache of information used while reading or writing a single
+ * PNG file. One of these is always required, although the simplified API
+ * (below) hides the creation and destruction of it.
+ */
+typedef struct png_struct_def png_struct;
+typedef const png_struct * png_const_structp;
+typedef png_struct * png_structp;
+typedef png_struct * * png_structpp;
+
+/* png_info contains information read from or to be written to a PNG file. One
+ * or more of these must exist while reading or creating a PNG file. The
+ * information is not used by libpng during read but is used to control what
+ * gets written when a PNG file is created. "png_get_" function calls read
+ * information during read and "png_set_" functions calls write information
+ * when creating a PNG.
+ * been moved into a separate header file that is not accessible to
+ * applications. Read libpng-manual.txt or libpng.3 for more info.
+ */
+typedef struct png_info_def png_info;
+typedef png_info * png_infop;
+typedef const png_info * png_const_infop;
+typedef png_info * * png_infopp;
+
+/* Types with names ending 'p' are pointer types. The corresponding types with
+ * names ending 'rp' are identical pointer types except that the pointer is
+ * marked 'restrict', which means that it is the only pointer to the object
+ * passed to the function. Applications should not use the 'restrict' types;
+ * it is always valid to pass 'p' to a pointer with a function argument of the
+ * corresponding 'rp' type. Different compilers have different rules with
+ * regard to type matching in the presence of 'restrict'. For backward
+ * compatibility libpng callbacks never have 'restrict' in their parameters and,
+ * consequentially, writing portable application code is extremely difficult if
+ * an attempt is made to use 'restrict'.
+ */
+typedef png_struct * PNG_RESTRICT png_structrp;
+typedef const png_struct * PNG_RESTRICT png_const_structrp;
+typedef png_info * PNG_RESTRICT png_inforp;
+typedef const png_info * PNG_RESTRICT png_const_inforp;
+
+/* Three color definitions. The order of the red, green, and blue, (and the
+ * exact size) is not important, although the size of the fields need to
+ * be png_byte or png_uint_16 (as defined below).
+ */
+typedef struct png_color_struct
+{
+ png_byte red;
+ png_byte green;
+ png_byte blue;
+} png_color;
+typedef png_color * png_colorp;
+typedef const png_color * png_const_colorp;
+typedef png_color * * png_colorpp;
+
+typedef struct png_color_16_struct
+{
+ png_byte index; /* used for palette files */
+ png_uint_16 red; /* for use in red green blue files */
+ png_uint_16 green;
+ png_uint_16 blue;
+ png_uint_16 gray; /* for use in grayscale files */
+} png_color_16;
+typedef png_color_16 * png_color_16p;
+typedef const png_color_16 * png_const_color_16p;
+typedef png_color_16 * * png_color_16pp;
+
+typedef struct png_color_8_struct
+{
+ png_byte red; /* for use in red green blue files */
+ png_byte green;
+ png_byte blue;
+ png_byte gray; /* for use in grayscale files */
+ png_byte alpha; /* for alpha channel files */
+} png_color_8;
+typedef png_color_8 * png_color_8p;
+typedef const png_color_8 * png_const_color_8p;
+typedef png_color_8 * * png_color_8pp;
+
+/*
+ * The following two structures are used for the in-core representation
+ * of sPLT chunks.
+ */
+typedef struct png_sPLT_entry_struct
+{
+ png_uint_16 red;
+ png_uint_16 green;
+ png_uint_16 blue;
+ png_uint_16 alpha;
+ png_uint_16 frequency;
+} png_sPLT_entry;
+typedef png_sPLT_entry * png_sPLT_entryp;
+typedef const png_sPLT_entry * png_const_sPLT_entryp;
+typedef png_sPLT_entry * * png_sPLT_entrypp;
+
+/* When the depth of the sPLT palette is 8 bits, the color and alpha samples
+ * occupy the LSB of their respective members, and the MSB of each member
+ * is zero-filled. The frequency member always occupies the full 16 bits.
+ */
+
+typedef struct png_sPLT_struct
+{
+ png_charp name; /* palette name */
+ png_byte depth; /* depth of palette samples */
+ png_sPLT_entryp entries; /* palette entries */
+ png_int_32 nentries; /* number of palette entries */
+} png_sPLT_t;
+typedef png_sPLT_t * png_sPLT_tp;
+typedef const png_sPLT_t * png_const_sPLT_tp;
+typedef png_sPLT_t * * png_sPLT_tpp;
+
+#ifdef PNG_TEXT_SUPPORTED
+/* png_text holds the contents of a text/ztxt/itxt chunk in a PNG file,
+ * and whether that contents is compressed or not. The "key" field
+ * points to a regular zero-terminated C string. The "text" fields can be a
+ * regular C string, an empty string, or a NULL pointer.
+ * However, the structure returned by png_get_text() will always contain
+ * the "text" field as a regular zero-terminated C string (possibly
+ * empty), never a NULL pointer, so it can be safely used in printf() and
+ * other string-handling functions. Note that the "itxt_length", "lang", and
+ * "lang_key" members of the structure only exist when the library is built
+ * with iTXt chunk support. Prior to libpng-1.4.0 the library was built by
+ * default without iTXt support. Also note that when iTXt *is* supported,
+ * the "lang" and "lang_key" fields contain NULL pointers when the
+ * "compression" field contains * PNG_TEXT_COMPRESSION_NONE or
+ * PNG_TEXT_COMPRESSION_zTXt. Note that the "compression value" is not the
+ * same as what appears in the PNG tEXt/zTXt/iTXt chunk's "compression flag"
+ * which is always 0 or 1, or its "compression method" which is always 0.
+ */
+typedef struct png_text_struct
+{
+ int compression; /* compression value:
+ -1: tEXt, none
+ 0: zTXt, deflate
+ 1: iTXt, none
+ 2: iTXt, deflate */
+ png_charp key; /* keyword, 1-79 character description of "text" */
+ png_charp text; /* comment, may be an empty string (ie "")
+ or a NULL pointer */
+ png_size_t text_length; /* length of the text string */
+ png_size_t itxt_length; /* length of the itxt string */
+ png_charp lang; /* language code, 0-79 characters
+ or a NULL pointer */
+ png_charp lang_key; /* keyword translated UTF-8 string, 0 or more
+ chars or a NULL pointer */
+} png_text;
+typedef png_text * png_textp;
+typedef const png_text * png_const_textp;
+typedef png_text * * png_textpp;
+#endif
+
+/* Supported compression types for text in PNG files (tEXt, and zTXt).
+ * The values of the PNG_TEXT_COMPRESSION_ defines should NOT be changed. */
+#define PNG_TEXT_COMPRESSION_NONE_WR -3
+#define PNG_TEXT_COMPRESSION_zTXt_WR -2
+#define PNG_TEXT_COMPRESSION_NONE -1
+#define PNG_TEXT_COMPRESSION_zTXt 0
+#define PNG_ITXT_COMPRESSION_NONE 1
+#define PNG_ITXT_COMPRESSION_zTXt 2
+#define PNG_TEXT_COMPRESSION_LAST 3 /* Not a valid value */
+
+/* png_time is a way to hold the time in an machine independent way.
+ * Two conversions are provided, both from time_t and struct tm. There
+ * is no portable way to convert to either of these structures, as far
+ * as I know. If you know of a portable way, send it to me. As a side
+ * note - PNG has always been Year 2000 compliant!
+ */
+typedef struct png_time_struct
+{
+ png_uint_16 year; /* full year, as in, 1995 */
+ png_byte month; /* month of year, 1 - 12 */
+ png_byte day; /* day of month, 1 - 31 */
+ png_byte hour; /* hour of day, 0 - 23 */
+ png_byte minute; /* minute of hour, 0 - 59 */
+ png_byte second; /* second of minute, 0 - 60 (for leap seconds) */
+} png_time;
+typedef png_time * png_timep;
+typedef const png_time * png_const_timep;
+typedef png_time * * png_timepp;
+
+#if defined(PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED) ||\
+ defined(PNG_USER_CHUNKS_SUPPORTED)
+/* png_unknown_chunk is a structure to hold queued chunks for which there is
+ * no specific support. The idea is that we can use this to queue
+ * up private chunks for output even though the library doesn't actually
+ * know about their semantics.
+ *
+ * The data in the structure is set by libpng on read and used on write.
+ */
+typedef struct png_unknown_chunk_t
+{
+ png_byte name[5]; /* Textual chunk name with '\0' terminator */
+ png_byte *data; /* Data, should not be modified on read! */
+ png_size_t size;
+
+ /* On write 'location' must be set using the flag values listed below.
+ * Notice that on read it is set by libpng however the values stored have
+ * more bits set than are listed below. Always treat the value as a
+ * bitmask. On write set only one bit - setting multiple bits may cause the
+ * chunk to be written in multiple places.
+ */
+ png_byte location; /* mode of operation at read time */
+}
+png_unknown_chunk;
+
+typedef png_unknown_chunk * png_unknown_chunkp;
+typedef const png_unknown_chunk * png_const_unknown_chunkp;
+typedef png_unknown_chunk * * png_unknown_chunkpp;
+#endif
+
+/* Flag values for the unknown chunk location byte. */
+#define PNG_HAVE_IHDR 0x01
+#define PNG_HAVE_PLTE 0x02
+#define PNG_AFTER_IDAT 0x08
+
+/* Maximum positive integer used in PNG is (2^31)-1 */
+#define PNG_UINT_31_MAX ((png_uint_32)0x7fffffffL)
+#define PNG_UINT_32_MAX ((png_uint_32)(-1))
+#define PNG_SIZE_MAX ((png_size_t)(-1))
+
+/* These are constants for fixed point values encoded in the
+ * PNG specification manner (x100000)
+ */
+#define PNG_FP_1 100000
+#define PNG_FP_HALF 50000
+#define PNG_FP_MAX ((png_fixed_point)0x7fffffffL)
+#define PNG_FP_MIN (-PNG_FP_MAX)
+
+/* These describe the color_type field in png_info. */
+/* color type masks */
+#define PNG_COLOR_MASK_PALETTE 1
+#define PNG_COLOR_MASK_COLOR 2
+#define PNG_COLOR_MASK_ALPHA 4
+
+/* color types. Note that not all combinations are legal */
+#define PNG_COLOR_TYPE_GRAY 0
+#define PNG_COLOR_TYPE_PALETTE (PNG_COLOR_MASK_COLOR | PNG_COLOR_MASK_PALETTE)
+#define PNG_COLOR_TYPE_RGB (PNG_COLOR_MASK_COLOR)
+#define PNG_COLOR_TYPE_RGB_ALPHA (PNG_COLOR_MASK_COLOR | PNG_COLOR_MASK_ALPHA)
+#define PNG_COLOR_TYPE_GRAY_ALPHA (PNG_COLOR_MASK_ALPHA)
+/* aliases */
+#define PNG_COLOR_TYPE_RGBA PNG_COLOR_TYPE_RGB_ALPHA
+#define PNG_COLOR_TYPE_GA PNG_COLOR_TYPE_GRAY_ALPHA
+
+/* This is for compression type. PNG 1.0-1.2 only define the single type. */
+#define PNG_COMPRESSION_TYPE_BASE 0 /* Deflate method 8, 32K window */
+#define PNG_COMPRESSION_TYPE_DEFAULT PNG_COMPRESSION_TYPE_BASE
+
+/* This is for filter type. PNG 1.0-1.2 only define the single type. */
+#define PNG_FILTER_TYPE_BASE 0 /* Single row per-byte filtering */
+#define PNG_INTRAPIXEL_DIFFERENCING 64 /* Used only in MNG datastreams */
+#define PNG_FILTER_TYPE_DEFAULT PNG_FILTER_TYPE_BASE
+
+/* These are for the interlacing type. These values should NOT be changed. */
+#define PNG_INTERLACE_NONE 0 /* Non-interlaced image */
+#define PNG_INTERLACE_ADAM7 1 /* Adam7 interlacing */
+#define PNG_INTERLACE_LAST 2 /* Not a valid value */
+
+/* These are for the oFFs chunk. These values should NOT be changed. */
+#define PNG_OFFSET_PIXEL 0 /* Offset in pixels */
+#define PNG_OFFSET_MICROMETER 1 /* Offset in micrometers (1/10^6 meter) */
+#define PNG_OFFSET_LAST 2 /* Not a valid value */
+
+/* These are for the pCAL chunk. These values should NOT be changed. */
+#define PNG_EQUATION_LINEAR 0 /* Linear transformation */
+#define PNG_EQUATION_BASE_E 1 /* Exponential base e transform */
+#define PNG_EQUATION_ARBITRARY 2 /* Arbitrary base exponential transform */
+#define PNG_EQUATION_HYPERBOLIC 3 /* Hyperbolic sine transformation */
+#define PNG_EQUATION_LAST 4 /* Not a valid value */
+
+/* These are for the sCAL chunk. These values should NOT be changed. */
+#define PNG_SCALE_UNKNOWN 0 /* unknown unit (image scale) */
+#define PNG_SCALE_METER 1 /* meters per pixel */
+#define PNG_SCALE_RADIAN 2 /* radians per pixel */
+#define PNG_SCALE_LAST 3 /* Not a valid value */
+
+/* These are for the pHYs chunk. These values should NOT be changed. */
+#define PNG_RESOLUTION_UNKNOWN 0 /* pixels/unknown unit (aspect ratio) */
+#define PNG_RESOLUTION_METER 1 /* pixels/meter */
+#define PNG_RESOLUTION_LAST 2 /* Not a valid value */
+
+/* These are for the sRGB chunk. These values should NOT be changed. */
+#define PNG_sRGB_INTENT_PERCEPTUAL 0
+#define PNG_sRGB_INTENT_RELATIVE 1
+#define PNG_sRGB_INTENT_SATURATION 2
+#define PNG_sRGB_INTENT_ABSOLUTE 3
+#define PNG_sRGB_INTENT_LAST 4 /* Not a valid value */
+
+/* This is for text chunks */
+#define PNG_KEYWORD_MAX_LENGTH 79
+
+/* Maximum number of entries in PLTE/sPLT/tRNS arrays */
+#define PNG_MAX_PALETTE_LENGTH 256
+
+/* These determine if an ancillary chunk's data has been successfully read
+ * from the PNG header, or if the application has filled in the corresponding
+ * data in the info_struct to be written into the output file. The values
+ * of the PNG_INFO_<chunk> defines should NOT be changed.
+ */
+#define PNG_INFO_gAMA 0x0001
+#define PNG_INFO_sBIT 0x0002
+#define PNG_INFO_cHRM 0x0004
+#define PNG_INFO_PLTE 0x0008
+#define PNG_INFO_tRNS 0x0010
+#define PNG_INFO_bKGD 0x0020
+#define PNG_INFO_hIST 0x0040
+#define PNG_INFO_pHYs 0x0080
+#define PNG_INFO_oFFs 0x0100
+#define PNG_INFO_tIME 0x0200
+#define PNG_INFO_pCAL 0x0400
+#define PNG_INFO_sRGB 0x0800 /* GR-P, 0.96a */
+#define PNG_INFO_iCCP 0x1000 /* ESR, 1.0.6 */
+#define PNG_INFO_sPLT 0x2000 /* ESR, 1.0.6 */
+#define PNG_INFO_sCAL 0x4000 /* ESR, 1.0.6 */
+#define PNG_INFO_IDAT 0x8000 /* ESR, 1.0.6 */
+
+/* This is used for the transformation routines, as some of them
+ * change these values for the row. It also should enable using
+ * the routines for other purposes.
+ */
+typedef struct png_row_info_struct
+{
+ png_uint_32 width; /* width of row */
+ png_size_t rowbytes; /* number of bytes in row */
+ png_byte color_type; /* color type of row */
+ png_byte bit_depth; /* bit depth of row */
+ png_byte channels; /* number of channels (1, 2, 3, or 4) */
+ png_byte pixel_depth; /* bits per pixel (depth * channels) */
+} png_row_info;
+
+typedef png_row_info * png_row_infop;
+typedef png_row_info * * png_row_infopp;
+
+/* These are the function types for the I/O functions and for the functions
+ * that allow the user to override the default I/O functions with his or her
+ * own. The png_error_ptr type should match that of user-supplied warning
+ * and error functions, while the png_rw_ptr type should match that of the
+ * user read/write data functions. Note that the 'write' function must not
+ * modify the buffer it is passed. The 'read' function, on the other hand, is
+ * expected to return the read data in the buffer.
+ */
+typedef PNG_CALLBACK(void, *png_error_ptr, (png_structp, png_const_charp));
+typedef PNG_CALLBACK(void, *png_rw_ptr, (png_structp, png_bytep, png_size_t));
+typedef PNG_CALLBACK(void, *png_flush_ptr, (png_structp));
+typedef PNG_CALLBACK(void, *png_read_status_ptr, (png_structp, png_uint_32,
+ int));
+typedef PNG_CALLBACK(void, *png_write_status_ptr, (png_structp, png_uint_32,
+ int));
+
+#ifdef PNG_PROGRESSIVE_READ_SUPPORTED
+typedef PNG_CALLBACK(void, *png_progressive_info_ptr, (png_structp, png_infop));
+typedef PNG_CALLBACK(void, *png_progressive_end_ptr, (png_structp, png_infop));
+
+/* The following callback receives png_uint_32 row_number, int pass for the
+ * png_bytep data of the row. When transforming an interlaced image the
+ * row number is the row number within the sub-image of the interlace pass, so
+ * the value will increase to the height of the sub-image (not the full image)
+ * then reset to 0 for the next pass.
+ *
+ * Use PNG_ROW_FROM_PASS_ROW(row, pass) and PNG_COL_FROM_PASS_COL(col, pass) to
+ * find the output pixel (x,y) given an interlaced sub-image pixel
+ * (row,col,pass). (See below for these macros.)
+ */
+typedef PNG_CALLBACK(void, *png_progressive_row_ptr, (png_structp, png_bytep,
+ png_uint_32, int));
+#endif
+
+#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) || \
+ defined(PNG_WRITE_USER_TRANSFORM_SUPPORTED)
+typedef PNG_CALLBACK(void, *png_user_transform_ptr, (png_structp, png_row_infop,
+ png_bytep));
+#endif
+
+#ifdef PNG_USER_CHUNKS_SUPPORTED
+typedef PNG_CALLBACK(int, *png_user_chunk_ptr, (png_structp,
+ png_unknown_chunkp));
+#endif
+#ifdef PNG_UNKNOWN_CHUNKS_SUPPORTED
+/* not used anywhere */
+/* typedef PNG_CALLBACK(void, *png_unknown_chunk_ptr, (png_structp)); */
+#endif
+
+#ifdef PNG_SETJMP_SUPPORTED
+/* This must match the function definition in <setjmp.h>, and the application
+ * must include this before png.h to obtain the definition of jmp_buf. The
+ * function is required to be PNG_NORETURN, but this is not checked. If the
+ * function does return the application will crash via an abort() or similar
+ * system level call.
+ *
+ * If you get a warning here while building the library you may need to make
+ * changes to ensure that pnglibconf.h records the calling convention used by
+ * your compiler. This may be very difficult - try using a different compiler
+ * to build the library!
+ */
+PNG_FUNCTION(void, (PNGCAPI *png_longjmp_ptr), PNGARG((jmp_buf, int)), typedef);
+#endif
+
+/* Transform masks for the high-level interface */
+#define PNG_TRANSFORM_IDENTITY 0x0000 /* read and write */
+#define PNG_TRANSFORM_STRIP_16 0x0001 /* read only */
+#define PNG_TRANSFORM_STRIP_ALPHA 0x0002 /* read only */
+#define PNG_TRANSFORM_PACKING 0x0004 /* read and write */
+#define PNG_TRANSFORM_PACKSWAP 0x0008 /* read and write */
+#define PNG_TRANSFORM_EXPAND 0x0010 /* read only */
+#define PNG_TRANSFORM_INVERT_MONO 0x0020 /* read and write */
+#define PNG_TRANSFORM_SHIFT 0x0040 /* read and write */
+#define PNG_TRANSFORM_BGR 0x0080 /* read and write */
+#define PNG_TRANSFORM_SWAP_ALPHA 0x0100 /* read and write */
+#define PNG_TRANSFORM_SWAP_ENDIAN 0x0200 /* read and write */
+#define PNG_TRANSFORM_INVERT_ALPHA 0x0400 /* read and write */
+#define PNG_TRANSFORM_STRIP_FILLER 0x0800 /* write only */
+/* Added to libpng-1.2.34 */
+#define PNG_TRANSFORM_STRIP_FILLER_BEFORE PNG_TRANSFORM_STRIP_FILLER
+#define PNG_TRANSFORM_STRIP_FILLER_AFTER 0x1000 /* write only */
+/* Added to libpng-1.4.0 */
+#define PNG_TRANSFORM_GRAY_TO_RGB 0x2000 /* read only */
+/* Added to libpng-1.5.4 */
+#define PNG_TRANSFORM_EXPAND_16 0x4000 /* read only */
+#define PNG_TRANSFORM_SCALE_16 0x8000 /* read only */
+
+/* Flags for MNG supported features */
+#define PNG_FLAG_MNG_EMPTY_PLTE 0x01
+#define PNG_FLAG_MNG_FILTER_64 0x04
+#define PNG_ALL_MNG_FEATURES 0x05
+
+/* NOTE: prior to 1.5 these functions had no 'API' style declaration,
+ * this allowed the zlib default functions to be used on Windows
+ * platforms. In 1.5 the zlib default malloc (which just calls malloc and
+ * ignores the first argument) should be completely compatible with the
+ * following.
+ */
+typedef PNG_CALLBACK(png_voidp, *png_malloc_ptr, (png_structp,
+ png_alloc_size_t));
+typedef PNG_CALLBACK(void, *png_free_ptr, (png_structp, png_voidp));
+
+/* Section 3: exported functions
+ * Here are the function definitions most commonly used. This is not
+ * the place to find out how to use libpng. See libpng-manual.txt for the
+ * full explanation, see example.c for the summary. This just provides
+ * a simple one line description of the use of each function.
+ *
+ * The PNG_EXPORT() and PNG_EXPORTA() macros used below are defined in
+ * pngconf.h and in the *.dfn files in the scripts directory.
+ *
+ * PNG_EXPORT(ordinal, type, name, (args));
+ *
+ * ordinal: ordinal that is used while building
+ * *.def files. The ordinal value is only
+ * relevant when preprocessing png.h with
+ * the *.dfn files for building symbol table
+ * entries, and are removed by pngconf.h.
+ * type: return type of the function
+ * name: function name
+ * args: function arguments, with types
+ *
+ * When we wish to append attributes to a function prototype we use
+ * the PNG_EXPORTA() macro instead.
+ *
+ * PNG_EXPORTA(ordinal, type, name, (args), attributes);
+ *
+ * ordinal, type, name, and args: same as in PNG_EXPORT().
+ * attributes: function attributes
+ */
+
+/* Returns the version number of the library */
+PNG_EXPORT(1, png_uint_32, png_access_version_number, (void));
+
+/* Tell lib we have already handled the first <num_bytes> magic bytes.
+ * Handling more than 8 bytes from the beginning of the file is an error.
+ */
+PNG_EXPORT(2, void, png_set_sig_bytes, (png_structrp png_ptr, int num_bytes));
+
+/* Check sig[start] through sig[start + num_to_check - 1] to see if it's a
+ * PNG file. Returns zero if the supplied bytes match the 8-byte PNG
+ * signature, and non-zero otherwise. Having num_to_check == 0 or
+ * start > 7 will always fail (ie return non-zero).
+ */
+PNG_EXPORT(3, int, png_sig_cmp, (png_const_bytep sig, png_size_t start,
+ png_size_t num_to_check));
+
+/* Simple signature checking function. This is the same as calling
+ * png_check_sig(sig, n) := !png_sig_cmp(sig, 0, n).
+ */
+#define png_check_sig(sig, n) !png_sig_cmp((sig), 0, (n))
+
+/* Allocate and initialize png_ptr struct for reading, and any other memory. */
+PNG_EXPORTA(4, png_structp, png_create_read_struct,
+ (png_const_charp user_png_ver, png_voidp error_ptr,
+ png_error_ptr error_fn, png_error_ptr warn_fn),
+ PNG_ALLOCATED);
+
+/* Allocate and initialize png_ptr struct for writing, and any other memory */
+PNG_EXPORTA(5, png_structp, png_create_write_struct,
+ (png_const_charp user_png_ver, png_voidp error_ptr, png_error_ptr error_fn,
+ png_error_ptr warn_fn),
+ PNG_ALLOCATED);
+
+PNG_EXPORT(6, png_size_t, png_get_compression_buffer_size,
+ (png_const_structrp png_ptr));
+
+PNG_EXPORT(7, void, png_set_compression_buffer_size, (png_structrp png_ptr,
+ png_size_t size));
+
+/* Moved from pngconf.h in 1.4.0 and modified to ensure setjmp/longjmp
+ * match up.
+ */
+#ifdef PNG_SETJMP_SUPPORTED
+/* This function returns the jmp_buf built in to *png_ptr. It must be
+ * supplied with an appropriate 'longjmp' function to use on that jmp_buf
+ * unless the default error function is overridden in which case NULL is
+ * acceptable. The size of the jmp_buf is checked against the actual size
+ * allocated by the library - the call will return NULL on a mismatch
+ * indicating an ABI mismatch.
+ */
+PNG_EXPORT(8, jmp_buf*, png_set_longjmp_fn, (png_structrp png_ptr,
+ png_longjmp_ptr longjmp_fn, size_t jmp_buf_size));
+# define png_jmpbuf(png_ptr) \
+ (*png_set_longjmp_fn((png_ptr), longjmp, (sizeof (jmp_buf))))
+#else
+# define png_jmpbuf(png_ptr) \
+ (LIBPNG_WAS_COMPILED_WITH__PNG_NO_SETJMP)
+#endif
+/* This function should be used by libpng applications in place of
+ * longjmp(png_ptr->jmpbuf, val). If longjmp_fn() has been set, it
+ * will use it; otherwise it will call PNG_ABORT(). This function was
+ * added in libpng-1.5.0.
+ */
+PNG_EXPORTA(9, void, png_longjmp, (png_const_structrp png_ptr, int val),
+ PNG_NORETURN);
+
+#ifdef PNG_READ_SUPPORTED
+/* Reset the compression stream */
+PNG_EXPORTA(10, int, png_reset_zstream, (png_structrp png_ptr), PNG_DEPRECATED);
+#endif
+
+/* New functions added in libpng-1.0.2 (not enabled by default until 1.2.0) */
+#ifdef PNG_USER_MEM_SUPPORTED
+PNG_EXPORTA(11, png_structp, png_create_read_struct_2,
+ (png_const_charp user_png_ver, png_voidp error_ptr, png_error_ptr error_fn,
+ png_error_ptr warn_fn,
+ png_voidp mem_ptr, png_malloc_ptr malloc_fn, png_free_ptr free_fn),
+ PNG_ALLOCATED);
+PNG_EXPORTA(12, png_structp, png_create_write_struct_2,
+ (png_const_charp user_png_ver, png_voidp error_ptr, png_error_ptr error_fn,
+ png_error_ptr warn_fn,
+ png_voidp mem_ptr, png_malloc_ptr malloc_fn, png_free_ptr free_fn),
+ PNG_ALLOCATED);
+#endif
+
+/* Write the PNG file signature. */
+PNG_EXPORT(13, void, png_write_sig, (png_structrp png_ptr));
+
+/* Write a PNG chunk - size, type, (optional) data, CRC. */
+PNG_EXPORT(14, void, png_write_chunk, (png_structrp png_ptr, png_const_bytep
+ chunk_name, png_const_bytep data, png_size_t length));
+
+/* Write the start of a PNG chunk - length and chunk name. */
+PNG_EXPORT(15, void, png_write_chunk_start, (png_structrp png_ptr,
+ png_const_bytep chunk_name, png_uint_32 length));
+
+/* Write the data of a PNG chunk started with png_write_chunk_start(). */
+PNG_EXPORT(16, void, png_write_chunk_data, (png_structrp png_ptr,
+ png_const_bytep data, png_size_t length));
+
+/* Finish a chunk started with png_write_chunk_start() (includes CRC). */
+PNG_EXPORT(17, void, png_write_chunk_end, (png_structrp png_ptr));
+
+/* Allocate and initialize the info structure */
+PNG_EXPORTA(18, png_infop, png_create_info_struct, (png_const_structrp png_ptr),
+ PNG_ALLOCATED);
+
+/* DEPRECATED: this function allowed init structures to be created using the
+ * default allocation method (typically malloc). Use is deprecated in 1.6.0 and
+ * the API will be removed in the future.
+ */
+PNG_EXPORTA(19, void, png_info_init_3, (png_infopp info_ptr,
+ png_size_t png_info_struct_size), PNG_DEPRECATED);
+
+/* Writes all the PNG information before the image. */
+PNG_EXPORT(20, void, png_write_info_before_PLTE,
+ (png_structrp png_ptr, png_const_inforp info_ptr));
+PNG_EXPORT(21, void, png_write_info,
+ (png_structrp png_ptr, png_const_inforp info_ptr));
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Read the information before the actual image data. */
+PNG_EXPORT(22, void, png_read_info,
+ (png_structrp png_ptr, png_inforp info_ptr));
+#endif
+
+#ifdef PNG_TIME_RFC1123_SUPPORTED
+ /* Convert to a US string format: there is no localization support in this
+ * routine. The original implementation used a 29 character buffer in
+ * png_struct, this will be removed in future versions.
+ */
+#if PNG_LIBPNG_VER < 10700
+/* To do: remove this from libpng17 (and from libpng17/png.c and pngstruct.h) */
+PNG_EXPORTA(23, png_const_charp, png_convert_to_rfc1123, (png_structrp png_ptr,
+ png_const_timep ptime),PNG_DEPRECATED);
+#endif
+PNG_EXPORT(241, int, png_convert_to_rfc1123_buffer, (char out[29],
+ png_const_timep ptime));
+#endif
+
+#ifdef PNG_CONVERT_tIME_SUPPORTED
+/* Convert from a struct tm to png_time */
+PNG_EXPORT(24, void, png_convert_from_struct_tm, (png_timep ptime,
+ const struct tm * ttime));
+
+/* Convert from time_t to png_time. Uses gmtime() */
+PNG_EXPORT(25, void, png_convert_from_time_t, (png_timep ptime, time_t ttime));
+#endif /* PNG_CONVERT_tIME_SUPPORTED */
+
+#ifdef PNG_READ_EXPAND_SUPPORTED
+/* Expand data to 24-bit RGB, or 8-bit grayscale, with alpha if available. */
+PNG_EXPORT(26, void, png_set_expand, (png_structrp png_ptr));
+PNG_EXPORT(27, void, png_set_expand_gray_1_2_4_to_8, (png_structrp png_ptr));
+PNG_EXPORT(28, void, png_set_palette_to_rgb, (png_structrp png_ptr));
+PNG_EXPORT(29, void, png_set_tRNS_to_alpha, (png_structrp png_ptr));
+#endif
+
+#ifdef PNG_READ_EXPAND_16_SUPPORTED
+/* Expand to 16-bit channels, forces conversion of palette to RGB and expansion
+ * of a tRNS chunk if present.
+ */
+PNG_EXPORT(221, void, png_set_expand_16, (png_structrp png_ptr));
+#endif
+
+#if defined(PNG_READ_BGR_SUPPORTED) || defined(PNG_WRITE_BGR_SUPPORTED)
+/* Use blue, green, red order for pixels. */
+PNG_EXPORT(30, void, png_set_bgr, (png_structrp png_ptr));
+#endif
+
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+/* Expand the grayscale to 24-bit RGB if necessary. */
+PNG_EXPORT(31, void, png_set_gray_to_rgb, (png_structrp png_ptr));
+#endif
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+/* Reduce RGB to grayscale. */
+#define PNG_ERROR_ACTION_NONE 1
+#define PNG_ERROR_ACTION_WARN 2
+#define PNG_ERROR_ACTION_ERROR 3
+#define PNG_RGB_TO_GRAY_DEFAULT (-1)/*for red/green coefficients*/
+
+PNG_FP_EXPORT(32, void, png_set_rgb_to_gray, (png_structrp png_ptr,
+ int error_action, double red, double green))
+PNG_FIXED_EXPORT(33, void, png_set_rgb_to_gray_fixed, (png_structrp png_ptr,
+ int error_action, png_fixed_point red, png_fixed_point green))
+
+PNG_EXPORT(34, png_byte, png_get_rgb_to_gray_status, (png_const_structrp
+ png_ptr));
+#endif
+
+#ifdef PNG_BUILD_GRAYSCALE_PALETTE_SUPPORTED
+PNG_EXPORT(35, void, png_build_grayscale_palette, (int bit_depth,
+ png_colorp palette));
+#endif
+
+#ifdef PNG_READ_ALPHA_MODE_SUPPORTED
+/* How the alpha channel is interpreted - this affects how the color channels of
+ * a PNG file are returned when an alpha channel, or tRNS chunk in a palette
+ * file, is present.
+ *
+ * This has no effect on the way pixels are written into a PNG output
+ * datastream. The color samples in a PNG datastream are never premultiplied
+ * with the alpha samples.
+ *
+ * The default is to return data according to the PNG specification: the alpha
+ * channel is a linear measure of the contribution of the pixel to the
+ * corresponding composited pixel. The gamma encoded color channels must be
+ * scaled according to the contribution and to do this it is necessary to undo
+ * the encoding, scale the color values, perform the composition and reencode
+ * the values. This is the 'PNG' mode.
+ *
+ * The alternative is to 'associate' the alpha with the color information by
+ * storing color channel values that have been scaled by the alpha. The
+ * advantage is that the color channels can be resampled (the image can be
+ * scaled) in this form. The disadvantage is that normal practice is to store
+ * linear, not (gamma) encoded, values and this requires 16-bit channels for
+ * still images rather than the 8-bit channels that are just about sufficient if
+ * gamma encoding is used. In addition all non-transparent pixel values,
+ * including completely opaque ones, must be gamma encoded to produce the final
+ * image. This is the 'STANDARD', 'ASSOCIATED' or 'PREMULTIPLIED' mode (the
+ * latter being the two common names for associated alpha color channels.)
+ *
+ * Since it is not necessary to perform arithmetic on opaque color values so
+ * long as they are not to be resampled and are in the final color space it is
+ * possible to optimize the handling of alpha by storing the opaque pixels in
+ * the PNG format (adjusted for the output color space) while storing partially
+ * opaque pixels in the standard, linear, format. The accuracy required for
+ * standard alpha composition is relatively low, because the pixels are
+ * isolated, therefore typically the accuracy loss in storing 8-bit linear
+ * values is acceptable. (This is not true if the alpha channel is used to
+ * simulate transparency over large areas - use 16 bits or the PNG mode in
+ * this case!) This is the 'OPTIMIZED' mode. For this mode a pixel is
+ * treated as opaque only if the alpha value is equal to the maximum value.
+ *
+ * The final choice is to gamma encode the alpha channel as well. This is
+ * broken because, in practice, no implementation that uses this choice
+ * correctly undoes the encoding before handling alpha composition. Use this
+ * choice only if other serious errors in the software or hardware you use
+ * mandate it; the typical serious error is for dark halos to appear around
+ * opaque areas of the composited PNG image because of arithmetic overflow.
+ *
+ * The API function png_set_alpha_mode specifies which of these choices to use
+ * with an enumerated 'mode' value and the gamma of the required output:
+ */
+#define PNG_ALPHA_PNG 0 /* according to the PNG standard */
+#define PNG_ALPHA_STANDARD 1 /* according to Porter/Duff */
+#define PNG_ALPHA_ASSOCIATED 1 /* as above; this is the normal practice */
+#define PNG_ALPHA_PREMULTIPLIED 1 /* as above */
+#define PNG_ALPHA_OPTIMIZED 2 /* 'PNG' for opaque pixels, else 'STANDARD' */
+#define PNG_ALPHA_BROKEN 3 /* the alpha channel is gamma encoded */
+
+PNG_FP_EXPORT(227, void, png_set_alpha_mode, (png_structrp png_ptr, int mode,
+ double output_gamma))
+PNG_FIXED_EXPORT(228, void, png_set_alpha_mode_fixed, (png_structrp png_ptr,
+ int mode, png_fixed_point output_gamma))
+#endif
+
+#if defined(PNG_GAMMA_SUPPORTED) || defined(PNG_READ_ALPHA_MODE_SUPPORTED)
+/* The output_gamma value is a screen gamma in libpng terminology: it expresses
+ * how to decode the output values, not how they are encoded. The values used
+ * correspond to the normal numbers used to describe the overall gamma of a
+ * computer display system; for example 2.2 for an sRGB conformant system. The
+ * values are scaled by 100000 in the _fixed version of the API (so 220000 for
+ * sRGB.)
+ *
+ * The inverse of the value is always used to provide a default for the PNG file
+ * encoding if it has no gAMA chunk and if png_set_gamma() has not been called
+ * to override the PNG gamma information.
+ *
+ * When the ALPHA_OPTIMIZED mode is selected the output gamma is used to encode
+ * opaque pixels however pixels with lower alpha values are not encoded,
+ * regardless of the output gamma setting.
+ *
+ * When the standard Porter Duff handling is requested with mode 1 the output
+ * encoding is set to be linear and the output_gamma value is only relevant
+ * as a default for input data that has no gamma information. The linear output
+ * encoding will be overridden if png_set_gamma() is called - the results may be
+ * highly unexpected!
+ *
+ * The following numbers are derived from the sRGB standard and the research
+ * behind it. sRGB is defined to be approximated by a PNG gAMA chunk value of
+ * 0.45455 (1/2.2) for PNG. The value implicitly includes any viewing
+ * correction required to take account of any differences in the color
+ * environment of the original scene and the intended display environment; the
+ * value expresses how to *decode* the image for display, not how the original
+ * data was *encoded*.
+ *
+ * sRGB provides a peg for the PNG standard by defining a viewing environment.
+ * sRGB itself, and earlier TV standards, actually use a more complex transform
+ * (a linear portion then a gamma 2.4 power law) than PNG can express. (PNG is
+ * limited to simple power laws.) By saying that an image for direct display on
+ * an sRGB conformant system should be stored with a gAMA chunk value of 45455
+ * (11.3.3.2 and 11.3.3.5 of the ISO PNG specification) the PNG specification
+ * makes it possible to derive values for other display systems and
+ * environments.
+ *
+ * The Mac value is deduced from the sRGB based on an assumption that the actual
+ * extra viewing correction used in early Mac display systems was implemented as
+ * a power 1.45 lookup table.
+ *
+ * Any system where a programmable lookup table is used or where the behavior of
+ * the final display device characteristics can be changed requires system
+ * specific code to obtain the current characteristic. However this can be
+ * difficult and most PNG gamma correction only requires an approximate value.
+ *
+ * By default, if png_set_alpha_mode() is not called, libpng assumes that all
+ * values are unencoded, linear, values and that the output device also has a
+ * linear characteristic. This is only very rarely correct - it is invariably
+ * better to call png_set_alpha_mode() with PNG_DEFAULT_sRGB than rely on the
+ * default if you don't know what the right answer is!
+ *
+ * The special value PNG_GAMMA_MAC_18 indicates an older Mac system (pre Mac OS
+ * 10.6) which used a correction table to implement a somewhat lower gamma on an
+ * otherwise sRGB system.
+ *
+ * Both these values are reserved (not simple gamma values) in order to allow
+ * more precise correction internally in the future.
+ *
+ * NOTE: the following values can be passed to either the fixed or floating
+ * point APIs, but the floating point API will also accept floating point
+ * values.
+ */
+#define PNG_DEFAULT_sRGB -1 /* sRGB gamma and color space */
+#define PNG_GAMMA_MAC_18 -2 /* Old Mac '1.8' gamma and color space */
+#define PNG_GAMMA_sRGB 220000 /* Television standards--matches sRGB gamma */
+#define PNG_GAMMA_LINEAR PNG_FP_1 /* Linear */
+#endif
+
+/* The following are examples of calls to png_set_alpha_mode to achieve the
+ * required overall gamma correction and, where necessary, alpha
+ * premultiplication.
+ *
+ * png_set_alpha_mode(pp, PNG_ALPHA_PNG, PNG_DEFAULT_sRGB);
+ * This is the default libpng handling of the alpha channel - it is not
+ * pre-multiplied into the color components. In addition the call states
+ * that the output is for a sRGB system and causes all PNG files without gAMA
+ * chunks to be assumed to be encoded using sRGB.
+ *
+ * png_set_alpha_mode(pp, PNG_ALPHA_PNG, PNG_GAMMA_MAC);
+ * In this case the output is assumed to be something like an sRGB conformant
+ * display preceeded by a power-law lookup table of power 1.45. This is how
+ * early Mac systems behaved.
+ *
+ * png_set_alpha_mode(pp, PNG_ALPHA_STANDARD, PNG_GAMMA_LINEAR);
+ * This is the classic Jim Blinn approach and will work in academic
+ * environments where everything is done by the book. It has the shortcoming
+ * of assuming that input PNG data with no gamma information is linear - this
+ * is unlikely to be correct unless the PNG files where generated locally.
+ * Most of the time the output precision will be so low as to show
+ * significant banding in dark areas of the image.
+ *
+ * png_set_expand_16(pp);
+ * png_set_alpha_mode(pp, PNG_ALPHA_STANDARD, PNG_DEFAULT_sRGB);
+ * This is a somewhat more realistic Jim Blinn inspired approach. PNG files
+ * are assumed to have the sRGB encoding if not marked with a gamma value and
+ * the output is always 16 bits per component. This permits accurate scaling
+ * and processing of the data. If you know that your input PNG files were
+ * generated locally you might need to replace PNG_DEFAULT_sRGB with the
+ * correct value for your system.
+ *
+ * png_set_alpha_mode(pp, PNG_ALPHA_OPTIMIZED, PNG_DEFAULT_sRGB);
+ * If you just need to composite the PNG image onto an existing background
+ * and if you control the code that does this you can use the optimization
+ * setting. In this case you just copy completely opaque pixels to the
+ * output. For pixels that are not completely transparent (you just skip
+ * those) you do the composition math using png_composite or png_composite_16
+ * below then encode the resultant 8-bit or 16-bit values to match the output
+ * encoding.
+ *
+ * Other cases
+ * If neither the PNG nor the standard linear encoding work for you because
+ * of the software or hardware you use then you have a big problem. The PNG
+ * case will probably result in halos around the image. The linear encoding
+ * will probably result in a washed out, too bright, image (it's actually too
+ * contrasty.) Try the ALPHA_OPTIMIZED mode above - this will probably
+ * substantially reduce the halos. Alternatively try:
+ *
+ * png_set_alpha_mode(pp, PNG_ALPHA_BROKEN, PNG_DEFAULT_sRGB);
+ * This option will also reduce the halos, but there will be slight dark
+ * halos round the opaque parts of the image where the background is light.
+ * In the OPTIMIZED mode the halos will be light halos where the background
+ * is dark. Take your pick - the halos are unavoidable unless you can get
+ * your hardware/software fixed! (The OPTIMIZED approach is slightly
+ * faster.)
+ *
+ * When the default gamma of PNG files doesn't match the output gamma.
+ * If you have PNG files with no gamma information png_set_alpha_mode allows
+ * you to provide a default gamma, but it also sets the ouput gamma to the
+ * matching value. If you know your PNG files have a gamma that doesn't
+ * match the output you can take advantage of the fact that
+ * png_set_alpha_mode always sets the output gamma but only sets the PNG
+ * default if it is not already set:
+ *
+ * png_set_alpha_mode(pp, PNG_ALPHA_PNG, PNG_DEFAULT_sRGB);
+ * png_set_alpha_mode(pp, PNG_ALPHA_PNG, PNG_GAMMA_MAC);
+ * The first call sets both the default and the output gamma values, the
+ * second call overrides the output gamma without changing the default. This
+ * is easier than achieving the same effect with png_set_gamma. You must use
+ * PNG_ALPHA_PNG for the first call - internal checking in png_set_alpha will
+ * fire if more than one call to png_set_alpha_mode and png_set_background is
+ * made in the same read operation, however multiple calls with PNG_ALPHA_PNG
+ * are ignored.
+ */
+
+#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED
+PNG_EXPORT(36, void, png_set_strip_alpha, (png_structrp png_ptr));
+#endif
+
+#if defined(PNG_READ_SWAP_ALPHA_SUPPORTED) || \
+ defined(PNG_WRITE_SWAP_ALPHA_SUPPORTED)
+PNG_EXPORT(37, void, png_set_swap_alpha, (png_structrp png_ptr));
+#endif
+
+#if defined(PNG_READ_INVERT_ALPHA_SUPPORTED) || \
+ defined(PNG_WRITE_INVERT_ALPHA_SUPPORTED)
+PNG_EXPORT(38, void, png_set_invert_alpha, (png_structrp png_ptr));
+#endif
+
+#if defined(PNG_READ_FILLER_SUPPORTED) || defined(PNG_WRITE_FILLER_SUPPORTED)
+/* Add a filler byte to 8-bit Gray or 24-bit RGB images. */
+PNG_EXPORT(39, void, png_set_filler, (png_structrp png_ptr, png_uint_32 filler,
+ int flags));
+/* The values of the PNG_FILLER_ defines should NOT be changed */
+# define PNG_FILLER_BEFORE 0
+# define PNG_FILLER_AFTER 1
+/* Add an alpha byte to 8-bit Gray or 24-bit RGB images. */
+PNG_EXPORT(40, void, png_set_add_alpha, (png_structrp png_ptr,
+ png_uint_32 filler, int flags));
+#endif /* PNG_READ_FILLER_SUPPORTED || PNG_WRITE_FILLER_SUPPORTED */
+
+#if defined(PNG_READ_SWAP_SUPPORTED) || defined(PNG_WRITE_SWAP_SUPPORTED)
+/* Swap bytes in 16-bit depth files. */
+PNG_EXPORT(41, void, png_set_swap, (png_structrp png_ptr));
+#endif
+
+#if defined(PNG_READ_PACK_SUPPORTED) || defined(PNG_WRITE_PACK_SUPPORTED)
+/* Use 1 byte per pixel in 1, 2, or 4-bit depth files. */
+PNG_EXPORT(42, void, png_set_packing, (png_structrp png_ptr));
+#endif
+
+#if defined(PNG_READ_PACKSWAP_SUPPORTED) || \
+ defined(PNG_WRITE_PACKSWAP_SUPPORTED)
+/* Swap packing order of pixels in bytes. */
+PNG_EXPORT(43, void, png_set_packswap, (png_structrp png_ptr));
+#endif
+
+#if defined(PNG_READ_SHIFT_SUPPORTED) || defined(PNG_WRITE_SHIFT_SUPPORTED)
+/* Converts files to legal bit depths. */
+PNG_EXPORT(44, void, png_set_shift, (png_structrp png_ptr, png_const_color_8p
+ true_bits));
+#endif
+
+#if defined(PNG_READ_INTERLACING_SUPPORTED) || \
+ defined(PNG_WRITE_INTERLACING_SUPPORTED)
+/* Have the code handle the interlacing. Returns the number of passes.
+ * MUST be called before png_read_update_info or png_start_read_image,
+ * otherwise it will not have the desired effect. Note that it is still
+ * necessary to call png_read_row or png_read_rows png_get_image_height
+ * times for each pass.
+*/
+PNG_EXPORT(45, int, png_set_interlace_handling, (png_structrp png_ptr));
+#endif
+
+#if defined(PNG_READ_INVERT_SUPPORTED) || defined(PNG_WRITE_INVERT_SUPPORTED)
+/* Invert monochrome files */
+PNG_EXPORT(46, void, png_set_invert_mono, (png_structrp png_ptr));
+#endif
+
+#ifdef PNG_READ_BACKGROUND_SUPPORTED
+/* Handle alpha and tRNS by replacing with a background color. Prior to
+ * libpng-1.5.4 this API must not be called before the PNG file header has been
+ * read. Doing so will result in unexpected behavior and possible warnings or
+ * errors if the PNG file contains a bKGD chunk.
+ */
+PNG_FP_EXPORT(47, void, png_set_background, (png_structrp png_ptr,
+ png_const_color_16p background_color, int background_gamma_code,
+ int need_expand, double background_gamma))
+PNG_FIXED_EXPORT(215, void, png_set_background_fixed, (png_structrp png_ptr,
+ png_const_color_16p background_color, int background_gamma_code,
+ int need_expand, png_fixed_point background_gamma))
+#endif
+#ifdef PNG_READ_BACKGROUND_SUPPORTED
+# define PNG_BACKGROUND_GAMMA_UNKNOWN 0
+# define PNG_BACKGROUND_GAMMA_SCREEN 1
+# define PNG_BACKGROUND_GAMMA_FILE 2
+# define PNG_BACKGROUND_GAMMA_UNIQUE 3
+#endif
+
+#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED
+/* Scale a 16-bit depth file down to 8-bit, accurately. */
+PNG_EXPORT(229, void, png_set_scale_16, (png_structrp png_ptr));
+#endif
+
+#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED
+#define PNG_READ_16_TO_8 SUPPORTED /* Name prior to 1.5.4 */
+/* Strip the second byte of information from a 16-bit depth file. */
+PNG_EXPORT(48, void, png_set_strip_16, (png_structrp png_ptr));
+#endif
+
+#ifdef PNG_READ_QUANTIZE_SUPPORTED
+/* Turn on quantizing, and reduce the palette to the number of colors
+ * available.
+ */
+PNG_EXPORT(49, void, png_set_quantize, (png_structrp png_ptr,
+ png_colorp palette, int num_palette, int maximum_colors,
+ png_const_uint_16p histogram, int full_quantize));
+#endif
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+/* The threshold on gamma processing is configurable but hard-wired into the
+ * library. The following is the floating point variant.
+ */
+#define PNG_GAMMA_THRESHOLD (PNG_GAMMA_THRESHOLD_FIXED*.00001)
+
+/* Handle gamma correction. Screen_gamma=(display_exponent).
+ * NOTE: this API simply sets the screen and file gamma values. It will
+ * therefore override the value for gamma in a PNG file if it is called after
+ * the file header has been read - use with care - call before reading the PNG
+ * file for best results!
+ *
+ * These routines accept the same gamma values as png_set_alpha_mode (described
+ * above). The PNG_GAMMA_ defines and PNG_DEFAULT_sRGB can be passed to either
+ * API (floating point or fixed.) Notice, however, that the 'file_gamma' value
+ * is the inverse of a 'screen gamma' value.
+ */
+PNG_FP_EXPORT(50, void, png_set_gamma, (png_structrp png_ptr,
+ double screen_gamma, double override_file_gamma))
+PNG_FIXED_EXPORT(208, void, png_set_gamma_fixed, (png_structrp png_ptr,
+ png_fixed_point screen_gamma, png_fixed_point override_file_gamma))
+#endif
+
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+/* Set how many lines between output flushes - 0 for no flushing */
+PNG_EXPORT(51, void, png_set_flush, (png_structrp png_ptr, int nrows));
+/* Flush the current PNG output buffer */
+PNG_EXPORT(52, void, png_write_flush, (png_structrp png_ptr));
+#endif
+
+/* Optional update palette with requested transformations */
+PNG_EXPORT(53, void, png_start_read_image, (png_structrp png_ptr));
+
+/* Optional call to update the users info structure */
+PNG_EXPORT(54, void, png_read_update_info, (png_structrp png_ptr,
+ png_inforp info_ptr));
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Read one or more rows of image data. */
+PNG_EXPORT(55, void, png_read_rows, (png_structrp png_ptr, png_bytepp row,
+ png_bytepp display_row, png_uint_32 num_rows));
+#endif
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Read a row of data. */
+PNG_EXPORT(56, void, png_read_row, (png_structrp png_ptr, png_bytep row,
+ png_bytep display_row));
+#endif
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Read the whole image into memory at once. */
+PNG_EXPORT(57, void, png_read_image, (png_structrp png_ptr, png_bytepp image));
+#endif
+
+/* Write a row of image data */
+PNG_EXPORT(58, void, png_write_row, (png_structrp png_ptr,
+ png_const_bytep row));
+
+/* Write a few rows of image data: (*row) is not written; however, the type
+ * is declared as writeable to maintain compatibility with previous versions
+ * of libpng and to allow the 'display_row' array from read_rows to be passed
+ * unchanged to write_rows.
+ */
+PNG_EXPORT(59, void, png_write_rows, (png_structrp png_ptr, png_bytepp row,
+ png_uint_32 num_rows));
+
+/* Write the image data */
+PNG_EXPORT(60, void, png_write_image, (png_structrp png_ptr, png_bytepp image));
+
+/* Write the end of the PNG file. */
+PNG_EXPORT(61, void, png_write_end, (png_structrp png_ptr,
+ png_inforp info_ptr));
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Read the end of the PNG file. */
+PNG_EXPORT(62, void, png_read_end, (png_structrp png_ptr, png_inforp info_ptr));
+#endif
+
+/* Free any memory associated with the png_info_struct */
+PNG_EXPORT(63, void, png_destroy_info_struct, (png_const_structrp png_ptr,
+ png_infopp info_ptr_ptr));
+
+/* Free any memory associated with the png_struct and the png_info_structs */
+PNG_EXPORT(64, void, png_destroy_read_struct, (png_structpp png_ptr_ptr,
+ png_infopp info_ptr_ptr, png_infopp end_info_ptr_ptr));
+
+/* Free any memory associated with the png_struct and the png_info_structs */
+PNG_EXPORT(65, void, png_destroy_write_struct, (png_structpp png_ptr_ptr,
+ png_infopp info_ptr_ptr));
+
+/* Set the libpng method of handling chunk CRC errors */
+PNG_EXPORT(66, void, png_set_crc_action, (png_structrp png_ptr, int crit_action,
+ int ancil_action));
+
+/* Values for png_set_crc_action() say how to handle CRC errors in
+ * ancillary and critical chunks, and whether to use the data contained
+ * therein. Note that it is impossible to "discard" data in a critical
+ * chunk. For versions prior to 0.90, the action was always error/quit,
+ * whereas in version 0.90 and later, the action for CRC errors in ancillary
+ * chunks is warn/discard. These values should NOT be changed.
+ *
+ * value action:critical action:ancillary
+ */
+#define PNG_CRC_DEFAULT 0 /* error/quit warn/discard data */
+#define PNG_CRC_ERROR_QUIT 1 /* error/quit error/quit */
+#define PNG_CRC_WARN_DISCARD 2 /* (INVALID) warn/discard data */
+#define PNG_CRC_WARN_USE 3 /* warn/use data warn/use data */
+#define PNG_CRC_QUIET_USE 4 /* quiet/use data quiet/use data */
+#define PNG_CRC_NO_CHANGE 5 /* use current value use current value */
+
+/* These functions give the user control over the scan-line filtering in
+ * libpng and the compression methods used by zlib. These functions are
+ * mainly useful for testing, as the defaults should work with most users.
+ * Those users who are tight on memory or want faster performance at the
+ * expense of compression can modify them. See the compression library
+ * header file (zlib.h) for an explination of the compression functions.
+ */
+
+/* Set the filtering method(s) used by libpng. Currently, the only valid
+ * value for "method" is 0.
+ */
+PNG_EXPORT(67, void, png_set_filter, (png_structrp png_ptr, int method,
+ int filters));
+
+/* Flags for png_set_filter() to say which filters to use. The flags
+ * are chosen so that they don't conflict with real filter types
+ * below, in case they are supplied instead of the #defined constants.
+ * These values should NOT be changed.
+ */
+#define PNG_NO_FILTERS 0x00
+#define PNG_FILTER_NONE 0x08
+#define PNG_FILTER_SUB 0x10
+#define PNG_FILTER_UP 0x20
+#define PNG_FILTER_AVG 0x40
+#define PNG_FILTER_PAETH 0x80
+#define PNG_ALL_FILTERS (PNG_FILTER_NONE | PNG_FILTER_SUB | PNG_FILTER_UP | \
+ PNG_FILTER_AVG | PNG_FILTER_PAETH)
+
+/* Filter values (not flags) - used in pngwrite.c, pngwutil.c for now.
+ * These defines should NOT be changed.
+ */
+#define PNG_FILTER_VALUE_NONE 0
+#define PNG_FILTER_VALUE_SUB 1
+#define PNG_FILTER_VALUE_UP 2
+#define PNG_FILTER_VALUE_AVG 3
+#define PNG_FILTER_VALUE_PAETH 4
+#define PNG_FILTER_VALUE_LAST 5
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED /* EXPERIMENTAL */
+/* The "heuristic_method" is given by one of the PNG_FILTER_HEURISTIC_
+ * defines, either the default (minimum-sum-of-absolute-differences), or
+ * the experimental method (weighted-minimum-sum-of-absolute-differences).
+ *
+ * Weights are factors >= 1.0, indicating how important it is to keep the
+ * filter type consistent between rows. Larger numbers mean the current
+ * filter is that many times as likely to be the same as the "num_weights"
+ * previous filters. This is cumulative for each previous row with a weight.
+ * There needs to be "num_weights" values in "filter_weights", or it can be
+ * NULL if the weights aren't being specified. Weights have no influence on
+ * the selection of the first row filter. Well chosen weights can (in theory)
+ * improve the compression for a given image.
+ *
+ * Costs are factors >= 1.0 indicating the relative decoding costs of a
+ * filter type. Higher costs indicate more decoding expense, and are
+ * therefore less likely to be selected over a filter with lower computational
+ * costs. There needs to be a value in "filter_costs" for each valid filter
+ * type (given by PNG_FILTER_VALUE_LAST), or it can be NULL if you aren't
+ * setting the costs. Costs try to improve the speed of decompression without
+ * unduly increasing the compressed image size.
+ *
+ * A negative weight or cost indicates the default value is to be used, and
+ * values in the range [0.0, 1.0) indicate the value is to remain unchanged.
+ * The default values for both weights and costs are currently 1.0, but may
+ * change if good general weighting/cost heuristics can be found. If both
+ * the weights and costs are set to 1.0, this degenerates the WEIGHTED method
+ * to the UNWEIGHTED method, but with added encoding time/computation.
+ */
+PNG_FP_EXPORT(68, void, png_set_filter_heuristics, (png_structrp png_ptr,
+ int heuristic_method, int num_weights, png_const_doublep filter_weights,
+ png_const_doublep filter_costs))
+PNG_FIXED_EXPORT(209, void, png_set_filter_heuristics_fixed,
+ (png_structrp png_ptr, int heuristic_method, int num_weights,
+ png_const_fixed_point_p filter_weights,
+ png_const_fixed_point_p filter_costs))
+#endif /* PNG_WRITE_WEIGHTED_FILTER_SUPPORTED */
+
+/* Heuristic used for row filter selection. These defines should NOT be
+ * changed.
+ */
+#define PNG_FILTER_HEURISTIC_DEFAULT 0 /* Currently "UNWEIGHTED" */
+#define PNG_FILTER_HEURISTIC_UNWEIGHTED 1 /* Used by libpng < 0.95 */
+#define PNG_FILTER_HEURISTIC_WEIGHTED 2 /* Experimental feature */
+#define PNG_FILTER_HEURISTIC_LAST 3 /* Not a valid value */
+
+#ifdef PNG_WRITE_SUPPORTED
+/* Set the library compression level. Currently, valid values range from
+ * 0 - 9, corresponding directly to the zlib compression levels 0 - 9
+ * (0 - no compression, 9 - "maximal" compression). Note that tests have
+ * shown that zlib compression levels 3-6 usually perform as well as level 9
+ * for PNG images, and do considerably fewer caclulations. In the future,
+ * these values may not correspond directly to the zlib compression levels.
+ */
+PNG_EXPORT(69, void, png_set_compression_level, (png_structrp png_ptr,
+ int level));
+
+PNG_EXPORT(70, void, png_set_compression_mem_level, (png_structrp png_ptr,
+ int mem_level));
+
+PNG_EXPORT(71, void, png_set_compression_strategy, (png_structrp png_ptr,
+ int strategy));
+
+/* If PNG_WRITE_OPTIMIZE_CMF_SUPPORTED is defined, libpng will use a
+ * smaller value of window_bits if it can do so safely.
+ */
+PNG_EXPORT(72, void, png_set_compression_window_bits, (png_structrp png_ptr,
+ int window_bits));
+
+PNG_EXPORT(73, void, png_set_compression_method, (png_structrp png_ptr,
+ int method));
+#endif
+
+#ifdef PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED
+/* Also set zlib parameters for compressing non-IDAT chunks */
+PNG_EXPORT(222, void, png_set_text_compression_level, (png_structrp png_ptr,
+ int level));
+
+PNG_EXPORT(223, void, png_set_text_compression_mem_level, (png_structrp png_ptr,
+ int mem_level));
+
+PNG_EXPORT(224, void, png_set_text_compression_strategy, (png_structrp png_ptr,
+ int strategy));
+
+/* If PNG_WRITE_OPTIMIZE_CMF_SUPPORTED is defined, libpng will use a
+ * smaller value of window_bits if it can do so safely.
+ */
+PNG_EXPORT(225, void, png_set_text_compression_window_bits,
+ (png_structrp png_ptr, int window_bits));
+
+PNG_EXPORT(226, void, png_set_text_compression_method, (png_structrp png_ptr,
+ int method));
+#endif /* PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED */
+
+/* These next functions are called for input/output, memory, and error
+ * handling. They are in the file pngrio.c, pngwio.c, and pngerror.c,
+ * and call standard C I/O routines such as fread(), fwrite(), and
+ * fprintf(). These functions can be made to use other I/O routines
+ * at run time for those applications that need to handle I/O in a
+ * different manner by calling png_set_???_fn(). See libpng-manual.txt for
+ * more information.
+ */
+
+#ifdef PNG_STDIO_SUPPORTED
+/* Initialize the input/output for the PNG file to the default functions. */
+PNG_EXPORT(74, void, png_init_io, (png_structrp png_ptr, png_FILE_p fp));
+#endif
+
+/* Replace the (error and abort), and warning functions with user
+ * supplied functions. If no messages are to be printed you must still
+ * write and use replacement functions. The replacement error_fn should
+ * still do a longjmp to the last setjmp location if you are using this
+ * method of error handling. If error_fn or warning_fn is NULL, the
+ * default function will be used.
+ */
+
+PNG_EXPORT(75, void, png_set_error_fn, (png_structrp png_ptr,
+ png_voidp error_ptr, png_error_ptr error_fn, png_error_ptr warning_fn));
+
+/* Return the user pointer associated with the error functions */
+PNG_EXPORT(76, png_voidp, png_get_error_ptr, (png_const_structrp png_ptr));
+
+/* Replace the default data output functions with a user supplied one(s).
+ * If buffered output is not used, then output_flush_fn can be set to NULL.
+ * If PNG_WRITE_FLUSH_SUPPORTED is not defined at libpng compile time
+ * output_flush_fn will be ignored (and thus can be NULL).
+ * It is probably a mistake to use NULL for output_flush_fn if
+ * write_data_fn is not also NULL unless you have built libpng with
+ * PNG_WRITE_FLUSH_SUPPORTED undefined, because in this case libpng's
+ * default flush function, which uses the standard *FILE structure, will
+ * be used.
+ */
+PNG_EXPORT(77, void, png_set_write_fn, (png_structrp png_ptr, png_voidp io_ptr,
+ png_rw_ptr write_data_fn, png_flush_ptr output_flush_fn));
+
+/* Replace the default data input function with a user supplied one. */
+PNG_EXPORT(78, void, png_set_read_fn, (png_structrp png_ptr, png_voidp io_ptr,
+ png_rw_ptr read_data_fn));
+
+/* Return the user pointer associated with the I/O functions */
+PNG_EXPORT(79, png_voidp, png_get_io_ptr, (png_const_structrp png_ptr));
+
+PNG_EXPORT(80, void, png_set_read_status_fn, (png_structrp png_ptr,
+ png_read_status_ptr read_row_fn));
+
+PNG_EXPORT(81, void, png_set_write_status_fn, (png_structrp png_ptr,
+ png_write_status_ptr write_row_fn));
+
+#ifdef PNG_USER_MEM_SUPPORTED
+/* Replace the default memory allocation functions with user supplied one(s). */
+PNG_EXPORT(82, void, png_set_mem_fn, (png_structrp png_ptr, png_voidp mem_ptr,
+ png_malloc_ptr malloc_fn, png_free_ptr free_fn));
+/* Return the user pointer associated with the memory functions */
+PNG_EXPORT(83, png_voidp, png_get_mem_ptr, (png_const_structrp png_ptr));
+#endif
+
+#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED
+PNG_EXPORT(84, void, png_set_read_user_transform_fn, (png_structrp png_ptr,
+ png_user_transform_ptr read_user_transform_fn));
+#endif
+
+#ifdef PNG_WRITE_USER_TRANSFORM_SUPPORTED
+PNG_EXPORT(85, void, png_set_write_user_transform_fn, (png_structrp png_ptr,
+ png_user_transform_ptr write_user_transform_fn));
+#endif
+
+#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED
+PNG_EXPORT(86, void, png_set_user_transform_info, (png_structrp png_ptr,
+ png_voidp user_transform_ptr, int user_transform_depth,
+ int user_transform_channels));
+/* Return the user pointer associated with the user transform functions */
+PNG_EXPORT(87, png_voidp, png_get_user_transform_ptr,
+ (png_const_structrp png_ptr));
+#endif
+
+#ifdef PNG_USER_TRANSFORM_INFO_SUPPORTED
+/* Return information about the row currently being processed. Note that these
+ * APIs do not fail but will return unexpected results if called outside a user
+ * transform callback. Also note that when transforming an interlaced image the
+ * row number is the row number within the sub-image of the interlace pass, so
+ * the value will increase to the height of the sub-image (not the full image)
+ * then reset to 0 for the next pass.
+ *
+ * Use PNG_ROW_FROM_PASS_ROW(row, pass) and PNG_COL_FROM_PASS_COL(col, pass) to
+ * find the output pixel (x,y) given an interlaced sub-image pixel
+ * (row,col,pass). (See below for these macros.)
+ */
+PNG_EXPORT(217, png_uint_32, png_get_current_row_number, (png_const_structrp));
+PNG_EXPORT(218, png_byte, png_get_current_pass_number, (png_const_structrp));
+#endif
+
+#ifdef PNG_READ_USER_CHUNKS_SUPPORTED
+/* This callback is called only for *unknown* chunks. If
+ * PNG_HANDLE_AS_UNKNOWN_SUPPORTED is set then it is possible to set known
+ * chunks to be treated as unknown, however in this case the callback must do
+ * any processing required by the chunk (e.g. by calling the appropriate
+ * png_set_ APIs.)
+ *
+ * There is no write support - on write, by default, all the chunks in the
+ * 'unknown' list are written in the specified position.
+ *
+ * The integer return from the callback function is interpreted thus:
+ *
+ * negative: An error occured, png_chunk_error will be called.
+ * zero: The chunk was not handled, the chunk will be saved. A critical
+ * chunk will cause an error at this point unless it is to be saved.
+ * positive: The chunk was handled, libpng will ignore/discard it.
+ *
+ * See "INTERACTION WTIH USER CHUNK CALLBACKS" below for important notes about
+ * how this behavior will change in libpng 1.7
+ */
+PNG_EXPORT(88, void, png_set_read_user_chunk_fn, (png_structrp png_ptr,
+ png_voidp user_chunk_ptr, png_user_chunk_ptr read_user_chunk_fn));
+#endif
+
+#ifdef PNG_USER_CHUNKS_SUPPORTED
+PNG_EXPORT(89, png_voidp, png_get_user_chunk_ptr, (png_const_structrp png_ptr));
+#endif
+
+#ifdef PNG_PROGRESSIVE_READ_SUPPORTED
+/* Sets the function callbacks for the push reader, and a pointer to a
+ * user-defined structure available to the callback functions.
+ */
+PNG_EXPORT(90, void, png_set_progressive_read_fn, (png_structrp png_ptr,
+ png_voidp progressive_ptr, png_progressive_info_ptr info_fn,
+ png_progressive_row_ptr row_fn, png_progressive_end_ptr end_fn));
+
+/* Returns the user pointer associated with the push read functions */
+PNG_EXPORT(91, png_voidp, png_get_progressive_ptr,
+ (png_const_structrp png_ptr));
+
+/* Function to be called when data becomes available */
+PNG_EXPORT(92, void, png_process_data, (png_structrp png_ptr,
+ png_inforp info_ptr, png_bytep buffer, png_size_t buffer_size));
+
+/* A function which may be called *only* within png_process_data to stop the
+ * processing of any more data. The function returns the number of bytes
+ * remaining, excluding any that libpng has cached internally. A subsequent
+ * call to png_process_data must supply these bytes again. If the argument
+ * 'save' is set to true the routine will first save all the pending data and
+ * will always return 0.
+ */
+PNG_EXPORT(219, png_size_t, png_process_data_pause, (png_structrp, int save));
+
+/* A function which may be called *only* outside (after) a call to
+ * png_process_data. It returns the number of bytes of data to skip in the
+ * input. Normally it will return 0, but if it returns a non-zero value the
+ * application must skip than number of bytes of input data and pass the
+ * following data to the next call to png_process_data.
+ */
+PNG_EXPORT(220, png_uint_32, png_process_data_skip, (png_structrp));
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+/* Function that combines rows. 'new_row' is a flag that should come from
+ * the callback and be non-NULL if anything needs to be done; the library
+ * stores its own version of the new data internally and ignores the passed
+ * in value.
+ */
+PNG_EXPORT(93, void, png_progressive_combine_row, (png_const_structrp png_ptr,
+ png_bytep old_row, png_const_bytep new_row));
+#endif /* PNG_READ_INTERLACING_SUPPORTED */
+#endif /* PNG_PROGRESSIVE_READ_SUPPORTED */
+
+PNG_EXPORTA(94, png_voidp, png_malloc, (png_const_structrp png_ptr,
+ png_alloc_size_t size), PNG_ALLOCATED);
+/* Added at libpng version 1.4.0 */
+PNG_EXPORTA(95, png_voidp, png_calloc, (png_const_structrp png_ptr,
+ png_alloc_size_t size), PNG_ALLOCATED);
+
+/* Added at libpng version 1.2.4 */
+PNG_EXPORTA(96, png_voidp, png_malloc_warn, (png_const_structrp png_ptr,
+ png_alloc_size_t size), PNG_ALLOCATED);
+
+/* Frees a pointer allocated by png_malloc() */
+PNG_EXPORT(97, void, png_free, (png_const_structrp png_ptr, png_voidp ptr));
+
+/* Free data that was allocated internally */
+PNG_EXPORT(98, void, png_free_data, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 free_me, int num));
+
+/* Reassign responsibility for freeing existing data, whether allocated
+ * by libpng or by the application; this works on the png_info structure passed
+ * in, it does not change the state for other png_info structures.
+ *
+ * It is unlikely that this function works correctly as of 1.6.0 and using it
+ * may result either in memory leaks or double free of allocated data.
+ */
+PNG_EXPORTA(99, void, png_data_freer, (png_const_structrp png_ptr,
+ png_inforp info_ptr, int freer, png_uint_32 mask), PNG_DEPRECATED);
+
+/* Assignments for png_data_freer */
+#define PNG_DESTROY_WILL_FREE_DATA 1
+#define PNG_SET_WILL_FREE_DATA 1
+#define PNG_USER_WILL_FREE_DATA 2
+/* Flags for png_ptr->free_me and info_ptr->free_me */
+#define PNG_FREE_HIST 0x0008
+#define PNG_FREE_ICCP 0x0010
+#define PNG_FREE_SPLT 0x0020
+#define PNG_FREE_ROWS 0x0040
+#define PNG_FREE_PCAL 0x0080
+#define PNG_FREE_SCAL 0x0100
+#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED
+# define PNG_FREE_UNKN 0x0200
+#endif
+/* PNG_FREE_LIST 0x0400 removed in 1.6.0 because it is ignored */
+#define PNG_FREE_PLTE 0x1000
+#define PNG_FREE_TRNS 0x2000
+#define PNG_FREE_TEXT 0x4000
+#define PNG_FREE_ALL 0x7fff
+#define PNG_FREE_MUL 0x4220 /* PNG_FREE_SPLT|PNG_FREE_TEXT|PNG_FREE_UNKN */
+
+#ifdef PNG_USER_MEM_SUPPORTED
+PNG_EXPORTA(100, png_voidp, png_malloc_default, (png_const_structrp png_ptr,
+ png_alloc_size_t size), PNG_ALLOCATED PNG_DEPRECATED);
+PNG_EXPORTA(101, void, png_free_default, (png_const_structrp png_ptr,
+ png_voidp ptr), PNG_DEPRECATED);
+#endif
+
+#ifdef PNG_ERROR_TEXT_SUPPORTED
+/* Fatal error in PNG image of libpng - can't continue */
+PNG_EXPORTA(102, void, png_error, (png_const_structrp png_ptr,
+ png_const_charp error_message), PNG_NORETURN);
+
+/* The same, but the chunk name is prepended to the error string. */
+PNG_EXPORTA(103, void, png_chunk_error, (png_const_structrp png_ptr,
+ png_const_charp error_message), PNG_NORETURN);
+
+#else
+/* Fatal error in PNG image of libpng - can't continue */
+PNG_EXPORTA(104, void, png_err, (png_const_structrp png_ptr), PNG_NORETURN);
+#endif
+
+#ifdef PNG_WARNINGS_SUPPORTED
+/* Non-fatal error in libpng. Can continue, but may have a problem. */
+PNG_EXPORT(105, void, png_warning, (png_const_structrp png_ptr,
+ png_const_charp warning_message));
+
+/* Non-fatal error in libpng, chunk name is prepended to message. */
+PNG_EXPORT(106, void, png_chunk_warning, (png_const_structrp png_ptr,
+ png_const_charp warning_message));
+#endif
+
+#ifdef PNG_BENIGN_ERRORS_SUPPORTED
+/* Benign error in libpng. Can continue, but may have a problem.
+ * User can choose whether to handle as a fatal error or as a warning. */
+PNG_EXPORT(107, void, png_benign_error, (png_const_structrp png_ptr,
+ png_const_charp warning_message));
+
+#ifdef PNG_READ_SUPPORTED
+/* Same, chunk name is prepended to message (only during read) */
+PNG_EXPORT(108, void, png_chunk_benign_error, (png_const_structrp png_ptr,
+ png_const_charp warning_message));
+#endif
+
+PNG_EXPORT(109, void, png_set_benign_errors,
+ (png_structrp png_ptr, int allowed));
+#else
+# ifdef PNG_ALLOW_BENIGN_ERRORS
+# define png_benign_error png_warning
+# define png_chunk_benign_error png_chunk_warning
+# else
+# define png_benign_error png_error
+# define png_chunk_benign_error png_chunk_error
+# endif
+#endif
+
+/* The png_set_<chunk> functions are for storing values in the png_info_struct.
+ * Similarly, the png_get_<chunk> calls are used to read values from the
+ * png_info_struct, either storing the parameters in the passed variables, or
+ * setting pointers into the png_info_struct where the data is stored. The
+ * png_get_<chunk> functions return a non-zero value if the data was available
+ * in info_ptr, or return zero and do not change any of the parameters if the
+ * data was not available.
+ *
+ * These functions should be used instead of directly accessing png_info
+ * to avoid problems with future changes in the size and internal layout of
+ * png_info_struct.
+ */
+/* Returns "flag" if chunk data is valid in info_ptr. */
+PNG_EXPORT(110, png_uint_32, png_get_valid, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, png_uint_32 flag));
+
+/* Returns number of bytes needed to hold a transformed row. */
+PNG_EXPORT(111, png_size_t, png_get_rowbytes, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+#ifdef PNG_INFO_IMAGE_SUPPORTED
+/* Returns row_pointers, which is an array of pointers to scanlines that was
+ * returned from png_read_png().
+ */
+PNG_EXPORT(112, png_bytepp, png_get_rows, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+/* Set row_pointers, which is an array of pointers to scanlines for use
+ * by png_write_png().
+ */
+PNG_EXPORT(113, void, png_set_rows, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_bytepp row_pointers));
+#endif
+
+/* Returns number of color channels in image. */
+PNG_EXPORT(114, png_byte, png_get_channels, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+#ifdef PNG_EASY_ACCESS_SUPPORTED
+/* Returns image width in pixels. */
+PNG_EXPORT(115, png_uint_32, png_get_image_width, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+/* Returns image height in pixels. */
+PNG_EXPORT(116, png_uint_32, png_get_image_height, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+/* Returns image bit_depth. */
+PNG_EXPORT(117, png_byte, png_get_bit_depth, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+/* Returns image color_type. */
+PNG_EXPORT(118, png_byte, png_get_color_type, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+/* Returns image filter_type. */
+PNG_EXPORT(119, png_byte, png_get_filter_type, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+/* Returns image interlace_type. */
+PNG_EXPORT(120, png_byte, png_get_interlace_type, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+/* Returns image compression_type. */
+PNG_EXPORT(121, png_byte, png_get_compression_type, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+
+/* Returns image resolution in pixels per meter, from pHYs chunk data. */
+PNG_EXPORT(122, png_uint_32, png_get_pixels_per_meter,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+PNG_EXPORT(123, png_uint_32, png_get_x_pixels_per_meter,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+PNG_EXPORT(124, png_uint_32, png_get_y_pixels_per_meter,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+
+/* Returns pixel aspect ratio, computed from pHYs chunk data. */
+PNG_FP_EXPORT(125, float, png_get_pixel_aspect_ratio,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr))
+PNG_FIXED_EXPORT(210, png_fixed_point, png_get_pixel_aspect_ratio_fixed,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr))
+
+/* Returns image x, y offset in pixels or microns, from oFFs chunk data. */
+PNG_EXPORT(126, png_int_32, png_get_x_offset_pixels,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+PNG_EXPORT(127, png_int_32, png_get_y_offset_pixels,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+PNG_EXPORT(128, png_int_32, png_get_x_offset_microns,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+PNG_EXPORT(129, png_int_32, png_get_y_offset_microns,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+
+#endif /* PNG_EASY_ACCESS_SUPPORTED */
+
+#ifdef PNG_READ_SUPPORTED
+/* Returns pointer to signature string read from PNG header */
+PNG_EXPORT(130, png_const_bytep, png_get_signature, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr));
+#endif
+
+#ifdef PNG_bKGD_SUPPORTED
+PNG_EXPORT(131, png_uint_32, png_get_bKGD, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_color_16p *background));
+#endif
+
+#ifdef PNG_bKGD_SUPPORTED
+PNG_EXPORT(132, void, png_set_bKGD, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_color_16p background));
+#endif
+
+#ifdef PNG_cHRM_SUPPORTED
+PNG_FP_EXPORT(133, png_uint_32, png_get_cHRM, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, double *white_x, double *white_y, double *red_x,
+ double *red_y, double *green_x, double *green_y, double *blue_x,
+ double *blue_y))
+PNG_FP_EXPORT(230, png_uint_32, png_get_cHRM_XYZ, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, double *red_X, double *red_Y, double *red_Z,
+ double *green_X, double *green_Y, double *green_Z, double *blue_X,
+ double *blue_Y, double *blue_Z))
+PNG_FIXED_EXPORT(134, png_uint_32, png_get_cHRM_fixed,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_fixed_point *int_white_x, png_fixed_point *int_white_y,
+ png_fixed_point *int_red_x, png_fixed_point *int_red_y,
+ png_fixed_point *int_green_x, png_fixed_point *int_green_y,
+ png_fixed_point *int_blue_x, png_fixed_point *int_blue_y))
+PNG_FIXED_EXPORT(231, png_uint_32, png_get_cHRM_XYZ_fixed,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_fixed_point *int_red_X, png_fixed_point *int_red_Y,
+ png_fixed_point *int_red_Z, png_fixed_point *int_green_X,
+ png_fixed_point *int_green_Y, png_fixed_point *int_green_Z,
+ png_fixed_point *int_blue_X, png_fixed_point *int_blue_Y,
+ png_fixed_point *int_blue_Z))
+#endif
+
+#ifdef PNG_cHRM_SUPPORTED
+PNG_FP_EXPORT(135, void, png_set_cHRM, (png_const_structrp png_ptr,
+ png_inforp info_ptr,
+ double white_x, double white_y, double red_x, double red_y, double green_x,
+ double green_y, double blue_x, double blue_y))
+PNG_FP_EXPORT(232, void, png_set_cHRM_XYZ, (png_const_structrp png_ptr,
+ png_inforp info_ptr, double red_X, double red_Y, double red_Z,
+ double green_X, double green_Y, double green_Z, double blue_X,
+ double blue_Y, double blue_Z))
+PNG_FIXED_EXPORT(136, void, png_set_cHRM_fixed, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_fixed_point int_white_x,
+ png_fixed_point int_white_y, png_fixed_point int_red_x,
+ png_fixed_point int_red_y, png_fixed_point int_green_x,
+ png_fixed_point int_green_y, png_fixed_point int_blue_x,
+ png_fixed_point int_blue_y))
+PNG_FIXED_EXPORT(233, void, png_set_cHRM_XYZ_fixed, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_fixed_point int_red_X, png_fixed_point int_red_Y,
+ png_fixed_point int_red_Z, png_fixed_point int_green_X,
+ png_fixed_point int_green_Y, png_fixed_point int_green_Z,
+ png_fixed_point int_blue_X, png_fixed_point int_blue_Y,
+ png_fixed_point int_blue_Z))
+#endif
+
+#ifdef PNG_gAMA_SUPPORTED
+PNG_FP_EXPORT(137, png_uint_32, png_get_gAMA, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, double *file_gamma))
+PNG_FIXED_EXPORT(138, png_uint_32, png_get_gAMA_fixed,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_fixed_point *int_file_gamma))
+#endif
+
+#ifdef PNG_gAMA_SUPPORTED
+PNG_FP_EXPORT(139, void, png_set_gAMA, (png_const_structrp png_ptr,
+ png_inforp info_ptr, double file_gamma))
+PNG_FIXED_EXPORT(140, void, png_set_gAMA_fixed, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_fixed_point int_file_gamma))
+#endif
+
+#ifdef PNG_hIST_SUPPORTED
+PNG_EXPORT(141, png_uint_32, png_get_hIST, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_uint_16p *hist));
+#endif
+
+#ifdef PNG_hIST_SUPPORTED
+PNG_EXPORT(142, void, png_set_hIST, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_uint_16p hist));
+#endif
+
+PNG_EXPORT(143, png_uint_32, png_get_IHDR, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, png_uint_32 *width, png_uint_32 *height,
+ int *bit_depth, int *color_type, int *interlace_method,
+ int *compression_method, int *filter_method));
+
+PNG_EXPORT(144, void, png_set_IHDR, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 width, png_uint_32 height, int bit_depth,
+ int color_type, int interlace_method, int compression_method,
+ int filter_method));
+
+#ifdef PNG_oFFs_SUPPORTED
+PNG_EXPORT(145, png_uint_32, png_get_oFFs, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, png_int_32 *offset_x, png_int_32 *offset_y,
+ int *unit_type));
+#endif
+
+#ifdef PNG_oFFs_SUPPORTED
+PNG_EXPORT(146, void, png_set_oFFs, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_int_32 offset_x, png_int_32 offset_y,
+ int unit_type));
+#endif
+
+#ifdef PNG_pCAL_SUPPORTED
+PNG_EXPORT(147, png_uint_32, png_get_pCAL, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_charp *purpose, png_int_32 *X0,
+ png_int_32 *X1, int *type, int *nparams, png_charp *units,
+ png_charpp *params));
+#endif
+
+#ifdef PNG_pCAL_SUPPORTED
+PNG_EXPORT(148, void, png_set_pCAL, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_charp purpose, png_int_32 X0, png_int_32 X1,
+ int type, int nparams, png_const_charp units, png_charpp params));
+#endif
+
+#ifdef PNG_pHYs_SUPPORTED
+PNG_EXPORT(149, png_uint_32, png_get_pHYs, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, png_uint_32 *res_x, png_uint_32 *res_y,
+ int *unit_type));
+#endif
+
+#ifdef PNG_pHYs_SUPPORTED
+PNG_EXPORT(150, void, png_set_pHYs, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 res_x, png_uint_32 res_y, int unit_type));
+#endif
+
+PNG_EXPORT(151, png_uint_32, png_get_PLTE, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_colorp *palette, int *num_palette));
+
+PNG_EXPORT(152, void, png_set_PLTE, (png_structrp png_ptr,
+ png_inforp info_ptr, png_const_colorp palette, int num_palette));
+
+#ifdef PNG_sBIT_SUPPORTED
+PNG_EXPORT(153, png_uint_32, png_get_sBIT, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_color_8p *sig_bit));
+#endif
+
+#ifdef PNG_sBIT_SUPPORTED
+PNG_EXPORT(154, void, png_set_sBIT, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_color_8p sig_bit));
+#endif
+
+#ifdef PNG_sRGB_SUPPORTED
+PNG_EXPORT(155, png_uint_32, png_get_sRGB, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, int *file_srgb_intent));
+#endif
+
+#ifdef PNG_sRGB_SUPPORTED
+PNG_EXPORT(156, void, png_set_sRGB, (png_const_structrp png_ptr,
+ png_inforp info_ptr, int srgb_intent));
+PNG_EXPORT(157, void, png_set_sRGB_gAMA_and_cHRM, (png_const_structrp png_ptr,
+ png_inforp info_ptr, int srgb_intent));
+#endif
+
+#ifdef PNG_iCCP_SUPPORTED
+PNG_EXPORT(158, png_uint_32, png_get_iCCP, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_charpp name, int *compression_type,
+ png_bytepp profile, png_uint_32 *proflen));
+#endif
+
+#ifdef PNG_iCCP_SUPPORTED
+PNG_EXPORT(159, void, png_set_iCCP, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_charp name, int compression_type,
+ png_const_bytep profile, png_uint_32 proflen));
+#endif
+
+#ifdef PNG_sPLT_SUPPORTED
+PNG_EXPORT(160, int, png_get_sPLT, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_sPLT_tpp entries));
+#endif
+
+#ifdef PNG_sPLT_SUPPORTED
+PNG_EXPORT(161, void, png_set_sPLT, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_sPLT_tp entries, int nentries));
+#endif
+
+#ifdef PNG_TEXT_SUPPORTED
+/* png_get_text also returns the number of text chunks in *num_text */
+PNG_EXPORT(162, int, png_get_text, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_textp *text_ptr, int *num_text));
+#endif
+
+/* Note while png_set_text() will accept a structure whose text,
+ * language, and translated keywords are NULL pointers, the structure
+ * returned by png_get_text will always contain regular
+ * zero-terminated C strings. They might be empty strings but
+ * they will never be NULL pointers.
+ */
+
+#ifdef PNG_TEXT_SUPPORTED
+PNG_EXPORT(163, void, png_set_text, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_textp text_ptr, int num_text));
+#endif
+
+#ifdef PNG_tIME_SUPPORTED
+PNG_EXPORT(164, png_uint_32, png_get_tIME, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_timep *mod_time));
+#endif
+
+#ifdef PNG_tIME_SUPPORTED
+PNG_EXPORT(165, void, png_set_tIME, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_timep mod_time));
+#endif
+
+#ifdef PNG_tRNS_SUPPORTED
+PNG_EXPORT(166, png_uint_32, png_get_tRNS, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_bytep *trans_alpha, int *num_trans,
+ png_color_16p *trans_color));
+#endif
+
+#ifdef PNG_tRNS_SUPPORTED
+PNG_EXPORT(167, void, png_set_tRNS, (png_structrp png_ptr,
+ png_inforp info_ptr, png_const_bytep trans_alpha, int num_trans,
+ png_const_color_16p trans_color));
+#endif
+
+#ifdef PNG_sCAL_SUPPORTED
+PNG_FP_EXPORT(168, png_uint_32, png_get_sCAL, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, int *unit, double *width, double *height))
+#if defined(PNG_FLOATING_ARITHMETIC_SUPPORTED) || \
+ defined(PNG_FLOATING_POINT_SUPPORTED)
+/* NOTE: this API is currently implemented using floating point arithmetic,
+ * consequently it can only be used on systems with floating point support.
+ * In any case the range of values supported by png_fixed_point is small and it
+ * is highly recommended that png_get_sCAL_s be used instead.
+ */
+PNG_FIXED_EXPORT(214, png_uint_32, png_get_sCAL_fixed,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr, int *unit,
+ png_fixed_point *width, png_fixed_point *height))
+#endif
+PNG_EXPORT(169, png_uint_32, png_get_sCAL_s,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr, int *unit,
+ png_charpp swidth, png_charpp sheight));
+
+PNG_FP_EXPORT(170, void, png_set_sCAL, (png_const_structrp png_ptr,
+ png_inforp info_ptr, int unit, double width, double height))
+PNG_FIXED_EXPORT(213, void, png_set_sCAL_fixed, (png_const_structrp png_ptr,
+ png_inforp info_ptr, int unit, png_fixed_point width,
+ png_fixed_point height))
+PNG_EXPORT(171, void, png_set_sCAL_s, (png_const_structrp png_ptr,
+ png_inforp info_ptr, int unit,
+ png_const_charp swidth, png_const_charp sheight));
+#endif /* PNG_sCAL_SUPPORTED */
+
+#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+/* Provide the default handling for all unknown chunks or, optionally, for
+ * specific unknown chunks.
+ *
+ * NOTE: prior to 1.6.0 the handling specified for particular chunks on read was
+ * ignored and the default was used, the per-chunk setting only had an effect on
+ * write. If you wish to have chunk-specific handling on read in code that must
+ * work on earlier versions you must use a user chunk callback to specify the
+ * desired handling (keep or discard.)
+ *
+ * The 'keep' parameter is a PNG_HANDLE_CHUNK_ value as listed below. The
+ * parameter is interpreted as follows:
+ *
+ * READ:
+ * PNG_HANDLE_CHUNK_AS_DEFAULT:
+ * Known chunks: do normal libpng processing, do not keep the chunk (but
+ * see the comments below about PNG_HANDLE_AS_UNKNOWN_SUPPORTED)
+ * Unknown chunks: for a specific chunk use the global default, when used
+ * as the default discard the chunk data.
+ * PNG_HANDLE_CHUNK_NEVER:
+ * Discard the chunk data.
+ * PNG_HANDLE_CHUNK_IF_SAFE:
+ * Keep the chunk data if the chunk is not critical else raise a chunk
+ * error.
+ * PNG_HANDLE_CHUNK_ALWAYS:
+ * Keep the chunk data.
+ *
+ * If the chunk data is saved it can be retrieved using png_get_unknown_chunks,
+ * below. Notice that specifying "AS_DEFAULT" as a global default is equivalent
+ * to specifying "NEVER", however when "AS_DEFAULT" is used for specific chunks
+ * it simply resets the behavior to the libpng default.
+ *
+ * INTERACTION WTIH USER CHUNK CALLBACKS:
+ * The per-chunk handling is always used when there is a png_user_chunk_ptr
+ * callback and the callback returns 0; the chunk is then always stored *unless*
+ * it is critical and the per-chunk setting is other than ALWAYS. Notice that
+ * the global default is *not* used in this case. (In effect the per-chunk
+ * value is incremented to at least IF_SAFE.)
+ *
+ * IMPORTANT NOTE: this behavior will change in libpng 1.7 - the global and
+ * per-chunk defaults will be honored. If you want to preserve the current
+ * behavior when your callback returns 0 you must set PNG_HANDLE_CHUNK_IF_SAFE
+ * as the default - if you don't do this libpng 1.6 will issue a warning.
+ *
+ * If you want unhandled unknown chunks to be discarded in libpng 1.6 and
+ * earlier simply return '1' (handled).
+ *
+ * PNG_HANDLE_AS_UNKNOWN_SUPPORTED:
+ * If this is *not* set known chunks will always be handled by libpng and
+ * will never be stored in the unknown chunk list. Known chunks listed to
+ * png_set_keep_unknown_chunks will have no effect. If it is set then known
+ * chunks listed with a keep other than AS_DEFAULT will *never* be processed
+ * by libpng, in addition critical chunks must either be processed by the
+ * callback or saved.
+ *
+ * The IHDR and IEND chunks must not be listed. Because this turns off the
+ * default handling for chunks that would otherwise be recognized the
+ * behavior of libpng transformations may well become incorrect!
+ *
+ * WRITE:
+ * When writing chunks the options only apply to the chunks specified by
+ * png_set_unknown_chunks (below), libpng will *always* write known chunks
+ * required by png_set_ calls and will always write the core critical chunks
+ * (as required for PLTE).
+ *
+ * Each chunk in the png_set_unknown_chunks list is looked up in the
+ * png_set_keep_unknown_chunks list to find the keep setting, this is then
+ * interpreted as follows:
+ *
+ * PNG_HANDLE_CHUNK_AS_DEFAULT:
+ * Write safe-to-copy chunks and write other chunks if the global
+ * default is set to _ALWAYS, otherwise don't write this chunk.
+ * PNG_HANDLE_CHUNK_NEVER:
+ * Do not write the chunk.
+ * PNG_HANDLE_CHUNK_IF_SAFE:
+ * Write the chunk if it is safe-to-copy, otherwise do not write it.
+ * PNG_HANDLE_CHUNK_ALWAYS:
+ * Write the chunk.
+ *
+ * Note that the default behavior is effectively the opposite of the read case -
+ * in read unknown chunks are not stored by default, in write they are written
+ * by default. Also the behavior of PNG_HANDLE_CHUNK_IF_SAFE is very different
+ * - on write the safe-to-copy bit is checked, on read the critical bit is
+ * checked and on read if the chunk is critical an error will be raised.
+ *
+ * num_chunks:
+ * ===========
+ * If num_chunks is positive, then the "keep" parameter specifies the manner
+ * for handling only those chunks appearing in the chunk_list array,
+ * otherwise the chunk list array is ignored.
+ *
+ * If num_chunks is 0 the "keep" parameter specifies the default behavior for
+ * unknown chunks, as described above.
+ *
+ * If num_chunks is negative, then the "keep" parameter specifies the manner
+ * for handling all unknown chunks plus all chunks recognized by libpng
+ * except for the IHDR, PLTE, tRNS, IDAT, and IEND chunks (which continue to
+ * be processed by libpng.
+ */
+PNG_EXPORT(172, void, png_set_keep_unknown_chunks, (png_structrp png_ptr,
+ int keep, png_const_bytep chunk_list, int num_chunks));
+
+/* The "keep" PNG_HANDLE_CHUNK_ parameter for the specified chunk is returned;
+ * the result is therefore true (non-zero) if special handling is required,
+ * false for the default handling.
+ */
+PNG_EXPORT(173, int, png_handle_as_unknown, (png_const_structrp png_ptr,
+ png_const_bytep chunk_name));
+#endif
+
+#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED
+PNG_EXPORT(174, void, png_set_unknown_chunks, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_unknown_chunkp unknowns,
+ int num_unknowns));
+ /* NOTE: prior to 1.6.0 this routine set the 'location' field of the added
+ * unknowns to the location currently stored in the png_struct. This is
+ * invariably the wrong value on write. To fix this call the following API
+ * for each chunk in the list with the correct location. If you know your
+ * code won't be compiled on earlier versions you can rely on
+ * png_set_unknown_chunks(write-ptr, png_get_unknown_chunks(read-ptr)) doing
+ * the correct thing.
+ */
+
+PNG_EXPORT(175, void, png_set_unknown_chunk_location,
+ (png_const_structrp png_ptr, png_inforp info_ptr, int chunk, int location));
+
+PNG_EXPORT(176, int, png_get_unknown_chunks, (png_const_structrp png_ptr,
+ png_inforp info_ptr, png_unknown_chunkpp entries));
+#endif
+
+/* Png_free_data() will turn off the "valid" flag for anything it frees.
+ * If you need to turn it off for a chunk that your application has freed,
+ * you can use png_set_invalid(png_ptr, info_ptr, PNG_INFO_CHNK);
+ */
+PNG_EXPORT(177, void, png_set_invalid, (png_const_structrp png_ptr,
+ png_inforp info_ptr, int mask));
+
+#ifdef PNG_INFO_IMAGE_SUPPORTED
+/* The "params" pointer is currently not used and is for future expansion. */
+PNG_EXPORT(178, void, png_read_png, (png_structrp png_ptr, png_inforp info_ptr,
+ int transforms, png_voidp params));
+PNG_EXPORT(179, void, png_write_png, (png_structrp png_ptr, png_inforp info_ptr,
+ int transforms, png_voidp params));
+#endif
+
+PNG_EXPORT(180, png_const_charp, png_get_copyright,
+ (png_const_structrp png_ptr));
+PNG_EXPORT(181, png_const_charp, png_get_header_ver,
+ (png_const_structrp png_ptr));
+PNG_EXPORT(182, png_const_charp, png_get_header_version,
+ (png_const_structrp png_ptr));
+PNG_EXPORT(183, png_const_charp, png_get_libpng_ver,
+ (png_const_structrp png_ptr));
+
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+PNG_EXPORT(184, png_uint_32, png_permit_mng_features, (png_structrp png_ptr,
+ png_uint_32 mng_features_permitted));
+#endif
+
+/* For use in png_set_keep_unknown, added to version 1.2.6 */
+#define PNG_HANDLE_CHUNK_AS_DEFAULT 0
+#define PNG_HANDLE_CHUNK_NEVER 1
+#define PNG_HANDLE_CHUNK_IF_SAFE 2
+#define PNG_HANDLE_CHUNK_ALWAYS 3
+#define PNG_HANDLE_CHUNK_LAST 4
+
+/* Strip the prepended error numbers ("#nnn ") from error and warning
+ * messages before passing them to the error or warning handler.
+ */
+#ifdef PNG_ERROR_NUMBERS_SUPPORTED
+PNG_EXPORT(185, void, png_set_strip_error_numbers, (png_structrp png_ptr,
+ png_uint_32 strip_mode));
+#endif
+
+/* Added in libpng-1.2.6 */
+#ifdef PNG_SET_USER_LIMITS_SUPPORTED
+PNG_EXPORT(186, void, png_set_user_limits, (png_structrp png_ptr,
+ png_uint_32 user_width_max, png_uint_32 user_height_max));
+PNG_EXPORT(187, png_uint_32, png_get_user_width_max,
+ (png_const_structrp png_ptr));
+PNG_EXPORT(188, png_uint_32, png_get_user_height_max,
+ (png_const_structrp png_ptr));
+/* Added in libpng-1.4.0 */
+PNG_EXPORT(189, void, png_set_chunk_cache_max, (png_structrp png_ptr,
+ png_uint_32 user_chunk_cache_max));
+PNG_EXPORT(190, png_uint_32, png_get_chunk_cache_max,
+ (png_const_structrp png_ptr));
+/* Added in libpng-1.4.1 */
+PNG_EXPORT(191, void, png_set_chunk_malloc_max, (png_structrp png_ptr,
+ png_alloc_size_t user_chunk_cache_max));
+PNG_EXPORT(192, png_alloc_size_t, png_get_chunk_malloc_max,
+ (png_const_structrp png_ptr));
+#endif
+
+#if defined(PNG_INCH_CONVERSIONS_SUPPORTED)
+PNG_EXPORT(193, png_uint_32, png_get_pixels_per_inch,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+
+PNG_EXPORT(194, png_uint_32, png_get_x_pixels_per_inch,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+
+PNG_EXPORT(195, png_uint_32, png_get_y_pixels_per_inch,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr));
+
+PNG_FP_EXPORT(196, float, png_get_x_offset_inches,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr))
+#ifdef PNG_FIXED_POINT_SUPPORTED /* otherwise not implemented. */
+PNG_FIXED_EXPORT(211, png_fixed_point, png_get_x_offset_inches_fixed,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr))
+#endif
+
+PNG_FP_EXPORT(197, float, png_get_y_offset_inches, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr))
+#ifdef PNG_FIXED_POINT_SUPPORTED /* otherwise not implemented. */
+PNG_FIXED_EXPORT(212, png_fixed_point, png_get_y_offset_inches_fixed,
+ (png_const_structrp png_ptr, png_const_inforp info_ptr))
+#endif
+
+# ifdef PNG_pHYs_SUPPORTED
+PNG_EXPORT(198, png_uint_32, png_get_pHYs_dpi, (png_const_structrp png_ptr,
+ png_const_inforp info_ptr, png_uint_32 *res_x, png_uint_32 *res_y,
+ int *unit_type));
+# endif /* PNG_pHYs_SUPPORTED */
+#endif /* PNG_INCH_CONVERSIONS_SUPPORTED */
+
+/* Added in libpng-1.4.0 */
+#ifdef PNG_IO_STATE_SUPPORTED
+PNG_EXPORT(199, png_uint_32, png_get_io_state, (png_const_structrp png_ptr));
+
+/* Removed from libpng 1.6; use png_get_io_chunk_type. */
+PNG_REMOVED(200, png_const_bytep, png_get_io_chunk_name, (png_structrp png_ptr),
+ PNG_DEPRECATED)
+
+PNG_EXPORT(216, png_uint_32, png_get_io_chunk_type,
+ (png_const_structrp png_ptr));
+
+/* The flags returned by png_get_io_state() are the following: */
+# define PNG_IO_NONE 0x0000 /* no I/O at this moment */
+# define PNG_IO_READING 0x0001 /* currently reading */
+# define PNG_IO_WRITING 0x0002 /* currently writing */
+# define PNG_IO_SIGNATURE 0x0010 /* currently at the file signature */
+# define PNG_IO_CHUNK_HDR 0x0020 /* currently at the chunk header */
+# define PNG_IO_CHUNK_DATA 0x0040 /* currently at the chunk data */
+# define PNG_IO_CHUNK_CRC 0x0080 /* currently at the chunk crc */
+# define PNG_IO_MASK_OP 0x000f /* current operation: reading/writing */
+# define PNG_IO_MASK_LOC 0x00f0 /* current location: sig/hdr/data/crc */
+#endif /* ?PNG_IO_STATE_SUPPORTED */
+
+/* Interlace support. The following macros are always defined so that if
+ * libpng interlace handling is turned off the macros may be used to handle
+ * interlaced images within the application.
+ */
+#define PNG_INTERLACE_ADAM7_PASSES 7
+
+/* Two macros to return the first row and first column of the original,
+ * full, image which appears in a given pass. 'pass' is in the range 0
+ * to 6 and the result is in the range 0 to 7.
+ */
+#define PNG_PASS_START_ROW(pass) (((1&~(pass))<<(3-((pass)>>1)))&7)
+#define PNG_PASS_START_COL(pass) (((1& (pass))<<(3-(((pass)+1)>>1)))&7)
+
+/* A macro to return the offset between pixels in the output row for a pair of
+ * pixels in the input - effectively the inverse of the 'COL_SHIFT' macro that
+ * follows. Note that ROW_OFFSET is the offset from one row to the next whereas
+ * COL_OFFSET is from one column to the next, within a row.
+ */
+#define PNG_PASS_ROW_OFFSET(pass) ((pass)>2?(8>>(((pass)-1)>>1)):8)
+#define PNG_PASS_COL_OFFSET(pass) (1<<((7-(pass))>>1))
+
+/* Two macros to help evaluate the number of rows or columns in each
+ * pass. This is expressed as a shift - effectively log2 of the number or
+ * rows or columns in each 8x8 tile of the original image.
+ */
+#define PNG_PASS_ROW_SHIFT(pass) ((pass)>2?(8-(pass))>>1:3)
+#define PNG_PASS_COL_SHIFT(pass) ((pass)>1?(7-(pass))>>1:3)
+
+/* Hence two macros to determine the number of rows or columns in a given
+ * pass of an image given its height or width. In fact these macros may
+ * return non-zero even though the sub-image is empty, because the other
+ * dimension may be empty for a small image.
+ */
+#define PNG_PASS_ROWS(height, pass) (((height)+(((1<<PNG_PASS_ROW_SHIFT(pass))\
+ -1)-PNG_PASS_START_ROW(pass)))>>PNG_PASS_ROW_SHIFT(pass))
+#define PNG_PASS_COLS(width, pass) (((width)+(((1<<PNG_PASS_COL_SHIFT(pass))\
+ -1)-PNG_PASS_START_COL(pass)))>>PNG_PASS_COL_SHIFT(pass))
+
+/* For the reader row callbacks (both progressive and sequential) it is
+ * necessary to find the row in the output image given a row in an interlaced
+ * image, so two more macros:
+ */
+#define PNG_ROW_FROM_PASS_ROW(y_in, pass) \
+ (((y_in)<<PNG_PASS_ROW_SHIFT(pass))+PNG_PASS_START_ROW(pass))
+#define PNG_COL_FROM_PASS_COL(x_in, pass) \
+ (((x_in)<<PNG_PASS_COL_SHIFT(pass))+PNG_PASS_START_COL(pass))
+
+/* Two macros which return a boolean (0 or 1) saying whether the given row
+ * or column is in a particular pass. These use a common utility macro that
+ * returns a mask for a given pass - the offset 'off' selects the row or
+ * column version. The mask has the appropriate bit set for each column in
+ * the tile.
+ */
+#define PNG_PASS_MASK(pass,off) ( \
+ ((0x110145AF>>(((7-(off))-(pass))<<2)) & 0xF) | \
+ ((0x01145AF0>>(((7-(off))-(pass))<<2)) & 0xF0))
+
+#define PNG_ROW_IN_INTERLACE_PASS(y, pass) \
+ ((PNG_PASS_MASK(pass,0) >> ((y)&7)) & 1)
+#define PNG_COL_IN_INTERLACE_PASS(x, pass) \
+ ((PNG_PASS_MASK(pass,1) >> ((x)&7)) & 1)
+
+#ifdef PNG_READ_COMPOSITE_NODIV_SUPPORTED
+/* With these routines we avoid an integer divide, which will be slower on
+ * most machines. However, it does take more operations than the corresponding
+ * divide method, so it may be slower on a few RISC systems. There are two
+ * shifts (by 8 or 16 bits) and an addition, versus a single integer divide.
+ *
+ * Note that the rounding factors are NOT supposed to be the same! 128 and
+ * 32768 are correct for the NODIV code; 127 and 32767 are correct for the
+ * standard method.
+ *
+ * [Optimized code by Greg Roelofs and Mark Adler...blame us for bugs. :-) ]
+ */
+
+ /* fg and bg should be in `gamma 1.0' space; alpha is the opacity */
+
+# define png_composite(composite, fg, alpha, bg) \
+ { png_uint_16 temp = (png_uint_16)((png_uint_16)(fg) \
+ * (png_uint_16)(alpha) \
+ + (png_uint_16)(bg)*(png_uint_16)(255 \
+ - (png_uint_16)(alpha)) + 128); \
+ (composite) = (png_byte)((temp + (temp >> 8)) >> 8); }
+
+# define png_composite_16(composite, fg, alpha, bg) \
+ { png_uint_32 temp = (png_uint_32)((png_uint_32)(fg) \
+ * (png_uint_32)(alpha) \
+ + (png_uint_32)(bg)*(65535 \
+ - (png_uint_32)(alpha)) + 32768); \
+ (composite) = (png_uint_16)((temp + (temp >> 16)) >> 16); }
+
+#else /* Standard method using integer division */
+
+# define png_composite(composite, fg, alpha, bg) \
+ (composite) = (png_byte)(((png_uint_16)(fg) * (png_uint_16)(alpha) + \
+ (png_uint_16)(bg) * (png_uint_16)(255 - (png_uint_16)(alpha)) + \
+ 127) / 255)
+
+# define png_composite_16(composite, fg, alpha, bg) \
+ (composite) = (png_uint_16)(((png_uint_32)(fg) * (png_uint_32)(alpha) + \
+ (png_uint_32)(bg)*(png_uint_32)(65535 - (png_uint_32)(alpha)) + \
+ 32767) / 65535)
+#endif /* PNG_READ_COMPOSITE_NODIV_SUPPORTED */
+
+#ifdef PNG_READ_INT_FUNCTIONS_SUPPORTED
+PNG_EXPORT(201, png_uint_32, png_get_uint_32, (png_const_bytep buf));
+PNG_EXPORT(202, png_uint_16, png_get_uint_16, (png_const_bytep buf));
+PNG_EXPORT(203, png_int_32, png_get_int_32, (png_const_bytep buf));
+#endif
+
+PNG_EXPORT(204, png_uint_32, png_get_uint_31, (png_const_structrp png_ptr,
+ png_const_bytep buf));
+/* No png_get_int_16 -- may be added if there's a real need for it. */
+
+/* Place a 32-bit number into a buffer in PNG byte order (big-endian). */
+#ifdef PNG_WRITE_INT_FUNCTIONS_SUPPORTED
+PNG_EXPORT(205, void, png_save_uint_32, (png_bytep buf, png_uint_32 i));
+#endif
+#ifdef PNG_SAVE_INT_32_SUPPORTED
+PNG_EXPORT(206, void, png_save_int_32, (png_bytep buf, png_int_32 i));
+#endif
+
+/* Place a 16-bit number into a buffer in PNG byte order.
+ * The parameter is declared unsigned int, not png_uint_16,
+ * just to avoid potential problems on pre-ANSI C compilers.
+ */
+#ifdef PNG_WRITE_INT_FUNCTIONS_SUPPORTED
+PNG_EXPORT(207, void, png_save_uint_16, (png_bytep buf, unsigned int i));
+/* No png_save_int_16 -- may be added if there's a real need for it. */
+#endif
+
+#ifdef PNG_USE_READ_MACROS
+/* Inline macros to do direct reads of bytes from the input buffer.
+ * The png_get_int_32() routine assumes we are using two's complement
+ * format for negative values, which is almost certainly true.
+ */
+# define PNG_get_uint_32(buf) \
+ (((png_uint_32)(*(buf)) << 24) + \
+ ((png_uint_32)(*((buf) + 1)) << 16) + \
+ ((png_uint_32)(*((buf) + 2)) << 8) + \
+ ((png_uint_32)(*((buf) + 3))))
+
+ /* From libpng-1.4.0 until 1.4.4, the png_get_uint_16 macro (but not the
+ * function) incorrectly returned a value of type png_uint_32.
+ */
+# define PNG_get_uint_16(buf) \
+ ((png_uint_16) \
+ (((unsigned int)(*(buf)) << 8) + \
+ ((unsigned int)(*((buf) + 1)))))
+
+# define PNG_get_int_32(buf) \
+ ((png_int_32)((*(buf) & 0x80) \
+ ? -((png_int_32)((png_get_uint_32(buf) ^ 0xffffffffL) + 1)) \
+ : (png_int_32)png_get_uint_32(buf)))
+
+ /* If PNG_PREFIX is defined the same thing as below happens in pnglibconf.h,
+ * but defining a macro name prefixed with PNG_PREFIX.
+ */
+# ifndef PNG_PREFIX
+# define png_get_uint_32(buf) PNG_get_uint_32(buf)
+# define png_get_uint_16(buf) PNG_get_uint_16(buf)
+# define png_get_int_32(buf) PNG_get_int_32(buf)
+# endif
+#else
+# ifdef PNG_PREFIX
+ /* No macros; revert to the (redefined) function */
+# define PNG_get_uint_32 (png_get_uint_32)
+# define PNG_get_uint_16 (png_get_uint_16)
+# define PNG_get_int_32 (png_get_int_32)
+# endif
+#endif
+
+/*******************************************************************************
+ * SIMPLIFIED API
+ *******************************************************************************
+ *
+ * Please read the documentation in libpng-manual.txt (TODO: write said
+ * documentation) if you don't understand what follows.
+ *
+ * The simplified API hides the details of both libpng and the PNG file format
+ * itself. It allows PNG files to be read into a very limited number of
+ * in-memory bitmap formats or to be written from the same formats. If these
+ * formats do not accomodate your needs then you can, and should, use the more
+ * sophisticated APIs above - these support a wide variety of in-memory formats
+ * and a wide variety of sophisticated transformations to those formats as well
+ * as a wide variety of APIs to manipulate ancillary information.
+ *
+ * To read a PNG file using the simplified API:
+ *
+ * 1) Declare a 'png_image' structure (see below) on the stack and set the
+ * version field to PNG_IMAGE_VERSION.
+ * 2) Call the appropriate png_image_begin_read... function.
+ * 3) Set the png_image 'format' member to the required sample format.
+ * 4) Allocate a buffer for the image and, if required, the color-map.
+ * 5) Call png_image_finish_read to read the image and, if required, the
+ * color-map into your buffers.
+ *
+ * There are no restrictions on the format of the PNG input itself; all valid
+ * color types, bit depths, and interlace methods are acceptable, and the
+ * input image is transformed as necessary to the requested in-memory format
+ * during the png_image_finish_read() step. The only caveat is that if you
+ * request a color-mapped image from a PNG that is full-color or makes
+ * complex use of an alpha channel the transformation is extremely lossy and the
+ * result may look terrible.
+ *
+ * To write a PNG file using the simplified API:
+ *
+ * 1) Declare a 'png_image' structure on the stack and memset() it to all zero.
+ * 2) Initialize the members of the structure that describe the image, setting
+ * the 'format' member to the format of the image samples.
+ * 3) Call the appropriate png_image_write... function with a pointer to the
+ * image and, if necessary, the color-map to write the PNG data.
+ *
+ * png_image is a structure that describes the in-memory format of an image
+ * when it is being read or defines the in-memory format of an image that you
+ * need to write:
+ */
+#define PNG_IMAGE_VERSION 1
+
+typedef struct png_control *png_controlp;
+typedef struct
+{
+ png_controlp opaque; /* Initialize to NULL, free with png_image_free */
+ png_uint_32 version; /* Set to PNG_IMAGE_VERSION */
+ png_uint_32 width; /* Image width in pixels (columns) */
+ png_uint_32 height; /* Image height in pixels (rows) */
+ png_uint_32 format; /* Image format as defined below */
+ png_uint_32 flags; /* A bit mask containing informational flags */
+ png_uint_32 colormap_entries;
+ /* Number of entries in the color-map */
+
+ /* In the event of an error or warning the following field will be set to a
+ * non-zero value and the 'message' field will contain a '\0' terminated
+ * string with the libpng error or warning message. If both warnings and
+ * an error were encountered, only the error is recorded. If there
+ * are multiple warnings, only the first one is recorded.
+ *
+ * The upper 30 bits of this value are reserved, the low two bits contain
+ * a value as follows:
+ */
+# define PNG_IMAGE_WARNING 1
+# define PNG_IMAGE_ERROR 2
+ /*
+ * The result is a two bit code such that a value more than 1 indicates
+ * a failure in the API just called:
+ *
+ * 0 - no warning or error
+ * 1 - warning
+ * 2 - error
+ * 3 - error preceded by warning
+ */
+# define PNG_IMAGE_FAILED(png_cntrl) ((((png_cntrl).warning_or_error)&0x03)>1)
+
+ png_uint_32 warning_or_error;
+
+ char message[64];
+} png_image, *png_imagep;
+
+/* The samples of the image have one to four channels whose components have
+ * original values in the range 0 to 1.0:
+ *
+ * 1: A single gray or luminance channel (G).
+ * 2: A gray/luminance channel and an alpha channel (GA).
+ * 3: Three red, green, blue color channels (RGB).
+ * 4: Three color channels and an alpha channel (RGBA).
+ *
+ * The components are encoded in one of two ways:
+ *
+ * a) As a small integer, value 0..255, contained in a single byte. For the
+ * alpha channel the original value is simply value/255. For the color or
+ * luminance channels the value is encoded according to the sRGB specification
+ * and matches the 8-bit format expected by typical display devices.
+ *
+ * The color/gray channels are not scaled (pre-multiplied) by the alpha
+ * channel and are suitable for passing to color management software.
+ *
+ * b) As a value in the range 0..65535, contained in a 2-byte integer. All
+ * channels can be converted to the original value by dividing by 65535; all
+ * channels are linear. Color channels use the RGB encoding (RGB end-points) of
+ * the sRGB specification. This encoding is identified by the
+ * PNG_FORMAT_FLAG_LINEAR flag below.
+ *
+ * When the simplified API needs to convert between sRGB and linear colorspaces,
+ * the actual sRGB transfer curve defined in the sRGB specification (see the
+ * article at http://en.wikipedia.org/wiki/SRGB) is used, not the gamma=1/2.2
+ * approximation used elsewhere in libpng.
+ *
+ * When an alpha channel is present it is expected to denote pixel coverage
+ * of the color or luminance channels and is returned as an associated alpha
+ * channel: the color/gray channels are scaled (pre-multiplied) by the alpha
+ * value.
+ *
+ * The samples are either contained directly in the image data, between 1 and 8
+ * bytes per pixel according to the encoding, or are held in a color-map indexed
+ * by bytes in the image data. In the case of a color-map the color-map entries
+ * are individual samples, encoded as above, and the image data has one byte per
+ * pixel to select the relevant sample from the color-map.
+ */
+
+/* PNG_FORMAT_*
+ *
+ * #defines to be used in png_image::format. Each #define identifies a
+ * particular layout of sample data and, if present, alpha values. There are
+ * separate defines for each of the two component encodings.
+ *
+ * A format is built up using single bit flag values. All combinations are
+ * valid. Formats can be built up from the flag values or you can use one of
+ * the predefined values below. When testing formats always use the FORMAT_FLAG
+ * macros to test for individual features - future versions of the library may
+ * add new flags.
+ *
+ * When reading or writing color-mapped images the format should be set to the
+ * format of the entries in the color-map then png_image_{read,write}_colormap
+ * called to read or write the color-map and set the format correctly for the
+ * image data. Do not set the PNG_FORMAT_FLAG_COLORMAP bit directly!
+ *
+ * NOTE: libpng can be built with particular features disabled, if you see
+ * compiler errors because the definition of one of the following flags has been
+ * compiled out it is because libpng does not have the required support. It is
+ * possible, however, for the libpng configuration to enable the format on just
+ * read or just write; in that case you may see an error at run time. You can
+ * guard against this by checking for the definition of the appropriate
+ * "_SUPPORTED" macro, one of:
+ *
+ * PNG_SIMPLIFIED_{READ,WRITE}_{BGR,AFIRST}_SUPPORTED
+ */
+#define PNG_FORMAT_FLAG_ALPHA 0x01U /* format with an alpha channel */
+#define PNG_FORMAT_FLAG_COLOR 0x02U /* color format: otherwise grayscale */
+#define PNG_FORMAT_FLAG_LINEAR 0x04U /* 2 byte channels else 1 byte */
+#define PNG_FORMAT_FLAG_COLORMAP 0x08U /* image data is color-mapped */
+
+#ifdef PNG_FORMAT_BGR_SUPPORTED
+# define PNG_FORMAT_FLAG_BGR 0x10U /* BGR colors, else order is RGB */
+#endif
+
+#ifdef PNG_FORMAT_AFIRST_SUPPORTED
+# define PNG_FORMAT_FLAG_AFIRST 0x20U /* alpha channel comes first */
+#endif
+
+/* Commonly used formats have predefined macros.
+ *
+ * First the single byte (sRGB) formats:
+ */
+#define PNG_FORMAT_GRAY 0
+#define PNG_FORMAT_GA PNG_FORMAT_FLAG_ALPHA
+#define PNG_FORMAT_AG (PNG_FORMAT_GA|PNG_FORMAT_FLAG_AFIRST)
+#define PNG_FORMAT_RGB PNG_FORMAT_FLAG_COLOR
+#define PNG_FORMAT_BGR (PNG_FORMAT_FLAG_COLOR|PNG_FORMAT_FLAG_BGR)
+#define PNG_FORMAT_RGBA (PNG_FORMAT_RGB|PNG_FORMAT_FLAG_ALPHA)
+#define PNG_FORMAT_ARGB (PNG_FORMAT_RGBA|PNG_FORMAT_FLAG_AFIRST)
+#define PNG_FORMAT_BGRA (PNG_FORMAT_BGR|PNG_FORMAT_FLAG_ALPHA)
+#define PNG_FORMAT_ABGR (PNG_FORMAT_BGRA|PNG_FORMAT_FLAG_AFIRST)
+
+/* Then the linear 2-byte formats. When naming these "Y" is used to
+ * indicate a luminance (gray) channel.
+ */
+#define PNG_FORMAT_LINEAR_Y PNG_FORMAT_FLAG_LINEAR
+#define PNG_FORMAT_LINEAR_Y_ALPHA (PNG_FORMAT_FLAG_LINEAR|PNG_FORMAT_FLAG_ALPHA)
+#define PNG_FORMAT_LINEAR_RGB (PNG_FORMAT_FLAG_LINEAR|PNG_FORMAT_FLAG_COLOR)
+#define PNG_FORMAT_LINEAR_RGB_ALPHA \
+ (PNG_FORMAT_FLAG_LINEAR|PNG_FORMAT_FLAG_COLOR|PNG_FORMAT_FLAG_ALPHA)
+
+/* With color-mapped formats the image data is one byte for each pixel, the byte
+ * is an index into the color-map which is formatted as above. To obtain a
+ * color-mapped format it is sufficient just to add the PNG_FOMAT_FLAG_COLORMAP
+ * to one of the above definitions, or you can use one of the definitions below.
+ */
+#define PNG_FORMAT_RGB_COLORMAP (PNG_FORMAT_RGB|PNG_FORMAT_FLAG_COLORMAP)
+#define PNG_FORMAT_BGR_COLORMAP (PNG_FORMAT_BGR|PNG_FORMAT_FLAG_COLORMAP)
+#define PNG_FORMAT_RGBA_COLORMAP (PNG_FORMAT_RGBA|PNG_FORMAT_FLAG_COLORMAP)
+#define PNG_FORMAT_ARGB_COLORMAP (PNG_FORMAT_ARGB|PNG_FORMAT_FLAG_COLORMAP)
+#define PNG_FORMAT_BGRA_COLORMAP (PNG_FORMAT_BGRA|PNG_FORMAT_FLAG_COLORMAP)
+#define PNG_FORMAT_ABGR_COLORMAP (PNG_FORMAT_ABGR|PNG_FORMAT_FLAG_COLORMAP)
+
+/* PNG_IMAGE macros
+ *
+ * These are convenience macros to derive information from a png_image
+ * structure. The PNG_IMAGE_SAMPLE_ macros return values appropriate to the
+ * actual image sample values - either the entries in the color-map or the
+ * pixels in the image. The PNG_IMAGE_PIXEL_ macros return corresponding values
+ * for the pixels and will always return 1 for color-mapped formats. The
+ * remaining macros return information about the rows in the image and the
+ * complete image.
+ *
+ * NOTE: All the macros that take a png_image::format parameter are compile time
+ * constants if the format parameter is, itself, a constant. Therefore these
+ * macros can be used in array declarations and case labels where required.
+ * Similarly the macros are also pre-processor constants (sizeof is not used) so
+ * they can be used in #if tests.
+ *
+ * First the information about the samples.
+ */
+#define PNG_IMAGE_SAMPLE_CHANNELS(fmt)\
+ (((fmt)&(PNG_FORMAT_FLAG_COLOR|PNG_FORMAT_FLAG_ALPHA))+1)
+ /* Return the total number of channels in a given format: 1..4 */
+
+#define PNG_IMAGE_SAMPLE_COMPONENT_SIZE(fmt)\
+ ((((fmt) & PNG_FORMAT_FLAG_LINEAR) >> 2)+1)
+ /* Return the size in bytes of a single component of a pixel or color-map
+ * entry (as appropriate) in the image: 1 or 2.
+ */
+
+#define PNG_IMAGE_SAMPLE_SIZE(fmt)\
+ (PNG_IMAGE_SAMPLE_CHANNELS(fmt) * PNG_IMAGE_SAMPLE_COMPONENT_SIZE(fmt))
+ /* This is the size of the sample data for one sample. If the image is
+ * color-mapped it is the size of one color-map entry (and image pixels are
+ * one byte in size), otherwise it is the size of one image pixel.
+ */
+
+#define PNG_IMAGE_MAXIMUM_COLORMAP_COMPONENTS(fmt)\
+ (PNG_IMAGE_SAMPLE_CHANNELS(fmt) * 256)
+ /* The maximum size of the color-map required by the format expressed in a
+ * count of components. This can be used to compile-time allocate a
+ * color-map:
+ *
+ * png_uint_16 colormap[PNG_IMAGE_MAXIMUM_COLORMAP_COMPONENTS(linear_fmt)];
+ *
+ * png_byte colormap[PNG_IMAGE_MAXIMUM_COLORMAP_COMPONENTS(sRGB_fmt)];
+ *
+ * Alternatively use the PNG_IMAGE_COLORMAP_SIZE macro below to use the
+ * information from one of the png_image_begin_read_ APIs and dynamically
+ * allocate the required memory.
+ */
+
+/* Corresponding information about the pixels */
+#define PNG_IMAGE_PIXEL_(test,fmt)\
+ (((fmt)&PNG_FORMAT_FLAG_COLORMAP)?1:test(fmt))
+
+#define PNG_IMAGE_PIXEL_CHANNELS(fmt)\
+ PNG_IMAGE_PIXEL_(PNG_IMAGE_SAMPLE_CHANNELS,fmt)
+ /* The number of separate channels (components) in a pixel; 1 for a
+ * color-mapped image.
+ */
+
+#define PNG_IMAGE_PIXEL_COMPONENT_SIZE(fmt)\
+ PNG_IMAGE_PIXEL_(PNG_IMAGE_SAMPLE_COMPONENT_SIZE,fmt)
+ /* The size, in bytes, of each component in a pixel; 1 for a color-mapped
+ * image.
+ */
+
+#define PNG_IMAGE_PIXEL_SIZE(fmt) PNG_IMAGE_PIXEL_(PNG_IMAGE_SAMPLE_SIZE,fmt)
+ /* The size, in bytes, of a complete pixel; 1 for a color-mapped image. */
+
+/* Information about the whole row, or whole image */
+#define PNG_IMAGE_ROW_STRIDE(image)\
+ (PNG_IMAGE_PIXEL_CHANNELS((image).format) * (image).width)
+ /* Return the total number of components in a single row of the image; this
+ * is the minimum 'row stride', the minimum count of components between each
+ * row. For a color-mapped image this is the minimum number of bytes in a
+ * row.
+ */
+
+#define PNG_IMAGE_BUFFER_SIZE(image, row_stride)\
+ (PNG_IMAGE_PIXEL_COMPONENT_SIZE((image).format)*(image).height*(row_stride))
+ /* Return the size, in bytes, of an image buffer given a png_image and a row
+ * stride - the number of components to leave space for in each row.
+ */
+
+#define PNG_IMAGE_SIZE(image)\
+ PNG_IMAGE_BUFFER_SIZE(image, PNG_IMAGE_ROW_STRIDE(image))
+ /* Return the size, in bytes, of the image in memory given just a png_image;
+ * the row stride is the minimum stride required for the image.
+ */
+
+#define PNG_IMAGE_COLORMAP_SIZE(image)\
+ (PNG_IMAGE_SAMPLE_SIZE((image).format) * (image).colormap_entries)
+ /* Return the size, in bytes, of the color-map of this image. If the image
+ * format is not a color-map format this will return a size sufficient for
+ * 256 entries in the given format; check PNG_FORMAT_FLAG_COLORMAP if
+ * you don't want to allocate a color-map in this case.
+ */
+
+/* PNG_IMAGE_FLAG_*
+ *
+ * Flags containing additional information about the image are held in the
+ * 'flags' field of png_image.
+ */
+#define PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB 0x01
+ /* This indicates the the RGB values of the in-memory bitmap do not
+ * correspond to the red, green and blue end-points defined by sRGB.
+ */
+
+#define PNG_IMAGE_FLAG_FAST 0x02
+ /* On write emphasise speed over compression; the resultant PNG file will be
+ * larger but will be produced significantly faster, particular for large
+ * images. Do not use this option for images which will be distributed, only
+ * used it when producing intermediate files that will be read back in
+ * repeatedly. For a typical 24-bit image the option will double the read
+ * speed at the cost of increasing the image size by 25%, however for many
+ * more compressible images the PNG file can be 10 times larger with only a
+ * slight speed gain.
+ */
+
+#define PNG_IMAGE_FLAG_16BIT_sRGB 0x04
+ /* On read if the image is a 16-bit per component image and there is no gAMA
+ * or sRGB chunk assume that the components are sRGB encoded. Notice that
+ * images output by the simplified API always have gamma information; setting
+ * this flag only affects the interpretation of 16-bit images from an
+ * external source. It is recommended that the application expose this flag
+ * to the user; the user can normally easily recognize the difference between
+ * linear and sRGB encoding. This flag has no effect on write - the data
+ * passed to the write APIs must have the correct encoding (as defined
+ * above.)
+ *
+ * If the flag is not set (the default) input 16-bit per component data is
+ * assumed to be linear.
+ *
+ * NOTE: the flag can only be set after the png_image_begin_read_ call,
+ * because that call initializes the 'flags' field.
+ */
+
+#ifdef PNG_SIMPLIFIED_READ_SUPPORTED
+/* READ APIs
+ * ---------
+ *
+ * The png_image passed to the read APIs must have been initialized by setting
+ * the png_controlp field 'opaque' to NULL (or, safer, memset the whole thing.)
+ */
+#ifdef PNG_STDIO_SUPPORTED
+PNG_EXPORT(234, int, png_image_begin_read_from_file, (png_imagep image,
+ const char *file_name));
+ /* The named file is opened for read and the image header is filled in
+ * from the PNG header in the file.
+ */
+
+PNG_EXPORT(235, int, png_image_begin_read_from_stdio, (png_imagep image,
+ FILE* file));
+ /* The PNG header is read from the stdio FILE object. */
+#endif /* PNG_STDIO_SUPPORTED */
+
+PNG_EXPORT(236, int, png_image_begin_read_from_memory, (png_imagep image,
+ png_const_voidp memory, png_size_t size));
+ /* The PNG header is read from the given memory buffer. */
+
+PNG_EXPORT(237, int, png_image_finish_read, (png_imagep image,
+ png_const_colorp background, void *buffer, png_int_32 row_stride,
+ void *colormap));
+ /* Finish reading the image into the supplied buffer and clean up the
+ * png_image structure.
+ *
+ * row_stride is the step, in byte or 2-byte units as appropriate,
+ * between adjacent rows. A positive stride indicates that the top-most row
+ * is first in the buffer - the normal top-down arrangement. A negative
+ * stride indicates that the bottom-most row is first in the buffer.
+ *
+ * background need only be supplied if an alpha channel must be removed from
+ * a png_byte format and the removal is to be done by compositing on a solid
+ * color; otherwise it may be NULL and any composition will be done directly
+ * onto the buffer. The value is an sRGB color to use for the background,
+ * for grayscale output the green channel is used.
+ *
+ * background must be supplied when an alpha channel must be removed from a
+ * single byte color-mapped output format, in other words if:
+ *
+ * 1) The original format from png_image_begin_read_from_* had
+ * PNG_FORMAT_FLAG_ALPHA set.
+ * 2) The format set by the application does not.
+ * 3) The format set by the application has PNG_FORMAT_FLAG_COLORMAP set and
+ * PNG_FORMAT_FLAG_LINEAR *not* set.
+ *
+ * For linear output removing the alpha channel is always done by compositing
+ * on black and background is ignored.
+ *
+ * colormap must be supplied when PNG_FORMAT_FLAG_COLORMAP is set. It must
+ * be at least the size (in bytes) returned by PNG_IMAGE_COLORMAP_SIZE.
+ * image->colormap_entries will be updated to the actual number of entries
+ * written to the colormap; this may be less than the original value.
+ */
+
+PNG_EXPORT(238, void, png_image_free, (png_imagep image));
+ /* Free any data allocated by libpng in image->opaque, setting the pointer to
+ * NULL. May be called at any time after the structure is initialized.
+ */
+#endif /* PNG_SIMPLIFIED_READ_SUPPORTED */
+
+#ifdef PNG_SIMPLIFIED_WRITE_SUPPORTED
+#ifdef PNG_STDIO_SUPPORTED
+/* WRITE APIS
+ * ----------
+ * For write you must initialize a png_image structure to describe the image to
+ * be written. To do this use memset to set the whole structure to 0 then
+ * initialize fields describing your image.
+ *
+ * version: must be set to PNG_IMAGE_VERSION
+ * opaque: must be initialized to NULL
+ * width: image width in pixels
+ * height: image height in rows
+ * format: the format of the data (image and color-map) you wish to write
+ * flags: set to 0 unless one of the defined flags applies; set
+ * PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB for color format images where the RGB
+ * values do not correspond to the colors in sRGB.
+ * colormap_entries: set to the number of entries in the color-map (0 to 256)
+ */
+PNG_EXPORT(239, int, png_image_write_to_file, (png_imagep image,
+ const char *file, int convert_to_8bit, const void *buffer,
+ png_int_32 row_stride, const void *colormap));
+ /* Write the image to the named file. */
+
+PNG_EXPORT(240, int, png_image_write_to_stdio, (png_imagep image, FILE *file,
+ int convert_to_8_bit, const void *buffer, png_int_32 row_stride,
+ const void *colormap));
+ /* Write the image to the given (FILE*). */
+
+/* With both write APIs if image is in one of the linear formats with 16-bit
+ * data then setting convert_to_8_bit will cause the output to be an 8-bit PNG
+ * gamma encoded according to the sRGB specification, otherwise a 16-bit linear
+ * encoded PNG file is written.
+ *
+ * With color-mapped data formats the colormap parameter point to a color-map
+ * with at least image->colormap_entries encoded in the specified format. If
+ * the format is linear the written PNG color-map will be converted to sRGB
+ * regardless of the convert_to_8_bit flag.
+ *
+ * With all APIs row_stride is handled as in the read APIs - it is the spacing
+ * from one row to the next in component sized units (1 or 2 bytes) and if
+ * negative indicates a bottom-up row layout in the buffer.
+ *
+ * Note that the write API does not support interlacing or sub-8-bit pixels.
+ */
+#endif /* PNG_STDIO_SUPPORTED */
+#endif /* PNG_SIMPLIFIED_WRITE_SUPPORTED */
+/*******************************************************************************
+ * END OF SIMPLIFIED API
+ ******************************************************************************/
+
+#ifdef PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED
+PNG_EXPORT(242, void, png_set_check_for_invalid_index,
+ (png_structrp png_ptr, int allowed));
+# ifdef PNG_GET_PALETTE_MAX_SUPPORTED
+PNG_EXPORT(243, int, png_get_palette_max, (png_const_structp png_ptr,
+ png_const_infop info_ptr));
+# endif
+#endif /* CHECK_FOR_INVALID_INDEX */
+
+/*******************************************************************************
+ * IMPLEMENTATION OPTIONS
+ *******************************************************************************
+ *
+ * Support for arbitrary implementation-specific optimizations. The API allows
+ * particular options to be turned on or off. 'Option' is the number of the
+ * option and 'onoff' is 0 (off) or non-0 (on). The value returned is given
+ * by the PNG_OPTION_ defines below.
+ *
+ * HARDWARE: normally hardware capabilites, such as the Intel SSE instructions,
+ * are detected at run time, however sometimes it may be impossible
+ * to do this in user mode, in which case it is necessary to discover
+ * the capabilities in an OS specific way. Such capabilities are
+ * listed here when libpng has support for them and must be turned
+ * ON by the application if present.
+ *
+ * SOFTWARE: sometimes software optimizations actually result in performance
+ * decrease on some architectures or systems, or with some sets of
+ * PNG images. 'Software' options allow such optimizations to be
+ * selected at run time.
+ */
+#ifdef PNG_SET_OPTION_SUPPORTED
+#ifdef PNG_ARM_NEON_API_SUPPORTED
+# define PNG_ARM_NEON 0 /* HARDWARE: ARM Neon SIMD instructions supported */
+#endif
+#define PNG_MAXIMUM_INFLATE_WINDOW 2 /* SOFTWARE: force maximum window */
+#define PNG_OPTION_NEXT 4 /* Next option - numbers must be even */
+
+/* Return values: NOTE: there are four values and 'off' is *not* zero */
+#define PNG_OPTION_UNSET 0 /* Unset - defaults to off */
+#define PNG_OPTION_INVALID 1 /* Option number out of range */
+#define PNG_OPTION_OFF 2
+#define PNG_OPTION_ON 3
+
+PNG_EXPORT(244, int, png_set_option, (png_structrp png_ptr, int option,
+ int onoff));
+#endif
+
+/*******************************************************************************
+ * END OF HARDWARE OPTIONS
+ ******************************************************************************/
+
+/* Maintainer: Put new public prototypes here ^, in libpng.3, and project
+ * defs, scripts/pnglibconf.h, and scripts/pnglibconf.h.prebuilt
+ */
+
+/* The last ordinal number (this is the *last* one already used; the next
+ * one to use is one more than this.) Maintainer, remember to add an entry to
+ * scripts/symbols.def as well.
+ */
+#ifdef PNG_EXPORT_LAST_ORDINAL
+ PNG_EXPORT_LAST_ORDINAL(244);
+#endif
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* PNG_VERSION_INFO_ONLY */
+/* Do not put anything past this line */
+#endif /* PNG_H */
diff --git a/ml/dlib/dlib/external/libpng/pngconf.h b/ml/dlib/dlib/external/libpng/pngconf.h
new file mode 100644
index 000000000..8c5347224
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngconf.h
@@ -0,0 +1,626 @@
+
+/* pngconf.h - machine configurable file for libpng
+ *
+ * libpng version 1.6.7 - November 14, 2013
+ *
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ */
+
+/* Any machine specific code is near the front of this file, so if you
+ * are configuring libpng for a machine, you may want to read the section
+ * starting here down to where it starts to typedef png_color, png_text,
+ * and png_info.
+ */
+
+#ifdef _MSC_VER
+// Disable the following warnings for Visual Studio
+// This is a warning you get from visual studio 2005 about things in the standard C++
+// library being "deprecated." I checked the C++ standard and it doesn't say jack
+// about any of them (I checked the searchable PDF). So this warning is total Bunk.
+#pragma warning(disable : 4996)
+#endif
+
+
+#ifndef PNGCONF_H
+#define PNGCONF_H
+
+/* To do: Do all of this in scripts/pnglibconf.dfa */
+#ifdef PNG_SAFE_LIMITS_SUPPORTED
+# ifdef PNG_USER_WIDTH_MAX
+# undef PNG_USER_WIDTH_MAX
+# define PNG_USER_WIDTH_MAX 1000000L
+# endif
+# ifdef PNG_USER_HEIGHT_MAX
+# undef PNG_USER_HEIGHT_MAX
+# define PNG_USER_HEIGHT_MAX 1000000L
+# endif
+# ifdef PNG_USER_CHUNK_MALLOC_MAX
+# undef PNG_USER_CHUNK_MALLOC_MAX
+# define PNG_USER_CHUNK_MALLOC_MAX 4000000L
+# endif
+# ifdef PNG_USER_CHUNK_CACHE_MAX
+# undef PNG_USER_CHUNK_CACHE_MAX
+# define PNG_USER_CHUNK_CACHE_MAX 128
+# endif
+#endif
+
+#ifndef PNG_BUILDING_SYMBOL_TABLE /* else includes may cause problems */
+
+/* From libpng 1.6.0 libpng requires an ANSI X3.159-1989 ("ISOC90") compliant C
+ * compiler for correct compilation. The following header files are required by
+ * the standard. If your compiler doesn't provide these header files, or they
+ * do not match the standard, you will need to provide/improve them.
+ */
+#include <limits.h>
+#include <stddef.h>
+
+/* Library header files. These header files are all defined by ISOC90; libpng
+ * expects conformant implementations, however, an ISOC90 conformant system need
+ * not provide these header files if the functionality cannot be implemented.
+ * In this case it will be necessary to disable the relevant parts of libpng in
+ * the build of pnglibconf.h.
+ *
+ * Prior to 1.6.0 string.h was included here; the API changes in 1.6.0 to not
+ * include this unnecessary header file.
+ */
+
+#ifdef PNG_STDIO_SUPPORTED
+ /* Required for the definition of FILE: */
+# include <stdio.h>
+#endif
+
+#ifdef PNG_SETJMP_SUPPORTED
+ /* Required for the definition of jmp_buf and the declaration of longjmp: */
+# include <setjmp.h>
+#endif
+
+#ifdef PNG_CONVERT_tIME_SUPPORTED
+ /* Required for struct tm: */
+# include <time.h>
+#endif
+
+#endif /* PNG_BUILDING_SYMBOL_TABLE */
+
+/* Prior to 1.6.0 it was possible to turn off 'const' in declarations using
+ * PNG_NO_CONST; this is no longer supported except for data declarations which
+ * apparently still cause problems in 2011 on some compilers.
+ */
+#define PNG_CONST const /* backward compatibility only */
+
+/* This controls optimization of the reading of 16 and 32 bit values
+ * from PNG files. It can be set on a per-app-file basis - it
+ * just changes whether a macro is used when the function is called.
+ * The library builder sets the default; if read functions are not
+ * built into the library the macro implementation is forced on.
+ */
+#ifndef PNG_READ_INT_FUNCTIONS_SUPPORTED
+# define PNG_USE_READ_MACROS
+#endif
+#if !defined(PNG_NO_USE_READ_MACROS) && !defined(PNG_USE_READ_MACROS)
+# if PNG_DEFAULT_READ_MACROS
+# define PNG_USE_READ_MACROS
+# endif
+#endif
+
+/* COMPILER SPECIFIC OPTIONS.
+ *
+ * These options are provided so that a variety of difficult compilers
+ * can be used. Some are fixed at build time (e.g. PNG_API_RULE
+ * below) but still have compiler specific implementations, others
+ * may be changed on a per-file basis when compiling against libpng.
+ */
+
+/* The PNGARG macro was used in versions of libpng prior to 1.6.0 to protect
+ * against legacy (pre ISOC90) compilers that did not understand function
+ * prototypes. It is not required for modern C compilers.
+ */
+#ifndef PNGARG
+# define PNGARG(arglist) arglist
+#endif
+
+/* Function calling conventions.
+ * =============================
+ * Normally it is not necessary to specify to the compiler how to call
+ * a function - it just does it - however on x86 systems derived from
+ * Microsoft and Borland C compilers ('IBM PC', 'DOS', 'Windows' systems
+ * and some others) there are multiple ways to call a function and the
+ * default can be changed on the compiler command line. For this reason
+ * libpng specifies the calling convention of every exported function and
+ * every function called via a user supplied function pointer. This is
+ * done in this file by defining the following macros:
+ *
+ * PNGAPI Calling convention for exported functions.
+ * PNGCBAPI Calling convention for user provided (callback) functions.
+ * PNGCAPI Calling convention used by the ANSI-C library (required
+ * for longjmp callbacks and sometimes used internally to
+ * specify the calling convention for zlib).
+ *
+ * These macros should never be overridden. If it is necessary to
+ * change calling convention in a private build this can be done
+ * by setting PNG_API_RULE (which defaults to 0) to one of the values
+ * below to select the correct 'API' variants.
+ *
+ * PNG_API_RULE=0 Use PNGCAPI - the 'C' calling convention - throughout.
+ * This is correct in every known environment.
+ * PNG_API_RULE=1 Use the operating system convention for PNGAPI and
+ * the 'C' calling convention (from PNGCAPI) for
+ * callbacks (PNGCBAPI). This is no longer required
+ * in any known environment - if it has to be used
+ * please post an explanation of the problem to the
+ * libpng mailing list.
+ *
+ * These cases only differ if the operating system does not use the C
+ * calling convention, at present this just means the above cases
+ * (x86 DOS/Windows sytems) and, even then, this does not apply to
+ * Cygwin running on those systems.
+ *
+ * Note that the value must be defined in pnglibconf.h so that what
+ * the application uses to call the library matches the conventions
+ * set when building the library.
+ */
+
+/* Symbol export
+ * =============
+ * When building a shared library it is almost always necessary to tell
+ * the compiler which symbols to export. The png.h macro 'PNG_EXPORT'
+ * is used to mark the symbols. On some systems these symbols can be
+ * extracted at link time and need no special processing by the compiler,
+ * on other systems the symbols are flagged by the compiler and just
+ * the declaration requires a special tag applied (unfortunately) in a
+ * compiler dependent way. Some systems can do either.
+ *
+ * A small number of older systems also require a symbol from a DLL to
+ * be flagged to the program that calls it. This is a problem because
+ * we do not know in the header file included by application code that
+ * the symbol will come from a shared library, as opposed to a statically
+ * linked one. For this reason the application must tell us by setting
+ * the magic flag PNG_USE_DLL to turn on the special processing before
+ * it includes png.h.
+ *
+ * Four additional macros are used to make this happen:
+ *
+ * PNG_IMPEXP The magic (if any) to cause a symbol to be exported from
+ * the build or imported if PNG_USE_DLL is set - compiler
+ * and system specific.
+ *
+ * PNG_EXPORT_TYPE(type) A macro that pre or appends PNG_IMPEXP to
+ * 'type', compiler specific.
+ *
+ * PNG_DLL_EXPORT Set to the magic to use during a libpng build to
+ * make a symbol exported from the DLL. Not used in the
+ * public header files; see pngpriv.h for how it is used
+ * in the libpng build.
+ *
+ * PNG_DLL_IMPORT Set to the magic to force the libpng symbols to come
+ * from a DLL - used to define PNG_IMPEXP when
+ * PNG_USE_DLL is set.
+ */
+
+/* System specific discovery.
+ * ==========================
+ * This code is used at build time to find PNG_IMPEXP, the API settings
+ * and PNG_EXPORT_TYPE(), it may also set a macro to indicate the DLL
+ * import processing is possible. On Windows systems it also sets
+ * compiler-specific macros to the values required to change the calling
+ * conventions of the various functions.
+ */
+#if defined(_Windows) || defined(_WINDOWS) || defined(WIN32) ||\
+ defined(_WIN32) || defined(__WIN32__) || defined(__CYGWIN__)
+ /* Windows system (DOS doesn't support DLLs). Includes builds under Cygwin or
+ * MinGW on any architecture currently supported by Windows. Also includes
+ * Watcom builds but these need special treatment because they are not
+ * compatible with GCC or Visual C because of different calling conventions.
+ */
+# if PNG_API_RULE == 2
+ /* If this line results in an error, either because __watcall is not
+ * understood or because of a redefine just below you cannot use *this*
+ * build of the library with the compiler you are using. *This* build was
+ * build using Watcom and applications must also be built using Watcom!
+ */
+# define PNGCAPI __watcall
+# endif
+
+# if defined(__GNUC__) || (defined(_MSC_VER) && (_MSC_VER >= 800))
+# define PNGCAPI __cdecl
+# if PNG_API_RULE == 1
+ /* If this line results in an error __stdcall is not understood and
+ * PNG_API_RULE should not have been set to '1'.
+ */
+# define PNGAPI __stdcall
+# endif
+# else
+ /* An older compiler, or one not detected (erroneously) above,
+ * if necessary override on the command line to get the correct
+ * variants for the compiler.
+ */
+# ifndef PNGCAPI
+# define PNGCAPI _cdecl
+# endif
+# if PNG_API_RULE == 1 && !defined(PNGAPI)
+# define PNGAPI _stdcall
+# endif
+# endif /* compiler/api */
+
+ /* NOTE: PNGCBAPI always defaults to PNGCAPI. */
+
+# if defined(PNGAPI) && !defined(PNG_USER_PRIVATEBUILD)
+# error "PNG_USER_PRIVATEBUILD must be defined if PNGAPI is changed"
+# endif
+
+# if (defined(_MSC_VER) && _MSC_VER < 800) ||\
+ (defined(__BORLANDC__) && __BORLANDC__ < 0x500)
+ /* older Borland and MSC
+ * compilers used '__export' and required this to be after
+ * the type.
+ */
+# ifndef PNG_EXPORT_TYPE
+# define PNG_EXPORT_TYPE(type) type PNG_IMPEXP
+# endif
+# define PNG_DLL_EXPORT __export
+# else /* newer compiler */
+# define PNG_DLL_EXPORT __declspec(dllexport)
+# ifndef PNG_DLL_IMPORT
+# define PNG_DLL_IMPORT __declspec(dllimport)
+# endif
+# endif /* compiler */
+
+#else /* !Windows */
+# if (defined(__IBMC__) || defined(__IBMCPP__)) && defined(__OS2__)
+# define PNGAPI _System
+# else /* !Windows/x86 && !OS/2 */
+ /* Use the defaults, or define PNG*API on the command line (but
+ * this will have to be done for every compile!)
+ */
+# endif /* other system, !OS/2 */
+#endif /* !Windows/x86 */
+
+/* Now do all the defaulting . */
+#ifndef PNGCAPI
+# define PNGCAPI
+#endif
+#ifndef PNGCBAPI
+# define PNGCBAPI PNGCAPI
+#endif
+#ifndef PNGAPI
+# define PNGAPI PNGCAPI
+#endif
+
+/* PNG_IMPEXP may be set on the compilation system command line or (if not set)
+ * then in an internal header file when building the library, otherwise (when
+ * using the library) it is set here.
+ */
+#ifndef PNG_IMPEXP
+# if defined(PNG_USE_DLL) && defined(PNG_DLL_IMPORT)
+ /* This forces use of a DLL, disallowing static linking */
+# define PNG_IMPEXP PNG_DLL_IMPORT
+# endif
+
+# ifndef PNG_IMPEXP
+# define PNG_IMPEXP
+# endif
+#endif
+
+/* In 1.5.2 the definition of PNG_FUNCTION has been changed to always treat
+ * 'attributes' as a storage class - the attributes go at the start of the
+ * function definition, and attributes are always appended regardless of the
+ * compiler. This considerably simplifies these macros but may cause problems
+ * if any compilers both need function attributes and fail to handle them as
+ * a storage class (this is unlikely.)
+ */
+#ifndef PNG_FUNCTION
+# define PNG_FUNCTION(type, name, args, attributes) attributes type name args
+#endif
+
+#ifndef PNG_EXPORT_TYPE
+# define PNG_EXPORT_TYPE(type) PNG_IMPEXP type
+#endif
+
+ /* The ordinal value is only relevant when preprocessing png.h for symbol
+ * table entries, so we discard it here. See the .dfn files in the
+ * scripts directory.
+ */
+#ifndef PNG_EXPORTA
+
+# define PNG_EXPORTA(ordinal, type, name, args, attributes)\
+ PNG_FUNCTION(PNG_EXPORT_TYPE(type),(PNGAPI name),PNGARG(args), \
+ extern attributes)
+#endif
+
+/* ANSI-C (C90) does not permit a macro to be invoked with an empty argument,
+ * so make something non-empty to satisfy the requirement:
+ */
+#define PNG_EMPTY /*empty list*/
+
+#define PNG_EXPORT(ordinal, type, name, args)\
+ PNG_EXPORTA(ordinal, type, name, args, PNG_EMPTY)
+
+/* Use PNG_REMOVED to comment out a removed interface. */
+#ifndef PNG_REMOVED
+# define PNG_REMOVED(ordinal, type, name, args, attributes)
+#endif
+
+#ifndef PNG_CALLBACK
+# define PNG_CALLBACK(type, name, args) type (PNGCBAPI name) PNGARG(args)
+#endif
+
+/* Support for compiler specific function attributes. These are used
+ * so that where compiler support is available incorrect use of API
+ * functions in png.h will generate compiler warnings.
+ *
+ * Added at libpng-1.2.41.
+ */
+
+#ifndef PNG_NO_PEDANTIC_WARNINGS
+# ifndef PNG_PEDANTIC_WARNINGS_SUPPORTED
+# define PNG_PEDANTIC_WARNINGS_SUPPORTED
+# endif
+#endif
+
+#ifdef PNG_PEDANTIC_WARNINGS_SUPPORTED
+ /* Support for compiler specific function attributes. These are used
+ * so that where compiler support is available, incorrect use of API
+ * functions in png.h will generate compiler warnings. Added at libpng
+ * version 1.2.41. Disabling these removes the warnings but may also produce
+ * less efficient code.
+ */
+# if defined(__GNUC__)
+# ifndef PNG_USE_RESULT
+# define PNG_USE_RESULT __attribute__((__warn_unused_result__))
+# endif
+# ifndef PNG_NORETURN
+# define PNG_NORETURN __attribute__((__noreturn__))
+# endif
+# if __GNUC__ >= 3
+# ifndef PNG_ALLOCATED
+# define PNG_ALLOCATED __attribute__((__malloc__))
+# endif
+# ifndef PNG_DEPRECATED
+# define PNG_DEPRECATED __attribute__((__deprecated__))
+# endif
+# ifndef PNG_PRIVATE
+# if 0 /* Doesn't work so we use deprecated instead*/
+# define PNG_PRIVATE \
+ __attribute__((warning("This function is not exported by libpng.")))
+# else
+# define PNG_PRIVATE \
+ __attribute__((__deprecated__))
+# endif
+# endif
+# if ((__GNUC__ != 3) || !defined(__GNUC_MINOR__) || (__GNUC_MINOR__ >= 1))
+# ifndef PNG_RESTRICT
+# define PNG_RESTRICT __restrict
+# endif
+# endif /* __GNUC__ == 3.0 */
+# endif /* __GNUC__ >= 3 */
+
+# elif defined(_MSC_VER) && (_MSC_VER >= 1300)
+# ifndef PNG_USE_RESULT
+# define PNG_USE_RESULT /* not supported */
+# endif
+# ifndef PNG_NORETURN
+# define PNG_NORETURN __declspec(noreturn)
+# endif
+# ifndef PNG_ALLOCATED
+# if (_MSC_VER >= 1400)
+# define PNG_ALLOCATED __declspec(restrict)
+# endif
+# endif
+# ifndef PNG_DEPRECATED
+# define PNG_DEPRECATED __declspec(deprecated)
+# endif
+# ifndef PNG_PRIVATE
+# define PNG_PRIVATE __declspec(deprecated)
+# endif
+# ifndef PNG_RESTRICT
+# if (_MSC_VER >= 1400)
+# define PNG_RESTRICT __restrict
+# endif
+# endif
+
+# elif defined(__WATCOMC__)
+# ifndef PNG_RESTRICT
+# define PNG_RESTRICT __restrict
+# endif
+# endif /* _MSC_VER */
+#endif /* PNG_PEDANTIC_WARNINGS */
+
+#ifndef PNG_DEPRECATED
+# define PNG_DEPRECATED /* Use of this function is deprecated */
+#endif
+#ifndef PNG_USE_RESULT
+# define PNG_USE_RESULT /* The result of this function must be checked */
+#endif
+#ifndef PNG_NORETURN
+# define PNG_NORETURN /* This function does not return */
+#endif
+#ifndef PNG_ALLOCATED
+# define PNG_ALLOCATED /* The result of the function is new memory */
+#endif
+#ifndef PNG_PRIVATE
+# define PNG_PRIVATE /* This is a private libpng function */
+#endif
+#ifndef PNG_RESTRICT
+# define PNG_RESTRICT /* The C99 "restrict" feature */
+#endif
+#ifndef PNG_FP_EXPORT /* A floating point API. */
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+# define PNG_FP_EXPORT(ordinal, type, name, args)\
+ PNG_EXPORT(ordinal, type, name, args);
+# else /* No floating point APIs */
+# define PNG_FP_EXPORT(ordinal, type, name, args)
+# endif
+#endif
+#ifndef PNG_FIXED_EXPORT /* A fixed point API. */
+# ifdef PNG_FIXED_POINT_SUPPORTED
+# define PNG_FIXED_EXPORT(ordinal, type, name, args)\
+ PNG_EXPORT(ordinal, type, name, args);
+# else /* No fixed point APIs */
+# define PNG_FIXED_EXPORT(ordinal, type, name, args)
+# endif
+#endif
+
+#ifndef PNG_BUILDING_SYMBOL_TABLE
+/* Some typedefs to get us started. These should be safe on most of the common
+ * platforms.
+ *
+ * png_uint_32 and png_int_32 may, currently, be larger than required to hold a
+ * 32-bit value however this is not normally advisable.
+ *
+ * png_uint_16 and png_int_16 should always be two bytes in size - this is
+ * verified at library build time.
+ *
+ * png_byte must always be one byte in size.
+ *
+ * The checks below use constants from limits.h, as defined by the ISOC90
+ * standard.
+ */
+#if CHAR_BIT == 8 && UCHAR_MAX == 255
+ typedef unsigned char png_byte;
+#else
+# error "libpng requires 8 bit bytes"
+#endif
+
+#if INT_MIN == -32768 && INT_MAX == 32767
+ typedef int png_int_16;
+#elif SHRT_MIN == -32768 && SHRT_MAX == 32767
+ typedef short png_int_16;
+#else
+# error "libpng requires a signed 16 bit type"
+#endif
+
+#if UINT_MAX == 65535
+ typedef unsigned int png_uint_16;
+#elif USHRT_MAX == 65535
+ typedef unsigned short png_uint_16;
+#else
+# error "libpng requires an unsigned 16 bit type"
+#endif
+
+#if INT_MIN < -2147483646 && INT_MAX > 2147483646
+ typedef int png_int_32;
+#elif LONG_MIN < -2147483646 && LONG_MAX > 2147483646
+ typedef long int png_int_32;
+#else
+# error "libpng requires a signed 32 bit (or more) type"
+#endif
+
+#if UINT_MAX > 4294967294
+ typedef unsigned int png_uint_32;
+#elif ULONG_MAX > 4294967294
+ typedef unsigned long int png_uint_32;
+#else
+# error "libpng requires an unsigned 32 bit (or more) type"
+#endif
+
+/* Prior to 1.6.0 it was possible to disable the use of size_t, 1.6.0, however,
+ * requires an ISOC90 compiler and relies on consistent behavior of sizeof.
+ */
+typedef size_t png_size_t;
+typedef ptrdiff_t png_ptrdiff_t;
+
+/* libpng needs to know the maximum value of 'size_t' and this controls the
+ * definition of png_alloc_size_t, below. This maximum value of size_t limits
+ * but does not control the maximum allocations the library makes - there is
+ * direct application control of this through png_set_user_limits().
+ */
+#ifndef PNG_SMALL_SIZE_T
+ /* Compiler specific tests for systems where size_t is known to be less than
+ * 32 bits (some of these systems may no longer work because of the lack of
+ * 'far' support; see above.)
+ */
+# if (defined(__TURBOC__) && !defined(__FLAT__)) ||\
+ (defined(_MSC_VER) && defined(MAXSEG_64K))
+# define PNG_SMALL_SIZE_T
+# endif
+#endif
+
+/* png_alloc_size_t is guaranteed to be no smaller than png_size_t, and no
+ * smaller than png_uint_32. Casts from png_size_t or png_uint_32 to
+ * png_alloc_size_t are not necessary; in fact, it is recommended not to use
+ * them at all so that the compiler can complain when something turns out to be
+ * problematic.
+ *
+ * Casts in the other direction (from png_alloc_size_t to png_size_t or
+ * png_uint_32) should be explicitly applied; however, we do not expect to
+ * encounter practical situations that require such conversions.
+ *
+ * PNG_SMALL_SIZE_T must be defined if the maximum value of size_t is less than
+ * 4294967295 - i.e. less than the maximum value of png_uint_32.
+ */
+#ifdef PNG_SMALL_SIZE_T
+ typedef png_uint_32 png_alloc_size_t;
+#else
+ typedef png_size_t png_alloc_size_t;
+#endif
+
+/* Prior to 1.6.0 libpng offered limited support for Microsoft C compiler
+ * implementations of Intel CPU specific support of user-mode segmented address
+ * spaces, where 16-bit pointers address more than 65536 bytes of memory using
+ * separate 'segment' registers. The implementation requires two different
+ * types of pointer (only one of which includes the segment value.)
+ *
+ * If required this support is available in version 1.2 of libpng and may be
+ * available in versions through 1.5, although the correctness of the code has
+ * not been verified recently.
+ */
+
+/* Typedef for floating-point numbers that are converted to fixed-point with a
+ * multiple of 100,000, e.g., gamma
+ */
+typedef png_int_32 png_fixed_point;
+
+/* Add typedefs for pointers */
+typedef void * png_voidp;
+typedef const void * png_const_voidp;
+typedef png_byte * png_bytep;
+typedef const png_byte * png_const_bytep;
+typedef png_uint_32 * png_uint_32p;
+typedef const png_uint_32 * png_const_uint_32p;
+typedef png_int_32 * png_int_32p;
+typedef const png_int_32 * png_const_int_32p;
+typedef png_uint_16 * png_uint_16p;
+typedef const png_uint_16 * png_const_uint_16p;
+typedef png_int_16 * png_int_16p;
+typedef const png_int_16 * png_const_int_16p;
+typedef char * png_charp;
+typedef const char * png_const_charp;
+typedef png_fixed_point * png_fixed_point_p;
+typedef const png_fixed_point * png_const_fixed_point_p;
+typedef png_size_t * png_size_tp;
+typedef const png_size_t * png_const_size_tp;
+
+#ifdef PNG_STDIO_SUPPORTED
+typedef FILE * png_FILE_p;
+#endif
+
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+typedef double * png_doublep;
+typedef const double * png_const_doublep;
+#endif
+
+/* Pointers to pointers; i.e. arrays */
+typedef png_byte * * png_bytepp;
+typedef png_uint_32 * * png_uint_32pp;
+typedef png_int_32 * * png_int_32pp;
+typedef png_uint_16 * * png_uint_16pp;
+typedef png_int_16 * * png_int_16pp;
+typedef const char * * png_const_charpp;
+typedef char * * png_charpp;
+typedef png_fixed_point * * png_fixed_point_pp;
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+typedef double * * png_doublepp;
+#endif
+
+/* Pointers to pointers to pointers; i.e., pointer to array */
+typedef char * * * png_charppp;
+
+#endif /* PNG_BUILDING_SYMBOL_TABLE */
+
+#endif /* PNGCONF_H */
diff --git a/ml/dlib/dlib/external/libpng/pngdebug.h b/ml/dlib/dlib/external/libpng/pngdebug.h
new file mode 100644
index 000000000..16f81fdd1
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngdebug.h
@@ -0,0 +1,157 @@
+
+/* pngdebug.h - Debugging macros for libpng, also used in pngtest.c
+ *
+ * Copyright (c) 1998-2011 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * Last changed in libpng 1.5.0 [January 6, 2011]
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+/* Define PNG_DEBUG at compile time for debugging information. Higher
+ * numbers for PNG_DEBUG mean more debugging information. This has
+ * only been added since version 0.95 so it is not implemented throughout
+ * libpng yet, but more support will be added as needed.
+ *
+ * png_debug[1-2]?(level, message ,arg{0-2})
+ * Expands to a statement (either a simple expression or a compound
+ * do..while(0) statement) that outputs a message with parameter
+ * substitution if PNG_DEBUG is defined to 2 or more. If PNG_DEBUG
+ * is undefined, 0 or 1 every png_debug expands to a simple expression
+ * (actually ((void)0)).
+ *
+ * level: level of detail of message, starting at 0. A level 'n'
+ * message is preceded by 'n' tab characters (not implemented
+ * on Microsoft compilers unless PNG_DEBUG_FILE is also
+ * defined, to allow debug DLL compilation with no standard IO).
+ * message: a printf(3) style text string. A trailing '\n' is added
+ * to the message.
+ * arg: 0 to 2 arguments for printf(3) style substitution in message.
+ */
+#ifndef PNGDEBUG_H
+#define PNGDEBUG_H
+/* These settings control the formatting of messages in png.c and pngerror.c */
+/* Moved to pngdebug.h at 1.5.0 */
+# ifndef PNG_LITERAL_SHARP
+# define PNG_LITERAL_SHARP 0x23
+# endif
+# ifndef PNG_LITERAL_LEFT_SQUARE_BRACKET
+# define PNG_LITERAL_LEFT_SQUARE_BRACKET 0x5b
+# endif
+# ifndef PNG_LITERAL_RIGHT_SQUARE_BRACKET
+# define PNG_LITERAL_RIGHT_SQUARE_BRACKET 0x5d
+# endif
+# ifndef PNG_STRING_NEWLINE
+# define PNG_STRING_NEWLINE "\n"
+# endif
+
+#ifdef PNG_DEBUG
+# if (PNG_DEBUG > 0)
+# if !defined(PNG_DEBUG_FILE) && defined(_MSC_VER)
+# include <crtdbg.h>
+# if (PNG_DEBUG > 1)
+# ifndef _DEBUG
+# define _DEBUG
+# endif
+# ifndef png_debug
+# define png_debug(l,m) _RPT0(_CRT_WARN,m PNG_STRING_NEWLINE)
+# endif
+# ifndef png_debug1
+# define png_debug1(l,m,p1) _RPT1(_CRT_WARN,m PNG_STRING_NEWLINE,p1)
+# endif
+# ifndef png_debug2
+# define png_debug2(l,m,p1,p2) \
+ _RPT2(_CRT_WARN,m PNG_STRING_NEWLINE,p1,p2)
+# endif
+# endif
+# else /* PNG_DEBUG_FILE || !_MSC_VER */
+# ifndef PNG_STDIO_SUPPORTED
+# include <stdio.h> /* not included yet */
+# endif
+# ifndef PNG_DEBUG_FILE
+# define PNG_DEBUG_FILE stderr
+# endif /* PNG_DEBUG_FILE */
+
+# if (PNG_DEBUG > 1)
+/* Note: ["%s"m PNG_STRING_NEWLINE] probably does not work on
+ * non-ISO compilers
+ */
+# ifdef __STDC__
+# ifndef png_debug
+# define png_debug(l,m) \
+ do { \
+ int num_tabs=l; \
+ fprintf(PNG_DEBUG_FILE,"%s"m PNG_STRING_NEWLINE,(num_tabs==1 ? "\t" : \
+ (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":"")))); \
+ } while (0)
+# endif
+# ifndef png_debug1
+# define png_debug1(l,m,p1) \
+ do { \
+ int num_tabs=l; \
+ fprintf(PNG_DEBUG_FILE,"%s"m PNG_STRING_NEWLINE,(num_tabs==1 ? "\t" : \
+ (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))),p1); \
+ } while (0)
+# endif
+# ifndef png_debug2
+# define png_debug2(l,m,p1,p2) \
+ do { \
+ int num_tabs=l; \
+ fprintf(PNG_DEBUG_FILE,"%s"m PNG_STRING_NEWLINE,(num_tabs==1 ? "\t" : \
+ (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))),p1,p2); \
+ } while (0)
+# endif
+# else /* __STDC __ */
+# ifndef png_debug
+# define png_debug(l,m) \
+ do { \
+ int num_tabs=l; \
+ char format[256]; \
+ snprintf(format,256,"%s%s%s",(num_tabs==1 ? "\t" : \
+ (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))), \
+ m,PNG_STRING_NEWLINE); \
+ fprintf(PNG_DEBUG_FILE,format); \
+ } while (0)
+# endif
+# ifndef png_debug1
+# define png_debug1(l,m,p1) \
+ do { \
+ int num_tabs=l; \
+ char format[256]; \
+ snprintf(format,256,"%s%s%s",(num_tabs==1 ? "\t" : \
+ (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))), \
+ m,PNG_STRING_NEWLINE); \
+ fprintf(PNG_DEBUG_FILE,format,p1); \
+ } while (0)
+# endif
+# ifndef png_debug2
+# define png_debug2(l,m,p1,p2) \
+ do { \
+ int num_tabs=l; \
+ char format[256]; \
+ snprintf(format,256,"%s%s%s",(num_tabs==1 ? "\t" : \
+ (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))), \
+ m,PNG_STRING_NEWLINE); \
+ fprintf(PNG_DEBUG_FILE,format,p1,p2); \
+ } while (0)
+# endif
+# endif /* __STDC __ */
+# endif /* (PNG_DEBUG > 1) */
+
+# endif /* _MSC_VER */
+# endif /* (PNG_DEBUG > 0) */
+#endif /* PNG_DEBUG */
+#ifndef png_debug
+# define png_debug(l, m) ((void)0)
+#endif
+#ifndef png_debug1
+# define png_debug1(l, m, p1) ((void)0)
+#endif
+#ifndef png_debug2
+# define png_debug2(l, m, p1, p2) ((void)0)
+#endif
+#endif /* PNGDEBUG_H */
diff --git a/ml/dlib/dlib/external/libpng/pngerror.c b/ml/dlib/dlib/external/libpng/pngerror.c
new file mode 100644
index 000000000..f469206ee
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngerror.c
@@ -0,0 +1,932 @@
+
+/* pngerror.c - stub functions for i/o and memory allocation
+ *
+ * Last changed in libpng 1.6.1 [March 28, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ * This file provides a location for all error handling. Users who
+ * need special error handling are expected to write replacement functions
+ * and use png_set_error_fn() to use those functions. See the instructions
+ * at each function.
+ */
+
+#include "pngpriv.h"
+
+#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED)
+
+static PNG_FUNCTION(void, png_default_error,PNGARG((png_const_structrp png_ptr,
+ png_const_charp error_message)),PNG_NORETURN);
+
+#ifdef PNG_WARNINGS_SUPPORTED
+static void /* PRIVATE */
+png_default_warning PNGARG((png_const_structrp png_ptr,
+ png_const_charp warning_message));
+#endif /* PNG_WARNINGS_SUPPORTED */
+
+/* This function is called whenever there is a fatal error. This function
+ * should not be changed. If there is a need to handle errors differently,
+ * you should supply a replacement error function and use png_set_error_fn()
+ * to replace the error function at run-time.
+ */
+#ifdef PNG_ERROR_TEXT_SUPPORTED
+PNG_FUNCTION(void,PNGAPI
+png_error,(png_const_structrp png_ptr, png_const_charp error_message),
+ PNG_NORETURN)
+{
+#ifdef PNG_ERROR_NUMBERS_SUPPORTED
+ char msg[16];
+ if (png_ptr != NULL)
+ {
+ if (png_ptr->flags&
+ (PNG_FLAG_STRIP_ERROR_NUMBERS|PNG_FLAG_STRIP_ERROR_TEXT))
+ {
+ if (*error_message == PNG_LITERAL_SHARP)
+ {
+ /* Strip "#nnnn " from beginning of error message. */
+ int offset;
+ for (offset = 1; offset<15; offset++)
+ if (error_message[offset] == ' ')
+ break;
+
+ if (png_ptr->flags&PNG_FLAG_STRIP_ERROR_TEXT)
+ {
+ int i;
+ for (i = 0; i < offset - 1; i++)
+ msg[i] = error_message[i + 1];
+ msg[i - 1] = '\0';
+ error_message = msg;
+ }
+
+ else
+ error_message += offset;
+ }
+
+ else
+ {
+ if (png_ptr->flags&PNG_FLAG_STRIP_ERROR_TEXT)
+ {
+ msg[0] = '0';
+ msg[1] = '\0';
+ error_message = msg;
+ }
+ }
+ }
+ }
+#endif
+ if (png_ptr != NULL && png_ptr->error_fn != NULL)
+ (*(png_ptr->error_fn))(png_constcast(png_structrp,png_ptr),
+ error_message);
+
+ /* If the custom handler doesn't exist, or if it returns,
+ use the default handler, which will not return. */
+ png_default_error(png_ptr, error_message);
+}
+#else
+PNG_FUNCTION(void,PNGAPI
+png_err,(png_const_structrp png_ptr),PNG_NORETURN)
+{
+ /* Prior to 1.5.2 the error_fn received a NULL pointer, expressed
+ * erroneously as '\0', instead of the empty string "". This was
+ * apparently an error, introduced in libpng-1.2.20, and png_default_error
+ * will crash in this case.
+ */
+ if (png_ptr != NULL && png_ptr->error_fn != NULL)
+ (*(png_ptr->error_fn))(png_constcast(png_structrp,png_ptr), "");
+
+ /* If the custom handler doesn't exist, or if it returns,
+ use the default handler, which will not return. */
+ png_default_error(png_ptr, "");
+}
+#endif /* PNG_ERROR_TEXT_SUPPORTED */
+
+/* Utility to safely appends strings to a buffer. This never errors out so
+ * error checking is not required in the caller.
+ */
+size_t
+png_safecat(png_charp buffer, size_t bufsize, size_t pos,
+ png_const_charp string)
+{
+ if (buffer != NULL && pos < bufsize)
+ {
+ if (string != NULL)
+ while (*string != '\0' && pos < bufsize-1)
+ buffer[pos++] = *string++;
+
+ buffer[pos] = '\0';
+ }
+
+ return pos;
+}
+
+#if defined(PNG_WARNINGS_SUPPORTED) || defined(PNG_TIME_RFC1123_SUPPORTED)
+/* Utility to dump an unsigned value into a buffer, given a start pointer and
+ * and end pointer (which should point just *beyond* the end of the buffer!)
+ * Returns the pointer to the start of the formatted string.
+ */
+png_charp
+png_format_number(png_const_charp start, png_charp end, int format,
+ png_alloc_size_t number)
+{
+ int count = 0; /* number of digits output */
+ int mincount = 1; /* minimum number required */
+ int output = 0; /* digit output (for the fixed point format) */
+
+ *--end = '\0';
+
+ /* This is written so that the loop always runs at least once, even with
+ * number zero.
+ */
+ while (end > start && (number != 0 || count < mincount))
+ {
+
+ static const char digits[] = "0123456789ABCDEF";
+
+ switch (format)
+ {
+ case PNG_NUMBER_FORMAT_fixed:
+ /* Needs five digits (the fraction) */
+ mincount = 5;
+ if (output || number % 10 != 0)
+ {
+ *--end = digits[number % 10];
+ output = 1;
+ }
+ number /= 10;
+ break;
+
+ case PNG_NUMBER_FORMAT_02u:
+ /* Expects at least 2 digits. */
+ mincount = 2;
+ /* FALL THROUGH */
+
+ case PNG_NUMBER_FORMAT_u:
+ *--end = digits[number % 10];
+ number /= 10;
+ break;
+
+ case PNG_NUMBER_FORMAT_02x:
+ /* This format expects at least two digits */
+ mincount = 2;
+ /* FALL THROUGH */
+
+ case PNG_NUMBER_FORMAT_x:
+ *--end = digits[number & 0xf];
+ number >>= 4;
+ break;
+
+ default: /* an error */
+ number = 0;
+ break;
+ }
+
+ /* Keep track of the number of digits added */
+ ++count;
+
+ /* Float a fixed number here: */
+ if (format == PNG_NUMBER_FORMAT_fixed) if (count == 5) if (end > start)
+ {
+ /* End of the fraction, but maybe nothing was output? In that case
+ * drop the decimal point. If the number is a true zero handle that
+ * here.
+ */
+ if (output)
+ *--end = '.';
+ else if (number == 0) /* and !output */
+ *--end = '0';
+ }
+ }
+
+ return end;
+}
+#endif
+
+#ifdef PNG_WARNINGS_SUPPORTED
+/* This function is called whenever there is a non-fatal error. This function
+ * should not be changed. If there is a need to handle warnings differently,
+ * you should supply a replacement warning function and use
+ * png_set_error_fn() to replace the warning function at run-time.
+ */
+void PNGAPI
+png_warning(png_const_structrp png_ptr, png_const_charp warning_message)
+{
+ int offset = 0;
+ if (png_ptr != NULL)
+ {
+#ifdef PNG_ERROR_NUMBERS_SUPPORTED
+ if (png_ptr->flags&
+ (PNG_FLAG_STRIP_ERROR_NUMBERS|PNG_FLAG_STRIP_ERROR_TEXT))
+#endif
+ {
+ if (*warning_message == PNG_LITERAL_SHARP)
+ {
+ for (offset = 1; offset < 15; offset++)
+ if (warning_message[offset] == ' ')
+ break;
+ }
+ }
+ }
+ if (png_ptr != NULL && png_ptr->warning_fn != NULL)
+ (*(png_ptr->warning_fn))(png_constcast(png_structrp,png_ptr),
+ warning_message + offset);
+ else
+ png_default_warning(png_ptr, warning_message + offset);
+}
+
+/* These functions support 'formatted' warning messages with up to
+ * PNG_WARNING_PARAMETER_COUNT parameters. In the format string the parameter
+ * is introduced by @<number>, where 'number' starts at 1. This follows the
+ * standard established by X/Open for internationalizable error messages.
+ */
+void
+png_warning_parameter(png_warning_parameters p, int number,
+ png_const_charp string)
+{
+ if (number > 0 && number <= PNG_WARNING_PARAMETER_COUNT)
+ (void)png_safecat(p[number-1], (sizeof p[number-1]), 0, string);
+}
+
+void
+png_warning_parameter_unsigned(png_warning_parameters p, int number, int format,
+ png_alloc_size_t value)
+{
+ char buffer[PNG_NUMBER_BUFFER_SIZE];
+ png_warning_parameter(p, number, PNG_FORMAT_NUMBER(buffer, format, value));
+}
+
+void
+png_warning_parameter_signed(png_warning_parameters p, int number, int format,
+ png_int_32 value)
+{
+ png_alloc_size_t u;
+ png_charp str;
+ char buffer[PNG_NUMBER_BUFFER_SIZE];
+
+ /* Avoid overflow by doing the negate in a png_alloc_size_t: */
+ u = (png_alloc_size_t)value;
+ if (value < 0)
+ u = ~u + 1;
+
+ str = PNG_FORMAT_NUMBER(buffer, format, u);
+
+ if (value < 0 && str > buffer)
+ *--str = '-';
+
+ png_warning_parameter(p, number, str);
+}
+
+void
+png_formatted_warning(png_const_structrp png_ptr, png_warning_parameters p,
+ png_const_charp message)
+{
+ /* The internal buffer is just 192 bytes - enough for all our messages,
+ * overflow doesn't happen because this code checks! If someone figures
+ * out how to send us a message longer than 192 bytes, all that will
+ * happen is that the message will be truncated appropriately.
+ */
+ size_t i = 0; /* Index in the msg[] buffer: */
+ char msg[192];
+
+ /* Each iteration through the following loop writes at most one character
+ * to msg[i++] then returns here to validate that there is still space for
+ * the trailing '\0'. It may (in the case of a parameter) read more than
+ * one character from message[]; it must check for '\0' and continue to the
+ * test if it finds the end of string.
+ */
+ while (i<(sizeof msg)-1 && *message != '\0')
+ {
+ /* '@' at end of string is now just printed (previously it was skipped);
+ * it is an error in the calling code to terminate the string with @.
+ */
+ if (p != NULL && *message == '@' && message[1] != '\0')
+ {
+ int parameter_char = *++message; /* Consume the '@' */
+ static const char valid_parameters[] = "123456789";
+ int parameter = 0;
+
+ /* Search for the parameter digit, the index in the string is the
+ * parameter to use.
+ */
+ while (valid_parameters[parameter] != parameter_char &&
+ valid_parameters[parameter] != '\0')
+ ++parameter;
+
+ /* If the parameter digit is out of range it will just get printed. */
+ if (parameter < PNG_WARNING_PARAMETER_COUNT)
+ {
+ /* Append this parameter */
+ png_const_charp parm = p[parameter];
+ png_const_charp pend = p[parameter] + (sizeof p[parameter]);
+
+ /* No need to copy the trailing '\0' here, but there is no guarantee
+ * that parm[] has been initialized, so there is no guarantee of a
+ * trailing '\0':
+ */
+ while (i<(sizeof msg)-1 && *parm != '\0' && parm < pend)
+ msg[i++] = *parm++;
+
+ /* Consume the parameter digit too: */
+ ++message;
+ continue;
+ }
+
+ /* else not a parameter and there is a character after the @ sign; just
+ * copy that. This is known not to be '\0' because of the test above.
+ */
+ }
+
+ /* At this point *message can't be '\0', even in the bad parameter case
+ * above where there is a lone '@' at the end of the message string.
+ */
+ msg[i++] = *message++;
+ }
+
+ /* i is always less than (sizeof msg), so: */
+ msg[i] = '\0';
+
+ /* And this is the formatted message. It may be larger than
+ * PNG_MAX_ERROR_TEXT, but that is only used for 'chunk' errors and these
+ * are not (currently) formatted.
+ */
+ png_warning(png_ptr, msg);
+}
+#endif /* PNG_WARNINGS_SUPPORTED */
+
+#ifdef PNG_BENIGN_ERRORS_SUPPORTED
+void PNGAPI
+png_benign_error(png_const_structrp png_ptr, png_const_charp error_message)
+{
+ if (png_ptr->flags & PNG_FLAG_BENIGN_ERRORS_WARN)
+ {
+# ifdef PNG_READ_SUPPORTED
+ if ((png_ptr->mode & PNG_IS_READ_STRUCT) != 0 &&
+ png_ptr->chunk_name != 0)
+ png_chunk_warning(png_ptr, error_message);
+ else
+# endif
+ png_warning(png_ptr, error_message);
+ }
+
+ else
+ {
+# ifdef PNG_READ_SUPPORTED
+ if ((png_ptr->mode & PNG_IS_READ_STRUCT) != 0 &&
+ png_ptr->chunk_name != 0)
+ png_chunk_error(png_ptr, error_message);
+ else
+# endif
+ png_error(png_ptr, error_message);
+ }
+}
+
+void /* PRIVATE */
+png_app_warning(png_const_structrp png_ptr, png_const_charp error_message)
+{
+ if (png_ptr->flags & PNG_FLAG_APP_WARNINGS_WARN)
+ png_warning(png_ptr, error_message);
+ else
+ png_error(png_ptr, error_message);
+}
+
+void /* PRIVATE */
+png_app_error(png_const_structrp png_ptr, png_const_charp error_message)
+{
+ if (png_ptr->flags & PNG_FLAG_APP_ERRORS_WARN)
+ png_warning(png_ptr, error_message);
+ else
+ png_error(png_ptr, error_message);
+}
+#endif /* BENIGN_ERRORS */
+
+/* These utilities are used internally to build an error message that relates
+ * to the current chunk. The chunk name comes from png_ptr->chunk_name,
+ * this is used to prefix the message. The message is limited in length
+ * to 63 bytes, the name characters are output as hex digits wrapped in []
+ * if the character is invalid.
+ */
+#define isnonalpha(c) ((c) < 65 || (c) > 122 || ((c) > 90 && (c) < 97))
+static PNG_CONST char png_digit[16] = {
+ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
+ 'A', 'B', 'C', 'D', 'E', 'F'
+};
+
+#define PNG_MAX_ERROR_TEXT 196 /* Currently limited be profile_error in png.c */
+#if defined(PNG_WARNINGS_SUPPORTED) || defined(PNG_ERROR_TEXT_SUPPORTED)
+static void /* PRIVATE */
+png_format_buffer(png_const_structrp png_ptr, png_charp buffer, png_const_charp
+ error_message)
+{
+ png_uint_32 chunk_name = png_ptr->chunk_name;
+ int iout = 0, ishift = 24;
+
+ while (ishift >= 0)
+ {
+ int c = (int)(chunk_name >> ishift) & 0xff;
+
+ ishift -= 8;
+ if (isnonalpha(c))
+ {
+ buffer[iout++] = PNG_LITERAL_LEFT_SQUARE_BRACKET;
+ buffer[iout++] = png_digit[(c & 0xf0) >> 4];
+ buffer[iout++] = png_digit[c & 0x0f];
+ buffer[iout++] = PNG_LITERAL_RIGHT_SQUARE_BRACKET;
+ }
+
+ else
+ {
+ buffer[iout++] = (char)c;
+ }
+ }
+
+ if (error_message == NULL)
+ buffer[iout] = '\0';
+
+ else
+ {
+ int iin = 0;
+
+ buffer[iout++] = ':';
+ buffer[iout++] = ' ';
+
+ while (iin < PNG_MAX_ERROR_TEXT-1 && error_message[iin] != '\0')
+ buffer[iout++] = error_message[iin++];
+
+ /* iin < PNG_MAX_ERROR_TEXT, so the following is safe: */
+ buffer[iout] = '\0';
+ }
+}
+#endif /* PNG_WARNINGS_SUPPORTED || PNG_ERROR_TEXT_SUPPORTED */
+
+#if defined(PNG_READ_SUPPORTED) && defined(PNG_ERROR_TEXT_SUPPORTED)
+PNG_FUNCTION(void,PNGAPI
+png_chunk_error,(png_const_structrp png_ptr, png_const_charp error_message),
+ PNG_NORETURN)
+{
+ char msg[18+PNG_MAX_ERROR_TEXT];
+ if (png_ptr == NULL)
+ png_error(png_ptr, error_message);
+
+ else
+ {
+ png_format_buffer(png_ptr, msg, error_message);
+ png_error(png_ptr, msg);
+ }
+}
+#endif /* PNG_READ_SUPPORTED && PNG_ERROR_TEXT_SUPPORTED */
+
+#ifdef PNG_WARNINGS_SUPPORTED
+void PNGAPI
+png_chunk_warning(png_const_structrp png_ptr, png_const_charp warning_message)
+{
+ char msg[18+PNG_MAX_ERROR_TEXT];
+ if (png_ptr == NULL)
+ png_warning(png_ptr, warning_message);
+
+ else
+ {
+ png_format_buffer(png_ptr, msg, warning_message);
+ png_warning(png_ptr, msg);
+ }
+}
+#endif /* PNG_WARNINGS_SUPPORTED */
+
+#ifdef PNG_READ_SUPPORTED
+#ifdef PNG_BENIGN_ERRORS_SUPPORTED
+void PNGAPI
+png_chunk_benign_error(png_const_structrp png_ptr, png_const_charp
+ error_message)
+{
+ if (png_ptr->flags & PNG_FLAG_BENIGN_ERRORS_WARN)
+ png_chunk_warning(png_ptr, error_message);
+
+ else
+ png_chunk_error(png_ptr, error_message);
+}
+#endif
+#endif /* PNG_READ_SUPPORTED */
+
+void /* PRIVATE */
+png_chunk_report(png_const_structrp png_ptr, png_const_charp message, int error)
+{
+ /* This is always supported, but for just read or just write it
+ * unconditionally does the right thing.
+ */
+# if defined(PNG_READ_SUPPORTED) && defined(PNG_WRITE_SUPPORTED)
+ if (png_ptr->mode & PNG_IS_READ_STRUCT)
+# endif
+
+# ifdef PNG_READ_SUPPORTED
+ {
+ if (error < PNG_CHUNK_ERROR)
+ png_chunk_warning(png_ptr, message);
+
+ else
+ png_chunk_benign_error(png_ptr, message);
+ }
+# endif
+
+# if defined(PNG_READ_SUPPORTED) && defined(PNG_WRITE_SUPPORTED)
+ else if (!(png_ptr->mode & PNG_IS_READ_STRUCT))
+# endif
+
+# ifdef PNG_WRITE_SUPPORTED
+ {
+ if (error < PNG_CHUNK_WRITE_ERROR)
+ png_app_warning(png_ptr, message);
+
+ else
+ png_app_error(png_ptr, message);
+ }
+# endif
+}
+
+#ifdef PNG_ERROR_TEXT_SUPPORTED
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+PNG_FUNCTION(void,
+png_fixed_error,(png_const_structrp png_ptr, png_const_charp name),PNG_NORETURN)
+{
+# define fixed_message "fixed point overflow in "
+# define fixed_message_ln ((sizeof fixed_message)-1)
+ int iin;
+ char msg[fixed_message_ln+PNG_MAX_ERROR_TEXT];
+ memcpy(msg, fixed_message, fixed_message_ln);
+ iin = 0;
+ if (name != NULL) while (iin < (PNG_MAX_ERROR_TEXT-1) && name[iin] != 0)
+ {
+ msg[fixed_message_ln + iin] = name[iin];
+ ++iin;
+ }
+ msg[fixed_message_ln + iin] = 0;
+ png_error(png_ptr, msg);
+}
+#endif
+#endif
+
+#ifdef PNG_SETJMP_SUPPORTED
+/* This API only exists if ANSI-C style error handling is used,
+ * otherwise it is necessary for png_default_error to be overridden.
+ */
+jmp_buf* PNGAPI
+png_set_longjmp_fn(png_structrp png_ptr, png_longjmp_ptr longjmp_fn,
+ size_t jmp_buf_size)
+{
+ /* From libpng 1.6.0 the app gets one chance to set a 'jmpbuf_size' value
+ * and it must not change after that. Libpng doesn't care how big the
+ * buffer is, just that it doesn't change.
+ *
+ * If the buffer size is no *larger* than the size of jmp_buf when libpng is
+ * compiled a built in jmp_buf is returned; this preserves the pre-1.6.0
+ * semantics that this call will not fail. If the size is larger, however,
+ * the buffer is allocated and this may fail, causing the function to return
+ * NULL.
+ */
+ if (png_ptr == NULL)
+ return NULL;
+
+ if (png_ptr->jmp_buf_ptr == NULL)
+ {
+ png_ptr->jmp_buf_size = 0; /* not allocated */
+
+ if (jmp_buf_size <= (sizeof png_ptr->jmp_buf_local))
+ png_ptr->jmp_buf_ptr = &png_ptr->jmp_buf_local;
+
+ else
+ {
+ png_ptr->jmp_buf_ptr = png_voidcast(jmp_buf *,
+ png_malloc_warn(png_ptr, jmp_buf_size));
+
+ if (png_ptr->jmp_buf_ptr == NULL)
+ return NULL; /* new NULL return on OOM */
+
+ png_ptr->jmp_buf_size = jmp_buf_size;
+ }
+ }
+
+ else /* Already allocated: check the size */
+ {
+ size_t size = png_ptr->jmp_buf_size;
+
+ if (size == 0)
+ {
+ size = (sizeof png_ptr->jmp_buf_local);
+ if (png_ptr->jmp_buf_ptr != &png_ptr->jmp_buf_local)
+ {
+ /* This is an internal error in libpng: somehow we have been left
+ * with a stack allocated jmp_buf when the application regained
+ * control. It's always possible to fix this up, but for the moment
+ * this is a png_error because that makes it easy to detect.
+ */
+ png_error(png_ptr, "Libpng jmp_buf still allocated");
+ /* png_ptr->jmp_buf_ptr = &png_ptr->jmp_buf_local; */
+ }
+ }
+
+ if (size != jmp_buf_size)
+ {
+ png_warning(png_ptr, "Application jmp_buf size changed");
+ return NULL; /* caller will probably crash: no choice here */
+ }
+ }
+
+ /* Finally fill in the function, now we have a satisfactory buffer. It is
+ * valid to change the function on every call.
+ */
+ png_ptr->longjmp_fn = longjmp_fn;
+ return png_ptr->jmp_buf_ptr;
+}
+
+void /* PRIVATE */
+png_free_jmpbuf(png_structrp png_ptr)
+{
+ if (png_ptr != NULL)
+ {
+ jmp_buf *jb = png_ptr->jmp_buf_ptr;
+
+ /* A size of 0 is used to indicate a local, stack, allocation of the
+ * pointer; used here and in png.c
+ */
+ if (jb != NULL && png_ptr->jmp_buf_size > 0)
+ {
+
+ /* This stuff is so that a failure to free the error control structure
+ * does not leave libpng in a state with no valid error handling: the
+ * free always succeeds, if there is an error it gets ignored.
+ */
+ if (jb != &png_ptr->jmp_buf_local)
+ {
+ /* Make an internal, libpng, jmp_buf to return here */
+ jmp_buf free_jmp_buf;
+
+ if (!setjmp(free_jmp_buf))
+ {
+ png_ptr->jmp_buf_ptr = &free_jmp_buf; /* come back here */
+ png_ptr->jmp_buf_size = 0; /* stack allocation */
+ png_ptr->longjmp_fn = longjmp;
+ png_free(png_ptr, jb); /* Return to setjmp on error */
+ }
+ }
+ }
+
+ /* *Always* cancel everything out: */
+ png_ptr->jmp_buf_size = 0;
+ png_ptr->jmp_buf_ptr = NULL;
+ png_ptr->longjmp_fn = 0;
+ }
+}
+#endif
+
+/* This is the default error handling function. Note that replacements for
+ * this function MUST NOT RETURN, or the program will likely crash. This
+ * function is used by default, or if the program supplies NULL for the
+ * error function pointer in png_set_error_fn().
+ */
+static PNG_FUNCTION(void /* PRIVATE */,
+png_default_error,(png_const_structrp png_ptr, png_const_charp error_message),
+ PNG_NORETURN)
+{
+#ifdef PNG_CONSOLE_IO_SUPPORTED
+#ifdef PNG_ERROR_NUMBERS_SUPPORTED
+ /* Check on NULL only added in 1.5.4 */
+ if (error_message != NULL && *error_message == PNG_LITERAL_SHARP)
+ {
+ /* Strip "#nnnn " from beginning of error message. */
+ int offset;
+ char error_number[16];
+ for (offset = 0; offset<15; offset++)
+ {
+ error_number[offset] = error_message[offset + 1];
+ if (error_message[offset] == ' ')
+ break;
+ }
+
+ if ((offset > 1) && (offset < 15))
+ {
+ error_number[offset - 1] = '\0';
+ fprintf(stderr, "libpng error no. %s: %s",
+ error_number, error_message + offset + 1);
+ fprintf(stderr, PNG_STRING_NEWLINE);
+ }
+
+ else
+ {
+ fprintf(stderr, "libpng error: %s, offset=%d",
+ error_message, offset);
+ fprintf(stderr, PNG_STRING_NEWLINE);
+ }
+ }
+ else
+#endif
+ {
+ fprintf(stderr, "libpng error: %s", error_message ? error_message :
+ "undefined");
+ fprintf(stderr, PNG_STRING_NEWLINE);
+ }
+#else
+ PNG_UNUSED(error_message) /* Make compiler happy */
+#endif
+ png_longjmp(png_ptr, 1);
+}
+
+PNG_FUNCTION(void,PNGAPI
+png_longjmp,(png_const_structrp png_ptr, int val),PNG_NORETURN)
+{
+#ifdef PNG_SETJMP_SUPPORTED
+ if (png_ptr && png_ptr->longjmp_fn && png_ptr->jmp_buf_ptr)
+ png_ptr->longjmp_fn(*png_ptr->jmp_buf_ptr, val);
+#endif
+
+ /* Here if not setjmp support or if png_ptr is null. */
+ PNG_ABORT();
+}
+
+#ifdef PNG_WARNINGS_SUPPORTED
+/* This function is called when there is a warning, but the library thinks
+ * it can continue anyway. Replacement functions don't have to do anything
+ * here if you don't want them to. In the default configuration, png_ptr is
+ * not used, but it is passed in case it may be useful.
+ */
+static void /* PRIVATE */
+png_default_warning(png_const_structrp png_ptr, png_const_charp warning_message)
+{
+#ifdef PNG_CONSOLE_IO_SUPPORTED
+# ifdef PNG_ERROR_NUMBERS_SUPPORTED
+ if (*warning_message == PNG_LITERAL_SHARP)
+ {
+ int offset;
+ char warning_number[16];
+ for (offset = 0; offset < 15; offset++)
+ {
+ warning_number[offset] = warning_message[offset + 1];
+ if (warning_message[offset] == ' ')
+ break;
+ }
+
+ if ((offset > 1) && (offset < 15))
+ {
+ warning_number[offset + 1] = '\0';
+ fprintf(stderr, "libpng warning no. %s: %s",
+ warning_number, warning_message + offset);
+ fprintf(stderr, PNG_STRING_NEWLINE);
+ }
+
+ else
+ {
+ fprintf(stderr, "libpng warning: %s",
+ warning_message);
+ fprintf(stderr, PNG_STRING_NEWLINE);
+ }
+ }
+ else
+# endif
+
+ {
+ fprintf(stderr, "libpng warning: %s", warning_message);
+ fprintf(stderr, PNG_STRING_NEWLINE);
+ }
+#else
+ PNG_UNUSED(warning_message) /* Make compiler happy */
+#endif
+ PNG_UNUSED(png_ptr) /* Make compiler happy */
+}
+#endif /* PNG_WARNINGS_SUPPORTED */
+
+/* This function is called when the application wants to use another method
+ * of handling errors and warnings. Note that the error function MUST NOT
+ * return to the calling routine or serious problems will occur. The return
+ * method used in the default routine calls longjmp(png_ptr->jmp_buf_ptr, 1)
+ */
+void PNGAPI
+png_set_error_fn(png_structrp png_ptr, png_voidp error_ptr,
+ png_error_ptr error_fn, png_error_ptr warning_fn)
+{
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->error_ptr = error_ptr;
+ png_ptr->error_fn = error_fn;
+#ifdef PNG_WARNINGS_SUPPORTED
+ png_ptr->warning_fn = warning_fn;
+#else
+ PNG_UNUSED(warning_fn)
+#endif
+}
+
+
+/* This function returns a pointer to the error_ptr associated with the user
+ * functions. The application should free any memory associated with this
+ * pointer before png_write_destroy and png_read_destroy are called.
+ */
+png_voidp PNGAPI
+png_get_error_ptr(png_const_structrp png_ptr)
+{
+ if (png_ptr == NULL)
+ return NULL;
+
+ return ((png_voidp)png_ptr->error_ptr);
+}
+
+
+#ifdef PNG_ERROR_NUMBERS_SUPPORTED
+void PNGAPI
+png_set_strip_error_numbers(png_structrp png_ptr, png_uint_32 strip_mode)
+{
+ if (png_ptr != NULL)
+ {
+ png_ptr->flags &=
+ ((~(PNG_FLAG_STRIP_ERROR_NUMBERS |
+ PNG_FLAG_STRIP_ERROR_TEXT))&strip_mode);
+ }
+}
+#endif
+
+#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\
+ defined(PNG_SIMPLIFIED_WRITE_SUPPORTED)
+ /* Currently the above both depend on SETJMP_SUPPORTED, however it would be
+ * possible to implement without setjmp support just so long as there is some
+ * way to handle the error return here:
+ */
+PNG_FUNCTION(void /* PRIVATE */,
+png_safe_error,(png_structp png_nonconst_ptr, png_const_charp error_message),
+ PNG_NORETURN)
+{
+ const png_const_structrp png_ptr = png_nonconst_ptr;
+ png_imagep image = png_voidcast(png_imagep, png_ptr->error_ptr);
+
+ /* An error is always logged here, overwriting anything (typically a warning)
+ * that is already there:
+ */
+ if (image != NULL)
+ {
+ png_safecat(image->message, (sizeof image->message), 0, error_message);
+ image->warning_or_error |= PNG_IMAGE_ERROR;
+
+ /* Retrieve the jmp_buf from within the png_control, making this work for
+ * C++ compilation too is pretty tricky: C++ wants a pointer to the first
+ * element of a jmp_buf, but C doesn't tell us the type of that.
+ */
+ if (image->opaque != NULL && image->opaque->error_buf != NULL)
+ longjmp(png_control_jmp_buf(image->opaque), 1);
+
+ /* Missing longjmp buffer, the following is to help debugging: */
+ {
+ size_t pos = png_safecat(image->message, (sizeof image->message), 0,
+ "bad longjmp: ");
+ png_safecat(image->message, (sizeof image->message), pos,
+ error_message);
+ }
+ }
+
+ /* Here on an internal programming error. */
+ abort();
+}
+
+#ifdef PNG_WARNINGS_SUPPORTED
+void /* PRIVATE */
+png_safe_warning(png_structp png_nonconst_ptr, png_const_charp warning_message)
+{
+ const png_const_structrp png_ptr = png_nonconst_ptr;
+ png_imagep image = png_voidcast(png_imagep, png_ptr->error_ptr);
+
+ /* A warning is only logged if there is no prior warning or error. */
+ if (image->warning_or_error == 0)
+ {
+ png_safecat(image->message, (sizeof image->message), 0, warning_message);
+ image->warning_or_error |= PNG_IMAGE_WARNING;
+ }
+}
+#endif
+
+int /* PRIVATE */
+png_safe_execute(png_imagep image_in, int (*function)(png_voidp), png_voidp arg)
+{
+ volatile png_imagep image = image_in;
+ volatile int result;
+ volatile png_voidp saved_error_buf;
+ jmp_buf safe_jmpbuf;
+
+ /* Safely execute function(arg) with png_error returning to this function. */
+ saved_error_buf = image->opaque->error_buf;
+ result = setjmp(safe_jmpbuf) == 0;
+
+ if (result)
+ {
+
+ image->opaque->error_buf = safe_jmpbuf;
+ result = function(arg);
+ }
+
+ image->opaque->error_buf = saved_error_buf;
+
+ /* And do the cleanup prior to any failure return. */
+ if (!result)
+ png_image_free(image);
+
+ return result;
+}
+#endif /* SIMPLIFIED READ/WRITE */
+#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngget.c b/ml/dlib/dlib/external/libpng/pngget.c
new file mode 100644
index 000000000..aca63a958
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngget.c
@@ -0,0 +1,1177 @@
+
+/* pngget.c - retrieval of values from info struct
+ *
+ * Last changed in libpng 1.6.1 [March 28, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ */
+
+#include "pngpriv.h"
+
+#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED)
+
+png_uint_32 PNGAPI
+png_get_valid(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_uint_32 flag)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return(info_ptr->valid & flag);
+
+ return(0);
+}
+
+png_size_t PNGAPI
+png_get_rowbytes(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return(info_ptr->rowbytes);
+
+ return(0);
+}
+
+#ifdef PNG_INFO_IMAGE_SUPPORTED
+png_bytepp PNGAPI
+png_get_rows(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return(info_ptr->row_pointers);
+
+ return(0);
+}
+#endif
+
+#ifdef PNG_EASY_ACCESS_SUPPORTED
+/* Easy access to info, added in libpng-0.99 */
+png_uint_32 PNGAPI
+png_get_image_width(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return info_ptr->width;
+
+ return (0);
+}
+
+png_uint_32 PNGAPI
+png_get_image_height(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return info_ptr->height;
+
+ return (0);
+}
+
+png_byte PNGAPI
+png_get_bit_depth(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return info_ptr->bit_depth;
+
+ return (0);
+}
+
+png_byte PNGAPI
+png_get_color_type(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return info_ptr->color_type;
+
+ return (0);
+}
+
+png_byte PNGAPI
+png_get_filter_type(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return info_ptr->filter_type;
+
+ return (0);
+}
+
+png_byte PNGAPI
+png_get_interlace_type(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return info_ptr->interlace_type;
+
+ return (0);
+}
+
+png_byte PNGAPI
+png_get_compression_type(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return info_ptr->compression_type;
+
+ return (0);
+}
+
+png_uint_32 PNGAPI
+png_get_x_pixels_per_meter(png_const_structrp png_ptr, png_const_inforp
+ info_ptr)
+{
+#ifdef PNG_pHYs_SUPPORTED
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs))
+ {
+ png_debug1(1, "in %s retrieval function",
+ "png_get_x_pixels_per_meter");
+
+ if (info_ptr->phys_unit_type == PNG_RESOLUTION_METER)
+ return (info_ptr->x_pixels_per_unit);
+ }
+#endif
+
+ return (0);
+}
+
+png_uint_32 PNGAPI
+png_get_y_pixels_per_meter(png_const_structrp png_ptr, png_const_inforp
+ info_ptr)
+{
+#ifdef PNG_pHYs_SUPPORTED
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs))
+ {
+ png_debug1(1, "in %s retrieval function",
+ "png_get_y_pixels_per_meter");
+
+ if (info_ptr->phys_unit_type == PNG_RESOLUTION_METER)
+ return (info_ptr->y_pixels_per_unit);
+ }
+#endif
+
+ return (0);
+}
+
+png_uint_32 PNGAPI
+png_get_pixels_per_meter(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+#ifdef PNG_pHYs_SUPPORTED
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs))
+ {
+ png_debug1(1, "in %s retrieval function", "png_get_pixels_per_meter");
+
+ if (info_ptr->phys_unit_type == PNG_RESOLUTION_METER &&
+ info_ptr->x_pixels_per_unit == info_ptr->y_pixels_per_unit)
+ return (info_ptr->x_pixels_per_unit);
+ }
+#endif
+
+ return (0);
+}
+
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+float PNGAPI
+png_get_pixel_aspect_ratio(png_const_structrp png_ptr, png_const_inforp
+ info_ptr)
+{
+#ifdef PNG_READ_pHYs_SUPPORTED
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs))
+ {
+ png_debug1(1, "in %s retrieval function", "png_get_aspect_ratio");
+
+ if (info_ptr->x_pixels_per_unit != 0)
+ return ((float)((float)info_ptr->y_pixels_per_unit
+ /(float)info_ptr->x_pixels_per_unit));
+ }
+#else
+ PNG_UNUSED(png_ptr)
+ PNG_UNUSED(info_ptr)
+#endif
+
+ return ((float)0.0);
+}
+#endif
+
+#ifdef PNG_FIXED_POINT_SUPPORTED
+png_fixed_point PNGAPI
+png_get_pixel_aspect_ratio_fixed(png_const_structrp png_ptr,
+ png_const_inforp info_ptr)
+{
+#ifdef PNG_READ_pHYs_SUPPORTED
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs)
+ && info_ptr->x_pixels_per_unit > 0 && info_ptr->y_pixels_per_unit > 0
+ && info_ptr->x_pixels_per_unit <= PNG_UINT_31_MAX
+ && info_ptr->y_pixels_per_unit <= PNG_UINT_31_MAX)
+ {
+ png_fixed_point res;
+
+ png_debug1(1, "in %s retrieval function", "png_get_aspect_ratio_fixed");
+
+ /* The following casts work because a PNG 4 byte integer only has a valid
+ * range of 0..2^31-1; otherwise the cast might overflow.
+ */
+ if (png_muldiv(&res, (png_int_32)info_ptr->y_pixels_per_unit, PNG_FP_1,
+ (png_int_32)info_ptr->x_pixels_per_unit))
+ return res;
+ }
+#else
+ PNG_UNUSED(png_ptr)
+ PNG_UNUSED(info_ptr)
+#endif
+
+ return 0;
+}
+#endif
+
+png_int_32 PNGAPI
+png_get_x_offset_microns(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+#ifdef PNG_oFFs_SUPPORTED
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs))
+ {
+ png_debug1(1, "in %s retrieval function", "png_get_x_offset_microns");
+
+ if (info_ptr->offset_unit_type == PNG_OFFSET_MICROMETER)
+ return (info_ptr->x_offset);
+ }
+#endif
+
+ return (0);
+}
+
+png_int_32 PNGAPI
+png_get_y_offset_microns(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+#ifdef PNG_oFFs_SUPPORTED
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs))
+ {
+ png_debug1(1, "in %s retrieval function", "png_get_y_offset_microns");
+
+ if (info_ptr->offset_unit_type == PNG_OFFSET_MICROMETER)
+ return (info_ptr->y_offset);
+ }
+#endif
+
+ return (0);
+}
+
+png_int_32 PNGAPI
+png_get_x_offset_pixels(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+#ifdef PNG_oFFs_SUPPORTED
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs))
+ {
+ png_debug1(1, "in %s retrieval function", "png_get_x_offset_pixels");
+
+ if (info_ptr->offset_unit_type == PNG_OFFSET_PIXEL)
+ return (info_ptr->x_offset);
+ }
+#endif
+
+ return (0);
+}
+
+png_int_32 PNGAPI
+png_get_y_offset_pixels(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+#ifdef PNG_oFFs_SUPPORTED
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs))
+ {
+ png_debug1(1, "in %s retrieval function", "png_get_y_offset_pixels");
+
+ if (info_ptr->offset_unit_type == PNG_OFFSET_PIXEL)
+ return (info_ptr->y_offset);
+ }
+#endif
+
+ return (0);
+}
+
+#ifdef PNG_INCH_CONVERSIONS_SUPPORTED
+static png_uint_32
+ppi_from_ppm(png_uint_32 ppm)
+{
+#if 0
+ /* The conversion is *(2.54/100), in binary (32 digits):
+ * .00000110100000001001110101001001
+ */
+ png_uint_32 t1001, t1101;
+ ppm >>= 1; /* .1 */
+ t1001 = ppm + (ppm >> 3); /* .1001 */
+ t1101 = t1001 + (ppm >> 1); /* .1101 */
+ ppm >>= 20; /* .000000000000000000001 */
+ t1101 += t1101 >> 15; /* .1101000000000001101 */
+ t1001 >>= 11; /* .000000000001001 */
+ t1001 += t1001 >> 12; /* .000000000001001000000001001 */
+ ppm += t1001; /* .000000000001001000001001001 */
+ ppm += t1101; /* .110100000001001110101001001 */
+ return (ppm + 16) >> 5;/* .00000110100000001001110101001001 */
+#else
+ /* The argument is a PNG unsigned integer, so it is not permitted
+ * to be bigger than 2^31.
+ */
+ png_fixed_point result;
+ if (ppm <= PNG_UINT_31_MAX && png_muldiv(&result, (png_int_32)ppm, 127,
+ 5000))
+ return result;
+
+ /* Overflow. */
+ return 0;
+#endif
+}
+
+png_uint_32 PNGAPI
+png_get_pixels_per_inch(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ return ppi_from_ppm(png_get_pixels_per_meter(png_ptr, info_ptr));
+}
+
+png_uint_32 PNGAPI
+png_get_x_pixels_per_inch(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ return ppi_from_ppm(png_get_x_pixels_per_meter(png_ptr, info_ptr));
+}
+
+png_uint_32 PNGAPI
+png_get_y_pixels_per_inch(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ return ppi_from_ppm(png_get_y_pixels_per_meter(png_ptr, info_ptr));
+}
+
+#ifdef PNG_FIXED_POINT_SUPPORTED
+static png_fixed_point
+png_fixed_inches_from_microns(png_const_structrp png_ptr, png_int_32 microns)
+{
+ /* Convert from metres * 1,000,000 to inches * 100,000, meters to
+ * inches is simply *(100/2.54), so we want *(10/2.54) == 500/127.
+ * Notice that this can overflow - a warning is output and 0 is
+ * returned.
+ */
+ return png_muldiv_warn(png_ptr, microns, 500, 127);
+}
+
+png_fixed_point PNGAPI
+png_get_x_offset_inches_fixed(png_const_structrp png_ptr,
+ png_const_inforp info_ptr)
+{
+ return png_fixed_inches_from_microns(png_ptr,
+ png_get_x_offset_microns(png_ptr, info_ptr));
+}
+#endif
+
+#ifdef PNG_FIXED_POINT_SUPPORTED
+png_fixed_point PNGAPI
+png_get_y_offset_inches_fixed(png_const_structrp png_ptr,
+ png_const_inforp info_ptr)
+{
+ return png_fixed_inches_from_microns(png_ptr,
+ png_get_y_offset_microns(png_ptr, info_ptr));
+}
+#endif
+
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+float PNGAPI
+png_get_x_offset_inches(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ /* To avoid the overflow do the conversion directly in floating
+ * point.
+ */
+ return (float)(png_get_x_offset_microns(png_ptr, info_ptr) * .00003937);
+}
+#endif
+
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+float PNGAPI
+png_get_y_offset_inches(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ /* To avoid the overflow do the conversion directly in floating
+ * point.
+ */
+ return (float)(png_get_y_offset_microns(png_ptr, info_ptr) * .00003937);
+}
+#endif
+
+#ifdef PNG_pHYs_SUPPORTED
+png_uint_32 PNGAPI
+png_get_pHYs_dpi(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_uint_32 *res_x, png_uint_32 *res_y, int *unit_type)
+{
+ png_uint_32 retval = 0;
+
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs))
+ {
+ png_debug1(1, "in %s retrieval function", "pHYs");
+
+ if (res_x != NULL)
+ {
+ *res_x = info_ptr->x_pixels_per_unit;
+ retval |= PNG_INFO_pHYs;
+ }
+
+ if (res_y != NULL)
+ {
+ *res_y = info_ptr->y_pixels_per_unit;
+ retval |= PNG_INFO_pHYs;
+ }
+
+ if (unit_type != NULL)
+ {
+ *unit_type = (int)info_ptr->phys_unit_type;
+ retval |= PNG_INFO_pHYs;
+
+ if (*unit_type == 1)
+ {
+ if (res_x != NULL) *res_x = (png_uint_32)(*res_x * .0254 + .50);
+ if (res_y != NULL) *res_y = (png_uint_32)(*res_y * .0254 + .50);
+ }
+ }
+ }
+
+ return (retval);
+}
+#endif /* PNG_pHYs_SUPPORTED */
+#endif /* PNG_INCH_CONVERSIONS_SUPPORTED */
+
+/* png_get_channels really belongs in here, too, but it's been around longer */
+
+#endif /* PNG_EASY_ACCESS_SUPPORTED */
+
+
+png_byte PNGAPI
+png_get_channels(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return(info_ptr->channels);
+
+ return (0);
+}
+
+#ifdef PNG_READ_SUPPORTED
+png_const_bytep PNGAPI
+png_get_signature(png_const_structrp png_ptr, png_const_inforp info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return(info_ptr->signature);
+
+ return (NULL);
+}
+#endif
+
+#ifdef PNG_bKGD_SUPPORTED
+png_uint_32 PNGAPI
+png_get_bKGD(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_color_16p *background)
+{
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_bKGD)
+ && background != NULL)
+ {
+ png_debug1(1, "in %s retrieval function", "bKGD");
+
+ *background = &(info_ptr->background);
+ return (PNG_INFO_bKGD);
+ }
+
+ return (0);
+}
+#endif
+
+#ifdef PNG_cHRM_SUPPORTED
+/* The XYZ APIs were added in 1.5.5 to take advantage of the code added at the
+ * same time to correct the rgb grayscale coefficient defaults obtained from the
+ * cHRM chunk in 1.5.4
+ */
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+png_uint_32 PNGAPI
+png_get_cHRM(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ double *white_x, double *white_y, double *red_x, double *red_y,
+ double *green_x, double *green_y, double *blue_x, double *blue_y)
+{
+ /* Quiet API change: this code used to only return the end points if a cHRM
+ * chunk was present, but the end points can also come from iCCP or sRGB
+ * chunks, so in 1.6.0 the png_get_ APIs return the end points regardless and
+ * the png_set_ APIs merely check that set end points are mutually
+ * consistent.
+ */
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS))
+ {
+ png_debug1(1, "in %s retrieval function", "cHRM");
+
+ if (white_x != NULL)
+ *white_x = png_float(png_ptr,
+ info_ptr->colorspace.end_points_xy.whitex, "cHRM white X");
+ if (white_y != NULL)
+ *white_y = png_float(png_ptr,
+ info_ptr->colorspace.end_points_xy.whitey, "cHRM white Y");
+ if (red_x != NULL)
+ *red_x = png_float(png_ptr, info_ptr->colorspace.end_points_xy.redx,
+ "cHRM red X");
+ if (red_y != NULL)
+ *red_y = png_float(png_ptr, info_ptr->colorspace.end_points_xy.redy,
+ "cHRM red Y");
+ if (green_x != NULL)
+ *green_x = png_float(png_ptr,
+ info_ptr->colorspace.end_points_xy.greenx, "cHRM green X");
+ if (green_y != NULL)
+ *green_y = png_float(png_ptr,
+ info_ptr->colorspace.end_points_xy.greeny, "cHRM green Y");
+ if (blue_x != NULL)
+ *blue_x = png_float(png_ptr, info_ptr->colorspace.end_points_xy.bluex,
+ "cHRM blue X");
+ if (blue_y != NULL)
+ *blue_y = png_float(png_ptr, info_ptr->colorspace.end_points_xy.bluey,
+ "cHRM blue Y");
+ return (PNG_INFO_cHRM);
+ }
+
+ return (0);
+}
+
+png_uint_32 PNGAPI
+png_get_cHRM_XYZ(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ double *red_X, double *red_Y, double *red_Z, double *green_X,
+ double *green_Y, double *green_Z, double *blue_X, double *blue_Y,
+ double *blue_Z)
+{
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS))
+ {
+ png_debug1(1, "in %s retrieval function", "cHRM_XYZ(float)");
+
+ if (red_X != NULL)
+ *red_X = png_float(png_ptr, info_ptr->colorspace.end_points_XYZ.red_X,
+ "cHRM red X");
+ if (red_Y != NULL)
+ *red_Y = png_float(png_ptr, info_ptr->colorspace.end_points_XYZ.red_Y,
+ "cHRM red Y");
+ if (red_Z != NULL)
+ *red_Z = png_float(png_ptr, info_ptr->colorspace.end_points_XYZ.red_Z,
+ "cHRM red Z");
+ if (green_X != NULL)
+ *green_X = png_float(png_ptr,
+ info_ptr->colorspace.end_points_XYZ.green_X, "cHRM green X");
+ if (green_Y != NULL)
+ *green_Y = png_float(png_ptr,
+ info_ptr->colorspace.end_points_XYZ.green_Y, "cHRM green Y");
+ if (green_Z != NULL)
+ *green_Z = png_float(png_ptr,
+ info_ptr->colorspace.end_points_XYZ.green_Z, "cHRM green Z");
+ if (blue_X != NULL)
+ *blue_X = png_float(png_ptr,
+ info_ptr->colorspace.end_points_XYZ.blue_X, "cHRM blue X");
+ if (blue_Y != NULL)
+ *blue_Y = png_float(png_ptr,
+ info_ptr->colorspace.end_points_XYZ.blue_Y, "cHRM blue Y");
+ if (blue_Z != NULL)
+ *blue_Z = png_float(png_ptr,
+ info_ptr->colorspace.end_points_XYZ.blue_Z, "cHRM blue Z");
+ return (PNG_INFO_cHRM);
+ }
+
+ return (0);
+}
+# endif
+
+# ifdef PNG_FIXED_POINT_SUPPORTED
+png_uint_32 PNGAPI
+png_get_cHRM_XYZ_fixed(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_fixed_point *int_red_X, png_fixed_point *int_red_Y,
+ png_fixed_point *int_red_Z, png_fixed_point *int_green_X,
+ png_fixed_point *int_green_Y, png_fixed_point *int_green_Z,
+ png_fixed_point *int_blue_X, png_fixed_point *int_blue_Y,
+ png_fixed_point *int_blue_Z)
+{
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS))
+ {
+ png_debug1(1, "in %s retrieval function", "cHRM_XYZ");
+
+ if (int_red_X != NULL)
+ *int_red_X = info_ptr->colorspace.end_points_XYZ.red_X;
+ if (int_red_Y != NULL)
+ *int_red_Y = info_ptr->colorspace.end_points_XYZ.red_Y;
+ if (int_red_Z != NULL)
+ *int_red_Z = info_ptr->colorspace.end_points_XYZ.red_Z;
+ if (int_green_X != NULL)
+ *int_green_X = info_ptr->colorspace.end_points_XYZ.green_X;
+ if (int_green_Y != NULL)
+ *int_green_Y = info_ptr->colorspace.end_points_XYZ.green_Y;
+ if (int_green_Z != NULL)
+ *int_green_Z = info_ptr->colorspace.end_points_XYZ.green_Z;
+ if (int_blue_X != NULL)
+ *int_blue_X = info_ptr->colorspace.end_points_XYZ.blue_X;
+ if (int_blue_Y != NULL)
+ *int_blue_Y = info_ptr->colorspace.end_points_XYZ.blue_Y;
+ if (int_blue_Z != NULL)
+ *int_blue_Z = info_ptr->colorspace.end_points_XYZ.blue_Z;
+ return (PNG_INFO_cHRM);
+ }
+
+ return (0);
+}
+
+png_uint_32 PNGAPI
+png_get_cHRM_fixed(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_fixed_point *white_x, png_fixed_point *white_y, png_fixed_point *red_x,
+ png_fixed_point *red_y, png_fixed_point *green_x, png_fixed_point *green_y,
+ png_fixed_point *blue_x, png_fixed_point *blue_y)
+{
+ png_debug1(1, "in %s retrieval function", "cHRM");
+
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS))
+ {
+ if (white_x != NULL)
+ *white_x = info_ptr->colorspace.end_points_xy.whitex;
+ if (white_y != NULL)
+ *white_y = info_ptr->colorspace.end_points_xy.whitey;
+ if (red_x != NULL)
+ *red_x = info_ptr->colorspace.end_points_xy.redx;
+ if (red_y != NULL)
+ *red_y = info_ptr->colorspace.end_points_xy.redy;
+ if (green_x != NULL)
+ *green_x = info_ptr->colorspace.end_points_xy.greenx;
+ if (green_y != NULL)
+ *green_y = info_ptr->colorspace.end_points_xy.greeny;
+ if (blue_x != NULL)
+ *blue_x = info_ptr->colorspace.end_points_xy.bluex;
+ if (blue_y != NULL)
+ *blue_y = info_ptr->colorspace.end_points_xy.bluey;
+ return (PNG_INFO_cHRM);
+ }
+
+ return (0);
+}
+# endif
+#endif
+
+#ifdef PNG_gAMA_SUPPORTED
+# ifdef PNG_FIXED_POINT_SUPPORTED
+png_uint_32 PNGAPI
+png_get_gAMA_fixed(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_fixed_point *file_gamma)
+{
+ png_debug1(1, "in %s retrieval function", "gAMA");
+
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_GAMMA) &&
+ file_gamma != NULL)
+ {
+ *file_gamma = info_ptr->colorspace.gamma;
+ return (PNG_INFO_gAMA);
+ }
+
+ return (0);
+}
+# endif
+
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+png_uint_32 PNGAPI
+png_get_gAMA(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ double *file_gamma)
+{
+ png_debug1(1, "in %s retrieval function", "gAMA(float)");
+
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_GAMMA) &&
+ file_gamma != NULL)
+ {
+ *file_gamma = png_float(png_ptr, info_ptr->colorspace.gamma,
+ "png_get_gAMA");
+ return (PNG_INFO_gAMA);
+ }
+
+ return (0);
+}
+# endif
+#endif
+
+#ifdef PNG_sRGB_SUPPORTED
+png_uint_32 PNGAPI
+png_get_sRGB(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ int *file_srgb_intent)
+{
+ png_debug1(1, "in %s retrieval function", "sRGB");
+
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_sRGB)
+ && file_srgb_intent != NULL)
+ {
+ *file_srgb_intent = info_ptr->colorspace.rendering_intent;
+ return (PNG_INFO_sRGB);
+ }
+
+ return (0);
+}
+#endif
+
+#ifdef PNG_iCCP_SUPPORTED
+png_uint_32 PNGAPI
+png_get_iCCP(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_charpp name, int *compression_type,
+ png_bytepp profile, png_uint_32 *proflen)
+{
+ png_debug1(1, "in %s retrieval function", "iCCP");
+
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_iCCP)
+ && name != NULL && compression_type != NULL && profile != NULL &&
+ proflen != NULL)
+ {
+ *name = info_ptr->iccp_name;
+ *profile = info_ptr->iccp_profile;
+ *proflen = png_get_uint_32(info_ptr->iccp_profile);
+ /* This is somewhat irrelevant since the profile data returned has
+ * actually been uncompressed.
+ */
+ *compression_type = PNG_COMPRESSION_TYPE_BASE;
+ return (PNG_INFO_iCCP);
+ }
+
+ return (0);
+}
+#endif
+
+#ifdef PNG_sPLT_SUPPORTED
+int PNGAPI
+png_get_sPLT(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_sPLT_tpp spalettes)
+{
+ if (png_ptr != NULL && info_ptr != NULL && spalettes != NULL)
+ {
+ *spalettes = info_ptr->splt_palettes;
+ return info_ptr->splt_palettes_num;
+ }
+
+ return (0);
+}
+#endif
+
+#ifdef PNG_hIST_SUPPORTED
+png_uint_32 PNGAPI
+png_get_hIST(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_uint_16p *hist)
+{
+ png_debug1(1, "in %s retrieval function", "hIST");
+
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_hIST)
+ && hist != NULL)
+ {
+ *hist = info_ptr->hist;
+ return (PNG_INFO_hIST);
+ }
+
+ return (0);
+}
+#endif
+
+png_uint_32 PNGAPI
+png_get_IHDR(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_uint_32 *width, png_uint_32 *height, int *bit_depth,
+ int *color_type, int *interlace_type, int *compression_type,
+ int *filter_type)
+{
+ png_debug1(1, "in %s retrieval function", "IHDR");
+
+ if (png_ptr == NULL || info_ptr == NULL || width == NULL ||
+ height == NULL || bit_depth == NULL || color_type == NULL)
+ return (0);
+
+ *width = info_ptr->width;
+ *height = info_ptr->height;
+ *bit_depth = info_ptr->bit_depth;
+ *color_type = info_ptr->color_type;
+
+ if (compression_type != NULL)
+ *compression_type = info_ptr->compression_type;
+
+ if (filter_type != NULL)
+ *filter_type = info_ptr->filter_type;
+
+ if (interlace_type != NULL)
+ *interlace_type = info_ptr->interlace_type;
+
+ /* This is redundant if we can be sure that the info_ptr values were all
+ * assigned in png_set_IHDR(). We do the check anyhow in case an
+ * application has ignored our advice not to mess with the members
+ * of info_ptr directly.
+ */
+ png_check_IHDR(png_ptr, info_ptr->width, info_ptr->height,
+ info_ptr->bit_depth, info_ptr->color_type, info_ptr->interlace_type,
+ info_ptr->compression_type, info_ptr->filter_type);
+
+ return (1);
+}
+
+#ifdef PNG_oFFs_SUPPORTED
+png_uint_32 PNGAPI
+png_get_oFFs(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_int_32 *offset_x, png_int_32 *offset_y, int *unit_type)
+{
+ png_debug1(1, "in %s retrieval function", "oFFs");
+
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs)
+ && offset_x != NULL && offset_y != NULL && unit_type != NULL)
+ {
+ *offset_x = info_ptr->x_offset;
+ *offset_y = info_ptr->y_offset;
+ *unit_type = (int)info_ptr->offset_unit_type;
+ return (PNG_INFO_oFFs);
+ }
+
+ return (0);
+}
+#endif
+
+#ifdef PNG_pCAL_SUPPORTED
+png_uint_32 PNGAPI
+png_get_pCAL(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_charp *purpose, png_int_32 *X0, png_int_32 *X1, int *type, int *nparams,
+ png_charp *units, png_charpp *params)
+{
+ png_debug1(1, "in %s retrieval function", "pCAL");
+
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pCAL)
+ && purpose != NULL && X0 != NULL && X1 != NULL && type != NULL &&
+ nparams != NULL && units != NULL && params != NULL)
+ {
+ *purpose = info_ptr->pcal_purpose;
+ *X0 = info_ptr->pcal_X0;
+ *X1 = info_ptr->pcal_X1;
+ *type = (int)info_ptr->pcal_type;
+ *nparams = (int)info_ptr->pcal_nparams;
+ *units = info_ptr->pcal_units;
+ *params = info_ptr->pcal_params;
+ return (PNG_INFO_pCAL);
+ }
+
+ return (0);
+}
+#endif
+
+#ifdef PNG_sCAL_SUPPORTED
+# ifdef PNG_FIXED_POINT_SUPPORTED
+# if defined(PNG_FLOATING_ARITHMETIC_SUPPORTED) || \
+ defined(PNG_FLOATING_POINT_SUPPORTED)
+png_uint_32 PNGAPI
+png_get_sCAL_fixed(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ int *unit, png_fixed_point *width, png_fixed_point *height)
+{
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->valid & PNG_INFO_sCAL))
+ {
+ *unit = info_ptr->scal_unit;
+ /*TODO: make this work without FP support; the API is currently eliminated
+ * if neither floating point APIs nor internal floating point arithmetic
+ * are enabled.
+ */
+ *width = png_fixed(png_ptr, atof(info_ptr->scal_s_width), "sCAL width");
+ *height = png_fixed(png_ptr, atof(info_ptr->scal_s_height),
+ "sCAL height");
+ return (PNG_INFO_sCAL);
+ }
+
+ return(0);
+}
+# endif /* FLOATING_ARITHMETIC */
+# endif /* FIXED_POINT */
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+png_uint_32 PNGAPI
+png_get_sCAL(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ int *unit, double *width, double *height)
+{
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->valid & PNG_INFO_sCAL))
+ {
+ *unit = info_ptr->scal_unit;
+ *width = atof(info_ptr->scal_s_width);
+ *height = atof(info_ptr->scal_s_height);
+ return (PNG_INFO_sCAL);
+ }
+
+ return(0);
+}
+# endif /* FLOATING POINT */
+png_uint_32 PNGAPI
+png_get_sCAL_s(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ int *unit, png_charpp width, png_charpp height)
+{
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->valid & PNG_INFO_sCAL))
+ {
+ *unit = info_ptr->scal_unit;
+ *width = info_ptr->scal_s_width;
+ *height = info_ptr->scal_s_height;
+ return (PNG_INFO_sCAL);
+ }
+
+ return(0);
+}
+#endif /* sCAL */
+
+#ifdef PNG_pHYs_SUPPORTED
+png_uint_32 PNGAPI
+png_get_pHYs(png_const_structrp png_ptr, png_const_inforp info_ptr,
+ png_uint_32 *res_x, png_uint_32 *res_y, int *unit_type)
+{
+ png_uint_32 retval = 0;
+
+ png_debug1(1, "in %s retrieval function", "pHYs");
+
+ if (png_ptr != NULL && info_ptr != NULL &&
+ (info_ptr->valid & PNG_INFO_pHYs))
+ {
+ if (res_x != NULL)
+ {
+ *res_x = info_ptr->x_pixels_per_unit;
+ retval |= PNG_INFO_pHYs;
+ }
+
+ if (res_y != NULL)
+ {
+ *res_y = info_ptr->y_pixels_per_unit;
+ retval |= PNG_INFO_pHYs;
+ }
+
+ if (unit_type != NULL)
+ {
+ *unit_type = (int)info_ptr->phys_unit_type;
+ retval |= PNG_INFO_pHYs;
+ }
+ }
+
+ return (retval);
+}
+#endif /* pHYs */
+
+png_uint_32 PNGAPI
+png_get_PLTE(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_colorp *palette, int *num_palette)
+{
+ png_debug1(1, "in %s retrieval function", "PLTE");
+
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_PLTE)
+ && palette != NULL)
+ {
+ *palette = info_ptr->palette;
+ *num_palette = info_ptr->num_palette;
+ png_debug1(3, "num_palette = %d", *num_palette);
+ return (PNG_INFO_PLTE);
+ }
+
+ return (0);
+}
+
+#ifdef PNG_sBIT_SUPPORTED
+png_uint_32 PNGAPI
+png_get_sBIT(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_color_8p *sig_bit)
+{
+ png_debug1(1, "in %s retrieval function", "sBIT");
+
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_sBIT)
+ && sig_bit != NULL)
+ {
+ *sig_bit = &(info_ptr->sig_bit);
+ return (PNG_INFO_sBIT);
+ }
+
+ return (0);
+}
+#endif
+
+#ifdef PNG_TEXT_SUPPORTED
+int PNGAPI
+png_get_text(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_textp *text_ptr, int *num_text)
+{
+ if (png_ptr != NULL && info_ptr != NULL && info_ptr->num_text > 0)
+ {
+ png_debug1(1, "in 0x%lx retrieval function",
+ (unsigned long)png_ptr->chunk_name);
+
+ if (text_ptr != NULL)
+ *text_ptr = info_ptr->text;
+
+ if (num_text != NULL)
+ *num_text = info_ptr->num_text;
+
+ return info_ptr->num_text;
+ }
+
+ if (num_text != NULL)
+ *num_text = 0;
+
+ return(0);
+}
+#endif
+
+#ifdef PNG_tIME_SUPPORTED
+png_uint_32 PNGAPI
+png_get_tIME(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_timep *mod_time)
+{
+ png_debug1(1, "in %s retrieval function", "tIME");
+
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_tIME)
+ && mod_time != NULL)
+ {
+ *mod_time = &(info_ptr->mod_time);
+ return (PNG_INFO_tIME);
+ }
+
+ return (0);
+}
+#endif
+
+#ifdef PNG_tRNS_SUPPORTED
+png_uint_32 PNGAPI
+png_get_tRNS(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_bytep *trans_alpha, int *num_trans, png_color_16p *trans_color)
+{
+ png_uint_32 retval = 0;
+ if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_tRNS))
+ {
+ png_debug1(1, "in %s retrieval function", "tRNS");
+
+ if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ if (trans_alpha != NULL)
+ {
+ *trans_alpha = info_ptr->trans_alpha;
+ retval |= PNG_INFO_tRNS;
+ }
+
+ if (trans_color != NULL)
+ *trans_color = &(info_ptr->trans_color);
+ }
+
+ else /* if (info_ptr->color_type != PNG_COLOR_TYPE_PALETTE) */
+ {
+ if (trans_color != NULL)
+ {
+ *trans_color = &(info_ptr->trans_color);
+ retval |= PNG_INFO_tRNS;
+ }
+
+ if (trans_alpha != NULL)
+ *trans_alpha = NULL;
+ }
+
+ if (num_trans != NULL)
+ {
+ *num_trans = info_ptr->num_trans;
+ retval |= PNG_INFO_tRNS;
+ }
+ }
+
+ return (retval);
+}
+#endif
+
+#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED
+int PNGAPI
+png_get_unknown_chunks(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_unknown_chunkpp unknowns)
+{
+ if (png_ptr != NULL && info_ptr != NULL && unknowns != NULL)
+ {
+ *unknowns = info_ptr->unknown_chunks;
+ return info_ptr->unknown_chunks_num;
+ }
+
+ return (0);
+}
+#endif
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+png_byte PNGAPI
+png_get_rgb_to_gray_status (png_const_structrp png_ptr)
+{
+ return (png_byte)(png_ptr ? png_ptr->rgb_to_gray_status : 0);
+}
+#endif
+
+#ifdef PNG_USER_CHUNKS_SUPPORTED
+png_voidp PNGAPI
+png_get_user_chunk_ptr(png_const_structrp png_ptr)
+{
+ return (png_ptr ? png_ptr->user_chunk_ptr : NULL);
+}
+#endif
+
+png_size_t PNGAPI
+png_get_compression_buffer_size(png_const_structrp png_ptr)
+{
+ if (png_ptr == NULL)
+ return 0;
+
+# ifdef PNG_WRITE_SUPPORTED
+ if (png_ptr->mode & PNG_IS_READ_STRUCT)
+# endif
+ {
+# ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+ return png_ptr->IDAT_read_size;
+# else
+ return PNG_IDAT_READ_SIZE;
+# endif
+ }
+
+# ifdef PNG_WRITE_SUPPORTED
+ else
+ return png_ptr->zbuffer_size;
+# endif
+}
+
+#ifdef PNG_SET_USER_LIMITS_SUPPORTED
+/* These functions were added to libpng 1.2.6 and were enabled
+ * by default in libpng-1.4.0 */
+png_uint_32 PNGAPI
+png_get_user_width_max (png_const_structrp png_ptr)
+{
+ return (png_ptr ? png_ptr->user_width_max : 0);
+}
+
+png_uint_32 PNGAPI
+png_get_user_height_max (png_const_structrp png_ptr)
+{
+ return (png_ptr ? png_ptr->user_height_max : 0);
+}
+
+/* This function was added to libpng 1.4.0 */
+png_uint_32 PNGAPI
+png_get_chunk_cache_max (png_const_structrp png_ptr)
+{
+ return (png_ptr ? png_ptr->user_chunk_cache_max : 0);
+}
+
+/* This function was added to libpng 1.4.1 */
+png_alloc_size_t PNGAPI
+png_get_chunk_malloc_max (png_const_structrp png_ptr)
+{
+ return (png_ptr ? png_ptr->user_chunk_malloc_max : 0);
+}
+#endif /* ?PNG_SET_USER_LIMITS_SUPPORTED */
+
+/* These functions were added to libpng 1.4.0 */
+#ifdef PNG_IO_STATE_SUPPORTED
+png_uint_32 PNGAPI
+png_get_io_state (png_const_structrp png_ptr)
+{
+ return png_ptr->io_state;
+}
+
+png_uint_32 PNGAPI
+png_get_io_chunk_type (png_const_structrp png_ptr)
+{
+ return png_ptr->chunk_name;
+}
+#endif /* ?PNG_IO_STATE_SUPPORTED */
+
+#ifdef PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED
+# ifdef PNG_GET_PALETTE_MAX_SUPPORTED
+int PNGAPI
+png_get_palette_max(png_const_structp png_ptr, png_const_infop info_ptr)
+{
+ if (png_ptr != NULL && info_ptr != NULL)
+ return png_ptr->num_palette_max;
+
+ return (-1);
+}
+# endif
+#endif
+
+#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pnginfo.h b/ml/dlib/dlib/external/libpng/pnginfo.h
new file mode 100644
index 000000000..26bf26502
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pnginfo.h
@@ -0,0 +1,260 @@
+
+/* pnginfo.h - header file for PNG reference library
+ *
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * Last changed in libpng 1.6.1 [March 28, 2013]
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+ /* png_info is a structure that holds the information in a PNG file so
+ * that the application can find out the characteristics of the image.
+ * If you are reading the file, this structure will tell you what is
+ * in the PNG file. If you are writing the file, fill in the information
+ * you want to put into the PNG file, using png_set_*() functions, then
+ * call png_write_info().
+ *
+ * The names chosen should be very close to the PNG specification, so
+ * consult that document for information about the meaning of each field.
+ *
+ * With libpng < 0.95, it was only possible to directly set and read the
+ * the values in the png_info_struct, which meant that the contents and
+ * order of the values had to remain fixed. With libpng 0.95 and later,
+ * however, there are now functions that abstract the contents of
+ * png_info_struct from the application, so this makes it easier to use
+ * libpng with dynamic libraries, and even makes it possible to use
+ * libraries that don't have all of the libpng ancillary chunk-handing
+ * functionality. In libpng-1.5.0 this was moved into a separate private
+ * file that is not visible to applications.
+ *
+ * The following members may have allocated storage attached that should be
+ * cleaned up before the structure is discarded: palette, trans, text,
+ * pcal_purpose, pcal_units, pcal_params, hist, iccp_name, iccp_profile,
+ * splt_palettes, scal_unit, row_pointers, and unknowns. By default, these
+ * are automatically freed when the info structure is deallocated, if they were
+ * allocated internally by libpng. This behavior can be changed by means
+ * of the png_data_freer() function.
+ *
+ * More allocation details: all the chunk-reading functions that
+ * change these members go through the corresponding png_set_*
+ * functions. A function to clear these members is available: see
+ * png_free_data(). The png_set_* functions do not depend on being
+ * able to point info structure members to any of the storage they are
+ * passed (they make their own copies), EXCEPT that the png_set_text
+ * functions use the same storage passed to them in the text_ptr or
+ * itxt_ptr structure argument, and the png_set_rows and png_set_unknowns
+ * functions do not make their own copies.
+ */
+#ifndef PNGINFO_H
+#define PNGINFO_H
+
+struct png_info_def
+{
+ /* The following are necessary for every PNG file */
+ png_uint_32 width; /* width of image in pixels (from IHDR) */
+ png_uint_32 height; /* height of image in pixels (from IHDR) */
+ png_uint_32 valid; /* valid chunk data (see PNG_INFO_ below) */
+ png_size_t rowbytes; /* bytes needed to hold an untransformed row */
+ png_colorp palette; /* array of color values (valid & PNG_INFO_PLTE) */
+ png_uint_16 num_palette; /* number of color entries in "palette" (PLTE) */
+ png_uint_16 num_trans; /* number of transparent palette color (tRNS) */
+ png_byte bit_depth; /* 1, 2, 4, 8, or 16 bits/channel (from IHDR) */
+ png_byte color_type; /* see PNG_COLOR_TYPE_ below (from IHDR) */
+ /* The following three should have been named *_method not *_type */
+ png_byte compression_type; /* must be PNG_COMPRESSION_TYPE_BASE (IHDR) */
+ png_byte filter_type; /* must be PNG_FILTER_TYPE_BASE (from IHDR) */
+ png_byte interlace_type; /* One of PNG_INTERLACE_NONE, PNG_INTERLACE_ADAM7 */
+
+ /* The following are set by png_set_IHDR, called from the application on
+ * write, but the are never actually used by the write code.
+ */
+ png_byte channels; /* number of data channels per pixel (1, 2, 3, 4) */
+ png_byte pixel_depth; /* number of bits per pixel */
+ png_byte spare_byte; /* to align the data, and for future use */
+
+#ifdef PNG_READ_SUPPORTED
+ /* This is never set during write */
+ png_byte signature[8]; /* magic bytes read by libpng from start of file */
+#endif
+
+ /* The rest of the data is optional. If you are reading, check the
+ * valid field to see if the information in these are valid. If you
+ * are writing, set the valid field to those chunks you want written,
+ * and initialize the appropriate fields below.
+ */
+
+#if defined(PNG_COLORSPACE_SUPPORTED) || defined(PNG_GAMMA_SUPPORTED)
+ /* png_colorspace only contains 'flags' if neither GAMMA or COLORSPACE are
+ * defined. When COLORSPACE is switched on all the colorspace-defining
+ * chunks should be enabled, when GAMMA is switched on all the gamma-defining
+ * chunks should be enabled. If this is not done it becomes possible to read
+ * inconsistent PNG files and assign a probably incorrect interpretation to
+ * the information. (In other words, by carefully choosing which chunks to
+ * recognize the system configuration can select an interpretation for PNG
+ * files containing ambiguous data and this will result in inconsistent
+ * behavior between different libpng builds!)
+ */
+ png_colorspace colorspace;
+#endif
+
+#ifdef PNG_iCCP_SUPPORTED
+ /* iCCP chunk data. */
+ png_charp iccp_name; /* profile name */
+ png_bytep iccp_profile; /* International Color Consortium profile data */
+ png_uint_32 iccp_proflen; /* ICC profile data length */
+#endif
+
+#ifdef PNG_TEXT_SUPPORTED
+ /* The tEXt, and zTXt chunks contain human-readable textual data in
+ * uncompressed, compressed, and optionally compressed forms, respectively.
+ * The data in "text" is an array of pointers to uncompressed,
+ * null-terminated C strings. Each chunk has a keyword that describes the
+ * textual data contained in that chunk. Keywords are not required to be
+ * unique, and the text string may be empty. Any number of text chunks may
+ * be in an image.
+ */
+ int num_text; /* number of comments read or comments to write */
+ int max_text; /* current size of text array */
+ png_textp text; /* array of comments read or comments to write */
+#endif /* PNG_TEXT_SUPPORTED */
+
+#ifdef PNG_tIME_SUPPORTED
+ /* The tIME chunk holds the last time the displayed image data was
+ * modified. See the png_time struct for the contents of this struct.
+ */
+ png_time mod_time;
+#endif
+
+#ifdef PNG_sBIT_SUPPORTED
+ /* The sBIT chunk specifies the number of significant high-order bits
+ * in the pixel data. Values are in the range [1, bit_depth], and are
+ * only specified for the channels in the pixel data. The contents of
+ * the low-order bits is not specified. Data is valid if
+ * (valid & PNG_INFO_sBIT) is non-zero.
+ */
+ png_color_8 sig_bit; /* significant bits in color channels */
+#endif
+
+#if defined(PNG_tRNS_SUPPORTED) || defined(PNG_READ_EXPAND_SUPPORTED) || \
+defined(PNG_READ_BACKGROUND_SUPPORTED)
+ /* The tRNS chunk supplies transparency data for paletted images and
+ * other image types that don't need a full alpha channel. There are
+ * "num_trans" transparency values for a paletted image, stored in the
+ * same order as the palette colors, starting from index 0. Values
+ * for the data are in the range [0, 255], ranging from fully transparent
+ * to fully opaque, respectively. For non-paletted images, there is a
+ * single color specified that should be treated as fully transparent.
+ * Data is valid if (valid & PNG_INFO_tRNS) is non-zero.
+ */
+ png_bytep trans_alpha; /* alpha values for paletted image */
+ png_color_16 trans_color; /* transparent color for non-palette image */
+#endif
+
+#if defined(PNG_bKGD_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED)
+ /* The bKGD chunk gives the suggested image background color if the
+ * display program does not have its own background color and the image
+ * is needs to composited onto a background before display. The colors
+ * in "background" are normally in the same color space/depth as the
+ * pixel data. Data is valid if (valid & PNG_INFO_bKGD) is non-zero.
+ */
+ png_color_16 background;
+#endif
+
+#ifdef PNG_oFFs_SUPPORTED
+ /* The oFFs chunk gives the offset in "offset_unit_type" units rightwards
+ * and downwards from the top-left corner of the display, page, or other
+ * application-specific co-ordinate space. See the PNG_OFFSET_ defines
+ * below for the unit types. Valid if (valid & PNG_INFO_oFFs) non-zero.
+ */
+ png_int_32 x_offset; /* x offset on page */
+ png_int_32 y_offset; /* y offset on page */
+ png_byte offset_unit_type; /* offset units type */
+#endif
+
+#ifdef PNG_pHYs_SUPPORTED
+ /* The pHYs chunk gives the physical pixel density of the image for
+ * display or printing in "phys_unit_type" units (see PNG_RESOLUTION_
+ * defines below). Data is valid if (valid & PNG_INFO_pHYs) is non-zero.
+ */
+ png_uint_32 x_pixels_per_unit; /* horizontal pixel density */
+ png_uint_32 y_pixels_per_unit; /* vertical pixel density */
+ png_byte phys_unit_type; /* resolution type (see PNG_RESOLUTION_ below) */
+#endif
+
+#ifdef PNG_hIST_SUPPORTED
+ /* The hIST chunk contains the relative frequency or importance of the
+ * various palette entries, so that a viewer can intelligently select a
+ * reduced-color palette, if required. Data is an array of "num_palette"
+ * values in the range [0,65535]. Data valid if (valid & PNG_INFO_hIST)
+ * is non-zero.
+ */
+ png_uint_16p hist;
+#endif
+
+#ifdef PNG_pCAL_SUPPORTED
+ /* The pCAL chunk describes a transformation between the stored pixel
+ * values and original physical data values used to create the image.
+ * The integer range [0, 2^bit_depth - 1] maps to the floating-point
+ * range given by [pcal_X0, pcal_X1], and are further transformed by a
+ * (possibly non-linear) transformation function given by "pcal_type"
+ * and "pcal_params" into "pcal_units". Please see the PNG_EQUATION_
+ * defines below, and the PNG-Group's PNG extensions document for a
+ * complete description of the transformations and how they should be
+ * implemented, and for a description of the ASCII parameter strings.
+ * Data values are valid if (valid & PNG_INFO_pCAL) non-zero.
+ */
+ png_charp pcal_purpose; /* pCAL chunk description string */
+ png_int_32 pcal_X0; /* minimum value */
+ png_int_32 pcal_X1; /* maximum value */
+ png_charp pcal_units; /* Latin-1 string giving physical units */
+ png_charpp pcal_params; /* ASCII strings containing parameter values */
+ png_byte pcal_type; /* equation type (see PNG_EQUATION_ below) */
+ png_byte pcal_nparams; /* number of parameters given in pcal_params */
+#endif
+
+/* New members added in libpng-1.0.6 */
+ png_uint_32 free_me; /* flags items libpng is responsible for freeing */
+
+#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED
+ /* Storage for unknown chunks that the library doesn't recognize. */
+ png_unknown_chunkp unknown_chunks;
+
+ /* The type of this field is limited by the type of
+ * png_struct::user_chunk_cache_max, else overflow can occur.
+ */
+ int unknown_chunks_num;
+#endif
+
+#ifdef PNG_sPLT_SUPPORTED
+ /* Data on sPLT chunks (there may be more than one). */
+ png_sPLT_tp splt_palettes;
+ int splt_palettes_num; /* Match type returned by png_get API */
+#endif
+
+#ifdef PNG_sCAL_SUPPORTED
+ /* The sCAL chunk describes the actual physical dimensions of the
+ * subject matter of the graphic. The chunk contains a unit specification
+ * a byte value, and two ASCII strings representing floating-point
+ * values. The values are width and height corresponsing to one pixel
+ * in the image. Data values are valid if (valid & PNG_INFO_sCAL) is
+ * non-zero.
+ */
+ png_byte scal_unit; /* unit of physical scale */
+ png_charp scal_s_width; /* string containing height */
+ png_charp scal_s_height; /* string containing width */
+#endif
+
+#ifdef PNG_INFO_IMAGE_SUPPORTED
+ /* Memory has been allocated if (valid & PNG_ALLOCATED_INFO_ROWS)
+ non-zero */
+ /* Data valid if (valid & PNG_INFO_IDAT) non-zero */
+ png_bytepp row_pointers; /* the image bits */
+#endif
+
+};
+#endif /* PNGINFO_H */
diff --git a/ml/dlib/dlib/external/libpng/pnglibconf.h b/ml/dlib/dlib/external/libpng/pnglibconf.h
new file mode 100644
index 000000000..6e9a41152
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pnglibconf.h
@@ -0,0 +1,211 @@
+/* libpng 1.6.7 STANDARD API DEFINITION */
+
+/* pnglibconf.h - library build configuration */
+
+/* Libpng version 1.6.7 - November 14, 2013 */
+
+/* Copyright (c) 1998-2013 Glenn Randers-Pehrson */
+
+/* This code is released under the libpng license. */
+/* For conditions of distribution and use, see the disclaimer */
+/* and license in png.h */
+
+/* pnglibconf.h */
+/* Machine generated file: DO NOT EDIT */
+/* Derived from: scripts/pnglibconf.dfa */
+#ifndef PNGLCONF_H
+#define PNGLCONF_H
+/* options */
+#define PNG_16BIT_SUPPORTED
+#define PNG_ALIGNED_MEMORY_SUPPORTED
+/*#undef PNG_ARM_NEON_API_SUPPORTED*/
+/*#undef PNG_ARM_NEON_CHECK_SUPPORTED*/
+#define PNG_BENIGN_ERRORS_SUPPORTED
+#define PNG_BENIGN_READ_ERRORS_SUPPORTED
+/*#undef PNG_BENIGN_WRITE_ERRORS_SUPPORTED*/
+#define PNG_BUILD_GRAYSCALE_PALETTE_SUPPORTED
+#define PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED
+#define PNG_COLORSPACE_SUPPORTED
+#define PNG_CONSOLE_IO_SUPPORTED
+#define PNG_CONVERT_tIME_SUPPORTED
+#define PNG_EASY_ACCESS_SUPPORTED
+/*#undef PNG_ERROR_NUMBERS_SUPPORTED*/
+#define PNG_ERROR_TEXT_SUPPORTED
+#define PNG_FIXED_POINT_SUPPORTED
+#define PNG_FLOATING_ARITHMETIC_SUPPORTED
+#define PNG_FLOATING_POINT_SUPPORTED
+#define PNG_FORMAT_AFIRST_SUPPORTED
+#define PNG_FORMAT_BGR_SUPPORTED
+#define PNG_GAMMA_SUPPORTED
+#define PNG_GET_PALETTE_MAX_SUPPORTED
+#define PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+#define PNG_INCH_CONVERSIONS_SUPPORTED
+#define PNG_INFO_IMAGE_SUPPORTED
+#define PNG_IO_STATE_SUPPORTED
+#define PNG_MNG_FEATURES_SUPPORTED
+#define PNG_POINTER_INDEXING_SUPPORTED
+#define PNG_PROGRESSIVE_READ_SUPPORTED
+#define PNG_READ_16BIT_SUPPORTED
+#define PNG_READ_ALPHA_MODE_SUPPORTED
+#define PNG_READ_ANCILLARY_CHUNKS_SUPPORTED
+#define PNG_READ_BACKGROUND_SUPPORTED
+#define PNG_READ_BGR_SUPPORTED
+#define PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED
+#define PNG_READ_COMPOSITE_NODIV_SUPPORTED
+#define PNG_READ_COMPRESSED_TEXT_SUPPORTED
+#define PNG_READ_EXPAND_16_SUPPORTED
+#define PNG_READ_EXPAND_SUPPORTED
+#define PNG_READ_FILLER_SUPPORTED
+#define PNG_READ_GAMMA_SUPPORTED
+#define PNG_READ_GET_PALETTE_MAX_SUPPORTED
+#define PNG_READ_GRAY_TO_RGB_SUPPORTED
+#define PNG_READ_INTERLACING_SUPPORTED
+#define PNG_READ_INT_FUNCTIONS_SUPPORTED
+#define PNG_READ_INVERT_ALPHA_SUPPORTED
+#define PNG_READ_INVERT_SUPPORTED
+#define PNG_READ_OPT_PLTE_SUPPORTED
+#define PNG_READ_PACKSWAP_SUPPORTED
+#define PNG_READ_PACK_SUPPORTED
+#define PNG_READ_QUANTIZE_SUPPORTED
+#define PNG_READ_RGB_TO_GRAY_SUPPORTED
+#define PNG_READ_SCALE_16_TO_8_SUPPORTED
+#define PNG_READ_SHIFT_SUPPORTED
+#define PNG_READ_STRIP_16_TO_8_SUPPORTED
+#define PNG_READ_STRIP_ALPHA_SUPPORTED
+#define PNG_READ_SUPPORTED
+#define PNG_READ_SWAP_ALPHA_SUPPORTED
+#define PNG_READ_SWAP_SUPPORTED
+#define PNG_READ_TEXT_SUPPORTED
+#define PNG_READ_TRANSFORMS_SUPPORTED
+#define PNG_READ_UNKNOWN_CHUNKS_SUPPORTED
+#define PNG_READ_USER_CHUNKS_SUPPORTED
+#define PNG_READ_USER_TRANSFORM_SUPPORTED
+#define PNG_READ_bKGD_SUPPORTED
+#define PNG_READ_cHRM_SUPPORTED
+#define PNG_READ_gAMA_SUPPORTED
+#define PNG_READ_hIST_SUPPORTED
+#define PNG_READ_iCCP_SUPPORTED
+#define PNG_READ_iTXt_SUPPORTED
+#define PNG_READ_oFFs_SUPPORTED
+#define PNG_READ_pCAL_SUPPORTED
+#define PNG_READ_pHYs_SUPPORTED
+#define PNG_READ_sBIT_SUPPORTED
+#define PNG_READ_sCAL_SUPPORTED
+#define PNG_READ_sPLT_SUPPORTED
+#define PNG_READ_sRGB_SUPPORTED
+#define PNG_READ_tEXt_SUPPORTED
+#define PNG_READ_tIME_SUPPORTED
+#define PNG_READ_tRNS_SUPPORTED
+#define PNG_READ_zTXt_SUPPORTED
+/*#undef PNG_SAFE_LIMITS_SUPPORTED*/
+#define PNG_SAVE_INT_32_SUPPORTED
+#define PNG_SAVE_UNKNOWN_CHUNKS_SUPPORTED
+#define PNG_SEQUENTIAL_READ_SUPPORTED
+#define PNG_SETJMP_SUPPORTED
+#define PNG_SET_CHUNK_CACHE_LIMIT_SUPPORTED
+#define PNG_SET_CHUNK_MALLOC_LIMIT_SUPPORTED
+#define PNG_SET_OPTION_SUPPORTED
+#define PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+#define PNG_SET_USER_LIMITS_SUPPORTED
+#define PNG_SIMPLIFIED_READ_AFIRST_SUPPORTED
+#define PNG_SIMPLIFIED_READ_BGR_SUPPORTED
+#define PNG_SIMPLIFIED_READ_SUPPORTED
+#define PNG_SIMPLIFIED_WRITE_AFIRST_SUPPORTED
+#define PNG_SIMPLIFIED_WRITE_BGR_SUPPORTED
+#define PNG_SIMPLIFIED_WRITE_SUPPORTED
+#define PNG_STDIO_SUPPORTED
+#define PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED
+#define PNG_TEXT_SUPPORTED
+#define PNG_TIME_RFC1123_SUPPORTED
+#define PNG_UNKNOWN_CHUNKS_SUPPORTED
+#define PNG_USER_CHUNKS_SUPPORTED
+#define PNG_USER_LIMITS_SUPPORTED
+#define PNG_USER_MEM_SUPPORTED
+#define PNG_USER_TRANSFORM_INFO_SUPPORTED
+#define PNG_USER_TRANSFORM_PTR_SUPPORTED
+#define PNG_WARNINGS_SUPPORTED
+#define PNG_WRITE_16BIT_SUPPORTED
+#define PNG_WRITE_ANCILLARY_CHUNKS_SUPPORTED
+#define PNG_WRITE_BGR_SUPPORTED
+#define PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED
+#define PNG_WRITE_COMPRESSED_TEXT_SUPPORTED
+#define PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED
+#define PNG_WRITE_FILLER_SUPPORTED
+#define PNG_WRITE_FILTER_SUPPORTED
+#define PNG_WRITE_FLUSH_SUPPORTED
+#define PNG_WRITE_GET_PALETTE_MAX_SUPPORTED
+#define PNG_WRITE_INTERLACING_SUPPORTED
+#define PNG_WRITE_INT_FUNCTIONS_SUPPORTED
+#define PNG_WRITE_INVERT_ALPHA_SUPPORTED
+#define PNG_WRITE_INVERT_SUPPORTED
+#define PNG_WRITE_OPTIMIZE_CMF_SUPPORTED
+#define PNG_WRITE_PACKSWAP_SUPPORTED
+#define PNG_WRITE_PACK_SUPPORTED
+#define PNG_WRITE_SHIFT_SUPPORTED
+#define PNG_WRITE_SUPPORTED
+#define PNG_WRITE_SWAP_ALPHA_SUPPORTED
+#define PNG_WRITE_SWAP_SUPPORTED
+#define PNG_WRITE_TEXT_SUPPORTED
+#define PNG_WRITE_TRANSFORMS_SUPPORTED
+#define PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED
+#define PNG_WRITE_USER_TRANSFORM_SUPPORTED
+#define PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+#define PNG_WRITE_bKGD_SUPPORTED
+#define PNG_WRITE_cHRM_SUPPORTED
+#define PNG_WRITE_gAMA_SUPPORTED
+#define PNG_WRITE_hIST_SUPPORTED
+#define PNG_WRITE_iCCP_SUPPORTED
+#define PNG_WRITE_iTXt_SUPPORTED
+#define PNG_WRITE_oFFs_SUPPORTED
+#define PNG_WRITE_pCAL_SUPPORTED
+#define PNG_WRITE_pHYs_SUPPORTED
+#define PNG_WRITE_sBIT_SUPPORTED
+#define PNG_WRITE_sCAL_SUPPORTED
+#define PNG_WRITE_sPLT_SUPPORTED
+#define PNG_WRITE_sRGB_SUPPORTED
+#define PNG_WRITE_tEXt_SUPPORTED
+#define PNG_WRITE_tIME_SUPPORTED
+#define PNG_WRITE_tRNS_SUPPORTED
+#define PNG_WRITE_zTXt_SUPPORTED
+#define PNG_bKGD_SUPPORTED
+#define PNG_cHRM_SUPPORTED
+#define PNG_gAMA_SUPPORTED
+#define PNG_hIST_SUPPORTED
+#define PNG_iCCP_SUPPORTED
+#define PNG_iTXt_SUPPORTED
+#define PNG_oFFs_SUPPORTED
+#define PNG_pCAL_SUPPORTED
+#define PNG_pHYs_SUPPORTED
+#define PNG_sBIT_SUPPORTED
+#define PNG_sCAL_SUPPORTED
+#define PNG_sPLT_SUPPORTED
+#define PNG_sRGB_SUPPORTED
+#define PNG_tEXt_SUPPORTED
+#define PNG_tIME_SUPPORTED
+#define PNG_tRNS_SUPPORTED
+#define PNG_zTXt_SUPPORTED
+/* end of options */
+/* settings */
+#define PNG_API_RULE 0
+#define PNG_CALLOC_SUPPORTED
+#define PNG_COST_SHIFT 3
+#define PNG_DEFAULT_READ_MACROS 1
+#define PNG_GAMMA_THRESHOLD_FIXED 5000
+#define PNG_IDAT_READ_SIZE PNG_ZBUF_SIZE
+#define PNG_INFLATE_BUF_SIZE 1024
+#define PNG_MAX_GAMMA_8 11
+#define PNG_QUANTIZE_BLUE_BITS 5
+#define PNG_QUANTIZE_GREEN_BITS 5
+#define PNG_QUANTIZE_RED_BITS 5
+#define PNG_TEXT_Z_DEFAULT_COMPRESSION (-1)
+#define PNG_TEXT_Z_DEFAULT_STRATEGY 0
+#define PNG_WEIGHT_SHIFT 8
+#define PNG_ZBUF_SIZE 8192
+#define PNG_ZLIB_VERNUM 0 /* unknown */
+#define PNG_Z_DEFAULT_COMPRESSION (-1)
+#define PNG_Z_DEFAULT_NOFILTER_STRATEGY 0
+#define PNG_Z_DEFAULT_STRATEGY 1
+#define PNG_sCAL_PRECISION 5
+#define PNG_sRGB_PROFILE_CHECKS 2
+/* end of settings */
+#endif /* PNGLCONF_H */
diff --git a/ml/dlib/dlib/external/libpng/pngmem.c b/ml/dlib/dlib/external/libpng/pngmem.c
new file mode 100644
index 000000000..b9b3efb44
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngmem.c
@@ -0,0 +1,277 @@
+
+/* pngmem.c - stub functions for memory allocation
+ *
+ * Last changed in libpng 1.6.0 [February 14, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ * This file provides a location for all memory allocation. Users who
+ * need special memory handling are expected to supply replacement
+ * functions for png_malloc() and png_free(), and to use
+ * png_create_read_struct_2() and png_create_write_struct_2() to
+ * identify the replacement functions.
+ */
+
+#include "pngpriv.h"
+
+#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED)
+/* Free a png_struct */
+void /* PRIVATE */
+png_destroy_png_struct(png_structrp png_ptr)
+{
+ if (png_ptr != NULL)
+ {
+ /* png_free might call png_error and may certainly call
+ * png_get_mem_ptr, so fake a temporary png_struct to support this.
+ */
+ png_struct dummy_struct = *png_ptr;
+ memset(png_ptr, 0, (sizeof *png_ptr));
+ png_free(&dummy_struct, png_ptr);
+
+# ifdef PNG_SETJMP_SUPPORTED
+ /* We may have a jmp_buf left to deallocate. */
+ png_free_jmpbuf(&dummy_struct);
+# endif
+ }
+}
+
+/* Allocate memory. For reasonable files, size should never exceed
+ * 64K. However, zlib may allocate more then 64K if you don't tell
+ * it not to. See zconf.h and png.h for more information. zlib does
+ * need to allocate exactly 64K, so whatever you call here must
+ * have the ability to do that.
+ */
+PNG_FUNCTION(png_voidp,PNGAPI
+png_calloc,(png_const_structrp png_ptr, png_alloc_size_t size),PNG_ALLOCATED)
+{
+ png_voidp ret;
+
+ ret = png_malloc(png_ptr, size);
+
+ if (ret != NULL)
+ memset(ret, 0, size);
+
+ return ret;
+}
+
+/* png_malloc_base, an internal function added at libpng 1.6.0, does the work of
+ * allocating memory, taking into account limits and PNG_USER_MEM_SUPPORTED.
+ * Checking and error handling must happen outside this routine; it returns NULL
+ * if the allocation cannot be done (for any reason.)
+ */
+PNG_FUNCTION(png_voidp /* PRIVATE */,
+png_malloc_base,(png_const_structrp png_ptr, png_alloc_size_t size),
+ PNG_ALLOCATED)
+{
+ /* Moved to png_malloc_base from png_malloc_default in 1.6.0; the DOS
+ * allocators have also been removed in 1.6.0, so any 16-bit system now has
+ * to implement a user memory handler. This checks to be sure it isn't
+ * called with big numbers.
+ */
+#ifdef PNG_USER_MEM_SUPPORTED
+ PNG_UNUSED(png_ptr)
+#endif
+ if (size > 0 && size <= PNG_SIZE_MAX
+# ifdef PNG_MAX_MALLOC_64K
+ && size <= 65536U
+# endif
+ )
+ {
+#ifdef PNG_USER_MEM_SUPPORTED
+ if (png_ptr != NULL && png_ptr->malloc_fn != NULL)
+ return png_ptr->malloc_fn(png_constcast(png_structrp,png_ptr), size);
+
+ else
+#endif
+ return malloc((size_t)size); /* checked for truncation above */
+ }
+
+ else
+ return NULL;
+}
+
+/* This is really here only to work round a spurious warning in GCC 4.6 and 4.7
+ * that arises because of the checks in png_realloc_array that are repeated in
+ * png_malloc_array.
+ */
+static png_voidp
+png_malloc_array_checked(png_const_structrp png_ptr, int nelements,
+ size_t element_size)
+{
+ png_alloc_size_t req = nelements; /* known to be > 0 */
+
+ if (req <= PNG_SIZE_MAX/element_size)
+ return png_malloc_base(png_ptr, req * element_size);
+
+ /* The failure case when the request is too large */
+ return NULL;
+}
+
+PNG_FUNCTION(png_voidp /* PRIVATE */,
+png_malloc_array,(png_const_structrp png_ptr, int nelements,
+ size_t element_size),PNG_ALLOCATED)
+{
+ if (nelements <= 0 || element_size == 0)
+ png_error(png_ptr, "internal error: array alloc");
+
+ return png_malloc_array_checked(png_ptr, nelements, element_size);
+}
+
+PNG_FUNCTION(png_voidp /* PRIVATE */,
+png_realloc_array,(png_const_structrp png_ptr, png_const_voidp old_array,
+ int old_elements, int add_elements, size_t element_size),PNG_ALLOCATED)
+{
+ /* These are internal errors: */
+ if (add_elements <= 0 || element_size == 0 || old_elements < 0 ||
+ (old_array == NULL && old_elements > 0))
+ png_error(png_ptr, "internal error: array realloc");
+
+ /* Check for overflow on the elements count (so the caller does not have to
+ * check.)
+ */
+ if (add_elements <= INT_MAX - old_elements)
+ {
+ png_voidp new_array = png_malloc_array_checked(png_ptr,
+ old_elements+add_elements, element_size);
+
+ if (new_array != NULL)
+ {
+ /* Because png_malloc_array worked the size calculations below cannot
+ * overflow.
+ */
+ if (old_elements > 0)
+ memcpy(new_array, old_array, element_size*(unsigned)old_elements);
+
+ memset((char*)new_array + element_size*(unsigned)old_elements, 0,
+ element_size*(unsigned)add_elements);
+
+ return new_array;
+ }
+ }
+
+ return NULL; /* error */
+}
+
+/* Various functions that have different error handling are derived from this.
+ * png_malloc always exists, but if PNG_USER_MEM_SUPPORTED is defined a separate
+ * function png_malloc_default is also provided.
+ */
+PNG_FUNCTION(png_voidp,PNGAPI
+png_malloc,(png_const_structrp png_ptr, png_alloc_size_t size),PNG_ALLOCATED)
+{
+ png_voidp ret;
+
+ if (png_ptr == NULL)
+ return NULL;
+
+ ret = png_malloc_base(png_ptr, size);
+
+ if (ret == NULL)
+ png_error(png_ptr, "Out of memory"); /* 'm' means png_malloc */
+
+ return ret;
+}
+
+#ifdef PNG_USER_MEM_SUPPORTED
+PNG_FUNCTION(png_voidp,PNGAPI
+png_malloc_default,(png_const_structrp png_ptr, png_alloc_size_t size),
+ PNG_ALLOCATED PNG_DEPRECATED)
+{
+ png_voidp ret;
+
+ if (png_ptr == NULL)
+ return NULL;
+
+ /* Passing 'NULL' here bypasses the application provided memory handler. */
+ ret = png_malloc_base(NULL/*use malloc*/, size);
+
+ if (ret == NULL)
+ png_error(png_ptr, "Out of Memory"); /* 'M' means png_malloc_default */
+
+ return ret;
+}
+#endif /* PNG_USER_MEM_SUPPORTED */
+
+/* This function was added at libpng version 1.2.3. The png_malloc_warn()
+ * function will issue a png_warning and return NULL instead of issuing a
+ * png_error, if it fails to allocate the requested memory.
+ */
+PNG_FUNCTION(png_voidp,PNGAPI
+png_malloc_warn,(png_const_structrp png_ptr, png_alloc_size_t size),
+ PNG_ALLOCATED)
+{
+ if (png_ptr != NULL)
+ {
+ png_voidp ret = png_malloc_base(png_ptr, size);
+
+ if (ret != NULL)
+ return ret;
+
+ png_warning(png_ptr, "Out of memory");
+ }
+
+ return NULL;
+}
+
+/* Free a pointer allocated by png_malloc(). If ptr is NULL, return
+ * without taking any action.
+ */
+void PNGAPI
+png_free(png_const_structrp png_ptr, png_voidp ptr)
+{
+ if (png_ptr == NULL || ptr == NULL)
+ return;
+
+#ifdef PNG_USER_MEM_SUPPORTED
+ if (png_ptr->free_fn != NULL)
+ png_ptr->free_fn(png_constcast(png_structrp,png_ptr), ptr);
+
+ else
+ png_free_default(png_ptr, ptr);
+}
+
+PNG_FUNCTION(void,PNGAPI
+png_free_default,(png_const_structrp png_ptr, png_voidp ptr),PNG_DEPRECATED)
+{
+ if (png_ptr == NULL || ptr == NULL)
+ return;
+#endif /* PNG_USER_MEM_SUPPORTED */
+
+ free(ptr);
+}
+
+#ifdef PNG_USER_MEM_SUPPORTED
+/* This function is called when the application wants to use another method
+ * of allocating and freeing memory.
+ */
+void PNGAPI
+png_set_mem_fn(png_structrp png_ptr, png_voidp mem_ptr, png_malloc_ptr
+ malloc_fn, png_free_ptr free_fn)
+{
+ if (png_ptr != NULL)
+ {
+ png_ptr->mem_ptr = mem_ptr;
+ png_ptr->malloc_fn = malloc_fn;
+ png_ptr->free_fn = free_fn;
+ }
+}
+
+/* This function returns a pointer to the mem_ptr associated with the user
+ * functions. The application should free any memory associated with this
+ * pointer before png_write_destroy and png_read_destroy are called.
+ */
+png_voidp PNGAPI
+png_get_mem_ptr(png_const_structrp png_ptr)
+{
+ if (png_ptr == NULL)
+ return NULL;
+
+ return png_ptr->mem_ptr;
+}
+#endif /* PNG_USER_MEM_SUPPORTED */
+#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngpread.c b/ml/dlib/dlib/external/libpng/pngpread.c
new file mode 100644
index 000000000..0169ecb2c
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngpread.c
@@ -0,0 +1,1291 @@
+
+/* pngpread.c - read a png file in push mode
+ *
+ * Last changed in libpng 1.6.0 [February 14, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+#include "pngpriv.h"
+
+#ifdef PNG_PROGRESSIVE_READ_SUPPORTED
+
+/* Push model modes */
+#define PNG_READ_SIG_MODE 0
+#define PNG_READ_CHUNK_MODE 1
+#define PNG_READ_IDAT_MODE 2
+#define PNG_SKIP_MODE 3
+#define PNG_READ_tEXt_MODE 4
+#define PNG_READ_zTXt_MODE 5
+#define PNG_READ_DONE_MODE 6
+#define PNG_READ_iTXt_MODE 7
+#define PNG_ERROR_MODE 8
+
+void PNGAPI
+png_process_data(png_structrp png_ptr, png_inforp info_ptr,
+ png_bytep buffer, png_size_t buffer_size)
+{
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ png_push_restore_buffer(png_ptr, buffer, buffer_size);
+
+ while (png_ptr->buffer_size)
+ {
+ png_process_some_data(png_ptr, info_ptr);
+ }
+}
+
+png_size_t PNGAPI
+png_process_data_pause(png_structrp png_ptr, int save)
+{
+ if (png_ptr != NULL)
+ {
+ /* It's easiest for the caller if we do the save, then the caller doesn't
+ * have to supply the same data again:
+ */
+ if (save)
+ png_push_save_buffer(png_ptr);
+ else
+ {
+ /* This includes any pending saved bytes: */
+ png_size_t remaining = png_ptr->buffer_size;
+ png_ptr->buffer_size = 0;
+
+ /* So subtract the saved buffer size, unless all the data
+ * is actually 'saved', in which case we just return 0
+ */
+ if (png_ptr->save_buffer_size < remaining)
+ return remaining - png_ptr->save_buffer_size;
+ }
+ }
+
+ return 0;
+}
+
+png_uint_32 PNGAPI
+png_process_data_skip(png_structrp png_ptr)
+{
+ png_uint_32 remaining = 0;
+
+ if (png_ptr != NULL && png_ptr->process_mode == PNG_SKIP_MODE &&
+ png_ptr->skip_length > 0)
+ {
+ /* At the end of png_process_data the buffer size must be 0 (see the loop
+ * above) so we can detect a broken call here:
+ */
+ if (png_ptr->buffer_size != 0)
+ png_error(png_ptr,
+ "png_process_data_skip called inside png_process_data");
+
+ /* If is impossible for there to be a saved buffer at this point -
+ * otherwise we could not be in SKIP mode. This will also happen if
+ * png_process_skip is called inside png_process_data (but only very
+ * rarely.)
+ */
+ if (png_ptr->save_buffer_size != 0)
+ png_error(png_ptr, "png_process_data_skip called with saved data");
+
+ remaining = png_ptr->skip_length;
+ png_ptr->skip_length = 0;
+ png_ptr->process_mode = PNG_READ_CHUNK_MODE;
+ }
+
+ return remaining;
+}
+
+/* What we do with the incoming data depends on what we were previously
+ * doing before we ran out of data...
+ */
+void /* PRIVATE */
+png_process_some_data(png_structrp png_ptr, png_inforp info_ptr)
+{
+ if (png_ptr == NULL)
+ return;
+
+ switch (png_ptr->process_mode)
+ {
+ case PNG_READ_SIG_MODE:
+ {
+ png_push_read_sig(png_ptr, info_ptr);
+ break;
+ }
+
+ case PNG_READ_CHUNK_MODE:
+ {
+ png_push_read_chunk(png_ptr, info_ptr);
+ break;
+ }
+
+ case PNG_READ_IDAT_MODE:
+ {
+ png_push_read_IDAT(png_ptr);
+ break;
+ }
+
+ case PNG_SKIP_MODE:
+ {
+ png_push_crc_finish(png_ptr);
+ break;
+ }
+
+ default:
+ {
+ png_ptr->buffer_size = 0;
+ break;
+ }
+ }
+}
+
+/* Read any remaining signature bytes from the stream and compare them with
+ * the correct PNG signature. It is possible that this routine is called
+ * with bytes already read from the signature, either because they have been
+ * checked by the calling application, or because of multiple calls to this
+ * routine.
+ */
+void /* PRIVATE */
+png_push_read_sig(png_structrp png_ptr, png_inforp info_ptr)
+{
+ png_size_t num_checked = png_ptr->sig_bytes, /* SAFE, does not exceed 8 */
+ num_to_check = 8 - num_checked;
+
+ if (png_ptr->buffer_size < num_to_check)
+ {
+ num_to_check = png_ptr->buffer_size;
+ }
+
+ png_push_fill_buffer(png_ptr, &(info_ptr->signature[num_checked]),
+ num_to_check);
+ png_ptr->sig_bytes = (png_byte)(png_ptr->sig_bytes + num_to_check);
+
+ if (png_sig_cmp(info_ptr->signature, num_checked, num_to_check))
+ {
+ if (num_checked < 4 &&
+ png_sig_cmp(info_ptr->signature, num_checked, num_to_check - 4))
+ png_error(png_ptr, "Not a PNG file");
+
+ else
+ png_error(png_ptr, "PNG file corrupted by ASCII conversion");
+ }
+ else
+ {
+ if (png_ptr->sig_bytes >= 8)
+ {
+ png_ptr->process_mode = PNG_READ_CHUNK_MODE;
+ }
+ }
+}
+
+void /* PRIVATE */
+png_push_read_chunk(png_structrp png_ptr, png_inforp info_ptr)
+{
+ png_uint_32 chunk_name;
+#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+ int keep; /* unknown handling method */
+#endif
+
+ /* First we make sure we have enough data for the 4 byte chunk name
+ * and the 4 byte chunk length before proceeding with decoding the
+ * chunk data. To fully decode each of these chunks, we also make
+ * sure we have enough data in the buffer for the 4 byte CRC at the
+ * end of every chunk (except IDAT, which is handled separately).
+ */
+ if (!(png_ptr->mode & PNG_HAVE_CHUNK_HEADER))
+ {
+ png_byte chunk_length[4];
+ png_byte chunk_tag[4];
+
+ if (png_ptr->buffer_size < 8)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_push_fill_buffer(png_ptr, chunk_length, 4);
+ png_ptr->push_length = png_get_uint_31(png_ptr, chunk_length);
+ png_reset_crc(png_ptr);
+ png_crc_read(png_ptr, chunk_tag, 4);
+ png_ptr->chunk_name = PNG_CHUNK_FROM_STRING(chunk_tag);
+ png_check_chunk_name(png_ptr, png_ptr->chunk_name);
+ png_ptr->mode |= PNG_HAVE_CHUNK_HEADER;
+ }
+
+ chunk_name = png_ptr->chunk_name;
+
+ if (chunk_name == png_IDAT)
+ {
+ if (png_ptr->mode & PNG_AFTER_IDAT)
+ png_ptr->mode |= PNG_HAVE_CHUNK_AFTER_IDAT;
+
+ /* If we reach an IDAT chunk, this means we have read all of the
+ * header chunks, and we can start reading the image (or if this
+ * is called after the image has been read - we have an error).
+ */
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_error(png_ptr, "Missing IHDR before IDAT");
+
+ else if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE &&
+ !(png_ptr->mode & PNG_HAVE_PLTE))
+ png_error(png_ptr, "Missing PLTE before IDAT");
+
+ png_ptr->mode |= PNG_HAVE_IDAT;
+
+ if (!(png_ptr->mode & PNG_HAVE_CHUNK_AFTER_IDAT))
+ if (png_ptr->push_length == 0)
+ return;
+
+ if (png_ptr->mode & PNG_AFTER_IDAT)
+ png_benign_error(png_ptr, "Too many IDATs found");
+ }
+
+ if (chunk_name == png_IHDR)
+ {
+ if (png_ptr->push_length != 13)
+ png_error(png_ptr, "Invalid IHDR length");
+
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_IHDR(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+ else if (chunk_name == png_IEND)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_IEND(png_ptr, info_ptr, png_ptr->push_length);
+
+ png_ptr->process_mode = PNG_READ_DONE_MODE;
+ png_push_have_end(png_ptr, info_ptr);
+ }
+
+#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+ else if ((keep = png_chunk_unknown_handling(png_ptr, chunk_name)) != 0)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_unknown(png_ptr, info_ptr, png_ptr->push_length, keep);
+
+ if (chunk_name == png_PLTE)
+ png_ptr->mode |= PNG_HAVE_PLTE;
+ }
+
+#endif
+ else if (chunk_name == png_PLTE)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+ png_handle_PLTE(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+ else if (chunk_name == png_IDAT)
+ {
+ png_ptr->idat_size = png_ptr->push_length;
+ png_ptr->process_mode = PNG_READ_IDAT_MODE;
+ png_push_have_info(png_ptr, info_ptr);
+ png_ptr->zstream.avail_out =
+ (uInt) PNG_ROWBYTES(png_ptr->pixel_depth,
+ png_ptr->iwidth) + 1;
+ png_ptr->zstream.next_out = png_ptr->row_buf;
+ return;
+ }
+
+#ifdef PNG_READ_gAMA_SUPPORTED
+ else if (png_ptr->chunk_name == png_gAMA)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_gAMA(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_sBIT_SUPPORTED
+ else if (png_ptr->chunk_name == png_sBIT)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_sBIT(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_cHRM_SUPPORTED
+ else if (png_ptr->chunk_name == png_cHRM)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_cHRM(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_sRGB_SUPPORTED
+ else if (chunk_name == png_sRGB)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_sRGB(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_iCCP_SUPPORTED
+ else if (png_ptr->chunk_name == png_iCCP)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_iCCP(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_sPLT_SUPPORTED
+ else if (chunk_name == png_sPLT)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_sPLT(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_tRNS_SUPPORTED
+ else if (chunk_name == png_tRNS)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_tRNS(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_bKGD_SUPPORTED
+ else if (chunk_name == png_bKGD)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_bKGD(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_hIST_SUPPORTED
+ else if (chunk_name == png_hIST)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_hIST(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_pHYs_SUPPORTED
+ else if (chunk_name == png_pHYs)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_pHYs(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_oFFs_SUPPORTED
+ else if (chunk_name == png_oFFs)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_oFFs(png_ptr, info_ptr, png_ptr->push_length);
+ }
+#endif
+
+#ifdef PNG_READ_pCAL_SUPPORTED
+ else if (chunk_name == png_pCAL)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_pCAL(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_sCAL_SUPPORTED
+ else if (chunk_name == png_sCAL)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_sCAL(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_tIME_SUPPORTED
+ else if (chunk_name == png_tIME)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_tIME(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_tEXt_SUPPORTED
+ else if (chunk_name == png_tEXt)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_tEXt(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_zTXt_SUPPORTED
+ else if (chunk_name == png_zTXt)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_zTXt(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+#ifdef PNG_READ_iTXt_SUPPORTED
+ else if (chunk_name == png_iTXt)
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_handle_iTXt(png_ptr, info_ptr, png_ptr->push_length);
+ }
+
+#endif
+ else
+ {
+ if (png_ptr->push_length + 4 > png_ptr->buffer_size)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+ png_handle_unknown(png_ptr, info_ptr, png_ptr->push_length,
+ PNG_HANDLE_CHUNK_AS_DEFAULT);
+ }
+
+ png_ptr->mode &= ~PNG_HAVE_CHUNK_HEADER;
+}
+
+void /* PRIVATE */
+png_push_crc_skip(png_structrp png_ptr, png_uint_32 skip)
+{
+ png_ptr->process_mode = PNG_SKIP_MODE;
+ png_ptr->skip_length = skip;
+}
+
+void /* PRIVATE */
+png_push_crc_finish(png_structrp png_ptr)
+{
+ if (png_ptr->skip_length && png_ptr->save_buffer_size)
+ {
+ png_size_t save_size = png_ptr->save_buffer_size;
+ png_uint_32 skip_length = png_ptr->skip_length;
+
+ /* We want the smaller of 'skip_length' and 'save_buffer_size', but
+ * they are of different types and we don't know which variable has the
+ * fewest bits. Carefully select the smaller and cast it to the type of
+ * the larger - this cannot overflow. Do not cast in the following test
+ * - it will break on either 16 or 64 bit platforms.
+ */
+ if (skip_length < save_size)
+ save_size = (png_size_t)skip_length;
+
+ else
+ skip_length = (png_uint_32)save_size;
+
+ png_calculate_crc(png_ptr, png_ptr->save_buffer_ptr, save_size);
+
+ png_ptr->skip_length -= skip_length;
+ png_ptr->buffer_size -= save_size;
+ png_ptr->save_buffer_size -= save_size;
+ png_ptr->save_buffer_ptr += save_size;
+ }
+ if (png_ptr->skip_length && png_ptr->current_buffer_size)
+ {
+ png_size_t save_size = png_ptr->current_buffer_size;
+ png_uint_32 skip_length = png_ptr->skip_length;
+
+ /* We want the smaller of 'skip_length' and 'current_buffer_size', here,
+ * the same problem exists as above and the same solution.
+ */
+ if (skip_length < save_size)
+ save_size = (png_size_t)skip_length;
+
+ else
+ skip_length = (png_uint_32)save_size;
+
+ png_calculate_crc(png_ptr, png_ptr->current_buffer_ptr, save_size);
+
+ png_ptr->skip_length -= skip_length;
+ png_ptr->buffer_size -= save_size;
+ png_ptr->current_buffer_size -= save_size;
+ png_ptr->current_buffer_ptr += save_size;
+ }
+ if (!png_ptr->skip_length)
+ {
+ if (png_ptr->buffer_size < 4)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_crc_finish(png_ptr, 0);
+ png_ptr->process_mode = PNG_READ_CHUNK_MODE;
+ }
+}
+
+void PNGCBAPI
+png_push_fill_buffer(png_structp png_ptr, png_bytep buffer, png_size_t length)
+{
+ png_bytep ptr;
+
+ if (png_ptr == NULL)
+ return;
+
+ ptr = buffer;
+ if (png_ptr->save_buffer_size)
+ {
+ png_size_t save_size;
+
+ if (length < png_ptr->save_buffer_size)
+ save_size = length;
+
+ else
+ save_size = png_ptr->save_buffer_size;
+
+ memcpy(ptr, png_ptr->save_buffer_ptr, save_size);
+ length -= save_size;
+ ptr += save_size;
+ png_ptr->buffer_size -= save_size;
+ png_ptr->save_buffer_size -= save_size;
+ png_ptr->save_buffer_ptr += save_size;
+ }
+ if (length && png_ptr->current_buffer_size)
+ {
+ png_size_t save_size;
+
+ if (length < png_ptr->current_buffer_size)
+ save_size = length;
+
+ else
+ save_size = png_ptr->current_buffer_size;
+
+ memcpy(ptr, png_ptr->current_buffer_ptr, save_size);
+ png_ptr->buffer_size -= save_size;
+ png_ptr->current_buffer_size -= save_size;
+ png_ptr->current_buffer_ptr += save_size;
+ }
+}
+
+void /* PRIVATE */
+png_push_save_buffer(png_structrp png_ptr)
+{
+ if (png_ptr->save_buffer_size)
+ {
+ if (png_ptr->save_buffer_ptr != png_ptr->save_buffer)
+ {
+ png_size_t i, istop;
+ png_bytep sp;
+ png_bytep dp;
+
+ istop = png_ptr->save_buffer_size;
+ for (i = 0, sp = png_ptr->save_buffer_ptr, dp = png_ptr->save_buffer;
+ i < istop; i++, sp++, dp++)
+ {
+ *dp = *sp;
+ }
+ }
+ }
+ if (png_ptr->save_buffer_size + png_ptr->current_buffer_size >
+ png_ptr->save_buffer_max)
+ {
+ png_size_t new_max;
+ png_bytep old_buffer;
+
+ if (png_ptr->save_buffer_size > PNG_SIZE_MAX -
+ (png_ptr->current_buffer_size + 256))
+ {
+ png_error(png_ptr, "Potential overflow of save_buffer");
+ }
+
+ new_max = png_ptr->save_buffer_size + png_ptr->current_buffer_size + 256;
+ old_buffer = png_ptr->save_buffer;
+ png_ptr->save_buffer = (png_bytep)png_malloc_warn(png_ptr,
+ (png_size_t)new_max);
+
+ if (png_ptr->save_buffer == NULL)
+ {
+ png_free(png_ptr, old_buffer);
+ png_error(png_ptr, "Insufficient memory for save_buffer");
+ }
+
+ memcpy(png_ptr->save_buffer, old_buffer, png_ptr->save_buffer_size);
+ png_free(png_ptr, old_buffer);
+ png_ptr->save_buffer_max = new_max;
+ }
+ if (png_ptr->current_buffer_size)
+ {
+ memcpy(png_ptr->save_buffer + png_ptr->save_buffer_size,
+ png_ptr->current_buffer_ptr, png_ptr->current_buffer_size);
+ png_ptr->save_buffer_size += png_ptr->current_buffer_size;
+ png_ptr->current_buffer_size = 0;
+ }
+ png_ptr->save_buffer_ptr = png_ptr->save_buffer;
+ png_ptr->buffer_size = 0;
+}
+
+void /* PRIVATE */
+png_push_restore_buffer(png_structrp png_ptr, png_bytep buffer,
+ png_size_t buffer_length)
+{
+ png_ptr->current_buffer = buffer;
+ png_ptr->current_buffer_size = buffer_length;
+ png_ptr->buffer_size = buffer_length + png_ptr->save_buffer_size;
+ png_ptr->current_buffer_ptr = png_ptr->current_buffer;
+}
+
+void /* PRIVATE */
+png_push_read_IDAT(png_structrp png_ptr)
+{
+ if (!(png_ptr->mode & PNG_HAVE_CHUNK_HEADER))
+ {
+ png_byte chunk_length[4];
+ png_byte chunk_tag[4];
+
+ /* TODO: this code can be commoned up with the same code in push_read */
+ if (png_ptr->buffer_size < 8)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_push_fill_buffer(png_ptr, chunk_length, 4);
+ png_ptr->push_length = png_get_uint_31(png_ptr, chunk_length);
+ png_reset_crc(png_ptr);
+ png_crc_read(png_ptr, chunk_tag, 4);
+ png_ptr->chunk_name = PNG_CHUNK_FROM_STRING(chunk_tag);
+ png_ptr->mode |= PNG_HAVE_CHUNK_HEADER;
+
+ if (png_ptr->chunk_name != png_IDAT)
+ {
+ png_ptr->process_mode = PNG_READ_CHUNK_MODE;
+
+ if (!(png_ptr->flags & PNG_FLAG_ZSTREAM_ENDED))
+ png_error(png_ptr, "Not enough compressed data");
+
+ return;
+ }
+
+ png_ptr->idat_size = png_ptr->push_length;
+ }
+
+ if (png_ptr->idat_size && png_ptr->save_buffer_size)
+ {
+ png_size_t save_size = png_ptr->save_buffer_size;
+ png_uint_32 idat_size = png_ptr->idat_size;
+
+ /* We want the smaller of 'idat_size' and 'current_buffer_size', but they
+ * are of different types and we don't know which variable has the fewest
+ * bits. Carefully select the smaller and cast it to the type of the
+ * larger - this cannot overflow. Do not cast in the following test - it
+ * will break on either 16 or 64 bit platforms.
+ */
+ if (idat_size < save_size)
+ save_size = (png_size_t)idat_size;
+
+ else
+ idat_size = (png_uint_32)save_size;
+
+ png_calculate_crc(png_ptr, png_ptr->save_buffer_ptr, save_size);
+
+ png_process_IDAT_data(png_ptr, png_ptr->save_buffer_ptr, save_size);
+
+ png_ptr->idat_size -= idat_size;
+ png_ptr->buffer_size -= save_size;
+ png_ptr->save_buffer_size -= save_size;
+ png_ptr->save_buffer_ptr += save_size;
+ }
+
+ if (png_ptr->idat_size && png_ptr->current_buffer_size)
+ {
+ png_size_t save_size = png_ptr->current_buffer_size;
+ png_uint_32 idat_size = png_ptr->idat_size;
+
+ /* We want the smaller of 'idat_size' and 'current_buffer_size', but they
+ * are of different types and we don't know which variable has the fewest
+ * bits. Carefully select the smaller and cast it to the type of the
+ * larger - this cannot overflow.
+ */
+ if (idat_size < save_size)
+ save_size = (png_size_t)idat_size;
+
+ else
+ idat_size = (png_uint_32)save_size;
+
+ png_calculate_crc(png_ptr, png_ptr->current_buffer_ptr, save_size);
+
+ png_process_IDAT_data(png_ptr, png_ptr->current_buffer_ptr, save_size);
+
+ png_ptr->idat_size -= idat_size;
+ png_ptr->buffer_size -= save_size;
+ png_ptr->current_buffer_size -= save_size;
+ png_ptr->current_buffer_ptr += save_size;
+ }
+ if (!png_ptr->idat_size)
+ {
+ if (png_ptr->buffer_size < 4)
+ {
+ png_push_save_buffer(png_ptr);
+ return;
+ }
+
+ png_crc_finish(png_ptr, 0);
+ png_ptr->mode &= ~PNG_HAVE_CHUNK_HEADER;
+ png_ptr->mode |= PNG_AFTER_IDAT;
+ png_ptr->zowner = 0;
+ }
+}
+
+void /* PRIVATE */
+png_process_IDAT_data(png_structrp png_ptr, png_bytep buffer,
+ png_size_t buffer_length)
+{
+ /* The caller checks for a non-zero buffer length. */
+ if (!(buffer_length > 0) || buffer == NULL)
+ png_error(png_ptr, "No IDAT data (internal error)");
+
+ /* This routine must process all the data it has been given
+ * before returning, calling the row callback as required to
+ * handle the uncompressed results.
+ */
+ png_ptr->zstream.next_in = buffer;
+ /* TODO: WARNING: TRUNCATION ERROR: DANGER WILL ROBINSON: */
+ png_ptr->zstream.avail_in = (uInt)buffer_length;
+
+ /* Keep going until the decompressed data is all processed
+ * or the stream marked as finished.
+ */
+ while (png_ptr->zstream.avail_in > 0 &&
+ !(png_ptr->flags & PNG_FLAG_ZSTREAM_ENDED))
+ {
+ int ret;
+
+ /* We have data for zlib, but we must check that zlib
+ * has someplace to put the results. It doesn't matter
+ * if we don't expect any results -- it may be the input
+ * data is just the LZ end code.
+ */
+ if (!(png_ptr->zstream.avail_out > 0))
+ {
+ /* TODO: WARNING: TRUNCATION ERROR: DANGER WILL ROBINSON: */
+ png_ptr->zstream.avail_out = (uInt)(PNG_ROWBYTES(png_ptr->pixel_depth,
+ png_ptr->iwidth) + 1);
+
+ png_ptr->zstream.next_out = png_ptr->row_buf;
+ }
+
+ /* Using Z_SYNC_FLUSH here means that an unterminated
+ * LZ stream (a stream with a missing end code) can still
+ * be handled, otherwise (Z_NO_FLUSH) a future zlib
+ * implementation might defer output and therefore
+ * change the current behavior (see comments in inflate.c
+ * for why this doesn't happen at present with zlib 1.2.5).
+ */
+ ret = inflate(&png_ptr->zstream, Z_SYNC_FLUSH);
+
+ /* Check for any failure before proceeding. */
+ if (ret != Z_OK && ret != Z_STREAM_END)
+ {
+ /* Terminate the decompression. */
+ png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED;
+ png_ptr->zowner = 0;
+
+ /* This may be a truncated stream (missing or
+ * damaged end code). Treat that as a warning.
+ */
+ if (png_ptr->row_number >= png_ptr->num_rows ||
+ png_ptr->pass > 6)
+ png_warning(png_ptr, "Truncated compressed data in IDAT");
+
+ else
+ png_error(png_ptr, "Decompression error in IDAT");
+
+ /* Skip the check on unprocessed input */
+ return;
+ }
+
+ /* Did inflate output any data? */
+ if (png_ptr->zstream.next_out != png_ptr->row_buf)
+ {
+ /* Is this unexpected data after the last row?
+ * If it is, artificially terminate the LZ output
+ * here.
+ */
+ if (png_ptr->row_number >= png_ptr->num_rows ||
+ png_ptr->pass > 6)
+ {
+ /* Extra data. */
+ png_warning(png_ptr, "Extra compressed data in IDAT");
+ png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED;
+ png_ptr->zowner = 0;
+
+ /* Do no more processing; skip the unprocessed
+ * input check below.
+ */
+ return;
+ }
+
+ /* Do we have a complete row? */
+ if (png_ptr->zstream.avail_out == 0)
+ png_push_process_row(png_ptr);
+ }
+
+ /* And check for the end of the stream. */
+ if (ret == Z_STREAM_END)
+ png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED;
+ }
+
+ /* All the data should have been processed, if anything
+ * is left at this point we have bytes of IDAT data
+ * after the zlib end code.
+ */
+ if (png_ptr->zstream.avail_in > 0)
+ png_warning(png_ptr, "Extra compression data in IDAT");
+}
+
+void /* PRIVATE */
+png_push_process_row(png_structrp png_ptr)
+{
+ /* 1.5.6: row_info moved out of png_struct to a local here. */
+ png_row_info row_info;
+
+ row_info.width = png_ptr->iwidth; /* NOTE: width of current interlaced row */
+ row_info.color_type = png_ptr->color_type;
+ row_info.bit_depth = png_ptr->bit_depth;
+ row_info.channels = png_ptr->channels;
+ row_info.pixel_depth = png_ptr->pixel_depth;
+ row_info.rowbytes = PNG_ROWBYTES(row_info.pixel_depth, row_info.width);
+
+ if (png_ptr->row_buf[0] > PNG_FILTER_VALUE_NONE)
+ {
+ if (png_ptr->row_buf[0] < PNG_FILTER_VALUE_LAST)
+ png_read_filter_row(png_ptr, &row_info, png_ptr->row_buf + 1,
+ png_ptr->prev_row + 1, png_ptr->row_buf[0]);
+ else
+ png_error(png_ptr, "bad adaptive filter value");
+ }
+
+ /* libpng 1.5.6: the following line was copying png_ptr->rowbytes before
+ * 1.5.6, while the buffer really is this big in current versions of libpng
+ * it may not be in the future, so this was changed just to copy the
+ * interlaced row count:
+ */
+ memcpy(png_ptr->prev_row, png_ptr->row_buf, row_info.rowbytes + 1);
+
+#ifdef PNG_READ_TRANSFORMS_SUPPORTED
+ if (png_ptr->transformations)
+ png_do_read_transformations(png_ptr, &row_info);
+#endif
+
+ /* The transformed pixel depth should match the depth now in row_info. */
+ if (png_ptr->transformed_pixel_depth == 0)
+ {
+ png_ptr->transformed_pixel_depth = row_info.pixel_depth;
+ if (row_info.pixel_depth > png_ptr->maximum_pixel_depth)
+ png_error(png_ptr, "progressive row overflow");
+ }
+
+ else if (png_ptr->transformed_pixel_depth != row_info.pixel_depth)
+ png_error(png_ptr, "internal progressive row size calculation error");
+
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ /* Blow up interlaced rows to full size */
+ if (png_ptr->interlaced && (png_ptr->transformations & PNG_INTERLACE))
+ {
+ if (png_ptr->pass < 6)
+ png_do_read_interlace(&row_info, png_ptr->row_buf + 1, png_ptr->pass,
+ png_ptr->transformations);
+
+ switch (png_ptr->pass)
+ {
+ case 0:
+ {
+ int i;
+ for (i = 0; i < 8 && png_ptr->pass == 0; i++)
+ {
+ png_push_have_row(png_ptr, png_ptr->row_buf + 1);
+ png_read_push_finish_row(png_ptr); /* Updates png_ptr->pass */
+ }
+
+ if (png_ptr->pass == 2) /* Pass 1 might be empty */
+ {
+ for (i = 0; i < 4 && png_ptr->pass == 2; i++)
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+ }
+
+ if (png_ptr->pass == 4 && png_ptr->height <= 4)
+ {
+ for (i = 0; i < 2 && png_ptr->pass == 4; i++)
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+ }
+
+ if (png_ptr->pass == 6 && png_ptr->height <= 4)
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ break;
+ }
+
+ case 1:
+ {
+ int i;
+ for (i = 0; i < 8 && png_ptr->pass == 1; i++)
+ {
+ png_push_have_row(png_ptr, png_ptr->row_buf + 1);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ if (png_ptr->pass == 2) /* Skip top 4 generated rows */
+ {
+ for (i = 0; i < 4 && png_ptr->pass == 2; i++)
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+ }
+
+ break;
+ }
+
+ case 2:
+ {
+ int i;
+
+ for (i = 0; i < 4 && png_ptr->pass == 2; i++)
+ {
+ png_push_have_row(png_ptr, png_ptr->row_buf + 1);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ for (i = 0; i < 4 && png_ptr->pass == 2; i++)
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ if (png_ptr->pass == 4) /* Pass 3 might be empty */
+ {
+ for (i = 0; i < 2 && png_ptr->pass == 4; i++)
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+ }
+
+ break;
+ }
+
+ case 3:
+ {
+ int i;
+
+ for (i = 0; i < 4 && png_ptr->pass == 3; i++)
+ {
+ png_push_have_row(png_ptr, png_ptr->row_buf + 1);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ if (png_ptr->pass == 4) /* Skip top two generated rows */
+ {
+ for (i = 0; i < 2 && png_ptr->pass == 4; i++)
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+ }
+
+ break;
+ }
+
+ case 4:
+ {
+ int i;
+
+ for (i = 0; i < 2 && png_ptr->pass == 4; i++)
+ {
+ png_push_have_row(png_ptr, png_ptr->row_buf + 1);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ for (i = 0; i < 2 && png_ptr->pass == 4; i++)
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ if (png_ptr->pass == 6) /* Pass 5 might be empty */
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ break;
+ }
+
+ case 5:
+ {
+ int i;
+
+ for (i = 0; i < 2 && png_ptr->pass == 5; i++)
+ {
+ png_push_have_row(png_ptr, png_ptr->row_buf + 1);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ if (png_ptr->pass == 6) /* Skip top generated row */
+ {
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+
+ break;
+ }
+
+ default:
+ case 6:
+ {
+ png_push_have_row(png_ptr, png_ptr->row_buf + 1);
+ png_read_push_finish_row(png_ptr);
+
+ if (png_ptr->pass != 6)
+ break;
+
+ png_push_have_row(png_ptr, NULL);
+ png_read_push_finish_row(png_ptr);
+ }
+ }
+ }
+ else
+#endif
+ {
+ png_push_have_row(png_ptr, png_ptr->row_buf + 1);
+ png_read_push_finish_row(png_ptr);
+ }
+}
+
+void /* PRIVATE */
+png_read_push_finish_row(png_structrp png_ptr)
+{
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */
+
+ /* Start of interlace block */
+ static PNG_CONST png_byte png_pass_start[] = {0, 4, 0, 2, 0, 1, 0};
+
+ /* Offset to next interlace block */
+ static PNG_CONST png_byte png_pass_inc[] = {8, 8, 4, 4, 2, 2, 1};
+
+ /* Start of interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_ystart[] = {0, 0, 4, 0, 2, 0, 1};
+
+ /* Offset to next interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_yinc[] = {8, 8, 8, 4, 4, 2, 2};
+
+ /* Height of interlace block. This is not currently used - if you need
+ * it, uncomment it here and in png.h
+ static PNG_CONST png_byte png_pass_height[] = {8, 8, 4, 4, 2, 2, 1};
+ */
+#endif
+
+ png_ptr->row_number++;
+ if (png_ptr->row_number < png_ptr->num_rows)
+ return;
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ if (png_ptr->interlaced)
+ {
+ png_ptr->row_number = 0;
+ memset(png_ptr->prev_row, 0, png_ptr->rowbytes + 1);
+
+ do
+ {
+ png_ptr->pass++;
+ if ((png_ptr->pass == 1 && png_ptr->width < 5) ||
+ (png_ptr->pass == 3 && png_ptr->width < 3) ||
+ (png_ptr->pass == 5 && png_ptr->width < 2))
+ png_ptr->pass++;
+
+ if (png_ptr->pass > 7)
+ png_ptr->pass--;
+
+ if (png_ptr->pass >= 7)
+ break;
+
+ png_ptr->iwidth = (png_ptr->width +
+ png_pass_inc[png_ptr->pass] - 1 -
+ png_pass_start[png_ptr->pass]) /
+ png_pass_inc[png_ptr->pass];
+
+ if (png_ptr->transformations & PNG_INTERLACE)
+ break;
+
+ png_ptr->num_rows = (png_ptr->height +
+ png_pass_yinc[png_ptr->pass] - 1 -
+ png_pass_ystart[png_ptr->pass]) /
+ png_pass_yinc[png_ptr->pass];
+
+ } while (png_ptr->iwidth == 0 || png_ptr->num_rows == 0);
+ }
+#endif /* PNG_READ_INTERLACING_SUPPORTED */
+}
+
+void /* PRIVATE */
+png_push_have_info(png_structrp png_ptr, png_inforp info_ptr)
+{
+ if (png_ptr->info_fn != NULL)
+ (*(png_ptr->info_fn))(png_ptr, info_ptr);
+}
+
+void /* PRIVATE */
+png_push_have_end(png_structrp png_ptr, png_inforp info_ptr)
+{
+ if (png_ptr->end_fn != NULL)
+ (*(png_ptr->end_fn))(png_ptr, info_ptr);
+}
+
+void /* PRIVATE */
+png_push_have_row(png_structrp png_ptr, png_bytep row)
+{
+ if (png_ptr->row_fn != NULL)
+ (*(png_ptr->row_fn))(png_ptr, row, png_ptr->row_number,
+ (int)png_ptr->pass);
+}
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+void PNGAPI
+png_progressive_combine_row(png_const_structrp png_ptr, png_bytep old_row,
+ png_const_bytep new_row)
+{
+ if (png_ptr == NULL)
+ return;
+
+ /* new_row is a flag here - if it is NULL then the app callback was called
+ * from an empty row (see the calls to png_struct::row_fn below), otherwise
+ * it must be png_ptr->row_buf+1
+ */
+ if (new_row != NULL)
+ png_combine_row(png_ptr, old_row, 1/*display*/);
+}
+#endif /* PNG_READ_INTERLACING_SUPPORTED */
+
+void PNGAPI
+png_set_progressive_read_fn(png_structrp png_ptr, png_voidp progressive_ptr,
+ png_progressive_info_ptr info_fn, png_progressive_row_ptr row_fn,
+ png_progressive_end_ptr end_fn)
+{
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->info_fn = info_fn;
+ png_ptr->row_fn = row_fn;
+ png_ptr->end_fn = end_fn;
+
+ png_set_read_fn(png_ptr, progressive_ptr, png_push_fill_buffer);
+}
+
+png_voidp PNGAPI
+png_get_progressive_ptr(png_const_structrp png_ptr)
+{
+ if (png_ptr == NULL)
+ return (NULL);
+
+ return png_ptr->io_ptr;
+}
+#endif /* PNG_PROGRESSIVE_READ_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngpriv.h b/ml/dlib/dlib/external/libpng/pngpriv.h
new file mode 100644
index 000000000..43b2ef4c5
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngpriv.h
@@ -0,0 +1,2047 @@
+
+/* pngpriv.h - private declarations for use inside libpng
+ *
+ * For conditions of distribution and use, see copyright notice in png.h
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * Last changed in libpng 1.6.7 [November 14, 2013]
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+/* The symbols declared in this file (including the functions declared
+ * as extern) are PRIVATE. They are not part of the libpng public
+ * interface, and are not recommended for use by regular applications.
+ * Some of them may become public in the future; others may stay private,
+ * change in an incompatible way, or even disappear.
+ * Although the libpng users are not forbidden to include this header,
+ * they should be well aware of the issues that may arise from doing so.
+ */
+
+#ifndef PNGPRIV_H
+#define PNGPRIV_H
+
+/* Feature Test Macros. The following are defined here to ensure that correctly
+ * implemented libraries reveal the APIs libpng needs to build and hide those
+ * that are not needed and potentially damaging to the compilation.
+ *
+ * Feature Test Macros must be defined before any system header is included (see
+ * POSIX 1003.1 2.8.2 "POSIX Symbols."
+ *
+ * These macros only have an effect if the operating system supports either
+ * POSIX 1003.1 or C99, or both. On other operating systems (particularly
+ * Windows/Visual Studio) there is no effect; the OS specific tests below are
+ * still required (as of 2011-05-02.)
+ */
+#define _POSIX_SOURCE 1 /* Just the POSIX 1003.1 and C89 APIs */
+
+#ifndef PNG_VERSION_INFO_ONLY
+/* Standard library headers not required by png.h: */
+# include <stdlib.h>
+# include <string.h>
+#endif
+
+#define PNGLIB_BUILD /*libpng is being built, not used*/
+
+/* If HAVE_CONFIG_H is defined during the build then the build system must
+ * provide an appropriate "config.h" file on the include path. The header file
+ * must provide definitions as required below (search for "HAVE_CONFIG_H");
+ * see configure.ac for more details of the requirements. The macro
+ * "PNG_NO_CONFIG_H" is provided for maintainers to test for dependencies on
+ * 'configure'; define this macro to prevent the configure build including the
+ * configure generated config.h. Libpng is expected to compile without *any*
+ * special build system support on a reasonably ANSI-C compliant system.
+ */
+#if defined(HAVE_CONFIG_H) && !defined(PNG_NO_CONFIG_H)
+# include <config.h>
+
+ /* Pick up the definition of 'restrict' from config.h if it was read: */
+# define PNG_RESTRICT restrict
+#endif
+
+/* To support symbol prefixing it is necessary to know *before* including png.h
+ * whether the fixed point (and maybe other) APIs are exported, because if they
+ * are not internal definitions may be required. This is handled below just
+ * before png.h is included, but load the configuration now if it is available.
+ */
+#ifndef PNGLCONF_H
+# include "pnglibconf.h"
+#endif
+
+/* Local renames may change non-exported API functions from png.h */
+#if defined(PNG_PREFIX) && !defined(PNGPREFIX_H)
+# include "pngprefix.h"
+#endif
+
+#ifdef PNG_USER_CONFIG
+# include "pngusr.h"
+ /* These should have been defined in pngusr.h */
+# ifndef PNG_USER_PRIVATEBUILD
+# define PNG_USER_PRIVATEBUILD "Custom libpng build"
+# endif
+# ifndef PNG_USER_DLLFNAME_POSTFIX
+# define PNG_USER_DLLFNAME_POSTFIX "Cb"
+# endif
+#endif
+
+/* Compile time options.
+ * =====================
+ * In a multi-arch build the compiler may compile the code several times for the
+ * same object module, producing different binaries for different architectures.
+ * When this happens configure-time setting of the target host options cannot be
+ * done and this interferes with the handling of the ARM NEON optimizations, and
+ * possibly other similar optimizations. Put additional tests here; in general
+ * this is needed when the same option can be changed at both compile time and
+ * run time depending on the target OS (i.e. iOS vs Android.)
+ *
+ * NOTE: symbol prefixing does not pass $(CFLAGS) to the preprocessor, because
+ * this is not possible with certain compilers (Oracle SUN OS CC), as a result
+ * it is necessary to ensure that all extern functions that *might* be used
+ * regardless of $(CFLAGS) get declared in this file. The test on __ARM_NEON__
+ * below is one example of this behavior because it is controlled by the
+ * presence or not of -mfpu=neon on the GCC command line, it is possible to do
+ * this in $(CC), e.g. "CC=gcc -mfpu=neon", but people who build libpng rarely
+ * do this.
+ */
+#ifndef PNG_ARM_NEON_OPT
+ /* ARM NEON optimizations are being controlled by the compiler settings,
+ * typically the target FPU. If the FPU has been set to NEON (-mfpu=neon
+ * with GCC) then the compiler will define __ARM_NEON__ and we can rely
+ * unconditionally on NEON instructions not crashing, otherwise we must
+ * disable use of NEON instructions:
+ */
+# ifdef __ARM_NEON__
+# define PNG_ARM_NEON_OPT 2
+# else
+# define PNG_ARM_NEON_OPT 0
+# endif
+#endif
+
+#if PNG_ARM_NEON_OPT > 0
+ /* NEON optimizations are to be at least considered by libpng, so enable the
+ * callbacks to do this.
+ */
+# define PNG_FILTER_OPTIMIZATIONS png_init_filter_functions_neon
+
+ /* By default the 'intrinsics' code in arm/filter_neon_intrinsics.c is used
+ * if possible - if __ARM_NEON__ is set and the compiler version is not known
+ * to be broken. This is control by PNG_ARM_NEON_IMPLEMENTATION which can
+ * be:
+ *
+ * 1 The intrinsics code (the default with __ARM_NEON__)
+ * 2 The hand coded assembler (the default without __ARM_NEON__)
+ *
+ * It is possible to set PNG_ARM_NEON_IMPLEMENTATION in CPPFLAGS, however
+ * this is *NOT* supported and may cease to work even after a minor revision
+ * to libpng. It *is* valid to do this for testing purposes, e.g. speed
+ * testing or a new compiler, but the results should be communicated to the
+ * libpng implementation list for incorporation in the next minor release.
+ */
+# ifndef PNG_ARM_NEON_IMPLEMENTATION
+# ifdef __ARM_NEON__
+# if defined(__clang__)
+ /* At present it is unknown by the libpng developers which versions
+ * of clang support the intrinsics, however some or perhaps all
+ * versions do not work with the assembler so this may be
+ * irrelevant, so just use the default (do nothing here.)
+ */
+# elif defined(__GNUC__)
+ /* GCC 4.5.4 NEON support is known to be broken. 4.6.3 is known to
+ * work, so if this *is* GCC, or G++, look for a version >4.5
+ */
+# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 6)
+# define PNG_ARM_NEON_IMPLEMENTATION 2
+# endif /* no GNUC support */
+# endif /* __GNUC__ */
+# else /* !defined __ARM_NEON__ */
+ /* The 'intrinsics' code simply won't compile without this -mfpu=neon:
+ */
+# define PNG_ARM_NEON_IMPLEMENTATION 2
+# endif /* __ARM_NEON__ */
+# endif /* !defined PNG_ARM_NEON_IMPLEMENTATION */
+
+# ifndef PNG_ARM_NEON_IMPLEMENTATION
+ /* Use the intrinsics code by default. */
+# define PNG_ARM_NEON_IMPLEMENTATION 1
+# endif
+#endif /* PNG_ARM_NEON_OPT > 0 */
+
+/* Is this a build of a DLL where compilation of the object modules requires
+ * different preprocessor settings to those required for a simple library? If
+ * so PNG_BUILD_DLL must be set.
+ *
+ * If libpng is used inside a DLL but that DLL does not export the libpng APIs
+ * PNG_BUILD_DLL must not be set. To avoid the code below kicking in build a
+ * static library of libpng then link the DLL against that.
+ */
+#ifndef PNG_BUILD_DLL
+# ifdef DLL_EXPORT
+ /* This is set by libtool when files are compiled for a DLL; libtool
+ * always compiles twice, even on systems where it isn't necessary. Set
+ * PNG_BUILD_DLL in case it is necessary:
+ */
+# define PNG_BUILD_DLL
+# else
+# ifdef _WINDLL
+ /* This is set by the Microsoft Visual Studio IDE in projects that
+ * build a DLL. It can't easily be removed from those projects (it
+ * isn't visible in the Visual Studio UI) so it is a fairly reliable
+ * indication that PNG_IMPEXP needs to be set to the DLL export
+ * attributes.
+ */
+# define PNG_BUILD_DLL
+# else
+# ifdef __DLL__
+ /* This is set by the Borland C system when compiling for a DLL
+ * (as above.)
+ */
+# define PNG_BUILD_DLL
+# else
+ /* Add additional compiler cases here. */
+# endif
+# endif
+# endif
+#endif /* Setting PNG_BUILD_DLL if required */
+
+/* See pngconf.h for more details: the builder of the library may set this on
+ * the command line to the right thing for the specific compilation system or it
+ * may be automagically set above (at present we know of no system where it does
+ * need to be set on the command line.)
+ *
+ * PNG_IMPEXP must be set here when building the library to prevent pngconf.h
+ * setting it to the "import" setting for a DLL build.
+ */
+#ifndef PNG_IMPEXP
+# ifdef PNG_BUILD_DLL
+# define PNG_IMPEXP PNG_DLL_EXPORT
+# else
+ /* Not building a DLL, or the DLL doesn't require specific export
+ * definitions.
+ */
+# define PNG_IMPEXP
+# endif
+#endif
+
+/* No warnings for private or deprecated functions in the build: */
+#ifndef PNG_DEPRECATED
+# define PNG_DEPRECATED
+#endif
+#ifndef PNG_PRIVATE
+# define PNG_PRIVATE
+#endif
+
+/* Symbol preprocessing support.
+ *
+ * To enable listing global, but internal, symbols the following macros should
+ * always be used to declare an extern data or function object in this file.
+ */
+#ifndef PNG_INTERNAL_DATA
+# define PNG_INTERNAL_DATA(type, name, array) extern type name array
+#endif
+
+#ifndef PNG_INTERNAL_FUNCTION
+# define PNG_INTERNAL_FUNCTION(type, name, args, attributes)\
+ extern PNG_FUNCTION(type, name, args, PNG_EMPTY attributes)
+#endif
+
+/* If floating or fixed point APIs are disabled they may still be compiled
+ * internally. To handle this make sure they are declared as the appropriate
+ * internal extern function (otherwise the symbol prefixing stuff won't work and
+ * the functions will be used without definitions.)
+ *
+ * NOTE: although all the API functions are declared here they are not all
+ * actually built! Because the declarations are still made it is necessary to
+ * fake out types that they depend on.
+ */
+#ifndef PNG_FP_EXPORT
+# ifndef PNG_FLOATING_POINT_SUPPORTED
+# define PNG_FP_EXPORT(ordinal, type, name, args)\
+ PNG_INTERNAL_FUNCTION(type, name, args, PNG_EMPTY);
+# ifndef PNG_VERSION_INFO_ONLY
+ typedef struct png_incomplete png_double;
+ typedef png_double* png_doublep;
+ typedef const png_double* png_const_doublep;
+ typedef png_double** png_doublepp;
+# endif
+# endif
+#endif
+#ifndef PNG_FIXED_EXPORT
+# ifndef PNG_FIXED_POINT_SUPPORTED
+# define PNG_FIXED_EXPORT(ordinal, type, name, args)\
+ PNG_INTERNAL_FUNCTION(type, name, args, PNG_EMPTY);
+# endif
+#endif
+
+#include "png.h"
+
+/* pngconf.h does not set PNG_DLL_EXPORT unless it is required, so: */
+#ifndef PNG_DLL_EXPORT
+# define PNG_DLL_EXPORT
+#endif
+
+/* SECURITY and SAFETY:
+ *
+ * By default libpng is built without any internal limits on image size,
+ * individual heap (png_malloc) allocations or the total amount of memory used.
+ * If PNG_SAFE_LIMITS_SUPPORTED is defined, however, the limits below are used
+ * (unless individually overridden). These limits are believed to be fairly
+ * safe, but builders of secure systems should verify the values against the
+ * real system capabilities.
+ */
+#ifdef PNG_SAFE_LIMITS_SUPPORTED
+ /* 'safe' limits */
+# ifndef PNG_USER_WIDTH_MAX
+# define PNG_USER_WIDTH_MAX 1000000
+# endif
+# ifndef PNG_USER_HEIGHT_MAX
+# define PNG_USER_HEIGHT_MAX 1000000
+# endif
+# ifndef PNG_USER_CHUNK_CACHE_MAX
+# define PNG_USER_CHUNK_CACHE_MAX 128
+# endif
+# ifndef PNG_USER_CHUNK_MALLOC_MAX
+# define PNG_USER_CHUNK_MALLOC_MAX 8000000
+# endif
+#else
+ /* values for no limits */
+# ifndef PNG_USER_WIDTH_MAX
+# define PNG_USER_WIDTH_MAX 0x7fffffff
+# endif
+# ifndef PNG_USER_HEIGHT_MAX
+# define PNG_USER_HEIGHT_MAX 0x7fffffff
+# endif
+# ifndef PNG_USER_CHUNK_CACHE_MAX
+# define PNG_USER_CHUNK_CACHE_MAX 0
+# endif
+# ifndef PNG_USER_CHUNK_MALLOC_MAX
+# define PNG_USER_CHUNK_MALLOC_MAX 0
+# endif
+#endif
+
+/* Moved to pngpriv.h at libpng-1.5.0 */
+/* NOTE: some of these may have been used in external applications as
+ * these definitions were exposed in pngconf.h prior to 1.5.
+ */
+
+/* If you are running on a machine where you cannot allocate more
+ * than 64K of memory at once, uncomment this. While libpng will not
+ * normally need that much memory in a chunk (unless you load up a very
+ * large file), zlib needs to know how big of a chunk it can use, and
+ * libpng thus makes sure to check any memory allocation to verify it
+ * will fit into memory.
+ *
+ * zlib provides 'MAXSEG_64K' which, if defined, indicates the
+ * same limit and pngconf.h (already included) sets the limit
+ * if certain operating systems are detected.
+ */
+#if defined(MAXSEG_64K) && !defined(PNG_MAX_MALLOC_64K)
+# define PNG_MAX_MALLOC_64K
+#endif
+
+#ifndef PNG_UNUSED
+/* Unused formal parameter warnings are silenced using the following macro
+ * which is expected to have no bad effects on performance (optimizing
+ * compilers will probably remove it entirely). Note that if you replace
+ * it with something other than whitespace, you must include the terminating
+ * semicolon.
+ */
+# define PNG_UNUSED(param) (void)param;
+#endif
+
+/* Just a little check that someone hasn't tried to define something
+ * contradictory.
+ */
+#if (PNG_ZBUF_SIZE > 65536L) && defined(PNG_MAX_MALLOC_64K)
+# undef PNG_ZBUF_SIZE
+# define PNG_ZBUF_SIZE 65536L
+#endif
+
+/* If warnings or errors are turned off the code is disabled or redirected here.
+ * From 1.5.4 functions have been added to allow very limited formatting of
+ * error and warning messages - this code will also be disabled here.
+ */
+#ifdef PNG_WARNINGS_SUPPORTED
+# define PNG_WARNING_PARAMETERS(p) png_warning_parameters p;
+#else
+# define png_warning(s1,s2) ((void)(s1))
+# define png_chunk_warning(s1,s2) ((void)(s1))
+# define png_warning_parameter(p,number,string) ((void)0)
+# define png_warning_parameter_unsigned(p,number,format,value) ((void)0)
+# define png_warning_parameter_signed(p,number,format,value) ((void)0)
+# define png_formatted_warning(pp,p,message) ((void)(pp))
+# define PNG_WARNING_PARAMETERS(p)
+#endif
+#ifndef PNG_ERROR_TEXT_SUPPORTED
+# define png_error(s1,s2) png_err(s1)
+# define png_chunk_error(s1,s2) png_err(s1)
+# define png_fixed_error(s1,s2) png_err(s1)
+#endif
+
+/* C allows up-casts from (void*) to any pointer and (const void*) to any
+ * pointer to a const object. C++ regards this as a type error and requires an
+ * explicit, static, cast and provides the static_cast<> rune to ensure that
+ * const is not cast away.
+ */
+#ifdef __cplusplus
+# define png_voidcast(type, value) static_cast<type>(value)
+# define png_constcast(type, value) const_cast<type>(value)
+# define png_aligncast(type, value) \
+ static_cast<type>(static_cast<void*>(value))
+# define png_aligncastconst(type, value) \
+ static_cast<type>(static_cast<const void*>(value))
+#else
+# define png_voidcast(type, value) (value)
+# define png_constcast(type, value) ((type)(value))
+# define png_aligncast(type, value) ((void*)(value))
+# define png_aligncastconst(type, value) ((const void*)(value))
+#endif /* __cplusplus */
+
+/* Some fixed point APIs are still required even if not exported because
+ * they get used by the corresponding floating point APIs. This magic
+ * deals with this:
+ */
+#ifdef PNG_FIXED_POINT_SUPPORTED
+# define PNGFAPI PNGAPI
+#else
+# define PNGFAPI /* PRIVATE */
+#endif
+
+#ifndef PNG_VERSION_INFO_ONLY
+/* Other defines specific to compilers can go here. Try to keep
+ * them inside an appropriate ifdef/endif pair for portability.
+ */
+#if defined(PNG_FLOATING_POINT_SUPPORTED) ||\
+ defined(PNG_FLOATING_ARITHMETIC_SUPPORTED)
+ /* png.c requires the following ANSI-C constants if the conversion of
+ * floating point to ASCII is implemented therein:
+ *
+ * DBL_DIG Maximum number of decimal digits (can be set to any constant)
+ * DBL_MIN Smallest normalized fp number (can be set to an arbitrary value)
+ * DBL_MAX Maximum floating point number (can be set to an arbitrary value)
+ */
+# include <float.h>
+
+# if (defined(__MWERKS__) && defined(macintosh)) || defined(applec) || \
+ defined(THINK_C) || defined(__SC__) || defined(TARGET_OS_MAC)
+ /* We need to check that <math.h> hasn't already been included earlier
+ * as it seems it doesn't agree with <fp.h>, yet we should really use
+ * <fp.h> if possible.
+ */
+# if !defined(__MATH_H__) && !defined(__MATH_H) && !defined(__cmath__)
+# include <fp.h>
+# endif
+# else
+# include <math.h>
+# endif
+# if defined(_AMIGA) && defined(__SASC) && defined(_M68881)
+ /* Amiga SAS/C: We must include builtin FPU functions when compiling using
+ * MATH=68881
+ */
+# include <m68881.h>
+# endif
+#endif
+
+/* This provides the non-ANSI (far) memory allocation routines. */
+#if defined(__TURBOC__) && defined(__MSDOS__)
+# include <mem.h>
+# include <alloc.h>
+#endif
+
+#if defined(WIN32) || defined(_Windows) || defined(_WINDOWS) || \
+ defined(_WIN32) || defined(__WIN32__)
+# include <windows.h> /* defines _WINDOWS_ macro */
+#endif
+#endif /* PNG_VERSION_INFO_ONLY */
+
+/* Moved here around 1.5.0beta36 from pngconf.h */
+/* Users may want to use these so they are not private. Any library
+ * functions that are passed far data must be model-independent.
+ */
+
+/* Memory model/platform independent fns */
+#ifndef PNG_ABORT
+# ifdef _WINDOWS_
+# define PNG_ABORT() ExitProcess(0)
+# else
+# define PNG_ABORT() abort()
+# endif
+#endif
+
+/* These macros may need to be architecture dependent. */
+#define PNG_ALIGN_NONE 0 /* do not use data alignment */
+#define PNG_ALIGN_ALWAYS 1 /* assume unaligned accesses are OK */
+#ifdef offsetof
+# define PNG_ALIGN_OFFSET 2 /* use offsetof to determine alignment */
+#else
+# define PNG_ALIGN_OFFSET -1 /* prevent the use of this */
+#endif
+#define PNG_ALIGN_SIZE 3 /* use sizeof to determine alignment */
+
+#ifndef PNG_ALIGN_TYPE
+ /* Default to using aligned access optimizations and requiring alignment to a
+ * multiple of the data type size. Override in a compiler specific fashion
+ * if necessary by inserting tests here:
+ */
+# define PNG_ALIGN_TYPE PNG_ALIGN_SIZE
+#endif
+
+#if PNG_ALIGN_TYPE == PNG_ALIGN_SIZE
+ /* This is used because in some compiler implementations non-aligned
+ * structure members are supported, so the offsetof approach below fails.
+ * Set PNG_ALIGN_SIZE=0 for compiler combinations where unaligned access
+ * is good for performance. Do not do this unless you have tested the result
+ * and understand it.
+ */
+# define png_alignof(type) (sizeof (type))
+#else
+# if PNG_ALIGN_TYPE == PNG_ALIGN_OFFSET
+# define png_alignof(type) offsetof(struct{char c; type t;}, t)
+# else
+# if PNG_ALIGN_TYPE == PNG_ALIGN_ALWAYS
+# define png_alignof(type) (1)
+# endif
+ /* Else leave png_alignof undefined to prevent use thereof */
+# endif
+#endif
+
+/* This implicitly assumes alignment is always to a power of 2. */
+#ifdef png_alignof
+# define png_isaligned(ptr, type)\
+ ((((const char*)ptr-(const char*)0) & (png_alignof(type)-1)) == 0)
+#else
+# define png_isaligned(ptr, type) 0
+#endif
+
+/* End of memory model/platform independent support */
+/* End of 1.5.0beta36 move from pngconf.h */
+
+/* CONSTANTS and UTILITY MACROS
+ * These are used internally by libpng and not exposed in the API
+ */
+
+/* Various modes of operation. Note that after an init, mode is set to
+ * zero automatically when the structure is created. Three of these
+ * are defined in png.h because they need to be visible to applications
+ * that call png_set_unknown_chunk().
+ */
+/* #define PNG_HAVE_IHDR 0x01 (defined in png.h) */
+/* #define PNG_HAVE_PLTE 0x02 (defined in png.h) */
+#define PNG_HAVE_IDAT 0x04
+/* #define PNG_AFTER_IDAT 0x08 (defined in png.h) */
+#define PNG_HAVE_IEND 0x10
+ /* 0x20 (unused) */
+ /* 0x40 (unused) */
+ /* 0x80 (unused) */
+#define PNG_HAVE_CHUNK_HEADER 0x100
+#define PNG_WROTE_tIME 0x200
+#define PNG_WROTE_INFO_BEFORE_PLTE 0x400
+#define PNG_BACKGROUND_IS_GRAY 0x800
+#define PNG_HAVE_PNG_SIGNATURE 0x1000
+#define PNG_HAVE_CHUNK_AFTER_IDAT 0x2000 /* Have another chunk after IDAT */
+ /* 0x4000 (unused) */
+#define PNG_IS_READ_STRUCT 0x8000 /* Else is a write struct */
+
+/* Flags for the transformations the PNG library does on the image data */
+#define PNG_BGR 0x0001
+#define PNG_INTERLACE 0x0002
+#define PNG_PACK 0x0004
+#define PNG_SHIFT 0x0008
+#define PNG_SWAP_BYTES 0x0010
+#define PNG_INVERT_MONO 0x0020
+#define PNG_QUANTIZE 0x0040
+#define PNG_COMPOSE 0x0080 /* Was PNG_BACKGROUND */
+#define PNG_BACKGROUND_EXPAND 0x0100
+#define PNG_EXPAND_16 0x0200 /* Added to libpng 1.5.2 */
+#define PNG_16_TO_8 0x0400 /* Becomes 'chop' in 1.5.4 */
+#define PNG_RGBA 0x0800
+#define PNG_EXPAND 0x1000
+#define PNG_GAMMA 0x2000
+#define PNG_GRAY_TO_RGB 0x4000
+#define PNG_FILLER 0x8000
+#define PNG_PACKSWAP 0x10000
+#define PNG_SWAP_ALPHA 0x20000
+#define PNG_STRIP_ALPHA 0x40000
+#define PNG_INVERT_ALPHA 0x80000
+#define PNG_USER_TRANSFORM 0x100000
+#define PNG_RGB_TO_GRAY_ERR 0x200000
+#define PNG_RGB_TO_GRAY_WARN 0x400000
+#define PNG_RGB_TO_GRAY 0x600000 /* two bits, RGB_TO_GRAY_ERR|WARN */
+#define PNG_ENCODE_ALPHA 0x800000 /* Added to libpng-1.5.4 */
+#define PNG_ADD_ALPHA 0x1000000 /* Added to libpng-1.2.7 */
+#define PNG_EXPAND_tRNS 0x2000000 /* Added to libpng-1.2.9 */
+#define PNG_SCALE_16_TO_8 0x4000000 /* Added to libpng-1.5.4 */
+ /* 0x8000000 unused */
+ /* 0x10000000 unused */
+ /* 0x20000000 unused */
+ /* 0x40000000 unused */
+/* Flags for png_create_struct */
+#define PNG_STRUCT_PNG 0x0001
+#define PNG_STRUCT_INFO 0x0002
+
+/* Scaling factor for filter heuristic weighting calculations */
+#define PNG_WEIGHT_FACTOR (1<<(PNG_WEIGHT_SHIFT))
+#define PNG_COST_FACTOR (1<<(PNG_COST_SHIFT))
+
+/* Flags for the png_ptr->flags rather than declaring a byte for each one */
+#define PNG_FLAG_ZLIB_CUSTOM_STRATEGY 0x0001
+#define PNG_FLAG_ZSTREAM_INITIALIZED 0x0002 /* Added to libpng-1.6.0 */
+ /* 0x0004 unused */
+#define PNG_FLAG_ZSTREAM_ENDED 0x0008 /* Added to libpng-1.6.0 */
+ /* 0x0010 unused */
+ /* 0x0020 unused */
+#define PNG_FLAG_ROW_INIT 0x0040
+#define PNG_FLAG_FILLER_AFTER 0x0080
+#define PNG_FLAG_CRC_ANCILLARY_USE 0x0100
+#define PNG_FLAG_CRC_ANCILLARY_NOWARN 0x0200
+#define PNG_FLAG_CRC_CRITICAL_USE 0x0400
+#define PNG_FLAG_CRC_CRITICAL_IGNORE 0x0800
+#define PNG_FLAG_ASSUME_sRGB 0x1000 /* Added to libpng-1.5.4 */
+#define PNG_FLAG_OPTIMIZE_ALPHA 0x2000 /* Added to libpng-1.5.4 */
+#define PNG_FLAG_DETECT_UNINITIALIZED 0x4000 /* Added to libpng-1.5.4 */
+/* #define PNG_FLAG_KEEP_UNKNOWN_CHUNKS 0x8000 */
+/* #define PNG_FLAG_KEEP_UNSAFE_CHUNKS 0x10000 */
+#define PNG_FLAG_LIBRARY_MISMATCH 0x20000
+#define PNG_FLAG_STRIP_ERROR_NUMBERS 0x40000
+#define PNG_FLAG_STRIP_ERROR_TEXT 0x80000
+#define PNG_FLAG_BENIGN_ERRORS_WARN 0x100000 /* Added to libpng-1.4.0 */
+#define PNG_FLAG_APP_WARNINGS_WARN 0x200000 /* Added to libpng-1.6.0 */
+#define PNG_FLAG_APP_ERRORS_WARN 0x400000 /* Added to libpng-1.6.0 */
+ /* 0x800000 unused */
+ /* 0x1000000 unused */
+ /* 0x2000000 unused */
+ /* 0x4000000 unused */
+ /* 0x8000000 unused */
+ /* 0x10000000 unused */
+ /* 0x20000000 unused */
+ /* 0x40000000 unused */
+
+#define PNG_FLAG_CRC_ANCILLARY_MASK (PNG_FLAG_CRC_ANCILLARY_USE | \
+ PNG_FLAG_CRC_ANCILLARY_NOWARN)
+
+#define PNG_FLAG_CRC_CRITICAL_MASK (PNG_FLAG_CRC_CRITICAL_USE | \
+ PNG_FLAG_CRC_CRITICAL_IGNORE)
+
+#define PNG_FLAG_CRC_MASK (PNG_FLAG_CRC_ANCILLARY_MASK | \
+ PNG_FLAG_CRC_CRITICAL_MASK)
+
+/* Save typing and make code easier to understand */
+
+#define PNG_COLOR_DIST(c1, c2) (abs((int)((c1).red) - (int)((c2).red)) + \
+ abs((int)((c1).green) - (int)((c2).green)) + \
+ abs((int)((c1).blue) - (int)((c2).blue)))
+
+/* Added to libpng-1.6.0: scale a 16-bit value in the range 0..65535 to 0..255
+ * by dividing by 257 *with rounding*. This macro is exact for the given range.
+ * See the discourse in pngrtran.c png_do_scale_16_to_8. The values in the
+ * macro were established by experiment (modifying the added value). The macro
+ * has a second variant that takes a value already scaled by 255 and divides by
+ * 65535 - this has a maximum error of .502. Over the range 0..65535*65535 it
+ * only gives off-by-one errors and only for 0.5% (1 in 200) of the values.
+ */
+#define PNG_DIV65535(v24) (((v24) + 32895) >> 16)
+#define PNG_DIV257(v16) PNG_DIV65535((png_uint_32)(v16) * 255)
+
+/* Added to libpng-1.2.6 JB */
+#define PNG_ROWBYTES(pixel_bits, width) \
+ ((pixel_bits) >= 8 ? \
+ ((png_size_t)(width) * (((png_size_t)(pixel_bits)) >> 3)) : \
+ (( ((png_size_t)(width) * ((png_size_t)(pixel_bits))) + 7) >> 3) )
+
+/* PNG_OUT_OF_RANGE returns true if value is outside the range
+ * ideal-delta..ideal+delta. Each argument is evaluated twice.
+ * "ideal" and "delta" should be constants, normally simple
+ * integers, "value" a variable. Added to libpng-1.2.6 JB
+ */
+#define PNG_OUT_OF_RANGE(value, ideal, delta) \
+ ( (value) < (ideal)-(delta) || (value) > (ideal)+(delta) )
+
+/* Conversions between fixed and floating point, only defined if
+ * required (to make sure the code doesn't accidentally use float
+ * when it is supposedly disabled.)
+ */
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+/* The floating point conversion can't overflow, though it can and
+ * does lose accuracy relative to the original fixed point value.
+ * In practice this doesn't matter because png_fixed_point only
+ * stores numbers with very low precision. The png_ptr and s
+ * arguments are unused by default but are there in case error
+ * checking becomes a requirement.
+ */
+#define png_float(png_ptr, fixed, s) (.00001 * (fixed))
+
+/* The fixed point conversion performs range checking and evaluates
+ * its argument multiple times, so must be used with care. The
+ * range checking uses the PNG specification values for a signed
+ * 32 bit fixed point value except that the values are deliberately
+ * rounded-to-zero to an integral value - 21474 (21474.83 is roughly
+ * (2^31-1) * 100000). 's' is a string that describes the value being
+ * converted.
+ *
+ * NOTE: this macro will raise a png_error if the range check fails,
+ * therefore it is normally only appropriate to use this on values
+ * that come from API calls or other sources where an out of range
+ * error indicates a programming error, not a data error!
+ *
+ * NOTE: by default this is off - the macro is not used - because the
+ * function call saves a lot of code.
+ */
+#ifdef PNG_FIXED_POINT_MACRO_SUPPORTED
+#define png_fixed(png_ptr, fp, s) ((fp) <= 21474 && (fp) >= -21474 ?\
+ ((png_fixed_point)(100000 * (fp))) : (png_fixed_error(png_ptr, s),0))
+#endif
+/* else the corresponding function is defined below, inside the scope of the
+ * cplusplus test.
+ */
+#endif
+
+/* Constants for known chunk types. If you need to add a chunk, define the name
+ * here. For historical reasons these constants have the form png_<name>; i.e.
+ * the prefix is lower case. Please use decimal values as the parameters to
+ * match the ISO PNG specification and to avoid relying on the C locale
+ * interpretation of character values.
+ *
+ * Prior to 1.5.6 these constants were strings, as of 1.5.6 png_uint_32 values
+ * are computed and a new macro (PNG_STRING_FROM_CHUNK) added to allow a string
+ * to be generated if required.
+ *
+ * PNG_32b correctly produces a value shifted by up to 24 bits, even on
+ * architectures where (int) is only 16 bits.
+ */
+#define PNG_32b(b,s) ((png_uint_32)(b) << (s))
+#define PNG_U32(b1,b2,b3,b4) \
+ (PNG_32b(b1,24) | PNG_32b(b2,16) | PNG_32b(b3,8) | PNG_32b(b4,0))
+
+/* Constants for known chunk types.
+ *
+ * MAINTAINERS: If you need to add a chunk, define the name here.
+ * For historical reasons these constants have the form png_<name>; i.e.
+ * the prefix is lower case. Please use decimal values as the parameters to
+ * match the ISO PNG specification and to avoid relying on the C locale
+ * interpretation of character values. Please keep the list sorted.
+ *
+ * Notice that PNG_U32 is used to define a 32-bit value for the 4 byte chunk
+ * type. In fact the specification does not express chunk types this way,
+ * however using a 32-bit value means that the chunk type can be read from the
+ * stream using exactly the same code as used for a 32-bit unsigned value and
+ * can be examined far more efficiently (using one arithmetic compare).
+ *
+ * Prior to 1.5.6 the chunk type constants were expressed as C strings. The
+ * libpng API still uses strings for 'unknown' chunks and a macro,
+ * PNG_STRING_FROM_CHUNK, allows a string to be generated if required. Notice
+ * that for portable code numeric values must still be used; the string "IHDR"
+ * is not portable and neither is PNG_U32('I', 'H', 'D', 'R').
+ *
+ * In 1.7.0 the definitions will be made public in png.h to avoid having to
+ * duplicate the same definitions in application code.
+ */
+#define png_IDAT PNG_U32( 73, 68, 65, 84)
+#define png_IEND PNG_U32( 73, 69, 78, 68)
+#define png_IHDR PNG_U32( 73, 72, 68, 82)
+#define png_PLTE PNG_U32( 80, 76, 84, 69)
+#define png_bKGD PNG_U32( 98, 75, 71, 68)
+#define png_cHRM PNG_U32( 99, 72, 82, 77)
+#define png_fRAc PNG_U32(102, 82, 65, 99) /* registered, not defined */
+#define png_gAMA PNG_U32(103, 65, 77, 65)
+#define png_gIFg PNG_U32(103, 73, 70, 103)
+#define png_gIFt PNG_U32(103, 73, 70, 116) /* deprecated */
+#define png_gIFx PNG_U32(103, 73, 70, 120)
+#define png_hIST PNG_U32(104, 73, 83, 84)
+#define png_iCCP PNG_U32(105, 67, 67, 80)
+#define png_iTXt PNG_U32(105, 84, 88, 116)
+#define png_oFFs PNG_U32(111, 70, 70, 115)
+#define png_pCAL PNG_U32(112, 67, 65, 76)
+#define png_pHYs PNG_U32(112, 72, 89, 115)
+#define png_sBIT PNG_U32(115, 66, 73, 84)
+#define png_sCAL PNG_U32(115, 67, 65, 76)
+#define png_sPLT PNG_U32(115, 80, 76, 84)
+#define png_sRGB PNG_U32(115, 82, 71, 66)
+#define png_sTER PNG_U32(115, 84, 69, 82)
+#define png_tEXt PNG_U32(116, 69, 88, 116)
+#define png_tIME PNG_U32(116, 73, 77, 69)
+#define png_tRNS PNG_U32(116, 82, 78, 83)
+#define png_zTXt PNG_U32(122, 84, 88, 116)
+
+/* The following will work on (signed char*) strings, whereas the get_uint_32
+ * macro will fail on top-bit-set values because of the sign extension.
+ */
+#define PNG_CHUNK_FROM_STRING(s)\
+ PNG_U32(0xff&(s)[0], 0xff&(s)[1], 0xff&(s)[2], 0xff&(s)[3])
+
+/* This uses (char), not (png_byte) to avoid warnings on systems where (char) is
+ * signed and the argument is a (char[]) This macro will fail miserably on
+ * systems where (char) is more than 8 bits.
+ */
+#define PNG_STRING_FROM_CHUNK(s,c)\
+ (void)(((char*)(s))[0]=(char)((c)>>24), ((char*)(s))[1]=(char)((c)>>16),\
+ ((char*)(s))[2]=(char)((c)>>8), ((char*)(s))[3]=(char)((c)))
+
+/* Do the same but terminate with a null character. */
+#define PNG_CSTRING_FROM_CHUNK(s,c)\
+ (void)(PNG_STRING_FROM_CHUNK(s,c), ((char*)(s))[4] = 0)
+
+/* Test on flag values as defined in the spec (section 5.4): */
+#define PNG_CHUNK_ANCILLARY(c) (1 & ((c) >> 29))
+#define PNG_CHUNK_CRITICAL(c) (!PNG_CHUNK_ANCILLARY(c))
+#define PNG_CHUNK_PRIVATE(c) (1 & ((c) >> 21))
+#define PNG_CHUNK_RESERVED(c) (1 & ((c) >> 13))
+#define PNG_CHUNK_SAFE_TO_COPY(c) (1 & ((c) >> 5))
+
+/* Gamma values (new at libpng-1.5.4): */
+#define PNG_GAMMA_MAC_OLD 151724 /* Assume '1.8' is really 2.2/1.45! */
+#define PNG_GAMMA_MAC_INVERSE 65909
+#define PNG_GAMMA_sRGB_INVERSE 45455
+
+/* Almost everything below is C specific; the #defines above can be used in
+ * non-C code (so long as it is C-preprocessed) the rest of this stuff cannot.
+ */
+#ifndef PNG_VERSION_INFO_ONLY
+
+#include "pngstruct.h"
+#include "pnginfo.h"
+
+/* Validate the include paths - the include path used to generate pnglibconf.h
+ * must match that used in the build, or we must be using pnglibconf.h.prebuilt:
+ */
+#if PNG_ZLIB_VERNUM != 0 && PNG_ZLIB_VERNUM != ZLIB_VERNUM
+# error ZLIB_VERNUM != PNG_ZLIB_VERNUM \
+ "-I (include path) error: see the notes in pngpriv.h"
+ /* This means that when pnglibconf.h was built the copy of zlib.h that it
+ * used is not the same as the one being used here. Because the build of
+ * libpng makes decisions to use inflateInit2 and inflateReset2 based on the
+ * zlib version number and because this affects handling of certain broken
+ * PNG files the -I directives must match.
+ *
+ * The most likely explanation is that you passed a -I in CFLAGS, this will
+ * not work; all the preprocessor directories and in particular all the -I
+ * directives must be in CPPFLAGS.
+ */
+#endif
+
+/* This is used for 16 bit gamma tables -- only the top level pointers are
+ * const; this could be changed:
+ */
+typedef const png_uint_16p * png_const_uint_16pp;
+
+/* Added to libpng-1.5.7: sRGB conversion tables */
+#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\
+ defined(PNG_SIMPLIFIED_WRITE_SUPPORTED)
+#ifdef PNG_SIMPLIFIED_READ_SUPPORTED
+PNG_INTERNAL_DATA(const png_uint_16, png_sRGB_table, [256]);
+ /* Convert from an sRGB encoded value 0..255 to a 16-bit linear value,
+ * 0..65535. This table gives the closest 16-bit answers (no errors).
+ */
+#endif
+
+PNG_INTERNAL_DATA(const png_uint_16, png_sRGB_base, [512]);
+PNG_INTERNAL_DATA(const png_byte, png_sRGB_delta, [512]);
+
+#define PNG_sRGB_FROM_LINEAR(linear) ((png_byte)((png_sRGB_base[(linear)>>15] +\
+ ((((linear)&0x7fff)*png_sRGB_delta[(linear)>>15])>>12)) >> 8))
+ /* Given a value 'linear' in the range 0..255*65535 calculate the 8-bit sRGB
+ * encoded value with maximum error 0.646365. Note that the input is not a
+ * 16-bit value; it has been multiplied by 255! */
+#endif /* PNG_SIMPLIFIED_READ/WRITE */
+
+
+/* Inhibit C++ name-mangling for libpng functions but not for system calls. */
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+/* Internal functions; these are not exported from a DLL however because they
+ * are used within several of the C source files they have to be C extern.
+ *
+ * All of these functions must be declared with PNG_INTERNAL_FUNCTION.
+ */
+
+/* Zlib support */
+#define PNG_UNEXPECTED_ZLIB_RETURN (-7)
+PNG_INTERNAL_FUNCTION(void, png_zstream_error,(png_structrp png_ptr, int ret),
+ PNG_EMPTY);
+ /* Used by the zlib handling functions to ensure that z_stream::msg is always
+ * set before they return.
+ */
+
+#ifdef PNG_WRITE_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_free_buffer_list,(png_structrp png_ptr,
+ png_compression_bufferp *list),PNG_EMPTY);
+ /* Free the buffer list used by the compressed write code. */
+#endif
+
+#if defined(PNG_FLOATING_POINT_SUPPORTED) && \
+ !defined(PNG_FIXED_POINT_MACRO_SUPPORTED) && \
+ (defined(PNG_gAMA_SUPPORTED) || defined(PNG_cHRM_SUPPORTED) || \
+ defined(PNG_sCAL_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) || \
+ defined(PNG_READ_RGB_TO_GRAY_SUPPORTED)) || \
+ (defined(PNG_sCAL_SUPPORTED) && \
+ defined(PNG_FLOATING_ARITHMETIC_SUPPORTED))
+PNG_INTERNAL_FUNCTION(png_fixed_point,png_fixed,(png_const_structrp png_ptr,
+ double fp, png_const_charp text),PNG_EMPTY);
+#endif
+
+/* Check the user version string for compatibility, returns false if the version
+ * numbers aren't compatible.
+ */
+PNG_INTERNAL_FUNCTION(int,png_user_version_check,(png_structrp png_ptr,
+ png_const_charp user_png_ver),PNG_EMPTY);
+
+/* Internal base allocator - no messages, NULL on failure to allocate. This
+ * does, however, call the application provided allocator and that could call
+ * png_error (although that would be a bug in the application implementation.)
+ */
+PNG_INTERNAL_FUNCTION(png_voidp,png_malloc_base,(png_const_structrp png_ptr,
+ png_alloc_size_t size),PNG_ALLOCATED);
+
+#if defined(PNG_TEXT_SUPPORTED) || defined(PNG_sPLT_SUPPORTED) ||\
+ defined(PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED)
+/* Internal array allocator, outputs no error or warning messages on failure,
+ * just returns NULL.
+ */
+PNG_INTERNAL_FUNCTION(png_voidp,png_malloc_array,(png_const_structrp png_ptr,
+ int nelements, size_t element_size),PNG_ALLOCATED);
+
+/* The same but an existing array is extended by add_elements. This function
+ * also memsets the new elements to 0 and copies the old elements. The old
+ * array is not freed or altered.
+ */
+PNG_INTERNAL_FUNCTION(png_voidp,png_realloc_array,(png_const_structrp png_ptr,
+ png_const_voidp array, int old_elements, int add_elements,
+ size_t element_size),PNG_ALLOCATED);
+#endif /* text, sPLT or unknown chunks */
+
+/* Magic to create a struct when there is no struct to call the user supplied
+ * memory allocators. Because error handling has not been set up the memory
+ * handlers can't safely call png_error, but this is an obscure and undocumented
+ * restriction so libpng has to assume that the 'free' handler, at least, might
+ * call png_error.
+ */
+PNG_INTERNAL_FUNCTION(png_structp,png_create_png_struct,
+ (png_const_charp user_png_ver, png_voidp error_ptr, png_error_ptr error_fn,
+ png_error_ptr warn_fn, png_voidp mem_ptr, png_malloc_ptr malloc_fn,
+ png_free_ptr free_fn),PNG_ALLOCATED);
+
+/* Free memory from internal libpng struct */
+PNG_INTERNAL_FUNCTION(void,png_destroy_png_struct,(png_structrp png_ptr),
+ PNG_EMPTY);
+
+/* Free an allocated jmp_buf (always succeeds) */
+PNG_INTERNAL_FUNCTION(void,png_free_jmpbuf,(png_structrp png_ptr),PNG_EMPTY);
+
+/* Function to allocate memory for zlib. PNGAPI is disallowed. */
+PNG_INTERNAL_FUNCTION(voidpf,png_zalloc,(voidpf png_ptr, uInt items, uInt size),
+ PNG_ALLOCATED);
+
+/* Function to free memory for zlib. PNGAPI is disallowed. */
+PNG_INTERNAL_FUNCTION(void,png_zfree,(voidpf png_ptr, voidpf ptr),PNG_EMPTY);
+
+/* Next four functions are used internally as callbacks. PNGCBAPI is required
+ * but not PNG_EXPORT. PNGAPI added at libpng version 1.2.3, changed to
+ * PNGCBAPI at 1.5.0
+ */
+
+PNG_INTERNAL_FUNCTION(void PNGCBAPI,png_default_read_data,(png_structp png_ptr,
+ png_bytep data, png_size_t length),PNG_EMPTY);
+
+#ifdef PNG_PROGRESSIVE_READ_SUPPORTED
+PNG_INTERNAL_FUNCTION(void PNGCBAPI,png_push_fill_buffer,(png_structp png_ptr,
+ png_bytep buffer, png_size_t length),PNG_EMPTY);
+#endif
+
+PNG_INTERNAL_FUNCTION(void PNGCBAPI,png_default_write_data,(png_structp png_ptr,
+ png_bytep data, png_size_t length),PNG_EMPTY);
+
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+# ifdef PNG_STDIO_SUPPORTED
+PNG_INTERNAL_FUNCTION(void PNGCBAPI,png_default_flush,(png_structp png_ptr),
+ PNG_EMPTY);
+# endif
+#endif
+
+/* Reset the CRC variable */
+PNG_INTERNAL_FUNCTION(void,png_reset_crc,(png_structrp png_ptr),PNG_EMPTY);
+
+/* Write the "data" buffer to whatever output you are using */
+PNG_INTERNAL_FUNCTION(void,png_write_data,(png_structrp png_ptr,
+ png_const_bytep data, png_size_t length),PNG_EMPTY);
+
+/* Read and check the PNG file signature */
+PNG_INTERNAL_FUNCTION(void,png_read_sig,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+
+/* Read the chunk header (length + type name) */
+PNG_INTERNAL_FUNCTION(png_uint_32,png_read_chunk_header,(png_structrp png_ptr),
+ PNG_EMPTY);
+
+/* Read data from whatever input you are using into the "data" buffer */
+PNG_INTERNAL_FUNCTION(void,png_read_data,(png_structrp png_ptr, png_bytep data,
+ png_size_t length),PNG_EMPTY);
+
+/* Read bytes into buf, and update png_ptr->crc */
+PNG_INTERNAL_FUNCTION(void,png_crc_read,(png_structrp png_ptr, png_bytep buf,
+ png_uint_32 length),PNG_EMPTY);
+
+/* Read "skip" bytes, read the file crc, and (optionally) verify png_ptr->crc */
+PNG_INTERNAL_FUNCTION(int,png_crc_finish,(png_structrp png_ptr,
+ png_uint_32 skip),PNG_EMPTY);
+
+/* Read the CRC from the file and compare it to the libpng calculated CRC */
+PNG_INTERNAL_FUNCTION(int,png_crc_error,(png_structrp png_ptr),PNG_EMPTY);
+
+/* Calculate the CRC over a section of data. Note that we are only
+ * passing a maximum of 64K on systems that have this as a memory limit,
+ * since this is the maximum buffer size we can specify.
+ */
+PNG_INTERNAL_FUNCTION(void,png_calculate_crc,(png_structrp png_ptr,
+ png_const_bytep ptr, png_size_t length),PNG_EMPTY);
+
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_flush,(png_structrp png_ptr),PNG_EMPTY);
+#endif
+
+/* Write various chunks */
+
+/* Write the IHDR chunk, and update the png_struct with the necessary
+ * information.
+ */
+PNG_INTERNAL_FUNCTION(void,png_write_IHDR,(png_structrp png_ptr,
+ png_uint_32 width, png_uint_32 height, int bit_depth, int color_type,
+ int compression_method, int filter_method, int interlace_method),PNG_EMPTY);
+
+PNG_INTERNAL_FUNCTION(void,png_write_PLTE,(png_structrp png_ptr,
+ png_const_colorp palette, png_uint_32 num_pal),PNG_EMPTY);
+
+PNG_INTERNAL_FUNCTION(void,png_compress_IDAT,(png_structrp png_ptr,
+ png_const_bytep row_data, png_alloc_size_t row_data_length, int flush),
+ PNG_EMPTY);
+
+PNG_INTERNAL_FUNCTION(void,png_write_IEND,(png_structrp png_ptr),PNG_EMPTY);
+
+#ifdef PNG_WRITE_gAMA_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_gAMA_fixed,(png_structrp png_ptr,
+ png_fixed_point file_gamma),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_sBIT_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_sBIT,(png_structrp png_ptr,
+ png_const_color_8p sbit, int color_type),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_cHRM_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_cHRM_fixed,(png_structrp png_ptr,
+ const png_xy *xy), PNG_EMPTY);
+ /* The xy value must have been previously validated */
+#endif
+
+#ifdef PNG_WRITE_sRGB_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_sRGB,(png_structrp png_ptr,
+ int intent),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_iCCP_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_iCCP,(png_structrp png_ptr,
+ png_const_charp name, png_const_bytep profile), PNG_EMPTY);
+ /* The profile must have been previously validated for correctness, the
+ * length comes from the first four bytes. Only the base, deflate,
+ * compression is supported.
+ */
+#endif
+
+#ifdef PNG_WRITE_sPLT_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_sPLT,(png_structrp png_ptr,
+ png_const_sPLT_tp palette),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_tRNS_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_tRNS,(png_structrp png_ptr,
+ png_const_bytep trans, png_const_color_16p values, int number,
+ int color_type),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_bKGD_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_bKGD,(png_structrp png_ptr,
+ png_const_color_16p values, int color_type),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_hIST_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_hIST,(png_structrp png_ptr,
+ png_const_uint_16p hist, int num_hist),PNG_EMPTY);
+#endif
+
+/* Chunks that have keywords */
+#ifdef PNG_WRITE_tEXt_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_tEXt,(png_structrp png_ptr,
+ png_const_charp key, png_const_charp text, png_size_t text_len),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_zTXt_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_zTXt,(png_structrp png_ptr, png_const_charp
+ key, png_const_charp text, png_size_t text_len, int compression),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_iTXt_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_iTXt,(png_structrp png_ptr,
+ int compression, png_const_charp key, png_const_charp lang,
+ png_const_charp lang_key, png_const_charp text),PNG_EMPTY);
+#endif
+
+#ifdef PNG_TEXT_SUPPORTED /* Added at version 1.0.14 and 1.2.4 */
+PNG_INTERNAL_FUNCTION(int,png_set_text_2,(png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_textp text_ptr, int num_text),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_oFFs_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_oFFs,(png_structrp png_ptr,
+ png_int_32 x_offset, png_int_32 y_offset, int unit_type),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_pCAL_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_pCAL,(png_structrp png_ptr,
+ png_charp purpose, png_int_32 X0, png_int_32 X1, int type, int nparams,
+ png_const_charp units, png_charpp params),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_pHYs_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_pHYs,(png_structrp png_ptr,
+ png_uint_32 x_pixels_per_unit, png_uint_32 y_pixels_per_unit,
+ int unit_type),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_tIME_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_tIME,(png_structrp png_ptr,
+ png_const_timep mod_time),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_sCAL_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_write_sCAL_s,(png_structrp png_ptr,
+ int unit, png_const_charp width, png_const_charp height),PNG_EMPTY);
+#endif
+
+/* Called when finished processing a row of data */
+PNG_INTERNAL_FUNCTION(void,png_write_finish_row,(png_structrp png_ptr),
+ PNG_EMPTY);
+
+/* Internal use only. Called before first row of data */
+PNG_INTERNAL_FUNCTION(void,png_write_start_row,(png_structrp png_ptr),
+ PNG_EMPTY);
+
+/* Combine a row of data, dealing with alpha, etc. if requested. 'row' is an
+ * array of png_ptr->width pixels. If the image is not interlaced or this
+ * is the final pass this just does a memcpy, otherwise the "display" flag
+ * is used to determine whether to copy pixels that are not in the current pass.
+ *
+ * Because 'png_do_read_interlace' (below) replicates pixels this allows this
+ * function to achieve the documented 'blocky' appearance during interlaced read
+ * if display is 1 and the 'sparkle' appearance, where existing pixels in 'row'
+ * are not changed if they are not in the current pass, when display is 0.
+ *
+ * 'display' must be 0 or 1, otherwise the memcpy will be done regardless.
+ *
+ * The API always reads from the png_struct row buffer and always assumes that
+ * it is full width (png_do_read_interlace has already been called.)
+ *
+ * This function is only ever used to write to row buffers provided by the
+ * caller of the relevant libpng API and the row must have already been
+ * transformed by the read transformations.
+ *
+ * The PNG_USE_COMPILE_TIME_MASKS option causes generation of pre-computed
+ * bitmasks for use within the code, otherwise runtime generated masks are used.
+ * The default is compile time masks.
+ */
+#ifndef PNG_USE_COMPILE_TIME_MASKS
+# define PNG_USE_COMPILE_TIME_MASKS 1
+#endif
+PNG_INTERNAL_FUNCTION(void,png_combine_row,(png_const_structrp png_ptr,
+ png_bytep row, int display),PNG_EMPTY);
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+/* Expand an interlaced row: the 'row_info' describes the pass data that has
+ * been read in and must correspond to the pixels in 'row', the pixels are
+ * expanded (moved apart) in 'row' to match the final layout, when doing this
+ * the pixels are *replicated* to the intervening space. This is essential for
+ * the correct operation of png_combine_row, above.
+ */
+PNG_INTERNAL_FUNCTION(void,png_do_read_interlace,(png_row_infop row_info,
+ png_bytep row, int pass, png_uint_32 transformations),PNG_EMPTY);
+#endif
+
+/* GRR TO DO (2.0 or whenever): simplify other internal calling interfaces */
+
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+/* Grab pixels out of a row for an interlaced pass */
+PNG_INTERNAL_FUNCTION(void,png_do_write_interlace,(png_row_infop row_info,
+ png_bytep row, int pass),PNG_EMPTY);
+#endif
+
+/* Unfilter a row: check the filter value before calling this, there is no point
+ * calling it for PNG_FILTER_VALUE_NONE.
+ */
+PNG_INTERNAL_FUNCTION(void,png_read_filter_row,(png_structrp pp, png_row_infop
+ row_info, png_bytep row, png_const_bytep prev_row, int filter),PNG_EMPTY);
+
+PNG_INTERNAL_FUNCTION(void,png_read_filter_row_up_neon,(png_row_infop row_info,
+ png_bytep row, png_const_bytep prev_row),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_read_filter_row_sub3_neon,(png_row_infop
+ row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_read_filter_row_sub4_neon,(png_row_infop
+ row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_read_filter_row_avg3_neon,(png_row_infop
+ row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_read_filter_row_avg4_neon,(png_row_infop
+ row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_read_filter_row_paeth3_neon,(png_row_infop
+ row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_read_filter_row_paeth4_neon,(png_row_infop
+ row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY);
+
+/* Choose the best filter to use and filter the row data */
+PNG_INTERNAL_FUNCTION(void,png_write_find_filter,(png_structrp png_ptr,
+ png_row_infop row_info),PNG_EMPTY);
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_read_IDAT_data,(png_structrp png_ptr,
+ png_bytep output, png_alloc_size_t avail_out),PNG_EMPTY);
+ /* Read 'avail_out' bytes of data from the IDAT stream. If the output buffer
+ * is NULL the function checks, instead, for the end of the stream. In this
+ * case a benign error will be issued if the stream end is not found or if
+ * extra data has to be consumed.
+ */
+PNG_INTERNAL_FUNCTION(void,png_read_finish_IDAT,(png_structrp png_ptr),
+ PNG_EMPTY);
+ /* This cleans up when the IDAT LZ stream does not end when the last image
+ * byte is read; there is still some pending input.
+ */
+
+PNG_INTERNAL_FUNCTION(void,png_read_finish_row,(png_structrp png_ptr),
+ PNG_EMPTY);
+ /* Finish a row while reading, dealing with interlacing passes, etc. */
+#endif
+
+/* Initialize the row buffers, etc. */
+PNG_INTERNAL_FUNCTION(void,png_read_start_row,(png_structrp png_ptr),PNG_EMPTY);
+
+#ifdef PNG_READ_TRANSFORMS_SUPPORTED
+/* Optional call to update the users info structure */
+PNG_INTERNAL_FUNCTION(void,png_read_transform_info,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+#endif
+
+/* These are the functions that do the transformations */
+#ifdef PNG_READ_FILLER_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_read_filler,(png_row_infop row_info,
+ png_bytep row, png_uint_32 filler, png_uint_32 flags),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_SWAP_ALPHA_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_read_swap_alpha,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_SWAP_ALPHA_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_write_swap_alpha,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_read_invert_alpha,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_write_invert_alpha,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#if defined(PNG_WRITE_FILLER_SUPPORTED) || \
+ defined(PNG_READ_STRIP_ALPHA_SUPPORTED)
+PNG_INTERNAL_FUNCTION(void,png_do_strip_channel,(png_row_infop row_info,
+ png_bytep row, int at_start),PNG_EMPTY);
+#endif
+
+#ifdef PNG_16BIT_SUPPORTED
+#if defined(PNG_READ_SWAP_SUPPORTED) || defined(PNG_WRITE_SWAP_SUPPORTED)
+PNG_INTERNAL_FUNCTION(void,png_do_swap,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+#endif
+
+#if defined(PNG_READ_PACKSWAP_SUPPORTED) || \
+ defined(PNG_WRITE_PACKSWAP_SUPPORTED)
+PNG_INTERNAL_FUNCTION(void,png_do_packswap,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+PNG_INTERNAL_FUNCTION(int,png_do_rgb_to_gray,(png_structrp png_ptr,
+ png_row_infop row_info, png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_gray_to_rgb,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_PACK_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_unpack,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_SHIFT_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_unshift,(png_row_infop row_info,
+ png_bytep row, png_const_color_8p sig_bits),PNG_EMPTY);
+#endif
+
+#if defined(PNG_READ_INVERT_SUPPORTED) || defined(PNG_WRITE_INVERT_SUPPORTED)
+PNG_INTERNAL_FUNCTION(void,png_do_invert,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_scale_16_to_8,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_chop,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_QUANTIZE_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_quantize,(png_row_infop row_info,
+ png_bytep row, png_const_bytep palette_lookup,
+ png_const_bytep quantize_lookup),PNG_EMPTY);
+
+# ifdef PNG_CORRECT_PALETTE_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_correct_palette,(png_structrp png_ptr,
+ png_colorp palette, int num_palette),PNG_EMPTY);
+# endif
+#endif
+
+#if defined(PNG_READ_BGR_SUPPORTED) || defined(PNG_WRITE_BGR_SUPPORTED)
+PNG_INTERNAL_FUNCTION(void,png_do_bgr,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_PACK_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_pack,(png_row_infop row_info,
+ png_bytep row, png_uint_32 bit_depth),PNG_EMPTY);
+#endif
+
+#ifdef PNG_WRITE_SHIFT_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_shift,(png_row_infop row_info,
+ png_bytep row, png_const_color_8p bit_depth),PNG_EMPTY);
+#endif
+
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED)
+PNG_INTERNAL_FUNCTION(void,png_do_compose,(png_row_infop row_info,
+ png_bytep row, png_structrp png_ptr),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_gamma,(png_row_infop row_info,
+ png_bytep row, png_structrp png_ptr),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_ALPHA_MODE_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_encode_alpha,(png_row_infop row_info,
+ png_bytep row, png_structrp png_ptr),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_EXPAND_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_expand_palette,(png_row_infop row_info,
+ png_bytep row, png_const_colorp palette, png_const_bytep trans,
+ int num_trans),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_do_expand,(png_row_infop row_info,
+ png_bytep row, png_const_color_16p trans_color),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_EXPAND_16_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_expand_16,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+/* The following decodes the appropriate chunks, and does error correction,
+ * then calls the appropriate callback for the chunk if it is valid.
+ */
+
+/* Decode the IHDR chunk */
+PNG_INTERNAL_FUNCTION(void,png_handle_IHDR,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_handle_PLTE,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_handle_IEND,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+
+#ifdef PNG_READ_bKGD_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_bKGD,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_cHRM_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_cHRM,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_gAMA_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_gAMA,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_hIST_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_hIST,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_iCCP_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_iCCP,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif /* PNG_READ_iCCP_SUPPORTED */
+
+#ifdef PNG_READ_iTXt_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_iTXt,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_oFFs_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_oFFs,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_pCAL_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_pCAL,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_pHYs_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_pHYs,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_sBIT_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_sBIT,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_sCAL_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_sCAL,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_sPLT_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_sPLT,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif /* PNG_READ_sPLT_SUPPORTED */
+
+#ifdef PNG_READ_sRGB_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_sRGB,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_tEXt_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_tEXt,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_tIME_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_tIME,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_tRNS_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_tRNS,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_zTXt_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_zTXt,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+#endif
+
+PNG_INTERNAL_FUNCTION(void,png_check_chunk_name,(png_structrp png_ptr,
+ png_uint_32 chunk_name),PNG_EMPTY);
+
+#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_handle_unknown,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length, int keep),PNG_EMPTY);
+ /* This is the function that gets called for unknown chunks. The 'keep'
+ * argument is either non-zero for a known chunk that has been set to be
+ * handled as unknown or zero for an unknown chunk. By default the function
+ * just skips the chunk or errors out if it is critical.
+ */
+
+#if defined(PNG_READ_UNKNOWN_CHUNKS_SUPPORTED) ||\
+ defined(PNG_HANDLE_AS_UNKNOWN_SUPPORTED)
+PNG_INTERNAL_FUNCTION(int,png_chunk_unknown_handling,
+ (png_const_structrp png_ptr, png_uint_32 chunk_name),PNG_EMPTY);
+ /* Exactly as the API png_handle_as_unknown() except that the argument is a
+ * 32-bit chunk name, not a string.
+ */
+#endif /* READ_UNKNOWN_CHUNKS || HANDLE_AS_UNKNOWN */
+#endif /* PNG_SET_UNKNOWN_CHUNKS_SUPPORTED */
+
+/* Handle the transformations for reading and writing */
+#ifdef PNG_READ_TRANSFORMS_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_read_transformations,(png_structrp png_ptr,
+ png_row_infop row_info),PNG_EMPTY);
+#endif
+#ifdef PNG_WRITE_TRANSFORMS_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_write_transformations,(png_structrp png_ptr,
+ png_row_infop row_info),PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_TRANSFORMS_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_init_read_transformations,(png_structrp png_ptr),
+ PNG_EMPTY);
+#endif
+
+#ifdef PNG_PROGRESSIVE_READ_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_push_read_chunk,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_read_sig,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_check_crc,(png_structrp png_ptr),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_crc_skip,(png_structrp png_ptr,
+ png_uint_32 length),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_crc_finish,(png_structrp png_ptr),
+ PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_save_buffer,(png_structrp png_ptr),
+ PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_restore_buffer,(png_structrp png_ptr,
+ png_bytep buffer, png_size_t buffer_length),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_read_IDAT,(png_structrp png_ptr),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_process_IDAT_data,(png_structrp png_ptr,
+ png_bytep buffer, png_size_t buffer_length),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_process_row,(png_structrp png_ptr),
+ PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_handle_unknown,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_have_info,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_have_end,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_have_row,(png_structrp png_ptr,
+ png_bytep row),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_read_end,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_process_some_data,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_read_push_finish_row,(png_structrp png_ptr),
+ PNG_EMPTY);
+# ifdef PNG_READ_tEXt_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_push_handle_tEXt,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_read_tEXt,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+# endif
+# ifdef PNG_READ_zTXt_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_push_handle_zTXt,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_read_zTXt,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+# endif
+# ifdef PNG_READ_iTXt_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_push_handle_iTXt,(png_structrp png_ptr,
+ png_inforp info_ptr, png_uint_32 length),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_push_read_iTXt,(png_structrp png_ptr,
+ png_inforp info_ptr),PNG_EMPTY);
+# endif
+
+#endif /* PNG_PROGRESSIVE_READ_SUPPORTED */
+
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_do_read_intrapixel,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_do_write_intrapixel,(png_row_infop row_info,
+ png_bytep row),PNG_EMPTY);
+#endif
+
+/* Added at libpng version 1.6.0 */
+#ifdef PNG_GAMMA_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_colorspace_set_gamma,(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, png_fixed_point gAMA), PNG_EMPTY);
+ /* Set the colorspace gamma with a value provided by the application or by
+ * the gAMA chunk on read. The value will override anything set by an ICC
+ * profile.
+ */
+
+PNG_INTERNAL_FUNCTION(void,png_colorspace_sync_info,(png_const_structrp png_ptr,
+ png_inforp info_ptr), PNG_EMPTY);
+ /* Synchronize the info 'valid' flags with the colorspace */
+
+PNG_INTERNAL_FUNCTION(void,png_colorspace_sync,(png_const_structrp png_ptr,
+ png_inforp info_ptr), PNG_EMPTY);
+ /* Copy the png_struct colorspace to the info_struct and call the above to
+ * synchronize the flags. Checks for NULL info_ptr and does nothing.
+ */
+#endif
+
+/* Added at libpng version 1.4.0 */
+#ifdef PNG_COLORSPACE_SUPPORTED
+/* These internal functions are for maintaining the colorspace structure within
+ * a png_info or png_struct (or, indeed, both).
+ */
+PNG_INTERNAL_FUNCTION(int,png_colorspace_set_chromaticities,
+ (png_const_structrp png_ptr, png_colorspacerp colorspace, const png_xy *xy,
+ int preferred), PNG_EMPTY);
+
+PNG_INTERNAL_FUNCTION(int,png_colorspace_set_endpoints,
+ (png_const_structrp png_ptr, png_colorspacerp colorspace, const png_XYZ *XYZ,
+ int preferred), PNG_EMPTY);
+
+#ifdef PNG_sRGB_SUPPORTED
+PNG_INTERNAL_FUNCTION(int,png_colorspace_set_sRGB,(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, int intent), PNG_EMPTY);
+ /* This does set the colorspace gAMA and cHRM values too, but doesn't set the
+ * flags to write them, if it returns false there was a problem and an error
+ * message has already been output (but the colorspace may still need to be
+ * synced to record the invalid flag).
+ */
+#endif /* sRGB */
+
+#ifdef PNG_iCCP_SUPPORTED
+PNG_INTERNAL_FUNCTION(int,png_colorspace_set_ICC,(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, png_const_charp name,
+ png_uint_32 profile_length, png_const_bytep profile, int color_type),
+ PNG_EMPTY);
+ /* The 'name' is used for information only */
+
+/* Routines for checking parts of an ICC profile. */
+PNG_INTERNAL_FUNCTION(int,png_icc_check_length,(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, png_const_charp name,
+ png_uint_32 profile_length), PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(int,png_icc_check_header,(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, png_const_charp name,
+ png_uint_32 profile_length,
+ png_const_bytep profile /* first 132 bytes only */, int color_type),
+ PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(int,png_icc_check_tag_table,(png_const_structrp png_ptr,
+ png_colorspacerp colorspace, png_const_charp name,
+ png_uint_32 profile_length,
+ png_const_bytep profile /* header plus whole tag table */), PNG_EMPTY);
+#ifdef PNG_sRGB_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_icc_set_sRGB,(
+ png_const_structrp png_ptr, png_colorspacerp colorspace,
+ png_const_bytep profile, uLong adler), PNG_EMPTY);
+ /* 'adler' is the Adler32 checksum of the uncompressed profile data. It may
+ * be zero to indicate that it is not available. It is used, if provided,
+ * as a fast check on the profile when checking to see if it is sRGB.
+ */
+#endif
+#endif /* iCCP */
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_colorspace_set_rgb_coefficients,
+ (png_structrp png_ptr), PNG_EMPTY);
+ /* Set the rgb_to_gray coefficients from the colorspace Y values */
+#endif /* READ_RGB_TO_GRAY */
+#endif /* COLORSPACE */
+
+/* Added at libpng version 1.4.0 */
+PNG_INTERNAL_FUNCTION(void,png_check_IHDR,(png_const_structrp png_ptr,
+ png_uint_32 width, png_uint_32 height, int bit_depth,
+ int color_type, int interlace_type, int compression_type,
+ int filter_type),PNG_EMPTY);
+
+/* Added at libpng version 1.5.10 */
+#if defined(PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED) || \
+ defined(PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED)
+PNG_INTERNAL_FUNCTION(void,png_do_check_palette_indexes,
+ (png_structrp png_ptr, png_row_infop row_info),PNG_EMPTY);
+#endif
+
+#if defined(PNG_FLOATING_POINT_SUPPORTED) && defined(PNG_ERROR_TEXT_SUPPORTED)
+PNG_INTERNAL_FUNCTION(void,png_fixed_error,(png_const_structrp png_ptr,
+ png_const_charp name),PNG_NORETURN);
+#endif
+
+/* Puts 'string' into 'buffer' at buffer[pos], taking care never to overwrite
+ * the end. Always leaves the buffer nul terminated. Never errors out (and
+ * there is no error code.)
+ */
+PNG_INTERNAL_FUNCTION(size_t,png_safecat,(png_charp buffer, size_t bufsize,
+ size_t pos, png_const_charp string),PNG_EMPTY);
+
+/* Various internal functions to handle formatted warning messages, currently
+ * only implemented for warnings.
+ */
+#if defined(PNG_WARNINGS_SUPPORTED) || defined(PNG_TIME_RFC1123_SUPPORTED)
+/* Utility to dump an unsigned value into a buffer, given a start pointer and
+ * and end pointer (which should point just *beyond* the end of the buffer!)
+ * Returns the pointer to the start of the formatted string. This utility only
+ * does unsigned values.
+ */
+PNG_INTERNAL_FUNCTION(png_charp,png_format_number,(png_const_charp start,
+ png_charp end, int format, png_alloc_size_t number),PNG_EMPTY);
+
+/* Convenience macro that takes an array: */
+#define PNG_FORMAT_NUMBER(buffer,format,number) \
+ png_format_number(buffer, buffer + (sizeof buffer), format, number)
+
+/* Suggested size for a number buffer (enough for 64 bits and a sign!) */
+#define PNG_NUMBER_BUFFER_SIZE 24
+
+/* These are the integer formats currently supported, the name is formed from
+ * the standard printf(3) format string.
+ */
+#define PNG_NUMBER_FORMAT_u 1 /* chose unsigned API! */
+#define PNG_NUMBER_FORMAT_02u 2
+#define PNG_NUMBER_FORMAT_d 1 /* chose signed API! */
+#define PNG_NUMBER_FORMAT_02d 2
+#define PNG_NUMBER_FORMAT_x 3
+#define PNG_NUMBER_FORMAT_02x 4
+#define PNG_NUMBER_FORMAT_fixed 5 /* choose the signed API */
+#endif
+
+#ifdef PNG_WARNINGS_SUPPORTED
+/* New defines and members adding in libpng-1.5.4 */
+# define PNG_WARNING_PARAMETER_SIZE 32
+# define PNG_WARNING_PARAMETER_COUNT 8 /* Maximum 9; see pngerror.c */
+
+/* An l-value of this type has to be passed to the APIs below to cache the
+ * values of the parameters to a formatted warning message.
+ */
+typedef char png_warning_parameters[PNG_WARNING_PARAMETER_COUNT][
+ PNG_WARNING_PARAMETER_SIZE];
+
+PNG_INTERNAL_FUNCTION(void,png_warning_parameter,(png_warning_parameters p,
+ int number, png_const_charp string),PNG_EMPTY);
+ /* Parameters are limited in size to PNG_WARNING_PARAMETER_SIZE characters,
+ * including the trailing '\0'.
+ */
+PNG_INTERNAL_FUNCTION(void,png_warning_parameter_unsigned,
+ (png_warning_parameters p, int number, int format, png_alloc_size_t value),
+ PNG_EMPTY);
+ /* Use png_alloc_size_t because it is an unsigned type as big as any we
+ * need to output. Use the following for a signed value.
+ */
+PNG_INTERNAL_FUNCTION(void,png_warning_parameter_signed,
+ (png_warning_parameters p, int number, int format, png_int_32 value),
+ PNG_EMPTY);
+
+PNG_INTERNAL_FUNCTION(void,png_formatted_warning,(png_const_structrp png_ptr,
+ png_warning_parameters p, png_const_charp message),PNG_EMPTY);
+ /* 'message' follows the X/Open approach of using @1, @2 to insert
+ * parameters previously supplied using the above functions. Errors in
+ * specifying the parameters will simply result in garbage substitutions.
+ */
+#endif
+
+#ifdef PNG_BENIGN_ERRORS_SUPPORTED
+/* Application errors (new in 1.6); use these functions (declared below) for
+ * errors in the parameters or order of API function calls on read. The
+ * 'warning' should be used for an error that can be handled completely; the
+ * 'error' for one which can be handled safely but which may lose application
+ * information or settings.
+ *
+ * By default these both result in a png_error call prior to release, while in a
+ * released version the 'warning' is just a warning. However if the application
+ * explicitly disables benign errors (explicitly permitting the code to lose
+ * information) they both turn into warnings.
+ *
+ * If benign errors aren't supported they end up as the corresponding base call
+ * (png_warning or png_error.)
+ */
+PNG_INTERNAL_FUNCTION(void,png_app_warning,(png_const_structrp png_ptr,
+ png_const_charp message),PNG_EMPTY);
+ /* The application provided invalid parameters to an API function or called
+ * an API function at the wrong time, libpng can completely recover.
+ */
+
+PNG_INTERNAL_FUNCTION(void,png_app_error,(png_const_structrp png_ptr,
+ png_const_charp message),PNG_EMPTY);
+ /* As above but libpng will ignore the call, or attempt some other partial
+ * recovery from the error.
+ */
+#else
+# define png_app_warning(pp,s) png_warning(pp,s)
+# define png_app_error(pp,s) png_error(pp,s)
+#endif
+
+PNG_INTERNAL_FUNCTION(void,png_chunk_report,(png_const_structrp png_ptr,
+ png_const_charp message, int error),PNG_EMPTY);
+ /* Report a recoverable issue in chunk data. On read this is used to report
+ * a problem found while reading a particular chunk and the
+ * png_chunk_benign_error or png_chunk_warning function is used as
+ * appropriate. On write this is used to report an error that comes from
+ * data set via an application call to a png_set_ API and png_app_error or
+ * png_app_warning is used as appropriate.
+ *
+ * The 'error' parameter must have one of the following values:
+ */
+#define PNG_CHUNK_WARNING 0 /* never an error */
+#define PNG_CHUNK_WRITE_ERROR 1 /* an error only on write */
+#define PNG_CHUNK_ERROR 2 /* always an error */
+
+/* ASCII to FP interfaces, currently only implemented if sCAL
+ * support is required.
+ */
+#if defined(PNG_sCAL_SUPPORTED)
+/* MAX_DIGITS is actually the maximum number of characters in an sCAL
+ * width or height, derived from the precision (number of significant
+ * digits - a build time settable option) and assumptions about the
+ * maximum ridiculous exponent.
+ */
+#define PNG_sCAL_MAX_DIGITS (PNG_sCAL_PRECISION+1/*.*/+1/*E*/+10/*exponent*/)
+
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_ascii_from_fp,(png_const_structrp png_ptr,
+ png_charp ascii, png_size_t size, double fp, unsigned int precision),
+ PNG_EMPTY);
+#endif /* FLOATING_POINT */
+
+#ifdef PNG_FIXED_POINT_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_ascii_from_fixed,(png_const_structrp png_ptr,
+ png_charp ascii, png_size_t size, png_fixed_point fp),PNG_EMPTY);
+#endif /* FIXED_POINT */
+#endif /* sCAL */
+
+#if defined(PNG_sCAL_SUPPORTED) || defined(PNG_pCAL_SUPPORTED)
+/* An internal API to validate the format of a floating point number.
+ * The result is the index of the next character. If the number is
+ * not valid it will be the index of a character in the supposed number.
+ *
+ * The format of a number is defined in the PNG extensions specification
+ * and this API is strictly conformant to that spec, not anyone elses!
+ *
+ * The format as a regular expression is:
+ *
+ * [+-]?[0-9]+.?([Ee][+-]?[0-9]+)?
+ *
+ * or:
+ *
+ * [+-]?.[0-9]+(.[0-9]+)?([Ee][+-]?[0-9]+)?
+ *
+ * The complexity is that either integer or fraction must be present and the
+ * fraction is permitted to have no digits only if the integer is present.
+ *
+ * NOTE: The dangling E problem.
+ * There is a PNG valid floating point number in the following:
+ *
+ * PNG floating point numbers are not greedy.
+ *
+ * Working this out requires *TWO* character lookahead (because of the
+ * sign), the parser does not do this - it will fail at the 'r' - this
+ * doesn't matter for PNG sCAL chunk values, but it requires more care
+ * if the value were ever to be embedded in something more complex. Use
+ * ANSI-C strtod if you need the lookahead.
+ */
+/* State table for the parser. */
+#define PNG_FP_INTEGER 0 /* before or in integer */
+#define PNG_FP_FRACTION 1 /* before or in fraction */
+#define PNG_FP_EXPONENT 2 /* before or in exponent */
+#define PNG_FP_STATE 3 /* mask for the above */
+#define PNG_FP_SAW_SIGN 4 /* Saw +/- in current state */
+#define PNG_FP_SAW_DIGIT 8 /* Saw a digit in current state */
+#define PNG_FP_SAW_DOT 16 /* Saw a dot in current state */
+#define PNG_FP_SAW_E 32 /* Saw an E (or e) in current state */
+#define PNG_FP_SAW_ANY 60 /* Saw any of the above 4 */
+
+/* These three values don't affect the parser. They are set but not used.
+ */
+#define PNG_FP_WAS_VALID 64 /* Preceding substring is a valid fp number */
+#define PNG_FP_NEGATIVE 128 /* A negative number, including "-0" */
+#define PNG_FP_NONZERO 256 /* A non-zero value */
+#define PNG_FP_STICKY 448 /* The above three flags */
+
+/* This is available for the caller to store in 'state' if required. Do not
+ * call the parser after setting it (the parser sometimes clears it.)
+ */
+#define PNG_FP_INVALID 512 /* Available for callers as a distinct value */
+
+/* Result codes for the parser (boolean - true meants ok, false means
+ * not ok yet.)
+ */
+#define PNG_FP_MAYBE 0 /* The number may be valid in the future */
+#define PNG_FP_OK 1 /* The number is valid */
+
+/* Tests on the sticky non-zero and negative flags. To pass these checks
+ * the state must also indicate that the whole number is valid - this is
+ * achieved by testing PNG_FP_SAW_DIGIT (see the implementation for why this
+ * is equivalent to PNG_FP_OK above.)
+ */
+#define PNG_FP_NZ_MASK (PNG_FP_SAW_DIGIT | PNG_FP_NEGATIVE | PNG_FP_NONZERO)
+ /* NZ_MASK: the string is valid and a non-zero negative value */
+#define PNG_FP_Z_MASK (PNG_FP_SAW_DIGIT | PNG_FP_NONZERO)
+ /* Z MASK: the string is valid and a non-zero value. */
+ /* PNG_FP_SAW_DIGIT: the string is valid. */
+#define PNG_FP_IS_ZERO(state) (((state) & PNG_FP_Z_MASK) == PNG_FP_SAW_DIGIT)
+#define PNG_FP_IS_POSITIVE(state) (((state) & PNG_FP_NZ_MASK) == PNG_FP_Z_MASK)
+#define PNG_FP_IS_NEGATIVE(state) (((state) & PNG_FP_NZ_MASK) == PNG_FP_NZ_MASK)
+
+/* The actual parser. This can be called repeatedly. It updates
+ * the index into the string and the state variable (which must
+ * be initialized to 0). It returns a result code, as above. There
+ * is no point calling the parser any more if it fails to advance to
+ * the end of the string - it is stuck on an invalid character (or
+ * terminated by '\0').
+ *
+ * Note that the pointer will consume an E or even an E+ and then leave
+ * a 'maybe' state even though a preceding integer.fraction is valid.
+ * The PNG_FP_WAS_VALID flag indicates that a preceding substring was
+ * a valid number. It's possible to recover from this by calling
+ * the parser again (from the start, with state 0) but with a string
+ * that omits the last character (i.e. set the size to the index of
+ * the problem character.) This has not been tested within libpng.
+ */
+PNG_INTERNAL_FUNCTION(int,png_check_fp_number,(png_const_charp string,
+ png_size_t size, int *statep, png_size_tp whereami),PNG_EMPTY);
+
+/* This is the same but it checks a complete string and returns true
+ * only if it just contains a floating point number. As of 1.5.4 this
+ * function also returns the state at the end of parsing the number if
+ * it was valid (otherwise it returns 0.) This can be used for testing
+ * for negative or zero values using the sticky flag.
+ */
+PNG_INTERNAL_FUNCTION(int,png_check_fp_string,(png_const_charp string,
+ png_size_t size),PNG_EMPTY);
+#endif /* pCAL || sCAL */
+
+#if defined(PNG_READ_GAMMA_SUPPORTED) ||\
+ defined(PNG_INCH_CONVERSIONS_SUPPORTED) || defined(PNG_READ_pHYs_SUPPORTED)
+/* Added at libpng version 1.5.0 */
+/* This is a utility to provide a*times/div (rounded) and indicate
+ * if there is an overflow. The result is a boolean - false (0)
+ * for overflow, true (1) if no overflow, in which case *res
+ * holds the result.
+ */
+PNG_INTERNAL_FUNCTION(int,png_muldiv,(png_fixed_point_p res, png_fixed_point a,
+ png_int_32 multiplied_by, png_int_32 divided_by),PNG_EMPTY);
+#endif
+
+#if defined(PNG_READ_GAMMA_SUPPORTED) || defined(PNG_INCH_CONVERSIONS_SUPPORTED)
+/* Same deal, but issue a warning on overflow and return 0. */
+PNG_INTERNAL_FUNCTION(png_fixed_point,png_muldiv_warn,
+ (png_const_structrp png_ptr, png_fixed_point a, png_int_32 multiplied_by,
+ png_int_32 divided_by),PNG_EMPTY);
+#endif
+
+#ifdef PNG_GAMMA_SUPPORTED
+/* Calculate a reciprocal - used for gamma values. This returns
+ * 0 if the argument is 0 in order to maintain an undefined value;
+ * there are no warnings.
+ */
+PNG_INTERNAL_FUNCTION(png_fixed_point,png_reciprocal,(png_fixed_point a),
+ PNG_EMPTY);
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+/* The same but gives a reciprocal of the product of two fixed point
+ * values. Accuracy is suitable for gamma calculations but this is
+ * not exact - use png_muldiv for that. Only required at present on read.
+ */
+PNG_INTERNAL_FUNCTION(png_fixed_point,png_reciprocal2,(png_fixed_point a,
+ png_fixed_point b),PNG_EMPTY);
+#endif
+
+/* Return true if the gamma value is significantly different from 1.0 */
+PNG_INTERNAL_FUNCTION(int,png_gamma_significant,(png_fixed_point gamma_value),
+ PNG_EMPTY);
+#endif
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+/* Internal fixed point gamma correction. These APIs are called as
+ * required to convert single values - they don't need to be fast,
+ * they are not used when processing image pixel values.
+ *
+ * While the input is an 'unsigned' value it must actually be the
+ * correct bit value - 0..255 or 0..65535 as required.
+ */
+PNG_INTERNAL_FUNCTION(png_uint_16,png_gamma_correct,(png_structrp png_ptr,
+ unsigned int value, png_fixed_point gamma_value),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(png_uint_16,png_gamma_16bit_correct,(unsigned int value,
+ png_fixed_point gamma_value),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(png_byte,png_gamma_8bit_correct,(unsigned int value,
+ png_fixed_point gamma_value),PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_destroy_gamma_table,(png_structrp png_ptr),
+ PNG_EMPTY);
+PNG_INTERNAL_FUNCTION(void,png_build_gamma_table,(png_structrp png_ptr,
+ int bit_depth),PNG_EMPTY);
+#endif
+
+/* SIMPLIFIED READ/WRITE SUPPORT */
+#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\
+ defined(PNG_SIMPLIFIED_WRITE_SUPPORTED)
+/* The internal structure that png_image::opaque points to. */
+typedef struct png_control
+{
+ png_structp png_ptr;
+ png_infop info_ptr;
+ png_voidp error_buf; /* Always a jmp_buf at present. */
+
+ png_const_bytep memory; /* Memory buffer. */
+ png_size_t size; /* Size of the memory buffer. */
+
+ unsigned int for_write :1; /* Otherwise it is a read structure */
+ unsigned int owned_file :1; /* We own the file in io_ptr */
+} png_control;
+
+/* Return the pointer to the jmp_buf from a png_control: necessary because C
+ * does not reveal the type of the elements of jmp_buf.
+ */
+#ifdef __cplusplus
+# define png_control_jmp_buf(pc) (((jmp_buf*)((pc)->error_buf))[0])
+#else
+# define png_control_jmp_buf(pc) ((pc)->error_buf)
+#endif
+
+/* Utility to safely execute a piece of libpng code catching and logging any
+ * errors that might occur. Returns true on success, false on failure (either
+ * of the function or as a result of a png_error.)
+ */
+PNG_INTERNAL_FUNCTION(void,png_safe_error,(png_structp png_ptr,
+ png_const_charp error_message),PNG_NORETURN);
+
+#ifdef PNG_WARNINGS_SUPPORTED
+PNG_INTERNAL_FUNCTION(void,png_safe_warning,(png_structp png_ptr,
+ png_const_charp warning_message),PNG_EMPTY);
+#else
+# define png_safe_warning 0/*dummy argument*/
+#endif
+
+PNG_INTERNAL_FUNCTION(int,png_safe_execute,(png_imagep image,
+ int (*function)(png_voidp), png_voidp arg),PNG_EMPTY);
+
+/* Utility to log an error; this also cleans up the png_image; the function
+ * always returns 0 (false).
+ */
+PNG_INTERNAL_FUNCTION(int,png_image_error,(png_imagep image,
+ png_const_charp error_message),PNG_EMPTY);
+
+#ifndef PNG_SIMPLIFIED_READ_SUPPORTED
+/* png_image_free is used by the write code but not exported */
+PNG_INTERNAL_FUNCTION(void, png_image_free, (png_imagep image), PNG_EMPTY);
+#endif /* !SIMPLIFIED_READ */
+
+#endif /* SIMPLIFIED READ/WRITE */
+
+/* These are initialization functions for hardware specific PNG filter
+ * optimizations; list these here then select the appropriate one at compile
+ * time using the macro PNG_FILTER_OPTIMIZATIONS. If the macro is not defined
+ * the generic code is used.
+ */
+#ifdef PNG_FILTER_OPTIMIZATIONS
+PNG_INTERNAL_FUNCTION(void, PNG_FILTER_OPTIMIZATIONS, (png_structp png_ptr,
+ unsigned int bpp), PNG_EMPTY);
+ /* Just declare the optimization that will be used */
+#else
+ /* List *all* the possible optimizations here - this branch is required if
+ * the builder of libpng passes the definition of PNG_FILTER_OPTIMIZATIONS in
+ * CFLAGS in place of CPPFLAGS *and* uses symbol prefixing.
+ */
+PNG_INTERNAL_FUNCTION(void, png_init_filter_functions_neon,
+ (png_structp png_ptr, unsigned int bpp), PNG_EMPTY);
+#endif
+
+/* Maintainer: Put new private prototypes here ^ */
+
+#include "pngdebug.h"
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* PNG_VERSION_INFO_ONLY */
+#endif /* PNGPRIV_H */
diff --git a/ml/dlib/dlib/external/libpng/pngread.c b/ml/dlib/dlib/external/libpng/pngread.c
new file mode 100644
index 000000000..8f96ca23e
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngread.c
@@ -0,0 +1,4000 @@
+
+/* pngread.c - read a PNG file
+ *
+ * Last changed in libpng 1.6.1 [March 28, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ * This file contains routines that an application calls directly to
+ * read a PNG file or stream.
+ */
+
+#include "pngpriv.h"
+#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) && defined(PNG_STDIO_SUPPORTED)
+# include <errno.h>
+#endif
+
+#ifdef PNG_READ_SUPPORTED
+
+/* Create a PNG structure for reading, and allocate any memory needed. */
+PNG_FUNCTION(png_structp,PNGAPI
+png_create_read_struct,(png_const_charp user_png_ver, png_voidp error_ptr,
+ png_error_ptr error_fn, png_error_ptr warn_fn),PNG_ALLOCATED)
+{
+#ifndef PNG_USER_MEM_SUPPORTED
+ png_structp png_ptr = png_create_png_struct(user_png_ver, error_ptr,
+ error_fn, warn_fn, NULL, NULL, NULL);
+#else
+ return png_create_read_struct_2(user_png_ver, error_ptr, error_fn,
+ warn_fn, NULL, NULL, NULL);
+}
+
+/* Alternate create PNG structure for reading, and allocate any memory
+ * needed.
+ */
+PNG_FUNCTION(png_structp,PNGAPI
+png_create_read_struct_2,(png_const_charp user_png_ver, png_voidp error_ptr,
+ png_error_ptr error_fn, png_error_ptr warn_fn, png_voidp mem_ptr,
+ png_malloc_ptr malloc_fn, png_free_ptr free_fn),PNG_ALLOCATED)
+{
+ png_structp png_ptr = png_create_png_struct(user_png_ver, error_ptr,
+ error_fn, warn_fn, mem_ptr, malloc_fn, free_fn);
+#endif /* PNG_USER_MEM_SUPPORTED */
+
+ if (png_ptr != NULL)
+ {
+ png_ptr->mode = PNG_IS_READ_STRUCT;
+
+ /* Added in libpng-1.6.0; this can be used to detect a read structure if
+ * required (it will be zero in a write structure.)
+ */
+# ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+ png_ptr->IDAT_read_size = PNG_IDAT_READ_SIZE;
+# endif
+
+# ifdef PNG_BENIGN_READ_ERRORS_SUPPORTED
+ png_ptr->flags |= PNG_FLAG_BENIGN_ERRORS_WARN;
+
+ /* In stable builds only warn if an application error can be completely
+ * handled.
+ */
+# if PNG_LIBPNG_BUILD_BASE_TYPE >= PNG_LIBPNG_BUILD_RC
+ png_ptr->flags |= PNG_FLAG_APP_WARNINGS_WARN;
+# endif
+# endif
+
+ /* TODO: delay this, it can be done in png_init_io (if the app doesn't
+ * do it itself) avoiding setting the default function if it is not
+ * required.
+ */
+ png_set_read_fn(png_ptr, NULL, NULL);
+ }
+
+ return png_ptr;
+}
+
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Read the information before the actual image data. This has been
+ * changed in v0.90 to allow reading a file that already has the magic
+ * bytes read from the stream. You can tell libpng how many bytes have
+ * been read from the beginning of the stream (up to the maximum of 8)
+ * via png_set_sig_bytes(), and we will only check the remaining bytes
+ * here. The application can then have access to the signature bytes we
+ * read if it is determined that this isn't a valid PNG file.
+ */
+void PNGAPI
+png_read_info(png_structrp png_ptr, png_inforp info_ptr)
+{
+#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+ int keep;
+#endif
+
+ png_debug(1, "in png_read_info");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ /* Read and check the PNG file signature. */
+ png_read_sig(png_ptr, info_ptr);
+
+ for (;;)
+ {
+ png_uint_32 length = png_read_chunk_header(png_ptr);
+ png_uint_32 chunk_name = png_ptr->chunk_name;
+
+ /* IDAT logic needs to happen here to simplify getting the two flags
+ * right.
+ */
+ if (chunk_name == png_IDAT)
+ {
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "Missing IHDR before IDAT");
+
+ else if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE &&
+ !(png_ptr->mode & PNG_HAVE_PLTE))
+ png_chunk_error(png_ptr, "Missing PLTE before IDAT");
+
+ else if (png_ptr->mode & PNG_AFTER_IDAT)
+ png_chunk_benign_error(png_ptr, "Too many IDATs found");
+
+ png_ptr->mode |= PNG_HAVE_IDAT;
+ }
+
+ else if (png_ptr->mode & PNG_HAVE_IDAT)
+ png_ptr->mode |= PNG_AFTER_IDAT;
+
+ /* This should be a binary subdivision search or a hash for
+ * matching the chunk name rather than a linear search.
+ */
+ if (chunk_name == png_IHDR)
+ png_handle_IHDR(png_ptr, info_ptr, length);
+
+ else if (chunk_name == png_IEND)
+ png_handle_IEND(png_ptr, info_ptr, length);
+
+#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+ else if ((keep = png_chunk_unknown_handling(png_ptr, chunk_name)) != 0)
+ {
+ png_handle_unknown(png_ptr, info_ptr, length, keep);
+
+ if (chunk_name == png_PLTE)
+ png_ptr->mode |= PNG_HAVE_PLTE;
+
+ else if (chunk_name == png_IDAT)
+ {
+ png_ptr->idat_size = 0; /* It has been consumed */
+ break;
+ }
+ }
+#endif
+ else if (chunk_name == png_PLTE)
+ png_handle_PLTE(png_ptr, info_ptr, length);
+
+ else if (chunk_name == png_IDAT)
+ {
+ png_ptr->idat_size = length;
+ break;
+ }
+
+#ifdef PNG_READ_bKGD_SUPPORTED
+ else if (chunk_name == png_bKGD)
+ png_handle_bKGD(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_cHRM_SUPPORTED
+ else if (chunk_name == png_cHRM)
+ png_handle_cHRM(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_gAMA_SUPPORTED
+ else if (chunk_name == png_gAMA)
+ png_handle_gAMA(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_hIST_SUPPORTED
+ else if (chunk_name == png_hIST)
+ png_handle_hIST(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_oFFs_SUPPORTED
+ else if (chunk_name == png_oFFs)
+ png_handle_oFFs(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_pCAL_SUPPORTED
+ else if (chunk_name == png_pCAL)
+ png_handle_pCAL(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_sCAL_SUPPORTED
+ else if (chunk_name == png_sCAL)
+ png_handle_sCAL(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_pHYs_SUPPORTED
+ else if (chunk_name == png_pHYs)
+ png_handle_pHYs(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_sBIT_SUPPORTED
+ else if (chunk_name == png_sBIT)
+ png_handle_sBIT(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_sRGB_SUPPORTED
+ else if (chunk_name == png_sRGB)
+ png_handle_sRGB(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_iCCP_SUPPORTED
+ else if (chunk_name == png_iCCP)
+ png_handle_iCCP(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_sPLT_SUPPORTED
+ else if (chunk_name == png_sPLT)
+ png_handle_sPLT(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_tEXt_SUPPORTED
+ else if (chunk_name == png_tEXt)
+ png_handle_tEXt(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_tIME_SUPPORTED
+ else if (chunk_name == png_tIME)
+ png_handle_tIME(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_tRNS_SUPPORTED
+ else if (chunk_name == png_tRNS)
+ png_handle_tRNS(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_zTXt_SUPPORTED
+ else if (chunk_name == png_zTXt)
+ png_handle_zTXt(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_iTXt_SUPPORTED
+ else if (chunk_name == png_iTXt)
+ png_handle_iTXt(png_ptr, info_ptr, length);
+#endif
+
+ else
+ png_handle_unknown(png_ptr, info_ptr, length,
+ PNG_HANDLE_CHUNK_AS_DEFAULT);
+ }
+}
+#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */
+
+/* Optional call to update the users info_ptr structure */
+void PNGAPI
+png_read_update_info(png_structrp png_ptr, png_inforp info_ptr)
+{
+ png_debug(1, "in png_read_update_info");
+
+ if (png_ptr != NULL)
+ {
+ if ((png_ptr->flags & PNG_FLAG_ROW_INIT) == 0)
+ {
+ png_read_start_row(png_ptr);
+
+# ifdef PNG_READ_TRANSFORMS_SUPPORTED
+ png_read_transform_info(png_ptr, info_ptr);
+# else
+ PNG_UNUSED(info_ptr)
+# endif
+ }
+
+ /* New in 1.6.0 this avoids the bug of doing the initializations twice */
+ else
+ png_app_error(png_ptr,
+ "png_read_update_info/png_start_read_image: duplicate call");
+ }
+}
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Initialize palette, background, etc, after transformations
+ * are set, but before any reading takes place. This allows
+ * the user to obtain a gamma-corrected palette, for example.
+ * If the user doesn't call this, we will do it ourselves.
+ */
+void PNGAPI
+png_start_read_image(png_structrp png_ptr)
+{
+ png_debug(1, "in png_start_read_image");
+
+ if (png_ptr != NULL)
+ {
+ if ((png_ptr->flags & PNG_FLAG_ROW_INIT) == 0)
+ png_read_start_row(png_ptr);
+
+ /* New in 1.6.0 this avoids the bug of doing the initializations twice */
+ else
+ png_app_error(png_ptr,
+ "png_start_read_image/png_read_update_info: duplicate call");
+ }
+}
+#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+void PNGAPI
+png_read_row(png_structrp png_ptr, png_bytep row, png_bytep dsp_row)
+{
+ png_row_info row_info;
+
+ if (png_ptr == NULL)
+ return;
+
+ png_debug2(1, "in png_read_row (row %lu, pass %d)",
+ (unsigned long)png_ptr->row_number, png_ptr->pass);
+
+ /* png_read_start_row sets the information (in particular iwidth) for this
+ * interlace pass.
+ */
+ if (!(png_ptr->flags & PNG_FLAG_ROW_INIT))
+ png_read_start_row(png_ptr);
+
+ /* 1.5.6: row_info moved out of png_struct to a local here. */
+ row_info.width = png_ptr->iwidth; /* NOTE: width of current interlaced row */
+ row_info.color_type = png_ptr->color_type;
+ row_info.bit_depth = png_ptr->bit_depth;
+ row_info.channels = png_ptr->channels;
+ row_info.pixel_depth = png_ptr->pixel_depth;
+ row_info.rowbytes = PNG_ROWBYTES(row_info.pixel_depth, row_info.width);
+
+ if (png_ptr->row_number == 0 && png_ptr->pass == 0)
+ {
+ /* Check for transforms that have been set but were defined out */
+#if defined(PNG_WRITE_INVERT_SUPPORTED) && !defined(PNG_READ_INVERT_SUPPORTED)
+ if (png_ptr->transformations & PNG_INVERT_MONO)
+ png_warning(png_ptr, "PNG_READ_INVERT_SUPPORTED is not defined");
+#endif
+
+#if defined(PNG_WRITE_FILLER_SUPPORTED) && !defined(PNG_READ_FILLER_SUPPORTED)
+ if (png_ptr->transformations & PNG_FILLER)
+ png_warning(png_ptr, "PNG_READ_FILLER_SUPPORTED is not defined");
+#endif
+
+#if defined(PNG_WRITE_PACKSWAP_SUPPORTED) && \
+ !defined(PNG_READ_PACKSWAP_SUPPORTED)
+ if (png_ptr->transformations & PNG_PACKSWAP)
+ png_warning(png_ptr, "PNG_READ_PACKSWAP_SUPPORTED is not defined");
+#endif
+
+#if defined(PNG_WRITE_PACK_SUPPORTED) && !defined(PNG_READ_PACK_SUPPORTED)
+ if (png_ptr->transformations & PNG_PACK)
+ png_warning(png_ptr, "PNG_READ_PACK_SUPPORTED is not defined");
+#endif
+
+#if defined(PNG_WRITE_SHIFT_SUPPORTED) && !defined(PNG_READ_SHIFT_SUPPORTED)
+ if (png_ptr->transformations & PNG_SHIFT)
+ png_warning(png_ptr, "PNG_READ_SHIFT_SUPPORTED is not defined");
+#endif
+
+#if defined(PNG_WRITE_BGR_SUPPORTED) && !defined(PNG_READ_BGR_SUPPORTED)
+ if (png_ptr->transformations & PNG_BGR)
+ png_warning(png_ptr, "PNG_READ_BGR_SUPPORTED is not defined");
+#endif
+
+#if defined(PNG_WRITE_SWAP_SUPPORTED) && !defined(PNG_READ_SWAP_SUPPORTED)
+ if (png_ptr->transformations & PNG_SWAP_BYTES)
+ png_warning(png_ptr, "PNG_READ_SWAP_SUPPORTED is not defined");
+#endif
+ }
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ /* If interlaced and we do not need a new row, combine row and return.
+ * Notice that the pixels we have from previous rows have been transformed
+ * already; we can only combine like with like (transformed or
+ * untransformed) and, because of the libpng API for interlaced images, this
+ * means we must transform before de-interlacing.
+ */
+ if (png_ptr->interlaced && (png_ptr->transformations & PNG_INTERLACE))
+ {
+ switch (png_ptr->pass)
+ {
+ case 0:
+ if (png_ptr->row_number & 0x07)
+ {
+ if (dsp_row != NULL)
+ png_combine_row(png_ptr, dsp_row, 1/*display*/);
+ png_read_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 1:
+ if ((png_ptr->row_number & 0x07) || png_ptr->width < 5)
+ {
+ if (dsp_row != NULL)
+ png_combine_row(png_ptr, dsp_row, 1/*display*/);
+
+ png_read_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 2:
+ if ((png_ptr->row_number & 0x07) != 4)
+ {
+ if (dsp_row != NULL && (png_ptr->row_number & 4))
+ png_combine_row(png_ptr, dsp_row, 1/*display*/);
+
+ png_read_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 3:
+ if ((png_ptr->row_number & 3) || png_ptr->width < 3)
+ {
+ if (dsp_row != NULL)
+ png_combine_row(png_ptr, dsp_row, 1/*display*/);
+
+ png_read_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 4:
+ if ((png_ptr->row_number & 3) != 2)
+ {
+ if (dsp_row != NULL && (png_ptr->row_number & 2))
+ png_combine_row(png_ptr, dsp_row, 1/*display*/);
+
+ png_read_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 5:
+ if ((png_ptr->row_number & 1) || png_ptr->width < 2)
+ {
+ if (dsp_row != NULL)
+ png_combine_row(png_ptr, dsp_row, 1/*display*/);
+
+ png_read_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ default:
+ case 6:
+ if (!(png_ptr->row_number & 1))
+ {
+ png_read_finish_row(png_ptr);
+ return;
+ }
+ break;
+ }
+ }
+#endif
+
+ if (!(png_ptr->mode & PNG_HAVE_IDAT))
+ png_error(png_ptr, "Invalid attempt to read row data");
+
+ /* Fill the row with IDAT data: */
+ png_read_IDAT_data(png_ptr, png_ptr->row_buf, row_info.rowbytes + 1);
+
+ if (png_ptr->row_buf[0] > PNG_FILTER_VALUE_NONE)
+ {
+ if (png_ptr->row_buf[0] < PNG_FILTER_VALUE_LAST)
+ png_read_filter_row(png_ptr, &row_info, png_ptr->row_buf + 1,
+ png_ptr->prev_row + 1, png_ptr->row_buf[0]);
+ else
+ png_error(png_ptr, "bad adaptive filter value");
+ }
+
+ /* libpng 1.5.6: the following line was copying png_ptr->rowbytes before
+ * 1.5.6, while the buffer really is this big in current versions of libpng
+ * it may not be in the future, so this was changed just to copy the
+ * interlaced count:
+ */
+ memcpy(png_ptr->prev_row, png_ptr->row_buf, row_info.rowbytes + 1);
+
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ if ((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) &&
+ (png_ptr->filter_type == PNG_INTRAPIXEL_DIFFERENCING))
+ {
+ /* Intrapixel differencing */
+ png_do_read_intrapixel(&row_info, png_ptr->row_buf + 1);
+ }
+#endif
+
+
+#ifdef PNG_READ_TRANSFORMS_SUPPORTED
+ if (png_ptr->transformations)
+ png_do_read_transformations(png_ptr, &row_info);
+#endif
+
+ /* The transformed pixel depth should match the depth now in row_info. */
+ if (png_ptr->transformed_pixel_depth == 0)
+ {
+ png_ptr->transformed_pixel_depth = row_info.pixel_depth;
+ if (row_info.pixel_depth > png_ptr->maximum_pixel_depth)
+ png_error(png_ptr, "sequential row overflow");
+ }
+
+ else if (png_ptr->transformed_pixel_depth != row_info.pixel_depth)
+ png_error(png_ptr, "internal sequential row size calculation error");
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ /* Blow up interlaced rows to full size */
+ if (png_ptr->interlaced &&
+ (png_ptr->transformations & PNG_INTERLACE))
+ {
+ if (png_ptr->pass < 6)
+ png_do_read_interlace(&row_info, png_ptr->row_buf + 1, png_ptr->pass,
+ png_ptr->transformations);
+
+ if (dsp_row != NULL)
+ png_combine_row(png_ptr, dsp_row, 1/*display*/);
+
+ if (row != NULL)
+ png_combine_row(png_ptr, row, 0/*row*/);
+ }
+
+ else
+#endif
+ {
+ if (row != NULL)
+ png_combine_row(png_ptr, row, -1/*ignored*/);
+
+ if (dsp_row != NULL)
+ png_combine_row(png_ptr, dsp_row, -1/*ignored*/);
+ }
+ png_read_finish_row(png_ptr);
+
+ if (png_ptr->read_row_fn != NULL)
+ (*(png_ptr->read_row_fn))(png_ptr, png_ptr->row_number, png_ptr->pass);
+
+}
+#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Read one or more rows of image data. If the image is interlaced,
+ * and png_set_interlace_handling() has been called, the rows need to
+ * contain the contents of the rows from the previous pass. If the
+ * image has alpha or transparency, and png_handle_alpha()[*] has been
+ * called, the rows contents must be initialized to the contents of the
+ * screen.
+ *
+ * "row" holds the actual image, and pixels are placed in it
+ * as they arrive. If the image is displayed after each pass, it will
+ * appear to "sparkle" in. "display_row" can be used to display a
+ * "chunky" progressive image, with finer detail added as it becomes
+ * available. If you do not want this "chunky" display, you may pass
+ * NULL for display_row. If you do not want the sparkle display, and
+ * you have not called png_handle_alpha(), you may pass NULL for rows.
+ * If you have called png_handle_alpha(), and the image has either an
+ * alpha channel or a transparency chunk, you must provide a buffer for
+ * rows. In this case, you do not have to provide a display_row buffer
+ * also, but you may. If the image is not interlaced, or if you have
+ * not called png_set_interlace_handling(), the display_row buffer will
+ * be ignored, so pass NULL to it.
+ *
+ * [*] png_handle_alpha() does not exist yet, as of this version of libpng
+ */
+
+void PNGAPI
+png_read_rows(png_structrp png_ptr, png_bytepp row,
+ png_bytepp display_row, png_uint_32 num_rows)
+{
+ png_uint_32 i;
+ png_bytepp rp;
+ png_bytepp dp;
+
+ png_debug(1, "in png_read_rows");
+
+ if (png_ptr == NULL)
+ return;
+
+ rp = row;
+ dp = display_row;
+ if (rp != NULL && dp != NULL)
+ for (i = 0; i < num_rows; i++)
+ {
+ png_bytep rptr = *rp++;
+ png_bytep dptr = *dp++;
+
+ png_read_row(png_ptr, rptr, dptr);
+ }
+
+ else if (rp != NULL)
+ for (i = 0; i < num_rows; i++)
+ {
+ png_bytep rptr = *rp;
+ png_read_row(png_ptr, rptr, NULL);
+ rp++;
+ }
+
+ else if (dp != NULL)
+ for (i = 0; i < num_rows; i++)
+ {
+ png_bytep dptr = *dp;
+ png_read_row(png_ptr, NULL, dptr);
+ dp++;
+ }
+}
+#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Read the entire image. If the image has an alpha channel or a tRNS
+ * chunk, and you have called png_handle_alpha()[*], you will need to
+ * initialize the image to the current image that PNG will be overlaying.
+ * We set the num_rows again here, in case it was incorrectly set in
+ * png_read_start_row() by a call to png_read_update_info() or
+ * png_start_read_image() if png_set_interlace_handling() wasn't called
+ * prior to either of these functions like it should have been. You can
+ * only call this function once. If you desire to have an image for
+ * each pass of a interlaced image, use png_read_rows() instead.
+ *
+ * [*] png_handle_alpha() does not exist yet, as of this version of libpng
+ */
+void PNGAPI
+png_read_image(png_structrp png_ptr, png_bytepp image)
+{
+ png_uint_32 i, image_height;
+ int pass, j;
+ png_bytepp rp;
+
+ png_debug(1, "in png_read_image");
+
+ if (png_ptr == NULL)
+ return;
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ if (!(png_ptr->flags & PNG_FLAG_ROW_INIT))
+ {
+ pass = png_set_interlace_handling(png_ptr);
+ /* And make sure transforms are initialized. */
+ png_start_read_image(png_ptr);
+ }
+ else
+ {
+ if (png_ptr->interlaced && !(png_ptr->transformations & PNG_INTERLACE))
+ {
+ /* Caller called png_start_read_image or png_read_update_info without
+ * first turning on the PNG_INTERLACE transform. We can fix this here,
+ * but the caller should do it!
+ */
+ png_warning(png_ptr, "Interlace handling should be turned on when "
+ "using png_read_image");
+ /* Make sure this is set correctly */
+ png_ptr->num_rows = png_ptr->height;
+ }
+
+ /* Obtain the pass number, which also turns on the PNG_INTERLACE flag in
+ * the above error case.
+ */
+ pass = png_set_interlace_handling(png_ptr);
+ }
+#else
+ if (png_ptr->interlaced)
+ png_error(png_ptr,
+ "Cannot read interlaced image -- interlace handler disabled");
+
+ pass = 1;
+#endif
+
+ image_height=png_ptr->height;
+
+ for (j = 0; j < pass; j++)
+ {
+ rp = image;
+ for (i = 0; i < image_height; i++)
+ {
+ png_read_row(png_ptr, *rp, NULL);
+ rp++;
+ }
+ }
+}
+#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+/* Read the end of the PNG file. Will not read past the end of the
+ * file, will verify the end is accurate, and will read any comments
+ * or time information at the end of the file, if info is not NULL.
+ */
+void PNGAPI
+png_read_end(png_structrp png_ptr, png_inforp info_ptr)
+{
+#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+ int keep;
+#endif
+
+ png_debug(1, "in png_read_end");
+
+ if (png_ptr == NULL)
+ return;
+
+ /* If png_read_end is called in the middle of reading the rows there may
+ * still be pending IDAT data and an owned zstream. Deal with this here.
+ */
+#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+ if (!png_chunk_unknown_handling(png_ptr, png_IDAT))
+#endif
+ png_read_finish_IDAT(png_ptr);
+
+#ifdef PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED
+ /* Report invalid palette index; added at libng-1.5.10 */
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE &&
+ png_ptr->num_palette_max > png_ptr->num_palette)
+ png_benign_error(png_ptr, "Read palette index exceeding num_palette");
+#endif
+
+ do
+ {
+ png_uint_32 length = png_read_chunk_header(png_ptr);
+ png_uint_32 chunk_name = png_ptr->chunk_name;
+
+ if (chunk_name == png_IHDR)
+ png_handle_IHDR(png_ptr, info_ptr, length);
+
+ else if (chunk_name == png_IEND)
+ png_handle_IEND(png_ptr, info_ptr, length);
+
+#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+ else if ((keep = png_chunk_unknown_handling(png_ptr, chunk_name)) != 0)
+ {
+ if (chunk_name == png_IDAT)
+ {
+ if ((length > 0) || (png_ptr->mode & PNG_HAVE_CHUNK_AFTER_IDAT))
+ png_benign_error(png_ptr, "Too many IDATs found");
+ }
+ png_handle_unknown(png_ptr, info_ptr, length, keep);
+ if (chunk_name == png_PLTE)
+ png_ptr->mode |= PNG_HAVE_PLTE;
+ }
+#endif
+
+ else if (chunk_name == png_IDAT)
+ {
+ /* Zero length IDATs are legal after the last IDAT has been
+ * read, but not after other chunks have been read.
+ */
+ if ((length > 0) || (png_ptr->mode & PNG_HAVE_CHUNK_AFTER_IDAT))
+ png_benign_error(png_ptr, "Too many IDATs found");
+
+ png_crc_finish(png_ptr, length);
+ }
+ else if (chunk_name == png_PLTE)
+ png_handle_PLTE(png_ptr, info_ptr, length);
+
+#ifdef PNG_READ_bKGD_SUPPORTED
+ else if (chunk_name == png_bKGD)
+ png_handle_bKGD(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_cHRM_SUPPORTED
+ else if (chunk_name == png_cHRM)
+ png_handle_cHRM(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_gAMA_SUPPORTED
+ else if (chunk_name == png_gAMA)
+ png_handle_gAMA(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_hIST_SUPPORTED
+ else if (chunk_name == png_hIST)
+ png_handle_hIST(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_oFFs_SUPPORTED
+ else if (chunk_name == png_oFFs)
+ png_handle_oFFs(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_pCAL_SUPPORTED
+ else if (chunk_name == png_pCAL)
+ png_handle_pCAL(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_sCAL_SUPPORTED
+ else if (chunk_name == png_sCAL)
+ png_handle_sCAL(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_pHYs_SUPPORTED
+ else if (chunk_name == png_pHYs)
+ png_handle_pHYs(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_sBIT_SUPPORTED
+ else if (chunk_name == png_sBIT)
+ png_handle_sBIT(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_sRGB_SUPPORTED
+ else if (chunk_name == png_sRGB)
+ png_handle_sRGB(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_iCCP_SUPPORTED
+ else if (chunk_name == png_iCCP)
+ png_handle_iCCP(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_sPLT_SUPPORTED
+ else if (chunk_name == png_sPLT)
+ png_handle_sPLT(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_tEXt_SUPPORTED
+ else if (chunk_name == png_tEXt)
+ png_handle_tEXt(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_tIME_SUPPORTED
+ else if (chunk_name == png_tIME)
+ png_handle_tIME(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_tRNS_SUPPORTED
+ else if (chunk_name == png_tRNS)
+ png_handle_tRNS(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_zTXt_SUPPORTED
+ else if (chunk_name == png_zTXt)
+ png_handle_zTXt(png_ptr, info_ptr, length);
+#endif
+
+#ifdef PNG_READ_iTXt_SUPPORTED
+ else if (chunk_name == png_iTXt)
+ png_handle_iTXt(png_ptr, info_ptr, length);
+#endif
+
+ else
+ png_handle_unknown(png_ptr, info_ptr, length,
+ PNG_HANDLE_CHUNK_AS_DEFAULT);
+ } while (!(png_ptr->mode & PNG_HAVE_IEND));
+}
+#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */
+
+/* Free all memory used in the read struct */
+static void
+png_read_destroy(png_structrp png_ptr)
+{
+ png_debug(1, "in png_read_destroy");
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ png_destroy_gamma_table(png_ptr);
+#endif
+
+ png_free(png_ptr, png_ptr->big_row_buf);
+ png_free(png_ptr, png_ptr->big_prev_row);
+ png_free(png_ptr, png_ptr->read_buffer);
+
+#ifdef PNG_READ_QUANTIZE_SUPPORTED
+ png_free(png_ptr, png_ptr->palette_lookup);
+ png_free(png_ptr, png_ptr->quantize_index);
+#endif
+
+ if (png_ptr->free_me & PNG_FREE_PLTE)
+ png_zfree(png_ptr, png_ptr->palette);
+ png_ptr->free_me &= ~PNG_FREE_PLTE;
+
+#if defined(PNG_tRNS_SUPPORTED) || \
+ defined(PNG_READ_EXPAND_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED)
+ if (png_ptr->free_me & PNG_FREE_TRNS)
+ png_free(png_ptr, png_ptr->trans_alpha);
+ png_ptr->free_me &= ~PNG_FREE_TRNS;
+#endif
+
+ inflateEnd(&png_ptr->zstream);
+
+#ifdef PNG_PROGRESSIVE_READ_SUPPORTED
+ png_free(png_ptr, png_ptr->save_buffer);
+#endif
+
+#if defined(PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED) &&\
+ defined(PNG_READ_UNKNOWN_CHUNKS_SUPPORTED)
+ png_free(png_ptr, png_ptr->unknown_chunk.data);
+#endif
+
+#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+ png_free(png_ptr, png_ptr->chunk_list);
+#endif
+
+ /* NOTE: the 'setjmp' buffer may still be allocated and the memory and error
+ * callbacks are still set at this point. They are required to complete the
+ * destruction of the png_struct itself.
+ */
+}
+
+/* Free all memory used by the read */
+void PNGAPI
+png_destroy_read_struct(png_structpp png_ptr_ptr, png_infopp info_ptr_ptr,
+ png_infopp end_info_ptr_ptr)
+{
+ png_structrp png_ptr = NULL;
+
+ png_debug(1, "in png_destroy_read_struct");
+
+ if (png_ptr_ptr != NULL)
+ png_ptr = *png_ptr_ptr;
+
+ if (png_ptr == NULL)
+ return;
+
+ /* libpng 1.6.0: use the API to destroy info structs to ensure consistent
+ * behavior. Prior to 1.6.0 libpng did extra 'info' destruction in this API.
+ * The extra was, apparently, unnecessary yet this hides memory leak bugs.
+ */
+ png_destroy_info_struct(png_ptr, end_info_ptr_ptr);
+ png_destroy_info_struct(png_ptr, info_ptr_ptr);
+
+ *png_ptr_ptr = NULL;
+ png_read_destroy(png_ptr);
+ png_destroy_png_struct(png_ptr);
+}
+
+void PNGAPI
+png_set_read_status_fn(png_structrp png_ptr, png_read_status_ptr read_row_fn)
+{
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->read_row_fn = read_row_fn;
+}
+
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+#ifdef PNG_INFO_IMAGE_SUPPORTED
+void PNGAPI
+png_read_png(png_structrp png_ptr, png_inforp info_ptr,
+ int transforms,
+ voidp params)
+{
+ int row;
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ /* png_read_info() gives us all of the information from the
+ * PNG file before the first IDAT (image data chunk).
+ */
+ png_read_info(png_ptr, info_ptr);
+ if (info_ptr->height > PNG_UINT_32_MAX/(sizeof (png_bytep)))
+ png_error(png_ptr, "Image is too high to process with png_read_png()");
+
+ /* -------------- image transformations start here ------------------- */
+
+#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED
+ /* Tell libpng to strip 16-bit/color files down to 8 bits per color.
+ */
+ if (transforms & PNG_TRANSFORM_SCALE_16)
+ {
+ /* Added at libpng-1.5.4. "strip_16" produces the same result that it
+ * did in earlier versions, while "scale_16" is now more accurate.
+ */
+ png_set_scale_16(png_ptr);
+ }
+#endif
+
+#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED
+ /* If both SCALE and STRIP are required pngrtran will effectively cancel the
+ * latter by doing SCALE first. This is ok and allows apps not to check for
+ * which is supported to get the right answer.
+ */
+ if (transforms & PNG_TRANSFORM_STRIP_16)
+ png_set_strip_16(png_ptr);
+#endif
+
+#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED
+ /* Strip alpha bytes from the input data without combining with
+ * the background (not recommended).
+ */
+ if (transforms & PNG_TRANSFORM_STRIP_ALPHA)
+ png_set_strip_alpha(png_ptr);
+#endif
+
+#if defined(PNG_READ_PACK_SUPPORTED) && !defined(PNG_READ_EXPAND_SUPPORTED)
+ /* Extract multiple pixels with bit depths of 1, 2, or 4 from a single
+ * byte into separate bytes (useful for paletted and grayscale images).
+ */
+ if (transforms & PNG_TRANSFORM_PACKING)
+ png_set_packing(png_ptr);
+#endif
+
+#ifdef PNG_READ_PACKSWAP_SUPPORTED
+ /* Change the order of packed pixels to least significant bit first
+ * (not useful if you are using png_set_packing).
+ */
+ if (transforms & PNG_TRANSFORM_PACKSWAP)
+ png_set_packswap(png_ptr);
+#endif
+
+#ifdef PNG_READ_EXPAND_SUPPORTED
+ /* Expand paletted colors into true RGB triplets
+ * Expand grayscale images to full 8 bits from 1, 2, or 4 bits/pixel
+ * Expand paletted or RGB images with transparency to full alpha
+ * channels so the data will be available as RGBA quartets.
+ */
+ if (transforms & PNG_TRANSFORM_EXPAND)
+ if ((png_ptr->bit_depth < 8) ||
+ (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) ||
+ (png_get_valid(png_ptr, info_ptr, PNG_INFO_tRNS)))
+ png_set_expand(png_ptr);
+#endif
+
+ /* We don't handle background color or gamma transformation or quantizing.
+ */
+
+#ifdef PNG_READ_INVERT_SUPPORTED
+ /* Invert monochrome files to have 0 as white and 1 as black
+ */
+ if (transforms & PNG_TRANSFORM_INVERT_MONO)
+ png_set_invert_mono(png_ptr);
+#endif
+
+#ifdef PNG_READ_SHIFT_SUPPORTED
+ /* If you want to shift the pixel values from the range [0,255] or
+ * [0,65535] to the original [0,7] or [0,31], or whatever range the
+ * colors were originally in:
+ */
+ if ((transforms & PNG_TRANSFORM_SHIFT)
+ && png_get_valid(png_ptr, info_ptr, PNG_INFO_sBIT))
+ {
+ png_color_8p sig_bit;
+
+ png_get_sBIT(png_ptr, info_ptr, &sig_bit);
+ png_set_shift(png_ptr, sig_bit);
+ }
+#endif
+
+#ifdef PNG_READ_BGR_SUPPORTED
+ /* Flip the RGB pixels to BGR (or RGBA to BGRA) */
+ if (transforms & PNG_TRANSFORM_BGR)
+ png_set_bgr(png_ptr);
+#endif
+
+#ifdef PNG_READ_SWAP_ALPHA_SUPPORTED
+ /* Swap the RGBA or GA data to ARGB or AG (or BGRA to ABGR) */
+ if (transforms & PNG_TRANSFORM_SWAP_ALPHA)
+ png_set_swap_alpha(png_ptr);
+#endif
+
+#ifdef PNG_READ_SWAP_SUPPORTED
+ /* Swap bytes of 16-bit files to least significant byte first */
+ if (transforms & PNG_TRANSFORM_SWAP_ENDIAN)
+ png_set_swap(png_ptr);
+#endif
+
+/* Added at libpng-1.2.41 */
+#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED
+ /* Invert the alpha channel from opacity to transparency */
+ if (transforms & PNG_TRANSFORM_INVERT_ALPHA)
+ png_set_invert_alpha(png_ptr);
+#endif
+
+/* Added at libpng-1.2.41 */
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+ /* Expand grayscale image to RGB */
+ if (transforms & PNG_TRANSFORM_GRAY_TO_RGB)
+ png_set_gray_to_rgb(png_ptr);
+#endif
+
+/* Added at libpng-1.5.4 */
+#ifdef PNG_READ_EXPAND_16_SUPPORTED
+ if (transforms & PNG_TRANSFORM_EXPAND_16)
+ png_set_expand_16(png_ptr);
+#endif
+
+ /* We don't handle adding filler bytes */
+
+ /* We use png_read_image and rely on that for interlace handling, but we also
+ * call png_read_update_info therefore must turn on interlace handling now:
+ */
+ (void)png_set_interlace_handling(png_ptr);
+
+ /* Optional call to gamma correct and add the background to the palette
+ * and update info structure. REQUIRED if you are expecting libpng to
+ * update the palette for you (i.e., you selected such a transform above).
+ */
+ png_read_update_info(png_ptr, info_ptr);
+
+ /* -------------- image transformations end here ------------------- */
+
+ png_free_data(png_ptr, info_ptr, PNG_FREE_ROWS, 0);
+ if (info_ptr->row_pointers == NULL)
+ {
+ png_uint_32 iptr;
+
+ info_ptr->row_pointers = (png_bytepp)png_malloc(png_ptr,
+ info_ptr->height * (sizeof (png_bytep)));
+ for (iptr=0; iptr<info_ptr->height; iptr++)
+ info_ptr->row_pointers[iptr] = NULL;
+
+ info_ptr->free_me |= PNG_FREE_ROWS;
+
+ for (row = 0; row < (int)info_ptr->height; row++)
+ info_ptr->row_pointers[row] = (png_bytep)png_malloc(png_ptr,
+ png_get_rowbytes(png_ptr, info_ptr));
+ }
+
+ png_read_image(png_ptr, info_ptr->row_pointers);
+ info_ptr->valid |= PNG_INFO_IDAT;
+
+ /* Read rest of file, and get additional chunks in info_ptr - REQUIRED */
+ png_read_end(png_ptr, info_ptr);
+
+ PNG_UNUSED(transforms) /* Quiet compiler warnings */
+ PNG_UNUSED(params)
+
+}
+#endif /* PNG_INFO_IMAGE_SUPPORTED */
+#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */
+
+#ifdef PNG_SIMPLIFIED_READ_SUPPORTED
+/* SIMPLIFIED READ
+ *
+ * This code currently relies on the sequential reader, though it could easily
+ * be made to work with the progressive one.
+ */
+/* Arguments to png_image_finish_read: */
+
+/* Encoding of PNG data (used by the color-map code) */
+/* TODO: change these, dang, ANSI-C reserves the 'E' namespace. */
+# define E_NOTSET 0 /* File encoding not yet known */
+# define E_sRGB 1 /* 8-bit encoded to sRGB gamma */
+# define E_LINEAR 2 /* 16-bit linear: not encoded, NOT pre-multiplied! */
+# define E_FILE 3 /* 8-bit encoded to file gamma, not sRGB or linear */
+# define E_LINEAR8 4 /* 8-bit linear: only from a file value */
+
+/* Color-map processing: after libpng has run on the PNG image further
+ * processing may be needed to conver the data to color-map indicies.
+ */
+#define PNG_CMAP_NONE 0
+#define PNG_CMAP_GA 1 /* Process GA data to a color-map with alpha */
+#define PNG_CMAP_TRANS 2 /* Process GA data to a background index */
+#define PNG_CMAP_RGB 3 /* Process RGB data */
+#define PNG_CMAP_RGB_ALPHA 4 /* Process RGBA data */
+
+/* The following document where the background is for each processing case. */
+#define PNG_CMAP_NONE_BACKGROUND 256
+#define PNG_CMAP_GA_BACKGROUND 231
+#define PNG_CMAP_TRANS_BACKGROUND 254
+#define PNG_CMAP_RGB_BACKGROUND 256
+#define PNG_CMAP_RGB_ALPHA_BACKGROUND 216
+
+typedef struct
+{
+ /* Arguments: */
+ png_imagep image;
+ png_voidp buffer;
+ png_int_32 row_stride;
+ png_voidp colormap;
+ png_const_colorp background;
+ /* Local variables: */
+ png_voidp local_row;
+ png_voidp first_row;
+ ptrdiff_t row_bytes; /* step between rows */
+ int file_encoding; /* E_ values above */
+ png_fixed_point gamma_to_linear; /* For E_FILE, reciprocal of gamma */
+ int colormap_processing; /* PNG_CMAP_ values above */
+} png_image_read_control;
+
+/* Do all the *safe* initialization - 'safe' means that png_error won't be
+ * called, so setting up the jmp_buf is not required. This means that anything
+ * called from here must *not* call png_malloc - it has to call png_malloc_warn
+ * instead so that control is returned safely back to this routine.
+ */
+static int
+png_image_read_init(png_imagep image)
+{
+ if (image->opaque == NULL)
+ {
+ png_structp png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, image,
+ png_safe_error, png_safe_warning);
+
+ /* And set the rest of the structure to NULL to ensure that the various
+ * fields are consistent.
+ */
+ memset(image, 0, (sizeof *image));
+ image->version = PNG_IMAGE_VERSION;
+
+ if (png_ptr != NULL)
+ {
+ png_infop info_ptr = png_create_info_struct(png_ptr);
+
+ if (info_ptr != NULL)
+ {
+ png_controlp control = png_voidcast(png_controlp,
+ png_malloc_warn(png_ptr, (sizeof *control)));
+
+ if (control != NULL)
+ {
+ memset(control, 0, (sizeof *control));
+
+ control->png_ptr = png_ptr;
+ control->info_ptr = info_ptr;
+ control->for_write = 0;
+
+ image->opaque = control;
+ return 1;
+ }
+
+ /* Error clean up */
+ png_destroy_info_struct(png_ptr, &info_ptr);
+ }
+
+ png_destroy_read_struct(&png_ptr, NULL, NULL);
+ }
+
+ return png_image_error(image, "png_image_read: out of memory");
+ }
+
+ return png_image_error(image, "png_image_read: opaque pointer not NULL");
+}
+
+/* Utility to find the base format of a PNG file from a png_struct. */
+static png_uint_32
+png_image_format(png_structrp png_ptr)
+{
+ png_uint_32 format = 0;
+
+ if (png_ptr->color_type & PNG_COLOR_MASK_COLOR)
+ format |= PNG_FORMAT_FLAG_COLOR;
+
+ if (png_ptr->color_type & PNG_COLOR_MASK_ALPHA)
+ format |= PNG_FORMAT_FLAG_ALPHA;
+
+ /* Use png_ptr here, not info_ptr, because by examination png_handle_tRNS
+ * sets the png_struct fields; that's all we are interested in here. The
+ * precise interaction with an app call to png_set_tRNS and PNG file reading
+ * is unclear.
+ */
+ else if (png_ptr->num_trans > 0)
+ format |= PNG_FORMAT_FLAG_ALPHA;
+
+ if (png_ptr->bit_depth == 16)
+ format |= PNG_FORMAT_FLAG_LINEAR;
+
+ if (png_ptr->color_type & PNG_COLOR_MASK_PALETTE)
+ format |= PNG_FORMAT_FLAG_COLORMAP;
+
+ return format;
+}
+
+/* Is the given gamma significantly different from sRGB? The test is the same
+ * one used in pngrtran.c when deciding whether to do gamma correction. The
+ * arithmetic optimizes the division by using the fact that the inverse of the
+ * file sRGB gamma is 2.2
+ */
+static int
+png_gamma_not_sRGB(png_fixed_point g)
+{
+ if (g < PNG_FP_1)
+ {
+ /* An uninitialized gamma is assumed to be sRGB for the simplified API. */
+ if (g == 0)
+ return 0;
+
+ return png_gamma_significant((g * 11 + 2)/5 /* i.e. *2.2, rounded */);
+ }
+
+ return 1;
+}
+
+/* Do the main body of a 'png_image_begin_read' function; read the PNG file
+ * header and fill in all the information. This is executed in a safe context,
+ * unlike the init routine above.
+ */
+static int
+png_image_read_header(png_voidp argument)
+{
+ png_imagep image = png_voidcast(png_imagep, argument);
+ png_structrp png_ptr = image->opaque->png_ptr;
+ png_inforp info_ptr = image->opaque->info_ptr;
+
+ png_set_benign_errors(png_ptr, 1/*warn*/);
+ png_read_info(png_ptr, info_ptr);
+
+ /* Do this the fast way; just read directly out of png_struct. */
+ image->width = png_ptr->width;
+ image->height = png_ptr->height;
+
+ {
+ png_uint_32 format = png_image_format(png_ptr);
+
+ image->format = format;
+
+#ifdef PNG_COLORSPACE_SUPPORTED
+ /* Does the colorspace match sRGB? If there is no color endpoint
+ * (colorant) information assume yes, otherwise require the
+ * 'ENDPOINTS_MATCHE_sRGB' colorspace flag to have been set. If the
+ * colorspace has been determined to be invalid ignore it.
+ */
+ if ((format & PNG_FORMAT_FLAG_COLOR) != 0 && ((png_ptr->colorspace.flags
+ & (PNG_COLORSPACE_HAVE_ENDPOINTS|PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB|
+ PNG_COLORSPACE_INVALID)) == PNG_COLORSPACE_HAVE_ENDPOINTS))
+ image->flags |= PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB;
+#endif
+ }
+
+ /* We need the maximum number of entries regardless of the format the
+ * application sets here.
+ */
+ {
+ png_uint_32 cmap_entries;
+
+ switch (png_ptr->color_type)
+ {
+ case PNG_COLOR_TYPE_GRAY:
+ cmap_entries = 1U << png_ptr->bit_depth;
+ break;
+
+ case PNG_COLOR_TYPE_PALETTE:
+ cmap_entries = png_ptr->num_palette;
+ break;
+
+ default:
+ cmap_entries = 256;
+ break;
+ }
+
+ if (cmap_entries > 256)
+ cmap_entries = 256;
+
+ image->colormap_entries = cmap_entries;
+ }
+
+ return 1;
+}
+
+#ifdef PNG_STDIO_SUPPORTED
+int PNGAPI
+png_image_begin_read_from_stdio(png_imagep image, FILE* file)
+{
+ if (image != NULL && image->version == PNG_IMAGE_VERSION)
+ {
+ if (file != NULL)
+ {
+ if (png_image_read_init(image))
+ {
+ /* This is slightly evil, but png_init_io doesn't do anything other
+ * than this and we haven't changed the standard IO functions so
+ * this saves a 'safe' function.
+ */
+ image->opaque->png_ptr->io_ptr = file;
+ return png_safe_execute(image, png_image_read_header, image);
+ }
+ }
+
+ else
+ return png_image_error(image,
+ "png_image_begin_read_from_stdio: invalid argument");
+ }
+
+ else if (image != NULL)
+ return png_image_error(image,
+ "png_image_begin_read_from_stdio: incorrect PNG_IMAGE_VERSION");
+
+ return 0;
+}
+
+int PNGAPI
+png_image_begin_read_from_file(png_imagep image, const char *file_name)
+{
+ if (image != NULL && image->version == PNG_IMAGE_VERSION)
+ {
+ if (file_name != NULL)
+ {
+ FILE *fp = fopen(file_name, "rb");
+
+ if (fp != NULL)
+ {
+ if (png_image_read_init(image))
+ {
+ image->opaque->png_ptr->io_ptr = fp;
+ image->opaque->owned_file = 1;
+ return png_safe_execute(image, png_image_read_header, image);
+ }
+
+ /* Clean up: just the opened file. */
+ (void)fclose(fp);
+ }
+
+ else
+ return png_image_error(image, strerror(errno));
+ }
+
+ else
+ return png_image_error(image,
+ "png_image_begin_read_from_file: invalid argument");
+ }
+
+ else if (image != NULL)
+ return png_image_error(image,
+ "png_image_begin_read_from_file: incorrect PNG_IMAGE_VERSION");
+
+ return 0;
+}
+#endif /* PNG_STDIO_SUPPORTED */
+
+static void PNGCBAPI
+png_image_memory_read(png_structp png_ptr, png_bytep out, png_size_t need)
+{
+ if (png_ptr != NULL)
+ {
+ png_imagep image = png_voidcast(png_imagep, png_ptr->io_ptr);
+ if (image != NULL)
+ {
+ png_controlp cp = image->opaque;
+ if (cp != NULL)
+ {
+ png_const_bytep memory = cp->memory;
+ png_size_t size = cp->size;
+
+ if (memory != NULL && size >= need)
+ {
+ memcpy(out, memory, need);
+ cp->memory = memory + need;
+ cp->size = size - need;
+ return;
+ }
+
+ png_error(png_ptr, "read beyond end of data");
+ }
+ }
+
+ png_error(png_ptr, "invalid memory read");
+ }
+}
+
+int PNGAPI png_image_begin_read_from_memory(png_imagep image,
+ png_const_voidp memory, png_size_t size)
+{
+ if (image != NULL && image->version == PNG_IMAGE_VERSION)
+ {
+ if (memory != NULL && size > 0)
+ {
+ if (png_image_read_init(image))
+ {
+ /* Now set the IO functions to read from the memory buffer and
+ * store it into io_ptr. Again do this in-place to avoid calling a
+ * libpng function that requires error handling.
+ */
+ image->opaque->memory = png_voidcast(png_const_bytep, memory);
+ image->opaque->size = size;
+ image->opaque->png_ptr->io_ptr = image;
+ image->opaque->png_ptr->read_data_fn = png_image_memory_read;
+
+ return png_safe_execute(image, png_image_read_header, image);
+ }
+ }
+
+ else
+ return png_image_error(image,
+ "png_image_begin_read_from_memory: invalid argument");
+ }
+
+ else if (image != NULL)
+ return png_image_error(image,
+ "png_image_begin_read_from_memory: incorrect PNG_IMAGE_VERSION");
+
+ return 0;
+}
+
+/* Utility function to skip chunks that are not used by the simplified image
+ * read functions and an appropriate macro to call it.
+ */
+#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+static void
+png_image_skip_unused_chunks(png_structrp png_ptr)
+{
+ /* Prepare the reader to ignore all recognized chunks whose data will not
+ * be used, i.e., all chunks recognized by libpng except for those
+ * involved in basic image reading:
+ *
+ * IHDR, PLTE, IDAT, IEND
+ *
+ * Or image data handling:
+ *
+ * tRNS, bKGD, gAMA, cHRM, sRGB, iCCP and sBIT.
+ *
+ * This provides a small performance improvement and eliminates any
+ * potential vulnerability to security problems in the unused chunks.
+ */
+ {
+ static PNG_CONST png_byte chunks_to_process[] = {
+ 98, 75, 71, 68, '\0', /* bKGD */
+ 99, 72, 82, 77, '\0', /* cHRM */
+ 103, 65, 77, 65, '\0', /* gAMA */
+ 105, 67, 67, 80, '\0', /* iCCP */
+ 115, 66, 73, 84, '\0', /* sBIT */
+ 115, 82, 71, 66, '\0', /* sRGB */
+ };
+
+ /* Ignore unknown chunks and all other chunks except for the
+ * IHDR, PLTE, tRNS, IDAT, and IEND chunks.
+ */
+ png_set_keep_unknown_chunks(png_ptr, PNG_HANDLE_CHUNK_NEVER,
+ NULL, -1);
+
+ /* But do not ignore image data handling chunks */
+ png_set_keep_unknown_chunks(png_ptr, PNG_HANDLE_CHUNK_AS_DEFAULT,
+ chunks_to_process, (sizeof chunks_to_process)/5);
+ }
+}
+
+# define PNG_SKIP_CHUNKS(p) png_image_skip_unused_chunks(p)
+#else
+# define PNG_SKIP_CHUNKS(p) ((void)0)
+#endif /* PNG_HANDLE_AS_UNKNOWN_SUPPORTED */
+
+/* The following macro gives the exact rounded answer for all values in the
+ * range 0..255 (it actually divides by 51.2, but the rounding still generates
+ * the correct numbers 0..5
+ */
+#define PNG_DIV51(v8) (((v8) * 5 + 130) >> 8)
+
+/* Utility functions to make particular color-maps */
+static void
+set_file_encoding(png_image_read_control *display)
+{
+ png_fixed_point g = display->image->opaque->png_ptr->colorspace.gamma;
+ if (png_gamma_significant(g))
+ {
+ if (png_gamma_not_sRGB(g))
+ {
+ display->file_encoding = E_FILE;
+ display->gamma_to_linear = png_reciprocal(g);
+ }
+
+ else
+ display->file_encoding = E_sRGB;
+ }
+
+ else
+ display->file_encoding = E_LINEAR8;
+}
+
+static unsigned int
+decode_gamma(png_image_read_control *display, png_uint_32 value, int encoding)
+{
+ if (encoding == E_FILE) /* double check */
+ encoding = display->file_encoding;
+
+ if (encoding == E_NOTSET) /* must be the file encoding */
+ {
+ set_file_encoding(display);
+ encoding = display->file_encoding;
+ }
+
+ switch (encoding)
+ {
+ case E_FILE:
+ value = png_gamma_16bit_correct(value*257, display->gamma_to_linear);
+ break;
+
+ case E_sRGB:
+ value = png_sRGB_table[value];
+ break;
+
+ case E_LINEAR:
+ break;
+
+ case E_LINEAR8:
+ value *= 257;
+ break;
+
+ default:
+ png_error(display->image->opaque->png_ptr,
+ "unexpected encoding (internal error)");
+ break;
+ }
+
+ return value;
+}
+
+static png_uint_32
+png_colormap_compose(png_image_read_control *display,
+ png_uint_32 foreground, int foreground_encoding, png_uint_32 alpha,
+ png_uint_32 background, int encoding)
+{
+ /* The file value is composed on the background, the background has the given
+ * encoding and so does the result, the file is encoded with E_FILE and the
+ * file and alpha are 8-bit values. The (output) encoding will always be
+ * E_LINEAR or E_sRGB.
+ */
+ png_uint_32 f = decode_gamma(display, foreground, foreground_encoding);
+ png_uint_32 b = decode_gamma(display, background, encoding);
+
+ /* The alpha is always an 8-bit value (it comes from the palette), the value
+ * scaled by 255 is what PNG_sRGB_FROM_LINEAR requires.
+ */
+ f = f * alpha + b * (255-alpha);
+
+ if (encoding == E_LINEAR)
+ {
+ /* Scale to 65535; divide by 255, approximately (in fact this is extremely
+ * accurate, it divides by 255.00000005937181414556, with no overflow.)
+ */
+ f *= 257; /* Now scaled by 65535 */
+ f += f >> 16;
+ f = (f+32768) >> 16;
+ }
+
+ else /* E_sRGB */
+ f = PNG_sRGB_FROM_LINEAR(f);
+
+ return f;
+}
+
+/* NOTE: E_LINEAR values to this routine must be 16-bit, but E_FILE values must
+ * be 8-bit.
+ */
+static void
+png_create_colormap_entry(png_image_read_control *display,
+ png_uint_32 ip, png_uint_32 red, png_uint_32 green, png_uint_32 blue,
+ png_uint_32 alpha, int encoding)
+{
+ png_imagep image = display->image;
+ const int output_encoding = (image->format & PNG_FORMAT_FLAG_LINEAR) ?
+ E_LINEAR : E_sRGB;
+ const int convert_to_Y = (image->format & PNG_FORMAT_FLAG_COLOR) == 0 &&
+ (red != green || green != blue);
+
+ if (ip > 255)
+ png_error(image->opaque->png_ptr, "color-map index out of range");
+
+ /* Update the cache with whether the file gamma is significantly different
+ * from sRGB.
+ */
+ if (encoding == E_FILE)
+ {
+ if (display->file_encoding == E_NOTSET)
+ set_file_encoding(display);
+
+ /* Note that the cached value may be E_FILE too, but if it is then the
+ * gamma_to_linear member has been set.
+ */
+ encoding = display->file_encoding;
+ }
+
+ if (encoding == E_FILE)
+ {
+ png_fixed_point g = display->gamma_to_linear;
+
+ red = png_gamma_16bit_correct(red*257, g);
+ green = png_gamma_16bit_correct(green*257, g);
+ blue = png_gamma_16bit_correct(blue*257, g);
+
+ if (convert_to_Y || output_encoding == E_LINEAR)
+ {
+ alpha *= 257;
+ encoding = E_LINEAR;
+ }
+
+ else
+ {
+ red = PNG_sRGB_FROM_LINEAR(red * 255);
+ green = PNG_sRGB_FROM_LINEAR(green * 255);
+ blue = PNG_sRGB_FROM_LINEAR(blue * 255);
+ encoding = E_sRGB;
+ }
+ }
+
+ else if (encoding == E_LINEAR8)
+ {
+ /* This encoding occurs quite frequently in test cases because PngSuite
+ * includes a gAMA 1.0 chunk with most images.
+ */
+ red *= 257;
+ green *= 257;
+ blue *= 257;
+ alpha *= 257;
+ encoding = E_LINEAR;
+ }
+
+ else if (encoding == E_sRGB && (convert_to_Y || output_encoding == E_LINEAR))
+ {
+ /* The values are 8-bit sRGB values, but must be converted to 16-bit
+ * linear.
+ */
+ red = png_sRGB_table[red];
+ green = png_sRGB_table[green];
+ blue = png_sRGB_table[blue];
+ alpha *= 257;
+ encoding = E_LINEAR;
+ }
+
+ /* This is set if the color isn't gray but the output is. */
+ if (encoding == E_LINEAR)
+ {
+ if (convert_to_Y)
+ {
+ /* NOTE: these values are copied from png_do_rgb_to_gray */
+ png_uint_32 y = (png_uint_32)6968 * red + (png_uint_32)23434 * green +
+ (png_uint_32)2366 * blue;
+
+ if (output_encoding == E_LINEAR)
+ y = (y + 16384) >> 15;
+
+ else
+ {
+ /* y is scaled by 32768, we need it scaled by 255: */
+ y = (y + 128) >> 8;
+ y *= 255;
+ y = PNG_sRGB_FROM_LINEAR((y + 64) >> 7);
+ encoding = E_sRGB;
+ }
+
+ blue = red = green = y;
+ }
+
+ else if (output_encoding == E_sRGB)
+ {
+ red = PNG_sRGB_FROM_LINEAR(red * 255);
+ green = PNG_sRGB_FROM_LINEAR(green * 255);
+ blue = PNG_sRGB_FROM_LINEAR(blue * 255);
+ alpha = PNG_DIV257(alpha);
+ encoding = E_sRGB;
+ }
+ }
+
+ if (encoding != output_encoding)
+ png_error(image->opaque->png_ptr, "bad encoding (internal error)");
+
+ /* Store the value. */
+ {
+# ifdef PNG_FORMAT_BGR_SUPPORTED
+ const int afirst = (image->format & PNG_FORMAT_FLAG_AFIRST) != 0 &&
+ (image->format & PNG_FORMAT_FLAG_ALPHA) != 0;
+# else
+# define afirst 0
+# endif
+# ifdef PNG_FORMAT_BGR_SUPPORTED
+ const int bgr = (image->format & PNG_FORMAT_FLAG_BGR) ? 2 : 0;
+# else
+# define bgr 0
+# endif
+
+ if (output_encoding == E_LINEAR)
+ {
+ png_uint_16p entry = png_voidcast(png_uint_16p, display->colormap);
+
+ entry += ip * PNG_IMAGE_SAMPLE_CHANNELS(image->format);
+
+ /* The linear 16-bit values must be pre-multiplied by the alpha channel
+ * value, if less than 65535 (this is, effectively, composite on black
+ * if the alpha channel is removed.)
+ */
+ switch (PNG_IMAGE_SAMPLE_CHANNELS(image->format))
+ {
+ case 4:
+ entry[afirst ? 0 : 3] = (png_uint_16)alpha;
+ /* FALL THROUGH */
+
+ case 3:
+ if (alpha < 65535)
+ {
+ if (alpha > 0)
+ {
+ blue = (blue * alpha + 32767U)/65535U;
+ green = (green * alpha + 32767U)/65535U;
+ red = (red * alpha + 32767U)/65535U;
+ }
+
+ else
+ red = green = blue = 0;
+ }
+ entry[afirst + (2 ^ bgr)] = (png_uint_16)blue;
+ entry[afirst + 1] = (png_uint_16)green;
+ entry[afirst + bgr] = (png_uint_16)red;
+ break;
+
+ case 2:
+ entry[1 ^ afirst] = (png_uint_16)alpha;
+ /* FALL THROUGH */
+
+ case 1:
+ if (alpha < 65535)
+ {
+ if (alpha > 0)
+ green = (green * alpha + 32767U)/65535U;
+
+ else
+ green = 0;
+ }
+ entry[afirst] = (png_uint_16)green;
+ break;
+
+ default:
+ break;
+ }
+ }
+
+ else /* output encoding is E_sRGB */
+ {
+ png_bytep entry = png_voidcast(png_bytep, display->colormap);
+
+ entry += ip * PNG_IMAGE_SAMPLE_CHANNELS(image->format);
+
+ switch (PNG_IMAGE_SAMPLE_CHANNELS(image->format))
+ {
+ case 4:
+ entry[afirst ? 0 : 3] = (png_byte)alpha;
+ case 3:
+ entry[afirst + (2 ^ bgr)] = (png_byte)blue;
+ entry[afirst + 1] = (png_byte)green;
+ entry[afirst + bgr] = (png_byte)red;
+ break;
+
+ case 2:
+ entry[1 ^ afirst] = (png_byte)alpha;
+ case 1:
+ entry[afirst] = (png_byte)green;
+ break;
+
+ default:
+ break;
+ }
+ }
+
+# ifdef afirst
+# undef afirst
+# endif
+# ifdef bgr
+# undef bgr
+# endif
+ }
+}
+
+static int
+make_gray_file_colormap(png_image_read_control *display)
+{
+ unsigned int i;
+
+ for (i=0; i<256; ++i)
+ png_create_colormap_entry(display, i, i, i, i, 255, E_FILE);
+
+ return i;
+}
+
+static int
+make_gray_colormap(png_image_read_control *display)
+{
+ unsigned int i;
+
+ for (i=0; i<256; ++i)
+ png_create_colormap_entry(display, i, i, i, i, 255, E_sRGB);
+
+ return i;
+}
+#define PNG_GRAY_COLORMAP_ENTRIES 256
+
+static int
+make_ga_colormap(png_image_read_control *display)
+{
+ unsigned int i, a;
+
+ /* Alpha is retained, the output will be a color-map with entries
+ * selected by six levels of alpha. One transparent entry, 6 gray
+ * levels for all the intermediate alpha values, leaving 230 entries
+ * for the opaque grays. The color-map entries are the six values
+ * [0..5]*51, the GA processing uses PNG_DIV51(value) to find the
+ * relevant entry.
+ *
+ * if (alpha > 229) // opaque
+ * {
+ * // The 231 entries are selected to make the math below work:
+ * base = 0;
+ * entry = (231 * gray + 128) >> 8;
+ * }
+ * else if (alpha < 26) // transparent
+ * {
+ * base = 231;
+ * entry = 0;
+ * }
+ * else // partially opaque
+ * {
+ * base = 226 + 6 * PNG_DIV51(alpha);
+ * entry = PNG_DIV51(gray);
+ * }
+ */
+ i = 0;
+ while (i < 231)
+ {
+ unsigned int gray = (i * 256 + 115) / 231;
+ png_create_colormap_entry(display, i++, gray, gray, gray, 255, E_sRGB);
+ }
+
+ /* 255 is used here for the component values for consistency with the code
+ * that undoes premultiplication in pngwrite.c.
+ */
+ png_create_colormap_entry(display, i++, 255, 255, 255, 0, E_sRGB);
+
+ for (a=1; a<5; ++a)
+ {
+ unsigned int g;
+
+ for (g=0; g<6; ++g)
+ png_create_colormap_entry(display, i++, g*51, g*51, g*51, a*51,
+ E_sRGB);
+ }
+
+ return i;
+}
+
+#define PNG_GA_COLORMAP_ENTRIES 256
+
+static int
+make_rgb_colormap(png_image_read_control *display)
+{
+ unsigned int i, r;
+
+ /* Build a 6x6x6 opaque RGB cube */
+ for (i=r=0; r<6; ++r)
+ {
+ unsigned int g;
+
+ for (g=0; g<6; ++g)
+ {
+ unsigned int b;
+
+ for (b=0; b<6; ++b)
+ png_create_colormap_entry(display, i++, r*51, g*51, b*51, 255,
+ E_sRGB);
+ }
+ }
+
+ return i;
+}
+
+#define PNG_RGB_COLORMAP_ENTRIES 216
+
+/* Return a palette index to the above palette given three 8-bit sRGB values. */
+#define PNG_RGB_INDEX(r,g,b) \
+ ((png_byte)(6 * (6 * PNG_DIV51(r) + PNG_DIV51(g)) + PNG_DIV51(b)))
+
+static int
+png_image_read_colormap(png_voidp argument)
+{
+ png_image_read_control *display =
+ png_voidcast(png_image_read_control*, argument);
+ const png_imagep image = display->image;
+
+ const png_structrp png_ptr = image->opaque->png_ptr;
+ const png_uint_32 output_format = image->format;
+ const int output_encoding = (output_format & PNG_FORMAT_FLAG_LINEAR) ?
+ E_LINEAR : E_sRGB;
+
+ unsigned int cmap_entries;
+ unsigned int output_processing; /* Output processing option */
+ unsigned int data_encoding = E_NOTSET; /* Encoding libpng must produce */
+
+ /* Background information; the background color and the index of this color
+ * in the color-map if it exists (else 256).
+ */
+ unsigned int background_index = 256;
+ png_uint_32 back_r, back_g, back_b;
+
+ /* Flags to accumulate things that need to be done to the input. */
+ int expand_tRNS = 0;
+
+ /* Exclude the NYI feature of compositing onto a color-mapped buffer; it is
+ * very difficult to do, the results look awful, and it is difficult to see
+ * what possible use it is because the application can't control the
+ * color-map.
+ */
+ if (((png_ptr->color_type & PNG_COLOR_MASK_ALPHA) != 0 ||
+ png_ptr->num_trans > 0) /* alpha in input */ &&
+ ((output_format & PNG_FORMAT_FLAG_ALPHA) == 0) /* no alpha in output */)
+ {
+ if (output_encoding == E_LINEAR) /* compose on black */
+ back_b = back_g = back_r = 0;
+
+ else if (display->background == NULL /* no way to remove it */)
+ png_error(png_ptr,
+ "a background color must be supplied to remove alpha/transparency");
+
+ /* Get a copy of the background color (this avoids repeating the checks
+ * below.) The encoding is 8-bit sRGB or 16-bit linear, depending on the
+ * output format.
+ */
+ else
+ {
+ back_g = display->background->green;
+ if (output_format & PNG_FORMAT_FLAG_COLOR)
+ {
+ back_r = display->background->red;
+ back_b = display->background->blue;
+ }
+ else
+ back_b = back_r = back_g;
+ }
+ }
+
+ else if (output_encoding == E_LINEAR)
+ back_b = back_r = back_g = 65535;
+
+ else
+ back_b = back_r = back_g = 255;
+
+ /* Default the input file gamma if required - this is necessary because
+ * libpng assumes that if no gamma information is present the data is in the
+ * output format, but the simplified API deduces the gamma from the input
+ * format.
+ */
+ if ((png_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_GAMMA) == 0)
+ {
+ /* Do this directly, not using the png_colorspace functions, to ensure
+ * that it happens even if the colorspace is invalid (though probably if
+ * it is the setting will be ignored) Note that the same thing can be
+ * achieved at the application interface with png_set_gAMA.
+ */
+ if (png_ptr->bit_depth == 16 &&
+ (image->flags & PNG_IMAGE_FLAG_16BIT_sRGB) == 0)
+ png_ptr->colorspace.gamma = PNG_GAMMA_LINEAR;
+
+ else
+ png_ptr->colorspace.gamma = PNG_GAMMA_sRGB_INVERSE;
+
+ png_ptr->colorspace.flags |= PNG_COLORSPACE_HAVE_GAMMA;
+ }
+
+ /* Decide what to do based on the PNG color type of the input data. The
+ * utility function png_create_colormap_entry deals with most aspects of the
+ * output transformations; this code works out how to produce bytes of
+ * color-map entries from the original format.
+ */
+ switch (png_ptr->color_type)
+ {
+ case PNG_COLOR_TYPE_GRAY:
+ if (png_ptr->bit_depth <= 8)
+ {
+ /* There at most 256 colors in the output, regardless of
+ * transparency.
+ */
+ unsigned int step, i, val, trans = 256/*ignore*/, back_alpha = 0;
+
+ cmap_entries = 1U << png_ptr->bit_depth;
+ if (cmap_entries > image->colormap_entries)
+ png_error(png_ptr, "gray[8] color-map: too few entries");
+
+ step = 255 / (cmap_entries - 1);
+ output_processing = PNG_CMAP_NONE;
+
+ /* If there is a tRNS chunk then this either selects a transparent
+ * value or, if the output has no alpha, the background color.
+ */
+ if (png_ptr->num_trans > 0)
+ {
+ trans = png_ptr->trans_color.gray;
+
+ if ((output_format & PNG_FORMAT_FLAG_ALPHA) == 0)
+ back_alpha = output_encoding == E_LINEAR ? 65535 : 255;
+ }
+
+ /* png_create_colormap_entry just takes an RGBA and writes the
+ * corresponding color-map entry using the format from 'image',
+ * including the required conversion to sRGB or linear as
+ * appropriate. The input values are always either sRGB (if the
+ * gamma correction flag is 0) or 0..255 scaled file encoded values
+ * (if the function must gamma correct them).
+ */
+ for (i=val=0; i<cmap_entries; ++i, val += step)
+ {
+ /* 'i' is a file value. While this will result in duplicated
+ * entries for 8-bit non-sRGB encoded files it is necessary to
+ * have non-gamma corrected values to do tRNS handling.
+ */
+ if (i != trans)
+ png_create_colormap_entry(display, i, val, val, val, 255,
+ E_FILE/*8-bit with file gamma*/);
+
+ /* Else this entry is transparent. The colors don't matter if
+ * there is an alpha channel (back_alpha == 0), but it does no
+ * harm to pass them in; the values are not set above so this
+ * passes in white.
+ *
+ * NOTE: this preserves the full precision of the application
+ * supplied background color when it is used.
+ */
+ else
+ png_create_colormap_entry(display, i, back_r, back_g, back_b,
+ back_alpha, output_encoding);
+ }
+
+ /* We need libpng to preserve the original encoding. */
+ data_encoding = E_FILE;
+
+ /* The rows from libpng, while technically gray values, are now also
+ * color-map indicies; however, they may need to be expanded to 1
+ * byte per pixel. This is what png_set_packing does (i.e., it
+ * unpacks the bit values into bytes.)
+ */
+ if (png_ptr->bit_depth < 8)
+ png_set_packing(png_ptr);
+ }
+
+ else /* bit depth is 16 */
+ {
+ /* The 16-bit input values can be converted directly to 8-bit gamma
+ * encoded values; however, if a tRNS chunk is present 257 color-map
+ * entries are required. This means that the extra entry requires
+ * special processing; add an alpha channel, sacrifice gray level
+ * 254 and convert transparent (alpha==0) entries to that.
+ *
+ * Use libpng to chop the data to 8 bits. Convert it to sRGB at the
+ * same time to minimize quality loss. If a tRNS chunk is present
+ * this means libpng must handle it too; otherwise it is impossible
+ * to do the exact match on the 16-bit value.
+ *
+ * If the output has no alpha channel *and* the background color is
+ * gray then it is possible to let libpng handle the substitution by
+ * ensuring that the corresponding gray level matches the background
+ * color exactly.
+ */
+ data_encoding = E_sRGB;
+
+ if (PNG_GRAY_COLORMAP_ENTRIES > image->colormap_entries)
+ png_error(png_ptr, "gray[16] color-map: too few entries");
+
+ cmap_entries = make_gray_colormap(display);
+
+ if (png_ptr->num_trans > 0)
+ {
+ unsigned int back_alpha;
+
+ if (output_format & PNG_FORMAT_FLAG_ALPHA)
+ back_alpha = 0;
+
+ else
+ {
+ if (back_r == back_g && back_g == back_b)
+ {
+ /* Background is gray; no special processing will be
+ * required.
+ */
+ png_color_16 c;
+ png_uint_32 gray = back_g;
+
+ if (output_encoding == E_LINEAR)
+ {
+ gray = PNG_sRGB_FROM_LINEAR(gray * 255);
+
+ /* And make sure the corresponding palette entry
+ * matches.
+ */
+ png_create_colormap_entry(display, gray, back_g, back_g,
+ back_g, 65535, E_LINEAR);
+ }
+
+ /* The background passed to libpng, however, must be the
+ * sRGB value.
+ */
+ c.index = 0; /*unused*/
+ c.gray = c.red = c.green = c.blue = (png_uint_16)gray;
+
+ /* NOTE: does this work without expanding tRNS to alpha?
+ * It should be the color->gray case below apparently
+ * doesn't.
+ */
+ png_set_background_fixed(png_ptr, &c,
+ PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/,
+ 0/*gamma: not used*/);
+
+ output_processing = PNG_CMAP_NONE;
+ break;
+ }
+
+ back_alpha = output_encoding == E_LINEAR ? 65535 : 255;
+ }
+
+ /* output_processing means that the libpng-processed row will be
+ * 8-bit GA and it has to be processing to single byte color-map
+ * values. Entry 254 is replaced by either a completely
+ * transparent entry or by the background color at full
+ * precision (and the background color is not a simple gray leve
+ * in this case.)
+ */
+ expand_tRNS = 1;
+ output_processing = PNG_CMAP_TRANS;
+ background_index = 254;
+
+ /* And set (overwrite) color-map entry 254 to the actual
+ * background color at full precision.
+ */
+ png_create_colormap_entry(display, 254, back_r, back_g, back_b,
+ back_alpha, output_encoding);
+ }
+
+ else
+ output_processing = PNG_CMAP_NONE;
+ }
+ break;
+
+ case PNG_COLOR_TYPE_GRAY_ALPHA:
+ /* 8-bit or 16-bit PNG with two channels - gray and alpha. A minimum
+ * of 65536 combinations. If, however, the alpha channel is to be
+ * removed there are only 256 possibilities if the background is gray.
+ * (Otherwise there is a subset of the 65536 possibilities defined by
+ * the triangle between black, white and the background color.)
+ *
+ * Reduce 16-bit files to 8-bit and sRGB encode the result. No need to
+ * worry about tRNS matching - tRNS is ignored if there is an alpha
+ * channel.
+ */
+ data_encoding = E_sRGB;
+
+ if (output_format & PNG_FORMAT_FLAG_ALPHA)
+ {
+ if (PNG_GA_COLORMAP_ENTRIES > image->colormap_entries)
+ png_error(png_ptr, "gray+alpha color-map: too few entries");
+
+ cmap_entries = make_ga_colormap(display);
+
+ background_index = PNG_CMAP_GA_BACKGROUND;
+ output_processing = PNG_CMAP_GA;
+ }
+
+ else /* alpha is removed */
+ {
+ /* Alpha must be removed as the PNG data is processed when the
+ * background is a color because the G and A channels are
+ * independent and the vector addition (non-parallel vectors) is a
+ * 2-D problem.
+ *
+ * This can be reduced to the same algorithm as above by making a
+ * colormap containing gray levels (for the opaque grays), a
+ * background entry (for a transparent pixel) and a set of four six
+ * level color values, one set for each intermediate alpha value.
+ * See the comments in make_ga_colormap for how this works in the
+ * per-pixel processing.
+ *
+ * If the background is gray, however, we only need a 256 entry gray
+ * level color map. It is sufficient to make the entry generated
+ * for the background color be exactly the color specified.
+ */
+ if ((output_format & PNG_FORMAT_FLAG_COLOR) == 0 ||
+ (back_r == back_g && back_g == back_b))
+ {
+ /* Background is gray; no special processing will be required. */
+ png_color_16 c;
+ png_uint_32 gray = back_g;
+
+ if (PNG_GRAY_COLORMAP_ENTRIES > image->colormap_entries)
+ png_error(png_ptr, "gray-alpha color-map: too few entries");
+
+ cmap_entries = make_gray_colormap(display);
+
+ if (output_encoding == E_LINEAR)
+ {
+ gray = PNG_sRGB_FROM_LINEAR(gray * 255);
+
+ /* And make sure the corresponding palette entry matches. */
+ png_create_colormap_entry(display, gray, back_g, back_g,
+ back_g, 65535, E_LINEAR);
+ }
+
+ /* The background passed to libpng, however, must be the sRGB
+ * value.
+ */
+ c.index = 0; /*unused*/
+ c.gray = c.red = c.green = c.blue = (png_uint_16)gray;
+
+ png_set_background_fixed(png_ptr, &c,
+ PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/,
+ 0/*gamma: not used*/);
+
+ output_processing = PNG_CMAP_NONE;
+ }
+
+ else
+ {
+ png_uint_32 i, a;
+
+ /* This is the same as png_make_ga_colormap, above, except that
+ * the entries are all opaque.
+ */
+ if (PNG_GA_COLORMAP_ENTRIES > image->colormap_entries)
+ png_error(png_ptr, "ga-alpha color-map: too few entries");
+
+ i = 0;
+ while (i < 231)
+ {
+ png_uint_32 gray = (i * 256 + 115) / 231;
+ png_create_colormap_entry(display, i++, gray, gray, gray,
+ 255, E_sRGB);
+ }
+
+ /* NOTE: this preserves the full precision of the application
+ * background color.
+ */
+ background_index = i;
+ png_create_colormap_entry(display, i++, back_r, back_g, back_b,
+ output_encoding == E_LINEAR ? 65535U : 255U, output_encoding);
+
+ /* For non-opaque input composite on the sRGB background - this
+ * requires inverting the encoding for each component. The input
+ * is still converted to the sRGB encoding because this is a
+ * reasonable approximate to the logarithmic curve of human
+ * visual sensitivity, at least over the narrow range which PNG
+ * represents. Consequently 'G' is always sRGB encoded, while
+ * 'A' is linear. We need the linear background colors.
+ */
+ if (output_encoding == E_sRGB) /* else already linear */
+ {
+ /* This may produce a value not exactly matching the
+ * background, but that's ok because these numbers are only
+ * used when alpha != 0
+ */
+ back_r = png_sRGB_table[back_r];
+ back_g = png_sRGB_table[back_g];
+ back_b = png_sRGB_table[back_b];
+ }
+
+ for (a=1; a<5; ++a)
+ {
+ unsigned int g;
+
+ /* PNG_sRGB_FROM_LINEAR expects a 16-bit linear value scaled
+ * by an 8-bit alpha value (0..255).
+ */
+ png_uint_32 alpha = 51 * a;
+ png_uint_32 back_rx = (255-alpha) * back_r;
+ png_uint_32 back_gx = (255-alpha) * back_g;
+ png_uint_32 back_bx = (255-alpha) * back_b;
+
+ for (g=0; g<6; ++g)
+ {
+ png_uint_32 gray = png_sRGB_table[g*51] * alpha;
+
+ png_create_colormap_entry(display, i++,
+ PNG_sRGB_FROM_LINEAR(gray + back_rx),
+ PNG_sRGB_FROM_LINEAR(gray + back_gx),
+ PNG_sRGB_FROM_LINEAR(gray + back_bx), 255, E_sRGB);
+ }
+ }
+
+ cmap_entries = i;
+ output_processing = PNG_CMAP_GA;
+ }
+ }
+ break;
+
+ case PNG_COLOR_TYPE_RGB:
+ case PNG_COLOR_TYPE_RGB_ALPHA:
+ /* Exclude the case where the output is gray; we can always handle this
+ * with the cases above.
+ */
+ if ((output_format & PNG_FORMAT_FLAG_COLOR) == 0)
+ {
+ /* The color-map will be grayscale, so we may as well convert the
+ * input RGB values to a simple grayscale and use the grayscale
+ * code above.
+ *
+ * NOTE: calling this apparently damages the recognition of the
+ * transparent color in background color handling; call
+ * png_set_tRNS_to_alpha before png_set_background_fixed.
+ */
+ png_set_rgb_to_gray_fixed(png_ptr, PNG_ERROR_ACTION_NONE, -1,
+ -1);
+ data_encoding = E_sRGB;
+
+ /* The output will now be one or two 8-bit gray or gray+alpha
+ * channels. The more complex case arises when the input has alpha.
+ */
+ if ((png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA ||
+ png_ptr->num_trans > 0) &&
+ (output_format & PNG_FORMAT_FLAG_ALPHA) != 0)
+ {
+ /* Both input and output have an alpha channel, so no background
+ * processing is required; just map the GA bytes to the right
+ * color-map entry.
+ */
+ expand_tRNS = 1;
+
+ if (PNG_GA_COLORMAP_ENTRIES > image->colormap_entries)
+ png_error(png_ptr, "rgb[ga] color-map: too few entries");
+
+ cmap_entries = make_ga_colormap(display);
+ background_index = PNG_CMAP_GA_BACKGROUND;
+ output_processing = PNG_CMAP_GA;
+ }
+
+ else
+ {
+ /* Either the input or the output has no alpha channel, so there
+ * will be no non-opaque pixels in the color-map; it will just be
+ * grayscale.
+ */
+ if (PNG_GRAY_COLORMAP_ENTRIES > image->colormap_entries)
+ png_error(png_ptr, "rgb[gray] color-map: too few entries");
+
+ /* Ideally this code would use libpng to do the gamma correction,
+ * but if an input alpha channel is to be removed we will hit the
+ * libpng bug in gamma+compose+rgb-to-gray (the double gamma
+ * correction bug). Fix this by dropping the gamma correction in
+ * this case and doing it in the palette; this will result in
+ * duplicate palette entries, but that's better than the
+ * alternative of double gamma correction.
+ */
+ if ((png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA ||
+ png_ptr->num_trans > 0) &&
+ png_gamma_not_sRGB(png_ptr->colorspace.gamma))
+ {
+ cmap_entries = make_gray_file_colormap(display);
+ data_encoding = E_FILE;
+ }
+
+ else
+ cmap_entries = make_gray_colormap(display);
+
+ /* But if the input has alpha or transparency it must be removed
+ */
+ if (png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA ||
+ png_ptr->num_trans > 0)
+ {
+ png_color_16 c;
+ png_uint_32 gray = back_g;
+
+ /* We need to ensure that the application background exists in
+ * the colormap and that completely transparent pixels map to
+ * it. Achieve this simply by ensuring that the entry
+ * selected for the background really is the background color.
+ */
+ if (data_encoding == E_FILE) /* from the fixup above */
+ {
+ /* The app supplied a gray which is in output_encoding, we
+ * need to convert it to a value of the input (E_FILE)
+ * encoding then set this palette entry to the required
+ * output encoding.
+ */
+ if (output_encoding == E_sRGB)
+ gray = png_sRGB_table[gray]; /* now E_LINEAR */
+
+ gray = PNG_DIV257(png_gamma_16bit_correct(gray,
+ png_ptr->colorspace.gamma)); /* now E_FILE */
+
+ /* And make sure the corresponding palette entry contains
+ * exactly the required sRGB value.
+ */
+ png_create_colormap_entry(display, gray, back_g, back_g,
+ back_g, 0/*unused*/, output_encoding);
+ }
+
+ else if (output_encoding == E_LINEAR)
+ {
+ gray = PNG_sRGB_FROM_LINEAR(gray * 255);
+
+ /* And make sure the corresponding palette entry matches.
+ */
+ png_create_colormap_entry(display, gray, back_g, back_g,
+ back_g, 0/*unused*/, E_LINEAR);
+ }
+
+ /* The background passed to libpng, however, must be the
+ * output (normally sRGB) value.
+ */
+ c.index = 0; /*unused*/
+ c.gray = c.red = c.green = c.blue = (png_uint_16)gray;
+
+ /* NOTE: the following is apparently a bug in libpng. Without
+ * it the transparent color recognition in
+ * png_set_background_fixed seems to go wrong.
+ */
+ expand_tRNS = 1;
+ png_set_background_fixed(png_ptr, &c,
+ PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/,
+ 0/*gamma: not used*/);
+ }
+
+ output_processing = PNG_CMAP_NONE;
+ }
+ }
+
+ else /* output is color */
+ {
+ /* We could use png_quantize here so long as there is no transparent
+ * color or alpha; png_quantize ignores alpha. Easier overall just
+ * to do it once and using PNG_DIV51 on the 6x6x6 reduced RGB cube.
+ * Consequently we always want libpng to produce sRGB data.
+ */
+ data_encoding = E_sRGB;
+
+ /* Is there any transparency or alpha? */
+ if (png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA ||
+ png_ptr->num_trans > 0)
+ {
+ /* Is there alpha in the output too? If so all four channels are
+ * processed into a special RGB cube with alpha support.
+ */
+ if (output_format & PNG_FORMAT_FLAG_ALPHA)
+ {
+ png_uint_32 r;
+
+ if (PNG_RGB_COLORMAP_ENTRIES+1+27 > image->colormap_entries)
+ png_error(png_ptr, "rgb+alpha color-map: too few entries");
+
+ cmap_entries = make_rgb_colormap(display);
+
+ /* Add a transparent entry. */
+ png_create_colormap_entry(display, cmap_entries, 255, 255,
+ 255, 0, E_sRGB);
+
+ /* This is stored as the background index for the processing
+ * algorithm.
+ */
+ background_index = cmap_entries++;
+
+ /* Add 27 r,g,b entries each with alpha 0.5. */
+ for (r=0; r<256; r = (r << 1) | 0x7f)
+ {
+ png_uint_32 g;
+
+ for (g=0; g<256; g = (g << 1) | 0x7f)
+ {
+ png_uint_32 b;
+
+ /* This generates components with the values 0, 127 and
+ * 255
+ */
+ for (b=0; b<256; b = (b << 1) | 0x7f)
+ png_create_colormap_entry(display, cmap_entries++,
+ r, g, b, 128, E_sRGB);
+ }
+ }
+
+ expand_tRNS = 1;
+ output_processing = PNG_CMAP_RGB_ALPHA;
+ }
+
+ else
+ {
+ /* Alpha/transparency must be removed. The background must
+ * exist in the color map (achieved by setting adding it after
+ * the 666 color-map). If the standard processing code will
+ * pick up this entry automatically that's all that is
+ * required; libpng can be called to do the background
+ * processing.
+ */
+ unsigned int sample_size =
+ PNG_IMAGE_SAMPLE_SIZE(output_format);
+ png_uint_32 r, g, b; /* sRGB background */
+
+ if (PNG_RGB_COLORMAP_ENTRIES+1+27 > image->colormap_entries)
+ png_error(png_ptr, "rgb-alpha color-map: too few entries");
+
+ cmap_entries = make_rgb_colormap(display);
+
+ png_create_colormap_entry(display, cmap_entries, back_r,
+ back_g, back_b, 0/*unused*/, output_encoding);
+
+ if (output_encoding == E_LINEAR)
+ {
+ r = PNG_sRGB_FROM_LINEAR(back_r * 255);
+ g = PNG_sRGB_FROM_LINEAR(back_g * 255);
+ b = PNG_sRGB_FROM_LINEAR(back_b * 255);
+ }
+
+ else
+ {
+ r = back_r;
+ g = back_g;
+ b = back_g;
+ }
+
+ /* Compare the newly-created color-map entry with the one the
+ * PNG_CMAP_RGB algorithm will use. If the two entries don't
+ * match, add the new one and set this as the background
+ * index.
+ */
+ if (memcmp((png_const_bytep)display->colormap +
+ sample_size * cmap_entries,
+ (png_const_bytep)display->colormap +
+ sample_size * PNG_RGB_INDEX(r,g,b),
+ sample_size) != 0)
+ {
+ /* The background color must be added. */
+ background_index = cmap_entries++;
+
+ /* Add 27 r,g,b entries each with created by composing with
+ * the background at alpha 0.5.
+ */
+ for (r=0; r<256; r = (r << 1) | 0x7f)
+ {
+ for (g=0; g<256; g = (g << 1) | 0x7f)
+ {
+ /* This generates components with the values 0, 127
+ * and 255
+ */
+ for (b=0; b<256; b = (b << 1) | 0x7f)
+ png_create_colormap_entry(display, cmap_entries++,
+ png_colormap_compose(display, r, E_sRGB, 128,
+ back_r, output_encoding),
+ png_colormap_compose(display, g, E_sRGB, 128,
+ back_g, output_encoding),
+ png_colormap_compose(display, b, E_sRGB, 128,
+ back_b, output_encoding),
+ 0/*unused*/, output_encoding);
+ }
+ }
+
+ expand_tRNS = 1;
+ output_processing = PNG_CMAP_RGB_ALPHA;
+ }
+
+ else /* background color is in the standard color-map */
+ {
+ png_color_16 c;
+
+ c.index = 0; /*unused*/
+ c.red = (png_uint_16)back_r;
+ c.gray = c.green = (png_uint_16)back_g;
+ c.blue = (png_uint_16)back_b;
+
+ png_set_background_fixed(png_ptr, &c,
+ PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/,
+ 0/*gamma: not used*/);
+
+ output_processing = PNG_CMAP_RGB;
+ }
+ }
+ }
+
+ else /* no alpha or transparency in the input */
+ {
+ /* Alpha in the output is irrelevant, simply map the opaque input
+ * pixels to the 6x6x6 color-map.
+ */
+ if (PNG_RGB_COLORMAP_ENTRIES > image->colormap_entries)
+ png_error(png_ptr, "rgb color-map: too few entries");
+
+ cmap_entries = make_rgb_colormap(display);
+ output_processing = PNG_CMAP_RGB;
+ }
+ }
+ break;
+
+ case PNG_COLOR_TYPE_PALETTE:
+ /* It's already got a color-map. It may be necessary to eliminate the
+ * tRNS entries though.
+ */
+ {
+ unsigned int num_trans = png_ptr->num_trans;
+ png_const_bytep trans = num_trans > 0 ? png_ptr->trans_alpha : NULL;
+ png_const_colorp colormap = png_ptr->palette;
+ const int do_background = trans != NULL &&
+ (output_format & PNG_FORMAT_FLAG_ALPHA) == 0;
+ unsigned int i;
+
+ /* Just in case: */
+ if (trans == NULL)
+ num_trans = 0;
+
+ output_processing = PNG_CMAP_NONE;
+ data_encoding = E_FILE; /* Don't change from color-map indicies */
+ cmap_entries = png_ptr->num_palette;
+ if (cmap_entries > 256)
+ cmap_entries = 256;
+
+ if (cmap_entries > image->colormap_entries)
+ png_error(png_ptr, "palette color-map: too few entries");
+
+ for (i=0; i < cmap_entries; ++i)
+ {
+ if (do_background && i < num_trans && trans[i] < 255)
+ {
+ if (trans[i] == 0)
+ png_create_colormap_entry(display, i, back_r, back_g,
+ back_b, 0, output_encoding);
+
+ else
+ {
+ /* Must compose the PNG file color in the color-map entry
+ * on the sRGB color in 'back'.
+ */
+ png_create_colormap_entry(display, i,
+ png_colormap_compose(display, colormap[i].red, E_FILE,
+ trans[i], back_r, output_encoding),
+ png_colormap_compose(display, colormap[i].green, E_FILE,
+ trans[i], back_g, output_encoding),
+ png_colormap_compose(display, colormap[i].blue, E_FILE,
+ trans[i], back_b, output_encoding),
+ output_encoding == E_LINEAR ? trans[i] * 257U :
+ trans[i],
+ output_encoding);
+ }
+ }
+
+ else
+ png_create_colormap_entry(display, i, colormap[i].red,
+ colormap[i].green, colormap[i].blue,
+ i < num_trans ? trans[i] : 255U, E_FILE/*8-bit*/);
+ }
+
+ /* The PNG data may have indicies packed in fewer than 8 bits, it
+ * must be expanded if so.
+ */
+ if (png_ptr->bit_depth < 8)
+ png_set_packing(png_ptr);
+ }
+ break;
+
+ default:
+ png_error(png_ptr, "invalid PNG color type");
+ /*NOT REACHED*/
+ break;
+ }
+
+ /* Now deal with the output processing */
+ if (expand_tRNS && png_ptr->num_trans > 0 &&
+ (png_ptr->color_type & PNG_COLOR_MASK_ALPHA) == 0)
+ png_set_tRNS_to_alpha(png_ptr);
+
+ switch (data_encoding)
+ {
+ default:
+ png_error(png_ptr, "bad data option (internal error)");
+ break;
+
+ case E_sRGB:
+ /* Change to 8-bit sRGB */
+ png_set_alpha_mode_fixed(png_ptr, PNG_ALPHA_PNG, PNG_GAMMA_sRGB);
+ /* FALL THROUGH */
+
+ case E_FILE:
+ if (png_ptr->bit_depth > 8)
+ png_set_scale_16(png_ptr);
+ break;
+ }
+
+ if (cmap_entries > 256 || cmap_entries > image->colormap_entries)
+ png_error(png_ptr, "color map overflow (BAD internal error)");
+
+ image->colormap_entries = cmap_entries;
+
+ /* Double check using the recorded background index */
+ switch (output_processing)
+ {
+ case PNG_CMAP_NONE:
+ if (background_index != PNG_CMAP_NONE_BACKGROUND)
+ goto bad_background;
+ break;
+
+ case PNG_CMAP_GA:
+ if (background_index != PNG_CMAP_GA_BACKGROUND)
+ goto bad_background;
+ break;
+
+ case PNG_CMAP_TRANS:
+ if (background_index >= cmap_entries ||
+ background_index != PNG_CMAP_TRANS_BACKGROUND)
+ goto bad_background;
+ break;
+
+ case PNG_CMAP_RGB:
+ if (background_index != PNG_CMAP_RGB_BACKGROUND)
+ goto bad_background;
+ break;
+
+ case PNG_CMAP_RGB_ALPHA:
+ if (background_index != PNG_CMAP_RGB_ALPHA_BACKGROUND)
+ goto bad_background;
+ break;
+
+ default:
+ png_error(png_ptr, "bad processing option (internal error)");
+
+ bad_background:
+ png_error(png_ptr, "bad background index (internal error)");
+ }
+
+ display->colormap_processing = output_processing;
+
+ return 1/*ok*/;
+}
+
+/* The final part of the color-map read called from png_image_finish_read. */
+static int
+png_image_read_and_map(png_voidp argument)
+{
+ png_image_read_control *display = png_voidcast(png_image_read_control*,
+ argument);
+ png_imagep image = display->image;
+ png_structrp png_ptr = image->opaque->png_ptr;
+ int passes;
+
+ /* Called when the libpng data must be transformed into the color-mapped
+ * form. There is a local row buffer in display->local and this routine must
+ * do the interlace handling.
+ */
+ switch (png_ptr->interlaced)
+ {
+ case PNG_INTERLACE_NONE:
+ passes = 1;
+ break;
+
+ case PNG_INTERLACE_ADAM7:
+ passes = PNG_INTERLACE_ADAM7_PASSES;
+ break;
+
+ default:
+ passes = 0;
+ png_error(png_ptr, "unknown interlace type");
+ }
+
+ {
+ png_uint_32 height = image->height;
+ png_uint_32 width = image->width;
+ int proc = display->colormap_processing;
+ png_bytep first_row = png_voidcast(png_bytep, display->first_row);
+ ptrdiff_t step_row = display->row_bytes;
+ int pass;
+
+ for (pass = 0; pass < passes; ++pass)
+ {
+ unsigned int startx, stepx, stepy;
+ png_uint_32 y;
+
+ if (png_ptr->interlaced == PNG_INTERLACE_ADAM7)
+ {
+ /* The row may be empty for a short image: */
+ if (PNG_PASS_COLS(width, pass) == 0)
+ continue;
+
+ startx = PNG_PASS_START_COL(pass);
+ stepx = PNG_PASS_COL_OFFSET(pass);
+ y = PNG_PASS_START_ROW(pass);
+ stepy = PNG_PASS_ROW_OFFSET(pass);
+ }
+
+ else
+ {
+ y = 0;
+ startx = 0;
+ stepx = stepy = 1;
+ }
+
+ for (; y<height; y += stepy)
+ {
+ png_bytep inrow = png_voidcast(png_bytep, display->local_row);
+ png_bytep outrow = first_row + y * step_row;
+ png_const_bytep end_row = outrow + width;
+
+ /* Read read the libpng data into the temporary buffer. */
+ png_read_row(png_ptr, inrow, NULL);
+
+ /* Now process the row according to the processing option, note
+ * that the caller verifies that the format of the libpng output
+ * data is as required.
+ */
+ outrow += startx;
+ switch (proc)
+ {
+ case PNG_CMAP_GA:
+ for (; outrow < end_row; outrow += stepx)
+ {
+ /* The data is always in the PNG order */
+ unsigned int gray = *inrow++;
+ unsigned int alpha = *inrow++;
+ unsigned int entry;
+
+ /* NOTE: this code is copied as a comment in
+ * make_ga_colormap above. Please update the
+ * comment if you change this code!
+ */
+ if (alpha > 229) /* opaque */
+ {
+ entry = (231 * gray + 128) >> 8;
+ }
+ else if (alpha < 26) /* transparent */
+ {
+ entry = 231;
+ }
+ else /* partially opaque */
+ {
+ entry = 226 + 6 * PNG_DIV51(alpha) + PNG_DIV51(gray);
+ }
+
+ *outrow = (png_byte)entry;
+ }
+ break;
+
+ case PNG_CMAP_TRANS:
+ for (; outrow < end_row; outrow += stepx)
+ {
+ png_byte gray = *inrow++;
+ png_byte alpha = *inrow++;
+
+ if (alpha == 0)
+ *outrow = PNG_CMAP_TRANS_BACKGROUND;
+
+ else if (gray != PNG_CMAP_TRANS_BACKGROUND)
+ *outrow = gray;
+
+ else
+ *outrow = (png_byte)(PNG_CMAP_TRANS_BACKGROUND+1);
+ }
+ break;
+
+ case PNG_CMAP_RGB:
+ for (; outrow < end_row; outrow += stepx)
+ {
+ *outrow = PNG_RGB_INDEX(inrow[0], inrow[1], inrow[2]);
+ inrow += 3;
+ }
+ break;
+
+ case PNG_CMAP_RGB_ALPHA:
+ for (; outrow < end_row; outrow += stepx)
+ {
+ unsigned int alpha = inrow[3];
+
+ /* Because the alpha entries only hold alpha==0.5 values
+ * split the processing at alpha==0.25 (64) and 0.75
+ * (196).
+ */
+
+ if (alpha >= 196)
+ *outrow = PNG_RGB_INDEX(inrow[0], inrow[1],
+ inrow[2]);
+
+ else if (alpha < 64)
+ *outrow = PNG_CMAP_RGB_ALPHA_BACKGROUND;
+
+ else
+ {
+ /* Likewise there are three entries for each of r, g
+ * and b. We could select the entry by popcount on
+ * the top two bits on those architectures that
+ * support it, this is what the code below does,
+ * crudely.
+ */
+ unsigned int back_i = PNG_CMAP_RGB_ALPHA_BACKGROUND+1;
+
+ /* Here are how the values map:
+ *
+ * 0x00 .. 0x3f -> 0
+ * 0x40 .. 0xbf -> 1
+ * 0xc0 .. 0xff -> 2
+ *
+ * So, as above with the explicit alpha checks, the
+ * breakpoints are at 64 and 196.
+ */
+ if (inrow[0] & 0x80) back_i += 9; /* red */
+ if (inrow[0] & 0x40) back_i += 9;
+ if (inrow[0] & 0x80) back_i += 3; /* green */
+ if (inrow[0] & 0x40) back_i += 3;
+ if (inrow[0] & 0x80) back_i += 1; /* blue */
+ if (inrow[0] & 0x40) back_i += 1;
+
+ *outrow = (png_byte)back_i;
+ }
+
+ inrow += 4;
+ }
+ break;
+
+ default:
+ break;
+ }
+ }
+ }
+ }
+
+ return 1;
+}
+
+static int
+png_image_read_colormapped(png_voidp argument)
+{
+ png_image_read_control *display = png_voidcast(png_image_read_control*,
+ argument);
+ png_imagep image = display->image;
+ png_controlp control = image->opaque;
+ png_structrp png_ptr = control->png_ptr;
+ png_inforp info_ptr = control->info_ptr;
+
+ int passes = 0; /* As a flag */
+
+ PNG_SKIP_CHUNKS(png_ptr);
+
+ /* Update the 'info' structure and make sure the result is as required; first
+ * make sure to turn on the interlace handling if it will be required
+ * (because it can't be turned on *after* the call to png_read_update_info!)
+ */
+ if (display->colormap_processing == PNG_CMAP_NONE)
+ passes = png_set_interlace_handling(png_ptr);
+
+ png_read_update_info(png_ptr, info_ptr);
+
+ /* The expected output can be deduced from the colormap_processing option. */
+ switch (display->colormap_processing)
+ {
+ case PNG_CMAP_NONE:
+ /* Output must be one channel and one byte per pixel, the output
+ * encoding can be anything.
+ */
+ if ((info_ptr->color_type == PNG_COLOR_TYPE_PALETTE ||
+ info_ptr->color_type == PNG_COLOR_TYPE_GRAY) &&
+ info_ptr->bit_depth == 8)
+ break;
+
+ goto bad_output;
+
+ case PNG_CMAP_TRANS:
+ case PNG_CMAP_GA:
+ /* Output must be two channels and the 'G' one must be sRGB, the latter
+ * can be checked with an exact number because it should have been set
+ * to this number above!
+ */
+ if (info_ptr->color_type == PNG_COLOR_TYPE_GRAY_ALPHA &&
+ info_ptr->bit_depth == 8 &&
+ png_ptr->screen_gamma == PNG_GAMMA_sRGB &&
+ image->colormap_entries == 256)
+ break;
+
+ goto bad_output;
+
+ case PNG_CMAP_RGB:
+ /* Output must be 8-bit sRGB encoded RGB */
+ if (info_ptr->color_type == PNG_COLOR_TYPE_RGB &&
+ info_ptr->bit_depth == 8 &&
+ png_ptr->screen_gamma == PNG_GAMMA_sRGB &&
+ image->colormap_entries == 216)
+ break;
+
+ goto bad_output;
+
+ case PNG_CMAP_RGB_ALPHA:
+ /* Output must be 8-bit sRGB encoded RGBA */
+ if (info_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA &&
+ info_ptr->bit_depth == 8 &&
+ png_ptr->screen_gamma == PNG_GAMMA_sRGB &&
+ image->colormap_entries == 244 /* 216 + 1 + 27 */)
+ break;
+
+ /* goto bad_output; */
+ /* FALL THROUGH */
+
+ default:
+ bad_output:
+ png_error(png_ptr, "bad color-map processing (internal error)");
+ }
+
+ /* Now read the rows. Do this here if it is possible to read directly into
+ * the output buffer, otherwise allocate a local row buffer of the maximum
+ * size libpng requires and call the relevant processing routine safely.
+ */
+ {
+ png_voidp first_row = display->buffer;
+ ptrdiff_t row_bytes = display->row_stride;
+
+ /* The following expression is designed to work correctly whether it gives
+ * a signed or an unsigned result.
+ */
+ if (row_bytes < 0)
+ {
+ char *ptr = png_voidcast(char*, first_row);
+ ptr += (image->height-1) * (-row_bytes);
+ first_row = png_voidcast(png_voidp, ptr);
+ }
+
+ display->first_row = first_row;
+ display->row_bytes = row_bytes;
+ }
+
+ if (passes == 0)
+ {
+ int result;
+ png_voidp row = png_malloc(png_ptr, png_get_rowbytes(png_ptr, info_ptr));
+
+ display->local_row = row;
+ result = png_safe_execute(image, png_image_read_and_map, display);
+ display->local_row = NULL;
+ png_free(png_ptr, row);
+
+ return result;
+ }
+
+ else
+ {
+ png_alloc_size_t row_bytes = display->row_bytes;
+
+ while (--passes >= 0)
+ {
+ png_uint_32 y = image->height;
+ png_bytep row = png_voidcast(png_bytep, display->first_row);
+
+ while (y-- > 0)
+ {
+ png_read_row(png_ptr, row, NULL);
+ row += row_bytes;
+ }
+ }
+
+ return 1;
+ }
+}
+
+/* Just the row reading part of png_image_read. */
+static int
+png_image_read_composite(png_voidp argument)
+{
+ png_image_read_control *display = png_voidcast(png_image_read_control*,
+ argument);
+ png_imagep image = display->image;
+ png_structrp png_ptr = image->opaque->png_ptr;
+ int passes;
+
+ switch (png_ptr->interlaced)
+ {
+ case PNG_INTERLACE_NONE:
+ passes = 1;
+ break;
+
+ case PNG_INTERLACE_ADAM7:
+ passes = PNG_INTERLACE_ADAM7_PASSES;
+ break;
+
+ default:
+ passes = 0;
+ png_error(png_ptr, "unknown interlace type");
+ }
+
+ {
+ png_uint_32 height = image->height;
+ png_uint_32 width = image->width;
+ ptrdiff_t step_row = display->row_bytes;
+ unsigned int channels = (image->format & PNG_FORMAT_FLAG_COLOR) ? 3 : 1;
+ int pass;
+
+ for (pass = 0; pass < passes; ++pass)
+ {
+ unsigned int startx, stepx, stepy;
+ png_uint_32 y;
+
+ if (png_ptr->interlaced == PNG_INTERLACE_ADAM7)
+ {
+ /* The row may be empty for a short image: */
+ if (PNG_PASS_COLS(width, pass) == 0)
+ continue;
+
+ startx = PNG_PASS_START_COL(pass) * channels;
+ stepx = PNG_PASS_COL_OFFSET(pass) * channels;
+ y = PNG_PASS_START_ROW(pass);
+ stepy = PNG_PASS_ROW_OFFSET(pass);
+ }
+
+ else
+ {
+ y = 0;
+ startx = 0;
+ stepx = channels;
+ stepy = 1;
+ }
+
+ for (; y<height; y += stepy)
+ {
+ png_bytep inrow = png_voidcast(png_bytep, display->local_row);
+ png_bytep outrow;
+ png_const_bytep end_row;
+
+ /* Read the row, which is packed: */
+ png_read_row(png_ptr, inrow, NULL);
+
+ outrow = png_voidcast(png_bytep, display->first_row);
+ outrow += y * step_row;
+ end_row = outrow + width * channels;
+
+ /* Now do the composition on each pixel in this row. */
+ outrow += startx;
+ for (; outrow < end_row; outrow += stepx)
+ {
+ png_byte alpha = inrow[channels];
+
+ if (alpha > 0) /* else no change to the output */
+ {
+ unsigned int c;
+
+ for (c=0; c<channels; ++c)
+ {
+ png_uint_32 component = inrow[c];
+
+ if (alpha < 255) /* else just use component */
+ {
+ /* This is PNG_OPTIMIZED_ALPHA, the component value
+ * is a linear 8-bit value. Combine this with the
+ * current outrow[c] value which is sRGB encoded.
+ * Arithmetic here is 16-bits to preserve the output
+ * values correctly.
+ */
+ component *= 257*255; /* =65535 */
+ component += (255-alpha)*png_sRGB_table[outrow[c]];
+
+ /* So 'component' is scaled by 255*65535 and is
+ * therefore appropriate for the sRGB to linear
+ * conversion table.
+ */
+ component = PNG_sRGB_FROM_LINEAR(component);
+ }
+
+ outrow[c] = (png_byte)component;
+ }
+ }
+
+ inrow += channels+1; /* components and alpha channel */
+ }
+ }
+ }
+ }
+
+ return 1;
+}
+
+/* The do_local_background case; called when all the following transforms are to
+ * be done:
+ *
+ * PNG_RGB_TO_GRAY
+ * PNG_COMPOSITE
+ * PNG_GAMMA
+ *
+ * This is a work-round for the fact that both the PNG_RGB_TO_GRAY and
+ * PNG_COMPOSITE code performs gamma correction, so we get double gamma
+ * correction. The fix-up is to prevent the PNG_COMPOSITE operation happening
+ * inside libpng, so this routine sees an 8 or 16-bit gray+alpha row and handles
+ * the removal or pre-multiplication of the alpha channel.
+ */
+static int
+png_image_read_background(png_voidp argument)
+{
+ png_image_read_control *display = png_voidcast(png_image_read_control*,
+ argument);
+ png_imagep image = display->image;
+ png_structrp png_ptr = image->opaque->png_ptr;
+ png_inforp info_ptr = image->opaque->info_ptr;
+ png_uint_32 height = image->height;
+ png_uint_32 width = image->width;
+ int pass, passes;
+
+ /* Double check the convoluted logic below. We expect to get here with
+ * libpng doing rgb to gray and gamma correction but background processing
+ * left to the png_image_read_background function. The rows libpng produce
+ * might be 8 or 16-bit but should always have two channels; gray plus alpha.
+ */
+ if ((png_ptr->transformations & PNG_RGB_TO_GRAY) == 0)
+ png_error(png_ptr, "lost rgb to gray");
+
+ if ((png_ptr->transformations & PNG_COMPOSE) != 0)
+ png_error(png_ptr, "unexpected compose");
+
+ if (png_get_channels(png_ptr, info_ptr) != 2)
+ png_error(png_ptr, "lost/gained channels");
+
+ /* Expect the 8-bit case to always remove the alpha channel */
+ if ((image->format & PNG_FORMAT_FLAG_LINEAR) == 0 &&
+ (image->format & PNG_FORMAT_FLAG_ALPHA) != 0)
+ png_error(png_ptr, "unexpected 8-bit transformation");
+
+ switch (png_ptr->interlaced)
+ {
+ case PNG_INTERLACE_NONE:
+ passes = 1;
+ break;
+
+ case PNG_INTERLACE_ADAM7:
+ passes = PNG_INTERLACE_ADAM7_PASSES;
+ break;
+
+ default:
+ passes = 0;
+ png_error(png_ptr, "unknown interlace type");
+ }
+
+ switch (png_get_bit_depth(png_ptr, info_ptr))
+ {
+ default:
+ png_error(png_ptr, "unexpected bit depth");
+ break;
+
+ case 8:
+ /* 8-bit sRGB gray values with an alpha channel; the alpha channel is
+ * to be removed by composing on a background: either the row if
+ * display->background is NULL or display->background->green if not.
+ * Unlike the code above ALPHA_OPTIMIZED has *not* been done.
+ */
+ {
+ png_bytep first_row = png_voidcast(png_bytep, display->first_row);
+ ptrdiff_t step_row = display->row_bytes;
+
+ for (pass = 0; pass < passes; ++pass)
+ {
+ png_bytep row = png_voidcast(png_bytep,
+ display->first_row);
+ unsigned int startx, stepx, stepy;
+ png_uint_32 y;
+
+ if (png_ptr->interlaced == PNG_INTERLACE_ADAM7)
+ {
+ /* The row may be empty for a short image: */
+ if (PNG_PASS_COLS(width, pass) == 0)
+ continue;
+
+ startx = PNG_PASS_START_COL(pass);
+ stepx = PNG_PASS_COL_OFFSET(pass);
+ y = PNG_PASS_START_ROW(pass);
+ stepy = PNG_PASS_ROW_OFFSET(pass);
+ }
+
+ else
+ {
+ y = 0;
+ startx = 0;
+ stepx = stepy = 1;
+ }
+
+ if (display->background == NULL)
+ {
+ for (; y<height; y += stepy)
+ {
+ png_bytep inrow = png_voidcast(png_bytep,
+ display->local_row);
+ png_bytep outrow = first_row + y * step_row;
+ png_const_bytep end_row = outrow + width;
+
+ /* Read the row, which is packed: */
+ png_read_row(png_ptr, inrow, NULL);
+
+ /* Now do the composition on each pixel in this row. */
+ outrow += startx;
+ for (; outrow < end_row; outrow += stepx)
+ {
+ png_byte alpha = inrow[1];
+
+ if (alpha > 0) /* else no change to the output */
+ {
+ png_uint_32 component = inrow[0];
+
+ if (alpha < 255) /* else just use component */
+ {
+ /* Since PNG_OPTIMIZED_ALPHA was not set it is
+ * necessary to invert the sRGB transfer
+ * function and multiply the alpha out.
+ */
+ component = png_sRGB_table[component] * alpha;
+ component += png_sRGB_table[outrow[0]] *
+ (255-alpha);
+ component = PNG_sRGB_FROM_LINEAR(component);
+ }
+
+ outrow[0] = (png_byte)component;
+ }
+
+ inrow += 2; /* gray and alpha channel */
+ }
+ }
+ }
+
+ else /* constant background value */
+ {
+ png_byte background8 = display->background->green;
+ png_uint_16 background = png_sRGB_table[background8];
+
+ for (; y<height; y += stepy)
+ {
+ png_bytep inrow = png_voidcast(png_bytep,
+ display->local_row);
+ png_bytep outrow = first_row + y * step_row;
+ png_const_bytep end_row = outrow + width;
+
+ /* Read the row, which is packed: */
+ png_read_row(png_ptr, inrow, NULL);
+
+ /* Now do the composition on each pixel in this row. */
+ outrow += startx;
+ for (; outrow < end_row; outrow += stepx)
+ {
+ png_byte alpha = inrow[1];
+
+ if (alpha > 0) /* else use background */
+ {
+ png_uint_32 component = inrow[0];
+
+ if (alpha < 255) /* else just use component */
+ {
+ component = png_sRGB_table[component] * alpha;
+ component += background * (255-alpha);
+ component = PNG_sRGB_FROM_LINEAR(component);
+ }
+
+ outrow[0] = (png_byte)component;
+ }
+
+ else
+ outrow[0] = background8;
+
+ inrow += 2; /* gray and alpha channel */
+ }
+
+ row += display->row_bytes;
+ }
+ }
+ }
+ }
+ break;
+
+ case 16:
+ /* 16-bit linear with pre-multiplied alpha; the pre-multiplication must
+ * still be done and, maybe, the alpha channel removed. This code also
+ * handles the alpha-first option.
+ */
+ {
+ png_uint_16p first_row = png_voidcast(png_uint_16p,
+ display->first_row);
+ /* The division by two is safe because the caller passed in a
+ * stride which was multiplied by 2 (below) to get row_bytes.
+ */
+ ptrdiff_t step_row = display->row_bytes / 2;
+ int preserve_alpha = (image->format & PNG_FORMAT_FLAG_ALPHA) != 0;
+ unsigned int outchannels = 1+preserve_alpha;
+ int swap_alpha = 0;
+
+ if (preserve_alpha && (image->format & PNG_FORMAT_FLAG_AFIRST))
+ swap_alpha = 1;
+
+ for (pass = 0; pass < passes; ++pass)
+ {
+ unsigned int startx, stepx, stepy;
+ png_uint_32 y;
+
+ /* The 'x' start and step are adjusted to output components here.
+ */
+ if (png_ptr->interlaced == PNG_INTERLACE_ADAM7)
+ {
+ /* The row may be empty for a short image: */
+ if (PNG_PASS_COLS(width, pass) == 0)
+ continue;
+
+ startx = PNG_PASS_START_COL(pass) * outchannels;
+ stepx = PNG_PASS_COL_OFFSET(pass) * outchannels;
+ y = PNG_PASS_START_ROW(pass);
+ stepy = PNG_PASS_ROW_OFFSET(pass);
+ }
+
+ else
+ {
+ y = 0;
+ startx = 0;
+ stepx = outchannels;
+ stepy = 1;
+ }
+
+ for (; y<height; y += stepy)
+ {
+ png_const_uint_16p inrow;
+ png_uint_16p outrow = first_row + y*step_row;
+ png_uint_16p end_row = outrow + width * outchannels;
+
+ /* Read the row, which is packed: */
+ png_read_row(png_ptr, png_voidcast(png_bytep,
+ display->local_row), NULL);
+ inrow = png_voidcast(png_const_uint_16p, display->local_row);
+
+ /* Now do the pre-multiplication on each pixel in this row.
+ */
+ outrow += startx;
+ for (; outrow < end_row; outrow += stepx)
+ {
+ png_uint_32 component = inrow[0];
+ png_uint_16 alpha = inrow[1];
+
+ if (alpha > 0) /* else 0 */
+ {
+ if (alpha < 65535) /* else just use component */
+ {
+ component *= alpha;
+ component += 32767;
+ component /= 65535;
+ }
+ }
+
+ else
+ component = 0;
+
+ outrow[swap_alpha] = (png_uint_16)component;
+ if (preserve_alpha)
+ outrow[1 ^ swap_alpha] = alpha;
+
+ inrow += 2; /* components and alpha channel */
+ }
+ }
+ }
+ }
+ break;
+ }
+
+ return 1;
+}
+
+/* The guts of png_image_finish_read as a png_safe_execute callback. */
+static int
+png_image_read_direct(png_voidp argument)
+{
+ png_image_read_control *display = png_voidcast(png_image_read_control*,
+ argument);
+ png_imagep image = display->image;
+ png_structrp png_ptr = image->opaque->png_ptr;
+ png_inforp info_ptr = image->opaque->info_ptr;
+
+ png_uint_32 format = image->format;
+ int linear = (format & PNG_FORMAT_FLAG_LINEAR) != 0;
+ int do_local_compose = 0;
+ int do_local_background = 0; /* to avoid double gamma correction bug */
+ int passes = 0;
+
+ /* Add transforms to ensure the correct output format is produced then check
+ * that the required implementation support is there. Always expand; always
+ * need 8 bits minimum, no palette and expanded tRNS.
+ */
+ png_set_expand(png_ptr);
+
+ /* Now check the format to see if it was modified. */
+ {
+ png_uint_32 base_format = png_image_format(png_ptr) &
+ ~PNG_FORMAT_FLAG_COLORMAP /* removed by png_set_expand */;
+ png_uint_32 change = format ^ base_format;
+ png_fixed_point output_gamma;
+ int mode; /* alpha mode */
+
+ /* Do this first so that we have a record if rgb to gray is happening. */
+ if (change & PNG_FORMAT_FLAG_COLOR)
+ {
+ /* gray<->color transformation required. */
+ if (format & PNG_FORMAT_FLAG_COLOR)
+ png_set_gray_to_rgb(png_ptr);
+
+ else
+ {
+ /* libpng can't do both rgb to gray and
+ * background/pre-multiplication if there is also significant gamma
+ * correction, because both operations require linear colors and
+ * the code only supports one transform doing the gamma correction.
+ * Handle this by doing the pre-multiplication or background
+ * operation in this code, if necessary.
+ *
+ * TODO: fix this by rewriting pngrtran.c (!)
+ *
+ * For the moment (given that fixing this in pngrtran.c is an
+ * enormous change) 'do_local_background' is used to indicate that
+ * the problem exists.
+ */
+ if (base_format & PNG_FORMAT_FLAG_ALPHA)
+ do_local_background = 1/*maybe*/;
+
+ png_set_rgb_to_gray_fixed(png_ptr, PNG_ERROR_ACTION_NONE,
+ PNG_RGB_TO_GRAY_DEFAULT, PNG_RGB_TO_GRAY_DEFAULT);
+ }
+
+ change &= ~PNG_FORMAT_FLAG_COLOR;
+ }
+
+ /* Set the gamma appropriately, linear for 16-bit input, sRGB otherwise.
+ */
+ {
+ png_fixed_point input_gamma_default;
+
+ if ((base_format & PNG_FORMAT_FLAG_LINEAR) &&
+ (image->flags & PNG_IMAGE_FLAG_16BIT_sRGB) == 0)
+ input_gamma_default = PNG_GAMMA_LINEAR;
+ else
+ input_gamma_default = PNG_DEFAULT_sRGB;
+
+ /* Call png_set_alpha_mode to set the default for the input gamma; the
+ * output gamma is set by a second call below.
+ */
+ png_set_alpha_mode_fixed(png_ptr, PNG_ALPHA_PNG, input_gamma_default);
+ }
+
+ if (linear)
+ {
+ /* If there *is* an alpha channel in the input it must be multiplied
+ * out; use PNG_ALPHA_STANDARD, otherwise just use PNG_ALPHA_PNG.
+ */
+ if (base_format & PNG_FORMAT_FLAG_ALPHA)
+ mode = PNG_ALPHA_STANDARD; /* associated alpha */
+
+ else
+ mode = PNG_ALPHA_PNG;
+
+ output_gamma = PNG_GAMMA_LINEAR;
+ }
+
+ else
+ {
+ mode = PNG_ALPHA_PNG;
+ output_gamma = PNG_DEFAULT_sRGB;
+ }
+
+ /* If 'do_local_background' is set check for the presence of gamma
+ * correction; this is part of the work-round for the libpng bug
+ * described above.
+ *
+ * TODO: fix libpng and remove this.
+ */
+ if (do_local_background)
+ {
+ png_fixed_point gtest;
+
+ /* This is 'png_gamma_threshold' from pngrtran.c; the test used for
+ * gamma correction, the screen gamma hasn't been set on png_struct
+ * yet; it's set below. png_struct::gamma, however, is set to the
+ * final value.
+ */
+ if (png_muldiv(&gtest, output_gamma, png_ptr->colorspace.gamma,
+ PNG_FP_1) && !png_gamma_significant(gtest))
+ do_local_background = 0;
+
+ else if (mode == PNG_ALPHA_STANDARD)
+ {
+ do_local_background = 2/*required*/;
+ mode = PNG_ALPHA_PNG; /* prevent libpng doing it */
+ }
+
+ /* else leave as 1 for the checks below */
+ }
+
+ /* If the bit-depth changes then handle that here. */
+ if (change & PNG_FORMAT_FLAG_LINEAR)
+ {
+ if (linear /*16-bit output*/)
+ png_set_expand_16(png_ptr);
+
+ else /* 8-bit output */
+ png_set_scale_16(png_ptr);
+
+ change &= ~PNG_FORMAT_FLAG_LINEAR;
+ }
+
+ /* Now the background/alpha channel changes. */
+ if (change & PNG_FORMAT_FLAG_ALPHA)
+ {
+ /* Removing an alpha channel requires composition for the 8-bit
+ * formats; for the 16-bit it is already done, above, by the
+ * pre-multiplication and the channel just needs to be stripped.
+ */
+ if (base_format & PNG_FORMAT_FLAG_ALPHA)
+ {
+ /* If RGB->gray is happening the alpha channel must be left and the
+ * operation completed locally.
+ *
+ * TODO: fix libpng and remove this.
+ */
+ if (do_local_background)
+ do_local_background = 2/*required*/;
+
+ /* 16-bit output: just remove the channel */
+ else if (linear) /* compose on black (well, pre-multiply) */
+ png_set_strip_alpha(png_ptr);
+
+ /* 8-bit output: do an appropriate compose */
+ else if (display->background != NULL)
+ {
+ png_color_16 c;
+
+ c.index = 0; /*unused*/
+ c.red = display->background->red;
+ c.green = display->background->green;
+ c.blue = display->background->blue;
+ c.gray = display->background->green;
+
+ /* This is always an 8-bit sRGB value, using the 'green' channel
+ * for gray is much better than calculating the luminance here;
+ * we can get off-by-one errors in that calculation relative to
+ * the app expectations and that will show up in transparent
+ * pixels.
+ */
+ png_set_background_fixed(png_ptr, &c,
+ PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/,
+ 0/*gamma: not used*/);
+ }
+
+ else /* compose on row: implemented below. */
+ {
+ do_local_compose = 1;
+ /* This leaves the alpha channel in the output, so it has to be
+ * removed by the code below. Set the encoding to the 'OPTIMIZE'
+ * one so the code only has to hack on the pixels that require
+ * composition.
+ */
+ mode = PNG_ALPHA_OPTIMIZED;
+ }
+ }
+
+ else /* output needs an alpha channel */
+ {
+ /* This is tricky because it happens before the swap operation has
+ * been accomplished; however, the swap does *not* swap the added
+ * alpha channel (weird API), so it must be added in the correct
+ * place.
+ */
+ png_uint_32 filler; /* opaque filler */
+ int where;
+
+ if (linear)
+ filler = 65535;
+
+ else
+ filler = 255;
+
+# ifdef PNG_FORMAT_AFIRST_SUPPORTED
+ if (format & PNG_FORMAT_FLAG_AFIRST)
+ {
+ where = PNG_FILLER_BEFORE;
+ change &= ~PNG_FORMAT_FLAG_AFIRST;
+ }
+
+ else
+# endif
+ where = PNG_FILLER_AFTER;
+
+ png_set_add_alpha(png_ptr, filler, where);
+ }
+
+ /* This stops the (irrelevant) call to swap_alpha below. */
+ change &= ~PNG_FORMAT_FLAG_ALPHA;
+ }
+
+ /* Now set the alpha mode correctly; this is always done, even if there is
+ * no alpha channel in either the input or the output because it correctly
+ * sets the output gamma.
+ */
+ png_set_alpha_mode_fixed(png_ptr, mode, output_gamma);
+
+# ifdef PNG_FORMAT_BGR_SUPPORTED
+ if (change & PNG_FORMAT_FLAG_BGR)
+ {
+ /* Check only the output format; PNG is never BGR; don't do this if
+ * the output is gray, but fix up the 'format' value in that case.
+ */
+ if (format & PNG_FORMAT_FLAG_COLOR)
+ png_set_bgr(png_ptr);
+
+ else
+ format &= ~PNG_FORMAT_FLAG_BGR;
+
+ change &= ~PNG_FORMAT_FLAG_BGR;
+ }
+# endif
+
+# ifdef PNG_FORMAT_AFIRST_SUPPORTED
+ if (change & PNG_FORMAT_FLAG_AFIRST)
+ {
+ /* Only relevant if there is an alpha channel - it's particularly
+ * important to handle this correctly because do_local_compose may
+ * be set above and then libpng will keep the alpha channel for this
+ * code to remove.
+ */
+ if (format & PNG_FORMAT_FLAG_ALPHA)
+ {
+ /* Disable this if doing a local background,
+ * TODO: remove this when local background is no longer required.
+ */
+ if (do_local_background != 2)
+ png_set_swap_alpha(png_ptr);
+ }
+
+ else
+ format &= ~PNG_FORMAT_FLAG_AFIRST;
+
+ change &= ~PNG_FORMAT_FLAG_AFIRST;
+ }
+# endif
+
+ /* If the *output* is 16-bit then we need to check for a byte-swap on this
+ * architecture.
+ */
+ if (linear)
+ {
+ PNG_CONST png_uint_16 le = 0x0001;
+
+ if (*(png_const_bytep)&le)
+ png_set_swap(png_ptr);
+ }
+
+ /* If change is not now 0 some transformation is missing - error out. */
+ if (change)
+ png_error(png_ptr, "png_read_image: unsupported transformation");
+ }
+
+ PNG_SKIP_CHUNKS(png_ptr);
+
+ /* Update the 'info' structure and make sure the result is as required; first
+ * make sure to turn on the interlace handling if it will be required
+ * (because it can't be turned on *after* the call to png_read_update_info!)
+ *
+ * TODO: remove the do_local_background fixup below.
+ */
+ if (!do_local_compose && do_local_background != 2)
+ passes = png_set_interlace_handling(png_ptr);
+
+ png_read_update_info(png_ptr, info_ptr);
+
+ {
+ png_uint_32 info_format = 0;
+
+ if (info_ptr->color_type & PNG_COLOR_MASK_COLOR)
+ info_format |= PNG_FORMAT_FLAG_COLOR;
+
+ if (info_ptr->color_type & PNG_COLOR_MASK_ALPHA)
+ {
+ /* do_local_compose removes this channel below. */
+ if (!do_local_compose)
+ {
+ /* do_local_background does the same if required. */
+ if (do_local_background != 2 ||
+ (format & PNG_FORMAT_FLAG_ALPHA) != 0)
+ info_format |= PNG_FORMAT_FLAG_ALPHA;
+ }
+ }
+
+ else if (do_local_compose) /* internal error */
+ png_error(png_ptr, "png_image_read: alpha channel lost");
+
+ if (info_ptr->bit_depth == 16)
+ info_format |= PNG_FORMAT_FLAG_LINEAR;
+
+# ifdef PNG_FORMAT_BGR_SUPPORTED
+ if (png_ptr->transformations & PNG_BGR)
+ info_format |= PNG_FORMAT_FLAG_BGR;
+# endif
+
+# ifdef PNG_FORMAT_AFIRST_SUPPORTED
+ if (do_local_background == 2)
+ {
+ if (format & PNG_FORMAT_FLAG_AFIRST)
+ info_format |= PNG_FORMAT_FLAG_AFIRST;
+ }
+
+ if ((png_ptr->transformations & PNG_SWAP_ALPHA) != 0 ||
+ ((png_ptr->transformations & PNG_ADD_ALPHA) != 0 &&
+ (png_ptr->flags & PNG_FLAG_FILLER_AFTER) == 0))
+ {
+ if (do_local_background == 2)
+ png_error(png_ptr, "unexpected alpha swap transformation");
+
+ info_format |= PNG_FORMAT_FLAG_AFIRST;
+ }
+# endif
+
+ /* This is actually an internal error. */
+ if (info_format != format)
+ png_error(png_ptr, "png_read_image: invalid transformations");
+ }
+
+ /* Now read the rows. If do_local_compose is set then it is necessary to use
+ * a local row buffer. The output will be GA, RGBA or BGRA and must be
+ * converted to G, RGB or BGR as appropriate. The 'local_row' member of the
+ * display acts as a flag.
+ */
+ {
+ png_voidp first_row = display->buffer;
+ ptrdiff_t row_bytes = display->row_stride;
+
+ if (linear)
+ row_bytes *= 2;
+
+ /* The following expression is designed to work correctly whether it gives
+ * a signed or an unsigned result.
+ */
+ if (row_bytes < 0)
+ {
+ char *ptr = png_voidcast(char*, first_row);
+ ptr += (image->height-1) * (-row_bytes);
+ first_row = png_voidcast(png_voidp, ptr);
+ }
+
+ display->first_row = first_row;
+ display->row_bytes = row_bytes;
+ }
+
+ if (do_local_compose)
+ {
+ int result;
+ png_voidp row = png_malloc(png_ptr, png_get_rowbytes(png_ptr, info_ptr));
+
+ display->local_row = row;
+ result = png_safe_execute(image, png_image_read_composite, display);
+ display->local_row = NULL;
+ png_free(png_ptr, row);
+
+ return result;
+ }
+
+ else if (do_local_background == 2)
+ {
+ int result;
+ png_voidp row = png_malloc(png_ptr, png_get_rowbytes(png_ptr, info_ptr));
+
+ display->local_row = row;
+ result = png_safe_execute(image, png_image_read_background, display);
+ display->local_row = NULL;
+ png_free(png_ptr, row);
+
+ return result;
+ }
+
+ else
+ {
+ png_alloc_size_t row_bytes = display->row_bytes;
+
+ while (--passes >= 0)
+ {
+ png_uint_32 y = image->height;
+ png_bytep row = png_voidcast(png_bytep, display->first_row);
+
+ while (y-- > 0)
+ {
+ png_read_row(png_ptr, row, NULL);
+ row += row_bytes;
+ }
+ }
+
+ return 1;
+ }
+}
+
+int PNGAPI
+png_image_finish_read(png_imagep image, png_const_colorp background,
+ void *buffer, png_int_32 row_stride, void *colormap)
+{
+ if (image != NULL && image->version == PNG_IMAGE_VERSION)
+ {
+ png_uint_32 check;
+
+ if (row_stride == 0)
+ row_stride = PNG_IMAGE_ROW_STRIDE(*image);
+
+ if (row_stride < 0)
+ check = -row_stride;
+
+ else
+ check = row_stride;
+
+ if (image->opaque != NULL && buffer != NULL &&
+ check >= PNG_IMAGE_ROW_STRIDE(*image))
+ {
+ if ((image->format & PNG_FORMAT_FLAG_COLORMAP) == 0 ||
+ (image->colormap_entries > 0 && colormap != NULL))
+ {
+ int result;
+ png_image_read_control display;
+
+ memset(&display, 0, (sizeof display));
+ display.image = image;
+ display.buffer = buffer;
+ display.row_stride = row_stride;
+ display.colormap = colormap;
+ display.background = background;
+ display.local_row = NULL;
+
+ /* Choose the correct 'end' routine; for the color-map case all the
+ * setup has already been done.
+ */
+ if (image->format & PNG_FORMAT_FLAG_COLORMAP)
+ result =
+ png_safe_execute(image, png_image_read_colormap, &display) &&
+ png_safe_execute(image, png_image_read_colormapped, &display);
+
+ else
+ result =
+ png_safe_execute(image, png_image_read_direct, &display);
+
+ png_image_free(image);
+ return result;
+ }
+
+ else
+ return png_image_error(image,
+ "png_image_finish_read[color-map]: no color-map");
+ }
+
+ else
+ return png_image_error(image,
+ "png_image_finish_read: invalid argument");
+ }
+
+ else if (image != NULL)
+ return png_image_error(image,
+ "png_image_finish_read: damaged PNG_IMAGE_VERSION");
+
+ return 0;
+}
+
+#endif /* PNG_SIMPLIFIED_READ_SUPPORTED */
+#endif /* PNG_READ_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngrio.c b/ml/dlib/dlib/external/libpng/pngrio.c
new file mode 100644
index 000000000..d7864407b
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngrio.c
@@ -0,0 +1,118 @@
+
+/* pngrio.c - functions for data input
+ *
+ * Last changed in libpng 1.6.0 [February 14, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ * This file provides a location for all input. Users who need
+ * special handling are expected to write a function that has the same
+ * arguments as this and performs a similar function, but that possibly
+ * has a different input method. Note that you shouldn't change this
+ * function, but rather write a replacement function and then make
+ * libpng use it at run time with png_set_read_fn(...).
+ */
+
+#include "pngpriv.h"
+
+#ifdef PNG_READ_SUPPORTED
+
+/* Read the data from whatever input you are using. The default routine
+ * reads from a file pointer. Note that this routine sometimes gets called
+ * with very small lengths, so you should implement some kind of simple
+ * buffering if you are using unbuffered reads. This should never be asked
+ * to read more then 64K on a 16 bit machine.
+ */
+void /* PRIVATE */
+png_read_data(png_structrp png_ptr, png_bytep data, png_size_t length)
+{
+ png_debug1(4, "reading %d bytes", (int)length);
+
+ if (png_ptr->read_data_fn != NULL)
+ (*(png_ptr->read_data_fn))(png_ptr, data, length);
+
+ else
+ png_error(png_ptr, "Call to NULL read function");
+}
+
+#ifdef PNG_STDIO_SUPPORTED
+/* This is the function that does the actual reading of data. If you are
+ * not reading from a standard C stream, you should create a replacement
+ * read_data function and use it at run time with png_set_read_fn(), rather
+ * than changing the library.
+ */
+void PNGCBAPI
+png_default_read_data(png_structp png_ptr, png_bytep data, png_size_t length)
+{
+ png_size_t check;
+
+ if (png_ptr == NULL)
+ return;
+
+ /* fread() returns 0 on error, so it is OK to store this in a png_size_t
+ * instead of an int, which is what fread() actually returns.
+ */
+ check = fread(data, 1, length, png_voidcast(png_FILE_p, png_ptr->io_ptr));
+
+ if (check != length)
+ png_error(png_ptr, "Read Error");
+}
+#endif
+
+/* This function allows the application to supply a new input function
+ * for libpng if standard C streams aren't being used.
+ *
+ * This function takes as its arguments:
+ *
+ * png_ptr - pointer to a png input data structure
+ *
+ * io_ptr - pointer to user supplied structure containing info about
+ * the input functions. May be NULL.
+ *
+ * read_data_fn - pointer to a new input function that takes as its
+ * arguments a pointer to a png_struct, a pointer to
+ * a location where input data can be stored, and a 32-bit
+ * unsigned int that is the number of bytes to be read.
+ * To exit and output any fatal error messages the new write
+ * function should call png_error(png_ptr, "Error msg").
+ * May be NULL, in which case libpng's default function will
+ * be used.
+ */
+void PNGAPI
+png_set_read_fn(png_structrp png_ptr, png_voidp io_ptr,
+ png_rw_ptr read_data_fn)
+{
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->io_ptr = io_ptr;
+
+#ifdef PNG_STDIO_SUPPORTED
+ if (read_data_fn != NULL)
+ png_ptr->read_data_fn = read_data_fn;
+
+ else
+ png_ptr->read_data_fn = png_default_read_data;
+#else
+ png_ptr->read_data_fn = read_data_fn;
+#endif
+
+ /* It is an error to write to a read device */
+ if (png_ptr->write_data_fn != NULL)
+ {
+ png_ptr->write_data_fn = NULL;
+ png_warning(png_ptr,
+ "Can't set both read_data_fn and write_data_fn in the"
+ " same structure");
+ }
+
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+ png_ptr->output_flush_fn = NULL;
+#endif
+}
+#endif /* PNG_READ_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngrtran.c b/ml/dlib/dlib/external/libpng/pngrtran.c
new file mode 100644
index 000000000..3b7d484fc
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngrtran.c
@@ -0,0 +1,5110 @@
+
+/* pngrtran.c - transforms the data in a row for PNG readers
+ *
+ * Last changed in libpng 1.6.4 [August 21, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ * This file contains functions optionally called by an application
+ * in order to tell libpng how to handle data when reading a PNG.
+ * Transformations that are used in both reading and writing are
+ * in pngtrans.c.
+ */
+
+#include "pngpriv.h"
+
+#ifdef PNG_READ_SUPPORTED
+
+/* Set the action on getting a CRC error for an ancillary or critical chunk. */
+void PNGAPI
+png_set_crc_action(png_structrp png_ptr, int crit_action, int ancil_action)
+{
+ png_debug(1, "in png_set_crc_action");
+
+ if (png_ptr == NULL)
+ return;
+
+ /* Tell libpng how we react to CRC errors in critical chunks */
+ switch (crit_action)
+ {
+ case PNG_CRC_NO_CHANGE: /* Leave setting as is */
+ break;
+
+ case PNG_CRC_WARN_USE: /* Warn/use data */
+ png_ptr->flags &= ~PNG_FLAG_CRC_CRITICAL_MASK;
+ png_ptr->flags |= PNG_FLAG_CRC_CRITICAL_USE;
+ break;
+
+ case PNG_CRC_QUIET_USE: /* Quiet/use data */
+ png_ptr->flags &= ~PNG_FLAG_CRC_CRITICAL_MASK;
+ png_ptr->flags |= PNG_FLAG_CRC_CRITICAL_USE |
+ PNG_FLAG_CRC_CRITICAL_IGNORE;
+ break;
+
+ case PNG_CRC_WARN_DISCARD: /* Not a valid action for critical data */
+ png_warning(png_ptr,
+ "Can't discard critical data on CRC error");
+ case PNG_CRC_ERROR_QUIT: /* Error/quit */
+
+ case PNG_CRC_DEFAULT:
+ default:
+ png_ptr->flags &= ~PNG_FLAG_CRC_CRITICAL_MASK;
+ break;
+ }
+
+ /* Tell libpng how we react to CRC errors in ancillary chunks */
+ switch (ancil_action)
+ {
+ case PNG_CRC_NO_CHANGE: /* Leave setting as is */
+ break;
+
+ case PNG_CRC_WARN_USE: /* Warn/use data */
+ png_ptr->flags &= ~PNG_FLAG_CRC_ANCILLARY_MASK;
+ png_ptr->flags |= PNG_FLAG_CRC_ANCILLARY_USE;
+ break;
+
+ case PNG_CRC_QUIET_USE: /* Quiet/use data */
+ png_ptr->flags &= ~PNG_FLAG_CRC_ANCILLARY_MASK;
+ png_ptr->flags |= PNG_FLAG_CRC_ANCILLARY_USE |
+ PNG_FLAG_CRC_ANCILLARY_NOWARN;
+ break;
+
+ case PNG_CRC_ERROR_QUIT: /* Error/quit */
+ png_ptr->flags &= ~PNG_FLAG_CRC_ANCILLARY_MASK;
+ png_ptr->flags |= PNG_FLAG_CRC_ANCILLARY_NOWARN;
+ break;
+
+ case PNG_CRC_WARN_DISCARD: /* Warn/discard data */
+
+ case PNG_CRC_DEFAULT:
+ default:
+ png_ptr->flags &= ~PNG_FLAG_CRC_ANCILLARY_MASK;
+ break;
+ }
+}
+
+#ifdef PNG_READ_TRANSFORMS_SUPPORTED
+/* Is it OK to set a transformation now? Only if png_start_read_image or
+ * png_read_update_info have not been called. It is not necessary for the IHDR
+ * to have been read in all cases, the parameter allows for this check too.
+ */
+static int
+png_rtran_ok(png_structrp png_ptr, int need_IHDR)
+{
+ if (png_ptr != NULL)
+ {
+ if (png_ptr->flags & PNG_FLAG_ROW_INIT)
+ png_app_error(png_ptr,
+ "invalid after png_start_read_image or png_read_update_info");
+
+ else if (need_IHDR && (png_ptr->mode & PNG_HAVE_IHDR) == 0)
+ png_app_error(png_ptr, "invalid before the PNG header has been read");
+
+ else
+ {
+ /* Turn on failure to initialize correctly for all transforms. */
+ png_ptr->flags |= PNG_FLAG_DETECT_UNINITIALIZED;
+
+ return 1; /* Ok */
+ }
+ }
+
+ return 0; /* no png_error possible! */
+}
+#endif
+
+#ifdef PNG_READ_BACKGROUND_SUPPORTED
+/* Handle alpha and tRNS via a background color */
+void PNGFAPI
+png_set_background_fixed(png_structrp png_ptr,
+ png_const_color_16p background_color, int background_gamma_code,
+ int need_expand, png_fixed_point background_gamma)
+{
+ png_debug(1, "in png_set_background_fixed");
+
+ if (!png_rtran_ok(png_ptr, 0) || background_color == NULL)
+ return;
+
+ if (background_gamma_code == PNG_BACKGROUND_GAMMA_UNKNOWN)
+ {
+ png_warning(png_ptr, "Application must supply a known background gamma");
+ return;
+ }
+
+ png_ptr->transformations |= PNG_COMPOSE | PNG_STRIP_ALPHA;
+ png_ptr->transformations &= ~PNG_ENCODE_ALPHA;
+ png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA;
+
+ png_ptr->background = *background_color;
+ png_ptr->background_gamma = background_gamma;
+ png_ptr->background_gamma_type = (png_byte)(background_gamma_code);
+ if (need_expand)
+ png_ptr->transformations |= PNG_BACKGROUND_EXPAND;
+ else
+ png_ptr->transformations &= ~PNG_BACKGROUND_EXPAND;
+}
+
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+void PNGAPI
+png_set_background(png_structrp png_ptr,
+ png_const_color_16p background_color, int background_gamma_code,
+ int need_expand, double background_gamma)
+{
+ png_set_background_fixed(png_ptr, background_color, background_gamma_code,
+ need_expand, png_fixed(png_ptr, background_gamma, "png_set_background"));
+}
+# endif /* FLOATING_POINT */
+#endif /* READ_BACKGROUND */
+
+/* Scale 16-bit depth files to 8-bit depth. If both of these are set then the
+ * one that pngrtran does first (scale) happens. This is necessary to allow the
+ * TRANSFORM and API behavior to be somewhat consistent, and it's simpler.
+ */
+#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED
+void PNGAPI
+png_set_scale_16(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_scale_16");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ png_ptr->transformations |= PNG_SCALE_16_TO_8;
+}
+#endif
+
+#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED
+/* Chop 16-bit depth files to 8-bit depth */
+void PNGAPI
+png_set_strip_16(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_strip_16");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ png_ptr->transformations |= PNG_16_TO_8;
+}
+#endif
+
+#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED
+void PNGAPI
+png_set_strip_alpha(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_strip_alpha");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ png_ptr->transformations |= PNG_STRIP_ALPHA;
+}
+#endif
+
+#if defined(PNG_READ_ALPHA_MODE_SUPPORTED) || defined(PNG_READ_GAMMA_SUPPORTED)
+static png_fixed_point
+translate_gamma_flags(png_structrp png_ptr, png_fixed_point output_gamma,
+ int is_screen)
+{
+ /* Check for flag values. The main reason for having the old Mac value as a
+ * flag is that it is pretty near impossible to work out what the correct
+ * value is from Apple documentation - a working Mac system is needed to
+ * discover the value!
+ */
+ if (output_gamma == PNG_DEFAULT_sRGB ||
+ output_gamma == PNG_FP_1 / PNG_DEFAULT_sRGB)
+ {
+ /* If there is no sRGB support this just sets the gamma to the standard
+ * sRGB value. (This is a side effect of using this function!)
+ */
+# ifdef PNG_READ_sRGB_SUPPORTED
+ png_ptr->flags |= PNG_FLAG_ASSUME_sRGB;
+# else
+ PNG_UNUSED(png_ptr)
+# endif
+ if (is_screen)
+ output_gamma = PNG_GAMMA_sRGB;
+ else
+ output_gamma = PNG_GAMMA_sRGB_INVERSE;
+ }
+
+ else if (output_gamma == PNG_GAMMA_MAC_18 ||
+ output_gamma == PNG_FP_1 / PNG_GAMMA_MAC_18)
+ {
+ if (is_screen)
+ output_gamma = PNG_GAMMA_MAC_OLD;
+ else
+ output_gamma = PNG_GAMMA_MAC_INVERSE;
+ }
+
+ return output_gamma;
+}
+
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+static png_fixed_point
+convert_gamma_value(png_structrp png_ptr, double output_gamma)
+{
+ /* The following silently ignores cases where fixed point (times 100,000)
+ * gamma values are passed to the floating point API. This is safe and it
+ * means the fixed point constants work just fine with the floating point
+ * API. The alternative would just lead to undetected errors and spurious
+ * bug reports. Negative values fail inside the _fixed API unless they
+ * correspond to the flag values.
+ */
+ if (output_gamma > 0 && output_gamma < 128)
+ output_gamma *= PNG_FP_1;
+
+ /* This preserves -1 and -2 exactly: */
+ output_gamma = floor(output_gamma + .5);
+
+ if (output_gamma > PNG_FP_MAX || output_gamma < PNG_FP_MIN)
+ png_fixed_error(png_ptr, "gamma value");
+
+ return (png_fixed_point)output_gamma;
+}
+# endif
+#endif /* READ_ALPHA_MODE || READ_GAMMA */
+
+#ifdef PNG_READ_ALPHA_MODE_SUPPORTED
+void PNGFAPI
+png_set_alpha_mode_fixed(png_structrp png_ptr, int mode,
+ png_fixed_point output_gamma)
+{
+ int compose = 0;
+ png_fixed_point file_gamma;
+
+ png_debug(1, "in png_set_alpha_mode");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ output_gamma = translate_gamma_flags(png_ptr, output_gamma, 1/*screen*/);
+
+ /* Validate the value to ensure it is in a reasonable range. The value
+ * is expected to be 1 or greater, but this range test allows for some
+ * viewing correction values. The intent is to weed out users of this API
+ * who use the inverse of the gamma value accidentally! Since some of these
+ * values are reasonable this may have to be changed.
+ */
+ if (output_gamma < 70000 || output_gamma > 300000)
+ png_error(png_ptr, "output gamma out of expected range");
+
+ /* The default file gamma is the inverse of the output gamma; the output
+ * gamma may be changed below so get the file value first:
+ */
+ file_gamma = png_reciprocal(output_gamma);
+
+ /* There are really 8 possibilities here, composed of any combination
+ * of:
+ *
+ * premultiply the color channels
+ * do not encode non-opaque pixels
+ * encode the alpha as well as the color channels
+ *
+ * The differences disappear if the input/output ('screen') gamma is 1.0,
+ * because then the encoding is a no-op and there is only the choice of
+ * premultiplying the color channels or not.
+ *
+ * png_set_alpha_mode and png_set_background interact because both use
+ * png_compose to do the work. Calling both is only useful when
+ * png_set_alpha_mode is used to set the default mode - PNG_ALPHA_PNG - along
+ * with a default gamma value. Otherwise PNG_COMPOSE must not be set.
+ */
+ switch (mode)
+ {
+ case PNG_ALPHA_PNG: /* default: png standard */
+ /* No compose, but it may be set by png_set_background! */
+ png_ptr->transformations &= ~PNG_ENCODE_ALPHA;
+ png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA;
+ break;
+
+ case PNG_ALPHA_ASSOCIATED: /* color channels premultiplied */
+ compose = 1;
+ png_ptr->transformations &= ~PNG_ENCODE_ALPHA;
+ png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA;
+ /* The output is linear: */
+ output_gamma = PNG_FP_1;
+ break;
+
+ case PNG_ALPHA_OPTIMIZED: /* associated, non-opaque pixels linear */
+ compose = 1;
+ png_ptr->transformations &= ~PNG_ENCODE_ALPHA;
+ png_ptr->flags |= PNG_FLAG_OPTIMIZE_ALPHA;
+ /* output_gamma records the encoding of opaque pixels! */
+ break;
+
+ case PNG_ALPHA_BROKEN: /* associated, non-linear, alpha encoded */
+ compose = 1;
+ png_ptr->transformations |= PNG_ENCODE_ALPHA;
+ png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA;
+ break;
+
+ default:
+ png_error(png_ptr, "invalid alpha mode");
+ }
+
+ /* Only set the default gamma if the file gamma has not been set (this has
+ * the side effect that the gamma in a second call to png_set_alpha_mode will
+ * be ignored.)
+ */
+ if (png_ptr->colorspace.gamma == 0)
+ {
+ png_ptr->colorspace.gamma = file_gamma;
+ png_ptr->colorspace.flags |= PNG_COLORSPACE_HAVE_GAMMA;
+ }
+
+ /* But always set the output gamma: */
+ png_ptr->screen_gamma = output_gamma;
+
+ /* Finally, if pre-multiplying, set the background fields to achieve the
+ * desired result.
+ */
+ if (compose)
+ {
+ /* And obtain alpha pre-multiplication by composing on black: */
+ memset(&png_ptr->background, 0, (sizeof png_ptr->background));
+ png_ptr->background_gamma = png_ptr->colorspace.gamma; /* just in case */
+ png_ptr->background_gamma_type = PNG_BACKGROUND_GAMMA_FILE;
+ png_ptr->transformations &= ~PNG_BACKGROUND_EXPAND;
+
+ if (png_ptr->transformations & PNG_COMPOSE)
+ png_error(png_ptr,
+ "conflicting calls to set alpha mode and background");
+
+ png_ptr->transformations |= PNG_COMPOSE;
+ }
+}
+
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+void PNGAPI
+png_set_alpha_mode(png_structrp png_ptr, int mode, double output_gamma)
+{
+ png_set_alpha_mode_fixed(png_ptr, mode, convert_gamma_value(png_ptr,
+ output_gamma));
+}
+# endif
+#endif
+
+#ifdef PNG_READ_QUANTIZE_SUPPORTED
+/* Dither file to 8-bit. Supply a palette, the current number
+ * of elements in the palette, the maximum number of elements
+ * allowed, and a histogram if possible. If the current number
+ * of colors is greater then the maximum number, the palette will be
+ * modified to fit in the maximum number. "full_quantize" indicates
+ * whether we need a quantizing cube set up for RGB images, or if we
+ * simply are reducing the number of colors in a paletted image.
+ */
+
+typedef struct png_dsort_struct
+{
+ struct png_dsort_struct * next;
+ png_byte left;
+ png_byte right;
+} png_dsort;
+typedef png_dsort * png_dsortp;
+typedef png_dsort * * png_dsortpp;
+
+void PNGAPI
+png_set_quantize(png_structrp png_ptr, png_colorp palette,
+ int num_palette, int maximum_colors, png_const_uint_16p histogram,
+ int full_quantize)
+{
+ png_debug(1, "in png_set_quantize");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ png_ptr->transformations |= PNG_QUANTIZE;
+
+ if (!full_quantize)
+ {
+ int i;
+
+ png_ptr->quantize_index = (png_bytep)png_malloc(png_ptr,
+ (png_uint_32)(num_palette * (sizeof (png_byte))));
+ for (i = 0; i < num_palette; i++)
+ png_ptr->quantize_index[i] = (png_byte)i;
+ }
+
+ if (num_palette > maximum_colors)
+ {
+ if (histogram != NULL)
+ {
+ /* This is easy enough, just throw out the least used colors.
+ * Perhaps not the best solution, but good enough.
+ */
+
+ int i;
+
+ /* Initialize an array to sort colors */
+ png_ptr->quantize_sort = (png_bytep)png_malloc(png_ptr,
+ (png_uint_32)(num_palette * (sizeof (png_byte))));
+
+ /* Initialize the quantize_sort array */
+ for (i = 0; i < num_palette; i++)
+ png_ptr->quantize_sort[i] = (png_byte)i;
+
+ /* Find the least used palette entries by starting a
+ * bubble sort, and running it until we have sorted
+ * out enough colors. Note that we don't care about
+ * sorting all the colors, just finding which are
+ * least used.
+ */
+
+ for (i = num_palette - 1; i >= maximum_colors; i--)
+ {
+ int done; /* To stop early if the list is pre-sorted */
+ int j;
+
+ done = 1;
+ for (j = 0; j < i; j++)
+ {
+ if (histogram[png_ptr->quantize_sort[j]]
+ < histogram[png_ptr->quantize_sort[j + 1]])
+ {
+ png_byte t;
+
+ t = png_ptr->quantize_sort[j];
+ png_ptr->quantize_sort[j] = png_ptr->quantize_sort[j + 1];
+ png_ptr->quantize_sort[j + 1] = t;
+ done = 0;
+ }
+ }
+
+ if (done)
+ break;
+ }
+
+ /* Swap the palette around, and set up a table, if necessary */
+ if (full_quantize)
+ {
+ int j = num_palette;
+
+ /* Put all the useful colors within the max, but don't
+ * move the others.
+ */
+ for (i = 0; i < maximum_colors; i++)
+ {
+ if ((int)png_ptr->quantize_sort[i] >= maximum_colors)
+ {
+ do
+ j--;
+ while ((int)png_ptr->quantize_sort[j] >= maximum_colors);
+
+ palette[i] = palette[j];
+ }
+ }
+ }
+ else
+ {
+ int j = num_palette;
+
+ /* Move all the used colors inside the max limit, and
+ * develop a translation table.
+ */
+ for (i = 0; i < maximum_colors; i++)
+ {
+ /* Only move the colors we need to */
+ if ((int)png_ptr->quantize_sort[i] >= maximum_colors)
+ {
+ png_color tmp_color;
+
+ do
+ j--;
+ while ((int)png_ptr->quantize_sort[j] >= maximum_colors);
+
+ tmp_color = palette[j];
+ palette[j] = palette[i];
+ palette[i] = tmp_color;
+ /* Indicate where the color went */
+ png_ptr->quantize_index[j] = (png_byte)i;
+ png_ptr->quantize_index[i] = (png_byte)j;
+ }
+ }
+
+ /* Find closest color for those colors we are not using */
+ for (i = 0; i < num_palette; i++)
+ {
+ if ((int)png_ptr->quantize_index[i] >= maximum_colors)
+ {
+ int min_d, k, min_k, d_index;
+
+ /* Find the closest color to one we threw out */
+ d_index = png_ptr->quantize_index[i];
+ min_d = PNG_COLOR_DIST(palette[d_index], palette[0]);
+ for (k = 1, min_k = 0; k < maximum_colors; k++)
+ {
+ int d;
+
+ d = PNG_COLOR_DIST(palette[d_index], palette[k]);
+
+ if (d < min_d)
+ {
+ min_d = d;
+ min_k = k;
+ }
+ }
+ /* Point to closest color */
+ png_ptr->quantize_index[i] = (png_byte)min_k;
+ }
+ }
+ }
+ png_free(png_ptr, png_ptr->quantize_sort);
+ png_ptr->quantize_sort = NULL;
+ }
+ else
+ {
+ /* This is much harder to do simply (and quickly). Perhaps
+ * we need to go through a median cut routine, but those
+ * don't always behave themselves with only a few colors
+ * as input. So we will just find the closest two colors,
+ * and throw out one of them (chosen somewhat randomly).
+ * [We don't understand this at all, so if someone wants to
+ * work on improving it, be our guest - AED, GRP]
+ */
+ int i;
+ int max_d;
+ int num_new_palette;
+ png_dsortp t;
+ png_dsortpp hash;
+
+ t = NULL;
+
+ /* Initialize palette index arrays */
+ png_ptr->index_to_palette = (png_bytep)png_malloc(png_ptr,
+ (png_uint_32)(num_palette * (sizeof (png_byte))));
+ png_ptr->palette_to_index = (png_bytep)png_malloc(png_ptr,
+ (png_uint_32)(num_palette * (sizeof (png_byte))));
+
+ /* Initialize the sort array */
+ for (i = 0; i < num_palette; i++)
+ {
+ png_ptr->index_to_palette[i] = (png_byte)i;
+ png_ptr->palette_to_index[i] = (png_byte)i;
+ }
+
+ hash = (png_dsortpp)png_calloc(png_ptr, (png_uint_32)(769 *
+ (sizeof (png_dsortp))));
+
+ num_new_palette = num_palette;
+
+ /* Initial wild guess at how far apart the farthest pixel
+ * pair we will be eliminating will be. Larger
+ * numbers mean more areas will be allocated, Smaller
+ * numbers run the risk of not saving enough data, and
+ * having to do this all over again.
+ *
+ * I have not done extensive checking on this number.
+ */
+ max_d = 96;
+
+ while (num_new_palette > maximum_colors)
+ {
+ for (i = 0; i < num_new_palette - 1; i++)
+ {
+ int j;
+
+ for (j = i + 1; j < num_new_palette; j++)
+ {
+ int d;
+
+ d = PNG_COLOR_DIST(palette[i], palette[j]);
+
+ if (d <= max_d)
+ {
+
+ t = (png_dsortp)png_malloc_warn(png_ptr,
+ (png_uint_32)(sizeof (png_dsort)));
+
+ if (t == NULL)
+ break;
+
+ t->next = hash[d];
+ t->left = (png_byte)i;
+ t->right = (png_byte)j;
+ hash[d] = t;
+ }
+ }
+ if (t == NULL)
+ break;
+ }
+
+ if (t != NULL)
+ for (i = 0; i <= max_d; i++)
+ {
+ if (hash[i] != NULL)
+ {
+ png_dsortp p;
+
+ for (p = hash[i]; p; p = p->next)
+ {
+ if ((int)png_ptr->index_to_palette[p->left]
+ < num_new_palette &&
+ (int)png_ptr->index_to_palette[p->right]
+ < num_new_palette)
+ {
+ int j, next_j;
+
+ if (num_new_palette & 0x01)
+ {
+ j = p->left;
+ next_j = p->right;
+ }
+ else
+ {
+ j = p->right;
+ next_j = p->left;
+ }
+
+ num_new_palette--;
+ palette[png_ptr->index_to_palette[j]]
+ = palette[num_new_palette];
+ if (!full_quantize)
+ {
+ int k;
+
+ for (k = 0; k < num_palette; k++)
+ {
+ if (png_ptr->quantize_index[k] ==
+ png_ptr->index_to_palette[j])
+ png_ptr->quantize_index[k] =
+ png_ptr->index_to_palette[next_j];
+
+ if ((int)png_ptr->quantize_index[k] ==
+ num_new_palette)
+ png_ptr->quantize_index[k] =
+ png_ptr->index_to_palette[j];
+ }
+ }
+
+ png_ptr->index_to_palette[png_ptr->palette_to_index
+ [num_new_palette]] = png_ptr->index_to_palette[j];
+
+ png_ptr->palette_to_index[png_ptr->index_to_palette[j]]
+ = png_ptr->palette_to_index[num_new_palette];
+
+ png_ptr->index_to_palette[j] =
+ (png_byte)num_new_palette;
+
+ png_ptr->palette_to_index[num_new_palette] =
+ (png_byte)j;
+ }
+ if (num_new_palette <= maximum_colors)
+ break;
+ }
+ if (num_new_palette <= maximum_colors)
+ break;
+ }
+ }
+
+ for (i = 0; i < 769; i++)
+ {
+ if (hash[i] != NULL)
+ {
+ png_dsortp p = hash[i];
+ while (p)
+ {
+ t = p->next;
+ png_free(png_ptr, p);
+ p = t;
+ }
+ }
+ hash[i] = 0;
+ }
+ max_d += 96;
+ }
+ png_free(png_ptr, hash);
+ png_free(png_ptr, png_ptr->palette_to_index);
+ png_free(png_ptr, png_ptr->index_to_palette);
+ png_ptr->palette_to_index = NULL;
+ png_ptr->index_to_palette = NULL;
+ }
+ num_palette = maximum_colors;
+ }
+ if (png_ptr->palette == NULL)
+ {
+ png_ptr->palette = palette;
+ }
+ png_ptr->num_palette = (png_uint_16)num_palette;
+
+ if (full_quantize)
+ {
+ int i;
+ png_bytep distance;
+ int total_bits = PNG_QUANTIZE_RED_BITS + PNG_QUANTIZE_GREEN_BITS +
+ PNG_QUANTIZE_BLUE_BITS;
+ int num_red = (1 << PNG_QUANTIZE_RED_BITS);
+ int num_green = (1 << PNG_QUANTIZE_GREEN_BITS);
+ int num_blue = (1 << PNG_QUANTIZE_BLUE_BITS);
+ png_size_t num_entries = ((png_size_t)1 << total_bits);
+
+ png_ptr->palette_lookup = (png_bytep)png_calloc(png_ptr,
+ (png_uint_32)(num_entries * (sizeof (png_byte))));
+
+ distance = (png_bytep)png_malloc(png_ptr, (png_uint_32)(num_entries *
+ (sizeof (png_byte))));
+
+ memset(distance, 0xff, num_entries * (sizeof (png_byte)));
+
+ for (i = 0; i < num_palette; i++)
+ {
+ int ir, ig, ib;
+ int r = (palette[i].red >> (8 - PNG_QUANTIZE_RED_BITS));
+ int g = (palette[i].green >> (8 - PNG_QUANTIZE_GREEN_BITS));
+ int b = (palette[i].blue >> (8 - PNG_QUANTIZE_BLUE_BITS));
+
+ for (ir = 0; ir < num_red; ir++)
+ {
+ /* int dr = abs(ir - r); */
+ int dr = ((ir > r) ? ir - r : r - ir);
+ int index_r = (ir << (PNG_QUANTIZE_BLUE_BITS +
+ PNG_QUANTIZE_GREEN_BITS));
+
+ for (ig = 0; ig < num_green; ig++)
+ {
+ /* int dg = abs(ig - g); */
+ int dg = ((ig > g) ? ig - g : g - ig);
+ int dt = dr + dg;
+ int dm = ((dr > dg) ? dr : dg);
+ int index_g = index_r | (ig << PNG_QUANTIZE_BLUE_BITS);
+
+ for (ib = 0; ib < num_blue; ib++)
+ {
+ int d_index = index_g | ib;
+ /* int db = abs(ib - b); */
+ int db = ((ib > b) ? ib - b : b - ib);
+ int dmax = ((dm > db) ? dm : db);
+ int d = dmax + dt + db;
+
+ if (d < (int)distance[d_index])
+ {
+ distance[d_index] = (png_byte)d;
+ png_ptr->palette_lookup[d_index] = (png_byte)i;
+ }
+ }
+ }
+ }
+ }
+
+ png_free(png_ptr, distance);
+ }
+}
+#endif /* PNG_READ_QUANTIZE_SUPPORTED */
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+void PNGFAPI
+png_set_gamma_fixed(png_structrp png_ptr, png_fixed_point scrn_gamma,
+ png_fixed_point file_gamma)
+{
+ png_debug(1, "in png_set_gamma_fixed");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ /* New in libpng-1.5.4 - reserve particular negative values as flags. */
+ scrn_gamma = translate_gamma_flags(png_ptr, scrn_gamma, 1/*screen*/);
+ file_gamma = translate_gamma_flags(png_ptr, file_gamma, 0/*file*/);
+
+ /* Checking the gamma values for being >0 was added in 1.5.4 along with the
+ * premultiplied alpha support; this actually hides an undocumented feature
+ * of the previous implementation which allowed gamma processing to be
+ * disabled in background handling. There is no evidence (so far) that this
+ * was being used; however, png_set_background itself accepted and must still
+ * accept '0' for the gamma value it takes, because it isn't always used.
+ *
+ * Since this is an API change (albeit a very minor one that removes an
+ * undocumented API feature) the following checks were only enabled in
+ * libpng-1.6.0.
+ */
+ if (file_gamma <= 0)
+ png_error(png_ptr, "invalid file gamma in png_set_gamma");
+
+ if (scrn_gamma <= 0)
+ png_error(png_ptr, "invalid screen gamma in png_set_gamma");
+
+ /* Set the gamma values unconditionally - this overrides the value in the PNG
+ * file if a gAMA chunk was present. png_set_alpha_mode provides a
+ * different, easier, way to default the file gamma.
+ */
+ png_ptr->colorspace.gamma = file_gamma;
+ png_ptr->colorspace.flags |= PNG_COLORSPACE_HAVE_GAMMA;
+ png_ptr->screen_gamma = scrn_gamma;
+}
+
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+void PNGAPI
+png_set_gamma(png_structrp png_ptr, double scrn_gamma, double file_gamma)
+{
+ png_set_gamma_fixed(png_ptr, convert_gamma_value(png_ptr, scrn_gamma),
+ convert_gamma_value(png_ptr, file_gamma));
+}
+# endif /* FLOATING_POINT_SUPPORTED */
+#endif /* READ_GAMMA */
+
+#ifdef PNG_READ_EXPAND_SUPPORTED
+/* Expand paletted images to RGB, expand grayscale images of
+ * less than 8-bit depth to 8-bit depth, and expand tRNS chunks
+ * to alpha channels.
+ */
+void PNGAPI
+png_set_expand(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_expand");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ png_ptr->transformations |= (PNG_EXPAND | PNG_EXPAND_tRNS);
+}
+
+/* GRR 19990627: the following three functions currently are identical
+ * to png_set_expand(). However, it is entirely reasonable that someone
+ * might wish to expand an indexed image to RGB but *not* expand a single,
+ * fully transparent palette entry to a full alpha channel--perhaps instead
+ * convert tRNS to the grayscale/RGB format (16-bit RGB value), or replace
+ * the transparent color with a particular RGB value, or drop tRNS entirely.
+ * IOW, a future version of the library may make the transformations flag
+ * a bit more fine-grained, with separate bits for each of these three
+ * functions.
+ *
+ * More to the point, these functions make it obvious what libpng will be
+ * doing, whereas "expand" can (and does) mean any number of things.
+ *
+ * GRP 20060307: In libpng-1.2.9, png_set_gray_1_2_4_to_8() was modified
+ * to expand only the sample depth but not to expand the tRNS to alpha
+ * and its name was changed to png_set_expand_gray_1_2_4_to_8().
+ */
+
+/* Expand paletted images to RGB. */
+void PNGAPI
+png_set_palette_to_rgb(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_palette_to_rgb");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ png_ptr->transformations |= (PNG_EXPAND | PNG_EXPAND_tRNS);
+}
+
+/* Expand grayscale images of less than 8-bit depth to 8 bits. */
+void PNGAPI
+png_set_expand_gray_1_2_4_to_8(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_expand_gray_1_2_4_to_8");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ png_ptr->transformations |= PNG_EXPAND;
+}
+
+/* Expand tRNS chunks to alpha channels. */
+void PNGAPI
+png_set_tRNS_to_alpha(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_tRNS_to_alpha");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ png_ptr->transformations |= (PNG_EXPAND | PNG_EXPAND_tRNS);
+}
+#endif /* defined(PNG_READ_EXPAND_SUPPORTED) */
+
+#ifdef PNG_READ_EXPAND_16_SUPPORTED
+/* Expand to 16-bit channels, expand the tRNS chunk too (because otherwise
+ * it may not work correctly.)
+ */
+void PNGAPI
+png_set_expand_16(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_expand_16");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ png_ptr->transformations |= (PNG_EXPAND_16 | PNG_EXPAND | PNG_EXPAND_tRNS);
+}
+#endif
+
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+void PNGAPI
+png_set_gray_to_rgb(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_gray_to_rgb");
+
+ if (!png_rtran_ok(png_ptr, 0))
+ return;
+
+ /* Because rgb must be 8 bits or more: */
+ png_set_expand_gray_1_2_4_to_8(png_ptr);
+ png_ptr->transformations |= PNG_GRAY_TO_RGB;
+}
+#endif
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+void PNGFAPI
+png_set_rgb_to_gray_fixed(png_structrp png_ptr, int error_action,
+ png_fixed_point red, png_fixed_point green)
+{
+ png_debug(1, "in png_set_rgb_to_gray");
+
+ /* Need the IHDR here because of the check on color_type below. */
+ /* TODO: fix this */
+ if (!png_rtran_ok(png_ptr, 1))
+ return;
+
+ switch(error_action)
+ {
+ case PNG_ERROR_ACTION_NONE:
+ png_ptr->transformations |= PNG_RGB_TO_GRAY;
+ break;
+
+ case PNG_ERROR_ACTION_WARN:
+ png_ptr->transformations |= PNG_RGB_TO_GRAY_WARN;
+ break;
+
+ case PNG_ERROR_ACTION_ERROR:
+ png_ptr->transformations |= PNG_RGB_TO_GRAY_ERR;
+ break;
+
+ default:
+ png_error(png_ptr, "invalid error action to rgb_to_gray");
+ break;
+ }
+
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+#ifdef PNG_READ_EXPAND_SUPPORTED
+ png_ptr->transformations |= PNG_EXPAND;
+#else
+ {
+ /* Make this an error in 1.6 because otherwise the application may assume
+ * that it just worked and get a memory overwrite.
+ */
+ png_error(png_ptr,
+ "Cannot do RGB_TO_GRAY without EXPAND_SUPPORTED");
+
+ /* png_ptr->transformations &= ~PNG_RGB_TO_GRAY; */
+ }
+#endif
+ {
+ if (red >= 0 && green >= 0 && red + green <= PNG_FP_1)
+ {
+ png_uint_16 red_int, green_int;
+
+ /* NOTE: this calculation does not round, but this behavior is retained
+ * for consistency, the inaccuracy is very small. The code here always
+ * overwrites the coefficients, regardless of whether they have been
+ * defaulted or set already.
+ */
+ red_int = (png_uint_16)(((png_uint_32)red*32768)/100000);
+ green_int = (png_uint_16)(((png_uint_32)green*32768)/100000);
+
+ png_ptr->rgb_to_gray_red_coeff = red_int;
+ png_ptr->rgb_to_gray_green_coeff = green_int;
+ png_ptr->rgb_to_gray_coefficients_set = 1;
+ }
+
+ else
+ {
+ if (red >= 0 && green >= 0)
+ png_app_warning(png_ptr,
+ "ignoring out of range rgb_to_gray coefficients");
+
+ /* Use the defaults, from the cHRM chunk if set, else the historical
+ * values which are close to the sRGB/HDTV/ITU-Rec 709 values. See
+ * png_do_rgb_to_gray for more discussion of the values. In this case
+ * the coefficients are not marked as 'set' and are not overwritten if
+ * something has already provided a default.
+ */
+ if (png_ptr->rgb_to_gray_red_coeff == 0 &&
+ png_ptr->rgb_to_gray_green_coeff == 0)
+ {
+ png_ptr->rgb_to_gray_red_coeff = 6968;
+ png_ptr->rgb_to_gray_green_coeff = 23434;
+ /* png_ptr->rgb_to_gray_blue_coeff = 2366; */
+ }
+ }
+ }
+}
+
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+/* Convert a RGB image to a grayscale of the same width. This allows us,
+ * for example, to convert a 24 bpp RGB image into an 8 bpp grayscale image.
+ */
+
+void PNGAPI
+png_set_rgb_to_gray(png_structrp png_ptr, int error_action, double red,
+ double green)
+{
+ png_set_rgb_to_gray_fixed(png_ptr, error_action,
+ png_fixed(png_ptr, red, "rgb to gray red coefficient"),
+ png_fixed(png_ptr, green, "rgb to gray green coefficient"));
+}
+#endif /* FLOATING POINT */
+
+#endif /* RGB_TO_GRAY */
+
+#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) || \
+ defined(PNG_WRITE_USER_TRANSFORM_SUPPORTED)
+void PNGAPI
+png_set_read_user_transform_fn(png_structrp png_ptr, png_user_transform_ptr
+ read_user_transform_fn)
+{
+ png_debug(1, "in png_set_read_user_transform_fn");
+
+#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED
+ png_ptr->transformations |= PNG_USER_TRANSFORM;
+ png_ptr->read_user_transform_fn = read_user_transform_fn;
+#endif
+}
+#endif
+
+#ifdef PNG_READ_TRANSFORMS_SUPPORTED
+#ifdef PNG_READ_GAMMA_SUPPORTED
+/* In the case of gamma transformations only do transformations on images where
+ * the [file] gamma and screen_gamma are not close reciprocals, otherwise it
+ * slows things down slightly, and also needlessly introduces small errors.
+ */
+static int /* PRIVATE */
+png_gamma_threshold(png_fixed_point screen_gamma, png_fixed_point file_gamma)
+{
+ /* PNG_GAMMA_THRESHOLD is the threshold for performing gamma
+ * correction as a difference of the overall transform from 1.0
+ *
+ * We want to compare the threshold with s*f - 1, if we get
+ * overflow here it is because of wacky gamma values so we
+ * turn on processing anyway.
+ */
+ png_fixed_point gtest;
+ return !png_muldiv(&gtest, screen_gamma, file_gamma, PNG_FP_1) ||
+ png_gamma_significant(gtest);
+}
+#endif
+
+/* Initialize everything needed for the read. This includes modifying
+ * the palette.
+ */
+
+/*For the moment 'png_init_palette_transformations' and
+ * 'png_init_rgb_transformations' only do some flag canceling optimizations.
+ * The intent is that these two routines should have palette or rgb operations
+ * extracted from 'png_init_read_transformations'.
+ */
+static void /* PRIVATE */
+png_init_palette_transformations(png_structrp png_ptr)
+{
+ /* Called to handle the (input) palette case. In png_do_read_transformations
+ * the first step is to expand the palette if requested, so this code must
+ * take care to only make changes that are invariant with respect to the
+ * palette expansion, or only do them if there is no expansion.
+ *
+ * STRIP_ALPHA has already been handled in the caller (by setting num_trans
+ * to 0.)
+ */
+ int input_has_alpha = 0;
+ int input_has_transparency = 0;
+
+ if (png_ptr->num_trans > 0)
+ {
+ int i;
+
+ /* Ignore if all the entries are opaque (unlikely!) */
+ for (i=0; i<png_ptr->num_trans; ++i)
+ {
+ if (png_ptr->trans_alpha[i] == 255)
+ continue;
+ else if (png_ptr->trans_alpha[i] == 0)
+ input_has_transparency = 1;
+ else
+ {
+ input_has_transparency = 1;
+ input_has_alpha = 1;
+ break;
+ }
+ }
+ }
+
+ /* If no alpha we can optimize. */
+ if (!input_has_alpha)
+ {
+ /* Any alpha means background and associative alpha processing is
+ * required, however if the alpha is 0 or 1 throughout OPTIIMIZE_ALPHA
+ * and ENCODE_ALPHA are irrelevant.
+ */
+ png_ptr->transformations &= ~PNG_ENCODE_ALPHA;
+ png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA;
+
+ if (!input_has_transparency)
+ png_ptr->transformations &= ~(PNG_COMPOSE | PNG_BACKGROUND_EXPAND);
+ }
+
+#if defined(PNG_READ_EXPAND_SUPPORTED) && defined(PNG_READ_BACKGROUND_SUPPORTED)
+ /* png_set_background handling - deals with the complexity of whether the
+ * background color is in the file format or the screen format in the case
+ * where an 'expand' will happen.
+ */
+
+ /* The following code cannot be entered in the alpha pre-multiplication case
+ * because PNG_BACKGROUND_EXPAND is cancelled below.
+ */
+ if ((png_ptr->transformations & PNG_BACKGROUND_EXPAND) &&
+ (png_ptr->transformations & PNG_EXPAND))
+ {
+ {
+ png_ptr->background.red =
+ png_ptr->palette[png_ptr->background.index].red;
+ png_ptr->background.green =
+ png_ptr->palette[png_ptr->background.index].green;
+ png_ptr->background.blue =
+ png_ptr->palette[png_ptr->background.index].blue;
+
+#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED
+ if (png_ptr->transformations & PNG_INVERT_ALPHA)
+ {
+ if (!(png_ptr->transformations & PNG_EXPAND_tRNS))
+ {
+ /* Invert the alpha channel (in tRNS) unless the pixels are
+ * going to be expanded, in which case leave it for later
+ */
+ int i, istop = png_ptr->num_trans;
+
+ for (i=0; i<istop; i++)
+ png_ptr->trans_alpha[i] = (png_byte)(255 -
+ png_ptr->trans_alpha[i]);
+ }
+ }
+#endif /* PNG_READ_INVERT_ALPHA_SUPPORTED */
+ }
+ } /* background expand and (therefore) no alpha association. */
+#endif /* PNG_READ_EXPAND_SUPPORTED && PNG_READ_BACKGROUND_SUPPORTED */
+}
+
+static void /* PRIVATE */
+png_init_rgb_transformations(png_structrp png_ptr)
+{
+ /* Added to libpng-1.5.4: check the color type to determine whether there
+ * is any alpha or transparency in the image and simply cancel the
+ * background and alpha mode stuff if there isn't.
+ */
+ int input_has_alpha = (png_ptr->color_type & PNG_COLOR_MASK_ALPHA) != 0;
+ int input_has_transparency = png_ptr->num_trans > 0;
+
+ /* If no alpha we can optimize. */
+ if (!input_has_alpha)
+ {
+ /* Any alpha means background and associative alpha processing is
+ * required, however if the alpha is 0 or 1 throughout OPTIIMIZE_ALPHA
+ * and ENCODE_ALPHA are irrelevant.
+ */
+# ifdef PNG_READ_ALPHA_MODE_SUPPORTED
+ png_ptr->transformations &= ~PNG_ENCODE_ALPHA;
+ png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA;
+# endif
+
+ if (!input_has_transparency)
+ png_ptr->transformations &= ~(PNG_COMPOSE | PNG_BACKGROUND_EXPAND);
+ }
+
+#if defined(PNG_READ_EXPAND_SUPPORTED) && defined(PNG_READ_BACKGROUND_SUPPORTED)
+ /* png_set_background handling - deals with the complexity of whether the
+ * background color is in the file format or the screen format in the case
+ * where an 'expand' will happen.
+ */
+
+ /* The following code cannot be entered in the alpha pre-multiplication case
+ * because PNG_BACKGROUND_EXPAND is cancelled below.
+ */
+ if ((png_ptr->transformations & PNG_BACKGROUND_EXPAND) &&
+ (png_ptr->transformations & PNG_EXPAND) &&
+ !(png_ptr->color_type & PNG_COLOR_MASK_COLOR))
+ /* i.e., GRAY or GRAY_ALPHA */
+ {
+ {
+ /* Expand background and tRNS chunks */
+ int gray = png_ptr->background.gray;
+ int trans_gray = png_ptr->trans_color.gray;
+
+ switch (png_ptr->bit_depth)
+ {
+ case 1:
+ gray *= 0xff;
+ trans_gray *= 0xff;
+ break;
+
+ case 2:
+ gray *= 0x55;
+ trans_gray *= 0x55;
+ break;
+
+ case 4:
+ gray *= 0x11;
+ trans_gray *= 0x11;
+ break;
+
+ default:
+
+ case 8:
+ /* FALL THROUGH (Already 8 bits) */
+
+ case 16:
+ /* Already a full 16 bits */
+ break;
+ }
+
+ png_ptr->background.red = png_ptr->background.green =
+ png_ptr->background.blue = (png_uint_16)gray;
+
+ if (!(png_ptr->transformations & PNG_EXPAND_tRNS))
+ {
+ png_ptr->trans_color.red = png_ptr->trans_color.green =
+ png_ptr->trans_color.blue = (png_uint_16)trans_gray;
+ }
+ }
+ } /* background expand and (therefore) no alpha association. */
+#endif /* PNG_READ_EXPAND_SUPPORTED && PNG_READ_BACKGROUND_SUPPORTED */
+}
+
+void /* PRIVATE */
+png_init_read_transformations(png_structrp png_ptr)
+{
+ png_debug(1, "in png_init_read_transformations");
+
+ /* This internal function is called from png_read_start_row in pngrutil.c
+ * and it is called before the 'rowbytes' calculation is done, so the code
+ * in here can change or update the transformations flags.
+ *
+ * First do updates that do not depend on the details of the PNG image data
+ * being processed.
+ */
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ /* Prior to 1.5.4 these tests were performed from png_set_gamma, 1.5.4 adds
+ * png_set_alpha_mode and this is another source for a default file gamma so
+ * the test needs to be performed later - here. In addition prior to 1.5.4
+ * the tests were repeated for the PALETTE color type here - this is no
+ * longer necessary (and doesn't seem to have been necessary before.)
+ */
+ {
+ /* The following temporary indicates if overall gamma correction is
+ * required.
+ */
+ int gamma_correction = 0;
+
+ if (png_ptr->colorspace.gamma != 0) /* has been set */
+ {
+ if (png_ptr->screen_gamma != 0) /* screen set too */
+ gamma_correction = png_gamma_threshold(png_ptr->colorspace.gamma,
+ png_ptr->screen_gamma);
+
+ else
+ /* Assume the output matches the input; a long time default behavior
+ * of libpng, although the standard has nothing to say about this.
+ */
+ png_ptr->screen_gamma = png_reciprocal(png_ptr->colorspace.gamma);
+ }
+
+ else if (png_ptr->screen_gamma != 0)
+ /* The converse - assume the file matches the screen, note that this
+ * perhaps undesireable default can (from 1.5.4) be changed by calling
+ * png_set_alpha_mode (even if the alpha handling mode isn't required
+ * or isn't changed from the default.)
+ */
+ png_ptr->colorspace.gamma = png_reciprocal(png_ptr->screen_gamma);
+
+ else /* neither are set */
+ /* Just in case the following prevents any processing - file and screen
+ * are both assumed to be linear and there is no way to introduce a
+ * third gamma value other than png_set_background with 'UNIQUE', and,
+ * prior to 1.5.4
+ */
+ png_ptr->screen_gamma = png_ptr->colorspace.gamma = PNG_FP_1;
+
+ /* We have a gamma value now. */
+ png_ptr->colorspace.flags |= PNG_COLORSPACE_HAVE_GAMMA;
+
+ /* Now turn the gamma transformation on or off as appropriate. Notice
+ * that PNG_GAMMA just refers to the file->screen correction. Alpha
+ * composition may independently cause gamma correction because it needs
+ * linear data (e.g. if the file has a gAMA chunk but the screen gamma
+ * hasn't been specified.) In any case this flag may get turned off in
+ * the code immediately below if the transform can be handled outside the
+ * row loop.
+ */
+ if (gamma_correction)
+ png_ptr->transformations |= PNG_GAMMA;
+
+ else
+ png_ptr->transformations &= ~PNG_GAMMA;
+ }
+#endif
+
+ /* Certain transformations have the effect of preventing other
+ * transformations that happen afterward in png_do_read_transformations,
+ * resolve the interdependencies here. From the code of
+ * png_do_read_transformations the order is:
+ *
+ * 1) PNG_EXPAND (including PNG_EXPAND_tRNS)
+ * 2) PNG_STRIP_ALPHA (if no compose)
+ * 3) PNG_RGB_TO_GRAY
+ * 4) PNG_GRAY_TO_RGB iff !PNG_BACKGROUND_IS_GRAY
+ * 5) PNG_COMPOSE
+ * 6) PNG_GAMMA
+ * 7) PNG_STRIP_ALPHA (if compose)
+ * 8) PNG_ENCODE_ALPHA
+ * 9) PNG_SCALE_16_TO_8
+ * 10) PNG_16_TO_8
+ * 11) PNG_QUANTIZE (converts to palette)
+ * 12) PNG_EXPAND_16
+ * 13) PNG_GRAY_TO_RGB iff PNG_BACKGROUND_IS_GRAY
+ * 14) PNG_INVERT_MONO
+ * 15) PNG_SHIFT
+ * 16) PNG_PACK
+ * 17) PNG_BGR
+ * 18) PNG_PACKSWAP
+ * 19) PNG_FILLER (includes PNG_ADD_ALPHA)
+ * 20) PNG_INVERT_ALPHA
+ * 21) PNG_SWAP_ALPHA
+ * 22) PNG_SWAP_BYTES
+ * 23) PNG_USER_TRANSFORM [must be last]
+ */
+#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED
+ if ((png_ptr->transformations & PNG_STRIP_ALPHA) &&
+ !(png_ptr->transformations & PNG_COMPOSE))
+ {
+ /* Stripping the alpha channel happens immediately after the 'expand'
+ * transformations, before all other transformation, so it cancels out
+ * the alpha handling. It has the side effect negating the effect of
+ * PNG_EXPAND_tRNS too:
+ */
+ png_ptr->transformations &= ~(PNG_BACKGROUND_EXPAND | PNG_ENCODE_ALPHA |
+ PNG_EXPAND_tRNS);
+ png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA;
+
+ /* Kill the tRNS chunk itself too. Prior to 1.5.4 this did not happen
+ * so transparency information would remain just so long as it wasn't
+ * expanded. This produces unexpected API changes if the set of things
+ * that do PNG_EXPAND_tRNS changes (perfectly possible given the
+ * documentation - which says ask for what you want, accept what you
+ * get.) This makes the behavior consistent from 1.5.4:
+ */
+ png_ptr->num_trans = 0;
+ }
+#endif /* STRIP_ALPHA supported, no COMPOSE */
+
+#ifdef PNG_READ_ALPHA_MODE_SUPPORTED
+ /* If the screen gamma is about 1.0 then the OPTIMIZE_ALPHA and ENCODE_ALPHA
+ * settings will have no effect.
+ */
+ if (!png_gamma_significant(png_ptr->screen_gamma))
+ {
+ png_ptr->transformations &= ~PNG_ENCODE_ALPHA;
+ png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA;
+ }
+#endif
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+ /* Make sure the coefficients for the rgb to gray conversion are set
+ * appropriately.
+ */
+ if (png_ptr->transformations & PNG_RGB_TO_GRAY)
+ png_colorspace_set_rgb_coefficients(png_ptr);
+#endif
+
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+#if defined(PNG_READ_EXPAND_SUPPORTED) && defined(PNG_READ_BACKGROUND_SUPPORTED)
+ /* Detect gray background and attempt to enable optimization for
+ * gray --> RGB case.
+ *
+ * Note: if PNG_BACKGROUND_EXPAND is set and color_type is either RGB or
+ * RGB_ALPHA (in which case need_expand is superfluous anyway), the
+ * background color might actually be gray yet not be flagged as such.
+ * This is not a problem for the current code, which uses
+ * PNG_BACKGROUND_IS_GRAY only to decide when to do the
+ * png_do_gray_to_rgb() transformation.
+ *
+ * TODO: this code needs to be revised to avoid the complexity and
+ * interdependencies. The color type of the background should be recorded in
+ * png_set_background, along with the bit depth, then the code has a record
+ * of exactly what color space the background is currently in.
+ */
+ if (png_ptr->transformations & PNG_BACKGROUND_EXPAND)
+ {
+ /* PNG_BACKGROUND_EXPAND: the background is in the file color space, so if
+ * the file was grayscale the background value is gray.
+ */
+ if (!(png_ptr->color_type & PNG_COLOR_MASK_COLOR))
+ png_ptr->mode |= PNG_BACKGROUND_IS_GRAY;
+ }
+
+ else if (png_ptr->transformations & PNG_COMPOSE)
+ {
+ /* PNG_COMPOSE: png_set_background was called with need_expand false,
+ * so the color is in the color space of the output or png_set_alpha_mode
+ * was called and the color is black. Ignore RGB_TO_GRAY because that
+ * happens before GRAY_TO_RGB.
+ */
+ if (png_ptr->transformations & PNG_GRAY_TO_RGB)
+ {
+ if (png_ptr->background.red == png_ptr->background.green &&
+ png_ptr->background.red == png_ptr->background.blue)
+ {
+ png_ptr->mode |= PNG_BACKGROUND_IS_GRAY;
+ png_ptr->background.gray = png_ptr->background.red;
+ }
+ }
+ }
+#endif /* PNG_READ_EXPAND_SUPPORTED && PNG_READ_BACKGROUND_SUPPORTED */
+#endif /* PNG_READ_GRAY_TO_RGB_SUPPORTED */
+
+ /* For indexed PNG data (PNG_COLOR_TYPE_PALETTE) many of the transformations
+ * can be performed directly on the palette, and some (such as rgb to gray)
+ * can be optimized inside the palette. This is particularly true of the
+ * composite (background and alpha) stuff, which can be pretty much all done
+ * in the palette even if the result is expanded to RGB or gray afterward.
+ *
+ * NOTE: this is Not Yet Implemented, the code behaves as in 1.5.1 and
+ * earlier and the palette stuff is actually handled on the first row. This
+ * leads to the reported bug that the palette returned by png_get_PLTE is not
+ * updated.
+ */
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ png_init_palette_transformations(png_ptr);
+
+ else
+ png_init_rgb_transformations(png_ptr);
+
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) && \
+ defined(PNG_READ_EXPAND_16_SUPPORTED)
+ if ((png_ptr->transformations & PNG_EXPAND_16) &&
+ (png_ptr->transformations & PNG_COMPOSE) &&
+ !(png_ptr->transformations & PNG_BACKGROUND_EXPAND) &&
+ png_ptr->bit_depth != 16)
+ {
+ /* TODO: fix this. Because the expand_16 operation is after the compose
+ * handling the background color must be 8, not 16, bits deep, but the
+ * application will supply a 16-bit value so reduce it here.
+ *
+ * The PNG_BACKGROUND_EXPAND code above does not expand to 16 bits at
+ * present, so that case is ok (until do_expand_16 is moved.)
+ *
+ * NOTE: this discards the low 16 bits of the user supplied background
+ * color, but until expand_16 works properly there is no choice!
+ */
+# define CHOP(x) (x)=((png_uint_16)PNG_DIV257(x))
+ CHOP(png_ptr->background.red);
+ CHOP(png_ptr->background.green);
+ CHOP(png_ptr->background.blue);
+ CHOP(png_ptr->background.gray);
+# undef CHOP
+ }
+#endif /* PNG_READ_BACKGROUND_SUPPORTED && PNG_READ_EXPAND_16_SUPPORTED */
+
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) && \
+ (defined(PNG_READ_SCALE_16_TO_8_SUPPORTED) || \
+ defined(PNG_READ_STRIP_16_TO_8_SUPPORTED))
+ if ((png_ptr->transformations & (PNG_16_TO_8|PNG_SCALE_16_TO_8)) &&
+ (png_ptr->transformations & PNG_COMPOSE) &&
+ !(png_ptr->transformations & PNG_BACKGROUND_EXPAND) &&
+ png_ptr->bit_depth == 16)
+ {
+ /* On the other hand, if a 16-bit file is to be reduced to 8-bits per
+ * component this will also happen after PNG_COMPOSE and so the background
+ * color must be pre-expanded here.
+ *
+ * TODO: fix this too.
+ */
+ png_ptr->background.red = (png_uint_16)(png_ptr->background.red * 257);
+ png_ptr->background.green =
+ (png_uint_16)(png_ptr->background.green * 257);
+ png_ptr->background.blue = (png_uint_16)(png_ptr->background.blue * 257);
+ png_ptr->background.gray = (png_uint_16)(png_ptr->background.gray * 257);
+ }
+#endif
+
+ /* NOTE: below 'PNG_READ_ALPHA_MODE_SUPPORTED' is presumed to also enable the
+ * background support (see the comments in scripts/pnglibconf.dfa), this
+ * allows pre-multiplication of the alpha channel to be implemented as
+ * compositing on black. This is probably sub-optimal and has been done in
+ * 1.5.4 betas simply to enable external critique and testing (i.e. to
+ * implement the new API quickly, without lots of internal changes.)
+ */
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+# ifdef PNG_READ_BACKGROUND_SUPPORTED
+ /* Includes ALPHA_MODE */
+ png_ptr->background_1 = png_ptr->background;
+# endif
+
+ /* This needs to change - in the palette image case a whole set of tables are
+ * built when it would be quicker to just calculate the correct value for
+ * each palette entry directly. Also, the test is too tricky - why check
+ * PNG_RGB_TO_GRAY if PNG_GAMMA is not set? The answer seems to be that
+ * PNG_GAMMA is cancelled even if the gamma is known? The test excludes the
+ * PNG_COMPOSE case, so apparently if there is no *overall* gamma correction
+ * the gamma tables will not be built even if composition is required on a
+ * gamma encoded value.
+ *
+ * In 1.5.4 this is addressed below by an additional check on the individual
+ * file gamma - if it is not 1.0 both RGB_TO_GRAY and COMPOSE need the
+ * tables.
+ */
+ if ((png_ptr->transformations & PNG_GAMMA)
+ || ((png_ptr->transformations & PNG_RGB_TO_GRAY)
+ && (png_gamma_significant(png_ptr->colorspace.gamma) ||
+ png_gamma_significant(png_ptr->screen_gamma)))
+ || ((png_ptr->transformations & PNG_COMPOSE)
+ && (png_gamma_significant(png_ptr->colorspace.gamma)
+ || png_gamma_significant(png_ptr->screen_gamma)
+# ifdef PNG_READ_BACKGROUND_SUPPORTED
+ || (png_ptr->background_gamma_type == PNG_BACKGROUND_GAMMA_UNIQUE
+ && png_gamma_significant(png_ptr->background_gamma))
+# endif
+ )) || ((png_ptr->transformations & PNG_ENCODE_ALPHA)
+ && png_gamma_significant(png_ptr->screen_gamma))
+ )
+ {
+ png_build_gamma_table(png_ptr, png_ptr->bit_depth);
+
+#ifdef PNG_READ_BACKGROUND_SUPPORTED
+ if (png_ptr->transformations & PNG_COMPOSE)
+ {
+ /* Issue a warning about this combination: because RGB_TO_GRAY is
+ * optimized to do the gamma transform if present yet do_background has
+ * to do the same thing if both options are set a
+ * double-gamma-correction happens. This is true in all versions of
+ * libpng to date.
+ */
+ if (png_ptr->transformations & PNG_RGB_TO_GRAY)
+ png_warning(png_ptr,
+ "libpng does not support gamma+background+rgb_to_gray");
+
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ /* We don't get to here unless there is a tRNS chunk with non-opaque
+ * entries - see the checking code at the start of this function.
+ */
+ png_color back, back_1;
+ png_colorp palette = png_ptr->palette;
+ int num_palette = png_ptr->num_palette;
+ int i;
+ if (png_ptr->background_gamma_type == PNG_BACKGROUND_GAMMA_FILE)
+ {
+
+ back.red = png_ptr->gamma_table[png_ptr->background.red];
+ back.green = png_ptr->gamma_table[png_ptr->background.green];
+ back.blue = png_ptr->gamma_table[png_ptr->background.blue];
+
+ back_1.red = png_ptr->gamma_to_1[png_ptr->background.red];
+ back_1.green = png_ptr->gamma_to_1[png_ptr->background.green];
+ back_1.blue = png_ptr->gamma_to_1[png_ptr->background.blue];
+ }
+ else
+ {
+ png_fixed_point g, gs;
+
+ switch (png_ptr->background_gamma_type)
+ {
+ case PNG_BACKGROUND_GAMMA_SCREEN:
+ g = (png_ptr->screen_gamma);
+ gs = PNG_FP_1;
+ break;
+
+ case PNG_BACKGROUND_GAMMA_FILE:
+ g = png_reciprocal(png_ptr->colorspace.gamma);
+ gs = png_reciprocal2(png_ptr->colorspace.gamma,
+ png_ptr->screen_gamma);
+ break;
+
+ case PNG_BACKGROUND_GAMMA_UNIQUE:
+ g = png_reciprocal(png_ptr->background_gamma);
+ gs = png_reciprocal2(png_ptr->background_gamma,
+ png_ptr->screen_gamma);
+ break;
+ default:
+ g = PNG_FP_1; /* back_1 */
+ gs = PNG_FP_1; /* back */
+ break;
+ }
+
+ if (png_gamma_significant(gs))
+ {
+ back.red = png_gamma_8bit_correct(png_ptr->background.red,
+ gs);
+ back.green = png_gamma_8bit_correct(png_ptr->background.green,
+ gs);
+ back.blue = png_gamma_8bit_correct(png_ptr->background.blue,
+ gs);
+ }
+
+ else
+ {
+ back.red = (png_byte)png_ptr->background.red;
+ back.green = (png_byte)png_ptr->background.green;
+ back.blue = (png_byte)png_ptr->background.blue;
+ }
+
+ if (png_gamma_significant(g))
+ {
+ back_1.red = png_gamma_8bit_correct(png_ptr->background.red,
+ g);
+ back_1.green = png_gamma_8bit_correct(
+ png_ptr->background.green, g);
+ back_1.blue = png_gamma_8bit_correct(png_ptr->background.blue,
+ g);
+ }
+
+ else
+ {
+ back_1.red = (png_byte)png_ptr->background.red;
+ back_1.green = (png_byte)png_ptr->background.green;
+ back_1.blue = (png_byte)png_ptr->background.blue;
+ }
+ }
+
+ for (i = 0; i < num_palette; i++)
+ {
+ if (i < (int)png_ptr->num_trans &&
+ png_ptr->trans_alpha[i] != 0xff)
+ {
+ if (png_ptr->trans_alpha[i] == 0)
+ {
+ palette[i] = back;
+ }
+ else /* if (png_ptr->trans_alpha[i] != 0xff) */
+ {
+ png_byte v, w;
+
+ v = png_ptr->gamma_to_1[palette[i].red];
+ png_composite(w, v, png_ptr->trans_alpha[i], back_1.red);
+ palette[i].red = png_ptr->gamma_from_1[w];
+
+ v = png_ptr->gamma_to_1[palette[i].green];
+ png_composite(w, v, png_ptr->trans_alpha[i], back_1.green);
+ palette[i].green = png_ptr->gamma_from_1[w];
+
+ v = png_ptr->gamma_to_1[palette[i].blue];
+ png_composite(w, v, png_ptr->trans_alpha[i], back_1.blue);
+ palette[i].blue = png_ptr->gamma_from_1[w];
+ }
+ }
+ else
+ {
+ palette[i].red = png_ptr->gamma_table[palette[i].red];
+ palette[i].green = png_ptr->gamma_table[palette[i].green];
+ palette[i].blue = png_ptr->gamma_table[palette[i].blue];
+ }
+ }
+
+ /* Prevent the transformations being done again.
+ *
+ * NOTE: this is highly dubious; it removes the transformations in
+ * place. This seems inconsistent with the general treatment of the
+ * transformations elsewhere.
+ */
+ png_ptr->transformations &= ~(PNG_COMPOSE | PNG_GAMMA);
+ } /* color_type == PNG_COLOR_TYPE_PALETTE */
+
+ /* if (png_ptr->background_gamma_type!=PNG_BACKGROUND_GAMMA_UNKNOWN) */
+ else /* color_type != PNG_COLOR_TYPE_PALETTE */
+ {
+ int gs_sig, g_sig;
+ png_fixed_point g = PNG_FP_1; /* Correction to linear */
+ png_fixed_point gs = PNG_FP_1; /* Correction to screen */
+
+ switch (png_ptr->background_gamma_type)
+ {
+ case PNG_BACKGROUND_GAMMA_SCREEN:
+ g = png_ptr->screen_gamma;
+ /* gs = PNG_FP_1; */
+ break;
+
+ case PNG_BACKGROUND_GAMMA_FILE:
+ g = png_reciprocal(png_ptr->colorspace.gamma);
+ gs = png_reciprocal2(png_ptr->colorspace.gamma,
+ png_ptr->screen_gamma);
+ break;
+
+ case PNG_BACKGROUND_GAMMA_UNIQUE:
+ g = png_reciprocal(png_ptr->background_gamma);
+ gs = png_reciprocal2(png_ptr->background_gamma,
+ png_ptr->screen_gamma);
+ break;
+
+ default:
+ png_error(png_ptr, "invalid background gamma type");
+ }
+
+ g_sig = png_gamma_significant(g);
+ gs_sig = png_gamma_significant(gs);
+
+ if (g_sig)
+ png_ptr->background_1.gray = png_gamma_correct(png_ptr,
+ png_ptr->background.gray, g);
+
+ if (gs_sig)
+ png_ptr->background.gray = png_gamma_correct(png_ptr,
+ png_ptr->background.gray, gs);
+
+ if ((png_ptr->background.red != png_ptr->background.green) ||
+ (png_ptr->background.red != png_ptr->background.blue) ||
+ (png_ptr->background.red != png_ptr->background.gray))
+ {
+ /* RGB or RGBA with color background */
+ if (g_sig)
+ {
+ png_ptr->background_1.red = png_gamma_correct(png_ptr,
+ png_ptr->background.red, g);
+
+ png_ptr->background_1.green = png_gamma_correct(png_ptr,
+ png_ptr->background.green, g);
+
+ png_ptr->background_1.blue = png_gamma_correct(png_ptr,
+ png_ptr->background.blue, g);
+ }
+
+ if (gs_sig)
+ {
+ png_ptr->background.red = png_gamma_correct(png_ptr,
+ png_ptr->background.red, gs);
+
+ png_ptr->background.green = png_gamma_correct(png_ptr,
+ png_ptr->background.green, gs);
+
+ png_ptr->background.blue = png_gamma_correct(png_ptr,
+ png_ptr->background.blue, gs);
+ }
+ }
+
+ else
+ {
+ /* GRAY, GRAY ALPHA, RGB, or RGBA with gray background */
+ png_ptr->background_1.red = png_ptr->background_1.green
+ = png_ptr->background_1.blue = png_ptr->background_1.gray;
+
+ png_ptr->background.red = png_ptr->background.green
+ = png_ptr->background.blue = png_ptr->background.gray;
+ }
+
+ /* The background is now in screen gamma: */
+ png_ptr->background_gamma_type = PNG_BACKGROUND_GAMMA_SCREEN;
+ } /* color_type != PNG_COLOR_TYPE_PALETTE */
+ }/* png_ptr->transformations & PNG_BACKGROUND */
+
+ else
+ /* Transformation does not include PNG_BACKGROUND */
+#endif /* PNG_READ_BACKGROUND_SUPPORTED */
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+ /* RGB_TO_GRAY needs to have non-gamma-corrected values! */
+ && ((png_ptr->transformations & PNG_EXPAND) == 0 ||
+ (png_ptr->transformations & PNG_RGB_TO_GRAY) == 0)
+#endif
+ )
+ {
+ png_colorp palette = png_ptr->palette;
+ int num_palette = png_ptr->num_palette;
+ int i;
+
+ /* NOTE: there are other transformations that should probably be in
+ * here too.
+ */
+ for (i = 0; i < num_palette; i++)
+ {
+ palette[i].red = png_ptr->gamma_table[palette[i].red];
+ palette[i].green = png_ptr->gamma_table[palette[i].green];
+ palette[i].blue = png_ptr->gamma_table[palette[i].blue];
+ }
+
+ /* Done the gamma correction. */
+ png_ptr->transformations &= ~PNG_GAMMA;
+ } /* color_type == PALETTE && !PNG_BACKGROUND transformation */
+ }
+#ifdef PNG_READ_BACKGROUND_SUPPORTED
+ else
+#endif
+#endif /* PNG_READ_GAMMA_SUPPORTED */
+
+#ifdef PNG_READ_BACKGROUND_SUPPORTED
+ /* No GAMMA transformation (see the hanging else 4 lines above) */
+ if ((png_ptr->transformations & PNG_COMPOSE) &&
+ (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE))
+ {
+ int i;
+ int istop = (int)png_ptr->num_trans;
+ png_color back;
+ png_colorp palette = png_ptr->palette;
+
+ back.red = (png_byte)png_ptr->background.red;
+ back.green = (png_byte)png_ptr->background.green;
+ back.blue = (png_byte)png_ptr->background.blue;
+
+ for (i = 0; i < istop; i++)
+ {
+ if (png_ptr->trans_alpha[i] == 0)
+ {
+ palette[i] = back;
+ }
+
+ else if (png_ptr->trans_alpha[i] != 0xff)
+ {
+ /* The png_composite() macro is defined in png.h */
+ png_composite(palette[i].red, palette[i].red,
+ png_ptr->trans_alpha[i], back.red);
+
+ png_composite(palette[i].green, palette[i].green,
+ png_ptr->trans_alpha[i], back.green);
+
+ png_composite(palette[i].blue, palette[i].blue,
+ png_ptr->trans_alpha[i], back.blue);
+ }
+ }
+
+ png_ptr->transformations &= ~PNG_COMPOSE;
+ }
+#endif /* PNG_READ_BACKGROUND_SUPPORTED */
+
+#ifdef PNG_READ_SHIFT_SUPPORTED
+ if ((png_ptr->transformations & PNG_SHIFT) &&
+ !(png_ptr->transformations & PNG_EXPAND) &&
+ (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE))
+ {
+ int i;
+ int istop = png_ptr->num_palette;
+ int shift = 8 - png_ptr->sig_bit.red;
+
+ png_ptr->transformations &= ~PNG_SHIFT;
+
+ /* significant bits can be in the range 1 to 7 for a meaninful result, if
+ * the number of significant bits is 0 then no shift is done (this is an
+ * error condition which is silently ignored.)
+ */
+ if (shift > 0 && shift < 8)
+ for (i=0; i<istop; ++i)
+ {
+ int component = png_ptr->palette[i].red;
+
+ component >>= shift;
+ png_ptr->palette[i].red = (png_byte)component;
+ }
+
+ shift = 8 - png_ptr->sig_bit.green;
+ if (shift > 0 && shift < 8)
+ for (i=0; i<istop; ++i)
+ {
+ int component = png_ptr->palette[i].green;
+
+ component >>= shift;
+ png_ptr->palette[i].green = (png_byte)component;
+ }
+
+ shift = 8 - png_ptr->sig_bit.blue;
+ if (shift > 0 && shift < 8)
+ for (i=0; i<istop; ++i)
+ {
+ int component = png_ptr->palette[i].blue;
+
+ component >>= shift;
+ png_ptr->palette[i].blue = (png_byte)component;
+ }
+ }
+#endif /* PNG_READ_SHIFT_SUPPORTED */
+}
+
+/* Modify the info structure to reflect the transformations. The
+ * info should be updated so a PNG file could be written with it,
+ * assuming the transformations result in valid PNG data.
+ */
+void /* PRIVATE */
+png_read_transform_info(png_structrp png_ptr, png_inforp info_ptr)
+{
+ png_debug(1, "in png_read_transform_info");
+
+#ifdef PNG_READ_EXPAND_SUPPORTED
+ if (png_ptr->transformations & PNG_EXPAND)
+ {
+ if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ /* This check must match what actually happens in
+ * png_do_expand_palette; if it ever checks the tRNS chunk to see if
+ * it is all opaque we must do the same (at present it does not.)
+ */
+ if (png_ptr->num_trans > 0)
+ info_ptr->color_type = PNG_COLOR_TYPE_RGB_ALPHA;
+
+ else
+ info_ptr->color_type = PNG_COLOR_TYPE_RGB;
+
+ info_ptr->bit_depth = 8;
+ info_ptr->num_trans = 0;
+ }
+ else
+ {
+ if (png_ptr->num_trans)
+ {
+ if (png_ptr->transformations & PNG_EXPAND_tRNS)
+ info_ptr->color_type |= PNG_COLOR_MASK_ALPHA;
+ }
+ if (info_ptr->bit_depth < 8)
+ info_ptr->bit_depth = 8;
+
+ info_ptr->num_trans = 0;
+ }
+ }
+#endif
+
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED)
+ /* The following is almost certainly wrong unless the background value is in
+ * the screen space!
+ */
+ if (png_ptr->transformations & PNG_COMPOSE)
+ info_ptr->background = png_ptr->background;
+#endif
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ /* The following used to be conditional on PNG_GAMMA (prior to 1.5.4),
+ * however it seems that the code in png_init_read_transformations, which has
+ * been called before this from png_read_update_info->png_read_start_row
+ * sometimes does the gamma transform and cancels the flag.
+ *
+ * TODO: this looks wrong; the info_ptr should end up with a gamma equal to
+ * the screen_gamma value. The following probably results in weirdness if
+ * the info_ptr is used by the app after the rows have been read.
+ */
+ info_ptr->colorspace.gamma = png_ptr->colorspace.gamma;
+#endif
+
+ if (info_ptr->bit_depth == 16)
+ {
+# ifdef PNG_READ_16BIT_SUPPORTED
+# ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED
+ if (png_ptr->transformations & PNG_SCALE_16_TO_8)
+ info_ptr->bit_depth = 8;
+# endif
+
+# ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED
+ if (png_ptr->transformations & PNG_16_TO_8)
+ info_ptr->bit_depth = 8;
+# endif
+
+# else
+ /* No 16 bit support: force chopping 16-bit input down to 8, in this case
+ * the app program can chose if both APIs are available by setting the
+ * correct scaling to use.
+ */
+# ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED
+ /* For compatibility with previous versions use the strip method by
+ * default. This code works because if PNG_SCALE_16_TO_8 is already
+ * set the code below will do that in preference to the chop.
+ */
+ png_ptr->transformations |= PNG_16_TO_8;
+ info_ptr->bit_depth = 8;
+# else
+
+# ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED
+ png_ptr->transformations |= PNG_SCALE_16_TO_8;
+ info_ptr->bit_depth = 8;
+# else
+
+ CONFIGURATION ERROR: you must enable at least one 16 to 8 method
+# endif
+# endif
+#endif /* !READ_16BIT_SUPPORTED */
+ }
+
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+ if (png_ptr->transformations & PNG_GRAY_TO_RGB)
+ info_ptr->color_type = (png_byte)(info_ptr->color_type |
+ PNG_COLOR_MASK_COLOR);
+#endif
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+ if (png_ptr->transformations & PNG_RGB_TO_GRAY)
+ info_ptr->color_type = (png_byte)(info_ptr->color_type &
+ ~PNG_COLOR_MASK_COLOR);
+#endif
+
+#ifdef PNG_READ_QUANTIZE_SUPPORTED
+ if (png_ptr->transformations & PNG_QUANTIZE)
+ {
+ if (((info_ptr->color_type == PNG_COLOR_TYPE_RGB) ||
+ (info_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA)) &&
+ png_ptr->palette_lookup && info_ptr->bit_depth == 8)
+ {
+ info_ptr->color_type = PNG_COLOR_TYPE_PALETTE;
+ }
+ }
+#endif
+
+#ifdef PNG_READ_EXPAND_16_SUPPORTED
+ if (png_ptr->transformations & PNG_EXPAND_16 && info_ptr->bit_depth == 8 &&
+ info_ptr->color_type != PNG_COLOR_TYPE_PALETTE)
+ {
+ info_ptr->bit_depth = 16;
+ }
+#endif
+
+#ifdef PNG_READ_PACK_SUPPORTED
+ if ((png_ptr->transformations & PNG_PACK) && (info_ptr->bit_depth < 8))
+ info_ptr->bit_depth = 8;
+#endif
+
+ if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ info_ptr->channels = 1;
+
+ else if (info_ptr->color_type & PNG_COLOR_MASK_COLOR)
+ info_ptr->channels = 3;
+
+ else
+ info_ptr->channels = 1;
+
+#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED
+ if (png_ptr->transformations & PNG_STRIP_ALPHA)
+ {
+ info_ptr->color_type = (png_byte)(info_ptr->color_type &
+ ~PNG_COLOR_MASK_ALPHA);
+ info_ptr->num_trans = 0;
+ }
+#endif
+
+ if (info_ptr->color_type & PNG_COLOR_MASK_ALPHA)
+ info_ptr->channels++;
+
+#ifdef PNG_READ_FILLER_SUPPORTED
+ /* STRIP_ALPHA and FILLER allowed: MASK_ALPHA bit stripped above */
+ if ((png_ptr->transformations & PNG_FILLER) &&
+ ((info_ptr->color_type == PNG_COLOR_TYPE_RGB) ||
+ (info_ptr->color_type == PNG_COLOR_TYPE_GRAY)))
+ {
+ info_ptr->channels++;
+ /* If adding a true alpha channel not just filler */
+ if (png_ptr->transformations & PNG_ADD_ALPHA)
+ info_ptr->color_type |= PNG_COLOR_MASK_ALPHA;
+ }
+#endif
+
+#if defined(PNG_USER_TRANSFORM_PTR_SUPPORTED) && \
+defined(PNG_READ_USER_TRANSFORM_SUPPORTED)
+ if (png_ptr->transformations & PNG_USER_TRANSFORM)
+ {
+ if (info_ptr->bit_depth < png_ptr->user_transform_depth)
+ info_ptr->bit_depth = png_ptr->user_transform_depth;
+
+ if (info_ptr->channels < png_ptr->user_transform_channels)
+ info_ptr->channels = png_ptr->user_transform_channels;
+ }
+#endif
+
+ info_ptr->pixel_depth = (png_byte)(info_ptr->channels *
+ info_ptr->bit_depth);
+
+ info_ptr->rowbytes = PNG_ROWBYTES(info_ptr->pixel_depth, info_ptr->width);
+
+ /* Adding in 1.5.4: cache the above value in png_struct so that we can later
+ * check in png_rowbytes that the user buffer won't get overwritten. Note
+ * that the field is not always set - if png_read_update_info isn't called
+ * the application has to either not do any transforms or get the calculation
+ * right itself.
+ */
+ png_ptr->info_rowbytes = info_ptr->rowbytes;
+
+#ifndef PNG_READ_EXPAND_SUPPORTED
+ if (png_ptr)
+ return;
+#endif
+}
+
+/* Transform the row. The order of transformations is significant,
+ * and is very touchy. If you add a transformation, take care to
+ * decide how it fits in with the other transformations here.
+ */
+void /* PRIVATE */
+png_do_read_transformations(png_structrp png_ptr, png_row_infop row_info)
+{
+ png_debug(1, "in png_do_read_transformations");
+
+ if (png_ptr->row_buf == NULL)
+ {
+ /* Prior to 1.5.4 this output row/pass where the NULL pointer is, but this
+ * error is incredibly rare and incredibly easy to debug without this
+ * information.
+ */
+ png_error(png_ptr, "NULL row buffer");
+ }
+
+ /* The following is debugging; prior to 1.5.4 the code was never compiled in;
+ * in 1.5.4 PNG_FLAG_DETECT_UNINITIALIZED was added and the macro
+ * PNG_WARN_UNINITIALIZED_ROW removed. In 1.6 the new flag is set only for
+ * all transformations, however in practice the ROW_INIT always gets done on
+ * demand, if necessary.
+ */
+ if ((png_ptr->flags & PNG_FLAG_DETECT_UNINITIALIZED) != 0 &&
+ !(png_ptr->flags & PNG_FLAG_ROW_INIT))
+ {
+ /* Application has failed to call either png_read_start_image() or
+ * png_read_update_info() after setting transforms that expand pixels.
+ * This check added to libpng-1.2.19 (but not enabled until 1.5.4).
+ */
+ png_error(png_ptr, "Uninitialized row");
+ }
+
+#ifdef PNG_READ_EXPAND_SUPPORTED
+ if (png_ptr->transformations & PNG_EXPAND)
+ {
+ if (row_info->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ png_do_expand_palette(row_info, png_ptr->row_buf + 1,
+ png_ptr->palette, png_ptr->trans_alpha, png_ptr->num_trans);
+ }
+
+ else
+ {
+ if (png_ptr->num_trans &&
+ (png_ptr->transformations & PNG_EXPAND_tRNS))
+ png_do_expand(row_info, png_ptr->row_buf + 1,
+ &(png_ptr->trans_color));
+
+ else
+ png_do_expand(row_info, png_ptr->row_buf + 1,
+ NULL);
+ }
+ }
+#endif
+
+#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED
+ if ((png_ptr->transformations & PNG_STRIP_ALPHA) &&
+ !(png_ptr->transformations & PNG_COMPOSE) &&
+ (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA ||
+ row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA))
+ png_do_strip_channel(row_info, png_ptr->row_buf + 1,
+ 0 /* at_start == false, because SWAP_ALPHA happens later */);
+#endif
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+ if (png_ptr->transformations & PNG_RGB_TO_GRAY)
+ {
+ int rgb_error =
+ png_do_rgb_to_gray(png_ptr, row_info,
+ png_ptr->row_buf + 1);
+
+ if (rgb_error)
+ {
+ png_ptr->rgb_to_gray_status=1;
+ if ((png_ptr->transformations & PNG_RGB_TO_GRAY) ==
+ PNG_RGB_TO_GRAY_WARN)
+ png_warning(png_ptr, "png_do_rgb_to_gray found nongray pixel");
+
+ if ((png_ptr->transformations & PNG_RGB_TO_GRAY) ==
+ PNG_RGB_TO_GRAY_ERR)
+ png_error(png_ptr, "png_do_rgb_to_gray found nongray pixel");
+ }
+ }
+#endif
+
+/* From Andreas Dilger e-mail to png-implement, 26 March 1998:
+ *
+ * In most cases, the "simple transparency" should be done prior to doing
+ * gray-to-RGB, or you will have to test 3x as many bytes to check if a
+ * pixel is transparent. You would also need to make sure that the
+ * transparency information is upgraded to RGB.
+ *
+ * To summarize, the current flow is:
+ * - Gray + simple transparency -> compare 1 or 2 gray bytes and composite
+ * with background "in place" if transparent,
+ * convert to RGB if necessary
+ * - Gray + alpha -> composite with gray background and remove alpha bytes,
+ * convert to RGB if necessary
+ *
+ * To support RGB backgrounds for gray images we need:
+ * - Gray + simple transparency -> convert to RGB + simple transparency,
+ * compare 3 or 6 bytes and composite with
+ * background "in place" if transparent
+ * (3x compare/pixel compared to doing
+ * composite with gray bkgrnd)
+ * - Gray + alpha -> convert to RGB + alpha, composite with background and
+ * remove alpha bytes (3x float
+ * operations/pixel compared with composite
+ * on gray background)
+ *
+ * Greg's change will do this. The reason it wasn't done before is for
+ * performance, as this increases the per-pixel operations. If we would check
+ * in advance if the background was gray or RGB, and position the gray-to-RGB
+ * transform appropriately, then it would save a lot of work/time.
+ */
+
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+ /* If gray -> RGB, do so now only if background is non-gray; else do later
+ * for performance reasons
+ */
+ if ((png_ptr->transformations & PNG_GRAY_TO_RGB) &&
+ !(png_ptr->mode & PNG_BACKGROUND_IS_GRAY))
+ png_do_gray_to_rgb(row_info, png_ptr->row_buf + 1);
+#endif
+
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED)
+ if (png_ptr->transformations & PNG_COMPOSE)
+ png_do_compose(row_info, png_ptr->row_buf + 1, png_ptr);
+#endif
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if ((png_ptr->transformations & PNG_GAMMA) &&
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+ /* Because RGB_TO_GRAY does the gamma transform. */
+ !(png_ptr->transformations & PNG_RGB_TO_GRAY) &&
+#endif
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED)
+ /* Because PNG_COMPOSE does the gamma transform if there is something to
+ * do (if there is an alpha channel or transparency.)
+ */
+ !((png_ptr->transformations & PNG_COMPOSE) &&
+ ((png_ptr->num_trans != 0) ||
+ (png_ptr->color_type & PNG_COLOR_MASK_ALPHA))) &&
+#endif
+ /* Because png_init_read_transformations transforms the palette, unless
+ * RGB_TO_GRAY will do the transform.
+ */
+ (png_ptr->color_type != PNG_COLOR_TYPE_PALETTE))
+ png_do_gamma(row_info, png_ptr->row_buf + 1, png_ptr);
+#endif
+
+#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED
+ if ((png_ptr->transformations & PNG_STRIP_ALPHA) &&
+ (png_ptr->transformations & PNG_COMPOSE) &&
+ (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA ||
+ row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA))
+ png_do_strip_channel(row_info, png_ptr->row_buf + 1,
+ 0 /* at_start == false, because SWAP_ALPHA happens later */);
+#endif
+
+#ifdef PNG_READ_ALPHA_MODE_SUPPORTED
+ if ((png_ptr->transformations & PNG_ENCODE_ALPHA) &&
+ (row_info->color_type & PNG_COLOR_MASK_ALPHA))
+ png_do_encode_alpha(row_info, png_ptr->row_buf + 1, png_ptr);
+#endif
+
+#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED
+ if (png_ptr->transformations & PNG_SCALE_16_TO_8)
+ png_do_scale_16_to_8(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED
+ /* There is no harm in doing both of these because only one has any effect,
+ * by putting the 'scale' option first if the app asks for scale (either by
+ * calling the API or in a TRANSFORM flag) this is what happens.
+ */
+ if (png_ptr->transformations & PNG_16_TO_8)
+ png_do_chop(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_QUANTIZE_SUPPORTED
+ if (png_ptr->transformations & PNG_QUANTIZE)
+ {
+ png_do_quantize(row_info, png_ptr->row_buf + 1,
+ png_ptr->palette_lookup, png_ptr->quantize_index);
+
+ if (row_info->rowbytes == 0)
+ png_error(png_ptr, "png_do_quantize returned rowbytes=0");
+ }
+#endif /* PNG_READ_QUANTIZE_SUPPORTED */
+
+#ifdef PNG_READ_EXPAND_16_SUPPORTED
+ /* Do the expansion now, after all the arithmetic has been done. Notice
+ * that previous transformations can handle the PNG_EXPAND_16 flag if this
+ * is efficient (particularly true in the case of gamma correction, where
+ * better accuracy results faster!)
+ */
+ if (png_ptr->transformations & PNG_EXPAND_16)
+ png_do_expand_16(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+ /* NOTE: moved here in 1.5.4 (from much later in this list.) */
+ if ((png_ptr->transformations & PNG_GRAY_TO_RGB) &&
+ (png_ptr->mode & PNG_BACKGROUND_IS_GRAY))
+ png_do_gray_to_rgb(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_INVERT_SUPPORTED
+ if (png_ptr->transformations & PNG_INVERT_MONO)
+ png_do_invert(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_SHIFT_SUPPORTED
+ if (png_ptr->transformations & PNG_SHIFT)
+ png_do_unshift(row_info, png_ptr->row_buf + 1,
+ &(png_ptr->shift));
+#endif
+
+#ifdef PNG_READ_PACK_SUPPORTED
+ if (png_ptr->transformations & PNG_PACK)
+ png_do_unpack(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED
+ /* Added at libpng-1.5.10 */
+ if (row_info->color_type == PNG_COLOR_TYPE_PALETTE &&
+ png_ptr->num_palette_max >= 0)
+ png_do_check_palette_indexes(png_ptr, row_info);
+#endif
+
+#ifdef PNG_READ_BGR_SUPPORTED
+ if (png_ptr->transformations & PNG_BGR)
+ png_do_bgr(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_PACKSWAP_SUPPORTED
+ if (png_ptr->transformations & PNG_PACKSWAP)
+ png_do_packswap(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_FILLER_SUPPORTED
+ if (png_ptr->transformations & PNG_FILLER)
+ png_do_read_filler(row_info, png_ptr->row_buf + 1,
+ (png_uint_32)png_ptr->filler, png_ptr->flags);
+#endif
+
+#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED
+ if (png_ptr->transformations & PNG_INVERT_ALPHA)
+ png_do_read_invert_alpha(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_SWAP_ALPHA_SUPPORTED
+ if (png_ptr->transformations & PNG_SWAP_ALPHA)
+ png_do_read_swap_alpha(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_READ_16BIT_SUPPORTED
+#ifdef PNG_READ_SWAP_SUPPORTED
+ if (png_ptr->transformations & PNG_SWAP_BYTES)
+ png_do_swap(row_info, png_ptr->row_buf + 1);
+#endif
+#endif
+
+#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED
+ if (png_ptr->transformations & PNG_USER_TRANSFORM)
+ {
+ if (png_ptr->read_user_transform_fn != NULL)
+ (*(png_ptr->read_user_transform_fn)) /* User read transform function */
+ (png_ptr, /* png_ptr */
+ row_info, /* row_info: */
+ /* png_uint_32 width; width of row */
+ /* png_size_t rowbytes; number of bytes in row */
+ /* png_byte color_type; color type of pixels */
+ /* png_byte bit_depth; bit depth of samples */
+ /* png_byte channels; number of channels (1-4) */
+ /* png_byte pixel_depth; bits per pixel (depth*channels) */
+ png_ptr->row_buf + 1); /* start of pixel data for row */
+#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED
+ if (png_ptr->user_transform_depth)
+ row_info->bit_depth = png_ptr->user_transform_depth;
+
+ if (png_ptr->user_transform_channels)
+ row_info->channels = png_ptr->user_transform_channels;
+#endif
+ row_info->pixel_depth = (png_byte)(row_info->bit_depth *
+ row_info->channels);
+
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_info->width);
+ }
+#endif
+}
+
+#ifdef PNG_READ_PACK_SUPPORTED
+/* Unpack pixels of 1, 2, or 4 bits per pixel into 1 byte per pixel,
+ * without changing the actual values. Thus, if you had a row with
+ * a bit depth of 1, you would end up with bytes that only contained
+ * the numbers 0 or 1. If you would rather they contain 0 and 255, use
+ * png_do_shift() after this.
+ */
+void /* PRIVATE */
+png_do_unpack(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_unpack");
+
+ if (row_info->bit_depth < 8)
+ {
+ png_uint_32 i;
+ png_uint_32 row_width=row_info->width;
+
+ switch (row_info->bit_depth)
+ {
+ case 1:
+ {
+ png_bytep sp = row + (png_size_t)((row_width - 1) >> 3);
+ png_bytep dp = row + (png_size_t)row_width - 1;
+ png_uint_32 shift = 7 - (int)((row_width + 7) & 0x07);
+ for (i = 0; i < row_width; i++)
+ {
+ *dp = (png_byte)((*sp >> shift) & 0x01);
+
+ if (shift == 7)
+ {
+ shift = 0;
+ sp--;
+ }
+
+ else
+ shift++;
+
+ dp--;
+ }
+ break;
+ }
+
+ case 2:
+ {
+
+ png_bytep sp = row + (png_size_t)((row_width - 1) >> 2);
+ png_bytep dp = row + (png_size_t)row_width - 1;
+ png_uint_32 shift = (int)((3 - ((row_width + 3) & 0x03)) << 1);
+ for (i = 0; i < row_width; i++)
+ {
+ *dp = (png_byte)((*sp >> shift) & 0x03);
+
+ if (shift == 6)
+ {
+ shift = 0;
+ sp--;
+ }
+
+ else
+ shift += 2;
+
+ dp--;
+ }
+ break;
+ }
+
+ case 4:
+ {
+ png_bytep sp = row + (png_size_t)((row_width - 1) >> 1);
+ png_bytep dp = row + (png_size_t)row_width - 1;
+ png_uint_32 shift = (int)((1 - ((row_width + 1) & 0x01)) << 2);
+ for (i = 0; i < row_width; i++)
+ {
+ *dp = (png_byte)((*sp >> shift) & 0x0f);
+
+ if (shift == 4)
+ {
+ shift = 0;
+ sp--;
+ }
+
+ else
+ shift = 4;
+
+ dp--;
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+ row_info->bit_depth = 8;
+ row_info->pixel_depth = (png_byte)(8 * row_info->channels);
+ row_info->rowbytes = row_width * row_info->channels;
+ }
+}
+#endif
+
+#ifdef PNG_READ_SHIFT_SUPPORTED
+/* Reverse the effects of png_do_shift. This routine merely shifts the
+ * pixels back to their significant bits values. Thus, if you have
+ * a row of bit depth 8, but only 5 are significant, this will shift
+ * the values back to 0 through 31.
+ */
+void /* PRIVATE */
+png_do_unshift(png_row_infop row_info, png_bytep row,
+ png_const_color_8p sig_bits)
+{
+ int color_type;
+
+ png_debug(1, "in png_do_unshift");
+
+ /* The palette case has already been handled in the _init routine. */
+ color_type = row_info->color_type;
+
+ if (color_type != PNG_COLOR_TYPE_PALETTE)
+ {
+ int shift[4];
+ int channels = 0;
+ int bit_depth = row_info->bit_depth;
+
+ if (color_type & PNG_COLOR_MASK_COLOR)
+ {
+ shift[channels++] = bit_depth - sig_bits->red;
+ shift[channels++] = bit_depth - sig_bits->green;
+ shift[channels++] = bit_depth - sig_bits->blue;
+ }
+
+ else
+ {
+ shift[channels++] = bit_depth - sig_bits->gray;
+ }
+
+ if (color_type & PNG_COLOR_MASK_ALPHA)
+ {
+ shift[channels++] = bit_depth - sig_bits->alpha;
+ }
+
+ {
+ int c, have_shift;
+
+ for (c = have_shift = 0; c < channels; ++c)
+ {
+ /* A shift of more than the bit depth is an error condition but it
+ * gets ignored here.
+ */
+ if (shift[c] <= 0 || shift[c] >= bit_depth)
+ shift[c] = 0;
+
+ else
+ have_shift = 1;
+ }
+
+ if (!have_shift)
+ return;
+ }
+
+ switch (bit_depth)
+ {
+ default:
+ /* Must be 1bpp gray: should not be here! */
+ /* NOTREACHED */
+ break;
+
+ case 2:
+ /* Must be 2bpp gray */
+ /* assert(channels == 1 && shift[0] == 1) */
+ {
+ png_bytep bp = row;
+ png_bytep bp_end = bp + row_info->rowbytes;
+
+ while (bp < bp_end)
+ {
+ int b = (*bp >> 1) & 0x55;
+ *bp++ = (png_byte)b;
+ }
+ break;
+ }
+
+ case 4:
+ /* Must be 4bpp gray */
+ /* assert(channels == 1) */
+ {
+ png_bytep bp = row;
+ png_bytep bp_end = bp + row_info->rowbytes;
+ int gray_shift = shift[0];
+ int mask = 0xf >> gray_shift;
+
+ mask |= mask << 4;
+
+ while (bp < bp_end)
+ {
+ int b = (*bp >> gray_shift) & mask;
+ *bp++ = (png_byte)b;
+ }
+ break;
+ }
+
+ case 8:
+ /* Single byte components, G, GA, RGB, RGBA */
+ {
+ png_bytep bp = row;
+ png_bytep bp_end = bp + row_info->rowbytes;
+ int channel = 0;
+
+ while (bp < bp_end)
+ {
+ int b = *bp >> shift[channel];
+ if (++channel >= channels)
+ channel = 0;
+ *bp++ = (png_byte)b;
+ }
+ break;
+ }
+
+#ifdef PNG_READ_16BIT_SUPPORTED
+ case 16:
+ /* Double byte components, G, GA, RGB, RGBA */
+ {
+ png_bytep bp = row;
+ png_bytep bp_end = bp + row_info->rowbytes;
+ int channel = 0;
+
+ while (bp < bp_end)
+ {
+ int value = (bp[0] << 8) + bp[1];
+
+ value >>= shift[channel];
+ if (++channel >= channels)
+ channel = 0;
+ *bp++ = (png_byte)(value >> 8);
+ *bp++ = (png_byte)(value & 0xff);
+ }
+ break;
+ }
+#endif
+ }
+ }
+}
+#endif
+
+#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED
+/* Scale rows of bit depth 16 down to 8 accurately */
+void /* PRIVATE */
+png_do_scale_16_to_8(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_scale_16_to_8");
+
+ if (row_info->bit_depth == 16)
+ {
+ png_bytep sp = row; /* source */
+ png_bytep dp = row; /* destination */
+ png_bytep ep = sp + row_info->rowbytes; /* end+1 */
+
+ while (sp < ep)
+ {
+ /* The input is an array of 16 bit components, these must be scaled to
+ * 8 bits each. For a 16 bit value V the required value (from the PNG
+ * specification) is:
+ *
+ * (V * 255) / 65535
+ *
+ * This reduces to round(V / 257), or floor((V + 128.5)/257)
+ *
+ * Represent V as the two byte value vhi.vlo. Make a guess that the
+ * result is the top byte of V, vhi, then the correction to this value
+ * is:
+ *
+ * error = floor(((V-vhi.vhi) + 128.5) / 257)
+ * = floor(((vlo-vhi) + 128.5) / 257)
+ *
+ * This can be approximated using integer arithmetic (and a signed
+ * shift):
+ *
+ * error = (vlo-vhi+128) >> 8;
+ *
+ * The approximate differs from the exact answer only when (vlo-vhi) is
+ * 128; it then gives a correction of +1 when the exact correction is
+ * 0. This gives 128 errors. The exact answer (correct for all 16 bit
+ * input values) is:
+ *
+ * error = (vlo-vhi+128)*65535 >> 24;
+ *
+ * An alternative arithmetic calculation which also gives no errors is:
+ *
+ * (V * 255 + 32895) >> 16
+ */
+
+ png_int_32 tmp = *sp++; /* must be signed! */
+ tmp += (((int)*sp++ - tmp + 128) * 65535) >> 24;
+ *dp++ = (png_byte)tmp;
+ }
+
+ row_info->bit_depth = 8;
+ row_info->pixel_depth = (png_byte)(8 * row_info->channels);
+ row_info->rowbytes = row_info->width * row_info->channels;
+ }
+}
+#endif
+
+#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED
+void /* PRIVATE */
+/* Simply discard the low byte. This was the default behavior prior
+ * to libpng-1.5.4.
+ */
+png_do_chop(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_chop");
+
+ if (row_info->bit_depth == 16)
+ {
+ png_bytep sp = row; /* source */
+ png_bytep dp = row; /* destination */
+ png_bytep ep = sp + row_info->rowbytes; /* end+1 */
+
+ while (sp < ep)
+ {
+ *dp++ = *sp;
+ sp += 2; /* skip low byte */
+ }
+
+ row_info->bit_depth = 8;
+ row_info->pixel_depth = (png_byte)(8 * row_info->channels);
+ row_info->rowbytes = row_info->width * row_info->channels;
+ }
+}
+#endif
+
+#ifdef PNG_READ_SWAP_ALPHA_SUPPORTED
+void /* PRIVATE */
+png_do_read_swap_alpha(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_read_swap_alpha");
+
+ {
+ png_uint_32 row_width = row_info->width;
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ {
+ /* This converts from RGBA to ARGB */
+ if (row_info->bit_depth == 8)
+ {
+ png_bytep sp = row + row_info->rowbytes;
+ png_bytep dp = sp;
+ png_byte save;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ save = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = save;
+ }
+ }
+
+#ifdef PNG_READ_16BIT_SUPPORTED
+ /* This converts from RRGGBBAA to AARRGGBB */
+ else
+ {
+ png_bytep sp = row + row_info->rowbytes;
+ png_bytep dp = sp;
+ png_byte save[2];
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ save[0] = *(--sp);
+ save[1] = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = save[0];
+ *(--dp) = save[1];
+ }
+ }
+#endif
+ }
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
+ {
+ /* This converts from GA to AG */
+ if (row_info->bit_depth == 8)
+ {
+ png_bytep sp = row + row_info->rowbytes;
+ png_bytep dp = sp;
+ png_byte save;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ save = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = save;
+ }
+ }
+
+#ifdef PNG_READ_16BIT_SUPPORTED
+ /* This converts from GGAA to AAGG */
+ else
+ {
+ png_bytep sp = row + row_info->rowbytes;
+ png_bytep dp = sp;
+ png_byte save[2];
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ save[0] = *(--sp);
+ save[1] = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = save[0];
+ *(--dp) = save[1];
+ }
+ }
+#endif
+ }
+ }
+}
+#endif
+
+#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED
+void /* PRIVATE */
+png_do_read_invert_alpha(png_row_infop row_info, png_bytep row)
+{
+ png_uint_32 row_width;
+ png_debug(1, "in png_do_read_invert_alpha");
+
+ row_width = row_info->width;
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ /* This inverts the alpha channel in RGBA */
+ png_bytep sp = row + row_info->rowbytes;
+ png_bytep dp = sp;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ *(--dp) = (png_byte)(255 - *(--sp));
+
+/* This does nothing:
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ We can replace it with:
+*/
+ sp-=3;
+ dp=sp;
+ }
+ }
+
+#ifdef PNG_READ_16BIT_SUPPORTED
+ /* This inverts the alpha channel in RRGGBBAA */
+ else
+ {
+ png_bytep sp = row + row_info->rowbytes;
+ png_bytep dp = sp;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ *(--dp) = (png_byte)(255 - *(--sp));
+ *(--dp) = (png_byte)(255 - *(--sp));
+
+/* This does nothing:
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ We can replace it with:
+*/
+ sp-=6;
+ dp=sp;
+ }
+ }
+#endif
+ }
+ else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ /* This inverts the alpha channel in GA */
+ png_bytep sp = row + row_info->rowbytes;
+ png_bytep dp = sp;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ *(--dp) = (png_byte)(255 - *(--sp));
+ *(--dp) = *(--sp);
+ }
+ }
+
+#ifdef PNG_READ_16BIT_SUPPORTED
+ else
+ {
+ /* This inverts the alpha channel in GGAA */
+ png_bytep sp = row + row_info->rowbytes;
+ png_bytep dp = sp;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ *(--dp) = (png_byte)(255 - *(--sp));
+ *(--dp) = (png_byte)(255 - *(--sp));
+/*
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+*/
+ sp-=2;
+ dp=sp;
+ }
+ }
+#endif
+ }
+}
+#endif
+
+#ifdef PNG_READ_FILLER_SUPPORTED
+/* Add filler channel if we have RGB color */
+void /* PRIVATE */
+png_do_read_filler(png_row_infop row_info, png_bytep row,
+ png_uint_32 filler, png_uint_32 flags)
+{
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+#ifdef PNG_READ_16BIT_SUPPORTED
+ png_byte hi_filler = (png_byte)((filler>>8) & 0xff);
+#endif
+ png_byte lo_filler = (png_byte)(filler & 0xff);
+
+ png_debug(1, "in png_do_read_filler");
+
+ if (
+ row_info->color_type == PNG_COLOR_TYPE_GRAY)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ if (flags & PNG_FLAG_FILLER_AFTER)
+ {
+ /* This changes the data from G to GX */
+ png_bytep sp = row + (png_size_t)row_width;
+ png_bytep dp = sp + (png_size_t)row_width;
+ for (i = 1; i < row_width; i++)
+ {
+ *(--dp) = lo_filler;
+ *(--dp) = *(--sp);
+ }
+ *(--dp) = lo_filler;
+ row_info->channels = 2;
+ row_info->pixel_depth = 16;
+ row_info->rowbytes = row_width * 2;
+ }
+
+ else
+ {
+ /* This changes the data from G to XG */
+ png_bytep sp = row + (png_size_t)row_width;
+ png_bytep dp = sp + (png_size_t)row_width;
+ for (i = 0; i < row_width; i++)
+ {
+ *(--dp) = *(--sp);
+ *(--dp) = lo_filler;
+ }
+ row_info->channels = 2;
+ row_info->pixel_depth = 16;
+ row_info->rowbytes = row_width * 2;
+ }
+ }
+
+#ifdef PNG_READ_16BIT_SUPPORTED
+ else if (row_info->bit_depth == 16)
+ {
+ if (flags & PNG_FLAG_FILLER_AFTER)
+ {
+ /* This changes the data from GG to GGXX */
+ png_bytep sp = row + (png_size_t)row_width * 2;
+ png_bytep dp = sp + (png_size_t)row_width * 2;
+ for (i = 1; i < row_width; i++)
+ {
+ *(--dp) = hi_filler;
+ *(--dp) = lo_filler;
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ }
+ *(--dp) = hi_filler;
+ *(--dp) = lo_filler;
+ row_info->channels = 2;
+ row_info->pixel_depth = 32;
+ row_info->rowbytes = row_width * 4;
+ }
+
+ else
+ {
+ /* This changes the data from GG to XXGG */
+ png_bytep sp = row + (png_size_t)row_width * 2;
+ png_bytep dp = sp + (png_size_t)row_width * 2;
+ for (i = 0; i < row_width; i++)
+ {
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = hi_filler;
+ *(--dp) = lo_filler;
+ }
+ row_info->channels = 2;
+ row_info->pixel_depth = 32;
+ row_info->rowbytes = row_width * 4;
+ }
+ }
+#endif
+ } /* COLOR_TYPE == GRAY */
+ else if (row_info->color_type == PNG_COLOR_TYPE_RGB)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ if (flags & PNG_FLAG_FILLER_AFTER)
+ {
+ /* This changes the data from RGB to RGBX */
+ png_bytep sp = row + (png_size_t)row_width * 3;
+ png_bytep dp = sp + (png_size_t)row_width;
+ for (i = 1; i < row_width; i++)
+ {
+ *(--dp) = lo_filler;
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ }
+ *(--dp) = lo_filler;
+ row_info->channels = 4;
+ row_info->pixel_depth = 32;
+ row_info->rowbytes = row_width * 4;
+ }
+
+ else
+ {
+ /* This changes the data from RGB to XRGB */
+ png_bytep sp = row + (png_size_t)row_width * 3;
+ png_bytep dp = sp + (png_size_t)row_width;
+ for (i = 0; i < row_width; i++)
+ {
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = lo_filler;
+ }
+ row_info->channels = 4;
+ row_info->pixel_depth = 32;
+ row_info->rowbytes = row_width * 4;
+ }
+ }
+
+#ifdef PNG_READ_16BIT_SUPPORTED
+ else if (row_info->bit_depth == 16)
+ {
+ if (flags & PNG_FLAG_FILLER_AFTER)
+ {
+ /* This changes the data from RRGGBB to RRGGBBXX */
+ png_bytep sp = row + (png_size_t)row_width * 6;
+ png_bytep dp = sp + (png_size_t)row_width * 2;
+ for (i = 1; i < row_width; i++)
+ {
+ *(--dp) = hi_filler;
+ *(--dp) = lo_filler;
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ }
+ *(--dp) = hi_filler;
+ *(--dp) = lo_filler;
+ row_info->channels = 4;
+ row_info->pixel_depth = 64;
+ row_info->rowbytes = row_width * 8;
+ }
+
+ else
+ {
+ /* This changes the data from RRGGBB to XXRRGGBB */
+ png_bytep sp = row + (png_size_t)row_width * 6;
+ png_bytep dp = sp + (png_size_t)row_width * 2;
+ for (i = 0; i < row_width; i++)
+ {
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = *(--sp);
+ *(--dp) = hi_filler;
+ *(--dp) = lo_filler;
+ }
+
+ row_info->channels = 4;
+ row_info->pixel_depth = 64;
+ row_info->rowbytes = row_width * 8;
+ }
+ }
+#endif
+ } /* COLOR_TYPE == RGB */
+}
+#endif
+
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+/* Expand grayscale files to RGB, with or without alpha */
+void /* PRIVATE */
+png_do_gray_to_rgb(png_row_infop row_info, png_bytep row)
+{
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ png_debug(1, "in png_do_gray_to_rgb");
+
+ if (row_info->bit_depth >= 8 &&
+ !(row_info->color_type & PNG_COLOR_MASK_COLOR))
+ {
+ if (row_info->color_type == PNG_COLOR_TYPE_GRAY)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ /* This changes G to RGB */
+ png_bytep sp = row + (png_size_t)row_width - 1;
+ png_bytep dp = sp + (png_size_t)row_width * 2;
+ for (i = 0; i < row_width; i++)
+ {
+ *(dp--) = *sp;
+ *(dp--) = *sp;
+ *(dp--) = *(sp--);
+ }
+ }
+
+ else
+ {
+ /* This changes GG to RRGGBB */
+ png_bytep sp = row + (png_size_t)row_width * 2 - 1;
+ png_bytep dp = sp + (png_size_t)row_width * 4;
+ for (i = 0; i < row_width; i++)
+ {
+ *(dp--) = *sp;
+ *(dp--) = *(sp - 1);
+ *(dp--) = *sp;
+ *(dp--) = *(sp - 1);
+ *(dp--) = *(sp--);
+ *(dp--) = *(sp--);
+ }
+ }
+ }
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ /* This changes GA to RGBA */
+ png_bytep sp = row + (png_size_t)row_width * 2 - 1;
+ png_bytep dp = sp + (png_size_t)row_width * 2;
+ for (i = 0; i < row_width; i++)
+ {
+ *(dp--) = *(sp--);
+ *(dp--) = *sp;
+ *(dp--) = *sp;
+ *(dp--) = *(sp--);
+ }
+ }
+
+ else
+ {
+ /* This changes GGAA to RRGGBBAA */
+ png_bytep sp = row + (png_size_t)row_width * 4 - 1;
+ png_bytep dp = sp + (png_size_t)row_width * 4;
+ for (i = 0; i < row_width; i++)
+ {
+ *(dp--) = *(sp--);
+ *(dp--) = *(sp--);
+ *(dp--) = *sp;
+ *(dp--) = *(sp - 1);
+ *(dp--) = *sp;
+ *(dp--) = *(sp - 1);
+ *(dp--) = *(sp--);
+ *(dp--) = *(sp--);
+ }
+ }
+ }
+ row_info->channels = (png_byte)(row_info->channels + 2);
+ row_info->color_type |= PNG_COLOR_MASK_COLOR;
+ row_info->pixel_depth = (png_byte)(row_info->channels *
+ row_info->bit_depth);
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width);
+ }
+}
+#endif
+
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+/* Reduce RGB files to grayscale, with or without alpha
+ * using the equation given in Poynton's ColorFAQ of 1998-01-04 at
+ * <http://www.inforamp.net/~poynton/> (THIS LINK IS DEAD June 2008 but
+ * versions dated 1998 through November 2002 have been archived at
+ * http://web.archive.org/web/20000816232553/http://www.inforamp.net/
+ * ~poynton/notes/colour_and_gamma/ColorFAQ.txt )
+ * Charles Poynton poynton at poynton.com
+ *
+ * Y = 0.212671 * R + 0.715160 * G + 0.072169 * B
+ *
+ * which can be expressed with integers as
+ *
+ * Y = (6969 * R + 23434 * G + 2365 * B)/32768
+ *
+ * Poynton's current link (as of January 2003 through July 2011):
+ * <http://www.poynton.com/notes/colour_and_gamma/>
+ * has changed the numbers slightly:
+ *
+ * Y = 0.2126*R + 0.7152*G + 0.0722*B
+ *
+ * which can be expressed with integers as
+ *
+ * Y = (6966 * R + 23436 * G + 2366 * B)/32768
+ *
+ * Historically, however, libpng uses numbers derived from the ITU-R Rec 709
+ * end point chromaticities and the D65 white point. Depending on the
+ * precision used for the D65 white point this produces a variety of different
+ * numbers, however if the four decimal place value used in ITU-R Rec 709 is
+ * used (0.3127,0.3290) the Y calculation would be:
+ *
+ * Y = (6968 * R + 23435 * G + 2366 * B)/32768
+ *
+ * While this is correct the rounding results in an overflow for white, because
+ * the sum of the rounded coefficients is 32769, not 32768. Consequently
+ * libpng uses, instead, the closest non-overflowing approximation:
+ *
+ * Y = (6968 * R + 23434 * G + 2366 * B)/32768
+ *
+ * Starting with libpng-1.5.5, if the image being converted has a cHRM chunk
+ * (including an sRGB chunk) then the chromaticities are used to calculate the
+ * coefficients. See the chunk handling in pngrutil.c for more information.
+ *
+ * In all cases the calculation is to be done in a linear colorspace. If no
+ * gamma information is available to correct the encoding of the original RGB
+ * values this results in an implicit assumption that the original PNG RGB
+ * values were linear.
+ *
+ * Other integer coefficents can be used via png_set_rgb_to_gray(). Because
+ * the API takes just red and green coefficients the blue coefficient is
+ * calculated to make the sum 32768. This will result in different rounding
+ * to that used above.
+ */
+int /* PRIVATE */
+png_do_rgb_to_gray(png_structrp png_ptr, png_row_infop row_info, png_bytep row)
+
+{
+ int rgb_error = 0;
+
+ png_debug(1, "in png_do_rgb_to_gray");
+
+ if (!(row_info->color_type & PNG_COLOR_MASK_PALETTE) &&
+ (row_info->color_type & PNG_COLOR_MASK_COLOR))
+ {
+ PNG_CONST png_uint_32 rc = png_ptr->rgb_to_gray_red_coeff;
+ PNG_CONST png_uint_32 gc = png_ptr->rgb_to_gray_green_coeff;
+ PNG_CONST png_uint_32 bc = 32768 - rc - gc;
+ PNG_CONST png_uint_32 row_width = row_info->width;
+ PNG_CONST int have_alpha =
+ (row_info->color_type & PNG_COLOR_MASK_ALPHA) != 0;
+
+ if (row_info->bit_depth == 8)
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ /* Notice that gamma to/from 1 are not necessarily inverses (if
+ * there is an overall gamma correction). Prior to 1.5.5 this code
+ * checked the linearized values for equality; this doesn't match
+ * the documentation, the original values must be checked.
+ */
+ if (png_ptr->gamma_from_1 != NULL && png_ptr->gamma_to_1 != NULL)
+ {
+ png_bytep sp = row;
+ png_bytep dp = row;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ png_byte red = *(sp++);
+ png_byte green = *(sp++);
+ png_byte blue = *(sp++);
+
+ if (red != green || red != blue)
+ {
+ red = png_ptr->gamma_to_1[red];
+ green = png_ptr->gamma_to_1[green];
+ blue = png_ptr->gamma_to_1[blue];
+
+ rgb_error |= 1;
+ *(dp++) = png_ptr->gamma_from_1[
+ (rc*red + gc*green + bc*blue + 16384)>>15];
+ }
+
+ else
+ {
+ /* If there is no overall correction the table will not be
+ * set.
+ */
+ if (png_ptr->gamma_table != NULL)
+ red = png_ptr->gamma_table[red];
+
+ *(dp++) = red;
+ }
+
+ if (have_alpha)
+ *(dp++) = *(sp++);
+ }
+ }
+ else
+#endif
+ {
+ png_bytep sp = row;
+ png_bytep dp = row;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ png_byte red = *(sp++);
+ png_byte green = *(sp++);
+ png_byte blue = *(sp++);
+
+ if (red != green || red != blue)
+ {
+ rgb_error |= 1;
+ /* NOTE: this is the historical approach which simply
+ * truncates the results.
+ */
+ *(dp++) = (png_byte)((rc*red + gc*green + bc*blue)>>15);
+ }
+
+ else
+ *(dp++) = red;
+
+ if (have_alpha)
+ *(dp++) = *(sp++);
+ }
+ }
+ }
+
+ else /* RGB bit_depth == 16 */
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (png_ptr->gamma_16_to_1 != NULL && png_ptr->gamma_16_from_1 != NULL)
+ {
+ png_bytep sp = row;
+ png_bytep dp = row;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ png_uint_16 red, green, blue, w;
+
+ red = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2;
+ green = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2;
+ blue = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2;
+
+ if (red == green && red == blue)
+ {
+ if (png_ptr->gamma_16_table != NULL)
+ w = png_ptr->gamma_16_table[(red&0xff)
+ >> png_ptr->gamma_shift][red>>8];
+
+ else
+ w = red;
+ }
+
+ else
+ {
+ png_uint_16 red_1 = png_ptr->gamma_16_to_1[(red&0xff)
+ >> png_ptr->gamma_shift][red>>8];
+ png_uint_16 green_1 =
+ png_ptr->gamma_16_to_1[(green&0xff) >>
+ png_ptr->gamma_shift][green>>8];
+ png_uint_16 blue_1 = png_ptr->gamma_16_to_1[(blue&0xff)
+ >> png_ptr->gamma_shift][blue>>8];
+ png_uint_16 gray16 = (png_uint_16)((rc*red_1 + gc*green_1
+ + bc*blue_1 + 16384)>>15);
+ w = png_ptr->gamma_16_from_1[(gray16&0xff) >>
+ png_ptr->gamma_shift][gray16 >> 8];
+ rgb_error |= 1;
+ }
+
+ *(dp++) = (png_byte)((w>>8) & 0xff);
+ *(dp++) = (png_byte)(w & 0xff);
+
+ if (have_alpha)
+ {
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ }
+ }
+ }
+ else
+#endif
+ {
+ png_bytep sp = row;
+ png_bytep dp = row;
+ png_uint_32 i;
+
+ for (i = 0; i < row_width; i++)
+ {
+ png_uint_16 red, green, blue, gray16;
+
+ red = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2;
+ green = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2;
+ blue = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2;
+
+ if (red != green || red != blue)
+ rgb_error |= 1;
+
+ /* From 1.5.5 in the 16 bit case do the accurate conversion even
+ * in the 'fast' case - this is because this is where the code
+ * ends up when handling linear 16 bit data.
+ */
+ gray16 = (png_uint_16)((rc*red + gc*green + bc*blue + 16384) >>
+ 15);
+ *(dp++) = (png_byte)((gray16>>8) & 0xff);
+ *(dp++) = (png_byte)(gray16 & 0xff);
+
+ if (have_alpha)
+ {
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ }
+ }
+ }
+ }
+
+ row_info->channels = (png_byte)(row_info->channels - 2);
+ row_info->color_type = (png_byte)(row_info->color_type &
+ ~PNG_COLOR_MASK_COLOR);
+ row_info->pixel_depth = (png_byte)(row_info->channels *
+ row_info->bit_depth);
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width);
+ }
+ return rgb_error;
+}
+#endif
+#endif /* PNG_READ_TRANSFORMS_SUPPORTED */
+
+#ifdef PNG_BUILD_GRAYSCALE_PALETTE_SUPPORTED
+/* Build a grayscale palette. Palette is assumed to be 1 << bit_depth
+ * large of png_color. This lets grayscale images be treated as
+ * paletted. Most useful for gamma correction and simplification
+ * of code. This API is not used internally.
+ */
+void PNGAPI
+png_build_grayscale_palette(int bit_depth, png_colorp palette)
+{
+ int num_palette;
+ int color_inc;
+ int i;
+ int v;
+
+ png_debug(1, "in png_do_build_grayscale_palette");
+
+ if (palette == NULL)
+ return;
+
+ switch (bit_depth)
+ {
+ case 1:
+ num_palette = 2;
+ color_inc = 0xff;
+ break;
+
+ case 2:
+ num_palette = 4;
+ color_inc = 0x55;
+ break;
+
+ case 4:
+ num_palette = 16;
+ color_inc = 0x11;
+ break;
+
+ case 8:
+ num_palette = 256;
+ color_inc = 1;
+ break;
+
+ default:
+ num_palette = 0;
+ color_inc = 0;
+ break;
+ }
+
+ for (i = 0, v = 0; i < num_palette; i++, v += color_inc)
+ {
+ palette[i].red = (png_byte)v;
+ palette[i].green = (png_byte)v;
+ palette[i].blue = (png_byte)v;
+ }
+}
+#endif
+
+
+#ifdef PNG_READ_TRANSFORMS_SUPPORTED
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED)
+/* Replace any alpha or transparency with the supplied background color.
+ * "background" is already in the screen gamma, while "background_1" is
+ * at a gamma of 1.0. Paletted files have already been taken care of.
+ */
+void /* PRIVATE */
+png_do_compose(png_row_infop row_info, png_bytep row, png_structrp png_ptr)
+{
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ png_const_bytep gamma_table = png_ptr->gamma_table;
+ png_const_bytep gamma_from_1 = png_ptr->gamma_from_1;
+ png_const_bytep gamma_to_1 = png_ptr->gamma_to_1;
+ png_const_uint_16pp gamma_16 = png_ptr->gamma_16_table;
+ png_const_uint_16pp gamma_16_from_1 = png_ptr->gamma_16_from_1;
+ png_const_uint_16pp gamma_16_to_1 = png_ptr->gamma_16_to_1;
+ int gamma_shift = png_ptr->gamma_shift;
+ int optimize = (png_ptr->flags & PNG_FLAG_OPTIMIZE_ALPHA) != 0;
+#endif
+
+ png_bytep sp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+ int shift;
+
+ png_debug(1, "in png_do_compose");
+
+ {
+ switch (row_info->color_type)
+ {
+ case PNG_COLOR_TYPE_GRAY:
+ {
+ switch (row_info->bit_depth)
+ {
+ case 1:
+ {
+ sp = row;
+ shift = 7;
+ for (i = 0; i < row_width; i++)
+ {
+ if ((png_uint_16)((*sp >> shift) & 0x01)
+ == png_ptr->trans_color.gray)
+ {
+ unsigned int tmp = *sp & (0x7f7f >> (7 - shift));
+ tmp |= png_ptr->background.gray << shift;
+ *sp = (png_byte)(tmp & 0xff);
+ }
+
+ if (!shift)
+ {
+ shift = 7;
+ sp++;
+ }
+
+ else
+ shift--;
+ }
+ break;
+ }
+
+ case 2:
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_table != NULL)
+ {
+ sp = row;
+ shift = 6;
+ for (i = 0; i < row_width; i++)
+ {
+ if ((png_uint_16)((*sp >> shift) & 0x03)
+ == png_ptr->trans_color.gray)
+ {
+ unsigned int tmp = *sp & (0x3f3f >> (6 - shift));
+ tmp |= png_ptr->background.gray << shift;
+ *sp = (png_byte)(tmp & 0xff);
+ }
+
+ else
+ {
+ unsigned int p = (*sp >> shift) & 0x03;
+ unsigned int g = (gamma_table [p | (p << 2) |
+ (p << 4) | (p << 6)] >> 6) & 0x03;
+ unsigned int tmp = *sp & (0x3f3f >> (6 - shift));
+ tmp |= g << shift;
+ *sp = (png_byte)(tmp & 0xff);
+ }
+
+ if (!shift)
+ {
+ shift = 6;
+ sp++;
+ }
+
+ else
+ shift -= 2;
+ }
+ }
+
+ else
+#endif
+ {
+ sp = row;
+ shift = 6;
+ for (i = 0; i < row_width; i++)
+ {
+ if ((png_uint_16)((*sp >> shift) & 0x03)
+ == png_ptr->trans_color.gray)
+ {
+ unsigned int tmp = *sp & (0x3f3f >> (6 - shift));
+ tmp |= png_ptr->background.gray << shift;
+ *sp = (png_byte)(tmp & 0xff);
+ }
+
+ if (!shift)
+ {
+ shift = 6;
+ sp++;
+ }
+
+ else
+ shift -= 2;
+ }
+ }
+ break;
+ }
+
+ case 4:
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_table != NULL)
+ {
+ sp = row;
+ shift = 4;
+ for (i = 0; i < row_width; i++)
+ {
+ if ((png_uint_16)((*sp >> shift) & 0x0f)
+ == png_ptr->trans_color.gray)
+ {
+ unsigned int tmp = *sp & (0xf0f >> (4 - shift));
+ tmp |= png_ptr->background.gray << shift;
+ *sp = (png_byte)(tmp & 0xff);
+ }
+
+ else
+ {
+ unsigned int p = (*sp >> shift) & 0x0f;
+ unsigned int g = (gamma_table[p | (p << 4)] >> 4) &
+ 0x0f;
+ unsigned int tmp = *sp & (0xf0f >> (4 - shift));
+ tmp |= g << shift;
+ *sp = (png_byte)(tmp & 0xff);
+ }
+
+ if (!shift)
+ {
+ shift = 4;
+ sp++;
+ }
+
+ else
+ shift -= 4;
+ }
+ }
+
+ else
+#endif
+ {
+ sp = row;
+ shift = 4;
+ for (i = 0; i < row_width; i++)
+ {
+ if ((png_uint_16)((*sp >> shift) & 0x0f)
+ == png_ptr->trans_color.gray)
+ {
+ unsigned int tmp = *sp & (0xf0f >> (4 - shift));
+ tmp |= png_ptr->background.gray << shift;
+ *sp = (png_byte)(tmp & 0xff);
+ }
+
+ if (!shift)
+ {
+ shift = 4;
+ sp++;
+ }
+
+ else
+ shift -= 4;
+ }
+ }
+ break;
+ }
+
+ case 8:
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_table != NULL)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp++)
+ {
+ if (*sp == png_ptr->trans_color.gray)
+ *sp = (png_byte)png_ptr->background.gray;
+
+ else
+ *sp = gamma_table[*sp];
+ }
+ }
+ else
+#endif
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp++)
+ {
+ if (*sp == png_ptr->trans_color.gray)
+ *sp = (png_byte)png_ptr->background.gray;
+ }
+ }
+ break;
+ }
+
+ case 16:
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_16 != NULL)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 2)
+ {
+ png_uint_16 v;
+
+ v = (png_uint_16)(((*sp) << 8) + *(sp + 1));
+
+ if (v == png_ptr->trans_color.gray)
+ {
+ /* Background is already in screen gamma */
+ *sp = (png_byte)((png_ptr->background.gray >> 8)
+ & 0xff);
+ *(sp + 1) = (png_byte)(png_ptr->background.gray
+ & 0xff);
+ }
+
+ else
+ {
+ v = gamma_16[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ }
+ }
+ }
+ else
+#endif
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 2)
+ {
+ png_uint_16 v;
+
+ v = (png_uint_16)(((*sp) << 8) + *(sp + 1));
+
+ if (v == png_ptr->trans_color.gray)
+ {
+ *sp = (png_byte)((png_ptr->background.gray >> 8)
+ & 0xff);
+ *(sp + 1) = (png_byte)(png_ptr->background.gray
+ & 0xff);
+ }
+ }
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+ break;
+ }
+
+ case PNG_COLOR_TYPE_RGB:
+ {
+ if (row_info->bit_depth == 8)
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_table != NULL)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 3)
+ {
+ if (*sp == png_ptr->trans_color.red &&
+ *(sp + 1) == png_ptr->trans_color.green &&
+ *(sp + 2) == png_ptr->trans_color.blue)
+ {
+ *sp = (png_byte)png_ptr->background.red;
+ *(sp + 1) = (png_byte)png_ptr->background.green;
+ *(sp + 2) = (png_byte)png_ptr->background.blue;
+ }
+
+ else
+ {
+ *sp = gamma_table[*sp];
+ *(sp + 1) = gamma_table[*(sp + 1)];
+ *(sp + 2) = gamma_table[*(sp + 2)];
+ }
+ }
+ }
+ else
+#endif
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 3)
+ {
+ if (*sp == png_ptr->trans_color.red &&
+ *(sp + 1) == png_ptr->trans_color.green &&
+ *(sp + 2) == png_ptr->trans_color.blue)
+ {
+ *sp = (png_byte)png_ptr->background.red;
+ *(sp + 1) = (png_byte)png_ptr->background.green;
+ *(sp + 2) = (png_byte)png_ptr->background.blue;
+ }
+ }
+ }
+ }
+ else /* if (row_info->bit_depth == 16) */
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_16 != NULL)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 6)
+ {
+ png_uint_16 r = (png_uint_16)(((*sp) << 8) + *(sp + 1));
+
+ png_uint_16 g = (png_uint_16)(((*(sp + 2)) << 8)
+ + *(sp + 3));
+
+ png_uint_16 b = (png_uint_16)(((*(sp + 4)) << 8)
+ + *(sp + 5));
+
+ if (r == png_ptr->trans_color.red &&
+ g == png_ptr->trans_color.green &&
+ b == png_ptr->trans_color.blue)
+ {
+ /* Background is already in screen gamma */
+ *sp = (png_byte)((png_ptr->background.red >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(png_ptr->background.red & 0xff);
+ *(sp + 2) = (png_byte)((png_ptr->background.green >> 8)
+ & 0xff);
+ *(sp + 3) = (png_byte)(png_ptr->background.green
+ & 0xff);
+ *(sp + 4) = (png_byte)((png_ptr->background.blue >> 8)
+ & 0xff);
+ *(sp + 5) = (png_byte)(png_ptr->background.blue & 0xff);
+ }
+
+ else
+ {
+ png_uint_16 v = gamma_16[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+
+ v = gamma_16[*(sp + 3) >> gamma_shift][*(sp + 2)];
+ *(sp + 2) = (png_byte)((v >> 8) & 0xff);
+ *(sp + 3) = (png_byte)(v & 0xff);
+
+ v = gamma_16[*(sp + 5) >> gamma_shift][*(sp + 4)];
+ *(sp + 4) = (png_byte)((v >> 8) & 0xff);
+ *(sp + 5) = (png_byte)(v & 0xff);
+ }
+ }
+ }
+
+ else
+#endif
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 6)
+ {
+ png_uint_16 r = (png_uint_16)(((*sp) << 8) + *(sp + 1));
+
+ png_uint_16 g = (png_uint_16)(((*(sp + 2)) << 8)
+ + *(sp + 3));
+
+ png_uint_16 b = (png_uint_16)(((*(sp + 4)) << 8)
+ + *(sp + 5));
+
+ if (r == png_ptr->trans_color.red &&
+ g == png_ptr->trans_color.green &&
+ b == png_ptr->trans_color.blue)
+ {
+ *sp = (png_byte)((png_ptr->background.red >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(png_ptr->background.red & 0xff);
+ *(sp + 2) = (png_byte)((png_ptr->background.green >> 8)
+ & 0xff);
+ *(sp + 3) = (png_byte)(png_ptr->background.green
+ & 0xff);
+ *(sp + 4) = (png_byte)((png_ptr->background.blue >> 8)
+ & 0xff);
+ *(sp + 5) = (png_byte)(png_ptr->background.blue & 0xff);
+ }
+ }
+ }
+ }
+ break;
+ }
+
+ case PNG_COLOR_TYPE_GRAY_ALPHA:
+ {
+ if (row_info->bit_depth == 8)
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_to_1 != NULL && gamma_from_1 != NULL &&
+ gamma_table != NULL)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 2)
+ {
+ png_uint_16 a = *(sp + 1);
+
+ if (a == 0xff)
+ *sp = gamma_table[*sp];
+
+ else if (a == 0)
+ {
+ /* Background is already in screen gamma */
+ *sp = (png_byte)png_ptr->background.gray;
+ }
+
+ else
+ {
+ png_byte v, w;
+
+ v = gamma_to_1[*sp];
+ png_composite(w, v, a, png_ptr->background_1.gray);
+ if (!optimize)
+ w = gamma_from_1[w];
+ *sp = w;
+ }
+ }
+ }
+ else
+#endif
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 2)
+ {
+ png_byte a = *(sp + 1);
+
+ if (a == 0)
+ *sp = (png_byte)png_ptr->background.gray;
+
+ else if (a < 0xff)
+ png_composite(*sp, *sp, a, png_ptr->background.gray);
+ }
+ }
+ }
+ else /* if (png_ptr->bit_depth == 16) */
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_16 != NULL && gamma_16_from_1 != NULL &&
+ gamma_16_to_1 != NULL)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 4)
+ {
+ png_uint_16 a = (png_uint_16)(((*(sp + 2)) << 8)
+ + *(sp + 3));
+
+ if (a == (png_uint_16)0xffff)
+ {
+ png_uint_16 v;
+
+ v = gamma_16[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ }
+
+ else if (a == 0)
+ {
+ /* Background is already in screen gamma */
+ *sp = (png_byte)((png_ptr->background.gray >> 8)
+ & 0xff);
+ *(sp + 1) = (png_byte)(png_ptr->background.gray & 0xff);
+ }
+
+ else
+ {
+ png_uint_16 g, v, w;
+
+ g = gamma_16_to_1[*(sp + 1) >> gamma_shift][*sp];
+ png_composite_16(v, g, a, png_ptr->background_1.gray);
+ if (optimize)
+ w = v;
+ else
+ w = gamma_16_from_1[(v&0xff) >> gamma_shift][v >> 8];
+ *sp = (png_byte)((w >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(w & 0xff);
+ }
+ }
+ }
+ else
+#endif
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 4)
+ {
+ png_uint_16 a = (png_uint_16)(((*(sp + 2)) << 8)
+ + *(sp + 3));
+
+ if (a == 0)
+ {
+ *sp = (png_byte)((png_ptr->background.gray >> 8)
+ & 0xff);
+ *(sp + 1) = (png_byte)(png_ptr->background.gray & 0xff);
+ }
+
+ else if (a < 0xffff)
+ {
+ png_uint_16 g, v;
+
+ g = (png_uint_16)(((*sp) << 8) + *(sp + 1));
+ png_composite_16(v, g, a, png_ptr->background.gray);
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ }
+ }
+ }
+ }
+ break;
+ }
+
+ case PNG_COLOR_TYPE_RGB_ALPHA:
+ {
+ if (row_info->bit_depth == 8)
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_to_1 != NULL && gamma_from_1 != NULL &&
+ gamma_table != NULL)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 4)
+ {
+ png_byte a = *(sp + 3);
+
+ if (a == 0xff)
+ {
+ *sp = gamma_table[*sp];
+ *(sp + 1) = gamma_table[*(sp + 1)];
+ *(sp + 2) = gamma_table[*(sp + 2)];
+ }
+
+ else if (a == 0)
+ {
+ /* Background is already in screen gamma */
+ *sp = (png_byte)png_ptr->background.red;
+ *(sp + 1) = (png_byte)png_ptr->background.green;
+ *(sp + 2) = (png_byte)png_ptr->background.blue;
+ }
+
+ else
+ {
+ png_byte v, w;
+
+ v = gamma_to_1[*sp];
+ png_composite(w, v, a, png_ptr->background_1.red);
+ if (!optimize) w = gamma_from_1[w];
+ *sp = w;
+
+ v = gamma_to_1[*(sp + 1)];
+ png_composite(w, v, a, png_ptr->background_1.green);
+ if (!optimize) w = gamma_from_1[w];
+ *(sp + 1) = w;
+
+ v = gamma_to_1[*(sp + 2)];
+ png_composite(w, v, a, png_ptr->background_1.blue);
+ if (!optimize) w = gamma_from_1[w];
+ *(sp + 2) = w;
+ }
+ }
+ }
+ else
+#endif
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 4)
+ {
+ png_byte a = *(sp + 3);
+
+ if (a == 0)
+ {
+ *sp = (png_byte)png_ptr->background.red;
+ *(sp + 1) = (png_byte)png_ptr->background.green;
+ *(sp + 2) = (png_byte)png_ptr->background.blue;
+ }
+
+ else if (a < 0xff)
+ {
+ png_composite(*sp, *sp, a, png_ptr->background.red);
+
+ png_composite(*(sp + 1), *(sp + 1), a,
+ png_ptr->background.green);
+
+ png_composite(*(sp + 2), *(sp + 2), a,
+ png_ptr->background.blue);
+ }
+ }
+ }
+ }
+ else /* if (row_info->bit_depth == 16) */
+ {
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ if (gamma_16 != NULL && gamma_16_from_1 != NULL &&
+ gamma_16_to_1 != NULL)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 8)
+ {
+ png_uint_16 a = (png_uint_16)(((png_uint_16)(*(sp + 6))
+ << 8) + (png_uint_16)(*(sp + 7)));
+
+ if (a == (png_uint_16)0xffff)
+ {
+ png_uint_16 v;
+
+ v = gamma_16[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+
+ v = gamma_16[*(sp + 3) >> gamma_shift][*(sp + 2)];
+ *(sp + 2) = (png_byte)((v >> 8) & 0xff);
+ *(sp + 3) = (png_byte)(v & 0xff);
+
+ v = gamma_16[*(sp + 5) >> gamma_shift][*(sp + 4)];
+ *(sp + 4) = (png_byte)((v >> 8) & 0xff);
+ *(sp + 5) = (png_byte)(v & 0xff);
+ }
+
+ else if (a == 0)
+ {
+ /* Background is already in screen gamma */
+ *sp = (png_byte)((png_ptr->background.red >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(png_ptr->background.red & 0xff);
+ *(sp + 2) = (png_byte)((png_ptr->background.green >> 8)
+ & 0xff);
+ *(sp + 3) = (png_byte)(png_ptr->background.green
+ & 0xff);
+ *(sp + 4) = (png_byte)((png_ptr->background.blue >> 8)
+ & 0xff);
+ *(sp + 5) = (png_byte)(png_ptr->background.blue & 0xff);
+ }
+
+ else
+ {
+ png_uint_16 v, w;
+
+ v = gamma_16_to_1[*(sp + 1) >> gamma_shift][*sp];
+ png_composite_16(w, v, a, png_ptr->background_1.red);
+ if (!optimize)
+ w = gamma_16_from_1[((w&0xff) >> gamma_shift)][w >>
+ 8];
+ *sp = (png_byte)((w >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(w & 0xff);
+
+ v = gamma_16_to_1[*(sp + 3) >> gamma_shift][*(sp + 2)];
+ png_composite_16(w, v, a, png_ptr->background_1.green);
+ if (!optimize)
+ w = gamma_16_from_1[((w&0xff) >> gamma_shift)][w >>
+ 8];
+
+ *(sp + 2) = (png_byte)((w >> 8) & 0xff);
+ *(sp + 3) = (png_byte)(w & 0xff);
+
+ v = gamma_16_to_1[*(sp + 5) >> gamma_shift][*(sp + 4)];
+ png_composite_16(w, v, a, png_ptr->background_1.blue);
+ if (!optimize)
+ w = gamma_16_from_1[((w&0xff) >> gamma_shift)][w >>
+ 8];
+
+ *(sp + 4) = (png_byte)((w >> 8) & 0xff);
+ *(sp + 5) = (png_byte)(w & 0xff);
+ }
+ }
+ }
+
+ else
+#endif
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++, sp += 8)
+ {
+ png_uint_16 a = (png_uint_16)(((png_uint_16)(*(sp + 6))
+ << 8) + (png_uint_16)(*(sp + 7)));
+
+ if (a == 0)
+ {
+ *sp = (png_byte)((png_ptr->background.red >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(png_ptr->background.red & 0xff);
+ *(sp + 2) = (png_byte)((png_ptr->background.green >> 8)
+ & 0xff);
+ *(sp + 3) = (png_byte)(png_ptr->background.green
+ & 0xff);
+ *(sp + 4) = (png_byte)((png_ptr->background.blue >> 8)
+ & 0xff);
+ *(sp + 5) = (png_byte)(png_ptr->background.blue & 0xff);
+ }
+
+ else if (a < 0xffff)
+ {
+ png_uint_16 v;
+
+ png_uint_16 r = (png_uint_16)(((*sp) << 8) + *(sp + 1));
+ png_uint_16 g = (png_uint_16)(((*(sp + 2)) << 8)
+ + *(sp + 3));
+ png_uint_16 b = (png_uint_16)(((*(sp + 4)) << 8)
+ + *(sp + 5));
+
+ png_composite_16(v, r, a, png_ptr->background.red);
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+
+ png_composite_16(v, g, a, png_ptr->background.green);
+ *(sp + 2) = (png_byte)((v >> 8) & 0xff);
+ *(sp + 3) = (png_byte)(v & 0xff);
+
+ png_composite_16(v, b, a, png_ptr->background.blue);
+ *(sp + 4) = (png_byte)((v >> 8) & 0xff);
+ *(sp + 5) = (png_byte)(v & 0xff);
+ }
+ }
+ }
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+}
+#endif /* PNG_READ_BACKGROUND_SUPPORTED || PNG_READ_ALPHA_MODE_SUPPORTED */
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+/* Gamma correct the image, avoiding the alpha channel. Make sure
+ * you do this after you deal with the transparency issue on grayscale
+ * or RGB images. If your bit depth is 8, use gamma_table, if it
+ * is 16, use gamma_16_table and gamma_shift. Build these with
+ * build_gamma_table().
+ */
+void /* PRIVATE */
+png_do_gamma(png_row_infop row_info, png_bytep row, png_structrp png_ptr)
+{
+ png_const_bytep gamma_table = png_ptr->gamma_table;
+ png_const_uint_16pp gamma_16_table = png_ptr->gamma_16_table;
+ int gamma_shift = png_ptr->gamma_shift;
+
+ png_bytep sp;
+ png_uint_32 i;
+ png_uint_32 row_width=row_info->width;
+
+ png_debug(1, "in png_do_gamma");
+
+ if (((row_info->bit_depth <= 8 && gamma_table != NULL) ||
+ (row_info->bit_depth == 16 && gamma_16_table != NULL)))
+ {
+ switch (row_info->color_type)
+ {
+ case PNG_COLOR_TYPE_RGB:
+ {
+ if (row_info->bit_depth == 8)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ *sp = gamma_table[*sp];
+ sp++;
+ *sp = gamma_table[*sp];
+ sp++;
+ *sp = gamma_table[*sp];
+ sp++;
+ }
+ }
+
+ else /* if (row_info->bit_depth == 16) */
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ png_uint_16 v;
+
+ v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ sp += 2;
+
+ v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ sp += 2;
+
+ v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ sp += 2;
+ }
+ }
+ break;
+ }
+
+ case PNG_COLOR_TYPE_RGB_ALPHA:
+ {
+ if (row_info->bit_depth == 8)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ *sp = gamma_table[*sp];
+ sp++;
+
+ *sp = gamma_table[*sp];
+ sp++;
+
+ *sp = gamma_table[*sp];
+ sp++;
+
+ sp++;
+ }
+ }
+
+ else /* if (row_info->bit_depth == 16) */
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ png_uint_16 v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ sp += 2;
+
+ v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ sp += 2;
+
+ v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ sp += 4;
+ }
+ }
+ break;
+ }
+
+ case PNG_COLOR_TYPE_GRAY_ALPHA:
+ {
+ if (row_info->bit_depth == 8)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ *sp = gamma_table[*sp];
+ sp += 2;
+ }
+ }
+
+ else /* if (row_info->bit_depth == 16) */
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ png_uint_16 v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ sp += 4;
+ }
+ }
+ break;
+ }
+
+ case PNG_COLOR_TYPE_GRAY:
+ {
+ if (row_info->bit_depth == 2)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i += 4)
+ {
+ int a = *sp & 0xc0;
+ int b = *sp & 0x30;
+ int c = *sp & 0x0c;
+ int d = *sp & 0x03;
+
+ *sp = (png_byte)(
+ ((((int)gamma_table[a|(a>>2)|(a>>4)|(a>>6)]) ) & 0xc0)|
+ ((((int)gamma_table[(b<<2)|b|(b>>2)|(b>>4)])>>2) & 0x30)|
+ ((((int)gamma_table[(c<<4)|(c<<2)|c|(c>>2)])>>4) & 0x0c)|
+ ((((int)gamma_table[(d<<6)|(d<<4)|(d<<2)|d])>>6) ));
+ sp++;
+ }
+ }
+
+ if (row_info->bit_depth == 4)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i += 2)
+ {
+ int msb = *sp & 0xf0;
+ int lsb = *sp & 0x0f;
+
+ *sp = (png_byte)((((int)gamma_table[msb | (msb >> 4)]) & 0xf0)
+ | (((int)gamma_table[(lsb << 4) | lsb]) >> 4));
+ sp++;
+ }
+ }
+
+ else if (row_info->bit_depth == 8)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ *sp = gamma_table[*sp];
+ sp++;
+ }
+ }
+
+ else if (row_info->bit_depth == 16)
+ {
+ sp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ png_uint_16 v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp];
+ *sp = (png_byte)((v >> 8) & 0xff);
+ *(sp + 1) = (png_byte)(v & 0xff);
+ sp += 2;
+ }
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+}
+#endif
+
+#ifdef PNG_READ_ALPHA_MODE_SUPPORTED
+/* Encode the alpha channel to the output gamma (the input channel is always
+ * linear.) Called only with color types that have an alpha channel. Needs the
+ * from_1 tables.
+ */
+void /* PRIVATE */
+png_do_encode_alpha(png_row_infop row_info, png_bytep row, png_structrp png_ptr)
+{
+ png_uint_32 row_width = row_info->width;
+
+ png_debug(1, "in png_do_encode_alpha");
+
+ if (row_info->color_type & PNG_COLOR_MASK_ALPHA)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ PNG_CONST png_bytep table = png_ptr->gamma_from_1;
+
+ if (table != NULL)
+ {
+ PNG_CONST int step =
+ (row_info->color_type & PNG_COLOR_MASK_COLOR) ? 4 : 2;
+
+ /* The alpha channel is the last component: */
+ row += step - 1;
+
+ for (; row_width > 0; --row_width, row += step)
+ *row = table[*row];
+
+ return;
+ }
+ }
+
+ else if (row_info->bit_depth == 16)
+ {
+ PNG_CONST png_uint_16pp table = png_ptr->gamma_16_from_1;
+ PNG_CONST int gamma_shift = png_ptr->gamma_shift;
+
+ if (table != NULL)
+ {
+ PNG_CONST int step =
+ (row_info->color_type & PNG_COLOR_MASK_COLOR) ? 8 : 4;
+
+ /* The alpha channel is the last component: */
+ row += step - 2;
+
+ for (; row_width > 0; --row_width, row += step)
+ {
+ png_uint_16 v;
+
+ v = table[*(row + 1) >> gamma_shift][*row];
+ *row = (png_byte)((v >> 8) & 0xff);
+ *(row + 1) = (png_byte)(v & 0xff);
+ }
+
+ return;
+ }
+ }
+ }
+
+ /* Only get to here if called with a weird row_info; no harm has been done,
+ * so just issue a warning.
+ */
+ png_warning(png_ptr, "png_do_encode_alpha: unexpected call");
+}
+#endif
+
+#ifdef PNG_READ_EXPAND_SUPPORTED
+/* Expands a palette row to an RGB or RGBA row depending
+ * upon whether you supply trans and num_trans.
+ */
+void /* PRIVATE */
+png_do_expand_palette(png_row_infop row_info, png_bytep row,
+ png_const_colorp palette, png_const_bytep trans_alpha, int num_trans)
+{
+ int shift, value;
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width=row_info->width;
+
+ png_debug(1, "in png_do_expand_palette");
+
+ if (row_info->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ if (row_info->bit_depth < 8)
+ {
+ switch (row_info->bit_depth)
+ {
+ case 1:
+ {
+ sp = row + (png_size_t)((row_width - 1) >> 3);
+ dp = row + (png_size_t)row_width - 1;
+ shift = 7 - (int)((row_width + 7) & 0x07);
+ for (i = 0; i < row_width; i++)
+ {
+ if ((*sp >> shift) & 0x01)
+ *dp = 1;
+
+ else
+ *dp = 0;
+
+ if (shift == 7)
+ {
+ shift = 0;
+ sp--;
+ }
+
+ else
+ shift++;
+
+ dp--;
+ }
+ break;
+ }
+
+ case 2:
+ {
+ sp = row + (png_size_t)((row_width - 1) >> 2);
+ dp = row + (png_size_t)row_width - 1;
+ shift = (int)((3 - ((row_width + 3) & 0x03)) << 1);
+ for (i = 0; i < row_width; i++)
+ {
+ value = (*sp >> shift) & 0x03;
+ *dp = (png_byte)value;
+ if (shift == 6)
+ {
+ shift = 0;
+ sp--;
+ }
+
+ else
+ shift += 2;
+
+ dp--;
+ }
+ break;
+ }
+
+ case 4:
+ {
+ sp = row + (png_size_t)((row_width - 1) >> 1);
+ dp = row + (png_size_t)row_width - 1;
+ shift = (int)((row_width & 0x01) << 2);
+ for (i = 0; i < row_width; i++)
+ {
+ value = (*sp >> shift) & 0x0f;
+ *dp = (png_byte)value;
+ if (shift == 4)
+ {
+ shift = 0;
+ sp--;
+ }
+
+ else
+ shift += 4;
+
+ dp--;
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+ row_info->bit_depth = 8;
+ row_info->pixel_depth = 8;
+ row_info->rowbytes = row_width;
+ }
+
+ if (row_info->bit_depth == 8)
+ {
+ {
+ if (num_trans > 0)
+ {
+ sp = row + (png_size_t)row_width - 1;
+ dp = row + (png_size_t)(row_width << 2) - 1;
+
+ for (i = 0; i < row_width; i++)
+ {
+ if ((int)(*sp) >= num_trans)
+ *dp-- = 0xff;
+
+ else
+ *dp-- = trans_alpha[*sp];
+
+ *dp-- = palette[*sp].blue;
+ *dp-- = palette[*sp].green;
+ *dp-- = palette[*sp].red;
+ sp--;
+ }
+ row_info->bit_depth = 8;
+ row_info->pixel_depth = 32;
+ row_info->rowbytes = row_width * 4;
+ row_info->color_type = 6;
+ row_info->channels = 4;
+ }
+
+ else
+ {
+ sp = row + (png_size_t)row_width - 1;
+ dp = row + (png_size_t)(row_width * 3) - 1;
+
+ for (i = 0; i < row_width; i++)
+ {
+ *dp-- = palette[*sp].blue;
+ *dp-- = palette[*sp].green;
+ *dp-- = palette[*sp].red;
+ sp--;
+ }
+
+ row_info->bit_depth = 8;
+ row_info->pixel_depth = 24;
+ row_info->rowbytes = row_width * 3;
+ row_info->color_type = 2;
+ row_info->channels = 3;
+ }
+ }
+ }
+ }
+}
+
+/* If the bit depth < 8, it is expanded to 8. Also, if the already
+ * expanded transparency value is supplied, an alpha channel is built.
+ */
+void /* PRIVATE */
+png_do_expand(png_row_infop row_info, png_bytep row,
+ png_const_color_16p trans_color)
+{
+ int shift, value;
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width=row_info->width;
+
+ png_debug(1, "in png_do_expand");
+
+ {
+ if (row_info->color_type == PNG_COLOR_TYPE_GRAY)
+ {
+ unsigned int gray = trans_color ? trans_color->gray : 0;
+
+ if (row_info->bit_depth < 8)
+ {
+ switch (row_info->bit_depth)
+ {
+ case 1:
+ {
+ gray = (gray & 0x01) * 0xff;
+ sp = row + (png_size_t)((row_width - 1) >> 3);
+ dp = row + (png_size_t)row_width - 1;
+ shift = 7 - (int)((row_width + 7) & 0x07);
+ for (i = 0; i < row_width; i++)
+ {
+ if ((*sp >> shift) & 0x01)
+ *dp = 0xff;
+
+ else
+ *dp = 0;
+
+ if (shift == 7)
+ {
+ shift = 0;
+ sp--;
+ }
+
+ else
+ shift++;
+
+ dp--;
+ }
+ break;
+ }
+
+ case 2:
+ {
+ gray = (gray & 0x03) * 0x55;
+ sp = row + (png_size_t)((row_width - 1) >> 2);
+ dp = row + (png_size_t)row_width - 1;
+ shift = (int)((3 - ((row_width + 3) & 0x03)) << 1);
+ for (i = 0; i < row_width; i++)
+ {
+ value = (*sp >> shift) & 0x03;
+ *dp = (png_byte)(value | (value << 2) | (value << 4) |
+ (value << 6));
+ if (shift == 6)
+ {
+ shift = 0;
+ sp--;
+ }
+
+ else
+ shift += 2;
+
+ dp--;
+ }
+ break;
+ }
+
+ case 4:
+ {
+ gray = (gray & 0x0f) * 0x11;
+ sp = row + (png_size_t)((row_width - 1) >> 1);
+ dp = row + (png_size_t)row_width - 1;
+ shift = (int)((1 - ((row_width + 1) & 0x01)) << 2);
+ for (i = 0; i < row_width; i++)
+ {
+ value = (*sp >> shift) & 0x0f;
+ *dp = (png_byte)(value | (value << 4));
+ if (shift == 4)
+ {
+ shift = 0;
+ sp--;
+ }
+
+ else
+ shift = 4;
+
+ dp--;
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+
+ row_info->bit_depth = 8;
+ row_info->pixel_depth = 8;
+ row_info->rowbytes = row_width;
+ }
+
+ if (trans_color != NULL)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ gray = gray & 0xff;
+ sp = row + (png_size_t)row_width - 1;
+ dp = row + (png_size_t)(row_width << 1) - 1;
+
+ for (i = 0; i < row_width; i++)
+ {
+ if (*sp == gray)
+ *dp-- = 0;
+
+ else
+ *dp-- = 0xff;
+
+ *dp-- = *sp--;
+ }
+ }
+
+ else if (row_info->bit_depth == 16)
+ {
+ unsigned int gray_high = (gray >> 8) & 0xff;
+ unsigned int gray_low = gray & 0xff;
+ sp = row + row_info->rowbytes - 1;
+ dp = row + (row_info->rowbytes << 1) - 1;
+ for (i = 0; i < row_width; i++)
+ {
+ if (*(sp - 1) == gray_high && *(sp) == gray_low)
+ {
+ *dp-- = 0;
+ *dp-- = 0;
+ }
+
+ else
+ {
+ *dp-- = 0xff;
+ *dp-- = 0xff;
+ }
+
+ *dp-- = *sp--;
+ *dp-- = *sp--;
+ }
+ }
+
+ row_info->color_type = PNG_COLOR_TYPE_GRAY_ALPHA;
+ row_info->channels = 2;
+ row_info->pixel_depth = (png_byte)(row_info->bit_depth << 1);
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth,
+ row_width);
+ }
+ }
+ else if (row_info->color_type == PNG_COLOR_TYPE_RGB && trans_color)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ png_byte red = (png_byte)(trans_color->red & 0xff);
+ png_byte green = (png_byte)(trans_color->green & 0xff);
+ png_byte blue = (png_byte)(trans_color->blue & 0xff);
+ sp = row + (png_size_t)row_info->rowbytes - 1;
+ dp = row + (png_size_t)(row_width << 2) - 1;
+ for (i = 0; i < row_width; i++)
+ {
+ if (*(sp - 2) == red && *(sp - 1) == green && *(sp) == blue)
+ *dp-- = 0;
+
+ else
+ *dp-- = 0xff;
+
+ *dp-- = *sp--;
+ *dp-- = *sp--;
+ *dp-- = *sp--;
+ }
+ }
+ else if (row_info->bit_depth == 16)
+ {
+ png_byte red_high = (png_byte)((trans_color->red >> 8) & 0xff);
+ png_byte green_high = (png_byte)((trans_color->green >> 8) & 0xff);
+ png_byte blue_high = (png_byte)((trans_color->blue >> 8) & 0xff);
+ png_byte red_low = (png_byte)(trans_color->red & 0xff);
+ png_byte green_low = (png_byte)(trans_color->green & 0xff);
+ png_byte blue_low = (png_byte)(trans_color->blue & 0xff);
+ sp = row + row_info->rowbytes - 1;
+ dp = row + (png_size_t)(row_width << 3) - 1;
+ for (i = 0; i < row_width; i++)
+ {
+ if (*(sp - 5) == red_high &&
+ *(sp - 4) == red_low &&
+ *(sp - 3) == green_high &&
+ *(sp - 2) == green_low &&
+ *(sp - 1) == blue_high &&
+ *(sp ) == blue_low)
+ {
+ *dp-- = 0;
+ *dp-- = 0;
+ }
+
+ else
+ {
+ *dp-- = 0xff;
+ *dp-- = 0xff;
+ }
+
+ *dp-- = *sp--;
+ *dp-- = *sp--;
+ *dp-- = *sp--;
+ *dp-- = *sp--;
+ *dp-- = *sp--;
+ *dp-- = *sp--;
+ }
+ }
+ row_info->color_type = PNG_COLOR_TYPE_RGB_ALPHA;
+ row_info->channels = 4;
+ row_info->pixel_depth = (png_byte)(row_info->bit_depth << 2);
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width);
+ }
+ }
+}
+#endif
+
+#ifdef PNG_READ_EXPAND_16_SUPPORTED
+/* If the bit depth is 8 and the color type is not a palette type expand the
+ * whole row to 16 bits. Has no effect otherwise.
+ */
+void /* PRIVATE */
+png_do_expand_16(png_row_infop row_info, png_bytep row)
+{
+ if (row_info->bit_depth == 8 &&
+ row_info->color_type != PNG_COLOR_TYPE_PALETTE)
+ {
+ /* The row have a sequence of bytes containing [0..255] and we need
+ * to turn it into another row containing [0..65535], to do this we
+ * calculate:
+ *
+ * (input / 255) * 65535
+ *
+ * Which happens to be exactly input * 257 and this can be achieved
+ * simply by byte replication in place (copying backwards).
+ */
+ png_byte *sp = row + row_info->rowbytes; /* source, last byte + 1 */
+ png_byte *dp = sp + row_info->rowbytes; /* destination, end + 1 */
+ while (dp > sp)
+ dp[-2] = dp[-1] = *--sp, dp -= 2;
+
+ row_info->rowbytes *= 2;
+ row_info->bit_depth = 16;
+ row_info->pixel_depth = (png_byte)(row_info->channels * 16);
+ }
+}
+#endif
+
+#ifdef PNG_READ_QUANTIZE_SUPPORTED
+void /* PRIVATE */
+png_do_quantize(png_row_infop row_info, png_bytep row,
+ png_const_bytep palette_lookup, png_const_bytep quantize_lookup)
+{
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width=row_info->width;
+
+ png_debug(1, "in png_do_quantize");
+
+ if (row_info->bit_depth == 8)
+ {
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB && palette_lookup)
+ {
+ int r, g, b, p;
+ sp = row;
+ dp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ r = *sp++;
+ g = *sp++;
+ b = *sp++;
+
+ /* This looks real messy, but the compiler will reduce
+ * it down to a reasonable formula. For example, with
+ * 5 bits per color, we get:
+ * p = (((r >> 3) & 0x1f) << 10) |
+ * (((g >> 3) & 0x1f) << 5) |
+ * ((b >> 3) & 0x1f);
+ */
+ p = (((r >> (8 - PNG_QUANTIZE_RED_BITS)) &
+ ((1 << PNG_QUANTIZE_RED_BITS) - 1)) <<
+ (PNG_QUANTIZE_GREEN_BITS + PNG_QUANTIZE_BLUE_BITS)) |
+ (((g >> (8 - PNG_QUANTIZE_GREEN_BITS)) &
+ ((1 << PNG_QUANTIZE_GREEN_BITS) - 1)) <<
+ (PNG_QUANTIZE_BLUE_BITS)) |
+ ((b >> (8 - PNG_QUANTIZE_BLUE_BITS)) &
+ ((1 << PNG_QUANTIZE_BLUE_BITS) - 1));
+
+ *dp++ = palette_lookup[p];
+ }
+
+ row_info->color_type = PNG_COLOR_TYPE_PALETTE;
+ row_info->channels = 1;
+ row_info->pixel_depth = row_info->bit_depth;
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width);
+ }
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA &&
+ palette_lookup != NULL)
+ {
+ int r, g, b, p;
+ sp = row;
+ dp = row;
+ for (i = 0; i < row_width; i++)
+ {
+ r = *sp++;
+ g = *sp++;
+ b = *sp++;
+ sp++;
+
+ p = (((r >> (8 - PNG_QUANTIZE_RED_BITS)) &
+ ((1 << PNG_QUANTIZE_RED_BITS) - 1)) <<
+ (PNG_QUANTIZE_GREEN_BITS + PNG_QUANTIZE_BLUE_BITS)) |
+ (((g >> (8 - PNG_QUANTIZE_GREEN_BITS)) &
+ ((1 << PNG_QUANTIZE_GREEN_BITS) - 1)) <<
+ (PNG_QUANTIZE_BLUE_BITS)) |
+ ((b >> (8 - PNG_QUANTIZE_BLUE_BITS)) &
+ ((1 << PNG_QUANTIZE_BLUE_BITS) - 1));
+
+ *dp++ = palette_lookup[p];
+ }
+
+ row_info->color_type = PNG_COLOR_TYPE_PALETTE;
+ row_info->channels = 1;
+ row_info->pixel_depth = row_info->bit_depth;
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width);
+ }
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_PALETTE &&
+ quantize_lookup)
+ {
+ sp = row;
+
+ for (i = 0; i < row_width; i++, sp++)
+ {
+ *sp = quantize_lookup[*sp];
+ }
+ }
+ }
+}
+#endif /* PNG_READ_QUANTIZE_SUPPORTED */
+#endif /* PNG_READ_TRANSFORMS_SUPPORTED */
+
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+/* Undoes intrapixel differencing */
+void /* PRIVATE */
+png_do_read_intrapixel(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_read_intrapixel");
+
+ if (
+ (row_info->color_type & PNG_COLOR_MASK_COLOR))
+ {
+ int bytes_per_pixel;
+ png_uint_32 row_width = row_info->width;
+
+ if (row_info->bit_depth == 8)
+ {
+ png_bytep rp;
+ png_uint_32 i;
+
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB)
+ bytes_per_pixel = 3;
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ bytes_per_pixel = 4;
+
+ else
+ return;
+
+ for (i = 0, rp = row; i < row_width; i++, rp += bytes_per_pixel)
+ {
+ *(rp) = (png_byte)((256 + *rp + *(rp + 1)) & 0xff);
+ *(rp+2) = (png_byte)((256 + *(rp + 2) + *(rp + 1)) & 0xff);
+ }
+ }
+ else if (row_info->bit_depth == 16)
+ {
+ png_bytep rp;
+ png_uint_32 i;
+
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB)
+ bytes_per_pixel = 6;
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ bytes_per_pixel = 8;
+
+ else
+ return;
+
+ for (i = 0, rp = row; i < row_width; i++, rp += bytes_per_pixel)
+ {
+ png_uint_32 s0 = (*(rp ) << 8) | *(rp + 1);
+ png_uint_32 s1 = (*(rp + 2) << 8) | *(rp + 3);
+ png_uint_32 s2 = (*(rp + 4) << 8) | *(rp + 5);
+ png_uint_32 red = (s0 + s1 + 65536) & 0xffff;
+ png_uint_32 blue = (s2 + s1 + 65536) & 0xffff;
+ *(rp ) = (png_byte)((red >> 8) & 0xff);
+ *(rp + 1) = (png_byte)(red & 0xff);
+ *(rp + 4) = (png_byte)((blue >> 8) & 0xff);
+ *(rp + 5) = (png_byte)(blue & 0xff);
+ }
+ }
+ }
+}
+#endif /* PNG_MNG_FEATURES_SUPPORTED */
+#endif /* PNG_READ_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngrutil.c b/ml/dlib/dlib/external/libpng/pngrutil.c
new file mode 100644
index 000000000..2438384dd
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngrutil.c
@@ -0,0 +1,4475 @@
+
+/* pngrutil.c - utilities to read a PNG file
+ *
+ * Last changed in libpng 1.6.7 [November 14, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ * This file contains routines that are only called from within
+ * libpng itself during the course of reading an image.
+ */
+
+#include "pngpriv.h"
+
+#ifdef PNG_READ_SUPPORTED
+
+png_uint_32 PNGAPI
+png_get_uint_31(png_const_structrp png_ptr, png_const_bytep buf)
+{
+ png_uint_32 uval = png_get_uint_32(buf);
+
+ if (uval > PNG_UINT_31_MAX)
+ png_error(png_ptr, "PNG unsigned integer out of range");
+
+ return (uval);
+}
+
+#if defined(PNG_READ_gAMA_SUPPORTED) || defined(PNG_READ_cHRM_SUPPORTED)
+/* The following is a variation on the above for use with the fixed
+ * point values used for gAMA and cHRM. Instead of png_error it
+ * issues a warning and returns (-1) - an invalid value because both
+ * gAMA and cHRM use *unsigned* integers for fixed point values.
+ */
+#define PNG_FIXED_ERROR (-1)
+
+static png_fixed_point /* PRIVATE */
+png_get_fixed_point(png_structrp png_ptr, png_const_bytep buf)
+{
+ png_uint_32 uval = png_get_uint_32(buf);
+
+ if (uval <= PNG_UINT_31_MAX)
+ return (png_fixed_point)uval; /* known to be in range */
+
+ /* The caller can turn off the warning by passing NULL. */
+ if (png_ptr != NULL)
+ png_warning(png_ptr, "PNG fixed point integer out of range");
+
+ return PNG_FIXED_ERROR;
+}
+#endif
+
+#ifdef PNG_READ_INT_FUNCTIONS_SUPPORTED
+/* NOTE: the read macros will obscure these definitions, so that if
+ * PNG_USE_READ_MACROS is set the library will not use them internally,
+ * but the APIs will still be available externally.
+ *
+ * The parentheses around "PNGAPI function_name" in the following three
+ * functions are necessary because they allow the macros to co-exist with
+ * these (unused but exported) functions.
+ */
+
+/* Grab an unsigned 32-bit integer from a buffer in big-endian format. */
+png_uint_32 (PNGAPI
+png_get_uint_32)(png_const_bytep buf)
+{
+ png_uint_32 uval =
+ ((png_uint_32)(*(buf )) << 24) +
+ ((png_uint_32)(*(buf + 1)) << 16) +
+ ((png_uint_32)(*(buf + 2)) << 8) +
+ ((png_uint_32)(*(buf + 3)) ) ;
+
+ return uval;
+}
+
+/* Grab a signed 32-bit integer from a buffer in big-endian format. The
+ * data is stored in the PNG file in two's complement format and there
+ * is no guarantee that a 'png_int_32' is exactly 32 bits, therefore
+ * the following code does a two's complement to native conversion.
+ */
+png_int_32 (PNGAPI
+png_get_int_32)(png_const_bytep buf)
+{
+ png_uint_32 uval = png_get_uint_32(buf);
+ if ((uval & 0x80000000) == 0) /* non-negative */
+ return uval;
+
+ uval = (uval ^ 0xffffffff) + 1; /* 2's complement: -x = ~x+1 */
+ return -(png_int_32)uval;
+}
+
+/* Grab an unsigned 16-bit integer from a buffer in big-endian format. */
+png_uint_16 (PNGAPI
+png_get_uint_16)(png_const_bytep buf)
+{
+ /* ANSI-C requires an int value to accomodate at least 16 bits so this
+ * works and allows the compiler not to worry about possible narrowing
+ * on 32 bit systems. (Pre-ANSI systems did not make integers smaller
+ * than 16 bits either.)
+ */
+ unsigned int val =
+ ((unsigned int)(*buf) << 8) +
+ ((unsigned int)(*(buf + 1)));
+
+ return (png_uint_16)val;
+}
+
+#endif /* PNG_READ_INT_FUNCTIONS_SUPPORTED */
+
+/* Read and check the PNG file signature */
+void /* PRIVATE */
+png_read_sig(png_structrp png_ptr, png_inforp info_ptr)
+{
+ png_size_t num_checked, num_to_check;
+
+ /* Exit if the user application does not expect a signature. */
+ if (png_ptr->sig_bytes >= 8)
+ return;
+
+ num_checked = png_ptr->sig_bytes;
+ num_to_check = 8 - num_checked;
+
+#ifdef PNG_IO_STATE_SUPPORTED
+ png_ptr->io_state = PNG_IO_READING | PNG_IO_SIGNATURE;
+#endif
+
+ /* The signature must be serialized in a single I/O call. */
+ png_read_data(png_ptr, &(info_ptr->signature[num_checked]), num_to_check);
+ png_ptr->sig_bytes = 8;
+
+ if (png_sig_cmp(info_ptr->signature, num_checked, num_to_check))
+ {
+ if (num_checked < 4 &&
+ png_sig_cmp(info_ptr->signature, num_checked, num_to_check - 4))
+ png_error(png_ptr, "Not a PNG file");
+ else
+ png_error(png_ptr, "PNG file corrupted by ASCII conversion");
+ }
+ if (num_checked < 3)
+ png_ptr->mode |= PNG_HAVE_PNG_SIGNATURE;
+}
+
+/* Read the chunk header (length + type name).
+ * Put the type name into png_ptr->chunk_name, and return the length.
+ */
+png_uint_32 /* PRIVATE */
+png_read_chunk_header(png_structrp png_ptr)
+{
+ png_byte buf[8];
+ png_uint_32 length;
+
+#ifdef PNG_IO_STATE_SUPPORTED
+ png_ptr->io_state = PNG_IO_READING | PNG_IO_CHUNK_HDR;
+#endif
+
+ /* Read the length and the chunk name.
+ * This must be performed in a single I/O call.
+ */
+ png_read_data(png_ptr, buf, 8);
+ length = png_get_uint_31(png_ptr, buf);
+
+ /* Put the chunk name into png_ptr->chunk_name. */
+ png_ptr->chunk_name = PNG_CHUNK_FROM_STRING(buf+4);
+
+ png_debug2(0, "Reading %lx chunk, length = %lu",
+ (unsigned long)png_ptr->chunk_name, (unsigned long)length);
+
+ /* Reset the crc and run it over the chunk name. */
+ png_reset_crc(png_ptr);
+ png_calculate_crc(png_ptr, buf + 4, 4);
+
+ /* Check to see if chunk name is valid. */
+ png_check_chunk_name(png_ptr, png_ptr->chunk_name);
+
+#ifdef PNG_IO_STATE_SUPPORTED
+ png_ptr->io_state = PNG_IO_READING | PNG_IO_CHUNK_DATA;
+#endif
+
+ return length;
+}
+
+/* Read data, and (optionally) run it through the CRC. */
+void /* PRIVATE */
+png_crc_read(png_structrp png_ptr, png_bytep buf, png_uint_32 length)
+{
+ if (png_ptr == NULL)
+ return;
+
+ png_read_data(png_ptr, buf, length);
+ png_calculate_crc(png_ptr, buf, length);
+}
+
+/* Optionally skip data and then check the CRC. Depending on whether we
+ * are reading an ancillary or critical chunk, and how the program has set
+ * things up, we may calculate the CRC on the data and print a message.
+ * Returns '1' if there was a CRC error, '0' otherwise.
+ */
+int /* PRIVATE */
+png_crc_finish(png_structrp png_ptr, png_uint_32 skip)
+{
+ /* The size of the local buffer for inflate is a good guess as to a
+ * reasonable size to use for buffering reads from the application.
+ */
+ while (skip > 0)
+ {
+ png_uint_32 len;
+ png_byte tmpbuf[PNG_INFLATE_BUF_SIZE];
+
+ len = (sizeof tmpbuf);
+ if (len > skip)
+ len = skip;
+ skip -= len;
+
+ png_crc_read(png_ptr, tmpbuf, len);
+ }
+
+ if (png_crc_error(png_ptr))
+ {
+ if (PNG_CHUNK_ANCILLARY(png_ptr->chunk_name) ?
+ !(png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_NOWARN) :
+ (png_ptr->flags & PNG_FLAG_CRC_CRITICAL_USE))
+ {
+ png_chunk_warning(png_ptr, "CRC error");
+ }
+
+ else
+ {
+ png_chunk_benign_error(png_ptr, "CRC error");
+ return (0);
+ }
+
+ return (1);
+ }
+
+ return (0);
+}
+
+/* Compare the CRC stored in the PNG file with that calculated by libpng from
+ * the data it has read thus far.
+ */
+int /* PRIVATE */
+png_crc_error(png_structrp png_ptr)
+{
+ png_byte crc_bytes[4];
+ png_uint_32 crc;
+ int need_crc = 1;
+
+ if (PNG_CHUNK_ANCILLARY(png_ptr->chunk_name))
+ {
+ if ((png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_MASK) ==
+ (PNG_FLAG_CRC_ANCILLARY_USE | PNG_FLAG_CRC_ANCILLARY_NOWARN))
+ need_crc = 0;
+ }
+
+ else /* critical */
+ {
+ if (png_ptr->flags & PNG_FLAG_CRC_CRITICAL_IGNORE)
+ need_crc = 0;
+ }
+
+#ifdef PNG_IO_STATE_SUPPORTED
+ png_ptr->io_state = PNG_IO_READING | PNG_IO_CHUNK_CRC;
+#endif
+
+ /* The chunk CRC must be serialized in a single I/O call. */
+ png_read_data(png_ptr, crc_bytes, 4);
+
+ if (need_crc)
+ {
+ crc = png_get_uint_32(crc_bytes);
+ return ((int)(crc != png_ptr->crc));
+ }
+
+ else
+ return (0);
+}
+
+/* Manage the read buffer; this simply reallocates the buffer if it is not small
+ * enough (or if it is not allocated). The routine returns a pointer to the
+ * buffer; if an error occurs and 'warn' is set the routine returns NULL, else
+ * it will call png_error (via png_malloc) on failure. (warn == 2 means
+ * 'silent').
+ */
+static png_bytep
+png_read_buffer(png_structrp png_ptr, png_alloc_size_t new_size, int warn)
+{
+ png_bytep buffer = png_ptr->read_buffer;
+
+ if (buffer != NULL && new_size > png_ptr->read_buffer_size)
+ {
+ png_ptr->read_buffer = NULL;
+ png_ptr->read_buffer = NULL;
+ png_ptr->read_buffer_size = 0;
+ png_free(png_ptr, buffer);
+ buffer = NULL;
+ }
+
+ if (buffer == NULL)
+ {
+ buffer = png_voidcast(png_bytep, png_malloc_base(png_ptr, new_size));
+
+ if (buffer != NULL)
+ {
+ png_ptr->read_buffer = buffer;
+ png_ptr->read_buffer_size = new_size;
+ }
+
+ else if (warn < 2) /* else silent */
+ {
+#ifdef PNG_WARNINGS_SUPPORTED
+ if (warn)
+ png_chunk_warning(png_ptr, "insufficient memory to read chunk");
+ else
+#endif
+ {
+#ifdef PNG_ERROR_TEXT_SUPPORTED
+ png_chunk_error(png_ptr, "insufficient memory to read chunk");
+#endif
+ }
+ }
+ }
+
+ return buffer;
+}
+
+/* png_inflate_claim: claim the zstream for some nefarious purpose that involves
+ * decompression. Returns Z_OK on success, else a zlib error code. It checks
+ * the owner but, in final release builds, just issues a warning if some other
+ * chunk apparently owns the stream. Prior to release it does a png_error.
+ */
+static int
+png_inflate_claim(png_structrp png_ptr, png_uint_32 owner)
+{
+ if (png_ptr->zowner != 0)
+ {
+ char msg[64];
+
+ PNG_STRING_FROM_CHUNK(msg, png_ptr->zowner);
+ /* So the message that results is "<chunk> using zstream"; this is an
+ * internal error, but is very useful for debugging. i18n requirements
+ * are minimal.
+ */
+ (void)png_safecat(msg, (sizeof msg), 4, " using zstream");
+# if PNG_LIBPNG_BUILD_BASE_TYPE >= PNG_LIBPNG_BUILD_RC
+ png_chunk_warning(png_ptr, msg);
+ png_ptr->zowner = 0;
+# else
+ png_chunk_error(png_ptr, msg);
+# endif
+ }
+
+ /* Implementation note: unlike 'png_deflate_claim' this internal function
+ * does not take the size of the data as an argument. Some efficiency could
+ * be gained by using this when it is known *if* the zlib stream itself does
+ * not record the number; however, this is an illusion: the original writer
+ * of the PNG may have selected a lower window size, and we really must
+ * follow that because, for systems with with limited capabilities, we
+ * would otherwise reject the application's attempts to use a smaller window
+ * size (zlib doesn't have an interface to say "this or lower"!).
+ *
+ * inflateReset2 was added to zlib 1.2.4; before this the window could not be
+ * reset, therefore it is necessary to always allocate the maximum window
+ * size with earlier zlibs just in case later compressed chunks need it.
+ */
+ {
+ int ret; /* zlib return code */
+# if PNG_ZLIB_VERNUM >= 0x1240
+
+# if defined(PNG_SET_OPTION_SUPPORTED) && \
+ defined(PNG_MAXIMUM_INFLATE_WINDOW)
+ int window_bits;
+
+ if (((png_ptr->options >> PNG_MAXIMUM_INFLATE_WINDOW) & 3) ==
+ PNG_OPTION_ON)
+ window_bits = 15;
+
+ else
+ window_bits = 0;
+# else
+# define window_bits 0
+# endif
+# endif
+
+ /* Set this for safety, just in case the previous owner left pointers to
+ * memory allocations.
+ */
+ png_ptr->zstream.next_in = NULL;
+ png_ptr->zstream.avail_in = 0;
+ png_ptr->zstream.next_out = NULL;
+ png_ptr->zstream.avail_out = 0;
+
+ if (png_ptr->flags & PNG_FLAG_ZSTREAM_INITIALIZED)
+ {
+# if PNG_ZLIB_VERNUM < 0x1240
+ ret = inflateReset(&png_ptr->zstream);
+# else
+ ret = inflateReset2(&png_ptr->zstream, window_bits);
+# endif
+ }
+
+ else
+ {
+# if PNG_ZLIB_VERNUM < 0x1240
+ ret = inflateInit(&png_ptr->zstream);
+# else
+ ret = inflateInit2(&png_ptr->zstream, window_bits);
+# endif
+
+ if (ret == Z_OK)
+ png_ptr->flags |= PNG_FLAG_ZSTREAM_INITIALIZED;
+ }
+
+ if (ret == Z_OK)
+ png_ptr->zowner = owner;
+
+ else
+ png_zstream_error(png_ptr, ret);
+
+ return ret;
+ }
+
+# ifdef window_bits
+# undef window_bits
+# endif
+}
+
+#ifdef PNG_READ_COMPRESSED_TEXT_SUPPORTED
+/* png_inflate now returns zlib error codes including Z_OK and Z_STREAM_END to
+ * allow the caller to do multiple calls if required. If the 'finish' flag is
+ * set Z_FINISH will be passed to the final inflate() call and Z_STREAM_END must
+ * be returned or there has been a problem, otherwise Z_SYNC_FLUSH is used and
+ * Z_OK or Z_STREAM_END will be returned on success.
+ *
+ * The input and output sizes are updated to the actual amounts of data consumed
+ * or written, not the amount available (as in a z_stream). The data pointers
+ * are not changed, so the next input is (data+input_size) and the next
+ * available output is (output+output_size).
+ */
+static int
+png_inflate(png_structrp png_ptr, png_uint_32 owner, int finish,
+ /* INPUT: */ png_const_bytep input, png_uint_32p input_size_ptr,
+ /* OUTPUT: */ png_bytep output, png_alloc_size_t *output_size_ptr)
+{
+ if (png_ptr->zowner == owner) /* Else not claimed */
+ {
+ int ret;
+ png_alloc_size_t avail_out = *output_size_ptr;
+ png_uint_32 avail_in = *input_size_ptr;
+
+ /* zlib can't necessarily handle more than 65535 bytes at once (i.e. it
+ * can't even necessarily handle 65536 bytes) because the type uInt is
+ * "16 bits or more". Consequently it is necessary to chunk the input to
+ * zlib. This code uses ZLIB_IO_MAX, from pngpriv.h, as the maximum (the
+ * maximum value that can be stored in a uInt.) It is possible to set
+ * ZLIB_IO_MAX to a lower value in pngpriv.h and this may sometimes have
+ * a performance advantage, because it reduces the amount of data accessed
+ * at each step and that may give the OS more time to page it in.
+ */
+ png_ptr->zstream.next_in = PNGZ_INPUT_CAST(input);
+ /* avail_in and avail_out are set below from 'size' */
+ png_ptr->zstream.avail_in = 0;
+ png_ptr->zstream.avail_out = 0;
+
+ /* Read directly into the output if it is available (this is set to
+ * a local buffer below if output is NULL).
+ */
+ if (output != NULL)
+ png_ptr->zstream.next_out = output;
+
+ do
+ {
+ uInt avail;
+ Byte local_buffer[PNG_INFLATE_BUF_SIZE];
+
+ /* zlib INPUT BUFFER */
+ /* The setting of 'avail_in' used to be outside the loop; by setting it
+ * inside it is possible to chunk the input to zlib and simply rely on
+ * zlib to advance the 'next_in' pointer. This allows arbitrary
+ * amounts of data to be passed through zlib at the unavoidable cost of
+ * requiring a window save (memcpy of up to 32768 output bytes)
+ * every ZLIB_IO_MAX input bytes.
+ */
+ avail_in += png_ptr->zstream.avail_in; /* not consumed last time */
+
+ avail = ZLIB_IO_MAX;
+
+ if (avail_in < avail)
+ avail = (uInt)avail_in; /* safe: < than ZLIB_IO_MAX */
+
+ avail_in -= avail;
+ png_ptr->zstream.avail_in = avail;
+
+ /* zlib OUTPUT BUFFER */
+ avail_out += png_ptr->zstream.avail_out; /* not written last time */
+
+ avail = ZLIB_IO_MAX; /* maximum zlib can process */
+
+ if (output == NULL)
+ {
+ /* Reset the output buffer each time round if output is NULL and
+ * make available the full buffer, up to 'remaining_space'
+ */
+ png_ptr->zstream.next_out = local_buffer;
+ if ((sizeof local_buffer) < avail)
+ avail = (sizeof local_buffer);
+ }
+
+ if (avail_out < avail)
+ avail = (uInt)avail_out; /* safe: < ZLIB_IO_MAX */
+
+ png_ptr->zstream.avail_out = avail;
+ avail_out -= avail;
+
+ /* zlib inflate call */
+ /* In fact 'avail_out' may be 0 at this point, that happens at the end
+ * of the read when the final LZ end code was not passed at the end of
+ * the previous chunk of input data. Tell zlib if we have reached the
+ * end of the output buffer.
+ */
+ ret = inflate(&png_ptr->zstream, avail_out > 0 ? Z_NO_FLUSH :
+ (finish ? Z_FINISH : Z_SYNC_FLUSH));
+ } while (ret == Z_OK);
+
+ /* For safety kill the local buffer pointer now */
+ if (output == NULL)
+ png_ptr->zstream.next_out = NULL;
+
+ /* Claw back the 'size' and 'remaining_space' byte counts. */
+ avail_in += png_ptr->zstream.avail_in;
+ avail_out += png_ptr->zstream.avail_out;
+
+ /* Update the input and output sizes; the updated values are the amount
+ * consumed or written, effectively the inverse of what zlib uses.
+ */
+ if (avail_out > 0)
+ *output_size_ptr -= avail_out;
+
+ if (avail_in > 0)
+ *input_size_ptr -= avail_in;
+
+ /* Ensure png_ptr->zstream.msg is set (even in the success case!) */
+ png_zstream_error(png_ptr, ret);
+ return ret;
+ }
+
+ else
+ {
+ /* This is a bad internal error. The recovery assigns to the zstream msg
+ * pointer, which is not owned by the caller, but this is safe; it's only
+ * used on errors!
+ */
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("zstream unclaimed");
+ return Z_STREAM_ERROR;
+ }
+}
+
+/*
+ * Decompress trailing data in a chunk. The assumption is that read_buffer
+ * points at an allocated area holding the contents of a chunk with a
+ * trailing compressed part. What we get back is an allocated area
+ * holding the original prefix part and an uncompressed version of the
+ * trailing part (the malloc area passed in is freed).
+ */
+static int
+png_decompress_chunk(png_structrp png_ptr,
+ png_uint_32 chunklength, png_uint_32 prefix_size,
+ png_alloc_size_t *newlength /* must be initialized to the maximum! */,
+ int terminate /*add a '\0' to the end of the uncompressed data*/)
+{
+ /* TODO: implement different limits for different types of chunk.
+ *
+ * The caller supplies *newlength set to the maximum length of the
+ * uncompressed data, but this routine allocates space for the prefix and
+ * maybe a '\0' terminator too. We have to assume that 'prefix_size' is
+ * limited only by the maximum chunk size.
+ */
+ png_alloc_size_t limit = PNG_SIZE_MAX;
+
+# ifdef PNG_SET_CHUNK_MALLOC_LIMIT_SUPPORTED
+ if (png_ptr->user_chunk_malloc_max > 0 &&
+ png_ptr->user_chunk_malloc_max < limit)
+ limit = png_ptr->user_chunk_malloc_max;
+# elif PNG_USER_CHUNK_MALLOC_MAX > 0
+ if (PNG_USER_CHUNK_MALLOC_MAX < limit)
+ limit = PNG_USER_CHUNK_MALLOC_MAX;
+# endif
+
+ if (limit >= prefix_size + (terminate != 0))
+ {
+ int ret;
+
+ limit -= prefix_size + (terminate != 0);
+
+ if (limit < *newlength)
+ *newlength = limit;
+
+ /* Now try to claim the stream. */
+ ret = png_inflate_claim(png_ptr, png_ptr->chunk_name);
+
+ if (ret == Z_OK)
+ {
+ png_uint_32 lzsize = chunklength - prefix_size;
+
+ ret = png_inflate(png_ptr, png_ptr->chunk_name, 1/*finish*/,
+ /* input: */ png_ptr->read_buffer + prefix_size, &lzsize,
+ /* output: */ NULL, newlength);
+
+ if (ret == Z_STREAM_END)
+ {
+ /* Use 'inflateReset' here, not 'inflateReset2' because this
+ * preserves the previously decided window size (otherwise it would
+ * be necessary to store the previous window size.) In practice
+ * this doesn't matter anyway, because png_inflate will call inflate
+ * with Z_FINISH in almost all cases, so the window will not be
+ * maintained.
+ */
+ if (inflateReset(&png_ptr->zstream) == Z_OK)
+ {
+ /* Because of the limit checks above we know that the new,
+ * expanded, size will fit in a size_t (let alone an
+ * png_alloc_size_t). Use png_malloc_base here to avoid an
+ * extra OOM message.
+ */
+ png_alloc_size_t new_size = *newlength;
+ png_alloc_size_t buffer_size = prefix_size + new_size +
+ (terminate != 0);
+ png_bytep text = png_voidcast(png_bytep, png_malloc_base(png_ptr,
+ buffer_size));
+
+ if (text != NULL)
+ {
+ ret = png_inflate(png_ptr, png_ptr->chunk_name, 1/*finish*/,
+ png_ptr->read_buffer + prefix_size, &lzsize,
+ text + prefix_size, newlength);
+
+ if (ret == Z_STREAM_END)
+ {
+ if (new_size == *newlength)
+ {
+ if (terminate)
+ text[prefix_size + *newlength] = 0;
+
+ if (prefix_size > 0)
+ memcpy(text, png_ptr->read_buffer, prefix_size);
+
+ {
+ png_bytep old_ptr = png_ptr->read_buffer;
+
+ png_ptr->read_buffer = text;
+ png_ptr->read_buffer_size = buffer_size;
+ text = old_ptr; /* freed below */
+ }
+ }
+
+ else
+ {
+ /* The size changed on the second read, there can be no
+ * guarantee that anything is correct at this point.
+ * The 'msg' pointer has been set to "unexpected end of
+ * LZ stream", which is fine, but return an error code
+ * that the caller won't accept.
+ */
+ ret = PNG_UNEXPECTED_ZLIB_RETURN;
+ }
+ }
+
+ else if (ret == Z_OK)
+ ret = PNG_UNEXPECTED_ZLIB_RETURN; /* for safety */
+
+ /* Free the text pointer (this is the old read_buffer on
+ * success)
+ */
+ png_free(png_ptr, text);
+
+ /* This really is very benign, but it's still an error because
+ * the extra space may otherwise be used as a Trojan Horse.
+ */
+ if (ret == Z_STREAM_END &&
+ chunklength - prefix_size != lzsize)
+ png_chunk_benign_error(png_ptr, "extra compressed data");
+ }
+
+ else
+ {
+ /* Out of memory allocating the buffer */
+ ret = Z_MEM_ERROR;
+ png_zstream_error(png_ptr, Z_MEM_ERROR);
+ }
+ }
+
+ else
+ {
+ /* inflateReset failed, store the error message */
+ png_zstream_error(png_ptr, ret);
+
+ if (ret == Z_STREAM_END)
+ ret = PNG_UNEXPECTED_ZLIB_RETURN;
+ }
+ }
+
+ else if (ret == Z_OK)
+ ret = PNG_UNEXPECTED_ZLIB_RETURN;
+
+ /* Release the claimed stream */
+ png_ptr->zowner = 0;
+ }
+
+ else /* the claim failed */ if (ret == Z_STREAM_END) /* impossible! */
+ ret = PNG_UNEXPECTED_ZLIB_RETURN;
+
+ return ret;
+ }
+
+ else
+ {
+ /* Application/configuration limits exceeded */
+ png_zstream_error(png_ptr, Z_MEM_ERROR);
+ return Z_MEM_ERROR;
+ }
+}
+#endif /* PNG_READ_COMPRESSED_TEXT_SUPPORTED */
+
+#ifdef PNG_READ_iCCP_SUPPORTED
+/* Perform a partial read and decompress, producing 'avail_out' bytes and
+ * reading from the current chunk as required.
+ */
+static int
+png_inflate_read(png_structrp png_ptr, png_bytep read_buffer, uInt read_size,
+ png_uint_32p chunk_bytes, png_bytep next_out, png_alloc_size_t *out_size,
+ int finish)
+{
+ if (png_ptr->zowner == png_ptr->chunk_name)
+ {
+ int ret;
+
+ /* next_in and avail_in must have been initialized by the caller. */
+ png_ptr->zstream.next_out = next_out;
+ png_ptr->zstream.avail_out = 0; /* set in the loop */
+
+ do
+ {
+ if (png_ptr->zstream.avail_in == 0)
+ {
+ if (read_size > *chunk_bytes)
+ read_size = (uInt)*chunk_bytes;
+ *chunk_bytes -= read_size;
+
+ if (read_size > 0)
+ png_crc_read(png_ptr, read_buffer, read_size);
+
+ png_ptr->zstream.next_in = read_buffer;
+ png_ptr->zstream.avail_in = read_size;
+ }
+
+ if (png_ptr->zstream.avail_out == 0)
+ {
+ uInt avail = ZLIB_IO_MAX;
+ if (avail > *out_size)
+ avail = (uInt)*out_size;
+ *out_size -= avail;
+
+ png_ptr->zstream.avail_out = avail;
+ }
+
+ /* Use Z_SYNC_FLUSH when there is no more chunk data to ensure that all
+ * the available output is produced; this allows reading of truncated
+ * streams.
+ */
+ ret = inflate(&png_ptr->zstream,
+ *chunk_bytes > 0 ? Z_NO_FLUSH : (finish ? Z_FINISH : Z_SYNC_FLUSH));
+ }
+ while (ret == Z_OK && (*out_size > 0 || png_ptr->zstream.avail_out > 0));
+
+ *out_size += png_ptr->zstream.avail_out;
+ png_ptr->zstream.avail_out = 0; /* Should not be required, but is safe */
+
+ /* Ensure the error message pointer is always set: */
+ png_zstream_error(png_ptr, ret);
+ return ret;
+ }
+
+ else
+ {
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("zstream unclaimed");
+ return Z_STREAM_ERROR;
+ }
+}
+#endif
+
+/* Read and check the IDHR chunk */
+void /* PRIVATE */
+png_handle_IHDR(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_byte buf[13];
+ png_uint_32 width, height;
+ int bit_depth, color_type, compression_type, filter_type;
+ int interlace_type;
+
+ png_debug(1, "in png_handle_IHDR");
+
+ if (png_ptr->mode & PNG_HAVE_IHDR)
+ png_chunk_error(png_ptr, "out of place");
+
+ /* Check the length */
+ if (length != 13)
+ png_chunk_error(png_ptr, "invalid");
+
+ png_ptr->mode |= PNG_HAVE_IHDR;
+
+ png_crc_read(png_ptr, buf, 13);
+ png_crc_finish(png_ptr, 0);
+
+ width = png_get_uint_31(png_ptr, buf);
+ height = png_get_uint_31(png_ptr, buf + 4);
+ bit_depth = buf[8];
+ color_type = buf[9];
+ compression_type = buf[10];
+ filter_type = buf[11];
+ interlace_type = buf[12];
+
+ /* Set internal variables */
+ png_ptr->width = width;
+ png_ptr->height = height;
+ png_ptr->bit_depth = (png_byte)bit_depth;
+ png_ptr->interlaced = (png_byte)interlace_type;
+ png_ptr->color_type = (png_byte)color_type;
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ png_ptr->filter_type = (png_byte)filter_type;
+#endif
+ png_ptr->compression_type = (png_byte)compression_type;
+
+ /* Find number of channels */
+ switch (png_ptr->color_type)
+ {
+ default: /* invalid, png_set_IHDR calls png_error */
+ case PNG_COLOR_TYPE_GRAY:
+ case PNG_COLOR_TYPE_PALETTE:
+ png_ptr->channels = 1;
+ break;
+
+ case PNG_COLOR_TYPE_RGB:
+ png_ptr->channels = 3;
+ break;
+
+ case PNG_COLOR_TYPE_GRAY_ALPHA:
+ png_ptr->channels = 2;
+ break;
+
+ case PNG_COLOR_TYPE_RGB_ALPHA:
+ png_ptr->channels = 4;
+ break;
+ }
+
+ /* Set up other useful info */
+ png_ptr->pixel_depth = (png_byte)(png_ptr->bit_depth *
+ png_ptr->channels);
+ png_ptr->rowbytes = PNG_ROWBYTES(png_ptr->pixel_depth, png_ptr->width);
+ png_debug1(3, "bit_depth = %d", png_ptr->bit_depth);
+ png_debug1(3, "channels = %d", png_ptr->channels);
+ png_debug1(3, "rowbytes = %lu", (unsigned long)png_ptr->rowbytes);
+ png_set_IHDR(png_ptr, info_ptr, width, height, bit_depth,
+ color_type, interlace_type, compression_type, filter_type);
+}
+
+/* Read and check the palette */
+void /* PRIVATE */
+png_handle_PLTE(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_color palette[PNG_MAX_PALETTE_LENGTH];
+ int num, i;
+#ifdef PNG_POINTER_INDEXING_SUPPORTED
+ png_colorp pal_ptr;
+#endif
+
+ png_debug(1, "in png_handle_PLTE");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ /* Moved to before the 'after IDAT' check below because otherwise duplicate
+ * PLTE chunks are potentially ignored (the spec says there shall not be more
+ * than one PLTE, the error is not treated as benign, so this check trumps
+ * the requirement that PLTE appears before IDAT.)
+ */
+ else if (png_ptr->mode & PNG_HAVE_PLTE)
+ png_chunk_error(png_ptr, "duplicate");
+
+ else if (png_ptr->mode & PNG_HAVE_IDAT)
+ {
+ /* This is benign because the non-benign error happened before, when an
+ * IDAT was encountered in a color-mapped image with no PLTE.
+ */
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ png_ptr->mode |= PNG_HAVE_PLTE;
+
+ if (!(png_ptr->color_type & PNG_COLOR_MASK_COLOR))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "ignored in grayscale PNG");
+ return;
+ }
+
+#ifndef PNG_READ_OPT_PLTE_SUPPORTED
+ if (png_ptr->color_type != PNG_COLOR_TYPE_PALETTE)
+ {
+ png_crc_finish(png_ptr, length);
+ return;
+ }
+#endif
+
+ if (length > 3*PNG_MAX_PALETTE_LENGTH || length % 3)
+ {
+ png_crc_finish(png_ptr, length);
+
+ if (png_ptr->color_type != PNG_COLOR_TYPE_PALETTE)
+ png_chunk_benign_error(png_ptr, "invalid");
+
+ else
+ png_chunk_error(png_ptr, "invalid");
+
+ return;
+ }
+
+ /* The cast is safe because 'length' is less than 3*PNG_MAX_PALETTE_LENGTH */
+ num = (int)length / 3;
+
+#ifdef PNG_POINTER_INDEXING_SUPPORTED
+ for (i = 0, pal_ptr = palette; i < num; i++, pal_ptr++)
+ {
+ png_byte buf[3];
+
+ png_crc_read(png_ptr, buf, 3);
+ pal_ptr->red = buf[0];
+ pal_ptr->green = buf[1];
+ pal_ptr->blue = buf[2];
+ }
+#else
+ for (i = 0; i < num; i++)
+ {
+ png_byte buf[3];
+
+ png_crc_read(png_ptr, buf, 3);
+ /* Don't depend upon png_color being any order */
+ palette[i].red = buf[0];
+ palette[i].green = buf[1];
+ palette[i].blue = buf[2];
+ }
+#endif
+
+ /* If we actually need the PLTE chunk (ie for a paletted image), we do
+ * whatever the normal CRC configuration tells us. However, if we
+ * have an RGB image, the PLTE can be considered ancillary, so
+ * we will act as though it is.
+ */
+#ifndef PNG_READ_OPT_PLTE_SUPPORTED
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+#endif
+ {
+ png_crc_finish(png_ptr, 0);
+ }
+
+#ifndef PNG_READ_OPT_PLTE_SUPPORTED
+ else if (png_crc_error(png_ptr)) /* Only if we have a CRC error */
+ {
+ /* If we don't want to use the data from an ancillary chunk,
+ * we have two options: an error abort, or a warning and we
+ * ignore the data in this chunk (which should be OK, since
+ * it's considered ancillary for a RGB or RGBA image).
+ *
+ * IMPLEMENTATION NOTE: this is only here because png_crc_finish uses the
+ * chunk type to determine whether to check the ancillary or the critical
+ * flags.
+ */
+ if (!(png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_USE))
+ {
+ if (png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_NOWARN)
+ {
+ png_chunk_benign_error(png_ptr, "CRC error");
+ }
+
+ else
+ {
+ png_chunk_warning(png_ptr, "CRC error");
+ return;
+ }
+ }
+
+ /* Otherwise, we (optionally) emit a warning and use the chunk. */
+ else if (!(png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_NOWARN))
+ {
+ png_chunk_warning(png_ptr, "CRC error");
+ }
+ }
+#endif
+
+ /* TODO: png_set_PLTE has the side effect of setting png_ptr->palette to its
+ * own copy of the palette. This has the side effect that when png_start_row
+ * is called (this happens after any call to png_read_update_info) the
+ * info_ptr palette gets changed. This is extremely unexpected and
+ * confusing.
+ *
+ * Fix this by not sharing the palette in this way.
+ */
+ png_set_PLTE(png_ptr, info_ptr, palette, num);
+
+ /* The three chunks, bKGD, hIST and tRNS *must* appear after PLTE and before
+ * IDAT. Prior to 1.6.0 this was not checked; instead the code merely
+ * checked the apparent validity of a tRNS chunk inserted before PLTE on a
+ * palette PNG. 1.6.0 attempts to rigorously follow the standard and
+ * therefore does a benign error if the erroneous condition is detected *and*
+ * cancels the tRNS if the benign error returns. The alternative is to
+ * amend the standard since it would be rather hypocritical of the standards
+ * maintainers to ignore it.
+ */
+#ifdef PNG_READ_tRNS_SUPPORTED
+ if (png_ptr->num_trans > 0 ||
+ (info_ptr != NULL && (info_ptr->valid & PNG_INFO_tRNS) != 0))
+ {
+ /* Cancel this because otherwise it would be used if the transforms
+ * require it. Don't cancel the 'valid' flag because this would prevent
+ * detection of duplicate chunks.
+ */
+ png_ptr->num_trans = 0;
+
+ if (info_ptr != NULL)
+ info_ptr->num_trans = 0;
+
+ png_chunk_benign_error(png_ptr, "tRNS must be after");
+ }
+#endif
+
+#ifdef PNG_READ_hIST_SUPPORTED
+ if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_hIST) != 0)
+ png_chunk_benign_error(png_ptr, "hIST must be after");
+#endif
+
+#ifdef PNG_READ_bKGD_SUPPORTED
+ if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_bKGD) != 0)
+ png_chunk_benign_error(png_ptr, "bKGD must be after");
+#endif
+}
+
+void /* PRIVATE */
+png_handle_IEND(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_debug(1, "in png_handle_IEND");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR) || !(png_ptr->mode & PNG_HAVE_IDAT))
+ png_chunk_error(png_ptr, "out of place");
+
+ png_ptr->mode |= (PNG_AFTER_IDAT | PNG_HAVE_IEND);
+
+ png_crc_finish(png_ptr, length);
+
+ if (length != 0)
+ png_chunk_benign_error(png_ptr, "invalid");
+
+ PNG_UNUSED(info_ptr)
+}
+
+#ifdef PNG_READ_gAMA_SUPPORTED
+void /* PRIVATE */
+png_handle_gAMA(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_fixed_point igamma;
+ png_byte buf[4];
+
+ png_debug(1, "in png_handle_gAMA");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ if (length != 4)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, buf, 4);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ igamma = png_get_fixed_point(NULL, buf);
+
+ png_colorspace_set_gamma(png_ptr, &png_ptr->colorspace, igamma);
+ png_colorspace_sync(png_ptr, info_ptr);
+}
+#endif
+
+#ifdef PNG_READ_sBIT_SUPPORTED
+void /* PRIVATE */
+png_handle_sBIT(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ unsigned int truelen;
+ png_byte buf[4];
+
+ png_debug(1, "in png_handle_sBIT");
+
+ buf[0] = buf[1] = buf[2] = buf[3] = 0;
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_sBIT))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ truelen = 3;
+
+ else
+ truelen = png_ptr->channels;
+
+ if (length != truelen || length > 4)
+ {
+ png_chunk_benign_error(png_ptr, "invalid");
+ png_crc_finish(png_ptr, length);
+ return;
+ }
+
+ png_crc_read(png_ptr, buf, truelen);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ if (png_ptr->color_type & PNG_COLOR_MASK_COLOR)
+ {
+ png_ptr->sig_bit.red = buf[0];
+ png_ptr->sig_bit.green = buf[1];
+ png_ptr->sig_bit.blue = buf[2];
+ png_ptr->sig_bit.alpha = buf[3];
+ }
+
+ else
+ {
+ png_ptr->sig_bit.gray = buf[0];
+ png_ptr->sig_bit.red = buf[0];
+ png_ptr->sig_bit.green = buf[0];
+ png_ptr->sig_bit.blue = buf[0];
+ png_ptr->sig_bit.alpha = buf[1];
+ }
+
+ png_set_sBIT(png_ptr, info_ptr, &(png_ptr->sig_bit));
+}
+#endif
+
+#ifdef PNG_READ_cHRM_SUPPORTED
+void /* PRIVATE */
+png_handle_cHRM(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_byte buf[32];
+ png_xy xy;
+
+ png_debug(1, "in png_handle_cHRM");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ if (length != 32)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, buf, 32);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ xy.whitex = png_get_fixed_point(NULL, buf);
+ xy.whitey = png_get_fixed_point(NULL, buf + 4);
+ xy.redx = png_get_fixed_point(NULL, buf + 8);
+ xy.redy = png_get_fixed_point(NULL, buf + 12);
+ xy.greenx = png_get_fixed_point(NULL, buf + 16);
+ xy.greeny = png_get_fixed_point(NULL, buf + 20);
+ xy.bluex = png_get_fixed_point(NULL, buf + 24);
+ xy.bluey = png_get_fixed_point(NULL, buf + 28);
+
+ if (xy.whitex == PNG_FIXED_ERROR ||
+ xy.whitey == PNG_FIXED_ERROR ||
+ xy.redx == PNG_FIXED_ERROR ||
+ xy.redy == PNG_FIXED_ERROR ||
+ xy.greenx == PNG_FIXED_ERROR ||
+ xy.greeny == PNG_FIXED_ERROR ||
+ xy.bluex == PNG_FIXED_ERROR ||
+ xy.bluey == PNG_FIXED_ERROR)
+ {
+ png_chunk_benign_error(png_ptr, "invalid values");
+ return;
+ }
+
+ /* If a colorspace error has already been output skip this chunk */
+ if (png_ptr->colorspace.flags & PNG_COLORSPACE_INVALID)
+ return;
+
+ if (png_ptr->colorspace.flags & PNG_COLORSPACE_FROM_cHRM)
+ {
+ png_ptr->colorspace.flags |= PNG_COLORSPACE_INVALID;
+ png_colorspace_sync(png_ptr, info_ptr);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ png_ptr->colorspace.flags |= PNG_COLORSPACE_FROM_cHRM;
+ (void)png_colorspace_set_chromaticities(png_ptr, &png_ptr->colorspace, &xy,
+ 1/*prefer cHRM values*/);
+ png_colorspace_sync(png_ptr, info_ptr);
+}
+#endif
+
+#ifdef PNG_READ_sRGB_SUPPORTED
+void /* PRIVATE */
+png_handle_sRGB(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_byte intent;
+
+ png_debug(1, "in png_handle_sRGB");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ if (length != 1)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, &intent, 1);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ /* If a colorspace error has already been output skip this chunk */
+ if (png_ptr->colorspace.flags & PNG_COLORSPACE_INVALID)
+ return;
+
+ /* Only one sRGB or iCCP chunk is allowed, use the HAVE_INTENT flag to detect
+ * this.
+ */
+ if (png_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_INTENT)
+ {
+ png_ptr->colorspace.flags |= PNG_COLORSPACE_INVALID;
+ png_colorspace_sync(png_ptr, info_ptr);
+ png_chunk_benign_error(png_ptr, "too many profiles");
+ return;
+ }
+
+ (void)png_colorspace_set_sRGB(png_ptr, &png_ptr->colorspace, intent);
+ png_colorspace_sync(png_ptr, info_ptr);
+}
+#endif /* PNG_READ_sRGB_SUPPORTED */
+
+#ifdef PNG_READ_iCCP_SUPPORTED
+void /* PRIVATE */
+png_handle_iCCP(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+/* Note: this does not properly handle profiles that are > 64K under DOS */
+{
+ png_const_charp errmsg = NULL; /* error message output, or no error */
+ int finished = 0; /* crc checked */
+
+ png_debug(1, "in png_handle_iCCP");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ /* Consistent with all the above colorspace handling an obviously *invalid*
+ * chunk is just ignored, so does not invalidate the color space. An
+ * alternative is to set the 'invalid' flags at the start of this routine
+ * and only clear them in they were not set before and all the tests pass.
+ * The minimum 'deflate' stream is assumed to be just the 2 byte header and 4
+ * byte checksum. The keyword must be one character and there is a
+ * terminator (0) byte and the compression method.
+ */
+ if (length < 9)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "too short");
+ return;
+ }
+
+ /* If a colorspace error has already been output skip this chunk */
+ if (png_ptr->colorspace.flags & PNG_COLORSPACE_INVALID)
+ {
+ png_crc_finish(png_ptr, length);
+ return;
+ }
+
+ /* Only one sRGB or iCCP chunk is allowed, use the HAVE_INTENT flag to detect
+ * this.
+ */
+ if ((png_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_INTENT) == 0)
+ {
+ uInt read_length, keyword_length;
+ char keyword[81];
+
+ /* Find the keyword; the keyword plus separator and compression method
+ * bytes can be at most 81 characters long.
+ */
+ read_length = 81; /* maximum */
+ if (read_length > length)
+ read_length = (uInt)length;
+
+ png_crc_read(png_ptr, (png_bytep)keyword, read_length);
+ length -= read_length;
+
+ keyword_length = 0;
+ while (keyword_length < 80 && keyword_length < read_length &&
+ keyword[keyword_length] != 0)
+ ++keyword_length;
+
+ /* TODO: make the keyword checking common */
+ if (keyword_length >= 1 && keyword_length <= 79)
+ {
+ /* We only understand '0' compression - deflate - so if we get a
+ * different value we can't safely decode the chunk.
+ */
+ if (keyword_length+1 < read_length &&
+ keyword[keyword_length+1] == PNG_COMPRESSION_TYPE_BASE)
+ {
+ read_length -= keyword_length+2;
+
+ if (png_inflate_claim(png_ptr, png_iCCP) == Z_OK)
+ {
+ Byte profile_header[132];
+ Byte local_buffer[PNG_INFLATE_BUF_SIZE];
+ png_alloc_size_t size = (sizeof profile_header);
+
+ png_ptr->zstream.next_in = (Bytef*)keyword + (keyword_length+2);
+ png_ptr->zstream.avail_in = read_length;
+ (void)png_inflate_read(png_ptr, local_buffer,
+ (sizeof local_buffer), &length, profile_header, &size,
+ 0/*finish: don't, because the output is too small*/);
+
+ if (size == 0)
+ {
+ /* We have the ICC profile header; do the basic header checks.
+ */
+ const png_uint_32 profile_length =
+ png_get_uint_32(profile_header);
+
+ if (png_icc_check_length(png_ptr, &png_ptr->colorspace,
+ keyword, profile_length))
+ {
+ /* The length is apparently ok, so we can check the 132
+ * byte header.
+ */
+ if (png_icc_check_header(png_ptr, &png_ptr->colorspace,
+ keyword, profile_length, profile_header,
+ png_ptr->color_type))
+ {
+ /* Now read the tag table; a variable size buffer is
+ * needed at this point, allocate one for the whole
+ * profile. The header check has already validated
+ * that none of these stuff will overflow.
+ */
+ const png_uint_32 tag_count = png_get_uint_32(
+ profile_header+128);
+ png_bytep profile = png_read_buffer(png_ptr,
+ profile_length, 2/*silent*/);
+
+ if (profile != NULL)
+ {
+ memcpy(profile, profile_header,
+ (sizeof profile_header));
+
+ size = 12 * tag_count;
+
+ (void)png_inflate_read(png_ptr, local_buffer,
+ (sizeof local_buffer), &length,
+ profile + (sizeof profile_header), &size, 0);
+
+ /* Still expect a a buffer error because we expect
+ * there to be some tag data!
+ */
+ if (size == 0)
+ {
+ if (png_icc_check_tag_table(png_ptr,
+ &png_ptr->colorspace, keyword, profile_length,
+ profile))
+ {
+ /* The profile has been validated for basic
+ * security issues, so read the whole thing in.
+ */
+ size = profile_length - (sizeof profile_header)
+ - 12 * tag_count;
+
+ (void)png_inflate_read(png_ptr, local_buffer,
+ (sizeof local_buffer), &length,
+ profile + (sizeof profile_header) +
+ 12 * tag_count, &size, 1/*finish*/);
+
+ if (length > 0 && !(png_ptr->flags &
+ PNG_FLAG_BENIGN_ERRORS_WARN))
+ errmsg = "extra compressed data";
+
+ /* But otherwise allow extra data: */
+ else if (size == 0)
+ {
+ if (length > 0)
+ {
+ /* This can be handled completely, so
+ * keep going.
+ */
+ png_chunk_warning(png_ptr,
+ "extra compressed data");
+ }
+
+ png_crc_finish(png_ptr, length);
+ finished = 1;
+
+# ifdef PNG_sRGB_SUPPORTED
+ /* Check for a match against sRGB */
+ png_icc_set_sRGB(png_ptr,
+ &png_ptr->colorspace, profile,
+ png_ptr->zstream.adler);
+# endif
+
+ /* Steal the profile for info_ptr. */
+ if (info_ptr != NULL)
+ {
+ png_free_data(png_ptr, info_ptr,
+ PNG_FREE_ICCP, 0);
+
+ info_ptr->iccp_name = png_voidcast(char*,
+ png_malloc_base(png_ptr,
+ keyword_length+1));
+ if (info_ptr->iccp_name != NULL)
+ {
+ memcpy(info_ptr->iccp_name, keyword,
+ keyword_length+1);
+ info_ptr->iccp_proflen =
+ profile_length;
+ info_ptr->iccp_profile = profile;
+ png_ptr->read_buffer = NULL; /*steal*/
+ info_ptr->free_me |= PNG_FREE_ICCP;
+ info_ptr->valid |= PNG_INFO_iCCP;
+ }
+
+ else
+ {
+ png_ptr->colorspace.flags |=
+ PNG_COLORSPACE_INVALID;
+ errmsg = "out of memory";
+ }
+ }
+
+ /* else the profile remains in the read
+ * buffer which gets reused for subsequent
+ * chunks.
+ */
+
+ if (info_ptr != NULL)
+ png_colorspace_sync(png_ptr, info_ptr);
+
+ if (errmsg == NULL)
+ {
+ png_ptr->zowner = 0;
+ return;
+ }
+ }
+
+ else if (size > 0)
+ errmsg = "truncated";
+
+ else
+ errmsg = png_ptr->zstream.msg;
+ }
+
+ /* else png_icc_check_tag_table output an error */
+ }
+
+ else /* profile truncated */
+ errmsg = png_ptr->zstream.msg;
+ }
+
+ else
+ errmsg = "out of memory";
+ }
+
+ /* else png_icc_check_header output an error */
+ }
+
+ /* else png_icc_check_length output an error */
+ }
+
+ else /* profile truncated */
+ errmsg = png_ptr->zstream.msg;
+
+ /* Release the stream */
+ png_ptr->zowner = 0;
+ }
+
+ else /* png_inflate_claim failed */
+ errmsg = png_ptr->zstream.msg;
+ }
+
+ else
+ errmsg = "bad compression method"; /* or missing */
+ }
+
+ else
+ errmsg = "bad keyword";
+ }
+
+ else
+ errmsg = "too many profiles";
+
+ /* Failure: the reason is in 'errmsg' */
+ if (!finished)
+ png_crc_finish(png_ptr, length);
+
+ png_ptr->colorspace.flags |= PNG_COLORSPACE_INVALID;
+ png_colorspace_sync(png_ptr, info_ptr);
+ if (errmsg != NULL) /* else already output */
+ png_chunk_benign_error(png_ptr, errmsg);
+}
+#endif /* PNG_READ_iCCP_SUPPORTED */
+
+#ifdef PNG_READ_sPLT_SUPPORTED
+void /* PRIVATE */
+png_handle_sPLT(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+/* Note: this does not properly handle chunks that are > 64K under DOS */
+{
+ png_bytep entry_start, buffer;
+ png_sPLT_t new_palette;
+ png_sPLT_entryp pp;
+ png_uint_32 data_length;
+ int entry_size, i;
+ png_uint_32 skip = 0;
+ png_uint_32 dl;
+ png_size_t max_dl;
+
+ png_debug(1, "in png_handle_sPLT");
+
+#ifdef PNG_USER_LIMITS_SUPPORTED
+ if (png_ptr->user_chunk_cache_max != 0)
+ {
+ if (png_ptr->user_chunk_cache_max == 1)
+ {
+ png_crc_finish(png_ptr, length);
+ return;
+ }
+
+ if (--png_ptr->user_chunk_cache_max == 1)
+ {
+ png_warning(png_ptr, "No space in chunk cache for sPLT");
+ png_crc_finish(png_ptr, length);
+ return;
+ }
+ }
+#endif
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & PNG_HAVE_IDAT)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+#ifdef PNG_MAX_MALLOC_64K
+ if (length > 65535U)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "too large to fit in memory");
+ return;
+ }
+#endif
+
+ buffer = png_read_buffer(png_ptr, length+1, 2/*silent*/);
+ if (buffer == NULL)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of memory");
+ return;
+ }
+
+
+ /* WARNING: this may break if size_t is less than 32 bits; it is assumed
+ * that the PNG_MAX_MALLOC_64K test is enabled in this case, but this is a
+ * potential breakage point if the types in pngconf.h aren't exactly right.
+ */
+ png_crc_read(png_ptr, buffer, length);
+
+ if (png_crc_finish(png_ptr, skip))
+ return;
+
+ buffer[length] = 0;
+
+ for (entry_start = buffer; *entry_start; entry_start++)
+ /* Empty loop to find end of name */ ;
+
+ ++entry_start;
+
+ /* A sample depth should follow the separator, and we should be on it */
+ if (entry_start > buffer + length - 2)
+ {
+ png_warning(png_ptr, "malformed sPLT chunk");
+ return;
+ }
+
+ new_palette.depth = *entry_start++;
+ entry_size = (new_palette.depth == 8 ? 6 : 10);
+ /* This must fit in a png_uint_32 because it is derived from the original
+ * chunk data length.
+ */
+ data_length = length - (png_uint_32)(entry_start - buffer);
+
+ /* Integrity-check the data length */
+ if (data_length % entry_size)
+ {
+ png_warning(png_ptr, "sPLT chunk has bad length");
+ return;
+ }
+
+ dl = (png_int_32)(data_length / entry_size);
+ max_dl = PNG_SIZE_MAX / (sizeof (png_sPLT_entry));
+
+ if (dl > max_dl)
+ {
+ png_warning(png_ptr, "sPLT chunk too long");
+ return;
+ }
+
+ new_palette.nentries = (png_int_32)(data_length / entry_size);
+
+ new_palette.entries = (png_sPLT_entryp)png_malloc_warn(
+ png_ptr, new_palette.nentries * (sizeof (png_sPLT_entry)));
+
+ if (new_palette.entries == NULL)
+ {
+ png_warning(png_ptr, "sPLT chunk requires too much memory");
+ return;
+ }
+
+#ifdef PNG_POINTER_INDEXING_SUPPORTED
+ for (i = 0; i < new_palette.nentries; i++)
+ {
+ pp = new_palette.entries + i;
+
+ if (new_palette.depth == 8)
+ {
+ pp->red = *entry_start++;
+ pp->green = *entry_start++;
+ pp->blue = *entry_start++;
+ pp->alpha = *entry_start++;
+ }
+
+ else
+ {
+ pp->red = png_get_uint_16(entry_start); entry_start += 2;
+ pp->green = png_get_uint_16(entry_start); entry_start += 2;
+ pp->blue = png_get_uint_16(entry_start); entry_start += 2;
+ pp->alpha = png_get_uint_16(entry_start); entry_start += 2;
+ }
+
+ pp->frequency = png_get_uint_16(entry_start); entry_start += 2;
+ }
+#else
+ pp = new_palette.entries;
+
+ for (i = 0; i < new_palette.nentries; i++)
+ {
+
+ if (new_palette.depth == 8)
+ {
+ pp[i].red = *entry_start++;
+ pp[i].green = *entry_start++;
+ pp[i].blue = *entry_start++;
+ pp[i].alpha = *entry_start++;
+ }
+
+ else
+ {
+ pp[i].red = png_get_uint_16(entry_start); entry_start += 2;
+ pp[i].green = png_get_uint_16(entry_start); entry_start += 2;
+ pp[i].blue = png_get_uint_16(entry_start); entry_start += 2;
+ pp[i].alpha = png_get_uint_16(entry_start); entry_start += 2;
+ }
+
+ pp[i].frequency = png_get_uint_16(entry_start); entry_start += 2;
+ }
+#endif
+
+ /* Discard all chunk data except the name and stash that */
+ new_palette.name = (png_charp)buffer;
+
+ png_set_sPLT(png_ptr, info_ptr, &new_palette, 1);
+
+ png_free(png_ptr, new_palette.entries);
+}
+#endif /* PNG_READ_sPLT_SUPPORTED */
+
+#ifdef PNG_READ_tRNS_SUPPORTED
+void /* PRIVATE */
+png_handle_tRNS(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_byte readbuf[PNG_MAX_PALETTE_LENGTH];
+
+ png_debug(1, "in png_handle_tRNS");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & PNG_HAVE_IDAT)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_tRNS))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ if (png_ptr->color_type == PNG_COLOR_TYPE_GRAY)
+ {
+ png_byte buf[2];
+
+ if (length != 2)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, buf, 2);
+ png_ptr->num_trans = 1;
+ png_ptr->trans_color.gray = png_get_uint_16(buf);
+ }
+
+ else if (png_ptr->color_type == PNG_COLOR_TYPE_RGB)
+ {
+ png_byte buf[6];
+
+ if (length != 6)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, buf, length);
+ png_ptr->num_trans = 1;
+ png_ptr->trans_color.red = png_get_uint_16(buf);
+ png_ptr->trans_color.green = png_get_uint_16(buf + 2);
+ png_ptr->trans_color.blue = png_get_uint_16(buf + 4);
+ }
+
+ else if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ if (!(png_ptr->mode & PNG_HAVE_PLTE))
+ {
+ /* TODO: is this actually an error in the ISO spec? */
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ if (length > png_ptr->num_palette || length > PNG_MAX_PALETTE_LENGTH ||
+ length == 0)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, readbuf, length);
+ png_ptr->num_trans = (png_uint_16)length;
+ }
+
+ else
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid with alpha channel");
+ return;
+ }
+
+ if (png_crc_finish(png_ptr, 0))
+ {
+ png_ptr->num_trans = 0;
+ return;
+ }
+
+ /* TODO: this is a horrible side effect in the palette case because the
+ * png_struct ends up with a pointer to the tRNS buffer owned by the
+ * png_info. Fix this.
+ */
+ png_set_tRNS(png_ptr, info_ptr, readbuf, png_ptr->num_trans,
+ &(png_ptr->trans_color));
+}
+#endif
+
+#ifdef PNG_READ_bKGD_SUPPORTED
+void /* PRIVATE */
+png_handle_bKGD(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ unsigned int truelen;
+ png_byte buf[6];
+ png_color_16 background;
+
+ png_debug(1, "in png_handle_bKGD");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if ((png_ptr->mode & PNG_HAVE_IDAT) ||
+ (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE &&
+ !(png_ptr->mode & PNG_HAVE_PLTE)))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_bKGD))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ truelen = 1;
+
+ else if (png_ptr->color_type & PNG_COLOR_MASK_COLOR)
+ truelen = 6;
+
+ else
+ truelen = 2;
+
+ if (length != truelen)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, buf, truelen);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ /* We convert the index value into RGB components so that we can allow
+ * arbitrary RGB values for background when we have transparency, and
+ * so it is easy to determine the RGB values of the background color
+ * from the info_ptr struct.
+ */
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ background.index = buf[0];
+
+ if (info_ptr && info_ptr->num_palette)
+ {
+ if (buf[0] >= info_ptr->num_palette)
+ {
+ png_chunk_benign_error(png_ptr, "invalid index");
+ return;
+ }
+
+ background.red = (png_uint_16)png_ptr->palette[buf[0]].red;
+ background.green = (png_uint_16)png_ptr->palette[buf[0]].green;
+ background.blue = (png_uint_16)png_ptr->palette[buf[0]].blue;
+ }
+
+ else
+ background.red = background.green = background.blue = 0;
+
+ background.gray = 0;
+ }
+
+ else if (!(png_ptr->color_type & PNG_COLOR_MASK_COLOR)) /* GRAY */
+ {
+ background.index = 0;
+ background.red =
+ background.green =
+ background.blue =
+ background.gray = png_get_uint_16(buf);
+ }
+
+ else
+ {
+ background.index = 0;
+ background.red = png_get_uint_16(buf);
+ background.green = png_get_uint_16(buf + 2);
+ background.blue = png_get_uint_16(buf + 4);
+ background.gray = 0;
+ }
+
+ png_set_bKGD(png_ptr, info_ptr, &background);
+}
+#endif
+
+#ifdef PNG_READ_hIST_SUPPORTED
+void /* PRIVATE */
+png_handle_hIST(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ unsigned int num, i;
+ png_uint_16 readbuf[PNG_MAX_PALETTE_LENGTH];
+
+ png_debug(1, "in png_handle_hIST");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if ((png_ptr->mode & PNG_HAVE_IDAT) || !(png_ptr->mode & PNG_HAVE_PLTE))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_hIST))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ num = length / 2 ;
+
+ if (num != png_ptr->num_palette || num > PNG_MAX_PALETTE_LENGTH)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ for (i = 0; i < num; i++)
+ {
+ png_byte buf[2];
+
+ png_crc_read(png_ptr, buf, 2);
+ readbuf[i] = png_get_uint_16(buf);
+ }
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ png_set_hIST(png_ptr, info_ptr, readbuf);
+}
+#endif
+
+#ifdef PNG_READ_pHYs_SUPPORTED
+void /* PRIVATE */
+png_handle_pHYs(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_byte buf[9];
+ png_uint_32 res_x, res_y;
+ int unit_type;
+
+ png_debug(1, "in png_handle_pHYs");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & PNG_HAVE_IDAT)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ if (length != 9)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, buf, 9);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ res_x = png_get_uint_32(buf);
+ res_y = png_get_uint_32(buf + 4);
+ unit_type = buf[8];
+ png_set_pHYs(png_ptr, info_ptr, res_x, res_y, unit_type);
+}
+#endif
+
+#ifdef PNG_READ_oFFs_SUPPORTED
+void /* PRIVATE */
+png_handle_oFFs(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_byte buf[9];
+ png_int_32 offset_x, offset_y;
+ int unit_type;
+
+ png_debug(1, "in png_handle_oFFs");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & PNG_HAVE_IDAT)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ if (length != 9)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, buf, 9);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ offset_x = png_get_int_32(buf);
+ offset_y = png_get_int_32(buf + 4);
+ unit_type = buf[8];
+ png_set_oFFs(png_ptr, info_ptr, offset_x, offset_y, unit_type);
+}
+#endif
+
+#ifdef PNG_READ_pCAL_SUPPORTED
+/* Read the pCAL chunk (described in the PNG Extensions document) */
+void /* PRIVATE */
+png_handle_pCAL(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_int_32 X0, X1;
+ png_byte type, nparams;
+ png_bytep buffer, buf, units, endptr;
+ png_charpp params;
+ int i;
+
+ png_debug(1, "in png_handle_pCAL");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & PNG_HAVE_IDAT)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_pCAL))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ png_debug1(2, "Allocating and reading pCAL chunk data (%u bytes)",
+ length + 1);
+
+ buffer = png_read_buffer(png_ptr, length+1, 2/*silent*/);
+
+ if (buffer == NULL)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of memory");
+ return;
+ }
+
+ png_crc_read(png_ptr, buffer, length);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ buffer[length] = 0; /* Null terminate the last string */
+
+ png_debug(3, "Finding end of pCAL purpose string");
+ for (buf = buffer; *buf; buf++)
+ /* Empty loop */ ;
+
+ endptr = buffer + length;
+
+ /* We need to have at least 12 bytes after the purpose string
+ * in order to get the parameter information.
+ */
+ if (endptr <= buf + 12)
+ {
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_debug(3, "Reading pCAL X0, X1, type, nparams, and units");
+ X0 = png_get_int_32((png_bytep)buf+1);
+ X1 = png_get_int_32((png_bytep)buf+5);
+ type = buf[9];
+ nparams = buf[10];
+ units = buf + 11;
+
+ png_debug(3, "Checking pCAL equation type and number of parameters");
+ /* Check that we have the right number of parameters for known
+ * equation types.
+ */
+ if ((type == PNG_EQUATION_LINEAR && nparams != 2) ||
+ (type == PNG_EQUATION_BASE_E && nparams != 3) ||
+ (type == PNG_EQUATION_ARBITRARY && nparams != 3) ||
+ (type == PNG_EQUATION_HYPERBOLIC && nparams != 4))
+ {
+ png_chunk_benign_error(png_ptr, "invalid parameter count");
+ return;
+ }
+
+ else if (type >= PNG_EQUATION_LAST)
+ {
+ png_chunk_benign_error(png_ptr, "unrecognized equation type");
+ }
+
+ for (buf = units; *buf; buf++)
+ /* Empty loop to move past the units string. */ ;
+
+ png_debug(3, "Allocating pCAL parameters array");
+
+ params = png_voidcast(png_charpp, png_malloc_warn(png_ptr,
+ nparams * (sizeof (png_charp))));
+
+ if (params == NULL)
+ {
+ png_chunk_benign_error(png_ptr, "out of memory");
+ return;
+ }
+
+ /* Get pointers to the start of each parameter string. */
+ for (i = 0; i < nparams; i++)
+ {
+ buf++; /* Skip the null string terminator from previous parameter. */
+
+ png_debug1(3, "Reading pCAL parameter %d", i);
+
+ for (params[i] = (png_charp)buf; buf <= endptr && *buf != 0; buf++)
+ /* Empty loop to move past each parameter string */ ;
+
+ /* Make sure we haven't run out of data yet */
+ if (buf > endptr)
+ {
+ png_free(png_ptr, params);
+ png_chunk_benign_error(png_ptr, "invalid data");
+ return;
+ }
+ }
+
+ png_set_pCAL(png_ptr, info_ptr, (png_charp)buffer, X0, X1, type, nparams,
+ (png_charp)units, params);
+
+ png_free(png_ptr, params);
+}
+#endif
+
+#ifdef PNG_READ_sCAL_SUPPORTED
+/* Read the sCAL chunk */
+void /* PRIVATE */
+png_handle_sCAL(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_bytep buffer;
+ png_size_t i;
+ int state;
+
+ png_debug(1, "in png_handle_sCAL");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (png_ptr->mode & PNG_HAVE_IDAT)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of place");
+ return;
+ }
+
+ else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_sCAL))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ /* Need unit type, width, \0, height: minimum 4 bytes */
+ else if (length < 4)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_debug1(2, "Allocating and reading sCAL chunk data (%u bytes)",
+ length + 1);
+
+ buffer = png_read_buffer(png_ptr, length+1, 2/*silent*/);
+
+ if (buffer == NULL)
+ {
+ png_chunk_benign_error(png_ptr, "out of memory");
+ png_crc_finish(png_ptr, length);
+ return;
+ }
+
+ png_crc_read(png_ptr, buffer, length);
+ buffer[length] = 0; /* Null terminate the last string */
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ /* Validate the unit. */
+ if (buffer[0] != 1 && buffer[0] != 2)
+ {
+ png_chunk_benign_error(png_ptr, "invalid unit");
+ return;
+ }
+
+ /* Validate the ASCII numbers, need two ASCII numbers separated by
+ * a '\0' and they need to fit exactly in the chunk data.
+ */
+ i = 1;
+ state = 0;
+
+ if (!png_check_fp_number((png_const_charp)buffer, length, &state, &i) ||
+ i >= length || buffer[i++] != 0)
+ png_chunk_benign_error(png_ptr, "bad width format");
+
+ else if (!PNG_FP_IS_POSITIVE(state))
+ png_chunk_benign_error(png_ptr, "non-positive width");
+
+ else
+ {
+ png_size_t heighti = i;
+
+ state = 0;
+ if (!png_check_fp_number((png_const_charp)buffer, length, &state, &i) ||
+ i != length)
+ png_chunk_benign_error(png_ptr, "bad height format");
+
+ else if (!PNG_FP_IS_POSITIVE(state))
+ png_chunk_benign_error(png_ptr, "non-positive height");
+
+ else
+ /* This is the (only) success case. */
+ png_set_sCAL_s(png_ptr, info_ptr, buffer[0],
+ (png_charp)buffer+1, (png_charp)buffer+heighti);
+ }
+}
+#endif
+
+#ifdef PNG_READ_tIME_SUPPORTED
+void /* PRIVATE */
+png_handle_tIME(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_byte buf[7];
+ png_time mod_time;
+
+ png_debug(1, "in png_handle_tIME");
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_tIME))
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "duplicate");
+ return;
+ }
+
+ if (png_ptr->mode & PNG_HAVE_IDAT)
+ png_ptr->mode |= PNG_AFTER_IDAT;
+
+ if (length != 7)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "invalid");
+ return;
+ }
+
+ png_crc_read(png_ptr, buf, 7);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ mod_time.second = buf[6];
+ mod_time.minute = buf[5];
+ mod_time.hour = buf[4];
+ mod_time.day = buf[3];
+ mod_time.month = buf[2];
+ mod_time.year = png_get_uint_16(buf);
+
+ png_set_tIME(png_ptr, info_ptr, &mod_time);
+}
+#endif
+
+#ifdef PNG_READ_tEXt_SUPPORTED
+/* Note: this does not properly handle chunks that are > 64K under DOS */
+void /* PRIVATE */
+png_handle_tEXt(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_text text_info;
+ png_bytep buffer;
+ png_charp key;
+ png_charp text;
+ png_uint_32 skip = 0;
+
+ png_debug(1, "in png_handle_tEXt");
+
+#ifdef PNG_USER_LIMITS_SUPPORTED
+ if (png_ptr->user_chunk_cache_max != 0)
+ {
+ if (png_ptr->user_chunk_cache_max == 1)
+ {
+ png_crc_finish(png_ptr, length);
+ return;
+ }
+
+ if (--png_ptr->user_chunk_cache_max == 1)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "no space in chunk cache");
+ return;
+ }
+ }
+#endif
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ if (png_ptr->mode & PNG_HAVE_IDAT)
+ png_ptr->mode |= PNG_AFTER_IDAT;
+
+#ifdef PNG_MAX_MALLOC_64K
+ if (length > 65535U)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "too large to fit in memory");
+ return;
+ }
+#endif
+
+ buffer = png_read_buffer(png_ptr, length+1, 1/*warn*/);
+
+ if (buffer == NULL)
+ {
+ png_chunk_benign_error(png_ptr, "out of memory");
+ return;
+ }
+
+ png_crc_read(png_ptr, buffer, length);
+
+ if (png_crc_finish(png_ptr, skip))
+ return;
+
+ key = (png_charp)buffer;
+ key[length] = 0;
+
+ for (text = key; *text; text++)
+ /* Empty loop to find end of key */ ;
+
+ if (text != key + length)
+ text++;
+
+ text_info.compression = PNG_TEXT_COMPRESSION_NONE;
+ text_info.key = key;
+ text_info.lang = NULL;
+ text_info.lang_key = NULL;
+ text_info.itxt_length = 0;
+ text_info.text = text;
+ text_info.text_length = strlen(text);
+
+ if (png_set_text_2(png_ptr, info_ptr, &text_info, 1))
+ png_warning(png_ptr, "Insufficient memory to process text chunk");
+}
+#endif
+
+#ifdef PNG_READ_zTXt_SUPPORTED
+/* Note: this does not correctly handle chunks that are > 64K under DOS */
+void /* PRIVATE */
+png_handle_zTXt(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_const_charp errmsg = NULL;
+ png_bytep buffer;
+ png_uint_32 keyword_length;
+
+ png_debug(1, "in png_handle_zTXt");
+
+#ifdef PNG_USER_LIMITS_SUPPORTED
+ if (png_ptr->user_chunk_cache_max != 0)
+ {
+ if (png_ptr->user_chunk_cache_max == 1)
+ {
+ png_crc_finish(png_ptr, length);
+ return;
+ }
+
+ if (--png_ptr->user_chunk_cache_max == 1)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "no space in chunk cache");
+ return;
+ }
+ }
+#endif
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ if (png_ptr->mode & PNG_HAVE_IDAT)
+ png_ptr->mode |= PNG_AFTER_IDAT;
+
+ buffer = png_read_buffer(png_ptr, length, 2/*silent*/);
+
+ if (buffer == NULL)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of memory");
+ return;
+ }
+
+ png_crc_read(png_ptr, buffer, length);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ /* TODO: also check that the keyword contents match the spec! */
+ for (keyword_length = 0;
+ keyword_length < length && buffer[keyword_length] != 0;
+ ++keyword_length)
+ /* Empty loop to find end of name */ ;
+
+ if (keyword_length > 79 || keyword_length < 1)
+ errmsg = "bad keyword";
+
+ /* zTXt must have some LZ data after the keyword, although it may expand to
+ * zero bytes; we need a '\0' at the end of the keyword, the compression type
+ * then the LZ data:
+ */
+ else if (keyword_length + 3 > length)
+ errmsg = "truncated";
+
+ else if (buffer[keyword_length+1] != PNG_COMPRESSION_TYPE_BASE)
+ errmsg = "unknown compression type";
+
+ else
+ {
+ png_alloc_size_t uncompressed_length = PNG_SIZE_MAX;
+
+ /* TODO: at present png_decompress_chunk imposes a single application
+ * level memory limit, this should be split to different values for iCCP
+ * and text chunks.
+ */
+ if (png_decompress_chunk(png_ptr, length, keyword_length+2,
+ &uncompressed_length, 1/*terminate*/) == Z_STREAM_END)
+ {
+ png_text text;
+
+ /* It worked; png_ptr->read_buffer now looks like a tEXt chunk except
+ * for the extra compression type byte and the fact that it isn't
+ * necessarily '\0' terminated.
+ */
+ buffer = png_ptr->read_buffer;
+ buffer[uncompressed_length+(keyword_length+2)] = 0;
+
+ text.compression = PNG_TEXT_COMPRESSION_zTXt;
+ text.key = (png_charp)buffer;
+ text.text = (png_charp)(buffer + keyword_length+2);
+ text.text_length = uncompressed_length;
+ text.itxt_length = 0;
+ text.lang = NULL;
+ text.lang_key = NULL;
+
+ if (png_set_text_2(png_ptr, info_ptr, &text, 1))
+ errmsg = "insufficient memory";
+ }
+
+ else
+ errmsg = png_ptr->zstream.msg;
+ }
+
+ if (errmsg != NULL)
+ png_chunk_benign_error(png_ptr, errmsg);
+}
+#endif
+
+#ifdef PNG_READ_iTXt_SUPPORTED
+/* Note: this does not correctly handle chunks that are > 64K under DOS */
+void /* PRIVATE */
+png_handle_iTXt(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length)
+{
+ png_const_charp errmsg = NULL;
+ png_bytep buffer;
+ png_uint_32 prefix_length;
+
+ png_debug(1, "in png_handle_iTXt");
+
+#ifdef PNG_USER_LIMITS_SUPPORTED
+ if (png_ptr->user_chunk_cache_max != 0)
+ {
+ if (png_ptr->user_chunk_cache_max == 1)
+ {
+ png_crc_finish(png_ptr, length);
+ return;
+ }
+
+ if (--png_ptr->user_chunk_cache_max == 1)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "no space in chunk cache");
+ return;
+ }
+ }
+#endif
+
+ if (!(png_ptr->mode & PNG_HAVE_IHDR))
+ png_chunk_error(png_ptr, "missing IHDR");
+
+ if (png_ptr->mode & PNG_HAVE_IDAT)
+ png_ptr->mode |= PNG_AFTER_IDAT;
+
+ buffer = png_read_buffer(png_ptr, length+1, 1/*warn*/);
+
+ if (buffer == NULL)
+ {
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "out of memory");
+ return;
+ }
+
+ png_crc_read(png_ptr, buffer, length);
+
+ if (png_crc_finish(png_ptr, 0))
+ return;
+
+ /* First the keyword. */
+ for (prefix_length=0;
+ prefix_length < length && buffer[prefix_length] != 0;
+ ++prefix_length)
+ /* Empty loop */ ;
+
+ /* Perform a basic check on the keyword length here. */
+ if (prefix_length > 79 || prefix_length < 1)
+ errmsg = "bad keyword";
+
+ /* Expect keyword, compression flag, compression type, language, translated
+ * keyword (both may be empty but are 0 terminated) then the text, which may
+ * be empty.
+ */
+ else if (prefix_length + 5 > length)
+ errmsg = "truncated";
+
+ else if (buffer[prefix_length+1] == 0 ||
+ (buffer[prefix_length+1] == 1 &&
+ buffer[prefix_length+2] == PNG_COMPRESSION_TYPE_BASE))
+ {
+ int compressed = buffer[prefix_length+1] != 0;
+ png_uint_32 language_offset, translated_keyword_offset;
+ png_alloc_size_t uncompressed_length = 0;
+
+ /* Now the language tag */
+ prefix_length += 3;
+ language_offset = prefix_length;
+
+ for (; prefix_length < length && buffer[prefix_length] != 0;
+ ++prefix_length)
+ /* Empty loop */ ;
+
+ /* WARNING: the length may be invalid here, this is checked below. */
+ translated_keyword_offset = ++prefix_length;
+
+ for (; prefix_length < length && buffer[prefix_length] != 0;
+ ++prefix_length)
+ /* Empty loop */ ;
+
+ /* prefix_length should now be at the trailing '\0' of the translated
+ * keyword, but it may already be over the end. None of this arithmetic
+ * can overflow because chunks are at most 2^31 bytes long, but on 16-bit
+ * systems the available allocaton may overflow.
+ */
+ ++prefix_length;
+
+ if (!compressed && prefix_length <= length)
+ uncompressed_length = length - prefix_length;
+
+ else if (compressed && prefix_length < length)
+ {
+ uncompressed_length = PNG_SIZE_MAX;
+
+ /* TODO: at present png_decompress_chunk imposes a single application
+ * level memory limit, this should be split to different values for
+ * iCCP and text chunks.
+ */
+ if (png_decompress_chunk(png_ptr, length, prefix_length,
+ &uncompressed_length, 1/*terminate*/) == Z_STREAM_END)
+ buffer = png_ptr->read_buffer;
+
+ else
+ errmsg = png_ptr->zstream.msg;
+ }
+
+ else
+ errmsg = "truncated";
+
+ if (errmsg == NULL)
+ {
+ png_text text;
+
+ buffer[uncompressed_length+prefix_length] = 0;
+
+ if (compressed)
+ text.compression = PNG_ITXT_COMPRESSION_NONE;
+
+ else
+ text.compression = PNG_ITXT_COMPRESSION_zTXt;
+
+ text.key = (png_charp)buffer;
+ text.lang = (png_charp)buffer + language_offset;
+ text.lang_key = (png_charp)buffer + translated_keyword_offset;
+ text.text = (png_charp)buffer + prefix_length;
+ text.text_length = 0;
+ text.itxt_length = uncompressed_length;
+
+ if (png_set_text_2(png_ptr, info_ptr, &text, 1))
+ errmsg = "insufficient memory";
+ }
+ }
+
+ else
+ errmsg = "bad compression info";
+
+ if (errmsg != NULL)
+ png_chunk_benign_error(png_ptr, errmsg);
+}
+#endif
+
+#ifdef PNG_READ_UNKNOWN_CHUNKS_SUPPORTED
+/* Utility function for png_handle_unknown; set up png_ptr::unknown_chunk */
+static int
+png_cache_unknown_chunk(png_structrp png_ptr, png_uint_32 length)
+{
+ png_alloc_size_t limit = PNG_SIZE_MAX;
+
+ if (png_ptr->unknown_chunk.data != NULL)
+ {
+ png_free(png_ptr, png_ptr->unknown_chunk.data);
+ png_ptr->unknown_chunk.data = NULL;
+ }
+
+# ifdef PNG_SET_CHUNK_MALLOC_LIMIT_SUPPORTED
+ if (png_ptr->user_chunk_malloc_max > 0 &&
+ png_ptr->user_chunk_malloc_max < limit)
+ limit = png_ptr->user_chunk_malloc_max;
+
+# elif PNG_USER_CHUNK_MALLOC_MAX > 0
+ if (PNG_USER_CHUNK_MALLOC_MAX < limit)
+ limit = PNG_USER_CHUNK_MALLOC_MAX;
+# endif
+
+ if (length <= limit)
+ {
+ PNG_CSTRING_FROM_CHUNK(png_ptr->unknown_chunk.name, png_ptr->chunk_name);
+ /* The following is safe because of the PNG_SIZE_MAX init above */
+ png_ptr->unknown_chunk.size = (png_size_t)length/*SAFE*/;
+ /* 'mode' is a flag array, only the bottom four bits matter here */
+ png_ptr->unknown_chunk.location = (png_byte)png_ptr->mode/*SAFE*/;
+
+ if (length == 0)
+ png_ptr->unknown_chunk.data = NULL;
+
+ else
+ {
+ /* Do a 'warn' here - it is handled below. */
+ png_ptr->unknown_chunk.data = png_voidcast(png_bytep,
+ png_malloc_warn(png_ptr, length));
+ }
+ }
+
+ if (png_ptr->unknown_chunk.data == NULL && length > 0)
+ {
+ /* This is benign because we clean up correctly */
+ png_crc_finish(png_ptr, length);
+ png_chunk_benign_error(png_ptr, "unknown chunk exceeds memory limits");
+ return 0;
+ }
+
+ else
+ {
+ if (length > 0)
+ png_crc_read(png_ptr, png_ptr->unknown_chunk.data, length);
+ png_crc_finish(png_ptr, 0);
+ return 1;
+ }
+}
+#endif /* PNG_READ_UNKNOWN_CHUNKS_SUPPORTED */
+
+/* Handle an unknown, or known but disabled, chunk */
+void /* PRIVATE */
+png_handle_unknown(png_structrp png_ptr, png_inforp info_ptr,
+ png_uint_32 length, int keep)
+{
+ int handled = 0; /* the chunk was handled */
+
+ png_debug(1, "in png_handle_unknown");
+
+#ifdef PNG_READ_UNKNOWN_CHUNKS_SUPPORTED
+ /* NOTE: this code is based on the code in libpng-1.4.12 except for fixing
+ * the bug which meant that setting a non-default behavior for a specific
+ * chunk would be ignored (the default was always used unless a user
+ * callback was installed).
+ *
+ * 'keep' is the value from the png_chunk_unknown_handling, the setting for
+ * this specific chunk_name, if PNG_HANDLE_AS_UNKNOWN_SUPPORTED, if not it
+ * will always be PNG_HANDLE_CHUNK_AS_DEFAULT and it needs to be set here.
+ * This is just an optimization to avoid multiple calls to the lookup
+ * function.
+ */
+# ifndef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+# ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+ keep = png_chunk_unknown_handling(png_ptr, png_ptr->chunk_name);
+# endif
+# endif
+
+ /* One of the following methods will read the chunk or skip it (at least one
+ * of these is always defined because this is the only way to switch on
+ * PNG_READ_UNKNOWN_CHUNKS_SUPPORTED)
+ */
+# ifdef PNG_READ_USER_CHUNKS_SUPPORTED
+ /* The user callback takes precedence over the chunk keep value, but the
+ * keep value is still required to validate a save of a critical chunk.
+ */
+ if (png_ptr->read_user_chunk_fn != NULL)
+ {
+ if (png_cache_unknown_chunk(png_ptr, length))
+ {
+ /* Callback to user unknown chunk handler */
+ int ret = (*(png_ptr->read_user_chunk_fn))(png_ptr,
+ &png_ptr->unknown_chunk);
+
+ /* ret is:
+ * negative: An error occured, png_chunk_error will be called.
+ * zero: The chunk was not handled, the chunk will be discarded
+ * unless png_set_keep_unknown_chunks has been used to set
+ * a 'keep' behavior for this particular chunk, in which
+ * case that will be used. A critical chunk will cause an
+ * error at this point unless it is to be saved.
+ * positive: The chunk was handled, libpng will ignore/discard it.
+ */
+ if (ret < 0)
+ png_chunk_error(png_ptr, "error in user chunk");
+
+ else if (ret == 0)
+ {
+ /* If the keep value is 'default' or 'never' override it, but
+ * still error out on critical chunks unless the keep value is
+ * 'always' While this is weird it is the behavior in 1.4.12.
+ * A possible improvement would be to obey the value set for the
+ * chunk, but this would be an API change that would probably
+ * damage some applications.
+ *
+ * The png_app_warning below catches the case that matters, where
+ * the application has not set specific save or ignore for this
+ * chunk or global save or ignore.
+ */
+ if (keep < PNG_HANDLE_CHUNK_IF_SAFE)
+ {
+# ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+ if (png_ptr->unknown_default < PNG_HANDLE_CHUNK_IF_SAFE)
+ {
+ png_chunk_warning(png_ptr, "Saving unknown chunk:");
+ png_app_warning(png_ptr,
+ "forcing save of an unhandled chunk;"
+ " please call png_set_keep_unknown_chunks");
+ /* with keep = PNG_HANDLE_CHUNK_IF_SAFE */
+ }
+# endif
+ keep = PNG_HANDLE_CHUNK_IF_SAFE;
+ }
+ }
+
+ else /* chunk was handled */
+ {
+ handled = 1;
+ /* Critical chunks can be safely discarded at this point. */
+ keep = PNG_HANDLE_CHUNK_NEVER;
+ }
+ }
+
+ else
+ keep = PNG_HANDLE_CHUNK_NEVER; /* insufficient memory */
+ }
+
+ else
+ /* Use the SAVE_UNKNOWN_CHUNKS code or skip the chunk */
+# endif /* PNG_READ_USER_CHUNKS_SUPPORTED */
+
+# ifdef PNG_SAVE_UNKNOWN_CHUNKS_SUPPORTED
+ {
+ /* keep is currently just the per-chunk setting, if there was no
+ * setting change it to the global default now (not that this may
+ * still be AS_DEFAULT) then obtain the cache of the chunk if required,
+ * if not simply skip the chunk.
+ */
+ if (keep == PNG_HANDLE_CHUNK_AS_DEFAULT)
+ keep = png_ptr->unknown_default;
+
+ if (keep == PNG_HANDLE_CHUNK_ALWAYS ||
+ (keep == PNG_HANDLE_CHUNK_IF_SAFE &&
+ PNG_CHUNK_ANCILLARY(png_ptr->chunk_name)))
+ {
+ if (!png_cache_unknown_chunk(png_ptr, length))
+ keep = PNG_HANDLE_CHUNK_NEVER;
+ }
+
+ else
+ png_crc_finish(png_ptr, length);
+ }
+# else
+# ifndef PNG_READ_USER_CHUNKS_SUPPORTED
+# error no method to support READ_UNKNOWN_CHUNKS
+# endif
+
+ {
+ /* If here there is no read callback pointer set and no support is
+ * compiled in to just save the unknown chunks, so simply skip this
+ * chunk. If 'keep' is something other than AS_DEFAULT or NEVER then
+ * the app has erroneously asked for unknown chunk saving when there
+ * is no support.
+ */
+ if (keep > PNG_HANDLE_CHUNK_NEVER)
+ png_app_error(png_ptr, "no unknown chunk support available");
+
+ png_crc_finish(png_ptr, length);
+ }
+# endif
+
+# ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED
+ /* Now store the chunk in the chunk list if appropriate, and if the limits
+ * permit it.
+ */
+ if (keep == PNG_HANDLE_CHUNK_ALWAYS ||
+ (keep == PNG_HANDLE_CHUNK_IF_SAFE &&
+ PNG_CHUNK_ANCILLARY(png_ptr->chunk_name)))
+ {
+# ifdef PNG_USER_LIMITS_SUPPORTED
+ switch (png_ptr->user_chunk_cache_max)
+ {
+ case 2:
+ png_ptr->user_chunk_cache_max = 1;
+ png_chunk_benign_error(png_ptr, "no space in chunk cache");
+ /* FALL THROUGH */
+ case 1:
+ /* NOTE: prior to 1.6.0 this case resulted in an unknown critical
+ * chunk being skipped, now there will be a hard error below.
+ */
+ break;
+
+ default: /* not at limit */
+ --(png_ptr->user_chunk_cache_max);
+ /* FALL THROUGH */
+ case 0: /* no limit */
+# endif /* PNG_USER_LIMITS_SUPPORTED */
+ /* Here when the limit isn't reached or when limits are compiled
+ * out; store the chunk.
+ */
+ png_set_unknown_chunks(png_ptr, info_ptr,
+ &png_ptr->unknown_chunk, 1);
+ handled = 1;
+# ifdef PNG_USER_LIMITS_SUPPORTED
+ break;
+ }
+# endif
+ }
+# else /* no store support: the chunk must be handled by the user callback */
+ PNG_UNUSED(info_ptr)
+# endif
+
+ /* Regardless of the error handling below the cached data (if any) can be
+ * freed now. Notice that the data is not freed if there is a png_error, but
+ * it will be freed by destroy_read_struct.
+ */
+ if (png_ptr->unknown_chunk.data != NULL)
+ png_free(png_ptr, png_ptr->unknown_chunk.data);
+ png_ptr->unknown_chunk.data = NULL;
+
+#else /* !PNG_READ_UNKNOWN_CHUNKS_SUPPORTED */
+ /* There is no support to read an unknown chunk, so just skip it. */
+ png_crc_finish(png_ptr, length);
+ PNG_UNUSED(info_ptr)
+ PNG_UNUSED(keep)
+#endif /* !PNG_READ_UNKNOWN_CHUNKS_SUPPORTED */
+
+ /* Check for unhandled critical chunks */
+ if (!handled && PNG_CHUNK_CRITICAL(png_ptr->chunk_name))
+ png_chunk_error(png_ptr, "unhandled critical chunk");
+}
+
+/* This function is called to verify that a chunk name is valid.
+ * This function can't have the "critical chunk check" incorporated
+ * into it, since in the future we will need to be able to call user
+ * functions to handle unknown critical chunks after we check that
+ * the chunk name itself is valid.
+ */
+
+/* Bit hacking: the test for an invalid byte in the 4 byte chunk name is:
+ *
+ * ((c) < 65 || (c) > 122 || ((c) > 90 && (c) < 97))
+ */
+
+void /* PRIVATE */
+png_check_chunk_name(png_structrp png_ptr, png_uint_32 chunk_name)
+{
+ int i;
+
+ png_debug(1, "in png_check_chunk_name");
+
+ for (i=1; i<=4; ++i)
+ {
+ int c = chunk_name & 0xff;
+
+ if (c < 65 || c > 122 || (c > 90 && c < 97))
+ png_chunk_error(png_ptr, "invalid chunk type");
+
+ chunk_name >>= 8;
+ }
+}
+
+/* Combines the row recently read in with the existing pixels in the row. This
+ * routine takes care of alpha and transparency if requested. This routine also
+ * handles the two methods of progressive display of interlaced images,
+ * depending on the 'display' value; if 'display' is true then the whole row
+ * (dp) is filled from the start by replicating the available pixels. If
+ * 'display' is false only those pixels present in the pass are filled in.
+ */
+void /* PRIVATE */
+png_combine_row(png_const_structrp png_ptr, png_bytep dp, int display)
+{
+ unsigned int pixel_depth = png_ptr->transformed_pixel_depth;
+ png_const_bytep sp = png_ptr->row_buf + 1;
+ png_uint_32 row_width = png_ptr->width;
+ unsigned int pass = png_ptr->pass;
+ png_bytep end_ptr = 0;
+ png_byte end_byte = 0;
+ unsigned int end_mask;
+
+ png_debug(1, "in png_combine_row");
+
+ /* Added in 1.5.6: it should not be possible to enter this routine until at
+ * least one row has been read from the PNG data and transformed.
+ */
+ if (pixel_depth == 0)
+ png_error(png_ptr, "internal row logic error");
+
+ /* Added in 1.5.4: the pixel depth should match the information returned by
+ * any call to png_read_update_info at this point. Do not continue if we got
+ * this wrong.
+ */
+ if (png_ptr->info_rowbytes != 0 && png_ptr->info_rowbytes !=
+ PNG_ROWBYTES(pixel_depth, row_width))
+ png_error(png_ptr, "internal row size calculation error");
+
+ /* Don't expect this to ever happen: */
+ if (row_width == 0)
+ png_error(png_ptr, "internal row width error");
+
+ /* Preserve the last byte in cases where only part of it will be overwritten,
+ * the multiply below may overflow, we don't care because ANSI-C guarantees
+ * we get the low bits.
+ */
+ end_mask = (pixel_depth * row_width) & 7;
+ if (end_mask != 0)
+ {
+ /* end_ptr == NULL is a flag to say do nothing */
+ end_ptr = dp + PNG_ROWBYTES(pixel_depth, row_width) - 1;
+ end_byte = *end_ptr;
+# ifdef PNG_READ_PACKSWAP_SUPPORTED
+ if (png_ptr->transformations & PNG_PACKSWAP) /* little-endian byte */
+ end_mask = 0xff << end_mask;
+
+ else /* big-endian byte */
+# endif
+ end_mask = 0xff >> end_mask;
+ /* end_mask is now the bits to *keep* from the destination row */
+ }
+
+ /* For non-interlaced images this reduces to a memcpy(). A memcpy()
+ * will also happen if interlacing isn't supported or if the application
+ * does not call png_set_interlace_handling(). In the latter cases the
+ * caller just gets a sequence of the unexpanded rows from each interlace
+ * pass.
+ */
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ if (png_ptr->interlaced && (png_ptr->transformations & PNG_INTERLACE) &&
+ pass < 6 && (display == 0 ||
+ /* The following copies everything for 'display' on passes 0, 2 and 4. */
+ (display == 1 && (pass & 1) != 0)))
+ {
+ /* Narrow images may have no bits in a pass; the caller should handle
+ * this, but this test is cheap:
+ */
+ if (row_width <= PNG_PASS_START_COL(pass))
+ return;
+
+ if (pixel_depth < 8)
+ {
+ /* For pixel depths up to 4 bpp the 8-pixel mask can be expanded to fit
+ * into 32 bits, then a single loop over the bytes using the four byte
+ * values in the 32-bit mask can be used. For the 'display' option the
+ * expanded mask may also not require any masking within a byte. To
+ * make this work the PACKSWAP option must be taken into account - it
+ * simply requires the pixels to be reversed in each byte.
+ *
+ * The 'regular' case requires a mask for each of the first 6 passes,
+ * the 'display' case does a copy for the even passes in the range
+ * 0..6. This has already been handled in the test above.
+ *
+ * The masks are arranged as four bytes with the first byte to use in
+ * the lowest bits (little-endian) regardless of the order (PACKSWAP or
+ * not) of the pixels in each byte.
+ *
+ * NOTE: the whole of this logic depends on the caller of this function
+ * only calling it on rows appropriate to the pass. This function only
+ * understands the 'x' logic; the 'y' logic is handled by the caller.
+ *
+ * The following defines allow generation of compile time constant bit
+ * masks for each pixel depth and each possibility of swapped or not
+ * swapped bytes. Pass 'p' is in the range 0..6; 'x', a pixel index,
+ * is in the range 0..7; and the result is 1 if the pixel is to be
+ * copied in the pass, 0 if not. 'S' is for the sparkle method, 'B'
+ * for the block method.
+ *
+ * With some compilers a compile time expression of the general form:
+ *
+ * (shift >= 32) ? (a >> (shift-32)) : (b >> shift)
+ *
+ * Produces warnings with values of 'shift' in the range 33 to 63
+ * because the right hand side of the ?: expression is evaluated by
+ * the compiler even though it isn't used. Microsoft Visual C (various
+ * versions) and the Intel C compiler are known to do this. To avoid
+ * this the following macros are used in 1.5.6. This is a temporary
+ * solution to avoid destabilizing the code during the release process.
+ */
+# if PNG_USE_COMPILE_TIME_MASKS
+# define PNG_LSR(x,s) ((x)>>((s) & 0x1f))
+# define PNG_LSL(x,s) ((x)<<((s) & 0x1f))
+# else
+# define PNG_LSR(x,s) ((x)>>(s))
+# define PNG_LSL(x,s) ((x)<<(s))
+# endif
+# define S_COPY(p,x) (((p)<4 ? PNG_LSR(0x80088822,(3-(p))*8+(7-(x))) :\
+ PNG_LSR(0xaa55ff00,(7-(p))*8+(7-(x)))) & 1)
+# define B_COPY(p,x) (((p)<4 ? PNG_LSR(0xff0fff33,(3-(p))*8+(7-(x))) :\
+ PNG_LSR(0xff55ff00,(7-(p))*8+(7-(x)))) & 1)
+
+ /* Return a mask for pass 'p' pixel 'x' at depth 'd'. The mask is
+ * little endian - the first pixel is at bit 0 - however the extra
+ * parameter 's' can be set to cause the mask position to be swapped
+ * within each byte, to match the PNG format. This is done by XOR of
+ * the shift with 7, 6 or 4 for bit depths 1, 2 and 4.
+ */
+# define PIXEL_MASK(p,x,d,s) \
+ (PNG_LSL(((PNG_LSL(1U,(d)))-1),(((x)*(d))^((s)?8-(d):0))))
+
+ /* Hence generate the appropriate 'block' or 'sparkle' pixel copy mask.
+ */
+# define S_MASKx(p,x,d,s) (S_COPY(p,x)?PIXEL_MASK(p,x,d,s):0)
+# define B_MASKx(p,x,d,s) (B_COPY(p,x)?PIXEL_MASK(p,x,d,s):0)
+
+ /* Combine 8 of these to get the full mask. For the 1-bpp and 2-bpp
+ * cases the result needs replicating, for the 4-bpp case the above
+ * generates a full 32 bits.
+ */
+# define MASK_EXPAND(m,d) ((m)*((d)==1?0x01010101:((d)==2?0x00010001:1)))
+
+# define S_MASK(p,d,s) MASK_EXPAND(S_MASKx(p,0,d,s) + S_MASKx(p,1,d,s) +\
+ S_MASKx(p,2,d,s) + S_MASKx(p,3,d,s) + S_MASKx(p,4,d,s) +\
+ S_MASKx(p,5,d,s) + S_MASKx(p,6,d,s) + S_MASKx(p,7,d,s), d)
+
+# define B_MASK(p,d,s) MASK_EXPAND(B_MASKx(p,0,d,s) + B_MASKx(p,1,d,s) +\
+ B_MASKx(p,2,d,s) + B_MASKx(p,3,d,s) + B_MASKx(p,4,d,s) +\
+ B_MASKx(p,5,d,s) + B_MASKx(p,6,d,s) + B_MASKx(p,7,d,s), d)
+
+#if PNG_USE_COMPILE_TIME_MASKS
+ /* Utility macros to construct all the masks for a depth/swap
+ * combination. The 's' parameter says whether the format is PNG
+ * (big endian bytes) or not. Only the three odd-numbered passes are
+ * required for the display/block algorithm.
+ */
+# define S_MASKS(d,s) { S_MASK(0,d,s), S_MASK(1,d,s), S_MASK(2,d,s),\
+ S_MASK(3,d,s), S_MASK(4,d,s), S_MASK(5,d,s) }
+
+# define B_MASKS(d,s) { B_MASK(1,d,s), S_MASK(3,d,s), S_MASK(5,d,s) }
+
+# define DEPTH_INDEX(d) ((d)==1?0:((d)==2?1:2))
+
+ /* Hence the pre-compiled masks indexed by PACKSWAP (or not), depth and
+ * then pass:
+ */
+ static PNG_CONST png_uint_32 row_mask[2/*PACKSWAP*/][3/*depth*/][6] =
+ {
+ /* Little-endian byte masks for PACKSWAP */
+ { S_MASKS(1,0), S_MASKS(2,0), S_MASKS(4,0) },
+ /* Normal (big-endian byte) masks - PNG format */
+ { S_MASKS(1,1), S_MASKS(2,1), S_MASKS(4,1) }
+ };
+
+ /* display_mask has only three entries for the odd passes, so index by
+ * pass>>1.
+ */
+ static PNG_CONST png_uint_32 display_mask[2][3][3] =
+ {
+ /* Little-endian byte masks for PACKSWAP */
+ { B_MASKS(1,0), B_MASKS(2,0), B_MASKS(4,0) },
+ /* Normal (big-endian byte) masks - PNG format */
+ { B_MASKS(1,1), B_MASKS(2,1), B_MASKS(4,1) }
+ };
+
+# define MASK(pass,depth,display,png)\
+ ((display)?display_mask[png][DEPTH_INDEX(depth)][pass>>1]:\
+ row_mask[png][DEPTH_INDEX(depth)][pass])
+
+#else /* !PNG_USE_COMPILE_TIME_MASKS */
+ /* This is the runtime alternative: it seems unlikely that this will
+ * ever be either smaller or faster than the compile time approach.
+ */
+# define MASK(pass,depth,display,png)\
+ ((display)?B_MASK(pass,depth,png):S_MASK(pass,depth,png))
+#endif /* !PNG_USE_COMPILE_TIME_MASKS */
+
+ /* Use the appropriate mask to copy the required bits. In some cases
+ * the byte mask will be 0 or 0xff, optimize these cases. row_width is
+ * the number of pixels, but the code copies bytes, so it is necessary
+ * to special case the end.
+ */
+ png_uint_32 pixels_per_byte = 8 / pixel_depth;
+ png_uint_32 mask;
+
+# ifdef PNG_READ_PACKSWAP_SUPPORTED
+ if (png_ptr->transformations & PNG_PACKSWAP)
+ mask = MASK(pass, pixel_depth, display, 0);
+
+ else
+# endif
+ mask = MASK(pass, pixel_depth, display, 1);
+
+ for (;;)
+ {
+ png_uint_32 m;
+
+ /* It doesn't matter in the following if png_uint_32 has more than
+ * 32 bits because the high bits always match those in m<<24; it is,
+ * however, essential to use OR here, not +, because of this.
+ */
+ m = mask;
+ mask = (m >> 8) | (m << 24); /* rotate right to good compilers */
+ m &= 0xff;
+
+ if (m != 0) /* something to copy */
+ {
+ if (m != 0xff)
+ *dp = (png_byte)((*dp & ~m) | (*sp & m));
+ else
+ *dp = *sp;
+ }
+
+ /* NOTE: this may overwrite the last byte with garbage if the image
+ * is not an exact number of bytes wide; libpng has always done
+ * this.
+ */
+ if (row_width <= pixels_per_byte)
+ break; /* May need to restore part of the last byte */
+
+ row_width -= pixels_per_byte;
+ ++dp;
+ ++sp;
+ }
+ }
+
+ else /* pixel_depth >= 8 */
+ {
+ unsigned int bytes_to_copy, bytes_to_jump;
+
+ /* Validate the depth - it must be a multiple of 8 */
+ if (pixel_depth & 7)
+ png_error(png_ptr, "invalid user transform pixel depth");
+
+ pixel_depth >>= 3; /* now in bytes */
+ row_width *= pixel_depth;
+
+ /* Regardless of pass number the Adam 7 interlace always results in a
+ * fixed number of pixels to copy then to skip. There may be a
+ * different number of pixels to skip at the start though.
+ */
+ {
+ unsigned int offset = PNG_PASS_START_COL(pass) * pixel_depth;
+
+ row_width -= offset;
+ dp += offset;
+ sp += offset;
+ }
+
+ /* Work out the bytes to copy. */
+ if (display)
+ {
+ /* When doing the 'block' algorithm the pixel in the pass gets
+ * replicated to adjacent pixels. This is why the even (0,2,4,6)
+ * passes are skipped above - the entire expanded row is copied.
+ */
+ bytes_to_copy = (1<<((6-pass)>>1)) * pixel_depth;
+
+ /* But don't allow this number to exceed the actual row width. */
+ if (bytes_to_copy > row_width)
+ bytes_to_copy = row_width;
+ }
+
+ else /* normal row; Adam7 only ever gives us one pixel to copy. */
+ bytes_to_copy = pixel_depth;
+
+ /* In Adam7 there is a constant offset between where the pixels go. */
+ bytes_to_jump = PNG_PASS_COL_OFFSET(pass) * pixel_depth;
+
+ /* And simply copy these bytes. Some optimization is possible here,
+ * depending on the value of 'bytes_to_copy'. Special case the low
+ * byte counts, which we know to be frequent.
+ *
+ * Notice that these cases all 'return' rather than 'break' - this
+ * avoids an unnecessary test on whether to restore the last byte
+ * below.
+ */
+ switch (bytes_to_copy)
+ {
+ case 1:
+ for (;;)
+ {
+ *dp = *sp;
+
+ if (row_width <= bytes_to_jump)
+ return;
+
+ dp += bytes_to_jump;
+ sp += bytes_to_jump;
+ row_width -= bytes_to_jump;
+ }
+
+ case 2:
+ /* There is a possibility of a partial copy at the end here; this
+ * slows the code down somewhat.
+ */
+ do
+ {
+ dp[0] = sp[0], dp[1] = sp[1];
+
+ if (row_width <= bytes_to_jump)
+ return;
+
+ sp += bytes_to_jump;
+ dp += bytes_to_jump;
+ row_width -= bytes_to_jump;
+ }
+ while (row_width > 1);
+
+ /* And there can only be one byte left at this point: */
+ *dp = *sp;
+ return;
+
+ case 3:
+ /* This can only be the RGB case, so each copy is exactly one
+ * pixel and it is not necessary to check for a partial copy.
+ */
+ for(;;)
+ {
+ dp[0] = sp[0], dp[1] = sp[1], dp[2] = sp[2];
+
+ if (row_width <= bytes_to_jump)
+ return;
+
+ sp += bytes_to_jump;
+ dp += bytes_to_jump;
+ row_width -= bytes_to_jump;
+ }
+
+ default:
+#if PNG_ALIGN_TYPE != PNG_ALIGN_NONE
+ /* Check for double byte alignment and, if possible, use a
+ * 16-bit copy. Don't attempt this for narrow images - ones that
+ * are less than an interlace panel wide. Don't attempt it for
+ * wide bytes_to_copy either - use the memcpy there.
+ */
+ if (bytes_to_copy < 16 /*else use memcpy*/ &&
+ png_isaligned(dp, png_uint_16) &&
+ png_isaligned(sp, png_uint_16) &&
+ bytes_to_copy % (sizeof (png_uint_16)) == 0 &&
+ bytes_to_jump % (sizeof (png_uint_16)) == 0)
+ {
+ /* Everything is aligned for png_uint_16 copies, but try for
+ * png_uint_32 first.
+ */
+ if (png_isaligned(dp, png_uint_32) &&
+ png_isaligned(sp, png_uint_32) &&
+ bytes_to_copy % (sizeof (png_uint_32)) == 0 &&
+ bytes_to_jump % (sizeof (png_uint_32)) == 0)
+ {
+ png_uint_32p dp32 = png_aligncast(png_uint_32p,dp);
+ png_const_uint_32p sp32 = png_aligncastconst(
+ png_const_uint_32p, sp);
+ size_t skip = (bytes_to_jump-bytes_to_copy) /
+ (sizeof (png_uint_32));
+
+ do
+ {
+ size_t c = bytes_to_copy;
+ do
+ {
+ *dp32++ = *sp32++;
+ c -= (sizeof (png_uint_32));
+ }
+ while (c > 0);
+
+ if (row_width <= bytes_to_jump)
+ return;
+
+ dp32 += skip;
+ sp32 += skip;
+ row_width -= bytes_to_jump;
+ }
+ while (bytes_to_copy <= row_width);
+
+ /* Get to here when the row_width truncates the final copy.
+ * There will be 1-3 bytes left to copy, so don't try the
+ * 16-bit loop below.
+ */
+ dp = (png_bytep)dp32;
+ sp = (png_const_bytep)sp32;
+ do
+ *dp++ = *sp++;
+ while (--row_width > 0);
+ return;
+ }
+
+ /* Else do it in 16-bit quantities, but only if the size is
+ * not too large.
+ */
+ else
+ {
+ png_uint_16p dp16 = png_aligncast(png_uint_16p, dp);
+ png_const_uint_16p sp16 = png_aligncastconst(
+ png_const_uint_16p, sp);
+ size_t skip = (bytes_to_jump-bytes_to_copy) /
+ (sizeof (png_uint_16));
+
+ do
+ {
+ size_t c = bytes_to_copy;
+ do
+ {
+ *dp16++ = *sp16++;
+ c -= (sizeof (png_uint_16));
+ }
+ while (c > 0);
+
+ if (row_width <= bytes_to_jump)
+ return;
+
+ dp16 += skip;
+ sp16 += skip;
+ row_width -= bytes_to_jump;
+ }
+ while (bytes_to_copy <= row_width);
+
+ /* End of row - 1 byte left, bytes_to_copy > row_width: */
+ dp = (png_bytep)dp16;
+ sp = (png_const_bytep)sp16;
+ do
+ *dp++ = *sp++;
+ while (--row_width > 0);
+ return;
+ }
+ }
+#endif /* PNG_ALIGN_ code */
+
+ /* The true default - use a memcpy: */
+ for (;;)
+ {
+ memcpy(dp, sp, bytes_to_copy);
+
+ if (row_width <= bytes_to_jump)
+ return;
+
+ sp += bytes_to_jump;
+ dp += bytes_to_jump;
+ row_width -= bytes_to_jump;
+ if (bytes_to_copy > row_width)
+ bytes_to_copy = row_width;
+ }
+ }
+
+ /* NOT REACHED*/
+ } /* pixel_depth >= 8 */
+
+ /* Here if pixel_depth < 8 to check 'end_ptr' below. */
+ }
+ else
+#endif
+
+ /* If here then the switch above wasn't used so just memcpy the whole row
+ * from the temporary row buffer (notice that this overwrites the end of the
+ * destination row if it is a partial byte.)
+ */
+ memcpy(dp, sp, PNG_ROWBYTES(pixel_depth, row_width));
+
+ /* Restore the overwritten bits from the last byte if necessary. */
+ if (end_ptr != NULL)
+ *end_ptr = (png_byte)((end_byte & end_mask) | (*end_ptr & ~end_mask));
+}
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+void /* PRIVATE */
+png_do_read_interlace(png_row_infop row_info, png_bytep row, int pass,
+ png_uint_32 transformations /* Because these may affect the byte layout */)
+{
+ /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */
+ /* Offset to next interlace block */
+ static PNG_CONST int png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1};
+
+ png_debug(1, "in png_do_read_interlace");
+ if (row != NULL && row_info != NULL)
+ {
+ png_uint_32 final_width;
+
+ final_width = row_info->width * png_pass_inc[pass];
+
+ switch (row_info->pixel_depth)
+ {
+ case 1:
+ {
+ png_bytep sp = row + (png_size_t)((row_info->width - 1) >> 3);
+ png_bytep dp = row + (png_size_t)((final_width - 1) >> 3);
+ int sshift, dshift;
+ int s_start, s_end, s_inc;
+ int jstop = png_pass_inc[pass];
+ png_byte v;
+ png_uint_32 i;
+ int j;
+
+#ifdef PNG_READ_PACKSWAP_SUPPORTED
+ if (transformations & PNG_PACKSWAP)
+ {
+ sshift = (int)((row_info->width + 7) & 0x07);
+ dshift = (int)((final_width + 7) & 0x07);
+ s_start = 7;
+ s_end = 0;
+ s_inc = -1;
+ }
+
+ else
+#endif
+ {
+ sshift = 7 - (int)((row_info->width + 7) & 0x07);
+ dshift = 7 - (int)((final_width + 7) & 0x07);
+ s_start = 0;
+ s_end = 7;
+ s_inc = 1;
+ }
+
+ for (i = 0; i < row_info->width; i++)
+ {
+ v = (png_byte)((*sp >> sshift) & 0x01);
+ for (j = 0; j < jstop; j++)
+ {
+ unsigned int tmp = *dp & (0x7f7f >> (7 - dshift));
+ tmp |= v << dshift;
+ *dp = (png_byte)(tmp & 0xff);
+
+ if (dshift == s_end)
+ {
+ dshift = s_start;
+ dp--;
+ }
+
+ else
+ dshift += s_inc;
+ }
+
+ if (sshift == s_end)
+ {
+ sshift = s_start;
+ sp--;
+ }
+
+ else
+ sshift += s_inc;
+ }
+ break;
+ }
+
+ case 2:
+ {
+ png_bytep sp = row + (png_uint_32)((row_info->width - 1) >> 2);
+ png_bytep dp = row + (png_uint_32)((final_width - 1) >> 2);
+ int sshift, dshift;
+ int s_start, s_end, s_inc;
+ int jstop = png_pass_inc[pass];
+ png_uint_32 i;
+
+#ifdef PNG_READ_PACKSWAP_SUPPORTED
+ if (transformations & PNG_PACKSWAP)
+ {
+ sshift = (int)(((row_info->width + 3) & 0x03) << 1);
+ dshift = (int)(((final_width + 3) & 0x03) << 1);
+ s_start = 6;
+ s_end = 0;
+ s_inc = -2;
+ }
+
+ else
+#endif
+ {
+ sshift = (int)((3 - ((row_info->width + 3) & 0x03)) << 1);
+ dshift = (int)((3 - ((final_width + 3) & 0x03)) << 1);
+ s_start = 0;
+ s_end = 6;
+ s_inc = 2;
+ }
+
+ for (i = 0; i < row_info->width; i++)
+ {
+ png_byte v;
+ int j;
+
+ v = (png_byte)((*sp >> sshift) & 0x03);
+ for (j = 0; j < jstop; j++)
+ {
+ unsigned int tmp = *dp & (0x3f3f >> (6 - dshift));
+ tmp |= v << dshift;
+ *dp = (png_byte)(tmp & 0xff);
+
+ if (dshift == s_end)
+ {
+ dshift = s_start;
+ dp--;
+ }
+
+ else
+ dshift += s_inc;
+ }
+
+ if (sshift == s_end)
+ {
+ sshift = s_start;
+ sp--;
+ }
+
+ else
+ sshift += s_inc;
+ }
+ break;
+ }
+
+ case 4:
+ {
+ png_bytep sp = row + (png_size_t)((row_info->width - 1) >> 1);
+ png_bytep dp = row + (png_size_t)((final_width - 1) >> 1);
+ int sshift, dshift;
+ int s_start, s_end, s_inc;
+ png_uint_32 i;
+ int jstop = png_pass_inc[pass];
+
+#ifdef PNG_READ_PACKSWAP_SUPPORTED
+ if (transformations & PNG_PACKSWAP)
+ {
+ sshift = (int)(((row_info->width + 1) & 0x01) << 2);
+ dshift = (int)(((final_width + 1) & 0x01) << 2);
+ s_start = 4;
+ s_end = 0;
+ s_inc = -4;
+ }
+
+ else
+#endif
+ {
+ sshift = (int)((1 - ((row_info->width + 1) & 0x01)) << 2);
+ dshift = (int)((1 - ((final_width + 1) & 0x01)) << 2);
+ s_start = 0;
+ s_end = 4;
+ s_inc = 4;
+ }
+
+ for (i = 0; i < row_info->width; i++)
+ {
+ png_byte v = (png_byte)((*sp >> sshift) & 0x0f);
+ int j;
+
+ for (j = 0; j < jstop; j++)
+ {
+ unsigned int tmp = *dp & (0xf0f >> (4 - dshift));
+ tmp |= v << dshift;
+ *dp = (png_byte)(tmp & 0xff);
+
+ if (dshift == s_end)
+ {
+ dshift = s_start;
+ dp--;
+ }
+
+ else
+ dshift += s_inc;
+ }
+
+ if (sshift == s_end)
+ {
+ sshift = s_start;
+ sp--;
+ }
+
+ else
+ sshift += s_inc;
+ }
+ break;
+ }
+
+ default:
+ {
+ png_size_t pixel_bytes = (row_info->pixel_depth >> 3);
+
+ png_bytep sp = row + (png_size_t)(row_info->width - 1)
+ * pixel_bytes;
+
+ png_bytep dp = row + (png_size_t)(final_width - 1) * pixel_bytes;
+
+ int jstop = png_pass_inc[pass];
+ png_uint_32 i;
+
+ for (i = 0; i < row_info->width; i++)
+ {
+ png_byte v[8]; /* SAFE; pixel_depth does not exceed 64 */
+ int j;
+
+ memcpy(v, sp, pixel_bytes);
+
+ for (j = 0; j < jstop; j++)
+ {
+ memcpy(dp, v, pixel_bytes);
+ dp -= pixel_bytes;
+ }
+
+ sp -= pixel_bytes;
+ }
+ break;
+ }
+ }
+
+ row_info->width = final_width;
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, final_width);
+ }
+#ifndef PNG_READ_PACKSWAP_SUPPORTED
+ PNG_UNUSED(transformations) /* Silence compiler warning */
+#endif
+}
+#endif /* PNG_READ_INTERLACING_SUPPORTED */
+
+static void
+png_read_filter_row_sub(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_size_t i;
+ png_size_t istop = row_info->rowbytes;
+ unsigned int bpp = (row_info->pixel_depth + 7) >> 3;
+ png_bytep rp = row + bpp;
+
+ PNG_UNUSED(prev_row)
+
+ for (i = bpp; i < istop; i++)
+ {
+ *rp = (png_byte)(((int)(*rp) + (int)(*(rp-bpp))) & 0xff);
+ rp++;
+ }
+}
+
+static void
+png_read_filter_row_up(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_size_t i;
+ png_size_t istop = row_info->rowbytes;
+ png_bytep rp = row;
+ png_const_bytep pp = prev_row;
+
+ for (i = 0; i < istop; i++)
+ {
+ *rp = (png_byte)(((int)(*rp) + (int)(*pp++)) & 0xff);
+ rp++;
+ }
+}
+
+static void
+png_read_filter_row_avg(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_size_t i;
+ png_bytep rp = row;
+ png_const_bytep pp = prev_row;
+ unsigned int bpp = (row_info->pixel_depth + 7) >> 3;
+ png_size_t istop = row_info->rowbytes - bpp;
+
+ for (i = 0; i < bpp; i++)
+ {
+ *rp = (png_byte)(((int)(*rp) +
+ ((int)(*pp++) / 2 )) & 0xff);
+
+ rp++;
+ }
+
+ for (i = 0; i < istop; i++)
+ {
+ *rp = (png_byte)(((int)(*rp) +
+ (int)(*pp++ + *(rp-bpp)) / 2 ) & 0xff);
+
+ rp++;
+ }
+}
+
+static void
+png_read_filter_row_paeth_1byte_pixel(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ png_bytep rp_end = row + row_info->rowbytes;
+ int a, c;
+
+ /* First pixel/byte */
+ c = *prev_row++;
+ a = *row + c;
+ *row++ = (png_byte)a;
+
+ /* Remainder */
+ while (row < rp_end)
+ {
+ int b, pa, pb, pc, p;
+
+ a &= 0xff; /* From previous iteration or start */
+ b = *prev_row++;
+
+ p = b - c;
+ pc = a - c;
+
+# ifdef PNG_USE_ABS
+ pa = abs(p);
+ pb = abs(pc);
+ pc = abs(p + pc);
+# else
+ pa = p < 0 ? -p : p;
+ pb = pc < 0 ? -pc : pc;
+ pc = (p + pc) < 0 ? -(p + pc) : p + pc;
+# endif
+
+ /* Find the best predictor, the least of pa, pb, pc favoring the earlier
+ * ones in the case of a tie.
+ */
+ if (pb < pa) pa = pb, a = b;
+ if (pc < pa) a = c;
+
+ /* Calculate the current pixel in a, and move the previous row pixel to c
+ * for the next time round the loop
+ */
+ c = b;
+ a += *row;
+ *row++ = (png_byte)a;
+ }
+}
+
+static void
+png_read_filter_row_paeth_multibyte_pixel(png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row)
+{
+ int bpp = (row_info->pixel_depth + 7) >> 3;
+ png_bytep rp_end = row + bpp;
+
+ /* Process the first pixel in the row completely (this is the same as 'up'
+ * because there is only one candidate predictor for the first row).
+ */
+ while (row < rp_end)
+ {
+ int a = *row + *prev_row++;
+ *row++ = (png_byte)a;
+ }
+
+ /* Remainder */
+ rp_end += row_info->rowbytes - bpp;
+
+ while (row < rp_end)
+ {
+ int a, b, c, pa, pb, pc, p;
+
+ c = *(prev_row - bpp);
+ a = *(row - bpp);
+ b = *prev_row++;
+
+ p = b - c;
+ pc = a - c;
+
+# ifdef PNG_USE_ABS
+ pa = abs(p);
+ pb = abs(pc);
+ pc = abs(p + pc);
+# else
+ pa = p < 0 ? -p : p;
+ pb = pc < 0 ? -pc : pc;
+ pc = (p + pc) < 0 ? -(p + pc) : p + pc;
+# endif
+
+ if (pb < pa) pa = pb, a = b;
+ if (pc < pa) a = c;
+
+ c = b;
+ a += *row;
+ *row++ = (png_byte)a;
+ }
+}
+
+static void
+png_init_filter_functions(png_structrp pp)
+ /* This function is called once for every PNG image (except for PNG images
+ * that only use PNG_FILTER_VALUE_NONE for all rows) to set the
+ * implementations required to reverse the filtering of PNG rows. Reversing
+ * the filter is the first transformation performed on the row data. It is
+ * performed in place, therefore an implementation can be selected based on
+ * the image pixel format. If the implementation depends on image width then
+ * take care to ensure that it works correctly if the image is interlaced -
+ * interlacing causes the actual row width to vary.
+ */
+{
+ unsigned int bpp = (pp->pixel_depth + 7) >> 3;
+
+ pp->read_filter[PNG_FILTER_VALUE_SUB-1] = png_read_filter_row_sub;
+ pp->read_filter[PNG_FILTER_VALUE_UP-1] = png_read_filter_row_up;
+ pp->read_filter[PNG_FILTER_VALUE_AVG-1] = png_read_filter_row_avg;
+ if (bpp == 1)
+ pp->read_filter[PNG_FILTER_VALUE_PAETH-1] =
+ png_read_filter_row_paeth_1byte_pixel;
+ else
+ pp->read_filter[PNG_FILTER_VALUE_PAETH-1] =
+ png_read_filter_row_paeth_multibyte_pixel;
+
+#ifdef PNG_FILTER_OPTIMIZATIONS
+ /* To use this define PNG_FILTER_OPTIMIZATIONS as the name of a function to
+ * call to install hardware optimizations for the above functions; simply
+ * replace whatever elements of the pp->read_filter[] array with a hardware
+ * specific (or, for that matter, generic) optimization.
+ *
+ * To see an example of this examine what configure.ac does when
+ * --enable-arm-neon is specified on the command line.
+ */
+ PNG_FILTER_OPTIMIZATIONS(pp, bpp);
+#endif
+}
+
+void /* PRIVATE */
+png_read_filter_row(png_structrp pp, png_row_infop row_info, png_bytep row,
+ png_const_bytep prev_row, int filter)
+{
+ /* OPTIMIZATION: DO NOT MODIFY THIS FUNCTION, instead #define
+ * PNG_FILTER_OPTIMIZATIONS to a function that overrides the generic
+ * implementations. See png_init_filter_functions above.
+ */
+ if (filter > PNG_FILTER_VALUE_NONE && filter < PNG_FILTER_VALUE_LAST)
+ {
+ if (pp->read_filter[0] == NULL)
+ png_init_filter_functions(pp);
+
+ pp->read_filter[filter-1](row_info, row, prev_row);
+ }
+}
+
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+void /* PRIVATE */
+png_read_IDAT_data(png_structrp png_ptr, png_bytep output,
+ png_alloc_size_t avail_out)
+{
+ /* Loop reading IDATs and decompressing the result into output[avail_out] */
+ png_ptr->zstream.next_out = output;
+ png_ptr->zstream.avail_out = 0; /* safety: set below */
+
+ if (output == NULL)
+ avail_out = 0;
+
+ do
+ {
+ int ret;
+ png_byte tmpbuf[PNG_INFLATE_BUF_SIZE];
+
+ if (png_ptr->zstream.avail_in == 0)
+ {
+ uInt avail_in;
+ png_bytep buffer;
+
+ while (png_ptr->idat_size == 0)
+ {
+ png_crc_finish(png_ptr, 0);
+
+ png_ptr->idat_size = png_read_chunk_header(png_ptr);
+ /* This is an error even in the 'check' case because the code just
+ * consumed a non-IDAT header.
+ */
+ if (png_ptr->chunk_name != png_IDAT)
+ png_error(png_ptr, "Not enough image data");
+ }
+
+ avail_in = png_ptr->IDAT_read_size;
+
+ if (avail_in > png_ptr->idat_size)
+ avail_in = (uInt)png_ptr->idat_size;
+
+ /* A PNG with a gradually increasing IDAT size will defeat this attempt
+ * to minimize memory usage by causing lots of re-allocs, but
+ * realistically doing IDAT_read_size re-allocs is not likely to be a
+ * big problem.
+ */
+ buffer = png_read_buffer(png_ptr, avail_in, 0/*error*/);
+
+ png_crc_read(png_ptr, buffer, avail_in);
+ png_ptr->idat_size -= avail_in;
+
+ png_ptr->zstream.next_in = buffer;
+ png_ptr->zstream.avail_in = avail_in;
+ }
+
+ /* And set up the output side. */
+ if (output != NULL) /* standard read */
+ {
+ uInt out = ZLIB_IO_MAX;
+
+ if (out > avail_out)
+ out = (uInt)avail_out;
+
+ avail_out -= out;
+ png_ptr->zstream.avail_out = out;
+ }
+
+ else /* after last row, checking for end */
+ {
+ png_ptr->zstream.next_out = tmpbuf;
+ png_ptr->zstream.avail_out = (sizeof tmpbuf);
+ }
+
+ /* Use NO_FLUSH; this gives zlib the maximum opportunity to optimize the
+ * process. If the LZ stream is truncated the sequential reader will
+ * terminally damage the stream, above, by reading the chunk header of the
+ * following chunk (it then exits with png_error).
+ *
+ * TODO: deal more elegantly with truncated IDAT lists.
+ */
+ ret = inflate(&png_ptr->zstream, Z_NO_FLUSH);
+
+ /* Take the unconsumed output back. */
+ if (output != NULL)
+ avail_out += png_ptr->zstream.avail_out;
+
+ else /* avail_out counts the extra bytes */
+ avail_out += (sizeof tmpbuf) - png_ptr->zstream.avail_out;
+
+ png_ptr->zstream.avail_out = 0;
+
+ if (ret == Z_STREAM_END)
+ {
+ /* Do this for safety; we won't read any more into this row. */
+ png_ptr->zstream.next_out = NULL;
+
+ png_ptr->mode |= PNG_AFTER_IDAT;
+ png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED;
+
+ if (png_ptr->zstream.avail_in > 0 || png_ptr->idat_size > 0)
+ png_chunk_benign_error(png_ptr, "Extra compressed data");
+ break;
+ }
+
+ if (ret != Z_OK)
+ {
+ png_zstream_error(png_ptr, ret);
+
+ if (output != NULL)
+ png_chunk_error(png_ptr, png_ptr->zstream.msg);
+
+ else /* checking */
+ {
+ png_chunk_benign_error(png_ptr, png_ptr->zstream.msg);
+ return;
+ }
+ }
+ } while (avail_out > 0);
+
+ if (avail_out > 0)
+ {
+ /* The stream ended before the image; this is the same as too few IDATs so
+ * should be handled the same way.
+ */
+ if (output != NULL)
+ png_error(png_ptr, "Not enough image data");
+
+ else /* the deflate stream contained extra data */
+ png_chunk_benign_error(png_ptr, "Too much image data");
+ }
+}
+
+void /* PRIVATE */
+png_read_finish_IDAT(png_structrp png_ptr)
+{
+ /* We don't need any more data and the stream should have ended, however the
+ * LZ end code may actually not have been processed. In this case we must
+ * read it otherwise stray unread IDAT data or, more likely, an IDAT chunk
+ * may still remain to be consumed.
+ */
+ if (!(png_ptr->flags & PNG_FLAG_ZSTREAM_ENDED))
+ {
+ /* The NULL causes png_read_IDAT_data to swallow any remaining bytes in
+ * the compressed stream, but the stream may be damaged too, so even after
+ * this call we may need to terminate the zstream ownership.
+ */
+ png_read_IDAT_data(png_ptr, NULL, 0);
+ png_ptr->zstream.next_out = NULL; /* safety */
+
+ /* Now clear everything out for safety; the following may not have been
+ * done.
+ */
+ if (!(png_ptr->flags & PNG_FLAG_ZSTREAM_ENDED))
+ {
+ png_ptr->mode |= PNG_AFTER_IDAT;
+ png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED;
+ }
+ }
+
+ /* If the zstream has not been released do it now *and* terminate the reading
+ * of the final IDAT chunk.
+ */
+ if (png_ptr->zowner == png_IDAT)
+ {
+ /* Always do this; the pointers otherwise point into the read buffer. */
+ png_ptr->zstream.next_in = NULL;
+ png_ptr->zstream.avail_in = 0;
+
+ /* Now we no longer own the zstream. */
+ png_ptr->zowner = 0;
+
+ /* The slightly weird semantics of the sequential IDAT reading is that we
+ * are always in or at the end of an IDAT chunk, so we always need to do a
+ * crc_finish here. If idat_size is non-zero we also need to read the
+ * spurious bytes at the end of the chunk now.
+ */
+ (void)png_crc_finish(png_ptr, png_ptr->idat_size);
+ }
+}
+
+void /* PRIVATE */
+png_read_finish_row(png_structrp png_ptr)
+{
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */
+
+ /* Start of interlace block */
+ static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0};
+
+ /* Offset to next interlace block */
+ static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1};
+
+ /* Start of interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_ystart[7] = {0, 0, 4, 0, 2, 0, 1};
+
+ /* Offset to next interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_yinc[7] = {8, 8, 8, 4, 4, 2, 2};
+#endif /* PNG_READ_INTERLACING_SUPPORTED */
+
+ png_debug(1, "in png_read_finish_row");
+ png_ptr->row_number++;
+ if (png_ptr->row_number < png_ptr->num_rows)
+ return;
+
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ if (png_ptr->interlaced)
+ {
+ png_ptr->row_number = 0;
+
+ /* TO DO: don't do this if prev_row isn't needed (requires
+ * read-ahead of the next row's filter byte.
+ */
+ memset(png_ptr->prev_row, 0, png_ptr->rowbytes + 1);
+
+ do
+ {
+ png_ptr->pass++;
+
+ if (png_ptr->pass >= 7)
+ break;
+
+ png_ptr->iwidth = (png_ptr->width +
+ png_pass_inc[png_ptr->pass] - 1 -
+ png_pass_start[png_ptr->pass]) /
+ png_pass_inc[png_ptr->pass];
+
+ if (!(png_ptr->transformations & PNG_INTERLACE))
+ {
+ png_ptr->num_rows = (png_ptr->height +
+ png_pass_yinc[png_ptr->pass] - 1 -
+ png_pass_ystart[png_ptr->pass]) /
+ png_pass_yinc[png_ptr->pass];
+ }
+
+ else /* if (png_ptr->transformations & PNG_INTERLACE) */
+ break; /* libpng deinterlacing sees every row */
+
+ } while (png_ptr->num_rows == 0 || png_ptr->iwidth == 0);
+
+ if (png_ptr->pass < 7)
+ return;
+ }
+#endif /* PNG_READ_INTERLACING_SUPPORTED */
+
+ /* Here after at the end of the last row of the last pass. */
+ png_read_finish_IDAT(png_ptr);
+}
+#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */
+
+void /* PRIVATE */
+png_read_start_row(png_structrp png_ptr)
+{
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */
+
+ /* Start of interlace block */
+ static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0};
+
+ /* Offset to next interlace block */
+ static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1};
+
+ /* Start of interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_ystart[7] = {0, 0, 4, 0, 2, 0, 1};
+
+ /* Offset to next interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_yinc[7] = {8, 8, 8, 4, 4, 2, 2};
+#endif
+
+ int max_pixel_depth;
+ png_size_t row_bytes;
+
+ png_debug(1, "in png_read_start_row");
+
+#ifdef PNG_READ_TRANSFORMS_SUPPORTED
+ png_init_read_transformations(png_ptr);
+#endif
+#ifdef PNG_READ_INTERLACING_SUPPORTED
+ if (png_ptr->interlaced)
+ {
+ if (!(png_ptr->transformations & PNG_INTERLACE))
+ png_ptr->num_rows = (png_ptr->height + png_pass_yinc[0] - 1 -
+ png_pass_ystart[0]) / png_pass_yinc[0];
+
+ else
+ png_ptr->num_rows = png_ptr->height;
+
+ png_ptr->iwidth = (png_ptr->width +
+ png_pass_inc[png_ptr->pass] - 1 -
+ png_pass_start[png_ptr->pass]) /
+ png_pass_inc[png_ptr->pass];
+ }
+
+ else
+#endif /* PNG_READ_INTERLACING_SUPPORTED */
+ {
+ png_ptr->num_rows = png_ptr->height;
+ png_ptr->iwidth = png_ptr->width;
+ }
+
+ max_pixel_depth = png_ptr->pixel_depth;
+
+ /* WARNING: * png_read_transform_info (pngrtran.c) performs a simpliar set of
+ * calculations to calculate the final pixel depth, then
+ * png_do_read_transforms actually does the transforms. This means that the
+ * code which effectively calculates this value is actually repeated in three
+ * separate places. They must all match. Innocent changes to the order of
+ * transformations can and will break libpng in a way that causes memory
+ * overwrites.
+ *
+ * TODO: fix this.
+ */
+#ifdef PNG_READ_PACK_SUPPORTED
+ if ((png_ptr->transformations & PNG_PACK) && png_ptr->bit_depth < 8)
+ max_pixel_depth = 8;
+#endif
+
+#ifdef PNG_READ_EXPAND_SUPPORTED
+ if (png_ptr->transformations & PNG_EXPAND)
+ {
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ if (png_ptr->num_trans)
+ max_pixel_depth = 32;
+
+ else
+ max_pixel_depth = 24;
+ }
+
+ else if (png_ptr->color_type == PNG_COLOR_TYPE_GRAY)
+ {
+ if (max_pixel_depth < 8)
+ max_pixel_depth = 8;
+
+ if (png_ptr->num_trans)
+ max_pixel_depth *= 2;
+ }
+
+ else if (png_ptr->color_type == PNG_COLOR_TYPE_RGB)
+ {
+ if (png_ptr->num_trans)
+ {
+ max_pixel_depth *= 4;
+ max_pixel_depth /= 3;
+ }
+ }
+ }
+#endif
+
+#ifdef PNG_READ_EXPAND_16_SUPPORTED
+ if (png_ptr->transformations & PNG_EXPAND_16)
+ {
+# ifdef PNG_READ_EXPAND_SUPPORTED
+ /* In fact it is an error if it isn't supported, but checking is
+ * the safe way.
+ */
+ if (png_ptr->transformations & PNG_EXPAND)
+ {
+ if (png_ptr->bit_depth < 16)
+ max_pixel_depth *= 2;
+ }
+ else
+# endif
+ png_ptr->transformations &= ~PNG_EXPAND_16;
+ }
+#endif
+
+#ifdef PNG_READ_FILLER_SUPPORTED
+ if (png_ptr->transformations & (PNG_FILLER))
+ {
+ if (png_ptr->color_type == PNG_COLOR_TYPE_GRAY)
+ {
+ if (max_pixel_depth <= 8)
+ max_pixel_depth = 16;
+
+ else
+ max_pixel_depth = 32;
+ }
+
+ else if (png_ptr->color_type == PNG_COLOR_TYPE_RGB ||
+ png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ if (max_pixel_depth <= 32)
+ max_pixel_depth = 32;
+
+ else
+ max_pixel_depth = 64;
+ }
+ }
+#endif
+
+#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED
+ if (png_ptr->transformations & PNG_GRAY_TO_RGB)
+ {
+ if (
+#ifdef PNG_READ_EXPAND_SUPPORTED
+ (png_ptr->num_trans && (png_ptr->transformations & PNG_EXPAND)) ||
+#endif
+#ifdef PNG_READ_FILLER_SUPPORTED
+ (png_ptr->transformations & (PNG_FILLER)) ||
+#endif
+ png_ptr->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
+ {
+ if (max_pixel_depth <= 16)
+ max_pixel_depth = 32;
+
+ else
+ max_pixel_depth = 64;
+ }
+
+ else
+ {
+ if (max_pixel_depth <= 8)
+ {
+ if (png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ max_pixel_depth = 32;
+
+ else
+ max_pixel_depth = 24;
+ }
+
+ else if (png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ max_pixel_depth = 64;
+
+ else
+ max_pixel_depth = 48;
+ }
+ }
+#endif
+
+#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) && \
+defined(PNG_USER_TRANSFORM_PTR_SUPPORTED)
+ if (png_ptr->transformations & PNG_USER_TRANSFORM)
+ {
+ int user_pixel_depth = png_ptr->user_transform_depth *
+ png_ptr->user_transform_channels;
+
+ if (user_pixel_depth > max_pixel_depth)
+ max_pixel_depth = user_pixel_depth;
+ }
+#endif
+
+ /* This value is stored in png_struct and double checked in the row read
+ * code.
+ */
+ png_ptr->maximum_pixel_depth = (png_byte)max_pixel_depth;
+ png_ptr->transformed_pixel_depth = 0; /* calculated on demand */
+
+ /* Align the width on the next larger 8 pixels. Mainly used
+ * for interlacing
+ */
+ row_bytes = ((png_ptr->width + 7) & ~((png_uint_32)7));
+ /* Calculate the maximum bytes needed, adding a byte and a pixel
+ * for safety's sake
+ */
+ row_bytes = PNG_ROWBYTES(max_pixel_depth, row_bytes) +
+ 1 + ((max_pixel_depth + 7) >> 3);
+
+#ifdef PNG_MAX_MALLOC_64K
+ if (row_bytes > (png_uint_32)65536L)
+ png_error(png_ptr, "This image requires a row greater than 64KB");
+#endif
+
+ if (row_bytes + 48 > png_ptr->old_big_row_buf_size)
+ {
+ png_free(png_ptr, png_ptr->big_row_buf);
+ png_free(png_ptr, png_ptr->big_prev_row);
+
+ if (png_ptr->interlaced)
+ png_ptr->big_row_buf = (png_bytep)png_calloc(png_ptr,
+ row_bytes + 48);
+
+ else
+ png_ptr->big_row_buf = (png_bytep)png_malloc(png_ptr, row_bytes + 48);
+
+ png_ptr->big_prev_row = (png_bytep)png_malloc(png_ptr, row_bytes + 48);
+
+#ifdef PNG_ALIGNED_MEMORY_SUPPORTED
+ /* Use 16-byte aligned memory for row_buf with at least 16 bytes
+ * of padding before and after row_buf; treat prev_row similarly.
+ * NOTE: the alignment is to the start of the pixels, one beyond the start
+ * of the buffer, because of the filter byte. Prior to libpng 1.5.6 this
+ * was incorrect; the filter byte was aligned, which had the exact
+ * opposite effect of that intended.
+ */
+ {
+ png_bytep temp = png_ptr->big_row_buf + 32;
+ int extra = (int)((temp - (png_bytep)0) & 0x0f);
+ png_ptr->row_buf = temp - extra - 1/*filter byte*/;
+
+ temp = png_ptr->big_prev_row + 32;
+ extra = (int)((temp - (png_bytep)0) & 0x0f);
+ png_ptr->prev_row = temp - extra - 1/*filter byte*/;
+ }
+
+#else
+ /* Use 31 bytes of padding before and 17 bytes after row_buf. */
+ png_ptr->row_buf = png_ptr->big_row_buf + 31;
+ png_ptr->prev_row = png_ptr->big_prev_row + 31;
+#endif
+ png_ptr->old_big_row_buf_size = row_bytes + 48;
+ }
+
+#ifdef PNG_MAX_MALLOC_64K
+ if (png_ptr->rowbytes > 65535)
+ png_error(png_ptr, "This image requires a row greater than 64KB");
+
+#endif
+ if (png_ptr->rowbytes > (PNG_SIZE_MAX - 1))
+ png_error(png_ptr, "Row has too many bytes to allocate in memory");
+
+ memset(png_ptr->prev_row, 0, png_ptr->rowbytes + 1);
+
+ png_debug1(3, "width = %u,", png_ptr->width);
+ png_debug1(3, "height = %u,", png_ptr->height);
+ png_debug1(3, "iwidth = %u,", png_ptr->iwidth);
+ png_debug1(3, "num_rows = %u,", png_ptr->num_rows);
+ png_debug1(3, "rowbytes = %lu,", (unsigned long)png_ptr->rowbytes);
+ png_debug1(3, "irowbytes = %lu",
+ (unsigned long)PNG_ROWBYTES(png_ptr->pixel_depth, png_ptr->iwidth) + 1);
+
+ /* The sequential reader needs a buffer for IDAT, but the progressive reader
+ * does not, so free the read buffer now regardless; the sequential reader
+ * reallocates it on demand.
+ */
+ if (png_ptr->read_buffer)
+ {
+ png_bytep buffer = png_ptr->read_buffer;
+
+ png_ptr->read_buffer_size = 0;
+ png_ptr->read_buffer = NULL;
+ png_free(png_ptr, buffer);
+ }
+
+ /* Finally claim the zstream for the inflate of the IDAT data, use the bits
+ * value from the stream (note that this will result in a fatal error if the
+ * IDAT stream has a bogus deflate header window_bits value, but this should
+ * not be happening any longer!)
+ */
+ if (png_inflate_claim(png_ptr, png_IDAT) != Z_OK)
+ png_error(png_ptr, png_ptr->zstream.msg);
+
+ png_ptr->flags |= PNG_FLAG_ROW_INIT;
+}
+#endif /* PNG_READ_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngset.c b/ml/dlib/dlib/external/libpng/pngset.c
new file mode 100644
index 000000000..7e355d1f4
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngset.c
@@ -0,0 +1,1597 @@
+
+/* pngset.c - storage of image information into info struct
+ *
+ * Last changed in libpng 1.6.3 [July 18, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ * The functions here are used during reads to store data from the file
+ * into the info struct, and during writes to store application data
+ * into the info struct for writing into the file. This abstracts the
+ * info struct and allows us to change the structure in the future.
+ */
+
+#include "pngpriv.h"
+
+#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED)
+
+#ifdef PNG_bKGD_SUPPORTED
+void PNGAPI
+png_set_bKGD(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_const_color_16p background)
+{
+ png_debug1(1, "in %s storage function", "bKGD");
+
+ if (png_ptr == NULL || info_ptr == NULL || background == NULL)
+ return;
+
+ info_ptr->background = *background;
+ info_ptr->valid |= PNG_INFO_bKGD;
+}
+#endif
+
+#ifdef PNG_cHRM_SUPPORTED
+void PNGFAPI
+png_set_cHRM_fixed(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_fixed_point white_x, png_fixed_point white_y, png_fixed_point red_x,
+ png_fixed_point red_y, png_fixed_point green_x, png_fixed_point green_y,
+ png_fixed_point blue_x, png_fixed_point blue_y)
+{
+ png_xy xy;
+
+ png_debug1(1, "in %s storage function", "cHRM fixed");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ xy.redx = red_x;
+ xy.redy = red_y;
+ xy.greenx = green_x;
+ xy.greeny = green_y;
+ xy.bluex = blue_x;
+ xy.bluey = blue_y;
+ xy.whitex = white_x;
+ xy.whitey = white_y;
+
+ if (png_colorspace_set_chromaticities(png_ptr, &info_ptr->colorspace, &xy,
+ 2/* override with app values*/))
+ info_ptr->colorspace.flags |= PNG_COLORSPACE_FROM_cHRM;
+
+ png_colorspace_sync_info(png_ptr, info_ptr);
+}
+
+void PNGFAPI
+png_set_cHRM_XYZ_fixed(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_fixed_point int_red_X, png_fixed_point int_red_Y,
+ png_fixed_point int_red_Z, png_fixed_point int_green_X,
+ png_fixed_point int_green_Y, png_fixed_point int_green_Z,
+ png_fixed_point int_blue_X, png_fixed_point int_blue_Y,
+ png_fixed_point int_blue_Z)
+{
+ png_XYZ XYZ;
+
+ png_debug1(1, "in %s storage function", "cHRM XYZ fixed");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ XYZ.red_X = int_red_X;
+ XYZ.red_Y = int_red_Y;
+ XYZ.red_Z = int_red_Z;
+ XYZ.green_X = int_green_X;
+ XYZ.green_Y = int_green_Y;
+ XYZ.green_Z = int_green_Z;
+ XYZ.blue_X = int_blue_X;
+ XYZ.blue_Y = int_blue_Y;
+ XYZ.blue_Z = int_blue_Z;
+
+ if (png_colorspace_set_endpoints(png_ptr, &info_ptr->colorspace, &XYZ, 2))
+ info_ptr->colorspace.flags |= PNG_COLORSPACE_FROM_cHRM;
+
+ png_colorspace_sync_info(png_ptr, info_ptr);
+}
+
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+void PNGAPI
+png_set_cHRM(png_const_structrp png_ptr, png_inforp info_ptr,
+ double white_x, double white_y, double red_x, double red_y,
+ double green_x, double green_y, double blue_x, double blue_y)
+{
+ png_set_cHRM_fixed(png_ptr, info_ptr,
+ png_fixed(png_ptr, white_x, "cHRM White X"),
+ png_fixed(png_ptr, white_y, "cHRM White Y"),
+ png_fixed(png_ptr, red_x, "cHRM Red X"),
+ png_fixed(png_ptr, red_y, "cHRM Red Y"),
+ png_fixed(png_ptr, green_x, "cHRM Green X"),
+ png_fixed(png_ptr, green_y, "cHRM Green Y"),
+ png_fixed(png_ptr, blue_x, "cHRM Blue X"),
+ png_fixed(png_ptr, blue_y, "cHRM Blue Y"));
+}
+
+void PNGAPI
+png_set_cHRM_XYZ(png_const_structrp png_ptr, png_inforp info_ptr, double red_X,
+ double red_Y, double red_Z, double green_X, double green_Y, double green_Z,
+ double blue_X, double blue_Y, double blue_Z)
+{
+ png_set_cHRM_XYZ_fixed(png_ptr, info_ptr,
+ png_fixed(png_ptr, red_X, "cHRM Red X"),
+ png_fixed(png_ptr, red_Y, "cHRM Red Y"),
+ png_fixed(png_ptr, red_Z, "cHRM Red Z"),
+ png_fixed(png_ptr, green_X, "cHRM Red X"),
+ png_fixed(png_ptr, green_Y, "cHRM Red Y"),
+ png_fixed(png_ptr, green_Z, "cHRM Red Z"),
+ png_fixed(png_ptr, blue_X, "cHRM Red X"),
+ png_fixed(png_ptr, blue_Y, "cHRM Red Y"),
+ png_fixed(png_ptr, blue_Z, "cHRM Red Z"));
+}
+# endif /* PNG_FLOATING_POINT_SUPPORTED */
+
+#endif /* PNG_cHRM_SUPPORTED */
+
+#ifdef PNG_gAMA_SUPPORTED
+void PNGFAPI
+png_set_gAMA_fixed(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_fixed_point file_gamma)
+{
+ png_debug1(1, "in %s storage function", "gAMA");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ png_colorspace_set_gamma(png_ptr, &info_ptr->colorspace, file_gamma);
+ png_colorspace_sync_info(png_ptr, info_ptr);
+}
+
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+void PNGAPI
+png_set_gAMA(png_const_structrp png_ptr, png_inforp info_ptr, double file_gamma)
+{
+ png_set_gAMA_fixed(png_ptr, info_ptr, png_fixed(png_ptr, file_gamma,
+ "png_set_gAMA"));
+}
+# endif
+#endif
+
+#ifdef PNG_hIST_SUPPORTED
+void PNGAPI
+png_set_hIST(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_const_uint_16p hist)
+{
+ int i;
+
+ png_debug1(1, "in %s storage function", "hIST");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ if (info_ptr->num_palette == 0 || info_ptr->num_palette
+ > PNG_MAX_PALETTE_LENGTH)
+ {
+ png_warning(png_ptr,
+ "Invalid palette size, hIST allocation skipped");
+
+ return;
+ }
+
+ png_free_data(png_ptr, info_ptr, PNG_FREE_HIST, 0);
+
+ /* Changed from info->num_palette to PNG_MAX_PALETTE_LENGTH in
+ * version 1.2.1
+ */
+ info_ptr->hist = png_voidcast(png_uint_16p, png_malloc_warn(png_ptr,
+ PNG_MAX_PALETTE_LENGTH * (sizeof (png_uint_16))));
+
+ if (info_ptr->hist == NULL)
+ {
+ png_warning(png_ptr, "Insufficient memory for hIST chunk data");
+ return;
+ }
+
+ info_ptr->free_me |= PNG_FREE_HIST;
+
+ for (i = 0; i < info_ptr->num_palette; i++)
+ info_ptr->hist[i] = hist[i];
+
+ info_ptr->valid |= PNG_INFO_hIST;
+}
+#endif
+
+void PNGAPI
+png_set_IHDR(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_uint_32 width, png_uint_32 height, int bit_depth,
+ int color_type, int interlace_type, int compression_type,
+ int filter_type)
+{
+ png_debug1(1, "in %s storage function", "IHDR");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ info_ptr->width = width;
+ info_ptr->height = height;
+ info_ptr->bit_depth = (png_byte)bit_depth;
+ info_ptr->color_type = (png_byte)color_type;
+ info_ptr->compression_type = (png_byte)compression_type;
+ info_ptr->filter_type = (png_byte)filter_type;
+ info_ptr->interlace_type = (png_byte)interlace_type;
+
+ png_check_IHDR (png_ptr, info_ptr->width, info_ptr->height,
+ info_ptr->bit_depth, info_ptr->color_type, info_ptr->interlace_type,
+ info_ptr->compression_type, info_ptr->filter_type);
+
+ if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ info_ptr->channels = 1;
+
+ else if (info_ptr->color_type & PNG_COLOR_MASK_COLOR)
+ info_ptr->channels = 3;
+
+ else
+ info_ptr->channels = 1;
+
+ if (info_ptr->color_type & PNG_COLOR_MASK_ALPHA)
+ info_ptr->channels++;
+
+ info_ptr->pixel_depth = (png_byte)(info_ptr->channels * info_ptr->bit_depth);
+
+ info_ptr->rowbytes = PNG_ROWBYTES(info_ptr->pixel_depth, width);
+}
+
+#ifdef PNG_oFFs_SUPPORTED
+void PNGAPI
+png_set_oFFs(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_int_32 offset_x, png_int_32 offset_y, int unit_type)
+{
+ png_debug1(1, "in %s storage function", "oFFs");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ info_ptr->x_offset = offset_x;
+ info_ptr->y_offset = offset_y;
+ info_ptr->offset_unit_type = (png_byte)unit_type;
+ info_ptr->valid |= PNG_INFO_oFFs;
+}
+#endif
+
+#ifdef PNG_pCAL_SUPPORTED
+void PNGAPI
+png_set_pCAL(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_const_charp purpose, png_int_32 X0, png_int_32 X1, int type,
+ int nparams, png_const_charp units, png_charpp params)
+{
+ png_size_t length;
+ int i;
+
+ png_debug1(1, "in %s storage function", "pCAL");
+
+ if (png_ptr == NULL || info_ptr == NULL || purpose == NULL || units == NULL
+ || (nparams > 0 && params == NULL))
+ return;
+
+ length = strlen(purpose) + 1;
+ png_debug1(3, "allocating purpose for info (%lu bytes)",
+ (unsigned long)length);
+
+ /* TODO: validate format of calibration name and unit name */
+
+ /* Check that the type matches the specification. */
+ if (type < 0 || type > 3)
+ png_error(png_ptr, "Invalid pCAL equation type");
+
+ if (nparams < 0 || nparams > 255)
+ png_error(png_ptr, "Invalid pCAL parameter count");
+
+ /* Validate params[nparams] */
+ for (i=0; i<nparams; ++i)
+ if (params[i] == NULL ||
+ !png_check_fp_string(params[i], strlen(params[i])))
+ png_error(png_ptr, "Invalid format for pCAL parameter");
+
+ info_ptr->pcal_purpose = png_voidcast(png_charp,
+ png_malloc_warn(png_ptr, length));
+
+ if (info_ptr->pcal_purpose == NULL)
+ {
+ png_warning(png_ptr, "Insufficient memory for pCAL purpose");
+ return;
+ }
+
+ memcpy(info_ptr->pcal_purpose, purpose, length);
+
+ png_debug(3, "storing X0, X1, type, and nparams in info");
+ info_ptr->pcal_X0 = X0;
+ info_ptr->pcal_X1 = X1;
+ info_ptr->pcal_type = (png_byte)type;
+ info_ptr->pcal_nparams = (png_byte)nparams;
+
+ length = strlen(units) + 1;
+ png_debug1(3, "allocating units for info (%lu bytes)",
+ (unsigned long)length);
+
+ info_ptr->pcal_units = png_voidcast(png_charp,
+ png_malloc_warn(png_ptr, length));
+
+ if (info_ptr->pcal_units == NULL)
+ {
+ png_warning(png_ptr, "Insufficient memory for pCAL units");
+ return;
+ }
+
+ memcpy(info_ptr->pcal_units, units, length);
+
+ info_ptr->pcal_params = png_voidcast(png_charpp, png_malloc_warn(png_ptr,
+ (png_size_t)((nparams + 1) * (sizeof (png_charp)))));
+
+ if (info_ptr->pcal_params == NULL)
+ {
+ png_warning(png_ptr, "Insufficient memory for pCAL params");
+ return;
+ }
+
+ memset(info_ptr->pcal_params, 0, (nparams + 1) * (sizeof (png_charp)));
+
+ for (i = 0; i < nparams; i++)
+ {
+ length = strlen(params[i]) + 1;
+ png_debug2(3, "allocating parameter %d for info (%lu bytes)", i,
+ (unsigned long)length);
+
+ info_ptr->pcal_params[i] = (png_charp)png_malloc_warn(png_ptr, length);
+
+ if (info_ptr->pcal_params[i] == NULL)
+ {
+ png_warning(png_ptr, "Insufficient memory for pCAL parameter");
+ return;
+ }
+
+ memcpy(info_ptr->pcal_params[i], params[i], length);
+ }
+
+ info_ptr->valid |= PNG_INFO_pCAL;
+ info_ptr->free_me |= PNG_FREE_PCAL;
+}
+#endif
+
+#ifdef PNG_sCAL_SUPPORTED
+void PNGAPI
+png_set_sCAL_s(png_const_structrp png_ptr, png_inforp info_ptr,
+ int unit, png_const_charp swidth, png_const_charp sheight)
+{
+ png_size_t lengthw = 0, lengthh = 0;
+
+ png_debug1(1, "in %s storage function", "sCAL");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ /* Double check the unit (should never get here with an invalid
+ * unit unless this is an API call.)
+ */
+ if (unit != 1 && unit != 2)
+ png_error(png_ptr, "Invalid sCAL unit");
+
+ if (swidth == NULL || (lengthw = strlen(swidth)) == 0 ||
+ swidth[0] == 45 /* '-' */ || !png_check_fp_string(swidth, lengthw))
+ png_error(png_ptr, "Invalid sCAL width");
+
+ if (sheight == NULL || (lengthh = strlen(sheight)) == 0 ||
+ sheight[0] == 45 /* '-' */ || !png_check_fp_string(sheight, lengthh))
+ png_error(png_ptr, "Invalid sCAL height");
+
+ info_ptr->scal_unit = (png_byte)unit;
+
+ ++lengthw;
+
+ png_debug1(3, "allocating unit for info (%u bytes)", (unsigned int)lengthw);
+
+ info_ptr->scal_s_width = png_voidcast(png_charp,
+ png_malloc_warn(png_ptr, lengthw));
+
+ if (info_ptr->scal_s_width == NULL)
+ {
+ png_warning(png_ptr, "Memory allocation failed while processing sCAL");
+ return;
+ }
+
+ memcpy(info_ptr->scal_s_width, swidth, lengthw);
+
+ ++lengthh;
+
+ png_debug1(3, "allocating unit for info (%u bytes)", (unsigned int)lengthh);
+
+ info_ptr->scal_s_height = png_voidcast(png_charp,
+ png_malloc_warn(png_ptr, lengthh));
+
+ if (info_ptr->scal_s_height == NULL)
+ {
+ png_free (png_ptr, info_ptr->scal_s_width);
+ info_ptr->scal_s_width = NULL;
+
+ png_warning(png_ptr, "Memory allocation failed while processing sCAL");
+ return;
+ }
+
+ memcpy(info_ptr->scal_s_height, sheight, lengthh);
+
+ info_ptr->valid |= PNG_INFO_sCAL;
+ info_ptr->free_me |= PNG_FREE_SCAL;
+}
+
+# ifdef PNG_FLOATING_POINT_SUPPORTED
+void PNGAPI
+png_set_sCAL(png_const_structrp png_ptr, png_inforp info_ptr, int unit,
+ double width, double height)
+{
+ png_debug1(1, "in %s storage function", "sCAL");
+
+ /* Check the arguments. */
+ if (width <= 0)
+ png_warning(png_ptr, "Invalid sCAL width ignored");
+
+ else if (height <= 0)
+ png_warning(png_ptr, "Invalid sCAL height ignored");
+
+ else
+ {
+ /* Convert 'width' and 'height' to ASCII. */
+ char swidth[PNG_sCAL_MAX_DIGITS+1];
+ char sheight[PNG_sCAL_MAX_DIGITS+1];
+
+ png_ascii_from_fp(png_ptr, swidth, (sizeof swidth), width,
+ PNG_sCAL_PRECISION);
+ png_ascii_from_fp(png_ptr, sheight, (sizeof sheight), height,
+ PNG_sCAL_PRECISION);
+
+ png_set_sCAL_s(png_ptr, info_ptr, unit, swidth, sheight);
+ }
+}
+# endif
+
+# ifdef PNG_FIXED_POINT_SUPPORTED
+void PNGAPI
+png_set_sCAL_fixed(png_const_structrp png_ptr, png_inforp info_ptr, int unit,
+ png_fixed_point width, png_fixed_point height)
+{
+ png_debug1(1, "in %s storage function", "sCAL");
+
+ /* Check the arguments. */
+ if (width <= 0)
+ png_warning(png_ptr, "Invalid sCAL width ignored");
+
+ else if (height <= 0)
+ png_warning(png_ptr, "Invalid sCAL height ignored");
+
+ else
+ {
+ /* Convert 'width' and 'height' to ASCII. */
+ char swidth[PNG_sCAL_MAX_DIGITS+1];
+ char sheight[PNG_sCAL_MAX_DIGITS+1];
+
+ png_ascii_from_fixed(png_ptr, swidth, (sizeof swidth), width);
+ png_ascii_from_fixed(png_ptr, sheight, (sizeof sheight), height);
+
+ png_set_sCAL_s(png_ptr, info_ptr, unit, swidth, sheight);
+ }
+}
+# endif
+#endif
+
+#ifdef PNG_pHYs_SUPPORTED
+void PNGAPI
+png_set_pHYs(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_uint_32 res_x, png_uint_32 res_y, int unit_type)
+{
+ png_debug1(1, "in %s storage function", "pHYs");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ info_ptr->x_pixels_per_unit = res_x;
+ info_ptr->y_pixels_per_unit = res_y;
+ info_ptr->phys_unit_type = (png_byte)unit_type;
+ info_ptr->valid |= PNG_INFO_pHYs;
+}
+#endif
+
+void PNGAPI
+png_set_PLTE(png_structrp png_ptr, png_inforp info_ptr,
+ png_const_colorp palette, int num_palette)
+{
+
+ png_debug1(1, "in %s storage function", "PLTE");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ if (num_palette < 0 || num_palette > PNG_MAX_PALETTE_LENGTH)
+ {
+ if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ png_error(png_ptr, "Invalid palette length");
+
+ else
+ {
+ png_warning(png_ptr, "Invalid palette length");
+ return;
+ }
+ }
+
+ if ((num_palette > 0 && palette == NULL) ||
+ (num_palette == 0
+# ifdef PNG_MNG_FEATURES_SUPPORTED
+ && (png_ptr->mng_features_permitted & PNG_FLAG_MNG_EMPTY_PLTE) == 0
+# endif
+ ))
+ {
+ png_chunk_report(png_ptr, "Invalid palette", PNG_CHUNK_ERROR);
+ return;
+ }
+
+ /* It may not actually be necessary to set png_ptr->palette here;
+ * we do it for backward compatibility with the way the png_handle_tRNS
+ * function used to do the allocation.
+ *
+ * 1.6.0: the above statement appears to be incorrect; something has to set
+ * the palette inside png_struct on read.
+ */
+ png_free_data(png_ptr, info_ptr, PNG_FREE_PLTE, 0);
+
+ /* Changed in libpng-1.2.1 to allocate PNG_MAX_PALETTE_LENGTH instead
+ * of num_palette entries, in case of an invalid PNG file that has
+ * too-large sample values.
+ */
+ png_ptr->palette = png_voidcast(png_colorp, png_calloc(png_ptr,
+ PNG_MAX_PALETTE_LENGTH * (sizeof (png_color))));
+
+ if (num_palette > 0)
+ memcpy(png_ptr->palette, palette, num_palette * (sizeof (png_color)));
+ info_ptr->palette = png_ptr->palette;
+ info_ptr->num_palette = png_ptr->num_palette = (png_uint_16)num_palette;
+
+ info_ptr->free_me |= PNG_FREE_PLTE;
+
+ info_ptr->valid |= PNG_INFO_PLTE;
+}
+
+#ifdef PNG_sBIT_SUPPORTED
+void PNGAPI
+png_set_sBIT(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_const_color_8p sig_bit)
+{
+ png_debug1(1, "in %s storage function", "sBIT");
+
+ if (png_ptr == NULL || info_ptr == NULL || sig_bit == NULL)
+ return;
+
+ info_ptr->sig_bit = *sig_bit;
+ info_ptr->valid |= PNG_INFO_sBIT;
+}
+#endif
+
+#ifdef PNG_sRGB_SUPPORTED
+void PNGAPI
+png_set_sRGB(png_const_structrp png_ptr, png_inforp info_ptr, int srgb_intent)
+{
+ png_debug1(1, "in %s storage function", "sRGB");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ (void)png_colorspace_set_sRGB(png_ptr, &info_ptr->colorspace, srgb_intent);
+ png_colorspace_sync_info(png_ptr, info_ptr);
+}
+
+void PNGAPI
+png_set_sRGB_gAMA_and_cHRM(png_const_structrp png_ptr, png_inforp info_ptr,
+ int srgb_intent)
+{
+ png_debug1(1, "in %s storage function", "sRGB_gAMA_and_cHRM");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ if (png_colorspace_set_sRGB(png_ptr, &info_ptr->colorspace, srgb_intent))
+ {
+ /* This causes the gAMA and cHRM to be written too */
+ info_ptr->colorspace.flags |=
+ PNG_COLORSPACE_FROM_gAMA|PNG_COLORSPACE_FROM_cHRM;
+ }
+
+ png_colorspace_sync_info(png_ptr, info_ptr);
+}
+#endif /* sRGB */
+
+
+#ifdef PNG_iCCP_SUPPORTED
+void PNGAPI
+png_set_iCCP(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_const_charp name, int compression_type,
+ png_const_bytep profile, png_uint_32 proflen)
+{
+ png_charp new_iccp_name;
+ png_bytep new_iccp_profile;
+ png_size_t length;
+
+ png_debug1(1, "in %s storage function", "iCCP");
+
+ if (png_ptr == NULL || info_ptr == NULL || name == NULL || profile == NULL)
+ return;
+
+ if (compression_type != PNG_COMPRESSION_TYPE_BASE)
+ png_app_error(png_ptr, "Invalid iCCP compression method");
+
+ /* Set the colorspace first because this validates the profile; do not
+ * override previously set app cHRM or gAMA here (because likely as not the
+ * application knows better than libpng what the correct values are.) Pass
+ * the info_ptr color_type field to png_colorspace_set_ICC because in the
+ * write case it has not yet been stored in png_ptr.
+ */
+ {
+ int result = png_colorspace_set_ICC(png_ptr, &info_ptr->colorspace, name,
+ proflen, profile, info_ptr->color_type);
+
+ png_colorspace_sync_info(png_ptr, info_ptr);
+
+ /* Don't do any of the copying if the profile was bad, or inconsistent. */
+ if (!result)
+ return;
+
+ /* But do write the gAMA and cHRM chunks from the profile. */
+ info_ptr->colorspace.flags |=
+ PNG_COLORSPACE_FROM_gAMA|PNG_COLORSPACE_FROM_cHRM;
+ }
+
+ length = strlen(name)+1;
+ new_iccp_name = png_voidcast(png_charp, png_malloc_warn(png_ptr, length));
+
+ if (new_iccp_name == NULL)
+ {
+ png_benign_error(png_ptr, "Insufficient memory to process iCCP chunk");
+ return;
+ }
+
+ memcpy(new_iccp_name, name, length);
+ new_iccp_profile = png_voidcast(png_bytep,
+ png_malloc_warn(png_ptr, proflen));
+
+ if (new_iccp_profile == NULL)
+ {
+ png_free(png_ptr, new_iccp_name);
+ png_benign_error(png_ptr,
+ "Insufficient memory to process iCCP profile");
+ return;
+ }
+
+ memcpy(new_iccp_profile, profile, proflen);
+
+ png_free_data(png_ptr, info_ptr, PNG_FREE_ICCP, 0);
+
+ info_ptr->iccp_proflen = proflen;
+ info_ptr->iccp_name = new_iccp_name;
+ info_ptr->iccp_profile = new_iccp_profile;
+ info_ptr->free_me |= PNG_FREE_ICCP;
+ info_ptr->valid |= PNG_INFO_iCCP;
+}
+#endif
+
+#ifdef PNG_TEXT_SUPPORTED
+void PNGAPI
+png_set_text(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_const_textp text_ptr, int num_text)
+{
+ int ret;
+ ret = png_set_text_2(png_ptr, info_ptr, text_ptr, num_text);
+
+ if (ret)
+ png_error(png_ptr, "Insufficient memory to store text");
+}
+
+int /* PRIVATE */
+png_set_text_2(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_const_textp text_ptr, int num_text)
+{
+ int i;
+
+ png_debug1(1, "in %lx storage function", png_ptr == NULL ? "unexpected" :
+ (unsigned long)png_ptr->chunk_name);
+
+ if (png_ptr == NULL || info_ptr == NULL || num_text <= 0 || text_ptr == NULL)
+ return(0);
+
+ /* Make sure we have enough space in the "text" array in info_struct
+ * to hold all of the incoming text_ptr objects. This compare can't overflow
+ * because max_text >= num_text (anyway, subtract of two positive integers
+ * can't overflow in any case.)
+ */
+ if (num_text > info_ptr->max_text - info_ptr->num_text)
+ {
+ int old_num_text = info_ptr->num_text;
+ int max_text;
+ png_textp new_text = NULL;
+
+ /* Calculate an appropriate max_text, checking for overflow. */
+ max_text = old_num_text;
+ if (num_text <= INT_MAX - max_text)
+ {
+ max_text += num_text;
+
+ /* Round up to a multiple of 8 */
+ if (max_text < INT_MAX-8)
+ max_text = (max_text + 8) & ~0x7;
+
+ else
+ max_text = INT_MAX;
+
+ /* Now allocate a new array and copy the old members in, this does all
+ * the overflow checks.
+ */
+ new_text = png_voidcast(png_textp,png_realloc_array(png_ptr,
+ info_ptr->text, old_num_text, max_text-old_num_text,
+ sizeof *new_text));
+ }
+
+ if (new_text == NULL)
+ {
+ png_chunk_report(png_ptr, "too many text chunks",
+ PNG_CHUNK_WRITE_ERROR);
+ return 1;
+ }
+
+ png_free(png_ptr, info_ptr->text);
+
+ info_ptr->text = new_text;
+ info_ptr->free_me |= PNG_FREE_TEXT;
+ info_ptr->max_text = max_text;
+ /* num_text is adjusted below as the entries are copied in */
+
+ png_debug1(3, "allocated %d entries for info_ptr->text", max_text);
+ }
+
+ for (i = 0; i < num_text; i++)
+ {
+ size_t text_length, key_len;
+ size_t lang_len, lang_key_len;
+ png_textp textp = &(info_ptr->text[info_ptr->num_text]);
+
+ if (text_ptr[i].key == NULL)
+ continue;
+
+ if (text_ptr[i].compression < PNG_TEXT_COMPRESSION_NONE ||
+ text_ptr[i].compression >= PNG_TEXT_COMPRESSION_LAST)
+ {
+ png_chunk_report(png_ptr, "text compression mode is out of range",
+ PNG_CHUNK_WRITE_ERROR);
+ continue;
+ }
+
+ key_len = strlen(text_ptr[i].key);
+
+ if (text_ptr[i].compression <= 0)
+ {
+ lang_len = 0;
+ lang_key_len = 0;
+ }
+
+ else
+# ifdef PNG_iTXt_SUPPORTED
+ {
+ /* Set iTXt data */
+
+ if (text_ptr[i].lang != NULL)
+ lang_len = strlen(text_ptr[i].lang);
+
+ else
+ lang_len = 0;
+
+ if (text_ptr[i].lang_key != NULL)
+ lang_key_len = strlen(text_ptr[i].lang_key);
+
+ else
+ lang_key_len = 0;
+ }
+# else /* PNG_iTXt_SUPPORTED */
+ {
+ png_chunk_report(png_ptr, "iTXt chunk not supported",
+ PNG_CHUNK_WRITE_ERROR);
+ continue;
+ }
+# endif
+
+ if (text_ptr[i].text == NULL || text_ptr[i].text[0] == '\0')
+ {
+ text_length = 0;
+# ifdef PNG_iTXt_SUPPORTED
+ if (text_ptr[i].compression > 0)
+ textp->compression = PNG_ITXT_COMPRESSION_NONE;
+
+ else
+# endif
+ textp->compression = PNG_TEXT_COMPRESSION_NONE;
+ }
+
+ else
+ {
+ text_length = strlen(text_ptr[i].text);
+ textp->compression = text_ptr[i].compression;
+ }
+
+ textp->key = png_voidcast(png_charp,png_malloc_base(png_ptr,
+ key_len + text_length + lang_len + lang_key_len + 4));
+
+ if (textp->key == NULL)
+ {
+ png_chunk_report(png_ptr, "text chunk: out of memory",
+ PNG_CHUNK_WRITE_ERROR);
+ return 1;
+ }
+
+ png_debug2(2, "Allocated %lu bytes at %p in png_set_text",
+ (unsigned long)(png_uint_32)
+ (key_len + lang_len + lang_key_len + text_length + 4),
+ textp->key);
+
+ memcpy(textp->key, text_ptr[i].key, key_len);
+ *(textp->key + key_len) = '\0';
+
+ if (text_ptr[i].compression > 0)
+ {
+ textp->lang = textp->key + key_len + 1;
+ memcpy(textp->lang, text_ptr[i].lang, lang_len);
+ *(textp->lang + lang_len) = '\0';
+ textp->lang_key = textp->lang + lang_len + 1;
+ memcpy(textp->lang_key, text_ptr[i].lang_key, lang_key_len);
+ *(textp->lang_key + lang_key_len) = '\0';
+ textp->text = textp->lang_key + lang_key_len + 1;
+ }
+
+ else
+ {
+ textp->lang=NULL;
+ textp->lang_key=NULL;
+ textp->text = textp->key + key_len + 1;
+ }
+
+ if (text_length)
+ memcpy(textp->text, text_ptr[i].text, text_length);
+
+ *(textp->text + text_length) = '\0';
+
+# ifdef PNG_iTXt_SUPPORTED
+ if (textp->compression > 0)
+ {
+ textp->text_length = 0;
+ textp->itxt_length = text_length;
+ }
+
+ else
+# endif
+ {
+ textp->text_length = text_length;
+ textp->itxt_length = 0;
+ }
+
+ info_ptr->num_text++;
+ png_debug1(3, "transferred text chunk %d", info_ptr->num_text);
+ }
+
+ return(0);
+}
+#endif
+
+#ifdef PNG_tIME_SUPPORTED
+void PNGAPI
+png_set_tIME(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_const_timep mod_time)
+{
+ png_debug1(1, "in %s storage function", "tIME");
+
+ if (png_ptr == NULL || info_ptr == NULL || mod_time == NULL ||
+ (png_ptr->mode & PNG_WROTE_tIME))
+ return;
+
+ if (mod_time->month == 0 || mod_time->month > 12 ||
+ mod_time->day == 0 || mod_time->day > 31 ||
+ mod_time->hour > 23 || mod_time->minute > 59 ||
+ mod_time->second > 60)
+ {
+ png_warning(png_ptr, "Ignoring invalid time value");
+ return;
+ }
+
+ info_ptr->mod_time = *mod_time;
+ info_ptr->valid |= PNG_INFO_tIME;
+}
+#endif
+
+#ifdef PNG_tRNS_SUPPORTED
+void PNGAPI
+png_set_tRNS(png_structrp png_ptr, png_inforp info_ptr,
+ png_const_bytep trans_alpha, int num_trans, png_const_color_16p trans_color)
+{
+ png_debug1(1, "in %s storage function", "tRNS");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ if (trans_alpha != NULL)
+ {
+ /* It may not actually be necessary to set png_ptr->trans_alpha here;
+ * we do it for backward compatibility with the way the png_handle_tRNS
+ * function used to do the allocation.
+ *
+ * 1.6.0: The above statement is incorrect; png_handle_tRNS effectively
+ * relies on png_set_tRNS storing the information in png_struct
+ * (otherwise it won't be there for the code in pngrtran.c).
+ */
+
+ png_free_data(png_ptr, info_ptr, PNG_FREE_TRNS, 0);
+
+ /* Changed from num_trans to PNG_MAX_PALETTE_LENGTH in version 1.2.1 */
+ png_ptr->trans_alpha = info_ptr->trans_alpha = png_voidcast(png_bytep,
+ png_malloc(png_ptr, PNG_MAX_PALETTE_LENGTH));
+
+ if (num_trans > 0 && num_trans <= PNG_MAX_PALETTE_LENGTH)
+ memcpy(info_ptr->trans_alpha, trans_alpha, (png_size_t)num_trans);
+ }
+
+ if (trans_color != NULL)
+ {
+ int sample_max = (1 << info_ptr->bit_depth);
+
+ if ((info_ptr->color_type == PNG_COLOR_TYPE_GRAY &&
+ trans_color->gray > sample_max) ||
+ (info_ptr->color_type == PNG_COLOR_TYPE_RGB &&
+ (trans_color->red > sample_max ||
+ trans_color->green > sample_max ||
+ trans_color->blue > sample_max)))
+ png_warning(png_ptr,
+ "tRNS chunk has out-of-range samples for bit_depth");
+
+ info_ptr->trans_color = *trans_color;
+
+ if (num_trans == 0)
+ num_trans = 1;
+ }
+
+ info_ptr->num_trans = (png_uint_16)num_trans;
+
+ if (num_trans != 0)
+ {
+ info_ptr->valid |= PNG_INFO_tRNS;
+ info_ptr->free_me |= PNG_FREE_TRNS;
+ }
+}
+#endif
+
+#ifdef PNG_sPLT_SUPPORTED
+void PNGAPI
+png_set_sPLT(png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_sPLT_tp entries, int nentries)
+/*
+ * entries - array of png_sPLT_t structures
+ * to be added to the list of palettes
+ * in the info structure.
+ *
+ * nentries - number of palette structures to be
+ * added.
+ */
+{
+ png_sPLT_tp np;
+
+ if (png_ptr == NULL || info_ptr == NULL || nentries <= 0 || entries == NULL)
+ return;
+
+ /* Use the internal realloc function, which checks for all the possible
+ * overflows. Notice that the parameters are (int) and (size_t)
+ */
+ np = png_voidcast(png_sPLT_tp,png_realloc_array(png_ptr,
+ info_ptr->splt_palettes, info_ptr->splt_palettes_num, nentries,
+ sizeof *np));
+
+ if (np == NULL)
+ {
+ /* Out of memory or too many chunks */
+ png_chunk_report(png_ptr, "too many sPLT chunks", PNG_CHUNK_WRITE_ERROR);
+ return;
+ }
+
+ png_free(png_ptr, info_ptr->splt_palettes);
+ info_ptr->splt_palettes = np;
+ info_ptr->free_me |= PNG_FREE_SPLT;
+
+ np += info_ptr->splt_palettes_num;
+
+ do
+ {
+ png_size_t length;
+
+ /* Skip invalid input entries */
+ if (entries->name == NULL || entries->entries == NULL)
+ {
+ /* png_handle_sPLT doesn't do this, so this is an app error */
+ png_app_error(png_ptr, "png_set_sPLT: invalid sPLT");
+ /* Just skip the invalid entry */
+ continue;
+ }
+
+ np->depth = entries->depth;
+
+ /* In the even of out-of-memory just return - there's no point keeping on
+ * trying to add sPLT chunks.
+ */
+ length = strlen(entries->name) + 1;
+ np->name = png_voidcast(png_charp, png_malloc_base(png_ptr, length));
+
+ if (np->name == NULL)
+ break;
+
+ memcpy(np->name, entries->name, length);
+
+ /* IMPORTANT: we have memory now that won't get freed if something else
+ * goes wrong, this code must free it. png_malloc_array produces no
+ * warnings, use a png_chunk_report (below) if there is an error.
+ */
+ np->entries = png_voidcast(png_sPLT_entryp, png_malloc_array(png_ptr,
+ entries->nentries, sizeof (png_sPLT_entry)));
+
+ if (np->entries == NULL)
+ {
+ png_free(png_ptr, np->name);
+ break;
+ }
+
+ np->nentries = entries->nentries;
+ /* This multiply can't overflow because png_malloc_array has already
+ * checked it when doing the allocation.
+ */
+ memcpy(np->entries, entries->entries,
+ entries->nentries * sizeof (png_sPLT_entry));
+
+ /* Note that 'continue' skips the advance of the out pointer and out
+ * count, so an invalid entry is not added.
+ */
+ info_ptr->valid |= PNG_INFO_sPLT;
+ ++(info_ptr->splt_palettes_num);
+ ++np;
+ }
+ while (++entries, --nentries);
+
+ if (nentries > 0)
+ png_chunk_report(png_ptr, "sPLT out of memory", PNG_CHUNK_WRITE_ERROR);
+}
+#endif /* PNG_sPLT_SUPPORTED */
+
+#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED
+static png_byte
+check_location(png_const_structrp png_ptr, int location)
+{
+ location &= (PNG_HAVE_IHDR|PNG_HAVE_PLTE|PNG_AFTER_IDAT);
+
+ /* New in 1.6.0; copy the location and check it. This is an API
+ * change, previously the app had to use the
+ * png_set_unknown_chunk_location API below for each chunk.
+ */
+ if (location == 0 && !(png_ptr->mode & PNG_IS_READ_STRUCT))
+ {
+ /* Write struct, so unknown chunks come from the app */
+ png_app_warning(png_ptr,
+ "png_set_unknown_chunks now expects a valid location");
+ /* Use the old behavior */
+ location = (png_byte)(png_ptr->mode &
+ (PNG_HAVE_IHDR|PNG_HAVE_PLTE|PNG_AFTER_IDAT));
+ }
+
+ /* This need not be an internal error - if the app calls
+ * png_set_unknown_chunks on a read pointer it must get the location right.
+ */
+ if (location == 0)
+ png_error(png_ptr, "invalid location in png_set_unknown_chunks");
+
+ /* Now reduce the location to the top-most set bit by removing each least
+ * significant bit in turn.
+ */
+ while (location != (location & -location))
+ location &= ~(location & -location);
+
+ /* The cast is safe because 'location' is a bit mask and only the low four
+ * bits are significant.
+ */
+ return (png_byte)location;
+}
+
+void PNGAPI
+png_set_unknown_chunks(png_const_structrp png_ptr,
+ png_inforp info_ptr, png_const_unknown_chunkp unknowns, int num_unknowns)
+{
+ png_unknown_chunkp np;
+
+ if (png_ptr == NULL || info_ptr == NULL || num_unknowns <= 0 ||
+ unknowns == NULL)
+ return;
+
+ /* Check for the failure cases where support has been disabled at compile
+ * time. This code is hardly ever compiled - it's here because
+ * STORE_UNKNOWN_CHUNKS is set by both read and write code (compiling in this
+ * code) but may be meaningless if the read or write handling of unknown
+ * chunks is not compiled in.
+ */
+# if !defined(PNG_READ_UNKNOWN_CHUNKS_SUPPORTED) && \
+ defined(PNG_READ_SUPPORTED)
+ if (png_ptr->mode & PNG_IS_READ_STRUCT)
+ {
+ png_app_error(png_ptr, "no unknown chunk support on read");
+ return;
+ }
+# endif
+# if !defined(PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED) && \
+ defined(PNG_WRITE_SUPPORTED)
+ if (!(png_ptr->mode & PNG_IS_READ_STRUCT))
+ {
+ png_app_error(png_ptr, "no unknown chunk support on write");
+ return;
+ }
+# endif
+
+ /* Prior to 1.6.0 this code used png_malloc_warn; however, this meant that
+ * unknown critical chunks could be lost with just a warning resulting in
+ * undefined behavior. Now png_chunk_report is used to provide behavior
+ * appropriate to read or write.
+ */
+ np = png_voidcast(png_unknown_chunkp, png_realloc_array(png_ptr,
+ info_ptr->unknown_chunks, info_ptr->unknown_chunks_num, num_unknowns,
+ sizeof *np));
+
+ if (np == NULL)
+ {
+ png_chunk_report(png_ptr, "too many unknown chunks",
+ PNG_CHUNK_WRITE_ERROR);
+ return;
+ }
+
+ png_free(png_ptr, info_ptr->unknown_chunks);
+ info_ptr->unknown_chunks = np; /* safe because it is initialized */
+ info_ptr->free_me |= PNG_FREE_UNKN;
+
+ np += info_ptr->unknown_chunks_num;
+
+ /* Increment unknown_chunks_num each time round the loop to protect the
+ * just-allocated chunk data.
+ */
+ for (; num_unknowns > 0; --num_unknowns, ++unknowns)
+ {
+ memcpy(np->name, unknowns->name, (sizeof np->name));
+ np->name[(sizeof np->name)-1] = '\0';
+ np->location = check_location(png_ptr, unknowns->location);
+
+ if (unknowns->size == 0)
+ {
+ np->data = NULL;
+ np->size = 0;
+ }
+
+ else
+ {
+ np->data = png_voidcast(png_bytep,
+ png_malloc_base(png_ptr, unknowns->size));
+
+ if (np->data == NULL)
+ {
+ png_chunk_report(png_ptr, "unknown chunk: out of memory",
+ PNG_CHUNK_WRITE_ERROR);
+ /* But just skip storing the unknown chunk */
+ continue;
+ }
+
+ memcpy(np->data, unknowns->data, unknowns->size);
+ np->size = unknowns->size;
+ }
+
+ /* These increments are skipped on out-of-memory for the data - the
+ * unknown chunk entry gets overwritten if the png_chunk_report returns.
+ * This is correct in the read case (the chunk is just dropped.)
+ */
+ ++np;
+ ++(info_ptr->unknown_chunks_num);
+ }
+}
+
+void PNGAPI
+png_set_unknown_chunk_location(png_const_structrp png_ptr, png_inforp info_ptr,
+ int chunk, int location)
+{
+ /* This API is pretty pointless in 1.6.0 because the location can be set
+ * before the call to png_set_unknown_chunks.
+ *
+ * TODO: add a png_app_warning in 1.7
+ */
+ if (png_ptr != NULL && info_ptr != NULL && chunk >= 0 &&
+ chunk < info_ptr->unknown_chunks_num)
+ {
+ if ((location & (PNG_HAVE_IHDR|PNG_HAVE_PLTE|PNG_AFTER_IDAT)) == 0)
+ {
+ png_app_error(png_ptr, "invalid unknown chunk location");
+ /* Fake out the pre 1.6.0 behavior: */
+ if ((location & PNG_HAVE_IDAT)) /* undocumented! */
+ location = PNG_AFTER_IDAT;
+
+ else
+ location = PNG_HAVE_IHDR; /* also undocumented */
+ }
+
+ info_ptr->unknown_chunks[chunk].location =
+ check_location(png_ptr, location);
+ }
+}
+#endif
+
+
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+png_uint_32 PNGAPI
+png_permit_mng_features (png_structrp png_ptr, png_uint_32 mng_features)
+{
+ png_debug(1, "in png_permit_mng_features");
+
+ if (png_ptr == NULL)
+ return 0;
+
+ png_ptr->mng_features_permitted = mng_features & PNG_ALL_MNG_FEATURES;
+
+ return png_ptr->mng_features_permitted;
+}
+#endif
+
+#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED
+static unsigned int
+add_one_chunk(png_bytep list, unsigned int count, png_const_bytep add, int keep)
+{
+ unsigned int i;
+
+ /* Utility function: update the 'keep' state of a chunk if it is already in
+ * the list, otherwise add it to the list.
+ */
+ for (i=0; i<count; ++i, list += 5) if (memcmp(list, add, 4) == 0)
+ {
+ list[4] = (png_byte)keep;
+ return count;
+ }
+
+ if (keep != PNG_HANDLE_CHUNK_AS_DEFAULT)
+ {
+ ++count;
+ memcpy(list, add, 4);
+ list[4] = (png_byte)keep;
+ }
+
+ return count;
+}
+
+void PNGAPI
+png_set_keep_unknown_chunks(png_structrp png_ptr, int keep,
+ png_const_bytep chunk_list, int num_chunks_in)
+{
+ png_bytep new_list;
+ unsigned int num_chunks, old_num_chunks;
+
+ if (png_ptr == NULL)
+ return;
+
+ if (keep < 0 || keep >= PNG_HANDLE_CHUNK_LAST)
+ {
+ png_app_error(png_ptr, "png_set_keep_unknown_chunks: invalid keep");
+ return;
+ }
+
+ if (num_chunks_in <= 0)
+ {
+ png_ptr->unknown_default = keep;
+
+ /* '0' means just set the flags, so stop here */
+ if (num_chunks_in == 0)
+ return;
+ }
+
+ if (num_chunks_in < 0)
+ {
+ /* Ignore all unknown chunks and all chunks recognized by
+ * libpng except for IHDR, PLTE, tRNS, IDAT, and IEND
+ */
+ static PNG_CONST png_byte chunks_to_ignore[] = {
+ 98, 75, 71, 68, '\0', /* bKGD */
+ 99, 72, 82, 77, '\0', /* cHRM */
+ 103, 65, 77, 65, '\0', /* gAMA */
+ 104, 73, 83, 84, '\0', /* hIST */
+ 105, 67, 67, 80, '\0', /* iCCP */
+ 105, 84, 88, 116, '\0', /* iTXt */
+ 111, 70, 70, 115, '\0', /* oFFs */
+ 112, 67, 65, 76, '\0', /* pCAL */
+ 112, 72, 89, 115, '\0', /* pHYs */
+ 115, 66, 73, 84, '\0', /* sBIT */
+ 115, 67, 65, 76, '\0', /* sCAL */
+ 115, 80, 76, 84, '\0', /* sPLT */
+ 115, 84, 69, 82, '\0', /* sTER */
+ 115, 82, 71, 66, '\0', /* sRGB */
+ 116, 69, 88, 116, '\0', /* tEXt */
+ 116, 73, 77, 69, '\0', /* tIME */
+ 122, 84, 88, 116, '\0' /* zTXt */
+ };
+
+ chunk_list = chunks_to_ignore;
+ num_chunks = (sizeof chunks_to_ignore)/5;
+ }
+
+ else /* num_chunks_in > 0 */
+ {
+ if (chunk_list == NULL)
+ {
+ /* Prior to 1.6.0 this was silently ignored, now it is an app_error
+ * which can be switched off.
+ */
+ png_app_error(png_ptr, "png_set_keep_unknown_chunks: no chunk list");
+ return;
+ }
+
+ num_chunks = num_chunks_in;
+ }
+
+ old_num_chunks = png_ptr->num_chunk_list;
+ if (png_ptr->chunk_list == NULL)
+ old_num_chunks = 0;
+
+ /* Since num_chunks is always restricted to UINT_MAX/5 this can't overflow.
+ */
+ if (num_chunks + old_num_chunks > UINT_MAX/5)
+ {
+ png_app_error(png_ptr, "png_set_keep_unknown_chunks: too many chunks");
+ return;
+ }
+
+ /* If these chunks are being reset to the default then no more memory is
+ * required because add_one_chunk above doesn't extend the list if the 'keep'
+ * parameter is the default.
+ */
+ if (keep)
+ {
+ new_list = png_voidcast(png_bytep, png_malloc(png_ptr,
+ 5 * (num_chunks + old_num_chunks)));
+
+ if (old_num_chunks > 0)
+ memcpy(new_list, png_ptr->chunk_list, 5*old_num_chunks);
+ }
+
+ else if (old_num_chunks > 0)
+ new_list = png_ptr->chunk_list;
+
+ else
+ new_list = NULL;
+
+ /* Add the new chunks together with each one's handling code. If the chunk
+ * already exists the code is updated, otherwise the chunk is added to the
+ * end. (In libpng 1.6.0 order no longer matters because this code enforces
+ * the earlier convention that the last setting is the one that is used.)
+ */
+ if (new_list != NULL)
+ {
+ png_const_bytep inlist;
+ png_bytep outlist;
+ unsigned int i;
+
+ for (i=0; i<num_chunks; ++i)
+ old_num_chunks = add_one_chunk(new_list, old_num_chunks,
+ chunk_list+5*i, keep);
+
+ /* Now remove any spurious 'default' entries. */
+ num_chunks = 0;
+ for (i=0, inlist=outlist=new_list; i<old_num_chunks; ++i, inlist += 5)
+ if (inlist[4])
+ {
+ if (outlist != inlist)
+ memcpy(outlist, inlist, 5);
+ outlist += 5;
+ ++num_chunks;
+ }
+
+ /* This means the application has removed all the specialized handling. */
+ if (num_chunks == 0)
+ {
+ if (png_ptr->chunk_list != new_list)
+ png_free(png_ptr, new_list);
+
+ new_list = NULL;
+ }
+ }
+
+ else
+ num_chunks = 0;
+
+ png_ptr->num_chunk_list = num_chunks;
+
+ if (png_ptr->chunk_list != new_list)
+ {
+ if (png_ptr->chunk_list != NULL)
+ png_free(png_ptr, png_ptr->chunk_list);
+
+ png_ptr->chunk_list = new_list;
+ }
+}
+#endif
+
+#ifdef PNG_READ_USER_CHUNKS_SUPPORTED
+void PNGAPI
+png_set_read_user_chunk_fn(png_structrp png_ptr, png_voidp user_chunk_ptr,
+ png_user_chunk_ptr read_user_chunk_fn)
+{
+ png_debug(1, "in png_set_read_user_chunk_fn");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->read_user_chunk_fn = read_user_chunk_fn;
+ png_ptr->user_chunk_ptr = user_chunk_ptr;
+}
+#endif
+
+#ifdef PNG_INFO_IMAGE_SUPPORTED
+void PNGAPI
+png_set_rows(png_const_structrp png_ptr, png_inforp info_ptr,
+ png_bytepp row_pointers)
+{
+ png_debug1(1, "in %s storage function", "rows");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ if (info_ptr->row_pointers && (info_ptr->row_pointers != row_pointers))
+ png_free_data(png_ptr, info_ptr, PNG_FREE_ROWS, 0);
+
+ info_ptr->row_pointers = row_pointers;
+
+ if (row_pointers)
+ info_ptr->valid |= PNG_INFO_IDAT;
+}
+#endif
+
+void PNGAPI
+png_set_compression_buffer_size(png_structrp png_ptr, png_size_t size)
+{
+ if (png_ptr == NULL)
+ return;
+
+ if (size == 0 || size > PNG_UINT_31_MAX)
+ png_error(png_ptr, "invalid compression buffer size");
+
+# ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+ if (png_ptr->mode & PNG_IS_READ_STRUCT)
+ {
+ png_ptr->IDAT_read_size = (png_uint_32)size; /* checked above */
+ return;
+ }
+# endif
+
+# ifdef PNG_WRITE_SUPPORTED
+ if (!(png_ptr->mode & PNG_IS_READ_STRUCT))
+ {
+ if (png_ptr->zowner != 0)
+ {
+ png_warning(png_ptr,
+ "Compression buffer size cannot be changed because it is in use");
+ return;
+ }
+
+ if (size > ZLIB_IO_MAX)
+ {
+ png_warning(png_ptr,
+ "Compression buffer size limited to system maximum");
+ size = ZLIB_IO_MAX; /* must fit */
+ }
+
+ else if (size < 6)
+ {
+ /* Deflate will potentially go into an infinite loop on a SYNC_FLUSH
+ * if this is permitted.
+ */
+ png_warning(png_ptr,
+ "Compression buffer size cannot be reduced below 6");
+ return;
+ }
+
+ if (png_ptr->zbuffer_size != size)
+ {
+ png_free_buffer_list(png_ptr, &png_ptr->zbuffer_list);
+ png_ptr->zbuffer_size = (uInt)size;
+ }
+ }
+# endif
+}
+
+void PNGAPI
+png_set_invalid(png_const_structrp png_ptr, png_inforp info_ptr, int mask)
+{
+ if (png_ptr && info_ptr)
+ info_ptr->valid &= ~mask;
+}
+
+
+#ifdef PNG_SET_USER_LIMITS_SUPPORTED
+/* This function was added to libpng 1.2.6 */
+void PNGAPI
+png_set_user_limits (png_structrp png_ptr, png_uint_32 user_width_max,
+ png_uint_32 user_height_max)
+{
+ /* Images with dimensions larger than these limits will be
+ * rejected by png_set_IHDR(). To accept any PNG datastream
+ * regardless of dimensions, set both limits to 0x7ffffffL.
+ */
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->user_width_max = user_width_max;
+ png_ptr->user_height_max = user_height_max;
+}
+
+/* This function was added to libpng 1.4.0 */
+void PNGAPI
+png_set_chunk_cache_max (png_structrp png_ptr, png_uint_32 user_chunk_cache_max)
+{
+ if (png_ptr)
+ png_ptr->user_chunk_cache_max = user_chunk_cache_max;
+}
+
+/* This function was added to libpng 1.4.1 */
+void PNGAPI
+png_set_chunk_malloc_max (png_structrp png_ptr,
+ png_alloc_size_t user_chunk_malloc_max)
+{
+ if (png_ptr)
+ png_ptr->user_chunk_malloc_max = user_chunk_malloc_max;
+}
+#endif /* ?PNG_SET_USER_LIMITS_SUPPORTED */
+
+
+#ifdef PNG_BENIGN_ERRORS_SUPPORTED
+void PNGAPI
+png_set_benign_errors(png_structrp png_ptr, int allowed)
+{
+ png_debug(1, "in png_set_benign_errors");
+
+ /* If allowed is 1, png_benign_error() is treated as a warning.
+ *
+ * If allowed is 0, png_benign_error() is treated as an error (which
+ * is the default behavior if png_set_benign_errors() is not called).
+ */
+
+ if (allowed)
+ png_ptr->flags |= PNG_FLAG_BENIGN_ERRORS_WARN |
+ PNG_FLAG_APP_WARNINGS_WARN | PNG_FLAG_APP_ERRORS_WARN;
+
+ else
+ png_ptr->flags &= ~(PNG_FLAG_BENIGN_ERRORS_WARN |
+ PNG_FLAG_APP_WARNINGS_WARN | PNG_FLAG_APP_ERRORS_WARN);
+}
+#endif /* PNG_BENIGN_ERRORS_SUPPORTED */
+
+#ifdef PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED
+ /* Whether to report invalid palette index; added at libng-1.5.10.
+ * It is possible for an indexed (color-type==3) PNG file to contain
+ * pixels with invalid (out-of-range) indexes if the PLTE chunk has
+ * fewer entries than the image's bit-depth would allow. We recover
+ * from this gracefully by filling any incomplete palette with zeroes
+ * (opaque black). By default, when this occurs libpng will issue
+ * a benign error. This API can be used to override that behavior.
+ */
+void PNGAPI
+png_set_check_for_invalid_index(png_structrp png_ptr, int allowed)
+{
+ png_debug(1, "in png_set_check_for_invalid_index");
+
+ if (allowed > 0)
+ png_ptr->num_palette_max = 0;
+
+ else
+ png_ptr->num_palette_max = -1;
+}
+#endif
+#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngstruct.h b/ml/dlib/dlib/external/libpng/pngstruct.h
new file mode 100644
index 000000000..d58c02884
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngstruct.h
@@ -0,0 +1,489 @@
+
+/* pngstruct.h - header file for PNG reference library
+ *
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * Last changed in libpng 1.6.1 [March 28, 2013]
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+/* The structure that holds the information to read and write PNG files.
+ * The only people who need to care about what is inside of this are the
+ * people who will be modifying the library for their own special needs.
+ * It should NOT be accessed directly by an application.
+ */
+
+#ifndef PNGSTRUCT_H
+#define PNGSTRUCT_H
+/* zlib.h defines the structure z_stream, an instance of which is included
+ * in this structure and is required for decompressing the LZ compressed
+ * data in PNG files.
+ */
+#ifndef ZLIB_CONST
+ /* We must ensure that zlib uses 'const' in declarations. */
+# define ZLIB_CONST
+#endif
+#include "zlib.h"
+#ifdef const
+ /* zlib.h sometimes #defines const to nothing, undo this. */
+# undef const
+#endif
+
+/* zlib.h has mediocre z_const use before 1.2.6, this stuff is for compatibility
+ * with older builds.
+ */
+#if ZLIB_VERNUM < 0x1260
+# define PNGZ_MSG_CAST(s) png_constcast(char*,s)
+# define PNGZ_INPUT_CAST(b) png_constcast(png_bytep,b)
+#else
+# define PNGZ_MSG_CAST(s) (s)
+# define PNGZ_INPUT_CAST(b) (b)
+#endif
+
+/* zlib.h declares a magic type 'uInt' that limits the amount of data that zlib
+ * can handle at once. This type need be no larger than 16 bits (so maximum of
+ * 65535), this define allows us to discover how big it is, but limited by the
+ * maximuum for png_size_t. The value can be overriden in a library build
+ * (pngusr.h, or set it in CPPFLAGS) and it works to set it to a considerably
+ * lower value (e.g. 255 works). A lower value may help memory usage (slightly)
+ * and may even improve performance on some systems (and degrade it on others.)
+ */
+#ifndef ZLIB_IO_MAX
+# define ZLIB_IO_MAX ((uInt)-1)
+#endif
+
+#ifdef PNG_WRITE_SUPPORTED
+/* The type of a compression buffer list used by the write code. */
+typedef struct png_compression_buffer
+{
+ struct png_compression_buffer *next;
+ png_byte output[1]; /* actually zbuf_size */
+} png_compression_buffer, *png_compression_bufferp;
+
+#define PNG_COMPRESSION_BUFFER_SIZE(pp)\
+ (offsetof(png_compression_buffer, output) + (pp)->zbuffer_size)
+#endif
+
+/* Colorspace support; structures used in png_struct, png_info and in internal
+ * functions to hold and communicate information about the color space.
+ *
+ * PNG_COLORSPACE_SUPPORTED is only required if the application will perform
+ * colorspace corrections, otherwise all the colorspace information can be
+ * skipped and the size of libpng can be reduced (significantly) by compiling
+ * out the colorspace support.
+ */
+#ifdef PNG_COLORSPACE_SUPPORTED
+/* The chromaticities of the red, green and blue colorants and the chromaticity
+ * of the corresponding white point (i.e. of rgb(1.0,1.0,1.0)).
+ */
+typedef struct png_xy
+{
+ png_fixed_point redx, redy;
+ png_fixed_point greenx, greeny;
+ png_fixed_point bluex, bluey;
+ png_fixed_point whitex, whitey;
+} png_xy;
+
+/* The same data as above but encoded as CIE XYZ values. When this data comes
+ * from chromaticities the sum of the Y values is assumed to be 1.0
+ */
+typedef struct png_XYZ
+{
+ png_fixed_point red_X, red_Y, red_Z;
+ png_fixed_point green_X, green_Y, green_Z;
+ png_fixed_point blue_X, blue_Y, blue_Z;
+} png_XYZ;
+#endif /* COLORSPACE */
+
+#if defined(PNG_COLORSPACE_SUPPORTED) || defined(PNG_GAMMA_SUPPORTED)
+/* A colorspace is all the above plus, potentially, profile information,
+ * however at present libpng does not use the profile internally so it is only
+ * stored in the png_info struct (if iCCP is supported.) The rendering intent
+ * is retained here and is checked.
+ *
+ * The file gamma encoding information is also stored here and gamma correction
+ * is done by libpng, whereas color correction must currently be done by the
+ * application.
+ */
+typedef struct png_colorspace
+{
+#ifdef PNG_GAMMA_SUPPORTED
+ png_fixed_point gamma; /* File gamma */
+#endif
+
+#ifdef PNG_COLORSPACE_SUPPORTED
+ png_xy end_points_xy; /* End points as chromaticities */
+ png_XYZ end_points_XYZ; /* End points as CIE XYZ colorant values */
+ png_uint_16 rendering_intent; /* Rendering intent of a profile */
+#endif
+
+ /* Flags are always defined to simplify the code. */
+ png_uint_16 flags; /* As defined below */
+} png_colorspace, * PNG_RESTRICT png_colorspacerp;
+
+typedef const png_colorspace * PNG_RESTRICT png_const_colorspacerp;
+
+/* General flags for the 'flags' field */
+#define PNG_COLORSPACE_HAVE_GAMMA 0x0001
+#define PNG_COLORSPACE_HAVE_ENDPOINTS 0x0002
+#define PNG_COLORSPACE_HAVE_INTENT 0x0004
+#define PNG_COLORSPACE_FROM_gAMA 0x0008
+#define PNG_COLORSPACE_FROM_cHRM 0x0010
+#define PNG_COLORSPACE_FROM_sRGB 0x0020
+#define PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB 0x0040
+#define PNG_COLORSPACE_MATCHES_sRGB 0x0080 /* exact match on profile */
+#define PNG_COLORSPACE_INVALID 0x8000
+#define PNG_COLORSPACE_CANCEL(flags) (0xffff ^ (flags))
+#endif /* COLORSPACE || GAMMA */
+
+struct png_struct_def
+{
+#ifdef PNG_SETJMP_SUPPORTED
+ jmp_buf jmp_buf_local; /* New name in 1.6.0 for jmp_buf in png_struct */
+ png_longjmp_ptr longjmp_fn;/* setjmp non-local goto function. */
+ jmp_buf *jmp_buf_ptr; /* passed to longjmp_fn */
+ size_t jmp_buf_size; /* size of the above, if allocated */
+#endif
+ png_error_ptr error_fn; /* function for printing errors and aborting */
+#ifdef PNG_WARNINGS_SUPPORTED
+ png_error_ptr warning_fn; /* function for printing warnings */
+#endif
+ png_voidp error_ptr; /* user supplied struct for error functions */
+ png_rw_ptr write_data_fn; /* function for writing output data */
+ png_rw_ptr read_data_fn; /* function for reading input data */
+ png_voidp io_ptr; /* ptr to application struct for I/O functions */
+
+#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED
+ png_user_transform_ptr read_user_transform_fn; /* user read transform */
+#endif
+
+#ifdef PNG_WRITE_USER_TRANSFORM_SUPPORTED
+ png_user_transform_ptr write_user_transform_fn; /* user write transform */
+#endif
+
+/* These were added in libpng-1.0.2 */
+#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED
+#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) || \
+ defined(PNG_WRITE_USER_TRANSFORM_SUPPORTED)
+ png_voidp user_transform_ptr; /* user supplied struct for user transform */
+ png_byte user_transform_depth; /* bit depth of user transformed pixels */
+ png_byte user_transform_channels; /* channels in user transformed pixels */
+#endif
+#endif
+
+ png_uint_32 mode; /* tells us where we are in the PNG file */
+ png_uint_32 flags; /* flags indicating various things to libpng */
+ png_uint_32 transformations; /* which transformations to perform */
+
+ png_uint_32 zowner; /* ID (chunk type) of zstream owner, 0 if none */
+ z_stream zstream; /* decompression structure */
+
+#ifdef PNG_WRITE_SUPPORTED
+ png_compression_bufferp zbuffer_list; /* Created on demand during write */
+ uInt zbuffer_size; /* size of the actual buffer */
+
+ int zlib_level; /* holds zlib compression level */
+ int zlib_method; /* holds zlib compression method */
+ int zlib_window_bits; /* holds zlib compression window bits */
+ int zlib_mem_level; /* holds zlib compression memory level */
+ int zlib_strategy; /* holds zlib compression strategy */
+#endif
+/* Added at libpng 1.5.4 */
+#ifdef PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED
+ int zlib_text_level; /* holds zlib compression level */
+ int zlib_text_method; /* holds zlib compression method */
+ int zlib_text_window_bits; /* holds zlib compression window bits */
+ int zlib_text_mem_level; /* holds zlib compression memory level */
+ int zlib_text_strategy; /* holds zlib compression strategy */
+#endif
+/* End of material added at libpng 1.5.4 */
+/* Added at libpng 1.6.0 */
+#ifdef PNG_WRITE_SUPPORTED
+ int zlib_set_level; /* Actual values set into the zstream on write */
+ int zlib_set_method;
+ int zlib_set_window_bits;
+ int zlib_set_mem_level;
+ int zlib_set_strategy;
+#endif
+
+ png_uint_32 width; /* width of image in pixels */
+ png_uint_32 height; /* height of image in pixels */
+ png_uint_32 num_rows; /* number of rows in current pass */
+ png_uint_32 usr_width; /* width of row at start of write */
+ png_size_t rowbytes; /* size of row in bytes */
+ png_uint_32 iwidth; /* width of current interlaced row in pixels */
+ png_uint_32 row_number; /* current row in interlace pass */
+ png_uint_32 chunk_name; /* PNG_CHUNK() id of current chunk */
+ png_bytep prev_row; /* buffer to save previous (unfiltered) row.
+ * This is a pointer into big_prev_row
+ */
+ png_bytep row_buf; /* buffer to save current (unfiltered) row.
+ * This is a pointer into big_row_buf
+ */
+#ifdef PNG_WRITE_SUPPORTED
+ png_bytep sub_row; /* buffer to save "sub" row when filtering */
+ png_bytep up_row; /* buffer to save "up" row when filtering */
+ png_bytep avg_row; /* buffer to save "avg" row when filtering */
+ png_bytep paeth_row; /* buffer to save "Paeth" row when filtering */
+#endif
+ png_size_t info_rowbytes; /* Added in 1.5.4: cache of updated row bytes */
+
+ png_uint_32 idat_size; /* current IDAT size for read */
+ png_uint_32 crc; /* current chunk CRC value */
+ png_colorp palette; /* palette from the input file */
+ png_uint_16 num_palette; /* number of color entries in palette */
+
+/* Added at libpng-1.5.10 */
+#ifdef PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED
+ int num_palette_max; /* maximum palette index found in IDAT */
+#endif
+
+ png_uint_16 num_trans; /* number of transparency values */
+ png_byte compression; /* file compression type (always 0) */
+ png_byte filter; /* file filter type (always 0) */
+ png_byte interlaced; /* PNG_INTERLACE_NONE, PNG_INTERLACE_ADAM7 */
+ png_byte pass; /* current interlace pass (0 - 6) */
+ png_byte do_filter; /* row filter flags (see PNG_FILTER_ below ) */
+ png_byte color_type; /* color type of file */
+ png_byte bit_depth; /* bit depth of file */
+ png_byte usr_bit_depth; /* bit depth of users row: write only */
+ png_byte pixel_depth; /* number of bits per pixel */
+ png_byte channels; /* number of channels in file */
+#ifdef PNG_WRITE_SUPPORTED
+ png_byte usr_channels; /* channels at start of write: write only */
+#endif
+ png_byte sig_bytes; /* magic bytes read/written from start of file */
+ png_byte maximum_pixel_depth;
+ /* pixel depth used for the row buffers */
+ png_byte transformed_pixel_depth;
+ /* pixel depth after read/write transforms */
+#if defined(PNG_READ_FILLER_SUPPORTED) || defined(PNG_WRITE_FILLER_SUPPORTED)
+ png_uint_16 filler; /* filler bytes for pixel expansion */
+#endif
+
+#if defined(PNG_bKGD_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) ||\
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED)
+ png_byte background_gamma_type;
+ png_fixed_point background_gamma;
+ png_color_16 background; /* background color in screen gamma space */
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ png_color_16 background_1; /* background normalized to gamma 1.0 */
+#endif
+#endif /* PNG_bKGD_SUPPORTED */
+
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+ png_flush_ptr output_flush_fn; /* Function for flushing output */
+ png_uint_32 flush_dist; /* how many rows apart to flush, 0 - no flush */
+ png_uint_32 flush_rows; /* number of rows written since last flush */
+#endif
+
+#ifdef PNG_READ_GAMMA_SUPPORTED
+ int gamma_shift; /* number of "insignificant" bits in 16-bit gamma */
+ png_fixed_point screen_gamma; /* screen gamma value (display_exponent) */
+
+ png_bytep gamma_table; /* gamma table for 8-bit depth files */
+ png_uint_16pp gamma_16_table; /* gamma table for 16-bit depth files */
+#if defined(PNG_READ_BACKGROUND_SUPPORTED) || \
+ defined(PNG_READ_ALPHA_MODE_SUPPORTED) || \
+ defined(PNG_READ_RGB_TO_GRAY_SUPPORTED)
+ png_bytep gamma_from_1; /* converts from 1.0 to screen */
+ png_bytep gamma_to_1; /* converts from file to 1.0 */
+ png_uint_16pp gamma_16_from_1; /* converts from 1.0 to screen */
+ png_uint_16pp gamma_16_to_1; /* converts from file to 1.0 */
+#endif /* READ_BACKGROUND || READ_ALPHA_MODE || RGB_TO_GRAY */
+#endif
+
+#if defined(PNG_READ_GAMMA_SUPPORTED) || defined(PNG_sBIT_SUPPORTED)
+ png_color_8 sig_bit; /* significant bits in each available channel */
+#endif
+
+#if defined(PNG_READ_SHIFT_SUPPORTED) || defined(PNG_WRITE_SHIFT_SUPPORTED)
+ png_color_8 shift; /* shift for significant bit tranformation */
+#endif
+
+#if defined(PNG_tRNS_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) \
+ || defined(PNG_READ_EXPAND_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED)
+ png_bytep trans_alpha; /* alpha values for paletted files */
+ png_color_16 trans_color; /* transparent color for non-paletted files */
+#endif
+
+ png_read_status_ptr read_row_fn; /* called after each row is decoded */
+ png_write_status_ptr write_row_fn; /* called after each row is encoded */
+#ifdef PNG_PROGRESSIVE_READ_SUPPORTED
+ png_progressive_info_ptr info_fn; /* called after header data fully read */
+ png_progressive_row_ptr row_fn; /* called after a prog. row is decoded */
+ png_progressive_end_ptr end_fn; /* called after image is complete */
+ png_bytep save_buffer_ptr; /* current location in save_buffer */
+ png_bytep save_buffer; /* buffer for previously read data */
+ png_bytep current_buffer_ptr; /* current location in current_buffer */
+ png_bytep current_buffer; /* buffer for recently used data */
+ png_uint_32 push_length; /* size of current input chunk */
+ png_uint_32 skip_length; /* bytes to skip in input data */
+ png_size_t save_buffer_size; /* amount of data now in save_buffer */
+ png_size_t save_buffer_max; /* total size of save_buffer */
+ png_size_t buffer_size; /* total amount of available input data */
+ png_size_t current_buffer_size; /* amount of data now in current_buffer */
+ int process_mode; /* what push library is currently doing */
+ int cur_palette; /* current push library palette index */
+
+#endif /* PNG_PROGRESSIVE_READ_SUPPORTED */
+
+#if defined(__TURBOC__) && !defined(_Windows) && !defined(__FLAT__)
+/* For the Borland special 64K segment handler */
+ png_bytepp offset_table_ptr;
+ png_bytep offset_table;
+ png_uint_16 offset_table_number;
+ png_uint_16 offset_table_count;
+ png_uint_16 offset_table_count_free;
+#endif
+
+#ifdef PNG_READ_QUANTIZE_SUPPORTED
+ png_bytep palette_lookup; /* lookup table for quantizing */
+ png_bytep quantize_index; /* index translation for palette files */
+#endif
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ png_byte heuristic_method; /* heuristic for row filter selection */
+ png_byte num_prev_filters; /* number of weights for previous rows */
+ png_bytep prev_filters; /* filter type(s) of previous row(s) */
+ png_uint_16p filter_weights; /* weight(s) for previous line(s) */
+ png_uint_16p inv_filter_weights; /* 1/weight(s) for previous line(s) */
+ png_uint_16p filter_costs; /* relative filter calculation cost */
+ png_uint_16p inv_filter_costs; /* 1/relative filter calculation cost */
+#endif
+
+ /* Options */
+#ifdef PNG_SET_OPTION_SUPPORTED
+ png_byte options; /* On/off state (up to 4 options) */
+#endif
+
+#if PNG_LIBPNG_VER < 10700
+/* To do: remove this from libpng-1.7 */
+#ifdef PNG_TIME_RFC1123_SUPPORTED
+ char time_buffer[29]; /* String to hold RFC 1123 time text */
+#endif
+#endif
+
+/* New members added in libpng-1.0.6 */
+
+ png_uint_32 free_me; /* flags items libpng is responsible for freeing */
+
+#ifdef PNG_USER_CHUNKS_SUPPORTED
+ png_voidp user_chunk_ptr;
+#ifdef PNG_READ_USER_CHUNKS_SUPPORTED
+ png_user_chunk_ptr read_user_chunk_fn; /* user read chunk handler */
+#endif
+#endif
+
+#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+ int unknown_default; /* As PNG_HANDLE_* */
+ unsigned int num_chunk_list; /* Number of entries in the list */
+ png_bytep chunk_list; /* List of png_byte[5]; the textual chunk name
+ * followed by a PNG_HANDLE_* byte */
+#endif
+
+/* New members added in libpng-1.0.3 */
+#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED
+ png_byte rgb_to_gray_status;
+ /* Added in libpng 1.5.5 to record setting of coefficients: */
+ png_byte rgb_to_gray_coefficients_set;
+ /* These were changed from png_byte in libpng-1.0.6 */
+ png_uint_16 rgb_to_gray_red_coeff;
+ png_uint_16 rgb_to_gray_green_coeff;
+ /* deleted in 1.5.5: rgb_to_gray_blue_coeff; */
+#endif
+
+/* New member added in libpng-1.0.4 (renamed in 1.0.9) */
+#if defined(PNG_MNG_FEATURES_SUPPORTED)
+/* Changed from png_byte to png_uint_32 at version 1.2.0 */
+ png_uint_32 mng_features_permitted;
+#endif
+
+/* New member added in libpng-1.0.9, ifdef'ed out in 1.0.12, enabled in 1.2.0 */
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ png_byte filter_type;
+#endif
+
+/* New members added in libpng-1.2.0 */
+
+/* New members added in libpng-1.0.2 but first enabled by default in 1.2.0 */
+#ifdef PNG_USER_MEM_SUPPORTED
+ png_voidp mem_ptr; /* user supplied struct for mem functions */
+ png_malloc_ptr malloc_fn; /* function for allocating memory */
+ png_free_ptr free_fn; /* function for freeing memory */
+#endif
+
+/* New member added in libpng-1.0.13 and 1.2.0 */
+ png_bytep big_row_buf; /* buffer to save current (unfiltered) row */
+
+#ifdef PNG_READ_QUANTIZE_SUPPORTED
+/* The following three members were added at version 1.0.14 and 1.2.4 */
+ png_bytep quantize_sort; /* working sort array */
+ png_bytep index_to_palette; /* where the original index currently is
+ in the palette */
+ png_bytep palette_to_index; /* which original index points to this
+ palette color */
+#endif
+
+/* New members added in libpng-1.0.16 and 1.2.6 */
+ png_byte compression_type;
+
+#ifdef PNG_USER_LIMITS_SUPPORTED
+ png_uint_32 user_width_max;
+ png_uint_32 user_height_max;
+
+ /* Added in libpng-1.4.0: Total number of sPLT, text, and unknown
+ * chunks that can be stored (0 means unlimited).
+ */
+ png_uint_32 user_chunk_cache_max;
+
+ /* Total memory that a zTXt, sPLT, iTXt, iCCP, or unknown chunk
+ * can occupy when decompressed. 0 means unlimited.
+ */
+ png_alloc_size_t user_chunk_malloc_max;
+#endif
+
+/* New member added in libpng-1.0.25 and 1.2.17 */
+#ifdef PNG_READ_UNKNOWN_CHUNKS_SUPPORTED
+ /* Temporary storage for unknown chunk that the library doesn't recognize,
+ * used while reading the chunk.
+ */
+ png_unknown_chunk unknown_chunk;
+#endif
+
+/* New member added in libpng-1.2.26 */
+ png_size_t old_big_row_buf_size;
+
+#ifdef PNG_READ_SUPPORTED
+/* New member added in libpng-1.2.30 */
+ png_bytep read_buffer; /* buffer for reading chunk data */
+ png_alloc_size_t read_buffer_size; /* current size of the buffer */
+#endif
+#ifdef PNG_SEQUENTIAL_READ_SUPPORTED
+ uInt IDAT_read_size; /* limit on read buffer size for IDAT */
+#endif
+
+#ifdef PNG_IO_STATE_SUPPORTED
+/* New member added in libpng-1.4.0 */
+ png_uint_32 io_state;
+#endif
+
+/* New member added in libpng-1.5.6 */
+ png_bytep big_prev_row;
+
+/* New member added in libpng-1.5.7 */
+ void (*read_filter[PNG_FILTER_VALUE_LAST-1])(png_row_infop row_info,
+ png_bytep row, png_const_bytep prev_row);
+
+#ifdef PNG_READ_SUPPORTED
+#if defined(PNG_COLORSPACE_SUPPORTED) || defined(PNG_GAMMA_SUPPORTED)
+ png_colorspace colorspace;
+#endif
+#endif
+};
+#endif /* PNGSTRUCT_H */
diff --git a/ml/dlib/dlib/external/libpng/pngtrans.c b/ml/dlib/dlib/external/libpng/pngtrans.c
new file mode 100644
index 000000000..8f8bc5d9e
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngtrans.c
@@ -0,0 +1,841 @@
+
+/* pngtrans.c - transforms the data in a row (used by both readers and writers)
+ *
+ * Last changed in libpng 1.6.2 [April 25, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+#include "pngpriv.h"
+
+#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED)
+
+#if defined(PNG_READ_BGR_SUPPORTED) || defined(PNG_WRITE_BGR_SUPPORTED)
+/* Turn on BGR-to-RGB mapping */
+void PNGAPI
+png_set_bgr(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_bgr");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->transformations |= PNG_BGR;
+}
+#endif
+
+#if defined(PNG_READ_SWAP_SUPPORTED) || defined(PNG_WRITE_SWAP_SUPPORTED)
+/* Turn on 16 bit byte swapping */
+void PNGAPI
+png_set_swap(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_swap");
+
+ if (png_ptr == NULL)
+ return;
+
+ if (png_ptr->bit_depth == 16)
+ png_ptr->transformations |= PNG_SWAP_BYTES;
+}
+#endif
+
+#if defined(PNG_READ_PACK_SUPPORTED) || defined(PNG_WRITE_PACK_SUPPORTED)
+/* Turn on pixel packing */
+void PNGAPI
+png_set_packing(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_packing");
+
+ if (png_ptr == NULL)
+ return;
+
+ if (png_ptr->bit_depth < 8)
+ {
+ png_ptr->transformations |= PNG_PACK;
+ png_ptr->usr_bit_depth = 8;
+ }
+}
+#endif
+
+#if defined(PNG_READ_PACKSWAP_SUPPORTED)||defined(PNG_WRITE_PACKSWAP_SUPPORTED)
+/* Turn on packed pixel swapping */
+void PNGAPI
+png_set_packswap(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_packswap");
+
+ if (png_ptr == NULL)
+ return;
+
+ if (png_ptr->bit_depth < 8)
+ png_ptr->transformations |= PNG_PACKSWAP;
+}
+#endif
+
+#if defined(PNG_READ_SHIFT_SUPPORTED) || defined(PNG_WRITE_SHIFT_SUPPORTED)
+void PNGAPI
+png_set_shift(png_structrp png_ptr, png_const_color_8p true_bits)
+{
+ png_debug(1, "in png_set_shift");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->transformations |= PNG_SHIFT;
+ png_ptr->shift = *true_bits;
+}
+#endif
+
+#if defined(PNG_READ_INTERLACING_SUPPORTED) || \
+ defined(PNG_WRITE_INTERLACING_SUPPORTED)
+int PNGAPI
+png_set_interlace_handling(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_interlace handling");
+
+ if (png_ptr && png_ptr->interlaced)
+ {
+ png_ptr->transformations |= PNG_INTERLACE;
+ return (7);
+ }
+
+ return (1);
+}
+#endif
+
+#if defined(PNG_READ_FILLER_SUPPORTED) || defined(PNG_WRITE_FILLER_SUPPORTED)
+/* Add a filler byte on read, or remove a filler or alpha byte on write.
+ * The filler type has changed in v0.95 to allow future 2-byte fillers
+ * for 48-bit input data, as well as to avoid problems with some compilers
+ * that don't like bytes as parameters.
+ */
+void PNGAPI
+png_set_filler(png_structrp png_ptr, png_uint_32 filler, int filler_loc)
+{
+ png_debug(1, "in png_set_filler");
+
+ if (png_ptr == NULL)
+ return;
+
+ /* In libpng 1.6 it is possible to determine whether this is a read or write
+ * operation and therefore to do more checking here for a valid call.
+ */
+ if (png_ptr->mode & PNG_IS_READ_STRUCT)
+ {
+# ifdef PNG_READ_FILLER_SUPPORTED
+ /* On read png_set_filler is always valid, regardless of the base PNG
+ * format, because other transformations can give a format where the
+ * filler code can execute (basically an 8 or 16-bit component RGB or G
+ * format.)
+ *
+ * NOTE: usr_channels is not used by the read code! (This has led to
+ * confusion in the past.) The filler is only used in the read code.
+ */
+ png_ptr->filler = (png_uint_16)filler;
+# else
+ png_app_error(png_ptr, "png_set_filler not supported on read");
+ PNG_UNUSED(filler) /* not used in the write case */
+ return;
+# endif
+ }
+
+ else /* write */
+ {
+# ifdef PNG_WRITE_FILLER_SUPPORTED
+ /* On write the usr_channels parameter must be set correctly at the
+ * start to record the number of channels in the app-supplied data.
+ */
+ switch (png_ptr->color_type)
+ {
+ case PNG_COLOR_TYPE_RGB:
+ png_ptr->usr_channels = 4;
+ break;
+
+ case PNG_COLOR_TYPE_GRAY:
+ if (png_ptr->bit_depth >= 8)
+ {
+ png_ptr->usr_channels = 2;
+ break;
+ }
+
+ else
+ {
+ /* There simply isn't any code in libpng to strip out bits
+ * from bytes when the components are less than a byte in
+ * size!
+ */
+ png_app_error(png_ptr,
+ "png_set_filler is invalid for low bit depth gray output");
+ return;
+ }
+
+ default:
+ png_app_error(png_ptr,
+ "png_set_filler: inappropriate color type");
+ return;
+ }
+# else
+ png_app_error(png_ptr, "png_set_filler not supported on write");
+ return;
+# endif
+ }
+
+ /* Here on success - libpng supports the operation, set the transformation
+ * and the flag to say where the filler channel is.
+ */
+ png_ptr->transformations |= PNG_FILLER;
+
+ if (filler_loc == PNG_FILLER_AFTER)
+ png_ptr->flags |= PNG_FLAG_FILLER_AFTER;
+
+ else
+ png_ptr->flags &= ~PNG_FLAG_FILLER_AFTER;
+}
+
+/* Added to libpng-1.2.7 */
+void PNGAPI
+png_set_add_alpha(png_structrp png_ptr, png_uint_32 filler, int filler_loc)
+{
+ png_debug(1, "in png_set_add_alpha");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_set_filler(png_ptr, filler, filler_loc);
+ /* The above may fail to do anything. */
+ if (png_ptr->transformations & PNG_FILLER)
+ png_ptr->transformations |= PNG_ADD_ALPHA;
+}
+
+#endif
+
+#if defined(PNG_READ_SWAP_ALPHA_SUPPORTED) || \
+ defined(PNG_WRITE_SWAP_ALPHA_SUPPORTED)
+void PNGAPI
+png_set_swap_alpha(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_swap_alpha");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->transformations |= PNG_SWAP_ALPHA;
+}
+#endif
+
+#if defined(PNG_READ_INVERT_ALPHA_SUPPORTED) || \
+ defined(PNG_WRITE_INVERT_ALPHA_SUPPORTED)
+void PNGAPI
+png_set_invert_alpha(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_invert_alpha");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->transformations |= PNG_INVERT_ALPHA;
+}
+#endif
+
+#if defined(PNG_READ_INVERT_SUPPORTED) || defined(PNG_WRITE_INVERT_SUPPORTED)
+void PNGAPI
+png_set_invert_mono(png_structrp png_ptr)
+{
+ png_debug(1, "in png_set_invert_mono");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->transformations |= PNG_INVERT_MONO;
+}
+
+/* Invert monochrome grayscale data */
+void /* PRIVATE */
+png_do_invert(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_invert");
+
+ /* This test removed from libpng version 1.0.13 and 1.2.0:
+ * if (row_info->bit_depth == 1 &&
+ */
+ if (row_info->color_type == PNG_COLOR_TYPE_GRAY)
+ {
+ png_bytep rp = row;
+ png_size_t i;
+ png_size_t istop = row_info->rowbytes;
+
+ for (i = 0; i < istop; i++)
+ {
+ *rp = (png_byte)(~(*rp));
+ rp++;
+ }
+ }
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA &&
+ row_info->bit_depth == 8)
+ {
+ png_bytep rp = row;
+ png_size_t i;
+ png_size_t istop = row_info->rowbytes;
+
+ for (i = 0; i < istop; i += 2)
+ {
+ *rp = (png_byte)(~(*rp));
+ rp += 2;
+ }
+ }
+
+#ifdef PNG_16BIT_SUPPORTED
+ else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA &&
+ row_info->bit_depth == 16)
+ {
+ png_bytep rp = row;
+ png_size_t i;
+ png_size_t istop = row_info->rowbytes;
+
+ for (i = 0; i < istop; i += 4)
+ {
+ *rp = (png_byte)(~(*rp));
+ *(rp + 1) = (png_byte)(~(*(rp + 1)));
+ rp += 4;
+ }
+ }
+#endif
+}
+#endif
+
+#ifdef PNG_16BIT_SUPPORTED
+#if defined(PNG_READ_SWAP_SUPPORTED) || defined(PNG_WRITE_SWAP_SUPPORTED)
+/* Swaps byte order on 16 bit depth images */
+void /* PRIVATE */
+png_do_swap(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_swap");
+
+ if (row_info->bit_depth == 16)
+ {
+ png_bytep rp = row;
+ png_uint_32 i;
+ png_uint_32 istop= row_info->width * row_info->channels;
+
+ for (i = 0; i < istop; i++, rp += 2)
+ {
+ png_byte t = *rp;
+ *rp = *(rp + 1);
+ *(rp + 1) = t;
+ }
+ }
+}
+#endif
+#endif
+
+#if defined(PNG_READ_PACKSWAP_SUPPORTED)||defined(PNG_WRITE_PACKSWAP_SUPPORTED)
+static PNG_CONST png_byte onebppswaptable[256] = {
+ 0x00, 0x80, 0x40, 0xC0, 0x20, 0xA0, 0x60, 0xE0,
+ 0x10, 0x90, 0x50, 0xD0, 0x30, 0xB0, 0x70, 0xF0,
+ 0x08, 0x88, 0x48, 0xC8, 0x28, 0xA8, 0x68, 0xE8,
+ 0x18, 0x98, 0x58, 0xD8, 0x38, 0xB8, 0x78, 0xF8,
+ 0x04, 0x84, 0x44, 0xC4, 0x24, 0xA4, 0x64, 0xE4,
+ 0x14, 0x94, 0x54, 0xD4, 0x34, 0xB4, 0x74, 0xF4,
+ 0x0C, 0x8C, 0x4C, 0xCC, 0x2C, 0xAC, 0x6C, 0xEC,
+ 0x1C, 0x9C, 0x5C, 0xDC, 0x3C, 0xBC, 0x7C, 0xFC,
+ 0x02, 0x82, 0x42, 0xC2, 0x22, 0xA2, 0x62, 0xE2,
+ 0x12, 0x92, 0x52, 0xD2, 0x32, 0xB2, 0x72, 0xF2,
+ 0x0A, 0x8A, 0x4A, 0xCA, 0x2A, 0xAA, 0x6A, 0xEA,
+ 0x1A, 0x9A, 0x5A, 0xDA, 0x3A, 0xBA, 0x7A, 0xFA,
+ 0x06, 0x86, 0x46, 0xC6, 0x26, 0xA6, 0x66, 0xE6,
+ 0x16, 0x96, 0x56, 0xD6, 0x36, 0xB6, 0x76, 0xF6,
+ 0x0E, 0x8E, 0x4E, 0xCE, 0x2E, 0xAE, 0x6E, 0xEE,
+ 0x1E, 0x9E, 0x5E, 0xDE, 0x3E, 0xBE, 0x7E, 0xFE,
+ 0x01, 0x81, 0x41, 0xC1, 0x21, 0xA1, 0x61, 0xE1,
+ 0x11, 0x91, 0x51, 0xD1, 0x31, 0xB1, 0x71, 0xF1,
+ 0x09, 0x89, 0x49, 0xC9, 0x29, 0xA9, 0x69, 0xE9,
+ 0x19, 0x99, 0x59, 0xD9, 0x39, 0xB9, 0x79, 0xF9,
+ 0x05, 0x85, 0x45, 0xC5, 0x25, 0xA5, 0x65, 0xE5,
+ 0x15, 0x95, 0x55, 0xD5, 0x35, 0xB5, 0x75, 0xF5,
+ 0x0D, 0x8D, 0x4D, 0xCD, 0x2D, 0xAD, 0x6D, 0xED,
+ 0x1D, 0x9D, 0x5D, 0xDD, 0x3D, 0xBD, 0x7D, 0xFD,
+ 0x03, 0x83, 0x43, 0xC3, 0x23, 0xA3, 0x63, 0xE3,
+ 0x13, 0x93, 0x53, 0xD3, 0x33, 0xB3, 0x73, 0xF3,
+ 0x0B, 0x8B, 0x4B, 0xCB, 0x2B, 0xAB, 0x6B, 0xEB,
+ 0x1B, 0x9B, 0x5B, 0xDB, 0x3B, 0xBB, 0x7B, 0xFB,
+ 0x07, 0x87, 0x47, 0xC7, 0x27, 0xA7, 0x67, 0xE7,
+ 0x17, 0x97, 0x57, 0xD7, 0x37, 0xB7, 0x77, 0xF7,
+ 0x0F, 0x8F, 0x4F, 0xCF, 0x2F, 0xAF, 0x6F, 0xEF,
+ 0x1F, 0x9F, 0x5F, 0xDF, 0x3F, 0xBF, 0x7F, 0xFF
+};
+
+static PNG_CONST png_byte twobppswaptable[256] = {
+ 0x00, 0x40, 0x80, 0xC0, 0x10, 0x50, 0x90, 0xD0,
+ 0x20, 0x60, 0xA0, 0xE0, 0x30, 0x70, 0xB0, 0xF0,
+ 0x04, 0x44, 0x84, 0xC4, 0x14, 0x54, 0x94, 0xD4,
+ 0x24, 0x64, 0xA4, 0xE4, 0x34, 0x74, 0xB4, 0xF4,
+ 0x08, 0x48, 0x88, 0xC8, 0x18, 0x58, 0x98, 0xD8,
+ 0x28, 0x68, 0xA8, 0xE8, 0x38, 0x78, 0xB8, 0xF8,
+ 0x0C, 0x4C, 0x8C, 0xCC, 0x1C, 0x5C, 0x9C, 0xDC,
+ 0x2C, 0x6C, 0xAC, 0xEC, 0x3C, 0x7C, 0xBC, 0xFC,
+ 0x01, 0x41, 0x81, 0xC1, 0x11, 0x51, 0x91, 0xD1,
+ 0x21, 0x61, 0xA1, 0xE1, 0x31, 0x71, 0xB1, 0xF1,
+ 0x05, 0x45, 0x85, 0xC5, 0x15, 0x55, 0x95, 0xD5,
+ 0x25, 0x65, 0xA5, 0xE5, 0x35, 0x75, 0xB5, 0xF5,
+ 0x09, 0x49, 0x89, 0xC9, 0x19, 0x59, 0x99, 0xD9,
+ 0x29, 0x69, 0xA9, 0xE9, 0x39, 0x79, 0xB9, 0xF9,
+ 0x0D, 0x4D, 0x8D, 0xCD, 0x1D, 0x5D, 0x9D, 0xDD,
+ 0x2D, 0x6D, 0xAD, 0xED, 0x3D, 0x7D, 0xBD, 0xFD,
+ 0x02, 0x42, 0x82, 0xC2, 0x12, 0x52, 0x92, 0xD2,
+ 0x22, 0x62, 0xA2, 0xE2, 0x32, 0x72, 0xB2, 0xF2,
+ 0x06, 0x46, 0x86, 0xC6, 0x16, 0x56, 0x96, 0xD6,
+ 0x26, 0x66, 0xA6, 0xE6, 0x36, 0x76, 0xB6, 0xF6,
+ 0x0A, 0x4A, 0x8A, 0xCA, 0x1A, 0x5A, 0x9A, 0xDA,
+ 0x2A, 0x6A, 0xAA, 0xEA, 0x3A, 0x7A, 0xBA, 0xFA,
+ 0x0E, 0x4E, 0x8E, 0xCE, 0x1E, 0x5E, 0x9E, 0xDE,
+ 0x2E, 0x6E, 0xAE, 0xEE, 0x3E, 0x7E, 0xBE, 0xFE,
+ 0x03, 0x43, 0x83, 0xC3, 0x13, 0x53, 0x93, 0xD3,
+ 0x23, 0x63, 0xA3, 0xE3, 0x33, 0x73, 0xB3, 0xF3,
+ 0x07, 0x47, 0x87, 0xC7, 0x17, 0x57, 0x97, 0xD7,
+ 0x27, 0x67, 0xA7, 0xE7, 0x37, 0x77, 0xB7, 0xF7,
+ 0x0B, 0x4B, 0x8B, 0xCB, 0x1B, 0x5B, 0x9B, 0xDB,
+ 0x2B, 0x6B, 0xAB, 0xEB, 0x3B, 0x7B, 0xBB, 0xFB,
+ 0x0F, 0x4F, 0x8F, 0xCF, 0x1F, 0x5F, 0x9F, 0xDF,
+ 0x2F, 0x6F, 0xAF, 0xEF, 0x3F, 0x7F, 0xBF, 0xFF
+};
+
+static PNG_CONST png_byte fourbppswaptable[256] = {
+ 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70,
+ 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0,
+ 0x01, 0x11, 0x21, 0x31, 0x41, 0x51, 0x61, 0x71,
+ 0x81, 0x91, 0xA1, 0xB1, 0xC1, 0xD1, 0xE1, 0xF1,
+ 0x02, 0x12, 0x22, 0x32, 0x42, 0x52, 0x62, 0x72,
+ 0x82, 0x92, 0xA2, 0xB2, 0xC2, 0xD2, 0xE2, 0xF2,
+ 0x03, 0x13, 0x23, 0x33, 0x43, 0x53, 0x63, 0x73,
+ 0x83, 0x93, 0xA3, 0xB3, 0xC3, 0xD3, 0xE3, 0xF3,
+ 0x04, 0x14, 0x24, 0x34, 0x44, 0x54, 0x64, 0x74,
+ 0x84, 0x94, 0xA4, 0xB4, 0xC4, 0xD4, 0xE4, 0xF4,
+ 0x05, 0x15, 0x25, 0x35, 0x45, 0x55, 0x65, 0x75,
+ 0x85, 0x95, 0xA5, 0xB5, 0xC5, 0xD5, 0xE5, 0xF5,
+ 0x06, 0x16, 0x26, 0x36, 0x46, 0x56, 0x66, 0x76,
+ 0x86, 0x96, 0xA6, 0xB6, 0xC6, 0xD6, 0xE6, 0xF6,
+ 0x07, 0x17, 0x27, 0x37, 0x47, 0x57, 0x67, 0x77,
+ 0x87, 0x97, 0xA7, 0xB7, 0xC7, 0xD7, 0xE7, 0xF7,
+ 0x08, 0x18, 0x28, 0x38, 0x48, 0x58, 0x68, 0x78,
+ 0x88, 0x98, 0xA8, 0xB8, 0xC8, 0xD8, 0xE8, 0xF8,
+ 0x09, 0x19, 0x29, 0x39, 0x49, 0x59, 0x69, 0x79,
+ 0x89, 0x99, 0xA9, 0xB9, 0xC9, 0xD9, 0xE9, 0xF9,
+ 0x0A, 0x1A, 0x2A, 0x3A, 0x4A, 0x5A, 0x6A, 0x7A,
+ 0x8A, 0x9A, 0xAA, 0xBA, 0xCA, 0xDA, 0xEA, 0xFA,
+ 0x0B, 0x1B, 0x2B, 0x3B, 0x4B, 0x5B, 0x6B, 0x7B,
+ 0x8B, 0x9B, 0xAB, 0xBB, 0xCB, 0xDB, 0xEB, 0xFB,
+ 0x0C, 0x1C, 0x2C, 0x3C, 0x4C, 0x5C, 0x6C, 0x7C,
+ 0x8C, 0x9C, 0xAC, 0xBC, 0xCC, 0xDC, 0xEC, 0xFC,
+ 0x0D, 0x1D, 0x2D, 0x3D, 0x4D, 0x5D, 0x6D, 0x7D,
+ 0x8D, 0x9D, 0xAD, 0xBD, 0xCD, 0xDD, 0xED, 0xFD,
+ 0x0E, 0x1E, 0x2E, 0x3E, 0x4E, 0x5E, 0x6E, 0x7E,
+ 0x8E, 0x9E, 0xAE, 0xBE, 0xCE, 0xDE, 0xEE, 0xFE,
+ 0x0F, 0x1F, 0x2F, 0x3F, 0x4F, 0x5F, 0x6F, 0x7F,
+ 0x8F, 0x9F, 0xAF, 0xBF, 0xCF, 0xDF, 0xEF, 0xFF
+};
+
+/* Swaps pixel packing order within bytes */
+void /* PRIVATE */
+png_do_packswap(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_packswap");
+
+ if (row_info->bit_depth < 8)
+ {
+ png_bytep rp;
+ png_const_bytep end, table;
+
+ end = row + row_info->rowbytes;
+
+ if (row_info->bit_depth == 1)
+ table = onebppswaptable;
+
+ else if (row_info->bit_depth == 2)
+ table = twobppswaptable;
+
+ else if (row_info->bit_depth == 4)
+ table = fourbppswaptable;
+
+ else
+ return;
+
+ for (rp = row; rp < end; rp++)
+ *rp = table[*rp];
+ }
+}
+#endif /* PNG_READ_PACKSWAP_SUPPORTED or PNG_WRITE_PACKSWAP_SUPPORTED */
+
+#if defined(PNG_WRITE_FILLER_SUPPORTED) || \
+ defined(PNG_READ_STRIP_ALPHA_SUPPORTED)
+/* Remove a channel - this used to be 'png_do_strip_filler' but it used a
+ * somewhat weird combination of flags to determine what to do. All the calls
+ * to png_do_strip_filler are changed in 1.5.2 to call this instead with the
+ * correct arguments.
+ *
+ * The routine isn't general - the channel must be the channel at the start or
+ * end (not in the middle) of each pixel.
+ */
+void /* PRIVATE */
+png_do_strip_channel(png_row_infop row_info, png_bytep row, int at_start)
+{
+ png_bytep sp = row; /* source pointer */
+ png_bytep dp = row; /* destination pointer */
+ png_bytep ep = row + row_info->rowbytes; /* One beyond end of row */
+
+ /* At the start sp will point to the first byte to copy and dp to where
+ * it is copied to. ep always points just beyond the end of the row, so
+ * the loop simply copies (channels-1) channels until sp reaches ep.
+ *
+ * at_start: 0 -- convert AG, XG, ARGB, XRGB, AAGG, XXGG, etc.
+ * nonzero -- convert GA, GX, RGBA, RGBX, GGAA, RRGGBBXX, etc.
+ */
+
+ /* GA, GX, XG cases */
+ if (row_info->channels == 2)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ if (at_start) /* Skip initial filler */
+ ++sp;
+ else /* Skip initial channel and, for sp, the filler */
+ sp += 2, ++dp;
+
+ /* For a 1 pixel wide image there is nothing to do */
+ while (sp < ep)
+ *dp++ = *sp, sp += 2;
+
+ row_info->pixel_depth = 8;
+ }
+
+ else if (row_info->bit_depth == 16)
+ {
+ if (at_start) /* Skip initial filler */
+ sp += 2;
+ else /* Skip initial channel and, for sp, the filler */
+ sp += 4, dp += 2;
+
+ while (sp < ep)
+ *dp++ = *sp++, *dp++ = *sp, sp += 3;
+
+ row_info->pixel_depth = 16;
+ }
+
+ else
+ return; /* bad bit depth */
+
+ row_info->channels = 1;
+
+ /* Finally fix the color type if it records an alpha channel */
+ if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
+ row_info->color_type = PNG_COLOR_TYPE_GRAY;
+ }
+
+ /* RGBA, RGBX, XRGB cases */
+ else if (row_info->channels == 4)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ if (at_start) /* Skip initial filler */
+ ++sp;
+ else /* Skip initial channels and, for sp, the filler */
+ sp += 4, dp += 3;
+
+ /* Note that the loop adds 3 to dp and 4 to sp each time. */
+ while (sp < ep)
+ *dp++ = *sp++, *dp++ = *sp++, *dp++ = *sp, sp += 2;
+
+ row_info->pixel_depth = 24;
+ }
+
+ else if (row_info->bit_depth == 16)
+ {
+ if (at_start) /* Skip initial filler */
+ sp += 2;
+ else /* Skip initial channels and, for sp, the filler */
+ sp += 8, dp += 6;
+
+ while (sp < ep)
+ {
+ /* Copy 6 bytes, skip 2 */
+ *dp++ = *sp++, *dp++ = *sp++;
+ *dp++ = *sp++, *dp++ = *sp++;
+ *dp++ = *sp++, *dp++ = *sp, sp += 3;
+ }
+
+ row_info->pixel_depth = 48;
+ }
+
+ else
+ return; /* bad bit depth */
+
+ row_info->channels = 3;
+
+ /* Finally fix the color type if it records an alpha channel */
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ row_info->color_type = PNG_COLOR_TYPE_RGB;
+ }
+
+ else
+ return; /* The filler channel has gone already */
+
+ /* Fix the rowbytes value. */
+ row_info->rowbytes = dp-row;
+}
+#endif
+
+#if defined(PNG_READ_BGR_SUPPORTED) || defined(PNG_WRITE_BGR_SUPPORTED)
+/* Swaps red and blue bytes within a pixel */
+void /* PRIVATE */
+png_do_bgr(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_bgr");
+
+ if ((row_info->color_type & PNG_COLOR_MASK_COLOR))
+ {
+ png_uint_32 row_width = row_info->width;
+ if (row_info->bit_depth == 8)
+ {
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB)
+ {
+ png_bytep rp;
+ png_uint_32 i;
+
+ for (i = 0, rp = row; i < row_width; i++, rp += 3)
+ {
+ png_byte save = *rp;
+ *rp = *(rp + 2);
+ *(rp + 2) = save;
+ }
+ }
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ {
+ png_bytep rp;
+ png_uint_32 i;
+
+ for (i = 0, rp = row; i < row_width; i++, rp += 4)
+ {
+ png_byte save = *rp;
+ *rp = *(rp + 2);
+ *(rp + 2) = save;
+ }
+ }
+ }
+
+#ifdef PNG_16BIT_SUPPORTED
+ else if (row_info->bit_depth == 16)
+ {
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB)
+ {
+ png_bytep rp;
+ png_uint_32 i;
+
+ for (i = 0, rp = row; i < row_width; i++, rp += 6)
+ {
+ png_byte save = *rp;
+ *rp = *(rp + 4);
+ *(rp + 4) = save;
+ save = *(rp + 1);
+ *(rp + 1) = *(rp + 5);
+ *(rp + 5) = save;
+ }
+ }
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ {
+ png_bytep rp;
+ png_uint_32 i;
+
+ for (i = 0, rp = row; i < row_width; i++, rp += 8)
+ {
+ png_byte save = *rp;
+ *rp = *(rp + 4);
+ *(rp + 4) = save;
+ save = *(rp + 1);
+ *(rp + 1) = *(rp + 5);
+ *(rp + 5) = save;
+ }
+ }
+ }
+#endif
+ }
+}
+#endif /* PNG_READ_BGR_SUPPORTED or PNG_WRITE_BGR_SUPPORTED */
+
+#if defined(PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED) || \
+ defined(PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED)
+/* Added at libpng-1.5.10 */
+void /* PRIVATE */
+png_do_check_palette_indexes(png_structrp png_ptr, png_row_infop row_info)
+{
+ if (png_ptr->num_palette < (1 << row_info->bit_depth) &&
+ png_ptr->num_palette > 0) /* num_palette can be 0 in MNG files */
+ {
+ /* Calculations moved outside switch in an attempt to stop different
+ * compiler warnings. 'padding' is in *bits* within the last byte, it is
+ * an 'int' because pixel_depth becomes an 'int' in the expression below,
+ * and this calculation is used because it avoids warnings that other
+ * forms produced on either GCC or MSVC.
+ */
+ int padding = (-row_info->pixel_depth * row_info->width) & 7;
+ png_bytep rp = png_ptr->row_buf + row_info->rowbytes;
+
+ switch (row_info->bit_depth)
+ {
+ case 1:
+ {
+ /* in this case, all bytes must be 0 so we don't need
+ * to unpack the pixels except for the rightmost one.
+ */
+ for (; rp > png_ptr->row_buf; rp--)
+ {
+ if (*rp >> padding != 0)
+ png_ptr->num_palette_max = 1;
+ padding = 0;
+ }
+
+ break;
+ }
+
+ case 2:
+ {
+ for (; rp > png_ptr->row_buf; rp--)
+ {
+ int i = ((*rp >> padding) & 0x03);
+
+ if (i > png_ptr->num_palette_max)
+ png_ptr->num_palette_max = i;
+
+ i = (((*rp >> padding) >> 2) & 0x03);
+
+ if (i > png_ptr->num_palette_max)
+ png_ptr->num_palette_max = i;
+
+ i = (((*rp >> padding) >> 4) & 0x03);
+
+ if (i > png_ptr->num_palette_max)
+ png_ptr->num_palette_max = i;
+
+ i = (((*rp >> padding) >> 6) & 0x03);
+
+ if (i > png_ptr->num_palette_max)
+ png_ptr->num_palette_max = i;
+
+ padding = 0;
+ }
+
+ break;
+ }
+
+ case 4:
+ {
+ for (; rp > png_ptr->row_buf; rp--)
+ {
+ int i = ((*rp >> padding) & 0x0f);
+
+ if (i > png_ptr->num_palette_max)
+ png_ptr->num_palette_max = i;
+
+ i = (((*rp >> padding) >> 4) & 0x0f);
+
+ if (i > png_ptr->num_palette_max)
+ png_ptr->num_palette_max = i;
+
+ padding = 0;
+ }
+
+ break;
+ }
+
+ case 8:
+ {
+ for (; rp > png_ptr->row_buf; rp--)
+ {
+ if (*rp > png_ptr->num_palette_max)
+ png_ptr->num_palette_max = (int) *rp;
+ }
+
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+}
+#endif /* PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED */
+
+#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) || \
+ defined(PNG_WRITE_USER_TRANSFORM_SUPPORTED)
+#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED
+void PNGAPI
+png_set_user_transform_info(png_structrp png_ptr, png_voidp
+ user_transform_ptr, int user_transform_depth, int user_transform_channels)
+{
+ png_debug(1, "in png_set_user_transform_info");
+
+ if (png_ptr == NULL)
+ return;
+
+#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED
+ if ((png_ptr->mode & PNG_IS_READ_STRUCT) != 0 &&
+ (png_ptr->flags & PNG_FLAG_ROW_INIT) != 0)
+ {
+ png_app_error(png_ptr,
+ "info change after png_start_read_image or png_read_update_info");
+ return;
+ }
+#endif
+
+ png_ptr->user_transform_ptr = user_transform_ptr;
+ png_ptr->user_transform_depth = (png_byte)user_transform_depth;
+ png_ptr->user_transform_channels = (png_byte)user_transform_channels;
+}
+#endif
+
+/* This function returns a pointer to the user_transform_ptr associated with
+ * the user transform functions. The application should free any memory
+ * associated with this pointer before png_write_destroy and png_read_destroy
+ * are called.
+ */
+#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED
+png_voidp PNGAPI
+png_get_user_transform_ptr(png_const_structrp png_ptr)
+{
+ if (png_ptr == NULL)
+ return (NULL);
+
+ return png_ptr->user_transform_ptr;
+}
+#endif
+
+#ifdef PNG_USER_TRANSFORM_INFO_SUPPORTED
+png_uint_32 PNGAPI
+png_get_current_row_number(png_const_structrp png_ptr)
+{
+ /* See the comments in png.h - this is the sub-image row when reading and
+ * interlaced image.
+ */
+ if (png_ptr != NULL)
+ return png_ptr->row_number;
+
+ return PNG_UINT_32_MAX; /* help the app not to fail silently */
+}
+
+png_byte PNGAPI
+png_get_current_pass_number(png_const_structrp png_ptr)
+{
+ if (png_ptr != NULL)
+ return png_ptr->pass;
+ return 8; /* invalid */
+}
+#endif /* PNG_USER_TRANSFORM_INFO_SUPPORTED */
+#endif /* PNG_READ_USER_TRANSFORM_SUPPORTED ||
+ PNG_WRITE_USER_TRANSFORM_SUPPORTED */
+#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngwio.c b/ml/dlib/dlib/external/libpng/pngwio.c
new file mode 100644
index 000000000..e3289dfe4
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngwio.c
@@ -0,0 +1,164 @@
+
+/* pngwio.c - functions for data output
+ *
+ * Last changed in libpng 1.6.0 [February 14, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ *
+ * This file provides a location for all output. Users who need
+ * special handling are expected to write functions that have the same
+ * arguments as these and perform similar functions, but that possibly
+ * use different output methods. Note that you shouldn't change these
+ * functions, but rather write replacement functions and then change
+ * them at run time with png_set_write_fn(...).
+ */
+
+#include "pngpriv.h"
+
+#ifdef PNG_WRITE_SUPPORTED
+
+/* Write the data to whatever output you are using. The default routine
+ * writes to a file pointer. Note that this routine sometimes gets called
+ * with very small lengths, so you should implement some kind of simple
+ * buffering if you are using unbuffered writes. This should never be asked
+ * to write more than 64K on a 16 bit machine.
+ */
+
+void /* PRIVATE */
+png_write_data(png_structrp png_ptr, png_const_bytep data, png_size_t length)
+{
+ /* NOTE: write_data_fn must not change the buffer! */
+ if (png_ptr->write_data_fn != NULL )
+ (*(png_ptr->write_data_fn))(png_ptr, png_constcast(png_bytep,data),
+ length);
+
+ else
+ png_error(png_ptr, "Call to NULL write function");
+}
+
+#ifdef PNG_STDIO_SUPPORTED
+/* This is the function that does the actual writing of data. If you are
+ * not writing to a standard C stream, you should create a replacement
+ * write_data function and use it at run time with png_set_write_fn(), rather
+ * than changing the library.
+ */
+void PNGCBAPI
+png_default_write_data(png_structp png_ptr, png_bytep data, png_size_t length)
+{
+ png_size_t check;
+
+ if (png_ptr == NULL)
+ return;
+
+ check = fwrite(data, 1, length, (png_FILE_p)(png_ptr->io_ptr));
+
+ if (check != length)
+ png_error(png_ptr, "Write Error");
+}
+#endif
+
+/* This function is called to output any data pending writing (normally
+ * to disk). After png_flush is called, there should be no data pending
+ * writing in any buffers.
+ */
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+void /* PRIVATE */
+png_flush(png_structrp png_ptr)
+{
+ if (png_ptr->output_flush_fn != NULL)
+ (*(png_ptr->output_flush_fn))(png_ptr);
+}
+
+# ifdef PNG_STDIO_SUPPORTED
+void PNGCBAPI
+png_default_flush(png_structp png_ptr)
+{
+ png_FILE_p io_ptr;
+
+ if (png_ptr == NULL)
+ return;
+
+ io_ptr = png_voidcast(png_FILE_p, (png_ptr->io_ptr));
+ fflush(io_ptr);
+}
+# endif
+#endif
+
+/* This function allows the application to supply new output functions for
+ * libpng if standard C streams aren't being used.
+ *
+ * This function takes as its arguments:
+ * png_ptr - pointer to a png output data structure
+ * io_ptr - pointer to user supplied structure containing info about
+ * the output functions. May be NULL.
+ * write_data_fn - pointer to a new output function that takes as its
+ * arguments a pointer to a png_struct, a pointer to
+ * data to be written, and a 32-bit unsigned int that is
+ * the number of bytes to be written. The new write
+ * function should call png_error(png_ptr, "Error msg")
+ * to exit and output any fatal error messages. May be
+ * NULL, in which case libpng's default function will
+ * be used.
+ * flush_data_fn - pointer to a new flush function that takes as its
+ * arguments a pointer to a png_struct. After a call to
+ * the flush function, there should be no data in any buffers
+ * or pending transmission. If the output method doesn't do
+ * any buffering of output, a function prototype must still be
+ * supplied although it doesn't have to do anything. If
+ * PNG_WRITE_FLUSH_SUPPORTED is not defined at libpng compile
+ * time, output_flush_fn will be ignored, although it must be
+ * supplied for compatibility. May be NULL, in which case
+ * libpng's default function will be used, if
+ * PNG_WRITE_FLUSH_SUPPORTED is defined. This is not
+ * a good idea if io_ptr does not point to a standard
+ * *FILE structure.
+ */
+void PNGAPI
+png_set_write_fn(png_structrp png_ptr, png_voidp io_ptr,
+ png_rw_ptr write_data_fn, png_flush_ptr output_flush_fn)
+{
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->io_ptr = io_ptr;
+
+#ifdef PNG_STDIO_SUPPORTED
+ if (write_data_fn != NULL)
+ png_ptr->write_data_fn = write_data_fn;
+
+ else
+ png_ptr->write_data_fn = png_default_write_data;
+#else
+ png_ptr->write_data_fn = write_data_fn;
+#endif
+
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+# ifdef PNG_STDIO_SUPPORTED
+
+ if (output_flush_fn != NULL)
+ png_ptr->output_flush_fn = output_flush_fn;
+
+ else
+ png_ptr->output_flush_fn = png_default_flush;
+
+# else
+ png_ptr->output_flush_fn = output_flush_fn;
+# endif
+#endif /* PNG_WRITE_FLUSH_SUPPORTED */
+
+ /* It is an error to read while writing a png file */
+ if (png_ptr->read_data_fn != NULL)
+ {
+ png_ptr->read_data_fn = NULL;
+
+ png_warning(png_ptr,
+ "Can't set both read_data_fn and write_data_fn in the"
+ " same structure");
+ }
+}
+#endif /* PNG_WRITE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngwrite.c b/ml/dlib/dlib/external/libpng/pngwrite.c
new file mode 100644
index 000000000..b71a3d345
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngwrite.c
@@ -0,0 +1,2330 @@
+
+/* pngwrite.c - general routines to write a PNG file
+ *
+ * Last changed in libpng 1.6.2 [April 25, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+#include "pngpriv.h"
+#if defined(PNG_SIMPLIFIED_WRITE_SUPPORTED) && defined(PNG_STDIO_SUPPORTED)
+# include <errno.h>
+#endif
+
+#ifdef PNG_WRITE_SUPPORTED
+
+#ifdef PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED
+/* Write out all the unknown chunks for the current given location */
+static void
+write_unknown_chunks(png_structrp png_ptr, png_const_inforp info_ptr,
+ unsigned int where)
+{
+ if (info_ptr->unknown_chunks_num)
+ {
+ png_const_unknown_chunkp up;
+
+ png_debug(5, "writing extra chunks");
+
+ for (up = info_ptr->unknown_chunks;
+ up < info_ptr->unknown_chunks + info_ptr->unknown_chunks_num;
+ ++up)
+ if (up->location & where)
+ {
+ /* If per-chunk unknown chunk handling is enabled use it, otherwise
+ * just write the chunks the application has set.
+ */
+#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+ int keep = png_handle_as_unknown(png_ptr, up->name);
+
+ /* NOTE: this code is radically different from the read side in the
+ * matter of handling an ancillary unknown chunk. In the read side
+ * the default behavior is to discard it, in the code below the default
+ * behavior is to write it. Critical chunks are, however, only
+ * written if explicitly listed or if the default is set to write all
+ * unknown chunks.
+ *
+ * The default handling is also slightly weird - it is not possible to
+ * stop the writing of all unsafe-to-copy chunks!
+ *
+ * TODO: REVIEW: this would seem to be a bug.
+ */
+ if (keep != PNG_HANDLE_CHUNK_NEVER &&
+ ((up->name[3] & 0x20) /* safe-to-copy overrides everything */ ||
+ keep == PNG_HANDLE_CHUNK_ALWAYS ||
+ (keep == PNG_HANDLE_CHUNK_AS_DEFAULT &&
+ png_ptr->unknown_default == PNG_HANDLE_CHUNK_ALWAYS)))
+#endif
+ {
+ /* TODO: review, what is wrong with a zero length unknown chunk? */
+ if (up->size == 0)
+ png_warning(png_ptr, "Writing zero-length unknown chunk");
+
+ png_write_chunk(png_ptr, up->name, up->data, up->size);
+ }
+ }
+ }
+}
+#endif /* PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED */
+
+/* Writes all the PNG information. This is the suggested way to use the
+ * library. If you have a new chunk to add, make a function to write it,
+ * and put it in the correct location here. If you want the chunk written
+ * after the image data, put it in png_write_end(). I strongly encourage
+ * you to supply a PNG_INFO_ flag, and check info_ptr->valid before writing
+ * the chunk, as that will keep the code from breaking if you want to just
+ * write a plain PNG file. If you have long comments, I suggest writing
+ * them in png_write_end(), and compressing them.
+ */
+void PNGAPI
+png_write_info_before_PLTE(png_structrp png_ptr, png_const_inforp info_ptr)
+{
+ png_debug(1, "in png_write_info_before_PLTE");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ if (!(png_ptr->mode & PNG_WROTE_INFO_BEFORE_PLTE))
+ {
+ /* Write PNG signature */
+ png_write_sig(png_ptr);
+
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ if ((png_ptr->mode&PNG_HAVE_PNG_SIGNATURE) && \
+ (png_ptr->mng_features_permitted))
+ {
+ png_warning(png_ptr, "MNG features are not allowed in a PNG datastream");
+ png_ptr->mng_features_permitted = 0;
+ }
+#endif
+
+ /* Write IHDR information. */
+ png_write_IHDR(png_ptr, info_ptr->width, info_ptr->height,
+ info_ptr->bit_depth, info_ptr->color_type, info_ptr->compression_type,
+ info_ptr->filter_type,
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+ info_ptr->interlace_type
+#else
+ 0
+#endif
+ );
+
+ /* The rest of these check to see if the valid field has the appropriate
+ * flag set, and if it does, writes the chunk.
+ *
+ * 1.6.0: COLORSPACE support controls the writing of these chunks too, and
+ * the chunks will be written if the WRITE routine is there and information
+ * is available in the COLORSPACE. (See png_colorspace_sync_info in png.c
+ * for where the valid flags get set.)
+ *
+ * Under certain circumstances the colorspace can be invalidated without
+ * syncing the info_struct 'valid' flags; this happens if libpng detects and
+ * error and calls png_error while the color space is being set, yet the
+ * application continues writing the PNG. So check the 'invalid' flag here
+ * too.
+ */
+#ifdef PNG_GAMMA_SUPPORTED
+# ifdef PNG_WRITE_gAMA_SUPPORTED
+ if (!(info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) &&
+ (info_ptr->colorspace.flags & PNG_COLORSPACE_FROM_gAMA) &&
+ (info_ptr->valid & PNG_INFO_gAMA))
+ png_write_gAMA_fixed(png_ptr, info_ptr->colorspace.gamma);
+# endif
+#endif
+
+#ifdef PNG_COLORSPACE_SUPPORTED
+ /* Write only one of sRGB or an ICC profile. If a profile was supplied
+ * and it matches one of the known sRGB ones issue a warning.
+ */
+# ifdef PNG_WRITE_iCCP_SUPPORTED
+ if (!(info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) &&
+ (info_ptr->valid & PNG_INFO_iCCP))
+ {
+# ifdef PNG_WRITE_sRGB_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_sRGB)
+ png_app_warning(png_ptr,
+ "profile matches sRGB but writing iCCP instead");
+# endif
+
+ png_write_iCCP(png_ptr, info_ptr->iccp_name,
+ info_ptr->iccp_profile);
+ }
+# ifdef PNG_WRITE_sRGB_SUPPORTED
+ else
+# endif
+# endif
+
+# ifdef PNG_WRITE_sRGB_SUPPORTED
+ if (!(info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) &&
+ (info_ptr->valid & PNG_INFO_sRGB))
+ png_write_sRGB(png_ptr, info_ptr->colorspace.rendering_intent);
+# endif /* WRITE_sRGB */
+#endif /* COLORSPACE */
+
+#ifdef PNG_WRITE_sBIT_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_sBIT)
+ png_write_sBIT(png_ptr, &(info_ptr->sig_bit), info_ptr->color_type);
+#endif
+
+#ifdef PNG_COLORSPACE_SUPPORTED
+# ifdef PNG_WRITE_cHRM_SUPPORTED
+ if (!(info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) &&
+ (info_ptr->colorspace.flags & PNG_COLORSPACE_FROM_cHRM) &&
+ (info_ptr->valid & PNG_INFO_cHRM))
+ png_write_cHRM_fixed(png_ptr, &info_ptr->colorspace.end_points_xy);
+# endif
+#endif
+
+#ifdef PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED
+ write_unknown_chunks(png_ptr, info_ptr, PNG_HAVE_IHDR);
+#endif
+
+ png_ptr->mode |= PNG_WROTE_INFO_BEFORE_PLTE;
+ }
+}
+
+void PNGAPI
+png_write_info(png_structrp png_ptr, png_const_inforp info_ptr)
+{
+#if defined(PNG_WRITE_TEXT_SUPPORTED) || defined(PNG_WRITE_sPLT_SUPPORTED)
+ int i;
+#endif
+
+ png_debug(1, "in png_write_info");
+
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ png_write_info_before_PLTE(png_ptr, info_ptr);
+
+ if (info_ptr->valid & PNG_INFO_PLTE)
+ png_write_PLTE(png_ptr, info_ptr->palette,
+ (png_uint_32)info_ptr->num_palette);
+
+ else if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ png_error(png_ptr, "Valid palette required for paletted images");
+
+#ifdef PNG_WRITE_tRNS_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_tRNS)
+ {
+#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED
+ /* Invert the alpha channel (in tRNS) */
+ if ((png_ptr->transformations & PNG_INVERT_ALPHA) &&
+ info_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ int j;
+ for (j = 0; j<(int)info_ptr->num_trans; j++)
+ info_ptr->trans_alpha[j] =
+ (png_byte)(255 - info_ptr->trans_alpha[j]);
+ }
+#endif
+ png_write_tRNS(png_ptr, info_ptr->trans_alpha, &(info_ptr->trans_color),
+ info_ptr->num_trans, info_ptr->color_type);
+ }
+#endif
+#ifdef PNG_WRITE_bKGD_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_bKGD)
+ png_write_bKGD(png_ptr, &(info_ptr->background), info_ptr->color_type);
+#endif
+
+#ifdef PNG_WRITE_hIST_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_hIST)
+ png_write_hIST(png_ptr, info_ptr->hist, info_ptr->num_palette);
+#endif
+
+#ifdef PNG_WRITE_oFFs_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_oFFs)
+ png_write_oFFs(png_ptr, info_ptr->x_offset, info_ptr->y_offset,
+ info_ptr->offset_unit_type);
+#endif
+
+#ifdef PNG_WRITE_pCAL_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_pCAL)
+ png_write_pCAL(png_ptr, info_ptr->pcal_purpose, info_ptr->pcal_X0,
+ info_ptr->pcal_X1, info_ptr->pcal_type, info_ptr->pcal_nparams,
+ info_ptr->pcal_units, info_ptr->pcal_params);
+#endif
+
+#ifdef PNG_WRITE_sCAL_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_sCAL)
+ png_write_sCAL_s(png_ptr, (int)info_ptr->scal_unit,
+ info_ptr->scal_s_width, info_ptr->scal_s_height);
+#endif /* sCAL */
+
+#ifdef PNG_WRITE_pHYs_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_pHYs)
+ png_write_pHYs(png_ptr, info_ptr->x_pixels_per_unit,
+ info_ptr->y_pixels_per_unit, info_ptr->phys_unit_type);
+#endif /* pHYs */
+
+#ifdef PNG_WRITE_tIME_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_tIME)
+ {
+ png_write_tIME(png_ptr, &(info_ptr->mod_time));
+ png_ptr->mode |= PNG_WROTE_tIME;
+ }
+#endif /* tIME */
+
+#ifdef PNG_WRITE_sPLT_SUPPORTED
+ if (info_ptr->valid & PNG_INFO_sPLT)
+ for (i = 0; i < (int)info_ptr->splt_palettes_num; i++)
+ png_write_sPLT(png_ptr, info_ptr->splt_palettes + i);
+#endif /* sPLT */
+
+#ifdef PNG_WRITE_TEXT_SUPPORTED
+ /* Check to see if we need to write text chunks */
+ for (i = 0; i < info_ptr->num_text; i++)
+ {
+ png_debug2(2, "Writing header text chunk %d, type %d", i,
+ info_ptr->text[i].compression);
+ /* An internationalized chunk? */
+ if (info_ptr->text[i].compression > 0)
+ {
+#ifdef PNG_WRITE_iTXt_SUPPORTED
+ /* Write international chunk */
+ png_write_iTXt(png_ptr,
+ info_ptr->text[i].compression,
+ info_ptr->text[i].key,
+ info_ptr->text[i].lang,
+ info_ptr->text[i].lang_key,
+ info_ptr->text[i].text);
+#else
+ png_warning(png_ptr, "Unable to write international text");
+#endif
+ /* Mark this chunk as written */
+ info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_NONE_WR;
+ }
+
+ /* If we want a compressed text chunk */
+ else if (info_ptr->text[i].compression == PNG_TEXT_COMPRESSION_zTXt)
+ {
+#ifdef PNG_WRITE_zTXt_SUPPORTED
+ /* Write compressed chunk */
+ png_write_zTXt(png_ptr, info_ptr->text[i].key,
+ info_ptr->text[i].text, 0,
+ info_ptr->text[i].compression);
+#else
+ png_warning(png_ptr, "Unable to write compressed text");
+#endif
+ /* Mark this chunk as written */
+ info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_zTXt_WR;
+ }
+
+ else if (info_ptr->text[i].compression == PNG_TEXT_COMPRESSION_NONE)
+ {
+#ifdef PNG_WRITE_tEXt_SUPPORTED
+ /* Write uncompressed chunk */
+ png_write_tEXt(png_ptr, info_ptr->text[i].key,
+ info_ptr->text[i].text,
+ 0);
+ /* Mark this chunk as written */
+ info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_NONE_WR;
+#else
+ /* Can't get here */
+ png_warning(png_ptr, "Unable to write uncompressed text");
+#endif
+ }
+ }
+#endif /* tEXt */
+
+#ifdef PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED
+ write_unknown_chunks(png_ptr, info_ptr, PNG_HAVE_PLTE);
+#endif
+}
+
+/* Writes the end of the PNG file. If you don't want to write comments or
+ * time information, you can pass NULL for info. If you already wrote these
+ * in png_write_info(), do not write them again here. If you have long
+ * comments, I suggest writing them here, and compressing them.
+ */
+void PNGAPI
+png_write_end(png_structrp png_ptr, png_inforp info_ptr)
+{
+ png_debug(1, "in png_write_end");
+
+ if (png_ptr == NULL)
+ return;
+
+ if (!(png_ptr->mode & PNG_HAVE_IDAT))
+ png_error(png_ptr, "No IDATs written into file");
+
+#ifdef PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED
+ if (png_ptr->num_palette_max > png_ptr->num_palette)
+ png_benign_error(png_ptr, "Wrote palette index exceeding num_palette");
+#endif
+
+ /* See if user wants us to write information chunks */
+ if (info_ptr != NULL)
+ {
+#ifdef PNG_WRITE_TEXT_SUPPORTED
+ int i; /* local index variable */
+#endif
+#ifdef PNG_WRITE_tIME_SUPPORTED
+ /* Check to see if user has supplied a time chunk */
+ if ((info_ptr->valid & PNG_INFO_tIME) &&
+ !(png_ptr->mode & PNG_WROTE_tIME))
+ png_write_tIME(png_ptr, &(info_ptr->mod_time));
+
+#endif
+#ifdef PNG_WRITE_TEXT_SUPPORTED
+ /* Loop through comment chunks */
+ for (i = 0; i < info_ptr->num_text; i++)
+ {
+ png_debug2(2, "Writing trailer text chunk %d, type %d", i,
+ info_ptr->text[i].compression);
+ /* An internationalized chunk? */
+ if (info_ptr->text[i].compression > 0)
+ {
+#ifdef PNG_WRITE_iTXt_SUPPORTED
+ /* Write international chunk */
+ png_write_iTXt(png_ptr,
+ info_ptr->text[i].compression,
+ info_ptr->text[i].key,
+ info_ptr->text[i].lang,
+ info_ptr->text[i].lang_key,
+ info_ptr->text[i].text);
+#else
+ png_warning(png_ptr, "Unable to write international text");
+#endif
+ /* Mark this chunk as written */
+ info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_NONE_WR;
+ }
+
+ else if (info_ptr->text[i].compression >= PNG_TEXT_COMPRESSION_zTXt)
+ {
+#ifdef PNG_WRITE_zTXt_SUPPORTED
+ /* Write compressed chunk */
+ png_write_zTXt(png_ptr, info_ptr->text[i].key,
+ info_ptr->text[i].text, 0,
+ info_ptr->text[i].compression);
+#else
+ png_warning(png_ptr, "Unable to write compressed text");
+#endif
+ /* Mark this chunk as written */
+ info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_zTXt_WR;
+ }
+
+ else if (info_ptr->text[i].compression == PNG_TEXT_COMPRESSION_NONE)
+ {
+#ifdef PNG_WRITE_tEXt_SUPPORTED
+ /* Write uncompressed chunk */
+ png_write_tEXt(png_ptr, info_ptr->text[i].key,
+ info_ptr->text[i].text, 0);
+#else
+ png_warning(png_ptr, "Unable to write uncompressed text");
+#endif
+
+ /* Mark this chunk as written */
+ info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_NONE_WR;
+ }
+ }
+#endif
+#ifdef PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED
+ write_unknown_chunks(png_ptr, info_ptr, PNG_AFTER_IDAT);
+#endif
+ }
+
+ png_ptr->mode |= PNG_AFTER_IDAT;
+
+ /* Write end of PNG file */
+ png_write_IEND(png_ptr);
+ /* This flush, added in libpng-1.0.8, removed from libpng-1.0.9beta03,
+ * and restored again in libpng-1.2.30, may cause some applications that
+ * do not set png_ptr->output_flush_fn to crash. If your application
+ * experiences a problem, please try building libpng with
+ * PNG_WRITE_FLUSH_AFTER_IEND_SUPPORTED defined, and report the event to
+ * png-mng-implement at lists.sf.net .
+ */
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+# ifdef PNG_WRITE_FLUSH_AFTER_IEND_SUPPORTED
+ png_flush(png_ptr);
+# endif
+#endif
+}
+
+#ifdef PNG_CONVERT_tIME_SUPPORTED
+void PNGAPI
+png_convert_from_struct_tm(png_timep ptime, PNG_CONST struct tm * ttime)
+{
+ png_debug(1, "in png_convert_from_struct_tm");
+
+ ptime->year = (png_uint_16)(1900 + ttime->tm_year);
+ ptime->month = (png_byte)(ttime->tm_mon + 1);
+ ptime->day = (png_byte)ttime->tm_mday;
+ ptime->hour = (png_byte)ttime->tm_hour;
+ ptime->minute = (png_byte)ttime->tm_min;
+ ptime->second = (png_byte)ttime->tm_sec;
+}
+
+void PNGAPI
+png_convert_from_time_t(png_timep ptime, time_t ttime)
+{
+ struct tm *tbuf;
+
+ png_debug(1, "in png_convert_from_time_t");
+
+ tbuf = gmtime(&ttime);
+ png_convert_from_struct_tm(ptime, tbuf);
+}
+#endif
+
+/* Initialize png_ptr structure, and allocate any memory needed */
+PNG_FUNCTION(png_structp,PNGAPI
+png_create_write_struct,(png_const_charp user_png_ver, png_voidp error_ptr,
+ png_error_ptr error_fn, png_error_ptr warn_fn),PNG_ALLOCATED)
+{
+#ifndef PNG_USER_MEM_SUPPORTED
+ png_structrp png_ptr = png_create_png_struct(user_png_ver, error_ptr,
+ error_fn, warn_fn, NULL, NULL, NULL);
+#else
+ return png_create_write_struct_2(user_png_ver, error_ptr, error_fn,
+ warn_fn, NULL, NULL, NULL);
+}
+
+/* Alternate initialize png_ptr structure, and allocate any memory needed */
+PNG_FUNCTION(png_structp,PNGAPI
+png_create_write_struct_2,(png_const_charp user_png_ver, png_voidp error_ptr,
+ png_error_ptr error_fn, png_error_ptr warn_fn, png_voidp mem_ptr,
+ png_malloc_ptr malloc_fn, png_free_ptr free_fn),PNG_ALLOCATED)
+{
+ png_structrp png_ptr = png_create_png_struct(user_png_ver, error_ptr,
+ error_fn, warn_fn, mem_ptr, malloc_fn, free_fn);
+#endif /* PNG_USER_MEM_SUPPORTED */
+ if (png_ptr != NULL)
+ {
+ /* Set the zlib control values to defaults; they can be overridden by the
+ * application after the struct has been created.
+ */
+ png_ptr->zbuffer_size = PNG_ZBUF_SIZE;
+
+ /* The 'zlib_strategy' setting is irrelevant because png_default_claim in
+ * pngwutil.c defaults it according to whether or not filters will be
+ * used, and ignores this setting.
+ */
+ png_ptr->zlib_strategy = PNG_Z_DEFAULT_STRATEGY;
+ png_ptr->zlib_level = PNG_Z_DEFAULT_COMPRESSION;
+ png_ptr->zlib_mem_level = 8;
+ png_ptr->zlib_window_bits = 15;
+ png_ptr->zlib_method = 8;
+
+#ifdef PNG_WRITE_COMPRESSED_TEXT_SUPPORTED
+ png_ptr->zlib_text_strategy = PNG_TEXT_Z_DEFAULT_STRATEGY;
+ png_ptr->zlib_text_level = PNG_TEXT_Z_DEFAULT_COMPRESSION;
+ png_ptr->zlib_text_mem_level = 8;
+ png_ptr->zlib_text_window_bits = 15;
+ png_ptr->zlib_text_method = 8;
+#endif /* PNG_WRITE_COMPRESSED_TEXT_SUPPORTED */
+
+ /* This is a highly dubious configuration option; by default it is off,
+ * but it may be appropriate for private builds that are testing
+ * extensions not conformant to the current specification, or of
+ * applications that must not fail to write at all costs!
+ */
+#ifdef PNG_BENIGN_WRITE_ERRORS_SUPPORTED
+ png_ptr->flags |= PNG_FLAG_BENIGN_ERRORS_WARN;
+ /* In stable builds only warn if an application error can be completely
+ * handled.
+ */
+#endif
+
+ /* App warnings are warnings in release (or release candidate) builds but
+ * are errors during development.
+ */
+#if PNG_LIBPNG_BUILD_BASE_TYPE >= PNG_LIBPNG_BUILD_RC
+ png_ptr->flags |= PNG_FLAG_APP_WARNINGS_WARN;
+#endif
+
+ /* TODO: delay this, it can be done in png_init_io() (if the app doesn't
+ * do it itself) avoiding setting the default function if it is not
+ * required.
+ */
+ png_set_write_fn(png_ptr, NULL, NULL, NULL);
+ }
+
+ return png_ptr;
+}
+
+
+/* Write a few rows of image data. If the image is interlaced,
+ * either you will have to write the 7 sub images, or, if you
+ * have called png_set_interlace_handling(), you will have to
+ * "write" the image seven times.
+ */
+void PNGAPI
+png_write_rows(png_structrp png_ptr, png_bytepp row,
+ png_uint_32 num_rows)
+{
+ png_uint_32 i; /* row counter */
+ png_bytepp rp; /* row pointer */
+
+ png_debug(1, "in png_write_rows");
+
+ if (png_ptr == NULL)
+ return;
+
+ /* Loop through the rows */
+ for (i = 0, rp = row; i < num_rows; i++, rp++)
+ {
+ png_write_row(png_ptr, *rp);
+ }
+}
+
+/* Write the image. You only need to call this function once, even
+ * if you are writing an interlaced image.
+ */
+void PNGAPI
+png_write_image(png_structrp png_ptr, png_bytepp image)
+{
+ png_uint_32 i; /* row index */
+ int pass, num_pass; /* pass variables */
+ png_bytepp rp; /* points to current row */
+
+ if (png_ptr == NULL)
+ return;
+
+ png_debug(1, "in png_write_image");
+
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+ /* Initialize interlace handling. If image is not interlaced,
+ * this will set pass to 1
+ */
+ num_pass = png_set_interlace_handling(png_ptr);
+#else
+ num_pass = 1;
+#endif
+ /* Loop through passes */
+ for (pass = 0; pass < num_pass; pass++)
+ {
+ /* Loop through image */
+ for (i = 0, rp = image; i < png_ptr->height; i++, rp++)
+ {
+ png_write_row(png_ptr, *rp);
+ }
+ }
+}
+
+/* Called by user to write a row of image data */
+void PNGAPI
+png_write_row(png_structrp png_ptr, png_const_bytep row)
+{
+ /* 1.5.6: moved from png_struct to be a local structure: */
+ png_row_info row_info;
+
+ if (png_ptr == NULL)
+ return;
+
+ png_debug2(1, "in png_write_row (row %u, pass %d)",
+ png_ptr->row_number, png_ptr->pass);
+
+ /* Initialize transformations and other stuff if first time */
+ if (png_ptr->row_number == 0 && png_ptr->pass == 0)
+ {
+ /* Make sure we wrote the header info */
+ if (!(png_ptr->mode & PNG_WROTE_INFO_BEFORE_PLTE))
+ png_error(png_ptr,
+ "png_write_info was never called before png_write_row");
+
+ /* Check for transforms that have been set but were defined out */
+#if !defined(PNG_WRITE_INVERT_SUPPORTED) && defined(PNG_READ_INVERT_SUPPORTED)
+ if (png_ptr->transformations & PNG_INVERT_MONO)
+ png_warning(png_ptr, "PNG_WRITE_INVERT_SUPPORTED is not defined");
+#endif
+
+#if !defined(PNG_WRITE_FILLER_SUPPORTED) && defined(PNG_READ_FILLER_SUPPORTED)
+ if (png_ptr->transformations & PNG_FILLER)
+ png_warning(png_ptr, "PNG_WRITE_FILLER_SUPPORTED is not defined");
+#endif
+#if !defined(PNG_WRITE_PACKSWAP_SUPPORTED) && \
+ defined(PNG_READ_PACKSWAP_SUPPORTED)
+ if (png_ptr->transformations & PNG_PACKSWAP)
+ png_warning(png_ptr,
+ "PNG_WRITE_PACKSWAP_SUPPORTED is not defined");
+#endif
+
+#if !defined(PNG_WRITE_PACK_SUPPORTED) && defined(PNG_READ_PACK_SUPPORTED)
+ if (png_ptr->transformations & PNG_PACK)
+ png_warning(png_ptr, "PNG_WRITE_PACK_SUPPORTED is not defined");
+#endif
+
+#if !defined(PNG_WRITE_SHIFT_SUPPORTED) && defined(PNG_READ_SHIFT_SUPPORTED)
+ if (png_ptr->transformations & PNG_SHIFT)
+ png_warning(png_ptr, "PNG_WRITE_SHIFT_SUPPORTED is not defined");
+#endif
+
+#if !defined(PNG_WRITE_BGR_SUPPORTED) && defined(PNG_READ_BGR_SUPPORTED)
+ if (png_ptr->transformations & PNG_BGR)
+ png_warning(png_ptr, "PNG_WRITE_BGR_SUPPORTED is not defined");
+#endif
+
+#if !defined(PNG_WRITE_SWAP_SUPPORTED) && defined(PNG_READ_SWAP_SUPPORTED)
+ if (png_ptr->transformations & PNG_SWAP_BYTES)
+ png_warning(png_ptr, "PNG_WRITE_SWAP_SUPPORTED is not defined");
+#endif
+
+ png_write_start_row(png_ptr);
+ }
+
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+ /* If interlaced and not interested in row, return */
+ if (png_ptr->interlaced && (png_ptr->transformations & PNG_INTERLACE))
+ {
+ switch (png_ptr->pass)
+ {
+ case 0:
+ if (png_ptr->row_number & 0x07)
+ {
+ png_write_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 1:
+ if ((png_ptr->row_number & 0x07) || png_ptr->width < 5)
+ {
+ png_write_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 2:
+ if ((png_ptr->row_number & 0x07) != 4)
+ {
+ png_write_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 3:
+ if ((png_ptr->row_number & 0x03) || png_ptr->width < 3)
+ {
+ png_write_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 4:
+ if ((png_ptr->row_number & 0x03) != 2)
+ {
+ png_write_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 5:
+ if ((png_ptr->row_number & 0x01) || png_ptr->width < 2)
+ {
+ png_write_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ case 6:
+ if (!(png_ptr->row_number & 0x01))
+ {
+ png_write_finish_row(png_ptr);
+ return;
+ }
+ break;
+
+ default: /* error: ignore it */
+ break;
+ }
+ }
+#endif
+
+ /* Set up row info for transformations */
+ row_info.color_type = png_ptr->color_type;
+ row_info.width = png_ptr->usr_width;
+ row_info.channels = png_ptr->usr_channels;
+ row_info.bit_depth = png_ptr->usr_bit_depth;
+ row_info.pixel_depth = (png_byte)(row_info.bit_depth * row_info.channels);
+ row_info.rowbytes = PNG_ROWBYTES(row_info.pixel_depth, row_info.width);
+
+ png_debug1(3, "row_info->color_type = %d", row_info.color_type);
+ png_debug1(3, "row_info->width = %u", row_info.width);
+ png_debug1(3, "row_info->channels = %d", row_info.channels);
+ png_debug1(3, "row_info->bit_depth = %d", row_info.bit_depth);
+ png_debug1(3, "row_info->pixel_depth = %d", row_info.pixel_depth);
+ png_debug1(3, "row_info->rowbytes = %lu", (unsigned long)row_info.rowbytes);
+
+ /* Copy user's row into buffer, leaving room for filter byte. */
+ memcpy(png_ptr->row_buf + 1, row, row_info.rowbytes);
+
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+ /* Handle interlacing */
+ if (png_ptr->interlaced && png_ptr->pass < 6 &&
+ (png_ptr->transformations & PNG_INTERLACE))
+ {
+ png_do_write_interlace(&row_info, png_ptr->row_buf + 1, png_ptr->pass);
+ /* This should always get caught above, but still ... */
+ if (!(row_info.width))
+ {
+ png_write_finish_row(png_ptr);
+ return;
+ }
+ }
+#endif
+
+#ifdef PNG_WRITE_TRANSFORMS_SUPPORTED
+ /* Handle other transformations */
+ if (png_ptr->transformations)
+ png_do_write_transformations(png_ptr, &row_info);
+#endif
+
+ /* At this point the row_info pixel depth must match the 'transformed' depth,
+ * which is also the output depth.
+ */
+ if (row_info.pixel_depth != png_ptr->pixel_depth ||
+ row_info.pixel_depth != png_ptr->transformed_pixel_depth)
+ png_error(png_ptr, "internal write transform logic error");
+
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ /* Write filter_method 64 (intrapixel differencing) only if
+ * 1. Libpng was compiled with PNG_MNG_FEATURES_SUPPORTED and
+ * 2. Libpng did not write a PNG signature (this filter_method is only
+ * used in PNG datastreams that are embedded in MNG datastreams) and
+ * 3. The application called png_permit_mng_features with a mask that
+ * included PNG_FLAG_MNG_FILTER_64 and
+ * 4. The filter_method is 64 and
+ * 5. The color_type is RGB or RGBA
+ */
+ if ((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) &&
+ (png_ptr->filter_type == PNG_INTRAPIXEL_DIFFERENCING))
+ {
+ /* Intrapixel differencing */
+ png_do_write_intrapixel(&row_info, png_ptr->row_buf + 1);
+ }
+#endif
+
+/* Added at libpng-1.5.10 */
+#ifdef PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED
+ /* Check for out-of-range palette index */
+ if (row_info.color_type == PNG_COLOR_TYPE_PALETTE &&
+ png_ptr->num_palette_max >= 0)
+ png_do_check_palette_indexes(png_ptr, &row_info);
+#endif
+
+ /* Find a filter if necessary, filter the row and write it out. */
+ png_write_find_filter(png_ptr, &row_info);
+
+ if (png_ptr->write_row_fn != NULL)
+ (*(png_ptr->write_row_fn))(png_ptr, png_ptr->row_number, png_ptr->pass);
+}
+
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+/* Set the automatic flush interval or 0 to turn flushing off */
+void PNGAPI
+png_set_flush(png_structrp png_ptr, int nrows)
+{
+ png_debug(1, "in png_set_flush");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->flush_dist = (nrows < 0 ? 0 : nrows);
+}
+
+/* Flush the current output buffers now */
+void PNGAPI
+png_write_flush(png_structrp png_ptr)
+{
+ png_debug(1, "in png_write_flush");
+
+ if (png_ptr == NULL)
+ return;
+
+ /* We have already written out all of the data */
+ if (png_ptr->row_number >= png_ptr->num_rows)
+ return;
+
+ png_compress_IDAT(png_ptr, NULL, 0, Z_SYNC_FLUSH);
+ png_ptr->flush_rows = 0;
+ png_flush(png_ptr);
+}
+#endif /* PNG_WRITE_FLUSH_SUPPORTED */
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+static void png_reset_filter_heuristics(png_structrp png_ptr);/* forward decl */
+#endif
+
+/* Free any memory used in png_ptr struct without freeing the struct itself. */
+static void
+png_write_destroy(png_structrp png_ptr)
+{
+ png_debug(1, "in png_write_destroy");
+
+ /* Free any memory zlib uses */
+ if (png_ptr->flags & PNG_FLAG_ZSTREAM_INITIALIZED)
+ deflateEnd(&png_ptr->zstream);
+
+ /* Free our memory. png_free checks NULL for us. */
+ png_free_buffer_list(png_ptr, &png_ptr->zbuffer_list);
+ png_free(png_ptr, png_ptr->row_buf);
+#ifdef PNG_WRITE_FILTER_SUPPORTED
+ png_free(png_ptr, png_ptr->prev_row);
+ png_free(png_ptr, png_ptr->sub_row);
+ png_free(png_ptr, png_ptr->up_row);
+ png_free(png_ptr, png_ptr->avg_row);
+ png_free(png_ptr, png_ptr->paeth_row);
+#endif
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ /* Use this to save a little code space, it doesn't free the filter_costs */
+ png_reset_filter_heuristics(png_ptr);
+ png_free(png_ptr, png_ptr->filter_costs);
+ png_free(png_ptr, png_ptr->inv_filter_costs);
+#endif
+
+#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED
+ png_free(png_ptr, png_ptr->chunk_list);
+#endif
+
+ /* The error handling and memory handling information is left intact at this
+ * point: the jmp_buf may still have to be freed. See png_destroy_png_struct
+ * for how this happens.
+ */
+}
+
+/* Free all memory used by the write.
+ * In libpng 1.6.0 this API changed quietly to no longer accept a NULL value for
+ * *png_ptr_ptr. Prior to 1.6.0 it would accept such a value and it would free
+ * the passed in info_structs but it would quietly fail to free any of the data
+ * inside them. In 1.6.0 it quietly does nothing (it has to be quiet because it
+ * has no png_ptr.)
+ */
+void PNGAPI
+png_destroy_write_struct(png_structpp png_ptr_ptr, png_infopp info_ptr_ptr)
+{
+ png_debug(1, "in png_destroy_write_struct");
+
+ if (png_ptr_ptr != NULL)
+ {
+ png_structrp png_ptr = *png_ptr_ptr;
+
+ if (png_ptr != NULL) /* added in libpng 1.6.0 */
+ {
+ png_destroy_info_struct(png_ptr, info_ptr_ptr);
+
+ *png_ptr_ptr = NULL;
+ png_write_destroy(png_ptr);
+ png_destroy_png_struct(png_ptr);
+ }
+ }
+}
+
+/* Allow the application to select one or more row filters to use. */
+void PNGAPI
+png_set_filter(png_structrp png_ptr, int method, int filters)
+{
+ png_debug(1, "in png_set_filter");
+
+ if (png_ptr == NULL)
+ return;
+
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ if ((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) &&
+ (method == PNG_INTRAPIXEL_DIFFERENCING))
+ method = PNG_FILTER_TYPE_BASE;
+
+#endif
+ if (method == PNG_FILTER_TYPE_BASE)
+ {
+ switch (filters & (PNG_ALL_FILTERS | 0x07))
+ {
+#ifdef PNG_WRITE_FILTER_SUPPORTED
+ case 5:
+ case 6:
+ case 7: png_app_error(png_ptr, "Unknown row filter for method 0");
+ /* FALL THROUGH */
+#endif /* PNG_WRITE_FILTER_SUPPORTED */
+ case PNG_FILTER_VALUE_NONE:
+ png_ptr->do_filter = PNG_FILTER_NONE; break;
+
+#ifdef PNG_WRITE_FILTER_SUPPORTED
+ case PNG_FILTER_VALUE_SUB:
+ png_ptr->do_filter = PNG_FILTER_SUB; break;
+
+ case PNG_FILTER_VALUE_UP:
+ png_ptr->do_filter = PNG_FILTER_UP; break;
+
+ case PNG_FILTER_VALUE_AVG:
+ png_ptr->do_filter = PNG_FILTER_AVG; break;
+
+ case PNG_FILTER_VALUE_PAETH:
+ png_ptr->do_filter = PNG_FILTER_PAETH; break;
+
+ default:
+ png_ptr->do_filter = (png_byte)filters; break;
+#else
+ default:
+ png_app_error(png_ptr, "Unknown row filter for method 0");
+#endif /* PNG_WRITE_FILTER_SUPPORTED */
+ }
+
+ /* If we have allocated the row_buf, this means we have already started
+ * with the image and we should have allocated all of the filter buffers
+ * that have been selected. If prev_row isn't already allocated, then
+ * it is too late to start using the filters that need it, since we
+ * will be missing the data in the previous row. If an application
+ * wants to start and stop using particular filters during compression,
+ * it should start out with all of the filters, and then add and
+ * remove them after the start of compression.
+ */
+ if (png_ptr->row_buf != NULL)
+ {
+#ifdef PNG_WRITE_FILTER_SUPPORTED
+ if ((png_ptr->do_filter & PNG_FILTER_SUB) && png_ptr->sub_row == NULL)
+ {
+ png_ptr->sub_row = (png_bytep)png_malloc(png_ptr,
+ (png_ptr->rowbytes + 1));
+ png_ptr->sub_row[0] = PNG_FILTER_VALUE_SUB;
+ }
+
+ if ((png_ptr->do_filter & PNG_FILTER_UP) && png_ptr->up_row == NULL)
+ {
+ if (png_ptr->prev_row == NULL)
+ {
+ png_warning(png_ptr, "Can't add Up filter after starting");
+ png_ptr->do_filter = (png_byte)(png_ptr->do_filter &
+ ~PNG_FILTER_UP);
+ }
+
+ else
+ {
+ png_ptr->up_row = (png_bytep)png_malloc(png_ptr,
+ (png_ptr->rowbytes + 1));
+ png_ptr->up_row[0] = PNG_FILTER_VALUE_UP;
+ }
+ }
+
+ if ((png_ptr->do_filter & PNG_FILTER_AVG) && png_ptr->avg_row == NULL)
+ {
+ if (png_ptr->prev_row == NULL)
+ {
+ png_warning(png_ptr, "Can't add Average filter after starting");
+ png_ptr->do_filter = (png_byte)(png_ptr->do_filter &
+ ~PNG_FILTER_AVG);
+ }
+
+ else
+ {
+ png_ptr->avg_row = (png_bytep)png_malloc(png_ptr,
+ (png_ptr->rowbytes + 1));
+ png_ptr->avg_row[0] = PNG_FILTER_VALUE_AVG;
+ }
+ }
+
+ if ((png_ptr->do_filter & PNG_FILTER_PAETH) &&
+ png_ptr->paeth_row == NULL)
+ {
+ if (png_ptr->prev_row == NULL)
+ {
+ png_warning(png_ptr, "Can't add Paeth filter after starting");
+ png_ptr->do_filter &= (png_byte)(~PNG_FILTER_PAETH);
+ }
+
+ else
+ {
+ png_ptr->paeth_row = (png_bytep)png_malloc(png_ptr,
+ (png_ptr->rowbytes + 1));
+ png_ptr->paeth_row[0] = PNG_FILTER_VALUE_PAETH;
+ }
+ }
+
+ if (png_ptr->do_filter == PNG_NO_FILTERS)
+#endif /* PNG_WRITE_FILTER_SUPPORTED */
+ png_ptr->do_filter = PNG_FILTER_NONE;
+ }
+ }
+ else
+ png_error(png_ptr, "Unknown custom filter method");
+}
+
+/* This allows us to influence the way in which libpng chooses the "best"
+ * filter for the current scanline. While the "minimum-sum-of-absolute-
+ * differences metric is relatively fast and effective, there is some
+ * question as to whether it can be improved upon by trying to keep the
+ * filtered data going to zlib more consistent, hopefully resulting in
+ * better compression.
+ */
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED /* GRR 970116 */
+/* Convenience reset API. */
+static void
+png_reset_filter_heuristics(png_structrp png_ptr)
+{
+ /* Clear out any old values in the 'weights' - this must be done because if
+ * the app calls set_filter_heuristics multiple times with different
+ * 'num_weights' values we would otherwise potentially have wrong sized
+ * arrays.
+ */
+ png_ptr->num_prev_filters = 0;
+ png_ptr->heuristic_method = PNG_FILTER_HEURISTIC_UNWEIGHTED;
+ if (png_ptr->prev_filters != NULL)
+ {
+ png_bytep old = png_ptr->prev_filters;
+ png_ptr->prev_filters = NULL;
+ png_free(png_ptr, old);
+ }
+ if (png_ptr->filter_weights != NULL)
+ {
+ png_uint_16p old = png_ptr->filter_weights;
+ png_ptr->filter_weights = NULL;
+ png_free(png_ptr, old);
+ }
+
+ if (png_ptr->inv_filter_weights != NULL)
+ {
+ png_uint_16p old = png_ptr->inv_filter_weights;
+ png_ptr->inv_filter_weights = NULL;
+ png_free(png_ptr, old);
+ }
+
+ /* Leave the filter_costs - this array is fixed size. */
+}
+
+static int
+png_init_filter_heuristics(png_structrp png_ptr, int heuristic_method,
+ int num_weights)
+{
+ if (png_ptr == NULL)
+ return 0;
+
+ /* Clear out the arrays */
+ png_reset_filter_heuristics(png_ptr);
+
+ /* Check arguments; the 'reset' function makes the correct settings for the
+ * unweighted case, but we must handle the weight case by initializing the
+ * arrays for the caller.
+ */
+ if (heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int i;
+
+ if (num_weights > 0)
+ {
+ png_ptr->prev_filters = (png_bytep)png_malloc(png_ptr,
+ (png_uint_32)((sizeof (png_byte)) * num_weights));
+
+ /* To make sure that the weighting starts out fairly */
+ for (i = 0; i < num_weights; i++)
+ {
+ png_ptr->prev_filters[i] = 255;
+ }
+
+ png_ptr->filter_weights = (png_uint_16p)png_malloc(png_ptr,
+ (png_uint_32)((sizeof (png_uint_16)) * num_weights));
+
+ png_ptr->inv_filter_weights = (png_uint_16p)png_malloc(png_ptr,
+ (png_uint_32)((sizeof (png_uint_16)) * num_weights));
+
+ for (i = 0; i < num_weights; i++)
+ {
+ png_ptr->inv_filter_weights[i] =
+ png_ptr->filter_weights[i] = PNG_WEIGHT_FACTOR;
+ }
+
+ /* Safe to set this now */
+ png_ptr->num_prev_filters = (png_byte)num_weights;
+ }
+
+ /* If, in the future, there are other filter methods, this would
+ * need to be based on png_ptr->filter.
+ */
+ if (png_ptr->filter_costs == NULL)
+ {
+ png_ptr->filter_costs = (png_uint_16p)png_malloc(png_ptr,
+ (png_uint_32)((sizeof (png_uint_16)) * PNG_FILTER_VALUE_LAST));
+
+ png_ptr->inv_filter_costs = (png_uint_16p)png_malloc(png_ptr,
+ (png_uint_32)((sizeof (png_uint_16)) * PNG_FILTER_VALUE_LAST));
+ }
+
+ for (i = 0; i < PNG_FILTER_VALUE_LAST; i++)
+ {
+ png_ptr->inv_filter_costs[i] =
+ png_ptr->filter_costs[i] = PNG_COST_FACTOR;
+ }
+
+ /* All the arrays are inited, safe to set this: */
+ png_ptr->heuristic_method = PNG_FILTER_HEURISTIC_WEIGHTED;
+
+ /* Return the 'ok' code. */
+ return 1;
+ }
+ else if (heuristic_method == PNG_FILTER_HEURISTIC_DEFAULT ||
+ heuristic_method == PNG_FILTER_HEURISTIC_UNWEIGHTED)
+ {
+ return 1;
+ }
+ else
+ {
+ png_warning(png_ptr, "Unknown filter heuristic method");
+ return 0;
+ }
+}
+
+/* Provide floating and fixed point APIs */
+#ifdef PNG_FLOATING_POINT_SUPPORTED
+void PNGAPI
+png_set_filter_heuristics(png_structrp png_ptr, int heuristic_method,
+ int num_weights, png_const_doublep filter_weights,
+ png_const_doublep filter_costs)
+{
+ png_debug(1, "in png_set_filter_heuristics");
+
+ /* The internal API allocates all the arrays and ensures that the elements of
+ * those arrays are set to the default value.
+ */
+ if (!png_init_filter_heuristics(png_ptr, heuristic_method, num_weights))
+ return;
+
+ /* If using the weighted method copy in the weights. */
+ if (heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int i;
+ for (i = 0; i < num_weights; i++)
+ {
+ if (filter_weights[i] <= 0.0)
+ {
+ png_ptr->inv_filter_weights[i] =
+ png_ptr->filter_weights[i] = PNG_WEIGHT_FACTOR;
+ }
+
+ else
+ {
+ png_ptr->inv_filter_weights[i] =
+ (png_uint_16)(PNG_WEIGHT_FACTOR*filter_weights[i]+.5);
+
+ png_ptr->filter_weights[i] =
+ (png_uint_16)(PNG_WEIGHT_FACTOR/filter_weights[i]+.5);
+ }
+ }
+
+ /* Here is where we set the relative costs of the different filters. We
+ * should take the desired compression level into account when setting
+ * the costs, so that Paeth, for instance, has a high relative cost at low
+ * compression levels, while it has a lower relative cost at higher
+ * compression settings. The filter types are in order of increasing
+ * relative cost, so it would be possible to do this with an algorithm.
+ */
+ for (i = 0; i < PNG_FILTER_VALUE_LAST; i++) if (filter_costs[i] >= 1.0)
+ {
+ png_ptr->inv_filter_costs[i] =
+ (png_uint_16)(PNG_COST_FACTOR / filter_costs[i] + .5);
+
+ png_ptr->filter_costs[i] =
+ (png_uint_16)(PNG_COST_FACTOR * filter_costs[i] + .5);
+ }
+ }
+}
+#endif /* FLOATING_POINT */
+
+#ifdef PNG_FIXED_POINT_SUPPORTED
+void PNGAPI
+png_set_filter_heuristics_fixed(png_structrp png_ptr, int heuristic_method,
+ int num_weights, png_const_fixed_point_p filter_weights,
+ png_const_fixed_point_p filter_costs)
+{
+ png_debug(1, "in png_set_filter_heuristics_fixed");
+
+ /* The internal API allocates all the arrays and ensures that the elements of
+ * those arrays are set to the default value.
+ */
+ if (!png_init_filter_heuristics(png_ptr, heuristic_method, num_weights))
+ return;
+
+ /* If using the weighted method copy in the weights. */
+ if (heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int i;
+ for (i = 0; i < num_weights; i++)
+ {
+ if (filter_weights[i] <= 0)
+ {
+ png_ptr->inv_filter_weights[i] =
+ png_ptr->filter_weights[i] = PNG_WEIGHT_FACTOR;
+ }
+
+ else
+ {
+ png_ptr->inv_filter_weights[i] = (png_uint_16)
+ ((PNG_WEIGHT_FACTOR*filter_weights[i]+PNG_FP_HALF)/PNG_FP_1);
+
+ png_ptr->filter_weights[i] = (png_uint_16)((PNG_WEIGHT_FACTOR*
+ PNG_FP_1+(filter_weights[i]/2))/filter_weights[i]);
+ }
+ }
+
+ /* Here is where we set the relative costs of the different filters. We
+ * should take the desired compression level into account when setting
+ * the costs, so that Paeth, for instance, has a high relative cost at low
+ * compression levels, while it has a lower relative cost at higher
+ * compression settings. The filter types are in order of increasing
+ * relative cost, so it would be possible to do this with an algorithm.
+ */
+ for (i = 0; i < PNG_FILTER_VALUE_LAST; i++)
+ if (filter_costs[i] >= PNG_FP_1)
+ {
+ png_uint_32 tmp;
+
+ /* Use a 32 bit unsigned temporary here because otherwise the
+ * intermediate value will be a 32 bit *signed* integer (ANSI rules)
+ * and this will get the wrong answer on division.
+ */
+ tmp = PNG_COST_FACTOR*PNG_FP_1 + (filter_costs[i]/2);
+ tmp /= filter_costs[i];
+
+ png_ptr->inv_filter_costs[i] = (png_uint_16)tmp;
+
+ tmp = PNG_COST_FACTOR * filter_costs[i] + PNG_FP_HALF;
+ tmp /= PNG_FP_1;
+
+ png_ptr->filter_costs[i] = (png_uint_16)tmp;
+ }
+ }
+}
+#endif /* FIXED_POINT */
+#endif /* PNG_WRITE_WEIGHTED_FILTER_SUPPORTED */
+
+void PNGAPI
+png_set_compression_level(png_structrp png_ptr, int level)
+{
+ png_debug(1, "in png_set_compression_level");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->zlib_level = level;
+}
+
+void PNGAPI
+png_set_compression_mem_level(png_structrp png_ptr, int mem_level)
+{
+ png_debug(1, "in png_set_compression_mem_level");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->zlib_mem_level = mem_level;
+}
+
+void PNGAPI
+png_set_compression_strategy(png_structrp png_ptr, int strategy)
+{
+ png_debug(1, "in png_set_compression_strategy");
+
+ if (png_ptr == NULL)
+ return;
+
+ /* The flag setting here prevents the libpng dynamic selection of strategy.
+ */
+ png_ptr->flags |= PNG_FLAG_ZLIB_CUSTOM_STRATEGY;
+ png_ptr->zlib_strategy = strategy;
+}
+
+/* If PNG_WRITE_OPTIMIZE_CMF_SUPPORTED is defined, libpng will use a
+ * smaller value of window_bits if it can do so safely.
+ */
+void PNGAPI
+png_set_compression_window_bits(png_structrp png_ptr, int window_bits)
+{
+ if (png_ptr == NULL)
+ return;
+
+ /* Prior to 1.6.0 this would warn but then set the window_bits value, this
+ * meant that negative window bits values could be selected which would cause
+ * libpng to write a non-standard PNG file with raw deflate or gzip
+ * compressed IDAT or ancillary chunks. Such files can be read and there is
+ * no warning on read, so this seems like a very bad idea.
+ */
+ if (window_bits > 15)
+ {
+ png_warning(png_ptr, "Only compression windows <= 32k supported by PNG");
+ window_bits = 15;
+ }
+
+ else if (window_bits < 8)
+ {
+ png_warning(png_ptr, "Only compression windows >= 256 supported by PNG");
+ window_bits = 8;
+ }
+
+ png_ptr->zlib_window_bits = window_bits;
+}
+
+void PNGAPI
+png_set_compression_method(png_structrp png_ptr, int method)
+{
+ png_debug(1, "in png_set_compression_method");
+
+ if (png_ptr == NULL)
+ return;
+
+ /* This would produce an invalid PNG file if it worked, but it doesn't and
+ * deflate will fault it, so it is harmless to just warn here.
+ */
+ if (method != 8)
+ png_warning(png_ptr, "Only compression method 8 is supported by PNG");
+
+ png_ptr->zlib_method = method;
+}
+
+/* The following were added to libpng-1.5.4 */
+#ifdef PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED
+void PNGAPI
+png_set_text_compression_level(png_structrp png_ptr, int level)
+{
+ png_debug(1, "in png_set_text_compression_level");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->zlib_text_level = level;
+}
+
+void PNGAPI
+png_set_text_compression_mem_level(png_structrp png_ptr, int mem_level)
+{
+ png_debug(1, "in png_set_text_compression_mem_level");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->zlib_text_mem_level = mem_level;
+}
+
+void PNGAPI
+png_set_text_compression_strategy(png_structrp png_ptr, int strategy)
+{
+ png_debug(1, "in png_set_text_compression_strategy");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->zlib_text_strategy = strategy;
+}
+
+/* If PNG_WRITE_OPTIMIZE_CMF_SUPPORTED is defined, libpng will use a
+ * smaller value of window_bits if it can do so safely.
+ */
+void PNGAPI
+png_set_text_compression_window_bits(png_structrp png_ptr, int window_bits)
+{
+ if (png_ptr == NULL)
+ return;
+
+ if (window_bits > 15)
+ {
+ png_warning(png_ptr, "Only compression windows <= 32k supported by PNG");
+ window_bits = 15;
+ }
+
+ else if (window_bits < 8)
+ {
+ png_warning(png_ptr, "Only compression windows >= 256 supported by PNG");
+ window_bits = 8;
+ }
+
+ png_ptr->zlib_text_window_bits = window_bits;
+}
+
+void PNGAPI
+png_set_text_compression_method(png_structrp png_ptr, int method)
+{
+ png_debug(1, "in png_set_text_compression_method");
+
+ if (png_ptr == NULL)
+ return;
+
+ if (method != 8)
+ png_warning(png_ptr, "Only compression method 8 is supported by PNG");
+
+ png_ptr->zlib_text_method = method;
+}
+#endif /* PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED */
+/* end of API added to libpng-1.5.4 */
+
+void PNGAPI
+png_set_write_status_fn(png_structrp png_ptr, png_write_status_ptr write_row_fn)
+{
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->write_row_fn = write_row_fn;
+}
+
+#ifdef PNG_WRITE_USER_TRANSFORM_SUPPORTED
+void PNGAPI
+png_set_write_user_transform_fn(png_structrp png_ptr, png_user_transform_ptr
+ write_user_transform_fn)
+{
+ png_debug(1, "in png_set_write_user_transform_fn");
+
+ if (png_ptr == NULL)
+ return;
+
+ png_ptr->transformations |= PNG_USER_TRANSFORM;
+ png_ptr->write_user_transform_fn = write_user_transform_fn;
+}
+#endif
+
+
+#ifdef PNG_INFO_IMAGE_SUPPORTED
+void PNGAPI
+png_write_png(png_structrp png_ptr, png_inforp info_ptr,
+ int transforms, voidp params)
+{
+ if (png_ptr == NULL || info_ptr == NULL)
+ return;
+
+ /* Write the file header information. */
+ png_write_info(png_ptr, info_ptr);
+
+ /* ------ these transformations don't touch the info structure ------- */
+
+#ifdef PNG_WRITE_INVERT_SUPPORTED
+ /* Invert monochrome pixels */
+ if (transforms & PNG_TRANSFORM_INVERT_MONO)
+ png_set_invert_mono(png_ptr);
+#endif
+
+#ifdef PNG_WRITE_SHIFT_SUPPORTED
+ /* Shift the pixels up to a legal bit depth and fill in
+ * as appropriate to correctly scale the image.
+ */
+ if ((transforms & PNG_TRANSFORM_SHIFT)
+ && (info_ptr->valid & PNG_INFO_sBIT))
+ png_set_shift(png_ptr, &info_ptr->sig_bit);
+#endif
+
+#ifdef PNG_WRITE_PACK_SUPPORTED
+ /* Pack pixels into bytes */
+ if (transforms & PNG_TRANSFORM_PACKING)
+ png_set_packing(png_ptr);
+#endif
+
+#ifdef PNG_WRITE_SWAP_ALPHA_SUPPORTED
+ /* Swap location of alpha bytes from ARGB to RGBA */
+ if (transforms & PNG_TRANSFORM_SWAP_ALPHA)
+ png_set_swap_alpha(png_ptr);
+#endif
+
+#ifdef PNG_WRITE_FILLER_SUPPORTED
+ /* Pack XRGB/RGBX/ARGB/RGBA into RGB (4 channels -> 3 channels) */
+ if (transforms & PNG_TRANSFORM_STRIP_FILLER_AFTER)
+ png_set_filler(png_ptr, 0, PNG_FILLER_AFTER);
+
+ else if (transforms & PNG_TRANSFORM_STRIP_FILLER_BEFORE)
+ png_set_filler(png_ptr, 0, PNG_FILLER_BEFORE);
+#endif
+
+#ifdef PNG_WRITE_BGR_SUPPORTED
+ /* Flip BGR pixels to RGB */
+ if (transforms & PNG_TRANSFORM_BGR)
+ png_set_bgr(png_ptr);
+#endif
+
+#ifdef PNG_WRITE_SWAP_SUPPORTED
+ /* Swap bytes of 16-bit files to most significant byte first */
+ if (transforms & PNG_TRANSFORM_SWAP_ENDIAN)
+ png_set_swap(png_ptr);
+#endif
+
+#ifdef PNG_WRITE_PACKSWAP_SUPPORTED
+ /* Swap bits of 1, 2, 4 bit packed pixel formats */
+ if (transforms & PNG_TRANSFORM_PACKSWAP)
+ png_set_packswap(png_ptr);
+#endif
+
+#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED
+ /* Invert the alpha channel from opacity to transparency */
+ if (transforms & PNG_TRANSFORM_INVERT_ALPHA)
+ png_set_invert_alpha(png_ptr);
+#endif
+
+ /* ----------------------- end of transformations ------------------- */
+
+ /* Write the bits */
+ if (info_ptr->valid & PNG_INFO_IDAT)
+ png_write_image(png_ptr, info_ptr->row_pointers);
+
+ /* It is REQUIRED to call this to finish writing the rest of the file */
+ png_write_end(png_ptr, info_ptr);
+
+ PNG_UNUSED(transforms) /* Quiet compiler warnings */
+ PNG_UNUSED(params)
+}
+#endif
+
+
+#ifdef PNG_SIMPLIFIED_WRITE_SUPPORTED
+#ifdef PNG_STDIO_SUPPORTED /* currently required for png_image_write_* */
+/* Initialize the write structure - general purpose utility. */
+static int
+png_image_write_init(png_imagep image)
+{
+ png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, image,
+ png_safe_error, png_safe_warning);
+
+ if (png_ptr != NULL)
+ {
+ png_infop info_ptr = png_create_info_struct(png_ptr);
+
+ if (info_ptr != NULL)
+ {
+ png_controlp control = png_voidcast(png_controlp,
+ png_malloc_warn(png_ptr, (sizeof *control)));
+
+ if (control != NULL)
+ {
+ memset(control, 0, (sizeof *control));
+
+ control->png_ptr = png_ptr;
+ control->info_ptr = info_ptr;
+ control->for_write = 1;
+
+ image->opaque = control;
+ return 1;
+ }
+
+ /* Error clean up */
+ png_destroy_info_struct(png_ptr, &info_ptr);
+ }
+
+ png_destroy_write_struct(&png_ptr, NULL);
+ }
+
+ return png_image_error(image, "png_image_write_: out of memory");
+}
+
+/* Arguments to png_image_write_main: */
+typedef struct
+{
+ /* Arguments: */
+ png_imagep image;
+ png_const_voidp buffer;
+ png_int_32 row_stride;
+ png_const_voidp colormap;
+ int convert_to_8bit;
+ /* Local variables: */
+ png_const_voidp first_row;
+ ptrdiff_t row_bytes;
+ png_voidp local_row;
+} png_image_write_control;
+
+/* Write png_uint_16 input to a 16-bit PNG; the png_ptr has already been set to
+ * do any necessary byte swapping. The component order is defined by the
+ * png_image format value.
+ */
+static int
+png_write_image_16bit(png_voidp argument)
+{
+ png_image_write_control *display = png_voidcast(png_image_write_control*,
+ argument);
+ png_imagep image = display->image;
+ png_structrp png_ptr = image->opaque->png_ptr;
+
+ png_const_uint_16p input_row = png_voidcast(png_const_uint_16p,
+ display->first_row);
+ png_uint_16p output_row = png_voidcast(png_uint_16p, display->local_row);
+ png_uint_16p row_end;
+ const int channels = (image->format & PNG_FORMAT_FLAG_COLOR) ? 3 : 1;
+ int aindex = 0;
+ png_uint_32 y = image->height;
+
+ if (image->format & PNG_FORMAT_FLAG_ALPHA)
+ {
+ if (image->format & PNG_FORMAT_FLAG_AFIRST)
+ {
+ aindex = -1;
+ ++input_row; /* To point to the first component */
+ ++output_row;
+ }
+
+ else
+ aindex = channels;
+ }
+
+ else
+ png_error(png_ptr, "png_write_image: internal call error");
+
+ /* Work out the output row end and count over this, note that the increment
+ * above to 'row' means that row_end can actually be beyond the end of the
+ * row; this is correct.
+ */
+ row_end = output_row + image->width * (channels+1);
+
+ while (y-- > 0)
+ {
+ png_const_uint_16p in_ptr = input_row;
+ png_uint_16p out_ptr = output_row;
+
+ while (out_ptr < row_end)
+ {
+ const png_uint_16 alpha = in_ptr[aindex];
+ png_uint_32 reciprocal = 0;
+ int c;
+
+ out_ptr[aindex] = alpha;
+
+ /* Calculate a reciprocal. The correct calculation is simply
+ * component/alpha*65535 << 15. (I.e. 15 bits of precision); this
+ * allows correct rounding by adding .5 before the shift. 'reciprocal'
+ * is only initialized when required.
+ */
+ if (alpha > 0 && alpha < 65535)
+ reciprocal = ((0xffff<<15)+(alpha>>1))/alpha;
+
+ c = channels;
+ do /* always at least one channel */
+ {
+ png_uint_16 component = *in_ptr++;
+
+ /* The following gives 65535 for an alpha of 0, which is fine,
+ * otherwise if 0/0 is represented as some other value there is more
+ * likely to be a discontinuity which will probably damage
+ * compression when moving from a fully transparent area to a
+ * nearly transparent one. (The assumption here is that opaque
+ * areas tend not to be 0 intensity.)
+ */
+ if (component >= alpha)
+ component = 65535;
+
+ /* component<alpha, so component/alpha is less than one and
+ * component*reciprocal is less than 2^31.
+ */
+ else if (component > 0 && alpha < 65535)
+ {
+ png_uint_32 calc = component * reciprocal;
+ calc += 16384; /* round to nearest */
+ component = (png_uint_16)(calc >> 15);
+ }
+
+ *out_ptr++ = component;
+ }
+ while (--c > 0);
+
+ /* Skip to next component (skip the intervening alpha channel) */
+ ++in_ptr;
+ ++out_ptr;
+ }
+
+ png_write_row(png_ptr, png_voidcast(png_const_bytep, display->local_row));
+ input_row += display->row_bytes/(sizeof (png_uint_16));
+ }
+
+ return 1;
+}
+
+/* Given 16-bit input (1 to 4 channels) write 8-bit output. If an alpha channel
+ * is present it must be removed from the components, the components are then
+ * written in sRGB encoding. No components are added or removed.
+ *
+ * Calculate an alpha reciprocal to reverse pre-multiplication. As above the
+ * calculation can be done to 15 bits of accuracy; however, the output needs to
+ * be scaled in the range 0..255*65535, so include that scaling here.
+ */
+#define UNP_RECIPROCAL(alpha) ((((0xffff*0xff)<<7)+(alpha>>1))/alpha)
+
+static png_byte
+png_unpremultiply(png_uint_32 component, png_uint_32 alpha,
+ png_uint_32 reciprocal/*from the above macro*/)
+{
+ /* The following gives 1.0 for an alpha of 0, which is fine, otherwise if 0/0
+ * is represented as some other value there is more likely to be a
+ * discontinuity which will probably damage compression when moving from a
+ * fully transparent area to a nearly transparent one. (The assumption here
+ * is that opaque areas tend not to be 0 intensity.)
+ *
+ * There is a rounding problem here; if alpha is less than 128 it will end up
+ * as 0 when scaled to 8 bits. To avoid introducing spurious colors into the
+ * output change for this too.
+ */
+ if (component >= alpha || alpha < 128)
+ return 255;
+
+ /* component<alpha, so component/alpha is less than one and
+ * component*reciprocal is less than 2^31.
+ */
+ else if (component > 0)
+ {
+ /* The test is that alpha/257 (rounded) is less than 255, the first value
+ * that becomes 255 is 65407.
+ * NOTE: this must agree with the PNG_DIV257 macro (which must, therefore,
+ * be exact!) [Could also test reciprocal != 0]
+ */
+ if (alpha < 65407)
+ {
+ component *= reciprocal;
+ component += 64; /* round to nearest */
+ component >>= 7;
+ }
+
+ else
+ component *= 255;
+
+ /* Convert the component to sRGB. */
+ return (png_byte)PNG_sRGB_FROM_LINEAR(component);
+ }
+
+ else
+ return 0;
+}
+
+static int
+png_write_image_8bit(png_voidp argument)
+{
+ png_image_write_control *display = png_voidcast(png_image_write_control*,
+ argument);
+ png_imagep image = display->image;
+ png_structrp png_ptr = image->opaque->png_ptr;
+
+ png_const_uint_16p input_row = png_voidcast(png_const_uint_16p,
+ display->first_row);
+ png_bytep output_row = png_voidcast(png_bytep, display->local_row);
+ png_uint_32 y = image->height;
+ const int channels = (image->format & PNG_FORMAT_FLAG_COLOR) ? 3 : 1;
+
+ if (image->format & PNG_FORMAT_FLAG_ALPHA)
+ {
+ png_bytep row_end;
+ int aindex;
+
+ if (image->format & PNG_FORMAT_FLAG_AFIRST)
+ {
+ aindex = -1;
+ ++input_row; /* To point to the first component */
+ ++output_row;
+ }
+
+ else
+ aindex = channels;
+
+ /* Use row_end in place of a loop counter: */
+ row_end = output_row + image->width * (channels+1);
+
+ while (y-- > 0)
+ {
+ png_const_uint_16p in_ptr = input_row;
+ png_bytep out_ptr = output_row;
+
+ while (out_ptr < row_end)
+ {
+ png_uint_16 alpha = in_ptr[aindex];
+ png_byte alphabyte = (png_byte)PNG_DIV257(alpha);
+ png_uint_32 reciprocal = 0;
+ int c;
+
+ /* Scale and write the alpha channel. */
+ out_ptr[aindex] = alphabyte;
+
+ if (alphabyte > 0 && alphabyte < 255)
+ reciprocal = UNP_RECIPROCAL(alpha);
+
+ c = channels;
+ do /* always at least one channel */
+ *out_ptr++ = png_unpremultiply(*in_ptr++, alpha, reciprocal);
+ while (--c > 0);
+
+ /* Skip to next component (skip the intervening alpha channel) */
+ ++in_ptr;
+ ++out_ptr;
+ } /* while out_ptr < row_end */
+
+ png_write_row(png_ptr, png_voidcast(png_const_bytep,
+ display->local_row));
+ input_row += display->row_bytes/(sizeof (png_uint_16));
+ } /* while y */
+ }
+
+ else
+ {
+ /* No alpha channel, so the row_end really is the end of the row and it
+ * is sufficient to loop over the components one by one.
+ */
+ png_bytep row_end = output_row + image->width * channels;
+
+ while (y-- > 0)
+ {
+ png_const_uint_16p in_ptr = input_row;
+ png_bytep out_ptr = output_row;
+
+ while (out_ptr < row_end)
+ {
+ png_uint_32 component = *in_ptr++;
+
+ component *= 255;
+ *out_ptr++ = (png_byte)PNG_sRGB_FROM_LINEAR(component);
+ }
+
+ png_write_row(png_ptr, output_row);
+ input_row += display->row_bytes/(sizeof (png_uint_16));
+ }
+ }
+
+ return 1;
+}
+
+static void
+png_image_set_PLTE(png_image_write_control *display)
+{
+ const png_imagep image = display->image;
+ const void *cmap = display->colormap;
+ const int entries = image->colormap_entries > 256 ? 256 :
+ (int)image->colormap_entries;
+
+ /* NOTE: the caller must check for cmap != NULL and entries != 0 */
+ const png_uint_32 format = image->format;
+ const int channels = PNG_IMAGE_SAMPLE_CHANNELS(format);
+
+# ifdef PNG_FORMAT_BGR_SUPPORTED
+ const int afirst = (format & PNG_FORMAT_FLAG_AFIRST) != 0 &&
+ (format & PNG_FORMAT_FLAG_ALPHA) != 0;
+# else
+# define afirst 0
+# endif
+
+# ifdef PNG_FORMAT_BGR_SUPPORTED
+ const int bgr = (format & PNG_FORMAT_FLAG_BGR) ? 2 : 0;
+# else
+# define bgr 0
+# endif
+
+ int i, num_trans;
+ png_color palette[256];
+ png_byte tRNS[256];
+
+ memset(tRNS, 255, (sizeof tRNS));
+ memset(palette, 0, (sizeof palette));
+
+ for (i=num_trans=0; i<entries; ++i)
+ {
+ /* This gets automatically converted to sRGB with reversal of the
+ * pre-multiplication if the color-map has an alpha channel.
+ */
+ if (format & PNG_FORMAT_FLAG_LINEAR)
+ {
+ png_const_uint_16p entry = png_voidcast(png_const_uint_16p, cmap);
+
+ entry += i * channels;
+
+ if (channels & 1) /* no alpha */
+ {
+ if (channels >= 3) /* RGB */
+ {
+ palette[i].blue = (png_byte)PNG_sRGB_FROM_LINEAR(255 *
+ entry[(2 ^ bgr)]);
+ palette[i].green = (png_byte)PNG_sRGB_FROM_LINEAR(255 *
+ entry[1]);
+ palette[i].red = (png_byte)PNG_sRGB_FROM_LINEAR(255 *
+ entry[bgr]);
+ }
+
+ else /* Gray */
+ palette[i].blue = palette[i].red = palette[i].green =
+ (png_byte)PNG_sRGB_FROM_LINEAR(255 * *entry);
+ }
+
+ else /* alpha */
+ {
+ png_uint_16 alpha = entry[afirst ? 0 : channels-1];
+ png_byte alphabyte = (png_byte)PNG_DIV257(alpha);
+ png_uint_32 reciprocal = 0;
+
+ /* Calculate a reciprocal, as in the png_write_image_8bit code above
+ * this is designed to produce a value scaled to 255*65535 when
+ * divided by 128 (i.e. asr 7).
+ */
+ if (alphabyte > 0 && alphabyte < 255)
+ reciprocal = (((0xffff*0xff)<<7)+(alpha>>1))/alpha;
+
+ tRNS[i] = alphabyte;
+ if (alphabyte < 255)
+ num_trans = i+1;
+
+ if (channels >= 3) /* RGB */
+ {
+ palette[i].blue = png_unpremultiply(entry[afirst + (2 ^ bgr)],
+ alpha, reciprocal);
+ palette[i].green = png_unpremultiply(entry[afirst + 1], alpha,
+ reciprocal);
+ palette[i].red = png_unpremultiply(entry[afirst + bgr], alpha,
+ reciprocal);
+ }
+
+ else /* gray */
+ palette[i].blue = palette[i].red = palette[i].green =
+ png_unpremultiply(entry[afirst], alpha, reciprocal);
+ }
+ }
+
+ else /* Color-map has sRGB values */
+ {
+ png_const_bytep entry = png_voidcast(png_const_bytep, cmap);
+
+ entry += i * channels;
+
+ switch (channels)
+ {
+ case 4:
+ tRNS[i] = entry[afirst ? 0 : 3];
+ if (tRNS[i] < 255)
+ num_trans = i+1;
+ /* FALL THROUGH */
+ case 3:
+ palette[i].blue = entry[afirst + (2 ^ bgr)];
+ palette[i].green = entry[afirst + 1];
+ palette[i].red = entry[afirst + bgr];
+ break;
+
+ case 2:
+ tRNS[i] = entry[1 ^ afirst];
+ if (tRNS[i] < 255)
+ num_trans = i+1;
+ /* FALL THROUGH */
+ case 1:
+ palette[i].blue = palette[i].red = palette[i].green =
+ entry[afirst];
+ break;
+
+ default:
+ break;
+ }
+ }
+ }
+
+# ifdef afirst
+# undef afirst
+# endif
+# ifdef bgr
+# undef bgr
+# endif
+
+ png_set_PLTE(image->opaque->png_ptr, image->opaque->info_ptr, palette,
+ entries);
+
+ if (num_trans > 0)
+ png_set_tRNS(image->opaque->png_ptr, image->opaque->info_ptr, tRNS,
+ num_trans, NULL);
+
+ image->colormap_entries = entries;
+}
+
+static int
+png_image_write_main(png_voidp argument)
+{
+ png_image_write_control *display = png_voidcast(png_image_write_control*,
+ argument);
+ png_imagep image = display->image;
+ png_structrp png_ptr = image->opaque->png_ptr;
+ png_inforp info_ptr = image->opaque->info_ptr;
+ png_uint_32 format = image->format;
+
+ int colormap = (format & PNG_FORMAT_FLAG_COLORMAP) != 0;
+ int linear = !colormap && (format & PNG_FORMAT_FLAG_LINEAR) != 0; /* input */
+ int alpha = !colormap && (format & PNG_FORMAT_FLAG_ALPHA) != 0;
+ int write_16bit = linear && !colormap && !display->convert_to_8bit;
+
+# ifdef PNG_BENIGN_ERRORS_SUPPORTED
+ /* Make sure we error out on any bad situation */
+ png_set_benign_errors(png_ptr, 0/*error*/);
+# endif
+
+ /* Default the 'row_stride' parameter if required. */
+ if (display->row_stride == 0)
+ display->row_stride = PNG_IMAGE_ROW_STRIDE(*image);
+
+ /* Set the required transforms then write the rows in the correct order. */
+ if (format & PNG_FORMAT_FLAG_COLORMAP)
+ {
+ if (display->colormap != NULL && image->colormap_entries > 0)
+ {
+ png_uint_32 entries = image->colormap_entries;
+
+ png_set_IHDR(png_ptr, info_ptr, image->width, image->height,
+ entries > 16 ? 8 : (entries > 4 ? 4 : (entries > 2 ? 2 : 1)),
+ PNG_COLOR_TYPE_PALETTE, PNG_INTERLACE_NONE,
+ PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE);
+
+ png_image_set_PLTE(display);
+ }
+
+ else
+ png_error(image->opaque->png_ptr,
+ "no color-map for color-mapped image");
+ }
+
+ else
+ png_set_IHDR(png_ptr, info_ptr, image->width, image->height,
+ write_16bit ? 16 : 8,
+ ((format & PNG_FORMAT_FLAG_COLOR) ? PNG_COLOR_MASK_COLOR : 0) +
+ ((format & PNG_FORMAT_FLAG_ALPHA) ? PNG_COLOR_MASK_ALPHA : 0),
+ PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE);
+
+ /* Counter-intuitively the data transformations must be called *after*
+ * png_write_info, not before as in the read code, but the 'set' functions
+ * must still be called before. Just set the color space information, never
+ * write an interlaced image.
+ */
+
+ if (write_16bit)
+ {
+ /* The gamma here is 1.0 (linear) and the cHRM chunk matches sRGB. */
+ png_set_gAMA_fixed(png_ptr, info_ptr, PNG_GAMMA_LINEAR);
+
+ if (!(image->flags & PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB))
+ png_set_cHRM_fixed(png_ptr, info_ptr,
+ /* color x y */
+ /* white */ 31270, 32900,
+ /* red */ 64000, 33000,
+ /* green */ 30000, 60000,
+ /* blue */ 15000, 6000
+ );
+ }
+
+ else if (!(image->flags & PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB))
+ png_set_sRGB(png_ptr, info_ptr, PNG_sRGB_INTENT_PERCEPTUAL);
+
+ /* Else writing an 8-bit file and the *colors* aren't sRGB, but the 8-bit
+ * space must still be gamma encoded.
+ */
+ else
+ png_set_gAMA_fixed(png_ptr, info_ptr, PNG_GAMMA_sRGB_INVERSE);
+
+ /* Write the file header. */
+ png_write_info(png_ptr, info_ptr);
+
+ /* Now set up the data transformations (*after* the header is written),
+ * remove the handled transformations from the 'format' flags for checking.
+ *
+ * First check for a little endian system if writing 16 bit files.
+ */
+ if (write_16bit)
+ {
+ PNG_CONST png_uint_16 le = 0x0001;
+
+ if (*(png_const_bytep)&le)
+ png_set_swap(png_ptr);
+ }
+
+# ifdef PNG_SIMPLIFIED_WRITE_BGR_SUPPORTED
+ if (format & PNG_FORMAT_FLAG_BGR)
+ {
+ if (!colormap && (format & PNG_FORMAT_FLAG_COLOR) != 0)
+ png_set_bgr(png_ptr);
+ format &= ~PNG_FORMAT_FLAG_BGR;
+ }
+# endif
+
+# ifdef PNG_SIMPLIFIED_WRITE_AFIRST_SUPPORTED
+ if (format & PNG_FORMAT_FLAG_AFIRST)
+ {
+ if (!colormap && (format & PNG_FORMAT_FLAG_ALPHA) != 0)
+ png_set_swap_alpha(png_ptr);
+ format &= ~PNG_FORMAT_FLAG_AFIRST;
+ }
+# endif
+
+ /* If there are 16 or fewer color-map entries we wrote a lower bit depth
+ * above, but the application data is still byte packed.
+ */
+ if (colormap && image->colormap_entries <= 16)
+ png_set_packing(png_ptr);
+
+ /* That should have handled all (both) the transforms. */
+ if ((format & ~(png_uint_32)(PNG_FORMAT_FLAG_COLOR | PNG_FORMAT_FLAG_LINEAR |
+ PNG_FORMAT_FLAG_ALPHA | PNG_FORMAT_FLAG_COLORMAP)) != 0)
+ png_error(png_ptr, "png_write_image: unsupported transformation");
+
+ {
+ png_const_bytep row = png_voidcast(png_const_bytep, display->buffer);
+ ptrdiff_t row_bytes = display->row_stride;
+
+ if (linear)
+ row_bytes *= (sizeof (png_uint_16));
+
+ if (row_bytes < 0)
+ row += (image->height-1) * (-row_bytes);
+
+ display->first_row = row;
+ display->row_bytes = row_bytes;
+ }
+
+ /* Apply 'fast' options if the flag is set. */
+ if ((image->flags & PNG_IMAGE_FLAG_FAST) != 0)
+ {
+ png_set_filter(png_ptr, PNG_FILTER_TYPE_BASE, PNG_NO_FILTERS);
+ /* NOTE: determined by experiment using pngstest, this reflects some
+ * balance between the time to write the image once and the time to read
+ * it about 50 times. The speed-up in pngstest was about 10-20% of the
+ * total (user) time on a heavily loaded system.
+ */
+ png_set_compression_level(png_ptr, 3);
+ }
+
+ /* Check for the cases that currently require a pre-transform on the row
+ * before it is written. This only applies when the input is 16-bit and
+ * either there is an alpha channel or it is converted to 8-bit.
+ */
+ if ((linear && alpha) || (!colormap && display->convert_to_8bit))
+ {
+ png_bytep row = png_voidcast(png_bytep, png_malloc(png_ptr,
+ png_get_rowbytes(png_ptr, info_ptr)));
+ int result;
+
+ display->local_row = row;
+ if (write_16bit)
+ result = png_safe_execute(image, png_write_image_16bit, display);
+ else
+ result = png_safe_execute(image, png_write_image_8bit, display);
+ display->local_row = NULL;
+
+ png_free(png_ptr, row);
+
+ /* Skip the 'write_end' on error: */
+ if (!result)
+ return 0;
+ }
+
+ /* Otherwise this is the case where the input is in a format currently
+ * supported by the rest of the libpng write code; call it directly.
+ */
+ else
+ {
+ png_const_bytep row = png_voidcast(png_const_bytep, display->first_row);
+ ptrdiff_t row_bytes = display->row_bytes;
+ png_uint_32 y = image->height;
+
+ while (y-- > 0)
+ {
+ png_write_row(png_ptr, row);
+ row += row_bytes;
+ }
+ }
+
+ png_write_end(png_ptr, info_ptr);
+ return 1;
+}
+
+int PNGAPI
+png_image_write_to_stdio(png_imagep image, FILE *file, int convert_to_8bit,
+ const void *buffer, png_int_32 row_stride, const void *colormap)
+{
+ /* Write the image to the given (FILE*). */
+ if (image != NULL && image->version == PNG_IMAGE_VERSION)
+ {
+ if (file != NULL)
+ {
+ if (png_image_write_init(image))
+ {
+ png_image_write_control display;
+ int result;
+
+ /* This is slightly evil, but png_init_io doesn't do anything other
+ * than this and we haven't changed the standard IO functions so
+ * this saves a 'safe' function.
+ */
+ image->opaque->png_ptr->io_ptr = file;
+
+ memset(&display, 0, (sizeof display));
+ display.image = image;
+ display.buffer = buffer;
+ display.row_stride = row_stride;
+ display.colormap = colormap;
+ display.convert_to_8bit = convert_to_8bit;
+
+ result = png_safe_execute(image, png_image_write_main, &display);
+ png_image_free(image);
+ return result;
+ }
+
+ else
+ return 0;
+ }
+
+ else
+ return png_image_error(image,
+ "png_image_write_to_stdio: invalid argument");
+ }
+
+ else if (image != NULL)
+ return png_image_error(image,
+ "png_image_write_to_stdio: incorrect PNG_IMAGE_VERSION");
+
+ else
+ return 0;
+}
+
+int PNGAPI
+png_image_write_to_file(png_imagep image, const char *file_name,
+ int convert_to_8bit, const void *buffer, png_int_32 row_stride,
+ const void *colormap)
+{
+ /* Write the image to the named file. */
+ if (image != NULL && image->version == PNG_IMAGE_VERSION)
+ {
+ if (file_name != NULL)
+ {
+ FILE *fp = fopen(file_name, "wb");
+
+ if (fp != NULL)
+ {
+ if (png_image_write_to_stdio(image, fp, convert_to_8bit, buffer,
+ row_stride, colormap))
+ {
+ int error; /* from fflush/fclose */
+
+ /* Make sure the file is flushed correctly. */
+ if (fflush(fp) == 0 && ferror(fp) == 0)
+ {
+ if (fclose(fp) == 0)
+ return 1;
+
+ error = errno; /* from fclose */
+ }
+
+ else
+ {
+ error = errno; /* from fflush or ferror */
+ (void)fclose(fp);
+ }
+
+ (void)remove(file_name);
+ /* The image has already been cleaned up; this is just used to
+ * set the error (because the original write succeeded).
+ */
+ return png_image_error(image, strerror(error));
+ }
+
+ else
+ {
+ /* Clean up: just the opened file. */
+ (void)fclose(fp);
+ (void)remove(file_name);
+ return 0;
+ }
+ }
+
+ else
+ return png_image_error(image, strerror(errno));
+ }
+
+ else
+ return png_image_error(image,
+ "png_image_write_to_file: invalid argument");
+ }
+
+ else if (image != NULL)
+ return png_image_error(image,
+ "png_image_write_to_file: incorrect PNG_IMAGE_VERSION");
+
+ else
+ return 0;
+}
+#endif /* PNG_STDIO_SUPPORTED */
+#endif /* SIMPLIFIED_WRITE */
+#endif /* PNG_WRITE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngwtran.c b/ml/dlib/dlib/external/libpng/pngwtran.c
new file mode 100644
index 000000000..98703f8c8
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngwtran.c
@@ -0,0 +1,637 @@
+
+/* pngwtran.c - transforms the data in a row for PNG writers
+ *
+ * Last changed in libpng 1.6.0 [February 14, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+#include "pngpriv.h"
+
+#ifdef PNG_WRITE_SUPPORTED
+
+#ifdef PNG_WRITE_TRANSFORMS_SUPPORTED
+/* Transform the data according to the user's wishes. The order of
+ * transformations is significant.
+ */
+void /* PRIVATE */
+png_do_write_transformations(png_structrp png_ptr, png_row_infop row_info)
+{
+ png_debug(1, "in png_do_write_transformations");
+
+ if (png_ptr == NULL)
+ return;
+
+#ifdef PNG_WRITE_USER_TRANSFORM_SUPPORTED
+ if (png_ptr->transformations & PNG_USER_TRANSFORM)
+ if (png_ptr->write_user_transform_fn != NULL)
+ (*(png_ptr->write_user_transform_fn)) /* User write transform
+ function */
+ (png_ptr, /* png_ptr */
+ row_info, /* row_info: */
+ /* png_uint_32 width; width of row */
+ /* png_size_t rowbytes; number of bytes in row */
+ /* png_byte color_type; color type of pixels */
+ /* png_byte bit_depth; bit depth of samples */
+ /* png_byte channels; number of channels (1-4) */
+ /* png_byte pixel_depth; bits per pixel (depth*channels) */
+ png_ptr->row_buf + 1); /* start of pixel data for row */
+#endif
+
+#ifdef PNG_WRITE_FILLER_SUPPORTED
+ if (png_ptr->transformations & PNG_FILLER)
+ png_do_strip_channel(row_info, png_ptr->row_buf + 1,
+ !(png_ptr->flags & PNG_FLAG_FILLER_AFTER));
+#endif
+
+#ifdef PNG_WRITE_PACKSWAP_SUPPORTED
+ if (png_ptr->transformations & PNG_PACKSWAP)
+ png_do_packswap(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_WRITE_PACK_SUPPORTED
+ if (png_ptr->transformations & PNG_PACK)
+ png_do_pack(row_info, png_ptr->row_buf + 1,
+ (png_uint_32)png_ptr->bit_depth);
+#endif
+
+#ifdef PNG_WRITE_SWAP_SUPPORTED
+ if (png_ptr->transformations & PNG_SWAP_BYTES)
+ png_do_swap(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_WRITE_SHIFT_SUPPORTED
+ if (png_ptr->transformations & PNG_SHIFT)
+ png_do_shift(row_info, png_ptr->row_buf + 1,
+ &(png_ptr->shift));
+#endif
+
+#ifdef PNG_WRITE_SWAP_ALPHA_SUPPORTED
+ if (png_ptr->transformations & PNG_SWAP_ALPHA)
+ png_do_write_swap_alpha(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED
+ if (png_ptr->transformations & PNG_INVERT_ALPHA)
+ png_do_write_invert_alpha(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_WRITE_BGR_SUPPORTED
+ if (png_ptr->transformations & PNG_BGR)
+ png_do_bgr(row_info, png_ptr->row_buf + 1);
+#endif
+
+#ifdef PNG_WRITE_INVERT_SUPPORTED
+ if (png_ptr->transformations & PNG_INVERT_MONO)
+ png_do_invert(row_info, png_ptr->row_buf + 1);
+#endif
+}
+
+#ifdef PNG_WRITE_PACK_SUPPORTED
+/* Pack pixels into bytes. Pass the true bit depth in bit_depth. The
+ * row_info bit depth should be 8 (one pixel per byte). The channels
+ * should be 1 (this only happens on grayscale and paletted images).
+ */
+void /* PRIVATE */
+png_do_pack(png_row_infop row_info, png_bytep row, png_uint_32 bit_depth)
+{
+ png_debug(1, "in png_do_pack");
+
+ if (row_info->bit_depth == 8 &&
+ row_info->channels == 1)
+ {
+ switch ((int)bit_depth)
+ {
+ case 1:
+ {
+ png_bytep sp, dp;
+ int mask, v;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ sp = row;
+ dp = row;
+ mask = 0x80;
+ v = 0;
+
+ for (i = 0; i < row_width; i++)
+ {
+ if (*sp != 0)
+ v |= mask;
+
+ sp++;
+
+ if (mask > 1)
+ mask >>= 1;
+
+ else
+ {
+ mask = 0x80;
+ *dp = (png_byte)v;
+ dp++;
+ v = 0;
+ }
+ }
+
+ if (mask != 0x80)
+ *dp = (png_byte)v;
+
+ break;
+ }
+
+ case 2:
+ {
+ png_bytep sp, dp;
+ int shift, v;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ sp = row;
+ dp = row;
+ shift = 6;
+ v = 0;
+
+ for (i = 0; i < row_width; i++)
+ {
+ png_byte value;
+
+ value = (png_byte)(*sp & 0x03);
+ v |= (value << shift);
+
+ if (shift == 0)
+ {
+ shift = 6;
+ *dp = (png_byte)v;
+ dp++;
+ v = 0;
+ }
+
+ else
+ shift -= 2;
+
+ sp++;
+ }
+
+ if (shift != 6)
+ *dp = (png_byte)v;
+
+ break;
+ }
+
+ case 4:
+ {
+ png_bytep sp, dp;
+ int shift, v;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ sp = row;
+ dp = row;
+ shift = 4;
+ v = 0;
+
+ for (i = 0; i < row_width; i++)
+ {
+ png_byte value;
+
+ value = (png_byte)(*sp & 0x0f);
+ v |= (value << shift);
+
+ if (shift == 0)
+ {
+ shift = 4;
+ *dp = (png_byte)v;
+ dp++;
+ v = 0;
+ }
+
+ else
+ shift -= 4;
+
+ sp++;
+ }
+
+ if (shift != 4)
+ *dp = (png_byte)v;
+
+ break;
+ }
+
+ default:
+ break;
+ }
+
+ row_info->bit_depth = (png_byte)bit_depth;
+ row_info->pixel_depth = (png_byte)(bit_depth * row_info->channels);
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth,
+ row_info->width);
+ }
+}
+#endif
+
+#ifdef PNG_WRITE_SHIFT_SUPPORTED
+/* Shift pixel values to take advantage of whole range. Pass the
+ * true number of bits in bit_depth. The row should be packed
+ * according to row_info->bit_depth. Thus, if you had a row of
+ * bit depth 4, but the pixels only had values from 0 to 7, you
+ * would pass 3 as bit_depth, and this routine would translate the
+ * data to 0 to 15.
+ */
+void /* PRIVATE */
+png_do_shift(png_row_infop row_info, png_bytep row,
+ png_const_color_8p bit_depth)
+{
+ png_debug(1, "in png_do_shift");
+
+ if (row_info->color_type != PNG_COLOR_TYPE_PALETTE)
+ {
+ int shift_start[4], shift_dec[4];
+ int channels = 0;
+
+ if (row_info->color_type & PNG_COLOR_MASK_COLOR)
+ {
+ shift_start[channels] = row_info->bit_depth - bit_depth->red;
+ shift_dec[channels] = bit_depth->red;
+ channels++;
+
+ shift_start[channels] = row_info->bit_depth - bit_depth->green;
+ shift_dec[channels] = bit_depth->green;
+ channels++;
+
+ shift_start[channels] = row_info->bit_depth - bit_depth->blue;
+ shift_dec[channels] = bit_depth->blue;
+ channels++;
+ }
+
+ else
+ {
+ shift_start[channels] = row_info->bit_depth - bit_depth->gray;
+ shift_dec[channels] = bit_depth->gray;
+ channels++;
+ }
+
+ if (row_info->color_type & PNG_COLOR_MASK_ALPHA)
+ {
+ shift_start[channels] = row_info->bit_depth - bit_depth->alpha;
+ shift_dec[channels] = bit_depth->alpha;
+ channels++;
+ }
+
+ /* With low row depths, could only be grayscale, so one channel */
+ if (row_info->bit_depth < 8)
+ {
+ png_bytep bp = row;
+ png_size_t i;
+ unsigned int mask;
+ png_size_t row_bytes = row_info->rowbytes;
+
+ if (bit_depth->gray == 1 && row_info->bit_depth == 2)
+ mask = 0x55;
+
+ else if (row_info->bit_depth == 4 && bit_depth->gray == 3)
+ mask = 0x11;
+
+ else
+ mask = 0xff;
+
+ for (i = 0; i < row_bytes; i++, bp++)
+ {
+ int j;
+ unsigned int v, out;
+
+ v = *bp;
+ out = 0;
+
+ for (j = shift_start[0]; j > -shift_dec[0]; j -= shift_dec[0])
+ {
+ if (j > 0)
+ out |= v << j;
+
+ else
+ out |= (v >> (-j)) & mask;
+ }
+
+ *bp = (png_byte)(out & 0xff);
+ }
+ }
+
+ else if (row_info->bit_depth == 8)
+ {
+ png_bytep bp = row;
+ png_uint_32 i;
+ png_uint_32 istop = channels * row_info->width;
+
+ for (i = 0; i < istop; i++, bp++)
+ {
+
+ const unsigned int c = i%channels;
+ int j;
+ unsigned int v, out;
+
+ v = *bp;
+ out = 0;
+
+ for (j = shift_start[c]; j > -shift_dec[c]; j -= shift_dec[c])
+ {
+ if (j > 0)
+ out |= v << j;
+
+ else
+ out |= v >> (-j);
+ }
+
+ *bp = (png_byte)(out & 0xff);
+ }
+ }
+
+ else
+ {
+ png_bytep bp;
+ png_uint_32 i;
+ png_uint_32 istop = channels * row_info->width;
+
+ for (bp = row, i = 0; i < istop; i++)
+ {
+ const unsigned int c = i%channels;
+ int j;
+ unsigned int value, v;
+
+ v = png_get_uint_16(bp);
+ value = 0;
+
+ for (j = shift_start[c]; j > -shift_dec[c]; j -= shift_dec[c])
+ {
+ if (j > 0)
+ value |= v << j;
+
+ else
+ value |= v >> (-j);
+ }
+ *bp++ = (png_byte)((value >> 8) & 0xff);
+ *bp++ = (png_byte)(value & 0xff);
+ }
+ }
+ }
+}
+#endif
+
+#ifdef PNG_WRITE_SWAP_ALPHA_SUPPORTED
+void /* PRIVATE */
+png_do_write_swap_alpha(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_write_swap_alpha");
+
+ {
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ /* This converts from ARGB to RGBA */
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ for (i = 0, sp = dp = row; i < row_width; i++)
+ {
+ png_byte save = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = save;
+ }
+ }
+
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ else
+ {
+ /* This converts from AARRGGBB to RRGGBBAA */
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ for (i = 0, sp = dp = row; i < row_width; i++)
+ {
+ png_byte save[2];
+ save[0] = *(sp++);
+ save[1] = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = save[0];
+ *(dp++) = save[1];
+ }
+ }
+#endif /* PNG_WRITE_16BIT_SUPPORTED */
+ }
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ /* This converts from AG to GA */
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ for (i = 0, sp = dp = row; i < row_width; i++)
+ {
+ png_byte save = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = save;
+ }
+ }
+
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ else
+ {
+ /* This converts from AAGG to GGAA */
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ for (i = 0, sp = dp = row; i < row_width; i++)
+ {
+ png_byte save[2];
+ save[0] = *(sp++);
+ save[1] = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = save[0];
+ *(dp++) = save[1];
+ }
+ }
+#endif /* PNG_WRITE_16BIT_SUPPORTED */
+ }
+ }
+}
+#endif
+
+#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED
+void /* PRIVATE */
+png_do_write_invert_alpha(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_write_invert_alpha");
+
+ {
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ /* This inverts the alpha channel in RGBA */
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ for (i = 0, sp = dp = row; i < row_width; i++)
+ {
+ /* Does nothing
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ */
+ sp+=3; dp = sp;
+ *(dp++) = (png_byte)(255 - *(sp++));
+ }
+ }
+
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ else
+ {
+ /* This inverts the alpha channel in RRGGBBAA */
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ for (i = 0, sp = dp = row; i < row_width; i++)
+ {
+ /* Does nothing
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ */
+ sp+=6; dp = sp;
+ *(dp++) = (png_byte)(255 - *(sp++));
+ *(dp++) = (png_byte)(255 - *(sp++));
+ }
+ }
+#endif /* PNG_WRITE_16BIT_SUPPORTED */
+ }
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
+ {
+ if (row_info->bit_depth == 8)
+ {
+ /* This inverts the alpha channel in GA */
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ for (i = 0, sp = dp = row; i < row_width; i++)
+ {
+ *(dp++) = *(sp++);
+ *(dp++) = (png_byte)(255 - *(sp++));
+ }
+ }
+
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ else
+ {
+ /* This inverts the alpha channel in GGAA */
+ png_bytep sp, dp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ for (i = 0, sp = dp = row; i < row_width; i++)
+ {
+ /* Does nothing
+ *(dp++) = *(sp++);
+ *(dp++) = *(sp++);
+ */
+ sp+=2; dp = sp;
+ *(dp++) = (png_byte)(255 - *(sp++));
+ *(dp++) = (png_byte)(255 - *(sp++));
+ }
+ }
+#endif /* PNG_WRITE_16BIT_SUPPORTED */
+ }
+ }
+}
+#endif
+#endif /* PNG_WRITE_TRANSFORMS_SUPPORTED */
+
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+/* Undoes intrapixel differencing */
+void /* PRIVATE */
+png_do_write_intrapixel(png_row_infop row_info, png_bytep row)
+{
+ png_debug(1, "in png_do_write_intrapixel");
+
+ if ((row_info->color_type & PNG_COLOR_MASK_COLOR))
+ {
+ int bytes_per_pixel;
+ png_uint_32 row_width = row_info->width;
+ if (row_info->bit_depth == 8)
+ {
+ png_bytep rp;
+ png_uint_32 i;
+
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB)
+ bytes_per_pixel = 3;
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ bytes_per_pixel = 4;
+
+ else
+ return;
+
+ for (i = 0, rp = row; i < row_width; i++, rp += bytes_per_pixel)
+ {
+ *(rp) = (png_byte)((*rp - *(rp + 1)) & 0xff);
+ *(rp + 2) = (png_byte)((*(rp + 2) - *(rp + 1)) & 0xff);
+ }
+ }
+
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ else if (row_info->bit_depth == 16)
+ {
+ png_bytep rp;
+ png_uint_32 i;
+
+ if (row_info->color_type == PNG_COLOR_TYPE_RGB)
+ bytes_per_pixel = 6;
+
+ else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA)
+ bytes_per_pixel = 8;
+
+ else
+ return;
+
+ for (i = 0, rp = row; i < row_width; i++, rp += bytes_per_pixel)
+ {
+ png_uint_32 s0 = (*(rp ) << 8) | *(rp + 1);
+ png_uint_32 s1 = (*(rp + 2) << 8) | *(rp + 3);
+ png_uint_32 s2 = (*(rp + 4) << 8) | *(rp + 5);
+ png_uint_32 red = (png_uint_32)((s0 - s1) & 0xffffL);
+ png_uint_32 blue = (png_uint_32)((s2 - s1) & 0xffffL);
+ *(rp ) = (png_byte)((red >> 8) & 0xff);
+ *(rp + 1) = (png_byte)(red & 0xff);
+ *(rp + 4) = (png_byte)((blue >> 8) & 0xff);
+ *(rp + 5) = (png_byte)(blue & 0xff);
+ }
+ }
+#endif /* PNG_WRITE_16BIT_SUPPORTED */
+ }
+}
+#endif /* PNG_MNG_FEATURES_SUPPORTED */
+#endif /* PNG_WRITE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/libpng/pngwutil.c b/ml/dlib/dlib/external/libpng/pngwutil.c
new file mode 100644
index 000000000..49e6a2d21
--- /dev/null
+++ b/ml/dlib/dlib/external/libpng/pngwutil.c
@@ -0,0 +1,3023 @@
+
+/* pngwutil.c - utilities to write a PNG file
+ *
+ * Last changed in libpng 1.6.2 [April 25, 2013]
+ * Copyright (c) 1998-2013 Glenn Randers-Pehrson
+ * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger)
+ * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.)
+ *
+ * This code is released under the libpng license.
+ * For conditions of distribution and use, see the disclaimer
+ * and license in png.h
+ */
+
+#include "pngpriv.h"
+
+#ifdef PNG_WRITE_SUPPORTED
+
+#ifdef PNG_WRITE_INT_FUNCTIONS_SUPPORTED
+/* Place a 32-bit number into a buffer in PNG byte order. We work
+ * with unsigned numbers for convenience, although one supported
+ * ancillary chunk uses signed (two's complement) numbers.
+ */
+void PNGAPI
+png_save_uint_32(png_bytep buf, png_uint_32 i)
+{
+ buf[0] = (png_byte)((i >> 24) & 0xff);
+ buf[1] = (png_byte)((i >> 16) & 0xff);
+ buf[2] = (png_byte)((i >> 8) & 0xff);
+ buf[3] = (png_byte)(i & 0xff);
+}
+
+/* Place a 16-bit number into a buffer in PNG byte order.
+ * The parameter is declared unsigned int, not png_uint_16,
+ * just to avoid potential problems on pre-ANSI C compilers.
+ */
+void PNGAPI
+png_save_uint_16(png_bytep buf, unsigned int i)
+{
+ buf[0] = (png_byte)((i >> 8) & 0xff);
+ buf[1] = (png_byte)(i & 0xff);
+}
+#endif
+
+/* Simple function to write the signature. If we have already written
+ * the magic bytes of the signature, or more likely, the PNG stream is
+ * being embedded into another stream and doesn't need its own signature,
+ * we should call png_set_sig_bytes() to tell libpng how many of the
+ * bytes have already been written.
+ */
+void PNGAPI
+png_write_sig(png_structrp png_ptr)
+{
+ png_byte png_signature[8] = {137, 80, 78, 71, 13, 10, 26, 10};
+
+#ifdef PNG_IO_STATE_SUPPORTED
+ /* Inform the I/O callback that the signature is being written */
+ png_ptr->io_state = PNG_IO_WRITING | PNG_IO_SIGNATURE;
+#endif
+
+ /* Write the rest of the 8 byte signature */
+ png_write_data(png_ptr, &png_signature[png_ptr->sig_bytes],
+ (png_size_t)(8 - png_ptr->sig_bytes));
+
+ if (png_ptr->sig_bytes < 3)
+ png_ptr->mode |= PNG_HAVE_PNG_SIGNATURE;
+}
+
+/* Write the start of a PNG chunk. The type is the chunk type.
+ * The total_length is the sum of the lengths of all the data you will be
+ * passing in png_write_chunk_data().
+ */
+static void
+png_write_chunk_header(png_structrp png_ptr, png_uint_32 chunk_name,
+ png_uint_32 length)
+{
+ png_byte buf[8];
+
+#if defined(PNG_DEBUG) && (PNG_DEBUG > 0)
+ PNG_CSTRING_FROM_CHUNK(buf, chunk_name);
+ png_debug2(0, "Writing %s chunk, length = %lu", buf, (unsigned long)length);
+#endif
+
+ if (png_ptr == NULL)
+ return;
+
+#ifdef PNG_IO_STATE_SUPPORTED
+ /* Inform the I/O callback that the chunk header is being written.
+ * PNG_IO_CHUNK_HDR requires a single I/O call.
+ */
+ png_ptr->io_state = PNG_IO_WRITING | PNG_IO_CHUNK_HDR;
+#endif
+
+ /* Write the length and the chunk name */
+ png_save_uint_32(buf, length);
+ png_save_uint_32(buf + 4, chunk_name);
+ png_write_data(png_ptr, buf, 8);
+
+ /* Put the chunk name into png_ptr->chunk_name */
+ png_ptr->chunk_name = chunk_name;
+
+ /* Reset the crc and run it over the chunk name */
+ png_reset_crc(png_ptr);
+
+ png_calculate_crc(png_ptr, buf + 4, 4);
+
+#ifdef PNG_IO_STATE_SUPPORTED
+ /* Inform the I/O callback that chunk data will (possibly) be written.
+ * PNG_IO_CHUNK_DATA does NOT require a specific number of I/O calls.
+ */
+ png_ptr->io_state = PNG_IO_WRITING | PNG_IO_CHUNK_DATA;
+#endif
+}
+
+void PNGAPI
+png_write_chunk_start(png_structrp png_ptr, png_const_bytep chunk_string,
+ png_uint_32 length)
+{
+ png_write_chunk_header(png_ptr, PNG_CHUNK_FROM_STRING(chunk_string), length);
+}
+
+/* Write the data of a PNG chunk started with png_write_chunk_header().
+ * Note that multiple calls to this function are allowed, and that the
+ * sum of the lengths from these calls *must* add up to the total_length
+ * given to png_write_chunk_header().
+ */
+void PNGAPI
+png_write_chunk_data(png_structrp png_ptr, png_const_bytep data,
+ png_size_t length)
+{
+ /* Write the data, and run the CRC over it */
+ if (png_ptr == NULL)
+ return;
+
+ if (data != NULL && length > 0)
+ {
+ png_write_data(png_ptr, data, length);
+
+ /* Update the CRC after writing the data,
+ * in case that the user I/O routine alters it.
+ */
+ png_calculate_crc(png_ptr, data, length);
+ }
+}
+
+/* Finish a chunk started with png_write_chunk_header(). */
+void PNGAPI
+png_write_chunk_end(png_structrp png_ptr)
+{
+ png_byte buf[4];
+
+ if (png_ptr == NULL) return;
+
+#ifdef PNG_IO_STATE_SUPPORTED
+ /* Inform the I/O callback that the chunk CRC is being written.
+ * PNG_IO_CHUNK_CRC requires a single I/O function call.
+ */
+ png_ptr->io_state = PNG_IO_WRITING | PNG_IO_CHUNK_CRC;
+#endif
+
+ /* Write the crc in a single operation */
+ png_save_uint_32(buf, png_ptr->crc);
+
+ png_write_data(png_ptr, buf, (png_size_t)4);
+}
+
+/* Write a PNG chunk all at once. The type is an array of ASCII characters
+ * representing the chunk name. The array must be at least 4 bytes in
+ * length, and does not need to be null terminated. To be safe, pass the
+ * pre-defined chunk names here, and if you need a new one, define it
+ * where the others are defined. The length is the length of the data.
+ * All the data must be present. If that is not possible, use the
+ * png_write_chunk_start(), png_write_chunk_data(), and png_write_chunk_end()
+ * functions instead.
+ */
+static void
+png_write_complete_chunk(png_structrp png_ptr, png_uint_32 chunk_name,
+ png_const_bytep data, png_size_t length)
+{
+ if (png_ptr == NULL)
+ return;
+
+ /* On 64 bit architectures 'length' may not fit in a png_uint_32. */
+ if (length > PNG_UINT_31_MAX)
+ png_error(png_ptr, "length exceeds PNG maxima");
+
+ png_write_chunk_header(png_ptr, chunk_name, (png_uint_32)length);
+ png_write_chunk_data(png_ptr, data, length);
+ png_write_chunk_end(png_ptr);
+}
+
+/* This is the API that calls the internal function above. */
+void PNGAPI
+png_write_chunk(png_structrp png_ptr, png_const_bytep chunk_string,
+ png_const_bytep data, png_size_t length)
+{
+ png_write_complete_chunk(png_ptr, PNG_CHUNK_FROM_STRING(chunk_string), data,
+ length);
+}
+
+/* This is used below to find the size of an image to pass to png_deflate_claim,
+ * so it only needs to be accurate if the size is less than 16384 bytes (the
+ * point at which a lower LZ window size can be used.)
+ */
+static png_alloc_size_t
+png_image_size(png_structrp png_ptr)
+{
+ /* Only return sizes up to the maximum of a png_uint_32, do this by limiting
+ * the width and height used to 15 bits.
+ */
+ png_uint_32 h = png_ptr->height;
+
+ if (png_ptr->rowbytes < 32768 && h < 32768)
+ {
+ if (png_ptr->interlaced)
+ {
+ /* Interlacing makes the image larger because of the replication of
+ * both the filter byte and the padding to a byte boundary.
+ */
+ png_uint_32 w = png_ptr->width;
+ unsigned int pd = png_ptr->pixel_depth;
+ png_alloc_size_t cb_base;
+ int pass;
+
+ for (cb_base=0, pass=0; pass<=6; ++pass)
+ {
+ png_uint_32 pw = PNG_PASS_COLS(w, pass);
+
+ if (pw > 0)
+ cb_base += (PNG_ROWBYTES(pd, pw)+1) * PNG_PASS_ROWS(h, pass);
+ }
+
+ return cb_base;
+ }
+
+ else
+ return (png_ptr->rowbytes+1) * h;
+ }
+
+ else
+ return 0xffffffffU;
+}
+
+#ifdef PNG_WRITE_OPTIMIZE_CMF_SUPPORTED
+ /* This is the code to hack the first two bytes of the deflate stream (the
+ * deflate header) to correct the windowBits value to match the actual data
+ * size. Note that the second argument is the *uncompressed* size but the
+ * first argument is the *compressed* data (and it must be deflate
+ * compressed.)
+ */
+static void
+optimize_cmf(png_bytep data, png_alloc_size_t data_size)
+{
+ /* Optimize the CMF field in the zlib stream. The resultant zlib stream is
+ * still compliant to the stream specification.
+ */
+ if (data_size <= 16384) /* else windowBits must be 15 */
+ {
+ unsigned int z_cmf = data[0]; /* zlib compression method and flags */
+
+ if ((z_cmf & 0x0f) == 8 && (z_cmf & 0xf0) <= 0x70)
+ {
+ unsigned int z_cinfo;
+ unsigned int half_z_window_size;
+
+ z_cinfo = z_cmf >> 4;
+ half_z_window_size = 1U << (z_cinfo + 7);
+
+ if (data_size <= half_z_window_size) /* else no change */
+ {
+ unsigned int tmp;
+
+ do
+ {
+ half_z_window_size >>= 1;
+ --z_cinfo;
+ }
+ while (z_cinfo > 0 && data_size <= half_z_window_size);
+
+ z_cmf = (z_cmf & 0x0f) | (z_cinfo << 4);
+
+ data[0] = (png_byte)z_cmf;
+ tmp = data[1] & 0xe0;
+ tmp += 0x1f - ((z_cmf << 8) + tmp) % 0x1f;
+ data[1] = (png_byte)tmp;
+ }
+ }
+ }
+}
+#else
+# define optimize_cmf(dp,dl) ((void)0)
+#endif /* PNG_WRITE_OPTIMIZE_CMF_SUPPORTED */
+
+/* Initialize the compressor for the appropriate type of compression. */
+static int
+png_deflate_claim(png_structrp png_ptr, png_uint_32 owner,
+ png_alloc_size_t data_size)
+{
+ if (png_ptr->zowner != 0)
+ {
+ char msg[64];
+
+ PNG_STRING_FROM_CHUNK(msg, owner);
+ msg[4] = ':';
+ msg[5] = ' ';
+ PNG_STRING_FROM_CHUNK(msg+6, png_ptr->zowner);
+ /* So the message that results is "<chunk> using zstream"; this is an
+ * internal error, but is very useful for debugging. i18n requirements
+ * are minimal.
+ */
+ (void)png_safecat(msg, (sizeof msg), 10, " using zstream");
+# if PNG_LIBPNG_BUILD_BASE_TYPE >= PNG_LIBPNG_BUILD_RC
+ png_warning(png_ptr, msg);
+
+ /* Attempt sane error recovery */
+ if (png_ptr->zowner == png_IDAT) /* don't steal from IDAT */
+ {
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("in use by IDAT");
+ return Z_STREAM_ERROR;
+ }
+
+ png_ptr->zowner = 0;
+# else
+ png_error(png_ptr, msg);
+# endif
+ }
+
+ {
+ int level = png_ptr->zlib_level;
+ int method = png_ptr->zlib_method;
+ int windowBits = png_ptr->zlib_window_bits;
+ int memLevel = png_ptr->zlib_mem_level;
+ int strategy; /* set below */
+ int ret; /* zlib return code */
+
+ if (owner == png_IDAT)
+ {
+ if (png_ptr->flags & PNG_FLAG_ZLIB_CUSTOM_STRATEGY)
+ strategy = png_ptr->zlib_strategy;
+
+ else if (png_ptr->do_filter != PNG_FILTER_NONE)
+ strategy = PNG_Z_DEFAULT_STRATEGY;
+
+ else
+ strategy = PNG_Z_DEFAULT_NOFILTER_STRATEGY;
+ }
+
+ else
+ {
+# ifdef PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED
+ level = png_ptr->zlib_text_level;
+ method = png_ptr->zlib_text_method;
+ windowBits = png_ptr->zlib_text_window_bits;
+ memLevel = png_ptr->zlib_text_mem_level;
+ strategy = png_ptr->zlib_text_strategy;
+# else
+ /* If customization is not supported the values all come from the
+ * IDAT values except for the strategy, which is fixed to the
+ * default. (This is the pre-1.6.0 behavior too, although it was
+ * implemented in a very different way.)
+ */
+ strategy = Z_DEFAULT_STRATEGY;
+# endif
+ }
+
+ /* Adjust 'windowBits' down if larger than 'data_size'; to stop this
+ * happening just pass 32768 as the data_size parameter. Notice that zlib
+ * requires an extra 262 bytes in the window in addition to the data to be
+ * able to see the whole of the data, so if data_size+262 takes us to the
+ * next windowBits size we need to fix up the value later. (Because even
+ * though deflate needs the extra window, inflate does not!)
+ */
+ if (data_size <= 16384)
+ {
+ /* IMPLEMENTATION NOTE: this 'half_window_size' stuff is only here to
+ * work round a Microsoft Visual C misbehavior which, contrary to C-90,
+ * widens the result of the following shift to 64-bits if (and,
+ * apparently, only if) it is used in a test.
+ */
+ unsigned int half_window_size = 1U << (windowBits-1);
+
+ while (data_size + 262 <= half_window_size)
+ {
+ half_window_size >>= 1;
+ --windowBits;
+ }
+ }
+
+ /* Check against the previous initialized values, if any. */
+ if ((png_ptr->flags & PNG_FLAG_ZSTREAM_INITIALIZED) &&
+ (png_ptr->zlib_set_level != level ||
+ png_ptr->zlib_set_method != method ||
+ png_ptr->zlib_set_window_bits != windowBits ||
+ png_ptr->zlib_set_mem_level != memLevel ||
+ png_ptr->zlib_set_strategy != strategy))
+ {
+ if (deflateEnd(&png_ptr->zstream) != Z_OK)
+ png_warning(png_ptr, "deflateEnd failed (ignored)");
+
+ png_ptr->flags &= ~PNG_FLAG_ZSTREAM_INITIALIZED;
+ }
+
+ /* For safety clear out the input and output pointers (currently zlib
+ * doesn't use them on Init, but it might in the future).
+ */
+ png_ptr->zstream.next_in = NULL;
+ png_ptr->zstream.avail_in = 0;
+ png_ptr->zstream.next_out = NULL;
+ png_ptr->zstream.avail_out = 0;
+
+ /* Now initialize if required, setting the new parameters, otherwise just
+ * to a simple reset to the previous parameters.
+ */
+ if (png_ptr->flags & PNG_FLAG_ZSTREAM_INITIALIZED)
+ ret = deflateReset(&png_ptr->zstream);
+
+ else
+ {
+ ret = deflateInit2(&png_ptr->zstream, level, method, windowBits,
+ memLevel, strategy);
+
+ if (ret == Z_OK)
+ png_ptr->flags |= PNG_FLAG_ZSTREAM_INITIALIZED;
+ }
+
+ /* The return code is from either deflateReset or deflateInit2; they have
+ * pretty much the same set of error codes.
+ */
+ if (ret == Z_OK)
+ png_ptr->zowner = owner;
+
+ else
+ png_zstream_error(png_ptr, ret);
+
+ return ret;
+ }
+}
+
+/* Clean up (or trim) a linked list of compression buffers. */
+void /* PRIVATE */
+png_free_buffer_list(png_structrp png_ptr, png_compression_bufferp *listp)
+{
+ png_compression_bufferp list = *listp;
+
+ if (list != NULL)
+ {
+ *listp = NULL;
+
+ do
+ {
+ png_compression_bufferp next = list->next;
+
+ png_free(png_ptr, list);
+ list = next;
+ }
+ while (list != NULL);
+ }
+}
+
+#ifdef PNG_WRITE_COMPRESSED_TEXT_SUPPORTED
+/* This pair of functions encapsulates the operation of (a) compressing a
+ * text string, and (b) issuing it later as a series of chunk data writes.
+ * The compression_state structure is shared context for these functions
+ * set up by the caller to allow access to the relevant local variables.
+ *
+ * compression_buffer (new in 1.6.0) is just a linked list of zbuffer_size
+ * temporary buffers. From 1.6.0 it is retained in png_struct so that it will
+ * be correctly freed in the event of a write error (previous implementations
+ * just leaked memory.)
+ */
+typedef struct
+{
+ png_const_bytep input; /* The uncompressed input data */
+ png_alloc_size_t input_len; /* Its length */
+ png_uint_32 output_len; /* Final compressed length */
+ png_byte output[1024]; /* First block of output */
+} compression_state;
+
+static void
+png_text_compress_init(compression_state *comp, png_const_bytep input,
+ png_alloc_size_t input_len)
+{
+ comp->input = input;
+ comp->input_len = input_len;
+ comp->output_len = 0;
+}
+
+/* Compress the data in the compression state input */
+static int
+png_text_compress(png_structrp png_ptr, png_uint_32 chunk_name,
+ compression_state *comp, png_uint_32 prefix_len)
+{
+ int ret;
+
+ /* To find the length of the output it is necessary to first compress the
+ * input, the result is buffered rather than using the two-pass algorithm
+ * that is used on the inflate side; deflate is assumed to be slower and a
+ * PNG writer is assumed to have more memory available than a PNG reader.
+ *
+ * IMPLEMENTATION NOTE: the zlib API deflateBound() can be used to find an
+ * upper limit on the output size, but it is always bigger than the input
+ * size so it is likely to be more efficient to use this linked-list
+ * approach.
+ */
+ ret = png_deflate_claim(png_ptr, chunk_name, comp->input_len);
+
+ if (ret != Z_OK)
+ return ret;
+
+ /* Set up the compression buffers, we need a loop here to avoid overflowing a
+ * uInt. Use ZLIB_IO_MAX to limit the input. The output is always limited
+ * by the output buffer size, so there is no need to check that. Since this
+ * is ANSI-C we know that an 'int', hence a uInt, is always at least 16 bits
+ * in size.
+ */
+ {
+ png_compression_bufferp *end = &png_ptr->zbuffer_list;
+ png_alloc_size_t input_len = comp->input_len; /* may be zero! */
+ png_uint_32 output_len;
+
+ /* zlib updates these for us: */
+ png_ptr->zstream.next_in = PNGZ_INPUT_CAST(comp->input);
+ png_ptr->zstream.avail_in = 0; /* Set below */
+ png_ptr->zstream.next_out = comp->output;
+ png_ptr->zstream.avail_out = (sizeof comp->output);
+
+ output_len = png_ptr->zstream.avail_out;
+
+ do
+ {
+ uInt avail_in = ZLIB_IO_MAX;
+
+ if (avail_in > input_len)
+ avail_in = (uInt)input_len;
+
+ input_len -= avail_in;
+
+ png_ptr->zstream.avail_in = avail_in;
+
+ if (png_ptr->zstream.avail_out == 0)
+ {
+ png_compression_buffer *next;
+
+ /* Chunk data is limited to 2^31 bytes in length, so the prefix
+ * length must be counted here.
+ */
+ if (output_len + prefix_len > PNG_UINT_31_MAX)
+ {
+ ret = Z_MEM_ERROR;
+ break;
+ }
+
+ /* Need a new (malloc'ed) buffer, but there may be one present
+ * already.
+ */
+ next = *end;
+ if (next == NULL)
+ {
+ next = png_voidcast(png_compression_bufferp, png_malloc_base
+ (png_ptr, PNG_COMPRESSION_BUFFER_SIZE(png_ptr)));
+
+ if (next == NULL)
+ {
+ ret = Z_MEM_ERROR;
+ break;
+ }
+
+ /* Link in this buffer (so that it will be freed later) */
+ next->next = NULL;
+ *end = next;
+ }
+
+ png_ptr->zstream.next_out = next->output;
+ png_ptr->zstream.avail_out = png_ptr->zbuffer_size;
+ output_len += png_ptr->zstream.avail_out;
+
+ /* Move 'end' to the next buffer pointer. */
+ end = &next->next;
+ }
+
+ /* Compress the data */
+ ret = deflate(&png_ptr->zstream,
+ input_len > 0 ? Z_NO_FLUSH : Z_FINISH);
+
+ /* Claw back input data that was not consumed (because avail_in is
+ * reset above every time round the loop).
+ */
+ input_len += png_ptr->zstream.avail_in;
+ png_ptr->zstream.avail_in = 0; /* safety */
+ }
+ while (ret == Z_OK);
+
+ /* There may be some space left in the last output buffer, this needs to
+ * be subtracted from output_len.
+ */
+ output_len -= png_ptr->zstream.avail_out;
+ png_ptr->zstream.avail_out = 0; /* safety */
+ comp->output_len = output_len;
+
+ /* Now double check the output length, put in a custom message if it is
+ * too long. Otherwise ensure the z_stream::msg pointer is set to
+ * something.
+ */
+ if (output_len + prefix_len >= PNG_UINT_31_MAX)
+ {
+ png_ptr->zstream.msg = PNGZ_MSG_CAST("compressed data too long");
+ ret = Z_MEM_ERROR;
+ }
+
+ else
+ png_zstream_error(png_ptr, ret);
+
+ /* Reset zlib for another zTXt/iTXt or image data */
+ png_ptr->zowner = 0;
+
+ /* The only success case is Z_STREAM_END, input_len must be 0, if not this
+ * is an internal error.
+ */
+ if (ret == Z_STREAM_END && input_len == 0)
+ {
+ /* Fix up the deflate header, if required */
+ optimize_cmf(comp->output, comp->input_len);
+
+ /* But Z_OK is returned, not Z_STREAM_END; this allows the claim
+ * function above to return Z_STREAM_END on an error (though it never
+ * does in the current versions of zlib.)
+ */
+ return Z_OK;
+ }
+
+ else
+ return ret;
+ }
+}
+
+/* Ship the compressed text out via chunk writes */
+static void
+png_write_compressed_data_out(png_structrp png_ptr, compression_state *comp)
+{
+ png_uint_32 output_len = comp->output_len;
+ png_const_bytep output = comp->output;
+ png_uint_32 avail = (sizeof comp->output);
+ png_compression_buffer *next = png_ptr->zbuffer_list;
+
+ for (;;)
+ {
+ if (avail > output_len)
+ avail = output_len;
+
+ png_write_chunk_data(png_ptr, output, avail);
+
+ output_len -= avail;
+
+ if (output_len == 0 || next == NULL)
+ break;
+
+ avail = png_ptr->zbuffer_size;
+ output = next->output;
+ next = next->next;
+ }
+
+ /* This is an internal error; 'next' must have been NULL! */
+ if (output_len > 0)
+ png_error(png_ptr, "error writing ancillary chunked compressed data");
+}
+#endif /* PNG_WRITE_COMPRESSED_TEXT_SUPPORTED */
+
+#if defined(PNG_WRITE_TEXT_SUPPORTED) || defined(PNG_WRITE_pCAL_SUPPORTED) || \
+ defined(PNG_WRITE_iCCP_SUPPORTED) || defined(PNG_WRITE_sPLT_SUPPORTED)
+/* Check that the tEXt or zTXt keyword is valid per PNG 1.0 specification,
+ * and if invalid, correct the keyword rather than discarding the entire
+ * chunk. The PNG 1.0 specification requires keywords 1-79 characters in
+ * length, forbids leading or trailing whitespace, multiple internal spaces,
+ * and the non-break space (0x80) from ISO 8859-1. Returns keyword length.
+ *
+ * The 'new_key' buffer must be 80 characters in size (for the keyword plus a
+ * trailing '\0'). If this routine returns 0 then there was no keyword, or a
+ * valid one could not be generated, and the caller must png_error.
+ */
+static png_uint_32
+png_check_keyword(png_structrp png_ptr, png_const_charp key, png_bytep new_key)
+{
+ png_const_charp orig_key = key;
+ png_uint_32 key_len = 0;
+ int bad_character = 0;
+ int space = 1;
+
+ png_debug(1, "in png_check_keyword");
+
+ if (key == NULL)
+ {
+ *new_key = 0;
+ return 0;
+ }
+
+ while (*key && key_len < 79)
+ {
+ png_byte ch = (png_byte)(0xff & *key++);
+
+ if ((ch > 32 && ch <= 126) || (ch >= 161 /*&& ch <= 255*/))
+ *new_key++ = ch, ++key_len, space = 0;
+
+ else if (!space)
+ {
+ /* A space or an invalid character when one wasn't seen immediately
+ * before; output just a space.
+ */
+ *new_key++ = 32, ++key_len, space = 1;
+
+ /* If the character was not a space then it is invalid. */
+ if (ch != 32)
+ bad_character = ch;
+ }
+
+ else if (!bad_character)
+ bad_character = ch; /* just skip it, record the first error */
+ }
+
+ if (key_len > 0 && space) /* trailing space */
+ {
+ --key_len, --new_key;
+ if (!bad_character)
+ bad_character = 32;
+ }
+
+ /* Terminate the keyword */
+ *new_key = 0;
+
+ if (key_len == 0)
+ return 0;
+
+ /* Try to only output one warning per keyword: */
+ if (*key) /* keyword too long */
+ png_warning(png_ptr, "keyword truncated");
+
+ else if (bad_character)
+ {
+ PNG_WARNING_PARAMETERS(p)
+
+ png_warning_parameter(p, 1, orig_key);
+ png_warning_parameter_signed(p, 2, PNG_NUMBER_FORMAT_02x, bad_character);
+
+ png_formatted_warning(png_ptr, p, "keyword \"@1\": bad character '0x@2'");
+ }
+
+ return key_len;
+}
+#endif
+
+/* Write the IHDR chunk, and update the png_struct with the necessary
+ * information. Note that the rest of this code depends upon this
+ * information being correct.
+ */
+void /* PRIVATE */
+png_write_IHDR(png_structrp png_ptr, png_uint_32 width, png_uint_32 height,
+ int bit_depth, int color_type, int compression_type, int filter_type,
+ int interlace_type)
+{
+ png_byte buf[13]; /* Buffer to store the IHDR info */
+
+ png_debug(1, "in png_write_IHDR");
+
+ /* Check that we have valid input data from the application info */
+ switch (color_type)
+ {
+ case PNG_COLOR_TYPE_GRAY:
+ switch (bit_depth)
+ {
+ case 1:
+ case 2:
+ case 4:
+ case 8:
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ case 16:
+#endif
+ png_ptr->channels = 1; break;
+
+ default:
+ png_error(png_ptr,
+ "Invalid bit depth for grayscale image");
+ }
+ break;
+
+ case PNG_COLOR_TYPE_RGB:
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ if (bit_depth != 8 && bit_depth != 16)
+#else
+ if (bit_depth != 8)
+#endif
+ png_error(png_ptr, "Invalid bit depth for RGB image");
+
+ png_ptr->channels = 3;
+ break;
+
+ case PNG_COLOR_TYPE_PALETTE:
+ switch (bit_depth)
+ {
+ case 1:
+ case 2:
+ case 4:
+ case 8:
+ png_ptr->channels = 1;
+ break;
+
+ default:
+ png_error(png_ptr, "Invalid bit depth for paletted image");
+ }
+ break;
+
+ case PNG_COLOR_TYPE_GRAY_ALPHA:
+ if (bit_depth != 8 && bit_depth != 16)
+ png_error(png_ptr, "Invalid bit depth for grayscale+alpha image");
+
+ png_ptr->channels = 2;
+ break;
+
+ case PNG_COLOR_TYPE_RGB_ALPHA:
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ if (bit_depth != 8 && bit_depth != 16)
+#else
+ if (bit_depth != 8)
+#endif
+ png_error(png_ptr, "Invalid bit depth for RGBA image");
+
+ png_ptr->channels = 4;
+ break;
+
+ default:
+ png_error(png_ptr, "Invalid image color type specified");
+ }
+
+ if (compression_type != PNG_COMPRESSION_TYPE_BASE)
+ {
+ png_warning(png_ptr, "Invalid compression type specified");
+ compression_type = PNG_COMPRESSION_TYPE_BASE;
+ }
+
+ /* Write filter_method 64 (intrapixel differencing) only if
+ * 1. Libpng was compiled with PNG_MNG_FEATURES_SUPPORTED and
+ * 2. Libpng did not write a PNG signature (this filter_method is only
+ * used in PNG datastreams that are embedded in MNG datastreams) and
+ * 3. The application called png_permit_mng_features with a mask that
+ * included PNG_FLAG_MNG_FILTER_64 and
+ * 4. The filter_method is 64 and
+ * 5. The color_type is RGB or RGBA
+ */
+ if (
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ !((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) &&
+ ((png_ptr->mode&PNG_HAVE_PNG_SIGNATURE) == 0) &&
+ (color_type == PNG_COLOR_TYPE_RGB ||
+ color_type == PNG_COLOR_TYPE_RGB_ALPHA) &&
+ (filter_type == PNG_INTRAPIXEL_DIFFERENCING)) &&
+#endif
+ filter_type != PNG_FILTER_TYPE_BASE)
+ {
+ png_warning(png_ptr, "Invalid filter type specified");
+ filter_type = PNG_FILTER_TYPE_BASE;
+ }
+
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+ if (interlace_type != PNG_INTERLACE_NONE &&
+ interlace_type != PNG_INTERLACE_ADAM7)
+ {
+ png_warning(png_ptr, "Invalid interlace type specified");
+ interlace_type = PNG_INTERLACE_ADAM7;
+ }
+#else
+ interlace_type=PNG_INTERLACE_NONE;
+#endif
+
+ /* Save the relevent information */
+ png_ptr->bit_depth = (png_byte)bit_depth;
+ png_ptr->color_type = (png_byte)color_type;
+ png_ptr->interlaced = (png_byte)interlace_type;
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ png_ptr->filter_type = (png_byte)filter_type;
+#endif
+ png_ptr->compression_type = (png_byte)compression_type;
+ png_ptr->width = width;
+ png_ptr->height = height;
+
+ png_ptr->pixel_depth = (png_byte)(bit_depth * png_ptr->channels);
+ png_ptr->rowbytes = PNG_ROWBYTES(png_ptr->pixel_depth, width);
+ /* Set the usr info, so any transformations can modify it */
+ png_ptr->usr_width = png_ptr->width;
+ png_ptr->usr_bit_depth = png_ptr->bit_depth;
+ png_ptr->usr_channels = png_ptr->channels;
+
+ /* Pack the header information into the buffer */
+ png_save_uint_32(buf, width);
+ png_save_uint_32(buf + 4, height);
+ buf[8] = (png_byte)bit_depth;
+ buf[9] = (png_byte)color_type;
+ buf[10] = (png_byte)compression_type;
+ buf[11] = (png_byte)filter_type;
+ buf[12] = (png_byte)interlace_type;
+
+ /* Write the chunk */
+ png_write_complete_chunk(png_ptr, png_IHDR, buf, (png_size_t)13);
+
+ if (!(png_ptr->do_filter))
+ {
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE ||
+ png_ptr->bit_depth < 8)
+ png_ptr->do_filter = PNG_FILTER_NONE;
+
+ else
+ png_ptr->do_filter = PNG_ALL_FILTERS;
+ }
+
+ png_ptr->mode = PNG_HAVE_IHDR; /* not READY_FOR_ZTXT */
+}
+
+/* Write the palette. We are careful not to trust png_color to be in the
+ * correct order for PNG, so people can redefine it to any convenient
+ * structure.
+ */
+void /* PRIVATE */
+png_write_PLTE(png_structrp png_ptr, png_const_colorp palette,
+ png_uint_32 num_pal)
+{
+ png_uint_32 i;
+ png_const_colorp pal_ptr;
+ png_byte buf[3];
+
+ png_debug(1, "in png_write_PLTE");
+
+ if ((
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ !(png_ptr->mng_features_permitted & PNG_FLAG_MNG_EMPTY_PLTE) &&
+#endif
+ num_pal == 0) || num_pal > 256)
+ {
+ if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ png_error(png_ptr, "Invalid number of colors in palette");
+ }
+
+ else
+ {
+ png_warning(png_ptr, "Invalid number of colors in palette");
+ return;
+ }
+ }
+
+ if (!(png_ptr->color_type&PNG_COLOR_MASK_COLOR))
+ {
+ png_warning(png_ptr,
+ "Ignoring request to write a PLTE chunk in grayscale PNG");
+
+ return;
+ }
+
+ png_ptr->num_palette = (png_uint_16)num_pal;
+ png_debug1(3, "num_palette = %d", png_ptr->num_palette);
+
+ png_write_chunk_header(png_ptr, png_PLTE, (png_uint_32)(num_pal * 3));
+#ifdef PNG_POINTER_INDEXING_SUPPORTED
+
+ for (i = 0, pal_ptr = palette; i < num_pal; i++, pal_ptr++)
+ {
+ buf[0] = pal_ptr->red;
+ buf[1] = pal_ptr->green;
+ buf[2] = pal_ptr->blue;
+ png_write_chunk_data(png_ptr, buf, (png_size_t)3);
+ }
+
+#else
+ /* This is a little slower but some buggy compilers need to do this
+ * instead
+ */
+ pal_ptr=palette;
+
+ for (i = 0; i < num_pal; i++)
+ {
+ buf[0] = pal_ptr[i].red;
+ buf[1] = pal_ptr[i].green;
+ buf[2] = pal_ptr[i].blue;
+ png_write_chunk_data(png_ptr, buf, (png_size_t)3);
+ }
+
+#endif
+ png_write_chunk_end(png_ptr);
+ png_ptr->mode |= PNG_HAVE_PLTE;
+}
+
+/* This is similar to png_text_compress, above, except that it does not require
+ * all of the data at once and, instead of buffering the compressed result,
+ * writes it as IDAT chunks. Unlike png_text_compress it *can* png_error out
+ * because it calls the write interface. As a result it does its own error
+ * reporting and does not return an error code. In the event of error it will
+ * just call png_error. The input data length may exceed 32-bits. The 'flush'
+ * parameter is exactly the same as that to deflate, with the following
+ * meanings:
+ *
+ * Z_NO_FLUSH: normal incremental output of compressed data
+ * Z_SYNC_FLUSH: do a SYNC_FLUSH, used by png_write_flush
+ * Z_FINISH: this is the end of the input, do a Z_FINISH and clean up
+ *
+ * The routine manages the acquire and release of the png_ptr->zstream by
+ * checking and (at the end) clearing png_ptr->zowner, it does some sanity
+ * checks on the 'mode' flags while doing this.
+ */
+void /* PRIVATE */
+png_compress_IDAT(png_structrp png_ptr, png_const_bytep input,
+ png_alloc_size_t input_len, int flush)
+{
+ if (png_ptr->zowner != png_IDAT)
+ {
+ /* First time. Ensure we have a temporary buffer for compression and
+ * trim the buffer list if it has more than one entry to free memory.
+ * If 'WRITE_COMPRESSED_TEXT' is not set the list will never have been
+ * created at this point, but the check here is quick and safe.
+ */
+ if (png_ptr->zbuffer_list == NULL)
+ {
+ png_ptr->zbuffer_list = png_voidcast(png_compression_bufferp,
+ png_malloc(png_ptr, PNG_COMPRESSION_BUFFER_SIZE(png_ptr)));
+ png_ptr->zbuffer_list->next = NULL;
+ }
+
+ else
+ png_free_buffer_list(png_ptr, &png_ptr->zbuffer_list->next);
+
+ /* It is a terminal error if we can't claim the zstream. */
+ if (png_deflate_claim(png_ptr, png_IDAT, png_image_size(png_ptr)) != Z_OK)
+ png_error(png_ptr, png_ptr->zstream.msg);
+
+ /* The output state is maintained in png_ptr->zstream, so it must be
+ * initialized here after the claim.
+ */
+ png_ptr->zstream.next_out = png_ptr->zbuffer_list->output;
+ png_ptr->zstream.avail_out = png_ptr->zbuffer_size;
+ }
+
+ /* Now loop reading and writing until all the input is consumed or an error
+ * terminates the operation. The _out values are maintained across calls to
+ * this function, but the input must be reset each time.
+ */
+ png_ptr->zstream.next_in = PNGZ_INPUT_CAST(input);
+ png_ptr->zstream.avail_in = 0; /* set below */
+ for (;;)
+ {
+ int ret;
+
+ /* INPUT: from the row data */
+ uInt avail = ZLIB_IO_MAX;
+
+ if (avail > input_len)
+ avail = (uInt)input_len; /* safe because of the check */
+
+ png_ptr->zstream.avail_in = avail;
+ input_len -= avail;
+
+ ret = deflate(&png_ptr->zstream, input_len > 0 ? Z_NO_FLUSH : flush);
+
+ /* Include as-yet unconsumed input */
+ input_len += png_ptr->zstream.avail_in;
+ png_ptr->zstream.avail_in = 0;
+
+ /* OUTPUT: write complete IDAT chunks when avail_out drops to zero, note
+ * that these two zstream fields are preserved across the calls, therefore
+ * there is no need to set these up on entry to the loop.
+ */
+ if (png_ptr->zstream.avail_out == 0)
+ {
+ png_bytep data = png_ptr->zbuffer_list->output;
+ uInt size = png_ptr->zbuffer_size;
+
+ /* Write an IDAT containing the data then reset the buffer. The
+ * first IDAT may need deflate header optimization.
+ */
+# ifdef PNG_WRITE_OPTIMIZE_CMF_SUPPORTED
+ if (!(png_ptr->mode & PNG_HAVE_IDAT) &&
+ png_ptr->compression_type == PNG_COMPRESSION_TYPE_BASE)
+ optimize_cmf(data, png_image_size(png_ptr));
+# endif
+
+ png_write_complete_chunk(png_ptr, png_IDAT, data, size);
+ png_ptr->mode |= PNG_HAVE_IDAT;
+
+ png_ptr->zstream.next_out = data;
+ png_ptr->zstream.avail_out = size;
+
+ /* For SYNC_FLUSH or FINISH it is essential to keep calling zlib with
+ * the same flush parameter until it has finished output, for NO_FLUSH
+ * it doesn't matter.
+ */
+ if (ret == Z_OK && flush != Z_NO_FLUSH)
+ continue;
+ }
+
+ /* The order of these checks doesn't matter much; it just effect which
+ * possible error might be detected if multiple things go wrong at once.
+ */
+ if (ret == Z_OK) /* most likely return code! */
+ {
+ /* If all the input has been consumed then just return. If Z_FINISH
+ * was used as the flush parameter something has gone wrong if we get
+ * here.
+ */
+ if (input_len == 0)
+ {
+ if (flush == Z_FINISH)
+ png_error(png_ptr, "Z_OK on Z_FINISH with output space");
+
+ return;
+ }
+ }
+
+ else if (ret == Z_STREAM_END && flush == Z_FINISH)
+ {
+ /* This is the end of the IDAT data; any pending output must be
+ * flushed. For small PNG files we may still be at the beginning.
+ */
+ png_bytep data = png_ptr->zbuffer_list->output;
+ uInt size = png_ptr->zbuffer_size - png_ptr->zstream.avail_out;
+
+# ifdef PNG_WRITE_OPTIMIZE_CMF_SUPPORTED
+ if (!(png_ptr->mode & PNG_HAVE_IDAT) &&
+ png_ptr->compression_type == PNG_COMPRESSION_TYPE_BASE)
+ optimize_cmf(data, png_image_size(png_ptr));
+# endif
+
+ png_write_complete_chunk(png_ptr, png_IDAT, data, size);
+ png_ptr->zstream.avail_out = 0;
+ png_ptr->zstream.next_out = NULL;
+ png_ptr->mode |= PNG_HAVE_IDAT | PNG_AFTER_IDAT;
+
+ png_ptr->zowner = 0; /* Release the stream */
+ return;
+ }
+
+ else
+ {
+ /* This is an error condition. */
+ png_zstream_error(png_ptr, ret);
+ png_error(png_ptr, png_ptr->zstream.msg);
+ }
+ }
+}
+
+/* Write an IEND chunk */
+void /* PRIVATE */
+png_write_IEND(png_structrp png_ptr)
+{
+ png_debug(1, "in png_write_IEND");
+
+ png_write_complete_chunk(png_ptr, png_IEND, NULL, (png_size_t)0);
+ png_ptr->mode |= PNG_HAVE_IEND;
+}
+
+#ifdef PNG_WRITE_gAMA_SUPPORTED
+/* Write a gAMA chunk */
+void /* PRIVATE */
+png_write_gAMA_fixed(png_structrp png_ptr, png_fixed_point file_gamma)
+{
+ png_byte buf[4];
+
+ png_debug(1, "in png_write_gAMA");
+
+ /* file_gamma is saved in 1/100,000ths */
+ png_save_uint_32(buf, (png_uint_32)file_gamma);
+ png_write_complete_chunk(png_ptr, png_gAMA, buf, (png_size_t)4);
+}
+#endif
+
+#ifdef PNG_WRITE_sRGB_SUPPORTED
+/* Write a sRGB chunk */
+void /* PRIVATE */
+png_write_sRGB(png_structrp png_ptr, int srgb_intent)
+{
+ png_byte buf[1];
+
+ png_debug(1, "in png_write_sRGB");
+
+ if (srgb_intent >= PNG_sRGB_INTENT_LAST)
+ png_warning(png_ptr,
+ "Invalid sRGB rendering intent specified");
+
+ buf[0]=(png_byte)srgb_intent;
+ png_write_complete_chunk(png_ptr, png_sRGB, buf, (png_size_t)1);
+}
+#endif
+
+#ifdef PNG_WRITE_iCCP_SUPPORTED
+/* Write an iCCP chunk */
+void /* PRIVATE */
+png_write_iCCP(png_structrp png_ptr, png_const_charp name,
+ png_const_bytep profile)
+{
+ png_uint_32 name_len;
+ png_uint_32 profile_len;
+ png_byte new_name[81]; /* 1 byte for the compression byte */
+ compression_state comp;
+
+ png_debug(1, "in png_write_iCCP");
+
+ /* These are all internal problems: the profile should have been checked
+ * before when it was stored.
+ */
+ if (profile == NULL)
+ png_error(png_ptr, "No profile for iCCP chunk"); /* internal error */
+
+ profile_len = png_get_uint_32(profile);
+
+ if (profile_len < 132)
+ png_error(png_ptr, "ICC profile too short");
+
+ if (profile_len & 0x03)
+ png_error(png_ptr, "ICC profile length invalid (not a multiple of 4)");
+
+ {
+ png_uint_32 embedded_profile_len = png_get_uint_32(profile);
+
+ if (profile_len != embedded_profile_len)
+ png_error(png_ptr, "Profile length does not match profile");
+ }
+
+ name_len = png_check_keyword(png_ptr, name, new_name);
+
+ if (name_len == 0)
+ png_error(png_ptr, "iCCP: invalid keyword");
+
+ new_name[++name_len] = PNG_COMPRESSION_TYPE_BASE;
+
+ /* Make sure we include the NULL after the name and the compression type */
+ ++name_len;
+
+ png_text_compress_init(&comp, profile, profile_len);
+
+ /* Allow for keyword terminator and compression byte */
+ if (png_text_compress(png_ptr, png_iCCP, &comp, name_len) != Z_OK)
+ png_error(png_ptr, png_ptr->zstream.msg);
+
+ png_write_chunk_header(png_ptr, png_iCCP, name_len + comp.output_len);
+
+ png_write_chunk_data(png_ptr, new_name, name_len);
+
+ png_write_compressed_data_out(png_ptr, &comp);
+
+ png_write_chunk_end(png_ptr);
+}
+#endif
+
+#ifdef PNG_WRITE_sPLT_SUPPORTED
+/* Write a sPLT chunk */
+void /* PRIVATE */
+png_write_sPLT(png_structrp png_ptr, png_const_sPLT_tp spalette)
+{
+ png_uint_32 name_len;
+ png_byte new_name[80];
+ png_byte entrybuf[10];
+ png_size_t entry_size = (spalette->depth == 8 ? 6 : 10);
+ png_size_t palette_size = entry_size * spalette->nentries;
+ png_sPLT_entryp ep;
+#ifndef PNG_POINTER_INDEXING_SUPPORTED
+ int i;
+#endif
+
+ png_debug(1, "in png_write_sPLT");
+
+ name_len = png_check_keyword(png_ptr, spalette->name, new_name);
+
+ if (name_len == 0)
+ png_error(png_ptr, "sPLT: invalid keyword");
+
+ /* Make sure we include the NULL after the name */
+ png_write_chunk_header(png_ptr, png_sPLT,
+ (png_uint_32)(name_len + 2 + palette_size));
+
+ png_write_chunk_data(png_ptr, (png_bytep)new_name,
+ (png_size_t)(name_len + 1));
+
+ png_write_chunk_data(png_ptr, &spalette->depth, (png_size_t)1);
+
+ /* Loop through each palette entry, writing appropriately */
+#ifdef PNG_POINTER_INDEXING_SUPPORTED
+ for (ep = spalette->entries; ep<spalette->entries + spalette->nentries; ep++)
+ {
+ if (spalette->depth == 8)
+ {
+ entrybuf[0] = (png_byte)ep->red;
+ entrybuf[1] = (png_byte)ep->green;
+ entrybuf[2] = (png_byte)ep->blue;
+ entrybuf[3] = (png_byte)ep->alpha;
+ png_save_uint_16(entrybuf + 4, ep->frequency);
+ }
+
+ else
+ {
+ png_save_uint_16(entrybuf + 0, ep->red);
+ png_save_uint_16(entrybuf + 2, ep->green);
+ png_save_uint_16(entrybuf + 4, ep->blue);
+ png_save_uint_16(entrybuf + 6, ep->alpha);
+ png_save_uint_16(entrybuf + 8, ep->frequency);
+ }
+
+ png_write_chunk_data(png_ptr, entrybuf, entry_size);
+ }
+#else
+ ep=spalette->entries;
+ for (i = 0; i>spalette->nentries; i++)
+ {
+ if (spalette->depth == 8)
+ {
+ entrybuf[0] = (png_byte)ep[i].red;
+ entrybuf[1] = (png_byte)ep[i].green;
+ entrybuf[2] = (png_byte)ep[i].blue;
+ entrybuf[3] = (png_byte)ep[i].alpha;
+ png_save_uint_16(entrybuf + 4, ep[i].frequency);
+ }
+
+ else
+ {
+ png_save_uint_16(entrybuf + 0, ep[i].red);
+ png_save_uint_16(entrybuf + 2, ep[i].green);
+ png_save_uint_16(entrybuf + 4, ep[i].blue);
+ png_save_uint_16(entrybuf + 6, ep[i].alpha);
+ png_save_uint_16(entrybuf + 8, ep[i].frequency);
+ }
+
+ png_write_chunk_data(png_ptr, entrybuf, entry_size);
+ }
+#endif
+
+ png_write_chunk_end(png_ptr);
+}
+#endif
+
+#ifdef PNG_WRITE_sBIT_SUPPORTED
+/* Write the sBIT chunk */
+void /* PRIVATE */
+png_write_sBIT(png_structrp png_ptr, png_const_color_8p sbit, int color_type)
+{
+ png_byte buf[4];
+ png_size_t size;
+
+ png_debug(1, "in png_write_sBIT");
+
+ /* Make sure we don't depend upon the order of PNG_COLOR_8 */
+ if (color_type & PNG_COLOR_MASK_COLOR)
+ {
+ png_byte maxbits;
+
+ maxbits = (png_byte)(color_type==PNG_COLOR_TYPE_PALETTE ? 8 :
+ png_ptr->usr_bit_depth);
+
+ if (sbit->red == 0 || sbit->red > maxbits ||
+ sbit->green == 0 || sbit->green > maxbits ||
+ sbit->blue == 0 || sbit->blue > maxbits)
+ {
+ png_warning(png_ptr, "Invalid sBIT depth specified");
+ return;
+ }
+
+ buf[0] = sbit->red;
+ buf[1] = sbit->green;
+ buf[2] = sbit->blue;
+ size = 3;
+ }
+
+ else
+ {
+ if (sbit->gray == 0 || sbit->gray > png_ptr->usr_bit_depth)
+ {
+ png_warning(png_ptr, "Invalid sBIT depth specified");
+ return;
+ }
+
+ buf[0] = sbit->gray;
+ size = 1;
+ }
+
+ if (color_type & PNG_COLOR_MASK_ALPHA)
+ {
+ if (sbit->alpha == 0 || sbit->alpha > png_ptr->usr_bit_depth)
+ {
+ png_warning(png_ptr, "Invalid sBIT depth specified");
+ return;
+ }
+
+ buf[size++] = sbit->alpha;
+ }
+
+ png_write_complete_chunk(png_ptr, png_sBIT, buf, size);
+}
+#endif
+
+#ifdef PNG_WRITE_cHRM_SUPPORTED
+/* Write the cHRM chunk */
+void /* PRIVATE */
+png_write_cHRM_fixed(png_structrp png_ptr, const png_xy *xy)
+{
+ png_byte buf[32];
+
+ png_debug(1, "in png_write_cHRM");
+
+ /* Each value is saved in 1/100,000ths */
+ png_save_int_32(buf, xy->whitex);
+ png_save_int_32(buf + 4, xy->whitey);
+
+ png_save_int_32(buf + 8, xy->redx);
+ png_save_int_32(buf + 12, xy->redy);
+
+ png_save_int_32(buf + 16, xy->greenx);
+ png_save_int_32(buf + 20, xy->greeny);
+
+ png_save_int_32(buf + 24, xy->bluex);
+ png_save_int_32(buf + 28, xy->bluey);
+
+ png_write_complete_chunk(png_ptr, png_cHRM, buf, 32);
+}
+#endif
+
+#ifdef PNG_WRITE_tRNS_SUPPORTED
+/* Write the tRNS chunk */
+void /* PRIVATE */
+png_write_tRNS(png_structrp png_ptr, png_const_bytep trans_alpha,
+ png_const_color_16p tran, int num_trans, int color_type)
+{
+ png_byte buf[6];
+
+ png_debug(1, "in png_write_tRNS");
+
+ if (color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ if (num_trans <= 0 || num_trans > (int)png_ptr->num_palette)
+ {
+ png_app_warning(png_ptr,
+ "Invalid number of transparent colors specified");
+ return;
+ }
+
+ /* Write the chunk out as it is */
+ png_write_complete_chunk(png_ptr, png_tRNS, trans_alpha,
+ (png_size_t)num_trans);
+ }
+
+ else if (color_type == PNG_COLOR_TYPE_GRAY)
+ {
+ /* One 16 bit value */
+ if (tran->gray >= (1 << png_ptr->bit_depth))
+ {
+ png_app_warning(png_ptr,
+ "Ignoring attempt to write tRNS chunk out-of-range for bit_depth");
+
+ return;
+ }
+
+ png_save_uint_16(buf, tran->gray);
+ png_write_complete_chunk(png_ptr, png_tRNS, buf, (png_size_t)2);
+ }
+
+ else if (color_type == PNG_COLOR_TYPE_RGB)
+ {
+ /* Three 16 bit values */
+ png_save_uint_16(buf, tran->red);
+ png_save_uint_16(buf + 2, tran->green);
+ png_save_uint_16(buf + 4, tran->blue);
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ if (png_ptr->bit_depth == 8 && (buf[0] | buf[2] | buf[4]))
+#else
+ if (buf[0] | buf[2] | buf[4])
+#endif
+ {
+ png_app_warning(png_ptr,
+ "Ignoring attempt to write 16-bit tRNS chunk when bit_depth is 8");
+ return;
+ }
+
+ png_write_complete_chunk(png_ptr, png_tRNS, buf, (png_size_t)6);
+ }
+
+ else
+ {
+ png_app_warning(png_ptr, "Can't write tRNS with an alpha channel");
+ }
+}
+#endif
+
+#ifdef PNG_WRITE_bKGD_SUPPORTED
+/* Write the background chunk */
+void /* PRIVATE */
+png_write_bKGD(png_structrp png_ptr, png_const_color_16p back, int color_type)
+{
+ png_byte buf[6];
+
+ png_debug(1, "in png_write_bKGD");
+
+ if (color_type == PNG_COLOR_TYPE_PALETTE)
+ {
+ if (
+#ifdef PNG_MNG_FEATURES_SUPPORTED
+ (png_ptr->num_palette ||
+ (!(png_ptr->mng_features_permitted & PNG_FLAG_MNG_EMPTY_PLTE))) &&
+#endif
+ back->index >= png_ptr->num_palette)
+ {
+ png_warning(png_ptr, "Invalid background palette index");
+ return;
+ }
+
+ buf[0] = back->index;
+ png_write_complete_chunk(png_ptr, png_bKGD, buf, (png_size_t)1);
+ }
+
+ else if (color_type & PNG_COLOR_MASK_COLOR)
+ {
+ png_save_uint_16(buf, back->red);
+ png_save_uint_16(buf + 2, back->green);
+ png_save_uint_16(buf + 4, back->blue);
+#ifdef PNG_WRITE_16BIT_SUPPORTED
+ if (png_ptr->bit_depth == 8 && (buf[0] | buf[2] | buf[4]))
+#else
+ if (buf[0] | buf[2] | buf[4])
+#endif
+ {
+ png_warning(png_ptr,
+ "Ignoring attempt to write 16-bit bKGD chunk when bit_depth is 8");
+
+ return;
+ }
+
+ png_write_complete_chunk(png_ptr, png_bKGD, buf, (png_size_t)6);
+ }
+
+ else
+ {
+ if (back->gray >= (1 << png_ptr->bit_depth))
+ {
+ png_warning(png_ptr,
+ "Ignoring attempt to write bKGD chunk out-of-range for bit_depth");
+
+ return;
+ }
+
+ png_save_uint_16(buf, back->gray);
+ png_write_complete_chunk(png_ptr, png_bKGD, buf, (png_size_t)2);
+ }
+}
+#endif
+
+#ifdef PNG_WRITE_hIST_SUPPORTED
+/* Write the histogram */
+void /* PRIVATE */
+png_write_hIST(png_structrp png_ptr, png_const_uint_16p hist, int num_hist)
+{
+ int i;
+ png_byte buf[3];
+
+ png_debug(1, "in png_write_hIST");
+
+ if (num_hist > (int)png_ptr->num_palette)
+ {
+ png_debug2(3, "num_hist = %d, num_palette = %d", num_hist,
+ png_ptr->num_palette);
+
+ png_warning(png_ptr, "Invalid number of histogram entries specified");
+ return;
+ }
+
+ png_write_chunk_header(png_ptr, png_hIST, (png_uint_32)(num_hist * 2));
+
+ for (i = 0; i < num_hist; i++)
+ {
+ png_save_uint_16(buf, hist[i]);
+ png_write_chunk_data(png_ptr, buf, (png_size_t)2);
+ }
+
+ png_write_chunk_end(png_ptr);
+}
+#endif
+
+#ifdef PNG_WRITE_tEXt_SUPPORTED
+/* Write a tEXt chunk */
+void /* PRIVATE */
+png_write_tEXt(png_structrp png_ptr, png_const_charp key, png_const_charp text,
+ png_size_t text_len)
+{
+ png_uint_32 key_len;
+ png_byte new_key[80];
+
+ png_debug(1, "in png_write_tEXt");
+
+ key_len = png_check_keyword(png_ptr, key, new_key);
+
+ if (key_len == 0)
+ png_error(png_ptr, "tEXt: invalid keyword");
+
+ if (text == NULL || *text == '\0')
+ text_len = 0;
+
+ else
+ text_len = strlen(text);
+
+ if (text_len > PNG_UINT_31_MAX - (key_len+1))
+ png_error(png_ptr, "tEXt: text too long");
+
+ /* Make sure we include the 0 after the key */
+ png_write_chunk_header(png_ptr, png_tEXt,
+ (png_uint_32)/*checked above*/(key_len + text_len + 1));
+ /*
+ * We leave it to the application to meet PNG-1.0 requirements on the
+ * contents of the text. PNG-1.0 through PNG-1.2 discourage the use of
+ * any non-Latin-1 characters except for NEWLINE. ISO PNG will forbid them.
+ * The NUL character is forbidden by PNG-1.0 through PNG-1.2 and ISO PNG.
+ */
+ png_write_chunk_data(png_ptr, new_key, key_len + 1);
+
+ if (text_len)
+ png_write_chunk_data(png_ptr, (png_const_bytep)text, text_len);
+
+ png_write_chunk_end(png_ptr);
+}
+#endif
+
+#ifdef PNG_WRITE_zTXt_SUPPORTED
+/* Write a compressed text chunk */
+void /* PRIVATE */
+png_write_zTXt(png_structrp png_ptr, png_const_charp key, png_const_charp text,
+ png_size_t text_len, int compression)
+{
+ png_uint_32 key_len;
+ png_byte new_key[81];
+ compression_state comp;
+
+ png_debug(1, "in png_write_zTXt");
+ PNG_UNUSED(text_len) /* Always use strlen */
+
+ if (compression == PNG_TEXT_COMPRESSION_NONE)
+ {
+ png_write_tEXt(png_ptr, key, text, 0);
+ return;
+ }
+
+ if (compression != PNG_TEXT_COMPRESSION_zTXt)
+ png_error(png_ptr, "zTXt: invalid compression type");
+
+ key_len = png_check_keyword(png_ptr, key, new_key);
+
+ if (key_len == 0)
+ png_error(png_ptr, "zTXt: invalid keyword");
+
+ /* Add the compression method and 1 for the keyword separator. */
+ new_key[++key_len] = PNG_COMPRESSION_TYPE_BASE;
+ ++key_len;
+
+ /* Compute the compressed data; do it now for the length */
+ png_text_compress_init(&comp, (png_const_bytep)text,
+ text == NULL ? 0 : strlen(text));
+
+ if (png_text_compress(png_ptr, png_zTXt, &comp, key_len) != Z_OK)
+ png_error(png_ptr, png_ptr->zstream.msg);
+
+ /* Write start of chunk */
+ png_write_chunk_header(png_ptr, png_zTXt, key_len + comp.output_len);
+
+ /* Write key */
+ png_write_chunk_data(png_ptr, new_key, key_len);
+
+ /* Write the compressed data */
+ png_write_compressed_data_out(png_ptr, &comp);
+
+ /* Close the chunk */
+ png_write_chunk_end(png_ptr);
+}
+#endif
+
+#ifdef PNG_WRITE_iTXt_SUPPORTED
+/* Write an iTXt chunk */
+void /* PRIVATE */
+png_write_iTXt(png_structrp png_ptr, int compression, png_const_charp key,
+ png_const_charp lang, png_const_charp lang_key, png_const_charp text)
+{
+ png_uint_32 key_len, prefix_len;
+ png_size_t lang_len, lang_key_len;
+ png_byte new_key[82];
+ compression_state comp;
+
+ png_debug(1, "in png_write_iTXt");
+
+ key_len = png_check_keyword(png_ptr, key, new_key);
+
+ if (key_len == 0)
+ png_error(png_ptr, "iTXt: invalid keyword");
+
+ /* Set the compression flag */
+ switch (compression)
+ {
+ case PNG_ITXT_COMPRESSION_NONE:
+ case PNG_TEXT_COMPRESSION_NONE:
+ compression = new_key[++key_len] = 0; /* no compression */
+ break;
+
+ case PNG_TEXT_COMPRESSION_zTXt:
+ case PNG_ITXT_COMPRESSION_zTXt:
+ compression = new_key[++key_len] = 1; /* compressed */
+ break;
+
+ default:
+ png_error(png_ptr, "iTXt: invalid compression");
+ }
+
+ new_key[++key_len] = PNG_COMPRESSION_TYPE_BASE;
+ ++key_len; /* for the keywod separator */
+
+ /* We leave it to the application to meet PNG-1.0 requirements on the
+ * contents of the text. PNG-1.0 through PNG-1.2 discourage the use of
+ * any non-Latin-1 characters except for NEWLINE. ISO PNG, however,
+ * specifies that the text is UTF-8 and this really doesn't require any
+ * checking.
+ *
+ * The NUL character is forbidden by PNG-1.0 through PNG-1.2 and ISO PNG.
+ *
+ * TODO: validate the language tag correctly (see the spec.)
+ */
+ if (lang == NULL) lang = ""; /* empty language is valid */
+ lang_len = strlen(lang)+1;
+ if (lang_key == NULL) lang_key = ""; /* may be empty */
+ lang_key_len = strlen(lang_key)+1;
+ if (text == NULL) text = ""; /* may be empty */
+
+ prefix_len = key_len;
+ if (lang_len > PNG_UINT_31_MAX-prefix_len)
+ prefix_len = PNG_UINT_31_MAX;
+ else
+ prefix_len = (png_uint_32)(prefix_len + lang_len);
+
+ if (lang_key_len > PNG_UINT_31_MAX-prefix_len)
+ prefix_len = PNG_UINT_31_MAX;
+ else
+ prefix_len = (png_uint_32)(prefix_len + lang_key_len);
+
+ png_text_compress_init(&comp, (png_const_bytep)text, strlen(text));
+
+ if (compression)
+ {
+ if (png_text_compress(png_ptr, png_iTXt, &comp, prefix_len) != Z_OK)
+ png_error(png_ptr, png_ptr->zstream.msg);
+ }
+
+ else
+ {
+ if (comp.input_len > PNG_UINT_31_MAX-prefix_len)
+ png_error(png_ptr, "iTXt: uncompressed text too long");
+
+ /* So the string will fit in a chunk: */
+ comp.output_len = (png_uint_32)/*SAFE*/comp.input_len;
+ }
+
+ png_write_chunk_header(png_ptr, png_iTXt, comp.output_len + prefix_len);
+
+ png_write_chunk_data(png_ptr, new_key, key_len);
+
+ png_write_chunk_data(png_ptr, (png_const_bytep)lang, lang_len);
+
+ png_write_chunk_data(png_ptr, (png_const_bytep)lang_key, lang_key_len);
+
+ if (compression)
+ png_write_compressed_data_out(png_ptr, &comp);
+
+ else
+ png_write_chunk_data(png_ptr, (png_const_bytep)text, comp.input_len);
+
+ png_write_chunk_end(png_ptr);
+}
+#endif
+
+#ifdef PNG_WRITE_oFFs_SUPPORTED
+/* Write the oFFs chunk */
+void /* PRIVATE */
+png_write_oFFs(png_structrp png_ptr, png_int_32 x_offset, png_int_32 y_offset,
+ int unit_type)
+{
+ png_byte buf[9];
+
+ png_debug(1, "in png_write_oFFs");
+
+ if (unit_type >= PNG_OFFSET_LAST)
+ png_warning(png_ptr, "Unrecognized unit type for oFFs chunk");
+
+ png_save_int_32(buf, x_offset);
+ png_save_int_32(buf + 4, y_offset);
+ buf[8] = (png_byte)unit_type;
+
+ png_write_complete_chunk(png_ptr, png_oFFs, buf, (png_size_t)9);
+}
+#endif
+#ifdef PNG_WRITE_pCAL_SUPPORTED
+/* Write the pCAL chunk (described in the PNG extensions document) */
+void /* PRIVATE */
+png_write_pCAL(png_structrp png_ptr, png_charp purpose, png_int_32 X0,
+ png_int_32 X1, int type, int nparams, png_const_charp units,
+ png_charpp params)
+{
+ png_uint_32 purpose_len;
+ png_size_t units_len, total_len;
+ png_size_tp params_len;
+ png_byte buf[10];
+ png_byte new_purpose[80];
+ int i;
+
+ png_debug1(1, "in png_write_pCAL (%d parameters)", nparams);
+
+ if (type >= PNG_EQUATION_LAST)
+ png_error(png_ptr, "Unrecognized equation type for pCAL chunk");
+
+ purpose_len = png_check_keyword(png_ptr, purpose, new_purpose);
+
+ if (purpose_len == 0)
+ png_error(png_ptr, "pCAL: invalid keyword");
+
+ ++purpose_len; /* terminator */
+
+ png_debug1(3, "pCAL purpose length = %d", (int)purpose_len);
+ units_len = strlen(units) + (nparams == 0 ? 0 : 1);
+ png_debug1(3, "pCAL units length = %d", (int)units_len);
+ total_len = purpose_len + units_len + 10;
+
+ params_len = (png_size_tp)png_malloc(png_ptr,
+ (png_alloc_size_t)(nparams * (sizeof (png_size_t))));
+
+ /* Find the length of each parameter, making sure we don't count the
+ * null terminator for the last parameter.
+ */
+ for (i = 0; i < nparams; i++)
+ {
+ params_len[i] = strlen(params[i]) + (i == nparams - 1 ? 0 : 1);
+ png_debug2(3, "pCAL parameter %d length = %lu", i,
+ (unsigned long)params_len[i]);
+ total_len += params_len[i];
+ }
+
+ png_debug1(3, "pCAL total length = %d", (int)total_len);
+ png_write_chunk_header(png_ptr, png_pCAL, (png_uint_32)total_len);
+ png_write_chunk_data(png_ptr, new_purpose, purpose_len);
+ png_save_int_32(buf, X0);
+ png_save_int_32(buf + 4, X1);
+ buf[8] = (png_byte)type;
+ buf[9] = (png_byte)nparams;
+ png_write_chunk_data(png_ptr, buf, (png_size_t)10);
+ png_write_chunk_data(png_ptr, (png_const_bytep)units, (png_size_t)units_len);
+
+ for (i = 0; i < nparams; i++)
+ {
+ png_write_chunk_data(png_ptr, (png_const_bytep)params[i], params_len[i]);
+ }
+
+ png_free(png_ptr, params_len);
+ png_write_chunk_end(png_ptr);
+}
+#endif
+
+#ifdef PNG_WRITE_sCAL_SUPPORTED
+/* Write the sCAL chunk */
+void /* PRIVATE */
+png_write_sCAL_s(png_structrp png_ptr, int unit, png_const_charp width,
+ png_const_charp height)
+{
+ png_byte buf[64];
+ png_size_t wlen, hlen, total_len;
+
+ png_debug(1, "in png_write_sCAL_s");
+
+ wlen = strlen(width);
+ hlen = strlen(height);
+ total_len = wlen + hlen + 2;
+
+ if (total_len > 64)
+ {
+ png_warning(png_ptr, "Can't write sCAL (buffer too small)");
+ return;
+ }
+
+ buf[0] = (png_byte)unit;
+ memcpy(buf + 1, width, wlen + 1); /* Append the '\0' here */
+ memcpy(buf + wlen + 2, height, hlen); /* Do NOT append the '\0' here */
+
+ png_debug1(3, "sCAL total length = %u", (unsigned int)total_len);
+ png_write_complete_chunk(png_ptr, png_sCAL, buf, total_len);
+}
+#endif
+
+#ifdef PNG_WRITE_pHYs_SUPPORTED
+/* Write the pHYs chunk */
+void /* PRIVATE */
+png_write_pHYs(png_structrp png_ptr, png_uint_32 x_pixels_per_unit,
+ png_uint_32 y_pixels_per_unit,
+ int unit_type)
+{
+ png_byte buf[9];
+
+ png_debug(1, "in png_write_pHYs");
+
+ if (unit_type >= PNG_RESOLUTION_LAST)
+ png_warning(png_ptr, "Unrecognized unit type for pHYs chunk");
+
+ png_save_uint_32(buf, x_pixels_per_unit);
+ png_save_uint_32(buf + 4, y_pixels_per_unit);
+ buf[8] = (png_byte)unit_type;
+
+ png_write_complete_chunk(png_ptr, png_pHYs, buf, (png_size_t)9);
+}
+#endif
+
+#ifdef PNG_WRITE_tIME_SUPPORTED
+/* Write the tIME chunk. Use either png_convert_from_struct_tm()
+ * or png_convert_from_time_t(), or fill in the structure yourself.
+ */
+void /* PRIVATE */
+png_write_tIME(png_structrp png_ptr, png_const_timep mod_time)
+{
+ png_byte buf[7];
+
+ png_debug(1, "in png_write_tIME");
+
+ if (mod_time->month > 12 || mod_time->month < 1 ||
+ mod_time->day > 31 || mod_time->day < 1 ||
+ mod_time->hour > 23 || mod_time->second > 60)
+ {
+ png_warning(png_ptr, "Invalid time specified for tIME chunk");
+ return;
+ }
+
+ png_save_uint_16(buf, mod_time->year);
+ buf[2] = mod_time->month;
+ buf[3] = mod_time->day;
+ buf[4] = mod_time->hour;
+ buf[5] = mod_time->minute;
+ buf[6] = mod_time->second;
+
+ png_write_complete_chunk(png_ptr, png_tIME, buf, (png_size_t)7);
+}
+#endif
+
+/* Initializes the row writing capability of libpng */
+void /* PRIVATE */
+png_write_start_row(png_structrp png_ptr)
+{
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+ /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */
+
+ /* Start of interlace block */
+ static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0};
+
+ /* Offset to next interlace block */
+ static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1};
+
+ /* Start of interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_ystart[7] = {0, 0, 4, 0, 2, 0, 1};
+
+ /* Offset to next interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_yinc[7] = {8, 8, 8, 4, 4, 2, 2};
+#endif
+
+ png_alloc_size_t buf_size;
+ int usr_pixel_depth;
+
+ png_debug(1, "in png_write_start_row");
+
+ usr_pixel_depth = png_ptr->usr_channels * png_ptr->usr_bit_depth;
+ buf_size = PNG_ROWBYTES(usr_pixel_depth, png_ptr->width) + 1;
+
+ /* 1.5.6: added to allow checking in the row write code. */
+ png_ptr->transformed_pixel_depth = png_ptr->pixel_depth;
+ png_ptr->maximum_pixel_depth = (png_byte)usr_pixel_depth;
+
+ /* Set up row buffer */
+ png_ptr->row_buf = (png_bytep)png_malloc(png_ptr, buf_size);
+
+ png_ptr->row_buf[0] = PNG_FILTER_VALUE_NONE;
+
+#ifdef PNG_WRITE_FILTER_SUPPORTED
+ /* Set up filtering buffer, if using this filter */
+ if (png_ptr->do_filter & PNG_FILTER_SUB)
+ {
+ png_ptr->sub_row = (png_bytep)png_malloc(png_ptr, png_ptr->rowbytes + 1);
+
+ png_ptr->sub_row[0] = PNG_FILTER_VALUE_SUB;
+ }
+
+ /* We only need to keep the previous row if we are using one of these. */
+ if (png_ptr->do_filter & (PNG_FILTER_AVG | PNG_FILTER_UP | PNG_FILTER_PAETH))
+ {
+ /* Set up previous row buffer */
+ png_ptr->prev_row = (png_bytep)png_calloc(png_ptr, buf_size);
+
+ if (png_ptr->do_filter & PNG_FILTER_UP)
+ {
+ png_ptr->up_row = (png_bytep)png_malloc(png_ptr,
+ png_ptr->rowbytes + 1);
+
+ png_ptr->up_row[0] = PNG_FILTER_VALUE_UP;
+ }
+
+ if (png_ptr->do_filter & PNG_FILTER_AVG)
+ {
+ png_ptr->avg_row = (png_bytep)png_malloc(png_ptr,
+ png_ptr->rowbytes + 1);
+
+ png_ptr->avg_row[0] = PNG_FILTER_VALUE_AVG;
+ }
+
+ if (png_ptr->do_filter & PNG_FILTER_PAETH)
+ {
+ png_ptr->paeth_row = (png_bytep)png_malloc(png_ptr,
+ png_ptr->rowbytes + 1);
+
+ png_ptr->paeth_row[0] = PNG_FILTER_VALUE_PAETH;
+ }
+ }
+#endif /* PNG_WRITE_FILTER_SUPPORTED */
+
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+ /* If interlaced, we need to set up width and height of pass */
+ if (png_ptr->interlaced)
+ {
+ if (!(png_ptr->transformations & PNG_INTERLACE))
+ {
+ png_ptr->num_rows = (png_ptr->height + png_pass_yinc[0] - 1 -
+ png_pass_ystart[0]) / png_pass_yinc[0];
+
+ png_ptr->usr_width = (png_ptr->width + png_pass_inc[0] - 1 -
+ png_pass_start[0]) / png_pass_inc[0];
+ }
+
+ else
+ {
+ png_ptr->num_rows = png_ptr->height;
+ png_ptr->usr_width = png_ptr->width;
+ }
+ }
+
+ else
+#endif
+ {
+ png_ptr->num_rows = png_ptr->height;
+ png_ptr->usr_width = png_ptr->width;
+ }
+}
+
+/* Internal use only. Called when finished processing a row of data. */
+void /* PRIVATE */
+png_write_finish_row(png_structrp png_ptr)
+{
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+ /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */
+
+ /* Start of interlace block */
+ static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0};
+
+ /* Offset to next interlace block */
+ static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1};
+
+ /* Start of interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_ystart[7] = {0, 0, 4, 0, 2, 0, 1};
+
+ /* Offset to next interlace block in the y direction */
+ static PNG_CONST png_byte png_pass_yinc[7] = {8, 8, 8, 4, 4, 2, 2};
+#endif
+
+ png_debug(1, "in png_write_finish_row");
+
+ /* Next row */
+ png_ptr->row_number++;
+
+ /* See if we are done */
+ if (png_ptr->row_number < png_ptr->num_rows)
+ return;
+
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+ /* If interlaced, go to next pass */
+ if (png_ptr->interlaced)
+ {
+ png_ptr->row_number = 0;
+ if (png_ptr->transformations & PNG_INTERLACE)
+ {
+ png_ptr->pass++;
+ }
+
+ else
+ {
+ /* Loop until we find a non-zero width or height pass */
+ do
+ {
+ png_ptr->pass++;
+
+ if (png_ptr->pass >= 7)
+ break;
+
+ png_ptr->usr_width = (png_ptr->width +
+ png_pass_inc[png_ptr->pass] - 1 -
+ png_pass_start[png_ptr->pass]) /
+ png_pass_inc[png_ptr->pass];
+
+ png_ptr->num_rows = (png_ptr->height +
+ png_pass_yinc[png_ptr->pass] - 1 -
+ png_pass_ystart[png_ptr->pass]) /
+ png_pass_yinc[png_ptr->pass];
+
+ if (png_ptr->transformations & PNG_INTERLACE)
+ break;
+
+ } while (png_ptr->usr_width == 0 || png_ptr->num_rows == 0);
+
+ }
+
+ /* Reset the row above the image for the next pass */
+ if (png_ptr->pass < 7)
+ {
+ if (png_ptr->prev_row != NULL)
+ memset(png_ptr->prev_row, 0,
+ (png_size_t)(PNG_ROWBYTES(png_ptr->usr_channels*
+ png_ptr->usr_bit_depth, png_ptr->width)) + 1);
+
+ return;
+ }
+ }
+#endif
+
+ /* If we get here, we've just written the last row, so we need
+ to flush the compressor */
+ png_compress_IDAT(png_ptr, NULL, 0, Z_FINISH);
+}
+
+#ifdef PNG_WRITE_INTERLACING_SUPPORTED
+/* Pick out the correct pixels for the interlace pass.
+ * The basic idea here is to go through the row with a source
+ * pointer and a destination pointer (sp and dp), and copy the
+ * correct pixels for the pass. As the row gets compacted,
+ * sp will always be >= dp, so we should never overwrite anything.
+ * See the default: case for the easiest code to understand.
+ */
+void /* PRIVATE */
+png_do_write_interlace(png_row_infop row_info, png_bytep row, int pass)
+{
+ /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */
+
+ /* Start of interlace block */
+ static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0};
+
+ /* Offset to next interlace block */
+ static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1};
+
+ png_debug(1, "in png_do_write_interlace");
+
+ /* We don't have to do anything on the last pass (6) */
+ if (pass < 6)
+ {
+ /* Each pixel depth is handled separately */
+ switch (row_info->pixel_depth)
+ {
+ case 1:
+ {
+ png_bytep sp;
+ png_bytep dp;
+ int shift;
+ int d;
+ int value;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ dp = row;
+ d = 0;
+ shift = 7;
+
+ for (i = png_pass_start[pass]; i < row_width;
+ i += png_pass_inc[pass])
+ {
+ sp = row + (png_size_t)(i >> 3);
+ value = (int)(*sp >> (7 - (int)(i & 0x07))) & 0x01;
+ d |= (value << shift);
+
+ if (shift == 0)
+ {
+ shift = 7;
+ *dp++ = (png_byte)d;
+ d = 0;
+ }
+
+ else
+ shift--;
+
+ }
+ if (shift != 7)
+ *dp = (png_byte)d;
+
+ break;
+ }
+
+ case 2:
+ {
+ png_bytep sp;
+ png_bytep dp;
+ int shift;
+ int d;
+ int value;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ dp = row;
+ shift = 6;
+ d = 0;
+
+ for (i = png_pass_start[pass]; i < row_width;
+ i += png_pass_inc[pass])
+ {
+ sp = row + (png_size_t)(i >> 2);
+ value = (*sp >> ((3 - (int)(i & 0x03)) << 1)) & 0x03;
+ d |= (value << shift);
+
+ if (shift == 0)
+ {
+ shift = 6;
+ *dp++ = (png_byte)d;
+ d = 0;
+ }
+
+ else
+ shift -= 2;
+ }
+ if (shift != 6)
+ *dp = (png_byte)d;
+
+ break;
+ }
+
+ case 4:
+ {
+ png_bytep sp;
+ png_bytep dp;
+ int shift;
+ int d;
+ int value;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+
+ dp = row;
+ shift = 4;
+ d = 0;
+ for (i = png_pass_start[pass]; i < row_width;
+ i += png_pass_inc[pass])
+ {
+ sp = row + (png_size_t)(i >> 1);
+ value = (*sp >> ((1 - (int)(i & 0x01)) << 2)) & 0x0f;
+ d |= (value << shift);
+
+ if (shift == 0)
+ {
+ shift = 4;
+ *dp++ = (png_byte)d;
+ d = 0;
+ }
+
+ else
+ shift -= 4;
+ }
+ if (shift != 4)
+ *dp = (png_byte)d;
+
+ break;
+ }
+
+ default:
+ {
+ png_bytep sp;
+ png_bytep dp;
+ png_uint_32 i;
+ png_uint_32 row_width = row_info->width;
+ png_size_t pixel_bytes;
+
+ /* Start at the beginning */
+ dp = row;
+
+ /* Find out how many bytes each pixel takes up */
+ pixel_bytes = (row_info->pixel_depth >> 3);
+
+ /* Loop through the row, only looking at the pixels that matter */
+ for (i = png_pass_start[pass]; i < row_width;
+ i += png_pass_inc[pass])
+ {
+ /* Find out where the original pixel is */
+ sp = row + (png_size_t)i * pixel_bytes;
+
+ /* Move the pixel */
+ if (dp != sp)
+ memcpy(dp, sp, pixel_bytes);
+
+ /* Next pixel */
+ dp += pixel_bytes;
+ }
+ break;
+ }
+ }
+ /* Set new row width */
+ row_info->width = (row_info->width +
+ png_pass_inc[pass] - 1 -
+ png_pass_start[pass]) /
+ png_pass_inc[pass];
+
+ row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth,
+ row_info->width);
+ }
+}
+#endif
+
+/* This filters the row, chooses which filter to use, if it has not already
+ * been specified by the application, and then writes the row out with the
+ * chosen filter.
+ */
+static void png_write_filtered_row(png_structrp png_ptr, png_bytep filtered_row,
+ png_size_t row_bytes);
+
+#define PNG_MAXSUM (((png_uint_32)(-1)) >> 1)
+#define PNG_HISHIFT 10
+#define PNG_LOMASK ((png_uint_32)0xffffL)
+#define PNG_HIMASK ((png_uint_32)(~PNG_LOMASK >> PNG_HISHIFT))
+void /* PRIVATE */
+png_write_find_filter(png_structrp png_ptr, png_row_infop row_info)
+{
+ png_bytep best_row;
+#ifdef PNG_WRITE_FILTER_SUPPORTED
+ png_bytep prev_row, row_buf;
+ png_uint_32 mins, bpp;
+ png_byte filter_to_do = png_ptr->do_filter;
+ png_size_t row_bytes = row_info->rowbytes;
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ int num_p_filters = png_ptr->num_prev_filters;
+#endif
+
+ png_debug(1, "in png_write_find_filter");
+
+#ifndef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ if (png_ptr->row_number == 0 && filter_to_do == PNG_ALL_FILTERS)
+ {
+ /* These will never be selected so we need not test them. */
+ filter_to_do &= ~(PNG_FILTER_UP | PNG_FILTER_PAETH);
+ }
+#endif
+
+ /* Find out how many bytes offset each pixel is */
+ bpp = (row_info->pixel_depth + 7) >> 3;
+
+ prev_row = png_ptr->prev_row;
+#endif
+ best_row = png_ptr->row_buf;
+#ifdef PNG_WRITE_FILTER_SUPPORTED
+ row_buf = best_row;
+ mins = PNG_MAXSUM;
+
+ /* The prediction method we use is to find which method provides the
+ * smallest value when summing the absolute values of the distances
+ * from zero, using anything >= 128 as negative numbers. This is known
+ * as the "minimum sum of absolute differences" heuristic. Other
+ * heuristics are the "weighted minimum sum of absolute differences"
+ * (experimental and can in theory improve compression), and the "zlib
+ * predictive" method (not implemented yet), which does test compressions
+ * of lines using different filter methods, and then chooses the
+ * (series of) filter(s) that give minimum compressed data size (VERY
+ * computationally expensive).
+ *
+ * GRR 980525: consider also
+ *
+ * (1) minimum sum of absolute differences from running average (i.e.,
+ * keep running sum of non-absolute differences & count of bytes)
+ * [track dispersion, too? restart average if dispersion too large?]
+ *
+ * (1b) minimum sum of absolute differences from sliding average, probably
+ * with window size <= deflate window (usually 32K)
+ *
+ * (2) minimum sum of squared differences from zero or running average
+ * (i.e., ~ root-mean-square approach)
+ */
+
+
+ /* We don't need to test the 'no filter' case if this is the only filter
+ * that has been chosen, as it doesn't actually do anything to the data.
+ */
+ if ((filter_to_do & PNG_FILTER_NONE) && filter_to_do != PNG_FILTER_NONE)
+ {
+ png_bytep rp;
+ png_uint_32 sum = 0;
+ png_size_t i;
+ int v;
+
+ for (i = 0, rp = row_buf + 1; i < row_bytes; i++, rp++)
+ {
+ v = *rp;
+ sum += (v < 128) ? v : 256 - v;
+ }
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ png_uint_32 sumhi, sumlo;
+ int j;
+ sumlo = sum & PNG_LOMASK;
+ sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK; /* Gives us some footroom */
+
+ /* Reduce the sum if we match any of the previous rows */
+ for (j = 0; j < num_p_filters; j++)
+ {
+ if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_NONE)
+ {
+ sumlo = (sumlo * png_ptr->filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+
+ sumhi = (sumhi * png_ptr->filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+ }
+ }
+
+ /* Factor in the cost of this filter (this is here for completeness,
+ * but it makes no sense to have a "cost" for the NONE filter, as
+ * it has the minimum possible computational cost - none).
+ */
+ sumlo = (sumlo * png_ptr->filter_costs[PNG_FILTER_VALUE_NONE]) >>
+ PNG_COST_SHIFT;
+
+ sumhi = (sumhi * png_ptr->filter_costs[PNG_FILTER_VALUE_NONE]) >>
+ PNG_COST_SHIFT;
+
+ if (sumhi > PNG_HIMASK)
+ sum = PNG_MAXSUM;
+
+ else
+ sum = (sumhi << PNG_HISHIFT) + sumlo;
+ }
+#endif
+ mins = sum;
+ }
+
+ /* Sub filter */
+ if (filter_to_do == PNG_FILTER_SUB)
+ /* It's the only filter so no testing is needed */
+ {
+ png_bytep rp, lp, dp;
+ png_size_t i;
+
+ for (i = 0, rp = row_buf + 1, dp = png_ptr->sub_row + 1; i < bpp;
+ i++, rp++, dp++)
+ {
+ *dp = *rp;
+ }
+
+ for (lp = row_buf + 1; i < row_bytes;
+ i++, rp++, lp++, dp++)
+ {
+ *dp = (png_byte)(((int)*rp - (int)*lp) & 0xff);
+ }
+
+ best_row = png_ptr->sub_row;
+ }
+
+ else if (filter_to_do & PNG_FILTER_SUB)
+ {
+ png_bytep rp, dp, lp;
+ png_uint_32 sum = 0, lmins = mins;
+ png_size_t i;
+ int v;
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ /* We temporarily increase the "minimum sum" by the factor we
+ * would reduce the sum of this filter, so that we can do the
+ * early exit comparison without scaling the sum each time.
+ */
+ if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int j;
+ png_uint_32 lmhi, lmlo;
+ lmlo = lmins & PNG_LOMASK;
+ lmhi = (lmins >> PNG_HISHIFT) & PNG_HIMASK;
+
+ for (j = 0; j < num_p_filters; j++)
+ {
+ if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_SUB)
+ {
+ lmlo = (lmlo * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+
+ lmhi = (lmhi * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+ }
+ }
+
+ lmlo = (lmlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_SUB]) >>
+ PNG_COST_SHIFT;
+
+ lmhi = (lmhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_SUB]) >>
+ PNG_COST_SHIFT;
+
+ if (lmhi > PNG_HIMASK)
+ lmins = PNG_MAXSUM;
+
+ else
+ lmins = (lmhi << PNG_HISHIFT) + lmlo;
+ }
+#endif
+
+ for (i = 0, rp = row_buf + 1, dp = png_ptr->sub_row + 1; i < bpp;
+ i++, rp++, dp++)
+ {
+ v = *dp = *rp;
+
+ sum += (v < 128) ? v : 256 - v;
+ }
+
+ for (lp = row_buf + 1; i < row_bytes;
+ i++, rp++, lp++, dp++)
+ {
+ v = *dp = (png_byte)(((int)*rp - (int)*lp) & 0xff);
+
+ sum += (v < 128) ? v : 256 - v;
+
+ if (sum > lmins) /* We are already worse, don't continue. */
+ break;
+ }
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int j;
+ png_uint_32 sumhi, sumlo;
+ sumlo = sum & PNG_LOMASK;
+ sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK;
+
+ for (j = 0; j < num_p_filters; j++)
+ {
+ if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_SUB)
+ {
+ sumlo = (sumlo * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+
+ sumhi = (sumhi * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+ }
+ }
+
+ sumlo = (sumlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_SUB]) >>
+ PNG_COST_SHIFT;
+
+ sumhi = (sumhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_SUB]) >>
+ PNG_COST_SHIFT;
+
+ if (sumhi > PNG_HIMASK)
+ sum = PNG_MAXSUM;
+
+ else
+ sum = (sumhi << PNG_HISHIFT) + sumlo;
+ }
+#endif
+
+ if (sum < mins)
+ {
+ mins = sum;
+ best_row = png_ptr->sub_row;
+ }
+ }
+
+ /* Up filter */
+ if (filter_to_do == PNG_FILTER_UP)
+ {
+ png_bytep rp, dp, pp;
+ png_size_t i;
+
+ for (i = 0, rp = row_buf + 1, dp = png_ptr->up_row + 1,
+ pp = prev_row + 1; i < row_bytes;
+ i++, rp++, pp++, dp++)
+ {
+ *dp = (png_byte)(((int)*rp - (int)*pp) & 0xff);
+ }
+
+ best_row = png_ptr->up_row;
+ }
+
+ else if (filter_to_do & PNG_FILTER_UP)
+ {
+ png_bytep rp, dp, pp;
+ png_uint_32 sum = 0, lmins = mins;
+ png_size_t i;
+ int v;
+
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int j;
+ png_uint_32 lmhi, lmlo;
+ lmlo = lmins & PNG_LOMASK;
+ lmhi = (lmins >> PNG_HISHIFT) & PNG_HIMASK;
+
+ for (j = 0; j < num_p_filters; j++)
+ {
+ if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_UP)
+ {
+ lmlo = (lmlo * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+
+ lmhi = (lmhi * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+ }
+ }
+
+ lmlo = (lmlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_UP]) >>
+ PNG_COST_SHIFT;
+
+ lmhi = (lmhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_UP]) >>
+ PNG_COST_SHIFT;
+
+ if (lmhi > PNG_HIMASK)
+ lmins = PNG_MAXSUM;
+
+ else
+ lmins = (lmhi << PNG_HISHIFT) + lmlo;
+ }
+#endif
+
+ for (i = 0, rp = row_buf + 1, dp = png_ptr->up_row + 1,
+ pp = prev_row + 1; i < row_bytes; i++)
+ {
+ v = *dp++ = (png_byte)(((int)*rp++ - (int)*pp++) & 0xff);
+
+ sum += (v < 128) ? v : 256 - v;
+
+ if (sum > lmins) /* We are already worse, don't continue. */
+ break;
+ }
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int j;
+ png_uint_32 sumhi, sumlo;
+ sumlo = sum & PNG_LOMASK;
+ sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK;
+
+ for (j = 0; j < num_p_filters; j++)
+ {
+ if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_UP)
+ {
+ sumlo = (sumlo * png_ptr->filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+
+ sumhi = (sumhi * png_ptr->filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+ }
+ }
+
+ sumlo = (sumlo * png_ptr->filter_costs[PNG_FILTER_VALUE_UP]) >>
+ PNG_COST_SHIFT;
+
+ sumhi = (sumhi * png_ptr->filter_costs[PNG_FILTER_VALUE_UP]) >>
+ PNG_COST_SHIFT;
+
+ if (sumhi > PNG_HIMASK)
+ sum = PNG_MAXSUM;
+
+ else
+ sum = (sumhi << PNG_HISHIFT) + sumlo;
+ }
+#endif
+
+ if (sum < mins)
+ {
+ mins = sum;
+ best_row = png_ptr->up_row;
+ }
+ }
+
+ /* Avg filter */
+ if (filter_to_do == PNG_FILTER_AVG)
+ {
+ png_bytep rp, dp, pp, lp;
+ png_uint_32 i;
+
+ for (i = 0, rp = row_buf + 1, dp = png_ptr->avg_row + 1,
+ pp = prev_row + 1; i < bpp; i++)
+ {
+ *dp++ = (png_byte)(((int)*rp++ - ((int)*pp++ / 2)) & 0xff);
+ }
+
+ for (lp = row_buf + 1; i < row_bytes; i++)
+ {
+ *dp++ = (png_byte)(((int)*rp++ - (((int)*pp++ + (int)*lp++) / 2))
+ & 0xff);
+ }
+ best_row = png_ptr->avg_row;
+ }
+
+ else if (filter_to_do & PNG_FILTER_AVG)
+ {
+ png_bytep rp, dp, pp, lp;
+ png_uint_32 sum = 0, lmins = mins;
+ png_size_t i;
+ int v;
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int j;
+ png_uint_32 lmhi, lmlo;
+ lmlo = lmins & PNG_LOMASK;
+ lmhi = (lmins >> PNG_HISHIFT) & PNG_HIMASK;
+
+ for (j = 0; j < num_p_filters; j++)
+ {
+ if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_AVG)
+ {
+ lmlo = (lmlo * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+
+ lmhi = (lmhi * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+ }
+ }
+
+ lmlo = (lmlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_AVG]) >>
+ PNG_COST_SHIFT;
+
+ lmhi = (lmhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_AVG]) >>
+ PNG_COST_SHIFT;
+
+ if (lmhi > PNG_HIMASK)
+ lmins = PNG_MAXSUM;
+
+ else
+ lmins = (lmhi << PNG_HISHIFT) + lmlo;
+ }
+#endif
+
+ for (i = 0, rp = row_buf + 1, dp = png_ptr->avg_row + 1,
+ pp = prev_row + 1; i < bpp; i++)
+ {
+ v = *dp++ = (png_byte)(((int)*rp++ - ((int)*pp++ / 2)) & 0xff);
+
+ sum += (v < 128) ? v : 256 - v;
+ }
+
+ for (lp = row_buf + 1; i < row_bytes; i++)
+ {
+ v = *dp++ =
+ (png_byte)(((int)*rp++ - (((int)*pp++ + (int)*lp++) / 2)) & 0xff);
+
+ sum += (v < 128) ? v : 256 - v;
+
+ if (sum > lmins) /* We are already worse, don't continue. */
+ break;
+ }
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int j;
+ png_uint_32 sumhi, sumlo;
+ sumlo = sum & PNG_LOMASK;
+ sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK;
+
+ for (j = 0; j < num_p_filters; j++)
+ {
+ if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_NONE)
+ {
+ sumlo = (sumlo * png_ptr->filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+
+ sumhi = (sumhi * png_ptr->filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+ }
+ }
+
+ sumlo = (sumlo * png_ptr->filter_costs[PNG_FILTER_VALUE_AVG]) >>
+ PNG_COST_SHIFT;
+
+ sumhi = (sumhi * png_ptr->filter_costs[PNG_FILTER_VALUE_AVG]) >>
+ PNG_COST_SHIFT;
+
+ if (sumhi > PNG_HIMASK)
+ sum = PNG_MAXSUM;
+
+ else
+ sum = (sumhi << PNG_HISHIFT) + sumlo;
+ }
+#endif
+
+ if (sum < mins)
+ {
+ mins = sum;
+ best_row = png_ptr->avg_row;
+ }
+ }
+
+ /* Paeth filter */
+ if (filter_to_do == PNG_FILTER_PAETH)
+ {
+ png_bytep rp, dp, pp, cp, lp;
+ png_size_t i;
+
+ for (i = 0, rp = row_buf + 1, dp = png_ptr->paeth_row + 1,
+ pp = prev_row + 1; i < bpp; i++)
+ {
+ *dp++ = (png_byte)(((int)*rp++ - (int)*pp++) & 0xff);
+ }
+
+ for (lp = row_buf + 1, cp = prev_row + 1; i < row_bytes; i++)
+ {
+ int a, b, c, pa, pb, pc, p;
+
+ b = *pp++;
+ c = *cp++;
+ a = *lp++;
+
+ p = b - c;
+ pc = a - c;
+
+#ifdef PNG_USE_ABS
+ pa = abs(p);
+ pb = abs(pc);
+ pc = abs(p + pc);
+#else
+ pa = p < 0 ? -p : p;
+ pb = pc < 0 ? -pc : pc;
+ pc = (p + pc) < 0 ? -(p + pc) : p + pc;
+#endif
+
+ p = (pa <= pb && pa <=pc) ? a : (pb <= pc) ? b : c;
+
+ *dp++ = (png_byte)(((int)*rp++ - p) & 0xff);
+ }
+ best_row = png_ptr->paeth_row;
+ }
+
+ else if (filter_to_do & PNG_FILTER_PAETH)
+ {
+ png_bytep rp, dp, pp, cp, lp;
+ png_uint_32 sum = 0, lmins = mins;
+ png_size_t i;
+ int v;
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int j;
+ png_uint_32 lmhi, lmlo;
+ lmlo = lmins & PNG_LOMASK;
+ lmhi = (lmins >> PNG_HISHIFT) & PNG_HIMASK;
+
+ for (j = 0; j < num_p_filters; j++)
+ {
+ if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_PAETH)
+ {
+ lmlo = (lmlo * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+
+ lmhi = (lmhi * png_ptr->inv_filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+ }
+ }
+
+ lmlo = (lmlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_PAETH]) >>
+ PNG_COST_SHIFT;
+
+ lmhi = (lmhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_PAETH]) >>
+ PNG_COST_SHIFT;
+
+ if (lmhi > PNG_HIMASK)
+ lmins = PNG_MAXSUM;
+
+ else
+ lmins = (lmhi << PNG_HISHIFT) + lmlo;
+ }
+#endif
+
+ for (i = 0, rp = row_buf + 1, dp = png_ptr->paeth_row + 1,
+ pp = prev_row + 1; i < bpp; i++)
+ {
+ v = *dp++ = (png_byte)(((int)*rp++ - (int)*pp++) & 0xff);
+
+ sum += (v < 128) ? v : 256 - v;
+ }
+
+ for (lp = row_buf + 1, cp = prev_row + 1; i < row_bytes; i++)
+ {
+ int a, b, c, pa, pb, pc, p;
+
+ b = *pp++;
+ c = *cp++;
+ a = *lp++;
+
+#ifndef PNG_SLOW_PAETH
+ p = b - c;
+ pc = a - c;
+#ifdef PNG_USE_ABS
+ pa = abs(p);
+ pb = abs(pc);
+ pc = abs(p + pc);
+#else
+ pa = p < 0 ? -p : p;
+ pb = pc < 0 ? -pc : pc;
+ pc = (p + pc) < 0 ? -(p + pc) : p + pc;
+#endif
+ p = (pa <= pb && pa <=pc) ? a : (pb <= pc) ? b : c;
+#else /* PNG_SLOW_PAETH */
+ p = a + b - c;
+ pa = abs(p - a);
+ pb = abs(p - b);
+ pc = abs(p - c);
+
+ if (pa <= pb && pa <= pc)
+ p = a;
+
+ else if (pb <= pc)
+ p = b;
+
+ else
+ p = c;
+#endif /* PNG_SLOW_PAETH */
+
+ v = *dp++ = (png_byte)(((int)*rp++ - p) & 0xff);
+
+ sum += (v < 128) ? v : 256 - v;
+
+ if (sum > lmins) /* We are already worse, don't continue. */
+ break;
+ }
+
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED)
+ {
+ int j;
+ png_uint_32 sumhi, sumlo;
+ sumlo = sum & PNG_LOMASK;
+ sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK;
+
+ for (j = 0; j < num_p_filters; j++)
+ {
+ if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_PAETH)
+ {
+ sumlo = (sumlo * png_ptr->filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+
+ sumhi = (sumhi * png_ptr->filter_weights[j]) >>
+ PNG_WEIGHT_SHIFT;
+ }
+ }
+
+ sumlo = (sumlo * png_ptr->filter_costs[PNG_FILTER_VALUE_PAETH]) >>
+ PNG_COST_SHIFT;
+
+ sumhi = (sumhi * png_ptr->filter_costs[PNG_FILTER_VALUE_PAETH]) >>
+ PNG_COST_SHIFT;
+
+ if (sumhi > PNG_HIMASK)
+ sum = PNG_MAXSUM;
+
+ else
+ sum = (sumhi << PNG_HISHIFT) + sumlo;
+ }
+#endif
+
+ if (sum < mins)
+ {
+ best_row = png_ptr->paeth_row;
+ }
+ }
+#endif /* PNG_WRITE_FILTER_SUPPORTED */
+
+ /* Do the actual writing of the filtered row data from the chosen filter. */
+ png_write_filtered_row(png_ptr, best_row, row_info->rowbytes+1);
+
+#ifdef PNG_WRITE_FILTER_SUPPORTED
+#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED
+ /* Save the type of filter we picked this time for future calculations */
+ if (png_ptr->num_prev_filters > 0)
+ {
+ int j;
+
+ for (j = 1; j < num_p_filters; j++)
+ {
+ png_ptr->prev_filters[j] = png_ptr->prev_filters[j - 1];
+ }
+
+ png_ptr->prev_filters[j] = best_row[0];
+ }
+#endif
+#endif /* PNG_WRITE_FILTER_SUPPORTED */
+}
+
+
+/* Do the actual writing of a previously filtered row. */
+static void
+png_write_filtered_row(png_structrp png_ptr, png_bytep filtered_row,
+ png_size_t full_row_length/*includes filter byte*/)
+{
+ png_debug(1, "in png_write_filtered_row");
+
+ png_debug1(2, "filter = %d", filtered_row[0]);
+
+ png_compress_IDAT(png_ptr, filtered_row, full_row_length, Z_NO_FLUSH);
+
+ /* Swap the current and previous rows */
+ if (png_ptr->prev_row != NULL)
+ {
+ png_bytep tptr;
+
+ tptr = png_ptr->prev_row;
+ png_ptr->prev_row = png_ptr->row_buf;
+ png_ptr->row_buf = tptr;
+ }
+
+ /* Finish row - updates counters and flushes zlib if last row */
+ png_write_finish_row(png_ptr);
+
+#ifdef PNG_WRITE_FLUSH_SUPPORTED
+ png_ptr->flush_rows++;
+
+ if (png_ptr->flush_dist > 0 &&
+ png_ptr->flush_rows >= png_ptr->flush_dist)
+ {
+ png_write_flush(png_ptr);
+ }
+#endif
+}
+#endif /* PNG_WRITE_SUPPORTED */
diff --git a/ml/dlib/dlib/external/pybind11/CMakeLists.txt b/ml/dlib/dlib/external/pybind11/CMakeLists.txt
new file mode 100644
index 000000000..4280ba742
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/CMakeLists.txt
@@ -0,0 +1,155 @@
+# CMakeLists.txt -- Build system for the pybind11 modules
+#
+# Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>
+#
+# All rights reserved. Use of this source code is governed by a
+# BSD-style license that can be found in the LICENSE file.
+
+cmake_minimum_required(VERSION 2.8.12)
+
+if (POLICY CMP0048)
+ # cmake warns if loaded from a min-3.0-required parent dir, so silence the warning:
+ cmake_policy(SET CMP0048 NEW)
+endif()
+
+# CMake versions < 3.4.0 do not support try_compile/pthread checks without C as active language.
+if(CMAKE_VERSION VERSION_LESS 3.4.0)
+ project(pybind11)
+else()
+ project(pybind11 CXX)
+endif()
+
+# Check if pybind11 is being used directly or via add_subdirectory
+set(PYBIND11_MASTER_PROJECT OFF)
+if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
+ set(PYBIND11_MASTER_PROJECT ON)
+endif()
+
+option(PYBIND11_INSTALL "Install pybind11 header files?" ${PYBIND11_MASTER_PROJECT})
+option(PYBIND11_TEST "Build pybind11 test suite?" ${PYBIND11_MASTER_PROJECT})
+
+list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/tools")
+
+include(pybind11Tools)
+
+# Cache variables so pybind11_add_module can be used in parent projects
+set(PYBIND11_INCLUDE_DIR "${CMAKE_CURRENT_LIST_DIR}/include" CACHE INTERNAL "")
+set(PYTHON_INCLUDE_DIRS ${PYTHON_INCLUDE_DIRS} CACHE INTERNAL "")
+set(PYTHON_LIBRARIES ${PYTHON_LIBRARIES} CACHE INTERNAL "")
+set(PYTHON_MODULE_PREFIX ${PYTHON_MODULE_PREFIX} CACHE INTERNAL "")
+set(PYTHON_MODULE_EXTENSION ${PYTHON_MODULE_EXTENSION} CACHE INTERNAL "")
+
+# NB: when adding a header don't forget to also add it to setup.py
+set(PYBIND11_HEADERS
+ include/pybind11/detail/class.h
+ include/pybind11/detail/common.h
+ include/pybind11/detail/descr.h
+ include/pybind11/detail/init.h
+ include/pybind11/detail/internals.h
+ include/pybind11/detail/typeid.h
+ include/pybind11/attr.h
+ include/pybind11/buffer_info.h
+ include/pybind11/cast.h
+ include/pybind11/chrono.h
+ include/pybind11/common.h
+ include/pybind11/complex.h
+ include/pybind11/options.h
+ include/pybind11/eigen.h
+ include/pybind11/embed.h
+ include/pybind11/eval.h
+ include/pybind11/functional.h
+ include/pybind11/numpy.h
+ include/pybind11/operators.h
+ include/pybind11/pybind11.h
+ include/pybind11/pytypes.h
+ include/pybind11/stl.h
+ include/pybind11/stl_bind.h
+)
+string(REPLACE "include/" "${CMAKE_CURRENT_SOURCE_DIR}/include/"
+ PYBIND11_HEADERS "${PYBIND11_HEADERS}")
+
+if (PYBIND11_TEST)
+ add_subdirectory(tests)
+endif()
+
+include(GNUInstallDirs)
+include(CMakePackageConfigHelpers)
+
+# extract project version from source
+file(STRINGS "${PYBIND11_INCLUDE_DIR}/pybind11/detail/common.h" pybind11_version_defines
+ REGEX "#define PYBIND11_VERSION_(MAJOR|MINOR|PATCH) ")
+foreach(ver ${pybind11_version_defines})
+ if (ver MATCHES "#define PYBIND11_VERSION_(MAJOR|MINOR|PATCH) +([^ ]+)$")
+ set(PYBIND11_VERSION_${CMAKE_MATCH_1} "${CMAKE_MATCH_2}" CACHE INTERNAL "")
+ endif()
+endforeach()
+set(${PROJECT_NAME}_VERSION ${PYBIND11_VERSION_MAJOR}.${PYBIND11_VERSION_MINOR}.${PYBIND11_VERSION_PATCH})
+message(STATUS "pybind11 v${${PROJECT_NAME}_VERSION}")
+
+option (USE_PYTHON_INCLUDE_DIR "Install pybind11 headers in Python include directory instead of default installation prefix" OFF)
+if (USE_PYTHON_INCLUDE_DIR)
+ file(RELATIVE_PATH CMAKE_INSTALL_INCLUDEDIR ${CMAKE_INSTALL_PREFIX} ${PYTHON_INCLUDE_DIRS})
+endif()
+
+if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) # CMake >= 3.0
+ # Build an interface library target:
+ add_library(pybind11 INTERFACE)
+ add_library(pybind11::pybind11 ALIAS pybind11) # to match exported target
+ target_include_directories(pybind11 INTERFACE $<BUILD_INTERFACE:${PYBIND11_INCLUDE_DIR}>
+ $<BUILD_INTERFACE:${PYTHON_INCLUDE_DIRS}>
+ $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>)
+ target_compile_options(pybind11 INTERFACE $<BUILD_INTERFACE:${PYBIND11_CPP_STANDARD}>)
+
+ add_library(module INTERFACE)
+ add_library(pybind11::module ALIAS module)
+ if(NOT MSVC)
+ target_compile_options(module INTERFACE -fvisibility=hidden)
+ endif()
+ target_link_libraries(module INTERFACE pybind11::pybind11)
+ if(WIN32 OR CYGWIN)
+ target_link_libraries(module INTERFACE $<BUILD_INTERFACE:${PYTHON_LIBRARIES}>)
+ elseif(APPLE)
+ target_link_libraries(module INTERFACE "-undefined dynamic_lookup")
+ endif()
+
+ add_library(embed INTERFACE)
+ add_library(pybind11::embed ALIAS embed)
+ target_link_libraries(embed INTERFACE pybind11::pybind11 $<BUILD_INTERFACE:${PYTHON_LIBRARIES}>)
+endif()
+
+if (PYBIND11_INSTALL)
+ install(DIRECTORY ${PYBIND11_INCLUDE_DIR}/pybind11 DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
+ # GNUInstallDirs "DATADIR" wrong here; CMake search path wants "share".
+ set(PYBIND11_CMAKECONFIG_INSTALL_DIR "share/cmake/${PROJECT_NAME}" CACHE STRING "install path for pybind11Config.cmake")
+
+ configure_package_config_file(tools/${PROJECT_NAME}Config.cmake.in
+ "${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config.cmake"
+ INSTALL_DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR})
+ # Remove CMAKE_SIZEOF_VOID_P from ConfigVersion.cmake since the library does
+ # not depend on architecture specific settings or libraries.
+ set(_PYBIND11_CMAKE_SIZEOF_VOID_P ${CMAKE_SIZEOF_VOID_P})
+ unset(CMAKE_SIZEOF_VOID_P)
+ write_basic_package_version_file(${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake
+ VERSION ${${PROJECT_NAME}_VERSION}
+ COMPATIBILITY AnyNewerVersion)
+ set(CMAKE_SIZEOF_VOID_P ${_PYBIND11_CMAKE_SIZEOF_VOID_P})
+ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config.cmake
+ ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake
+ tools/FindPythonLibsNew.cmake
+ tools/pybind11Tools.cmake
+ DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR})
+
+ if(NOT (CMAKE_VERSION VERSION_LESS 3.0))
+ if(NOT PYBIND11_EXPORT_NAME)
+ set(PYBIND11_EXPORT_NAME "${PROJECT_NAME}Targets")
+ endif()
+
+ install(TARGETS pybind11 module embed
+ EXPORT "${PYBIND11_EXPORT_NAME}")
+ if(PYBIND11_MASTER_PROJECT)
+ install(EXPORT "${PYBIND11_EXPORT_NAME}"
+ NAMESPACE "${PROJECT_NAME}::"
+ DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR})
+ endif()
+ endif()
+endif()
diff --git a/ml/dlib/dlib/external/pybind11/CONTRIBUTING.md b/ml/dlib/dlib/external/pybind11/CONTRIBUTING.md
new file mode 100644
index 000000000..375735f6c
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/CONTRIBUTING.md
@@ -0,0 +1,47 @@
+Thank you for your interest in this project! Please refer to the following
+sections on how to contribute code and bug reports.
+
+### Reporting bugs
+
+At the moment, this project is run in the spare time of a single person
+([Wenzel Jakob](http://rgl.epfl.ch/people/wjakob)) with very limited resources
+for issue tracker tickets. Thus, before submitting a question or bug report,
+please take a moment of your time and ensure that your issue isn't already
+discussed in the project documentation provided at
+[http://pybind11.readthedocs.org/en/latest](http://pybind11.readthedocs.org/en/latest).
+
+Assuming that you have identified a previously unknown problem or an important
+question, it's essential that you submit a self-contained and minimal piece of
+code that reproduces the problem. In other words: no external dependencies,
+isolate the function(s) that cause breakage, submit matched and complete C++
+and Python snippets that can be easily compiled and run on my end.
+
+## Pull requests
+Contributions are submitted, reviewed, and accepted using Github pull requests.
+Please refer to [this
+article](https://help.github.com/articles/using-pull-requests) for details and
+adhere to the following rules to make the process as smooth as possible:
+
+* Make a new branch for every feature you're working on.
+* Make small and clean pull requests that are easy to review but make sure they
+ do add value by themselves.
+* Add tests for any new functionality and run the test suite (``make pytest``)
+ to ensure that no existing features break.
+* This project has a strong focus on providing general solutions using a
+ minimal amount of code, thus small pull requests are greatly preferred.
+
+### Licensing of contributions
+
+pybind11 is provided under a BSD-style license that can be found in the
+``LICENSE`` file. By using, distributing, or contributing to this project, you
+agree to the terms and conditions of this license.
+
+You are under no obligation whatsoever to provide any bug fixes, patches, or
+upgrades to the features, functionality or performance of the source code
+("Enhancements") to anyone; however, if you choose to make your Enhancements
+available either publicly, or directly to the author of this software, without
+imposing a separate written license agreement for such Enhancements, then you
+hereby grant the following license: a non-exclusive, royalty-free perpetual
+license to install, use, modify, prepare derivative works, incorporate into
+other computer software, distribute, and sublicense such enhancements or
+derivative works thereof, in binary and source code form.
diff --git a/ml/dlib/dlib/external/pybind11/LICENSE b/ml/dlib/dlib/external/pybind11/LICENSE
new file mode 100644
index 000000000..6f15578cc
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/LICENSE
@@ -0,0 +1,29 @@
+Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>, All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its contributors
+ may be used to endorse or promote products derived from this software
+ without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+Please also refer to the file CONTRIBUTING.md, which clarifies licensing of
+external contributions to this project including patches, pull requests, etc.
diff --git a/ml/dlib/dlib/external/pybind11/README.md b/ml/dlib/dlib/external/pybind11/README.md
new file mode 100644
index 000000000..447788240
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/README.md
@@ -0,0 +1,129 @@
+![pybind11 logo](https://github.com/pybind/pybind11/raw/master/docs/pybind11-logo.png)
+
+# pybind11 — Seamless operability between C++11 and Python
+
+[![Documentation Status](https://readthedocs.org/projects/pybind11/badge/?version=master)](http://pybind11.readthedocs.org/en/master/?badge=master)
+[![Documentation Status](https://readthedocs.org/projects/pybind11/badge/?version=stable)](http://pybind11.readthedocs.org/en/stable/?badge=stable)
+[![Gitter chat](https://img.shields.io/gitter/room/gitterHQ/gitter.svg)](https://gitter.im/pybind/Lobby)
+[![Build Status](https://travis-ci.org/pybind/pybind11.svg?branch=master)](https://travis-ci.org/pybind/pybind11)
+[![Build status](https://ci.appveyor.com/api/projects/status/riaj54pn4h08xy40?svg=true)](https://ci.appveyor.com/project/wjakob/pybind11)
+
+**pybind11** is a lightweight header-only library that exposes C++ types in Python
+and vice versa, mainly to create Python bindings of existing C++ code. Its
+goals and syntax are similar to the excellent
+[Boost.Python](http://www.boost.org/doc/libs/1_58_0/libs/python/doc/) library
+by David Abrahams: to minimize boilerplate code in traditional extension
+modules by inferring type information using compile-time introspection.
+
+The main issue with Boost.Python—and the reason for creating such a similar
+project—is Boost. Boost is an enormously large and complex suite of utility
+libraries that works with almost every C++ compiler in existence. This
+compatibility has its cost: arcane template tricks and workarounds are
+necessary to support the oldest and buggiest of compiler specimens. Now that
+C++11-compatible compilers are widely available, this heavy machinery has
+become an excessively large and unnecessary dependency.
+
+Think of this library as a tiny self-contained version of Boost.Python with
+everything stripped away that isn't relevant for binding generation. Without
+comments, the core header files only require ~4K lines of code and depend on
+Python (2.7 or 3.x, or PyPy2.7 >= 5.7) and the C++ standard library. This
+compact implementation was possible thanks to some of the new C++11 language
+features (specifically: tuples, lambda functions and variadic templates). Since
+its creation, this library has grown beyond Boost.Python in many ways, leading
+to dramatically simpler binding code in many common situations.
+
+Tutorial and reference documentation is provided at
+[http://pybind11.readthedocs.org/en/master](http://pybind11.readthedocs.org/en/master).
+A PDF version of the manual is available
+[here](https://media.readthedocs.org/pdf/pybind11/master/pybind11.pdf).
+
+## Core features
+pybind11 can map the following core C++ features to Python
+
+- Functions accepting and returning custom data structures per value, reference, or pointer
+- Instance methods and static methods
+- Overloaded functions
+- Instance attributes and static attributes
+- Arbitrary exception types
+- Enumerations
+- Callbacks
+- Iterators and ranges
+- Custom operators
+- Single and multiple inheritance
+- STL data structures
+- Iterators and ranges
+- Smart pointers with reference counting like ``std::shared_ptr``
+- Internal references with correct reference counting
+- C++ classes with virtual (and pure virtual) methods can be extended in Python
+
+## Goodies
+In addition to the core functionality, pybind11 provides some extra goodies:
+
+- Python 2.7, 3.x, and PyPy (PyPy2.7 >= 5.7) are supported with an
+ implementation-agnostic interface.
+
+- It is possible to bind C++11 lambda functions with captured variables. The
+ lambda capture data is stored inside the resulting Python function object.
+
+- pybind11 uses C++11 move constructors and move assignment operators whenever
+ possible to efficiently transfer custom data types.
+
+- It's easy to expose the internal storage of custom data types through
+ Pythons' buffer protocols. This is handy e.g. for fast conversion between
+ C++ matrix classes like Eigen and NumPy without expensive copy operations.
+
+- pybind11 can automatically vectorize functions so that they are transparently
+ applied to all entries of one or more NumPy array arguments.
+
+- Python's slice-based access and assignment operations can be supported with
+ just a few lines of code.
+
+- Everything is contained in just a few header files; there is no need to link
+ against any additional libraries.
+
+- Binaries are generally smaller by a factor of at least 2 compared to
+ equivalent bindings generated by Boost.Python. A recent pybind11 conversion
+ of PyRosetta, an enormous Boost.Python binding project,
+ [reported](http://graylab.jhu.edu/RosettaCon2016/PyRosetta-4.pdf) a binary
+ size reduction of **5.4x** and compile time reduction by **5.8x**.
+
+- When supported by the compiler, two new C++14 features (relaxed constexpr and
+ return value deduction) are used to precompute function signatures at compile
+ time, leading to smaller binaries.
+
+- With little extra effort, C++ types can be pickled and unpickled similar to
+ regular Python objects.
+
+## Supported compilers
+
+1. Clang/LLVM 3.3 or newer (for Apple Xcode's clang, this is 5.0.0 or newer)
+2. GCC 4.8 or newer
+3. Microsoft Visual Studio 2015 Update 3 or newer
+4. Intel C++ compiler 16 or newer (15 with a [workaround](https://github.com/pybind/pybind11/issues/276))
+5. Cygwin/GCC (tested on 2.5.1)
+
+## About
+
+This project was created by [Wenzel Jakob](http://rgl.epfl.ch/people/wjakob).
+Significant features and/or improvements to the code were contributed by
+Jonas Adler,
+Sylvain Corlay,
+Trent Houliston,
+Axel Huebl,
+@hulucc,
+Sergey Lyskov
+Johan Mabille,
+Tomasz Miąsko,
+Dean Moldovan,
+Ben Pritchard,
+Jason Rhinelander,
+Boris Schäling,
+Pim Schellart,
+Ivan Smirnov, and
+Patrick Stewart.
+
+### License
+
+pybind11 is provided under a BSD-style license that can be found in the
+``LICENSE`` file. By using, distributing, or contributing to this project,
+you agree to the terms and conditions of this license.
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/attr.h b/ml/dlib/dlib/external/pybind11/include/pybind11/attr.h
new file mode 100644
index 000000000..dce875a6b
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/attr.h
@@ -0,0 +1,489 @@
+/*
+ pybind11/attr.h: Infrastructure for processing custom
+ type and function attributes
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "cast.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+/// \addtogroup annotations
+/// @{
+
+/// Annotation for methods
+struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
+
+/// Annotation for operators
+struct is_operator { };
+
+/// Annotation for parent scope
+struct scope { handle value; scope(const handle &s) : value(s) { } };
+
+/// Annotation for documentation
+struct doc { const char *value; doc(const char *value) : value(value) { } };
+
+/// Annotation for function names
+struct name { const char *value; name(const char *value) : value(value) { } };
+
+/// Annotation indicating that a function is an overload associated with a given "sibling"
+struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } };
+
+/// Annotation indicating that a class derives from another given type
+template <typename T> struct base {
+ PYBIND11_DEPRECATED("base<T>() was deprecated in favor of specifying 'T' as a template argument to class_")
+ base() { }
+};
+
+/// Keep patient alive while nurse lives
+template <size_t Nurse, size_t Patient> struct keep_alive { };
+
+/// Annotation indicating that a class is involved in a multiple inheritance relationship
+struct multiple_inheritance { };
+
+/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class
+struct dynamic_attr { };
+
+/// Annotation which enables the buffer protocol for a type
+struct buffer_protocol { };
+
+/// Annotation which requests that a special metaclass is created for a type
+struct metaclass {
+ handle value;
+
+ PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.")
+ metaclass() {}
+
+ /// Override pybind11's default metaclass
+ explicit metaclass(handle value) : value(value) { }
+};
+
+/// Annotation that marks a class as local to the module:
+struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } };
+
+/// Annotation to mark enums as an arithmetic type
+struct arithmetic { };
+
+/** \rst
+ A call policy which places one or more guard variables (``Ts...``) around the function call.
+
+ For example, this definition:
+
+ .. code-block:: cpp
+
+ m.def("foo", foo, py::call_guard<T>());
+
+ is equivalent to the following pseudocode:
+
+ .. code-block:: cpp
+
+ m.def("foo", [](args...) {
+ T scope_guard;
+ return foo(args...); // forwarded arguments
+ });
+ \endrst */
+template <typename... Ts> struct call_guard;
+
+template <> struct call_guard<> { using type = detail::void_type; };
+
+template <typename T>
+struct call_guard<T> {
+ static_assert(std::is_default_constructible<T>::value,
+ "The guard type must be default constructible");
+
+ using type = T;
+};
+
+template <typename T, typename... Ts>
+struct call_guard<T, Ts...> {
+ struct type {
+ T guard{}; // Compose multiple guard types with left-to-right default-constructor order
+ typename call_guard<Ts...>::type next{};
+ };
+};
+
+/// @} annotations
+
+NAMESPACE_BEGIN(detail)
+/* Forward declarations */
+enum op_id : int;
+enum op_type : int;
+struct undefined_t;
+template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t> struct op_;
+inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
+
+/// Internal data structure which holds metadata about a keyword argument
+struct argument_record {
+ const char *name; ///< Argument name
+ const char *descr; ///< Human-readable version of the argument value
+ handle value; ///< Associated Python object
+ bool convert : 1; ///< True if the argument is allowed to convert when loading
+ bool none : 1; ///< True if None is allowed when loading
+
+ argument_record(const char *name, const char *descr, handle value, bool convert, bool none)
+ : name(name), descr(descr), value(value), convert(convert), none(none) { }
+};
+
+/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
+struct function_record {
+ function_record()
+ : is_constructor(false), is_new_style_constructor(false), is_stateless(false),
+ is_operator(false), has_args(false), has_kwargs(false), is_method(false) { }
+
+ /// Function name
+ char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
+
+ // User-specified documentation string
+ char *doc = nullptr;
+
+ /// Human-readable version of the function signature
+ char *signature = nullptr;
+
+ /// List of registered keyword arguments
+ std::vector<argument_record> args;
+
+ /// Pointer to lambda function which converts arguments and performs the actual call
+ handle (*impl) (function_call &) = nullptr;
+
+ /// Storage for the wrapped function pointer and captured data, if any
+ void *data[3] = { };
+
+ /// Pointer to custom destructor for 'data' (if needed)
+ void (*free_data) (function_record *ptr) = nullptr;
+
+ /// Return value policy associated with this function
+ return_value_policy policy = return_value_policy::automatic;
+
+ /// True if name == '__init__'
+ bool is_constructor : 1;
+
+ /// True if this is a new-style `__init__` defined in `detail/init.h`
+ bool is_new_style_constructor : 1;
+
+ /// True if this is a stateless function pointer
+ bool is_stateless : 1;
+
+ /// True if this is an operator (__add__), etc.
+ bool is_operator : 1;
+
+ /// True if the function has a '*args' argument
+ bool has_args : 1;
+
+ /// True if the function has a '**kwargs' argument
+ bool has_kwargs : 1;
+
+ /// True if this is a method
+ bool is_method : 1;
+
+ /// Number of arguments (including py::args and/or py::kwargs, if present)
+ std::uint16_t nargs;
+
+ /// Python method object
+ PyMethodDef *def = nullptr;
+
+ /// Python handle to the parent scope (a class or a module)
+ handle scope;
+
+ /// Python handle to the sibling function representing an overload chain
+ handle sibling;
+
+ /// Pointer to next overload
+ function_record *next = nullptr;
+};
+
+/// Special data structure which (temporarily) holds metadata about a bound class
+struct type_record {
+ PYBIND11_NOINLINE type_record()
+ : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { }
+
+ /// Handle to the parent scope
+ handle scope;
+
+ /// Name of the class
+ const char *name = nullptr;
+
+ // Pointer to RTTI type_info data structure
+ const std::type_info *type = nullptr;
+
+ /// How large is the underlying C++ type?
+ size_t type_size = 0;
+
+ /// How large is the type's holder?
+ size_t holder_size = 0;
+
+ /// The global operator new can be overridden with a class-specific variant
+ void *(*operator_new)(size_t) = ::operator new;
+
+ /// Function pointer to class_<..>::init_instance
+ void (*init_instance)(instance *, const void *) = nullptr;
+
+ /// Function pointer to class_<..>::dealloc
+ void (*dealloc)(detail::value_and_holder &) = nullptr;
+
+ /// List of base classes of the newly created type
+ list bases;
+
+ /// Optional docstring
+ const char *doc = nullptr;
+
+ /// Custom metaclass (optional)
+ handle metaclass;
+
+ /// Multiple inheritance marker
+ bool multiple_inheritance : 1;
+
+ /// Does the class manage a __dict__?
+ bool dynamic_attr : 1;
+
+ /// Does the class implement the buffer protocol?
+ bool buffer_protocol : 1;
+
+ /// Is the default (unique_ptr) holder type used?
+ bool default_holder : 1;
+
+ /// Is the class definition local to the module shared object?
+ bool module_local : 1;
+
+ PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
+ auto base_info = detail::get_type_info(base, false);
+ if (!base_info) {
+ std::string tname(base.name());
+ detail::clean_type_id(tname);
+ pybind11_fail("generic_type: type \"" + std::string(name) +
+ "\" referenced unknown base type \"" + tname + "\"");
+ }
+
+ if (default_holder != base_info->default_holder) {
+ std::string tname(base.name());
+ detail::clean_type_id(tname);
+ pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
+ (default_holder ? "does not have" : "has") +
+ " a non-default holder type while its base \"" + tname + "\" " +
+ (base_info->default_holder ? "does not" : "does"));
+ }
+
+ bases.append((PyObject *) base_info->type);
+
+ if (base_info->type->tp_dictoffset != 0)
+ dynamic_attr = true;
+
+ if (caster)
+ base_info->implicit_casts.emplace_back(type, caster);
+ }
+};
+
+inline function_call::function_call(function_record &f, handle p) :
+ func(f), parent(p) {
+ args.reserve(f.nargs);
+ args_convert.reserve(f.nargs);
+}
+
+/// Tag for a new-style `__init__` defined in `detail/init.h`
+struct is_new_style_constructor { };
+
+/**
+ * Partial template specializations to process custom attributes provided to
+ * cpp_function_ and class_. These are either used to initialize the respective
+ * fields in the type_record and function_record data structures or executed at
+ * runtime to deal with custom call policies (e.g. keep_alive).
+ */
+template <typename T, typename SFINAE = void> struct process_attribute;
+
+template <typename T> struct process_attribute_default {
+ /// Default implementation: do nothing
+ static void init(const T &, function_record *) { }
+ static void init(const T &, type_record *) { }
+ static void precall(function_call &) { }
+ static void postcall(function_call &, handle) { }
+};
+
+/// Process an attribute specifying the function's name
+template <> struct process_attribute<name> : process_attribute_default<name> {
+ static void init(const name &n, function_record *r) { r->name = const_cast<char *>(n.value); }
+};
+
+/// Process an attribute specifying the function's docstring
+template <> struct process_attribute<doc> : process_attribute_default<doc> {
+ static void init(const doc &n, function_record *r) { r->doc = const_cast<char *>(n.value); }
+};
+
+/// Process an attribute specifying the function's docstring (provided as a C-style string)
+template <> struct process_attribute<const char *> : process_attribute_default<const char *> {
+ static void init(const char *d, function_record *r) { r->doc = const_cast<char *>(d); }
+ static void init(const char *d, type_record *r) { r->doc = const_cast<char *>(d); }
+};
+template <> struct process_attribute<char *> : process_attribute<const char *> { };
+
+/// Process an attribute indicating the function's return value policy
+template <> struct process_attribute<return_value_policy> : process_attribute_default<return_value_policy> {
+ static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
+};
+
+/// Process an attribute which indicates that this is an overloaded function associated with a given sibling
+template <> struct process_attribute<sibling> : process_attribute_default<sibling> {
+ static void init(const sibling &s, function_record *r) { r->sibling = s.value; }
+};
+
+/// Process an attribute which indicates that this function is a method
+template <> struct process_attribute<is_method> : process_attribute_default<is_method> {
+ static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
+};
+
+/// Process an attribute which indicates the parent scope of a method
+template <> struct process_attribute<scope> : process_attribute_default<scope> {
+ static void init(const scope &s, function_record *r) { r->scope = s.value; }
+};
+
+/// Process an attribute which indicates that this function is an operator
+template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
+ static void init(const is_operator &, function_record *r) { r->is_operator = true; }
+};
+
+template <> struct process_attribute<is_new_style_constructor> : process_attribute_default<is_new_style_constructor> {
+ static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; }
+};
+
+/// Process a keyword argument attribute (*without* a default value)
+template <> struct process_attribute<arg> : process_attribute_default<arg> {
+ static void init(const arg &a, function_record *r) {
+ if (r->is_method && r->args.empty())
+ r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/);
+ r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none);
+ }
+};
+
+/// Process a keyword argument attribute (*with* a default value)
+template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
+ static void init(const arg_v &a, function_record *r) {
+ if (r->is_method && r->args.empty())
+ r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/);
+
+ if (!a.value) {
+#if !defined(NDEBUG)
+ std::string descr("'");
+ if (a.name) descr += std::string(a.name) + ": ";
+ descr += a.type + "'";
+ if (r->is_method) {
+ if (r->name)
+ descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'";
+ else
+ descr += " in method of '" + (std::string) str(r->scope) + "'";
+ } else if (r->name) {
+ descr += " in function '" + (std::string) r->name + "'";
+ }
+ pybind11_fail("arg(): could not convert default argument "
+ + descr + " into a Python object (type not registered yet?)");
+#else
+ pybind11_fail("arg(): could not convert default argument "
+ "into a Python object (type not registered yet?). "
+ "Compile in debug mode for more information.");
+#endif
+ }
+ r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none);
+ }
+};
+
+/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
+template <typename T>
+struct process_attribute<T, enable_if_t<is_pyobject<T>::value>> : process_attribute_default<handle> {
+ static void init(const handle &h, type_record *r) { r->bases.append(h); }
+};
+
+/// Process a parent class attribute (deprecated, does not support multiple inheritance)
+template <typename T>
+struct process_attribute<base<T>> : process_attribute_default<base<T>> {
+ static void init(const base<T> &, type_record *r) { r->add_base(typeid(T), nullptr); }
+};
+
+/// Process a multiple inheritance attribute
+template <>
+struct process_attribute<multiple_inheritance> : process_attribute_default<multiple_inheritance> {
+ static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; }
+};
+
+template <>
+struct process_attribute<dynamic_attr> : process_attribute_default<dynamic_attr> {
+ static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
+};
+
+template <>
+struct process_attribute<buffer_protocol> : process_attribute_default<buffer_protocol> {
+ static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; }
+};
+
+template <>
+struct process_attribute<metaclass> : process_attribute_default<metaclass> {
+ static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
+};
+
+template <>
+struct process_attribute<module_local> : process_attribute_default<module_local> {
+ static void init(const module_local &l, type_record *r) { r->module_local = l.value; }
+};
+
+/// Process an 'arithmetic' attribute for enums (does nothing here)
+template <>
+struct process_attribute<arithmetic> : process_attribute_default<arithmetic> {};
+
+template <typename... Ts>
+struct process_attribute<call_guard<Ts...>> : process_attribute_default<call_guard<Ts...>> { };
+
+/**
+ * Process a keep_alive call policy -- invokes keep_alive_impl during the
+ * pre-call handler if both Nurse, Patient != 0 and use the post-call handler
+ * otherwise
+ */
+template <size_t Nurse, size_t Patient> struct process_attribute<keep_alive<Nurse, Patient>> : public process_attribute_default<keep_alive<Nurse, Patient>> {
+ template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
+ static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); }
+ template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
+ static void postcall(function_call &, handle) { }
+ template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
+ static void precall(function_call &) { }
+ template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
+ static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); }
+};
+
+/// Recursively iterate over variadic template arguments
+template <typename... Args> struct process_attributes {
+ static void init(const Args&... args, function_record *r) {
+ int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
+ ignore_unused(unused);
+ }
+ static void init(const Args&... args, type_record *r) {
+ int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
+ ignore_unused(unused);
+ }
+ static void precall(function_call &call) {
+ int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::precall(call), 0) ... };
+ ignore_unused(unused);
+ }
+ static void postcall(function_call &call, handle fn_ret) {
+ int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::postcall(call, fn_ret), 0) ... };
+ ignore_unused(unused);
+ }
+};
+
+template <typename T>
+using is_call_guard = is_instantiation<call_guard, T>;
+
+/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found)
+template <typename... Extra>
+using extract_guard_t = typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
+
+/// Check the number of named arguments at compile time
+template <typename... Extra,
+ size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
+ size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
+constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
+ return named == 0 || (self + named + has_args + has_kwargs) == nargs;
+}
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/buffer_info.h b/ml/dlib/dlib/external/pybind11/include/pybind11/buffer_info.h
new file mode 100644
index 000000000..9f072fa73
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/buffer_info.h
@@ -0,0 +1,108 @@
+/*
+ pybind11/buffer_info.h: Python buffer object interface
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "detail/common.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+/// Information record describing a Python buffer object
+struct buffer_info {
+ void *ptr = nullptr; // Pointer to the underlying storage
+ ssize_t itemsize = 0; // Size of individual items in bytes
+ ssize_t size = 0; // Total number of entries
+ std::string format; // For homogeneous buffers, this should be set to format_descriptor<T>::format()
+ ssize_t ndim = 0; // Number of dimensions
+ std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
+ std::vector<ssize_t> strides; // Number of entries between adjacent entries (for each per dimension)
+
+ buffer_info() { }
+
+ buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
+ detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
+ : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
+ shape(std::move(shape_in)), strides(std::move(strides_in)) {
+ if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
+ pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
+ for (size_t i = 0; i < (size_t) ndim; ++i)
+ size *= shape[i];
+ }
+
+ template <typename T>
+ buffer_info(T *ptr, detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
+ : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor<T>::format(), static_cast<ssize_t>(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
+
+ buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
+ : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
+
+ template <typename T>
+ buffer_info(T *ptr, ssize_t size)
+ : buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) { }
+
+ explicit buffer_info(Py_buffer *view, bool ownview = true)
+ : buffer_info(view->buf, view->itemsize, view->format, view->ndim,
+ {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
+ this->view = view;
+ this->ownview = ownview;
+ }
+
+ buffer_info(const buffer_info &) = delete;
+ buffer_info& operator=(const buffer_info &) = delete;
+
+ buffer_info(buffer_info &&other) {
+ (*this) = std::move(other);
+ }
+
+ buffer_info& operator=(buffer_info &&rhs) {
+ ptr = rhs.ptr;
+ itemsize = rhs.itemsize;
+ size = rhs.size;
+ format = std::move(rhs.format);
+ ndim = rhs.ndim;
+ shape = std::move(rhs.shape);
+ strides = std::move(rhs.strides);
+ std::swap(view, rhs.view);
+ std::swap(ownview, rhs.ownview);
+ return *this;
+ }
+
+ ~buffer_info() {
+ if (view && ownview) { PyBuffer_Release(view); delete view; }
+ }
+
+private:
+ struct private_ctr_tag { };
+
+ buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
+ detail::any_container<ssize_t> &&shape_in, detail::any_container<ssize_t> &&strides_in)
+ : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
+
+ Py_buffer *view = nullptr;
+ bool ownview = false;
+};
+
+NAMESPACE_BEGIN(detail)
+
+template <typename T, typename SFINAE = void> struct compare_buffer_info {
+ static bool compare(const buffer_info& b) {
+ return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
+ }
+};
+
+template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
+ static bool compare(const buffer_info& b) {
+ return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
+ ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
+ ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
+ }
+};
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/cast.h b/ml/dlib/dlib/external/pybind11/include/pybind11/cast.h
new file mode 100644
index 000000000..a722a9e81
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/cast.h
@@ -0,0 +1,2063 @@
+/*
+ pybind11/cast.h: Partial template specializations to cast between
+ C++ and Python types
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pytypes.h"
+#include "detail/typeid.h"
+#include "detail/descr.h"
+#include "detail/internals.h"
+#include <array>
+#include <limits>
+#include <tuple>
+
+#if defined(PYBIND11_CPP17)
+# if defined(__has_include)
+# if __has_include(<string_view>)
+# define PYBIND11_HAS_STRING_VIEW
+# endif
+# elif defined(_MSC_VER)
+# define PYBIND11_HAS_STRING_VIEW
+# endif
+#endif
+#ifdef PYBIND11_HAS_STRING_VIEW
+#include <string_view>
+#endif
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+/// A life support system for temporary objects created by `type_caster::load()`.
+/// Adding a patient will keep it alive up until the enclosing function returns.
+class loader_life_support {
+public:
+ /// A new patient frame is created when a function is entered
+ loader_life_support() {
+ get_internals().loader_patient_stack.push_back(nullptr);
+ }
+
+ /// ... and destroyed after it returns
+ ~loader_life_support() {
+ auto &stack = get_internals().loader_patient_stack;
+ if (stack.empty())
+ pybind11_fail("loader_life_support: internal error");
+
+ auto ptr = stack.back();
+ stack.pop_back();
+ Py_CLEAR(ptr);
+
+ // A heuristic to reduce the stack's capacity (e.g. after long recursive calls)
+ if (stack.capacity() > 16 && stack.size() != 0 && stack.capacity() / stack.size() > 2)
+ stack.shrink_to_fit();
+ }
+
+ /// This can only be used inside a pybind11-bound function, either by `argument_loader`
+ /// at argument preparation time or by `py::cast()` at execution time.
+ PYBIND11_NOINLINE static void add_patient(handle h) {
+ auto &stack = get_internals().loader_patient_stack;
+ if (stack.empty())
+ throw cast_error("When called outside a bound function, py::cast() cannot "
+ "do Python -> C++ conversions which require the creation "
+ "of temporary values");
+
+ auto &list_ptr = stack.back();
+ if (list_ptr == nullptr) {
+ list_ptr = PyList_New(1);
+ if (!list_ptr)
+ pybind11_fail("loader_life_support: error allocating list");
+ PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr());
+ } else {
+ auto result = PyList_Append(list_ptr, h.ptr());
+ if (result == -1)
+ pybind11_fail("loader_life_support: error adding patient");
+ }
+ }
+};
+
+// Gets the cache entry for the given type, creating it if necessary. The return value is the pair
+// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was
+// just created.
+inline std::pair<decltype(internals::registered_types_py)::iterator, bool> all_type_info_get_cache(PyTypeObject *type);
+
+// Populates a just-created cache entry.
+PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vector<type_info *> &bases) {
+ std::vector<PyTypeObject *> check;
+ for (handle parent : reinterpret_borrow<tuple>(t->tp_bases))
+ check.push_back((PyTypeObject *) parent.ptr());
+
+ auto const &type_dict = get_internals().registered_types_py;
+ for (size_t i = 0; i < check.size(); i++) {
+ auto type = check[i];
+ // Ignore Python2 old-style class super type:
+ if (!PyType_Check((PyObject *) type)) continue;
+
+ // Check `type` in the current set of registered python types:
+ auto it = type_dict.find(type);
+ if (it != type_dict.end()) {
+ // We found a cache entry for it, so it's either pybind-registered or has pre-computed
+ // pybind bases, but we have to make sure we haven't already seen the type(s) before: we
+ // want to follow Python/virtual C++ rules that there should only be one instance of a
+ // common base.
+ for (auto *tinfo : it->second) {
+ // NB: Could use a second set here, rather than doing a linear search, but since
+ // having a large number of immediate pybind11-registered types seems fairly
+ // unlikely, that probably isn't worthwhile.
+ bool found = false;
+ for (auto *known : bases) {
+ if (known == tinfo) { found = true; break; }
+ }
+ if (!found) bases.push_back(tinfo);
+ }
+ }
+ else if (type->tp_bases) {
+ // It's some python type, so keep follow its bases classes to look for one or more
+ // registered types
+ if (i + 1 == check.size()) {
+ // When we're at the end, we can pop off the current element to avoid growing
+ // `check` when adding just one base (which is typical--i.e. when there is no
+ // multiple inheritance)
+ check.pop_back();
+ i--;
+ }
+ for (handle parent : reinterpret_borrow<tuple>(type->tp_bases))
+ check.push_back((PyTypeObject *) parent.ptr());
+ }
+ }
+}
+
+/**
+ * Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will
+ * be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side
+ * derived class that uses single inheritance. Will contain as many types as required for a Python
+ * class that uses multiple inheritance to inherit (directly or indirectly) from multiple
+ * pybind-registered classes. Will be empty if neither the type nor any base classes are
+ * pybind-registered.
+ *
+ * The value is cached for the lifetime of the Python type.
+ */
+inline const std::vector<detail::type_info *> &all_type_info(PyTypeObject *type) {
+ auto ins = all_type_info_get_cache(type);
+ if (ins.second)
+ // New cache entry: populate it
+ all_type_info_populate(type, ins.first->second);
+
+ return ins.first->second;
+}
+
+/**
+ * Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any
+ * ancestors are pybind11-registered. Throws an exception if there are multiple bases--use
+ * `all_type_info` instead if you want to support multiple bases.
+ */
+PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) {
+ auto &bases = all_type_info(type);
+ if (bases.size() == 0)
+ return nullptr;
+ if (bases.size() > 1)
+ pybind11_fail("pybind11::detail::get_type_info: type has multiple pybind11-registered bases");
+ return bases.front();
+}
+
+inline detail::type_info *get_local_type_info(const std::type_index &tp) {
+ auto &locals = registered_local_types_cpp();
+ auto it = locals.find(tp);
+ if (it != locals.end())
+ return it->second;
+ return nullptr;
+}
+
+inline detail::type_info *get_global_type_info(const std::type_index &tp) {
+ auto &types = get_internals().registered_types_cpp;
+ auto it = types.find(tp);
+ if (it != types.end())
+ return it->second;
+ return nullptr;
+}
+
+/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr.
+PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp,
+ bool throw_if_missing = false) {
+ if (auto ltype = get_local_type_info(tp))
+ return ltype;
+ if (auto gtype = get_global_type_info(tp))
+ return gtype;
+
+ if (throw_if_missing) {
+ std::string tname = tp.name();
+ detail::clean_type_id(tname);
+ pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \"" + tname + "\"");
+ }
+ return nullptr;
+}
+
+PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool throw_if_missing) {
+ detail::type_info *type_info = get_type_info(tp, throw_if_missing);
+ return handle(type_info ? ((PyObject *) type_info->type) : nullptr);
+}
+
+struct value_and_holder {
+ instance *inst;
+ size_t index;
+ const detail::type_info *type;
+ void **vh;
+
+ // Main constructor for a found value/holder:
+ value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) :
+ inst{i}, index{index}, type{type},
+ vh{inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[vpos]}
+ {}
+
+ // Default constructor (used to signal a value-and-holder not found by get_value_and_holder())
+ value_and_holder() : inst{nullptr} {}
+
+ // Used for past-the-end iterator
+ value_and_holder(size_t index) : index{index} {}
+
+ template <typename V = void> V *&value_ptr() const {
+ return reinterpret_cast<V *&>(vh[0]);
+ }
+ // True if this `value_and_holder` has a non-null value pointer
+ explicit operator bool() const { return value_ptr(); }
+
+ template <typename H> H &holder() const {
+ return reinterpret_cast<H &>(vh[1]);
+ }
+ bool holder_constructed() const {
+ return inst->simple_layout
+ ? inst->simple_holder_constructed
+ : inst->nonsimple.status[index] & instance::status_holder_constructed;
+ }
+ void set_holder_constructed(bool v = true) {
+ if (inst->simple_layout)
+ inst->simple_holder_constructed = v;
+ else if (v)
+ inst->nonsimple.status[index] |= instance::status_holder_constructed;
+ else
+ inst->nonsimple.status[index] &= (uint8_t) ~instance::status_holder_constructed;
+ }
+ bool instance_registered() const {
+ return inst->simple_layout
+ ? inst->simple_instance_registered
+ : inst->nonsimple.status[index] & instance::status_instance_registered;
+ }
+ void set_instance_registered(bool v = true) {
+ if (inst->simple_layout)
+ inst->simple_instance_registered = v;
+ else if (v)
+ inst->nonsimple.status[index] |= instance::status_instance_registered;
+ else
+ inst->nonsimple.status[index] &= (uint8_t) ~instance::status_instance_registered;
+ }
+};
+
+// Container for accessing and iterating over an instance's values/holders
+struct values_and_holders {
+private:
+ instance *inst;
+ using type_vec = std::vector<detail::type_info *>;
+ const type_vec &tinfo;
+
+public:
+ values_and_holders(instance *inst) : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {}
+
+ struct iterator {
+ private:
+ instance *inst;
+ const type_vec *types;
+ value_and_holder curr;
+ friend struct values_and_holders;
+ iterator(instance *inst, const type_vec *tinfo)
+ : inst{inst}, types{tinfo},
+ curr(inst /* instance */,
+ types->empty() ? nullptr : (*types)[0] /* type info */,
+ 0, /* vpos: (non-simple types only): the first vptr comes first */
+ 0 /* index */)
+ {}
+ // Past-the-end iterator:
+ iterator(size_t end) : curr(end) {}
+ public:
+ bool operator==(const iterator &other) { return curr.index == other.curr.index; }
+ bool operator!=(const iterator &other) { return curr.index != other.curr.index; }
+ iterator &operator++() {
+ if (!inst->simple_layout)
+ curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs;
+ ++curr.index;
+ curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr;
+ return *this;
+ }
+ value_and_holder &operator*() { return curr; }
+ value_and_holder *operator->() { return &curr; }
+ };
+
+ iterator begin() { return iterator(inst, &tinfo); }
+ iterator end() { return iterator(tinfo.size()); }
+
+ iterator find(const type_info *find_type) {
+ auto it = begin(), endit = end();
+ while (it != endit && it->type != find_type) ++it;
+ return it;
+ }
+
+ size_t size() { return tinfo.size(); }
+};
+
+/**
+ * Extracts C++ value and holder pointer references from an instance (which may contain multiple
+ * values/holders for python-side multiple inheritance) that match the given type. Throws an error
+ * if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If
+ * `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned,
+ * regardless of type (and the resulting .type will be nullptr).
+ *
+ * The returned object should be short-lived: in particular, it must not outlive the called-upon
+ * instance.
+ */
+PYBIND11_NOINLINE inline value_and_holder instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/, bool throw_if_missing /*= true in common.h*/) {
+ // Optimize common case:
+ if (!find_type || Py_TYPE(this) == find_type->type)
+ return value_and_holder(this, find_type, 0, 0);
+
+ detail::values_and_holders vhs(this);
+ auto it = vhs.find(find_type);
+ if (it != vhs.end())
+ return *it;
+
+ if (!throw_if_missing)
+ return value_and_holder();
+
+#if defined(NDEBUG)
+ pybind11_fail("pybind11::detail::instance::get_value_and_holder: "
+ "type is not a pybind11 base of the given instance "
+ "(compile in debug mode for type details)");
+#else
+ pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" +
+ std::string(find_type->type->tp_name) + "' is not a pybind11 base of the given `" +
+ std::string(Py_TYPE(this)->tp_name) + "' instance");
+#endif
+}
+
+PYBIND11_NOINLINE inline void instance::allocate_layout() {
+ auto &tinfo = all_type_info(Py_TYPE(this));
+
+ const size_t n_types = tinfo.size();
+
+ if (n_types == 0)
+ pybind11_fail("instance allocation failed: new instance has no pybind11-registered base types");
+
+ simple_layout =
+ n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs();
+
+ // Simple path: no python-side multiple inheritance, and a small-enough holder
+ if (simple_layout) {
+ simple_value_holder[0] = nullptr;
+ simple_holder_constructed = false;
+ simple_instance_registered = false;
+ }
+ else { // multiple base types or a too-large holder
+ // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer,
+ // [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool
+ // values that tracks whether each associated holder has been initialized. Each [block] is
+ // padded, if necessary, to an integer multiple of sizeof(void *).
+ size_t space = 0;
+ for (auto t : tinfo) {
+ space += 1; // value pointer
+ space += t->holder_size_in_ptrs; // holder instance
+ }
+ size_t flags_at = space;
+ space += size_in_ptrs(n_types); // status bytes (holder_constructed and instance_registered)
+
+ // Allocate space for flags, values, and holders, and initialize it to 0 (flags and values,
+ // in particular, need to be 0). Use Python's memory allocation functions: in Python 3.6
+ // they default to using pymalloc, which is designed to be efficient for small allocations
+ // like the one we're doing here; in earlier versions (and for larger allocations) they are
+ // just wrappers around malloc.
+#if PY_VERSION_HEX >= 0x03050000
+ nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *));
+ if (!nonsimple.values_and_holders) throw std::bad_alloc();
+#else
+ nonsimple.values_and_holders = (void **) PyMem_New(void *, space);
+ if (!nonsimple.values_and_holders) throw std::bad_alloc();
+ std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *));
+#endif
+ nonsimple.status = reinterpret_cast<uint8_t *>(&nonsimple.values_and_holders[flags_at]);
+ }
+ owned = true;
+}
+
+PYBIND11_NOINLINE inline void instance::deallocate_layout() {
+ if (!simple_layout)
+ PyMem_Free(nonsimple.values_and_holders);
+}
+
+PYBIND11_NOINLINE inline bool isinstance_generic(handle obj, const std::type_info &tp) {
+ handle type = detail::get_type_handle(tp, false);
+ if (!type)
+ return false;
+ return isinstance(obj, type);
+}
+
+PYBIND11_NOINLINE inline std::string error_string() {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred");
+ return "Unknown internal error occurred";
+ }
+
+ error_scope scope; // Preserve error state
+
+ std::string errorString;
+ if (scope.type) {
+ errorString += handle(scope.type).attr("__name__").cast<std::string>();
+ errorString += ": ";
+ }
+ if (scope.value)
+ errorString += (std::string) str(scope.value);
+
+ PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace);
+
+#if PY_MAJOR_VERSION >= 3
+ if (scope.trace != nullptr)
+ PyException_SetTraceback(scope.value, scope.trace);
+#endif
+
+#if !defined(PYPY_VERSION)
+ if (scope.trace) {
+ PyTracebackObject *trace = (PyTracebackObject *) scope.trace;
+
+ /* Get the deepest trace possible */
+ while (trace->tb_next)
+ trace = trace->tb_next;
+
+ PyFrameObject *frame = trace->tb_frame;
+ errorString += "\n\nAt:\n";
+ while (frame) {
+ int lineno = PyFrame_GetLineNumber(frame);
+ errorString +=
+ " " + handle(frame->f_code->co_filename).cast<std::string>() +
+ "(" + std::to_string(lineno) + "): " +
+ handle(frame->f_code->co_name).cast<std::string>() + "\n";
+ frame = frame->f_back;
+ }
+ }
+#endif
+
+ return errorString;
+}
+
+PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr, const detail::type_info *type ) {
+ auto &instances = get_internals().registered_instances;
+ auto range = instances.equal_range(ptr);
+ for (auto it = range.first; it != range.second; ++it) {
+ for (auto vh : values_and_holders(it->second)) {
+ if (vh.type == type)
+ return handle((PyObject *) it->second);
+ }
+ }
+ return handle();
+}
+
+inline PyThreadState *get_thread_state_unchecked() {
+#if defined(PYPY_VERSION)
+ return PyThreadState_GET();
+#elif PY_VERSION_HEX < 0x03000000
+ return _PyThreadState_Current;
+#elif PY_VERSION_HEX < 0x03050000
+ return (PyThreadState*) _Py_atomic_load_relaxed(&_PyThreadState_Current);
+#elif PY_VERSION_HEX < 0x03050200
+ return (PyThreadState*) _PyThreadState_Current.value;
+#else
+ return _PyThreadState_UncheckedGet();
+#endif
+}
+
+// Forward declarations
+inline void keep_alive_impl(handle nurse, handle patient);
+inline PyObject *make_new_instance(PyTypeObject *type);
+
+class type_caster_generic {
+public:
+ PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info)
+ : typeinfo(get_type_info(type_info)), cpptype(&type_info) { }
+
+ type_caster_generic(const type_info *typeinfo)
+ : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) { }
+
+ bool load(handle src, bool convert) {
+ return load_impl<type_caster_generic>(src, convert);
+ }
+
+ PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent,
+ const detail::type_info *tinfo,
+ void *(*copy_constructor)(const void *),
+ void *(*move_constructor)(const void *),
+ const void *existing_holder = nullptr) {
+ if (!tinfo) // no type info: error will be set already
+ return handle();
+
+ void *src = const_cast<void *>(_src);
+ if (src == nullptr)
+ return none().release();
+
+ auto it_instances = get_internals().registered_instances.equal_range(src);
+ for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) {
+ for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) {
+ if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype))
+ return handle((PyObject *) it_i->second).inc_ref();
+ }
+ }
+
+ auto inst = reinterpret_steal<object>(make_new_instance(tinfo->type));
+ auto wrapper = reinterpret_cast<instance *>(inst.ptr());
+ wrapper->owned = false;
+ void *&valueptr = values_and_holders(wrapper).begin()->value_ptr();
+
+ switch (policy) {
+ case return_value_policy::automatic:
+ case return_value_policy::take_ownership:
+ valueptr = src;
+ wrapper->owned = true;
+ break;
+
+ case return_value_policy::automatic_reference:
+ case return_value_policy::reference:
+ valueptr = src;
+ wrapper->owned = false;
+ break;
+
+ case return_value_policy::copy:
+ if (copy_constructor)
+ valueptr = copy_constructor(src);
+ else
+ throw cast_error("return_value_policy = copy, but the "
+ "object is non-copyable!");
+ wrapper->owned = true;
+ break;
+
+ case return_value_policy::move:
+ if (move_constructor)
+ valueptr = move_constructor(src);
+ else if (copy_constructor)
+ valueptr = copy_constructor(src);
+ else
+ throw cast_error("return_value_policy = move, but the "
+ "object is neither movable nor copyable!");
+ wrapper->owned = true;
+ break;
+
+ case return_value_policy::reference_internal:
+ valueptr = src;
+ wrapper->owned = false;
+ keep_alive_impl(inst, parent);
+ break;
+
+ default:
+ throw cast_error("unhandled return_value_policy: should not happen!");
+ }
+
+ tinfo->init_instance(wrapper, existing_holder);
+
+ return inst.release();
+ }
+
+ // Base methods for generic caster; there are overridden in copyable_holder_caster
+ void load_value(value_and_holder &&v_h) {
+ auto *&vptr = v_h.value_ptr();
+ // Lazy allocation for unallocated values:
+ if (vptr == nullptr) {
+ auto *type = v_h.type ? v_h.type : typeinfo;
+ vptr = type->operator_new(type->type_size);
+ }
+ value = vptr;
+ }
+ bool try_implicit_casts(handle src, bool convert) {
+ for (auto &cast : typeinfo->implicit_casts) {
+ type_caster_generic sub_caster(*cast.first);
+ if (sub_caster.load(src, convert)) {
+ value = cast.second(sub_caster.value);
+ return true;
+ }
+ }
+ return false;
+ }
+ bool try_direct_conversions(handle src) {
+ for (auto &converter : *typeinfo->direct_conversions) {
+ if (converter(src.ptr(), value))
+ return true;
+ }
+ return false;
+ }
+ void check_holder_compat() {}
+
+ PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) {
+ auto caster = type_caster_generic(ti);
+ if (caster.load(src, false))
+ return caster.value;
+ return nullptr;
+ }
+
+ /// Try to load with foreign typeinfo, if available. Used when there is no
+ /// native typeinfo, or when the native one wasn't able to produce a value.
+ PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) {
+ constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID;
+ const auto pytype = src.get_type();
+ if (!hasattr(pytype, local_key))
+ return false;
+
+ type_info *foreign_typeinfo = reinterpret_borrow<capsule>(getattr(pytype, local_key));
+ // Only consider this foreign loader if actually foreign and is a loader of the correct cpp type
+ if (foreign_typeinfo->module_local_load == &local_load
+ || (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype)))
+ return false;
+
+ if (auto result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) {
+ value = result;
+ return true;
+ }
+ return false;
+ }
+
+ // Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant
+ // bits of code between here and copyable_holder_caster where the two classes need different
+ // logic (without having to resort to virtual inheritance).
+ template <typename ThisT>
+ PYBIND11_NOINLINE bool load_impl(handle src, bool convert) {
+ if (!src) return false;
+ if (!typeinfo) return try_load_foreign_module_local(src);
+ if (src.is_none()) {
+ // Defer accepting None to other overloads (if we aren't in convert mode):
+ if (!convert) return false;
+ value = nullptr;
+ return true;
+ }
+
+ auto &this_ = static_cast<ThisT &>(*this);
+ this_.check_holder_compat();
+
+ PyTypeObject *srctype = Py_TYPE(src.ptr());
+
+ // Case 1: If src is an exact type match for the target type then we can reinterpret_cast
+ // the instance's value pointer to the target type:
+ if (srctype == typeinfo->type) {
+ this_.load_value(reinterpret_cast<instance *>(src.ptr())->get_value_and_holder());
+ return true;
+ }
+ // Case 2: We have a derived class
+ else if (PyType_IsSubtype(srctype, typeinfo->type)) {
+ auto &bases = all_type_info(srctype);
+ bool no_cpp_mi = typeinfo->simple_type;
+
+ // Case 2a: the python type is a Python-inherited derived class that inherits from just
+ // one simple (no MI) pybind11 class, or is an exact match, so the C++ instance is of
+ // the right type and we can use reinterpret_cast.
+ // (This is essentially the same as case 2b, but because not using multiple inheritance
+ // is extremely common, we handle it specially to avoid the loop iterator and type
+ // pointer lookup overhead)
+ if (bases.size() == 1 && (no_cpp_mi || bases.front()->type == typeinfo->type)) {
+ this_.load_value(reinterpret_cast<instance *>(src.ptr())->get_value_and_holder());
+ return true;
+ }
+ // Case 2b: the python type inherits from multiple C++ bases. Check the bases to see if
+ // we can find an exact match (or, for a simple C++ type, an inherited match); if so, we
+ // can safely reinterpret_cast to the relevant pointer.
+ else if (bases.size() > 1) {
+ for (auto base : bases) {
+ if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) : base->type == typeinfo->type) {
+ this_.load_value(reinterpret_cast<instance *>(src.ptr())->get_value_and_holder(base));
+ return true;
+ }
+ }
+ }
+
+ // Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type match
+ // in the registered bases, above, so try implicit casting (needed for proper C++ casting
+ // when MI is involved).
+ if (this_.try_implicit_casts(src, convert))
+ return true;
+ }
+
+ // Perform an implicit conversion
+ if (convert) {
+ for (auto &converter : typeinfo->implicit_conversions) {
+ auto temp = reinterpret_steal<object>(converter(src.ptr(), typeinfo->type));
+ if (load_impl<ThisT>(temp, false)) {
+ loader_life_support::add_patient(temp);
+ return true;
+ }
+ }
+ if (this_.try_direct_conversions(src))
+ return true;
+ }
+
+ // Failed to match local typeinfo. Try again with global.
+ if (typeinfo->module_local) {
+ if (auto gtype = get_global_type_info(*typeinfo->cpptype)) {
+ typeinfo = gtype;
+ return load(src, false);
+ }
+ }
+
+ // Global typeinfo has precedence over foreign module_local
+ return try_load_foreign_module_local(src);
+ }
+
+
+ // Called to do type lookup and wrap the pointer and type in a pair when a dynamic_cast
+ // isn't needed or can't be used. If the type is unknown, sets the error and returns a pair
+ // with .second = nullptr. (p.first = nullptr is not an error: it becomes None).
+ PYBIND11_NOINLINE static std::pair<const void *, const type_info *> src_and_type(
+ const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) {
+ if (auto *tpi = get_type_info(cast_type))
+ return {src, const_cast<const type_info *>(tpi)};
+
+ // Not found, set error:
+ std::string tname = rtti_type ? rtti_type->name() : cast_type.name();
+ detail::clean_type_id(tname);
+ std::string msg = "Unregistered type : " + tname;
+ PyErr_SetString(PyExc_TypeError, msg.c_str());
+ return {nullptr, nullptr};
+ }
+
+ const type_info *typeinfo = nullptr;
+ const std::type_info *cpptype = nullptr;
+ void *value = nullptr;
+};
+
+/**
+ * Determine suitable casting operator for pointer-or-lvalue-casting type casters. The type caster
+ * needs to provide `operator T*()` and `operator T&()` operators.
+ *
+ * If the type supports moving the value away via an `operator T&&() &&` method, it should use
+ * `movable_cast_op_type` instead.
+ */
+template <typename T>
+using cast_op_type =
+ conditional_t<std::is_pointer<remove_reference_t<T>>::value,
+ typename std::add_pointer<intrinsic_t<T>>::type,
+ typename std::add_lvalue_reference<intrinsic_t<T>>::type>;
+
+/**
+ * Determine suitable casting operator for a type caster with a movable value. Such a type caster
+ * needs to provide `operator T*()`, `operator T&()`, and `operator T&&() &&`. The latter will be
+ * called in appropriate contexts where the value can be moved rather than copied.
+ *
+ * These operator are automatically provided when using the PYBIND11_TYPE_CASTER macro.
+ */
+template <typename T>
+using movable_cast_op_type =
+ conditional_t<std::is_pointer<typename std::remove_reference<T>::type>::value,
+ typename std::add_pointer<intrinsic_t<T>>::type,
+ conditional_t<std::is_rvalue_reference<T>::value,
+ typename std::add_rvalue_reference<intrinsic_t<T>>::type,
+ typename std::add_lvalue_reference<intrinsic_t<T>>::type>>;
+
+// std::is_copy_constructible isn't quite enough: it lets std::vector<T> (and similar) through when
+// T is non-copyable, but code containing such a copy constructor fails to actually compile.
+template <typename T, typename SFINAE = void> struct is_copy_constructible : std::is_copy_constructible<T> {};
+
+// Specialization for types that appear to be copy constructible but also look like stl containers
+// (we specifically check for: has `value_type` and `reference` with `reference = value_type&`): if
+// so, copy constructability depends on whether the value_type is copy constructible.
+template <typename Container> struct is_copy_constructible<Container, enable_if_t<all_of<
+ std::is_copy_constructible<Container>,
+ std::is_same<typename Container::value_type &, typename Container::reference>
+ >::value>> : is_copy_constructible<typename Container::value_type> {};
+
+#if !defined(PYBIND11_CPP17)
+// Likewise for std::pair before C++17 (which mandates that the copy constructor not exist when the
+// two types aren't themselves copy constructible).
+template <typename T1, typename T2> struct is_copy_constructible<std::pair<T1, T2>>
+ : all_of<is_copy_constructible<T1>, is_copy_constructible<T2>> {};
+#endif
+
+/// Generic type caster for objects stored on the heap
+template <typename type> class type_caster_base : public type_caster_generic {
+ using itype = intrinsic_t<type>;
+public:
+ static PYBIND11_DESCR name() { return type_descr(_<type>()); }
+
+ type_caster_base() : type_caster_base(typeid(type)) { }
+ explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { }
+
+ static handle cast(const itype &src, return_value_policy policy, handle parent) {
+ if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
+ policy = return_value_policy::copy;
+ return cast(&src, policy, parent);
+ }
+
+ static handle cast(itype &&src, return_value_policy, handle parent) {
+ return cast(&src, return_value_policy::move, parent);
+ }
+
+ // Returns a (pointer, type_info) pair taking care of necessary RTTI type lookup for a
+ // polymorphic type. If the instance isn't derived, returns the non-RTTI base version.
+ template <typename T = itype, enable_if_t<std::is_polymorphic<T>::value, int> = 0>
+ static std::pair<const void *, const type_info *> src_and_type(const itype *src) {
+ const void *vsrc = src;
+ auto &cast_type = typeid(itype);
+ const std::type_info *instance_type = nullptr;
+ if (vsrc) {
+ instance_type = &typeid(*src);
+ if (!same_type(cast_type, *instance_type)) {
+ // This is a base pointer to a derived type; if it is a pybind11-registered type, we
+ // can get the correct derived pointer (which may be != base pointer) by a
+ // dynamic_cast to most derived type:
+ if (auto *tpi = get_type_info(*instance_type))
+ return {dynamic_cast<const void *>(src), const_cast<const type_info *>(tpi)};
+ }
+ }
+ // Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so
+ // don't do a cast
+ return type_caster_generic::src_and_type(vsrc, cast_type, instance_type);
+ }
+
+ // Non-polymorphic type, so no dynamic casting; just call the generic version directly
+ template <typename T = itype, enable_if_t<!std::is_polymorphic<T>::value, int> = 0>
+ static std::pair<const void *, const type_info *> src_and_type(const itype *src) {
+ return type_caster_generic::src_and_type(src, typeid(itype));
+ }
+
+ static handle cast(const itype *src, return_value_policy policy, handle parent) {
+ auto st = src_and_type(src);
+ return type_caster_generic::cast(
+ st.first, policy, parent, st.second,
+ make_copy_constructor(src), make_move_constructor(src));
+ }
+
+ static handle cast_holder(const itype *src, const void *holder) {
+ auto st = src_and_type(src);
+ return type_caster_generic::cast(
+ st.first, return_value_policy::take_ownership, {}, st.second,
+ nullptr, nullptr, holder);
+ }
+
+ template <typename T> using cast_op_type = cast_op_type<T>;
+
+ operator itype*() { return (type *) value; }
+ operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); }
+
+protected:
+ using Constructor = void *(*)(const void *);
+
+ /* Only enabled when the types are {copy,move}-constructible *and* when the type
+ does not have a private operator new implementation. */
+ template <typename T, typename = enable_if_t<is_copy_constructible<T>::value>>
+ static auto make_copy_constructor(const T *x) -> decltype(new T(*x), Constructor{}) {
+ return [](const void *arg) -> void * {
+ return new T(*reinterpret_cast<const T *>(arg));
+ };
+ }
+
+ template <typename T, typename = enable_if_t<std::is_move_constructible<T>::value>>
+ static auto make_move_constructor(const T *x) -> decltype(new T(std::move(*const_cast<T *>(x))), Constructor{}) {
+ return [](const void *arg) -> void * {
+ return new T(std::move(*const_cast<T *>(reinterpret_cast<const T *>(arg))));
+ };
+ }
+
+ static Constructor make_copy_constructor(...) { return nullptr; }
+ static Constructor make_move_constructor(...) { return nullptr; }
+};
+
+template <typename type, typename SFINAE = void> class type_caster : public type_caster_base<type> { };
+template <typename type> using make_caster = type_caster<intrinsic_t<type>>;
+
+// Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T
+template <typename T> typename make_caster<T>::template cast_op_type<T> cast_op(make_caster<T> &caster) {
+ return caster.operator typename make_caster<T>::template cast_op_type<T>();
+}
+template <typename T> typename make_caster<T>::template cast_op_type<typename std::add_rvalue_reference<T>::type>
+cast_op(make_caster<T> &&caster) {
+ return std::move(caster).operator
+ typename make_caster<T>::template cast_op_type<typename std::add_rvalue_reference<T>::type>();
+}
+
+template <typename type> class type_caster<std::reference_wrapper<type>> {
+private:
+ using caster_t = make_caster<type>;
+ caster_t subcaster;
+ using subcaster_cast_op_type = typename caster_t::template cast_op_type<type>;
+ static_assert(std::is_same<typename std::remove_const<type>::type &, subcaster_cast_op_type>::value,
+ "std::reference_wrapper<T> caster requires T to have a caster with an `T &` operator");
+public:
+ bool load(handle src, bool convert) { return subcaster.load(src, convert); }
+ static PYBIND11_DESCR name() { return caster_t::name(); }
+ static handle cast(const std::reference_wrapper<type> &src, return_value_policy policy, handle parent) {
+ // It is definitely wrong to take ownership of this pointer, so mask that rvp
+ if (policy == return_value_policy::take_ownership || policy == return_value_policy::automatic)
+ policy = return_value_policy::automatic_reference;
+ return caster_t::cast(&src.get(), policy, parent);
+ }
+ template <typename T> using cast_op_type = std::reference_wrapper<type>;
+ operator std::reference_wrapper<type>() { return subcaster.operator subcaster_cast_op_type&(); }
+};
+
+#define PYBIND11_TYPE_CASTER(type, py_name) \
+ protected: \
+ type value; \
+ public: \
+ static PYBIND11_DESCR name() { return type_descr(py_name); } \
+ template <typename T_, enable_if_t<std::is_same<type, remove_cv_t<T_>>::value, int> = 0> \
+ static handle cast(T_ *src, return_value_policy policy, handle parent) { \
+ if (!src) return none().release(); \
+ if (policy == return_value_policy::take_ownership) { \
+ auto h = cast(std::move(*src), policy, parent); delete src; return h; \
+ } else { \
+ return cast(*src, policy, parent); \
+ } \
+ } \
+ operator type*() { return &value; } \
+ operator type&() { return value; } \
+ operator type&&() && { return std::move(value); } \
+ template <typename T_> using cast_op_type = pybind11::detail::movable_cast_op_type<T_>
+
+
+template <typename CharT> using is_std_char_type = any_of<
+ std::is_same<CharT, char>, /* std::string */
+ std::is_same<CharT, char16_t>, /* std::u16string */
+ std::is_same<CharT, char32_t>, /* std::u32string */
+ std::is_same<CharT, wchar_t> /* std::wstring */
+>;
+
+template <typename T>
+struct type_caster<T, enable_if_t<std::is_arithmetic<T>::value && !is_std_char_type<T>::value>> {
+ using _py_type_0 = conditional_t<sizeof(T) <= sizeof(long), long, long long>;
+ using _py_type_1 = conditional_t<std::is_signed<T>::value, _py_type_0, typename std::make_unsigned<_py_type_0>::type>;
+ using py_type = conditional_t<std::is_floating_point<T>::value, double, _py_type_1>;
+public:
+
+ bool load(handle src, bool convert) {
+ py_type py_value;
+
+ if (!src)
+ return false;
+
+ if (std::is_floating_point<T>::value) {
+ if (convert || PyFloat_Check(src.ptr()))
+ py_value = (py_type) PyFloat_AsDouble(src.ptr());
+ else
+ return false;
+ } else if (PyFloat_Check(src.ptr())) {
+ return false;
+ } else if (std::is_unsigned<py_type>::value) {
+ py_value = as_unsigned<py_type>(src.ptr());
+ } else { // signed integer:
+ py_value = sizeof(T) <= sizeof(long)
+ ? (py_type) PyLong_AsLong(src.ptr())
+ : (py_type) PYBIND11_LONG_AS_LONGLONG(src.ptr());
+ }
+
+ bool py_err = py_value == (py_type) -1 && PyErr_Occurred();
+ if (py_err || (std::is_integral<T>::value && sizeof(py_type) != sizeof(T) &&
+ (py_value < (py_type) std::numeric_limits<T>::min() ||
+ py_value > (py_type) std::numeric_limits<T>::max()))) {
+ bool type_error = py_err && PyErr_ExceptionMatches(
+#if PY_VERSION_HEX < 0x03000000 && !defined(PYPY_VERSION)
+ PyExc_SystemError
+#else
+ PyExc_TypeError
+#endif
+ );
+ PyErr_Clear();
+ if (type_error && convert && PyNumber_Check(src.ptr())) {
+ auto tmp = reinterpret_steal<object>(std::is_floating_point<T>::value
+ ? PyNumber_Float(src.ptr())
+ : PyNumber_Long(src.ptr()));
+ PyErr_Clear();
+ return load(tmp, false);
+ }
+ return false;
+ }
+
+ value = (T) py_value;
+ return true;
+ }
+
+ static handle cast(T src, return_value_policy /* policy */, handle /* parent */) {
+ if (std::is_floating_point<T>::value) {
+ return PyFloat_FromDouble((double) src);
+ } else if (sizeof(T) <= sizeof(long)) {
+ if (std::is_signed<T>::value)
+ return PyLong_FromLong((long) src);
+ else
+ return PyLong_FromUnsignedLong((unsigned long) src);
+ } else {
+ if (std::is_signed<T>::value)
+ return PyLong_FromLongLong((long long) src);
+ else
+ return PyLong_FromUnsignedLongLong((unsigned long long) src);
+ }
+ }
+
+ PYBIND11_TYPE_CASTER(T, _<std::is_integral<T>::value>("int", "float"));
+};
+
+template<typename T> struct void_caster {
+public:
+ bool load(handle src, bool) {
+ if (src && src.is_none())
+ return true;
+ return false;
+ }
+ static handle cast(T, return_value_policy /* policy */, handle /* parent */) {
+ return none().inc_ref();
+ }
+ PYBIND11_TYPE_CASTER(T, _("None"));
+};
+
+template <> class type_caster<void_type> : public void_caster<void_type> {};
+
+template <> class type_caster<void> : public type_caster<void_type> {
+public:
+ using type_caster<void_type>::cast;
+
+ bool load(handle h, bool) {
+ if (!h) {
+ return false;
+ } else if (h.is_none()) {
+ value = nullptr;
+ return true;
+ }
+
+ /* Check if this is a capsule */
+ if (isinstance<capsule>(h)) {
+ value = reinterpret_borrow<capsule>(h);
+ return true;
+ }
+
+ /* Check if this is a C++ type */
+ auto &bases = all_type_info((PyTypeObject *) h.get_type().ptr());
+ if (bases.size() == 1) { // Only allowing loading from a single-value type
+ value = values_and_holders(reinterpret_cast<instance *>(h.ptr())).begin()->value_ptr();
+ return true;
+ }
+
+ /* Fail */
+ return false;
+ }
+
+ static handle cast(const void *ptr, return_value_policy /* policy */, handle /* parent */) {
+ if (ptr)
+ return capsule(ptr).release();
+ else
+ return none().inc_ref();
+ }
+
+ template <typename T> using cast_op_type = void*&;
+ operator void *&() { return value; }
+ static PYBIND11_DESCR name() { return type_descr(_("capsule")); }
+private:
+ void *value = nullptr;
+};
+
+template <> class type_caster<std::nullptr_t> : public void_caster<std::nullptr_t> { };
+
+template <> class type_caster<bool> {
+public:
+ bool load(handle src, bool convert) {
+ if (!src) return false;
+ else if (src.ptr() == Py_True) { value = true; return true; }
+ else if (src.ptr() == Py_False) { value = false; return true; }
+ else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) {
+ // (allow non-implicit conversion for numpy booleans)
+
+ Py_ssize_t res = -1;
+ if (src.is_none()) {
+ res = 0; // None is implicitly converted to False
+ }
+ #if defined(PYPY_VERSION)
+ // On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists
+ else if (hasattr(src, PYBIND11_BOOL_ATTR)) {
+ res = PyObject_IsTrue(src.ptr());
+ }
+ #else
+ // Alternate approach for CPython: this does the same as the above, but optimized
+ // using the CPython API so as to avoid an unneeded attribute lookup.
+ else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) {
+ if (PYBIND11_NB_BOOL(tp_as_number)) {
+ res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr());
+ }
+ }
+ #endif
+ if (res == 0 || res == 1) {
+ value = (bool) res;
+ return true;
+ }
+ }
+ return false;
+ }
+ static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) {
+ return handle(src ? Py_True : Py_False).inc_ref();
+ }
+ PYBIND11_TYPE_CASTER(bool, _("bool"));
+};
+
+// Helper class for UTF-{8,16,32} C++ stl strings:
+template <typename StringType, bool IsView = false> struct string_caster {
+ using CharT = typename StringType::value_type;
+
+ // Simplify life by being able to assume standard char sizes (the standard only guarantees
+ // minimums, but Python requires exact sizes)
+ static_assert(!std::is_same<CharT, char>::value || sizeof(CharT) == 1, "Unsupported char size != 1");
+ static_assert(!std::is_same<CharT, char16_t>::value || sizeof(CharT) == 2, "Unsupported char16_t size != 2");
+ static_assert(!std::is_same<CharT, char32_t>::value || sizeof(CharT) == 4, "Unsupported char32_t size != 4");
+ // wchar_t can be either 16 bits (Windows) or 32 (everywhere else)
+ static_assert(!std::is_same<CharT, wchar_t>::value || sizeof(CharT) == 2 || sizeof(CharT) == 4,
+ "Unsupported wchar_t size != 2/4");
+ static constexpr size_t UTF_N = 8 * sizeof(CharT);
+
+ bool load(handle src, bool) {
+#if PY_MAJOR_VERSION < 3
+ object temp;
+#endif
+ handle load_src = src;
+ if (!src) {
+ return false;
+ } else if (!PyUnicode_Check(load_src.ptr())) {
+#if PY_MAJOR_VERSION >= 3
+ return load_bytes(load_src);
+#else
+ if (sizeof(CharT) == 1) {
+ return load_bytes(load_src);
+ }
+
+ // The below is a guaranteed failure in Python 3 when PyUnicode_Check returns false
+ if (!PYBIND11_BYTES_CHECK(load_src.ptr()))
+ return false;
+
+ temp = reinterpret_steal<object>(PyUnicode_FromObject(load_src.ptr()));
+ if (!temp) { PyErr_Clear(); return false; }
+ load_src = temp;
+#endif
+ }
+
+ object utfNbytes = reinterpret_steal<object>(PyUnicode_AsEncodedString(
+ load_src.ptr(), UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr));
+ if (!utfNbytes) { PyErr_Clear(); return false; }
+
+ const CharT *buffer = reinterpret_cast<const CharT *>(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr()));
+ size_t length = (size_t) PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT);
+ if (UTF_N > 8) { buffer++; length--; } // Skip BOM for UTF-16/32
+ value = StringType(buffer, length);
+
+ // If we're loading a string_view we need to keep the encoded Python object alive:
+ if (IsView)
+ loader_life_support::add_patient(utfNbytes);
+
+ return true;
+ }
+
+ static handle cast(const StringType &src, return_value_policy /* policy */, handle /* parent */) {
+ const char *buffer = reinterpret_cast<const char *>(src.data());
+ ssize_t nbytes = ssize_t(src.size() * sizeof(CharT));
+ handle s = decode_utfN(buffer, nbytes);
+ if (!s) throw error_already_set();
+ return s;
+ }
+
+ PYBIND11_TYPE_CASTER(StringType, _(PYBIND11_STRING_NAME));
+
+private:
+ static handle decode_utfN(const char *buffer, ssize_t nbytes) {
+#if !defined(PYPY_VERSION)
+ return
+ UTF_N == 8 ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr) :
+ UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) :
+ PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr);
+#else
+ // PyPy seems to have multiple problems related to PyUnicode_UTF*: the UTF8 version
+ // sometimes segfaults for unknown reasons, while the UTF16 and 32 versions require a
+ // non-const char * arguments, which is also a nuissance, so bypass the whole thing by just
+ // passing the encoding as a string value, which works properly:
+ return PyUnicode_Decode(buffer, nbytes, UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr);
+#endif
+ }
+
+ // When loading into a std::string or char*, accept a bytes object as-is (i.e.
+ // without any encoding/decoding attempt). For other C++ char sizes this is a no-op.
+ // which supports loading a unicode from a str, doesn't take this path.
+ template <typename C = CharT>
+ bool load_bytes(enable_if_t<sizeof(C) == 1, handle> src) {
+ if (PYBIND11_BYTES_CHECK(src.ptr())) {
+ // We were passed a Python 3 raw bytes; accept it into a std::string or char*
+ // without any encoding attempt.
+ const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr());
+ if (bytes) {
+ value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr()));
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ template <typename C = CharT>
+ bool load_bytes(enable_if_t<sizeof(C) != 1, handle>) { return false; }
+};
+
+template <typename CharT, class Traits, class Allocator>
+struct type_caster<std::basic_string<CharT, Traits, Allocator>, enable_if_t<is_std_char_type<CharT>::value>>
+ : string_caster<std::basic_string<CharT, Traits, Allocator>> {};
+
+#ifdef PYBIND11_HAS_STRING_VIEW
+template <typename CharT, class Traits>
+struct type_caster<std::basic_string_view<CharT, Traits>, enable_if_t<is_std_char_type<CharT>::value>>
+ : string_caster<std::basic_string_view<CharT, Traits>, true> {};
+#endif
+
+// Type caster for C-style strings. We basically use a std::string type caster, but also add the
+// ability to use None as a nullptr char* (which the string caster doesn't allow).
+template <typename CharT> struct type_caster<CharT, enable_if_t<is_std_char_type<CharT>::value>> {
+ using StringType = std::basic_string<CharT>;
+ using StringCaster = type_caster<StringType>;
+ StringCaster str_caster;
+ bool none = false;
+ CharT one_char = 0;
+public:
+ bool load(handle src, bool convert) {
+ if (!src) return false;
+ if (src.is_none()) {
+ // Defer accepting None to other overloads (if we aren't in convert mode):
+ if (!convert) return false;
+ none = true;
+ return true;
+ }
+ return str_caster.load(src, convert);
+ }
+
+ static handle cast(const CharT *src, return_value_policy policy, handle parent) {
+ if (src == nullptr) return pybind11::none().inc_ref();
+ return StringCaster::cast(StringType(src), policy, parent);
+ }
+
+ static handle cast(CharT src, return_value_policy policy, handle parent) {
+ if (std::is_same<char, CharT>::value) {
+ handle s = PyUnicode_DecodeLatin1((const char *) &src, 1, nullptr);
+ if (!s) throw error_already_set();
+ return s;
+ }
+ return StringCaster::cast(StringType(1, src), policy, parent);
+ }
+
+ operator CharT*() { return none ? nullptr : const_cast<CharT *>(static_cast<StringType &>(str_caster).c_str()); }
+ operator CharT&() {
+ if (none)
+ throw value_error("Cannot convert None to a character");
+
+ auto &value = static_cast<StringType &>(str_caster);
+ size_t str_len = value.size();
+ if (str_len == 0)
+ throw value_error("Cannot convert empty string to a character");
+
+ // If we're in UTF-8 mode, we have two possible failures: one for a unicode character that
+ // is too high, and one for multiple unicode characters (caught later), so we need to figure
+ // out how long the first encoded character is in bytes to distinguish between these two
+ // errors. We also allow want to allow unicode characters U+0080 through U+00FF, as those
+ // can fit into a single char value.
+ if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) {
+ unsigned char v0 = static_cast<unsigned char>(value[0]);
+ size_t char0_bytes = !(v0 & 0x80) ? 1 : // low bits only: 0-127
+ (v0 & 0xE0) == 0xC0 ? 2 : // 0b110xxxxx - start of 2-byte sequence
+ (v0 & 0xF0) == 0xE0 ? 3 : // 0b1110xxxx - start of 3-byte sequence
+ 4; // 0b11110xxx - start of 4-byte sequence
+
+ if (char0_bytes == str_len) {
+ // If we have a 128-255 value, we can decode it into a single char:
+ if (char0_bytes == 2 && (v0 & 0xFC) == 0xC0) { // 0x110000xx 0x10xxxxxx
+ one_char = static_cast<CharT>(((v0 & 3) << 6) + (static_cast<unsigned char>(value[1]) & 0x3F));
+ return one_char;
+ }
+ // Otherwise we have a single character, but it's > U+00FF
+ throw value_error("Character code point not in range(0x100)");
+ }
+ }
+
+ // UTF-16 is much easier: we can only have a surrogate pair for values above U+FFFF, thus a
+ // surrogate pair with total length 2 instantly indicates a range error (but not a "your
+ // string was too long" error).
+ else if (StringCaster::UTF_N == 16 && str_len == 2) {
+ one_char = static_cast<CharT>(value[0]);
+ if (one_char >= 0xD800 && one_char < 0xE000)
+ throw value_error("Character code point not in range(0x10000)");
+ }
+
+ if (str_len != 1)
+ throw value_error("Expected a character, but multi-character string found");
+
+ one_char = value[0];
+ return one_char;
+ }
+
+ static PYBIND11_DESCR name() { return type_descr(_(PYBIND11_STRING_NAME)); }
+ template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
+};
+
+// Base implementation for std::tuple and std::pair
+template <template<typename...> class Tuple, typename... Ts> class tuple_caster {
+ using type = Tuple<Ts...>;
+ static constexpr auto size = sizeof...(Ts);
+ using indices = make_index_sequence<size>;
+public:
+
+ bool load(handle src, bool convert) {
+ if (!isinstance<sequence>(src))
+ return false;
+ const auto seq = reinterpret_borrow<sequence>(src);
+ if (seq.size() != size)
+ return false;
+ return load_impl(seq, convert, indices{});
+ }
+
+ template <typename T>
+ static handle cast(T &&src, return_value_policy policy, handle parent) {
+ return cast_impl(std::forward<T>(src), policy, parent, indices{});
+ }
+
+ static PYBIND11_DESCR name() {
+ return type_descr(_("Tuple[") + detail::concat(make_caster<Ts>::name()...) + _("]"));
+ }
+
+ template <typename T> using cast_op_type = type;
+
+ operator type() & { return implicit_cast(indices{}); }
+ operator type() && { return std::move(*this).implicit_cast(indices{}); }
+
+protected:
+ template <size_t... Is>
+ type implicit_cast(index_sequence<Is...>) & { return type(cast_op<Ts>(std::get<Is>(subcasters))...); }
+ template <size_t... Is>
+ type implicit_cast(index_sequence<Is...>) && { return type(cast_op<Ts>(std::move(std::get<Is>(subcasters)))...); }
+
+ static constexpr bool load_impl(const sequence &, bool, index_sequence<>) { return true; }
+
+ template <size_t... Is>
+ bool load_impl(const sequence &seq, bool convert, index_sequence<Is...>) {
+ for (bool r : {std::get<Is>(subcasters).load(seq[Is], convert)...})
+ if (!r)
+ return false;
+ return true;
+ }
+
+ /* Implementation: Convert a C++ tuple into a Python tuple */
+ template <typename T, size_t... Is>
+ static handle cast_impl(T &&src, return_value_policy policy, handle parent, index_sequence<Is...>) {
+ std::array<object, size> entries{{
+ reinterpret_steal<object>(make_caster<Ts>::cast(std::get<Is>(std::forward<T>(src)), policy, parent))...
+ }};
+ for (const auto &entry: entries)
+ if (!entry)
+ return handle();
+ tuple result(size);
+ int counter = 0;
+ for (auto & entry: entries)
+ PyTuple_SET_ITEM(result.ptr(), counter++, entry.release().ptr());
+ return result.release();
+ }
+
+ Tuple<make_caster<Ts>...> subcasters;
+};
+
+template <typename T1, typename T2> class type_caster<std::pair<T1, T2>>
+ : public tuple_caster<std::pair, T1, T2> {};
+
+template <typename... Ts> class type_caster<std::tuple<Ts...>>
+ : public tuple_caster<std::tuple, Ts...> {};
+
+/// Helper class which abstracts away certain actions. Users can provide specializations for
+/// custom holders, but it's only necessary if the type has a non-standard interface.
+template <typename T>
+struct holder_helper {
+ static auto get(const T &p) -> decltype(p.get()) { return p.get(); }
+};
+
+/// Type caster for holder types like std::shared_ptr, etc.
+template <typename type, typename holder_type>
+struct copyable_holder_caster : public type_caster_base<type> {
+public:
+ using base = type_caster_base<type>;
+ static_assert(std::is_base_of<base, type_caster<type>>::value,
+ "Holder classes are only supported for custom types");
+ using base::base;
+ using base::cast;
+ using base::typeinfo;
+ using base::value;
+
+ bool load(handle src, bool convert) {
+ return base::template load_impl<copyable_holder_caster<type, holder_type>>(src, convert);
+ }
+
+ explicit operator type*() { return this->value; }
+ explicit operator type&() { return *(this->value); }
+ explicit operator holder_type*() { return &holder; }
+
+ // Workaround for Intel compiler bug
+ // see pybind11 issue 94
+ #if defined(__ICC) || defined(__INTEL_COMPILER)
+ operator holder_type&() { return holder; }
+ #else
+ explicit operator holder_type&() { return holder; }
+ #endif
+
+ static handle cast(const holder_type &src, return_value_policy, handle) {
+ const auto *ptr = holder_helper<holder_type>::get(src);
+ return type_caster_base<type>::cast_holder(ptr, &src);
+ }
+
+protected:
+ friend class type_caster_generic;
+ void check_holder_compat() {
+ if (typeinfo->default_holder)
+ throw cast_error("Unable to load a custom holder type from a default-holder instance");
+ }
+
+ bool load_value(value_and_holder &&v_h) {
+ if (v_h.holder_constructed()) {
+ value = v_h.value_ptr();
+ holder = v_h.template holder<holder_type>();
+ return true;
+ } else {
+ throw cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
+#if defined(NDEBUG)
+ "(compile in debug mode for type information)");
+#else
+ "of type '" + type_id<holder_type>() + "''");
+#endif
+ }
+ }
+
+ template <typename T = holder_type, detail::enable_if_t<!std::is_constructible<T, const T &, type*>::value, int> = 0>
+ bool try_implicit_casts(handle, bool) { return false; }
+
+ template <typename T = holder_type, detail::enable_if_t<std::is_constructible<T, const T &, type*>::value, int> = 0>
+ bool try_implicit_casts(handle src, bool convert) {
+ for (auto &cast : typeinfo->implicit_casts) {
+ copyable_holder_caster sub_caster(*cast.first);
+ if (sub_caster.load(src, convert)) {
+ value = cast.second(sub_caster.value);
+ holder = holder_type(sub_caster.holder, (type *) value);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ static bool try_direct_conversions(handle) { return false; }
+
+
+ holder_type holder;
+};
+
+/// Specialize for the common std::shared_ptr, so users don't need to
+template <typename T>
+class type_caster<std::shared_ptr<T>> : public copyable_holder_caster<T, std::shared_ptr<T>> { };
+
+template <typename type, typename holder_type>
+struct move_only_holder_caster {
+ static_assert(std::is_base_of<type_caster_base<type>, type_caster<type>>::value,
+ "Holder classes are only supported for custom types");
+
+ static handle cast(holder_type &&src, return_value_policy, handle) {
+ auto *ptr = holder_helper<holder_type>::get(src);
+ return type_caster_base<type>::cast_holder(ptr, &src);
+ }
+ static PYBIND11_DESCR name() { return type_caster_base<type>::name(); }
+};
+
+template <typename type, typename deleter>
+class type_caster<std::unique_ptr<type, deleter>>
+ : public move_only_holder_caster<type, std::unique_ptr<type, deleter>> { };
+
+template <typename type, typename holder_type>
+using type_caster_holder = conditional_t<is_copy_constructible<holder_type>::value,
+ copyable_holder_caster<type, holder_type>,
+ move_only_holder_caster<type, holder_type>>;
+
+template <typename T, bool Value = false> struct always_construct_holder { static constexpr bool value = Value; };
+
+/// Create a specialization for custom holder types (silently ignores std::shared_ptr)
+#define PYBIND11_DECLARE_HOLDER_TYPE(type, holder_type, ...) \
+ namespace pybind11 { namespace detail { \
+ template <typename type> \
+ struct always_construct_holder<holder_type> : always_construct_holder<void, ##__VA_ARGS__> { }; \
+ template <typename type> \
+ class type_caster<holder_type, enable_if_t<!is_shared_ptr<holder_type>::value>> \
+ : public type_caster_holder<type, holder_type> { }; \
+ }}
+
+// PYBIND11_DECLARE_HOLDER_TYPE holder types:
+template <typename base, typename holder> struct is_holder_type :
+ std::is_base_of<detail::type_caster_holder<base, holder>, detail::type_caster<holder>> {};
+// Specialization for always-supported unique_ptr holders:
+template <typename base, typename deleter> struct is_holder_type<base, std::unique_ptr<base, deleter>> :
+ std::true_type {};
+
+template <typename T> struct handle_type_name { static PYBIND11_DESCR name() { return _<T>(); } };
+template <> struct handle_type_name<bytes> { static PYBIND11_DESCR name() { return _(PYBIND11_BYTES_NAME); } };
+template <> struct handle_type_name<args> { static PYBIND11_DESCR name() { return _("*args"); } };
+template <> struct handle_type_name<kwargs> { static PYBIND11_DESCR name() { return _("**kwargs"); } };
+
+template <typename type>
+struct pyobject_caster {
+ template <typename T = type, enable_if_t<std::is_same<T, handle>::value, int> = 0>
+ bool load(handle src, bool /* convert */) { value = src; return static_cast<bool>(value); }
+
+ template <typename T = type, enable_if_t<std::is_base_of<object, T>::value, int> = 0>
+ bool load(handle src, bool /* convert */) {
+ if (!isinstance<type>(src))
+ return false;
+ value = reinterpret_borrow<type>(src);
+ return true;
+ }
+
+ static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
+ return src.inc_ref();
+ }
+ PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
+};
+
+template <typename T>
+class type_caster<T, enable_if_t<is_pyobject<T>::value>> : public pyobject_caster<T> { };
+
+// Our conditions for enabling moving are quite restrictive:
+// At compile time:
+// - T needs to be a non-const, non-pointer, non-reference type
+// - type_caster<T>::operator T&() must exist
+// - the type must be move constructible (obviously)
+// At run-time:
+// - if the type is non-copy-constructible, the object must be the sole owner of the type (i.e. it
+// must have ref_count() == 1)h
+// If any of the above are not satisfied, we fall back to copying.
+template <typename T> using move_is_plain_type = satisfies_none_of<T,
+ std::is_void, std::is_pointer, std::is_reference, std::is_const
+>;
+template <typename T, typename SFINAE = void> struct move_always : std::false_type {};
+template <typename T> struct move_always<T, enable_if_t<all_of<
+ move_is_plain_type<T>,
+ negation<is_copy_constructible<T>>,
+ std::is_move_constructible<T>,
+ std::is_same<decltype(std::declval<make_caster<T>>().operator T&()), T&>
+>::value>> : std::true_type {};
+template <typename T, typename SFINAE = void> struct move_if_unreferenced : std::false_type {};
+template <typename T> struct move_if_unreferenced<T, enable_if_t<all_of<
+ move_is_plain_type<T>,
+ negation<move_always<T>>,
+ std::is_move_constructible<T>,
+ std::is_same<decltype(std::declval<make_caster<T>>().operator T&()), T&>
+>::value>> : std::true_type {};
+template <typename T> using move_never = none_of<move_always<T>, move_if_unreferenced<T>>;
+
+// Detect whether returning a `type` from a cast on type's type_caster is going to result in a
+// reference or pointer to a local variable of the type_caster. Basically, only
+// non-reference/pointer `type`s and reference/pointers from a type_caster_generic are safe;
+// everything else returns a reference/pointer to a local variable.
+template <typename type> using cast_is_temporary_value_reference = bool_constant<
+ (std::is_reference<type>::value || std::is_pointer<type>::value) &&
+ !std::is_base_of<type_caster_generic, make_caster<type>>::value
+>;
+
+// When a value returned from a C++ function is being cast back to Python, we almost always want to
+// force `policy = move`, regardless of the return value policy the function/method was declared
+// with. Some classes (most notably Eigen::Ref and related) need to avoid this, and so can do so by
+// specializing this struct.
+template <typename Return, typename SFINAE = void> struct return_value_policy_override {
+ static return_value_policy policy(return_value_policy p) {
+ return !std::is_lvalue_reference<Return>::value && !std::is_pointer<Return>::value
+ ? return_value_policy::move : p;
+ }
+};
+
+// Basic python -> C++ casting; throws if casting fails
+template <typename T, typename SFINAE> type_caster<T, SFINAE> &load_type(type_caster<T, SFINAE> &conv, const handle &handle) {
+ if (!conv.load(handle, true)) {
+#if defined(NDEBUG)
+ throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)");
+#else
+ throw cast_error("Unable to cast Python instance of type " +
+ (std::string) str(handle.get_type()) + " to C++ type '" + type_id<T>() + "'");
+#endif
+ }
+ return conv;
+}
+// Wrapper around the above that also constructs and returns a type_caster
+template <typename T> make_caster<T> load_type(const handle &handle) {
+ make_caster<T> conv;
+ load_type(conv, handle);
+ return conv;
+}
+
+NAMESPACE_END(detail)
+
+// pytype -> C++ type
+template <typename T, detail::enable_if_t<!detail::is_pyobject<T>::value, int> = 0>
+T cast(const handle &handle) {
+ using namespace detail;
+ static_assert(!cast_is_temporary_value_reference<T>::value,
+ "Unable to cast type to reference: value is local to type caster");
+ return cast_op<T>(load_type<T>(handle));
+}
+
+// pytype -> pytype (calls converting constructor)
+template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
+T cast(const handle &handle) { return T(reinterpret_borrow<object>(handle)); }
+
+// C++ type -> py::object
+template <typename T, detail::enable_if_t<!detail::is_pyobject<T>::value, int> = 0>
+object cast(const T &value, return_value_policy policy = return_value_policy::automatic_reference,
+ handle parent = handle()) {
+ if (policy == return_value_policy::automatic)
+ policy = std::is_pointer<T>::value ? return_value_policy::take_ownership : return_value_policy::copy;
+ else if (policy == return_value_policy::automatic_reference)
+ policy = std::is_pointer<T>::value ? return_value_policy::reference : return_value_policy::copy;
+ return reinterpret_steal<object>(detail::make_caster<T>::cast(value, policy, parent));
+}
+
+template <typename T> T handle::cast() const { return pybind11::cast<T>(*this); }
+template <> inline void handle::cast() const { return; }
+
+template <typename T>
+detail::enable_if_t<!detail::move_never<T>::value, T> move(object &&obj) {
+ if (obj.ref_count() > 1)
+#if defined(NDEBUG)
+ throw cast_error("Unable to cast Python instance to C++ rvalue: instance has multiple references"
+ " (compile in debug mode for details)");
+#else
+ throw cast_error("Unable to move from Python " + (std::string) str(obj.get_type()) +
+ " instance to C++ " + type_id<T>() + " instance: instance has multiple references");
+#endif
+
+ // Move into a temporary and return that, because the reference may be a local value of `conv`
+ T ret = std::move(detail::load_type<T>(obj).operator T&());
+ return ret;
+}
+
+// Calling cast() on an rvalue calls pybind::cast with the object rvalue, which does:
+// - If we have to move (because T has no copy constructor), do it. This will fail if the moved
+// object has multiple references, but trying to copy will fail to compile.
+// - If both movable and copyable, check ref count: if 1, move; otherwise copy
+// - Otherwise (not movable), copy.
+template <typename T> detail::enable_if_t<detail::move_always<T>::value, T> cast(object &&object) {
+ return move<T>(std::move(object));
+}
+template <typename T> detail::enable_if_t<detail::move_if_unreferenced<T>::value, T> cast(object &&object) {
+ if (object.ref_count() > 1)
+ return cast<T>(object);
+ else
+ return move<T>(std::move(object));
+}
+template <typename T> detail::enable_if_t<detail::move_never<T>::value, T> cast(object &&object) {
+ return cast<T>(object);
+}
+
+template <typename T> T object::cast() const & { return pybind11::cast<T>(*this); }
+template <typename T> T object::cast() && { return pybind11::cast<T>(std::move(*this)); }
+template <> inline void object::cast() const & { return; }
+template <> inline void object::cast() && { return; }
+
+NAMESPACE_BEGIN(detail)
+
+// Declared in pytypes.h:
+template <typename T, enable_if_t<!is_pyobject<T>::value, int>>
+object object_or_cast(T &&o) { return pybind11::cast(std::forward<T>(o)); }
+
+struct overload_unused {}; // Placeholder type for the unneeded (and dead code) static variable in the OVERLOAD_INT macro
+template <typename ret_type> using overload_caster_t = conditional_t<
+ cast_is_temporary_value_reference<ret_type>::value, make_caster<ret_type>, overload_unused>;
+
+// Trampoline use: for reference/pointer types to value-converted values, we do a value cast, then
+// store the result in the given variable. For other types, this is a no-op.
+template <typename T> enable_if_t<cast_is_temporary_value_reference<T>::value, T> cast_ref(object &&o, make_caster<T> &caster) {
+ return cast_op<T>(load_type(caster, o));
+}
+template <typename T> enable_if_t<!cast_is_temporary_value_reference<T>::value, T> cast_ref(object &&, overload_unused &) {
+ pybind11_fail("Internal error: cast_ref fallback invoked"); }
+
+// Trampoline use: Having a pybind11::cast with an invalid reference type is going to static_assert, even
+// though if it's in dead code, so we provide a "trampoline" to pybind11::cast that only does anything in
+// cases where pybind11::cast is valid.
+template <typename T> enable_if_t<!cast_is_temporary_value_reference<T>::value, T> cast_safe(object &&o) {
+ return pybind11::cast<T>(std::move(o)); }
+template <typename T> enable_if_t<cast_is_temporary_value_reference<T>::value, T> cast_safe(object &&) {
+ pybind11_fail("Internal error: cast_safe fallback invoked"); }
+template <> inline void cast_safe<void>(object &&) {}
+
+NAMESPACE_END(detail)
+
+template <return_value_policy policy = return_value_policy::automatic_reference>
+tuple make_tuple() { return tuple(0); }
+
+template <return_value_policy policy = return_value_policy::automatic_reference,
+ typename... Args> tuple make_tuple(Args&&... args_) {
+ constexpr size_t size = sizeof...(Args);
+ std::array<object, size> args {
+ { reinterpret_steal<object>(detail::make_caster<Args>::cast(
+ std::forward<Args>(args_), policy, nullptr))... }
+ };
+ for (size_t i = 0; i < args.size(); i++) {
+ if (!args[i]) {
+#if defined(NDEBUG)
+ throw cast_error("make_tuple(): unable to convert arguments to Python object (compile in debug mode for details)");
+#else
+ std::array<std::string, size> argtypes { {type_id<Args>()...} };
+ throw cast_error("make_tuple(): unable to convert argument of type '" +
+ argtypes[i] + "' to Python object");
+#endif
+ }
+ }
+ tuple result(size);
+ int counter = 0;
+ for (auto &arg_value : args)
+ PyTuple_SET_ITEM(result.ptr(), counter++, arg_value.release().ptr());
+ return result;
+}
+
+/// \ingroup annotations
+/// Annotation for arguments
+struct arg {
+ /// Constructs an argument with the name of the argument; if null or omitted, this is a positional argument.
+ constexpr explicit arg(const char *name = nullptr) : name(name), flag_noconvert(false), flag_none(true) { }
+ /// Assign a value to this argument
+ template <typename T> arg_v operator=(T &&value) const;
+ /// Indicate that the type should not be converted in the type caster
+ arg &noconvert(bool flag = true) { flag_noconvert = flag; return *this; }
+ /// Indicates that the argument should/shouldn't allow None (e.g. for nullable pointer args)
+ arg &none(bool flag = true) { flag_none = flag; return *this; }
+
+ const char *name; ///< If non-null, this is a named kwargs argument
+ bool flag_noconvert : 1; ///< If set, do not allow conversion (requires a supporting type caster!)
+ bool flag_none : 1; ///< If set (the default), allow None to be passed to this argument
+};
+
+/// \ingroup annotations
+/// Annotation for arguments with values
+struct arg_v : arg {
+private:
+ template <typename T>
+ arg_v(arg &&base, T &&x, const char *descr = nullptr)
+ : arg(base),
+ value(reinterpret_steal<object>(
+ detail::make_caster<T>::cast(x, return_value_policy::automatic, {})
+ )),
+ descr(descr)
+#if !defined(NDEBUG)
+ , type(type_id<T>())
+#endif
+ { }
+
+public:
+ /// Direct construction with name, default, and description
+ template <typename T>
+ arg_v(const char *name, T &&x, const char *descr = nullptr)
+ : arg_v(arg(name), std::forward<T>(x), descr) { }
+
+ /// Called internally when invoking `py::arg("a") = value`
+ template <typename T>
+ arg_v(const arg &base, T &&x, const char *descr = nullptr)
+ : arg_v(arg(base), std::forward<T>(x), descr) { }
+
+ /// Same as `arg::noconvert()`, but returns *this as arg_v&, not arg&
+ arg_v &noconvert(bool flag = true) { arg::noconvert(flag); return *this; }
+
+ /// Same as `arg::nonone()`, but returns *this as arg_v&, not arg&
+ arg_v &none(bool flag = true) { arg::none(flag); return *this; }
+
+ /// The default value
+ object value;
+ /// The (optional) description of the default value
+ const char *descr;
+#if !defined(NDEBUG)
+ /// The C++ type name of the default value (only available when compiled in debug mode)
+ std::string type;
+#endif
+};
+
+template <typename T>
+arg_v arg::operator=(T &&value) const { return {std::move(*this), std::forward<T>(value)}; }
+
+/// Alias for backward compatibility -- to be removed in version 2.0
+template <typename /*unused*/> using arg_t = arg_v;
+
+inline namespace literals {
+/** \rst
+ String literal version of `arg`
+ \endrst */
+constexpr arg operator"" _a(const char *name, size_t) { return arg(name); }
+}
+
+NAMESPACE_BEGIN(detail)
+
+// forward declaration (definition in attr.h)
+struct function_record;
+
+/// Internal data associated with a single function call
+struct function_call {
+ function_call(function_record &f, handle p); // Implementation in attr.h
+
+ /// The function data:
+ const function_record &func;
+
+ /// Arguments passed to the function:
+ std::vector<handle> args;
+
+ /// The `convert` value the arguments should be loaded with
+ std::vector<bool> args_convert;
+
+ /// Extra references for the optional `py::args` and/or `py::kwargs` arguments (which, if
+ /// present, are also in `args` but without a reference).
+ object args_ref, kwargs_ref;
+
+ /// The parent, if any
+ handle parent;
+
+ /// If this is a call to an initializer, this argument contains `self`
+ handle init_self;
+};
+
+
+/// Helper class which loads arguments for C++ functions called from Python
+template <typename... Args>
+class argument_loader {
+ using indices = make_index_sequence<sizeof...(Args)>;
+
+ template <typename Arg> using argument_is_args = std::is_same<intrinsic_t<Arg>, args>;
+ template <typename Arg> using argument_is_kwargs = std::is_same<intrinsic_t<Arg>, kwargs>;
+ // Get args/kwargs argument positions relative to the end of the argument list:
+ static constexpr auto args_pos = constexpr_first<argument_is_args, Args...>() - (int) sizeof...(Args),
+ kwargs_pos = constexpr_first<argument_is_kwargs, Args...>() - (int) sizeof...(Args);
+
+ static constexpr bool args_kwargs_are_last = kwargs_pos >= - 1 && args_pos >= kwargs_pos - 1;
+
+ static_assert(args_kwargs_are_last, "py::args/py::kwargs are only permitted as the last argument(s) of a function");
+
+public:
+ static constexpr bool has_kwargs = kwargs_pos < 0;
+ static constexpr bool has_args = args_pos < 0;
+
+ static PYBIND11_DESCR arg_names() { return detail::concat(make_caster<Args>::name()...); }
+
+ bool load_args(function_call &call) {
+ return load_impl_sequence(call, indices{});
+ }
+
+ template <typename Return, typename Guard, typename Func>
+ enable_if_t<!std::is_void<Return>::value, Return> call(Func &&f) && {
+ return std::move(*this).template call_impl<Return>(std::forward<Func>(f), indices{}, Guard{});
+ }
+
+ template <typename Return, typename Guard, typename Func>
+ enable_if_t<std::is_void<Return>::value, void_type> call(Func &&f) && {
+ std::move(*this).template call_impl<Return>(std::forward<Func>(f), indices{}, Guard{});
+ return void_type();
+ }
+
+private:
+
+ static bool load_impl_sequence(function_call &, index_sequence<>) { return true; }
+
+ template <size_t... Is>
+ bool load_impl_sequence(function_call &call, index_sequence<Is...>) {
+ for (bool r : {std::get<Is>(argcasters).load(call.args[Is], call.args_convert[Is])...})
+ if (!r)
+ return false;
+ return true;
+ }
+
+ template <typename Return, typename Func, size_t... Is, typename Guard>
+ Return call_impl(Func &&f, index_sequence<Is...>, Guard &&) {
+ return std::forward<Func>(f)(cast_op<Args>(std::move(std::get<Is>(argcasters)))...);
+ }
+
+ std::tuple<make_caster<Args>...> argcasters;
+};
+
+/// Helper class which collects only positional arguments for a Python function call.
+/// A fancier version below can collect any argument, but this one is optimal for simple calls.
+template <return_value_policy policy>
+class simple_collector {
+public:
+ template <typename... Ts>
+ explicit simple_collector(Ts &&...values)
+ : m_args(pybind11::make_tuple<policy>(std::forward<Ts>(values)...)) { }
+
+ const tuple &args() const & { return m_args; }
+ dict kwargs() const { return {}; }
+
+ tuple args() && { return std::move(m_args); }
+
+ /// Call a Python function and pass the collected arguments
+ object call(PyObject *ptr) const {
+ PyObject *result = PyObject_CallObject(ptr, m_args.ptr());
+ if (!result)
+ throw error_already_set();
+ return reinterpret_steal<object>(result);
+ }
+
+private:
+ tuple m_args;
+};
+
+/// Helper class which collects positional, keyword, * and ** arguments for a Python function call
+template <return_value_policy policy>
+class unpacking_collector {
+public:
+ template <typename... Ts>
+ explicit unpacking_collector(Ts &&...values) {
+ // Tuples aren't (easily) resizable so a list is needed for collection,
+ // but the actual function call strictly requires a tuple.
+ auto args_list = list();
+ int _[] = { 0, (process(args_list, std::forward<Ts>(values)), 0)... };
+ ignore_unused(_);
+
+ m_args = std::move(args_list);
+ }
+
+ const tuple &args() const & { return m_args; }
+ const dict &kwargs() const & { return m_kwargs; }
+
+ tuple args() && { return std::move(m_args); }
+ dict kwargs() && { return std::move(m_kwargs); }
+
+ /// Call a Python function and pass the collected arguments
+ object call(PyObject *ptr) const {
+ PyObject *result = PyObject_Call(ptr, m_args.ptr(), m_kwargs.ptr());
+ if (!result)
+ throw error_already_set();
+ return reinterpret_steal<object>(result);
+ }
+
+private:
+ template <typename T>
+ void process(list &args_list, T &&x) {
+ auto o = reinterpret_steal<object>(detail::make_caster<T>::cast(std::forward<T>(x), policy, {}));
+ if (!o) {
+#if defined(NDEBUG)
+ argument_cast_error();
+#else
+ argument_cast_error(std::to_string(args_list.size()), type_id<T>());
+#endif
+ }
+ args_list.append(o);
+ }
+
+ void process(list &args_list, detail::args_proxy ap) {
+ for (const auto &a : ap)
+ args_list.append(a);
+ }
+
+ void process(list &/*args_list*/, arg_v a) {
+ if (!a.name)
+#if defined(NDEBUG)
+ nameless_argument_error();
+#else
+ nameless_argument_error(a.type);
+#endif
+
+ if (m_kwargs.contains(a.name)) {
+#if defined(NDEBUG)
+ multiple_values_error();
+#else
+ multiple_values_error(a.name);
+#endif
+ }
+ if (!a.value) {
+#if defined(NDEBUG)
+ argument_cast_error();
+#else
+ argument_cast_error(a.name, a.type);
+#endif
+ }
+ m_kwargs[a.name] = a.value;
+ }
+
+ void process(list &/*args_list*/, detail::kwargs_proxy kp) {
+ if (!kp)
+ return;
+ for (const auto &k : reinterpret_borrow<dict>(kp)) {
+ if (m_kwargs.contains(k.first)) {
+#if defined(NDEBUG)
+ multiple_values_error();
+#else
+ multiple_values_error(str(k.first));
+#endif
+ }
+ m_kwargs[k.first] = k.second;
+ }
+ }
+
+ [[noreturn]] static void nameless_argument_error() {
+ throw type_error("Got kwargs without a name; only named arguments "
+ "may be passed via py::arg() to a python function call. "
+ "(compile in debug mode for details)");
+ }
+ [[noreturn]] static void nameless_argument_error(std::string type) {
+ throw type_error("Got kwargs without a name of type '" + type + "'; only named "
+ "arguments may be passed via py::arg() to a python function call. ");
+ }
+ [[noreturn]] static void multiple_values_error() {
+ throw type_error("Got multiple values for keyword argument "
+ "(compile in debug mode for details)");
+ }
+
+ [[noreturn]] static void multiple_values_error(std::string name) {
+ throw type_error("Got multiple values for keyword argument '" + name + "'");
+ }
+
+ [[noreturn]] static void argument_cast_error() {
+ throw cast_error("Unable to convert call argument to Python object "
+ "(compile in debug mode for details)");
+ }
+
+ [[noreturn]] static void argument_cast_error(std::string name, std::string type) {
+ throw cast_error("Unable to convert call argument '" + name
+ + "' of type '" + type + "' to Python object");
+ }
+
+private:
+ tuple m_args;
+ dict m_kwargs;
+};
+
+/// Collect only positional arguments for a Python function call
+template <return_value_policy policy, typename... Args,
+ typename = enable_if_t<all_of<is_positional<Args>...>::value>>
+simple_collector<policy> collect_arguments(Args &&...args) {
+ return simple_collector<policy>(std::forward<Args>(args)...);
+}
+
+/// Collect all arguments, including keywords and unpacking (only instantiated when needed)
+template <return_value_policy policy, typename... Args,
+ typename = enable_if_t<!all_of<is_positional<Args>...>::value>>
+unpacking_collector<policy> collect_arguments(Args &&...args) {
+ // Following argument order rules for generalized unpacking according to PEP 448
+ static_assert(
+ constexpr_last<is_positional, Args...>() < constexpr_first<is_keyword_or_ds, Args...>()
+ && constexpr_last<is_s_unpacking, Args...>() < constexpr_first<is_ds_unpacking, Args...>(),
+ "Invalid function call: positional args must precede keywords and ** unpacking; "
+ "* unpacking must precede ** unpacking"
+ );
+ return unpacking_collector<policy>(std::forward<Args>(args)...);
+}
+
+template <typename Derived>
+template <return_value_policy policy, typename... Args>
+object object_api<Derived>::operator()(Args &&...args) const {
+ return detail::collect_arguments<policy>(std::forward<Args>(args)...).call(derived().ptr());
+}
+
+template <typename Derived>
+template <return_value_policy policy, typename... Args>
+object object_api<Derived>::call(Args &&...args) const {
+ return operator()<policy>(std::forward<Args>(args)...);
+}
+
+NAMESPACE_END(detail)
+
+#define PYBIND11_MAKE_OPAQUE(Type) \
+ namespace pybind11 { namespace detail { \
+ template<> class type_caster<Type> : public type_caster_base<Type> { }; \
+ }}
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/chrono.h b/ml/dlib/dlib/external/pybind11/include/pybind11/chrono.h
new file mode 100644
index 000000000..95ada76e0
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/chrono.h
@@ -0,0 +1,162 @@
+/*
+ pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
+
+ Copyright (c) 2016 Trent Houliston <trent@houliston.me> and
+ Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pybind11.h"
+#include <cmath>
+#include <ctime>
+#include <chrono>
+#include <datetime.h>
+
+// Backport the PyDateTime_DELTA functions from Python3.3 if required
+#ifndef PyDateTime_DELTA_GET_DAYS
+#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
+#endif
+#ifndef PyDateTime_DELTA_GET_SECONDS
+#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
+#endif
+#ifndef PyDateTime_DELTA_GET_MICROSECONDS
+#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
+#endif
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+template <typename type> class duration_caster {
+public:
+ typedef typename type::rep rep;
+ typedef typename type::period period;
+
+ typedef std::chrono::duration<uint_fast32_t, std::ratio<86400>> days;
+
+ bool load(handle src, bool) {
+ using namespace std::chrono;
+
+ // Lazy initialise the PyDateTime import
+ if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
+
+ if (!src) return false;
+ // If invoked with datetime.delta object
+ if (PyDelta_Check(src.ptr())) {
+ value = type(duration_cast<duration<rep, period>>(
+ days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
+ + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
+ + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
+ return true;
+ }
+ // If invoked with a float we assume it is seconds and convert
+ else if (PyFloat_Check(src.ptr())) {
+ value = type(duration_cast<duration<rep, period>>(duration<double>(PyFloat_AsDouble(src.ptr()))));
+ return true;
+ }
+ else return false;
+ }
+
+ // If this is a duration just return it back
+ static const std::chrono::duration<rep, period>& get_duration(const std::chrono::duration<rep, period> &src) {
+ return src;
+ }
+
+ // If this is a time_point get the time_since_epoch
+ template <typename Clock> static std::chrono::duration<rep, period> get_duration(const std::chrono::time_point<Clock, std::chrono::duration<rep, period>> &src) {
+ return src.time_since_epoch();
+ }
+
+ static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
+ using namespace std::chrono;
+
+ // Use overloaded function to get our duration from our source
+ // Works out if it is a duration or time_point and get the duration
+ auto d = get_duration(src);
+
+ // Lazy initialise the PyDateTime import
+ if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
+
+ // Declare these special duration types so the conversions happen with the correct primitive types (int)
+ using dd_t = duration<int, std::ratio<86400>>;
+ using ss_t = duration<int, std::ratio<1>>;
+ using us_t = duration<int, std::micro>;
+
+ auto dd = duration_cast<dd_t>(d);
+ auto subd = d - dd;
+ auto ss = duration_cast<ss_t>(subd);
+ auto us = duration_cast<us_t>(subd - ss);
+ return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
+ }
+
+ PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
+};
+
+// This is for casting times on the system clock into datetime.datetime instances
+template <typename Duration> class type_caster<std::chrono::time_point<std::chrono::system_clock, Duration>> {
+public:
+ typedef std::chrono::time_point<std::chrono::system_clock, Duration> type;
+ bool load(handle src, bool) {
+ using namespace std::chrono;
+
+ // Lazy initialise the PyDateTime import
+ if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
+
+ if (!src) return false;
+ if (PyDateTime_Check(src.ptr())) {
+ std::tm cal;
+ cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
+ cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
+ cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
+ cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
+ cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
+ cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
+ cal.tm_isdst = -1;
+
+ value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
+ return true;
+ }
+ else return false;
+ }
+
+ static handle cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src, return_value_policy /* policy */, handle /* parent */) {
+ using namespace std::chrono;
+
+ // Lazy initialise the PyDateTime import
+ if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
+
+ std::time_t tt = system_clock::to_time_t(src);
+ // this function uses static memory so it's best to copy it out asap just in case
+ // otherwise other code that is using localtime may break this (not just python code)
+ std::tm localtime = *std::localtime(&tt);
+
+ // Declare these special duration types so the conversions happen with the correct primitive types (int)
+ using us_t = duration<int, std::micro>;
+
+ return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
+ localtime.tm_mon + 1,
+ localtime.tm_mday,
+ localtime.tm_hour,
+ localtime.tm_min,
+ localtime.tm_sec,
+ (duration_cast<us_t>(src.time_since_epoch() % seconds(1))).count());
+ }
+ PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
+};
+
+// Other clocks that are not the system clock are not measured as datetime.datetime objects
+// since they are not measured on calendar time. So instead we just make them timedeltas
+// Or if they have passed us a time as a float we convert that
+template <typename Clock, typename Duration> class type_caster<std::chrono::time_point<Clock, Duration>>
+: public duration_caster<std::chrono::time_point<Clock, Duration>> {
+};
+
+template <typename Rep, typename Period> class type_caster<std::chrono::duration<Rep, Period>>
+: public duration_caster<std::chrono::duration<Rep, Period>> {
+};
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/common.h b/ml/dlib/dlib/external/pybind11/include/pybind11/common.h
new file mode 100644
index 000000000..6c8a4f1e8
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/common.h
@@ -0,0 +1,2 @@
+#include "detail/common.h"
+#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/complex.h b/ml/dlib/dlib/external/pybind11/include/pybind11/complex.h
new file mode 100644
index 000000000..5dac27cc4
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/complex.h
@@ -0,0 +1,61 @@
+/*
+ pybind11/complex.h: Complex number support
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pybind11.h"
+#include <complex>
+
+/// glibc defines I as a macro which breaks things, e.g., boost template names
+#ifdef I
+# undef I
+#endif
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+template <typename T> struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
+ static constexpr const char c = format_descriptor<T>::c;
+ static constexpr const char value[3] = { 'Z', c, '\0' };
+ static std::string format() { return std::string(value); }
+};
+
+template <typename T> constexpr const char format_descriptor<
+ std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
+
+NAMESPACE_BEGIN(detail)
+
+template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
+ static constexpr bool value = true;
+ static constexpr int index = is_fmt_numeric<T>::index + 3;
+};
+
+template <typename T> class type_caster<std::complex<T>> {
+public:
+ bool load(handle src, bool convert) {
+ if (!src)
+ return false;
+ if (!convert && !PyComplex_Check(src.ptr()))
+ return false;
+ Py_complex result = PyComplex_AsCComplex(src.ptr());
+ if (result.real == -1.0 && PyErr_Occurred()) {
+ PyErr_Clear();
+ return false;
+ }
+ value = std::complex<T>((T) result.real, (T) result.imag);
+ return true;
+ }
+
+ static handle cast(const std::complex<T> &src, return_value_policy /* policy */, handle /* parent */) {
+ return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
+ }
+
+ PYBIND11_TYPE_CASTER(std::complex<T>, _("complex"));
+};
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/class.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/class.h
new file mode 100644
index 000000000..ff06370fa
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/class.h
@@ -0,0 +1,626 @@
+/*
+ pybind11/detail/class.h: Python C API implementation details for py::class_
+
+ Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "../attr.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+#if PY_VERSION_HEX >= 0x03030000
+# define PYBIND11_BUILTIN_QUALNAME
+# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj)
+#else
+// In pre-3.3 Python, we still set __qualname__ so that we can produce reliable function type
+// signatures; in 3.3+ this macro expands to nothing:
+# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) setattr((PyObject *) obj, "__qualname__", nameobj)
+#endif
+
+inline PyTypeObject *type_incref(PyTypeObject *type) {
+ Py_INCREF(type);
+ return type;
+}
+
+#if !defined(PYPY_VERSION)
+
+/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance.
+extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) {
+ return PyProperty_Type.tp_descr_get(self, cls, cls);
+}
+
+/// `pybind11_static_property.__set__()`: Just like the above `__get__()`.
+extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) {
+ PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj);
+ return PyProperty_Type.tp_descr_set(self, cls, value);
+}
+
+/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()`
+ methods are modified to always use the object type instead of a concrete instance.
+ Return value: New reference. */
+inline PyTypeObject *make_static_property_type() {
+ constexpr auto *name = "pybind11_static_property";
+ auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
+
+ /* Danger zone: from now (and until PyType_Ready), make sure to
+ issue no Python C API calls which could potentially invoke the
+ garbage collector (the GC will call type_traverse(), which will in
+ turn find the newly constructed type in an invalid state) */
+ auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
+ if (!heap_type)
+ pybind11_fail("make_static_property_type(): error allocating type!");
+
+ heap_type->ht_name = name_obj.inc_ref().ptr();
+#ifdef PYBIND11_BUILTIN_QUALNAME
+ heap_type->ht_qualname = name_obj.inc_ref().ptr();
+#endif
+
+ auto type = &heap_type->ht_type;
+ type->tp_name = name;
+ type->tp_base = type_incref(&PyProperty_Type);
+ type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
+ type->tp_descr_get = pybind11_static_get;
+ type->tp_descr_set = pybind11_static_set;
+
+ if (PyType_Ready(type) < 0)
+ pybind11_fail("make_static_property_type(): failure in PyType_Ready()!");
+
+ setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
+ PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
+
+ return type;
+}
+
+#else // PYPY
+
+/** PyPy has some issues with the above C API, so we evaluate Python code instead.
+ This function will only be called once so performance isn't really a concern.
+ Return value: New reference. */
+inline PyTypeObject *make_static_property_type() {
+ auto d = dict();
+ PyObject *result = PyRun_String(R"(\
+ class pybind11_static_property(property):
+ def __get__(self, obj, cls):
+ return property.__get__(self, cls, cls)
+
+ def __set__(self, obj, value):
+ cls = obj if isinstance(obj, type) else type(obj)
+ property.__set__(self, cls, value)
+ )", Py_file_input, d.ptr(), d.ptr()
+ );
+ if (result == nullptr)
+ throw error_already_set();
+ Py_DECREF(result);
+ return (PyTypeObject *) d["pybind11_static_property"].cast<object>().release().ptr();
+}
+
+#endif // PYPY
+
+/** Types with static properties need to handle `Type.static_prop = x` in a specific way.
+ By default, Python replaces the `static_property` itself, but for wrapped C++ types
+ we need to call `static_property.__set__()` in order to propagate the new value to
+ the underlying C++ data structure. */
+extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) {
+ // Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw
+ // descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`).
+ PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
+
+ // The following assignment combinations are possible:
+ // 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)`
+ // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop`
+ // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment
+ const auto static_prop = (PyObject *) get_internals().static_property_type;
+ const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop)
+ && !PyObject_IsInstance(value, static_prop);
+ if (call_descr_set) {
+ // Call `static_property.__set__()` instead of replacing the `static_property`.
+#if !defined(PYPY_VERSION)
+ return Py_TYPE(descr)->tp_descr_set(descr, obj, value);
+#else
+ if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) {
+ Py_DECREF(result);
+ return 0;
+ } else {
+ return -1;
+ }
+#endif
+ } else {
+ // Replace existing attribute.
+ return PyType_Type.tp_setattro(obj, name, value);
+ }
+}
+
+#if PY_MAJOR_VERSION >= 3
+/**
+ * Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing
+ * methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function,
+ * when called on a class, or a PyMethod, when called on an instance. Override that behaviour here
+ * to do a special case bypass for PyInstanceMethod_Types.
+ */
+extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) {
+ PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
+ if (descr && PyInstanceMethod_Check(descr)) {
+ Py_INCREF(descr);
+ return descr;
+ }
+ else {
+ return PyType_Type.tp_getattro(obj, name);
+ }
+}
+#endif
+
+/** This metaclass is assigned by default to all pybind11 types and is required in order
+ for static properties to function correctly. Users may override this using `py::metaclass`.
+ Return value: New reference. */
+inline PyTypeObject* make_default_metaclass() {
+ constexpr auto *name = "pybind11_type";
+ auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
+
+ /* Danger zone: from now (and until PyType_Ready), make sure to
+ issue no Python C API calls which could potentially invoke the
+ garbage collector (the GC will call type_traverse(), which will in
+ turn find the newly constructed type in an invalid state) */
+ auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
+ if (!heap_type)
+ pybind11_fail("make_default_metaclass(): error allocating metaclass!");
+
+ heap_type->ht_name = name_obj.inc_ref().ptr();
+#ifdef PYBIND11_BUILTIN_QUALNAME
+ heap_type->ht_qualname = name_obj.inc_ref().ptr();
+#endif
+
+ auto type = &heap_type->ht_type;
+ type->tp_name = name;
+ type->tp_base = type_incref(&PyType_Type);
+ type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
+
+ type->tp_setattro = pybind11_meta_setattro;
+#if PY_MAJOR_VERSION >= 3
+ type->tp_getattro = pybind11_meta_getattro;
+#endif
+
+ if (PyType_Ready(type) < 0)
+ pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!");
+
+ setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
+ PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
+
+ return type;
+}
+
+/// For multiple inheritance types we need to recursively register/deregister base pointers for any
+/// base classes with pointers that are difference from the instance value pointer so that we can
+/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs.
+inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self,
+ bool (*f)(void * /*parentptr*/, instance * /*self*/)) {
+ for (handle h : reinterpret_borrow<tuple>(tinfo->type->tp_bases)) {
+ if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) {
+ for (auto &c : parent_tinfo->implicit_casts) {
+ if (c.first == tinfo->cpptype) {
+ auto *parentptr = c.second(valueptr);
+ if (parentptr != valueptr)
+ f(parentptr, self);
+ traverse_offset_bases(parentptr, parent_tinfo, self, f);
+ break;
+ }
+ }
+ }
+ }
+}
+
+inline bool register_instance_impl(void *ptr, instance *self) {
+ get_internals().registered_instances.emplace(ptr, self);
+ return true; // unused, but gives the same signature as the deregister func
+}
+inline bool deregister_instance_impl(void *ptr, instance *self) {
+ auto &registered_instances = get_internals().registered_instances;
+ auto range = registered_instances.equal_range(ptr);
+ for (auto it = range.first; it != range.second; ++it) {
+ if (Py_TYPE(self) == Py_TYPE(it->second)) {
+ registered_instances.erase(it);
+ return true;
+ }
+ }
+ return false;
+}
+
+inline void register_instance(instance *self, void *valptr, const type_info *tinfo) {
+ register_instance_impl(valptr, self);
+ if (!tinfo->simple_ancestors)
+ traverse_offset_bases(valptr, tinfo, self, register_instance_impl);
+}
+
+inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) {
+ bool ret = deregister_instance_impl(valptr, self);
+ if (!tinfo->simple_ancestors)
+ traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl);
+ return ret;
+}
+
+/// Instance creation function for all pybind11 types. It allocates the internal instance layout for
+/// holding C++ objects and holders. Allocation is done lazily (the first time the instance is cast
+/// to a reference or pointer), and initialization is done by an `__init__` function.
+inline PyObject *make_new_instance(PyTypeObject *type) {
+#if defined(PYPY_VERSION)
+ // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited
+ // object is a a plain Python type (i.e. not derived from an extension type). Fix it.
+ ssize_t instance_size = static_cast<ssize_t>(sizeof(instance));
+ if (type->tp_basicsize < instance_size) {
+ type->tp_basicsize = instance_size;
+ }
+#endif
+ PyObject *self = type->tp_alloc(type, 0);
+ auto inst = reinterpret_cast<instance *>(self);
+ // Allocate the value/holder internals:
+ inst->allocate_layout();
+
+ inst->owned = true;
+
+ return self;
+}
+
+/// Instance creation function for all pybind11 types. It only allocates space for the
+/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
+extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
+ return make_new_instance(type);
+}
+
+/// An `__init__` function constructs the C++ object. Users should provide at least one
+/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the
+/// following default function will be used which simply throws an exception.
+extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) {
+ PyTypeObject *type = Py_TYPE(self);
+ std::string msg;
+#if defined(PYPY_VERSION)
+ msg += handle((PyObject *) type).attr("__module__").cast<std::string>() + ".";
+#endif
+ msg += type->tp_name;
+ msg += ": No constructor defined!";
+ PyErr_SetString(PyExc_TypeError, msg.c_str());
+ return -1;
+}
+
+inline void add_patient(PyObject *nurse, PyObject *patient) {
+ auto &internals = get_internals();
+ auto instance = reinterpret_cast<detail::instance *>(nurse);
+ auto &current_patients = internals.patients[nurse];
+ instance->has_patients = true;
+ for (auto &p : current_patients)
+ if (p == patient)
+ return;
+ Py_INCREF(patient);
+ current_patients.push_back(patient);
+}
+
+inline void clear_patients(PyObject *self) {
+ auto instance = reinterpret_cast<detail::instance *>(self);
+ auto &internals = get_internals();
+ auto pos = internals.patients.find(self);
+ assert(pos != internals.patients.end());
+ // Clearing the patients can cause more Python code to run, which
+ // can invalidate the iterator. Extract the vector of patients
+ // from the unordered_map first.
+ auto patients = std::move(pos->second);
+ internals.patients.erase(pos);
+ instance->has_patients = false;
+ for (PyObject *&patient : patients)
+ Py_CLEAR(patient);
+}
+
+/// Clears all internal data from the instance and removes it from registered instances in
+/// preparation for deallocation.
+inline void clear_instance(PyObject *self) {
+ auto instance = reinterpret_cast<detail::instance *>(self);
+
+ // Deallocate any values/holders, if present:
+ for (auto &v_h : values_and_holders(instance)) {
+ if (v_h) {
+
+ // We have to deregister before we call dealloc because, for virtual MI types, we still
+ // need to be able to get the parent pointers.
+ if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type))
+ pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!");
+
+ if (instance->owned || v_h.holder_constructed())
+ v_h.type->dealloc(v_h);
+ }
+ }
+ // Deallocate the value/holder layout internals:
+ instance->deallocate_layout();
+
+ if (instance->weakrefs)
+ PyObject_ClearWeakRefs(self);
+
+ PyObject **dict_ptr = _PyObject_GetDictPtr(self);
+ if (dict_ptr)
+ Py_CLEAR(*dict_ptr);
+
+ if (instance->has_patients)
+ clear_patients(self);
+}
+
+/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc`
+/// to destroy the C++ object itself, while the rest is Python bookkeeping.
+extern "C" inline void pybind11_object_dealloc(PyObject *self) {
+ clear_instance(self);
+
+ auto type = Py_TYPE(self);
+ type->tp_free(self);
+
+ // `type->tp_dealloc != pybind11_object_dealloc` means that we're being called
+ // as part of a derived type's dealloc, in which case we're not allowed to decref
+ // the type here. For cross-module compatibility, we shouldn't compare directly
+ // with `pybind11_object_dealloc`, but with the common one stashed in internals.
+ auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base;
+ if (type->tp_dealloc == pybind11_object_type->tp_dealloc)
+ Py_DECREF(type);
+}
+
+/** Create the type which can be used as a common base for all classes. This is
+ needed in order to satisfy Python's requirements for multiple inheritance.
+ Return value: New reference. */
+inline PyObject *make_object_base_type(PyTypeObject *metaclass) {
+ constexpr auto *name = "pybind11_object";
+ auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
+
+ /* Danger zone: from now (and until PyType_Ready), make sure to
+ issue no Python C API calls which could potentially invoke the
+ garbage collector (the GC will call type_traverse(), which will in
+ turn find the newly constructed type in an invalid state) */
+ auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
+ if (!heap_type)
+ pybind11_fail("make_object_base_type(): error allocating type!");
+
+ heap_type->ht_name = name_obj.inc_ref().ptr();
+#ifdef PYBIND11_BUILTIN_QUALNAME
+ heap_type->ht_qualname = name_obj.inc_ref().ptr();
+#endif
+
+ auto type = &heap_type->ht_type;
+ type->tp_name = name;
+ type->tp_base = type_incref(&PyBaseObject_Type);
+ type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
+ type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
+
+ type->tp_new = pybind11_object_new;
+ type->tp_init = pybind11_object_init;
+ type->tp_dealloc = pybind11_object_dealloc;
+
+ /* Support weak references (needed for the keep_alive feature) */
+ type->tp_weaklistoffset = offsetof(instance, weakrefs);
+
+ if (PyType_Ready(type) < 0)
+ pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string());
+
+ setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
+ PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
+
+ assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
+ return (PyObject *) heap_type;
+}
+
+/// dynamic_attr: Support for `d = instance.__dict__`.
+extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) {
+ PyObject *&dict = *_PyObject_GetDictPtr(self);
+ if (!dict)
+ dict = PyDict_New();
+ Py_XINCREF(dict);
+ return dict;
+}
+
+/// dynamic_attr: Support for `instance.__dict__ = dict()`.
+extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) {
+ if (!PyDict_Check(new_dict)) {
+ PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'",
+ Py_TYPE(new_dict)->tp_name);
+ return -1;
+ }
+ PyObject *&dict = *_PyObject_GetDictPtr(self);
+ Py_INCREF(new_dict);
+ Py_CLEAR(dict);
+ dict = new_dict;
+ return 0;
+}
+
+/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`.
+extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) {
+ PyObject *&dict = *_PyObject_GetDictPtr(self);
+ Py_VISIT(dict);
+ return 0;
+}
+
+/// dynamic_attr: Allow the GC to clear the dictionary.
+extern "C" inline int pybind11_clear(PyObject *self) {
+ PyObject *&dict = *_PyObject_GetDictPtr(self);
+ Py_CLEAR(dict);
+ return 0;
+}
+
+/// Give instances of this type a `__dict__` and opt into garbage collection.
+inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) {
+ auto type = &heap_type->ht_type;
+#if defined(PYPY_VERSION)
+ pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are "
+ "currently not supported in "
+ "conjunction with PyPy!");
+#endif
+ type->tp_flags |= Py_TPFLAGS_HAVE_GC;
+ type->tp_dictoffset = type->tp_basicsize; // place dict at the end
+ type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it
+ type->tp_traverse = pybind11_traverse;
+ type->tp_clear = pybind11_clear;
+
+ static PyGetSetDef getset[] = {
+ {const_cast<char*>("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr},
+ {nullptr, nullptr, nullptr, nullptr, nullptr}
+ };
+ type->tp_getset = getset;
+}
+
+/// buffer_protocol: Fill in the view as specified by flags.
+extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
+ // Look for a `get_buffer` implementation in this type's info or any bases (following MRO).
+ type_info *tinfo = nullptr;
+ for (auto type : reinterpret_borrow<tuple>(Py_TYPE(obj)->tp_mro)) {
+ tinfo = get_type_info((PyTypeObject *) type.ptr());
+ if (tinfo && tinfo->get_buffer)
+ break;
+ }
+ if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) {
+ if (view)
+ view->obj = nullptr;
+ PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
+ return -1;
+ }
+ std::memset(view, 0, sizeof(Py_buffer));
+ buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data);
+ view->obj = obj;
+ view->ndim = 1;
+ view->internal = info;
+ view->buf = info->ptr;
+ view->itemsize = info->itemsize;
+ view->len = view->itemsize;
+ for (auto s : info->shape)
+ view->len *= s;
+ if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)
+ view->format = const_cast<char *>(info->format.c_str());
+ if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
+ view->ndim = (int) info->ndim;
+ view->strides = &info->strides[0];
+ view->shape = &info->shape[0];
+ }
+ Py_INCREF(view->obj);
+ return 0;
+}
+
+/// buffer_protocol: Release the resources of the buffer.
+extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) {
+ delete (buffer_info *) view->internal;
+}
+
+/// Give this type a buffer interface.
+inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) {
+ heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer;
+#if PY_MAJOR_VERSION < 3
+ heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER;
+#endif
+
+ heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer;
+ heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer;
+}
+
+/** Create a brand new Python type according to the `type_record` specification.
+ Return value: New reference. */
+inline PyObject* make_new_python_type(const type_record &rec) {
+ auto name = reinterpret_steal<object>(PYBIND11_FROM_STRING(rec.name));
+
+ auto qualname = name;
+ if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) {
+#if PY_MAJOR_VERSION >= 3
+ qualname = reinterpret_steal<object>(
+ PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr()));
+#else
+ qualname = str(rec.scope.attr("__qualname__").cast<std::string>() + "." + rec.name);
+#endif
+ }
+
+ object module;
+ if (rec.scope) {
+ if (hasattr(rec.scope, "__module__"))
+ module = rec.scope.attr("__module__");
+ else if (hasattr(rec.scope, "__name__"))
+ module = rec.scope.attr("__name__");
+ }
+
+ auto full_name = c_str(
+#if !defined(PYPY_VERSION)
+ module ? str(module).cast<std::string>() + "." + rec.name :
+#endif
+ rec.name);
+
+ char *tp_doc = nullptr;
+ if (rec.doc && options::show_user_defined_docstrings()) {
+ /* Allocate memory for docstring (using PyObject_MALLOC, since
+ Python will free this later on) */
+ size_t size = strlen(rec.doc) + 1;
+ tp_doc = (char *) PyObject_MALLOC(size);
+ memcpy((void *) tp_doc, rec.doc, size);
+ }
+
+ auto &internals = get_internals();
+ auto bases = tuple(rec.bases);
+ auto base = (bases.size() == 0) ? internals.instance_base
+ : bases[0].ptr();
+
+ /* Danger zone: from now (and until PyType_Ready), make sure to
+ issue no Python C API calls which could potentially invoke the
+ garbage collector (the GC will call type_traverse(), which will in
+ turn find the newly constructed type in an invalid state) */
+ auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr()
+ : internals.default_metaclass;
+
+ auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
+ if (!heap_type)
+ pybind11_fail(std::string(rec.name) + ": Unable to create type object!");
+
+ heap_type->ht_name = name.release().ptr();
+#ifdef PYBIND11_BUILTIN_QUALNAME
+ heap_type->ht_qualname = qualname.inc_ref().ptr();
+#endif
+
+ auto type = &heap_type->ht_type;
+ type->tp_name = full_name;
+ type->tp_doc = tp_doc;
+ type->tp_base = type_incref((PyTypeObject *)base);
+ type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
+ if (bases.size() > 0)
+ type->tp_bases = bases.release().ptr();
+
+ /* Don't inherit base __init__ */
+ type->tp_init = pybind11_object_init;
+
+ /* Supported protocols */
+ type->tp_as_number = &heap_type->as_number;
+ type->tp_as_sequence = &heap_type->as_sequence;
+ type->tp_as_mapping = &heap_type->as_mapping;
+
+ /* Flags */
+ type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
+#if PY_MAJOR_VERSION < 3
+ type->tp_flags |= Py_TPFLAGS_CHECKTYPES;
+#endif
+
+ if (rec.dynamic_attr)
+ enable_dynamic_attributes(heap_type);
+
+ if (rec.buffer_protocol)
+ enable_buffer_protocol(heap_type);
+
+ if (PyType_Ready(type) < 0)
+ pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!");
+
+ assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)
+ : !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
+
+ /* Register type with the parent scope */
+ if (rec.scope)
+ setattr(rec.scope, rec.name, (PyObject *) type);
+ else
+ Py_INCREF(type); // Keep it alive forever (reference leak)
+
+ if (module) // Needed by pydoc
+ setattr((PyObject *) type, "__module__", module);
+
+ PYBIND11_SET_OLDPY_QUALNAME(type, qualname);
+
+ return (PyObject *) type;
+}
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/common.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/common.h
new file mode 100644
index 000000000..7d41cd63b
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/common.h
@@ -0,0 +1,802 @@
+/*
+ pybind11/detail/common.h -- Basic macros
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#if !defined(NAMESPACE_BEGIN)
+# define NAMESPACE_BEGIN(name) namespace name {
+#endif
+#if !defined(NAMESPACE_END)
+# define NAMESPACE_END(name) }
+#endif
+
+// Robust support for some features and loading modules compiled against different pybind versions
+// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on
+// the main `pybind11` namespace.
+#if !defined(PYBIND11_NAMESPACE)
+# ifdef __GNUG__
+# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden")))
+# else
+# define PYBIND11_NAMESPACE pybind11
+# endif
+#endif
+
+#if !defined(_MSC_VER) && !defined(__INTEL_COMPILER)
+# if __cplusplus >= 201402L
+# define PYBIND11_CPP14
+# if __cplusplus > 201402L /* Temporary: should be updated to >= the final C++17 value once known */
+# define PYBIND11_CPP17
+# endif
+# endif
+#elif defined(_MSC_VER)
+// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented)
+# if _MSVC_LANG >= 201402L
+# define PYBIND11_CPP14
+# if _MSVC_LANG > 201402L && _MSC_VER >= 1910
+# define PYBIND11_CPP17
+# endif
+# endif
+#endif
+
+// Compiler version assertions
+#if defined(__INTEL_COMPILER)
+# if __INTEL_COMPILER < 1500
+# error pybind11 requires Intel C++ compiler v15 or newer
+# endif
+#elif defined(__clang__) && !defined(__apple_build_version__)
+# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3)
+# error pybind11 requires clang 3.3 or newer
+# endif
+#elif defined(__clang__)
+// Apple changes clang version macros to its Xcode version; the first Xcode release based on
+// (upstream) clang 3.3 was Xcode 5:
+# if __clang_major__ < 5
+# error pybind11 requires Xcode/clang 5.0 or newer
+# endif
+#elif defined(__GNUG__)
+# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8)
+# error pybind11 requires gcc 4.8 or newer
+# endif
+#elif defined(_MSC_VER)
+// Pybind hits various compiler bugs in 2015u2 and earlier, and also makes use of some stl features
+// (e.g. std::negation) added in 2015u3:
+# if _MSC_FULL_VER < 190024210
+# error pybind11 requires MSVC 2015 update 3 or newer
+# endif
+#endif
+
+#if !defined(PYBIND11_EXPORT)
+# if defined(WIN32) || defined(_WIN32)
+# define PYBIND11_EXPORT __declspec(dllexport)
+# else
+# define PYBIND11_EXPORT __attribute__ ((visibility("default")))
+# endif
+#endif
+
+#if defined(_MSC_VER)
+# define PYBIND11_NOINLINE __declspec(noinline)
+#else
+# define PYBIND11_NOINLINE __attribute__ ((noinline))
+#endif
+
+#if defined(PYBIND11_CPP14)
+# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]]
+#else
+# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason)))
+#endif
+
+#define PYBIND11_VERSION_MAJOR 2
+#define PYBIND11_VERSION_MINOR 2
+#define PYBIND11_VERSION_PATCH 2
+
+/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
+#if defined(_MSC_VER)
+# if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 4)
+# define HAVE_ROUND 1
+# endif
+# pragma warning(push)
+# pragma warning(disable: 4510 4610 4512 4005)
+# if defined(_DEBUG)
+# define PYBIND11_DEBUG_MARKER
+# undef _DEBUG
+# endif
+#endif
+
+#include <Python.h>
+#include <frameobject.h>
+#include <pythread.h>
+
+#if defined(_WIN32) && (defined(min) || defined(max))
+# error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows
+#endif
+
+#if defined(isalnum)
+# undef isalnum
+# undef isalpha
+# undef islower
+# undef isspace
+# undef isupper
+# undef tolower
+# undef toupper
+#endif
+
+#if defined(_MSC_VER)
+# if defined(PYBIND11_DEBUG_MARKER)
+# define _DEBUG
+# undef PYBIND11_DEBUG_MARKER
+# endif
+# pragma warning(pop)
+#endif
+
+#include <cstddef>
+#include <cstring>
+#include <forward_list>
+#include <vector>
+#include <string>
+#include <stdexcept>
+#include <unordered_set>
+#include <unordered_map>
+#include <memory>
+#include <typeindex>
+#include <type_traits>
+
+#if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions
+#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr)
+#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check
+#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION
+#define PYBIND11_BYTES_CHECK PyBytes_Check
+#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString
+#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize
+#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize
+#define PYBIND11_BYTES_AS_STRING PyBytes_AsString
+#define PYBIND11_BYTES_SIZE PyBytes_Size
+#define PYBIND11_LONG_CHECK(o) PyLong_Check(o)
+#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o)
+#define PYBIND11_BYTES_NAME "bytes"
+#define PYBIND11_STRING_NAME "str"
+#define PYBIND11_SLICE_OBJECT PyObject
+#define PYBIND11_FROM_STRING PyUnicode_FromString
+#define PYBIND11_STR_TYPE ::pybind11::str
+#define PYBIND11_BOOL_ATTR "__bool__"
+#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
+#define PYBIND11_PLUGIN_IMPL(name) \
+ extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
+
+#else
+#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
+#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
+#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyMethod_GET_FUNCTION
+#define PYBIND11_BYTES_CHECK PyString_Check
+#define PYBIND11_BYTES_FROM_STRING PyString_FromString
+#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyString_FromStringAndSize
+#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyString_AsStringAndSize
+#define PYBIND11_BYTES_AS_STRING PyString_AsString
+#define PYBIND11_BYTES_SIZE PyString_Size
+#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o))
+#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o))
+#define PYBIND11_BYTES_NAME "str"
+#define PYBIND11_STRING_NAME "unicode"
+#define PYBIND11_SLICE_OBJECT PySliceObject
+#define PYBIND11_FROM_STRING PyString_FromString
+#define PYBIND11_STR_TYPE ::pybind11::bytes
+#define PYBIND11_BOOL_ATTR "__nonzero__"
+#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
+#define PYBIND11_PLUGIN_IMPL(name) \
+ static PyObject *pybind11_init_wrapper(); \
+ extern "C" PYBIND11_EXPORT void init##name() { \
+ (void)pybind11_init_wrapper(); \
+ } \
+ PyObject *pybind11_init_wrapper()
+#endif
+
+#if PY_VERSION_HEX >= 0x03050000 && PY_VERSION_HEX < 0x03050200
+extern "C" {
+ struct _Py_atomic_address { void *value; };
+ PyAPI_DATA(_Py_atomic_address) _PyThreadState_Current;
+}
+#endif
+
+#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code
+#define PYBIND11_STRINGIFY(x) #x
+#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x)
+#define PYBIND11_CONCAT(first, second) first##second
+
+/** \rst
+ ***Deprecated in favor of PYBIND11_MODULE***
+
+ This macro creates the entry point that will be invoked when the Python interpreter
+ imports a plugin library. Please create a `module` in the function body and return
+ the pointer to its underlying Python object at the end.
+
+ .. code-block:: cpp
+
+ PYBIND11_PLUGIN(example) {
+ pybind11::module m("example", "pybind11 example plugin");
+ /// Set up bindings here
+ return m.ptr();
+ }
+\endrst */
+#define PYBIND11_PLUGIN(name) \
+ PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \
+ static PyObject *pybind11_init(); \
+ PYBIND11_PLUGIN_IMPL(name) { \
+ int major, minor; \
+ if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) { \
+ PyErr_SetString(PyExc_ImportError, "Can't parse Python version."); \
+ return nullptr; \
+ } else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) { \
+ PyErr_Format(PyExc_ImportError, \
+ "Python version mismatch: module was compiled for " \
+ "version %i.%i, while the interpreter is running " \
+ "version %i.%i.", PY_MAJOR_VERSION, PY_MINOR_VERSION, \
+ major, minor); \
+ return nullptr; \
+ } \
+ try { \
+ return pybind11_init(); \
+ } catch (pybind11::error_already_set &e) { \
+ PyErr_SetString(PyExc_ImportError, e.what()); \
+ return nullptr; \
+ } catch (const std::exception &e) { \
+ PyErr_SetString(PyExc_ImportError, e.what()); \
+ return nullptr; \
+ } \
+ } \
+ PyObject *pybind11_init()
+
+/** \rst
+ This macro creates the entry point that will be invoked when the Python interpreter
+ imports an extension module. The module name is given as the fist argument and it
+ should not be in quotes. The second macro argument defines a variable of type
+ `py::module` which can be used to initialize the module.
+
+ .. code-block:: cpp
+
+ PYBIND11_MODULE(example, m) {
+ m.doc() = "pybind11 example module";
+
+ // Add bindings here
+ m.def("foo", []() {
+ return "Hello, World!";
+ });
+ }
+\endrst */
+#define PYBIND11_MODULE(name, variable) \
+ static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
+ PYBIND11_PLUGIN_IMPL(name) { \
+ int major, minor; \
+ if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) { \
+ PyErr_SetString(PyExc_ImportError, "Can't parse Python version."); \
+ return nullptr; \
+ } else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) { \
+ PyErr_Format(PyExc_ImportError, \
+ "Python version mismatch: module was compiled for " \
+ "version %i.%i, while the interpreter is running " \
+ "version %i.%i.", PY_MAJOR_VERSION, PY_MINOR_VERSION, \
+ major, minor); \
+ return nullptr; \
+ } \
+ auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
+ try { \
+ PYBIND11_CONCAT(pybind11_init_, name)(m); \
+ return m.ptr(); \
+ } catch (pybind11::error_already_set &e) { \
+ PyErr_SetString(PyExc_ImportError, e.what()); \
+ return nullptr; \
+ } catch (const std::exception &e) { \
+ PyErr_SetString(PyExc_ImportError, e.what()); \
+ return nullptr; \
+ } \
+ } \
+ void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
+
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+using ssize_t = Py_ssize_t;
+using size_t = std::size_t;
+
+/// Approach used to cast a previously unknown C++ instance into a Python object
+enum class return_value_policy : uint8_t {
+ /** This is the default return value policy, which falls back to the policy
+ return_value_policy::take_ownership when the return value is a pointer.
+ Otherwise, it uses return_value::move or return_value::copy for rvalue
+ and lvalue references, respectively. See below for a description of what
+ all of these different policies do. */
+ automatic = 0,
+
+ /** As above, but use policy return_value_policy::reference when the return
+ value is a pointer. This is the default conversion policy for function
+ arguments when calling Python functions manually from C++ code (i.e. via
+ handle::operator()). You probably won't need to use this. */
+ automatic_reference,
+
+ /** Reference an existing object (i.e. do not create a new copy) and take
+ ownership. Python will call the destructor and delete operator when the
+ object’s reference count reaches zero. Undefined behavior ensues when
+ the C++ side does the same.. */
+ take_ownership,
+
+ /** Create a new copy of the returned object, which will be owned by
+ Python. This policy is comparably safe because the lifetimes of the two
+ instances are decoupled. */
+ copy,
+
+ /** Use std::move to move the return value contents into a new instance
+ that will be owned by Python. This policy is comparably safe because the
+ lifetimes of the two instances (move source and destination) are
+ decoupled. */
+ move,
+
+ /** Reference an existing object, but do not take ownership. The C++ side
+ is responsible for managing the object’s lifetime and deallocating it
+ when it is no longer used. Warning: undefined behavior will ensue when
+ the C++ side deletes an object that is still referenced and used by
+ Python. */
+ reference,
+
+ /** This policy only applies to methods and properties. It references the
+ object without taking ownership similar to the above
+ return_value_policy::reference policy. In contrast to that policy, the
+ function or property’s implicit this argument (called the parent) is
+ considered to be the the owner of the return value (the child).
+ pybind11 then couples the lifetime of the parent to the child via a
+ reference relationship that ensures that the parent cannot be garbage
+ collected while Python is still using the child. More advanced
+ variations of this scheme are also possible using combinations of
+ return_value_policy::reference and the keep_alive call policy */
+ reference_internal
+};
+
+NAMESPACE_BEGIN(detail)
+
+inline static constexpr int log2(size_t n, int k = 0) { return (n <= 1) ? k : log2(n >> 1, k + 1); }
+
+// Returns the size as a multiple of sizeof(void *), rounded up.
+inline static constexpr size_t size_in_ptrs(size_t s) { return 1 + ((s - 1) >> log2(sizeof(void *))); }
+
+/**
+ * The space to allocate for simple layout instance holders (see below) in multiple of the size of
+ * a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required
+ * to holder either a std::unique_ptr or std::shared_ptr (which is almost always
+ * sizeof(std::shared_ptr<T>)).
+ */
+constexpr size_t instance_simple_holder_in_ptrs() {
+ static_assert(sizeof(std::shared_ptr<int>) >= sizeof(std::unique_ptr<int>),
+ "pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs");
+ return size_in_ptrs(sizeof(std::shared_ptr<int>));
+}
+
+// Forward declarations
+struct type_info;
+struct value_and_holder;
+
+struct nonsimple_values_and_holders {
+ void **values_and_holders;
+ uint8_t *status;
+};
+
+/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof')
+struct instance {
+ PyObject_HEAD
+ /// Storage for pointers and holder; see simple_layout, below, for a description
+ union {
+ void *simple_value_holder[1 + instance_simple_holder_in_ptrs()];
+ nonsimple_values_and_holders nonsimple;
+ };
+ /// Weak references (needed for keep alive):
+ PyObject *weakrefs;
+ /// If true, the pointer is owned which means we're free to manage it with a holder.
+ bool owned : 1;
+ /**
+ * An instance has two possible value/holder layouts.
+ *
+ * Simple layout (when this flag is true), means the `simple_value_holder` is set with a pointer
+ * and the holder object governing that pointer, i.e. [val1*][holder]. This layout is applied
+ * whenever there is no python-side multiple inheritance of bound C++ types *and* the type's
+ * holder will fit in the default space (which is large enough to hold either a std::unique_ptr
+ * or std::shared_ptr).
+ *
+ * Non-simple layout applies when using custom holders that require more space than `shared_ptr`
+ * (which is typically the size of two pointers), or when multiple inheritance is used on the
+ * python side. Non-simple layout allocates the required amount of memory to have multiple
+ * bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a
+ * pointer to allocated space of the required space to hold a a sequence of value pointers and
+ * holders followed `status`, a set of bit flags (1 byte each), i.e.
+ * [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of
+ * `sizeof(void *)`. `nonsimple.holder_constructed` is, for convenience, a pointer to the
+ * beginning of the [bb...] block (but not independently allocated).
+ *
+ * Status bits indicate whether the associated holder is constructed (&
+ * status_holder_constructed) and whether the value pointer is registered (&
+ * status_instance_registered) in `registered_instances`.
+ */
+ bool simple_layout : 1;
+ /// For simple layout, tracks whether the holder has been constructed
+ bool simple_holder_constructed : 1;
+ /// For simple layout, tracks whether the instance is registered in `registered_instances`
+ bool simple_instance_registered : 1;
+ /// If true, get_internals().patients has an entry for this object
+ bool has_patients : 1;
+
+ /// Initializes all of the above type/values/holders data (but not the instance values themselves)
+ void allocate_layout();
+
+ /// Destroys/deallocates all of the above
+ void deallocate_layout();
+
+ /// Returns the value_and_holder wrapper for the given type (or the first, if `find_type`
+ /// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if
+ /// `throw_if_missing` is false.
+ value_and_holder get_value_and_holder(const type_info *find_type = nullptr, bool throw_if_missing = true);
+
+ /// Bit values for the non-simple status flags
+ static constexpr uint8_t status_holder_constructed = 1;
+ static constexpr uint8_t status_instance_registered = 2;
+};
+
+static_assert(std::is_standard_layout<instance>::value, "Internal error: `pybind11::detail::instance` is not standard layout!");
+
+/// from __cpp_future__ import (convenient aliases from C++14/17)
+#if defined(PYBIND11_CPP14) && (!defined(_MSC_VER) || _MSC_VER >= 1910)
+using std::enable_if_t;
+using std::conditional_t;
+using std::remove_cv_t;
+using std::remove_reference_t;
+#else
+template <bool B, typename T = void> using enable_if_t = typename std::enable_if<B, T>::type;
+template <bool B, typename T, typename F> using conditional_t = typename std::conditional<B, T, F>::type;
+template <typename T> using remove_cv_t = typename std::remove_cv<T>::type;
+template <typename T> using remove_reference_t = typename std::remove_reference<T>::type;
+#endif
+
+/// Index sequences
+#if defined(PYBIND11_CPP14)
+using std::index_sequence;
+using std::make_index_sequence;
+#else
+template<size_t ...> struct index_sequence { };
+template<size_t N, size_t ...S> struct make_index_sequence_impl : make_index_sequence_impl <N - 1, N - 1, S...> { };
+template<size_t ...S> struct make_index_sequence_impl <0, S...> { typedef index_sequence<S...> type; };
+template<size_t N> using make_index_sequence = typename make_index_sequence_impl<N>::type;
+#endif
+
+/// Make an index sequence of the indices of true arguments
+template <typename ISeq, size_t, bool...> struct select_indices_impl { using type = ISeq; };
+template <size_t... IPrev, size_t I, bool B, bool... Bs> struct select_indices_impl<index_sequence<IPrev...>, I, B, Bs...>
+ : select_indices_impl<conditional_t<B, index_sequence<IPrev..., I>, index_sequence<IPrev...>>, I + 1, Bs...> {};
+template <bool... Bs> using select_indices = typename select_indices_impl<index_sequence<>, 0, Bs...>::type;
+
+/// Backports of std::bool_constant and std::negation to accommodate older compilers
+template <bool B> using bool_constant = std::integral_constant<bool, B>;
+template <typename T> struct negation : bool_constant<!T::value> { };
+
+template <typename...> struct void_t_impl { using type = void; };
+template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
+
+/// Compile-time all/any/none of that check the boolean value of all template types
+#ifdef __cpp_fold_expressions
+template <class... Ts> using all_of = bool_constant<(Ts::value && ...)>;
+template <class... Ts> using any_of = bool_constant<(Ts::value || ...)>;
+#elif !defined(_MSC_VER)
+template <bool...> struct bools {};
+template <class... Ts> using all_of = std::is_same<
+ bools<Ts::value..., true>,
+ bools<true, Ts::value...>>;
+template <class... Ts> using any_of = negation<all_of<negation<Ts>...>>;
+#else
+// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit
+// at a slight loss of compilation efficiency).
+template <class... Ts> using all_of = std::conjunction<Ts...>;
+template <class... Ts> using any_of = std::disjunction<Ts...>;
+#endif
+template <class... Ts> using none_of = negation<any_of<Ts...>>;
+
+template <class T, template<class> class... Predicates> using satisfies_all_of = all_of<Predicates<T>...>;
+template <class T, template<class> class... Predicates> using satisfies_any_of = any_of<Predicates<T>...>;
+template <class T, template<class> class... Predicates> using satisfies_none_of = none_of<Predicates<T>...>;
+
+/// Strip the class from a method type
+template <typename T> struct remove_class { };
+template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...)> { typedef R type(A...); };
+template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...) const> { typedef R type(A...); };
+
+/// Helper template to strip away type modifiers
+template <typename T> struct intrinsic_type { typedef T type; };
+template <typename T> struct intrinsic_type<const T> { typedef typename intrinsic_type<T>::type type; };
+template <typename T> struct intrinsic_type<T*> { typedef typename intrinsic_type<T>::type type; };
+template <typename T> struct intrinsic_type<T&> { typedef typename intrinsic_type<T>::type type; };
+template <typename T> struct intrinsic_type<T&&> { typedef typename intrinsic_type<T>::type type; };
+template <typename T, size_t N> struct intrinsic_type<const T[N]> { typedef typename intrinsic_type<T>::type type; };
+template <typename T, size_t N> struct intrinsic_type<T[N]> { typedef typename intrinsic_type<T>::type type; };
+template <typename T> using intrinsic_t = typename intrinsic_type<T>::type;
+
+/// Helper type to replace 'void' in some expressions
+struct void_type { };
+
+/// Helper template which holds a list of types
+template <typename...> struct type_list { };
+
+/// Compile-time integer sum
+#ifdef __cpp_fold_expressions
+template <typename... Ts> constexpr size_t constexpr_sum(Ts... ns) { return (0 + ... + size_t{ns}); }
+#else
+constexpr size_t constexpr_sum() { return 0; }
+template <typename T, typename... Ts>
+constexpr size_t constexpr_sum(T n, Ts... ns) { return size_t{n} + constexpr_sum(ns...); }
+#endif
+
+NAMESPACE_BEGIN(constexpr_impl)
+/// Implementation details for constexpr functions
+constexpr int first(int i) { return i; }
+template <typename T, typename... Ts>
+constexpr int first(int i, T v, Ts... vs) { return v ? i : first(i + 1, vs...); }
+
+constexpr int last(int /*i*/, int result) { return result; }
+template <typename T, typename... Ts>
+constexpr int last(int i, int result, T v, Ts... vs) { return last(i + 1, v ? i : result, vs...); }
+NAMESPACE_END(constexpr_impl)
+
+/// Return the index of the first type in Ts which satisfies Predicate<T>. Returns sizeof...(Ts) if
+/// none match.
+template <template<typename> class Predicate, typename... Ts>
+constexpr int constexpr_first() { return constexpr_impl::first(0, Predicate<Ts>::value...); }
+
+/// Return the index of the last type in Ts which satisfies Predicate<T>, or -1 if none match.
+template <template<typename> class Predicate, typename... Ts>
+constexpr int constexpr_last() { return constexpr_impl::last(0, -1, Predicate<Ts>::value...); }
+
+/// Return the Nth element from the parameter pack
+template <size_t N, typename T, typename... Ts>
+struct pack_element { using type = typename pack_element<N - 1, Ts...>::type; };
+template <typename T, typename... Ts>
+struct pack_element<0, T, Ts...> { using type = T; };
+
+/// Return the one and only type which matches the predicate, or Default if none match.
+/// If more than one type matches the predicate, fail at compile-time.
+template <template<typename> class Predicate, typename Default, typename... Ts>
+struct exactly_one {
+ static constexpr auto found = constexpr_sum(Predicate<Ts>::value...);
+ static_assert(found <= 1, "Found more than one type matching the predicate");
+
+ static constexpr auto index = found ? constexpr_first<Predicate, Ts...>() : 0;
+ using type = conditional_t<found, typename pack_element<index, Ts...>::type, Default>;
+};
+template <template<typename> class P, typename Default>
+struct exactly_one<P, Default> { using type = Default; };
+
+template <template<typename> class Predicate, typename Default, typename... Ts>
+using exactly_one_t = typename exactly_one<Predicate, Default, Ts...>::type;
+
+/// Defer the evaluation of type T until types Us are instantiated
+template <typename T, typename... /*Us*/> struct deferred_type { using type = T; };
+template <typename T, typename... Us> using deferred_t = typename deferred_type<T, Us...>::type;
+
+/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of<T, T>::value == false`,
+/// unlike `std::is_base_of`)
+template <typename Base, typename Derived> using is_strict_base_of = bool_constant<
+ std::is_base_of<Base, Derived>::value && !std::is_same<Base, Derived>::value>;
+
+template <template<typename...> class Base>
+struct is_template_base_of_impl {
+ template <typename... Us> static std::true_type check(Base<Us...> *);
+ static std::false_type check(...);
+};
+
+/// Check if a template is the base of a type. For example:
+/// `is_template_base_of<Base, T>` is true if `struct T : Base<U> {}` where U can be anything
+template <template<typename...> class Base, typename T>
+#if !defined(_MSC_VER)
+using is_template_base_of = decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr));
+#else // MSVC2015 has trouble with decltype in template aliases
+struct is_template_base_of : decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr)) { };
+#endif
+
+/// Check if T is an instantiation of the template `Class`. For example:
+/// `is_instantiation<shared_ptr, T>` is true if `T == shared_ptr<U>` where U can be anything.
+template <template<typename...> class Class, typename T>
+struct is_instantiation : std::false_type { };
+template <template<typename...> class Class, typename... Us>
+struct is_instantiation<Class, Class<Us...>> : std::true_type { };
+
+/// Check if T is std::shared_ptr<U> where U can be anything
+template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
+
+/// Check if T looks like an input iterator
+template <typename T, typename = void> struct is_input_iterator : std::false_type {};
+template <typename T>
+struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
+ : std::true_type {};
+
+template <typename T> using is_function_pointer = bool_constant<
+ std::is_pointer<T>::value && std::is_function<typename std::remove_pointer<T>::type>::value>;
+
+template <typename F> struct strip_function_object {
+ using type = typename remove_class<decltype(&F::operator())>::type;
+};
+
+// Extracts the function signature from a function, function pointer or lambda.
+template <typename Function, typename F = remove_reference_t<Function>>
+using function_signature_t = conditional_t<
+ std::is_function<F>::value,
+ F,
+ typename conditional_t<
+ std::is_pointer<F>::value || std::is_member_pointer<F>::value,
+ std::remove_pointer<F>,
+ strip_function_object<F>
+ >::type
+>;
+
+/// Returns true if the type looks like a lambda: that is, isn't a function, pointer or member
+/// pointer. Note that this can catch all sorts of other things, too; this is intended to be used
+/// in a place where passing a lambda makes sense.
+template <typename T> using is_lambda = satisfies_none_of<remove_reference_t<T>,
+ std::is_function, std::is_pointer, std::is_member_pointer>;
+
+/// Ignore that a variable is unused in compiler warnings
+inline void ignore_unused(const int *) { }
+
+/// Apply a function over each element of a parameter pack
+#ifdef __cpp_fold_expressions
+#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...)
+#else
+using expand_side_effects = bool[];
+#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) pybind11::detail::expand_side_effects{ ((PATTERN), void(), false)..., false }
+#endif
+
+NAMESPACE_END(detail)
+
+/// C++ bindings of builtin Python exceptions
+class builtin_exception : public std::runtime_error {
+public:
+ using std::runtime_error::runtime_error;
+ /// Set the error using the Python C API
+ virtual void set_error() const = 0;
+};
+
+#define PYBIND11_RUNTIME_EXCEPTION(name, type) \
+ class name : public builtin_exception { public: \
+ using builtin_exception::builtin_exception; \
+ name() : name("") { } \
+ void set_error() const override { PyErr_SetString(type, what()); } \
+ };
+
+PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
+PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError)
+PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError)
+PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError)
+PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError)
+PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error
+PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally
+
+[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); }
+[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); }
+
+template <typename T, typename SFINAE = void> struct format_descriptor { };
+
+NAMESPACE_BEGIN(detail)
+// Returns the index of the given type in the type char array below, and in the list in numpy.h
+// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double;
+// complex float,double,long double. Note that the long double types only participate when long
+// double is actually longer than double (it isn't under MSVC).
+// NB: not only the string below but also complex.h and numpy.h rely on this order.
+template <typename T, typename SFINAE = void> struct is_fmt_numeric { static constexpr bool value = false; };
+template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>::value>> {
+ static constexpr bool value = true;
+ static constexpr int index = std::is_same<T, bool>::value ? 0 : 1 + (
+ std::is_integral<T>::value ? detail::log2(sizeof(T))*2 + std::is_unsigned<T>::value : 8 + (
+ std::is_same<T, double>::value ? 1 : std::is_same<T, long double>::value ? 2 : 0));
+};
+NAMESPACE_END(detail)
+
+template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_arithmetic<T>::value>> {
+ static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric<T>::index];
+ static constexpr const char value[2] = { c, '\0' };
+ static std::string format() { return std::string(1, c); }
+};
+
+template <typename T> constexpr const char format_descriptor<
+ T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];
+
+/// RAII wrapper that temporarily clears any Python error state
+struct error_scope {
+ PyObject *type, *value, *trace;
+ error_scope() { PyErr_Fetch(&type, &value, &trace); }
+ ~error_scope() { PyErr_Restore(type, value, trace); }
+};
+
+/// Dummy destructor wrapper that can be used to expose classes with a private destructor
+struct nodelete { template <typename T> void operator()(T*) { } };
+
+// overload_cast requires variable templates: C++14
+#if defined(PYBIND11_CPP14)
+#define PYBIND11_OVERLOAD_CAST 1
+
+NAMESPACE_BEGIN(detail)
+template <typename... Args>
+struct overload_cast_impl {
+ constexpr overload_cast_impl() {} // MSVC 2015 needs this
+
+ template <typename Return>
+ constexpr auto operator()(Return (*pf)(Args...)) const noexcept
+ -> decltype(pf) { return pf; }
+
+ template <typename Return, typename Class>
+ constexpr auto operator()(Return (Class::*pmf)(Args...), std::false_type = {}) const noexcept
+ -> decltype(pmf) { return pmf; }
+
+ template <typename Return, typename Class>
+ constexpr auto operator()(Return (Class::*pmf)(Args...) const, std::true_type) const noexcept
+ -> decltype(pmf) { return pmf; }
+};
+NAMESPACE_END(detail)
+
+/// Syntax sugar for resolving overloaded function pointers:
+/// - regular: static_cast<Return (Class::*)(Arg0, Arg1, Arg2)>(&Class::func)
+/// - sweet: overload_cast<Arg0, Arg1, Arg2>(&Class::func)
+template <typename... Args>
+static constexpr detail::overload_cast_impl<Args...> overload_cast = {};
+// MSVC 2015 only accepts this particular initialization syntax for this variable template.
+
+/// Const member function selector for overload_cast
+/// - regular: static_cast<Return (Class::*)(Arg) const>(&Class::func)
+/// - sweet: overload_cast<Arg>(&Class::func, const_)
+static constexpr auto const_ = std::true_type{};
+
+#else // no overload_cast: providing something that static_assert-fails:
+template <typename... Args> struct overload_cast {
+ static_assert(detail::deferred_t<std::false_type, Args...>::value,
+ "pybind11::overload_cast<...> requires compiling in C++14 mode");
+};
+#endif // overload_cast
+
+NAMESPACE_BEGIN(detail)
+
+// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
+// any standard container (or C-style array) supporting std::begin/std::end, any singleton
+// arithmetic type (if T is arithmetic), or explicitly constructible from an iterator pair.
+template <typename T>
+class any_container {
+ std::vector<T> v;
+public:
+ any_container() = default;
+
+ // Can construct from a pair of iterators
+ template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
+ any_container(It first, It last) : v(first, last) { }
+
+ // Implicit conversion constructor from any arbitrary container type with values convertible to T
+ template <typename Container, typename = enable_if_t<std::is_convertible<decltype(*std::begin(std::declval<const Container &>())), T>::value>>
+ any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { }
+
+ // initializer_list's aren't deducible, so don't get matched by the above template; we need this
+ // to explicitly allow implicit conversion from one:
+ template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
+ any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) { }
+
+ // Avoid copying if given an rvalue vector of the correct type.
+ any_container(std::vector<T> &&v) : v(std::move(v)) { }
+
+ // Moves the vector out of an rvalue any_container
+ operator std::vector<T> &&() && { return std::move(v); }
+
+ // Dereferencing obtains a reference to the underlying vector
+ std::vector<T> &operator*() { return v; }
+ const std::vector<T> &operator*() const { return v; }
+
+ // -> lets you call methods on the underlying vector
+ std::vector<T> *operator->() { return &v; }
+ const std::vector<T> *operator->() const { return &v; }
+};
+
+NAMESPACE_END(detail)
+
+
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/descr.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/descr.h
new file mode 100644
index 000000000..e3bf2ba97
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/descr.h
@@ -0,0 +1,185 @@
+/*
+ pybind11/detail/descr.h: Helper type for concatenating type signatures
+ either at runtime (C++11) or compile time (C++14)
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "common.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+/* Concatenate type signatures at compile time using C++14 */
+#if defined(PYBIND11_CPP14) && !defined(_MSC_VER)
+#define PYBIND11_CONSTEXPR_DESCR
+
+template <size_t Size1, size_t Size2> class descr {
+ template <size_t Size1_, size_t Size2_> friend class descr;
+public:
+ constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1])
+ : descr(text, types,
+ make_index_sequence<Size1>(),
+ make_index_sequence<Size2>()) { }
+
+ constexpr const char *text() const { return m_text; }
+ constexpr const std::type_info * const * types() const { return m_types; }
+
+ template <size_t OtherSize1, size_t OtherSize2>
+ constexpr descr<Size1 + OtherSize1, Size2 + OtherSize2> operator+(const descr<OtherSize1, OtherSize2> &other) const {
+ return concat(other,
+ make_index_sequence<Size1>(),
+ make_index_sequence<Size2>(),
+ make_index_sequence<OtherSize1>(),
+ make_index_sequence<OtherSize2>());
+ }
+
+protected:
+ template <size_t... Indices1, size_t... Indices2>
+ constexpr descr(
+ char const (&text) [Size1+1],
+ const std::type_info * const (&types) [Size2+1],
+ index_sequence<Indices1...>, index_sequence<Indices2...>)
+ : m_text{text[Indices1]..., '\0'},
+ m_types{types[Indices2]..., nullptr } {}
+
+ template <size_t OtherSize1, size_t OtherSize2, size_t... Indices1,
+ size_t... Indices2, size_t... OtherIndices1, size_t... OtherIndices2>
+ constexpr descr<Size1 + OtherSize1, Size2 + OtherSize2>
+ concat(const descr<OtherSize1, OtherSize2> &other,
+ index_sequence<Indices1...>, index_sequence<Indices2...>,
+ index_sequence<OtherIndices1...>, index_sequence<OtherIndices2...>) const {
+ return descr<Size1 + OtherSize1, Size2 + OtherSize2>(
+ { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' },
+ { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr }
+ );
+ }
+
+protected:
+ char m_text[Size1 + 1];
+ const std::type_info * m_types[Size2 + 1];
+};
+
+template <size_t Size> constexpr descr<Size - 1, 0> _(char const(&text)[Size]) {
+ return descr<Size - 1, 0>(text, { nullptr });
+}
+
+template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
+template <size_t...Digits> struct int_to_str<0, Digits...> {
+ static constexpr auto digits = descr<sizeof...(Digits), 0>({ ('0' + Digits)..., '\0' }, { nullptr });
+};
+
+// Ternary description (like std::conditional)
+template <bool B, size_t Size1, size_t Size2>
+constexpr enable_if_t<B, descr<Size1 - 1, 0>> _(char const(&text1)[Size1], char const(&)[Size2]) {
+ return _(text1);
+}
+template <bool B, size_t Size1, size_t Size2>
+constexpr enable_if_t<!B, descr<Size2 - 1, 0>> _(char const(&)[Size1], char const(&text2)[Size2]) {
+ return _(text2);
+}
+template <bool B, size_t SizeA1, size_t SizeA2, size_t SizeB1, size_t SizeB2>
+constexpr enable_if_t<B, descr<SizeA1, SizeA2>> _(descr<SizeA1, SizeA2> d, descr<SizeB1, SizeB2>) { return d; }
+template <bool B, size_t SizeA1, size_t SizeA2, size_t SizeB1, size_t SizeB2>
+constexpr enable_if_t<!B, descr<SizeB1, SizeB2>> _(descr<SizeA1, SizeA2>, descr<SizeB1, SizeB2> d) { return d; }
+
+template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
+ return int_to_str<Size / 10, Size % 10>::digits;
+}
+
+template <typename Type> constexpr descr<1, 1> _() {
+ return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr });
+}
+
+inline constexpr descr<0, 0> concat() { return _(""); }
+template <size_t Size1, size_t Size2, typename... Args> auto constexpr concat(descr<Size1, Size2> descr) { return descr; }
+template <size_t Size1, size_t Size2, typename... Args> auto constexpr concat(descr<Size1, Size2> descr, Args&&... args) { return descr + _(", ") + concat(args...); }
+template <size_t Size1, size_t Size2> auto constexpr type_descr(descr<Size1, Size2> descr) { return _("{") + descr + _("}"); }
+
+#define PYBIND11_DESCR constexpr auto
+
+#else /* Simpler C++11 implementation based on run-time memory allocation and copying */
+
+class descr {
+public:
+ PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) {
+ size_t nChars = len(text), nTypes = len(types);
+ m_text = new char[nChars];
+ m_types = new const std::type_info *[nTypes];
+ memcpy(m_text, text, nChars * sizeof(char));
+ memcpy(m_types, types, nTypes * sizeof(const std::type_info *));
+ }
+
+ PYBIND11_NOINLINE descr operator+(descr &&d2) && {
+ descr r;
+
+ size_t nChars1 = len(m_text), nTypes1 = len(m_types);
+ size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types);
+
+ r.m_text = new char[nChars1 + nChars2 - 1];
+ r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1];
+ memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char));
+ memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char));
+ memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *));
+ memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *));
+
+ delete[] m_text; delete[] m_types;
+ delete[] d2.m_text; delete[] d2.m_types;
+
+ return r;
+ }
+
+ char *text() { return m_text; }
+ const std::type_info * * types() { return m_types; }
+
+protected:
+ PYBIND11_NOINLINE descr() { }
+
+ template <typename T> static size_t len(const T *ptr) { // return length including null termination
+ const T *it = ptr;
+ while (*it++ != (T) 0)
+ ;
+ return static_cast<size_t>(it - ptr);
+ }
+
+ const std::type_info **m_types = nullptr;
+ char *m_text = nullptr;
+};
+
+/* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */
+
+PYBIND11_NOINLINE inline descr _(const char *text) {
+ const std::type_info *types[1] = { nullptr };
+ return descr(text, types);
+}
+
+template <bool B> PYBIND11_NOINLINE enable_if_t<B, descr> _(const char *text1, const char *) { return _(text1); }
+template <bool B> PYBIND11_NOINLINE enable_if_t<!B, descr> _(char const *, const char *text2) { return _(text2); }
+template <bool B> PYBIND11_NOINLINE enable_if_t<B, descr> _(descr d, descr) { return d; }
+template <bool B> PYBIND11_NOINLINE enable_if_t<!B, descr> _(descr, descr d) { return d; }
+
+template <typename Type> PYBIND11_NOINLINE descr _() {
+ const std::type_info *types[2] = { &typeid(Type), nullptr };
+ return descr("%", types);
+}
+
+template <size_t Size> PYBIND11_NOINLINE descr _() {
+ const std::type_info *types[1] = { nullptr };
+ return descr(std::to_string(Size).c_str(), types);
+}
+
+PYBIND11_NOINLINE inline descr concat() { return _(""); }
+PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; }
+template <typename... Args> PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward<Args>(args)...); }
+PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); }
+
+#define PYBIND11_DESCR ::pybind11::detail::descr
+#endif
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/init.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/init.h
new file mode 100644
index 000000000..82f740760
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/init.h
@@ -0,0 +1,335 @@
+/*
+ pybind11/detail/init.h: init factory function implementation and support code.
+
+ Copyright (c) 2017 Jason Rhinelander <jason@imaginary.ca>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "class.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+template <>
+class type_caster<value_and_holder> {
+public:
+ bool load(handle h, bool) {
+ value = reinterpret_cast<value_and_holder *>(h.ptr());
+ return true;
+ }
+
+ template <typename> using cast_op_type = value_and_holder &;
+ operator value_and_holder &() { return *value; }
+ static PYBIND11_DESCR name() { return type_descr(_<value_and_holder>()); }
+
+private:
+ value_and_holder *value = nullptr;
+};
+
+NAMESPACE_BEGIN(initimpl)
+
+inline void no_nullptr(void *ptr) {
+ if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr");
+}
+
+// Implementing functions for all forms of py::init<...> and py::init(...)
+template <typename Class> using Cpp = typename Class::type;
+template <typename Class> using Alias = typename Class::type_alias;
+template <typename Class> using Holder = typename Class::holder_type;
+
+template <typename Class> using is_alias_constructible = std::is_constructible<Alias<Class>, Cpp<Class> &&>;
+
+// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance.
+template <typename Class, enable_if_t<Class::has_alias, int> = 0>
+bool is_alias(Cpp<Class> *ptr) {
+ return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
+}
+// Failing fallback version of the above for a no-alias class (always returns false)
+template <typename /*Class*/>
+constexpr bool is_alias(void *) { return false; }
+
+// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall
+// back to brace aggregate initiailization so that for aggregate initialization can be used with
+// py::init, e.g. `py::init<int, int>` to initialize a `struct T { int a; int b; }`. For
+// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
+// works, but will not do the expected thing when `T` has an `initializer_list<T>` constructor).
+template <typename Class, typename... Args, detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
+inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward<Args>(args)...); }
+template <typename Class, typename... Args, detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
+inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward<Args>(args)...}; }
+
+// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with
+// an alias to provide only a single Cpp factory function as long as the Alias can be
+// constructed from an rvalue reference of the base Cpp type. This means that Alias classes
+// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to
+// inherit all the base class constructors.
+template <typename Class>
+void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/,
+ value_and_holder &v_h, Cpp<Class> &&base) {
+ v_h.value_ptr() = new Alias<Class>(std::move(base));
+}
+template <typename Class>
+[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
+ value_and_holder &, Cpp<Class> &&) {
+ throw type_error("pybind11::init(): unable to convert returned instance to required "
+ "alias class: no `Alias<Class>(Class &&)` constructor available");
+}
+
+// Error-generating fallback for factories that don't match one of the below construction
+// mechanisms.
+template <typename Class>
+void construct(...) {
+ static_assert(!std::is_same<Class, Class>::value /* always false */,
+ "pybind11::init(): init function must return a compatible pointer, "
+ "holder, or value");
+}
+
+// Pointer return v1: the factory function returns a class pointer for a registered class.
+// If we don't need an alias (because this class doesn't have one, or because the final type is
+// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to
+// construct an Alias from the returned base instance.
+template <typename Class>
+void construct(value_and_holder &v_h, Cpp<Class> *ptr, bool need_alias) {
+ no_nullptr(ptr);
+ if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
+ // We're going to try to construct an alias by moving the cpp type. Whether or not
+ // that succeeds, we still need to destroy the original cpp pointer (either the
+ // moved away leftover, if the alias construction works, or the value itself if we
+ // throw an error), but we can't just call `delete ptr`: it might have a special
+ // deleter, or might be shared_from_this. So we construct a holder around it as if
+ // it was a normal instance, then steal the holder away into a local variable; thus
+ // the holder and destruction happens when we leave the C++ scope, and the holder
+ // class gets to handle the destruction however it likes.
+ v_h.value_ptr() = ptr;
+ v_h.set_instance_registered(true); // To prevent init_instance from registering it
+ v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
+ Holder<Class> temp_holder(std::move(v_h.holder<Holder<Class>>())); // Steal the holder
+ v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null
+ v_h.set_instance_registered(false);
+
+ construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(*ptr));
+ } else {
+ // Otherwise the type isn't inherited, so we don't need an Alias
+ v_h.value_ptr() = ptr;
+ }
+}
+
+// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over
+// ownership of the pointer.
+template <typename Class, enable_if_t<Class::has_alias, int> = 0>
+void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
+ no_nullptr(alias_ptr);
+ v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
+}
+
+// Holder return: copy its pointer, and move or copy the returned holder into the new instance's
+// holder. This also handles types like std::shared_ptr<T> and std::unique_ptr<T> where T is a
+// derived type (through those holder's implicit conversion from derived class holder constructors).
+template <typename Class>
+void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
+ auto *ptr = holder_helper<Holder<Class>>::get(holder);
+ // If we need an alias, check that the held pointer is actually an alias instance
+ if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
+ throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "
+ "is not an alias instance");
+
+ v_h.value_ptr() = ptr;
+ v_h.type->init_instance(v_h.inst, &holder);
+}
+
+// return-by-value version 1: returning a cpp class by value. If the class has an alias and an
+// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct
+// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't
+// need it, we simply move-construct the cpp value into a new instance.
+template <typename Class>
+void construct(value_and_holder &v_h, Cpp<Class> &&result, bool need_alias) {
+ static_assert(std::is_move_constructible<Cpp<Class>>::value,
+ "pybind11::init() return-by-value factory function requires a movable class");
+ if (Class::has_alias && need_alias)
+ construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(result));
+ else
+ v_h.value_ptr() = new Cpp<Class>(std::move(result));
+}
+
+// return-by-value version 2: returning a value of the alias type itself. We move-construct an
+// Alias instance (even if no the python-side inheritance is involved). The is intended for
+// cases where Alias initialization is always desired.
+template <typename Class>
+void construct(value_and_holder &v_h, Alias<Class> &&result, bool) {
+ static_assert(std::is_move_constructible<Alias<Class>>::value,
+ "pybind11::init() return-by-alias-value factory function requires a movable alias class");
+ v_h.value_ptr() = new Alias<Class>(std::move(result));
+}
+
+// Implementing class for py::init<...>()
+template <typename... Args>
+struct constructor {
+ template <typename Class, typename... Extra, enable_if_t<!Class::has_alias, int> = 0>
+ static void execute(Class &cl, const Extra&... extra) {
+ cl.def("__init__", [](value_and_holder &v_h, Args... args) {
+ v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
+ }, is_new_style_constructor(), extra...);
+ }
+
+ template <typename Class, typename... Extra,
+ enable_if_t<Class::has_alias &&
+ std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
+ static void execute(Class &cl, const Extra&... extra) {
+ cl.def("__init__", [](value_and_holder &v_h, Args... args) {
+ if (Py_TYPE(v_h.inst) == v_h.type->type)
+ v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
+ else
+ v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
+ }, is_new_style_constructor(), extra...);
+ }
+
+ template <typename Class, typename... Extra,
+ enable_if_t<Class::has_alias &&
+ !std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
+ static void execute(Class &cl, const Extra&... extra) {
+ cl.def("__init__", [](value_and_holder &v_h, Args... args) {
+ v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
+ }, is_new_style_constructor(), extra...);
+ }
+};
+
+// Implementing class for py::init_alias<...>()
+template <typename... Args> struct alias_constructor {
+ template <typename Class, typename... Extra,
+ enable_if_t<Class::has_alias && std::is_constructible<Alias<Class>, Args...>::value, int> = 0>
+ static void execute(Class &cl, const Extra&... extra) {
+ cl.def("__init__", [](value_and_holder &v_h, Args... args) {
+ v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
+ }, is_new_style_constructor(), extra...);
+ }
+};
+
+// Implementation class for py::init(Func) and py::init(Func, AliasFunc)
+template <typename CFunc, typename AFunc = void_type (*)(),
+ typename = function_signature_t<CFunc>, typename = function_signature_t<AFunc>>
+struct factory;
+
+// Specialization for py::init(Func)
+template <typename Func, typename Return, typename... Args>
+struct factory<Func, void_type (*)(), Return(Args...)> {
+ remove_reference_t<Func> class_factory;
+
+ factory(Func &&f) : class_factory(std::forward<Func>(f)) { }
+
+ // The given class either has no alias or has no separate alias factory;
+ // this always constructs the class itself. If the class is registered with an alias
+ // type and an alias instance is needed (i.e. because the final type is a Python class
+ // inheriting from the C++ type) the returned value needs to either already be an alias
+ // instance, or the alias needs to be constructible from a `Class &&` argument.
+ template <typename Class, typename... Extra>
+ void execute(Class &cl, const Extra &...extra) && {
+ #if defined(PYBIND11_CPP14)
+ cl.def("__init__", [func = std::move(class_factory)]
+ #else
+ auto &func = class_factory;
+ cl.def("__init__", [func]
+ #endif
+ (value_and_holder &v_h, Args... args) {
+ construct<Class>(v_h, func(std::forward<Args>(args)...),
+ Py_TYPE(v_h.inst) != v_h.type->type);
+ }, is_new_style_constructor(), extra...);
+ }
+};
+
+// Specialization for py::init(Func, AliasFunc)
+template <typename CFunc, typename AFunc,
+ typename CReturn, typename... CArgs, typename AReturn, typename... AArgs>
+struct factory<CFunc, AFunc, CReturn(CArgs...), AReturn(AArgs...)> {
+ static_assert(sizeof...(CArgs) == sizeof...(AArgs),
+ "pybind11::init(class_factory, alias_factory): class and alias factories "
+ "must have identical argument signatures");
+ static_assert(all_of<std::is_same<CArgs, AArgs>...>::value,
+ "pybind11::init(class_factory, alias_factory): class and alias factories "
+ "must have identical argument signatures");
+
+ remove_reference_t<CFunc> class_factory;
+ remove_reference_t<AFunc> alias_factory;
+
+ factory(CFunc &&c, AFunc &&a)
+ : class_factory(std::forward<CFunc>(c)), alias_factory(std::forward<AFunc>(a)) { }
+
+ // The class factory is called when the `self` type passed to `__init__` is the direct
+ // class (i.e. not inherited), the alias factory when `self` is a Python-side subtype.
+ template <typename Class, typename... Extra>
+ void execute(Class &cl, const Extra&... extra) && {
+ static_assert(Class::has_alias, "The two-argument version of `py::init()` can "
+ "only be used if the class has an alias");
+ #if defined(PYBIND11_CPP14)
+ cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)]
+ #else
+ auto &class_func = class_factory;
+ auto &alias_func = alias_factory;
+ cl.def("__init__", [class_func, alias_func]
+ #endif
+ (value_and_holder &v_h, CArgs... args) {
+ if (Py_TYPE(v_h.inst) == v_h.type->type)
+ // If the instance type equals the registered type we don't have inheritance, so
+ // don't need the alias and can construct using the class function:
+ construct<Class>(v_h, class_func(std::forward<CArgs>(args)...), false);
+ else
+ construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...), true);
+ }, is_new_style_constructor(), extra...);
+ }
+};
+
+/// Set just the C++ state. Same as `__init__`.
+template <typename Class, typename T>
+void setstate(value_and_holder &v_h, T &&result, bool need_alias) {
+ construct<Class>(v_h, std::forward<T>(result), need_alias);
+}
+
+/// Set both the C++ and Python states
+template <typename Class, typename T, typename O,
+ enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
+void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
+ construct<Class>(v_h, std::move(result.first), need_alias);
+ setattr((PyObject *) v_h.inst, "__dict__", result.second);
+}
+
+/// Implementation for py::pickle(GetState, SetState)
+template <typename Get, typename Set,
+ typename = function_signature_t<Get>, typename = function_signature_t<Set>>
+struct pickle_factory;
+
+template <typename Get, typename Set,
+ typename RetState, typename Self, typename NewInstance, typename ArgState>
+struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
+ static_assert(std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
+ "The type returned by `__getstate__` must be the same "
+ "as the argument accepted by `__setstate__`");
+
+ remove_reference_t<Get> get;
+ remove_reference_t<Set> set;
+
+ pickle_factory(Get get, Set set)
+ : get(std::forward<Get>(get)), set(std::forward<Set>(set)) { }
+
+ template <typename Class, typename... Extra>
+ void execute(Class &cl, const Extra &...extra) && {
+ cl.def("__getstate__", std::move(get));
+
+#if defined(PYBIND11_CPP14)
+ cl.def("__setstate__", [func = std::move(set)]
+#else
+ auto &func = set;
+ cl.def("__setstate__", [func]
+#endif
+ (value_and_holder &v_h, ArgState state) {
+ setstate<Class>(v_h, func(std::forward<ArgState>(state)),
+ Py_TYPE(v_h.inst) != v_h.type->type);
+ }, is_new_style_constructor(), extra...);
+ }
+};
+
+NAMESPACE_END(initimpl)
+NAMESPACE_END(detail)
+NAMESPACE_END(pybind11)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/internals.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/internals.h
new file mode 100644
index 000000000..e39f38695
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/internals.h
@@ -0,0 +1,249 @@
+/*
+ pybind11/detail/internals.h: Internal data structure and related functions
+
+ Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "../pytypes.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+// Forward declarations
+inline PyTypeObject *make_static_property_type();
+inline PyTypeObject *make_default_metaclass();
+inline PyObject *make_object_base_type(PyTypeObject *metaclass);
+
+// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
+// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
+// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
+// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
+// which works. If not under a known-good stl, provide our own name-based hash and equality
+// functions that use the type name.
+#if defined(__GLIBCXX__)
+inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
+using type_hash = std::hash<std::type_index>;
+using type_equal_to = std::equal_to<std::type_index>;
+#else
+inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
+ return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
+}
+
+struct type_hash {
+ size_t operator()(const std::type_index &t) const {
+ size_t hash = 5381;
+ const char *ptr = t.name();
+ while (auto c = static_cast<unsigned char>(*ptr++))
+ hash = (hash * 33) ^ c;
+ return hash;
+ }
+};
+
+struct type_equal_to {
+ bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
+ return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
+ }
+};
+#endif
+
+template <typename value_type>
+using type_map = std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
+
+struct overload_hash {
+ inline size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
+ size_t value = std::hash<const void *>()(v.first);
+ value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
+ return value;
+ }
+};
+
+/// Internal data structure used to track registered instances and types.
+/// Whenever binary incompatible changes are made to this structure,
+/// `PYBIND11_INTERNALS_VERSION` must be incremented.
+struct internals {
+ type_map<type_info *> registered_types_cpp; // std::type_index -> pybind11's type information
+ std::unordered_map<PyTypeObject *, std::vector<type_info *>> registered_types_py; // PyTypeObject* -> base type_info(s)
+ std::unordered_multimap<const void *, instance*> registered_instances; // void * -> instance*
+ std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
+ type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
+ std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
+ std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
+ std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
+ std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
+ std::forward_list<std::string> static_strings; // Stores the std::strings backing detail::c_str()
+ PyTypeObject *static_property_type;
+ PyTypeObject *default_metaclass;
+ PyObject *instance_base;
+#if defined(WITH_THREAD)
+ decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
+ PyInterpreterState *istate = nullptr;
+#endif
+};
+
+/// Additional type information which does not fit into the PyTypeObject.
+/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
+struct type_info {
+ PyTypeObject *type;
+ const std::type_info *cpptype;
+ size_t type_size, holder_size_in_ptrs;
+ void *(*operator_new)(size_t);
+ void (*init_instance)(instance *, const void *);
+ void (*dealloc)(value_and_holder &v_h);
+ std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
+ std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
+ std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
+ buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
+ void *get_buffer_data = nullptr;
+ void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
+ /* A simple type never occurs as a (direct or indirect) parent
+ * of a class that makes use of multiple inheritance */
+ bool simple_type : 1;
+ /* True if there is no multiple inheritance in this type's inheritance tree */
+ bool simple_ancestors : 1;
+ /* for base vs derived holder_type checks */
+ bool default_holder : 1;
+ /* true if this is a type registered with py::module_local */
+ bool module_local : 1;
+};
+
+/// Tracks the `internals` and `type_info` ABI version independent of the main library version
+#define PYBIND11_INTERNALS_VERSION 1
+
+#if defined(WITH_THREAD)
+# define PYBIND11_INTERNALS_KIND ""
+#else
+# define PYBIND11_INTERNALS_KIND "_without_thread"
+#endif
+
+#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
+ PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND "__"
+
+#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
+ PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND "__"
+
+/// Each module locally stores a pointer to the `internals` data. The data
+/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
+inline internals **&get_internals_pp() {
+ static internals **internals_pp = nullptr;
+ return internals_pp;
+}
+
+/// Return a reference to the current `internals` data
+PYBIND11_NOINLINE inline internals &get_internals() {
+ auto **&internals_pp = get_internals_pp();
+ if (internals_pp && *internals_pp)
+ return **internals_pp;
+
+ constexpr auto *id = PYBIND11_INTERNALS_ID;
+ auto builtins = handle(PyEval_GetBuiltins());
+ if (builtins.contains(id) && isinstance<capsule>(builtins[id])) {
+ internals_pp = static_cast<internals **>(capsule(builtins[id]));
+
+ // We loaded builtins through python's builtins, which means that our `error_already_set`
+ // and `builtin_exception` may be different local classes than the ones set up in the
+ // initial exception translator, below, so add another for our local exception classes.
+ //
+ // libstdc++ doesn't require this (types there are identified only by name)
+#if !defined(__GLIBCXX__)
+ (*internals_pp)->registered_exception_translators.push_front(
+ [](std::exception_ptr p) -> void {
+ try {
+ if (p) std::rethrow_exception(p);
+ } catch (error_already_set &e) { e.restore(); return;
+ } catch (const builtin_exception &e) { e.set_error(); return;
+ }
+ }
+ );
+#endif
+ } else {
+ if (!internals_pp) internals_pp = new internals*();
+ auto *&internals_ptr = *internals_pp;
+ internals_ptr = new internals();
+#if defined(WITH_THREAD)
+ PyEval_InitThreads();
+ PyThreadState *tstate = PyThreadState_Get();
+ internals_ptr->tstate = PyThread_create_key();
+ PyThread_set_key_value(internals_ptr->tstate, tstate);
+ internals_ptr->istate = tstate->interp;
+#endif
+ builtins[id] = capsule(internals_pp);
+ internals_ptr->registered_exception_translators.push_front(
+ [](std::exception_ptr p) -> void {
+ try {
+ if (p) std::rethrow_exception(p);
+ } catch (error_already_set &e) { e.restore(); return;
+ } catch (const builtin_exception &e) { e.set_error(); return;
+ } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return;
+ } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
+ } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
+ } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
+ } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return;
+ } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
+ } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return;
+ } catch (...) {
+ PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
+ return;
+ }
+ }
+ );
+ internals_ptr->static_property_type = make_static_property_type();
+ internals_ptr->default_metaclass = make_default_metaclass();
+ internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
+ }
+ return **internals_pp;
+}
+
+/// Works like `internals.registered_types_cpp`, but for module-local registered types:
+inline type_map<type_info *> &registered_local_types_cpp() {
+ static type_map<type_info *> locals{};
+ return locals;
+}
+
+/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
+/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
+/// cleared when the program exits or after interpreter shutdown (when embedding), and so are
+/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
+template <typename... Args>
+const char *c_str(Args &&...args) {
+ auto &strings = get_internals().static_strings;
+ strings.emplace_front(std::forward<Args>(args)...);
+ return strings.front().c_str();
+}
+
+NAMESPACE_END(detail)
+
+/// Returns a named pointer that is shared among all extension modules (using the same
+/// pybind11 version) running in the current interpreter. Names starting with underscores
+/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
+inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
+ auto &internals = detail::get_internals();
+ auto it = internals.shared_data.find(name);
+ return it != internals.shared_data.end() ? it->second : nullptr;
+}
+
+/// Set the shared data that can be later recovered by `get_shared_data()`.
+inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
+ detail::get_internals().shared_data[name] = data;
+ return data;
+}
+
+/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
+/// such entry exists. Otherwise, a new object of default-constructible type `T` is
+/// added to the shared data under the given name and a reference to it is returned.
+template<typename T>
+T &get_or_create_shared_data(const std::string &name) {
+ auto &internals = detail::get_internals();
+ auto it = internals.shared_data.find(name);
+ T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
+ if (!ptr) {
+ ptr = new T();
+ internals.shared_data[name] = ptr;
+ }
+ return *ptr;
+}
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/typeid.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/typeid.h
new file mode 100644
index 000000000..6f36aab75
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/typeid.h
@@ -0,0 +1,53 @@
+/*
+ pybind11/detail/typeid.h: Compiler-independent access to type identifiers
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include <cstdio>
+#include <cstdlib>
+
+#if defined(__GNUG__)
+#include <cxxabi.h>
+#endif
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+/// Erase all occurrences of a substring
+inline void erase_all(std::string &string, const std::string &search) {
+ for (size_t pos = 0;;) {
+ pos = string.find(search, pos);
+ if (pos == std::string::npos) break;
+ string.erase(pos, search.length());
+ }
+}
+
+PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
+#if defined(__GNUG__)
+ int status = 0;
+ std::unique_ptr<char, void (*)(void *)> res {
+ abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
+ if (status == 0)
+ name = res.get();
+#else
+ detail::erase_all(name, "class ");
+ detail::erase_all(name, "struct ");
+ detail::erase_all(name, "enum ");
+#endif
+ detail::erase_all(name, "pybind11::");
+}
+NAMESPACE_END(detail)
+
+/// Return a string representation of a C++ type
+template <typename T> static std::string type_id() {
+ std::string name(typeid(T).name());
+ detail::clean_type_id(name);
+ return name;
+}
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/eigen.h b/ml/dlib/dlib/external/pybind11/include/pybind11/eigen.h
new file mode 100644
index 000000000..693a484dc
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/eigen.h
@@ -0,0 +1,612 @@
+/*
+ pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "numpy.h"
+
+#if defined(__INTEL_COMPILER)
+# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
+#elif defined(__GNUG__) || defined(__clang__)
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wconversion"
+# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+# if __GNUC__ >= 7
+# pragma GCC diagnostic ignored "-Wint-in-bool-context"
+# endif
+#endif
+
+#if defined(_MSC_VER)
+# pragma warning(push)
+# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
+# pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17
+#endif
+
+#include <Eigen/Core>
+#include <Eigen/SparseCore>
+
+// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit
+// move constructors that break things. We could detect this an explicitly copy, but an extra copy
+// of matrices seems highly undesirable.
+static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7");
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides:
+using EigenDStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
+template <typename MatrixType> using EigenDRef = Eigen::Ref<MatrixType, 0, EigenDStride>;
+template <typename MatrixType> using EigenDMap = Eigen::Map<MatrixType, 0, EigenDStride>;
+
+NAMESPACE_BEGIN(detail)
+
+#if EIGEN_VERSION_AT_LEAST(3,3,0)
+using EigenIndex = Eigen::Index;
+#else
+using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE;
+#endif
+
+// Matches Eigen::Map, Eigen::Ref, blocks, etc:
+template <typename T> using is_eigen_dense_map = all_of<is_template_base_of<Eigen::DenseBase, T>, std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>, T>>;
+template <typename T> using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>, T>;
+template <typename T> using is_eigen_dense_plain = all_of<negation<is_eigen_dense_map<T>>, is_template_base_of<Eigen::PlainObjectBase, T>>;
+template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
+// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
+// basically covers anything that can be assigned to a dense matrix but that don't have a typical
+// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
+// SelfAdjointView fall into this category.
+template <typename T> using is_eigen_other = all_of<
+ is_template_base_of<Eigen::EigenBase, T>,
+ negation<any_of<is_eigen_dense_map<T>, is_eigen_dense_plain<T>, is_eigen_sparse<T>>>
+>;
+
+// Captures numpy/eigen conformability status (returned by EigenProps::conformable()):
+template <bool EigenRowMajor> struct EigenConformable {
+ bool conformable = false;
+ EigenIndex rows = 0, cols = 0;
+ EigenDStride stride{0, 0}; // Only valid if negativestrides is false!
+ bool negativestrides = false; // If true, do not use stride!
+
+ EigenConformable(bool fits = false) : conformable{fits} {}
+ // Matrix type:
+ EigenConformable(EigenIndex r, EigenIndex c,
+ EigenIndex rstride, EigenIndex cstride) :
+ conformable{true}, rows{r}, cols{c} {
+ // TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747
+ if (rstride < 0 || cstride < 0) {
+ negativestrides = true;
+ } else {
+ stride = {EigenRowMajor ? rstride : cstride /* outer stride */,
+ EigenRowMajor ? cstride : rstride /* inner stride */ };
+ }
+ }
+ // Vector type:
+ EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
+ : EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {}
+
+ template <typename props> bool stride_compatible() const {
+ // To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
+ // matching strides, or a dimension size of 1 (in which case the stride value is irrelevant)
+ return
+ !negativestrides &&
+ (props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() ||
+ (EigenRowMajor ? cols : rows) == 1) &&
+ (props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() ||
+ (EigenRowMajor ? rows : cols) == 1);
+ }
+ operator bool() const { return conformable; }
+};
+
+template <typename Type> struct eigen_extract_stride { using type = Type; };
+template <typename PlainObjectType, int MapOptions, typename StrideType>
+struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>> { using type = StrideType; };
+template <typename PlainObjectType, int Options, typename StrideType>
+struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
+
+// Helper struct for extracting information from an Eigen type
+template <typename Type_> struct EigenProps {
+ using Type = Type_;
+ using Scalar = typename Type::Scalar;
+ using StrideType = typename eigen_extract_stride<Type>::type;
+ static constexpr EigenIndex
+ rows = Type::RowsAtCompileTime,
+ cols = Type::ColsAtCompileTime,
+ size = Type::SizeAtCompileTime;
+ static constexpr bool
+ row_major = Type::IsRowMajor,
+ vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1
+ fixed_rows = rows != Eigen::Dynamic,
+ fixed_cols = cols != Eigen::Dynamic,
+ fixed = size != Eigen::Dynamic, // Fully-fixed size
+ dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size
+
+ template <EigenIndex i, EigenIndex ifzero> using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
+ static constexpr EigenIndex inner_stride = if_zero<StrideType::InnerStrideAtCompileTime, 1>::value,
+ outer_stride = if_zero<StrideType::OuterStrideAtCompileTime,
+ vector ? size : row_major ? cols : rows>::value;
+ static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic;
+ static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
+ static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
+
+ // Takes an input array and determines whether we can make it fit into the Eigen type. If
+ // the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector
+ // (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type).
+ static EigenConformable<row_major> conformable(const array &a) {
+ const auto dims = a.ndim();
+ if (dims < 1 || dims > 2)
+ return false;
+
+ if (dims == 2) { // Matrix type: require exact match (or dynamic)
+
+ EigenIndex
+ np_rows = a.shape(0),
+ np_cols = a.shape(1),
+ np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
+ np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
+ if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
+ return false;
+
+ return {np_rows, np_cols, np_rstride, np_cstride};
+ }
+
+ // Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
+ // is used, we want the (single) numpy stride value.
+ const EigenIndex n = a.shape(0),
+ stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
+
+ if (vector) { // Eigen type is a compile-time vector
+ if (fixed && size != n)
+ return false; // Vector size mismatch
+ return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride};
+ }
+ else if (fixed) {
+ // The type has a fixed size, but is not a vector: abort
+ return false;
+ }
+ else if (fixed_cols) {
+ // Since this isn't a vector, cols must be != 1. We allow this only if it exactly
+ // equals the number of elements (rows is Dynamic, and so 1 row is allowed).
+ if (cols != n) return false;
+ return {1, n, stride};
+ }
+ else {
+ // Otherwise it's either fully dynamic, or column dynamic; both become a column vector
+ if (fixed_rows && rows != n) return false;
+ return {n, 1, stride};
+ }
+ }
+
+ static PYBIND11_DESCR descriptor() {
+ constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
+ constexpr bool show_order = is_eigen_dense_map<Type>::value;
+ constexpr bool show_c_contiguous = show_order && requires_row_major;
+ constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
+
+ return type_descr(_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
+ _("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
+ _(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
+ _("]") +
+ // For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
+ // satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
+ // options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
+ // to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
+ // see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
+ // *gave* a numpy.ndarray of the right type and dimensions.
+ _<show_writeable>(", flags.writeable", "") +
+ _<show_c_contiguous>(", flags.c_contiguous", "") +
+ _<show_f_contiguous>(", flags.f_contiguous", "") +
+ _("]")
+ );
+ }
+};
+
+// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
+// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
+template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
+ constexpr ssize_t elem_size = sizeof(typename props::Scalar);
+ array a;
+ if (props::vector)
+ a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
+ else
+ a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
+ src.data(), base);
+
+ if (!writeable)
+ array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
+
+ return a.release();
+}
+
+// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that
+// reference the Eigen object's data with `base` as the python-registered base class (if omitted,
+// the base will be set to None, and lifetime management is up to the caller). The numpy array is
+// non-writeable if the given type is const.
+template <typename props, typename Type>
+handle eigen_ref_array(Type &src, handle parent = none()) {
+ // none here is to get past array's should-we-copy detection, which currently always
+ // copies when there is no base. Setting the base to None should be harmless.
+ return eigen_array_cast<props>(src, parent, !std::is_const<Type>::value);
+}
+
+// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy
+// array that references the encapsulated data with a python-side reference to the capsule to tie
+// its destruction to that of any dependent python objects. Const-ness is determined by whether or
+// not the Type of the pointer given is const.
+template <typename props, typename Type, typename = enable_if_t<is_eigen_dense_plain<Type>::value>>
+handle eigen_encapsulate(Type *src) {
+ capsule base(src, [](void *o) { delete static_cast<Type *>(o); });
+ return eigen_ref_array<props>(*src, base);
+}
+
+// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense
+// types.
+template<typename Type>
+struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
+ using Scalar = typename Type::Scalar;
+ using props = EigenProps<Type>;
+
+ bool load(handle src, bool convert) {
+ // If we're in no-convert mode, only load if given an array of the correct type
+ if (!convert && !isinstance<array_t<Scalar>>(src))
+ return false;
+
+ // Coerce into an array, but don't do type conversion yet; the copy below handles it.
+ auto buf = array::ensure(src);
+
+ if (!buf)
+ return false;
+
+ auto dims = buf.ndim();
+ if (dims < 1 || dims > 2)
+ return false;
+
+ auto fits = props::conformable(buf);
+ if (!fits)
+ return false;
+
+ // Allocate the new type, then build a numpy reference into it
+ value = Type(fits.rows, fits.cols);
+ auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
+ if (dims == 1) ref = ref.squeeze();
+ else if (ref.ndim() == 1) buf = buf.squeeze();
+
+ int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
+
+ if (result < 0) { // Copy failed!
+ PyErr_Clear();
+ return false;
+ }
+
+ return true;
+ }
+
+private:
+
+ // Cast implementation
+ template <typename CType>
+ static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
+ switch (policy) {
+ case return_value_policy::take_ownership:
+ case return_value_policy::automatic:
+ return eigen_encapsulate<props>(src);
+ case return_value_policy::move:
+ return eigen_encapsulate<props>(new CType(std::move(*src)));
+ case return_value_policy::copy:
+ return eigen_array_cast<props>(*src);
+ case return_value_policy::reference:
+ case return_value_policy::automatic_reference:
+ return eigen_ref_array<props>(*src);
+ case return_value_policy::reference_internal:
+ return eigen_ref_array<props>(*src, parent);
+ default:
+ throw cast_error("unhandled return_value_policy: should not happen!");
+ };
+ }
+
+public:
+
+ // Normal returned non-reference, non-const value:
+ static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
+ return cast_impl(&src, return_value_policy::move, parent);
+ }
+ // If you return a non-reference const, we mark the numpy array readonly:
+ static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
+ return cast_impl(&src, return_value_policy::move, parent);
+ }
+ // lvalue reference return; default (automatic) becomes copy
+ static handle cast(Type &src, return_value_policy policy, handle parent) {
+ if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
+ policy = return_value_policy::copy;
+ return cast_impl(&src, policy, parent);
+ }
+ // const lvalue reference return; default (automatic) becomes copy
+ static handle cast(const Type &src, return_value_policy policy, handle parent) {
+ if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
+ policy = return_value_policy::copy;
+ return cast(&src, policy, parent);
+ }
+ // non-const pointer return
+ static handle cast(Type *src, return_value_policy policy, handle parent) {
+ return cast_impl(src, policy, parent);
+ }
+ // const pointer return
+ static handle cast(const Type *src, return_value_policy policy, handle parent) {
+ return cast_impl(src, policy, parent);
+ }
+
+ static PYBIND11_DESCR name() { return props::descriptor(); }
+
+ operator Type*() { return &value; }
+ operator Type&() { return value; }
+ operator Type&&() && { return std::move(value); }
+ template <typename T> using cast_op_type = movable_cast_op_type<T>;
+
+private:
+ Type value;
+};
+
+// Eigen Ref/Map classes have slightly different policy requirements, meaning we don't want to force
+// `move` when a Ref/Map rvalue is returned; we treat Ref<> sort of like a pointer (we care about
+// the underlying data, not the outer shell).
+template <typename Return>
+struct return_value_policy_override<Return, enable_if_t<is_eigen_dense_map<Return>::value>> {
+ static return_value_policy policy(return_value_policy p) { return p; }
+};
+
+// Base class for casting reference/map/block/etc. objects back to python.
+template <typename MapType> struct eigen_map_caster {
+private:
+ using props = EigenProps<MapType>;
+
+public:
+
+ // Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has
+ // to stay around), but we'll allow it under the assumption that you know what you're doing (and
+ // have an appropriate keep_alive in place). We return a numpy array pointing directly at the
+ // ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note
+ // that this means you need to ensure you don't destroy the object in some other way (e.g. with
+ // an appropriate keep_alive, or with a reference to a statically allocated matrix).
+ static handle cast(const MapType &src, return_value_policy policy, handle parent) {
+ switch (policy) {
+ case return_value_policy::copy:
+ return eigen_array_cast<props>(src);
+ case return_value_policy::reference_internal:
+ return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
+ case return_value_policy::reference:
+ case return_value_policy::automatic:
+ case return_value_policy::automatic_reference:
+ return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
+ default:
+ // move, take_ownership don't make any sense for a ref/map:
+ pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
+ }
+ }
+
+ static PYBIND11_DESCR name() { return props::descriptor(); }
+
+ // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
+ // types but not bound arguments). We still provide them (with an explicitly delete) so that
+ // you end up here if you try anyway.
+ bool load(handle, bool) = delete;
+ operator MapType() = delete;
+ template <typename> using cast_op_type = MapType;
+};
+
+// We can return any map-like object (but can only load Refs, specialized next):
+template <typename Type> struct type_caster<Type, enable_if_t<is_eigen_dense_map<Type>::value>>
+ : eigen_map_caster<Type> {};
+
+// Loader for Ref<...> arguments. See the documentation for info on how to make this work without
+// copying (it requires some extra effort in many cases).
+template <typename PlainObjectType, typename StrideType>
+struct type_caster<
+ Eigen::Ref<PlainObjectType, 0, StrideType>,
+ enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>
+> : public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
+private:
+ using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
+ using props = EigenProps<Type>;
+ using Scalar = typename props::Scalar;
+ using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
+ using Array = array_t<Scalar, array::forcecast |
+ ((props::row_major ? props::inner_stride : props::outer_stride) == 1 ? array::c_style :
+ (props::row_major ? props::outer_stride : props::inner_stride) == 1 ? array::f_style : 0)>;
+ static constexpr bool need_writeable = is_eigen_mutable_map<Type>::value;
+ // Delay construction (these have no default constructor)
+ std::unique_ptr<MapType> map;
+ std::unique_ptr<Type> ref;
+ // Our array. When possible, this is just a numpy array pointing to the source data, but
+ // sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible
+ // layout, or is an array of a type that needs to be converted). Using a numpy temporary
+ // (rather than an Eigen temporary) saves an extra copy when we need both type conversion and
+ // storage order conversion. (Note that we refuse to use this temporary copy when loading an
+ // argument for a Ref<M> with M non-const, i.e. a read-write reference).
+ Array copy_or_ref;
+public:
+ bool load(handle src, bool convert) {
+ // First check whether what we have is already an array of the right type. If not, we can't
+ // avoid a copy (because the copy is also going to do type conversion).
+ bool need_copy = !isinstance<Array>(src);
+
+ EigenConformable<props::row_major> fits;
+ if (!need_copy) {
+ // We don't need a converting copy, but we also need to check whether the strides are
+ // compatible with the Ref's stride requirements
+ Array aref = reinterpret_borrow<Array>(src);
+
+ if (aref && (!need_writeable || aref.writeable())) {
+ fits = props::conformable(aref);
+ if (!fits) return false; // Incompatible dimensions
+ if (!fits.template stride_compatible<props>())
+ need_copy = true;
+ else
+ copy_or_ref = std::move(aref);
+ }
+ else {
+ need_copy = true;
+ }
+ }
+
+ if (need_copy) {
+ // We need to copy: If we need a mutable reference, or we're not supposed to convert
+ // (either because we're in the no-convert overload pass, or because we're explicitly
+ // instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
+ if (!convert || need_writeable) return false;
+
+ Array copy = Array::ensure(src);
+ if (!copy) return false;
+ fits = props::conformable(copy);
+ if (!fits || !fits.template stride_compatible<props>())
+ return false;
+ copy_or_ref = std::move(copy);
+ loader_life_support::add_patient(copy_or_ref);
+ }
+
+ ref.reset();
+ map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner())));
+ ref.reset(new Type(*map));
+
+ return true;
+ }
+
+ operator Type*() { return ref.get(); }
+ operator Type&() { return *ref; }
+ template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
+
+private:
+ template <typename T = Type, enable_if_t<is_eigen_mutable_map<T>::value, int> = 0>
+ Scalar *data(Array &a) { return a.mutable_data(); }
+
+ template <typename T = Type, enable_if_t<!is_eigen_mutable_map<T>::value, int> = 0>
+ const Scalar *data(Array &a) { return a.data(); }
+
+ // Attempt to figure out a constructor of `Stride` that will work.
+ // If both strides are fixed, use a default constructor:
+ template <typename S> using stride_ctor_default = bool_constant<
+ S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
+ std::is_default_constructible<S>::value>;
+ // Otherwise, if there is a two-index constructor, assume it is (outer,inner) like
+ // Eigen::Stride, and use it:
+ template <typename S> using stride_ctor_dual = bool_constant<
+ !stride_ctor_default<S>::value && std::is_constructible<S, EigenIndex, EigenIndex>::value>;
+ // Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use
+ // it (passing whichever stride is dynamic).
+ template <typename S> using stride_ctor_outer = bool_constant<
+ !any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
+ S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic &&
+ std::is_constructible<S, EigenIndex>::value>;
+ template <typename S> using stride_ctor_inner = bool_constant<
+ !any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
+ S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
+ std::is_constructible<S, EigenIndex>::value>;
+
+ template <typename S = StrideType, enable_if_t<stride_ctor_default<S>::value, int> = 0>
+ static S make_stride(EigenIndex, EigenIndex) { return S(); }
+ template <typename S = StrideType, enable_if_t<stride_ctor_dual<S>::value, int> = 0>
+ static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); }
+ template <typename S = StrideType, enable_if_t<stride_ctor_outer<S>::value, int> = 0>
+ static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); }
+ template <typename S = StrideType, enable_if_t<stride_ctor_inner<S>::value, int> = 0>
+ static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); }
+
+};
+
+// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not
+// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout).
+// load() is not supported, but we can cast them into the python domain by first copying to a
+// regular Eigen::Matrix, then casting that.
+template <typename Type>
+struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
+protected:
+ using Matrix = Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
+ using props = EigenProps<Matrix>;
+public:
+ static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
+ handle h = eigen_encapsulate<props>(new Matrix(src));
+ return h;
+ }
+ static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
+
+ static PYBIND11_DESCR name() { return props::descriptor(); }
+
+ // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
+ // types but not bound arguments). We still provide them (with an explicitly delete) so that
+ // you end up here if you try anyway.
+ bool load(handle, bool) = delete;
+ operator Type() = delete;
+ template <typename> using cast_op_type = Type;
+};
+
+template<typename Type>
+struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
+ typedef typename Type::Scalar Scalar;
+ typedef remove_reference_t<decltype(*std::declval<Type>().outerIndexPtr())> StorageIndex;
+ typedef typename Type::Index Index;
+ static constexpr bool rowMajor = Type::IsRowMajor;
+
+ bool load(handle src, bool) {
+ if (!src)
+ return false;
+
+ auto obj = reinterpret_borrow<object>(src);
+ object sparse_module = module::import("scipy.sparse");
+ object matrix_type = sparse_module.attr(
+ rowMajor ? "csr_matrix" : "csc_matrix");
+
+ if (!obj.get_type().is(matrix_type)) {
+ try {
+ obj = matrix_type(obj);
+ } catch (const error_already_set &) {
+ return false;
+ }
+ }
+
+ auto values = array_t<Scalar>((object) obj.attr("data"));
+ auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
+ auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
+ auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
+ auto nnz = obj.attr("nnz").cast<Index>();
+
+ if (!values || !innerIndices || !outerIndices)
+ return false;
+
+ value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
+ shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
+ outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
+
+ return true;
+ }
+
+ static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
+ const_cast<Type&>(src).makeCompressed();
+
+ object matrix_type = module::import("scipy.sparse").attr(
+ rowMajor ? "csr_matrix" : "csc_matrix");
+
+ array data(src.nonZeros(), src.valuePtr());
+ array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
+ array innerIndices(src.nonZeros(), src.innerIndexPtr());
+
+ return matrix_type(
+ std::make_tuple(data, innerIndices, outerIndices),
+ std::make_pair(src.rows(), src.cols())
+ ).release();
+ }
+
+ PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
+ + npy_format_descriptor<Scalar>::name() + _("]"));
+};
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
+
+#if defined(__GNUG__) || defined(__clang__)
+# pragma GCC diagnostic pop
+#elif defined(_MSC_VER)
+# pragma warning(pop)
+#endif
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/embed.h b/ml/dlib/dlib/external/pybind11/include/pybind11/embed.h
new file mode 100644
index 000000000..9abc61c34
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/embed.h
@@ -0,0 +1,194 @@
+/*
+ pybind11/embed.h: Support for embedding the interpreter
+
+ Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pybind11.h"
+#include "eval.h"
+
+#if defined(PYPY_VERSION)
+# error Embedding the interpreter is not supported with PyPy
+#endif
+
+#if PY_MAJOR_VERSION >= 3
+# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
+ extern "C" PyObject *pybind11_init_impl_##name() { \
+ return pybind11_init_wrapper_##name(); \
+ }
+#else
+# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
+ extern "C" void pybind11_init_impl_##name() { \
+ pybind11_init_wrapper_##name(); \
+ }
+#endif
+
+/** \rst
+ Add a new module to the table of builtins for the interpreter. Must be
+ defined in global scope. The first macro parameter is the name of the
+ module (without quotes). The second parameter is the variable which will
+ be used as the interface to add functions and classes to the module.
+
+ .. code-block:: cpp
+
+ PYBIND11_EMBEDDED_MODULE(example, m) {
+ // ... initialize functions and classes here
+ m.def("foo", []() {
+ return "Hello, World!";
+ });
+ }
+ \endrst */
+#define PYBIND11_EMBEDDED_MODULE(name, variable) \
+ static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
+ static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
+ auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
+ try { \
+ PYBIND11_CONCAT(pybind11_init_, name)(m); \
+ return m.ptr(); \
+ } catch (pybind11::error_already_set &e) { \
+ PyErr_SetString(PyExc_ImportError, e.what()); \
+ return nullptr; \
+ } catch (const std::exception &e) { \
+ PyErr_SetString(PyExc_ImportError, e.what()); \
+ return nullptr; \
+ } \
+ } \
+ PYBIND11_EMBEDDED_MODULE_IMPL(name) \
+ pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
+ PYBIND11_CONCAT(pybind11_init_impl_, name)); \
+ void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
+
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
+struct embedded_module {
+#if PY_MAJOR_VERSION >= 3
+ using init_t = PyObject *(*)();
+#else
+ using init_t = void (*)();
+#endif
+ embedded_module(const char *name, init_t init) {
+ if (Py_IsInitialized())
+ pybind11_fail("Can't add new modules after the interpreter has been initialized");
+
+ auto result = PyImport_AppendInittab(name, init);
+ if (result == -1)
+ pybind11_fail("Insufficient memory to add a new module");
+ }
+};
+
+NAMESPACE_END(detail)
+
+/** \rst
+ Initialize the Python interpreter. No other pybind11 or CPython API functions can be
+ called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
+ optional parameter can be used to skip the registration of signal handlers (see the
+ Python documentation for details). Calling this function again after the interpreter
+ has already been initialized is a fatal error.
+ \endrst */
+inline void initialize_interpreter(bool init_signal_handlers = true) {
+ if (Py_IsInitialized())
+ pybind11_fail("The interpreter is already running");
+
+ Py_InitializeEx(init_signal_handlers ? 1 : 0);
+
+ // Make .py files in the working directory available by default
+ module::import("sys").attr("path").cast<list>().append(".");
+}
+
+/** \rst
+ Shut down the Python interpreter. No pybind11 or CPython API functions can be called
+ after this. In addition, pybind11 objects must not outlive the interpreter:
+
+ .. code-block:: cpp
+
+ { // BAD
+ py::initialize_interpreter();
+ auto hello = py::str("Hello, World!");
+ py::finalize_interpreter();
+ } // <-- BOOM, hello's destructor is called after interpreter shutdown
+
+ { // GOOD
+ py::initialize_interpreter();
+ { // scoped
+ auto hello = py::str("Hello, World!");
+ } // <-- OK, hello is cleaned up properly
+ py::finalize_interpreter();
+ }
+
+ { // BETTER
+ py::scoped_interpreter guard{};
+ auto hello = py::str("Hello, World!");
+ }
+
+ .. warning::
+
+ The interpreter can be restarted by calling `initialize_interpreter` again.
+ Modules created using pybind11 can be safely re-initialized. However, Python
+ itself cannot completely unload binary extension modules and there are several
+ caveats with regard to interpreter restarting. All the details can be found
+ in the CPython documentation. In short, not all interpreter memory may be
+ freed, either due to reference cycles or user-created global data.
+
+ \endrst */
+inline void finalize_interpreter() {
+ handle builtins(PyEval_GetBuiltins());
+ const char *id = PYBIND11_INTERNALS_ID;
+
+ // Get the internals pointer (without creating it if it doesn't exist). It's possible for the
+ // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
+ // during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
+ detail::internals **internals_ptr_ptr = detail::get_internals_pp();
+ // It could also be stashed in builtins, so look there too:
+ if (builtins.contains(id) && isinstance<capsule>(builtins[id]))
+ internals_ptr_ptr = capsule(builtins[id]);
+
+ Py_Finalize();
+
+ if (internals_ptr_ptr) {
+ delete *internals_ptr_ptr;
+ *internals_ptr_ptr = nullptr;
+ }
+}
+
+/** \rst
+ Scope guard version of `initialize_interpreter` and `finalize_interpreter`.
+ This a move-only guard and only a single instance can exist.
+
+ .. code-block:: cpp
+
+ #include <pybind11/embed.h>
+
+ int main() {
+ py::scoped_interpreter guard{};
+ py::print(Hello, World!);
+ } // <-- interpreter shutdown
+ \endrst */
+class scoped_interpreter {
+public:
+ scoped_interpreter(bool init_signal_handlers = true) {
+ initialize_interpreter(init_signal_handlers);
+ }
+
+ scoped_interpreter(const scoped_interpreter &) = delete;
+ scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
+ scoped_interpreter &operator=(const scoped_interpreter &) = delete;
+ scoped_interpreter &operator=(scoped_interpreter &&) = delete;
+
+ ~scoped_interpreter() {
+ if (is_valid)
+ finalize_interpreter();
+ }
+
+private:
+ bool is_valid = true;
+};
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/eval.h b/ml/dlib/dlib/external/pybind11/include/pybind11/eval.h
new file mode 100644
index 000000000..ea85ba1db
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/eval.h
@@ -0,0 +1,117 @@
+/*
+ pybind11/exec.h: Support for evaluating Python expressions and statements
+ from strings and files
+
+ Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de> and
+ Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pybind11.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+enum eval_mode {
+ /// Evaluate a string containing an isolated expression
+ eval_expr,
+
+ /// Evaluate a string containing a single statement. Returns \c none
+ eval_single_statement,
+
+ /// Evaluate a string containing a sequence of statement. Returns \c none
+ eval_statements
+};
+
+template <eval_mode mode = eval_expr>
+object eval(str expr, object global = globals(), object local = object()) {
+ if (!local)
+ local = global;
+
+ /* PyRun_String does not accept a PyObject / encoding specifier,
+ this seems to be the only alternative */
+ std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
+
+ int start;
+ switch (mode) {
+ case eval_expr: start = Py_eval_input; break;
+ case eval_single_statement: start = Py_single_input; break;
+ case eval_statements: start = Py_file_input; break;
+ default: pybind11_fail("invalid evaluation mode");
+ }
+
+ PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
+ if (!result)
+ throw error_already_set();
+ return reinterpret_steal<object>(result);
+}
+
+template <eval_mode mode = eval_expr, size_t N>
+object eval(const char (&s)[N], object global = globals(), object local = object()) {
+ /* Support raw string literals by removing common leading whitespace */
+ auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
+ : str(s);
+ return eval<mode>(expr, global, local);
+}
+
+inline void exec(str expr, object global = globals(), object local = object()) {
+ eval<eval_statements>(expr, global, local);
+}
+
+template <size_t N>
+void exec(const char (&s)[N], object global = globals(), object local = object()) {
+ eval<eval_statements>(s, global, local);
+}
+
+template <eval_mode mode = eval_statements>
+object eval_file(str fname, object global = globals(), object local = object()) {
+ if (!local)
+ local = global;
+
+ int start;
+ switch (mode) {
+ case eval_expr: start = Py_eval_input; break;
+ case eval_single_statement: start = Py_single_input; break;
+ case eval_statements: start = Py_file_input; break;
+ default: pybind11_fail("invalid evaluation mode");
+ }
+
+ int closeFile = 1;
+ std::string fname_str = (std::string) fname;
+#if PY_VERSION_HEX >= 0x03040000
+ FILE *f = _Py_fopen_obj(fname.ptr(), "r");
+#elif PY_VERSION_HEX >= 0x03000000
+ FILE *f = _Py_fopen(fname.ptr(), "r");
+#else
+ /* No unicode support in open() :( */
+ auto fobj = reinterpret_steal<object>(PyFile_FromString(
+ const_cast<char *>(fname_str.c_str()),
+ const_cast<char*>("r")));
+ FILE *f = nullptr;
+ if (fobj)
+ f = PyFile_AsFile(fobj.ptr());
+ closeFile = 0;
+#endif
+ if (!f) {
+ PyErr_Clear();
+ pybind11_fail("File \"" + fname_str + "\" could not be opened!");
+ }
+
+#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION)
+ PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(),
+ local.ptr());
+ (void) closeFile;
+#else
+ PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
+ local.ptr(), closeFile);
+#endif
+
+ if (!result)
+ throw error_already_set();
+ return reinterpret_steal<object>(result);
+}
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/functional.h b/ml/dlib/dlib/external/pybind11/include/pybind11/functional.h
new file mode 100644
index 000000000..eda14ba58
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/functional.h
@@ -0,0 +1,85 @@
+/*
+ pybind11/functional.h: std::function<> support
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pybind11.h"
+#include <functional>
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+template <typename Return, typename... Args>
+struct type_caster<std::function<Return(Args...)>> {
+ using type = std::function<Return(Args...)>;
+ using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
+ using function_type = Return (*) (Args...);
+
+public:
+ bool load(handle src, bool convert) {
+ if (src.is_none()) {
+ // Defer accepting None to other overloads (if we aren't in convert mode):
+ if (!convert) return false;
+ return true;
+ }
+
+ if (!isinstance<function>(src))
+ return false;
+
+ auto func = reinterpret_borrow<function>(src);
+
+ /*
+ When passing a C++ function as an argument to another C++
+ function via Python, every function call would normally involve
+ a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
+ Here, we try to at least detect the case where the function is
+ stateless (i.e. function pointer or lambda function without
+ captured variables), in which case the roundtrip can be avoided.
+ */
+ if (auto cfunc = func.cpp_function()) {
+ auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
+ auto rec = (function_record *) c;
+
+ if (rec && rec->is_stateless &&
+ same_type(typeid(function_type), *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
+ struct capture { function_type f; };
+ value = ((capture *) &rec->data)->f;
+ return true;
+ }
+ }
+
+ value = [func](Args... args) -> Return {
+ gil_scoped_acquire acq;
+ object retval(func(std::forward<Args>(args)...));
+ /* Visual studio 2015 parser issue: need parentheses around this expression */
+ return (retval.template cast<Return>());
+ };
+ return true;
+ }
+
+ template <typename Func>
+ static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
+ if (!f_)
+ return none().inc_ref();
+
+ auto result = f_.template target<function_type>();
+ if (result)
+ return cpp_function(*result, policy).release();
+ else
+ return cpp_function(std::forward<Func>(f_), policy).release();
+ }
+
+ PYBIND11_TYPE_CASTER(type, _("Callable[[") +
+ argument_loader<Args...>::arg_names() + _("], ") +
+ make_caster<retval_type>::name() +
+ _("]"));
+};
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/iostream.h b/ml/dlib/dlib/external/pybind11/include/pybind11/iostream.h
new file mode 100644
index 000000000..a9c27aac1
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/iostream.h
@@ -0,0 +1,200 @@
+/*
+ pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python
+
+ Copyright (c) 2017 Henry F. Schreiner
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pybind11.h"
+
+#include <streambuf>
+#include <ostream>
+#include <string>
+#include <memory>
+#include <iostream>
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+// Buffer that writes to Python instead of C++
+class pythonbuf : public std::streambuf {
+private:
+ using traits_type = std::streambuf::traits_type;
+
+ char d_buffer[1024];
+ object pywrite;
+ object pyflush;
+
+ int overflow(int c) {
+ if (!traits_type::eq_int_type(c, traits_type::eof())) {
+ *pptr() = traits_type::to_char_type(c);
+ pbump(1);
+ }
+ return sync() ? traits_type::not_eof(c) : traits_type::eof();
+ }
+
+ int sync() {
+ if (pbase() != pptr()) {
+ // This subtraction cannot be negative, so dropping the sign
+ str line(pbase(), static_cast<size_t>(pptr() - pbase()));
+
+ pywrite(line);
+ pyflush();
+
+ setp(pbase(), epptr());
+ }
+ return 0;
+ }
+
+public:
+ pythonbuf(object pyostream)
+ : pywrite(pyostream.attr("write")),
+ pyflush(pyostream.attr("flush")) {
+ setp(d_buffer, d_buffer + sizeof(d_buffer) - 1);
+ }
+
+ /// Sync before destroy
+ ~pythonbuf() {
+ sync();
+ }
+};
+
+NAMESPACE_END(detail)
+
+
+/** \rst
+ This a move-only guard that redirects output.
+
+ .. code-block:: cpp
+
+ #include <pybind11/iostream.h>
+
+ ...
+
+ {
+ py::scoped_ostream_redirect output;
+ std::cout << "Hello, World!"; // Python stdout
+ } // <-- return std::cout to normal
+
+ You can explicitly pass the c++ stream and the python object,
+ for example to guard stderr instead.
+
+ .. code-block:: cpp
+
+ {
+ py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")};
+ std::cerr << "Hello, World!";
+ }
+ \endrst */
+class scoped_ostream_redirect {
+protected:
+ std::streambuf *old;
+ std::ostream &costream;
+ detail::pythonbuf buffer;
+
+public:
+ scoped_ostream_redirect(
+ std::ostream &costream = std::cout,
+ object pyostream = module::import("sys").attr("stdout"))
+ : costream(costream), buffer(pyostream) {
+ old = costream.rdbuf(&buffer);
+ }
+
+ ~scoped_ostream_redirect() {
+ costream.rdbuf(old);
+ }
+
+ scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
+ scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
+ scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
+ scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
+};
+
+
+/** \rst
+ Like `scoped_ostream_redirect`, but redirects cerr by default. This class
+ is provided primary to make ``py::call_guard`` easier to make.
+
+ .. code-block:: cpp
+
+ m.def("noisy_func", &noisy_func,
+ py::call_guard<scoped_ostream_redirect,
+ scoped_estream_redirect>());
+
+\endrst */
+class scoped_estream_redirect : public scoped_ostream_redirect {
+public:
+ scoped_estream_redirect(
+ std::ostream &costream = std::cerr,
+ object pyostream = module::import("sys").attr("stderr"))
+ : scoped_ostream_redirect(costream,pyostream) {}
+};
+
+
+NAMESPACE_BEGIN(detail)
+
+// Class to redirect output as a context manager. C++ backend.
+class OstreamRedirect {
+ bool do_stdout_;
+ bool do_stderr_;
+ std::unique_ptr<scoped_ostream_redirect> redirect_stdout;
+ std::unique_ptr<scoped_estream_redirect> redirect_stderr;
+
+public:
+ OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
+ : do_stdout_(do_stdout), do_stderr_(do_stderr) {}
+
+ void enter() {
+ if (do_stdout_)
+ redirect_stdout.reset(new scoped_ostream_redirect());
+ if (do_stderr_)
+ redirect_stderr.reset(new scoped_estream_redirect());
+ }
+
+ void exit() {
+ redirect_stdout.reset();
+ redirect_stderr.reset();
+ }
+};
+
+NAMESPACE_END(detail)
+
+/** \rst
+ This is a helper function to add a C++ redirect context manager to Python
+ instead of using a C++ guard. To use it, add the following to your binding code:
+
+ .. code-block:: cpp
+
+ #include <pybind11/iostream.h>
+
+ ...
+
+ py::add_ostream_redirect(m, "ostream_redirect");
+
+ You now have a Python context manager that redirects your output:
+
+ .. code-block:: python
+
+ with m.ostream_redirect():
+ m.print_to_cout_function()
+
+ This manager can optionally be told which streams to operate on:
+
+ .. code-block:: python
+
+ with m.ostream_redirect(stdout=true, stderr=true):
+ m.noisy_function_with_error_printing()
+
+ \endrst */
+inline class_<detail::OstreamRedirect> add_ostream_redirect(module m, std::string name = "ostream_redirect") {
+ return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
+ .def(init<bool,bool>(), arg("stdout")=true, arg("stderr")=true)
+ .def("__enter__", &detail::OstreamRedirect::enter)
+ .def("__exit__", [](detail::OstreamRedirect &self, args) { self.exit(); });
+}
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/numpy.h b/ml/dlib/dlib/external/pybind11/include/pybind11/numpy.h
new file mode 100644
index 000000000..b1600dc2e
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/numpy.h
@@ -0,0 +1,1600 @@
+/*
+ pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pybind11.h"
+#include "complex.h"
+#include <numeric>
+#include <algorithm>
+#include <array>
+#include <cstdlib>
+#include <cstring>
+#include <sstream>
+#include <string>
+#include <initializer_list>
+#include <functional>
+#include <utility>
+#include <typeindex>
+
+#if defined(_MSC_VER)
+# pragma warning(push)
+# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
+#endif
+
+/* This will be true on all flat address space platforms and allows us to reduce the
+ whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
+ and dimension types (e.g. shape, strides, indexing), instead of inflicting this
+ upon the library user. */
+static_assert(sizeof(ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+class array; // Forward declaration
+
+NAMESPACE_BEGIN(detail)
+template <typename type, typename SFINAE = void> struct npy_format_descriptor;
+
+struct PyArrayDescr_Proxy {
+ PyObject_HEAD
+ PyObject *typeobj;
+ char kind;
+ char type;
+ char byteorder;
+ char flags;
+ int type_num;
+ int elsize;
+ int alignment;
+ char *subarray;
+ PyObject *fields;
+ PyObject *names;
+};
+
+struct PyArray_Proxy {
+ PyObject_HEAD
+ char *data;
+ int nd;
+ ssize_t *dimensions;
+ ssize_t *strides;
+ PyObject *base;
+ PyObject *descr;
+ int flags;
+};
+
+struct PyVoidScalarObject_Proxy {
+ PyObject_VAR_HEAD
+ char *obval;
+ PyArrayDescr_Proxy *descr;
+ int flags;
+ PyObject *base;
+};
+
+struct numpy_type_info {
+ PyObject* dtype_ptr;
+ std::string format_str;
+};
+
+struct numpy_internals {
+ std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
+
+ numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
+ auto it = registered_dtypes.find(std::type_index(tinfo));
+ if (it != registered_dtypes.end())
+ return &(it->second);
+ if (throw_if_missing)
+ pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
+ return nullptr;
+ }
+
+ template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
+ return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
+ }
+};
+
+inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
+ ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
+}
+
+inline numpy_internals& get_numpy_internals() {
+ static numpy_internals* ptr = nullptr;
+ if (!ptr)
+ load_numpy_internals(ptr);
+ return *ptr;
+}
+
+struct npy_api {
+ enum constants {
+ NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
+ NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
+ NPY_ARRAY_OWNDATA_ = 0x0004,
+ NPY_ARRAY_FORCECAST_ = 0x0010,
+ NPY_ARRAY_ENSUREARRAY_ = 0x0040,
+ NPY_ARRAY_ALIGNED_ = 0x0100,
+ NPY_ARRAY_WRITEABLE_ = 0x0400,
+ NPY_BOOL_ = 0,
+ NPY_BYTE_, NPY_UBYTE_,
+ NPY_SHORT_, NPY_USHORT_,
+ NPY_INT_, NPY_UINT_,
+ NPY_LONG_, NPY_ULONG_,
+ NPY_LONGLONG_, NPY_ULONGLONG_,
+ NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
+ NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
+ NPY_OBJECT_ = 17,
+ NPY_STRING_, NPY_UNICODE_, NPY_VOID_
+ };
+
+ typedef struct {
+ Py_intptr_t *ptr;
+ int len;
+ } PyArray_Dims;
+
+ static npy_api& get() {
+ static npy_api api = lookup();
+ return api;
+ }
+
+ bool PyArray_Check_(PyObject *obj) const {
+ return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
+ }
+ bool PyArrayDescr_Check_(PyObject *obj) const {
+ return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
+ }
+
+ unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
+ PyObject *(*PyArray_DescrFromType_)(int);
+ PyObject *(*PyArray_NewFromDescr_)
+ (PyTypeObject *, PyObject *, int, Py_intptr_t *,
+ Py_intptr_t *, void *, int, PyObject *);
+ PyObject *(*PyArray_DescrNewFromType_)(int);
+ int (*PyArray_CopyInto_)(PyObject *, PyObject *);
+ PyObject *(*PyArray_NewCopy_)(PyObject *, int);
+ PyTypeObject *PyArray_Type_;
+ PyTypeObject *PyVoidArrType_Type_;
+ PyTypeObject *PyArrayDescr_Type_;
+ PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
+ PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
+ int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
+ bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
+ int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
+ Py_ssize_t *, PyObject **, PyObject *);
+ PyObject *(*PyArray_Squeeze_)(PyObject *);
+ int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
+ PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
+private:
+ enum functions {
+ API_PyArray_GetNDArrayCFeatureVersion = 211,
+ API_PyArray_Type = 2,
+ API_PyArrayDescr_Type = 3,
+ API_PyVoidArrType_Type = 39,
+ API_PyArray_DescrFromType = 45,
+ API_PyArray_DescrFromScalar = 57,
+ API_PyArray_FromAny = 69,
+ API_PyArray_Resize = 80,
+ API_PyArray_CopyInto = 82,
+ API_PyArray_NewCopy = 85,
+ API_PyArray_NewFromDescr = 94,
+ API_PyArray_DescrNewFromType = 9,
+ API_PyArray_DescrConverter = 174,
+ API_PyArray_EquivTypes = 182,
+ API_PyArray_GetArrayParamsFromObject = 278,
+ API_PyArray_Squeeze = 136,
+ API_PyArray_SetBaseObject = 282
+ };
+
+ static npy_api lookup() {
+ module m = module::import("numpy.core.multiarray");
+ auto c = m.attr("_ARRAY_API");
+#if PY_MAJOR_VERSION >= 3
+ void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
+#else
+ void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
+#endif
+ npy_api api;
+#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
+ DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
+ if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
+ pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
+ DECL_NPY_API(PyArray_Type);
+ DECL_NPY_API(PyVoidArrType_Type);
+ DECL_NPY_API(PyArrayDescr_Type);
+ DECL_NPY_API(PyArray_DescrFromType);
+ DECL_NPY_API(PyArray_DescrFromScalar);
+ DECL_NPY_API(PyArray_FromAny);
+ DECL_NPY_API(PyArray_Resize);
+ DECL_NPY_API(PyArray_CopyInto);
+ DECL_NPY_API(PyArray_NewCopy);
+ DECL_NPY_API(PyArray_NewFromDescr);
+ DECL_NPY_API(PyArray_DescrNewFromType);
+ DECL_NPY_API(PyArray_DescrConverter);
+ DECL_NPY_API(PyArray_EquivTypes);
+ DECL_NPY_API(PyArray_GetArrayParamsFromObject);
+ DECL_NPY_API(PyArray_Squeeze);
+ DECL_NPY_API(PyArray_SetBaseObject);
+#undef DECL_NPY_API
+ return api;
+ }
+};
+
+inline PyArray_Proxy* array_proxy(void* ptr) {
+ return reinterpret_cast<PyArray_Proxy*>(ptr);
+}
+
+inline const PyArray_Proxy* array_proxy(const void* ptr) {
+ return reinterpret_cast<const PyArray_Proxy*>(ptr);
+}
+
+inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
+ return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
+}
+
+inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
+ return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
+}
+
+inline bool check_flags(const void* ptr, int flag) {
+ return (flag == (array_proxy(ptr)->flags & flag));
+}
+
+template <typename T> struct is_std_array : std::false_type { };
+template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
+template <typename T> struct is_complex : std::false_type { };
+template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
+
+template <typename T> struct array_info_scalar {
+ typedef T type;
+ static constexpr bool is_array = false;
+ static constexpr bool is_empty = false;
+ static PYBIND11_DESCR extents() { return _(""); }
+ static void append_extents(list& /* shape */) { }
+};
+// Computes underlying type and a comma-separated list of extents for array
+// types (any mix of std::array and built-in arrays). An array of char is
+// treated as scalar because it gets special handling.
+template <typename T> struct array_info : array_info_scalar<T> { };
+template <typename T, size_t N> struct array_info<std::array<T, N>> {
+ using type = typename array_info<T>::type;
+ static constexpr bool is_array = true;
+ static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
+ static constexpr size_t extent = N;
+
+ // appends the extents to shape
+ static void append_extents(list& shape) {
+ shape.append(N);
+ array_info<T>::append_extents(shape);
+ }
+
+ template<typename T2 = T, enable_if_t<!array_info<T2>::is_array, int> = 0>
+ static PYBIND11_DESCR extents() {
+ return _<N>();
+ }
+
+ template<typename T2 = T, enable_if_t<array_info<T2>::is_array, int> = 0>
+ static PYBIND11_DESCR extents() {
+ return concat(_<N>(), array_info<T>::extents());
+ }
+};
+// For numpy we have special handling for arrays of characters, so we don't include
+// the size in the array extents.
+template <size_t N> struct array_info<char[N]> : array_info_scalar<char[N]> { };
+template <size_t N> struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> { };
+template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<T, N>> { };
+template <typename T> using remove_all_extents_t = typename array_info<T>::type;
+
+template <typename T> using is_pod_struct = all_of<
+ std::is_standard_layout<T>, // since we're accessing directly in memory we need a standard layout type
+#if !defined(__GNUG__) || defined(_LIBCPP_VERSION) || defined(_GLIBCXX_USE_CXX11_ABI)
+ // _GLIBCXX_USE_CXX11_ABI indicates that we're using libstdc++ from GCC 5 or newer, independent
+ // of the actual compiler (Clang can also use libstdc++, but it always defines __GNUC__ == 4).
+ std::is_trivially_copyable<T>,
+#else
+ // GCC 4 doesn't implement is_trivially_copyable, so approximate it
+ std::is_trivially_destructible<T>,
+ satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
+#endif
+ satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
+>;
+
+template <ssize_t Dim = 0, typename Strides> ssize_t byte_offset_unsafe(const Strides &) { return 0; }
+template <ssize_t Dim = 0, typename Strides, typename... Ix>
+ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
+ return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
+}
+
+/**
+ * Proxy class providing unsafe, unchecked const access to array data. This is constructed through
+ * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
+ * will be -1 for dimensions determined at runtime.
+ */
+template <typename T, ssize_t Dims>
+class unchecked_reference {
+protected:
+ static constexpr bool Dynamic = Dims < 0;
+ const unsigned char *data_;
+ // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
+ // make large performance gains on big, nested loops, but requires compile-time dimensions
+ conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>>
+ shape_, strides_;
+ const ssize_t dims_;
+
+ friend class pybind11::array;
+ // Constructor for compile-time dimensions:
+ template <bool Dyn = Dynamic>
+ unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<!Dyn, ssize_t>)
+ : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
+ for (size_t i = 0; i < (size_t) dims_; i++) {
+ shape_[i] = shape[i];
+ strides_[i] = strides[i];
+ }
+ }
+ // Constructor for runtime dimensions:
+ template <bool Dyn = Dynamic>
+ unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<Dyn, ssize_t> dims)
+ : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
+
+public:
+ /**
+ * Unchecked const reference access to data at the given indices. For a compile-time known
+ * number of dimensions, this requires the correct number of arguments; for run-time
+ * dimensionality, this is not checked (and so is up to the caller to use safely).
+ */
+ template <typename... Ix> const T &operator()(Ix... index) const {
+ static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
+ "Invalid number of indices for unchecked array reference");
+ return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, ssize_t(index)...));
+ }
+ /**
+ * Unchecked const reference access to data; this operator only participates if the reference
+ * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
+ */
+ template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
+ const T &operator[](ssize_t index) const { return operator()(index); }
+
+ /// Pointer access to the data at the given indices.
+ template <typename... Ix> const T *data(Ix... ix) const { return &operator()(ssize_t(ix)...); }
+
+ /// Returns the item size, i.e. sizeof(T)
+ constexpr static ssize_t itemsize() { return sizeof(T); }
+
+ /// Returns the shape (i.e. size) of dimension `dim`
+ ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
+
+ /// Returns the number of dimensions of the array
+ ssize_t ndim() const { return dims_; }
+
+ /// Returns the total number of elements in the referenced array, i.e. the product of the shapes
+ template <bool Dyn = Dynamic>
+ enable_if_t<!Dyn, ssize_t> size() const {
+ return std::accumulate(shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
+ }
+ template <bool Dyn = Dynamic>
+ enable_if_t<Dyn, ssize_t> size() const {
+ return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
+ }
+
+ /// Returns the total number of bytes used by the referenced data. Note that the actual span in
+ /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
+ ssize_t nbytes() const {
+ return size() * itemsize();
+ }
+};
+
+template <typename T, ssize_t Dims>
+class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
+ friend class pybind11::array;
+ using ConstBase = unchecked_reference<T, Dims>;
+ using ConstBase::ConstBase;
+ using ConstBase::Dynamic;
+public:
+ /// Mutable, unchecked access to data at the given indices.
+ template <typename... Ix> T& operator()(Ix... index) {
+ static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
+ "Invalid number of indices for unchecked array reference");
+ return const_cast<T &>(ConstBase::operator()(index...));
+ }
+ /**
+ * Mutable, unchecked access data at the given index; this operator only participates if the
+ * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
+ * exactly equivalent to `obj(index)`.
+ */
+ template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
+ T &operator[](ssize_t index) { return operator()(index); }
+
+ /// Mutable pointer access to the data at the given indices.
+ template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(ssize_t(ix)...); }
+};
+
+template <typename T, ssize_t Dim>
+struct type_caster<unchecked_reference<T, Dim>> {
+ static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
+};
+template <typename T, ssize_t Dim>
+struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
+
+NAMESPACE_END(detail)
+
+class dtype : public object {
+public:
+ PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
+
+ explicit dtype(const buffer_info &info) {
+ dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
+ // If info.itemsize == 0, use the value calculated from the format string
+ m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
+ }
+
+ explicit dtype(const std::string &format) {
+ m_ptr = from_args(pybind11::str(format)).release().ptr();
+ }
+
+ dtype(const char *format) : dtype(std::string(format)) { }
+
+ dtype(list names, list formats, list offsets, ssize_t itemsize) {
+ dict args;
+ args["names"] = names;
+ args["formats"] = formats;
+ args["offsets"] = offsets;
+ args["itemsize"] = pybind11::int_(itemsize);
+ m_ptr = from_args(args).release().ptr();
+ }
+
+ /// This is essentially the same as calling numpy.dtype(args) in Python.
+ static dtype from_args(object args) {
+ PyObject *ptr = nullptr;
+ if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
+ throw error_already_set();
+ return reinterpret_steal<dtype>(ptr);
+ }
+
+ /// Return dtype associated with a C++ type.
+ template <typename T> static dtype of() {
+ return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
+ }
+
+ /// Size of the data type in bytes.
+ ssize_t itemsize() const {
+ return detail::array_descriptor_proxy(m_ptr)->elsize;
+ }
+
+ /// Returns true for structured data types.
+ bool has_fields() const {
+ return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
+ }
+
+ /// Single-character type code.
+ char kind() const {
+ return detail::array_descriptor_proxy(m_ptr)->kind;
+ }
+
+private:
+ static object _dtype_from_pep3118() {
+ static PyObject *obj = module::import("numpy.core._internal")
+ .attr("_dtype_from_pep3118").cast<object>().release().ptr();
+ return reinterpret_borrow<object>(obj);
+ }
+
+ dtype strip_padding(ssize_t itemsize) {
+ // Recursively strip all void fields with empty names that are generated for
+ // padding fields (as of NumPy v1.11).
+ if (!has_fields())
+ return *this;
+
+ struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
+ std::vector<field_descr> field_descriptors;
+
+ for (auto field : attr("fields").attr("items")()) {
+ auto spec = field.cast<tuple>();
+ auto name = spec[0].cast<pybind11::str>();
+ auto format = spec[1].cast<tuple>()[0].cast<dtype>();
+ auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
+ if (!len(name) && format.kind() == 'V')
+ continue;
+ field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
+ }
+
+ std::sort(field_descriptors.begin(), field_descriptors.end(),
+ [](const field_descr& a, const field_descr& b) {
+ return a.offset.cast<int>() < b.offset.cast<int>();
+ });
+
+ list names, formats, offsets;
+ for (auto& descr : field_descriptors) {
+ names.append(descr.name);
+ formats.append(descr.format);
+ offsets.append(descr.offset);
+ }
+ return dtype(names, formats, offsets, itemsize);
+ }
+};
+
+class array : public buffer {
+public:
+ PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
+
+ enum {
+ c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
+ f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
+ forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
+ };
+
+ array() : array({{0}}, static_cast<const double *>(nullptr)) {}
+
+ using ShapeContainer = detail::any_container<ssize_t>;
+ using StridesContainer = detail::any_container<ssize_t>;
+
+ // Constructs an array taking shape/strides from arbitrary container types
+ array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
+ const void *ptr = nullptr, handle base = handle()) {
+
+ if (strides->empty())
+ *strides = c_strides(*shape, dt.itemsize());
+
+ auto ndim = shape->size();
+ if (ndim != strides->size())
+ pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
+ auto descr = dt;
+
+ int flags = 0;
+ if (base && ptr) {
+ if (isinstance<array>(base))
+ /* Copy flags from base (except ownership bit) */
+ flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
+ else
+ /* Writable by default, easy to downgrade later on if needed */
+ flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
+ }
+
+ auto &api = detail::npy_api::get();
+ auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
+ api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(),
+ const_cast<void *>(ptr), flags, nullptr));
+ if (!tmp)
+ throw error_already_set();
+ if (ptr) {
+ if (base) {
+ api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
+ } else {
+ tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
+ }
+ }
+ m_ptr = tmp.release().ptr();
+ }
+
+ array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
+ : array(dt, std::move(shape), {}, ptr, base) { }
+
+ template <typename T, typename = detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
+ array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
+ : array(dt, {{count}}, ptr, base) { }
+
+ template <typename T>
+ array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
+ : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
+
+ template <typename T>
+ array(ShapeContainer shape, const T *ptr, handle base = handle())
+ : array(std::move(shape), {}, ptr, base) { }
+
+ template <typename T>
+ explicit array(ssize_t count, const T *ptr, handle base = handle()) : array({count}, {}, ptr, base) { }
+
+ explicit array(const buffer_info &info)
+ : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
+
+ /// Array descriptor (dtype)
+ pybind11::dtype dtype() const {
+ return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
+ }
+
+ /// Total number of elements
+ ssize_t size() const {
+ return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
+ }
+
+ /// Byte size of a single element
+ ssize_t itemsize() const {
+ return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
+ }
+
+ /// Total number of bytes
+ ssize_t nbytes() const {
+ return size() * itemsize();
+ }
+
+ /// Number of dimensions
+ ssize_t ndim() const {
+ return detail::array_proxy(m_ptr)->nd;
+ }
+
+ /// Base object
+ object base() const {
+ return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
+ }
+
+ /// Dimensions of the array
+ const ssize_t* shape() const {
+ return detail::array_proxy(m_ptr)->dimensions;
+ }
+
+ /// Dimension along a given axis
+ ssize_t shape(ssize_t dim) const {
+ if (dim >= ndim())
+ fail_dim_check(dim, "invalid axis");
+ return shape()[dim];
+ }
+
+ /// Strides of the array
+ const ssize_t* strides() const {
+ return detail::array_proxy(m_ptr)->strides;
+ }
+
+ /// Stride along a given axis
+ ssize_t strides(ssize_t dim) const {
+ if (dim >= ndim())
+ fail_dim_check(dim, "invalid axis");
+ return strides()[dim];
+ }
+
+ /// Return the NumPy array flags
+ int flags() const {
+ return detail::array_proxy(m_ptr)->flags;
+ }
+
+ /// If set, the array is writeable (otherwise the buffer is read-only)
+ bool writeable() const {
+ return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
+ }
+
+ /// If set, the array owns the data (will be freed when the array is deleted)
+ bool owndata() const {
+ return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
+ }
+
+ /// Pointer to the contained data. If index is not provided, points to the
+ /// beginning of the buffer. May throw if the index would lead to out of bounds access.
+ template<typename... Ix> const void* data(Ix... index) const {
+ return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
+ }
+
+ /// Mutable pointer to the contained data. If index is not provided, points to the
+ /// beginning of the buffer. May throw if the index would lead to out of bounds access.
+ /// May throw if the array is not writeable.
+ template<typename... Ix> void* mutable_data(Ix... index) {
+ check_writeable();
+ return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
+ }
+
+ /// Byte offset from beginning of the array to a given index (full or partial).
+ /// May throw if the index would lead to out of bounds access.
+ template<typename... Ix> ssize_t offset_at(Ix... index) const {
+ if ((ssize_t) sizeof...(index) > ndim())
+ fail_dim_check(sizeof...(index), "too many indices for an array");
+ return byte_offset(ssize_t(index)...);
+ }
+
+ ssize_t offset_at() const { return 0; }
+
+ /// Item count from beginning of the array to a given index (full or partial).
+ /// May throw if the index would lead to out of bounds access.
+ template<typename... Ix> ssize_t index_at(Ix... index) const {
+ return offset_at(index...) / itemsize();
+ }
+
+ /**
+ * Returns a proxy object that provides access to the array's data without bounds or
+ * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
+ * care: the array must not be destroyed or reshaped for the duration of the returned object,
+ * and the caller must take care not to access invalid dimensions or dimension indices.
+ */
+ template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
+ if (Dims >= 0 && ndim() != Dims)
+ throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
+ "; expected " + std::to_string(Dims));
+ return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
+ }
+
+ /**
+ * Returns a proxy object that provides const access to the array's data without bounds or
+ * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
+ * underlying array have the `writable` flag. Use with care: the array must not be destroyed or
+ * reshaped for the duration of the returned object, and the caller must take care not to access
+ * invalid dimensions or dimension indices.
+ */
+ template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
+ if (Dims >= 0 && ndim() != Dims)
+ throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
+ "; expected " + std::to_string(Dims));
+ return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
+ }
+
+ /// Return a new view with all of the dimensions of length 1 removed
+ array squeeze() {
+ auto& api = detail::npy_api::get();
+ return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
+ }
+
+ /// Resize array to given shape
+ /// If refcheck is true and more that one reference exist to this array
+ /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
+ void resize(ShapeContainer new_shape, bool refcheck = true) {
+ detail::npy_api::PyArray_Dims d = {
+ new_shape->data(), int(new_shape->size())
+ };
+ // try to resize, set ordering param to -1 cause it's not used anyway
+ object new_array = reinterpret_steal<object>(
+ detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1)
+ );
+ if (!new_array) throw error_already_set();
+ if (isinstance<array>(new_array)) { *this = std::move(new_array); }
+ }
+
+ /// Ensure that the argument is a NumPy array
+ /// In case of an error, nullptr is returned and the Python error is cleared.
+ static array ensure(handle h, int ExtraFlags = 0) {
+ auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
+ if (!result)
+ PyErr_Clear();
+ return result;
+ }
+
+protected:
+ template<typename, typename> friend struct detail::npy_format_descriptor;
+
+ void fail_dim_check(ssize_t dim, const std::string& msg) const {
+ throw index_error(msg + ": " + std::to_string(dim) +
+ " (ndim = " + std::to_string(ndim()) + ")");
+ }
+
+ template<typename... Ix> ssize_t byte_offset(Ix... index) const {
+ check_dimensions(index...);
+ return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
+ }
+
+ void check_writeable() const {
+ if (!writeable())
+ throw std::domain_error("array is not writeable");
+ }
+
+ // Default, C-style strides
+ static std::vector<ssize_t> c_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
+ auto ndim = shape.size();
+ std::vector<ssize_t> strides(ndim, itemsize);
+ for (size_t i = ndim - 1; i > 0; --i)
+ strides[i - 1] = strides[i] * shape[i];
+ return strides;
+ }
+
+ // F-style strides; default when constructing an array_t with `ExtraFlags & f_style`
+ static std::vector<ssize_t> f_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
+ auto ndim = shape.size();
+ std::vector<ssize_t> strides(ndim, itemsize);
+ for (size_t i = 1; i < ndim; ++i)
+ strides[i] = strides[i - 1] * shape[i - 1];
+ return strides;
+ }
+
+ template<typename... Ix> void check_dimensions(Ix... index) const {
+ check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
+ }
+
+ void check_dimensions_impl(ssize_t, const ssize_t*) const { }
+
+ template<typename... Ix> void check_dimensions_impl(ssize_t axis, const ssize_t* shape, ssize_t i, Ix... index) const {
+ if (i >= *shape) {
+ throw index_error(std::string("index ") + std::to_string(i) +
+ " is out of bounds for axis " + std::to_string(axis) +
+ " with size " + std::to_string(*shape));
+ }
+ check_dimensions_impl(axis + 1, shape + 1, index...);
+ }
+
+ /// Create array from any object -- always returns a new reference
+ static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
+ if (ptr == nullptr) {
+ PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
+ return nullptr;
+ }
+ return detail::npy_api::get().PyArray_FromAny_(
+ ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
+ }
+};
+
+template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
+private:
+ struct private_ctor {};
+ // Delegating constructor needed when both moving and accessing in the same constructor
+ array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base)
+ : array(std::move(shape), std::move(strides), ptr, base) {}
+public:
+ static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
+
+ using value_type = T;
+
+ array_t() : array(0, static_cast<const T *>(nullptr)) {}
+ array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { }
+ array_t(handle h, stolen_t) : array(h, stolen_t{}) { }
+
+ PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
+ array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
+ if (!m_ptr) PyErr_Clear();
+ if (!is_borrowed) Py_XDECREF(h.ptr());
+ }
+
+ array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
+ if (!m_ptr) throw error_already_set();
+ }
+
+ explicit array_t(const buffer_info& info) : array(info) { }
+
+ array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
+ : array(std::move(shape), std::move(strides), ptr, base) { }
+
+ explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
+ : array_t(private_ctor{}, std::move(shape),
+ ExtraFlags & f_style ? f_strides(*shape, itemsize()) : c_strides(*shape, itemsize()),
+ ptr, base) { }
+
+ explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
+ : array({count}, {}, ptr, base) { }
+
+ constexpr ssize_t itemsize() const {
+ return sizeof(T);
+ }
+
+ template<typename... Ix> ssize_t index_at(Ix... index) const {
+ return offset_at(index...) / itemsize();
+ }
+
+ template<typename... Ix> const T* data(Ix... index) const {
+ return static_cast<const T*>(array::data(index...));
+ }
+
+ template<typename... Ix> T* mutable_data(Ix... index) {
+ return static_cast<T*>(array::mutable_data(index...));
+ }
+
+ // Reference to element at a given index
+ template<typename... Ix> const T& at(Ix... index) const {
+ if (sizeof...(index) != ndim())
+ fail_dim_check(sizeof...(index), "index dimension mismatch");
+ return *(static_cast<const T*>(array::data()) + byte_offset(ssize_t(index)...) / itemsize());
+ }
+
+ // Mutable reference to element at a given index
+ template<typename... Ix> T& mutable_at(Ix... index) {
+ if (sizeof...(index) != ndim())
+ fail_dim_check(sizeof...(index), "index dimension mismatch");
+ return *(static_cast<T*>(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize());
+ }
+
+ /**
+ * Returns a proxy object that provides access to the array's data without bounds or
+ * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
+ * care: the array must not be destroyed or reshaped for the duration of the returned object,
+ * and the caller must take care not to access invalid dimensions or dimension indices.
+ */
+ template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
+ return array::mutable_unchecked<T, Dims>();
+ }
+
+ /**
+ * Returns a proxy object that provides const access to the array's data without bounds or
+ * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
+ * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
+ * for the duration of the returned object, and the caller must take care not to access invalid
+ * dimensions or dimension indices.
+ */
+ template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
+ return array::unchecked<T, Dims>();
+ }
+
+ /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
+ /// it). In case of an error, nullptr is returned and the Python error is cleared.
+ static array_t ensure(handle h) {
+ auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
+ if (!result)
+ PyErr_Clear();
+ return result;
+ }
+
+ static bool check_(handle h) {
+ const auto &api = detail::npy_api::get();
+ return api.PyArray_Check_(h.ptr())
+ && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
+ }
+
+protected:
+ /// Create array from any object -- always returns a new reference
+ static PyObject *raw_array_t(PyObject *ptr) {
+ if (ptr == nullptr) {
+ PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
+ return nullptr;
+ }
+ return detail::npy_api::get().PyArray_FromAny_(
+ ptr, dtype::of<T>().release().ptr(), 0, 0,
+ detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
+ }
+};
+
+template <typename T>
+struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
+ static std::string format() {
+ return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
+ }
+};
+
+template <size_t N> struct format_descriptor<char[N]> {
+ static std::string format() { return std::to_string(N) + "s"; }
+};
+template <size_t N> struct format_descriptor<std::array<char, N>> {
+ static std::string format() { return std::to_string(N) + "s"; }
+};
+
+template <typename T>
+struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
+ static std::string format() {
+ return format_descriptor<
+ typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
+ }
+};
+
+template <typename T>
+struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
+ static std::string format() {
+ using namespace detail;
+ PYBIND11_DESCR extents = _("(") + array_info<T>::extents() + _(")");
+ return extents.text() + format_descriptor<remove_all_extents_t<T>>::format();
+ }
+};
+
+NAMESPACE_BEGIN(detail)
+template <typename T, int ExtraFlags>
+struct pyobject_caster<array_t<T, ExtraFlags>> {
+ using type = array_t<T, ExtraFlags>;
+
+ bool load(handle src, bool convert) {
+ if (!convert && !type::check_(src))
+ return false;
+ value = type::ensure(src);
+ return static_cast<bool>(value);
+ }
+
+ static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
+ return src.inc_ref();
+ }
+ PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
+};
+
+template <typename T>
+struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
+ static bool compare(const buffer_info& b) {
+ return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
+ }
+};
+
+template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
+private:
+ // NB: the order here must match the one in common.h
+ constexpr static const int values[15] = {
+ npy_api::NPY_BOOL_,
+ npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
+ npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
+ npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_,
+ npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
+ };
+
+public:
+ static constexpr int value = values[detail::is_fmt_numeric<T>::index];
+
+ static pybind11::dtype dtype() {
+ if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
+ return reinterpret_borrow<pybind11::dtype>(ptr);
+ pybind11_fail("Unsupported buffer format!");
+ }
+ template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
+ static PYBIND11_DESCR name() {
+ return _<std::is_same<T, bool>::value>(_("bool"),
+ _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
+ }
+ template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
+ static PYBIND11_DESCR name() {
+ return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
+ _("float") + _<sizeof(T)*8>(), _("longdouble"));
+ }
+ template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
+ static PYBIND11_DESCR name() {
+ return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
+ _("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex"));
+ }
+};
+
+#define PYBIND11_DECL_CHAR_FMT \
+ static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
+ static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
+template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
+template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
+#undef PYBIND11_DECL_CHAR_FMT
+
+template<typename T> struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
+private:
+ using base_descr = npy_format_descriptor<typename array_info<T>::type>;
+public:
+ static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
+
+ static PYBIND11_DESCR name() { return _("(") + array_info<T>::extents() + _(")") + base_descr::name(); }
+ static pybind11::dtype dtype() {
+ list shape;
+ array_info<T>::append_extents(shape);
+ return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape));
+ }
+};
+
+template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
+private:
+ using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
+public:
+ static PYBIND11_DESCR name() { return base_descr::name(); }
+ static pybind11::dtype dtype() { return base_descr::dtype(); }
+};
+
+struct field_descriptor {
+ const char *name;
+ ssize_t offset;
+ ssize_t size;
+ std::string format;
+ dtype descr;
+};
+
+inline PYBIND11_NOINLINE void register_structured_dtype(
+ const std::initializer_list<field_descriptor>& fields,
+ const std::type_info& tinfo, ssize_t itemsize,
+ bool (*direct_converter)(PyObject *, void *&)) {
+
+ auto& numpy_internals = get_numpy_internals();
+ if (numpy_internals.get_type_info(tinfo, false))
+ pybind11_fail("NumPy: dtype is already registered");
+
+ list names, formats, offsets;
+ for (auto field : fields) {
+ if (!field.descr)
+ pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
+ field.name + "` @ " + tinfo.name());
+ names.append(PYBIND11_STR_TYPE(field.name));
+ formats.append(field.descr);
+ offsets.append(pybind11::int_(field.offset));
+ }
+ auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
+
+ // There is an existing bug in NumPy (as of v1.11): trailing bytes are
+ // not encoded explicitly into the format string. This will supposedly
+ // get fixed in v1.12; for further details, see these:
+ // - https://github.com/numpy/numpy/issues/7797
+ // - https://github.com/numpy/numpy/pull/7798
+ // Because of this, we won't use numpy's logic to generate buffer format
+ // strings and will just do it ourselves.
+ std::vector<field_descriptor> ordered_fields(fields);
+ std::sort(ordered_fields.begin(), ordered_fields.end(),
+ [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
+ ssize_t offset = 0;
+ std::ostringstream oss;
+ // mark the structure as unaligned with '^', because numpy and C++ don't
+ // always agree about alignment (particularly for complex), and we're
+ // explicitly listing all our padding. This depends on none of the fields
+ // overriding the endianness. Putting the ^ in front of individual fields
+ // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
+ oss << "^T{";
+ for (auto& field : ordered_fields) {
+ if (field.offset > offset)
+ oss << (field.offset - offset) << 'x';
+ oss << field.format << ':' << field.name << ':';
+ offset = field.offset + field.size;
+ }
+ if (itemsize > offset)
+ oss << (itemsize - offset) << 'x';
+ oss << '}';
+ auto format_str = oss.str();
+
+ // Sanity check: verify that NumPy properly parses our buffer format string
+ auto& api = npy_api::get();
+ auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
+ if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
+ pybind11_fail("NumPy: invalid buffer descriptor!");
+
+ auto tindex = std::type_index(tinfo);
+ numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
+ get_internals().direct_conversions[tindex].push_back(direct_converter);
+}
+
+template <typename T, typename SFINAE> struct npy_format_descriptor {
+ static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
+
+ static PYBIND11_DESCR name() { return make_caster<T>::name(); }
+
+ static pybind11::dtype dtype() {
+ return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
+ }
+
+ static std::string format() {
+ static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
+ return format_str;
+ }
+
+ static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
+ register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
+ sizeof(T), &direct_converter);
+ }
+
+private:
+ static PyObject* dtype_ptr() {
+ static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
+ return ptr;
+ }
+
+ static bool direct_converter(PyObject *obj, void*& value) {
+ auto& api = npy_api::get();
+ if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
+ return false;
+ if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
+ if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
+ value = ((PyVoidScalarObject_Proxy *) obj)->obval;
+ return true;
+ }
+ }
+ return false;
+ }
+};
+
+#ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
+# define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0)
+# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0)
+#else
+
+#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
+ ::pybind11::detail::field_descriptor { \
+ Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
+ ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
+ ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
+ }
+
+// Extract name, offset and format descriptor for a struct field
+#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
+
+// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
+// (C) William Swanson, Paul Fultz
+#define PYBIND11_EVAL0(...) __VA_ARGS__
+#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
+#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
+#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
+#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
+#define PYBIND11_EVAL(...) PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
+#define PYBIND11_MAP_END(...)
+#define PYBIND11_MAP_OUT
+#define PYBIND11_MAP_COMMA ,
+#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
+#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
+#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
+#define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
+#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
+#define PYBIND11_MAP_LIST_NEXT1(test, next) \
+ PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
+#else
+#define PYBIND11_MAP_LIST_NEXT1(test, next) \
+ PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
+#endif
+#define PYBIND11_MAP_LIST_NEXT(test, next) \
+ PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
+#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
+ f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
+#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
+ f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
+// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
+#define PYBIND11_MAP_LIST(f, t, ...) \
+ PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
+
+#define PYBIND11_NUMPY_DTYPE(Type, ...) \
+ ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
+ ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
+
+#ifdef _MSC_VER
+#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
+ PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
+#else
+#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
+ PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
+#endif
+#define PYBIND11_MAP2_LIST_NEXT(test, next) \
+ PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
+#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
+ f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
+#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
+ f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
+// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
+#define PYBIND11_MAP2_LIST(f, t, ...) \
+ PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))
+
+#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
+ ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
+ ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
+
+#endif // __CLION_IDE__
+
+template <class T>
+using array_iterator = typename std::add_pointer<T>::type;
+
+template <class T>
+array_iterator<T> array_begin(const buffer_info& buffer) {
+ return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
+}
+
+template <class T>
+array_iterator<T> array_end(const buffer_info& buffer) {
+ return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
+}
+
+class common_iterator {
+public:
+ using container_type = std::vector<ssize_t>;
+ using value_type = container_type::value_type;
+ using size_type = container_type::size_type;
+
+ common_iterator() : p_ptr(0), m_strides() {}
+
+ common_iterator(void* ptr, const container_type& strides, const container_type& shape)
+ : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
+ m_strides.back() = static_cast<value_type>(strides.back());
+ for (size_type i = m_strides.size() - 1; i != 0; --i) {
+ size_type j = i - 1;
+ value_type s = static_cast<value_type>(shape[i]);
+ m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
+ }
+ }
+
+ void increment(size_type dim) {
+ p_ptr += m_strides[dim];
+ }
+
+ void* data() const {
+ return p_ptr;
+ }
+
+private:
+ char* p_ptr;
+ container_type m_strides;
+};
+
+template <size_t N> class multi_array_iterator {
+public:
+ using container_type = std::vector<ssize_t>;
+
+ multi_array_iterator(const std::array<buffer_info, N> &buffers,
+ const container_type &shape)
+ : m_shape(shape.size()), m_index(shape.size(), 0),
+ m_common_iterator() {
+
+ // Manual copy to avoid conversion warning if using std::copy
+ for (size_t i = 0; i < shape.size(); ++i)
+ m_shape[i] = shape[i];
+
+ container_type strides(shape.size());
+ for (size_t i = 0; i < N; ++i)
+ init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
+ }
+
+ multi_array_iterator& operator++() {
+ for (size_t j = m_index.size(); j != 0; --j) {
+ size_t i = j - 1;
+ if (++m_index[i] != m_shape[i]) {
+ increment_common_iterator(i);
+ break;
+ } else {
+ m_index[i] = 0;
+ }
+ }
+ return *this;
+ }
+
+ template <size_t K, class T = void> T* data() const {
+ return reinterpret_cast<T*>(m_common_iterator[K].data());
+ }
+
+private:
+
+ using common_iter = common_iterator;
+
+ void init_common_iterator(const buffer_info &buffer,
+ const container_type &shape,
+ common_iter &iterator,
+ container_type &strides) {
+ auto buffer_shape_iter = buffer.shape.rbegin();
+ auto buffer_strides_iter = buffer.strides.rbegin();
+ auto shape_iter = shape.rbegin();
+ auto strides_iter = strides.rbegin();
+
+ while (buffer_shape_iter != buffer.shape.rend()) {
+ if (*shape_iter == *buffer_shape_iter)
+ *strides_iter = *buffer_strides_iter;
+ else
+ *strides_iter = 0;
+
+ ++buffer_shape_iter;
+ ++buffer_strides_iter;
+ ++shape_iter;
+ ++strides_iter;
+ }
+
+ std::fill(strides_iter, strides.rend(), 0);
+ iterator = common_iter(buffer.ptr, strides, shape);
+ }
+
+ void increment_common_iterator(size_t dim) {
+ for (auto &iter : m_common_iterator)
+ iter.increment(dim);
+ }
+
+ container_type m_shape;
+ container_type m_index;
+ std::array<common_iter, N> m_common_iterator;
+};
+
+enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };
+
+// Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial
+// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a
+// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage
+// buffer; returns `non_trivial` otherwise.
+template <size_t N>
+broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
+ ndim = std::accumulate(buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
+ return std::max(res, buf.ndim);
+ });
+
+ shape.clear();
+ shape.resize((size_t) ndim, 1);
+
+ // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or
+ // the full size).
+ for (size_t i = 0; i < N; ++i) {
+ auto res_iter = shape.rbegin();
+ auto end = buffers[i].shape.rend();
+ for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) {
+ const auto &dim_size_in = *shape_iter;
+ auto &dim_size_out = *res_iter;
+
+ // Each input dimension can either be 1 or `n`, but `n` values must match across buffers
+ if (dim_size_out == 1)
+ dim_size_out = dim_size_in;
+ else if (dim_size_in != 1 && dim_size_in != dim_size_out)
+ pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
+ }
+ }
+
+ bool trivial_broadcast_c = true;
+ bool trivial_broadcast_f = true;
+ for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
+ if (buffers[i].size == 1)
+ continue;
+
+ // Require the same number of dimensions:
+ if (buffers[i].ndim != ndim)
+ return broadcast_trivial::non_trivial;
+
+ // Require all dimensions be full-size:
+ if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin()))
+ return broadcast_trivial::non_trivial;
+
+ // Check for C contiguity (but only if previous inputs were also C contiguous)
+ if (trivial_broadcast_c) {
+ ssize_t expect_stride = buffers[i].itemsize;
+ auto end = buffers[i].shape.crend();
+ for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin();
+ trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) {
+ if (expect_stride == *stride_iter)
+ expect_stride *= *shape_iter;
+ else
+ trivial_broadcast_c = false;
+ }
+ }
+
+ // Check for Fortran contiguity (if previous inputs were also F contiguous)
+ if (trivial_broadcast_f) {
+ ssize_t expect_stride = buffers[i].itemsize;
+ auto end = buffers[i].shape.cend();
+ for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin();
+ trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) {
+ if (expect_stride == *stride_iter)
+ expect_stride *= *shape_iter;
+ else
+ trivial_broadcast_f = false;
+ }
+ }
+ }
+
+ return
+ trivial_broadcast_c ? broadcast_trivial::c_trivial :
+ trivial_broadcast_f ? broadcast_trivial::f_trivial :
+ broadcast_trivial::non_trivial;
+}
+
+template <typename T>
+struct vectorize_arg {
+ static_assert(!std::is_rvalue_reference<T>::value, "Functions with rvalue reference arguments cannot be vectorized");
+ // The wrapped function gets called with this type:
+ using call_type = remove_reference_t<T>;
+ // Is this a vectorized argument?
+ static constexpr bool vectorize =
+ satisfies_any_of<call_type, std::is_arithmetic, is_complex, std::is_pod>::value &&
+ satisfies_none_of<call_type, std::is_pointer, std::is_array, is_std_array, std::is_enum>::value &&
+ (!std::is_reference<T>::value ||
+ (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
+ // Accept this type: an array for vectorized types, otherwise the type as-is:
+ using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
+};
+
+template <typename Func, typename Return, typename... Args>
+struct vectorize_helper {
+private:
+ static constexpr size_t N = sizeof...(Args);
+ static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
+ static_assert(NVectorized >= 1,
+ "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
+
+public:
+ template <typename T>
+ explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) { }
+
+ object operator()(typename vectorize_arg<Args>::type... args) {
+ return run(args...,
+ make_index_sequence<N>(),
+ select_indices<vectorize_arg<Args>::vectorize...>(),
+ make_index_sequence<NVectorized>());
+ }
+
+private:
+ remove_reference_t<Func> f;
+
+ template <size_t Index> using param_n_t = typename pack_element<Index, typename vectorize_arg<Args>::call_type...>::type;
+
+ // Runs a vectorized function given arguments tuple and three index sequences:
+ // - Index is the full set of 0 ... (N-1) argument indices;
+ // - VIndex is the subset of argument indices with vectorized parameters, letting us access
+ // vectorized arguments (anything not in this sequence is passed through)
+ // - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that
+ // we can store vectorized buffer_infos in an array (argument VIndex has its buffer at
+ // index BIndex in the array).
+ template <size_t... Index, size_t... VIndex, size_t... BIndex> object run(
+ typename vectorize_arg<Args>::type &...args,
+ index_sequence<Index...> i_seq, index_sequence<VIndex...> vi_seq, index_sequence<BIndex...> bi_seq) {
+
+ // Pointers to values the function was called with; the vectorized ones set here will start
+ // out as array_t<T> pointers, but they will be changed them to T pointers before we make
+ // call the wrapped function. Non-vectorized pointers are left as-is.
+ std::array<void *, N> params{{ &args... }};
+
+ // The array of `buffer_info`s of vectorized arguments:
+ std::array<buffer_info, NVectorized> buffers{{ reinterpret_cast<array *>(params[VIndex])->request()... }};
+
+ /* Determine dimensions parameters of output array */
+ ssize_t nd = 0;
+ std::vector<ssize_t> shape(0);
+ auto trivial = broadcast(buffers, nd, shape);
+ size_t ndim = (size_t) nd;
+
+ size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());
+
+ // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e.
+ // not wrapped in an array).
+ if (size == 1 && ndim == 0) {
+ PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
+ return cast(f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...));
+ }
+
+ array_t<Return> result;
+ if (trivial == broadcast_trivial::f_trivial) result = array_t<Return, array::f_style>(shape);
+ else result = array_t<Return>(shape);
+
+ if (size == 0) return result;
+
+ /* Call the function */
+ if (trivial == broadcast_trivial::non_trivial)
+ apply_broadcast(buffers, params, result, i_seq, vi_seq, bi_seq);
+ else
+ apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq);
+
+ return result;
+ }
+
+ template <size_t... Index, size_t... VIndex, size_t... BIndex>
+ void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
+ std::array<void *, N> &params,
+ Return *out,
+ size_t size,
+ index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
+
+ // Initialize an array of mutable byte references and sizes with references set to the
+ // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size
+ // (except for singletons, which get an increment of 0).
+ std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{{
+ std::pair<unsigned char *&, const size_t>(
+ reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
+ buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>)
+ )...
+ }};
+
+ for (size_t i = 0; i < size; ++i) {
+ out[i] = f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...);
+ for (auto &x : vecparams) x.first += x.second;
+ }
+ }
+
+ template <size_t... Index, size_t... VIndex, size_t... BIndex>
+ void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
+ std::array<void *, N> &params,
+ array_t<Return> &output_array,
+ index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
+
+ buffer_info output = output_array.request();
+ multi_array_iterator<NVectorized> input_iter(buffers, output.shape);
+
+ for (array_iterator<Return> iter = array_begin<Return>(output), end = array_end<Return>(output);
+ iter != end;
+ ++iter, ++input_iter) {
+ PYBIND11_EXPAND_SIDE_EFFECTS((
+ params[VIndex] = input_iter.template data<BIndex>()
+ ));
+ *iter = f(*reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
+ }
+ }
+};
+
+template <typename Func, typename Return, typename... Args>
+vectorize_helper<Func, Return, Args...>
+vectorize_extractor(const Func &f, Return (*) (Args ...)) {
+ return detail::vectorize_helper<Func, Return, Args...>(f);
+}
+
+template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
+ static PYBIND11_DESCR name() {
+ return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]");
+ }
+};
+
+NAMESPACE_END(detail)
+
+// Vanilla pointer vectorizer:
+template <typename Return, typename... Args>
+detail::vectorize_helper<Return (*)(Args...), Return, Args...>
+vectorize(Return (*f) (Args ...)) {
+ return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
+}
+
+// lambda vectorizer:
+template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
+auto vectorize(Func &&f) -> decltype(
+ detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr)) {
+ return detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr);
+}
+
+// Vectorize a class method (non-const):
+template <typename Return, typename Class, typename... Args,
+ typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())), Return, Class *, Args...>>
+Helper vectorize(Return (Class::*f)(Args...)) {
+ return Helper(std::mem_fn(f));
+}
+
+// Vectorize a class method (non-const):
+template <typename Return, typename Class, typename... Args,
+ typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())), Return, const Class *, Args...>>
+Helper vectorize(Return (Class::*f)(Args...) const) {
+ return Helper(std::mem_fn(f));
+}
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/operators.h b/ml/dlib/dlib/external/pybind11/include/pybind11/operators.h
new file mode 100644
index 000000000..b3dd62c3b
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/operators.h
@@ -0,0 +1,168 @@
+/*
+ pybind11/operator.h: Metatemplates for operator overloading
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pybind11.h"
+
+#if defined(__clang__) && !defined(__INTEL_COMPILER)
+# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type()))
+#elif defined(_MSC_VER)
+# pragma warning(push)
+# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
+#endif
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+/// Enumeration with all supported operator types
+enum op_id : int {
+ op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift,
+ op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert,
+ op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le,
+ op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift,
+ op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero,
+ op_repr, op_truediv, op_itruediv, op_hash
+};
+
+enum op_type : int {
+ op_l, /* base type on left */
+ op_r, /* base type on right */
+ op_u /* unary operator */
+};
+
+struct self_t { };
+static const self_t self = self_t();
+
+/// Type for an unused type slot
+struct undefined_t { };
+
+/// Don't warn about an unused variable
+inline self_t __self() { return self; }
+
+/// base template of operator implementations
+template <op_id, op_type, typename B, typename L, typename R> struct op_impl { };
+
+/// Operator implementation generator
+template <op_id id, op_type ot, typename L, typename R> struct op_ {
+ template <typename Class, typename... Extra> void execute(Class &cl, const Extra&... extra) const {
+ using Base = typename Class::type;
+ using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
+ using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
+ using op = op_impl<id, ot, Base, L_type, R_type>;
+ cl.def(op::name(), &op::execute, is_operator(), extra...);
+ #if PY_MAJOR_VERSION < 3
+ if (id == op_truediv || id == op_itruediv)
+ cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
+ &op::execute, is_operator(), extra...);
+ #endif
+ }
+ template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
+ using Base = typename Class::type;
+ using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
+ using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
+ using op = op_impl<id, ot, Base, L_type, R_type>;
+ cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
+ #if PY_MAJOR_VERSION < 3
+ if (id == op_truediv || id == op_itruediv)
+ cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
+ &op::execute, is_operator(), extra...);
+ #endif
+ }
+};
+
+#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
+template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
+ static char const* name() { return "__" #id "__"; } \
+ static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \
+ static B execute_cast(const L &l, const R &r) { return B(expr); } \
+}; \
+template <typename B, typename L, typename R> struct op_impl<op_##id, op_r, B, L, R> { \
+ static char const* name() { return "__" #rid "__"; } \
+ static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \
+ static B execute_cast(const R &r, const L &l) { return B(expr); } \
+}; \
+inline op_<op_##id, op_l, self_t, self_t> op(const self_t &, const self_t &) { \
+ return op_<op_##id, op_l, self_t, self_t>(); \
+} \
+template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
+ return op_<op_##id, op_l, self_t, T>(); \
+} \
+template <typename T> op_<op_##id, op_r, T, self_t> op(const T &, const self_t &) { \
+ return op_<op_##id, op_r, T, self_t>(); \
+}
+
+#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
+template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
+ static char const* name() { return "__" #id "__"; } \
+ static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
+ static B execute_cast(L &l, const R &r) { return B(expr); } \
+}; \
+template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
+ return op_<op_##id, op_l, self_t, T>(); \
+}
+
+#define PYBIND11_UNARY_OPERATOR(id, op, expr) \
+template <typename B, typename L> struct op_impl<op_##id, op_u, B, L, undefined_t> { \
+ static char const* name() { return "__" #id "__"; } \
+ static auto execute(const L &l) -> decltype(expr) { return expr; } \
+ static B execute_cast(const L &l) { return B(expr); } \
+}; \
+inline op_<op_##id, op_u, self_t, undefined_t> op(const self_t &) { \
+ return op_<op_##id, op_u, self_t, undefined_t>(); \
+}
+
+PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
+PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
+PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r)
+PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
+PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
+PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
+PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r)
+PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r)
+PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
+PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
+PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
+PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
+PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r)
+PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
+PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r)
+PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
+//PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r))
+PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
+PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
+PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
+PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
+PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
+PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
+PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
+PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
+PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
+PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
+PYBIND11_UNARY_OPERATOR(neg, operator-, -l)
+PYBIND11_UNARY_OPERATOR(pos, operator+, +l)
+PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
+PYBIND11_UNARY_OPERATOR(hash, hash, std::hash<L>()(l))
+PYBIND11_UNARY_OPERATOR(invert, operator~, (~l))
+PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
+PYBIND11_UNARY_OPERATOR(int, int_, (int) l)
+PYBIND11_UNARY_OPERATOR(float, float_, (double) l)
+
+#undef PYBIND11_BINARY_OPERATOR
+#undef PYBIND11_INPLACE_OPERATOR
+#undef PYBIND11_UNARY_OPERATOR
+NAMESPACE_END(detail)
+
+using detail::self;
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
+
+#if defined(_MSC_VER)
+# pragma warning(pop)
+#endif
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/options.h b/ml/dlib/dlib/external/pybind11/include/pybind11/options.h
new file mode 100644
index 000000000..cc1e1f6f0
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/options.h
@@ -0,0 +1,65 @@
+/*
+ pybind11/options.h: global settings that are configurable at runtime.
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "detail/common.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+class options {
+public:
+
+ // Default RAII constructor, which leaves settings as they currently are.
+ options() : previous_state(global_state()) {}
+
+ // Class is non-copyable.
+ options(const options&) = delete;
+ options& operator=(const options&) = delete;
+
+ // Destructor, which restores settings that were in effect before.
+ ~options() {
+ global_state() = previous_state;
+ }
+
+ // Setter methods (affect the global state):
+
+ options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; }
+
+ options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; }
+
+ options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; }
+
+ options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; }
+
+ // Getter methods (return the global state):
+
+ static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; }
+
+ static bool show_function_signatures() { return global_state().show_function_signatures; }
+
+ // This type is not meant to be allocated on the heap.
+ void* operator new(size_t) = delete;
+
+private:
+
+ struct state {
+ bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings.
+ bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings.
+ };
+
+ static state &global_state() {
+ static state instance;
+ return instance;
+ }
+
+ state previous_state;
+};
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/pybind11.h b/ml/dlib/dlib/external/pybind11/include/pybind11/pybind11.h
new file mode 100644
index 000000000..7723d2a8e
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/pybind11.h
@@ -0,0 +1,1963 @@
+/*
+ pybind11/pybind11.h: Main header file of the C++11 python
+ binding generator library
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#if defined(_MSC_VER)
+# pragma warning(push)
+# pragma warning(disable: 4100) // warning C4100: Unreferenced formal parameter
+# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
+# pragma warning(disable: 4512) // warning C4512: Assignment operator was implicitly defined as deleted
+# pragma warning(disable: 4800) // warning C4800: 'int': forcing value to bool 'true' or 'false' (performance warning)
+# pragma warning(disable: 4996) // warning C4996: The POSIX name for this item is deprecated. Instead, use the ISO C and C++ conformant name
+# pragma warning(disable: 4702) // warning C4702: unreachable code
+# pragma warning(disable: 4522) // warning C4522: multiple assignment operators specified
+#elif defined(__INTEL_COMPILER)
+# pragma warning(push)
+# pragma warning(disable: 68) // integer conversion resulted in a change of sign
+# pragma warning(disable: 186) // pointless comparison of unsigned integer with zero
+# pragma warning(disable: 878) // incompatible exception specifications
+# pragma warning(disable: 1334) // the "template" keyword used for syntactic disambiguation may only be used within a template
+# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
+# pragma warning(disable: 1875) // offsetof applied to non-POD (Plain Old Data) types is nonstandard
+# pragma warning(disable: 2196) // warning #2196: routine is both "inline" and "noinline"
+#elif defined(__GNUG__) && !defined(__clang__)
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wunused-but-set-parameter"
+# pragma GCC diagnostic ignored "-Wunused-but-set-variable"
+# pragma GCC diagnostic ignored "-Wmissing-field-initializers"
+# pragma GCC diagnostic ignored "-Wstrict-aliasing"
+# pragma GCC diagnostic ignored "-Wattributes"
+# if __GNUC__ >= 7
+# pragma GCC diagnostic ignored "-Wnoexcept-type"
+# endif
+#endif
+
+#include "attr.h"
+#include "options.h"
+#include "detail/class.h"
+#include "detail/init.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+/// Wraps an arbitrary C++ function/method/lambda function/.. into a callable Python object
+class cpp_function : public function {
+public:
+ cpp_function() { }
+
+ /// Construct a cpp_function from a vanilla function pointer
+ template <typename Return, typename... Args, typename... Extra>
+ cpp_function(Return (*f)(Args...), const Extra&... extra) {
+ initialize(f, f, extra...);
+ }
+
+ /// Construct a cpp_function from a lambda function (possibly with internal state)
+ template <typename Func, typename... Extra,
+ typename = detail::enable_if_t<detail::is_lambda<Func>::value>>
+ cpp_function(Func &&f, const Extra&... extra) {
+ initialize(std::forward<Func>(f),
+ (detail::function_signature_t<Func> *) nullptr, extra...);
+ }
+
+ /// Construct a cpp_function from a class method (non-const)
+ template <typename Return, typename Class, typename... Arg, typename... Extra>
+ cpp_function(Return (Class::*f)(Arg...), const Extra&... extra) {
+ initialize([f](Class *c, Arg... args) -> Return { return (c->*f)(args...); },
+ (Return (*) (Class *, Arg...)) nullptr, extra...);
+ }
+
+ /// Construct a cpp_function from a class method (const)
+ template <typename Return, typename Class, typename... Arg, typename... Extra>
+ cpp_function(Return (Class::*f)(Arg...) const, const Extra&... extra) {
+ initialize([f](const Class *c, Arg... args) -> Return { return (c->*f)(args...); },
+ (Return (*)(const Class *, Arg ...)) nullptr, extra...);
+ }
+
+ /// Return the function name
+ object name() const { return attr("__name__"); }
+
+protected:
+ /// Space optimization: don't inline this frequently instantiated fragment
+ PYBIND11_NOINLINE detail::function_record *make_function_record() {
+ return new detail::function_record();
+ }
+
+ /// Special internal constructor for functors, lambda functions, etc.
+ template <typename Func, typename Return, typename... Args, typename... Extra>
+ void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) {
+ using namespace detail;
+
+ struct capture { remove_reference_t<Func> f; };
+
+ /* Store the function including any extra state it might have (e.g. a lambda capture object) */
+ auto rec = make_function_record();
+
+ /* Store the capture object directly in the function record if there is enough space */
+ if (sizeof(capture) <= sizeof(rec->data)) {
+ /* Without these pragmas, GCC warns that there might not be
+ enough space to use the placement new operator. However, the
+ 'if' statement above ensures that this is the case. */
+#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wplacement-new"
+#endif
+ new ((capture *) &rec->data) capture { std::forward<Func>(f) };
+#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6
+# pragma GCC diagnostic pop
+#endif
+ if (!std::is_trivially_destructible<Func>::value)
+ rec->free_data = [](function_record *r) { ((capture *) &r->data)->~capture(); };
+ } else {
+ rec->data[0] = new capture { std::forward<Func>(f) };
+ rec->free_data = [](function_record *r) { delete ((capture *) r->data[0]); };
+ }
+
+ /* Type casters for the function arguments and return value */
+ using cast_in = argument_loader<Args...>;
+ using cast_out = make_caster<
+ conditional_t<std::is_void<Return>::value, void_type, Return>
+ >;
+
+ static_assert(expected_num_args<Extra...>(sizeof...(Args), cast_in::has_args, cast_in::has_kwargs),
+ "The number of argument annotations does not match the number of function arguments");
+
+ /* Dispatch code which converts function arguments and performs the actual function call */
+ rec->impl = [](function_call &call) -> handle {
+ cast_in args_converter;
+
+ /* Try to cast the function arguments into the C++ domain */
+ if (!args_converter.load_args(call))
+ return PYBIND11_TRY_NEXT_OVERLOAD;
+
+ /* Invoke call policy pre-call hook */
+ process_attributes<Extra...>::precall(call);
+
+ /* Get a pointer to the capture object */
+ auto data = (sizeof(capture) <= sizeof(call.func.data)
+ ? &call.func.data : call.func.data[0]);
+ capture *cap = const_cast<capture *>(reinterpret_cast<const capture *>(data));
+
+ /* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
+ const auto policy = return_value_policy_override<Return>::policy(call.func.policy);
+
+ /* Function scope guard -- defaults to the compile-to-nothing `void_type` */
+ using Guard = extract_guard_t<Extra...>;
+
+ /* Perform the function call */
+ handle result = cast_out::cast(
+ std::move(args_converter).template call<Return, Guard>(cap->f), policy, call.parent);
+
+ /* Invoke call policy post-call hook */
+ process_attributes<Extra...>::postcall(call, result);
+
+ return result;
+ };
+
+ /* Process any user-provided function attributes */
+ process_attributes<Extra...>::init(extra..., rec);
+
+ /* Generate a readable signature describing the function's arguments and return value types */
+ PYBIND11_DESCR signature = _("(") + cast_in::arg_names() + _(") -> ") + cast_out::name();
+
+ /* Register the function with Python from generic (non-templated) code */
+ initialize_generic(rec, signature.text(), signature.types(), sizeof...(Args));
+
+ if (cast_in::has_args) rec->has_args = true;
+ if (cast_in::has_kwargs) rec->has_kwargs = true;
+
+ /* Stash some additional information used by an important optimization in 'functional.h' */
+ using FunctionType = Return (*)(Args...);
+ constexpr bool is_function_ptr =
+ std::is_convertible<Func, FunctionType>::value &&
+ sizeof(capture) == sizeof(void *);
+ if (is_function_ptr) {
+ rec->is_stateless = true;
+ rec->data[1] = const_cast<void *>(reinterpret_cast<const void *>(&typeid(FunctionType)));
+ }
+ }
+
+ /// Register a function call with Python (generic non-templated code goes here)
+ void initialize_generic(detail::function_record *rec, const char *text,
+ const std::type_info *const *types, size_t args) {
+
+ /* Create copies of all referenced C-style strings */
+ rec->name = strdup(rec->name ? rec->name : "");
+ if (rec->doc) rec->doc = strdup(rec->doc);
+ for (auto &a: rec->args) {
+ if (a.name)
+ a.name = strdup(a.name);
+ if (a.descr)
+ a.descr = strdup(a.descr);
+ else if (a.value)
+ a.descr = strdup(a.value.attr("__repr__")().cast<std::string>().c_str());
+ }
+
+ rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__");
+
+#if !defined(NDEBUG) && !defined(PYBIND11_DISABLE_NEW_STYLE_INIT_WARNING)
+ if (rec->is_constructor && !rec->is_new_style_constructor) {
+ const auto class_name = std::string(((PyTypeObject *) rec->scope.ptr())->tp_name);
+ const auto func_name = std::string(rec->name);
+ PyErr_WarnEx(
+ PyExc_FutureWarning,
+ ("pybind11-bound class '" + class_name + "' is using an old-style "
+ "placement-new '" + func_name + "' which has been deprecated. See "
+ "the upgrade guide in pybind11's docs. This message is only visible "
+ "when compiled in debug mode.").c_str(), 0
+ );
+ }
+#endif
+
+ /* Generate a proper function signature */
+ std::string signature;
+ size_t type_depth = 0, char_index = 0, type_index = 0, arg_index = 0;
+ while (true) {
+ char c = text[char_index++];
+ if (c == '\0')
+ break;
+
+ if (c == '{') {
+ // Write arg name for everything except *args, **kwargs and return type.
+ if (type_depth == 0 && text[char_index] != '*' && arg_index < args) {
+ if (!rec->args.empty() && rec->args[arg_index].name) {
+ signature += rec->args[arg_index].name;
+ } else if (arg_index == 0 && rec->is_method) {
+ signature += "self";
+ } else {
+ signature += "arg" + std::to_string(arg_index - (rec->is_method ? 1 : 0));
+ }
+ signature += ": ";
+ }
+ ++type_depth;
+ } else if (c == '}') {
+ --type_depth;
+ if (type_depth == 0) {
+ if (arg_index < rec->args.size() && rec->args[arg_index].descr) {
+ signature += "=";
+ signature += rec->args[arg_index].descr;
+ }
+ arg_index++;
+ }
+ } else if (c == '%') {
+ const std::type_info *t = types[type_index++];
+ if (!t)
+ pybind11_fail("Internal error while parsing type signature (1)");
+ if (auto tinfo = detail::get_type_info(*t)) {
+ handle th((PyObject *) tinfo->type);
+ signature +=
+ th.attr("__module__").cast<std::string>() + "." +
+ th.attr("__qualname__").cast<std::string>(); // Python 3.3+, but we backport it to earlier versions
+ } else if (rec->is_new_style_constructor && arg_index == 0) {
+ // A new-style `__init__` takes `self` as `value_and_holder`.
+ // Rewrite it to the proper class type.
+ signature +=
+ rec->scope.attr("__module__").cast<std::string>() + "." +
+ rec->scope.attr("__qualname__").cast<std::string>();
+ } else {
+ std::string tname(t->name());
+ detail::clean_type_id(tname);
+ signature += tname;
+ }
+ } else {
+ signature += c;
+ }
+ }
+ if (type_depth != 0 || types[type_index] != nullptr)
+ pybind11_fail("Internal error while parsing type signature (2)");
+
+ #if !defined(PYBIND11_CONSTEXPR_DESCR)
+ delete[] types;
+ delete[] text;
+ #endif
+
+#if PY_MAJOR_VERSION < 3
+ if (strcmp(rec->name, "__next__") == 0) {
+ std::free(rec->name);
+ rec->name = strdup("next");
+ } else if (strcmp(rec->name, "__bool__") == 0) {
+ std::free(rec->name);
+ rec->name = strdup("__nonzero__");
+ }
+#endif
+ rec->signature = strdup(signature.c_str());
+ rec->args.shrink_to_fit();
+ rec->nargs = (std::uint16_t) args;
+
+ if (rec->sibling && PYBIND11_INSTANCE_METHOD_CHECK(rec->sibling.ptr()))
+ rec->sibling = PYBIND11_INSTANCE_METHOD_GET_FUNCTION(rec->sibling.ptr());
+
+ detail::function_record *chain = nullptr, *chain_start = rec;
+ if (rec->sibling) {
+ if (PyCFunction_Check(rec->sibling.ptr())) {
+ auto rec_capsule = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(rec->sibling.ptr()));
+ chain = (detail::function_record *) rec_capsule;
+ /* Never append a method to an overload chain of a parent class;
+ instead, hide the parent's overloads in this case */
+ if (!chain->scope.is(rec->scope))
+ chain = nullptr;
+ }
+ // Don't trigger for things like the default __init__, which are wrapper_descriptors that we are intentionally replacing
+ else if (!rec->sibling.is_none() && rec->name[0] != '_')
+ pybind11_fail("Cannot overload existing non-function object \"" + std::string(rec->name) +
+ "\" with a function of the same name");
+ }
+
+ if (!chain) {
+ /* No existing overload was found, create a new function object */
+ rec->def = new PyMethodDef();
+ std::memset(rec->def, 0, sizeof(PyMethodDef));
+ rec->def->ml_name = rec->name;
+ rec->def->ml_meth = reinterpret_cast<PyCFunction>(*dispatcher);
+ rec->def->ml_flags = METH_VARARGS | METH_KEYWORDS;
+
+ capsule rec_capsule(rec, [](void *ptr) {
+ destruct((detail::function_record *) ptr);
+ });
+
+ object scope_module;
+ if (rec->scope) {
+ if (hasattr(rec->scope, "__module__")) {
+ scope_module = rec->scope.attr("__module__");
+ } else if (hasattr(rec->scope, "__name__")) {
+ scope_module = rec->scope.attr("__name__");
+ }
+ }
+
+ m_ptr = PyCFunction_NewEx(rec->def, rec_capsule.ptr(), scope_module.ptr());
+ if (!m_ptr)
+ pybind11_fail("cpp_function::cpp_function(): Could not allocate function object");
+ } else {
+ /* Append at the end of the overload chain */
+ m_ptr = rec->sibling.ptr();
+ inc_ref();
+ chain_start = chain;
+ if (chain->is_method != rec->is_method)
+ pybind11_fail("overloading a method with both static and instance methods is not supported; "
+ #if defined(NDEBUG)
+ "compile in debug mode for more details"
+ #else
+ "error while attempting to bind " + std::string(rec->is_method ? "instance" : "static") + " method " +
+ std::string(pybind11::str(rec->scope.attr("__name__"))) + "." + std::string(rec->name) + signature
+ #endif
+ );
+ while (chain->next)
+ chain = chain->next;
+ chain->next = rec;
+ }
+
+ std::string signatures;
+ int index = 0;
+ /* Create a nice pydoc rec including all signatures and
+ docstrings of the functions in the overload chain */
+ if (chain && options::show_function_signatures()) {
+ // First a generic signature
+ signatures += rec->name;
+ signatures += "(*args, **kwargs)\n";
+ signatures += "Overloaded function.\n\n";
+ }
+ // Then specific overload signatures
+ bool first_user_def = true;
+ for (auto it = chain_start; it != nullptr; it = it->next) {
+ if (options::show_function_signatures()) {
+ if (index > 0) signatures += "\n";
+ if (chain)
+ signatures += std::to_string(++index) + ". ";
+ signatures += rec->name;
+ signatures += it->signature;
+ signatures += "\n";
+ }
+ if (it->doc && strlen(it->doc) > 0 && options::show_user_defined_docstrings()) {
+ // If we're appending another docstring, and aren't printing function signatures, we
+ // need to append a newline first:
+ if (!options::show_function_signatures()) {
+ if (first_user_def) first_user_def = false;
+ else signatures += "\n";
+ }
+ if (options::show_function_signatures()) signatures += "\n";
+ signatures += it->doc;
+ if (options::show_function_signatures()) signatures += "\n";
+ }
+ }
+
+ /* Install docstring */
+ PyCFunctionObject *func = (PyCFunctionObject *) m_ptr;
+ if (func->m_ml->ml_doc)
+ std::free(const_cast<char *>(func->m_ml->ml_doc));
+ func->m_ml->ml_doc = strdup(signatures.c_str());
+
+ if (rec->is_method) {
+ m_ptr = PYBIND11_INSTANCE_METHOD_NEW(m_ptr, rec->scope.ptr());
+ if (!m_ptr)
+ pybind11_fail("cpp_function::cpp_function(): Could not allocate instance method object");
+ Py_DECREF(func);
+ }
+ }
+
+ /// When a cpp_function is GCed, release any memory allocated by pybind11
+ static void destruct(detail::function_record *rec) {
+ while (rec) {
+ detail::function_record *next = rec->next;
+ if (rec->free_data)
+ rec->free_data(rec);
+ std::free((char *) rec->name);
+ std::free((char *) rec->doc);
+ std::free((char *) rec->signature);
+ for (auto &arg: rec->args) {
+ std::free(const_cast<char *>(arg.name));
+ std::free(const_cast<char *>(arg.descr));
+ arg.value.dec_ref();
+ }
+ if (rec->def) {
+ std::free(const_cast<char *>(rec->def->ml_doc));
+ delete rec->def;
+ }
+ delete rec;
+ rec = next;
+ }
+ }
+
+ /// Main dispatch logic for calls to functions bound using pybind11
+ static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) {
+ using namespace detail;
+
+ /* Iterator over the list of potentially admissible overloads */
+ function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr),
+ *it = overloads;
+
+ /* Need to know how many arguments + keyword arguments there are to pick the right overload */
+ const size_t n_args_in = (size_t) PyTuple_GET_SIZE(args_in);
+
+ handle parent = n_args_in > 0 ? PyTuple_GET_ITEM(args_in, 0) : nullptr,
+ result = PYBIND11_TRY_NEXT_OVERLOAD;
+
+ auto self_value_and_holder = value_and_holder();
+ if (overloads->is_constructor) {
+ const auto tinfo = get_type_info((PyTypeObject *) overloads->scope.ptr());
+ const auto pi = reinterpret_cast<instance *>(parent.ptr());
+ self_value_and_holder = pi->get_value_and_holder(tinfo, false);
+
+ if (!self_value_and_holder.type || !self_value_and_holder.inst) {
+ PyErr_SetString(PyExc_TypeError, "__init__(self, ...) called with invalid `self` argument");
+ return nullptr;
+ }
+
+ // If this value is already registered it must mean __init__ is invoked multiple times;
+ // we really can't support that in C++, so just ignore the second __init__.
+ if (self_value_and_holder.instance_registered())
+ return none().release().ptr();
+ }
+
+ try {
+ // We do this in two passes: in the first pass, we load arguments with `convert=false`;
+ // in the second, we allow conversion (except for arguments with an explicit
+ // py::arg().noconvert()). This lets us prefer calls without conversion, with
+ // conversion as a fallback.
+ std::vector<function_call> second_pass;
+
+ // However, if there are no overloads, we can just skip the no-convert pass entirely
+ const bool overloaded = it != nullptr && it->next != nullptr;
+
+ for (; it != nullptr; it = it->next) {
+
+ /* For each overload:
+ 1. Copy all positional arguments we were given, also checking to make sure that
+ named positional arguments weren't *also* specified via kwarg.
+ 2. If we weren't given enough, try to make up the omitted ones by checking
+ whether they were provided by a kwarg matching the `py::arg("name")` name. If
+ so, use it (and remove it from kwargs; if not, see if the function binding
+ provided a default that we can use.
+ 3. Ensure that either all keyword arguments were "consumed", or that the function
+ takes a kwargs argument to accept unconsumed kwargs.
+ 4. Any positional arguments still left get put into a tuple (for args), and any
+ leftover kwargs get put into a dict.
+ 5. Pack everything into a vector; if we have py::args or py::kwargs, they are an
+ extra tuple or dict at the end of the positional arguments.
+ 6. Call the function call dispatcher (function_record::impl)
+
+ If one of these fail, move on to the next overload and keep trying until we get a
+ result other than PYBIND11_TRY_NEXT_OVERLOAD.
+ */
+
+ function_record &func = *it;
+ size_t pos_args = func.nargs; // Number of positional arguments that we need
+ if (func.has_args) --pos_args; // (but don't count py::args
+ if (func.has_kwargs) --pos_args; // or py::kwargs)
+
+ if (!func.has_args && n_args_in > pos_args)
+ continue; // Too many arguments for this overload
+
+ if (n_args_in < pos_args && func.args.size() < pos_args)
+ continue; // Not enough arguments given, and not enough defaults to fill in the blanks
+
+ function_call call(func, parent);
+
+ size_t args_to_copy = std::min(pos_args, n_args_in);
+ size_t args_copied = 0;
+
+ // 0. Inject new-style `self` argument
+ if (func.is_new_style_constructor) {
+ // The `value` may have been preallocated by an old-style `__init__`
+ // if it was a preceding candidate for overload resolution.
+ if (self_value_and_holder)
+ self_value_and_holder.type->dealloc(self_value_and_holder);
+
+ call.init_self = PyTuple_GET_ITEM(args_in, 0);
+ call.args.push_back(reinterpret_cast<PyObject *>(&self_value_and_holder));
+ call.args_convert.push_back(false);
+ ++args_copied;
+ }
+
+ // 1. Copy any position arguments given.
+ bool bad_arg = false;
+ for (; args_copied < args_to_copy; ++args_copied) {
+ argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr;
+ if (kwargs_in && arg_rec && arg_rec->name && PyDict_GetItemString(kwargs_in, arg_rec->name)) {
+ bad_arg = true;
+ break;
+ }
+
+ handle arg(PyTuple_GET_ITEM(args_in, args_copied));
+ if (arg_rec && !arg_rec->none && arg.is_none()) {
+ bad_arg = true;
+ break;
+ }
+ call.args.push_back(arg);
+ call.args_convert.push_back(arg_rec ? arg_rec->convert : true);
+ }
+ if (bad_arg)
+ continue; // Maybe it was meant for another overload (issue #688)
+
+ // We'll need to copy this if we steal some kwargs for defaults
+ dict kwargs = reinterpret_borrow<dict>(kwargs_in);
+
+ // 2. Check kwargs and, failing that, defaults that may help complete the list
+ if (args_copied < pos_args) {
+ bool copied_kwargs = false;
+
+ for (; args_copied < pos_args; ++args_copied) {
+ const auto &arg = func.args[args_copied];
+
+ handle value;
+ if (kwargs_in && arg.name)
+ value = PyDict_GetItemString(kwargs.ptr(), arg.name);
+
+ if (value) {
+ // Consume a kwargs value
+ if (!copied_kwargs) {
+ kwargs = reinterpret_steal<dict>(PyDict_Copy(kwargs.ptr()));
+ copied_kwargs = true;
+ }
+ PyDict_DelItemString(kwargs.ptr(), arg.name);
+ } else if (arg.value) {
+ value = arg.value;
+ }
+
+ if (value) {
+ call.args.push_back(value);
+ call.args_convert.push_back(arg.convert);
+ }
+ else
+ break;
+ }
+
+ if (args_copied < pos_args)
+ continue; // Not enough arguments, defaults, or kwargs to fill the positional arguments
+ }
+
+ // 3. Check everything was consumed (unless we have a kwargs arg)
+ if (kwargs && kwargs.size() > 0 && !func.has_kwargs)
+ continue; // Unconsumed kwargs, but no py::kwargs argument to accept them
+
+ // 4a. If we have a py::args argument, create a new tuple with leftovers
+ if (func.has_args) {
+ tuple extra_args;
+ if (args_to_copy == 0) {
+ // We didn't copy out any position arguments from the args_in tuple, so we
+ // can reuse it directly without copying:
+ extra_args = reinterpret_borrow<tuple>(args_in);
+ } else if (args_copied >= n_args_in) {
+ extra_args = tuple(0);
+ } else {
+ size_t args_size = n_args_in - args_copied;
+ extra_args = tuple(args_size);
+ for (size_t i = 0; i < args_size; ++i) {
+ extra_args[i] = PyTuple_GET_ITEM(args_in, args_copied + i);
+ }
+ }
+ call.args.push_back(extra_args);
+ call.args_convert.push_back(false);
+ call.args_ref = std::move(extra_args);
+ }
+
+ // 4b. If we have a py::kwargs, pass on any remaining kwargs
+ if (func.has_kwargs) {
+ if (!kwargs.ptr())
+ kwargs = dict(); // If we didn't get one, send an empty one
+ call.args.push_back(kwargs);
+ call.args_convert.push_back(false);
+ call.kwargs_ref = std::move(kwargs);
+ }
+
+ // 5. Put everything in a vector. Not technically step 5, we've been building it
+ // in `call.args` all along.
+ #if !defined(NDEBUG)
+ if (call.args.size() != func.nargs || call.args_convert.size() != func.nargs)
+ pybind11_fail("Internal error: function call dispatcher inserted wrong number of arguments!");
+ #endif
+
+ std::vector<bool> second_pass_convert;
+ if (overloaded) {
+ // We're in the first no-convert pass, so swap out the conversion flags for a
+ // set of all-false flags. If the call fails, we'll swap the flags back in for
+ // the conversion-allowed call below.
+ second_pass_convert.resize(func.nargs, false);
+ call.args_convert.swap(second_pass_convert);
+ }
+
+ // 6. Call the function.
+ try {
+ loader_life_support guard{};
+ result = func.impl(call);
+ } catch (reference_cast_error &) {
+ result = PYBIND11_TRY_NEXT_OVERLOAD;
+ }
+
+ if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD)
+ break;
+
+ if (overloaded) {
+ // The (overloaded) call failed; if the call has at least one argument that
+ // permits conversion (i.e. it hasn't been explicitly specified `.noconvert()`)
+ // then add this call to the list of second pass overloads to try.
+ for (size_t i = func.is_method ? 1 : 0; i < pos_args; i++) {
+ if (second_pass_convert[i]) {
+ // Found one: swap the converting flags back in and store the call for
+ // the second pass.
+ call.args_convert.swap(second_pass_convert);
+ second_pass.push_back(std::move(call));
+ break;
+ }
+ }
+ }
+ }
+
+ if (overloaded && !second_pass.empty() && result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) {
+ // The no-conversion pass finished without success, try again with conversion allowed
+ for (auto &call : second_pass) {
+ try {
+ loader_life_support guard{};
+ result = call.func.impl(call);
+ } catch (reference_cast_error &) {
+ result = PYBIND11_TRY_NEXT_OVERLOAD;
+ }
+
+ if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD)
+ break;
+ }
+ }
+ } catch (error_already_set &e) {
+ e.restore();
+ return nullptr;
+ } catch (...) {
+ /* When an exception is caught, give each registered exception
+ translator a chance to translate it to a Python exception
+ in reverse order of registration.
+
+ A translator may choose to do one of the following:
+
+ - catch the exception and call PyErr_SetString or PyErr_SetObject
+ to set a standard (or custom) Python exception, or
+ - do nothing and let the exception fall through to the next translator, or
+ - delegate translation to the next translator by throwing a new type of exception. */
+
+ auto last_exception = std::current_exception();
+ auto &registered_exception_translators = get_internals().registered_exception_translators;
+ for (auto& translator : registered_exception_translators) {
+ try {
+ translator(last_exception);
+ } catch (...) {
+ last_exception = std::current_exception();
+ continue;
+ }
+ return nullptr;
+ }
+ PyErr_SetString(PyExc_SystemError, "Exception escaped from default exception translator!");
+ return nullptr;
+ }
+
+ auto append_note_if_missing_header_is_suspected = [](std::string &msg) {
+ if (msg.find("std::") != std::string::npos) {
+ msg += "\n\n"
+ "Did you forget to `#include <pybind11/stl.h>`? Or <pybind11/complex.h>,\n"
+ "<pybind11/functional.h>, <pybind11/chrono.h>, etc. Some automatic\n"
+ "conversions are optional and require extra headers to be included\n"
+ "when compiling your pybind11 module.";
+ }
+ };
+
+ if (result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) {
+ if (overloads->is_operator)
+ return handle(Py_NotImplemented).inc_ref().ptr();
+
+ std::string msg = std::string(overloads->name) + "(): incompatible " +
+ std::string(overloads->is_constructor ? "constructor" : "function") +
+ " arguments. The following argument types are supported:\n";
+
+ int ctr = 0;
+ for (function_record *it2 = overloads; it2 != nullptr; it2 = it2->next) {
+ msg += " "+ std::to_string(++ctr) + ". ";
+
+ bool wrote_sig = false;
+ if (overloads->is_constructor) {
+ // For a constructor, rewrite `(self: Object, arg0, ...) -> NoneType` as `Object(arg0, ...)`
+ std::string sig = it2->signature;
+ size_t start = sig.find('(') + 7; // skip "(self: "
+ if (start < sig.size()) {
+ // End at the , for the next argument
+ size_t end = sig.find(", "), next = end + 2;
+ size_t ret = sig.rfind(" -> ");
+ // Or the ), if there is no comma:
+ if (end >= sig.size()) next = end = sig.find(')');
+ if (start < end && next < sig.size()) {
+ msg.append(sig, start, end - start);
+ msg += '(';
+ msg.append(sig, next, ret - next);
+ wrote_sig = true;
+ }
+ }
+ }
+ if (!wrote_sig) msg += it2->signature;
+
+ msg += "\n";
+ }
+ msg += "\nInvoked with: ";
+ auto args_ = reinterpret_borrow<tuple>(args_in);
+ bool some_args = false;
+ for (size_t ti = overloads->is_constructor ? 1 : 0; ti < args_.size(); ++ti) {
+ if (!some_args) some_args = true;
+ else msg += ", ";
+ msg += pybind11::repr(args_[ti]);
+ }
+ if (kwargs_in) {
+ auto kwargs = reinterpret_borrow<dict>(kwargs_in);
+ if (kwargs.size() > 0) {
+ if (some_args) msg += "; ";
+ msg += "kwargs: ";
+ bool first = true;
+ for (auto kwarg : kwargs) {
+ if (first) first = false;
+ else msg += ", ";
+ msg += pybind11::str("{}={!r}").format(kwarg.first, kwarg.second);
+ }
+ }
+ }
+
+ append_note_if_missing_header_is_suspected(msg);
+ PyErr_SetString(PyExc_TypeError, msg.c_str());
+ return nullptr;
+ } else if (!result) {
+ std::string msg = "Unable to convert function return value to a "
+ "Python type! The signature was\n\t";
+ msg += it->signature;
+ append_note_if_missing_header_is_suspected(msg);
+ PyErr_SetString(PyExc_TypeError, msg.c_str());
+ return nullptr;
+ } else {
+ if (overloads->is_constructor && !self_value_and_holder.holder_constructed()) {
+ auto *pi = reinterpret_cast<instance *>(parent.ptr());
+ self_value_and_holder.type->init_instance(pi, nullptr);
+ }
+ return result.ptr();
+ }
+ }
+};
+
+/// Wrapper for Python extension modules
+class module : public object {
+public:
+ PYBIND11_OBJECT_DEFAULT(module, object, PyModule_Check)
+
+ /// Create a new top-level Python module with the given name and docstring
+ explicit module(const char *name, const char *doc = nullptr) {
+ if (!options::show_user_defined_docstrings()) doc = nullptr;
+#if PY_MAJOR_VERSION >= 3
+ PyModuleDef *def = new PyModuleDef();
+ std::memset(def, 0, sizeof(PyModuleDef));
+ def->m_name = name;
+ def->m_doc = doc;
+ def->m_size = -1;
+ Py_INCREF(def);
+ m_ptr = PyModule_Create(def);
+#else
+ m_ptr = Py_InitModule3(name, nullptr, doc);
+#endif
+ if (m_ptr == nullptr)
+ pybind11_fail("Internal error in module::module()");
+ inc_ref();
+ }
+
+ /** \rst
+ Create Python binding for a new function within the module scope. ``Func``
+ can be a plain C++ function, a function pointer, or a lambda function. For
+ details on the ``Extra&& ... extra`` argument, see section :ref:`extras`.
+ \endrst */
+ template <typename Func, typename... Extra>
+ module &def(const char *name_, Func &&f, const Extra& ... extra) {
+ cpp_function func(std::forward<Func>(f), name(name_), scope(*this),
+ sibling(getattr(*this, name_, none())), extra...);
+ // NB: allow overwriting here because cpp_function sets up a chain with the intention of
+ // overwriting (and has already checked internally that it isn't overwriting non-functions).
+ add_object(name_, func, true /* overwrite */);
+ return *this;
+ }
+
+ /** \rst
+ Create and return a new Python submodule with the given name and docstring.
+ This also works recursively, i.e.
+
+ .. code-block:: cpp
+
+ py::module m("example", "pybind11 example plugin");
+ py::module m2 = m.def_submodule("sub", "A submodule of 'example'");
+ py::module m3 = m2.def_submodule("subsub", "A submodule of 'example.sub'");
+ \endrst */
+ module def_submodule(const char *name, const char *doc = nullptr) {
+ std::string full_name = std::string(PyModule_GetName(m_ptr))
+ + std::string(".") + std::string(name);
+ auto result = reinterpret_borrow<module>(PyImport_AddModule(full_name.c_str()));
+ if (doc && options::show_user_defined_docstrings())
+ result.attr("__doc__") = pybind11::str(doc);
+ attr(name) = result;
+ return result;
+ }
+
+ /// Import and return a module or throws `error_already_set`.
+ static module import(const char *name) {
+ PyObject *obj = PyImport_ImportModule(name);
+ if (!obj)
+ throw error_already_set();
+ return reinterpret_steal<module>(obj);
+ }
+
+ /// Reload the module or throws `error_already_set`.
+ void reload() {
+ PyObject *obj = PyImport_ReloadModule(ptr());
+ if (!obj)
+ throw error_already_set();
+ *this = reinterpret_steal<module>(obj);
+ }
+
+ // Adds an object to the module using the given name. Throws if an object with the given name
+ // already exists.
+ //
+ // overwrite should almost always be false: attempting to overwrite objects that pybind11 has
+ // established will, in most cases, break things.
+ PYBIND11_NOINLINE void add_object(const char *name, handle obj, bool overwrite = false) {
+ if (!overwrite && hasattr(*this, name))
+ pybind11_fail("Error during initialization: multiple incompatible definitions with name \"" +
+ std::string(name) + "\"");
+
+ PyModule_AddObject(ptr(), name, obj.inc_ref().ptr() /* steals a reference */);
+ }
+};
+
+/// \ingroup python_builtins
+/// Return a dictionary representing the global variables in the current execution frame,
+/// or ``__main__.__dict__`` if there is no frame (usually when the interpreter is embedded).
+inline dict globals() {
+ PyObject *p = PyEval_GetGlobals();
+ return reinterpret_borrow<dict>(p ? p : module::import("__main__").attr("__dict__").ptr());
+}
+
+NAMESPACE_BEGIN(detail)
+/// Generic support for creating new Python heap types
+class generic_type : public object {
+ template <typename...> friend class class_;
+public:
+ PYBIND11_OBJECT_DEFAULT(generic_type, object, PyType_Check)
+protected:
+ void initialize(const type_record &rec) {
+ if (rec.scope && hasattr(rec.scope, rec.name))
+ pybind11_fail("generic_type: cannot initialize type \"" + std::string(rec.name) +
+ "\": an object with that name is already defined");
+
+ if (rec.module_local ? get_local_type_info(*rec.type) : get_global_type_info(*rec.type))
+ pybind11_fail("generic_type: type \"" + std::string(rec.name) +
+ "\" is already registered!");
+
+ m_ptr = make_new_python_type(rec);
+
+ /* Register supplemental type information in C++ dict */
+ auto *tinfo = new detail::type_info();
+ tinfo->type = (PyTypeObject *) m_ptr;
+ tinfo->cpptype = rec.type;
+ tinfo->type_size = rec.type_size;
+ tinfo->operator_new = rec.operator_new;
+ tinfo->holder_size_in_ptrs = size_in_ptrs(rec.holder_size);
+ tinfo->init_instance = rec.init_instance;
+ tinfo->dealloc = rec.dealloc;
+ tinfo->simple_type = true;
+ tinfo->simple_ancestors = true;
+ tinfo->default_holder = rec.default_holder;
+ tinfo->module_local = rec.module_local;
+
+ auto &internals = get_internals();
+ auto tindex = std::type_index(*rec.type);
+ tinfo->direct_conversions = &internals.direct_conversions[tindex];
+ if (rec.module_local)
+ registered_local_types_cpp()[tindex] = tinfo;
+ else
+ internals.registered_types_cpp[tindex] = tinfo;
+ internals.registered_types_py[(PyTypeObject *) m_ptr] = { tinfo };
+
+ if (rec.bases.size() > 1 || rec.multiple_inheritance) {
+ mark_parents_nonsimple(tinfo->type);
+ tinfo->simple_ancestors = false;
+ }
+ else if (rec.bases.size() == 1) {
+ auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr());
+ tinfo->simple_ancestors = parent_tinfo->simple_ancestors;
+ }
+
+ if (rec.module_local) {
+ // Stash the local typeinfo and loader so that external modules can access it.
+ tinfo->module_local_load = &type_caster_generic::local_load;
+ setattr(m_ptr, PYBIND11_MODULE_LOCAL_ID, capsule(tinfo));
+ }
+ }
+
+ /// Helper function which tags all parents of a type using mult. inheritance
+ void mark_parents_nonsimple(PyTypeObject *value) {
+ auto t = reinterpret_borrow<tuple>(value->tp_bases);
+ for (handle h : t) {
+ auto tinfo2 = get_type_info((PyTypeObject *) h.ptr());
+ if (tinfo2)
+ tinfo2->simple_type = false;
+ mark_parents_nonsimple((PyTypeObject *) h.ptr());
+ }
+ }
+
+ void install_buffer_funcs(
+ buffer_info *(*get_buffer)(PyObject *, void *),
+ void *get_buffer_data) {
+ PyHeapTypeObject *type = (PyHeapTypeObject*) m_ptr;
+ auto tinfo = detail::get_type_info(&type->ht_type);
+
+ if (!type->ht_type.tp_as_buffer)
+ pybind11_fail(
+ "To be able to register buffer protocol support for the type '" +
+ std::string(tinfo->type->tp_name) +
+ "' the associated class<>(..) invocation must "
+ "include the pybind11::buffer_protocol() annotation!");
+
+ tinfo->get_buffer = get_buffer;
+ tinfo->get_buffer_data = get_buffer_data;
+ }
+
+ void def_property_static_impl(const char *name,
+ handle fget, handle fset,
+ detail::function_record *rec_fget) {
+ const auto is_static = !(rec_fget->is_method && rec_fget->scope);
+ const auto has_doc = rec_fget->doc && pybind11::options::show_user_defined_docstrings();
+
+ auto property = handle((PyObject *) (is_static ? get_internals().static_property_type
+ : &PyProperty_Type));
+ attr(name) = property(fget.ptr() ? fget : none(),
+ fset.ptr() ? fset : none(),
+ /*deleter*/none(),
+ pybind11::str(has_doc ? rec_fget->doc : ""));
+ }
+};
+
+/// Set the pointer to operator new if it exists. The cast is needed because it can be overloaded.
+template <typename T, typename = void_t<decltype(static_cast<void *(*)(size_t)>(T::operator new))>>
+void set_operator_new(type_record *r) { r->operator_new = &T::operator new; }
+
+template <typename> void set_operator_new(...) { }
+
+template <typename T, typename SFINAE = void> struct has_operator_delete : std::false_type { };
+template <typename T> struct has_operator_delete<T, void_t<decltype(static_cast<void (*)(void *)>(T::operator delete))>>
+ : std::true_type { };
+template <typename T, typename SFINAE = void> struct has_operator_delete_size : std::false_type { };
+template <typename T> struct has_operator_delete_size<T, void_t<decltype(static_cast<void (*)(void *, size_t)>(T::operator delete))>>
+ : std::true_type { };
+/// Call class-specific delete if it exists or global otherwise. Can also be an overload set.
+template <typename T, enable_if_t<has_operator_delete<T>::value, int> = 0>
+void call_operator_delete(T *p, size_t) { T::operator delete(p); }
+template <typename T, enable_if_t<!has_operator_delete<T>::value && has_operator_delete_size<T>::value, int> = 0>
+void call_operator_delete(T *p, size_t s) { T::operator delete(p, s); }
+
+inline void call_operator_delete(void *p, size_t) { ::operator delete(p); }
+
+NAMESPACE_END(detail)
+
+/// Given a pointer to a member function, cast it to its `Derived` version.
+/// Forward everything else unchanged.
+template <typename /*Derived*/, typename F>
+auto method_adaptor(F &&f) -> decltype(std::forward<F>(f)) { return std::forward<F>(f); }
+
+template <typename Derived, typename Return, typename Class, typename... Args>
+auto method_adaptor(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) { return pmf; }
+
+template <typename Derived, typename Return, typename Class, typename... Args>
+auto method_adaptor(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const { return pmf; }
+
+template <typename type_, typename... options>
+class class_ : public detail::generic_type {
+ template <typename T> using is_holder = detail::is_holder_type<type_, T>;
+ template <typename T> using is_subtype = detail::is_strict_base_of<type_, T>;
+ template <typename T> using is_base = detail::is_strict_base_of<T, type_>;
+ // struct instead of using here to help MSVC:
+ template <typename T> struct is_valid_class_option :
+ detail::any_of<is_holder<T>, is_subtype<T>, is_base<T>> {};
+
+public:
+ using type = type_;
+ using type_alias = detail::exactly_one_t<is_subtype, void, options...>;
+ constexpr static bool has_alias = !std::is_void<type_alias>::value;
+ using holder_type = detail::exactly_one_t<is_holder, std::unique_ptr<type>, options...>;
+
+ static_assert(detail::all_of<is_valid_class_option<options>...>::value,
+ "Unknown/invalid class_ template parameters provided");
+
+ static_assert(!has_alias || std::is_polymorphic<type>::value,
+ "Cannot use an alias class with a non-polymorphic type");
+
+ PYBIND11_OBJECT(class_, generic_type, PyType_Check)
+
+ template <typename... Extra>
+ class_(handle scope, const char *name, const Extra &... extra) {
+ using namespace detail;
+
+ // MI can only be specified via class_ template options, not constructor parameters
+ static_assert(
+ none_of<is_pyobject<Extra>...>::value || // no base class arguments, or:
+ ( constexpr_sum(is_pyobject<Extra>::value...) == 1 && // Exactly one base
+ constexpr_sum(is_base<options>::value...) == 0 && // no template option bases
+ none_of<std::is_same<multiple_inheritance, Extra>...>::value), // no multiple_inheritance attr
+ "Error: multiple inheritance bases must be specified via class_ template options");
+
+ type_record record;
+ record.scope = scope;
+ record.name = name;
+ record.type = &typeid(type);
+ record.type_size = sizeof(conditional_t<has_alias, type_alias, type>);
+ record.holder_size = sizeof(holder_type);
+ record.init_instance = init_instance;
+ record.dealloc = dealloc;
+ record.default_holder = std::is_same<holder_type, std::unique_ptr<type>>::value;
+
+ set_operator_new<type>(&record);
+
+ /* Register base classes specified via template arguments to class_, if any */
+ PYBIND11_EXPAND_SIDE_EFFECTS(add_base<options>(record));
+
+ /* Process optional arguments, if any */
+ process_attributes<Extra...>::init(extra..., &record);
+
+ generic_type::initialize(record);
+
+ if (has_alias) {
+ auto &instances = record.module_local ? registered_local_types_cpp() : get_internals().registered_types_cpp;
+ instances[std::type_index(typeid(type_alias))] = instances[std::type_index(typeid(type))];
+ }
+ }
+
+ template <typename Base, detail::enable_if_t<is_base<Base>::value, int> = 0>
+ static void add_base(detail::type_record &rec) {
+ rec.add_base(typeid(Base), [](void *src) -> void * {
+ return static_cast<Base *>(reinterpret_cast<type *>(src));
+ });
+ }
+
+ template <typename Base, detail::enable_if_t<!is_base<Base>::value, int> = 0>
+ static void add_base(detail::type_record &) { }
+
+ template <typename Func, typename... Extra>
+ class_ &def(const char *name_, Func&& f, const Extra&... extra) {
+ cpp_function cf(method_adaptor<type>(std::forward<Func>(f)), name(name_), is_method(*this),
+ sibling(getattr(*this, name_, none())), extra...);
+ attr(cf.name()) = cf;
+ return *this;
+ }
+
+ template <typename Func, typename... Extra> class_ &
+ def_static(const char *name_, Func &&f, const Extra&... extra) {
+ static_assert(!std::is_member_function_pointer<Func>::value,
+ "def_static(...) called with a non-static member function pointer");
+ cpp_function cf(std::forward<Func>(f), name(name_), scope(*this),
+ sibling(getattr(*this, name_, none())), extra...);
+ attr(cf.name()) = cf;
+ return *this;
+ }
+
+ template <detail::op_id id, detail::op_type ot, typename L, typename R, typename... Extra>
+ class_ &def(const detail::op_<id, ot, L, R> &op, const Extra&... extra) {
+ op.execute(*this, extra...);
+ return *this;
+ }
+
+ template <detail::op_id id, detail::op_type ot, typename L, typename R, typename... Extra>
+ class_ & def_cast(const detail::op_<id, ot, L, R> &op, const Extra&... extra) {
+ op.execute_cast(*this, extra...);
+ return *this;
+ }
+
+ template <typename... Args, typename... Extra>
+ class_ &def(const detail::initimpl::constructor<Args...> &init, const Extra&... extra) {
+ init.execute(*this, extra...);
+ return *this;
+ }
+
+ template <typename... Args, typename... Extra>
+ class_ &def(const detail::initimpl::alias_constructor<Args...> &init, const Extra&... extra) {
+ init.execute(*this, extra...);
+ return *this;
+ }
+
+ template <typename... Args, typename... Extra>
+ class_ &def(detail::initimpl::factory<Args...> &&init, const Extra&... extra) {
+ std::move(init).execute(*this, extra...);
+ return *this;
+ }
+
+ template <typename... Args, typename... Extra>
+ class_ &def(detail::initimpl::pickle_factory<Args...> &&pf, const Extra &...extra) {
+ std::move(pf).execute(*this, extra...);
+ return *this;
+ }
+
+ template <typename Func> class_& def_buffer(Func &&func) {
+ struct capture { Func func; };
+ capture *ptr = new capture { std::forward<Func>(func) };
+ install_buffer_funcs([](PyObject *obj, void *ptr) -> buffer_info* {
+ detail::make_caster<type> caster;
+ if (!caster.load(obj, false))
+ return nullptr;
+ return new buffer_info(((capture *) ptr)->func(caster));
+ }, ptr);
+ return *this;
+ }
+
+ template <typename Return, typename Class, typename... Args>
+ class_ &def_buffer(Return (Class::*func)(Args...)) {
+ return def_buffer([func] (type &obj) { return (obj.*func)(); });
+ }
+
+ template <typename Return, typename Class, typename... Args>
+ class_ &def_buffer(Return (Class::*func)(Args...) const) {
+ return def_buffer([func] (const type &obj) { return (obj.*func)(); });
+ }
+
+ template <typename C, typename D, typename... Extra>
+ class_ &def_readwrite(const char *name, D C::*pm, const Extra&... extra) {
+ static_assert(std::is_base_of<C, type>::value, "def_readwrite() requires a class member (or base class member)");
+ cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)),
+ fset([pm](type &c, const D &value) { c.*pm = value; }, is_method(*this));
+ def_property(name, fget, fset, return_value_policy::reference_internal, extra...);
+ return *this;
+ }
+
+ template <typename C, typename D, typename... Extra>
+ class_ &def_readonly(const char *name, const D C::*pm, const Extra& ...extra) {
+ static_assert(std::is_base_of<C, type>::value, "def_readonly() requires a class member (or base class member)");
+ cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this));
+ def_property_readonly(name, fget, return_value_policy::reference_internal, extra...);
+ return *this;
+ }
+
+ template <typename D, typename... Extra>
+ class_ &def_readwrite_static(const char *name, D *pm, const Extra& ...extra) {
+ cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this)),
+ fset([pm](object, const D &value) { *pm = value; }, scope(*this));
+ def_property_static(name, fget, fset, return_value_policy::reference, extra...);
+ return *this;
+ }
+
+ template <typename D, typename... Extra>
+ class_ &def_readonly_static(const char *name, const D *pm, const Extra& ...extra) {
+ cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this));
+ def_property_readonly_static(name, fget, return_value_policy::reference, extra...);
+ return *this;
+ }
+
+ /// Uses return_value_policy::reference_internal by default
+ template <typename Getter, typename... Extra>
+ class_ &def_property_readonly(const char *name, const Getter &fget, const Extra& ...extra) {
+ return def_property_readonly(name, cpp_function(method_adaptor<type>(fget)),
+ return_value_policy::reference_internal, extra...);
+ }
+
+ /// Uses cpp_function's return_value_policy by default
+ template <typename... Extra>
+ class_ &def_property_readonly(const char *name, const cpp_function &fget, const Extra& ...extra) {
+ return def_property(name, fget, cpp_function(), extra...);
+ }
+
+ /// Uses return_value_policy::reference by default
+ template <typename Getter, typename... Extra>
+ class_ &def_property_readonly_static(const char *name, const Getter &fget, const Extra& ...extra) {
+ return def_property_readonly_static(name, cpp_function(fget), return_value_policy::reference, extra...);
+ }
+
+ /// Uses cpp_function's return_value_policy by default
+ template <typename... Extra>
+ class_ &def_property_readonly_static(const char *name, const cpp_function &fget, const Extra& ...extra) {
+ return def_property_static(name, fget, cpp_function(), extra...);
+ }
+
+ /// Uses return_value_policy::reference_internal by default
+ template <typename Getter, typename Setter, typename... Extra>
+ class_ &def_property(const char *name, const Getter &fget, const Setter &fset, const Extra& ...extra) {
+ return def_property(name, fget, cpp_function(method_adaptor<type>(fset)), extra...);
+ }
+ template <typename Getter, typename... Extra>
+ class_ &def_property(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) {
+ return def_property(name, cpp_function(method_adaptor<type>(fget)), fset,
+ return_value_policy::reference_internal, extra...);
+ }
+
+ /// Uses cpp_function's return_value_policy by default
+ template <typename... Extra>
+ class_ &def_property(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) {
+ return def_property_static(name, fget, fset, is_method(*this), extra...);
+ }
+
+ /// Uses return_value_policy::reference by default
+ template <typename Getter, typename... Extra>
+ class_ &def_property_static(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) {
+ return def_property_static(name, cpp_function(fget), fset, return_value_policy::reference, extra...);
+ }
+
+ /// Uses cpp_function's return_value_policy by default
+ template <typename... Extra>
+ class_ &def_property_static(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) {
+ auto rec_fget = get_function_record(fget), rec_fset = get_function_record(fset);
+ char *doc_prev = rec_fget->doc; /* 'extra' field may include a property-specific documentation string */
+ detail::process_attributes<Extra...>::init(extra..., rec_fget);
+ if (rec_fget->doc && rec_fget->doc != doc_prev) {
+ free(doc_prev);
+ rec_fget->doc = strdup(rec_fget->doc);
+ }
+ if (rec_fset) {
+ doc_prev = rec_fset->doc;
+ detail::process_attributes<Extra...>::init(extra..., rec_fset);
+ if (rec_fset->doc && rec_fset->doc != doc_prev) {
+ free(doc_prev);
+ rec_fset->doc = strdup(rec_fset->doc);
+ }
+ }
+ def_property_static_impl(name, fget, fset, rec_fget);
+ return *this;
+ }
+
+private:
+ /// Initialize holder object, variant 1: object derives from enable_shared_from_this
+ template <typename T>
+ static void init_holder(detail::instance *inst, detail::value_and_holder &v_h,
+ const holder_type * /* unused */, const std::enable_shared_from_this<T> * /* dummy */) {
+ try {
+ auto sh = std::dynamic_pointer_cast<typename holder_type::element_type>(
+ v_h.value_ptr<type>()->shared_from_this());
+ if (sh) {
+ new (&v_h.holder<holder_type>()) holder_type(std::move(sh));
+ v_h.set_holder_constructed();
+ }
+ } catch (const std::bad_weak_ptr &) {}
+
+ if (!v_h.holder_constructed() && inst->owned) {
+ new (&v_h.holder<holder_type>()) holder_type(v_h.value_ptr<type>());
+ v_h.set_holder_constructed();
+ }
+ }
+
+ static void init_holder_from_existing(const detail::value_and_holder &v_h,
+ const holder_type *holder_ptr, std::true_type /*is_copy_constructible*/) {
+ new (&v_h.holder<holder_type>()) holder_type(*reinterpret_cast<const holder_type *>(holder_ptr));
+ }
+
+ static void init_holder_from_existing(const detail::value_and_holder &v_h,
+ const holder_type *holder_ptr, std::false_type /*is_copy_constructible*/) {
+ new (&v_h.holder<holder_type>()) holder_type(std::move(*const_cast<holder_type *>(holder_ptr)));
+ }
+
+ /// Initialize holder object, variant 2: try to construct from existing holder object, if possible
+ static void init_holder(detail::instance *inst, detail::value_and_holder &v_h,
+ const holder_type *holder_ptr, const void * /* dummy -- not enable_shared_from_this<T>) */) {
+ if (holder_ptr) {
+ init_holder_from_existing(v_h, holder_ptr, std::is_copy_constructible<holder_type>());
+ v_h.set_holder_constructed();
+ } else if (inst->owned || detail::always_construct_holder<holder_type>::value) {
+ new (&v_h.holder<holder_type>()) holder_type(v_h.value_ptr<type>());
+ v_h.set_holder_constructed();
+ }
+ }
+
+ /// Performs instance initialization including constructing a holder and registering the known
+ /// instance. Should be called as soon as the `type` value_ptr is set for an instance. Takes an
+ /// optional pointer to an existing holder to use; if not specified and the instance is
+ /// `.owned`, a new holder will be constructed to manage the value pointer.
+ static void init_instance(detail::instance *inst, const void *holder_ptr) {
+ auto v_h = inst->get_value_and_holder(detail::get_type_info(typeid(type)));
+ if (!v_h.instance_registered()) {
+ register_instance(inst, v_h.value_ptr(), v_h.type);
+ v_h.set_instance_registered();
+ }
+ init_holder(inst, v_h, (const holder_type *) holder_ptr, v_h.value_ptr<type>());
+ }
+
+ /// Deallocates an instance; via holder, if constructed; otherwise via operator delete.
+ static void dealloc(detail::value_and_holder &v_h) {
+ if (v_h.holder_constructed()) {
+ v_h.holder<holder_type>().~holder_type();
+ v_h.set_holder_constructed(false);
+ }
+ else {
+ detail::call_operator_delete(v_h.value_ptr<type>(), v_h.type->type_size);
+ }
+ v_h.value_ptr() = nullptr;
+ }
+
+ static detail::function_record *get_function_record(handle h) {
+ h = detail::get_function(h);
+ return h ? (detail::function_record *) reinterpret_borrow<capsule>(PyCFunction_GET_SELF(h.ptr()))
+ : nullptr;
+ }
+};
+
+/// Binds an existing constructor taking arguments Args...
+template <typename... Args> detail::initimpl::constructor<Args...> init() { return {}; }
+/// Like `init<Args...>()`, but the instance is always constructed through the alias class (even
+/// when not inheriting on the Python side).
+template <typename... Args> detail::initimpl::alias_constructor<Args...> init_alias() { return {}; }
+
+/// Binds a factory function as a constructor
+template <typename Func, typename Ret = detail::initimpl::factory<Func>>
+Ret init(Func &&f) { return {std::forward<Func>(f)}; }
+
+/// Dual-argument factory function: the first function is called when no alias is needed, the second
+/// when an alias is needed (i.e. due to python-side inheritance). Arguments must be identical.
+template <typename CFunc, typename AFunc, typename Ret = detail::initimpl::factory<CFunc, AFunc>>
+Ret init(CFunc &&c, AFunc &&a) {
+ return {std::forward<CFunc>(c), std::forward<AFunc>(a)};
+}
+
+/// Binds pickling functions `__getstate__` and `__setstate__` and ensures that the type
+/// returned by `__getstate__` is the same as the argument accepted by `__setstate__`.
+template <typename GetState, typename SetState>
+detail::initimpl::pickle_factory<GetState, SetState> pickle(GetState &&g, SetState &&s) {
+ return {std::forward<GetState>(g), std::forward<SetState>(s)};
+}
+
+/// Binds C++ enumerations and enumeration classes to Python
+template <typename Type> class enum_ : public class_<Type> {
+public:
+ using class_<Type>::def;
+ using class_<Type>::def_property_readonly_static;
+ using Scalar = typename std::underlying_type<Type>::type;
+
+ template <typename... Extra>
+ enum_(const handle &scope, const char *name, const Extra&... extra)
+ : class_<Type>(scope, name, extra...), m_entries(), m_parent(scope) {
+
+ constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;
+
+ auto m_entries_ptr = m_entries.inc_ref().ptr();
+ def("__repr__", [name, m_entries_ptr](Type value) -> pybind11::str {
+ for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
+ if (pybind11::cast<Type>(kv.second) == value)
+ return pybind11::str("{}.{}").format(name, kv.first);
+ }
+ return pybind11::str("{}.???").format(name);
+ });
+ def_property_readonly_static("__members__", [m_entries_ptr](object /* self */) {
+ dict m;
+ for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr))
+ m[kv.first] = kv.second;
+ return m;
+ }, return_value_policy::copy);
+ def(init([](Scalar i) { return static_cast<Type>(i); }));
+ def("__int__", [](Type value) { return (Scalar) value; });
+ #if PY_MAJOR_VERSION < 3
+ def("__long__", [](Type value) { return (Scalar) value; });
+ #endif
+ def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; });
+ def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
+ if (is_arithmetic) {
+ def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; });
+ def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; });
+ def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; });
+ def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; });
+ }
+ if (std::is_convertible<Type, Scalar>::value) {
+ // Don't provide comparison with the underlying type if the enum isn't convertible,
+ // i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly
+ // convert Type to Scalar below anyway because this needs to compile).
+ def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; });
+ def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; });
+ if (is_arithmetic) {
+ def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; });
+ def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; });
+ def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; });
+ def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; });
+ def("__invert__", [](const Type &value) { return ~((Scalar) value); });
+ def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
+ def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
+ def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
+ def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
+ def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
+ def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
+ def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; });
+ def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; });
+ def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; });
+ }
+ }
+ def("__hash__", [](const Type &value) { return (Scalar) value; });
+ // Pickling and unpickling -- needed for use with the 'multiprocessing' module
+ def(pickle([](const Type &value) { return pybind11::make_tuple((Scalar) value); },
+ [](tuple t) { return static_cast<Type>(t[0].cast<Scalar>()); }));
+ }
+
+ /// Export enumeration entries into the parent scope
+ enum_& export_values() {
+ for (const auto &kv : m_entries)
+ m_parent.attr(kv.first) = kv.second;
+ return *this;
+ }
+
+ /// Add an enumeration entry
+ enum_& value(char const* name, Type value) {
+ auto v = pybind11::cast(value, return_value_policy::copy);
+ this->attr(name) = v;
+ m_entries[pybind11::str(name)] = v;
+ return *this;
+ }
+
+private:
+ dict m_entries;
+ handle m_parent;
+};
+
+NAMESPACE_BEGIN(detail)
+
+
+inline void keep_alive_impl(handle nurse, handle patient) {
+ if (!nurse || !patient)
+ pybind11_fail("Could not activate keep_alive!");
+
+ if (patient.is_none() || nurse.is_none())
+ return; /* Nothing to keep alive or nothing to be kept alive by */
+
+ auto tinfo = all_type_info(Py_TYPE(nurse.ptr()));
+ if (!tinfo.empty()) {
+ /* It's a pybind-registered type, so we can store the patient in the
+ * internal list. */
+ add_patient(nurse.ptr(), patient.ptr());
+ }
+ else {
+ /* Fall back to clever approach based on weak references taken from
+ * Boost.Python. This is not used for pybind-registered types because
+ * the objects can be destroyed out-of-order in a GC pass. */
+ cpp_function disable_lifesupport(
+ [patient](handle weakref) { patient.dec_ref(); weakref.dec_ref(); });
+
+ weakref wr(nurse, disable_lifesupport);
+
+ patient.inc_ref(); /* reference patient and leak the weak reference */
+ (void) wr.release();
+ }
+}
+
+PYBIND11_NOINLINE inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) {
+ auto get_arg = [&](size_t n) {
+ if (n == 0)
+ return ret;
+ else if (n == 1 && call.init_self)
+ return call.init_self;
+ else if (n <= call.args.size())
+ return call.args[n - 1];
+ return handle();
+ };
+
+ keep_alive_impl(get_arg(Nurse), get_arg(Patient));
+}
+
+inline std::pair<decltype(internals::registered_types_py)::iterator, bool> all_type_info_get_cache(PyTypeObject *type) {
+ auto res = get_internals().registered_types_py
+#ifdef __cpp_lib_unordered_map_try_emplace
+ .try_emplace(type);
+#else
+ .emplace(type, std::vector<detail::type_info *>());
+#endif
+ if (res.second) {
+ // New cache entry created; set up a weak reference to automatically remove it if the type
+ // gets destroyed:
+ weakref((PyObject *) type, cpp_function([type](handle wr) {
+ get_internals().registered_types_py.erase(type);
+ wr.dec_ref();
+ })).release();
+ }
+
+ return res;
+}
+
+template <typename Iterator, typename Sentinel, bool KeyIterator, return_value_policy Policy>
+struct iterator_state {
+ Iterator it;
+ Sentinel end;
+ bool first_or_done;
+};
+
+NAMESPACE_END(detail)
+
+/// Makes a python iterator from a first and past-the-end C++ InputIterator.
+template <return_value_policy Policy = return_value_policy::reference_internal,
+ typename Iterator,
+ typename Sentinel,
+ typename ValueType = decltype(*std::declval<Iterator>()),
+ typename... Extra>
+iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) {
+ typedef detail::iterator_state<Iterator, Sentinel, false, Policy> state;
+
+ if (!detail::get_type_info(typeid(state), false)) {
+ class_<state>(handle(), "iterator", pybind11::module_local())
+ .def("__iter__", [](state &s) -> state& { return s; })
+ .def("__next__", [](state &s) -> ValueType {
+ if (!s.first_or_done)
+ ++s.it;
+ else
+ s.first_or_done = false;
+ if (s.it == s.end) {
+ s.first_or_done = true;
+ throw stop_iteration();
+ }
+ return *s.it;
+ }, std::forward<Extra>(extra)..., Policy);
+ }
+
+ return cast(state{first, last, true});
+}
+
+/// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a
+/// first and past-the-end InputIterator.
+template <return_value_policy Policy = return_value_policy::reference_internal,
+ typename Iterator,
+ typename Sentinel,
+ typename KeyType = decltype((*std::declval<Iterator>()).first),
+ typename... Extra>
+iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) {
+ typedef detail::iterator_state<Iterator, Sentinel, true, Policy> state;
+
+ if (!detail::get_type_info(typeid(state), false)) {
+ class_<state>(handle(), "iterator", pybind11::module_local())
+ .def("__iter__", [](state &s) -> state& { return s; })
+ .def("__next__", [](state &s) -> KeyType {
+ if (!s.first_or_done)
+ ++s.it;
+ else
+ s.first_or_done = false;
+ if (s.it == s.end) {
+ s.first_or_done = true;
+ throw stop_iteration();
+ }
+ return (*s.it).first;
+ }, std::forward<Extra>(extra)..., Policy);
+ }
+
+ return cast(state{first, last, true});
+}
+
+/// Makes an iterator over values of an stl container or other container supporting
+/// `std::begin()`/`std::end()`
+template <return_value_policy Policy = return_value_policy::reference_internal,
+ typename Type, typename... Extra> iterator make_iterator(Type &value, Extra&&... extra) {
+ return make_iterator<Policy>(std::begin(value), std::end(value), extra...);
+}
+
+/// Makes an iterator over the keys (`.first`) of a stl map-like container supporting
+/// `std::begin()`/`std::end()`
+template <return_value_policy Policy = return_value_policy::reference_internal,
+ typename Type, typename... Extra> iterator make_key_iterator(Type &value, Extra&&... extra) {
+ return make_key_iterator<Policy>(std::begin(value), std::end(value), extra...);
+}
+
+template <typename InputType, typename OutputType> void implicitly_convertible() {
+ struct set_flag {
+ bool &flag;
+ set_flag(bool &flag) : flag(flag) { flag = true; }
+ ~set_flag() { flag = false; }
+ };
+ auto implicit_caster = [](PyObject *obj, PyTypeObject *type) -> PyObject * {
+ static bool currently_used = false;
+ if (currently_used) // implicit conversions are non-reentrant
+ return nullptr;
+ set_flag flag_helper(currently_used);
+ if (!detail::make_caster<InputType>().load(obj, false))
+ return nullptr;
+ tuple args(1);
+ args[0] = obj;
+ PyObject *result = PyObject_Call((PyObject *) type, args.ptr(), nullptr);
+ if (result == nullptr)
+ PyErr_Clear();
+ return result;
+ };
+
+ if (auto tinfo = detail::get_type_info(typeid(OutputType)))
+ tinfo->implicit_conversions.push_back(implicit_caster);
+ else
+ pybind11_fail("implicitly_convertible: Unable to find type " + type_id<OutputType>());
+}
+
+template <typename ExceptionTranslator>
+void register_exception_translator(ExceptionTranslator&& translator) {
+ detail::get_internals().registered_exception_translators.push_front(
+ std::forward<ExceptionTranslator>(translator));
+}
+
+/**
+ * Wrapper to generate a new Python exception type.
+ *
+ * This should only be used with PyErr_SetString for now.
+ * It is not (yet) possible to use as a py::base.
+ * Template type argument is reserved for future use.
+ */
+template <typename type>
+class exception : public object {
+public:
+ exception(handle scope, const char *name, PyObject *base = PyExc_Exception) {
+ std::string full_name = scope.attr("__name__").cast<std::string>() +
+ std::string(".") + name;
+ m_ptr = PyErr_NewException(const_cast<char *>(full_name.c_str()), base, NULL);
+ if (hasattr(scope, name))
+ pybind11_fail("Error during initialization: multiple incompatible "
+ "definitions with name \"" + std::string(name) + "\"");
+ scope.attr(name) = *this;
+ }
+
+ // Sets the current python exception to this exception object with the given message
+ void operator()(const char *message) {
+ PyErr_SetString(m_ptr, message);
+ }
+};
+
+/**
+ * Registers a Python exception in `m` of the given `name` and installs an exception translator to
+ * translate the C++ exception to the created Python exception using the exceptions what() method.
+ * This is intended for simple exception translations; for more complex translation, register the
+ * exception object and translator directly.
+ */
+template <typename CppException>
+exception<CppException> &register_exception(handle scope,
+ const char *name,
+ PyObject *base = PyExc_Exception) {
+ static exception<CppException> ex(scope, name, base);
+ register_exception_translator([](std::exception_ptr p) {
+ if (!p) return;
+ try {
+ std::rethrow_exception(p);
+ } catch (const CppException &e) {
+ ex(e.what());
+ }
+ });
+ return ex;
+}
+
+NAMESPACE_BEGIN(detail)
+PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) {
+ auto strings = tuple(args.size());
+ for (size_t i = 0; i < args.size(); ++i) {
+ strings[i] = str(args[i]);
+ }
+ auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" ");
+ auto line = sep.attr("join")(strings);
+
+ object file;
+ if (kwargs.contains("file")) {
+ file = kwargs["file"].cast<object>();
+ } else {
+ try {
+ file = module::import("sys").attr("stdout");
+ } catch (const error_already_set &) {
+ /* If print() is called from code that is executed as
+ part of garbage collection during interpreter shutdown,
+ importing 'sys' can fail. Give up rather than crashing the
+ interpreter in this case. */
+ return;
+ }
+ }
+
+ auto write = file.attr("write");
+ write(line);
+ write(kwargs.contains("end") ? kwargs["end"] : cast("\n"));
+
+ if (kwargs.contains("flush") && kwargs["flush"].cast<bool>())
+ file.attr("flush")();
+}
+NAMESPACE_END(detail)
+
+template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
+void print(Args &&...args) {
+ auto c = detail::collect_arguments<policy>(std::forward<Args>(args)...);
+ detail::print(c.args(), c.kwargs());
+}
+
+#if defined(WITH_THREAD) && !defined(PYPY_VERSION)
+
+/* The functions below essentially reproduce the PyGILState_* API using a RAII
+ * pattern, but there are a few important differences:
+ *
+ * 1. When acquiring the GIL from an non-main thread during the finalization
+ * phase, the GILState API blindly terminates the calling thread, which
+ * is often not what is wanted. This API does not do this.
+ *
+ * 2. The gil_scoped_release function can optionally cut the relationship
+ * of a PyThreadState and its associated thread, which allows moving it to
+ * another thread (this is a fairly rare/advanced use case).
+ *
+ * 3. The reference count of an acquired thread state can be controlled. This
+ * can be handy to prevent cases where callbacks issued from an external
+ * thread would otherwise constantly construct and destroy thread state data
+ * structures.
+ *
+ * See the Python bindings of NanoGUI (http://github.com/wjakob/nanogui) for an
+ * example which uses features 2 and 3 to migrate the Python thread of
+ * execution to another thread (to run the event loop on the original thread,
+ * in this case).
+ */
+
+class gil_scoped_acquire {
+public:
+ PYBIND11_NOINLINE gil_scoped_acquire() {
+ auto const &internals = detail::get_internals();
+ tstate = (PyThreadState *) PyThread_get_key_value(internals.tstate);
+
+ if (!tstate) {
+ tstate = PyThreadState_New(internals.istate);
+ #if !defined(NDEBUG)
+ if (!tstate)
+ pybind11_fail("scoped_acquire: could not create thread state!");
+ #endif
+ tstate->gilstate_counter = 0;
+ #if PY_MAJOR_VERSION < 3
+ PyThread_delete_key_value(internals.tstate);
+ #endif
+ PyThread_set_key_value(internals.tstate, tstate);
+ } else {
+ release = detail::get_thread_state_unchecked() != tstate;
+ }
+
+ if (release) {
+ /* Work around an annoying assertion in PyThreadState_Swap */
+ #if defined(Py_DEBUG)
+ PyInterpreterState *interp = tstate->interp;
+ tstate->interp = nullptr;
+ #endif
+ PyEval_AcquireThread(tstate);
+ #if defined(Py_DEBUG)
+ tstate->interp = interp;
+ #endif
+ }
+
+ inc_ref();
+ }
+
+ void inc_ref() {
+ ++tstate->gilstate_counter;
+ }
+
+ PYBIND11_NOINLINE void dec_ref() {
+ --tstate->gilstate_counter;
+ #if !defined(NDEBUG)
+ if (detail::get_thread_state_unchecked() != tstate)
+ pybind11_fail("scoped_acquire::dec_ref(): thread state must be current!");
+ if (tstate->gilstate_counter < 0)
+ pybind11_fail("scoped_acquire::dec_ref(): reference count underflow!");
+ #endif
+ if (tstate->gilstate_counter == 0) {
+ #if !defined(NDEBUG)
+ if (!release)
+ pybind11_fail("scoped_acquire::dec_ref(): internal error!");
+ #endif
+ PyThreadState_Clear(tstate);
+ PyThreadState_DeleteCurrent();
+ PyThread_delete_key_value(detail::get_internals().tstate);
+ release = false;
+ }
+ }
+
+ PYBIND11_NOINLINE ~gil_scoped_acquire() {
+ dec_ref();
+ if (release)
+ PyEval_SaveThread();
+ }
+private:
+ PyThreadState *tstate = nullptr;
+ bool release = true;
+};
+
+class gil_scoped_release {
+public:
+ explicit gil_scoped_release(bool disassoc = false) : disassoc(disassoc) {
+ // `get_internals()` must be called here unconditionally in order to initialize
+ // `internals.tstate` for subsequent `gil_scoped_acquire` calls. Otherwise, an
+ // initialization race could occur as multiple threads try `gil_scoped_acquire`.
+ const auto &internals = detail::get_internals();
+ tstate = PyEval_SaveThread();
+ if (disassoc) {
+ auto key = internals.tstate;
+ #if PY_MAJOR_VERSION < 3
+ PyThread_delete_key_value(key);
+ #else
+ PyThread_set_key_value(key, nullptr);
+ #endif
+ }
+ }
+ ~gil_scoped_release() {
+ if (!tstate)
+ return;
+ PyEval_RestoreThread(tstate);
+ if (disassoc) {
+ auto key = detail::get_internals().tstate;
+ #if PY_MAJOR_VERSION < 3
+ PyThread_delete_key_value(key);
+ #endif
+ PyThread_set_key_value(key, tstate);
+ }
+ }
+private:
+ PyThreadState *tstate;
+ bool disassoc;
+};
+#elif defined(PYPY_VERSION)
+class gil_scoped_acquire {
+ PyGILState_STATE state;
+public:
+ gil_scoped_acquire() { state = PyGILState_Ensure(); }
+ ~gil_scoped_acquire() { PyGILState_Release(state); }
+};
+
+class gil_scoped_release {
+ PyThreadState *state;
+public:
+ gil_scoped_release() { state = PyEval_SaveThread(); }
+ ~gil_scoped_release() { PyEval_RestoreThread(state); }
+};
+#else
+class gil_scoped_acquire { };
+class gil_scoped_release { };
+#endif
+
+error_already_set::~error_already_set() {
+ if (type) {
+ gil_scoped_acquire gil;
+ type.release().dec_ref();
+ value.release().dec_ref();
+ trace.release().dec_ref();
+ }
+}
+
+inline function get_type_overload(const void *this_ptr, const detail::type_info *this_type, const char *name) {
+ handle self = detail::get_object_handle(this_ptr, this_type);
+ if (!self)
+ return function();
+ handle type = self.get_type();
+ auto key = std::make_pair(type.ptr(), name);
+
+ /* Cache functions that aren't overloaded in Python to avoid
+ many costly Python dictionary lookups below */
+ auto &cache = detail::get_internals().inactive_overload_cache;
+ if (cache.find(key) != cache.end())
+ return function();
+
+ function overload = getattr(self, name, function());
+ if (overload.is_cpp_function()) {
+ cache.insert(key);
+ return function();
+ }
+
+ /* Don't call dispatch code if invoked from overridden function.
+ Unfortunately this doesn't work on PyPy. */
+#if !defined(PYPY_VERSION)
+ PyFrameObject *frame = PyThreadState_Get()->frame;
+ if (frame && (std::string) str(frame->f_code->co_name) == name &&
+ frame->f_code->co_argcount > 0) {
+ PyFrame_FastToLocals(frame);
+ PyObject *self_caller = PyDict_GetItem(
+ frame->f_locals, PyTuple_GET_ITEM(frame->f_code->co_varnames, 0));
+ if (self_caller == self.ptr())
+ return function();
+ }
+#else
+ /* PyPy currently doesn't provide a detailed cpyext emulation of
+ frame objects, so we have to emulate this using Python. This
+ is going to be slow..*/
+ dict d; d["self"] = self; d["name"] = pybind11::str(name);
+ PyObject *result = PyRun_String(
+ "import inspect\n"
+ "frame = inspect.currentframe()\n"
+ "if frame is not None:\n"
+ " frame = frame.f_back\n"
+ " if frame is not None and str(frame.f_code.co_name) == name and "
+ "frame.f_code.co_argcount > 0:\n"
+ " self_caller = frame.f_locals[frame.f_code.co_varnames[0]]\n"
+ " if self_caller == self:\n"
+ " self = None\n",
+ Py_file_input, d.ptr(), d.ptr());
+ if (result == nullptr)
+ throw error_already_set();
+ if (d["self"].is_none())
+ return function();
+ Py_DECREF(result);
+#endif
+
+ return overload;
+}
+
+template <class T> function get_overload(const T *this_ptr, const char *name) {
+ auto tinfo = detail::get_type_info(typeid(T));
+ return tinfo ? get_type_overload(this_ptr, tinfo, name) : function();
+}
+
+#define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \
+ pybind11::gil_scoped_acquire gil; \
+ pybind11::function overload = pybind11::get_overload(static_cast<const cname *>(this), name); \
+ if (overload) { \
+ auto o = overload(__VA_ARGS__); \
+ if (pybind11::detail::cast_is_temporary_value_reference<ret_type>::value) { \
+ static pybind11::detail::overload_caster_t<ret_type> caster; \
+ return pybind11::detail::cast_ref<ret_type>(std::move(o), caster); \
+ } \
+ else return pybind11::detail::cast_safe<ret_type>(std::move(o)); \
+ } \
+ }
+
+#define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \
+ PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \
+ return cname::fn(__VA_ARGS__)
+
+#define PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, name, fn, ...) \
+ PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \
+ pybind11::pybind11_fail("Tried to call pure virtual function \"" #cname "::" name "\"");
+
+#define PYBIND11_OVERLOAD(ret_type, cname, fn, ...) \
+ PYBIND11_OVERLOAD_NAME(ret_type, cname, #fn, fn, __VA_ARGS__)
+
+#define PYBIND11_OVERLOAD_PURE(ret_type, cname, fn, ...) \
+ PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, #fn, fn, __VA_ARGS__)
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
+
+#if defined(_MSC_VER)
+# pragma warning(pop)
+#elif defined(__INTEL_COMPILER)
+/* Leave ignored warnings on */
+#elif defined(__GNUG__) && !defined(__clang__)
+# pragma GCC diagnostic pop
+#endif
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/pytypes.h b/ml/dlib/dlib/external/pybind11/include/pybind11/pytypes.h
new file mode 100644
index 000000000..d7fa17775
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/pytypes.h
@@ -0,0 +1,1332 @@
+/*
+ pybind11/pytypes.h: Convenience wrapper classes for basic Python types
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "detail/common.h"
+#include "buffer_info.h"
+#include <utility>
+#include <type_traits>
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+/* A few forward declarations */
+class handle; class object;
+class str; class iterator;
+struct arg; struct arg_v;
+
+NAMESPACE_BEGIN(detail)
+class args_proxy;
+inline bool isinstance_generic(handle obj, const std::type_info &tp);
+
+// Accessor forward declarations
+template <typename Policy> class accessor;
+namespace accessor_policies {
+ struct obj_attr;
+ struct str_attr;
+ struct generic_item;
+ struct sequence_item;
+ struct list_item;
+ struct tuple_item;
+}
+using obj_attr_accessor = accessor<accessor_policies::obj_attr>;
+using str_attr_accessor = accessor<accessor_policies::str_attr>;
+using item_accessor = accessor<accessor_policies::generic_item>;
+using sequence_accessor = accessor<accessor_policies::sequence_item>;
+using list_accessor = accessor<accessor_policies::list_item>;
+using tuple_accessor = accessor<accessor_policies::tuple_item>;
+
+/// Tag and check to identify a class which implements the Python object API
+class pyobject_tag { };
+template <typename T> using is_pyobject = std::is_base_of<pyobject_tag, remove_reference_t<T>>;
+
+/** \rst
+ A mixin class which adds common functions to `handle`, `object` and various accessors.
+ The only requirement for `Derived` is to implement ``PyObject *Derived::ptr() const``.
+\endrst */
+template <typename Derived>
+class object_api : public pyobject_tag {
+ const Derived &derived() const { return static_cast<const Derived &>(*this); }
+
+public:
+ /** \rst
+ Return an iterator equivalent to calling ``iter()`` in Python. The object
+ must be a collection which supports the iteration protocol.
+ \endrst */
+ iterator begin() const;
+ /// Return a sentinel which ends iteration.
+ iterator end() const;
+
+ /** \rst
+ Return an internal functor to invoke the object's sequence protocol. Casting
+ the returned ``detail::item_accessor`` instance to a `handle` or `object`
+ subclass causes a corresponding call to ``__getitem__``. Assigning a `handle`
+ or `object` subclass causes a call to ``__setitem__``.
+ \endrst */
+ item_accessor operator[](handle key) const;
+ /// See above (the only difference is that they key is provided as a string literal)
+ item_accessor operator[](const char *key) const;
+
+ /** \rst
+ Return an internal functor to access the object's attributes. Casting the
+ returned ``detail::obj_attr_accessor`` instance to a `handle` or `object`
+ subclass causes a corresponding call to ``getattr``. Assigning a `handle`
+ or `object` subclass causes a call to ``setattr``.
+ \endrst */
+ obj_attr_accessor attr(handle key) const;
+ /// See above (the only difference is that they key is provided as a string literal)
+ str_attr_accessor attr(const char *key) const;
+
+ /** \rst
+ Matches * unpacking in Python, e.g. to unpack arguments out of a ``tuple``
+ or ``list`` for a function call. Applying another * to the result yields
+ ** unpacking, e.g. to unpack a dict as function keyword arguments.
+ See :ref:`calling_python_functions`.
+ \endrst */
+ args_proxy operator*() const;
+
+ /// Check if the given item is contained within this object, i.e. ``item in obj``.
+ template <typename T> bool contains(T &&item) const;
+
+ /** \rst
+ Assuming the Python object is a function or implements the ``__call__``
+ protocol, ``operator()`` invokes the underlying function, passing an
+ arbitrary set of parameters. The result is returned as a `object` and
+ may need to be converted back into a Python object using `handle::cast()`.
+
+ When some of the arguments cannot be converted to Python objects, the
+ function will throw a `cast_error` exception. When the Python function
+ call fails, a `error_already_set` exception is thrown.
+ \endrst */
+ template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
+ object operator()(Args &&...args) const;
+ template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
+ PYBIND11_DEPRECATED("call(...) was deprecated in favor of operator()(...)")
+ object call(Args&&... args) const;
+
+ /// Equivalent to ``obj is other`` in Python.
+ bool is(object_api const& other) const { return derived().ptr() == other.derived().ptr(); }
+ /// Equivalent to ``obj is None`` in Python.
+ bool is_none() const { return derived().ptr() == Py_None; }
+ PYBIND11_DEPRECATED("Use py::str(obj) instead")
+ pybind11::str str() const;
+
+ /// Get or set the object's docstring, i.e. ``obj.__doc__``.
+ str_attr_accessor doc() const;
+
+ /// Return the object's current reference count
+ int ref_count() const { return static_cast<int>(Py_REFCNT(derived().ptr())); }
+ /// Return a handle to the Python type object underlying the instance
+ handle get_type() const;
+};
+
+NAMESPACE_END(detail)
+
+/** \rst
+ Holds a reference to a Python object (no reference counting)
+
+ The `handle` class is a thin wrapper around an arbitrary Python object (i.e. a
+ ``PyObject *`` in Python's C API). It does not perform any automatic reference
+ counting and merely provides a basic C++ interface to various Python API functions.
+
+ .. seealso::
+ The `object` class inherits from `handle` and adds automatic reference
+ counting features.
+\endrst */
+class handle : public detail::object_api<handle> {
+public:
+ /// The default constructor creates a handle with a ``nullptr``-valued pointer
+ handle() = default;
+ /// Creates a ``handle`` from the given raw Python object pointer
+ handle(PyObject *ptr) : m_ptr(ptr) { } // Allow implicit conversion from PyObject*
+
+ /// Return the underlying ``PyObject *`` pointer
+ PyObject *ptr() const { return m_ptr; }
+ PyObject *&ptr() { return m_ptr; }
+
+ /** \rst
+ Manually increase the reference count of the Python object. Usually, it is
+ preferable to use the `object` class which derives from `handle` and calls
+ this function automatically. Returns a reference to itself.
+ \endrst */
+ const handle& inc_ref() const & { Py_XINCREF(m_ptr); return *this; }
+
+ /** \rst
+ Manually decrease the reference count of the Python object. Usually, it is
+ preferable to use the `object` class which derives from `handle` and calls
+ this function automatically. Returns a reference to itself.
+ \endrst */
+ const handle& dec_ref() const & { Py_XDECREF(m_ptr); return *this; }
+
+ /** \rst
+ Attempt to cast the Python object into the given C++ type. A `cast_error`
+ will be throw upon failure.
+ \endrst */
+ template <typename T> T cast() const;
+ /// Return ``true`` when the `handle` wraps a valid Python object
+ explicit operator bool() const { return m_ptr != nullptr; }
+ /** \rst
+ Deprecated: Check that the underlying pointers are the same.
+ Equivalent to ``obj1 is obj2`` in Python.
+ \endrst */
+ PYBIND11_DEPRECATED("Use obj1.is(obj2) instead")
+ bool operator==(const handle &h) const { return m_ptr == h.m_ptr; }
+ PYBIND11_DEPRECATED("Use !obj1.is(obj2) instead")
+ bool operator!=(const handle &h) const { return m_ptr != h.m_ptr; }
+ PYBIND11_DEPRECATED("Use handle::operator bool() instead")
+ bool check() const { return m_ptr != nullptr; }
+protected:
+ PyObject *m_ptr = nullptr;
+};
+
+/** \rst
+ Holds a reference to a Python object (with reference counting)
+
+ Like `handle`, the `object` class is a thin wrapper around an arbitrary Python
+ object (i.e. a ``PyObject *`` in Python's C API). In contrast to `handle`, it
+ optionally increases the object's reference count upon construction, and it
+ *always* decreases the reference count when the `object` instance goes out of
+ scope and is destructed. When using `object` instances consistently, it is much
+ easier to get reference counting right at the first attempt.
+\endrst */
+class object : public handle {
+public:
+ object() = default;
+ PYBIND11_DEPRECATED("Use reinterpret_borrow<object>() or reinterpret_steal<object>()")
+ object(handle h, bool is_borrowed) : handle(h) { if (is_borrowed) inc_ref(); }
+ /// Copy constructor; always increases the reference count
+ object(const object &o) : handle(o) { inc_ref(); }
+ /// Move constructor; steals the object from ``other`` and preserves its reference count
+ object(object &&other) noexcept { m_ptr = other.m_ptr; other.m_ptr = nullptr; }
+ /// Destructor; automatically calls `handle::dec_ref()`
+ ~object() { dec_ref(); }
+
+ /** \rst
+ Resets the internal pointer to ``nullptr`` without without decreasing the
+ object's reference count. The function returns a raw handle to the original
+ Python object.
+ \endrst */
+ handle release() {
+ PyObject *tmp = m_ptr;
+ m_ptr = nullptr;
+ return handle(tmp);
+ }
+
+ object& operator=(const object &other) {
+ other.inc_ref();
+ dec_ref();
+ m_ptr = other.m_ptr;
+ return *this;
+ }
+
+ object& operator=(object &&other) noexcept {
+ if (this != &other) {
+ handle temp(m_ptr);
+ m_ptr = other.m_ptr;
+ other.m_ptr = nullptr;
+ temp.dec_ref();
+ }
+ return *this;
+ }
+
+ // Calling cast() on an object lvalue just copies (via handle::cast)
+ template <typename T> T cast() const &;
+ // Calling on an object rvalue does a move, if needed and/or possible
+ template <typename T> T cast() &&;
+
+protected:
+ // Tags for choosing constructors from raw PyObject *
+ struct borrowed_t { };
+ struct stolen_t { };
+
+ template <typename T> friend T reinterpret_borrow(handle);
+ template <typename T> friend T reinterpret_steal(handle);
+
+public:
+ // Only accessible from derived classes and the reinterpret_* functions
+ object(handle h, borrowed_t) : handle(h) { inc_ref(); }
+ object(handle h, stolen_t) : handle(h) { }
+};
+
+/** \rst
+ Declare that a `handle` or ``PyObject *`` is a certain type and borrow the reference.
+ The target type ``T`` must be `object` or one of its derived classes. The function
+ doesn't do any conversions or checks. It's up to the user to make sure that the
+ target type is correct.
+
+ .. code-block:: cpp
+
+ PyObject *p = PyList_GetItem(obj, index);
+ py::object o = reinterpret_borrow<py::object>(p);
+ // or
+ py::tuple t = reinterpret_borrow<py::tuple>(p); // <-- `p` must be already be a `tuple`
+\endrst */
+template <typename T> T reinterpret_borrow(handle h) { return {h, object::borrowed_t{}}; }
+
+/** \rst
+ Like `reinterpret_borrow`, but steals the reference.
+
+ .. code-block:: cpp
+
+ PyObject *p = PyObject_Str(obj);
+ py::str s = reinterpret_steal<py::str>(p); // <-- `p` must be already be a `str`
+\endrst */
+template <typename T> T reinterpret_steal(handle h) { return {h, object::stolen_t{}}; }
+
+NAMESPACE_BEGIN(detail)
+inline std::string error_string();
+NAMESPACE_END(detail)
+
+/// Fetch and hold an error which was already set in Python. An instance of this is typically
+/// thrown to propagate python-side errors back through C++ which can either be caught manually or
+/// else falls back to the function dispatcher (which then raises the captured error back to
+/// python).
+class error_already_set : public std::runtime_error {
+public:
+ /// Constructs a new exception from the current Python error indicator, if any. The current
+ /// Python error indicator will be cleared.
+ error_already_set() : std::runtime_error(detail::error_string()) {
+ PyErr_Fetch(&type.ptr(), &value.ptr(), &trace.ptr());
+ }
+
+ inline ~error_already_set();
+
+ /// Give the currently-held error back to Python, if any. If there is currently a Python error
+ /// already set it is cleared first. After this call, the current object no longer stores the
+ /// error variables (but the `.what()` string is still available).
+ void restore() { PyErr_Restore(type.release().ptr(), value.release().ptr(), trace.release().ptr()); }
+
+ // Does nothing; provided for backwards compatibility.
+ PYBIND11_DEPRECATED("Use of error_already_set.clear() is deprecated")
+ void clear() {}
+
+ /// Check if the currently trapped error type matches the given Python exception class (or a
+ /// subclass thereof). May also be passed a tuple to search for any exception class matches in
+ /// the given tuple.
+ bool matches(handle ex) const { return PyErr_GivenExceptionMatches(ex.ptr(), type.ptr()); }
+
+private:
+ object type, value, trace;
+};
+
+/** \defgroup python_builtins _
+ Unless stated otherwise, the following C++ functions behave the same
+ as their Python counterparts.
+ */
+
+/** \ingroup python_builtins
+ \rst
+ Return true if ``obj`` is an instance of ``T``. Type ``T`` must be a subclass of
+ `object` or a class which was exposed to Python as ``py::class_<T>``.
+\endrst */
+template <typename T, detail::enable_if_t<std::is_base_of<object, T>::value, int> = 0>
+bool isinstance(handle obj) { return T::check_(obj); }
+
+template <typename T, detail::enable_if_t<!std::is_base_of<object, T>::value, int> = 0>
+bool isinstance(handle obj) { return detail::isinstance_generic(obj, typeid(T)); }
+
+template <> inline bool isinstance<handle>(handle obj) = delete;
+template <> inline bool isinstance<object>(handle obj) { return obj.ptr() != nullptr; }
+
+/// \ingroup python_builtins
+/// Return true if ``obj`` is an instance of the ``type``.
+inline bool isinstance(handle obj, handle type) {
+ const auto result = PyObject_IsInstance(obj.ptr(), type.ptr());
+ if (result == -1)
+ throw error_already_set();
+ return result != 0;
+}
+
+/// \addtogroup python_builtins
+/// @{
+inline bool hasattr(handle obj, handle name) {
+ return PyObject_HasAttr(obj.ptr(), name.ptr()) == 1;
+}
+
+inline bool hasattr(handle obj, const char *name) {
+ return PyObject_HasAttrString(obj.ptr(), name) == 1;
+}
+
+inline object getattr(handle obj, handle name) {
+ PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr());
+ if (!result) { throw error_already_set(); }
+ return reinterpret_steal<object>(result);
+}
+
+inline object getattr(handle obj, const char *name) {
+ PyObject *result = PyObject_GetAttrString(obj.ptr(), name);
+ if (!result) { throw error_already_set(); }
+ return reinterpret_steal<object>(result);
+}
+
+inline object getattr(handle obj, handle name, handle default_) {
+ if (PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr())) {
+ return reinterpret_steal<object>(result);
+ } else {
+ PyErr_Clear();
+ return reinterpret_borrow<object>(default_);
+ }
+}
+
+inline object getattr(handle obj, const char *name, handle default_) {
+ if (PyObject *result = PyObject_GetAttrString(obj.ptr(), name)) {
+ return reinterpret_steal<object>(result);
+ } else {
+ PyErr_Clear();
+ return reinterpret_borrow<object>(default_);
+ }
+}
+
+inline void setattr(handle obj, handle name, handle value) {
+ if (PyObject_SetAttr(obj.ptr(), name.ptr(), value.ptr()) != 0) { throw error_already_set(); }
+}
+
+inline void setattr(handle obj, const char *name, handle value) {
+ if (PyObject_SetAttrString(obj.ptr(), name, value.ptr()) != 0) { throw error_already_set(); }
+}
+
+inline ssize_t hash(handle obj) {
+ auto h = PyObject_Hash(obj.ptr());
+ if (h == -1) { throw error_already_set(); }
+ return h;
+}
+
+/// @} python_builtins
+
+NAMESPACE_BEGIN(detail)
+inline handle get_function(handle value) {
+ if (value) {
+#if PY_MAJOR_VERSION >= 3
+ if (PyInstanceMethod_Check(value.ptr()))
+ value = PyInstanceMethod_GET_FUNCTION(value.ptr());
+ else
+#endif
+ if (PyMethod_Check(value.ptr()))
+ value = PyMethod_GET_FUNCTION(value.ptr());
+ }
+ return value;
+}
+
+// Helper aliases/functions to support implicit casting of values given to python accessors/methods.
+// When given a pyobject, this simply returns the pyobject as-is; for other C++ type, the value goes
+// through pybind11::cast(obj) to convert it to an `object`.
+template <typename T, enable_if_t<is_pyobject<T>::value, int> = 0>
+auto object_or_cast(T &&o) -> decltype(std::forward<T>(o)) { return std::forward<T>(o); }
+// The following casting version is implemented in cast.h:
+template <typename T, enable_if_t<!is_pyobject<T>::value, int> = 0>
+object object_or_cast(T &&o);
+// Match a PyObject*, which we want to convert directly to handle via its converting constructor
+inline handle object_or_cast(PyObject *ptr) { return ptr; }
+
+
+template <typename Policy>
+class accessor : public object_api<accessor<Policy>> {
+ using key_type = typename Policy::key_type;
+
+public:
+ accessor(handle obj, key_type key) : obj(obj), key(std::move(key)) { }
+ accessor(const accessor &) = default;
+ accessor(accessor &&) = default;
+
+ // accessor overload required to override default assignment operator (templates are not allowed
+ // to replace default compiler-generated assignments).
+ void operator=(const accessor &a) && { std::move(*this).operator=(handle(a)); }
+ void operator=(const accessor &a) & { operator=(handle(a)); }
+
+ template <typename T> void operator=(T &&value) && {
+ Policy::set(obj, key, object_or_cast(std::forward<T>(value)));
+ }
+ template <typename T> void operator=(T &&value) & {
+ get_cache() = reinterpret_borrow<object>(object_or_cast(std::forward<T>(value)));
+ }
+
+ template <typename T = Policy>
+ PYBIND11_DEPRECATED("Use of obj.attr(...) as bool is deprecated in favor of pybind11::hasattr(obj, ...)")
+ explicit operator enable_if_t<std::is_same<T, accessor_policies::str_attr>::value ||
+ std::is_same<T, accessor_policies::obj_attr>::value, bool>() const {
+ return hasattr(obj, key);
+ }
+ template <typename T = Policy>
+ PYBIND11_DEPRECATED("Use of obj[key] as bool is deprecated in favor of obj.contains(key)")
+ explicit operator enable_if_t<std::is_same<T, accessor_policies::generic_item>::value, bool>() const {
+ return obj.contains(key);
+ }
+
+ operator object() const { return get_cache(); }
+ PyObject *ptr() const { return get_cache().ptr(); }
+ template <typename T> T cast() const { return get_cache().template cast<T>(); }
+
+private:
+ object &get_cache() const {
+ if (!cache) { cache = Policy::get(obj, key); }
+ return cache;
+ }
+
+private:
+ handle obj;
+ key_type key;
+ mutable object cache;
+};
+
+NAMESPACE_BEGIN(accessor_policies)
+struct obj_attr {
+ using key_type = object;
+ static object get(handle obj, handle key) { return getattr(obj, key); }
+ static void set(handle obj, handle key, handle val) { setattr(obj, key, val); }
+};
+
+struct str_attr {
+ using key_type = const char *;
+ static object get(handle obj, const char *key) { return getattr(obj, key); }
+ static void set(handle obj, const char *key, handle val) { setattr(obj, key, val); }
+};
+
+struct generic_item {
+ using key_type = object;
+
+ static object get(handle obj, handle key) {
+ PyObject *result = PyObject_GetItem(obj.ptr(), key.ptr());
+ if (!result) { throw error_already_set(); }
+ return reinterpret_steal<object>(result);
+ }
+
+ static void set(handle obj, handle key, handle val) {
+ if (PyObject_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) { throw error_already_set(); }
+ }
+};
+
+struct sequence_item {
+ using key_type = size_t;
+
+ static object get(handle obj, size_t index) {
+ PyObject *result = PySequence_GetItem(obj.ptr(), static_cast<ssize_t>(index));
+ if (!result) { throw error_already_set(); }
+ return reinterpret_steal<object>(result);
+ }
+
+ static void set(handle obj, size_t index, handle val) {
+ // PySequence_SetItem does not steal a reference to 'val'
+ if (PySequence_SetItem(obj.ptr(), static_cast<ssize_t>(index), val.ptr()) != 0) {
+ throw error_already_set();
+ }
+ }
+};
+
+struct list_item {
+ using key_type = size_t;
+
+ static object get(handle obj, size_t index) {
+ PyObject *result = PyList_GetItem(obj.ptr(), static_cast<ssize_t>(index));
+ if (!result) { throw error_already_set(); }
+ return reinterpret_borrow<object>(result);
+ }
+
+ static void set(handle obj, size_t index, handle val) {
+ // PyList_SetItem steals a reference to 'val'
+ if (PyList_SetItem(obj.ptr(), static_cast<ssize_t>(index), val.inc_ref().ptr()) != 0) {
+ throw error_already_set();
+ }
+ }
+};
+
+struct tuple_item {
+ using key_type = size_t;
+
+ static object get(handle obj, size_t index) {
+ PyObject *result = PyTuple_GetItem(obj.ptr(), static_cast<ssize_t>(index));
+ if (!result) { throw error_already_set(); }
+ return reinterpret_borrow<object>(result);
+ }
+
+ static void set(handle obj, size_t index, handle val) {
+ // PyTuple_SetItem steals a reference to 'val'
+ if (PyTuple_SetItem(obj.ptr(), static_cast<ssize_t>(index), val.inc_ref().ptr()) != 0) {
+ throw error_already_set();
+ }
+ }
+};
+NAMESPACE_END(accessor_policies)
+
+/// STL iterator template used for tuple, list, sequence and dict
+template <typename Policy>
+class generic_iterator : public Policy {
+ using It = generic_iterator;
+
+public:
+ using difference_type = ssize_t;
+ using iterator_category = typename Policy::iterator_category;
+ using value_type = typename Policy::value_type;
+ using reference = typename Policy::reference;
+ using pointer = typename Policy::pointer;
+
+ generic_iterator() = default;
+ generic_iterator(handle seq, ssize_t index) : Policy(seq, index) { }
+
+ reference operator*() const { return Policy::dereference(); }
+ reference operator[](difference_type n) const { return *(*this + n); }
+ pointer operator->() const { return **this; }
+
+ It &operator++() { Policy::increment(); return *this; }
+ It operator++(int) { auto copy = *this; Policy::increment(); return copy; }
+ It &operator--() { Policy::decrement(); return *this; }
+ It operator--(int) { auto copy = *this; Policy::decrement(); return copy; }
+ It &operator+=(difference_type n) { Policy::advance(n); return *this; }
+ It &operator-=(difference_type n) { Policy::advance(-n); return *this; }
+
+ friend It operator+(const It &a, difference_type n) { auto copy = a; return copy += n; }
+ friend It operator+(difference_type n, const It &b) { return b + n; }
+ friend It operator-(const It &a, difference_type n) { auto copy = a; return copy -= n; }
+ friend difference_type operator-(const It &a, const It &b) { return a.distance_to(b); }
+
+ friend bool operator==(const It &a, const It &b) { return a.equal(b); }
+ friend bool operator!=(const It &a, const It &b) { return !(a == b); }
+ friend bool operator< (const It &a, const It &b) { return b - a > 0; }
+ friend bool operator> (const It &a, const It &b) { return b < a; }
+ friend bool operator>=(const It &a, const It &b) { return !(a < b); }
+ friend bool operator<=(const It &a, const It &b) { return !(a > b); }
+};
+
+NAMESPACE_BEGIN(iterator_policies)
+/// Quick proxy class needed to implement ``operator->`` for iterators which can't return pointers
+template <typename T>
+struct arrow_proxy {
+ T value;
+
+ arrow_proxy(T &&value) : value(std::move(value)) { }
+ T *operator->() const { return &value; }
+};
+
+/// Lightweight iterator policy using just a simple pointer: see ``PySequence_Fast_ITEMS``
+class sequence_fast_readonly {
+protected:
+ using iterator_category = std::random_access_iterator_tag;
+ using value_type = handle;
+ using reference = const handle;
+ using pointer = arrow_proxy<const handle>;
+
+ sequence_fast_readonly(handle obj, ssize_t n) : ptr(PySequence_Fast_ITEMS(obj.ptr()) + n) { }
+
+ reference dereference() const { return *ptr; }
+ void increment() { ++ptr; }
+ void decrement() { --ptr; }
+ void advance(ssize_t n) { ptr += n; }
+ bool equal(const sequence_fast_readonly &b) const { return ptr == b.ptr; }
+ ssize_t distance_to(const sequence_fast_readonly &b) const { return ptr - b.ptr; }
+
+private:
+ PyObject **ptr;
+};
+
+/// Full read and write access using the sequence protocol: see ``detail::sequence_accessor``
+class sequence_slow_readwrite {
+protected:
+ using iterator_category = std::random_access_iterator_tag;
+ using value_type = object;
+ using reference = sequence_accessor;
+ using pointer = arrow_proxy<const sequence_accessor>;
+
+ sequence_slow_readwrite(handle obj, ssize_t index) : obj(obj), index(index) { }
+
+ reference dereference() const { return {obj, static_cast<size_t>(index)}; }
+ void increment() { ++index; }
+ void decrement() { --index; }
+ void advance(ssize_t n) { index += n; }
+ bool equal(const sequence_slow_readwrite &b) const { return index == b.index; }
+ ssize_t distance_to(const sequence_slow_readwrite &b) const { return index - b.index; }
+
+private:
+ handle obj;
+ ssize_t index;
+};
+
+/// Python's dictionary protocol permits this to be a forward iterator
+class dict_readonly {
+protected:
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = std::pair<handle, handle>;
+ using reference = const value_type;
+ using pointer = arrow_proxy<const value_type>;
+
+ dict_readonly() = default;
+ dict_readonly(handle obj, ssize_t pos) : obj(obj), pos(pos) { increment(); }
+
+ reference dereference() const { return {key, value}; }
+ void increment() { if (!PyDict_Next(obj.ptr(), &pos, &key, &value)) { pos = -1; } }
+ bool equal(const dict_readonly &b) const { return pos == b.pos; }
+
+private:
+ handle obj;
+ PyObject *key, *value;
+ ssize_t pos = -1;
+};
+NAMESPACE_END(iterator_policies)
+
+#if !defined(PYPY_VERSION)
+using tuple_iterator = generic_iterator<iterator_policies::sequence_fast_readonly>;
+using list_iterator = generic_iterator<iterator_policies::sequence_fast_readonly>;
+#else
+using tuple_iterator = generic_iterator<iterator_policies::sequence_slow_readwrite>;
+using list_iterator = generic_iterator<iterator_policies::sequence_slow_readwrite>;
+#endif
+
+using sequence_iterator = generic_iterator<iterator_policies::sequence_slow_readwrite>;
+using dict_iterator = generic_iterator<iterator_policies::dict_readonly>;
+
+inline bool PyIterable_Check(PyObject *obj) {
+ PyObject *iter = PyObject_GetIter(obj);
+ if (iter) {
+ Py_DECREF(iter);
+ return true;
+ } else {
+ PyErr_Clear();
+ return false;
+ }
+}
+
+inline bool PyNone_Check(PyObject *o) { return o == Py_None; }
+
+inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); }
+
+class kwargs_proxy : public handle {
+public:
+ explicit kwargs_proxy(handle h) : handle(h) { }
+};
+
+class args_proxy : public handle {
+public:
+ explicit args_proxy(handle h) : handle(h) { }
+ kwargs_proxy operator*() const { return kwargs_proxy(*this); }
+};
+
+/// Python argument categories (using PEP 448 terms)
+template <typename T> using is_keyword = std::is_base_of<arg, T>;
+template <typename T> using is_s_unpacking = std::is_same<args_proxy, T>; // * unpacking
+template <typename T> using is_ds_unpacking = std::is_same<kwargs_proxy, T>; // ** unpacking
+template <typename T> using is_positional = satisfies_none_of<T,
+ is_keyword, is_s_unpacking, is_ds_unpacking
+>;
+template <typename T> using is_keyword_or_ds = satisfies_any_of<T, is_keyword, is_ds_unpacking>;
+
+// Call argument collector forward declarations
+template <return_value_policy policy = return_value_policy::automatic_reference>
+class simple_collector;
+template <return_value_policy policy = return_value_policy::automatic_reference>
+class unpacking_collector;
+
+NAMESPACE_END(detail)
+
+// TODO: After the deprecated constructors are removed, this macro can be simplified by
+// inheriting ctors: `using Parent::Parent`. It's not an option right now because
+// the `using` statement triggers the parent deprecation warning even if the ctor
+// isn't even used.
+#define PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \
+ public: \
+ PYBIND11_DEPRECATED("Use reinterpret_borrow<"#Name">() or reinterpret_steal<"#Name">()") \
+ Name(handle h, bool is_borrowed) : Parent(is_borrowed ? Parent(h, borrowed_t{}) : Parent(h, stolen_t{})) { } \
+ Name(handle h, borrowed_t) : Parent(h, borrowed_t{}) { } \
+ Name(handle h, stolen_t) : Parent(h, stolen_t{}) { } \
+ PYBIND11_DEPRECATED("Use py::isinstance<py::python_type>(obj) instead") \
+ bool check() const { return m_ptr != nullptr && (bool) CheckFun(m_ptr); } \
+ static bool check_(handle h) { return h.ptr() != nullptr && CheckFun(h.ptr()); }
+
+#define PYBIND11_OBJECT_CVT(Name, Parent, CheckFun, ConvertFun) \
+ PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \
+ /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \
+ Name(const object &o) \
+ : Parent(check_(o) ? o.inc_ref().ptr() : ConvertFun(o.ptr()), stolen_t{}) \
+ { if (!m_ptr) throw error_already_set(); } \
+ Name(object &&o) \
+ : Parent(check_(o) ? o.release().ptr() : ConvertFun(o.ptr()), stolen_t{}) \
+ { if (!m_ptr) throw error_already_set(); } \
+ template <typename Policy_> \
+ Name(const ::pybind11::detail::accessor<Policy_> &a) : Name(object(a)) { }
+
+#define PYBIND11_OBJECT(Name, Parent, CheckFun) \
+ PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \
+ /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \
+ Name(const object &o) : Parent(o) { } \
+ Name(object &&o) : Parent(std::move(o)) { }
+
+#define PYBIND11_OBJECT_DEFAULT(Name, Parent, CheckFun) \
+ PYBIND11_OBJECT(Name, Parent, CheckFun) \
+ Name() : Parent() { }
+
+/// \addtogroup pytypes
+/// @{
+
+/** \rst
+ Wraps a Python iterator so that it can also be used as a C++ input iterator
+
+ Caveat: copying an iterator does not (and cannot) clone the internal
+ state of the Python iterable. This also applies to the post-increment
+ operator. This iterator should only be used to retrieve the current
+ value using ``operator*()``.
+\endrst */
+class iterator : public object {
+public:
+ using iterator_category = std::input_iterator_tag;
+ using difference_type = ssize_t;
+ using value_type = handle;
+ using reference = const handle;
+ using pointer = const handle *;
+
+ PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check)
+
+ iterator& operator++() {
+ advance();
+ return *this;
+ }
+
+ iterator operator++(int) {
+ auto rv = *this;
+ advance();
+ return rv;
+ }
+
+ reference operator*() const {
+ if (m_ptr && !value.ptr()) {
+ auto& self = const_cast<iterator &>(*this);
+ self.advance();
+ }
+ return value;
+ }
+
+ pointer operator->() const { operator*(); return &value; }
+
+ /** \rst
+ The value which marks the end of the iteration. ``it == iterator::sentinel()``
+ is equivalent to catching ``StopIteration`` in Python.
+
+ .. code-block:: cpp
+
+ void foo(py::iterator it) {
+ while (it != py::iterator::sentinel()) {
+ // use `*it`
+ ++it;
+ }
+ }
+ \endrst */
+ static iterator sentinel() { return {}; }
+
+ friend bool operator==(const iterator &a, const iterator &b) { return a->ptr() == b->ptr(); }
+ friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); }
+
+private:
+ void advance() {
+ value = reinterpret_steal<object>(PyIter_Next(m_ptr));
+ if (PyErr_Occurred()) { throw error_already_set(); }
+ }
+
+private:
+ object value = {};
+};
+
+class iterable : public object {
+public:
+ PYBIND11_OBJECT_DEFAULT(iterable, object, detail::PyIterable_Check)
+};
+
+class bytes;
+
+class str : public object {
+public:
+ PYBIND11_OBJECT_CVT(str, object, detail::PyUnicode_Check_Permissive, raw_str)
+
+ str(const char *c, size_t n)
+ : object(PyUnicode_FromStringAndSize(c, (ssize_t) n), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate string object!");
+ }
+
+ // 'explicit' is explicitly omitted from the following constructors to allow implicit conversion to py::str from C++ string-like objects
+ str(const char *c = "")
+ : object(PyUnicode_FromString(c), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate string object!");
+ }
+
+ str(const std::string &s) : str(s.data(), s.size()) { }
+
+ explicit str(const bytes &b);
+
+ /** \rst
+ Return a string representation of the object. This is analogous to
+ the ``str()`` function in Python.
+ \endrst */
+ explicit str(handle h) : object(raw_str(h.ptr()), stolen_t{}) { }
+
+ operator std::string() const {
+ object temp = *this;
+ if (PyUnicode_Check(m_ptr)) {
+ temp = reinterpret_steal<object>(PyUnicode_AsUTF8String(m_ptr));
+ if (!temp)
+ pybind11_fail("Unable to extract string contents! (encoding issue)");
+ }
+ char *buffer;
+ ssize_t length;
+ if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length))
+ pybind11_fail("Unable to extract string contents! (invalid type)");
+ return std::string(buffer, (size_t) length);
+ }
+
+ template <typename... Args>
+ str format(Args &&...args) const {
+ return attr("format")(std::forward<Args>(args)...);
+ }
+
+private:
+ /// Return string representation -- always returns a new reference, even if already a str
+ static PyObject *raw_str(PyObject *op) {
+ PyObject *str_value = PyObject_Str(op);
+#if PY_MAJOR_VERSION < 3
+ if (!str_value) throw error_already_set();
+ PyObject *unicode = PyUnicode_FromEncodedObject(str_value, "utf-8", nullptr);
+ Py_XDECREF(str_value); str_value = unicode;
+#endif
+ return str_value;
+ }
+};
+/// @} pytypes
+
+inline namespace literals {
+/** \rst
+ String literal version of `str`
+ \endrst */
+inline str operator"" _s(const char *s, size_t size) { return {s, size}; }
+}
+
+/// \addtogroup pytypes
+/// @{
+class bytes : public object {
+public:
+ PYBIND11_OBJECT(bytes, object, PYBIND11_BYTES_CHECK)
+
+ // Allow implicit conversion:
+ bytes(const char *c = "")
+ : object(PYBIND11_BYTES_FROM_STRING(c), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate bytes object!");
+ }
+
+ bytes(const char *c, size_t n)
+ : object(PYBIND11_BYTES_FROM_STRING_AND_SIZE(c, (ssize_t) n), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate bytes object!");
+ }
+
+ // Allow implicit conversion:
+ bytes(const std::string &s) : bytes(s.data(), s.size()) { }
+
+ explicit bytes(const pybind11::str &s);
+
+ operator std::string() const {
+ char *buffer;
+ ssize_t length;
+ if (PYBIND11_BYTES_AS_STRING_AND_SIZE(m_ptr, &buffer, &length))
+ pybind11_fail("Unable to extract bytes contents!");
+ return std::string(buffer, (size_t) length);
+ }
+};
+
+inline bytes::bytes(const pybind11::str &s) {
+ object temp = s;
+ if (PyUnicode_Check(s.ptr())) {
+ temp = reinterpret_steal<object>(PyUnicode_AsUTF8String(s.ptr()));
+ if (!temp)
+ pybind11_fail("Unable to extract string contents! (encoding issue)");
+ }
+ char *buffer;
+ ssize_t length;
+ if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length))
+ pybind11_fail("Unable to extract string contents! (invalid type)");
+ auto obj = reinterpret_steal<object>(PYBIND11_BYTES_FROM_STRING_AND_SIZE(buffer, length));
+ if (!obj)
+ pybind11_fail("Could not allocate bytes object!");
+ m_ptr = obj.release().ptr();
+}
+
+inline str::str(const bytes& b) {
+ char *buffer;
+ ssize_t length;
+ if (PYBIND11_BYTES_AS_STRING_AND_SIZE(b.ptr(), &buffer, &length))
+ pybind11_fail("Unable to extract bytes contents!");
+ auto obj = reinterpret_steal<object>(PyUnicode_FromStringAndSize(buffer, (ssize_t) length));
+ if (!obj)
+ pybind11_fail("Could not allocate string object!");
+ m_ptr = obj.release().ptr();
+}
+
+class none : public object {
+public:
+ PYBIND11_OBJECT(none, object, detail::PyNone_Check)
+ none() : object(Py_None, borrowed_t{}) { }
+};
+
+class bool_ : public object {
+public:
+ PYBIND11_OBJECT_CVT(bool_, object, PyBool_Check, raw_bool)
+ bool_() : object(Py_False, borrowed_t{}) { }
+ // Allow implicit conversion from and to `bool`:
+ bool_(bool value) : object(value ? Py_True : Py_False, borrowed_t{}) { }
+ operator bool() const { return m_ptr && PyLong_AsLong(m_ptr) != 0; }
+
+private:
+ /// Return the truth value of an object -- always returns a new reference
+ static PyObject *raw_bool(PyObject *op) {
+ const auto value = PyObject_IsTrue(op);
+ if (value == -1) return nullptr;
+ return handle(value ? Py_True : Py_False).inc_ref().ptr();
+ }
+};
+
+NAMESPACE_BEGIN(detail)
+// Converts a value to the given unsigned type. If an error occurs, you get back (Unsigned) -1;
+// otherwise you get back the unsigned long or unsigned long long value cast to (Unsigned).
+// (The distinction is critically important when casting a returned -1 error value to some other
+// unsigned type: (A)-1 != (B)-1 when A and B are unsigned types of different sizes).
+template <typename Unsigned>
+Unsigned as_unsigned(PyObject *o) {
+ if (sizeof(Unsigned) <= sizeof(unsigned long)
+#if PY_VERSION_HEX < 0x03000000
+ || PyInt_Check(o)
+#endif
+ ) {
+ unsigned long v = PyLong_AsUnsignedLong(o);
+ return v == (unsigned long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v;
+ }
+ else {
+ unsigned long long v = PyLong_AsUnsignedLongLong(o);
+ return v == (unsigned long long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v;
+ }
+}
+NAMESPACE_END(detail)
+
+class int_ : public object {
+public:
+ PYBIND11_OBJECT_CVT(int_, object, PYBIND11_LONG_CHECK, PyNumber_Long)
+ int_() : object(PyLong_FromLong(0), stolen_t{}) { }
+ // Allow implicit conversion from C++ integral types:
+ template <typename T,
+ detail::enable_if_t<std::is_integral<T>::value, int> = 0>
+ int_(T value) {
+ if (sizeof(T) <= sizeof(long)) {
+ if (std::is_signed<T>::value)
+ m_ptr = PyLong_FromLong((long) value);
+ else
+ m_ptr = PyLong_FromUnsignedLong((unsigned long) value);
+ } else {
+ if (std::is_signed<T>::value)
+ m_ptr = PyLong_FromLongLong((long long) value);
+ else
+ m_ptr = PyLong_FromUnsignedLongLong((unsigned long long) value);
+ }
+ if (!m_ptr) pybind11_fail("Could not allocate int object!");
+ }
+
+ template <typename T,
+ detail::enable_if_t<std::is_integral<T>::value, int> = 0>
+ operator T() const {
+ return std::is_unsigned<T>::value
+ ? detail::as_unsigned<T>(m_ptr)
+ : sizeof(T) <= sizeof(long)
+ ? (T) PyLong_AsLong(m_ptr)
+ : (T) PYBIND11_LONG_AS_LONGLONG(m_ptr);
+ }
+};
+
+class float_ : public object {
+public:
+ PYBIND11_OBJECT_CVT(float_, object, PyFloat_Check, PyNumber_Float)
+ // Allow implicit conversion from float/double:
+ float_(float value) : object(PyFloat_FromDouble((double) value), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate float object!");
+ }
+ float_(double value = .0) : object(PyFloat_FromDouble((double) value), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate float object!");
+ }
+ operator float() const { return (float) PyFloat_AsDouble(m_ptr); }
+ operator double() const { return (double) PyFloat_AsDouble(m_ptr); }
+};
+
+class weakref : public object {
+public:
+ PYBIND11_OBJECT_DEFAULT(weakref, object, PyWeakref_Check)
+ explicit weakref(handle obj, handle callback = {})
+ : object(PyWeakref_NewRef(obj.ptr(), callback.ptr()), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate weak reference!");
+ }
+};
+
+class slice : public object {
+public:
+ PYBIND11_OBJECT_DEFAULT(slice, object, PySlice_Check)
+ slice(ssize_t start_, ssize_t stop_, ssize_t step_) {
+ int_ start(start_), stop(stop_), step(step_);
+ m_ptr = PySlice_New(start.ptr(), stop.ptr(), step.ptr());
+ if (!m_ptr) pybind11_fail("Could not allocate slice object!");
+ }
+ bool compute(size_t length, size_t *start, size_t *stop, size_t *step,
+ size_t *slicelength) const {
+ return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr,
+ (ssize_t) length, (ssize_t *) start,
+ (ssize_t *) stop, (ssize_t *) step,
+ (ssize_t *) slicelength) == 0;
+ }
+};
+
+class capsule : public object {
+public:
+ PYBIND11_OBJECT_DEFAULT(capsule, object, PyCapsule_CheckExact)
+ PYBIND11_DEPRECATED("Use reinterpret_borrow<capsule>() or reinterpret_steal<capsule>()")
+ capsule(PyObject *ptr, bool is_borrowed) : object(is_borrowed ? object(ptr, borrowed_t{}) : object(ptr, stolen_t{})) { }
+
+ explicit capsule(const void *value, const char *name = nullptr, void (*destructor)(PyObject *) = nullptr)
+ : object(PyCapsule_New(const_cast<void *>(value), name, destructor), stolen_t{}) {
+ if (!m_ptr)
+ pybind11_fail("Could not allocate capsule object!");
+ }
+
+ PYBIND11_DEPRECATED("Please pass a destructor that takes a void pointer as input")
+ capsule(const void *value, void (*destruct)(PyObject *))
+ : object(PyCapsule_New(const_cast<void*>(value), nullptr, destruct), stolen_t{}) {
+ if (!m_ptr)
+ pybind11_fail("Could not allocate capsule object!");
+ }
+
+ capsule(const void *value, void (*destructor)(void *)) {
+ m_ptr = PyCapsule_New(const_cast<void *>(value), nullptr, [](PyObject *o) {
+ auto destructor = reinterpret_cast<void (*)(void *)>(PyCapsule_GetContext(o));
+ void *ptr = PyCapsule_GetPointer(o, nullptr);
+ destructor(ptr);
+ });
+
+ if (!m_ptr)
+ pybind11_fail("Could not allocate capsule object!");
+
+ if (PyCapsule_SetContext(m_ptr, (void *) destructor) != 0)
+ pybind11_fail("Could not set capsule context!");
+ }
+
+ capsule(void (*destructor)()) {
+ m_ptr = PyCapsule_New(reinterpret_cast<void *>(destructor), nullptr, [](PyObject *o) {
+ auto destructor = reinterpret_cast<void (*)()>(PyCapsule_GetPointer(o, nullptr));
+ destructor();
+ });
+
+ if (!m_ptr)
+ pybind11_fail("Could not allocate capsule object!");
+ }
+
+ template <typename T> operator T *() const {
+ auto name = this->name();
+ T * result = static_cast<T *>(PyCapsule_GetPointer(m_ptr, name));
+ if (!result) pybind11_fail("Unable to extract capsule contents!");
+ return result;
+ }
+
+ const char *name() const { return PyCapsule_GetName(m_ptr); }
+};
+
+class tuple : public object {
+public:
+ PYBIND11_OBJECT_CVT(tuple, object, PyTuple_Check, PySequence_Tuple)
+ explicit tuple(size_t size = 0) : object(PyTuple_New((ssize_t) size), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate tuple object!");
+ }
+ size_t size() const { return (size_t) PyTuple_Size(m_ptr); }
+ detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }
+ detail::tuple_iterator begin() const { return {*this, 0}; }
+ detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }
+};
+
+class dict : public object {
+public:
+ PYBIND11_OBJECT_CVT(dict, object, PyDict_Check, raw_dict)
+ dict() : object(PyDict_New(), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate dict object!");
+ }
+ template <typename... Args,
+ typename = detail::enable_if_t<detail::all_of<detail::is_keyword_or_ds<Args>...>::value>,
+ // MSVC workaround: it can't compile an out-of-line definition, so defer the collector
+ typename collector = detail::deferred_t<detail::unpacking_collector<>, Args...>>
+ explicit dict(Args &&...args) : dict(collector(std::forward<Args>(args)...).kwargs()) { }
+
+ size_t size() const { return (size_t) PyDict_Size(m_ptr); }
+ detail::dict_iterator begin() const { return {*this, 0}; }
+ detail::dict_iterator end() const { return {}; }
+ void clear() const { PyDict_Clear(ptr()); }
+ bool contains(handle key) const { return PyDict_Contains(ptr(), key.ptr()) == 1; }
+ bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; }
+
+private:
+ /// Call the `dict` Python type -- always returns a new reference
+ static PyObject *raw_dict(PyObject *op) {
+ if (PyDict_Check(op))
+ return handle(op).inc_ref().ptr();
+ return PyObject_CallFunctionObjArgs((PyObject *) &PyDict_Type, op, nullptr);
+ }
+};
+
+class sequence : public object {
+public:
+ PYBIND11_OBJECT_DEFAULT(sequence, object, PySequence_Check)
+ size_t size() const { return (size_t) PySequence_Size(m_ptr); }
+ detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
+ detail::sequence_iterator begin() const { return {*this, 0}; }
+ detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; }
+};
+
+class list : public object {
+public:
+ PYBIND11_OBJECT_CVT(list, object, PyList_Check, PySequence_List)
+ explicit list(size_t size = 0) : object(PyList_New((ssize_t) size), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate list object!");
+ }
+ size_t size() const { return (size_t) PyList_Size(m_ptr); }
+ detail::list_accessor operator[](size_t index) const { return {*this, index}; }
+ detail::list_iterator begin() const { return {*this, 0}; }
+ detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; }
+ template <typename T> void append(T &&val) const {
+ PyList_Append(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr());
+ }
+};
+
+class args : public tuple { PYBIND11_OBJECT_DEFAULT(args, tuple, PyTuple_Check) };
+class kwargs : public dict { PYBIND11_OBJECT_DEFAULT(kwargs, dict, PyDict_Check) };
+
+class set : public object {
+public:
+ PYBIND11_OBJECT_CVT(set, object, PySet_Check, PySet_New)
+ set() : object(PySet_New(nullptr), stolen_t{}) {
+ if (!m_ptr) pybind11_fail("Could not allocate set object!");
+ }
+ size_t size() const { return (size_t) PySet_Size(m_ptr); }
+ template <typename T> bool add(T &&val) const {
+ return PySet_Add(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 0;
+ }
+ void clear() const { PySet_Clear(m_ptr); }
+};
+
+class function : public object {
+public:
+ PYBIND11_OBJECT_DEFAULT(function, object, PyCallable_Check)
+ handle cpp_function() const {
+ handle fun = detail::get_function(m_ptr);
+ if (fun && PyCFunction_Check(fun.ptr()))
+ return fun;
+ return handle();
+ }
+ bool is_cpp_function() const { return (bool) cpp_function(); }
+};
+
+class buffer : public object {
+public:
+ PYBIND11_OBJECT_DEFAULT(buffer, object, PyObject_CheckBuffer)
+
+ buffer_info request(bool writable = false) {
+ int flags = PyBUF_STRIDES | PyBUF_FORMAT;
+ if (writable) flags |= PyBUF_WRITABLE;
+ Py_buffer *view = new Py_buffer();
+ if (PyObject_GetBuffer(m_ptr, view, flags) != 0) {
+ delete view;
+ throw error_already_set();
+ }
+ return buffer_info(view);
+ }
+};
+
+class memoryview : public object {
+public:
+ explicit memoryview(const buffer_info& info) {
+ static Py_buffer buf { };
+ // Py_buffer uses signed sizes, strides and shape!..
+ static std::vector<Py_ssize_t> py_strides { };
+ static std::vector<Py_ssize_t> py_shape { };
+ buf.buf = info.ptr;
+ buf.itemsize = info.itemsize;
+ buf.format = const_cast<char *>(info.format.c_str());
+ buf.ndim = (int) info.ndim;
+ buf.len = info.size;
+ py_strides.clear();
+ py_shape.clear();
+ for (size_t i = 0; i < (size_t) info.ndim; ++i) {
+ py_strides.push_back(info.strides[i]);
+ py_shape.push_back(info.shape[i]);
+ }
+ buf.strides = py_strides.data();
+ buf.shape = py_shape.data();
+ buf.suboffsets = nullptr;
+ buf.readonly = false;
+ buf.internal = nullptr;
+
+ m_ptr = PyMemoryView_FromBuffer(&buf);
+ if (!m_ptr)
+ pybind11_fail("Unable to create memoryview from buffer descriptor");
+ }
+
+ PYBIND11_OBJECT_CVT(memoryview, object, PyMemoryView_Check, PyMemoryView_FromObject)
+};
+/// @} pytypes
+
+/// \addtogroup python_builtins
+/// @{
+inline size_t len(handle h) {
+ ssize_t result = PyObject_Length(h.ptr());
+ if (result < 0)
+ pybind11_fail("Unable to compute length of object");
+ return (size_t) result;
+}
+
+inline str repr(handle h) {
+ PyObject *str_value = PyObject_Repr(h.ptr());
+ if (!str_value) throw error_already_set();
+#if PY_MAJOR_VERSION < 3
+ PyObject *unicode = PyUnicode_FromEncodedObject(str_value, "utf-8", nullptr);
+ Py_XDECREF(str_value); str_value = unicode;
+ if (!str_value) throw error_already_set();
+#endif
+ return reinterpret_steal<str>(str_value);
+}
+
+inline iterator iter(handle obj) {
+ PyObject *result = PyObject_GetIter(obj.ptr());
+ if (!result) { throw error_already_set(); }
+ return reinterpret_steal<iterator>(result);
+}
+/// @} python_builtins
+
+NAMESPACE_BEGIN(detail)
+template <typename D> iterator object_api<D>::begin() const { return iter(derived()); }
+template <typename D> iterator object_api<D>::end() const { return iterator::sentinel(); }
+template <typename D> item_accessor object_api<D>::operator[](handle key) const {
+ return {derived(), reinterpret_borrow<object>(key)};
+}
+template <typename D> item_accessor object_api<D>::operator[](const char *key) const {
+ return {derived(), pybind11::str(key)};
+}
+template <typename D> obj_attr_accessor object_api<D>::attr(handle key) const {
+ return {derived(), reinterpret_borrow<object>(key)};
+}
+template <typename D> str_attr_accessor object_api<D>::attr(const char *key) const {
+ return {derived(), key};
+}
+template <typename D> args_proxy object_api<D>::operator*() const {
+ return args_proxy(derived().ptr());
+}
+template <typename D> template <typename T> bool object_api<D>::contains(T &&item) const {
+ return attr("__contains__")(std::forward<T>(item)).template cast<bool>();
+}
+
+template <typename D>
+pybind11::str object_api<D>::str() const { return pybind11::str(derived()); }
+
+template <typename D>
+str_attr_accessor object_api<D>::doc() const { return attr("__doc__"); }
+
+template <typename D>
+handle object_api<D>::get_type() const { return (PyObject *) Py_TYPE(derived().ptr()); }
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/stl.h b/ml/dlib/dlib/external/pybind11/include/pybind11/stl.h
new file mode 100644
index 000000000..90eb7ea2e
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/stl.h
@@ -0,0 +1,370 @@
+/*
+ pybind11/stl.h: Transparent conversion for STL data types
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pybind11.h"
+#include <set>
+#include <unordered_set>
+#include <map>
+#include <unordered_map>
+#include <iostream>
+#include <list>
+#include <valarray>
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
+#endif
+
+#ifdef __has_include
+// std::optional (but including it in c++14 mode isn't allowed)
+# if defined(PYBIND11_CPP17) && __has_include(<optional>)
+# include <optional>
+# define PYBIND11_HAS_OPTIONAL 1
+# endif
+// std::experimental::optional (but not allowed in c++11 mode)
+# if defined(PYBIND11_CPP14) && (__has_include(<experimental/optional>) && \
+ !__has_include(<optional>))
+# include <experimental/optional>
+# define PYBIND11_HAS_EXP_OPTIONAL 1
+# endif
+// std::variant
+# if defined(PYBIND11_CPP17) && __has_include(<variant>)
+# include <variant>
+# define PYBIND11_HAS_VARIANT 1
+# endif
+#elif defined(_MSC_VER) && defined(PYBIND11_CPP17)
+# include <optional>
+# include <variant>
+# define PYBIND11_HAS_OPTIONAL 1
+# define PYBIND11_HAS_VARIANT 1
+#endif
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
+/// forwarding a container element). Typically used indirect via forwarded_type(), below.
+template <typename T, typename U>
+using forwarded_type = conditional_t<
+ std::is_lvalue_reference<T>::value, remove_reference_t<U> &, remove_reference_t<U> &&>;
+
+/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically
+/// used for forwarding a container's elements.
+template <typename T, typename U>
+forwarded_type<T, U> forward_like(U &&u) {
+ return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
+}
+
+template <typename Type, typename Key> struct set_caster {
+ using type = Type;
+ using key_conv = make_caster<Key>;
+
+ bool load(handle src, bool convert) {
+ if (!isinstance<pybind11::set>(src))
+ return false;
+ auto s = reinterpret_borrow<pybind11::set>(src);
+ value.clear();
+ for (auto entry : s) {
+ key_conv conv;
+ if (!conv.load(entry, convert))
+ return false;
+ value.insert(cast_op<Key &&>(std::move(conv)));
+ }
+ return true;
+ }
+
+ template <typename T>
+ static handle cast(T &&src, return_value_policy policy, handle parent) {
+ pybind11::set s;
+ for (auto &&value : src) {
+ auto value_ = reinterpret_steal<object>(key_conv::cast(forward_like<T>(value), policy, parent));
+ if (!value_ || !s.add(value_))
+ return handle();
+ }
+ return s.release();
+ }
+
+ PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name() + _("]"));
+};
+
+template <typename Type, typename Key, typename Value> struct map_caster {
+ using key_conv = make_caster<Key>;
+ using value_conv = make_caster<Value>;
+
+ bool load(handle src, bool convert) {
+ if (!isinstance<dict>(src))
+ return false;
+ auto d = reinterpret_borrow<dict>(src);
+ value.clear();
+ for (auto it : d) {
+ key_conv kconv;
+ value_conv vconv;
+ if (!kconv.load(it.first.ptr(), convert) ||
+ !vconv.load(it.second.ptr(), convert))
+ return false;
+ value.emplace(cast_op<Key &&>(std::move(kconv)), cast_op<Value &&>(std::move(vconv)));
+ }
+ return true;
+ }
+
+ template <typename T>
+ static handle cast(T &&src, return_value_policy policy, handle parent) {
+ dict d;
+ for (auto &&kv : src) {
+ auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy, parent));
+ auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy, parent));
+ if (!key || !value)
+ return handle();
+ d[key] = value;
+ }
+ return d.release();
+ }
+
+ PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name() + _(", ") + value_conv::name() + _("]"));
+};
+
+template <typename Type, typename Value> struct list_caster {
+ using value_conv = make_caster<Value>;
+
+ bool load(handle src, bool convert) {
+ if (!isinstance<sequence>(src))
+ return false;
+ auto s = reinterpret_borrow<sequence>(src);
+ value.clear();
+ reserve_maybe(s, &value);
+ for (auto it : s) {
+ value_conv conv;
+ if (!conv.load(it, convert))
+ return false;
+ value.push_back(cast_op<Value &&>(std::move(conv)));
+ }
+ return true;
+ }
+
+private:
+ template <typename T = Type,
+ enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)), void>::value, int> = 0>
+ void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); }
+ void reserve_maybe(sequence, void *) { }
+
+public:
+ template <typename T>
+ static handle cast(T &&src, return_value_policy policy, handle parent) {
+ list l(src.size());
+ size_t index = 0;
+ for (auto &&value : src) {
+ auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
+ if (!value_)
+ return handle();
+ PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
+ }
+ return l.release();
+ }
+
+ PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name() + _("]"));
+};
+
+template <typename Type, typename Alloc> struct type_caster<std::vector<Type, Alloc>>
+ : list_caster<std::vector<Type, Alloc>, Type> { };
+
+template <typename Type, typename Alloc> struct type_caster<std::list<Type, Alloc>>
+ : list_caster<std::list<Type, Alloc>, Type> { };
+
+template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0> struct array_caster {
+ using value_conv = make_caster<Value>;
+
+private:
+ template <bool R = Resizable>
+ bool require_size(enable_if_t<R, size_t> size) {
+ if (value.size() != size)
+ value.resize(size);
+ return true;
+ }
+ template <bool R = Resizable>
+ bool require_size(enable_if_t<!R, size_t> size) {
+ return size == Size;
+ }
+
+public:
+ bool load(handle src, bool convert) {
+ if (!isinstance<list>(src))
+ return false;
+ auto l = reinterpret_borrow<list>(src);
+ if (!require_size(l.size()))
+ return false;
+ size_t ctr = 0;
+ for (auto it : l) {
+ value_conv conv;
+ if (!conv.load(it, convert))
+ return false;
+ value[ctr++] = cast_op<Value &&>(std::move(conv));
+ }
+ return true;
+ }
+
+ template <typename T>
+ static handle cast(T &&src, return_value_policy policy, handle parent) {
+ list l(src.size());
+ size_t index = 0;
+ for (auto &&value : src) {
+ auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
+ if (!value_)
+ return handle();
+ PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
+ }
+ return l.release();
+ }
+
+ PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name() + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
+};
+
+template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
+ : array_caster<std::array<Type, Size>, Type, false, Size> { };
+
+template <typename Type> struct type_caster<std::valarray<Type>>
+ : array_caster<std::valarray<Type>, Type, true> { };
+
+template <typename Key, typename Compare, typename Alloc> struct type_caster<std::set<Key, Compare, Alloc>>
+ : set_caster<std::set<Key, Compare, Alloc>, Key> { };
+
+template <typename Key, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
+ : set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> { };
+
+template <typename Key, typename Value, typename Compare, typename Alloc> struct type_caster<std::map<Key, Value, Compare, Alloc>>
+ : map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> { };
+
+template <typename Key, typename Value, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>>
+ : map_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>, Key, Value> { };
+
+// This type caster is intended to be used for std::optional and std::experimental::optional
+template<typename T> struct optional_caster {
+ using value_conv = make_caster<typename T::value_type>;
+
+ template <typename T_>
+ static handle cast(T_ &&src, return_value_policy policy, handle parent) {
+ if (!src)
+ return none().inc_ref();
+ return value_conv::cast(*std::forward<T_>(src), policy, parent);
+ }
+
+ bool load(handle src, bool convert) {
+ if (!src) {
+ return false;
+ } else if (src.is_none()) {
+ return true; // default-constructed value is already empty
+ }
+ value_conv inner_caster;
+ if (!inner_caster.load(src, convert))
+ return false;
+
+ value.emplace(cast_op<typename T::value_type &&>(std::move(inner_caster)));
+ return true;
+ }
+
+ PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name() + _("]"));
+};
+
+#if PYBIND11_HAS_OPTIONAL
+template<typename T> struct type_caster<std::optional<T>>
+ : public optional_caster<std::optional<T>> {};
+
+template<> struct type_caster<std::nullopt_t>
+ : public void_caster<std::nullopt_t> {};
+#endif
+
+#if PYBIND11_HAS_EXP_OPTIONAL
+template<typename T> struct type_caster<std::experimental::optional<T>>
+ : public optional_caster<std::experimental::optional<T>> {};
+
+template<> struct type_caster<std::experimental::nullopt_t>
+ : public void_caster<std::experimental::nullopt_t> {};
+#endif
+
+/// Visit a variant and cast any found type to Python
+struct variant_caster_visitor {
+ return_value_policy policy;
+ handle parent;
+
+ using result_type = handle; // required by boost::variant in C++11
+
+ template <typename T>
+ result_type operator()(T &&src) const {
+ return make_caster<T>::cast(std::forward<T>(src), policy, parent);
+ }
+};
+
+/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar
+/// `namespace::variant` types which provide a `namespace::visit()` function are handled here
+/// automatically using argument-dependent lookup. Users can provide specializations for other
+/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`.
+template <template<typename...> class Variant>
+struct visit_helper {
+ template <typename... Args>
+ static auto call(Args &&...args) -> decltype(visit(std::forward<Args>(args)...)) {
+ return visit(std::forward<Args>(args)...);
+ }
+};
+
+/// Generic variant caster
+template <typename Variant> struct variant_caster;
+
+template <template<typename...> class V, typename... Ts>
+struct variant_caster<V<Ts...>> {
+ static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative.");
+
+ template <typename U, typename... Us>
+ bool load_alternative(handle src, bool convert, type_list<U, Us...>) {
+ auto caster = make_caster<U>();
+ if (caster.load(src, convert)) {
+ value = cast_op<U>(caster);
+ return true;
+ }
+ return load_alternative(src, convert, type_list<Us...>{});
+ }
+
+ bool load_alternative(handle, bool, type_list<>) { return false; }
+
+ bool load(handle src, bool convert) {
+ // Do a first pass without conversions to improve constructor resolution.
+ // E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
+ // slot of the variant. Without two-pass loading `double` would be filled
+ // because it appears first and a conversion is possible.
+ if (convert && load_alternative(src, false, type_list<Ts...>{}))
+ return true;
+ return load_alternative(src, convert, type_list<Ts...>{});
+ }
+
+ template <typename Variant>
+ static handle cast(Variant &&src, return_value_policy policy, handle parent) {
+ return visit_helper<V>::call(variant_caster_visitor{policy, parent},
+ std::forward<Variant>(src));
+ }
+
+ using Type = V<Ts...>;
+ PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name()...) + _("]"));
+};
+
+#if PYBIND11_HAS_VARIANT
+template <typename... Ts>
+struct type_caster<std::variant<Ts...>> : variant_caster<std::variant<Ts...>> { };
+#endif
+NAMESPACE_END(detail)
+
+inline std::ostream &operator<<(std::ostream &os, const handle &obj) {
+ os << (std::string) str(obj);
+ return os;
+}
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/stl_bind.h b/ml/dlib/dlib/external/pybind11/include/pybind11/stl_bind.h
new file mode 100644
index 000000000..38dd68f69
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/include/pybind11/stl_bind.h
@@ -0,0 +1,599 @@
+/*
+ pybind11/std_bind.h: Binding generators for STL data types
+
+ Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "detail/common.h"
+#include "operators.h"
+
+#include <algorithm>
+#include <sstream>
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+/* SFINAE helper class used by 'is_comparable */
+template <typename T> struct container_traits {
+ template <typename T2> static std::true_type test_comparable(decltype(std::declval<const T2 &>() == std::declval<const T2 &>())*);
+ template <typename T2> static std::false_type test_comparable(...);
+ template <typename T2> static std::true_type test_value(typename T2::value_type *);
+ template <typename T2> static std::false_type test_value(...);
+ template <typename T2> static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *);
+ template <typename T2> static std::false_type test_pair(...);
+
+ static constexpr const bool is_comparable = std::is_same<std::true_type, decltype(test_comparable<T>(nullptr))>::value;
+ static constexpr const bool is_pair = std::is_same<std::true_type, decltype(test_pair<T>(nullptr, nullptr))>::value;
+ static constexpr const bool is_vector = std::is_same<std::true_type, decltype(test_value<T>(nullptr))>::value;
+ static constexpr const bool is_element = !is_pair && !is_vector;
+};
+
+/* Default: is_comparable -> std::false_type */
+template <typename T, typename SFINAE = void>
+struct is_comparable : std::false_type { };
+
+/* For non-map data structures, check whether operator== can be instantiated */
+template <typename T>
+struct is_comparable<
+ T, enable_if_t<container_traits<T>::is_element &&
+ container_traits<T>::is_comparable>>
+ : std::true_type { };
+
+/* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */
+template <typename T>
+struct is_comparable<T, enable_if_t<container_traits<T>::is_vector>> {
+ static constexpr const bool value =
+ is_comparable<typename T::value_type>::value;
+};
+
+/* For pairs, recursively check the two data types */
+template <typename T>
+struct is_comparable<T, enable_if_t<container_traits<T>::is_pair>> {
+ static constexpr const bool value =
+ is_comparable<typename T::first_type>::value &&
+ is_comparable<typename T::second_type>::value;
+};
+
+/* Fallback functions */
+template <typename, typename, typename... Args> void vector_if_copy_constructible(const Args &...) { }
+template <typename, typename, typename... Args> void vector_if_equal_operator(const Args &...) { }
+template <typename, typename, typename... Args> void vector_if_insertion_operator(const Args &...) { }
+template <typename, typename, typename... Args> void vector_modifiers(const Args &...) { }
+
+template<typename Vector, typename Class_>
+void vector_if_copy_constructible(enable_if_t<is_copy_constructible<Vector>::value, Class_> &cl) {
+ cl.def(init<const Vector &>(), "Copy constructor");
+}
+
+template<typename Vector, typename Class_>
+void vector_if_equal_operator(enable_if_t<is_comparable<Vector>::value, Class_> &cl) {
+ using T = typename Vector::value_type;
+
+ cl.def(self == self);
+ cl.def(self != self);
+
+ cl.def("count",
+ [](const Vector &v, const T &x) {
+ return std::count(v.begin(), v.end(), x);
+ },
+ arg("x"),
+ "Return the number of times ``x`` appears in the list"
+ );
+
+ cl.def("remove", [](Vector &v, const T &x) {
+ auto p = std::find(v.begin(), v.end(), x);
+ if (p != v.end())
+ v.erase(p);
+ else
+ throw value_error();
+ },
+ arg("x"),
+ "Remove the first item from the list whose value is x. "
+ "It is an error if there is no such item."
+ );
+
+ cl.def("__contains__",
+ [](const Vector &v, const T &x) {
+ return std::find(v.begin(), v.end(), x) != v.end();
+ },
+ arg("x"),
+ "Return true the container contains ``x``"
+ );
+}
+
+// Vector modifiers -- requires a copyable vector_type:
+// (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems
+// silly to allow deletion but not insertion, so include them here too.)
+template <typename Vector, typename Class_>
+void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_type>::value, Class_> &cl) {
+ using T = typename Vector::value_type;
+ using SizeType = typename Vector::size_type;
+ using DiffType = typename Vector::difference_type;
+
+ cl.def("append",
+ [](Vector &v, const T &value) { v.push_back(value); },
+ arg("x"),
+ "Add an item to the end of the list");
+
+ cl.def(init([](iterable it) {
+ auto v = std::unique_ptr<Vector>(new Vector());
+ v->reserve(len(it));
+ for (handle h : it)
+ v->push_back(h.cast<T>());
+ return v.release();
+ }));
+
+ cl.def("extend",
+ [](Vector &v, const Vector &src) {
+ v.insert(v.end(), src.begin(), src.end());
+ },
+ arg("L"),
+ "Extend the list by appending all the items in the given list"
+ );
+
+ cl.def("insert",
+ [](Vector &v, SizeType i, const T &x) {
+ if (i > v.size())
+ throw index_error();
+ v.insert(v.begin() + (DiffType) i, x);
+ },
+ arg("i") , arg("x"),
+ "Insert an item at a given position."
+ );
+
+ cl.def("pop",
+ [](Vector &v) {
+ if (v.empty())
+ throw index_error();
+ T t = v.back();
+ v.pop_back();
+ return t;
+ },
+ "Remove and return the last item"
+ );
+
+ cl.def("pop",
+ [](Vector &v, SizeType i) {
+ if (i >= v.size())
+ throw index_error();
+ T t = v[i];
+ v.erase(v.begin() + (DiffType) i);
+ return t;
+ },
+ arg("i"),
+ "Remove and return the item at index ``i``"
+ );
+
+ cl.def("__setitem__",
+ [](Vector &v, SizeType i, const T &t) {
+ if (i >= v.size())
+ throw index_error();
+ v[i] = t;
+ }
+ );
+
+ /// Slicing protocol
+ cl.def("__getitem__",
+ [](const Vector &v, slice slice) -> Vector * {
+ size_t start, stop, step, slicelength;
+
+ if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
+ throw error_already_set();
+
+ Vector *seq = new Vector();
+ seq->reserve((size_t) slicelength);
+
+ for (size_t i=0; i<slicelength; ++i) {
+ seq->push_back(v[start]);
+ start += step;
+ }
+ return seq;
+ },
+ arg("s"),
+ "Retrieve list elements using a slice object"
+ );
+
+ cl.def("__setitem__",
+ [](Vector &v, slice slice, const Vector &value) {
+ size_t start, stop, step, slicelength;
+ if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
+ throw error_already_set();
+
+ if (slicelength != value.size())
+ throw std::runtime_error("Left and right hand size of slice assignment have different sizes!");
+
+ for (size_t i=0; i<slicelength; ++i) {
+ v[start] = value[i];
+ start += step;
+ }
+ },
+ "Assign list elements using a slice object"
+ );
+
+ cl.def("__delitem__",
+ [](Vector &v, SizeType i) {
+ if (i >= v.size())
+ throw index_error();
+ v.erase(v.begin() + DiffType(i));
+ },
+ "Delete the list elements at index ``i``"
+ );
+
+ cl.def("__delitem__",
+ [](Vector &v, slice slice) {
+ size_t start, stop, step, slicelength;
+
+ if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
+ throw error_already_set();
+
+ if (step == 1 && false) {
+ v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength));
+ } else {
+ for (size_t i = 0; i < slicelength; ++i) {
+ v.erase(v.begin() + DiffType(start));
+ start += step - 1;
+ }
+ }
+ },
+ "Delete list elements using a slice object"
+ );
+
+}
+
+// If the type has an operator[] that doesn't return a reference (most notably std::vector<bool>),
+// we have to access by copying; otherwise we return by reference.
+template <typename Vector> using vector_needs_copy = negation<
+ std::is_same<decltype(std::declval<Vector>()[typename Vector::size_type()]), typename Vector::value_type &>>;
+
+// The usual case: access and iterate by reference
+template <typename Vector, typename Class_>
+void vector_accessor(enable_if_t<!vector_needs_copy<Vector>::value, Class_> &cl) {
+ using T = typename Vector::value_type;
+ using SizeType = typename Vector::size_type;
+ using ItType = typename Vector::iterator;
+
+ cl.def("__getitem__",
+ [](Vector &v, SizeType i) -> T & {
+ if (i >= v.size())
+ throw index_error();
+ return v[i];
+ },
+ return_value_policy::reference_internal // ref + keepalive
+ );
+
+ cl.def("__iter__",
+ [](Vector &v) {
+ return make_iterator<
+ return_value_policy::reference_internal, ItType, ItType, T&>(
+ v.begin(), v.end());
+ },
+ keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
+ );
+}
+
+// The case for special objects, like std::vector<bool>, that have to be returned-by-copy:
+template <typename Vector, typename Class_>
+void vector_accessor(enable_if_t<vector_needs_copy<Vector>::value, Class_> &cl) {
+ using T = typename Vector::value_type;
+ using SizeType = typename Vector::size_type;
+ using ItType = typename Vector::iterator;
+ cl.def("__getitem__",
+ [](const Vector &v, SizeType i) -> T {
+ if (i >= v.size())
+ throw index_error();
+ return v[i];
+ }
+ );
+
+ cl.def("__iter__",
+ [](Vector &v) {
+ return make_iterator<
+ return_value_policy::copy, ItType, ItType, T>(
+ v.begin(), v.end());
+ },
+ keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
+ );
+}
+
+template <typename Vector, typename Class_> auto vector_if_insertion_operator(Class_ &cl, std::string const &name)
+ -> decltype(std::declval<std::ostream&>() << std::declval<typename Vector::value_type>(), void()) {
+ using size_type = typename Vector::size_type;
+
+ cl.def("__repr__",
+ [name](Vector &v) {
+ std::ostringstream s;
+ s << name << '[';
+ for (size_type i=0; i < v.size(); ++i) {
+ s << v[i];
+ if (i != v.size() - 1)
+ s << ", ";
+ }
+ s << ']';
+ return s.str();
+ },
+ "Return the canonical string representation of this list."
+ );
+}
+
+// Provide the buffer interface for vectors if we have data() and we have a format for it
+// GCC seems to have "void std::vector<bool>::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer
+template <typename Vector, typename = void>
+struct vector_has_data_and_format : std::false_type {};
+template <typename Vector>
+struct vector_has_data_and_format<Vector, enable_if_t<std::is_same<decltype(format_descriptor<typename Vector::value_type>::format(), std::declval<Vector>().data()), typename Vector::value_type*>::value>> : std::true_type {};
+
+// Add the buffer interface to a vector
+template <typename Vector, typename Class_, typename... Args>
+enable_if_t<detail::any_of<std::is_same<Args, buffer_protocol>...>::value>
+vector_buffer(Class_& cl) {
+ using T = typename Vector::value_type;
+
+ static_assert(vector_has_data_and_format<Vector>::value, "There is not an appropriate format descriptor for this vector");
+
+ // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here
+ format_descriptor<T>::format();
+
+ cl.def_buffer([](Vector& v) -> buffer_info {
+ return buffer_info(v.data(), static_cast<ssize_t>(sizeof(T)), format_descriptor<T>::format(), 1, {v.size()}, {sizeof(T)});
+ });
+
+ cl.def(init([](buffer buf) {
+ auto info = buf.request();
+ if (info.ndim != 1 || info.strides[0] % static_cast<ssize_t>(sizeof(T)))
+ throw type_error("Only valid 1D buffers can be copied to a vector");
+ if (!detail::compare_buffer_info<T>::compare(info) || (ssize_t) sizeof(T) != info.itemsize)
+ throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")");
+
+ auto vec = std::unique_ptr<Vector>(new Vector());
+ vec->reserve((size_t) info.shape[0]);
+ T *p = static_cast<T*>(info.ptr);
+ ssize_t step = info.strides[0] / static_cast<ssize_t>(sizeof(T));
+ T *end = p + info.shape[0] * step;
+ for (; p != end; p += step)
+ vec->push_back(*p);
+ return vec.release();
+ }));
+
+ return;
+}
+
+template <typename Vector, typename Class_, typename... Args>
+enable_if_t<!detail::any_of<std::is_same<Args, buffer_protocol>...>::value> vector_buffer(Class_&) {}
+
+NAMESPACE_END(detail)
+
+//
+// std::vector
+//
+template <typename Vector, typename holder_type = std::unique_ptr<Vector>, typename... Args>
+class_<Vector, holder_type> bind_vector(handle scope, std::string const &name, Args&&... args) {
+ using Class_ = class_<Vector, holder_type>;
+
+ // If the value_type is unregistered (e.g. a converting type) or is itself registered
+ // module-local then make the vector binding module-local as well:
+ using vtype = typename Vector::value_type;
+ auto vtype_info = detail::get_type_info(typeid(vtype));
+ bool local = !vtype_info || vtype_info->module_local;
+
+ Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
+
+ // Declare the buffer interface if a buffer_protocol() is passed in
+ detail::vector_buffer<Vector, Class_, Args...>(cl);
+
+ cl.def(init<>());
+
+ // Register copy constructor (if possible)
+ detail::vector_if_copy_constructible<Vector, Class_>(cl);
+
+ // Register comparison-related operators and functions (if possible)
+ detail::vector_if_equal_operator<Vector, Class_>(cl);
+
+ // Register stream insertion operator (if possible)
+ detail::vector_if_insertion_operator<Vector, Class_>(cl, name);
+
+ // Modifiers require copyable vector value type
+ detail::vector_modifiers<Vector, Class_>(cl);
+
+ // Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive
+ detail::vector_accessor<Vector, Class_>(cl);
+
+ cl.def("__bool__",
+ [](const Vector &v) -> bool {
+ return !v.empty();
+ },
+ "Check whether the list is nonempty"
+ );
+
+ cl.def("__len__", &Vector::size);
+
+
+
+
+#if 0
+ // C++ style functions deprecated, leaving it here as an example
+ cl.def(init<size_type>());
+
+ cl.def("resize",
+ (void (Vector::*) (size_type count)) & Vector::resize,
+ "changes the number of elements stored");
+
+ cl.def("erase",
+ [](Vector &v, SizeType i) {
+ if (i >= v.size())
+ throw index_error();
+ v.erase(v.begin() + i);
+ }, "erases element at index ``i``");
+
+ cl.def("empty", &Vector::empty, "checks whether the container is empty");
+ cl.def("size", &Vector::size, "returns the number of elements");
+ cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end");
+ cl.def("pop_back", &Vector::pop_back, "removes the last element");
+
+ cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements");
+ cl.def("reserve", &Vector::reserve, "reserves storage");
+ cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage");
+ cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory");
+
+ cl.def("clear", &Vector::clear, "clears the contents");
+ cl.def("swap", &Vector::swap, "swaps the contents");
+
+ cl.def("front", [](Vector &v) {
+ if (v.size()) return v.front();
+ else throw index_error();
+ }, "access the first element");
+
+ cl.def("back", [](Vector &v) {
+ if (v.size()) return v.back();
+ else throw index_error();
+ }, "access the last element ");
+
+#endif
+
+ return cl;
+}
+
+
+
+//
+// std::map, std::unordered_map
+//
+
+NAMESPACE_BEGIN(detail)
+
+/* Fallback functions */
+template <typename, typename, typename... Args> void map_if_insertion_operator(const Args &...) { }
+template <typename, typename, typename... Args> void map_assignment(const Args &...) { }
+
+// Map assignment when copy-assignable: just copy the value
+template <typename Map, typename Class_>
+void map_assignment(enable_if_t<std::is_copy_assignable<typename Map::mapped_type>::value, Class_> &cl) {
+ using KeyType = typename Map::key_type;
+ using MappedType = typename Map::mapped_type;
+
+ cl.def("__setitem__",
+ [](Map &m, const KeyType &k, const MappedType &v) {
+ auto it = m.find(k);
+ if (it != m.end()) it->second = v;
+ else m.emplace(k, v);
+ }
+ );
+}
+
+// Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting
+template<typename Map, typename Class_>
+void map_assignment(enable_if_t<
+ !std::is_copy_assignable<typename Map::mapped_type>::value &&
+ is_copy_constructible<typename Map::mapped_type>::value,
+ Class_> &cl) {
+ using KeyType = typename Map::key_type;
+ using MappedType = typename Map::mapped_type;
+
+ cl.def("__setitem__",
+ [](Map &m, const KeyType &k, const MappedType &v) {
+ // We can't use m[k] = v; because value type might not be default constructable
+ auto r = m.emplace(k, v);
+ if (!r.second) {
+ // value type is not copy assignable so the only way to insert it is to erase it first...
+ m.erase(r.first);
+ m.emplace(k, v);
+ }
+ }
+ );
+}
+
+
+template <typename Map, typename Class_> auto map_if_insertion_operator(Class_ &cl, std::string const &name)
+-> decltype(std::declval<std::ostream&>() << std::declval<typename Map::key_type>() << std::declval<typename Map::mapped_type>(), void()) {
+
+ cl.def("__repr__",
+ [name](Map &m) {
+ std::ostringstream s;
+ s << name << '{';
+ bool f = false;
+ for (auto const &kv : m) {
+ if (f)
+ s << ", ";
+ s << kv.first << ": " << kv.second;
+ f = true;
+ }
+ s << '}';
+ return s.str();
+ },
+ "Return the canonical string representation of this map."
+ );
+}
+
+
+NAMESPACE_END(detail)
+
+template <typename Map, typename holder_type = std::unique_ptr<Map>, typename... Args>
+class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&... args) {
+ using KeyType = typename Map::key_type;
+ using MappedType = typename Map::mapped_type;
+ using Class_ = class_<Map, holder_type>;
+
+ // If either type is a non-module-local bound type then make the map binding non-local as well;
+ // otherwise (e.g. both types are either module-local or converting) the map will be
+ // module-local.
+ auto tinfo = detail::get_type_info(typeid(MappedType));
+ bool local = !tinfo || tinfo->module_local;
+ if (local) {
+ tinfo = detail::get_type_info(typeid(KeyType));
+ local = !tinfo || tinfo->module_local;
+ }
+
+ Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
+
+ cl.def(init<>());
+
+ // Register stream insertion operator (if possible)
+ detail::map_if_insertion_operator<Map, Class_>(cl, name);
+
+ cl.def("__bool__",
+ [](const Map &m) -> bool { return !m.empty(); },
+ "Check whether the map is nonempty"
+ );
+
+ cl.def("__iter__",
+ [](Map &m) { return make_key_iterator(m.begin(), m.end()); },
+ keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
+ );
+
+ cl.def("items",
+ [](Map &m) { return make_iterator(m.begin(), m.end()); },
+ keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
+ );
+
+ cl.def("__getitem__",
+ [](Map &m, const KeyType &k) -> MappedType & {
+ auto it = m.find(k);
+ if (it == m.end())
+ throw key_error();
+ return it->second;
+ },
+ return_value_policy::reference_internal // ref + keepalive
+ );
+
+ // Assignment provided only if the type is copyable
+ detail::map_assignment<Map, Class_>(cl);
+
+ cl.def("__delitem__",
+ [](Map &m, const KeyType &k) {
+ auto it = m.find(k);
+ if (it == m.end())
+ throw key_error();
+ m.erase(it);
+ }
+ );
+
+ cl.def("__len__", &Map::size);
+
+ return cl;
+}
+
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/ml/dlib/dlib/external/pybind11/tools/FindCatch.cmake b/ml/dlib/dlib/external/pybind11/tools/FindCatch.cmake
new file mode 100644
index 000000000..9d490c5aa
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/tools/FindCatch.cmake
@@ -0,0 +1,57 @@
+# - Find the Catch test framework or download it (single header)
+#
+# This is a quick module for internal use. It assumes that Catch is
+# REQUIRED and that a minimum version is provided (not EXACT). If
+# a suitable version isn't found locally, the single header file
+# will be downloaded and placed in the build dir: PROJECT_BINARY_DIR.
+#
+# This code sets the following variables:
+# CATCH_INCLUDE_DIR - path to catch.hpp
+# CATCH_VERSION - version number
+
+if(NOT Catch_FIND_VERSION)
+ message(FATAL_ERROR "A version number must be specified.")
+elseif(Catch_FIND_REQUIRED)
+ message(FATAL_ERROR "This module assumes Catch is not required.")
+elseif(Catch_FIND_VERSION_EXACT)
+ message(FATAL_ERROR "Exact version numbers are not supported, only minimum.")
+endif()
+
+# Extract the version number from catch.hpp
+function(_get_catch_version)
+ file(STRINGS "${CATCH_INCLUDE_DIR}/catch.hpp" version_line REGEX "Catch v.*" LIMIT_COUNT 1)
+ if(version_line MATCHES "Catch v([0-9]+)\\.([0-9]+)\\.([0-9]+)")
+ set(CATCH_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}" PARENT_SCOPE)
+ endif()
+endfunction()
+
+# Download the single-header version of Catch
+function(_download_catch version destination_dir)
+ message(STATUS "Downloading catch v${version}...")
+ set(url https://github.com/philsquared/Catch/releases/download/v${version}/catch.hpp)
+ file(DOWNLOAD ${url} "${destination_dir}/catch.hpp" STATUS status)
+ list(GET status 0 error)
+ if(error)
+ message(FATAL_ERROR "Could not download ${url}")
+ endif()
+ set(CATCH_INCLUDE_DIR "${destination_dir}" CACHE INTERNAL "")
+endfunction()
+
+# Look for catch locally
+find_path(CATCH_INCLUDE_DIR NAMES catch.hpp PATH_SUFFIXES catch)
+if(CATCH_INCLUDE_DIR)
+ _get_catch_version()
+endif()
+
+# Download the header if it wasn't found or if it's outdated
+if(NOT CATCH_VERSION OR CATCH_VERSION VERSION_LESS ${Catch_FIND_VERSION})
+ if(DOWNLOAD_CATCH)
+ _download_catch(${Catch_FIND_VERSION} "${PROJECT_BINARY_DIR}/catch/")
+ _get_catch_version()
+ else()
+ set(CATCH_FOUND FALSE)
+ return()
+ endif()
+endif()
+
+set(CATCH_FOUND TRUE)
diff --git a/ml/dlib/dlib/external/pybind11/tools/FindEigen3.cmake b/ml/dlib/dlib/external/pybind11/tools/FindEigen3.cmake
new file mode 100644
index 000000000..9c546a05d
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/tools/FindEigen3.cmake
@@ -0,0 +1,81 @@
+# - Try to find Eigen3 lib
+#
+# This module supports requiring a minimum version, e.g. you can do
+# find_package(Eigen3 3.1.2)
+# to require version 3.1.2 or newer of Eigen3.
+#
+# Once done this will define
+#
+# EIGEN3_FOUND - system has eigen lib with correct version
+# EIGEN3_INCLUDE_DIR - the eigen include directory
+# EIGEN3_VERSION - eigen version
+
+# Copyright (c) 2006, 2007 Montel Laurent, <montel@kde.org>
+# Copyright (c) 2008, 2009 Gael Guennebaud, <g.gael@free.fr>
+# Copyright (c) 2009 Benoit Jacob <jacob.benoit.1@gmail.com>
+# Redistribution and use is allowed according to the terms of the 2-clause BSD license.
+
+if(NOT Eigen3_FIND_VERSION)
+ if(NOT Eigen3_FIND_VERSION_MAJOR)
+ set(Eigen3_FIND_VERSION_MAJOR 2)
+ endif(NOT Eigen3_FIND_VERSION_MAJOR)
+ if(NOT Eigen3_FIND_VERSION_MINOR)
+ set(Eigen3_FIND_VERSION_MINOR 91)
+ endif(NOT Eigen3_FIND_VERSION_MINOR)
+ if(NOT Eigen3_FIND_VERSION_PATCH)
+ set(Eigen3_FIND_VERSION_PATCH 0)
+ endif(NOT Eigen3_FIND_VERSION_PATCH)
+
+ set(Eigen3_FIND_VERSION "${Eigen3_FIND_VERSION_MAJOR}.${Eigen3_FIND_VERSION_MINOR}.${Eigen3_FIND_VERSION_PATCH}")
+endif(NOT Eigen3_FIND_VERSION)
+
+macro(_eigen3_check_version)
+ file(READ "${EIGEN3_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h" _eigen3_version_header)
+
+ string(REGEX MATCH "define[ \t]+EIGEN_WORLD_VERSION[ \t]+([0-9]+)" _eigen3_world_version_match "${_eigen3_version_header}")
+ set(EIGEN3_WORLD_VERSION "${CMAKE_MATCH_1}")
+ string(REGEX MATCH "define[ \t]+EIGEN_MAJOR_VERSION[ \t]+([0-9]+)" _eigen3_major_version_match "${_eigen3_version_header}")
+ set(EIGEN3_MAJOR_VERSION "${CMAKE_MATCH_1}")
+ string(REGEX MATCH "define[ \t]+EIGEN_MINOR_VERSION[ \t]+([0-9]+)" _eigen3_minor_version_match "${_eigen3_version_header}")
+ set(EIGEN3_MINOR_VERSION "${CMAKE_MATCH_1}")
+
+ set(EIGEN3_VERSION ${EIGEN3_WORLD_VERSION}.${EIGEN3_MAJOR_VERSION}.${EIGEN3_MINOR_VERSION})
+ if(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
+ set(EIGEN3_VERSION_OK FALSE)
+ else(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
+ set(EIGEN3_VERSION_OK TRUE)
+ endif(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
+
+ if(NOT EIGEN3_VERSION_OK)
+
+ message(STATUS "Eigen3 version ${EIGEN3_VERSION} found in ${EIGEN3_INCLUDE_DIR}, "
+ "but at least version ${Eigen3_FIND_VERSION} is required")
+ endif(NOT EIGEN3_VERSION_OK)
+endmacro(_eigen3_check_version)
+
+if (EIGEN3_INCLUDE_DIR)
+
+ # in cache already
+ _eigen3_check_version()
+ set(EIGEN3_FOUND ${EIGEN3_VERSION_OK})
+
+else (EIGEN3_INCLUDE_DIR)
+
+ find_path(EIGEN3_INCLUDE_DIR NAMES signature_of_eigen3_matrix_library
+ PATHS
+ ${CMAKE_INSTALL_PREFIX}/include
+ ${KDE4_INCLUDE_DIR}
+ PATH_SUFFIXES eigen3 eigen
+ )
+
+ if(EIGEN3_INCLUDE_DIR)
+ _eigen3_check_version()
+ endif(EIGEN3_INCLUDE_DIR)
+
+ include(FindPackageHandleStandardArgs)
+ find_package_handle_standard_args(Eigen3 DEFAULT_MSG EIGEN3_INCLUDE_DIR EIGEN3_VERSION_OK)
+
+ mark_as_advanced(EIGEN3_INCLUDE_DIR)
+
+endif(EIGEN3_INCLUDE_DIR)
+
diff --git a/ml/dlib/dlib/external/pybind11/tools/FindPythonLibsNew.cmake b/ml/dlib/dlib/external/pybind11/tools/FindPythonLibsNew.cmake
new file mode 100644
index 000000000..b29b287de
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/tools/FindPythonLibsNew.cmake
@@ -0,0 +1,195 @@
+# - Find python libraries
+# This module finds the libraries corresponding to the Python interpreter
+# FindPythonInterp provides.
+# This code sets the following variables:
+#
+# PYTHONLIBS_FOUND - have the Python libs been found
+# PYTHON_PREFIX - path to the Python installation
+# PYTHON_LIBRARIES - path to the python library
+# PYTHON_INCLUDE_DIRS - path to where Python.h is found
+# PYTHON_MODULE_EXTENSION - lib extension, e.g. '.so' or '.pyd'
+# PYTHON_MODULE_PREFIX - lib name prefix: usually an empty string
+# PYTHON_SITE_PACKAGES - path to installation site-packages
+# PYTHON_IS_DEBUG - whether the Python interpreter is a debug build
+#
+# Thanks to talljimbo for the patch adding the 'LDVERSION' config
+# variable usage.
+
+#=============================================================================
+# Copyright 2001-2009 Kitware, Inc.
+# Copyright 2012 Continuum Analytics, Inc.
+#
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the names of Kitware, Inc., the Insight Software Consortium,
+# nor the names of their contributors may be used to endorse or promote
+# products derived from this software without specific prior written
+# permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#=============================================================================
+
+# Checking for the extension makes sure that `LibsNew` was found and not just `Libs`.
+if(PYTHONLIBS_FOUND AND PYTHON_MODULE_EXTENSION)
+ return()
+endif()
+
+# Use the Python interpreter to find the libs.
+if(PythonLibsNew_FIND_REQUIRED)
+ find_package(PythonInterp ${PythonLibsNew_FIND_VERSION} REQUIRED)
+else()
+ find_package(PythonInterp ${PythonLibsNew_FIND_VERSION})
+endif()
+
+if(NOT PYTHONINTERP_FOUND)
+ set(PYTHONLIBS_FOUND FALSE)
+ return()
+endif()
+
+# According to http://stackoverflow.com/questions/646518/python-how-to-detect-debug-interpreter
+# testing whether sys has the gettotalrefcount function is a reliable, cross-platform
+# way to detect a CPython debug interpreter.
+#
+# The library suffix is from the config var LDVERSION sometimes, otherwise
+# VERSION. VERSION will typically be like "2.7" on unix, and "27" on windows.
+execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
+ "from distutils import sysconfig as s;import sys;import struct;
+print('.'.join(str(v) for v in sys.version_info));
+print(sys.prefix);
+print(s.get_python_inc(plat_specific=True));
+print(s.get_python_lib(plat_specific=True));
+print(s.get_config_var('SO'));
+print(hasattr(sys, 'gettotalrefcount')+0);
+print(struct.calcsize('@P'));
+print(s.get_config_var('LDVERSION') or s.get_config_var('VERSION'));
+print(s.get_config_var('LIBDIR') or '');
+print(s.get_config_var('MULTIARCH') or '');
+"
+ RESULT_VARIABLE _PYTHON_SUCCESS
+ OUTPUT_VARIABLE _PYTHON_VALUES
+ ERROR_VARIABLE _PYTHON_ERROR_VALUE)
+
+if(NOT _PYTHON_SUCCESS MATCHES 0)
+ if(PythonLibsNew_FIND_REQUIRED)
+ message(FATAL_ERROR
+ "Python config failure:\n${_PYTHON_ERROR_VALUE}")
+ endif()
+ set(PYTHONLIBS_FOUND FALSE)
+ return()
+endif()
+
+# Convert the process output into a list
+string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES})
+string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
+list(GET _PYTHON_VALUES 0 _PYTHON_VERSION_LIST)
+list(GET _PYTHON_VALUES 1 PYTHON_PREFIX)
+list(GET _PYTHON_VALUES 2 PYTHON_INCLUDE_DIR)
+list(GET _PYTHON_VALUES 3 PYTHON_SITE_PACKAGES)
+list(GET _PYTHON_VALUES 4 PYTHON_MODULE_EXTENSION)
+list(GET _PYTHON_VALUES 5 PYTHON_IS_DEBUG)
+list(GET _PYTHON_VALUES 6 PYTHON_SIZEOF_VOID_P)
+list(GET _PYTHON_VALUES 7 PYTHON_LIBRARY_SUFFIX)
+list(GET _PYTHON_VALUES 8 PYTHON_LIBDIR)
+list(GET _PYTHON_VALUES 9 PYTHON_MULTIARCH)
+
+# Make sure the Python has the same pointer-size as the chosen compiler
+# Skip if CMAKE_SIZEOF_VOID_P is not defined
+if(CMAKE_SIZEOF_VOID_P AND (NOT "${PYTHON_SIZEOF_VOID_P}" STREQUAL "${CMAKE_SIZEOF_VOID_P}"))
+ if(PythonLibsNew_FIND_REQUIRED)
+ math(EXPR _PYTHON_BITS "${PYTHON_SIZEOF_VOID_P} * 8")
+ math(EXPR _CMAKE_BITS "${CMAKE_SIZEOF_VOID_P} * 8")
+ message(FATAL_ERROR
+ "Python config failure: Python is ${_PYTHON_BITS}-bit, "
+ "chosen compiler is ${_CMAKE_BITS}-bit")
+ endif()
+ set(PYTHONLIBS_FOUND FALSE)
+ return()
+endif()
+
+# The built-in FindPython didn't always give the version numbers
+string(REGEX REPLACE "\\." ";" _PYTHON_VERSION_LIST ${_PYTHON_VERSION_LIST})
+list(GET _PYTHON_VERSION_LIST 0 PYTHON_VERSION_MAJOR)
+list(GET _PYTHON_VERSION_LIST 1 PYTHON_VERSION_MINOR)
+list(GET _PYTHON_VERSION_LIST 2 PYTHON_VERSION_PATCH)
+
+# Make sure all directory separators are '/'
+string(REGEX REPLACE "\\\\" "/" PYTHON_PREFIX ${PYTHON_PREFIX})
+string(REGEX REPLACE "\\\\" "/" PYTHON_INCLUDE_DIR ${PYTHON_INCLUDE_DIR})
+string(REGEX REPLACE "\\\\" "/" PYTHON_SITE_PACKAGES ${PYTHON_SITE_PACKAGES})
+
+if(CMAKE_HOST_WIN32)
+ set(PYTHON_LIBRARY
+ "${PYTHON_PREFIX}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib")
+
+ # when run in a venv, PYTHON_PREFIX points to it. But the libraries remain in the
+ # original python installation. They may be found relative to PYTHON_INCLUDE_DIR.
+ if(NOT EXISTS "${PYTHON_LIBRARY}")
+ get_filename_component(_PYTHON_ROOT ${PYTHON_INCLUDE_DIR} DIRECTORY)
+ set(PYTHON_LIBRARY
+ "${_PYTHON_ROOT}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib")
+ endif()
+
+ # raise an error if the python libs are still not found.
+ if(NOT EXISTS "${PYTHON_LIBRARY}")
+ message(FATAL_ERROR "Python libraries not found")
+ endif()
+
+else()
+ if(PYTHON_MULTIARCH)
+ set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}/${PYTHON_MULTIARCH}" "${PYTHON_LIBDIR}")
+ else()
+ set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}")
+ endif()
+ #message(STATUS "Searching for Python libs in ${_PYTHON_LIBS_SEARCH}")
+ # Probably this needs to be more involved. It would be nice if the config
+ # information the python interpreter itself gave us were more complete.
+ find_library(PYTHON_LIBRARY
+ NAMES "python${PYTHON_LIBRARY_SUFFIX}"
+ PATHS ${_PYTHON_LIBS_SEARCH}
+ NO_DEFAULT_PATH)
+
+ # If all else fails, just set the name/version and let the linker figure out the path.
+ if(NOT PYTHON_LIBRARY)
+ set(PYTHON_LIBRARY python${PYTHON_LIBRARY_SUFFIX})
+ endif()
+endif()
+
+MARK_AS_ADVANCED(
+ PYTHON_LIBRARY
+ PYTHON_INCLUDE_DIR
+)
+
+# We use PYTHON_INCLUDE_DIR, PYTHON_LIBRARY and PYTHON_DEBUG_LIBRARY for the
+# cache entries because they are meant to specify the location of a single
+# library. We now set the variables listed by the documentation for this
+# module.
+SET(PYTHON_INCLUDE_DIRS "${PYTHON_INCLUDE_DIR}")
+SET(PYTHON_LIBRARIES "${PYTHON_LIBRARY}")
+SET(PYTHON_DEBUG_LIBRARIES "${PYTHON_DEBUG_LIBRARY}")
+
+find_package_message(PYTHON
+ "Found PythonLibs: ${PYTHON_LIBRARY}"
+ "${PYTHON_EXECUTABLE}${PYTHON_VERSION}")
+
+set(PYTHONLIBS_FOUND TRUE)
diff --git a/ml/dlib/dlib/external/pybind11/tools/check-style.sh b/ml/dlib/dlib/external/pybind11/tools/check-style.sh
new file mode 100755
index 000000000..0a9f7d24f
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/tools/check-style.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+#
+# Script to check include/test code for common pybind11 code style errors.
+#
+# This script currently checks for
+#
+# 1. use of tabs instead of spaces
+# 2. MSDOS-style CRLF endings
+# 3. trailing spaces
+# 4. missing space between keyword and parenthesis, e.g.: for(, if(, while(
+# 5. Missing space between right parenthesis and brace, e.g. 'for (...){'
+# 6. opening brace on its own line. It should always be on the same line as the
+# if/while/for/do statement.
+#
+# Invoke as: tools/check-style.sh
+#
+
+check_style_errors=0
+IFS=$'\n'
+
+found="$( GREP_COLORS='mt=41' GREP_COLOR='41' grep $'\t' include tests/*.{cpp,py,h} docs/*.rst -rn --color=always )"
+if [ -n "$found" ]; then
+ # The mt=41 sets a red background for matched tabs:
+ echo -e '\033[31;01mError: found tab characters in the following files:\033[0m'
+ check_style_errors=1
+ echo "$found" | sed -e 's/^/ /'
+fi
+
+
+found="$( grep -IUlr $'\r' include tests/*.{cpp,py,h} docs/*.rst --color=always )"
+if [ -n "$found" ]; then
+ echo -e '\033[31;01mError: found CRLF characters in the following files:\033[0m'
+ check_style_errors=1
+ echo "$found" | sed -e 's/^/ /'
+fi
+
+found="$(GREP_COLORS='mt=41' GREP_COLOR='41' grep '[[:blank:]]\+$' include tests/*.{cpp,py,h} docs/*.rst -rn --color=always )"
+if [ -n "$found" ]; then
+ # The mt=41 sets a red background for matched trailing spaces
+ echo -e '\033[31;01mError: found trailing spaces in the following files:\033[0m'
+ check_style_errors=1
+ echo "$found" | sed -e 's/^/ /'
+fi
+
+found="$(grep '\<\(if\|for\|while\|catch\)(\|){' include tests/*.{cpp,h} -rn --color=always)"
+if [ -n "$found" ]; then
+ echo -e '\033[31;01mError: found the following coding style problems:\033[0m'
+ check_style_errors=1
+ echo "$found" | sed -e 's/^/ /'
+fi
+
+found="$(awk '
+function prefix(filename, lineno) {
+ return " \033[35m" filename "\033[36m:\033[32m" lineno "\033[36m:\033[0m"
+}
+function mark(pattern, string) { sub(pattern, "\033[01;31m&\033[0m", string); return string }
+last && /^\s*{/ {
+ print prefix(FILENAME, FNR-1) mark("\\)\\s*$", last)
+ print prefix(FILENAME, FNR) mark("^\\s*{", $0)
+ last=""
+}
+{ last = /(if|for|while|catch|switch)\s*\(.*\)\s*$/ ? $0 : "" }
+' $(find include -type f) tests/*.{cpp,h} docs/*.rst)"
+if [ -n "$found" ]; then
+ check_style_errors=1
+ echo -e '\033[31;01mError: braces should occur on the same line as the if/while/.. statement. Found issues in the following files:\033[0m'
+ echo "$found"
+fi
+
+exit $check_style_errors
diff --git a/ml/dlib/dlib/external/pybind11/tools/libsize.py b/ml/dlib/dlib/external/pybind11/tools/libsize.py
new file mode 100644
index 000000000..5dcb8b0d0
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/tools/libsize.py
@@ -0,0 +1,38 @@
+from __future__ import print_function, division
+import os
+import sys
+
+# Internal build script for generating debugging test .so size.
+# Usage:
+# python libsize.py file.so save.txt -- displays the size of file.so and, if save.txt exists, compares it to the
+# size in it, then overwrites save.txt with the new size for future runs.
+
+if len(sys.argv) != 3:
+ sys.exit("Invalid arguments: usage: python libsize.py file.so save.txt")
+
+lib = sys.argv[1]
+save = sys.argv[2]
+
+if not os.path.exists(lib):
+ sys.exit("Error: requested file ({}) does not exist".format(lib))
+
+libsize = os.path.getsize(lib)
+
+print("------", os.path.basename(lib), "file size:", libsize, end='')
+
+if os.path.exists(save):
+ with open(save) as sf:
+ oldsize = int(sf.readline())
+
+ if oldsize > 0:
+ change = libsize - oldsize
+ if change == 0:
+ print(" (no change)")
+ else:
+ print(" (change of {:+} bytes = {:+.2%})".format(change, change / oldsize))
+else:
+ print()
+
+with open(save, 'w') as sf:
+ sf.write(str(libsize))
+
diff --git a/ml/dlib/dlib/external/pybind11/tools/mkdoc.py b/ml/dlib/dlib/external/pybind11/tools/mkdoc.py
new file mode 100644
index 000000000..1fd8cceed
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/tools/mkdoc.py
@@ -0,0 +1,304 @@
+#!/usr/bin/env python3
+#
+# Syntax: mkdoc.py [-I<path> ..] [.. a list of header files ..]
+#
+# Extract documentation from C++ header files to use it in Python bindings
+#
+
+import os
+import sys
+import platform
+import re
+import textwrap
+
+from clang import cindex
+from clang.cindex import CursorKind
+from collections import OrderedDict
+from threading import Thread, Semaphore
+from multiprocessing import cpu_count
+
+RECURSE_LIST = [
+ CursorKind.TRANSLATION_UNIT,
+ CursorKind.NAMESPACE,
+ CursorKind.CLASS_DECL,
+ CursorKind.STRUCT_DECL,
+ CursorKind.ENUM_DECL,
+ CursorKind.CLASS_TEMPLATE
+]
+
+PRINT_LIST = [
+ CursorKind.CLASS_DECL,
+ CursorKind.STRUCT_DECL,
+ CursorKind.ENUM_DECL,
+ CursorKind.ENUM_CONSTANT_DECL,
+ CursorKind.CLASS_TEMPLATE,
+ CursorKind.FUNCTION_DECL,
+ CursorKind.FUNCTION_TEMPLATE,
+ CursorKind.CONVERSION_FUNCTION,
+ CursorKind.CXX_METHOD,
+ CursorKind.CONSTRUCTOR,
+ CursorKind.FIELD_DECL
+]
+
+CPP_OPERATORS = {
+ '<=': 'le', '>=': 'ge', '==': 'eq', '!=': 'ne', '[]': 'array',
+ '+=': 'iadd', '-=': 'isub', '*=': 'imul', '/=': 'idiv', '%=':
+ 'imod', '&=': 'iand', '|=': 'ior', '^=': 'ixor', '<<=': 'ilshift',
+ '>>=': 'irshift', '++': 'inc', '--': 'dec', '<<': 'lshift', '>>':
+ 'rshift', '&&': 'land', '||': 'lor', '!': 'lnot', '~': 'bnot',
+ '&': 'band', '|': 'bor', '+': 'add', '-': 'sub', '*': 'mul', '/':
+ 'div', '%': 'mod', '<': 'lt', '>': 'gt', '=': 'assign', '()': 'call'
+}
+
+CPP_OPERATORS = OrderedDict(
+ sorted(CPP_OPERATORS.items(), key=lambda t: -len(t[0])))
+
+job_count = cpu_count()
+job_semaphore = Semaphore(job_count)
+
+output = []
+
+def d(s):
+ return s.decode('utf8')
+
+
+def sanitize_name(name):
+ name = re.sub(r'type-parameter-0-([0-9]+)', r'T\1', name)
+ for k, v in CPP_OPERATORS.items():
+ name = name.replace('operator%s' % k, 'operator_%s' % v)
+ name = re.sub('<.*>', '', name)
+ name = ''.join([ch if ch.isalnum() else '_' for ch in name])
+ name = re.sub('_$', '', re.sub('_+', '_', name))
+ return '__doc_' + name
+
+
+def process_comment(comment):
+ result = ''
+
+ # Remove C++ comment syntax
+ leading_spaces = float('inf')
+ for s in comment.expandtabs(tabsize=4).splitlines():
+ s = s.strip()
+ if s.startswith('/*'):
+ s = s[2:].lstrip('*')
+ elif s.endswith('*/'):
+ s = s[:-2].rstrip('*')
+ elif s.startswith('///'):
+ s = s[3:]
+ if s.startswith('*'):
+ s = s[1:]
+ if len(s) > 0:
+ leading_spaces = min(leading_spaces, len(s) - len(s.lstrip()))
+ result += s + '\n'
+
+ if leading_spaces != float('inf'):
+ result2 = ""
+ for s in result.splitlines():
+ result2 += s[leading_spaces:] + '\n'
+ result = result2
+
+ # Doxygen tags
+ cpp_group = '([\w:]+)'
+ param_group = '([\[\w:\]]+)'
+
+ s = result
+ s = re.sub(r'\\c\s+%s' % cpp_group, r'``\1``', s)
+ s = re.sub(r'\\a\s+%s' % cpp_group, r'*\1*', s)
+ s = re.sub(r'\\e\s+%s' % cpp_group, r'*\1*', s)
+ s = re.sub(r'\\em\s+%s' % cpp_group, r'*\1*', s)
+ s = re.sub(r'\\b\s+%s' % cpp_group, r'**\1**', s)
+ s = re.sub(r'\\ingroup\s+%s' % cpp_group, r'', s)
+ s = re.sub(r'\\param%s?\s+%s' % (param_group, cpp_group),
+ r'\n\n$Parameter ``\2``:\n\n', s)
+ s = re.sub(r'\\tparam%s?\s+%s' % (param_group, cpp_group),
+ r'\n\n$Template parameter ``\2``:\n\n', s)
+
+ for in_, out_ in {
+ 'return': 'Returns',
+ 'author': 'Author',
+ 'authors': 'Authors',
+ 'copyright': 'Copyright',
+ 'date': 'Date',
+ 'remark': 'Remark',
+ 'sa': 'See also',
+ 'see': 'See also',
+ 'extends': 'Extends',
+ 'throw': 'Throws',
+ 'throws': 'Throws'
+ }.items():
+ s = re.sub(r'\\%s\s*' % in_, r'\n\n$%s:\n\n' % out_, s)
+
+ s = re.sub(r'\\details\s*', r'\n\n', s)
+ s = re.sub(r'\\brief\s*', r'', s)
+ s = re.sub(r'\\short\s*', r'', s)
+ s = re.sub(r'\\ref\s*', r'', s)
+
+ s = re.sub(r'\\code\s?(.*?)\s?\\endcode',
+ r"```\n\1\n```\n", s, flags=re.DOTALL)
+
+ # HTML/TeX tags
+ s = re.sub(r'<tt>(.*?)</tt>', r'``\1``', s, flags=re.DOTALL)
+ s = re.sub(r'<pre>(.*?)</pre>', r"```\n\1\n```\n", s, flags=re.DOTALL)
+ s = re.sub(r'<em>(.*?)</em>', r'*\1*', s, flags=re.DOTALL)
+ s = re.sub(r'<b>(.*?)</b>', r'**\1**', s, flags=re.DOTALL)
+ s = re.sub(r'\\f\$(.*?)\\f\$', r'$\1$', s, flags=re.DOTALL)
+ s = re.sub(r'<li>', r'\n\n* ', s)
+ s = re.sub(r'</?ul>', r'', s)
+ s = re.sub(r'</li>', r'\n\n', s)
+
+ s = s.replace('``true``', '``True``')
+ s = s.replace('``false``', '``False``')
+
+ # Re-flow text
+ wrapper = textwrap.TextWrapper()
+ wrapper.expand_tabs = True
+ wrapper.replace_whitespace = True
+ wrapper.drop_whitespace = True
+ wrapper.width = 70
+ wrapper.initial_indent = wrapper.subsequent_indent = ''
+
+ result = ''
+ in_code_segment = False
+ for x in re.split(r'(```)', s):
+ if x == '```':
+ if not in_code_segment:
+ result += '```\n'
+ else:
+ result += '\n```\n\n'
+ in_code_segment = not in_code_segment
+ elif in_code_segment:
+ result += x.strip()
+ else:
+ for y in re.split(r'(?: *\n *){2,}', x):
+ wrapped = wrapper.fill(re.sub(r'\s+', ' ', y).strip())
+ if len(wrapped) > 0 and wrapped[0] == '$':
+ result += wrapped[1:] + '\n'
+ wrapper.initial_indent = \
+ wrapper.subsequent_indent = ' ' * 4
+ else:
+ if len(wrapped) > 0:
+ result += wrapped + '\n\n'
+ wrapper.initial_indent = wrapper.subsequent_indent = ''
+ return result.rstrip().lstrip('\n')
+
+
+def extract(filename, node, prefix):
+ if not (node.location.file is None or
+ os.path.samefile(d(node.location.file.name), filename)):
+ return 0
+ if node.kind in RECURSE_LIST:
+ sub_prefix = prefix
+ if node.kind != CursorKind.TRANSLATION_UNIT:
+ if len(sub_prefix) > 0:
+ sub_prefix += '_'
+ sub_prefix += d(node.spelling)
+ for i in node.get_children():
+ extract(filename, i, sub_prefix)
+ if node.kind in PRINT_LIST:
+ comment = d(node.raw_comment) if node.raw_comment is not None else ''
+ comment = process_comment(comment)
+ sub_prefix = prefix
+ if len(sub_prefix) > 0:
+ sub_prefix += '_'
+ if len(node.spelling) > 0:
+ name = sanitize_name(sub_prefix + d(node.spelling))
+ global output
+ output.append((name, filename, comment))
+
+
+class ExtractionThread(Thread):
+ def __init__(self, filename, parameters):
+ Thread.__init__(self)
+ self.filename = filename
+ self.parameters = parameters
+ job_semaphore.acquire()
+
+ def run(self):
+ print('Processing "%s" ..' % self.filename, file=sys.stderr)
+ try:
+ index = cindex.Index(
+ cindex.conf.lib.clang_createIndex(False, True))
+ tu = index.parse(self.filename, self.parameters)
+ extract(self.filename, tu.cursor, '')
+ finally:
+ job_semaphore.release()
+
+if __name__ == '__main__':
+ parameters = ['-x', 'c++', '-std=c++11']
+ filenames = []
+
+ if platform.system() == 'Darwin':
+ dev_path = '/Applications/Xcode.app/Contents/Developer/'
+ lib_dir = dev_path + 'Toolchains/XcodeDefault.xctoolchain/usr/lib/'
+ sdk_dir = dev_path + 'Platforms/MacOSX.platform/Developer/SDKs'
+ libclang = lib_dir + 'libclang.dylib'
+
+ if os.path.exists(libclang):
+ cindex.Config.set_library_path(os.path.dirname(libclang))
+
+ if os.path.exists(sdk_dir):
+ sysroot_dir = os.path.join(sdk_dir, next(os.walk(sdk_dir))[1][0])
+ parameters.append('-isysroot')
+ parameters.append(sysroot_dir)
+
+ for item in sys.argv[1:]:
+ if item.startswith('-'):
+ parameters.append(item)
+ else:
+ filenames.append(item)
+
+ if len(filenames) == 0:
+ print('Syntax: %s [.. a list of header files ..]' % sys.argv[0])
+ exit(-1)
+
+ print('''/*
+ This file contains docstrings for the Python bindings.
+ Do not edit! These were automatically extracted by mkdoc.py
+ */
+
+#define __EXPAND(x) x
+#define __COUNT(_1, _2, _3, _4, _5, _6, _7, COUNT, ...) COUNT
+#define __VA_SIZE(...) __EXPAND(__COUNT(__VA_ARGS__, 7, 6, 5, 4, 3, 2, 1))
+#define __CAT1(a, b) a ## b
+#define __CAT2(a, b) __CAT1(a, b)
+#define __DOC1(n1) __doc_##n1
+#define __DOC2(n1, n2) __doc_##n1##_##n2
+#define __DOC3(n1, n2, n3) __doc_##n1##_##n2##_##n3
+#define __DOC4(n1, n2, n3, n4) __doc_##n1##_##n2##_##n3##_##n4
+#define __DOC5(n1, n2, n3, n4, n5) __doc_##n1##_##n2##_##n3##_##n4##_##n5
+#define __DOC6(n1, n2, n3, n4, n5, n6) __doc_##n1##_##n2##_##n3##_##n4##_##n5##_##n6
+#define __DOC7(n1, n2, n3, n4, n5, n6, n7) __doc_##n1##_##n2##_##n3##_##n4##_##n5##_##n6##_##n7
+#define DOC(...) __EXPAND(__EXPAND(__CAT2(__DOC, __VA_SIZE(__VA_ARGS__)))(__VA_ARGS__))
+
+#if defined(__GNUG__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-variable"
+#endif
+''')
+
+ output.clear()
+ for filename in filenames:
+ thr = ExtractionThread(filename, parameters)
+ thr.start()
+
+ print('Waiting for jobs to finish ..', file=sys.stderr)
+ for i in range(job_count):
+ job_semaphore.acquire()
+
+ name_ctr = 1
+ name_prev = None
+ for name, _, comment in list(sorted(output, key=lambda x: (x[0], x[1]))):
+ if name == name_prev:
+ name_ctr += 1
+ name = name + "_%i" % name_ctr
+ else:
+ name_prev = name
+ name_ctr = 1
+ print('\nstatic const char *%s =%sR"doc(%s)doc";' %
+ (name, '\n' if '\n' in comment else ' ', comment))
+
+ print('''
+#if defined(__GNUG__)
+#pragma GCC diagnostic pop
+#endif
+''')
diff --git a/ml/dlib/dlib/external/pybind11/tools/pybind11Config.cmake.in b/ml/dlib/dlib/external/pybind11/tools/pybind11Config.cmake.in
new file mode 100644
index 000000000..3dd1b2c1a
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/tools/pybind11Config.cmake.in
@@ -0,0 +1,100 @@
+# pybind11Config.cmake
+# --------------------
+#
+# PYBIND11 cmake module.
+# This module sets the following variables in your project::
+#
+# pybind11_FOUND - true if pybind11 and all required components found on the system
+# pybind11_VERSION - pybind11 version in format Major.Minor.Release
+# pybind11_INCLUDE_DIRS - Directories where pybind11 and python headers are located.
+# pybind11_INCLUDE_DIR - Directory where pybind11 headers are located.
+# pybind11_DEFINITIONS - Definitions necessary to use pybind11, namely USING_pybind11.
+# pybind11_LIBRARIES - compile flags and python libraries (as needed) to link against.
+# pybind11_LIBRARY - empty.
+# CMAKE_MODULE_PATH - appends location of accompanying FindPythonLibsNew.cmake and
+# pybind11Tools.cmake modules.
+#
+#
+# Available components: None
+#
+#
+# Exported targets::
+#
+# If pybind11 is found, this module defines the following :prop_tgt:`IMPORTED`
+# interface library targets::
+#
+# pybind11::module - for extension modules
+# pybind11::embed - for embedding the Python interpreter
+#
+# Python headers, libraries (as needed by platform), and the C++ standard
+# are attached to the target. Set PythonLibsNew variables to influence
+# python detection and PYBIND11_CPP_STANDARD (-std=c++11 or -std=c++14) to
+# influence standard setting. ::
+#
+# find_package(pybind11 CONFIG REQUIRED)
+# message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIRS}")
+#
+# # Create an extension module
+# add_library(mylib MODULE main.cpp)
+# target_link_libraries(mylib pybind11::module)
+#
+# # Or embed the Python interpreter into an executable
+# add_executable(myexe main.cpp)
+# target_link_libraries(myexe pybind11::embed)
+#
+# Suggested usage::
+#
+# find_package with version info is not recommended except for release versions. ::
+#
+# find_package(pybind11 CONFIG)
+# find_package(pybind11 2.0 EXACT CONFIG REQUIRED)
+#
+#
+# The following variables can be set to guide the search for this package::
+#
+# pybind11_DIR - CMake variable, set to directory containing this Config file
+# CMAKE_PREFIX_PATH - CMake variable, set to root directory of this package
+# PATH - environment variable, set to bin directory of this package
+# CMAKE_DISABLE_FIND_PACKAGE_pybind11 - CMake variable, disables
+# find_package(pybind11) when not REQUIRED, perhaps to force internal build
+
+@PACKAGE_INIT@
+
+set(PN pybind11)
+
+# location of pybind11/pybind11.h
+set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/@CMAKE_INSTALL_INCLUDEDIR@")
+
+set(${PN}_LIBRARY "")
+set(${PN}_DEFINITIONS USING_${PN})
+
+check_required_components(${PN})
+
+# make detectable the FindPythonLibsNew.cmake module
+list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR})
+
+include(pybind11Tools)
+
+if(NOT (CMAKE_VERSION VERSION_LESS 3.0))
+#-----------------------------------------------------------------------------
+# Don't include targets if this file is being picked up by another
+# project which has already built this as a subproject
+#-----------------------------------------------------------------------------
+if(NOT TARGET ${PN}::pybind11)
+ include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake")
+
+ find_package(PythonLibsNew ${PYBIND11_PYTHON_VERSION} MODULE REQUIRED)
+ set_property(TARGET ${PN}::pybind11 APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${PYTHON_INCLUDE_DIRS})
+ set_property(TARGET ${PN}::embed APPEND PROPERTY INTERFACE_LINK_LIBRARIES ${PYTHON_LIBRARIES})
+ if(WIN32 OR CYGWIN)
+ set_property(TARGET ${PN}::module APPEND PROPERTY INTERFACE_LINK_LIBRARIES ${PYTHON_LIBRARIES})
+ endif()
+
+ set_property(TARGET ${PN}::pybind11 APPEND PROPERTY INTERFACE_COMPILE_OPTIONS "${PYBIND11_CPP_STANDARD}")
+
+ get_property(_iid TARGET ${PN}::pybind11 PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
+ get_property(_ill TARGET ${PN}::module PROPERTY INTERFACE_LINK_LIBRARIES)
+ set(${PN}_INCLUDE_DIRS ${_iid})
+ set(${PN}_LIBRARIES ${_ico} ${_ill})
+endif()
+endif()
diff --git a/ml/dlib/dlib/external/pybind11/tools/pybind11Tools.cmake b/ml/dlib/dlib/external/pybind11/tools/pybind11Tools.cmake
new file mode 100644
index 000000000..a7c471a07
--- /dev/null
+++ b/ml/dlib/dlib/external/pybind11/tools/pybind11Tools.cmake
@@ -0,0 +1,202 @@
+# tools/pybind11Tools.cmake -- Build system for the pybind11 modules
+#
+# Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>
+#
+# All rights reserved. Use of this source code is governed by a
+# BSD-style license that can be found in the LICENSE file.
+
+cmake_minimum_required(VERSION 2.8.12)
+
+# Add a CMake parameter for choosing a desired Python version
+if(NOT PYBIND11_PYTHON_VERSION)
+ set(PYBIND11_PYTHON_VERSION "" CACHE STRING "Python version to use for compiling modules")
+endif()
+
+set(Python_ADDITIONAL_VERSIONS 3.7 3.6 3.5 3.4)
+find_package(PythonLibsNew ${PYBIND11_PYTHON_VERSION} REQUIRED)
+
+include(CheckCXXCompilerFlag)
+include(CMakeParseArguments)
+
+if(NOT PYBIND11_CPP_STANDARD AND NOT CMAKE_CXX_STANDARD)
+ if(NOT MSVC)
+ check_cxx_compiler_flag("-std=c++14" HAS_CPP14_FLAG)
+
+ if (HAS_CPP14_FLAG)
+ set(PYBIND11_CPP_STANDARD -std=c++14)
+ else()
+ check_cxx_compiler_flag("-std=c++11" HAS_CPP11_FLAG)
+ if (HAS_CPP11_FLAG)
+ set(PYBIND11_CPP_STANDARD -std=c++11)
+ else()
+ message(FATAL_ERROR "Unsupported compiler -- pybind11 requires C++11 support!")
+ endif()
+ endif()
+ elseif(MSVC)
+ set(PYBIND11_CPP_STANDARD /std:c++14)
+ endif()
+
+ set(PYBIND11_CPP_STANDARD ${PYBIND11_CPP_STANDARD} CACHE STRING
+ "C++ standard flag, e.g. -std=c++11, -std=c++14, /std:c++14. Defaults to C++14 mode." FORCE)
+endif()
+
+# Checks whether the given CXX/linker flags can compile and link a cxx file. cxxflags and
+# linkerflags are lists of flags to use. The result variable is a unique variable name for each set
+# of flags: the compilation result will be cached base on the result variable. If the flags work,
+# sets them in cxxflags_out/linkerflags_out internal cache variables (in addition to ${result}).
+function(_pybind11_return_if_cxx_and_linker_flags_work result cxxflags linkerflags cxxflags_out linkerflags_out)
+ set(CMAKE_REQUIRED_LIBRARIES ${linkerflags})
+ check_cxx_compiler_flag("${cxxflags}" ${result})
+ if (${result})
+ set(${cxxflags_out} "${cxxflags}" CACHE INTERNAL "" FORCE)
+ set(${linkerflags_out} "${linkerflags}" CACHE INTERNAL "" FORCE)
+ endif()
+endfunction()
+
+# Internal: find the appropriate link time optimization flags for this compiler
+function(_pybind11_add_lto_flags target_name prefer_thin_lto)
+ if (NOT DEFINED PYBIND11_LTO_CXX_FLAGS)
+ set(PYBIND11_LTO_CXX_FLAGS "" CACHE INTERNAL "")
+ set(PYBIND11_LTO_LINKER_FLAGS "" CACHE INTERNAL "")
+
+ if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
+ set(cxx_append "")
+ set(linker_append "")
+ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND NOT APPLE)
+ # Clang Gold plugin does not support -Os; append -O3 to MinSizeRel builds to override it
+ set(linker_append ";$<$<CONFIG:MinSizeRel>:-O3>")
+ elseif(CMAKE_CXX_COMPILER_ID MATCHES "GNU")
+ set(cxx_append ";-fno-fat-lto-objects")
+ endif()
+
+ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND prefer_thin_lto)
+ _pybind11_return_if_cxx_and_linker_flags_work(HAS_FLTO_THIN
+ "-flto=thin${cxx_append}" "-flto=thin${linker_append}"
+ PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS)
+ endif()
+
+ if (NOT HAS_FLTO_THIN)
+ _pybind11_return_if_cxx_and_linker_flags_work(HAS_FLTO
+ "-flto${cxx_append}" "-flto${linker_append}"
+ PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS)
+ endif()
+ elseif (CMAKE_CXX_COMPILER_ID MATCHES "Intel")
+ # Intel equivalent to LTO is called IPO
+ _pybind11_return_if_cxx_and_linker_flags_work(HAS_INTEL_IPO
+ "-ipo" "-ipo" PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS)
+ elseif(MSVC)
+ # cmake only interprets libraries as linker flags when they start with a - (otherwise it
+ # converts /LTCG to \LTCG as if it was a Windows path). Luckily MSVC supports passing flags
+ # with - instead of /, even if it is a bit non-standard:
+ _pybind11_return_if_cxx_and_linker_flags_work(HAS_MSVC_GL_LTCG
+ "/GL" "-LTCG" PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS)
+ endif()
+
+ if (PYBIND11_LTO_CXX_FLAGS)
+ message(STATUS "LTO enabled")
+ else()
+ message(STATUS "LTO disabled (not supported by the compiler and/or linker)")
+ endif()
+ endif()
+
+ # Enable LTO flags if found, except for Debug builds
+ if (PYBIND11_LTO_CXX_FLAGS)
+ target_compile_options(${target_name} PRIVATE "$<$<NOT:$<CONFIG:Debug>>:${PYBIND11_LTO_CXX_FLAGS}>")
+ endif()
+ if (PYBIND11_LTO_LINKER_FLAGS)
+ target_link_libraries(${target_name} PRIVATE "$<$<NOT:$<CONFIG:Debug>>:${PYBIND11_LTO_LINKER_FLAGS}>")
+ endif()
+endfunction()
+
+# Build a Python extension module:
+# pybind11_add_module(<name> [MODULE | SHARED] [EXCLUDE_FROM_ALL]
+# [NO_EXTRAS] [THIN_LTO] source1 [source2 ...])
+#
+function(pybind11_add_module target_name)
+ set(options MODULE SHARED EXCLUDE_FROM_ALL NO_EXTRAS THIN_LTO)
+ cmake_parse_arguments(ARG "${options}" "" "" ${ARGN})
+
+ if(ARG_MODULE AND ARG_SHARED)
+ message(FATAL_ERROR "Can't be both MODULE and SHARED")
+ elseif(ARG_SHARED)
+ set(lib_type SHARED)
+ else()
+ set(lib_type MODULE)
+ endif()
+
+ if(ARG_EXCLUDE_FROM_ALL)
+ set(exclude_from_all EXCLUDE_FROM_ALL)
+ endif()
+
+ add_library(${target_name} ${lib_type} ${exclude_from_all} ${ARG_UNPARSED_ARGUMENTS})
+
+ target_include_directories(${target_name}
+ PRIVATE ${PYBIND11_INCLUDE_DIR} # from project CMakeLists.txt
+ PRIVATE ${pybind11_INCLUDE_DIR} # from pybind11Config
+ PRIVATE ${PYTHON_INCLUDE_DIRS})
+
+ # The prefix and extension are provided by FindPythonLibsNew.cmake
+ set_target_properties(${target_name} PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}")
+ set_target_properties(${target_name} PROPERTIES SUFFIX "${PYTHON_MODULE_EXTENSION}")
+
+ # -fvisibility=hidden is required to allow multiple modules compiled against
+ # different pybind versions to work properly, and for some features (e.g.
+ # py::module_local). We force it on everything inside the `pybind11`
+ # namespace; also turning it on for a pybind module compilation here avoids
+ # potential warnings or issues from having mixed hidden/non-hidden types.
+ set_target_properties(${target_name} PROPERTIES CXX_VISIBILITY_PRESET "hidden")
+
+ if(WIN32 OR CYGWIN)
+ # Link against the Python shared library on Windows
+ target_link_libraries(${target_name} PRIVATE ${PYTHON_LIBRARIES})
+ elseif(APPLE)
+ # It's quite common to have multiple copies of the same Python version
+ # installed on one's system. E.g.: one copy from the OS and another copy
+ # that's statically linked into an application like Blender or Maya.
+ # If we link our plugin library against the OS Python here and import it
+ # into Blender or Maya later on, this will cause segfaults when multiple
+ # conflicting Python instances are active at the same time (even when they
+ # are of the same version).
+
+ # Windows is not affected by this issue since it handles DLL imports
+ # differently. The solution for Linux and Mac OS is simple: we just don't
+ # link against the Python library. The resulting shared library will have
+ # missing symbols, but that's perfectly fine -- they will be resolved at
+ # import time.
+
+ target_link_libraries(${target_name} PRIVATE "-undefined dynamic_lookup")
+
+ if(ARG_SHARED)
+ # Suppress CMake >= 3.0 warning for shared libraries
+ set_target_properties(${target_name} PROPERTIES MACOSX_RPATH ON)
+ endif()
+ endif()
+
+ # Make sure C++11/14 are enabled
+ target_compile_options(${target_name} PUBLIC ${PYBIND11_CPP_STANDARD})
+
+ if(ARG_NO_EXTRAS)
+ return()
+ endif()
+
+ _pybind11_add_lto_flags(${target_name} ${ARG_THIN_LTO})
+
+ if (NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug)
+ # Strip unnecessary sections of the binary on Linux/Mac OS
+ if(CMAKE_STRIP)
+ if(APPLE)
+ add_custom_command(TARGET ${target_name} POST_BUILD
+ COMMAND ${CMAKE_STRIP} -x $<TARGET_FILE:${target_name}>)
+ else()
+ add_custom_command(TARGET ${target_name} POST_BUILD
+ COMMAND ${CMAKE_STRIP} $<TARGET_FILE:${target_name}>)
+ endif()
+ endif()
+ endif()
+
+ if(MSVC)
+ # /MP enables multithreaded builds (relevant when there are many files), /bigobj is
+ # needed for bigger binding projects due to the limit to 64k addressable sections
+ target_compile_options(${target_name} PRIVATE /MP /bigobj)
+ endif()
+endfunction()
diff --git a/ml/dlib/dlib/external/zlib/README b/ml/dlib/dlib/external/zlib/README
new file mode 100644
index 000000000..5ca9d127e
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/README
@@ -0,0 +1,115 @@
+ZLIB DATA COMPRESSION LIBRARY
+
+zlib 1.2.8 is a general purpose data compression library. All the code is
+thread safe. The data format used by the zlib library is described by RFCs
+(Request for Comments) 1950 to 1952 in the files
+http://tools.ietf.org/html/rfc1950 (zlib format), rfc1951 (deflate format) and
+rfc1952 (gzip format).
+
+All functions of the compression library are documented in the file zlib.h
+(volunteer to write man pages welcome, contact zlib@gzip.org). A usage example
+of the library is given in the file test/example.c which also tests that
+the library is working correctly. Another example is given in the file
+test/minigzip.c. The compression library itself is composed of all source
+files in the root directory.
+
+To compile all files and run the test program, follow the instructions given at
+the top of Makefile.in. In short "./configure; make test", and if that goes
+well, "make install" should work for most flavors of Unix. For Windows, use
+one of the special makefiles in win32/ or contrib/vstudio/ . For VMS, use
+make_vms.com.
+
+Questions about zlib should be sent to <zlib@gzip.org>, or to Gilles Vollant
+<info@winimage.com> for the Windows DLL version. The zlib home page is
+http://zlib.net/ . Before reporting a problem, please check this site to
+verify that you have the latest version of zlib; otherwise get the latest
+version and check whether the problem still exists or not.
+
+PLEASE read the zlib FAQ http://zlib.net/zlib_faq.html before asking for help.
+
+Mark Nelson <markn@ieee.org> wrote an article about zlib for the Jan. 1997
+issue of Dr. Dobb's Journal; a copy of the article is available at
+http://marknelson.us/1997/01/01/zlib-engine/ .
+
+The changes made in version 1.2.8 are documented in the file ChangeLog.
+
+Unsupported third party contributions are provided in directory contrib/ .
+
+zlib is available in Java using the java.util.zip package, documented at
+http://java.sun.com/developer/technicalArticles/Programming/compression/ .
+
+A Perl interface to zlib written by Paul Marquess <pmqs@cpan.org> is available
+at CPAN (Comprehensive Perl Archive Network) sites, including
+http://search.cpan.org/~pmqs/IO-Compress-Zlib/ .
+
+A Python interface to zlib written by A.M. Kuchling <amk@amk.ca> is
+available in Python 1.5 and later versions, see
+http://docs.python.org/library/zlib.html .
+
+zlib is built into tcl: http://wiki.tcl.tk/4610 .
+
+An experimental package to read and write files in .zip format, written on top
+of zlib by Gilles Vollant <info@winimage.com>, is available in the
+contrib/minizip directory of zlib.
+
+
+Notes for some targets:
+
+- For Windows DLL versions, please see win32/DLL_FAQ.txt
+
+- For 64-bit Irix, deflate.c must be compiled without any optimization. With
+ -O, one libpng test fails. The test works in 32 bit mode (with the -n32
+ compiler flag). The compiler bug has been reported to SGI.
+
+- zlib doesn't work with gcc 2.6.3 on a DEC 3000/300LX under OSF/1 2.1 it works
+ when compiled with cc.
+
+- On Digital Unix 4.0D (formely OSF/1) on AlphaServer, the cc option -std1 is
+ necessary to get gzprintf working correctly. This is done by configure.
+
+- zlib doesn't work on HP-UX 9.05 with some versions of /bin/cc. It works with
+ other compilers. Use "make test" to check your compiler.
+
+- gzdopen is not supported on RISCOS or BEOS.
+
+- For PalmOs, see http://palmzlib.sourceforge.net/
+
+
+Acknowledgments:
+
+ The deflate format used by zlib was defined by Phil Katz. The deflate and
+ zlib specifications were written by L. Peter Deutsch. Thanks to all the
+ people who reported problems and suggested various improvements in zlib; they
+ are too numerous to cite here.
+
+Copyright notice:
+
+ (C) 1995-2013 Jean-loup Gailly and Mark Adler
+
+ This software is provided 'as-is', without any express or implied
+ warranty. In no event will the authors be held liable for any damages
+ arising from the use of this software.
+
+ Permission is granted to anyone to use this software for any purpose,
+ including commercial applications, and to alter it and redistribute it
+ freely, subject to the following restrictions:
+
+ 1. The origin of this software must not be misrepresented; you must not
+ claim that you wrote the original software. If you use this software
+ in a product, an acknowledgment in the product documentation would be
+ appreciated but is not required.
+ 2. Altered source versions must be plainly marked as such, and must not be
+ misrepresented as being the original software.
+ 3. This notice may not be removed or altered from any source distribution.
+
+ Jean-loup Gailly Mark Adler
+ jloup@gzip.org madler@alumni.caltech.edu
+
+If you use the zlib library in a product, we would appreciate *not* receiving
+lengthy legal documents to sign. The sources are provided for free but without
+warranty of any kind. The library has been entirely written by Jean-loup
+Gailly and Mark Adler; it does not include third-party code.
+
+If you redistribute modified sources, we would appreciate that you include in
+the file ChangeLog history information documenting your changes. Please read
+the FAQ for more information on the distribution of modified source versions.
diff --git a/ml/dlib/dlib/external/zlib/adler32.c b/ml/dlib/dlib/external/zlib/adler32.c
new file mode 100644
index 000000000..a868f073d
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/adler32.c
@@ -0,0 +1,179 @@
+/* adler32.c -- compute the Adler-32 checksum of a data stream
+ * Copyright (C) 1995-2011 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* @(#) $Id$ */
+
+#include "zutil.h"
+
+#define local static
+
+local uLong adler32_combine_ OF((uLong adler1, uLong adler2, z_off64_t len2));
+
+#define BASE 65521 /* largest prime smaller than 65536 */
+#define NMAX 5552
+/* NMAX is the largest n such that 255n(n+1)/2 + (n+1)(BASE-1) <= 2^32-1 */
+
+#define DO1(buf,i) {adler += (buf)[i]; sum2 += adler;}
+#define DO2(buf,i) DO1(buf,i); DO1(buf,i+1);
+#define DO4(buf,i) DO2(buf,i); DO2(buf,i+2);
+#define DO8(buf,i) DO4(buf,i); DO4(buf,i+4);
+#define DO16(buf) DO8(buf,0); DO8(buf,8);
+
+/* use NO_DIVIDE if your processor does not do division in hardware --
+ try it both ways to see which is faster */
+#ifdef NO_DIVIDE
+/* note that this assumes BASE is 65521, where 65536 % 65521 == 15
+ (thank you to John Reiser for pointing this out) */
+# define CHOP(a) \
+ do { \
+ unsigned long tmp = a >> 16; \
+ a &= 0xffffUL; \
+ a += (tmp << 4) - tmp; \
+ } while (0)
+# define MOD28(a) \
+ do { \
+ CHOP(a); \
+ if (a >= BASE) a -= BASE; \
+ } while (0)
+# define MOD(a) \
+ do { \
+ CHOP(a); \
+ MOD28(a); \
+ } while (0)
+# define MOD63(a) \
+ do { /* this assumes a is not negative */ \
+ z_off64_t tmp = a >> 32; \
+ a &= 0xffffffffL; \
+ a += (tmp << 8) - (tmp << 5) + tmp; \
+ tmp = a >> 16; \
+ a &= 0xffffL; \
+ a += (tmp << 4) - tmp; \
+ tmp = a >> 16; \
+ a &= 0xffffL; \
+ a += (tmp << 4) - tmp; \
+ if (a >= BASE) a -= BASE; \
+ } while (0)
+#else
+# define MOD(a) a %= BASE
+# define MOD28(a) a %= BASE
+# define MOD63(a) a %= BASE
+#endif
+
+/* ========================================================================= */
+uLong ZEXPORT adler32(adler, buf, len)
+ uLong adler;
+ const Bytef *buf;
+ uInt len;
+{
+ unsigned long sum2;
+ unsigned n;
+
+ /* split Adler-32 into component sums */
+ sum2 = (adler >> 16) & 0xffff;
+ adler &= 0xffff;
+
+ /* in case user likes doing a byte at a time, keep it fast */
+ if (len == 1) {
+ adler += buf[0];
+ if (adler >= BASE)
+ adler -= BASE;
+ sum2 += adler;
+ if (sum2 >= BASE)
+ sum2 -= BASE;
+ return adler | (sum2 << 16);
+ }
+
+ /* initial Adler-32 value (deferred check for len == 1 speed) */
+ if (buf == Z_NULL)
+ return 1L;
+
+ /* in case short lengths are provided, keep it somewhat fast */
+ if (len < 16) {
+ while (len--) {
+ adler += *buf++;
+ sum2 += adler;
+ }
+ if (adler >= BASE)
+ adler -= BASE;
+ MOD28(sum2); /* only added so many BASE's */
+ return adler | (sum2 << 16);
+ }
+
+ /* do length NMAX blocks -- requires just one modulo operation */
+ while (len >= NMAX) {
+ len -= NMAX;
+ n = NMAX / 16; /* NMAX is divisible by 16 */
+ do {
+ DO16(buf); /* 16 sums unrolled */
+ buf += 16;
+ } while (--n);
+ MOD(adler);
+ MOD(sum2);
+ }
+
+ /* do remaining bytes (less than NMAX, still just one modulo) */
+ if (len) { /* avoid modulos if none remaining */
+ while (len >= 16) {
+ len -= 16;
+ DO16(buf);
+ buf += 16;
+ }
+ while (len--) {
+ adler += *buf++;
+ sum2 += adler;
+ }
+ MOD(adler);
+ MOD(sum2);
+ }
+
+ /* return recombined sums */
+ return adler | (sum2 << 16);
+}
+
+/* ========================================================================= */
+local uLong adler32_combine_(adler1, adler2, len2)
+ uLong adler1;
+ uLong adler2;
+ z_off64_t len2;
+{
+ unsigned long sum1;
+ unsigned long sum2;
+ unsigned rem;
+
+ /* for negative len, return invalid adler32 as a clue for debugging */
+ if (len2 < 0)
+ return 0xffffffffUL;
+
+ /* the derivation of this formula is left as an exercise for the reader */
+ MOD63(len2); /* assumes len2 >= 0 */
+ rem = (unsigned)len2;
+ sum1 = adler1 & 0xffff;
+ sum2 = rem * sum1;
+ MOD(sum2);
+ sum1 += (adler2 & 0xffff) + BASE - 1;
+ sum2 += ((adler1 >> 16) & 0xffff) + ((adler2 >> 16) & 0xffff) + BASE - rem;
+ if (sum1 >= BASE) sum1 -= BASE;
+ if (sum1 >= BASE) sum1 -= BASE;
+ if (sum2 >= (BASE << 1)) sum2 -= (BASE << 1);
+ if (sum2 >= BASE) sum2 -= BASE;
+ return sum1 | (sum2 << 16);
+}
+
+/* ========================================================================= */
+uLong ZEXPORT adler32_combine(adler1, adler2, len2)
+ uLong adler1;
+ uLong adler2;
+ z_off_t len2;
+{
+ return adler32_combine_(adler1, adler2, len2);
+}
+
+uLong ZEXPORT adler32_combine64(adler1, adler2, len2)
+ uLong adler1;
+ uLong adler2;
+ z_off64_t len2;
+{
+ return adler32_combine_(adler1, adler2, len2);
+}
diff --git a/ml/dlib/dlib/external/zlib/compress.c b/ml/dlib/dlib/external/zlib/compress.c
new file mode 100644
index 000000000..6e9762676
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/compress.c
@@ -0,0 +1,80 @@
+/* compress.c -- compress a memory buffer
+ * Copyright (C) 1995-2005 Jean-loup Gailly.
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* @(#) $Id$ */
+
+#define ZLIB_INTERNAL
+#include "zlib.h"
+
+/* ===========================================================================
+ Compresses the source buffer into the destination buffer. The level
+ parameter has the same meaning as in deflateInit. sourceLen is the byte
+ length of the source buffer. Upon entry, destLen is the total size of the
+ destination buffer, which must be at least 0.1% larger than sourceLen plus
+ 12 bytes. Upon exit, destLen is the actual size of the compressed buffer.
+
+ compress2 returns Z_OK if success, Z_MEM_ERROR if there was not enough
+ memory, Z_BUF_ERROR if there was not enough room in the output buffer,
+ Z_STREAM_ERROR if the level parameter is invalid.
+*/
+int ZEXPORT compress2 (dest, destLen, source, sourceLen, level)
+ Bytef *dest;
+ uLongf *destLen;
+ const Bytef *source;
+ uLong sourceLen;
+ int level;
+{
+ z_stream stream;
+ int err;
+
+ stream.next_in = (z_const Bytef *)source;
+ stream.avail_in = (uInt)sourceLen;
+#ifdef MAXSEG_64K
+ /* Check for source > 64K on 16-bit machine: */
+ if ((uLong)stream.avail_in != sourceLen) return Z_BUF_ERROR;
+#endif
+ stream.next_out = dest;
+ stream.avail_out = (uInt)*destLen;
+ if ((uLong)stream.avail_out != *destLen) return Z_BUF_ERROR;
+
+ stream.zalloc = (alloc_func)0;
+ stream.zfree = (free_func)0;
+ stream.opaque = (voidpf)0;
+
+ err = deflateInit(&stream, level);
+ if (err != Z_OK) return err;
+
+ err = deflate(&stream, Z_FINISH);
+ if (err != Z_STREAM_END) {
+ deflateEnd(&stream);
+ return err == Z_OK ? Z_BUF_ERROR : err;
+ }
+ *destLen = stream.total_out;
+
+ err = deflateEnd(&stream);
+ return err;
+}
+
+/* ===========================================================================
+ */
+int ZEXPORT compress (dest, destLen, source, sourceLen)
+ Bytef *dest;
+ uLongf *destLen;
+ const Bytef *source;
+ uLong sourceLen;
+{
+ return compress2(dest, destLen, source, sourceLen, Z_DEFAULT_COMPRESSION);
+}
+
+/* ===========================================================================
+ If the default memLevel or windowBits for deflateInit() is changed, then
+ this function needs to be updated.
+ */
+uLong ZEXPORT compressBound (sourceLen)
+ uLong sourceLen;
+{
+ return sourceLen + (sourceLen >> 12) + (sourceLen >> 14) +
+ (sourceLen >> 25) + 13;
+}
diff --git a/ml/dlib/dlib/external/zlib/crc32.c b/ml/dlib/dlib/external/zlib/crc32.c
new file mode 100644
index 000000000..979a7190a
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/crc32.c
@@ -0,0 +1,425 @@
+/* crc32.c -- compute the CRC-32 of a data stream
+ * Copyright (C) 1995-2006, 2010, 2011, 2012 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ *
+ * Thanks to Rodney Brown <rbrown64@csc.com.au> for his contribution of faster
+ * CRC methods: exclusive-oring 32 bits of data at a time, and pre-computing
+ * tables for updating the shift register in one step with three exclusive-ors
+ * instead of four steps with four exclusive-ors. This results in about a
+ * factor of two increase in speed on a Power PC G4 (PPC7455) using gcc -O3.
+ */
+
+/* @(#) $Id$ */
+
+/*
+ Note on the use of DYNAMIC_CRC_TABLE: there is no mutex or semaphore
+ protection on the static variables used to control the first-use generation
+ of the crc tables. Therefore, if you #define DYNAMIC_CRC_TABLE, you should
+ first call get_crc_table() to initialize the tables before allowing more than
+ one thread to use crc32().
+
+ DYNAMIC_CRC_TABLE and MAKECRCH can be #defined to write out crc32.h.
+ */
+
+#ifdef MAKECRCH
+# include <stdio.h>
+# ifndef DYNAMIC_CRC_TABLE
+# define DYNAMIC_CRC_TABLE
+# endif /* !DYNAMIC_CRC_TABLE */
+#endif /* MAKECRCH */
+
+#include "zutil.h" /* for STDC and FAR definitions */
+
+#define local static
+
+/* Definitions for doing the crc four data bytes at a time. */
+#if !defined(NOBYFOUR) && defined(Z_U4)
+# define BYFOUR
+#endif
+#ifdef BYFOUR
+ local unsigned long crc32_little OF((unsigned long,
+ const unsigned char FAR *, unsigned));
+ local unsigned long crc32_big OF((unsigned long,
+ const unsigned char FAR *, unsigned));
+# define TBLS 8
+#else
+# define TBLS 1
+#endif /* BYFOUR */
+
+/* Local functions for crc concatenation */
+local unsigned long gf2_matrix_times OF((unsigned long *mat,
+ unsigned long vec));
+local void gf2_matrix_square OF((unsigned long *square, unsigned long *mat));
+local uLong crc32_combine_ OF((uLong crc1, uLong crc2, z_off64_t len2));
+
+
+#ifdef DYNAMIC_CRC_TABLE
+
+local volatile int crc_table_empty = 1;
+local z_crc_t FAR crc_table[TBLS][256];
+local void make_crc_table OF((void));
+#ifdef MAKECRCH
+ local void write_table OF((FILE *, const z_crc_t FAR *));
+#endif /* MAKECRCH */
+/*
+ Generate tables for a byte-wise 32-bit CRC calculation on the polynomial:
+ x^32+x^26+x^23+x^22+x^16+x^12+x^11+x^10+x^8+x^7+x^5+x^4+x^2+x+1.
+
+ Polynomials over GF(2) are represented in binary, one bit per coefficient,
+ with the lowest powers in the most significant bit. Then adding polynomials
+ is just exclusive-or, and multiplying a polynomial by x is a right shift by
+ one. If we call the above polynomial p, and represent a byte as the
+ polynomial q, also with the lowest power in the most significant bit (so the
+ byte 0xb1 is the polynomial x^7+x^3+x+1), then the CRC is (q*x^32) mod p,
+ where a mod b means the remainder after dividing a by b.
+
+ This calculation is done using the shift-register method of multiplying and
+ taking the remainder. The register is initialized to zero, and for each
+ incoming bit, x^32 is added mod p to the register if the bit is a one (where
+ x^32 mod p is p+x^32 = x^26+...+1), and the register is multiplied mod p by
+ x (which is shifting right by one and adding x^32 mod p if the bit shifted
+ out is a one). We start with the highest power (least significant bit) of
+ q and repeat for all eight bits of q.
+
+ The first table is simply the CRC of all possible eight bit values. This is
+ all the information needed to generate CRCs on data a byte at a time for all
+ combinations of CRC register values and incoming bytes. The remaining tables
+ allow for word-at-a-time CRC calculation for both big-endian and little-
+ endian machines, where a word is four bytes.
+*/
+local void make_crc_table()
+{
+ z_crc_t c;
+ int n, k;
+ z_crc_t poly; /* polynomial exclusive-or pattern */
+ /* terms of polynomial defining this crc (except x^32): */
+ static volatile int first = 1; /* flag to limit concurrent making */
+ static const unsigned char p[] = {0,1,2,4,5,7,8,10,11,12,16,22,23,26};
+
+ /* See if another task is already doing this (not thread-safe, but better
+ than nothing -- significantly reduces duration of vulnerability in
+ case the advice about DYNAMIC_CRC_TABLE is ignored) */
+ if (first) {
+ first = 0;
+
+ /* make exclusive-or pattern from polynomial (0xedb88320UL) */
+ poly = 0;
+ for (n = 0; n < (int)(sizeof(p)/sizeof(unsigned char)); n++)
+ poly |= (z_crc_t)1 << (31 - p[n]);
+
+ /* generate a crc for every 8-bit value */
+ for (n = 0; n < 256; n++) {
+ c = (z_crc_t)n;
+ for (k = 0; k < 8; k++)
+ c = c & 1 ? poly ^ (c >> 1) : c >> 1;
+ crc_table[0][n] = c;
+ }
+
+#ifdef BYFOUR
+ /* generate crc for each value followed by one, two, and three zeros,
+ and then the byte reversal of those as well as the first table */
+ for (n = 0; n < 256; n++) {
+ c = crc_table[0][n];
+ crc_table[4][n] = ZSWAP32(c);
+ for (k = 1; k < 4; k++) {
+ c = crc_table[0][c & 0xff] ^ (c >> 8);
+ crc_table[k][n] = c;
+ crc_table[k + 4][n] = ZSWAP32(c);
+ }
+ }
+#endif /* BYFOUR */
+
+ crc_table_empty = 0;
+ }
+ else { /* not first */
+ /* wait for the other guy to finish (not efficient, but rare) */
+ while (crc_table_empty)
+ ;
+ }
+
+#ifdef MAKECRCH
+ /* write out CRC tables to crc32.h */
+ {
+ FILE *out;
+
+ out = fopen("crc32.h", "w");
+ if (out == NULL) return;
+ fprintf(out, "/* crc32.h -- tables for rapid CRC calculation\n");
+ fprintf(out, " * Generated automatically by crc32.c\n */\n\n");
+ fprintf(out, "local const z_crc_t FAR ");
+ fprintf(out, "crc_table[TBLS][256] =\n{\n {\n");
+ write_table(out, crc_table[0]);
+# ifdef BYFOUR
+ fprintf(out, "#ifdef BYFOUR\n");
+ for (k = 1; k < 8; k++) {
+ fprintf(out, " },\n {\n");
+ write_table(out, crc_table[k]);
+ }
+ fprintf(out, "#endif\n");
+# endif /* BYFOUR */
+ fprintf(out, " }\n};\n");
+ fclose(out);
+ }
+#endif /* MAKECRCH */
+}
+
+#ifdef MAKECRCH
+local void write_table(out, table)
+ FILE *out;
+ const z_crc_t FAR *table;
+{
+ int n;
+
+ for (n = 0; n < 256; n++)
+ fprintf(out, "%s0x%08lxUL%s", n % 5 ? "" : " ",
+ (unsigned long)(table[n]),
+ n == 255 ? "\n" : (n % 5 == 4 ? ",\n" : ", "));
+}
+#endif /* MAKECRCH */
+
+#else /* !DYNAMIC_CRC_TABLE */
+/* ========================================================================
+ * Tables of CRC-32s of all single-byte values, made by make_crc_table().
+ */
+#include "crc32.h"
+#endif /* DYNAMIC_CRC_TABLE */
+
+/* =========================================================================
+ * This function can be used by asm versions of crc32()
+ */
+const z_crc_t FAR * ZEXPORT get_crc_table()
+{
+#ifdef DYNAMIC_CRC_TABLE
+ if (crc_table_empty)
+ make_crc_table();
+#endif /* DYNAMIC_CRC_TABLE */
+ return (const z_crc_t FAR *)crc_table;
+}
+
+/* ========================================================================= */
+#define DO1 crc = crc_table[0][((int)crc ^ (*buf++)) & 0xff] ^ (crc >> 8)
+#define DO8 DO1; DO1; DO1; DO1; DO1; DO1; DO1; DO1
+
+/* ========================================================================= */
+unsigned long ZEXPORT crc32(crc, buf, len)
+ unsigned long crc;
+ const unsigned char FAR *buf;
+ uInt len;
+{
+ if (buf == Z_NULL) return 0UL;
+
+#ifdef DYNAMIC_CRC_TABLE
+ if (crc_table_empty)
+ make_crc_table();
+#endif /* DYNAMIC_CRC_TABLE */
+
+#ifdef BYFOUR
+ if (sizeof(void *) == sizeof(ptrdiff_t)) {
+ z_crc_t endian;
+
+ endian = 1;
+ if (*((unsigned char *)(&endian)))
+ return crc32_little(crc, buf, len);
+ else
+ return crc32_big(crc, buf, len);
+ }
+#endif /* BYFOUR */
+ crc = crc ^ 0xffffffffUL;
+ while (len >= 8) {
+ DO8;
+ len -= 8;
+ }
+ if (len) do {
+ DO1;
+ } while (--len);
+ return crc ^ 0xffffffffUL;
+}
+
+#ifdef BYFOUR
+
+/* ========================================================================= */
+#define DOLIT4 c ^= *buf4++; \
+ c = crc_table[3][c & 0xff] ^ crc_table[2][(c >> 8) & 0xff] ^ \
+ crc_table[1][(c >> 16) & 0xff] ^ crc_table[0][c >> 24]
+#define DOLIT32 DOLIT4; DOLIT4; DOLIT4; DOLIT4; DOLIT4; DOLIT4; DOLIT4; DOLIT4
+
+/* ========================================================================= */
+local unsigned long crc32_little(crc, buf, len)
+ unsigned long crc;
+ const unsigned char FAR *buf;
+ unsigned len;
+{
+ register z_crc_t c;
+ register const z_crc_t FAR *buf4;
+
+ c = (z_crc_t)crc;
+ c = ~c;
+ while (len && ((ptrdiff_t)buf & 3)) {
+ c = crc_table[0][(c ^ *buf++) & 0xff] ^ (c >> 8);
+ len--;
+ }
+
+ buf4 = (const z_crc_t FAR *)(const void FAR *)buf;
+ while (len >= 32) {
+ DOLIT32;
+ len -= 32;
+ }
+ while (len >= 4) {
+ DOLIT4;
+ len -= 4;
+ }
+ buf = (const unsigned char FAR *)buf4;
+
+ if (len) do {
+ c = crc_table[0][(c ^ *buf++) & 0xff] ^ (c >> 8);
+ } while (--len);
+ c = ~c;
+ return (unsigned long)c;
+}
+
+/* ========================================================================= */
+#define DOBIG4 c ^= *++buf4; \
+ c = crc_table[4][c & 0xff] ^ crc_table[5][(c >> 8) & 0xff] ^ \
+ crc_table[6][(c >> 16) & 0xff] ^ crc_table[7][c >> 24]
+#define DOBIG32 DOBIG4; DOBIG4; DOBIG4; DOBIG4; DOBIG4; DOBIG4; DOBIG4; DOBIG4
+
+/* ========================================================================= */
+local unsigned long crc32_big(crc, buf, len)
+ unsigned long crc;
+ const unsigned char FAR *buf;
+ unsigned len;
+{
+ register z_crc_t c;
+ register const z_crc_t FAR *buf4;
+
+ c = ZSWAP32((z_crc_t)crc);
+ c = ~c;
+ while (len && ((ptrdiff_t)buf & 3)) {
+ c = crc_table[4][(c >> 24) ^ *buf++] ^ (c << 8);
+ len--;
+ }
+
+ buf4 = (const z_crc_t FAR *)(const void FAR *)buf;
+ buf4--;
+ while (len >= 32) {
+ DOBIG32;
+ len -= 32;
+ }
+ while (len >= 4) {
+ DOBIG4;
+ len -= 4;
+ }
+ buf4++;
+ buf = (const unsigned char FAR *)buf4;
+
+ if (len) do {
+ c = crc_table[4][(c >> 24) ^ *buf++] ^ (c << 8);
+ } while (--len);
+ c = ~c;
+ return (unsigned long)(ZSWAP32(c));
+}
+
+#endif /* BYFOUR */
+
+#define GF2_DIM 32 /* dimension of GF(2) vectors (length of CRC) */
+
+/* ========================================================================= */
+local unsigned long gf2_matrix_times(mat, vec)
+ unsigned long *mat;
+ unsigned long vec;
+{
+ unsigned long sum;
+
+ sum = 0;
+ while (vec) {
+ if (vec & 1)
+ sum ^= *mat;
+ vec >>= 1;
+ mat++;
+ }
+ return sum;
+}
+
+/* ========================================================================= */
+local void gf2_matrix_square(square, mat)
+ unsigned long *square;
+ unsigned long *mat;
+{
+ int n;
+
+ for (n = 0; n < GF2_DIM; n++)
+ square[n] = gf2_matrix_times(mat, mat[n]);
+}
+
+/* ========================================================================= */
+local uLong crc32_combine_(crc1, crc2, len2)
+ uLong crc1;
+ uLong crc2;
+ z_off64_t len2;
+{
+ int n;
+ unsigned long row;
+ unsigned long even[GF2_DIM]; /* even-power-of-two zeros operator */
+ unsigned long odd[GF2_DIM]; /* odd-power-of-two zeros operator */
+
+ /* degenerate case (also disallow negative lengths) */
+ if (len2 <= 0)
+ return crc1;
+
+ /* put operator for one zero bit in odd */
+ odd[0] = 0xedb88320UL; /* CRC-32 polynomial */
+ row = 1;
+ for (n = 1; n < GF2_DIM; n++) {
+ odd[n] = row;
+ row <<= 1;
+ }
+
+ /* put operator for two zero bits in even */
+ gf2_matrix_square(even, odd);
+
+ /* put operator for four zero bits in odd */
+ gf2_matrix_square(odd, even);
+
+ /* apply len2 zeros to crc1 (first square will put the operator for one
+ zero byte, eight zero bits, in even) */
+ do {
+ /* apply zeros operator for this bit of len2 */
+ gf2_matrix_square(even, odd);
+ if (len2 & 1)
+ crc1 = gf2_matrix_times(even, crc1);
+ len2 >>= 1;
+
+ /* if no more bits set, then done */
+ if (len2 == 0)
+ break;
+
+ /* another iteration of the loop with odd and even swapped */
+ gf2_matrix_square(odd, even);
+ if (len2 & 1)
+ crc1 = gf2_matrix_times(odd, crc1);
+ len2 >>= 1;
+
+ /* if no more bits set, then done */
+ } while (len2 != 0);
+
+ /* return combined crc */
+ crc1 ^= crc2;
+ return crc1;
+}
+
+/* ========================================================================= */
+uLong ZEXPORT crc32_combine(crc1, crc2, len2)
+ uLong crc1;
+ uLong crc2;
+ z_off_t len2;
+{
+ return crc32_combine_(crc1, crc2, len2);
+}
+
+uLong ZEXPORT crc32_combine64(crc1, crc2, len2)
+ uLong crc1;
+ uLong crc2;
+ z_off64_t len2;
+{
+ return crc32_combine_(crc1, crc2, len2);
+}
diff --git a/ml/dlib/dlib/external/zlib/crc32.h b/ml/dlib/dlib/external/zlib/crc32.h
new file mode 100644
index 000000000..9e0c77810
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/crc32.h
@@ -0,0 +1,441 @@
+/* crc32.h -- tables for rapid CRC calculation
+ * Generated automatically by crc32.c
+ */
+
+local const z_crc_t FAR crc_table[TBLS][256] =
+{
+ {
+ 0x00000000UL, 0x77073096UL, 0xee0e612cUL, 0x990951baUL, 0x076dc419UL,
+ 0x706af48fUL, 0xe963a535UL, 0x9e6495a3UL, 0x0edb8832UL, 0x79dcb8a4UL,
+ 0xe0d5e91eUL, 0x97d2d988UL, 0x09b64c2bUL, 0x7eb17cbdUL, 0xe7b82d07UL,
+ 0x90bf1d91UL, 0x1db71064UL, 0x6ab020f2UL, 0xf3b97148UL, 0x84be41deUL,
+ 0x1adad47dUL, 0x6ddde4ebUL, 0xf4d4b551UL, 0x83d385c7UL, 0x136c9856UL,
+ 0x646ba8c0UL, 0xfd62f97aUL, 0x8a65c9ecUL, 0x14015c4fUL, 0x63066cd9UL,
+ 0xfa0f3d63UL, 0x8d080df5UL, 0x3b6e20c8UL, 0x4c69105eUL, 0xd56041e4UL,
+ 0xa2677172UL, 0x3c03e4d1UL, 0x4b04d447UL, 0xd20d85fdUL, 0xa50ab56bUL,
+ 0x35b5a8faUL, 0x42b2986cUL, 0xdbbbc9d6UL, 0xacbcf940UL, 0x32d86ce3UL,
+ 0x45df5c75UL, 0xdcd60dcfUL, 0xabd13d59UL, 0x26d930acUL, 0x51de003aUL,
+ 0xc8d75180UL, 0xbfd06116UL, 0x21b4f4b5UL, 0x56b3c423UL, 0xcfba9599UL,
+ 0xb8bda50fUL, 0x2802b89eUL, 0x5f058808UL, 0xc60cd9b2UL, 0xb10be924UL,
+ 0x2f6f7c87UL, 0x58684c11UL, 0xc1611dabUL, 0xb6662d3dUL, 0x76dc4190UL,
+ 0x01db7106UL, 0x98d220bcUL, 0xefd5102aUL, 0x71b18589UL, 0x06b6b51fUL,
+ 0x9fbfe4a5UL, 0xe8b8d433UL, 0x7807c9a2UL, 0x0f00f934UL, 0x9609a88eUL,
+ 0xe10e9818UL, 0x7f6a0dbbUL, 0x086d3d2dUL, 0x91646c97UL, 0xe6635c01UL,
+ 0x6b6b51f4UL, 0x1c6c6162UL, 0x856530d8UL, 0xf262004eUL, 0x6c0695edUL,
+ 0x1b01a57bUL, 0x8208f4c1UL, 0xf50fc457UL, 0x65b0d9c6UL, 0x12b7e950UL,
+ 0x8bbeb8eaUL, 0xfcb9887cUL, 0x62dd1ddfUL, 0x15da2d49UL, 0x8cd37cf3UL,
+ 0xfbd44c65UL, 0x4db26158UL, 0x3ab551ceUL, 0xa3bc0074UL, 0xd4bb30e2UL,
+ 0x4adfa541UL, 0x3dd895d7UL, 0xa4d1c46dUL, 0xd3d6f4fbUL, 0x4369e96aUL,
+ 0x346ed9fcUL, 0xad678846UL, 0xda60b8d0UL, 0x44042d73UL, 0x33031de5UL,
+ 0xaa0a4c5fUL, 0xdd0d7cc9UL, 0x5005713cUL, 0x270241aaUL, 0xbe0b1010UL,
+ 0xc90c2086UL, 0x5768b525UL, 0x206f85b3UL, 0xb966d409UL, 0xce61e49fUL,
+ 0x5edef90eUL, 0x29d9c998UL, 0xb0d09822UL, 0xc7d7a8b4UL, 0x59b33d17UL,
+ 0x2eb40d81UL, 0xb7bd5c3bUL, 0xc0ba6cadUL, 0xedb88320UL, 0x9abfb3b6UL,
+ 0x03b6e20cUL, 0x74b1d29aUL, 0xead54739UL, 0x9dd277afUL, 0x04db2615UL,
+ 0x73dc1683UL, 0xe3630b12UL, 0x94643b84UL, 0x0d6d6a3eUL, 0x7a6a5aa8UL,
+ 0xe40ecf0bUL, 0x9309ff9dUL, 0x0a00ae27UL, 0x7d079eb1UL, 0xf00f9344UL,
+ 0x8708a3d2UL, 0x1e01f268UL, 0x6906c2feUL, 0xf762575dUL, 0x806567cbUL,
+ 0x196c3671UL, 0x6e6b06e7UL, 0xfed41b76UL, 0x89d32be0UL, 0x10da7a5aUL,
+ 0x67dd4accUL, 0xf9b9df6fUL, 0x8ebeeff9UL, 0x17b7be43UL, 0x60b08ed5UL,
+ 0xd6d6a3e8UL, 0xa1d1937eUL, 0x38d8c2c4UL, 0x4fdff252UL, 0xd1bb67f1UL,
+ 0xa6bc5767UL, 0x3fb506ddUL, 0x48b2364bUL, 0xd80d2bdaUL, 0xaf0a1b4cUL,
+ 0x36034af6UL, 0x41047a60UL, 0xdf60efc3UL, 0xa867df55UL, 0x316e8eefUL,
+ 0x4669be79UL, 0xcb61b38cUL, 0xbc66831aUL, 0x256fd2a0UL, 0x5268e236UL,
+ 0xcc0c7795UL, 0xbb0b4703UL, 0x220216b9UL, 0x5505262fUL, 0xc5ba3bbeUL,
+ 0xb2bd0b28UL, 0x2bb45a92UL, 0x5cb36a04UL, 0xc2d7ffa7UL, 0xb5d0cf31UL,
+ 0x2cd99e8bUL, 0x5bdeae1dUL, 0x9b64c2b0UL, 0xec63f226UL, 0x756aa39cUL,
+ 0x026d930aUL, 0x9c0906a9UL, 0xeb0e363fUL, 0x72076785UL, 0x05005713UL,
+ 0x95bf4a82UL, 0xe2b87a14UL, 0x7bb12baeUL, 0x0cb61b38UL, 0x92d28e9bUL,
+ 0xe5d5be0dUL, 0x7cdcefb7UL, 0x0bdbdf21UL, 0x86d3d2d4UL, 0xf1d4e242UL,
+ 0x68ddb3f8UL, 0x1fda836eUL, 0x81be16cdUL, 0xf6b9265bUL, 0x6fb077e1UL,
+ 0x18b74777UL, 0x88085ae6UL, 0xff0f6a70UL, 0x66063bcaUL, 0x11010b5cUL,
+ 0x8f659effUL, 0xf862ae69UL, 0x616bffd3UL, 0x166ccf45UL, 0xa00ae278UL,
+ 0xd70dd2eeUL, 0x4e048354UL, 0x3903b3c2UL, 0xa7672661UL, 0xd06016f7UL,
+ 0x4969474dUL, 0x3e6e77dbUL, 0xaed16a4aUL, 0xd9d65adcUL, 0x40df0b66UL,
+ 0x37d83bf0UL, 0xa9bcae53UL, 0xdebb9ec5UL, 0x47b2cf7fUL, 0x30b5ffe9UL,
+ 0xbdbdf21cUL, 0xcabac28aUL, 0x53b39330UL, 0x24b4a3a6UL, 0xbad03605UL,
+ 0xcdd70693UL, 0x54de5729UL, 0x23d967bfUL, 0xb3667a2eUL, 0xc4614ab8UL,
+ 0x5d681b02UL, 0x2a6f2b94UL, 0xb40bbe37UL, 0xc30c8ea1UL, 0x5a05df1bUL,
+ 0x2d02ef8dUL
+#ifdef BYFOUR
+ },
+ {
+ 0x00000000UL, 0x191b3141UL, 0x32366282UL, 0x2b2d53c3UL, 0x646cc504UL,
+ 0x7d77f445UL, 0x565aa786UL, 0x4f4196c7UL, 0xc8d98a08UL, 0xd1c2bb49UL,
+ 0xfaefe88aUL, 0xe3f4d9cbUL, 0xacb54f0cUL, 0xb5ae7e4dUL, 0x9e832d8eUL,
+ 0x87981ccfUL, 0x4ac21251UL, 0x53d92310UL, 0x78f470d3UL, 0x61ef4192UL,
+ 0x2eaed755UL, 0x37b5e614UL, 0x1c98b5d7UL, 0x05838496UL, 0x821b9859UL,
+ 0x9b00a918UL, 0xb02dfadbUL, 0xa936cb9aUL, 0xe6775d5dUL, 0xff6c6c1cUL,
+ 0xd4413fdfUL, 0xcd5a0e9eUL, 0x958424a2UL, 0x8c9f15e3UL, 0xa7b24620UL,
+ 0xbea97761UL, 0xf1e8e1a6UL, 0xe8f3d0e7UL, 0xc3de8324UL, 0xdac5b265UL,
+ 0x5d5daeaaUL, 0x44469febUL, 0x6f6bcc28UL, 0x7670fd69UL, 0x39316baeUL,
+ 0x202a5aefUL, 0x0b07092cUL, 0x121c386dUL, 0xdf4636f3UL, 0xc65d07b2UL,
+ 0xed705471UL, 0xf46b6530UL, 0xbb2af3f7UL, 0xa231c2b6UL, 0x891c9175UL,
+ 0x9007a034UL, 0x179fbcfbUL, 0x0e848dbaUL, 0x25a9de79UL, 0x3cb2ef38UL,
+ 0x73f379ffUL, 0x6ae848beUL, 0x41c51b7dUL, 0x58de2a3cUL, 0xf0794f05UL,
+ 0xe9627e44UL, 0xc24f2d87UL, 0xdb541cc6UL, 0x94158a01UL, 0x8d0ebb40UL,
+ 0xa623e883UL, 0xbf38d9c2UL, 0x38a0c50dUL, 0x21bbf44cUL, 0x0a96a78fUL,
+ 0x138d96ceUL, 0x5ccc0009UL, 0x45d73148UL, 0x6efa628bUL, 0x77e153caUL,
+ 0xbabb5d54UL, 0xa3a06c15UL, 0x888d3fd6UL, 0x91960e97UL, 0xded79850UL,
+ 0xc7cca911UL, 0xece1fad2UL, 0xf5facb93UL, 0x7262d75cUL, 0x6b79e61dUL,
+ 0x4054b5deUL, 0x594f849fUL, 0x160e1258UL, 0x0f152319UL, 0x243870daUL,
+ 0x3d23419bUL, 0x65fd6ba7UL, 0x7ce65ae6UL, 0x57cb0925UL, 0x4ed03864UL,
+ 0x0191aea3UL, 0x188a9fe2UL, 0x33a7cc21UL, 0x2abcfd60UL, 0xad24e1afUL,
+ 0xb43fd0eeUL, 0x9f12832dUL, 0x8609b26cUL, 0xc94824abUL, 0xd05315eaUL,
+ 0xfb7e4629UL, 0xe2657768UL, 0x2f3f79f6UL, 0x362448b7UL, 0x1d091b74UL,
+ 0x04122a35UL, 0x4b53bcf2UL, 0x52488db3UL, 0x7965de70UL, 0x607eef31UL,
+ 0xe7e6f3feUL, 0xfefdc2bfUL, 0xd5d0917cUL, 0xcccba03dUL, 0x838a36faUL,
+ 0x9a9107bbUL, 0xb1bc5478UL, 0xa8a76539UL, 0x3b83984bUL, 0x2298a90aUL,
+ 0x09b5fac9UL, 0x10aecb88UL, 0x5fef5d4fUL, 0x46f46c0eUL, 0x6dd93fcdUL,
+ 0x74c20e8cUL, 0xf35a1243UL, 0xea412302UL, 0xc16c70c1UL, 0xd8774180UL,
+ 0x9736d747UL, 0x8e2de606UL, 0xa500b5c5UL, 0xbc1b8484UL, 0x71418a1aUL,
+ 0x685abb5bUL, 0x4377e898UL, 0x5a6cd9d9UL, 0x152d4f1eUL, 0x0c367e5fUL,
+ 0x271b2d9cUL, 0x3e001cddUL, 0xb9980012UL, 0xa0833153UL, 0x8bae6290UL,
+ 0x92b553d1UL, 0xddf4c516UL, 0xc4eff457UL, 0xefc2a794UL, 0xf6d996d5UL,
+ 0xae07bce9UL, 0xb71c8da8UL, 0x9c31de6bUL, 0x852aef2aUL, 0xca6b79edUL,
+ 0xd37048acUL, 0xf85d1b6fUL, 0xe1462a2eUL, 0x66de36e1UL, 0x7fc507a0UL,
+ 0x54e85463UL, 0x4df36522UL, 0x02b2f3e5UL, 0x1ba9c2a4UL, 0x30849167UL,
+ 0x299fa026UL, 0xe4c5aeb8UL, 0xfdde9ff9UL, 0xd6f3cc3aUL, 0xcfe8fd7bUL,
+ 0x80a96bbcUL, 0x99b25afdUL, 0xb29f093eUL, 0xab84387fUL, 0x2c1c24b0UL,
+ 0x350715f1UL, 0x1e2a4632UL, 0x07317773UL, 0x4870e1b4UL, 0x516bd0f5UL,
+ 0x7a468336UL, 0x635db277UL, 0xcbfad74eUL, 0xd2e1e60fUL, 0xf9ccb5ccUL,
+ 0xe0d7848dUL, 0xaf96124aUL, 0xb68d230bUL, 0x9da070c8UL, 0x84bb4189UL,
+ 0x03235d46UL, 0x1a386c07UL, 0x31153fc4UL, 0x280e0e85UL, 0x674f9842UL,
+ 0x7e54a903UL, 0x5579fac0UL, 0x4c62cb81UL, 0x8138c51fUL, 0x9823f45eUL,
+ 0xb30ea79dUL, 0xaa1596dcUL, 0xe554001bUL, 0xfc4f315aUL, 0xd7626299UL,
+ 0xce7953d8UL, 0x49e14f17UL, 0x50fa7e56UL, 0x7bd72d95UL, 0x62cc1cd4UL,
+ 0x2d8d8a13UL, 0x3496bb52UL, 0x1fbbe891UL, 0x06a0d9d0UL, 0x5e7ef3ecUL,
+ 0x4765c2adUL, 0x6c48916eUL, 0x7553a02fUL, 0x3a1236e8UL, 0x230907a9UL,
+ 0x0824546aUL, 0x113f652bUL, 0x96a779e4UL, 0x8fbc48a5UL, 0xa4911b66UL,
+ 0xbd8a2a27UL, 0xf2cbbce0UL, 0xebd08da1UL, 0xc0fdde62UL, 0xd9e6ef23UL,
+ 0x14bce1bdUL, 0x0da7d0fcUL, 0x268a833fUL, 0x3f91b27eUL, 0x70d024b9UL,
+ 0x69cb15f8UL, 0x42e6463bUL, 0x5bfd777aUL, 0xdc656bb5UL, 0xc57e5af4UL,
+ 0xee530937UL, 0xf7483876UL, 0xb809aeb1UL, 0xa1129ff0UL, 0x8a3fcc33UL,
+ 0x9324fd72UL
+ },
+ {
+ 0x00000000UL, 0x01c26a37UL, 0x0384d46eUL, 0x0246be59UL, 0x0709a8dcUL,
+ 0x06cbc2ebUL, 0x048d7cb2UL, 0x054f1685UL, 0x0e1351b8UL, 0x0fd13b8fUL,
+ 0x0d9785d6UL, 0x0c55efe1UL, 0x091af964UL, 0x08d89353UL, 0x0a9e2d0aUL,
+ 0x0b5c473dUL, 0x1c26a370UL, 0x1de4c947UL, 0x1fa2771eUL, 0x1e601d29UL,
+ 0x1b2f0bacUL, 0x1aed619bUL, 0x18abdfc2UL, 0x1969b5f5UL, 0x1235f2c8UL,
+ 0x13f798ffUL, 0x11b126a6UL, 0x10734c91UL, 0x153c5a14UL, 0x14fe3023UL,
+ 0x16b88e7aUL, 0x177ae44dUL, 0x384d46e0UL, 0x398f2cd7UL, 0x3bc9928eUL,
+ 0x3a0bf8b9UL, 0x3f44ee3cUL, 0x3e86840bUL, 0x3cc03a52UL, 0x3d025065UL,
+ 0x365e1758UL, 0x379c7d6fUL, 0x35dac336UL, 0x3418a901UL, 0x3157bf84UL,
+ 0x3095d5b3UL, 0x32d36beaUL, 0x331101ddUL, 0x246be590UL, 0x25a98fa7UL,
+ 0x27ef31feUL, 0x262d5bc9UL, 0x23624d4cUL, 0x22a0277bUL, 0x20e69922UL,
+ 0x2124f315UL, 0x2a78b428UL, 0x2bbade1fUL, 0x29fc6046UL, 0x283e0a71UL,
+ 0x2d711cf4UL, 0x2cb376c3UL, 0x2ef5c89aUL, 0x2f37a2adUL, 0x709a8dc0UL,
+ 0x7158e7f7UL, 0x731e59aeUL, 0x72dc3399UL, 0x7793251cUL, 0x76514f2bUL,
+ 0x7417f172UL, 0x75d59b45UL, 0x7e89dc78UL, 0x7f4bb64fUL, 0x7d0d0816UL,
+ 0x7ccf6221UL, 0x798074a4UL, 0x78421e93UL, 0x7a04a0caUL, 0x7bc6cafdUL,
+ 0x6cbc2eb0UL, 0x6d7e4487UL, 0x6f38fadeUL, 0x6efa90e9UL, 0x6bb5866cUL,
+ 0x6a77ec5bUL, 0x68315202UL, 0x69f33835UL, 0x62af7f08UL, 0x636d153fUL,
+ 0x612bab66UL, 0x60e9c151UL, 0x65a6d7d4UL, 0x6464bde3UL, 0x662203baUL,
+ 0x67e0698dUL, 0x48d7cb20UL, 0x4915a117UL, 0x4b531f4eUL, 0x4a917579UL,
+ 0x4fde63fcUL, 0x4e1c09cbUL, 0x4c5ab792UL, 0x4d98dda5UL, 0x46c49a98UL,
+ 0x4706f0afUL, 0x45404ef6UL, 0x448224c1UL, 0x41cd3244UL, 0x400f5873UL,
+ 0x4249e62aUL, 0x438b8c1dUL, 0x54f16850UL, 0x55330267UL, 0x5775bc3eUL,
+ 0x56b7d609UL, 0x53f8c08cUL, 0x523aaabbUL, 0x507c14e2UL, 0x51be7ed5UL,
+ 0x5ae239e8UL, 0x5b2053dfUL, 0x5966ed86UL, 0x58a487b1UL, 0x5deb9134UL,
+ 0x5c29fb03UL, 0x5e6f455aUL, 0x5fad2f6dUL, 0xe1351b80UL, 0xe0f771b7UL,
+ 0xe2b1cfeeUL, 0xe373a5d9UL, 0xe63cb35cUL, 0xe7fed96bUL, 0xe5b86732UL,
+ 0xe47a0d05UL, 0xef264a38UL, 0xeee4200fUL, 0xeca29e56UL, 0xed60f461UL,
+ 0xe82fe2e4UL, 0xe9ed88d3UL, 0xebab368aUL, 0xea695cbdUL, 0xfd13b8f0UL,
+ 0xfcd1d2c7UL, 0xfe976c9eUL, 0xff5506a9UL, 0xfa1a102cUL, 0xfbd87a1bUL,
+ 0xf99ec442UL, 0xf85cae75UL, 0xf300e948UL, 0xf2c2837fUL, 0xf0843d26UL,
+ 0xf1465711UL, 0xf4094194UL, 0xf5cb2ba3UL, 0xf78d95faUL, 0xf64fffcdUL,
+ 0xd9785d60UL, 0xd8ba3757UL, 0xdafc890eUL, 0xdb3ee339UL, 0xde71f5bcUL,
+ 0xdfb39f8bUL, 0xddf521d2UL, 0xdc374be5UL, 0xd76b0cd8UL, 0xd6a966efUL,
+ 0xd4efd8b6UL, 0xd52db281UL, 0xd062a404UL, 0xd1a0ce33UL, 0xd3e6706aUL,
+ 0xd2241a5dUL, 0xc55efe10UL, 0xc49c9427UL, 0xc6da2a7eUL, 0xc7184049UL,
+ 0xc25756ccUL, 0xc3953cfbUL, 0xc1d382a2UL, 0xc011e895UL, 0xcb4dafa8UL,
+ 0xca8fc59fUL, 0xc8c97bc6UL, 0xc90b11f1UL, 0xcc440774UL, 0xcd866d43UL,
+ 0xcfc0d31aUL, 0xce02b92dUL, 0x91af9640UL, 0x906dfc77UL, 0x922b422eUL,
+ 0x93e92819UL, 0x96a63e9cUL, 0x976454abUL, 0x9522eaf2UL, 0x94e080c5UL,
+ 0x9fbcc7f8UL, 0x9e7eadcfUL, 0x9c381396UL, 0x9dfa79a1UL, 0x98b56f24UL,
+ 0x99770513UL, 0x9b31bb4aUL, 0x9af3d17dUL, 0x8d893530UL, 0x8c4b5f07UL,
+ 0x8e0de15eUL, 0x8fcf8b69UL, 0x8a809decUL, 0x8b42f7dbUL, 0x89044982UL,
+ 0x88c623b5UL, 0x839a6488UL, 0x82580ebfUL, 0x801eb0e6UL, 0x81dcdad1UL,
+ 0x8493cc54UL, 0x8551a663UL, 0x8717183aUL, 0x86d5720dUL, 0xa9e2d0a0UL,
+ 0xa820ba97UL, 0xaa6604ceUL, 0xaba46ef9UL, 0xaeeb787cUL, 0xaf29124bUL,
+ 0xad6fac12UL, 0xacadc625UL, 0xa7f18118UL, 0xa633eb2fUL, 0xa4755576UL,
+ 0xa5b73f41UL, 0xa0f829c4UL, 0xa13a43f3UL, 0xa37cfdaaUL, 0xa2be979dUL,
+ 0xb5c473d0UL, 0xb40619e7UL, 0xb640a7beUL, 0xb782cd89UL, 0xb2cddb0cUL,
+ 0xb30fb13bUL, 0xb1490f62UL, 0xb08b6555UL, 0xbbd72268UL, 0xba15485fUL,
+ 0xb853f606UL, 0xb9919c31UL, 0xbcde8ab4UL, 0xbd1ce083UL, 0xbf5a5edaUL,
+ 0xbe9834edUL
+ },
+ {
+ 0x00000000UL, 0xb8bc6765UL, 0xaa09c88bUL, 0x12b5afeeUL, 0x8f629757UL,
+ 0x37def032UL, 0x256b5fdcUL, 0x9dd738b9UL, 0xc5b428efUL, 0x7d084f8aUL,
+ 0x6fbde064UL, 0xd7018701UL, 0x4ad6bfb8UL, 0xf26ad8ddUL, 0xe0df7733UL,
+ 0x58631056UL, 0x5019579fUL, 0xe8a530faUL, 0xfa109f14UL, 0x42acf871UL,
+ 0xdf7bc0c8UL, 0x67c7a7adUL, 0x75720843UL, 0xcdce6f26UL, 0x95ad7f70UL,
+ 0x2d111815UL, 0x3fa4b7fbUL, 0x8718d09eUL, 0x1acfe827UL, 0xa2738f42UL,
+ 0xb0c620acUL, 0x087a47c9UL, 0xa032af3eUL, 0x188ec85bUL, 0x0a3b67b5UL,
+ 0xb28700d0UL, 0x2f503869UL, 0x97ec5f0cUL, 0x8559f0e2UL, 0x3de59787UL,
+ 0x658687d1UL, 0xdd3ae0b4UL, 0xcf8f4f5aUL, 0x7733283fUL, 0xeae41086UL,
+ 0x525877e3UL, 0x40edd80dUL, 0xf851bf68UL, 0xf02bf8a1UL, 0x48979fc4UL,
+ 0x5a22302aUL, 0xe29e574fUL, 0x7f496ff6UL, 0xc7f50893UL, 0xd540a77dUL,
+ 0x6dfcc018UL, 0x359fd04eUL, 0x8d23b72bUL, 0x9f9618c5UL, 0x272a7fa0UL,
+ 0xbafd4719UL, 0x0241207cUL, 0x10f48f92UL, 0xa848e8f7UL, 0x9b14583dUL,
+ 0x23a83f58UL, 0x311d90b6UL, 0x89a1f7d3UL, 0x1476cf6aUL, 0xaccaa80fUL,
+ 0xbe7f07e1UL, 0x06c36084UL, 0x5ea070d2UL, 0xe61c17b7UL, 0xf4a9b859UL,
+ 0x4c15df3cUL, 0xd1c2e785UL, 0x697e80e0UL, 0x7bcb2f0eUL, 0xc377486bUL,
+ 0xcb0d0fa2UL, 0x73b168c7UL, 0x6104c729UL, 0xd9b8a04cUL, 0x446f98f5UL,
+ 0xfcd3ff90UL, 0xee66507eUL, 0x56da371bUL, 0x0eb9274dUL, 0xb6054028UL,
+ 0xa4b0efc6UL, 0x1c0c88a3UL, 0x81dbb01aUL, 0x3967d77fUL, 0x2bd27891UL,
+ 0x936e1ff4UL, 0x3b26f703UL, 0x839a9066UL, 0x912f3f88UL, 0x299358edUL,
+ 0xb4446054UL, 0x0cf80731UL, 0x1e4da8dfUL, 0xa6f1cfbaUL, 0xfe92dfecUL,
+ 0x462eb889UL, 0x549b1767UL, 0xec277002UL, 0x71f048bbUL, 0xc94c2fdeUL,
+ 0xdbf98030UL, 0x6345e755UL, 0x6b3fa09cUL, 0xd383c7f9UL, 0xc1366817UL,
+ 0x798a0f72UL, 0xe45d37cbUL, 0x5ce150aeUL, 0x4e54ff40UL, 0xf6e89825UL,
+ 0xae8b8873UL, 0x1637ef16UL, 0x048240f8UL, 0xbc3e279dUL, 0x21e91f24UL,
+ 0x99557841UL, 0x8be0d7afUL, 0x335cb0caUL, 0xed59b63bUL, 0x55e5d15eUL,
+ 0x47507eb0UL, 0xffec19d5UL, 0x623b216cUL, 0xda874609UL, 0xc832e9e7UL,
+ 0x708e8e82UL, 0x28ed9ed4UL, 0x9051f9b1UL, 0x82e4565fUL, 0x3a58313aUL,
+ 0xa78f0983UL, 0x1f336ee6UL, 0x0d86c108UL, 0xb53aa66dUL, 0xbd40e1a4UL,
+ 0x05fc86c1UL, 0x1749292fUL, 0xaff54e4aUL, 0x322276f3UL, 0x8a9e1196UL,
+ 0x982bbe78UL, 0x2097d91dUL, 0x78f4c94bUL, 0xc048ae2eUL, 0xd2fd01c0UL,
+ 0x6a4166a5UL, 0xf7965e1cUL, 0x4f2a3979UL, 0x5d9f9697UL, 0xe523f1f2UL,
+ 0x4d6b1905UL, 0xf5d77e60UL, 0xe762d18eUL, 0x5fdeb6ebUL, 0xc2098e52UL,
+ 0x7ab5e937UL, 0x680046d9UL, 0xd0bc21bcUL, 0x88df31eaUL, 0x3063568fUL,
+ 0x22d6f961UL, 0x9a6a9e04UL, 0x07bda6bdUL, 0xbf01c1d8UL, 0xadb46e36UL,
+ 0x15080953UL, 0x1d724e9aUL, 0xa5ce29ffUL, 0xb77b8611UL, 0x0fc7e174UL,
+ 0x9210d9cdUL, 0x2aacbea8UL, 0x38191146UL, 0x80a57623UL, 0xd8c66675UL,
+ 0x607a0110UL, 0x72cfaefeUL, 0xca73c99bUL, 0x57a4f122UL, 0xef189647UL,
+ 0xfdad39a9UL, 0x45115eccUL, 0x764dee06UL, 0xcef18963UL, 0xdc44268dUL,
+ 0x64f841e8UL, 0xf92f7951UL, 0x41931e34UL, 0x5326b1daUL, 0xeb9ad6bfUL,
+ 0xb3f9c6e9UL, 0x0b45a18cUL, 0x19f00e62UL, 0xa14c6907UL, 0x3c9b51beUL,
+ 0x842736dbUL, 0x96929935UL, 0x2e2efe50UL, 0x2654b999UL, 0x9ee8defcUL,
+ 0x8c5d7112UL, 0x34e11677UL, 0xa9362eceUL, 0x118a49abUL, 0x033fe645UL,
+ 0xbb838120UL, 0xe3e09176UL, 0x5b5cf613UL, 0x49e959fdUL, 0xf1553e98UL,
+ 0x6c820621UL, 0xd43e6144UL, 0xc68bceaaUL, 0x7e37a9cfUL, 0xd67f4138UL,
+ 0x6ec3265dUL, 0x7c7689b3UL, 0xc4caeed6UL, 0x591dd66fUL, 0xe1a1b10aUL,
+ 0xf3141ee4UL, 0x4ba87981UL, 0x13cb69d7UL, 0xab770eb2UL, 0xb9c2a15cUL,
+ 0x017ec639UL, 0x9ca9fe80UL, 0x241599e5UL, 0x36a0360bUL, 0x8e1c516eUL,
+ 0x866616a7UL, 0x3eda71c2UL, 0x2c6fde2cUL, 0x94d3b949UL, 0x090481f0UL,
+ 0xb1b8e695UL, 0xa30d497bUL, 0x1bb12e1eUL, 0x43d23e48UL, 0xfb6e592dUL,
+ 0xe9dbf6c3UL, 0x516791a6UL, 0xccb0a91fUL, 0x740cce7aUL, 0x66b96194UL,
+ 0xde0506f1UL
+ },
+ {
+ 0x00000000UL, 0x96300777UL, 0x2c610eeeUL, 0xba510999UL, 0x19c46d07UL,
+ 0x8ff46a70UL, 0x35a563e9UL, 0xa395649eUL, 0x3288db0eUL, 0xa4b8dc79UL,
+ 0x1ee9d5e0UL, 0x88d9d297UL, 0x2b4cb609UL, 0xbd7cb17eUL, 0x072db8e7UL,
+ 0x911dbf90UL, 0x6410b71dUL, 0xf220b06aUL, 0x4871b9f3UL, 0xde41be84UL,
+ 0x7dd4da1aUL, 0xebe4dd6dUL, 0x51b5d4f4UL, 0xc785d383UL, 0x56986c13UL,
+ 0xc0a86b64UL, 0x7af962fdUL, 0xecc9658aUL, 0x4f5c0114UL, 0xd96c0663UL,
+ 0x633d0ffaUL, 0xf50d088dUL, 0xc8206e3bUL, 0x5e10694cUL, 0xe44160d5UL,
+ 0x727167a2UL, 0xd1e4033cUL, 0x47d4044bUL, 0xfd850dd2UL, 0x6bb50aa5UL,
+ 0xfaa8b535UL, 0x6c98b242UL, 0xd6c9bbdbUL, 0x40f9bcacUL, 0xe36cd832UL,
+ 0x755cdf45UL, 0xcf0dd6dcUL, 0x593dd1abUL, 0xac30d926UL, 0x3a00de51UL,
+ 0x8051d7c8UL, 0x1661d0bfUL, 0xb5f4b421UL, 0x23c4b356UL, 0x9995bacfUL,
+ 0x0fa5bdb8UL, 0x9eb80228UL, 0x0888055fUL, 0xb2d90cc6UL, 0x24e90bb1UL,
+ 0x877c6f2fUL, 0x114c6858UL, 0xab1d61c1UL, 0x3d2d66b6UL, 0x9041dc76UL,
+ 0x0671db01UL, 0xbc20d298UL, 0x2a10d5efUL, 0x8985b171UL, 0x1fb5b606UL,
+ 0xa5e4bf9fUL, 0x33d4b8e8UL, 0xa2c90778UL, 0x34f9000fUL, 0x8ea80996UL,
+ 0x18980ee1UL, 0xbb0d6a7fUL, 0x2d3d6d08UL, 0x976c6491UL, 0x015c63e6UL,
+ 0xf4516b6bUL, 0x62616c1cUL, 0xd8306585UL, 0x4e0062f2UL, 0xed95066cUL,
+ 0x7ba5011bUL, 0xc1f40882UL, 0x57c40ff5UL, 0xc6d9b065UL, 0x50e9b712UL,
+ 0xeab8be8bUL, 0x7c88b9fcUL, 0xdf1ddd62UL, 0x492dda15UL, 0xf37cd38cUL,
+ 0x654cd4fbUL, 0x5861b24dUL, 0xce51b53aUL, 0x7400bca3UL, 0xe230bbd4UL,
+ 0x41a5df4aUL, 0xd795d83dUL, 0x6dc4d1a4UL, 0xfbf4d6d3UL, 0x6ae96943UL,
+ 0xfcd96e34UL, 0x468867adUL, 0xd0b860daUL, 0x732d0444UL, 0xe51d0333UL,
+ 0x5f4c0aaaUL, 0xc97c0dddUL, 0x3c710550UL, 0xaa410227UL, 0x10100bbeUL,
+ 0x86200cc9UL, 0x25b56857UL, 0xb3856f20UL, 0x09d466b9UL, 0x9fe461ceUL,
+ 0x0ef9de5eUL, 0x98c9d929UL, 0x2298d0b0UL, 0xb4a8d7c7UL, 0x173db359UL,
+ 0x810db42eUL, 0x3b5cbdb7UL, 0xad6cbac0UL, 0x2083b8edUL, 0xb6b3bf9aUL,
+ 0x0ce2b603UL, 0x9ad2b174UL, 0x3947d5eaUL, 0xaf77d29dUL, 0x1526db04UL,
+ 0x8316dc73UL, 0x120b63e3UL, 0x843b6494UL, 0x3e6a6d0dUL, 0xa85a6a7aUL,
+ 0x0bcf0ee4UL, 0x9dff0993UL, 0x27ae000aUL, 0xb19e077dUL, 0x44930ff0UL,
+ 0xd2a30887UL, 0x68f2011eUL, 0xfec20669UL, 0x5d5762f7UL, 0xcb676580UL,
+ 0x71366c19UL, 0xe7066b6eUL, 0x761bd4feUL, 0xe02bd389UL, 0x5a7ada10UL,
+ 0xcc4add67UL, 0x6fdfb9f9UL, 0xf9efbe8eUL, 0x43beb717UL, 0xd58eb060UL,
+ 0xe8a3d6d6UL, 0x7e93d1a1UL, 0xc4c2d838UL, 0x52f2df4fUL, 0xf167bbd1UL,
+ 0x6757bca6UL, 0xdd06b53fUL, 0x4b36b248UL, 0xda2b0dd8UL, 0x4c1b0aafUL,
+ 0xf64a0336UL, 0x607a0441UL, 0xc3ef60dfUL, 0x55df67a8UL, 0xef8e6e31UL,
+ 0x79be6946UL, 0x8cb361cbUL, 0x1a8366bcUL, 0xa0d26f25UL, 0x36e26852UL,
+ 0x95770cccUL, 0x03470bbbUL, 0xb9160222UL, 0x2f260555UL, 0xbe3bbac5UL,
+ 0x280bbdb2UL, 0x925ab42bUL, 0x046ab35cUL, 0xa7ffd7c2UL, 0x31cfd0b5UL,
+ 0x8b9ed92cUL, 0x1daede5bUL, 0xb0c2649bUL, 0x26f263ecUL, 0x9ca36a75UL,
+ 0x0a936d02UL, 0xa906099cUL, 0x3f360eebUL, 0x85670772UL, 0x13570005UL,
+ 0x824abf95UL, 0x147ab8e2UL, 0xae2bb17bUL, 0x381bb60cUL, 0x9b8ed292UL,
+ 0x0dbed5e5UL, 0xb7efdc7cUL, 0x21dfdb0bUL, 0xd4d2d386UL, 0x42e2d4f1UL,
+ 0xf8b3dd68UL, 0x6e83da1fUL, 0xcd16be81UL, 0x5b26b9f6UL, 0xe177b06fUL,
+ 0x7747b718UL, 0xe65a0888UL, 0x706a0fffUL, 0xca3b0666UL, 0x5c0b0111UL,
+ 0xff9e658fUL, 0x69ae62f8UL, 0xd3ff6b61UL, 0x45cf6c16UL, 0x78e20aa0UL,
+ 0xeed20dd7UL, 0x5483044eUL, 0xc2b30339UL, 0x612667a7UL, 0xf71660d0UL,
+ 0x4d476949UL, 0xdb776e3eUL, 0x4a6ad1aeUL, 0xdc5ad6d9UL, 0x660bdf40UL,
+ 0xf03bd837UL, 0x53aebca9UL, 0xc59ebbdeUL, 0x7fcfb247UL, 0xe9ffb530UL,
+ 0x1cf2bdbdUL, 0x8ac2bacaUL, 0x3093b353UL, 0xa6a3b424UL, 0x0536d0baUL,
+ 0x9306d7cdUL, 0x2957de54UL, 0xbf67d923UL, 0x2e7a66b3UL, 0xb84a61c4UL,
+ 0x021b685dUL, 0x942b6f2aUL, 0x37be0bb4UL, 0xa18e0cc3UL, 0x1bdf055aUL,
+ 0x8def022dUL
+ },
+ {
+ 0x00000000UL, 0x41311b19UL, 0x82623632UL, 0xc3532d2bUL, 0x04c56c64UL,
+ 0x45f4777dUL, 0x86a75a56UL, 0xc796414fUL, 0x088ad9c8UL, 0x49bbc2d1UL,
+ 0x8ae8effaUL, 0xcbd9f4e3UL, 0x0c4fb5acUL, 0x4d7eaeb5UL, 0x8e2d839eUL,
+ 0xcf1c9887UL, 0x5112c24aUL, 0x1023d953UL, 0xd370f478UL, 0x9241ef61UL,
+ 0x55d7ae2eUL, 0x14e6b537UL, 0xd7b5981cUL, 0x96848305UL, 0x59981b82UL,
+ 0x18a9009bUL, 0xdbfa2db0UL, 0x9acb36a9UL, 0x5d5d77e6UL, 0x1c6c6cffUL,
+ 0xdf3f41d4UL, 0x9e0e5acdUL, 0xa2248495UL, 0xe3159f8cUL, 0x2046b2a7UL,
+ 0x6177a9beUL, 0xa6e1e8f1UL, 0xe7d0f3e8UL, 0x2483dec3UL, 0x65b2c5daUL,
+ 0xaaae5d5dUL, 0xeb9f4644UL, 0x28cc6b6fUL, 0x69fd7076UL, 0xae6b3139UL,
+ 0xef5a2a20UL, 0x2c09070bUL, 0x6d381c12UL, 0xf33646dfUL, 0xb2075dc6UL,
+ 0x715470edUL, 0x30656bf4UL, 0xf7f32abbUL, 0xb6c231a2UL, 0x75911c89UL,
+ 0x34a00790UL, 0xfbbc9f17UL, 0xba8d840eUL, 0x79dea925UL, 0x38efb23cUL,
+ 0xff79f373UL, 0xbe48e86aUL, 0x7d1bc541UL, 0x3c2ade58UL, 0x054f79f0UL,
+ 0x447e62e9UL, 0x872d4fc2UL, 0xc61c54dbUL, 0x018a1594UL, 0x40bb0e8dUL,
+ 0x83e823a6UL, 0xc2d938bfUL, 0x0dc5a038UL, 0x4cf4bb21UL, 0x8fa7960aUL,
+ 0xce968d13UL, 0x0900cc5cUL, 0x4831d745UL, 0x8b62fa6eUL, 0xca53e177UL,
+ 0x545dbbbaUL, 0x156ca0a3UL, 0xd63f8d88UL, 0x970e9691UL, 0x5098d7deUL,
+ 0x11a9ccc7UL, 0xd2fae1ecUL, 0x93cbfaf5UL, 0x5cd76272UL, 0x1de6796bUL,
+ 0xdeb55440UL, 0x9f844f59UL, 0x58120e16UL, 0x1923150fUL, 0xda703824UL,
+ 0x9b41233dUL, 0xa76bfd65UL, 0xe65ae67cUL, 0x2509cb57UL, 0x6438d04eUL,
+ 0xa3ae9101UL, 0xe29f8a18UL, 0x21cca733UL, 0x60fdbc2aUL, 0xafe124adUL,
+ 0xeed03fb4UL, 0x2d83129fUL, 0x6cb20986UL, 0xab2448c9UL, 0xea1553d0UL,
+ 0x29467efbUL, 0x687765e2UL, 0xf6793f2fUL, 0xb7482436UL, 0x741b091dUL,
+ 0x352a1204UL, 0xf2bc534bUL, 0xb38d4852UL, 0x70de6579UL, 0x31ef7e60UL,
+ 0xfef3e6e7UL, 0xbfc2fdfeUL, 0x7c91d0d5UL, 0x3da0cbccUL, 0xfa368a83UL,
+ 0xbb07919aUL, 0x7854bcb1UL, 0x3965a7a8UL, 0x4b98833bUL, 0x0aa99822UL,
+ 0xc9fab509UL, 0x88cbae10UL, 0x4f5def5fUL, 0x0e6cf446UL, 0xcd3fd96dUL,
+ 0x8c0ec274UL, 0x43125af3UL, 0x022341eaUL, 0xc1706cc1UL, 0x804177d8UL,
+ 0x47d73697UL, 0x06e62d8eUL, 0xc5b500a5UL, 0x84841bbcUL, 0x1a8a4171UL,
+ 0x5bbb5a68UL, 0x98e87743UL, 0xd9d96c5aUL, 0x1e4f2d15UL, 0x5f7e360cUL,
+ 0x9c2d1b27UL, 0xdd1c003eUL, 0x120098b9UL, 0x533183a0UL, 0x9062ae8bUL,
+ 0xd153b592UL, 0x16c5f4ddUL, 0x57f4efc4UL, 0x94a7c2efUL, 0xd596d9f6UL,
+ 0xe9bc07aeUL, 0xa88d1cb7UL, 0x6bde319cUL, 0x2aef2a85UL, 0xed796bcaUL,
+ 0xac4870d3UL, 0x6f1b5df8UL, 0x2e2a46e1UL, 0xe136de66UL, 0xa007c57fUL,
+ 0x6354e854UL, 0x2265f34dUL, 0xe5f3b202UL, 0xa4c2a91bUL, 0x67918430UL,
+ 0x26a09f29UL, 0xb8aec5e4UL, 0xf99fdefdUL, 0x3accf3d6UL, 0x7bfde8cfUL,
+ 0xbc6ba980UL, 0xfd5ab299UL, 0x3e099fb2UL, 0x7f3884abUL, 0xb0241c2cUL,
+ 0xf1150735UL, 0x32462a1eUL, 0x73773107UL, 0xb4e17048UL, 0xf5d06b51UL,
+ 0x3683467aUL, 0x77b25d63UL, 0x4ed7facbUL, 0x0fe6e1d2UL, 0xccb5ccf9UL,
+ 0x8d84d7e0UL, 0x4a1296afUL, 0x0b238db6UL, 0xc870a09dUL, 0x8941bb84UL,
+ 0x465d2303UL, 0x076c381aUL, 0xc43f1531UL, 0x850e0e28UL, 0x42984f67UL,
+ 0x03a9547eUL, 0xc0fa7955UL, 0x81cb624cUL, 0x1fc53881UL, 0x5ef42398UL,
+ 0x9da70eb3UL, 0xdc9615aaUL, 0x1b0054e5UL, 0x5a314ffcUL, 0x996262d7UL,
+ 0xd85379ceUL, 0x174fe149UL, 0x567efa50UL, 0x952dd77bUL, 0xd41ccc62UL,
+ 0x138a8d2dUL, 0x52bb9634UL, 0x91e8bb1fUL, 0xd0d9a006UL, 0xecf37e5eUL,
+ 0xadc26547UL, 0x6e91486cUL, 0x2fa05375UL, 0xe836123aUL, 0xa9070923UL,
+ 0x6a542408UL, 0x2b653f11UL, 0xe479a796UL, 0xa548bc8fUL, 0x661b91a4UL,
+ 0x272a8abdUL, 0xe0bccbf2UL, 0xa18dd0ebUL, 0x62defdc0UL, 0x23efe6d9UL,
+ 0xbde1bc14UL, 0xfcd0a70dUL, 0x3f838a26UL, 0x7eb2913fUL, 0xb924d070UL,
+ 0xf815cb69UL, 0x3b46e642UL, 0x7a77fd5bUL, 0xb56b65dcUL, 0xf45a7ec5UL,
+ 0x370953eeUL, 0x763848f7UL, 0xb1ae09b8UL, 0xf09f12a1UL, 0x33cc3f8aUL,
+ 0x72fd2493UL
+ },
+ {
+ 0x00000000UL, 0x376ac201UL, 0x6ed48403UL, 0x59be4602UL, 0xdca80907UL,
+ 0xebc2cb06UL, 0xb27c8d04UL, 0x85164f05UL, 0xb851130eUL, 0x8f3bd10fUL,
+ 0xd685970dUL, 0xe1ef550cUL, 0x64f91a09UL, 0x5393d808UL, 0x0a2d9e0aUL,
+ 0x3d475c0bUL, 0x70a3261cUL, 0x47c9e41dUL, 0x1e77a21fUL, 0x291d601eUL,
+ 0xac0b2f1bUL, 0x9b61ed1aUL, 0xc2dfab18UL, 0xf5b56919UL, 0xc8f23512UL,
+ 0xff98f713UL, 0xa626b111UL, 0x914c7310UL, 0x145a3c15UL, 0x2330fe14UL,
+ 0x7a8eb816UL, 0x4de47a17UL, 0xe0464d38UL, 0xd72c8f39UL, 0x8e92c93bUL,
+ 0xb9f80b3aUL, 0x3cee443fUL, 0x0b84863eUL, 0x523ac03cUL, 0x6550023dUL,
+ 0x58175e36UL, 0x6f7d9c37UL, 0x36c3da35UL, 0x01a91834UL, 0x84bf5731UL,
+ 0xb3d59530UL, 0xea6bd332UL, 0xdd011133UL, 0x90e56b24UL, 0xa78fa925UL,
+ 0xfe31ef27UL, 0xc95b2d26UL, 0x4c4d6223UL, 0x7b27a022UL, 0x2299e620UL,
+ 0x15f32421UL, 0x28b4782aUL, 0x1fdeba2bUL, 0x4660fc29UL, 0x710a3e28UL,
+ 0xf41c712dUL, 0xc376b32cUL, 0x9ac8f52eUL, 0xada2372fUL, 0xc08d9a70UL,
+ 0xf7e75871UL, 0xae591e73UL, 0x9933dc72UL, 0x1c259377UL, 0x2b4f5176UL,
+ 0x72f11774UL, 0x459bd575UL, 0x78dc897eUL, 0x4fb64b7fUL, 0x16080d7dUL,
+ 0x2162cf7cUL, 0xa4748079UL, 0x931e4278UL, 0xcaa0047aUL, 0xfdcac67bUL,
+ 0xb02ebc6cUL, 0x87447e6dUL, 0xdefa386fUL, 0xe990fa6eUL, 0x6c86b56bUL,
+ 0x5bec776aUL, 0x02523168UL, 0x3538f369UL, 0x087faf62UL, 0x3f156d63UL,
+ 0x66ab2b61UL, 0x51c1e960UL, 0xd4d7a665UL, 0xe3bd6464UL, 0xba032266UL,
+ 0x8d69e067UL, 0x20cbd748UL, 0x17a11549UL, 0x4e1f534bUL, 0x7975914aUL,
+ 0xfc63de4fUL, 0xcb091c4eUL, 0x92b75a4cUL, 0xa5dd984dUL, 0x989ac446UL,
+ 0xaff00647UL, 0xf64e4045UL, 0xc1248244UL, 0x4432cd41UL, 0x73580f40UL,
+ 0x2ae64942UL, 0x1d8c8b43UL, 0x5068f154UL, 0x67023355UL, 0x3ebc7557UL,
+ 0x09d6b756UL, 0x8cc0f853UL, 0xbbaa3a52UL, 0xe2147c50UL, 0xd57ebe51UL,
+ 0xe839e25aUL, 0xdf53205bUL, 0x86ed6659UL, 0xb187a458UL, 0x3491eb5dUL,
+ 0x03fb295cUL, 0x5a456f5eUL, 0x6d2fad5fUL, 0x801b35e1UL, 0xb771f7e0UL,
+ 0xeecfb1e2UL, 0xd9a573e3UL, 0x5cb33ce6UL, 0x6bd9fee7UL, 0x3267b8e5UL,
+ 0x050d7ae4UL, 0x384a26efUL, 0x0f20e4eeUL, 0x569ea2ecUL, 0x61f460edUL,
+ 0xe4e22fe8UL, 0xd388ede9UL, 0x8a36abebUL, 0xbd5c69eaUL, 0xf0b813fdUL,
+ 0xc7d2d1fcUL, 0x9e6c97feUL, 0xa90655ffUL, 0x2c101afaUL, 0x1b7ad8fbUL,
+ 0x42c49ef9UL, 0x75ae5cf8UL, 0x48e900f3UL, 0x7f83c2f2UL, 0x263d84f0UL,
+ 0x115746f1UL, 0x944109f4UL, 0xa32bcbf5UL, 0xfa958df7UL, 0xcdff4ff6UL,
+ 0x605d78d9UL, 0x5737bad8UL, 0x0e89fcdaUL, 0x39e33edbUL, 0xbcf571deUL,
+ 0x8b9fb3dfUL, 0xd221f5ddUL, 0xe54b37dcUL, 0xd80c6bd7UL, 0xef66a9d6UL,
+ 0xb6d8efd4UL, 0x81b22dd5UL, 0x04a462d0UL, 0x33cea0d1UL, 0x6a70e6d3UL,
+ 0x5d1a24d2UL, 0x10fe5ec5UL, 0x27949cc4UL, 0x7e2adac6UL, 0x494018c7UL,
+ 0xcc5657c2UL, 0xfb3c95c3UL, 0xa282d3c1UL, 0x95e811c0UL, 0xa8af4dcbUL,
+ 0x9fc58fcaUL, 0xc67bc9c8UL, 0xf1110bc9UL, 0x740744ccUL, 0x436d86cdUL,
+ 0x1ad3c0cfUL, 0x2db902ceUL, 0x4096af91UL, 0x77fc6d90UL, 0x2e422b92UL,
+ 0x1928e993UL, 0x9c3ea696UL, 0xab546497UL, 0xf2ea2295UL, 0xc580e094UL,
+ 0xf8c7bc9fUL, 0xcfad7e9eUL, 0x9613389cUL, 0xa179fa9dUL, 0x246fb598UL,
+ 0x13057799UL, 0x4abb319bUL, 0x7dd1f39aUL, 0x3035898dUL, 0x075f4b8cUL,
+ 0x5ee10d8eUL, 0x698bcf8fUL, 0xec9d808aUL, 0xdbf7428bUL, 0x82490489UL,
+ 0xb523c688UL, 0x88649a83UL, 0xbf0e5882UL, 0xe6b01e80UL, 0xd1dadc81UL,
+ 0x54cc9384UL, 0x63a65185UL, 0x3a181787UL, 0x0d72d586UL, 0xa0d0e2a9UL,
+ 0x97ba20a8UL, 0xce0466aaUL, 0xf96ea4abUL, 0x7c78ebaeUL, 0x4b1229afUL,
+ 0x12ac6fadUL, 0x25c6adacUL, 0x1881f1a7UL, 0x2feb33a6UL, 0x765575a4UL,
+ 0x413fb7a5UL, 0xc429f8a0UL, 0xf3433aa1UL, 0xaafd7ca3UL, 0x9d97bea2UL,
+ 0xd073c4b5UL, 0xe71906b4UL, 0xbea740b6UL, 0x89cd82b7UL, 0x0cdbcdb2UL,
+ 0x3bb10fb3UL, 0x620f49b1UL, 0x55658bb0UL, 0x6822d7bbUL, 0x5f4815baUL,
+ 0x06f653b8UL, 0x319c91b9UL, 0xb48adebcUL, 0x83e01cbdUL, 0xda5e5abfUL,
+ 0xed3498beUL
+ },
+ {
+ 0x00000000UL, 0x6567bcb8UL, 0x8bc809aaUL, 0xeeafb512UL, 0x5797628fUL,
+ 0x32f0de37UL, 0xdc5f6b25UL, 0xb938d79dUL, 0xef28b4c5UL, 0x8a4f087dUL,
+ 0x64e0bd6fUL, 0x018701d7UL, 0xb8bfd64aUL, 0xddd86af2UL, 0x3377dfe0UL,
+ 0x56106358UL, 0x9f571950UL, 0xfa30a5e8UL, 0x149f10faUL, 0x71f8ac42UL,
+ 0xc8c07bdfUL, 0xada7c767UL, 0x43087275UL, 0x266fcecdUL, 0x707fad95UL,
+ 0x1518112dUL, 0xfbb7a43fUL, 0x9ed01887UL, 0x27e8cf1aUL, 0x428f73a2UL,
+ 0xac20c6b0UL, 0xc9477a08UL, 0x3eaf32a0UL, 0x5bc88e18UL, 0xb5673b0aUL,
+ 0xd00087b2UL, 0x6938502fUL, 0x0c5fec97UL, 0xe2f05985UL, 0x8797e53dUL,
+ 0xd1878665UL, 0xb4e03addUL, 0x5a4f8fcfUL, 0x3f283377UL, 0x8610e4eaUL,
+ 0xe3775852UL, 0x0dd8ed40UL, 0x68bf51f8UL, 0xa1f82bf0UL, 0xc49f9748UL,
+ 0x2a30225aUL, 0x4f579ee2UL, 0xf66f497fUL, 0x9308f5c7UL, 0x7da740d5UL,
+ 0x18c0fc6dUL, 0x4ed09f35UL, 0x2bb7238dUL, 0xc518969fUL, 0xa07f2a27UL,
+ 0x1947fdbaUL, 0x7c204102UL, 0x928ff410UL, 0xf7e848a8UL, 0x3d58149bUL,
+ 0x583fa823UL, 0xb6901d31UL, 0xd3f7a189UL, 0x6acf7614UL, 0x0fa8caacUL,
+ 0xe1077fbeUL, 0x8460c306UL, 0xd270a05eUL, 0xb7171ce6UL, 0x59b8a9f4UL,
+ 0x3cdf154cUL, 0x85e7c2d1UL, 0xe0807e69UL, 0x0e2fcb7bUL, 0x6b4877c3UL,
+ 0xa20f0dcbUL, 0xc768b173UL, 0x29c70461UL, 0x4ca0b8d9UL, 0xf5986f44UL,
+ 0x90ffd3fcUL, 0x7e5066eeUL, 0x1b37da56UL, 0x4d27b90eUL, 0x284005b6UL,
+ 0xc6efb0a4UL, 0xa3880c1cUL, 0x1ab0db81UL, 0x7fd76739UL, 0x9178d22bUL,
+ 0xf41f6e93UL, 0x03f7263bUL, 0x66909a83UL, 0x883f2f91UL, 0xed589329UL,
+ 0x546044b4UL, 0x3107f80cUL, 0xdfa84d1eUL, 0xbacff1a6UL, 0xecdf92feUL,
+ 0x89b82e46UL, 0x67179b54UL, 0x027027ecUL, 0xbb48f071UL, 0xde2f4cc9UL,
+ 0x3080f9dbUL, 0x55e74563UL, 0x9ca03f6bUL, 0xf9c783d3UL, 0x176836c1UL,
+ 0x720f8a79UL, 0xcb375de4UL, 0xae50e15cUL, 0x40ff544eUL, 0x2598e8f6UL,
+ 0x73888baeUL, 0x16ef3716UL, 0xf8408204UL, 0x9d273ebcUL, 0x241fe921UL,
+ 0x41785599UL, 0xafd7e08bUL, 0xcab05c33UL, 0x3bb659edUL, 0x5ed1e555UL,
+ 0xb07e5047UL, 0xd519ecffUL, 0x6c213b62UL, 0x094687daUL, 0xe7e932c8UL,
+ 0x828e8e70UL, 0xd49eed28UL, 0xb1f95190UL, 0x5f56e482UL, 0x3a31583aUL,
+ 0x83098fa7UL, 0xe66e331fUL, 0x08c1860dUL, 0x6da63ab5UL, 0xa4e140bdUL,
+ 0xc186fc05UL, 0x2f294917UL, 0x4a4ef5afUL, 0xf3762232UL, 0x96119e8aUL,
+ 0x78be2b98UL, 0x1dd99720UL, 0x4bc9f478UL, 0x2eae48c0UL, 0xc001fdd2UL,
+ 0xa566416aUL, 0x1c5e96f7UL, 0x79392a4fUL, 0x97969f5dUL, 0xf2f123e5UL,
+ 0x05196b4dUL, 0x607ed7f5UL, 0x8ed162e7UL, 0xebb6de5fUL, 0x528e09c2UL,
+ 0x37e9b57aUL, 0xd9460068UL, 0xbc21bcd0UL, 0xea31df88UL, 0x8f566330UL,
+ 0x61f9d622UL, 0x049e6a9aUL, 0xbda6bd07UL, 0xd8c101bfUL, 0x366eb4adUL,
+ 0x53090815UL, 0x9a4e721dUL, 0xff29cea5UL, 0x11867bb7UL, 0x74e1c70fUL,
+ 0xcdd91092UL, 0xa8beac2aUL, 0x46111938UL, 0x2376a580UL, 0x7566c6d8UL,
+ 0x10017a60UL, 0xfeaecf72UL, 0x9bc973caUL, 0x22f1a457UL, 0x479618efUL,
+ 0xa939adfdUL, 0xcc5e1145UL, 0x06ee4d76UL, 0x6389f1ceUL, 0x8d2644dcUL,
+ 0xe841f864UL, 0x51792ff9UL, 0x341e9341UL, 0xdab12653UL, 0xbfd69aebUL,
+ 0xe9c6f9b3UL, 0x8ca1450bUL, 0x620ef019UL, 0x07694ca1UL, 0xbe519b3cUL,
+ 0xdb362784UL, 0x35999296UL, 0x50fe2e2eUL, 0x99b95426UL, 0xfcdee89eUL,
+ 0x12715d8cUL, 0x7716e134UL, 0xce2e36a9UL, 0xab498a11UL, 0x45e63f03UL,
+ 0x208183bbUL, 0x7691e0e3UL, 0x13f65c5bUL, 0xfd59e949UL, 0x983e55f1UL,
+ 0x2106826cUL, 0x44613ed4UL, 0xaace8bc6UL, 0xcfa9377eUL, 0x38417fd6UL,
+ 0x5d26c36eUL, 0xb389767cUL, 0xd6eecac4UL, 0x6fd61d59UL, 0x0ab1a1e1UL,
+ 0xe41e14f3UL, 0x8179a84bUL, 0xd769cb13UL, 0xb20e77abUL, 0x5ca1c2b9UL,
+ 0x39c67e01UL, 0x80fea99cUL, 0xe5991524UL, 0x0b36a036UL, 0x6e511c8eUL,
+ 0xa7166686UL, 0xc271da3eUL, 0x2cde6f2cUL, 0x49b9d394UL, 0xf0810409UL,
+ 0x95e6b8b1UL, 0x7b490da3UL, 0x1e2eb11bUL, 0x483ed243UL, 0x2d596efbUL,
+ 0xc3f6dbe9UL, 0xa6916751UL, 0x1fa9b0ccUL, 0x7ace0c74UL, 0x9461b966UL,
+ 0xf10605deUL
+#endif
+ }
+};
diff --git a/ml/dlib/dlib/external/zlib/deflate.c b/ml/dlib/dlib/external/zlib/deflate.c
new file mode 100644
index 000000000..696957705
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/deflate.c
@@ -0,0 +1,1967 @@
+/* deflate.c -- compress data using the deflation algorithm
+ * Copyright (C) 1995-2013 Jean-loup Gailly and Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/*
+ * ALGORITHM
+ *
+ * The "deflation" process depends on being able to identify portions
+ * of the input text which are identical to earlier input (within a
+ * sliding window trailing behind the input currently being processed).
+ *
+ * The most straightforward technique turns out to be the fastest for
+ * most input files: try all possible matches and select the longest.
+ * The key feature of this algorithm is that insertions into the string
+ * dictionary are very simple and thus fast, and deletions are avoided
+ * completely. Insertions are performed at each input character, whereas
+ * string matches are performed only when the previous match ends. So it
+ * is preferable to spend more time in matches to allow very fast string
+ * insertions and avoid deletions. The matching algorithm for small
+ * strings is inspired from that of Rabin & Karp. A brute force approach
+ * is used to find longer strings when a small match has been found.
+ * A similar algorithm is used in comic (by Jan-Mark Wams) and freeze
+ * (by Leonid Broukhis).
+ * A previous version of this file used a more sophisticated algorithm
+ * (by Fiala and Greene) which is guaranteed to run in linear amortized
+ * time, but has a larger average cost, uses more memory and is patented.
+ * However the F&G algorithm may be faster for some highly redundant
+ * files if the parameter max_chain_length (described below) is too large.
+ *
+ * ACKNOWLEDGEMENTS
+ *
+ * The idea of lazy evaluation of matches is due to Jan-Mark Wams, and
+ * I found it in 'freeze' written by Leonid Broukhis.
+ * Thanks to many people for bug reports and testing.
+ *
+ * REFERENCES
+ *
+ * Deutsch, L.P.,"DEFLATE Compressed Data Format Specification".
+ * Available in http://tools.ietf.org/html/rfc1951
+ *
+ * A description of the Rabin and Karp algorithm is given in the book
+ * "Algorithms" by R. Sedgewick, Addison-Wesley, p252.
+ *
+ * Fiala,E.R., and Greene,D.H.
+ * Data Compression with Finite Windows, Comm.ACM, 32,4 (1989) 490-595
+ *
+ */
+
+/* @(#) $Id$ */
+
+#include "deflate.h"
+
+const char deflate_copyright[] =
+ " deflate 1.2.8 Copyright 1995-2013 Jean-loup Gailly and Mark Adler ";
+/*
+ If you use the zlib library in a product, an acknowledgment is welcome
+ in the documentation of your product. If for some reason you cannot
+ include such an acknowledgment, I would appreciate that you keep this
+ copyright string in the executable of your product.
+ */
+
+/* ===========================================================================
+ * Function prototypes.
+ */
+typedef enum {
+ need_more, /* block not completed, need more input or more output */
+ block_done, /* block flush performed */
+ finish_started, /* finish started, need only more output at next deflate */
+ finish_done /* finish done, accept no more input or output */
+} block_state;
+
+typedef block_state (*compress_func) OF((deflate_state *s, int flush));
+/* Compression function. Returns the block state after the call. */
+
+local void fill_window OF((deflate_state *s));
+local block_state deflate_stored OF((deflate_state *s, int flush));
+local block_state deflate_fast OF((deflate_state *s, int flush));
+#ifndef FASTEST
+local block_state deflate_slow OF((deflate_state *s, int flush));
+#endif
+local block_state deflate_rle OF((deflate_state *s, int flush));
+local block_state deflate_huff OF((deflate_state *s, int flush));
+local void lm_init OF((deflate_state *s));
+local void putShortMSB OF((deflate_state *s, uInt b));
+local void flush_pending OF((z_streamp strm));
+local int read_buf OF((z_streamp strm, Bytef *buf, unsigned size));
+#ifdef ASMV
+ void match_init OF((void)); /* asm code initialization */
+ uInt longest_match OF((deflate_state *s, IPos cur_match));
+#else
+local uInt longest_match OF((deflate_state *s, IPos cur_match));
+#endif
+
+#ifdef DEBUG
+local void check_match OF((deflate_state *s, IPos start, IPos match,
+ int length));
+#endif
+
+/* ===========================================================================
+ * Local data
+ */
+
+#define NIL 0
+/* Tail of hash chains */
+
+#ifndef TOO_FAR
+# define TOO_FAR 4096
+#endif
+/* Matches of length 3 are discarded if their distance exceeds TOO_FAR */
+
+/* Values for max_lazy_match, good_match and max_chain_length, depending on
+ * the desired pack level (0..9). The values given below have been tuned to
+ * exclude worst case performance for pathological files. Better values may be
+ * found for specific files.
+ */
+typedef struct config_s {
+ ush good_length; /* reduce lazy search above this match length */
+ ush max_lazy; /* do not perform lazy search above this match length */
+ ush nice_length; /* quit search above this match length */
+ ush max_chain;
+ compress_func func;
+} config;
+
+#ifdef FASTEST
+local const config configuration_table[2] = {
+/* good lazy nice chain */
+/* 0 */ {0, 0, 0, 0, deflate_stored}, /* store only */
+/* 1 */ {4, 4, 8, 4, deflate_fast}}; /* max speed, no lazy matches */
+#else
+local const config configuration_table[10] = {
+/* good lazy nice chain */
+/* 0 */ {0, 0, 0, 0, deflate_stored}, /* store only */
+/* 1 */ {4, 4, 8, 4, deflate_fast}, /* max speed, no lazy matches */
+/* 2 */ {4, 5, 16, 8, deflate_fast},
+/* 3 */ {4, 6, 32, 32, deflate_fast},
+
+/* 4 */ {4, 4, 16, 16, deflate_slow}, /* lazy matches */
+/* 5 */ {8, 16, 32, 32, deflate_slow},
+/* 6 */ {8, 16, 128, 128, deflate_slow},
+/* 7 */ {8, 32, 128, 256, deflate_slow},
+/* 8 */ {32, 128, 258, 1024, deflate_slow},
+/* 9 */ {32, 258, 258, 4096, deflate_slow}}; /* max compression */
+#endif
+
+/* Note: the deflate() code requires max_lazy >= MIN_MATCH and max_chain >= 4
+ * For deflate_fast() (levels <= 3) good is ignored and lazy has a different
+ * meaning.
+ */
+
+#define EQUAL 0
+/* result of memcmp for equal strings */
+
+#ifndef NO_DUMMY_DECL
+struct static_tree_desc_s {int dummy;}; /* for buggy compilers */
+#endif
+
+/* rank Z_BLOCK between Z_NO_FLUSH and Z_PARTIAL_FLUSH */
+#define RANK(f) (((f) << 1) - ((f) > 4 ? 9 : 0))
+
+/* ===========================================================================
+ * Update a hash value with the given input byte
+ * IN assertion: all calls to to UPDATE_HASH are made with consecutive
+ * input characters, so that a running hash key can be computed from the
+ * previous key instead of complete recalculation each time.
+ */
+#define UPDATE_HASH(s,h,c) (h = (((h)<<s->hash_shift) ^ (c)) & s->hash_mask)
+
+
+/* ===========================================================================
+ * Insert string str in the dictionary and set match_head to the previous head
+ * of the hash chain (the most recent string with same hash key). Return
+ * the previous length of the hash chain.
+ * If this file is compiled with -DFASTEST, the compression level is forced
+ * to 1, and no hash chains are maintained.
+ * IN assertion: all calls to to INSERT_STRING are made with consecutive
+ * input characters and the first MIN_MATCH bytes of str are valid
+ * (except for the last MIN_MATCH-1 bytes of the input file).
+ */
+#ifdef FASTEST
+#define INSERT_STRING(s, str, match_head) \
+ (UPDATE_HASH(s, s->ins_h, s->window[(str) + (MIN_MATCH-1)]), \
+ match_head = s->head[s->ins_h], \
+ s->head[s->ins_h] = (Pos)(str))
+#else
+#define INSERT_STRING(s, str, match_head) \
+ (UPDATE_HASH(s, s->ins_h, s->window[(str) + (MIN_MATCH-1)]), \
+ match_head = s->prev[(str) & s->w_mask] = s->head[s->ins_h], \
+ s->head[s->ins_h] = (Pos)(str))
+#endif
+
+/* ===========================================================================
+ * Initialize the hash table (avoiding 64K overflow for 16 bit systems).
+ * prev[] will be initialized on the fly.
+ */
+#define CLEAR_HASH(s) \
+ s->head[s->hash_size-1] = NIL; \
+ zmemzero((Bytef *)s->head, (unsigned)(s->hash_size-1)*sizeof(*s->head));
+
+/* ========================================================================= */
+int ZEXPORT deflateInit_(strm, level, version, stream_size)
+ z_streamp strm;
+ int level;
+ const char *version;
+ int stream_size;
+{
+ return deflateInit2_(strm, level, Z_DEFLATED, MAX_WBITS, DEF_MEM_LEVEL,
+ Z_DEFAULT_STRATEGY, version, stream_size);
+ /* To do: ignore strm->next_in if we use it as window */
+}
+
+/* ========================================================================= */
+int ZEXPORT deflateInit2_(strm, level, method, windowBits, memLevel, strategy,
+ version, stream_size)
+ z_streamp strm;
+ int level;
+ int method;
+ int windowBits;
+ int memLevel;
+ int strategy;
+ const char *version;
+ int stream_size;
+{
+ deflate_state *s;
+ int wrap = 1;
+ static const char my_version[] = ZLIB_VERSION;
+
+ ushf *overlay;
+ /* We overlay pending_buf and d_buf+l_buf. This works since the average
+ * output size for (length,distance) codes is <= 24 bits.
+ */
+
+ if (version == Z_NULL || version[0] != my_version[0] ||
+ stream_size != sizeof(z_stream)) {
+ return Z_VERSION_ERROR;
+ }
+ if (strm == Z_NULL) return Z_STREAM_ERROR;
+
+ strm->msg = Z_NULL;
+ if (strm->zalloc == (alloc_func)0) {
+#ifdef Z_SOLO
+ return Z_STREAM_ERROR;
+#else
+ strm->zalloc = zcalloc;
+ strm->opaque = (voidpf)0;
+#endif
+ }
+ if (strm->zfree == (free_func)0)
+#ifdef Z_SOLO
+ return Z_STREAM_ERROR;
+#else
+ strm->zfree = zcfree;
+#endif
+
+#ifdef FASTEST
+ if (level != 0) level = 1;
+#else
+ if (level == Z_DEFAULT_COMPRESSION) level = 6;
+#endif
+
+ if (windowBits < 0) { /* suppress zlib wrapper */
+ wrap = 0;
+ windowBits = -windowBits;
+ }
+#ifdef GZIP
+ else if (windowBits > 15) {
+ wrap = 2; /* write gzip wrapper instead */
+ windowBits -= 16;
+ }
+#endif
+ if (memLevel < 1 || memLevel > MAX_MEM_LEVEL || method != Z_DEFLATED ||
+ windowBits < 8 || windowBits > 15 || level < 0 || level > 9 ||
+ strategy < 0 || strategy > Z_FIXED) {
+ return Z_STREAM_ERROR;
+ }
+ if (windowBits == 8) windowBits = 9; /* until 256-byte window bug fixed */
+ s = (deflate_state *) ZALLOC(strm, 1, sizeof(deflate_state));
+ if (s == Z_NULL) return Z_MEM_ERROR;
+ strm->state = (struct internal_state FAR *)s;
+ s->strm = strm;
+
+ s->wrap = wrap;
+ s->gzhead = Z_NULL;
+ s->w_bits = windowBits;
+ s->w_size = 1 << s->w_bits;
+ s->w_mask = s->w_size - 1;
+
+ s->hash_bits = memLevel + 7;
+ s->hash_size = 1 << s->hash_bits;
+ s->hash_mask = s->hash_size - 1;
+ s->hash_shift = ((s->hash_bits+MIN_MATCH-1)/MIN_MATCH);
+
+ s->window = (Bytef *) ZALLOC(strm, s->w_size, 2*sizeof(Byte));
+ s->prev = (Posf *) ZALLOC(strm, s->w_size, sizeof(Pos));
+ s->head = (Posf *) ZALLOC(strm, s->hash_size, sizeof(Pos));
+
+ s->high_water = 0; /* nothing written to s->window yet */
+
+ s->lit_bufsize = 1 << (memLevel + 6); /* 16K elements by default */
+
+ overlay = (ushf *) ZALLOC(strm, s->lit_bufsize, sizeof(ush)+2);
+ s->pending_buf = (uchf *) overlay;
+ s->pending_buf_size = (ulg)s->lit_bufsize * (sizeof(ush)+2L);
+
+ if (s->window == Z_NULL || s->prev == Z_NULL || s->head == Z_NULL ||
+ s->pending_buf == Z_NULL) {
+ s->status = FINISH_STATE;
+ strm->msg = ERR_MSG(Z_MEM_ERROR);
+ deflateEnd (strm);
+ return Z_MEM_ERROR;
+ }
+ s->d_buf = overlay + s->lit_bufsize/sizeof(ush);
+ s->l_buf = s->pending_buf + (1+sizeof(ush))*s->lit_bufsize;
+
+ s->level = level;
+ s->strategy = strategy;
+ s->method = (Byte)method;
+
+ return deflateReset(strm);
+}
+
+/* ========================================================================= */
+int ZEXPORT deflateSetDictionary (strm, dictionary, dictLength)
+ z_streamp strm;
+ const Bytef *dictionary;
+ uInt dictLength;
+{
+ deflate_state *s;
+ uInt str, n;
+ int wrap;
+ unsigned avail;
+ z_const unsigned char *next;
+
+ if (strm == Z_NULL || strm->state == Z_NULL || dictionary == Z_NULL)
+ return Z_STREAM_ERROR;
+ s = strm->state;
+ wrap = s->wrap;
+ if (wrap == 2 || (wrap == 1 && s->status != INIT_STATE) || s->lookahead)
+ return Z_STREAM_ERROR;
+
+ /* when using zlib wrappers, compute Adler-32 for provided dictionary */
+ if (wrap == 1)
+ strm->adler = adler32(strm->adler, dictionary, dictLength);
+ s->wrap = 0; /* avoid computing Adler-32 in read_buf */
+
+ /* if dictionary would fill window, just replace the history */
+ if (dictLength >= s->w_size) {
+ if (wrap == 0) { /* already empty otherwise */
+ CLEAR_HASH(s);
+ s->strstart = 0;
+ s->block_start = 0L;
+ s->insert = 0;
+ }
+ dictionary += dictLength - s->w_size; /* use the tail */
+ dictLength = s->w_size;
+ }
+
+ /* insert dictionary into window and hash */
+ avail = strm->avail_in;
+ next = strm->next_in;
+ strm->avail_in = dictLength;
+ strm->next_in = (z_const Bytef *)dictionary;
+ fill_window(s);
+ while (s->lookahead >= MIN_MATCH) {
+ str = s->strstart;
+ n = s->lookahead - (MIN_MATCH-1);
+ do {
+ UPDATE_HASH(s, s->ins_h, s->window[str + MIN_MATCH-1]);
+#ifndef FASTEST
+ s->prev[str & s->w_mask] = s->head[s->ins_h];
+#endif
+ s->head[s->ins_h] = (Pos)str;
+ str++;
+ } while (--n);
+ s->strstart = str;
+ s->lookahead = MIN_MATCH-1;
+ fill_window(s);
+ }
+ s->strstart += s->lookahead;
+ s->block_start = (long)s->strstart;
+ s->insert = s->lookahead;
+ s->lookahead = 0;
+ s->match_length = s->prev_length = MIN_MATCH-1;
+ s->match_available = 0;
+ strm->next_in = next;
+ strm->avail_in = avail;
+ s->wrap = wrap;
+ return Z_OK;
+}
+
+/* ========================================================================= */
+int ZEXPORT deflateResetKeep (strm)
+ z_streamp strm;
+{
+ deflate_state *s;
+
+ if (strm == Z_NULL || strm->state == Z_NULL ||
+ strm->zalloc == (alloc_func)0 || strm->zfree == (free_func)0) {
+ return Z_STREAM_ERROR;
+ }
+
+ strm->total_in = strm->total_out = 0;
+ strm->msg = Z_NULL; /* use zfree if we ever allocate msg dynamically */
+ strm->data_type = Z_UNKNOWN;
+
+ s = (deflate_state *)strm->state;
+ s->pending = 0;
+ s->pending_out = s->pending_buf;
+
+ if (s->wrap < 0) {
+ s->wrap = -s->wrap; /* was made negative by deflate(..., Z_FINISH); */
+ }
+ s->status = s->wrap ? INIT_STATE : BUSY_STATE;
+ strm->adler =
+#ifdef GZIP
+ s->wrap == 2 ? crc32(0L, Z_NULL, 0) :
+#endif
+ adler32(0L, Z_NULL, 0);
+ s->last_flush = Z_NO_FLUSH;
+
+ _tr_init(s);
+
+ return Z_OK;
+}
+
+/* ========================================================================= */
+int ZEXPORT deflateReset (strm)
+ z_streamp strm;
+{
+ int ret;
+
+ ret = deflateResetKeep(strm);
+ if (ret == Z_OK)
+ lm_init(strm->state);
+ return ret;
+}
+
+/* ========================================================================= */
+int ZEXPORT deflateSetHeader (strm, head)
+ z_streamp strm;
+ gz_headerp head;
+{
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ if (strm->state->wrap != 2) return Z_STREAM_ERROR;
+ strm->state->gzhead = head;
+ return Z_OK;
+}
+
+/* ========================================================================= */
+int ZEXPORT deflatePending (strm, pending, bits)
+ unsigned *pending;
+ int *bits;
+ z_streamp strm;
+{
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ if (pending != Z_NULL)
+ *pending = strm->state->pending;
+ if (bits != Z_NULL)
+ *bits = strm->state->bi_valid;
+ return Z_OK;
+}
+
+/* ========================================================================= */
+int ZEXPORT deflatePrime (strm, bits, value)
+ z_streamp strm;
+ int bits;
+ int value;
+{
+ deflate_state *s;
+ int put;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ s = strm->state;
+ if ((Bytef *)(s->d_buf) < s->pending_out + ((Buf_size + 7) >> 3))
+ return Z_BUF_ERROR;
+ do {
+ put = Buf_size - s->bi_valid;
+ if (put > bits)
+ put = bits;
+ s->bi_buf |= (ush)((value & ((1 << put) - 1)) << s->bi_valid);
+ s->bi_valid += put;
+ _tr_flush_bits(s);
+ value >>= put;
+ bits -= put;
+ } while (bits);
+ return Z_OK;
+}
+
+/* ========================================================================= */
+int ZEXPORT deflateParams(strm, level, strategy)
+ z_streamp strm;
+ int level;
+ int strategy;
+{
+ deflate_state *s;
+ compress_func func;
+ int err = Z_OK;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ s = strm->state;
+
+#ifdef FASTEST
+ if (level != 0) level = 1;
+#else
+ if (level == Z_DEFAULT_COMPRESSION) level = 6;
+#endif
+ if (level < 0 || level > 9 || strategy < 0 || strategy > Z_FIXED) {
+ return Z_STREAM_ERROR;
+ }
+ func = configuration_table[s->level].func;
+
+ if ((strategy != s->strategy || func != configuration_table[level].func) &&
+ strm->total_in != 0) {
+ /* Flush the last buffer: */
+ err = deflate(strm, Z_BLOCK);
+ if (err == Z_BUF_ERROR && s->pending == 0)
+ err = Z_OK;
+ }
+ if (s->level != level) {
+ s->level = level;
+ s->max_lazy_match = configuration_table[level].max_lazy;
+ s->good_match = configuration_table[level].good_length;
+ s->nice_match = configuration_table[level].nice_length;
+ s->max_chain_length = configuration_table[level].max_chain;
+ }
+ s->strategy = strategy;
+ return err;
+}
+
+/* ========================================================================= */
+int ZEXPORT deflateTune(strm, good_length, max_lazy, nice_length, max_chain)
+ z_streamp strm;
+ int good_length;
+ int max_lazy;
+ int nice_length;
+ int max_chain;
+{
+ deflate_state *s;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ s = strm->state;
+ s->good_match = good_length;
+ s->max_lazy_match = max_lazy;
+ s->nice_match = nice_length;
+ s->max_chain_length = max_chain;
+ return Z_OK;
+}
+
+/* =========================================================================
+ * For the default windowBits of 15 and memLevel of 8, this function returns
+ * a close to exact, as well as small, upper bound on the compressed size.
+ * They are coded as constants here for a reason--if the #define's are
+ * changed, then this function needs to be changed as well. The return
+ * value for 15 and 8 only works for those exact settings.
+ *
+ * For any setting other than those defaults for windowBits and memLevel,
+ * the value returned is a conservative worst case for the maximum expansion
+ * resulting from using fixed blocks instead of stored blocks, which deflate
+ * can emit on compressed data for some combinations of the parameters.
+ *
+ * This function could be more sophisticated to provide closer upper bounds for
+ * every combination of windowBits and memLevel. But even the conservative
+ * upper bound of about 14% expansion does not seem onerous for output buffer
+ * allocation.
+ */
+uLong ZEXPORT deflateBound(strm, sourceLen)
+ z_streamp strm;
+ uLong sourceLen;
+{
+ deflate_state *s;
+ uLong complen, wraplen;
+ Bytef *str;
+
+ /* conservative upper bound for compressed data */
+ complen = sourceLen +
+ ((sourceLen + 7) >> 3) + ((sourceLen + 63) >> 6) + 5;
+
+ /* if can't get parameters, return conservative bound plus zlib wrapper */
+ if (strm == Z_NULL || strm->state == Z_NULL)
+ return complen + 6;
+
+ /* compute wrapper length */
+ s = strm->state;
+ switch (s->wrap) {
+ case 0: /* raw deflate */
+ wraplen = 0;
+ break;
+ case 1: /* zlib wrapper */
+ wraplen = 6 + (s->strstart ? 4 : 0);
+ break;
+ case 2: /* gzip wrapper */
+ wraplen = 18;
+ if (s->gzhead != Z_NULL) { /* user-supplied gzip header */
+ if (s->gzhead->extra != Z_NULL)
+ wraplen += 2 + s->gzhead->extra_len;
+ str = s->gzhead->name;
+ if (str != Z_NULL)
+ do {
+ wraplen++;
+ } while (*str++);
+ str = s->gzhead->comment;
+ if (str != Z_NULL)
+ do {
+ wraplen++;
+ } while (*str++);
+ if (s->gzhead->hcrc)
+ wraplen += 2;
+ }
+ break;
+ default: /* for compiler happiness */
+ wraplen = 6;
+ }
+
+ /* if not default parameters, return conservative bound */
+ if (s->w_bits != 15 || s->hash_bits != 8 + 7)
+ return complen + wraplen;
+
+ /* default settings: return tight bound for that case */
+ return sourceLen + (sourceLen >> 12) + (sourceLen >> 14) +
+ (sourceLen >> 25) + 13 - 6 + wraplen;
+}
+
+/* =========================================================================
+ * Put a short in the pending buffer. The 16-bit value is put in MSB order.
+ * IN assertion: the stream state is correct and there is enough room in
+ * pending_buf.
+ */
+local void putShortMSB (s, b)
+ deflate_state *s;
+ uInt b;
+{
+ put_byte(s, (Byte)(b >> 8));
+ put_byte(s, (Byte)(b & 0xff));
+}
+
+/* =========================================================================
+ * Flush as much pending output as possible. All deflate() output goes
+ * through this function so some applications may wish to modify it
+ * to avoid allocating a large strm->next_out buffer and copying into it.
+ * (See also read_buf()).
+ */
+local void flush_pending(strm)
+ z_streamp strm;
+{
+ unsigned len;
+ deflate_state *s = strm->state;
+
+ _tr_flush_bits(s);
+ len = s->pending;
+ if (len > strm->avail_out) len = strm->avail_out;
+ if (len == 0) return;
+
+ zmemcpy(strm->next_out, s->pending_out, len);
+ strm->next_out += len;
+ s->pending_out += len;
+ strm->total_out += len;
+ strm->avail_out -= len;
+ s->pending -= len;
+ if (s->pending == 0) {
+ s->pending_out = s->pending_buf;
+ }
+}
+
+/* ========================================================================= */
+int ZEXPORT deflate (strm, flush)
+ z_streamp strm;
+ int flush;
+{
+ int old_flush; /* value of flush param for previous deflate call */
+ deflate_state *s;
+
+ if (strm == Z_NULL || strm->state == Z_NULL ||
+ flush > Z_BLOCK || flush < 0) {
+ return Z_STREAM_ERROR;
+ }
+ s = strm->state;
+
+ if (strm->next_out == Z_NULL ||
+ (strm->next_in == Z_NULL && strm->avail_in != 0) ||
+ (s->status == FINISH_STATE && flush != Z_FINISH)) {
+ ERR_RETURN(strm, Z_STREAM_ERROR);
+ }
+ if (strm->avail_out == 0) ERR_RETURN(strm, Z_BUF_ERROR);
+
+ s->strm = strm; /* just in case */
+ old_flush = s->last_flush;
+ s->last_flush = flush;
+
+ /* Write the header */
+ if (s->status == INIT_STATE) {
+#ifdef GZIP
+ if (s->wrap == 2) {
+ strm->adler = crc32(0L, Z_NULL, 0);
+ put_byte(s, 31);
+ put_byte(s, 139);
+ put_byte(s, 8);
+ if (s->gzhead == Z_NULL) {
+ put_byte(s, 0);
+ put_byte(s, 0);
+ put_byte(s, 0);
+ put_byte(s, 0);
+ put_byte(s, 0);
+ put_byte(s, s->level == 9 ? 2 :
+ (s->strategy >= Z_HUFFMAN_ONLY || s->level < 2 ?
+ 4 : 0));
+ put_byte(s, OS_CODE);
+ s->status = BUSY_STATE;
+ }
+ else {
+ put_byte(s, (s->gzhead->text ? 1 : 0) +
+ (s->gzhead->hcrc ? 2 : 0) +
+ (s->gzhead->extra == Z_NULL ? 0 : 4) +
+ (s->gzhead->name == Z_NULL ? 0 : 8) +
+ (s->gzhead->comment == Z_NULL ? 0 : 16)
+ );
+ put_byte(s, (Byte)(s->gzhead->time & 0xff));
+ put_byte(s, (Byte)((s->gzhead->time >> 8) & 0xff));
+ put_byte(s, (Byte)((s->gzhead->time >> 16) & 0xff));
+ put_byte(s, (Byte)((s->gzhead->time >> 24) & 0xff));
+ put_byte(s, s->level == 9 ? 2 :
+ (s->strategy >= Z_HUFFMAN_ONLY || s->level < 2 ?
+ 4 : 0));
+ put_byte(s, s->gzhead->os & 0xff);
+ if (s->gzhead->extra != Z_NULL) {
+ put_byte(s, s->gzhead->extra_len & 0xff);
+ put_byte(s, (s->gzhead->extra_len >> 8) & 0xff);
+ }
+ if (s->gzhead->hcrc)
+ strm->adler = crc32(strm->adler, s->pending_buf,
+ s->pending);
+ s->gzindex = 0;
+ s->status = EXTRA_STATE;
+ }
+ }
+ else
+#endif
+ {
+ uInt header = (Z_DEFLATED + ((s->w_bits-8)<<4)) << 8;
+ uInt level_flags;
+
+ if (s->strategy >= Z_HUFFMAN_ONLY || s->level < 2)
+ level_flags = 0;
+ else if (s->level < 6)
+ level_flags = 1;
+ else if (s->level == 6)
+ level_flags = 2;
+ else
+ level_flags = 3;
+ header |= (level_flags << 6);
+ if (s->strstart != 0) header |= PRESET_DICT;
+ header += 31 - (header % 31);
+
+ s->status = BUSY_STATE;
+ putShortMSB(s, header);
+
+ /* Save the adler32 of the preset dictionary: */
+ if (s->strstart != 0) {
+ putShortMSB(s, (uInt)(strm->adler >> 16));
+ putShortMSB(s, (uInt)(strm->adler & 0xffff));
+ }
+ strm->adler = adler32(0L, Z_NULL, 0);
+ }
+ }
+#ifdef GZIP
+ if (s->status == EXTRA_STATE) {
+ if (s->gzhead->extra != Z_NULL) {
+ uInt beg = s->pending; /* start of bytes to update crc */
+
+ while (s->gzindex < (s->gzhead->extra_len & 0xffff)) {
+ if (s->pending == s->pending_buf_size) {
+ if (s->gzhead->hcrc && s->pending > beg)
+ strm->adler = crc32(strm->adler, s->pending_buf + beg,
+ s->pending - beg);
+ flush_pending(strm);
+ beg = s->pending;
+ if (s->pending == s->pending_buf_size)
+ break;
+ }
+ put_byte(s, s->gzhead->extra[s->gzindex]);
+ s->gzindex++;
+ }
+ if (s->gzhead->hcrc && s->pending > beg)
+ strm->adler = crc32(strm->adler, s->pending_buf + beg,
+ s->pending - beg);
+ if (s->gzindex == s->gzhead->extra_len) {
+ s->gzindex = 0;
+ s->status = NAME_STATE;
+ }
+ }
+ else
+ s->status = NAME_STATE;
+ }
+ if (s->status == NAME_STATE) {
+ if (s->gzhead->name != Z_NULL) {
+ uInt beg = s->pending; /* start of bytes to update crc */
+ int val;
+
+ do {
+ if (s->pending == s->pending_buf_size) {
+ if (s->gzhead->hcrc && s->pending > beg)
+ strm->adler = crc32(strm->adler, s->pending_buf + beg,
+ s->pending - beg);
+ flush_pending(strm);
+ beg = s->pending;
+ if (s->pending == s->pending_buf_size) {
+ val = 1;
+ break;
+ }
+ }
+ val = s->gzhead->name[s->gzindex++];
+ put_byte(s, val);
+ } while (val != 0);
+ if (s->gzhead->hcrc && s->pending > beg)
+ strm->adler = crc32(strm->adler, s->pending_buf + beg,
+ s->pending - beg);
+ if (val == 0) {
+ s->gzindex = 0;
+ s->status = COMMENT_STATE;
+ }
+ }
+ else
+ s->status = COMMENT_STATE;
+ }
+ if (s->status == COMMENT_STATE) {
+ if (s->gzhead->comment != Z_NULL) {
+ uInt beg = s->pending; /* start of bytes to update crc */
+ int val;
+
+ do {
+ if (s->pending == s->pending_buf_size) {
+ if (s->gzhead->hcrc && s->pending > beg)
+ strm->adler = crc32(strm->adler, s->pending_buf + beg,
+ s->pending - beg);
+ flush_pending(strm);
+ beg = s->pending;
+ if (s->pending == s->pending_buf_size) {
+ val = 1;
+ break;
+ }
+ }
+ val = s->gzhead->comment[s->gzindex++];
+ put_byte(s, val);
+ } while (val != 0);
+ if (s->gzhead->hcrc && s->pending > beg)
+ strm->adler = crc32(strm->adler, s->pending_buf + beg,
+ s->pending - beg);
+ if (val == 0)
+ s->status = HCRC_STATE;
+ }
+ else
+ s->status = HCRC_STATE;
+ }
+ if (s->status == HCRC_STATE) {
+ if (s->gzhead->hcrc) {
+ if (s->pending + 2 > s->pending_buf_size)
+ flush_pending(strm);
+ if (s->pending + 2 <= s->pending_buf_size) {
+ put_byte(s, (Byte)(strm->adler & 0xff));
+ put_byte(s, (Byte)((strm->adler >> 8) & 0xff));
+ strm->adler = crc32(0L, Z_NULL, 0);
+ s->status = BUSY_STATE;
+ }
+ }
+ else
+ s->status = BUSY_STATE;
+ }
+#endif
+
+ /* Flush as much pending output as possible */
+ if (s->pending != 0) {
+ flush_pending(strm);
+ if (strm->avail_out == 0) {
+ /* Since avail_out is 0, deflate will be called again with
+ * more output space, but possibly with both pending and
+ * avail_in equal to zero. There won't be anything to do,
+ * but this is not an error situation so make sure we
+ * return OK instead of BUF_ERROR at next call of deflate:
+ */
+ s->last_flush = -1;
+ return Z_OK;
+ }
+
+ /* Make sure there is something to do and avoid duplicate consecutive
+ * flushes. For repeated and useless calls with Z_FINISH, we keep
+ * returning Z_STREAM_END instead of Z_BUF_ERROR.
+ */
+ } else if (strm->avail_in == 0 && RANK(flush) <= RANK(old_flush) &&
+ flush != Z_FINISH) {
+ ERR_RETURN(strm, Z_BUF_ERROR);
+ }
+
+ /* User must not provide more input after the first FINISH: */
+ if (s->status == FINISH_STATE && strm->avail_in != 0) {
+ ERR_RETURN(strm, Z_BUF_ERROR);
+ }
+
+ /* Start a new block or continue the current one.
+ */
+ if (strm->avail_in != 0 || s->lookahead != 0 ||
+ (flush != Z_NO_FLUSH && s->status != FINISH_STATE)) {
+ block_state bstate;
+
+ bstate = s->strategy == Z_HUFFMAN_ONLY ? deflate_huff(s, flush) :
+ (s->strategy == Z_RLE ? deflate_rle(s, flush) :
+ (*(configuration_table[s->level].func))(s, flush));
+
+ if (bstate == finish_started || bstate == finish_done) {
+ s->status = FINISH_STATE;
+ }
+ if (bstate == need_more || bstate == finish_started) {
+ if (strm->avail_out == 0) {
+ s->last_flush = -1; /* avoid BUF_ERROR next call, see above */
+ }
+ return Z_OK;
+ /* If flush != Z_NO_FLUSH && avail_out == 0, the next call
+ * of deflate should use the same flush parameter to make sure
+ * that the flush is complete. So we don't have to output an
+ * empty block here, this will be done at next call. This also
+ * ensures that for a very small output buffer, we emit at most
+ * one empty block.
+ */
+ }
+ if (bstate == block_done) {
+ if (flush == Z_PARTIAL_FLUSH) {
+ _tr_align(s);
+ } else if (flush != Z_BLOCK) { /* FULL_FLUSH or SYNC_FLUSH */
+ _tr_stored_block(s, (char*)0, 0L, 0);
+ /* For a full flush, this empty block will be recognized
+ * as a special marker by inflate_sync().
+ */
+ if (flush == Z_FULL_FLUSH) {
+ CLEAR_HASH(s); /* forget history */
+ if (s->lookahead == 0) {
+ s->strstart = 0;
+ s->block_start = 0L;
+ s->insert = 0;
+ }
+ }
+ }
+ flush_pending(strm);
+ if (strm->avail_out == 0) {
+ s->last_flush = -1; /* avoid BUF_ERROR at next call, see above */
+ return Z_OK;
+ }
+ }
+ }
+ Assert(strm->avail_out > 0, "bug2");
+
+ if (flush != Z_FINISH) return Z_OK;
+ if (s->wrap <= 0) return Z_STREAM_END;
+
+ /* Write the trailer */
+#ifdef GZIP
+ if (s->wrap == 2) {
+ put_byte(s, (Byte)(strm->adler & 0xff));
+ put_byte(s, (Byte)((strm->adler >> 8) & 0xff));
+ put_byte(s, (Byte)((strm->adler >> 16) & 0xff));
+ put_byte(s, (Byte)((strm->adler >> 24) & 0xff));
+ put_byte(s, (Byte)(strm->total_in & 0xff));
+ put_byte(s, (Byte)((strm->total_in >> 8) & 0xff));
+ put_byte(s, (Byte)((strm->total_in >> 16) & 0xff));
+ put_byte(s, (Byte)((strm->total_in >> 24) & 0xff));
+ }
+ else
+#endif
+ {
+ putShortMSB(s, (uInt)(strm->adler >> 16));
+ putShortMSB(s, (uInt)(strm->adler & 0xffff));
+ }
+ flush_pending(strm);
+ /* If avail_out is zero, the application will call deflate again
+ * to flush the rest.
+ */
+ if (s->wrap > 0) s->wrap = -s->wrap; /* write the trailer only once! */
+ return s->pending != 0 ? Z_OK : Z_STREAM_END;
+}
+
+/* ========================================================================= */
+int ZEXPORT deflateEnd (strm)
+ z_streamp strm;
+{
+ int status;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+
+ status = strm->state->status;
+ if (status != INIT_STATE &&
+ status != EXTRA_STATE &&
+ status != NAME_STATE &&
+ status != COMMENT_STATE &&
+ status != HCRC_STATE &&
+ status != BUSY_STATE &&
+ status != FINISH_STATE) {
+ return Z_STREAM_ERROR;
+ }
+
+ /* Deallocate in reverse order of allocations: */
+ TRY_FREE(strm, strm->state->pending_buf);
+ TRY_FREE(strm, strm->state->head);
+ TRY_FREE(strm, strm->state->prev);
+ TRY_FREE(strm, strm->state->window);
+
+ ZFREE(strm, strm->state);
+ strm->state = Z_NULL;
+
+ return status == BUSY_STATE ? Z_DATA_ERROR : Z_OK;
+}
+
+/* =========================================================================
+ * Copy the source state to the destination state.
+ * To simplify the source, this is not supported for 16-bit MSDOS (which
+ * doesn't have enough memory anyway to duplicate compression states).
+ */
+int ZEXPORT deflateCopy (dest, source)
+ z_streamp dest;
+ z_streamp source;
+{
+#ifdef MAXSEG_64K
+ return Z_STREAM_ERROR;
+#else
+ deflate_state *ds;
+ deflate_state *ss;
+ ushf *overlay;
+
+
+ if (source == Z_NULL || dest == Z_NULL || source->state == Z_NULL) {
+ return Z_STREAM_ERROR;
+ }
+
+ ss = source->state;
+
+ zmemcpy((voidpf)dest, (voidpf)source, sizeof(z_stream));
+
+ ds = (deflate_state *) ZALLOC(dest, 1, sizeof(deflate_state));
+ if (ds == Z_NULL) return Z_MEM_ERROR;
+ dest->state = (struct internal_state FAR *) ds;
+ zmemcpy((voidpf)ds, (voidpf)ss, sizeof(deflate_state));
+ ds->strm = dest;
+
+ ds->window = (Bytef *) ZALLOC(dest, ds->w_size, 2*sizeof(Byte));
+ ds->prev = (Posf *) ZALLOC(dest, ds->w_size, sizeof(Pos));
+ ds->head = (Posf *) ZALLOC(dest, ds->hash_size, sizeof(Pos));
+ overlay = (ushf *) ZALLOC(dest, ds->lit_bufsize, sizeof(ush)+2);
+ ds->pending_buf = (uchf *) overlay;
+
+ if (ds->window == Z_NULL || ds->prev == Z_NULL || ds->head == Z_NULL ||
+ ds->pending_buf == Z_NULL) {
+ deflateEnd (dest);
+ return Z_MEM_ERROR;
+ }
+ /* following zmemcpy do not work for 16-bit MSDOS */
+ zmemcpy(ds->window, ss->window, ds->w_size * 2 * sizeof(Byte));
+ zmemcpy((voidpf)ds->prev, (voidpf)ss->prev, ds->w_size * sizeof(Pos));
+ zmemcpy((voidpf)ds->head, (voidpf)ss->head, ds->hash_size * sizeof(Pos));
+ zmemcpy(ds->pending_buf, ss->pending_buf, (uInt)ds->pending_buf_size);
+
+ ds->pending_out = ds->pending_buf + (ss->pending_out - ss->pending_buf);
+ ds->d_buf = overlay + ds->lit_bufsize/sizeof(ush);
+ ds->l_buf = ds->pending_buf + (1+sizeof(ush))*ds->lit_bufsize;
+
+ ds->l_desc.dyn_tree = ds->dyn_ltree;
+ ds->d_desc.dyn_tree = ds->dyn_dtree;
+ ds->bl_desc.dyn_tree = ds->bl_tree;
+
+ return Z_OK;
+#endif /* MAXSEG_64K */
+}
+
+/* ===========================================================================
+ * Read a new buffer from the current input stream, update the adler32
+ * and total number of bytes read. All deflate() input goes through
+ * this function so some applications may wish to modify it to avoid
+ * allocating a large strm->next_in buffer and copying from it.
+ * (See also flush_pending()).
+ */
+local int read_buf(strm, buf, size)
+ z_streamp strm;
+ Bytef *buf;
+ unsigned size;
+{
+ unsigned len = strm->avail_in;
+
+ if (len > size) len = size;
+ if (len == 0) return 0;
+
+ strm->avail_in -= len;
+
+ zmemcpy(buf, strm->next_in, len);
+ if (strm->state->wrap == 1) {
+ strm->adler = adler32(strm->adler, buf, len);
+ }
+#ifdef GZIP
+ else if (strm->state->wrap == 2) {
+ strm->adler = crc32(strm->adler, buf, len);
+ }
+#endif
+ strm->next_in += len;
+ strm->total_in += len;
+
+ return (int)len;
+}
+
+/* ===========================================================================
+ * Initialize the "longest match" routines for a new zlib stream
+ */
+local void lm_init (s)
+ deflate_state *s;
+{
+ s->window_size = (ulg)2L*s->w_size;
+
+ CLEAR_HASH(s);
+
+ /* Set the default configuration parameters:
+ */
+ s->max_lazy_match = configuration_table[s->level].max_lazy;
+ s->good_match = configuration_table[s->level].good_length;
+ s->nice_match = configuration_table[s->level].nice_length;
+ s->max_chain_length = configuration_table[s->level].max_chain;
+
+ s->strstart = 0;
+ s->block_start = 0L;
+ s->lookahead = 0;
+ s->insert = 0;
+ s->match_length = s->prev_length = MIN_MATCH-1;
+ s->match_available = 0;
+ s->ins_h = 0;
+#ifndef FASTEST
+#ifdef ASMV
+ match_init(); /* initialize the asm code */
+#endif
+#endif
+}
+
+#ifndef FASTEST
+/* ===========================================================================
+ * Set match_start to the longest match starting at the given string and
+ * return its length. Matches shorter or equal to prev_length are discarded,
+ * in which case the result is equal to prev_length and match_start is
+ * garbage.
+ * IN assertions: cur_match is the head of the hash chain for the current
+ * string (strstart) and its distance is <= MAX_DIST, and prev_length >= 1
+ * OUT assertion: the match length is not greater than s->lookahead.
+ */
+#ifndef ASMV
+/* For 80x86 and 680x0, an optimized version will be provided in match.asm or
+ * match.S. The code will be functionally equivalent.
+ */
+local uInt longest_match(s, cur_match)
+ deflate_state *s;
+ IPos cur_match; /* current match */
+{
+ unsigned chain_length = s->max_chain_length;/* max hash chain length */
+ register Bytef *scan = s->window + s->strstart; /* current string */
+ register Bytef *match; /* matched string */
+ register int len; /* length of current match */
+ int best_len = s->prev_length; /* best match length so far */
+ int nice_match = s->nice_match; /* stop if match long enough */
+ IPos limit = s->strstart > (IPos)MAX_DIST(s) ?
+ s->strstart - (IPos)MAX_DIST(s) : NIL;
+ /* Stop when cur_match becomes <= limit. To simplify the code,
+ * we prevent matches with the string of window index 0.
+ */
+ Posf *prev = s->prev;
+ uInt wmask = s->w_mask;
+
+#ifdef UNALIGNED_OK
+ /* Compare two bytes at a time. Note: this is not always beneficial.
+ * Try with and without -DUNALIGNED_OK to check.
+ */
+ register Bytef *strend = s->window + s->strstart + MAX_MATCH - 1;
+ register ush scan_start = *(ushf*)scan;
+ register ush scan_end = *(ushf*)(scan+best_len-1);
+#else
+ register Bytef *strend = s->window + s->strstart + MAX_MATCH;
+ register Byte scan_end1 = scan[best_len-1];
+ register Byte scan_end = scan[best_len];
+#endif
+
+ /* The code is optimized for HASH_BITS >= 8 and MAX_MATCH-2 multiple of 16.
+ * It is easy to get rid of this optimization if necessary.
+ */
+ Assert(s->hash_bits >= 8 && MAX_MATCH == 258, "Code too clever");
+
+ /* Do not waste too much time if we already have a good match: */
+ if (s->prev_length >= s->good_match) {
+ chain_length >>= 2;
+ }
+ /* Do not look for matches beyond the end of the input. This is necessary
+ * to make deflate deterministic.
+ */
+ if ((uInt)nice_match > s->lookahead) nice_match = s->lookahead;
+
+ Assert((ulg)s->strstart <= s->window_size-MIN_LOOKAHEAD, "need lookahead");
+
+ do {
+ Assert(cur_match < s->strstart, "no future");
+ match = s->window + cur_match;
+
+ /* Skip to next match if the match length cannot increase
+ * or if the match length is less than 2. Note that the checks below
+ * for insufficient lookahead only occur occasionally for performance
+ * reasons. Therefore uninitialized memory will be accessed, and
+ * conditional jumps will be made that depend on those values.
+ * However the length of the match is limited to the lookahead, so
+ * the output of deflate is not affected by the uninitialized values.
+ */
+#if (defined(UNALIGNED_OK) && MAX_MATCH == 258)
+ /* This code assumes sizeof(unsigned short) == 2. Do not use
+ * UNALIGNED_OK if your compiler uses a different size.
+ */
+ if (*(ushf*)(match+best_len-1) != scan_end ||
+ *(ushf*)match != scan_start) continue;
+
+ /* It is not necessary to compare scan[2] and match[2] since they are
+ * always equal when the other bytes match, given that the hash keys
+ * are equal and that HASH_BITS >= 8. Compare 2 bytes at a time at
+ * strstart+3, +5, ... up to strstart+257. We check for insufficient
+ * lookahead only every 4th comparison; the 128th check will be made
+ * at strstart+257. If MAX_MATCH-2 is not a multiple of 8, it is
+ * necessary to put more guard bytes at the end of the window, or
+ * to check more often for insufficient lookahead.
+ */
+ Assert(scan[2] == match[2], "scan[2]?");
+ scan++, match++;
+ do {
+ } while (*(ushf*)(scan+=2) == *(ushf*)(match+=2) &&
+ *(ushf*)(scan+=2) == *(ushf*)(match+=2) &&
+ *(ushf*)(scan+=2) == *(ushf*)(match+=2) &&
+ *(ushf*)(scan+=2) == *(ushf*)(match+=2) &&
+ scan < strend);
+ /* The funny "do {}" generates better code on most compilers */
+
+ /* Here, scan <= window+strstart+257 */
+ Assert(scan <= s->window+(unsigned)(s->window_size-1), "wild scan");
+ if (*scan == *match) scan++;
+
+ len = (MAX_MATCH - 1) - (int)(strend-scan);
+ scan = strend - (MAX_MATCH-1);
+
+#else /* UNALIGNED_OK */
+
+ if (match[best_len] != scan_end ||
+ match[best_len-1] != scan_end1 ||
+ *match != *scan ||
+ *++match != scan[1]) continue;
+
+ /* The check at best_len-1 can be removed because it will be made
+ * again later. (This heuristic is not always a win.)
+ * It is not necessary to compare scan[2] and match[2] since they
+ * are always equal when the other bytes match, given that
+ * the hash keys are equal and that HASH_BITS >= 8.
+ */
+ scan += 2, match++;
+ Assert(*scan == *match, "match[2]?");
+
+ /* We check for insufficient lookahead only every 8th comparison;
+ * the 256th check will be made at strstart+258.
+ */
+ do {
+ } while (*++scan == *++match && *++scan == *++match &&
+ *++scan == *++match && *++scan == *++match &&
+ *++scan == *++match && *++scan == *++match &&
+ *++scan == *++match && *++scan == *++match &&
+ scan < strend);
+
+ Assert(scan <= s->window+(unsigned)(s->window_size-1), "wild scan");
+
+ len = MAX_MATCH - (int)(strend - scan);
+ scan = strend - MAX_MATCH;
+
+#endif /* UNALIGNED_OK */
+
+ if (len > best_len) {
+ s->match_start = cur_match;
+ best_len = len;
+ if (len >= nice_match) break;
+#ifdef UNALIGNED_OK
+ scan_end = *(ushf*)(scan+best_len-1);
+#else
+ scan_end1 = scan[best_len-1];
+ scan_end = scan[best_len];
+#endif
+ }
+ } while ((cur_match = prev[cur_match & wmask]) > limit
+ && --chain_length != 0);
+
+ if ((uInt)best_len <= s->lookahead) return (uInt)best_len;
+ return s->lookahead;
+}
+#endif /* ASMV */
+
+#else /* FASTEST */
+
+/* ---------------------------------------------------------------------------
+ * Optimized version for FASTEST only
+ */
+local uInt longest_match(s, cur_match)
+ deflate_state *s;
+ IPos cur_match; /* current match */
+{
+ register Bytef *scan = s->window + s->strstart; /* current string */
+ register Bytef *match; /* matched string */
+ register int len; /* length of current match */
+ register Bytef *strend = s->window + s->strstart + MAX_MATCH;
+
+ /* The code is optimized for HASH_BITS >= 8 and MAX_MATCH-2 multiple of 16.
+ * It is easy to get rid of this optimization if necessary.
+ */
+ Assert(s->hash_bits >= 8 && MAX_MATCH == 258, "Code too clever");
+
+ Assert((ulg)s->strstart <= s->window_size-MIN_LOOKAHEAD, "need lookahead");
+
+ Assert(cur_match < s->strstart, "no future");
+
+ match = s->window + cur_match;
+
+ /* Return failure if the match length is less than 2:
+ */
+ if (match[0] != scan[0] || match[1] != scan[1]) return MIN_MATCH-1;
+
+ /* The check at best_len-1 can be removed because it will be made
+ * again later. (This heuristic is not always a win.)
+ * It is not necessary to compare scan[2] and match[2] since they
+ * are always equal when the other bytes match, given that
+ * the hash keys are equal and that HASH_BITS >= 8.
+ */
+ scan += 2, match += 2;
+ Assert(*scan == *match, "match[2]?");
+
+ /* We check for insufficient lookahead only every 8th comparison;
+ * the 256th check will be made at strstart+258.
+ */
+ do {
+ } while (*++scan == *++match && *++scan == *++match &&
+ *++scan == *++match && *++scan == *++match &&
+ *++scan == *++match && *++scan == *++match &&
+ *++scan == *++match && *++scan == *++match &&
+ scan < strend);
+
+ Assert(scan <= s->window+(unsigned)(s->window_size-1), "wild scan");
+
+ len = MAX_MATCH - (int)(strend - scan);
+
+ if (len < MIN_MATCH) return MIN_MATCH - 1;
+
+ s->match_start = cur_match;
+ return (uInt)len <= s->lookahead ? (uInt)len : s->lookahead;
+}
+
+#endif /* FASTEST */
+
+#ifdef DEBUG
+/* ===========================================================================
+ * Check that the match at match_start is indeed a match.
+ */
+local void check_match(s, start, match, length)
+ deflate_state *s;
+ IPos start, match;
+ int length;
+{
+ /* check that the match is indeed a match */
+ if (zmemcmp(s->window + match,
+ s->window + start, length) != EQUAL) {
+ fprintf(stderr, " start %u, match %u, length %d\n",
+ start, match, length);
+ do {
+ fprintf(stderr, "%c%c", s->window[match++], s->window[start++]);
+ } while (--length != 0);
+ z_error("invalid match");
+ }
+ if (z_verbose > 1) {
+ fprintf(stderr,"\\[%d,%d]", start-match, length);
+ do { putc(s->window[start++], stderr); } while (--length != 0);
+ }
+}
+#else
+# define check_match(s, start, match, length)
+#endif /* DEBUG */
+
+/* ===========================================================================
+ * Fill the window when the lookahead becomes insufficient.
+ * Updates strstart and lookahead.
+ *
+ * IN assertion: lookahead < MIN_LOOKAHEAD
+ * OUT assertions: strstart <= window_size-MIN_LOOKAHEAD
+ * At least one byte has been read, or avail_in == 0; reads are
+ * performed for at least two bytes (required for the zip translate_eol
+ * option -- not supported here).
+ */
+local void fill_window(s)
+ deflate_state *s;
+{
+ register unsigned n, m;
+ register Posf *p;
+ unsigned more; /* Amount of free space at the end of the window. */
+ uInt wsize = s->w_size;
+
+ Assert(s->lookahead < MIN_LOOKAHEAD, "already enough lookahead");
+
+ do {
+ more = (unsigned)(s->window_size -(ulg)s->lookahead -(ulg)s->strstart);
+
+ /* Deal with !@#$% 64K limit: */
+ if (sizeof(int) <= 2) {
+ if (more == 0 && s->strstart == 0 && s->lookahead == 0) {
+ more = wsize;
+
+ } else if (more == (unsigned)(-1)) {
+ /* Very unlikely, but possible on 16 bit machine if
+ * strstart == 0 && lookahead == 1 (input done a byte at time)
+ */
+ more--;
+ }
+ }
+
+ /* If the window is almost full and there is insufficient lookahead,
+ * move the upper half to the lower one to make room in the upper half.
+ */
+ if (s->strstart >= wsize+MAX_DIST(s)) {
+
+ zmemcpy(s->window, s->window+wsize, (unsigned)wsize);
+ s->match_start -= wsize;
+ s->strstart -= wsize; /* we now have strstart >= MAX_DIST */
+ s->block_start -= (long) wsize;
+
+ /* Slide the hash table (could be avoided with 32 bit values
+ at the expense of memory usage). We slide even when level == 0
+ to keep the hash table consistent if we switch back to level > 0
+ later. (Using level 0 permanently is not an optimal usage of
+ zlib, so we don't care about this pathological case.)
+ */
+ n = s->hash_size;
+ p = &s->head[n];
+ do {
+ m = *--p;
+ *p = (Pos)(m >= wsize ? m-wsize : NIL);
+ } while (--n);
+
+ n = wsize;
+#ifndef FASTEST
+ p = &s->prev[n];
+ do {
+ m = *--p;
+ *p = (Pos)(m >= wsize ? m-wsize : NIL);
+ /* If n is not on any hash chain, prev[n] is garbage but
+ * its value will never be used.
+ */
+ } while (--n);
+#endif
+ more += wsize;
+ }
+ if (s->strm->avail_in == 0) break;
+
+ /* If there was no sliding:
+ * strstart <= WSIZE+MAX_DIST-1 && lookahead <= MIN_LOOKAHEAD - 1 &&
+ * more == window_size - lookahead - strstart
+ * => more >= window_size - (MIN_LOOKAHEAD-1 + WSIZE + MAX_DIST-1)
+ * => more >= window_size - 2*WSIZE + 2
+ * In the BIG_MEM or MMAP case (not yet supported),
+ * window_size == input_size + MIN_LOOKAHEAD &&
+ * strstart + s->lookahead <= input_size => more >= MIN_LOOKAHEAD.
+ * Otherwise, window_size == 2*WSIZE so more >= 2.
+ * If there was sliding, more >= WSIZE. So in all cases, more >= 2.
+ */
+ Assert(more >= 2, "more < 2");
+
+ n = read_buf(s->strm, s->window + s->strstart + s->lookahead, more);
+ s->lookahead += n;
+
+ /* Initialize the hash value now that we have some input: */
+ if (s->lookahead + s->insert >= MIN_MATCH) {
+ uInt str = s->strstart - s->insert;
+ s->ins_h = s->window[str];
+ UPDATE_HASH(s, s->ins_h, s->window[str + 1]);
+#if MIN_MATCH != 3
+ Call UPDATE_HASH() MIN_MATCH-3 more times
+#endif
+ while (s->insert) {
+ UPDATE_HASH(s, s->ins_h, s->window[str + MIN_MATCH-1]);
+#ifndef FASTEST
+ s->prev[str & s->w_mask] = s->head[s->ins_h];
+#endif
+ s->head[s->ins_h] = (Pos)str;
+ str++;
+ s->insert--;
+ if (s->lookahead + s->insert < MIN_MATCH)
+ break;
+ }
+ }
+ /* If the whole input has less than MIN_MATCH bytes, ins_h is garbage,
+ * but this is not important since only literal bytes will be emitted.
+ */
+
+ } while (s->lookahead < MIN_LOOKAHEAD && s->strm->avail_in != 0);
+
+ /* If the WIN_INIT bytes after the end of the current data have never been
+ * written, then zero those bytes in order to avoid memory check reports of
+ * the use of uninitialized (or uninitialised as Julian writes) bytes by
+ * the longest match routines. Update the high water mark for the next
+ * time through here. WIN_INIT is set to MAX_MATCH since the longest match
+ * routines allow scanning to strstart + MAX_MATCH, ignoring lookahead.
+ */
+ if (s->high_water < s->window_size) {
+ ulg curr = s->strstart + (ulg)(s->lookahead);
+ ulg init;
+
+ if (s->high_water < curr) {
+ /* Previous high water mark below current data -- zero WIN_INIT
+ * bytes or up to end of window, whichever is less.
+ */
+ init = s->window_size - curr;
+ if (init > WIN_INIT)
+ init = WIN_INIT;
+ zmemzero(s->window + curr, (unsigned)init);
+ s->high_water = curr + init;
+ }
+ else if (s->high_water < (ulg)curr + WIN_INIT) {
+ /* High water mark at or above current data, but below current data
+ * plus WIN_INIT -- zero out to current data plus WIN_INIT, or up
+ * to end of window, whichever is less.
+ */
+ init = (ulg)curr + WIN_INIT - s->high_water;
+ if (init > s->window_size - s->high_water)
+ init = s->window_size - s->high_water;
+ zmemzero(s->window + s->high_water, (unsigned)init);
+ s->high_water += init;
+ }
+ }
+
+ Assert((ulg)s->strstart <= s->window_size - MIN_LOOKAHEAD,
+ "not enough room for search");
+}
+
+/* ===========================================================================
+ * Flush the current block, with given end-of-file flag.
+ * IN assertion: strstart is set to the end of the current match.
+ */
+#define FLUSH_BLOCK_ONLY(s, last) { \
+ _tr_flush_block(s, (s->block_start >= 0L ? \
+ (charf *)&s->window[(unsigned)s->block_start] : \
+ (charf *)Z_NULL), \
+ (ulg)((long)s->strstart - s->block_start), \
+ (last)); \
+ s->block_start = s->strstart; \
+ flush_pending(s->strm); \
+ Tracev((stderr,"[FLUSH]")); \
+}
+
+/* Same but force premature exit if necessary. */
+#define FLUSH_BLOCK(s, last) { \
+ FLUSH_BLOCK_ONLY(s, last); \
+ if (s->strm->avail_out == 0) return (last) ? finish_started : need_more; \
+}
+
+/* ===========================================================================
+ * Copy without compression as much as possible from the input stream, return
+ * the current block state.
+ * This function does not insert new strings in the dictionary since
+ * uncompressible data is probably not useful. This function is used
+ * only for the level=0 compression option.
+ * NOTE: this function should be optimized to avoid extra copying from
+ * window to pending_buf.
+ */
+local block_state deflate_stored(s, flush)
+ deflate_state *s;
+ int flush;
+{
+ /* Stored blocks are limited to 0xffff bytes, pending_buf is limited
+ * to pending_buf_size, and each stored block has a 5 byte header:
+ */
+ ulg max_block_size = 0xffff;
+ ulg max_start;
+
+ if (max_block_size > s->pending_buf_size - 5) {
+ max_block_size = s->pending_buf_size - 5;
+ }
+
+ /* Copy as much as possible from input to output: */
+ for (;;) {
+ /* Fill the window as much as possible: */
+ if (s->lookahead <= 1) {
+
+ Assert(s->strstart < s->w_size+MAX_DIST(s) ||
+ s->block_start >= (long)s->w_size, "slide too late");
+
+ fill_window(s);
+ if (s->lookahead == 0 && flush == Z_NO_FLUSH) return need_more;
+
+ if (s->lookahead == 0) break; /* flush the current block */
+ }
+ Assert(s->block_start >= 0L, "block gone");
+
+ s->strstart += s->lookahead;
+ s->lookahead = 0;
+
+ /* Emit a stored block if pending_buf will be full: */
+ max_start = s->block_start + max_block_size;
+ if (s->strstart == 0 || (ulg)s->strstart >= max_start) {
+ /* strstart == 0 is possible when wraparound on 16-bit machine */
+ s->lookahead = (uInt)(s->strstart - max_start);
+ s->strstart = (uInt)max_start;
+ FLUSH_BLOCK(s, 0);
+ }
+ /* Flush if we may have to slide, otherwise block_start may become
+ * negative and the data will be gone:
+ */
+ if (s->strstart - (uInt)s->block_start >= MAX_DIST(s)) {
+ FLUSH_BLOCK(s, 0);
+ }
+ }
+ s->insert = 0;
+ if (flush == Z_FINISH) {
+ FLUSH_BLOCK(s, 1);
+ return finish_done;
+ }
+ if ((long)s->strstart > s->block_start)
+ FLUSH_BLOCK(s, 0);
+ return block_done;
+}
+
+/* ===========================================================================
+ * Compress as much as possible from the input stream, return the current
+ * block state.
+ * This function does not perform lazy evaluation of matches and inserts
+ * new strings in the dictionary only for unmatched strings or for short
+ * matches. It is used only for the fast compression options.
+ */
+local block_state deflate_fast(s, flush)
+ deflate_state *s;
+ int flush;
+{
+ IPos hash_head; /* head of the hash chain */
+ int bflush; /* set if current block must be flushed */
+
+ for (;;) {
+ /* Make sure that we always have enough lookahead, except
+ * at the end of the input file. We need MAX_MATCH bytes
+ * for the next match, plus MIN_MATCH bytes to insert the
+ * string following the next match.
+ */
+ if (s->lookahead < MIN_LOOKAHEAD) {
+ fill_window(s);
+ if (s->lookahead < MIN_LOOKAHEAD && flush == Z_NO_FLUSH) {
+ return need_more;
+ }
+ if (s->lookahead == 0) break; /* flush the current block */
+ }
+
+ /* Insert the string window[strstart .. strstart+2] in the
+ * dictionary, and set hash_head to the head of the hash chain:
+ */
+ hash_head = NIL;
+ if (s->lookahead >= MIN_MATCH) {
+ INSERT_STRING(s, s->strstart, hash_head);
+ }
+
+ /* Find the longest match, discarding those <= prev_length.
+ * At this point we have always match_length < MIN_MATCH
+ */
+ if (hash_head != NIL && s->strstart - hash_head <= MAX_DIST(s)) {
+ /* To simplify the code, we prevent matches with the string
+ * of window index 0 (in particular we have to avoid a match
+ * of the string with itself at the start of the input file).
+ */
+ s->match_length = longest_match (s, hash_head);
+ /* longest_match() sets match_start */
+ }
+ if (s->match_length >= MIN_MATCH) {
+ check_match(s, s->strstart, s->match_start, s->match_length);
+
+ _tr_tally_dist(s, s->strstart - s->match_start,
+ s->match_length - MIN_MATCH, bflush);
+
+ s->lookahead -= s->match_length;
+
+ /* Insert new strings in the hash table only if the match length
+ * is not too large. This saves time but degrades compression.
+ */
+#ifndef FASTEST
+ if (s->match_length <= s->max_insert_length &&
+ s->lookahead >= MIN_MATCH) {
+ s->match_length--; /* string at strstart already in table */
+ do {
+ s->strstart++;
+ INSERT_STRING(s, s->strstart, hash_head);
+ /* strstart never exceeds WSIZE-MAX_MATCH, so there are
+ * always MIN_MATCH bytes ahead.
+ */
+ } while (--s->match_length != 0);
+ s->strstart++;
+ } else
+#endif
+ {
+ s->strstart += s->match_length;
+ s->match_length = 0;
+ s->ins_h = s->window[s->strstart];
+ UPDATE_HASH(s, s->ins_h, s->window[s->strstart+1]);
+#if MIN_MATCH != 3
+ Call UPDATE_HASH() MIN_MATCH-3 more times
+#endif
+ /* If lookahead < MIN_MATCH, ins_h is garbage, but it does not
+ * matter since it will be recomputed at next deflate call.
+ */
+ }
+ } else {
+ /* No match, output a literal byte */
+ Tracevv((stderr,"%c", s->window[s->strstart]));
+ _tr_tally_lit (s, s->window[s->strstart], bflush);
+ s->lookahead--;
+ s->strstart++;
+ }
+ if (bflush) FLUSH_BLOCK(s, 0);
+ }
+ s->insert = s->strstart < MIN_MATCH-1 ? s->strstart : MIN_MATCH-1;
+ if (flush == Z_FINISH) {
+ FLUSH_BLOCK(s, 1);
+ return finish_done;
+ }
+ if (s->last_lit)
+ FLUSH_BLOCK(s, 0);
+ return block_done;
+}
+
+#ifndef FASTEST
+/* ===========================================================================
+ * Same as above, but achieves better compression. We use a lazy
+ * evaluation for matches: a match is finally adopted only if there is
+ * no better match at the next window position.
+ */
+local block_state deflate_slow(s, flush)
+ deflate_state *s;
+ int flush;
+{
+ IPos hash_head; /* head of hash chain */
+ int bflush; /* set if current block must be flushed */
+
+ /* Process the input block. */
+ for (;;) {
+ /* Make sure that we always have enough lookahead, except
+ * at the end of the input file. We need MAX_MATCH bytes
+ * for the next match, plus MIN_MATCH bytes to insert the
+ * string following the next match.
+ */
+ if (s->lookahead < MIN_LOOKAHEAD) {
+ fill_window(s);
+ if (s->lookahead < MIN_LOOKAHEAD && flush == Z_NO_FLUSH) {
+ return need_more;
+ }
+ if (s->lookahead == 0) break; /* flush the current block */
+ }
+
+ /* Insert the string window[strstart .. strstart+2] in the
+ * dictionary, and set hash_head to the head of the hash chain:
+ */
+ hash_head = NIL;
+ if (s->lookahead >= MIN_MATCH) {
+ INSERT_STRING(s, s->strstart, hash_head);
+ }
+
+ /* Find the longest match, discarding those <= prev_length.
+ */
+ s->prev_length = s->match_length, s->prev_match = s->match_start;
+ s->match_length = MIN_MATCH-1;
+
+ if (hash_head != NIL && s->prev_length < s->max_lazy_match &&
+ s->strstart - hash_head <= MAX_DIST(s)) {
+ /* To simplify the code, we prevent matches with the string
+ * of window index 0 (in particular we have to avoid a match
+ * of the string with itself at the start of the input file).
+ */
+ s->match_length = longest_match (s, hash_head);
+ /* longest_match() sets match_start */
+
+ if (s->match_length <= 5 && (s->strategy == Z_FILTERED
+#if TOO_FAR <= 32767
+ || (s->match_length == MIN_MATCH &&
+ s->strstart - s->match_start > TOO_FAR)
+#endif
+ )) {
+
+ /* If prev_match is also MIN_MATCH, match_start is garbage
+ * but we will ignore the current match anyway.
+ */
+ s->match_length = MIN_MATCH-1;
+ }
+ }
+ /* If there was a match at the previous step and the current
+ * match is not better, output the previous match:
+ */
+ if (s->prev_length >= MIN_MATCH && s->match_length <= s->prev_length) {
+ uInt max_insert = s->strstart + s->lookahead - MIN_MATCH;
+ /* Do not insert strings in hash table beyond this. */
+
+ check_match(s, s->strstart-1, s->prev_match, s->prev_length);
+
+ _tr_tally_dist(s, s->strstart -1 - s->prev_match,
+ s->prev_length - MIN_MATCH, bflush);
+
+ /* Insert in hash table all strings up to the end of the match.
+ * strstart-1 and strstart are already inserted. If there is not
+ * enough lookahead, the last two strings are not inserted in
+ * the hash table.
+ */
+ s->lookahead -= s->prev_length-1;
+ s->prev_length -= 2;
+ do {
+ if (++s->strstart <= max_insert) {
+ INSERT_STRING(s, s->strstart, hash_head);
+ }
+ } while (--s->prev_length != 0);
+ s->match_available = 0;
+ s->match_length = MIN_MATCH-1;
+ s->strstart++;
+
+ if (bflush) FLUSH_BLOCK(s, 0);
+
+ } else if (s->match_available) {
+ /* If there was no match at the previous position, output a
+ * single literal. If there was a match but the current match
+ * is longer, truncate the previous match to a single literal.
+ */
+ Tracevv((stderr,"%c", s->window[s->strstart-1]));
+ _tr_tally_lit(s, s->window[s->strstart-1], bflush);
+ if (bflush) {
+ FLUSH_BLOCK_ONLY(s, 0);
+ }
+ s->strstart++;
+ s->lookahead--;
+ if (s->strm->avail_out == 0) return need_more;
+ } else {
+ /* There is no previous match to compare with, wait for
+ * the next step to decide.
+ */
+ s->match_available = 1;
+ s->strstart++;
+ s->lookahead--;
+ }
+ }
+ Assert (flush != Z_NO_FLUSH, "no flush?");
+ if (s->match_available) {
+ Tracevv((stderr,"%c", s->window[s->strstart-1]));
+ _tr_tally_lit(s, s->window[s->strstart-1], bflush);
+ s->match_available = 0;
+ }
+ s->insert = s->strstart < MIN_MATCH-1 ? s->strstart : MIN_MATCH-1;
+ if (flush == Z_FINISH) {
+ FLUSH_BLOCK(s, 1);
+ return finish_done;
+ }
+ if (s->last_lit)
+ FLUSH_BLOCK(s, 0);
+ return block_done;
+}
+#endif /* FASTEST */
+
+/* ===========================================================================
+ * For Z_RLE, simply look for runs of bytes, generate matches only of distance
+ * one. Do not maintain a hash table. (It will be regenerated if this run of
+ * deflate switches away from Z_RLE.)
+ */
+local block_state deflate_rle(s, flush)
+ deflate_state *s;
+ int flush;
+{
+ int bflush; /* set if current block must be flushed */
+ uInt prev; /* byte at distance one to match */
+ Bytef *scan, *strend; /* scan goes up to strend for length of run */
+
+ for (;;) {
+ /* Make sure that we always have enough lookahead, except
+ * at the end of the input file. We need MAX_MATCH bytes
+ * for the longest run, plus one for the unrolled loop.
+ */
+ if (s->lookahead <= MAX_MATCH) {
+ fill_window(s);
+ if (s->lookahead <= MAX_MATCH && flush == Z_NO_FLUSH) {
+ return need_more;
+ }
+ if (s->lookahead == 0) break; /* flush the current block */
+ }
+
+ /* See how many times the previous byte repeats */
+ s->match_length = 0;
+ if (s->lookahead >= MIN_MATCH && s->strstart > 0) {
+ scan = s->window + s->strstart - 1;
+ prev = *scan;
+ if (prev == *++scan && prev == *++scan && prev == *++scan) {
+ strend = s->window + s->strstart + MAX_MATCH;
+ do {
+ } while (prev == *++scan && prev == *++scan &&
+ prev == *++scan && prev == *++scan &&
+ prev == *++scan && prev == *++scan &&
+ prev == *++scan && prev == *++scan &&
+ scan < strend);
+ s->match_length = MAX_MATCH - (int)(strend - scan);
+ if (s->match_length > s->lookahead)
+ s->match_length = s->lookahead;
+ }
+ Assert(scan <= s->window+(uInt)(s->window_size-1), "wild scan");
+ }
+
+ /* Emit match if have run of MIN_MATCH or longer, else emit literal */
+ if (s->match_length >= MIN_MATCH) {
+ check_match(s, s->strstart, s->strstart - 1, s->match_length);
+
+ _tr_tally_dist(s, 1, s->match_length - MIN_MATCH, bflush);
+
+ s->lookahead -= s->match_length;
+ s->strstart += s->match_length;
+ s->match_length = 0;
+ } else {
+ /* No match, output a literal byte */
+ Tracevv((stderr,"%c", s->window[s->strstart]));
+ _tr_tally_lit (s, s->window[s->strstart], bflush);
+ s->lookahead--;
+ s->strstart++;
+ }
+ if (bflush) FLUSH_BLOCK(s, 0);
+ }
+ s->insert = 0;
+ if (flush == Z_FINISH) {
+ FLUSH_BLOCK(s, 1);
+ return finish_done;
+ }
+ if (s->last_lit)
+ FLUSH_BLOCK(s, 0);
+ return block_done;
+}
+
+/* ===========================================================================
+ * For Z_HUFFMAN_ONLY, do not look for matches. Do not maintain a hash table.
+ * (It will be regenerated if this run of deflate switches away from Huffman.)
+ */
+local block_state deflate_huff(s, flush)
+ deflate_state *s;
+ int flush;
+{
+ int bflush; /* set if current block must be flushed */
+
+ for (;;) {
+ /* Make sure that we have a literal to write. */
+ if (s->lookahead == 0) {
+ fill_window(s);
+ if (s->lookahead == 0) {
+ if (flush == Z_NO_FLUSH)
+ return need_more;
+ break; /* flush the current block */
+ }
+ }
+
+ /* Output a literal byte */
+ s->match_length = 0;
+ Tracevv((stderr,"%c", s->window[s->strstart]));
+ _tr_tally_lit (s, s->window[s->strstart], bflush);
+ s->lookahead--;
+ s->strstart++;
+ if (bflush) FLUSH_BLOCK(s, 0);
+ }
+ s->insert = 0;
+ if (flush == Z_FINISH) {
+ FLUSH_BLOCK(s, 1);
+ return finish_done;
+ }
+ if (s->last_lit)
+ FLUSH_BLOCK(s, 0);
+ return block_done;
+}
diff --git a/ml/dlib/dlib/external/zlib/deflate.h b/ml/dlib/dlib/external/zlib/deflate.h
new file mode 100644
index 000000000..ce0299edd
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/deflate.h
@@ -0,0 +1,346 @@
+/* deflate.h -- internal compression state
+ * Copyright (C) 1995-2012 Jean-loup Gailly
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* WARNING: this file should *not* be used by applications. It is
+ part of the implementation of the compression library and is
+ subject to change. Applications should only use zlib.h.
+ */
+
+/* @(#) $Id$ */
+
+#ifndef DEFLATE_H
+#define DEFLATE_H
+
+#include "zutil.h"
+
+/* define NO_GZIP when compiling if you want to disable gzip header and
+ trailer creation by deflate(). NO_GZIP would be used to avoid linking in
+ the crc code when it is not needed. For shared libraries, gzip encoding
+ should be left enabled. */
+#ifndef NO_GZIP
+# define GZIP
+#endif
+
+/* ===========================================================================
+ * Internal compression state.
+ */
+
+#define LENGTH_CODES 29
+/* number of length codes, not counting the special END_BLOCK code */
+
+#define LITERALS 256
+/* number of literal bytes 0..255 */
+
+#define L_CODES (LITERALS+1+LENGTH_CODES)
+/* number of Literal or Length codes, including the END_BLOCK code */
+
+#define D_CODES 30
+/* number of distance codes */
+
+#define BL_CODES 19
+/* number of codes used to transfer the bit lengths */
+
+#define HEAP_SIZE (2*L_CODES+1)
+/* maximum heap size */
+
+#define MAX_BITS 15
+/* All codes must not exceed MAX_BITS bits */
+
+#define Buf_size 16
+/* size of bit buffer in bi_buf */
+
+#define INIT_STATE 42
+#define EXTRA_STATE 69
+#define NAME_STATE 73
+#define COMMENT_STATE 91
+#define HCRC_STATE 103
+#define BUSY_STATE 113
+#define FINISH_STATE 666
+/* Stream status */
+
+
+/* Data structure describing a single value and its code string. */
+typedef struct ct_data_s {
+ union {
+ ush freq; /* frequency count */
+ ush code; /* bit string */
+ } fc;
+ union {
+ ush dad; /* father node in Huffman tree */
+ ush len; /* length of bit string */
+ } dl;
+} FAR ct_data;
+
+#define Freq fc.freq
+#define Code fc.code
+#define Dad dl.dad
+#define Len dl.len
+
+typedef struct static_tree_desc_s static_tree_desc;
+
+typedef struct tree_desc_s {
+ ct_data *dyn_tree; /* the dynamic tree */
+ int max_code; /* largest code with non zero frequency */
+ static_tree_desc *stat_desc; /* the corresponding static tree */
+} FAR tree_desc;
+
+typedef ush Pos;
+typedef Pos FAR Posf;
+typedef unsigned IPos;
+
+/* A Pos is an index in the character window. We use short instead of int to
+ * save space in the various tables. IPos is used only for parameter passing.
+ */
+
+typedef struct internal_state {
+ z_streamp strm; /* pointer back to this zlib stream */
+ int status; /* as the name implies */
+ Bytef *pending_buf; /* output still pending */
+ ulg pending_buf_size; /* size of pending_buf */
+ Bytef *pending_out; /* next pending byte to output to the stream */
+ uInt pending; /* nb of bytes in the pending buffer */
+ int wrap; /* bit 0 true for zlib, bit 1 true for gzip */
+ gz_headerp gzhead; /* gzip header information to write */
+ uInt gzindex; /* where in extra, name, or comment */
+ Byte method; /* can only be DEFLATED */
+ int last_flush; /* value of flush param for previous deflate call */
+
+ /* used by deflate.c: */
+
+ uInt w_size; /* LZ77 window size (32K by default) */
+ uInt w_bits; /* log2(w_size) (8..16) */
+ uInt w_mask; /* w_size - 1 */
+
+ Bytef *window;
+ /* Sliding window. Input bytes are read into the second half of the window,
+ * and move to the first half later to keep a dictionary of at least wSize
+ * bytes. With this organization, matches are limited to a distance of
+ * wSize-MAX_MATCH bytes, but this ensures that IO is always
+ * performed with a length multiple of the block size. Also, it limits
+ * the window size to 64K, which is quite useful on MSDOS.
+ * To do: use the user input buffer as sliding window.
+ */
+
+ ulg window_size;
+ /* Actual size of window: 2*wSize, except when the user input buffer
+ * is directly used as sliding window.
+ */
+
+ Posf *prev;
+ /* Link to older string with same hash index. To limit the size of this
+ * array to 64K, this link is maintained only for the last 32K strings.
+ * An index in this array is thus a window index modulo 32K.
+ */
+
+ Posf *head; /* Heads of the hash chains or NIL. */
+
+ uInt ins_h; /* hash index of string to be inserted */
+ uInt hash_size; /* number of elements in hash table */
+ uInt hash_bits; /* log2(hash_size) */
+ uInt hash_mask; /* hash_size-1 */
+
+ uInt hash_shift;
+ /* Number of bits by which ins_h must be shifted at each input
+ * step. It must be such that after MIN_MATCH steps, the oldest
+ * byte no longer takes part in the hash key, that is:
+ * hash_shift * MIN_MATCH >= hash_bits
+ */
+
+ long block_start;
+ /* Window position at the beginning of the current output block. Gets
+ * negative when the window is moved backwards.
+ */
+
+ uInt match_length; /* length of best match */
+ IPos prev_match; /* previous match */
+ int match_available; /* set if previous match exists */
+ uInt strstart; /* start of string to insert */
+ uInt match_start; /* start of matching string */
+ uInt lookahead; /* number of valid bytes ahead in window */
+
+ uInt prev_length;
+ /* Length of the best match at previous step. Matches not greater than this
+ * are discarded. This is used in the lazy match evaluation.
+ */
+
+ uInt max_chain_length;
+ /* To speed up deflation, hash chains are never searched beyond this
+ * length. A higher limit improves compression ratio but degrades the
+ * speed.
+ */
+
+ uInt max_lazy_match;
+ /* Attempt to find a better match only when the current match is strictly
+ * smaller than this value. This mechanism is used only for compression
+ * levels >= 4.
+ */
+# define max_insert_length max_lazy_match
+ /* Insert new strings in the hash table only if the match length is not
+ * greater than this length. This saves time but degrades compression.
+ * max_insert_length is used only for compression levels <= 3.
+ */
+
+ int level; /* compression level (1..9) */
+ int strategy; /* favor or force Huffman coding*/
+
+ uInt good_match;
+ /* Use a faster search when the previous match is longer than this */
+
+ int nice_match; /* Stop searching when current match exceeds this */
+
+ /* used by trees.c: */
+ /* Didn't use ct_data typedef below to suppress compiler warning */
+ struct ct_data_s dyn_ltree[HEAP_SIZE]; /* literal and length tree */
+ struct ct_data_s dyn_dtree[2*D_CODES+1]; /* distance tree */
+ struct ct_data_s bl_tree[2*BL_CODES+1]; /* Huffman tree for bit lengths */
+
+ struct tree_desc_s l_desc; /* desc. for literal tree */
+ struct tree_desc_s d_desc; /* desc. for distance tree */
+ struct tree_desc_s bl_desc; /* desc. for bit length tree */
+
+ ush bl_count[MAX_BITS+1];
+ /* number of codes at each bit length for an optimal tree */
+
+ int heap[2*L_CODES+1]; /* heap used to build the Huffman trees */
+ int heap_len; /* number of elements in the heap */
+ int heap_max; /* element of largest frequency */
+ /* The sons of heap[n] are heap[2*n] and heap[2*n+1]. heap[0] is not used.
+ * The same heap array is used to build all trees.
+ */
+
+ uch depth[2*L_CODES+1];
+ /* Depth of each subtree used as tie breaker for trees of equal frequency
+ */
+
+ uchf *l_buf; /* buffer for literals or lengths */
+
+ uInt lit_bufsize;
+ /* Size of match buffer for literals/lengths. There are 4 reasons for
+ * limiting lit_bufsize to 64K:
+ * - frequencies can be kept in 16 bit counters
+ * - if compression is not successful for the first block, all input
+ * data is still in the window so we can still emit a stored block even
+ * when input comes from standard input. (This can also be done for
+ * all blocks if lit_bufsize is not greater than 32K.)
+ * - if compression is not successful for a file smaller than 64K, we can
+ * even emit a stored file instead of a stored block (saving 5 bytes).
+ * This is applicable only for zip (not gzip or zlib).
+ * - creating new Huffman trees less frequently may not provide fast
+ * adaptation to changes in the input data statistics. (Take for
+ * example a binary file with poorly compressible code followed by
+ * a highly compressible string table.) Smaller buffer sizes give
+ * fast adaptation but have of course the overhead of transmitting
+ * trees more frequently.
+ * - I can't count above 4
+ */
+
+ uInt last_lit; /* running index in l_buf */
+
+ ushf *d_buf;
+ /* Buffer for distances. To simplify the code, d_buf and l_buf have
+ * the same number of elements. To use different lengths, an extra flag
+ * array would be necessary.
+ */
+
+ ulg opt_len; /* bit length of current block with optimal trees */
+ ulg static_len; /* bit length of current block with static trees */
+ uInt matches; /* number of string matches in current block */
+ uInt insert; /* bytes at end of window left to insert */
+
+#ifdef DEBUG
+ ulg compressed_len; /* total bit length of compressed file mod 2^32 */
+ ulg bits_sent; /* bit length of compressed data sent mod 2^32 */
+#endif
+
+ ush bi_buf;
+ /* Output buffer. bits are inserted starting at the bottom (least
+ * significant bits).
+ */
+ int bi_valid;
+ /* Number of valid bits in bi_buf. All bits above the last valid bit
+ * are always zero.
+ */
+
+ ulg high_water;
+ /* High water mark offset in window for initialized bytes -- bytes above
+ * this are set to zero in order to avoid memory check warnings when
+ * longest match routines access bytes past the input. This is then
+ * updated to the new high water mark.
+ */
+
+} FAR deflate_state;
+
+/* Output a byte on the stream.
+ * IN assertion: there is enough room in pending_buf.
+ */
+#define put_byte(s, c) {s->pending_buf[s->pending++] = (c);}
+
+
+#define MIN_LOOKAHEAD (MAX_MATCH+MIN_MATCH+1)
+/* Minimum amount of lookahead, except at the end of the input file.
+ * See deflate.c for comments about the MIN_MATCH+1.
+ */
+
+#define MAX_DIST(s) ((s)->w_size-MIN_LOOKAHEAD)
+/* In order to simplify the code, particularly on 16 bit machines, match
+ * distances are limited to MAX_DIST instead of WSIZE.
+ */
+
+#define WIN_INIT MAX_MATCH
+/* Number of bytes after end of data in window to initialize in order to avoid
+ memory checker errors from longest match routines */
+
+ /* in trees.c */
+void ZLIB_INTERNAL _tr_init OF((deflate_state *s));
+int ZLIB_INTERNAL _tr_tally OF((deflate_state *s, unsigned dist, unsigned lc));
+void ZLIB_INTERNAL _tr_flush_block OF((deflate_state *s, charf *buf,
+ ulg stored_len, int last));
+void ZLIB_INTERNAL _tr_flush_bits OF((deflate_state *s));
+void ZLIB_INTERNAL _tr_align OF((deflate_state *s));
+void ZLIB_INTERNAL _tr_stored_block OF((deflate_state *s, charf *buf,
+ ulg stored_len, int last));
+
+#define d_code(dist) \
+ ((dist) < 256 ? _dist_code[dist] : _dist_code[256+((dist)>>7)])
+/* Mapping from a distance to a distance code. dist is the distance - 1 and
+ * must not have side effects. _dist_code[256] and _dist_code[257] are never
+ * used.
+ */
+
+#ifndef DEBUG
+/* Inline versions of _tr_tally for speed: */
+
+#if defined(GEN_TREES_H) || !defined(STDC)
+ extern uch ZLIB_INTERNAL _length_code[];
+ extern uch ZLIB_INTERNAL _dist_code[];
+#else
+ extern const uch ZLIB_INTERNAL _length_code[];
+ extern const uch ZLIB_INTERNAL _dist_code[];
+#endif
+
+# define _tr_tally_lit(s, c, flush) \
+ { uch cc = (c); \
+ s->d_buf[s->last_lit] = 0; \
+ s->l_buf[s->last_lit++] = cc; \
+ s->dyn_ltree[cc].Freq++; \
+ flush = (s->last_lit == s->lit_bufsize-1); \
+ }
+# define _tr_tally_dist(s, distance, length, flush) \
+ { uch len = (length); \
+ ush dist = (distance); \
+ s->d_buf[s->last_lit] = dist; \
+ s->l_buf[s->last_lit++] = len; \
+ dist--; \
+ s->dyn_ltree[_length_code[len]+LITERALS+1].Freq++; \
+ s->dyn_dtree[d_code(dist)].Freq++; \
+ flush = (s->last_lit == s->lit_bufsize-1); \
+ }
+#else
+# define _tr_tally_lit(s, c, flush) flush = _tr_tally(s, 0, c)
+# define _tr_tally_dist(s, distance, length, flush) \
+ flush = _tr_tally(s, distance, length)
+#endif
+
+#endif /* DEFLATE_H */
diff --git a/ml/dlib/dlib/external/zlib/gzclose.c b/ml/dlib/dlib/external/zlib/gzclose.c
new file mode 100644
index 000000000..caeb99a31
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/gzclose.c
@@ -0,0 +1,25 @@
+/* gzclose.c -- zlib gzclose() function
+ * Copyright (C) 2004, 2010 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+#include "gzguts.h"
+
+/* gzclose() is in a separate file so that it is linked in only if it is used.
+ That way the other gzclose functions can be used instead to avoid linking in
+ unneeded compression or decompression routines. */
+int ZEXPORT gzclose(file)
+ gzFile file;
+{
+#ifndef NO_GZCOMPRESS
+ gz_statep state;
+
+ if (file == NULL)
+ return Z_STREAM_ERROR;
+ state = (gz_statep)file;
+
+ return state->mode == GZ_READ ? gzclose_r(file) : gzclose_w(file);
+#else
+ return gzclose_r(file);
+#endif
+}
diff --git a/ml/dlib/dlib/external/zlib/gzguts.h b/ml/dlib/dlib/external/zlib/gzguts.h
new file mode 100644
index 000000000..2bb0b0499
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/gzguts.h
@@ -0,0 +1,219 @@
+/* gzguts.h -- zlib internal header definitions for gz* operations
+ * Copyright (C) 2004, 2005, 2010, 2011, 2012, 2013 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+#ifdef _MSC_VER
+// Disable the following warnings for Visual Studio
+// This is a warning you get from visual studio 2005 about things in the standard C++
+// library being "deprecated." I checked the C++ standard and it doesn't say jack
+// about any of them (I checked the searchable PDF). So this warning is total Bunk.
+#pragma warning(disable : 4996)
+#endif
+
+#ifdef _LARGEFILE64_SOURCE
+# ifndef _LARGEFILE_SOURCE
+# define _LARGEFILE_SOURCE 1
+# endif
+# ifdef _FILE_OFFSET_BITS
+# undef _FILE_OFFSET_BITS
+# endif
+#endif
+
+#ifdef HAVE_HIDDEN
+# define ZLIB_INTERNAL __attribute__((visibility ("hidden")))
+#else
+# define ZLIB_INTERNAL
+#endif
+
+#include <stdio.h>
+#include "zlib.h"
+#ifdef STDC
+# include <string.h>
+# include <stdlib.h>
+# include <limits.h>
+#endif
+#include <fcntl.h>
+
+#ifdef _WIN32
+# include <stddef.h>
+#endif
+
+#if defined(__TURBOC__) || defined(_MSC_VER) || defined(_WIN32)
+# include <io.h>
+#else
+# include <unistd.h>
+#endif
+
+#ifdef WINAPI_FAMILY
+# define open _open
+# define read _read
+# define write _write
+# define close _close
+#endif
+
+#ifdef NO_DEFLATE /* for compatibility with old definition */
+# define NO_GZCOMPRESS
+#endif
+
+#if defined(STDC99) || (defined(__TURBOC__) && __TURBOC__ >= 0x550)
+# ifndef HAVE_VSNPRINTF
+# define HAVE_VSNPRINTF
+# endif
+#endif
+
+#if defined(__CYGWIN__)
+# ifndef HAVE_VSNPRINTF
+# define HAVE_VSNPRINTF
+# endif
+#endif
+
+#if defined(MSDOS) && defined(__BORLANDC__) && (BORLANDC > 0x410)
+# ifndef HAVE_VSNPRINTF
+# define HAVE_VSNPRINTF
+# endif
+#endif
+
+#ifndef HAVE_VSNPRINTF
+# ifdef MSDOS
+/* vsnprintf may exist on some MS-DOS compilers (DJGPP?),
+ but for now we just assume it doesn't. */
+# define NO_vsnprintf
+# endif
+# ifdef __TURBOC__
+# define NO_vsnprintf
+# endif
+# ifdef WIN32
+/* In Win32, vsnprintf is available as the "non-ANSI" _vsnprintf. */
+# if !defined(vsnprintf) && !defined(NO_vsnprintf)
+# if !defined(_MSC_VER) || ( defined(_MSC_VER) && _MSC_VER < 1500 )
+# define vsnprintf _vsnprintf
+# endif
+# endif
+# endif
+# ifdef __SASC
+# define NO_vsnprintf
+# endif
+# ifdef VMS
+# define NO_vsnprintf
+# endif
+# ifdef __OS400__
+# define NO_vsnprintf
+# endif
+# ifdef __MVS__
+# define NO_vsnprintf
+# endif
+#endif
+
+/* unlike snprintf (which is required in C99, yet still not supported by
+ Microsoft more than a decade later!), _snprintf does not guarantee null
+ termination of the result -- however this is only used in gzlib.c where
+ the result is assured to fit in the space provided */
+#ifdef _MSC_VER
+# define snprintf _snprintf
+#endif
+
+#ifndef local
+# define local static
+#endif
+/* compile with -Dlocal if your debugger can't find static symbols */
+
+/* gz* functions always use library allocation functions */
+#ifndef STDC
+ extern voidp malloc OF((uInt size));
+ extern void free OF((voidpf ptr));
+#endif
+
+/* get errno and strerror definition */
+#if defined UNDER_CE
+# include <windows.h>
+# define zstrerror() gz_strwinerror((DWORD)GetLastError())
+#else
+# ifndef NO_STRERROR
+# include <errno.h>
+# define zstrerror() strerror(errno)
+# else
+# define zstrerror() "stdio error (consult errno)"
+# endif
+#endif
+
+/* provide prototypes for these when building zlib without LFS */
+#if !defined(_LARGEFILE64_SOURCE) || _LFS64_LARGEFILE-0 == 0
+ ZEXTERN gzFile ZEXPORT gzopen64 OF((const char *, const char *));
+ ZEXTERN z_off64_t ZEXPORT gzseek64 OF((gzFile, z_off64_t, int));
+ ZEXTERN z_off64_t ZEXPORT gztell64 OF((gzFile));
+ ZEXTERN z_off64_t ZEXPORT gzoffset64 OF((gzFile));
+#endif
+
+/* default memLevel */
+#if MAX_MEM_LEVEL >= 8
+# define DEF_MEM_LEVEL 8
+#else
+# define DEF_MEM_LEVEL MAX_MEM_LEVEL
+#endif
+
+/* default i/o buffer size -- double this for output when reading (this and
+ twice this must be able to fit in an unsigned type) */
+#define GZBUFSIZE 8192
+
+/* gzip modes, also provide a little integrity check on the passed structure */
+#define GZ_NONE 0
+#define GZ_READ 7247
+#define GZ_WRITE 31153
+#define GZ_APPEND 1 /* mode set to GZ_WRITE after the file is opened */
+
+/* values for gz_state how */
+#define LOOK 0 /* look for a gzip header */
+#define COPY 1 /* copy input directly */
+#define GZIP 2 /* decompress a gzip stream */
+
+/* internal gzip file state data structure */
+typedef struct {
+ /* exposed contents for gzgetc() macro */
+ struct gzFile_s x; /* "x" for exposed */
+ /* x.have: number of bytes available at x.next */
+ /* x.next: next output data to deliver or write */
+ /* x.pos: current position in uncompressed data */
+ /* used for both reading and writing */
+ int mode; /* see gzip modes above */
+ int fd; /* file descriptor */
+ char *path; /* path or fd for error messages */
+ unsigned size; /* buffer size, zero if not allocated yet */
+ unsigned want; /* requested buffer size, default is GZBUFSIZE */
+ unsigned char *in; /* input buffer */
+ unsigned char *out; /* output buffer (double-sized when reading) */
+ int direct; /* 0 if processing gzip, 1 if transparent */
+ /* just for reading */
+ int how; /* 0: get header, 1: copy, 2: decompress */
+ z_off64_t start; /* where the gzip data started, for rewinding */
+ int eof; /* true if end of input file reached */
+ int past; /* true if read requested past end */
+ /* just for writing */
+ int level; /* compression level */
+ int strategy; /* compression strategy */
+ /* seek request */
+ z_off64_t skip; /* amount to skip (already rewound if backwards) */
+ int seek; /* true if seek request pending */
+ /* error information */
+ int err; /* error code */
+ char *msg; /* error message */
+ /* zlib inflate or deflate stream */
+ z_stream strm; /* stream structure in-place (not a pointer) */
+} gz_state;
+typedef gz_state FAR *gz_statep;
+
+/* shared functions */
+void ZLIB_INTERNAL gz_error OF((gz_statep, int, const char *));
+#if defined UNDER_CE
+char ZLIB_INTERNAL *gz_strwinerror OF((DWORD error));
+#endif
+
+/* GT_OFF(x), where x is an unsigned value, is true if x > maximum z_off64_t
+ value -- needed when comparing unsigned to z_off64_t, which is signed
+ (possible z_off64_t types off_t, off64_t, and long are all signed) */
+#ifdef INT_MAX
+# define GT_OFF(x) (sizeof(int) == sizeof(z_off64_t) && (x) > INT_MAX)
+#else
+unsigned ZLIB_INTERNAL gz_intmax OF((void));
+# define GT_OFF(x) (sizeof(int) == sizeof(z_off64_t) && (x) > gz_intmax())
+#endif
diff --git a/ml/dlib/dlib/external/zlib/gzlib.c b/ml/dlib/dlib/external/zlib/gzlib.c
new file mode 100644
index 000000000..fae202ef8
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/gzlib.c
@@ -0,0 +1,634 @@
+/* gzlib.c -- zlib functions common to reading and writing gzip files
+ * Copyright (C) 2004, 2010, 2011, 2012, 2013 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+#include "gzguts.h"
+
+#if defined(_WIN32) && !defined(__BORLANDC__)
+# define LSEEK _lseeki64
+#else
+#if defined(_LARGEFILE64_SOURCE) && _LFS64_LARGEFILE-0
+# define LSEEK lseek64
+#else
+# define LSEEK lseek
+#endif
+#endif
+
+/* Local functions */
+local void gz_reset OF((gz_statep));
+local gzFile gz_open OF((const void *, int, const char *));
+
+#if defined UNDER_CE
+
+/* Map the Windows error number in ERROR to a locale-dependent error message
+ string and return a pointer to it. Typically, the values for ERROR come
+ from GetLastError.
+
+ The string pointed to shall not be modified by the application, but may be
+ overwritten by a subsequent call to gz_strwinerror
+
+ The gz_strwinerror function does not change the current setting of
+ GetLastError. */
+char ZLIB_INTERNAL *gz_strwinerror (error)
+ DWORD error;
+{
+ static char buf[1024];
+
+ wchar_t *msgbuf;
+ DWORD lasterr = GetLastError();
+ DWORD chars = FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM
+ | FORMAT_MESSAGE_ALLOCATE_BUFFER,
+ NULL,
+ error,
+ 0, /* Default language */
+ (LPVOID)&msgbuf,
+ 0,
+ NULL);
+ if (chars != 0) {
+ /* If there is an \r\n appended, zap it. */
+ if (chars >= 2
+ && msgbuf[chars - 2] == '\r' && msgbuf[chars - 1] == '\n') {
+ chars -= 2;
+ msgbuf[chars] = 0;
+ }
+
+ if (chars > sizeof (buf) - 1) {
+ chars = sizeof (buf) - 1;
+ msgbuf[chars] = 0;
+ }
+
+ wcstombs(buf, msgbuf, chars + 1);
+ LocalFree(msgbuf);
+ }
+ else {
+ sprintf(buf, "unknown win32 error (%ld)", error);
+ }
+
+ SetLastError(lasterr);
+ return buf;
+}
+
+#endif /* UNDER_CE */
+
+/* Reset gzip file state */
+local void gz_reset(state)
+ gz_statep state;
+{
+ state->x.have = 0; /* no output data available */
+ if (state->mode == GZ_READ) { /* for reading ... */
+ state->eof = 0; /* not at end of file */
+ state->past = 0; /* have not read past end yet */
+ state->how = LOOK; /* look for gzip header */
+ }
+ state->seek = 0; /* no seek request pending */
+ gz_error(state, Z_OK, NULL); /* clear error */
+ state->x.pos = 0; /* no uncompressed data yet */
+ state->strm.avail_in = 0; /* no input data yet */
+}
+
+/* Open a gzip file either by name or file descriptor. */
+local gzFile gz_open(path, fd, mode)
+ const void *path;
+ int fd;
+ const char *mode;
+{
+ gz_statep state;
+ size_t len;
+ int oflag;
+#ifdef O_CLOEXEC
+ int cloexec = 0;
+#endif
+#ifdef O_EXCL
+ int exclusive = 0;
+#endif
+
+ /* check input */
+ if (path == NULL)
+ return NULL;
+
+ /* allocate gzFile structure to return */
+ state = (gz_statep)malloc(sizeof(gz_state));
+ if (state == NULL)
+ return NULL;
+ state->size = 0; /* no buffers allocated yet */
+ state->want = GZBUFSIZE; /* requested buffer size */
+ state->msg = NULL; /* no error message yet */
+
+ /* interpret mode */
+ state->mode = GZ_NONE;
+ state->level = Z_DEFAULT_COMPRESSION;
+ state->strategy = Z_DEFAULT_STRATEGY;
+ state->direct = 0;
+ while (*mode) {
+ if (*mode >= '0' && *mode <= '9')
+ state->level = *mode - '0';
+ else
+ switch (*mode) {
+ case 'r':
+ state->mode = GZ_READ;
+ break;
+#ifndef NO_GZCOMPRESS
+ case 'w':
+ state->mode = GZ_WRITE;
+ break;
+ case 'a':
+ state->mode = GZ_APPEND;
+ break;
+#endif
+ case '+': /* can't read and write at the same time */
+ free(state);
+ return NULL;
+ case 'b': /* ignore -- will request binary anyway */
+ break;
+#ifdef O_CLOEXEC
+ case 'e':
+ cloexec = 1;
+ break;
+#endif
+#ifdef O_EXCL
+ case 'x':
+ exclusive = 1;
+ break;
+#endif
+ case 'f':
+ state->strategy = Z_FILTERED;
+ break;
+ case 'h':
+ state->strategy = Z_HUFFMAN_ONLY;
+ break;
+ case 'R':
+ state->strategy = Z_RLE;
+ break;
+ case 'F':
+ state->strategy = Z_FIXED;
+ break;
+ case 'T':
+ state->direct = 1;
+ break;
+ default: /* could consider as an error, but just ignore */
+ ;
+ }
+ mode++;
+ }
+
+ /* must provide an "r", "w", or "a" */
+ if (state->mode == GZ_NONE) {
+ free(state);
+ return NULL;
+ }
+
+ /* can't force transparent read */
+ if (state->mode == GZ_READ) {
+ if (state->direct) {
+ free(state);
+ return NULL;
+ }
+ state->direct = 1; /* for empty file */
+ }
+
+ /* save the path name for error messages */
+#ifdef _WIN32
+ if (fd == -2) {
+ len = wcstombs(NULL, path, 0);
+ if (len == (size_t)-1)
+ len = 0;
+ }
+ else
+#endif
+ len = strlen((const char *)path);
+ state->path = (char *)malloc(len + 1);
+ if (state->path == NULL) {
+ free(state);
+ return NULL;
+ }
+#ifdef _WIN32
+ if (fd == -2)
+ if (len)
+ wcstombs(state->path, path, len + 1);
+ else
+ *(state->path) = 0;
+ else
+#endif
+#if !defined(NO_snprintf) && !defined(NO_vsnprintf)
+ snprintf(state->path, len + 1, "%s", (const char *)path);
+#else
+ strcpy(state->path, path);
+#endif
+
+ /* compute the flags for open() */
+ oflag =
+#ifdef O_LARGEFILE
+ O_LARGEFILE |
+#endif
+#ifdef O_BINARY
+ O_BINARY |
+#endif
+#ifdef O_CLOEXEC
+ (cloexec ? O_CLOEXEC : 0) |
+#endif
+ (state->mode == GZ_READ ?
+ O_RDONLY :
+ (O_WRONLY | O_CREAT |
+#ifdef O_EXCL
+ (exclusive ? O_EXCL : 0) |
+#endif
+ (state->mode == GZ_WRITE ?
+ O_TRUNC :
+ O_APPEND)));
+
+ /* open the file with the appropriate flags (or just use fd) */
+ state->fd = fd > -1 ? fd : (
+#ifdef _WIN32
+ fd == -2 ? _wopen(path, oflag, 0666) :
+#endif
+ open((const char *)path, oflag, 0666));
+ if (state->fd == -1) {
+ free(state->path);
+ free(state);
+ return NULL;
+ }
+ if (state->mode == GZ_APPEND)
+ state->mode = GZ_WRITE; /* simplify later checks */
+
+ /* save the current position for rewinding (only if reading) */
+ if (state->mode == GZ_READ) {
+ state->start = LSEEK(state->fd, 0, SEEK_CUR);
+ if (state->start == -1) state->start = 0;
+ }
+
+ /* initialize stream */
+ gz_reset(state);
+
+ /* return stream */
+ return (gzFile)state;
+}
+
+/* -- see zlib.h -- */
+gzFile ZEXPORT gzopen(path, mode)
+ const char *path;
+ const char *mode;
+{
+ return gz_open(path, -1, mode);
+}
+
+/* -- see zlib.h -- */
+gzFile ZEXPORT gzopen64(path, mode)
+ const char *path;
+ const char *mode;
+{
+ return gz_open(path, -1, mode);
+}
+
+/* -- see zlib.h -- */
+gzFile ZEXPORT gzdopen(fd, mode)
+ int fd;
+ const char *mode;
+{
+ char *path; /* identifier for error messages */
+ gzFile gz;
+
+ if (fd == -1 || (path = (char *)malloc(7 + 3 * sizeof(int))) == NULL)
+ return NULL;
+#if !defined(NO_snprintf) && !defined(NO_vsnprintf)
+ snprintf(path, 7 + 3 * sizeof(int), "<fd:%d>", fd); /* for debugging */
+#else
+ sprintf(path, "<fd:%d>", fd); /* for debugging */
+#endif
+ gz = gz_open(path, fd, mode);
+ free(path);
+ return gz;
+}
+
+/* -- see zlib.h -- */
+#ifdef _WIN32
+gzFile ZEXPORT gzopen_w(path, mode)
+ const wchar_t *path;
+ const char *mode;
+{
+ return gz_open(path, -2, mode);
+}
+#endif
+
+/* -- see zlib.h -- */
+int ZEXPORT gzbuffer(file, size)
+ gzFile file;
+ unsigned size;
+{
+ gz_statep state;
+
+ /* get internal structure and check integrity */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+ if (state->mode != GZ_READ && state->mode != GZ_WRITE)
+ return -1;
+
+ /* make sure we haven't already allocated memory */
+ if (state->size != 0)
+ return -1;
+
+ /* check and set requested size */
+ if (size < 2)
+ size = 2; /* need two bytes to check magic header */
+ state->want = size;
+ return 0;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzrewind(file)
+ gzFile file;
+{
+ gz_statep state;
+
+ /* get internal structure */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+
+ /* check that we're reading and that there's no error */
+ if (state->mode != GZ_READ ||
+ (state->err != Z_OK && state->err != Z_BUF_ERROR))
+ return -1;
+
+ /* back up and start over */
+ if (LSEEK(state->fd, state->start, SEEK_SET) == -1)
+ return -1;
+ gz_reset(state);
+ return 0;
+}
+
+/* -- see zlib.h -- */
+z_off64_t ZEXPORT gzseek64(file, offset, whence)
+ gzFile file;
+ z_off64_t offset;
+ int whence;
+{
+ unsigned n;
+ z_off64_t ret;
+ gz_statep state;
+
+ /* get internal structure and check integrity */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+ if (state->mode != GZ_READ && state->mode != GZ_WRITE)
+ return -1;
+
+ /* check that there's no error */
+ if (state->err != Z_OK && state->err != Z_BUF_ERROR)
+ return -1;
+
+ /* can only seek from start or relative to current position */
+ if (whence != SEEK_SET && whence != SEEK_CUR)
+ return -1;
+
+ /* normalize offset to a SEEK_CUR specification */
+ if (whence == SEEK_SET)
+ offset -= state->x.pos;
+ else if (state->seek)
+ offset += state->skip;
+ state->seek = 0;
+
+ /* if within raw area while reading, just go there */
+ if (state->mode == GZ_READ && state->how == COPY &&
+ state->x.pos + offset >= 0) {
+ ret = LSEEK(state->fd, offset - state->x.have, SEEK_CUR);
+ if (ret == -1)
+ return -1;
+ state->x.have = 0;
+ state->eof = 0;
+ state->past = 0;
+ state->seek = 0;
+ gz_error(state, Z_OK, NULL);
+ state->strm.avail_in = 0;
+ state->x.pos += offset;
+ return state->x.pos;
+ }
+
+ /* calculate skip amount, rewinding if needed for back seek when reading */
+ if (offset < 0) {
+ if (state->mode != GZ_READ) /* writing -- can't go backwards */
+ return -1;
+ offset += state->x.pos;
+ if (offset < 0) /* before start of file! */
+ return -1;
+ if (gzrewind(file) == -1) /* rewind, then skip to offset */
+ return -1;
+ }
+
+ /* if reading, skip what's in output buffer (one less gzgetc() check) */
+ if (state->mode == GZ_READ) {
+ n = GT_OFF(state->x.have) || (z_off64_t)state->x.have > offset ?
+ (unsigned)offset : state->x.have;
+ state->x.have -= n;
+ state->x.next += n;
+ state->x.pos += n;
+ offset -= n;
+ }
+
+ /* request skip (if not zero) */
+ if (offset) {
+ state->seek = 1;
+ state->skip = offset;
+ }
+ return state->x.pos + offset;
+}
+
+/* -- see zlib.h -- */
+z_off_t ZEXPORT gzseek(file, offset, whence)
+ gzFile file;
+ z_off_t offset;
+ int whence;
+{
+ z_off64_t ret;
+
+ ret = gzseek64(file, (z_off64_t)offset, whence);
+ return ret == (z_off_t)ret ? (z_off_t)ret : -1;
+}
+
+/* -- see zlib.h -- */
+z_off64_t ZEXPORT gztell64(file)
+ gzFile file;
+{
+ gz_statep state;
+
+ /* get internal structure and check integrity */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+ if (state->mode != GZ_READ && state->mode != GZ_WRITE)
+ return -1;
+
+ /* return position */
+ return state->x.pos + (state->seek ? state->skip : 0);
+}
+
+/* -- see zlib.h -- */
+z_off_t ZEXPORT gztell(file)
+ gzFile file;
+{
+ z_off64_t ret;
+
+ ret = gztell64(file);
+ return ret == (z_off_t)ret ? (z_off_t)ret : -1;
+}
+
+/* -- see zlib.h -- */
+z_off64_t ZEXPORT gzoffset64(file)
+ gzFile file;
+{
+ z_off64_t offset;
+ gz_statep state;
+
+ /* get internal structure and check integrity */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+ if (state->mode != GZ_READ && state->mode != GZ_WRITE)
+ return -1;
+
+ /* compute and return effective offset in file */
+ offset = LSEEK(state->fd, 0, SEEK_CUR);
+ if (offset == -1)
+ return -1;
+ if (state->mode == GZ_READ) /* reading */
+ offset -= state->strm.avail_in; /* don't count buffered input */
+ return offset;
+}
+
+/* -- see zlib.h -- */
+z_off_t ZEXPORT gzoffset(file)
+ gzFile file;
+{
+ z_off64_t ret;
+
+ ret = gzoffset64(file);
+ return ret == (z_off_t)ret ? (z_off_t)ret : -1;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzeof(file)
+ gzFile file;
+{
+ gz_statep state;
+
+ /* get internal structure and check integrity */
+ if (file == NULL)
+ return 0;
+ state = (gz_statep)file;
+ if (state->mode != GZ_READ && state->mode != GZ_WRITE)
+ return 0;
+
+ /* return end-of-file state */
+ return state->mode == GZ_READ ? state->past : 0;
+}
+
+/* -- see zlib.h -- */
+const char * ZEXPORT gzerror(file, errnum)
+ gzFile file;
+ int *errnum;
+{
+ gz_statep state;
+
+ /* get internal structure and check integrity */
+ if (file == NULL)
+ return NULL;
+ state = (gz_statep)file;
+ if (state->mode != GZ_READ && state->mode != GZ_WRITE)
+ return NULL;
+
+ /* return error information */
+ if (errnum != NULL)
+ *errnum = state->err;
+ return state->err == Z_MEM_ERROR ? "out of memory" :
+ (state->msg == NULL ? "" : state->msg);
+}
+
+/* -- see zlib.h -- */
+void ZEXPORT gzclearerr(file)
+ gzFile file;
+{
+ gz_statep state;
+
+ /* get internal structure and check integrity */
+ if (file == NULL)
+ return;
+ state = (gz_statep)file;
+ if (state->mode != GZ_READ && state->mode != GZ_WRITE)
+ return;
+
+ /* clear error and end-of-file */
+ if (state->mode == GZ_READ) {
+ state->eof = 0;
+ state->past = 0;
+ }
+ gz_error(state, Z_OK, NULL);
+}
+
+/* Create an error message in allocated memory and set state->err and
+ state->msg accordingly. Free any previous error message already there. Do
+ not try to free or allocate space if the error is Z_MEM_ERROR (out of
+ memory). Simply save the error message as a static string. If there is an
+ allocation failure constructing the error message, then convert the error to
+ out of memory. */
+void ZLIB_INTERNAL gz_error(state, err, msg)
+ gz_statep state;
+ int err;
+ const char *msg;
+{
+ /* free previously allocated message and clear */
+ if (state->msg != NULL) {
+ if (state->err != Z_MEM_ERROR)
+ free(state->msg);
+ state->msg = NULL;
+ }
+
+ /* if fatal, set state->x.have to 0 so that the gzgetc() macro fails */
+ if (err != Z_OK && err != Z_BUF_ERROR)
+ state->x.have = 0;
+
+ /* set error code, and if no message, then done */
+ state->err = err;
+ if (msg == NULL)
+ return;
+
+ /* for an out of memory error, return literal string when requested */
+ if (err == Z_MEM_ERROR)
+ return;
+
+ /* construct error message with path */
+ if ((state->msg = (char *)malloc(strlen(state->path) + strlen(msg) + 3)) ==
+ NULL) {
+ state->err = Z_MEM_ERROR;
+ return;
+ }
+#if !defined(NO_snprintf) && !defined(NO_vsnprintf)
+ snprintf(state->msg, strlen(state->path) + strlen(msg) + 3,
+ "%s%s%s", state->path, ": ", msg);
+#else
+ strcpy(state->msg, state->path);
+ strcat(state->msg, ": ");
+ strcat(state->msg, msg);
+#endif
+ return;
+}
+
+#ifndef INT_MAX
+/* portably return maximum value for an int (when limits.h presumed not
+ available) -- we need to do this to cover cases where 2's complement not
+ used, since C standard permits 1's complement and sign-bit representations,
+ otherwise we could just use ((unsigned)-1) >> 1 */
+unsigned ZLIB_INTERNAL gz_intmax()
+{
+ unsigned p, q;
+
+ p = 1;
+ do {
+ q = p;
+ p <<= 1;
+ p++;
+ } while (p > q);
+ return q >> 1;
+}
+#endif
diff --git a/ml/dlib/dlib/external/zlib/gzread.c b/ml/dlib/dlib/external/zlib/gzread.c
new file mode 100644
index 000000000..bf4538eb2
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/gzread.c
@@ -0,0 +1,594 @@
+/* gzread.c -- zlib functions for reading gzip files
+ * Copyright (C) 2004, 2005, 2010, 2011, 2012, 2013 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+#include "gzguts.h"
+
+/* Local functions */
+local int gz_load OF((gz_statep, unsigned char *, unsigned, unsigned *));
+local int gz_avail OF((gz_statep));
+local int gz_look OF((gz_statep));
+local int gz_decomp OF((gz_statep));
+local int gz_fetch OF((gz_statep));
+local int gz_skip OF((gz_statep, z_off64_t));
+
+/* Use read() to load a buffer -- return -1 on error, otherwise 0. Read from
+ state->fd, and update state->eof, state->err, and state->msg as appropriate.
+ This function needs to loop on read(), since read() is not guaranteed to
+ read the number of bytes requested, depending on the type of descriptor. */
+local int gz_load(state, buf, len, have)
+ gz_statep state;
+ unsigned char *buf;
+ unsigned len;
+ unsigned *have;
+{
+ int ret;
+
+ *have = 0;
+ do {
+ ret = read(state->fd, buf + *have, len - *have);
+ if (ret <= 0)
+ break;
+ *have += ret;
+ } while (*have < len);
+ if (ret < 0) {
+ gz_error(state, Z_ERRNO, zstrerror());
+ return -1;
+ }
+ if (ret == 0)
+ state->eof = 1;
+ return 0;
+}
+
+/* Load up input buffer and set eof flag if last data loaded -- return -1 on
+ error, 0 otherwise. Note that the eof flag is set when the end of the input
+ file is reached, even though there may be unused data in the buffer. Once
+ that data has been used, no more attempts will be made to read the file.
+ If strm->avail_in != 0, then the current data is moved to the beginning of
+ the input buffer, and then the remainder of the buffer is loaded with the
+ available data from the input file. */
+local int gz_avail(state)
+ gz_statep state;
+{
+ unsigned got;
+ z_streamp strm = &(state->strm);
+
+ if (state->err != Z_OK && state->err != Z_BUF_ERROR)
+ return -1;
+ if (state->eof == 0) {
+ if (strm->avail_in) { /* copy what's there to the start */
+ unsigned char *p = state->in;
+ unsigned const char *q = strm->next_in;
+ unsigned n = strm->avail_in;
+ do {
+ *p++ = *q++;
+ } while (--n);
+ }
+ if (gz_load(state, state->in + strm->avail_in,
+ state->size - strm->avail_in, &got) == -1)
+ return -1;
+ strm->avail_in += got;
+ strm->next_in = state->in;
+ }
+ return 0;
+}
+
+/* Look for gzip header, set up for inflate or copy. state->x.have must be 0.
+ If this is the first time in, allocate required memory. state->how will be
+ left unchanged if there is no more input data available, will be set to COPY
+ if there is no gzip header and direct copying will be performed, or it will
+ be set to GZIP for decompression. If direct copying, then leftover input
+ data from the input buffer will be copied to the output buffer. In that
+ case, all further file reads will be directly to either the output buffer or
+ a user buffer. If decompressing, the inflate state will be initialized.
+ gz_look() will return 0 on success or -1 on failure. */
+local int gz_look(state)
+ gz_statep state;
+{
+ z_streamp strm = &(state->strm);
+
+ /* allocate read buffers and inflate memory */
+ if (state->size == 0) {
+ /* allocate buffers */
+ state->in = (unsigned char *)malloc(state->want);
+ state->out = (unsigned char *)malloc(state->want << 1);
+ if (state->in == NULL || state->out == NULL) {
+ if (state->out != NULL)
+ free(state->out);
+ if (state->in != NULL)
+ free(state->in);
+ gz_error(state, Z_MEM_ERROR, "out of memory");
+ return -1;
+ }
+ state->size = state->want;
+
+ /* allocate inflate memory */
+ state->strm.zalloc = Z_NULL;
+ state->strm.zfree = Z_NULL;
+ state->strm.opaque = Z_NULL;
+ state->strm.avail_in = 0;
+ state->strm.next_in = Z_NULL;
+ if (inflateInit2(&(state->strm), 15 + 16) != Z_OK) { /* gunzip */
+ free(state->out);
+ free(state->in);
+ state->size = 0;
+ gz_error(state, Z_MEM_ERROR, "out of memory");
+ return -1;
+ }
+ }
+
+ /* get at least the magic bytes in the input buffer */
+ if (strm->avail_in < 2) {
+ if (gz_avail(state) == -1)
+ return -1;
+ if (strm->avail_in == 0)
+ return 0;
+ }
+
+ /* look for gzip magic bytes -- if there, do gzip decoding (note: there is
+ a logical dilemma here when considering the case of a partially written
+ gzip file, to wit, if a single 31 byte is written, then we cannot tell
+ whether this is a single-byte file, or just a partially written gzip
+ file -- for here we assume that if a gzip file is being written, then
+ the header will be written in a single operation, so that reading a
+ single byte is sufficient indication that it is not a gzip file) */
+ if (strm->avail_in > 1 &&
+ strm->next_in[0] == 31 && strm->next_in[1] == 139) {
+ inflateReset(strm);
+ state->how = GZIP;
+ state->direct = 0;
+ return 0;
+ }
+
+ /* no gzip header -- if we were decoding gzip before, then this is trailing
+ garbage. Ignore the trailing garbage and finish. */
+ if (state->direct == 0) {
+ strm->avail_in = 0;
+ state->eof = 1;
+ state->x.have = 0;
+ return 0;
+ }
+
+ /* doing raw i/o, copy any leftover input to output -- this assumes that
+ the output buffer is larger than the input buffer, which also assures
+ space for gzungetc() */
+ state->x.next = state->out;
+ if (strm->avail_in) {
+ memcpy(state->x.next, strm->next_in, strm->avail_in);
+ state->x.have = strm->avail_in;
+ strm->avail_in = 0;
+ }
+ state->how = COPY;
+ state->direct = 1;
+ return 0;
+}
+
+/* Decompress from input to the provided next_out and avail_out in the state.
+ On return, state->x.have and state->x.next point to the just decompressed
+ data. If the gzip stream completes, state->how is reset to LOOK to look for
+ the next gzip stream or raw data, once state->x.have is depleted. Returns 0
+ on success, -1 on failure. */
+local int gz_decomp(state)
+ gz_statep state;
+{
+ int ret = Z_OK;
+ unsigned had;
+ z_streamp strm = &(state->strm);
+
+ /* fill output buffer up to end of deflate stream */
+ had = strm->avail_out;
+ do {
+ /* get more input for inflate() */
+ if (strm->avail_in == 0 && gz_avail(state) == -1)
+ return -1;
+ if (strm->avail_in == 0) {
+ gz_error(state, Z_BUF_ERROR, "unexpected end of file");
+ break;
+ }
+
+ /* decompress and handle errors */
+ ret = inflate(strm, Z_NO_FLUSH);
+ if (ret == Z_STREAM_ERROR || ret == Z_NEED_DICT) {
+ gz_error(state, Z_STREAM_ERROR,
+ "internal error: inflate stream corrupt");
+ return -1;
+ }
+ if (ret == Z_MEM_ERROR) {
+ gz_error(state, Z_MEM_ERROR, "out of memory");
+ return -1;
+ }
+ if (ret == Z_DATA_ERROR) { /* deflate stream invalid */
+ gz_error(state, Z_DATA_ERROR,
+ strm->msg == NULL ? "compressed data error" : strm->msg);
+ return -1;
+ }
+ } while (strm->avail_out && ret != Z_STREAM_END);
+
+ /* update available output */
+ state->x.have = had - strm->avail_out;
+ state->x.next = strm->next_out - state->x.have;
+
+ /* if the gzip stream completed successfully, look for another */
+ if (ret == Z_STREAM_END)
+ state->how = LOOK;
+
+ /* good decompression */
+ return 0;
+}
+
+/* Fetch data and put it in the output buffer. Assumes state->x.have is 0.
+ Data is either copied from the input file or decompressed from the input
+ file depending on state->how. If state->how is LOOK, then a gzip header is
+ looked for to determine whether to copy or decompress. Returns -1 on error,
+ otherwise 0. gz_fetch() will leave state->how as COPY or GZIP unless the
+ end of the input file has been reached and all data has been processed. */
+local int gz_fetch(state)
+ gz_statep state;
+{
+ z_streamp strm = &(state->strm);
+
+ do {
+ switch(state->how) {
+ case LOOK: /* -> LOOK, COPY (only if never GZIP), or GZIP */
+ if (gz_look(state) == -1)
+ return -1;
+ if (state->how == LOOK)
+ return 0;
+ break;
+ case COPY: /* -> COPY */
+ if (gz_load(state, state->out, state->size << 1, &(state->x.have))
+ == -1)
+ return -1;
+ state->x.next = state->out;
+ return 0;
+ case GZIP: /* -> GZIP or LOOK (if end of gzip stream) */
+ strm->avail_out = state->size << 1;
+ strm->next_out = state->out;
+ if (gz_decomp(state) == -1)
+ return -1;
+ }
+ } while (state->x.have == 0 && (!state->eof || strm->avail_in));
+ return 0;
+}
+
+/* Skip len uncompressed bytes of output. Return -1 on error, 0 on success. */
+local int gz_skip(state, len)
+ gz_statep state;
+ z_off64_t len;
+{
+ unsigned n;
+
+ /* skip over len bytes or reach end-of-file, whichever comes first */
+ while (len)
+ /* skip over whatever is in output buffer */
+ if (state->x.have) {
+ n = GT_OFF(state->x.have) || (z_off64_t)state->x.have > len ?
+ (unsigned)len : state->x.have;
+ state->x.have -= n;
+ state->x.next += n;
+ state->x.pos += n;
+ len -= n;
+ }
+
+ /* output buffer empty -- return if we're at the end of the input */
+ else if (state->eof && state->strm.avail_in == 0)
+ break;
+
+ /* need more data to skip -- load up output buffer */
+ else {
+ /* get more output, looking for header if required */
+ if (gz_fetch(state) == -1)
+ return -1;
+ }
+ return 0;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzread(file, buf, len)
+ gzFile file;
+ voidp buf;
+ unsigned len;
+{
+ unsigned got, n;
+ gz_statep state;
+ z_streamp strm;
+
+ /* get internal structure */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+ strm = &(state->strm);
+
+ /* check that we're reading and that there's no (serious) error */
+ if (state->mode != GZ_READ ||
+ (state->err != Z_OK && state->err != Z_BUF_ERROR))
+ return -1;
+
+ /* since an int is returned, make sure len fits in one, otherwise return
+ with an error (this avoids the flaw in the interface) */
+ if ((int)len < 0) {
+ gz_error(state, Z_DATA_ERROR, "requested length does not fit in int");
+ return -1;
+ }
+
+ /* if len is zero, avoid unnecessary operations */
+ if (len == 0)
+ return 0;
+
+ /* process a skip request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_skip(state, state->skip) == -1)
+ return -1;
+ }
+
+ /* get len bytes to buf, or less than len if at the end */
+ got = 0;
+ do {
+ /* first just try copying data from the output buffer */
+ if (state->x.have) {
+ n = state->x.have > len ? len : state->x.have;
+ memcpy(buf, state->x.next, n);
+ state->x.next += n;
+ state->x.have -= n;
+ }
+
+ /* output buffer empty -- return if we're at the end of the input */
+ else if (state->eof && strm->avail_in == 0) {
+ state->past = 1; /* tried to read past end */
+ break;
+ }
+
+ /* need output data -- for small len or new stream load up our output
+ buffer */
+ else if (state->how == LOOK || len < (state->size << 1)) {
+ /* get more output, looking for header if required */
+ if (gz_fetch(state) == -1)
+ return -1;
+ continue; /* no progress yet -- go back to copy above */
+ /* the copy above assures that we will leave with space in the
+ output buffer, allowing at least one gzungetc() to succeed */
+ }
+
+ /* large len -- read directly into user buffer */
+ else if (state->how == COPY) { /* read directly */
+ if (gz_load(state, (unsigned char *)buf, len, &n) == -1)
+ return -1;
+ }
+
+ /* large len -- decompress directly into user buffer */
+ else { /* state->how == GZIP */
+ strm->avail_out = len;
+ strm->next_out = (unsigned char *)buf;
+ if (gz_decomp(state) == -1)
+ return -1;
+ n = state->x.have;
+ state->x.have = 0;
+ }
+
+ /* update progress */
+ len -= n;
+ buf = (char *)buf + n;
+ got += n;
+ state->x.pos += n;
+ } while (len);
+
+ /* return number of bytes read into user buffer (will fit in int) */
+ return (int)got;
+}
+
+/* -- see zlib.h -- */
+#ifdef Z_PREFIX_SET
+# undef z_gzgetc
+#else
+# undef gzgetc
+#endif
+int ZEXPORT gzgetc(file)
+ gzFile file;
+{
+ int ret;
+ unsigned char buf[1];
+ gz_statep state;
+
+ /* get internal structure */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+
+ /* check that we're reading and that there's no (serious) error */
+ if (state->mode != GZ_READ ||
+ (state->err != Z_OK && state->err != Z_BUF_ERROR))
+ return -1;
+
+ /* try output buffer (no need to check for skip request) */
+ if (state->x.have) {
+ state->x.have--;
+ state->x.pos++;
+ return *(state->x.next)++;
+ }
+
+ /* nothing there -- try gzread() */
+ ret = gzread(file, buf, 1);
+ return ret < 1 ? -1 : buf[0];
+}
+
+int ZEXPORT gzgetc_(file)
+gzFile file;
+{
+ return gzgetc(file);
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzungetc(c, file)
+ int c;
+ gzFile file;
+{
+ gz_statep state;
+
+ /* get internal structure */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+
+ /* check that we're reading and that there's no (serious) error */
+ if (state->mode != GZ_READ ||
+ (state->err != Z_OK && state->err != Z_BUF_ERROR))
+ return -1;
+
+ /* process a skip request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_skip(state, state->skip) == -1)
+ return -1;
+ }
+
+ /* can't push EOF */
+ if (c < 0)
+ return -1;
+
+ /* if output buffer empty, put byte at end (allows more pushing) */
+ if (state->x.have == 0) {
+ state->x.have = 1;
+ state->x.next = state->out + (state->size << 1) - 1;
+ state->x.next[0] = c;
+ state->x.pos--;
+ state->past = 0;
+ return c;
+ }
+
+ /* if no room, give up (must have already done a gzungetc()) */
+ if (state->x.have == (state->size << 1)) {
+ gz_error(state, Z_DATA_ERROR, "out of room to push characters");
+ return -1;
+ }
+
+ /* slide output data if needed and insert byte before existing data */
+ if (state->x.next == state->out) {
+ unsigned char *src = state->out + state->x.have;
+ unsigned char *dest = state->out + (state->size << 1);
+ while (src > state->out)
+ *--dest = *--src;
+ state->x.next = dest;
+ }
+ state->x.have++;
+ state->x.next--;
+ state->x.next[0] = c;
+ state->x.pos--;
+ state->past = 0;
+ return c;
+}
+
+/* -- see zlib.h -- */
+char * ZEXPORT gzgets(file, buf, len)
+ gzFile file;
+ char *buf;
+ int len;
+{
+ unsigned left, n;
+ char *str;
+ unsigned char *eol;
+ gz_statep state;
+
+ /* check parameters and get internal structure */
+ if (file == NULL || buf == NULL || len < 1)
+ return NULL;
+ state = (gz_statep)file;
+
+ /* check that we're reading and that there's no (serious) error */
+ if (state->mode != GZ_READ ||
+ (state->err != Z_OK && state->err != Z_BUF_ERROR))
+ return NULL;
+
+ /* process a skip request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_skip(state, state->skip) == -1)
+ return NULL;
+ }
+
+ /* copy output bytes up to new line or len - 1, whichever comes first --
+ append a terminating zero to the string (we don't check for a zero in
+ the contents, let the user worry about that) */
+ str = buf;
+ left = (unsigned)len - 1;
+ if (left) do {
+ /* assure that something is in the output buffer */
+ if (state->x.have == 0 && gz_fetch(state) == -1)
+ return NULL; /* error */
+ if (state->x.have == 0) { /* end of file */
+ state->past = 1; /* read past end */
+ break; /* return what we have */
+ }
+
+ /* look for end-of-line in current output buffer */
+ n = state->x.have > left ? left : state->x.have;
+ eol = (unsigned char *)memchr(state->x.next, '\n', n);
+ if (eol != NULL)
+ n = (unsigned)(eol - state->x.next) + 1;
+
+ /* copy through end-of-line, or remainder if not found */
+ memcpy(buf, state->x.next, n);
+ state->x.have -= n;
+ state->x.next += n;
+ state->x.pos += n;
+ left -= n;
+ buf += n;
+ } while (left && eol == NULL);
+
+ /* return terminated string, or if nothing, end of file */
+ if (buf == str)
+ return NULL;
+ buf[0] = 0;
+ return str;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzdirect(file)
+ gzFile file;
+{
+ gz_statep state;
+
+ /* get internal structure */
+ if (file == NULL)
+ return 0;
+ state = (gz_statep)file;
+
+ /* if the state is not known, but we can find out, then do so (this is
+ mainly for right after a gzopen() or gzdopen()) */
+ if (state->mode == GZ_READ && state->how == LOOK && state->x.have == 0)
+ (void)gz_look(state);
+
+ /* return 1 if transparent, 0 if processing a gzip stream */
+ return state->direct;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzclose_r(file)
+ gzFile file;
+{
+ int ret, err;
+ gz_statep state;
+
+ /* get internal structure */
+ if (file == NULL)
+ return Z_STREAM_ERROR;
+ state = (gz_statep)file;
+
+ /* check that we're reading */
+ if (state->mode != GZ_READ)
+ return Z_STREAM_ERROR;
+
+ /* free memory and close file */
+ if (state->size) {
+ inflateEnd(&(state->strm));
+ free(state->out);
+ free(state->in);
+ }
+ err = state->err == Z_BUF_ERROR ? Z_BUF_ERROR : Z_OK;
+ gz_error(state, Z_OK, NULL);
+ free(state->path);
+ ret = close(state->fd);
+ free(state);
+ return ret ? Z_ERRNO : err;
+}
diff --git a/ml/dlib/dlib/external/zlib/gzwrite.c b/ml/dlib/dlib/external/zlib/gzwrite.c
new file mode 100644
index 000000000..aa767fbf6
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/gzwrite.c
@@ -0,0 +1,577 @@
+/* gzwrite.c -- zlib functions for writing gzip files
+ * Copyright (C) 2004, 2005, 2010, 2011, 2012, 2013 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+#include "gzguts.h"
+
+/* Local functions */
+local int gz_init OF((gz_statep));
+local int gz_comp OF((gz_statep, int));
+local int gz_zero OF((gz_statep, z_off64_t));
+
+/* Initialize state for writing a gzip file. Mark initialization by setting
+ state->size to non-zero. Return -1 on failure or 0 on success. */
+local int gz_init(state)
+ gz_statep state;
+{
+ int ret;
+ z_streamp strm = &(state->strm);
+
+ /* allocate input buffer */
+ state->in = (unsigned char *)malloc(state->want);
+ if (state->in == NULL) {
+ gz_error(state, Z_MEM_ERROR, "out of memory");
+ return -1;
+ }
+
+ /* only need output buffer and deflate state if compressing */
+ if (!state->direct) {
+ /* allocate output buffer */
+ state->out = (unsigned char *)malloc(state->want);
+ if (state->out == NULL) {
+ free(state->in);
+ gz_error(state, Z_MEM_ERROR, "out of memory");
+ return -1;
+ }
+
+ /* allocate deflate memory, set up for gzip compression */
+ strm->zalloc = Z_NULL;
+ strm->zfree = Z_NULL;
+ strm->opaque = Z_NULL;
+ ret = deflateInit2(strm, state->level, Z_DEFLATED,
+ MAX_WBITS + 16, DEF_MEM_LEVEL, state->strategy);
+ if (ret != Z_OK) {
+ free(state->out);
+ free(state->in);
+ gz_error(state, Z_MEM_ERROR, "out of memory");
+ return -1;
+ }
+ }
+
+ /* mark state as initialized */
+ state->size = state->want;
+
+ /* initialize write buffer if compressing */
+ if (!state->direct) {
+ strm->avail_out = state->size;
+ strm->next_out = state->out;
+ state->x.next = strm->next_out;
+ }
+ return 0;
+}
+
+/* Compress whatever is at avail_in and next_in and write to the output file.
+ Return -1 if there is an error writing to the output file, otherwise 0.
+ flush is assumed to be a valid deflate() flush value. If flush is Z_FINISH,
+ then the deflate() state is reset to start a new gzip stream. If gz->direct
+ is true, then simply write to the output file without compressing, and
+ ignore flush. */
+local int gz_comp(state, flush)
+ gz_statep state;
+ int flush;
+{
+ int ret, got;
+ unsigned have;
+ z_streamp strm = &(state->strm);
+
+ /* allocate memory if this is the first time through */
+ if (state->size == 0 && gz_init(state) == -1)
+ return -1;
+
+ /* write directly if requested */
+ if (state->direct) {
+ got = write(state->fd, strm->next_in, strm->avail_in);
+ if (got < 0 || (unsigned)got != strm->avail_in) {
+ gz_error(state, Z_ERRNO, zstrerror());
+ return -1;
+ }
+ strm->avail_in = 0;
+ return 0;
+ }
+
+ /* run deflate() on provided input until it produces no more output */
+ ret = Z_OK;
+ do {
+ /* write out current buffer contents if full, or if flushing, but if
+ doing Z_FINISH then don't write until we get to Z_STREAM_END */
+ if (strm->avail_out == 0 || (flush != Z_NO_FLUSH &&
+ (flush != Z_FINISH || ret == Z_STREAM_END))) {
+ have = (unsigned)(strm->next_out - state->x.next);
+ if (have && ((got = write(state->fd, state->x.next, have)) < 0 ||
+ (unsigned)got != have)) {
+ gz_error(state, Z_ERRNO, zstrerror());
+ return -1;
+ }
+ if (strm->avail_out == 0) {
+ strm->avail_out = state->size;
+ strm->next_out = state->out;
+ }
+ state->x.next = strm->next_out;
+ }
+
+ /* compress */
+ have = strm->avail_out;
+ ret = deflate(strm, flush);
+ if (ret == Z_STREAM_ERROR) {
+ gz_error(state, Z_STREAM_ERROR,
+ "internal error: deflate stream corrupt");
+ return -1;
+ }
+ have -= strm->avail_out;
+ } while (have);
+
+ /* if that completed a deflate stream, allow another to start */
+ if (flush == Z_FINISH)
+ deflateReset(strm);
+
+ /* all done, no errors */
+ return 0;
+}
+
+/* Compress len zeros to output. Return -1 on error, 0 on success. */
+local int gz_zero(state, len)
+ gz_statep state;
+ z_off64_t len;
+{
+ int first;
+ unsigned n;
+ z_streamp strm = &(state->strm);
+
+ /* consume whatever's left in the input buffer */
+ if (strm->avail_in && gz_comp(state, Z_NO_FLUSH) == -1)
+ return -1;
+
+ /* compress len zeros (len guaranteed > 0) */
+ first = 1;
+ while (len) {
+ n = GT_OFF(state->size) || (z_off64_t)state->size > len ?
+ (unsigned)len : state->size;
+ if (first) {
+ memset(state->in, 0, n);
+ first = 0;
+ }
+ strm->avail_in = n;
+ strm->next_in = state->in;
+ state->x.pos += n;
+ if (gz_comp(state, Z_NO_FLUSH) == -1)
+ return -1;
+ len -= n;
+ }
+ return 0;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzwrite(file, buf, len)
+ gzFile file;
+ voidpc buf;
+ unsigned len;
+{
+ unsigned put = len;
+ gz_statep state;
+ z_streamp strm;
+
+ /* get internal structure */
+ if (file == NULL)
+ return 0;
+ state = (gz_statep)file;
+ strm = &(state->strm);
+
+ /* check that we're writing and that there's no error */
+ if (state->mode != GZ_WRITE || state->err != Z_OK)
+ return 0;
+
+ /* since an int is returned, make sure len fits in one, otherwise return
+ with an error (this avoids the flaw in the interface) */
+ if ((int)len < 0) {
+ gz_error(state, Z_DATA_ERROR, "requested length does not fit in int");
+ return 0;
+ }
+
+ /* if len is zero, avoid unnecessary operations */
+ if (len == 0)
+ return 0;
+
+ /* allocate memory if this is the first time through */
+ if (state->size == 0 && gz_init(state) == -1)
+ return 0;
+
+ /* check for seek request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_zero(state, state->skip) == -1)
+ return 0;
+ }
+
+ /* for small len, copy to input buffer, otherwise compress directly */
+ if (len < state->size) {
+ /* copy to input buffer, compress when full */
+ do {
+ unsigned have, copy;
+
+ if (strm->avail_in == 0)
+ strm->next_in = state->in;
+ have = (unsigned)((strm->next_in + strm->avail_in) - state->in);
+ copy = state->size - have;
+ if (copy > len)
+ copy = len;
+ memcpy(state->in + have, buf, copy);
+ strm->avail_in += copy;
+ state->x.pos += copy;
+ buf = (const char *)buf + copy;
+ len -= copy;
+ if (len && gz_comp(state, Z_NO_FLUSH) == -1)
+ return 0;
+ } while (len);
+ }
+ else {
+ /* consume whatever's left in the input buffer */
+ if (strm->avail_in && gz_comp(state, Z_NO_FLUSH) == -1)
+ return 0;
+
+ /* directly compress user buffer to file */
+ strm->avail_in = len;
+ strm->next_in = (z_const Bytef *)buf;
+ state->x.pos += len;
+ if (gz_comp(state, Z_NO_FLUSH) == -1)
+ return 0;
+ }
+
+ /* input was all buffered or compressed (put will fit in int) */
+ return (int)put;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzputc(file, c)
+ gzFile file;
+ int c;
+{
+ unsigned have;
+ unsigned char buf[1];
+ gz_statep state;
+ z_streamp strm;
+
+ /* get internal structure */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+ strm = &(state->strm);
+
+ /* check that we're writing and that there's no error */
+ if (state->mode != GZ_WRITE || state->err != Z_OK)
+ return -1;
+
+ /* check for seek request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_zero(state, state->skip) == -1)
+ return -1;
+ }
+
+ /* try writing to input buffer for speed (state->size == 0 if buffer not
+ initialized) */
+ if (state->size) {
+ if (strm->avail_in == 0)
+ strm->next_in = state->in;
+ have = (unsigned)((strm->next_in + strm->avail_in) - state->in);
+ if (have < state->size) {
+ state->in[have] = c;
+ strm->avail_in++;
+ state->x.pos++;
+ return c & 0xff;
+ }
+ }
+
+ /* no room in buffer or not initialized, use gz_write() */
+ buf[0] = c;
+ if (gzwrite(file, buf, 1) != 1)
+ return -1;
+ return c & 0xff;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzputs(file, str)
+ gzFile file;
+ const char *str;
+{
+ int ret;
+ unsigned len;
+
+ /* write string */
+ len = (unsigned)strlen(str);
+ ret = gzwrite(file, str, len);
+ return ret == 0 && len != 0 ? -1 : ret;
+}
+
+#if defined(STDC) || defined(Z_HAVE_STDARG_H)
+#include <stdarg.h>
+
+/* -- see zlib.h -- */
+int ZEXPORTVA gzvprintf(gzFile file, const char *format, va_list va)
+{
+ int size, len;
+ gz_statep state;
+ z_streamp strm;
+
+ /* get internal structure */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+ strm = &(state->strm);
+
+ /* check that we're writing and that there's no error */
+ if (state->mode != GZ_WRITE || state->err != Z_OK)
+ return 0;
+
+ /* make sure we have some buffer space */
+ if (state->size == 0 && gz_init(state) == -1)
+ return 0;
+
+ /* check for seek request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_zero(state, state->skip) == -1)
+ return 0;
+ }
+
+ /* consume whatever's left in the input buffer */
+ if (strm->avail_in && gz_comp(state, Z_NO_FLUSH) == -1)
+ return 0;
+
+ /* do the printf() into the input buffer, put length in len */
+ size = (int)(state->size);
+ state->in[size - 1] = 0;
+#ifdef NO_vsnprintf
+# ifdef HAS_vsprintf_void
+ (void)vsprintf((char *)(state->in), format, va);
+ for (len = 0; len < size; len++)
+ if (state->in[len] == 0) break;
+# else
+ len = vsprintf((char *)(state->in), format, va);
+# endif
+#else
+# ifdef HAS_vsnprintf_void
+ (void)vsnprintf((char *)(state->in), size, format, va);
+ len = strlen((char *)(state->in));
+# else
+ len = vsnprintf((char *)(state->in), size, format, va);
+# endif
+#endif
+
+ /* check that printf() results fit in buffer */
+ if (len <= 0 || len >= (int)size || state->in[size - 1] != 0)
+ return 0;
+
+ /* update buffer and position, defer compression until needed */
+ strm->avail_in = (unsigned)len;
+ strm->next_in = state->in;
+ state->x.pos += len;
+ return len;
+}
+
+int ZEXPORTVA gzprintf(gzFile file, const char *format, ...)
+{
+ va_list va;
+ int ret;
+
+ va_start(va, format);
+ ret = gzvprintf(file, format, va);
+ va_end(va);
+ return ret;
+}
+
+#else /* !STDC && !Z_HAVE_STDARG_H */
+
+/* -- see zlib.h -- */
+int ZEXPORTVA gzprintf (file, format, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10,
+ a11, a12, a13, a14, a15, a16, a17, a18, a19, a20)
+ gzFile file;
+ const char *format;
+ int a1, a2, a3, a4, a5, a6, a7, a8, a9, a10,
+ a11, a12, a13, a14, a15, a16, a17, a18, a19, a20;
+{
+ int size, len;
+ gz_statep state;
+ z_streamp strm;
+
+ /* get internal structure */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+ strm = &(state->strm);
+
+ /* check that can really pass pointer in ints */
+ if (sizeof(int) != sizeof(void *))
+ return 0;
+
+ /* check that we're writing and that there's no error */
+ if (state->mode != GZ_WRITE || state->err != Z_OK)
+ return 0;
+
+ /* make sure we have some buffer space */
+ if (state->size == 0 && gz_init(state) == -1)
+ return 0;
+
+ /* check for seek request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_zero(state, state->skip) == -1)
+ return 0;
+ }
+
+ /* consume whatever's left in the input buffer */
+ if (strm->avail_in && gz_comp(state, Z_NO_FLUSH) == -1)
+ return 0;
+
+ /* do the printf() into the input buffer, put length in len */
+ size = (int)(state->size);
+ state->in[size - 1] = 0;
+#ifdef NO_snprintf
+# ifdef HAS_sprintf_void
+ sprintf((char *)(state->in), format, a1, a2, a3, a4, a5, a6, a7, a8,
+ a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20);
+ for (len = 0; len < size; len++)
+ if (state->in[len] == 0) break;
+# else
+ len = sprintf((char *)(state->in), format, a1, a2, a3, a4, a5, a6, a7, a8,
+ a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20);
+# endif
+#else
+# ifdef HAS_snprintf_void
+ snprintf((char *)(state->in), size, format, a1, a2, a3, a4, a5, a6, a7, a8,
+ a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20);
+ len = strlen((char *)(state->in));
+# else
+ len = snprintf((char *)(state->in), size, format, a1, a2, a3, a4, a5, a6,
+ a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
+ a19, a20);
+# endif
+#endif
+
+ /* check that printf() results fit in buffer */
+ if (len <= 0 || len >= (int)size || state->in[size - 1] != 0)
+ return 0;
+
+ /* update buffer and position, defer compression until needed */
+ strm->avail_in = (unsigned)len;
+ strm->next_in = state->in;
+ state->x.pos += len;
+ return len;
+}
+
+#endif
+
+/* -- see zlib.h -- */
+int ZEXPORT gzflush(file, flush)
+ gzFile file;
+ int flush;
+{
+ gz_statep state;
+
+ /* get internal structure */
+ if (file == NULL)
+ return -1;
+ state = (gz_statep)file;
+
+ /* check that we're writing and that there's no error */
+ if (state->mode != GZ_WRITE || state->err != Z_OK)
+ return Z_STREAM_ERROR;
+
+ /* check flush parameter */
+ if (flush < 0 || flush > Z_FINISH)
+ return Z_STREAM_ERROR;
+
+ /* check for seek request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_zero(state, state->skip) == -1)
+ return -1;
+ }
+
+ /* compress remaining data with requested flush */
+ gz_comp(state, flush);
+ return state->err;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzsetparams(file, level, strategy)
+ gzFile file;
+ int level;
+ int strategy;
+{
+ gz_statep state;
+ z_streamp strm;
+
+ /* get internal structure */
+ if (file == NULL)
+ return Z_STREAM_ERROR;
+ state = (gz_statep)file;
+ strm = &(state->strm);
+
+ /* check that we're writing and that there's no error */
+ if (state->mode != GZ_WRITE || state->err != Z_OK)
+ return Z_STREAM_ERROR;
+
+ /* if no change is requested, then do nothing */
+ if (level == state->level && strategy == state->strategy)
+ return Z_OK;
+
+ /* check for seek request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_zero(state, state->skip) == -1)
+ return -1;
+ }
+
+ /* change compression parameters for subsequent input */
+ if (state->size) {
+ /* flush previous input with previous parameters before changing */
+ if (strm->avail_in && gz_comp(state, Z_PARTIAL_FLUSH) == -1)
+ return state->err;
+ deflateParams(strm, level, strategy);
+ }
+ state->level = level;
+ state->strategy = strategy;
+ return Z_OK;
+}
+
+/* -- see zlib.h -- */
+int ZEXPORT gzclose_w(file)
+ gzFile file;
+{
+ int ret = Z_OK;
+ gz_statep state;
+
+ /* get internal structure */
+ if (file == NULL)
+ return Z_STREAM_ERROR;
+ state = (gz_statep)file;
+
+ /* check that we're writing */
+ if (state->mode != GZ_WRITE)
+ return Z_STREAM_ERROR;
+
+ /* check for seek request */
+ if (state->seek) {
+ state->seek = 0;
+ if (gz_zero(state, state->skip) == -1)
+ ret = state->err;
+ }
+
+ /* flush, free memory, and close file */
+ if (gz_comp(state, Z_FINISH) == -1)
+ ret = state->err;
+ if (state->size) {
+ if (!state->direct) {
+ (void)deflateEnd(&(state->strm));
+ free(state->out);
+ }
+ free(state->in);
+ }
+ gz_error(state, Z_OK, NULL);
+ free(state->path);
+ if (close(state->fd) == -1)
+ ret = Z_ERRNO;
+ free(state);
+ return ret;
+}
diff --git a/ml/dlib/dlib/external/zlib/infback.c b/ml/dlib/dlib/external/zlib/infback.c
new file mode 100644
index 000000000..f3833c2e4
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/infback.c
@@ -0,0 +1,640 @@
+/* infback.c -- inflate using a call-back interface
+ * Copyright (C) 1995-2011 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/*
+ This code is largely copied from inflate.c. Normally either infback.o or
+ inflate.o would be linked into an application--not both. The interface
+ with inffast.c is retained so that optimized assembler-coded versions of
+ inflate_fast() can be used with either inflate.c or infback.c.
+ */
+
+#include "zutil.h"
+#include "inftrees.h"
+#include "inflate.h"
+#include "inffast.h"
+
+/* function prototypes */
+local void fixedtables OF((struct inflate_state FAR *state));
+
+/*
+ strm provides memory allocation functions in zalloc and zfree, or
+ Z_NULL to use the library memory allocation functions.
+
+ windowBits is in the range 8..15, and window is a user-supplied
+ window and output buffer that is 2**windowBits bytes.
+ */
+int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
+z_streamp strm;
+int windowBits;
+unsigned char FAR *window;
+const char *version;
+int stream_size;
+{
+ struct inflate_state FAR *state;
+
+ if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
+ stream_size != (int)(sizeof(z_stream)))
+ return Z_VERSION_ERROR;
+ if (strm == Z_NULL || window == Z_NULL ||
+ windowBits < 8 || windowBits > 15)
+ return Z_STREAM_ERROR;
+ strm->msg = Z_NULL; /* in case we return an error */
+ if (strm->zalloc == (alloc_func)0) {
+#ifdef Z_SOLO
+ return Z_STREAM_ERROR;
+#else
+ strm->zalloc = zcalloc;
+ strm->opaque = (voidpf)0;
+#endif
+ }
+ if (strm->zfree == (free_func)0)
+#ifdef Z_SOLO
+ return Z_STREAM_ERROR;
+#else
+ strm->zfree = zcfree;
+#endif
+ state = (struct inflate_state FAR *)ZALLOC(strm, 1,
+ sizeof(struct inflate_state));
+ if (state == Z_NULL) return Z_MEM_ERROR;
+ Tracev((stderr, "inflate: allocated\n"));
+ strm->state = (struct internal_state FAR *)state;
+ state->dmax = 32768U;
+ state->wbits = windowBits;
+ state->wsize = 1U << windowBits;
+ state->window = window;
+ state->wnext = 0;
+ state->whave = 0;
+ return Z_OK;
+}
+
+/*
+ Return state with length and distance decoding tables and index sizes set to
+ fixed code decoding. Normally this returns fixed tables from inffixed.h.
+ If BUILDFIXED is defined, then instead this routine builds the tables the
+ first time it's called, and returns those tables the first time and
+ thereafter. This reduces the size of the code by about 2K bytes, in
+ exchange for a little execution time. However, BUILDFIXED should not be
+ used for threaded applications, since the rewriting of the tables and virgin
+ may not be thread-safe.
+ */
+local void fixedtables(state)
+struct inflate_state FAR *state;
+{
+#ifdef BUILDFIXED
+ static int virgin = 1;
+ static code *lenfix, *distfix;
+ static code fixed[544];
+
+ /* build fixed huffman tables if first call (may not be thread safe) */
+ if (virgin) {
+ unsigned sym, bits;
+ static code *next;
+
+ /* literal/length table */
+ sym = 0;
+ while (sym < 144) state->lens[sym++] = 8;
+ while (sym < 256) state->lens[sym++] = 9;
+ while (sym < 280) state->lens[sym++] = 7;
+ while (sym < 288) state->lens[sym++] = 8;
+ next = fixed;
+ lenfix = next;
+ bits = 9;
+ inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
+
+ /* distance table */
+ sym = 0;
+ while (sym < 32) state->lens[sym++] = 5;
+ distfix = next;
+ bits = 5;
+ inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
+
+ /* do this just once */
+ virgin = 0;
+ }
+#else /* !BUILDFIXED */
+# include "inffixed.h"
+#endif /* BUILDFIXED */
+ state->lencode = lenfix;
+ state->lenbits = 9;
+ state->distcode = distfix;
+ state->distbits = 5;
+}
+
+/* Macros for inflateBack(): */
+
+/* Load returned state from inflate_fast() */
+#define LOAD() \
+ do { \
+ put = strm->next_out; \
+ left = strm->avail_out; \
+ next = strm->next_in; \
+ have = strm->avail_in; \
+ hold = state->hold; \
+ bits = state->bits; \
+ } while (0)
+
+/* Set state from registers for inflate_fast() */
+#define RESTORE() \
+ do { \
+ strm->next_out = put; \
+ strm->avail_out = left; \
+ strm->next_in = next; \
+ strm->avail_in = have; \
+ state->hold = hold; \
+ state->bits = bits; \
+ } while (0)
+
+/* Clear the input bit accumulator */
+#define INITBITS() \
+ do { \
+ hold = 0; \
+ bits = 0; \
+ } while (0)
+
+/* Assure that some input is available. If input is requested, but denied,
+ then return a Z_BUF_ERROR from inflateBack(). */
+#define PULL() \
+ do { \
+ if (have == 0) { \
+ have = in(in_desc, &next); \
+ if (have == 0) { \
+ next = Z_NULL; \
+ ret = Z_BUF_ERROR; \
+ goto inf_leave; \
+ } \
+ } \
+ } while (0)
+
+/* Get a byte of input into the bit accumulator, or return from inflateBack()
+ with an error if there is no input available. */
+#define PULLBYTE() \
+ do { \
+ PULL(); \
+ have--; \
+ hold += (unsigned long)(*next++) << bits; \
+ bits += 8; \
+ } while (0)
+
+/* Assure that there are at least n bits in the bit accumulator. If there is
+ not enough available input to do that, then return from inflateBack() with
+ an error. */
+#define NEEDBITS(n) \
+ do { \
+ while (bits < (unsigned)(n)) \
+ PULLBYTE(); \
+ } while (0)
+
+/* Return the low n bits of the bit accumulator (n < 16) */
+#define BITS(n) \
+ ((unsigned)hold & ((1U << (n)) - 1))
+
+/* Remove n bits from the bit accumulator */
+#define DROPBITS(n) \
+ do { \
+ hold >>= (n); \
+ bits -= (unsigned)(n); \
+ } while (0)
+
+/* Remove zero to seven bits as needed to go to a byte boundary */
+#define BYTEBITS() \
+ do { \
+ hold >>= bits & 7; \
+ bits -= bits & 7; \
+ } while (0)
+
+/* Assure that some output space is available, by writing out the window
+ if it's full. If the write fails, return from inflateBack() with a
+ Z_BUF_ERROR. */
+#define ROOM() \
+ do { \
+ if (left == 0) { \
+ put = state->window; \
+ left = state->wsize; \
+ state->whave = left; \
+ if (out(out_desc, put, left)) { \
+ ret = Z_BUF_ERROR; \
+ goto inf_leave; \
+ } \
+ } \
+ } while (0)
+
+/*
+ strm provides the memory allocation functions and window buffer on input,
+ and provides information on the unused input on return. For Z_DATA_ERROR
+ returns, strm will also provide an error message.
+
+ in() and out() are the call-back input and output functions. When
+ inflateBack() needs more input, it calls in(). When inflateBack() has
+ filled the window with output, or when it completes with data in the
+ window, it calls out() to write out the data. The application must not
+ change the provided input until in() is called again or inflateBack()
+ returns. The application must not change the window/output buffer until
+ inflateBack() returns.
+
+ in() and out() are called with a descriptor parameter provided in the
+ inflateBack() call. This parameter can be a structure that provides the
+ information required to do the read or write, as well as accumulated
+ information on the input and output such as totals and check values.
+
+ in() should return zero on failure. out() should return non-zero on
+ failure. If either in() or out() fails, than inflateBack() returns a
+ Z_BUF_ERROR. strm->next_in can be checked for Z_NULL to see whether it
+ was in() or out() that caused in the error. Otherwise, inflateBack()
+ returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
+ error, or Z_MEM_ERROR if it could not allocate memory for the state.
+ inflateBack() can also return Z_STREAM_ERROR if the input parameters
+ are not correct, i.e. strm is Z_NULL or the state was not initialized.
+ */
+int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
+z_streamp strm;
+in_func in;
+void FAR *in_desc;
+out_func out;
+void FAR *out_desc;
+{
+ struct inflate_state FAR *state;
+ z_const unsigned char FAR *next; /* next input */
+ unsigned char FAR *put; /* next output */
+ unsigned have, left; /* available input and output */
+ unsigned long hold; /* bit buffer */
+ unsigned bits; /* bits in bit buffer */
+ unsigned copy; /* number of stored or match bytes to copy */
+ unsigned char FAR *from; /* where to copy match bytes from */
+ code here; /* current decoding table entry */
+ code last; /* parent table entry */
+ unsigned len; /* length to copy for repeats, bits to drop */
+ int ret; /* return code */
+ static const unsigned short order[19] = /* permutation of code lengths */
+ {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
+
+ /* Check that the strm exists and that the state was initialized */
+ if (strm == Z_NULL || strm->state == Z_NULL)
+ return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+
+ /* Reset the state */
+ strm->msg = Z_NULL;
+ state->mode = TYPE;
+ state->last = 0;
+ state->whave = 0;
+ next = strm->next_in;
+ have = next != Z_NULL ? strm->avail_in : 0;
+ hold = 0;
+ bits = 0;
+ put = state->window;
+ left = state->wsize;
+
+ /* Inflate until end of block marked as last */
+ for (;;)
+ switch (state->mode) {
+ case TYPE:
+ /* determine and dispatch block type */
+ if (state->last) {
+ BYTEBITS();
+ state->mode = DONE;
+ break;
+ }
+ NEEDBITS(3);
+ state->last = BITS(1);
+ DROPBITS(1);
+ switch (BITS(2)) {
+ case 0: /* stored block */
+ Tracev((stderr, "inflate: stored block%s\n",
+ state->last ? " (last)" : ""));
+ state->mode = STORED;
+ break;
+ case 1: /* fixed block */
+ fixedtables(state);
+ Tracev((stderr, "inflate: fixed codes block%s\n",
+ state->last ? " (last)" : ""));
+ state->mode = LEN; /* decode codes */
+ break;
+ case 2: /* dynamic block */
+ Tracev((stderr, "inflate: dynamic codes block%s\n",
+ state->last ? " (last)" : ""));
+ state->mode = TABLE;
+ break;
+ case 3:
+ strm->msg = (char *)"invalid block type";
+ state->mode = BAD;
+ }
+ DROPBITS(2);
+ break;
+
+ case STORED:
+ /* get and verify stored block length */
+ BYTEBITS(); /* go to byte boundary */
+ NEEDBITS(32);
+ if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
+ strm->msg = (char *)"invalid stored block lengths";
+ state->mode = BAD;
+ break;
+ }
+ state->length = (unsigned)hold & 0xffff;
+ Tracev((stderr, "inflate: stored length %u\n",
+ state->length));
+ INITBITS();
+
+ /* copy stored block from input to output */
+ while (state->length != 0) {
+ copy = state->length;
+ PULL();
+ ROOM();
+ if (copy > have) copy = have;
+ if (copy > left) copy = left;
+ zmemcpy(put, next, copy);
+ have -= copy;
+ next += copy;
+ left -= copy;
+ put += copy;
+ state->length -= copy;
+ }
+ Tracev((stderr, "inflate: stored end\n"));
+ state->mode = TYPE;
+ break;
+
+ case TABLE:
+ /* get dynamic table entries descriptor */
+ NEEDBITS(14);
+ state->nlen = BITS(5) + 257;
+ DROPBITS(5);
+ state->ndist = BITS(5) + 1;
+ DROPBITS(5);
+ state->ncode = BITS(4) + 4;
+ DROPBITS(4);
+#ifndef PKZIP_BUG_WORKAROUND
+ if (state->nlen > 286 || state->ndist > 30) {
+ strm->msg = (char *)"too many length or distance symbols";
+ state->mode = BAD;
+ break;
+ }
+#endif
+ Tracev((stderr, "inflate: table sizes ok\n"));
+
+ /* get code length code lengths (not a typo) */
+ state->have = 0;
+ while (state->have < state->ncode) {
+ NEEDBITS(3);
+ state->lens[order[state->have++]] = (unsigned short)BITS(3);
+ DROPBITS(3);
+ }
+ while (state->have < 19)
+ state->lens[order[state->have++]] = 0;
+ state->next = state->codes;
+ state->lencode = (code const FAR *)(state->next);
+ state->lenbits = 7;
+ ret = inflate_table(CODES, state->lens, 19, &(state->next),
+ &(state->lenbits), state->work);
+ if (ret) {
+ strm->msg = (char *)"invalid code lengths set";
+ state->mode = BAD;
+ break;
+ }
+ Tracev((stderr, "inflate: code lengths ok\n"));
+
+ /* get length and distance code code lengths */
+ state->have = 0;
+ while (state->have < state->nlen + state->ndist) {
+ for (;;) {
+ here = state->lencode[BITS(state->lenbits)];
+ if ((unsigned)(here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ if (here.val < 16) {
+ DROPBITS(here.bits);
+ state->lens[state->have++] = here.val;
+ }
+ else {
+ if (here.val == 16) {
+ NEEDBITS(here.bits + 2);
+ DROPBITS(here.bits);
+ if (state->have == 0) {
+ strm->msg = (char *)"invalid bit length repeat";
+ state->mode = BAD;
+ break;
+ }
+ len = (unsigned)(state->lens[state->have - 1]);
+ copy = 3 + BITS(2);
+ DROPBITS(2);
+ }
+ else if (here.val == 17) {
+ NEEDBITS(here.bits + 3);
+ DROPBITS(here.bits);
+ len = 0;
+ copy = 3 + BITS(3);
+ DROPBITS(3);
+ }
+ else {
+ NEEDBITS(here.bits + 7);
+ DROPBITS(here.bits);
+ len = 0;
+ copy = 11 + BITS(7);
+ DROPBITS(7);
+ }
+ if (state->have + copy > state->nlen + state->ndist) {
+ strm->msg = (char *)"invalid bit length repeat";
+ state->mode = BAD;
+ break;
+ }
+ while (copy--)
+ state->lens[state->have++] = (unsigned short)len;
+ }
+ }
+
+ /* handle error breaks in while */
+ if (state->mode == BAD) break;
+
+ /* check for end-of-block code (better have one) */
+ if (state->lens[256] == 0) {
+ strm->msg = (char *)"invalid code -- missing end-of-block";
+ state->mode = BAD;
+ break;
+ }
+
+ /* build code tables -- note: do not change the lenbits or distbits
+ values here (9 and 6) without reading the comments in inftrees.h
+ concerning the ENOUGH constants, which depend on those values */
+ state->next = state->codes;
+ state->lencode = (code const FAR *)(state->next);
+ state->lenbits = 9;
+ ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
+ &(state->lenbits), state->work);
+ if (ret) {
+ strm->msg = (char *)"invalid literal/lengths set";
+ state->mode = BAD;
+ break;
+ }
+ state->distcode = (code const FAR *)(state->next);
+ state->distbits = 6;
+ ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
+ &(state->next), &(state->distbits), state->work);
+ if (ret) {
+ strm->msg = (char *)"invalid distances set";
+ state->mode = BAD;
+ break;
+ }
+ Tracev((stderr, "inflate: codes ok\n"));
+ state->mode = LEN;
+
+ case LEN:
+ /* use inflate_fast() if we have enough input and output */
+ if (have >= 6 && left >= 258) {
+ RESTORE();
+ if (state->whave < state->wsize)
+ state->whave = state->wsize - left;
+ inflate_fast(strm, state->wsize);
+ LOAD();
+ break;
+ }
+
+ /* get a literal, length, or end-of-block code */
+ for (;;) {
+ here = state->lencode[BITS(state->lenbits)];
+ if ((unsigned)(here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ if (here.op && (here.op & 0xf0) == 0) {
+ last = here;
+ for (;;) {
+ here = state->lencode[last.val +
+ (BITS(last.bits + last.op) >> last.bits)];
+ if ((unsigned)(last.bits + here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ DROPBITS(last.bits);
+ }
+ DROPBITS(here.bits);
+ state->length = (unsigned)here.val;
+
+ /* process literal */
+ if (here.op == 0) {
+ Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
+ "inflate: literal '%c'\n" :
+ "inflate: literal 0x%02x\n", here.val));
+ ROOM();
+ *put++ = (unsigned char)(state->length);
+ left--;
+ state->mode = LEN;
+ break;
+ }
+
+ /* process end of block */
+ if (here.op & 32) {
+ Tracevv((stderr, "inflate: end of block\n"));
+ state->mode = TYPE;
+ break;
+ }
+
+ /* invalid code */
+ if (here.op & 64) {
+ strm->msg = (char *)"invalid literal/length code";
+ state->mode = BAD;
+ break;
+ }
+
+ /* length code -- get extra bits, if any */
+ state->extra = (unsigned)(here.op) & 15;
+ if (state->extra != 0) {
+ NEEDBITS(state->extra);
+ state->length += BITS(state->extra);
+ DROPBITS(state->extra);
+ }
+ Tracevv((stderr, "inflate: length %u\n", state->length));
+
+ /* get distance code */
+ for (;;) {
+ here = state->distcode[BITS(state->distbits)];
+ if ((unsigned)(here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ if ((here.op & 0xf0) == 0) {
+ last = here;
+ for (;;) {
+ here = state->distcode[last.val +
+ (BITS(last.bits + last.op) >> last.bits)];
+ if ((unsigned)(last.bits + here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ DROPBITS(last.bits);
+ }
+ DROPBITS(here.bits);
+ if (here.op & 64) {
+ strm->msg = (char *)"invalid distance code";
+ state->mode = BAD;
+ break;
+ }
+ state->offset = (unsigned)here.val;
+
+ /* get distance extra bits, if any */
+ state->extra = (unsigned)(here.op) & 15;
+ if (state->extra != 0) {
+ NEEDBITS(state->extra);
+ state->offset += BITS(state->extra);
+ DROPBITS(state->extra);
+ }
+ if (state->offset > state->wsize - (state->whave < state->wsize ?
+ left : 0)) {
+ strm->msg = (char *)"invalid distance too far back";
+ state->mode = BAD;
+ break;
+ }
+ Tracevv((stderr, "inflate: distance %u\n", state->offset));
+
+ /* copy match from window to output */
+ do {
+ ROOM();
+ copy = state->wsize - state->offset;
+ if (copy < left) {
+ from = put + copy;
+ copy = left - copy;
+ }
+ else {
+ from = put - state->offset;
+ copy = left;
+ }
+ if (copy > state->length) copy = state->length;
+ state->length -= copy;
+ left -= copy;
+ do {
+ *put++ = *from++;
+ } while (--copy);
+ } while (state->length != 0);
+ break;
+
+ case DONE:
+ /* inflate stream terminated properly -- write leftover output */
+ ret = Z_STREAM_END;
+ if (left < state->wsize) {
+ if (out(out_desc, state->window, state->wsize - left))
+ ret = Z_BUF_ERROR;
+ }
+ goto inf_leave;
+
+ case BAD:
+ ret = Z_DATA_ERROR;
+ goto inf_leave;
+
+ default: /* can't happen, but makes compilers happy */
+ ret = Z_STREAM_ERROR;
+ goto inf_leave;
+ }
+
+ /* Return unused input */
+ inf_leave:
+ strm->next_in = next;
+ strm->avail_in = have;
+ return ret;
+}
+
+int ZEXPORT inflateBackEnd(strm)
+z_streamp strm;
+{
+ if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
+ return Z_STREAM_ERROR;
+ ZFREE(strm, strm->state);
+ strm->state = Z_NULL;
+ Tracev((stderr, "inflate: end\n"));
+ return Z_OK;
+}
diff --git a/ml/dlib/dlib/external/zlib/inffast.c b/ml/dlib/dlib/external/zlib/inffast.c
new file mode 100644
index 000000000..bda59ceb6
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/inffast.c
@@ -0,0 +1,340 @@
+/* inffast.c -- fast decoding
+ * Copyright (C) 1995-2008, 2010, 2013 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+#include "zutil.h"
+#include "inftrees.h"
+#include "inflate.h"
+#include "inffast.h"
+
+#ifndef ASMINF
+
+/* Allow machine dependent optimization for post-increment or pre-increment.
+ Based on testing to date,
+ Pre-increment preferred for:
+ - PowerPC G3 (Adler)
+ - MIPS R5000 (Randers-Pehrson)
+ Post-increment preferred for:
+ - none
+ No measurable difference:
+ - Pentium III (Anderson)
+ - M68060 (Nikl)
+ */
+#ifdef POSTINC
+# define OFF 0
+# define PUP(a) *(a)++
+#else
+# define OFF 1
+# define PUP(a) *++(a)
+#endif
+
+/*
+ Decode literal, length, and distance codes and write out the resulting
+ literal and match bytes until either not enough input or output is
+ available, an end-of-block is encountered, or a data error is encountered.
+ When large enough input and output buffers are supplied to inflate(), for
+ example, a 16K input buffer and a 64K output buffer, more than 95% of the
+ inflate execution time is spent in this routine.
+
+ Entry assumptions:
+
+ state->mode == LEN
+ strm->avail_in >= 6
+ strm->avail_out >= 258
+ start >= strm->avail_out
+ state->bits < 8
+
+ On return, state->mode is one of:
+
+ LEN -- ran out of enough output space or enough available input
+ TYPE -- reached end of block code, inflate() to interpret next block
+ BAD -- error in block data
+
+ Notes:
+
+ - The maximum input bits used by a length/distance pair is 15 bits for the
+ length code, 5 bits for the length extra, 15 bits for the distance code,
+ and 13 bits for the distance extra. This totals 48 bits, or six bytes.
+ Therefore if strm->avail_in >= 6, then there is enough input to avoid
+ checking for available input while decoding.
+
+ - The maximum bytes that a single length/distance pair can output is 258
+ bytes, which is the maximum length that can be coded. inflate_fast()
+ requires strm->avail_out >= 258 for each loop to avoid checking for
+ output space.
+ */
+void ZLIB_INTERNAL inflate_fast(strm, start)
+z_streamp strm;
+unsigned start; /* inflate()'s starting value for strm->avail_out */
+{
+ struct inflate_state FAR *state;
+ z_const unsigned char FAR *in; /* local strm->next_in */
+ z_const unsigned char FAR *last; /* have enough input while in < last */
+ unsigned char FAR *out; /* local strm->next_out */
+ unsigned char FAR *beg; /* inflate()'s initial strm->next_out */
+ unsigned char FAR *end; /* while out < end, enough space available */
+#ifdef INFLATE_STRICT
+ unsigned dmax; /* maximum distance from zlib header */
+#endif
+ unsigned wsize; /* window size or zero if not using window */
+ unsigned whave; /* valid bytes in the window */
+ unsigned wnext; /* window write index */
+ unsigned char FAR *window; /* allocated sliding window, if wsize != 0 */
+ unsigned long hold; /* local strm->hold */
+ unsigned bits; /* local strm->bits */
+ code const FAR *lcode; /* local strm->lencode */
+ code const FAR *dcode; /* local strm->distcode */
+ unsigned lmask; /* mask for first level of length codes */
+ unsigned dmask; /* mask for first level of distance codes */
+ code here; /* retrieved table entry */
+ unsigned op; /* code bits, operation, extra bits, or */
+ /* window position, window bytes to copy */
+ unsigned len; /* match length, unused bytes */
+ unsigned dist; /* match distance */
+ unsigned char FAR *from; /* where to copy match from */
+
+ /* copy state to local variables */
+ state = (struct inflate_state FAR *)strm->state;
+ in = strm->next_in - OFF;
+ last = in + (strm->avail_in - 5);
+ out = strm->next_out - OFF;
+ beg = out - (start - strm->avail_out);
+ end = out + (strm->avail_out - 257);
+#ifdef INFLATE_STRICT
+ dmax = state->dmax;
+#endif
+ wsize = state->wsize;
+ whave = state->whave;
+ wnext = state->wnext;
+ window = state->window;
+ hold = state->hold;
+ bits = state->bits;
+ lcode = state->lencode;
+ dcode = state->distcode;
+ lmask = (1U << state->lenbits) - 1;
+ dmask = (1U << state->distbits) - 1;
+
+ /* decode literals and length/distances until end-of-block or not enough
+ input data or output space */
+ do {
+ if (bits < 15) {
+ hold += (unsigned long)(PUP(in)) << bits;
+ bits += 8;
+ hold += (unsigned long)(PUP(in)) << bits;
+ bits += 8;
+ }
+ here = lcode[hold & lmask];
+ dolen:
+ op = (unsigned)(here.bits);
+ hold >>= op;
+ bits -= op;
+ op = (unsigned)(here.op);
+ if (op == 0) { /* literal */
+ Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
+ "inflate: literal '%c'\n" :
+ "inflate: literal 0x%02x\n", here.val));
+ PUP(out) = (unsigned char)(here.val);
+ }
+ else if (op & 16) { /* length base */
+ len = (unsigned)(here.val);
+ op &= 15; /* number of extra bits */
+ if (op) {
+ if (bits < op) {
+ hold += (unsigned long)(PUP(in)) << bits;
+ bits += 8;
+ }
+ len += (unsigned)hold & ((1U << op) - 1);
+ hold >>= op;
+ bits -= op;
+ }
+ Tracevv((stderr, "inflate: length %u\n", len));
+ if (bits < 15) {
+ hold += (unsigned long)(PUP(in)) << bits;
+ bits += 8;
+ hold += (unsigned long)(PUP(in)) << bits;
+ bits += 8;
+ }
+ here = dcode[hold & dmask];
+ dodist:
+ op = (unsigned)(here.bits);
+ hold >>= op;
+ bits -= op;
+ op = (unsigned)(here.op);
+ if (op & 16) { /* distance base */
+ dist = (unsigned)(here.val);
+ op &= 15; /* number of extra bits */
+ if (bits < op) {
+ hold += (unsigned long)(PUP(in)) << bits;
+ bits += 8;
+ if (bits < op) {
+ hold += (unsigned long)(PUP(in)) << bits;
+ bits += 8;
+ }
+ }
+ dist += (unsigned)hold & ((1U << op) - 1);
+#ifdef INFLATE_STRICT
+ if (dist > dmax) {
+ strm->msg = (char *)"invalid distance too far back";
+ state->mode = BAD;
+ break;
+ }
+#endif
+ hold >>= op;
+ bits -= op;
+ Tracevv((stderr, "inflate: distance %u\n", dist));
+ op = (unsigned)(out - beg); /* max distance in output */
+ if (dist > op) { /* see if copy from window */
+ op = dist - op; /* distance back in window */
+ if (op > whave) {
+ if (state->sane) {
+ strm->msg =
+ (char *)"invalid distance too far back";
+ state->mode = BAD;
+ break;
+ }
+#ifdef INFLATE_ALLOW_INVALID_DISTANCE_TOOFAR_ARRR
+ if (len <= op - whave) {
+ do {
+ PUP(out) = 0;
+ } while (--len);
+ continue;
+ }
+ len -= op - whave;
+ do {
+ PUP(out) = 0;
+ } while (--op > whave);
+ if (op == 0) {
+ from = out - dist;
+ do {
+ PUP(out) = PUP(from);
+ } while (--len);
+ continue;
+ }
+#endif
+ }
+ from = window - OFF;
+ if (wnext == 0) { /* very common case */
+ from += wsize - op;
+ if (op < len) { /* some from window */
+ len -= op;
+ do {
+ PUP(out) = PUP(from);
+ } while (--op);
+ from = out - dist; /* rest from output */
+ }
+ }
+ else if (wnext < op) { /* wrap around window */
+ from += wsize + wnext - op;
+ op -= wnext;
+ if (op < len) { /* some from end of window */
+ len -= op;
+ do {
+ PUP(out) = PUP(from);
+ } while (--op);
+ from = window - OFF;
+ if (wnext < len) { /* some from start of window */
+ op = wnext;
+ len -= op;
+ do {
+ PUP(out) = PUP(from);
+ } while (--op);
+ from = out - dist; /* rest from output */
+ }
+ }
+ }
+ else { /* contiguous in window */
+ from += wnext - op;
+ if (op < len) { /* some from window */
+ len -= op;
+ do {
+ PUP(out) = PUP(from);
+ } while (--op);
+ from = out - dist; /* rest from output */
+ }
+ }
+ while (len > 2) {
+ PUP(out) = PUP(from);
+ PUP(out) = PUP(from);
+ PUP(out) = PUP(from);
+ len -= 3;
+ }
+ if (len) {
+ PUP(out) = PUP(from);
+ if (len > 1)
+ PUP(out) = PUP(from);
+ }
+ }
+ else {
+ from = out - dist; /* copy direct from output */
+ do { /* minimum length is three */
+ PUP(out) = PUP(from);
+ PUP(out) = PUP(from);
+ PUP(out) = PUP(from);
+ len -= 3;
+ } while (len > 2);
+ if (len) {
+ PUP(out) = PUP(from);
+ if (len > 1)
+ PUP(out) = PUP(from);
+ }
+ }
+ }
+ else if ((op & 64) == 0) { /* 2nd level distance code */
+ here = dcode[here.val + (hold & ((1U << op) - 1))];
+ goto dodist;
+ }
+ else {
+ strm->msg = (char *)"invalid distance code";
+ state->mode = BAD;
+ break;
+ }
+ }
+ else if ((op & 64) == 0) { /* 2nd level length code */
+ here = lcode[here.val + (hold & ((1U << op) - 1))];
+ goto dolen;
+ }
+ else if (op & 32) { /* end-of-block */
+ Tracevv((stderr, "inflate: end of block\n"));
+ state->mode = TYPE;
+ break;
+ }
+ else {
+ strm->msg = (char *)"invalid literal/length code";
+ state->mode = BAD;
+ break;
+ }
+ } while (in < last && out < end);
+
+ /* return unused bytes (on entry, bits < 8, so in won't go too far back) */
+ len = bits >> 3;
+ in -= len;
+ bits -= len << 3;
+ hold &= (1U << bits) - 1;
+
+ /* update state and return */
+ strm->next_in = in + OFF;
+ strm->next_out = out + OFF;
+ strm->avail_in = (unsigned)(in < last ? 5 + (last - in) : 5 - (in - last));
+ strm->avail_out = (unsigned)(out < end ?
+ 257 + (end - out) : 257 - (out - end));
+ state->hold = hold;
+ state->bits = bits;
+ return;
+}
+
+/*
+ inflate_fast() speedups that turned out slower (on a PowerPC G3 750CXe):
+ - Using bit fields for code structure
+ - Different op definition to avoid & for extra bits (do & for table bits)
+ - Three separate decoding do-loops for direct, window, and wnext == 0
+ - Special case for distance > 1 copies to do overlapped load and store copy
+ - Explicit branch predictions (based on measured branch probabilities)
+ - Deferring match copy and interspersed it with decoding subsequent codes
+ - Swapping literal/length else
+ - Swapping window/direct else
+ - Larger unrolled copy loops (three is about right)
+ - Moving len -= 3 statement into middle of loop
+ */
+
+#endif /* !ASMINF */
diff --git a/ml/dlib/dlib/external/zlib/inffast.h b/ml/dlib/dlib/external/zlib/inffast.h
new file mode 100644
index 000000000..e5c1aa4ca
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/inffast.h
@@ -0,0 +1,11 @@
+/* inffast.h -- header to use inffast.c
+ * Copyright (C) 1995-2003, 2010 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* WARNING: this file should *not* be used by applications. It is
+ part of the implementation of the compression library and is
+ subject to change. Applications should only use zlib.h.
+ */
+
+void ZLIB_INTERNAL inflate_fast OF((z_streamp strm, unsigned start));
diff --git a/ml/dlib/dlib/external/zlib/inffixed.h b/ml/dlib/dlib/external/zlib/inffixed.h
new file mode 100644
index 000000000..d62832776
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/inffixed.h
@@ -0,0 +1,94 @@
+ /* inffixed.h -- table for decoding fixed codes
+ * Generated automatically by makefixed().
+ */
+
+ /* WARNING: this file should *not* be used by applications.
+ It is part of the implementation of this library and is
+ subject to change. Applications should only use zlib.h.
+ */
+
+ static const code lenfix[512] = {
+ {96,7,0},{0,8,80},{0,8,16},{20,8,115},{18,7,31},{0,8,112},{0,8,48},
+ {0,9,192},{16,7,10},{0,8,96},{0,8,32},{0,9,160},{0,8,0},{0,8,128},
+ {0,8,64},{0,9,224},{16,7,6},{0,8,88},{0,8,24},{0,9,144},{19,7,59},
+ {0,8,120},{0,8,56},{0,9,208},{17,7,17},{0,8,104},{0,8,40},{0,9,176},
+ {0,8,8},{0,8,136},{0,8,72},{0,9,240},{16,7,4},{0,8,84},{0,8,20},
+ {21,8,227},{19,7,43},{0,8,116},{0,8,52},{0,9,200},{17,7,13},{0,8,100},
+ {0,8,36},{0,9,168},{0,8,4},{0,8,132},{0,8,68},{0,9,232},{16,7,8},
+ {0,8,92},{0,8,28},{0,9,152},{20,7,83},{0,8,124},{0,8,60},{0,9,216},
+ {18,7,23},{0,8,108},{0,8,44},{0,9,184},{0,8,12},{0,8,140},{0,8,76},
+ {0,9,248},{16,7,3},{0,8,82},{0,8,18},{21,8,163},{19,7,35},{0,8,114},
+ {0,8,50},{0,9,196},{17,7,11},{0,8,98},{0,8,34},{0,9,164},{0,8,2},
+ {0,8,130},{0,8,66},{0,9,228},{16,7,7},{0,8,90},{0,8,26},{0,9,148},
+ {20,7,67},{0,8,122},{0,8,58},{0,9,212},{18,7,19},{0,8,106},{0,8,42},
+ {0,9,180},{0,8,10},{0,8,138},{0,8,74},{0,9,244},{16,7,5},{0,8,86},
+ {0,8,22},{64,8,0},{19,7,51},{0,8,118},{0,8,54},{0,9,204},{17,7,15},
+ {0,8,102},{0,8,38},{0,9,172},{0,8,6},{0,8,134},{0,8,70},{0,9,236},
+ {16,7,9},{0,8,94},{0,8,30},{0,9,156},{20,7,99},{0,8,126},{0,8,62},
+ {0,9,220},{18,7,27},{0,8,110},{0,8,46},{0,9,188},{0,8,14},{0,8,142},
+ {0,8,78},{0,9,252},{96,7,0},{0,8,81},{0,8,17},{21,8,131},{18,7,31},
+ {0,8,113},{0,8,49},{0,9,194},{16,7,10},{0,8,97},{0,8,33},{0,9,162},
+ {0,8,1},{0,8,129},{0,8,65},{0,9,226},{16,7,6},{0,8,89},{0,8,25},
+ {0,9,146},{19,7,59},{0,8,121},{0,8,57},{0,9,210},{17,7,17},{0,8,105},
+ {0,8,41},{0,9,178},{0,8,9},{0,8,137},{0,8,73},{0,9,242},{16,7,4},
+ {0,8,85},{0,8,21},{16,8,258},{19,7,43},{0,8,117},{0,8,53},{0,9,202},
+ {17,7,13},{0,8,101},{0,8,37},{0,9,170},{0,8,5},{0,8,133},{0,8,69},
+ {0,9,234},{16,7,8},{0,8,93},{0,8,29},{0,9,154},{20,7,83},{0,8,125},
+ {0,8,61},{0,9,218},{18,7,23},{0,8,109},{0,8,45},{0,9,186},{0,8,13},
+ {0,8,141},{0,8,77},{0,9,250},{16,7,3},{0,8,83},{0,8,19},{21,8,195},
+ {19,7,35},{0,8,115},{0,8,51},{0,9,198},{17,7,11},{0,8,99},{0,8,35},
+ {0,9,166},{0,8,3},{0,8,131},{0,8,67},{0,9,230},{16,7,7},{0,8,91},
+ {0,8,27},{0,9,150},{20,7,67},{0,8,123},{0,8,59},{0,9,214},{18,7,19},
+ {0,8,107},{0,8,43},{0,9,182},{0,8,11},{0,8,139},{0,8,75},{0,9,246},
+ {16,7,5},{0,8,87},{0,8,23},{64,8,0},{19,7,51},{0,8,119},{0,8,55},
+ {0,9,206},{17,7,15},{0,8,103},{0,8,39},{0,9,174},{0,8,7},{0,8,135},
+ {0,8,71},{0,9,238},{16,7,9},{0,8,95},{0,8,31},{0,9,158},{20,7,99},
+ {0,8,127},{0,8,63},{0,9,222},{18,7,27},{0,8,111},{0,8,47},{0,9,190},
+ {0,8,15},{0,8,143},{0,8,79},{0,9,254},{96,7,0},{0,8,80},{0,8,16},
+ {20,8,115},{18,7,31},{0,8,112},{0,8,48},{0,9,193},{16,7,10},{0,8,96},
+ {0,8,32},{0,9,161},{0,8,0},{0,8,128},{0,8,64},{0,9,225},{16,7,6},
+ {0,8,88},{0,8,24},{0,9,145},{19,7,59},{0,8,120},{0,8,56},{0,9,209},
+ {17,7,17},{0,8,104},{0,8,40},{0,9,177},{0,8,8},{0,8,136},{0,8,72},
+ {0,9,241},{16,7,4},{0,8,84},{0,8,20},{21,8,227},{19,7,43},{0,8,116},
+ {0,8,52},{0,9,201},{17,7,13},{0,8,100},{0,8,36},{0,9,169},{0,8,4},
+ {0,8,132},{0,8,68},{0,9,233},{16,7,8},{0,8,92},{0,8,28},{0,9,153},
+ {20,7,83},{0,8,124},{0,8,60},{0,9,217},{18,7,23},{0,8,108},{0,8,44},
+ {0,9,185},{0,8,12},{0,8,140},{0,8,76},{0,9,249},{16,7,3},{0,8,82},
+ {0,8,18},{21,8,163},{19,7,35},{0,8,114},{0,8,50},{0,9,197},{17,7,11},
+ {0,8,98},{0,8,34},{0,9,165},{0,8,2},{0,8,130},{0,8,66},{0,9,229},
+ {16,7,7},{0,8,90},{0,8,26},{0,9,149},{20,7,67},{0,8,122},{0,8,58},
+ {0,9,213},{18,7,19},{0,8,106},{0,8,42},{0,9,181},{0,8,10},{0,8,138},
+ {0,8,74},{0,9,245},{16,7,5},{0,8,86},{0,8,22},{64,8,0},{19,7,51},
+ {0,8,118},{0,8,54},{0,9,205},{17,7,15},{0,8,102},{0,8,38},{0,9,173},
+ {0,8,6},{0,8,134},{0,8,70},{0,9,237},{16,7,9},{0,8,94},{0,8,30},
+ {0,9,157},{20,7,99},{0,8,126},{0,8,62},{0,9,221},{18,7,27},{0,8,110},
+ {0,8,46},{0,9,189},{0,8,14},{0,8,142},{0,8,78},{0,9,253},{96,7,0},
+ {0,8,81},{0,8,17},{21,8,131},{18,7,31},{0,8,113},{0,8,49},{0,9,195},
+ {16,7,10},{0,8,97},{0,8,33},{0,9,163},{0,8,1},{0,8,129},{0,8,65},
+ {0,9,227},{16,7,6},{0,8,89},{0,8,25},{0,9,147},{19,7,59},{0,8,121},
+ {0,8,57},{0,9,211},{17,7,17},{0,8,105},{0,8,41},{0,9,179},{0,8,9},
+ {0,8,137},{0,8,73},{0,9,243},{16,7,4},{0,8,85},{0,8,21},{16,8,258},
+ {19,7,43},{0,8,117},{0,8,53},{0,9,203},{17,7,13},{0,8,101},{0,8,37},
+ {0,9,171},{0,8,5},{0,8,133},{0,8,69},{0,9,235},{16,7,8},{0,8,93},
+ {0,8,29},{0,9,155},{20,7,83},{0,8,125},{0,8,61},{0,9,219},{18,7,23},
+ {0,8,109},{0,8,45},{0,9,187},{0,8,13},{0,8,141},{0,8,77},{0,9,251},
+ {16,7,3},{0,8,83},{0,8,19},{21,8,195},{19,7,35},{0,8,115},{0,8,51},
+ {0,9,199},{17,7,11},{0,8,99},{0,8,35},{0,9,167},{0,8,3},{0,8,131},
+ {0,8,67},{0,9,231},{16,7,7},{0,8,91},{0,8,27},{0,9,151},{20,7,67},
+ {0,8,123},{0,8,59},{0,9,215},{18,7,19},{0,8,107},{0,8,43},{0,9,183},
+ {0,8,11},{0,8,139},{0,8,75},{0,9,247},{16,7,5},{0,8,87},{0,8,23},
+ {64,8,0},{19,7,51},{0,8,119},{0,8,55},{0,9,207},{17,7,15},{0,8,103},
+ {0,8,39},{0,9,175},{0,8,7},{0,8,135},{0,8,71},{0,9,239},{16,7,9},
+ {0,8,95},{0,8,31},{0,9,159},{20,7,99},{0,8,127},{0,8,63},{0,9,223},
+ {18,7,27},{0,8,111},{0,8,47},{0,9,191},{0,8,15},{0,8,143},{0,8,79},
+ {0,9,255}
+ };
+
+ static const code distfix[32] = {
+ {16,5,1},{23,5,257},{19,5,17},{27,5,4097},{17,5,5},{25,5,1025},
+ {21,5,65},{29,5,16385},{16,5,3},{24,5,513},{20,5,33},{28,5,8193},
+ {18,5,9},{26,5,2049},{22,5,129},{64,5,0},{16,5,2},{23,5,385},
+ {19,5,25},{27,5,6145},{17,5,7},{25,5,1537},{21,5,97},{29,5,24577},
+ {16,5,4},{24,5,769},{20,5,49},{28,5,12289},{18,5,13},{26,5,3073},
+ {22,5,193},{64,5,0}
+ };
diff --git a/ml/dlib/dlib/external/zlib/inflate.c b/ml/dlib/dlib/external/zlib/inflate.c
new file mode 100644
index 000000000..870f89bb4
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/inflate.c
@@ -0,0 +1,1512 @@
+/* inflate.c -- zlib decompression
+ * Copyright (C) 1995-2012 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/*
+ * Change history:
+ *
+ * 1.2.beta0 24 Nov 2002
+ * - First version -- complete rewrite of inflate to simplify code, avoid
+ * creation of window when not needed, minimize use of window when it is
+ * needed, make inffast.c even faster, implement gzip decoding, and to
+ * improve code readability and style over the previous zlib inflate code
+ *
+ * 1.2.beta1 25 Nov 2002
+ * - Use pointers for available input and output checking in inffast.c
+ * - Remove input and output counters in inffast.c
+ * - Change inffast.c entry and loop from avail_in >= 7 to >= 6
+ * - Remove unnecessary second byte pull from length extra in inffast.c
+ * - Unroll direct copy to three copies per loop in inffast.c
+ *
+ * 1.2.beta2 4 Dec 2002
+ * - Change external routine names to reduce potential conflicts
+ * - Correct filename to inffixed.h for fixed tables in inflate.c
+ * - Make hbuf[] unsigned char to match parameter type in inflate.c
+ * - Change strm->next_out[-state->offset] to *(strm->next_out - state->offset)
+ * to avoid negation problem on Alphas (64 bit) in inflate.c
+ *
+ * 1.2.beta3 22 Dec 2002
+ * - Add comments on state->bits assertion in inffast.c
+ * - Add comments on op field in inftrees.h
+ * - Fix bug in reuse of allocated window after inflateReset()
+ * - Remove bit fields--back to byte structure for speed
+ * - Remove distance extra == 0 check in inflate_fast()--only helps for lengths
+ * - Change post-increments to pre-increments in inflate_fast(), PPC biased?
+ * - Add compile time option, POSTINC, to use post-increments instead (Intel?)
+ * - Make MATCH copy in inflate() much faster for when inflate_fast() not used
+ * - Use local copies of stream next and avail values, as well as local bit
+ * buffer and bit count in inflate()--for speed when inflate_fast() not used
+ *
+ * 1.2.beta4 1 Jan 2003
+ * - Split ptr - 257 statements in inflate_table() to avoid compiler warnings
+ * - Move a comment on output buffer sizes from inffast.c to inflate.c
+ * - Add comments in inffast.c to introduce the inflate_fast() routine
+ * - Rearrange window copies in inflate_fast() for speed and simplification
+ * - Unroll last copy for window match in inflate_fast()
+ * - Use local copies of window variables in inflate_fast() for speed
+ * - Pull out common wnext == 0 case for speed in inflate_fast()
+ * - Make op and len in inflate_fast() unsigned for consistency
+ * - Add FAR to lcode and dcode declarations in inflate_fast()
+ * - Simplified bad distance check in inflate_fast()
+ * - Added inflateBackInit(), inflateBack(), and inflateBackEnd() in new
+ * source file infback.c to provide a call-back interface to inflate for
+ * programs like gzip and unzip -- uses window as output buffer to avoid
+ * window copying
+ *
+ * 1.2.beta5 1 Jan 2003
+ * - Improved inflateBack() interface to allow the caller to provide initial
+ * input in strm.
+ * - Fixed stored blocks bug in inflateBack()
+ *
+ * 1.2.beta6 4 Jan 2003
+ * - Added comments in inffast.c on effectiveness of POSTINC
+ * - Typecasting all around to reduce compiler warnings
+ * - Changed loops from while (1) or do {} while (1) to for (;;), again to
+ * make compilers happy
+ * - Changed type of window in inflateBackInit() to unsigned char *
+ *
+ * 1.2.beta7 27 Jan 2003
+ * - Changed many types to unsigned or unsigned short to avoid warnings
+ * - Added inflateCopy() function
+ *
+ * 1.2.0 9 Mar 2003
+ * - Changed inflateBack() interface to provide separate opaque descriptors
+ * for the in() and out() functions
+ * - Changed inflateBack() argument and in_func typedef to swap the length
+ * and buffer address return values for the input function
+ * - Check next_in and next_out for Z_NULL on entry to inflate()
+ *
+ * The history for versions after 1.2.0 are in ChangeLog in zlib distribution.
+ */
+
+#include "zutil.h"
+#include "inftrees.h"
+#include "inflate.h"
+#include "inffast.h"
+
+#ifdef MAKEFIXED
+# ifndef BUILDFIXED
+# define BUILDFIXED
+# endif
+#endif
+
+/* function prototypes */
+local void fixedtables OF((struct inflate_state FAR *state));
+local int updatewindow OF((z_streamp strm, const unsigned char FAR *end,
+ unsigned copy));
+#ifdef BUILDFIXED
+ void makefixed OF((void));
+#endif
+local unsigned syncsearch OF((unsigned FAR *have, const unsigned char FAR *buf,
+ unsigned len));
+
+int ZEXPORT inflateResetKeep(strm)
+z_streamp strm;
+{
+ struct inflate_state FAR *state;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+ strm->total_in = strm->total_out = state->total = 0;
+ strm->msg = Z_NULL;
+ if (state->wrap) /* to support ill-conceived Java test suite */
+ strm->adler = state->wrap & 1;
+ state->mode = HEAD;
+ state->last = 0;
+ state->havedict = 0;
+ state->dmax = 32768U;
+ state->head = Z_NULL;
+ state->hold = 0;
+ state->bits = 0;
+ state->lencode = state->distcode = state->next = state->codes;
+ state->sane = 1;
+ state->back = -1;
+ Tracev((stderr, "inflate: reset\n"));
+ return Z_OK;
+}
+
+int ZEXPORT inflateReset(strm)
+z_streamp strm;
+{
+ struct inflate_state FAR *state;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+ state->wsize = 0;
+ state->whave = 0;
+ state->wnext = 0;
+ return inflateResetKeep(strm);
+}
+
+int ZEXPORT inflateReset2(strm, windowBits)
+z_streamp strm;
+int windowBits;
+{
+ int wrap;
+ struct inflate_state FAR *state;
+
+ /* get the state */
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+
+ /* extract wrap request from windowBits parameter */
+ if (windowBits < 0) {
+ wrap = 0;
+ windowBits = -windowBits;
+ }
+ else {
+ wrap = (windowBits >> 4) + 1;
+#ifdef GUNZIP
+ if (windowBits < 48)
+ windowBits &= 15;
+#endif
+ }
+
+ /* set number of window bits, free window if different */
+ if (windowBits && (windowBits < 8 || windowBits > 15))
+ return Z_STREAM_ERROR;
+ if (state->window != Z_NULL && state->wbits != (unsigned)windowBits) {
+ ZFREE(strm, state->window);
+ state->window = Z_NULL;
+ }
+
+ /* update state and reset the rest of it */
+ state->wrap = wrap;
+ state->wbits = (unsigned)windowBits;
+ return inflateReset(strm);
+}
+
+int ZEXPORT inflateInit2_(strm, windowBits, version, stream_size)
+z_streamp strm;
+int windowBits;
+const char *version;
+int stream_size;
+{
+ int ret;
+ struct inflate_state FAR *state;
+
+ if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
+ stream_size != (int)(sizeof(z_stream)))
+ return Z_VERSION_ERROR;
+ if (strm == Z_NULL) return Z_STREAM_ERROR;
+ strm->msg = Z_NULL; /* in case we return an error */
+ if (strm->zalloc == (alloc_func)0) {
+#ifdef Z_SOLO
+ return Z_STREAM_ERROR;
+#else
+ strm->zalloc = zcalloc;
+ strm->opaque = (voidpf)0;
+#endif
+ }
+ if (strm->zfree == (free_func)0)
+#ifdef Z_SOLO
+ return Z_STREAM_ERROR;
+#else
+ strm->zfree = zcfree;
+#endif
+ state = (struct inflate_state FAR *)
+ ZALLOC(strm, 1, sizeof(struct inflate_state));
+ if (state == Z_NULL) return Z_MEM_ERROR;
+ Tracev((stderr, "inflate: allocated\n"));
+ strm->state = (struct internal_state FAR *)state;
+ state->window = Z_NULL;
+ ret = inflateReset2(strm, windowBits);
+ if (ret != Z_OK) {
+ ZFREE(strm, state);
+ strm->state = Z_NULL;
+ }
+ return ret;
+}
+
+int ZEXPORT inflateInit_(strm, version, stream_size)
+z_streamp strm;
+const char *version;
+int stream_size;
+{
+ return inflateInit2_(strm, DEF_WBITS, version, stream_size);
+}
+
+int ZEXPORT inflatePrime(strm, bits, value)
+z_streamp strm;
+int bits;
+int value;
+{
+ struct inflate_state FAR *state;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+ if (bits < 0) {
+ state->hold = 0;
+ state->bits = 0;
+ return Z_OK;
+ }
+ if (bits > 16 || state->bits + bits > 32) return Z_STREAM_ERROR;
+ value &= (1L << bits) - 1;
+ state->hold += value << state->bits;
+ state->bits += bits;
+ return Z_OK;
+}
+
+/*
+ Return state with length and distance decoding tables and index sizes set to
+ fixed code decoding. Normally this returns fixed tables from inffixed.h.
+ If BUILDFIXED is defined, then instead this routine builds the tables the
+ first time it's called, and returns those tables the first time and
+ thereafter. This reduces the size of the code by about 2K bytes, in
+ exchange for a little execution time. However, BUILDFIXED should not be
+ used for threaded applications, since the rewriting of the tables and virgin
+ may not be thread-safe.
+ */
+local void fixedtables(state)
+struct inflate_state FAR *state;
+{
+#ifdef BUILDFIXED
+ static int virgin = 1;
+ static code *lenfix, *distfix;
+ static code fixed[544];
+
+ /* build fixed huffman tables if first call (may not be thread safe) */
+ if (virgin) {
+ unsigned sym, bits;
+ static code *next;
+
+ /* literal/length table */
+ sym = 0;
+ while (sym < 144) state->lens[sym++] = 8;
+ while (sym < 256) state->lens[sym++] = 9;
+ while (sym < 280) state->lens[sym++] = 7;
+ while (sym < 288) state->lens[sym++] = 8;
+ next = fixed;
+ lenfix = next;
+ bits = 9;
+ inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
+
+ /* distance table */
+ sym = 0;
+ while (sym < 32) state->lens[sym++] = 5;
+ distfix = next;
+ bits = 5;
+ inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
+
+ /* do this just once */
+ virgin = 0;
+ }
+#else /* !BUILDFIXED */
+# include "inffixed.h"
+#endif /* BUILDFIXED */
+ state->lencode = lenfix;
+ state->lenbits = 9;
+ state->distcode = distfix;
+ state->distbits = 5;
+}
+
+#ifdef MAKEFIXED
+#include <stdio.h>
+
+/*
+ Write out the inffixed.h that is #include'd above. Defining MAKEFIXED also
+ defines BUILDFIXED, so the tables are built on the fly. makefixed() writes
+ those tables to stdout, which would be piped to inffixed.h. A small program
+ can simply call makefixed to do this:
+
+ void makefixed(void);
+
+ int main(void)
+ {
+ makefixed();
+ return 0;
+ }
+
+ Then that can be linked with zlib built with MAKEFIXED defined and run:
+
+ a.out > inffixed.h
+ */
+void makefixed()
+{
+ unsigned low, size;
+ struct inflate_state state;
+
+ fixedtables(&state);
+ puts(" /* inffixed.h -- table for decoding fixed codes");
+ puts(" * Generated automatically by makefixed().");
+ puts(" */");
+ puts("");
+ puts(" /* WARNING: this file should *not* be used by applications.");
+ puts(" It is part of the implementation of this library and is");
+ puts(" subject to change. Applications should only use zlib.h.");
+ puts(" */");
+ puts("");
+ size = 1U << 9;
+ printf(" static const code lenfix[%u] = {", size);
+ low = 0;
+ for (;;) {
+ if ((low % 7) == 0) printf("\n ");
+ printf("{%u,%u,%d}", (low & 127) == 99 ? 64 : state.lencode[low].op,
+ state.lencode[low].bits, state.lencode[low].val);
+ if (++low == size) break;
+ putchar(',');
+ }
+ puts("\n };");
+ size = 1U << 5;
+ printf("\n static const code distfix[%u] = {", size);
+ low = 0;
+ for (;;) {
+ if ((low % 6) == 0) printf("\n ");
+ printf("{%u,%u,%d}", state.distcode[low].op, state.distcode[low].bits,
+ state.distcode[low].val);
+ if (++low == size) break;
+ putchar(',');
+ }
+ puts("\n };");
+}
+#endif /* MAKEFIXED */
+
+/*
+ Update the window with the last wsize (normally 32K) bytes written before
+ returning. If window does not exist yet, create it. This is only called
+ when a window is already in use, or when output has been written during this
+ inflate call, but the end of the deflate stream has not been reached yet.
+ It is also called to create a window for dictionary data when a dictionary
+ is loaded.
+
+ Providing output buffers larger than 32K to inflate() should provide a speed
+ advantage, since only the last 32K of output is copied to the sliding window
+ upon return from inflate(), and since all distances after the first 32K of
+ output will fall in the output data, making match copies simpler and faster.
+ The advantage may be dependent on the size of the processor's data caches.
+ */
+local int updatewindow(strm, end, copy)
+z_streamp strm;
+const Bytef *end;
+unsigned copy;
+{
+ struct inflate_state FAR *state;
+ unsigned dist;
+
+ state = (struct inflate_state FAR *)strm->state;
+
+ /* if it hasn't been done already, allocate space for the window */
+ if (state->window == Z_NULL) {
+ state->window = (unsigned char FAR *)
+ ZALLOC(strm, 1U << state->wbits,
+ sizeof(unsigned char));
+ if (state->window == Z_NULL) return 1;
+ }
+
+ /* if window not in use yet, initialize */
+ if (state->wsize == 0) {
+ state->wsize = 1U << state->wbits;
+ state->wnext = 0;
+ state->whave = 0;
+ }
+
+ /* copy state->wsize or less output bytes into the circular window */
+ if (copy >= state->wsize) {
+ zmemcpy(state->window, end - state->wsize, state->wsize);
+ state->wnext = 0;
+ state->whave = state->wsize;
+ }
+ else {
+ dist = state->wsize - state->wnext;
+ if (dist > copy) dist = copy;
+ zmemcpy(state->window + state->wnext, end - copy, dist);
+ copy -= dist;
+ if (copy) {
+ zmemcpy(state->window, end - copy, copy);
+ state->wnext = copy;
+ state->whave = state->wsize;
+ }
+ else {
+ state->wnext += dist;
+ if (state->wnext == state->wsize) state->wnext = 0;
+ if (state->whave < state->wsize) state->whave += dist;
+ }
+ }
+ return 0;
+}
+
+/* Macros for inflate(): */
+
+/* check function to use adler32() for zlib or crc32() for gzip */
+#ifdef GUNZIP
+# define UPDATE(check, buf, len) \
+ (state->flags ? crc32(check, buf, len) : adler32(check, buf, len))
+#else
+# define UPDATE(check, buf, len) adler32(check, buf, len)
+#endif
+
+/* check macros for header crc */
+#ifdef GUNZIP
+# define CRC2(check, word) \
+ do { \
+ hbuf[0] = (unsigned char)(word); \
+ hbuf[1] = (unsigned char)((word) >> 8); \
+ check = crc32(check, hbuf, 2); \
+ } while (0)
+
+# define CRC4(check, word) \
+ do { \
+ hbuf[0] = (unsigned char)(word); \
+ hbuf[1] = (unsigned char)((word) >> 8); \
+ hbuf[2] = (unsigned char)((word) >> 16); \
+ hbuf[3] = (unsigned char)((word) >> 24); \
+ check = crc32(check, hbuf, 4); \
+ } while (0)
+#endif
+
+/* Load registers with state in inflate() for speed */
+#define LOAD() \
+ do { \
+ put = strm->next_out; \
+ left = strm->avail_out; \
+ next = strm->next_in; \
+ have = strm->avail_in; \
+ hold = state->hold; \
+ bits = state->bits; \
+ } while (0)
+
+/* Restore state from registers in inflate() */
+#define RESTORE() \
+ do { \
+ strm->next_out = put; \
+ strm->avail_out = left; \
+ strm->next_in = next; \
+ strm->avail_in = have; \
+ state->hold = hold; \
+ state->bits = bits; \
+ } while (0)
+
+/* Clear the input bit accumulator */
+#define INITBITS() \
+ do { \
+ hold = 0; \
+ bits = 0; \
+ } while (0)
+
+/* Get a byte of input into the bit accumulator, or return from inflate()
+ if there is no input available. */
+#define PULLBYTE() \
+ do { \
+ if (have == 0) goto inf_leave; \
+ have--; \
+ hold += (unsigned long)(*next++) << bits; \
+ bits += 8; \
+ } while (0)
+
+/* Assure that there are at least n bits in the bit accumulator. If there is
+ not enough available input to do that, then return from inflate(). */
+#define NEEDBITS(n) \
+ do { \
+ while (bits < (unsigned)(n)) \
+ PULLBYTE(); \
+ } while (0)
+
+/* Return the low n bits of the bit accumulator (n < 16) */
+#define BITS(n) \
+ ((unsigned)hold & ((1U << (n)) - 1))
+
+/* Remove n bits from the bit accumulator */
+#define DROPBITS(n) \
+ do { \
+ hold >>= (n); \
+ bits -= (unsigned)(n); \
+ } while (0)
+
+/* Remove zero to seven bits as needed to go to a byte boundary */
+#define BYTEBITS() \
+ do { \
+ hold >>= bits & 7; \
+ bits -= bits & 7; \
+ } while (0)
+
+/*
+ inflate() uses a state machine to process as much input data and generate as
+ much output data as possible before returning. The state machine is
+ structured roughly as follows:
+
+ for (;;) switch (state) {
+ ...
+ case STATEn:
+ if (not enough input data or output space to make progress)
+ return;
+ ... make progress ...
+ state = STATEm;
+ break;
+ ...
+ }
+
+ so when inflate() is called again, the same case is attempted again, and
+ if the appropriate resources are provided, the machine proceeds to the
+ next state. The NEEDBITS() macro is usually the way the state evaluates
+ whether it can proceed or should return. NEEDBITS() does the return if
+ the requested bits are not available. The typical use of the BITS macros
+ is:
+
+ NEEDBITS(n);
+ ... do something with BITS(n) ...
+ DROPBITS(n);
+
+ where NEEDBITS(n) either returns from inflate() if there isn't enough
+ input left to load n bits into the accumulator, or it continues. BITS(n)
+ gives the low n bits in the accumulator. When done, DROPBITS(n) drops
+ the low n bits off the accumulator. INITBITS() clears the accumulator
+ and sets the number of available bits to zero. BYTEBITS() discards just
+ enough bits to put the accumulator on a byte boundary. After BYTEBITS()
+ and a NEEDBITS(8), then BITS(8) would return the next byte in the stream.
+
+ NEEDBITS(n) uses PULLBYTE() to get an available byte of input, or to return
+ if there is no input available. The decoding of variable length codes uses
+ PULLBYTE() directly in order to pull just enough bytes to decode the next
+ code, and no more.
+
+ Some states loop until they get enough input, making sure that enough
+ state information is maintained to continue the loop where it left off
+ if NEEDBITS() returns in the loop. For example, want, need, and keep
+ would all have to actually be part of the saved state in case NEEDBITS()
+ returns:
+
+ case STATEw:
+ while (want < need) {
+ NEEDBITS(n);
+ keep[want++] = BITS(n);
+ DROPBITS(n);
+ }
+ state = STATEx;
+ case STATEx:
+
+ As shown above, if the next state is also the next case, then the break
+ is omitted.
+
+ A state may also return if there is not enough output space available to
+ complete that state. Those states are copying stored data, writing a
+ literal byte, and copying a matching string.
+
+ When returning, a "goto inf_leave" is used to update the total counters,
+ update the check value, and determine whether any progress has been made
+ during that inflate() call in order to return the proper return code.
+ Progress is defined as a change in either strm->avail_in or strm->avail_out.
+ When there is a window, goto inf_leave will update the window with the last
+ output written. If a goto inf_leave occurs in the middle of decompression
+ and there is no window currently, goto inf_leave will create one and copy
+ output to the window for the next call of inflate().
+
+ In this implementation, the flush parameter of inflate() only affects the
+ return code (per zlib.h). inflate() always writes as much as possible to
+ strm->next_out, given the space available and the provided input--the effect
+ documented in zlib.h of Z_SYNC_FLUSH. Furthermore, inflate() always defers
+ the allocation of and copying into a sliding window until necessary, which
+ provides the effect documented in zlib.h for Z_FINISH when the entire input
+ stream available. So the only thing the flush parameter actually does is:
+ when flush is set to Z_FINISH, inflate() cannot return Z_OK. Instead it
+ will return Z_BUF_ERROR if it has not reached the end of the stream.
+ */
+
+int ZEXPORT inflate(strm, flush)
+z_streamp strm;
+int flush;
+{
+ struct inflate_state FAR *state;
+ z_const unsigned char FAR *next; /* next input */
+ unsigned char FAR *put; /* next output */
+ unsigned have, left; /* available input and output */
+ unsigned long hold; /* bit buffer */
+ unsigned bits; /* bits in bit buffer */
+ unsigned in, out; /* save starting available input and output */
+ unsigned copy; /* number of stored or match bytes to copy */
+ unsigned char FAR *from; /* where to copy match bytes from */
+ code here; /* current decoding table entry */
+ code last; /* parent table entry */
+ unsigned len; /* length to copy for repeats, bits to drop */
+ int ret; /* return code */
+#ifdef GUNZIP
+ unsigned char hbuf[4]; /* buffer for gzip header crc calculation */
+#endif
+ static const unsigned short order[19] = /* permutation of code lengths */
+ {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
+
+ if (strm == Z_NULL || strm->state == Z_NULL || strm->next_out == Z_NULL ||
+ (strm->next_in == Z_NULL && strm->avail_in != 0))
+ return Z_STREAM_ERROR;
+
+ state = (struct inflate_state FAR *)strm->state;
+ if (state->mode == TYPE) state->mode = TYPEDO; /* skip check */
+ LOAD();
+ in = have;
+ out = left;
+ ret = Z_OK;
+ for (;;)
+ switch (state->mode) {
+ case HEAD:
+ if (state->wrap == 0) {
+ state->mode = TYPEDO;
+ break;
+ }
+ NEEDBITS(16);
+#ifdef GUNZIP
+ if ((state->wrap & 2) && hold == 0x8b1f) { /* gzip header */
+ state->check = crc32(0L, Z_NULL, 0);
+ CRC2(state->check, hold);
+ INITBITS();
+ state->mode = FLAGS;
+ break;
+ }
+ state->flags = 0; /* expect zlib header */
+ if (state->head != Z_NULL)
+ state->head->done = -1;
+ if (!(state->wrap & 1) || /* check if zlib header allowed */
+#else
+ if (
+#endif
+ ((BITS(8) << 8) + (hold >> 8)) % 31) {
+ strm->msg = (char *)"incorrect header check";
+ state->mode = BAD;
+ break;
+ }
+ if (BITS(4) != Z_DEFLATED) {
+ strm->msg = (char *)"unknown compression method";
+ state->mode = BAD;
+ break;
+ }
+ DROPBITS(4);
+ len = BITS(4) + 8;
+ if (state->wbits == 0)
+ state->wbits = len;
+ else if (len > state->wbits) {
+ strm->msg = (char *)"invalid window size";
+ state->mode = BAD;
+ break;
+ }
+ state->dmax = 1U << len;
+ Tracev((stderr, "inflate: zlib header ok\n"));
+ strm->adler = state->check = adler32(0L, Z_NULL, 0);
+ state->mode = hold & 0x200 ? DICTID : TYPE;
+ INITBITS();
+ break;
+#ifdef GUNZIP
+ case FLAGS:
+ NEEDBITS(16);
+ state->flags = (int)(hold);
+ if ((state->flags & 0xff) != Z_DEFLATED) {
+ strm->msg = (char *)"unknown compression method";
+ state->mode = BAD;
+ break;
+ }
+ if (state->flags & 0xe000) {
+ strm->msg = (char *)"unknown header flags set";
+ state->mode = BAD;
+ break;
+ }
+ if (state->head != Z_NULL)
+ state->head->text = (int)((hold >> 8) & 1);
+ if (state->flags & 0x0200) CRC2(state->check, hold);
+ INITBITS();
+ state->mode = TIME;
+ case TIME:
+ NEEDBITS(32);
+ if (state->head != Z_NULL)
+ state->head->time = hold;
+ if (state->flags & 0x0200) CRC4(state->check, hold);
+ INITBITS();
+ state->mode = OS;
+ case OS:
+ NEEDBITS(16);
+ if (state->head != Z_NULL) {
+ state->head->xflags = (int)(hold & 0xff);
+ state->head->os = (int)(hold >> 8);
+ }
+ if (state->flags & 0x0200) CRC2(state->check, hold);
+ INITBITS();
+ state->mode = EXLEN;
+ case EXLEN:
+ if (state->flags & 0x0400) {
+ NEEDBITS(16);
+ state->length = (unsigned)(hold);
+ if (state->head != Z_NULL)
+ state->head->extra_len = (unsigned)hold;
+ if (state->flags & 0x0200) CRC2(state->check, hold);
+ INITBITS();
+ }
+ else if (state->head != Z_NULL)
+ state->head->extra = Z_NULL;
+ state->mode = EXTRA;
+ case EXTRA:
+ if (state->flags & 0x0400) {
+ copy = state->length;
+ if (copy > have) copy = have;
+ if (copy) {
+ if (state->head != Z_NULL &&
+ state->head->extra != Z_NULL) {
+ len = state->head->extra_len - state->length;
+ zmemcpy(state->head->extra + len, next,
+ len + copy > state->head->extra_max ?
+ state->head->extra_max - len : copy);
+ }
+ if (state->flags & 0x0200)
+ state->check = crc32(state->check, next, copy);
+ have -= copy;
+ next += copy;
+ state->length -= copy;
+ }
+ if (state->length) goto inf_leave;
+ }
+ state->length = 0;
+ state->mode = NAME;
+ case NAME:
+ if (state->flags & 0x0800) {
+ if (have == 0) goto inf_leave;
+ copy = 0;
+ do {
+ len = (unsigned)(next[copy++]);
+ if (state->head != Z_NULL &&
+ state->head->name != Z_NULL &&
+ state->length < state->head->name_max)
+ state->head->name[state->length++] = len;
+ } while (len && copy < have);
+ if (state->flags & 0x0200)
+ state->check = crc32(state->check, next, copy);
+ have -= copy;
+ next += copy;
+ if (len) goto inf_leave;
+ }
+ else if (state->head != Z_NULL)
+ state->head->name = Z_NULL;
+ state->length = 0;
+ state->mode = COMMENT;
+ case COMMENT:
+ if (state->flags & 0x1000) {
+ if (have == 0) goto inf_leave;
+ copy = 0;
+ do {
+ len = (unsigned)(next[copy++]);
+ if (state->head != Z_NULL &&
+ state->head->comment != Z_NULL &&
+ state->length < state->head->comm_max)
+ state->head->comment[state->length++] = len;
+ } while (len && copy < have);
+ if (state->flags & 0x0200)
+ state->check = crc32(state->check, next, copy);
+ have -= copy;
+ next += copy;
+ if (len) goto inf_leave;
+ }
+ else if (state->head != Z_NULL)
+ state->head->comment = Z_NULL;
+ state->mode = HCRC;
+ case HCRC:
+ if (state->flags & 0x0200) {
+ NEEDBITS(16);
+ if (hold != (state->check & 0xffff)) {
+ strm->msg = (char *)"header crc mismatch";
+ state->mode = BAD;
+ break;
+ }
+ INITBITS();
+ }
+ if (state->head != Z_NULL) {
+ state->head->hcrc = (int)((state->flags >> 9) & 1);
+ state->head->done = 1;
+ }
+ strm->adler = state->check = crc32(0L, Z_NULL, 0);
+ state->mode = TYPE;
+ break;
+#endif
+ case DICTID:
+ NEEDBITS(32);
+ strm->adler = state->check = ZSWAP32(hold);
+ INITBITS();
+ state->mode = DICT;
+ case DICT:
+ if (state->havedict == 0) {
+ RESTORE();
+ return Z_NEED_DICT;
+ }
+ strm->adler = state->check = adler32(0L, Z_NULL, 0);
+ state->mode = TYPE;
+ case TYPE:
+ if (flush == Z_BLOCK || flush == Z_TREES) goto inf_leave;
+ case TYPEDO:
+ if (state->last) {
+ BYTEBITS();
+ state->mode = CHECK;
+ break;
+ }
+ NEEDBITS(3);
+ state->last = BITS(1);
+ DROPBITS(1);
+ switch (BITS(2)) {
+ case 0: /* stored block */
+ Tracev((stderr, "inflate: stored block%s\n",
+ state->last ? " (last)" : ""));
+ state->mode = STORED;
+ break;
+ case 1: /* fixed block */
+ fixedtables(state);
+ Tracev((stderr, "inflate: fixed codes block%s\n",
+ state->last ? " (last)" : ""));
+ state->mode = LEN_; /* decode codes */
+ if (flush == Z_TREES) {
+ DROPBITS(2);
+ goto inf_leave;
+ }
+ break;
+ case 2: /* dynamic block */
+ Tracev((stderr, "inflate: dynamic codes block%s\n",
+ state->last ? " (last)" : ""));
+ state->mode = TABLE;
+ break;
+ case 3:
+ strm->msg = (char *)"invalid block type";
+ state->mode = BAD;
+ }
+ DROPBITS(2);
+ break;
+ case STORED:
+ BYTEBITS(); /* go to byte boundary */
+ NEEDBITS(32);
+ if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
+ strm->msg = (char *)"invalid stored block lengths";
+ state->mode = BAD;
+ break;
+ }
+ state->length = (unsigned)hold & 0xffff;
+ Tracev((stderr, "inflate: stored length %u\n",
+ state->length));
+ INITBITS();
+ state->mode = COPY_;
+ if (flush == Z_TREES) goto inf_leave;
+ case COPY_:
+ state->mode = COPY;
+ case COPY:
+ copy = state->length;
+ if (copy) {
+ if (copy > have) copy = have;
+ if (copy > left) copy = left;
+ if (copy == 0) goto inf_leave;
+ zmemcpy(put, next, copy);
+ have -= copy;
+ next += copy;
+ left -= copy;
+ put += copy;
+ state->length -= copy;
+ break;
+ }
+ Tracev((stderr, "inflate: stored end\n"));
+ state->mode = TYPE;
+ break;
+ case TABLE:
+ NEEDBITS(14);
+ state->nlen = BITS(5) + 257;
+ DROPBITS(5);
+ state->ndist = BITS(5) + 1;
+ DROPBITS(5);
+ state->ncode = BITS(4) + 4;
+ DROPBITS(4);
+#ifndef PKZIP_BUG_WORKAROUND
+ if (state->nlen > 286 || state->ndist > 30) {
+ strm->msg = (char *)"too many length or distance symbols";
+ state->mode = BAD;
+ break;
+ }
+#endif
+ Tracev((stderr, "inflate: table sizes ok\n"));
+ state->have = 0;
+ state->mode = LENLENS;
+ case LENLENS:
+ while (state->have < state->ncode) {
+ NEEDBITS(3);
+ state->lens[order[state->have++]] = (unsigned short)BITS(3);
+ DROPBITS(3);
+ }
+ while (state->have < 19)
+ state->lens[order[state->have++]] = 0;
+ state->next = state->codes;
+ state->lencode = (const code FAR *)(state->next);
+ state->lenbits = 7;
+ ret = inflate_table(CODES, state->lens, 19, &(state->next),
+ &(state->lenbits), state->work);
+ if (ret) {
+ strm->msg = (char *)"invalid code lengths set";
+ state->mode = BAD;
+ break;
+ }
+ Tracev((stderr, "inflate: code lengths ok\n"));
+ state->have = 0;
+ state->mode = CODELENS;
+ case CODELENS:
+ while (state->have < state->nlen + state->ndist) {
+ for (;;) {
+ here = state->lencode[BITS(state->lenbits)];
+ if ((unsigned)(here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ if (here.val < 16) {
+ DROPBITS(here.bits);
+ state->lens[state->have++] = here.val;
+ }
+ else {
+ if (here.val == 16) {
+ NEEDBITS(here.bits + 2);
+ DROPBITS(here.bits);
+ if (state->have == 0) {
+ strm->msg = (char *)"invalid bit length repeat";
+ state->mode = BAD;
+ break;
+ }
+ len = state->lens[state->have - 1];
+ copy = 3 + BITS(2);
+ DROPBITS(2);
+ }
+ else if (here.val == 17) {
+ NEEDBITS(here.bits + 3);
+ DROPBITS(here.bits);
+ len = 0;
+ copy = 3 + BITS(3);
+ DROPBITS(3);
+ }
+ else {
+ NEEDBITS(here.bits + 7);
+ DROPBITS(here.bits);
+ len = 0;
+ copy = 11 + BITS(7);
+ DROPBITS(7);
+ }
+ if (state->have + copy > state->nlen + state->ndist) {
+ strm->msg = (char *)"invalid bit length repeat";
+ state->mode = BAD;
+ break;
+ }
+ while (copy--)
+ state->lens[state->have++] = (unsigned short)len;
+ }
+ }
+
+ /* handle error breaks in while */
+ if (state->mode == BAD) break;
+
+ /* check for end-of-block code (better have one) */
+ if (state->lens[256] == 0) {
+ strm->msg = (char *)"invalid code -- missing end-of-block";
+ state->mode = BAD;
+ break;
+ }
+
+ /* build code tables -- note: do not change the lenbits or distbits
+ values here (9 and 6) without reading the comments in inftrees.h
+ concerning the ENOUGH constants, which depend on those values */
+ state->next = state->codes;
+ state->lencode = (const code FAR *)(state->next);
+ state->lenbits = 9;
+ ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
+ &(state->lenbits), state->work);
+ if (ret) {
+ strm->msg = (char *)"invalid literal/lengths set";
+ state->mode = BAD;
+ break;
+ }
+ state->distcode = (const code FAR *)(state->next);
+ state->distbits = 6;
+ ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
+ &(state->next), &(state->distbits), state->work);
+ if (ret) {
+ strm->msg = (char *)"invalid distances set";
+ state->mode = BAD;
+ break;
+ }
+ Tracev((stderr, "inflate: codes ok\n"));
+ state->mode = LEN_;
+ if (flush == Z_TREES) goto inf_leave;
+ case LEN_:
+ state->mode = LEN;
+ case LEN:
+ if (have >= 6 && left >= 258) {
+ RESTORE();
+ inflate_fast(strm, out);
+ LOAD();
+ if (state->mode == TYPE)
+ state->back = -1;
+ break;
+ }
+ state->back = 0;
+ for (;;) {
+ here = state->lencode[BITS(state->lenbits)];
+ if ((unsigned)(here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ if (here.op && (here.op & 0xf0) == 0) {
+ last = here;
+ for (;;) {
+ here = state->lencode[last.val +
+ (BITS(last.bits + last.op) >> last.bits)];
+ if ((unsigned)(last.bits + here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ DROPBITS(last.bits);
+ state->back += last.bits;
+ }
+ DROPBITS(here.bits);
+ state->back += here.bits;
+ state->length = (unsigned)here.val;
+ if ((int)(here.op) == 0) {
+ Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
+ "inflate: literal '%c'\n" :
+ "inflate: literal 0x%02x\n", here.val));
+ state->mode = LIT;
+ break;
+ }
+ if (here.op & 32) {
+ Tracevv((stderr, "inflate: end of block\n"));
+ state->back = -1;
+ state->mode = TYPE;
+ break;
+ }
+ if (here.op & 64) {
+ strm->msg = (char *)"invalid literal/length code";
+ state->mode = BAD;
+ break;
+ }
+ state->extra = (unsigned)(here.op) & 15;
+ state->mode = LENEXT;
+ case LENEXT:
+ if (state->extra) {
+ NEEDBITS(state->extra);
+ state->length += BITS(state->extra);
+ DROPBITS(state->extra);
+ state->back += state->extra;
+ }
+ Tracevv((stderr, "inflate: length %u\n", state->length));
+ state->was = state->length;
+ state->mode = DIST;
+ case DIST:
+ for (;;) {
+ here = state->distcode[BITS(state->distbits)];
+ if ((unsigned)(here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ if ((here.op & 0xf0) == 0) {
+ last = here;
+ for (;;) {
+ here = state->distcode[last.val +
+ (BITS(last.bits + last.op) >> last.bits)];
+ if ((unsigned)(last.bits + here.bits) <= bits) break;
+ PULLBYTE();
+ }
+ DROPBITS(last.bits);
+ state->back += last.bits;
+ }
+ DROPBITS(here.bits);
+ state->back += here.bits;
+ if (here.op & 64) {
+ strm->msg = (char *)"invalid distance code";
+ state->mode = BAD;
+ break;
+ }
+ state->offset = (unsigned)here.val;
+ state->extra = (unsigned)(here.op) & 15;
+ state->mode = DISTEXT;
+ case DISTEXT:
+ if (state->extra) {
+ NEEDBITS(state->extra);
+ state->offset += BITS(state->extra);
+ DROPBITS(state->extra);
+ state->back += state->extra;
+ }
+#ifdef INFLATE_STRICT
+ if (state->offset > state->dmax) {
+ strm->msg = (char *)"invalid distance too far back";
+ state->mode = BAD;
+ break;
+ }
+#endif
+ Tracevv((stderr, "inflate: distance %u\n", state->offset));
+ state->mode = MATCH;
+ case MATCH:
+ if (left == 0) goto inf_leave;
+ copy = out - left;
+ if (state->offset > copy) { /* copy from window */
+ copy = state->offset - copy;
+ if (copy > state->whave) {
+ if (state->sane) {
+ strm->msg = (char *)"invalid distance too far back";
+ state->mode = BAD;
+ break;
+ }
+#ifdef INFLATE_ALLOW_INVALID_DISTANCE_TOOFAR_ARRR
+ Trace((stderr, "inflate.c too far\n"));
+ copy -= state->whave;
+ if (copy > state->length) copy = state->length;
+ if (copy > left) copy = left;
+ left -= copy;
+ state->length -= copy;
+ do {
+ *put++ = 0;
+ } while (--copy);
+ if (state->length == 0) state->mode = LEN;
+ break;
+#endif
+ }
+ if (copy > state->wnext) {
+ copy -= state->wnext;
+ from = state->window + (state->wsize - copy);
+ }
+ else
+ from = state->window + (state->wnext - copy);
+ if (copy > state->length) copy = state->length;
+ }
+ else { /* copy from output */
+ from = put - state->offset;
+ copy = state->length;
+ }
+ if (copy > left) copy = left;
+ left -= copy;
+ state->length -= copy;
+ do {
+ *put++ = *from++;
+ } while (--copy);
+ if (state->length == 0) state->mode = LEN;
+ break;
+ case LIT:
+ if (left == 0) goto inf_leave;
+ *put++ = (unsigned char)(state->length);
+ left--;
+ state->mode = LEN;
+ break;
+ case CHECK:
+ if (state->wrap) {
+ NEEDBITS(32);
+ out -= left;
+ strm->total_out += out;
+ state->total += out;
+ if (out)
+ strm->adler = state->check =
+ UPDATE(state->check, put - out, out);
+ out = left;
+ if ((
+#ifdef GUNZIP
+ state->flags ? hold :
+#endif
+ ZSWAP32(hold)) != state->check) {
+ strm->msg = (char *)"incorrect data check";
+ state->mode = BAD;
+ break;
+ }
+ INITBITS();
+ Tracev((stderr, "inflate: check matches trailer\n"));
+ }
+#ifdef GUNZIP
+ state->mode = LENGTH;
+ case LENGTH:
+ if (state->wrap && state->flags) {
+ NEEDBITS(32);
+ if (hold != (state->total & 0xffffffffUL)) {
+ strm->msg = (char *)"incorrect length check";
+ state->mode = BAD;
+ break;
+ }
+ INITBITS();
+ Tracev((stderr, "inflate: length matches trailer\n"));
+ }
+#endif
+ state->mode = DONE;
+ case DONE:
+ ret = Z_STREAM_END;
+ goto inf_leave;
+ case BAD:
+ ret = Z_DATA_ERROR;
+ goto inf_leave;
+ case MEM:
+ return Z_MEM_ERROR;
+ case SYNC:
+ default:
+ return Z_STREAM_ERROR;
+ }
+
+ /*
+ Return from inflate(), updating the total counts and the check value.
+ If there was no progress during the inflate() call, return a buffer
+ error. Call updatewindow() to create and/or update the window state.
+ Note: a memory error from inflate() is non-recoverable.
+ */
+ inf_leave:
+ RESTORE();
+ if (state->wsize || (out != strm->avail_out && state->mode < BAD &&
+ (state->mode < CHECK || flush != Z_FINISH)))
+ if (updatewindow(strm, strm->next_out, out - strm->avail_out)) {
+ state->mode = MEM;
+ return Z_MEM_ERROR;
+ }
+ in -= strm->avail_in;
+ out -= strm->avail_out;
+ strm->total_in += in;
+ strm->total_out += out;
+ state->total += out;
+ if (state->wrap && out)
+ strm->adler = state->check =
+ UPDATE(state->check, strm->next_out - out, out);
+ strm->data_type = state->bits + (state->last ? 64 : 0) +
+ (state->mode == TYPE ? 128 : 0) +
+ (state->mode == LEN_ || state->mode == COPY_ ? 256 : 0);
+ if (((in == 0 && out == 0) || flush == Z_FINISH) && ret == Z_OK)
+ ret = Z_BUF_ERROR;
+ return ret;
+}
+
+int ZEXPORT inflateEnd(strm)
+z_streamp strm;
+{
+ struct inflate_state FAR *state;
+ if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
+ return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+ if (state->window != Z_NULL) ZFREE(strm, state->window);
+ ZFREE(strm, strm->state);
+ strm->state = Z_NULL;
+ Tracev((stderr, "inflate: end\n"));
+ return Z_OK;
+}
+
+int ZEXPORT inflateGetDictionary(strm, dictionary, dictLength)
+z_streamp strm;
+Bytef *dictionary;
+uInt *dictLength;
+{
+ struct inflate_state FAR *state;
+
+ /* check state */
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+
+ /* copy dictionary */
+ if (state->whave && dictionary != Z_NULL) {
+ zmemcpy(dictionary, state->window + state->wnext,
+ state->whave - state->wnext);
+ zmemcpy(dictionary + state->whave - state->wnext,
+ state->window, state->wnext);
+ }
+ if (dictLength != Z_NULL)
+ *dictLength = state->whave;
+ return Z_OK;
+}
+
+int ZEXPORT inflateSetDictionary(strm, dictionary, dictLength)
+z_streamp strm;
+const Bytef *dictionary;
+uInt dictLength;
+{
+ struct inflate_state FAR *state;
+ unsigned long dictid;
+ int ret;
+
+ /* check state */
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+ if (state->wrap != 0 && state->mode != DICT)
+ return Z_STREAM_ERROR;
+
+ /* check for correct dictionary identifier */
+ if (state->mode == DICT) {
+ dictid = adler32(0L, Z_NULL, 0);
+ dictid = adler32(dictid, dictionary, dictLength);
+ if (dictid != state->check)
+ return Z_DATA_ERROR;
+ }
+
+ /* copy dictionary to window using updatewindow(), which will amend the
+ existing dictionary if appropriate */
+ ret = updatewindow(strm, dictionary + dictLength, dictLength);
+ if (ret) {
+ state->mode = MEM;
+ return Z_MEM_ERROR;
+ }
+ state->havedict = 1;
+ Tracev((stderr, "inflate: dictionary set\n"));
+ return Z_OK;
+}
+
+int ZEXPORT inflateGetHeader(strm, head)
+z_streamp strm;
+gz_headerp head;
+{
+ struct inflate_state FAR *state;
+
+ /* check state */
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+ if ((state->wrap & 2) == 0) return Z_STREAM_ERROR;
+
+ /* save header structure */
+ state->head = head;
+ head->done = 0;
+ return Z_OK;
+}
+
+/*
+ Search buf[0..len-1] for the pattern: 0, 0, 0xff, 0xff. Return when found
+ or when out of input. When called, *have is the number of pattern bytes
+ found in order so far, in 0..3. On return *have is updated to the new
+ state. If on return *have equals four, then the pattern was found and the
+ return value is how many bytes were read including the last byte of the
+ pattern. If *have is less than four, then the pattern has not been found
+ yet and the return value is len. In the latter case, syncsearch() can be
+ called again with more data and the *have state. *have is initialized to
+ zero for the first call.
+ */
+local unsigned syncsearch(have, buf, len)
+unsigned FAR *have;
+const unsigned char FAR *buf;
+unsigned len;
+{
+ unsigned got;
+ unsigned next;
+
+ got = *have;
+ next = 0;
+ while (next < len && got < 4) {
+ if ((int)(buf[next]) == (got < 2 ? 0 : 0xff))
+ got++;
+ else if (buf[next])
+ got = 0;
+ else
+ got = 4 - got;
+ next++;
+ }
+ *have = got;
+ return next;
+}
+
+int ZEXPORT inflateSync(strm)
+z_streamp strm;
+{
+ unsigned len; /* number of bytes to look at or looked at */
+ unsigned long in, out; /* temporary to save total_in and total_out */
+ unsigned char buf[4]; /* to restore bit buffer to byte string */
+ struct inflate_state FAR *state;
+
+ /* check parameters */
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+ if (strm->avail_in == 0 && state->bits < 8) return Z_BUF_ERROR;
+
+ /* if first time, start search in bit buffer */
+ if (state->mode != SYNC) {
+ state->mode = SYNC;
+ state->hold <<= state->bits & 7;
+ state->bits -= state->bits & 7;
+ len = 0;
+ while (state->bits >= 8) {
+ buf[len++] = (unsigned char)(state->hold);
+ state->hold >>= 8;
+ state->bits -= 8;
+ }
+ state->have = 0;
+ syncsearch(&(state->have), buf, len);
+ }
+
+ /* search available input */
+ len = syncsearch(&(state->have), strm->next_in, strm->avail_in);
+ strm->avail_in -= len;
+ strm->next_in += len;
+ strm->total_in += len;
+
+ /* return no joy or set up to restart inflate() on a new block */
+ if (state->have != 4) return Z_DATA_ERROR;
+ in = strm->total_in; out = strm->total_out;
+ inflateReset(strm);
+ strm->total_in = in; strm->total_out = out;
+ state->mode = TYPE;
+ return Z_OK;
+}
+
+/*
+ Returns true if inflate is currently at the end of a block generated by
+ Z_SYNC_FLUSH or Z_FULL_FLUSH. This function is used by one PPP
+ implementation to provide an additional safety check. PPP uses
+ Z_SYNC_FLUSH but removes the length bytes of the resulting empty stored
+ block. When decompressing, PPP checks that at the end of input packet,
+ inflate is waiting for these length bytes.
+ */
+int ZEXPORT inflateSyncPoint(strm)
+z_streamp strm;
+{
+ struct inflate_state FAR *state;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+ return state->mode == STORED && state->bits == 0;
+}
+
+int ZEXPORT inflateCopy(dest, source)
+z_streamp dest;
+z_streamp source;
+{
+ struct inflate_state FAR *state;
+ struct inflate_state FAR *copy;
+ unsigned char FAR *window;
+ unsigned wsize;
+
+ /* check input */
+ if (dest == Z_NULL || source == Z_NULL || source->state == Z_NULL ||
+ source->zalloc == (alloc_func)0 || source->zfree == (free_func)0)
+ return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)source->state;
+
+ /* allocate space */
+ copy = (struct inflate_state FAR *)
+ ZALLOC(source, 1, sizeof(struct inflate_state));
+ if (copy == Z_NULL) return Z_MEM_ERROR;
+ window = Z_NULL;
+ if (state->window != Z_NULL) {
+ window = (unsigned char FAR *)
+ ZALLOC(source, 1U << state->wbits, sizeof(unsigned char));
+ if (window == Z_NULL) {
+ ZFREE(source, copy);
+ return Z_MEM_ERROR;
+ }
+ }
+
+ /* copy state */
+ zmemcpy((voidpf)dest, (voidpf)source, sizeof(z_stream));
+ zmemcpy((voidpf)copy, (voidpf)state, sizeof(struct inflate_state));
+ if (state->lencode >= state->codes &&
+ state->lencode <= state->codes + ENOUGH - 1) {
+ copy->lencode = copy->codes + (state->lencode - state->codes);
+ copy->distcode = copy->codes + (state->distcode - state->codes);
+ }
+ copy->next = copy->codes + (state->next - state->codes);
+ if (window != Z_NULL) {
+ wsize = 1U << state->wbits;
+ zmemcpy(window, state->window, wsize);
+ }
+ copy->window = window;
+ dest->state = (struct internal_state FAR *)copy;
+ return Z_OK;
+}
+
+int ZEXPORT inflateUndermine(strm, subvert)
+z_streamp strm;
+int subvert;
+{
+ struct inflate_state FAR *state;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR;
+ state = (struct inflate_state FAR *)strm->state;
+ state->sane = !subvert;
+#ifdef INFLATE_ALLOW_INVALID_DISTANCE_TOOFAR_ARRR
+ return Z_OK;
+#else
+ state->sane = 1;
+ return Z_DATA_ERROR;
+#endif
+}
+
+long ZEXPORT inflateMark(strm)
+z_streamp strm;
+{
+ struct inflate_state FAR *state;
+
+ if (strm == Z_NULL || strm->state == Z_NULL) return -1L << 16;
+ state = (struct inflate_state FAR *)strm->state;
+ return ((long)(state->back) << 16) +
+ (state->mode == COPY ? state->length :
+ (state->mode == MATCH ? state->was - state->length : 0));
+}
diff --git a/ml/dlib/dlib/external/zlib/inflate.h b/ml/dlib/dlib/external/zlib/inflate.h
new file mode 100644
index 000000000..95f4986d4
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/inflate.h
@@ -0,0 +1,122 @@
+/* inflate.h -- internal inflate state definition
+ * Copyright (C) 1995-2009 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* WARNING: this file should *not* be used by applications. It is
+ part of the implementation of the compression library and is
+ subject to change. Applications should only use zlib.h.
+ */
+
+/* define NO_GZIP when compiling if you want to disable gzip header and
+ trailer decoding by inflate(). NO_GZIP would be used to avoid linking in
+ the crc code when it is not needed. For shared libraries, gzip decoding
+ should be left enabled. */
+#ifndef NO_GZIP
+# define GUNZIP
+#endif
+
+/* Possible inflate modes between inflate() calls */
+typedef enum {
+ HEAD, /* i: waiting for magic header */
+ FLAGS, /* i: waiting for method and flags (gzip) */
+ TIME, /* i: waiting for modification time (gzip) */
+ OS, /* i: waiting for extra flags and operating system (gzip) */
+ EXLEN, /* i: waiting for extra length (gzip) */
+ EXTRA, /* i: waiting for extra bytes (gzip) */
+ NAME, /* i: waiting for end of file name (gzip) */
+ COMMENT, /* i: waiting for end of comment (gzip) */
+ HCRC, /* i: waiting for header crc (gzip) */
+ DICTID, /* i: waiting for dictionary check value */
+ DICT, /* waiting for inflateSetDictionary() call */
+ TYPE, /* i: waiting for type bits, including last-flag bit */
+ TYPEDO, /* i: same, but skip check to exit inflate on new block */
+ STORED, /* i: waiting for stored size (length and complement) */
+ COPY_, /* i/o: same as COPY below, but only first time in */
+ COPY, /* i/o: waiting for input or output to copy stored block */
+ TABLE, /* i: waiting for dynamic block table lengths */
+ LENLENS, /* i: waiting for code length code lengths */
+ CODELENS, /* i: waiting for length/lit and distance code lengths */
+ LEN_, /* i: same as LEN below, but only first time in */
+ LEN, /* i: waiting for length/lit/eob code */
+ LENEXT, /* i: waiting for length extra bits */
+ DIST, /* i: waiting for distance code */
+ DISTEXT, /* i: waiting for distance extra bits */
+ MATCH, /* o: waiting for output space to copy string */
+ LIT, /* o: waiting for output space to write literal */
+ CHECK, /* i: waiting for 32-bit check value */
+ LENGTH, /* i: waiting for 32-bit length (gzip) */
+ DONE, /* finished check, done -- remain here until reset */
+ BAD, /* got a data error -- remain here until reset */
+ MEM, /* got an inflate() memory error -- remain here until reset */
+ SYNC /* looking for synchronization bytes to restart inflate() */
+} inflate_mode;
+
+/*
+ State transitions between above modes -
+
+ (most modes can go to BAD or MEM on error -- not shown for clarity)
+
+ Process header:
+ HEAD -> (gzip) or (zlib) or (raw)
+ (gzip) -> FLAGS -> TIME -> OS -> EXLEN -> EXTRA -> NAME -> COMMENT ->
+ HCRC -> TYPE
+ (zlib) -> DICTID or TYPE
+ DICTID -> DICT -> TYPE
+ (raw) -> TYPEDO
+ Read deflate blocks:
+ TYPE -> TYPEDO -> STORED or TABLE or LEN_ or CHECK
+ STORED -> COPY_ -> COPY -> TYPE
+ TABLE -> LENLENS -> CODELENS -> LEN_
+ LEN_ -> LEN
+ Read deflate codes in fixed or dynamic block:
+ LEN -> LENEXT or LIT or TYPE
+ LENEXT -> DIST -> DISTEXT -> MATCH -> LEN
+ LIT -> LEN
+ Process trailer:
+ CHECK -> LENGTH -> DONE
+ */
+
+/* state maintained between inflate() calls. Approximately 10K bytes. */
+struct inflate_state {
+ inflate_mode mode; /* current inflate mode */
+ int last; /* true if processing last block */
+ int wrap; /* bit 0 true for zlib, bit 1 true for gzip */
+ int havedict; /* true if dictionary provided */
+ int flags; /* gzip header method and flags (0 if zlib) */
+ unsigned dmax; /* zlib header max distance (INFLATE_STRICT) */
+ unsigned long check; /* protected copy of check value */
+ unsigned long total; /* protected copy of output count */
+ gz_headerp head; /* where to save gzip header information */
+ /* sliding window */
+ unsigned wbits; /* log base 2 of requested window size */
+ unsigned wsize; /* window size or zero if not using window */
+ unsigned whave; /* valid bytes in the window */
+ unsigned wnext; /* window write index */
+ unsigned char FAR *window; /* allocated sliding window, if needed */
+ /* bit accumulator */
+ unsigned long hold; /* input bit accumulator */
+ unsigned bits; /* number of bits in "in" */
+ /* for string and stored block copying */
+ unsigned length; /* literal or length of data to copy */
+ unsigned offset; /* distance back to copy string from */
+ /* for table and code decoding */
+ unsigned extra; /* extra bits needed */
+ /* fixed and dynamic code tables */
+ code const FAR *lencode; /* starting table for length/literal codes */
+ code const FAR *distcode; /* starting table for distance codes */
+ unsigned lenbits; /* index bits for lencode */
+ unsigned distbits; /* index bits for distcode */
+ /* dynamic table building */
+ unsigned ncode; /* number of code length code lengths */
+ unsigned nlen; /* number of length code lengths */
+ unsigned ndist; /* number of distance code lengths */
+ unsigned have; /* number of code lengths in lens[] */
+ code FAR *next; /* next available space in codes[] */
+ unsigned short lens[320]; /* temporary storage for code lengths */
+ unsigned short work[288]; /* work area for code table building */
+ code codes[ENOUGH]; /* space for code tables */
+ int sane; /* if false, allow invalid distance too far */
+ int back; /* bits back of last unprocessed length/lit */
+ unsigned was; /* initial length of match */
+};
diff --git a/ml/dlib/dlib/external/zlib/inftrees.c b/ml/dlib/dlib/external/zlib/inftrees.c
new file mode 100644
index 000000000..44d89cf24
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/inftrees.c
@@ -0,0 +1,306 @@
+/* inftrees.c -- generate Huffman trees for efficient decoding
+ * Copyright (C) 1995-2013 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+#include "zutil.h"
+#include "inftrees.h"
+
+#define MAXBITS 15
+
+const char inflate_copyright[] =
+ " inflate 1.2.8 Copyright 1995-2013 Mark Adler ";
+/*
+ If you use the zlib library in a product, an acknowledgment is welcome
+ in the documentation of your product. If for some reason you cannot
+ include such an acknowledgment, I would appreciate that you keep this
+ copyright string in the executable of your product.
+ */
+
+/*
+ Build a set of tables to decode the provided canonical Huffman code.
+ The code lengths are lens[0..codes-1]. The result starts at *table,
+ whose indices are 0..2^bits-1. work is a writable array of at least
+ lens shorts, which is used as a work area. type is the type of code
+ to be generated, CODES, LENS, or DISTS. On return, zero is success,
+ -1 is an invalid code, and +1 means that ENOUGH isn't enough. table
+ on return points to the next available entry's address. bits is the
+ requested root table index bits, and on return it is the actual root
+ table index bits. It will differ if the request is greater than the
+ longest code or if it is less than the shortest code.
+ */
+int ZLIB_INTERNAL inflate_table(type, lens, codes, table, bits, work)
+codetype type;
+unsigned short FAR *lens;
+unsigned codes;
+code FAR * FAR *table;
+unsigned FAR *bits;
+unsigned short FAR *work;
+{
+ unsigned len; /* a code's length in bits */
+ unsigned sym; /* index of code symbols */
+ unsigned min, max; /* minimum and maximum code lengths */
+ unsigned root; /* number of index bits for root table */
+ unsigned curr; /* number of index bits for current table */
+ unsigned drop; /* code bits to drop for sub-table */
+ int left; /* number of prefix codes available */
+ unsigned used; /* code entries in table used */
+ unsigned huff; /* Huffman code */
+ unsigned incr; /* for incrementing code, index */
+ unsigned fill; /* index for replicating entries */
+ unsigned low; /* low bits for current root entry */
+ unsigned mask; /* mask for low root bits */
+ code here; /* table entry for duplication */
+ code FAR *next; /* next available space in table */
+ const unsigned short FAR *base; /* base value table to use */
+ const unsigned short FAR *extra; /* extra bits table to use */
+ int end; /* use base and extra for symbol > end */
+ unsigned short count[MAXBITS+1]; /* number of codes of each length */
+ unsigned short offs[MAXBITS+1]; /* offsets in table for each length */
+ static const unsigned short lbase[31] = { /* Length codes 257..285 base */
+ 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
+ 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0};
+ static const unsigned short lext[31] = { /* Length codes 257..285 extra */
+ 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 18, 18, 18, 18,
+ 19, 19, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 16, 72, 78};
+ static const unsigned short dbase[32] = { /* Distance codes 0..29 base */
+ 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
+ 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
+ 8193, 12289, 16385, 24577, 0, 0};
+ static const unsigned short dext[32] = { /* Distance codes 0..29 extra */
+ 16, 16, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22,
+ 23, 23, 24, 24, 25, 25, 26, 26, 27, 27,
+ 28, 28, 29, 29, 64, 64};
+
+ /*
+ Process a set of code lengths to create a canonical Huffman code. The
+ code lengths are lens[0..codes-1]. Each length corresponds to the
+ symbols 0..codes-1. The Huffman code is generated by first sorting the
+ symbols by length from short to long, and retaining the symbol order
+ for codes with equal lengths. Then the code starts with all zero bits
+ for the first code of the shortest length, and the codes are integer
+ increments for the same length, and zeros are appended as the length
+ increases. For the deflate format, these bits are stored backwards
+ from their more natural integer increment ordering, and so when the
+ decoding tables are built in the large loop below, the integer codes
+ are incremented backwards.
+
+ This routine assumes, but does not check, that all of the entries in
+ lens[] are in the range 0..MAXBITS. The caller must assure this.
+ 1..MAXBITS is interpreted as that code length. zero means that that
+ symbol does not occur in this code.
+
+ The codes are sorted by computing a count of codes for each length,
+ creating from that a table of starting indices for each length in the
+ sorted table, and then entering the symbols in order in the sorted
+ table. The sorted table is work[], with that space being provided by
+ the caller.
+
+ The length counts are used for other purposes as well, i.e. finding
+ the minimum and maximum length codes, determining if there are any
+ codes at all, checking for a valid set of lengths, and looking ahead
+ at length counts to determine sub-table sizes when building the
+ decoding tables.
+ */
+
+ /* accumulate lengths for codes (assumes lens[] all in 0..MAXBITS) */
+ for (len = 0; len <= MAXBITS; len++)
+ count[len] = 0;
+ for (sym = 0; sym < codes; sym++)
+ count[lens[sym]]++;
+
+ /* bound code lengths, force root to be within code lengths */
+ root = *bits;
+ for (max = MAXBITS; max >= 1; max--)
+ if (count[max] != 0) break;
+ if (root > max) root = max;
+ if (max == 0) { /* no symbols to code at all */
+ here.op = (unsigned char)64; /* invalid code marker */
+ here.bits = (unsigned char)1;
+ here.val = (unsigned short)0;
+ *(*table)++ = here; /* make a table to force an error */
+ *(*table)++ = here;
+ *bits = 1;
+ return 0; /* no symbols, but wait for decoding to report error */
+ }
+ for (min = 1; min < max; min++)
+ if (count[min] != 0) break;
+ if (root < min) root = min;
+
+ /* check for an over-subscribed or incomplete set of lengths */
+ left = 1;
+ for (len = 1; len <= MAXBITS; len++) {
+ left <<= 1;
+ left -= count[len];
+ if (left < 0) return -1; /* over-subscribed */
+ }
+ if (left > 0 && (type == CODES || max != 1))
+ return -1; /* incomplete set */
+
+ /* generate offsets into symbol table for each length for sorting */
+ offs[1] = 0;
+ for (len = 1; len < MAXBITS; len++)
+ offs[len + 1] = offs[len] + count[len];
+
+ /* sort symbols by length, by symbol order within each length */
+ for (sym = 0; sym < codes; sym++)
+ if (lens[sym] != 0) work[offs[lens[sym]]++] = (unsigned short)sym;
+
+ /*
+ Create and fill in decoding tables. In this loop, the table being
+ filled is at next and has curr index bits. The code being used is huff
+ with length len. That code is converted to an index by dropping drop
+ bits off of the bottom. For codes where len is less than drop + curr,
+ those top drop + curr - len bits are incremented through all values to
+ fill the table with replicated entries.
+
+ root is the number of index bits for the root table. When len exceeds
+ root, sub-tables are created pointed to by the root entry with an index
+ of the low root bits of huff. This is saved in low to check for when a
+ new sub-table should be started. drop is zero when the root table is
+ being filled, and drop is root when sub-tables are being filled.
+
+ When a new sub-table is needed, it is necessary to look ahead in the
+ code lengths to determine what size sub-table is needed. The length
+ counts are used for this, and so count[] is decremented as codes are
+ entered in the tables.
+
+ used keeps track of how many table entries have been allocated from the
+ provided *table space. It is checked for LENS and DIST tables against
+ the constants ENOUGH_LENS and ENOUGH_DISTS to guard against changes in
+ the initial root table size constants. See the comments in inftrees.h
+ for more information.
+
+ sym increments through all symbols, and the loop terminates when
+ all codes of length max, i.e. all codes, have been processed. This
+ routine permits incomplete codes, so another loop after this one fills
+ in the rest of the decoding tables with invalid code markers.
+ */
+
+ /* set up for code type */
+ switch (type) {
+ case CODES:
+ base = extra = work; /* dummy value--not used */
+ end = 19;
+ break;
+ case LENS:
+ base = lbase;
+ base -= 257;
+ extra = lext;
+ extra -= 257;
+ end = 256;
+ break;
+ default: /* DISTS */
+ base = dbase;
+ extra = dext;
+ end = -1;
+ }
+
+ /* initialize state for loop */
+ huff = 0; /* starting code */
+ sym = 0; /* starting code symbol */
+ len = min; /* starting code length */
+ next = *table; /* current table to fill in */
+ curr = root; /* current table index bits */
+ drop = 0; /* current bits to drop from code for index */
+ low = (unsigned)(-1); /* trigger new sub-table when len > root */
+ used = 1U << root; /* use root table entries */
+ mask = used - 1; /* mask for comparing low */
+
+ /* check available table space */
+ if ((type == LENS && used > ENOUGH_LENS) ||
+ (type == DISTS && used > ENOUGH_DISTS))
+ return 1;
+
+ /* process all codes and make table entries */
+ for (;;) {
+ /* create table entry */
+ here.bits = (unsigned char)(len - drop);
+ if ((int)(work[sym]) < end) {
+ here.op = (unsigned char)0;
+ here.val = work[sym];
+ }
+ else if ((int)(work[sym]) > end) {
+ here.op = (unsigned char)(extra[work[sym]]);
+ here.val = base[work[sym]];
+ }
+ else {
+ here.op = (unsigned char)(32 + 64); /* end of block */
+ here.val = 0;
+ }
+
+ /* replicate for those indices with low len bits equal to huff */
+ incr = 1U << (len - drop);
+ fill = 1U << curr;
+ min = fill; /* save offset to next table */
+ do {
+ fill -= incr;
+ next[(huff >> drop) + fill] = here;
+ } while (fill != 0);
+
+ /* backwards increment the len-bit code huff */
+ incr = 1U << (len - 1);
+ while (huff & incr)
+ incr >>= 1;
+ if (incr != 0) {
+ huff &= incr - 1;
+ huff += incr;
+ }
+ else
+ huff = 0;
+
+ /* go to next symbol, update count, len */
+ sym++;
+ if (--(count[len]) == 0) {
+ if (len == max) break;
+ len = lens[work[sym]];
+ }
+
+ /* create new sub-table if needed */
+ if (len > root && (huff & mask) != low) {
+ /* if first time, transition to sub-tables */
+ if (drop == 0)
+ drop = root;
+
+ /* increment past last table */
+ next += min; /* here min is 1 << curr */
+
+ /* determine length of next table */
+ curr = len - drop;
+ left = (int)(1 << curr);
+ while (curr + drop < max) {
+ left -= count[curr + drop];
+ if (left <= 0) break;
+ curr++;
+ left <<= 1;
+ }
+
+ /* check for enough space */
+ used += 1U << curr;
+ if ((type == LENS && used > ENOUGH_LENS) ||
+ (type == DISTS && used > ENOUGH_DISTS))
+ return 1;
+
+ /* point entry in root table to sub-table */
+ low = huff & mask;
+ (*table)[low].op = (unsigned char)curr;
+ (*table)[low].bits = (unsigned char)root;
+ (*table)[low].val = (unsigned short)(next - *table);
+ }
+ }
+
+ /* fill in remaining table entry if code is incomplete (guaranteed to have
+ at most one remaining entry, since if the code is incomplete, the
+ maximum code length that was allowed to get this far is one bit) */
+ if (huff != 0) {
+ here.op = (unsigned char)64; /* invalid code marker */
+ here.bits = (unsigned char)(len - drop);
+ here.val = (unsigned short)0;
+ next[huff] = here;
+ }
+
+ /* set return parameters */
+ *table += used;
+ *bits = root;
+ return 0;
+}
diff --git a/ml/dlib/dlib/external/zlib/inftrees.h b/ml/dlib/dlib/external/zlib/inftrees.h
new file mode 100644
index 000000000..baa53a0b1
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/inftrees.h
@@ -0,0 +1,62 @@
+/* inftrees.h -- header to use inftrees.c
+ * Copyright (C) 1995-2005, 2010 Mark Adler
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* WARNING: this file should *not* be used by applications. It is
+ part of the implementation of the compression library and is
+ subject to change. Applications should only use zlib.h.
+ */
+
+/* Structure for decoding tables. Each entry provides either the
+ information needed to do the operation requested by the code that
+ indexed that table entry, or it provides a pointer to another
+ table that indexes more bits of the code. op indicates whether
+ the entry is a pointer to another table, a literal, a length or
+ distance, an end-of-block, or an invalid code. For a table
+ pointer, the low four bits of op is the number of index bits of
+ that table. For a length or distance, the low four bits of op
+ is the number of extra bits to get after the code. bits is
+ the number of bits in this code or part of the code to drop off
+ of the bit buffer. val is the actual byte to output in the case
+ of a literal, the base length or distance, or the offset from
+ the current table to the next table. Each entry is four bytes. */
+typedef struct {
+ unsigned char op; /* operation, extra bits, table bits */
+ unsigned char bits; /* bits in this part of the code */
+ unsigned short val; /* offset in table or code value */
+} code;
+
+/* op values as set by inflate_table():
+ 00000000 - literal
+ 0000tttt - table link, tttt != 0 is the number of table index bits
+ 0001eeee - length or distance, eeee is the number of extra bits
+ 01100000 - end of block
+ 01000000 - invalid code
+ */
+
+/* Maximum size of the dynamic table. The maximum number of code structures is
+ 1444, which is the sum of 852 for literal/length codes and 592 for distance
+ codes. These values were found by exhaustive searches using the program
+ examples/enough.c found in the zlib distribtution. The arguments to that
+ program are the number of symbols, the initial root table size, and the
+ maximum bit length of a code. "enough 286 9 15" for literal/length codes
+ returns returns 852, and "enough 30 6 15" for distance codes returns 592.
+ The initial root table size (9 or 6) is found in the fifth argument of the
+ inflate_table() calls in inflate.c and infback.c. If the root table size is
+ changed, then these maximum sizes would be need to be recalculated and
+ updated. */
+#define ENOUGH_LENS 852
+#define ENOUGH_DISTS 592
+#define ENOUGH (ENOUGH_LENS+ENOUGH_DISTS)
+
+/* Type of code to build for inflate_table() */
+typedef enum {
+ CODES,
+ LENS,
+ DISTS
+} codetype;
+
+int ZLIB_INTERNAL inflate_table OF((codetype type, unsigned short FAR *lens,
+ unsigned codes, code FAR * FAR *table,
+ unsigned FAR *bits, unsigned short FAR *work));
diff --git a/ml/dlib/dlib/external/zlib/trees.c b/ml/dlib/dlib/external/zlib/trees.c
new file mode 100644
index 000000000..1fd7759ef
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/trees.c
@@ -0,0 +1,1226 @@
+/* trees.c -- output deflated data using Huffman coding
+ * Copyright (C) 1995-2012 Jean-loup Gailly
+ * detect_data_type() function provided freely by Cosmin Truta, 2006
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/*
+ * ALGORITHM
+ *
+ * The "deflation" process uses several Huffman trees. The more
+ * common source values are represented by shorter bit sequences.
+ *
+ * Each code tree is stored in a compressed form which is itself
+ * a Huffman encoding of the lengths of all the code strings (in
+ * ascending order by source values). The actual code strings are
+ * reconstructed from the lengths in the inflate process, as described
+ * in the deflate specification.
+ *
+ * REFERENCES
+ *
+ * Deutsch, L.P.,"'Deflate' Compressed Data Format Specification".
+ * Available in ftp.uu.net:/pub/archiving/zip/doc/deflate-1.1.doc
+ *
+ * Storer, James A.
+ * Data Compression: Methods and Theory, pp. 49-50.
+ * Computer Science Press, 1988. ISBN 0-7167-8156-5.
+ *
+ * Sedgewick, R.
+ * Algorithms, p290.
+ * Addison-Wesley, 1983. ISBN 0-201-06672-6.
+ */
+
+/* @(#) $Id$ */
+
+/* #define GEN_TREES_H */
+
+#include "deflate.h"
+
+#ifdef DEBUG
+# include <ctype.h>
+#endif
+
+/* ===========================================================================
+ * Constants
+ */
+
+#define MAX_BL_BITS 7
+/* Bit length codes must not exceed MAX_BL_BITS bits */
+
+#define END_BLOCK 256
+/* end of block literal code */
+
+#define REP_3_6 16
+/* repeat previous bit length 3-6 times (2 bits of repeat count) */
+
+#define REPZ_3_10 17
+/* repeat a zero length 3-10 times (3 bits of repeat count) */
+
+#define REPZ_11_138 18
+/* repeat a zero length 11-138 times (7 bits of repeat count) */
+
+local const int extra_lbits[LENGTH_CODES] /* extra bits for each length code */
+ = {0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0};
+
+local const int extra_dbits[D_CODES] /* extra bits for each distance code */
+ = {0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13};
+
+local const int extra_blbits[BL_CODES]/* extra bits for each bit length code */
+ = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,3,7};
+
+local const uch bl_order[BL_CODES]
+ = {16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15};
+/* The lengths of the bit length codes are sent in order of decreasing
+ * probability, to avoid transmitting the lengths for unused bit length codes.
+ */
+
+/* ===========================================================================
+ * Local data. These are initialized only once.
+ */
+
+#define DIST_CODE_LEN 512 /* see definition of array dist_code below */
+
+#if defined(GEN_TREES_H) || !defined(STDC)
+/* non ANSI compilers may not accept trees.h */
+
+local ct_data static_ltree[L_CODES+2];
+/* The static literal tree. Since the bit lengths are imposed, there is no
+ * need for the L_CODES extra codes used during heap construction. However
+ * The codes 286 and 287 are needed to build a canonical tree (see _tr_init
+ * below).
+ */
+
+local ct_data static_dtree[D_CODES];
+/* The static distance tree. (Actually a trivial tree since all codes use
+ * 5 bits.)
+ */
+
+uch _dist_code[DIST_CODE_LEN];
+/* Distance codes. The first 256 values correspond to the distances
+ * 3 .. 258, the last 256 values correspond to the top 8 bits of
+ * the 15 bit distances.
+ */
+
+uch _length_code[MAX_MATCH-MIN_MATCH+1];
+/* length code for each normalized match length (0 == MIN_MATCH) */
+
+local int base_length[LENGTH_CODES];
+/* First normalized length for each code (0 = MIN_MATCH) */
+
+local int base_dist[D_CODES];
+/* First normalized distance for each code (0 = distance of 1) */
+
+#else
+# include "trees.h"
+#endif /* GEN_TREES_H */
+
+struct static_tree_desc_s {
+ const ct_data *static_tree; /* static tree or NULL */
+ const intf *extra_bits; /* extra bits for each code or NULL */
+ int extra_base; /* base index for extra_bits */
+ int elems; /* max number of elements in the tree */
+ int max_length; /* max bit length for the codes */
+};
+
+local static_tree_desc static_l_desc =
+{static_ltree, extra_lbits, LITERALS+1, L_CODES, MAX_BITS};
+
+local static_tree_desc static_d_desc =
+{static_dtree, extra_dbits, 0, D_CODES, MAX_BITS};
+
+local static_tree_desc static_bl_desc =
+{(const ct_data *)0, extra_blbits, 0, BL_CODES, MAX_BL_BITS};
+
+/* ===========================================================================
+ * Local (static) routines in this file.
+ */
+
+local void tr_static_init OF((void));
+local void init_block OF((deflate_state *s));
+local void pqdownheap OF((deflate_state *s, ct_data *tree, int k));
+local void gen_bitlen OF((deflate_state *s, tree_desc *desc));
+local void gen_codes OF((ct_data *tree, int max_code, ushf *bl_count));
+local void build_tree OF((deflate_state *s, tree_desc *desc));
+local void scan_tree OF((deflate_state *s, ct_data *tree, int max_code));
+local void send_tree OF((deflate_state *s, ct_data *tree, int max_code));
+local int build_bl_tree OF((deflate_state *s));
+local void send_all_trees OF((deflate_state *s, int lcodes, int dcodes,
+ int blcodes));
+local void compress_block OF((deflate_state *s, const ct_data *ltree,
+ const ct_data *dtree));
+local int detect_data_type OF((deflate_state *s));
+local unsigned bi_reverse OF((unsigned value, int length));
+local void bi_windup OF((deflate_state *s));
+local void bi_flush OF((deflate_state *s));
+local void copy_block OF((deflate_state *s, charf *buf, unsigned len,
+ int header));
+
+#ifdef GEN_TREES_H
+local void gen_trees_header OF((void));
+#endif
+
+#ifndef DEBUG
+# define send_code(s, c, tree) send_bits(s, tree[c].Code, tree[c].Len)
+ /* Send a code of the given tree. c and tree must not have side effects */
+
+#else /* DEBUG */
+# define send_code(s, c, tree) \
+ { if (z_verbose>2) fprintf(stderr,"\ncd %3d ",(c)); \
+ send_bits(s, tree[c].Code, tree[c].Len); }
+#endif
+
+/* ===========================================================================
+ * Output a short LSB first on the stream.
+ * IN assertion: there is enough room in pendingBuf.
+ */
+#define put_short(s, w) { \
+ put_byte(s, (uch)((w) & 0xff)); \
+ put_byte(s, (uch)((ush)(w) >> 8)); \
+}
+
+/* ===========================================================================
+ * Send a value on a given number of bits.
+ * IN assertion: length <= 16 and value fits in length bits.
+ */
+#ifdef DEBUG
+local void send_bits OF((deflate_state *s, int value, int length));
+
+local void send_bits(s, value, length)
+ deflate_state *s;
+ int value; /* value to send */
+ int length; /* number of bits */
+{
+ Tracevv((stderr," l %2d v %4x ", length, value));
+ Assert(length > 0 && length <= 15, "invalid length");
+ s->bits_sent += (ulg)length;
+
+ /* If not enough room in bi_buf, use (valid) bits from bi_buf and
+ * (16 - bi_valid) bits from value, leaving (width - (16-bi_valid))
+ * unused bits in value.
+ */
+ if (s->bi_valid > (int)Buf_size - length) {
+ s->bi_buf |= (ush)value << s->bi_valid;
+ put_short(s, s->bi_buf);
+ s->bi_buf = (ush)value >> (Buf_size - s->bi_valid);
+ s->bi_valid += length - Buf_size;
+ } else {
+ s->bi_buf |= (ush)value << s->bi_valid;
+ s->bi_valid += length;
+ }
+}
+#else /* !DEBUG */
+
+#define send_bits(s, value, length) \
+{ int len = length;\
+ if (s->bi_valid > (int)Buf_size - len) {\
+ int val = value;\
+ s->bi_buf |= (ush)val << s->bi_valid;\
+ put_short(s, s->bi_buf);\
+ s->bi_buf = (ush)val >> (Buf_size - s->bi_valid);\
+ s->bi_valid += len - Buf_size;\
+ } else {\
+ s->bi_buf |= (ush)(value) << s->bi_valid;\
+ s->bi_valid += len;\
+ }\
+}
+#endif /* DEBUG */
+
+
+/* the arguments must not have side effects */
+
+/* ===========================================================================
+ * Initialize the various 'constant' tables.
+ */
+local void tr_static_init()
+{
+#if defined(GEN_TREES_H) || !defined(STDC)
+ static int static_init_done = 0;
+ int n; /* iterates over tree elements */
+ int bits; /* bit counter */
+ int length; /* length value */
+ int code; /* code value */
+ int dist; /* distance index */
+ ush bl_count[MAX_BITS+1];
+ /* number of codes at each bit length for an optimal tree */
+
+ if (static_init_done) return;
+
+ /* For some embedded targets, global variables are not initialized: */
+#ifdef NO_INIT_GLOBAL_POINTERS
+ static_l_desc.static_tree = static_ltree;
+ static_l_desc.extra_bits = extra_lbits;
+ static_d_desc.static_tree = static_dtree;
+ static_d_desc.extra_bits = extra_dbits;
+ static_bl_desc.extra_bits = extra_blbits;
+#endif
+
+ /* Initialize the mapping length (0..255) -> length code (0..28) */
+ length = 0;
+ for (code = 0; code < LENGTH_CODES-1; code++) {
+ base_length[code] = length;
+ for (n = 0; n < (1<<extra_lbits[code]); n++) {
+ _length_code[length++] = (uch)code;
+ }
+ }
+ Assert (length == 256, "tr_static_init: length != 256");
+ /* Note that the length 255 (match length 258) can be represented
+ * in two different ways: code 284 + 5 bits or code 285, so we
+ * overwrite length_code[255] to use the best encoding:
+ */
+ _length_code[length-1] = (uch)code;
+
+ /* Initialize the mapping dist (0..32K) -> dist code (0..29) */
+ dist = 0;
+ for (code = 0 ; code < 16; code++) {
+ base_dist[code] = dist;
+ for (n = 0; n < (1<<extra_dbits[code]); n++) {
+ _dist_code[dist++] = (uch)code;
+ }
+ }
+ Assert (dist == 256, "tr_static_init: dist != 256");
+ dist >>= 7; /* from now on, all distances are divided by 128 */
+ for ( ; code < D_CODES; code++) {
+ base_dist[code] = dist << 7;
+ for (n = 0; n < (1<<(extra_dbits[code]-7)); n++) {
+ _dist_code[256 + dist++] = (uch)code;
+ }
+ }
+ Assert (dist == 256, "tr_static_init: 256+dist != 512");
+
+ /* Construct the codes of the static literal tree */
+ for (bits = 0; bits <= MAX_BITS; bits++) bl_count[bits] = 0;
+ n = 0;
+ while (n <= 143) static_ltree[n++].Len = 8, bl_count[8]++;
+ while (n <= 255) static_ltree[n++].Len = 9, bl_count[9]++;
+ while (n <= 279) static_ltree[n++].Len = 7, bl_count[7]++;
+ while (n <= 287) static_ltree[n++].Len = 8, bl_count[8]++;
+ /* Codes 286 and 287 do not exist, but we must include them in the
+ * tree construction to get a canonical Huffman tree (longest code
+ * all ones)
+ */
+ gen_codes((ct_data *)static_ltree, L_CODES+1, bl_count);
+
+ /* The static distance tree is trivial: */
+ for (n = 0; n < D_CODES; n++) {
+ static_dtree[n].Len = 5;
+ static_dtree[n].Code = bi_reverse((unsigned)n, 5);
+ }
+ static_init_done = 1;
+
+# ifdef GEN_TREES_H
+ gen_trees_header();
+# endif
+#endif /* defined(GEN_TREES_H) || !defined(STDC) */
+}
+
+/* ===========================================================================
+ * Genererate the file trees.h describing the static trees.
+ */
+#ifdef GEN_TREES_H
+# ifndef DEBUG
+# include <stdio.h>
+# endif
+
+# define SEPARATOR(i, last, width) \
+ ((i) == (last)? "\n};\n\n" : \
+ ((i) % (width) == (width)-1 ? ",\n" : ", "))
+
+void gen_trees_header()
+{
+ FILE *header = fopen("trees.h", "w");
+ int i;
+
+ Assert (header != NULL, "Can't open trees.h");
+ fprintf(header,
+ "/* header created automatically with -DGEN_TREES_H */\n\n");
+
+ fprintf(header, "local const ct_data static_ltree[L_CODES+2] = {\n");
+ for (i = 0; i < L_CODES+2; i++) {
+ fprintf(header, "{{%3u},{%3u}}%s", static_ltree[i].Code,
+ static_ltree[i].Len, SEPARATOR(i, L_CODES+1, 5));
+ }
+
+ fprintf(header, "local const ct_data static_dtree[D_CODES] = {\n");
+ for (i = 0; i < D_CODES; i++) {
+ fprintf(header, "{{%2u},{%2u}}%s", static_dtree[i].Code,
+ static_dtree[i].Len, SEPARATOR(i, D_CODES-1, 5));
+ }
+
+ fprintf(header, "const uch ZLIB_INTERNAL _dist_code[DIST_CODE_LEN] = {\n");
+ for (i = 0; i < DIST_CODE_LEN; i++) {
+ fprintf(header, "%2u%s", _dist_code[i],
+ SEPARATOR(i, DIST_CODE_LEN-1, 20));
+ }
+
+ fprintf(header,
+ "const uch ZLIB_INTERNAL _length_code[MAX_MATCH-MIN_MATCH+1]= {\n");
+ for (i = 0; i < MAX_MATCH-MIN_MATCH+1; i++) {
+ fprintf(header, "%2u%s", _length_code[i],
+ SEPARATOR(i, MAX_MATCH-MIN_MATCH, 20));
+ }
+
+ fprintf(header, "local const int base_length[LENGTH_CODES] = {\n");
+ for (i = 0; i < LENGTH_CODES; i++) {
+ fprintf(header, "%1u%s", base_length[i],
+ SEPARATOR(i, LENGTH_CODES-1, 20));
+ }
+
+ fprintf(header, "local const int base_dist[D_CODES] = {\n");
+ for (i = 0; i < D_CODES; i++) {
+ fprintf(header, "%5u%s", base_dist[i],
+ SEPARATOR(i, D_CODES-1, 10));
+ }
+
+ fclose(header);
+}
+#endif /* GEN_TREES_H */
+
+/* ===========================================================================
+ * Initialize the tree data structures for a new zlib stream.
+ */
+void ZLIB_INTERNAL _tr_init(s)
+ deflate_state *s;
+{
+ tr_static_init();
+
+ s->l_desc.dyn_tree = s->dyn_ltree;
+ s->l_desc.stat_desc = &static_l_desc;
+
+ s->d_desc.dyn_tree = s->dyn_dtree;
+ s->d_desc.stat_desc = &static_d_desc;
+
+ s->bl_desc.dyn_tree = s->bl_tree;
+ s->bl_desc.stat_desc = &static_bl_desc;
+
+ s->bi_buf = 0;
+ s->bi_valid = 0;
+#ifdef DEBUG
+ s->compressed_len = 0L;
+ s->bits_sent = 0L;
+#endif
+
+ /* Initialize the first block of the first file: */
+ init_block(s);
+}
+
+/* ===========================================================================
+ * Initialize a new block.
+ */
+local void init_block(s)
+ deflate_state *s;
+{
+ int n; /* iterates over tree elements */
+
+ /* Initialize the trees. */
+ for (n = 0; n < L_CODES; n++) s->dyn_ltree[n].Freq = 0;
+ for (n = 0; n < D_CODES; n++) s->dyn_dtree[n].Freq = 0;
+ for (n = 0; n < BL_CODES; n++) s->bl_tree[n].Freq = 0;
+
+ s->dyn_ltree[END_BLOCK].Freq = 1;
+ s->opt_len = s->static_len = 0L;
+ s->last_lit = s->matches = 0;
+}
+
+#define SMALLEST 1
+/* Index within the heap array of least frequent node in the Huffman tree */
+
+
+/* ===========================================================================
+ * Remove the smallest element from the heap and recreate the heap with
+ * one less element. Updates heap and heap_len.
+ */
+#define pqremove(s, tree, top) \
+{\
+ top = s->heap[SMALLEST]; \
+ s->heap[SMALLEST] = s->heap[s->heap_len--]; \
+ pqdownheap(s, tree, SMALLEST); \
+}
+
+/* ===========================================================================
+ * Compares to subtrees, using the tree depth as tie breaker when
+ * the subtrees have equal frequency. This minimizes the worst case length.
+ */
+#define smaller(tree, n, m, depth) \
+ (tree[n].Freq < tree[m].Freq || \
+ (tree[n].Freq == tree[m].Freq && depth[n] <= depth[m]))
+
+/* ===========================================================================
+ * Restore the heap property by moving down the tree starting at node k,
+ * exchanging a node with the smallest of its two sons if necessary, stopping
+ * when the heap property is re-established (each father smaller than its
+ * two sons).
+ */
+local void pqdownheap(s, tree, k)
+ deflate_state *s;
+ ct_data *tree; /* the tree to restore */
+ int k; /* node to move down */
+{
+ int v = s->heap[k];
+ int j = k << 1; /* left son of k */
+ while (j <= s->heap_len) {
+ /* Set j to the smallest of the two sons: */
+ if (j < s->heap_len &&
+ smaller(tree, s->heap[j+1], s->heap[j], s->depth)) {
+ j++;
+ }
+ /* Exit if v is smaller than both sons */
+ if (smaller(tree, v, s->heap[j], s->depth)) break;
+
+ /* Exchange v with the smallest son */
+ s->heap[k] = s->heap[j]; k = j;
+
+ /* And continue down the tree, setting j to the left son of k */
+ j <<= 1;
+ }
+ s->heap[k] = v;
+}
+
+/* ===========================================================================
+ * Compute the optimal bit lengths for a tree and update the total bit length
+ * for the current block.
+ * IN assertion: the fields freq and dad are set, heap[heap_max] and
+ * above are the tree nodes sorted by increasing frequency.
+ * OUT assertions: the field len is set to the optimal bit length, the
+ * array bl_count contains the frequencies for each bit length.
+ * The length opt_len is updated; static_len is also updated if stree is
+ * not null.
+ */
+local void gen_bitlen(s, desc)
+ deflate_state *s;
+ tree_desc *desc; /* the tree descriptor */
+{
+ ct_data *tree = desc->dyn_tree;
+ int max_code = desc->max_code;
+ const ct_data *stree = desc->stat_desc->static_tree;
+ const intf *extra = desc->stat_desc->extra_bits;
+ int base = desc->stat_desc->extra_base;
+ int max_length = desc->stat_desc->max_length;
+ int h; /* heap index */
+ int n, m; /* iterate over the tree elements */
+ int bits; /* bit length */
+ int xbits; /* extra bits */
+ ush f; /* frequency */
+ int overflow = 0; /* number of elements with bit length too large */
+
+ for (bits = 0; bits <= MAX_BITS; bits++) s->bl_count[bits] = 0;
+
+ /* In a first pass, compute the optimal bit lengths (which may
+ * overflow in the case of the bit length tree).
+ */
+ tree[s->heap[s->heap_max]].Len = 0; /* root of the heap */
+
+ for (h = s->heap_max+1; h < HEAP_SIZE; h++) {
+ n = s->heap[h];
+ bits = tree[tree[n].Dad].Len + 1;
+ if (bits > max_length) bits = max_length, overflow++;
+ tree[n].Len = (ush)bits;
+ /* We overwrite tree[n].Dad which is no longer needed */
+
+ if (n > max_code) continue; /* not a leaf node */
+
+ s->bl_count[bits]++;
+ xbits = 0;
+ if (n >= base) xbits = extra[n-base];
+ f = tree[n].Freq;
+ s->opt_len += (ulg)f * (bits + xbits);
+ if (stree) s->static_len += (ulg)f * (stree[n].Len + xbits);
+ }
+ if (overflow == 0) return;
+
+ Trace((stderr,"\nbit length overflow\n"));
+ /* This happens for example on obj2 and pic of the Calgary corpus */
+
+ /* Find the first bit length which could increase: */
+ do {
+ bits = max_length-1;
+ while (s->bl_count[bits] == 0) bits--;
+ s->bl_count[bits]--; /* move one leaf down the tree */
+ s->bl_count[bits+1] += 2; /* move one overflow item as its brother */
+ s->bl_count[max_length]--;
+ /* The brother of the overflow item also moves one step up,
+ * but this does not affect bl_count[max_length]
+ */
+ overflow -= 2;
+ } while (overflow > 0);
+
+ /* Now recompute all bit lengths, scanning in increasing frequency.
+ * h is still equal to HEAP_SIZE. (It is simpler to reconstruct all
+ * lengths instead of fixing only the wrong ones. This idea is taken
+ * from 'ar' written by Haruhiko Okumura.)
+ */
+ for (bits = max_length; bits != 0; bits--) {
+ n = s->bl_count[bits];
+ while (n != 0) {
+ m = s->heap[--h];
+ if (m > max_code) continue;
+ if ((unsigned) tree[m].Len != (unsigned) bits) {
+ Trace((stderr,"code %d bits %d->%d\n", m, tree[m].Len, bits));
+ s->opt_len += ((long)bits - (long)tree[m].Len)
+ *(long)tree[m].Freq;
+ tree[m].Len = (ush)bits;
+ }
+ n--;
+ }
+ }
+}
+
+/* ===========================================================================
+ * Generate the codes for a given tree and bit counts (which need not be
+ * optimal).
+ * IN assertion: the array bl_count contains the bit length statistics for
+ * the given tree and the field len is set for all tree elements.
+ * OUT assertion: the field code is set for all tree elements of non
+ * zero code length.
+ */
+local void gen_codes (tree, max_code, bl_count)
+ ct_data *tree; /* the tree to decorate */
+ int max_code; /* largest code with non zero frequency */
+ ushf *bl_count; /* number of codes at each bit length */
+{
+ ush next_code[MAX_BITS+1]; /* next code value for each bit length */
+ ush code = 0; /* running code value */
+ int bits; /* bit index */
+ int n; /* code index */
+
+ /* The distribution counts are first used to generate the code values
+ * without bit reversal.
+ */
+ for (bits = 1; bits <= MAX_BITS; bits++) {
+ next_code[bits] = code = (code + bl_count[bits-1]) << 1;
+ }
+ /* Check that the bit counts in bl_count are consistent. The last code
+ * must be all ones.
+ */
+ Assert (code + bl_count[MAX_BITS]-1 == (1<<MAX_BITS)-1,
+ "inconsistent bit counts");
+ Tracev((stderr,"\ngen_codes: max_code %d ", max_code));
+
+ for (n = 0; n <= max_code; n++) {
+ int len = tree[n].Len;
+ if (len == 0) continue;
+ /* Now reverse the bits */
+ tree[n].Code = bi_reverse(next_code[len]++, len);
+
+ Tracecv(tree != static_ltree, (stderr,"\nn %3d %c l %2d c %4x (%x) ",
+ n, (isgraph(n) ? n : ' '), len, tree[n].Code, next_code[len]-1));
+ }
+}
+
+/* ===========================================================================
+ * Construct one Huffman tree and assigns the code bit strings and lengths.
+ * Update the total bit length for the current block.
+ * IN assertion: the field freq is set for all tree elements.
+ * OUT assertions: the fields len and code are set to the optimal bit length
+ * and corresponding code. The length opt_len is updated; static_len is
+ * also updated if stree is not null. The field max_code is set.
+ */
+local void build_tree(s, desc)
+ deflate_state *s;
+ tree_desc *desc; /* the tree descriptor */
+{
+ ct_data *tree = desc->dyn_tree;
+ const ct_data *stree = desc->stat_desc->static_tree;
+ int elems = desc->stat_desc->elems;
+ int n, m; /* iterate over heap elements */
+ int max_code = -1; /* largest code with non zero frequency */
+ int node; /* new node being created */
+
+ /* Construct the initial heap, with least frequent element in
+ * heap[SMALLEST]. The sons of heap[n] are heap[2*n] and heap[2*n+1].
+ * heap[0] is not used.
+ */
+ s->heap_len = 0, s->heap_max = HEAP_SIZE;
+
+ for (n = 0; n < elems; n++) {
+ if (tree[n].Freq != 0) {
+ s->heap[++(s->heap_len)] = max_code = n;
+ s->depth[n] = 0;
+ } else {
+ tree[n].Len = 0;
+ }
+ }
+
+ /* The pkzip format requires that at least one distance code exists,
+ * and that at least one bit should be sent even if there is only one
+ * possible code. So to avoid special checks later on we force at least
+ * two codes of non zero frequency.
+ */
+ while (s->heap_len < 2) {
+ node = s->heap[++(s->heap_len)] = (max_code < 2 ? ++max_code : 0);
+ tree[node].Freq = 1;
+ s->depth[node] = 0;
+ s->opt_len--; if (stree) s->static_len -= stree[node].Len;
+ /* node is 0 or 1 so it does not have extra bits */
+ }
+ desc->max_code = max_code;
+
+ /* The elements heap[heap_len/2+1 .. heap_len] are leaves of the tree,
+ * establish sub-heaps of increasing lengths:
+ */
+ for (n = s->heap_len/2; n >= 1; n--) pqdownheap(s, tree, n);
+
+ /* Construct the Huffman tree by repeatedly combining the least two
+ * frequent nodes.
+ */
+ node = elems; /* next internal node of the tree */
+ do {
+ pqremove(s, tree, n); /* n = node of least frequency */
+ m = s->heap[SMALLEST]; /* m = node of next least frequency */
+
+ s->heap[--(s->heap_max)] = n; /* keep the nodes sorted by frequency */
+ s->heap[--(s->heap_max)] = m;
+
+ /* Create a new node father of n and m */
+ tree[node].Freq = tree[n].Freq + tree[m].Freq;
+ s->depth[node] = (uch)((s->depth[n] >= s->depth[m] ?
+ s->depth[n] : s->depth[m]) + 1);
+ tree[n].Dad = tree[m].Dad = (ush)node;
+#ifdef DUMP_BL_TREE
+ if (tree == s->bl_tree) {
+ fprintf(stderr,"\nnode %d(%d), sons %d(%d) %d(%d)",
+ node, tree[node].Freq, n, tree[n].Freq, m, tree[m].Freq);
+ }
+#endif
+ /* and insert the new node in the heap */
+ s->heap[SMALLEST] = node++;
+ pqdownheap(s, tree, SMALLEST);
+
+ } while (s->heap_len >= 2);
+
+ s->heap[--(s->heap_max)] = s->heap[SMALLEST];
+
+ /* At this point, the fields freq and dad are set. We can now
+ * generate the bit lengths.
+ */
+ gen_bitlen(s, (tree_desc *)desc);
+
+ /* The field len is now set, we can generate the bit codes */
+ gen_codes ((ct_data *)tree, max_code, s->bl_count);
+}
+
+/* ===========================================================================
+ * Scan a literal or distance tree to determine the frequencies of the codes
+ * in the bit length tree.
+ */
+local void scan_tree (s, tree, max_code)
+ deflate_state *s;
+ ct_data *tree; /* the tree to be scanned */
+ int max_code; /* and its largest code of non zero frequency */
+{
+ int n; /* iterates over all tree elements */
+ int prevlen = -1; /* last emitted length */
+ int curlen; /* length of current code */
+ int nextlen = tree[0].Len; /* length of next code */
+ int count = 0; /* repeat count of the current code */
+ int max_count = 7; /* max repeat count */
+ int min_count = 4; /* min repeat count */
+
+ if (nextlen == 0) max_count = 138, min_count = 3;
+ tree[max_code+1].Len = (ush)0xffff; /* guard */
+
+ for (n = 0; n <= max_code; n++) {
+ curlen = nextlen; nextlen = tree[n+1].Len;
+ if (++count < max_count && curlen == nextlen) {
+ continue;
+ } else if (count < min_count) {
+ s->bl_tree[curlen].Freq += count;
+ } else if (curlen != 0) {
+ if (curlen != prevlen) s->bl_tree[curlen].Freq++;
+ s->bl_tree[REP_3_6].Freq++;
+ } else if (count <= 10) {
+ s->bl_tree[REPZ_3_10].Freq++;
+ } else {
+ s->bl_tree[REPZ_11_138].Freq++;
+ }
+ count = 0; prevlen = curlen;
+ if (nextlen == 0) {
+ max_count = 138, min_count = 3;
+ } else if (curlen == nextlen) {
+ max_count = 6, min_count = 3;
+ } else {
+ max_count = 7, min_count = 4;
+ }
+ }
+}
+
+/* ===========================================================================
+ * Send a literal or distance tree in compressed form, using the codes in
+ * bl_tree.
+ */
+local void send_tree (s, tree, max_code)
+ deflate_state *s;
+ ct_data *tree; /* the tree to be scanned */
+ int max_code; /* and its largest code of non zero frequency */
+{
+ int n; /* iterates over all tree elements */
+ int prevlen = -1; /* last emitted length */
+ int curlen; /* length of current code */
+ int nextlen = tree[0].Len; /* length of next code */
+ int count = 0; /* repeat count of the current code */
+ int max_count = 7; /* max repeat count */
+ int min_count = 4; /* min repeat count */
+
+ /* tree[max_code+1].Len = -1; */ /* guard already set */
+ if (nextlen == 0) max_count = 138, min_count = 3;
+
+ for (n = 0; n <= max_code; n++) {
+ curlen = nextlen; nextlen = tree[n+1].Len;
+ if (++count < max_count && curlen == nextlen) {
+ continue;
+ } else if (count < min_count) {
+ do { send_code(s, curlen, s->bl_tree); } while (--count != 0);
+
+ } else if (curlen != 0) {
+ if (curlen != prevlen) {
+ send_code(s, curlen, s->bl_tree); count--;
+ }
+ Assert(count >= 3 && count <= 6, " 3_6?");
+ send_code(s, REP_3_6, s->bl_tree); send_bits(s, count-3, 2);
+
+ } else if (count <= 10) {
+ send_code(s, REPZ_3_10, s->bl_tree); send_bits(s, count-3, 3);
+
+ } else {
+ send_code(s, REPZ_11_138, s->bl_tree); send_bits(s, count-11, 7);
+ }
+ count = 0; prevlen = curlen;
+ if (nextlen == 0) {
+ max_count = 138, min_count = 3;
+ } else if (curlen == nextlen) {
+ max_count = 6, min_count = 3;
+ } else {
+ max_count = 7, min_count = 4;
+ }
+ }
+}
+
+/* ===========================================================================
+ * Construct the Huffman tree for the bit lengths and return the index in
+ * bl_order of the last bit length code to send.
+ */
+local int build_bl_tree(s)
+ deflate_state *s;
+{
+ int max_blindex; /* index of last bit length code of non zero freq */
+
+ /* Determine the bit length frequencies for literal and distance trees */
+ scan_tree(s, (ct_data *)s->dyn_ltree, s->l_desc.max_code);
+ scan_tree(s, (ct_data *)s->dyn_dtree, s->d_desc.max_code);
+
+ /* Build the bit length tree: */
+ build_tree(s, (tree_desc *)(&(s->bl_desc)));
+ /* opt_len now includes the length of the tree representations, except
+ * the lengths of the bit lengths codes and the 5+5+4 bits for the counts.
+ */
+
+ /* Determine the number of bit length codes to send. The pkzip format
+ * requires that at least 4 bit length codes be sent. (appnote.txt says
+ * 3 but the actual value used is 4.)
+ */
+ for (max_blindex = BL_CODES-1; max_blindex >= 3; max_blindex--) {
+ if (s->bl_tree[bl_order[max_blindex]].Len != 0) break;
+ }
+ /* Update opt_len to include the bit length tree and counts */
+ s->opt_len += 3*(max_blindex+1) + 5+5+4;
+ Tracev((stderr, "\ndyn trees: dyn %ld, stat %ld",
+ s->opt_len, s->static_len));
+
+ return max_blindex;
+}
+
+/* ===========================================================================
+ * Send the header for a block using dynamic Huffman trees: the counts, the
+ * lengths of the bit length codes, the literal tree and the distance tree.
+ * IN assertion: lcodes >= 257, dcodes >= 1, blcodes >= 4.
+ */
+local void send_all_trees(s, lcodes, dcodes, blcodes)
+ deflate_state *s;
+ int lcodes, dcodes, blcodes; /* number of codes for each tree */
+{
+ int rank; /* index in bl_order */
+
+ Assert (lcodes >= 257 && dcodes >= 1 && blcodes >= 4, "not enough codes");
+ Assert (lcodes <= L_CODES && dcodes <= D_CODES && blcodes <= BL_CODES,
+ "too many codes");
+ Tracev((stderr, "\nbl counts: "));
+ send_bits(s, lcodes-257, 5); /* not +255 as stated in appnote.txt */
+ send_bits(s, dcodes-1, 5);
+ send_bits(s, blcodes-4, 4); /* not -3 as stated in appnote.txt */
+ for (rank = 0; rank < blcodes; rank++) {
+ Tracev((stderr, "\nbl code %2d ", bl_order[rank]));
+ send_bits(s, s->bl_tree[bl_order[rank]].Len, 3);
+ }
+ Tracev((stderr, "\nbl tree: sent %ld", s->bits_sent));
+
+ send_tree(s, (ct_data *)s->dyn_ltree, lcodes-1); /* literal tree */
+ Tracev((stderr, "\nlit tree: sent %ld", s->bits_sent));
+
+ send_tree(s, (ct_data *)s->dyn_dtree, dcodes-1); /* distance tree */
+ Tracev((stderr, "\ndist tree: sent %ld", s->bits_sent));
+}
+
+/* ===========================================================================
+ * Send a stored block
+ */
+void ZLIB_INTERNAL _tr_stored_block(s, buf, stored_len, last)
+ deflate_state *s;
+ charf *buf; /* input block */
+ ulg stored_len; /* length of input block */
+ int last; /* one if this is the last block for a file */
+{
+ send_bits(s, (STORED_BLOCK<<1)+last, 3); /* send block type */
+#ifdef DEBUG
+ s->compressed_len = (s->compressed_len + 3 + 7) & (ulg)~7L;
+ s->compressed_len += (stored_len + 4) << 3;
+#endif
+ copy_block(s, buf, (unsigned)stored_len, 1); /* with header */
+}
+
+/* ===========================================================================
+ * Flush the bits in the bit buffer to pending output (leaves at most 7 bits)
+ */
+void ZLIB_INTERNAL _tr_flush_bits(s)
+ deflate_state *s;
+{
+ bi_flush(s);
+}
+
+/* ===========================================================================
+ * Send one empty static block to give enough lookahead for inflate.
+ * This takes 10 bits, of which 7 may remain in the bit buffer.
+ */
+void ZLIB_INTERNAL _tr_align(s)
+ deflate_state *s;
+{
+ send_bits(s, STATIC_TREES<<1, 3);
+ send_code(s, END_BLOCK, static_ltree);
+#ifdef DEBUG
+ s->compressed_len += 10L; /* 3 for block type, 7 for EOB */
+#endif
+ bi_flush(s);
+}
+
+/* ===========================================================================
+ * Determine the best encoding for the current block: dynamic trees, static
+ * trees or store, and output the encoded block to the zip file.
+ */
+void ZLIB_INTERNAL _tr_flush_block(s, buf, stored_len, last)
+ deflate_state *s;
+ charf *buf; /* input block, or NULL if too old */
+ ulg stored_len; /* length of input block */
+ int last; /* one if this is the last block for a file */
+{
+ ulg opt_lenb, static_lenb; /* opt_len and static_len in bytes */
+ int max_blindex = 0; /* index of last bit length code of non zero freq */
+
+ /* Build the Huffman trees unless a stored block is forced */
+ if (s->level > 0) {
+
+ /* Check if the file is binary or text */
+ if (s->strm->data_type == Z_UNKNOWN)
+ s->strm->data_type = detect_data_type(s);
+
+ /* Construct the literal and distance trees */
+ build_tree(s, (tree_desc *)(&(s->l_desc)));
+ Tracev((stderr, "\nlit data: dyn %ld, stat %ld", s->opt_len,
+ s->static_len));
+
+ build_tree(s, (tree_desc *)(&(s->d_desc)));
+ Tracev((stderr, "\ndist data: dyn %ld, stat %ld", s->opt_len,
+ s->static_len));
+ /* At this point, opt_len and static_len are the total bit lengths of
+ * the compressed block data, excluding the tree representations.
+ */
+
+ /* Build the bit length tree for the above two trees, and get the index
+ * in bl_order of the last bit length code to send.
+ */
+ max_blindex = build_bl_tree(s);
+
+ /* Determine the best encoding. Compute the block lengths in bytes. */
+ opt_lenb = (s->opt_len+3+7)>>3;
+ static_lenb = (s->static_len+3+7)>>3;
+
+ Tracev((stderr, "\nopt %lu(%lu) stat %lu(%lu) stored %lu lit %u ",
+ opt_lenb, s->opt_len, static_lenb, s->static_len, stored_len,
+ s->last_lit));
+
+ if (static_lenb <= opt_lenb) opt_lenb = static_lenb;
+
+ } else {
+ Assert(buf != (char*)0, "lost buf");
+ opt_lenb = static_lenb = stored_len + 5; /* force a stored block */
+ }
+
+#ifdef FORCE_STORED
+ if (buf != (char*)0) { /* force stored block */
+#else
+ if (stored_len+4 <= opt_lenb && buf != (char*)0) {
+ /* 4: two words for the lengths */
+#endif
+ /* The test buf != NULL is only necessary if LIT_BUFSIZE > WSIZE.
+ * Otherwise we can't have processed more than WSIZE input bytes since
+ * the last block flush, because compression would have been
+ * successful. If LIT_BUFSIZE <= WSIZE, it is never too late to
+ * transform a block into a stored block.
+ */
+ _tr_stored_block(s, buf, stored_len, last);
+
+#ifdef FORCE_STATIC
+ } else if (static_lenb >= 0) { /* force static trees */
+#else
+ } else if (s->strategy == Z_FIXED || static_lenb == opt_lenb) {
+#endif
+ send_bits(s, (STATIC_TREES<<1)+last, 3);
+ compress_block(s, (const ct_data *)static_ltree,
+ (const ct_data *)static_dtree);
+#ifdef DEBUG
+ s->compressed_len += 3 + s->static_len;
+#endif
+ } else {
+ send_bits(s, (DYN_TREES<<1)+last, 3);
+ send_all_trees(s, s->l_desc.max_code+1, s->d_desc.max_code+1,
+ max_blindex+1);
+ compress_block(s, (const ct_data *)s->dyn_ltree,
+ (const ct_data *)s->dyn_dtree);
+#ifdef DEBUG
+ s->compressed_len += 3 + s->opt_len;
+#endif
+ }
+ Assert (s->compressed_len == s->bits_sent, "bad compressed size");
+ /* The above check is made mod 2^32, for files larger than 512 MB
+ * and uLong implemented on 32 bits.
+ */
+ init_block(s);
+
+ if (last) {
+ bi_windup(s);
+#ifdef DEBUG
+ s->compressed_len += 7; /* align on byte boundary */
+#endif
+ }
+ Tracev((stderr,"\ncomprlen %lu(%lu) ", s->compressed_len>>3,
+ s->compressed_len-7*last));
+}
+
+/* ===========================================================================
+ * Save the match info and tally the frequency counts. Return true if
+ * the current block must be flushed.
+ */
+int ZLIB_INTERNAL _tr_tally (s, dist, lc)
+ deflate_state *s;
+ unsigned dist; /* distance of matched string */
+ unsigned lc; /* match length-MIN_MATCH or unmatched char (if dist==0) */
+{
+ s->d_buf[s->last_lit] = (ush)dist;
+ s->l_buf[s->last_lit++] = (uch)lc;
+ if (dist == 0) {
+ /* lc is the unmatched char */
+ s->dyn_ltree[lc].Freq++;
+ } else {
+ s->matches++;
+ /* Here, lc is the match length - MIN_MATCH */
+ dist--; /* dist = match distance - 1 */
+ Assert((ush)dist < (ush)MAX_DIST(s) &&
+ (ush)lc <= (ush)(MAX_MATCH-MIN_MATCH) &&
+ (ush)d_code(dist) < (ush)D_CODES, "_tr_tally: bad match");
+
+ s->dyn_ltree[_length_code[lc]+LITERALS+1].Freq++;
+ s->dyn_dtree[d_code(dist)].Freq++;
+ }
+
+#ifdef TRUNCATE_BLOCK
+ /* Try to guess if it is profitable to stop the current block here */
+ if ((s->last_lit & 0x1fff) == 0 && s->level > 2) {
+ /* Compute an upper bound for the compressed length */
+ ulg out_length = (ulg)s->last_lit*8L;
+ ulg in_length = (ulg)((long)s->strstart - s->block_start);
+ int dcode;
+ for (dcode = 0; dcode < D_CODES; dcode++) {
+ out_length += (ulg)s->dyn_dtree[dcode].Freq *
+ (5L+extra_dbits[dcode]);
+ }
+ out_length >>= 3;
+ Tracev((stderr,"\nlast_lit %u, in %ld, out ~%ld(%ld%%) ",
+ s->last_lit, in_length, out_length,
+ 100L - out_length*100L/in_length));
+ if (s->matches < s->last_lit/2 && out_length < in_length/2) return 1;
+ }
+#endif
+ return (s->last_lit == s->lit_bufsize-1);
+ /* We avoid equality with lit_bufsize because of wraparound at 64K
+ * on 16 bit machines and because stored blocks are restricted to
+ * 64K-1 bytes.
+ */
+}
+
+/* ===========================================================================
+ * Send the block data compressed using the given Huffman trees
+ */
+local void compress_block(s, ltree, dtree)
+ deflate_state *s;
+ const ct_data *ltree; /* literal tree */
+ const ct_data *dtree; /* distance tree */
+{
+ unsigned dist; /* distance of matched string */
+ int lc; /* match length or unmatched char (if dist == 0) */
+ unsigned lx = 0; /* running index in l_buf */
+ unsigned code; /* the code to send */
+ int extra; /* number of extra bits to send */
+
+ if (s->last_lit != 0) do {
+ dist = s->d_buf[lx];
+ lc = s->l_buf[lx++];
+ if (dist == 0) {
+ send_code(s, lc, ltree); /* send a literal byte */
+ Tracecv(isgraph(lc), (stderr," '%c' ", lc));
+ } else {
+ /* Here, lc is the match length - MIN_MATCH */
+ code = _length_code[lc];
+ send_code(s, code+LITERALS+1, ltree); /* send the length code */
+ extra = extra_lbits[code];
+ if (extra != 0) {
+ lc -= base_length[code];
+ send_bits(s, lc, extra); /* send the extra length bits */
+ }
+ dist--; /* dist is now the match distance - 1 */
+ code = d_code(dist);
+ Assert (code < D_CODES, "bad d_code");
+
+ send_code(s, code, dtree); /* send the distance code */
+ extra = extra_dbits[code];
+ if (extra != 0) {
+ dist -= base_dist[code];
+ send_bits(s, dist, extra); /* send the extra distance bits */
+ }
+ } /* literal or match pair ? */
+
+ /* Check that the overlay between pending_buf and d_buf+l_buf is ok: */
+ Assert((uInt)(s->pending) < s->lit_bufsize + 2*lx,
+ "pendingBuf overflow");
+
+ } while (lx < s->last_lit);
+
+ send_code(s, END_BLOCK, ltree);
+}
+
+/* ===========================================================================
+ * Check if the data type is TEXT or BINARY, using the following algorithm:
+ * - TEXT if the two conditions below are satisfied:
+ * a) There are no non-portable control characters belonging to the
+ * "black list" (0..6, 14..25, 28..31).
+ * b) There is at least one printable character belonging to the
+ * "white list" (9 {TAB}, 10 {LF}, 13 {CR}, 32..255).
+ * - BINARY otherwise.
+ * - The following partially-portable control characters form a
+ * "gray list" that is ignored in this detection algorithm:
+ * (7 {BEL}, 8 {BS}, 11 {VT}, 12 {FF}, 26 {SUB}, 27 {ESC}).
+ * IN assertion: the fields Freq of dyn_ltree are set.
+ */
+local int detect_data_type(s)
+ deflate_state *s;
+{
+ /* black_mask is the bit mask of black-listed bytes
+ * set bits 0..6, 14..25, and 28..31
+ * 0xf3ffc07f = binary 11110011111111111100000001111111
+ */
+ unsigned long black_mask = 0xf3ffc07fUL;
+ int n;
+
+ /* Check for non-textual ("black-listed") bytes. */
+ for (n = 0; n <= 31; n++, black_mask >>= 1)
+ if ((black_mask & 1) && (s->dyn_ltree[n].Freq != 0))
+ return Z_BINARY;
+
+ /* Check for textual ("white-listed") bytes. */
+ if (s->dyn_ltree[9].Freq != 0 || s->dyn_ltree[10].Freq != 0
+ || s->dyn_ltree[13].Freq != 0)
+ return Z_TEXT;
+ for (n = 32; n < LITERALS; n++)
+ if (s->dyn_ltree[n].Freq != 0)
+ return Z_TEXT;
+
+ /* There are no "black-listed" or "white-listed" bytes:
+ * this stream either is empty or has tolerated ("gray-listed") bytes only.
+ */
+ return Z_BINARY;
+}
+
+/* ===========================================================================
+ * Reverse the first len bits of a code, using straightforward code (a faster
+ * method would use a table)
+ * IN assertion: 1 <= len <= 15
+ */
+local unsigned bi_reverse(code, len)
+ unsigned code; /* the value to invert */
+ int len; /* its bit length */
+{
+ register unsigned res = 0;
+ do {
+ res |= code & 1;
+ code >>= 1, res <<= 1;
+ } while (--len > 0);
+ return res >> 1;
+}
+
+/* ===========================================================================
+ * Flush the bit buffer, keeping at most 7 bits in it.
+ */
+local void bi_flush(s)
+ deflate_state *s;
+{
+ if (s->bi_valid == 16) {
+ put_short(s, s->bi_buf);
+ s->bi_buf = 0;
+ s->bi_valid = 0;
+ } else if (s->bi_valid >= 8) {
+ put_byte(s, (Byte)s->bi_buf);
+ s->bi_buf >>= 8;
+ s->bi_valid -= 8;
+ }
+}
+
+/* ===========================================================================
+ * Flush the bit buffer and align the output on a byte boundary
+ */
+local void bi_windup(s)
+ deflate_state *s;
+{
+ if (s->bi_valid > 8) {
+ put_short(s, s->bi_buf);
+ } else if (s->bi_valid > 0) {
+ put_byte(s, (Byte)s->bi_buf);
+ }
+ s->bi_buf = 0;
+ s->bi_valid = 0;
+#ifdef DEBUG
+ s->bits_sent = (s->bits_sent+7) & ~7;
+#endif
+}
+
+/* ===========================================================================
+ * Copy a stored block, storing first the length and its
+ * one's complement if requested.
+ */
+local void copy_block(s, buf, len, header)
+ deflate_state *s;
+ charf *buf; /* the input data */
+ unsigned len; /* its length */
+ int header; /* true if block header must be written */
+{
+ bi_windup(s); /* align on byte boundary */
+
+ if (header) {
+ put_short(s, (ush)len);
+ put_short(s, (ush)~len);
+#ifdef DEBUG
+ s->bits_sent += 2*16;
+#endif
+ }
+#ifdef DEBUG
+ s->bits_sent += (ulg)len<<3;
+#endif
+ while (len--) {
+ put_byte(s, *buf++);
+ }
+}
diff --git a/ml/dlib/dlib/external/zlib/trees.h b/ml/dlib/dlib/external/zlib/trees.h
new file mode 100644
index 000000000..d35639d82
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/trees.h
@@ -0,0 +1,128 @@
+/* header created automatically with -DGEN_TREES_H */
+
+local const ct_data static_ltree[L_CODES+2] = {
+{{ 12},{ 8}}, {{140},{ 8}}, {{ 76},{ 8}}, {{204},{ 8}}, {{ 44},{ 8}},
+{{172},{ 8}}, {{108},{ 8}}, {{236},{ 8}}, {{ 28},{ 8}}, {{156},{ 8}},
+{{ 92},{ 8}}, {{220},{ 8}}, {{ 60},{ 8}}, {{188},{ 8}}, {{124},{ 8}},
+{{252},{ 8}}, {{ 2},{ 8}}, {{130},{ 8}}, {{ 66},{ 8}}, {{194},{ 8}},
+{{ 34},{ 8}}, {{162},{ 8}}, {{ 98},{ 8}}, {{226},{ 8}}, {{ 18},{ 8}},
+{{146},{ 8}}, {{ 82},{ 8}}, {{210},{ 8}}, {{ 50},{ 8}}, {{178},{ 8}},
+{{114},{ 8}}, {{242},{ 8}}, {{ 10},{ 8}}, {{138},{ 8}}, {{ 74},{ 8}},
+{{202},{ 8}}, {{ 42},{ 8}}, {{170},{ 8}}, {{106},{ 8}}, {{234},{ 8}},
+{{ 26},{ 8}}, {{154},{ 8}}, {{ 90},{ 8}}, {{218},{ 8}}, {{ 58},{ 8}},
+{{186},{ 8}}, {{122},{ 8}}, {{250},{ 8}}, {{ 6},{ 8}}, {{134},{ 8}},
+{{ 70},{ 8}}, {{198},{ 8}}, {{ 38},{ 8}}, {{166},{ 8}}, {{102},{ 8}},
+{{230},{ 8}}, {{ 22},{ 8}}, {{150},{ 8}}, {{ 86},{ 8}}, {{214},{ 8}},
+{{ 54},{ 8}}, {{182},{ 8}}, {{118},{ 8}}, {{246},{ 8}}, {{ 14},{ 8}},
+{{142},{ 8}}, {{ 78},{ 8}}, {{206},{ 8}}, {{ 46},{ 8}}, {{174},{ 8}},
+{{110},{ 8}}, {{238},{ 8}}, {{ 30},{ 8}}, {{158},{ 8}}, {{ 94},{ 8}},
+{{222},{ 8}}, {{ 62},{ 8}}, {{190},{ 8}}, {{126},{ 8}}, {{254},{ 8}},
+{{ 1},{ 8}}, {{129},{ 8}}, {{ 65},{ 8}}, {{193},{ 8}}, {{ 33},{ 8}},
+{{161},{ 8}}, {{ 97},{ 8}}, {{225},{ 8}}, {{ 17},{ 8}}, {{145},{ 8}},
+{{ 81},{ 8}}, {{209},{ 8}}, {{ 49},{ 8}}, {{177},{ 8}}, {{113},{ 8}},
+{{241},{ 8}}, {{ 9},{ 8}}, {{137},{ 8}}, {{ 73},{ 8}}, {{201},{ 8}},
+{{ 41},{ 8}}, {{169},{ 8}}, {{105},{ 8}}, {{233},{ 8}}, {{ 25},{ 8}},
+{{153},{ 8}}, {{ 89},{ 8}}, {{217},{ 8}}, {{ 57},{ 8}}, {{185},{ 8}},
+{{121},{ 8}}, {{249},{ 8}}, {{ 5},{ 8}}, {{133},{ 8}}, {{ 69},{ 8}},
+{{197},{ 8}}, {{ 37},{ 8}}, {{165},{ 8}}, {{101},{ 8}}, {{229},{ 8}},
+{{ 21},{ 8}}, {{149},{ 8}}, {{ 85},{ 8}}, {{213},{ 8}}, {{ 53},{ 8}},
+{{181},{ 8}}, {{117},{ 8}}, {{245},{ 8}}, {{ 13},{ 8}}, {{141},{ 8}},
+{{ 77},{ 8}}, {{205},{ 8}}, {{ 45},{ 8}}, {{173},{ 8}}, {{109},{ 8}},
+{{237},{ 8}}, {{ 29},{ 8}}, {{157},{ 8}}, {{ 93},{ 8}}, {{221},{ 8}},
+{{ 61},{ 8}}, {{189},{ 8}}, {{125},{ 8}}, {{253},{ 8}}, {{ 19},{ 9}},
+{{275},{ 9}}, {{147},{ 9}}, {{403},{ 9}}, {{ 83},{ 9}}, {{339},{ 9}},
+{{211},{ 9}}, {{467},{ 9}}, {{ 51},{ 9}}, {{307},{ 9}}, {{179},{ 9}},
+{{435},{ 9}}, {{115},{ 9}}, {{371},{ 9}}, {{243},{ 9}}, {{499},{ 9}},
+{{ 11},{ 9}}, {{267},{ 9}}, {{139},{ 9}}, {{395},{ 9}}, {{ 75},{ 9}},
+{{331},{ 9}}, {{203},{ 9}}, {{459},{ 9}}, {{ 43},{ 9}}, {{299},{ 9}},
+{{171},{ 9}}, {{427},{ 9}}, {{107},{ 9}}, {{363},{ 9}}, {{235},{ 9}},
+{{491},{ 9}}, {{ 27},{ 9}}, {{283},{ 9}}, {{155},{ 9}}, {{411},{ 9}},
+{{ 91},{ 9}}, {{347},{ 9}}, {{219},{ 9}}, {{475},{ 9}}, {{ 59},{ 9}},
+{{315},{ 9}}, {{187},{ 9}}, {{443},{ 9}}, {{123},{ 9}}, {{379},{ 9}},
+{{251},{ 9}}, {{507},{ 9}}, {{ 7},{ 9}}, {{263},{ 9}}, {{135},{ 9}},
+{{391},{ 9}}, {{ 71},{ 9}}, {{327},{ 9}}, {{199},{ 9}}, {{455},{ 9}},
+{{ 39},{ 9}}, {{295},{ 9}}, {{167},{ 9}}, {{423},{ 9}}, {{103},{ 9}},
+{{359},{ 9}}, {{231},{ 9}}, {{487},{ 9}}, {{ 23},{ 9}}, {{279},{ 9}},
+{{151},{ 9}}, {{407},{ 9}}, {{ 87},{ 9}}, {{343},{ 9}}, {{215},{ 9}},
+{{471},{ 9}}, {{ 55},{ 9}}, {{311},{ 9}}, {{183},{ 9}}, {{439},{ 9}},
+{{119},{ 9}}, {{375},{ 9}}, {{247},{ 9}}, {{503},{ 9}}, {{ 15},{ 9}},
+{{271},{ 9}}, {{143},{ 9}}, {{399},{ 9}}, {{ 79},{ 9}}, {{335},{ 9}},
+{{207},{ 9}}, {{463},{ 9}}, {{ 47},{ 9}}, {{303},{ 9}}, {{175},{ 9}},
+{{431},{ 9}}, {{111},{ 9}}, {{367},{ 9}}, {{239},{ 9}}, {{495},{ 9}},
+{{ 31},{ 9}}, {{287},{ 9}}, {{159},{ 9}}, {{415},{ 9}}, {{ 95},{ 9}},
+{{351},{ 9}}, {{223},{ 9}}, {{479},{ 9}}, {{ 63},{ 9}}, {{319},{ 9}},
+{{191},{ 9}}, {{447},{ 9}}, {{127},{ 9}}, {{383},{ 9}}, {{255},{ 9}},
+{{511},{ 9}}, {{ 0},{ 7}}, {{ 64},{ 7}}, {{ 32},{ 7}}, {{ 96},{ 7}},
+{{ 16},{ 7}}, {{ 80},{ 7}}, {{ 48},{ 7}}, {{112},{ 7}}, {{ 8},{ 7}},
+{{ 72},{ 7}}, {{ 40},{ 7}}, {{104},{ 7}}, {{ 24},{ 7}}, {{ 88},{ 7}},
+{{ 56},{ 7}}, {{120},{ 7}}, {{ 4},{ 7}}, {{ 68},{ 7}}, {{ 36},{ 7}},
+{{100},{ 7}}, {{ 20},{ 7}}, {{ 84},{ 7}}, {{ 52},{ 7}}, {{116},{ 7}},
+{{ 3},{ 8}}, {{131},{ 8}}, {{ 67},{ 8}}, {{195},{ 8}}, {{ 35},{ 8}},
+{{163},{ 8}}, {{ 99},{ 8}}, {{227},{ 8}}
+};
+
+local const ct_data static_dtree[D_CODES] = {
+{{ 0},{ 5}}, {{16},{ 5}}, {{ 8},{ 5}}, {{24},{ 5}}, {{ 4},{ 5}},
+{{20},{ 5}}, {{12},{ 5}}, {{28},{ 5}}, {{ 2},{ 5}}, {{18},{ 5}},
+{{10},{ 5}}, {{26},{ 5}}, {{ 6},{ 5}}, {{22},{ 5}}, {{14},{ 5}},
+{{30},{ 5}}, {{ 1},{ 5}}, {{17},{ 5}}, {{ 9},{ 5}}, {{25},{ 5}},
+{{ 5},{ 5}}, {{21},{ 5}}, {{13},{ 5}}, {{29},{ 5}}, {{ 3},{ 5}},
+{{19},{ 5}}, {{11},{ 5}}, {{27},{ 5}}, {{ 7},{ 5}}, {{23},{ 5}}
+};
+
+const uch ZLIB_INTERNAL _dist_code[DIST_CODE_LEN] = {
+ 0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
+ 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
+10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
+11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
+12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13,
+13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
+14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
+14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
+14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
+15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 0, 0, 16, 17,
+18, 18, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22,
+23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
+26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
+26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27,
+27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
+27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
+28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
+28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
+28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
+29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
+29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
+29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29
+};
+
+const uch ZLIB_INTERNAL _length_code[MAX_MATCH-MIN_MATCH+1]= {
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 12, 12,
+13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16,
+17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19,
+19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
+21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22,
+22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
+23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
+25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26,
+26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
+26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
+27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28
+};
+
+local const int base_length[LENGTH_CODES] = {
+0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56,
+64, 80, 96, 112, 128, 160, 192, 224, 0
+};
+
+local const int base_dist[D_CODES] = {
+ 0, 1, 2, 3, 4, 6, 8, 12, 16, 24,
+ 32, 48, 64, 96, 128, 192, 256, 384, 512, 768,
+ 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 24576
+};
+
diff --git a/ml/dlib/dlib/external/zlib/uncompr.c b/ml/dlib/dlib/external/zlib/uncompr.c
new file mode 100644
index 000000000..242e9493d
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/uncompr.c
@@ -0,0 +1,59 @@
+/* uncompr.c -- decompress a memory buffer
+ * Copyright (C) 1995-2003, 2010 Jean-loup Gailly.
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* @(#) $Id$ */
+
+#define ZLIB_INTERNAL
+#include "zlib.h"
+
+/* ===========================================================================
+ Decompresses the source buffer into the destination buffer. sourceLen is
+ the byte length of the source buffer. Upon entry, destLen is the total
+ size of the destination buffer, which must be large enough to hold the
+ entire uncompressed data. (The size of the uncompressed data must have
+ been saved previously by the compressor and transmitted to the decompressor
+ by some mechanism outside the scope of this compression library.)
+ Upon exit, destLen is the actual size of the compressed buffer.
+
+ uncompress returns Z_OK if success, Z_MEM_ERROR if there was not
+ enough memory, Z_BUF_ERROR if there was not enough room in the output
+ buffer, or Z_DATA_ERROR if the input data was corrupted.
+*/
+int ZEXPORT uncompress (dest, destLen, source, sourceLen)
+ Bytef *dest;
+ uLongf *destLen;
+ const Bytef *source;
+ uLong sourceLen;
+{
+ z_stream stream;
+ int err;
+
+ stream.next_in = (z_const Bytef *)source;
+ stream.avail_in = (uInt)sourceLen;
+ /* Check for source > 64K on 16-bit machine: */
+ if ((uLong)stream.avail_in != sourceLen) return Z_BUF_ERROR;
+
+ stream.next_out = dest;
+ stream.avail_out = (uInt)*destLen;
+ if ((uLong)stream.avail_out != *destLen) return Z_BUF_ERROR;
+
+ stream.zalloc = (alloc_func)0;
+ stream.zfree = (free_func)0;
+
+ err = inflateInit(&stream);
+ if (err != Z_OK) return err;
+
+ err = inflate(&stream, Z_FINISH);
+ if (err != Z_STREAM_END) {
+ inflateEnd(&stream);
+ if (err == Z_NEED_DICT || (err == Z_BUF_ERROR && stream.avail_in == 0))
+ return Z_DATA_ERROR;
+ return err;
+ }
+ *destLen = stream.total_out;
+
+ err = inflateEnd(&stream);
+ return err;
+}
diff --git a/ml/dlib/dlib/external/zlib/zconf.h b/ml/dlib/dlib/external/zlib/zconf.h
new file mode 100644
index 000000000..9987a7755
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/zconf.h
@@ -0,0 +1,511 @@
+/* zconf.h -- configuration of the zlib compression library
+ * Copyright (C) 1995-2013 Jean-loup Gailly.
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* @(#) $Id$ */
+
+#ifndef ZCONF_H
+#define ZCONF_H
+
+/*
+ * If you *really* need a unique prefix for all types and library functions,
+ * compile with -DZ_PREFIX. The "standard" zlib should be compiled without it.
+ * Even better than compiling with -DZ_PREFIX would be to use configure to set
+ * this permanently in zconf.h using "./configure --zprefix".
+ */
+#ifdef Z_PREFIX /* may be set to #if 1 by ./configure */
+# define Z_PREFIX_SET
+
+/* all linked symbols */
+# define _dist_code z__dist_code
+# define _length_code z__length_code
+# define _tr_align z__tr_align
+# define _tr_flush_bits z__tr_flush_bits
+# define _tr_flush_block z__tr_flush_block
+# define _tr_init z__tr_init
+# define _tr_stored_block z__tr_stored_block
+# define _tr_tally z__tr_tally
+# define adler32 z_adler32
+# define adler32_combine z_adler32_combine
+# define adler32_combine64 z_adler32_combine64
+# ifndef Z_SOLO
+# define compress z_compress
+# define compress2 z_compress2
+# define compressBound z_compressBound
+# endif
+# define crc32 z_crc32
+# define crc32_combine z_crc32_combine
+# define crc32_combine64 z_crc32_combine64
+# define deflate z_deflate
+# define deflateBound z_deflateBound
+# define deflateCopy z_deflateCopy
+# define deflateEnd z_deflateEnd
+# define deflateInit2_ z_deflateInit2_
+# define deflateInit_ z_deflateInit_
+# define deflateParams z_deflateParams
+# define deflatePending z_deflatePending
+# define deflatePrime z_deflatePrime
+# define deflateReset z_deflateReset
+# define deflateResetKeep z_deflateResetKeep
+# define deflateSetDictionary z_deflateSetDictionary
+# define deflateSetHeader z_deflateSetHeader
+# define deflateTune z_deflateTune
+# define deflate_copyright z_deflate_copyright
+# define get_crc_table z_get_crc_table
+# ifndef Z_SOLO
+# define gz_error z_gz_error
+# define gz_intmax z_gz_intmax
+# define gz_strwinerror z_gz_strwinerror
+# define gzbuffer z_gzbuffer
+# define gzclearerr z_gzclearerr
+# define gzclose z_gzclose
+# define gzclose_r z_gzclose_r
+# define gzclose_w z_gzclose_w
+# define gzdirect z_gzdirect
+# define gzdopen z_gzdopen
+# define gzeof z_gzeof
+# define gzerror z_gzerror
+# define gzflush z_gzflush
+# define gzgetc z_gzgetc
+# define gzgetc_ z_gzgetc_
+# define gzgets z_gzgets
+# define gzoffset z_gzoffset
+# define gzoffset64 z_gzoffset64
+# define gzopen z_gzopen
+# define gzopen64 z_gzopen64
+# ifdef _WIN32
+# define gzopen_w z_gzopen_w
+# endif
+# define gzprintf z_gzprintf
+# define gzvprintf z_gzvprintf
+# define gzputc z_gzputc
+# define gzputs z_gzputs
+# define gzread z_gzread
+# define gzrewind z_gzrewind
+# define gzseek z_gzseek
+# define gzseek64 z_gzseek64
+# define gzsetparams z_gzsetparams
+# define gztell z_gztell
+# define gztell64 z_gztell64
+# define gzungetc z_gzungetc
+# define gzwrite z_gzwrite
+# endif
+# define inflate z_inflate
+# define inflateBack z_inflateBack
+# define inflateBackEnd z_inflateBackEnd
+# define inflateBackInit_ z_inflateBackInit_
+# define inflateCopy z_inflateCopy
+# define inflateEnd z_inflateEnd
+# define inflateGetHeader z_inflateGetHeader
+# define inflateInit2_ z_inflateInit2_
+# define inflateInit_ z_inflateInit_
+# define inflateMark z_inflateMark
+# define inflatePrime z_inflatePrime
+# define inflateReset z_inflateReset
+# define inflateReset2 z_inflateReset2
+# define inflateSetDictionary z_inflateSetDictionary
+# define inflateGetDictionary z_inflateGetDictionary
+# define inflateSync z_inflateSync
+# define inflateSyncPoint z_inflateSyncPoint
+# define inflateUndermine z_inflateUndermine
+# define inflateResetKeep z_inflateResetKeep
+# define inflate_copyright z_inflate_copyright
+# define inflate_fast z_inflate_fast
+# define inflate_table z_inflate_table
+# ifndef Z_SOLO
+# define uncompress z_uncompress
+# endif
+# define zError z_zError
+# ifndef Z_SOLO
+# define zcalloc z_zcalloc
+# define zcfree z_zcfree
+# endif
+# define zlibCompileFlags z_zlibCompileFlags
+# define zlibVersion z_zlibVersion
+
+/* all zlib typedefs in zlib.h and zconf.h */
+# define Byte z_Byte
+# define Bytef z_Bytef
+# define alloc_func z_alloc_func
+# define charf z_charf
+# define free_func z_free_func
+# ifndef Z_SOLO
+# define gzFile z_gzFile
+# endif
+# define gz_header z_gz_header
+# define gz_headerp z_gz_headerp
+# define in_func z_in_func
+# define intf z_intf
+# define out_func z_out_func
+# define uInt z_uInt
+# define uIntf z_uIntf
+# define uLong z_uLong
+# define uLongf z_uLongf
+# define voidp z_voidp
+# define voidpc z_voidpc
+# define voidpf z_voidpf
+
+/* all zlib structs in zlib.h and zconf.h */
+# define gz_header_s z_gz_header_s
+# define internal_state z_internal_state
+
+#endif
+
+#if defined(__MSDOS__) && !defined(MSDOS)
+# define MSDOS
+#endif
+#if (defined(OS_2) || defined(__OS2__)) && !defined(OS2)
+# define OS2
+#endif
+#if defined(_WINDOWS) && !defined(WINDOWS)
+# define WINDOWS
+#endif
+#if defined(_WIN32) || defined(_WIN32_WCE) || defined(__WIN32__)
+# ifndef WIN32
+# define WIN32
+# endif
+#endif
+#if (defined(MSDOS) || defined(OS2) || defined(WINDOWS)) && !defined(WIN32)
+# if !defined(__GNUC__) && !defined(__FLAT__) && !defined(__386__)
+# ifndef SYS16BIT
+# define SYS16BIT
+# endif
+# endif
+#endif
+
+/*
+ * Compile with -DMAXSEG_64K if the alloc function cannot allocate more
+ * than 64k bytes at a time (needed on systems with 16-bit int).
+ */
+#ifdef SYS16BIT
+# define MAXSEG_64K
+#endif
+#ifdef MSDOS
+# define UNALIGNED_OK
+#endif
+
+#ifdef __STDC_VERSION__
+# ifndef STDC
+# define STDC
+# endif
+# if __STDC_VERSION__ >= 199901L
+# ifndef STDC99
+# define STDC99
+# endif
+# endif
+#endif
+#if !defined(STDC) && (defined(__STDC__) || defined(__cplusplus))
+# define STDC
+#endif
+#if !defined(STDC) && (defined(__GNUC__) || defined(__BORLANDC__))
+# define STDC
+#endif
+#if !defined(STDC) && (defined(MSDOS) || defined(WINDOWS) || defined(WIN32))
+# define STDC
+#endif
+#if !defined(STDC) && (defined(OS2) || defined(__HOS_AIX__))
+# define STDC
+#endif
+
+#if defined(__OS400__) && !defined(STDC) /* iSeries (formerly AS/400). */
+# define STDC
+#endif
+
+#ifndef STDC
+# ifndef const /* cannot use !defined(STDC) && !defined(const) on Mac */
+# define const /* note: need a more gentle solution here */
+# endif
+#endif
+
+#if defined(ZLIB_CONST) && !defined(z_const)
+# define z_const const
+#else
+# define z_const
+#endif
+
+/* Some Mac compilers merge all .h files incorrectly: */
+#if defined(__MWERKS__)||defined(applec)||defined(THINK_C)||defined(__SC__)
+# define NO_DUMMY_DECL
+#endif
+
+/* Maximum value for memLevel in deflateInit2 */
+#ifndef MAX_MEM_LEVEL
+# ifdef MAXSEG_64K
+# define MAX_MEM_LEVEL 8
+# else
+# define MAX_MEM_LEVEL 9
+# endif
+#endif
+
+/* Maximum value for windowBits in deflateInit2 and inflateInit2.
+ * WARNING: reducing MAX_WBITS makes minigzip unable to extract .gz files
+ * created by gzip. (Files created by minigzip can still be extracted by
+ * gzip.)
+ */
+#ifndef MAX_WBITS
+# define MAX_WBITS 15 /* 32K LZ77 window */
+#endif
+
+/* The memory requirements for deflate are (in bytes):
+ (1 << (windowBits+2)) + (1 << (memLevel+9))
+ that is: 128K for windowBits=15 + 128K for memLevel = 8 (default values)
+ plus a few kilobytes for small objects. For example, if you want to reduce
+ the default memory requirements from 256K to 128K, compile with
+ make CFLAGS="-O -DMAX_WBITS=14 -DMAX_MEM_LEVEL=7"
+ Of course this will generally degrade compression (there's no free lunch).
+
+ The memory requirements for inflate are (in bytes) 1 << windowBits
+ that is, 32K for windowBits=15 (default value) plus a few kilobytes
+ for small objects.
+*/
+
+ /* Type declarations */
+
+#ifndef OF /* function prototypes */
+# ifdef STDC
+# define OF(args) args
+# else
+# define OF(args) ()
+# endif
+#endif
+
+#ifndef Z_ARG /* function prototypes for stdarg */
+# if defined(STDC) || defined(Z_HAVE_STDARG_H)
+# define Z_ARG(args) args
+# else
+# define Z_ARG(args) ()
+# endif
+#endif
+
+/* The following definitions for FAR are needed only for MSDOS mixed
+ * model programming (small or medium model with some far allocations).
+ * This was tested only with MSC; for other MSDOS compilers you may have
+ * to define NO_MEMCPY in zutil.h. If you don't need the mixed model,
+ * just define FAR to be empty.
+ */
+#ifdef SYS16BIT
+# if defined(M_I86SM) || defined(M_I86MM)
+ /* MSC small or medium model */
+# define SMALL_MEDIUM
+# ifdef _MSC_VER
+# define FAR _far
+# else
+# define FAR far
+# endif
+# endif
+# if (defined(__SMALL__) || defined(__MEDIUM__))
+ /* Turbo C small or medium model */
+# define SMALL_MEDIUM
+# ifdef __BORLANDC__
+# define FAR _far
+# else
+# define FAR far
+# endif
+# endif
+#endif
+
+#if defined(WINDOWS) || defined(WIN32)
+ /* If building or using zlib as a DLL, define ZLIB_DLL.
+ * This is not mandatory, but it offers a little performance increase.
+ */
+# ifdef ZLIB_DLL
+# if defined(WIN32) && (!defined(__BORLANDC__) || (__BORLANDC__ >= 0x500))
+# ifdef ZLIB_INTERNAL
+# define ZEXTERN extern __declspec(dllexport)
+# else
+# define ZEXTERN extern __declspec(dllimport)
+# endif
+# endif
+# endif /* ZLIB_DLL */
+ /* If building or using zlib with the WINAPI/WINAPIV calling convention,
+ * define ZLIB_WINAPI.
+ * Caution: the standard ZLIB1.DLL is NOT compiled using ZLIB_WINAPI.
+ */
+# ifdef ZLIB_WINAPI
+# ifdef FAR
+# undef FAR
+# endif
+# include <windows.h>
+ /* No need for _export, use ZLIB.DEF instead. */
+ /* For complete Windows compatibility, use WINAPI, not __stdcall. */
+# define ZEXPORT WINAPI
+# ifdef WIN32
+# define ZEXPORTVA WINAPIV
+# else
+# define ZEXPORTVA FAR CDECL
+# endif
+# endif
+#endif
+
+#if defined (__BEOS__)
+# ifdef ZLIB_DLL
+# ifdef ZLIB_INTERNAL
+# define ZEXPORT __declspec(dllexport)
+# define ZEXPORTVA __declspec(dllexport)
+# else
+# define ZEXPORT __declspec(dllimport)
+# define ZEXPORTVA __declspec(dllimport)
+# endif
+# endif
+#endif
+
+#ifndef ZEXTERN
+# define ZEXTERN extern
+#endif
+#ifndef ZEXPORT
+# define ZEXPORT
+#endif
+#ifndef ZEXPORTVA
+# define ZEXPORTVA
+#endif
+
+#ifndef FAR
+# define FAR
+#endif
+
+#if !defined(__MACTYPES__)
+typedef unsigned char Byte; /* 8 bits */
+#endif
+typedef unsigned int uInt; /* 16 bits or more */
+typedef unsigned long uLong; /* 32 bits or more */
+
+#ifdef SMALL_MEDIUM
+ /* Borland C/C++ and some old MSC versions ignore FAR inside typedef */
+# define Bytef Byte FAR
+#else
+ typedef Byte FAR Bytef;
+#endif
+typedef char FAR charf;
+typedef int FAR intf;
+typedef uInt FAR uIntf;
+typedef uLong FAR uLongf;
+
+#ifdef STDC
+ typedef void const *voidpc;
+ typedef void FAR *voidpf;
+ typedef void *voidp;
+#else
+ typedef Byte const *voidpc;
+ typedef Byte FAR *voidpf;
+ typedef Byte *voidp;
+#endif
+
+#if !defined(Z_U4) && !defined(Z_SOLO) && defined(STDC)
+# include <limits.h>
+# if (UINT_MAX == 0xffffffffUL)
+# define Z_U4 unsigned
+# elif (ULONG_MAX == 0xffffffffUL)
+# define Z_U4 unsigned long
+# elif (USHRT_MAX == 0xffffffffUL)
+# define Z_U4 unsigned short
+# endif
+#endif
+
+#ifdef Z_U4
+ typedef Z_U4 z_crc_t;
+#else
+ typedef unsigned long z_crc_t;
+#endif
+
+#ifdef HAVE_UNISTD_H /* may be set to #if 1 by ./configure */
+# define Z_HAVE_UNISTD_H
+#endif
+
+#ifdef HAVE_STDARG_H /* may be set to #if 1 by ./configure */
+# define Z_HAVE_STDARG_H
+#endif
+
+#ifdef STDC
+# ifndef Z_SOLO
+# include <sys/types.h> /* for off_t */
+# endif
+#endif
+
+#if defined(STDC) || defined(Z_HAVE_STDARG_H)
+# ifndef Z_SOLO
+# include <stdarg.h> /* for va_list */
+# endif
+#endif
+
+#ifdef _WIN32
+# ifndef Z_SOLO
+# include <stddef.h> /* for wchar_t */
+# endif
+#endif
+
+/* a little trick to accommodate both "#define _LARGEFILE64_SOURCE" and
+ * "#define _LARGEFILE64_SOURCE 1" as requesting 64-bit operations, (even
+ * though the former does not conform to the LFS document), but considering
+ * both "#undef _LARGEFILE64_SOURCE" and "#define _LARGEFILE64_SOURCE 0" as
+ * equivalently requesting no 64-bit operations
+ */
+#if defined(_LARGEFILE64_SOURCE) && -_LARGEFILE64_SOURCE - -1 == 1
+# undef _LARGEFILE64_SOURCE
+#endif
+
+#if defined(__WATCOMC__) && !defined(Z_HAVE_UNISTD_H)
+# define Z_HAVE_UNISTD_H
+#endif
+#ifndef Z_SOLO
+# if defined(Z_HAVE_UNISTD_H) || defined(_LARGEFILE64_SOURCE)
+# include <unistd.h> /* for SEEK_*, off_t, and _LFS64_LARGEFILE */
+# ifdef VMS
+# include <unixio.h> /* for off_t */
+# endif
+# ifndef z_off_t
+# define z_off_t off_t
+# endif
+# endif
+#endif
+
+#if defined(_LFS64_LARGEFILE) && _LFS64_LARGEFILE-0
+# define Z_LFS64
+#endif
+
+#if defined(_LARGEFILE64_SOURCE) && defined(Z_LFS64)
+# define Z_LARGE64
+#endif
+
+#if defined(_FILE_OFFSET_BITS) && _FILE_OFFSET_BITS-0 == 64 && defined(Z_LFS64)
+# define Z_WANT64
+#endif
+
+#if !defined(SEEK_SET) && !defined(Z_SOLO)
+# define SEEK_SET 0 /* Seek from beginning of file. */
+# define SEEK_CUR 1 /* Seek from current position. */
+# define SEEK_END 2 /* Set file pointer to EOF plus "offset" */
+#endif
+
+#ifndef z_off_t
+# define z_off_t long
+#endif
+
+#if !defined(_WIN32) && defined(Z_LARGE64)
+# define z_off64_t off64_t
+#else
+# if defined(_WIN32) && !defined(__GNUC__) && !defined(Z_SOLO)
+# define z_off64_t __int64
+# else
+# define z_off64_t z_off_t
+# endif
+#endif
+
+/* MVS linker does not support external names larger than 8 bytes */
+#if defined(__MVS__)
+ #pragma map(deflateInit_,"DEIN")
+ #pragma map(deflateInit2_,"DEIN2")
+ #pragma map(deflateEnd,"DEEND")
+ #pragma map(deflateBound,"DEBND")
+ #pragma map(inflateInit_,"ININ")
+ #pragma map(inflateInit2_,"ININ2")
+ #pragma map(inflateEnd,"INEND")
+ #pragma map(inflateSync,"INSY")
+ #pragma map(inflateSetDictionary,"INSEDI")
+ #pragma map(compressBound,"CMBND")
+ #pragma map(inflate_table,"INTABL")
+ #pragma map(inflate_fast,"INFA")
+ #pragma map(inflate_copyright,"INCOPY")
+#endif
+
+#endif /* ZCONF_H */
diff --git a/ml/dlib/dlib/external/zlib/zlib.h b/ml/dlib/dlib/external/zlib/zlib.h
new file mode 100644
index 000000000..3e0c7672a
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/zlib.h
@@ -0,0 +1,1768 @@
+/* zlib.h -- interface of the 'zlib' general purpose compression library
+ version 1.2.8, April 28th, 2013
+
+ Copyright (C) 1995-2013 Jean-loup Gailly and Mark Adler
+
+ This software is provided 'as-is', without any express or implied
+ warranty. In no event will the authors be held liable for any damages
+ arising from the use of this software.
+
+ Permission is granted to anyone to use this software for any purpose,
+ including commercial applications, and to alter it and redistribute it
+ freely, subject to the following restrictions:
+
+ 1. The origin of this software must not be misrepresented; you must not
+ claim that you wrote the original software. If you use this software
+ in a product, an acknowledgment in the product documentation would be
+ appreciated but is not required.
+ 2. Altered source versions must be plainly marked as such, and must not be
+ misrepresented as being the original software.
+ 3. This notice may not be removed or altered from any source distribution.
+
+ Jean-loup Gailly Mark Adler
+ jloup@gzip.org madler@alumni.caltech.edu
+
+
+ The data format used by the zlib library is described by RFCs (Request for
+ Comments) 1950 to 1952 in the files http://tools.ietf.org/html/rfc1950
+ (zlib format), rfc1951 (deflate format) and rfc1952 (gzip format).
+*/
+
+#ifndef ZLIB_H
+#define ZLIB_H
+
+#include "zconf.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define ZLIB_VERSION "1.2.8"
+#define ZLIB_VERNUM 0x1280
+#define ZLIB_VER_MAJOR 1
+#define ZLIB_VER_MINOR 2
+#define ZLIB_VER_REVISION 8
+#define ZLIB_VER_SUBREVISION 0
+
+/*
+ The 'zlib' compression library provides in-memory compression and
+ decompression functions, including integrity checks of the uncompressed data.
+ This version of the library supports only one compression method (deflation)
+ but other algorithms will be added later and will have the same stream
+ interface.
+
+ Compression can be done in a single step if the buffers are large enough,
+ or can be done by repeated calls of the compression function. In the latter
+ case, the application must provide more input and/or consume the output
+ (providing more output space) before each call.
+
+ The compressed data format used by default by the in-memory functions is
+ the zlib format, which is a zlib wrapper documented in RFC 1950, wrapped
+ around a deflate stream, which is itself documented in RFC 1951.
+
+ The library also supports reading and writing files in gzip (.gz) format
+ with an interface similar to that of stdio using the functions that start
+ with "gz". The gzip format is different from the zlib format. gzip is a
+ gzip wrapper, documented in RFC 1952, wrapped around a deflate stream.
+
+ This library can optionally read and write gzip streams in memory as well.
+
+ The zlib format was designed to be compact and fast for use in memory
+ and on communications channels. The gzip format was designed for single-
+ file compression on file systems, has a larger header than zlib to maintain
+ directory information, and uses a different, slower check method than zlib.
+
+ The library does not install any signal handler. The decoder checks
+ the consistency of the compressed data, so the library should never crash
+ even in case of corrupted input.
+*/
+
+typedef voidpf (*alloc_func) OF((voidpf opaque, uInt items, uInt size));
+typedef void (*free_func) OF((voidpf opaque, voidpf address));
+
+struct internal_state;
+
+typedef struct z_stream_s {
+ z_const Bytef *next_in; /* next input byte */
+ uInt avail_in; /* number of bytes available at next_in */
+ uLong total_in; /* total number of input bytes read so far */
+
+ Bytef *next_out; /* next output byte should be put there */
+ uInt avail_out; /* remaining free space at next_out */
+ uLong total_out; /* total number of bytes output so far */
+
+ z_const char *msg; /* last error message, NULL if no error */
+ struct internal_state FAR *state; /* not visible by applications */
+
+ alloc_func zalloc; /* used to allocate the internal state */
+ free_func zfree; /* used to free the internal state */
+ voidpf opaque; /* private data object passed to zalloc and zfree */
+
+ int data_type; /* best guess about the data type: binary or text */
+ uLong adler; /* adler32 value of the uncompressed data */
+ uLong reserved; /* reserved for future use */
+} z_stream;
+
+typedef z_stream FAR *z_streamp;
+
+/*
+ gzip header information passed to and from zlib routines. See RFC 1952
+ for more details on the meanings of these fields.
+*/
+typedef struct gz_header_s {
+ int text; /* true if compressed data believed to be text */
+ uLong time; /* modification time */
+ int xflags; /* extra flags (not used when writing a gzip file) */
+ int os; /* operating system */
+ Bytef *extra; /* pointer to extra field or Z_NULL if none */
+ uInt extra_len; /* extra field length (valid if extra != Z_NULL) */
+ uInt extra_max; /* space at extra (only when reading header) */
+ Bytef *name; /* pointer to zero-terminated file name or Z_NULL */
+ uInt name_max; /* space at name (only when reading header) */
+ Bytef *comment; /* pointer to zero-terminated comment or Z_NULL */
+ uInt comm_max; /* space at comment (only when reading header) */
+ int hcrc; /* true if there was or will be a header crc */
+ int done; /* true when done reading gzip header (not used
+ when writing a gzip file) */
+} gz_header;
+
+typedef gz_header FAR *gz_headerp;
+
+/*
+ The application must update next_in and avail_in when avail_in has dropped
+ to zero. It must update next_out and avail_out when avail_out has dropped
+ to zero. The application must initialize zalloc, zfree and opaque before
+ calling the init function. All other fields are set by the compression
+ library and must not be updated by the application.
+
+ The opaque value provided by the application will be passed as the first
+ parameter for calls of zalloc and zfree. This can be useful for custom
+ memory management. The compression library attaches no meaning to the
+ opaque value.
+
+ zalloc must return Z_NULL if there is not enough memory for the object.
+ If zlib is used in a multi-threaded application, zalloc and zfree must be
+ thread safe.
+
+ On 16-bit systems, the functions zalloc and zfree must be able to allocate
+ exactly 65536 bytes, but will not be required to allocate more than this if
+ the symbol MAXSEG_64K is defined (see zconf.h). WARNING: On MSDOS, pointers
+ returned by zalloc for objects of exactly 65536 bytes *must* have their
+ offset normalized to zero. The default allocation function provided by this
+ library ensures this (see zutil.c). To reduce memory requirements and avoid
+ any allocation of 64K objects, at the expense of compression ratio, compile
+ the library with -DMAX_WBITS=14 (see zconf.h).
+
+ The fields total_in and total_out can be used for statistics or progress
+ reports. After compression, total_in holds the total size of the
+ uncompressed data and may be saved for use in the decompressor (particularly
+ if the decompressor wants to decompress everything in a single step).
+*/
+
+ /* constants */
+
+#define Z_NO_FLUSH 0
+#define Z_PARTIAL_FLUSH 1
+#define Z_SYNC_FLUSH 2
+#define Z_FULL_FLUSH 3
+#define Z_FINISH 4
+#define Z_BLOCK 5
+#define Z_TREES 6
+/* Allowed flush values; see deflate() and inflate() below for details */
+
+#define Z_OK 0
+#define Z_STREAM_END 1
+#define Z_NEED_DICT 2
+#define Z_ERRNO (-1)
+#define Z_STREAM_ERROR (-2)
+#define Z_DATA_ERROR (-3)
+#define Z_MEM_ERROR (-4)
+#define Z_BUF_ERROR (-5)
+#define Z_VERSION_ERROR (-6)
+/* Return codes for the compression/decompression functions. Negative values
+ * are errors, positive values are used for special but normal events.
+ */
+
+#define Z_NO_COMPRESSION 0
+#define Z_BEST_SPEED 1
+#define Z_BEST_COMPRESSION 9
+#define Z_DEFAULT_COMPRESSION (-1)
+/* compression levels */
+
+#define Z_FILTERED 1
+#define Z_HUFFMAN_ONLY 2
+#define Z_RLE 3
+#define Z_FIXED 4
+#define Z_DEFAULT_STRATEGY 0
+/* compression strategy; see deflateInit2() below for details */
+
+#define Z_BINARY 0
+#define Z_TEXT 1
+#define Z_ASCII Z_TEXT /* for compatibility with 1.2.2 and earlier */
+#define Z_UNKNOWN 2
+/* Possible values of the data_type field (though see inflate()) */
+
+#define Z_DEFLATED 8
+/* The deflate compression method (the only one supported in this version) */
+
+#define Z_NULL 0 /* for initializing zalloc, zfree, opaque */
+
+#define zlib_version zlibVersion()
+/* for compatibility with versions < 1.0.2 */
+
+
+ /* basic functions */
+
+ZEXTERN const char * ZEXPORT zlibVersion OF((void));
+/* The application can compare zlibVersion and ZLIB_VERSION for consistency.
+ If the first character differs, the library code actually used is not
+ compatible with the zlib.h header file used by the application. This check
+ is automatically made by deflateInit and inflateInit.
+ */
+
+/*
+ZEXTERN int ZEXPORT deflateInit OF((z_streamp strm, int level));
+
+ Initializes the internal stream state for compression. The fields
+ zalloc, zfree and opaque must be initialized before by the caller. If
+ zalloc and zfree are set to Z_NULL, deflateInit updates them to use default
+ allocation functions.
+
+ The compression level must be Z_DEFAULT_COMPRESSION, or between 0 and 9:
+ 1 gives best speed, 9 gives best compression, 0 gives no compression at all
+ (the input data is simply copied a block at a time). Z_DEFAULT_COMPRESSION
+ requests a default compromise between speed and compression (currently
+ equivalent to level 6).
+
+ deflateInit returns Z_OK if success, Z_MEM_ERROR if there was not enough
+ memory, Z_STREAM_ERROR if level is not a valid compression level, or
+ Z_VERSION_ERROR if the zlib library version (zlib_version) is incompatible
+ with the version assumed by the caller (ZLIB_VERSION). msg is set to null
+ if there is no error message. deflateInit does not perform any compression:
+ this will be done by deflate().
+*/
+
+
+ZEXTERN int ZEXPORT deflate OF((z_streamp strm, int flush));
+/*
+ deflate compresses as much data as possible, and stops when the input
+ buffer becomes empty or the output buffer becomes full. It may introduce
+ some output latency (reading input without producing any output) except when
+ forced to flush.
+
+ The detailed semantics are as follows. deflate performs one or both of the
+ following actions:
+
+ - Compress more input starting at next_in and update next_in and avail_in
+ accordingly. If not all input can be processed (because there is not
+ enough room in the output buffer), next_in and avail_in are updated and
+ processing will resume at this point for the next call of deflate().
+
+ - Provide more output starting at next_out and update next_out and avail_out
+ accordingly. This action is forced if the parameter flush is non zero.
+ Forcing flush frequently degrades the compression ratio, so this parameter
+ should be set only when necessary (in interactive applications). Some
+ output may be provided even if flush is not set.
+
+ Before the call of deflate(), the application should ensure that at least
+ one of the actions is possible, by providing more input and/or consuming more
+ output, and updating avail_in or avail_out accordingly; avail_out should
+ never be zero before the call. The application can consume the compressed
+ output when it wants, for example when the output buffer is full (avail_out
+ == 0), or after each call of deflate(). If deflate returns Z_OK and with
+ zero avail_out, it must be called again after making room in the output
+ buffer because there might be more output pending.
+
+ Normally the parameter flush is set to Z_NO_FLUSH, which allows deflate to
+ decide how much data to accumulate before producing output, in order to
+ maximize compression.
+
+ If the parameter flush is set to Z_SYNC_FLUSH, all pending output is
+ flushed to the output buffer and the output is aligned on a byte boundary, so
+ that the decompressor can get all input data available so far. (In
+ particular avail_in is zero after the call if enough output space has been
+ provided before the call.) Flushing may degrade compression for some
+ compression algorithms and so it should be used only when necessary. This
+ completes the current deflate block and follows it with an empty stored block
+ that is three bits plus filler bits to the next byte, followed by four bytes
+ (00 00 ff ff).
+
+ If flush is set to Z_PARTIAL_FLUSH, all pending output is flushed to the
+ output buffer, but the output is not aligned to a byte boundary. All of the
+ input data so far will be available to the decompressor, as for Z_SYNC_FLUSH.
+ This completes the current deflate block and follows it with an empty fixed
+ codes block that is 10 bits long. This assures that enough bytes are output
+ in order for the decompressor to finish the block before the empty fixed code
+ block.
+
+ If flush is set to Z_BLOCK, a deflate block is completed and emitted, as
+ for Z_SYNC_FLUSH, but the output is not aligned on a byte boundary, and up to
+ seven bits of the current block are held to be written as the next byte after
+ the next deflate block is completed. In this case, the decompressor may not
+ be provided enough bits at this point in order to complete decompression of
+ the data provided so far to the compressor. It may need to wait for the next
+ block to be emitted. This is for advanced applications that need to control
+ the emission of deflate blocks.
+
+ If flush is set to Z_FULL_FLUSH, all output is flushed as with
+ Z_SYNC_FLUSH, and the compression state is reset so that decompression can
+ restart from this point if previous compressed data has been damaged or if
+ random access is desired. Using Z_FULL_FLUSH too often can seriously degrade
+ compression.
+
+ If deflate returns with avail_out == 0, this function must be called again
+ with the same value of the flush parameter and more output space (updated
+ avail_out), until the flush is complete (deflate returns with non-zero
+ avail_out). In the case of a Z_FULL_FLUSH or Z_SYNC_FLUSH, make sure that
+ avail_out is greater than six to avoid repeated flush markers due to
+ avail_out == 0 on return.
+
+ If the parameter flush is set to Z_FINISH, pending input is processed,
+ pending output is flushed and deflate returns with Z_STREAM_END if there was
+ enough output space; if deflate returns with Z_OK, this function must be
+ called again with Z_FINISH and more output space (updated avail_out) but no
+ more input data, until it returns with Z_STREAM_END or an error. After
+ deflate has returned Z_STREAM_END, the only possible operations on the stream
+ are deflateReset or deflateEnd.
+
+ Z_FINISH can be used immediately after deflateInit if all the compression
+ is to be done in a single step. In this case, avail_out must be at least the
+ value returned by deflateBound (see below). Then deflate is guaranteed to
+ return Z_STREAM_END. If not enough output space is provided, deflate will
+ not return Z_STREAM_END, and it must be called again as described above.
+
+ deflate() sets strm->adler to the adler32 checksum of all input read
+ so far (that is, total_in bytes).
+
+ deflate() may update strm->data_type if it can make a good guess about
+ the input data type (Z_BINARY or Z_TEXT). In doubt, the data is considered
+ binary. This field is only for information purposes and does not affect the
+ compression algorithm in any manner.
+
+ deflate() returns Z_OK if some progress has been made (more input
+ processed or more output produced), Z_STREAM_END if all input has been
+ consumed and all output has been produced (only when flush is set to
+ Z_FINISH), Z_STREAM_ERROR if the stream state was inconsistent (for example
+ if next_in or next_out was Z_NULL), Z_BUF_ERROR if no progress is possible
+ (for example avail_in or avail_out was zero). Note that Z_BUF_ERROR is not
+ fatal, and deflate() can be called again with more input and more output
+ space to continue compressing.
+*/
+
+
+ZEXTERN int ZEXPORT deflateEnd OF((z_streamp strm));
+/*
+ All dynamically allocated data structures for this stream are freed.
+ This function discards any unprocessed input and does not flush any pending
+ output.
+
+ deflateEnd returns Z_OK if success, Z_STREAM_ERROR if the
+ stream state was inconsistent, Z_DATA_ERROR if the stream was freed
+ prematurely (some input or output was discarded). In the error case, msg
+ may be set but then points to a static string (which must not be
+ deallocated).
+*/
+
+
+/*
+ZEXTERN int ZEXPORT inflateInit OF((z_streamp strm));
+
+ Initializes the internal stream state for decompression. The fields
+ next_in, avail_in, zalloc, zfree and opaque must be initialized before by
+ the caller. If next_in is not Z_NULL and avail_in is large enough (the
+ exact value depends on the compression method), inflateInit determines the
+ compression method from the zlib header and allocates all data structures
+ accordingly; otherwise the allocation will be deferred to the first call of
+ inflate. If zalloc and zfree are set to Z_NULL, inflateInit updates them to
+ use default allocation functions.
+
+ inflateInit returns Z_OK if success, Z_MEM_ERROR if there was not enough
+ memory, Z_VERSION_ERROR if the zlib library version is incompatible with the
+ version assumed by the caller, or Z_STREAM_ERROR if the parameters are
+ invalid, such as a null pointer to the structure. msg is set to null if
+ there is no error message. inflateInit does not perform any decompression
+ apart from possibly reading the zlib header if present: actual decompression
+ will be done by inflate(). (So next_in and avail_in may be modified, but
+ next_out and avail_out are unused and unchanged.) The current implementation
+ of inflateInit() does not process any header information -- that is deferred
+ until inflate() is called.
+*/
+
+
+ZEXTERN int ZEXPORT inflate OF((z_streamp strm, int flush));
+/*
+ inflate decompresses as much data as possible, and stops when the input
+ buffer becomes empty or the output buffer becomes full. It may introduce
+ some output latency (reading input without producing any output) except when
+ forced to flush.
+
+ The detailed semantics are as follows. inflate performs one or both of the
+ following actions:
+
+ - Decompress more input starting at next_in and update next_in and avail_in
+ accordingly. If not all input can be processed (because there is not
+ enough room in the output buffer), next_in is updated and processing will
+ resume at this point for the next call of inflate().
+
+ - Provide more output starting at next_out and update next_out and avail_out
+ accordingly. inflate() provides as much output as possible, until there is
+ no more input data or no more space in the output buffer (see below about
+ the flush parameter).
+
+ Before the call of inflate(), the application should ensure that at least
+ one of the actions is possible, by providing more input and/or consuming more
+ output, and updating the next_* and avail_* values accordingly. The
+ application can consume the uncompressed output when it wants, for example
+ when the output buffer is full (avail_out == 0), or after each call of
+ inflate(). If inflate returns Z_OK and with zero avail_out, it must be
+ called again after making room in the output buffer because there might be
+ more output pending.
+
+ The flush parameter of inflate() can be Z_NO_FLUSH, Z_SYNC_FLUSH, Z_FINISH,
+ Z_BLOCK, or Z_TREES. Z_SYNC_FLUSH requests that inflate() flush as much
+ output as possible to the output buffer. Z_BLOCK requests that inflate()
+ stop if and when it gets to the next deflate block boundary. When decoding
+ the zlib or gzip format, this will cause inflate() to return immediately
+ after the header and before the first block. When doing a raw inflate,
+ inflate() will go ahead and process the first block, and will return when it
+ gets to the end of that block, or when it runs out of data.
+
+ The Z_BLOCK option assists in appending to or combining deflate streams.
+ Also to assist in this, on return inflate() will set strm->data_type to the
+ number of unused bits in the last byte taken from strm->next_in, plus 64 if
+ inflate() is currently decoding the last block in the deflate stream, plus
+ 128 if inflate() returned immediately after decoding an end-of-block code or
+ decoding the complete header up to just before the first byte of the deflate
+ stream. The end-of-block will not be indicated until all of the uncompressed
+ data from that block has been written to strm->next_out. The number of
+ unused bits may in general be greater than seven, except when bit 7 of
+ data_type is set, in which case the number of unused bits will be less than
+ eight. data_type is set as noted here every time inflate() returns for all
+ flush options, and so can be used to determine the amount of currently
+ consumed input in bits.
+
+ The Z_TREES option behaves as Z_BLOCK does, but it also returns when the
+ end of each deflate block header is reached, before any actual data in that
+ block is decoded. This allows the caller to determine the length of the
+ deflate block header for later use in random access within a deflate block.
+ 256 is added to the value of strm->data_type when inflate() returns
+ immediately after reaching the end of the deflate block header.
+
+ inflate() should normally be called until it returns Z_STREAM_END or an
+ error. However if all decompression is to be performed in a single step (a
+ single call of inflate), the parameter flush should be set to Z_FINISH. In
+ this case all pending input is processed and all pending output is flushed;
+ avail_out must be large enough to hold all of the uncompressed data for the
+ operation to complete. (The size of the uncompressed data may have been
+ saved by the compressor for this purpose.) The use of Z_FINISH is not
+ required to perform an inflation in one step. However it may be used to
+ inform inflate that a faster approach can be used for the single inflate()
+ call. Z_FINISH also informs inflate to not maintain a sliding window if the
+ stream completes, which reduces inflate's memory footprint. If the stream
+ does not complete, either because not all of the stream is provided or not
+ enough output space is provided, then a sliding window will be allocated and
+ inflate() can be called again to continue the operation as if Z_NO_FLUSH had
+ been used.
+
+ In this implementation, inflate() always flushes as much output as
+ possible to the output buffer, and always uses the faster approach on the
+ first call. So the effects of the flush parameter in this implementation are
+ on the return value of inflate() as noted below, when inflate() returns early
+ when Z_BLOCK or Z_TREES is used, and when inflate() avoids the allocation of
+ memory for a sliding window when Z_FINISH is used.
+
+ If a preset dictionary is needed after this call (see inflateSetDictionary
+ below), inflate sets strm->adler to the Adler-32 checksum of the dictionary
+ chosen by the compressor and returns Z_NEED_DICT; otherwise it sets
+ strm->adler to the Adler-32 checksum of all output produced so far (that is,
+ total_out bytes) and returns Z_OK, Z_STREAM_END or an error code as described
+ below. At the end of the stream, inflate() checks that its computed adler32
+ checksum is equal to that saved by the compressor and returns Z_STREAM_END
+ only if the checksum is correct.
+
+ inflate() can decompress and check either zlib-wrapped or gzip-wrapped
+ deflate data. The header type is detected automatically, if requested when
+ initializing with inflateInit2(). Any information contained in the gzip
+ header is not retained, so applications that need that information should
+ instead use raw inflate, see inflateInit2() below, or inflateBack() and
+ perform their own processing of the gzip header and trailer. When processing
+ gzip-wrapped deflate data, strm->adler32 is set to the CRC-32 of the output
+ producted so far. The CRC-32 is checked against the gzip trailer.
+
+ inflate() returns Z_OK if some progress has been made (more input processed
+ or more output produced), Z_STREAM_END if the end of the compressed data has
+ been reached and all uncompressed output has been produced, Z_NEED_DICT if a
+ preset dictionary is needed at this point, Z_DATA_ERROR if the input data was
+ corrupted (input stream not conforming to the zlib format or incorrect check
+ value), Z_STREAM_ERROR if the stream structure was inconsistent (for example
+ next_in or next_out was Z_NULL), Z_MEM_ERROR if there was not enough memory,
+ Z_BUF_ERROR if no progress is possible or if there was not enough room in the
+ output buffer when Z_FINISH is used. Note that Z_BUF_ERROR is not fatal, and
+ inflate() can be called again with more input and more output space to
+ continue decompressing. If Z_DATA_ERROR is returned, the application may
+ then call inflateSync() to look for a good compression block if a partial
+ recovery of the data is desired.
+*/
+
+
+ZEXTERN int ZEXPORT inflateEnd OF((z_streamp strm));
+/*
+ All dynamically allocated data structures for this stream are freed.
+ This function discards any unprocessed input and does not flush any pending
+ output.
+
+ inflateEnd returns Z_OK if success, Z_STREAM_ERROR if the stream state
+ was inconsistent. In the error case, msg may be set but then points to a
+ static string (which must not be deallocated).
+*/
+
+
+ /* Advanced functions */
+
+/*
+ The following functions are needed only in some special applications.
+*/
+
+/*
+ZEXTERN int ZEXPORT deflateInit2 OF((z_streamp strm,
+ int level,
+ int method,
+ int windowBits,
+ int memLevel,
+ int strategy));
+
+ This is another version of deflateInit with more compression options. The
+ fields next_in, zalloc, zfree and opaque must be initialized before by the
+ caller.
+
+ The method parameter is the compression method. It must be Z_DEFLATED in
+ this version of the library.
+
+ The windowBits parameter is the base two logarithm of the window size
+ (the size of the history buffer). It should be in the range 8..15 for this
+ version of the library. Larger values of this parameter result in better
+ compression at the expense of memory usage. The default value is 15 if
+ deflateInit is used instead.
+
+ windowBits can also be -8..-15 for raw deflate. In this case, -windowBits
+ determines the window size. deflate() will then generate raw deflate data
+ with no zlib header or trailer, and will not compute an adler32 check value.
+
+ windowBits can also be greater than 15 for optional gzip encoding. Add
+ 16 to windowBits to write a simple gzip header and trailer around the
+ compressed data instead of a zlib wrapper. The gzip header will have no
+ file name, no extra data, no comment, no modification time (set to zero), no
+ header crc, and the operating system will be set to 255 (unknown). If a
+ gzip stream is being written, strm->adler is a crc32 instead of an adler32.
+
+ The memLevel parameter specifies how much memory should be allocated
+ for the internal compression state. memLevel=1 uses minimum memory but is
+ slow and reduces compression ratio; memLevel=9 uses maximum memory for
+ optimal speed. The default value is 8. See zconf.h for total memory usage
+ as a function of windowBits and memLevel.
+
+ The strategy parameter is used to tune the compression algorithm. Use the
+ value Z_DEFAULT_STRATEGY for normal data, Z_FILTERED for data produced by a
+ filter (or predictor), Z_HUFFMAN_ONLY to force Huffman encoding only (no
+ string match), or Z_RLE to limit match distances to one (run-length
+ encoding). Filtered data consists mostly of small values with a somewhat
+ random distribution. In this case, the compression algorithm is tuned to
+ compress them better. The effect of Z_FILTERED is to force more Huffman
+ coding and less string matching; it is somewhat intermediate between
+ Z_DEFAULT_STRATEGY and Z_HUFFMAN_ONLY. Z_RLE is designed to be almost as
+ fast as Z_HUFFMAN_ONLY, but give better compression for PNG image data. The
+ strategy parameter only affects the compression ratio but not the
+ correctness of the compressed output even if it is not set appropriately.
+ Z_FIXED prevents the use of dynamic Huffman codes, allowing for a simpler
+ decoder for special applications.
+
+ deflateInit2 returns Z_OK if success, Z_MEM_ERROR if there was not enough
+ memory, Z_STREAM_ERROR if any parameter is invalid (such as an invalid
+ method), or Z_VERSION_ERROR if the zlib library version (zlib_version) is
+ incompatible with the version assumed by the caller (ZLIB_VERSION). msg is
+ set to null if there is no error message. deflateInit2 does not perform any
+ compression: this will be done by deflate().
+*/
+
+ZEXTERN int ZEXPORT deflateSetDictionary OF((z_streamp strm,
+ const Bytef *dictionary,
+ uInt dictLength));
+/*
+ Initializes the compression dictionary from the given byte sequence
+ without producing any compressed output. When using the zlib format, this
+ function must be called immediately after deflateInit, deflateInit2 or
+ deflateReset, and before any call of deflate. When doing raw deflate, this
+ function must be called either before any call of deflate, or immediately
+ after the completion of a deflate block, i.e. after all input has been
+ consumed and all output has been delivered when using any of the flush
+ options Z_BLOCK, Z_PARTIAL_FLUSH, Z_SYNC_FLUSH, or Z_FULL_FLUSH. The
+ compressor and decompressor must use exactly the same dictionary (see
+ inflateSetDictionary).
+
+ The dictionary should consist of strings (byte sequences) that are likely
+ to be encountered later in the data to be compressed, with the most commonly
+ used strings preferably put towards the end of the dictionary. Using a
+ dictionary is most useful when the data to be compressed is short and can be
+ predicted with good accuracy; the data can then be compressed better than
+ with the default empty dictionary.
+
+ Depending on the size of the compression data structures selected by
+ deflateInit or deflateInit2, a part of the dictionary may in effect be
+ discarded, for example if the dictionary is larger than the window size
+ provided in deflateInit or deflateInit2. Thus the strings most likely to be
+ useful should be put at the end of the dictionary, not at the front. In
+ addition, the current implementation of deflate will use at most the window
+ size minus 262 bytes of the provided dictionary.
+
+ Upon return of this function, strm->adler is set to the adler32 value
+ of the dictionary; the decompressor may later use this value to determine
+ which dictionary has been used by the compressor. (The adler32 value
+ applies to the whole dictionary even if only a subset of the dictionary is
+ actually used by the compressor.) If a raw deflate was requested, then the
+ adler32 value is not computed and strm->adler is not set.
+
+ deflateSetDictionary returns Z_OK if success, or Z_STREAM_ERROR if a
+ parameter is invalid (e.g. dictionary being Z_NULL) or the stream state is
+ inconsistent (for example if deflate has already been called for this stream
+ or if not at a block boundary for raw deflate). deflateSetDictionary does
+ not perform any compression: this will be done by deflate().
+*/
+
+ZEXTERN int ZEXPORT deflateCopy OF((z_streamp dest,
+ z_streamp source));
+/*
+ Sets the destination stream as a complete copy of the source stream.
+
+ This function can be useful when several compression strategies will be
+ tried, for example when there are several ways of pre-processing the input
+ data with a filter. The streams that will be discarded should then be freed
+ by calling deflateEnd. Note that deflateCopy duplicates the internal
+ compression state which can be quite large, so this strategy is slow and can
+ consume lots of memory.
+
+ deflateCopy returns Z_OK if success, Z_MEM_ERROR if there was not
+ enough memory, Z_STREAM_ERROR if the source stream state was inconsistent
+ (such as zalloc being Z_NULL). msg is left unchanged in both source and
+ destination.
+*/
+
+ZEXTERN int ZEXPORT deflateReset OF((z_streamp strm));
+/*
+ This function is equivalent to deflateEnd followed by deflateInit,
+ but does not free and reallocate all the internal compression state. The
+ stream will keep the same compression level and any other attributes that
+ may have been set by deflateInit2.
+
+ deflateReset returns Z_OK if success, or Z_STREAM_ERROR if the source
+ stream state was inconsistent (such as zalloc or state being Z_NULL).
+*/
+
+ZEXTERN int ZEXPORT deflateParams OF((z_streamp strm,
+ int level,
+ int strategy));
+/*
+ Dynamically update the compression level and compression strategy. The
+ interpretation of level and strategy is as in deflateInit2. This can be
+ used to switch between compression and straight copy of the input data, or
+ to switch to a different kind of input data requiring a different strategy.
+ If the compression level is changed, the input available so far is
+ compressed with the old level (and may be flushed); the new level will take
+ effect only at the next call of deflate().
+
+ Before the call of deflateParams, the stream state must be set as for
+ a call of deflate(), since the currently available input may have to be
+ compressed and flushed. In particular, strm->avail_out must be non-zero.
+
+ deflateParams returns Z_OK if success, Z_STREAM_ERROR if the source
+ stream state was inconsistent or if a parameter was invalid, Z_BUF_ERROR if
+ strm->avail_out was zero.
+*/
+
+ZEXTERN int ZEXPORT deflateTune OF((z_streamp strm,
+ int good_length,
+ int max_lazy,
+ int nice_length,
+ int max_chain));
+/*
+ Fine tune deflate's internal compression parameters. This should only be
+ used by someone who understands the algorithm used by zlib's deflate for
+ searching for the best matching string, and even then only by the most
+ fanatic optimizer trying to squeeze out the last compressed bit for their
+ specific input data. Read the deflate.c source code for the meaning of the
+ max_lazy, good_length, nice_length, and max_chain parameters.
+
+ deflateTune() can be called after deflateInit() or deflateInit2(), and
+ returns Z_OK on success, or Z_STREAM_ERROR for an invalid deflate stream.
+ */
+
+ZEXTERN uLong ZEXPORT deflateBound OF((z_streamp strm,
+ uLong sourceLen));
+/*
+ deflateBound() returns an upper bound on the compressed size after
+ deflation of sourceLen bytes. It must be called after deflateInit() or
+ deflateInit2(), and after deflateSetHeader(), if used. This would be used
+ to allocate an output buffer for deflation in a single pass, and so would be
+ called before deflate(). If that first deflate() call is provided the
+ sourceLen input bytes, an output buffer allocated to the size returned by
+ deflateBound(), and the flush value Z_FINISH, then deflate() is guaranteed
+ to return Z_STREAM_END. Note that it is possible for the compressed size to
+ be larger than the value returned by deflateBound() if flush options other
+ than Z_FINISH or Z_NO_FLUSH are used.
+*/
+
+ZEXTERN int ZEXPORT deflatePending OF((z_streamp strm,
+ unsigned *pending,
+ int *bits));
+/*
+ deflatePending() returns the number of bytes and bits of output that have
+ been generated, but not yet provided in the available output. The bytes not
+ provided would be due to the available output space having being consumed.
+ The number of bits of output not provided are between 0 and 7, where they
+ await more bits to join them in order to fill out a full byte. If pending
+ or bits are Z_NULL, then those values are not set.
+
+ deflatePending returns Z_OK if success, or Z_STREAM_ERROR if the source
+ stream state was inconsistent.
+ */
+
+ZEXTERN int ZEXPORT deflatePrime OF((z_streamp strm,
+ int bits,
+ int value));
+/*
+ deflatePrime() inserts bits in the deflate output stream. The intent
+ is that this function is used to start off the deflate output with the bits
+ leftover from a previous deflate stream when appending to it. As such, this
+ function can only be used for raw deflate, and must be used before the first
+ deflate() call after a deflateInit2() or deflateReset(). bits must be less
+ than or equal to 16, and that many of the least significant bits of value
+ will be inserted in the output.
+
+ deflatePrime returns Z_OK if success, Z_BUF_ERROR if there was not enough
+ room in the internal buffer to insert the bits, or Z_STREAM_ERROR if the
+ source stream state was inconsistent.
+*/
+
+ZEXTERN int ZEXPORT deflateSetHeader OF((z_streamp strm,
+ gz_headerp head));
+/*
+ deflateSetHeader() provides gzip header information for when a gzip
+ stream is requested by deflateInit2(). deflateSetHeader() may be called
+ after deflateInit2() or deflateReset() and before the first call of
+ deflate(). The text, time, os, extra field, name, and comment information
+ in the provided gz_header structure are written to the gzip header (xflag is
+ ignored -- the extra flags are set according to the compression level). The
+ caller must assure that, if not Z_NULL, name and comment are terminated with
+ a zero byte, and that if extra is not Z_NULL, that extra_len bytes are
+ available there. If hcrc is true, a gzip header crc is included. Note that
+ the current versions of the command-line version of gzip (up through version
+ 1.3.x) do not support header crc's, and will report that it is a "multi-part
+ gzip file" and give up.
+
+ If deflateSetHeader is not used, the default gzip header has text false,
+ the time set to zero, and os set to 255, with no extra, name, or comment
+ fields. The gzip header is returned to the default state by deflateReset().
+
+ deflateSetHeader returns Z_OK if success, or Z_STREAM_ERROR if the source
+ stream state was inconsistent.
+*/
+
+/*
+ZEXTERN int ZEXPORT inflateInit2 OF((z_streamp strm,
+ int windowBits));
+
+ This is another version of inflateInit with an extra parameter. The
+ fields next_in, avail_in, zalloc, zfree and opaque must be initialized
+ before by the caller.
+
+ The windowBits parameter is the base two logarithm of the maximum window
+ size (the size of the history buffer). It should be in the range 8..15 for
+ this version of the library. The default value is 15 if inflateInit is used
+ instead. windowBits must be greater than or equal to the windowBits value
+ provided to deflateInit2() while compressing, or it must be equal to 15 if
+ deflateInit2() was not used. If a compressed stream with a larger window
+ size is given as input, inflate() will return with the error code
+ Z_DATA_ERROR instead of trying to allocate a larger window.
+
+ windowBits can also be zero to request that inflate use the window size in
+ the zlib header of the compressed stream.
+
+ windowBits can also be -8..-15 for raw inflate. In this case, -windowBits
+ determines the window size. inflate() will then process raw deflate data,
+ not looking for a zlib or gzip header, not generating a check value, and not
+ looking for any check values for comparison at the end of the stream. This
+ is for use with other formats that use the deflate compressed data format
+ such as zip. Those formats provide their own check values. If a custom
+ format is developed using the raw deflate format for compressed data, it is
+ recommended that a check value such as an adler32 or a crc32 be applied to
+ the uncompressed data as is done in the zlib, gzip, and zip formats. For
+ most applications, the zlib format should be used as is. Note that comments
+ above on the use in deflateInit2() applies to the magnitude of windowBits.
+
+ windowBits can also be greater than 15 for optional gzip decoding. Add
+ 32 to windowBits to enable zlib and gzip decoding with automatic header
+ detection, or add 16 to decode only the gzip format (the zlib format will
+ return a Z_DATA_ERROR). If a gzip stream is being decoded, strm->adler is a
+ crc32 instead of an adler32.
+
+ inflateInit2 returns Z_OK if success, Z_MEM_ERROR if there was not enough
+ memory, Z_VERSION_ERROR if the zlib library version is incompatible with the
+ version assumed by the caller, or Z_STREAM_ERROR if the parameters are
+ invalid, such as a null pointer to the structure. msg is set to null if
+ there is no error message. inflateInit2 does not perform any decompression
+ apart from possibly reading the zlib header if present: actual decompression
+ will be done by inflate(). (So next_in and avail_in may be modified, but
+ next_out and avail_out are unused and unchanged.) The current implementation
+ of inflateInit2() does not process any header information -- that is
+ deferred until inflate() is called.
+*/
+
+ZEXTERN int ZEXPORT inflateSetDictionary OF((z_streamp strm,
+ const Bytef *dictionary,
+ uInt dictLength));
+/*
+ Initializes the decompression dictionary from the given uncompressed byte
+ sequence. This function must be called immediately after a call of inflate,
+ if that call returned Z_NEED_DICT. The dictionary chosen by the compressor
+ can be determined from the adler32 value returned by that call of inflate.
+ The compressor and decompressor must use exactly the same dictionary (see
+ deflateSetDictionary). For raw inflate, this function can be called at any
+ time to set the dictionary. If the provided dictionary is smaller than the
+ window and there is already data in the window, then the provided dictionary
+ will amend what's there. The application must insure that the dictionary
+ that was used for compression is provided.
+
+ inflateSetDictionary returns Z_OK if success, Z_STREAM_ERROR if a
+ parameter is invalid (e.g. dictionary being Z_NULL) or the stream state is
+ inconsistent, Z_DATA_ERROR if the given dictionary doesn't match the
+ expected one (incorrect adler32 value). inflateSetDictionary does not
+ perform any decompression: this will be done by subsequent calls of
+ inflate().
+*/
+
+ZEXTERN int ZEXPORT inflateGetDictionary OF((z_streamp strm,
+ Bytef *dictionary,
+ uInt *dictLength));
+/*
+ Returns the sliding dictionary being maintained by inflate. dictLength is
+ set to the number of bytes in the dictionary, and that many bytes are copied
+ to dictionary. dictionary must have enough space, where 32768 bytes is
+ always enough. If inflateGetDictionary() is called with dictionary equal to
+ Z_NULL, then only the dictionary length is returned, and nothing is copied.
+ Similary, if dictLength is Z_NULL, then it is not set.
+
+ inflateGetDictionary returns Z_OK on success, or Z_STREAM_ERROR if the
+ stream state is inconsistent.
+*/
+
+ZEXTERN int ZEXPORT inflateSync OF((z_streamp strm));
+/*
+ Skips invalid compressed data until a possible full flush point (see above
+ for the description of deflate with Z_FULL_FLUSH) can be found, or until all
+ available input is skipped. No output is provided.
+
+ inflateSync searches for a 00 00 FF FF pattern in the compressed data.
+ All full flush points have this pattern, but not all occurrences of this
+ pattern are full flush points.
+
+ inflateSync returns Z_OK if a possible full flush point has been found,
+ Z_BUF_ERROR if no more input was provided, Z_DATA_ERROR if no flush point
+ has been found, or Z_STREAM_ERROR if the stream structure was inconsistent.
+ In the success case, the application may save the current current value of
+ total_in which indicates where valid compressed data was found. In the
+ error case, the application may repeatedly call inflateSync, providing more
+ input each time, until success or end of the input data.
+*/
+
+ZEXTERN int ZEXPORT inflateCopy OF((z_streamp dest,
+ z_streamp source));
+/*
+ Sets the destination stream as a complete copy of the source stream.
+
+ This function can be useful when randomly accessing a large stream. The
+ first pass through the stream can periodically record the inflate state,
+ allowing restarting inflate at those points when randomly accessing the
+ stream.
+
+ inflateCopy returns Z_OK if success, Z_MEM_ERROR if there was not
+ enough memory, Z_STREAM_ERROR if the source stream state was inconsistent
+ (such as zalloc being Z_NULL). msg is left unchanged in both source and
+ destination.
+*/
+
+ZEXTERN int ZEXPORT inflateReset OF((z_streamp strm));
+/*
+ This function is equivalent to inflateEnd followed by inflateInit,
+ but does not free and reallocate all the internal decompression state. The
+ stream will keep attributes that may have been set by inflateInit2.
+
+ inflateReset returns Z_OK if success, or Z_STREAM_ERROR if the source
+ stream state was inconsistent (such as zalloc or state being Z_NULL).
+*/
+
+ZEXTERN int ZEXPORT inflateReset2 OF((z_streamp strm,
+ int windowBits));
+/*
+ This function is the same as inflateReset, but it also permits changing
+ the wrap and window size requests. The windowBits parameter is interpreted
+ the same as it is for inflateInit2.
+
+ inflateReset2 returns Z_OK if success, or Z_STREAM_ERROR if the source
+ stream state was inconsistent (such as zalloc or state being Z_NULL), or if
+ the windowBits parameter is invalid.
+*/
+
+ZEXTERN int ZEXPORT inflatePrime OF((z_streamp strm,
+ int bits,
+ int value));
+/*
+ This function inserts bits in the inflate input stream. The intent is
+ that this function is used to start inflating at a bit position in the
+ middle of a byte. The provided bits will be used before any bytes are used
+ from next_in. This function should only be used with raw inflate, and
+ should be used before the first inflate() call after inflateInit2() or
+ inflateReset(). bits must be less than or equal to 16, and that many of the
+ least significant bits of value will be inserted in the input.
+
+ If bits is negative, then the input stream bit buffer is emptied. Then
+ inflatePrime() can be called again to put bits in the buffer. This is used
+ to clear out bits leftover after feeding inflate a block description prior
+ to feeding inflate codes.
+
+ inflatePrime returns Z_OK if success, or Z_STREAM_ERROR if the source
+ stream state was inconsistent.
+*/
+
+ZEXTERN long ZEXPORT inflateMark OF((z_streamp strm));
+/*
+ This function returns two values, one in the lower 16 bits of the return
+ value, and the other in the remaining upper bits, obtained by shifting the
+ return value down 16 bits. If the upper value is -1 and the lower value is
+ zero, then inflate() is currently decoding information outside of a block.
+ If the upper value is -1 and the lower value is non-zero, then inflate is in
+ the middle of a stored block, with the lower value equaling the number of
+ bytes from the input remaining to copy. If the upper value is not -1, then
+ it is the number of bits back from the current bit position in the input of
+ the code (literal or length/distance pair) currently being processed. In
+ that case the lower value is the number of bytes already emitted for that
+ code.
+
+ A code is being processed if inflate is waiting for more input to complete
+ decoding of the code, or if it has completed decoding but is waiting for
+ more output space to write the literal or match data.
+
+ inflateMark() is used to mark locations in the input data for random
+ access, which may be at bit positions, and to note those cases where the
+ output of a code may span boundaries of random access blocks. The current
+ location in the input stream can be determined from avail_in and data_type
+ as noted in the description for the Z_BLOCK flush parameter for inflate.
+
+ inflateMark returns the value noted above or -1 << 16 if the provided
+ source stream state was inconsistent.
+*/
+
+ZEXTERN int ZEXPORT inflateGetHeader OF((z_streamp strm,
+ gz_headerp head));
+/*
+ inflateGetHeader() requests that gzip header information be stored in the
+ provided gz_header structure. inflateGetHeader() may be called after
+ inflateInit2() or inflateReset(), and before the first call of inflate().
+ As inflate() processes the gzip stream, head->done is zero until the header
+ is completed, at which time head->done is set to one. If a zlib stream is
+ being decoded, then head->done is set to -1 to indicate that there will be
+ no gzip header information forthcoming. Note that Z_BLOCK or Z_TREES can be
+ used to force inflate() to return immediately after header processing is
+ complete and before any actual data is decompressed.
+
+ The text, time, xflags, and os fields are filled in with the gzip header
+ contents. hcrc is set to true if there is a header CRC. (The header CRC
+ was valid if done is set to one.) If extra is not Z_NULL, then extra_max
+ contains the maximum number of bytes to write to extra. Once done is true,
+ extra_len contains the actual extra field length, and extra contains the
+ extra field, or that field truncated if extra_max is less than extra_len.
+ If name is not Z_NULL, then up to name_max characters are written there,
+ terminated with a zero unless the length is greater than name_max. If
+ comment is not Z_NULL, then up to comm_max characters are written there,
+ terminated with a zero unless the length is greater than comm_max. When any
+ of extra, name, or comment are not Z_NULL and the respective field is not
+ present in the header, then that field is set to Z_NULL to signal its
+ absence. This allows the use of deflateSetHeader() with the returned
+ structure to duplicate the header. However if those fields are set to
+ allocated memory, then the application will need to save those pointers
+ elsewhere so that they can be eventually freed.
+
+ If inflateGetHeader is not used, then the header information is simply
+ discarded. The header is always checked for validity, including the header
+ CRC if present. inflateReset() will reset the process to discard the header
+ information. The application would need to call inflateGetHeader() again to
+ retrieve the header from the next gzip stream.
+
+ inflateGetHeader returns Z_OK if success, or Z_STREAM_ERROR if the source
+ stream state was inconsistent.
+*/
+
+/*
+ZEXTERN int ZEXPORT inflateBackInit OF((z_streamp strm, int windowBits,
+ unsigned char FAR *window));
+
+ Initialize the internal stream state for decompression using inflateBack()
+ calls. The fields zalloc, zfree and opaque in strm must be initialized
+ before the call. If zalloc and zfree are Z_NULL, then the default library-
+ derived memory allocation routines are used. windowBits is the base two
+ logarithm of the window size, in the range 8..15. window is a caller
+ supplied buffer of that size. Except for special applications where it is
+ assured that deflate was used with small window sizes, windowBits must be 15
+ and a 32K byte window must be supplied to be able to decompress general
+ deflate streams.
+
+ See inflateBack() for the usage of these routines.
+
+ inflateBackInit will return Z_OK on success, Z_STREAM_ERROR if any of
+ the parameters are invalid, Z_MEM_ERROR if the internal state could not be
+ allocated, or Z_VERSION_ERROR if the version of the library does not match
+ the version of the header file.
+*/
+
+typedef unsigned (*in_func) OF((void FAR *,
+ z_const unsigned char FAR * FAR *));
+typedef int (*out_func) OF((void FAR *, unsigned char FAR *, unsigned));
+
+ZEXTERN int ZEXPORT inflateBack OF((z_streamp strm,
+ in_func in, void FAR *in_desc,
+ out_func out, void FAR *out_desc));
+/*
+ inflateBack() does a raw inflate with a single call using a call-back
+ interface for input and output. This is potentially more efficient than
+ inflate() for file i/o applications, in that it avoids copying between the
+ output and the sliding window by simply making the window itself the output
+ buffer. inflate() can be faster on modern CPUs when used with large
+ buffers. inflateBack() trusts the application to not change the output
+ buffer passed by the output function, at least until inflateBack() returns.
+
+ inflateBackInit() must be called first to allocate the internal state
+ and to initialize the state with the user-provided window buffer.
+ inflateBack() may then be used multiple times to inflate a complete, raw
+ deflate stream with each call. inflateBackEnd() is then called to free the
+ allocated state.
+
+ A raw deflate stream is one with no zlib or gzip header or trailer.
+ This routine would normally be used in a utility that reads zip or gzip
+ files and writes out uncompressed files. The utility would decode the
+ header and process the trailer on its own, hence this routine expects only
+ the raw deflate stream to decompress. This is different from the normal
+ behavior of inflate(), which expects either a zlib or gzip header and
+ trailer around the deflate stream.
+
+ inflateBack() uses two subroutines supplied by the caller that are then
+ called by inflateBack() for input and output. inflateBack() calls those
+ routines until it reads a complete deflate stream and writes out all of the
+ uncompressed data, or until it encounters an error. The function's
+ parameters and return types are defined above in the in_func and out_func
+ typedefs. inflateBack() will call in(in_desc, &buf) which should return the
+ number of bytes of provided input, and a pointer to that input in buf. If
+ there is no input available, in() must return zero--buf is ignored in that
+ case--and inflateBack() will return a buffer error. inflateBack() will call
+ out(out_desc, buf, len) to write the uncompressed data buf[0..len-1]. out()
+ should return zero on success, or non-zero on failure. If out() returns
+ non-zero, inflateBack() will return with an error. Neither in() nor out()
+ are permitted to change the contents of the window provided to
+ inflateBackInit(), which is also the buffer that out() uses to write from.
+ The length written by out() will be at most the window size. Any non-zero
+ amount of input may be provided by in().
+
+ For convenience, inflateBack() can be provided input on the first call by
+ setting strm->next_in and strm->avail_in. If that input is exhausted, then
+ in() will be called. Therefore strm->next_in must be initialized before
+ calling inflateBack(). If strm->next_in is Z_NULL, then in() will be called
+ immediately for input. If strm->next_in is not Z_NULL, then strm->avail_in
+ must also be initialized, and then if strm->avail_in is not zero, input will
+ initially be taken from strm->next_in[0 .. strm->avail_in - 1].
+
+ The in_desc and out_desc parameters of inflateBack() is passed as the
+ first parameter of in() and out() respectively when they are called. These
+ descriptors can be optionally used to pass any information that the caller-
+ supplied in() and out() functions need to do their job.
+
+ On return, inflateBack() will set strm->next_in and strm->avail_in to
+ pass back any unused input that was provided by the last in() call. The
+ return values of inflateBack() can be Z_STREAM_END on success, Z_BUF_ERROR
+ if in() or out() returned an error, Z_DATA_ERROR if there was a format error
+ in the deflate stream (in which case strm->msg is set to indicate the nature
+ of the error), or Z_STREAM_ERROR if the stream was not properly initialized.
+ In the case of Z_BUF_ERROR, an input or output error can be distinguished
+ using strm->next_in which will be Z_NULL only if in() returned an error. If
+ strm->next_in is not Z_NULL, then the Z_BUF_ERROR was due to out() returning
+ non-zero. (in() will always be called before out(), so strm->next_in is
+ assured to be defined if out() returns non-zero.) Note that inflateBack()
+ cannot return Z_OK.
+*/
+
+ZEXTERN int ZEXPORT inflateBackEnd OF((z_streamp strm));
+/*
+ All memory allocated by inflateBackInit() is freed.
+
+ inflateBackEnd() returns Z_OK on success, or Z_STREAM_ERROR if the stream
+ state was inconsistent.
+*/
+
+ZEXTERN uLong ZEXPORT zlibCompileFlags OF((void));
+/* Return flags indicating compile-time options.
+
+ Type sizes, two bits each, 00 = 16 bits, 01 = 32, 10 = 64, 11 = other:
+ 1.0: size of uInt
+ 3.2: size of uLong
+ 5.4: size of voidpf (pointer)
+ 7.6: size of z_off_t
+
+ Compiler, assembler, and debug options:
+ 8: DEBUG
+ 9: ASMV or ASMINF -- use ASM code
+ 10: ZLIB_WINAPI -- exported functions use the WINAPI calling convention
+ 11: 0 (reserved)
+
+ One-time table building (smaller code, but not thread-safe if true):
+ 12: BUILDFIXED -- build static block decoding tables when needed
+ 13: DYNAMIC_CRC_TABLE -- build CRC calculation tables when needed
+ 14,15: 0 (reserved)
+
+ Library content (indicates missing functionality):
+ 16: NO_GZCOMPRESS -- gz* functions cannot compress (to avoid linking
+ deflate code when not needed)
+ 17: NO_GZIP -- deflate can't write gzip streams, and inflate can't detect
+ and decode gzip streams (to avoid linking crc code)
+ 18-19: 0 (reserved)
+
+ Operation variations (changes in library functionality):
+ 20: PKZIP_BUG_WORKAROUND -- slightly more permissive inflate
+ 21: FASTEST -- deflate algorithm with only one, lowest compression level
+ 22,23: 0 (reserved)
+
+ The sprintf variant used by gzprintf (zero is best):
+ 24: 0 = vs*, 1 = s* -- 1 means limited to 20 arguments after the format
+ 25: 0 = *nprintf, 1 = *printf -- 1 means gzprintf() not secure!
+ 26: 0 = returns value, 1 = void -- 1 means inferred string length returned
+
+ Remainder:
+ 27-31: 0 (reserved)
+ */
+
+#ifndef Z_SOLO
+
+ /* utility functions */
+
+/*
+ The following utility functions are implemented on top of the basic
+ stream-oriented functions. To simplify the interface, some default options
+ are assumed (compression level and memory usage, standard memory allocation
+ functions). The source code of these utility functions can be modified if
+ you need special options.
+*/
+
+ZEXTERN int ZEXPORT compress OF((Bytef *dest, uLongf *destLen,
+ const Bytef *source, uLong sourceLen));
+/*
+ Compresses the source buffer into the destination buffer. sourceLen is
+ the byte length of the source buffer. Upon entry, destLen is the total size
+ of the destination buffer, which must be at least the value returned by
+ compressBound(sourceLen). Upon exit, destLen is the actual size of the
+ compressed buffer.
+
+ compress returns Z_OK if success, Z_MEM_ERROR if there was not
+ enough memory, Z_BUF_ERROR if there was not enough room in the output
+ buffer.
+*/
+
+ZEXTERN int ZEXPORT compress2 OF((Bytef *dest, uLongf *destLen,
+ const Bytef *source, uLong sourceLen,
+ int level));
+/*
+ Compresses the source buffer into the destination buffer. The level
+ parameter has the same meaning as in deflateInit. sourceLen is the byte
+ length of the source buffer. Upon entry, destLen is the total size of the
+ destination buffer, which must be at least the value returned by
+ compressBound(sourceLen). Upon exit, destLen is the actual size of the
+ compressed buffer.
+
+ compress2 returns Z_OK if success, Z_MEM_ERROR if there was not enough
+ memory, Z_BUF_ERROR if there was not enough room in the output buffer,
+ Z_STREAM_ERROR if the level parameter is invalid.
+*/
+
+ZEXTERN uLong ZEXPORT compressBound OF((uLong sourceLen));
+/*
+ compressBound() returns an upper bound on the compressed size after
+ compress() or compress2() on sourceLen bytes. It would be used before a
+ compress() or compress2() call to allocate the destination buffer.
+*/
+
+ZEXTERN int ZEXPORT uncompress OF((Bytef *dest, uLongf *destLen,
+ const Bytef *source, uLong sourceLen));
+/*
+ Decompresses the source buffer into the destination buffer. sourceLen is
+ the byte length of the source buffer. Upon entry, destLen is the total size
+ of the destination buffer, which must be large enough to hold the entire
+ uncompressed data. (The size of the uncompressed data must have been saved
+ previously by the compressor and transmitted to the decompressor by some
+ mechanism outside the scope of this compression library.) Upon exit, destLen
+ is the actual size of the uncompressed buffer.
+
+ uncompress returns Z_OK if success, Z_MEM_ERROR if there was not
+ enough memory, Z_BUF_ERROR if there was not enough room in the output
+ buffer, or Z_DATA_ERROR if the input data was corrupted or incomplete. In
+ the case where there is not enough room, uncompress() will fill the output
+ buffer with the uncompressed data up to that point.
+*/
+
+ /* gzip file access functions */
+
+/*
+ This library supports reading and writing files in gzip (.gz) format with
+ an interface similar to that of stdio, using the functions that start with
+ "gz". The gzip format is different from the zlib format. gzip is a gzip
+ wrapper, documented in RFC 1952, wrapped around a deflate stream.
+*/
+
+typedef struct gzFile_s *gzFile; /* semi-opaque gzip file descriptor */
+
+/*
+ZEXTERN gzFile ZEXPORT gzopen OF((const char *path, const char *mode));
+
+ Opens a gzip (.gz) file for reading or writing. The mode parameter is as
+ in fopen ("rb" or "wb") but can also include a compression level ("wb9") or
+ a strategy: 'f' for filtered data as in "wb6f", 'h' for Huffman-only
+ compression as in "wb1h", 'R' for run-length encoding as in "wb1R", or 'F'
+ for fixed code compression as in "wb9F". (See the description of
+ deflateInit2 for more information about the strategy parameter.) 'T' will
+ request transparent writing or appending with no compression and not using
+ the gzip format.
+
+ "a" can be used instead of "w" to request that the gzip stream that will
+ be written be appended to the file. "+" will result in an error, since
+ reading and writing to the same gzip file is not supported. The addition of
+ "x" when writing will create the file exclusively, which fails if the file
+ already exists. On systems that support it, the addition of "e" when
+ reading or writing will set the flag to close the file on an execve() call.
+
+ These functions, as well as gzip, will read and decode a sequence of gzip
+ streams in a file. The append function of gzopen() can be used to create
+ such a file. (Also see gzflush() for another way to do this.) When
+ appending, gzopen does not test whether the file begins with a gzip stream,
+ nor does it look for the end of the gzip streams to begin appending. gzopen
+ will simply append a gzip stream to the existing file.
+
+ gzopen can be used to read a file which is not in gzip format; in this
+ case gzread will directly read from the file without decompression. When
+ reading, this will be detected automatically by looking for the magic two-
+ byte gzip header.
+
+ gzopen returns NULL if the file could not be opened, if there was
+ insufficient memory to allocate the gzFile state, or if an invalid mode was
+ specified (an 'r', 'w', or 'a' was not provided, or '+' was provided).
+ errno can be checked to determine if the reason gzopen failed was that the
+ file could not be opened.
+*/
+
+ZEXTERN gzFile ZEXPORT gzdopen OF((int fd, const char *mode));
+/*
+ gzdopen associates a gzFile with the file descriptor fd. File descriptors
+ are obtained from calls like open, dup, creat, pipe or fileno (if the file
+ has been previously opened with fopen). The mode parameter is as in gzopen.
+
+ The next call of gzclose on the returned gzFile will also close the file
+ descriptor fd, just like fclose(fdopen(fd, mode)) closes the file descriptor
+ fd. If you want to keep fd open, use fd = dup(fd_keep); gz = gzdopen(fd,
+ mode);. The duplicated descriptor should be saved to avoid a leak, since
+ gzdopen does not close fd if it fails. If you are using fileno() to get the
+ file descriptor from a FILE *, then you will have to use dup() to avoid
+ double-close()ing the file descriptor. Both gzclose() and fclose() will
+ close the associated file descriptor, so they need to have different file
+ descriptors.
+
+ gzdopen returns NULL if there was insufficient memory to allocate the
+ gzFile state, if an invalid mode was specified (an 'r', 'w', or 'a' was not
+ provided, or '+' was provided), or if fd is -1. The file descriptor is not
+ used until the next gz* read, write, seek, or close operation, so gzdopen
+ will not detect if fd is invalid (unless fd is -1).
+*/
+
+ZEXTERN int ZEXPORT gzbuffer OF((gzFile file, unsigned size));
+/*
+ Set the internal buffer size used by this library's functions. The
+ default buffer size is 8192 bytes. This function must be called after
+ gzopen() or gzdopen(), and before any other calls that read or write the
+ file. The buffer memory allocation is always deferred to the first read or
+ write. Two buffers are allocated, either both of the specified size when
+ writing, or one of the specified size and the other twice that size when
+ reading. A larger buffer size of, for example, 64K or 128K bytes will
+ noticeably increase the speed of decompression (reading).
+
+ The new buffer size also affects the maximum length for gzprintf().
+
+ gzbuffer() returns 0 on success, or -1 on failure, such as being called
+ too late.
+*/
+
+ZEXTERN int ZEXPORT gzsetparams OF((gzFile file, int level, int strategy));
+/*
+ Dynamically update the compression level or strategy. See the description
+ of deflateInit2 for the meaning of these parameters.
+
+ gzsetparams returns Z_OK if success, or Z_STREAM_ERROR if the file was not
+ opened for writing.
+*/
+
+ZEXTERN int ZEXPORT gzread OF((gzFile file, voidp buf, unsigned len));
+/*
+ Reads the given number of uncompressed bytes from the compressed file. If
+ the input file is not in gzip format, gzread copies the given number of
+ bytes into the buffer directly from the file.
+
+ After reaching the end of a gzip stream in the input, gzread will continue
+ to read, looking for another gzip stream. Any number of gzip streams may be
+ concatenated in the input file, and will all be decompressed by gzread().
+ If something other than a gzip stream is encountered after a gzip stream,
+ that remaining trailing garbage is ignored (and no error is returned).
+
+ gzread can be used to read a gzip file that is being concurrently written.
+ Upon reaching the end of the input, gzread will return with the available
+ data. If the error code returned by gzerror is Z_OK or Z_BUF_ERROR, then
+ gzclearerr can be used to clear the end of file indicator in order to permit
+ gzread to be tried again. Z_OK indicates that a gzip stream was completed
+ on the last gzread. Z_BUF_ERROR indicates that the input file ended in the
+ middle of a gzip stream. Note that gzread does not return -1 in the event
+ of an incomplete gzip stream. This error is deferred until gzclose(), which
+ will return Z_BUF_ERROR if the last gzread ended in the middle of a gzip
+ stream. Alternatively, gzerror can be used before gzclose to detect this
+ case.
+
+ gzread returns the number of uncompressed bytes actually read, less than
+ len for end of file, or -1 for error.
+*/
+
+ZEXTERN int ZEXPORT gzwrite OF((gzFile file,
+ voidpc buf, unsigned len));
+/*
+ Writes the given number of uncompressed bytes into the compressed file.
+ gzwrite returns the number of uncompressed bytes written or 0 in case of
+ error.
+*/
+
+ZEXTERN int ZEXPORTVA gzprintf Z_ARG((gzFile file, const char *format, ...));
+/*
+ Converts, formats, and writes the arguments to the compressed file under
+ control of the format string, as in fprintf. gzprintf returns the number of
+ uncompressed bytes actually written, or 0 in case of error. The number of
+ uncompressed bytes written is limited to 8191, or one less than the buffer
+ size given to gzbuffer(). The caller should assure that this limit is not
+ exceeded. If it is exceeded, then gzprintf() will return an error (0) with
+ nothing written. In this case, there may also be a buffer overflow with
+ unpredictable consequences, which is possible only if zlib was compiled with
+ the insecure functions sprintf() or vsprintf() because the secure snprintf()
+ or vsnprintf() functions were not available. This can be determined using
+ zlibCompileFlags().
+*/
+
+ZEXTERN int ZEXPORT gzputs OF((gzFile file, const char *s));
+/*
+ Writes the given null-terminated string to the compressed file, excluding
+ the terminating null character.
+
+ gzputs returns the number of characters written, or -1 in case of error.
+*/
+
+ZEXTERN char * ZEXPORT gzgets OF((gzFile file, char *buf, int len));
+/*
+ Reads bytes from the compressed file until len-1 characters are read, or a
+ newline character is read and transferred to buf, or an end-of-file
+ condition is encountered. If any characters are read or if len == 1, the
+ string is terminated with a null character. If no characters are read due
+ to an end-of-file or len < 1, then the buffer is left untouched.
+
+ gzgets returns buf which is a null-terminated string, or it returns NULL
+ for end-of-file or in case of error. If there was an error, the contents at
+ buf are indeterminate.
+*/
+
+ZEXTERN int ZEXPORT gzputc OF((gzFile file, int c));
+/*
+ Writes c, converted to an unsigned char, into the compressed file. gzputc
+ returns the value that was written, or -1 in case of error.
+*/
+
+ZEXTERN int ZEXPORT gzgetc OF((gzFile file));
+/*
+ Reads one byte from the compressed file. gzgetc returns this byte or -1
+ in case of end of file or error. This is implemented as a macro for speed.
+ As such, it does not do all of the checking the other functions do. I.e.
+ it does not check to see if file is NULL, nor whether the structure file
+ points to has been clobbered or not.
+*/
+
+ZEXTERN int ZEXPORT gzungetc OF((int c, gzFile file));
+/*
+ Push one character back onto the stream to be read as the first character
+ on the next read. At least one character of push-back is allowed.
+ gzungetc() returns the character pushed, or -1 on failure. gzungetc() will
+ fail if c is -1, and may fail if a character has been pushed but not read
+ yet. If gzungetc is used immediately after gzopen or gzdopen, at least the
+ output buffer size of pushed characters is allowed. (See gzbuffer above.)
+ The pushed character will be discarded if the stream is repositioned with
+ gzseek() or gzrewind().
+*/
+
+ZEXTERN int ZEXPORT gzflush OF((gzFile file, int flush));
+/*
+ Flushes all pending output into the compressed file. The parameter flush
+ is as in the deflate() function. The return value is the zlib error number
+ (see function gzerror below). gzflush is only permitted when writing.
+
+ If the flush parameter is Z_FINISH, the remaining data is written and the
+ gzip stream is completed in the output. If gzwrite() is called again, a new
+ gzip stream will be started in the output. gzread() is able to read such
+ concatented gzip streams.
+
+ gzflush should be called only when strictly necessary because it will
+ degrade compression if called too often.
+*/
+
+/*
+ZEXTERN z_off_t ZEXPORT gzseek OF((gzFile file,
+ z_off_t offset, int whence));
+
+ Sets the starting position for the next gzread or gzwrite on the given
+ compressed file. The offset represents a number of bytes in the
+ uncompressed data stream. The whence parameter is defined as in lseek(2);
+ the value SEEK_END is not supported.
+
+ If the file is opened for reading, this function is emulated but can be
+ extremely slow. If the file is opened for writing, only forward seeks are
+ supported; gzseek then compresses a sequence of zeroes up to the new
+ starting position.
+
+ gzseek returns the resulting offset location as measured in bytes from
+ the beginning of the uncompressed stream, or -1 in case of error, in
+ particular if the file is opened for writing and the new starting position
+ would be before the current position.
+*/
+
+ZEXTERN int ZEXPORT gzrewind OF((gzFile file));
+/*
+ Rewinds the given file. This function is supported only for reading.
+
+ gzrewind(file) is equivalent to (int)gzseek(file, 0L, SEEK_SET)
+*/
+
+/*
+ZEXTERN z_off_t ZEXPORT gztell OF((gzFile file));
+
+ Returns the starting position for the next gzread or gzwrite on the given
+ compressed file. This position represents a number of bytes in the
+ uncompressed data stream, and is zero when starting, even if appending or
+ reading a gzip stream from the middle of a file using gzdopen().
+
+ gztell(file) is equivalent to gzseek(file, 0L, SEEK_CUR)
+*/
+
+/*
+ZEXTERN z_off_t ZEXPORT gzoffset OF((gzFile file));
+
+ Returns the current offset in the file being read or written. This offset
+ includes the count of bytes that precede the gzip stream, for example when
+ appending or when using gzdopen() for reading. When reading, the offset
+ does not include as yet unused buffered input. This information can be used
+ for a progress indicator. On error, gzoffset() returns -1.
+*/
+
+ZEXTERN int ZEXPORT gzeof OF((gzFile file));
+/*
+ Returns true (1) if the end-of-file indicator has been set while reading,
+ false (0) otherwise. Note that the end-of-file indicator is set only if the
+ read tried to go past the end of the input, but came up short. Therefore,
+ just like feof(), gzeof() may return false even if there is no more data to
+ read, in the event that the last read request was for the exact number of
+ bytes remaining in the input file. This will happen if the input file size
+ is an exact multiple of the buffer size.
+
+ If gzeof() returns true, then the read functions will return no more data,
+ unless the end-of-file indicator is reset by gzclearerr() and the input file
+ has grown since the previous end of file was detected.
+*/
+
+ZEXTERN int ZEXPORT gzdirect OF((gzFile file));
+/*
+ Returns true (1) if file is being copied directly while reading, or false
+ (0) if file is a gzip stream being decompressed.
+
+ If the input file is empty, gzdirect() will return true, since the input
+ does not contain a gzip stream.
+
+ If gzdirect() is used immediately after gzopen() or gzdopen() it will
+ cause buffers to be allocated to allow reading the file to determine if it
+ is a gzip file. Therefore if gzbuffer() is used, it should be called before
+ gzdirect().
+
+ When writing, gzdirect() returns true (1) if transparent writing was
+ requested ("wT" for the gzopen() mode), or false (0) otherwise. (Note:
+ gzdirect() is not needed when writing. Transparent writing must be
+ explicitly requested, so the application already knows the answer. When
+ linking statically, using gzdirect() will include all of the zlib code for
+ gzip file reading and decompression, which may not be desired.)
+*/
+
+ZEXTERN int ZEXPORT gzclose OF((gzFile file));
+/*
+ Flushes all pending output if necessary, closes the compressed file and
+ deallocates the (de)compression state. Note that once file is closed, you
+ cannot call gzerror with file, since its structures have been deallocated.
+ gzclose must not be called more than once on the same file, just as free
+ must not be called more than once on the same allocation.
+
+ gzclose will return Z_STREAM_ERROR if file is not valid, Z_ERRNO on a
+ file operation error, Z_MEM_ERROR if out of memory, Z_BUF_ERROR if the
+ last read ended in the middle of a gzip stream, or Z_OK on success.
+*/
+
+ZEXTERN int ZEXPORT gzclose_r OF((gzFile file));
+ZEXTERN int ZEXPORT gzclose_w OF((gzFile file));
+/*
+ Same as gzclose(), but gzclose_r() is only for use when reading, and
+ gzclose_w() is only for use when writing or appending. The advantage to
+ using these instead of gzclose() is that they avoid linking in zlib
+ compression or decompression code that is not used when only reading or only
+ writing respectively. If gzclose() is used, then both compression and
+ decompression code will be included the application when linking to a static
+ zlib library.
+*/
+
+ZEXTERN const char * ZEXPORT gzerror OF((gzFile file, int *errnum));
+/*
+ Returns the error message for the last error which occurred on the given
+ compressed file. errnum is set to zlib error number. If an error occurred
+ in the file system and not in the compression library, errnum is set to
+ Z_ERRNO and the application may consult errno to get the exact error code.
+
+ The application must not modify the returned string. Future calls to
+ this function may invalidate the previously returned string. If file is
+ closed, then the string previously returned by gzerror will no longer be
+ available.
+
+ gzerror() should be used to distinguish errors from end-of-file for those
+ functions above that do not distinguish those cases in their return values.
+*/
+
+ZEXTERN void ZEXPORT gzclearerr OF((gzFile file));
+/*
+ Clears the error and end-of-file flags for file. This is analogous to the
+ clearerr() function in stdio. This is useful for continuing to read a gzip
+ file that is being written concurrently.
+*/
+
+#endif /* !Z_SOLO */
+
+ /* checksum functions */
+
+/*
+ These functions are not related to compression but are exported
+ anyway because they might be useful in applications using the compression
+ library.
+*/
+
+ZEXTERN uLong ZEXPORT adler32 OF((uLong adler, const Bytef *buf, uInt len));
+/*
+ Update a running Adler-32 checksum with the bytes buf[0..len-1] and
+ return the updated checksum. If buf is Z_NULL, this function returns the
+ required initial value for the checksum.
+
+ An Adler-32 checksum is almost as reliable as a CRC32 but can be computed
+ much faster.
+
+ Usage example:
+
+ uLong adler = adler32(0L, Z_NULL, 0);
+
+ while (read_buffer(buffer, length) != EOF) {
+ adler = adler32(adler, buffer, length);
+ }
+ if (adler != original_adler) error();
+*/
+
+/*
+ZEXTERN uLong ZEXPORT adler32_combine OF((uLong adler1, uLong adler2,
+ z_off_t len2));
+
+ Combine two Adler-32 checksums into one. For two sequences of bytes, seq1
+ and seq2 with lengths len1 and len2, Adler-32 checksums were calculated for
+ each, adler1 and adler2. adler32_combine() returns the Adler-32 checksum of
+ seq1 and seq2 concatenated, requiring only adler1, adler2, and len2. Note
+ that the z_off_t type (like off_t) is a signed integer. If len2 is
+ negative, the result has no meaning or utility.
+*/
+
+ZEXTERN uLong ZEXPORT crc32 OF((uLong crc, const Bytef *buf, uInt len));
+/*
+ Update a running CRC-32 with the bytes buf[0..len-1] and return the
+ updated CRC-32. If buf is Z_NULL, this function returns the required
+ initial value for the crc. Pre- and post-conditioning (one's complement) is
+ performed within this function so it shouldn't be done by the application.
+
+ Usage example:
+
+ uLong crc = crc32(0L, Z_NULL, 0);
+
+ while (read_buffer(buffer, length) != EOF) {
+ crc = crc32(crc, buffer, length);
+ }
+ if (crc != original_crc) error();
+*/
+
+/*
+ZEXTERN uLong ZEXPORT crc32_combine OF((uLong crc1, uLong crc2, z_off_t len2));
+
+ Combine two CRC-32 check values into one. For two sequences of bytes,
+ seq1 and seq2 with lengths len1 and len2, CRC-32 check values were
+ calculated for each, crc1 and crc2. crc32_combine() returns the CRC-32
+ check value of seq1 and seq2 concatenated, requiring only crc1, crc2, and
+ len2.
+*/
+
+
+ /* various hacks, don't look :) */
+
+/* deflateInit and inflateInit are macros to allow checking the zlib version
+ * and the compiler's view of z_stream:
+ */
+ZEXTERN int ZEXPORT deflateInit_ OF((z_streamp strm, int level,
+ const char *version, int stream_size));
+ZEXTERN int ZEXPORT inflateInit_ OF((z_streamp strm,
+ const char *version, int stream_size));
+ZEXTERN int ZEXPORT deflateInit2_ OF((z_streamp strm, int level, int method,
+ int windowBits, int memLevel,
+ int strategy, const char *version,
+ int stream_size));
+ZEXTERN int ZEXPORT inflateInit2_ OF((z_streamp strm, int windowBits,
+ const char *version, int stream_size));
+ZEXTERN int ZEXPORT inflateBackInit_ OF((z_streamp strm, int windowBits,
+ unsigned char FAR *window,
+ const char *version,
+ int stream_size));
+#define deflateInit(strm, level) \
+ deflateInit_((strm), (level), ZLIB_VERSION, (int)sizeof(z_stream))
+#define inflateInit(strm) \
+ inflateInit_((strm), ZLIB_VERSION, (int)sizeof(z_stream))
+#define deflateInit2(strm, level, method, windowBits, memLevel, strategy) \
+ deflateInit2_((strm),(level),(method),(windowBits),(memLevel),\
+ (strategy), ZLIB_VERSION, (int)sizeof(z_stream))
+#define inflateInit2(strm, windowBits) \
+ inflateInit2_((strm), (windowBits), ZLIB_VERSION, \
+ (int)sizeof(z_stream))
+#define inflateBackInit(strm, windowBits, window) \
+ inflateBackInit_((strm), (windowBits), (window), \
+ ZLIB_VERSION, (int)sizeof(z_stream))
+
+#ifndef Z_SOLO
+
+/* gzgetc() macro and its supporting function and exposed data structure. Note
+ * that the real internal state is much larger than the exposed structure.
+ * This abbreviated structure exposes just enough for the gzgetc() macro. The
+ * user should not mess with these exposed elements, since their names or
+ * behavior could change in the future, perhaps even capriciously. They can
+ * only be used by the gzgetc() macro. You have been warned.
+ */
+struct gzFile_s {
+ unsigned have;
+ unsigned char *next;
+ z_off64_t pos;
+};
+ZEXTERN int ZEXPORT gzgetc_ OF((gzFile file)); /* backward compatibility */
+#ifdef Z_PREFIX_SET
+# undef z_gzgetc
+# define z_gzgetc(g) \
+ ((g)->have ? ((g)->have--, (g)->pos++, *((g)->next)++) : gzgetc(g))
+#else
+# define gzgetc(g) \
+ ((g)->have ? ((g)->have--, (g)->pos++, *((g)->next)++) : gzgetc(g))
+#endif
+
+/* provide 64-bit offset functions if _LARGEFILE64_SOURCE defined, and/or
+ * change the regular functions to 64 bits if _FILE_OFFSET_BITS is 64 (if
+ * both are true, the application gets the *64 functions, and the regular
+ * functions are changed to 64 bits) -- in case these are set on systems
+ * without large file support, _LFS64_LARGEFILE must also be true
+ */
+#ifdef Z_LARGE64
+ ZEXTERN gzFile ZEXPORT gzopen64 OF((const char *, const char *));
+ ZEXTERN z_off64_t ZEXPORT gzseek64 OF((gzFile, z_off64_t, int));
+ ZEXTERN z_off64_t ZEXPORT gztell64 OF((gzFile));
+ ZEXTERN z_off64_t ZEXPORT gzoffset64 OF((gzFile));
+ ZEXTERN uLong ZEXPORT adler32_combine64 OF((uLong, uLong, z_off64_t));
+ ZEXTERN uLong ZEXPORT crc32_combine64 OF((uLong, uLong, z_off64_t));
+#endif
+
+#if !defined(ZLIB_INTERNAL) && defined(Z_WANT64)
+# ifdef Z_PREFIX_SET
+# define z_gzopen z_gzopen64
+# define z_gzseek z_gzseek64
+# define z_gztell z_gztell64
+# define z_gzoffset z_gzoffset64
+# define z_adler32_combine z_adler32_combine64
+# define z_crc32_combine z_crc32_combine64
+# else
+# define gzopen gzopen64
+# define gzseek gzseek64
+# define gztell gztell64
+# define gzoffset gzoffset64
+# define adler32_combine adler32_combine64
+# define crc32_combine crc32_combine64
+# endif
+# ifndef Z_LARGE64
+ ZEXTERN gzFile ZEXPORT gzopen64 OF((const char *, const char *));
+ ZEXTERN z_off_t ZEXPORT gzseek64 OF((gzFile, z_off_t, int));
+ ZEXTERN z_off_t ZEXPORT gztell64 OF((gzFile));
+ ZEXTERN z_off_t ZEXPORT gzoffset64 OF((gzFile));
+ ZEXTERN uLong ZEXPORT adler32_combine64 OF((uLong, uLong, z_off_t));
+ ZEXTERN uLong ZEXPORT crc32_combine64 OF((uLong, uLong, z_off_t));
+# endif
+#else
+ ZEXTERN gzFile ZEXPORT gzopen OF((const char *, const char *));
+ ZEXTERN z_off_t ZEXPORT gzseek OF((gzFile, z_off_t, int));
+ ZEXTERN z_off_t ZEXPORT gztell OF((gzFile));
+ ZEXTERN z_off_t ZEXPORT gzoffset OF((gzFile));
+ ZEXTERN uLong ZEXPORT adler32_combine OF((uLong, uLong, z_off_t));
+ ZEXTERN uLong ZEXPORT crc32_combine OF((uLong, uLong, z_off_t));
+#endif
+
+#else /* Z_SOLO */
+
+ ZEXTERN uLong ZEXPORT adler32_combine OF((uLong, uLong, z_off_t));
+ ZEXTERN uLong ZEXPORT crc32_combine OF((uLong, uLong, z_off_t));
+
+#endif /* !Z_SOLO */
+
+/* hack for buggy compilers */
+#if !defined(ZUTIL_H) && !defined(NO_DUMMY_DECL)
+ struct internal_state {int dummy;};
+#endif
+
+/* undocumented functions */
+ZEXTERN const char * ZEXPORT zError OF((int));
+ZEXTERN int ZEXPORT inflateSyncPoint OF((z_streamp));
+ZEXTERN const z_crc_t FAR * ZEXPORT get_crc_table OF((void));
+ZEXTERN int ZEXPORT inflateUndermine OF((z_streamp, int));
+ZEXTERN int ZEXPORT inflateResetKeep OF((z_streamp));
+ZEXTERN int ZEXPORT deflateResetKeep OF((z_streamp));
+#if defined(_WIN32) && !defined(Z_SOLO)
+ZEXTERN gzFile ZEXPORT gzopen_w OF((const wchar_t *path,
+ const char *mode));
+#endif
+#if defined(STDC) || defined(Z_HAVE_STDARG_H)
+# ifndef Z_SOLO
+ZEXTERN int ZEXPORTVA gzvprintf Z_ARG((gzFile file,
+ const char *format,
+ va_list va));
+# endif
+#endif
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* ZLIB_H */
diff --git a/ml/dlib/dlib/external/zlib/zutil.c b/ml/dlib/dlib/external/zlib/zutil.c
new file mode 100644
index 000000000..23d2ebef0
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/zutil.c
@@ -0,0 +1,324 @@
+/* zutil.c -- target dependent utility functions for the compression library
+ * Copyright (C) 1995-2005, 2010, 2011, 2012 Jean-loup Gailly.
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* @(#) $Id$ */
+
+#include "zutil.h"
+#ifndef Z_SOLO
+# include "gzguts.h"
+#endif
+
+#ifndef NO_DUMMY_DECL
+struct internal_state {int dummy;}; /* for buggy compilers */
+#endif
+
+z_const char * const z_errmsg[10] = {
+"need dictionary", /* Z_NEED_DICT 2 */
+"stream end", /* Z_STREAM_END 1 */
+"", /* Z_OK 0 */
+"file error", /* Z_ERRNO (-1) */
+"stream error", /* Z_STREAM_ERROR (-2) */
+"data error", /* Z_DATA_ERROR (-3) */
+"insufficient memory", /* Z_MEM_ERROR (-4) */
+"buffer error", /* Z_BUF_ERROR (-5) */
+"incompatible version",/* Z_VERSION_ERROR (-6) */
+""};
+
+
+const char * ZEXPORT zlibVersion()
+{
+ return ZLIB_VERSION;
+}
+
+uLong ZEXPORT zlibCompileFlags()
+{
+ uLong flags;
+
+ flags = 0;
+ switch ((int)(sizeof(uInt))) {
+ case 2: break;
+ case 4: flags += 1; break;
+ case 8: flags += 2; break;
+ default: flags += 3;
+ }
+ switch ((int)(sizeof(uLong))) {
+ case 2: break;
+ case 4: flags += 1 << 2; break;
+ case 8: flags += 2 << 2; break;
+ default: flags += 3 << 2;
+ }
+ switch ((int)(sizeof(voidpf))) {
+ case 2: break;
+ case 4: flags += 1 << 4; break;
+ case 8: flags += 2 << 4; break;
+ default: flags += 3 << 4;
+ }
+ switch ((int)(sizeof(z_off_t))) {
+ case 2: break;
+ case 4: flags += 1 << 6; break;
+ case 8: flags += 2 << 6; break;
+ default: flags += 3 << 6;
+ }
+#ifdef DEBUG
+ flags += 1 << 8;
+#endif
+#if defined(ASMV) || defined(ASMINF)
+ flags += 1 << 9;
+#endif
+#ifdef ZLIB_WINAPI
+ flags += 1 << 10;
+#endif
+#ifdef BUILDFIXED
+ flags += 1 << 12;
+#endif
+#ifdef DYNAMIC_CRC_TABLE
+ flags += 1 << 13;
+#endif
+#ifdef NO_GZCOMPRESS
+ flags += 1L << 16;
+#endif
+#ifdef NO_GZIP
+ flags += 1L << 17;
+#endif
+#ifdef PKZIP_BUG_WORKAROUND
+ flags += 1L << 20;
+#endif
+#ifdef FASTEST
+ flags += 1L << 21;
+#endif
+#if defined(STDC) || defined(Z_HAVE_STDARG_H)
+# ifdef NO_vsnprintf
+ flags += 1L << 25;
+# ifdef HAS_vsprintf_void
+ flags += 1L << 26;
+# endif
+# else
+# ifdef HAS_vsnprintf_void
+ flags += 1L << 26;
+# endif
+# endif
+#else
+ flags += 1L << 24;
+# ifdef NO_snprintf
+ flags += 1L << 25;
+# ifdef HAS_sprintf_void
+ flags += 1L << 26;
+# endif
+# else
+# ifdef HAS_snprintf_void
+ flags += 1L << 26;
+# endif
+# endif
+#endif
+ return flags;
+}
+
+#ifdef DEBUG
+
+# ifndef verbose
+# define verbose 0
+# endif
+int ZLIB_INTERNAL z_verbose = verbose;
+
+void ZLIB_INTERNAL z_error (m)
+ char *m;
+{
+ fprintf(stderr, "%s\n", m);
+ exit(1);
+}
+#endif
+
+/* exported to allow conversion of error code to string for compress() and
+ * uncompress()
+ */
+const char * ZEXPORT zError(err)
+ int err;
+{
+ return ERR_MSG(err);
+}
+
+#if defined(_WIN32_WCE)
+ /* The Microsoft C Run-Time Library for Windows CE doesn't have
+ * errno. We define it as a global variable to simplify porting.
+ * Its value is always 0 and should not be used.
+ */
+ int errno = 0;
+#endif
+
+#ifndef HAVE_MEMCPY
+
+void ZLIB_INTERNAL zmemcpy(dest, source, len)
+ Bytef* dest;
+ const Bytef* source;
+ uInt len;
+{
+ if (len == 0) return;
+ do {
+ *dest++ = *source++; /* ??? to be unrolled */
+ } while (--len != 0);
+}
+
+int ZLIB_INTERNAL zmemcmp(s1, s2, len)
+ const Bytef* s1;
+ const Bytef* s2;
+ uInt len;
+{
+ uInt j;
+
+ for (j = 0; j < len; j++) {
+ if (s1[j] != s2[j]) return 2*(s1[j] > s2[j])-1;
+ }
+ return 0;
+}
+
+void ZLIB_INTERNAL zmemzero(dest, len)
+ Bytef* dest;
+ uInt len;
+{
+ if (len == 0) return;
+ do {
+ *dest++ = 0; /* ??? to be unrolled */
+ } while (--len != 0);
+}
+#endif
+
+#ifndef Z_SOLO
+
+#ifdef SYS16BIT
+
+#ifdef __TURBOC__
+/* Turbo C in 16-bit mode */
+
+# define MY_ZCALLOC
+
+/* Turbo C malloc() does not allow dynamic allocation of 64K bytes
+ * and farmalloc(64K) returns a pointer with an offset of 8, so we
+ * must fix the pointer. Warning: the pointer must be put back to its
+ * original form in order to free it, use zcfree().
+ */
+
+#define MAX_PTR 10
+/* 10*64K = 640K */
+
+local int next_ptr = 0;
+
+typedef struct ptr_table_s {
+ voidpf org_ptr;
+ voidpf new_ptr;
+} ptr_table;
+
+local ptr_table table[MAX_PTR];
+/* This table is used to remember the original form of pointers
+ * to large buffers (64K). Such pointers are normalized with a zero offset.
+ * Since MSDOS is not a preemptive multitasking OS, this table is not
+ * protected from concurrent access. This hack doesn't work anyway on
+ * a protected system like OS/2. Use Microsoft C instead.
+ */
+
+voidpf ZLIB_INTERNAL zcalloc (voidpf opaque, unsigned items, unsigned size)
+{
+ voidpf buf = opaque; /* just to make some compilers happy */
+ ulg bsize = (ulg)items*size;
+
+ /* If we allocate less than 65520 bytes, we assume that farmalloc
+ * will return a usable pointer which doesn't have to be normalized.
+ */
+ if (bsize < 65520L) {
+ buf = farmalloc(bsize);
+ if (*(ush*)&buf != 0) return buf;
+ } else {
+ buf = farmalloc(bsize + 16L);
+ }
+ if (buf == NULL || next_ptr >= MAX_PTR) return NULL;
+ table[next_ptr].org_ptr = buf;
+
+ /* Normalize the pointer to seg:0 */
+ *((ush*)&buf+1) += ((ush)((uch*)buf-0) + 15) >> 4;
+ *(ush*)&buf = 0;
+ table[next_ptr++].new_ptr = buf;
+ return buf;
+}
+
+void ZLIB_INTERNAL zcfree (voidpf opaque, voidpf ptr)
+{
+ int n;
+ if (*(ush*)&ptr != 0) { /* object < 64K */
+ farfree(ptr);
+ return;
+ }
+ /* Find the original pointer */
+ for (n = 0; n < next_ptr; n++) {
+ if (ptr != table[n].new_ptr) continue;
+
+ farfree(table[n].org_ptr);
+ while (++n < next_ptr) {
+ table[n-1] = table[n];
+ }
+ next_ptr--;
+ return;
+ }
+ ptr = opaque; /* just to make some compilers happy */
+ Assert(0, "zcfree: ptr not found");
+}
+
+#endif /* __TURBOC__ */
+
+
+#ifdef M_I86
+/* Microsoft C in 16-bit mode */
+
+# define MY_ZCALLOC
+
+#if (!defined(_MSC_VER) || (_MSC_VER <= 600))
+# define _halloc halloc
+# define _hfree hfree
+#endif
+
+voidpf ZLIB_INTERNAL zcalloc (voidpf opaque, uInt items, uInt size)
+{
+ if (opaque) opaque = 0; /* to make compiler happy */
+ return _halloc((long)items, size);
+}
+
+void ZLIB_INTERNAL zcfree (voidpf opaque, voidpf ptr)
+{
+ if (opaque) opaque = 0; /* to make compiler happy */
+ _hfree(ptr);
+}
+
+#endif /* M_I86 */
+
+#endif /* SYS16BIT */
+
+
+#ifndef MY_ZCALLOC /* Any system without a special alloc function */
+
+#ifndef STDC
+extern voidp malloc OF((uInt size));
+extern voidp calloc OF((uInt items, uInt size));
+extern void free OF((voidpf ptr));
+#endif
+
+voidpf ZLIB_INTERNAL zcalloc (opaque, items, size)
+ voidpf opaque;
+ unsigned items;
+ unsigned size;
+{
+ if (opaque) items += size - size; /* make compiler happy */
+ return sizeof(uInt) > 2 ? (voidpf)malloc(items * size) :
+ (voidpf)calloc(items, size);
+}
+
+void ZLIB_INTERNAL zcfree (opaque, ptr)
+ voidpf opaque;
+ voidpf ptr;
+{
+ free(ptr);
+ if (opaque) return; /* make compiler happy */
+}
+
+#endif /* MY_ZCALLOC */
+
+#endif /* !Z_SOLO */
diff --git a/ml/dlib/dlib/external/zlib/zutil.h b/ml/dlib/dlib/external/zlib/zutil.h
new file mode 100644
index 000000000..24ab06b1c
--- /dev/null
+++ b/ml/dlib/dlib/external/zlib/zutil.h
@@ -0,0 +1,253 @@
+/* zutil.h -- internal interface and configuration of the compression library
+ * Copyright (C) 1995-2013 Jean-loup Gailly.
+ * For conditions of distribution and use, see copyright notice in zlib.h
+ */
+
+/* WARNING: this file should *not* be used by applications. It is
+ part of the implementation of the compression library and is
+ subject to change. Applications should only use zlib.h.
+ */
+
+/* @(#) $Id$ */
+
+#ifndef ZUTIL_H
+#define ZUTIL_H
+
+#ifdef HAVE_HIDDEN
+# define ZLIB_INTERNAL __attribute__((visibility ("hidden")))
+#else
+# define ZLIB_INTERNAL
+#endif
+
+#include "zlib.h"
+
+#if defined(STDC) && !defined(Z_SOLO)
+# if !(defined(_WIN32_WCE) && defined(_MSC_VER))
+# include <stddef.h>
+# endif
+# include <string.h>
+# include <stdlib.h>
+#endif
+
+#ifdef Z_SOLO
+ typedef long ptrdiff_t; /* guess -- will be caught if guess is wrong */
+#endif
+
+#ifndef local
+# define local static
+#endif
+/* compile with -Dlocal if your debugger can't find static symbols */
+
+typedef unsigned char uch;
+typedef uch FAR uchf;
+typedef unsigned short ush;
+typedef ush FAR ushf;
+typedef unsigned long ulg;
+
+extern z_const char * const z_errmsg[10]; /* indexed by 2-zlib_error */
+/* (size given to avoid silly warnings with Visual C++) */
+
+#define ERR_MSG(err) z_errmsg[Z_NEED_DICT-(err)]
+
+#define ERR_RETURN(strm,err) \
+ return (strm->msg = ERR_MSG(err), (err))
+/* To be used only when the state is known to be valid */
+
+ /* common constants */
+
+#ifndef DEF_WBITS
+# define DEF_WBITS MAX_WBITS
+#endif
+/* default windowBits for decompression. MAX_WBITS is for compression only */
+
+#if MAX_MEM_LEVEL >= 8
+# define DEF_MEM_LEVEL 8
+#else
+# define DEF_MEM_LEVEL MAX_MEM_LEVEL
+#endif
+/* default memLevel */
+
+#define STORED_BLOCK 0
+#define STATIC_TREES 1
+#define DYN_TREES 2
+/* The three kinds of block type */
+
+#define MIN_MATCH 3
+#define MAX_MATCH 258
+/* The minimum and maximum match lengths */
+
+#define PRESET_DICT 0x20 /* preset dictionary flag in zlib header */
+
+ /* target dependencies */
+
+#if defined(MSDOS) || (defined(WINDOWS) && !defined(WIN32))
+# define OS_CODE 0x00
+# ifndef Z_SOLO
+# if defined(__TURBOC__) || defined(__BORLANDC__)
+# if (__STDC__ == 1) && (defined(__LARGE__) || defined(__COMPACT__))
+ /* Allow compilation with ANSI keywords only enabled */
+ void _Cdecl farfree( void *block );
+ void *_Cdecl farmalloc( unsigned long nbytes );
+# else
+# include <alloc.h>
+# endif
+# else /* MSC or DJGPP */
+# include <malloc.h>
+# endif
+# endif
+#endif
+
+#ifdef AMIGA
+# define OS_CODE 0x01
+#endif
+
+#if defined(VAXC) || defined(VMS)
+# define OS_CODE 0x02
+# define F_OPEN(name, mode) \
+ fopen((name), (mode), "mbc=60", "ctx=stm", "rfm=fix", "mrs=512")
+#endif
+
+#if defined(ATARI) || defined(atarist)
+# define OS_CODE 0x05
+#endif
+
+#ifdef OS2
+# define OS_CODE 0x06
+# if defined(M_I86) && !defined(Z_SOLO)
+# include <malloc.h>
+# endif
+#endif
+
+#if defined(MACOS) || defined(TARGET_OS_MAC)
+# define OS_CODE 0x07
+# ifndef Z_SOLO
+# if defined(__MWERKS__) && __dest_os != __be_os && __dest_os != __win32_os
+# include <unix.h> /* for fdopen */
+# else
+# ifndef fdopen
+# define fdopen(fd,mode) NULL /* No fdopen() */
+# endif
+# endif
+# endif
+#endif
+
+#ifdef TOPS20
+# define OS_CODE 0x0a
+#endif
+
+#ifdef WIN32
+# ifndef __CYGWIN__ /* Cygwin is Unix, not Win32 */
+# define OS_CODE 0x0b
+# endif
+#endif
+
+#ifdef __50SERIES /* Prime/PRIMOS */
+# define OS_CODE 0x0f
+#endif
+
+#if defined(_BEOS_) || defined(RISCOS)
+# define fdopen(fd,mode) NULL /* No fdopen() */
+#endif
+
+#if (defined(_MSC_VER) && (_MSC_VER > 600)) && !defined __INTERIX
+# if defined(_WIN32_WCE)
+# define fdopen(fd,mode) NULL /* No fdopen() */
+# ifndef _PTRDIFF_T_DEFINED
+ typedef int ptrdiff_t;
+# define _PTRDIFF_T_DEFINED
+# endif
+# else
+# define fdopen(fd,type) _fdopen(fd,type)
+# endif
+#endif
+
+#if defined(__BORLANDC__) && !defined(MSDOS)
+ #pragma warn -8004
+ #pragma warn -8008
+ #pragma warn -8066
+#endif
+
+/* provide prototypes for these when building zlib without LFS */
+#if !defined(_WIN32) && \
+ (!defined(_LARGEFILE64_SOURCE) || _LFS64_LARGEFILE-0 == 0)
+ ZEXTERN uLong ZEXPORT adler32_combine64 OF((uLong, uLong, z_off_t));
+ ZEXTERN uLong ZEXPORT crc32_combine64 OF((uLong, uLong, z_off_t));
+#endif
+
+ /* common defaults */
+
+#ifndef OS_CODE
+# define OS_CODE 0x03 /* assume Unix */
+#endif
+
+#ifndef F_OPEN
+# define F_OPEN(name, mode) fopen((name), (mode))
+#endif
+
+ /* functions */
+
+#if defined(pyr) || defined(Z_SOLO)
+# define NO_MEMCPY
+#endif
+#if defined(SMALL_MEDIUM) && !defined(_MSC_VER) && !defined(__SC__)
+ /* Use our own functions for small and medium model with MSC <= 5.0.
+ * You may have to use the same strategy for Borland C (untested).
+ * The __SC__ check is for Symantec.
+ */
+# define NO_MEMCPY
+#endif
+#if defined(STDC) && !defined(HAVE_MEMCPY) && !defined(NO_MEMCPY)
+# define HAVE_MEMCPY
+#endif
+#ifdef HAVE_MEMCPY
+# ifdef SMALL_MEDIUM /* MSDOS small or medium model */
+# define zmemcpy _fmemcpy
+# define zmemcmp _fmemcmp
+# define zmemzero(dest, len) _fmemset(dest, 0, len)
+# else
+# define zmemcpy memcpy
+# define zmemcmp memcmp
+# define zmemzero(dest, len) memset(dest, 0, len)
+# endif
+#else
+ void ZLIB_INTERNAL zmemcpy OF((Bytef* dest, const Bytef* source, uInt len));
+ int ZLIB_INTERNAL zmemcmp OF((const Bytef* s1, const Bytef* s2, uInt len));
+ void ZLIB_INTERNAL zmemzero OF((Bytef* dest, uInt len));
+#endif
+
+/* Diagnostic functions */
+#ifdef DEBUG
+# include <stdio.h>
+ extern int ZLIB_INTERNAL z_verbose;
+ extern void ZLIB_INTERNAL z_error OF((char *m));
+# define Assert(cond,msg) {if(!(cond)) z_error(msg);}
+# define Trace(x) {if (z_verbose>=0) fprintf x ;}
+# define Tracev(x) {if (z_verbose>0) fprintf x ;}
+# define Tracevv(x) {if (z_verbose>1) fprintf x ;}
+# define Tracec(c,x) {if (z_verbose>0 && (c)) fprintf x ;}
+# define Tracecv(c,x) {if (z_verbose>1 && (c)) fprintf x ;}
+#else
+# define Assert(cond,msg)
+# define Trace(x)
+# define Tracev(x)
+# define Tracevv(x)
+# define Tracec(c,x)
+# define Tracecv(c,x)
+#endif
+
+#ifndef Z_SOLO
+ voidpf ZLIB_INTERNAL zcalloc OF((voidpf opaque, unsigned items,
+ unsigned size));
+ void ZLIB_INTERNAL zcfree OF((voidpf opaque, voidpf ptr));
+#endif
+
+#define ZALLOC(strm, items, size) \
+ (*((strm)->zalloc))((strm)->opaque, (items), (size))
+#define ZFREE(strm, addr) (*((strm)->zfree))((strm)->opaque, (voidpf)(addr))
+#define TRY_FREE(s, p) {if (p) ZFREE(s, p);}
+
+/* Reverse the bytes in a 32-bit value */
+#define ZSWAP32(q) ((((q) >> 24) & 0xff) + (((q) >> 8) & 0xff00) + \
+ (((q) & 0xff00) << 8) + (((q) & 0xff) << 24))
+
+#endif /* ZUTIL_H */
diff --git a/ml/dlib/dlib/filtering.h b/ml/dlib/dlib/filtering.h
new file mode 100644
index 000000000..764d54e81
--- /dev/null
+++ b/ml/dlib/dlib/filtering.h
@@ -0,0 +1,12 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FILTERiNG_HEADER
+#define DLIB_FILTERiNG_HEADER
+
+#include "filtering/kalman_filter.h"
+#include "filtering/rls_filter.h"
+
+#endif // DLIB_FILTERiNG_HEADER
+
+
+
diff --git a/ml/dlib/dlib/filtering/kalman_filter.cpp b/ml/dlib/dlib/filtering/kalman_filter.cpp
new file mode 100644
index 000000000..5d47793a9
--- /dev/null
+++ b/ml/dlib/dlib/filtering/kalman_filter.cpp
@@ -0,0 +1,104 @@
+// Copyright (C) 2018 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_KALMAN_FiLTER_CPp_
+#define DLIB_KALMAN_FiLTER_CPp_
+
+#include "kalman_filter.h"
+#include "../global_optimization.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ momentum_filter find_optimal_momentum_filter (
+ const std::vector<std::vector<double>>& sequences,
+ const double smoothness
+ )
+ {
+ DLIB_CASSERT(sequences.size() != 0);
+ for (auto& vals : sequences)
+ DLIB_CASSERT(vals.size() > 4);
+ DLIB_CASSERT(smoothness >= 0);
+
+ // define the objective function we optimize to find the best filter
+ auto obj = [&](double measurement_noise, double typical_acceleration, double max_measurement_deviation)
+ {
+ running_stats<double> rs;
+ for (auto& vals : sequences)
+ {
+ momentum_filter filt(measurement_noise, typical_acceleration, max_measurement_deviation);
+ double prev_filt = 0;
+ for (size_t i = 0; i < vals.size(); ++i)
+ {
+ // we care about smoothness and fitting the data.
+ if (i > 0)
+ {
+ // the filter should fit the data
+ rs.add(std::abs(vals[i]-filt.get_predicted_next_position()));
+ }
+ double next_filt = filt(vals[i]);
+ if (i > 0)
+ {
+ // the filter should also output a smooth trajectory
+ rs.add(smoothness*std::abs(next_filt-prev_filt));
+ }
+ prev_filt = next_filt;
+ }
+ }
+ return rs.mean();
+ };
+
+ running_stats<double> avgdiff;
+ for (auto& vals : sequences)
+ {
+ for (size_t i = 1; i < vals.size(); ++i)
+ avgdiff.add(vals[i]-vals[i-1]);
+ }
+ const double scale = avgdiff.stddev();
+
+ function_evaluation opt = find_min_global(obj, {scale*0.01, scale*0.0001, 0.00001}, {scale*10, scale*10, 10}, max_function_calls(400));
+
+ momentum_filter filt(opt.x(0), opt.x(1), opt.x(2));
+
+ return filt;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ momentum_filter find_optimal_momentum_filter (
+ const std::vector<double>& sequence,
+ const double smoothness
+ )
+ {
+ return find_optimal_momentum_filter({1,sequence}, smoothness);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rect_filter find_optimal_rect_filter (
+ const std::vector<rectangle>& rects,
+ const double smoothness
+ )
+ {
+ DLIB_CASSERT(rects.size() > 4);
+ DLIB_CASSERT(smoothness >= 0);
+
+ std::vector<std::vector<double>> vals(4);
+ for (auto& r : rects)
+ {
+ vals[0].push_back(r.left());
+ vals[1].push_back(r.top());
+ vals[2].push_back(r.right());
+ vals[3].push_back(r.bottom());
+ }
+ return rect_filter(find_optimal_momentum_filter(vals, smoothness));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KALMAN_FiLTER_CPp_
+
diff --git a/ml/dlib/dlib/filtering/kalman_filter.h b/ml/dlib/dlib/filtering/kalman_filter.h
new file mode 100644
index 000000000..30289fa42
--- /dev/null
+++ b/ml/dlib/dlib/filtering/kalman_filter.h
@@ -0,0 +1,382 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_KALMAN_FiLTER_Hh_
+#define DLIB_KALMAN_FiLTER_Hh_
+
+#include "kalman_filter_abstract.h"
+#include "../matrix.h"
+#include "../geometry.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long states,
+ long measurements
+ >
+ class kalman_filter
+ {
+ public:
+
+ kalman_filter()
+ {
+ H = 0;
+ A = 0;
+ Q = 0;
+ R = 0;
+ x = 0;
+ xb = 0;
+ P = identity_matrix<double>(states);
+ got_first_meas = false;
+ }
+
+ void set_observation_model ( const matrix<double,measurements,states>& H_) { H = H_; }
+ void set_transition_model ( const matrix<double,states,states>& A_) { A = A_; }
+ void set_process_noise ( const matrix<double,states,states>& Q_) { Q = Q_; }
+ void set_measurement_noise ( const matrix<double,measurements,measurements>& R_) { R = R_; }
+ void set_estimation_error_covariance( const matrix<double,states,states>& P_) { P = P_; }
+ void set_state ( const matrix<double,states,1>& xb_)
+ {
+ xb = xb_;
+ if (!got_first_meas)
+ {
+ x = xb_;
+ got_first_meas = true;
+ }
+ }
+
+ const matrix<double,measurements,states>& get_observation_model (
+ ) const { return H; }
+
+ const matrix<double,states,states>& get_transition_model (
+ ) const { return A; }
+
+ const matrix<double,states,states>& get_process_noise (
+ ) const { return Q; }
+
+ const matrix<double,measurements,measurements>& get_measurement_noise (
+ ) const { return R; }
+
+ void update (
+ )
+ {
+ // propagate estimation error covariance forward
+ P = A*P*trans(A) + Q;
+
+ // propagate state forward
+ x = xb;
+ xb = A*x;
+ }
+
+ void update (const matrix<double,measurements,1>& z)
+ {
+ // propagate estimation error covariance forward
+ P = A*P*trans(A) + Q;
+
+ // compute Kalman gain matrix
+ const matrix<double,states,measurements> K = P*trans(H)*pinv(H*P*trans(H) + R);
+
+ if (got_first_meas)
+ {
+ const matrix<double,measurements,1> res = z - H*xb;
+ // correct the current state estimate
+ x = xb + K*res;
+ }
+ else
+ {
+ // Since we don't have a previous state estimate at the start of filtering,
+ // we will just set the current state to whatever is indicated by the measurement
+ x = pinv(H)*z;
+ got_first_meas = true;
+ }
+
+ // propagate state forward in time
+ xb = A*x;
+
+ // update estimation error covariance since we got a measurement.
+ P = (identity_matrix<double,states>() - K*H)*P;
+ }
+
+ const matrix<double,states,1>& get_current_state(
+ ) const
+ {
+ return x;
+ }
+
+ const matrix<double,states,1>& get_predicted_next_state(
+ ) const
+ {
+ return xb;
+ }
+
+ const matrix<double,states,states>& get_current_estimation_error_covariance(
+ ) const
+ {
+ return P;
+ }
+
+ friend inline void serialize(const kalman_filter& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.got_first_meas, out);
+ serialize(item.x, out);
+ serialize(item.xb, out);
+ serialize(item.P, out);
+ serialize(item.H, out);
+ serialize(item.A, out);
+ serialize(item.Q, out);
+ serialize(item.R, out);
+ }
+
+ friend inline void deserialize(kalman_filter& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw dlib::serialization_error("Unknown version number found while deserializing kalman_filter object.");
+
+ deserialize(item.got_first_meas, in);
+ deserialize(item.x, in);
+ deserialize(item.xb, in);
+ deserialize(item.P, in);
+ deserialize(item.H, in);
+ deserialize(item.A, in);
+ deserialize(item.Q, in);
+ deserialize(item.R, in);
+ }
+
+ private:
+
+ bool got_first_meas;
+ matrix<double,states,1> x, xb;
+ matrix<double,states,states> P;
+
+ matrix<double,measurements,states> H;
+ matrix<double,states,states> A;
+ matrix<double,states,states> Q;
+ matrix<double,measurements,measurements> R;
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class momentum_filter
+ {
+ public:
+
+ momentum_filter(
+ double meas_noise,
+ double acc,
+ double max_meas_dev
+ ) :
+ measurement_noise(meas_noise),
+ typical_acceleration(acc),
+ max_measurement_deviation(max_meas_dev)
+ {
+ DLIB_CASSERT(meas_noise >= 0);
+ DLIB_CASSERT(acc >= 0);
+ DLIB_CASSERT(max_meas_dev >= 0);
+
+ kal.set_observation_model({1, 0});
+ kal.set_transition_model( {1, 1,
+ 0, 1});
+ kal.set_process_noise({0, 0,
+ 0, typical_acceleration*typical_acceleration});
+
+ kal.set_measurement_noise({measurement_noise*measurement_noise});
+ }
+
+ momentum_filter() = default;
+
+ double get_measurement_noise (
+ ) const { return measurement_noise; }
+
+ double get_typical_acceleration (
+ ) const { return typical_acceleration; }
+
+ double get_max_measurement_deviation (
+ ) const { return max_measurement_deviation; }
+
+ void reset()
+ {
+ *this = momentum_filter(measurement_noise, typical_acceleration, max_measurement_deviation);
+ }
+
+ double get_predicted_next_position(
+ ) const
+ {
+ return kal.get_predicted_next_state()(0);
+ }
+
+ double operator()(
+ const double measured_position
+ )
+ {
+ auto x = kal.get_predicted_next_state();
+ const auto max_deviation = max_measurement_deviation*measurement_noise;
+ // Check if measured_position has suddenly jumped in value by a whole lot. This
+ // could happen if the velocity term experiences a much larger than normal
+ // acceleration, e.g. because the underlying object is doing a maneuver. If
+ // this happens then we clamp the state so that the predicted next value is no
+ // more than max_deviation away from measured_position at all times.
+ if (x(0) > measured_position + max_deviation)
+ {
+ x(0) = measured_position + max_deviation;
+ kal.set_state(x);
+ }
+ else if (x(0) < measured_position - max_deviation)
+ {
+ x(0) = measured_position - max_deviation;
+ kal.set_state(x);
+ }
+
+ kal.update({measured_position});
+
+ return kal.get_current_state()(0);
+ }
+
+ friend std::ostream& operator << (std::ostream& out, const momentum_filter& item)
+ {
+ out << "measurement_noise: " << item.measurement_noise << "\n";
+ out << "typical_acceleration: " << item.typical_acceleration << "\n";
+ out << "max_measurement_deviation: " << item.max_measurement_deviation;
+ return out;
+ }
+
+ friend void serialize(const momentum_filter& item, std::ostream& out)
+ {
+ int version = 15;
+ serialize(version, out);
+ serialize(item.measurement_noise, out);
+ serialize(item.typical_acceleration, out);
+ serialize(item.max_measurement_deviation, out);
+ serialize(item.kal, out);
+ }
+
+ friend void deserialize(momentum_filter& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 15)
+ throw serialization_error("Unexpected version found while deserializing momentum_filter.");
+ deserialize(item.measurement_noise, in);
+ deserialize(item.typical_acceleration, in);
+ deserialize(item.max_measurement_deviation, in);
+ deserialize(item.kal, in);
+ }
+
+ private:
+
+ double measurement_noise = 2;
+ double typical_acceleration = 0.1;
+ double max_measurement_deviation = 3; // nominally number of standard deviations
+
+ kalman_filter<2,1> kal;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ momentum_filter find_optimal_momentum_filter (
+ const std::vector<std::vector<double>>& sequences,
+ const double smoothness = 1
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ momentum_filter find_optimal_momentum_filter (
+ const std::vector<double>& sequence,
+ const double smoothness = 1
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ class rect_filter
+ {
+ public:
+ rect_filter() = default;
+
+ rect_filter(
+ double meas_noise,
+ double acc,
+ double max_meas_dev
+ ) : rect_filter(momentum_filter(meas_noise, acc, max_meas_dev)) {}
+
+ rect_filter(
+ const momentum_filter& filt
+ ) :
+ left(filt),
+ top(filt),
+ right(filt),
+ bottom(filt)
+ {
+ }
+
+ drectangle operator()(const drectangle& r)
+ {
+ return drectangle(left(r.left()),
+ top(r.top()),
+ right(r.right()),
+ bottom(r.bottom()));
+ }
+
+ drectangle operator()(const rectangle& r)
+ {
+ return drectangle(left(r.left()),
+ top(r.top()),
+ right(r.right()),
+ bottom(r.bottom()));
+ }
+
+ const momentum_filter& get_left () const { return left; }
+ momentum_filter& get_left () { return left; }
+ const momentum_filter& get_top () const { return top; }
+ momentum_filter& get_top () { return top; }
+ const momentum_filter& get_right () const { return right; }
+ momentum_filter& get_right () { return right; }
+ const momentum_filter& get_bottom () const { return bottom; }
+ momentum_filter& get_bottom () { return bottom; }
+
+ friend void serialize(const rect_filter& item, std::ostream& out)
+ {
+ int version = 123;
+ serialize(version, out);
+ serialize(item.left, out);
+ serialize(item.top, out);
+ serialize(item.right, out);
+ serialize(item.bottom, out);
+ }
+
+ friend void deserialize(rect_filter& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 123)
+ throw dlib::serialization_error("Unknown version number found while deserializing rect_filter object.");
+ deserialize(item.left, in);
+ deserialize(item.top, in);
+ deserialize(item.right, in);
+ deserialize(item.bottom, in);
+ }
+
+ private:
+
+ momentum_filter left, top, right, bottom;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ rect_filter find_optimal_rect_filter (
+ const std::vector<rectangle>& rects,
+ const double smoothness = 1
+ );
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KALMAN_FiLTER_Hh_
+
diff --git a/ml/dlib/dlib/filtering/kalman_filter_abstract.h b/ml/dlib/dlib/filtering/kalman_filter_abstract.h
new file mode 100644
index 000000000..cdac2c569
--- /dev/null
+++ b/ml/dlib/dlib/filtering/kalman_filter_abstract.h
@@ -0,0 +1,492 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_KALMAN_FiLTER_ABSTRACT_Hh_
+#ifdef DLIB_KALMAN_FiLTER_ABSTRACT_Hh_
+
+#include "../serialize.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long states,
+ long measurements
+ >
+ class kalman_filter
+ {
+ /*!
+ REQUIREMENTS ON states
+ states > 0
+
+ REQUIREMENTS ON measurements
+ measurements > 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the Kalman filter, which is a tool for
+ recursively estimating the state of a process given measurements
+ related to that process. To use this tool you will have to
+ be familiar with the workings of the Kalman filter. An excellent
+ introduction can be found in the paper:
+ An Introduction to the Kalman Filter
+ by Greg Welch and Gary Bishop
+
+ !*/
+
+ public:
+
+ kalman_filter(
+ );
+ /*!
+ - #get_observation_model() == 0
+ - #get_transition_model() == 0
+ - #get_process_noise() == 0
+ - #get_measurement_noise() == 0
+ - #get_current_state() == 0
+ - #get_predicted_next_state() == 0
+ - #get_current_estimation_error_covariance() == the identity matrix
+ !*/
+
+ void set_observation_model (
+ const matrix<double,measurements,states>& H
+ );
+ /*!
+ ensures
+ - #get_observation_model() == H
+ !*/
+
+ void set_transition_model (
+ const matrix<double,states,states>& A
+ );
+ /*!
+ ensures
+ - #get_transition_model() == A
+ !*/
+
+ void set_process_noise (
+ const matrix<double,states,states>& Q
+ );
+ /*!
+ ensures
+ - #get_process_noise() == Q
+ !*/
+
+ void set_measurement_noise (
+ const matrix<double,measurements,measurements>& R
+ );
+ /*!
+ ensures
+ - #get_measurement_noise() == R
+ !*/
+
+ void set_estimation_error_covariance (
+ const matrix<double,states,states>& P
+ );
+ /*!
+ ensures
+ - #get_current_estimation_error_covariance() == P
+ (Note that you should only set this before you start filtering
+ since the Kalman filter will maintain the value of P on its own.
+ So only set this during initialization unless you are sure you
+ understand what you are doing.)
+ !*/
+
+ void set_state (
+ const matrix<double,states,1>& xb
+ );
+ /*!
+ ensures
+ - This function can be used when the initial state is known, or if the
+ state needs to be corrected before the next update().
+ - #get_predicted_next_state() == xb
+ - If (update() hasn't been called yet) then
+ - #get_current_state() == xb
+ !*/
+
+ const matrix<double,measurements,states>& get_observation_model (
+ ) const;
+ /*!
+ ensures
+ - Returns the matrix "H" which relates process states x to measurements z.
+ The relation is linear, therefore, z = H*x. That is, multiplying a
+ state by H gives the measurement you expect to observe for that state.
+ !*/
+
+ const matrix<double,states,states>& get_transition_model (
+ ) const;
+ /*!
+ ensures
+ - Returns the matrix "A" which determines how process states change over time.
+ The relation is linear, therefore, given a state vector x, the value you
+ expect it to have at the next time step is A*x.
+ !*/
+
+ const matrix<double,states,states>& get_process_noise (
+ ) const;
+ /*!
+ ensures
+ - returns the process noise covariance matrix. You can think of this
+ covariance matrix as a measure of how wrong the assumption of
+ linear state transitions is.
+ !*/
+
+ const matrix<double,measurements,measurements>& get_measurement_noise (
+ ) const;
+ /*!
+ ensures
+ - returns the measurement noise covariance matrix. That is, when we
+ measure a state x we only obtain H*x corrupted by Gaussian noise.
+ The measurement noise is the covariance matrix of this Gaussian
+ noise which corrupts our measurements.
+ !*/
+
+ void update (
+ );
+ /*!
+ ensures
+ - propagates the current state estimate forward in time one
+ time step. In particular:
+ - #get_current_state() == get_predicted_next_state()
+ - #get_predicted_next_state() == get_transition_model()*get_current_state()
+ - #get_current_estimation_error_covariance() == the propagated value of this covariance matrix
+ !*/
+
+ void update (
+ const matrix<double,measurements,1>& z
+ );
+ /*!
+ ensures
+ - propagates the current state estimate forward in time one time step.
+ Also applies a correction based on the given measurement z. In particular:
+ - #get_current_state(), #get_predicted_next_state(), and
+ #get_current_estimation_error_covariance() are updated using the
+ Kalman filter method based on the new measurement in z.
+ !*/
+
+ const matrix<double,states,1>& get_current_state(
+ ) const;
+ /*!
+ ensures
+ - returns the current estimate of the state of the process. This
+ estimate is based on all the measurements supplied to the update()
+ method.
+ !*/
+
+ const matrix<double,states,1>& get_predicted_next_state(
+ ) const;
+ /*!
+ ensures
+ - returns the next expected value of the process state.
+ - Specifically, returns get_transition_model()*get_current_state()
+
+ !*/
+
+ const matrix<double,states,states>& get_current_estimation_error_covariance(
+ ) const;
+ /*!
+ ensures
+ - returns the current state error estimation covariance matrix.
+ This matrix captures our uncertainty about the value of get_current_state().
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const kalman_filter& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ kalman_filter& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class momentum_filter
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple tool for filtering a single scalar value that
+ measures the location of a moving object that has some non-trivial
+ momentum. Importantly, the measurements are noisy and the object can
+ experience sudden unpredictable accelerations. To accomplish this
+ filtering we use a simple Kalman filter with a state transition model of:
+
+ position_{i+1} = position_{i} + velocity_{i}
+ velocity_{i+1} = velocity_{i} + some_unpredictable_acceleration
+
+ and a measurement model of:
+
+ measured_position_{i} = position_{i} + measurement_noise
+
+ Where some_unpredictable_acceleration and measurement_noise are 0 mean Gaussian
+ noise sources with standard deviations of get_typical_acceleration() and
+ get_measurement_noise() respectively.
+
+ To allow for really sudden and large but infrequent accelerations, at each
+ step we check if the current measured position deviates from the predicted
+ filtered position by more than get_max_measurement_deviation()*get_measurement_noise()
+ and if so we adjust the filter's state to keep it within these bounds.
+ This allows the moving object to undergo large unmodeled accelerations, far
+ in excess of what would be suggested by get_typical_acceleration(), without
+ then experiencing a long lag time where the Kalman filter has to "catch
+ up" to the new position.
+ !*/
+
+ public:
+
+ momentum_filter(
+ ) = default;
+ /*!
+ ensures
+ - #get_measurement_noise() == 2
+ - #get_typical_acceleration() == 0.1
+ - #get_max_measurement_deviation() == 3
+ !*/
+
+ momentum_filter(
+ double meas_noise,
+ double acc,
+ double max_meas_dev
+ );
+ /*!
+ requires
+ - meas_noise >= 0
+ - acc >= 0
+ - max_meas_dev >= 0
+ ensures
+ - #get_measurement_noise() == meas_noise
+ - #get_typical_acceleration() == acc
+ - #get_max_measurement_deviation() == max_meas_dev
+ !*/
+
+
+ double get_measurement_noise (
+ ) const;
+ /*!
+ ensures
+ - Returns the standard deviation of the 0 mean Gaussian noise that corrupts
+ measurements of the moving object.
+ !*/
+
+ double get_typical_acceleration (
+ ) const;
+ /*!
+ ensures
+ - We assume that the moving object experiences random accelerations that
+ are distributed by 0 mean Gaussian noise with get_typical_acceleration()
+ standard deviation.
+ !*/
+
+ double get_max_measurement_deviation (
+ ) const;
+ /*!
+ ensures
+ - This object will never let the filtered location of the object deviate
+ from the measured location by much more than
+ get_max_measurement_deviation()*get_measurement_noise().
+ !*/
+
+ void reset(
+ );
+ /*!
+ ensures
+ - Returns this object to the state immediately after construction. To be precise, we do:
+ *this = momentum_filter(get_measurement_noise(), get_typical_acceleration(), get_max_measurement_deviation());
+ !*/
+
+ double operator()(
+ const double measured_position
+ );
+ /*!
+ ensures
+ - Updates the Kalman filter with the new measured position of the object
+ and returns the new filtered estimate of the object's position, now that
+ we have seen the latest measured position.
+ - #get_predicted_next_position() == the prediction for the *next* place we
+ will see the object. That is, where we think it will be in the future
+ rather than where it is now.
+ !*/
+
+ double get_predicted_next_position (
+ ) const;
+ /*!
+ ensures
+ - Returns the Kalman filter's estimate of the next position we will see the object.
+ !*/
+ };
+
+ std::ostream& operator << (std::ostream& out, const momentum_filter& item);
+ void serialize(const momentum_filter& item, std::ostream& out);
+ void deserialize(momentum_filter& item, std::istream& in);
+ /*!
+ Provide printing and serialization support.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ momentum_filter find_optimal_momentum_filter (
+ const std::vector<std::vector<double>>& sequences,
+ const double smoothness = 1
+ );
+ /*!
+ requires
+ - sequences.size() != 0
+ - for all valid i: sequences[i].size() > 4
+ - smoothness >= 0
+ ensures
+ - This function finds the "optimal" settings of a momentum_filter based on
+ recorded measurement data stored in sequences. Here we assume that each
+ vector in sequences is a complete track history of some object's measured
+ positions. What we do is find the momentum_filter that minimizes the
+ following objective function:
+ sum of abs(predicted_location[i] - measured_location[i]) + smoothness*abs(filtered_location[i]-filtered_location[i-1])
+ Where i is a time index.
+ The sum runs over all the data in sequences. So what we do is find the
+ filter settings that produce smooth filtered trajectories but also produce
+ filtered outputs that are as close to the measured positions as possible.
+ The larger the value of smoothness the less jittery the filter outputs will
+ be, but they might become biased or laggy if smoothness is set really high.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ momentum_filter find_optimal_momentum_filter (
+ const std::vector<double>& sequence,
+ const double smoothness = 1
+ );
+ /*!
+ requires
+ - sequence.size() > 4
+ - smoothness >= 0
+ ensures
+ - performs: find_optimal_momentum_filter({1,sequence}, smoothness);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class rect_filter
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object simply contains four momentum_filters and applies them to the
+ 4 components of a dlib::rectangle's position. It therefore allows you to
+ easily filter a sequence of rectangles. For instance, it can be used to
+ smooth the output of an object detector running on a video.
+ !*/
+
+ public:
+ rect_filter(
+ ) = default;
+ /*!
+ ensures
+ - The four momentum_filters in this object are default initialized.
+ !*/
+
+ rect_filter(
+ const momentum_filter& filt
+ );
+ /*!
+ ensures
+ - #get_left() == filt
+ - #get_top() == filt
+ - #get_right() == filt
+ - #get_bottom() == filt
+ !*/
+
+ rect_filter(
+ double meas_noise,
+ double acc,
+ double max_meas_dev
+ ) : rect_filter(momentum_filter(meas_noise, acc, max_meas_dev)) {}
+ /*!
+ requires
+ - meas_noise >= 0
+ - acc >= 0
+ - max_meas_dev >= 0
+ ensures
+ - Initializes this object with momentum_filter(meas_noise, acc, max_meas_dev)
+ !*/
+
+ drectangle operator()(
+ const drectangle& r
+ );
+ /*!
+ ensures
+ - Runs the given rectangle through the momentum_filters and returns the
+ filtered rectangle location. That is, performs:
+ return drectangle(get_left()(r.left()),
+ get_top()(r.top()),
+ get_right()(r.right()),
+ get_bottom()(r.bottom()));
+ !*/
+
+ drectangle operator()(
+ const rectangle& r
+ );
+ /*!
+ ensures
+ - Runs the given rectangle through the momentum_filters and returns the
+ filtered rectangle location. That is, performs:
+ return drectangle(get_left()(r.left()),
+ get_top()(r.top()),
+ get_right()(r.right()),
+ get_bottom()(r.bottom()));
+ !*/
+
+ const momentum_filter& get_left() const;
+ momentum_filter& get_left();
+ const momentum_filter& get_top() const;
+ momentum_filter& get_top();
+ const momentum_filter& get_right() const;
+ momentum_filter& get_right();
+ const momentum_filter& get_bottom() const;
+ momentum_filter& get_bottom();
+ /*!
+ Provides access to the 4 momentum_filters used to filter the 4 coordinates that define a rectangle.
+ !*/
+ };
+
+ void serialize(const rect_filter& item, std::ostream& out);
+ void deserialize(rect_filter& item, std::istream& in);
+ /*!
+ Provide serialization support.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ rect_filter find_optimal_rect_filter (
+ const std::vector<rectangle>& rects,
+ const double smoothness = 1
+ );
+ /*!
+ requires
+ - rects.size() > 4
+ - smoothness >= 0
+ ensures
+ - This routine simply invokes find_optimal_momentum_filter() to find the
+ momentum_filter that works best on the provided sequence of rectangles. It
+ then constructs a rect_filter using that momentum_filter and returns it.
+ Therefore, this routine finds the rect_filter that is "optimal" for filtering
+ the given sequence of rectangles.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KALMAN_FiLTER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/filtering/rls_filter.h b/ml/dlib/dlib/filtering/rls_filter.h
new file mode 100644
index 000000000..4481ab3f4
--- /dev/null
+++ b/ml/dlib/dlib/filtering/rls_filter.h
@@ -0,0 +1,198 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RLS_FiLTER_Hh_
+#define DLIB_RLS_FiLTER_Hh_
+
+#include "rls_filter_abstract.h"
+#include "../svm/rls.h"
+#include <vector>
+#include "../matrix.h"
+#include "../sliding_buffer.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class rls_filter
+ {
+ /*!
+ CONVENTION
+ - data.size() == the number of variables in a measurement
+ - data[i].size() == data[j].size() for all i and j.
+ - data[i].size() == get_window_size()
+ - data[i][0] == most recent measurement of i-th variable given to update.
+ - data[i].back() == oldest measurement of i-th variable given to update
+ (or zero if we haven't seen this much data yet).
+
+ - if (count <= 2) then
+ - count == number of times update(z) has been called
+ !*/
+ public:
+
+ rls_filter()
+ {
+ size = 5;
+ count = 0;
+ filter = rls(0.8, 100);
+ }
+
+ explicit rls_filter (
+ unsigned long size_,
+ double forget_factor = 0.8,
+ double C = 100
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < forget_factor && forget_factor <= 1 &&
+ 0 < C && size_ >= 2,
+ "\t rls_filter::rls_filter()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t forget_factor: " << forget_factor
+ << "\n\t C: " << C
+ << "\n\t size_: " << size_
+ << "\n\t this: " << this
+ );
+
+ size = size_;
+ count = 0;
+ filter = rls(forget_factor, C);
+ }
+
+ double get_c(
+ ) const
+ {
+ return filter.get_c();
+ }
+
+ double get_forget_factor(
+ ) const
+ {
+ return filter.get_forget_factor();
+ }
+
+ unsigned long get_window_size (
+ ) const
+ {
+ return size;
+ }
+
+ void update (
+ )
+ {
+ if (filter.get_w().size() == 0)
+ return;
+
+ for (unsigned long i = 0; i < data.size(); ++i)
+ {
+ // Put old predicted value into the circular buffer as if it was
+ // the measurement we just observed. But don't update the rls filter.
+ data[i].push_front(next(i));
+ }
+
+ // predict next state
+ for (long i = 0; i < next.size(); ++i)
+ next(i) = filter(mat(data[i]));
+ }
+
+ template <typename EXP>
+ void update (
+ const matrix_exp<EXP>& z
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(z) == true &&
+ z.size() != 0 &&
+ (get_predicted_next_state().size()==0 || z.size()==get_predicted_next_state().size()),
+ "\t void rls_filter::update(z)"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t is_col_vector(z): " << is_col_vector(z)
+ << "\n\t z.size(): " << z.size()
+ << "\n\t get_predicted_next_state().size(): " << get_predicted_next_state().size()
+ << "\n\t this: " << this
+ );
+
+ // initialize data if necessary
+ if (data.size() == 0)
+ {
+ data.resize(z.size());
+ for (long i = 0; i < z.size(); ++i)
+ data[i].assign(size, 0);
+ }
+
+
+ for (unsigned long i = 0; i < data.size(); ++i)
+ {
+ // Once there is some stuff in the circular buffer, start
+ // showing it to the rls filter so it can do its thing.
+ if (count >= 2)
+ {
+ filter.train(mat(data[i]), z(i));
+ }
+
+ // keep track of the measurements in our circular buffer
+ data[i].push_front(z(i));
+ }
+
+ // Don't bother with the filter until we have seen two samples
+ if (count >= 2)
+ {
+ // predict next state
+ for (long i = 0; i < z.size(); ++i)
+ next(i) = filter(mat(data[i]));
+ }
+ else
+ {
+ // Use current measurement as the next state prediction
+ // since we don't know any better at this point.
+ ++count;
+ next = matrix_cast<double>(z);
+ }
+ }
+
+ const matrix<double,0,1>& get_predicted_next_state(
+ ) const
+ {
+ return next;
+ }
+
+ friend inline void serialize(const rls_filter& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.count, out);
+ serialize(item.size, out);
+ serialize(item.filter, out);
+ serialize(item.next, out);
+ serialize(item.data, out);
+ }
+
+ friend inline void deserialize(rls_filter& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw dlib::serialization_error("Unknown version number found while deserializing rls_filter object.");
+
+ deserialize(item.count, in);
+ deserialize(item.size, in);
+ deserialize(item.filter, in);
+ deserialize(item.next, in);
+ deserialize(item.data, in);
+ }
+
+ private:
+
+ unsigned long count;
+ unsigned long size;
+ rls filter;
+ matrix<double,0,1> next;
+ std::vector<circular_buffer<double> > data;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RLS_FiLTER_Hh_
+
diff --git a/ml/dlib/dlib/filtering/rls_filter_abstract.h b/ml/dlib/dlib/filtering/rls_filter_abstract.h
new file mode 100644
index 000000000..0a932cb87
--- /dev/null
+++ b/ml/dlib/dlib/filtering/rls_filter_abstract.h
@@ -0,0 +1,171 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RLS_FiLTER_ABSTRACT_Hh_
+#ifdef DLIB_RLS_FiLTER_ABSTRACT_Hh_
+
+#include "../svm/rls_abstract.h"
+#include "../matrix/matrix_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class rls_filter
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for doing time series prediction using linear
+ recursive least squares. In particular, this object takes a sequence
+ of points from the user and, at each step, attempts to predict the
+ value of the next point.
+
+ To accomplish this, this object maintains a fixed size buffer of recent
+ points. Each prediction is a linear combination of the points in this
+ history buffer. It uses the recursive least squares algorithm to
+ determine how to best combine the contents of the history buffer to
+ predict each point. Therefore, each time update() is called with
+ a point, recursive least squares updates the linear combination weights,
+ and then it inserts the point into the history buffer. After that, the
+ next prediction is based on these updated weights and the current history
+ buffer.
+ !*/
+
+ public:
+
+ rls_filter(
+ );
+ /*!
+ ensures
+ - #get_window_size() == 5
+ - #get_forget_factor() == 0.8
+ - #get_c() == 100
+ - #get_predicted_next_state().size() == 0
+ !*/
+
+ explicit rls_filter (
+ unsigned long size,
+ double forget_factor = 0.8,
+ double C = 100
+ );
+ /*!
+ requires
+ - 0 < forget_factor <= 1
+ - 0 < C
+ - size >= 2
+ ensures
+ - #get_window_size() == size
+ - #get_forget_factor() == forget_factor
+ - #get_c() == C
+ - #get_predicted_next_state().size() == 0
+ !*/
+
+ double get_c(
+ ) const;
+ /*!
+ ensures
+ - returns the regularization parameter. It is the parameter that determines
+ the trade-off between trying to fit the data points given to update() or
+ allowing more errors but hopefully improving the generalization of the
+ predictions. Larger values encourage exact fitting while smaller values
+ of C may encourage better generalization.
+ !*/
+
+ double get_forget_factor(
+ ) const;
+ /*!
+ ensures
+ - This object uses exponential forgetting in its implementation of recursive
+ least squares. Therefore, this function returns the "forget factor".
+ - if (get_forget_factor() == 1) then
+ - In this case, exponential forgetting is disabled.
+ - The recursive least squares algorithm will implicitly take all previous
+ calls to update(z) into account when estimating the optimal weights for
+ linearly combining the history buffer into a prediction of the next point.
+ - else
+ - Old calls to update(z) are eventually forgotten. That is, the smaller
+ the forget factor, the less recursive least squares will care about
+ attempting to find linear combination weights which would have make
+ good predictions on old points. It will care more about fitting recent
+ points. This is appropriate if the statistical properties of the time
+ series we are modeling are not constant.
+ !*/
+
+ unsigned long get_window_size (
+ ) const;
+ /*!
+ ensures
+ - returns the size of the history buffer. This is the number of points which
+ are linearly combine to make the predictions returned by get_predicted_next_state().
+ !*/
+
+ void update (
+ );
+ /*!
+ ensures
+ - Propagates the prediction forward in time.
+ - In particular, the value in get_predicted_next_state() is inserted
+ into the history buffer and then the next prediction is estimated
+ based on this updated history buffer.
+ - #get_predicted_next_state() == the prediction for the next point
+ in the time series.
+ !*/
+
+ template <typename EXP>
+ void update (
+ const matrix_exp<EXP>& z
+ );
+ /*!
+ requires
+ - is_col_vector(z) == true
+ - z.size() != 0
+ - if (get_predicted_next_state().size() != 0) then
+ - z.size() == get_predicted_next_state().size()
+ (i.e. z must be the same size as all the previous z values given
+ to this function)
+ ensures
+ - Updates the state of this filter based on the current measurement in z.
+ - In particular, the filter weights are updated and z is inserted into
+ the history buffer. Then the next prediction is estimated based on
+ these updated weights and history buffer.
+ - #get_predicted_next_state() == the prediction for the next point
+ in the time series.
+ - #get_predicted_next_state().size() == z.size()
+ !*/
+
+ const matrix<double,0,1>& get_predicted_next_state(
+ ) const;
+ /*!
+ ensures
+ - returns the estimate of the next point we will observe in the
+ time series data.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const rls_filter& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ rls_filter& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+#endif // DLIB_RLS_FiLTER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/float_details.h b/ml/dlib/dlib/float_details.h
new file mode 100644
index 000000000..3dc7eae49
--- /dev/null
+++ b/ml/dlib/dlib/float_details.h
@@ -0,0 +1,161 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FLOAT_DEtAILS_Hh_
+#define DLIB_FLOAT_DEtAILS_Hh_
+
+#include <cmath>
+#include "algs.h"
+#include <limits>
+
+namespace dlib
+{
+ struct float_details
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for converting floating point numbers into an
+ explicit integer representation and then also converting back. In
+ particular, a float_details object represents a floating point number with
+ a 64 bit mantissa and 16 bit exponent. These are stored in the public
+ fields of the same names.
+
+ The main use of this object is to convert floating point values into a
+ known uniform representation so they can be serialized to an output stream.
+ This allows dlib serialization code to work on any system, regardless of
+ the floating point representation used by the hardware. It also means
+ that, for example, a double can be serialized and then deserialized into a
+ float and it will perform the appropriate conversion.
+
+
+ In more detail, this object represents a floating point value equal to
+ mantissa*pow(2,exponent), except when exponent takes on any of the
+ following special values:
+ - is_inf
+ - is_ninf
+ - is_nan
+ These values are used to indicate that the floating point value should be
+ either infinity, negative infinity, or not-a-number respectively.
+ !*/
+
+ float_details(
+ int64 man,
+ int16 exp
+ ) : mantissa(man), exponent(exp) {}
+ /*!
+ ensures
+ - #mantissa == man
+ - #exponent == exp
+ !*/
+
+ float_details() :
+ mantissa(0), exponent(0)
+ {}
+ /*!
+ ensures
+ - this object represents a floating point value of 0
+ !*/
+
+ float_details ( const double& val) { *this = val; }
+ float_details ( const float& val) { *this = val; }
+ float_details ( const long double& val) { *this = val; }
+ /*!
+ ensures
+ - converts the given value into a float_details representation. This
+ means that converting #*this back into a floating point number should
+ recover the input val.
+ !*/
+
+ float_details& operator= ( const double& val) { convert_from_T(val); return *this; }
+ float_details& operator= ( const float& val) { convert_from_T(val); return *this; }
+ float_details& operator= ( const long double& val) { convert_from_T(val); return *this; }
+ /*!
+ ensures
+ - converts the given value into a float_details representation. This
+ means that converting #*this back into a floating point number should
+ recover the input val.
+ !*/
+
+ operator double () const { return convert_to_T<double>(); }
+ operator float () const { return convert_to_T<float>(); }
+ operator long double () const { return convert_to_T<long double>(); }
+ /*!
+ ensures
+ - converts the contents of this float_details object into a floating point number.
+ !*/
+
+ const static int16 is_inf = 32000;
+ const static int16 is_ninf = 32001;
+ const static int16 is_nan = 32002;
+
+ int64 mantissa;
+ int16 exponent;
+
+
+ private:
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// IMPLEMENTATION DETAILS
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void convert_from_T (
+ const T& val
+ )
+ {
+ mantissa = 0;
+
+ const int digits = dlib::tmin<std::numeric_limits<T>::digits, 63>::value;
+
+ if (val == std::numeric_limits<T>::infinity())
+ {
+ exponent = is_inf;
+ }
+ else if (val == -std::numeric_limits<T>::infinity())
+ {
+ exponent = is_ninf;
+ }
+ else if (val < std::numeric_limits<T>::infinity())
+ {
+ int exp;
+ mantissa = static_cast<int64>(std::frexp(val, &exp)*(((uint64)1)<<digits));
+ exponent = exp - digits;
+
+ // Compact the representation a bit by shifting off any low order bytes
+ // which are zero in the mantissa. This makes the numbers in mantissa and
+ // exponent generally smaller which can make serialization and other things
+ // more efficient in some cases.
+ for (int i = 0; i < 8 && ((mantissa&0xFF)==0); ++i)
+ {
+ mantissa >>= 8;
+ exponent += 8;
+ }
+ }
+ else
+ {
+ exponent = is_nan;
+ }
+ }
+
+ template <typename T>
+ T convert_to_T (
+ ) const
+ {
+ if (exponent < is_inf)
+ return std::ldexp((T)mantissa, exponent);
+ else if (exponent == is_inf)
+ return std::numeric_limits<T>::infinity();
+ else if (exponent == is_ninf)
+ return -std::numeric_limits<T>::infinity();
+ else
+ return std::numeric_limits<T>::quiet_NaN();
+ }
+
+ };
+
+}
+
+#endif // DLIB_FLOAT_DEtAILS_Hh_
+
diff --git a/ml/dlib/dlib/fstream b/ml/dlib/dlib/fstream
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/fstream
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/general_hash/count_bits.h b/ml/dlib/dlib/general_hash/count_bits.h
new file mode 100644
index 000000000..01acad2ab
--- /dev/null
+++ b/ml/dlib/dlib/general_hash/count_bits.h
@@ -0,0 +1,82 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_COUNT_BiTS_Hh_
+#define DLIB_COUNT_BiTS_Hh_
+
+#include "../algs.h"
+#include <climits>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T count_bits (
+ T v
+ )
+ /*!
+ requires
+ - T is an unsigned integral type
+ ensures
+ - returns the number of bits in v which are set to 1.
+ !*/
+ {
+ COMPILE_TIME_ASSERT(is_unsigned_type<T>::value && sizeof(T) <= 8);
+
+ // This bit of bit trickery is from:
+ // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSet64
+
+ v = v - ((v >> 1) & (T)~(T)0/3);
+ v = (v & (T)~(T)0/15*3) + ((v >> 2) & (T)~(T)0/15*3);
+ v = (v + (v >> 4)) & (T)~(T)0/255*15;
+ return (T)(v * ((T)~(T)0/255)) >> (sizeof(T) - 1) * CHAR_BIT;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T hamming_distance (
+ const T& a,
+ const T& b
+ )
+ /*!
+ requires
+ - T is an unsigned integral type
+ ensures
+ - returns the number of bits which differ between a and b.
+ !*/
+ {
+ return count_bits(a^b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T hamming_distance (
+ const std::pair<T,T>& a,
+ const std::pair<T,T>& b
+ )
+ /*!
+ requires
+ - T is an unsigned integral type or a std::pair that, recursively, eventually
+ contains unsigned integral types.
+ ensures
+ - returns the number of bits which differ between a and b.
+ !*/
+ {
+ return hamming_distance(a.first,b.first) + hamming_distance(a.second, b.second);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_COUNT_BiTS_Hh_
+
diff --git a/ml/dlib/dlib/general_hash/count_bits_abstract.h b/ml/dlib/dlib/general_hash/count_bits_abstract.h
new file mode 100644
index 000000000..ff5d4482c
--- /dev/null
+++ b/ml/dlib/dlib/general_hash/count_bits_abstract.h
@@ -0,0 +1,48 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_COUNT_BiTS_ABSTRACT_Hh_
+#ifdef DLIB_COUNT_BiTS_ABSTRACT_Hh_
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T count_bits (
+ T v
+ );
+ /*!
+ requires
+ - T is an unsigned integral type
+ ensures
+ - returns the number of bits in v which are set to 1.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T hamming_distance (
+ const T& a,
+ const T& b
+ );
+ /*!
+ requires
+ - T is an unsigned integral type
+ ensures
+ - returns the number of bits which differ between a and b. (I.e. returns
+ count_bits(a^b).)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_COUNT_BiTS_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/general_hash/general_hash.h b/ml/dlib/dlib/general_hash/general_hash.h
new file mode 100644
index 000000000..3de0b2698
--- /dev/null
+++ b/ml/dlib/dlib/general_hash/general_hash.h
@@ -0,0 +1,80 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GENERAL_HASh_
+#define DLIB_GENERAL_HASh_
+
+
+#include <string>
+#include "hash.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------- provide a general hashing function object ----------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class general_hash
+ {
+ public:
+ inline unsigned long operator() (
+ const T& item
+ ) const;
+ };
+ /*!
+ Note that the default behavior of general hash is to attempt to cast
+ an object of type T to an unsigned long and use that as the hash.
+
+ REQUIREMENTS ON general_hash
+ - must have a default constructor
+ - must be a function object which overloads operator() as follows:
+ unsigned long operator()(const T& item)
+ - must take item, compute a hash number and return it
+ - must not throw
+ - must define the hash in such a way that all equivalent objects have
+ the same hash. where equivalent means the following:
+ definition of equivalent:
+ a is equivalent to b if
+ a < b == false and
+ b < a == false
+ !*/
+
+// ---------------
+
+ template <
+ typename T
+ >
+ unsigned long general_hash<T>::
+ operator() (
+ const T& item
+ ) const
+ {
+ // hash any types that have a conversion to uint64
+ return hash(static_cast<uint64>(item));
+ }
+
+
+// ---------------
+
+ // std::string hash
+ template <>
+ inline unsigned long general_hash<std::string>::
+ operator() (
+ const std::string& item
+ ) const
+ {
+ return hash(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GENERAL_HASh_
+
diff --git a/ml/dlib/dlib/general_hash/hash.h b/ml/dlib/dlib/general_hash/hash.h
new file mode 100644
index 000000000..6edb99b99
--- /dev/null
+++ b/ml/dlib/dlib/general_hash/hash.h
@@ -0,0 +1,142 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HAsH_Hh_
+#define DLIB_HAsH_Hh_
+
+#include "hash_abstract.h"
+#include <vector>
+#include <string>
+#include <map>
+#include "murmur_hash3.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 hash (
+ const std::string& item,
+ uint32 seed = 0
+ )
+ {
+ if (item.size() == 0)
+ return 0;
+ else
+ return murmur_hash3(&item[0], sizeof(item[0])*item.size(), seed);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 hash (
+ const std::wstring& item,
+ uint32 seed = 0
+ )
+ {
+ if (item.size() == 0)
+ return 0;
+ else
+ return murmur_hash3(&item[0], sizeof(item[0])*item.size(), seed);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ uint32 hash (
+ const std::vector<T,alloc>& item,
+ uint32 seed = 0
+ )
+ {
+ DLIB_ASSERT_HAS_STANDARD_LAYOUT(T);
+
+ if (item.size() == 0)
+ return 0;
+ else
+ return murmur_hash3(&item[0], sizeof(T)*item.size(), seed);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename alloc>
+ uint32 hash (
+ const std::vector<std::pair<T,U>,alloc>& item,
+ uint32 seed = 0
+ )
+ {
+ DLIB_ASSERT_HAS_STANDARD_LAYOUT(T);
+ DLIB_ASSERT_HAS_STANDARD_LAYOUT(U);
+
+ if (item.size() == 0)
+ return 0;
+ else
+ return murmur_hash3(&item[0], sizeof(item[0])*item.size(), seed);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename comp, typename alloc>
+ uint32 hash (
+ const std::map<T,U,comp,alloc>& item,
+ uint32 seed = 0
+ )
+ {
+ return hash(std::vector<std::pair<T,U> >(item.begin(), item.end()), seed);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 hash (
+ uint32 val,
+ uint32 seed = 0
+ )
+ {
+ return murmur_hash3_2(val,seed);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 hash (
+ uint64 val,
+ uint32 seed = 0
+ )
+ {
+ return static_cast<uint32>(murmur_hash3_128bit_3(val,seed,0).first);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 hash (
+ const std::pair<uint64,uint64>& item,
+ uint32 seed = 0
+ )
+ {
+ return static_cast<uint32>(murmur_hash3_128bit_3(item.first,item.second,seed).first);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 hash (
+ const std::pair<uint32,uint32>& item,
+ uint32 seed = 0
+ )
+ {
+ return murmur_hash3_3(item.first,item.second,seed);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ uint32 hash (
+ const std::pair<T,U>& item,
+ uint32 seed = 0
+ )
+ {
+ return hash(item.first, seed) ^ hash(item.second, seed+1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HAsH_Hh_
+
diff --git a/ml/dlib/dlib/general_hash/hash_abstract.h b/ml/dlib/dlib/general_hash/hash_abstract.h
new file mode 100644
index 000000000..8959bbe44
--- /dev/null
+++ b/ml/dlib/dlib/general_hash/hash_abstract.h
@@ -0,0 +1,182 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_HAsH_ABSTRACT_Hh_
+#ifdef DLIB_HAsH_ABSTRACT_Hh_
+
+#include "murmur_hash3_abstract.h"
+#include <vector>
+#include <string>
+#include <map>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ uint32 hash (
+ const std::string& item,
+ uint32 seed = 0
+ );
+ /*!
+ ensures
+ - returns a 32bit hash of the data stored in item.
+ - Each value of seed results in a different hash function being used.
+ (e.g. hash(item,0) should generally not be equal to hash(item,1))
+ - uses the murmur_hash3() routine to compute the actual hash.
+ - This routine will always give the same hash value when presented
+ with the same input string.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ uint32 hash (
+ const std::wstring& item,
+ uint32 seed = 0
+ );
+ /*!
+ ensures
+ - returns a 32bit hash of the data stored in item.
+ - Each value of seed results in a different hash function being used.
+ (e.g. hash(item,0) should generally not be equal to hash(item,1))
+ - uses the murmur_hash3() routine to compute the actual hash.
+ - Note that if the memory layout of the elements in item change between
+ hardware platforms then hash() will give different outputs. If you want
+ hash() to always give the same output for the same input then you must
+ ensure that elements of item always have the same layout in memory.
+ Typically this means using fixed width types and performing byte swapping
+ to account for endianness before passing item to hash().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ uint32 hash (
+ const std::vector<T,alloc>& item,
+ uint32 seed = 0
+ );
+ /*!
+ requires
+ - T is a standard layout type (e.g. a POD type like int, float,
+ or a simple struct).
+ ensures
+ - returns a 32bit hash of the data stored in item.
+ - Each value of seed results in a different hash function being used.
+ (e.g. hash(item,0) should generally not be equal to hash(item,1))
+ - uses the murmur_hash3() routine to compute the actual hash.
+ - Note that if the memory layout of the elements in item change between
+ hardware platforms then hash() will give different outputs. If you want
+ hash() to always give the same output for the same input then you must
+ ensure that elements of item always have the same layout in memory.
+ Typically this means using fixed width types and performing byte swapping
+ to account for endianness before passing item to hash().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename alloc>
+ uint32 hash (
+ const std::vector<std::pair<T,U>,alloc>& item,
+ uint32 seed = 0
+ );
+ /*!
+ requires
+ - T and U are standard layout types (e.g. POD types like int, float,
+ or simple structs).
+ ensures
+ - returns a 32bit hash of the data stored in item.
+ - Each value of seed results in a different hash function being used.
+ (e.g. hash(item,0) should generally not be equal to hash(item,1))
+ - uses the murmur_hash3() routine to compute the actual hash.
+ - Note that if the memory layout of the elements in item change between
+ hardware platforms then hash() will give different outputs. If you want
+ hash() to always give the same output for the same input then you must
+ ensure that elements of item always have the same layout in memory.
+ Typically this means using fixed width types and performing byte swapping
+ to account for endianness before passing item to hash().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename comp, typename alloc>
+ uint32 hash (
+ const std::map<T,U,comp,alloc>& item,
+ uint32 seed = 0
+ );
+ /*!
+ requires
+ - T and U are standard layout types (e.g. POD types like int, float,
+ or simple structs).
+ ensures
+ - returns a 32bit hash of the data stored in item.
+ - Each value of seed results in a different hash function being used.
+ (e.g. hash(item,0) should generally not be equal to hash(item,1))
+ - uses the murmur_hash3() routine to compute the actual hash.
+ - Note that if the memory layout of the elements in item change between
+ hardware platforms then hash() will give different outputs. If you want
+ hash() to always give the same output for the same input then you must
+ ensure that elements of item always have the same layout in memory.
+ Typically this means using fixed width types and performing byte swapping
+ to account for endianness before passing item to hash(). However, since
+ you can't modify the keys in a map you may have to copy it into a
+ std::vector and then work from there.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 hash (
+ uint32 item,
+ uint32 seed = 0
+ );
+ /*!
+ ensures
+ - returns a 32bit hash of the data stored in item.
+ - Each value of seed results in a different hash function being used.
+ (e.g. hash(item,0) should generally not be equal to hash(item,1))
+ - uses the murmur_hash3_2() routine to compute the actual hash.
+ - This routine will always give the same hash value when presented
+ with the same input.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 hash (
+ uint64 item,
+ uint32 seed = 0
+ );
+ /*!
+ ensures
+ - returns a 32bit hash of the data stored in item.
+ - Each value of seed results in a different hash function being used.
+ (e.g. hash(item,0) should generally not be equal to hash(item,1))
+ - uses the murmur_hash3_128bit_3() routine to compute the actual hash.
+ - This routine will always give the same hash value when presented
+ with the same input.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ uint32 hash (
+ const std::pair<T,U>& item,
+ uint32 seed = 0
+ );
+ /*!
+ requires
+ - hash() is defined for objects of type T and U
+ ensures
+ - returns a 32bit hash of the data stored in item.
+ - Each value of seed results in a different hash function being used.
+ (e.g. hash(item,0) should generally not be equal to hash(item,1))
+ - if (calling hash() on items of type T and U is always guaranteed to give the
+ same hash values when presented with the same input) then
+ - This routine will always give the same hash value when presented
+ with the same input.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HAsH_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/general_hash/murmur_hash3.h b/ml/dlib/dlib/general_hash/murmur_hash3.h
new file mode 100644
index 000000000..b8e37260e
--- /dev/null
+++ b/ml/dlib/dlib/general_hash/murmur_hash3.h
@@ -0,0 +1,519 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MURMUR_HAsH_3_Hh_
+#define DLIB_MURMUR_HAsH_3_Hh_
+
+#include "murmur_hash3_abstract.h"
+#include "../uintn.h"
+#include <utility>
+#include <string.h>
+
+namespace dlib
+{
+ //-----------------------------------------------------------------------------
+ // The original MurmurHash3 code was written by Austin Appleby, and is placed
+ // in the public domain. The author hereby disclaims copyright to this source code.
+ // The code in this particular file was modified by Davis E. King. In
+ // particular, endian-swapping was added along with some other minor code
+ // changes like avoiding strict aliasing violations.
+
+
+ //-----------------------------------------------------------------------------
+ // Platform-specific functions and macros
+
+ // Microsoft Visual Studio
+
+#if defined(_MSC_VER)
+
+#define DLIB_FORCE_INLINE __forceinline
+
+#include <stdlib.h>
+
+#define DLIB_ROTL32(x,y) _rotl(x,y)
+#define DLIB_ROTL64(x,y) _rotl64(x,y)
+
+#define DLIB_BIG_CONSTANT(x) (x)
+
+ // Other compilers
+
+#else // defined(_MSC_VER)
+
+#define DLIB_FORCE_INLINE __attribute__((always_inline)) inline
+
+ inline uint32 murmur_rotl32 ( uint32 x, int8 r )
+ {
+ return (x << r) | (x >> (32 - r));
+ }
+
+ inline uint64 murmur_rotl64 ( uint64 x, int8 r )
+ {
+ return (x << r) | (x >> (64 - r));
+ }
+
+#define DLIB_ROTL32(x,y) dlib::murmur_rotl32(x,y)
+#define DLIB_ROTL64(x,y) dlib::murmur_rotl64(x,y)
+
+#define DLIB_BIG_CONSTANT(x) (x##LLU)
+
+#endif // !defined(_MSC_VER)
+
+// ----------------------------------------------------------------------------------------
+ // Block read - if your platform needs to do endian-swapping or can only
+ // handle aligned reads, do the conversion here
+
+ DLIB_FORCE_INLINE uint32 murmur_getblock ( const uint32 * p, int i )
+ {
+ // The reason we do a memcpy() here instead of simply returning p[i] is because
+ // doing it this way avoids violations of the strict aliasing rule when all these
+ // functions are inlined into the user's code.
+ uint32 temp;
+ memcpy(&temp, p+i, 4);
+ return temp;
+ }
+
+ DLIB_FORCE_INLINE uint32 murmur_getblock_byte_swap ( const uint32 * p, int i )
+ {
+ union
+ {
+ uint8 bytes[4];
+ uint32 val;
+ } temp;
+
+ const uint8* pp = reinterpret_cast<const uint8*>(p + i);
+ temp.bytes[0] = pp[3];
+ temp.bytes[1] = pp[2];
+ temp.bytes[2] = pp[1];
+ temp.bytes[3] = pp[0];
+
+ return temp.val;
+ }
+
+ DLIB_FORCE_INLINE uint64 murmur_getblock ( const uint64 * p, int i )
+ {
+ // The reason we do a memcpy() here instead of simply returning p[i] is because
+ // doing it this way avoids violations of the strict aliasing rule when all these
+ // functions are inlined into the user's code.
+ uint64 temp;
+ memcpy(&temp, p+i, 8);
+ return temp;
+ }
+
+ DLIB_FORCE_INLINE uint64 murmur_getblock_byte_swap ( const uint64 * p, int i )
+ {
+ union
+ {
+ uint8 bytes[8];
+ uint64 val;
+ } temp;
+
+ const uint8* pp = reinterpret_cast<const uint8*>(p + i);
+ temp.bytes[0] = pp[7];
+ temp.bytes[1] = pp[6];
+ temp.bytes[2] = pp[5];
+ temp.bytes[3] = pp[4];
+ temp.bytes[4] = pp[3];
+ temp.bytes[5] = pp[2];
+ temp.bytes[6] = pp[1];
+ temp.bytes[7] = pp[0];
+
+ return temp.val;
+ }
+
+// ----------------------------------------------------------------------------------------
+ // Finalization mix - force all bits of a hash block to avalanche
+
+ DLIB_FORCE_INLINE uint32 murmur_fmix ( uint32 h )
+ {
+ h ^= h >> 16;
+ h *= 0x85ebca6b;
+ h ^= h >> 13;
+ h *= 0xc2b2ae35;
+ h ^= h >> 16;
+
+ return h;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ DLIB_FORCE_INLINE uint64 murmur_fmix ( uint64 k )
+ {
+ k ^= k >> 33;
+ k *= DLIB_BIG_CONSTANT(0xff51afd7ed558ccd);
+ k ^= k >> 33;
+ k *= DLIB_BIG_CONSTANT(0xc4ceb9fe1a85ec53);
+ k ^= k >> 33;
+
+ return k;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 murmur_hash3 (
+ const void * key,
+ const int len,
+ const uint32 seed = 0
+ )
+ {
+ const uint8 * data = (const uint8*)key;
+ const int nblocks = len / 4;
+
+ uint32 h1 = seed;
+
+ uint32 c1 = 0xcc9e2d51;
+ uint32 c2 = 0x1b873593;
+
+ //----------
+ // body
+
+ const uint32 * blocks = (const uint32 *)(data + nblocks*4);
+
+ bool is_little_endian = true;
+ uint32 endian_test = 1;
+ if (*reinterpret_cast<unsigned char*>(&endian_test) != 1)
+ is_little_endian = false;
+
+
+ if (is_little_endian)
+ {
+ for(int i = -nblocks; i; i++)
+ {
+ uint32 k1 = murmur_getblock(blocks,i);
+
+ k1 *= c1;
+ k1 = DLIB_ROTL32(k1,15);
+ k1 *= c2;
+
+ h1 ^= k1;
+ h1 = DLIB_ROTL32(h1,13);
+ h1 = h1*5+0xe6546b64;
+ }
+ }
+ else
+ {
+ for(int i = -nblocks; i; i++)
+ {
+ uint32 k1 = murmur_getblock_byte_swap(blocks,i);
+
+ k1 *= c1;
+ k1 = DLIB_ROTL32(k1,15);
+ k1 *= c2;
+
+ h1 ^= k1;
+ h1 = DLIB_ROTL32(h1,13);
+ h1 = h1*5+0xe6546b64;
+ }
+ }
+
+ //----------
+ // tail
+
+ const uint8 * tail = (const uint8*)(data + nblocks*4);
+
+ uint32 k1 = 0;
+
+ switch(len & 3)
+ {
+ case 3: k1 ^= tail[2] << 16;
+ // fall through
+ case 2: k1 ^= tail[1] << 8;
+ // fall through
+ case 1: k1 ^= tail[0];
+ k1 *= c1; k1 = DLIB_ROTL32(k1,15); k1 *= c2; h1 ^= k1;
+ };
+
+ //----------
+ // finalization
+
+ h1 ^= len;
+
+ h1 = murmur_fmix(h1);
+
+ return h1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 murmur_hash3_2 (
+ const uint32 v1,
+ const uint32 v2
+ )
+ {
+ uint32 h1 = v2;
+
+ uint32 c1 = 0xcc9e2d51;
+ uint32 c2 = 0x1b873593;
+
+ //----------
+ // body
+
+
+ uint32 k1 = v1;
+
+ k1 *= c1;
+ k1 = DLIB_ROTL32(k1,15);
+ k1 *= c2;
+
+ h1 ^= k1;
+ h1 = DLIB_ROTL32(h1,13);
+ h1 = h1*5+0xe6546b64;
+
+
+ //----------
+ // finalization
+
+ h1 ^= 4; // =^ by length in bytes
+
+ h1 = murmur_fmix(h1);
+
+ return h1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 murmur_hash3_3 (
+ const uint32 v1,
+ const uint32 v2,
+ const uint32 v3
+ )
+ {
+
+ uint32 h1 = v3;
+
+ uint32 c1 = 0xcc9e2d51;
+ uint32 c2 = 0x1b873593;
+
+ //----------
+ // body
+
+
+ uint32 k1 = v1;
+
+ k1 *= c1;
+ k1 = DLIB_ROTL32(k1,15);
+ k1 *= c2;
+
+ h1 ^= k1;
+ h1 = DLIB_ROTL32(h1,13);
+ h1 = h1*5+0xe6546b64;
+
+ k1 = v2;
+ k1 *= c1;
+ k1 = DLIB_ROTL32(k1,15);
+ k1 *= c2;
+
+ h1 ^= k1;
+ h1 = DLIB_ROTL32(h1,13);
+ h1 = h1*5+0xe6546b64;
+
+ //----------
+ // finalization
+
+ h1 ^= 8; // =^ by length in bytes
+
+ h1 = murmur_fmix(h1);
+
+ return h1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::pair<uint64,uint64> murmur_hash3_128bit (
+ const void* key,
+ const int len,
+ const uint32 seed = 0
+ )
+ {
+ const uint8 * data = (const uint8*)key;
+ const int nblocks = len / 16;
+
+ uint64 h1 = seed;
+ uint64 h2 = seed;
+
+ uint64 c1 = DLIB_BIG_CONSTANT(0x87c37b91114253d5);
+ uint64 c2 = DLIB_BIG_CONSTANT(0x4cf5ad432745937f);
+
+ //----------
+ // body
+
+ const uint64 * blocks = (const uint64 *)(data);
+
+ bool is_little_endian = true;
+ uint32 endian_test = 1;
+ if (*reinterpret_cast<unsigned char*>(&endian_test) != 1)
+ is_little_endian = false;
+
+
+ if (is_little_endian)
+ {
+ for(int i = 0; i < nblocks; i++)
+ {
+ uint64 k1 = murmur_getblock(blocks,i*2+0);
+ uint64 k2 = murmur_getblock(blocks,i*2+1);
+
+ k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2; h1 ^= k1;
+
+ h1 = DLIB_ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729;
+
+ k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1; h2 ^= k2;
+
+ h2 = DLIB_ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5;
+ }
+ }
+ else
+ {
+ for(int i = 0; i < nblocks; i++)
+ {
+ uint64 k1 = murmur_getblock_byte_swap(blocks,i*2+0);
+ uint64 k2 = murmur_getblock_byte_swap(blocks,i*2+1);
+
+ k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2; h1 ^= k1;
+
+ h1 = DLIB_ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729;
+
+ k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1; h2 ^= k2;
+
+ h2 = DLIB_ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5;
+ }
+ }
+
+ //----------
+ // tail
+
+ const uint8 * tail = (const uint8*)(data + nblocks*16);
+
+ uint64 k1 = 0;
+ uint64 k2 = 0;
+
+ switch(len & 15)
+ {
+ case 15: k2 ^= uint64(tail[14]) << 48; // fall through
+ case 14: k2 ^= uint64(tail[13]) << 40; // fall through
+ case 13: k2 ^= uint64(tail[12]) << 32; // fall through
+ case 12: k2 ^= uint64(tail[11]) << 24; // fall through
+ case 11: k2 ^= uint64(tail[10]) << 16; // fall through
+ case 10: k2 ^= uint64(tail[ 9]) << 8; // fall through
+ case 9: k2 ^= uint64(tail[ 8]) << 0;
+ k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1; h2 ^= k2; // fall through
+
+ case 8: k1 ^= uint64(tail[ 7]) << 56; // fall through
+ case 7: k1 ^= uint64(tail[ 6]) << 48; // fall through
+ case 6: k1 ^= uint64(tail[ 5]) << 40; // fall through
+ case 5: k1 ^= uint64(tail[ 4]) << 32; // fall through
+ case 4: k1 ^= uint64(tail[ 3]) << 24; // fall through
+ case 3: k1 ^= uint64(tail[ 2]) << 16; // fall through
+ case 2: k1 ^= uint64(tail[ 1]) << 8; // fall through
+ case 1: k1 ^= uint64(tail[ 0]) << 0;
+ k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2; h1 ^= k1; // fall through
+ };
+
+ //----------
+ // finalization
+
+ h1 ^= len; h2 ^= len;
+
+ h1 += h2;
+ h2 += h1;
+
+ h1 = murmur_fmix(h1);
+ h2 = murmur_fmix(h2);
+
+ h1 += h2;
+ h2 += h1;
+
+ return std::make_pair(h1,h2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::pair<uint64,uint64> murmur_hash3_128bit (
+ const uint32& v1,
+ const uint32& v2,
+ const uint32& v3,
+ const uint32& v4
+ )
+ {
+ uint64 h1 = 0;
+ uint64 h2 = 0;
+
+ const uint64 c1 = DLIB_BIG_CONSTANT(0x87c37b91114253d5);
+ const uint64 c2 = DLIB_BIG_CONSTANT(0x4cf5ad432745937f);
+
+ //----------
+ // body
+
+ uint64 k1 = (static_cast<uint64>(v2)<<32)|v1;
+ uint64 k2 = (static_cast<uint64>(v4)<<32)|v3;
+
+ k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2;
+
+ h1 = DLIB_ROTL64(k1,27); h1 = h1*5+0x52dce729;
+
+ k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1;
+
+ h2 = DLIB_ROTL64(k2,31); h2 += h1; h2 = h2*5+0x38495ab5;
+
+ //----------
+ // finalization
+
+ h1 ^= 16; h2 ^= 16;
+
+ h1 += h2;
+ h2 += h1;
+
+ h1 = murmur_fmix(h1);
+ h2 = murmur_fmix(h2);
+
+ h1 += h2;
+ h2 += h1;
+
+ return std::make_pair(h1,h2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::pair<uint64,uint64> murmur_hash3_128bit_3 (
+ uint64 k1,
+ uint64 k2,
+ uint64 k3
+ )
+ {
+ uint64 h1 = k3;
+ uint64 h2 = k3;
+
+ const uint64 c1 = DLIB_BIG_CONSTANT(0x87c37b91114253d5);
+ const uint64 c2 = DLIB_BIG_CONSTANT(0x4cf5ad432745937f);
+
+ //----------
+ // body
+
+ k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2; h1 ^= k1;
+
+ h1 = DLIB_ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729;
+
+ k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1; h2 ^= k2;
+
+ h2 = DLIB_ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5;
+
+ //----------
+ // finalization
+
+ h1 ^= 16; h2 ^= 16;
+
+ h1 += h2;
+ h2 += h1;
+
+ h1 = murmur_fmix(h1);
+ h2 = murmur_fmix(h2);
+
+ h1 += h2;
+ h2 += h1;
+
+ return std::make_pair(h1,h2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MURMUR_HAsH_3_Hh_
+
diff --git a/ml/dlib/dlib/general_hash/murmur_hash3_abstract.h b/ml/dlib/dlib/general_hash/murmur_hash3_abstract.h
new file mode 100644
index 000000000..ef7bc0b41
--- /dev/null
+++ b/ml/dlib/dlib/general_hash/murmur_hash3_abstract.h
@@ -0,0 +1,125 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MURMUR_HAsH_3_ABSTRACT_Hh_
+#ifdef DLIB_MURMUR_HAsH_3_ABSTRACT_Hh_
+
+#include "../uintn.h"
+#include <utility>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ uint32 murmur_hash3 (
+ const void* key,
+ const int len,
+ const uint32 seed = 0
+ );
+ /*!
+ requires
+ - key == a pointer to a block of memory len bytes long
+ ensures
+ - returns a 32bit hash of the len bytes pointed to by key.
+ - Each value of seed results in a different hash function being used.
+ (e.g. murmur_hash3(key,len,0) should generally not be equal to murmur_hash3(key,len,1))
+ - This function is machine architecture agnostic and should always give the same
+ hash value when presented with the same inputs.
+ - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x86_32.
+ See: http://code.google.com/p/smhasher/
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 murmur_hash3_2 (
+ const uint32 v1,
+ const uint32 v2
+ );
+ /*!
+ ensures
+ - returns a 32bit hash of the two integers given to this function.
+ - This function is machine architecture agnostic and should always give the same
+ hash value when presented with the same inputs.
+ - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x86_32.
+ See: http://code.google.com/p/smhasher/
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline uint32 murmur_hash3_3 (
+ const uint32 v1,
+ const uint32 v2,
+ const uint32 v3
+ );
+ /*!
+ ensures
+ - returns a 32bit hash of the three integers given to this function.
+ - This function is machine architecture agnostic and should always give the same
+ hash value when presented with the same inputs.
+ - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x86_32.
+ See: http://code.google.com/p/smhasher/
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::pair<uint64,uint64> murmur_hash3_128bit (
+ const void* key,
+ const int len,
+ const uint32 seed = 0
+ );
+ /*!
+ requires
+ - key == a pointer to a block of memory len bytes long
+ ensures
+ - returns a 128bit hash (as two 64bit numbers) of the len bytes pointed to by key.
+ - Each value of seed results in a different hash function being used.
+ (e.g. murmur_hash3_128bit(key,len,0) should generally not be equal to
+ murmur_hash3_128bit(key,len,1))
+ - This function is machine architecture agnostic and should always give the same
+ hash value when presented with the same inputs.
+ - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x64_128.
+ See: http://code.google.com/p/smhasher/
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::pair<uint64,uint64> murmur_hash3_128bit (
+ const uint32& v1,
+ const uint32& v2,
+ const uint32& v3,
+ const uint32& v4
+ );
+ /*!
+ ensures
+ - returns a 128bit hash (as two 64bit numbers) of the 4 integers given to this
+ function.
+ - This function is machine architecture agnostic and should always give the
+ same hash value when presented with the same inputs.
+ - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x64_128.
+ See: http://code.google.com/p/smhasher/
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::pair<uint64,uint64> murmur_hash3_128bit_3 (
+ uint64 k1,
+ uint64 k2,
+ uint64 k3
+ );
+ /*!
+ ensures
+ - returns a 128bit hash (as two 64bit numbers) of the 3 integers given to this
+ function.
+ - This function is machine architecture agnostic and should always give the
+ same hash value when presented with the same inputs.
+ - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x64_128.
+ See: http://code.google.com/p/smhasher/
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MURMUR_HAsH_3_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/general_hash/random_hashing.h b/ml/dlib/dlib/general_hash/random_hashing.h
new file mode 100644
index 000000000..7a06a6878
--- /dev/null
+++ b/ml/dlib/dlib/general_hash/random_hashing.h
@@ -0,0 +1,877 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RANDOM_HAsHING_Hh_
+#define DLIB_RANDOM_HAsHING_Hh_
+
+#include "random_hashing_abstract.h"
+#include "murmur_hash3.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ inline double uniform_random_hash (
+ const uint64& k1,
+ const uint64& k2,
+ const uint64& k3
+ )
+ {
+ const std::pair<uint64,uint64> h = murmur_hash3_128bit_3(k1,k2,k3);
+ const uint64 mask = DLIB_BIG_CONSTANT(0xFFFFFFFFFF);
+ const double max = mask+1;
+ return static_cast<double>(h.first&mask)/max;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double gaussian_random_hash (
+ const uint64& k1,
+ const uint64& k2,
+ const uint64& k3
+ )
+ {
+ const std::pair<uint64,uint64> h = murmur_hash3_128bit_3(k1,k2,k3);
+
+ const static unsigned int max = 4096;
+
+ const static double logvals[max] = {
+ 4.079, 3.905, 3.8, 3.723, 3.663, 3.613, 3.57, 3.532, 3.499, 3.468,
+ 3.441, 3.416, 3.392, 3.37, 3.35, 3.33, 3.312, 3.295, 3.278, 3.263,
+ 3.248, 3.233, 3.219, 3.206, 3.193, 3.181, 3.169, 3.158, 3.147, 3.136,
+ 3.125, 3.115, 3.105, 3.096, 3.086, 3.077, 3.068, 3.059, 3.051, 3.043,
+ 3.035, 3.027, 3.019, 3.011, 3.004, 2.996, 2.989, 2.982, 2.975, 2.968,
+ 2.962, 2.955, 2.949, 2.942, 2.936, 2.93, 2.924, 2.918, 2.912, 2.906,
+ 2.901, 2.895, 2.89, 2.884, 2.879, 2.873, 2.868, 2.863, 2.858, 2.853,
+ 2.848, 2.843, 2.838, 2.833, 2.829, 2.824, 2.819, 2.815, 2.81, 2.806,
+ 2.801, 2.797, 2.792, 2.788, 2.784, 2.78, 2.776, 2.771, 2.767, 2.763,
+ 2.759, 2.755, 2.751, 2.748, 2.744, 2.74, 2.736, 2.732, 2.729, 2.725,
+ 2.721, 2.718, 2.714, 2.71, 2.707, 2.703, 2.7, 2.697, 2.693, 2.69,
+ 2.686, 2.683, 2.68, 2.676, 2.673, 2.67, 2.667, 2.663, 2.66, 2.657,
+ 2.654, 2.651, 2.648, 2.645, 2.642, 2.639, 2.636, 2.633, 2.63, 2.627,
+ 2.624, 2.621, 2.618, 2.615, 2.612, 2.61, 2.607, 2.604, 2.601, 2.599,
+ 2.596, 2.593, 2.59, 2.588, 2.585, 2.582, 2.58, 2.577, 2.574, 2.572,
+ 2.569, 2.567, 2.564, 2.562, 2.559, 2.557, 2.554, 2.552, 2.549, 2.547,
+ 2.544, 2.542, 2.539, 2.537, 2.534, 2.532, 2.53, 2.527, 2.525, 2.523,
+ 2.52, 2.518, 2.516, 2.513, 2.511, 2.509, 2.507, 2.504, 2.502, 2.5,
+ 2.498, 2.495, 2.493, 2.491, 2.489, 2.487, 2.485, 2.482, 2.48, 2.478,
+ 2.476, 2.474, 2.472, 2.47, 2.468, 2.466, 2.464, 2.462, 2.459, 2.457,
+ 2.455, 2.453, 2.451, 2.449, 2.447, 2.445, 2.443, 2.441, 2.439, 2.437,
+ 2.436, 2.434, 2.432, 2.43, 2.428, 2.426, 2.424, 2.422, 2.42, 2.418,
+ 2.416, 2.415, 2.413, 2.411, 2.409, 2.407, 2.405, 2.404, 2.402, 2.4,
+ 2.398, 2.396, 2.394, 2.393, 2.391, 2.389, 2.387, 2.386, 2.384, 2.382,
+ 2.38, 2.379, 2.377, 2.375, 2.373, 2.372, 2.37, 2.368, 2.367, 2.365,
+ 2.363, 2.361, 2.36, 2.358, 2.356, 2.355, 2.353, 2.352, 2.35, 2.348,
+ 2.347, 2.345, 2.343, 2.342, 2.34, 2.338, 2.337, 2.335, 2.334, 2.332,
+ 2.331, 2.329, 2.327, 2.326, 2.324, 2.323, 2.321, 2.32, 2.318, 2.316,
+ 2.315, 2.313, 2.312, 2.31, 2.309, 2.307, 2.306, 2.304, 2.303, 2.301,
+ 2.3, 2.298, 2.297, 2.295, 2.294, 2.292, 2.291, 2.289, 2.288, 2.286,
+ 2.285, 2.284, 2.282, 2.281, 2.279, 2.278, 2.276, 2.275, 2.274, 2.272,
+ 2.271, 2.269, 2.268, 2.266, 2.265, 2.264, 2.262, 2.261, 2.259, 2.258,
+ 2.257, 2.255, 2.254, 2.253, 2.251, 2.25, 2.248, 2.247, 2.246, 2.244,
+ 2.243, 2.242, 2.24, 2.239, 2.238, 2.236, 2.235, 2.234, 2.232, 2.231,
+ 2.23, 2.228, 2.227, 2.226, 2.225, 2.223, 2.222, 2.221, 2.219, 2.218,
+ 2.217, 2.215, 2.214, 2.213, 2.212, 2.21, 2.209, 2.208, 2.207, 2.205,
+ 2.204, 2.203, 2.202, 2.2, 2.199, 2.198, 2.197, 2.195, 2.194, 2.193,
+ 2.192, 2.19, 2.189, 2.188, 2.187, 2.185, 2.184, 2.183, 2.182, 2.181,
+ 2.179, 2.178, 2.177, 2.176, 2.175, 2.173, 2.172, 2.171, 2.17, 2.169,
+ 2.168, 2.166, 2.165, 2.164, 2.163, 2.162, 2.16, 2.159, 2.158, 2.157,
+ 2.156, 2.155, 2.154, 2.152, 2.151, 2.15, 2.149, 2.148, 2.147, 2.146,
+ 2.144, 2.143, 2.142, 2.141, 2.14, 2.139, 2.138, 2.136, 2.135, 2.134,
+ 2.133, 2.132, 2.131, 2.13, 2.129, 2.128, 2.126, 2.125, 2.124, 2.123,
+ 2.122, 2.121, 2.12, 2.119, 2.118, 2.117, 2.116, 2.114, 2.113, 2.112,
+ 2.111, 2.11, 2.109, 2.108, 2.107, 2.106, 2.105, 2.104, 2.103, 2.102,
+ 2.101, 2.1, 2.099, 2.097, 2.096, 2.095, 2.094, 2.093, 2.092, 2.091,
+ 2.09, 2.089, 2.088, 2.087, 2.086, 2.085, 2.084, 2.083, 2.082, 2.081,
+ 2.08, 2.079, 2.078, 2.077, 2.076, 2.075, 2.074, 2.073, 2.072, 2.071,
+ 2.07, 2.069, 2.068, 2.067, 2.066, 2.065, 2.064, 2.063, 2.062, 2.061,
+ 2.06, 2.059, 2.058, 2.057, 2.056, 2.055, 2.054, 2.053, 2.052, 2.051,
+ 2.05, 2.049, 2.048, 2.047, 2.046, 2.045, 2.044, 2.043, 2.042, 2.041,
+ 2.04, 2.039, 2.038, 2.037, 2.036, 2.036, 2.035, 2.034, 2.033, 2.032,
+ 2.031, 2.03, 2.029, 2.028, 2.027, 2.026, 2.025, 2.024, 2.023, 2.022,
+ 2.021, 2.02, 2.02, 2.019, 2.018, 2.017, 2.016, 2.015, 2.014, 2.013,
+ 2.012, 2.011, 2.01, 2.009, 2.008, 2.008, 2.007, 2.006, 2.005, 2.004,
+ 2.003, 2.002, 2.001, 2, 1.999, 1.998, 1.998, 1.997, 1.996, 1.995,
+ 1.994, 1.993, 1.992, 1.991, 1.99, 1.99, 1.989, 1.988, 1.987, 1.986,
+ 1.985, 1.984, 1.983, 1.982, 1.982, 1.981, 1.98, 1.979, 1.978, 1.977,
+ 1.976, 1.975, 1.975, 1.974, 1.973, 1.972, 1.971, 1.97, 1.969, 1.969,
+ 1.968, 1.967, 1.966, 1.965, 1.964, 1.963, 1.963, 1.962, 1.961, 1.96,
+ 1.959, 1.958, 1.957, 1.957, 1.956, 1.955, 1.954, 1.953, 1.952, 1.952,
+ 1.951, 1.95, 1.949, 1.948, 1.947, 1.947, 1.946, 1.945, 1.944, 1.943,
+ 1.942, 1.942, 1.941, 1.94, 1.939, 1.938, 1.937, 1.937, 1.936, 1.935,
+ 1.934, 1.933, 1.933, 1.932, 1.931, 1.93, 1.929, 1.928, 1.928, 1.927,
+ 1.926, 1.925, 1.924, 1.924, 1.923, 1.922, 1.921, 1.92, 1.92, 1.919,
+ 1.918, 1.917, 1.916, 1.916, 1.915, 1.914, 1.913, 1.912, 1.912, 1.911,
+ 1.91, 1.909, 1.908, 1.908, 1.907, 1.906, 1.905, 1.904, 1.904, 1.903,
+ 1.902, 1.901, 1.901, 1.9, 1.899, 1.898, 1.897, 1.897, 1.896, 1.895,
+ 1.894, 1.894, 1.893, 1.892, 1.891, 1.89, 1.89, 1.889, 1.888, 1.887,
+ 1.887, 1.886, 1.885, 1.884, 1.884, 1.883, 1.882, 1.881, 1.88, 1.88,
+ 1.879, 1.878, 1.877, 1.877, 1.876, 1.875, 1.874, 1.874, 1.873, 1.872,
+ 1.871, 1.871, 1.87, 1.869, 1.868, 1.868, 1.867, 1.866, 1.865, 1.865,
+ 1.864, 1.863, 1.862, 1.862, 1.861, 1.86, 1.859, 1.859, 1.858, 1.857,
+ 1.857, 1.856, 1.855, 1.854, 1.854, 1.853, 1.852, 1.851, 1.851, 1.85,
+ 1.849, 1.848, 1.848, 1.847, 1.846, 1.846, 1.845, 1.844, 1.843, 1.843,
+ 1.842, 1.841, 1.84, 1.84, 1.839, 1.838, 1.838, 1.837, 1.836, 1.835,
+ 1.835, 1.834, 1.833, 1.833, 1.832, 1.831, 1.83, 1.83, 1.829, 1.828,
+ 1.828, 1.827, 1.826, 1.825, 1.825, 1.824, 1.823, 1.823, 1.822, 1.821,
+ 1.821, 1.82, 1.819, 1.818, 1.818, 1.817, 1.816, 1.816, 1.815, 1.814,
+ 1.814, 1.813, 1.812, 1.811, 1.811, 1.81, 1.809, 1.809, 1.808, 1.807,
+ 1.807, 1.806, 1.805, 1.805, 1.804, 1.803, 1.802, 1.802, 1.801, 1.8,
+ 1.8, 1.799, 1.798, 1.798, 1.797, 1.796, 1.796, 1.795, 1.794, 1.794,
+ 1.793, 1.792, 1.792, 1.791, 1.79, 1.79, 1.789, 1.788, 1.787, 1.787,
+ 1.786, 1.785, 1.785, 1.784, 1.783, 1.783, 1.782, 1.781, 1.781, 1.78,
+ 1.779, 1.779, 1.778, 1.777, 1.777, 1.776, 1.775, 1.775, 1.774, 1.773,
+ 1.773, 1.772, 1.771, 1.771, 1.77, 1.769, 1.769, 1.768, 1.767, 1.767,
+ 1.766, 1.766, 1.765, 1.764, 1.764, 1.763, 1.762, 1.762, 1.761, 1.76,
+ 1.76, 1.759, 1.758, 1.758, 1.757, 1.756, 1.756, 1.755, 1.754, 1.754,
+ 1.753, 1.752, 1.752, 1.751, 1.751, 1.75, 1.749, 1.749, 1.748, 1.747,
+ 1.747, 1.746, 1.745, 1.745, 1.744, 1.743, 1.743, 1.742, 1.742, 1.741,
+ 1.74, 1.74, 1.739, 1.738, 1.738, 1.737, 1.736, 1.736, 1.735, 1.735,
+ 1.734, 1.733, 1.733, 1.732, 1.731, 1.731, 1.73, 1.729, 1.729, 1.728,
+ 1.728, 1.727, 1.726, 1.726, 1.725, 1.724, 1.724, 1.723, 1.723, 1.722,
+ 1.721, 1.721, 1.72, 1.719, 1.719, 1.718, 1.718, 1.717, 1.716, 1.716,
+ 1.715, 1.715, 1.714, 1.713, 1.713, 1.712, 1.711, 1.711, 1.71, 1.71,
+ 1.709, 1.708, 1.708, 1.707, 1.706, 1.706, 1.705, 1.705, 1.704, 1.703,
+ 1.703, 1.702, 1.702, 1.701, 1.7, 1.7, 1.699, 1.699, 1.698, 1.697,
+ 1.697, 1.696, 1.696, 1.695, 1.694, 1.694, 1.693, 1.692, 1.692, 1.691,
+ 1.691, 1.69, 1.689, 1.689, 1.688, 1.688, 1.687, 1.686, 1.686, 1.685,
+ 1.685, 1.684, 1.683, 1.683, 1.682, 1.682, 1.681, 1.68, 1.68, 1.679,
+ 1.679, 1.678, 1.678, 1.677, 1.676, 1.676, 1.675, 1.675, 1.674, 1.673,
+ 1.673, 1.672, 1.672, 1.671, 1.67, 1.67, 1.669, 1.669, 1.668, 1.667,
+ 1.667, 1.666, 1.666, 1.665, 1.665, 1.664, 1.663, 1.663, 1.662, 1.662,
+ 1.661, 1.66, 1.66, 1.659, 1.659, 1.658, 1.658, 1.657, 1.656, 1.656,
+ 1.655, 1.655, 1.654, 1.653, 1.653, 1.652, 1.652, 1.651, 1.651, 1.65,
+ 1.649, 1.649, 1.648, 1.648, 1.647, 1.647, 1.646, 1.645, 1.645, 1.644,
+ 1.644, 1.643, 1.643, 1.642, 1.641, 1.641, 1.64, 1.64, 1.639, 1.639,
+ 1.638, 1.637, 1.637, 1.636, 1.636, 1.635, 1.635, 1.634, 1.633, 1.633,
+ 1.632, 1.632, 1.631, 1.631, 1.63, 1.629, 1.629, 1.628, 1.628, 1.627,
+ 1.627, 1.626, 1.625, 1.625, 1.624, 1.624, 1.623, 1.623, 1.622, 1.622,
+ 1.621, 1.62, 1.62, 1.619, 1.619, 1.618, 1.618, 1.617, 1.617, 1.616,
+ 1.615, 1.615, 1.614, 1.614, 1.613, 1.613, 1.612, 1.612, 1.611, 1.61,
+ 1.61, 1.609, 1.609, 1.608, 1.608, 1.607, 1.607, 1.606, 1.605, 1.605,
+ 1.604, 1.604, 1.603, 1.603, 1.602, 1.602, 1.601, 1.6, 1.6, 1.599,
+ 1.599, 1.598, 1.598, 1.597, 1.597, 1.596, 1.596, 1.595, 1.594, 1.594,
+ 1.593, 1.593, 1.592, 1.592, 1.591, 1.591, 1.59, 1.59, 1.589, 1.588,
+ 1.588, 1.587, 1.587, 1.586, 1.586, 1.585, 1.585, 1.584, 1.584, 1.583,
+ 1.582, 1.582, 1.581, 1.581, 1.58, 1.58, 1.579, 1.579, 1.578, 1.578,
+ 1.577, 1.577, 1.576, 1.576, 1.575, 1.574, 1.574, 1.573, 1.573, 1.572,
+ 1.572, 1.571, 1.571, 1.57, 1.57, 1.569, 1.569, 1.568, 1.567, 1.567,
+ 1.566, 1.566, 1.565, 1.565, 1.564, 1.564, 1.563, 1.563, 1.562, 1.562,
+ 1.561, 1.561, 1.56, 1.56, 1.559, 1.558, 1.558, 1.557, 1.557, 1.556,
+ 1.556, 1.555, 1.555, 1.554, 1.554, 1.553, 1.553, 1.552, 1.552, 1.551,
+ 1.551, 1.55, 1.55, 1.549, 1.549, 1.548, 1.547, 1.547, 1.546, 1.546,
+ 1.545, 1.545, 1.544, 1.544, 1.543, 1.543, 1.542, 1.542, 1.541, 1.541,
+ 1.54, 1.54, 1.539, 1.539, 1.538, 1.538, 1.537, 1.537, 1.536, 1.536,
+ 1.535, 1.534, 1.534, 1.533, 1.533, 1.532, 1.532, 1.531, 1.531, 1.53,
+ 1.53, 1.529, 1.529, 1.528, 1.528, 1.527, 1.527, 1.526, 1.526, 1.525,
+ 1.525, 1.524, 1.524, 1.523, 1.523, 1.522, 1.522, 1.521, 1.521, 1.52,
+ 1.52, 1.519, 1.519, 1.518, 1.518, 1.517, 1.517, 1.516, 1.516, 1.515,
+ 1.515, 1.514, 1.514, 1.513, 1.512, 1.512, 1.511, 1.511, 1.51, 1.51,
+ 1.509, 1.509, 1.508, 1.508, 1.507, 1.507, 1.506, 1.506, 1.505, 1.505,
+ 1.504, 1.504, 1.503, 1.503, 1.502, 1.502, 1.501, 1.501, 1.5, 1.5,
+ 1.499, 1.499, 1.498, 1.498, 1.497, 1.497, 1.496, 1.496, 1.495, 1.495,
+ 1.494, 1.494, 1.493, 1.493, 1.492, 1.492, 1.491, 1.491, 1.49, 1.49,
+ 1.489, 1.489, 1.488, 1.488, 1.487, 1.487, 1.486, 1.486, 1.485, 1.485,
+ 1.484, 1.484, 1.483, 1.483, 1.482, 1.482, 1.481, 1.481, 1.48, 1.48,
+ 1.48, 1.479, 1.479, 1.478, 1.478, 1.477, 1.477, 1.476, 1.476, 1.475,
+ 1.475, 1.474, 1.474, 1.473, 1.473, 1.472, 1.472, 1.471, 1.471, 1.47,
+ 1.47, 1.469, 1.469, 1.468, 1.468, 1.467, 1.467, 1.466, 1.466, 1.465,
+ 1.465, 1.464, 1.464, 1.463, 1.463, 1.462, 1.462, 1.461, 1.461, 1.46,
+ 1.46, 1.459, 1.459, 1.458, 1.458, 1.458, 1.457, 1.457, 1.456, 1.456,
+ 1.455, 1.455, 1.454, 1.454, 1.453, 1.453, 1.452, 1.452, 1.451, 1.451,
+ 1.45, 1.45, 1.449, 1.449, 1.448, 1.448, 1.447, 1.447, 1.446, 1.446,
+ 1.445, 1.445, 1.444, 1.444, 1.444, 1.443, 1.443, 1.442, 1.442, 1.441,
+ 1.441, 1.44, 1.44, 1.439, 1.439, 1.438, 1.438, 1.437, 1.437, 1.436,
+ 1.436, 1.435, 1.435, 1.434, 1.434, 1.434, 1.433, 1.433, 1.432, 1.432,
+ 1.431, 1.431, 1.43, 1.43, 1.429, 1.429, 1.428, 1.428, 1.427, 1.427,
+ 1.426, 1.426, 1.425, 1.425, 1.424, 1.424, 1.424, 1.423, 1.423, 1.422,
+ 1.422, 1.421, 1.421, 1.42, 1.42, 1.419, 1.419, 1.418, 1.418, 1.417,
+ 1.417, 1.416, 1.416, 1.416, 1.415, 1.415, 1.414, 1.414, 1.413, 1.413,
+ 1.412, 1.412, 1.411, 1.411, 1.41, 1.41, 1.409, 1.409, 1.409, 1.408,
+ 1.408, 1.407, 1.407, 1.406, 1.406, 1.405, 1.405, 1.404, 1.404, 1.403,
+ 1.403, 1.402, 1.402, 1.402, 1.401, 1.401, 1.4, 1.4, 1.399, 1.399,
+ 1.398, 1.398, 1.397, 1.397, 1.396, 1.396, 1.395, 1.395, 1.395, 1.394,
+ 1.394, 1.393, 1.393, 1.392, 1.392, 1.391, 1.391, 1.39, 1.39, 1.389,
+ 1.389, 1.389, 1.388, 1.388, 1.387, 1.387, 1.386, 1.386, 1.385, 1.385,
+ 1.384, 1.384, 1.383, 1.383, 1.383, 1.382, 1.382, 1.381, 1.381, 1.38,
+ 1.38, 1.379, 1.379, 1.378, 1.378, 1.378, 1.377, 1.377, 1.376, 1.376,
+ 1.375, 1.375, 1.374, 1.374, 1.373, 1.373, 1.373, 1.372, 1.372, 1.371,
+ 1.371, 1.37, 1.37, 1.369, 1.369, 1.368, 1.368, 1.367, 1.367, 1.367,
+ 1.366, 1.366, 1.365, 1.365, 1.364, 1.364, 1.363, 1.363, 1.362, 1.362,
+ 1.362, 1.361, 1.361, 1.36, 1.36, 1.359, 1.359, 1.358, 1.358, 1.358,
+ 1.357, 1.357, 1.356, 1.356, 1.355, 1.355, 1.354, 1.354, 1.353, 1.353,
+ 1.353, 1.352, 1.352, 1.351, 1.351, 1.35, 1.35, 1.349, 1.349, 1.349,
+ 1.348, 1.348, 1.347, 1.347, 1.346, 1.346, 1.345, 1.345, 1.344, 1.344,
+ 1.344, 1.343, 1.343, 1.342, 1.342, 1.341, 1.341, 1.34, 1.34, 1.34,
+ 1.339, 1.339, 1.338, 1.338, 1.337, 1.337, 1.336, 1.336, 1.336, 1.335,
+ 1.335, 1.334, 1.334, 1.333, 1.333, 1.332, 1.332, 1.332, 1.331, 1.331,
+ 1.33, 1.33, 1.329, 1.329, 1.328, 1.328, 1.328, 1.327, 1.327, 1.326,
+ 1.326, 1.325, 1.325, 1.324, 1.324, 1.324, 1.323, 1.323, 1.322, 1.322,
+ 1.321, 1.321, 1.32, 1.32, 1.32, 1.319, 1.319, 1.318, 1.318, 1.317,
+ 1.317, 1.316, 1.316, 1.316, 1.315, 1.315, 1.314, 1.314, 1.313, 1.313,
+ 1.312, 1.312, 1.312, 1.311, 1.311, 1.31, 1.31, 1.309, 1.309, 1.309,
+ 1.308, 1.308, 1.307, 1.307, 1.306, 1.306, 1.305, 1.305, 1.305, 1.304,
+ 1.304, 1.303, 1.303, 1.302, 1.302, 1.302, 1.301, 1.301, 1.3, 1.3,
+ 1.299, 1.299, 1.298, 1.298, 1.298, 1.297, 1.297, 1.296, 1.296, 1.295,
+ 1.295, 1.295, 1.294, 1.294, 1.293, 1.293, 1.292, 1.292, 1.291, 1.291,
+ 1.291, 1.29, 1.29, 1.289, 1.289, 1.288, 1.288, 1.288, 1.287, 1.287,
+ 1.286, 1.286, 1.285, 1.285, 1.285, 1.284, 1.284, 1.283, 1.283, 1.282,
+ 1.282, 1.281, 1.281, 1.281, 1.28, 1.28, 1.279, 1.279, 1.278, 1.278,
+ 1.278, 1.277, 1.277, 1.276, 1.276, 1.275, 1.275, 1.275, 1.274, 1.274,
+ 1.273, 1.273, 1.272, 1.272, 1.272, 1.271, 1.271, 1.27, 1.27, 1.269,
+ 1.269, 1.269, 1.268, 1.268, 1.267, 1.267, 1.266, 1.266, 1.266, 1.265,
+ 1.265, 1.264, 1.264, 1.263, 1.263, 1.263, 1.262, 1.262, 1.261, 1.261,
+ 1.26, 1.26, 1.26, 1.259, 1.259, 1.258, 1.258, 1.257, 1.257, 1.257,
+ 1.256, 1.256, 1.255, 1.255, 1.254, 1.254, 1.254, 1.253, 1.253, 1.252,
+ 1.252, 1.251, 1.251, 1.251, 1.25, 1.25, 1.249, 1.249, 1.248, 1.248,
+ 1.248, 1.247, 1.247, 1.246, 1.246, 1.245, 1.245, 1.245, 1.244, 1.244,
+ 1.243, 1.243, 1.242, 1.242, 1.242, 1.241, 1.241, 1.24, 1.24, 1.239,
+ 1.239, 1.239, 1.238, 1.238, 1.237, 1.237, 1.237, 1.236, 1.236, 1.235,
+ 1.235, 1.234, 1.234, 1.234, 1.233, 1.233, 1.232, 1.232, 1.231, 1.231,
+ 1.231, 1.23, 1.23, 1.229, 1.229, 1.228, 1.228, 1.228, 1.227, 1.227,
+ 1.226, 1.226, 1.226, 1.225, 1.225, 1.224, 1.224, 1.223, 1.223, 1.223,
+ 1.222, 1.222, 1.221, 1.221, 1.22, 1.22, 1.22, 1.219, 1.219, 1.218,
+ 1.218, 1.218, 1.217, 1.217, 1.216, 1.216, 1.215, 1.215, 1.215, 1.214,
+ 1.214, 1.213, 1.213, 1.212, 1.212, 1.212, 1.211, 1.211, 1.21, 1.21,
+ 1.21, 1.209, 1.209, 1.208, 1.208, 1.207, 1.207, 1.207, 1.206, 1.206,
+ 1.205, 1.205, 1.204, 1.204, 1.204, 1.203, 1.203, 1.202, 1.202, 1.202,
+ 1.201, 1.201, 1.2, 1.2, 1.199, 1.199, 1.199, 1.198, 1.198, 1.197,
+ 1.197, 1.197, 1.196, 1.196, 1.195, 1.195, 1.194, 1.194, 1.194, 1.193,
+ 1.193, 1.192, 1.192, 1.192, 1.191, 1.191, 1.19, 1.19, 1.189, 1.189,
+ 1.189, 1.188, 1.188, 1.187, 1.187, 1.187, 1.186, 1.186, 1.185, 1.185,
+ 1.184, 1.184, 1.184, 1.183, 1.183, 1.182, 1.182, 1.182, 1.181, 1.181,
+ 1.18, 1.18, 1.179, 1.179, 1.179, 1.178, 1.178, 1.177, 1.177, 1.177,
+ 1.176, 1.176, 1.175, 1.175, 1.175, 1.174, 1.174, 1.173, 1.173, 1.172,
+ 1.172, 1.172, 1.171, 1.171, 1.17, 1.17, 1.17, 1.169, 1.169, 1.168,
+ 1.168, 1.167, 1.167, 1.167, 1.166, 1.166, 1.165, 1.165, 1.165, 1.164,
+ 1.164, 1.163, 1.163, 1.163, 1.162, 1.162, 1.161, 1.161, 1.16, 1.16,
+ 1.16, 1.159, 1.159, 1.158, 1.158, 1.158, 1.157, 1.157, 1.156, 1.156,
+ 1.156, 1.155, 1.155, 1.154, 1.154, 1.153, 1.153, 1.153, 1.152, 1.152,
+ 1.151, 1.151, 1.151, 1.15, 1.15, 1.149, 1.149, 1.149, 1.148, 1.148,
+ 1.147, 1.147, 1.146, 1.146, 1.146, 1.145, 1.145, 1.144, 1.144, 1.144,
+ 1.143, 1.143, 1.142, 1.142, 1.142, 1.141, 1.141, 1.14, 1.14, 1.139,
+ 1.139, 1.139, 1.138, 1.138, 1.137, 1.137, 1.137, 1.136, 1.136, 1.135,
+ 1.135, 1.135, 1.134, 1.134, 1.133, 1.133, 1.133, 1.132, 1.132, 1.131,
+ 1.131, 1.13, 1.13, 1.13, 1.129, 1.129, 1.128, 1.128, 1.128, 1.127,
+ 1.127, 1.126, 1.126, 1.126, 1.125, 1.125, 1.124, 1.124, 1.124, 1.123,
+ 1.123, 1.122, 1.122, 1.121, 1.121, 1.121, 1.12, 1.12, 1.119, 1.119,
+ 1.119, 1.118, 1.118, 1.117, 1.117, 1.117, 1.116, 1.116, 1.115, 1.115,
+ 1.115, 1.114, 1.114, 1.113, 1.113, 1.113, 1.112, 1.112, 1.111, 1.111,
+ 1.11, 1.11, 1.11, 1.109, 1.109, 1.108, 1.108, 1.108, 1.107, 1.107,
+ 1.106, 1.106, 1.106, 1.105, 1.105, 1.104, 1.104, 1.104, 1.103, 1.103,
+ 1.102, 1.102, 1.102, 1.101, 1.101, 1.1, 1.1, 1.099, 1.099, 1.099,
+ 1.098, 1.098, 1.097, 1.097, 1.097, 1.096, 1.096, 1.095, 1.095, 1.095,
+ 1.094, 1.094, 1.093, 1.093, 1.093, 1.092, 1.092, 1.091, 1.091, 1.091,
+ 1.09, 1.09, 1.089, 1.089, 1.089, 1.088, 1.088, 1.087, 1.087, 1.086,
+ 1.086, 1.086, 1.085, 1.085, 1.084, 1.084, 1.084, 1.083, 1.083, 1.082,
+ 1.082, 1.082, 1.081, 1.081, 1.08, 1.08, 1.08, 1.079, 1.079, 1.078,
+ 1.078, 1.078, 1.077, 1.077, 1.076, 1.076, 1.076, 1.075, 1.075, 1.074,
+ 1.074, 1.074, 1.073, 1.073, 1.072, 1.072, 1.072, 1.071, 1.071, 1.07,
+ 1.07, 1.069, 1.069, 1.069, 1.068, 1.068, 1.067, 1.067, 1.067, 1.066,
+ 1.066, 1.065, 1.065, 1.065, 1.064, 1.064, 1.063, 1.063, 1.063, 1.062,
+ 1.062, 1.061, 1.061, 1.061, 1.06, 1.06, 1.059, 1.059, 1.059, 1.058,
+ 1.058, 1.057, 1.057, 1.057, 1.056, 1.056, 1.055, 1.055, 1.055, 1.054,
+ 1.054, 1.053, 1.053, 1.053, 1.052, 1.052, 1.051, 1.051, 1.05, 1.05,
+ 1.05, 1.049, 1.049, 1.048, 1.048, 1.048, 1.047, 1.047, 1.046, 1.046,
+ 1.046, 1.045, 1.045, 1.044, 1.044, 1.044, 1.043, 1.043, 1.042, 1.042,
+ 1.042, 1.041, 1.041, 1.04, 1.04, 1.04, 1.039, 1.039, 1.038, 1.038,
+ 1.038, 1.037, 1.037, 1.036, 1.036, 1.036, 1.035, 1.035, 1.034, 1.034,
+ 1.034, 1.033, 1.033, 1.032, 1.032, 1.032, 1.031, 1.031, 1.03, 1.03,
+ 1.03, 1.029, 1.029, 1.028, 1.028, 1.028, 1.027, 1.027, 1.026, 1.026,
+ 1.026, 1.025, 1.025, 1.024, 1.024, 1.023, 1.023, 1.023, 1.022, 1.022,
+ 1.021, 1.021, 1.021, 1.02, 1.02, 1.019, 1.019, 1.019, 1.018, 1.018,
+ 1.017, 1.017, 1.017, 1.016, 1.016, 1.015, 1.015, 1.015, 1.014, 1.014,
+ 1.013, 1.013, 1.013, 1.012, 1.012, 1.011, 1.011, 1.011, 1.01, 1.01,
+ 1.009, 1.009, 1.009, 1.008, 1.008, 1.007, 1.007, 1.007, 1.006, 1.006,
+ 1.005, 1.005, 1.005, 1.004, 1.004, 1.003, 1.003, 1.003, 1.002, 1.002,
+ 1.001, 1.001, 1.001, 1, 0.9997, 0.9993, 0.9989, 0.9985, 0.9981, 0.9977,
+ 0.9973, 0.9969, 0.9965, 0.9961, 0.9957, 0.9953, 0.9949, 0.9945, 0.9941, 0.9937,
+ 0.9933, 0.9929, 0.9925, 0.9921, 0.9917, 0.9913, 0.9909, 0.9905, 0.9901, 0.9897,
+ 0.9893, 0.9889, 0.9885, 0.9881, 0.9877, 0.9873, 0.9869, 0.9865, 0.9861, 0.9856,
+ 0.9852, 0.9848, 0.9844, 0.984, 0.9836, 0.9832, 0.9828, 0.9824, 0.982, 0.9816,
+ 0.9812, 0.9808, 0.9804, 0.98, 0.9796, 0.9792, 0.9788, 0.9784, 0.978, 0.9776,
+ 0.9772, 0.9768, 0.9764, 0.976, 0.9756, 0.9752, 0.9748, 0.9744, 0.974, 0.9736,
+ 0.9732, 0.9728, 0.9724, 0.972, 0.9716, 0.9712, 0.9707, 0.9703, 0.9699, 0.9695,
+ 0.9691, 0.9687, 0.9683, 0.9679, 0.9675, 0.9671, 0.9667, 0.9663, 0.9659, 0.9655,
+ 0.9651, 0.9647, 0.9643, 0.9639, 0.9635, 0.9631, 0.9627, 0.9623, 0.9619, 0.9615,
+ 0.9611, 0.9607, 0.9603, 0.9599, 0.9595, 0.9591, 0.9587, 0.9583, 0.9579, 0.9574,
+ 0.957, 0.9566, 0.9562, 0.9558, 0.9554, 0.955, 0.9546, 0.9542, 0.9538, 0.9534,
+ 0.953, 0.9526, 0.9522, 0.9518, 0.9514, 0.951, 0.9506, 0.9502, 0.9498, 0.9494,
+ 0.949, 0.9486, 0.9482, 0.9478, 0.9474, 0.947, 0.9466, 0.9462, 0.9457, 0.9453,
+ 0.9449, 0.9445, 0.9441, 0.9437, 0.9433, 0.9429, 0.9425, 0.9421, 0.9417, 0.9413,
+ 0.9409, 0.9405, 0.9401, 0.9397, 0.9393, 0.9389, 0.9385, 0.9381, 0.9377, 0.9373,
+ 0.9369, 0.9365, 0.9361, 0.9356, 0.9352, 0.9348, 0.9344, 0.934, 0.9336, 0.9332,
+ 0.9328, 0.9324, 0.932, 0.9316, 0.9312, 0.9308, 0.9304, 0.93, 0.9296, 0.9292,
+ 0.9288, 0.9284, 0.928, 0.9276, 0.9272, 0.9267, 0.9263, 0.9259, 0.9255, 0.9251,
+ 0.9247, 0.9243, 0.9239, 0.9235, 0.9231, 0.9227, 0.9223, 0.9219, 0.9215, 0.9211,
+ 0.9207, 0.9203, 0.9199, 0.9195, 0.9191, 0.9186, 0.9182, 0.9178, 0.9174, 0.917,
+ 0.9166, 0.9162, 0.9158, 0.9154, 0.915, 0.9146, 0.9142, 0.9138, 0.9134, 0.913,
+ 0.9126, 0.9122, 0.9118, 0.9113, 0.9109, 0.9105, 0.9101, 0.9097, 0.9093, 0.9089,
+ 0.9085, 0.9081, 0.9077, 0.9073, 0.9069, 0.9065, 0.9061, 0.9057, 0.9053, 0.9049,
+ 0.9044, 0.904, 0.9036, 0.9032, 0.9028, 0.9024, 0.902, 0.9016, 0.9012, 0.9008,
+ 0.9004, 0.9, 0.8996, 0.8992, 0.8988, 0.8983, 0.8979, 0.8975, 0.8971, 0.8967,
+ 0.8963, 0.8959, 0.8955, 0.8951, 0.8947, 0.8943, 0.8939, 0.8935, 0.8931, 0.8926,
+ 0.8922, 0.8918, 0.8914, 0.891, 0.8906, 0.8902, 0.8898, 0.8894, 0.889, 0.8886,
+ 0.8882, 0.8878, 0.8873, 0.8869, 0.8865, 0.8861, 0.8857, 0.8853, 0.8849, 0.8845,
+ 0.8841, 0.8837, 0.8833, 0.8829, 0.8825, 0.882, 0.8816, 0.8812, 0.8808, 0.8804,
+ 0.88, 0.8796, 0.8792, 0.8788, 0.8784, 0.878, 0.8775, 0.8771, 0.8767, 0.8763,
+ 0.8759, 0.8755, 0.8751, 0.8747, 0.8743, 0.8739, 0.8735, 0.873, 0.8726, 0.8722,
+ 0.8718, 0.8714, 0.871, 0.8706, 0.8702, 0.8698, 0.8694, 0.869, 0.8685, 0.8681,
+ 0.8677, 0.8673, 0.8669, 0.8665, 0.8661, 0.8657, 0.8653, 0.8649, 0.8644, 0.864,
+ 0.8636, 0.8632, 0.8628, 0.8624, 0.862, 0.8616, 0.8612, 0.8607, 0.8603, 0.8599,
+ 0.8595, 0.8591, 0.8587, 0.8583, 0.8579, 0.8575, 0.857, 0.8566, 0.8562, 0.8558,
+ 0.8554, 0.855, 0.8546, 0.8542, 0.8538, 0.8533, 0.8529, 0.8525, 0.8521, 0.8517,
+ 0.8513, 0.8509, 0.8505, 0.85, 0.8496, 0.8492, 0.8488, 0.8484, 0.848, 0.8476,
+ 0.8472, 0.8467, 0.8463, 0.8459, 0.8455, 0.8451, 0.8447, 0.8443, 0.8439, 0.8434,
+ 0.843, 0.8426, 0.8422, 0.8418, 0.8414, 0.841, 0.8406, 0.8401, 0.8397, 0.8393,
+ 0.8389, 0.8385, 0.8381, 0.8377, 0.8372, 0.8368, 0.8364, 0.836, 0.8356, 0.8352,
+ 0.8348, 0.8343, 0.8339, 0.8335, 0.8331, 0.8327, 0.8323, 0.8319, 0.8314, 0.831,
+ 0.8306, 0.8302, 0.8298, 0.8294, 0.8289, 0.8285, 0.8281, 0.8277, 0.8273, 0.8269,
+ 0.8265, 0.826, 0.8256, 0.8252, 0.8248, 0.8244, 0.824, 0.8235, 0.8231, 0.8227,
+ 0.8223, 0.8219, 0.8215, 0.821, 0.8206, 0.8202, 0.8198, 0.8194, 0.819, 0.8185,
+ 0.8181, 0.8177, 0.8173, 0.8169, 0.8165, 0.816, 0.8156, 0.8152, 0.8148, 0.8144,
+ 0.814, 0.8135, 0.8131, 0.8127, 0.8123, 0.8119, 0.8114, 0.811, 0.8106, 0.8102,
+ 0.8098, 0.8094, 0.8089, 0.8085, 0.8081, 0.8077, 0.8073, 0.8068, 0.8064, 0.806,
+ 0.8056, 0.8052, 0.8047, 0.8043, 0.8039, 0.8035, 0.8031, 0.8026, 0.8022, 0.8018,
+ 0.8014, 0.801, 0.8005, 0.8001, 0.7997, 0.7993, 0.7989, 0.7984, 0.798, 0.7976,
+ 0.7972, 0.7968, 0.7963, 0.7959, 0.7955, 0.7951, 0.7947, 0.7942, 0.7938, 0.7934,
+ 0.793, 0.7926, 0.7921, 0.7917, 0.7913, 0.7909, 0.7904, 0.79, 0.7896, 0.7892,
+ 0.7888, 0.7883, 0.7879, 0.7875, 0.7871, 0.7866, 0.7862, 0.7858, 0.7854, 0.7849,
+ 0.7845, 0.7841, 0.7837, 0.7833, 0.7828, 0.7824, 0.782, 0.7816, 0.7811, 0.7807,
+ 0.7803, 0.7799, 0.7794, 0.779, 0.7786, 0.7782, 0.7777, 0.7773, 0.7769, 0.7765,
+ 0.776, 0.7756, 0.7752, 0.7748, 0.7743, 0.7739, 0.7735, 0.7731, 0.7726, 0.7722,
+ 0.7718, 0.7714, 0.7709, 0.7705, 0.7701, 0.7697, 0.7692, 0.7688, 0.7684, 0.7679,
+ 0.7675, 0.7671, 0.7667, 0.7662, 0.7658, 0.7654, 0.765, 0.7645, 0.7641, 0.7637,
+ 0.7632, 0.7628, 0.7624, 0.762, 0.7615, 0.7611, 0.7607, 0.7602, 0.7598, 0.7594,
+ 0.759, 0.7585, 0.7581, 0.7577, 0.7572, 0.7568, 0.7564, 0.756, 0.7555, 0.7551,
+ 0.7547, 0.7542, 0.7538, 0.7534, 0.7529, 0.7525, 0.7521, 0.7516, 0.7512, 0.7508,
+ 0.7504, 0.7499, 0.7495, 0.7491, 0.7486, 0.7482, 0.7478, 0.7473, 0.7469, 0.7465,
+ 0.746, 0.7456, 0.7452, 0.7447, 0.7443, 0.7439, 0.7434, 0.743, 0.7426, 0.7421,
+ 0.7417, 0.7413, 0.7408, 0.7404, 0.74, 0.7395, 0.7391, 0.7387, 0.7382, 0.7378,
+ 0.7374, 0.7369, 0.7365, 0.7361, 0.7356, 0.7352, 0.7348, 0.7343, 0.7339, 0.7335,
+ 0.733, 0.7326, 0.7321, 0.7317, 0.7313, 0.7308, 0.7304, 0.73, 0.7295, 0.7291,
+ 0.7287, 0.7282, 0.7278, 0.7273, 0.7269, 0.7265, 0.726, 0.7256, 0.7252, 0.7247,
+ 0.7243, 0.7238, 0.7234, 0.723, 0.7225, 0.7221, 0.7216, 0.7212, 0.7208, 0.7203,
+ 0.7199, 0.7195, 0.719, 0.7186, 0.7181, 0.7177, 0.7173, 0.7168, 0.7164, 0.7159,
+ 0.7155, 0.7151, 0.7146, 0.7142, 0.7137, 0.7133, 0.7128, 0.7124, 0.712, 0.7115,
+ 0.7111, 0.7106, 0.7102, 0.7098, 0.7093, 0.7089, 0.7084, 0.708, 0.7075, 0.7071,
+ 0.7066, 0.7062, 0.7058, 0.7053, 0.7049, 0.7044, 0.704, 0.7035, 0.7031, 0.7027,
+ 0.7022, 0.7018, 0.7013, 0.7009, 0.7004, 0.7, 0.6995, 0.6991, 0.6986, 0.6982,
+ 0.6978, 0.6973, 0.6969, 0.6964, 0.696, 0.6955, 0.6951, 0.6946, 0.6942, 0.6937,
+ 0.6933, 0.6928, 0.6924, 0.6919, 0.6915, 0.691, 0.6906, 0.6901, 0.6897, 0.6892,
+ 0.6888, 0.6883, 0.6879, 0.6874, 0.687, 0.6865, 0.6861, 0.6856, 0.6852, 0.6847,
+ 0.6843, 0.6838, 0.6834, 0.6829, 0.6825, 0.682, 0.6816, 0.6811, 0.6807, 0.6802,
+ 0.6798, 0.6793, 0.6789, 0.6784, 0.678, 0.6775, 0.6771, 0.6766, 0.6762, 0.6757,
+ 0.6752, 0.6748, 0.6743, 0.6739, 0.6734, 0.673, 0.6725, 0.6721, 0.6716, 0.6711,
+ 0.6707, 0.6702, 0.6698, 0.6693, 0.6689, 0.6684, 0.668, 0.6675, 0.667, 0.6666,
+ 0.6661, 0.6657, 0.6652, 0.6648, 0.6643, 0.6638, 0.6634, 0.6629, 0.6625, 0.662,
+ 0.6615, 0.6611, 0.6606, 0.6602, 0.6597, 0.6592, 0.6588, 0.6583, 0.6579, 0.6574,
+ 0.6569, 0.6565, 0.656, 0.6556, 0.6551, 0.6546, 0.6542, 0.6537, 0.6532, 0.6528,
+ 0.6523, 0.6519, 0.6514, 0.6509, 0.6505, 0.65, 0.6495, 0.6491, 0.6486, 0.6481,
+ 0.6477, 0.6472, 0.6468, 0.6463, 0.6458, 0.6454, 0.6449, 0.6444, 0.644, 0.6435,
+ 0.643, 0.6426, 0.6421, 0.6416, 0.6412, 0.6407, 0.6402, 0.6397, 0.6393, 0.6388,
+ 0.6383, 0.6379, 0.6374, 0.6369, 0.6365, 0.636, 0.6355, 0.6351, 0.6346, 0.6341,
+ 0.6336, 0.6332, 0.6327, 0.6322, 0.6318, 0.6313, 0.6308, 0.6303, 0.6299, 0.6294,
+ 0.6289, 0.6285, 0.628, 0.6275, 0.627, 0.6266, 0.6261, 0.6256, 0.6251, 0.6247,
+ 0.6242, 0.6237, 0.6232, 0.6228, 0.6223, 0.6218, 0.6213, 0.6208, 0.6204, 0.6199,
+ 0.6194, 0.6189, 0.6185, 0.618, 0.6175, 0.617, 0.6165, 0.6161, 0.6156, 0.6151,
+ 0.6146, 0.6142, 0.6137, 0.6132, 0.6127, 0.6122, 0.6117, 0.6113, 0.6108, 0.6103,
+ 0.6098, 0.6093, 0.6089, 0.6084, 0.6079, 0.6074, 0.6069, 0.6064, 0.606, 0.6055,
+ 0.605, 0.6045, 0.604, 0.6035, 0.603, 0.6026, 0.6021, 0.6016, 0.6011, 0.6006,
+ 0.6001, 0.5996, 0.5992, 0.5987, 0.5982, 0.5977, 0.5972, 0.5967, 0.5962, 0.5957,
+ 0.5952, 0.5948, 0.5943, 0.5938, 0.5933, 0.5928, 0.5923, 0.5918, 0.5913, 0.5908,
+ 0.5903, 0.5898, 0.5894, 0.5889, 0.5884, 0.5879, 0.5874, 0.5869, 0.5864, 0.5859,
+ 0.5854, 0.5849, 0.5844, 0.5839, 0.5834, 0.5829, 0.5824, 0.5819, 0.5814, 0.5809,
+ 0.5804, 0.5799, 0.5794, 0.5789, 0.5784, 0.5779, 0.5774, 0.5769, 0.5764, 0.5759,
+ 0.5754, 0.5749, 0.5744, 0.5739, 0.5734, 0.5729, 0.5724, 0.5719, 0.5714, 0.5709,
+ 0.5704, 0.5699, 0.5694, 0.5689, 0.5684, 0.5679, 0.5674, 0.5669, 0.5664, 0.5659,
+ 0.5654, 0.5649, 0.5644, 0.5639, 0.5633, 0.5628, 0.5623, 0.5618, 0.5613, 0.5608,
+ 0.5603, 0.5598, 0.5593, 0.5588, 0.5582, 0.5577, 0.5572, 0.5567, 0.5562, 0.5557,
+ 0.5552, 0.5547, 0.5541, 0.5536, 0.5531, 0.5526, 0.5521, 0.5516, 0.5511, 0.5505,
+ 0.55, 0.5495, 0.549, 0.5485, 0.548, 0.5474, 0.5469, 0.5464, 0.5459, 0.5454,
+ 0.5448, 0.5443, 0.5438, 0.5433, 0.5428, 0.5422, 0.5417, 0.5412, 0.5407, 0.5402,
+ 0.5396, 0.5391, 0.5386, 0.5381, 0.5375, 0.537, 0.5365, 0.536, 0.5354, 0.5349,
+ 0.5344, 0.5339, 0.5333, 0.5328, 0.5323, 0.5317, 0.5312, 0.5307, 0.5302, 0.5296,
+ 0.5291, 0.5286, 0.528, 0.5275, 0.527, 0.5264, 0.5259, 0.5254, 0.5248, 0.5243,
+ 0.5238, 0.5232, 0.5227, 0.5222, 0.5216, 0.5211, 0.5206, 0.52, 0.5195, 0.5189,
+ 0.5184, 0.5179, 0.5173, 0.5168, 0.5162, 0.5157, 0.5152, 0.5146, 0.5141, 0.5135,
+ 0.513, 0.5124, 0.5119, 0.5114, 0.5108, 0.5103, 0.5097, 0.5092, 0.5086, 0.5081,
+ 0.5075, 0.507, 0.5064, 0.5059, 0.5053, 0.5048, 0.5043, 0.5037, 0.5032, 0.5026,
+ 0.502, 0.5015, 0.5009, 0.5004, 0.4998, 0.4993, 0.4987, 0.4982, 0.4976, 0.4971,
+ 0.4965, 0.496, 0.4954, 0.4948, 0.4943, 0.4937, 0.4932, 0.4926, 0.492, 0.4915,
+ 0.4909, 0.4904, 0.4898, 0.4892, 0.4887, 0.4881, 0.4875, 0.487, 0.4864, 0.4859,
+ 0.4853, 0.4847, 0.4842, 0.4836, 0.483, 0.4825, 0.4819, 0.4813, 0.4807, 0.4802,
+ 0.4796, 0.479, 0.4785, 0.4779, 0.4773, 0.4767, 0.4762, 0.4756, 0.475, 0.4744,
+ 0.4739, 0.4733, 0.4727, 0.4721, 0.4716, 0.471, 0.4704, 0.4698, 0.4692, 0.4687,
+ 0.4681, 0.4675, 0.4669, 0.4663, 0.4657, 0.4652, 0.4646, 0.464, 0.4634, 0.4628,
+ 0.4622, 0.4616, 0.461, 0.4605, 0.4599, 0.4593, 0.4587, 0.4581, 0.4575, 0.4569,
+ 0.4563, 0.4557, 0.4551, 0.4545, 0.4539, 0.4533, 0.4527, 0.4521, 0.4515, 0.451,
+ 0.4504, 0.4498, 0.4491, 0.4485, 0.4479, 0.4473, 0.4467, 0.4461, 0.4455, 0.4449,
+ 0.4443, 0.4437, 0.4431, 0.4425, 0.4419, 0.4413, 0.4407, 0.4401, 0.4394, 0.4388,
+ 0.4382, 0.4376, 0.437, 0.4364, 0.4358, 0.4351, 0.4345, 0.4339, 0.4333, 0.4327,
+ 0.4321, 0.4314, 0.4308, 0.4302, 0.4296, 0.4289, 0.4283, 0.4277, 0.4271, 0.4264,
+ 0.4258, 0.4252, 0.4246, 0.4239, 0.4233, 0.4227, 0.422, 0.4214, 0.4208, 0.4201,
+ 0.4195, 0.4189, 0.4182, 0.4176, 0.4169, 0.4163, 0.4157, 0.415, 0.4144, 0.4137,
+ 0.4131, 0.4125, 0.4118, 0.4112, 0.4105, 0.4099, 0.4092, 0.4086, 0.4079, 0.4073,
+ 0.4066, 0.406, 0.4053, 0.4047, 0.404, 0.4034, 0.4027, 0.402, 0.4014, 0.4007,
+ 0.4001, 0.3994, 0.3987, 0.3981, 0.3974, 0.3967, 0.3961, 0.3954, 0.3947, 0.3941,
+ 0.3934, 0.3927, 0.3921, 0.3914, 0.3907, 0.39, 0.3894, 0.3887, 0.388, 0.3873,
+ 0.3866, 0.386, 0.3853, 0.3846, 0.3839, 0.3832, 0.3825, 0.3819, 0.3812, 0.3805,
+ 0.3798, 0.3791, 0.3784, 0.3777, 0.377, 0.3763, 0.3756, 0.3749, 0.3742, 0.3735,
+ 0.3728, 0.3721, 0.3714, 0.3707, 0.37, 0.3693, 0.3686, 0.3679, 0.3672, 0.3665,
+ 0.3657, 0.365, 0.3643, 0.3636, 0.3629, 0.3622, 0.3614, 0.3607, 0.36, 0.3593,
+ 0.3585, 0.3578, 0.3571, 0.3564, 0.3556, 0.3549, 0.3542, 0.3534, 0.3527, 0.352,
+ 0.3512, 0.3505, 0.3497, 0.349, 0.3483, 0.3475, 0.3468, 0.346, 0.3453, 0.3445,
+ 0.3438, 0.343, 0.3422, 0.3415, 0.3407, 0.34, 0.3392, 0.3384, 0.3377, 0.3369,
+ 0.3361, 0.3354, 0.3346, 0.3338, 0.3331, 0.3323, 0.3315, 0.3307, 0.3299, 0.3292,
+ 0.3284, 0.3276, 0.3268, 0.326, 0.3252, 0.3244, 0.3236, 0.3228, 0.3221, 0.3213,
+ 0.3205, 0.3196, 0.3188, 0.318, 0.3172, 0.3164, 0.3156, 0.3148, 0.314, 0.3132,
+ 0.3123, 0.3115, 0.3107, 0.3099, 0.309, 0.3082, 0.3074, 0.3065, 0.3057, 0.3049,
+ 0.304, 0.3032, 0.3023, 0.3015, 0.3007, 0.2998, 0.2989, 0.2981, 0.2972, 0.2964,
+ 0.2955, 0.2946, 0.2938, 0.2929, 0.292, 0.2912, 0.2903, 0.2894, 0.2885, 0.2877,
+ 0.2868, 0.2859, 0.285, 0.2841, 0.2832, 0.2823, 0.2814, 0.2805, 0.2796, 0.2787,
+ 0.2778, 0.2768, 0.2759, 0.275, 0.2741, 0.2732, 0.2722, 0.2713, 0.2704, 0.2694,
+ 0.2685, 0.2675, 0.2666, 0.2656, 0.2647, 0.2637, 0.2628, 0.2618, 0.2608, 0.2599,
+ 0.2589, 0.2579, 0.2569, 0.256, 0.255, 0.254, 0.253, 0.252, 0.251, 0.25,
+ 0.249, 0.248, 0.2469, 0.2459, 0.2449, 0.2439, 0.2428, 0.2418, 0.2408, 0.2397,
+ 0.2387, 0.2376, 0.2365, 0.2355, 0.2344, 0.2333, 0.2323, 0.2312, 0.2301, 0.229,
+ 0.2279, 0.2268, 0.2257, 0.2246, 0.2235, 0.2223, 0.2212, 0.2201, 0.2189, 0.2178,
+ 0.2166, 0.2155, 0.2143, 0.2132, 0.212, 0.2108, 0.2096, 0.2084, 0.2072, 0.206,
+ 0.2048, 0.2036, 0.2023, 0.2011, 0.1999, 0.1986, 0.1974, 0.1961, 0.1948, 0.1935,
+ 0.1923, 0.191, 0.1896, 0.1883, 0.187, 0.1857, 0.1843, 0.183, 0.1816, 0.1802,
+ 0.1789, 0.1775, 0.1761, 0.1747, 0.1732, 0.1718, 0.1703, 0.1689, 0.1674, 0.1659,
+ 0.1644, 0.1629, 0.1614, 0.1599, 0.1583, 0.1567, 0.1551, 0.1535, 0.1519, 0.1503,
+ 0.1486, 0.147, 0.1453, 0.1436, 0.1418, 0.1401, 0.1383, 0.1365, 0.1347, 0.1329,
+ 0.131, 0.1291, 0.1272, 0.1252, 0.1233, 0.1213, 0.1192, 0.1171, 0.115, 0.1129,
+ 0.1107, 0.1084, 0.1061, 0.1038, 0.1014, 0.09894, 0.09643, 0.09385, 0.0912, 0.08847,
+ 0.08566, 0.08275, 0.07974, 0.0766, 0.07334, 0.06992, 0.06633, 0.06253, 0.05849, 0.05415,
+ 0.04943, 0.0442, 0.03828, 0.03125, 0.0221, -0};
+
+ const static double cosvals[max] = {
+ 1, 1, 1, 1, 1, 1, 0.9999, 0.9999, 0.9999, 0.9999,
+ 0.9999, 0.9998, 0.9998, 0.9998, 0.9997, 0.9997, 0.9997, 0.9996, 0.9996, 0.9995,
+ 0.9995, 0.9994, 0.9994, 0.9993, 0.9993, 0.9992, 0.9991, 0.9991, 0.999, 0.9989,
+ 0.9989, 0.9988, 0.9987, 0.9986, 0.9986, 0.9985, 0.9984, 0.9983, 0.9982, 0.9981,
+ 0.998, 0.9979, 0.9978, 0.9977, 0.9976, 0.9975, 0.9974, 0.9973, 0.9972, 0.9971,
+ 0.9969, 0.9968, 0.9967, 0.9966, 0.9964, 0.9963, 0.9962, 0.996, 0.9959, 0.9958,
+ 0.9956, 0.9955, 0.9953, 0.9952, 0.995, 0.9949, 0.9947, 0.9946, 0.9944, 0.9942,
+ 0.9941, 0.9939, 0.9937, 0.9936, 0.9934, 0.9932, 0.993, 0.9929, 0.9927, 0.9925,
+ 0.9923, 0.9921, 0.9919, 0.9917, 0.9915, 0.9913, 0.9911, 0.9909, 0.9907, 0.9905,
+ 0.9903, 0.9901, 0.9898, 0.9896, 0.9894, 0.9892, 0.989, 0.9887, 0.9885, 0.9883,
+ 0.988, 0.9878, 0.9875, 0.9873, 0.9871, 0.9868, 0.9866, 0.9863, 0.9861, 0.9858,
+ 0.9855, 0.9853, 0.985, 0.9847, 0.9845, 0.9842, 0.9839, 0.9837, 0.9834, 0.9831,
+ 0.9828, 0.9825, 0.9823, 0.982, 0.9817, 0.9814, 0.9811, 0.9808, 0.9805, 0.9802,
+ 0.9799, 0.9796, 0.9793, 0.9789, 0.9786, 0.9783, 0.978, 0.9777, 0.9774, 0.977,
+ 0.9767, 0.9764, 0.976, 0.9757, 0.9754, 0.975, 0.9747, 0.9743, 0.974, 0.9736,
+ 0.9733, 0.9729, 0.9726, 0.9722, 0.9719, 0.9715, 0.9711, 0.9708, 0.9704, 0.97,
+ 0.9697, 0.9693, 0.9689, 0.9685, 0.9681, 0.9678, 0.9674, 0.967, 0.9666, 0.9662,
+ 0.9658, 0.9654, 0.965, 0.9646, 0.9642, 0.9638, 0.9634, 0.963, 0.9625, 0.9621,
+ 0.9617, 0.9613, 0.9609, 0.9604, 0.96, 0.9596, 0.9591, 0.9587, 0.9583, 0.9578,
+ 0.9574, 0.9569, 0.9565, 0.956, 0.9556, 0.9551, 0.9547, 0.9542, 0.9538, 0.9533,
+ 0.9528, 0.9524, 0.9519, 0.9514, 0.951, 0.9505, 0.95, 0.9495, 0.949, 0.9486,
+ 0.9481, 0.9476, 0.9471, 0.9466, 0.9461, 0.9456, 0.9451, 0.9446, 0.9441, 0.9436,
+ 0.9431, 0.9426, 0.9421, 0.9415, 0.941, 0.9405, 0.94, 0.9395, 0.9389, 0.9384,
+ 0.9379, 0.9373, 0.9368, 0.9363, 0.9357, 0.9352, 0.9346, 0.9341, 0.9335, 0.933,
+ 0.9324, 0.9319, 0.9313, 0.9308, 0.9302, 0.9296, 0.9291, 0.9285, 0.9279, 0.9274,
+ 0.9268, 0.9262, 0.9256, 0.925, 0.9245, 0.9239, 0.9233, 0.9227, 0.9221, 0.9215,
+ 0.9209, 0.9203, 0.9197, 0.9191, 0.9185, 0.9179, 0.9173, 0.9167, 0.9161, 0.9154,
+ 0.9148, 0.9142, 0.9136, 0.913, 0.9123, 0.9117, 0.9111, 0.9104, 0.9098, 0.9092,
+ 0.9085, 0.9079, 0.9072, 0.9066, 0.9059, 0.9053, 0.9046, 0.904, 0.9033, 0.9027,
+ 0.902, 0.9013, 0.9007, 0.9, 0.8993, 0.8987, 0.898, 0.8973, 0.8966, 0.896,
+ 0.8953, 0.8946, 0.8939, 0.8932, 0.8925, 0.8918, 0.8911, 0.8904, 0.8897, 0.889,
+ 0.8883, 0.8876, 0.8869, 0.8862, 0.8855, 0.8848, 0.8841, 0.8834, 0.8826, 0.8819,
+ 0.8812, 0.8805, 0.8797, 0.879, 0.8783, 0.8775, 0.8768, 0.8761, 0.8753, 0.8746,
+ 0.8738, 0.8731, 0.8723, 0.8716, 0.8708, 0.8701, 0.8693, 0.8686, 0.8678, 0.867,
+ 0.8663, 0.8655, 0.8647, 0.864, 0.8632, 0.8624, 0.8616, 0.8609, 0.8601, 0.8593,
+ 0.8585, 0.8577, 0.8569, 0.8561, 0.8554, 0.8546, 0.8538, 0.853, 0.8522, 0.8514,
+ 0.8505, 0.8497, 0.8489, 0.8481, 0.8473, 0.8465, 0.8457, 0.8449, 0.844, 0.8432,
+ 0.8424, 0.8416, 0.8407, 0.8399, 0.8391, 0.8382, 0.8374, 0.8365, 0.8357, 0.8349,
+ 0.834, 0.8332, 0.8323, 0.8315, 0.8306, 0.8298, 0.8289, 0.828, 0.8272, 0.8263,
+ 0.8255, 0.8246, 0.8237, 0.8228, 0.822, 0.8211, 0.8202, 0.8193, 0.8185, 0.8176,
+ 0.8167, 0.8158, 0.8149, 0.814, 0.8131, 0.8123, 0.8114, 0.8105, 0.8096, 0.8087,
+ 0.8078, 0.8068, 0.8059, 0.805, 0.8041, 0.8032, 0.8023, 0.8014, 0.8005, 0.7995,
+ 0.7986, 0.7977, 0.7968, 0.7958, 0.7949, 0.794, 0.793, 0.7921, 0.7912, 0.7902,
+ 0.7893, 0.7883, 0.7874, 0.7865, 0.7855, 0.7846, 0.7836, 0.7827, 0.7817, 0.7807,
+ 0.7798, 0.7788, 0.7779, 0.7769, 0.7759, 0.775, 0.774, 0.773, 0.772, 0.7711,
+ 0.7701, 0.7691, 0.7681, 0.7671, 0.7662, 0.7652, 0.7642, 0.7632, 0.7622, 0.7612,
+ 0.7602, 0.7592, 0.7582, 0.7572, 0.7562, 0.7552, 0.7542, 0.7532, 0.7522, 0.7512,
+ 0.7502, 0.7491, 0.7481, 0.7471, 0.7461, 0.7451, 0.744, 0.743, 0.742, 0.741,
+ 0.7399, 0.7389, 0.7379, 0.7368, 0.7358, 0.7347, 0.7337, 0.7327, 0.7316, 0.7306,
+ 0.7295, 0.7285, 0.7274, 0.7264, 0.7253, 0.7242, 0.7232, 0.7221, 0.7211, 0.72,
+ 0.7189, 0.7179, 0.7168, 0.7157, 0.7147, 0.7136, 0.7125, 0.7114, 0.7104, 0.7093,
+ 0.7082, 0.7071, 0.706, 0.7049, 0.7038, 0.7028, 0.7017, 0.7006, 0.6995, 0.6984,
+ 0.6973, 0.6962, 0.6951, 0.694, 0.6929, 0.6918, 0.6907, 0.6895, 0.6884, 0.6873,
+ 0.6862, 0.6851, 0.684, 0.6828, 0.6817, 0.6806, 0.6795, 0.6784, 0.6772, 0.6761,
+ 0.675, 0.6738, 0.6727, 0.6716, 0.6704, 0.6693, 0.6681, 0.667, 0.6659, 0.6647,
+ 0.6636, 0.6624, 0.6613, 0.6601, 0.659, 0.6578, 0.6567, 0.6555, 0.6543, 0.6532,
+ 0.652, 0.6508, 0.6497, 0.6485, 0.6473, 0.6462, 0.645, 0.6438, 0.6427, 0.6415,
+ 0.6403, 0.6391, 0.6379, 0.6368, 0.6356, 0.6344, 0.6332, 0.632, 0.6308, 0.6296,
+ 0.6284, 0.6273, 0.6261, 0.6249, 0.6237, 0.6225, 0.6213, 0.6201, 0.6189, 0.6176,
+ 0.6164, 0.6152, 0.614, 0.6128, 0.6116, 0.6104, 0.6092, 0.6079, 0.6067, 0.6055,
+ 0.6043, 0.6031, 0.6018, 0.6006, 0.5994, 0.5982, 0.5969, 0.5957, 0.5945, 0.5932,
+ 0.592, 0.5908, 0.5895, 0.5883, 0.587, 0.5858, 0.5846, 0.5833, 0.5821, 0.5808,
+ 0.5796, 0.5783, 0.5771, 0.5758, 0.5746, 0.5733, 0.572, 0.5708, 0.5695, 0.5683,
+ 0.567, 0.5657, 0.5645, 0.5632, 0.5619, 0.5607, 0.5594, 0.5581, 0.5568, 0.5556,
+ 0.5543, 0.553, 0.5517, 0.5505, 0.5492, 0.5479, 0.5466, 0.5453, 0.544, 0.5428,
+ 0.5415, 0.5402, 0.5389, 0.5376, 0.5363, 0.535, 0.5337, 0.5324, 0.5311, 0.5298,
+ 0.5285, 0.5272, 0.5259, 0.5246, 0.5233, 0.522, 0.5207, 0.5194, 0.518, 0.5167,
+ 0.5154, 0.5141, 0.5128, 0.5115, 0.5102, 0.5088, 0.5075, 0.5062, 0.5049, 0.5035,
+ 0.5022, 0.5009, 0.4996, 0.4982, 0.4969, 0.4956, 0.4942, 0.4929, 0.4916, 0.4902,
+ 0.4889, 0.4876, 0.4862, 0.4849, 0.4835, 0.4822, 0.4808, 0.4795, 0.4781, 0.4768,
+ 0.4755, 0.4741, 0.4727, 0.4714, 0.47, 0.4687, 0.4673, 0.466, 0.4646, 0.4633,
+ 0.4619, 0.4605, 0.4592, 0.4578, 0.4564, 0.4551, 0.4537, 0.4523, 0.451, 0.4496,
+ 0.4482, 0.4469, 0.4455, 0.4441, 0.4427, 0.4414, 0.44, 0.4386, 0.4372, 0.4359,
+ 0.4345, 0.4331, 0.4317, 0.4303, 0.4289, 0.4276, 0.4262, 0.4248, 0.4234, 0.422,
+ 0.4206, 0.4192, 0.4178, 0.4164, 0.415, 0.4136, 0.4122, 0.4108, 0.4094, 0.408,
+ 0.4066, 0.4052, 0.4038, 0.4024, 0.401, 0.3996, 0.3982, 0.3968, 0.3954, 0.394,
+ 0.3926, 0.3912, 0.3898, 0.3883, 0.3869, 0.3855, 0.3841, 0.3827, 0.3813, 0.3798,
+ 0.3784, 0.377, 0.3756, 0.3742, 0.3727, 0.3713, 0.3699, 0.3685, 0.367, 0.3656,
+ 0.3642, 0.3628, 0.3613, 0.3599, 0.3585, 0.357, 0.3556, 0.3542, 0.3527, 0.3513,
+ 0.3499, 0.3484, 0.347, 0.3455, 0.3441, 0.3427, 0.3412, 0.3398, 0.3383, 0.3369,
+ 0.3354, 0.334, 0.3326, 0.3311, 0.3297, 0.3282, 0.3268, 0.3253, 0.3239, 0.3224,
+ 0.321, 0.3195, 0.318, 0.3166, 0.3151, 0.3137, 0.3122, 0.3108, 0.3093, 0.3078,
+ 0.3064, 0.3049, 0.3035, 0.302, 0.3005, 0.2991, 0.2976, 0.2962, 0.2947, 0.2932,
+ 0.2918, 0.2903, 0.2888, 0.2873, 0.2859, 0.2844, 0.2829, 0.2815, 0.28, 0.2785,
+ 0.277, 0.2756, 0.2741, 0.2726, 0.2711, 0.2697, 0.2682, 0.2667, 0.2652, 0.2638,
+ 0.2623, 0.2608, 0.2593, 0.2578, 0.2563, 0.2549, 0.2534, 0.2519, 0.2504, 0.2489,
+ 0.2474, 0.246, 0.2445, 0.243, 0.2415, 0.24, 0.2385, 0.237, 0.2355, 0.234,
+ 0.2326, 0.2311, 0.2296, 0.2281, 0.2266, 0.2251, 0.2236, 0.2221, 0.2206, 0.2191,
+ 0.2176, 0.2161, 0.2146, 0.2131, 0.2116, 0.2101, 0.2086, 0.2071, 0.2056, 0.2041,
+ 0.2026, 0.2011, 0.1996, 0.1981, 0.1966, 0.1951, 0.1936, 0.1921, 0.1906, 0.1891,
+ 0.1876, 0.1861, 0.1845, 0.183, 0.1815, 0.18, 0.1785, 0.177, 0.1755, 0.174,
+ 0.1725, 0.171, 0.1695, 0.1679, 0.1664, 0.1649, 0.1634, 0.1619, 0.1604, 0.1589,
+ 0.1573, 0.1558, 0.1543, 0.1528, 0.1513, 0.1498, 0.1482, 0.1467, 0.1452, 0.1437,
+ 0.1422, 0.1407, 0.1391, 0.1376, 0.1361, 0.1346, 0.1331, 0.1315, 0.13, 0.1285,
+ 0.127, 0.1255, 0.1239, 0.1224, 0.1209, 0.1194, 0.1178, 0.1163, 0.1148, 0.1133,
+ 0.1117, 0.1102, 0.1087, 0.1072, 0.1056, 0.1041, 0.1026, 0.1011, 0.09954, 0.09802,
+ 0.09649, 0.09496, 0.09344, 0.09191, 0.09038, 0.08885, 0.08733, 0.0858, 0.08427, 0.08274,
+ 0.08121, 0.07968, 0.07815, 0.07662, 0.07509, 0.07356, 0.07203, 0.0705, 0.06897, 0.06744,
+ 0.06591, 0.06438, 0.06285, 0.06132, 0.05979, 0.05826, 0.05673, 0.0552, 0.05366, 0.05213,
+ 0.0506, 0.04907, 0.04754, 0.046, 0.04447, 0.04294, 0.04141, 0.03987, 0.03834, 0.03681,
+ 0.03527, 0.03374, 0.03221, 0.03067, 0.02914, 0.02761, 0.02607, 0.02454, 0.02301, 0.02147,
+ 0.01994, 0.01841, 0.01687, 0.01534, 0.01381, 0.01227, 0.01074, 0.009204, 0.00767, 0.006136,
+ 0.004602, 0.003068, 0.001534, 6.123e-17, -0.001534, -0.003068, -0.004602, -0.006136, -0.00767, -0.009204,
+ -0.01074, -0.01227, -0.01381, -0.01534, -0.01687, -0.01841, -0.01994, -0.02147, -0.02301, -0.02454,
+ -0.02607, -0.02761, -0.02914, -0.03067, -0.03221, -0.03374, -0.03527, -0.03681, -0.03834, -0.03987,
+ -0.04141, -0.04294, -0.04447, -0.046, -0.04754, -0.04907, -0.0506, -0.05213, -0.05366, -0.0552,
+ -0.05673, -0.05826, -0.05979, -0.06132, -0.06285, -0.06438, -0.06591, -0.06744, -0.06897, -0.0705,
+ -0.07203, -0.07356, -0.07509, -0.07662, -0.07815, -0.07968, -0.08121, -0.08274, -0.08427, -0.0858,
+ -0.08733, -0.08885, -0.09038, -0.09191, -0.09344, -0.09496, -0.09649, -0.09802, -0.09954, -0.1011,
+ -0.1026, -0.1041, -0.1056, -0.1072, -0.1087, -0.1102, -0.1117, -0.1133, -0.1148, -0.1163,
+ -0.1178, -0.1194, -0.1209, -0.1224, -0.1239, -0.1255, -0.127, -0.1285, -0.13, -0.1315,
+ -0.1331, -0.1346, -0.1361, -0.1376, -0.1391, -0.1407, -0.1422, -0.1437, -0.1452, -0.1467,
+ -0.1482, -0.1498, -0.1513, -0.1528, -0.1543, -0.1558, -0.1573, -0.1589, -0.1604, -0.1619,
+ -0.1634, -0.1649, -0.1664, -0.1679, -0.1695, -0.171, -0.1725, -0.174, -0.1755, -0.177,
+ -0.1785, -0.18, -0.1815, -0.183, -0.1845, -0.1861, -0.1876, -0.1891, -0.1906, -0.1921,
+ -0.1936, -0.1951, -0.1966, -0.1981, -0.1996, -0.2011, -0.2026, -0.2041, -0.2056, -0.2071,
+ -0.2086, -0.2101, -0.2116, -0.2131, -0.2146, -0.2161, -0.2176, -0.2191, -0.2206, -0.2221,
+ -0.2236, -0.2251, -0.2266, -0.2281, -0.2296, -0.2311, -0.2326, -0.234, -0.2355, -0.237,
+ -0.2385, -0.24, -0.2415, -0.243, -0.2445, -0.246, -0.2474, -0.2489, -0.2504, -0.2519,
+ -0.2534, -0.2549, -0.2563, -0.2578, -0.2593, -0.2608, -0.2623, -0.2638, -0.2652, -0.2667,
+ -0.2682, -0.2697, -0.2711, -0.2726, -0.2741, -0.2756, -0.277, -0.2785, -0.28, -0.2815,
+ -0.2829, -0.2844, -0.2859, -0.2873, -0.2888, -0.2903, -0.2918, -0.2932, -0.2947, -0.2962,
+ -0.2976, -0.2991, -0.3005, -0.302, -0.3035, -0.3049, -0.3064, -0.3078, -0.3093, -0.3108,
+ -0.3122, -0.3137, -0.3151, -0.3166, -0.318, -0.3195, -0.321, -0.3224, -0.3239, -0.3253,
+ -0.3268, -0.3282, -0.3297, -0.3311, -0.3326, -0.334, -0.3354, -0.3369, -0.3383, -0.3398,
+ -0.3412, -0.3427, -0.3441, -0.3455, -0.347, -0.3484, -0.3499, -0.3513, -0.3527, -0.3542,
+ -0.3556, -0.357, -0.3585, -0.3599, -0.3613, -0.3628, -0.3642, -0.3656, -0.367, -0.3685,
+ -0.3699, -0.3713, -0.3727, -0.3742, -0.3756, -0.377, -0.3784, -0.3798, -0.3813, -0.3827,
+ -0.3841, -0.3855, -0.3869, -0.3883, -0.3898, -0.3912, -0.3926, -0.394, -0.3954, -0.3968,
+ -0.3982, -0.3996, -0.401, -0.4024, -0.4038, -0.4052, -0.4066, -0.408, -0.4094, -0.4108,
+ -0.4122, -0.4136, -0.415, -0.4164, -0.4178, -0.4192, -0.4206, -0.422, -0.4234, -0.4248,
+ -0.4262, -0.4276, -0.4289, -0.4303, -0.4317, -0.4331, -0.4345, -0.4359, -0.4372, -0.4386,
+ -0.44, -0.4414, -0.4427, -0.4441, -0.4455, -0.4469, -0.4482, -0.4496, -0.451, -0.4523,
+ -0.4537, -0.4551, -0.4564, -0.4578, -0.4592, -0.4605, -0.4619, -0.4633, -0.4646, -0.466,
+ -0.4673, -0.4687, -0.47, -0.4714, -0.4727, -0.4741, -0.4755, -0.4768, -0.4781, -0.4795,
+ -0.4808, -0.4822, -0.4835, -0.4849, -0.4862, -0.4876, -0.4889, -0.4902, -0.4916, -0.4929,
+ -0.4942, -0.4956, -0.4969, -0.4982, -0.4996, -0.5009, -0.5022, -0.5035, -0.5049, -0.5062,
+ -0.5075, -0.5088, -0.5102, -0.5115, -0.5128, -0.5141, -0.5154, -0.5167, -0.518, -0.5194,
+ -0.5207, -0.522, -0.5233, -0.5246, -0.5259, -0.5272, -0.5285, -0.5298, -0.5311, -0.5324,
+ -0.5337, -0.535, -0.5363, -0.5376, -0.5389, -0.5402, -0.5415, -0.5428, -0.544, -0.5453,
+ -0.5466, -0.5479, -0.5492, -0.5505, -0.5517, -0.553, -0.5543, -0.5556, -0.5568, -0.5581,
+ -0.5594, -0.5607, -0.5619, -0.5632, -0.5645, -0.5657, -0.567, -0.5683, -0.5695, -0.5708,
+ -0.572, -0.5733, -0.5746, -0.5758, -0.5771, -0.5783, -0.5796, -0.5808, -0.5821, -0.5833,
+ -0.5846, -0.5858, -0.587, -0.5883, -0.5895, -0.5908, -0.592, -0.5932, -0.5945, -0.5957,
+ -0.5969, -0.5982, -0.5994, -0.6006, -0.6018, -0.6031, -0.6043, -0.6055, -0.6067, -0.6079,
+ -0.6092, -0.6104, -0.6116, -0.6128, -0.614, -0.6152, -0.6164, -0.6176, -0.6189, -0.6201,
+ -0.6213, -0.6225, -0.6237, -0.6249, -0.6261, -0.6273, -0.6284, -0.6296, -0.6308, -0.632,
+ -0.6332, -0.6344, -0.6356, -0.6368, -0.6379, -0.6391, -0.6403, -0.6415, -0.6427, -0.6438,
+ -0.645, -0.6462, -0.6473, -0.6485, -0.6497, -0.6508, -0.652, -0.6532, -0.6543, -0.6555,
+ -0.6567, -0.6578, -0.659, -0.6601, -0.6613, -0.6624, -0.6636, -0.6647, -0.6659, -0.667,
+ -0.6681, -0.6693, -0.6704, -0.6716, -0.6727, -0.6738, -0.675, -0.6761, -0.6772, -0.6784,
+ -0.6795, -0.6806, -0.6817, -0.6828, -0.684, -0.6851, -0.6862, -0.6873, -0.6884, -0.6895,
+ -0.6907, -0.6918, -0.6929, -0.694, -0.6951, -0.6962, -0.6973, -0.6984, -0.6995, -0.7006,
+ -0.7017, -0.7028, -0.7038, -0.7049, -0.706, -0.7071, -0.7082, -0.7093, -0.7104, -0.7114,
+ -0.7125, -0.7136, -0.7147, -0.7157, -0.7168, -0.7179, -0.7189, -0.72, -0.7211, -0.7221,
+ -0.7232, -0.7242, -0.7253, -0.7264, -0.7274, -0.7285, -0.7295, -0.7306, -0.7316, -0.7327,
+ -0.7337, -0.7347, -0.7358, -0.7368, -0.7379, -0.7389, -0.7399, -0.741, -0.742, -0.743,
+ -0.744, -0.7451, -0.7461, -0.7471, -0.7481, -0.7491, -0.7502, -0.7512, -0.7522, -0.7532,
+ -0.7542, -0.7552, -0.7562, -0.7572, -0.7582, -0.7592, -0.7602, -0.7612, -0.7622, -0.7632,
+ -0.7642, -0.7652, -0.7662, -0.7671, -0.7681, -0.7691, -0.7701, -0.7711, -0.772, -0.773,
+ -0.774, -0.775, -0.7759, -0.7769, -0.7779, -0.7788, -0.7798, -0.7807, -0.7817, -0.7827,
+ -0.7836, -0.7846, -0.7855, -0.7865, -0.7874, -0.7883, -0.7893, -0.7902, -0.7912, -0.7921,
+ -0.793, -0.794, -0.7949, -0.7958, -0.7968, -0.7977, -0.7986, -0.7995, -0.8005, -0.8014,
+ -0.8023, -0.8032, -0.8041, -0.805, -0.8059, -0.8068, -0.8078, -0.8087, -0.8096, -0.8105,
+ -0.8114, -0.8123, -0.8131, -0.814, -0.8149, -0.8158, -0.8167, -0.8176, -0.8185, -0.8193,
+ -0.8202, -0.8211, -0.822, -0.8228, -0.8237, -0.8246, -0.8255, -0.8263, -0.8272, -0.828,
+ -0.8289, -0.8298, -0.8306, -0.8315, -0.8323, -0.8332, -0.834, -0.8349, -0.8357, -0.8365,
+ -0.8374, -0.8382, -0.8391, -0.8399, -0.8407, -0.8416, -0.8424, -0.8432, -0.844, -0.8449,
+ -0.8457, -0.8465, -0.8473, -0.8481, -0.8489, -0.8497, -0.8505, -0.8514, -0.8522, -0.853,
+ -0.8538, -0.8546, -0.8554, -0.8561, -0.8569, -0.8577, -0.8585, -0.8593, -0.8601, -0.8609,
+ -0.8616, -0.8624, -0.8632, -0.864, -0.8647, -0.8655, -0.8663, -0.867, -0.8678, -0.8686,
+ -0.8693, -0.8701, -0.8708, -0.8716, -0.8723, -0.8731, -0.8738, -0.8746, -0.8753, -0.8761,
+ -0.8768, -0.8775, -0.8783, -0.879, -0.8797, -0.8805, -0.8812, -0.8819, -0.8826, -0.8834,
+ -0.8841, -0.8848, -0.8855, -0.8862, -0.8869, -0.8876, -0.8883, -0.889, -0.8897, -0.8904,
+ -0.8911, -0.8918, -0.8925, -0.8932, -0.8939, -0.8946, -0.8953, -0.896, -0.8966, -0.8973,
+ -0.898, -0.8987, -0.8993, -0.9, -0.9007, -0.9013, -0.902, -0.9027, -0.9033, -0.904,
+ -0.9046, -0.9053, -0.9059, -0.9066, -0.9072, -0.9079, -0.9085, -0.9092, -0.9098, -0.9104,
+ -0.9111, -0.9117, -0.9123, -0.913, -0.9136, -0.9142, -0.9148, -0.9154, -0.9161, -0.9167,
+ -0.9173, -0.9179, -0.9185, -0.9191, -0.9197, -0.9203, -0.9209, -0.9215, -0.9221, -0.9227,
+ -0.9233, -0.9239, -0.9245, -0.925, -0.9256, -0.9262, -0.9268, -0.9274, -0.9279, -0.9285,
+ -0.9291, -0.9296, -0.9302, -0.9308, -0.9313, -0.9319, -0.9324, -0.933, -0.9335, -0.9341,
+ -0.9346, -0.9352, -0.9357, -0.9363, -0.9368, -0.9373, -0.9379, -0.9384, -0.9389, -0.9395,
+ -0.94, -0.9405, -0.941, -0.9415, -0.9421, -0.9426, -0.9431, -0.9436, -0.9441, -0.9446,
+ -0.9451, -0.9456, -0.9461, -0.9466, -0.9471, -0.9476, -0.9481, -0.9486, -0.949, -0.9495,
+ -0.95, -0.9505, -0.951, -0.9514, -0.9519, -0.9524, -0.9528, -0.9533, -0.9538, -0.9542,
+ -0.9547, -0.9551, -0.9556, -0.956, -0.9565, -0.9569, -0.9574, -0.9578, -0.9583, -0.9587,
+ -0.9591, -0.9596, -0.96, -0.9604, -0.9609, -0.9613, -0.9617, -0.9621, -0.9625, -0.963,
+ -0.9634, -0.9638, -0.9642, -0.9646, -0.965, -0.9654, -0.9658, -0.9662, -0.9666, -0.967,
+ -0.9674, -0.9678, -0.9681, -0.9685, -0.9689, -0.9693, -0.9697, -0.97, -0.9704, -0.9708,
+ -0.9711, -0.9715, -0.9719, -0.9722, -0.9726, -0.9729, -0.9733, -0.9736, -0.974, -0.9743,
+ -0.9747, -0.975, -0.9754, -0.9757, -0.976, -0.9764, -0.9767, -0.977, -0.9774, -0.9777,
+ -0.978, -0.9783, -0.9786, -0.9789, -0.9793, -0.9796, -0.9799, -0.9802, -0.9805, -0.9808,
+ -0.9811, -0.9814, -0.9817, -0.982, -0.9823, -0.9825, -0.9828, -0.9831, -0.9834, -0.9837,
+ -0.9839, -0.9842, -0.9845, -0.9847, -0.985, -0.9853, -0.9855, -0.9858, -0.9861, -0.9863,
+ -0.9866, -0.9868, -0.9871, -0.9873, -0.9875, -0.9878, -0.988, -0.9883, -0.9885, -0.9887,
+ -0.989, -0.9892, -0.9894, -0.9896, -0.9898, -0.9901, -0.9903, -0.9905, -0.9907, -0.9909,
+ -0.9911, -0.9913, -0.9915, -0.9917, -0.9919, -0.9921, -0.9923, -0.9925, -0.9927, -0.9929,
+ -0.993, -0.9932, -0.9934, -0.9936, -0.9937, -0.9939, -0.9941, -0.9942, -0.9944, -0.9946,
+ -0.9947, -0.9949, -0.995, -0.9952, -0.9953, -0.9955, -0.9956, -0.9958, -0.9959, -0.996,
+ -0.9962, -0.9963, -0.9964, -0.9966, -0.9967, -0.9968, -0.9969, -0.9971, -0.9972, -0.9973,
+ -0.9974, -0.9975, -0.9976, -0.9977, -0.9978, -0.9979, -0.998, -0.9981, -0.9982, -0.9983,
+ -0.9984, -0.9985, -0.9986, -0.9986, -0.9987, -0.9988, -0.9989, -0.9989, -0.999, -0.9991,
+ -0.9991, -0.9992, -0.9993, -0.9993, -0.9994, -0.9994, -0.9995, -0.9995, -0.9996, -0.9996,
+ -0.9997, -0.9997, -0.9997, -0.9998, -0.9998, -0.9998, -0.9999, -0.9999, -0.9999, -0.9999,
+ -0.9999, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+ -1, -1, -1, -1, -0.9999, -0.9999, -0.9999, -0.9999, -0.9999, -0.9998,
+ -0.9998, -0.9998, -0.9997, -0.9997, -0.9997, -0.9996, -0.9996, -0.9995, -0.9995, -0.9994,
+ -0.9994, -0.9993, -0.9993, -0.9992, -0.9991, -0.9991, -0.999, -0.9989, -0.9989, -0.9988,
+ -0.9987, -0.9986, -0.9986, -0.9985, -0.9984, -0.9983, -0.9982, -0.9981, -0.998, -0.9979,
+ -0.9978, -0.9977, -0.9976, -0.9975, -0.9974, -0.9973, -0.9972, -0.9971, -0.9969, -0.9968,
+ -0.9967, -0.9966, -0.9964, -0.9963, -0.9962, -0.996, -0.9959, -0.9958, -0.9956, -0.9955,
+ -0.9953, -0.9952, -0.995, -0.9949, -0.9947, -0.9946, -0.9944, -0.9942, -0.9941, -0.9939,
+ -0.9937, -0.9936, -0.9934, -0.9932, -0.993, -0.9929, -0.9927, -0.9925, -0.9923, -0.9921,
+ -0.9919, -0.9917, -0.9915, -0.9913, -0.9911, -0.9909, -0.9907, -0.9905, -0.9903, -0.9901,
+ -0.9898, -0.9896, -0.9894, -0.9892, -0.989, -0.9887, -0.9885, -0.9883, -0.988, -0.9878,
+ -0.9875, -0.9873, -0.9871, -0.9868, -0.9866, -0.9863, -0.9861, -0.9858, -0.9855, -0.9853,
+ -0.985, -0.9847, -0.9845, -0.9842, -0.9839, -0.9837, -0.9834, -0.9831, -0.9828, -0.9825,
+ -0.9823, -0.982, -0.9817, -0.9814, -0.9811, -0.9808, -0.9805, -0.9802, -0.9799, -0.9796,
+ -0.9793, -0.9789, -0.9786, -0.9783, -0.978, -0.9777, -0.9774, -0.977, -0.9767, -0.9764,
+ -0.976, -0.9757, -0.9754, -0.975, -0.9747, -0.9743, -0.974, -0.9736, -0.9733, -0.9729,
+ -0.9726, -0.9722, -0.9719, -0.9715, -0.9711, -0.9708, -0.9704, -0.97, -0.9697, -0.9693,
+ -0.9689, -0.9685, -0.9681, -0.9678, -0.9674, -0.967, -0.9666, -0.9662, -0.9658, -0.9654,
+ -0.965, -0.9646, -0.9642, -0.9638, -0.9634, -0.963, -0.9625, -0.9621, -0.9617, -0.9613,
+ -0.9609, -0.9604, -0.96, -0.9596, -0.9591, -0.9587, -0.9583, -0.9578, -0.9574, -0.9569,
+ -0.9565, -0.956, -0.9556, -0.9551, -0.9547, -0.9542, -0.9538, -0.9533, -0.9528, -0.9524,
+ -0.9519, -0.9514, -0.951, -0.9505, -0.95, -0.9495, -0.949, -0.9486, -0.9481, -0.9476,
+ -0.9471, -0.9466, -0.9461, -0.9456, -0.9451, -0.9446, -0.9441, -0.9436, -0.9431, -0.9426,
+ -0.9421, -0.9415, -0.941, -0.9405, -0.94, -0.9395, -0.9389, -0.9384, -0.9379, -0.9373,
+ -0.9368, -0.9363, -0.9357, -0.9352, -0.9346, -0.9341, -0.9335, -0.933, -0.9324, -0.9319,
+ -0.9313, -0.9308, -0.9302, -0.9296, -0.9291, -0.9285, -0.9279, -0.9274, -0.9268, -0.9262,
+ -0.9256, -0.925, -0.9245, -0.9239, -0.9233, -0.9227, -0.9221, -0.9215, -0.9209, -0.9203,
+ -0.9197, -0.9191, -0.9185, -0.9179, -0.9173, -0.9167, -0.9161, -0.9154, -0.9148, -0.9142,
+ -0.9136, -0.913, -0.9123, -0.9117, -0.9111, -0.9104, -0.9098, -0.9092, -0.9085, -0.9079,
+ -0.9072, -0.9066, -0.9059, -0.9053, -0.9046, -0.904, -0.9033, -0.9027, -0.902, -0.9013,
+ -0.9007, -0.9, -0.8993, -0.8987, -0.898, -0.8973, -0.8966, -0.896, -0.8953, -0.8946,
+ -0.8939, -0.8932, -0.8925, -0.8918, -0.8911, -0.8904, -0.8897, -0.889, -0.8883, -0.8876,
+ -0.8869, -0.8862, -0.8855, -0.8848, -0.8841, -0.8834, -0.8826, -0.8819, -0.8812, -0.8805,
+ -0.8797, -0.879, -0.8783, -0.8775, -0.8768, -0.8761, -0.8753, -0.8746, -0.8738, -0.8731,
+ -0.8723, -0.8716, -0.8708, -0.8701, -0.8693, -0.8686, -0.8678, -0.867, -0.8663, -0.8655,
+ -0.8647, -0.864, -0.8632, -0.8624, -0.8616, -0.8609, -0.8601, -0.8593, -0.8585, -0.8577,
+ -0.8569, -0.8561, -0.8554, -0.8546, -0.8538, -0.853, -0.8522, -0.8514, -0.8505, -0.8497,
+ -0.8489, -0.8481, -0.8473, -0.8465, -0.8457, -0.8449, -0.844, -0.8432, -0.8424, -0.8416,
+ -0.8407, -0.8399, -0.8391, -0.8382, -0.8374, -0.8365, -0.8357, -0.8349, -0.834, -0.8332,
+ -0.8323, -0.8315, -0.8306, -0.8298, -0.8289, -0.828, -0.8272, -0.8263, -0.8255, -0.8246,
+ -0.8237, -0.8228, -0.822, -0.8211, -0.8202, -0.8193, -0.8185, -0.8176, -0.8167, -0.8158,
+ -0.8149, -0.814, -0.8131, -0.8123, -0.8114, -0.8105, -0.8096, -0.8087, -0.8078, -0.8068,
+ -0.8059, -0.805, -0.8041, -0.8032, -0.8023, -0.8014, -0.8005, -0.7995, -0.7986, -0.7977,
+ -0.7968, -0.7958, -0.7949, -0.794, -0.793, -0.7921, -0.7912, -0.7902, -0.7893, -0.7883,
+ -0.7874, -0.7865, -0.7855, -0.7846, -0.7836, -0.7827, -0.7817, -0.7807, -0.7798, -0.7788,
+ -0.7779, -0.7769, -0.7759, -0.775, -0.774, -0.773, -0.772, -0.7711, -0.7701, -0.7691,
+ -0.7681, -0.7671, -0.7662, -0.7652, -0.7642, -0.7632, -0.7622, -0.7612, -0.7602, -0.7592,
+ -0.7582, -0.7572, -0.7562, -0.7552, -0.7542, -0.7532, -0.7522, -0.7512, -0.7502, -0.7491,
+ -0.7481, -0.7471, -0.7461, -0.7451, -0.744, -0.743, -0.742, -0.741, -0.7399, -0.7389,
+ -0.7379, -0.7368, -0.7358, -0.7347, -0.7337, -0.7327, -0.7316, -0.7306, -0.7295, -0.7285,
+ -0.7274, -0.7264, -0.7253, -0.7242, -0.7232, -0.7221, -0.7211, -0.72, -0.7189, -0.7179,
+ -0.7168, -0.7157, -0.7147, -0.7136, -0.7125, -0.7114, -0.7104, -0.7093, -0.7082, -0.7071,
+ -0.706, -0.7049, -0.7038, -0.7028, -0.7017, -0.7006, -0.6995, -0.6984, -0.6973, -0.6962,
+ -0.6951, -0.694, -0.6929, -0.6918, -0.6907, -0.6895, -0.6884, -0.6873, -0.6862, -0.6851,
+ -0.684, -0.6828, -0.6817, -0.6806, -0.6795, -0.6784, -0.6772, -0.6761, -0.675, -0.6738,
+ -0.6727, -0.6716, -0.6704, -0.6693, -0.6681, -0.667, -0.6659, -0.6647, -0.6636, -0.6624,
+ -0.6613, -0.6601, -0.659, -0.6578, -0.6567, -0.6555, -0.6543, -0.6532, -0.652, -0.6508,
+ -0.6497, -0.6485, -0.6473, -0.6462, -0.645, -0.6438, -0.6427, -0.6415, -0.6403, -0.6391,
+ -0.6379, -0.6368, -0.6356, -0.6344, -0.6332, -0.632, -0.6308, -0.6296, -0.6284, -0.6273,
+ -0.6261, -0.6249, -0.6237, -0.6225, -0.6213, -0.6201, -0.6189, -0.6176, -0.6164, -0.6152,
+ -0.614, -0.6128, -0.6116, -0.6104, -0.6092, -0.6079, -0.6067, -0.6055, -0.6043, -0.6031,
+ -0.6018, -0.6006, -0.5994, -0.5982, -0.5969, -0.5957, -0.5945, -0.5932, -0.592, -0.5908,
+ -0.5895, -0.5883, -0.587, -0.5858, -0.5846, -0.5833, -0.5821, -0.5808, -0.5796, -0.5783,
+ -0.5771, -0.5758, -0.5746, -0.5733, -0.572, -0.5708, -0.5695, -0.5683, -0.567, -0.5657,
+ -0.5645, -0.5632, -0.5619, -0.5607, -0.5594, -0.5581, -0.5568, -0.5556, -0.5543, -0.553,
+ -0.5517, -0.5505, -0.5492, -0.5479, -0.5466, -0.5453, -0.544, -0.5428, -0.5415, -0.5402,
+ -0.5389, -0.5376, -0.5363, -0.535, -0.5337, -0.5324, -0.5311, -0.5298, -0.5285, -0.5272,
+ -0.5259, -0.5246, -0.5233, -0.522, -0.5207, -0.5194, -0.518, -0.5167, -0.5154, -0.5141,
+ -0.5128, -0.5115, -0.5102, -0.5088, -0.5075, -0.5062, -0.5049, -0.5035, -0.5022, -0.5009,
+ -0.4996, -0.4982, -0.4969, -0.4956, -0.4942, -0.4929, -0.4916, -0.4902, -0.4889, -0.4876,
+ -0.4862, -0.4849, -0.4835, -0.4822, -0.4808, -0.4795, -0.4781, -0.4768, -0.4755, -0.4741,
+ -0.4727, -0.4714, -0.47, -0.4687, -0.4673, -0.466, -0.4646, -0.4633, -0.4619, -0.4605,
+ -0.4592, -0.4578, -0.4564, -0.4551, -0.4537, -0.4523, -0.451, -0.4496, -0.4482, -0.4469,
+ -0.4455, -0.4441, -0.4427, -0.4414, -0.44, -0.4386, -0.4372, -0.4359, -0.4345, -0.4331,
+ -0.4317, -0.4303, -0.4289, -0.4276, -0.4262, -0.4248, -0.4234, -0.422, -0.4206, -0.4192,
+ -0.4178, -0.4164, -0.415, -0.4136, -0.4122, -0.4108, -0.4094, -0.408, -0.4066, -0.4052,
+ -0.4038, -0.4024, -0.401, -0.3996, -0.3982, -0.3968, -0.3954, -0.394, -0.3926, -0.3912,
+ -0.3898, -0.3883, -0.3869, -0.3855, -0.3841, -0.3827, -0.3813, -0.3798, -0.3784, -0.377,
+ -0.3756, -0.3742, -0.3727, -0.3713, -0.3699, -0.3685, -0.367, -0.3656, -0.3642, -0.3628,
+ -0.3613, -0.3599, -0.3585, -0.357, -0.3556, -0.3542, -0.3527, -0.3513, -0.3499, -0.3484,
+ -0.347, -0.3455, -0.3441, -0.3427, -0.3412, -0.3398, -0.3383, -0.3369, -0.3354, -0.334,
+ -0.3326, -0.3311, -0.3297, -0.3282, -0.3268, -0.3253, -0.3239, -0.3224, -0.321, -0.3195,
+ -0.318, -0.3166, -0.3151, -0.3137, -0.3122, -0.3108, -0.3093, -0.3078, -0.3064, -0.3049,
+ -0.3035, -0.302, -0.3005, -0.2991, -0.2976, -0.2962, -0.2947, -0.2932, -0.2918, -0.2903,
+ -0.2888, -0.2873, -0.2859, -0.2844, -0.2829, -0.2815, -0.28, -0.2785, -0.277, -0.2756,
+ -0.2741, -0.2726, -0.2711, -0.2697, -0.2682, -0.2667, -0.2652, -0.2638, -0.2623, -0.2608,
+ -0.2593, -0.2578, -0.2563, -0.2549, -0.2534, -0.2519, -0.2504, -0.2489, -0.2474, -0.246,
+ -0.2445, -0.243, -0.2415, -0.24, -0.2385, -0.237, -0.2355, -0.234, -0.2326, -0.2311,
+ -0.2296, -0.2281, -0.2266, -0.2251, -0.2236, -0.2221, -0.2206, -0.2191, -0.2176, -0.2161,
+ -0.2146, -0.2131, -0.2116, -0.2101, -0.2086, -0.2071, -0.2056, -0.2041, -0.2026, -0.2011,
+ -0.1996, -0.1981, -0.1966, -0.1951, -0.1936, -0.1921, -0.1906, -0.1891, -0.1876, -0.1861,
+ -0.1845, -0.183, -0.1815, -0.18, -0.1785, -0.177, -0.1755, -0.174, -0.1725, -0.171,
+ -0.1695, -0.1679, -0.1664, -0.1649, -0.1634, -0.1619, -0.1604, -0.1589, -0.1573, -0.1558,
+ -0.1543, -0.1528, -0.1513, -0.1498, -0.1482, -0.1467, -0.1452, -0.1437, -0.1422, -0.1407,
+ -0.1391, -0.1376, -0.1361, -0.1346, -0.1331, -0.1315, -0.13, -0.1285, -0.127, -0.1255,
+ -0.1239, -0.1224, -0.1209, -0.1194, -0.1178, -0.1163, -0.1148, -0.1133, -0.1117, -0.1102,
+ -0.1087, -0.1072, -0.1056, -0.1041, -0.1026, -0.1011, -0.09954, -0.09802, -0.09649, -0.09496,
+ -0.09344, -0.09191, -0.09038, -0.08885, -0.08733, -0.0858, -0.08427, -0.08274, -0.08121, -0.07968,
+ -0.07815, -0.07662, -0.07509, -0.07356, -0.07203, -0.0705, -0.06897, -0.06744, -0.06591, -0.06438,
+ -0.06285, -0.06132, -0.05979, -0.05826, -0.05673, -0.0552, -0.05366, -0.05213, -0.0506, -0.04907,
+ -0.04754, -0.046, -0.04447, -0.04294, -0.04141, -0.03987, -0.03834, -0.03681, -0.03527, -0.03374,
+ -0.03221, -0.03067, -0.02914, -0.02761, -0.02607, -0.02454, -0.02301, -0.02147, -0.01994, -0.01841,
+ -0.01687, -0.01534, -0.01381, -0.01227, -0.01074, -0.009204, -0.00767, -0.006136, -0.004602, -0.003068,
+ -0.001534, -1.837e-16, 0.001534, 0.003068, 0.004602, 0.006136, 0.00767, 0.009204, 0.01074, 0.01227,
+ 0.01381, 0.01534, 0.01687, 0.01841, 0.01994, 0.02147, 0.02301, 0.02454, 0.02607, 0.02761,
+ 0.02914, 0.03067, 0.03221, 0.03374, 0.03527, 0.03681, 0.03834, 0.03987, 0.04141, 0.04294,
+ 0.04447, 0.046, 0.04754, 0.04907, 0.0506, 0.05213, 0.05366, 0.0552, 0.05673, 0.05826,
+ 0.05979, 0.06132, 0.06285, 0.06438, 0.06591, 0.06744, 0.06897, 0.0705, 0.07203, 0.07356,
+ 0.07509, 0.07662, 0.07815, 0.07968, 0.08121, 0.08274, 0.08427, 0.0858, 0.08733, 0.08885,
+ 0.09038, 0.09191, 0.09344, 0.09496, 0.09649, 0.09802, 0.09954, 0.1011, 0.1026, 0.1041,
+ 0.1056, 0.1072, 0.1087, 0.1102, 0.1117, 0.1133, 0.1148, 0.1163, 0.1178, 0.1194,
+ 0.1209, 0.1224, 0.1239, 0.1255, 0.127, 0.1285, 0.13, 0.1315, 0.1331, 0.1346,
+ 0.1361, 0.1376, 0.1391, 0.1407, 0.1422, 0.1437, 0.1452, 0.1467, 0.1482, 0.1498,
+ 0.1513, 0.1528, 0.1543, 0.1558, 0.1573, 0.1589, 0.1604, 0.1619, 0.1634, 0.1649,
+ 0.1664, 0.1679, 0.1695, 0.171, 0.1725, 0.174, 0.1755, 0.177, 0.1785, 0.18,
+ 0.1815, 0.183, 0.1845, 0.1861, 0.1876, 0.1891, 0.1906, 0.1921, 0.1936, 0.1951,
+ 0.1966, 0.1981, 0.1996, 0.2011, 0.2026, 0.2041, 0.2056, 0.2071, 0.2086, 0.2101,
+ 0.2116, 0.2131, 0.2146, 0.2161, 0.2176, 0.2191, 0.2206, 0.2221, 0.2236, 0.2251,
+ 0.2266, 0.2281, 0.2296, 0.2311, 0.2326, 0.234, 0.2355, 0.237, 0.2385, 0.24,
+ 0.2415, 0.243, 0.2445, 0.246, 0.2474, 0.2489, 0.2504, 0.2519, 0.2534, 0.2549,
+ 0.2563, 0.2578, 0.2593, 0.2608, 0.2623, 0.2638, 0.2652, 0.2667, 0.2682, 0.2697,
+ 0.2711, 0.2726, 0.2741, 0.2756, 0.277, 0.2785, 0.28, 0.2815, 0.2829, 0.2844,
+ 0.2859, 0.2873, 0.2888, 0.2903, 0.2918, 0.2932, 0.2947, 0.2962, 0.2976, 0.2991,
+ 0.3005, 0.302, 0.3035, 0.3049, 0.3064, 0.3078, 0.3093, 0.3108, 0.3122, 0.3137,
+ 0.3151, 0.3166, 0.318, 0.3195, 0.321, 0.3224, 0.3239, 0.3253, 0.3268, 0.3282,
+ 0.3297, 0.3311, 0.3326, 0.334, 0.3354, 0.3369, 0.3383, 0.3398, 0.3412, 0.3427,
+ 0.3441, 0.3455, 0.347, 0.3484, 0.3499, 0.3513, 0.3527, 0.3542, 0.3556, 0.357,
+ 0.3585, 0.3599, 0.3613, 0.3628, 0.3642, 0.3656, 0.367, 0.3685, 0.3699, 0.3713,
+ 0.3727, 0.3742, 0.3756, 0.377, 0.3784, 0.3798, 0.3813, 0.3827, 0.3841, 0.3855,
+ 0.3869, 0.3883, 0.3898, 0.3912, 0.3926, 0.394, 0.3954, 0.3968, 0.3982, 0.3996,
+ 0.401, 0.4024, 0.4038, 0.4052, 0.4066, 0.408, 0.4094, 0.4108, 0.4122, 0.4136,
+ 0.415, 0.4164, 0.4178, 0.4192, 0.4206, 0.422, 0.4234, 0.4248, 0.4262, 0.4276,
+ 0.4289, 0.4303, 0.4317, 0.4331, 0.4345, 0.4359, 0.4372, 0.4386, 0.44, 0.4414,
+ 0.4427, 0.4441, 0.4455, 0.4469, 0.4482, 0.4496, 0.451, 0.4523, 0.4537, 0.4551,
+ 0.4564, 0.4578, 0.4592, 0.4605, 0.4619, 0.4633, 0.4646, 0.466, 0.4673, 0.4687,
+ 0.47, 0.4714, 0.4727, 0.4741, 0.4755, 0.4768, 0.4781, 0.4795, 0.4808, 0.4822,
+ 0.4835, 0.4849, 0.4862, 0.4876, 0.4889, 0.4902, 0.4916, 0.4929, 0.4942, 0.4956,
+ 0.4969, 0.4982, 0.4996, 0.5009, 0.5022, 0.5035, 0.5049, 0.5062, 0.5075, 0.5088,
+ 0.5102, 0.5115, 0.5128, 0.5141, 0.5154, 0.5167, 0.518, 0.5194, 0.5207, 0.522,
+ 0.5233, 0.5246, 0.5259, 0.5272, 0.5285, 0.5298, 0.5311, 0.5324, 0.5337, 0.535,
+ 0.5363, 0.5376, 0.5389, 0.5402, 0.5415, 0.5428, 0.544, 0.5453, 0.5466, 0.5479,
+ 0.5492, 0.5505, 0.5517, 0.553, 0.5543, 0.5556, 0.5568, 0.5581, 0.5594, 0.5607,
+ 0.5619, 0.5632, 0.5645, 0.5657, 0.567, 0.5683, 0.5695, 0.5708, 0.572, 0.5733,
+ 0.5746, 0.5758, 0.5771, 0.5783, 0.5796, 0.5808, 0.5821, 0.5833, 0.5846, 0.5858,
+ 0.587, 0.5883, 0.5895, 0.5908, 0.592, 0.5932, 0.5945, 0.5957, 0.5969, 0.5982,
+ 0.5994, 0.6006, 0.6018, 0.6031, 0.6043, 0.6055, 0.6067, 0.6079, 0.6092, 0.6104,
+ 0.6116, 0.6128, 0.614, 0.6152, 0.6164, 0.6176, 0.6189, 0.6201, 0.6213, 0.6225,
+ 0.6237, 0.6249, 0.6261, 0.6273, 0.6284, 0.6296, 0.6308, 0.632, 0.6332, 0.6344,
+ 0.6356, 0.6368, 0.6379, 0.6391, 0.6403, 0.6415, 0.6427, 0.6438, 0.645, 0.6462,
+ 0.6473, 0.6485, 0.6497, 0.6508, 0.652, 0.6532, 0.6543, 0.6555, 0.6567, 0.6578,
+ 0.659, 0.6601, 0.6613, 0.6624, 0.6636, 0.6647, 0.6659, 0.667, 0.6681, 0.6693,
+ 0.6704, 0.6716, 0.6727, 0.6738, 0.675, 0.6761, 0.6772, 0.6784, 0.6795, 0.6806,
+ 0.6817, 0.6828, 0.684, 0.6851, 0.6862, 0.6873, 0.6884, 0.6895, 0.6907, 0.6918,
+ 0.6929, 0.694, 0.6951, 0.6962, 0.6973, 0.6984, 0.6995, 0.7006, 0.7017, 0.7028,
+ 0.7038, 0.7049, 0.706, 0.7071, 0.7082, 0.7093, 0.7104, 0.7114, 0.7125, 0.7136,
+ 0.7147, 0.7157, 0.7168, 0.7179, 0.7189, 0.72, 0.7211, 0.7221, 0.7232, 0.7242,
+ 0.7253, 0.7264, 0.7274, 0.7285, 0.7295, 0.7306, 0.7316, 0.7327, 0.7337, 0.7347,
+ 0.7358, 0.7368, 0.7379, 0.7389, 0.7399, 0.741, 0.742, 0.743, 0.744, 0.7451,
+ 0.7461, 0.7471, 0.7481, 0.7491, 0.7502, 0.7512, 0.7522, 0.7532, 0.7542, 0.7552,
+ 0.7562, 0.7572, 0.7582, 0.7592, 0.7602, 0.7612, 0.7622, 0.7632, 0.7642, 0.7652,
+ 0.7662, 0.7671, 0.7681, 0.7691, 0.7701, 0.7711, 0.772, 0.773, 0.774, 0.775,
+ 0.7759, 0.7769, 0.7779, 0.7788, 0.7798, 0.7807, 0.7817, 0.7827, 0.7836, 0.7846,
+ 0.7855, 0.7865, 0.7874, 0.7883, 0.7893, 0.7902, 0.7912, 0.7921, 0.793, 0.794,
+ 0.7949, 0.7958, 0.7968, 0.7977, 0.7986, 0.7995, 0.8005, 0.8014, 0.8023, 0.8032,
+ 0.8041, 0.805, 0.8059, 0.8068, 0.8078, 0.8087, 0.8096, 0.8105, 0.8114, 0.8123,
+ 0.8131, 0.814, 0.8149, 0.8158, 0.8167, 0.8176, 0.8185, 0.8193, 0.8202, 0.8211,
+ 0.822, 0.8228, 0.8237, 0.8246, 0.8255, 0.8263, 0.8272, 0.828, 0.8289, 0.8298,
+ 0.8306, 0.8315, 0.8323, 0.8332, 0.834, 0.8349, 0.8357, 0.8365, 0.8374, 0.8382,
+ 0.8391, 0.8399, 0.8407, 0.8416, 0.8424, 0.8432, 0.844, 0.8449, 0.8457, 0.8465,
+ 0.8473, 0.8481, 0.8489, 0.8497, 0.8505, 0.8514, 0.8522, 0.853, 0.8538, 0.8546,
+ 0.8554, 0.8561, 0.8569, 0.8577, 0.8585, 0.8593, 0.8601, 0.8609, 0.8616, 0.8624,
+ 0.8632, 0.864, 0.8647, 0.8655, 0.8663, 0.867, 0.8678, 0.8686, 0.8693, 0.8701,
+ 0.8708, 0.8716, 0.8723, 0.8731, 0.8738, 0.8746, 0.8753, 0.8761, 0.8768, 0.8775,
+ 0.8783, 0.879, 0.8797, 0.8805, 0.8812, 0.8819, 0.8826, 0.8834, 0.8841, 0.8848,
+ 0.8855, 0.8862, 0.8869, 0.8876, 0.8883, 0.889, 0.8897, 0.8904, 0.8911, 0.8918,
+ 0.8925, 0.8932, 0.8939, 0.8946, 0.8953, 0.896, 0.8966, 0.8973, 0.898, 0.8987,
+ 0.8993, 0.9, 0.9007, 0.9013, 0.902, 0.9027, 0.9033, 0.904, 0.9046, 0.9053,
+ 0.9059, 0.9066, 0.9072, 0.9079, 0.9085, 0.9092, 0.9098, 0.9104, 0.9111, 0.9117,
+ 0.9123, 0.913, 0.9136, 0.9142, 0.9148, 0.9154, 0.9161, 0.9167, 0.9173, 0.9179,
+ 0.9185, 0.9191, 0.9197, 0.9203, 0.9209, 0.9215, 0.9221, 0.9227, 0.9233, 0.9239,
+ 0.9245, 0.925, 0.9256, 0.9262, 0.9268, 0.9274, 0.9279, 0.9285, 0.9291, 0.9296,
+ 0.9302, 0.9308, 0.9313, 0.9319, 0.9324, 0.933, 0.9335, 0.9341, 0.9346, 0.9352,
+ 0.9357, 0.9363, 0.9368, 0.9373, 0.9379, 0.9384, 0.9389, 0.9395, 0.94, 0.9405,
+ 0.941, 0.9415, 0.9421, 0.9426, 0.9431, 0.9436, 0.9441, 0.9446, 0.9451, 0.9456,
+ 0.9461, 0.9466, 0.9471, 0.9476, 0.9481, 0.9486, 0.949, 0.9495, 0.95, 0.9505,
+ 0.951, 0.9514, 0.9519, 0.9524, 0.9528, 0.9533, 0.9538, 0.9542, 0.9547, 0.9551,
+ 0.9556, 0.956, 0.9565, 0.9569, 0.9574, 0.9578, 0.9583, 0.9587, 0.9591, 0.9596,
+ 0.96, 0.9604, 0.9609, 0.9613, 0.9617, 0.9621, 0.9625, 0.963, 0.9634, 0.9638,
+ 0.9642, 0.9646, 0.965, 0.9654, 0.9658, 0.9662, 0.9666, 0.967, 0.9674, 0.9678,
+ 0.9681, 0.9685, 0.9689, 0.9693, 0.9697, 0.97, 0.9704, 0.9708, 0.9711, 0.9715,
+ 0.9719, 0.9722, 0.9726, 0.9729, 0.9733, 0.9736, 0.974, 0.9743, 0.9747, 0.975,
+ 0.9754, 0.9757, 0.976, 0.9764, 0.9767, 0.977, 0.9774, 0.9777, 0.978, 0.9783,
+ 0.9786, 0.9789, 0.9793, 0.9796, 0.9799, 0.9802, 0.9805, 0.9808, 0.9811, 0.9814,
+ 0.9817, 0.982, 0.9823, 0.9825, 0.9828, 0.9831, 0.9834, 0.9837, 0.9839, 0.9842,
+ 0.9845, 0.9847, 0.985, 0.9853, 0.9855, 0.9858, 0.9861, 0.9863, 0.9866, 0.9868,
+ 0.9871, 0.9873, 0.9875, 0.9878, 0.988, 0.9883, 0.9885, 0.9887, 0.989, 0.9892,
+ 0.9894, 0.9896, 0.9898, 0.9901, 0.9903, 0.9905, 0.9907, 0.9909, 0.9911, 0.9913,
+ 0.9915, 0.9917, 0.9919, 0.9921, 0.9923, 0.9925, 0.9927, 0.9929, 0.993, 0.9932,
+ 0.9934, 0.9936, 0.9937, 0.9939, 0.9941, 0.9942, 0.9944, 0.9946, 0.9947, 0.9949,
+ 0.995, 0.9952, 0.9953, 0.9955, 0.9956, 0.9958, 0.9959, 0.996, 0.9962, 0.9963,
+ 0.9964, 0.9966, 0.9967, 0.9968, 0.9969, 0.9971, 0.9972, 0.9973, 0.9974, 0.9975,
+ 0.9976, 0.9977, 0.9978, 0.9979, 0.998, 0.9981, 0.9982, 0.9983, 0.9984, 0.9985,
+ 0.9986, 0.9986, 0.9987, 0.9988, 0.9989, 0.9989, 0.999, 0.9991, 0.9991, 0.9992,
+ 0.9993, 0.9993, 0.9994, 0.9994, 0.9995, 0.9995, 0.9996, 0.9996, 0.9997, 0.9997,
+ 0.9997, 0.9998, 0.9998, 0.9998, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 1,
+ 1, 1, 1, 1, 1, 1};
+
+ const static uint64 mask = max-1;
+ return logvals[h.first&mask]*cosvals[h.second&mask];
+
+ // Note that we are just using the Box-Muller transform to compute the result. In
+ // particular, we are doing this (where u1 and u2 are uniform random variables in
+ // the range [0,1]):
+ // return sqrt(-2*log(u1)) * cos(2*PI*u2);
+ // It is just that we use table lookups to avoid calling sqrt(), log() and cos().
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANDOM_HAsHING_Hh_
+
diff --git a/ml/dlib/dlib/general_hash/random_hashing_abstract.h b/ml/dlib/dlib/general_hash/random_hashing_abstract.h
new file mode 100644
index 000000000..3d196d8c0
--- /dev/null
+++ b/ml/dlib/dlib/general_hash/random_hashing_abstract.h
@@ -0,0 +1,58 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RANDOM_HAsHING_ABSTRACT_Hh_
+#ifdef DLIB_RANDOM_HAsHING_ABSTRACT_Hh_
+
+#include "random_hashing_abstract.h"
+#include "murmur_hash3.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ double uniform_random_hash (
+ const uint64& k1,
+ const uint64& k2,
+ const uint64& k3
+ );
+ /*!
+ ensures
+ - This function uses hashing to generate uniform random values in the range [0,1).
+ - To define this function precisely, assume we have an arbitrary sequence of
+ input triplets. Then calling uniform_random_hash() on each of them should
+ result in a sequence of double values that look like numbers sampled
+ independently and uniformly at random from the interval [0,1). This is true
+ even if there is some simple pattern in the inputs. For example, (0,0,0),
+ (1,0,0), (2,0,0), (3,0,0), etc.
+ - This function is deterministic. That is, the same output is always returned
+ when given the same input.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ double gaussian_random_hash (
+ const uint64& k1,
+ const uint64& k2,
+ const uint64& k3
+ );
+ /*!
+ ensures
+ - This function uses hashing to generate Gaussian distributed random values
+ with mean 0 and variance 1.
+ - To define this function precisely, assume we have an arbitrary sequence of
+ input triplets. Then calling gaussian_random_hash() on each of them should
+ result in a sequence of double values that look like numbers sampled
+ independently from a standard normal distribution. This is true even if
+ there is some simple pattern in the inputs. For example, (0,0,0), (1,0,0),
+ (2,0,0), (3,0,0), etc.
+ - This function is deterministic. That is, the same output is always returned
+ when given the same input.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANDOM_HAsHING_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/geometry.h b/ml/dlib/dlib/geometry.h
new file mode 100644
index 000000000..9d326b150
--- /dev/null
+++ b/ml/dlib/dlib/geometry.h
@@ -0,0 +1,14 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GEOMETRy_HEADER
+#define DLIB_GEOMETRy_HEADER
+
+#include "geometry/rectangle.h"
+#include "geometry/drectangle.h"
+#include "geometry/vector.h"
+#include "geometry/border_enumerator.h"
+#include "geometry/point_transforms.h"
+
+#endif // DLIB_GEOMETRy_HEADER
+
+
diff --git a/ml/dlib/dlib/geometry/border_enumerator.h b/ml/dlib/dlib/geometry/border_enumerator.h
new file mode 100644
index 000000000..0c69cc37f
--- /dev/null
+++ b/ml/dlib/dlib/geometry/border_enumerator.h
@@ -0,0 +1,186 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BORDER_EnUMERATOR_H_
+#define DLIB_BORDER_EnUMERATOR_H_
+
+#include "border_enumerator_abstract.h"
+#include "rectangle.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class border_enumerator
+ {
+ public:
+ border_enumerator(
+ )
+ {
+ reset();
+ }
+
+ border_enumerator(
+ const rectangle& rect_,
+ unsigned long border_size
+ ) :
+ rect(rect_),
+ inner_rect(shrink_rect(rect_, border_size))
+ {
+ reset();
+ }
+
+ border_enumerator(
+ const rectangle& rect_,
+ const rectangle& non_border_region
+ ) :
+ rect(rect_),
+ inner_rect(non_border_region.intersect(rect))
+ {
+ reset();
+ }
+
+ void reset (
+ )
+ {
+ // make the four rectangles that surround inner_rect and intersect them
+ // with rect.
+ bleft = rect.intersect(rectangle(std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::min(),
+ inner_rect.left()-1,
+ std::numeric_limits<long>::max()));
+
+ bright = rect.intersect(rectangle(inner_rect.right()+1,
+ std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(),
+ std::numeric_limits<long>::max()));
+
+ btop = rect.intersect(rectangle(inner_rect.left(),
+ std::numeric_limits<long>::min(),
+ inner_rect.right(),
+ inner_rect.top()-1));
+
+ bbottom = rect.intersect(rectangle(inner_rect.left(),
+ inner_rect.bottom()+1,
+ inner_rect.right(),
+ std::numeric_limits<long>::max()));
+
+ p = bleft.tl_corner();
+ p.x() -= 1;
+
+ mode = atleft;
+ }
+
+ bool at_start (
+ ) const
+ {
+ point temp = bleft.tl_corner();
+ temp.x() -=1;
+ return temp == p;
+ }
+
+ bool current_element_valid(
+ ) const
+ {
+ return rect.contains(p);
+ }
+
+ bool move_next()
+ {
+ if (mode == atleft)
+ {
+ if (advance_point(bleft, p))
+ return true;
+
+ mode = attop;
+ p = btop.tl_corner();
+ p.x() -= 1;
+ }
+ if (mode == attop)
+ {
+ if (advance_point(btop, p))
+ return true;
+
+ mode = atright;
+ p = bright.tl_corner();
+ p.x() -= 1;
+ }
+ if (mode == atright)
+ {
+ if (advance_point(bright, p))
+ return true;
+
+ mode = atbottom;
+ p = bbottom.tl_corner();
+ p.x() -= 1;
+ }
+
+ if (advance_point(bbottom, p))
+ return true;
+
+ // put p outside rect since there are no more points to enumerate
+ p = rect.br_corner();
+ p.x() += 1;
+
+ return false;
+ }
+
+ size_t size (
+ ) const
+ {
+ return rect.area() - inner_rect.area();
+ }
+
+ const point& element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_element_valid(),
+ "\t point border_enumerator::element()"
+ << "\n\t This function can't be called unless the element is valid."
+ << "\n\t this: " << this
+ );
+
+ return p;
+ }
+
+ private:
+
+ bool advance_point (
+ const rectangle& r,
+ point& p
+ ) const
+ {
+ p.x() += 1;
+ if (p.x() > r.right())
+ {
+ p.x() = r.left();
+ p.y() += 1;
+ }
+
+ return r.contains(p);
+ }
+
+ point p;
+ rectangle rect;
+ rectangle inner_rect; // the non-border regions of rect
+
+ enum emode
+ {
+ atleft,
+ atright,
+ atbottom,
+ attop
+ };
+
+ emode mode;
+
+ rectangle btop, bleft, bright, bbottom;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BORDER_EnUMERATOR_H_
+
diff --git a/ml/dlib/dlib/geometry/border_enumerator_abstract.h b/ml/dlib/dlib/geometry/border_enumerator_abstract.h
new file mode 100644
index 000000000..11118d571
--- /dev/null
+++ b/ml/dlib/dlib/geometry/border_enumerator_abstract.h
@@ -0,0 +1,126 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BORDER_EnUMERATOR_ABSTRACT_H_
+#ifdef DLIB_BORDER_EnUMERATOR_ABSTRACT_H_
+
+#include "rectangle_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class border_enumerator
+ {
+ /*!
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ All operations on this object other than calling element() invalidate
+ pointers and references to internal data.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is an enumerator over the border points of a rectangle.
+ !*/
+ public:
+
+ border_enumerator(
+ );
+ /*!
+ ensures
+ - #move_next() == false
+ (i.e. this object is "empty" and won't enumerate anything)
+ - current_element_valid() == false
+ - at_start() == true
+ - size() == 0
+ !*/
+
+ border_enumerator(
+ const rectangle& rect,
+ unsigned long border_size
+ );
+ /*!
+ ensures
+ - This object will enumerate over the border points which are inside rect
+ but within border_size of the edge. For example, if border_size == 1
+ then it enumerates over the single point wide strip of points all around
+ the interior edge of rect.
+ - current_element_valid() == false
+ - at_start() == true
+ - size() == rect.area() - shrink_rect(rect,border_size).area()
+ (i.e. the number of points in the border area of rect)
+ !*/
+
+ border_enumerator(
+ const rectangle& rect,
+ const rectangle& non_border_region
+ );
+ /*!
+ ensures
+ - This object will enumerate over all points which are in rect but
+ not in non_border_region.
+ - current_element_valid() == false
+ - at_start() == true
+ - size() == rect.area() - rect.intersect(non_border_region).area()
+ !*/
+
+ bool at_start (
+ ) const;
+ /*!
+ ensures
+ - returns true if *this represents one position before the first point
+ (this would also make the current element invalid) else returns false
+ !*/
+
+ void reset (
+ );
+ /*!
+ ensures
+ - #current_element_valid() == false
+ - #at_start() == true
+ !*/
+
+ bool current_element_valid(
+ ) const;
+ /*!
+ ensures
+ - returns true if we are currently at a valid element else
+ returns false
+ !*/
+
+ bool move_next(
+ );
+ /*!
+ ensures
+ - moves to the next element. i.e. #element() will now
+ return the next border point.
+ - the return value will be equal to #current_element_valid()
+ - #at_start() == false
+
+ - returns true if there is another element
+ - returns false if there are no more elements in the container
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of border points
+ !*/
+
+ const point& element (
+ ) const;
+ /*!
+ requires
+ - current_element_valid() == true
+ ensures
+ - returns the current border point
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BORDER_EnUMERATOR_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/geometry/drectangle.h b/ml/dlib/dlib/geometry/drectangle.h
new file mode 100644
index 000000000..9ccc5c0ee
--- /dev/null
+++ b/ml/dlib/dlib/geometry/drectangle.h
@@ -0,0 +1,488 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DRECTANGLe_
+#define DLIB_DRECTANGLe_
+
+#include "drectangle_abstract.h"
+#include "rectangle.h"
+
+namespace dlib
+{
+ class drectangle;
+ drectangle operator* (
+ const drectangle& rect,
+ const double& scale
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ class drectangle
+ {
+ public:
+
+ drectangle (
+ ) : l(0), t(0), r(-1), b(-1) {}
+
+ drectangle (
+ double l_,
+ double t_,
+ double r_,
+ double b_
+ ) :
+ l(l_),
+ t(t_),
+ r(r_),
+ b(b_)
+ {}
+
+ drectangle (
+ const dlib::vector<double,2>& p
+ ) :
+ l(p.x()),
+ t(p.y()),
+ r(p.x()),
+ b(p.y())
+ {
+ }
+
+ template <typename T, typename U>
+ drectangle (
+ const vector<T,2>& p1,
+ const vector<U,2>& p2
+ )
+ {
+ *this = drectangle(p1) + drectangle(p2);
+ }
+
+ drectangle (
+ const rectangle& rect
+ ) : l(rect.left()),
+ t(rect.top()),
+ r(rect.right()),
+ b(rect.bottom()) {}
+
+ operator rectangle (
+ ) const
+ {
+ return rectangle((long)std::floor(l+0.5),
+ (long)std::floor(t+0.5),
+ (long)std::floor(r+0.5),
+ (long)std::floor(b+0.5));
+ }
+
+ double left() const { return l; }
+ double top() const { return t; }
+ double right() const { return r; }
+ double bottom() const { return b; }
+
+ double& left() { return l; }
+ double& top() { return t; }
+ double& right() { return r; }
+ double& bottom() { return b; }
+
+ const dlib::vector<double,2> tl_corner (
+ ) const { return dlib::vector<double,2>(left(), top()); }
+
+ const dlib::vector<double,2> bl_corner (
+ ) const { return dlib::vector<double,2>(left(), bottom()); }
+
+ const dlib::vector<double,2> tr_corner (
+ ) const { return dlib::vector<double,2>(right(), top()); }
+
+ const dlib::vector<double,2> br_corner (
+ ) const { return dlib::vector<double,2>(right(), bottom()); }
+
+ double width (
+ ) const
+ {
+ if (is_empty())
+ return 0;
+ else
+ return r - l + 1;
+ }
+
+ double height (
+ ) const
+ {
+ if (is_empty())
+ return 0;
+ else
+ return b - t + 1;
+ }
+
+ double area (
+ ) const
+ {
+ return width()*height();
+ }
+
+ bool is_empty (
+ ) const { return (t > b || l > r); }
+
+ drectangle operator + (
+ const drectangle& rhs
+ ) const
+ {
+ if (rhs.is_empty())
+ return *this;
+ else if (is_empty())
+ return rhs;
+
+ return drectangle (
+ std::min(l,rhs.l),
+ std::min(t,rhs.t),
+ std::max(r,rhs.r),
+ std::max(b,rhs.b)
+ );
+ }
+
+ drectangle intersect (
+ const drectangle& rhs
+ ) const
+ {
+ return drectangle (
+ std::max(l,rhs.l),
+ std::max(t,rhs.t),
+ std::min(r,rhs.r),
+ std::min(b,rhs.b)
+ );
+ }
+
+ bool contains (
+ const dlib::vector<double,2>& p
+ ) const
+ {
+ if (p.x() < l || p.x() > r || p.y() < t || p.y() > b)
+ return false;
+ return true;
+ }
+
+ bool contains (
+ const drectangle& rect
+ ) const
+ {
+ if (rect.is_empty())
+ return true;
+ if (l <= rect.left() &&
+ r >= rect.right() &&
+ t <= rect.top() &&
+ b >= rect.bottom())
+ return true;
+ return false;
+ }
+
+ drectangle& operator *= (
+ const double& scale
+ )
+ {
+ *this = *this*scale;
+ return *this;
+ }
+
+ drectangle& operator /= (
+ const double& scale
+ )
+ {
+ *this = *this*(1.0/scale);
+ return *this;
+ }
+
+ drectangle& operator += (
+ const dlib::vector<double,2>& p
+ )
+ {
+ *this = *this + drectangle(p);
+ return *this;
+ }
+
+ bool operator== (
+ const drectangle& rect
+ ) const
+ {
+ return (l == rect.l) && (t == rect.t) && (r == rect.r) && (b == rect.b);
+ }
+
+ bool operator!= (
+ const drectangle& rect
+ ) const
+ {
+ return !(*this == rect);
+ }
+
+ private:
+ double l;
+ double t;
+ double r;
+ double b;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const drectangle& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.left(),out);
+ serialize(item.top(),out);
+ serialize(item.right(),out);
+ serialize(item.bottom(),out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing an object of type drectangle");
+ }
+ }
+
+ inline void deserialize (
+ drectangle& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.left(),in);
+ deserialize(item.top(),in);
+ deserialize(item.right(),in);
+ deserialize(item.bottom(),in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing an object of type drectangle");
+ }
+ }
+
+ inline std::ostream& operator<< (
+ std::ostream& out,
+ const drectangle& item
+ )
+ {
+ out << "[(" << item.left() << ", " << item.top() << ") (" << item.right() << ", " << item.bottom() << ")]";
+ return out;
+ }
+
+ inline std::istream& operator>>(
+ std::istream& in,
+ drectangle& item
+ )
+ {
+ // ignore any whitespace
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n')
+ in.get();
+ // now eat the leading '[' character
+ if (in.get() != '[')
+ {
+ in.setstate(in.rdstate() | std::ios::failbit);
+ return in;
+ }
+
+ dlib::vector<double,2> p1, p2;
+ in >> p1;
+ in >> p2;
+ item = drectangle(p1) + drectangle(p2);
+
+ // ignore any whitespace
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n')
+ in.get();
+ // now eat the trailing ']' character
+ if (in.get() != ']')
+ {
+ in.setstate(in.rdstate() | std::ios::failbit);
+ }
+ return in;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline dlib::vector<double,2> center (
+ const drectangle& rect
+ )
+ {
+ dlib::vector<double,2> temp(rect.left() + rect.right(),
+ rect.top() + rect.bottom());
+
+ return temp/2.0;
+ }
+
+ inline dlib::vector<double,2> dcenter (
+ const drectangle& rect
+ )
+ {
+ return center(rect);
+ }
+
+ inline drectangle operator* (
+ const drectangle& rect,
+ const double& scale
+ )
+ {
+ if (!rect.is_empty())
+ {
+ const double width = (rect.right()-rect.left())*scale;
+ const double height = (rect.bottom()-rect.top())*scale;
+ const dlib::vector<double,2> p = center(rect);
+ return drectangle(p.x()-width/2, p.y()-height/2, p.x()+width/2, p.y()+height/2);
+ }
+ else
+ {
+ return rect;
+ }
+ }
+
+ inline drectangle operator* (
+ const double& scale,
+ const drectangle& rect
+ )
+ {
+ return rect*scale;
+ }
+
+ inline drectangle operator/ (
+ const drectangle& rect,
+ const double& scale
+ )
+ {
+ return rect*(1.0/scale);
+ }
+
+ inline drectangle operator+ (
+ const drectangle& r,
+ const dlib::vector<double,2>& p
+ )
+ {
+ return r + drectangle(p);
+ }
+
+ inline drectangle operator+ (
+ const dlib::vector<double,2>& p,
+ const drectangle& r
+ )
+ {
+ return r + drectangle(p);
+ }
+
+ template <typename T>
+ inline drectangle translate_rect (
+ const drectangle& rect,
+ const dlib::vector<T,2>& p
+ )
+ {
+ drectangle result;
+ result.top () = rect.top() + p.y();
+ result.bottom () = rect.bottom() + p.y();
+ result.left () = rect.left() + p.x();
+ result.right () = rect.right() + p.x();
+ return result;
+ }
+
+ inline drectangle intersect (
+ const drectangle& a,
+ const drectangle& b
+ ) { return a.intersect(b); }
+
+ inline double area (
+ const drectangle& a
+ ) { return a.area(); }
+
+ inline drectangle centered_drect (
+ const dlib::vector<double,2>& p,
+ double width,
+ double height
+ )
+ {
+ width--;
+ height--;
+
+ return drectangle(p.x()-width/2, p.y()-height/2, p.x()+width/2, p.y()+height/2);
+ }
+
+ inline drectangle centered_drect (
+ const drectangle& rect,
+ double width,
+ double height
+ )
+ {
+ return centered_drect(dcenter(rect), width, height);
+ }
+
+ inline const drectangle shrink_rect (
+ const drectangle& rect,
+ double num
+ )
+ {
+ return drectangle(rect.left()+num, rect.top()+num, rect.right()-num, rect.bottom()-num);
+ }
+
+ inline const drectangle grow_rect (
+ const drectangle& rect,
+ double num
+ )
+ {
+ return shrink_rect(rect, -num);
+ }
+
+ inline const drectangle shrink_rect (
+ const drectangle& rect,
+ double width,
+ double height
+ )
+ {
+ return drectangle(rect.left()+width, rect.top()+height, rect.right()-width, rect.bottom()-height);
+ }
+
+ inline const drectangle grow_rect (
+ const drectangle& rect,
+ double width,
+ double height
+ )
+ {
+ return shrink_rect(rect, -width, -height);
+ }
+
+ inline drectangle set_rect_area (
+ const drectangle& rect,
+ double area
+ )
+ {
+ DLIB_ASSERT(area >= 0, "drectangle can't have a negative area.");
+
+ if (area == 0)
+ return drectangle(dcenter(rect));
+
+ if (rect.area() == 0)
+ {
+ // In this case we will make the output rectangle a square with the requested
+ // area.
+ double scale = std::sqrt(area);
+ return centered_drect(rect, scale, scale);
+ }
+ else
+ {
+ double scale = std::sqrt(area/rect.area());
+ return centered_drect(rect, rect.width()*scale, rect.height()*scale);
+ }
+ }
+
+ inline drectangle set_aspect_ratio (
+ const drectangle& rect,
+ double ratio
+ )
+ {
+ DLIB_ASSERT(ratio > 0,
+ "\t drectangle set_aspect_ratio()"
+ << "\n\t ratio: " << ratio
+ );
+
+ const double h = std::sqrt(rect.area()/ratio);
+ const double w = h*ratio;
+ return centered_drect(rect, w, h);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DRECTANGLe_
+
diff --git a/ml/dlib/dlib/geometry/drectangle_abstract.h b/ml/dlib/dlib/geometry/drectangle_abstract.h
new file mode 100644
index 000000000..0f2221353
--- /dev/null
+++ b/ml/dlib/dlib/geometry/drectangle_abstract.h
@@ -0,0 +1,628 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DRECTANGLe_ABSTRACT_H_
+#ifdef DLIB_DRECTANGLe_ABSTRACT_H_
+
+#include "rectangle_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class drectangle
+ {
+ /*!
+ INITIAL VALUE
+ The initial value of this object is defined by its constructor.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is just like dlib::rectangle except that it stores the
+ coordinates of the rectangle using double rather than long variables. As
+ such, this object represents a rectangular region inside an image. The
+ region is the rectangle with its top left corner at position (left(),top())
+ and its bottom right corner at (right(),bottom()).
+
+ Note that the origin of the coordinate system, i.e. (0,0), is located at
+ the upper left corner. That is, points such as (1,1) or (3,5) represent
+ locations that are below and to the right of the origin.
+
+ Also note that rectangles where top() > bottom() or left() > right()
+ represent empty rectangles.
+ !*/
+ public:
+
+ drectangle (
+ );
+ /*!
+ ensures
+ - #left() == 0
+ - #top() == 0
+ - #right() == -1
+ - #bottom() == -1
+ - #is_empty() == true
+ !*/
+
+ drectangle (
+ double left_,
+ double top_,
+ double right_,
+ double bottom_
+ );
+ /*!
+ ensures
+ - #left() == left_
+ - #top() == top_
+ - #right() == right_
+ - #bottom() == bottom_
+ !*/
+
+ drectangle (
+ const vector<double,2>& p
+ );
+ /*!
+ ensures
+ - #left() == p.x()
+ - #top() == p.y()
+ - #right() == p.x()
+ - #bottom() == p.y()
+ !*/
+
+ template <typename T, typename U>
+ drectangle (
+ const vector<T,2>& p1,
+ const vector<U,2>& p2
+ );
+ /*!
+ ensures
+ - #*this == drectangle(p1) + drectangle(p2)
+ !*/
+
+ drectangle (
+ const drectangle& rect
+ );
+ /*!
+ ensures
+ - #*this represents the same rectangle as rect
+ !*/
+
+ drectangle (
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - left() == rect.left()
+ - top() == rect.top()
+ - right() == rect.right()
+ - bottom() == rect.bottom()
+ - dcenter(*this) == dcenter(rect)
+ - width() == rect.width()
+ - height() == rect.height()
+ !*/
+
+ operator rectangle (
+ ) const;
+ /*!
+ ensures
+ - returns a rectangle where left(), top(), right(), and bottom() have been
+ rounded to the nearest integer values.
+ !*/
+
+ double left (
+ ) const;
+ /*!
+ ensures
+ - returns the x coordinate for the left side of this rectangle
+ !*/
+
+ double& left (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the x coordinate for the left side
+ of this rectangle
+ !*/
+
+ double top (
+ ) const;
+ /*!
+ ensures
+ - returns the y coordinate for the top of this rectangle
+ !*/
+
+ double& top (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the y coordinate for the
+ top of this rectangle
+ !*/
+
+ double right (
+ ) const;
+ /*!
+ ensures
+ - returns the x coordinate for the right side of this rectangle
+ !*/
+
+ double& right (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the x coordinate for the right
+ side of this rectangle
+ !*/
+
+ double bottom (
+ ) const;
+ /*!
+ ensures
+ - returns the y coordinate for the bottom of this rectangle
+ !*/
+
+ double& bottom (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the y coordinate for the bottom
+ of this rectangle
+ !*/
+
+ const vector<double,2> tl_corner (
+ ) const;
+ /*!
+ ensures
+ - returns vector<double,2>(left(), top())
+ (i.e. returns the top left corner point for this rectangle)
+ !*/
+
+ const vector<double,2> bl_corner (
+ ) const;
+ /*!
+ ensures
+ - returns vector<double,2>(left(), bottom())
+ (i.e. returns the bottom left corner point for this rectangle)
+ !*/
+
+ const vector<double,2> tr_corner (
+ ) const;
+ /*!
+ ensures
+ - returns vector<double,2>(right(), top())
+ (i.e. returns the top right corner point for this rectangle)
+ !*/
+
+ const vector<double,2> br_corner (
+ ) const;
+ /*!
+ ensures
+ - returns vector<double,2>(right(), bottom())
+ (i.e. returns the bottom right corner point for this rectangle)
+ !*/
+
+ double width (
+ ) const;
+ /*!
+ ensures
+ - if (is_empty()) then
+ - returns 0
+ - else
+ - returns the width of this rectangle.
+ (i.e. right() - left() + 1)
+ !*/
+
+ double height (
+ ) const;
+ /*!
+ ensures
+ - if (is_empty()) then
+ - returns 0
+ - else
+ - returns the height of this rectangle.
+ (i.e. bottom() - top() + 1)
+ !*/
+
+ double area (
+ ) const;
+ /*!
+ ensures
+ - returns width()*height()
+ !*/
+
+ bool is_empty (
+ ) const;
+ /*!
+ ensures
+ - if (top() > bottom() || left() > right()) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ drectangle operator + (
+ const drectangle& rhs
+ ) const;
+ /*!
+ ensures
+ - if (rhs.is_empty() == false && this->is_empty() == false) then
+ - returns the smallest rectangle that contains both *this and
+ rhs.
+ - if (rhs.is_empty() == true && this->is_empty() == false) then
+ - returns *this
+ - if (rhs.is_empty() == false && this->is_empty() == true) then
+ - returns rhs
+ - if (rhs.is_empty() == true && this->is_empty() == true) then
+ - returns a rectangle that has is_empty() == true
+ !*/
+
+ drectangle intersect (
+ const drectangle& rhs
+ ) const;
+ /*!
+ ensures
+ - if (there is a region of intersection between *this and rhs) then
+ - returns a rectangle that represents the intersection of *this
+ and rhs.
+ - else
+ - returns a rectangle where is_empty() == true
+ !*/
+
+ bool contains (
+ const vector<double,2>& p
+ ) const;
+ /*!
+ ensures
+ - if (the point (p.x(),p.y()) is contained in this rectangle) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool contains (
+ const drectangle& rect
+ ) const
+ /*!
+ ensures
+ - if (rect + *this == *this) then
+ - returns true
+ (i.e. returns true if *this contains the given rectangle)
+ - else
+ - returns false
+ !*/
+
+ drectangle& operator *= (
+ const double& scale
+ );
+ /*!
+ ensures
+ - performs: *this = *this*scale;
+ - returns #*this
+ !*/
+
+ drectangle& operator /= (
+ const double& scale
+ );
+ /*!
+ requires
+ - scale != 0
+ ensures
+ - performs: *this = *this*(1.0/scale);
+ - returns #*this
+ !*/
+
+ drectangle& operator += (
+ const dlib::vector<double,2>& p
+ );
+ /*!
+ ensures
+ - performs: *this = *this + drectangle(p)
+ - returns #*this
+ !*/
+
+ bool operator== (
+ const drectangle& rect
+ ) const;
+ /*!
+ ensures
+ - if (top() == rect.top() && left() == rect.left() &&
+ right() == rect.right() && bottom() == rect.bottom()) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool operator!= (
+ const drectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns !(*this == rect)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const drectangle& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ drectangle& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::ostream& operator<< (
+ std::ostream& out,
+ const drectangle& item
+ );
+ /*!
+ ensures
+ - writes item to out in the form "[(left, top) (right, bottom)]"
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::istream& operator>>(
+ std::istream& in,
+ drectangle& item
+ );
+ /*!
+ ensures
+ - reads a drectangle from the input stream in and stores it in #item. The data
+ in the input stream should be of the form [(left, top) (right, bottom)]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ vector<double,2> center (
+ const drectangle& rect
+ );
+ /*!
+ ensures
+ - returns the center of the given rectangle
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ vector<double,2> dcenter (
+ const drectangle& rect
+ );
+ /*!
+ ensures
+ - returns the center of the given rectangle. (Both center() and dcenter() are
+ identical when applied to drectangle objects)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle operator* (
+ const drectangle& rect,
+ const double& scale
+ );
+ /*!
+ ensures
+ - This function returns a rectangle that has the same center as rect but with
+ dimensions that are scale times larger. That is, we return a new rectangle R
+ such that:
+ - center(R) == center(rect)
+ - R.right()-R.left() == (rect.right()-rect.left())*scale
+ - R.bottom()-R.top() == (rect.bottom()-rect.top())*scale
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle operator* (
+ const double& scale,
+ const drectangle& rect
+ );
+ /*!
+ ensures
+ - returns rect*scale
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle operator/ (
+ const drectangle& rect,
+ const double& scale
+ );
+ /*!
+ ensures
+ - returns rect*(1.0/scale)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle operator+ (
+ const drectangle& r,
+ const vector<double,2>& p
+ );
+ /*!
+ ensures
+ - returns r + drectangle(p)
+ (i.e. returns the rectangle that contains both r and p)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle operator+ (
+ const vector<double,2>& p,
+ const drectangle& r
+ );
+ /*!
+ ensures
+ - returns r + drectangle(p)
+ (i.e. returns the rectangle that contains both r and p)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ drectangle translate_rect (
+ const drectangle& rect,
+ const vector<T,2>& p
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - R.left() == rect.left() + p.x()
+ - R.right() == rect.right() + p.x()
+ - R.top() == rect.top() + p.y()
+ - R.bottom() == rect.bottom() + p.y()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle intersect (
+ const drectangle& a,
+ const drectangle& b
+ );
+ /*!
+ ensures
+ - returns a.intersect(b)
+ (i.e. returns a rectangle representing the intersection of a and b)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ double area (
+ const drectangle& a
+ );
+ /*!
+ ensures
+ - returns a.area()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle centered_drect (
+ const vector<double,2>& p,
+ double width,
+ double height
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - center(R) == p
+ - if (width < 1 || height < 1)
+ - R.width() == 0
+ - R.height() == 0
+ - R.is_empty() == true
+ - else
+ - R.width() == width
+ - R.height() == height
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle centered_drect (
+ const drectangle& rect,
+ double width,
+ double height
+ );
+ /*!
+ ensures
+ - returns centered_drect(center(rect), width, height)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const drectangle shrink_rect (
+ const drectangle& rect,
+ double num
+ );
+ /*!
+ ensures
+ - returns drectangle(rect.left()+num, rect.top()+num, rect.right()-num, rect.bottom()-num)
+ (i.e. shrinks the given drectangle by shrinking its border by num)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const drectangle grow_rect (
+ const drectangle& rect,
+ double num
+ );
+ /*!
+ ensures
+ - return shrink_rect(rect, -num)
+ (i.e. grows the given drectangle by expanding its border by num)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const drectangle shrink_rect (
+ const drectangle& rect,
+ double width,
+ double height
+ );
+ /*!
+ ensures
+ - returns drectangle(rect.left()+width, rect.top()+height, rect.right()-width, rect.bottom()-height)
+ (i.e. shrinks the given drectangle by shrinking its left and right borders by width
+ and its top and bottom borders by height. )
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const drectangle grow_rect (
+ const drectangle& rect,
+ double width,
+ double height
+ );
+ /*!
+ ensures
+ - return shrink_rect(rect, -width, -height)
+ (i.e. grows the given drectangle by expanding its border)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle set_rect_area (
+ const drectangle& rect,
+ double area
+ );
+ /*!
+ requires
+ - area >= 0
+ ensures
+ - Returns a rectangle R such that:
+ - center(R) == center(rect)
+ - R has the same aspect ratio as rect. If rect.area() == 0 then the
+ returned rect has a 1:1 aspect ratio.
+ - R.area() == area
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ drectangle set_aspect_ratio (
+ const drectangle& rect,
+ double ratio
+ );
+ /*!
+ requires
+ - ratio > 0
+ ensures
+ - This function reshapes the given rectangle so that it has the given aspect
+ ratio. In particular, this means we return a rectangle R such that the
+ following equations are true:
+ - R.width()/R.height() == ratio
+ - R.area() == rect.area()
+ - dcenter(rect) == dcenter(R)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DRECTANGLe_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/geometry/point_transforms.h b/ml/dlib/dlib/geometry/point_transforms.h
new file mode 100644
index 000000000..e789fd2a2
--- /dev/null
+++ b/ml/dlib/dlib/geometry/point_transforms.h
@@ -0,0 +1,989 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_POINT_TrANSFORMS_H_
+#define DLIB_POINT_TrANSFORMS_H_
+
+#include "point_transforms_abstract.h"
+#include "../algs.h"
+#include "vector.h"
+#include "../matrix.h"
+#include "../matrix/matrix_la.h"
+#include "../optimization/optimization.h"
+#include "rectangle.h"
+#include "drectangle.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class point_rotator
+ {
+ public:
+ point_rotator (
+ )
+ {
+ sin_angle = 0;
+ cos_angle = 1;
+ }
+
+ point_rotator (
+ const double& angle
+ )
+ {
+ sin_angle = std::sin(angle);
+ cos_angle = std::cos(angle);
+ }
+
+ template <typename T>
+ const dlib::vector<T,2> operator() (
+ const dlib::vector<T,2>& p
+ ) const
+ {
+ double x = cos_angle*p.x() - sin_angle*p.y();
+ double y = sin_angle*p.x() + cos_angle*p.y();
+
+ return dlib::vector<double,2>(x,y);
+ }
+
+ const matrix<double,2,2> get_m(
+ ) const
+ {
+ matrix<double,2,2> temp;
+ temp = cos_angle, -sin_angle,
+ sin_angle, cos_angle;
+ return temp;
+ }
+
+ inline friend void serialize (const point_rotator& item, std::ostream& out)
+ {
+ serialize(item.sin_angle, out);
+ serialize(item.cos_angle, out);
+ }
+
+ inline friend void deserialize (point_rotator& item, std::istream& in)
+ {
+ deserialize(item.sin_angle, in);
+ deserialize(item.cos_angle, in);
+ }
+
+ private:
+ double sin_angle;
+ double cos_angle;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class point_transform
+ {
+ public:
+
+ point_transform (
+ )
+ {
+ sin_angle = 0;
+ cos_angle = 1;
+ translate.x() = 0;
+ translate.y() = 0;
+ }
+
+ point_transform (
+ const double& angle,
+ const dlib::vector<double,2>& translate_
+ )
+ {
+ sin_angle = std::sin(angle);
+ cos_angle = std::cos(angle);
+ translate = translate_;
+ }
+
+ template <typename T>
+ const dlib::vector<T,2> operator() (
+ const dlib::vector<T,2>& p
+ ) const
+ {
+ double x = cos_angle*p.x() - sin_angle*p.y();
+ double y = sin_angle*p.x() + cos_angle*p.y();
+
+ return dlib::vector<double,2>(x,y) + translate;
+ }
+
+ const matrix<double,2,2> get_m(
+ ) const
+ {
+ matrix<double,2,2> temp;
+ temp = cos_angle, -sin_angle,
+ sin_angle, cos_angle;
+ return temp;
+ }
+
+ const dlib::vector<double,2> get_b(
+ ) const { return translate; }
+
+ inline friend void serialize (const point_transform& item, std::ostream& out)
+ {
+ serialize(item.sin_angle, out);
+ serialize(item.cos_angle, out);
+ serialize(item.translate, out);
+ }
+
+ inline friend void deserialize (point_transform& item, std::istream& in)
+ {
+ deserialize(item.sin_angle, in);
+ deserialize(item.cos_angle, in);
+ deserialize(item.translate, in);
+ }
+
+ private:
+ double sin_angle;
+ double cos_angle;
+ dlib::vector<double,2> translate;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class point_transform_affine
+ {
+ public:
+
+ point_transform_affine (
+ )
+ {
+ m = identity_matrix<double>(2);
+ b.x() = 0;
+ b.y() = 0;
+ }
+
+ point_transform_affine (
+ const matrix<double,2,2>& m_,
+ const dlib::vector<double,2>& b_
+ ) :m(m_), b(b_)
+ {
+ }
+
+ const dlib::vector<double,2> operator() (
+ const dlib::vector<double,2>& p
+ ) const
+ {
+ return m*p + b;
+ }
+
+ const matrix<double,2,2>& get_m(
+ ) const { return m; }
+
+ const dlib::vector<double,2>& get_b(
+ ) const { return b; }
+
+ inline friend void serialize (const point_transform_affine& item, std::ostream& out)
+ {
+ serialize(item.m, out);
+ serialize(item.b, out);
+ }
+
+ inline friend void deserialize (point_transform_affine& item, std::istream& in)
+ {
+ deserialize(item.m, in);
+ deserialize(item.b, in);
+ }
+
+ private:
+ matrix<double,2,2> m;
+ dlib::vector<double,2> b;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class rectangle_transform
+ {
+ public:
+
+ rectangle_transform (
+ )
+ {
+ }
+
+ rectangle_transform (
+ const point_transform_affine& tform_
+ ) :tform(tform_)
+ {
+ }
+
+ drectangle operator() (
+ const drectangle& r
+ ) const
+ {
+ dpoint tl = r.tl_corner();
+ dpoint tr = r.tr_corner();
+ dpoint bl = r.bl_corner();
+ dpoint br = r.br_corner();
+ // The new rectangle wouold ideally have this area if we could actually rotrate
+ // the box.
+ double new_area = length(tform(tl)-tform(tr))*length(tform(tl)-tform(bl));
+
+ // But if we rotate the coners of the rectangle and then find the rectangle
+ // that contains them we get this, which might have a much larger area than we
+ // want.
+ drectangle temp;
+ temp += tform(tl);
+ temp += tform(tr);
+ temp += tform(bl);
+ temp += tform(br);
+ // so we adjust the area to match the target area and have the same center as
+ // the above box.
+ double scale = std::sqrt(new_area/temp.area());
+
+ return centered_rect(center(temp), static_cast<long>(temp.width()*scale+0.5), static_cast<long>(temp.height()*scale+0.5));
+ }
+
+ rectangle operator() (
+ const rectangle& r
+ ) const
+ {
+ return (*this)(drectangle(r));
+ }
+
+ const point_transform_affine& get_tform(
+ ) const { return tform; }
+
+ inline friend void serialize (const rectangle_transform& item, std::ostream& out)
+ {
+ serialize(item.tform, out);
+ }
+
+ inline friend void deserialize (rectangle_transform& item, std::istream& in)
+ {
+ deserialize(item.tform, in);
+ }
+
+ private:
+ point_transform_affine tform;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine operator* (
+ const point_transform_affine& lhs,
+ const point_transform_affine& rhs
+ )
+ {
+ return point_transform_affine(lhs.get_m()*rhs.get_m(), lhs.get_m()*rhs.get_b()+lhs.get_b());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine inv (
+ const point_transform_affine& trans
+ )
+ {
+ matrix<double,2,2> im = inv(trans.get_m());
+ return point_transform_affine(im, -im*trans.get_b());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ point_transform_affine find_affine_transform (
+ const std::vector<dlib::vector<T,2> >& from_points,
+ const std::vector<dlib::vector<T,2> >& to_points
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(from_points.size() == to_points.size() &&
+ from_points.size() >= 3,
+ "\t point_transform_affine find_affine_transform(from_points, to_points)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t from_points.size(): " << from_points.size()
+ << "\n\t to_points.size(): " << to_points.size()
+ );
+
+ matrix<double,3,0> P(3, from_points.size());
+ matrix<double,2,0> Q(2, from_points.size());
+
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ P(0,i) = from_points[i].x();
+ P(1,i) = from_points[i].y();
+ P(2,i) = 1;
+
+ Q(0,i) = to_points[i].x();
+ Q(1,i) = to_points[i].y();
+ }
+
+ const matrix<double,2,3> m = Q*pinv(P);
+ return point_transform_affine(subm(m,0,0,2,2), colm(m,2));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ point_transform_affine find_similarity_transform (
+ const std::vector<dlib::vector<T,2> >& from_points,
+ const std::vector<dlib::vector<T,2> >& to_points
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(from_points.size() == to_points.size() &&
+ from_points.size() >= 2,
+ "\t point_transform_affine find_similarity_transform(from_points, to_points)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t from_points.size(): " << from_points.size()
+ << "\n\t to_points.size(): " << to_points.size()
+ );
+
+ // We use the formulas from the paper: Least-squares estimation of transformation
+ // parameters between two point patterns by Umeyama. They are equations 34 through
+ // 43.
+
+ dlib::vector<double,2> mean_from, mean_to;
+ double sigma_from = 0, sigma_to = 0;
+ matrix<double,2,2> cov;
+ cov = 0;
+
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ mean_from += from_points[i];
+ mean_to += to_points[i];
+ }
+ mean_from /= from_points.size();
+ mean_to /= from_points.size();
+
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ sigma_from += length_squared(from_points[i] - mean_from);
+ sigma_to += length_squared(to_points[i] - mean_to);
+ cov += (to_points[i] - mean_to)*trans(from_points[i] - mean_from);
+ }
+
+ sigma_from /= from_points.size();
+ sigma_to /= from_points.size();
+ cov /= from_points.size();
+
+ matrix<double,2,2> u, v, s, d;
+ svd(cov, u,d,v);
+ s = identity_matrix(cov);
+ if (det(cov) < 0 || (det(cov) == 0 && det(u)*det(v)<0))
+ {
+ if (d(1,1) < d(0,0))
+ s(1,1) = -1;
+ else
+ s(0,0) = -1;
+ }
+
+ matrix<double,2,2> r = u*s*trans(v);
+ double c = 1;
+ if (sigma_from != 0)
+ c = 1.0/sigma_from * trace(d*s);
+ vector<double,2> t = mean_to - c*r*mean_from;
+
+ return point_transform_affine(c*r, t);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class point_transform_projective
+ {
+ public:
+
+ point_transform_projective (
+ )
+ {
+ m = identity_matrix<double>(3);
+ }
+
+ point_transform_projective (
+ const matrix<double,3,3>& m_
+ ) :m(m_)
+ {
+ }
+
+ point_transform_projective (
+ const point_transform_affine& tran
+ )
+ {
+ set_subm(m, 0,0, 2,2) = tran.get_m();
+ set_subm(m, 0,2, 2,1) = tran.get_b();
+ m(2,0) = 0;
+ m(2,1) = 0;
+ m(2,2) = 1;
+ }
+
+
+ const dlib::vector<double,2> operator() (
+ const dlib::vector<double,2>& p
+ ) const
+ {
+ dlib::vector<double,3> temp(p);
+ temp.z() = 1;
+ temp = m*temp;
+ if (temp.z() != 0)
+ temp = temp/temp.z();
+ return temp;
+ }
+
+ const matrix<double,3,3>& get_m(
+ ) const { return m; }
+
+ inline friend void serialize (const point_transform_projective& item, std::ostream& out)
+ {
+ serialize(item.m, out);
+ }
+
+ inline friend void deserialize (point_transform_projective& item, std::istream& in)
+ {
+ deserialize(item.m, in);
+ }
+
+ private:
+ matrix<double,3,3> m;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_projective operator* (
+ const point_transform_projective& lhs,
+ const point_transform_projective& rhs
+ )
+ {
+ return point_transform_projective(lhs.get_m()*rhs.get_m());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_projective inv (
+ const point_transform_projective& trans
+ )
+ {
+ return point_transform_projective(inv(trans.get_m()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl_proj
+ {
+
+ inline point_transform_projective find_projective_transform_basic (
+ const std::vector<dlib::vector<double,2> >& from_points,
+ const std::vector<dlib::vector<double,2> >& to_points
+ )
+ /*!
+ ensures
+ - Uses the system of equations approach to finding a projective transform.
+ This is "Method 3" from Estimating Projective Transformation Matrix by
+ Zhengyou Zhang.
+ - It should be emphasized that the find_projective_transform_basic()
+ routine, which uses the most popular method for finding projective
+ transformations, doesn't really work well when the minimum error solution
+ doesn't have zero error. In this case, it can deviate by a large amount
+ from the proper minimum mean squared error transformation. Therefore,
+ our overall strategy will be to use the solution from
+ find_projective_transform_basic() as a starting point for a BFGS based
+ non-linear optimizer which will optimize the correct mean squared error
+ criterion.
+
+ A great essay on this subject is Homography Estimation by Elan Dubrofsky.
+ !*/
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(from_points.size() == to_points.size() &&
+ from_points.size() >= 4,
+ "\t point_transform_projective find_projective_transform_basic(from_points, to_points)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t from_points.size(): " << from_points.size()
+ << "\n\t to_points.size(): " << to_points.size()
+ );
+
+ matrix<double,9,9> accum, u, v;
+ matrix<double,9,1> w;
+ matrix<double,2,9> B;
+ accum = 0;
+ B = 0;
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ dlib::vector<double,3> f = from_points[i];
+ f.z() = 1;
+ dlib::vector<double,3> t = to_points[i];
+ t.z() = 1;
+
+ set_subm(B,0,0,1,3) = t.y()*trans(f);
+ set_subm(B,1,0,1,3) = trans(f);
+
+ set_subm(B,0,3,1,3) = -t.x()*trans(f);
+ set_subm(B,1,6,1,3) = -t.x()*trans(f);
+
+ accum += trans(B)*B;
+ }
+
+ svd2(true, false, accum, u, w, v);
+ long j = index_of_min(w);
+
+ return point_transform_projective(reshape(colm(u,j),3,3));
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ struct obj
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the objective function we really want to minimize when looking
+ for a transformation matrix. That is, we would like the transformed
+ points to be as close as possible to their "to" points. Here,
+ closeness is measured using Euclidean distance.
+
+ !*/
+ obj(
+ const std::vector<dlib::vector<double,2> >& from_points_,
+ const std::vector<dlib::vector<double,2> >& to_points_
+ ) :
+ from_points(from_points_) ,
+ to_points(to_points_)
+ {}
+ const std::vector<dlib::vector<double,2> >& from_points;
+ const std::vector<dlib::vector<double,2> >& to_points;
+
+ double operator() (
+ const matrix<double,9,1>& p
+ ) const
+ {
+ point_transform_projective tran(reshape(p,3,3));
+
+ double sum = 0;
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ sum += length_squared(tran(from_points[i]) - to_points[i]);
+ }
+ return sum;
+ }
+ };
+
+ struct obj_der
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the derivative of obj.
+ !*/
+
+ obj_der(
+ const std::vector<dlib::vector<double,2> >& from_points_,
+ const std::vector<dlib::vector<double,2> >& to_points_
+ ) :
+ from_points(from_points_) ,
+ to_points(to_points_)
+ {}
+ const std::vector<dlib::vector<double,2> >& from_points;
+ const std::vector<dlib::vector<double,2> >& to_points;
+
+ matrix<double,9,1> operator() (
+ const matrix<double,9,1>& p
+ ) const
+ {
+ const matrix<double,3,3> H = reshape(p,3,3);
+
+ matrix<double,3,3> grad;
+ grad = 0;
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ dlib::vector<double,3> from, to;
+ from = from_points[i];
+ from.z() = 1;
+ to = to_points[i];
+ to.z() = 1;
+
+ matrix<double,3,1> w = H*from;
+ const double scale = (w(2) != 0) ? (1.0/w(2)) : (1);
+ w *= scale;
+ matrix<double,3,1> residual = (w-to)*2*scale;
+
+ grad(0,0) += from.x()*residual(0);
+ grad(0,1) += from.y()*residual(0);
+ grad(0,2) += residual(0);
+
+ grad(1,0) += from.x()*residual(1);
+ grad(1,1) += from.y()*residual(1);
+ grad(1,2) += residual(1);
+
+ grad(2,0) += -(from.x()*w(0)*residual(0) + from.x()*w(1)*residual(1));
+ grad(2,1) += -(from.y()*w(0)*residual(0) + from.y()*w(1)*residual(1));
+ grad(2,2) += -( w(0)*residual(0) + w(1)*residual(1));
+
+ }
+ return reshape_to_column_vector(grad);
+ }
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_projective find_projective_transform (
+ const std::vector<dlib::vector<double,2> >& from_points,
+ const std::vector<dlib::vector<double,2> >& to_points
+ )
+ {
+ using namespace impl_proj;
+ // make sure requires clause is not broken
+ DLIB_ASSERT(from_points.size() == to_points.size() &&
+ from_points.size() >= 4,
+ "\t point_transform_projective find_projective_transform(from_points, to_points)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t from_points.size(): " << from_points.size()
+ << "\n\t to_points.size(): " << to_points.size()
+ );
+
+
+ // Find a candidate projective transformation. Also, find the best affine
+ // transform and then compare it with the projective transform estimated using the
+ // direct SVD method. Use whichever one works better as the starting point for a
+ // BFGS based optimizer. If the best solution has large mean squared error and is
+ // also close to affine then find_projective_transform_basic() might give a very
+ // bad initial guess. So also checking for a good affine transformation can
+ // produce a much better final result in many cases.
+ point_transform_projective tran1 = find_projective_transform_basic(from_points, to_points);
+ point_transform_affine tran2 = find_affine_transform(from_points, to_points);
+
+ // check which is best
+ double error1 = 0;
+ double error2 = 0;
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ error1 += length_squared(tran1(from_points[i])-to_points[i]);
+ error2 += length_squared(tran2(from_points[i])-to_points[i]);
+ }
+ matrix<double,9,1> params;
+ // Pick the minimum error solution among the two so far.
+ if (error1 < error2)
+ params = reshape_to_column_vector(tran1.get_m());
+ else
+ params = reshape_to_column_vector(point_transform_projective(tran2).get_m());
+
+
+ // Now refine the transformation matrix so that we can be sure we have
+ // at least a local minimizer.
+ obj o(from_points, to_points);
+ obj_der der(from_points, to_points);
+ find_min(bfgs_search_strategy(),
+ objective_delta_stop_strategy(1e-6,100),
+ o,
+ der,
+ params,
+ 0);
+
+ return point_transform_projective(reshape(params,3,3));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ const dlib::vector<T,2> rotate_point (
+ const dlib::vector<T,2>& center,
+ const dlib::vector<T,2>& p,
+ double angle
+ )
+ {
+ point_rotator rot(angle);
+ return rot(p-center)+center;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline matrix<double,2,2> rotation_matrix (
+ double angle
+ )
+ {
+ const double ca = std::cos(angle);
+ const double sa = std::sin(angle);
+
+ matrix<double,2,2> m;
+ m = ca, -sa,
+ sa, ca;
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class point_transform_affine3d
+ {
+ public:
+
+ point_transform_affine3d (
+ )
+ {
+ m = identity_matrix<double>(3);
+ b.x() = 0;
+ b.y() = 0;
+ }
+
+ point_transform_affine3d (
+ const matrix<double,3,3>& m_,
+ const dlib::vector<double,3>& b_
+ ) :m(m_), b(b_)
+ {
+ }
+
+ const dlib::vector<double,3> operator() (
+ const dlib::vector<double,3>& p
+ ) const
+ {
+ return m*p + b;
+ }
+
+ const matrix<double,3,3>& get_m(
+ ) const { return m; }
+
+ const dlib::vector<double,3>& get_b(
+ ) const { return b; }
+
+ inline friend void serialize (const point_transform_affine3d& item, std::ostream& out)
+ {
+ serialize(item.m, out);
+ serialize(item.b, out);
+ }
+
+ inline friend void deserialize (point_transform_affine3d& item, std::istream& in)
+ {
+ deserialize(item.m, in);
+ deserialize(item.b, in);
+ }
+
+ private:
+ matrix<double,3,3> m;
+ dlib::vector<double,3> b;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine3d operator* (
+ const point_transform_affine3d& lhs,
+ const point_transform_affine& rhs
+ )
+ {
+ matrix<double,3,3> m;
+ m = 0;
+ set_subm(m, get_rect(rhs.get_m())) = rhs.get_m();
+ vector<double,3> b = rhs.get_b();
+
+ return point_transform_affine3d(lhs.get_m()*m, lhs.get_m()*b+lhs.get_b());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine3d operator* (
+ const point_transform_affine3d& lhs,
+ const point_transform_affine3d& rhs
+ )
+ {
+ return point_transform_affine3d(lhs.get_m()*rhs.get_m(), lhs.get_m()*rhs.get_b()+lhs.get_b());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine3d inv (
+ const point_transform_affine3d& trans
+ )
+ {
+ matrix<double,3,3> im = inv(trans.get_m());
+ return point_transform_affine3d(im, -im*trans.get_b());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine3d rotate_around_x (
+ double angle
+ )
+ {
+ const double ca = std::cos(angle);
+ const double sa = std::sin(angle);
+
+ matrix<double,3,3> m;
+ m = 1, 0, 0,
+ 0, ca, -sa,
+ 0, sa, ca;
+
+ vector<double,3> b;
+
+ return point_transform_affine3d(m,b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine3d rotate_around_y (
+ double angle
+ )
+ {
+ const double ca = std::cos(angle);
+ const double sa = std::sin(angle);
+
+ matrix<double,3,3> m;
+ m = ca, 0, sa,
+ 0, 1, 0,
+ -sa, 0, ca;
+
+ vector<double,3> b;
+
+ return point_transform_affine3d(m,b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine3d rotate_around_z (
+ double angle
+ )
+ {
+ const double ca = std::cos(angle);
+ const double sa = std::sin(angle);
+
+ matrix<double,3,3> m;
+ m = ca, -sa, 0,
+ sa, ca, 0,
+ 0, 0, 1;
+
+ vector<double,3> b;
+
+ return point_transform_affine3d(m,b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine3d translate_point (
+ const vector<double,3>& delta
+ )
+ {
+ return point_transform_affine3d(identity_matrix<double>(3),delta);
+ }
+
+ inline point_transform_affine3d translate_point (
+ double x,
+ double y,
+ double z
+ )
+ {
+ return translate_point(vector<double>(x,y,z));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class camera_transform
+ {
+
+ public:
+
+ camera_transform (
+ )
+ {
+ *this = camera_transform(vector<double>(1,1,1),
+ vector<double>(0,0,0),
+ vector<double>(0,0,1),
+ 90,
+ 1);
+ }
+
+ camera_transform (
+ const vector<double>& camera_pos_,
+ const vector<double>& camera_looking_at_,
+ const vector<double>& camera_up_direction_,
+ const double camera_field_of_view_,
+ const unsigned long num_pixels_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(0 < camera_field_of_view_ && camera_field_of_view_ < 180,
+ "\t camera_transform::camera_transform()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t camera_field_of_view_: " << camera_field_of_view_
+ );
+
+ camera_pos = camera_pos_;
+ camera_looking_at = camera_looking_at_;
+ camera_up_direction = camera_up_direction_;
+ camera_field_of_view = camera_field_of_view_;
+ num_pixels = num_pixels_;
+
+ dlib::vector<double> X,Y,Z;
+ Z = (camera_looking_at - camera_pos).normalize();
+ Y = camera_up_direction - dot(camera_up_direction,Z)*Z;
+ Y = Y.normalize();
+ X = Z.cross(Y);
+
+ set_rowm(proj,0) = trans(X);
+ // Minus because images have y axis going down but we want the 3d projection to appear using a normal coordinate system with y going up.
+ set_rowm(proj,1) = -trans(Y);
+ set_rowm(proj,2) = trans(Z);
+
+ width = num_pixels/2.0;
+ dist_scale = width/std::tan(pi/180*camera_field_of_view/2);
+ }
+
+ vector<double> get_camera_pos() const { return camera_pos; }
+ vector<double> get_camera_looking_at() const { return camera_looking_at; }
+ vector<double> get_camera_up_direction()const { return camera_up_direction; }
+ double get_camera_field_of_view() const { return camera_field_of_view; }
+ unsigned long get_num_pixels() const { return num_pixels; }
+
+ inline dpoint operator() (
+ const vector<double>& p,
+ double& scale,
+ double& distance
+ ) const
+ {
+ vector<double> temp = p-camera_pos;
+ temp = proj*temp;
+ distance = temp.z();
+ scale = dist_scale/(temp.z()>0 ? temp.z() : 1e-9);
+ temp.x() = temp.x()*scale + width;
+ temp.y() = temp.y()*scale + width;
+ return temp;
+ }
+
+ dpoint operator() (
+ const vector<double>& p
+ ) const
+ {
+ double scale, distance;
+ return (*this)(p,scale,distance);
+ }
+
+ inline friend void serialize (const camera_transform& item, std::ostream& out)
+ {
+ serialize(item.camera_pos, out);
+ serialize(item.camera_looking_at, out);
+ serialize(item.camera_up_direction, out);
+ serialize(item.camera_field_of_view, out);
+ serialize(item.num_pixels, out);
+ serialize(item.proj, out);
+ serialize(item.dist_scale, out);
+ serialize(item.width, out);
+ }
+
+ inline friend void deserialize (camera_transform& item, std::istream& in)
+ {
+ deserialize(item.camera_pos, in);
+ deserialize(item.camera_looking_at, in);
+ deserialize(item.camera_up_direction, in);
+ deserialize(item.camera_field_of_view, in);
+ deserialize(item.num_pixels, in);
+ deserialize(item.proj, in);
+ deserialize(item.dist_scale, in);
+ deserialize(item.width, in);
+ }
+
+ private:
+
+ vector<double> camera_pos;
+ vector<double> camera_looking_at;
+ vector<double> camera_up_direction;
+ double camera_field_of_view;
+ unsigned long num_pixels;
+ matrix<double,3,3> proj;
+ double dist_scale;
+ double width;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_POINT_TrANSFORMS_H_
+
diff --git a/ml/dlib/dlib/geometry/point_transforms_abstract.h b/ml/dlib/dlib/geometry/point_transforms_abstract.h
new file mode 100644
index 000000000..492ae745c
--- /dev/null
+++ b/ml/dlib/dlib/geometry/point_transforms_abstract.h
@@ -0,0 +1,797 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_POINT_TrANSFORMS_ABSTRACT_Hh_
+#ifdef DLIB_POINT_TrANSFORMS_ABSTRACT_Hh_
+
+#include "../matrix/matrix_abstract.h"
+#include "vector_abstract.h"
+#include "rectangle_abstract.h"
+#include "drectangle_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class point_transform_affine
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an object that takes 2D points or vectors and
+ applies an affine transformation to them.
+
+ THREAD SAFETY
+ It is safe for multiple threads to make concurrent accesses to this object
+ without synchronization.
+ !*/
+ public:
+
+ point_transform_affine (
+ );
+ /*!
+ ensures
+ - This object will perform the identity transform. That is, given a point
+ as input it will return the same point as output.
+ !*/
+
+ point_transform_affine (
+ const matrix<double,2,2>& m,
+ const dlib::vector<double,2>& b
+ );
+ /*!
+ ensures
+ - #get_m() == m
+ - #get_b() == b
+ - When (*this)(p) is invoked it will return a point P such that:
+ - P == m*p + b
+ !*/
+
+ const dlib::vector<double,2> operator() (
+ const dlib::vector<double,2>& p
+ ) const;
+ /*!
+ ensures
+ - applies the affine transformation defined by this object's constructor
+ to p and returns the result.
+ !*/
+
+ const matrix<double,2,2>& get_m(
+ ) const;
+ /*!
+ ensures
+ - returns the transformation matrix used by this object.
+ !*/
+
+ const dlib::vector<double,2>& get_b(
+ ) const;
+ /*!
+ ensures
+ - returns the offset vector used by this object.
+ !*/
+
+ };
+
+ void serialize (const point_transform_affine& item, std::ostream& out);
+ void deserialize (point_transform_affine& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_affine operator* (
+ const point_transform_affine& lhs,
+ const point_transform_affine& rhs
+ );
+ /*!
+ ensures
+ - returns a transformation TFORM(x) that is equivalent to lhs(rhs(x)). That
+ is, for all valid x: TFORM(x) == lhs(rhs(x)).
+ !*/
+
+ // ----------------------------------------------------------------------------------------
+
+ point_transform_affine inv (
+ const point_transform_affine& trans
+ );
+ /*!
+ ensures
+ - If trans is an invertible transformation then this function returns a new
+ transformation that is the inverse of trans.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ point_transform_affine find_affine_transform (
+ const std::vector<dlib::vector<T,2> >& from_points,
+ const std::vector<dlib::vector<T,2> >& to_points
+ );
+ /*!
+ requires
+ - from_points.size() == to_points.size()
+ - from_points.size() >= 3
+ ensures
+ - returns a point_transform_affine object, T, such that for all valid i:
+ length(T(from_points[i]) - to_points[i])
+ is minimized as often as possible. That is, this function finds the affine
+ transform that maps points in from_points to points in to_points. If no
+ affine transform exists which performs this mapping exactly then the one
+ which minimizes the mean squared error is selected. Additionally, if many
+ equally good transformations exist, then the transformation with the smallest
+ squared parameters is selected (i.e. if you wrote the transformation as a
+ matrix then we say we select the transform with minimum Frobenius norm among
+ all possible solutions).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ point_transform_affine find_similarity_transform (
+ const std::vector<dlib::vector<T,2> >& from_points,
+ const std::vector<dlib::vector<T,2> >& to_points
+ );
+ /*!
+ requires
+ - from_points.size() == to_points.size()
+ - from_points.size() >= 2
+ ensures
+ - This function is just like find_affine_transform() except it finds the best
+ similarity transform instead of a full affine transform. This means that it
+ optimizes over only the space of rotations, scale changes, and translations.
+ So for example, if you mapped the 3 vertices of a triangle through a
+ similarity transform then the output would still be the same triangle.
+ However, the triangle itself may be larger or smaller, rotated, or at a
+ different location in the coordinate system. This is not the case for a
+ general affine transform which can stretch points in ways that cause, for
+ example, an equilateral triangle to turn into an isosceles triangle.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class rectangle_transform
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is just a point_transform_affine wrapped up so that it can
+ transform rectangle objects. It will take a rectangle and transform it
+ according to an affine transformation.
+
+ THREAD SAFETY
+ It is safe for multiple threads to make concurrent accesses to this object
+ without synchronization.
+ !*/
+ public:
+
+ rectangle_transform (
+ );
+ /*!
+ ensures
+ - This object will perform the identity transform. That is, given a rectangle
+ as input it will return the same rectangle as output.
+ !*/
+
+ rectangle_transform (
+ const point_transform_affine& tform
+ );
+ /*!
+ ensures
+ - #get_tform() == tform
+ !*/
+
+ drectangle operator() (
+ const drectangle& r
+ ) const;
+ /*!
+ ensures
+ - Applies the transformation get_tform() to r and returns the resulting
+ rectangle. If the transformation doesn't have any rotation then the
+ transformation simply maps the corners of the rectangle according to
+ get_tform() and returns the exact result. However, since
+ dlib::drectangle can't represent rotated rectangles, if there is any
+ rotation in the affine transform we will attempt to produce the most
+ faithful possible outputs by ensuring the output rectangle has the
+ correct center point and that its area and aspect ratio match the correct
+ rotated rectangle's as much as possible.
+ !*/
+
+ rectangle operator() (
+ const rectangle& r
+ ) const;
+ /*!
+ ensures
+ - returns (*this)(drectangle(r))
+ !*/
+
+ const point_transform_affine& get_tform(
+ ) const;
+ /*!
+ ensures
+ - returns the affine transformation this object uses to transform rectangles.
+ !*/
+
+ };
+
+ void serialize (const rectangle_transform& item, std::ostream& out);
+ void deserialize (rectangle_transform& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class point_transform_projective
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an object that takes 2D points or vectors and
+ applies a projective transformation to them.
+
+ THREAD SAFETY
+ It is safe for multiple threads to make concurrent accesses to this object
+ without synchronization.
+ !*/
+
+ public:
+
+ point_transform_projective (
+ );
+ /*!
+ ensures
+ - This object will perform the identity transform. That is, given a point
+ as input it will return the same point as output.
+ !*/
+
+ point_transform_projective (
+ const matrix<double,3,3>& m
+ );
+ /*!
+ ensures
+ - #get_m() == m
+ !*/
+
+ point_transform_projective (
+ const point_transform_affine& tran
+ );
+ /*!
+ ensures
+ - This object will perform exactly the same transformation as the given
+ affine transform.
+ !*/
+
+ const dlib::vector<double,2> operator() (
+ const dlib::vector<double,2>& p
+ ) const;
+ /*!
+ ensures
+ - Applies the projective transformation defined by this object's constructor
+ to p and returns the result. To define this precisely:
+ - let p_h == the point p in homogeneous coordinates. That is:
+ - p_h.x() == p.x()
+ - p_h.y() == p.y()
+ - p_h.z() == 1
+ - let x == get_m()*p_h
+ - Then this function returns the value x/x.z()
+ !*/
+
+ const matrix<double,3,3>& get_m(
+ ) const;
+ /*!
+ ensures
+ - returns the transformation matrix used by this object.
+ !*/
+
+ };
+
+ void serialize (const point_transform_projective& item, std::ostream& out);
+ void deserialize (point_transform_projective& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_projective operator* (
+ const point_transform_projective& lhs,
+ const point_transform_projective& rhs
+ );
+ /*!
+ ensures
+ - returns a transformation TFORM(x) that is equivalent to lhs(rhs(x)). That
+ is, for all valid x: TFORM(x) == lhs(rhs(x)).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_projective inv (
+ const point_transform_projective& trans
+ );
+ /*!
+ ensures
+ - If trans is an invertible transformation then this function returns a new
+ transformation that is the inverse of trans.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_projective find_projective_transform (
+ const std::vector<dlib::vector<double,2> >& from_points,
+ const std::vector<dlib::vector<double,2> >& to_points
+ );
+ /*!
+ requires
+ - from_points.size() == to_points.size()
+ - from_points.size() >= 4
+ ensures
+ - returns a point_transform_projective object, T, such that for all valid i:
+ length(T(from_points[i]) - to_points[i])
+ is minimized as often as possible. That is, this function finds the projective
+ transform that maps points in from_points to points in to_points. If no
+ projective transform exists which performs this mapping exactly then the one
+ which minimizes the mean squared error is selected.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class point_transform
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an object that takes 2D points or vectors and
+ rotates them around the origin by a given angle and then
+ translates them.
+
+ THREAD SAFETY
+ It is safe for multiple threads to make concurrent accesses to this object
+ without synchronization.
+ !*/
+ public:
+
+ point_transform (
+ );
+ /*!
+ ensures
+ - This object will perform the identity transform. That is, given a point
+ as input it will return the same point as output.
+ !*/
+
+ point_transform (
+ const double& angle,
+ const dlib::vector<double,2>& translate
+ )
+ /*!
+ ensures
+ - When (*this)(p) is invoked it will return a point P such that:
+ - P is the point p rotated counter-clockwise around the origin
+ angle radians and then shifted by having translate added to it.
+ (Note that this is counter clockwise with respect to the normal
+ coordinate system with positive y going up and positive x going
+ to the right)
+ !*/
+
+ template <typename T>
+ const dlib::vector<T,2> operator() (
+ const dlib::vector<T,2>& p
+ ) const;
+ /*!
+ ensures
+ - rotates p, then translates it and returns the result. The output
+ of this function is therefore equal to get_m()*p + get_b().
+ !*/
+
+ const matrix<double,2,2> get_m(
+ ) const;
+ /*!
+ ensures
+ - returns the transformation matrix used by this object.
+ !*/
+
+ const dlib::vector<double,2> get_b(
+ ) const;
+ /*!
+ ensures
+ - returns the offset vector used by this object.
+ !*/
+
+ };
+
+ void serialize (const point_transform& item, std::ostream& out);
+ void deserialize (point_transform& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class point_rotator
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an object that takes 2D points or vectors and
+ rotates them around the origin by a given angle.
+
+ THREAD SAFETY
+ It is safe for multiple threads to make concurrent accesses to this object
+ without synchronization.
+ !*/
+ public:
+
+ point_rotator (
+ );
+ /*!
+ ensures
+ - This object will perform the identity transform. That is, given a point
+ as input it will return the same point as output.
+ !*/
+
+ point_rotator (
+ const double& angle
+ );
+ /*!
+ ensures
+ - When (*this)(p) is invoked it will return a point P such that:
+ - P is the point p rotated counter-clockwise around the origin
+ angle radians.
+ (Note that this is counter clockwise with respect to the normal
+ coordinate system with positive y going up and positive x going
+ to the right)
+ !*/
+
+ template <typename T>
+ const dlib::vector<T,2> operator() (
+ const dlib::vector<T,2>& p
+ ) const;
+ /*!
+ ensures
+ - rotates p and returns the result. The output of this function is
+ therefore equal to get_m()*p.
+ !*/
+
+ const matrix<double,2,2> get_m(
+ ) const;
+ /*!
+ ensures
+ - returns the transformation matrix used by this object.
+ !*/
+ };
+
+ void serialize (const point_rotator& item, std::ostream& out);
+ void deserialize (point_rotator& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ const dlib::vector<T,2> rotate_point (
+ const dlib::vector<T,2> center,
+ const dlib::vector<T,2> p,
+ double angle
+ );
+ /*!
+ ensures
+ - returns a point P such that:
+ - P is the point p rotated counter-clockwise around the given
+ center point by angle radians.
+ (Note that this is counter clockwise with respect to the normal
+ coordinate system with positive y going up and positive x going
+ to the right)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ matrix<double,2,2> rotation_matrix (
+ double angle
+ );
+ /*!
+ ensures
+ - returns a rotation matrix which rotates points around the origin in a
+ counter-clockwise direction by angle radians.
+ (Note that this is counter clockwise with respect to the normal
+ coordinate system with positive y going up and positive x going
+ to the right)
+ Or in other words, this function returns a matrix M such that, given a
+ point P, M*P gives a point which is P rotated by angle radians around
+ the origin in a counter-clockwise direction.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class point_transform_affine3d
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an object that takes 3D points or vectors and
+ applies an affine transformation to them.
+
+ THREAD SAFETY
+ It is safe for multiple threads to make concurrent accesses to this object
+ without synchronization.
+ !*/
+ public:
+
+ point_transform_affine3d (
+ );
+ /*!
+ ensures
+ - This object will perform the identity transform. That is, given a point
+ as input it will return the same point as output.
+ !*/
+
+ point_transform_affine3d (
+ const matrix<double,3,3>& m,
+ const dlib::vector<double,3>& b
+ );
+ /*!
+ ensures
+ - #get_m() == m
+ - #get_b() == b
+ - When (*this)(p) is invoked it will return a point P such that:
+ - P == m*p + b
+ !*/
+
+ const dlib::vector<double,3> operator() (
+ const dlib::vector<double,3>& p
+ ) const;
+ /*!
+ ensures
+ - applies the affine transformation defined by this object's constructor
+ to p and returns the result.
+ !*/
+
+ const matrix<double,3,3>& get_m(
+ ) const;
+ /*!
+ ensures
+ - returns the transformation matrix used by this object.
+ !*/
+
+ const dlib::vector<double,3>& get_b(
+ ) const;
+ /*!
+ ensures
+ - returns the offset vector used by this object.
+ !*/
+
+ };
+
+ void serialize (const point_transform_affine3d& item, std::ostream& out);
+ void deserialize (point_transform_affine3d& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_affine3d operator* (
+ const point_transform_affine3d& lhs,
+ const point_transform_affine3d& rhs
+ );
+ /*!
+ ensures
+ - returns a transformation TFORM(x) that is equivalent to lhs(rhs(x)). That
+ is, for all valid x: TFORM(x) == lhs(rhs(x)).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_affine3d operator* (
+ const point_transform_affine3d& lhs,
+ const point_transform_affine& rhs
+ );
+ /*!
+ ensures
+ - returns a transformation TFORM(x) that is equivalent to lhs(rhs(x)). That
+ is, for all valid x: TFORM(x) == lhs(rhs(x)).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_affine3d inv (
+ const point_transform_affine3d& trans
+ );
+ /*!
+ ensures
+ - If trans is an invertible transformation then this function returns a new
+ transformation that is the inverse of trans.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_affine3d rotate_around_x (
+ double angle
+ );
+ /*!
+ ensures
+ - Returns a transformation that rotates a point around the x axis in a
+ counter-clockwise direction by angle radians. That is, the rotation appears
+ counter-clockwise when the x axis points toward the observer, the coordinate
+ system is right-handed, and the angle is positive.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_affine3d rotate_around_y (
+ double angle
+ );
+ /*!
+ ensures
+ - Returns a transformation that rotates a point around the y axis in a
+ counter-clockwise direction by angle radians. That is, the rotation appears
+ counter-clockwise when the y axis points toward the observer, the coordinate
+ system is right-handed, and the angle is positive.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_affine3d rotate_around_z (
+ double angle
+ );
+ /*!
+ ensures
+ - Returns a transformation that rotates a point around the z axis in a
+ counter-clockwise direction by angle radians. That is, the rotation appears
+ counter-clockwise when the z axis points toward the observer, the coordinate
+ system is right-handed, and the angle is positive.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_affine3d translate_point (
+ const vector<double,3>& delta
+ );
+ /*!
+ ensures
+ - returns a transformation that simply translates points by adding delta to
+ them. That is, this function returns:
+ point_transform_affine3d(identity_matrix<double>(3),delta);
+ !*/
+
+ point_transform_affine3d translate_point (
+ double x,
+ double y,
+ double z
+ );
+ /*!
+ ensures
+ - returns translate_point(vector<double>(x,y,z))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class camera_transform
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object maps 3D points into the image plane of a camera. Therefore,
+ you can use it to compute 2D representations of 3D data from the point of
+ view of some camera in 3D space.
+
+ THREAD SAFETY
+ It is safe for multiple threads to make concurrent accesses to this object
+ without synchronization.
+ !*/
+
+ public:
+
+ camera_transform (
+ );
+ /*!
+ ensures
+ - #get_camera_pos() == vector<double>(1,1,1)
+ - #get_camera_looking_at() == vector<double>(0,0,0)
+ - #get_camera_up_direction() == vector<double>(0,0,1)
+ - #get_camera_field_of_view() == 90
+ - #get_num_pixels() == 1
+ !*/
+
+ camera_transform (
+ const vector<double>& camera_pos,
+ const vector<double>& camera_looking_at,
+ const vector<double>& camera_up_direction,
+ const double camera_field_of_view,
+ const unsigned long num_pixels
+ );
+ /*!
+ requires
+ - 0 < camera_field_of_view < 180
+ ensures
+ - #get_camera_pos() == camera_pos
+ - #get_camera_looking_at() == camera_looking_at
+ - #get_camera_up_direction() == camera_up_direction
+ - #get_camera_field_of_view() == camera_field_of_view
+ - #get_num_pixels() == num_pixels
+ !*/
+
+ dpoint operator() (
+ const vector<double>& p
+ ) const;
+ /*!
+ ensures
+ - Maps the given 3D point p into the 2D image plane defined by the camera
+ parameters given to this object's constructor. The 2D point in the image
+ plane is returned.
+ !*/
+
+ dpoint operator() (
+ const vector<double>& p,
+ double& scale,
+ double& distance
+ ) const;
+ /*!
+ ensures
+ - Maps the given 3D point p into the 2D image plane defined by the camera
+ parameters given to this object's constructor. The 2D point in the image
+ plane is returned.
+ - #scale == a number that tells you how large things are at the point p.
+ Objects further from the camera appear smaller, in particular, they
+ appear #scale times their normal size.
+ - #distance == how far away the point is from the image plane. Objects in
+ front of the camera will have a positive distance and those behind a
+ negative distance.
+ !*/
+
+ vector<double> get_camera_pos(
+ ) const;
+ /*!
+ ensures
+ - returns the position, in 3D space, of the camera. When operator() is
+ invoked it maps 3D points into the image plane of this camera.
+ !*/
+
+ vector<double> get_camera_looking_at(
+ ) const;
+ /*!
+ ensures
+ - returns the point in 3D space the camera is pointed at.
+ !*/
+
+ vector<double> get_camera_up_direction(
+ ) const;
+ /*!
+ ensures
+ - returns a vector that defines what direction is "up" for the camera.
+ This means that as you travel from the bottom of the image plane to the
+ top you will be traveling in the direction of this vector. Note that
+ get_camera_up_direction() doesn't need to be orthogonal to the camera's
+ line of sight (i.e. get_camera_looking_at()-get_camera_pos()), it just
+ needs to not be an exact multiple of the line of sight. Any necessary
+ orthogonalization will be taken care of internally.
+ !*/
+
+ double get_camera_field_of_view(
+ ) const;
+ /*!
+ ensures
+ - returns the field of view of the camera in degrees.
+ !*/
+
+ unsigned long get_num_pixels(
+ ) const;
+ /*!
+ ensures
+ - 3D points that fall within the field of view of the camera are mapped by
+ operator() into the pixel coordinates of a get_num_pixels() by
+ get_num_pixels() image. Therefore, you can use the output of operator()
+ to index into an image. However, you still need to perform bounds
+ checking as there might be 3D points outside the field of view of the
+ camera and those will be mapped to 2D points outside the image.
+ !*/
+
+ };
+
+ void serialize (const camera_transform& item, std::ostream& out);
+ void deserialize (camera_transform& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_POINT_TrANSFORMS_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/geometry/rectangle.h b/ml/dlib/dlib/geometry/rectangle.h
new file mode 100644
index 000000000..3d67ca8c4
--- /dev/null
+++ b/ml/dlib/dlib/geometry/rectangle.h
@@ -0,0 +1,824 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RECTANGLe_
+#define DLIB_RECTANGLe_
+
+#include "rectangle_abstract.h"
+#include "../algs.h"
+#include <algorithm>
+#include <iostream>
+#include "../serialize.h"
+#include "vector.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class rectangle
+ {
+ /*!
+ INITIAL VALUE
+ The initial value of this object is defined by its constructor.
+
+ CONVENTION
+ left() == l
+ top() == t
+ right() == r
+ bottom() == b
+ !*/
+
+ public:
+
+ rectangle (
+ long l_,
+ long t_,
+ long r_,
+ long b_
+ ) :
+ l(l_),
+ t(t_),
+ r(r_),
+ b(b_)
+ {}
+
+ rectangle (
+ unsigned long w,
+ unsigned long h
+ ) :
+ l(0),
+ t(0),
+ r(static_cast<long>(w)-1),
+ b(static_cast<long>(h)-1)
+ {
+ DLIB_ASSERT((w > 0 && h > 0) || (w == 0 && h == 0),
+ "\trectangle(width,height)"
+ << "\n\twidth and height must be > 0 or both == 0"
+ << "\n\twidth: " << w
+ << "\n\theight: " << h
+ << "\n\tthis: " << this
+ );
+ }
+
+ rectangle (
+ const point& p
+ ) :
+ l(p.x()),
+ t(p.y()),
+ r(p.x()),
+ b(p.y())
+ {
+ }
+
+ rectangle (
+ const point& p1,
+ const point& p2
+ )
+ {
+ *this = rectangle(p1) + rectangle(p2);
+ }
+
+ template <typename T>
+ rectangle (
+ const vector<T,2>& p1,
+ const vector<T,2>& p2
+ )
+ {
+ *this = rectangle(p1) + rectangle(p2);
+ }
+
+ rectangle (
+ ) :
+ l(0),
+ t(0),
+ r(-1),
+ b(-1)
+ {}
+
+ long top (
+ ) const { return t; }
+
+ long& top (
+ ) { return t; }
+
+ void set_top (
+ long top_
+ ) { t = top_; }
+
+ long left (
+ ) const { return l; }
+
+ long& left (
+ ) { return l; }
+
+ void set_left (
+ long left_
+ ) { l = left_; }
+
+ long right (
+ ) const { return r; }
+
+ long& right (
+ ) { return r; }
+
+ void set_right (
+ long right_
+ ) { r = right_; }
+
+ long bottom (
+ ) const { return b; }
+
+ long& bottom (
+ ) { return b; }
+
+ void set_bottom (
+ long bottom_
+ ) { b = bottom_; }
+
+ const point tl_corner (
+ ) const { return point(left(), top()); }
+
+ const point bl_corner (
+ ) const { return point(left(), bottom()); }
+
+ const point tr_corner (
+ ) const { return point(right(), top()); }
+
+ const point br_corner (
+ ) const { return point(right(), bottom()); }
+
+ unsigned long width (
+ ) const
+ {
+ if (is_empty())
+ return 0;
+ else
+ return r - l + 1;
+ }
+
+ unsigned long height (
+ ) const
+ {
+ if (is_empty())
+ return 0;
+ else
+ return b - t + 1;
+ }
+
+ unsigned long area (
+ ) const
+ {
+ return width()*height();
+ }
+
+ bool is_empty (
+ ) const { return (t > b || l > r); }
+
+ rectangle operator + (
+ const rectangle& rhs
+ ) const
+ {
+ if (rhs.is_empty())
+ return *this;
+ else if (is_empty())
+ return rhs;
+
+ return rectangle (
+ std::min(l,rhs.l),
+ std::min(t,rhs.t),
+ std::max(r,rhs.r),
+ std::max(b,rhs.b)
+ );
+ }
+
+ rectangle intersect (
+ const rectangle& rhs
+ ) const
+ {
+ return rectangle (
+ std::max(l,rhs.l),
+ std::max(t,rhs.t),
+ std::min(r,rhs.r),
+ std::min(b,rhs.b)
+ );
+ }
+
+ bool contains (
+ const point& p
+ ) const
+ {
+ if (p.x() < l || p.x() > r || p.y() < t || p.y() > b)
+ return false;
+ return true;
+ }
+
+ bool contains (
+ long x,
+ long y
+ ) const
+ {
+ if (x < l || x > r || y < t || y > b)
+ return false;
+ return true;
+ }
+
+ bool contains (
+ const rectangle& rect
+ ) const
+ {
+ return (rect + *this == *this);
+ }
+
+ rectangle& operator+= (
+ const point& p
+ )
+ {
+ *this = *this + rectangle(p);
+ return *this;
+ }
+
+ rectangle& operator+= (
+ const rectangle& rect
+ )
+ {
+ *this = *this + rect;
+ return *this;
+ }
+
+ bool operator== (
+ const rectangle& rect
+ ) const
+ {
+ return (l == rect.l) && (t == rect.t) && (r == rect.r) && (b == rect.b);
+ }
+
+ bool operator!= (
+ const rectangle& rect
+ ) const
+ {
+ return !(*this == rect);
+ }
+
+ inline bool operator< (const dlib::rectangle& b) const
+ {
+ if (left() < b.left()) return true;
+ else if (left() > b.left()) return false;
+ else if (top() < b.top()) return true;
+ else if (top() > b.top()) return false;
+ else if (right() < b.right()) return true;
+ else if (right() > b.right()) return false;
+ else if (bottom() < b.bottom()) return true;
+ else if (bottom() > b.bottom()) return false;
+ else return false;
+ }
+
+ private:
+ long l;
+ long t;
+ long r;
+ long b;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const rectangle& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.left(),out);
+ serialize(item.top(),out);
+ serialize(item.right(),out);
+ serialize(item.bottom(),out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing an object of type rectangle");
+ }
+ }
+
+ inline void deserialize (
+ rectangle& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.left(),in);
+ deserialize(item.top(),in);
+ deserialize(item.right(),in);
+ deserialize(item.bottom(),in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing an object of type rectangle");
+ }
+ }
+
+ inline std::ostream& operator<< (
+ std::ostream& out,
+ const rectangle& item
+ )
+ {
+ out << "[(" << item.left() << ", " << item.top() << ") (" << item.right() << ", " << item.bottom() << ")]";
+ return out;
+ }
+
+ inline std::istream& operator>>(
+ std::istream& in,
+ rectangle& item
+ )
+ {
+ // ignore any whitespace
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n')
+ in.get();
+ // now eat the leading '[' character
+ if (in.get() != '[')
+ {
+ in.setstate(in.rdstate() | std::ios::failbit);
+ return in;
+ }
+
+ point p1, p2;
+ in >> p1;
+ in >> p2;
+ item = rectangle(p1) + rectangle(p2);
+
+ // ignore any whitespace
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n')
+ in.get();
+ // now eat the trailing ']' character
+ if (in.get() != ']')
+ {
+ in.setstate(in.rdstate() | std::ios::failbit);
+ }
+ return in;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle centered_rect (
+ long x,
+ long y,
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ rectangle result;
+ result.set_left ( x - static_cast<long>(width) / 2 );
+ result.set_top ( y - static_cast<long>(height) / 2 );
+ result.set_right ( result.left() + width - 1 );
+ result.set_bottom ( result.top() + height - 1 );
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle intersect (
+ const rectangle& a,
+ const rectangle& b
+ ) { return a.intersect(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ inline unsigned long area (
+ const rectangle& a
+ ) { return a.area(); }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point center (
+ const dlib::rectangle& rect
+ )
+ {
+ point temp(rect.left() + rect.right() + 1,
+ rect.top() + rect.bottom() + 1);
+
+ if (temp.x() < 0)
+ temp.x() -= 1;
+
+ if (temp.y() < 0)
+ temp.y() -= 1;
+
+ return temp/2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline dlib::vector<double,2> dcenter (
+ const dlib::rectangle& rect
+ )
+ {
+ dlib::vector<double,2> temp(rect.left() + rect.right(),
+ rect.top() + rect.bottom());
+
+ return temp/2.0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline long distance_to_rect_edge (
+ const rectangle& rect,
+ const point& p
+ )
+ {
+ using std::max;
+ using std::min;
+ using std::abs;
+
+ const long dist_x = min(abs(p.x()-rect.left()), abs(p.x()-rect.right()));
+ const long dist_y = min(abs(p.y()-rect.top()), abs(p.y()-rect.bottom()));
+
+ if (rect.contains(p))
+ return min(dist_x,dist_y);
+ else if (rect.left() <= p.x() && p.x() <= rect.right())
+ return dist_y;
+ else if (rect.top() <= p.y() && p.y() <= rect.bottom())
+ return dist_x;
+ else
+ return dist_x + dist_y;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const point nearest_point (
+ const rectangle& rect,
+ const point& p
+ )
+ {
+ point temp(p);
+ if (temp.x() < rect.left())
+ temp.x() = rect.left();
+ else if (temp.x() > rect.right())
+ temp.x() = rect.right();
+
+ if (temp.y() < rect.top())
+ temp.y() = rect.top();
+ else if (temp.y() > rect.bottom())
+ temp.y() = rect.bottom();
+
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline size_t nearest_rect (
+ const std::vector<rectangle>& rects,
+ const point& p
+ )
+ {
+ DLIB_ASSERT(rects.size() > 0);
+ size_t idx = 0;
+ double best_dist = std::numeric_limits<double>::infinity();
+
+ for (size_t i = 0; i < rects.size(); ++i)
+ {
+ if (rects[i].contains(p))
+ {
+ return i;
+ }
+ else
+ {
+ double dist = (nearest_point(rects[i],p)-p).length();
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ idx = i;
+ }
+ }
+ }
+ return idx;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ double distance_to_line (
+ const std::pair<vector<T,2>,vector<T,2> >& line,
+ const vector<U,2>& p
+ )
+ {
+ const vector<double,2> delta = p-line.second;
+ const double along_dist = (line.first-line.second).normalize().dot(delta);
+ return std::sqrt(std::max(0.0,delta.length_squared() - along_dist*along_dist));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void clip_line_to_rectangle (
+ const rectangle& box,
+ point& p1,
+ point& p2
+ )
+ {
+ // Now clip the line segment so it is contained inside box.
+ if (p1.x() == p2.x())
+ {
+ if (!box.contains(p1))
+ p1.y() = box.top();
+ if (!box.contains(p2))
+ p2.y() = box.bottom();
+ }
+ else if (p1.y() == p2.y())
+ {
+ if (!box.contains(p1))
+ p1.x() = box.left();
+ if (!box.contains(p2))
+ p2.x() = box.right();
+ }
+ else
+ {
+ // We use these relations to find alpha values. These values tell us
+ // how to generate points intersecting the rectangle boundaries. We then
+ // test the resulting points for ones that are inside the rectangle and output
+ // those.
+ //box.left() == alpha1*(p1.x()-p2.x()) + p2.x();
+ //box.right() == alpha2*(p1.x()-p2.x()) + p2.x();
+
+ const point d = p1-p2;
+ double alpha1 = (box.left() -p2.x())/(double)d.x();
+ double alpha2 = (box.right() -p2.x())/(double)d.x();
+ double alpha3 = (box.top() -p2.y())/(double)d.y();
+ double alpha4 = (box.bottom()-p2.y())/(double)d.y();
+
+ const point c1 = alpha1*d + p2;
+ const point c2 = alpha2*d + p2;
+ const point c3 = alpha3*d + p2;
+ const point c4 = alpha4*d + p2;
+
+ if (!box.contains(p1))
+ p1 = c1;
+ if (!box.contains(p2))
+ p2 = c2;
+ if (box.contains(c3))
+ {
+ if (!box.contains(p2))
+ p2 = c3;
+ else if (!box.contains(p1))
+ p1 = c3;
+ }
+ if (box.contains(c4))
+ {
+ if (!box.contains(p2))
+ p2 = c4;
+ else if (!box.contains(p1))
+ p1 = c4;
+ }
+ }
+
+ p1 = nearest_point(box, p1);
+ p2 = nearest_point(box, p2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle centered_rect (
+ const point& p,
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ return centered_rect(p.x(),p.y(),width,height);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle centered_rect (
+ const rectangle& rect,
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ return centered_rect((rect.left()+rect.right())/2, (rect.top()+rect.bottom())/2, width, height);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle shrink_rect (
+ const rectangle& rect,
+ long num
+ )
+ {
+ return rectangle(rect.left()+num, rect.top()+num, rect.right()-num, rect.bottom()-num);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle grow_rect (
+ const rectangle& rect,
+ long num
+ )
+ {
+ return shrink_rect(rect, -num);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle shrink_rect (
+ const rectangle& rect,
+ long width,
+ long height
+ )
+ {
+ return rectangle(rect.left()+width, rect.top()+height, rect.right()-width, rect.bottom()-height);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle grow_rect (
+ const rectangle& rect,
+ long width,
+ long height
+ )
+ {
+ return shrink_rect(rect, -width, -height);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle translate_rect (
+ const rectangle& rect,
+ const point& p
+ )
+ {
+ rectangle result;
+ result.set_top ( rect.top() + p.y() );
+ result.set_bottom ( rect.bottom() + p.y() );
+ result.set_left ( rect.left() + p.x() );
+ result.set_right ( rect.right() + p.x() );
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle translate_rect (
+ const rectangle& rect,
+ long x,
+ long y
+ )
+ {
+ rectangle result;
+ result.set_top ( rect.top() + y );
+ result.set_bottom ( rect.bottom() + y );
+ result.set_left ( rect.left() + x );
+ result.set_right ( rect.right() + x );
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle resize_rect (
+ const rectangle& rect,
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ return rectangle(rect.left(),rect.top(),
+ rect.left()+width-1,
+ rect.top()+height-1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle resize_rect_width (
+ const rectangle& rect,
+ unsigned long width
+ )
+ {
+ return rectangle(rect.left(),rect.top(),
+ rect.left()+width-1,
+ rect.bottom());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle resize_rect_height (
+ const rectangle& rect,
+ unsigned long height
+ )
+ {
+ return rectangle(rect.left(),rect.top(),
+ rect.right(),
+ rect.top()+height-1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle move_rect (
+ const rectangle& rect,
+ const point& p
+ )
+ {
+ return rectangle(p.x(), p.y(), p.x()+rect.width()-1, p.y()+rect.height()-1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle move_rect (
+ const rectangle& rect,
+ long x,
+ long y
+ )
+ {
+ return rectangle(x, y, x+rect.width()-1, y+rect.height()-1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle set_rect_area (
+ const rectangle& rect,
+ unsigned long area
+ )
+ {
+ DLIB_ASSERT(area > 0);
+
+ if (rect.area() == 0)
+ {
+ // In this case we will make the output rectangle a square with the requested
+ // area.
+ unsigned long scale = std::round(std::sqrt(area));
+ return centered_rect(rect, scale, scale);
+ }
+ else
+ {
+ double scale = std::sqrt(area/(double)rect.area());
+ return centered_rect(rect, (long)std::round(rect.width()*scale), (long)std::round(rect.height()*scale));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle set_aspect_ratio (
+ const rectangle& rect,
+ double ratio
+ )
+ {
+ DLIB_ASSERT(ratio > 0,
+ "\t rectangle set_aspect_ratio()"
+ << "\n\t ratio: " << ratio
+ );
+
+ // aspect ratio is w/h
+
+ // we need to find the rectangle that is nearest to rect in area but
+ // with an aspect ratio of ratio.
+
+ // w/h == ratio
+ // w*h == rect.area()
+
+ if (ratio >= 1)
+ {
+ const long h = static_cast<long>(std::sqrt(rect.area()/ratio) + 0.5);
+ const long w = static_cast<long>(h*ratio + 0.5);
+ return centered_rect(rect, w, h);
+ }
+ else
+ {
+ const long w = static_cast<long>(std::sqrt(rect.area()*ratio) + 0.5);
+ const long h = static_cast<long>(w/ratio + 0.5);
+ return centered_rect(rect, w, h);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ inline const rectangle get_rect (
+ const T& m
+ )
+ {
+ return rectangle(0, 0, num_columns(m)-1, num_rows(m)-1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle operator+ (
+ const rectangle& r,
+ const point& p
+ )
+ {
+ return r + rectangle(p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle operator+ (
+ const point& p,
+ const rectangle& r
+ )
+ {
+ return r + rectangle(p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RECTANGLe_
+
+
diff --git a/ml/dlib/dlib/geometry/rectangle_abstract.h b/ml/dlib/dlib/geometry/rectangle_abstract.h
new file mode 100644
index 000000000..0ff0f0a8d
--- /dev/null
+++ b/ml/dlib/dlib/geometry/rectangle_abstract.h
@@ -0,0 +1,836 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RECTANGLe_ABSTRACT_
+#ifdef DLIB_RECTANGLe_ABSTRACT_
+
+#include "vector_abstract.h"
+#include <iostream>
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ class rectangle
+ {
+ /*!
+ INITIAL VALUE
+ The initial value of this object is defined by its constructor.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a rectangular region inside a Cartesian
+ coordinate system. The region is the rectangle with its top
+ left corner at position (left(),top()) and its bottom right corner
+ at (right(),bottom()).
+
+ Note that the origin of the coordinate system, i.e. (0,0), is located
+ at the upper left corner. That is, points such as (1,1) or (3,5)
+ represent locations that are below and to the right of the origin.
+
+ Also note that rectangles where top() > bottom() or left() > right()
+ represent empty rectangles.
+ !*/
+
+ public:
+
+ rectangle (
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - #*this represents the same rectangle as rect
+ !*/
+
+ rectangle (
+ );
+ /*!
+ ensures
+ - #left() == 0
+ - #top() == 0
+ - #right() == -1
+ - #bottom() == -1
+ - #is_empty() == true
+ !*/
+
+ rectangle (
+ long left_,
+ long top_,
+ long right_,
+ long bottom_
+ );
+ /*!
+ ensures
+ - #left() == left_
+ - #top() == top_
+ - #right() == right_
+ - #bottom() == bottom_
+ !*/
+
+ rectangle (
+ unsigned long width_,
+ unsigned long height_
+ );
+ /*!
+ requires
+ - (width_ > 0 && height_ > 0) || (width_ == 0 && height_ == 0)
+ ensures
+ - #left() == 0
+ - #top() == 0
+ - #width() == width_
+ - #height() == height_
+ !*/
+
+ rectangle (
+ const point& p
+ );
+ /*!
+ ensures
+ - #left() == p.x()
+ - #top() == p.y()
+ - #right() == p.x()
+ - #bottom() == p.y()
+ !*/
+
+ template <typename T>
+ rectangle (
+ const vector<T,2>& p1,
+ const vector<T,2>& p2
+ );
+ /*!
+ ensures
+ - #*this == rectangle(p1) + rectangle(p2)
+ !*/
+
+ long left (
+ ) const;
+ /*!
+ ensures
+ - returns the x coordinate for the left side of this rectangle
+ !*/
+
+ long& left (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the x coordinate for the left side
+ of this rectangle
+ !*/
+
+ void set_left (
+ long left_
+ );
+ /*!
+ ensures
+ - #left() == left_
+ !*/
+
+ long top (
+ ) const;
+ /*!
+ ensures
+ - returns the y coordinate for the top of this rectangle
+ !*/
+
+ long& top (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the y coordinate for the
+ top of this rectangle
+ !*/
+
+ void set_top (
+ long top_
+ );
+ /*!
+ ensures
+ - #top() == top_
+ !*/
+
+ long right (
+ ) const;
+ /*!
+ ensures
+ - returns the x coordinate for the right side of this rectangle
+ !*/
+
+ long& right (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the x coordinate for the right
+ side of this rectangle
+ !*/
+
+ void set_right (
+ long right_
+ );
+ /*!
+ ensures
+ - #right() == right_
+ !*/
+
+ long bottom (
+ ) const;
+ /*!
+ ensures
+ - returns the y coordinate for the bottom of this rectangle
+ !*/
+
+ long& bottom (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the y coordinate for the bottom
+ of this rectangle
+ !*/
+
+ void set_bottom (
+ long bottom_
+ );
+ /*!
+ ensures
+ - #bottom() == bottom_
+ !*/
+
+ const point tl_corner (
+ ) const;
+ /*!
+ ensures
+ - returns point(left(), top())
+ (i.e. returns the top left corner point for this rectangle)
+ !*/
+
+ const point bl_corner (
+ ) const;
+ /*!
+ ensures
+ - returns point(left(), bottom())
+ (i.e. returns the bottom left corner point for this rectangle)
+ !*/
+
+ const point tr_corner (
+ ) const;
+ /*!
+ ensures
+ - returns point(right(), top())
+ (i.e. returns the top right corner point for this rectangle)
+ !*/
+
+ const point br_corner (
+ ) const;
+ /*!
+ ensures
+ - returns point(right(), bottom())
+ (i.e. returns the bottom right corner point for this rectangle)
+ !*/
+
+ bool is_empty (
+ ) const;
+ /*!
+ ensures
+ - if (top() > bottom() || left() > right()) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ unsigned long width (
+ ) const;
+ /*!
+ ensures
+ - if (is_empty()) then
+ - returns 0
+ - else
+ - returns the width of this rectangle.
+ (i.e. right() - left() + 1)
+ !*/
+
+ unsigned long height (
+ ) const;
+ /*!
+ ensures
+ - if (is_empty()) then
+ - returns 0
+ - else
+ - returns the height of this rectangle.
+ (i.e. bottom() - top() + 1)
+ !*/
+
+ unsigned long area (
+ ) const;
+ /*!
+ ensures
+ - returns width()*height()
+ !*/
+
+ rectangle operator + (
+ const rectangle& rhs
+ ) const;
+ /*!
+ ensures
+ - if (rhs.is_empty() == false && this->is_empty() == false) then
+ - returns the smallest rectangle that contains both *this and
+ rhs.
+ - if (rhs.is_empty() == true && this->is_empty() == false) then
+ - returns *this
+ - if (rhs.is_empty() == false && this->is_empty() == true) then
+ - returns rhs
+ - if (rhs.is_empty() == true && this->is_empty() == true) then
+ - returns a rectangle that has is_empty() == true
+ !*/
+
+ rectangle intersect (
+ const rectangle& rhs
+ ) const;
+ /*!
+ ensures
+ - if (there is a region of intersection between *this and rhs) then
+ - returns a rectangle that represents the intersection of *this
+ and rhs.
+ - else
+ - returns a rectangle where is_empty() == true
+ !*/
+
+ bool contains (
+ long x,
+ long y
+ ) const;
+ /*!
+ ensures
+ - if (the point (x,y) is contained in this rectangle) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool contains (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - if (the point (p.x(),p.y()) is contained in this rectangle) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool contains (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - if (rect + *this == *this) then
+ - returns true
+ (i.e. returns true if *this contains the given rectangle)
+ - else
+ - returns false
+ !*/
+
+ rectangle& operator= (
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - #*this represents the same rectangle as rect
+ - returns #*this
+ !*/
+
+ rectangle& operator+= (
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - #*this == *this + rect
+ - returns #*this
+ !*/
+
+ bool operator== (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - if (top() == rect.top() && left() == rect.left() &&
+ right() == rect.right() && bottom() == rect.bottom()) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool operator!= (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns !(*this == rect)
+ !*/
+
+ bool operator< (
+ const dlib::rectangle& a,
+ const dlib::rectangle& b
+ ) const;
+ /*!
+ ensures
+ - Defines a total ordering over rectangles so they can be used in
+ associative containers.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const rectangle& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ rectangle& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+ std::ostream& operator<< (
+ std::ostream& out,
+ const rectangle& item
+ );
+ /*!
+ ensures
+ - writes item to out in the form "[(left, top) (right, bottom)]"
+ !*/
+
+ std::istream& operator>>(
+ std::istream& in,
+ rectangle& item
+ );
+ /*!
+ ensures
+ - reads a rectangle from the input stream in and stores it in #item.
+ The data in the input stream should be of the form [(left, top) (right, bottom)]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point center (
+ const dlib::rectangle& rect
+ );
+ /*!
+ ensures
+ - returns the center of the given rectangle
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ dlib::vector<double,2> dcenter (
+ const dlib::rectangle& rect
+ );
+ /*!
+ ensures
+ - returns the center of the given rectangle using a real valued vector.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle centered_rect (
+ const point& p,
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - center(R) == p
+ - if (width == 0 || height == 0)
+ - R.width() == 0
+ - R.height() == 0
+ - else
+ - R.width() == width
+ - R.height() == height
+ - R.tl_corner() == point(p.x()-width/2, p.y()-height/2)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle centered_rect (
+ long x,
+ long y,
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - center(R) == p
+ - if (width == 0 || height == 0)
+ - R.width() == 0
+ - R.height() == 0
+ - else
+ - R.width() == width
+ - R.height() == height
+ - R.tl_corner() == point(x-width/2, y-height/2)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle centered_rect (
+ const rectangle& rect,
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ ensures
+ - returns centered_rect( (rect.tl_corner() + rect.br_corner())/2, width, height)
+ (i.e. returns a rectangle centered on rect but with the given width
+ and height)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle set_rect_area (
+ const rectangle& rect,
+ unsigned long area
+ );
+ /*!
+ requires
+ - area > 0
+ ensures
+ - Returns a rectangle R such that:
+ - center(R) == center(rect)
+ - R has the same aspect ratio as rect. If rect.area() == 0 then the
+ returned rect has a 1:1 aspect ratio.
+ - R.area() == area
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle set_aspect_ratio (
+ const rectangle& rect,
+ double ratio
+ );
+ /*!
+ requires
+ - ratio > 0
+ ensures
+ - This function reshapes the given rectangle so that it has the given aspect
+ ratio. In particular, this means we return a rectangle R such that the
+ following equations are as true as possible:
+ - R.width()/R.height() == ratio
+ - R.area() == rect.area()
+ - center(rect) == center(R)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle intersect (
+ const rectangle& a,
+ const rectangle& b
+ );
+ /*!
+ ensures
+ - returns a.intersect(b)
+ (i.e. returns a rectangle representing the intersection of a and b)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline unsigned long area (
+ const rectangle& a
+ );
+ /*!
+ ensures
+ - returns a.area()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle shrink_rect (
+ const rectangle& rect,
+ long num
+ );
+ /*!
+ ensures
+ - returns rectangle(rect.left()+num, rect.top()+num, rect.right()-num, rect.bottom()-num)
+ (i.e. shrinks the given rectangle by shrinking its border by num)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle grow_rect (
+ const rectangle& rect,
+ long num
+ );
+ /*!
+ ensures
+ - return shrink_rect(rect, -num)
+ (i.e. grows the given rectangle by expanding its border by num)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle shrink_rect (
+ const rectangle& rect,
+ long width,
+ long height
+ );
+ /*!
+ ensures
+ - returns rectangle(rect.left()+width, rect.top()+height, rect.right()-width, rect.bottom()-height)
+ (i.e. shrinks the given rectangle by shrinking its left and right borders by width
+ and its top and bottom borders by height. )
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline const rectangle grow_rect (
+ const rectangle& rect,
+ long width,
+ long height
+ );
+ /*!
+ ensures
+ - return shrink_rect(rect, -width, -height)
+ (i.e. grows the given rectangle by expanding its border)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle translate_rect (
+ const rectangle& rect,
+ const point& p
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - R.left() == rect.left() + p.x()
+ - R.right() == rect.right() + p.x()
+ - R.top() == rect.top() + p.y()
+ - R.bottom() == rect.bottom() + p.y()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle translate_rect (
+ const rectangle& rect,
+ long x,
+ long y
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - R.left() == rect.left() + x
+ - R.right() == rect.right() + x
+ - R.top() == rect.top() + y
+ - R.bottom() == rect.bottom() + y
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle resize_rect (
+ const rectangle& rect,
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - if (width == 0 || height == 0)
+ - R.width() == 0
+ - R.height() == 0
+ - else
+ - R.width() == width
+ - R.height() == height
+ - R.left() == rect.left()
+ - R.top() == rect.top()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle resize_rect_width (
+ const rectangle& rect,
+ unsigned long width
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - R.width() == width
+ - R.left() == rect.left()
+ - R.top() == rect.top()
+ - R.bottom() == rect.bottom()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle resize_rect_height (
+ const rectangle& rect,
+ unsigned long height
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - R.height() == height
+ - R.left() == rect.left()
+ - R.top() == rect.top()
+ - R.right() == rect.right()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle move_rect (
+ const rectangle& rect,
+ const point& p
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - R.width() == rect.width()
+ - R.height() == rect.height()
+ - R.left() == p.x()
+ - R.top() == p.y()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle move_rect (
+ const rectangle& rect,
+ long x,
+ long y
+ );
+ /*!
+ ensures
+ - returns a rectangle R such that:
+ - R.width() == rect.width()
+ - R.height() == rect.height()
+ - R.left() == x
+ - R.top() == y
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline const point nearest_point (
+ const rectangle& rect,
+ const point& p
+ );
+ /*!
+ ensures
+ - if (rect.contains(p)) then
+ - returns p
+ - else
+ - returns the point in rect that is closest to p
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline size_t nearest_rect (
+ const std::vector<rectangle>& rects,
+ const point& p
+ );
+ /*!
+ requires
+ - rects.size() > 0
+ ensures
+ - returns the index of the rectangle that is closest to the point p. In
+ particular, this function returns an IDX such that:
+ length(nearest_point(rects[IDX],p) - p)
+ is minimized.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline long distance_to_rect_edge (
+ const rectangle& rect,
+ const point& p
+ );
+ /*!
+ ensures
+ - returns the Manhattan distance between the edge of rect and p.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ double distance_to_line (
+ const std::pair<vector<T,2>,vector<T,2> >& line,
+ const vector<U,2>& p
+ );
+ /*!
+ ensures
+ - returns the euclidean distance between the given line and the point p. That
+ is, given a line that passes though the points line.first and line.second,
+ what is the distance between p and the nearest point on the line? This
+ function returns that distance.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void clip_line_to_rectangle (
+ const rectangle& box,
+ point& p1,
+ point& p2
+ );
+ /*!
+ ensures
+ - clips the line segment that goes from points p1 to p2 so that it is entirely
+ within the given box. In particular, we will have:
+ - box.contains(#p1) == true
+ - box.contains(#p2) == true
+ - The line segment #p1 to #p2 is entirely contained within the line segment
+ p1 to p2. Moreover, #p1 to #p2 is the largest such line segment that
+ fits within the given box.
+ - If the line segment does not intersect the box then the result is some
+ arbitrary line segment inside the box.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const rectangle get_rect (
+ const T& m
+ );
+ /*!
+ requires
+ - It must be possible to determine the number of "rows" and "columns" in m.
+ Either by calling num_rows(m) and num_columns(m) or by calling m.nr() and
+ m.nc() to obtain the number of rows and columns respectively. Moreover,
+ these routines should return longs.
+ ensures
+ - returns rectangle(0, 0, num_columns(m)-1, num_rows(m)-1)
+ (i.e. assuming T represents some kind of rectangular grid, such as
+ the dlib::matrix or dlib::array2d objects, this function returns the
+ bounding rectangle for that gridded object.)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle operator+ (
+ const rectangle& r,
+ const point& p
+ );
+ /*!
+ ensures
+ - returns r + rectangle(p)
+ (i.e. returns the rectangle that contains both r and p)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle operator+ (
+ const point& p,
+ const rectangle& r
+ );
+ /*!
+ ensures
+ - returns r + rectangle(p)
+ (i.e. returns the rectangle that contains both r and p)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RECTANGLe_ABSTRACT_
+
diff --git a/ml/dlib/dlib/geometry/vector.h b/ml/dlib/dlib/geometry/vector.h
new file mode 100644
index 000000000..4ea53799d
--- /dev/null
+++ b/ml/dlib/dlib/geometry/vector.h
@@ -0,0 +1,1330 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_VECTOr_H_
+#define DLIB_VECTOr_H_
+
+#include <cmath>
+#include "vector_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include <functional>
+#include <iostream>
+#include "../matrix/matrix.h"
+#include <limits>
+
+#if defined(_MSC_VER) && _MSC_VER < 1400
+// Despite my efforts to disabuse visual studio of its usual nonsense I can't find a
+// way to make this warning go away without just disabling it. This is the warning:
+// dlib\geometry\vector.h(129) : warning C4805: '==' : unsafe mix of type 'std::numeric_limits<_Ty>::is_integer' and type 'bool' in operation
+//
+#pragma warning(disable:4805)
+#endif
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ long NR = 3
+ >
+ class vector;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename enabled = void>
+ struct vect_promote;
+
+ template <typename T, typename U, bool res = (sizeof(T) <= sizeof(U))>
+ struct largest_type
+ {
+ typedef T type;
+ };
+ template <typename T, typename U>
+ struct largest_type<T,U,true>
+ {
+ typedef U type;
+ };
+
+ template <typename T, typename U>
+ struct vect_promote<T,U, typename enable_if_c<std::numeric_limits<T>::is_integer == std::numeric_limits<U>::is_integer>::type>
+ {
+ // If both T and U are both either integral or non-integral then just
+ // use the biggest one
+ typedef typename largest_type<T,U>::type type;
+ };
+
+ template <typename T, typename U>
+ struct vect_promote<T,U, typename enable_if_c<std::numeric_limits<T>::is_integer != std::numeric_limits<U>::is_integer>::type>
+ {
+ typedef double type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ // This insanity here is to work around a bug in visual studio 8. These two rebind
+ // structures are actually declared at a few points in this file because just having the
+ // one declaration here isn't enough for visual studio. It takes the three spread around
+ // to avoid all its bugs.
+ template <typename T, long N>
+ struct vc_rebind
+ {
+ typedef vector<T,N> type;
+ };
+ template <typename T, typename U, long N>
+ struct vc_rebind_promote
+ {
+ typedef vector<typename vect_promote<T,U>::type,N> type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename enabled = void>
+ struct vector_assign_helper
+ {
+ template <long NR>
+ static void assign (
+ vector<T,2>& dest,
+ const vector<U,NR>& src
+ )
+ {
+ dest.x() = static_cast<T>(src.x());
+ dest.y() = static_cast<T>(src.y());
+ }
+
+ template <long NR>
+ static void assign (
+ vector<T,3>& dest,
+ const vector<U,NR>& src
+ )
+ {
+ dest.x() = static_cast<T>(src.x());
+ dest.y() = static_cast<T>(src.y());
+ dest.z() = static_cast<T>(src.z());
+ }
+
+ template <typename EXP>
+ static void assign (
+ vector<T,2>& dest,
+ const matrix_exp<EXP>& m
+ )
+ {
+ T x = static_cast<T>(m(0));
+ T y = static_cast<T>(m(1));
+ dest.x() = x;
+ dest.y() = y;
+ }
+
+ template <typename EXP>
+ static void assign (
+ vector<T,3>& dest,
+ const matrix_exp<EXP>& m
+ )
+ {
+ T x = static_cast<T>(m(0));
+ T y = static_cast<T>(m(1));
+ T z = static_cast<T>(m(2));
+
+ dest.x() = x;
+ dest.y() = y;
+ dest.z() = z;
+ }
+ };
+
+ // This is an overload for the case where you are converting from a floating point
+ // type to an integral type. These overloads make sure values are rounded to
+ // the nearest integral value.
+ template <typename T, typename U>
+ struct vector_assign_helper<T,U, typename enable_if_c<std::numeric_limits<T>::is_integer == true &&
+ std::numeric_limits<U>::is_integer == false>::type>
+ {
+ template <long NR>
+ static void assign (
+ vector<T,2>& dest,
+ const vector<U,NR>& src
+ )
+ {
+ dest.x() = static_cast<T>(std::floor(src.x() + 0.5));
+ dest.y() = static_cast<T>(std::floor(src.y() + 0.5));
+ }
+
+ template <long NR>
+ static void assign (
+ vector<T,3>& dest,
+ const vector<U,NR>& src
+ )
+ {
+ dest.x() = static_cast<T>(std::floor(src.x() + 0.5));
+ dest.y() = static_cast<T>(std::floor(src.y() + 0.5));
+ dest.z() = static_cast<T>(std::floor(src.z() + 0.5));
+ }
+
+ template <typename EXP>
+ static void assign (
+ vector<T,3>& dest,
+ const matrix_exp<EXP>& m
+ )
+ {
+ dest.x() = static_cast<T>(std::floor(m(0) + 0.5));
+ dest.y() = static_cast<T>(std::floor(m(1) + 0.5));
+ dest.z() = static_cast<T>(std::floor(m(2) + 0.5));
+ }
+
+ template <typename EXP>
+ static void assign (
+ vector<T,2>& dest,
+ const matrix_exp<EXP>& m
+ )
+ {
+ dest.x() = static_cast<T>(std::floor(m(0) + 0.5));
+ dest.y() = static_cast<T>(std::floor(m(1) + 0.5));
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class vector<T,3> : public matrix<T,3,1>
+ {
+ /*!
+ INITIAL VALUE
+ - x() == 0
+ - y() == 0
+ - z() == 0
+
+ CONVENTION
+ - (*this)(0) == x()
+ - (*this)(1) == y()
+ - (*this)(2) == z()
+
+ !*/
+
+ // This insanity here is to work around a bug in visual studio 8.
+ template <typename V, long N>
+ struct vc_rebind
+ {
+ typedef vector<V,N> type;
+ };
+ template <typename V, typename U, long N>
+ struct vc_rebind_promote
+ {
+ typedef vector<typename vect_promote<V,U>::type,N> type;
+ };
+
+ public:
+
+ typedef T type;
+
+ vector (
+ )
+ {
+ x() = 0;
+ y() = 0;
+ z() = 0;
+ }
+
+ // ---------------------------------------
+
+ vector (
+ const T _x,
+ const T _y,
+ const T _z
+ )
+ {
+ x() = _x;
+ y() = _y;
+ z() = _z;
+ }
+
+ // ---------------------------------------
+
+ vector (
+ const vector& item
+ ) : matrix<T,3,1>(item)
+ {
+ }
+
+ // ---------------------------------------
+
+ template <typename U>
+ vector (
+ const vector<U,2>& item
+ )
+ {
+ // Do this so that we get the appropriate rounding depending on the relative
+ // type of T and U.
+ vector<T,2> temp(item);
+ x() = temp.x();
+ y() = temp.y();
+ z() = 0;
+ }
+
+ // ---------------------------------------
+
+ vector (
+ const vector<T,2>& item
+ )
+ {
+ x() = item.x();
+ y() = item.y();
+ z() = 0;
+ }
+
+ // ---------------------------------------
+
+ template <typename U>
+ vector (
+ const vector<U,3>& item
+ )
+ {
+ (*this) = item;
+ }
+
+ // ---------------------------------------
+
+ template <typename EXP>
+ vector ( const matrix_exp<EXP>& m)
+ {
+ (*this) = m;
+ }
+
+ // ---------------------------------------
+
+ template <typename EXP>
+ vector& operator = (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only assign vectors with 3 elements to a dlib::vector<T,3> object
+ COMPILE_TIME_ASSERT(EXP::NR*EXP::NC == 3 || EXP::NR*EXP::NC == 0);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT((m.nr() == 1 || m.nc() == 1) && (m.size() == 3),
+ "\t vector(const matrix_exp& m)"
+ << "\n\t the given matrix is of the wrong size"
+ << "\n\t m.nr(): " << m.nr()
+ << "\n\t m.nc(): " << m.nc()
+ << "\n\t m.size(): " << m.size()
+ << "\n\t this: " << this
+ );
+
+ vector_assign_helper<T, typename EXP::type>::assign(*this, m);
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long N>
+ vector& operator = (
+ const vector<U,N>& item
+ )
+ {
+ vector_assign_helper<T,U>::assign(*this, item);
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector& operator= (
+ const vector& item
+ )
+ {
+ x() = item.x();
+ y() = item.y();
+ z() = item.z();
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ double length(
+ ) const
+ {
+ return std::sqrt((double)(x()*x() + y()*y() + z()*z()));
+ }
+
+ // ---------------------------------------
+
+ double length_squared(
+ ) const
+ {
+ return (double)(x()*x() + y()*y() + z()*z());
+ }
+
+ // ---------------------------------------
+
+ typename vc_rebind<double,3>::type normalize (
+ ) const
+ {
+ const double tmp = std::sqrt((double)(x()*x() + y()*y() + z()*z()));
+ return vector<double,3> ( x()/tmp,
+ y()/tmp,
+ z()/tmp
+ );
+ }
+
+ // ---------------------------------------
+
+ T& x (
+ )
+ {
+ return (*this)(0);
+ }
+
+ // ---------------------------------------
+
+ T& y (
+ )
+ {
+ return (*this)(1);
+ }
+
+ // ---------------------------------------
+
+ T& z (
+ )
+ {
+ return (*this)(2);
+ }
+
+ // ---------------------------------------
+
+ const T& x (
+ ) const
+ {
+ return (*this)(0);
+ }
+
+ // ---------------------------------------
+
+ const T& y (
+ ) const
+ {
+ return (*this)(1);
+ }
+
+ // ---------------------------------------
+
+ const T& z (
+ ) const
+ {
+ return (*this)(2);
+ }
+
+ // ---------------------------------------
+
+ T dot (
+ const vector& rhs
+ ) const
+ {
+ return x()*rhs.x() + y()*rhs.y() + z()*rhs.z();
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long N>
+ typename vect_promote<T,U>::type dot (
+ const vector<U,N>& rhs
+ ) const
+ {
+ return x()*rhs.x() + y()*rhs.y() + z()*rhs.z();
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long N>
+ typename vc_rebind_promote<T,U,3>::type cross (
+ const vector<U,N>& rhs
+ ) const
+ {
+ typedef vector<typename vect_promote<T,U>::type,3> ret_type;
+
+ return ret_type (
+ y()*rhs.z() - z()*rhs.y(),
+ z()*rhs.x() - x()*rhs.z(),
+ x()*rhs.y() - y()*rhs.x()
+ );
+ }
+
+ // ---------------------------------------
+
+ vector& operator += (
+ const vector& rhs
+ )
+ {
+ x() += rhs.x();
+ y() += rhs.y();
+ z() += rhs.z();
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector& operator -= (
+ const vector& rhs
+ )
+ {
+ x() -= rhs.x();
+ y() -= rhs.y();
+ z() -= rhs.z();
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector& operator /= (
+ const T& rhs
+ )
+ {
+ x() /= rhs;
+ y() /= rhs;
+ z() /= rhs;
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector& operator *= (
+ const T& rhs
+ )
+ {
+ x() *= rhs;
+ y() *= rhs;
+ z() *= rhs;
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector operator - (
+ ) const
+ {
+ return vector(-x(), -y(), -z());
+ }
+
+ // ---------------------------------------
+
+ template <typename U>
+ typename vc_rebind_promote<T,U,3>::type operator / (
+ const U& val
+ ) const
+ {
+ typedef vector<typename vect_promote<T,U>::type,3> ret_type;
+ return ret_type(x()/val, y()/val, z()/val);
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long NR2>
+ bool operator== (
+ const vector<U,NR2>& rhs
+ ) const
+ {
+ return x()==rhs.x() && y()==rhs.y() && z()==rhs.z();
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long NR2>
+ bool operator!= (
+ const vector<U,NR2>& rhs
+ ) const
+ {
+ return !(*this == rhs);
+ }
+
+ // ---------------------------------------
+
+ void swap (
+ vector& item
+ )
+ {
+ dlib::exchange(x(), item.x());
+ dlib::exchange(y(), item.y());
+ dlib::exchange(z(), item.z());
+ }
+
+ // ---------------------------------------
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class vector<T,2> : public matrix<T,2,1>
+ {
+ /*!
+ INITIAL VALUE
+ - x() == 0
+ - y() == 0
+
+ CONVENTION
+ - (*this)(0) == x()
+ - (*this)(1) == y()
+ - z() == 0
+ !*/
+
+ // This insanity here is to work around a bug in visual studio 8.
+ template <typename V, long N>
+ struct vc_rebind
+ {
+ typedef vector<V,N> type;
+ };
+ template <typename V, typename U, long N>
+ struct vc_rebind_promote
+ {
+ typedef vector<typename vect_promote<V,U>::type,N> type;
+ };
+
+
+ public:
+
+ typedef T type;
+
+ vector (
+ )
+ {
+ x() = 0;
+ y() = 0;
+ }
+
+ // ---------------------------------------
+
+ vector (
+ const T _x,
+ const T _y
+ )
+ {
+ x() = _x;
+ y() = _y;
+ }
+
+ // ---------------------------------------
+
+ template <typename U>
+ vector (
+ const vector<U,3>& item
+ )
+ {
+ // Do this so that we get the appropriate rounding depending on the relative
+ // type of T and U.
+ vector<T,3> temp(item);
+ x() = temp.x();
+ y() = temp.y();
+ }
+
+ // ---------------------------------------
+
+ vector (
+ const vector& item
+ ) : matrix<T,2,1>(item)
+ {
+ }
+
+ // ---------------------------------------
+
+ vector (
+ const vector<T,3>& item
+ )
+ {
+ x() = item.x();
+ y() = item.y();
+ }
+
+ // ---------------------------------------
+
+ template <typename U>
+ vector (
+ const vector<U,2>& item
+ )
+ {
+ (*this) = item;
+ }
+
+ // ---------------------------------------
+
+ template <typename EXP>
+ vector ( const matrix_exp<EXP>& m)
+ {
+ (*this) = m;
+ }
+
+ // ---------------------------------------
+
+ template <typename EXP>
+ vector& operator = (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only assign vectors with 2 elements to a dlib::vector<T,2> object
+ COMPILE_TIME_ASSERT(EXP::NR*EXP::NC == 2 || EXP::NR*EXP::NC == 0);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT((m.nr() == 1 || m.nc() == 1) && (m.size() == 2),
+ "\t vector(const matrix_exp& m)"
+ << "\n\t the given matrix is of the wrong size"
+ << "\n\t m.nr(): " << m.nr()
+ << "\n\t m.nc(): " << m.nc()
+ << "\n\t m.size(): " << m.size()
+ << "\n\t this: " << this
+ );
+
+ vector_assign_helper<T, typename EXP::type>::assign(*this, m);
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long N>
+ vector& operator = (
+ const vector<U,N>& item
+ )
+ {
+ vector_assign_helper<T,U>::assign(*this, item);
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector& operator= (
+ const vector& item
+ )
+ {
+ x() = item.x();
+ y() = item.y();
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ double length(
+ ) const
+ {
+ return std::sqrt((double)(x()*x() + y()*y()));
+ }
+
+ // ---------------------------------------
+
+ double length_squared(
+ ) const
+ {
+ return (double)(x()*x() + y()*y());
+ }
+
+ // ---------------------------------------
+
+ typename vc_rebind<double,2>::type normalize (
+ ) const
+ {
+ const double tmp = std::sqrt((double)(x()*x() + y()*y()));
+ return vector<double,2> ( x()/tmp,
+ y()/tmp
+ );
+ }
+
+ // ---------------------------------------
+
+ T& x (
+ )
+ {
+ return (*this)(0);
+ }
+
+ // ---------------------------------------
+
+ T& y (
+ )
+ {
+ return (*this)(1);
+ }
+
+ // ---------------------------------------
+
+ const T& x (
+ ) const
+ {
+ return (*this)(0);
+ }
+
+ // ---------------------------------------
+
+ const T& y (
+ ) const
+ {
+ return (*this)(1);
+ }
+
+ // ---------------------------------------
+
+ const T z (
+ ) const
+ {
+ return 0;
+ }
+
+ // ---------------------------------------
+
+ T dot (
+ const vector& rhs
+ ) const
+ {
+ return x()*rhs.x() + y()*rhs.y();
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long N>
+ typename vect_promote<T,U>::type dot (
+ const vector<U,N>& rhs
+ ) const
+ {
+ return x()*rhs.x() + y()*rhs.y() + z()*rhs.z();
+ }
+
+ // ---------------------------------------
+
+ vector& operator += (
+ const vector& rhs
+ )
+ {
+ x() += rhs.x();
+ y() += rhs.y();
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector& operator -= (
+ const vector& rhs
+ )
+ {
+ x() -= rhs.x();
+ y() -= rhs.y();
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector& operator /= (
+ const T& rhs
+ )
+ {
+ x() /= rhs;
+ y() /= rhs;
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector& operator *= (
+ const T& rhs
+ )
+ {
+ x() *= rhs;
+ y() *= rhs;
+ return *this;
+ }
+
+ // ---------------------------------------
+
+ vector operator - (
+ ) const
+ {
+ return vector(-x(), -y());
+ }
+
+ // ---------------------------------------
+
+ template <typename U>
+ typename vc_rebind_promote<T,U,2>::type operator / (
+ const U& val
+ ) const
+ {
+ typedef vector<typename vect_promote<T,U>::type,2> ret_type;
+ return ret_type(x()/val, y()/val);
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long NR2>
+ bool operator== (
+ const vector<U,NR2>& rhs
+ ) const
+ {
+ return x()==rhs.x() && y()==rhs.y() && z()==rhs.z();
+ }
+
+ // ---------------------------------------
+
+ bool operator== (
+ const vector& rhs
+ ) const
+ {
+ return x()==rhs.x() && y()==rhs.y();
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long NR2>
+ bool operator!= (
+ const vector<U,NR2>& rhs
+ ) const
+ {
+ return !(*this == rhs);
+ }
+
+ // ---------------------------------------
+
+ bool operator!= (
+ const vector& rhs
+ ) const
+ {
+ return !(*this == rhs);
+ }
+
+ // ---------------------------------------
+
+ void swap (
+ vector& item
+ )
+ {
+ dlib::exchange(x(), item.x());
+ dlib::exchange(y(), item.y());
+ }
+
+ // ---------------------------------------
+
+ template <typename U, long N>
+ typename vc_rebind_promote<T,U,3>::type cross (
+ const vector<U,N>& rhs
+ ) const
+ {
+ typedef vector<typename vect_promote<T,U>::type,3> ret_type;
+ return ret_type (
+ y()*rhs.z(),
+ - x()*rhs.z(),
+ x()*rhs.y() - y()*rhs.x()
+ );
+ }
+
+ // ---------------------------------------
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline const typename vc_rebind_promote<T,U,2>::type operator+ (
+ const vector<T,2>& lhs,
+ const vector<U,2>& rhs
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,2>::type ret_type;
+ return ret_type(lhs.x()+rhs.x(), lhs.y()+rhs.y());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline const typename vc_rebind_promote<T,U,3>::type operator+ (
+ const vector<T,3>& lhs,
+ const vector<U,3>& rhs
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,3>::type ret_type;
+ return ret_type(lhs.x()+rhs.x(), lhs.y()+rhs.y(), lhs.z()+rhs.z());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline const typename vc_rebind_promote<T,U,3>::type operator+ (
+ const vector<T,2>& lhs,
+ const vector<U,3>& rhs
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,3>::type ret_type;
+ return ret_type(lhs.x()+rhs.x(), lhs.y()+rhs.y(), rhs.z());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline const typename vc_rebind_promote<T,U,3>::type operator+ (
+ const vector<T,3>& lhs,
+ const vector<U,2>& rhs
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,3>::type ret_type;
+ return ret_type(lhs.x()+rhs.x(), lhs.y()+rhs.y(), lhs.z());
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline const typename vc_rebind_promote<T,U,2>::type operator- (
+ const vector<T,2>& lhs,
+ const vector<U,2>& rhs
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,2>::type ret_type;
+ return ret_type(lhs.x()-rhs.x(), lhs.y()-rhs.y());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline const typename vc_rebind_promote<T,U,3>::type operator- (
+ const vector<T,3>& lhs,
+ const vector<U,3>& rhs
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,3>::type ret_type;
+ return ret_type(lhs.x()-rhs.x(), lhs.y()-rhs.y(), lhs.z()-rhs.z());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline const typename vc_rebind_promote<T,U,3>::type operator- (
+ const vector<T,2>& lhs,
+ const vector<U,3>& rhs
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,3>::type ret_type;
+ return ret_type(lhs.x()-rhs.x(), lhs.y()-rhs.y(), -rhs.z());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline const typename vc_rebind_promote<T,U,3>::type operator- (
+ const vector<T,3>& lhs,
+ const vector<U,2>& rhs
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,3>::type ret_type;
+ return ret_type(lhs.x()-rhs.x(), lhs.y()-rhs.y(), lhs.z());
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline typename disable_if<is_matrix<U>, const typename vc_rebind_promote<T,U,2>::type >::type operator* (
+ const vector<T,2>& v,
+ const U& s
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,2>::type ret_type;
+ return ret_type(v.x()*s, v.y()*s);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline typename disable_if<is_matrix<U>, const typename vc_rebind_promote<T,U,2>::type >::type operator* (
+ const U& s,
+ const vector<T,2>& v
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,2>::type ret_type;
+ return ret_type(v.x()*s, v.y()*s);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline typename disable_if<is_matrix<U>, const typename vc_rebind_promote<T,U,3>::type >::type operator* (
+ const vector<T,3>& v,
+ const U& s
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,3>::type ret_type;
+ return ret_type(v.x()*s, v.y()*s, v.z()*s);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ inline typename disable_if<is_matrix<U>, const typename vc_rebind_promote<T,U,3>::type >::type operator* (
+ const U& s,
+ const vector<T,3>& v
+ )
+ {
+ typedef typename vc_rebind_promote<T,U,3>::type ret_type;
+ return ret_type(v.x()*s, v.y()*s, v.z()*s);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T, long NR>
+ inline void swap (
+ vector<T,NR> & a,
+ vector<T,NR> & b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T>
+ inline void serialize (
+ const vector<T,3>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.x(),out);
+ serialize(item.y(),out);
+ serialize(item.z(),out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type vector");
+ }
+ }
+
+ template<typename T>
+ inline void deserialize (
+ vector<T,3>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.x(),in);
+ deserialize(item.y(),in);
+ deserialize(item.z(),in);
+ }
+ catch (serialization_error& e)
+ {
+ item.x() = 0;
+ item.y() = 0;
+ item.z() = 0;
+ throw serialization_error(e.info + "\n while deserializing object of type vector");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T>
+ inline void serialize (
+ const vector<T,2>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.x(),out);
+ serialize(item.y(),out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type vector");
+ }
+ }
+
+ template<typename T>
+ inline void deserialize (
+ vector<T,2>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.x(),in);
+ deserialize(item.y(),in);
+ }
+ catch (serialization_error& e)
+ {
+ item.x() = 0;
+ item.y() = 0;
+ throw serialization_error(e.info + "\n while deserializing object of type vector");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T>
+ std::ostream& operator<< (
+ std::ostream& out,
+ const vector<T,3>& item
+ )
+ {
+ out << "(" << item.x() << ", " << item.y() << ", " << item.z() << ")";
+ return out;
+ }
+
+ template<typename T>
+ std::istream& operator>>(
+ std::istream& in,
+ vector<T,3>& item
+ )
+ {
+
+ // eat all the crap up to the '('
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n')
+ in.get();
+
+ // there should be a '(' if not then this is an error
+ if (in.get() != '(')
+ {
+ in.setstate(in.rdstate() | std::ios::failbit);
+ return in;
+ }
+
+ // eat all the crap up to the first number
+ while (in.peek() == ' ' || in.peek() == '\t')
+ in.get();
+ in >> item.x();
+
+ if (!in.good())
+ return in;
+
+ // eat all the crap up to the next number
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == ',')
+ in.get();
+ in >> item.y();
+
+ if (!in.good())
+ return in;
+
+ // eat all the crap up to the next number
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == ',')
+ in.get();
+ in >> item.z();
+
+ if (!in.good())
+ return in;
+
+ // eat all the crap up to the ')'
+ while (in.peek() == ' ' || in.peek() == '\t')
+ in.get();
+
+ // there should be a ')' if not then this is an error
+ if (in.get() != ')')
+ in.setstate(in.rdstate() | std::ios::failbit);
+ return in;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ template<typename T>
+ std::ostream& operator<< (
+ std::ostream& out,
+ const vector<T,2>& item
+ )
+ {
+ out << "(" << item.x() << ", " << item.y() << ")";
+ return out;
+ }
+
+ template<typename T>
+ std::istream& operator>>(
+ std::istream& in,
+ vector<T,2>& item
+ )
+ {
+
+ // eat all the crap up to the '('
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n')
+ in.get();
+
+ // there should be a '(' if not then this is an error
+ if (in.get() != '(')
+ {
+ in.setstate(in.rdstate() | std::ios::failbit);
+ return in;
+ }
+
+ // eat all the crap up to the first number
+ while (in.peek() == ' ' || in.peek() == '\t')
+ in.get();
+ in >> item.x();
+
+ if (!in.good())
+ return in;
+
+ // eat all the crap up to the next number
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == ',')
+ in.get();
+ in >> item.y();
+
+ if (!in.good())
+ return in;
+
+ // eat all the crap up to the ')'
+ while (in.peek() == ' ' || in.peek() == '\t')
+ in.get();
+
+ // there should be a ')' if not then this is an error
+ if (in.get() != ')')
+ in.setstate(in.rdstate() | std::ios::failbit);
+ return in;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ typedef vector<long,2> point;
+ typedef vector<double,2> dpoint;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+namespace std
+{
+ /*!
+ Define std::less<vector<T,3> > so that you can use vectors in the associative containers.
+ !*/
+ template<typename T>
+ struct less<dlib::vector<T,3> >
+ {
+ typedef dlib::vector<T, 3> first_argument_type;
+ typedef dlib::vector<T, 3> second_argument_type;
+ typedef bool result_type;
+ inline bool operator() (const dlib::vector<T,3> & a, const dlib::vector<T,3> & b) const
+ {
+ if (a.x() < b.x()) return true;
+ else if (a.x() > b.x()) return false;
+ else if (a.y() < b.y()) return true;
+ else if (a.y() > b.y()) return false;
+ else if (a.z() < b.z()) return true;
+ else if (a.z() > b.z()) return false;
+ else return false;
+ }
+ };
+
+ /*!
+ Define std::less<vector<T,2> > so that you can use vector<T,2>s in the associative containers.
+ !*/
+ template<typename T>
+ struct less<dlib::vector<T,2> >
+ {
+ typedef dlib::vector<T, 2> first_argument_type;
+ typedef dlib::vector<T, 2> second_argument_type;
+ typedef bool result_type;
+ inline bool operator() (const dlib::vector<T,2> & a, const dlib::vector<T,2> & b) const
+ {
+ if (a.x() < b.x()) return true;
+ else if (a.x() > b.x()) return false;
+ else if (a.y() < b.y()) return true;
+ else if (a.y() > b.y()) return false;
+ else return false;
+ }
+ };
+}
+
+#if defined(_MSC_VER) && _MSC_VER < 1400
+// turn this warning back on
+#pragma warning(default:4805)
+#endif
+
+#endif // DLIB_VECTOr_H_
+
diff --git a/ml/dlib/dlib/geometry/vector_abstract.h b/ml/dlib/dlib/geometry/vector_abstract.h
new file mode 100644
index 000000000..4aee8e32d
--- /dev/null
+++ b/ml/dlib/dlib/geometry/vector_abstract.h
@@ -0,0 +1,489 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_VECTOR_ABSTRACT_
+#ifdef DLIB_VECTOR_ABSTRACT_
+
+#include "../serialize.h"
+#include <functional>
+#include <iostream>
+#include "../matrix/matrix_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename T,
+ long NR = 3
+ >
+ class vector : public matrix<T,NR,1>
+ {
+ /*!
+ REQUIREMENTS ON T
+ T should be some object that provides an interface that is
+ compatible with double, float, int, long and the like.
+
+ REQUIREMENTS ON NR
+ NR == 3 || NR == 2
+
+ INITIAL VALUE
+ x() == 0
+ y() == 0
+ z() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a three dimensional vector. If NR == 2 then
+ this object is limited to representing points on the XY plane where
+ Z is set to 0.
+
+ Also note that this object performs the appropriate integer and
+ floating point conversions and promotions when vectors of mixed
+ type are used together. For example:
+ vector<int,3> vi;
+ vector<double,2> vd;
+ vd + vi == a vector<double,3> object type since that is what
+ is needed to contain the result of vi+vd without
+ any loss of information.
+ !*/
+
+ public:
+
+ typedef T type;
+
+ vector (
+ );
+ /*!
+ ensures
+ - #*this has been properly initialized
+ !*/
+
+ vector (
+ const T _x,
+ const T _y,
+ const T _z
+ );
+ /*!
+ requires
+ - NR == 3
+ ensures
+ - #x() == _x
+ - #y() == _y
+ - #z() == _z
+ !*/
+
+ vector (
+ const T _x,
+ const T _y
+ );
+ /*!
+ requires
+ - NR == 2
+ ensures
+ - #x() == _x
+ - #y() == _y
+ - #z() == 0
+ !*/
+
+ template <typename U, long NRv>
+ vector (
+ const vector<U,NRv>& v
+ );
+ /*!
+ ensures
+ - Initializes *this with the contents of v and does any rounding if necessary and also
+ takes care of converting between 2 and 3 dimensional vectors.
+ - if (U is a real valued type like float or double and T is an integral type like long) then
+ - if (NR == 3) then
+ - #x() == floor(v.x() + 0.5)
+ - #y() == floor(v.y() + 0.5)
+ - #z() == floor(v.z() + 0.5)
+ - else // NR == 2
+ - #x() == floor(v.x() + 0.5)
+ - #y() == floor(v.y() + 0.5)
+ - #z() == 0
+ - else
+ - if (NR == 3) then
+ - #x() == v.x()
+ - #y() == v.y()
+ - #z() == v.z()
+ - else // NR == 2
+ - #x() == v.x()
+ - #y() == v.y()
+ - #z() == 0
+ !*/
+
+ template <typename EXP>
+ vector (
+ const matrix_exp<EXP>& m
+ );
+ /*!
+ requires
+ - m.size() == NR
+ - m.nr() == 1 || m.nc() == 1 (i.e. m must be a row or column matrix)
+ ensures
+ - Initializes *this with the contents of m and does any rounding if necessary and also
+ takes care of converting between 2 and 3 dimensional vectors.
+ - if (m contains real valued values like float or double and T is an integral type like long) then
+ - #x() == floor(m(0) + 0.5)
+ - #y() == floor(m(1) + 0.5)
+ - if (NR == 3) then
+ - #z() == floor(m(2) + 0.5)
+ - else
+ - #z() == 0
+ - else
+ - #x() == m(0)
+ - #y() == m(1)
+ - if (NR == 3) then
+ - #z() == m(2)
+ - else
+ - #z() == 0
+ !*/
+
+ ~vector (
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+
+ double length(
+ ) const;
+ /*!
+ ensures
+ - returns the length of the vector
+ !*/
+
+ double length_squared(
+ ) const;
+ /*!
+ ensures
+ - returns length()*length()
+ !*/
+
+ T& x (
+ );
+ /*!
+ ensures
+ - returns a reference to the x component of the vector
+ !*/
+
+ T& y (
+ );
+ /*!
+ ensures
+ - returns a reference to the y component of the vector
+ !*/
+
+ T& z (
+ );
+ /*!
+ requires
+ - NR == 3 (this function actually doesn't exist when NR != 3)
+ ensures
+ - returns a reference to the z component of the vector
+ !*/
+
+ const T& x (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the x component of the vector
+ !*/
+
+ const T& y (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the y component of the vector
+ !*/
+
+ const T& z (
+ ) const;
+ /*!
+ ensures
+ - if (NR == 3) then
+ - returns a const reference to the z component of the vector
+ - else
+ - return 0
+ (there isn't really a z in this case so we just return 0)
+ !*/
+
+ T dot (
+ const vector& rhs
+ ) const;
+ /*!
+ ensures
+ - returns the result of the dot product between *this and rhs
+ !*/
+
+ vector<T,3> cross (
+ const vector& rhs
+ ) const;
+ /*!
+ ensures
+ - returns the result of the cross product between *this and rhs
+ !*/
+
+ vector<double,NR> normalize (
+ ) const;
+ /*!
+ ensures
+ - returns a vector with length() == 1 and in the same direction as *this
+ !*/
+
+ vector operator+ (
+ const vector& rhs
+ ) const;
+ /*!
+ ensures
+ - returns the result of adding *this to rhs
+ !*/
+
+ vector operator- (
+ const vector& rhs
+ ) const;
+ /*!
+ ensures
+ - returns the result of subtracting rhs from *this
+ !*/
+
+ vector operator- (
+ ) const;
+ /*!
+ ensures
+ - returns -1*(*this)
+ !*/
+
+ vector operator/ (
+ const T rhs
+ ) const;
+ /*!
+ ensures
+ - returns the result of dividing *this by rhs
+ !*/
+
+ vector& operator= (
+ const vector& rhs
+ );
+ /*!
+ ensures
+ - #x() == rhs.x()
+ - #y() == rhs.y()
+ - #z() == rhs.z()
+ - returns #*this
+ !*/
+
+ vector& operator += (
+ const vector& rhs
+ );
+ /*!
+ ensures
+ - #*this == *this + rhs
+ - returns #*this
+ !*/
+
+ vector& operator -= (
+ const vector& rhs
+ );
+ /*!
+ ensures
+ - #*this == *this - rhs
+ - returns #*this
+ !*/
+
+ vector& operator *= (
+ const T rhs
+ );
+ /*!
+ ensures
+ - #*this == *this * rhs
+ - returns #*this
+ !*/
+
+ vector& operator /= (
+ const T rhs
+ );
+ /*!
+ ensures
+ - #*this == *this / rhs
+ - returns #*this
+ !*/
+
+ template <typename U, long NR2>
+ bool operator== (
+ const vector<U,NR2>& rhs
+ ) const;
+ /*!
+ ensures
+ - if (x() == rhs.x() && y() == rhs.y() && z() == rhs.z()) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ template <typename U, long NR2>
+ bool operator!= (
+ const vector<U,NR2>& rhs
+ ) const;
+ /*!
+ ensures
+ - returns !((*this) == rhs)
+ !*/
+
+ void swap (
+ vector& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T, typename U, long NR>
+ vector operator* (
+ const vector<T,NR> & lhs,
+ const U rhs
+ );
+ /*!
+ ensures
+ - returns the result of multiplying the scalar rhs by lhs
+ !*/
+
+ template<typename T, typename U, long NR>
+ vector operator* (
+ const U lhs,
+ const vector<T,NR> & rhs
+ );
+ /*!
+ ensures
+ - returns the result of multiplying the scalar lhs by rhs
+ !*/
+
+ template<typename T, long NR>
+ inline void swap (
+ vector<T,NR> & a,
+ vector<T,NR> & b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template<typename T, long NR>
+ void serialize (
+ const vector<T,NR>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template<typename T, long NR>
+ void deserialize (
+ vector<T,NR>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+ template<typename T>
+ std::ostream& operator<< (
+ std::ostream& out,
+ const vector<T,3>& item
+ );
+ /*!
+ ensures
+ - writes item to out in the form "(x, y, z)"
+ !*/
+
+ template<typename T>
+ std::istream& operator>>(
+ std::istream& in,
+ vector<T,3>& item
+ );
+ /*!
+ ensures
+ - reads a vector from the input stream in and stores it in #item.
+ The data in the input stream should be of the form (x, y, z)
+ !*/
+
+ template<typename T>
+ std::ostream& operator<< (
+ std::ostream& out,
+ const vector<T,2>& item
+ );
+ /*!
+ ensures
+ - writes item to out in the form "(x, y)"
+ !*/
+
+ template<typename T>
+ std::istream& operator>>(
+ std::istream& in,
+ vector<T,2>& item
+ );
+ /*!
+ ensures
+ - reads a vector from the input stream in and stores it in #item.
+ The data in the input stream should be of the form (x, y)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A point
+ This is just a typedef of the vector object.
+ !*/
+
+ typedef vector<long,2> point;
+
+ /*!A dpoint
+ This is just a typedef of the vector object.
+ !*/
+
+ typedef vector<double,2> dpoint;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+namespace std
+{
+ /*!
+ Define std::less<vector<T,3> > so that you can use vectors in the associative containers.
+ !*/
+ template<typename T>
+ struct less<dlib::vector<T,3> > : public binary_function<dlib::vector<T,3> ,dlib::vector<T,3> ,bool>
+ {
+ inline bool operator() (const dlib::vector<T,3> & a, const dlib::vector<T,3> & b) const
+ {
+ if (a.x() < b.x()) return true;
+ else if (a.x() > b.x()) return false;
+ else if (a.y() < b.y()) return true;
+ else if (a.y() > b.y()) return false;
+ else if (a.z() < b.z()) return true;
+ else if (a.z() > b.z()) return false;
+ else return false;
+ }
+ };
+
+ /*!
+ Define std::less<vector<T,2> > so that you can use vector<T,2>s in the associative containers.
+ !*/
+ template<typename T>
+ struct less<dlib::vector<T,2> > : public binary_function<dlib::vector<T,2> ,dlib::vector<T,2> ,bool>
+ {
+ inline bool operator() (const dlib::vector<T,2> & a, const dlib::vector<T,2> & b) const
+ {
+ if (a.x() < b.x()) return true;
+ else if (a.x() > b.x()) return false;
+ else if (a.y() < b.y()) return true;
+ else if (a.y() > b.y()) return false;
+ else return false;
+ }
+ };
+}
+
+#endif // DLIB_VECTOR_ABSTRACT_
+
diff --git a/ml/dlib/dlib/global_optimization.h b/ml/dlib/dlib/global_optimization.h
new file mode 100644
index 000000000..26b40fcd9
--- /dev/null
+++ b/ml/dlib/dlib/global_optimization.h
@@ -0,0 +1,14 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GLOBAL_OPTIMIZATIOn_HEADER
+#define DLIB_GLOBAL_OPTIMIZATIOn_HEADER
+
+#include "global_optimization/upper_bound_function.h"
+#include "global_optimization/global_function_search.h"
+#include "global_optimization/find_max_global.h"
+
+#endif // DLIB_GLOBAL_OPTIMIZATIOn_HEADER
+
+
+
+
diff --git a/ml/dlib/dlib/global_optimization/find_max_global.h b/ml/dlib/dlib/global_optimization/find_max_global.h
new file mode 100644
index 000000000..5356129f5
--- /dev/null
+++ b/ml/dlib/dlib/global_optimization/find_max_global.h
@@ -0,0 +1,511 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FiND_GLOBAL_MAXIMUM_hH_
+#define DLIB_FiND_GLOBAL_MAXIMUM_hH_
+
+#include "find_max_global_abstract.h"
+#include "global_function_search.h"
+#include "../metaprogramming.h"
+#include <utility>
+#include <chrono>
+
+namespace dlib
+{
+ namespace gopt_impl
+ {
+
+ // ----------------------------------------------------------------------------------------
+
+ class disable_decay_to_scalar
+ {
+ const matrix<double,0,1>& a;
+ public:
+ disable_decay_to_scalar(const matrix<double,0,1>& a) : a(a){}
+ operator const matrix<double,0,1>&() const { return a;}
+ };
+
+
+ template <typename T, size_t... indices>
+ auto _cwv (
+ T&& f,
+ const matrix<double,0,1>& a,
+ compile_time_integer_list<indices...>
+ ) -> decltype(f(a(indices-1)...))
+ {
+ DLIB_CASSERT(a.size() == sizeof...(indices),
+ "You invoked dlib::call_function_and_expand_args(f,a) but the number of arguments expected by f() doesn't match the size of 'a'. "
+ << "Expected " << sizeof...(indices) << " arguments but got " << a.size() << "."
+ );
+ return f(a(indices-1)...);
+ }
+
+ // Visual studio, as of November 2017, doesn't support C++11 and can't compile this code.
+ // So we write the terrible garbage in the #else for visual studio. When Visual Studio supports C++11 I'll update this #ifdef to use the C++11 code.
+#ifndef _MSC_VER
+ template <size_t max_unpack>
+ struct call_function_and_expand_args
+ {
+ template <typename T>
+ static auto go(T&& f, const matrix<double,0,1>& a) -> decltype(_cwv(std::forward<T>(f),a,typename make_compile_time_integer_range<max_unpack>::type()))
+ {
+ return _cwv(std::forward<T>(f),a,typename make_compile_time_integer_range<max_unpack>::type());
+ }
+
+ template <typename T>
+ static auto go(T&& f, const matrix<double,0,1>& a) -> decltype(call_function_and_expand_args<max_unpack-1>::template go(std::forward<T>(f),a))
+ {
+ return call_function_and_expand_args<max_unpack-1>::go(std::forward<T>(f),a);
+ }
+ };
+
+ template <>
+ struct call_function_and_expand_args<0>
+ {
+ template <typename T>
+ static auto go(T&& f, const matrix<double,0,1>& a) -> decltype(f(disable_decay_to_scalar(a)))
+ {
+ return f(disable_decay_to_scalar(a));
+ }
+ };
+#else
+ template <size_t max_unpack>
+ struct call_function_and_expand_args
+ {
+template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> decltype(f(disable_decay_to_scalar(a))) {return f(disable_decay_to_scalar(a)); }
+template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> decltype(f(a(0))) { DLIB_CASSERT(a.size() == 1); return f(a(0)); }
+template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> decltype(f(a(0),a(1))) { DLIB_CASSERT(a.size() == 2); return f(a(0),a(1)); }
+template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> decltype(f(a(0), a(1), a(2))) { DLIB_CASSERT(a.size() == 3); return f(a(0), a(1),a(2)); }
+template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> decltype(f(a(0), a(1), a(2), a(3))) { DLIB_CASSERT(a.size() == 4); return f(a(0), a(1), a(2), a(3)); }
+template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> decltype(f(a(0), a(1), a(2), a(3), a(4))) { DLIB_CASSERT(a.size() == 5); return f(a(0), a(1), a(2), a(3), a(4)); }
+template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> decltype(f(a(0), a(1), a(2), a(3), a(4), a(5))) { DLIB_CASSERT(a.size() == 6); return f(a(0), a(1), a(2), a(3), a(4), a(5)); }
+template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> decltype(f(a(0), a(1), a(2), a(3), a(4), a(5), a(6))) { DLIB_CASSERT(a.size() == 7); return f(a(0), a(1), a(2), a(3), a(4), a(5), a(6)); }
+ };
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ auto call_function_and_expand_args(
+ T&& f,
+ const matrix<double,0,1>& a
+ ) -> decltype(gopt_impl::call_function_and_expand_args<40>::go(f,a))
+ {
+ // unpack up to 40 parameters when calling f()
+ return gopt_impl::call_function_and_expand_args<40>::go(std::forward<T>(f),a);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ struct max_function_calls
+ {
+ max_function_calls() = default;
+ explicit max_function_calls(size_t max_calls) : max_calls(max_calls) {}
+ size_t max_calls = std::numeric_limits<size_t>::max();
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ const auto FOREVER = std::chrono::hours(24*356*290); // 290 years
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename funct
+ >
+ std::pair<size_t,function_evaluation> find_max_global (
+ std::vector<funct>& functions,
+ std::vector<function_spec> specs,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon,
+ double ymult
+ )
+ {
+ // Decide which parameters should be searched on a log scale. Basically, it's
+ // common for machine learning models to have parameters that should be searched on
+ // a log scale (e.g. SVM C). These parameters are usually identifiable because
+ // they have bounds like [1e-5 1e10], that is, they span a very large range of
+ // magnitudes from really small to really big. So there we are going to check for
+ // that and if we find parameters with that kind of bound constraints we will
+ // transform them to a log scale automatically.
+ std::vector<std::vector<bool>> log_scale(specs.size());
+ for (size_t i = 0; i < specs.size(); ++i)
+ {
+ for (long j = 0; j < specs[i].lower.size(); ++j)
+ {
+ if (!specs[i].is_integer_variable[j] && specs[i].lower(j) > 0 && specs[i].upper(j)/specs[i].lower(j) >= 1000)
+ {
+ log_scale[i].push_back(true);
+ specs[i].lower(j) = std::log(specs[i].lower(j));
+ specs[i].upper(j) = std::log(specs[i].upper(j));
+ }
+ else
+ {
+ log_scale[i].push_back(false);
+ }
+ }
+ }
+
+ global_function_search opt(specs);
+ opt.set_solver_epsilon(solver_epsilon);
+
+ const auto time_to_stop = std::chrono::steady_clock::now() + max_runtime;
+
+ // Now run the main solver loop.
+ for (size_t i = 0; i < num.max_calls && std::chrono::steady_clock::now() < time_to_stop; ++i)
+ {
+ auto next = opt.get_next_x();
+ matrix<double,0,1> x = next.x();
+ // Undo any log-scaling that was applied to the variables before we pass them
+ // to the functions being optimized.
+ for (long j = 0; j < x.size(); ++j)
+ {
+ if (log_scale[next.function_idx()][j])
+ x(j) = std::exp(x(j));
+ }
+ double y = ymult*call_function_and_expand_args(functions[next.function_idx()], x);
+ next.set(y);
+ }
+
+
+ matrix<double,0,1> x;
+ double y;
+ size_t function_idx;
+ opt.get_best_function_eval(x,y,function_idx);
+ // Undo any log-scaling that was applied to the variables before we output them.
+ for (long j = 0; j < x.size(); ++j)
+ {
+ if (log_scale[function_idx][j])
+ x(j) = std::exp(x(j));
+ }
+ return std::make_pair(function_idx, function_evaluation(x,y/ymult));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ std::pair<size_t,function_evaluation> find_max_global (
+ std::vector<funct>& functions,
+ std::vector<function_spec> specs,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return impl::find_max_global(functions, std::move(specs), num, max_runtime, solver_epsilon, +1);
+ }
+
+ template <
+ typename funct
+ >
+ std::pair<size_t,function_evaluation> find_min_global (
+ std::vector<funct>& functions,
+ std::vector<function_spec> specs,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return impl::find_max_global(functions, std::move(specs), num, max_runtime, solver_epsilon, -1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ std::vector<funct> functions(1,std::move(f));
+ std::vector<function_spec> specs(1, function_spec(bound1, bound2, is_integer_variable));
+ return find_max_global(functions, std::move(specs), num, max_runtime, solver_epsilon).second;
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ std::vector<funct> functions(1,std::move(f));
+ std::vector<function_spec> specs(1, function_spec(bound1, bound2, is_integer_variable));
+ return find_min_global(functions, std::move(specs), num, max_runtime, solver_epsilon).second;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, is_integer_variable, num, FOREVER, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, is_integer_variable, num, FOREVER, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, FOREVER, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, FOREVER, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), matrix<double,0,1>({bound1}), matrix<double,0,1>({bound2}), num, max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), matrix<double,0,1>({bound1}), matrix<double,0,1>({bound2}), num, max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_max_global(std::move(f), matrix<double,0,1>({bound1}), matrix<double,0,1>({bound2}), num, FOREVER, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_min_global(std::move(f), matrix<double,0,1>({bound1}), matrix<double,0,1>({bound2}), num, FOREVER, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, is_integer_variable, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, is_integer_variable, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FiND_GLOBAL_MAXIMUM_hH_
+
diff --git a/ml/dlib/dlib/global_optimization/find_max_global_abstract.h b/ml/dlib/dlib/global_optimization/find_max_global_abstract.h
new file mode 100644
index 000000000..4be62b154
--- /dev/null
+++ b/ml/dlib/dlib/global_optimization/find_max_global_abstract.h
@@ -0,0 +1,496 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FiND_GLOBAL_MAXIMUM_ABSTRACT_hH_
+#ifdef DLIB_FiND_GLOBAL_MAXIMUM_ABSTRACT_hH_
+
+#include "upper_bound_function_abstract.h"
+#include "global_function_search_abstract.h"
+#include "../metaprogramming.h"
+#include "../matrix.h"
+#include <utility>
+#include <chrono>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ auto call_function_and_expand_args(
+ T&& f,
+ const matrix<double,0,1>& args
+ ) -> decltype(f(args or args expanded out as discussed below));
+ /*!
+ requires
+ - f is a function object with one of the following signatures:
+ auto f(matrix<double,0,1>)
+ auto f(double)
+ auto f(double,double)
+ auto f(double,double,double)
+ ...
+ auto f(double,double,...,double) // up to 40 double arguments
+ - if (f() explicitly expands its arguments) then
+ - args.size() == the number of arguments taken by f.
+ ensures
+ - This function invokes f() with the given arguments and returns the result.
+ However, the signature of f() is allowed to vary. In particular, if f()
+ takes a matrix<double,0,1> as a single argument then this function simply
+ calls f(args). However, if f() takes double arguments then args is expanded
+ appropriately, i.e. it calls one of the following as appropriate:
+ f(args(0))
+ f(args(0),args(1))
+ ...
+ f(args(0),args(1),...,args(N))
+ and the result of f() is returned.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ struct max_function_calls
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple typed integer class used to strongly type the "max number
+ of function calls" argument to find_max_global() and find_min_global().
+
+ !*/
+
+ max_function_calls() = default;
+
+ explicit max_function_calls(size_t max_calls) : max_calls(max_calls) {}
+
+ size_t max_calls = std::numeric_limits<size_t>::max();
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ const auto FOREVER = std::chrono::hours(24*356*290); // 290 years, basically forever
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ std::pair<size_t,function_evaluation> find_max_global (
+ std::vector<funct>& functions,
+ const std::vector<function_spec>& specs,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ );
+ /*!
+ requires
+ - functions.size() != 0
+ - functions.size() == specs.size()
+ - solver_epsilon >= 0
+ - for all valid i:
+ - functions[i] is a real valued multi-variate function object. Moreover,
+ it must be callable via an expression of the form:
+ call_function_and_expand_args(functions[i], specs.lower). This means
+ function[i] should have a signature like one of the following:
+ double f(matrix<double,0,1>)
+ double f(double)
+ double f(double,double)
+ etc.
+ - The range of inputs defined by specs[i] must be valid inputs to
+ functions[i].
+ ensures
+ - This function performs global optimization on the set of given functions.
+ The goal is to maximize the following objective function:
+ max_{i,x_i}: functions[i](x_i)
+ subject to the constraints on x_i defined by specs[i].
+ Once found, the return value of find_max_global() is:
+ make_pair(i, function_evaluation(x_i,functions[i](x_i))).
+ That is, we search for the settings of i and x that return the largest output
+ and return those settings.
+ - The search is performed using the global_function_search object. See its
+ documentation for details of the algorithm.
+ - We set the global_function_search::get_solver_epsilon() parameter to
+ solver_epsilon. Therefore, the search will only attempt to find a global
+ maximizer to at most solver_epsilon accuracy. Once a local maximizer is
+ found to that accuracy the search will focus entirely on finding other maxima
+ elsewhere rather than on further improving the current local optima found so
+ far. That is, once a local maxima is identified to about solver_epsilon
+ accuracy, the algorithm will spend all its time exploring the functions to
+ find other local maxima to investigate. An epsilon of 0 means it will keep
+ solving until it reaches full floating point precision. Larger values will
+ cause it to switch to pure global exploration sooner and therefore might be
+ more effective if your objective function has many local maxima and you don't
+ care about a super high precision solution.
+ - find_max_global() runs until one of the following is true:
+ - The total number of calls to the provided functions is == num.max_calls
+ - More than max_runtime time has elapsed since the start of this function.
+ - Any variables that satisfy the following conditions are optimized on a log-scale:
+ - The lower bound on the variable is > 0
+ - The ratio of the upper bound to lower bound is >= 1000
+ - The variable is not an integer variable
+ We do this because it's common to optimize machine learning models that have
+ parameters with bounds in a range such as [1e-5 to 1e10] (e.g. the SVM C
+ parameter) and it's much more appropriate to optimize these kinds of
+ variables on a log scale. So we transform them by applying std::log() to
+ them and then undo the transform via std::exp() before invoking the function
+ being optimized. Therefore, this transformation is invisible to the user
+ supplied functions. In most cases, it improves the efficiency of the
+ optimizer.
+ !*/
+
+ template <
+ typename funct
+ >
+ std::pair<size_t,function_evaluation> find_min_global (
+ std::vector<funct>& functions,
+ const std::vector<function_spec>& specs,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ );
+ /*!
+ This function is identical to the find_max_global() defined immediately above,
+ except that we perform minimization rather than maximization.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ );
+ /*!
+ requires
+ - bound1.size() == bound2.size() == is_integer_variable.size()
+ - for all valid i: bound1(i) != bound2(i)
+ - solver_epsilon >= 0
+ - f() is a real valued multi-variate function object. Moreover, it must be
+ callable via an expression of the form: call_function_and_expand_args(f,
+ bound1). This means f() should have a signature like one of the following:
+ double f(matrix<double,0,1>)
+ double f(double)
+ double f(double,double)
+ etc.
+ - The range of inputs defined by function_spec(bound1,bound2,is_integer_variable)
+ must be valid inputs to f().
+ ensures
+ - This function performs global optimization on the given f() function.
+ The goal is to maximize the following objective function:
+ f(x)
+ subject to the constraints on x defined by function_spec(bound1,bound2,is_integer_variable).
+ Once found, the return value of find_max_global() is:
+ function_evaluation(x,f(x))).
+ That is, we search for the setting of x that returns the largest output and
+ return that setting.
+ - The search is performed using the global_function_search object. See its
+ documentation for details of the algorithm.
+ - We set the global_function_search::get_solver_epsilon() parameter to
+ solver_epsilon. Therefore, the search will only attempt to find a global
+ maximizer to at most solver_epsilon accuracy. Once a local maximizer is
+ found to that accuracy the search will focus entirely on finding other maxima
+ elsewhere rather than on further improving the current local optima found so
+ far. That is, once a local maxima is identified to about solver_epsilon
+ accuracy, the algorithm will spend all its time exploring the function to
+ find other local maxima to investigate. An epsilon of 0 means it will keep
+ solving until it reaches full floating point precision. Larger values will
+ cause it to switch to pure global exploration sooner and therefore might be
+ more effective if your objective function has many local maxima and you don't
+ care about a super high precision solution.
+ - find_max_global() runs until one of the following is true:
+ - The total number of calls to f() is == num.max_calls
+ - More than max_runtime time has elapsed since the start of this function.
+ - Any variables that satisfy the following conditions are optimized on a log-scale:
+ - The lower bound on the variable is > 0
+ - The ratio of the upper bound to lower bound is >= 1000
+ - The variable is not an integer variable
+ We do this because it's common to optimize machine learning models that have
+ parameters with bounds in a range such as [1e-5 to 1e10] (e.g. the SVM C
+ parameter) and it's much more appropriate to optimize these kinds of
+ variables on a log scale. So we transform them by applying std::log() to
+ them and then undo the transform via std::exp() before invoking the function
+ being optimized. Therefore, this transformation is invisible to the user
+ supplied functions. In most cases, it improves the efficiency of the
+ optimizer.
+ !*/
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ );
+ /*!
+ This function is identical to the find_max_global() defined immediately above,
+ except that we perform minimization rather than maximization.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// The following functions are just convenient overloads for calling the above defined
+// find_max_global() and find_min_global() routines.
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, FOREVER, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, FOREVER, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, FOREVER, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, std::vector<bool>(bound1.size(),false), num, FOREVER, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), matrix<double,0,1>({bound1}), matrix<double,0,1>({bound2}), num, max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const max_function_calls num,
+ const std::chrono::nanoseconds max_runtime = FOREVER,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), matrix<double,0,1>({bound1}), matrix<double,0,1>({bound2}), num, max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_max_global(std::move(f), matrix<double,0,1>({bound1}), matrix<double,0,1>({bound2}), num, FOREVER, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const max_function_calls num,
+ double solver_epsilon
+ )
+ {
+ return find_min_global(std::move(f), matrix<double,0,1>({bound1}), matrix<double,0,1>({bound2}), num, FOREVER, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const double bound1,
+ const double bound2,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ function_evaluation find_max_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_max_global(std::move(f), bound1, bound2, is_integer_variable, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+ template <
+ typename funct
+ >
+ function_evaluation find_min_global (
+ funct f,
+ const matrix<double,0,1>& bound1,
+ const matrix<double,0,1>& bound2,
+ const std::vector<bool>& is_integer_variable,
+ const std::chrono::nanoseconds max_runtime,
+ double solver_epsilon = 0
+ )
+ {
+ return find_min_global(std::move(f), bound1, bound2, is_integer_variable, max_function_calls(), max_runtime, solver_epsilon);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FiND_GLOBAL_MAXIMUM_ABSTRACT_hH_
+
+
diff --git a/ml/dlib/dlib/global_optimization/global_function_search.cpp b/ml/dlib/dlib/global_optimization/global_function_search.cpp
new file mode 100644
index 000000000..fada289a4
--- /dev/null
+++ b/ml/dlib/dlib/global_optimization/global_function_search.cpp
@@ -0,0 +1,942 @@
+
+#include "global_function_search.h"
+#include "upper_bound_function.h"
+#include "../optimization.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace qopt_impl
+ {
+ void fit_quadratic_to_points_mse(
+ const matrix<double>& X,
+ const matrix<double,0,1>& Y,
+ matrix<double>& H,
+ matrix<double,0,1>& g,
+ double& c
+ )
+ {
+ DLIB_CASSERT(X.size() > 0);
+ DLIB_CASSERT(X.nc() == Y.size());
+ DLIB_CASSERT(X.nc() >= (X.nr()+1)*(X.nr()+2)/2);
+
+ const long dims = X.nr();
+ const long M = X.nc();
+
+ matrix<double> W((X.nr()+1)*(X.nr()+2)/2, M);
+
+ set_subm(W, 0,0, dims, M) = X;
+ set_subm(W, dims,0, 1, M) = 1;
+ for (long c = 0; c < X.nc(); ++c)
+ {
+ long wr = dims+1;
+ for (long r = 0; r < X.nr(); ++r)
+ {
+ for (long r2 = r; r2 < X.nr(); ++r2)
+ {
+ W(wr,c) = X(r,c)*X(r2,c);
+ if (r2 == r)
+ W(wr,c) *= 0.5;
+ ++wr;
+ }
+ }
+ }
+
+ matrix<double,0,1> z = pinv(trans(W))*Y;
+
+ c = z(dims);
+ g = rowm(z, range(0,dims-1));
+
+ H.set_size(dims,dims);
+
+ long wr = dims+1;
+ for (long r = 0; r < X.nr(); ++r)
+ {
+ for (long r2 = r; r2 < X.nr(); ++r2)
+ {
+ H(r,r2) = H(r2,r) = z(wr++);
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void fit_quadratic_to_points(
+ const matrix<double>& X,
+ const matrix<double,0,1>& Y,
+ matrix<double>& H,
+ matrix<double,0,1>& g,
+ double& c
+ )
+ /*!
+ requires
+ - X.size() > 0
+ - X.nc() == Y.size()
+ - X.nr()+1 <= X.nc()
+ ensures
+ - This function finds a quadratic function, Q(x), that interpolates the
+ given set of points. If there aren't enough points to uniquely define
+ Q(x) then the Q(x) that fits the given points with the minimum Frobenius
+ norm hessian matrix is selected.
+ - To be precise:
+ - Let: Q(x) == 0.5*trans(x)*H*x + trans(x)*g + c
+ - Then this function finds H, g, and c that minimizes the following:
+ sum(squared(H))
+ such that:
+ Q(colm(X,i)) == Y(i), for all valid i
+ - If there are more points than necessary to constrain Q then the Q
+ that best interpolates the function in the mean squared sense is
+ found.
+ !*/
+ {
+ DLIB_CASSERT(X.size() > 0);
+ DLIB_CASSERT(X.nc() == Y.size());
+ DLIB_CASSERT(X.nr()+1 <= X.nc());
+
+
+ if (X.nc() >= (X.nr()+1)*(X.nr()+2)/2)
+ {
+ fit_quadratic_to_points_mse(X,Y,H,g,c);
+ return;
+ }
+
+
+ const long dims = X.nr();
+ const long M = X.nc();
+
+ /*
+ Our implementation uses the equations 3.9 - 3.12 from the paper:
+ The NEWUOA software for unconstrained optimization without derivatives
+ By M.J.D. Powell, 40th Workshop on Large Scale Nonlinear Optimization (Erice, Italy, 2004)
+ */
+
+ matrix<double> W(M + dims + 1, M + dims + 1);
+
+ set_subm(W, 0, 0, M, M) = 0.5*squared(tmp(trans(X)*X));
+ set_subm(W, 0, M, M, 1) = 1;
+ set_subm(W, M, 0, 1, M) = 1;
+ set_subm(W, M, M, dims+1, dims+1) = 0;
+ set_subm(W, 0, M+1, X.nc(), X.nr()) = trans(X);
+ set_subm(W, M+1, 0, X.nr(), X.nc()) = X;
+
+
+ const matrix<double,0,1> r = join_cols(Y, zeros_matrix<double>(dims+1,1));
+
+ //matrix<double,0,1> z = pinv(W)*r;
+ lu_decomposition<decltype(W)> lu(W);
+ matrix<double,0,1> z = lu.solve(r);
+ //if (lu.is_singular()) std::cout << "WARNING, THE W MATRIX IS SINGULAR!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" << std::endl;
+
+ matrix<double,0,1> lambda = rowm(z, range(0,M-1));
+
+ c = z(M);
+ g = rowm(z, range(M+1,z.size()-1));
+ H = X*diagm(lambda)*trans(X);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ struct quad_interp_result
+ {
+ quad_interp_result() = default;
+
+ template <typename EXP>
+ quad_interp_result(
+ const matrix_exp<EXP>& best_x,
+ double predicted_improvement
+ ) : best_x(best_x), predicted_improvement(predicted_improvement) {}
+
+ matrix<double,0,1> best_x;
+ double predicted_improvement = std::numeric_limits<double>::quiet_NaN();
+ };
+
+ // ----------------------------------------------------------------------------------------
+
+ quad_interp_result find_max_quadraticly_interpolated_vector (
+ const matrix<double,0,1>& anchor,
+ const double radius,
+ const std::vector<matrix<double,0,1>>& x,
+ const std::vector<double>& y,
+ const matrix<double,0,1>& lower,
+ const matrix<double,0,1>& upper
+ )
+ {
+ DLIB_CASSERT(x.size() == y.size());
+ DLIB_CASSERT(x.size() > 0);
+ for (size_t i = 0; i < x.size(); ++i)
+ DLIB_CASSERT(anchor.size() == x[i].size());
+
+ const long x_size = static_cast<long>(x.size());
+ DLIB_CASSERT(anchor.size()+1 <= x_size && x_size <= (anchor.size()+1)*(anchor.size()+2)/2);
+
+
+ matrix<double> X(anchor.size(), x.size());
+ matrix<double,0,1> Y(x.size());
+ for (size_t i = 0; i < x.size(); ++i)
+ {
+ set_colm(X,i) = x[i] - anchor;
+ Y(i) = y[i];
+ }
+
+ matrix<double> H;
+ matrix<double,0,1> g;
+ double c;
+
+ fit_quadratic_to_points(X, Y, H, g, c);
+
+ matrix<double,0,1> p;
+
+ solve_trust_region_subproblem_bounded(-H,-g, radius, p, 0.001, 500, lower-anchor, upper-anchor);
+
+ // ensure we never move more than radius from the anchor. This might happen if the
+ // trust region subproblem isn't solved accurately enough.
+ if (length(p) >= radius)
+ p *= radius/length(p);
+
+
+ double predicted_improvement = 0.5*trans(p)*H*p + trans(p)*g;
+ return quad_interp_result{clamp(anchor+p,lower,upper), predicted_improvement};
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ quad_interp_result pick_next_sample_using_trust_region (
+ const std::vector<function_evaluation>& samples,
+ double& radius,
+ const matrix<double,0,1>& lower,
+ const matrix<double,0,1>& upper,
+ const std::vector<bool>& is_integer_variable
+ )
+ {
+ DLIB_CASSERT(samples.size() > 0);
+ // We don't use the QP to optimize integer variables. Instead, we just fix them at
+ // their best observed value and use the QP to optimize the real variables. So the
+ // number of dimensions, as far as the QP is concerned, is the number of non-integer
+ // variables.
+ size_t dims = 0;
+ for (auto is_int : is_integer_variable)
+ {
+ if (!is_int)
+ ++dims;
+ }
+
+ DLIB_CASSERT(samples.size() >= dims+1);
+
+ // Use enough points to fill out a quadratic model or the max available if we don't
+ // have quite enough.
+ const long N = std::min(samples.size(), (dims+1)*(dims+2)/2);
+
+
+ // first find the best sample;
+ double best_val = -1e300;
+ matrix<double,0,1> best_x;
+ for (auto& v : samples)
+ {
+ if (v.y > best_val)
+ {
+ best_val = v.y;
+ best_x = v.x;
+ }
+ }
+
+ // if there are only integer variables then there isn't really anything to do. So just
+ // return the best_x and say there is no improvement.
+ if (dims == 0)
+ return quad_interp_result(best_x, 0);
+
+ matrix<long,0,1> active_dims(dims);
+ long j = 0;
+ for (size_t i = 0; i < is_integer_variable.size(); ++i)
+ {
+ if (!is_integer_variable[i])
+ active_dims(j++) = i;
+ }
+
+ // now find the N-1 nearest neighbors of best_x
+ std::vector<std::pair<double,size_t>> distances;
+ for (size_t i = 0; i < samples.size(); ++i)
+ distances.emplace_back(length(best_x-samples[i].x), i);
+ std::sort(distances.begin(), distances.end());
+ distances.resize(N);
+
+ std::vector<matrix<double,0,1>> x;
+ std::vector<double> y;
+ for (auto& idx : distances)
+ {
+ x.emplace_back(rowm(samples[idx.second].x, active_dims));
+ y.emplace_back(samples[idx.second].y);
+ }
+
+ if (radius == 0)
+ {
+ for (auto& idx : distances)
+ radius = std::max(radius, length(rowm(best_x-samples[idx.second].x, active_dims)) );
+ // Shrink the radius a little so we are always going to be making the sampling of
+ // points near the best current point smaller.
+ radius *= 0.95;
+ }
+
+
+ auto tmp = find_max_quadraticly_interpolated_vector(rowm(best_x,active_dims), radius, x, y, rowm(lower,active_dims), rowm(upper,active_dims));
+
+ // stick the integer variables back into the solution
+ for (long i = 0; i < active_dims.size(); ++i)
+ best_x(active_dims(i)) = tmp.best_x(i);
+
+ tmp.best_x = best_x;
+ return tmp;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,0,1> make_random_vector(
+ dlib::rand& rnd,
+ const matrix<double,0,1>& lower,
+ const matrix<double,0,1>& upper,
+ const std::vector<bool>& is_integer_variable
+ )
+ {
+ matrix<double,0,1> temp(lower.size());
+ for (long i = 0; i < temp.size(); ++i)
+ {
+ temp(i) = rnd.get_double_in_range(lower(i), upper(i));
+ if (is_integer_variable[i])
+ temp(i) = std::round(temp(i));
+ }
+ return temp;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ struct max_upper_bound_function
+ {
+ max_upper_bound_function() = default;
+
+ template <typename EXP>
+ max_upper_bound_function(
+ const matrix_exp<EXP>& x,
+ double predicted_improvement,
+ double upper_bound
+ ) : x(x), predicted_improvement(predicted_improvement), upper_bound(upper_bound) {}
+
+ matrix<double,0,1> x;
+ double predicted_improvement = 0;
+ double upper_bound = 0;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ max_upper_bound_function pick_next_sample_as_max_upper_bound (
+ dlib::rand& rnd,
+ const upper_bound_function& ub,
+ const matrix<double,0,1>& lower,
+ const matrix<double,0,1>& upper,
+ const std::vector<bool>& is_integer_variable,
+ const size_t num_random_samples
+ )
+ {
+ DLIB_CASSERT(ub.num_points() > 0);
+
+
+
+ // now do a simple random search to find the maximum upper bound
+ double best_ub_so_far = -std::numeric_limits<double>::infinity();
+ matrix<double,0,1> vtemp(lower.size()), v;
+ for (size_t rounds = 0; rounds < num_random_samples; ++rounds)
+ {
+ vtemp = make_random_vector(rnd, lower, upper, is_integer_variable);
+
+ double bound = ub(vtemp);
+ if (bound > best_ub_so_far)
+ {
+ best_ub_so_far = bound;
+ v = vtemp;
+ }
+ }
+
+ double max_value = -std::numeric_limits<double>::infinity();
+ for (auto& v : ub.get_points())
+ max_value = std::max(max_value, v.y);
+
+ return max_upper_bound_function(v, best_ub_so_far - max_value, best_ub_so_far);
+ }
+
+ } // end of namespace qopt_impl;
+
+ using namespace qopt_impl;
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ function_spec::function_spec(
+ matrix<double,0,1> bound1,
+ matrix<double,0,1> bound2
+ ) :
+ lower(std::move(bound1)), upper(std::move(bound2))
+ {
+ DLIB_CASSERT(lower.size() == upper.size());
+ for (long i = 0; i < lower.size(); ++i)
+ {
+ if (upper(i) < lower(i))
+ std::swap(lower(i), upper(i));
+ DLIB_CASSERT(upper(i) != lower(i), "The upper and lower bounds can't be equal.");
+ }
+ is_integer_variable.assign(lower.size(), false);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ function_spec::function_spec(
+ matrix<double,0,1> bound1,
+ matrix<double,0,1> bound2,
+ std::vector<bool> is_integer
+ ) :
+ function_spec(std::move(bound1),std::move(bound2))
+ {
+ is_integer_variable = std::move(is_integer);
+ DLIB_CASSERT(lower.size() == (long)is_integer_variable.size());
+
+
+ // Make sure any integer variables have integer bounds.
+ for (size_t i = 0; i < is_integer_variable.size(); ++i)
+ {
+ if (is_integer_variable[i])
+ {
+ DLIB_CASSERT(std::round(lower(i)) == lower(i), "If you say a variable is an integer variable then it must have an integer lower bound. \n"
+ << "lower[i] = " << lower(i));
+ DLIB_CASSERT(std::round(upper(i)) == upper(i), "If you say a variable is an integer variable then it must have an integer upper bound. \n"
+ << "upper[i] = " << upper(i));
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace gopt_impl
+ {
+ upper_bound_function funct_info::build_upper_bound_with_all_function_evals (
+ ) const
+ {
+ upper_bound_function tmp(ub);
+
+ // we are going to add the outstanding evals into this and assume the
+ // outstanding evals are going to take y values equal to their nearest
+ // neighbor complete evals.
+ for (auto& eval : outstanding_evals)
+ {
+ function_evaluation e;
+ e.x = eval.x;
+ e.y = find_nn(ub.get_points(), eval.x);
+ tmp.add(e);
+ }
+
+ return tmp;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ double funct_info::find_nn (
+ const std::vector<function_evaluation>& evals,
+ const matrix<double,0,1>& x
+ )
+ {
+ double best_y = 0;
+ double best_dist = std::numeric_limits<double>::infinity();
+ for (auto& v : evals)
+ {
+ double dist = length_squared(v.x-x);
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ best_y = v.y;
+ }
+ }
+ return best_y;
+ }
+
+ } // end namespace gopt_impl
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ function_evaluation_request::function_evaluation_request(
+ function_evaluation_request&& item
+ )
+ {
+ m_has_been_evaluated = item.m_has_been_evaluated;
+ req = item.req;
+ info = item.info;
+ item.info.reset();
+
+ item.m_has_been_evaluated = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ function_evaluation_request& function_evaluation_request::
+ operator=(
+ function_evaluation_request&& item
+ )
+ {
+ function_evaluation_request(std::move(item)).swap(*this);
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void function_evaluation_request::
+ swap(
+ function_evaluation_request& item
+ )
+ {
+ std::swap(m_has_been_evaluated, item.m_has_been_evaluated);
+ std::swap(req, item.req);
+ std::swap(info, item.info);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ size_t function_evaluation_request::
+ function_idx (
+ ) const
+ {
+ return info->function_idx;
+ }
+
+ const matrix<double,0,1>& function_evaluation_request::
+ x (
+ ) const
+ {
+ return req.x;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool function_evaluation_request::
+ has_been_evaluated (
+ ) const
+ {
+ return m_has_been_evaluated;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ function_evaluation_request::
+ ~function_evaluation_request()
+ {
+ if (!m_has_been_evaluated)
+ {
+ std::lock_guard<std::mutex> lock(*info->m);
+
+ // remove the evaluation request from the outstanding list.
+ auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
+ info->outstanding_evals.erase(i);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void function_evaluation_request::
+ set (
+ double y
+ )
+ {
+ DLIB_CASSERT(has_been_evaluated() == false);
+ std::lock_guard<std::mutex> lock(*info->m);
+
+ m_has_been_evaluated = true;
+
+
+ // move the evaluation from outstanding to complete
+ auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
+ DLIB_CASSERT(i != info->outstanding_evals.end());
+ info->outstanding_evals.erase(i);
+ info->ub.add(function_evaluation(req.x,y));
+
+
+ // Now do trust region radius maintenance and keep track of the best objective
+ // values and all that.
+ if (req.was_trust_region_generated_request)
+ {
+ // Adjust trust region radius based on how good this evaluation
+ // was.
+ double measured_improvement = y-req.anchor_objective_value;
+ double rho = measured_improvement/std::abs(req.predicted_improvement);
+ //std::cout << "rho: "<< rho << std::endl;
+ //std::cout << "radius: "<< info->radius << std::endl;
+ if (rho < 0.25)
+ info->radius *= 0.5;
+ else if (rho > 0.75)
+ info->radius *= 2;
+ }
+
+ if (y > info->best_objective_value)
+ {
+ if (!req.was_trust_region_generated_request && length(req.x - info->best_x) > info->radius*1.001)
+ {
+ //std::cout << "reset radius because of big move, " << length(req.x - info->best_x) << " radius was " << info->radius << std::endl;
+ // reset trust region radius since we made a big move. Doing this will
+ // cause the radius to be reset to the size of the local region.
+ info->radius = 0;
+ }
+ info->best_objective_value = y;
+ info->best_x = std::move(req.x);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ global_function_search::
+ global_function_search(
+ const function_spec& function
+ ) : global_function_search(std::vector<function_spec>(1,function)) {}
+
+// ----------------------------------------------------------------------------------------
+
+ global_function_search::
+ global_function_search(
+ const std::vector<function_spec>& functions_
+ )
+ {
+ DLIB_CASSERT(functions_.size() > 0);
+ m = std::make_shared<std::mutex>();
+ functions.reserve(functions_.size());
+ for (size_t i = 0; i < functions_.size(); ++i)
+ functions.emplace_back(std::make_shared<gopt_impl::funct_info>(functions_[i],i,m));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ global_function_search::
+ global_function_search(
+ const std::vector<function_spec>& functions_,
+ const std::vector<std::vector<function_evaluation>>& initial_function_evals,
+ const double relative_noise_magnitude_
+ ) :
+ global_function_search(functions_)
+ {
+ DLIB_CASSERT(functions_.size() > 0);
+ DLIB_CASSERT(functions_.size() == initial_function_evals.size());
+ DLIB_CASSERT(relative_noise_magnitude >= 0);
+ relative_noise_magnitude = relative_noise_magnitude_;
+ for (size_t i = 0; i < initial_function_evals.size(); ++i)
+ {
+ functions[i]->ub = upper_bound_function(initial_function_evals[i], relative_noise_magnitude);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ size_t global_function_search::
+ num_functions(
+ ) const
+ {
+ return functions.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void global_function_search::
+ set_seed (
+ time_t seed
+ )
+ {
+ rnd = dlib::rand(seed);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void global_function_search::
+ get_function_evaluations (
+ std::vector<function_spec>& specs,
+ std::vector<std::vector<function_evaluation>>& function_evals
+ ) const
+ {
+ std::lock_guard<std::mutex> lock(*m);
+ specs.clear();
+ function_evals.clear();
+ for (size_t i = 0; i < functions.size(); ++i)
+ {
+ specs.emplace_back(functions[i]->spec);
+ function_evals.emplace_back(functions[i]->ub.get_points());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void global_function_search::
+ get_best_function_eval (
+ matrix<double,0,1>& x,
+ double& y,
+ size_t& function_idx
+ ) const
+ {
+ DLIB_CASSERT(num_functions() != 0);
+
+ std::lock_guard<std::mutex> lock(*m);
+
+ // find the largest value
+ auto& info = *best_function(function_idx);
+ y = info.best_objective_value;
+ x = info.best_x;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ function_evaluation_request global_function_search::
+ get_next_x (
+ )
+ {
+ DLIB_CASSERT(num_functions() != 0);
+
+ using namespace gopt_impl;
+
+ std::lock_guard<std::mutex> lock(*m);
+
+
+ // the first thing we do is make sure each function has at least max(3,dimensionality of function) evaluations
+ for (auto& info : functions)
+ {
+ const long dims = info->spec.lower.size();
+ // If this is the very beginning of the optimization process
+ if (info->ub.num_points()+info->outstanding_evals.size() < 1)
+ {
+ outstanding_function_eval_request new_req;
+ new_req.request_id = next_request_id++;
+ // Pick the point right in the center of the bounds to evaluate first since
+ // people will commonly center the bound on a location they think is good.
+ // So might as well try there first.
+ new_req.x = (info->spec.lower + info->spec.upper)/2.0;
+ for (long i = 0; i < new_req.x.size(); ++i)
+ {
+ if (info->spec.is_integer_variable[i])
+ new_req.x(i) = std::round(new_req.x(i));
+ }
+ info->outstanding_evals.emplace_back(new_req);
+ return function_evaluation_request(new_req,info);
+ }
+ else if (info->ub.num_points() < std::max<long>(3,dims))
+ {
+ outstanding_function_eval_request new_req;
+ new_req.request_id = next_request_id++;
+ new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
+ info->outstanding_evals.emplace_back(new_req);
+ return function_evaluation_request(new_req,info);
+ }
+ }
+
+
+ if (do_trust_region_step && !has_outstanding_trust_region_request())
+ {
+ // find the currently best performing function, we will do a trust region
+ // step on it.
+ auto info = best_function();
+ const long dims = info->spec.lower.size();
+ // if we have enough points to do a trust region step
+ if (info->ub.num_points() > dims+1)
+ {
+ auto tmp = pick_next_sample_using_trust_region(info->ub.get_points(),
+ info->radius, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
+ //std::cout << "QP predicted improvement: "<< tmp.predicted_improvement << std::endl;
+ if (tmp.predicted_improvement > min_trust_region_epsilon)
+ {
+ do_trust_region_step = false;
+ outstanding_function_eval_request new_req;
+ new_req.request_id = next_request_id++;
+ new_req.x = tmp.best_x;
+ new_req.was_trust_region_generated_request = true;
+ new_req.anchor_objective_value = info->best_objective_value;
+ new_req.predicted_improvement = tmp.predicted_improvement;
+ info->outstanding_evals.emplace_back(new_req);
+ return function_evaluation_request(new_req, info);
+ }
+ }
+ }
+
+ // make it so we alternate between upper bounded and trust region steps.
+ do_trust_region_step = true;
+
+ if (rnd.get_random_double() >= pure_random_search_probability)
+ {
+ // pick a point at random to sample according to the upper bound
+ double best_upper_bound = -std::numeric_limits<double>::infinity();
+ std::shared_ptr<funct_info> best_funct;
+ matrix<double,0,1> next_sample;
+ // so figure out if any function has a good upper bound and if so pick the
+ // function with the largest upper bound for evaluation.
+ for (auto& info : functions)
+ {
+ auto tmp = pick_next_sample_as_max_upper_bound(rnd,
+ info->build_upper_bound_with_all_function_evals(), info->spec.lower, info->spec.upper,
+ info->spec.is_integer_variable, num_random_samples);
+ if (tmp.predicted_improvement > 0 && tmp.upper_bound > best_upper_bound)
+ {
+ best_upper_bound = tmp.upper_bound;
+ next_sample = std::move(tmp.x);
+ best_funct = info;
+ }
+ }
+
+ // if we found a good function to evaluate then return that.
+ if (best_funct)
+ {
+ outstanding_function_eval_request new_req;
+ new_req.request_id = next_request_id++;
+ new_req.x = std::move(next_sample);
+ best_funct->outstanding_evals.emplace_back(new_req);
+ return function_evaluation_request(new_req, best_funct);
+ }
+ }
+
+
+ // pick entirely at random
+ size_t function_idx = rnd.get_integer(functions.size());
+ auto info = functions[function_idx];
+ outstanding_function_eval_request new_req;
+ new_req.request_id = next_request_id++;
+ new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
+ info->outstanding_evals.emplace_back(new_req);
+ return function_evaluation_request(new_req, info);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double global_function_search::
+ get_pure_random_search_probability (
+ ) const
+ {
+ return pure_random_search_probability;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void global_function_search::
+ set_pure_random_search_probability (
+ double prob
+ )
+ {
+ DLIB_CASSERT(0 <= prob && prob <= 1);
+ pure_random_search_probability = prob;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double global_function_search::
+ get_solver_epsilon (
+ ) const
+ {
+ return min_trust_region_epsilon;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void global_function_search::
+ set_solver_epsilon (
+ double eps
+ )
+ {
+ DLIB_CASSERT(0 <= eps);
+ min_trust_region_epsilon = eps;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double global_function_search::
+ get_relative_noise_magnitude (
+ ) const
+ {
+ return relative_noise_magnitude;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void global_function_search::
+ set_relative_noise_magnitude (
+ double value
+ )
+ {
+ DLIB_CASSERT(0 <= value);
+ relative_noise_magnitude = value;
+ if (m)
+ {
+ std::lock_guard<std::mutex> lock(*m);
+ // recreate all the upper bound functions with the new relative noise magnitude
+ for (auto& f : functions)
+ f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ size_t global_function_search::
+ get_monte_carlo_upper_bound_sample_num (
+ ) const
+ {
+ return num_random_samples;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void global_function_search::
+ set_monte_carlo_upper_bound_sample_num (
+ size_t num
+ )
+ {
+ DLIB_CASSERT(0 <= num);
+ num_random_samples = num;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::shared_ptr<gopt_impl::funct_info> global_function_search::
+ best_function(
+ ) const
+ {
+ size_t idx = 0;
+ return best_function(idx);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::shared_ptr<gopt_impl::funct_info> global_function_search::
+ best_function(
+ size_t& idx
+ ) const
+ {
+ auto compare = [](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b)
+ { return a->best_objective_value < b->best_objective_value; };
+
+ auto i = std::max_element(functions.begin(), functions.end(), compare);
+
+ idx = std::distance(functions.begin(),i);
+ return *i;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool global_function_search::
+ has_outstanding_trust_region_request (
+ ) const
+ {
+ for (auto& f : functions)
+ {
+ for (auto& i : f->outstanding_evals)
+ {
+ if (i.was_trust_region_generated_request)
+ return true;
+ }
+ }
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
diff --git a/ml/dlib/dlib/global_optimization/global_function_search.h b/ml/dlib/dlib/global_optimization/global_function_search.h
new file mode 100644
index 000000000..fa036884a
--- /dev/null
+++ b/ml/dlib/dlib/global_optimization/global_function_search.h
@@ -0,0 +1,245 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GLOBAL_FuNCTION_SEARCH_Hh_
+#define DLIB_GLOBAL_FuNCTION_SEARCH_Hh_
+
+#include "global_function_search_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include <mutex>
+#include "../rand.h"
+#include "upper_bound_function.h"
+#include "../test_for_odr_violations.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct function_spec
+ {
+ function_spec(
+ matrix<double,0,1> bound1,
+ matrix<double,0,1> bound2
+ );
+
+ function_spec(
+ matrix<double,0,1> bound1,
+ matrix<double,0,1> bound2,
+ std::vector<bool> is_integer
+ );
+
+ matrix<double,0,1> lower;
+ matrix<double,0,1> upper;
+ std::vector<bool> is_integer_variable;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ namespace gopt_impl
+ {
+ struct outstanding_function_eval_request
+ {
+ size_t request_id = 0; // unique id for this eval request
+ matrix<double,0,1> x; // function x to evaluate
+
+ // trust region specific stuff
+ bool was_trust_region_generated_request = false;
+ double predicted_improvement = std::numeric_limits<double>::quiet_NaN();
+ double anchor_objective_value = std::numeric_limits<double>::quiet_NaN(); // objective value at center of TR step
+
+ bool operator==(const outstanding_function_eval_request& item) const { return request_id == item.request_id; }
+ };
+
+ struct funct_info
+ {
+ funct_info() = delete;
+ funct_info(const funct_info&) = delete;
+ funct_info& operator=(const funct_info&) = delete;
+
+ funct_info(
+ const function_spec& spec,
+ size_t function_idx,
+ const std::shared_ptr<std::mutex>& m
+ ) :
+ spec(spec), function_idx(function_idx), m(m)
+ {
+ best_x = zeros_matrix(spec.lower);
+ }
+
+ upper_bound_function build_upper_bound_with_all_function_evals (
+ ) const;
+
+ static double find_nn (
+ const std::vector<function_evaluation>& evals,
+ const matrix<double,0,1>& x
+ );
+
+
+ function_spec spec;
+ size_t function_idx = 0;
+ std::shared_ptr<std::mutex> m;
+ upper_bound_function ub;
+ std::vector<outstanding_function_eval_request> outstanding_evals;
+ matrix<double,0,1> best_x;
+ double best_objective_value = -std::numeric_limits<double>::infinity();
+ double radius = 0;
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class function_evaluation_request
+ {
+ public:
+
+ function_evaluation_request() = delete;
+ function_evaluation_request(const function_evaluation_request&) = delete;
+ function_evaluation_request& operator=(const function_evaluation_request&) = delete;
+
+
+ function_evaluation_request(function_evaluation_request&& item);
+ function_evaluation_request& operator=(function_evaluation_request&& item);
+
+ ~function_evaluation_request();
+
+ size_t function_idx (
+ ) const;
+
+ const matrix<double,0,1>& x (
+ ) const;
+
+ bool has_been_evaluated (
+ ) const;
+
+ void set (
+ double y
+ );
+
+ void swap(function_evaluation_request& item);
+
+ private:
+
+ friend class global_function_search;
+
+ explicit function_evaluation_request(
+ const gopt_impl::outstanding_function_eval_request& req,
+ const std::shared_ptr<gopt_impl::funct_info>& info
+ ) : req(req), info(info) {}
+
+ bool m_has_been_evaluated = false;
+ gopt_impl::outstanding_function_eval_request req;
+ std::shared_ptr<gopt_impl::funct_info> info;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class global_function_search
+ {
+ public:
+
+ global_function_search() = default;
+
+ explicit global_function_search(
+ const function_spec& function
+ );
+
+ explicit global_function_search(
+ const std::vector<function_spec>& functions_
+ );
+
+ global_function_search(
+ const std::vector<function_spec>& functions_,
+ const std::vector<std::vector<function_evaluation>>& initial_function_evals,
+ const double relative_noise_magnitude = 0.001
+ );
+
+ global_function_search(const global_function_search&) = delete;
+ global_function_search& operator=(const global_function_search& item) = delete;
+
+ global_function_search(global_function_search&& item) = default;
+ global_function_search& operator=(global_function_search&& item) = default;
+
+ size_t num_functions(
+ ) const;
+
+ void set_seed (
+ time_t seed
+ );
+
+ void get_function_evaluations (
+ std::vector<function_spec>& specs,
+ std::vector<std::vector<function_evaluation>>& function_evals
+ ) const;
+
+ void get_best_function_eval (
+ matrix<double,0,1>& x,
+ double& y,
+ size_t& function_idx
+ ) const;
+
+ function_evaluation_request get_next_x (
+ );
+
+ double get_pure_random_search_probability (
+ ) const;
+
+ void set_pure_random_search_probability (
+ double prob
+ );
+
+ double get_solver_epsilon (
+ ) const;
+
+ void set_solver_epsilon (
+ double eps
+ );
+
+ double get_relative_noise_magnitude (
+ ) const;
+
+ void set_relative_noise_magnitude (
+ double value
+ );
+
+ size_t get_monte_carlo_upper_bound_sample_num (
+ ) const;
+
+ void set_monte_carlo_upper_bound_sample_num (
+ size_t num
+ );
+
+ private:
+
+ std::shared_ptr<gopt_impl::funct_info> best_function(
+ ) const;
+
+ std::shared_ptr<gopt_impl::funct_info> best_function(
+ size_t& idx
+ ) const;
+
+ bool has_outstanding_trust_region_request (
+ ) const;
+
+
+ dlib::rand rnd;
+ double pure_random_search_probability = 0.02;
+ double min_trust_region_epsilon = 0;
+ double relative_noise_magnitude = 0.001;
+ size_t num_random_samples = 5000;
+ bool do_trust_region_step = true;
+
+ size_t next_request_id = 1;
+
+ std::vector<std::shared_ptr<gopt_impl::funct_info>> functions;
+ std::shared_ptr<std::mutex> m;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GLOBAL_FuNCTION_SEARCH_Hh_
+
diff --git a/ml/dlib/dlib/global_optimization/global_function_search_abstract.h b/ml/dlib/dlib/global_optimization/global_function_search_abstract.h
new file mode 100644
index 000000000..c8bfc3993
--- /dev/null
+++ b/ml/dlib/dlib/global_optimization/global_function_search_abstract.h
@@ -0,0 +1,605 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_GLOBAL_FuNCTION_SEARCH_ABSTRACT_Hh_
+#ifdef DLIB_GLOBAL_FuNCTION_SEARCH_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "upper_bound_function_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct function_spec
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple struct that lets you define the valid inputs to a
+ multivariate function. It lets you define bound constraints for each
+ variable as well as say if a variable is integer valued or not. Therefore,
+ an instance of this struct says that a function takes upper.size() input
+ variables, where the ith variable must be in the range [lower(i) upper(i)]
+ and be an integer if is_integer_variable[i]==true.
+ !*/
+
+ function_spec(
+ matrix<double,0,1> bound1,
+ matrix<double,0,1> bound2
+ );
+ /*!
+ requires
+ - bound1.size() == bound2.size()
+ - for all valid i: bound1(i) != bound2(i)
+ ensures
+ - #is_integer_variable.size() == bound1.size()
+ - #lower.size() == bound1.size()
+ - #upper.size() == bound1.size()
+ - for all valid i:
+ - #is_integer_variable[i] == false
+ - #lower(i) == min(bound1(i), bound2(i))
+ - #upper(i) == max(bound1(i), bound2(i))
+ !*/
+
+ function_spec(
+ matrix<double,0,1> lower,
+ matrix<double,0,1> upper,
+ std::vector<bool> is_integer
+ );
+ /*!
+ requires
+ - bound1.size() == bound2.size() == is_integer.size()
+ - for all valid i: bound1(i) != bound2(i)
+ ensures
+ - #is_integer_variable.size() == bound1.size()
+ - #lower.size() == bound1.size()
+ - #upper.size() == bound1.size()
+ - for all valid i:
+ - #is_integer_variable[i] == is_integer[i]
+ - #lower(i) == min(bound1(i), bound2(i))
+ - #upper(i) == max(bound1(i), bound2(i))
+ !*/
+
+ matrix<double,0,1> lower;
+ matrix<double,0,1> upper;
+ std::vector<bool> is_integer_variable;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class function_evaluation_request
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a request, by the global_function_search object, to
+ evaluate a real-valued function and report back the results.
+
+ THREAD SAFETY
+ You shouldn't let more than one thread touch a function_evaluation_request
+ at the same time. However, it is safe to send instances of this class to
+ other threads for processing. This lets you evaluate multiple
+ function_evaluation_requests in parallel. Any appropriate synchronization
+ with regard to the originating global_function_search instance is handled
+ automatically.
+ !*/
+
+ public:
+
+ // You can't make or copy this object, the only way to get one is from the
+ // global_function_search class via get_next_x().
+ function_evaluation_request() = delete;
+ function_evaluation_request(const function_evaluation_request&) = delete;
+ function_evaluation_request& operator=(const function_evaluation_request&) = delete;
+
+ // You can however move and swap this object.
+ function_evaluation_request(function_evaluation_request&& item);
+ function_evaluation_request& operator=(function_evaluation_request&& item);
+ /*!
+ ensures
+ - *this takes the state of item.
+ - #item.has_been_evaluated() == true
+ !*/
+
+ ~function_evaluation_request(
+ );
+ /*!
+ ensures
+ - frees all resources associated with this object.
+ - It's fine to destruct function_evaluation_requests even if they haven't
+ been evaluated yet. If this happens it will simply be as if the request
+ was never issued.
+ !*/
+
+ size_t function_idx (
+ ) const;
+ /*!
+ ensures
+ - Returns the function index that identifies which function is to be
+ evaluated.
+ !*/
+
+ const matrix<double,0,1>& x (
+ ) const;
+ /*!
+ ensures
+ - returns the input parameters to the function to be evaluated.
+ !*/
+
+ bool has_been_evaluated (
+ ) const;
+ /*!
+ ensures
+ - If this evaluation request is still outstanding then returns false,
+ otherwise returns true. That is, if the global_function_search is still
+ waiting for you report back by calling set() then
+ has_been_evaluated()==false.
+ !*/
+
+ void set (
+ double y
+ );
+ /*!
+ requires
+ - has_been_evaluated() == false
+ ensures
+ - #has_been_evaluated() == true
+ - Notifies the global_function_search instance that created this object
+ that when the function_idx()th function is evaluated with x() as input
+ then the output is y.
+ !*/
+
+ void swap(
+ function_evaluation_request& item
+ );
+ /*!
+ ensures
+ - swaps the state of *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class global_function_search
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object performs global optimization of a set of user supplied
+ functions. The goal is to maximize the following objective function:
+ max_{function_i,x_i}: function_i(x_i)
+ subject to bound constraints on each element of x_i. Moreover, each
+ element of x_i can be either real valued or integer valued. Each of the
+ functions can also take a different number of variables. Therefore, the
+ final result of the optimization tells you which function produced the
+ largest output and what input (i.e. the x value) to that function is
+ necessary to obtain that maximal value.
+
+ Importantly, the global_function_search object does not require the user to
+ supply derivatives. Moreover, the functions may contain discontinuities,
+ behave stochastically, and have many local maxima. The global_function_search
+ object will attempt to find the global optima in the face of these challenges.
+ It is also designed to use as few function evaluations as possible, making
+ it suitable for optimizing functions that are very expensive to evaluate.
+
+ It does this by alternating between two modes. A global exploration mode
+ and a local optima refinement mode. This is accomplished by building and
+ maintaining two models of the objective function:
+ 1. A global model that upper bounds our objective function. This is a
+ non-parametric piecewise linear model based on all function
+ evaluations ever seen by the global_function_search object.
+ 2. A local quadratic model fit around the best point seen so far.
+
+ The optimization procedure therefore looks like this:
+
+ while(not done)
+ {
+ DO GLOBAL EXPLORE STEP:
+ Find the point that maximizes the upper bounding model since
+ that is the point with the largest possible improvement in the
+ objective function.
+
+ Evaluate the new point and incorporate it into our models.
+
+ DO LOCAL REFINEMENT STEP:
+ Find the optimal solution to the local quadratic model.
+
+ If this point looks like it will improve on the "best point seen
+ so far" by at least get_solver_epsilon() then we evaluate that
+ point and incorporate it into our models, otherwise we ignore
+ it.
+ }
+
+ You can see that we alternate between global search and local refinement,
+ except in the case where the local model seems to have converged to within
+ get_solver_epsilon() accuracy. In that case only global search steps are
+ used. We do this in the hope that the global search will find a new and
+ better local optima to explore, which would then reactivate local
+ refinement when it has something productive to do.
+
+
+ Now let's turn our attention to the specific API defined by the
+ global_function_search object. We will begin by showing a short example of
+ its use:
+
+ // Suppose we want to find which of these functions, F() and G(), have
+ // the largest output and what input is necessary to produce the
+ // maximal output.
+ auto F = [](double a, double b) { return -std::pow(a-2,2.0) - std::pow(b-4,2.0); };
+ auto G = [](double x) { return 2-std::pow(x-5,2.0); };
+
+ // We first define function_spec objects that specify bounds on the
+ // inputs to each function. The search process will only search within
+ // these bounds.
+ function_spec spec_F({-10,-10}, {10,10});
+ function_spec spec_G({-2}, {6});
+ // Then we create a global_function_search object with those function specifications.
+ global_function_search opt({spec_F, spec_G});
+
+ // Here we run 15 iterations of the search process. Note that the user
+ // of global_function_search writes the main solver loop, which is
+ // somewhat unusual. We will discuss why that is in a moment, but for
+ // now let's look at this example.
+ for (int i = 0; i < 15; ++i)
+ {
+ // All we do here is ask the global_function_search object what to
+ // evaluate next, then do what it asked, and then report the
+ // results back by calling function_evaluation_request's set()
+ // method.
+ function_evaluation_request next = opt.get_next_x();
+ // next.function_idx() tells you which of the functions you should
+ // evaluate. We have 2 functions here (F and G) so function_idx()
+ // can take only the values 0 and 1. If, for example, we had 10
+ // functions it would take the values 0 through 9.
+ if (next.function_idx() == 0)
+ {
+ // Call F with the inputs requested by the
+ // global_function_search and report them back.
+ double a = next.x()(0);
+ double b = next.x()(1);
+ next.set(F(a,b)); // Tell the solver what happened.
+ }
+ else
+ {
+ double x = next.x()(0);
+ next.set(G(x));
+ }
+ }
+
+ // Find out what point gave the largest outputs:
+ matrix<double,0,1> x;
+ double y;
+ size_t function_idx;
+ opt.get_best_function_eval(x,y,function_idx);
+
+ cout << "function_idx: "<< function_idx << endl;
+ cout << "y: " << y << endl;
+ cout << "x: " << x << endl;
+
+ The above cout statements will print this:
+
+ function_idx: 1
+ y: 2
+ x: 5
+
+ Which is the correct result since G(5) gives the largest possible output in
+ our example.
+
+ So why does the user write the main loop? Why isn't it embedded inside
+ dlib? Well, there are two answers to this. The first is that it is. Most
+ users should just call dlib::find_max_global() which does exactly that, it
+ runs the loop for you. However, the API shown above gives you the
+ opportunity to run multiple function evaluations in parallel. For
+ instance, it is perfectly valid to call get_next_x() multiple times and
+ send the resulting function_evaluation_request objects to separate threads
+ for processing. Those separate threads can run the functions being
+ optimized (e.g. F and G or whatever) and report back by calling
+ function_evaluation_request::set(). You could even spread the work across
+ a compute cluster if you have one.
+
+ So what happens if you have N outstanding function evaluation requests?
+ Or in other words, what happens if you called get_next_x() N times and
+ haven't yet called their set() methods? Well, 1 of the N requests will be
+ a local refinement step while the N-1 other requests will be global
+ exploration steps generated from the current upper bounding model. This
+ should give you an idea of the usefulness of this kind of parallelism. If
+ for example, your functions being optimized were simple convex functions
+ this kind of parallelism wouldn't help since essentially all the
+ interesting work in the solver is going to be done by the local optimizer.
+ However, if your function has a lot of local optima, running many global
+ exploration steps in parallel might significantly reduce the time it takes
+ to find a good solution.
+
+ It should also be noted that our upper bounding model is implemented by the
+ dlib::upper_bound_function object, which is a tool that allows us to create
+ a tight upper bound on our objective function. This upper bound is
+ non-parametric and gets progressively more accurate as the optimization
+ progresses, but also more and more expensive to maintain. It causes the
+ runtime of the entire optimization procedure to be O(N^2) where N is the
+ number of objective function evaluations. So problems that require millions
+ of function evaluations to find a good solution are not appropriate for the
+ global_function_search tool. However, if your objective function is very
+ expensive to evaluate then this relatively expensive upper bounding model
+ is well worth its computational cost.
+
+ Finally, let's introduce some background literature on this algorithm. The
+ two most relevant papers in the optimization literature are:
+ Global optimization of Lipschitz functions Malherbe, Cédric and Vayatis,
+ Nicolas International Conference on Machine Learning - 2017
+ and
+ The NEWUOA software for unconstrained optimization without derivatives By
+ M.J.D. Powell, 40th Workshop on Large Scale Nonlinear Optimization (Erice,
+ Italy, 2004)
+
+ Our upper bounding model is an extension of the AdaLIPO method in the
+ Malherbe. See the documentation of dlib::upper_bound_function for more
+ details on that, as we make a number of important extensions. The other
+ part of our method, our local refinement model, is essentially the same
+ type of trust region model proposed by Powell in the above paper. That is,
+ each time we do a local refinement step we identify the best point seen so
+ far, fit a quadratic function around it using the function evaluations we
+ have collected so far, and then use a simple trust region procedure to
+ decide the next best point to evaluate based on our quadratic model.
+
+ The method proposed by Malherbe gives excellent global search performance
+ but has terrible convergence properties in the area around a maxima.
+ Powell's method on the other hand has excellent convergence in the area
+ around a local maxima, as expected by a quadratic trust region method, but
+ is aggressively local maxima seeking. It will simply get stuck in the
+ nearest local optima. Combining the two together as we do here gives us
+ excellent performance in both global search and final convergence speed
+ near a local optima. Causing the global_function_search to perform well
+ for functions with many local optima while still giving high precision
+ solutions. For instance, on typical tests problems, like the Holder table
+ function, the global_function_search object can reliably find the globally
+ optimal solution to full floating point precision in under a few hundred
+ steps.
+
+
+ THREAD SAFETY
+ You shouldn't let more than one thread touch a global_function_search
+ instance at the same time.
+ !*/
+
+ public:
+
+ global_function_search(
+ );
+ /*!
+ ensures
+ - #num_functions() == 0
+ - #get_relative_noise_magnitude() == 0.001
+ - #get_solver_epsilon() == 0
+ - #get_monte_carlo_upper_bound_sample_num() == 5000
+ - #get_pure_random_search_probability() == 0.02
+ !*/
+
+ explicit global_function_search(
+ const function_spec& function
+ );
+ /*!
+ ensures
+ - #num_functions() == 1
+ - #get_function_evaluations() will indicate that there are no function evaluations yet.
+ - #get_relative_noise_magnitude() == 0.001
+ - #get_solver_epsilon() == 0
+ - #get_monte_carlo_upper_bound_sample_num() == 5000
+ - #get_pure_random_search_probability() == 0.02
+ !*/
+
+ explicit global_function_search(
+ const std::vector<function_spec>& functions
+ );
+ /*!
+ ensures
+ - #num_functions() == functions.size()
+ - #get_function_evaluations() will indicate that there are no function evaluations yet.
+ - #get_relative_noise_magnitude() == 0.001
+ - #get_solver_epsilon() == 0
+ - #get_monte_carlo_upper_bound_sample_num() == 5000
+ - #get_pure_random_search_probability() == 0.02
+ !*/
+
+ global_function_search(
+ const std::vector<function_spec>& functions,
+ const std::vector<std::vector<function_evaluation>>& initial_function_evals,
+ const double relative_noise_magnitude = 0.001
+ );
+ /*!
+ requires
+ - functions.size() == initial_function_evals.size()
+ - relative_noise_magnitude >= 0
+ ensures
+ - #num_functions() == functions.size()
+ - #get_function_evaluations() will return the provided initial_function_evals.
+ - #get_relative_noise_magnitude() == relative_noise_magnitude
+ - #get_solver_epsilon() == 0
+ - #get_monte_carlo_upper_bound_sample_num() == 5000
+ - #get_pure_random_search_probability() == 0.02
+ !*/
+
+ // This object can't be copied.
+ global_function_search(const global_function_search&) = delete;
+ global_function_search& operator=(const global_function_search& item) = delete;
+ // But it can be moved
+ global_function_search(global_function_search&& item) = default;
+ global_function_search& operator=(global_function_search&& item) = default;
+ /*!
+ ensures
+ - moves the state of item into *this
+ - #item.num_functions() == 0
+ !*/
+
+ void set_seed (
+ time_t seed
+ );
+ /*!
+ ensures
+ - Part of this object's algorithm uses random sampling to decide what
+ points to evaluate next. Calling set_seed() lets you set the seed used
+ by the random number generator. Note that if you don't call set_seed()
+ you will always get the same deterministic behavior.
+ !*/
+
+ size_t num_functions(
+ ) const;
+ /*!
+ ensures
+ - returns the number of functions being optimized.
+ !*/
+
+ void get_function_evaluations (
+ std::vector<function_spec>& specs,
+ std::vector<std::vector<function_evaluation>>& function_evals
+ ) const;
+ /*!
+ ensures
+ - #specs.size() == num_functions()
+ - #function_evals.size() == num_functions()
+ - This function allows you to query the state of the solver. In
+ particular, you can find the function_specs for each function being
+ optimized and their recorded evaluations.
+ - for all valid i:
+ - function_evals[i] == all the function evaluations that have been
+ recorded for the ith function (i.e. the function with the
+ function_spec #specs[i]). That is, this is the record of all the x
+ and y values reported back by function_evaluation_request::set()
+ calls.
+ !*/
+
+ void get_best_function_eval (
+ matrix<double,0,1>& x,
+ double& y,
+ size_t& function_idx
+ ) const;
+ /*!
+ requires
+ - num_functions() != 0
+ ensures
+ - if (no function evaluations have been recorded yet) then
+ - The outputs of this function are in a valid but undefined state.
+ - else
+ - This function tells you which function has produced the largest
+ output seen so far. It also tells you the inputs to that function
+ that leads to those outputs (x) as well as the output value itself (y).
+ - 0 <= #function_idx < num_functions()
+ - #function_idx == the index of the function that produced the largest output seen so far.
+ - #x == the input parameters to the function that produced the largest outputs seen so far.
+ - #y == the largest output seen so far.
+ !*/
+
+ function_evaluation_request get_next_x (
+ );
+ /*!
+ requires
+ - num_functions() != 0
+ ensures
+ - Generates and returns a function evaluation request. See the discussion
+ in the WHAT THIS OBJECT REPRESENTS section above for details.
+ !*/
+
+ double get_pure_random_search_probability (
+ ) const;
+ /*!
+ ensures
+ - When we decide to do a global explore step we will, with probability
+ get_pure_random_search_probability(), sample a point completely at random
+ rather than using the upper bounding model. Therefore, if you set this
+ probability to 0 then we will depend entirely on the upper bounding
+ model. Alternatively, if you set get_pure_random_search_probability() to
+ 1 then we won't use the upper bounding model at all and instead use pure
+ random search to do global exploration. Pure random search is much
+ faster than using the upper bounding model, so if you know that your
+ objective function is especially simple it can be faster to use pure
+ random search. However, if you really know your function that well you
+ should probably use a gradient based optimizer :)
+ !*/
+
+ void set_pure_random_search_probability (
+ double prob
+ );
+ /*!
+ requires
+ - prob >= 0
+ ensures
+ - #get_pure_random_search_probability() == prob
+ !*/
+
+ double get_solver_epsilon (
+ ) const;
+ /*!
+ ensures
+ - As discussed in the WHAT THIS OBJECT REPRESENTS section, we only do a
+ local refinement step if we haven't already found the peak of the current
+ local optima. get_solver_epsilon() sets the tolerance for deciding if
+ the local search method has found the local optima. Therefore, when the
+ local trust region model runs we check if its predicted improvement in
+ the objective function is greater than get_solver_epsilon(). If it isn't
+ then we assume it has converged and we should focus entirely on global
+ search.
+
+ This means that, for instance, setting get_solver_epsilon() to 0
+ essentially instructs the solver to find each local optima to full
+ floating point precision and only then to focus on pure global search.
+ !*/
+
+ void set_solver_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps >= 0
+ ensures
+ - #get_solver_epsilon() == eps
+ !*/
+
+ double get_relative_noise_magnitude (
+ ) const;
+ /*!
+ ensures
+ - Returns the value of the relative noise magnitude parameter to the
+ dlib::upper_bound_function's used by this object. See the
+ upper_bound_function's documentation for a detailed discussion of this
+ parameter's meaning. Most users should leave this value as its default
+ setting.
+ !*/
+
+ void set_relative_noise_magnitude (
+ double value
+ );
+ /*!
+ requires
+ - value >= 0
+ ensures
+ - #get_relative_noise_magnitude() == value
+ !*/
+
+ size_t get_monte_carlo_upper_bound_sample_num (
+ ) const;
+ /*!
+ ensures
+ - To find the point that maximizes the upper bounding model we use
+ get_monte_carlo_upper_bound_sample_num() random evaluations and select
+ the largest upper bound from that set. So this parameter influences how
+ well we estimate the maximum point on the upper bounding model.
+ !*/
+
+ void set_monte_carlo_upper_bound_sample_num (
+ size_t num
+ );
+ /*!
+ requires
+ - num > 0
+ ensures
+ - #get_monte_carlo_upper_bound_sample_num() == num
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GLOBAL_FuNCTION_SEARCH_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/global_optimization/upper_bound_function.h b/ml/dlib/dlib/global_optimization/upper_bound_function.h
new file mode 100644
index 000000000..d1957623e
--- /dev/null
+++ b/ml/dlib/dlib/global_optimization/upper_bound_function.h
@@ -0,0 +1,286 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_UPPER_bOUND_FUNCTION_Hh_
+#define DLIB_UPPER_bOUND_FUNCTION_Hh_
+
+#include "upper_bound_function_abstract.h"
+#include "../svm/svm_c_linear_dcd_trainer.h"
+#include "../statistics.h"
+#include <limits>
+#include <utility>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct function_evaluation
+ {
+ function_evaluation() = default;
+ function_evaluation(const matrix<double,0,1>& x, double y) :x(x), y(y) {}
+
+ matrix<double,0,1> x;
+ double y = std::numeric_limits<double>::quiet_NaN();
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class upper_bound_function
+ {
+
+ public:
+
+ upper_bound_function(
+ ) = default;
+
+ upper_bound_function(
+ const double relative_noise_magnitude,
+ const double solver_eps
+ ) : relative_noise_magnitude(relative_noise_magnitude), solver_eps(solver_eps)
+ {
+ DLIB_CASSERT(relative_noise_magnitude >= 0);
+ DLIB_CASSERT(solver_eps > 0);
+ }
+
+ explicit upper_bound_function(
+ const std::vector<function_evaluation>& _points,
+ const double relative_noise_magnitude = 0.001,
+ const double solver_eps = 0.0001
+ ) : relative_noise_magnitude(relative_noise_magnitude), solver_eps(solver_eps), points(_points)
+ {
+ DLIB_CASSERT(relative_noise_magnitude >= 0);
+ DLIB_CASSERT(solver_eps > 0);
+
+ if (points.size() > 1)
+ {
+ DLIB_CASSERT(points[0].x.size() > 0, "The vectors can't be empty.");
+
+ const long dims = points[0].x.size();
+ for (auto& p : points)
+ DLIB_CASSERT(p.x.size() == dims, "All the vectors given to upper_bound_function must have the same dimensionality.");
+
+ learn_params();
+ }
+
+ }
+
+ void add (
+ const function_evaluation& point
+ )
+ {
+ DLIB_CASSERT(point.x.size() != 0, "The vectors can't be empty.");
+ if (points.size() == 0)
+ {
+ points.push_back(point);
+ return;
+ }
+
+ DLIB_CASSERT(point.x.size() == dimensionality(), "All the vectors given to upper_bound_function must have the same dimensionality.");
+
+ if (points.size() < 4)
+ {
+ points.push_back(point);
+ *this = upper_bound_function(points, relative_noise_magnitude, solver_eps);
+ return;
+ }
+
+ points.push_back(point);
+ // add constraints between the new point and the old points
+ for (size_t i = 0; i < points.size()-1; ++i)
+ active_constraints.push_back(std::make_pair(i,points.size()-1));
+
+ learn_params();
+ }
+
+ long num_points(
+ ) const
+ {
+ return points.size();
+ }
+
+ long dimensionality(
+ ) const
+ {
+ if (points.size() == 0)
+ return 0;
+ else
+ return points[0].x.size();
+ }
+
+ const std::vector<function_evaluation>& get_points(
+ ) const
+ {
+ return points;
+ }
+
+ double operator() (
+ const matrix<double,0,1>& x
+ ) const
+ {
+ DLIB_CASSERT(num_points() > 0);
+ DLIB_CASSERT(x.size() == dimensionality());
+
+
+
+ double upper_bound = std::numeric_limits<double>::infinity();
+
+ for (size_t i = 0; i < points.size(); ++i)
+ {
+ const double local_bound = points[i].y + std::sqrt(offsets[i] + dot(slopes, squared(x-points[i].x)));
+ upper_bound = std::min(upper_bound, local_bound);
+ }
+
+ return upper_bound;
+ }
+
+ private:
+
+ void learn_params (
+ )
+ {
+ const long dims = points[0].x.size();
+
+ using sample_type = std::vector<std::pair<size_t,double>>;
+ using kernel_type = sparse_linear_kernel<sample_type>;
+ std::vector<sample_type> x;
+ std::vector<double> y;
+
+ // We are going to normalize the data so the values aren't extreme. First, we
+ // collect statistics on our data.
+ std::vector<running_stats<double>> x_rs(dims);
+ running_stats<double> y_rs;
+ for (auto& v : points)
+ {
+ for (long i = 0; i < v.x.size(); ++i)
+ x_rs[i].add(v.x(i));
+ y_rs.add(v.y);
+ }
+
+
+ // compute normalization vectors for the data. The only reason we do this is
+ // to make the optimization well conditioned. In particular, scaling the y
+ // values will prevent numerical errors in the 1-diff*diff computation below that
+ // would otherwise result when diff is really big. Also, scaling the xvalues
+ // to be about 1 will similarly make the optimization more stable and it also
+ // has the added benefit of keeping the relative_noise_magnitude's scale
+ // constant regardless of the size of x values.
+ const double yscale = 1.0/y_rs.stddev();
+ std::vector<double> xscale(dims);
+ for (size_t i = 0; i < xscale.size(); ++i)
+ xscale[i] = 1.0/(x_rs[i].stddev()*yscale); // make it so that xscale[i]*yscale == 1/x_rs[i].stddev()
+
+ sample_type samp;
+ auto add_constraint = [&](long i, long j) {
+ samp.clear();
+ for (long k = 0; k < dims; ++k)
+ {
+ double temp = (points[i].x(k) - points[j].x(k))*xscale[k]*yscale;
+ samp.push_back(std::make_pair(k, temp*temp));
+ }
+
+ if (points[i].y > points[j].y)
+ samp.push_back(std::make_pair(dims + j, relative_noise_magnitude));
+ else
+ samp.push_back(std::make_pair(dims + i, relative_noise_magnitude));
+
+ const double diff = (points[i].y - points[j].y)*yscale;
+ samp.push_back(std::make_pair(dims + points.size(), 1-diff*diff));
+
+ x.push_back(samp);
+ y.push_back(1);
+ };
+
+ if (active_constraints.size() == 0)
+ {
+ x.reserve(points.size()*(points.size()-1)/2);
+ y.reserve(points.size()*(points.size()-1)/2);
+ for (size_t i = 0; i < points.size(); ++i)
+ {
+ for (size_t j = i+1; j < points.size(); ++j)
+ {
+ add_constraint(i,j);
+ }
+ }
+ }
+ else
+ {
+ for (auto& p : active_constraints)
+ add_constraint(p.first, p.second);
+ }
+
+
+
+
+ svm_c_linear_dcd_trainer<kernel_type> trainer;
+ trainer.set_c(std::numeric_limits<double>::infinity());
+ //trainer.be_verbose();
+ trainer.force_last_weight_to_1(true);
+ trainer.set_epsilon(solver_eps);
+
+ svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
+ auto df = trainer.train(x,y, state);
+
+ // save the active constraints for later so we can use them inside add() to add
+ // new points efficiently.
+ if (active_constraints.size() == 0)
+ {
+ long k = 0;
+ for (size_t i = 0; i < points.size(); ++i)
+ {
+ for (size_t j = i+1; j < points.size(); ++j)
+ {
+ if (state.get_alpha()[k++] != 0)
+ active_constraints.push_back(std::make_pair(i,j));
+ }
+ }
+ }
+ else
+ {
+ DLIB_CASSERT(state.get_alpha().size() == active_constraints.size());
+ new_active_constraints.clear();
+ for (size_t i = 0; i < state.get_alpha().size(); ++i)
+ {
+ if (state.get_alpha()[i] != 0)
+ new_active_constraints.push_back(active_constraints[i]);
+ }
+ active_constraints.swap(new_active_constraints);
+ }
+
+ //std::cout << "points.size(): " << points.size() << std::endl;
+ //std::cout << "active_constraints.size(): " << active_constraints.size() << std::endl;
+
+
+ const auto& bv = df.basis_vectors(0);
+ slopes.set_size(dims);
+ for (long i = 0; i < dims; ++i)
+ slopes(i) = bv[i].second*xscale[i]*xscale[i];
+
+ //std::cout << "slopes:" << trans(slopes);
+
+ offsets.assign(points.size(),0);
+
+
+ for (size_t i = 0; i < points.size(); ++i)
+ {
+ offsets[i] += bv[slopes.size()+i].second*relative_noise_magnitude;
+ }
+ }
+
+
+
+ double relative_noise_magnitude = 0.001;
+ double solver_eps = 0.0001;
+ std::vector<std::pair<size_t,size_t>> active_constraints, new_active_constraints;
+
+ std::vector<function_evaluation> points;
+ std::vector<double> offsets; // offsets.size() == points.size()
+ matrix<double,0,1> slopes; // slopes.size() == points[0].first.size()
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_UPPER_bOUND_FUNCTION_Hh_
+
+
diff --git a/ml/dlib/dlib/global_optimization/upper_bound_function_abstract.h b/ml/dlib/dlib/global_optimization/upper_bound_function_abstract.h
new file mode 100644
index 000000000..56b361597
--- /dev/null
+++ b/ml/dlib/dlib/global_optimization/upper_bound_function_abstract.h
@@ -0,0 +1,212 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_UPPER_bOUND_FUNCTION_ABSTRACT_Hh_
+#ifdef DLIB_UPPER_bOUND_FUNCTION_ABSTRACT_Hh_
+
+#include "../matrix.h"
+#include <limits>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct function_evaluation
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object records the output of a real valued function in response to
+ some input.
+
+ In particular, if you have a function F(x) then the function_evaluation is
+ simply a struct that records x and the scalar value F(x).
+ !*/
+
+ function_evaluation() = default;
+ function_evaluation(const matrix<double,0,1>& x, double y) :x(x), y(y) {}
+
+ matrix<double,0,1> x;
+ double y = std::numeric_limits<double>::quiet_NaN();
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class upper_bound_function
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a piecewise linear non-parametric function that can
+ be used to define an upper bound on some more complex and unknown function.
+ To describe this precisely, lets assume there is a function F(x) which you
+ are capable of sampling from but otherwise know nothing about, and that you
+ would like to find an upper bounding function U(x) such that U(x) >= F(x)
+ for any x. It would also be good if U(x)-F(x) was minimal. I.e. we would
+ like U(x) to be a tight upper bound, not something vacuous like U(x) =
+ infinity.
+
+ The upper_bound_function class is a tool for creating this kind of upper
+ bounding function from a set of function_evaluations of F(x). We do this
+ by considering only U(x) of the form:
+ U = [](matrix<double,0,1> x) {
+ double min_ub = infinity;
+ for (size_t i = 0; i < POINTS.size(); ++i) {
+ function_evaluation p = POINTS[i]
+ double local_bound = p.y + sqrt(noise_terms[i] + trans(p.x-x)*M*(p.x-x))
+ min_ub = min(min_ub, local_bound)
+ }
+ return min_ub;
+ }
+ Where POINTS is an array of function_evaluation instances drawn from F(x),
+ M is a diagonal matrix, and noise_terms is an array of scalars.
+
+ To create an upper bound U(x), the upper_bound_function takes a POINTS array
+ containing evaluations of F(x) as input and solves the following quadratic
+ program to find the parameters of U(x):
+
+ min_{M,noise_terms}: sum(squared(M)) + sum(squared(noise_terms/relative_noise_magnitude))
+ s.t. U(POINTS[i].x) >= POINTS[i].y, for all i
+ noise_terms[i] >= 0
+ min(M) >= 0
+ M is a diagonal matrix
+
+ Therefore, the quadratic program finds the U(x) that always upper bounds
+ F(x) on the supplied POINTS, but is otherwise as small as possible.
+
+
+
+ The inspiration for the upper_bound_function object came from the AdaLIPO
+ algorithm from this excellent paper:
+ Global optimization of Lipschitz functions
+ Malherbe, Cédric and Vayatis, Nicolas
+ International Conference on Machine Learning - 2017
+ In that paper, they propose to use a simpler U(x) where noise_terms is
+ always 0 and M is a diagonal matrix where each diagonal element is the same
+ value. Therefore, there is only a single scalar parameter for U(x) in
+ their formulation of the problem. This causes difficulties if F(x) is
+ stochastic or has discontinuities since, without the noise term, M will
+ become really huge and the upper bound becomes vacuously large. It is also
+ problematic if the gradient of F(x) with respect to x contains elements of
+ widely varying magnitude since the simpler formulation of U(x) assumes a
+ uniform rate of change regardless of which dimension is varying.
+ !*/
+
+ public:
+
+ upper_bound_function(
+ );
+ /*!
+ ensures
+ - #num_points() == 0
+ - #dimensionality() == 0
+ !*/
+
+ explicit upper_bound_function(
+ const std::vector<function_evaluation>& points,
+ const double relative_noise_magnitude = 0.001,
+ const double solver_eps = 0.0001
+ );
+ /*!
+ requires
+ - all the x vectors in points must have the same non-zero dimensionality.
+ - relative_noise_magnitude >= 0
+ - solver_eps > 0
+ ensures
+ - Creates an upper bounding function U(x), as described above, assuming that
+ the given points are drawn from F(x).
+ - Uses the provided relative_noise_magnitude when solving the QP, as
+ described above. Note that relative_noise_magnitude can be set to 0. If
+ you do this then all the noise terms are constrained to 0. You should
+ only do this if you know F(x) is non-stochastic and continuous
+ everywhere.
+ - When solving the QP used to find the parameters of U(x), the upper
+ bounding function, we solve the QP to solver_eps accuracy. It's
+ possible that large enough solver_eps can lead to upper bounds that don't
+ upper bound all the supplied points. But for reasonable epsilon values
+ this shouldn't be a problem.
+ - #num_points() == points.size()
+ - #dimensionality() == points[0].x.size()
+ !*/
+
+ upper_bound_function(
+ const double relative_noise_magnitude,
+ const double solver_eps
+ );
+ /*!
+ requires
+ - relative_noise_magnitude >= 0
+ - solver_eps > 0
+ ensures
+ - #num_points() == 0
+ - #dimensionality() == 0
+ - This destructor is the same as calling the above constructor with points.size()==0
+ !*/
+
+
+ void add (
+ const function_evaluation& point
+ );
+ /*!
+ requires
+ - num_points() == 0 || point.x.size() == dimensionality()
+ - point.x.size() != 0
+ ensures
+ - Adds point to get_points().
+ - Incrementally updates the upper bounding function with the given function
+ evaluation. That is, we assume that F(point.x)==point.y and solve the QP
+ described above to find the new U(x) that upper bounds all the points
+ this object knows about (i.e. all the points in get_points() and the new point).
+ - Calling add() is much faster than recreating the upper_bound_function
+ from scratch with all the points. This is because we warm start with the
+ previous solution to the QP. This is done by discarding any non-active
+ constraints and solving the QP again with only the previously active
+ constraints and the new constraints formed by all the pairs of the new
+ point and the old points. This means the QP solved by add() is much
+ smaller than the QP that would be solved by a fresh call to the
+ upper_bound_function constructor.
+ !*/
+
+ const std::vector<function_evaluation>& get_points(
+ ) const;
+ /*!
+ ensures
+ - returns the points from F(x) used to define this upper bounding function.
+ These are all the function_evaluation objects given to this object via
+ its constructor and add().
+ !*/
+
+ long num_points(
+ ) const;
+ /*!
+ ensures
+ - returns the number of points used to define the upper bounding function.
+ (i.e. returns get_points().size())
+ !*/
+
+ long dimensionality(
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the input vectors to the upper bounding function.
+ !*/
+
+ double operator() (
+ const matrix<double,0,1>& x
+ ) const;
+ /*!
+ requires
+ - num_points() > 0
+ - x.size() == dimensionality()
+ ensures
+ - return U(x)
+ (i.e. returns the upper bound on F(x) at x given by our upper bounding function)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_UPPER_bOUND_FUNCTION_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/graph.h b/ml/dlib/dlib/graph.h
new file mode 100644
index 000000000..39f7cfbb6
--- /dev/null
+++ b/ml/dlib/dlib/graph.h
@@ -0,0 +1,37 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GRAPh_
+#define DLIB_GRAPh_
+
+#include "graph/graph_kernel_1.h"
+
+#include "algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename E = char,
+ typename mem_manager = default_memory_manager
+ >
+ class graph
+ {
+ graph() {}
+ public:
+
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef graph_kernel_1<T,E,mem_manager,false>
+ kernel_1a;
+ typedef graph_kernel_1<T,E,mem_manager,true>
+ kernel_1a_c;
+
+ };
+}
+
+#endif // DLIB_GRAPh_
+
+
diff --git a/ml/dlib/dlib/graph/graph_kernel_1.h b/ml/dlib/dlib/graph/graph_kernel_1.h
new file mode 100644
index 000000000..fb0d6e7a6
--- /dev/null
+++ b/ml/dlib/dlib/graph/graph_kernel_1.h
@@ -0,0 +1,629 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GRAPH_KERNEl_1_
+#define DLIB_GRAPH_KERNEl_1_
+
+#include <memory>
+#include <vector>
+
+#include "../serialize.h"
+#include "../noncopyable.h"
+#include "../std_allocator.h"
+#include "../algs.h"
+#include "graph_kernel_abstract.h"
+#include "../is_kind.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename node_type, typename graph, bool is_checked>
+ struct graph_checker_helper
+ {
+ /*!
+ This object is used to check preconditions based on the value of is_checked
+ !*/
+
+ static void check_neighbor (
+ unsigned long edge_index,
+ const node_type& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(edge_index < self.number_of_neighbors(),
+ "\tnode_type& graph::node_type::neighbor(edge_index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tedge_index: " << edge_index
+ << "\n\tnumber_of_neighbors(): " << self.number_of_neighbors()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_edge (
+ unsigned long edge_index,
+ const node_type& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(edge_index < self.number_of_neighbors(),
+ "\tE& graph::node_type::edge(edge_index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tedge_index: " << edge_index
+ << "\n\tnumber_of_neighbors(): " << self.number_of_neighbors()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_node (
+ unsigned long index,
+ const graph& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(index < self.number_of_nodes(),
+ "\tnode_type& graph::node(index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tindex: " << index
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_has_edge (
+ unsigned long node_index1,
+ unsigned long node_index2,
+ const graph& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(node_index1 < self.number_of_nodes() &&
+ node_index2 < self.number_of_nodes(),
+ "\tvoid graph::has_edge(node_index1, node_index2)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tnode_index1: " << node_index1
+ << "\n\tnode_index2: " << node_index2
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+ }
+
+ static void check_add_edge (
+ unsigned long node_index1,
+ unsigned long node_index2,
+ const graph& self
+ )
+ {
+ DLIB_CASSERT(node_index1 < self.number_of_nodes() &&
+ node_index2 < self.number_of_nodes(),
+ "\tvoid graph::add_edge(node_index1, node_index2)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tnode_index1: " << node_index1
+ << "\n\tnode_index2: " << node_index2
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+
+ DLIB_CASSERT( self.has_edge(node_index1, node_index2) == false,
+ "\tvoid graph::add_edge(node_index1, node_index2)"
+ << "\n\tYou can't add an edge if it already exists in the graph"
+ << "\n\tnode_index1: " << node_index1
+ << "\n\tnode_index2: " << node_index2
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+
+ }
+
+ static void check_remove_edge (
+ unsigned long node_index1,
+ unsigned long node_index2,
+ const graph& self
+ )
+ {
+ DLIB_CASSERT(node_index1 < self.number_of_nodes() &&
+ node_index2 < self.number_of_nodes(),
+ "\tvoid graph::remove_edge(node_index1, node_index2)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tnode_index1: " << node_index1
+ << "\n\tnode_index2: " << node_index2
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+
+ DLIB_CASSERT( self.has_edge(node_index1, node_index2) == true,
+ "\tvoid graph::remove_edge(node_index1, node_index2)"
+ << "\n\tYou can't remove an edge if it isn't in the graph"
+ << "\n\tnode_index1: " << node_index1
+ << "\n\tnode_index2: " << node_index2
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+
+ }
+
+ static void check_remove_node (
+ unsigned long index,
+ const graph& self
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(index < self.number_of_nodes(),
+ "\tvoid graph::remove_node(index)"
+ << "\n\tYou have specified an invalid index"
+ << "\n\tindex: " << index
+ << "\n\tnumber_of_nodes(): " << self.number_of_nodes()
+ << "\n\tthis: " << &self
+ );
+ }
+ };
+
+ template <typename node_type, typename graph>
+ struct graph_checker_helper <node_type, graph, false>
+ {
+ static inline void check_edge ( unsigned long , const node_type& ) { }
+ static inline void check_neighbor ( unsigned long , const node_type& ) { }
+ static inline void check_node ( unsigned long , const graph& ) { }
+ static inline void check_has_edge ( unsigned long , unsigned long , const graph& ) { }
+ static inline void check_add_edge ( unsigned long , unsigned long , const graph& ) { }
+ static inline void check_remove_edge ( unsigned long , unsigned long , const graph& ) { }
+ static inline void check_remove_node ( unsigned long , const graph& ) { }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E = char,
+ typename mem_manager = default_memory_manager,
+ bool is_checked = true
+ >
+ class graph_kernel_1 : noncopyable
+ {
+
+ /*!
+ INITIAL VALUE
+ - nodes.size() == 0
+
+ CONVENTION
+ - nodes.size() == number_of_nodes()
+ - for all valid i:
+ - *nodes[i] == node(i)
+ - nodes[i]->neighbors.size() == nodes[i]->number_of_neighbors(i)
+ - nodes[i]->idx == i == nodes[i]->index()
+ - for all valid n:
+ - nodes[i]->neighbors[n] == pointer to the n'th parent node of i
+ - *nodes[i]->neighbors[n] == node(i).neighbor(n)
+ - *nodes[i]->edges[n] == node(i).edge(n)
+ !*/
+
+ public:
+ struct node_type;
+
+ private:
+ typedef graph_checker_helper<node_type, graph_kernel_1, is_checked> checker;
+
+
+ public:
+
+ typedef T type;
+ typedef E edge_type;
+ typedef mem_manager mem_manager_type;
+
+ graph_kernel_1(
+ ) {}
+
+ virtual ~graph_kernel_1(
+ ) {}
+
+ void clear(
+ ) { nodes.clear(); }
+
+ void set_number_of_nodes (
+ unsigned long new_size
+ );
+
+ unsigned long number_of_nodes (
+ ) const { return nodes.size(); }
+
+ node_type& node (
+ unsigned long index
+ ) { checker::check_node(index,*this); return *nodes[index]; }
+
+ const node_type& node (
+ unsigned long index
+ ) const { checker::check_node(index,*this); return *nodes[index]; }
+
+ bool has_edge (
+ unsigned long node_index1,
+ unsigned long node_index2
+ ) const;
+
+ void add_edge (
+ unsigned long node_index1,
+ unsigned long node_index2
+ );
+
+ void remove_edge (
+ unsigned long node_index1,
+ unsigned long node_index2
+ );
+
+ unsigned long add_node (
+ );
+
+ void remove_node (
+ unsigned long index
+ );
+
+ void swap (
+ graph_kernel_1& item
+ ) { nodes.swap(item.nodes); }
+
+ public:
+
+ struct node_type
+ {
+ T data;
+ typedef graph_kernel_1 graph_type;
+
+ unsigned long index(
+ ) const { return idx; }
+
+ unsigned long number_of_neighbors (
+ ) const { return neighbors.size(); }
+
+ const node_type& neighbor (
+ unsigned long edge_index
+ ) const { checker::check_neighbor(edge_index,*this); return *neighbors[edge_index]; }
+
+ node_type& neighbor (
+ unsigned long edge_index
+ ) { checker::check_neighbor(edge_index,*this); return *neighbors[edge_index]; }
+
+ const E& edge (
+ unsigned long edge_index
+ ) const { checker::check_edge(edge_index,*this); return *edges[edge_index]; }
+
+ E& edge (
+ unsigned long edge_index
+ ) { checker::check_edge(edge_index,*this); return *edges[edge_index]; }
+
+ private:
+ friend class graph_kernel_1;
+ typedef std_allocator<node_type*,mem_manager> alloc_type;
+ typedef std_allocator<std::shared_ptr<E>,mem_manager> alloc_edge_type;
+ std::vector<node_type*,alloc_type> neighbors;
+ std::vector<std::shared_ptr<E>,alloc_edge_type> edges;
+ unsigned long idx;
+ };
+
+ private:
+
+ typedef std_allocator<std::shared_ptr<node_type>,mem_manager> alloc_type;
+ typedef std::vector<std::shared_ptr<node_type>, alloc_type> vector_type;
+ vector_type nodes;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ inline void swap (
+ graph_kernel_1<T,E,mem_manager,is_checked>& a,
+ graph_kernel_1<T,E,mem_manager,is_checked>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ struct is_graph<graph_kernel_1<T,E,mem_manager, is_checked> >
+ {
+ static const bool value = true;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void serialize (
+ const graph_kernel_1<T,E,mem_manager,is_checked>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.number_of_nodes(), out);
+
+ // serialize each node
+ for (unsigned long i = 0; i < item.number_of_nodes(); ++i)
+ {
+ serialize(item.node(i).data, out);
+
+ // serialize all the edges
+ for (unsigned long n = 0; n < item.node(i).number_of_neighbors(); ++n)
+ {
+ // only serialize edges that we haven't already serialized
+ if (item.node(i).neighbor(n).index() >= i)
+ {
+ serialize(item.node(i).neighbor(n).index(), out);
+ serialize(item.node(i).edge(n), out);
+ }
+ }
+ const unsigned long stop_mark = 0xFFFFFFFF;
+ serialize(stop_mark, out);
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type graph_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void deserialize (
+ graph_kernel_1<T,E,mem_manager,is_checked>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long size;
+ deserialize(size, in);
+
+ item.clear();
+ item.set_number_of_nodes(size);
+
+ // deserialize each node
+ for (unsigned long i = 0; i < item.number_of_nodes(); ++i)
+ {
+ deserialize(item.node(i).data, in);
+
+ const unsigned long stop_mark = 0xFFFFFFFF;
+ // Add all the edges going to this node's neighbors
+ unsigned long index;
+ deserialize(index, in);
+ while (index != stop_mark)
+ {
+ item.add_edge(i, index);
+ // find the edge
+ unsigned long j = 0;
+ for (j = 0; j < item.node(i).number_of_neighbors(); ++j)
+ if (item.node(i).neighbor(j).index() == index)
+ break;
+
+ deserialize(item.node(i).edge(j), in);
+ deserialize(index, in);
+ }
+
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type graph_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void graph_kernel_1<T,E,mem_manager,is_checked>::
+ set_number_of_nodes (
+ unsigned long new_size
+ )
+ {
+ try
+ {
+ nodes.resize(new_size);
+ for (unsigned long i = 0; i < nodes.size(); ++i)
+ {
+ nodes[i].reset(new node_type);
+ nodes[i]->idx = i;
+ }
+ }
+ catch (...)
+ {
+ clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ bool graph_kernel_1<T,E,mem_manager,is_checked>::
+ has_edge (
+ unsigned long node_index1,
+ unsigned long node_index2
+ ) const
+ {
+ checker::check_has_edge(node_index1, node_index2, *this);
+
+ node_type& n = *nodes[node_index1];
+
+ // search all the child nodes to see if there is a link to the right node
+ for (unsigned long i = 0; i < n.neighbors.size(); ++i)
+ {
+ if (n.neighbors[i]->idx == node_index2)
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void graph_kernel_1<T,E,mem_manager,is_checked>::
+ add_edge (
+ unsigned long node_index1,
+ unsigned long node_index2
+ )
+ {
+ checker::check_add_edge(node_index1, node_index2, *this);
+ try
+ {
+ node_type& n1 = *nodes[node_index1];
+ node_type& n2 = *nodes[node_index2];
+
+ n1.neighbors.push_back(&n2);
+
+ std::shared_ptr<E> e(new E);
+ n1.edges.push_back(e);
+
+ // don't add this twice if this is an edge from node_index1 back to itself
+ if (node_index1 != node_index2)
+ {
+ n2.neighbors.push_back(&n1);
+ n2.edges.push_back(e);
+ }
+ }
+ catch (...)
+ {
+ clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void graph_kernel_1<T,E,mem_manager,is_checked>::
+ remove_edge (
+ unsigned long node_index1,
+ unsigned long node_index2
+ )
+ {
+ checker::check_remove_edge(node_index1, node_index2, *this);
+
+ node_type& n1 = *nodes[node_index1];
+ node_type& n2 = *nodes[node_index2];
+
+ // remove the record of the link from n1
+ unsigned long pos = static_cast<unsigned long>(find(n1.neighbors.begin(), n1.neighbors.end(), &n2) - n1.neighbors.begin());
+ n1.neighbors.erase(n1.neighbors.begin() + pos);
+ n1.edges.erase(n1.edges.begin() + pos);
+
+ // check if this is an edge that goes from node_index1 back to itself
+ if (node_index1 != node_index2)
+ {
+ // remove the record of the link from n2
+ unsigned long pos = static_cast<unsigned long>(find(n2.neighbors.begin(), n2.neighbors.end(), &n1) - n2.neighbors.begin());
+ n2.neighbors.erase(n2.neighbors.begin() + pos);
+ n2.edges.erase(n2.edges.begin() + pos);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ unsigned long graph_kernel_1<T,E,mem_manager,is_checked>::
+ add_node (
+ )
+ {
+ try
+ {
+ std::shared_ptr<node_type> n(new node_type);
+ n->idx = nodes.size();
+ nodes.push_back(n);
+ return n->idx;
+ }
+ catch (...)
+ {
+ clear();
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager,
+ bool is_checked
+ >
+ void graph_kernel_1<T,E,mem_manager,is_checked>::
+ remove_node (
+ unsigned long index
+ )
+ {
+ checker::check_remove_node(index,*this);
+
+ node_type& n = *nodes[index];
+
+ // remove all edges pointing to this node from its neighbors
+ for (unsigned long i = 0; i < n.neighbors.size(); ++i)
+ {
+ // remove the edge from this specific parent
+ unsigned long pos = static_cast<unsigned long>(find(n.neighbors[i]->neighbors.begin(), n.neighbors[i]->neighbors.end(), &n) -
+ n.neighbors[i]->neighbors.begin());
+ n.neighbors[i]->neighbors.erase(n.neighbors[i]->neighbors.begin() + pos);
+ n.neighbors[i]->edges.erase(n.neighbors[i]->edges.begin() + pos);
+ }
+
+ // now remove this node by replacing it with the last node in the nodes vector
+ nodes[index] = nodes[nodes.size()-1];
+
+ // update the index for the node we just moved
+ nodes[index]->idx = index;
+
+ // now remove the duplicated node at the end of the vector
+ nodes.pop_back();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GRAPH_KERNEl_1_
+
diff --git a/ml/dlib/dlib/graph/graph_kernel_abstract.h b/ml/dlib/dlib/graph/graph_kernel_abstract.h
new file mode 100644
index 000000000..e6e699332
--- /dev/null
+++ b/ml/dlib/dlib/graph/graph_kernel_abstract.h
@@ -0,0 +1,329 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_GRAPH_KERNEl_ABSTRACT_
+#ifdef DLIB_GRAPH_KERNEl_ABSTRACT_
+
+#include "../serialize.h"
+#include "../algs.h"
+#include "../noncopyable.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename E = char,
+ typename mem_manager = default_memory_manager
+ >
+ class graph : noncopyable
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must be swappable by a global swap() and
+ T must have a default constructor
+
+ REQUIREMENTS ON E
+ E must be swappable by a global swap() and
+ E must have a default constructor
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ The only time pointers or references to nodes or edges become invalid is when
+ they reference nodes or edges that have been removed from a graph.
+
+ INITIAL VALUE
+ number_of_nodes() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an undirected graph which is a set of nodes with undirected
+ edges connecting various nodes.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef E edge_type;
+ typedef mem_manager mem_manager_type;
+
+ graph(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ !*/
+
+ virtual ~graph(
+ );
+ /*!
+ ensures
+ - all resources associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void set_number_of_nodes (
+ unsigned long new_size
+ );
+ /*!
+ ensures
+ - #number_of_nodes() == new_size
+ - for all i < new_size:
+ - number_of_neighbors(i) == 0
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ unsigned long number_of_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in this graph
+ !*/
+
+ struct node_type
+ {
+ T data;
+ typedef graph graph_type;
+
+ unsigned long index(
+ ) const;
+ /*!
+ ensures
+ - let G be the graph that contains the node *this
+ - returns a number N such that G.node(N) == *this
+ (i.e. returns the index of this node in the graph)
+ !*/
+
+ unsigned long number_of_neighbors (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in this graph that are
+ adjacent to this node. I.e. the number of nodes
+ that are directly connected to this node via an edge.
+ !*/
+
+ const node_type& neighbor (
+ unsigned long edge_index
+ ) const;
+ /*!
+ requires
+ - edge_index < number_of_neighbors()
+ ensures
+ - returns a const reference to the edge_index'th neighbor of *this
+ !*/
+
+ node_type& neighbor (
+ unsigned long edge_index
+ );
+ /*!
+ requires
+ - edge_index < number_of_neighbors()
+ ensures
+ - returns a non-const reference to the edge_index'th neighbor of *this
+ !*/
+
+ const E& edge (
+ unsigned long edge_index
+ ) const;
+ /*!
+ requires
+ - edge_index < number_of_neighbors()
+ ensures
+ - returns a const reference to the edge_index'th edge data for the
+ edge connecting to neighbor this->neighbor(edge_index)
+ !*/
+
+ E& edge (
+ unsigned long edge_index
+ );
+ /*!
+ requires
+ - edge_index < number_of_neighbors()
+ ensures
+ - returns a non-const reference to the edge_index'th edge data for the
+ edge connecting to neighbor this->neighbor(edge_index)
+ !*/
+
+ };
+
+ node_type& node (
+ unsigned long index
+ );
+ /*!
+ requires
+ - index < number_of_nodes()
+ ensures
+ - returns a non-const reference to the node with the given index
+ !*/
+
+ const node_type& node (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < number_of_nodes()
+ ensures
+ - returns a const reference to the node with the given index
+ !*/
+
+ bool has_edge (
+ unsigned long node_index1,
+ unsigned long node_index2
+ ) const;
+ /*!
+ requires
+ - node_index1 < number_of_nodes()
+ - node_index2 < number_of_nodes()
+ ensures
+ - if (there is an edge connecting node(node_index1) and node(node_index2)) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void add_edge (
+ unsigned long node_index1,
+ unsigned long node_index2
+ );
+ /*!
+ requires
+ - node_index1 < number_of_nodes()
+ - node_index2 < number_of_nodes()
+ - has_edge(node_index1, node_index2) == false
+ ensures
+ - #has_edge(node_index1, node_index2) == true
+ throws
+ - std::bad_alloc
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ void remove_edge (
+ unsigned long node_index1,
+ unsigned long node_index2
+ );
+ /*!
+ requires
+ - node_index1 < number_of_nodes()
+ - node_index2 < number_of_nodes()
+ - has_edge(node_index1, node_index2) == true
+ ensures
+ - #has_edge(node_index1, node_index2) == false
+ throws
+ - std::bad_alloc
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ unsigned long add_node (
+ );
+ /*!
+ ensures
+ - does not change the index number of existing nodes
+ - adds a node with index N == number_of_nodes() such that:
+ - #node(N).number_of_neighbors() == 0
+ - #number_of_nodes() == number_of_nodes() + 1
+ - returns N
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ void remove_node (
+ unsigned long index
+ );
+ /*!
+ requires
+ - index < number_of_nodes()
+ ensures
+ - removes the node with the given index from the graph.
+ - removes all edges linking the removed node to the rest
+ of the graph.
+ - the remaining node indexes are remapped so that they remain
+ contiguous. (This means that for all valid N, node(N) doesn't
+ necessarily reference the same node as #node(N))
+ - #number_of_nodes() == number_of_nodes() - 1
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then this object reverts back
+ to its initial state.
+ !*/
+
+ void swap (
+ graph& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager
+ >
+ inline void swap (
+ graph<T,E,mem_manager>& a,
+ graph<T,E,mem_manager>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager
+ >
+ void serialize (
+ const graph<T,E,mem_manager>& item,
+ std::ostream& out
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+ template <
+ typename T,
+ typename E,
+ typename mem_manager
+ >
+ void deserialize (
+ graph<T,E,mem_manager>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+}
+
+#endif // DLIB_GRAPH_KERNEl_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/graph_cuts.h b/ml/dlib/dlib/graph_cuts.h
new file mode 100644
index 000000000..c245b2be4
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts.h
@@ -0,0 +1,14 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GRAPH_CUTs_HEADER_
+#define DLIB_GRAPH_CUTs_HEADER_
+
+#include "graph_cuts/min_cut.h"
+#include "graph_cuts/general_flow_graph.h"
+#include "graph_cuts/find_max_factor_graph_potts.h"
+#include "graph_cuts/graph_labeler.h"
+
+#endif // DLIB_GRAPH_CUTs_HEADER_
+
+
+
diff --git a/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h
new file mode 100644
index 000000000..f035442bf
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h
@@ -0,0 +1,959 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_
+#define DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_
+
+#include "find_max_factor_graph_potts_abstract.h"
+#include "../matrix.h"
+#include "min_cut.h"
+#include "general_potts_problem.h"
+#include "../algs.h"
+#include "../graph_utils.h"
+#include "../array2d.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ template <
+ typename potts_problem,
+ typename T = void
+ >
+ class flows_container
+ {
+ /*
+ This object notionally represents a matrix of flow values. It's
+ overloaded to represent this matrix efficiently though. In this case
+ it represents the matrix using a sparse representation.
+ */
+
+ typedef typename potts_problem::value_type edge_type;
+ std::vector<std::vector<edge_type> > flows;
+ public:
+
+ void setup(
+ const potts_problem& p
+ )
+ {
+ flows.resize(p.number_of_nodes());
+ for (unsigned long i = 0; i < flows.size(); ++i)
+ {
+ flows[i].resize(p.number_of_neighbors(i));
+ }
+ }
+
+ edge_type& operator() (
+ const long r,
+ const long c
+ ) { return flows[r][c]; }
+
+ const edge_type& operator() (
+ const long r,
+ const long c
+ ) const { return flows[r][c]; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_problem
+ >
+ class flows_container<potts_problem,
+ typename enable_if_c<potts_problem::max_number_of_neighbors!=0>::type>
+ {
+ /*
+ This object notionally represents a matrix of flow values. It's
+ overloaded to represent this matrix efficiently though. In this case
+ it represents the matrix using a dense representation.
+
+ */
+ typedef typename potts_problem::value_type edge_type;
+ const static unsigned long max_number_of_neighbors = potts_problem::max_number_of_neighbors;
+ matrix<edge_type,0,max_number_of_neighbors> flows;
+ public:
+
+ void setup(
+ const potts_problem& p
+ )
+ {
+ flows.set_size(p.number_of_nodes(), max_number_of_neighbors);
+ }
+
+ edge_type& operator() (
+ const long r,
+ const long c
+ ) { return flows(r,c); }
+
+ const edge_type& operator() (
+ const long r,
+ const long c
+ ) const { return flows(r,c); }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_problem
+ >
+ class potts_flow_graph
+ {
+ public:
+ typedef typename potts_problem::value_type edge_type;
+ private:
+ /*!
+ This is a utility class used by dlib::min_cut to convert a potts_problem
+ into the kind of flow graph expected by the min_cut object's main block
+ of code.
+
+ Within this object, we will use the convention that one past
+ potts_problem::number_of_nodes() is the source node and two past is
+ the sink node.
+ !*/
+
+ potts_problem& g;
+
+ // flows(i,j) == the flow from node id i to it's jth neighbor
+ flows_container<potts_problem> flows;
+ // source_flows(i,0) == flow from source to node i,
+ // source_flows(i,1) == flow from node i to source
+ matrix<edge_type,0,2> source_flows;
+
+ // sink_flows(i,0) == flow from sink to node i,
+ // sink_flows(i,1) == flow from node i to sink
+ matrix<edge_type,0,2> sink_flows;
+
+ node_label source_label, sink_label;
+ public:
+
+ potts_flow_graph(
+ potts_problem& g_
+ ) : g(g_)
+ {
+ flows.setup(g);
+
+ source_flows.set_size(g.number_of_nodes(), 2);
+ sink_flows.set_size(g.number_of_nodes(), 2);
+ source_flows = 0;
+ sink_flows = 0;
+
+ source_label = FREE_NODE;
+ sink_label = FREE_NODE;
+
+ // setup flows based on factor potentials
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ const edge_type temp = g.factor_value(i);
+ if (temp < 0)
+ source_flows(i,0) = -temp;
+ else
+ sink_flows(i,1) = temp;
+
+ for (unsigned long j = 0; j < g.number_of_neighbors(i); ++j)
+ {
+ flows(i,j) = g.factor_value_disagreement(i, g.get_neighbor(i,j));
+ }
+ }
+ }
+
+ class out_edge_iterator
+ {
+ friend class potts_flow_graph;
+ unsigned long idx; // base node idx
+ unsigned long cnt; // count over the neighbors of idx
+ public:
+
+ out_edge_iterator(
+ ):idx(0),cnt(0){}
+
+ out_edge_iterator(
+ unsigned long idx_,
+ unsigned long cnt_
+ ):idx(idx_),cnt(cnt_)
+ {}
+
+ bool operator!= (
+ const out_edge_iterator& item
+ ) const { return cnt != item.cnt; }
+
+ out_edge_iterator& operator++(
+ )
+ {
+ ++cnt;
+ return *this;
+ }
+ };
+
+ class in_edge_iterator
+ {
+ friend class potts_flow_graph;
+ unsigned long idx; // base node idx
+ unsigned long cnt; // count over the neighbors of idx
+ public:
+
+ in_edge_iterator(
+ ):idx(0),cnt(0)
+ {}
+
+
+ in_edge_iterator(
+ unsigned long idx_,
+ unsigned long cnt_
+ ):idx(idx_),cnt(cnt_)
+ {}
+
+ bool operator!= (
+ const in_edge_iterator& item
+ ) const { return cnt != item.cnt; }
+
+ in_edge_iterator& operator++(
+ )
+ {
+ ++cnt;
+ return *this;
+ }
+ };
+
+ unsigned long number_of_nodes (
+ ) const { return g.number_of_nodes() + 2; }
+
+ out_edge_iterator out_begin(
+ const unsigned long& it
+ ) const { return out_edge_iterator(it, 0); }
+
+ in_edge_iterator in_begin(
+ const unsigned long& it
+ ) const { return in_edge_iterator(it, 0); }
+
+ out_edge_iterator out_end(
+ const unsigned long& it
+ ) const
+ {
+ if (it >= g.number_of_nodes())
+ return out_edge_iterator(it, g.number_of_nodes());
+ else
+ return out_edge_iterator(it, g.number_of_neighbors(it)+2);
+ }
+
+ in_edge_iterator in_end(
+ const unsigned long& it
+ ) const
+ {
+ if (it >= g.number_of_nodes())
+ return in_edge_iterator(it, g.number_of_nodes());
+ else
+ return in_edge_iterator(it, g.number_of_neighbors(it)+2);
+ }
+
+
+ template <typename iterator_type>
+ unsigned long node_id (
+ const iterator_type& it
+ ) const
+ {
+ // if this isn't an iterator over the source or sink nodes
+ if (it.idx < g.number_of_nodes())
+ {
+ const unsigned long num = g.number_of_neighbors(it.idx);
+ if (it.cnt < num)
+ return g.get_neighbor(it.idx, it.cnt);
+ else if (it.cnt == num)
+ return g.number_of_nodes();
+ else
+ return g.number_of_nodes()+1;
+ }
+ else
+ {
+ return it.cnt;
+ }
+ }
+
+
+ edge_type get_flow (
+ const unsigned long& it1,
+ const unsigned long& it2
+ ) const
+ {
+ if (it1 >= g.number_of_nodes())
+ {
+ // if it1 is the source
+ if (it1 == g.number_of_nodes())
+ {
+ return source_flows(it2,0);
+ }
+ else // if it1 is the sink
+ {
+ return sink_flows(it2,0);
+ }
+ }
+ else if (it2 >= g.number_of_nodes())
+ {
+ // if it2 is the source
+ if (it2 == g.number_of_nodes())
+ {
+ return source_flows(it1,1);
+ }
+ else // if it2 is the sink
+ {
+ return sink_flows(it1,1);
+ }
+ }
+ else
+ {
+ return flows(it1, g.get_neighbor_idx(it1, it2));
+ }
+
+ }
+
+ edge_type get_flow (
+ const out_edge_iterator& it
+ ) const
+ {
+ if (it.idx < g.number_of_nodes())
+ {
+ const unsigned long num = g.number_of_neighbors(it.idx);
+ if (it.cnt < num)
+ return flows(it.idx, it.cnt);
+ else if (it.cnt == num)
+ return source_flows(it.idx,1);
+ else
+ return sink_flows(it.idx,1);
+ }
+ else
+ {
+ // if it.idx is the source
+ if (it.idx == g.number_of_nodes())
+ {
+ return source_flows(it.cnt,0);
+ }
+ else // if it.idx is the sink
+ {
+ return sink_flows(it.cnt,0);
+ }
+ }
+ }
+
+ edge_type get_flow (
+ const in_edge_iterator& it
+ ) const
+ {
+ return get_flow(node_id(it), it.idx);
+ }
+
+ void adjust_flow (
+ const unsigned long& it1,
+ const unsigned long& it2,
+ const edge_type& value
+ )
+ {
+ if (it1 >= g.number_of_nodes())
+ {
+ // if it1 is the source
+ if (it1 == g.number_of_nodes())
+ {
+ source_flows(it2,0) += value;
+ source_flows(it2,1) -= value;
+ }
+ else // if it1 is the sink
+ {
+ sink_flows(it2,0) += value;
+ sink_flows(it2,1) -= value;
+ }
+ }
+ else if (it2 >= g.number_of_nodes())
+ {
+ // if it2 is the source
+ if (it2 == g.number_of_nodes())
+ {
+ source_flows(it1,1) += value;
+ source_flows(it1,0) -= value;
+ }
+ else // if it2 is the sink
+ {
+ sink_flows(it1,1) += value;
+ sink_flows(it1,0) -= value;
+ }
+ }
+ else
+ {
+ flows(it1, g.get_neighbor_idx(it1, it2)) += value;
+ flows(it2, g.get_neighbor_idx(it2, it1)) -= value;
+ }
+
+ }
+
+ void set_label (
+ const unsigned long& it,
+ node_label value
+ )
+ {
+ if (it < g.number_of_nodes())
+ g.set_label(it, value);
+ else if (it == g.number_of_nodes())
+ source_label = value;
+ else
+ sink_label = value;
+ }
+
+ node_label get_label (
+ const unsigned long& it
+ ) const
+ {
+ if (it < g.number_of_nodes())
+ return g.get_label(it);
+ if (it == g.number_of_nodes())
+ return source_label;
+ else
+ return sink_label;
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename label_image_type,
+ typename image_potts_model
+ >
+ class potts_grid_problem
+ {
+ label_image_type& label_img;
+ long nc;
+ long num_nodes;
+ unsigned char* labels;
+ const image_potts_model& model;
+
+ public:
+ const static unsigned long max_number_of_neighbors = 4;
+
+ potts_grid_problem (
+ label_image_type& label_img_,
+ const image_potts_model& image_potts_model_
+ ) :
+ label_img(label_img_),
+ model(image_potts_model_)
+ {
+ num_nodes = model.nr()*model.nc();
+ nc = model.nc();
+ labels = &label_img[0][0];
+ }
+
+ unsigned long number_of_nodes (
+ ) const { return num_nodes; }
+
+ unsigned long number_of_neighbors (
+ unsigned long
+ ) const
+ {
+ return 4;
+ }
+
+ unsigned long get_neighbor_idx (
+ long node_id1,
+ long node_id2
+ ) const
+ {
+ long diff = node_id2-node_id1;
+ if (diff > nc)
+ diff -= (long)number_of_nodes();
+ else if (diff < -nc)
+ diff += (long)number_of_nodes();
+
+ if (diff == 1)
+ return 0;
+ else if (diff == -1)
+ return 1;
+ else if (diff == nc)
+ return 2;
+ else
+ return 3;
+ }
+
+ unsigned long get_neighbor (
+ long node_id,
+ long idx
+ ) const
+ {
+ switch(idx)
+ {
+ case 0:
+ {
+ long temp = node_id+1;
+ if (temp < (long)number_of_nodes())
+ return temp;
+ else
+ return temp - (long)number_of_nodes();
+ }
+ case 1:
+ {
+ long temp = node_id-1;
+ if (node_id >= 1)
+ return temp;
+ else
+ return temp + (long)number_of_nodes();
+ }
+ case 2:
+ {
+ long temp = node_id+nc;
+ if (temp < (long)number_of_nodes())
+ return temp;
+ else
+ return temp - (long)number_of_nodes();
+ }
+ case 3:
+ {
+ long temp = node_id-nc;
+ if (node_id >= nc)
+ return temp;
+ else
+ return temp + (long)number_of_nodes();
+ }
+ }
+ return 0;
+ }
+
+ void set_label (
+ const unsigned long& idx,
+ node_label value
+ )
+ {
+ *(labels+idx) = value;
+ }
+
+ node_label get_label (
+ const unsigned long& idx
+ ) const
+ {
+ return *(labels+idx);
+ }
+
+ typedef typename image_potts_model::value_type value_type;
+
+ value_type factor_value (unsigned long idx) const
+ {
+ return model.factor_value(idx);
+ }
+
+ value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const
+ {
+ return model.factor_value_disagreement(idx1,idx2);
+ }
+
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_model
+ >
+ typename potts_model::value_type potts_model_score (
+ const potts_model& prob
+ )
+ {
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
+ {
+ for (unsigned long jj = 0; jj < prob.number_of_neighbors(i); ++jj)
+ {
+ unsigned long j = prob.get_neighbor(i,jj);
+ DLIB_ASSERT(prob.factor_value_disagreement(i,j) >= 0,
+ "\t value_type potts_model_score(prob)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t prob.factor_value_disagreement(i,j): " << prob.factor_value_disagreement(i,j)
+ );
+ DLIB_ASSERT(prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i),
+ "\t value_type potts_model_score(prob)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t prob.factor_value_disagreement(i,j): " << prob.factor_value_disagreement(i,j)
+ << "\n\t prob.factor_value_disagreement(j,i): " << prob.factor_value_disagreement(j,i)
+ );
+ }
+ }
+#endif
+
+ typename potts_model::value_type score = 0;
+ for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
+ {
+ const bool label = (prob.get_label(i)!=0);
+ if (label)
+ score += prob.factor_value(i);
+ }
+
+ for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
+ {
+ for (unsigned long n = 0; n < prob.number_of_neighbors(i); ++n)
+ {
+ const unsigned long idx2 = prob.get_neighbor(i,n);
+ const bool label_i = (prob.get_label(i)!=0);
+ const bool label_idx2 = (prob.get_label(idx2)!=0);
+ if (label_i != label_idx2 && i < idx2)
+ score -= prob.factor_value_disagreement(i, idx2);
+ }
+ }
+
+ return score;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ typename graph_type::edge_type potts_model_score (
+ const graph_type& g,
+ const std::vector<node_label>& labels
+ )
+ {
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\t edge_type potts_model_score(g,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ );
+ typedef typename graph_type::edge_type edge_type;
+ typedef typename graph_type::type type;
+
+ // The edges and node's have to use the same type to represent factor weights!
+ COMPILE_TIME_ASSERT((is_same_type<edge_type, type>::value == true));
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ for (unsigned long jj = 0; jj < g.node(i).number_of_neighbors(); ++jj)
+ {
+ unsigned long j = g.node(i).neighbor(jj).index();
+ DLIB_ASSERT(edge(g,i,j) >= 0,
+ "\t edge_type potts_model_score(g,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t edge(g,i,j): " << edge(g,i,j)
+ );
+ }
+ }
+#endif
+
+ typename graph_type::edge_type score = 0;
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ const bool label = (labels[i]!=0);
+ if (label)
+ score += g.node(i).data;
+ }
+
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
+ {
+ const unsigned long idx2 = g.node(i).neighbor(n).index();
+ const bool label_i = (labels[i]!=0);
+ const bool label_idx2 = (labels[idx2]!=0);
+ if (label_i != label_idx2 && i < idx2)
+ score -= g.node(i).edge(n);
+ }
+ }
+
+ return score;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_grid_problem,
+ typename mem_manager
+ >
+ typename potts_grid_problem::value_type potts_model_score (
+ const potts_grid_problem& prob,
+ const array2d<node_label,mem_manager>& labels
+ )
+ {
+ DLIB_ASSERT(prob.nr() == labels.nr() && prob.nc() == labels.nc(),
+ "\t value_type potts_model_score(prob,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t prob.nr(): " << labels.nr()
+ << "\n\t prob.nc(): " << labels.nc()
+ );
+ typedef array2d<node_label,mem_manager> image_type;
+ // This const_cast is ok because the model object won't actually modify labels
+ dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(const_cast<image_type&>(labels),prob);
+ return potts_model_score(model);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_model
+ >
+ void find_max_factor_graph_potts (
+ potts_model& prob
+ )
+ {
+#ifdef ENABLE_ASSERTS
+ for (unsigned long node_i = 0; node_i < prob.number_of_nodes(); ++node_i)
+ {
+ for (unsigned long jj = 0; jj < prob.number_of_neighbors(node_i); ++jj)
+ {
+ unsigned long node_j = prob.get_neighbor(node_i,jj);
+ DLIB_ASSERT(prob.get_neighbor_idx(node_j,node_i) < prob.number_of_neighbors(node_j),
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t The supplied potts problem defines an invalid graph."
+ << "\n\t node_i: " << node_i
+ << "\n\t node_j: " << node_j
+ << "\n\t prob.get_neighbor_idx(node_j,node_i): " << prob.get_neighbor_idx(node_j,node_i)
+ << "\n\t prob.number_of_neighbors(node_j): " << prob.number_of_neighbors(node_j)
+ );
+
+ DLIB_ASSERT(prob.get_neighbor_idx(node_i,prob.get_neighbor(node_i,jj)) == jj,
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t The get_neighbor_idx() and get_neighbor() functions must be inverses of each other."
+ << "\n\t node_i: " << node_i
+ << "\n\t jj: " << jj
+ << "\n\t prob.get_neighbor(node_i,jj): " << prob.get_neighbor(node_i,jj)
+ << "\n\t prob.get_neighbor_idx(node_i,prob.get_neighbor(node_i,jj)): " << prob.get_neighbor_idx(node_i,node_j)
+ );
+
+ DLIB_ASSERT(prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i))==node_i,
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t The get_neighbor_idx() and get_neighbor() functions must be inverses of each other."
+ << "\n\t node_i: " << node_i
+ << "\n\t node_j: " << node_j
+ << "\n\t prob.get_neighbor_idx(node_j,node_i): " << prob.get_neighbor_idx(node_j,node_i)
+ << "\n\t prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i)): " << prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i))
+ );
+
+ DLIB_ASSERT(prob.factor_value_disagreement(node_i,node_j) >= 0,
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t node_i: " << node_i
+ << "\n\t node_j: " << node_j
+ << "\n\t prob.factor_value_disagreement(node_i,node_j): " << prob.factor_value_disagreement(node_i,node_j)
+ );
+ DLIB_ASSERT(prob.factor_value_disagreement(node_i,node_j) == prob.factor_value_disagreement(node_j,node_i),
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t node_i: " << node_i
+ << "\n\t node_j: " << node_j
+ << "\n\t prob.factor_value_disagreement(node_i,node_j): " << prob.factor_value_disagreement(node_i,node_j)
+ << "\n\t prob.factor_value_disagreement(node_j,node_i): " << prob.factor_value_disagreement(node_j,node_i)
+ );
+ }
+ }
+#endif
+ COMPILE_TIME_ASSERT(is_signed_type<typename potts_model::value_type>::value);
+ min_cut mc;
+ dlib::impl::potts_flow_graph<potts_model> pfg(prob);
+ mc(pfg, prob.number_of_nodes(), prob.number_of_nodes()+1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ void find_max_factor_graph_potts (
+ const graph_type& g,
+ std::vector<node_label>& labels
+ )
+ {
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\t void find_max_factor_graph_potts(g,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ );
+ typedef typename graph_type::edge_type edge_type;
+ typedef typename graph_type::type type;
+
+ // The edges and node's have to use the same type to represent factor weights!
+ COMPILE_TIME_ASSERT((is_same_type<edge_type, type>::value == true));
+ COMPILE_TIME_ASSERT(is_signed_type<edge_type>::value);
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ for (unsigned long jj = 0; jj < g.node(i).number_of_neighbors(); ++jj)
+ {
+ unsigned long j = g.node(i).neighbor(jj).index();
+ DLIB_ASSERT(edge(g,i,j) >= 0,
+ "\t void find_max_factor_graph_potts(g,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t edge(g,i,j): " << edge(g,i,j)
+ );
+ }
+ }
+#endif
+
+ dlib::impl::general_potts_problem<graph_type> gg(g, labels);
+ find_max_factor_graph_potts(gg);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_grid_problem,
+ typename mem_manager
+ >
+ void find_max_factor_graph_potts (
+ const potts_grid_problem& prob,
+ array2d<node_label,mem_manager>& labels
+ )
+ {
+ typedef array2d<node_label,mem_manager> image_type;
+ labels.set_size(prob.nr(), prob.nc());
+ dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(labels,prob);
+ find_max_factor_graph_potts(model);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename pixel_type1,
+ typename pixel_type2,
+ typename model_type
+ >
+ struct potts_grid_image_pair_model
+ {
+ const pixel_type1* data1;
+ const pixel_type2* data2;
+ const model_type& model;
+ const long nr_;
+ const long nc_;
+ template <typename image_type1, typename image_type2>
+ potts_grid_image_pair_model(
+ const model_type& model_,
+ const image_type1& img1,
+ const image_type2& img2
+ ) :
+ model(model_),
+ nr_(img1.nr()),
+ nc_(img1.nc())
+ {
+ data1 = &img1[0][0];
+ data2 = &img2[0][0];
+ }
+
+ typedef typename model_type::value_type value_type;
+
+ long nr() const { return nr_; }
+ long nc() const { return nc_; }
+
+ value_type factor_value (
+ unsigned long idx
+ ) const
+ {
+ return model.factor_value(*(data1 + idx), *(data2 + idx));
+ }
+
+ value_type factor_value_disagreement (
+ unsigned long idx1,
+ unsigned long idx2
+ ) const
+ {
+ return model.factor_value_disagreement(*(data1 + idx1), *(data1 + idx2));
+ }
+ };
+
+ // ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename model_type
+ >
+ struct potts_grid_image_single_model
+ {
+ const typename image_type::type* data1;
+ const model_type& model;
+ const long nr_;
+ const long nc_;
+ potts_grid_image_single_model(
+ const model_type& model_,
+ const image_type& img1
+ ) :
+ model(model_),
+ nr_(img1.nr()),
+ nc_(img1.nc())
+ {
+ data1 = &img1[0][0];
+ }
+
+ typedef typename model_type::value_type value_type;
+
+ long nr() const { return nr_; }
+ long nc() const { return nc_; }
+
+ value_type factor_value (
+ unsigned long idx
+ ) const
+ {
+ return model.factor_value(*(data1 + idx));
+ }
+
+ value_type factor_value_disagreement (
+ unsigned long idx1,
+ unsigned long idx2
+ ) const
+ {
+ return model.factor_value_disagreement(*(data1 + idx1), *(data1 + idx2));
+ }
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pair_image_model,
+ typename pixel_type1,
+ typename pixel_type2,
+ typename mem_manager
+ >
+ impl::potts_grid_image_pair_model<pixel_type1, pixel_type2, pair_image_model> make_potts_grid_problem (
+ const pair_image_model& model,
+ const array2d<pixel_type1,mem_manager>& img1,
+ const array2d<pixel_type2,mem_manager>& img2
+ )
+ {
+ DLIB_ASSERT(get_rect(img1) == get_rect(img2),
+ "\t potts_grid_problem make_potts_grid_problem()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t get_rect(img1): " << get_rect(img1)
+ << "\n\t get_rect(img2): " << get_rect(img2)
+ );
+ typedef impl::potts_grid_image_pair_model<pixel_type1, pixel_type2, pair_image_model> potts_type;
+ return potts_type(model,img1,img2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename single_image_model,
+ typename pixel_type,
+ typename mem_manager
+ >
+ impl::potts_grid_image_single_model<array2d<pixel_type,mem_manager>, single_image_model> make_potts_grid_problem (
+ const single_image_model& model,
+ const array2d<pixel_type,mem_manager>& img
+ )
+ {
+ typedef impl::potts_grid_image_single_model<array2d<pixel_type,mem_manager>, single_image_model> potts_type;
+ return potts_type(model,img);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_
+
diff --git a/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts_abstract.h b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts_abstract.h
new file mode 100644
index 000000000..69aa59256
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts_abstract.h
@@ -0,0 +1,636 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_ABSTRACT_Hh_
+#ifdef DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_ABSTRACT_Hh_
+
+#include "../matrix.h"
+#include "min_cut_abstract.h"
+#include "../graph_utils.h"
+#include "../array2d/array2d_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class potts_problem
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a boolean valued factor graph or graphical model
+ that can be efficiently operated on using graph cuts. In particular, this
+ object defines the interface a MAP problem on a factor graph must
+ implement if it is to be solved using the find_max_factor_graph_potts()
+ routine defined at the bottom of this file.
+
+ Note that there is no dlib::potts_problem object. What you are looking
+ at here is simply the interface definition for a Potts problem. You must
+ implement your own version of this object for the problem you wish to
+ solve and then pass it to the find_max_factor_graph_potts() routine.
+
+ Note also that a factor graph should not have any nodes which are
+ neighbors with themselves. Additionally, the graph is undirected. This
+ mean that if A is a neighbor of B then B must be a neighbor of A for
+ the MAP problem to be valid.
+ !*/
+
+ public:
+
+ unsigned long number_of_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in the factor graph. Or in other words,
+ returns the number of variables in the MAP problem/Potts model.
+ !*/
+
+ unsigned long number_of_neighbors (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - returns the number of neighbors of node idx.
+ !*/
+
+ // This is an optional variable which specifies a number that is always
+ // greater than or equal to number_of_neighbors(idx). If you don't know
+ // the value at compile time then either don't include max_number_of_neighbors
+ // in your potts_problem object or set it to 0.
+ const static unsigned long max_number_of_neighbors = 0;
+
+ unsigned long get_neighbor (
+ unsigned long idx,
+ unsigned long n
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ - n < number_of_neighbors(idx)
+ ensures
+ - returns the node index value of the n-th neighbor of
+ the node with index value idx.
+ - The neighbor relationship is reciprocal. That is, if
+ get_neighbor(A,i)==B then there is a value of j such
+ that get_neighbor(B,j)==A.
+ - A node is never its own neighbor. That is, there is
+ no i such that get_neighbor(idx,i)==idx.
+ !*/
+
+ unsigned long get_neighbor_idx (
+ unsigned long idx1,
+ unsigned long idx2
+ ) const;
+ /*!
+ requires
+ - idx1 < number_of_nodes()
+ - idx2 < number_of_nodes()
+ ensures
+ - This function is basically the inverse of get_neighbor().
+ - returns a number IDX such that:
+ - get_neighbor(idx1,IDX) == idx2
+ - IDX < number_of_neighbors(idx1)
+ !*/
+
+ void set_label (
+ const unsigned long& idx,
+ node_label value
+ );
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - #get_label(idx) == value
+ !*/
+
+ node_label get_label (
+ const unsigned long& idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - returns the current label for the idx-th node. This is a value which is
+ 0 if the node's label is false and is any other value if it is true.
+
+ Note that this value is not used by factor_value() or factor_value_disagreement().
+ It is simply here to provide a mechanism for find_max_factor_graph_potts()
+ to return its labeled result. Additionally, the reason it returns a
+ node_label rather than a bool is because doing it this way facilitates
+ use of a graph cut algorithm for the solution of the MAP problem. For
+ more of an explanation you should read the paper referenced by the min_cut
+ object.
+ !*/
+
+ // This typedef should be for a type like int or double. It
+ // must also be capable of representing signed values.
+ typedef an_integer_or_real_type value_type;
+
+ value_type factor_value (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - returns a value which indicates how "good" it is to assign the idx-th
+ node the label of true. The larger the value, the more desirable it is
+ to give it this label. Similarly, a negative value indicates that it is
+ better to give the node a label of false.
+ - It is valid for the returned value to be positive or negative infinity.
+ A value of positive infinity indicates that the idx-th node must be labeled
+ true while negative infinity means it must be labeled false.
+ !*/
+
+ value_type factor_value_disagreement (
+ unsigned long idx1,
+ unsigned long idx2
+ ) const;
+ /*!
+ requires
+ - idx1 < number_of_nodes()
+ - idx2 < number_of_nodes()
+ - idx1 != idx2
+ - the idx1-th node and idx2-th node are neighbors in the graph. That is,
+ get_neighbor(idx1,i)==idx2 for some value of i.
+ ensures
+ - returns a number >= 0. This is the penalty for giving node idx1 and idx2
+ different labels. Larger values indicate a larger penalty.
+ - this function is symmetric. That is, it is true that:
+ factor_value_disagreement(i,j) == factor_value_disagreement(j,i)
+ - It is valid for the returned value to be positive infinity. Returning
+ infinity indicates that the idx1-th and idx2-th nodes must share the same
+ label.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class potts_grid_problem
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a specialization of a potts_problem to the case where
+ the graph is a regular grid where each node is connected to its four
+ neighbors. An example of this is an image where each pixel is a node
+ and is connected to its four immediate neighboring pixels. Therefore,
+ this object defines the interface this special kind of MAP problem
+ must implement if it is to be solved by the find_max_factor_graph_potts(potts_grid_problem,array2d)
+ routine defined at the end of this file.
+
+
+ Note that all nodes always have four neighbors, even nodes on the edge
+ of the graph. This is because these border nodes are connected to
+ the border nodes on the other side of the graph. That is, the graph
+ "wraps" around at the borders.
+ !*/
+
+ public:
+
+ // This typedef should be for a type like int or double. It
+ // must also be capable of representing signed values.
+ typedef an_integer_or_real_type value_type;
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the grid
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in the grid
+ !*/
+
+ value_type factor_value (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < nr()*nc()
+ ensures
+ - The grid is represented in row-major-order format. Therefore, idx
+ identifies a node according to its position in the row-major-order
+ representation of the grid graph. Or in other words, idx corresponds
+ to the following row and column location in the graph:
+ - row == idx/nc()
+ - col == idx%nc()
+ - returns a value which indicates how "good" it is to assign the idx-th
+ node the label of true. The larger the value, the more desirable it is
+ to give it this label. Similarly, a negative value indicates that it is
+ better to give the node a label of false.
+ - It is valid for the returned value to be positive or negative infinity.
+ A value of positive infinity indicates that the idx-th node must be labeled
+ true while negative infinity means it must be labeled false.
+ !*/
+
+ value_type factor_value_disagreement (
+ unsigned long idx1,
+ unsigned long idx2
+ ) const;
+ /*!
+ requires
+ - idx1 < nr()*nc()
+ - idx2 < nr()*nc()
+ - idx1 != idx2
+ - the idx1-th node and idx2-th node are neighbors in the grid graph.
+ ensures
+ - The grid is represented in row-major-order format. Therefore, idx1 and
+ idx2 identify nodes according to their positions in the row-major-order
+ representation of the grid graph. For example, idx1 corresponds
+ to the following row and column location in the graph:
+ - row == idx1/nc()
+ - col == idx1%nc()
+ - returns a number >= 0. This is the penalty for giving node idx1 and idx2
+ different labels. Larger values indicate a larger penalty.
+ - this function is symmetric. That is, it is true that:
+ factor_value_disagreement(i,j) == factor_value_disagreement(j,i)
+ - It is valid for the returned value to be positive infinity. Returning
+ infinity indicates that the idx1-th and idx2-th nodes must share the same
+ label.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_problem
+ >
+ typename potts_problem::value_type potts_model_score (
+ const potts_problem& prob
+ );
+ /*!
+ requires
+ - potts_problem == an object with an interface compatible with the potts_problem
+ object defined at the top of this file.
+ - for all valid i and j:
+ - prob.factor_value_disagreement(i,j) >= 0
+ - prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i)
+ ensures
+ - computes the model score for the given potts_problem. We define this
+ precisely below:
+ - let L(i) == the boolean label of the i-th variable in prob. Or in other
+ words, L(i) == (prob.get_label(i) != 0).
+ - let F == the sum of values of prob.factor_value(i) for only i values
+ where L(i) == true.
+ - Let D == the sum of values of prob.factor_value_disagreement(i,j)
+ for only i and j values which meet the following conditions:
+ - i and j are neighbors in the graph defined by prob, that is,
+ it is valid to call prob.factor_value_disagreement(i,j).
+ - L(i) != L(j)
+ - i < j
+ (i.e. We want to make sure to only count the edge between i and j once)
+
+ - Then this function returns F - D
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ typename graph_type::edge_type potts_model_score (
+ const graph_type& g,
+ const std::vector<node_label>& labels
+ );
+ /*!
+ requires
+ - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h
+ - graph_type::edge_type is some signed type such as int or double
+ - graph_type::type must be the same type as graph_type::edge_type
+ - graph_contains_length_one_cycle(g) == false
+ - for all valid i and j:
+ - edge(g,i,j) >= 0
+ ensures
+ - This function does the same thing as the version of potts_model_score()
+ defined above, except that this version operates on a dlib::graph
+ instead of a potts_problem object.
+ - computes the model score for the given graph and labeling. We define this
+ precisely below:
+ - let L(i) == the boolean label of the i-th variable in g. Or in other
+ words, L(i) == (labels[i] != 0).
+ - let F == the sum of values of g.node(i).data for only i values
+ where L(i) == true.
+ - Let D == the sum of values of edge(g,i,j) for only i and j
+ values which meet the following conditions:
+ - i and j are neighbors in the graph defined by g, that is,
+ it is valid to call edge(g,i,j).
+ - L(i) != L(j)
+ - i < j
+ (i.e. We want to make sure to only count the edge between i and j once)
+
+ - Then this function returns F - D
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_grid_problem,
+ typename mem_manager
+ >
+ typename potts_grid_problem::value_type potts_model_score (
+ const potts_grid_problem& prob,
+ const array2d<node_label,mem_manager>& labels
+ );
+ /*!
+ requires
+ - prob.nr() == labels.nr()
+ - prob.nc() == labels.nc()
+ - potts_grid_problem == an object with an interface compatible with the
+ potts_grid_problem object defined above.
+ - for all valid i and j:
+ - prob.factor_value_disagreement(i,j) >= 0
+ - prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i)
+ ensures
+ - computes the model score for the given potts_grid_problem. We define this
+ precisely below:
+ - let L(i) == the boolean label of the i-th variable in prob. Or in other
+ words, L(i) == (labels[i/labels.nc()][i%labels.nc()] != 0).
+ - let F == the sum of values of prob.factor_value(i) for only i values
+ where L(i) == true.
+ - Let D == the sum of values of prob.factor_value_disagreement(i,j)
+ for only i and j values which meet the following conditions:
+ - i and j are neighbors in the graph defined by prob, that is,
+ it is valid to call prob.factor_value_disagreement(i,j).
+ - L(i) != L(j)
+ - i < j
+ (i.e. We want to make sure to only count the edge between i and j once)
+
+ - Then this function returns F - D
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_problem
+ >
+ void find_max_factor_graph_potts (
+ potts_problem& prob
+ );
+ /*!
+ requires
+ - potts_problem == an object with an interface compatible with the potts_problem
+ object defined at the top of this file.
+ - for all valid i and j:
+ - prob.factor_value_disagreement(i,j) >= 0
+ - prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i)
+ ensures
+ - This function is a tool for exactly solving the MAP problem in a Potts
+ model. In particular, this means that this function finds the assignments
+ to all the labels in prob which maximizes potts_model_score(#prob).
+ - The optimal labels are stored in #prob.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ void find_max_factor_graph_potts (
+ const graph_type& g,
+ std::vector<node_label>& labels
+ );
+ /*!
+ requires
+ - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h
+ - graph_type::edge_type is some signed type such as int or double
+ - graph_type::type must be the same type as graph_type::edge_type
+ - graph_contains_length_one_cycle(g) == false
+ - for all valid i and j:
+ - edge(g,i,j) >= 0
+ ensures
+ - This routine simply converts g into a potts_problem and calls the
+ version of find_max_factor_graph_potts() defined above on it. Therefore,
+ this routine is just a convenience wrapper that lets you use a dlib::graph
+ to represent a potts problem. This means that this function maximizes
+ the value of potts_model_score(g, #labels).
+ - #labels.size() == g.number_of_nodes()
+ - for all valid i:
+ - #labels[i] == the optimal label for g.node(i)
+ - The correspondence between g and a potts_problem is the following:
+ - the factor_value() for a node is stored in g.node(i).data.
+ - the factor_value_disagreement(i,j) is stored in edge(g,i,j).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_grid_problem,
+ typename mem_manager
+ >
+ void find_max_factor_graph_potts (
+ const potts_grid_problem& prob,
+ array2d<node_label,mem_manager>& labels
+ );
+ /*!
+ requires
+ - potts_grid_problem == an object with an interface compatible with the
+ potts_grid_problem object defined above.
+ - for all valid i and j:
+ - prob.factor_value_disagreement(i,j) >= 0
+ - prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i)
+ ensures
+ - This routine solves a version of a potts problem where the graph is a
+ regular grid where each node is connected to its four immediate neighbors.
+ In particular, this means that this function finds the assignments
+ to all the labels in prob which maximizes potts_model_score(prob,#labels).
+ - The optimal labels are stored in #labels.
+ - #labels.nr() == prob.nr()
+ - #labels.nc() == prob.nc()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// The following functions and interface definitions are convenience routines for use
+// with the potts grid problem version of find_max_factor_graph_potts() defined above.
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ struct single_image_model
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines a slightly more convenient interface for creating
+ potts_grid_problems which operate on an image. In this case, the goal
+ is to assign a binary label to each pixel in an image. In particular,
+ this object defines the interface used by the make_potts_grid_problem()
+ routine defined below.
+
+ In the following comments, we will refer to the image supplied to
+ make_potts_grid_problem() as IMG.
+ !*/
+
+ // This typedef should be for a type like int or double. It
+ // must also be capable of representing signed values.
+ typedef an_integer_or_real_type value_type;
+
+ template <typename pixel_type>
+ value_type factor_value (
+ const pixel_type& v
+ ) const;
+ /*!
+ requires
+ - v is a pixel value from IMG.
+ ensures
+ - returns a value which indicates how "good" it is to assign the location
+ in IMG corresponding to v with the label of true. The larger the value,
+ the more desirable it is to give it this label. Similarly, a negative
+ value indicates that it is better to give the node a label of false.
+ - It is valid for the returned value to be positive or negative infinity.
+ A value of positive infinity indicates that the pixel must be labeled
+ true while negative infinity means it must be labeled false.
+ !*/
+
+ template <typename pixel_type>
+ value_type factor_value_disagreement (
+ const pixel_type& v1,
+ const pixel_type& v2
+ ) const;
+ /*!
+ requires
+ - v1 and v2 are pixel values from neighboring pixels in the IMG image.
+ ensures
+ - returns a number >= 0. This is the penalty for giving neighboring pixels
+ with values v1 and v2 different labels. Larger values indicate a larger
+ penalty.
+ - this function is symmetric. That is, it is true that:
+ factor_value_disagreement(i,j) == factor_value_disagreement(j,i)
+ - It is valid for the returned value to be positive infinity. Returning
+ infinity indicates that the idx1-th and idx2-th nodes must share the same
+ label.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct pair_image_model
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines a slightly more convenient interface for creating
+ potts_grid_problems which operate on a pair of identically sized images.
+ In this case, the goal is to assign a label to each pixel in the first
+ image of the pair. In particular, this object defines the interface
+ used by the make_potts_grid_problem() routine defined below.
+
+ In the following comments, we will refer to the two images supplied to
+ make_potts_grid_problem() as IMG1 and IMG2. The goal of the potts
+ problem will be to assign labels to each pixel in IMG1 (IMG2 is
+ not labeled, it is simply used as a place to keep auxiliary data).
+ !*/
+
+ // This typedef should be for a type like int or double. It
+ // must also be capable of representing signed values.
+ typedef an_integer_or_real_type value_type;
+
+ template <typename pixel_type1, typename pixel_type2>
+ value_type factor_value (
+ const pixel_type1& v1,
+ const pixel_type2& v2
+ ) const;
+ /*!
+ requires
+ - v1 and v2 are corresponding pixels from IMG1 and IMG2 respectively.
+ That is, both pixel values have the same coordinates in the images.
+ So for example, if v1 is the value of IMG1[4][5] then v2 is the value
+ of IMG2[4][5].
+ ensures
+ - returns a value which indicates how "good" it is to assign the location
+ in IMG1 corresponding to v1 with the label of true. The larger the value,
+ the more desirable it is to give it this label. Similarly, a negative
+ value indicates that it is better to give the node a label of false.
+ - It is valid for the returned value to be positive or negative infinity.
+ A value of positive infinity indicates that the pixel must be labeled
+ true while negative infinity means it must be labeled false.
+ !*/
+
+ template <typename pixel_type>
+ value_type factor_value_disagreement (
+ const pixel_type& v1,
+ const pixel_type& v2
+ ) const;
+ /*!
+ requires
+ - v1 and v2 are pixel values from neighboring pixels in the IMG1 image.
+ ensures
+ - returns a number >= 0. This is the penalty for giving neighboring pixels
+ with values v1 and v2 different labels. Larger values indicate a larger
+ penalty.
+ - this function is symmetric. That is, it is true that:
+ factor_value_disagreement(i,j) == factor_value_disagreement(j,i)
+ - It is valid for the returned value to be positive infinity. Returning
+ infinity indicates that the idx1-th and idx2-th nodes must share the same
+ label.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename single_image_model,
+ typename pixel_type,
+ typename mem_manager
+ >
+ potts_grid_problem make_potts_grid_problem (
+ const single_image_model& model,
+ const array2d<pixel_type,mem_manager>& img
+ );
+ /*!
+ requires
+ - single_image_model == an object with an interface compatible with the
+ single_image_model object defined above.
+ - for all valid i and j:
+ - model.factor_value_disagreement(i,j) >= 0
+ - model.factor_value_disagreement(i,j) == model.factor_value_disagreement(j,i)
+ ensures
+ - returns a potts_grid_problem which can be solved using the
+ find_max_factor_graph_potts(prob,array2d) routine defined above. That is,
+ given an image to store the labels, the following statement would solve the
+ potts problem defined by the given model and image:
+ find_max_factor_graph_potts(make_potts_grid_problem(model,img),labels);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pair_image_model,
+ typename pixel_type1,
+ typename pixel_type2,
+ typename mem_manager
+ >
+ potts_grid_problem make_potts_grid_problem (
+ const pair_image_model& model,
+ const array2d<pixel_type1,mem_manager>& img1,
+ const array2d<pixel_type2,mem_manager>& img2
+ );
+ /*!
+ requires
+ - get_rect(img1) == get_rect(img2)
+ (i.e. img1 and img2 have the same dimensions)
+ - pair_image_model == an object with an interface compatible with the
+ pair_image_model object defined above.
+ - for all valid i and j:
+ - model.factor_value_disagreement(i,j) >= 0
+ - model.factor_value_disagreement(i,j) == model.factor_value_disagreement(j,i)
+ ensures
+ - returns a potts_grid_problem which can be solved using the
+ find_max_factor_graph_potts(prob,array2d) routine defined above. That is,
+ given an image to store the labels, the following statement would solve the
+ potts problem defined by the given model and pair of images:
+ find_max_factor_graph_potts(make_potts_grid_problem(model,img1,img2),labels);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/graph_cuts/general_flow_graph.h b/ml/dlib/dlib/graph_cuts/general_flow_graph.h
new file mode 100644
index 000000000..d0b93e311
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/general_flow_graph.h
@@ -0,0 +1,172 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GENERAL_FLOW_GRaPH_Hh_
+#define DLIB_GENERAL_FLOW_GRaPH_Hh_
+
+#include "../graph_utils.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ template <
+ typename directed_graph_type
+ >
+ class general_flow_graph
+ {
+ /*!
+ this is a utility class used by dlib::min_cut to convert a directed_graph
+ into the kind of flow graph expected by the min_cut object's main block
+ of code.
+ !*/
+
+ directed_graph_type& g;
+
+ typedef typename directed_graph_type::node_type node_type;
+ typedef typename directed_graph_type::type node_label;
+
+ public:
+
+ general_flow_graph(
+ directed_graph_type& g_
+ ) : g(g_)
+ {
+ }
+
+ class out_edge_iterator
+ {
+ friend class general_flow_graph;
+ unsigned long idx; // base node idx
+ unsigned long cnt; // count over the neighbors of idx
+ public:
+
+ out_edge_iterator(
+ ):idx(0),cnt(0) {}
+
+ out_edge_iterator(
+ unsigned long idx_,
+ unsigned long cnt_
+ ):idx(idx_),cnt(cnt_)
+ {}
+
+ bool operator!= (
+ const out_edge_iterator& item
+ ) const { return cnt != item.cnt; }
+
+ out_edge_iterator& operator++(
+ )
+ {
+ ++cnt;
+ return *this;
+ }
+ };
+
+ class in_edge_iterator
+ {
+
+ friend class general_flow_graph;
+ unsigned long idx; // base node idx
+ unsigned long cnt; // count over the neighbors of idx
+ public:
+
+ in_edge_iterator(
+ ):idx(0),cnt(0) {}
+
+ in_edge_iterator(
+ unsigned long idx_,
+ unsigned long cnt_
+ ):idx(idx_),cnt(cnt_)
+ {}
+
+ bool operator!= (
+ const in_edge_iterator& item
+ ) const { return cnt != item.cnt; }
+
+ in_edge_iterator& operator++(
+ )
+ {
+ ++cnt;
+ return *this;
+ }
+ };
+
+ unsigned long number_of_nodes (
+ ) const { return g.number_of_nodes(); }
+
+ out_edge_iterator out_begin(
+ const unsigned long& it
+ ) const { return out_edge_iterator(it, 0); }
+
+ in_edge_iterator in_begin(
+ const unsigned long& it
+ ) const { return in_edge_iterator(it, 0); }
+
+ out_edge_iterator out_end(
+ const unsigned long& it
+ ) const { return out_edge_iterator(it, g.node(it).number_of_children()); }
+
+ in_edge_iterator in_end(
+ const unsigned long& it
+ ) const { return in_edge_iterator(it, g.node(it).number_of_parents()); }
+
+ unsigned long node_id (
+ const out_edge_iterator& it
+ ) const { return g.node(it.idx).child(it.cnt).index(); }
+ unsigned long node_id (
+ const in_edge_iterator& it
+ ) const { return g.node(it.idx).parent(it.cnt).index(); }
+
+ typedef typename directed_graph_type::edge_type edge_type;
+
+ edge_type get_flow (const unsigned long& it1, const unsigned long& it2) const
+ {
+ return edge(g, it1, it2);
+ }
+ edge_type get_flow (const out_edge_iterator& it) const
+ {
+ return g.node(it.idx).child_edge(it.cnt);
+ }
+ edge_type get_flow (const in_edge_iterator& it) const
+ {
+ return g.node(it.idx).parent_edge(it.cnt);
+ }
+
+ void adjust_flow (
+ const unsigned long& it1,
+ const unsigned long& it2,
+ const edge_type& value
+ )
+ {
+ edge(g, it1, it2) += value;
+ edge(g, it2, it1) -= value;
+ }
+
+ void set_label (
+ const unsigned long& it,
+ node_label value
+ )
+ {
+ g.node(it).data = value;
+ }
+
+ node_label get_label (
+ const unsigned long& it
+ ) const
+ {
+ return g.node(it).data;
+ }
+
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GENERAL_FLOW_GRaPH_Hh_
+
diff --git a/ml/dlib/dlib/graph_cuts/general_potts_problem.h b/ml/dlib/dlib/graph_cuts/general_potts_problem.h
new file mode 100644
index 000000000..ebedbc572
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/general_potts_problem.h
@@ -0,0 +1,99 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GENERAL_POTTS_PRoBLEM_Hh_
+#define DLIB_GENERAL_POTTS_PRoBLEM_Hh_
+
+#include "../graph_utils.h"
+#include "min_cut.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename graph_type
+ >
+ class general_potts_problem
+ {
+
+ const graph_type& g;
+ std::vector<node_label>& labels;
+ public:
+ general_potts_problem (
+ const graph_type& g_,
+ std::vector<node_label>& labels_
+ ) : g(g_), labels(labels_)
+ {
+ labels.resize(g.number_of_nodes());
+ }
+
+ unsigned long number_of_nodes (
+ ) const { return g.number_of_nodes(); }
+
+ unsigned long number_of_neighbors (
+ unsigned long idx
+ ) const { return g.node(idx).number_of_neighbors(); }
+
+ unsigned long get_neighbor (
+ unsigned long idx,
+ unsigned long n
+ ) const { return g.node(idx).neighbor(n).index(); }
+
+ unsigned long get_neighbor_idx (
+ unsigned long idx1,
+ unsigned long idx2
+ ) const
+ {
+ for (unsigned long i = 0; i < g.node(idx1).number_of_neighbors(); ++i)
+ {
+ if (g.node(idx1).neighbor(i).index() == idx2)
+ return i;
+ }
+
+ // This should never ever execute
+ return 0;
+ }
+
+ void set_label (
+ const unsigned long& idx,
+ node_label value
+ )
+ {
+ labels[idx] = value;
+ }
+
+ node_label get_label (
+ const unsigned long& idx
+ ) const { return labels[idx]; }
+
+ typedef typename graph_type::edge_type value_type;
+
+ value_type factor_value (
+ unsigned long idx
+ ) const
+ {
+ return g.node(idx).data;
+ }
+
+ value_type factor_value_disagreement (
+ unsigned long idx1,
+ unsigned long idx2
+ ) const
+ {
+ return edge(g, idx1, idx2);
+ }
+
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GENERAL_POTTS_PRoBLEM_Hh_
+
+
diff --git a/ml/dlib/dlib/graph_cuts/graph_labeler.h b/ml/dlib/dlib/graph_cuts/graph_labeler.h
new file mode 100644
index 000000000..34c3fcb5e
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/graph_labeler.h
@@ -0,0 +1,211 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GRAPH_LaBELER_Hh_
+#define DLIB_GRAPH_LaBELER_Hh_
+
+#include "graph_labeler_abstract.h"
+#include "../matrix.h"
+#include "../string.h"
+#include <vector>
+#include "find_max_factor_graph_potts.h"
+#include "../svm/sparse_vector.h"
+#include "../graph.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ class graph_labeler
+ {
+
+ public:
+
+ typedef std::vector<bool> label_type;
+ typedef label_type result_type;
+
+ graph_labeler()
+ {
+ }
+
+ graph_labeler(
+ const vector_type& edge_weights_,
+ const vector_type& node_weights_
+ ) :
+ edge_weights(edge_weights_),
+ node_weights(node_weights_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(edge_weights.size() == 0 || min(edge_weights) >= 0,
+ "\t graph_labeler::graph_labeler()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t min(edge_weights): " << min(edge_weights)
+ << "\n\t this: " << this
+ );
+ }
+
+ const vector_type& get_edge_weights (
+ ) const { return edge_weights; }
+
+ const vector_type& get_node_weights (
+ ) const { return node_weights; }
+
+ template <typename graph_type>
+ void operator() (
+ const graph_type& sample,
+ std::vector<bool>& labels
+ ) const
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ DLIB_ASSERT(graph_contains_length_one_cycle(sample) == false,
+ "\t void graph_labeler::operator()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t get_edge_weights().size(): " << get_edge_weights().size()
+ << "\n\t get_node_weights().size(): " << get_node_weights().size()
+ << "\n\t graph_contains_length_one_cycle(sample): " << graph_contains_length_one_cycle(sample)
+ << "\n\t this: " << this
+ );
+ for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
+ {
+ if (is_matrix<vector_type>::value &&
+ is_matrix<typename graph_type::type>::value)
+ {
+ // check that dot() is legal.
+ DLIB_ASSERT((unsigned long)get_node_weights().size() == (unsigned long)sample.node(i).data.size(),
+ "\t void graph_labeler::operator()"
+ << "\n\t The size of the node weight vector must match the one in the node."
+ << "\n\t get_node_weights().size(): " << get_node_weights().size()
+ << "\n\t sample.node(i).data.size(): " << sample.node(i).data.size()
+ << "\n\t i: " << i
+ << "\n\t this: " << this
+ );
+ }
+
+ for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
+ {
+ if (is_matrix<vector_type>::value &&
+ is_matrix<typename graph_type::edge_type>::value)
+ {
+ // check that dot() is legal.
+ DLIB_ASSERT((unsigned long)get_edge_weights().size() == (unsigned long)sample.node(i).edge(n).size(),
+ "\t void graph_labeler::operator()"
+ << "\n\t The size of the edge weight vector must match the one in graph's edge."
+ << "\n\t get_edge_weights().size(): " << get_edge_weights().size()
+ << "\n\t sample.node(i).edge(n).size(): " << sample.node(i).edge(n).size()
+ << "\n\t i: " << i
+ << "\n\t this: " << this
+ );
+ }
+
+ DLIB_ASSERT(sample.node(i).edge(n).size() == 0 || min(sample.node(i).edge(n)) >= 0,
+ "\t void graph_labeler::operator()"
+ << "\n\t No edge vectors are allowed to have negative elements."
+ << "\n\t min(sample.node(i).edge(n)): " << min(sample.node(i).edge(n))
+ << "\n\t i: " << i
+ << "\n\t n: " << n
+ << "\n\t this: " << this
+ );
+ }
+ }
+#endif
+
+
+ graph<double,double>::kernel_1a g;
+ copy_graph_structure(sample, g);
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ g.node(i).data = dot(node_weights, sample.node(i).data);
+
+ for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
+ {
+ const unsigned long j = g.node(i).neighbor(n).index();
+ // Don't compute an edge weight more than once.
+ if (i < j)
+ {
+ g.node(i).edge(n) = dot(edge_weights, sample.node(i).edge(n));
+ }
+ }
+
+ }
+
+ labels.clear();
+ std::vector<node_label> temp;
+ find_max_factor_graph_potts(g, temp);
+ for (unsigned long i = 0; i < temp.size(); ++i)
+ {
+ if (temp[i] != 0)
+ labels.push_back(true);
+ else
+ labels.push_back(false);
+ }
+ }
+
+ template <typename graph_type>
+ std::vector<bool> operator() (
+ const graph_type& sample
+ ) const
+ {
+ std::vector<bool> temp;
+ (*this)(sample, temp);
+ return temp;
+ }
+
+ private:
+
+ vector_type edge_weights;
+ vector_type node_weights;
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void serialize (
+ const graph_labeler<vector_type>& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.get_edge_weights(), out);
+ serialize(item.get_node_weights(), out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void deserialize (
+ graph_labeler<vector_type>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ {
+ throw dlib::serialization_error("While deserializing graph_labeler, found unexpected version number of " +
+ cast_to_string(version) + ".");
+ }
+
+ vector_type edge_weights, node_weights;
+ deserialize(edge_weights, in);
+ deserialize(node_weights, in);
+
+ item = graph_labeler<vector_type>(edge_weights, node_weights);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GRAPH_LaBELER_Hh_
+
+
diff --git a/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h b/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h
new file mode 100644
index 000000000..a0821b696
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h
@@ -0,0 +1,185 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_GRAPH_LaBELER_ABSTRACT_Hh_
+#ifdef DLIB_GRAPH_LaBELER_ABSTRACT_Hh_
+
+#include "find_max_factor_graph_potts_abstract.h"
+#include "../graph/graph_kernel_abstract.h"
+#include "../matrix/matrix_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ class graph_labeler
+ {
+ /*!
+ REQUIREMENTS ON vector_type
+ - vector_type is a dlib::matrix capable of representing column
+ vectors or it is a sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for labeling each node in a graph with a value
+ of true or false, subject to a labeling consistency constraint between
+ nodes that share an edge. In particular, this object is useful for
+ representing a graph labeling model learned via some machine learning
+ method.
+
+ To elaborate, suppose we have a graph we want to label. Moreover,
+ suppose we can assign a score to each node which represents how much
+ we want to label the node as true, and we also have scores for each
+ edge which represent how much we wanted the nodes sharing the edge to
+ have the same label. If we could do this then we could find the optimal
+ labeling using the find_max_factor_graph_potts() routine. Therefore,
+ the graph_labeler is just an object which contains the necessary data
+ to compute these score functions and then call find_max_factor_graph_potts().
+ Additionally, this object uses linear functions to represent these score
+ functions.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call the const members of this object from multiple
+ threads. This is because the const members are purely read-only
+ operations. However, any operation that modifies a graph_labeler is
+ not threadsafe.
+ !*/
+
+ public:
+
+ typedef std::vector<bool> label_type;
+ typedef label_type result_type;
+
+ graph_labeler(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ - #get_node_weights() == an initial value of type vector_type.
+ - #get_edge_weights() == an initial value of type vector_type.
+ !*/
+
+ graph_labeler(
+ const vector_type& edge_weights,
+ const vector_type& node_weights
+ );
+ /*!
+ requires
+ - min(edge_weights) >= 0
+ ensures
+ - #get_edge_weights() == edge_weights
+ - #get_node_weights() == node_weights
+ !*/
+
+ const vector_type& get_edge_weights (
+ ) const;
+ /*!
+ ensures
+ - Recall that the score function for an edge is a linear function of
+ the vector stored at that edge. This means there is some vector, E,
+ which we dot product with the vector in the graph to compute the
+ score. Therefore, this function returns that E vector which defines
+ the edge score function.
+ !*/
+
+ const vector_type& get_node_weights (
+ ) const;
+ /*!
+ ensures
+ - Recall that the score function for a node is a linear function of
+ the vector stored in that node. This means there is some vector, W,
+ which we dot product with the vector in the graph to compute the score.
+ Therefore, this function returns that W vector which defines the node
+ score function.
+ !*/
+
+ template <typename graph_type>
+ void operator() (
+ const graph_type& sample,
+ std::vector<bool>& labels
+ ) const;
+ /*!
+ requires
+ - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h
+ - graph_type::type and graph_type::edge_type must be either matrix objects
+ capable of representing column vectors or some kind of sparse vector
+ type as defined in dlib/svm/sparse_vector_abstract.h.
+ - graph_contains_length_one_cycle(sample) == false
+ - for all valid i and j:
+ - min(edge(sample,i,j)) >= 0
+ - it must be legal to call dot(edge(sample,i,j), get_edge_weights())
+ - it must be legal to call dot(sample.node(i).data, get_node_weights())
+ ensures
+ - Computes a labeling for each node in the given graph and stores the result
+ in #labels.
+ - #labels.size() == sample.number_of_nodes()
+ - for all valid i:
+ - #labels[i] == the label of the node sample.node(i).
+ - The labels are computed by creating a graph, G, with scalar values on each node
+ and edge. The scalar values are calculated according to the following:
+ - for all valid i:
+ - G.node(i).data == dot(get_node_weights(), sample.node(i).data)
+ - for all valid i and j:
+ - edge(G,i,j) == dot(get_edge_weights(), edge(sample,i,j))
+ Then the labels are computed by calling find_max_factor_graph_potts(G,#labels).
+ !*/
+
+ template <typename graph_type>
+ std::vector<bool> operator() (
+ const graph_type& sample
+ ) const;
+ /*!
+ requires
+ - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h
+ - graph_contains_length_one_cycle(sample) == false
+ - for all valid i and j:
+ - min(edge(sample,i,j)) >= 0
+ - it must be legal to call dot(edge(sample,i,j), get_edge_weights())
+ - it must be legal to call dot(sample.node(i).data, get_node_weights())
+ ensures
+ - Performs (*this)(sample, labels); return labels;
+ (i.e. This is just another version of the above operator() routine
+ but instead of returning the labels via the second argument, it
+ returns them as the normal return value).
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void serialize (
+ const graph_labeler<vector_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void deserialize (
+ graph_labeler<vector_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GRAPH_LaBELER_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/graph_cuts/min_cut.h b/ml/dlib/dlib/graph_cuts/min_cut.h
new file mode 100644
index 000000000..6bbb57608
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/min_cut.h
@@ -0,0 +1,571 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MIN_CuT_Hh_
+#define DLIB_MIN_CuT_Hh_
+
+#include "min_cut_abstract.h"
+#include "../matrix.h"
+#include "general_flow_graph.h"
+#include "../is_kind.h"
+
+#include <iostream>
+#include <fstream>
+#include <deque>
+
+
+// ----------------------------------------------------------------------------------------
+
+
+namespace dlib
+{
+
+ typedef unsigned char node_label;
+
+// ----------------------------------------------------------------------------------------
+
+ const node_label SOURCE_CUT = 0;
+ const node_label SINK_CUT = 254;
+ const node_label FREE_NODE = 255;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename flow_graph>
+ typename disable_if<is_directed_graph<flow_graph>,typename flow_graph::edge_type>::type
+ graph_cut_score (
+ const flow_graph& g
+ )
+ {
+ typedef typename flow_graph::edge_type edge_weight_type;
+ edge_weight_type score = 0;
+ typedef typename flow_graph::out_edge_iterator out_edge_iterator;
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ if (g.get_label(i) != SOURCE_CUT)
+ continue;
+
+ for (out_edge_iterator n = g.out_begin(i); n != g.out_end(i); ++n)
+ {
+ if (g.get_label(g.node_id(n)) != SOURCE_CUT)
+ {
+ score += g.get_flow(n);
+ }
+ }
+ }
+
+ return score;
+ }
+
+ template <typename directed_graph>
+ typename enable_if<is_directed_graph<directed_graph>,typename directed_graph::edge_type>::type
+ graph_cut_score (
+ const directed_graph& g
+ )
+ {
+ return graph_cut_score(dlib::impl::general_flow_graph<const directed_graph>(g));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class min_cut
+ {
+
+ public:
+
+ min_cut()
+ {
+ }
+
+ min_cut( const min_cut& )
+ {
+ // Intentionally left empty since all the member variables
+ // don't logically contribute to the state of this object.
+ // This copy constructor is here to explicitly avoid the overhead
+ // of copying these transient variables.
+ }
+
+ template <
+ typename directed_graph
+ >
+ typename enable_if<is_directed_graph<directed_graph> >::type operator() (
+ directed_graph& g,
+ const unsigned long source_node,
+ const unsigned long sink_node
+ ) const
+ {
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\t void min_cut::operator()"
+ << "\n\t Invalid arguments were given to this function."
+ );
+ DLIB_ASSERT(graph_has_symmetric_edges(g) == true,
+ "\t void min_cut::operator()"
+ << "\n\t Invalid arguments were given to this function."
+ );
+
+ dlib::impl::general_flow_graph<directed_graph> temp(g);
+ (*this)(temp, source_node, sink_node);
+ }
+
+ template <
+ typename flow_graph
+ >
+ typename disable_if<is_directed_graph<flow_graph> >::type operator() (
+ flow_graph& g,
+ const unsigned long source_node,
+ const unsigned long sink_node
+ ) const
+ {
+#ifdef ENABLE_ASSERTS
+ DLIB_ASSERT(source_node != sink_node &&
+ source_node < g.number_of_nodes() &&
+ sink_node < g.number_of_nodes(),
+ "\t void min_cut::operator()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t g.number_of_nodes(): " << g.number_of_nodes()
+ << "\n\t source_node: " << source_node
+ << "\n\t sink_node: " << sink_node
+ << "\n\t this: " << this
+ );
+
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ typename flow_graph::out_edge_iterator j, end = g.out_end(i);
+ for (j = g.out_begin(i); j != end; ++j)
+ {
+ const unsigned long jj = g.node_id(j);
+ DLIB_ASSERT(g.get_flow(i,jj) >= 0,
+ "\t void min_cut::operator()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: "<< i
+ << "\n\t jj: "<< jj
+ << "\n\t g.get_flow(i,jj): "<< g.get_flow(i,jj)
+ << "\n\t this: "<< this
+ );
+
+ }
+ }
+#endif
+ parent.clear();
+ active.clear();
+ orphans.clear();
+
+ typedef typename flow_graph::edge_type edge_type;
+ COMPILE_TIME_ASSERT(is_signed_type<edge_type>::value);
+
+ typedef typename flow_graph::out_edge_iterator out_edge_iterator;
+ typedef typename flow_graph::in_edge_iterator in_edge_iterator;
+
+ // initialize labels
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ g.set_label(i, FREE_NODE);
+
+ g.set_label(source_node, SOURCE_CUT);
+ g.set_label(sink_node, SINK_CUT);
+
+ // used to indicate "no parent"
+ const unsigned long no_parent = g.number_of_nodes();
+
+ parent.assign(g.number_of_nodes(), no_parent);
+
+ time = 1;
+ dist.assign(g.number_of_nodes(), 0);
+ ts.assign(g.number_of_nodes(), time);
+
+ active.push_back(source_node);
+ active.push_back(sink_node);
+
+
+ in_edge_iterator in_begin = g.in_begin(active[0]);
+ out_edge_iterator out_begin = g.out_begin(active[0]);
+
+ unsigned long source_side, sink_side;
+ while (grow(g,source_side,sink_side, in_begin, out_begin))
+ {
+ ++time;
+ ts[source_node] = time;
+ ts[sink_node] = time;
+
+ augment(g, source_node, sink_node, source_side, sink_side);
+ adopt(g, source_node, sink_node);
+ }
+
+ }
+
+
+ private:
+
+ unsigned long distance_to_origin (
+ const unsigned long no_parent,
+ unsigned long p,
+ unsigned long
+ ) const
+ {
+ unsigned long start = p;
+ unsigned long count = 0;
+ while (p != no_parent)
+ {
+ if (ts[p] == time)
+ {
+ count += dist[p];
+
+ unsigned long count_down = count;
+ // adjust the dist and ts for the nodes on this path.
+ while (start != p)
+ {
+ ts[start] = time;
+ dist[start] = count_down;
+ --count_down;
+ start = parent[start];
+ }
+
+ return count;
+ }
+ p = parent[p];
+ ++count;
+ }
+
+ return std::numeric_limits<unsigned long>::max();
+ }
+
+ template <typename flow_graph>
+ void adopt (
+ flow_graph& g,
+ const unsigned long source,
+ const unsigned long sink
+ ) const
+ {
+ typedef typename flow_graph::out_edge_iterator out_edge_iterator;
+ typedef typename flow_graph::in_edge_iterator in_edge_iterator;
+
+ // used to indicate "no parent"
+ const unsigned long no_parent = g.number_of_nodes();
+
+ while (orphans.size() > 0)
+ {
+ const unsigned long p = orphans.back();
+ orphans.pop_back();
+
+ const unsigned char label_p = g.get_label(p);
+
+ // Try to find a valid parent for p.
+ if (label_p == SOURCE_CUT)
+ {
+ const in_edge_iterator begin(g.in_begin(p));
+ const in_edge_iterator end(g.in_end(p));
+ unsigned long best_dist = std::numeric_limits<unsigned long>::max();
+ unsigned long best_node = 0;
+ for(in_edge_iterator q = begin; q != end; ++q)
+ {
+ const unsigned long id = g.node_id(q);
+
+ if (g.get_label(id) != label_p || g.get_flow(q) <= 0 )
+ continue;
+
+ unsigned long temp = distance_to_origin(no_parent, id,source);
+ if (temp < best_dist)
+ {
+ best_dist = temp;
+ best_node = id;
+ }
+
+ }
+ if (best_dist != std::numeric_limits<unsigned long>::max())
+ {
+ parent[p] = best_node;
+ dist[p] = dist[best_node] + 1;
+ ts[p] = time;
+ }
+
+ // if we didn't find a parent for p
+ if (parent[p] == no_parent)
+ {
+ for(in_edge_iterator q = begin; q != end; ++q)
+ {
+ const unsigned long id = g.node_id(q);
+
+ if (g.get_label(id) != SOURCE_CUT)
+ continue;
+
+ if (g.get_flow(q) > 0)
+ active.push_back(id);
+
+ if (parent[id] == p)
+ {
+ parent[id] = no_parent;
+ orphans.push_back(id);
+ }
+ }
+ g.set_label(p, FREE_NODE);
+ }
+ }
+ else
+ {
+ unsigned long best_node = 0;
+ unsigned long best_dist = std::numeric_limits<unsigned long>::max();
+ const out_edge_iterator begin(g.out_begin(p));
+ const out_edge_iterator end(g.out_end(p));
+ for(out_edge_iterator q = begin; q != end; ++q)
+ {
+ const unsigned long id = g.node_id(q);
+ if (g.get_label(id) != label_p || g.get_flow(q) <= 0)
+ continue;
+
+ unsigned long temp = distance_to_origin(no_parent, id,sink);
+
+ if (temp < best_dist)
+ {
+ best_dist = temp;
+ best_node = id;
+ }
+ }
+
+ if (best_dist != std::numeric_limits<unsigned long>::max())
+ {
+ parent[p] = best_node;
+ dist[p] = dist[best_node] + 1;
+ ts[p] = time;
+ }
+
+ // if we didn't find a parent for p
+ if (parent[p] == no_parent)
+ {
+ for(out_edge_iterator q = begin; q != end; ++q)
+ {
+ const unsigned long id = g.node_id(q);
+
+ if (g.get_label(id) != SINK_CUT)
+ continue;
+
+ if (g.get_flow(q) > 0)
+ active.push_back(id);
+
+ if (parent[id] == p)
+ {
+ parent[id] = no_parent;
+ orphans.push_back(id);
+ }
+ }
+
+ g.set_label(p, FREE_NODE);
+ }
+ }
+
+
+ }
+
+ }
+
+ template <typename flow_graph>
+ void augment (
+ flow_graph& g,
+ const unsigned long& source,
+ const unsigned long& sink,
+ const unsigned long& source_side,
+ const unsigned long& sink_side
+ ) const
+ {
+ typedef typename flow_graph::edge_type edge_type;
+
+ // used to indicate "no parent"
+ const unsigned long no_parent = g.number_of_nodes();
+
+ unsigned long s = source_side;
+ unsigned long t = sink_side;
+ edge_type min_cap = g.get_flow(s,t);
+
+ // find the bottleneck capacity on the current path.
+
+ // check from source_side back to the source for the min capacity link.
+ t = s;
+ while (t != source)
+ {
+ s = parent[t];
+ const edge_type temp = g.get_flow(s, t);
+ if (temp < min_cap)
+ {
+ min_cap = temp;
+ }
+ t = s;
+ }
+
+ // check from sink_side back to the sink for the min capacity link
+ s = sink_side;
+ while (s != sink)
+ {
+ t = parent[s];
+ const edge_type temp = g.get_flow(s, t);
+ if (temp < min_cap)
+ {
+ min_cap = temp;
+ }
+ s = t;
+ }
+
+
+ // now push the max possible amount of flow though the path
+ s = source_side;
+ t = sink_side;
+ g.adjust_flow(t,s, min_cap);
+
+ // trace back towards the source
+ t = s;
+ while (t != source)
+ {
+ s = parent[t];
+ g.adjust_flow(t,s, min_cap);
+ if (g.get_flow(s,t) <= 0)
+ {
+ parent[t] = no_parent;
+ orphans.push_back(t);
+ }
+
+ t = s;
+ }
+
+ // trace back towards the sink
+ s = sink_side;
+ while (s != sink)
+ {
+ t = parent[s];
+ g.adjust_flow(t,s, min_cap);
+ if (g.get_flow(s,t) <= 0)
+ {
+ parent[s] = no_parent;
+ orphans.push_back(s);
+ }
+ s = t;
+ }
+ }
+
+ template <typename flow_graph>
+ bool grow (
+ flow_graph& g,
+ unsigned long& source_side,
+ unsigned long& sink_side,
+ typename flow_graph::in_edge_iterator& in_begin,
+ typename flow_graph::out_edge_iterator& out_begin
+ ) const
+ /*!
+ ensures
+ - if (an augmenting path was found) then
+ - returns true
+ - (#source_side, #sink_side) == the point where the two trees meet.
+ #source_side is part of the source tree and #sink_side is part of
+ the sink tree.
+ - else
+ - returns false
+ !*/
+ {
+ typedef typename flow_graph::out_edge_iterator out_edge_iterator;
+ typedef typename flow_graph::in_edge_iterator in_edge_iterator;
+
+
+ while (active.size() != 0)
+ {
+ // pick an active node
+ const unsigned long A = active[0];
+
+ const unsigned char label_A = g.get_label(A);
+
+ // process its neighbors
+ if (label_A == SOURCE_CUT)
+ {
+ const out_edge_iterator out_end = g.out_end(A);
+ for(out_edge_iterator& i = out_begin; i != out_end; ++i)
+ {
+ if (g.get_flow(i) > 0)
+ {
+ const unsigned long id = g.node_id(i);
+ const unsigned char label_i = g.get_label(id);
+ if (label_i == FREE_NODE)
+ {
+ active.push_back(id);
+ g.set_label(id,SOURCE_CUT);
+ parent[id] = A;
+ ts[id] = ts[A];
+ dist[id] = dist[A] + 1;
+ }
+ else if (label_A != label_i)
+ {
+ source_side = A;
+ sink_side = id;
+ return true;
+ }
+ else if (is_closer(A, id))
+ {
+ parent[id] = A;
+ ts[id] = ts[A];
+ dist[id] = dist[A] + 1;
+ }
+ }
+ }
+ }
+ else if (label_A == SINK_CUT)
+ {
+ const in_edge_iterator in_end = g.in_end(A);
+ for(in_edge_iterator& i = in_begin; i != in_end; ++i)
+ {
+ if (g.get_flow(i) > 0)
+ {
+ const unsigned long id = g.node_id(i);
+ const unsigned char label_i = g.get_label(id);
+ if (label_i == FREE_NODE)
+ {
+ active.push_back(id);
+ g.set_label(id,SINK_CUT);
+ parent[id] = A;
+ ts[id] = ts[A];
+ dist[id] = dist[A] + 1;
+ }
+ else if (label_A != label_i)
+ {
+ sink_side = A;
+ source_side = id;
+ return true;
+ }
+ else if (is_closer(A, id))
+ {
+ parent[id] = A;
+ ts[id] = ts[A];
+ dist[id] = dist[A] + 1;
+ }
+ }
+ }
+ }
+
+ active.pop_front();
+ if (active.size() != 0)
+ {
+ in_begin = g.in_begin(active[0]);
+ out_begin = g.out_begin(active[0]);
+ }
+ }
+
+ return false;
+ }
+
+ inline bool is_closer (
+ unsigned long p,
+ unsigned long q
+ ) const
+ {
+ // return true if p is closer to a terminal than q
+ return ts[q] <= ts[p] && dist[q] > dist[p];
+ }
+
+ mutable std::vector<uint32> dist;
+ mutable std::vector<uint32> ts;
+ mutable uint32 time;
+ mutable std::vector<unsigned long> parent;
+
+ mutable std::deque<unsigned long> active;
+ mutable std::vector<unsigned long> orphans;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_MIN_CuT_Hh_
+
diff --git a/ml/dlib/dlib/graph_cuts/min_cut_abstract.h b/ml/dlib/dlib/graph_cuts/min_cut_abstract.h
new file mode 100644
index 000000000..748aca950
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/min_cut_abstract.h
@@ -0,0 +1,476 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MIN_CuT_ABSTRACT_Hh_
+#ifdef DLIB_MIN_CuT_ABSTRACT_Hh_
+
+#include "../graph_utils.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ /*!A node_label
+ The node_label type is the type used to label which part of a graph cut
+ a node is on. It is used by all the graph cut tools. The three possible
+ values of a node label are SOURCE_CUT, SINK_CUT, or FREE_NODE.
+ !*/
+
+ typedef unsigned char node_label;
+ const node_label SOURCE_CUT = 0;
+ const node_label SINK_CUT = 254;
+ const node_label FREE_NODE = 255;
+
+// ----------------------------------------------------------------------------------------
+
+ class flow_graph
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a flow capacity graph for use with the
+ min_cut algorithm defined below. In particular, this object
+ is a kind of directed graph where the edge weights specify the
+ flow capacities.
+
+ Note that there is no dlib::flow_graph object. What you are
+ looking at here is simply the interface definition for a graph
+ which can be used with the min_cut algorithm. You must implement
+ your own version of this object for the graph you wish to work with
+ and then pass it to the min_cut::operator() routine.
+
+ It's also worth pointing out that this graph has symmetric edge
+ connections. That is, if there is an edge from node A to node B
+ then there must also be an edge from node B to node A.
+ !*/
+
+ public:
+
+ class out_edge_iterator
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple forward iterator for iterating over the neighbors
+ of a node in the graph. It also represents the fact that the neighbors
+ are on the end of an outgoing edge. That is, the edge represents
+ the amount of flow which can flow towards the neighbor.
+ !*/
+
+ public:
+ out_edge_iterator(
+ );
+ /*!
+ ensures
+ - constructs an iterator in an undefined state. It can't
+ be used until assigned with a valid iterator.
+ !*/
+
+ out_edge_iterator(
+ const out_edge_iterator& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ !*/
+
+ out_edge_iterator& operator=(
+ const out_edge_iterator& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ - returns #*this
+ !*/
+
+ bool operator!= (
+ const out_edge_iterator& item
+ ) const;
+ /*!
+ requires
+ - *this and item are iterators over the neighbors for the
+ same node.
+ ensures
+ - returns false if *this and item both reference the same
+ node in the graph and true otherwise.
+ !*/
+
+ out_edge_iterator& operator++(
+ );
+ /*!
+ ensures
+ - advances *this to the next neighbor node.
+ - returns a reference to the updated *this
+ (i.e. this is the ++object form of the increment operator)
+ !*/
+ };
+
+ class in_edge_iterator
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple forward iterator for iterating over the neighbors
+ of a node in the graph. It also represents the fact that the neighbors
+ are on the end of an incoming edge. That is, the edge represents
+ the amount of flow which can flow out of the neighbor node.
+ !*/
+
+ public:
+
+ in_edge_iterator(
+ );
+ /*!
+ ensures
+ - constructs an iterator in an undefined state. It can't
+ be used until assigned with a valid iterator.
+ !*/
+
+ in_edge_iterator(
+ const in_edge_iterator& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ !*/
+
+ in_edge_iterator& operator=(
+ const in_edge_iterator& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ - returns #*this
+ !*/
+
+ bool operator!= (
+ const in_edge_iterator& item
+ ) const;
+ /*!
+ requires
+ - *this and item are iterators over the neighbors for the
+ same node.
+ ensures
+ - returns false if *this and item both reference the same
+ node in the graph and true otherwise.
+ !*/
+
+ in_edge_iterator& operator++(
+ );
+ /*!
+ ensures
+ - advances *this to the next neighbor node.
+ - returns a reference to the updated *this
+ (i.e. this is the ++object form of the increment operator)
+ !*/
+ };
+
+
+
+ unsigned long number_of_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in the graph.
+ !*/
+
+ out_edge_iterator out_begin(
+ const unsigned long& idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - returns an iterator pointing to the first neighboring node of
+ the idx-th node. If no such node exists then returns out_end(idx).
+ - The returned iterator also represents the directed edge going from
+ node idx to the neighbor.
+ !*/
+
+ in_edge_iterator in_begin(
+ const unsigned long& idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - returns an iterator pointing to the first neighboring node of
+ the idx-th node. If no such node exists then returns in_end(idx).
+ - The returned iterator also represents the directed edge going from
+ the neighbor to node idx.
+ !*/
+
+ out_edge_iterator out_end(
+ const unsigned long& idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - returns an iterator to one past the last neighboring node of
+ the idx-th node.
+ !*/
+
+ in_edge_iterator in_end(
+ const unsigned long& idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - returns an iterator to one past the last neighboring node of
+ the idx-th node.
+ !*/
+
+
+ unsigned long node_id (
+ const out_edge_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator (i.e. it must be in the range [out_begin(idx), out_end(idx))
+ for some valid idx)
+ ensures
+ - returns a number IDX such that:
+ - 0 <= IDX < number_of_nodes()
+ - IDX == The index which uniquely identifies the node pointed to by the
+ iterator it. This number can be used with any member function in this
+ object which expect a node index. (e.g. get_label(IDX) == the label for the
+ node pointed to by it)
+ !*/
+
+ unsigned long node_id (
+ const in_edge_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator (i.e. it must be in the range [in_begin(idx), in_end(idx))
+ for some valid idx)
+ ensures
+ - returns a number IDX such that:
+ - 0 <= IDX < number_of_nodes()
+ - IDX == The index which uniquely identifies the node pointed to by the
+ iterator it. This number can be used with any member function in this
+ object which expect a node index. (e.g. get_label(IDX) == the label for the
+ node pointed to by it)
+ !*/
+
+ // This typedef should be for a type like int or double. It
+ // must also be capable of representing signed values.
+ typedef an_integer_or_real_type edge_type;
+
+ edge_type get_flow (
+ const unsigned long& idx1,
+ const unsigned long& idx2
+ ) const;
+ /*!
+ requires
+ - idx1 < number_of_nodes()
+ - idx2 < number_of_nodes()
+ - idx1 and idx2 are neighbors in the graph
+ ensures
+ - returns the residual flow capacity from the idx1-th node to the idx2-th node.
+ - It is valid for this function to return a floating point value of infinity.
+ This value means this edge has an unlimited capacity.
+ !*/
+
+ edge_type get_flow (
+ const out_edge_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator (i.e. it must be in the range [out_begin(idx), out_end(idx))
+ for some valid idx)
+ ensures
+ - let IDX = node_id(it)
+ - it represents the directed edge from a node, call it H, to the node IDX. Therefore,
+ this function returns get_flow(H,IDX)
+ - It is valid for this function to return a floating point value of infinity.
+ This value means this edge has an unlimited capacity.
+ !*/
+
+ edge_type get_flow (
+ const in_edge_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator (i.e. it must be in the range [in_begin(idx), in_end(idx))
+ for some valid idx)
+ ensures
+ - let IDX = node_id(it)
+ - it represents the directed edge from node IDX to another node, call it H. Therefore,
+ this function returns get_flow(IDX,H)
+ - It is valid for this function to return a floating point value of infinity.
+ This value means this edge has an unlimited capacity.
+ !*/
+
+ void adjust_flow (
+ const unsigned long& idx1,
+ const unsigned long& idx2,
+ const edge_type& value
+ );
+ /*!
+ requires
+ - idx1 < number_of_nodes()
+ - idx2 < number_of_nodes()
+ - idx1 and idx2 are neighbors in the graph
+ ensures
+ - #get_flow(idx1,idx2) == get_flow(idx1,idx2) + value
+ - #get_flow(idx2,idx1) == get_flow(idx2,idx1) - value
+ !*/
+
+ void set_label (
+ const unsigned long& idx,
+ node_label value
+ );
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - #get_label(idx) == value
+ !*/
+
+ node_label get_label (
+ const unsigned long& idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_nodes()
+ ensures
+ - returns the label for the idx-th node in the graph.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename flow_graph
+ >
+ typename flow_graph::edge_type graph_cut_score (
+ const flow_graph& g
+ );
+ /*!
+ requires
+ - flow_graph == an object with an interface compatible with the flow_graph
+ object defined at the top of this file, or, an implementation of
+ dlib/directed_graph/directed_graph_kernel_abstract.h.
+ ensures
+ - returns the sum of the outgoing flows from nodes with a label of SOURCE_CUT
+ to nodes with a label != SOURCE_CUT. Note that for a directed_graph object,
+ the labels are stored in the node's data field.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class min_cut
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a function object which can be used to find the min cut
+ on a graph.
+
+ The implementation is based on the method described in the following
+ paper:
+ An Experimental Comparison of Min-Cut/Max-Flow Algorithms for
+ Energy Minimization in Vision, by Yuri Boykov and Vladimir Kolmogorov,
+ in PAMI 2004.
+
+ !*/
+
+ public:
+
+ min_cut(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ template <
+ typename flow_graph
+ >
+ void operator() (
+ flow_graph& g,
+ const unsigned long source_node,
+ const unsigned long sink_node
+ ) const;
+ /*!
+ requires
+ - flow_graph == an object with an interface compatible with the flow_graph
+ object defined at the top of this file.
+ - source_node != sink_node
+ - source_node < g.number_of_nodes()
+ - sink_node < g.number_of_nodes()
+ - for all valid i and j:
+ - g.get_flow(i,j) >= 0
+ (i.e. all the flow capacities/edge weights are non-negative)
+ - g does not contain any self loops. That is, no nodes are neighbors with
+ themselves.
+ ensures
+ - Finds the minimum cut on the given graph. That is, this function finds
+ a labeling of nodes in g such that graph_cut_score(g) would be minimized. Note
+ that the flow values in #g are modified by this algorithm so if you want
+ to obtain the min cut score you must call min_cut::operator(), then copy
+ the flow values back into #g, and then call graph_cut_score(#g). But in most
+ cases you don't care about the value of the min cut score, rather, you
+ just want the labels in #g.
+ - #g.get_label(source_node) == SOURCE_CUT
+ - #g.get_label(sink_node) == SINK_CUT
+ - for all valid i:
+ - #g.get_label(i) == SOURCE_CUT, SINK_CUT, or FREE_NODE
+ - if (#g.get_label(i) == SOURCE_CUT) then
+ - The minimum cut of g places node i into the source side of the cut.
+ - if (#g.get_label(i) == SINK_CUT) then
+ - The minimum cut of g places node i into the sink side of the cut.
+ - if (#g.get_label(i) == FREE_NODE) then
+ - Node i can be labeled SOURCE_CUT or SINK_CUT. Both labelings
+ result in the same cut score.
+ - When interpreting g as a graph of flow capacities from the source_node
+ to the sink_node we can say that the min cut problem is equivalent to
+ the max flow problem. This equivalent problem is to find out how to push
+ as much "flow" from the source node to the sink node as possible.
+ Upon termination, #g will contain the final flow residuals in addition to
+ the graph cut labels. That is, for all valid i and j:
+ - #g.get_flow(i,j) == the residual flow capacity left after the max
+ possible amount of flow is passing from the source node to the sink
+ node. For example, this means that #g.get_flow(i,j) == 0 whenever
+ node i is in the SOURCE_CUT and j is in the SINK_CUT.
+ - #g.get_flow(i,j) >= 0
+ !*/
+
+ template <
+ typename directed_graph
+ >
+ void operator() (
+ directed_graph& g,
+ const unsigned long source_node,
+ const unsigned long sink_node
+ ) const;
+ /*!
+ requires
+ - directed_graph == an implementation of dlib/directed_graph/directed_graph_kernel_abstract.h
+ - directed_graph::type == node_label
+ - directed_graph::edge_type == and integer or double type
+ - source_node != sink_node
+ - source_node < g.number_of_nodes()
+ - sink_node < g.number_of_nodes()
+ - for all valid i and j:
+ - edge(g,i,j) >= 0
+ (i.e. all the flow capacities/edge weights are positive)
+ - graph_contains_length_one_cycle(g) == false
+ - graph_has_symmetric_edges(g) == true
+ ensures
+ - This routine simply converts g into a flow graph and calls the version
+ of operator() defined above. Note that the conversion is done in O(1)
+ time, it's just an interface adaptor.
+ - edge weights in g correspond to network flows while the .data field of
+ each node in g corresponds to the graph node labels.
+ - upon termination, the flows and labels in g will have been modified
+ as described in the above operator() routine.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MIN_CuT_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/graph_utils.h b/ml/dlib/dlib/graph_utils.h
new file mode 100644
index 000000000..c79e05b64
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils.h
@@ -0,0 +1,12 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GRAPH_UTILs_H_
+#define DLIB_GRAPH_UTILs_H_
+
+#include "graph_utils/graph_utils.h"
+#include "graph_utils/edge_list_graphs.h"
+#include "graph_utils/function_objects.h"
+
+#endif // DLIB_GRAPH_UTILs_H_
+
+
diff --git a/ml/dlib/dlib/graph_utils/edge_list_graphs.h b/ml/dlib/dlib/graph_utils/edge_list_graphs.h
new file mode 100644
index 000000000..d2447acdb
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/edge_list_graphs.h
@@ -0,0 +1,593 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_EDGE_LIST_GrAPHS_Hh_
+#define DLIB_EDGE_LIST_GrAPHS_Hh_
+
+#include "edge_list_graphs_abstract.h"
+#include <limits>
+#include <vector>
+#include "../string.h"
+#include "../rand.h"
+#include <algorithm>
+#include "sample_pair.h"
+#include "ordered_sample_pair.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_duplicate_edges (
+ vector_type& pairs
+ )
+ {
+ typedef typename vector_type::value_type T;
+ if (pairs.size() > 0)
+ {
+ // sort pairs so that we can avoid duplicates in the loop below
+ std::sort(pairs.begin(), pairs.end(), &order_by_index<T>);
+
+ // now put edges into temp while avoiding duplicates
+ vector_type temp;
+ temp.reserve(pairs.size());
+ temp.push_back(pairs[0]);
+ for (unsigned long i = 1; i < pairs.size(); ++i)
+ {
+ if (pairs[i] != pairs[i-1])
+ {
+ temp.push_back(pairs[i]);
+ }
+ }
+
+ temp.swap(pairs);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename iterator>
+ iterator iterator_of_worst (
+ iterator begin,
+ const iterator& end
+ )
+ /*!
+ ensures
+ - returns an iterator that points to the element in the given range
+ that has the biggest distance
+ !*/
+ {
+ double dist = begin->distance();
+ iterator worst = begin;
+ for (; begin != end; ++begin)
+ {
+ if (begin->distance() > dist)
+ {
+ dist = begin->distance();
+ worst = begin;
+ }
+ }
+
+ return worst;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename distance_function_type,
+ typename alloc,
+ typename T
+ >
+ void find_percent_shortest_edges_randomly (
+ const vector_type& samples,
+ const distance_function_type& dist_funct,
+ const double percent,
+ const unsigned long num,
+ const T& random_seed,
+ std::vector<sample_pair, alloc>& out
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( 0 < percent && percent <= 1 &&
+ num > 0,
+ "\t void find_percent_shortest_edges_randomly()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t percent: " << percent
+ << "\n\t num: " << num
+ );
+
+ out.clear();
+
+ if (samples.size() <= 1)
+ {
+ return;
+ }
+
+ std::vector<sample_pair, alloc> edges;
+ edges.reserve(num);
+
+ dlib::rand rnd;
+ rnd.set_seed(cast_to_string(random_seed));
+
+ // randomly sample a bunch of edges
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size();
+ const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size();
+ if (idx1 != idx2)
+ {
+ const double dist = dist_funct(samples[idx1], samples[idx2]);
+ if (dist < std::numeric_limits<double>::infinity())
+ {
+ edges.push_back(sample_pair(idx1, idx2, dist));
+ }
+ }
+ }
+
+
+ // now put edges into out while avoiding duplicates
+ if (edges.size() > 0)
+ {
+ remove_duplicate_edges(edges);
+
+ // now sort all the edges by distance and take the percent with the smallest distance
+ std::sort(edges.begin(), edges.end(), &order_by_distance<sample_pair>);
+
+ const unsigned long out_size = std::min<unsigned long>((unsigned long)(num*percent), edges.size());
+ out.assign(edges.begin(), edges.begin() + out_size);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename distance_function_type,
+ typename alloc,
+ typename T
+ >
+ void find_approximate_k_nearest_neighbors (
+ const vector_type& samples,
+ const distance_function_type& dist_funct,
+ const unsigned long k,
+ unsigned long num,
+ const T& random_seed,
+ std::vector<sample_pair, alloc>& out
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( num > 0 && k > 0,
+ "\t void find_approximate_k_nearest_neighbors()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t k: " << k
+ << "\n\t num: " << num
+ );
+
+ out.clear();
+
+ if (samples.size() <= 1)
+ {
+ return;
+ }
+
+ // we add each edge twice in the following loop. So multiply num by 2 to account for that.
+ num *= 2;
+
+ std::vector<ordered_sample_pair> edges;
+ edges.reserve(num);
+ std::vector<sample_pair, alloc> temp;
+ temp.reserve(num);
+
+ dlib::rand rnd;
+ rnd.set_seed(cast_to_string(random_seed));
+
+ // randomly sample a bunch of edges
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size();
+ const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size();
+ if (idx1 != idx2)
+ {
+ const double dist = dist_funct(samples[idx1], samples[idx2]);
+ if (dist < std::numeric_limits<double>::infinity())
+ {
+ edges.push_back(ordered_sample_pair(idx1, idx2, dist));
+ edges.push_back(ordered_sample_pair(idx2, idx1, dist));
+ }
+ }
+ }
+
+ std::sort(edges.begin(), edges.end(), &order_by_index<ordered_sample_pair>);
+
+ std::vector<ordered_sample_pair>::iterator beg, itr;
+ // now copy edges into temp when they aren't duplicates and also only move in the k shortest for
+ // each index.
+ itr = edges.begin();
+ while (itr != edges.end())
+ {
+ // first find the bounding range for all the edges connected to node itr->index1()
+ beg = itr;
+ while (itr != edges.end() && itr->index1() == beg->index1())
+ ++itr;
+
+ // If the node has more than k edges then sort them by distance so that
+ // we will end up with the k best.
+ if (static_cast<unsigned long>(itr - beg) > k)
+ {
+ std::sort(beg, itr, &order_by_distance_and_index<ordered_sample_pair>);
+ }
+
+ // take the k best unique edges from the range [beg,itr)
+ temp.push_back(sample_pair(beg->index1(), beg->index2(), beg->distance()));
+ unsigned long prev_index2 = beg->index2();
+ ++beg;
+ unsigned long count = 1;
+ for (; beg != itr && count < k; ++beg)
+ {
+ if (beg->index2() != prev_index2)
+ {
+ temp.push_back(sample_pair(beg->index1(), beg->index2(), beg->distance()));
+ ++count;
+ }
+ prev_index2 = beg->index2();
+ }
+ }
+
+
+ remove_duplicate_edges(temp);
+ temp.swap(out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename distance_function_type,
+ typename alloc
+ >
+ void find_k_nearest_neighbors (
+ const vector_type& samples,
+ const distance_function_type& dist_funct,
+ const unsigned long k,
+ std::vector<sample_pair, alloc>& out
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(k > 0,
+ "\t void find_k_nearest_neighbors()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t k: " << k
+ );
+
+ out.clear();
+
+ if (samples.size() <= 1)
+ {
+ return;
+ }
+
+ using namespace impl;
+ std::vector<sample_pair> edges;
+
+ // Initialize all the edges to an edge with an invalid index
+ edges.resize(samples.size()*k,
+ sample_pair(samples.size(),samples.size(),std::numeric_limits<double>::infinity()));
+
+ // Hold the length for the longest edge for each node. Initially they are all infinity.
+ std::vector<double> worst_dists(samples.size(), std::numeric_limits<double>::infinity());
+
+ std::vector<sample_pair>::iterator begin_i, end_i, begin_j, end_j;
+ begin_i = edges.begin();
+ end_i = begin_i + k;
+
+ // Loop over all combinations of samples. We will maintain the iterator ranges so that
+ // within the inner for loop we have:
+ // [begin_i, end_i) == the range in edges that contains neighbors of samples[i]
+ // [begin_j, end_j) == the range in edges that contains neighbors of samples[j]
+ for (unsigned long i = 0; i+1 < samples.size(); ++i)
+ {
+ begin_j = begin_i;
+ end_j = end_i;
+
+ for (unsigned long j = i+1; j < samples.size(); ++j)
+ {
+ begin_j += k;
+ end_j += k;
+
+ const double dist = dist_funct(samples[i], samples[j]);
+
+ if (dist < worst_dists[i])
+ {
+ *iterator_of_worst(begin_i, end_i) = sample_pair(i, j, dist);
+ worst_dists[i] = iterator_of_worst(begin_i, end_i)->distance();
+ }
+
+ if (dist < worst_dists[j])
+ {
+ *iterator_of_worst(begin_j, end_j) = sample_pair(i, j, dist);
+ worst_dists[j] = iterator_of_worst(begin_j, end_j)->distance();
+ }
+ }
+
+ begin_i += k;
+ end_i += k;
+ }
+
+ // sort the edges so that duplicate edges will be adjacent
+ std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
+
+ // if the first edge is valid
+ if (edges[0].index1() < samples.size())
+ {
+ // now put edges into out while avoiding duplicates and any remaining invalid edges.
+ out.reserve(edges.size());
+ out.push_back(edges[0]);
+ for (unsigned long i = 1; i < edges.size(); ++i)
+ {
+ // if we hit an invalid edge then we can stop
+ if (edges[i].index1() >= samples.size())
+ break;
+
+ // if this isn't a duplicate edge
+ if (edges[i] != edges[i-1])
+ {
+ out.push_back(edges[i]);
+ }
+ }
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ bool contains_duplicate_pairs (
+ const vector_type& pairs
+ )
+ {
+ typedef typename vector_type::value_type T;
+ vector_type temp(pairs);
+ std::sort(temp.begin(), temp.end(), &order_by_index<T>);
+
+ for (unsigned long i = 1; i < temp.size(); ++i)
+ {
+ // if we found a duplicate
+ if (temp[i-1] == temp[i])
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ typename enable_if_c<(is_same_type<sample_pair, typename vector_type::value_type>::value ||
+ is_same_type<ordered_sample_pair, typename vector_type::value_type>::value),
+ unsigned long>::type
+ max_index_plus_one (
+ const vector_type& pairs
+ )
+ {
+ if (pairs.size() == 0)
+ {
+ return 0;
+ }
+ else
+ {
+ unsigned long max_idx = 0;
+ for (unsigned long i = 0; i < pairs.size(); ++i)
+ {
+ if (pairs[i].index1() > max_idx)
+ max_idx = pairs[i].index1();
+ if (pairs[i].index2() > max_idx)
+ max_idx = pairs[i].index2();
+ }
+
+ return max_idx + 1;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_long_edges (
+ vector_type& pairs,
+ double distance_threshold
+ )
+ {
+ vector_type temp;
+ temp.reserve(pairs.size());
+
+ // add all the pairs shorter than the given threshold into temp
+ for (unsigned long i = 0; i < pairs.size(); ++i)
+ {
+ if (pairs[i].distance() <= distance_threshold)
+ temp.push_back(pairs[i]);
+ }
+
+ // move temp into the output vector
+ temp.swap(pairs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_short_edges (
+ vector_type& pairs,
+ double distance_threshold
+ )
+ {
+ vector_type temp;
+ temp.reserve(pairs.size());
+
+ // add all the pairs longer than the given threshold into temp
+ for (unsigned long i = 0; i < pairs.size(); ++i)
+ {
+ if (pairs[i].distance() >= distance_threshold)
+ temp.push_back(pairs[i]);
+ }
+
+ // move temp into the output vector
+ temp.swap(pairs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_percent_longest_edges (
+ vector_type& pairs,
+ double percent
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( 0 <= percent && percent < 1,
+ "\t void remove_percent_longest_edges()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t percent: " << percent
+ );
+
+ typedef typename vector_type::value_type T;
+ std::sort(pairs.begin(), pairs.end(), &order_by_distance<T>);
+
+ const unsigned long num = static_cast<unsigned long>((1.0-percent)*pairs.size());
+
+ // pick out the num shortest pairs
+ vector_type temp(pairs.begin(), pairs.begin() + num);
+
+ // move temp into the output vector
+ temp.swap(pairs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_percent_shortest_edges (
+ vector_type& pairs,
+ double percent
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( 0 <= percent && percent < 1,
+ "\t void remove_percent_shortest_edges()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t percent: " << percent
+ );
+
+ typedef typename vector_type::value_type T;
+ std::sort(pairs.rbegin(), pairs.rend(), &order_by_distance<T>);
+
+ const unsigned long num = static_cast<unsigned long>((1.0-percent)*pairs.size());
+
+ // pick out the num shortest pairs
+ vector_type temp(pairs.begin(), pairs.begin() + num);
+
+ // move temp into the output vector
+ temp.swap(pairs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ bool is_ordered_by_index (
+ const vector_type& edges
+ )
+ {
+ for (unsigned long i = 1; i < edges.size(); ++i)
+ {
+ if (order_by_index(edges[i], edges[i-1]))
+ return false;
+ }
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename alloc1,
+ typename alloc2
+ >
+ void find_neighbor_ranges (
+ const std::vector<ordered_sample_pair,alloc1>& edges,
+ std::vector<std::pair<unsigned long, unsigned long>,alloc2>& neighbors
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_ordered_by_index(edges),
+ "\t void find_neighbor_ranges()"
+ << "\n\t Invalid inputs were given to this function"
+ );
+
+
+ // setup neighbors so that [neighbors[i].first, neighbors[i].second) is the range
+ // within edges that contains all node i's edges.
+ const unsigned long num_nodes = max_index_plus_one(edges);
+ neighbors.assign(num_nodes, std::make_pair(0,0));
+ unsigned long cur_node = 0;
+ unsigned long start_idx = 0;
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ if (edges[i].index1() != cur_node)
+ {
+ neighbors[cur_node] = std::make_pair(start_idx, i);
+ start_idx = i;
+ cur_node = edges[i].index1();
+ }
+ }
+ if (neighbors.size() != 0)
+ neighbors[cur_node] = std::make_pair(start_idx, (unsigned long)edges.size());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename alloc1,
+ typename alloc2
+ >
+ void convert_unordered_to_ordered (
+ const std::vector<sample_pair,alloc1>& edges,
+ std::vector<ordered_sample_pair,alloc2>& out_edges
+ )
+ {
+ out_edges.clear();
+ out_edges.reserve(edges.size()*2);
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ out_edges.push_back(ordered_sample_pair(edges[i].index1(), edges[i].index2(), edges[i].distance()));
+ if (edges[i].index1() != edges[i].index2())
+ out_edges.push_back(ordered_sample_pair(edges[i].index2(), edges[i].index1(), edges[i].distance()));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_EDGE_LIST_GrAPHS_Hh_
+
+
diff --git a/ml/dlib/dlib/graph_utils/edge_list_graphs_abstract.h b/ml/dlib/dlib/graph_utils/edge_list_graphs_abstract.h
new file mode 100644
index 000000000..1f72c2739
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/edge_list_graphs_abstract.h
@@ -0,0 +1,358 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_EDGE_LIST_GrAPHS_ABSTRACT_Hh_
+#ifdef DLIB_EDGE_LIST_GrAPHS_ABSTRACT_Hh_
+
+#include <vector>
+#include "../string.h"
+#include "sample_pair_abstract.h"
+#include "ordered_sample_pair_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename distance_function_type,
+ typename alloc,
+ typename T
+ >
+ void find_percent_shortest_edges_randomly (
+ const vector_type& samples,
+ const distance_function_type& dist_funct,
+ const double percent,
+ const unsigned long num,
+ const T& random_seed,
+ std::vector<sample_pair, alloc>& out
+ );
+ /*!
+ requires
+ - 0 < percent <= 1
+ - num > 0
+ - random_seed must be convertible to a string by dlib::cast_to_string()
+ - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates
+ to a floating point number
+ ensures
+ - This function randomly samples the space of pairs of integers between
+ 0 and samples.size()-1 inclusive. For each of these pairs, (i,j), a
+ sample_pair is created as follows:
+ sample_pair(i, j, dist_funct(samples[i], samples[j]))
+ num such sample_pair objects are generated, duplicates and pairs with distance
+ values == infinity are removed, and then the top percent of them with the
+ smallest distance are stored into out.
+ - #out.size() <= num*percent
+ - contains_duplicate_pairs(#out) == false
+ - for all valid i:
+ - #out[i].distance() == dist_funct(samples[#out[i].index1()], samples[#out[i].index2()])
+ - #out[i].distance() < std::numeric_limits<double>::infinity()
+ - random_seed is used to seed the random number generator used by this
+ function.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename distance_function_type,
+ typename alloc,
+ typename T
+ >
+ void find_approximate_k_nearest_neighbors (
+ const vector_type& samples,
+ const distance_function_type& dist_funct,
+ const unsigned long k,
+ const unsigned long num,
+ const T& random_seed,
+ std::vector<sample_pair, alloc>& out
+ );
+ /*!
+ requires
+ - k > 0
+ - num > 0
+ - random_seed must be convertible to a string by dlib::cast_to_string()
+ - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates
+ to a floating point number
+ ensures
+ - This function computes an approximate form of k nearest neighbors. As num grows
+ larger the output of this function converges to the output of the
+ find_k_nearest_neighbors() function defined below.
+ - Specifically, this function randomly samples the space of pairs of integers between
+ 0 and samples.size()-1 inclusive. For each of these pairs, (i,j), a
+ sample_pair is created as follows:
+ sample_pair(i, j, dist_funct(samples[i], samples[j]))
+ num such sample_pair objects are generated and then exact k-nearest-neighbors
+ is performed amongst these sample_pairs and the results are stored into #out.
+ Note that samples with an infinite distance between them are considered to
+ be not connected at all.
+ - contains_duplicate_pairs(#out) == false
+ - for all valid i:
+ - #out[i].distance() == dist_funct(samples[#out[i].index1()], samples[#out[i].index2()])
+ - #out[i].distance() < std::numeric_limits<double>::infinity()
+ - random_seed is used to seed the random number generator used by this
+ function.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename distance_function_type,
+ typename alloc
+ >
+ void find_k_nearest_neighbors (
+ const vector_type& samples,
+ const distance_function_type& dist_funct,
+ const unsigned long k,
+ std::vector<sample_pair, alloc>& out
+ );
+ /*!
+ requires
+ - k > 0
+ - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates
+ to a floating point number
+ ensures
+ - #out == a set of sample_pair objects that represent all the k nearest
+ neighbors in samples according to the given distance function dist_funct.
+ Note that samples with an infinite distance between them are considered to
+ be not connected at all.
+ - for all valid i:
+ - #out[i].distance() == dist_funct(samples[#out[i].index1()], samples[#out[i].index2()])
+ - #out[i].distance() < std::numeric_limits<double>::infinity()
+ - contains_duplicate_pairs(#out) == false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ bool contains_duplicate_pairs (
+ const vector_type& pairs
+ );
+ /*!
+ requires
+ - vector_type == a type with an interface compatible with std::vector and it
+ must in turn contain objects with an interface compatible with
+ dlib::sample_pair or dlib::ordered_sample_pair.
+ ensures
+ - if (pairs contains any elements that are equal according to operator==) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ unsigned long max_index_plus_one (
+ const vector_type& pairs
+ );
+ /*!
+ requires
+ - vector_type == a type with an interface compatible with std::vector and it
+ must in turn contain objects with an interface compatible with
+ dlib::sample_pair or dlib::ordered_sample_pair.
+ ensures
+ - if (pairs.size() == 0) then
+ - returns 0
+ - else
+ - returns a number N such that:
+ - for all i: pairs[i].index1() < N && pairs[i].index2() < N
+ - for some j: pairs[j].index1()+1 == N || pairs[j].index2()+1 == N
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_long_edges (
+ vector_type& pairs,
+ double distance_threshold
+ );
+ /*!
+ requires
+ - vector_type == a type with an interface compatible with std::vector and it
+ must in turn contain objects with an interface compatible with
+ dlib::sample_pair or dlib::ordered_sample_pair.
+ ensures
+ - Removes all elements of pairs that have a distance value greater than the
+ given threshold.
+ - #pairs.size() <= pairs.size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_short_edges (
+ vector_type& pairs,
+ double distance_threshold
+ );
+ /*!
+ requires
+ - vector_type == a type with an interface compatible with std::vector and it
+ must in turn contain objects with an interface compatible with
+ dlib::sample_pair or dlib::ordered_sample_pair.
+ ensures
+ - Removes all elements of pairs that have a distance value less than the
+ given threshold.
+ - #pairs.size() <= pairs.size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_percent_longest_edges (
+ vector_type& pairs,
+ double percent
+ );
+ /*!
+ requires
+ - 0 <= percent < 1
+ - vector_type == a type with an interface compatible with std::vector and it
+ must in turn contain objects with an interface compatible with
+ dlib::sample_pair or dlib::ordered_sample_pair.
+ ensures
+ - Removes the given upper percentage of the longest edges in pairs. I.e.
+ this function removes the long edges from pairs.
+ - #pairs.size() == (1-percent)*pairs.size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_percent_shortest_edges (
+ vector_type& pairs,
+ double percent
+ );
+ /*!
+ requires
+ - 0 <= percent < 1
+ - vector_type == a type with an interface compatible with std::vector and it
+ must in turn contain objects with an interface compatible with
+ dlib::sample_pair or dlib::ordered_sample_pair.
+ ensures
+ - Removes the given upper percentage of the shortest edges in pairs. I.e.
+ this function removes the short edges from pairs.
+ - #pairs.size() == (1-percent)*pairs.size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void remove_duplicate_edges (
+ vector_type& pairs
+ );
+ /*!
+ requires
+ - vector_type == a type with an interface compatible with std::vector and it
+ must in turn contain objects with an interface compatible with
+ dlib::sample_pair or dlib::ordered_sample_pair.
+ ensures
+ - Removes any duplicate edges from pairs. That is, for all elements of pairs,
+ A and B, such that A == B, only one of A or B will be in pairs after this
+ function terminates.
+ - #pairs.size() <= pairs.size()
+ - is_ordered_by_index(#pairs) == true
+ - contains_duplicate_pairs(#pairs) == false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ bool is_ordered_by_index (
+ const vector_type& edges
+ );
+ /*!
+ requires
+ - vector_type == a type with an interface compatible with std::vector and it
+ must in turn contain objects with an interface compatible with
+ dlib::sample_pair or dlib::ordered_sample_pair.
+ ensures
+ - returns true if and only if the contents of edges are in sorted order
+ according to order_by_index(). That is, we return true if calling
+ std::stable_sort(edges.begin(), edges.end(), &order_by_index<T>) would not
+ change the ordering of elements of edges.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename alloc1,
+ typename alloc2
+ >
+ void find_neighbor_ranges (
+ const std::vector<ordered_sample_pair,alloc1>& edges,
+ std::vector<std::pair<unsigned long, unsigned long>,alloc2>& neighbors
+ );
+ /*!
+ requires
+ - is_ordered_by_index(edges) == true
+ (i.e. edges is sorted so that all the edges for a particular node are grouped
+ together)
+ ensures
+ - This function takes a graph, represented by its list of edges, and finds the
+ ranges that contain the edges for each node in the graph. In particular,
+ #neighbors[i] will tell you which edges correspond to the ith node in the
+ graph.
+ - #neighbors.size() == max_index_plus_one(edges)
+ (i.e. neighbors will have an entry for each node in the graph defined by the
+ list of edges)
+ - for all valid i:
+ - all elements of edges such that their index1() value == i are in the
+ range [neighbors[i].first, neighbors[i].second). That is, for all k such
+ that neighbors[i].first <= k < neighbors[i].second:
+ - edges[k].index1() == i.
+ - all edges outside this range have an index1() value != i
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename alloc1,
+ typename alloc2
+ >
+ void convert_unordered_to_ordered (
+ const std::vector<sample_pair,alloc1>& edges,
+ std::vector<ordered_sample_pair,alloc2>& out_edges
+ );
+ /*!
+ ensures
+ - interprets edges a defining an undirected graph.
+ - This function populates out_edges with a directed graph that represents the
+ same graph as the one in edges. In particular, this means that for all valid
+ i we have the following:
+ - if (edges[i].index1() != edges[i].index2()) then
+ - #out_edges contains two edges corresponding to edges[i]. They
+ represent the two directions of this edge. The distance value from
+ edges[i] is also copied into the output edges.
+ - else
+ - #out_edges contains one edge corresponding to edges[i] since this is
+ a self edge. The distance value from edges[i] is also copied into
+ the output edge.
+ - max_index_plus_one(edges) == max_index_plus_one(#out_edges)
+ (i.e. both graphs have the same number of nodes)
+ - In all but the most trivial cases, we will have is_ordered_by_index(#out_edges) == false
+ - contains_duplicate_pairs(#out_edges) == contains_duplicate_pairs(edges)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_EDGE_LIST_GrAPHS_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh.h b/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh.h
new file mode 100644
index 000000000..433a588a5
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh.h
@@ -0,0 +1,217 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_
+#define DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_
+
+#include "find_k_nearest_neighbors_lsh_abstract.h"
+#include "../threads.h"
+#include "../lsh/hashes.h"
+#include <vector>
+#include <queue>
+#include "sample_pair.h"
+#include "edge_list_graphs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ struct compare_sample_pair_with_distance
+ {
+ inline bool operator() (const sample_pair& a, const sample_pair& b) const
+ {
+ return a.distance() < b.distance();
+ }
+ };
+
+ template <
+ typename vector_type,
+ typename hash_function_type
+ >
+ class hash_block
+ {
+ public:
+ hash_block(
+ const vector_type& samples_,
+ const hash_function_type& hash_funct_,
+ std::vector<typename hash_function_type::result_type>& hashes_
+ ) :
+ samples(samples_),
+ hash_funct(hash_funct_),
+ hashes(hashes_)
+ {}
+
+ void operator() (long i) const
+ {
+ hashes[i] = hash_funct(samples[i]);
+ }
+
+ const vector_type& samples;
+ const hash_function_type& hash_funct;
+ std::vector<typename hash_function_type::result_type>& hashes;
+ };
+
+ template <
+ typename vector_type,
+ typename distance_function_type,
+ typename hash_function_type,
+ typename alloc
+ >
+ class scan_find_k_nearest_neighbors_lsh
+ {
+ public:
+ scan_find_k_nearest_neighbors_lsh (
+ const vector_type& samples_,
+ const distance_function_type& dist_funct_,
+ const hash_function_type& hash_funct_,
+ const unsigned long k_,
+ std::vector<sample_pair, alloc>& edges_,
+ const unsigned long k_oversample_,
+ const std::vector<typename hash_function_type::result_type>& hashes_
+ ) :
+ samples(samples_),
+ dist_funct(dist_funct_),
+ hash_funct(hash_funct_),
+ k(k_),
+ edges(edges_),
+ k_oversample(k_oversample_),
+ hashes(hashes_)
+ {
+ edges.clear();
+ edges.reserve(samples.size()*k/2);
+ }
+
+ mutex m;
+ const vector_type& samples;
+ const distance_function_type& dist_funct;
+ const hash_function_type& hash_funct;
+ const unsigned long k;
+ std::vector<sample_pair, alloc>& edges;
+ const unsigned long k_oversample;
+ const std::vector<typename hash_function_type::result_type>& hashes;
+
+ void operator() (unsigned long i) const
+ {
+ const unsigned long k_hash = k*k_oversample;
+
+ std::priority_queue<std::pair<unsigned long, unsigned long> > best_hashes;
+ std::priority_queue<sample_pair, std::vector<sample_pair>, dlib::impl::compare_sample_pair_with_distance> best_samples;
+ unsigned long worst_distance = std::numeric_limits<unsigned long>::max();
+ // scan over the hashes and find the best matches for hashes[i]
+ for (unsigned long j = 0; j < hashes.size(); ++j)
+ {
+ if (i == j)
+ continue;
+
+ const unsigned long dist = hash_funct.distance(hashes[i], hashes[j]);
+ if (dist < worst_distance || best_hashes.size() < k_hash)
+ {
+ if (best_hashes.size() >= k_hash)
+ best_hashes.pop();
+ best_hashes.push(std::make_pair(dist, j));
+ worst_distance = best_hashes.top().first;
+ }
+ }
+
+ // Now figure out which of the best_hashes are actually the k best matches
+ // according to dist_funct()
+ while (best_hashes.size() != 0)
+ {
+ const unsigned long j = best_hashes.top().second;
+ best_hashes.pop();
+
+ const double dist = dist_funct(samples[i], samples[j]);
+ if (dist < std::numeric_limits<double>::infinity())
+ {
+ if (best_samples.size() >= k)
+ best_samples.pop();
+ best_samples.push(sample_pair(i,j,dist));
+ }
+ }
+
+ // Finally, now put the k best matches according to dist_funct() into edges
+ auto_mutex lock(m);
+ while (best_samples.size() != 0)
+ {
+ edges.push_back(best_samples.top());
+ best_samples.pop();
+ }
+ }
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename hash_function_type
+ >
+ void hash_samples (
+ const vector_type& samples,
+ const hash_function_type& hash_funct,
+ const unsigned long num_threads,
+ std::vector<typename hash_function_type::result_type>& hashes
+ )
+ {
+ hashes.resize(samples.size());
+
+ typedef impl::hash_block<vector_type,hash_function_type> block_type;
+ block_type temp(samples, hash_funct, hashes);
+ parallel_for(num_threads, 0, samples.size(), temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename distance_function_type,
+ typename hash_function_type,
+ typename alloc
+ >
+ void find_k_nearest_neighbors_lsh (
+ const vector_type& samples,
+ const distance_function_type& dist_funct,
+ const hash_function_type& hash_funct,
+ const unsigned long k,
+ const unsigned long num_threads,
+ std::vector<sample_pair, alloc>& edges,
+ const unsigned long k_oversample = 20
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(k > 0 && k_oversample > 0,
+ "\t void find_k_nearest_neighbors_lsh()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t k: " << k
+ << "\n\t k_oversample: " << k_oversample
+ );
+
+ edges.clear();
+
+ if (samples.size() <= 1)
+ {
+ return;
+ }
+
+ typedef typename hash_function_type::result_type hash_type;
+ std::vector<hash_type> hashes;
+ hash_samples(samples, hash_funct, num_threads, hashes);
+
+ typedef impl::scan_find_k_nearest_neighbors_lsh<vector_type, distance_function_type,hash_function_type,alloc> scan_type;
+ scan_type temp(samples, dist_funct, hash_funct, k, edges, k_oversample, hashes);
+ parallel_for(num_threads, 0, hashes.size(), temp);
+
+ remove_duplicate_edges(edges);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_
+
+
diff --git a/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h b/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h
new file mode 100644
index 000000000..1de159be4
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h
@@ -0,0 +1,102 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_Hh_
+#ifdef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_Hh_
+
+#include "../lsh/hashes_abstract.h"
+#include "sample_pair_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename hash_function_type
+ >
+ void hash_samples (
+ const vector_type& samples,
+ const hash_function_type& hash_funct,
+ const unsigned long num_threads,
+ std::vector<typename hash_function_type::result_type>& hashes
+ );
+ /*!
+ requires
+ - hash_funct() is threadsafe. This means that it must be safe for multiple
+ threads to invoke the member functions of hash_funct() at the same time.
+ - vector_type is any container that looks like a std::vector or dlib::array.
+ - hash_funct must be a function object with an interface compatible with the
+ objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
+ must be capable of hashing the elements in the samples vector.
+ ensures
+ - This function hashes all the elements in samples and stores the results in
+ hashes. It will also use num_threads concurrent threads to do this. You
+ should set this value equal to the number of processing cores on your
+ computer for maximum speed.
+ - #hashes.size() == 0
+ - for all valid i:
+ - #hashes[i] = hash_funct(samples[i])
+ (i.e. #hashes[i] will contain the hash of samples[i])
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type,
+ typename distance_function_type,
+ typename hash_function_type,
+ typename alloc
+ >
+ void find_k_nearest_neighbors_lsh (
+ const vector_type& samples,
+ const distance_function_type& dist_funct,
+ const hash_function_type& hash_funct,
+ const unsigned long k,
+ const unsigned long num_threads,
+ std::vector<sample_pair, alloc>& edges,
+ const unsigned long k_oversample = 20
+ );
+ /*!
+ requires
+ - hash_funct and dist_funct are threadsafe. This means that it must be safe
+ for multiple threads to invoke the member functions of these objects at the
+ same time.
+ - k > 0
+ - k_oversample > 0
+ - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates
+ to a floating point number
+ - vector_type is any container that looks like a std::vector or dlib::array.
+ - hash_funct must be a function object with an interface compatible with the
+ objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct
+ must be capable of hashing the elements in the samples vector.
+ ensures
+ - This function computes an approximate form of a k nearest neighbors graph of
+ the elements in samples. In particular, the way it works is that it first
+ hashes all elements in samples using the provided locality sensitive hash
+ function hash_funct(). Then it performs an exact k nearest neighbors on the
+ hashes which can be done very quickly. For each of these neighbors we
+ compute the true distance using dist_funct() and the k nearest neighbors for
+ each sample are stored into #edges.
+ - Note that samples with an infinite distance between them are considered to be
+ not connected at all. Therefore, we exclude edges with such distances from
+ being output.
+ - for all valid i:
+ - #edges[i].distance() == dist_funct(samples[#edges[i].index1()], samples[#edges[i].index2()])
+ - #edges[i].distance() < std::numeric_limits<double>::infinity()
+ - contains_duplicate_pairs(#edges) == false
+ - This function will use num_threads concurrent threads of processing. You
+ should set this value equal to the number of processing cores on your
+ computer for maximum speed.
+ - The hash based k nearest neighbor step is approximate, however, you can
+ improve the output accuracy by using a larger k value for this first step.
+ Therefore, this function finds k*k_oversample nearest neighbors during the
+ first hashing based step.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/graph_utils/function_objects.h b/ml/dlib/dlib/graph_utils/function_objects.h
new file mode 100644
index 000000000..33b6e51ff
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/function_objects.h
@@ -0,0 +1,129 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MR_FUNCTION_ObJECTS_Hh_
+#define DLIB_MR_FUNCTION_ObJECTS_Hh_
+
+#include "function_objects_abstract.h"
+#include "../matrix.h"
+#include "../svm/sparse_vector.h"
+#include <cmath>
+#include <limits>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct squared_euclidean_distance
+ {
+ squared_euclidean_distance (
+ ) :
+ lower(0),
+ upper(std::numeric_limits<double>::infinity())
+ {}
+
+ squared_euclidean_distance (
+ const double l,
+ const double u
+ ) :
+ lower(l),
+ upper(u)
+ {}
+
+ const double lower;
+ const double upper;
+
+ template <typename sample_type>
+ double operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ const double len = length_squared(a-b);
+ if (lower <= len && len <= upper)
+ return len;
+ else
+ return std::numeric_limits<double>::infinity();
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct cosine_distance
+ {
+ template <typename sample_type>
+ double operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ const double temp = length(a)*length(b);
+ if (temp == 0)
+ return 0;
+ else
+ return 1-dot(a,b)/temp;
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct negative_dot_product_distance
+ {
+ template <typename sample_type>
+ double operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return -dot(a,b);
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct use_weights_of_one
+ {
+ template <typename edge_type>
+ double operator() (
+ const edge_type&
+ ) const
+ {
+ return 1;
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct use_gaussian_weights
+ {
+ use_gaussian_weights (
+ )
+ {
+ gamma = 0.1;
+ }
+
+ use_gaussian_weights (
+ double g
+ )
+ {
+ gamma = g;
+ }
+
+ double gamma;
+
+ template <typename edge_type>
+ double operator() (
+ const edge_type& e
+ ) const
+ {
+ return std::exp(-gamma*e.distance());
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MR_FUNCTION_ObJECTS_Hh_
+
+
diff --git a/ml/dlib/dlib/graph_utils/function_objects_abstract.h b/ml/dlib/dlib/graph_utils/function_objects_abstract.h
new file mode 100644
index 000000000..394b99397
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/function_objects_abstract.h
@@ -0,0 +1,209 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MR_FUNCTION_ObJECTS_ABSTRACT_Hh_
+#ifdef DLIB_MR_FUNCTION_ObJECTS_ABSTRACT_Hh_
+
+#include "../matrix.h"
+#include <cmath>
+#include "../svm/sparse_vector_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct squared_euclidean_distance
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple function object that computes squared euclidean distance
+ between two dlib::matrix objects.
+
+ THREAD SAFETY
+ This object has no mutable members. Therefore, it is safe to call
+ operator() on a single instance of this object simultaneously from multiple
+ threads.
+ !*/
+
+ squared_euclidean_distance (
+ );
+ /*!
+ ensures
+ - #lower == 0
+ - #upper == std::numeric_limits<double>::infinity()
+ !*/
+
+ squared_euclidean_distance (
+ const double l,
+ const double u
+ );
+ /*!
+ ensures
+ - #lower == l
+ - #upper == u
+ !*/
+
+ const double lower;
+ const double upper;
+
+ template <typename sample_type>
+ double operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - sample_type should be a kind of dlib::matrix
+ ensures
+ - let LEN = length_squared(a-b)
+ - if (lower <= LEN <= upper) then
+ - returns LEN
+ - else
+ - returns std::numeric_limits<double>::infinity()
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct cosine_distance
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple function object that computes the cosine of the angle
+ between two vectors and returns 1 - this quantity. Moreover, this object
+ works for both sparse and dense vectors.
+
+ THREAD SAFETY
+ This object has no mutable members. Therefore, it is safe to call
+ operator() on a single instance of this object simultaneously from multiple
+ threads.
+ !*/
+
+ template <typename sample_type>
+ double operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - sample_type is a dense vector (e.g. a dlib::matrix) or a sparse
+ vector as defined at the top of dlib/svm/sparse_vector_abstract.h
+ ensures
+ - let theta = the angle between a and b.
+ - returns 1 - cos(theta)
+ (e.g. this function returns 0 when a and b have an angle of 0 between
+ each other, 1 if they have a 90 degree angle, and a maximum of 2 if the
+ vectors have a 180 degree angle between each other).
+ - zero length vectors are considered to have angles of 0 between all other
+ vectors.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct negative_dot_product_distance
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple function object that computes the dot product between two
+ vectors and returns the negation of this value. Moreover, this object
+ works for both sparse and dense vectors.
+
+ THREAD SAFETY
+ This object has no mutable members. Therefore, it is safe to call
+ operator() on a single instance of this object simultaneously from multiple
+ threads.
+ !*/
+
+ template <typename sample_type>
+ double operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - sample_type is a dense vector (e.g. a dlib::matrix) or a sparse
+ vector as defined at the top of dlib/svm/sparse_vector_abstract.h
+ ensures
+ - returns -dot(a,b)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct use_weights_of_one
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple function object that takes a single argument
+ and always returns 1
+
+ THREAD SAFETY
+ This object has no mutable members. Therefore, it is safe to call
+ operator() on a single instance of this object simultaneously from multiple
+ threads.
+ !*/
+
+ template <typename edge_type>
+ double operator() (
+ const edge_type&
+ ) const;
+ /*!
+ ensures
+ - returns 1
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct use_gaussian_weights
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple function object that takes a single argument
+ which should be an object similar to dlib::sample_pair.
+
+ THREAD SAFETY
+ This object has no mutable members. Therefore, it is safe to call
+ operator() on a single instance of this object simultaneously from multiple
+ threads.
+ !*/
+
+ use_gaussian_weights (
+ );
+ /*!
+ ensures
+ - #gamma == 0.1
+ !*/
+
+ use_gaussian_weights (
+ double g
+ );
+ /*!
+ ensures
+ - #gamma == g
+ !*/
+
+ double gamma;
+
+ template <typename edge_type>
+ double operator() (
+ const edge_type& e
+ ) const;
+ /*!
+ requires
+ - e.distance() must be a valid expression that returns a number
+ (e.g. edge_type might be dlib::sample_pair)
+ ensures
+ - returns std::exp(-gamma*e.distance());
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MR_FUNCTION_ObJECTS_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/graph_utils/graph_utils.h b/ml/dlib/dlib/graph_utils/graph_utils.h
new file mode 100644
index 000000000..81262b7f5
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/graph_utils.h
@@ -0,0 +1,1227 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GRAPH_UTILs_
+#define DLIB_GRAPH_UTILs_
+
+#include "../algs.h"
+#include <vector>
+#include "graph_utils_abstract.h"
+#include "../is_kind.h"
+#include "../enable_if.h"
+#include <algorithm>
+#include "../set.h"
+#include "../memory_manager.h"
+#include "../set_utils.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename enable_if<is_graph<T>,typename T::edge_type>::type& edge(
+ T& g,
+ unsigned long idx_i,
+ unsigned long idx_j
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(g.has_edge(idx_i,idx_j) == true,
+ "\tT::edge_type& edge(g, idx_i, idx_j)"
+ << "\n\t you have requested an invalid edge"
+ << "\n\t idx_i: " << idx_i
+ << "\n\t idx_j: " << idx_j
+ );
+
+ for (unsigned long i = 0; i < g.node(idx_i).number_of_neighbors(); ++i)
+ {
+ if (g.node(idx_i).neighbor(i).index() == idx_j)
+ return g.node(idx_i).edge(i);
+ }
+
+ // put this here just so compilers don't complain about a lack of
+ // a return here
+ DLIB_CASSERT(false,
+ "\tT::edge_type& edge(g, idx_i, idx_j)"
+ << "\n\t you have requested an invalid edge"
+ << "\n\t idx_i: " << idx_i
+ << "\n\t idx_j: " << idx_j
+ );
+ }
+
+ template <typename T>
+ const typename enable_if<is_graph<T>,typename T::edge_type>::type& edge(
+ const T& g,
+ unsigned long idx_i,
+ unsigned long idx_j
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(g.has_edge(idx_i,idx_j) == true,
+ "\tT::edge_type& edge(g, idx_i, idx_j)"
+ << "\n\t you have requested an invalid edge"
+ << "\n\t idx_i: " << idx_i
+ << "\n\t idx_j: " << idx_j
+ );
+
+ for (unsigned long i = 0; i < g.node(idx_i).number_of_neighbors(); ++i)
+ {
+ if (g.node(idx_i).neighbor(i).index() == idx_j)
+ return g.node(idx_i).edge(i);
+ }
+
+ // put this here just so compilers don't complain about a lack of
+ // a return here
+ DLIB_CASSERT(false,
+ "\tT::edge_type& edge(g, idx_i, idx_j)"
+ << "\n\t you have requested an invalid edge"
+ << "\n\t idx_i: " << idx_i
+ << "\n\t idx_j: " << idx_j
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename enable_if<is_directed_graph<T>,typename T::edge_type>::type& edge(
+ T& g,
+ unsigned long parent_idx,
+ unsigned long child_idx
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(g.has_edge(parent_idx,child_idx) == true,
+ "\t T::edge_type& edge(g, parent_idx, child_idx)"
+ << "\n\t you have requested an invalid edge"
+ << "\n\t parent_idx: " << parent_idx
+ << "\n\t child_idx: " << child_idx
+ );
+
+ for (unsigned long i = 0; i < g.node(parent_idx).number_of_children(); ++i)
+ {
+ if (g.node(parent_idx).child(i).index() == child_idx)
+ return g.node(parent_idx).child_edge(i);
+ }
+
+ // put this here just so compilers don't complain about a lack of
+ // a return here
+ DLIB_CASSERT(false,
+ "\t T::edge_type& edge(g, parent_idx, child_idx)"
+ << "\n\t you have requested an invalid edge"
+ << "\n\t parent_idx: " << parent_idx
+ << "\n\t child_idx: " << child_idx
+ );
+ }
+
+ template <typename T>
+ const typename enable_if<is_directed_graph<T>,typename T::edge_type>::type& edge(
+ const T& g,
+ unsigned long parent_idx,
+ unsigned long child_idx
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(g.has_edge(parent_idx,child_idx) == true,
+ "\t T::edge_type& edge(g, parent_idx, child_idx)"
+ << "\n\t you have requested an invalid edge"
+ << "\n\t parent_idx: " << parent_idx
+ << "\n\t child_idx: " << child_idx
+ );
+
+ for (unsigned long i = 0; i < g.node(parent_idx).number_of_children(); ++i)
+ {
+ if (g.node(parent_idx).child(i).index() == child_idx)
+ return g.node(parent_idx).child_edge(i);
+ }
+
+ // put this here just so compilers don't complain about a lack of
+ // a return here
+ DLIB_ASSERT(false,
+ "\t T::edge_type& edge(g, parent_idx, child_idx)"
+ << "\n\t you have requested an invalid edge"
+ << "\n\t parent_idx: " << parent_idx
+ << "\n\t child_idx: " << child_idx
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace graph_helpers
+ {
+ template <typename T, typename U>
+ inline bool is_same_object (
+ const T& a,
+ const U& b
+ )
+ {
+ if (is_same_type<const T,const U>::value == false)
+ return false;
+ if ((void*)&a == (void*)&b)
+ return true;
+ else
+ return false;
+ }
+
+ template <
+ typename T
+ >
+ bool search_for_directed_cycles (
+ const T& node,
+ std::vector<bool>& visited,
+ std::vector<bool>& temp
+ )
+ /*!
+ requires
+ - visited.size() >= number of nodes in the graph that contains the given node
+ - temp.size() >= number of nodes in the graph that contains the given node
+ - for all i in temp:
+ - temp[i] == false
+ ensures
+ - checks the connected subgraph containing the given node for directed cycles
+ and returns true if any are found and false otherwise.
+ - for all nodes N in the connected subgraph containing the given node:
+ - #visited[N.index()] == true
+ - for all i in temp:
+ - #temp[i] == false
+ !*/
+ {
+ if (temp[node.index()] == true)
+ return true;
+
+ visited[node.index()] = true;
+ temp[node.index()] = true;
+
+ for (unsigned long i = 0; i < node.number_of_children(); ++i)
+ {
+ if (search_for_directed_cycles(node.child(i), visited, temp))
+ return true;
+ }
+
+ temp[node.index()] = false;
+
+ return false;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename enable_if<is_directed_graph<typename T::graph_type>,bool>::type search_for_undirected_cycles (
+ const T& node,
+ std::vector<bool>& visited,
+ unsigned long prev = std::numeric_limits<unsigned long>::max()
+ )
+ /*!
+ requires
+ - visited.size() >= number of nodes in the graph that contains the given node
+ - for all nodes N in the connected subgraph containing the given node:
+ - visited[N.index] == false
+ ensures
+ - checks the connected subgraph containing the given node for directed cycles
+ and returns true if any are found and false otherwise.
+ - for all nodes N in the connected subgraph containing the given node:
+ - #visited[N.index()] == true
+ !*/
+ {
+ using namespace std;
+ if (visited[node.index()] == true)
+ return true;
+
+ visited[node.index()] = true;
+
+ for (unsigned long i = 0; i < node.number_of_children(); ++i)
+ {
+ if (node.child(i).index() != prev &&
+ search_for_undirected_cycles(node.child(i), visited, node.index()))
+ return true;
+ }
+
+ for (unsigned long i = 0; i < node.number_of_parents(); ++i)
+ {
+ if (node.parent(i).index() != prev &&
+ search_for_undirected_cycles(node.parent(i), visited, node.index()))
+ return true;
+ }
+
+ return false;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename enable_if<is_graph<typename T::graph_type>,bool>::type search_for_undirected_cycles (
+ const T& node,
+ std::vector<bool>& visited,
+ unsigned long prev = std::numeric_limits<unsigned long>::max()
+ )
+ /*!
+ requires
+ - visited.size() >= number of nodes in the graph that contains the given node
+ - for all nodes N in the connected subgraph containing the given node:
+ - visited[N.index] == false
+ ensures
+ - checks the connected subgraph containing the given node for directed cycles
+ and returns true if any are found and false otherwise.
+ - for all nodes N in the connected subgraph containing the given node:
+ - #visited[N.index()] == true
+ !*/
+ {
+ using namespace std;
+ if (visited[node.index()] == true)
+ return true;
+
+ visited[node.index()] = true;
+
+ for (unsigned long i = 0; i < node.number_of_neighbors(); ++i)
+ {
+ if (node.neighbor(i).index() != prev &&
+ search_for_undirected_cycles(node.neighbor(i), visited, node.index()))
+ return true;
+ }
+
+ return false;
+ }
+
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type1,
+ typename graph_type2
+ >
+ typename enable_if<is_graph<graph_type1> >::type copy_graph_structure (
+ const graph_type1& src,
+ graph_type2& dest
+ )
+ {
+ COMPILE_TIME_ASSERT(is_graph<graph_type1>::value);
+ COMPILE_TIME_ASSERT(is_graph<graph_type2>::value);
+ if (graph_helpers::is_same_object(src,dest))
+ return;
+
+ dest.clear();
+ dest.set_number_of_nodes(src.number_of_nodes());
+
+ // copy all the edges from src into dest
+ for (unsigned long i = 0; i < src.number_of_nodes(); ++i)
+ {
+ for (unsigned long j = 0; j < src.node(i).number_of_neighbors(); ++j)
+ {
+ const unsigned long nidx = src.node(i).neighbor(j).index();
+ if (nidx >= i)
+ {
+ dest.add_edge(i,nidx);
+ }
+ }
+ }
+ }
+
+ template <
+ typename graph_type1,
+ typename graph_type2
+ >
+ typename enable_if<is_directed_graph<graph_type1> >::type copy_graph_structure (
+ const graph_type1& src,
+ graph_type2& dest
+ )
+ {
+ COMPILE_TIME_ASSERT(is_directed_graph<graph_type1>::value);
+ COMPILE_TIME_ASSERT(is_directed_graph<graph_type2>::value || is_graph<graph_type2>::value );
+ if (graph_helpers::is_same_object(src,dest))
+ return;
+
+ dest.clear();
+ dest.set_number_of_nodes(src.number_of_nodes());
+
+ // copy all the edges from src into dest
+ for (unsigned long i = 0; i < src.number_of_nodes(); ++i)
+ {
+ for (unsigned long j = 0; j < src.node(i).number_of_children(); ++j)
+ {
+ const unsigned long nidx = src.node(i).child(j).index();
+ if (dest.has_edge(i,nidx) == false)
+ {
+ dest.add_edge(i,nidx);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type1,
+ typename graph_type2
+ >
+ typename enable_if<is_graph<graph_type1> >::type copy_graph (
+ const graph_type1& src,
+ graph_type2& dest
+ )
+ {
+ COMPILE_TIME_ASSERT(is_graph<graph_type1>::value);
+ COMPILE_TIME_ASSERT(is_graph<graph_type2>::value);
+ if (graph_helpers::is_same_object(src,dest))
+ return;
+
+ copy_graph_structure(src,dest);
+
+ // copy all the node and edge content
+ for (unsigned long i = 0; i < src.number_of_nodes(); ++i)
+ {
+ dest.node(i).data = src.node(i).data;
+
+ for (unsigned long j = 0; j < src.node(i).number_of_neighbors(); ++j)
+ {
+ const unsigned long nidx = src.node(i).neighbor(j).index();
+ if (nidx >= i)
+ {
+ dest.node(i).edge(j) = src.node(i).edge(j);
+ }
+ }
+ }
+ }
+
+ template <
+ typename graph_type1,
+ typename graph_type2
+ >
+ typename enable_if<is_directed_graph<graph_type1> >::type copy_graph (
+ const graph_type1& src,
+ graph_type2& dest
+ )
+ {
+ COMPILE_TIME_ASSERT(is_directed_graph<graph_type1>::value);
+ COMPILE_TIME_ASSERT(is_directed_graph<graph_type2>::value);
+ if (graph_helpers::is_same_object(src,dest))
+ return;
+
+ copy_graph_structure(src,dest);
+
+ // copy all the node and edge content
+ for (unsigned long i = 0; i < src.number_of_nodes(); ++i)
+ {
+ dest.node(i).data = src.node(i).data;
+ for (unsigned long j = 0; j < src.node(i).number_of_children(); ++j)
+ {
+ dest.node(i).child_edge(j) = src.node(i).child_edge(j);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename S
+ >
+ typename enable_if<is_graph<typename T::graph_type> >::type find_connected_nodes (
+ const T& n,
+ S& visited
+ )
+ {
+ if (visited.is_member(n.index()) == false)
+ {
+ unsigned long temp = n.index();
+ visited.add(temp);
+
+ for (unsigned long i = 0; i < n.number_of_neighbors(); ++i)
+ find_connected_nodes(n.neighbor(i), visited);
+ }
+ }
+
+ template <
+ typename T,
+ typename S
+ >
+ typename enable_if<is_directed_graph<typename T::graph_type> >::type find_connected_nodes (
+ const T& n,
+ S& visited
+ )
+ {
+ if (visited.is_member(n.index()) == false)
+ {
+ unsigned long temp = n.index();
+ visited.add(temp);
+
+ for (unsigned long i = 0; i < n.number_of_parents(); ++i)
+ find_connected_nodes(n.parent(i), visited);
+ for (unsigned long i = 0; i < n.number_of_children(); ++i)
+ find_connected_nodes(n.child(i), visited);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool graph_is_connected (
+ const T& g
+ )
+ {
+ if (g.number_of_nodes() == 0)
+ return true;
+
+ set<unsigned long>::kernel_1b_c visited;
+ find_connected_nodes(g.node(0), visited);
+ return (visited.size() == g.number_of_nodes());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool graph_has_symmetric_edges (
+ const T& graph
+ )
+ {
+ for (unsigned long i = 0; i < graph.number_of_nodes(); ++i)
+ {
+ for (unsigned long j = 0; j < graph.node(i).number_of_children(); ++j)
+ {
+ const unsigned long jj = graph.node(i).child(j).index();
+ // make sure every edge from a parent to a child has an edge linking back
+ if (graph.has_edge(jj,i) == false)
+ return false;
+ }
+
+ for (unsigned long j = 0; j < graph.node(i).number_of_parents(); ++j)
+ {
+ const unsigned long jj = graph.node(i).parent(j).index();
+ // make sure every edge from a child to a parent has an edge linking back
+ if (graph.has_edge(i,jj) == false)
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool graph_contains_directed_cycle (
+ const T& graph
+ )
+ {
+ using namespace std;
+ using namespace graph_helpers;
+ std::vector<bool> visited(graph.number_of_nodes(), false);
+ std::vector<bool> temp(graph.number_of_nodes(), false);
+
+ while (true)
+ {
+ // find the first node that hasn't been visited yet
+ unsigned long i;
+ for (i = 0; i < visited.size(); ++i)
+ {
+ if (visited[i] == false)
+ break;
+ }
+
+ // if we didn't find any non-visited nodes then we are done
+ if (i == visited.size())
+ return false;
+
+ if (search_for_directed_cycles(graph.node(i), visited, temp))
+ return true;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool graph_contains_undirected_cycle (
+ const T& graph
+ )
+ {
+ using namespace std;
+ using namespace graph_helpers;
+ std::vector<bool> visited(graph.number_of_nodes(), false);
+
+ while (true)
+ {
+ // find the first node that hasn't been visited yet
+ unsigned long i;
+ for (i = 0; i < visited.size(); ++i)
+ {
+ if (visited[i] == false)
+ break;
+ }
+
+ // if we didn't find any non-visited nodes then we are done
+ if (i == visited.size())
+ return false;
+
+ if (search_for_undirected_cycles(graph.node(i), visited))
+ return true;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename directed_graph_type,
+ typename graph_type
+ >
+ void create_moral_graph (
+ const directed_graph_type& g,
+ graph_type& moral_graph
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(graph_contains_directed_cycle(g) == false,
+ "\tvoid create_moral_graph(g, moral_graph)"
+ << "\n\tYou can only make moral graphs if g doesn't have directed cycles"
+ );
+ COMPILE_TIME_ASSERT(is_graph<graph_type>::value);
+ COMPILE_TIME_ASSERT(is_directed_graph<directed_graph_type>::value);
+
+ copy_graph_structure(g, moral_graph);
+
+ // now marry all the parents (i.e. add edges between parent nodes)
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ // loop over all combinations of parents of g.node(i)
+ for (unsigned long j = 0; j < g.node(i).number_of_parents(); ++j)
+ {
+ for (unsigned long k = 0; k < g.node(i).number_of_parents(); ++k)
+ {
+ const unsigned long p1 = g.node(i).parent(j).index();
+ const unsigned long p2 = g.node(i).parent(k).index();
+ if (p1 == p2)
+ continue;
+
+ if (moral_graph.has_edge(p1,p2) == false)
+ moral_graph.add_edge(p1,p2);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type,
+ typename sets_of_int
+ >
+ bool is_clique (
+ const graph_type& g,
+ const sets_of_int& clique
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\tvoid is_clique(g, clique)"
+ << "\n\tinvalid graph"
+ );
+#ifdef ENABLE_ASSERTS
+ clique.reset();
+ while (clique.move_next())
+ {
+ const unsigned long x = clique.element();
+ DLIB_ASSERT( x < g.number_of_nodes(),
+ "\tvoid is_clique(g, clique)"
+ << "\n\tthe clique set contained an invalid node index"
+ << "\n\tx: " << x
+ << "\n\tg.number_of_nodes(): " << g.number_of_nodes()
+ );
+ }
+#endif
+
+ COMPILE_TIME_ASSERT(is_graph<graph_type>::value);
+
+ std::vector<unsigned long> v;
+ v.reserve(clique.size());
+ clique.reset();
+ while (clique.move_next())
+ {
+ v.push_back(clique.element());
+ }
+
+ for (unsigned long i = 0; i < v.size(); ++i)
+ {
+ for (unsigned long j = 0; j < v.size(); ++j)
+ {
+ if (v[i] == v[j])
+ continue;
+ if (g.has_edge(v[i], v[j]) == false)
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type,
+ typename sets_of_int
+ >
+ bool is_maximal_clique (
+ const graph_type& g,
+ const sets_of_int& clique
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\tvoid is_maximal_clique(g, clique)"
+ << "\n\tinvalid graph"
+ );
+ DLIB_ASSERT(is_clique(g,clique) == true,
+ "\tvoid is_maximal_clique(g, clique)"
+ << "\n\tinvalid graph"
+ );
+#ifdef ENABLE_ASSERTS
+ clique.reset();
+ while (clique.move_next())
+ {
+ const unsigned long x = clique.element();
+ DLIB_ASSERT( x < g.number_of_nodes(),
+ "\tvoid is_maximal_clique(g, clique)"
+ << "\n\tthe clique set contained an invalid node index"
+ << "\n\tx: " << x
+ << "\n\tg.number_of_nodes(): " << g.number_of_nodes()
+ );
+ }
+#endif
+
+ COMPILE_TIME_ASSERT(is_graph<graph_type>::value);
+
+ if (clique.size() == 0)
+ return true;
+
+ // get an element in the clique and make sure that
+ // none of its neighbors that aren't in the clique are connected
+ // to all the elements of the clique.
+ clique.reset();
+ clique.move_next();
+ const unsigned long idx = clique.element();
+
+ for (unsigned long i = 0; i < g.node(idx).number_of_neighbors(); ++i)
+ {
+ const unsigned long n = g.node(idx).neighbor(i).index();
+ if (clique.is_member(n))
+ continue;
+
+ // now loop over all the clique members and make sure they don't all
+ // share an edge with node n
+ bool all_share_edge = true;
+ clique.reset();
+ while (clique.move_next())
+ {
+ if (g.has_edge(clique.element(), n) == false)
+ {
+ all_share_edge = false;
+ break;
+ }
+ }
+
+ if (all_share_edge == true)
+ return false;
+ }
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename enable_if<is_directed_graph<T>,bool>::type graph_contains_length_one_cycle (
+ const T& graph
+ )
+ {
+ for (unsigned long i = 0; i < graph.number_of_nodes(); ++i)
+ {
+ // make sure none of this guys children are actually itself
+ for (unsigned long n = 0; n < graph.node(i).number_of_children(); ++n)
+ {
+ if (graph.node(i).child(n).index() == i)
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename enable_if<is_graph<T>,bool>::type graph_contains_length_one_cycle (
+ const T& graph
+ )
+ {
+ for (unsigned long i = 0; i < graph.number_of_nodes(); ++i)
+ {
+ // make sure none of this guys neighbors are actually itself
+ for (unsigned long n = 0; n < graph.node(i).number_of_neighbors(); ++n)
+ {
+ if (graph.node(i).neighbor(n).index() == i)
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace graph_helpers
+ {
+ struct pair
+ {
+ unsigned long index;
+ unsigned long num_neighbors;
+
+ bool operator< (const pair& p) const { return num_neighbors < p.num_neighbors; }
+ };
+
+ template <
+ typename T,
+ typename S,
+ typename V
+ >
+ void search_graph_for_triangulate (
+ const T& n,
+ S& visited,
+ V& order_visited
+ )
+ {
+ // base case of recursion. stop when we hit a node we have
+ // already visited.
+ if (visited.is_member(n.index()))
+ return;
+
+ // record that we have visited this node
+ order_visited.push_back(n.index());
+ unsigned long temp = n.index();
+ visited.add(temp);
+
+ // we want to visit all the neighbors of this node but do
+ // so by visiting the nodes with the most neighbors first. So
+ // lets make a vector that lists the nodes in the order we
+ // want to visit them
+ std::vector<pair> neighbors;
+ for (unsigned long i = 0; i < n.number_of_neighbors(); ++i)
+ {
+ pair p;
+ p.index = i;
+ p.num_neighbors = n.neighbor(i).number_of_neighbors();
+ neighbors.push_back(p);
+ }
+
+ // now sort the neighbors array so that the neighbors with the
+ // most neighbors come first.
+ std::sort(neighbors.rbegin(), neighbors.rend());
+
+ // now visit all the nodes
+ for (unsigned long i = 0; i < neighbors.size(); ++i)
+ {
+ search_graph_for_triangulate(n.neighbor(neighbors[i].index), visited, order_visited);
+ }
+ }
+ } // end namespace graph_helpers
+
+ template <
+ typename graph_type,
+ typename set_of_sets_of_int
+ >
+ void triangulate_graph_and_find_cliques (
+ graph_type& g,
+ set_of_sets_of_int& cliques
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\tvoid triangulate_graph_and_find_cliques(g, cliques)"
+ << "\n\tInvalid graph"
+ );
+ DLIB_ASSERT(graph_is_connected(g) == true,
+ "\tvoid triangulate_graph_and_find_cliques(g, cliques)"
+ << "\n\tInvalid graph"
+ );
+
+ COMPILE_TIME_ASSERT(is_graph<graph_type>::value);
+
+
+ using namespace graph_helpers;
+ using namespace std;
+ typedef typename set_of_sets_of_int::type set_of_int;
+
+ cliques.clear();
+
+ // first we find the node with the most neighbors
+ unsigned long max_index = 0;
+ unsigned long num_neighbors = 0;
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ if (g.node(i).number_of_neighbors() > num_neighbors)
+ {
+ max_index = i;
+ num_neighbors = g.node(i).number_of_neighbors();
+ }
+ }
+
+ // now we do a depth first search of the entire graph starting
+ // with the node we just found. We record the order in which
+ // we visit each node in the vector order_visited.
+ std::vector<unsigned long> order_visited;
+ set_of_int visited;
+ search_graph_for_triangulate(g.node(max_index), visited, order_visited);
+
+ set_of_int clique;
+
+ // now add edges to the graph to make it triangulated
+ while (visited.size() > 0)
+ {
+ // we are going to enumerate over the nodes in the reverse of the
+ // order in which they were visited. So get the last node out.
+ const unsigned long idx = order_visited.back();
+ order_visited.pop_back();
+ visited.destroy(idx);
+
+ // as a start add this node to our current clique
+ unsigned long temp = idx;
+ clique.clear();
+ clique.add(temp);
+
+ // now we want to make a clique that contains node g.node(idx) and
+ // all of its neighbors that are still recorded in the visited set
+ // (except for neighbors that have only one edge).
+ for (unsigned long i = 0; i < g.node(idx).number_of_neighbors(); ++i)
+ {
+ // get the index of the i'th neighbor
+ unsigned long nidx = g.node(idx).neighbor(i).index();
+
+ // add it to the clique if it is still in visited and it isn't
+ // a node with only one neighbor
+ if (visited.is_member(nidx) == true &&
+ g.node(nidx).number_of_neighbors() != 1)
+ {
+ // add edges between this new node and all the nodes
+ // that are already in the clique
+ clique.reset();
+ while (clique.move_next())
+ {
+ if (g.has_edge(nidx, clique.element()) == false)
+ g.add_edge(nidx, clique.element());
+ }
+
+ // now also record that we added this node to the clique
+ clique.add(nidx);
+ }
+ }
+
+ if (cliques.is_member(clique) == false && is_maximal_clique(g,clique) )
+ {
+ cliques.add(clique);
+ }
+
+ // now it is possible that we are missing some cliques of size 2 since
+ // above we didn't add nodes with only one edge to any of our cliques.
+ // Now lets make sure all these nodes are accounted for
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ clique.clear();
+ if (g.node(i).number_of_neighbors() == 1)
+ {
+ unsigned long temp = i;
+ clique.add(temp);
+ temp = g.node(i).neighbor(0).index();
+ clique.add(temp);
+
+ if (cliques.is_member(clique) == false)
+ cliques.add(clique);
+ }
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type,
+ typename join_tree_type
+ >
+ void create_join_tree (
+ const graph_type& g,
+ join_tree_type& join_tree
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\tvoid create_join_tree(g, join_tree)"
+ << "\n\tInvalid graph"
+ );
+ DLIB_ASSERT(graph_is_connected(g) == true,
+ "\tvoid create_join_tree(g, join_tree)"
+ << "\n\tInvalid graph"
+ );
+
+ COMPILE_TIME_ASSERT(is_graph<graph_type>::value);
+ COMPILE_TIME_ASSERT(is_graph<join_tree_type>::value);
+
+
+
+ typedef typename join_tree_type::type set_of_int;
+ typedef typename join_tree_type::edge_type set_of_int_edge;
+ typedef typename set<set_of_int>::kernel_1b_c set_of_sets_of_int;
+
+ copy_graph_structure(g, join_tree);
+
+ // don't even bother in this case
+ if (g.number_of_nodes() == 0)
+ return;
+
+ set_of_sets_of_int cliques;
+ set_of_int s;
+
+ triangulate_graph_and_find_cliques(join_tree, cliques);
+
+ join_tree.set_number_of_nodes(cliques.size());
+
+ // copy the cliques into each of the nodes of tree
+ for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i)
+ {
+ cliques.remove_any(s);
+ s.swap(join_tree.node(i).data);
+ }
+
+ set_of_int_edge e;
+
+ // add all possible edges to the join_tree
+ for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i)
+ {
+ for (unsigned long j = i+1; j < join_tree.number_of_nodes(); ++j)
+ {
+ set_intersection(
+ join_tree.node(i).data,
+ join_tree.node(j).data,
+ e);
+
+ if (e.size() > 0)
+ {
+ join_tree.add_edge(i,j);
+ edge(join_tree,i,j).swap(e);
+ }
+ }
+ }
+
+ // now we just need to remove the unnecessary edges so that we get a
+ // proper join tree
+ s.clear();
+ set_of_int& good = s; // rename s to something slightly more meaningful
+ // good will contain nodes that have been "approved"
+ unsigned long n = 0;
+ good.add(n);
+
+ std::vector<unsigned long> vtemp;
+
+ while (good.size() < join_tree.number_of_nodes())
+ {
+ // figure out which of the neighbors of nodes in good has the best edge
+ unsigned long best_bad_idx = 0;
+ unsigned long best_good_idx = 0;
+ unsigned long best_overlap = 0;
+ good.reset();
+ while (good.move_next())
+ {
+ // loop over all the neighbors of the current node in good
+ for (unsigned long i = 0; i < join_tree.node(good.element()).number_of_neighbors(); ++i)
+ {
+ const unsigned long idx = join_tree.node(good.element()).neighbor(i).index();
+ if (!good.is_member(idx))
+ {
+ const unsigned long overlap = join_tree.node(good.element()).edge(i).size();
+
+ if (overlap > best_overlap)
+ {
+ best_overlap = overlap;
+ best_bad_idx = idx;
+ best_good_idx = good.element();
+ }
+ }
+ }
+ }
+
+ // now remove all the edges from best_bad_idx to the nodes in good except for the
+ // edge to best_good_idx.
+ for (unsigned long i = 0; i < join_tree.node(best_bad_idx).number_of_neighbors(); ++i)
+ {
+ const unsigned long idx = join_tree.node(best_bad_idx).neighbor(i).index();
+ if (idx != best_good_idx && good.is_member(idx))
+ {
+ vtemp.push_back(idx);
+ }
+ }
+
+ for (unsigned long i = 0; i < vtemp.size(); ++i)
+ join_tree.remove_edge(vtemp[i], best_bad_idx);
+
+ vtemp.clear();
+
+
+ // and finally add this bad index into the good set
+ good.add(best_bad_idx);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace graph_helpers
+ {
+ template <
+ typename T,
+ typename U
+ >
+ bool validate_join_tree (
+ const T& n,
+ U& deads,
+ unsigned long parent = 0xffffffff
+ )
+ /*!
+ this function makes sure that a join tree satisfies the following criterion for paths starting at the given node:
+ - for all valid i and j such that i and j are both < #join_tree.number_of_nodes()
+ - let X be the set of numbers that is contained in both #join_tree.node(i).data
+ and #join_tree.node(j).data
+ - It is the case that all nodes on the unique path between #join_tree.node(i)
+ and #join_tree.node(j) contain the numbers from X in their sets.
+
+ returns true if validation passed and false if there is a problem with the tree
+ !*/
+ {
+ n.data.reset();
+ while (n.data.move_next())
+ {
+ if (deads.is_member(n.data.element()))
+ return false;
+ }
+
+
+ for (unsigned long i = 0; i < n.number_of_neighbors(); ++i)
+ {
+ if (n.neighbor(i).index() == parent)
+ continue;
+
+ // add anything to dead stuff
+ n.data.reset();
+ while (n.data.move_next())
+ {
+ if (n.neighbor(i).data.is_member(n.data.element()) == false)
+ {
+ unsigned long temp = n.data.element();
+ deads.add(temp);
+ }
+ }
+
+ if (validate_join_tree(n.neighbor(i), deads, n.index()) == false)
+ return false;
+
+ // remove this nodes stuff from dead stuff
+ n.data.reset();
+ while (n.data.move_next())
+ {
+ if (n.neighbor(i).data.is_member(n.data.element()) == false)
+ {
+ unsigned long temp = n.data.element();
+ deads.destroy(temp);
+ }
+ }
+ }
+
+ return true;
+ }
+ }
+
+ template <
+ typename graph_type,
+ typename join_tree_type
+ >
+ bool is_join_tree (
+ const graph_type& g,
+ const join_tree_type& join_tree
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\tvoid create_join_tree(g, join_tree)"
+ << "\n\tInvalid graph"
+ );
+ DLIB_ASSERT(graph_is_connected(g) == true,
+ "\tvoid create_join_tree(g, join_tree)"
+ << "\n\tInvalid graph"
+ );
+
+ COMPILE_TIME_ASSERT(is_graph<graph_type>::value || is_directed_graph<graph_type>::value);
+ COMPILE_TIME_ASSERT(is_graph<join_tree_type>::value);
+
+
+ if (graph_contains_undirected_cycle(join_tree))
+ return false;
+
+ if (graph_is_connected(join_tree) == false)
+ return false;
+
+ // verify that the path condition of the join tree is valid
+ for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i)
+ {
+ typename join_tree_type::type deads;
+ if (graph_helpers::validate_join_tree(join_tree.node(i), deads) == false)
+ return false;
+ }
+
+ typename join_tree_type::edge_type e;
+ typename join_tree_type::edge_type all;
+ // now make sure that the edges contain correct intersections
+ for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i)
+ {
+ set_union(all,join_tree.node(i).data, all);
+ for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j)
+ {
+ set_intersection(join_tree.node(i).data,
+ join_tree.node(i).neighbor(j).data,
+ e);
+
+ if (!(e == join_tree.node(i).edge(j)))
+ return false;
+ }
+ }
+
+ // and finally check that all the nodes in g show up in the join tree
+ if (all.size() != g.number_of_nodes())
+ return false;
+ all.reset();
+ while (all.move_next())
+ {
+ if (all.element() >= g.number_of_nodes())
+ return false;
+ }
+
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GRAPH_UTILs_
+
+
diff --git a/ml/dlib/dlib/graph_utils/graph_utils_abstract.h b/ml/dlib/dlib/graph_utils/graph_utils_abstract.h
new file mode 100644
index 000000000..52e170237
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/graph_utils_abstract.h
@@ -0,0 +1,452 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_GRAPH_UTILs_ABSTRACT_
+#ifdef DLIB_GRAPH_UTILs_ABSTRACT_
+
+#include "../directed_graph.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename T::edge_type& edge(
+ T& g,
+ unsigned long i,
+ unsigned long j
+ );
+ /*!
+ requires
+ - T is an implementation of graph/graph_kernel_abstract.h
+ - g.has_edge(i,j)
+ ensures
+ - returns a reference to the edge data for the edge connecting nodes i and j
+ (i.e. returns g.node(i).edge(x) such that g.node(i).neighbor(x).index() == j)
+ !*/
+
+ template <
+ typename T
+ >
+ typename const T::edge_type& edge(
+ const T& g,
+ unsigned long i,
+ unsigned long j
+ );
+ /*!
+ requires
+ - T is an implementation of graph/graph_kernel_abstract.h
+ - g.has_edge(i,j)
+ ensures
+ - returns a const reference to the edge data for the edge connecting nodes i and j
+ (i.e. returns g.node(i).edge(x) such that g.node(i).neighbor(x).index() == j)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename T::edge_type& edge(
+ T& g,
+ unsigned long parent_idx,
+ unsigned long child_idx
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - g.has_edge(parent_idx,child_idx)
+ ensures
+ - returns a reference to the edge data for the directed edge connecting parent
+ node g.node(parent_idx) to child node g.node(child_idx).
+ !*/
+
+ template <
+ typename T
+ >
+ typename const T::edge_type& edge(
+ const T& g,
+ unsigned long parent_idx,
+ unsigned long child_idx
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - g.has_edge(parent_idx,child_idx)
+ ensures
+ - returns a const reference to the edge data for the directed edge connecting
+ parent node g.node(parent_idx) to child node g.node(child_idx).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool graph_has_symmetric_edges (
+ const T& graph
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ ensures
+ - if (All nodes have either 0 edges between them or 2 edges between them.
+ That is, if there is an edge pointing from node A to node B then there is
+ also an edge from B to A) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool graph_contains_directed_cycle (
+ const T& graph
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ ensures
+ - if (there is a directed cycle in the given graph) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool graph_contains_undirected_cycle (
+ const T& graph
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h or
+ T is an implementation of graph/graph_kernel_abstract.h
+ ensures
+ - if (there is an undirected cycle in the given graph) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool graph_contains_length_one_cycle (
+ const T& graph
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h or
+ T is an implementation of graph/graph_kernel_abstract.h
+ ensures
+ - if (it is the case that graph.has_edge(i,i) == true for some i) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename src_type,
+ typename dest_type
+ >
+ void copy_graph_structure (
+ const src_type& src,
+ dest_type& dest
+ );
+ /*!
+ requires
+ - src_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or
+ src_type is an implementation of graph/graph_kernel_abstract.h
+ - dest_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or
+ dest_type is an implementation of graph/graph_kernel_abstract.h
+ - dest_type is not a directed_graph when src_type is a graph
+ ensures
+ - this function copies the graph structure from src into dest
+ - #dest.number_of_nodes() == src.number_of_nodes()
+ - for all valid i: #dest.node(i).item has an initial value for its type
+ - for all valid i and j:
+ - if (src.has_edge(i,j) == true) then
+ - #dest.has_edge(i,j) == true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename src_type,
+ typename dest_type
+ >
+ void copy_graph (
+ const src_type& src,
+ dest_type& dest
+ );
+ /*!
+ requires
+ - src_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or
+ src_type is an implementation of graph/graph_kernel_abstract.h
+ - dest_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or
+ dest_type is an implementation of graph/graph_kernel_abstract.h
+ - src_type and dest_type are both the same kind of graph. That is, they
+ are either both directed or both undirected.
+ - the node and edge data in the graphs are copyable via operator=().
+ ensures
+ - #dest is a complete duplicate of src.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename directed_graph_type,
+ typename graph_type
+ >
+ void create_moral_graph (
+ const directed_graph_type& g,
+ graph_type& moral_graph
+ );
+ /*!
+ requires
+ - directed_graph_type is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ - graph_type is an implementation of graph/graph_kernel_abstract.h
+ - graph_contains_directed_cycle(g) == false
+ ensures
+ - #moral_graph == the moralized version of the directed graph g
+ - #moral_graph.number_of_nodes() == g.number_of_nodes()
+ - for all valid i and j:
+ - if (g.has_edge(i,j) == true) then
+ - #moral_graph.has_edge(i,j) == true
+ (i.e. all the edges that are in g are also in moral_graph)
+ - for all valid i:
+ - for all pairs p1 and p2 such that p1 != p2 and g.node(p1) and g.node(p2) are both
+ parents of node g.node(i):
+ - #moral_graph.has_edge(p1,p2) == true
+ (i.e. all the parents of a node are connected in the moral graph)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename S
+ >
+ void find_connected_nodes (
+ const T& n,
+ S& visited
+ );
+ /*!
+ requires
+ - T is a node_type from an implementation of directed_graph/directed_graph_kernel_abstract.h or
+ T is a node_type from an implementation of graph/graph_kernel_abstract.h
+ - S is an implementation of set/set_kernel_abstract.h
+ ensures
+ - let G be the graph that contains node n
+ - #visited.is_member(n.index()) == true
+ - for all i such that there is an undirected path from n to G.node(i):
+ - #visited.is_member(i) == true
+ - for all i such that visited.is_member(i):
+ - #visited.is_member(i) == true
+ (i.e. this function doesn't remove anything from visited. So if
+ it contains stuff when you call this function then it will still
+ contain those things once the function ends)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool graph_is_connected (
+ const T& g
+ );
+ /*!
+ requires
+ - T is an implementation of directed_graph/directed_graph_kernel_abstract.h or
+ T is an implementation of graph/graph_kernel_abstract.h
+ ensures
+ - every node in g has an undirected path to every other node in g.
+ I.e. g is a connected graph
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type,
+ typename sets_of_int
+ >
+ bool is_clique (
+ const graph_type& g,
+ const sets_of_int& clique
+ );
+ /*!
+ requires
+ - graph_type is an implementation of graph/graph_kernel_abstract.h
+ - sets_of_int is an implementation of set/set_kernel_abstract.h
+ and it contains unsigned long objects.
+ - graph_contains_length_one_cycle(g) == false
+ - for all x such that clique.is_member(x):
+ - x < g.number_of_nodes()
+ ensures
+ - if (it is true that for all i and j such that clique.is_member(i) and
+ clique.is_member(j) then g.has_edge(i,j) == true) then
+ - returns true
+ - else
+ - returns false
+ - if (clique.size() == 0) then
+ - returns true
+ (this is just a special case of the above condition)
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type,
+ typename sets_of_int
+ >
+ bool is_maximal_clique (
+ const graph_type& g,
+ const sets_of_int& clique
+ );
+ /*!
+ requires
+ - graph_type is an implementation of graph/graph_kernel_abstract.h
+ - sets_of_int is an implementation of set/set_kernel_abstract.h
+ and it contains unsigned long objects.
+ - graph_contains_length_one_cycle(g) == false
+ - for all x such that clique.is_member(x):
+ - x < g.number_of_nodes()
+ - is_clique(g,clique) == true
+ ensures
+ - if (there is no x such that clique.is_member(x) == false
+ and g.has_edge(i,x) for all i such that cliques.is_member(i)) then
+ - returns true
+ - else
+ - returns false
+ - if (clique.size() == 0) then
+ - returns true
+ (this is just a special case of the above condition)
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type,
+ typename set_of_sets_of_int
+ >
+ void triangulate_graph_and_find_cliques (
+ graph_type& g,
+ set_of_sets_of_int& cliques
+ );
+ /*!
+ requires
+ - graph_type is an implementation of graph/graph_kernel_abstract.h
+ - set_of_sets_of_int is an implementation of set/set_kernel_abstract.h
+ and it contains another set object which is comparable by operator< and
+ itself contains unsigned long objects.
+ (e.g. set<set<unsigned long>::compare_1a>::kernel_1a)
+ - graph_contains_length_one_cycle(g) == false
+ - graph_is_connected(g) == true
+ ensures
+ - #g.number_of_nodes() == g.number_of_nodes()
+ - all this function does to g is add edges to it until g becomes a
+ chordal graph where a chordal graph is a graph where each cycle
+ in the graph of 4 or more nodes has an edge joining two nodes
+ that are not adjacent in the cycle.
+ - #cliques.size() == the number of maximal cliques in the graph #g
+ - for all valid sets S such that #cliques.is_member(S):
+ - for all valid integers i and j such that S.is_member(i) == true
+ and S.is_member(j) == true and i != j:
+ - #g.has_edge(i,j) == true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type,
+ typename join_tree_type
+ >
+ bool is_join_tree (
+ const graph_type& g,
+ const join_tree_type& join_tree
+ );
+ /*!
+ requires
+ - graph_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or
+ graph_type is an implementation of graph/graph_kernel_abstract.h
+ - join_tree_type is an implementation of graph/graph_kernel_abstract.h
+ - join_tree_type::type is an implementation of set/set_compare_abstract.h and
+ this set type contains unsigned long objects.
+ - join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and
+ this set type contains unsigned long objects.
+ - graph_contains_length_one_cycle(g) == false
+ - graph_is_connected(g) == true
+ ensures
+ - if (join_tree is a valid join tree of graph g. That is, join_tree is a
+ tree decomposition of g) then
+ - returns true
+ - else
+ - returns false
+
+ - a join tree of graph g is defined as follows:
+ - graph_contains_undirected_cycle(join_tree) == false
+ - graph_is_connected(join_tree) == true
+ - for all valid i:
+ - join_tree.node(i).item == a non-empty set containing node indexes
+ from g. That is, this set contains all the nodes from g that are
+ in this cluster in the join tree
+ - for all valid i and j such that i and j are both < join_tree.number_of_nodes()
+ - let X be the set of numbers that is contained in both join_tree.node(i).item
+ and join_tree.node(j).item
+ - It is the case that all nodes on the unique path between join_tree.node(i)
+ and join_tree.node(j) contain the numbers from X in their sets.
+ - edge(join_tree,i,j) == a set containing the intersection of
+ join_tree.node(i).item and join_tree.node(j).item
+ - the node index for every node in g appears in some node in join_tree at
+ least once.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type,
+ typename join_tree_type
+ >
+ void create_join_tree (
+ const graph_type& g,
+ join_tree_type& join_tree
+ );
+ /*!
+ requires
+ - graph_type is an implementation of graph/graph_kernel_abstract.h
+ - join_tree_type is an implementation of graph/graph_kernel_abstract.h
+ - join_tree_type::type is an implementation of set/set_compare_abstract.h and
+ this set type contains unsigned long objects.
+ - join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and
+ this set type contains unsigned long objects.
+ - graph_contains_length_one_cycle(g) == false
+ - graph_is_connected(g) == true
+ ensures
+ - #is_join_tree(g, join_tree) == true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GRAPH_UTILs_ABSTRACT_
+
diff --git a/ml/dlib/dlib/graph_utils/ordered_sample_pair.h b/ml/dlib/dlib/graph_utils/ordered_sample_pair.h
new file mode 100644
index 000000000..7d510e122
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/ordered_sample_pair.h
@@ -0,0 +1,125 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ORDERED_SAMPLE_PaIR_Hh_
+#define DLIB_ORDERED_SAMPLE_PaIR_Hh_
+
+#include "ordered_sample_pair_abstract.h"
+#include <limits>
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class ordered_sample_pair
+ {
+ public:
+ ordered_sample_pair(
+ ) :
+ _index1(0),
+ _index2(0)
+ {
+ _distance = 1;
+ }
+
+ ordered_sample_pair (
+ const unsigned long idx1,
+ const unsigned long idx2
+ )
+ {
+ _distance = 1;
+ _index1 = idx1;
+ _index2 = idx2;
+ }
+
+ ordered_sample_pair (
+ const unsigned long idx1,
+ const unsigned long idx2,
+ const double dist
+ )
+ {
+ _distance = dist;
+ _index1 = idx1;
+ _index2 = idx2;
+ }
+
+ const unsigned long& index1 (
+ ) const { return _index1; }
+
+ const unsigned long& index2 (
+ ) const { return _index2; }
+
+ const double& distance (
+ ) const { return _distance; }
+
+ private:
+ unsigned long _index1;
+ unsigned long _index2;
+ double _distance;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool operator == (
+ const ordered_sample_pair& a,
+ const ordered_sample_pair& b
+ )
+ {
+ return a.index1() == b.index1() && a.index2() == b.index2();
+ }
+
+ inline bool operator != (
+ const ordered_sample_pair& a,
+ const ordered_sample_pair& b
+ )
+ {
+ return !(a == b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const ordered_sample_pair& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.index1(),out);
+ serialize(item.index2(),out);
+ serialize(item.distance(),out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type ordered_sample_pair");
+ }
+ }
+
+ inline void deserialize (
+ ordered_sample_pair& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long idx1, idx2;
+ double dist;
+
+ deserialize(idx1,in);
+ deserialize(idx2,in);
+ deserialize(dist,in);
+ item = ordered_sample_pair(idx1, idx2, dist);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type ordered_sample_pair");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ORDERED_SAMPLE_PaIR_Hh_
+
diff --git a/ml/dlib/dlib/graph_utils/ordered_sample_pair_abstract.h b/ml/dlib/dlib/graph_utils/ordered_sample_pair_abstract.h
new file mode 100644
index 000000000..9d150e257
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/ordered_sample_pair_abstract.h
@@ -0,0 +1,128 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ORDERED_SAMPLE_PaIR_ABSTRACT_Hh_
+#ifdef DLIB_ORDERED_SAMPLE_PaIR_ABSTRACT_Hh_
+
+#include <limits>
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class ordered_sample_pair
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is intended to represent an edge in a directed graph which has
+ data samples at its vertices. So it contains two integers (index1 and
+ index2) which represent the identifying indices of the samples at the ends
+ of an edge.
+
+ This object also contains a double which can be used for any purpose.
+ !*/
+
+ public:
+ ordered_sample_pair(
+ );
+ /*!
+ ensures
+ - #index1() == 0
+ - #index2() == 0
+ - #distance() == 1
+ !*/
+
+ ordered_sample_pair (
+ const unsigned long idx1,
+ const unsigned long idx2
+ );
+ /*!
+ ensures
+ - #index1() == idx1
+ - #index2() == idx2
+ - #distance() == 1
+ !*/
+
+ ordered_sample_pair (
+ const unsigned long idx1,
+ const unsigned long idx2,
+ const double dist
+ );
+ /*!
+ ensures
+ - #index1() == idx1
+ - #index2() == idx2
+ - #distance() == dist
+ !*/
+
+ const unsigned long& index1 (
+ ) const;
+ /*!
+ ensures
+ - returns the first index value stored in this object
+ !*/
+
+ const unsigned long& index2 (
+ ) const;
+ /*!
+ ensures
+ - returns the second index value stored in this object
+ !*/
+
+ const double& distance (
+ ) const;
+ /*!
+ ensures
+ - returns the floating point number stored in this object
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool operator == (
+ const ordered_sample_pair& a,
+ const ordered_sample_pair& b
+ );
+ /*!
+ ensures
+ - returns a.index1() == b.index1() && a.index2() == b.index2();
+ I.e. returns true if a and b both represent the same pair and false otherwise.
+ Note that the distance field is not involved in this comparison.
+ !*/
+
+ inline bool operator != (
+ const ordered_sample_pair& a,
+ const ordered_sample_pair& b
+ );
+ /*!
+ ensures
+ - returns !(a == b)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const ordered_sample_pair& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ inline void deserialize (
+ ordered_sample_pair& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ORDERED_SAMPLE_PaIR_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/graph_utils/sample_pair.h b/ml/dlib/dlib/graph_utils/sample_pair.h
new file mode 100644
index 000000000..88ad458f6
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/sample_pair.h
@@ -0,0 +1,179 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SAMPLE_PaIR_Hh_
+#define DLIB_SAMPLE_PaIR_Hh_
+
+#include "sample_pair_abstract.h"
+#include <limits>
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class sample_pair
+ {
+ public:
+ sample_pair(
+ ) :
+ _index1(0),
+ _index2(0)
+ {
+ _distance = 1;
+ }
+
+ sample_pair (
+ const unsigned long idx1,
+ const unsigned long idx2
+ )
+ {
+ _distance = 1;
+ if (idx1 < idx2)
+ {
+ _index1 = idx1;
+ _index2 = idx2;
+ }
+ else
+ {
+ _index1 = idx2;
+ _index2 = idx1;
+ }
+ }
+
+ sample_pair (
+ const unsigned long idx1,
+ const unsigned long idx2,
+ const double dist
+ )
+ {
+ _distance = dist;
+ if (idx1 < idx2)
+ {
+ _index1 = idx1;
+ _index2 = idx2;
+ }
+ else
+ {
+ _index1 = idx2;
+ _index2 = idx1;
+ }
+ }
+
+ const unsigned long& index1 (
+ ) const { return _index1; }
+
+ const unsigned long& index2 (
+ ) const { return _index2; }
+
+ const double& distance (
+ ) const { return _distance; }
+
+ private:
+ unsigned long _index1;
+ unsigned long _index2;
+ double _distance;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ inline bool order_by_index (
+ const T& a,
+ const T& b
+ )
+ {
+ return a.index1() < b.index1() || (a.index1() == b.index1() && a.index2() < b.index2());
+ }
+
+ template <typename T>
+ inline bool order_by_distance (
+ const T& a,
+ const T& b
+ )
+ {
+ return a.distance() < b.distance();
+ }
+
+ template <typename T>
+ inline bool order_by_descending_distance (
+ const T& a,
+ const T& b
+ )
+ {
+ return a.distance() > b.distance();
+ }
+
+ template <typename T>
+ bool order_by_distance_and_index (
+ const T& a,
+ const T& b
+ )
+ {
+ return a.distance() < b.distance() || (a.distance() == b.distance() && order_by_index(a,b));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool operator == (
+ const sample_pair& a,
+ const sample_pair& b
+ )
+ {
+ return a.index1() == b.index1() && a.index2() == b.index2();
+ }
+
+ inline bool operator != (
+ const sample_pair& a,
+ const sample_pair& b
+ )
+ {
+ return !(a == b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const sample_pair& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.index1(),out);
+ serialize(item.index2(),out);
+ serialize(item.distance(),out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type sample_pair");
+ }
+ }
+
+ inline void deserialize (
+ sample_pair& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long idx1, idx2;
+ double dist;
+
+ deserialize(idx1,in);
+ deserialize(idx2,in);
+ deserialize(dist,in);
+ item = sample_pair(idx1, idx2, dist);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type sample_pair");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SAMPLE_PaIR_Hh_
+
diff --git a/ml/dlib/dlib/graph_utils/sample_pair_abstract.h b/ml/dlib/dlib/graph_utils/sample_pair_abstract.h
new file mode 100644
index 000000000..3306899e3
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils/sample_pair_abstract.h
@@ -0,0 +1,192 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SAMPLE_PaIR_ABSTRACT_Hh_
+#ifdef DLIB_SAMPLE_PaIR_ABSTRACT_Hh_
+
+#include <limits>
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class sample_pair
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is intended to represent an edge in an undirected graph
+ which has data samples at its vertices. So it contains two integers
+ (index1 and index2) which represent the identifying indices of
+ the samples at the ends of an edge. Note that this object enforces
+ the constraint that index1 <= index2. This has the effect of
+ making the edges undirected since a sample_pair is incapable
+ of representing a single edge in more than one way. That is,
+ sample_pair(i,j) == sample_pair(j,i) for any value of i and j.
+
+ This object also contains a double which can be used for any purpose.
+ !*/
+
+ public:
+ sample_pair(
+ );
+ /*!
+ ensures
+ - #index1() == 0
+ - #index2() == 0
+ - #distance() == 1
+ !*/
+
+ sample_pair (
+ const unsigned long idx1,
+ const unsigned long idx2
+ );
+ /*!
+ ensures
+ - #index1() == min(idx1,idx2)
+ - #index2() == max(idx1,idx2)
+ - #distance() == 1
+ !*/
+
+ sample_pair (
+ const unsigned long idx1,
+ const unsigned long idx2,
+ const double dist
+ );
+ /*!
+ ensures
+ - #index1() == min(idx1,idx2)
+ - #index2() == max(idx1,idx2)
+ - #distance() == dist
+ !*/
+
+ const unsigned long& index1 (
+ ) const;
+ /*!
+ ensures
+ - returns the first index value stored in this object
+ !*/
+
+ const unsigned long& index2 (
+ ) const;
+ /*!
+ ensures
+ - returns the second index value stored in this object
+ !*/
+
+ const double& distance (
+ ) const;
+ /*!
+ ensures
+ - returns the floating point number stored in this object
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ bool order_by_index (
+ const T& a,
+ const T& b
+ ) { return a.index1() < b.index1() || (a.index1() == b.index1() && a.index2() < b.index2()); }
+ /*!
+ requires
+ - T is a type with an interface compatible with sample_pair.
+ ensures
+ - provides a total ordering of sample_pair objects that will cause pairs that are
+ equal to be adjacent when sorted. So for example, this function can be used
+ with std::sort() to first sort a sequence of sample_pair objects and then
+ find duplicate edges.
+ !*/
+
+ template <typename T>
+ bool order_by_distance (
+ const T& a,
+ const T& b
+ ) { return a.distance() < b.distance(); }
+ /*!
+ requires
+ - T is a type with an interface compatible with sample_pair.
+ ensures
+ - provides a total ordering of sample_pair objects that causes pairs with
+ smallest distance to be the first in a sorted list. This function can be
+ used with std::sort().
+ !*/
+
+ template <typename T>
+ bool order_by_descending_distance (
+ const T& a,
+ const T& b
+ ) { return a.distance() > b.distance(); }
+ /*!
+ requires
+ - T is a type with an interface compatible with sample_pair.
+ ensures
+ - provides a total ordering of sample_pair objects that causes pairs with
+ largest distance to be the first in a sorted list. This function can be
+ used with std::sort().
+ !*/
+
+ template <typename T>
+ bool order_by_distance_and_index (
+ const T& a,
+ const T& b
+ ) { return a.distance() < b.distance() || (a.distance() == b.distance() && order_by_index(a,b)); }
+ /*!
+ requires
+ - T is a type with an interface compatible with sample_pair.
+ ensures
+ - provides a total ordering of sample_pair objects that causes pairs with
+ smallest distance to be the first in a sorted list but also orders samples
+ with equal distances according to order_by_index(). This function can be
+ used with std::sort().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool operator == (
+ const sample_pair& a,
+ const sample_pair& b
+ );
+ /*!
+ ensures
+ - returns a.index1() == b.index1() && a.index2() == b.index2();
+ I.e. returns true if a and b both represent the same pair and false otherwise.
+ Note that the distance field is not involved in this comparison.
+ !*/
+
+ inline bool operator != (
+ const sample_pair& a,
+ const sample_pair& b
+ );
+ /*!
+ ensures
+ - returns !(a == b)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const sample_pair& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ inline void deserialize (
+ sample_pair& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SAMPLE_PaIR_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/graph_utils_threaded.h b/ml/dlib/dlib/graph_utils_threaded.h
new file mode 100644
index 000000000..c9938fd80
--- /dev/null
+++ b/ml/dlib/dlib/graph_utils_threaded.h
@@ -0,0 +1,12 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GRAPH_UTILs_THREADED_H_
+#define DLIB_GRAPH_UTILs_THREADED_H_
+
+#include "graph_utils.h"
+#include "graph_utils/find_k_nearest_neighbors_lsh.h"
+
+#endif // DLIB_GRAPH_UTILs_THREADED_H_
+
+
+
diff --git a/ml/dlib/dlib/gui_core.h b/ml/dlib/dlib/gui_core.h
new file mode 100644
index 000000000..6ba54b11c
--- /dev/null
+++ b/ml/dlib/dlib/gui_core.h
@@ -0,0 +1,20 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GUI_CORe_
+#define DLIB_GUI_CORe_
+
+
+#include "platform.h"
+
+
+
+#ifdef WIN32
+#include "gui_core/windows.h"
+#else
+#include "gui_core/xlib.h"
+#endif
+
+
+
+#endif // DLIB_GUI_CORe_
+
diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_1.cpp b/ml/dlib/dlib/gui_core/gui_core_kernel_1.cpp
new file mode 100644
index 000000000..2a6efa4c9
--- /dev/null
+++ b/ml/dlib/dlib/gui_core/gui_core_kernel_1.cpp
@@ -0,0 +1,2204 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GUI_CORE_KERNEL_1_CPp_
+#define DLIB_GUI_CORE_KERNEL_1_CPp_
+#include "../platform.h"
+
+#ifdef WIN32
+
+#include "gui_core_kernel_1.h"
+
+// tell visual studio to link to the libraries we need if we are
+// in fact using visual studio
+#ifdef _MSC_VER
+#pragma comment (lib, "gdi32.lib")
+#pragma comment (lib, "comctl32.lib")
+#pragma comment (lib, "user32.lib")
+#pragma comment (lib, "imm32.lib")
+#endif
+
+#include <cmath>
+#include <memory>
+#include <sstream>
+#include <vector>
+
+#include "../threads.h"
+#include "../assert.h"
+#include "../queue.h"
+#include "../sync_extension.h"
+#include "../queue.h"
+#include "../logger.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace gui_core_kernel_1_globals
+ {
+
+
+ struct user_event_type
+ {
+ HWND w;
+ void* p;
+ int i;
+ };
+
+ typedef sync_extension<binary_search_tree<HWND,base_window*>::kernel_1a>::kernel_1a window_table_type;
+ typedef sync_extension<queue<user_event_type,memory_manager<char>::kernel_1b>::kernel_2a_c>::kernel_1a queue_of_user_events;
+
+
+ enum USER_OFFSETS
+ {
+ CREATE_WINDOW,
+ DESTROY_WINDOW,
+ SET_ACTIVE_WINDOW,
+ QUIT_EVENT_HANDLER_THREAD,
+ USER_EVENTS_READY,
+ CALL_MOVE_WINDOW,
+ SHOW_WINDOW_SHOW,
+ SHOW_WINDOW_HIDE,
+ CALL_SET_WINDOW_TITLE
+ };
+
+ // ----------------------------------------------------------------------------------------
+
+ const std::shared_ptr<dlib::mutex>& global_mutex()
+ {
+ static std::shared_ptr<dlib::mutex> m(new dlib::mutex);
+ return m;
+ }
+
+ class event_handler_thread : public threaded_object
+ {
+ public:
+
+ enum et_state
+ {
+ uninitialized,
+ initialized,
+ failure_to_init
+ };
+
+ et_state status;
+
+ queue_of_user_events user_events;
+ queue_of_user_events user_events_temp;
+ logger dlog;
+
+ HINSTANCE hInstance;
+ HWND helper_window;
+ const TCHAR* window_class_name;
+
+ bool quit_windows_loop;
+ bool set_window_title_done;
+ std::wstring window_title;
+ bool move_window_done;
+ HWND move_window_hwnd;
+ int move_window_width;
+ int move_window_height;
+ int move_window_x;
+ int move_window_y;
+ bool request_new_window;
+ DWORD dwStyle;
+ HWND new_window;
+ bool in_ime_composition;
+ bool event_thread_started;
+ // the window_table.get_mutex() mutex locks the above 11 variables
+
+
+ // this variable holds a mapping from window handles to the base_window
+ // objects which represent them. Note that this objects mutex is always locked
+ // when inside the event loop.
+ window_table_type window_table;
+ rsignaler window_close_signaler;
+ rsignaler et_signaler;
+
+ // note that this is the thread that will perform all the event
+ // processing.
+ thread_id_type event_thread_id;
+
+ std::shared_ptr<dlib::mutex> reference_to_global_mutex;
+
+ event_handler_thread(
+ ) :
+ dlog("dlib.gui_core"),
+ hInstance(0),
+ helper_window(0),
+ window_class_name(TEXT ("w3049u6qc2d94thw9m34f4we0gvwa3-tgkser0-b9gm 05")),
+ quit_windows_loop(false),
+ set_window_title_done(true),
+ move_window_done(true),
+ move_window_hwnd(0),
+ move_window_width(0),
+ move_window_height(0),
+ move_window_x(0),
+ move_window_y(0),
+ request_new_window(false),
+ dwStyle(0),
+ new_window(0),
+ in_ime_composition(false),
+ event_thread_started(false),
+ window_close_signaler(window_table.get_mutex()),
+ et_signaler(window_table.get_mutex()),
+ reference_to_global_mutex(global_mutex())
+ {
+ status = uninitialized;
+ }
+
+ void start_event_thread (
+ )
+ /*!
+ we can't call this function from this objects constructor because
+ starting the event thread in windows involves sending messages to the
+ WndProc() and that requires this object to be fully constructed.
+ !*/
+ {
+
+ if (event_thread_started == false)
+ {
+ auto_mutex M(window_table.get_mutex());
+ if (event_thread_started == false)
+ {
+ event_thread_started = true;
+ // start up the event handler thread
+ start();
+
+ // wait for the event thread to get up and running
+ while (status == uninitialized)
+ et_signaler.wait();
+
+ if (status == failure_to_init)
+ throw gui_error("Failed to start event thread");
+ }
+ }
+ }
+
+ ~event_handler_thread ()
+ {
+ using namespace gui_core_kernel_1_globals;
+
+ if (is_alive())
+ {
+ if (PostMessage(helper_window,WM_USER+QUIT_EVENT_HANDLER_THREAD,0,0)==0)
+ {
+ dlog << LWARN << "Unable to schedule function for execution in event handling thread.";
+ // No point calling wait() here since the thread isn't going to
+ // terminate gracefully in this case. So we just let the program
+ // end as it will and hope for the best.
+ }
+ else
+ {
+ // wait for the event handler thread to terminate.
+ wait();
+ }
+ }
+
+ }
+
+ private:
+
+ void thread (
+ )
+ {
+ event_thread_id = get_thread_id();
+
+ hInstance = GetModuleHandle(NULL);
+ if (hInstance == NULL)
+ {
+ dlog << LFATAL << "Error gathering needed resources";
+
+ // signal that an error has occurred
+ window_table.get_mutex().lock();
+ status = failure_to_init;
+ et_signaler.broadcast();
+ window_table.get_mutex().unlock();
+ return;
+ }
+
+ // register the main window class
+ WNDCLASS wndclass ;
+
+ wndclass.style = CS_DBLCLKS;
+ wndclass.lpfnWndProc = dlib::gui_core_kernel_1_globals::WndProc ;
+ wndclass.cbClsExtra = 0 ;
+ wndclass.cbWndExtra = 0 ;
+ wndclass.hInstance = hInstance ;
+ wndclass.hIcon = LoadIcon (NULL, IDI_APPLICATION) ;
+ wndclass.hCursor = LoadCursor (NULL, IDC_ARROW) ;
+ wndclass.hbrBackground = 0;
+ wndclass.lpszMenuName = NULL ;
+ wndclass.lpszClassName = window_class_name ;
+
+ if (!RegisterClass (&wndclass))
+ {
+ dlog << LFATAL << "Error registering window class";
+
+ // signal that an error has occurred
+ window_table.get_mutex().lock();
+ status = failure_to_init;
+ et_signaler.broadcast();
+ window_table.get_mutex().unlock();
+ return;
+ }
+
+
+ // make the helper window that is used to trigger events in the
+ // event handler loop from other threads
+ TCHAR nothing[] = TEXT("");
+ helper_window = CreateWindow(window_class_name,nothing,WS_DISABLED,0,0,0,0,HWND_MESSAGE,NULL,hInstance,NULL);
+ if (helper_window == NULL)
+ {
+ dlog << LFATAL << "Error gathering needed resources";
+
+ // signal that an error has occurred
+ window_table.get_mutex().lock();
+ status = failure_to_init;
+ et_signaler.broadcast();
+ window_table.get_mutex().unlock();
+ return;
+ }
+
+ // signal that the event thread is now up and running
+ window_table.get_mutex().lock();
+ status = initialized;
+ et_signaler.broadcast();
+ window_table.get_mutex().unlock();
+
+ // start the event handler loop.
+ /*
+ A note about this quit_windows_loop thing. If the user is holding
+ the mouse button down on the title bar of a window it will cause
+ the PostQuitMessage() function to be ignored!! This extra bool
+ is a work around to prevent that from happening.
+ */
+ MSG msg;
+ while (GetMessage (&msg, NULL, 0, 0) &&
+ quit_windows_loop == false)
+ {
+ TranslateMessage (&msg) ;
+ DispatchMessage (&msg) ;
+ }
+ }
+ };
+
+ // Do all this just to make sure global_mutex() is initialized at program start
+ // and thus hopefully before any threads have the chance to startup and call
+ // global_data() concurrently.
+ struct call_global_mutex { call_global_mutex() { global_mutex(); } };
+ static call_global_mutex call_global_mutex_instance;
+
+ // Note that we need to use dlib::shared_ptr_thread_safe here rather than
+ // std::shared_ptr. This is because the destructor of event_handler_thread
+ // triggers an event in the main event loop telling it to stop and that event loop
+ // holds a shared pointer to the event_handler_thread. So what can happen is
+ // global_data()'s shared pointer gets destructed because the program is
+ // terminating, which causes the event_handler_thread destructor to run, which
+ // eventually causes the event loop to ask global_data() for a handle to the event
+ // thread. This is bad for std::shared_ptr since (at least as of 2017) most
+ // implementations of std::shared_ptr decrement the reference count to 0 before
+ // invoking the event_handler_thread's destructor and then when the event thread
+ // calls global_data() it increments the counter again, then decrements it back to
+ // 0, triggering a double deletion, when the event handler routine finally
+ // finishes.
+ //
+ // dlib::shared_ptr_thread_safe doesn't have this problem. An alternative would be
+ // to somehow avoid this kind of self reference. But it's not obvious how to do
+ // that given the limitations of the Win32 event WndProc() structure imposed by
+ // windows. So in any case, we just use the old dlib::shared_ptr_thread_safe to
+ // avoid this problem.
+ const shared_ptr_thread_safe<event_handler_thread>& global_data()
+ {
+ auto_mutex M(*global_mutex());
+ static shared_ptr_thread_safe<event_handler_thread> p;
+ if (p.get() == 0)
+ {
+ p.reset(new event_handler_thread());
+ M.unlock();
+ p->start_event_thread();
+ }
+ return p;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ struct ebh_param
+ {
+ std::string text;
+ std::string title;
+ };
+
+ static void error_box_helper(void* param)
+ {
+ ebh_param& p = *static_cast<ebh_param*>(param);
+#ifdef UNICODE
+ MessageBox (NULL, convert_mbstring_to_wstring(p.text).c_str(),
+ convert_mbstring_to_wstring(p.title).c_str(), MB_OK|MB_ICONERROR|MB_SYSTEMMODAL
+ );
+#else
+ MessageBox (NULL, p.text.c_str(),
+ p.title.c_str(), MB_OK|MB_ICONERROR|MB_SYSTEMMODAL
+ );
+#endif
+ delete &p;
+ }
+
+ static void error_box (
+ const char* title,
+ const char* text,
+ bool nonblocking = false
+ )
+ {
+ try
+ {
+ if (nonblocking)
+ {
+ ebh_param* param = new ebh_param;
+ param->text = text;
+ param->title = title;
+ dlib::create_new_thread(error_box_helper,param);
+ }
+ else
+ {
+#ifdef UNICODE
+ MessageBox (NULL, convert_mbstring_to_wstring(text).c_str(),
+ convert_mbstring_to_wstring(title).c_str(),
+ MB_OK|MB_ICONERROR|MB_SYSTEMMODAL
+ );
+#else
+ MessageBox (NULL, text,
+ title, MB_OK|MB_ICONERROR|MB_SYSTEMMODAL
+ );
+#endif
+ }
+ }
+ catch (...)
+ {
+ // we are totally screwed if this happens so just quit
+ exit(0);
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ static bool map_keys (
+ unsigned long keycode,
+ bool shift,
+ bool caps,
+ unsigned long& result,
+ bool& is_printable
+ )
+ /*!
+ requires
+ - if (shift was down for this key) then
+ - shift == true
+ - if (caps lock was on for this key) then
+ - caps == true
+ - keycode == the keycode from windows that we are to process
+ - keycode < keyboard_keys_size
+ ensures
+ - if (this key should be ignored) then
+ - returns false
+ - else
+ - returns true
+ - #is_printable == true if result is a printable ascii character
+ - #result == the keycode converted into the proper number to tbe
+ returned by the event handler.
+ !*/
+ {
+ is_printable = true;
+
+ if (keycode <= '9' && keycode >= '0')
+ {
+ result = keycode;
+ if (shift)
+ {
+ switch (result)
+ {
+ case '0': result = ')'; break;
+ case '1': result = '!'; break;
+ case '2': result = '@'; break;
+ case '3': result = '#'; break;
+ case '4': result = '$'; break;
+ case '5': result = '%'; break;
+ case '6': result = '^'; break;
+ case '7': result = '&'; break;
+ case '8': result = '*'; break;
+ case '9': result = '('; break;
+ }
+ }
+ }
+ else if (keycode <= 'Z' && keycode >= 'A')
+ {
+ result = keycode;
+
+ // make the result lower case if we need to.
+ if ((shift && caps) || (!caps && !shift))
+ result = result - 'A' + 'a';
+ }
+ else
+ {
+ switch (keycode)
+ {
+ case VK_BACK:
+ is_printable = false;
+ result = base_window::KEY_BACKSPACE;
+ break;
+
+ case VK_SHIFT:
+ is_printable = false;
+ result = base_window::KEY_SHIFT;
+ break;
+
+ case VK_CONTROL:
+ is_printable = false;
+ result = base_window::KEY_CTRL;
+ break;
+
+ case VK_MENU:
+ is_printable = false;
+ result = base_window::KEY_ALT;
+ break;
+
+ case VK_PAUSE:
+ is_printable = false;
+ result = base_window::KEY_PAUSE;
+ break;
+
+ case VK_CAPITAL:
+ is_printable = false;
+ result = base_window::KEY_CAPS_LOCK;
+ break;
+
+ case VK_ESCAPE:
+ is_printable = false;
+ result = base_window::KEY_ESC;
+ break;
+
+ case VK_PRIOR:
+ is_printable = false;
+ result = base_window::KEY_PAGE_UP;
+ break;
+
+ case VK_NEXT:
+ is_printable = false;
+ result = base_window::KEY_PAGE_DOWN;
+ break;
+
+ case VK_END:
+ is_printable = false;
+ result = base_window::KEY_END;
+ break;
+
+ case VK_HOME:
+ is_printable = false;
+ result = base_window::KEY_HOME;
+ break;
+
+ case VK_LEFT:
+ is_printable = false;
+ result = base_window::KEY_LEFT;
+ break;
+
+ case VK_RIGHT:
+ is_printable = false;
+ result = base_window::KEY_RIGHT;
+ break;
+
+ case VK_UP:
+ is_printable = false;
+ result = base_window::KEY_UP;
+ break;
+
+ case VK_DOWN:
+ is_printable = false;
+ result = base_window::KEY_DOWN;
+ break;
+
+ case VK_INSERT:
+ is_printable = false;
+ result = base_window::KEY_INSERT;
+ break;
+
+ case VK_DELETE:
+ is_printable = false;
+ result = base_window::KEY_DELETE;
+ break;
+
+ case 0x91:
+ is_printable = false;
+ result = base_window::KEY_SCROLL_LOCK;
+ break;
+
+ case VK_F1:
+ is_printable = false;
+ result = base_window::KEY_F1;
+ break;
+
+ case VK_F2:
+ is_printable = false;
+ result = base_window::KEY_F2;
+ break;
+
+ case VK_F3:
+ is_printable = false;
+ result = base_window::KEY_F3;
+ break;
+
+ case VK_F4:
+ is_printable = false;
+ result = base_window::KEY_F4;
+ break;
+
+ case VK_F5:
+ is_printable = false;
+ result = base_window::KEY_F5;
+ break;
+
+ case VK_F6:
+ is_printable = false;
+ result = base_window::KEY_F6;
+ break;
+
+ case VK_F7:
+ is_printable = false;
+ result = base_window::KEY_F7;
+ break;
+
+ case VK_F8:
+ is_printable = false;
+ result = base_window::KEY_F8;
+ break;
+
+ case VK_F9:
+ is_printable = false;
+ result = base_window::KEY_F9;
+ break;
+
+ case VK_F10:
+ is_printable = false;
+ result = base_window::KEY_F10;
+ break;
+
+ case VK_F11:
+ is_printable = false;
+ result = base_window::KEY_F11;
+ break;
+
+ case VK_F12:
+ is_printable = false;
+ result = base_window::KEY_F12;
+ break;
+
+
+ case VK_SPACE: result = ' '; break;
+ case VK_TAB: result = '\t'; break;
+ case VK_RETURN: result = '\n'; break;
+ case VK_NUMPAD0: result = '0'; break;
+ case VK_NUMPAD1: result = '1'; break;
+ case VK_NUMPAD2: result = '2'; break;
+ case VK_NUMPAD3: result = '3'; break;
+ case VK_NUMPAD4: result = '4'; break;
+ case VK_NUMPAD5: result = '5'; break;
+ case VK_NUMPAD6: result = '6'; break;
+ case VK_NUMPAD7: result = '7'; break;
+ case VK_NUMPAD8: result = '8'; break;
+ case VK_NUMPAD9: result = '9'; break;
+
+ case VK_MULTIPLY: result = '*'; break;
+ case VK_ADD: result = '+'; break;
+ case VK_SUBTRACT: result = '-'; break;
+ case VK_DECIMAL: result = '.'; break;
+ case VK_DIVIDE: result = '/'; break;
+
+ case VK_OEM_1:
+ if (shift) result = ':';
+ else result = ';';
+ break;
+
+ case VK_OEM_PLUS:
+ if (shift) result = '+';
+ else result = '=';
+ break;
+
+ case VK_OEM_COMMA:
+ if (shift) result = '<';
+ else result = ',';
+ break;
+
+ case VK_OEM_MINUS:
+ if (shift) result = '_';
+ else result = '-';
+ break;
+
+ case VK_OEM_PERIOD:
+ if (shift) result = '>';
+ else result = '.';
+ break;
+
+ case VK_OEM_2:
+ if (shift) result = '?';
+ else result = '/';
+ break;
+
+ case VK_OEM_3:
+ if (shift) result = '~';
+ else result = '`';
+ break;
+
+ case VK_OEM_4:
+ if (shift) result = '{';
+ else result = '[';
+ break;
+
+ case VK_OEM_5:
+ if (shift) result = '|';
+ else result = '\\';
+ break;
+
+ case VK_OEM_6:
+ if (shift) result = '}';
+ else result = ']';
+ break;
+
+ case VK_OEM_7:
+ if (shift) result = '"';
+ else result = '\'';
+ break;
+
+ default:
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ LRESULT CALLBACK WndProc (
+ HWND hwnd,
+ UINT message,
+ WPARAM wParam,
+ LPARAM lParam
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ // Make the event processing thread have a priority slightly above normal.
+ // This makes the GUI smother if you do heavy processing in other threads.
+ HANDLE hand = OpenThread(THREAD_ALL_ACCESS,FALSE,GetCurrentThreadId());
+ SetThreadPriority(hand,THREAD_PRIORITY_ABOVE_NORMAL);
+ CloseHandle(hand);
+
+ auto globals = global_data();
+
+ window_table_type& window_table = globals->window_table;
+ HWND& helper_window = globals->helper_window;
+
+ auto_mutex M(window_table.get_mutex());
+
+ try
+ {
+ std::vector<unsigned char> bitmap_buffer;
+
+ bool is_double = false;
+ unsigned long btn = base_window::NONE;
+
+ switch (message)
+ {
+ case WM_USER+QUIT_EVENT_HANDLER_THREAD:
+ if (hwnd == helper_window)
+ {
+ globals->quit_windows_loop = true;
+ PostQuitMessage(0);
+ }
+ return 0;
+
+ case WM_USER+DESTROY_WINDOW:
+ if (hwnd == helper_window)
+ {
+ DestroyWindow((HWND)wParam);
+ }
+ return 0;
+
+ case WM_USER+CALL_MOVE_WINDOW:
+ if (hwnd == helper_window)
+ {
+ MoveWindow(
+ globals->move_window_hwnd,
+ globals->move_window_x,
+ globals->move_window_y,
+ globals->move_window_width,
+ globals->move_window_height,
+ TRUE);
+ globals->move_window_done = true;
+ globals->et_signaler.broadcast();
+ }
+ return 0;
+
+ case WM_USER+USER_EVENTS_READY:
+ if (hwnd == helper_window)
+ {
+ // this is the signal to look in the user_events queue
+ globals->user_events.lock();
+ globals->user_events.swap(globals->user_events_temp);
+ globals->user_events.unlock();
+ globals->user_events_temp.reset();
+ // now dispatch all these user events
+ while (globals->user_events_temp.move_next())
+ {
+ base_window** win_ = window_table[globals->user_events_temp.element().w];
+ base_window* win;
+ // if this window exists in the window table then dispatch
+ // its event.
+ if (win_)
+ {
+ win = *win_;
+ win->on_user_event(
+ globals->user_events_temp.element().p,
+ globals->user_events_temp.element().i
+ );
+ }
+ }
+ globals->user_events_temp.clear();
+ }
+ return 0;
+
+ case WM_USER+SET_ACTIVE_WINDOW:
+ if (hwnd == helper_window)
+ {
+ SetActiveWindow((HWND)wParam);
+ }
+ return 0;
+
+ case WM_USER+SHOW_WINDOW_SHOW:
+ if (hwnd == helper_window)
+ {
+ ShowWindow((HWND)wParam,SW_SHOW);
+ BringWindowToTop((HWND)wParam);
+ }
+ return 0;
+
+ case WM_USER+SHOW_WINDOW_HIDE:
+ if (hwnd == helper_window)
+ {
+ ShowWindow((HWND)wParam,SW_HIDE);
+ }
+ return 0;
+
+ case WM_USER+CALL_SET_WINDOW_TITLE:
+ if (hwnd == helper_window)
+ {
+ SetWindowTextW((HWND)wParam,globals->window_title.c_str());
+ globals->set_window_title_done = true;
+ globals->et_signaler.broadcast();
+ }
+ return 0;
+
+
+ case WM_USER+CREATE_WINDOW:
+ if (hwnd == helper_window)
+ {
+
+ // if this is stupposed to be a popup window then do the popup window thing
+ if (globals->dwStyle == WS_CHILD)
+ {
+ TCHAR nothing[] = TEXT("");
+ globals->new_window = CreateWindowEx (WS_EX_TOOLWINDOW,globals->window_class_name, nothing,
+ globals->dwStyle,
+ CW_USEDEFAULT, CW_USEDEFAULT,
+ CW_USEDEFAULT, CW_USEDEFAULT,
+ helper_window, NULL, globals->hInstance, NULL);
+ SetParent(globals->new_window,NULL);
+ }
+ else
+ {
+ TCHAR nothing[] = TEXT("");
+ globals->new_window = CreateWindow (globals->window_class_name, nothing,
+ globals->dwStyle,
+ CW_USEDEFAULT, CW_USEDEFAULT,
+ CW_USEDEFAULT, CW_USEDEFAULT,
+ NULL, NULL, globals->hInstance, NULL);
+ }
+ // use the helper_window to indicate that CreateWindow failed
+ if (globals->new_window == NULL)
+ globals->new_window = helper_window;
+ globals->et_signaler.broadcast();
+ }
+ return 0;
+
+ case WM_SYSKEYDOWN:
+ case WM_KEYDOWN:
+ {
+ if (globals->in_ime_composition) break;
+
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+ unsigned long state = 0;
+
+ bool shift = ((GetKeyState(VK_SHIFT)&0x8000)!=0);
+ bool ctrl = ((GetKeyState(VK_CONTROL)&0x8000)!=0);
+ bool caps = ((GetKeyState(VK_CAPITAL)&0x0001)!=0);
+ if(shift)
+ state = base_window::KBD_MOD_SHIFT;
+ if(ctrl)
+ state |= base_window::KBD_MOD_CONTROL;
+ if(caps)
+ state |= base_window::KBD_MOD_CAPS_LOCK;
+ if((GetKeyState(VK_MENU)&0x8000)!=0)
+ state |= base_window::KBD_MOD_ALT;
+ if((GetKeyState(VK_NUMLOCK)&0x0001)!=0)
+ state |= base_window::KBD_MOD_NUM_LOCK;
+ if((GetKeyState(VK_SCROLL)&0x0001)!=0)
+ state |= base_window::KBD_MOD_SCROLL_LOCK;
+
+
+ bool is_printable;
+ unsigned long result;
+
+ if (map_keys(wParam,shift,caps,result,is_printable))
+ {
+ // signal the keyboard event
+ win->on_keydown(result,is_printable,state);
+ }
+
+ }
+ break;
+
+ // treat the user releasing the mouse button on the non client area (e.g. the title bar)
+ // like focus being lost since that is what X11 does
+ case WM_NCLBUTTONUP:
+ case WM_NCMBUTTONUP:
+ case WM_NCRBUTTONUP:
+ case WM_SETFOCUS:
+ {
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+ // signal that the window is gaining focus
+ win->on_focus_gained();
+ }
+ break;
+
+ // treat the user clicking on the non client area (e.g. the title bar)
+ // like focus being lost since that is what X11 does
+ case WM_NCLBUTTONDBLCLK:
+ case WM_NCMBUTTONDBLCLK:
+ case WM_NCRBUTTONDBLCLK:
+ case WM_NCLBUTTONDOWN:
+ case WM_NCMBUTTONDOWN:
+ case WM_NCRBUTTONDOWN:
+ case WM_KILLFOCUS:
+ {
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+ // signal that the window is gaining focus
+ win->on_focus_lost();
+ }
+ break;
+
+ case WM_SIZE:
+ {
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+
+ // signal that the window has been resized
+ win->on_window_resized();
+
+ }
+ return 0;
+
+ case WM_MOVE:
+ {
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+
+ // signal that the window has moved
+ win->on_window_moved();
+
+ }
+ return 0;
+
+ case WM_MOUSELEAVE:
+ {
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+
+ // signal that the mouse has left the window
+ if (win->mouse_in)
+ {
+ win->on_mouse_leave();
+ win->mouse_in = false;
+ }
+
+ }
+ return 0;
+
+ case WM_MOUSEWHEEL:
+ {
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+ unsigned long state = 0;
+ if (wParam & MK_CONTROL)
+ state |= base_window::CONTROL;
+ if (wParam & MK_LBUTTON)
+ state |= base_window::LEFT;
+ if (wParam & MK_MBUTTON)
+ state |= base_window::MIDDLE;
+ if (wParam & MK_RBUTTON)
+ state |= base_window::RIGHT;
+ if (wParam & MK_SHIFT)
+ state |= base_window::SHIFT;
+
+ // signal the mouse wheel event
+ if (GET_WHEEL_DELTA_WPARAM(wParam) > 0)
+ {
+ win->on_wheel_up(state);
+ }
+ else
+ {
+ win->on_wheel_down(state);
+ }
+
+ }
+ return 0;
+
+ case WM_LBUTTONUP:
+ btn = base_window::LEFT;
+ case WM_MBUTTONUP:
+ if (btn == base_window::NONE)
+ btn = base_window::MIDDLE;
+ case WM_RBUTTONUP:
+ if (btn == base_window::NONE)
+ btn = base_window::RIGHT;
+ {
+ // release the mouse capture if the user isn't holding any
+ // other mouse buttons
+ if (!((wParam & MK_LBUTTON) | (wParam & MK_MBUTTON) | (wParam & MK_RBUTTON)))
+ ReleaseCapture();
+
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+ unsigned long state = 0;
+ if (wParam & MK_CONTROL)
+ state |= base_window::CONTROL;
+ if (wParam & MK_LBUTTON)
+ state |= base_window::LEFT;
+ if (wParam & MK_MBUTTON)
+ state |= base_window::MIDDLE;
+ if (wParam & MK_RBUTTON)
+ state |= base_window::RIGHT;
+ if (wParam & MK_SHIFT)
+ state |= base_window::SHIFT;
+
+ // remove the clicked button from the state
+ state &= (~btn);
+
+ // signal the mouse click
+ win->on_mouse_up(btn,state,GET_X_LPARAM(lParam),GET_Y_LPARAM(lParam));
+
+ }
+ return 0;
+
+
+
+ case WM_LBUTTONDBLCLK:
+ if (btn == base_window::NONE)
+ btn = base_window::LEFT;
+ case WM_MBUTTONDBLCLK:
+ if (btn == base_window::NONE)
+ btn = base_window::MIDDLE;
+ case WM_RBUTTONDBLCLK:
+ if (btn == base_window::NONE)
+ btn = base_window::RIGHT;
+ is_double = true;
+ case WM_LBUTTONDOWN:
+ if (btn == base_window::NONE)
+ btn = base_window::LEFT;
+ case WM_MBUTTONDOWN:
+ if (btn == base_window::NONE)
+ btn = base_window::MIDDLE;
+ case WM_RBUTTONDOWN:
+ if (btn == base_window::NONE)
+ btn = base_window::RIGHT;
+ {
+ SetCapture(hwnd);
+
+
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+ unsigned long state = 0;
+ if (wParam & MK_CONTROL)
+ state |= base_window::CONTROL;
+ if (wParam & MK_LBUTTON)
+ state |= base_window::LEFT;
+ if (wParam & MK_MBUTTON)
+ state |= base_window::MIDDLE;
+ if (wParam & MK_RBUTTON)
+ state |= base_window::RIGHT;
+ if (wParam & MK_SHIFT)
+ state |= base_window::SHIFT;
+
+ // remove the clicked button from the state
+ state &= (~btn);
+
+ // signal the mouse click
+ win->on_mouse_down(btn,state,GET_X_LPARAM(lParam),GET_Y_LPARAM(lParam),is_double);
+
+ }
+ return 0;
+
+ case WM_MOUSEMOVE:
+ {
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+ unsigned long state = 0;
+ bool mouse_button_down = false;
+ if (wParam & MK_CONTROL)
+ state |= base_window::CONTROL;
+ if (wParam & MK_LBUTTON)
+ {
+ state |= base_window::LEFT;
+ mouse_button_down = true;
+ }
+ if (wParam & MK_MBUTTON)
+ {
+ mouse_button_down = true;
+ state |= base_window::MIDDLE;
+ }
+ if (wParam & MK_RBUTTON)
+ {
+ state |= base_window::RIGHT;
+ mouse_button_down = true;
+ }
+ if (wParam & MK_SHIFT)
+ state |= base_window::SHIFT;
+
+ // signal the mouse movement if this mouse event isn't identical to the
+ // last one we got
+ if ( GET_X_LPARAM(lParam) != win->prevx ||
+ GET_Y_LPARAM(lParam) != win->prevy ||
+ state != win->prev_state)
+ {
+ win->on_mouse_move(state,GET_X_LPARAM(lParam),GET_Y_LPARAM(lParam));
+ }
+
+ // save the event data into the prev* member variables
+ win->prevx = GET_X_LPARAM(lParam);
+ win->prevy = GET_Y_LPARAM(lParam);
+ win->prev_state = state;
+
+ // The following block of code checks if the mouse is moving
+ // into or out of the window.
+ if (mouse_button_down == false)
+ {
+ // if there isn't any mouse button down then the fact that
+ // we are getting a mouse move message means it is in the
+ // window
+ if (win->mouse_in == false)
+ {
+ win->on_mouse_enter();
+ win->mouse_in = true;
+
+ // set the tracker for the mouse
+ TRACKMOUSEEVENT tm;
+ tm.hwndTrack = hwnd;
+ tm.cbSize = sizeof(tm);
+ tm.dwFlags = TME_LEAVE;
+ _TrackMouseEvent(&tm);
+ }
+ }
+ else if (win->mouse_in)
+ {
+ // check if the mouse is currently outside the window
+ const long mouse_x = GET_X_LPARAM(lParam);
+ const long mouse_y = GET_Y_LPARAM(lParam);
+ if (mouse_x < 0 || mouse_y < 0)
+ {
+ // the mouse is not in the window
+ win->mouse_in = false;
+ win->on_mouse_leave();
+ }
+ else
+ {
+ unsigned long width, height;
+ win->get_size(width,height);
+ if (mouse_x >= static_cast<long>(width) ||
+ mouse_y >= static_cast<long>(height))
+ {
+ // the mouse is not in the window
+ win->mouse_in = false;
+ win->on_mouse_leave();
+ }
+ }
+ }
+ else if (win->mouse_in == false)
+ {
+ // at this point we know that the mouse is moving around
+ // with some of its buttons down. So it might be outside the window.
+ // get the window size and see if the mouse is outside
+ // it.
+ const long mouse_x = GET_X_LPARAM(lParam);
+ const long mouse_y = GET_Y_LPARAM(lParam);
+ unsigned long width, height;
+ win->get_size(width,height);
+ if (mouse_x < static_cast<long>(width) &&
+ mouse_y < static_cast<long>(height) &&
+ mouse_x >= 0 &&
+ mouse_y >= 0)
+ {
+ // The mouse has gone inside the window
+ win->mouse_in = true;
+ win->on_mouse_enter();
+
+ // set the tracker for the mouse
+ TRACKMOUSEEVENT tm;
+ tm.hwndTrack = hwnd;
+ tm.cbSize = sizeof(tm);
+ tm.dwFlags = TME_LEAVE;
+ _TrackMouseEvent(&tm);
+ }
+
+ }
+
+
+ }
+ return 0;
+
+ case WM_PAINT :
+ {
+
+ PAINTSTRUCT ps;
+ HDC hdc = NULL;
+
+ hdc = BeginPaint (hwnd, &ps) ;
+
+ try
+ {
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+
+
+
+ LONG x = ps.rcPaint.left;
+ LONG y = ps.rcPaint.top;
+ LONG width = ps.rcPaint.right - x;
+ LONG height = ps.rcPaint.bottom - y;
+
+ if (width != 0 && height != 0)
+ {
+
+ BITMAPINFO bmap_info;
+ bmap_info.bmiColors[0].rgbBlue = 0;
+ bmap_info.bmiColors[0].rgbGreen = 0;
+ bmap_info.bmiColors[0].rgbRed = 0;
+ bmap_info.bmiColors[0].rgbReserved = 0;
+ bmap_info.bmiHeader.biSize = sizeof(bmap_info.bmiHeader);
+ bmap_info.bmiHeader.biWidth = width;
+ bmap_info.bmiHeader.biHeight = -1*height;
+ bmap_info.bmiHeader.biPlanes = 1;
+ bmap_info.bmiHeader.biBitCount = 24;
+ bmap_info.bmiHeader.biCompression = BI_RGB;
+ bmap_info.bmiHeader.biSizeImage = 0;
+ bmap_info.bmiHeader.biXPelsPerMeter = 0;
+ bmap_info.bmiHeader.biYPelsPerMeter = 0;
+ bmap_info.bmiHeader.biClrUsed = 0;
+ bmap_info.bmiHeader.biClrImportant = 0;
+
+
+
+ unsigned char* bitmap ;
+ unsigned long size;
+ unsigned long padding = 0;
+ if ((width*3)%sizeof(LONG) != 0)
+ {
+ padding = sizeof(LONG) - (width*3)%sizeof(LONG);
+ size = (width*3+padding)*height;
+ }
+ else
+ {
+ size = width*height*3;
+ }
+
+ if (bitmap_buffer.size() < size)
+ bitmap_buffer.resize(size);
+ bitmap = &bitmap_buffer[0];
+
+ canvas bits(bitmap,padding,x,y,x+width-1,y+height-1);
+
+
+
+ win->paint(bits);
+
+
+
+ SetDIBitsToDevice (
+ hdc,
+ ps.rcPaint.left,
+ ps.rcPaint.top,
+ width,
+ height,
+ 0,
+ 0,
+ 0,
+ height,
+ bitmap,
+ &bmap_info,
+ DIB_RGB_COLORS
+ );
+ }
+
+ EndPaint (hwnd, &ps) ;
+
+ }
+ catch (...)
+ {
+ // make sure EndPaint is called even if an exception
+ // is thrown.
+ if (hdc != NULL)
+ EndPaint (hwnd, &ps);
+ throw;
+ }
+ }
+ return 0 ;
+
+ case WM_ERASEBKGND:
+ return 1;
+
+
+
+
+ case WM_CLOSE:
+ {
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+
+
+ // signal that the window is being closed
+ if (win->on_window_close() == base_window::DO_NOT_CLOSE_WINDOW)
+ {
+ DLIB_ASSERT(win->has_been_destroyed == false,
+ "\tYou called close_window() inside the on_window_close() event but"
+ << "\n\tthen returned DO_NOT_CLOSE_WINDOW. You can do one or the other but not both."
+ << "\n\tthis: " << win
+ );
+ // this happens if the on_window_close() callback
+ // tells us to ignore the close event.
+ return 0;
+ }
+ else
+ {
+ if (window_table[hwnd])
+ {
+ window_table.destroy(hwnd);
+ win->has_been_destroyed = true;
+ win->hwnd = 0;
+ globals->window_close_signaler.broadcast();
+ }
+ else
+ {
+ // in this case the window must have self destructed by
+ // calling delete this;
+ return 0;
+ }
+ }
+
+ }
+ return DefWindowProc (hwnd, message, wParam, lParam);
+
+ case WM_IME_STARTCOMPOSITION:
+ globals->in_ime_composition = true;
+ break;
+
+ case WM_IME_COMPOSITION:
+ {
+ globals->in_ime_composition = false;
+ base_window** win_ = window_table[hwnd];
+ base_window* win;
+ if (win_)
+ win = *win_;
+ else
+ break;
+ HIMC hImc = ImmGetContext(hwnd);
+ if (lParam & GCS_RESULTSTR){
+ WCHAR wc;
+ LONG bufbyte = ImmGetCompositionStringW(hImc, GCS_RESULTSTR, &wc, 0);
+ if (bufbyte != IMM_ERROR_NODATA && bufbyte != IMM_ERROR_GENERAL){
+ bufbyte += sizeof(WCHAR);
+
+
+ WCHAR *buf = new WCHAR[bufbyte / sizeof(WCHAR)];
+ ImmGetCompositionStringW(hImc, GCS_RESULTSTR, buf, bufbyte);
+ buf[bufbyte / sizeof(WCHAR) - 1] = L'\0';
+
+ // signal the putstring event
+ win->on_string_put(std::wstring(buf));
+ delete [] buf;
+ }
+ }
+ ImmReleaseContext(hwnd, hImc);
+ }
+ break;
+
+ default:
+ break;
+
+ } // switch (message)
+
+
+ }
+ catch (std::exception& e)
+ {
+ error_box("Exception thrown in event handler",e.what());
+ globals->quit_windows_loop = true;
+ }
+ catch (...)
+ {
+ error_box("Exception thrown in event handler","Unknown Exception type.");
+ globals->quit_windows_loop = true;
+ }
+
+ return DefWindowProc (hwnd, message, wParam, lParam) ;
+
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void show_window (
+ HWND hwnd
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ PostMessage(global_data()->helper_window,WM_USER+SHOW_WINDOW_SHOW,(WPARAM)hwnd,0);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void hide_window (
+ HWND hwnd
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ PostMessage(global_data()->helper_window,WM_USER+SHOW_WINDOW_HIDE,(WPARAM)hwnd,0);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void give_window_focus (
+ HWND hwnd
+ )
+ /*!
+ ensures
+ - calls SetActiveWindow(hwnd) from the event handling thread.
+ !*/
+ {
+ using namespace gui_core_kernel_1_globals;
+ PostMessage(global_data()->helper_window,WM_USER+SET_ACTIVE_WINDOW,(WPARAM)hwnd,0);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void destroy_window (
+ HWND hwnd
+ )
+ /*!
+ ensures
+ - calls DestroyWindow(hwnd) from the event handling thread.
+ !*/
+ {
+ using namespace gui_core_kernel_1_globals;
+ PostMessage(global_data()->helper_window,WM_USER+DESTROY_WINDOW,(WPARAM)hwnd,0);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ HWND make_window (
+ DWORD dwStyle_
+ )
+ /*!
+ ensures
+ - creates a window by calling CreateWindow and passes on the
+ dwStyle argument.
+ - returns the HWND that is returned by CreateWindow
+ - ensures that CreateWindow is called from the event handler thread
+ - if (it was unable to create a window) then
+ - returns NULL or helper_window
+ !*/
+ {
+ using namespace gui_core_kernel_1_globals;
+ auto globals = global_data();
+ // if we are running in the event handling thread then just call
+ // CreateWindow directly
+ if (get_thread_id() == globals->event_thread_id)
+ {
+ // if this is stupposed to be a popup window then do the popup window thing
+ if (dwStyle_ == WS_CHILD)
+ {
+ TCHAR nothing[] = TEXT("");
+ HWND tmp = CreateWindowEx (WS_EX_TOOLWINDOW|WS_EX_TOPMOST, globals->window_class_name, nothing,
+ dwStyle_,
+ CW_USEDEFAULT, CW_USEDEFAULT,
+ CW_USEDEFAULT, CW_USEDEFAULT,
+ globals->helper_window, NULL, globals->hInstance, NULL);
+ SetParent(tmp,NULL);
+ return tmp;
+ }
+ else
+ {
+ TCHAR nothing[] = TEXT("");
+ return CreateWindow (globals->window_class_name, nothing,
+ dwStyle_,
+ CW_USEDEFAULT, CW_USEDEFAULT,
+ CW_USEDEFAULT, CW_USEDEFAULT,
+ NULL, NULL, globals->hInstance, NULL);
+ }
+ }
+ else
+ {
+ auto_mutex M(globals->window_table.get_mutex());
+ // wait for our chance to make a new window request
+ while (globals->request_new_window)
+ globals->et_signaler.wait();
+
+
+ globals->dwStyle = dwStyle_;
+ if (PostMessage(globals->helper_window,WM_USER+CREATE_WINDOW,0,0)==0)
+ {
+ throw gui_error("Unable to schedule function for execution in event handling thread.");
+ }
+
+ // wait for our request to be serviced
+ while (globals->new_window == NULL)
+ globals->et_signaler.wait();
+
+ HWND temp = globals->new_window;
+ globals->new_window = NULL;
+ globals->request_new_window = false;
+ globals->et_signaler.broadcast();
+
+ // if make_window() returns the helper_window then it means it failed
+ // to make a new window
+ if (temp == globals->helper_window)
+ temp = NULL;
+
+ return temp;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+
+ } // end namespace gui_core_kernel_1_globals
+
+// ----------------------------------------------------------------------------------------
+
+ void canvas::
+ fill (
+ unsigned char red_,
+ unsigned char green_,
+ unsigned char blue_
+ ) const
+ {
+ const unsigned long red = red_;
+ const unsigned long green = green_;
+ const unsigned long blue = blue_;
+
+ const LONG block1 = (blue<<24) | (red<<16) | (green<<8) | blue;
+ const LONG block2 = (green<<24) | (blue<<16) | (red<<8) | green;
+ const LONG block3 = (red<<24) | (green<<16) | (blue<<8) | red;
+
+ // remember that row_width is a multiple of 4 because windows
+ // requires that all bitmaps have row widths that are multiples of 4.
+ unsigned long size = row_width/4;
+ for (unsigned long i = 0; i < height_; ++i)
+ {
+ unsigned long padding = size%3;
+ LONG* start = reinterpret_cast<LONG*>(bits+row_width*i);
+ LONG* end = reinterpret_cast<LONG*>(start) + size - padding;
+ while (start != end)
+ {
+ *start = block1;
+ ++start;
+ *start = block2;
+ ++start;
+ *start = block3;
+ ++start;
+ }
+ if (padding)
+ {
+ *start = block1;
+ ++start;
+ --padding;
+ }
+ if (padding)
+ {
+ *start = block2;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ trigger_user_event (
+ void* p,
+ int i
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+
+ user_event_type e;
+ e.w = hwnd;
+ e.p = p;
+ e.i = i;
+ {
+ auto_mutex M(globals->user_events.get_mutex());
+ globals->user_events.enqueue(e);
+ }
+
+ if (PostMessage(globals->helper_window,WM_USER+USER_EVENTS_READY,0,0)==0)
+ {
+ throw gui_error("Unable to schedule function for execution in event handling thread.");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ base_window::
+ base_window (
+ bool resizable,
+ bool undecorated
+ ) :
+ globals(gui_core_kernel_1_globals::global_data()),
+ has_been_destroyed(false),
+ prevx(-1),
+ prevy(-1),
+ prev_state(0),
+ wm(globals->window_table.get_mutex())
+ {
+ using namespace gui_core_kernel_1_globals;
+ DLIB_ASSERT(!(undecorated == true && resizable == true),
+ "\tbase_window::base_window()"
+ << "\n\tThere is no such thing as an undecorated window that is resizable by the user."
+ << "\n\tthis: " << this
+ );
+
+ if (resizable)
+ style = WS_OVERLAPPEDWINDOW;
+ else if (undecorated)
+ style = WS_CHILD;
+ else
+ style = WS_OVERLAPPEDWINDOW ^ WS_THICKFRAME ^ WS_MAXIMIZEBOX;
+
+ hwnd = gui_core_kernel_1_globals::make_window(style);
+
+ if (hwnd == NULL)
+ throw gui_error("unable to create base_window");
+
+ auto_mutex M(wm);
+
+ mouse_in = false;
+
+ HWND temp = hwnd;
+ base_window* ttemp = this;
+ globals->window_table.add(temp,ttemp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ base_window::
+ ~base_window (
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ close_window();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ close_window (
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ auto_mutex M(wm);
+ if (has_been_destroyed == false)
+ {
+ // do this just to make sure no one tries to call this window's
+ // calbacks.
+ globals->window_table.destroy(hwnd);
+ gui_core_kernel_1_globals::destroy_window(hwnd);
+ hwnd = 0;
+ has_been_destroyed = true;
+ globals->window_close_signaler.broadcast();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ wait_until_closed (
+ ) const
+ {
+ using namespace gui_core_kernel_1_globals;
+ auto_mutex M(wm);
+ while (has_been_destroyed == false)
+ globals->window_close_signaler.wait();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool base_window::
+ is_closed (
+ ) const
+ {
+ auto_mutex M(wm);
+ return has_been_destroyed;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ set_title (
+ const std::string &title
+ )
+ {
+ set_title(convert_mbstring_to_wstring(title));
+ }
+
+ void base_window::
+ set_title (
+ const ustring &title
+ )
+ {
+ set_title(convert_utf32_to_wstring(title));
+ }
+
+ void base_window::
+ set_title (
+ const std::wstring& title
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+
+ auto_mutex M(wm);
+ if (has_been_destroyed == true)
+ return;
+
+
+ // call the SetWindowText function with our arguments but make sure it is from
+ // the event thread. We have to do this because the SetWindowText() apparently blocks
+ // until something happens in the event thread so we have to
+ // do this to avoid possible deadlocks.
+ if (get_thread_id() == globals->event_thread_id)
+ {
+ SetWindowTextW(hwnd,title.c_str());
+ }
+ else
+ {
+ globals->window_title = title;
+ globals->set_window_title_done = false;
+
+ if (PostMessage(globals->helper_window,WM_USER+CALL_SET_WINDOW_TITLE,(WPARAM)hwnd,0)==0)
+ {
+ throw gui_error("Unable to schedule SetWindowText function for execution in event handling thread.");
+ }
+
+ // wait for any SetWindowText() calls to finish
+ while (globals->set_window_title_done == false)
+ globals->et_signaler.wait();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ show (
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ auto_mutex M(wm);
+ if (has_been_destroyed == true)
+ return;
+
+ show_window(hwnd);
+ if (style != WS_CHILD)
+ give_window_focus(hwnd);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ hide(
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ auto_mutex M(wm);
+ if (has_been_destroyed == true)
+ return;
+
+ hide_window(hwnd);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ set_size (
+ int width_,
+ int height_
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ auto_mutex M(wm);
+ if (has_been_destroyed == true)
+ return;
+
+ if (get_thread_id() == globals->event_thread_id)
+ {
+ RECT info;
+ GetWindowRect(hwnd,&info);
+
+ int x = info.left;
+ int y = info.top;
+ int width;
+ int height;
+
+ RECT rect;
+ rect.top = 0;
+ rect.left = 0;
+ rect.bottom = height_;
+ rect.right = width_;
+ AdjustWindowRectEx(&rect,style,FALSE,0);
+
+ width = std::abs(rect.right - rect.left);
+ height = std::abs(rect.bottom - rect.top);
+
+ MoveWindow(
+ hwnd,
+ x,
+ y,
+ width,
+ height,
+ TRUE);
+ }
+ else
+ {
+ RECT info;
+ GetWindowRect(hwnd,&info);
+
+ int x = info.left;
+ int y = info.top;
+ int width;
+ int height;
+
+ RECT rect;
+ rect.top = 0;
+ rect.left = 0;
+ rect.bottom = height_;
+ rect.right = width_;
+ AdjustWindowRectEx(&rect,style,FALSE,0);
+
+ width = std::abs(rect.right - rect.left);
+ height = std::abs(rect.bottom - rect.top);
+
+ // call the MoveWindow function with our arguments. We
+ // have to do this because the MoveWindow() apparently blocks
+ // until something happens in the event thread so we have to
+ // do this to avoid possible deadlocks.
+ globals->move_window_hwnd = hwnd;
+ globals->move_window_x = x;
+ globals->move_window_y = y;
+ globals->move_window_width = width;
+ globals->move_window_height = height;
+ globals->move_window_done = false;
+
+ if (PostMessage(globals->helper_window,WM_USER+CALL_MOVE_WINDOW,0,0)==0)
+ {
+ throw gui_error("Unable to schedule MoveWindow function for execution in event handling thread.");
+ }
+
+ // wait for any MoveWindow calls to finish
+ while (globals->move_window_done == false)
+ globals->et_signaler.wait();
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ set_pos (
+ long x_,
+ long y_
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ auto_mutex M(wm);
+ if (has_been_destroyed == true)
+ return;
+
+ if (get_thread_id() == globals->event_thread_id)
+ {
+ RECT info;
+ GetWindowRect(hwnd,&info);
+ int width = info.right - info.left;
+ int height = info.bottom - info.top;
+
+ MoveWindow(
+ hwnd,
+ x_,
+ y_,
+ width,
+ height,
+ TRUE);
+
+ }
+ else
+ {
+ RECT info;
+ GetWindowRect(hwnd,&info);
+ int width = info.right - info.left;
+ int height = info.bottom - info.top;
+
+
+
+ // call the MoveWindow function with our arguments. We
+ // have to do this because the MoveWindow() apparently blocks
+ // until something happens in the event thread so we have to
+ // do this to avoid possible deadlocks.
+ globals->move_window_hwnd = hwnd;
+ globals->move_window_x = x_;
+ globals->move_window_y = y_;
+ globals->move_window_width = width;
+ globals->move_window_height = height;
+ globals->move_window_done = false;
+
+ if (PostMessage(globals->helper_window,WM_USER+CALL_MOVE_WINDOW,0,0)==0)
+ {
+ throw gui_error("Unable to schedule MoveWindow function for execution in event handling thread.");
+ }
+
+ // wait for any MoveWindow calls to finish
+ while (globals->move_window_done == false)
+ globals->et_signaler.wait();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ get_pos (
+ long& x_,
+ long& y_
+ )
+ {
+ auto_mutex M(wm);
+ x_ = 0;
+ y_ = 0;
+ if (has_been_destroyed == true)
+ return;
+
+ POINT p;
+ p.x = 0;
+ p.y = 0;
+ ClientToScreen(hwnd,&p);
+
+ x_ = p.x;
+ y_ = p.y;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ get_display_size (
+ unsigned long& width,
+ unsigned long& height
+ ) const
+ {
+ auto_mutex M(wm);
+ width = 0;
+ height = 0;
+ if (has_been_destroyed == true)
+ return;
+
+
+ RECT rc;
+ GetWindowRect(hwnd, &rc);
+
+ HMONITOR hMonitor;
+ MONITORINFO mi;
+ //
+ // get the nearest monitor to the passed rect.
+ //
+ hMonitor = MonitorFromRect(&rc, MONITOR_DEFAULTTONEAREST);
+
+ //
+ // get the work area or entire monitor rect.
+ //
+ mi.cbSize = sizeof(mi);
+ GetMonitorInfo(hMonitor, &mi);
+
+ rc = mi.rcMonitor;
+
+ width = static_cast<unsigned long>(rc.right - rc.left);
+ height = static_cast<unsigned long>(rc.bottom - rc.top);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ get_size (
+ unsigned long& width,
+ unsigned long& height
+ ) const
+ {
+ auto_mutex M(wm);
+ width = 0;
+ height = 0;
+ if (has_been_destroyed == true)
+ return;
+
+
+ RECT r;
+ GetClientRect(hwnd,&r);
+
+ width = r.right - r.left;
+ height = r.bottom - r.top;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ invalidate_rectangle (
+ const rectangle& rect
+ )
+ {
+ auto_mutex M(wm);
+ if (rect.is_empty() == false && !has_been_destroyed)
+ {
+ RECT info;
+ info.top = rect.top();
+ info.left = rect.left();
+ info.right = rect.right()+1;
+ info.bottom = rect.bottom()+1;
+
+ InvalidateRect(hwnd,&info,FALSE);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ set_im_pos (
+ long x,
+ long y
+ )
+ {
+ auto_mutex M(wm);
+ if (has_been_destroyed == true)
+ return;
+
+ HIMC hImc = ImmGetContext(hwnd);
+
+ COMPOSITIONFORM cf;
+ cf.dwStyle = CFS_POINT;
+ cf.ptCurrentPos.x = x;
+ cf.ptCurrentPos.y = y;
+ ImmSetCompositionWindow(hImc, &cf);
+ ImmReleaseContext(hwnd, hImc);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void put_on_clipboard (
+ const std::string& str
+ )
+ {
+ put_on_clipboard(convert_mbstring_to_wstring(str));
+ }
+
+ void put_on_clipboard (
+ const dlib::ustring& str
+ )
+ {
+ put_on_clipboard(convert_utf32_to_wstring(str));
+ }
+
+ void put_on_clipboard (
+ const std::wstring& str
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ using namespace std;
+
+ auto globals = global_data();
+
+ if (OpenClipboard(globals->helper_window))
+ {
+ EmptyClipboard();
+ auto_mutex M(globals->window_table.get_mutex());
+
+ const unsigned long newlines = count(str.begin(),str.end(),L'\n');
+
+ HGLOBAL mem = GlobalAlloc(GMEM_MOVEABLE,(str.size()+newlines+1)*sizeof(wchar_t));
+ if (mem != NULL)
+ {
+ wchar_t* buf = reinterpret_cast<wchar_t*>(GlobalLock(mem));
+
+ if (buf != NULL)
+ {
+ // copy str into buf while also replacing all the \n with \r\n
+ for (wstring::size_type i = 0; i < str.size(); ++i)
+ {
+ if (str[i] != L'\n')
+ {
+ *buf = str[i];
+ ++buf;
+ }
+ else
+ {
+ *buf = L'\r';
+ ++buf;
+ *buf = L'\n';
+ ++buf;
+ }
+ }
+ *buf = L'\0';
+ GlobalUnlock(mem);
+ SetClipboardData(CF_UNICODETEXT,mem);
+ }
+ }
+ CloseClipboard();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void get_from_clipboard (
+ std::string& str
+ )
+ {
+ std::wstring wstr;
+ get_from_clipboard(wstr);
+ str = convert_wstring_to_mbstring(wstr);
+ }
+
+ void get_from_clipboard (
+ dlib::ustring& str
+ )
+ {
+ std::wstring wstr;
+ get_from_clipboard(wstr);
+ str = convert_wstring_to_utf32(wstr);
+ }
+
+ void get_from_clipboard (
+ std::wstring& str
+ )
+ {
+ using namespace gui_core_kernel_1_globals;
+ using namespace std;
+ auto globals = global_data();
+
+ auto_mutex M(globals->window_table.get_mutex());
+ if (OpenClipboard(globals->helper_window))
+ {
+
+ HANDLE data = GetClipboardData(CF_UNICODETEXT);
+ if (data != NULL)
+ {
+ wchar_t* buf = reinterpret_cast<wchar_t*>(GlobalLock(data));
+ if (buf != 0)
+ {
+ str.clear();
+
+ // copy the data from buf into str while also removing any '\r'
+ // characters.
+ while (*buf != L'\0')
+ {
+ if (*buf != L'\r')
+ str += *buf;
+ ++buf;
+ }
+
+ GlobalUnlock(data);
+ }
+ else
+ {
+ Beep(500,500);
+ }
+ }
+
+ CloseClipboard();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // WIN32
+
+#endif // DLIB_GUI_CORE_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_1.h b/ml/dlib/dlib/gui_core/gui_core_kernel_1.h
new file mode 100644
index 000000000..b7077ac01
--- /dev/null
+++ b/ml/dlib/dlib/gui_core/gui_core_kernel_1.h
@@ -0,0 +1,420 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GUI_CORE_KERNEl_1_
+#define DLIB_GUI_CORE_KERNEl_1_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+#ifdef DLIB_NO_GUI_SUPPORT
+#error "DLIB_NO_GUI_SUPPORT is defined so you can't use the GUI code. Turn DLIB_NO_GUI_SUPPORT off if you want to use it."
+#endif
+
+#include <string>
+
+#include "../windows_magic.h"
+
+
+#include <windows.h>
+#include <winuser.h>
+#include <windowsx.h>
+#include <commctrl.h>
+
+#include "gui_core_kernel_abstract.h"
+
+#ifdef _MSC_VER
+// Disable the following warnings for Visual Studio
+//
+// These two warnings have to do with converting points to and from the LONG
+// type. But both these types are 32 bits in windows so it is fine.
+#pragma warning(disable: 4244; disable: 4312)
+#endif
+
+#include "../algs.h"
+#include "../sync_extension.h"
+#include "../binary_search_tree.h"
+#include "../threads.h"
+#include "../geometry/rectangle.h"
+#include "../assert.h"
+#include "../queue.h"
+#include "../pixel.h"
+#include "../unicode.h"
+#include "../smart_pointers/shared_ptr_thread_safe.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class base_window;
+ namespace gui_core_kernel_1_globals
+ {
+
+ LRESULT CALLBACK WndProc (HWND, UINT, WPARAM, LPARAM);
+ class event_handler_thread;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class canvas : public rectangle
+ {
+ public:
+ struct pixel
+ {
+ unsigned char blue;
+ unsigned char green;
+ unsigned char red;
+ };
+
+ ~canvas() { }
+
+ inline pixel* operator[] (
+ unsigned long row
+ ) const
+ {
+ DLIB_ASSERT(row < height(),
+ "\tpixel* canvas::operator[]"
+ << "\n\tyou have to give a row that is less than the height()"
+ << "\n\tthis: " << this
+ << "\n\trow: " << row
+ << "\n\theight(): " << height()
+ );
+ unsigned char* temp = bits + row_width*row;
+ return reinterpret_cast<pixel*>(temp);
+ }
+
+ void fill (
+ unsigned char red_,
+ unsigned char green_,
+ unsigned char blue_
+ ) const;
+
+ private:
+
+ friend LRESULT CALLBACK gui_core_kernel_1_globals::WndProc (HWND, UINT, WPARAM, LPARAM);
+
+ canvas (
+ unsigned char* bits_,
+ unsigned long padding_,
+ unsigned long left_,
+ unsigned long top_,
+ unsigned long right_,
+ unsigned long bottom_
+ ) :
+ rectangle(left_,top_,right_,bottom_),
+ bits(bits_),
+ width_(width()),
+ height_(height()),
+ row_width(width_*3+padding_)
+ {}
+
+ // restricted functions
+ canvas(); // normal constructor
+ canvas(canvas&); // copy constructor
+ canvas& operator=(canvas&); // assignment operator
+
+ unsigned char* const bits;
+ const unsigned long width_;
+ const unsigned long height_;
+ const unsigned long row_width;
+ };
+
+ template <>
+ struct pixel_traits<canvas::pixel>
+ {
+ constexpr static bool rgb = true;
+ constexpr static bool rgb_alpha = false;
+ constexpr static bool grayscale = false;
+ constexpr static bool hsi = false;
+ constexpr static long num = 3;
+ typedef unsigned char basic_pixel_type;
+ static basic_pixel_type min() { return 0;}
+ static basic_pixel_type max() { return 255;}
+ constexpr static bool is_unsigned = true;
+ constexpr static bool has_alpha = false;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void put_on_clipboard (
+ const std::string& str
+ );
+
+ void put_on_clipboard (
+ const std::wstring& str
+ );
+
+ void put_on_clipboard (
+ const dlib::ustring& str
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ void get_from_clipboard (
+ std::string& str
+ );
+
+ void get_from_clipboard (
+ std::wstring& str
+ );
+
+ void get_from_clipboard (
+ dlib::ustring& str
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ class base_window
+ {
+ friend LRESULT CALLBACK gui_core_kernel_1_globals::WndProc (HWND, UINT, WPARAM, LPARAM);
+ dlib::shared_ptr_thread_safe<gui_core_kernel_1_globals::event_handler_thread> globals;
+
+ HWND hwnd;
+ DWORD style;
+ bool has_been_destroyed;
+
+ // This is true if the mouse is in this window. false otherwise.
+ // also note that this variable is only accessed from the event handling thread
+ // (except for being initialized below in the constructor, but that is inside
+ // the window_table mutex so it doesn't matter).
+ bool mouse_in;
+
+ // this is a copy of the last inputs we sent to the on_mouse_move() event.
+ long prevx;
+ long prevy;
+ unsigned long prev_state;
+
+ protected:
+ const rmutex& wm;
+
+ public:
+
+ base_window (
+ bool resizable = true,
+ bool undecorated = false
+ );
+
+ virtual ~base_window (
+ );
+
+ void close_window (
+ );
+
+ bool is_closed (
+ ) const;
+
+ void set_title (
+ const std::string& title
+ );
+
+ void set_title (
+ const std::wstring& title
+ );
+
+ void set_title (
+ const ustring& title
+ );
+
+ virtual void show (
+ );
+
+ virtual void hide(
+ );
+
+ void set_size (
+ int width_,
+ int height_
+ );
+
+ void set_pos (
+ long x_,
+ long y_
+ );
+
+ void get_pos (
+ long& x_,
+ long& y_
+ );
+
+ void get_size (
+ unsigned long& width,
+ unsigned long& height
+ ) const;
+
+ void get_display_size (
+ unsigned long& width,
+ unsigned long& height
+ ) const;
+
+ void invalidate_rectangle (
+ const rectangle& rect
+ );
+
+ void trigger_user_event (
+ void* p,
+ int i
+ );
+
+ void wait_until_closed (
+ ) const;
+
+ void set_im_pos (
+ long x_,
+ long y_
+ );
+
+ enum on_close_return_code
+ {
+ DO_NOT_CLOSE_WINDOW,
+ CLOSE_WINDOW
+ };
+
+ enum mouse_state_masks
+ {
+ NONE = 0,
+ LEFT = 1,
+ RIGHT = 2,
+ MIDDLE = 4,
+ SHIFT = 8,
+ CONTROL = 16
+ };
+
+ enum keyboard_state_masks
+ {
+ KBD_MOD_NONE = 0,
+ KBD_MOD_SHIFT = 1,
+ KBD_MOD_CONTROL = 2,
+ KBD_MOD_ALT = 4,
+ KBD_MOD_META = 8,
+ KBD_MOD_CAPS_LOCK = 16,
+ KBD_MOD_NUM_LOCK = 32,
+ KBD_MOD_SCROLL_LOCK = 64
+ };
+
+ enum non_printable_keyboard_keys
+ {
+ KEY_BACKSPACE,
+ KEY_SHIFT,
+ KEY_CTRL,
+ KEY_ALT,
+ KEY_PAUSE,
+ KEY_CAPS_LOCK,
+ KEY_ESC,
+ KEY_PAGE_UP,
+ KEY_PAGE_DOWN,
+ KEY_END,
+ KEY_HOME,
+ KEY_LEFT, // This is the left arrow key
+ KEY_RIGHT, // This is the right arrow key
+ KEY_UP, // This is the up arrow key
+ KEY_DOWN, // This is the down arrow key
+ KEY_INSERT,
+ KEY_DELETE,
+ KEY_SCROLL_LOCK,
+
+ // Function Keys
+ KEY_F1,
+ KEY_F2,
+ KEY_F3,
+ KEY_F4,
+ KEY_F5,
+ KEY_F6,
+ KEY_F7,
+ KEY_F8,
+ KEY_F9,
+ KEY_F10,
+ KEY_F11,
+ KEY_F12
+ };
+
+ protected:
+
+ virtual on_close_return_code on_window_close(
+ ){return CLOSE_WINDOW;}
+
+ virtual void on_user_event (
+ void* ,
+ int
+ ){}
+
+ virtual void on_window_resized(
+ ){}
+
+ virtual void on_window_moved(
+ ){}
+
+ virtual void on_mouse_down (
+ unsigned long ,
+ unsigned long ,
+ long ,
+ long ,
+ bool
+ ){}
+
+ virtual void on_mouse_up (
+ unsigned long ,
+ unsigned long ,
+ long ,
+ long
+ ){}
+
+ virtual void on_mouse_move (
+ unsigned long ,
+ long ,
+ long
+ ){}
+
+ virtual void on_mouse_leave (
+ ){}
+
+ virtual void on_mouse_enter (
+ ){}
+
+ virtual void on_wheel_up (
+ unsigned long
+ ){}
+
+ virtual void on_wheel_down (
+ unsigned long
+ ){}
+
+ virtual void on_focus_gained (
+ ){}
+
+ virtual void on_focus_lost (
+ ){}
+
+ virtual void on_keydown (
+ unsigned long ,
+ bool ,
+ unsigned long
+ ){}
+
+ virtual void on_string_put (
+ const std::wstring&
+ ){}
+
+ private:
+
+ virtual void paint (
+ const canvas&
+ ) =0;
+
+ base_window(base_window&); // copy constructor
+ base_window& operator=(base_window&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#ifdef NO_MAKEFILE
+#include "gui_core_kernel_1.cpp"
+#endif
+
+#endif // DLIB_GUI_CORE_KERNEl_1_
+
diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_2.cpp b/ml/dlib/dlib/gui_core/gui_core_kernel_2.cpp
new file mode 100644
index 000000000..feca4bf22
--- /dev/null
+++ b/ml/dlib/dlib/gui_core/gui_core_kernel_2.cpp
@@ -0,0 +1,1996 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GUI_CORE_KERNEL_2_CPp_
+#define DLIB_GUI_CORE_KERNEL_2_CPp_
+#include "../platform.h"
+
+#ifdef POSIX
+
+#include "gui_core_kernel_2.h"
+
+#include <cmath>
+#include <cstring>
+#include <iostream>
+#include <vector>
+#include <set>
+
+#include <X11/Xatom.h>
+#include <X11/Xlib.h>
+#include <X11/Xutil.h>
+#include <X11/keysym.h>
+#include <X11/Xlocale.h>
+#include <X11/XKBlib.h>
+
+#include <poll.h>
+
+#include "../assert.h"
+#include "../queue.h"
+#include "../sync_extension.h"
+#include "../logger.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace gui_core_kernel_2_globals
+ {
+ void init_keyboard_mod_masks();
+ struct user_event_type
+ {
+ Window w;
+ void* p;
+ int i;
+ };
+
+ typedef sync_extension<queue<user_event_type,memory_manager<char>::kernel_1b>::kernel_2a_c>::kernel_1a queue_of_user_events;
+
+ typedef sync_extension<binary_search_tree<Window,base_window*>::kernel_1a>::kernel_1a
+ window_table_type;
+
+ // ----------------------------------------------------------------------------------------
+
+ const std::shared_ptr<dlib::mutex>& global_mutex()
+ {
+ static std::shared_ptr<dlib::mutex> m(new dlib::mutex);
+ return m;
+ }
+
+ class event_handler_thread : public threaded_object
+ {
+ public:
+
+ enum et_state
+ {
+ uninitialized,
+ initialized,
+ failure_to_init
+ };
+
+ et_state status;
+ logger dlog;
+
+
+ int depth;
+ Display* disp;
+ XIM xim;
+ XIMStyle xim_style;
+ Screen* screen;
+
+ Atom delete_window;
+ Window exit_window;
+ std::wstring clipboard;
+
+ int alt_mask;
+ int meta_mask;
+ int num_lock_mask;
+ int scroll_lock_mask;
+
+ // the mutex in this object is the global mutex used to protect everything
+ // in the gui_core and gui_widgets components.
+ window_table_type window_table;
+
+ rsignaler window_close_signaler;
+ rsignaler et_signaler;
+
+ queue_of_user_events user_events;
+ queue_of_user_events user_events_temp;
+
+ std::shared_ptr<dlib::mutex> reference_to_global_mutex;
+
+ event_handler_thread(
+ ) :
+ dlog("dlib.gui_core"),
+ depth(0),
+ disp(0),
+ xim(0),
+ screen(0),
+ alt_mask(0),
+ meta_mask(0),
+ num_lock_mask(0),
+ scroll_lock_mask(0),
+ window_close_signaler(window_table.get_mutex()),
+ et_signaler(window_table.get_mutex()),
+ reference_to_global_mutex(global_mutex())
+ {
+ auto_mutex M(window_table.get_mutex());
+
+ status = uninitialized;
+
+ // start up the event handler thread
+ start();
+
+ // wait for the event thread to get up and running
+ while (status == uninitialized)
+ et_signaler.wait();
+
+ if (status == failure_to_init)
+ throw gui_error("Failed to initialize X11 resources");
+
+ init_keyboard_mod_masks();
+ }
+
+ ~event_handler_thread ()
+ {
+
+ if (is_alive())
+ {
+
+ if (status != failure_to_init)
+ {
+ XConfigureEvent event;
+ event.type = ConfigureNotify;
+ event.send_event = True;
+ event.display = disp;
+ event.window = exit_window;
+ event.x = 1;
+ XFlush(disp);
+ XPutBackEvent(disp,reinterpret_cast<XEvent*>(&event));
+ XFlush(disp);
+
+ // This should cause XNextEvent() to unblock so that it will see
+ // this ConfigureNotify event we are putting onto the event queue.
+ XSendEvent(disp,exit_window,False,0,reinterpret_cast<XEvent*>(&event));
+ XFlush(disp);
+
+ wait();
+
+ if (xim != NULL)
+ {
+ XCloseIM(xim);
+ }
+
+ XCloseDisplay(disp);
+
+
+ }
+ else
+ {
+
+ wait();
+ }
+ }
+
+
+ }
+
+ private:
+
+ void thread (
+ )
+ {
+ using namespace std;
+ using namespace dlib;
+ try
+ {
+
+ // You are supposed to call this if using XLib in a threaded program. Note
+ // however that at one point I noticed that calling this causes a dead-lock
+ // when using XIM. But I can't reproduce that anymore and not calling it
+ // sometimes causes XCloseDisplay() to hang.
+ if (XInitThreads() == 0)
+ {
+ dlog << LFATAL << "Unable to initialize threading support.";
+ // signal that an error has occurred
+ window_table.get_mutex().lock();
+ status = failure_to_init;
+ et_signaler.broadcast();
+ window_table.get_mutex().unlock();
+ return;
+ }
+
+ window_table.get_mutex().lock();
+ disp = XOpenDisplay(NULL);
+ window_table.get_mutex().unlock();
+ if (disp == 0)
+ {
+ window_table.get_mutex().lock();
+ disp = XOpenDisplay(":0.0");
+ window_table.get_mutex().unlock();
+ if (disp == 0)
+ {
+ dlog << LFATAL << "Unable to connect to the X display.";
+ // signal that an error has occurred
+ window_table.get_mutex().lock();
+ status = failure_to_init;
+ et_signaler.broadcast();
+ window_table.get_mutex().unlock();
+ return;
+ }
+ }
+
+ window_table.get_mutex().lock();
+ screen = DefaultScreenOfDisplay(disp);
+ depth = DefaultDepthOfScreen(screen);
+ delete_window = XInternAtom(disp,"WM_DELETE_WINDOW",1);
+ window_table.get_mutex().unlock();
+
+ xim = NULL;
+ // I'm disabling XIM usage all together because calling XSetICValues()
+ // in set_im_pos() randomly hangs the application (on Ubuntu 13.10 at
+ // least).
+ /*
+ window_table.get_mutex().lock();
+ std::string saved_locale(setlocale (LC_CTYPE, NULL));
+ if (setlocale( LC_CTYPE, "" ) && XSupportsLocale() && XSetLocaleModifiers(""))
+ xim = XOpenIM(disp, NULL, NULL, NULL);
+ else
+ setlocale( LC_CTYPE, saved_locale.c_str() );
+ window_table.get_mutex().unlock();
+ */
+ if (xim)
+ {
+ const static XIMStyle preedit_styles[] =
+ {XIMPreeditPosition, XIMPreeditNothing, XIMPreeditNone, 0};
+ const static XIMStyle status_styles[] =
+ {XIMStatusNothing, XIMStatusNone, 0};
+ xim_style = 0;
+
+ XIMStyles *xim_styles;
+ window_table.get_mutex().lock();
+
+ XGetIMValues (xim, XNQueryInputStyle, &xim_styles, (const void*)NULL);
+ window_table.get_mutex().unlock();
+ std::set<XIMStyle> xims;
+ for (int i = 0; i < xim_styles->count_styles; ++i){
+ xims.insert(xim_styles->supported_styles[i]);
+ }
+ for (int j = 0; status_styles[j]; ++j){
+ for (int i = 0; preedit_styles[i]; ++i){
+ xim_style = (status_styles[j] | preedit_styles[i]);
+ if (xims.count(xim_style)) break;
+ }
+ if (xim_style) break;
+ }
+ XFree(xim_styles);
+ }
+
+ // make this window just so we can send messages to it and trigger
+ // events in the event thread
+ XSetWindowAttributes attr;
+ window_table.get_mutex().lock();
+ exit_window = XCreateWindow(
+ disp,
+ DefaultRootWindow(disp),
+ 0,
+ 0,
+ 10, // this is the default width of a window
+ 10, // this is the default width of a window
+ 0,
+ depth,
+ InputOutput,
+ CopyFromParent,
+ 0,
+ &attr
+ );
+ window_table.get_mutex().unlock();
+
+ // signal that the event thread is now up and running
+ window_table.get_mutex().lock();
+ status = initialized;
+ et_signaler.broadcast();
+ window_table.get_mutex().unlock();
+
+ // start the event handler
+ event_handler();
+ }
+ catch (std::exception& e)
+ {
+ cout << "\nEXCEPTION THROWN: \n" << e.what() << endl;
+ abort();
+ }
+ catch (...)
+ {
+ cout << "UNKNOWN EXCEPTION THROWN.\n" << endl;
+ abort();
+ }
+ }
+
+ void event_handler();
+ void init_keyboard_mod_masks();
+ };
+
+ struct x11_base_windowstuff
+ {
+ Window hwnd;
+ Time last_click_time;
+ XIC xic;
+ XFontSet fs;
+ std::shared_ptr<event_handler_thread> globals;
+ };
+
+ // Do all this just to make sure global_mutex() is initialized at program start
+ // and thus hopefully before any threads have the chance to startup and call
+ // global_data() concurrently.
+ struct call_global_mutex { call_global_mutex() { global_mutex(); } };
+ static call_global_mutex call_global_mutex_instance;
+
+ const std::shared_ptr<event_handler_thread>& global_data()
+ {
+ auto_mutex M(*global_mutex());
+ static std::shared_ptr<event_handler_thread> p;
+ if (p.get() == 0)
+ p.reset(new event_handler_thread());
+ return p;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ Bool XCheckIfEventPredicate (
+ Display* ,
+ XEvent* event,
+ XPointer arg
+ )
+ /*!
+ ensures
+ - if (event is an Expose event for the window pointed to by arg) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ {
+ if (event->type == Expose)
+ {
+ XExposeEvent* e = reinterpret_cast<XExposeEvent*>(event);
+ Window* win= reinterpret_cast<Window*>(arg);
+ if (e->window == *win)
+ {
+ return 1;
+ }
+ }
+ return 0;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ static bool map_keys (
+ KeySym keycode,
+ bool ,
+ bool ,
+ unsigned long& result,
+ bool& is_printable
+ )
+ /*!
+ requires
+ - if (shift was down for this key) then
+ - shift == true
+ - if (caps lock was on for this key) then
+ - caps == true
+ - keycode == the keycode from windows that we are to process
+ - keycode < keyboard_keys_size
+ ensures
+ - if (this key should be ignored) then
+ - returns false
+ - else
+ - returns true
+ - #is_printable == true if result is a printable ascii character
+ - #result == the keycode converted into the proper number to tbe
+ returned by the event handler.
+ !*/
+ {
+ is_printable = true;
+ if ((keycode <= 'z' && keycode >= 'a') ||
+ (keycode <= 'Z' && keycode >= 'A') ||
+ (keycode <= '9' && keycode >= '0'))
+ {
+ result = keycode;
+ }
+ else
+ {
+ is_printable = false;
+ switch (keycode)
+ {
+ case XK_Home: result = base_window::KEY_HOME; break;
+ case XK_Left: result = base_window::KEY_LEFT; break;
+ case XK_Right: result = base_window::KEY_RIGHT; break;
+ case XK_Down: result = base_window::KEY_DOWN; break;
+ case XK_Up: result = base_window::KEY_UP; break;
+ case XK_Prior: result = base_window::KEY_PAGE_UP; break;
+ case XK_Next: result = base_window::KEY_PAGE_DOWN; break;
+ case XK_End: result = base_window::KEY_END; break;
+ case XK_Escape: result = base_window::KEY_ESC; break;
+
+ case XK_KP_Delete: result = base_window::KEY_DELETE; break;
+ case XK_KP_Prior: result = base_window::KEY_PAGE_UP; break;
+ case XK_KP_Next: result = base_window::KEY_PAGE_DOWN; break;
+
+
+ case XK_F1: result = base_window::KEY_F1; break;
+ case XK_F2: result = base_window::KEY_F2; break;
+ case XK_F3: result = base_window::KEY_F3; break;
+ case XK_F4: result = base_window::KEY_F4; break;
+ case XK_F5: result = base_window::KEY_F5; break;
+ case XK_F6: result = base_window::KEY_F6; break;
+ case XK_F7: result = base_window::KEY_F7; break;
+ case XK_F8: result = base_window::KEY_F8; break;
+ case XK_F9: result = base_window::KEY_F9; break;
+ case XK_F10: result = base_window::KEY_F10; break;
+ case XK_F11: result = base_window::KEY_F11; break;
+ case XK_F12: result = base_window::KEY_F12; break;
+
+
+ case XK_Shift_L: result = base_window::KEY_SHIFT; break;
+ case XK_Shift_R: result = base_window::KEY_SHIFT; break;
+ case XK_Control_L: result = base_window::KEY_CTRL; break;
+ case XK_Control_R: result = base_window::KEY_CTRL; break;
+ case XK_Caps_Lock: result = base_window::KEY_CAPS_LOCK; break;
+ case XK_Alt_L: result = base_window::KEY_ALT; break;
+ case XK_Alt_R: result = base_window::KEY_ALT; break;
+
+
+ case XK_BackSpace: result = base_window::KEY_BACKSPACE; break;
+ case XK_Delete: result = base_window::KEY_DELETE; break;
+ case XK_Scroll_Lock: result = base_window::KEY_SCROLL_LOCK; break;
+ case XK_Pause: result = base_window::KEY_PAUSE; break;
+ case XK_Insert: result = base_window::KEY_INSERT; break;
+ case XK_KP_Insert: result = base_window::KEY_INSERT; break;
+
+
+
+
+ case XK_exclam:
+ is_printable = true;
+ result = '!'; break;
+ case XK_quotedbl:
+ is_printable = true;
+ result = '"'; break;
+ case XK_numbersign:
+ is_printable = true;
+ result = '#'; break;
+ case XK_dollar:
+ is_printable = true;
+ result = '$'; break;
+ case XK_percent:
+ is_printable = true;
+ result = '%'; break;
+ case XK_ampersand:
+ is_printable = true;
+ result = '&'; break;
+ case XK_apostrophe:
+ is_printable = true;
+ result = '\''; break;
+ case XK_parenleft:
+ is_printable = true;
+ result = '('; break;
+ case XK_parenright:
+ is_printable = true;
+ result = ')'; break;
+ case XK_asterisk:
+ is_printable = true;
+ result = '*'; break;
+ case XK_plus:
+ is_printable = true;
+ result = '+'; break;
+ case XK_comma:
+ is_printable = true;
+ result = ','; break;
+ case XK_minus:
+ is_printable = true;
+ result = '-'; break;
+ case XK_period:
+ is_printable = true;
+ result = '.'; break;
+ case XK_slash:
+ is_printable = true;
+ result = '/'; break;
+ case XK_colon:
+ is_printable = true;
+ result = ':'; break;
+ case XK_semicolon:
+ is_printable = true;
+ result = ';'; break;
+ case XK_less:
+ is_printable = true;
+ result = '<'; break;
+ case XK_equal:
+ is_printable = true;
+ result = '='; break;
+ case XK_greater:
+ is_printable = true;
+ result = '>'; break;
+ case XK_question:
+ is_printable = true;
+ result = '?'; break;
+ case XK_at:
+ is_printable = true;
+ result = '@'; break;
+ case XK_grave:
+ is_printable = true;
+ result = '`'; break;
+ case XK_underscore:
+ is_printable = true;
+ result = '_'; break;
+ case XK_asciicircum:
+ is_printable = true;
+ result = '^'; break;
+ case XK_bracketleft:
+ is_printable = true;
+ result = '['; break;
+ case XK_backslash:
+ is_printable = true;
+ result = '\\'; break;
+ case XK_bracketright:
+ is_printable = true;
+ result = ']'; break;
+ case XK_asciitilde:
+ is_printable = true;
+ result = '~'; break;
+ case XK_braceleft:
+ is_printable = true;
+ result = '{'; break;
+ case XK_bar:
+ is_printable = true;
+ result = '|'; break;
+ case XK_braceright:
+ is_printable = true;
+ result = '}'; break;
+
+
+
+
+ case XK_space:
+ is_printable = true;
+ result = ' '; break;
+ case XK_Return:
+ is_printable = true;
+ result = '\n'; break;
+ case XK_Tab:
+ is_printable = true;
+ result = '\t'; break;
+ case XK_KP_Divide:
+ is_printable = true;
+ result = '/'; break;
+ case XK_KP_Decimal:
+ is_printable = true;
+ result = '.'; break;
+ case XK_KP_Subtract:
+ is_printable = true;
+ result = '-'; break;
+ case XK_KP_Add:
+ is_printable = true;
+ result = '+'; break;
+ case XK_KP_Multiply:
+ is_printable = true;
+ result = '*'; break;
+ case XK_KP_Equal:
+ is_printable = true;
+ result = '='; break;
+
+ case XK_KP_0:
+ is_printable = true;
+ result = '0'; break;
+ case XK_KP_1:
+ is_printable = true;
+ result = '1'; break;
+ case XK_KP_2:
+ is_printable = true;
+ result = '2'; break;
+ case XK_KP_3:
+ is_printable = true;
+ result = '3'; break;
+ case XK_KP_4:
+ is_printable = true;
+ result = '4'; break;
+ case XK_KP_5:
+ is_printable = true;
+ result = '5'; break;
+ case XK_KP_6:
+ is_printable = true;
+ result = '6'; break;
+ case XK_KP_7:
+ is_printable = true;
+ result = '7'; break;
+ case XK_KP_8:
+ is_printable = true;
+ result = '8'; break;
+ case XK_KP_9:
+ is_printable = true;
+ result = '9'; break;
+
+ default:
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ void event_handler_thread::
+ event_handler (
+ )
+ /*!
+ ensures
+ - will handle all events and event dispatching
+ !*/
+ {
+ try
+ {
+ std::vector<unsigned char> bitmap_buffer;
+ bool quit_event_loop = false;
+ while (quit_event_loop == false)
+ {
+ // get a lock on the window_table's mutex
+ auto_mutex window_table_locker(window_table.get_mutex());
+
+ XEvent ev;
+ memset(&ev, 0, sizeof(ev));
+ while (XPending(disp) == 0){
+ window_table.get_mutex().unlock();
+ // wait until receiving X11 next event
+ struct pollfd pfd;
+ pfd.fd = ConnectionNumber(disp);
+ pfd.events = POLLIN | POLLPRI;
+ poll(&pfd, 1, -1);
+
+ window_table.get_mutex().lock();
+ }
+ XNextEvent(disp,&ev);
+
+ // pass events to input method.
+ // if this event is needed by input method, XFilterEvent returns True
+ if (XFilterEvent(&ev, None) == True){
+ continue;
+ }
+
+ // if this event is for one of the windows in the window_table
+ // then get that window out of the table and put it into win.
+ XAnyEvent* _ae = reinterpret_cast<XAnyEvent*>(&ev);
+ base_window** win_ = window_table[_ae->window];
+ base_window* win = 0;
+ if (win_)
+ win = *win_;
+
+
+ // ignore messages for unmapped windows
+ if (ev.type != MapNotify && win != 0)
+ {
+ if (win->is_mapped == false)
+ continue;
+ }
+
+
+ switch (ev.type)
+ {
+
+ case SelectionRequest:
+ {
+ Atom a_ct = XInternAtom(disp, "COMPOUND_TEXT", False);
+ XSelectionRequestEvent* req = reinterpret_cast<XSelectionRequestEvent*>(&ev.xselectionrequest);
+ XEvent respond;
+
+ if (req->target == XA_STRING)
+ {
+ XChangeProperty (disp,
+ req->requestor,
+ req->property,
+ XA_STRING,
+ 8,
+ PropModeReplace,
+ reinterpret_cast<const unsigned char*>(convert_wstring_to_mbstring(clipboard).c_str()),
+ clipboard.size()+1);
+ respond.xselection.property=req->property;
+ }
+ else if (req->target == a_ct)
+ {
+ XChangeProperty (disp,
+ req->requestor,
+ req->property,
+ a_ct,
+ sizeof(wchar_t)*8,
+ PropModeReplace,
+ reinterpret_cast<const unsigned char*>(clipboard.c_str()),
+ clipboard.size()+1);
+ respond.xselection.property=req->property;
+ }
+ else
+ {
+ respond.xselection.property= None;
+ }
+ respond.xselection.type= SelectionNotify;
+ respond.xselection.display= req->display;
+ respond.xselection.requestor= req->requestor;
+ respond.xselection.selection=req->selection;
+ respond.xselection.target= req->target;
+ respond.xselection.time = req->time;
+ XSendEvent (disp, req->requestor,0,0,&respond);
+ XFlush (disp);
+
+ } break;
+
+ case MapNotify:
+ {
+ if (win == 0)
+ break;
+
+ win->is_mapped = true;
+
+ if (win->resizable == false)
+ {
+ XSizeHints* hints = XAllocSizeHints();
+ hints->flags = PMinSize|PMaxSize;
+ hints->min_width = win->width;
+ hints->max_width = win->width;
+ hints->max_height = win->height;
+ hints->min_height = win->height;
+ XSetNormalHints(disp,win->x11_stuff.hwnd,hints);
+ XFree(hints);
+ }
+
+
+
+ if (win->has_been_resized)
+ {
+ XResizeWindow(disp,win->x11_stuff.hwnd,win->width,win->height);
+ win->has_been_resized = false;
+ win->on_window_resized();
+ }
+
+ if (win->has_been_moved)
+ {
+ XMoveWindow(disp,win->x11_stuff.hwnd,win->x,win->y);
+ win->has_been_moved = false;
+ win->on_window_moved();
+ }
+ XFlush(disp);
+
+
+ } break;
+
+
+ case KeyPress:
+ {
+ XKeyPressedEvent* e = reinterpret_cast<XKeyPressedEvent*>(&ev);
+
+ if (win == 0)
+ break;
+
+ unsigned long state = 0;
+ bool shift = ((e->state & ShiftMask)!=0);
+ bool ctrl = ((e->state & ControlMask)!=0);
+ bool caps = ((e->state & LockMask)!=0);
+ if(shift)
+ state |= base_window::KBD_MOD_SHIFT;
+ if(ctrl)
+ state |= base_window::KBD_MOD_CONTROL;
+ if(caps)
+ state |= base_window::KBD_MOD_CAPS_LOCK;
+ if((e->state & alt_mask)!=0)
+ state |= base_window::KBD_MOD_ALT;
+ if((e->state & meta_mask)!=0)
+ state |= base_window::KBD_MOD_META;
+ if((e->state & num_lock_mask)!=0)
+ state |= base_window::KBD_MOD_NUM_LOCK;
+ if((e->state & scroll_lock_mask)!=0)
+ state |= base_window::KBD_MOD_SCROLL_LOCK;
+
+ KeySym key;
+ Status status;
+
+ if (win->x11_stuff.xic) {
+ std::wstring wstr;
+ wstr.resize(2);
+ int len = XwcLookupString(win->x11_stuff.xic,e,&wstr[0],wstr.size(),&key,&status);
+ if (status == XBufferOverflow){
+ wstr.resize(len);
+ len = XwcLookupString(win->x11_stuff.xic,e,&wstr[0],wstr.size(),&key,&status);
+ }
+ if (status == XLookupChars){
+ win->on_string_put(wstr);
+ }
+ } else {
+ char buffer[2];
+ XLookupString(e, buffer, sizeof(buffer), &key, NULL);
+ status = XLookupKeySym;
+ }
+
+ if (status == XLookupKeySym || status == XLookupBoth){
+
+ bool is_printable;
+ unsigned long result;
+
+ if (map_keys(key,shift,caps,result,is_printable))
+ {
+ // signal the keyboard event
+ win->on_keydown(result,is_printable,state);
+ }
+ }
+
+ } break;
+
+ case FocusIn:
+ {
+ if (win == 0)
+ break;
+
+ // signal the focus event
+ win->on_focus_gained();
+ } break;
+
+ case FocusOut:
+ {
+ if (win == 0)
+ break;
+
+ // signal the focus event
+ win->on_focus_lost();
+ } break;
+
+ case ButtonPress:
+ case ButtonRelease:
+ {
+ XButtonEvent* e = reinterpret_cast<XButtonEvent*>(&ev);
+
+ if (win == 0)
+ break;
+
+ unsigned long btn = base_window::NONE;
+ if (e->button == Button1)
+ btn = base_window::LEFT;
+ else if (e->button == Button3)
+ btn = base_window::RIGHT;
+ else if (e->button == Button2)
+ btn = base_window::MIDDLE;
+
+ unsigned long state = 0;
+ if (e->state & ControlMask)
+ state |= base_window::CONTROL;
+ if (e->state & Button1Mask)
+ state |= base_window::LEFT;
+ if (e->state & Button2Mask)
+ state |= base_window::MIDDLE;
+ if (e->state & Button3Mask)
+ state |= base_window::RIGHT;
+ if (e->state & ShiftMask)
+ state |= base_window::SHIFT;
+
+ // only send the event if this is a button we support
+ if (btn != (unsigned long)base_window::NONE)
+ {
+
+
+ if (ev.type == ButtonPress)
+ {
+ bool is_double_click = false;
+ if (win->last_click_button == btn &&
+ std::abs((long)win->last_click_x - (long)e->x) < 5 &&
+ std::abs((long)win->last_click_y - (long)e->y) < 5 &&
+ e->time - win->x11_stuff.last_click_time <= 400)
+ {
+ // this is a double click
+ is_double_click = true;
+ // set this to make sure the next click can't be
+ // interpreted as a double click
+ win->last_click_button = base_window::NONE;
+ }
+ else
+ {
+ win->last_click_button = btn;
+ win->last_click_x = e->x;
+ win->last_click_y = e->y;
+ win->x11_stuff.last_click_time = e->time;
+ }
+
+ // remove the clicked button from the state
+ state &= (~btn);
+ win->on_mouse_down(btn,state,e->x,e->y,is_double_click);
+
+ }
+ else
+ {
+ // remove the clicked button from the state
+ state &= (~btn);
+ win->on_mouse_up(btn,state,e->x,e->y);
+ }
+ }
+ else if (e->button == Button4 && ev.type == ButtonPress)
+ {
+ win->on_wheel_up(state);
+ }
+ else if (e->button == Button5 && ev.type == ButtonPress)
+ {
+ win->on_wheel_down(state);
+ }
+
+ } break;
+
+ case LeaveNotify:
+ {
+ if (win == 0)
+ break;
+
+ win->on_mouse_leave();
+
+ } break;
+
+ case EnterNotify:
+ {
+ if (win == 0)
+ break;
+
+ win->on_mouse_enter();
+ } break;
+
+ case MotionNotify:
+ {
+ XMotionEvent* e = reinterpret_cast<XMotionEvent*>(&ev);
+
+ if (win == 0)
+ break;
+
+ unsigned long state = 0;
+ if (e->state & ControlMask)
+ state |= base_window::CONTROL;
+ if (e->state & Button1Mask)
+ state |= base_window::LEFT;
+ if (e->state & Button2Mask)
+ state |= base_window::MIDDLE;
+ if (e->state & Button3Mask)
+ state |= base_window::RIGHT;
+ if (e->state & ShiftMask)
+ state |= base_window::SHIFT;
+
+ win->on_mouse_move(state,e->x,e->y);
+
+ } break;
+
+ case ConfigureNotify:
+ {
+ XConfigureEvent* e = reinterpret_cast<XConfigureEvent*>(&ev);
+ if (e->window == exit_window)
+ {
+ // this is the signal to quit the event handler
+ quit_event_loop = true;
+ break;
+ }
+
+ if (win == 0)
+ break;
+
+ if (win->width != e->width ||
+ win->height != e->height ||
+ win->has_been_resized)
+ {
+ win->has_been_resized = false;
+ // this is a resize
+ win->width = e->width;
+ win->height = e->height;
+ win->on_window_resized();
+ }
+ if (win->x != e->x ||
+ win->y != e->y ||
+ win->has_been_moved)
+ {
+ win->has_been_moved = false;
+ // this is a move
+ win->x = e->x;
+ win->y = e->y;
+ win->on_window_moved();
+ }
+
+ } break;
+
+ case ClientMessage:
+ {
+ XClientMessageEvent* e = reinterpret_cast<XClientMessageEvent*>(&ev);
+ if ((Atom)e->data.l[0] == delete_window)
+ {
+ if (win == 0)
+ break;
+
+
+ if (win->on_window_close() == base_window::DO_NOT_CLOSE_WINDOW)
+ {
+ DLIB_ASSERT(win->has_been_destroyed == false,
+ "\tYou called close_window() inside the on_window_close() event but"
+ << "\n\tthen returned DO_NOT_CLOSE_WINDOW. You can do one or the other but not both."
+ << "\n\tthis: " << win
+ );
+ // the client has decided not to close the window
+ // after all
+ }
+ else
+ {
+ if (window_table[e->window])
+ {
+ window_table.destroy(e->window);
+ XDestroyWindow(disp,e->window);
+ win->has_been_destroyed = true;
+ window_close_signaler.broadcast();
+ }
+ else
+ {
+ // in this case the window must have self destructed by
+ // calling delete this; so we don't have to do anything.
+ }
+ }
+ }
+ } break;
+
+ case Expose:
+ {
+ XExposeEvent* e = reinterpret_cast<XExposeEvent*>(&ev);
+
+ if (win == 0)
+ break;
+
+ // take all the expose events for this window out
+ XEvent etemp;
+ int x = e->x;
+ int y = e->y;
+ int width = e->width;
+ int height = e->height;
+
+
+
+ // What we are doing here with this loop is we are combining
+ // all of the Expose events for this window that are
+ // currently in the queue.
+ while (XCheckIfEvent(disp,&etemp,XCheckIfEventPredicate,reinterpret_cast<XPointer>(&(e->window))))
+ {
+ XExposeEvent* e2 = reinterpret_cast<XExposeEvent*>(&etemp);
+ if (e2->x < x)
+ {
+ width += x - e2->x;
+ x = e2->x;
+ }
+ if (e2->y < y)
+ {
+ height += y - e2->y;
+ y = e2->y;
+ }
+ if (e2->width + e2->x > width + x)
+ {
+ width = e2->width + e2->x - x;
+ }
+ if (e2->height + e2->y > height + y)
+ {
+ height = e2->height + e2->y - y;
+ }
+ }
+
+ // I'm not sure if this sort of thing can happen but
+ // if it does then just ignore this entire event.
+ if (width == 0 || height == 0)
+ {
+ break;
+ }
+
+ if (bitmap_buffer.size() < static_cast<unsigned long>(width*height*4))
+ bitmap_buffer.resize(width*height*4);
+
+ unsigned char* const bitmap = &bitmap_buffer[0];
+ unsigned char* const end = bitmap + width*height*4;
+
+ unsigned char* temp;
+ canvas c(bitmap,x,y,x+width-1,y+height-1);
+
+
+ win->paint(c);
+
+ // the user might have called win->close_window() and if they did
+ // then just stop right here. We don't want to paint the window.
+ if (win->has_been_destroyed)
+ break;
+
+ // if the color depth we are working with isn't 24bits then we need
+ // to transform our image into whatever it is supposed to be.
+ if (depth != 24)
+ {
+ // convert this image into an 8 bit image
+ unsigned int red_bits = 0;
+ unsigned int green_bits = 0;
+ unsigned int blue_bits = 0;
+ if (depth != 16)
+ {
+ unsigned int bits = depth/3;
+ unsigned int extra = depth%3;
+ red_bits = bits;
+ green_bits = bits;
+ blue_bits = bits;
+ if (extra)
+ {
+ ++red_bits;
+ --extra;
+ }
+ if (extra)
+ {
+ ++green_bits;
+ }
+ }
+ else if (depth == 16)
+ {
+ red_bits = 5;
+ green_bits = 6;
+ blue_bits = 5;
+ }
+
+ if (depth == 16)
+ {
+ temp = bitmap;
+ unsigned char *red, *green, *blue;
+ while (temp != end)
+ {
+ blue = temp;
+ ++temp;
+ green = temp;
+ ++temp;
+ red = temp;
+ ++temp;
+ ++temp;
+
+ const unsigned long r = static_cast<unsigned long>(*red)>>(8-red_bits);
+ const unsigned long g = static_cast<unsigned long>(*green)>>(8-green_bits);
+ const unsigned long b = static_cast<unsigned long>(*blue)>>(8-blue_bits);
+
+ unsigned long color = (r<<(depth-red_bits))| (g<<(depth-red_bits-green_bits))| b;
+
+ *blue = (color>>0)&0xFF;
+ *green = (color>>8)&0xFF;
+ }
+ }
+ else if (depth < 24)
+ {
+ temp = bitmap;
+ unsigned char *red, *green, *blue;
+ while (temp != end)
+ {
+ blue = temp;
+ ++temp;
+ green = temp;
+ ++temp;
+ red = temp;
+ ++temp;
+ ++temp;
+
+ const unsigned long r = static_cast<unsigned long>(*red)>>(8-red_bits);
+ const unsigned long g = static_cast<unsigned long>(*green)>>(8-green_bits);
+ const unsigned long b = static_cast<unsigned long>(*blue)>>(8-blue_bits);
+
+ unsigned long color = (b<<(depth-blue_bits))| (g<<(depth-blue_bits-green_bits))| r;
+
+ *blue = (color>>0)&0xFF;
+ *green = (color>>8)&0xFF;
+ *red = (color>>16)&0xFF;
+ }
+ }
+ else if (depth > 24)
+ {
+ temp = bitmap;
+ unsigned char *red, *green, *blue, *four;
+ while (temp != end)
+ {
+ blue = temp;
+ ++temp;
+ green = temp;
+ ++temp;
+ red = temp;
+ ++temp;
+ four = temp;
+ ++temp;
+
+ const unsigned long r = static_cast<unsigned long>(*red)<<(red_bits-8);
+ const unsigned long g = static_cast<unsigned long>(*green)<<(green_bits-8);
+ const unsigned long b = static_cast<unsigned long>(*blue)<<(blue_bits-8);
+
+ unsigned long color = (b<<(depth-blue_bits))| (g<<(depth-blue_bits-green_bits))| r;
+
+ *blue = (color>>0)&0xFF;
+ *green = (color>>8)&0xFF;
+ *red = (color>>16)&0xFF;
+ *four = (color>>24)&0xFF;
+ }
+ }
+ } // if (depth != 24)
+
+
+
+ XImage img;
+ memset(&img,0,sizeof(img));
+ img.width = width;
+ img.height = height;
+ img.depth = depth;
+ img.data = reinterpret_cast<char*>(bitmap);
+ img.bitmap_bit_order = LSBFirst;
+ img.byte_order = LSBFirst;
+ img.format = ZPixmap;
+ img.bitmap_pad = 32;
+ img.bitmap_unit = 32;
+ img.bits_per_pixel = 32;
+
+
+ XInitImage(&img);
+
+ GC gc = XCreateGC(disp, e->window, 0, NULL);
+
+ XPutImage(disp,e->window,gc,&img,0,0,x,y,width,height);
+
+ XFreeGC(disp,gc);
+ } break;
+ } // switch (ev.type)
+ }
+ }
+ catch (std::exception& e)
+ {
+ dlog << LFATAL << "Exception thrown in event handler: " << e.what();
+ }
+ catch (...)
+ {
+ dlog << LFATAL << "Unknown exception thrown in event handler.";
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+
+ int index_to_modmask(unsigned long n)
+ {
+ switch ( n )
+ {
+ case 0:
+ return Mod1Mask;
+ case 1:
+ return Mod2Mask;
+ case 2:
+ return Mod3Mask;
+ case 3:
+ return Mod4Mask;
+ }
+ return Mod5Mask;
+ }
+
+ void event_handler_thread::
+ init_keyboard_mod_masks()
+ {
+ XModifierKeymap* map = XGetModifierMapping( disp );
+ KeyCode* codes = map->modifiermap + map->max_keypermod * Mod1MapIndex;
+ for (int n = 0; n < 5 * map->max_keypermod; n++ )
+ {
+ if ( codes[n] == 0 )
+ continue;
+ switch(XkbKeycodeToKeysym( disp, codes[n], 0, 0 ))
+ {
+ case XK_Alt_L:
+ alt_mask = index_to_modmask(n / map->max_keypermod);
+ continue;
+ case XK_Alt_R:
+ if(alt_mask == 0)
+ alt_mask = index_to_modmask(n / map->max_keypermod);
+ continue;
+ case XK_Meta_L:
+ case XK_Meta_R:
+ meta_mask = index_to_modmask(n / map->max_keypermod);
+ continue;
+ case XK_Scroll_Lock:
+ scroll_lock_mask = index_to_modmask(n / map->max_keypermod);
+ continue;
+ case XK_Num_Lock:
+ num_lock_mask = index_to_modmask(n / map->max_keypermod);
+ default:
+ continue;
+ }
+ }
+ XFreeModifiermap( map );
+ if ( alt_mask == 0 )
+ {
+ dlog << LWARN << "Search for Alt-key faild.";
+ if ( meta_mask != 0 )
+ alt_mask = meta_mask;
+ else
+ alt_mask = Mod1Mask; // resort to guessing
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+
+
+
+
+ } // namespace gui_core_kernel_2_globals
+
+// ----------------------------------------------------------------------------------------
+
+ void canvas::
+ fill (
+ unsigned char red_,
+ unsigned char green_,
+ unsigned char blue_
+ ) const
+ {
+ pixel pixel_value;
+ pixel_value.red = red_;
+ pixel_value.green = green_;
+ pixel_value.blue = blue_;
+ pixel_value._padding = 0;
+
+ pixel* start = reinterpret_cast<pixel*>(bits);
+ pixel* end = start + width_*height_;
+
+ while (start != end)
+ {
+ *start = pixel_value;
+ ++start;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void put_on_clipboard (
+ const std::string& str
+ )
+ {
+ put_on_clipboard(convert_mbstring_to_wstring(str));
+ }
+
+ void put_on_clipboard (
+ const dlib::ustring& str
+ )
+ {
+ put_on_clipboard(convert_utf32_to_wstring(str));
+ }
+
+ void put_on_clipboard (
+ const std::wstring& str
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+
+ std::shared_ptr<event_handler_thread> globals(global_data());
+
+ auto_mutex M(globals->window_table.get_mutex());
+ globals->clipboard = str.c_str();
+
+ XSetSelectionOwner(globals->disp,XA_PRIMARY,globals->exit_window,CurrentTime);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ Bool clip_peek_helper (
+ Display*,
+ XEvent* event,
+ XPointer
+ )
+ {
+ if ( event->type == SelectionNotify)
+ {
+ return True;
+ }
+ else
+ {
+ return False;
+ }
+ }
+
+ void get_from_clipboard (
+ std::string& str
+ )
+ {
+ std::wstring wstr;
+ get_from_clipboard(wstr);
+ str = convert_wstring_to_mbstring(wstr);
+ }
+
+ void get_from_clipboard (
+ dlib::ustring& str
+ )
+ {
+ std::wstring wstr;
+ get_from_clipboard(wstr);
+ str = convert_wstring_to_utf32(wstr);
+ }
+
+ void get_from_clipboard (
+ std::wstring& str
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ std::shared_ptr<event_handler_thread> globals(global_data());
+
+ auto_mutex M(globals->window_table.get_mutex());
+ str.clear();
+ unsigned char *data = 0;
+ wchar_t **plist = 0;
+ Window sown;
+ Atom type;
+ int format, result;
+ unsigned long len, bytes_left, dummy;
+ XEvent e;
+
+ try
+ {
+ Atom atom_ct = XInternAtom(globals->disp, "COMPOUND_TEXT", False);
+ sown = XGetSelectionOwner (globals->disp, XA_PRIMARY);
+ if (sown == globals->exit_window)
+ {
+ // if we are copying from ourselfs then don't fool with the Xwindows junk.
+ str = globals->clipboard.c_str();
+ }
+ else if (sown != None)
+ {
+ // request that the selection be copied into the XA_PRIMARY property
+ // of the exit_window. It doesn't matter what window we put it in
+ // so long as it is one under the control of this process and exit_window
+ // is easy to use here so that is what I'm using.
+ XConvertSelection (globals->disp, XA_PRIMARY, atom_ct, XA_PRIMARY,
+ globals->exit_window, CurrentTime);
+
+ // This will wait until we get a SelectionNotify event which should happen
+ // really soon.
+ XPeekIfEvent(globals->disp,&e,clip_peek_helper,0);
+
+ // See how much data we got
+ XGetWindowProperty (globals->disp, globals->exit_window,
+ XA_PRIMARY, // Tricky..
+ 0, 0, // offset - len
+ 0, // Delete 0==FALSE
+ AnyPropertyType, //flag
+ &type, // return type
+ &format, // return format
+ &len, &bytes_left, //that
+ &data);
+ if (data)
+ {
+ XFree(data);
+ data = 0;
+ }
+ if (bytes_left > 0 && type == atom_ct)
+ {
+ XTextProperty p;
+ result = XGetWindowProperty (globals->disp, globals->exit_window,
+ XA_PRIMARY, 0,bytes_left,0,
+ AnyPropertyType, &p.encoding,&p.format,
+ &p.nitems, &dummy, &p.value);
+ if (result == Success && p.encoding == atom_ct)
+ {
+ int n;
+ XwcTextPropertyToTextList(globals->disp, &p, &plist, &n);
+ str = plist[0];
+ }
+ if (plist)
+ {
+ XwcFreeStringList(plist);
+ plist = 0;
+ }
+ }
+ }
+ }
+ catch (...)
+ {
+ if (data)
+ XFree(data);
+ if (plist)
+ {
+ XwcFreeStringList(plist);
+ plist = 0;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace gui_core_kernel_2_globals
+ {
+ void trigger_user_event_threadproc (
+ void*
+ )
+ {
+ std::shared_ptr<event_handler_thread> globals(global_data());
+ auto_mutex M(globals->window_table.get_mutex());
+
+ globals->user_events.lock();
+ globals->user_events.swap(globals->user_events_temp);
+ globals->user_events.unlock();
+
+
+ globals->user_events_temp.reset();
+ // now dispatch all these user events
+ while (globals->user_events_temp.move_next())
+ {
+ base_window** win_ = globals->window_table[globals->user_events_temp.element().w];
+ base_window* win;
+ // if this window exists in the window table then dispatch
+ // its event.
+ if (win_)
+ {
+ win = *win_;
+ win->on_user_event(
+ globals->user_events_temp.element().p,
+ globals->user_events_temp.element().i
+ );
+ }
+ }
+ globals->user_events_temp.clear();
+ }
+ }
+
+ void base_window::
+ trigger_user_event (
+ void* p,
+ int i
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ user_event_type e;
+ e.w = x11_stuff.hwnd;
+ e.p = p;
+ e.i = i;
+ {
+ std::shared_ptr<event_handler_thread> globals(global_data());
+ auto_mutex M(globals->user_events.get_mutex());
+ globals->user_events.enqueue(e);
+
+ // we only need to start a thread to deal with this if there isn't already
+ // one out working on the queue
+ if (globals->user_events.size() == 1)
+ create_new_thread (trigger_user_event_threadproc,0);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ base_window::
+ base_window (
+ bool resizable_,
+ bool undecorated
+ ) :
+ x11_stuff(*(new gui_core_kernel_2_globals::x11_base_windowstuff)),
+ is_mapped(false),
+ resizable(resizable_),
+ has_been_destroyed(false),
+ has_been_resized(false),
+ has_been_moved(false),
+ wm(gui_core_kernel_2_globals::global_data()->window_table.get_mutex())
+ {
+ DLIB_ASSERT(!(undecorated == true && resizable_ == true),
+ "\tbase_window::base_window()"
+ << "\n\tThere is no such thing as an undecorated window that is resizable by the user."
+ << "\n\tthis: " << this
+ );
+ using namespace gui_core_kernel_2_globals;
+
+ auto_mutex M(wm);
+
+ x11_stuff.globals = global_data();
+
+ x11_stuff.last_click_time = 0;
+ last_click_x = 0;
+ last_click_y = 0;
+ last_click_button = NONE;
+
+ XSetWindowAttributes attr;
+ memset(&attr,'\0',sizeof(attr));
+
+ unsigned long valuemask = 0;
+ if (undecorated)
+ {
+ attr.override_redirect = True;
+ valuemask = CWOverrideRedirect;
+ }
+
+
+ x11_stuff.hwnd = XCreateWindow(
+ x11_stuff.globals->disp,
+ DefaultRootWindow(x11_stuff.globals->disp),
+ 0,
+ 0,
+ 10, // this is the default width of a window
+ 10, // this is the default width of a window
+ 0,
+ x11_stuff.globals->depth,
+ InputOutput,
+ CopyFromParent,
+ valuemask,
+ &attr
+ );
+
+ x11_stuff.xic = NULL;
+ if (x11_stuff.globals->xim)
+ {
+ XVaNestedList xva_nlist;
+ XPoint xpoint;
+
+ char **mlist;
+ int mcount;
+ char *def_str;
+ char fontset[256];
+ const long native_font_height = 12;
+ sprintf(fontset, "-*-*-medium-r-normal--%lu-*-*-*-", native_font_height);
+ x11_stuff.fs = XCreateFontSet(x11_stuff.globals->disp, fontset, &mlist, &mcount, &def_str);
+ xpoint.x = 0;
+ xpoint.y = 0;
+ xva_nlist = XVaCreateNestedList(0, XNSpotLocation, &xpoint, XNFontSet, x11_stuff.fs, (const void*)NULL);
+ x11_stuff.xic = XCreateIC(
+ x11_stuff.globals->xim,
+ XNInputStyle, x11_stuff.globals->xim_style,
+ XNClientWindow, x11_stuff.hwnd,
+ XNPreeditAttributes, xva_nlist,
+ (const void*)NULL
+ );
+ XFree(xva_nlist);
+ XFreeStringList(mlist);
+ }
+
+ Window temp = x11_stuff.hwnd;
+ base_window* ttemp = this;
+ x11_stuff.globals->window_table.add(temp,ttemp);
+
+ // query event mask required by input method
+ unsigned long event_xim = 0;
+ if (x11_stuff.xic)
+ XGetICValues( x11_stuff.xic, XNFilterEvents, &event_xim, (const void*)NULL );
+
+ XSelectInput(
+ x11_stuff.globals->disp,
+ x11_stuff.hwnd,
+ StructureNotifyMask|ExposureMask|ButtonPressMask|ButtonReleaseMask|
+ PointerMotionMask|LeaveWindowMask|EnterWindowMask|KeyPressMask|
+ KeyReleaseMask| FocusChangeMask | event_xim
+ );
+
+ XSetWMProtocols(
+ x11_stuff.globals->disp,
+ x11_stuff.hwnd,
+ &x11_stuff.globals->delete_window,
+ 1
+ );
+
+
+ // these are just default values
+ x = 0;
+ y = 0;
+ width = 10;
+ height = 10;
+
+ if (resizable == false)
+ {
+ XSizeHints* hints = XAllocSizeHints();
+ hints->flags = PMinSize|PMaxSize;
+ hints->min_width = width;
+ hints->max_width = width;
+ hints->max_height = height;
+ hints->min_height = height;
+ XSetNormalHints(x11_stuff.globals->disp,x11_stuff.hwnd,hints);
+ XFree(hints);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ base_window::
+ ~base_window (
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ close_window();
+
+ if (x11_stuff.globals->xim != NULL)
+ {
+ XDestroyIC(x11_stuff.xic);
+ x11_stuff.xic = 0;
+ XFreeFontSet(x11_stuff.globals->disp,x11_stuff.fs);
+ }
+
+ delete &x11_stuff;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ close_window (
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex M(wm);
+ if (has_been_destroyed == false)
+ {
+ has_been_destroyed = true;
+
+ x11_stuff.globals->window_table.destroy(x11_stuff.hwnd);
+
+ XDestroyWindow(x11_stuff.globals->disp,x11_stuff.hwnd);
+ x11_stuff.hwnd = 0;
+ x11_stuff.globals->window_close_signaler.broadcast();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool base_window::
+ is_closed (
+ ) const
+ {
+ auto_mutex M(wm);
+ return has_been_destroyed;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ set_title (
+ const std::string& title_
+ )
+ {
+ set_title(convert_mbstring_to_wstring(title_));
+ }
+
+ void base_window::
+ set_title (
+ const ustring& title_
+ )
+ {
+ set_title(convert_utf32_to_wstring(title_));
+ }
+
+ void base_window::
+ set_title (
+ const std::wstring& title_
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex M(wm);
+ if (has_been_destroyed == true)
+ return;
+
+ // I'm pretty sure the pointer won't be modified even though
+ // it isn't const anymore.
+ wchar_t *title = const_cast<wchar_t *>(title_.c_str());
+ XTextProperty property;
+ XwcTextListToTextProperty(x11_stuff.globals->disp,&title,1,XStdICCTextStyle, &property);
+ XSetWMName(x11_stuff.globals->disp,x11_stuff.hwnd,&property);
+ XFree(property.value);
+ XFlush(x11_stuff.globals->disp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ show (
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex M(wm);
+ if (has_been_destroyed == true)
+ return;
+
+ XMapRaised(x11_stuff.globals->disp,x11_stuff.hwnd);
+ XFlush(x11_stuff.globals->disp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ wait_until_closed (
+ ) const
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex M(wm);
+ while (has_been_destroyed == false)
+ x11_stuff.globals->window_close_signaler.wait();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ hide (
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex M(wm);
+ if (has_been_destroyed == true)
+ return;
+
+ XUnmapWindow(x11_stuff.globals->disp,x11_stuff.hwnd);
+ XFlush(x11_stuff.globals->disp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ set_size (
+ int width_,
+ int height_
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex a(wm);
+ if (has_been_destroyed == true)
+ return;
+
+
+ // do some sanity checking on these values
+ if (width_ < 1)
+ width_ = 1;
+ if (height_ < 1)
+ height_ = 1;
+
+ width = width_;
+ height = height_;
+ has_been_resized = true;
+
+ if (resizable == false)
+ {
+ XSizeHints* hints = XAllocSizeHints();
+ hints->flags = PMinSize|PMaxSize;
+ hints->min_width = width;
+ hints->max_width = width;
+ hints->max_height = height;
+ hints->min_height = height;
+ XSetNormalHints(x11_stuff.globals->disp,x11_stuff.hwnd,hints);
+ XFree(hints);
+ }
+
+ XResizeWindow(x11_stuff.globals->disp,x11_stuff.hwnd,width,height);
+
+ XFlush(x11_stuff.globals->disp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ set_pos (
+ long x_,
+ long y_
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex a(wm);
+ if (has_been_destroyed == true)
+ return;
+
+
+ x = x_;
+ y = y_;
+
+ has_been_moved = true;
+
+ XMoveWindow(x11_stuff.globals->disp,x11_stuff.hwnd,x,y);
+ XFlush(x11_stuff.globals->disp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ get_pos (
+ long& x_,
+ long& y_
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex a(wm);
+ x_ = 0;
+ y_ = 0;
+ if (has_been_destroyed == true)
+ return;
+
+ // we can't really trust the values we have for x and y because some window managers
+ // will have reported bogus values back in the ConfigureNotify event. So just to be
+ // on the safe side we will use XTranslateCoordinates()
+ int rx, ry;
+ Window desktop_window = DefaultRootWindow(x11_stuff.globals->disp);
+ Window junk;
+ XTranslateCoordinates(x11_stuff.globals->disp,x11_stuff.hwnd,desktop_window,0,0,&rx, &ry, &junk);
+ x_ = rx;
+ y_ = ry;
+ x = rx;
+ y = ry;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ get_size (
+ unsigned long& width_,
+ unsigned long& height_
+ ) const
+ {
+ auto_mutex M(wm);
+ width_ = 0;
+ height_ = 0;
+ if (has_been_destroyed == true)
+ return;
+
+
+ width_ = width;
+ height_ = height;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ get_display_size (
+ unsigned long& width_,
+ unsigned long& height_
+ ) const
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex M(wm);
+ width_ = 0;
+ height_ = 0;
+ if (has_been_destroyed == true)
+ return;
+
+ int screen_number = XScreenNumberOfScreen(x11_stuff.globals->screen);
+ width_ = DisplayWidth(x11_stuff.globals->disp, screen_number);
+ height_ = DisplayHeight(x11_stuff.globals->disp, screen_number);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ invalidate_rectangle (
+ const rectangle& rect
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex a(wm);
+ if (is_mapped == false)
+ return;
+
+ if (rect.is_empty() == false && !has_been_destroyed)
+ {
+ const long x = rect.left();
+ const long y = rect.top();
+ const unsigned long width = rect.width();
+ const unsigned long height = rect.height();
+
+ XClearArea(x11_stuff.globals->disp,x11_stuff.hwnd,x,y,width,height,1);
+ XFlush(x11_stuff.globals->disp);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void base_window::
+ set_im_pos (
+ long x,
+ long y
+ )
+ {
+ using namespace gui_core_kernel_2_globals;
+ auto_mutex a(wm);
+ if (has_been_destroyed == true)
+ return;
+
+ if (!x11_stuff.xic || !(x11_stuff.globals->xim_style & XIMPreeditPosition)) return;
+
+ XVaNestedList xva_nlist;
+ XPoint xpoint;
+
+ xpoint.x = x;
+ xpoint.y = y;
+
+ xva_nlist = XVaCreateNestedList(0, XNSpotLocation, &xpoint, (const void*)NULL);
+ XSetICValues(x11_stuff.xic, XNPreeditAttributes, xva_nlist, (const void*)NULL);
+ XFree(xva_nlist);
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // POSIX
+
+#endif // DLIB_GUI_CORE_KERNEL_2_CPp_
+
diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_2.h b/ml/dlib/dlib/gui_core/gui_core_kernel_2.h
new file mode 100644
index 000000000..efcd4ba19
--- /dev/null
+++ b/ml/dlib/dlib/gui_core/gui_core_kernel_2.h
@@ -0,0 +1,419 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GUI_CORE_KERNEl_2_
+#define DLIB_GUI_CORE_KERNEl_2_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+#ifdef DLIB_NO_GUI_SUPPORT
+#error "DLIB_NO_GUI_SUPPORT is defined so you can't use the GUI code. Turn DLIB_NO_GUI_SUPPORT off if you want to use it."
+#error "Also make sure you have libx11-dev installed on your system"
+#endif
+
+#include <string>
+
+#include "gui_core_kernel_abstract.h"
+#include "../algs.h"
+#include "../threads.h"
+#include "../geometry/rectangle.h"
+#include "../binary_search_tree.h"
+#include <string.h>
+#include "../pixel.h"
+#include "../unicode.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace gui_core_kernel_2_globals
+ {
+ class event_handler_thread;
+ void trigger_user_event_threadproc (void*);
+
+ // This is a forward declaration for a struct that contains any
+ // X11 variables. This allows me to avoid having any dlib header files
+ // include the X11 headers. Which in turn speeds build times and simplifies
+ // build setups.
+ struct x11_base_windowstuff;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void put_on_clipboard (
+ const std::string& str
+ );
+
+ void put_on_clipboard (
+ const std::wstring& str
+ );
+
+ void put_on_clipboard (
+ const dlib::ustring& str
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ void get_from_clipboard (
+ std::string& str
+ );
+
+ void get_from_clipboard (
+ std::wstring& str
+ );
+
+ void get_from_clipboard (
+ dlib::ustring& str
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ class canvas : public rectangle
+ {
+ public:
+ struct pixel
+ {
+ unsigned char blue;
+ unsigned char green;
+ unsigned char red;
+ private:
+ friend class canvas;
+ unsigned char _padding;
+ };
+
+ ~canvas() {}
+
+ inline pixel* operator[] (
+ unsigned long row
+ ) const
+ {
+ DLIB_ASSERT(row < height(),
+ "\tpixel* canvas::operator[]"
+ << "\n\tyou have to give a row that is less than the height()"
+ << "\n\tthis: " << this
+ << "\n\trow: " << row
+ << "\n\theight(): " << height()
+ );
+ unsigned char* temp = bits + row_width*row;
+ return reinterpret_cast<pixel*>(temp);
+ }
+
+ void fill (
+ unsigned char red_,
+ unsigned char green_,
+ unsigned char blue_
+ ) const;
+
+ private:
+
+ friend class gui_core_kernel_2_globals::event_handler_thread;
+
+
+ canvas (
+ unsigned char* bits_,
+ unsigned long left_,
+ unsigned long top_,
+ unsigned long right_,
+ unsigned long bottom_
+ ) :
+ rectangle(left_,top_,right_,bottom_),
+ bits(bits_),
+ width_(width()),
+ height_(height()),
+ row_width(width_*4)
+ {}
+
+ // restricted functions
+ canvas(); // normal constructor
+ canvas(canvas&); // copy constructor
+ canvas& operator=(canvas&); // assignment operator
+
+ unsigned char* const bits;
+ const unsigned long width_;
+ const unsigned long height_;
+ const unsigned long row_width;
+ };
+
+ template <>
+ struct pixel_traits<canvas::pixel>
+ {
+ constexpr static bool rgb = true;
+ constexpr static bool rgb_alpha = false;
+ constexpr static bool grayscale = false;
+ constexpr static bool hsi = false;
+ constexpr static long num = 3;
+ typedef unsigned char basic_pixel_type;
+ static basic_pixel_type min() { return 0;}
+ static basic_pixel_type max() { return 255;}
+ constexpr static bool is_unsigned = true;
+ constexpr static bool has_alpha = false;
+ };
+
+// -----------------
+
+ class base_window
+ {
+ friend class gui_core_kernel_2_globals::event_handler_thread;
+ friend void gui_core_kernel_2_globals::trigger_user_event_threadproc (void*);
+
+ public:
+
+ enum mouse_state_masks
+ {
+ NONE = 0,
+ LEFT = 1,
+ RIGHT = 2,
+ MIDDLE = 4,
+ SHIFT = 8,
+ CONTROL = 16
+ };
+
+ enum keyboard_state_masks
+ {
+ KBD_MOD_NONE = 0,
+ KBD_MOD_SHIFT = 1,
+ KBD_MOD_CONTROL = 2,
+ KBD_MOD_ALT = 4,
+ KBD_MOD_META = 8,
+ KBD_MOD_CAPS_LOCK = 16,
+ KBD_MOD_NUM_LOCK = 32,
+ KBD_MOD_SCROLL_LOCK = 64
+ };
+
+ enum on_close_return_code
+ {
+ DO_NOT_CLOSE_WINDOW,
+ CLOSE_WINDOW
+ };
+
+ enum non_printable_keyboard_keys
+ {
+ KEY_BACKSPACE,
+ KEY_SHIFT,
+ KEY_CTRL,
+ KEY_ALT,
+ KEY_PAUSE,
+ KEY_CAPS_LOCK,
+ KEY_ESC,
+ KEY_PAGE_UP,
+ KEY_PAGE_DOWN,
+ KEY_END,
+ KEY_HOME,
+ KEY_LEFT, // This is the left arrow key
+ KEY_RIGHT, // This is the right arrow key
+ KEY_UP, // This is the up arrow key
+ KEY_DOWN, // This is the down arrow key
+ KEY_INSERT,
+ KEY_DELETE,
+ KEY_SCROLL_LOCK,
+
+ // Function Keys
+ KEY_F1,
+ KEY_F2,
+ KEY_F3,
+ KEY_F4,
+ KEY_F5,
+ KEY_F6,
+ KEY_F7,
+ KEY_F8,
+ KEY_F9,
+ KEY_F10,
+ KEY_F11,
+ KEY_F12
+ };
+
+ private:
+
+ gui_core_kernel_2_globals::x11_base_windowstuff& x11_stuff;
+
+ int x, y, width, height;
+ bool is_mapped;
+
+ const bool resizable;
+ bool has_been_destroyed;
+ bool has_been_resized; // true if someone called set_size() and the on_window_resized() event
+ // hasn't yet occurred.
+ bool has_been_moved; // true if someone called set_pos() and the on_window_moved() event
+ // hasn't yet occurred.
+
+
+ // The following 3 variables (and x11_stuff.last_click_time) are only accessed from the
+ // event handling loop (except for being initialized below). They record the last
+ // mouse click event details.
+ long last_click_x, last_click_y;
+ unsigned long last_click_button;
+
+
+ protected:
+ const rmutex& wm;
+
+ public:
+
+ base_window (
+ bool resizable_ = true,
+ bool undecorated = false
+ );
+
+ virtual ~base_window (
+ );
+
+ void close_window (
+ );
+
+ void wait_until_closed (
+ ) const;
+
+ void set_im_pos (
+ long x_,
+ long y_
+ );
+
+ bool is_closed (
+ ) const;
+
+ void set_title (
+ const std::string& title_
+ );
+
+ void set_title (
+ const std::wstring& title_
+ );
+
+ void set_title (
+ const dlib::ustring& title_
+ );
+
+ virtual void show (
+ );
+
+ virtual void hide(
+ );
+
+ void set_size (
+ int width_,
+ int height_
+ );
+
+ void set_pos (
+ long x_,
+ long y_
+ );
+
+ void get_pos (
+ long& x_,
+ long& y_
+ );
+
+ void get_size (
+ unsigned long& width_,
+ unsigned long& height_
+ ) const;
+
+ void get_display_size (
+ unsigned long& width,
+ unsigned long& height
+ ) const;
+
+ void invalidate_rectangle (
+ const rectangle& rect
+ );
+
+ void trigger_user_event (
+ void* p,
+ int i
+ );
+
+ protected:
+
+ virtual on_close_return_code on_window_close(
+ ){return CLOSE_WINDOW;}
+
+ virtual void on_window_resized(
+ ){}
+
+ virtual void on_window_moved(
+ ){}
+ virtual void on_user_event (
+ void* ,
+ int
+ ){}
+
+ virtual void on_mouse_down (
+ unsigned long ,
+ unsigned long ,
+ long ,
+ long ,
+ bool
+ ){}
+
+ virtual void on_mouse_up (
+ unsigned long ,
+ unsigned long ,
+ long ,
+ long
+ ){}
+
+ virtual void on_mouse_move (
+ unsigned long ,
+ long ,
+ long
+ ){}
+
+ virtual void on_mouse_leave (
+ ){}
+
+ virtual void on_mouse_enter (
+ ){}
+
+ virtual void on_wheel_up (
+ unsigned long
+ ){}
+
+ virtual void on_wheel_down (
+ unsigned long
+ ){}
+
+ virtual void on_focus_gained (
+ ){}
+
+ virtual void on_focus_lost (
+ ){}
+
+ virtual void on_keydown (
+ unsigned long ,
+ bool ,
+ unsigned long
+ ){}
+
+ virtual void on_string_put (
+ const std::wstring&
+ ){}
+
+ private:
+
+ virtual void paint (
+ const canvas& c
+ ) =0;
+
+
+
+ base_window(base_window&); // copy constructor
+ base_window& operator=(base_window&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+
+#ifdef NO_MAKEFILE
+#include "gui_core_kernel_2.cpp"
+#endif
+
+#endif // DLIB_GUI_CORE_KERNEl_2_
+
diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_abstract.h b/ml/dlib/dlib/gui_core/gui_core_kernel_abstract.h
new file mode 100644
index 000000000..a773e4287
--- /dev/null
+++ b/ml/dlib/dlib/gui_core/gui_core_kernel_abstract.h
@@ -0,0 +1,792 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_GUI_CORE_KERNEl_ABSTRACT_
+#ifdef DLIB_GUI_CORE_KERNEl_ABSTRACT_
+
+#include <string>
+#include "../algs.h"
+#include "../geometry/rectangle_abstract.h"
+#include "../unicode/unicode_abstract.h"
+
+namespace dlib
+{
+
+ /*!
+ OVERVIEW:
+ This is a set of objects and functions which provide a very basic
+ framework for manipulating windows. It is intended to provide a
+ portable interface which can be used to build a more complex windowing
+ toolkit.
+
+ EXCEPTIONS
+ Do not let an exception leave any of the base_window event handlers.
+ The results of doing so are undefined.
+
+ THREAD SAFETY
+ Event Handlers
+ All event handlers are executed in a special event handling thread.
+ This means that you must not do anything that will take a long time or
+ block while in an event handler. Doing so will freeze all event
+ processing.
+
+ Also, don't rely on get_thread_id() always returning the same ID from
+ inside event handlers.
+
+ canvas
+ Never access a canvas object outside of the paint() callback
+ that supplied it. Only access a canvas object from the event
+ handling thread. After the paint() event handler has returned do
+ not access that canvas object again.
+
+ base_window
+ All methods for this class are thread safe. You may call them
+ from any thread and do not need to serialize access.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void put_on_clipboard (
+ const std::string& str
+ );
+ /*!
+ ensures
+ - posts the contents of str to the system clipboard
+ throws
+ - std::bad_alloc
+ - dlib::gui_error
+ - dlib::thread_error
+ !*/
+
+ // overloads for wide character strings
+ void put_on_clipboard (const std::wstring& str);
+ void put_on_clipboard (const dlib::ustring& str);
+
+// ----------------------------------------------------------------------------------------
+
+ void get_from_clipboard (
+ std::string& str
+ );
+ /*!
+ ensures
+ - if (there is string data on the system clipboard) then
+ - #str == the data from the clipboard
+ - else
+ - #str == ""
+ throws
+ - std::bad_alloc
+ - dlib::gui_error
+ - dlib::thread_error
+ !*/
+
+ // overloads for wide character strings
+ void get_from_clipboard (std::wtring& str);
+ void get_from_clipboard (dlib::utring& str);
+
+// ----------------------------------------------------------------------------------------
+
+
+ class canvas : public rectangle
+ {
+ /*!
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ All functions of this object may invalidate pointers and references
+ to internal data.
+
+ INITIAL VALUE
+ The initial value of each pixel is undefined.
+ is_empty() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a rectangular area of a window that you
+ can draw on.
+
+ Each pixel can be accessed with the following syntax:
+ canvas_instance[y][x].red == the red value for this pixel
+ canvas_instance[y][x].blue == the blue value for this pixel
+ canvas_instance[y][x].green == the green value for this pixel
+
+ The origin, i.e. (0,0), of the x,y coordinate plane of the canvas is in
+ the upper left corner of the canvas. Note that the upper left corner
+ of the canvas appears at the point (left(),top()) in its window.
+ !*/
+
+ public:
+
+ struct pixel
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a single pixel. Each pixel's value
+ ranges from 0 to 255 with 0 indicating that the color is not
+ present in the pixel at all and 255 indicating that the color
+ is present in the pixel with maximum intensity.
+
+ Note that the structure, order, and size of of this struct are
+ implementation dependent. It will always contain fields called
+ red, green, and blue but they may not be in that order and there
+ may be padding.
+
+ Also note that pixel_traits<> is defined for this pixel type,
+ thus you can use it in assign_pixel() calls.
+ !*/
+ unsigned char red;
+ unsigned char green;
+ unsigned char blue;
+ };
+
+
+ pixel* operator[] (
+ unsigned long row
+ ) const;
+ /*!
+ requires
+ - row < height()
+ ensures
+ - returns an array of width() pixel structs that represents the given
+ row of pixels in the canvas.
+ !*/
+
+ void fill (
+ unsigned char red,
+ unsigned char green,
+ unsigned char blue
+ ) const;
+ /*!
+ ensures
+ - for all valid values of x and y:
+ - (#*this)[y][x].red = red
+ - (#*this)[y][x].green = green
+ - (#*this)[y][x].blue = blue
+ !*/
+
+ private:
+
+ // restricted functions
+ canvas(); // normal constructor
+ canvas(canvas&); // copy constructor
+ canvas& operator=(canvas&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class base_window
+ {
+
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a window on the desktop. A window has a "client
+ area" that is a region of the screen that you can draw whatever you like
+ on. You implement the paint() callback and use the canvas object to do
+ this drawing.
+
+ INITIAL STATE
+ - The initial state of the window is to be hidden. This means you need
+ to call show() to make it appear.
+ - is_closed() == false
+
+ paint() callback:
+ This is where you will do all your drawing. It is triggered when
+ part of the window needs to be drawn/redrawn.
+
+ mouse events:
+ It is important to note a few things about the mouse events. First,
+ the on_mouse_move() event is not triggered for each pixel the mouse crosses
+ but rather its frequency and precision is implementation dependent.
+
+ Second, it is possible that a mouse button may be depressed but the
+ corresponding button release event does not go to the window. For instance,
+ if the mouse is outside the window and some other application jumps to the
+ top it is possible that the new application will receive any mouse button
+ release events rather than the original window. But the point is that
+ you should not rely on always getting a button up event for every button
+ down event.
+
+ keydown event:
+ Note that the existence of a typematic action (holding down a key
+ and having it start to repeat itself after a moment) for each key is
+ totally implementation dependent. So don't rely on it for any key
+ and conversely don't assume it isn't present either.
+
+ The base_window::wm mutex
+ This is a reference to a global rmutex. All instances of base_window make
+ reference to the same global rmutex. It is used to synchronize access to
+ the base_window to make it thread safe. It is also always locked before
+ an event handler is called.
+ !*/
+
+ public:
+
+ enum on_close_return_code
+ {
+ DO_NOT_CLOSE_WINDOW,
+ CLOSE_WINDOW
+ };
+
+ enum mouse_state_masks
+ {
+ /*!
+ These constants represent the various buttons referenced by
+ mouse events.
+ !*/
+ NONE = 0,
+ LEFT = 1,
+ RIGHT = 2,
+ MIDDLE = 4,
+ SHIFT = 8,
+ CONTROL = 16
+ };
+
+ enum keyboard_state_masks
+ {
+ /*!
+ These constants represent the various modifier buttons that
+ could be in effect during a key press on the keyboard
+ !*/
+ KBD_MOD_NONE = 0,
+ KBD_MOD_SHIFT = 1,
+ KBD_MOD_CONTROL = 2,
+ KBD_MOD_ALT = 4,
+ KBD_MOD_META = 8,
+ KBD_MOD_CAPS_LOCK = 16,
+ KBD_MOD_NUM_LOCK = 32,
+ KBD_MOD_SCROLL_LOCK = 64
+ };
+
+ enum non_printable_keyboard_keys
+ {
+ KEY_BACKSPACE,
+ KEY_SHIFT,
+ KEY_CTRL,
+ KEY_ALT,
+ KEY_PAUSE,
+ KEY_CAPS_LOCK,
+ KEY_ESC,
+ KEY_PAGE_UP,
+ KEY_PAGE_DOWN,
+ KEY_END,
+ KEY_HOME,
+ KEY_LEFT, // This is the left arrow key
+ KEY_RIGHT, // This is the right arrow key
+ KEY_UP, // This is the up arrow key
+ KEY_DOWN, // This is the down arrow key
+ KEY_INSERT,
+ KEY_DELETE,
+ KEY_SCROLL_LOCK,
+
+ // Function Keys
+ KEY_F1,
+ KEY_F2,
+ KEY_F3,
+ KEY_F4,
+ KEY_F5,
+ KEY_F6,
+ KEY_F7,
+ KEY_F8,
+ KEY_F9,
+ KEY_F10,
+ KEY_F11,
+ KEY_F12
+ };
+
+ base_window (
+ bool resizable = true,
+ bool undecorated = false
+ );
+ /*!
+ requires
+ - if (undecorated == true) then
+ - resizable == false
+ ensures
+ - #*this has been properly initialized
+ - if (resizable == true) then
+ - this window will be resizable by the user
+ - else
+ - this window will not be resizable by the user
+ - if (undecorated == true) then
+ - this window will not have any graphical elements outside
+ of its drawable area or appear in the system task bar. It
+ also won't take the input focus from other windows.
+ (it is suitable for making things such as popup menus)
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ - dlib::gui_error
+ This exception is thrown if there is an error while
+ creating this window.
+ !*/
+
+ virtual ~base_window (
+ );
+ /*!
+ ensures
+ - does NOT trigger the on_window_close() event
+ - all resources associated with *this have been released
+ - closes this window
+ !*/
+
+ void close_window (
+ );
+ /*!
+ ensures
+ - #is_closed() == true
+ (i.e. permanently closes this window. The window is removed from the
+ screen and no more events will be dispatched to this window. )
+ - does NOT trigger the on_window_close() event
+ !*/
+
+ void wait_until_closed (
+ ) const;
+ /*!
+ ensures
+ - blocks until is_closed() == true
+ !*/
+
+ bool is_closed (
+ ) const;
+ /*!
+ ensures
+ - returns true if this window has been closed, false otherwise.
+ (Note that closed windows do not receive any callbacks at all.
+ They are also not visible on the screen.)
+ !*/
+
+ void set_title (
+ const std::string& title
+ );
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - sets the title of the window
+ !*/
+
+ void set_title (
+ const std::wstring& title
+ );
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - sets the title of the window
+ !*/
+
+ void set_title (
+ const dlib::ustring& title
+ );
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - sets the title of the window
+ !*/
+
+ virtual void show (
+ );
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - this window will appear on the screen
+ !*/
+
+ virtual void hide(
+ );
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - the window does not appear on the screen
+ !*/
+
+ void set_size (
+ int width,
+ int height
+ );
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - The width of the client area of this window is at least width
+ pixels.
+ - The height of the client area of this window is at least height
+ pixels.
+ - if (the window wasn't already this size) then
+ - triggers the on_window_resized() callback
+ !*/
+
+ void set_pos (
+ long x,
+ long y
+ );
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - sets the upper left corner of this window to the position (x,y)
+ on the desktop. Note that the origin (0,0) is at the upper left
+ corner of the desktop.
+ !*/
+
+ void get_pos (
+ long& x,
+ long& y
+ ) const;
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - #x == the x coordinate of the upper left corner of the client area of
+ this window.
+ - #y == the y coordinate of the upper left corner of the client area of
+ this window.
+ - i.e. the point (#x,#y) on the desktop is coincident with the point
+ (0,0) in the client area of this window.
+ - else
+ - #x == 0
+ - #y == 0
+ !*/
+
+ void get_size (
+ unsigned long& width,
+ unsigned long& height
+ ) const;
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - #width == the width of the client area of this window in pixels
+ - #height == the height of the client area of this window in pixels
+ - else
+ - #width == 0
+ - #height == 0
+ !*/
+
+ void get_display_size (
+ unsigned long& width,
+ unsigned long& height
+ ) const;
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - #width == the width in pixels of the display device that contains this window
+ - #height == the height in pixels of the display device that contains this window
+ - else
+ - #width == 0
+ - #height == 0
+ !*/
+
+ void invalidate_rectangle (
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - causes the area of this window defined by rect to become invalid.
+ This means that a paint() message will be dispatched to repaint
+ this area of the window. Note that it is possible that the
+ resulting paint() message may include a bigger rectangle than
+ the one defined by rect.
+ !*/
+
+ void trigger_user_event (
+ void* p,
+ int i
+ );
+ /*!
+ ensures
+ - will never block (even if some other thread has a lock on the
+ global mutex referenced by wm.)
+ - if (is_closed() == false) then
+ - causes the on_user_event() event to be called with
+ the given arguments.
+ !*/
+
+ void set_im_pos (
+ long x_,
+ long y_
+ );
+ /*!
+ ensures
+ - if (is_closed() == false) then
+ - sets the left-top position of input method rectangle used
+ for wide character input methods.
+ !*/
+
+ protected:
+ const rmutex& wm;
+
+ // let the window close by default
+ virtual on_close_return_code on_window_close(
+ ){return CLOSE_WINDOW;}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when the user attempts to close this window
+ - if (this function returns CLOSE_WINDOW) then
+ - #is_closed() == true (i.e. this window will be closed)
+ - it is safe to call "delete this;" inside on_window_close()
+ if *this was allocated on the heap and no one will try to
+ access *this anymore.
+ - else
+ - this window will not be closed and the attempt to close it
+ by the user will have no effect.
+ - #is_closed() == false
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_user_event (
+ void* p,
+ int i
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called whenever someone calls trigger_user_event()
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_window_resized(
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when this window is resized
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_window_moved(
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when this window's position changes
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when the user depresses one of the mouse buttons
+ - btn == the button that was depressed. (either LEFT, MIDDLE, or RIGHT)
+ - state == the bitwise OR of the buttons that are currently depressed
+ excluding the button given by btn. (from the mouse_state_masks enum)
+ - (x,y) == the position of the mouse (relative to the upper left corner
+ of the window) when this event occurred. Note that the mouse may be
+ outside the window.
+ - if (this is the second button press of a double click) then
+ - is_double_click == true
+ - else
+ - is_double_click == false
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when the user releases one of the mouse buttons
+ - btn == the button that was released. (either LEFT, MIDDLE, or RIGHT)
+ - state == the bitwise OR of the buttons that are currently depressed
+ (from the mouse_state_masks enum)
+ - (x,y) == the position of the mouse (relative to the upper left corner
+ of the window) when this event occurred. Note that the mouse may be
+ outside the window.
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when the user moves the mouse
+ - state == the bitwise OR of the buttons that are currently depressed
+ (from the mouse_state_masks enum)
+ - (x,y) == the position of the mouse (relative to the upper left corner
+ of the window) when this event occurred.
+ - if (the user is holding down one or more of the mouse buttons) then
+ - the mouse move events will continue to track the mouse even if
+ it goes out of the window. This will continue until the user
+ releases all the mouse buttons.
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_mouse_leave (
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when the mouse leaves this window
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_mouse_enter (
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when the mouse enters this window
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_focus_gained (
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when this window gains input focus
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_focus_lost (
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when this window loses input focus
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_wheel_up (
+ unsigned long state
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called every time the mouse wheel is scrolled up one notch
+ - state == the bitwise OR of the buttons that are currently depressed
+ (from the mouse_state_masks enum)
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_wheel_down (
+ unsigned long state
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called every time the mouse wheel is scrolled down one notch
+ - state == the bitwise OR of the buttons that are currently depressed
+ (from the mouse_state_masks enum)
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ // do nothing by default
+ virtual void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when a keyboard key is pressed or if a key is held
+ down then this is called repeatedly at a certain rate once the
+ typematic action begins (note that some keys might not have any
+ typematic action on some platforms).
+ - if (is_printable) then
+ - key == the character that was pressed. (e.g. 'a', 'b', '1' etc.)
+ - this is a printable character. Note that ' ', '\t', and
+ '\n' (this is the return/enter key) are all considered printable.
+ - else
+ - key == one of the non_printable_keyboard_keys enums.
+ - state == the bitwise OR of the keyboard modifiers that are currently
+ depressed (taken from keyboard_state_masks).
+ - if (key is not in the range 'a' to 'z' or 'A' to 'Z') then
+ - if (the shift key was down when this key was pressed) then
+ - (state & KBD_MOD_SHIFT) != 0
+ - else
+ - (state & KBD_MOD_SHIFT) == 0
+ - else
+ - the state of the shift key is implementation defined
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ virtual void on_string_put (
+ const std::wstring &str
+ ){}
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when a wide/multibyte character input method determines a string
+ that is being input to the window.
+ - str == the string that is being input
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ private:
+
+ virtual void paint (
+ const canvas& c
+ ) =0;
+ /*!
+ requires
+ - is_closed() == false
+ - mutex wm is locked
+ - is called when part of the window needs to be repainted for
+ any reason.
+ - c == a canvas object that represents the invalid area of this
+ window which needs to be painted.
+ ensures
+ - does not change the state of mutex wm
+ !*/
+
+ base_window(base_window&); // copy constructor
+ base_window& operator=(base_window&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GUI_CORE_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/gui_core/windows.h b/ml/dlib/dlib/gui_core/windows.h
new file mode 100644
index 000000000..8d71fd70e
--- /dev/null
+++ b/ml/dlib/dlib/gui_core/windows.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GUI_CORE_KERNEl_2_
+#include "gui_core_kernel_1.h"
+#endif
+
diff --git a/ml/dlib/dlib/gui_core/xlib.h b/ml/dlib/dlib/gui_core/xlib.h
new file mode 100644
index 000000000..de9c4666b
--- /dev/null
+++ b/ml/dlib/dlib/gui_core/xlib.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GUI_CORE_KERNEl_1_
+#include "gui_core_kernel_2.h"
+#endif
+
diff --git a/ml/dlib/dlib/gui_widgets.h b/ml/dlib/dlib/gui_widgets.h
new file mode 100644
index 000000000..5b243ef8f
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets.h
@@ -0,0 +1,18 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_GUI_WIDGETs_
+#define DLIB_GUI_WIDGETs_
+
+
+
+#include "gui_widgets/widgets.h"
+
+
+
+#endif // DLIB_GUI_WIDGETs_
+
diff --git a/ml/dlib/dlib/gui_widgets/base_widgets.cpp b/ml/dlib/dlib/gui_widgets/base_widgets.cpp
new file mode 100644
index 000000000..2f2eb8e9c
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/base_widgets.cpp
@@ -0,0 +1,3343 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BASE_WIDGETs_CPP_
+#define DLIB_BASE_WIDGETs_CPP_
+
+#include <iostream>
+#include <memory>
+
+#include "base_widgets.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // button object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ rectangle min_rect = style->get_min_size(name_,*mfont);
+ // only change the size if it isn't going to be too small to fit the name
+ if (height >= min_rect.height() &&
+ width >= min_rect.width())
+ {
+ rectangle old(rect);
+ rect = resize_rect(rect,width,height);
+ parent.invalidate_rectangle(style->get_invalidation_rect(rect+old));
+ btn_tooltip.set_size(width,height);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ show (
+ )
+ {
+ button_action::show();
+ btn_tooltip.show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ hide (
+ )
+ {
+ button_action::hide();
+ btn_tooltip.hide();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ enable (
+ )
+ {
+ button_action::enable();
+ btn_tooltip.enable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ disable (
+ )
+ {
+ button_action::disable();
+ btn_tooltip.disable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ set_tooltip_text (
+ const std::string& text
+ )
+ {
+ btn_tooltip.set_text(text);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ set_tooltip_text (
+ const std::wstring& text
+ )
+ {
+ btn_tooltip.set_text(text);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ set_tooltip_text (
+ const ustring& text
+ )
+ {
+ btn_tooltip.set_text(text);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string button::
+ tooltip_text (
+ ) const
+ {
+ return btn_tooltip.text();
+ }
+
+ const std::wstring button::
+ tooltip_wtext (
+ ) const
+ {
+ return btn_tooltip.wtext();
+ }
+
+ const dlib::ustring button::
+ tooltip_utext (
+ ) const
+ {
+ return btn_tooltip.utext();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+ set_name(name_);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ auto_mutex M(m);
+ button_action::set_pos(x,y);
+ btn_tooltip.set_pos(x,y);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ set_name (
+ const std::string& name
+ )
+ {
+ set_name(convert_mbstring_to_wstring(name));
+ }
+
+ void button::
+ set_name (
+ const std::wstring& name
+ )
+ {
+ set_name(convert_wstring_to_utf32(name));
+ }
+
+ void button::
+ set_name (
+ const ustring& name
+ )
+ {
+ auto_mutex M(m);
+ name_ = name;
+ // do this to get rid of any reference counting that may be present in
+ // the std::string implementation.
+ name_[0] = name_[0];
+
+ rectangle old(rect);
+ rect = move_rect(style->get_min_size(name,*mfont),rect.left(),rect.top());
+ btn_tooltip.set_size(rect.width(),rect.height());
+
+ parent.invalidate_rectangle(style->get_invalidation_rect(rect+old));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string button::
+ name (
+ ) const
+ {
+ auto_mutex M(m);
+ std::string temp = convert_wstring_to_mbstring(wname());
+ // do this to get rid of any reference counting that may be present in
+ // the std::string implementation.
+ char c = temp[0];
+ temp[0] = c;
+ return temp;
+ }
+
+ const std::wstring button::
+ wname (
+ ) const
+ {
+ auto_mutex M(m);
+ std::wstring temp = convert_utf32_to_wstring(uname());
+ // do this to get rid of any reference counting that may be present in
+ // the std::wstring implementation.
+ wchar_t w = temp[0];
+ temp[0] = w;
+ return temp;
+ }
+
+ const dlib::ustring button::
+ uname (
+ ) const
+ {
+ auto_mutex M(m);
+ dlib::ustring temp = name_;
+ // do this to get rid of any reference counting that may be present in
+ // the dlib::ustring implementation.
+ temp[0] = name_[0];
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ on_button_up (
+ bool mouse_over
+ )
+ {
+ if (mouse_over)
+ {
+ // this is a valid button click
+ if (event_handler.is_set())
+ event_handler();
+ if (event_handler_self.is_set())
+ event_handler_self(*this);
+ }
+ if (button_up_handler.is_set())
+ button_up_handler(mouse_over);
+ if (button_up_handler_self.is_set())
+ button_up_handler_self(mouse_over,*this);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button::
+ on_button_down (
+ )
+ {
+ if (button_down_handler.is_set())
+ button_down_handler();
+ if (button_down_handler_self.is_set())
+ button_down_handler_self(*this);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // draggable object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ draggable::~draggable() {}
+
+// ----------------------------------------------------------------------------------------
+
+ void draggable::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ if (drag && (state & base_window::LEFT) && enabled && !hidden)
+ {
+ // the user is trying to drag this object. we should calculate the new
+ // x and y positions for the upper left corner of this object's rectangle
+
+ long new_x = x - this->x;
+ long new_y = y - this->y;
+
+ // make sure these points are inside the draggable area.
+ if (new_x < area.left())
+ new_x = area.left();
+ if (new_x + static_cast<long>(rect.width()) - 1 > area.right())
+ new_x = area.right() - rect.width() + 1;
+
+ if (new_y + static_cast<long>(rect.height()) - 1 > area.bottom())
+ new_y = area.bottom() - rect.height() + 1;
+ if (new_y < area.top())
+ new_y = area.top();
+
+ // now make the new rectangle for this object
+ rectangle new_rect(
+ new_x,
+ new_y,
+ new_x + rect.width() - 1,
+ new_y + rect.height() - 1
+ );
+
+ // only do anything if this is a new rectangle and it is inside area
+ if (new_rect != rect && area.intersect(new_rect) == new_rect)
+ {
+ parent.invalidate_rectangle(new_rect + rect);
+ rect = new_rect;
+
+ // call the on_drag() event handler
+ on_drag();
+ }
+ }
+ else
+ {
+ drag = false;
+ on_drag_stop();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void draggable::
+ on_mouse_up (
+ unsigned long ,
+ unsigned long state,
+ long ,
+ long
+ )
+ {
+ if (drag && (state & base_window::LEFT) == 0)
+ {
+ drag = false;
+ on_drag_stop();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void draggable::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long ,
+ long x,
+ long y,
+ bool
+ )
+ {
+ if (enabled && !hidden && rect.contains(x,y) && btn == base_window::LEFT)
+ {
+ drag = true;
+ this->x = x - rect.left();
+ this->y = y - rect.top();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // mouse_over_event object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ mouse_over_event::~mouse_over_event() {}
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_over_event::
+ on_mouse_leave (
+ )
+ {
+ if (is_mouse_over_)
+ {
+ is_mouse_over_ = false;
+ on_mouse_not_over();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_over_event::
+ on_mouse_move (
+ unsigned long ,
+ long x,
+ long y
+ )
+ {
+ if (rect.contains(x,y) == false)
+ {
+ if (is_mouse_over_)
+ {
+ is_mouse_over_ = false;
+ on_mouse_not_over();
+ }
+ }
+ else if (is_mouse_over_ == false)
+ {
+ is_mouse_over_ = true;
+ if (enabled && !hidden)
+ on_mouse_over();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool mouse_over_event::
+ is_mouse_over (
+ ) const
+ {
+ // check if the mouse is still really over this button
+ if (is_mouse_over_ && rect.contains(lastx,lasty) == false)
+ {
+ // trigger a user event to call on_mouse_not_over() and repaint this object.
+ // we must do this in another event because someone might call is_mouse_over()
+ // from draw() and you don't want this function to end up calling
+ // parent.invalidate_rectangle(). It would lead to draw() being called over
+ // and over.
+ parent.trigger_user_event((void*)this,drawable::next_free_user_event_number());
+ return false;
+ }
+
+ return is_mouse_over_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_over_event::
+ on_user_event (
+ int num
+ )
+ {
+ if (is_mouse_over_ && num == drawable::next_free_user_event_number())
+ {
+ is_mouse_over_ = false;
+ on_mouse_not_over();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // button_action object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ button_action::~button_action() {}
+
+// ----------------------------------------------------------------------------------------
+
+ void button_action::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long ,
+ long x,
+ long y,
+ bool
+ )
+ {
+ if (enabled && !hidden && btn == base_window::LEFT && rect.contains(x,y))
+ {
+ is_depressed_ = true;
+ seen_click = true;
+ parent.invalidate_rectangle(rect);
+ on_button_down();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button_action::
+ on_mouse_not_over (
+ )
+ {
+ if (is_depressed_)
+ {
+ is_depressed_ = false;
+ parent.invalidate_rectangle(rect);
+ on_button_up(false);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button_action::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ // forward event to the parent class so it can do it's thing as well as us
+ mouse_over_event::on_mouse_move(state,x,y);
+
+ if (enabled == false || hidden == true)
+ return;
+
+
+ if ((state & base_window::LEFT) == 0)
+ {
+ seen_click = false;
+ if (is_depressed_)
+ {
+ is_depressed_ = false;
+ parent.invalidate_rectangle(rect);
+ on_button_up(false);
+ }
+
+ // the left button isn't down so we don't care about anything else
+ return;
+ }
+
+ if (rect.contains(x,y) == false)
+ {
+ if (is_depressed_)
+ {
+ is_depressed_ = false;
+ parent.invalidate_rectangle(rect);
+ on_button_up(false);
+ }
+ }
+ else if (is_depressed_ == false && seen_click)
+ {
+ is_depressed_ = true;
+ parent.invalidate_rectangle(rect);
+ on_button_down();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button_action::
+ on_mouse_up (
+ unsigned long btn,
+ unsigned long,
+ long x,
+ long y
+ )
+ {
+ if (enabled && !hidden && btn == base_window::LEFT)
+ {
+ if (is_depressed_)
+ {
+ is_depressed_ = false;
+ parent.invalidate_rectangle(rect);
+
+ if (rect.contains(x,y))
+ {
+ on_button_up(true);
+ }
+ else
+ {
+ on_button_up(false);
+ }
+ }
+ else if (seen_click && rect.contains(x,y))
+ {
+ // this case here covers the unlikly event that you click on a button,
+ // move the mouse off the button and then move it back very quickly and
+ // release the mouse button. It is possible that this mouse up event
+ // will occurr before any mouse move event so you might not have set
+ // that the button is depressed yet.
+
+ // So we should say that this triggers an on_button_down() event and
+ // then an on_button_up(true) event.
+
+ parent.invalidate_rectangle(rect);
+
+ on_button_down();
+ on_button_up(true);
+ }
+
+ seen_click = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool button_action::
+ is_depressed (
+ ) const
+ {
+ // check if the mouse is still really over this button
+ if (enabled && !hidden && is_depressed_ && rect.contains(lastx,lasty) == false)
+ {
+ // trigger a user event to call on_button_up() and repaint this object.
+ // we must do this in another event because someone might call is_depressed()
+ // from draw() and you don't want this function to end up calling
+ // parent.invalidate_rectangle(). It would lead to draw() being called over
+ // and over.
+ parent.trigger_user_event((void*)this,mouse_over_event::next_free_user_event_number());
+ return false;
+ }
+
+ return is_depressed_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void button_action::
+ on_user_event (
+ int num
+ )
+ {
+ // forward event to the parent class so it can do it's thing as well as us
+ mouse_over_event::on_user_event(num);
+
+ if (is_depressed_ && num == mouse_over_event::next_free_user_event_number())
+ {
+ is_depressed_ = false;
+ parent.invalidate_rectangle(rect);
+ on_button_up(false);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // scroll_bar object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ scroll_bar::
+ scroll_bar(
+ drawable_window& w,
+ bar_orientation orientation
+ ) :
+ drawable(w),
+ b1(w),
+ b2(w),
+ slider(w,*this,&scroll_bar::on_slider_drag),
+ ori(orientation),
+ top_filler(w,*this,&scroll_bar::top_filler_down,&scroll_bar::top_filler_up),
+ bottom_filler(w,*this,&scroll_bar::bottom_filler_down,&scroll_bar::bottom_filler_up),
+ pos(0),
+ max_pos(0),
+ js(10),
+ b1_timer(*this,&scroll_bar::b1_down_t),
+ b2_timer(*this,&scroll_bar::b2_down_t),
+ top_filler_timer(*this,&scroll_bar::top_filler_down_t),
+ bottom_filler_timer(*this,&scroll_bar::bottom_filler_down_t)
+ {
+ set_style(scroll_bar_style_default());
+
+ // don't show the slider when there is no place it can move.
+ slider.hide();
+
+ set_length(100);
+
+ b1.set_button_down_handler(*this,&scroll_bar::b1_down);
+ b2.set_button_down_handler(*this,&scroll_bar::b2_down);
+
+ b1.set_button_up_handler(*this,&scroll_bar::b1_up);
+ b2.set_button_up_handler(*this,&scroll_bar::b2_up);
+ b1.disable();
+ b2.disable();
+ enable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ scroll_bar::
+ ~scroll_bar(
+ )
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ // wait for all the timers to be stopped
+ b1_timer.stop_and_wait();
+ b2_timer.stop_and_wait();
+ top_filler_timer.stop_and_wait();
+ bottom_filler_timer.stop_and_wait();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ scroll_bar::bar_orientation scroll_bar::
+ orientation (
+ ) const
+ {
+ auto_mutex M(m);
+ return ori;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ set_length (
+ unsigned long length
+ )
+ {
+ auto_mutex M(m);
+ // make the min length be at least 1
+ if (length == 0)
+ {
+ length = 1;
+ }
+
+
+ parent.invalidate_rectangle(rect);
+
+ if (ori == HORIZONTAL)
+ {
+ rect.set_right(rect.left() + length - 1);
+ rect.set_bottom(rect.top() + style->get_width() - 1);
+
+ const long btn_size = style->get_button_length(rect.width(), max_pos);
+
+ b1.set_size(btn_size,style->get_width());
+ b2.set_size(btn_size,style->get_width());
+
+ slider.set_size(get_slider_size(),style->get_width());
+ }
+ else
+ {
+ rect.set_right(rect.left() + style->get_width() - 1);
+ rect.set_bottom(rect.top() + length - 1);
+
+ const long btn_size = style->get_button_length(rect.height(), max_pos);
+
+ b1.set_size(style->get_width(),btn_size);
+ b2.set_size(style->get_width(),btn_size);
+
+ slider.set_size(style->get_width(),get_slider_size());
+ }
+
+ // call this to put everything is in the right spot.
+ set_pos (rect.left(),rect.top());
+
+ if ((b2.get_rect().top() - b1.get_rect().bottom() - 1 <= 8 && ori == VERTICAL) ||
+ (b2.get_rect().left() - b1.get_rect().right() - 1 <= 8 && ori == HORIZONTAL) ||
+ max_pos == 0)
+ {
+ hide_slider();
+ }
+ else if (enabled && !hidden)
+ {
+ show_slider();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ auto_mutex M(m);
+ drawable::set_pos(x,y);
+
+ b1.set_pos(rect.left(),rect.top());
+ if (ori == HORIZONTAL)
+ {
+ // make the b2 button appear at the end of the scroll_bar
+ b2.set_pos(rect.right()-b2.get_rect().width() + 1,rect.top());
+
+ if (max_pos != 0)
+ {
+ double range = b2.get_rect().left() - b1.get_rect().right() - slider.get_rect().width() - 1;
+ double slider_pos = pos;
+ slider_pos /= max_pos;
+ slider_pos *= range;
+ slider.set_pos(
+ static_cast<long>(slider_pos)+rect.left() + b1.get_rect().width(),
+ rect.top()
+ );
+
+ // move the draggable area for the slider to the new location
+ rectangle area = rect;
+ area.set_left(area.left() + style->get_width());
+ area.set_right(area.right() - style->get_width());
+ slider.set_draggable_area(area);
+
+ }
+
+
+ }
+ else
+ {
+ // make the b2 button appear at the end of the scroll_bar
+ b2.set_pos(rect.left(), rect.bottom() - b2.get_rect().height() + 1);
+
+ if (max_pos != 0)
+ {
+ double range = b2.get_rect().top() - b1.get_rect().bottom() - slider.get_rect().height() - 1;
+ double slider_pos = pos;
+ slider_pos /= max_pos;
+ slider_pos *= range;
+ slider.set_pos(
+ rect.left(),
+ static_cast<long>(slider_pos) + rect.top() + b1.get_rect().height()
+ );
+
+ // move the draggable area for the slider to the new location
+ rectangle area = rect;
+ area.set_top(area.top() + style->get_width());
+ area.set_bottom(area.bottom() - style->get_width());
+ slider.set_draggable_area(area);
+ }
+ }
+
+ adjust_fillers();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long scroll_bar::
+ get_slider_size (
+ ) const
+ {
+ if (ori == HORIZONTAL)
+ return style->get_slider_length(rect.width(),max_pos);
+ else
+ return style->get_slider_length(rect.height(),max_pos);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ adjust_fillers (
+ )
+ {
+ rectangle top(rect), bottom(rect);
+
+ if (ori == HORIZONTAL)
+ {
+ if (slider.is_hidden())
+ {
+ top.set_left(b1.get_rect().right()+1);
+ top.set_right(b2.get_rect().left()-1);
+ bottom.set_left(1);
+ bottom.set_right(-1);
+ }
+ else
+ {
+ top.set_left(b1.get_rect().right()+1);
+ top.set_right(slider.get_rect().left()-1);
+ bottom.set_left(slider.get_rect().right()+1);
+ bottom.set_right(b2.get_rect().left()-1);
+ }
+ }
+ else
+ {
+ if (slider.is_hidden())
+ {
+ top.set_top(b1.get_rect().bottom()+1);
+ top.set_bottom(b2.get_rect().top()-1);
+ bottom.set_top(1);
+ bottom.set_bottom(-1);
+ }
+ else
+ {
+ top.set_top(b1.get_rect().bottom()+1);
+ top.set_bottom(slider.get_rect().top()-1);
+ bottom.set_top(slider.get_rect().bottom()+1);
+ bottom.set_bottom(b2.get_rect().top()-1);
+ }
+ }
+
+ top_filler.rect = top;
+ bottom_filler.rect = bottom;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ hide_slider (
+ )
+ {
+ rectangle top(rect), bottom(rect);
+ slider.hide();
+ top_filler.disable();
+ bottom_filler.disable();
+ bottom_filler.hide();
+ if (ori == HORIZONTAL)
+ {
+ top.set_left(b1.get_rect().right()+1);
+ top.set_right(b2.get_rect().left()-1);
+ }
+ else
+ {
+ top.set_top(b1.get_rect().bottom()+1);
+ top.set_bottom(b2.get_rect().top()-1);
+ }
+ top_filler.rect = top;
+ bottom_filler.rect = bottom;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ show_slider (
+ )
+ {
+ if ((b2.get_rect().top() - b1.get_rect().bottom() - 1 <= 8 && ori == VERTICAL) ||
+ (b2.get_rect().left() - b1.get_rect().right() - 1 <= 8 && ori == HORIZONTAL) ||
+ max_pos == 0)
+ return;
+
+ rectangle top(rect), bottom(rect);
+ slider.show();
+ top_filler.enable();
+ bottom_filler.enable();
+ bottom_filler.show();
+ if (ori == HORIZONTAL)
+ {
+ top.set_left(b1.get_rect().right()+1);
+ top.set_right(slider.get_rect().left()-1);
+ bottom.set_left(slider.get_rect().right()+1);
+ bottom.set_right(b2.get_rect().left()-1);
+ }
+ else
+ {
+ top.set_top(b1.get_rect().bottom()+1);
+ top.set_bottom(slider.get_rect().top()-1);
+ bottom.set_top(slider.get_rect().bottom()+1);
+ bottom.set_bottom(b2.get_rect().top()-1);
+ }
+ top_filler.rect = top;
+ bottom_filler.rect = bottom;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long scroll_bar::
+ max_slider_pos (
+ ) const
+ {
+ auto_mutex M(m);
+ return max_pos;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ set_max_slider_pos (
+ long mpos
+ )
+ {
+ auto_mutex M(m);
+ max_pos = mpos;
+ if (pos > mpos)
+ pos = mpos;
+
+ if (ori == HORIZONTAL)
+ set_length(rect.width());
+ else
+ set_length(rect.height());
+
+ if (mpos != 0 && enabled)
+ {
+ b1.enable();
+ b2.enable();
+ }
+ else
+ {
+ b1.disable();
+ b2.disable();
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ set_slider_pos (
+ long pos
+ )
+ {
+ auto_mutex M(m);
+ if (pos < 0)
+ pos = 0;
+ if (pos > max_pos)
+ pos = max_pos;
+
+ this->pos = pos;
+
+ // move the slider object to its new position
+ set_pos(rect.left(),rect.top());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long scroll_bar::
+ slider_pos (
+ ) const
+ {
+ auto_mutex M(m);
+ return pos;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ on_slider_drag (
+ )
+ {
+ if (ori == HORIZONTAL)
+ {
+ double slider_pos = slider.get_rect().left() - b1.get_rect().right() - 1;
+ double range = b2.get_rect().left() - b1.get_rect().right() - slider.get_rect().width() - 1;
+ slider_pos /= range;
+ slider_pos *= max_pos;
+ pos = static_cast<unsigned long>(slider_pos);
+ }
+ else
+ {
+ double slider_pos = slider.get_rect().top() - b1.get_rect().bottom() - 1;
+ double range = b2.get_rect().top() - b1.get_rect().bottom() - slider.get_rect().height() - 1;
+ slider_pos /= range;
+ slider_pos *= max_pos;
+ pos = static_cast<unsigned long>(slider_pos);
+ }
+
+ adjust_fillers();
+
+ if (scroll_handler.is_set())
+ scroll_handler();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ draw (
+ const canvas&
+ ) const
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ b1_down (
+ )
+ {
+ if (pos != 0)
+ {
+ set_slider_pos(pos-1);
+ if (scroll_handler.is_set())
+ scroll_handler();
+
+ if (b1_timer.delay_time() == 1000)
+ b1_timer.set_delay_time(500);
+ else
+ b1_timer.set_delay_time(50);
+ b1_timer.start();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ b1_up (
+ bool
+ )
+ {
+ b1_timer.stop();
+ b1_timer.set_delay_time(1000);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ b2_down (
+ )
+ {
+ if (pos != max_pos)
+ {
+ set_slider_pos(pos+1);
+ if (scroll_handler.is_set())
+ scroll_handler();
+
+ if (b2_timer.delay_time() == 1000)
+ b2_timer.set_delay_time(500);
+ else
+ b2_timer.set_delay_time(50);
+ b2_timer.start();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ b2_up (
+ bool
+ )
+ {
+ b2_timer.stop();
+ b2_timer.set_delay_time(1000);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ top_filler_down (
+ )
+ {
+ // ignore this if the mouse is now outside this object. This could happen
+ // since the timers are also calling this function.
+ if (top_filler.rect.contains(lastx,lasty) == false)
+ {
+ top_filler_up(false);
+ return;
+ }
+
+ if (pos != 0)
+ {
+ if (pos < js)
+ {
+ // if there is less than jump_size() space left then jump the remaining
+ // amount.
+ delayed_set_slider_pos(0);
+ }
+ else
+ {
+ delayed_set_slider_pos(pos-js);
+ }
+
+ if (top_filler_timer.delay_time() == 1000)
+ top_filler_timer.set_delay_time(500);
+ else
+ top_filler_timer.set_delay_time(50);
+ top_filler_timer.start();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ top_filler_up (
+ bool
+ )
+ {
+ top_filler_timer.stop();
+ top_filler_timer.set_delay_time(1000);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ bottom_filler_down (
+ )
+ {
+ // ignore this if the mouse is now outside this object. This could happen
+ // since the timers are also calling this function.
+ if (bottom_filler.rect.contains(lastx,lasty) == false)
+ {
+ bottom_filler_up(false);
+ return;
+ }
+
+ if (pos != max_pos)
+ {
+ if (max_pos - pos < js)
+ {
+ // if there is less than jump_size() space left then jump the remaining
+ // amount.
+ delayed_set_slider_pos(max_pos);
+ }
+ else
+ {
+ delayed_set_slider_pos(pos+js);
+ }
+
+ if (bottom_filler_timer.delay_time() == 1000)
+ bottom_filler_timer.set_delay_time(500);
+ else
+ bottom_filler_timer.set_delay_time(50);
+ bottom_filler_timer.start();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ bottom_filler_up (
+ bool
+ )
+ {
+ bottom_filler_timer.stop();
+ bottom_filler_timer.set_delay_time(1000);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ set_jump_size (
+ long js_
+ )
+ {
+ auto_mutex M(m);
+ if (js_ < 1)
+ js = 1;
+ else
+ js = js_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long scroll_bar::
+ jump_size (
+ ) const
+ {
+ auto_mutex M(m);
+ return js;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ on_user_event (
+ int i
+ )
+ {
+ switch (i)
+ {
+ case 0:
+ b1_down();
+ break;
+ case 1:
+ b2_down();
+ break;
+ case 2:
+ top_filler_down();
+ break;
+ case 3:
+ bottom_filler_down();
+ break;
+ case 4:
+ // if the position we are supposed to switch the slider too isn't
+ // already set
+ if (delayed_pos != pos)
+ {
+ set_slider_pos(delayed_pos);
+ if (scroll_handler.is_set())
+ scroll_handler();
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ delayed_set_slider_pos (
+ unsigned long dpos
+ )
+ {
+ delayed_pos = dpos;
+ parent.trigger_user_event(this,4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ b1_down_t (
+ )
+ {
+ parent.trigger_user_event(this,0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ b2_down_t (
+ )
+ {
+ parent.trigger_user_event(this,1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ top_filler_down_t (
+ )
+ {
+ parent.trigger_user_event(this,2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar::
+ bottom_filler_down_t (
+ )
+ {
+ parent.trigger_user_event(this,3);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// widget_group object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ empty (
+ )
+ {
+ auto_mutex M(m);
+ widgets.clear();
+ wg_widgets.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ add (
+ drawable& widget,
+ unsigned long x,
+ unsigned long y
+ )
+ {
+ auto_mutex M(m);
+ drawable* w = &widget;
+ relpos rp;
+ rp.x = x;
+ rp.y = y;
+ if (widgets.is_in_domain(w))
+ {
+ widgets[w].x = x;
+ widgets[w].y = y;
+ }
+ else
+ {
+ widgets.add(w,rp);
+ }
+ if (is_hidden())
+ widget.hide();
+ else
+ widget.show();
+
+ if (is_enabled())
+ widget.enable();
+ else
+ widget.disable();
+
+ widget.set_z_order(z_order());
+ widget.set_pos(x+rect.left(),y+rect.top());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ add (
+ widget_group& widget,
+ unsigned long x,
+ unsigned long y
+ )
+ {
+ auto_mutex M(m);
+ drawable& w = widget;
+ add(w, x, y);
+
+ widget_group* wg = &widget;
+ wg_widgets.add(wg);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool widget_group::
+ is_member (
+ const drawable& widget
+ ) const
+ {
+ auto_mutex M(m);
+ drawable* w = const_cast<drawable*>(&widget);
+ return widgets.is_in_domain(w);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ remove (
+ const drawable& widget
+ )
+ {
+ auto_mutex M(m);
+ drawable* w = const_cast<drawable*>(&widget);
+ if (widgets.is_in_domain(w))
+ {
+ widgets.destroy(w);
+
+ // check if we also have an entry in the wg_widgets set and if
+ // so then remove that too
+ widget_group* wg = reinterpret_cast<widget_group*>(w);
+ if (wg_widgets.is_member(wg))
+ {
+ wg_widgets.destroy(wg);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ size_t widget_group::
+ size (
+ ) const
+ {
+ auto_mutex M(m);
+ return widgets.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ auto_mutex M(m);
+ widgets.reset();
+ while (widgets.move_next())
+ {
+ const unsigned long rx = widgets.element().value().x;
+ const unsigned long ry = widgets.element().value().y;
+ widgets.element().key()->set_pos(x+rx,y+ry);
+ }
+ drawable::set_pos(x,y);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ set_z_order (
+ long order
+ )
+ {
+ auto_mutex M(m);
+ widgets.reset();
+ while (widgets.move_next())
+ widgets.element().key()->set_z_order(order);
+ drawable::set_z_order(order);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ show (
+ )
+ {
+ auto_mutex M(m);
+ widgets.reset();
+ while (widgets.move_next())
+ widgets.element().key()->show();
+ drawable::show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ hide (
+ )
+ {
+ auto_mutex M(m);
+ widgets.reset();
+ while (widgets.move_next())
+ widgets.element().key()->hide();
+ drawable::hide();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ enable (
+ )
+ {
+ auto_mutex M(m);
+ widgets.reset();
+ while (widgets.move_next())
+ widgets.element().key()->enable();
+ drawable::enable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ disable ()
+ {
+ auto_mutex M(m);
+ widgets.reset();
+ while (widgets.move_next())
+ widgets.element().key()->disable();
+ drawable::disable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void widget_group::
+ fit_to_contents (
+ )
+ {
+ auto_mutex M(m);
+
+ // call fit_to_contents on all the widget_groups we contain
+ wg_widgets.reset();
+ while (wg_widgets.move_next())
+ wg_widgets.element()->fit_to_contents();
+
+ // now accumulate a rectangle that contains everything in this widget_group
+ rectangle r;
+ widgets.reset();
+ while (widgets.move_next())
+ r = r + widgets.element().key()->get_rect();
+
+ if (r.is_empty())
+ {
+ // make sure it is still empty after we set it at the correct position
+ r.set_right(rect.left()-1);
+ r.set_bottom(rect.top()-1);
+ }
+
+ r.set_left(rect.left());
+ r.set_top(rect.top());
+ rect = r;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// class popup_menu
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ popup_menu::
+ popup_menu (
+ ) :
+ base_window(false,true),
+ pad(2),
+ item_pad(3),
+ cur_rect(pad,pad,pad-1,pad-1),
+ left_width(0),
+ middle_width(0),
+ selected_item(0),
+ submenu_open(false)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ enable_menu_item (
+ unsigned long idx
+ )
+ {
+ DLIB_ASSERT ( idx < size() ,
+ "\tvoid popup_menu::enable_menu_item()"
+ << "\n\tidx: " << idx
+ << "\n\tsize(): " << size()
+ );
+ auto_mutex M(wm);
+ item_enabled[idx] = true;
+ invalidate_rectangle(cur_rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ disable_menu_item (
+ unsigned long idx
+ )
+ {
+ DLIB_ASSERT ( idx < size() ,
+ "\tvoid popup_menu::enable_menu_item()"
+ << "\n\tidx: " << idx
+ << "\n\tsize(): " << size()
+ );
+ auto_mutex M(wm);
+ item_enabled[idx] = false;
+ invalidate_rectangle(cur_rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ size_t popup_menu::
+ size (
+ ) const
+ {
+ auto_mutex M(wm);
+ return items.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ clear (
+ )
+ {
+ auto_mutex M(wm);
+ hide();
+ cur_rect = rectangle(pad,pad,pad-1,pad-1);
+ win_rect = rectangle();
+ left_width = 0;
+ middle_width = 0;
+ items.clear();
+ item_enabled.clear();
+ left_rects.clear();
+ middle_rects.clear();
+ right_rects.clear();
+ line_rects.clear();
+ submenus.clear();
+ selected_item = 0;
+ submenu_open = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ show (
+ )
+ {
+ auto_mutex M(wm);
+ selected_item = submenus.size();
+ base_window::show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ hide (
+ )
+ {
+ auto_mutex M(wm);
+ // hide ourselves
+ close_submenu();
+ selected_item = submenus.size();
+ base_window::hide();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ select_first_item (
+ )
+ {
+ auto_mutex M(wm);
+ close_submenu();
+ selected_item = items.size();
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ if ((items[i]->has_click_event() || submenus[i]) && item_enabled[i])
+ {
+ selected_item = i;
+ break;
+ }
+ }
+ invalidate_rectangle(cur_rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool popup_menu::
+ forwarded_on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ auto_mutex M(wm);
+ // do nothing if this popup menu is empty
+ if (items.size() == 0)
+ return false;
+
+
+ // check if the selected item is a submenu
+ if (selected_item != submenus.size() && submenus[selected_item] != 0 && submenu_open)
+ {
+ // send the key to the submenu and return if that menu used the key
+ if (submenus[selected_item]->forwarded_on_keydown(key,is_printable,state) == true)
+ return true;
+ }
+
+ if (key == KEY_UP)
+ {
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ selected_item = (selected_item + items.size() - 1)%items.size();
+ // only stop looking if this one is enabled and has a click event or is a submenu
+ if (item_enabled[selected_item] && (items[selected_item]->has_click_event() || submenus[selected_item]) )
+ break;
+ }
+ invalidate_rectangle(cur_rect);
+ return true;
+ }
+ else if (key == KEY_DOWN)
+ {
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ selected_item = (selected_item + 1)%items.size();
+ // only stop looking if this one is enabled and has a click event or is a submenu
+ if (item_enabled[selected_item] && (items[selected_item]->has_click_event() || submenus[selected_item]))
+ break;
+ }
+ invalidate_rectangle(cur_rect);
+ return true;
+ }
+ else if (key == KEY_RIGHT && submenu_open == false && display_selected_submenu())
+ {
+ submenus[selected_item]->select_first_item();
+ return true;
+ }
+ else if (key == KEY_LEFT && selected_item != submenus.size() &&
+ submenus[selected_item] != 0 && submenu_open)
+ {
+ close_submenu();
+ return true;
+ }
+ else if (key == '\n')
+ {
+ if (selected_item != submenus.size() && (items[selected_item]->has_click_event() || submenus[selected_item]))
+ {
+ const long idx = selected_item;
+ // only hide this popup window if this isn't a submenu
+ if (submenus[idx] == 0)
+ {
+ hide();
+ hide_handlers.reset();
+ while (hide_handlers.move_next())
+ hide_handlers.element()();
+ }
+ else
+ {
+ display_selected_submenu();
+ submenus[idx]->select_first_item();
+ }
+ items[idx]->on_click();
+ return true;
+ }
+ }
+ else if (is_printable)
+ {
+ // check if there is a hotkey for this key
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ if (std::tolower(key) == std::tolower(items[i]->get_hot_key()) &&
+ (items[i]->has_click_event() || submenus[i]) && item_enabled[i] )
+ {
+ // only hide this popup window if this isn't a submenu
+ if (submenus[i] == 0)
+ {
+ hide();
+ hide_handlers.reset();
+ while (hide_handlers.move_next())
+ hide_handlers.element()();
+ }
+ else
+ {
+ if (selected_item != items.size())
+ invalidate_rectangle(line_rects[selected_item]);
+
+ selected_item = i;
+ display_selected_submenu();
+ invalidate_rectangle(line_rects[i]);
+ submenus[i]->select_first_item();
+ }
+ items[i]->on_click();
+ }
+ }
+
+ // always say we use a printable key for hotkeys
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ on_submenu_hide (
+ )
+ {
+ hide();
+ hide_handlers.reset();
+ while (hide_handlers.move_next())
+ hide_handlers.element()();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ on_window_resized(
+ )
+ {
+ invalidate_rectangle(win_rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ on_mouse_up (
+ unsigned long btn,
+ unsigned long,
+ long x,
+ long y
+ )
+ {
+ if (cur_rect.contains(x,y) && btn == LEFT)
+ {
+ // figure out which item this was on
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ if (line_rects[i].contains(x,y) && item_enabled[i] && items[i]->has_click_event())
+ {
+ // only hide this popup window if this isn't a submenu
+ if (submenus[i] == 0)
+ {
+ hide();
+ hide_handlers.reset();
+ while (hide_handlers.move_next())
+ hide_handlers.element()();
+ }
+ items[i]->on_click();
+ break;
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ on_mouse_move (
+ unsigned long ,
+ long x,
+ long y
+ )
+ {
+ if (cur_rect.contains(x,y))
+ {
+ // check if the mouse is still in the same rect it was in last time
+ rectangle last_rect;
+ if (selected_item != submenus.size())
+ {
+ last_rect = line_rects[selected_item];
+ }
+
+ // if the mouse isn't in the same rectangle any more
+ if (last_rect.contains(x,y) == false)
+ {
+ if (selected_item != submenus.size())
+ {
+ invalidate_rectangle(last_rect);
+ close_submenu();
+ selected_item = submenus.size();
+ }
+
+
+ // figure out if we should redraw any menu items
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ if (items[i]->has_click_event() || submenus[i])
+ {
+ if (line_rects[i].contains(x,y))
+ {
+ selected_item = i;
+ break;
+ }
+ }
+ }
+
+ // if we found a rectangle that contains the mouse then
+ // tell it to redraw itself
+ if (selected_item != submenus.size())
+ {
+ display_selected_submenu();
+ invalidate_rectangle(line_rects[selected_item]);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ close_submenu (
+ )
+ {
+ if (selected_item != submenus.size() && submenus[selected_item] && submenu_open)
+ {
+ submenus[selected_item]->hide();
+ submenu_open = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool popup_menu::
+ display_selected_submenu (
+ )
+ {
+ // show the submenu if one exists
+ if (selected_item != submenus.size() && submenus[selected_item])
+ {
+ long wx, wy;
+ get_pos(wx,wy);
+ wx += line_rects[selected_item].right();
+ wy += line_rects[selected_item].top();
+ submenus[selected_item]->set_pos(wx+1,wy-2);
+ submenus[selected_item]->show();
+ submenu_open = true;
+ return true;
+ }
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ on_mouse_leave (
+ )
+ {
+ if (selected_item != submenus.size())
+ {
+ // only unhighlight a menu item if it isn't a submenu item
+ if (submenus[selected_item] == 0)
+ {
+ invalidate_rectangle(line_rects[selected_item]);
+ selected_item = submenus.size();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu::
+ paint (
+ const canvas& c
+ )
+ {
+ c.fill(200,200,200);
+ draw_rectangle(c, win_rect);
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ bool is_selected = false;
+ if (selected_item != submenus.size() && i == selected_item &&
+ item_enabled[i])
+ is_selected = true;
+
+ items[i]->draw_background(c,line_rects[i], item_enabled[i], is_selected);
+ items[i]->draw_left(c,left_rects[i], item_enabled[i], is_selected);
+ items[i]->draw_middle(c,middle_rects[i], item_enabled[i], is_selected);
+ items[i]->draw_right(c,right_rects[i], item_enabled[i], is_selected);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// class zoomable_region
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ zoomable_region::
+ zoomable_region (
+ drawable_window& w,
+ unsigned long events
+ ) :
+ drawable(w,MOUSE_CLICK | MOUSE_WHEEL | MOUSE_MOVE | events),
+ min_scale(0.15),
+ max_scale(1.0),
+ zoom_increment_(0.90),
+ vsb(w, scroll_bar::VERTICAL),
+ hsb(w, scroll_bar::HORIZONTAL)
+ {
+ scale = 1;
+ mouse_drag_screen = false;
+ style.reset(new scrollable_region_style_default());
+
+ hsb.set_scroll_handler(*this,&zoomable_region::on_h_scroll);
+ vsb.set_scroll_handler(*this,&zoomable_region::on_v_scroll);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ zoomable_region::
+ ~zoomable_region()
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ auto_mutex M(m);
+ drawable::set_pos(x,y);
+ const long border_size = style->get_border_size();
+ vsb.set_pos(rect.right()-border_size+1-vsb.width(),rect.top()+border_size);
+ hsb.set_pos(rect.left()+border_size,rect.bottom()-border_size+1-hsb.height());
+
+ display_rect_ = rectangle(rect.left()+border_size,
+ rect.top()+border_size,
+ rect.right()-border_size-vsb.width(),
+ rect.bottom()-border_size-hsb.height());
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ set_zoom_increment (
+ double zi
+ )
+ {
+ DLIB_ASSERT(0.0 < zi && zi < 1.0,
+ "\tvoid zoomable_region::set_zoom_increment(zi)"
+ << "\n\t the zoom increment must be between 0 and 1"
+ << "\n\t zi: " << zi
+ << "\n\t this: " << this
+ );
+
+ auto_mutex M(m);
+ zoom_increment_ = zi;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double zoomable_region::
+ zoom_increment (
+ ) const
+ {
+ auto_mutex M(m);
+ return zoom_increment_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ set_max_zoom_scale (
+ double ms
+ )
+ {
+ DLIB_ASSERT(ms > 0,
+ "\tvoid zoomable_region::set_max_zoom_scale(ms)"
+ << "\n\t the max zoom scale must be greater than 0"
+ << "\n\t ms: " << ms
+ << "\n\t this: " << this
+ );
+
+ auto_mutex M(m);
+ max_scale = ms;
+ if (scale > ms)
+ {
+ scale = max_scale;
+ lr_point = gui_to_graph_space(point(display_rect_.right(),display_rect_.bottom()));
+ redraw_graph();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ set_min_zoom_scale (
+ double ms
+ )
+ {
+ DLIB_ASSERT(ms > 0,
+ "\tvoid zoomable_region::set_min_zoom_scale(ms)"
+ << "\n\t the min zoom scale must be greater than 0"
+ << "\n\t ms: " << ms
+ << "\n\t this: " << this
+ );
+
+ auto_mutex M(m);
+ min_scale = ms;
+
+ if (scale < ms)
+ {
+ scale = min_scale;
+ }
+
+ // just call set_size so that everything gets redrawn right
+ set_size(rect.width(), rect.height());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double zoomable_region::
+ min_zoom_scale (
+ ) const
+ {
+ auto_mutex M(m);
+ return min_scale;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double zoomable_region::
+ max_zoom_scale (
+ ) const
+ {
+ auto_mutex M(m);
+ return max_scale;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ rectangle old(rect);
+ const long border_size = style->get_border_size();
+ rect = resize_rect(rect,width,height);
+ vsb.set_pos(rect.right()-border_size+1-vsb.width(), rect.top()+border_size);
+ hsb.set_pos(rect.left()+border_size, rect.bottom()-border_size+1-hsb.height());
+
+ display_rect_ = rectangle(rect.left()+border_size,
+ rect.top()+border_size,
+ rect.right()-border_size-vsb.width(),
+ rect.bottom()-border_size-hsb.height());
+ vsb.set_length(display_rect_.height());
+ hsb.set_length(display_rect_.width());
+ parent.invalidate_rectangle(rect+old);
+
+ const double old_scale = scale;
+ const vector<double,2> old_gr_orig(gr_orig);
+ scale = min_scale;
+ gr_orig = vector<double,2>(0,0);
+ lr_point = gui_to_graph_space(point(display_rect_.right(),display_rect_.bottom()));
+ scale = old_scale;
+
+ // call adjust_origin() so that the scroll bars get their max slider positions
+ // setup right
+ const point rect_corner(display_rect_.left(), display_rect_.top());
+ adjust_origin(rect_corner, old_gr_orig);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ show (
+ )
+ {
+ auto_mutex M(m);
+ drawable::show();
+ hsb.show();
+ vsb.show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ hide (
+ )
+ {
+ auto_mutex M(m);
+ drawable::hide();
+ hsb.hide();
+ vsb.hide();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ enable (
+ )
+ {
+ auto_mutex M(m);
+ drawable::enable();
+ hsb.enable();
+ vsb.enable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ disable (
+ )
+ {
+ auto_mutex M(m);
+ drawable::disable();
+ hsb.disable();
+ vsb.disable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ set_z_order (
+ long order
+ )
+ {
+ auto_mutex M(m);
+ drawable::set_z_order(order);
+ hsb.set_z_order(order);
+ vsb.set_z_order(order);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ point zoomable_region::
+ graph_to_gui_space (
+ const vector<double,2>& p
+ ) const
+ {
+ const point rect_corner(display_rect_.left(), display_rect_.top());
+ return (p - gr_orig)*scale + rect_corner;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ vector<double,2> zoomable_region::
+ gui_to_graph_space (
+ const point& p
+ ) const
+ {
+ const point rect_corner(display_rect_.left(), display_rect_.top());
+ return (p - rect_corner)/scale + gr_orig;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ point zoomable_region::
+ max_graph_point (
+ ) const
+ {
+ return lr_point;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle zoomable_region::
+ display_rect (
+ ) const
+ {
+ return display_rect_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double zoomable_region::
+ zoom_scale (
+ ) const
+ {
+ return scale;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ set_zoom_scale (
+ double new_scale
+ )
+ {
+ // if new_scale isn't in the right range then put it back in range before we do the
+ // rest of this function
+ if (!(min_scale <= new_scale && new_scale <= max_scale))
+ {
+ if (new_scale > max_scale)
+ new_scale = max_scale;
+ else
+ new_scale = min_scale;
+ }
+
+ // find the point in the center of the graph area
+ point center((display_rect_.left()+display_rect_.right())/2, (display_rect_.top()+display_rect_.bottom())/2);
+ point graph_p(gui_to_graph_space(center));
+ scale = new_scale;
+ adjust_origin(center, graph_p);
+ redraw_graph();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ center_display_at_graph_point (
+ const vector<double,2>& p
+ )
+ {
+ // find the point in the center of the graph area
+ point center((display_rect_.left()+display_rect_.right())/2, (display_rect_.top()+display_rect_.bottom())/2);
+ adjust_origin(center, p);
+ redraw_graph();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ on_wheel_down (
+ unsigned long
+ )
+ {
+ // zoom out
+ if (enabled && !hidden && scale > min_scale && display_rect_.contains(lastx,lasty))
+ {
+ point gui_p(lastx,lasty);
+ point graph_p(gui_to_graph_space(gui_p));
+ const double old_scale = scale;
+ scale *= zoom_increment_;
+ if (scale < min_scale)
+ scale = min_scale;
+ redraw_graph();
+ adjust_origin(gui_p, graph_p);
+
+ if (scale != old_scale)
+ on_view_changed();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ on_wheel_up (
+ unsigned long
+ )
+ {
+ // zoom in
+ if (enabled && !hidden && scale < max_scale && display_rect_.contains(lastx,lasty))
+ {
+ point gui_p(lastx,lasty);
+ point graph_p(gui_to_graph_space(gui_p));
+ const double old_scale = scale;
+ scale /= zoom_increment_;
+ if (scale > max_scale)
+ scale = max_scale;
+ redraw_graph();
+ adjust_origin(gui_p, graph_p);
+
+ if (scale != old_scale)
+ on_view_changed();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ if (enabled && !hidden && mouse_drag_screen)
+ {
+ adjust_origin(point(x,y), drag_screen_point);
+ redraw_graph();
+ on_view_changed();
+ }
+
+ // check if the mouse isn't being dragged anymore
+ if ((state & base_window::LEFT) == 0)
+ {
+ mouse_drag_screen = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ on_mouse_up (
+ unsigned long ,
+ unsigned long ,
+ long ,
+ long
+ )
+ {
+ mouse_drag_screen = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long ,
+ long x,
+ long y,
+ bool
+ )
+ {
+ if (enabled && !hidden && display_rect_.contains(x,y) && btn == base_window::LEFT)
+ {
+ mouse_drag_screen = true;
+ drag_screen_point = gui_to_graph_space(point(x,y));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ draw (
+ const canvas& c
+ ) const
+ {
+ style->draw_scrollable_region_border(c, rect, enabled);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ on_h_scroll (
+ )
+ {
+ gr_orig.x() = hsb.slider_pos();
+ redraw_graph();
+
+ on_view_changed();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ on_v_scroll (
+ )
+ {
+ gr_orig.y() = vsb.slider_pos();
+ redraw_graph();
+
+ on_view_changed();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ redraw_graph (
+ )
+ {
+ parent.invalidate_rectangle(display_rect_);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void zoomable_region::
+ adjust_origin (
+ const point& gui_p,
+ const vector<double,2>& graph_p
+ )
+ {
+ const point rect_corner(display_rect_.left(), display_rect_.top());
+ const dlib::vector<double,2> v(gui_p - rect_corner);
+ gr_orig = graph_p - v/scale;
+
+
+ // make sure the origin isn't outside the point (0,0)
+ if (gr_orig.x() < 0)
+ gr_orig.x() = 0;
+ if (gr_orig.y() < 0)
+ gr_orig.y() = 0;
+
+ // make sure the lower right corner of the display_rect_ doesn't map to a point beyond lr_point
+ point lr_rect_corner(display_rect_.right(), display_rect_.bottom());
+ point p = graph_to_gui_space(lr_point);
+ vector<double,2> lr_rect_corner_graph_space(gui_to_graph_space(lr_rect_corner));
+ vector<double,2> delta(lr_point - lr_rect_corner_graph_space);
+ if (lr_rect_corner.x() > p.x())
+ {
+ gr_orig.x() += delta.x();
+ }
+
+ if (lr_rect_corner.y() > p.y())
+ {
+ gr_orig.y() += delta.y();
+ }
+
+
+ const vector<double,2> ul_rect_corner_graph_space(gui_to_graph_space(rect_corner));
+ lr_rect_corner_graph_space = gui_to_graph_space(lr_rect_corner);
+ // now adjust the scroll bars
+
+ hsb.set_max_slider_pos((unsigned long)std::max(lr_point.x()-(lr_rect_corner_graph_space.x()-ul_rect_corner_graph_space.x()),0.0));
+ vsb.set_max_slider_pos((unsigned long)std::max(lr_point.y()-(lr_rect_corner_graph_space.y()-ul_rect_corner_graph_space.y()),0.0));
+ // adjust slider position now.
+ hsb.set_slider_pos(static_cast<long>(ul_rect_corner_graph_space.x()));
+ vsb.set_slider_pos(static_cast<long>(ul_rect_corner_graph_space.y()));
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// class scrollable_region
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ scrollable_region::
+ scrollable_region (
+ drawable_window& w,
+ unsigned long events
+ ) :
+ drawable(w, MOUSE_WHEEL|events|MOUSE_CLICK|MOUSE_MOVE),
+ hsb(w,scroll_bar::HORIZONTAL),
+ vsb(w,scroll_bar::VERTICAL),
+ hscroll_bar_inc(1),
+ vscroll_bar_inc(1),
+ h_wheel_scroll_bar_inc(1),
+ v_wheel_scroll_bar_inc(1),
+ mouse_drag_enabled_(false),
+ user_is_dragging_mouse(false)
+ {
+ style.reset(new scrollable_region_style_default());
+
+ hsb.set_scroll_handler(*this,&scrollable_region::on_h_scroll);
+ vsb.set_scroll_handler(*this,&scrollable_region::on_v_scroll);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ scrollable_region::
+ ~scrollable_region (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ show (
+ )
+ {
+ auto_mutex M(m);
+ drawable::show();
+ if (need_h_scroll())
+ hsb.show();
+ if (need_v_scroll())
+ vsb.show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ hide (
+ )
+ {
+ auto_mutex M(m);
+ drawable::hide();
+ hsb.hide();
+ vsb.hide();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ enable (
+ )
+ {
+ auto_mutex M(m);
+ drawable::enable();
+ hsb.enable();
+ vsb.enable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ disable (
+ )
+ {
+ auto_mutex M(m);
+ drawable::disable();
+ hsb.disable();
+ vsb.disable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_z_order (
+ long order
+ )
+ {
+ auto_mutex M(m);
+ drawable::set_z_order(order);
+ hsb.set_z_order(order);
+ vsb.set_z_order(order);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ rectangle old(rect);
+ rect = resize_rect(rect,width,height);
+ vsb.set_pos(rect.right()-style->get_border_size()-vsb.width()+1, rect.top()+style->get_border_size());
+ hsb.set_pos(rect.left()+style->get_border_size(), rect.bottom()-style->get_border_size()-hsb.height()+1);
+
+ // adjust the display_rect_
+ if (need_h_scroll() && need_v_scroll())
+ {
+ // both scroll bars aren't hidden
+ if (!hidden)
+ {
+ vsb.show();
+ hsb.show();
+ }
+ display_rect_ = rectangle( rect.left()+style->get_border_size(),
+ rect.top()+style->get_border_size(),
+ rect.right()-style->get_border_size()-vsb.width(),
+ rect.bottom()-style->get_border_size()-hsb.height());
+
+ // figure out how many scroll bar positions there should be
+ unsigned long hdelta = total_rect_.width()-display_rect_.width();
+ unsigned long vdelta = total_rect_.height()-display_rect_.height();
+ hdelta = (hdelta+hscroll_bar_inc-1)/hscroll_bar_inc;
+ vdelta = (vdelta+vscroll_bar_inc-1)/vscroll_bar_inc;
+
+ hsb.set_max_slider_pos(hdelta);
+ vsb.set_max_slider_pos(vdelta);
+
+ vsb.set_jump_size((display_rect_.height()+vscroll_bar_inc-1)/vscroll_bar_inc/2+1);
+ hsb.set_jump_size((display_rect_.width()+hscroll_bar_inc-1)/hscroll_bar_inc/2+1);
+ }
+ else if (need_h_scroll())
+ {
+ // only hsb is hidden
+ if (!hidden)
+ {
+ hsb.show();
+ vsb.hide();
+ }
+ display_rect_ = rectangle( rect.left()+style->get_border_size(),
+ rect.top()+style->get_border_size(),
+ rect.right()-style->get_border_size(),
+ rect.bottom()-style->get_border_size()-hsb.height());
+
+ // figure out how many scroll bar positions there should be
+ unsigned long hdelta = total_rect_.width()-display_rect_.width();
+ hdelta = (hdelta+hscroll_bar_inc-1)/hscroll_bar_inc;
+
+ hsb.set_max_slider_pos(hdelta);
+ vsb.set_max_slider_pos(0);
+
+ hsb.set_jump_size((display_rect_.width()+hscroll_bar_inc-1)/hscroll_bar_inc/2+1);
+ }
+ else if (need_v_scroll())
+ {
+ // only vsb is hidden
+ if (!hidden)
+ {
+ hsb.hide();
+ vsb.show();
+ }
+ display_rect_ = rectangle( rect.left()+style->get_border_size(),
+ rect.top()+style->get_border_size(),
+ rect.right()-style->get_border_size()-vsb.width(),
+ rect.bottom()-style->get_border_size());
+
+ unsigned long vdelta = total_rect_.height()-display_rect_.height();
+ vdelta = (vdelta+vscroll_bar_inc-1)/vscroll_bar_inc;
+
+ hsb.set_max_slider_pos(0);
+ vsb.set_max_slider_pos(vdelta);
+
+ vsb.set_jump_size((display_rect_.height()+vscroll_bar_inc-1)/vscroll_bar_inc/2+1);
+ }
+ else
+ {
+ // both are hidden
+ if (!hidden)
+ {
+ hsb.hide();
+ vsb.hide();
+ }
+ display_rect_ = rectangle( rect.left()+style->get_border_size(),
+ rect.top()+style->get_border_size(),
+ rect.right()-style->get_border_size(),
+ rect.bottom()-style->get_border_size());
+
+ hsb.set_max_slider_pos(0);
+ vsb.set_max_slider_pos(0);
+ }
+
+ vsb.set_length(display_rect_.height());
+ hsb.set_length(display_rect_.width());
+
+ // adjust the total_rect_ position by trigging the scroll events
+ on_h_scroll();
+ on_v_scroll();
+
+ parent.invalidate_rectangle(rect+old);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long scrollable_region::
+ horizontal_mouse_wheel_scroll_increment (
+ ) const
+ {
+ auto_mutex M(m);
+ return h_wheel_scroll_bar_inc;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long scrollable_region::
+ vertical_mouse_wheel_scroll_increment (
+ ) const
+ {
+ auto_mutex M(m);
+ return v_wheel_scroll_bar_inc;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_horizontal_mouse_wheel_scroll_increment (
+ unsigned long inc
+ )
+ {
+ auto_mutex M(m);
+ h_wheel_scroll_bar_inc = inc;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_vertical_mouse_wheel_scroll_increment (
+ unsigned long inc
+ )
+ {
+ auto_mutex M(m);
+ v_wheel_scroll_bar_inc = inc;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long scrollable_region::
+ horizontal_scroll_increment (
+ ) const
+ {
+ auto_mutex M(m);
+ return hscroll_bar_inc;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long scrollable_region::
+ vertical_scroll_increment (
+ ) const
+ {
+ auto_mutex M(m);
+ return vscroll_bar_inc;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_horizontal_scroll_increment (
+ unsigned long inc
+ )
+ {
+ auto_mutex M(m);
+ hscroll_bar_inc = inc;
+ // call set_size to reset the scroll bars
+ set_size(rect.width(),rect.height());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_vertical_scroll_increment (
+ unsigned long inc
+ )
+ {
+ auto_mutex M(m);
+ vscroll_bar_inc = inc;
+ // call set_size to reset the scroll bars
+ set_size(rect.width(),rect.height());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long scrollable_region::
+ horizontal_scroll_pos (
+ ) const
+ {
+ auto_mutex M(m);
+ return hsb.slider_pos();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long scrollable_region::
+ vertical_scroll_pos (
+ ) const
+ {
+ auto_mutex M(m);
+ return vsb.slider_pos();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_horizontal_scroll_pos (
+ long pos
+ )
+ {
+ auto_mutex M(m);
+
+ hsb.set_slider_pos(pos);
+ on_h_scroll();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_vertical_scroll_pos (
+ long pos
+ )
+ {
+ auto_mutex M(m);
+
+ vsb.set_slider_pos(pos);
+ on_v_scroll();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ auto_mutex M(m);
+ drawable::set_pos(x,y);
+ vsb.set_pos(rect.right()-style->get_border_size()-vsb.width()+1, rect.top()+style->get_border_size());
+ hsb.set_pos(rect.left()+style->get_border_size(), rect.bottom()-style->get_border_size()-hsb.height()+1);
+
+ const long delta_x = total_rect_.left() - display_rect_.left();
+ const long delta_y = total_rect_.top() - display_rect_.top();
+
+ display_rect_ = move_rect(display_rect_, rect.left()+style->get_border_size(), rect.top()+style->get_border_size());
+
+ total_rect_ = move_rect(total_rect_, display_rect_.left()+delta_x, display_rect_.top()+delta_y);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool scrollable_region::
+ mouse_drag_enabled (
+ ) const
+ {
+ auto_mutex M(m);
+ return mouse_drag_enabled_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ enable_mouse_drag (
+ )
+ {
+ auto_mutex M(m);
+ mouse_drag_enabled_ = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ disable_mouse_drag (
+ )
+ {
+ auto_mutex M(m);
+ mouse_drag_enabled_ = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle& scrollable_region::
+ display_rect (
+ ) const
+ {
+ return display_rect_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ set_total_rect_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ DLIB_ASSERT((width > 0 && height > 0) || (width == 0 && height == 0),
+ "\tvoid scrollable_region::set_total_rect_size(width,height)"
+ << "\n\twidth and height must be > 0 or both == 0"
+ << "\n\twidth: " << width
+ << "\n\theight: " << height
+ << "\n\tthis: " << this
+ );
+
+ total_rect_ = move_rect(rectangle(width,height),
+ display_rect_.left()-static_cast<long>(hsb.slider_pos()),
+ display_rect_.top()-static_cast<long>(vsb.slider_pos()));
+
+ // call this just to reconfigure the scroll bars
+ set_size(rect.width(),rect.height());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const rectangle& scrollable_region::
+ total_rect (
+ ) const
+ {
+ return total_rect_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ scroll_to_rect (
+ const rectangle& r_
+ )
+ {
+ const rectangle r(total_rect_.intersect(r_));
+ const rectangle old(total_rect_);
+ // adjust the horizontal scroll bar so that r fits as best as possible
+ if (r.left() < display_rect_.left())
+ {
+ long distance = (r.left()-total_rect_.left())/hscroll_bar_inc;
+ hsb.set_slider_pos(distance);
+ }
+ else if (r.right() > display_rect_.right())
+ {
+ long distance = (r.right()-total_rect_.left()-display_rect_.width()+hscroll_bar_inc)/hscroll_bar_inc;
+ hsb.set_slider_pos(distance);
+ }
+
+ // adjust the vertical scroll bar so that r fits as best as possible
+ if (r.top() < display_rect_.top())
+ {
+ long distance = (r.top()-total_rect_.top())/vscroll_bar_inc;
+ vsb.set_slider_pos(distance);
+ }
+ else if (r.bottom() > display_rect_.bottom())
+ {
+ long distance = (r.bottom()-total_rect_.top()-display_rect_.height()+vscroll_bar_inc)/vscroll_bar_inc;
+ vsb.set_slider_pos(distance);
+ }
+
+
+ // adjust total_rect_ so that it matches where the scroll bars are now
+ total_rect_ = move_rect(total_rect_,
+ display_rect_.left()-hscroll_bar_inc*hsb.slider_pos(),
+ display_rect_.top()-vscroll_bar_inc*vsb.slider_pos());
+
+ // only redraw if we actually changed something
+ if (total_rect_ != old)
+ {
+ parent.invalidate_rectangle(display_rect_);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ on_wheel_down (
+ unsigned long
+ )
+ {
+ if (rect.contains(lastx,lasty) && enabled && !hidden)
+ {
+ if (need_v_scroll())
+ {
+ long pos = vsb.slider_pos();
+ vsb.set_slider_pos(pos+(long)v_wheel_scroll_bar_inc);
+ on_v_scroll();
+ }
+ else if (need_h_scroll())
+ {
+ long pos = hsb.slider_pos();
+ hsb.set_slider_pos(pos+(long)h_wheel_scroll_bar_inc);
+ on_h_scroll();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ if (enabled && !hidden && user_is_dragging_mouse && state==base_window::LEFT)
+ {
+ point current_delta = point(x,y) - point(total_rect().left(), total_rect().top());
+ rectangle new_rect(translate_rect(display_rect(), drag_origin - current_delta));
+ new_rect = centered_rect(new_rect, new_rect.width()-hscroll_bar_inc, new_rect.height()-vscroll_bar_inc);
+ scroll_to_rect(new_rect);
+ on_view_changed();
+ }
+ else
+ {
+ user_is_dragging_mouse = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long ,
+ long x,
+ long y,
+ bool
+ )
+ {
+ if (mouse_drag_enabled_ && enabled && !hidden && display_rect().contains(x,y) && (btn==base_window::LEFT))
+ {
+ drag_origin = point(x,y) - point(total_rect().left(), total_rect().top());
+ user_is_dragging_mouse = true;
+ }
+ else
+ {
+ user_is_dragging_mouse = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ on_mouse_up (
+ unsigned long ,
+ unsigned long ,
+ long ,
+ long
+ )
+ {
+ user_is_dragging_mouse = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ on_wheel_up (
+ unsigned long
+ )
+ {
+ if (rect.contains(lastx,lasty) && enabled && !hidden)
+ {
+ if (need_v_scroll())
+ {
+ long pos = vsb.slider_pos();
+ vsb.set_slider_pos(pos-(long)v_wheel_scroll_bar_inc);
+ on_v_scroll();
+ }
+ else if (need_h_scroll())
+ {
+ long pos = hsb.slider_pos();
+ hsb.set_slider_pos(pos-(long)h_wheel_scroll_bar_inc);
+ on_h_scroll();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ draw (
+ const canvas& c
+ ) const
+ {
+ style->draw_scrollable_region_border(c, rect, enabled);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool scrollable_region::
+ need_h_scroll (
+ ) const
+ {
+ if (total_rect_.width() > rect.width()-style->get_border_size()*2)
+ {
+ return true;
+ }
+ else
+ {
+ // check if we would need a vertical scroll bar and if adding one would make us need
+ // a horizontal one
+ if (total_rect_.height() > rect.height()-style->get_border_size()*2 &&
+ total_rect_.width() > rect.width()-style->get_border_size()*2-vsb.width())
+ return true;
+ else
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool scrollable_region::
+ need_v_scroll (
+ ) const
+ {
+ if (total_rect_.height() > rect.height()-style->get_border_size()*2)
+ {
+ return true;
+ }
+ else
+ {
+ // check if we would need a horizontal scroll bar and if adding one would make us need
+ // a vertical_scroll_pos one
+ if (total_rect_.width() > rect.width()-style->get_border_size()*2 &&
+ total_rect_.height() > rect.height()-style->get_border_size()*2-hsb.height())
+ return true;
+ else
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ on_h_scroll (
+ )
+ {
+ total_rect_ = move_rect(total_rect_, display_rect_.left()-hscroll_bar_inc*hsb.slider_pos(), total_rect_.top());
+ parent.invalidate_rectangle(display_rect_);
+ if (events_are_enabled())
+ on_view_changed();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scrollable_region::
+ on_v_scroll (
+ )
+ {
+ total_rect_ = move_rect(total_rect_, total_rect_.left(), display_rect_.top()-vscroll_bar_inc*vsb.slider_pos());
+ parent.invalidate_rectangle(display_rect_);
+ if (events_are_enabled())
+ on_view_changed();
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// class popup_menu_region
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ popup_menu_region::
+ popup_menu_region(
+ drawable_window& w
+ ) :
+ drawable(w,MOUSE_CLICK | KEYBOARD_EVENTS | FOCUS_EVENTS | WINDOW_MOVED),
+ popup_menu_shown(false)
+ {
+
+ menu_.set_on_hide_handler(*this,&popup_menu_region::on_menu_becomes_hidden);
+ enable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ popup_menu_region::
+ ~popup_menu_region(
+ )
+ {
+ disable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ rect = resize_rect(rect,width,height);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ set_rect (
+ const rectangle& new_rect
+ )
+ {
+ auto_mutex M(m);
+ rect = new_rect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ popup_menu& popup_menu_region::
+ menu (
+ )
+ {
+ return menu_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ hide (
+ )
+ {
+ auto_mutex M(m);
+ drawable::hide();
+ menu_.hide();
+ popup_menu_shown = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ disable (
+ )
+ {
+ auto_mutex M(m);
+ drawable::disable();
+ menu_.hide();
+ popup_menu_shown = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ if (enabled && !hidden && popup_menu_shown)
+ {
+ menu_.forwarded_on_keydown(key, is_printable, state);
+ }
+ else if (popup_menu_shown)
+ {
+ menu_.hide();
+ popup_menu_shown = false;
+ }
+
+ if (key == (unsigned long)base_window::KEY_ESC)
+ {
+ menu_.hide();
+ popup_menu_shown = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ on_menu_becomes_hidden (
+ )
+ {
+ popup_menu_shown = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ on_focus_lost (
+ )
+ {
+ if (popup_menu_shown)
+ {
+ menu_.hide();
+ popup_menu_shown = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ on_focus_gained (
+ )
+ {
+ if (popup_menu_shown)
+ {
+ menu_.hide();
+ popup_menu_shown = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ on_window_moved(
+ )
+ {
+ if (popup_menu_shown)
+ {
+ menu_.hide();
+ popup_menu_shown = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long ,
+ long x,
+ long y,
+ bool
+ )
+ {
+ if (enabled && !hidden && rect.contains(x,y) && btn == base_window::RIGHT)
+ {
+ long orig_x, orig_y;
+ parent.get_pos(orig_x, orig_y);
+ menu_.set_pos(orig_x+x, orig_y+y);
+ menu_.show();
+ popup_menu_shown = true;
+ }
+ else if (popup_menu_shown)
+ {
+ menu_.hide();
+ popup_menu_shown = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void popup_menu_region::
+ draw (
+ const canvas&
+ ) const
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BASE_WIDGETs_CPP_
+
diff --git a/ml/dlib/dlib/gui_widgets/base_widgets.h b/ml/dlib/dlib/gui_widgets/base_widgets.h
new file mode 100644
index 000000000..7c6d25097
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/base_widgets.h
@@ -0,0 +1,2678 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifndef DLIB_BASE_WIDGETs_
+#define DLIB_BASE_WIDGETs_
+
+#include <cctype>
+#include <memory>
+
+#include "base_widgets_abstract.h"
+#include "drawable.h"
+#include "../gui_core.h"
+#include "../algs.h"
+#include "../member_function_pointer.h"
+#include "../timer.h"
+#include "../map.h"
+#include "../set.h"
+#include "../array2d.h"
+#include "../pixel.h"
+#include "../image_transforms/assign_image.h"
+#include "../array.h"
+#include "style.h"
+#include "../unicode.h"
+#include "../any.h"
+
+
+namespace dlib
+{
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class draggable
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class draggable : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - drag == false
+
+ CONVENTION
+ - if (the user is holding the left button down over this object) then
+ - drag == true
+ - x == the x position of the mouse relative to the upper left corner
+ of this object.
+ - y == the y position of the mouse relative to the upper left corner
+ of this object.
+ - else
+ - drag == false
+ !*/
+
+ public:
+
+ draggable(
+ drawable_window& w,
+ unsigned long events = 0
+ ) :
+ drawable(w,events | MOUSE_MOVE | MOUSE_CLICK),
+ drag(false)
+ {}
+
+ virtual ~draggable(
+ ) = 0;
+
+ rectangle draggable_area (
+ ) const { auto_mutex M(m); return area; }
+
+ void set_draggable_area (
+ const rectangle& area_
+ ) { auto_mutex M(m); area = area_; }
+
+ protected:
+
+ bool is_being_dragged (
+ ) const { return drag; }
+
+ virtual void on_drag (
+ ){}
+
+ virtual void on_drag_stop (
+ ){}
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long ,
+ long x,
+ long y,
+ bool
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ private:
+
+ rectangle area;
+ bool drag;
+ long x, y;
+
+ // restricted functions
+ draggable(draggable&); // copy constructor
+ draggable& operator=(draggable&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class mouse_over_event
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class mouse_over_event : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - is_mouse_over_ == false
+
+ CONVENTION
+ - is_mouse_over_ == is_mouse_over()
+ !*/
+
+ public:
+
+ mouse_over_event(
+ drawable_window& w,
+ unsigned long events = 0
+ ) :
+ drawable(w,events | MOUSE_MOVE),
+ is_mouse_over_(false)
+ {}
+
+
+ virtual ~mouse_over_event(
+ ) = 0;
+
+ int next_free_user_event_number() const
+ {
+ return drawable::next_free_user_event_number()+1;
+ }
+
+ protected:
+
+ bool is_mouse_over (
+ ) const;
+
+ virtual void on_mouse_over (
+ ){}
+
+ virtual void on_mouse_not_over (
+ ){}
+
+ void on_mouse_leave (
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_user_event (
+ int num
+ );
+
+ private:
+ mutable bool is_mouse_over_;
+
+ // restricted functions
+ mouse_over_event(mouse_over_event&); // copy constructor
+ mouse_over_event& operator=(mouse_over_event&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class button_action
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class button_action : public mouse_over_event
+ {
+ /*!
+ INITIAL VALUE
+ - is_depressed_ == false
+ - seen_click == false
+
+ CONVENTION
+ - is_depressed_ == is_depressed()
+ - if (the user has clicked the button but hasn't yet released the
+ left mouse button) then
+ - seen_click == true
+ - else
+ - seen_click == false
+ !*/
+
+ public:
+
+ button_action(
+ drawable_window& w,
+ unsigned long events = 0
+ ) :
+ mouse_over_event(w,events | MOUSE_MOVE | MOUSE_CLICK),
+ is_depressed_(false),
+ seen_click(false)
+ {}
+
+
+ virtual ~button_action(
+ ) = 0;
+
+ int next_free_user_event_number() const
+ {
+ return mouse_over_event::next_free_user_event_number()+1;
+ }
+
+ protected:
+
+ bool is_depressed (
+ ) const;
+
+ virtual void on_button_down (
+ ){}
+
+ virtual void on_button_up (
+ bool
+ ){}
+
+ void on_mouse_not_over (
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long ,
+ long x,
+ long y,
+ bool
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long,
+ long x,
+ long y
+ );
+
+
+ private:
+ mutable bool is_depressed_;
+ bool seen_click;
+
+ void on_user_event (
+ int num
+ );
+
+ // restricted functions
+ button_action(button_action&); // copy constructor
+ button_action& operator=(button_action&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class widget_group
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class widget_group : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ widgets.size() == 0
+
+ CONVENTION
+ - widgets contains all the drawable objects and their relative positions
+ that are in *this.
+ - wg_widgets contains pointers to just the widgets that happen
+ to be widget_group objects.
+ !*/
+
+ struct relpos
+ {
+ unsigned long x;
+ unsigned long y;
+ };
+
+ public:
+ widget_group(
+ drawable_window& w
+ ) : drawable(w) { rect = rectangle(0,0,-1,-1); enable_events();}
+
+ virtual ~widget_group(
+ ){ disable_events(); }
+
+ void empty (
+ );
+
+ void add (
+ drawable& widget,
+ unsigned long x,
+ unsigned long y
+ );
+
+ void add (
+ widget_group& widget,
+ unsigned long x,
+ unsigned long y
+ );
+
+ bool is_member (
+ const drawable& widget
+ ) const;
+
+ void remove (
+ const drawable& widget
+ );
+
+ size_t size (
+ ) const;
+
+ void set_pos (
+ long x,
+ long y
+ );
+
+ void set_z_order (
+ long order
+ );
+
+ void show (
+ );
+
+ void hide (
+ );
+
+ void enable (
+ );
+
+ void disable (
+ );
+
+ void fit_to_contents (
+ );
+
+ protected:
+
+ // this object doesn't draw anything but also isn't abstract
+ void draw (
+ const canvas&
+ ) const {}
+
+ private:
+
+ map<drawable*,relpos>::kernel_1a_c widgets;
+ set<widget_group*>::kernel_1a_c wg_widgets;
+
+
+ // restricted functions
+ widget_group(widget_group&); // copy constructor
+ widget_group& operator=(widget_group&); // assignment operator
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ class image_widget : public draggable
+ {
+ /*!
+ INITIAL VALUE
+ - img.size() == 0
+
+ CONVENTION
+ - img == the image this object displays
+ !*/
+
+ public:
+
+ image_widget(
+ drawable_window& w
+ ): draggable(w) { enable_events(); }
+
+ ~image_widget(
+ )
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ }
+
+ template <
+ typename image_type
+ >
+ void set_image (
+ const image_type& new_img
+ )
+ {
+ auto_mutex M(m);
+ assign_image_scaled(img,new_img);
+ rectangle old(rect);
+ rect.set_right(rect.left()+num_columns(img)-1);
+ rect.set_bottom(rect.top()+num_rows(img)-1);
+ parent.invalidate_rectangle(rect+old);
+ }
+
+ private:
+
+ void draw (
+ const canvas& c
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ draw_image(c, point(rect.left(),rect.top()), img);
+ }
+
+ array2d<rgb_alpha_pixel> img;
+
+ // restricted functions
+ image_widget(image_widget&); // copy constructor
+ image_widget& operator=(image_widget&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class tooltip
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class tooltip : public mouse_over_event
+ {
+ /*!
+ INITIAL VALUE
+ - stuff.get() == 0
+ - events_are_enabled() == false
+
+ CONVENTION
+ - if (events_are_enabled() == true) then
+ - stuff.get() != 0
+ !*/
+
+ public:
+
+ tooltip(
+ drawable_window& w
+ ) :
+ mouse_over_event(w,MOUSE_CLICK)
+ {}
+
+ ~tooltip(
+ ){ disable_events();}
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ rect = resize_rect(rect,width,height);
+ }
+
+
+ void set_text (
+ const std::string& str
+ )
+ {
+ set_text(convert_mbstring_to_wstring(str));
+ }
+
+ void set_text (
+ const std::wstring& str
+ )
+ {
+ set_text(convert_wstring_to_utf32(str));
+ }
+
+ void set_text (
+ const ustring& str
+ )
+ {
+ auto_mutex M(m);
+ if (!stuff)
+ {
+ stuff.reset(new data(*this));
+ enable_events();
+ }
+
+ stuff->win.set_text(str);
+ }
+
+ const std::string text (
+ ) const
+ {
+ return convert_wstring_to_mbstring(wtext());
+ }
+
+ const std::wstring wtext (
+ ) const
+ {
+ return convert_utf32_to_wstring(utext());
+ }
+
+ const dlib::ustring utext (
+ ) const
+ {
+ auto_mutex M(m);
+ dlib::ustring temp;
+ if (stuff)
+ {
+ temp = stuff->win.text;
+ }
+ return temp.c_str();
+ }
+
+ void hide (
+ )
+ {
+ auto_mutex M(m);
+ mouse_over_event::hide();
+ if (stuff)
+ {
+ stuff->tt_timer.stop();
+ stuff->win.hide();
+ }
+ }
+
+ void disable (
+ )
+ {
+ auto_mutex M(m);
+ mouse_over_event::disable();
+ if (stuff)
+ {
+ stuff->tt_timer.stop();
+ stuff->win.hide();
+ }
+ }
+
+ protected:
+
+ void on_mouse_over()
+ {
+ stuff->x = lastx;
+ stuff->y = lasty;
+ stuff->tt_timer.start();
+ }
+
+ void on_mouse_not_over ()
+ {
+ stuff->tt_timer.stop();
+ stuff->win.hide();
+ }
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ )
+ {
+ mouse_over_event::on_mouse_down(btn,state,x,y,is_double_click);
+ stuff->tt_timer.stop();
+ stuff->win.hide();
+ }
+
+ void draw (
+ const canvas&
+ ) const{}
+
+ private:
+
+ class tooltip_window : public base_window
+ {
+ public:
+ tooltip_window (const std::shared_ptr<font>& f) : base_window(false,true), pad(3), mfont(f)
+ {
+ }
+
+ ustring text;
+ rectangle rect_all;
+ rectangle rect_text;
+ const unsigned long pad;
+ const std::shared_ptr<font> mfont;
+
+ void set_text (
+ const std::string& str
+ )
+ {
+ set_text(convert_mbstring_to_wstring(str));
+ }
+
+ void set_text (
+ const std::wstring& str
+ )
+ {
+ set_text(convert_wstring_to_utf32(str));
+ }
+
+ void set_text (
+ const dlib::ustring& str
+ )
+ {
+ text = str.c_str();
+
+ unsigned long width, height;
+ mfont->compute_size(text,width,height);
+
+ set_size(width+pad*2, height+pad*2);
+ rect_all.set_left(0);
+ rect_all.set_top(0);
+ rect_all.set_right(width+pad*2-1);
+ rect_all.set_bottom(height+pad*2-1);
+
+ rect_text = move_rect(rectangle(width,height),pad,pad);
+ }
+
+ void paint(const canvas& c)
+ {
+ c.fill(255,255,150);
+ draw_rectangle(c, rect_all);
+ mfont->draw_string(c,rect_text,text);
+ }
+ };
+
+ void show_tooltip (
+ )
+ {
+ auto_mutex M(m);
+ long x, y;
+ // if the mouse has moved since we started the timer then
+ // keep waiting until the user stops moving it
+ if (lastx != stuff->x || lasty != stuff->y)
+ {
+ stuff->x = lastx;
+ stuff->y = lasty;
+ return;
+ }
+
+ unsigned long display_width, display_height;
+ // stop the timer
+ stuff->tt_timer.stop();
+ parent.get_pos(x,y);
+ x += lastx+15;
+ y += lasty+15;
+
+ // make sure the tooltip isn't going to be off the screen
+ parent.get_display_size(display_width, display_height);
+ rectangle wrect(move_rect(stuff->win.rect_all,x,y));
+ rectangle srect(display_width, display_height);
+ if (srect.contains(wrect) == false)
+ {
+ rectangle temp(srect.intersect(wrect));
+ x -= wrect.width()-temp.width();
+ y -= wrect.height()-temp.height();
+ }
+
+ stuff->win.set_pos(x,y);
+ stuff->win.show();
+ }
+
+ // put all this stuff in data so we can arrange to only
+ // construct it when someone is actually using the tooltip widget
+ // rather than just instantiating it.
+ struct data
+ {
+ data(
+ tooltip& self
+ ) :
+ x(-1),
+ y(-1),
+ win(self.mfont),
+ tt_timer(self,&tooltip::show_tooltip)
+ {
+ tt_timer.set_delay_time(400);
+ }
+
+ long x, y;
+ tooltip_window win;
+ timer<tooltip> tt_timer;
+
+ };
+ friend struct data;
+ std::unique_ptr<data> stuff;
+
+
+
+ // restricted functions
+ tooltip(tooltip&); // copy constructor
+ tooltip& operator=(tooltip&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class button
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class button : public button_action
+ {
+ public:
+ button(
+ drawable_window& w
+ ) :
+ button_action(w),
+ btn_tooltip(w)
+ {
+ style.reset(new button_style_default());
+ enable_events();
+ }
+
+ ~button() { disable_events(); parent.invalidate_rectangle(style->get_invalidation_rect(rect)); }
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ void set_name (
+ const std::string& name_
+ );
+
+ void set_name (
+ const std::wstring& name_
+ );
+
+ void set_name (
+ const dlib::ustring& name_
+ );
+
+ const std::string name (
+ ) const;
+
+ const std::wstring wname (
+ ) const;
+
+ const dlib::ustring uname (
+ ) const;
+
+ void set_tooltip_text (
+ const std::string& text
+ );
+
+ void set_tooltip_text (
+ const std::wstring& text
+ );
+
+ void set_tooltip_text (
+ const dlib::ustring& text
+ );
+
+ void set_pos(
+ long x,
+ long y
+ );
+
+ const std::string tooltip_text (
+ ) const;
+
+ const std::wstring tooltip_wtext (
+ ) const;
+
+ const dlib::ustring tooltip_utext (
+ ) const;
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ void show (
+ );
+
+ void hide (
+ );
+
+ void enable (
+ );
+
+ void disable (
+ );
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ )
+ {
+ auto_mutex M(m);
+ style.reset(new style_type(style_));
+ rect = move_rect(style->get_min_size(name_,*mfont), rect.left(), rect.top());
+ parent.invalidate_rectangle(style->get_invalidation_rect(rect));
+ }
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler_)()
+ )
+ {
+ auto_mutex M(m);
+ event_handler = make_mfp(object,event_handler_);
+ event_handler_self.clear();
+ }
+
+ void set_click_handler (
+ const any_function<void()>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ event_handler = event_handler_;
+ event_handler_self.clear();
+ }
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler_)(button&)
+ )
+ {
+ auto_mutex M(m);
+ event_handler_self = make_mfp(object,event_handler_);
+ event_handler.clear();
+ }
+
+ void set_sourced_click_handler (
+ const any_function<void(button&)>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ event_handler_self = event_handler_;
+ event_handler.clear();
+ }
+
+ bool is_depressed (
+ ) const
+ {
+ auto_mutex M(m);
+ return button_action::is_depressed();
+ }
+
+ template <
+ typename T
+ >
+ void set_button_down_handler (
+ T& object,
+ void (T::*event_handler)()
+ )
+ {
+ auto_mutex M(m);
+ button_down_handler = make_mfp(object,event_handler);
+ }
+
+ void set_button_down_handler (
+ const any_function<void()>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ button_down_handler = event_handler;
+ }
+
+ template <
+ typename T
+ >
+ void set_button_up_handler (
+ T& object,
+ void (T::*event_handler)(bool mouse_over)
+ )
+ {
+ auto_mutex M(m);
+ button_up_handler = make_mfp(object,event_handler);
+ }
+
+ void set_button_up_handler (
+ const any_function<void(bool)>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ button_up_handler = event_handler;
+ }
+
+ template <
+ typename T
+ >
+ void set_button_down_handler (
+ T& object,
+ void (T::*event_handler)(button&)
+ )
+ {
+ auto_mutex M(m);
+ button_down_handler_self = make_mfp(object,event_handler);
+ }
+
+ void set_sourced_button_down_handler (
+ const any_function<void(button&)>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ button_down_handler_self = event_handler;
+ }
+
+ template <
+ typename T
+ >
+ void set_button_up_handler (
+ T& object,
+ void (T::*event_handler)(bool mouse_over, button&)
+ )
+ {
+ auto_mutex M(m);
+ button_up_handler_self = make_mfp(object,event_handler);
+ }
+
+ void set_sourced_button_up_handler (
+ const any_function<void(bool,button&)>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ button_up_handler_self = event_handler;
+ }
+
+ private:
+
+ // restricted functions
+ button(button&); // copy constructor
+ button& operator=(button&); // assignment operator
+
+ dlib::ustring name_;
+ tooltip btn_tooltip;
+
+ any_function<void()> event_handler;
+ any_function<void(button&)> event_handler_self;
+ any_function<void()> button_down_handler;
+ any_function<void(bool)> button_up_handler;
+ any_function<void(button&)> button_down_handler_self;
+ any_function<void(bool,button&)> button_up_handler_self;
+
+ std::unique_ptr<button_style> style;
+
+ protected:
+
+ void draw (
+ const canvas& c
+ ) const { style->draw_button(c,rect,enabled,*mfont,lastx,lasty,name_,is_depressed()); }
+
+ void on_button_up (
+ bool mouse_over
+ );
+
+ void on_button_down (
+ );
+
+ void on_mouse_over (
+ ){ if (style->redraw_on_mouse_over()) parent.invalidate_rectangle(style->get_invalidation_rect(rect)); }
+
+ void on_mouse_not_over (
+ ){ if (style->redraw_on_mouse_over()) parent.invalidate_rectangle(style->get_invalidation_rect(rect)); }
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class scroll_bar
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class scroll_bar : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - ori == a value given by the constructor
+ - style == a scroll_bar_style_default object
+ - pos == 0
+ - max_pos == 0
+ - js == 10
+
+ CONVENTION
+ - ori == orientation()
+ - b1 == the button that is near the 0 end of the scroll bar
+ - b2 == the button that is near the max_pos() end of the scroll bar
+
+ - max_pos == max_slider_pos()
+ - pos == slider_pos()
+ - js == jump_size()
+ !*/
+
+ public:
+ enum bar_orientation
+ {
+ HORIZONTAL,
+ VERTICAL
+ };
+
+ scroll_bar(
+ drawable_window& w,
+ bar_orientation orientation_
+ );
+
+ virtual ~scroll_bar(
+ );
+
+ bar_orientation orientation (
+ ) const;
+
+ void set_length (
+ unsigned long length
+ );
+
+ long max_slider_pos (
+ ) const;
+
+ void set_max_slider_pos (
+ long mpos
+ );
+
+ void set_slider_pos (
+ long pos
+ );
+
+ long slider_pos (
+ ) const;
+
+ template <
+ typename T
+ >
+ void set_scroll_handler (
+ T& object,
+ void (T::*eh)()
+ ) { auto_mutex M(m); scroll_handler = make_mfp(object,eh); }
+
+ void set_scroll_handler (
+ const any_function<void()>& eh
+ ) { auto_mutex M(m); scroll_handler = eh; }
+
+ void set_pos (
+ long x,
+ long y
+ );
+
+ void enable (
+ )
+ {
+ auto_mutex M(m);
+ if (!hidden)
+ show_slider();
+ if (max_pos != 0)
+ {
+ b1.enable();
+ b2.enable();
+ }
+ drawable::enable();
+ }
+
+ void disable (
+ )
+ {
+ auto_mutex M(m);
+ hide_slider();
+ b1.disable();
+ b2.disable();
+ drawable::disable();
+ }
+
+ void hide (
+ )
+ {
+ auto_mutex M(m);
+ hide_slider();
+ top_filler.hide();
+ bottom_filler.hide();
+ b1.hide();
+ b2.hide();
+ drawable::hide();
+ }
+
+ void show (
+ )
+ {
+ auto_mutex M(m);
+ b1.show();
+ b2.show();
+ drawable::show();
+ top_filler.show();
+ if (enabled)
+ show_slider();
+ }
+
+ void set_z_order (
+ long order
+ )
+ {
+ auto_mutex M(m);
+ slider.set_z_order(order);
+ top_filler.set_z_order(order);
+ bottom_filler.set_z_order(order);
+ b1.set_z_order(order);
+ b2.set_z_order(order);
+ drawable::set_z_order(order);
+ }
+
+ void set_jump_size (
+ long js
+ );
+
+ long jump_size (
+ ) const;
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ )
+ {
+ auto_mutex M(m);
+ style.reset(new style_type(style_));
+
+ if (ori == HORIZONTAL)
+ {
+ b1.set_style(style_.get_left_button_style());
+ b2.set_style(style_.get_right_button_style());
+ set_length(rect.width());
+ }
+ else
+ {
+ b1.set_style(style_.get_up_button_style());
+ b2.set_style(style_.get_down_button_style());
+ set_length(rect.height());
+ }
+
+ }
+
+ private:
+
+ void hide_slider (
+ );
+ /*!
+ ensures
+ - hides the slider and makes any other changes needed so that the
+ scroll_bar still looks right.
+ !*/
+
+ void show_slider (
+ );
+ /*!
+ ensures
+ - shows the slider and makes any other changes needed so that the
+ scroll_bar still looks right.
+ !*/
+
+
+ void on_slider_drag (
+ );
+ /*!
+ requires
+ - is called whenever the user drags the slider
+ !*/
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ void b1_down (
+ );
+
+ void b1_up (
+ bool mouse_over
+ );
+
+ void b2_down (
+ );
+
+ void b2_up (
+ bool mouse_over
+ );
+
+ void top_filler_down (
+ );
+
+ void top_filler_up (
+ bool mouse_over
+ );
+
+ void bottom_filler_down (
+ );
+
+ void bottom_filler_up (
+ bool mouse_over
+ );
+
+ void on_user_event (
+ int i
+ );
+
+ void delayed_set_slider_pos (
+ unsigned long dpos
+ );
+
+ void b1_down_t (
+ );
+
+ void b2_down_t (
+ );
+
+ void top_filler_down_t (
+ );
+
+ void bottom_filler_down_t (
+ );
+
+ friend class filler;
+ class filler : public button_action
+ {
+ friend class scroll_bar;
+ public:
+ filler (
+ drawable_window& w,
+ scroll_bar& object,
+ void (scroll_bar::*down)(),
+ void (scroll_bar::*up)(bool)
+ ):
+ button_action(w),
+ my_scroll_bar(object)
+ {
+ bup = make_mfp(object,up);
+ bdown = make_mfp(object,down);
+
+ enable_events();
+ }
+
+ ~filler (
+ )
+ {
+ disable_events();
+ }
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ rectangle old(rect);
+ const unsigned long x = rect.left();
+ const unsigned long y = rect.top();
+ rect.set_right(x+width-1);
+ rect.set_bottom(y+height-1);
+
+ parent.invalidate_rectangle(rect+old);
+ }
+
+ private:
+
+ void draw (
+ const canvas& c
+ ) const
+ {
+ my_scroll_bar.style->draw_scroll_bar_background(c,rect,enabled,lastx,lasty,is_depressed());
+ }
+
+ void on_button_down (
+ ) { bdown(); }
+
+ void on_button_up (
+ bool mouse_over
+ ) { bup(mouse_over); }
+
+ scroll_bar& my_scroll_bar;
+ any_function<void()> bdown;
+ any_function<void(bool)> bup;
+ };
+
+ friend class slider_class;
+ class slider_class : public draggable
+ {
+ friend class scroll_bar;
+ public:
+ slider_class (
+ drawable_window& w,
+ scroll_bar& object,
+ void (scroll_bar::*handler)()
+ ) :
+ draggable(w, MOUSE_MOVE),
+ mouse_in_widget(false),
+ my_scroll_bar(object)
+ {
+ callback = make_mfp(object,handler);
+ enable_events();
+ }
+
+ ~slider_class (
+ )
+ {
+ disable_events();
+ }
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ rectangle old(rect);
+ const unsigned long x = rect.left();
+ const unsigned long y = rect.top();
+ rect.set_right(x+width-1);
+ rect.set_bottom(y+height-1);
+
+ parent.invalidate_rectangle(rect+old);
+ }
+
+ private:
+ virtual void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ draggable::on_mouse_move(state,x,y);
+ if (!hidden && my_scroll_bar.style->redraw_on_mouse_over_slider())
+ {
+ if (rect.contains(x,y) && !mouse_in_widget)
+ {
+ mouse_in_widget = true;
+ parent.invalidate_rectangle(rect);
+ }
+ else if (rect.contains(x,y) == false && mouse_in_widget)
+ {
+ mouse_in_widget = false;
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ }
+
+ void on_mouse_leave (
+ )
+ {
+ if (mouse_in_widget && my_scroll_bar.style->redraw_on_mouse_over_slider())
+ {
+ mouse_in_widget = false;
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+ void on_drag_stop (
+ )
+ {
+ if (my_scroll_bar.style->redraw_on_mouse_over_slider())
+ parent.invalidate_rectangle(rect);
+ }
+
+ void on_drag (
+ )
+ {
+ callback();
+ }
+
+ void draw (
+ const canvas& c
+ ) const
+ {
+ my_scroll_bar.style->draw_scroll_bar_slider(c,rect,enabled,lastx,lasty, is_being_dragged());
+ }
+
+ bool mouse_in_widget;
+ scroll_bar& my_scroll_bar;
+ any_function<void()> callback;
+ };
+
+
+ void adjust_fillers (
+ );
+ /*!
+ ensures
+ - top_filler and bottom_filler appear in their correct positions
+ relative to the current positions of the slider and the b1 and
+ b2 buttons
+ !*/
+
+ unsigned long get_slider_size (
+ ) const;
+ /*!
+ ensures
+ - returns the length in pixels the slider should have based on the current
+ state of this scroll bar
+ !*/
+
+
+ button b1, b2;
+ slider_class slider;
+ bar_orientation ori;
+ filler top_filler, bottom_filler;
+ any_function<void()> scroll_handler;
+
+ long pos;
+ long max_pos;
+ long js;
+
+ timer<scroll_bar> b1_timer;
+ timer<scroll_bar> b2_timer;
+ timer<scroll_bar> top_filler_timer;
+ timer<scroll_bar> bottom_filler_timer;
+ long delayed_pos;
+ std::unique_ptr<scroll_bar_style> style;
+
+ // restricted functions
+ scroll_bar(scroll_bar&); // copy constructor
+ scroll_bar& operator=(scroll_bar&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class popup_menu
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class menu_item
+ {
+ public:
+ virtual ~menu_item() {}
+
+ virtual rectangle get_left_size (
+ ) const { return rectangle(); }
+ virtual rectangle get_middle_size (
+ ) const = 0;
+ virtual rectangle get_right_size (
+ ) const { return rectangle(); }
+
+ virtual unichar get_hot_key (
+ ) const { return 0; }
+
+ virtual void draw_background (
+ const canvas& ,
+ const rectangle& ,
+ const bool ,
+ const bool
+ ) const {}
+
+ virtual void draw_left (
+ const canvas& ,
+ const rectangle& ,
+ const bool ,
+ const bool
+ ) const {}
+
+ virtual void draw_middle (
+ const canvas& ,
+ const rectangle& ,
+ const bool ,
+ const bool
+ ) const = 0;
+
+ virtual void draw_right (
+ const canvas& ,
+ const rectangle& ,
+ const bool ,
+ const bool
+ ) const {}
+
+ virtual void on_click (
+ ) const {}
+
+ virtual bool has_click_event (
+ ) const { return false; }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class menu_item_submenu : public menu_item
+ {
+ void initialize (
+ unichar hk
+ )
+ {
+ const dlib::ustring &str = text;
+ if (hk != 0)
+ {
+ std::string::size_type pos = str.find_first_of(hk);
+ if (pos != std::string::npos)
+ {
+ // now compute the location of the underline bar
+ rectangle r1 = f->compute_cursor_rect( rectangle(100000,100000), str, pos);
+ rectangle r2 = f->compute_cursor_rect( rectangle(100000,100000), str, pos+1);
+
+ underline_p1.x() = r1.left()+1;
+ underline_p2.x() = r2.left()-1;
+ underline_p1.y() = r1.bottom()-f->height()+f->ascender()+2;
+ underline_p2.y() = r2.bottom()-f->height()+f->ascender()+2;
+ }
+ }
+ }
+ public:
+ menu_item_submenu (
+ const std::string& str,
+ unichar hk = 0
+ ) :
+ text(convert_wstring_to_utf32(convert_mbstring_to_wstring(str))),
+ f(default_font::get_font()),
+ hotkey(hk)
+ {
+ initialize(hk);
+ }
+
+ menu_item_submenu (
+ const std::wstring& str,
+ unichar hk = 0
+ ) :
+ text(convert_wstring_to_utf32(str)),
+ f(default_font::get_font()),
+ hotkey(hk)
+ {
+ initialize(hk);
+ }
+
+ menu_item_submenu (
+ const dlib::ustring& str,
+ unichar hk = 0
+ ) :
+ text(str),
+ f(default_font::get_font()),
+ hotkey(hk)
+ {
+ initialize(hk);
+ }
+
+ virtual unichar get_hot_key (
+ ) const { return hotkey; }
+
+ virtual rectangle get_middle_size (
+ ) const
+ {
+ unsigned long width, height;
+ f->compute_size(text,width,height);
+ return rectangle(width+30,height);
+ }
+
+ virtual rectangle get_right_size (
+ ) const
+ {
+ return rectangle(15, 5);
+ }
+
+ virtual void draw_background (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const bool is_selected
+ ) const
+ {
+ if (c.intersect(rect).is_empty())
+ return;
+
+ if (enabled && is_selected)
+ {
+ fill_rect_with_vertical_gradient(c, rect,rgb_alpha_pixel(0,200,0,100), rgb_alpha_pixel(0,0,0,100));
+ draw_rectangle(c, rect,rgb_alpha_pixel(0,0,0,100));
+ }
+ }
+
+ virtual void draw_right (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const bool
+ ) const
+ {
+ if (c.intersect(rect).is_empty())
+ return;
+
+ unsigned char color = 0;
+
+ if (enabled == false)
+ color = 128;
+
+ long x, y;
+ x = rect.right() - 7;
+ y = rect.top() + rect.height()/2;
+
+ for ( unsigned long i = 0; i < 5; ++i)
+ draw_line (c, point(x - i, y + i), point(x - i, y - i), color);
+ }
+
+ virtual void draw_middle (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const bool
+ ) const
+ {
+ if (c.intersect(rect).is_empty())
+ return;
+
+ if (enabled)
+ {
+ f->draw_string(c,rect,text);
+ }
+ else
+ {
+ f->draw_string(c,rect,text,128);
+ }
+
+ if (underline_p1 != underline_p2)
+ {
+ point base(rect.left(),rect.top());
+ draw_line(c, base+underline_p1, base+underline_p2);
+ }
+ }
+
+ private:
+ dlib::ustring text;
+ const std::shared_ptr<font> f;
+ any_function<void()> action;
+ unichar hotkey;
+ point underline_p1;
+ point underline_p2;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class menu_item_text : public menu_item
+ {
+ void initialize (
+ const any_function<void()>& event_handler_,
+ unichar hk
+ )
+ {
+ dlib::ustring &str = text;
+ action = event_handler_;
+
+ if (hk != 0)
+ {
+ std::string::size_type pos = str.find_first_of(hk);
+ if (pos != std::string::npos)
+ {
+ // now compute the location of the underline bar
+ rectangle r1 = f->compute_cursor_rect( rectangle(100000,100000), str, pos);
+ rectangle r2 = f->compute_cursor_rect( rectangle(100000,100000), str, pos+1);
+
+ underline_p1.x() = r1.left()+1;
+ underline_p2.x() = r2.left()-1;
+ underline_p1.y() = r1.bottom()-f->height()+f->ascender()+2;
+ underline_p2.y() = r2.bottom()-f->height()+f->ascender()+2;
+ }
+ }
+ }
+
+ public:
+ template <typename T>
+ menu_item_text (
+ const std::string& str,
+ T& object,
+ void (T::*event_handler_)(),
+ unichar hk = 0
+ ) :
+ text(convert_wstring_to_utf32(convert_mbstring_to_wstring(str))),
+ f(default_font::get_font()),
+ hotkey(hk)
+ {
+ initialize(make_mfp(object, event_handler_), hk);
+ }
+
+ menu_item_text (
+ const std::string& str,
+ const any_function<void()>& event_handler_,
+ unichar hk = 0
+ ) :
+ text(convert_wstring_to_utf32(convert_mbstring_to_wstring(str))),
+ f(default_font::get_font()),
+ hotkey(hk)
+ {
+ initialize(event_handler_, hk);
+ }
+
+ template <typename T>
+ menu_item_text (
+ const std::wstring& str,
+ T& object,
+ void (T::*event_handler_)(),
+ unichar hk = 0
+ ) :
+ text(convert_wstring_to_utf32(str)),
+ f(default_font::get_font()),
+ hotkey(hk)
+ {
+ initialize(make_mfp(object, event_handler_), hk);
+ }
+
+ menu_item_text (
+ const std::wstring& str,
+ const any_function<void()>& event_handler_,
+ unichar hk = 0
+ ) :
+ text(convert_wstring_to_utf32(str)),
+ f(default_font::get_font()),
+ hotkey(hk)
+ {
+ initialize(event_handler_, hk);
+ }
+
+ template <typename T>
+ menu_item_text (
+ const dlib::ustring& str,
+ T& object,
+ void (T::*event_handler_)(),
+ unichar hk = 0
+ ) :
+ text(str),
+ f(default_font::get_font()),
+ hotkey(hk)
+ {
+ initialize(make_mfp(object, event_handler_), hk);
+ }
+
+ menu_item_text (
+ const dlib::ustring& str,
+ const any_function<void()>& event_handler_,
+ unichar hk = 0
+ ) :
+ text(str),
+ f(default_font::get_font()),
+ hotkey(hk)
+ {
+ initialize(event_handler_, hk);
+ }
+
+ virtual unichar get_hot_key (
+ ) const { return hotkey; }
+
+ virtual rectangle get_middle_size (
+ ) const
+ {
+ unsigned long width, height;
+ f->compute_size(text,width,height);
+ return rectangle(width,height);
+ }
+
+ virtual void draw_background (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const bool is_selected
+ ) const
+ {
+ if (c.intersect(rect).is_empty())
+ return;
+
+ if (enabled && is_selected)
+ {
+ fill_rect_with_vertical_gradient(c, rect,rgb_alpha_pixel(0,200,0,100), rgb_alpha_pixel(0,0,0,100));
+ draw_rectangle(c, rect,rgb_alpha_pixel(0,0,0,100));
+ }
+ }
+
+ virtual void draw_middle (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const bool
+ ) const
+ {
+ if (c.intersect(rect).is_empty())
+ return;
+
+ unsigned char color = 0;
+
+ if (enabled == false)
+ color = 128;
+
+ f->draw_string(c,rect,text,color);
+
+ if (underline_p1 != underline_p2)
+ {
+ point base(rect.left(),rect.top());
+ draw_line(c, base+underline_p1, base+underline_p2, color);
+ }
+ }
+
+ virtual void on_click (
+ ) const
+ {
+ action();
+ }
+
+ virtual bool has_click_event (
+ ) const { return true; }
+
+ private:
+ dlib::ustring text;
+ const std::shared_ptr<font> f;
+ any_function<void()> action;
+ unichar hotkey;
+ point underline_p1;
+ point underline_p2;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class menu_item_separator : public menu_item
+ {
+ public:
+ virtual rectangle get_middle_size (
+ ) const
+ {
+ return rectangle(10,4);
+ }
+
+ virtual void draw_background (
+ const canvas& c,
+ const rectangle& rect,
+ const bool ,
+ const bool
+ ) const
+ {
+ if (c.intersect(rect).is_empty())
+ return;
+
+ point p1(rect.left(),rect.top()+rect.height()/2-1);
+ point p2(rect.right(),rect.top()+rect.height()/2-1);
+
+ point p3(rect.left(),rect.top()+rect.height()/2);
+ point p4(rect.right(),rect.top()+rect.height()/2);
+ draw_line(c, p1,p2,128);
+ draw_line(c, p3,p4,255);
+ }
+
+ virtual void draw_middle (
+ const canvas& ,
+ const rectangle& ,
+ const bool ,
+ const bool
+ ) const
+ {
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class popup_menu : public base_window
+ {
+ /*!
+ INITIAL VALUE
+ - pad == 2
+ - item_pad == 3
+ - cur_rect == rectangle(pad,pad,pad-1,pad-1)
+ - left_width == 0
+ - middle_width == 0
+ - selected_item == 0
+ - submenu_open == false
+ - items.size() == 0
+ - item_enabled.size() == 0
+ - left_rects.size() == 0
+ - middle_rects.size() == 0
+ - right_rects.size() == 0
+ - line_rects.size() == 0
+ - submenus.size() == 0
+ - hide_handlers.size() == 0
+
+ CONVENTION
+ - pad = 2
+ - item_pad = 3
+ - all of the following arrays have the same size:
+ - items.size()
+ - item_enabled.size()
+ - left_rects.size()
+ - middle_rects.size()
+ - right_rects.size()
+ - line_rects.size()
+ - submenus.size()
+
+ - win_rect == a rectangle that is the exact size of this window and with
+ its upper left corner at (0,0)
+ - cur_rect == the rect inside which all the menu items are drawn
+
+ - if (a menu_item is supposed to be selected) then
+ - selected_item == the index in menus of the menu_item
+ - else
+ - selected_item == submenus.size()
+
+ - if (there is a selected submenu and it is currently open) then
+ - submenu_open == true
+ - else
+ - submenu_open == false
+
+ - for all valid i:
+ - items[i] == a pointer to the ith menu_item
+ - item_enabled[i] == true if the ith menu_item is enabled, false otherwise
+ - left_rects[i] == the left rectangle for the ith menu item
+ - middle_rects[i] == the middle rectangle for the ith menu item
+ - right_rects[i] == the right rectangle for the ith menu item
+ - line_rects[i] == the rectangle for the entire line on which the ith menu
+ item appears.
+ - if (submenus[i] != 0) then
+ - the ith menu item has a submenu and it is pointed to by submenus[i]
+
+ - hide_handlers == an array of all the on_hide events registered for
+ this popup_menu
+ !*/
+
+ public:
+
+ popup_menu (
+ );
+
+ template <
+ typename menu_item_type
+ >
+ unsigned long add_menu_item (
+ const menu_item_type& new_item
+ )
+ {
+ auto_mutex M(wm);
+ bool t = true;
+ std::unique_ptr<menu_item> item(new menu_item_type(new_item));
+ items.push_back(item);
+ item_enabled.push_back(t);
+
+ // figure out how big the window should be now and what not
+ rectangle left = new_item.get_left_size();
+ rectangle middle = new_item.get_middle_size();
+ rectangle right = new_item.get_right_size();
+
+ bool recalc_rect_positions = false;
+ const rectangle all = left+middle+right;
+
+
+ // make sure left_width contains the max of all the left rectangles
+ if (left.width() > left_width)
+ {
+ left_width = left.width();
+ recalc_rect_positions = true;
+ }
+ // make sure middle_width contains the max of all the middle rectangles
+ if (middle.width() > middle_width)
+ {
+ middle_width = middle.width();
+ recalc_rect_positions = true;
+ }
+
+ // make the current rectangle wider if necessary
+ if (cur_rect.width() < left_width + middle_width + right.width() + 2*item_pad)
+ {
+ cur_rect = resize_rect_width(cur_rect, left_width + middle_width + right.width() + 2*item_pad);
+ recalc_rect_positions = true;
+ }
+
+ const long y = cur_rect.bottom()+1 + item_pad;
+ const long x = cur_rect.left() + item_pad;
+
+ // make the current rectangle taller to account for this new menu item
+ cur_rect.set_bottom(cur_rect.bottom()+all.height() + 2*item_pad);
+
+ // adjust all the saved rectangles since the width of the window changed
+ // or left_width changed
+ if (recalc_rect_positions)
+ {
+ long y = cur_rect.top() + item_pad;
+ for (unsigned long i = 0; i < left_rects.size(); ++i)
+ {
+ middle_rects[i] = move_rect(middle_rects[i], x+left_width, y);
+ right_rects[i] = move_rect(right_rects[i], x+cur_rect.width()-right_rects[i].width()-item_pad, y);
+ line_rects[i] = resize_rect_width(line_rects[i], cur_rect.width());
+
+ y += line_rects[i].height();
+ }
+ }
+
+ // save the rectangles for later use. Also position them at the
+ // right spots
+ left = move_rect(left,x,y);
+ middle = move_rect(middle,x+left_width,y);
+ right = move_rect(right,x+cur_rect.width()-right.width()-item_pad,y);
+ rectangle line(move_rect(rectangle(cur_rect.width(),all.height()+2*item_pad), x-item_pad, y-item_pad));
+
+ // make sure the left, middle, and right rectangles are centered in the
+ // line.
+ if (left.height() < all.height())
+ left = translate_rect(left,0, (all.height()-left.height())/2);
+ if (middle.height() < all.height())
+ middle = translate_rect(middle,0, (all.height()-middle.height())/2);
+ if (right.height() < all.height())
+ right = translate_rect(right,0, (all.height()-right.height())/2);
+
+ left_rects.push_back(left);
+ middle_rects.push_back(middle);
+ right_rects.push_back(right);
+ line_rects.push_back(line);
+
+ popup_menu* junk = 0;
+ submenus.push_back(junk);
+
+ win_rect.set_right(cur_rect.right()+pad);
+ win_rect.set_bottom(cur_rect.bottom()+pad);
+ set_size(win_rect.width(),win_rect.height());
+
+ // make it so that nothing is selected
+ selected_item = submenus.size();
+
+ return items.size()-1;
+ }
+
+ template <
+ typename menu_item_type
+ >
+ unsigned long add_submenu (
+ const menu_item_type& new_item,
+ popup_menu& submenu
+ )
+ {
+ auto_mutex M(wm);
+
+ submenus[add_menu_item(new_item)] = &submenu;
+
+ submenu.set_on_hide_handler(*this,&popup_menu::on_submenu_hide);
+
+ return items.size()-1;
+ }
+
+ void enable_menu_item (
+ unsigned long idx
+ );
+
+ void disable_menu_item (
+ unsigned long idx
+ );
+
+ size_t size (
+ ) const;
+
+ void clear (
+ );
+
+ void show (
+ );
+
+ void hide (
+ );
+
+ template <typename T>
+ void set_on_hide_handler (
+ T& object,
+ void (T::*event_handler)()
+ )
+ {
+ auto_mutex M(wm);
+
+ member_function_pointer<> temp;
+ temp.set(object,event_handler);
+
+ // if this handler isn't already registered then add it
+ bool found_handler = false;
+ for (unsigned long i = 0; i < hide_handlers.size(); ++i)
+ {
+ if (hide_handlers[i] == temp)
+ {
+ found_handler = true;
+ break;
+ }
+ }
+
+ if (found_handler == false)
+ {
+ hide_handlers.push_back(temp);
+ }
+ }
+
+ void select_first_item (
+ );
+
+ bool forwarded_on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+
+ private:
+
+ void on_submenu_hide (
+ );
+
+ void on_window_resized(
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long,
+ long x,
+ long y
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void close_submenu (
+ );
+
+ bool display_selected_submenu (
+ );
+ /*!
+ ensures
+ - if (submenus[selected_item] isn't null) then
+ - displays the selected submenu
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void on_mouse_leave (
+ );
+
+ void paint (
+ const canvas& c
+ );
+
+ const long pad;
+ const long item_pad;
+ rectangle cur_rect;
+ rectangle win_rect;
+ unsigned long left_width;
+ unsigned long middle_width;
+ array<std::unique_ptr<menu_item> > items;
+ array<bool> item_enabled;
+ array<rectangle> left_rects;
+ array<rectangle> middle_rects;
+ array<rectangle> right_rects;
+ array<rectangle> line_rects;
+ array<popup_menu*> submenus;
+ unsigned long selected_item;
+ bool submenu_open;
+ array<member_function_pointer<> > hide_handlers;
+
+ // restricted functions
+ popup_menu(popup_menu&); // copy constructor
+ popup_menu& operator=(popup_menu&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class zoomable_region : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - min_scale == 0.15
+ - max_scale == 1.0
+ - zoom_increment_ == 0.02
+ - scale == 1.0
+ - mouse_drag_screen == false
+
+
+ CONVENTION
+ - zoom_increment() == zoom_increment_
+ - min_zoom_scale() == min_scale
+ - max_zoom_scale() == max_scale
+ - zoom_scale() == scale
+ - if (the user is currently dragging the graph around via the mouse) then
+ - mouse_drag_screen == true
+ - else
+ - mouse_drag_screen == false
+
+ - max_graph_point() == lr_point
+ - display_rect() == display_rect_
+ - gui_to_graph_space(point(display_rect.left(),display_rect.top())) == gr_orig
+ !*/
+
+ public:
+
+ zoomable_region (
+ drawable_window& w,
+ unsigned long events = 0
+ );
+
+ virtual ~zoomable_region (
+ )= 0;
+
+ virtual void set_pos (
+ long x,
+ long y
+ );
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ )
+ {
+ auto_mutex M(m);
+ style.reset(new style_type(style_));
+ hsb.set_style(style_.get_horizontal_scroll_bar_style());
+ vsb.set_style(style_.get_vertical_scroll_bar_style());
+
+ // do this just so that everything gets redrawn right
+ set_size(rect.width(), rect.height());
+ }
+
+ void set_zoom_increment (
+ double zi
+ );
+
+ double zoom_increment (
+ ) const;
+
+ void set_max_zoom_scale (
+ double ms
+ );
+
+ void set_min_zoom_scale (
+ double ms
+ );
+
+ double min_zoom_scale (
+ ) const;
+
+ double max_zoom_scale (
+ ) const;
+
+ virtual void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ void show (
+ );
+
+ void hide (
+ );
+
+ void enable (
+ );
+
+ void disable (
+ );
+
+ void set_z_order (
+ long order
+ );
+
+ protected:
+
+ virtual void on_view_changed () {}
+
+ point graph_to_gui_space (
+ const vector<double,2>& p
+ ) const;
+
+ vector<double,2> gui_to_graph_space (
+ const point& p
+ ) const;
+
+ point max_graph_point (
+ ) const;
+
+ rectangle display_rect (
+ ) const;
+
+ double zoom_scale (
+ ) const;
+
+ void set_zoom_scale (
+ double new_scale
+ );
+
+ void center_display_at_graph_point (
+ const vector<double,2>& p
+ );
+
+ // ----------- event handlers ---------------
+
+ void on_wheel_down (
+ unsigned long state
+ );
+
+ void on_wheel_up (
+ unsigned long state
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ private:
+
+ void on_h_scroll (
+ );
+
+ void on_v_scroll (
+ );
+
+ void redraw_graph (
+ );
+
+ void adjust_origin (
+ const point& gui_p,
+ const vector<double,2>& graph_p
+ );
+ /*!
+ ensures
+ - adjusts gr_orig so that we are as close to the following as possible:
+ - graph_to_gui_space(graph_p) == gui_p
+ - gui_to_graph_space(gui_p) == graph_p
+ !*/
+
+
+ vector<double,2> gr_orig; // point in graph space such that it's gui space point is the upper left of display_rect_
+ vector<double,2> lr_point; // point in graph space such that it is at the lower right corner of the screen at max zoom
+
+ mutable std::ostringstream sout;
+
+ double scale; // 0 < scale <= 1
+ double min_scale;
+ double max_scale;
+ double zoom_increment_;
+ rectangle display_rect_;
+
+ bool mouse_drag_screen; // true if the user is dragging the white background area
+ vector<double,2> drag_screen_point; // the starting point the mouse was at in graph space for the background area drag
+
+ scroll_bar vsb;
+ scroll_bar hsb;
+
+ std::unique_ptr<scrollable_region_style> style;
+
+ // restricted functions
+ zoomable_region(zoomable_region&); // copy constructor
+ zoomable_region& operator=(zoomable_region&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class scrollable_region : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - hscroll_bar_inc == 1
+ - vscroll_bar_inc == 1
+ - h_wheel_scroll_bar_inc == 1
+ - v_wheel_scroll_bar_inc == 1
+ - mouse_drag_enabled_ == false
+ - user_is_dragging_mouse == false
+
+ CONVENTION
+ - mouse_drag_enabled() == mouse_drag_enabled_
+ - horizontal_scroll_increment() == hscroll_bar_inc
+ - vertical_scroll_increment() == vscroll_bar_inc
+ - horizontal_mouse_wheel_scroll_increment() == h_wheel_scroll_bar_inc
+ - vertical_mouse_wheel_scroll_increment() == v_wheel_scroll_bar_inc
+ - vertical_scroll_pos() == vsb.slider_pos()
+ - horizontal_scroll_pos() == hsb.slider_pos()
+ - total_rect() == total_rect_
+ - display_rect() == display_rect_
+
+ - if (the user is currently dragging the total_rect around with a mouse drag) then
+ - user_is_dragging_mouse == true
+ - drag_origin == the point the mouse was at, with respect to total_rect,
+ when the dragging started
+ - else
+ - user_is_dragging_mouse == false
+ !*/
+
+ public:
+
+ scrollable_region (
+ drawable_window& w,
+ unsigned long events = 0
+ );
+
+ virtual ~scrollable_region (
+ ) = 0;
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ )
+ {
+ auto_mutex M(m);
+ style.reset(new style_type(style_));
+ hsb.set_style(style_.get_horizontal_scroll_bar_style());
+ vsb.set_style(style_.get_vertical_scroll_bar_style());
+
+ // do this just so that everything gets redrawn right
+ set_size(rect.width(), rect.height());
+ }
+
+ void show (
+ );
+
+ void hide (
+ );
+
+ void enable (
+ );
+
+ void disable (
+ );
+
+ void set_z_order (
+ long order
+ );
+
+ virtual void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ unsigned long horizontal_mouse_wheel_scroll_increment (
+ ) const;
+
+ unsigned long vertical_mouse_wheel_scroll_increment (
+ ) const;
+
+ void set_horizontal_mouse_wheel_scroll_increment (
+ unsigned long inc
+ );
+
+ void set_vertical_mouse_wheel_scroll_increment (
+ unsigned long inc
+ );
+
+ unsigned long horizontal_scroll_increment (
+ ) const;
+
+ unsigned long vertical_scroll_increment (
+ ) const;
+
+ void set_horizontal_scroll_increment (
+ unsigned long inc
+ );
+
+ void set_vertical_scroll_increment (
+ unsigned long inc
+ );
+
+ long horizontal_scroll_pos (
+ ) const;
+
+ long vertical_scroll_pos (
+ ) const;
+
+ void set_horizontal_scroll_pos (
+ long pos
+ );
+
+ void set_vertical_scroll_pos (
+ long pos
+ );
+
+ virtual void set_pos (
+ long x,
+ long y
+ );
+
+ bool mouse_drag_enabled (
+ ) const;
+
+ void enable_mouse_drag (
+ );
+
+ void disable_mouse_drag (
+ );
+
+ protected:
+
+ virtual void on_view_changed () {}
+
+ const rectangle& display_rect (
+ ) const;
+
+ void set_total_rect_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ const rectangle& total_rect (
+ ) const;
+
+ void scroll_to_rect (
+ const rectangle& r_
+ );
+
+ void on_wheel_down (
+ unsigned long state
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_wheel_up (
+ unsigned long state
+ );
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ private:
+
+ bool need_h_scroll (
+ ) const;
+
+ bool need_v_scroll (
+ ) const;
+
+ void on_h_scroll (
+ );
+
+ void on_v_scroll (
+ );
+
+ rectangle total_rect_;
+ rectangle display_rect_;
+ scroll_bar hsb;
+ scroll_bar vsb;
+ unsigned long hscroll_bar_inc;
+ unsigned long vscroll_bar_inc;
+ unsigned long h_wheel_scroll_bar_inc;
+ unsigned long v_wheel_scroll_bar_inc;
+ bool mouse_drag_enabled_;
+ bool user_is_dragging_mouse;
+ point drag_origin;
+ std::unique_ptr<scrollable_region_style> style;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class popup_menu_region
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class popup_menu_region : public drawable
+ {
+ /*!
+ CONVENTION
+ popup_menu_visible() == popup_menu_shown
+ !*/
+
+ public:
+
+ popup_menu_region(
+ drawable_window& w
+ );
+
+ virtual ~popup_menu_region(
+ );
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ void set_rect (
+ const rectangle& new_rect
+ );
+
+ popup_menu& menu (
+ );
+
+ void hide (
+ );
+
+ void disable (
+ );
+
+ bool popup_menu_visible (
+ ) const { auto_mutex M(m); return popup_menu_shown; }
+
+ protected:
+
+ void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+
+ void on_focus_lost (
+ );
+
+ void on_focus_gained (
+ );
+
+ void on_window_moved(
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void on_menu_becomes_hidden (
+ );
+
+ void draw (
+ const canvas&
+ ) const;
+
+ private:
+
+ popup_menu menu_;
+ bool popup_menu_shown;
+
+ // restricted functions
+ popup_menu_region(popup_menu_region&); // copy constructor
+ popup_menu_region& operator=(popup_menu_region&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "base_widgets.cpp"
+#endif
+
+#endif // DLIB_BASE_WIDGETs_
+
diff --git a/ml/dlib/dlib/gui_widgets/base_widgets_abstract.h b/ml/dlib/dlib/gui_widgets/base_widgets_abstract.h
new file mode 100644
index 000000000..3dcee0d5a
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/base_widgets_abstract.h
@@ -0,0 +1,2290 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BASE_WIDGETs_ABSTRACT_
+#ifdef DLIB_BASE_WIDGETs_ABSTRACT_
+
+#include "fonts_abstract.h"
+#include "drawable_abstract.h"
+
+#include "../gui_core.h"
+#include <string>
+
+namespace dlib
+{
+
+ /*!
+ GENERAL REMARKS
+ This file contains objects that are useful for creating complex drawable
+ widgets.
+
+ THREAD SAFETY
+ All objects and functions defined in this file are thread safe. You may
+ call them from any thread without serializing access to them.
+
+ EVENT HANDLERS
+ If you derive from any of the drawable objects and redefine any of the on_*()
+ event handlers then you should ensure that your version calls the same event
+ handler in the base object so that the base class part of your object will also
+ be able to process the event.
+
+ Also note that all event handlers, including the user registered callback
+ functions, are executed in the event handling thread. Additionally,
+ the drawable::m mutex will always be locked while these event handlers
+ are running. Also, don't rely on get_thread_id() always returning the
+ same ID from inside event handlers.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class draggable
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class draggable : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ draggable_area() == an initial value for its type
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a drawable object that is draggable by the mouse.
+ You use it by inheriting from it and defining the draw() method and any
+ of the on_*() event handlers you need.
+
+ This object is draggable by the user when is_enabled() == true and
+ not draggable otherwise.
+ !*/
+
+ public:
+
+ draggable(
+ drawable_window& w,
+ unsigned long events = 0
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ - This object will not receive any events or draw() requests until
+ enable_events() is called
+ - the events flags are passed on to the drawable object's
+ constructor.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~draggable(
+ ) = 0;
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ rectangle draggable_area (
+ ) const;
+ /*!
+ ensures
+ - returns the area that this draggable can be dragged around in.
+ !*/
+
+ void set_draggable_area (
+ const rectangle& area
+ );
+ /*!
+ ensures
+ - #draggable_area() == area
+ !*/
+
+ protected:
+
+ bool is_being_dragged (
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - if (this widget is currently being dragged by the user) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ // does nothing by default
+ virtual void on_drag (
+ ){}
+ /*!
+ requires
+ - enable_events() has been called
+ - is_enabled() == true
+ - is_hidden() == false
+ - mutex drawable::m is locked
+ - is called when the user drags this object
+ - get_rect() == the rectangle that defines the new position
+ of this object.
+ - is_being_dragged() == true
+ ensures
+ - does not change the state of mutex drawable::m.
+ !*/
+
+ // does nothing by default
+ virtual void on_drag_stop (
+ ){}
+ /*!
+ requires
+ - enable_events() has been called
+ - mutex drawable::m is locked
+ - is called when the user stops dragging this object
+ - is_being_dragged() == false
+ ensures
+ - does not change the state of mutex drawable::m.
+ !*/
+
+ private:
+
+ // restricted functions
+ draggable(draggable&); // copy constructor
+ draggable& operator=(draggable&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class mouse_over_event
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class mouse_over_event : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ is_mouse_over() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a drawable object with the addition of two events
+ that will alert you when the mouse enters or leaves your drawable object.
+
+ You use it by inheriting from it and defining the draw() method and any
+ of the on_*() event handlers you need.
+ !*/
+
+ public:
+
+ mouse_over_event(
+ drawable_window& w,
+ unsigned long events = 0
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ - #*this will not receive any events or draw() requests until
+ enable_events() is called
+ - the events flags are passed on to the drawable object's
+ constructor.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~mouse_over_event(
+ ) = 0;
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ protected:
+
+ bool is_mouse_over (
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - if (the mouse is currently over this widget) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ // does nothing by default
+ virtual void on_mouse_over (
+ ){}
+ /*!
+ requires
+ - enable_events() has been called
+ - mutex drawable::m is locked
+ - is_enabled() == true
+ - is_hidden() == false
+ - is called whenever this object transitions from the state where
+ is_mouse_over() == false to is_mouse_over() == true
+ ensures
+ - does not change the state of mutex drawable::m.
+ !*/
+
+ // does nothing by default
+ virtual void on_mouse_not_over (
+ ){}
+ /*!
+ requires
+ - enable_events() has been called
+ - mutex drawable::m is locked
+ - is called whenever this object transitions from the state where
+ is_mouse_over() == true to is_mouse_over() == false
+ ensures
+ - does not change the state of mutex drawable::m.
+ !*/
+
+ private:
+
+ // restricted functions
+ mouse_over_event(mouse_over_event&); // copy constructor
+ mouse_over_event& operator=(mouse_over_event&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class button_action
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class button_action : public mouse_over_event
+ {
+ /*!
+ INITIAL VALUE
+ is_depressed() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the clicking action of a push button. It provides
+ simple callbacks that can be used to make various kinds of button
+ widgets.
+
+ You use it by inheriting from it and defining the draw() method and any
+ of the on_*() event handlers you need.
+ !*/
+
+ public:
+
+ button_action(
+ drawable_window& w,
+ unsigned long events = 0
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ - #*this will not receive any events or draw() requests until
+ enable_events() is called
+ - the events flags are passed on to the drawable object's
+ constructor.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~button_action(
+ ) = 0;
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ protected:
+
+ bool is_depressed (
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - if (this button is currently in a depressed state) then
+ - the user has left clicked on this drawable and is still
+ holding the left mouse button down over it.
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ // does nothing by default
+ virtual void on_button_down (
+ ){}
+ /*!
+ requires
+ - enable_events() has been called
+ - mutex drawable::m is locked
+ - is_enabled() == true
+ - is_hidden() == false
+ - the area in parent_window() defined by get_rect() has been invalidated.
+ (This means you don't have to call invalidate_rectangle())
+ - is called whenever this object transitions from the state where
+ is_depressed() == false to is_depressed() == true
+ ensures
+ - does not change the state of mutex drawable::m.
+ !*/
+
+ // does nothing by default
+ virtual void on_button_up (
+ bool mouse_over
+ ){}
+ /*!
+ requires
+ - enable_events() has been called
+ - mutex drawable::m is locked
+ - the area in parent_window() defined by get_rect() has been invalidated.
+ (This means you don't have to call invalidate_rectangle())
+ - is called whenever this object transitions from the state where
+ is_depressed() == true to is_depressed() == false
+ - if (the mouse was over this button when this event occurred) then
+ - mouse_over == true
+ - else
+ - mouse_over == false
+ ensures
+ - does not change the state of mutex drawable::m.
+ !*/
+
+ private:
+
+ // restricted functions
+ button_action(button_action&); // copy constructor
+ button_action& operator=(button_action&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class button
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class button : public button_action
+ {
+ /*!
+ INITIAL VALUE
+ name() == ""
+ tooltip_text() == "" (i.e. there is no tooltip by default)
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple button.
+
+ When this object is disabled it means it will not respond to user clicks.
+ !*/
+
+ public:
+
+ button(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~button(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_size (
+ unsigned long width_,
+ unsigned long height_
+ );
+ /*!
+ ensures
+ - if (width and height are big enough to contain the name of this button) then
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this button stays the
+ same but its width and height are modified
+ !*/
+
+ void set_name (const std::wstring& name);
+ void set_name (const dlib::ustring& name);
+ void set_name (
+ const std::string& name
+ );
+ /*!
+ ensures
+ - #name() == name
+ - this button has been resized such that it is big enough to contain
+ the new name.
+ throws
+ - std::bad_alloc
+ !*/
+
+ const std::wstring wname () const;
+ const dlib::string uname () const;
+ const std::string name (
+ ) const;
+ /*!
+ ensures
+ - returns the name of this button
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_tooltip_text (const std::wstring& text);
+ void set_tooltip_text (const dlib::ustring& text);
+ void set_tooltip_text (
+ const std::string& text
+ );
+ /*!
+ ensures
+ - #tooltip_text() == text
+ - enables the tooltip for this button
+ !*/
+
+ const dlib::ustring tooltip_utext () const;
+ const std::wstring tooltip_wtext () const;
+ const std::string tooltip_text (
+ ) const;
+ /*!
+ ensures
+ - returns the text that is displayed in the tooltip for this button
+ !*/
+
+ bool is_depressed (
+ ) const;
+ /*!
+ ensures
+ - if (this button is currently in a depressed state) then
+ - the user has left clicked on this widget and is still
+ holding the left mouse button down over it.
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style
+ );
+ /*!
+ requires
+ - style_type == a type that inherits from button_style
+ ensures
+ - this button object will draw itself using the given
+ button style
+ !*/
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the button is
+ clicked by the user.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_click_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the button is clicked by
+ the user.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler)(button& self)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - &self == this
+ - the event_handler function is called on object when the button is
+ clicked by the user.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_sourced_click_handler (
+ const any_function<void(button& self)>& event_handler
+ );
+ /*!
+ ensures
+ - &self == this
+ - the event_handler function is called when the button is clicked by
+ the user.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_button_down_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the user causes
+ the button to go into its depressed state.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_button_down_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the user causes the button
+ to go into its depressed state.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_button_up_handler (
+ T& object,
+ void (T::*event_handler)(bool mouse_over)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the user causes
+ the button to go into its non-depressed state.
+ - if (the mouse is over this button when this event occurs) then
+ - mouse_over == true
+ - else
+ - mouse_over == false
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_button_up_handler (
+ const any_function<void(bool mouse_over)>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the user causes the
+ button to go into its non-depressed state.
+ - if (the mouse is over this button when this event occurs) then
+ - mouse_over == true
+ - else
+ - mouse_over == false
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_button_down_handler (
+ T& object,
+ void (T::*event_handler)(button& self)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - &self == this
+ - the event_handler function is called on object when the user causes
+ the button to go into its depressed state.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_sourced_button_down_handler (
+ const any_function<void(button& self)>& event_handler
+ );
+ /*!
+ ensures
+ - &self == this
+ - the event_handler function is called when the user causes the button
+ to go into its depressed state.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_button_up_handler (
+ T& object,
+ void (T::*event_handler)(bool mouse_over, button& self)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - &self == this
+ - the event_handler function is called on object when the user causes
+ the button to go into its non-depressed state.
+ - if (the mouse is over this button when this event occurs) then
+ - mouse_over == true
+ - else
+ - mouse_over == false
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_sourced_button_up_handler (
+ const any_function<void(bool mouse_over, button& self)>& event_handler
+ );
+ /*!
+ ensures
+ - &self == this
+ - the event_handler function is called when the user causes the
+ button to go into its non-depressed state.
+ - if (the mouse is over this button when this event occurs) then
+ - mouse_over == true
+ - else
+ - mouse_over == false
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ button(button&); // copy constructor
+ button& operator=(button&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class scroll_bar
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class scroll_bar : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ orientation() == a value given to the constructor.
+ max_slider_pos() == 0
+ slider_pos() == 0
+ jump_size() == 10
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a scroll bar. The slider_pos() of the scroll bar
+ ranges from 0 to max_slider_pos(). The 0 position of the scroll_bar is
+ in the top or left side of the scroll_bar depending on its orientation.
+
+ When this object is disabled it means it will not respond to user clicks.
+ !*/
+
+ public:
+ enum bar_orientation
+ {
+ HORIZONTAL,
+ VERTICAL
+ };
+
+ scroll_bar(
+ drawable_window& w,
+ bar_orientation orientation
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #orientation() == orientation
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~scroll_bar(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ bar_orientation orientation (
+ ) const;
+ /*!
+ ensures
+ - returns the orientation of this scroll_bar
+ !*/
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style
+ );
+ /*!
+ requires
+ - style_type == a type that inherits from scroll_bar_style
+ ensures
+ - this scroll_bar object will draw itself using the given
+ scroll bar style
+ !*/
+
+ void set_length (
+ unsigned long length,
+ );
+ /*!
+ ensures
+ - if (orientation() == HORIZONTAL) then
+ - #width() == max(length,1)
+ - else
+ - #height() == max(length,1)
+ !*/
+
+ long max_slider_pos (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum value that slider_pos() can take.
+ !*/
+
+ void set_max_slider_pos (
+ long mpos
+ );
+ /*!
+ ensures
+ - if (mpos < 0) then
+ - #max_slider_pos() == 0
+ - else
+ - #max_slider_pos() == mpos
+ - if (slider_pos() > #max_slider_pos()) then
+ - #slider_pos() == #max_slider_pos()
+ - else
+ - #slider_pos() == slider_pos()
+ !*/
+
+ void set_slider_pos (
+ unsigned long pos
+ );
+ /*!
+ ensures
+ - if (pos < 0) then
+ - #slider_pos() == 0
+ - else if (pos > max_slider_pos()) then
+ - #slider_pos() == max_slider_pos()
+ - else
+ - #slider_pos() == pos
+ !*/
+
+ long slider_pos (
+ ) const;
+ /*!
+ ensures
+ - returns the current position of the slider box within the scroll bar.
+ !*/
+
+ long jump_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of positions that the slider bar will jump when the
+ user clicks on the empty gaps above or below the slider bar.
+ (note that the slider will jump less than the jump size if it hits the
+ end of the scroll bar)
+ !*/
+
+ void set_jump_size (
+ long js
+ );
+ /*!
+ ensures
+ - if (js < 1) then
+ - #jump_size() == 1
+ - else
+ - #jump_size() == js
+ !*/
+
+
+ template <
+ typename T
+ >
+ void set_scroll_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - The event_handler function is called whenever the user causes the slider box
+ to move.
+ - This event is NOT triggered by calling set_slider_pos()
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_scroll_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - The event_handler function is called whenever the user causes the slider box
+ to move.
+ - This event is NOT triggered by calling set_slider_pos()
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ scroll_bar(scroll_bar&); // copy constructor
+ scroll_bar& operator=(scroll_bar&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class widget_group
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class widget_group : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ size() == 0
+ get_rect().is_empty() == true
+ left() == 0
+ top() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a grouping of drawable widgets. It doesn't draw
+ anything itself, rather it lets you manipulate the position, enabled
+ status, and visibility of a set of widgets as a group.
+ !*/
+
+ public:
+ widget_group(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~widget_group(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released.
+ !*/
+
+ void empty (
+ );
+ /*!
+ ensures
+ - #size() == 0
+ !*/
+
+ void fit_to_contents (
+ );
+ /*!
+ ensures
+ - does not change the position of this object.
+ (i.e. the upper left corner of get_rect() remains at the same position)
+ - if (size() == 0) then
+ - #get_rect().is_empty() == true
+ - else
+ - recursively calls fit_to_contents() on any widget_groups inside
+ this object.
+ - #get_rect() will be the smallest rectangle that contains all the
+ widgets in this group and the upper left corner of get_rect().
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of widgets currently in *this.
+ !*/
+
+ void add (
+ drawable& widget,
+ unsigned long x,
+ unsigned long y
+ );
+ /*!
+ ensures
+ - #is_member(widget) == true
+ - if (is_member(widget) == false) then
+ - #size() == size() + 1
+ - else
+ - #size() == size()
+ - The following conditions apply to this function as well as to all of the
+ following functions so long as is_member(widget) == true:
+ enable(), disable(), hide(), show(), set_z_order(), and set_pos().
+ - #widget.left() == left()+x
+ - #widget.width() == widget.width()
+ - #widget.top() == top()+y
+ - #widget.height() == widget.height()
+ - #widget.is_hidden() == is_hidden()
+ - #widget.is_enabled() == is_enabled()
+ - #widget.z_order() == z_order()
+ throws
+ - std::bad_alloc
+ !*/
+
+ bool is_member (
+ const drawable& widget
+ ) const;
+ /*!
+ ensures
+ - returns true if widget is currently in this object, returns false otherwise.
+ !*/
+
+ void remove (
+ const drawable& widget
+ );
+ /*!
+ ensures
+ - #is_member(widget) == false
+ - if (is_member(widget) == true) then
+ - #size() == size() - 1
+ - else
+ - #size() == size()
+ !*/
+
+ protected:
+
+ // this object doesn't draw anything but also isn't abstract
+ void draw (
+ const canvas& c
+ ) const {}
+
+ private:
+
+ // restricted functions
+ widget_group(widget_group&); // copy constructor
+ widget_group& operator=(widget_group&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class image_widget
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class image_widget : public draggable
+ {
+ /*!
+ INITIAL VALUE
+ draggable_area() == an initial value for its type.
+ This object isn't displaying anything.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a draggable image. You give it an image to display
+ by calling set_image().
+
+ Also note that initially the draggable area is empty so it won't be
+ draggable unless you call set_draggable_area() to some non-empty region.
+
+ The image is drawn such that:
+ - the pixel img[0][0] is the upper left corner of the image.
+ - the pixel img[img.nr()-1][0] is the lower left corner of the image.
+ - the pixel img[0][img.nc()-1] is the upper right corner of the image.
+ - the pixel img[img.nr()-1][img.nc()-1] is the lower right corner of the image.
+
+ !*/
+
+ public:
+
+ image_widget(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~image_widget(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ template <
+ typename image_type
+ >
+ void set_image (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an implementation of array2d/array2d_kernel_abstract.h
+ - pixel_traits<typename image_type::type> must be defined
+ ensures
+ - #width() == img.nc()
+ - #height() == img.nr()
+ - #*this widget is now displaying the given image img.
+ !*/
+
+ private:
+
+ // restricted functions
+ image_widget(image_widget&); // copy constructor
+ image_widget& operator=(image_widget&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class tooltip
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class tooltip : public mouse_over_event
+ {
+ /*!
+ INITIAL VALUE
+ - text() == ""
+ - the tooltip is inactive until the text is changed to
+ a non-empty string.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a region on a window where if the user
+ hovers the mouse over this region a tooltip with a message
+ appears.
+ !*/
+
+ public:
+
+ tooltip(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~tooltip(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_size (
+ unsigned long width_,
+ unsigned long height_
+ );
+ /*!
+ ensures
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this widget stays the
+ same but its width and height are modified
+ !*/
+
+ void set_text (const std::wstring& str);
+ void set_text (const dlib::ustring& str);
+ void set_text (
+ const std::string& str
+ );
+ /*!
+ ensures
+ - #text() == str
+ - activates the tooltip. i.e. after this function the tooltip
+ will display on the screen when the user hovers the mouse over it
+ !*/
+
+ const std::wstring wtext () const;
+ const dlib::ustring utext () const;
+ const std::string text (
+ ) const;
+ /*!
+ ensures
+ - returns the text that is displayed inside this
+ tooltip
+ !*/
+
+ private:
+
+ // restricted functions
+ tooltip(tooltip&); // copy constructor
+ tooltip& operator=(tooltip&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // popup menu stuff
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class menu_item
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an abstract class that defines the interface a
+ menu item in a popup_menu must implement.
+
+ Note that a menu_item is drawn as 3 separate pieces:
+ ---------------------------------
+ | left | middle | right |
+ ---------------------------------
+
+ Also note that derived classes must be copyable via
+ their copy constructors.
+ !*/
+
+ public:
+
+ virtual ~menu_item() {}
+
+ virtual void on_click (
+ ) const {}
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - if (has_click_event()) then
+ - this function is called when the user clicks on this menu_item
+ !*/
+
+ virtual bool has_click_event (
+ ) const { return false; }
+ /*!
+ ensures
+ - if (this menu_item wants to receive on_click events) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ virtual unichar get_hot_key (
+ ) const { return 0; }
+ /*!
+ ensures
+ - if (this menu item has a keyboard hot key) then
+ - returns the unicode value of the key
+ - else
+ - returns 0
+ !*/
+
+ virtual rectangle get_left_size (
+ ) const { return rectangle(); } // return empty rect by default
+ /*!
+ ensures
+ - returns the dimensions of the left part of the menu_item
+ !*/
+
+ virtual rectangle get_middle_size (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the dimensions of the middle part of the menu_item
+ !*/
+
+ virtual rectangle get_right_size (
+ ) const { return rectangle(); } // return empty rect by default
+ /*!
+ ensures
+ - returns the dimensions of the right part of the menu_item
+ !*/
+
+ virtual void draw_background (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const bool is_selected
+ ) const {}
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ requires
+ - c == the canvas to draw on
+ - rect == the rectangle in which we are to draw the background
+ - enabled == true if the menu_item is to be drawn enabled
+ - is_selected == true if the menu_item is to be drawn selected
+ ensures
+ - draws the background of the menu_item on the canvas c at the location
+ given by rect.
+ !*/
+
+ virtual void draw_left (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const bool is_selected
+ ) const {}
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ requires
+ - c == the canvas to draw on
+ - rect == the rectangle in which we are to draw the background
+ - enabled == true if the menu_item is to be drawn enabled
+ - is_selected == true if the menu_item is to be drawn selected
+ ensures
+ - draws the left part of the menu_item on the canvas c at the location
+ given by rect.
+ !*/
+
+ virtual void draw_middle (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const bool is_selected
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ requires
+ - c == the canvas to draw on
+ - rect == the rectangle in which we are to draw the background
+ - enabled == true if the menu_item is to be drawn enabled
+ - is_selected == true if the menu_item is to be drawn selected
+ ensures
+ - draws the middle part of the menu_item on the canvas c at the location
+ given by rect.
+ !*/
+
+ virtual void draw_right (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const bool is_selected
+ ) const {}
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ requires
+ - c == the canvas to draw on
+ - rect == the rectangle in which we are to draw the background
+ - enabled == true if the menu_item is to be drawn enabled
+ - is_selected == true if the menu_item is to be drawn selected
+ ensures
+ - draws the right part of the menu_item on the canvas c at the location
+ given by rect.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class menu_item_text : public menu_item
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple text menu item
+ !*/
+
+ public:
+
+ template <
+ typename T
+ >
+ menu_item_text (
+ const std::string& str,
+ T& object,
+ void (T::*on_click_handler)(),
+ unichar hotkey = 0
+ );
+ /*!
+ ensures
+ - The text of this menu item will be str
+ - the on_click_handler function is called on object when this menu_item
+ clicked by the user.
+ - #get_hot_key() == hotkey
+ !*/
+
+ menu_item_text (
+ const std::string& str,
+ const any_function<void()>& on_click_handler,
+ unichar hotkey = 0
+ );
+ /*!
+ ensures
+ - The text of this menu item will be str
+ - the on_click_handler function is called when this menu_item
+ clicked by the user.
+ - #get_hot_key() == hotkey
+ !*/
+
+ // overloads for wide character strings
+ template <
+ typename T
+ >
+ menu_item_text (
+ const std::wstring& str,
+ T& object,
+ void (T::*on_click_handler)(),
+ unichar hotkey = 0
+ );
+
+ menu_item_text (
+ const std::wstring& str,
+ const any_function<void()>& on_click_handler,
+ unichar hotkey = 0
+ );
+
+ template <
+ typename T
+ >
+ menu_item_text (
+ const dlib::ustring& str,
+ T& object,
+ void (T::*on_click_handler)(),
+ unichar hotkey = 0
+ );
+
+ template <
+ typename T
+ >
+ menu_item_text (
+ const dlib::ustring& str,
+ const any_function<void()>& on_click_handler,
+ unichar hotkey = 0
+ );
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class menu_item_submenu : public menu_item
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple text item intended to be used with
+ submenus inside a popup_menu.
+ !*/
+
+ public:
+
+ menu_item_submenu (
+ const std::string& str,
+ unichar hotkey = 0
+ );
+ /*!
+ ensures
+ - The text of this menu item will be str
+ - #get_hot_key() == hotkey
+ !*/
+
+ //overloads for wide character strings
+ menu_item_submenu (
+ const std::wstring& str,
+ unichar hotkey = 0
+ );
+
+ menu_item_submenu (
+ const dlib::ustring& str,
+ unichar hotkey = 0
+ );
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class menu_item_separator : public menu_item
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a horizontal separator in a popup menu
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class popup_menu : public base_window
+ {
+ /*!
+ INITIAL VALUE
+ - size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a popup menu window capable of containing
+ menu_item objects.
+ !*/
+
+ public:
+
+ popup_menu (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ - dlib::gui_error
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ template <
+ typename menu_item_type
+ >
+ unsigned long add_menu_item (
+ const menu_item_type& new_item
+ );
+ /*!
+ requires
+ - menu_item_type == a type that inherits from menu_item
+ ensures
+ - adds new_item onto the bottom of this popup_menu.
+ - returns size()
+ (This is also the index by which this item can be
+ referenced by the enable_menu_item() and disable_menu_item()
+ functions.)
+ !*/
+
+ template <
+ typename menu_item_type
+ >
+ unsigned long add_submenu (
+ const menu_item_type& new_item,
+ popup_menu& submenu
+ );
+ /*!
+ requires
+ - menu_item_type == a type that inherits from menu_item
+ ensures
+ - adds new_item onto the bottom of this popup_menu.
+ - when the user puts the mouse above this menu_item the given
+ submenu popup_menu will be displayed.
+ - returns size()
+ (This is also the index by which this item can be
+ referenced by the enable_menu_item() and disable_menu_item()
+ functions.)
+ !*/
+
+ void enable_menu_item (
+ unsigned long idx
+ );
+ /*!
+ requires
+ - idx < size()
+ ensures
+ - the menu_item in this with the index idx has been enabled
+ !*/
+
+ void disable_menu_item (
+ unsigned long idx
+ );
+ /*!
+ requires
+ - idx < size()
+ ensures
+ - the menu_item in this with the index idx has been disabled
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of menu_item objects in this popup_menu
+ !*/
+
+ template <typename T>
+ void set_on_hide_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ ensures
+ - the event_handler function is called on object when this popup_menu
+ hides itself due to an action by the user.
+ - Note that you can register multiple handlers for this event.
+ !*/
+
+ void select_first_item (
+ );
+ /*!
+ ensures
+ - causes this popup menu to highlight the first
+ menu item that it contains which has a click event
+ and is enabled.
+ !*/
+
+ bool forwarded_on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+ /*!
+ requires
+ - key, is_printable, and state are the variables from the
+ base_window::on_keydown() event
+ ensures
+ - forwards this keyboard event to this popup window so that it
+ may deal with keyboard events from other windows.
+ - if (this popup_menu uses the keyboard event) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ private:
+
+ // restricted functions
+ popup_menu(popup_menu&); // copy constructor
+ popup_menu& operator=(popup_menu&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class popup_menu_region
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class popup_menu_region : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - popup_menu_visible() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a region on a window where if the user
+ right clicks the mouse over this region a popup_menu pops up.
+
+ Note that this widget doesn't actually draw anything, it just
+ provides a region the user can click on to get a popup menu.
+ !*/
+
+ public:
+
+ popup_menu_region(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~popup_menu_region(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_size (
+ unsigned long width_,
+ unsigned long height_
+ );
+ /*!
+ ensures
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this widget stays the
+ same but its width and height are modified
+ !*/
+
+ void set_rect (
+ const rectangle& new_rect
+ );
+ /*!
+ ensures
+ - #get_rect() == new_rect
+ !*/
+
+ bool popup_menu_visible (
+ ) const;
+ /*!
+ ensures
+ - if (the popup menu is currently visible on the screen) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ popup_menu& menu (
+ );
+ /*!
+ ensures
+ - returns a reference to the popup_menu for this object. It is
+ the menu that is displayed when the user right clicks on
+ this widget
+ !*/
+
+ private:
+
+ // restricted functions
+ popup_menu_region(popup_menu_region&); // copy constructor
+ popup_menu_region& operator=(popup_menu_region&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class zoomable_region
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class zoomable_region : public drawable
+ {
+ /*
+ INITIAL VALUE
+ - min_zoom_scale() == 0.15
+ - max_zoom_scale() == 1.0
+ - zoom_increment() == 0.90
+ - zoom_scale() == 1.0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a 2D Cartesian graph that you can zoom into and
+ out of. It is a graphical widget that draws a rectangle with
+ a horizontal and vertical scroll bar that allow the user to scroll
+ around on a Cartesian graph that is much larger than the actual
+ area occupied by this object on the screen. It also allows
+ the user to zoom in and out.
+
+ To use this object you inherit from it and make use of its public and
+ protected member functions. It provides functions for converting between
+ pixel locations and the points in our 2D Cartesian graph so that when the
+ user is scrolling/zooming the widget you can still determine where
+ things are to be placed on the screen and what screen pixels correspond
+ to in the Cartesian graph.
+
+ Note that the Cartesian graph in this object is bounded by the point
+ (0,0), corresponding to the upper left corner when we are zoomed all
+ the way out, and max_graph_point() which corresponds to the lower right
+ corner when zoomed all the way out. The value of max_graph_point() is
+ determined automatically from the size of this object's on screen
+ rectangle and the value of min_zoom_scale() which determines how far
+ out you can zoom.
+ */
+
+ public:
+
+ zoomable_region (
+ drawable_window& w,
+ unsigned long events = 0
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ - This object will not receive any events or draw() requests until
+ enable_events() is called
+ - the events flags are passed on to the drawable object's
+ constructor.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~zoomable_region (
+ ) = 0;
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ );
+ /*!
+ requires
+ - style_type == a type that inherits from scrollable_region_style
+ ensures
+ - this zoomable_region object will draw itself using the given
+ style
+ !*/
+
+ void set_zoom_increment (
+ double zi
+ );
+ /*!
+ requires
+ - 0 < zi < 1
+ ensures
+ - #zoom_increment() == zi
+ !*/
+
+ double zoom_increment (
+ ) const;
+ /*!
+ ensures
+ - When the user zooms in using the mouse wheel:
+ - #zoom_scale() == zoom_scale() / zoom_increment()
+ - When the user zooms out using the mouse wheel:
+ - #zoom_scale() == zoom_scale() * zoom_increment()
+ - So this function returns the number that determines how much the zoom
+ changes when the mouse wheel is moved.
+ !*/
+
+ void set_max_zoom_scale (
+ double ms
+ );
+ /*!
+ requires
+ - ms > 0
+ ensures
+ - #max_zoom_scale() == ms
+ !*/
+
+ void set_min_zoom_scale (
+ double ms
+ );
+ /*!
+ requires
+ - ms > 0
+ ensures
+ - #min_zoom_scale() == ms
+ !*/
+
+ double min_zoom_scale (
+ ) const;
+ /*!
+ ensures
+ - returns the minimum allowed value of zoom_scale()
+ (i.e. this is the number that determines how far out the user is allowed to zoom)
+ !*/
+
+ double max_zoom_scale (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum allowed value of zoom_scale()
+ (i.e. this is the number that determines how far in the user is allowed to zoom)
+ !*/
+
+ virtual void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ ensures
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this button stays the
+ same but its width and height are modified
+ !*/
+
+ protected:
+
+ rectangle display_rect (
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - returns the rectangle on the screen that contains the Cartesian
+ graph in this widget. I.e. this is the area of this widget minus
+ the area taken up by the scroll bars and border decorations.
+ !*/
+
+ point graph_to_gui_space (
+ const vector<double,2>& graph_point
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - returns the location of the pixel on the screen that corresponds
+ to the given point in Cartesian graph space
+ !*/
+
+ vector<double,2> gui_to_graph_space (
+ const point& pixel_point
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - returns the point in Cartesian graph space that corresponds to the given
+ pixel location
+ !*/
+
+ vector<double,2> max_graph_point (
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - returns the pixel farthest from the graph point (0,0) that is still
+ in the graph. I.e. returns the point in graph space that corresponds
+ to the lower right corner of the display_rect() when we are zoomed
+ all the way out.
+ !*/
+
+ double zoom_scale (
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - returns a double Z that represents the current zoom.
+ - Smaller values of Z represent the user zooming out.
+ - Bigger values of Z represent the user zooming in.
+ - The default unzoomed case is when Z == 1
+ - objects should be drawn such that they are zoom_scale()
+ times their normal size
+ !*/
+
+ void set_zoom_scale (
+ double new_scale
+ );
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - invalidates the display_rect() so that it will be redrawn
+ - if (min_zoom_scale() <= new_scale && new_scale <= max_zoom_scale()) then
+ - #zoom_scale() == new_scale
+ - else if (new_scale < min_zoom_scale()) then
+ - #zoom_scale() == min_zoom_scale()
+ - else if (new_scale > max_zoom_scale()) then
+ - #zoom_scale() == max_zoom_scale()
+ !*/
+
+ void center_display_at_graph_point (
+ const vector<double,2>& graph_point
+ );
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - causes the given graph point to be centered in the display
+ if possible
+ - invalidates the display_rect() so that it will be redrawn
+ !*/
+
+ virtual void on_view_changed (
+ ) {}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex drawable::m is locked
+ ensures
+ - on_view_changed() is called whenever the user causes the view of the
+ zoomable_region to change. That is, this function is called when the
+ user scrolls or zooms around in the region.
+ !*/
+
+ // ---------------------------- event handlers ----------------------------
+ // The following event handlers are used in this object. So if you
+ // use any of them in your derived object you should pass the events
+ // back to it so that they still operate unless you wish to hijack the
+ // event for your own reasons (e.g. to override the mouse drag this object
+ // performs)
+
+ void on_wheel_down (unsigned long state);
+ void on_wheel_up (unsigned long state);
+ void on_mouse_move ( unsigned long state, long x, long y);
+ void on_mouse_up ( unsigned long btn, unsigned long state, long x, long y);
+ void on_mouse_down ( unsigned long btn, unsigned long state, long x, long y, bool is_double_click);
+ void draw ( const canvas& c) const;
+
+ private:
+
+ // restricted functions
+ zoomable_region(zoomable_region&); // copy constructor
+ zoomable_region& operator=(zoomable_region&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class scrollable_region : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - horizontal_scroll_pos() == 0
+ - horizontal_scroll_increment() == 1
+ - horizontal_mouse_wheel_scroll_increment() == 1
+ - vertical_scroll_pos() == 0
+ - vertical_scroll_increment() == 1
+ - vertical_mouse_wheel_scroll_increment() == 1
+ - total_rect().empty() == true
+ - mouse_drag_enabled() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a 2D region of arbitrary size that is displayed
+ within a possibly smaller scrollable gui widget. That is, it is a
+ graphical widget that draws a rectangle with a horizontal and vertical
+ scroll bar that allows the user to scroll around on a region that is much
+ larger than the actual area occupied by this object on the screen.
+
+ To use this object you inherit from it and make use of its public and
+ protected member functions. It provides a function, total_rect(), that
+ tells you where the 2D region is on the screen. You draw your stuff
+ inside total_rect() as you would normally except that you only modify
+ pixels that are also inside display_rect(). When the user moves the
+ scroll bars the position of total_rect() is updated accordingly, causing
+ the widget's content to scroll across the screen.
+ !*/
+
+ public:
+ scrollable_region (
+ drawable_window& w,
+ unsigned long events = 0
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ - This object will not receive any events or draw() requests until
+ enable_events() is called
+ - the events flags are passed on to the drawable object's
+ constructor.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~scrollable_region (
+ ) = 0;
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ );
+ /*!
+ requires
+ - style_type == a type that inherits from scrollable_region_style
+ ensures
+ - this scrollable_region object will draw itself using the given
+ style
+ !*/
+
+ virtual void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ ensures
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this widget stays the
+ same but its width and height are modified.
+ !*/
+
+ long horizontal_scroll_pos (
+ ) const;
+ /*!
+ ensures
+ - returns the current position of the horizontal scroll bar.
+ 0 means it is at the far left while bigger values represent
+ scroll positions closer to the right.
+ !*/
+
+ long vertical_scroll_pos (
+ ) const;
+ /*!
+ ensures
+ - returns the current position of the vertical scroll bar.
+ 0 means it is at the top and bigger values represent scroll positions
+ closer to the bottom.
+ !*/
+
+ void set_horizontal_scroll_pos (
+ long pos
+ );
+ /*!
+ ensures
+ - if (pos is a valid horizontal scroll position) then
+ - #horizontal_scroll_pos() == pos
+ - else
+ - #horizontal_scroll_pos() == the valid scroll position closest to pos
+ !*/
+
+ void set_vertical_scroll_pos (
+ long pos
+ );
+ /*!
+ ensures
+ - if (pos is a valid vertical scroll position) then
+ - #vertical_scroll_pos() == pos
+ - else
+ - #vertical_scroll_pos() == the valid scroll position closest to pos
+ !*/
+
+ unsigned long horizontal_mouse_wheel_scroll_increment (
+ ) const;
+ /*!
+ ensures
+ - returns the number of positions the horizontal scroll bar
+ moves when the user scrolls the mouse wheel.
+ !*/
+
+ unsigned long vertical_mouse_wheel_scroll_increment (
+ ) const;
+ /*!
+ ensures
+ - returns the number of positions the vertical scroll bar
+ moves when the user scrolls the mouse wheel.
+ !*/
+
+ void set_horizontal_mouse_wheel_scroll_increment (
+ unsigned long inc
+ );
+ /*!
+ ensures
+ - #horizontal_mouse_wheel_scroll_increment() == inc
+ !*/
+
+ void set_vertical_mouse_wheel_scroll_increment (
+ unsigned long inc
+ );
+ /*!
+ ensures
+ - #vertical_mouse_wheel_scroll_increment() == inc
+ !*/
+
+
+ unsigned long horizontal_scroll_increment (
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels that total_rect() is moved by when
+ the horizontal scroll bar moves by one position
+ !*/
+
+ unsigned long vertical_scroll_increment (
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels that total_rect() is moved by when
+ the vertical scroll bar moves by one position
+ !*/
+
+ void set_horizontal_scroll_increment (
+ unsigned long inc
+ );
+ /*!
+ ensures
+ - #horizontal_scroll_increment() == inc
+ !*/
+
+ void set_vertical_scroll_increment (
+ unsigned long inc
+ );
+ /*!
+ ensures
+ - #vertical_scroll_increment() == inc
+ !*/
+
+ bool mouse_drag_enabled (
+ ) const;
+ /*!
+ ensures
+ - if (the user can drag this contents of this widget around by
+ holding down the left mouse button and dragging) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void enable_mouse_drag (
+ );
+ /*!
+ ensures
+ - #mouse_drag_enabled() == true
+ !*/
+
+ void disable_mouse_drag (
+ );
+ /*!
+ ensures
+ - #mouse_drag_enabled() == false
+ !*/
+
+ protected:
+
+ rectangle display_rect (
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - returns the rectangle on the screen that contains the scrollable
+ area in this widget. I.e. this is the area of this widget minus
+ the area taken up by the scroll bars and border decorations.
+ !*/
+
+ void set_total_rect_size (
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ requires
+ - mutex drawable::m is locked
+ - (width > 0 && height > 0) || (width == 0 && height == 0)
+ ensures
+ - #total_rect().width() == width
+ - #total_rect().height() == height
+ - The scroll bars as well as the position of #total_rect()
+ is updated so that the total rect is still in the correct
+ position with respect to the scroll bars.
+ !*/
+
+ const rectangle& total_rect (
+ ) const;
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - returns a rectangle that represents the entire scrollable
+ region inside this widget, even the parts that are outside
+ display_rect().
+ !*/
+
+ void scroll_to_rect (
+ const rectangle& r
+ );
+ /*!
+ requires
+ - mutex drawable::m is locked
+ ensures
+ - Adjusts the scroll bars of this object so that the part of
+ the total_rect() rectangle that overlaps with r is displayed in
+ the display_rect() rectangle on the screen.
+ !*/
+
+ virtual void on_view_changed (
+ ) {}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex drawable::m is locked
+ ensures
+ - on_view_changed() is called whenever the user causes the view of the
+ scrollable_region to change. That is, this function is called when the
+ user scrolls around in the region.
+ !*/
+
+ // ---------------------------- event handlers ----------------------------
+ // The following event handlers are used in this object. So if you
+ // use any of them in your derived object you should pass the events
+ // back to it so that they still operate unless you wish to hijack the
+ // event for your own reasons (e.g. to override the mouse wheel action
+ // this object performs)
+
+ void on_wheel_down (unsigned long state);
+ void on_wheel_up (unsigned long state);
+ void on_mouse_move (unsigned long state, long x, long y);
+ void on_mouse_down (unsigned long btn, unsigned long state, long x, long y, bool is_double_click);
+ void on_mouse_up (unsigned long btn, unsigned long state, long x, long y);
+ void draw (const canvas& c) const;
+
+ private:
+
+ // restricted functions
+ scrollable_region(scrollable_region&); // copy constructor
+ scrollable_region& operator=(scrollable_region&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BASE_WIDGETs_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/gui_widgets/canvas_drawing.cpp b/ml/dlib/dlib/gui_widgets/canvas_drawing.cpp
new file mode 100644
index 000000000..0fecd1cd3
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/canvas_drawing.cpp
@@ -0,0 +1,101 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CANVAS_DRAWINg_CPP_
+#define DLIB_CANVAS_DRAWINg_CPP_
+
+#include "canvas_drawing.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_sunken_rectangle (
+ const canvas& c,
+ const rectangle& border,
+ unsigned char alpha
+ )
+ {
+ rectangle area = border.intersect(c);
+ if (area.is_empty() == false)
+ {
+ const rgb_alpha_pixel dark_gray(64,64,64,alpha);
+ const rgb_alpha_pixel gray(128,128,128,alpha);
+ const rgb_alpha_pixel white(255,255,255,alpha);
+ const rgb_alpha_pixel background(212,208,200,alpha);
+
+ draw_line(c,point(border.left(),border.top()),point(border.right()-1,border.top()),gray);
+
+ draw_line(c,point(border.left(),border.bottom()),point(border.right(),border.bottom()),white);
+ draw_line(c,point(border.left()+1,border.bottom()-1),point(border.right()-1,border.bottom()-1),background);
+
+ draw_line(c,point(border.left(),border.top()+1),point(border.left(),border.bottom()-1),gray);
+
+ draw_line(c,point(border.right(),border.top()),point(border.right(),border.bottom()-1),white);
+ draw_line(c,point(border.right()-1,border.top()+1),point(border.right()-1,border.bottom()-2),background);
+
+ draw_line(c,point(border.left()+1,border.top()+1),point(border.left()+1,border.bottom()-2),dark_gray);
+ draw_line(c,point(border.left()+1,border.top()+1),point(border.right()-2,border.top()+1),dark_gray);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_button_down (
+ const canvas& c,
+ const rectangle& btn,
+ unsigned char alpha
+ )
+ {
+ rectangle area = btn.intersect(c);
+ if (area.is_empty() == false)
+ {
+ const rgb_alpha_pixel dark_gray(64,64,64,alpha);
+ const rgb_alpha_pixel gray(128,128,128,alpha);
+ const rgb_alpha_pixel black(0,0,0,alpha);
+
+ draw_line(c,point(btn.left(),btn.top()),point(btn.right(),btn.top()),black);
+
+ draw_line(c,point(btn.left()+1,btn.bottom()),point(btn.right(),btn.bottom()),dark_gray);
+ draw_line(c,point(btn.left()+1,btn.top()+1),point(btn.right()-1,btn.top()+1),gray);
+
+ draw_line(c,point(btn.left(),btn.top()+1),point(btn.left(),btn.bottom()),black);
+
+ draw_line(c,point(btn.right(),btn.top()+1),point(btn.right(),btn.bottom()-1),dark_gray);
+ draw_line(c,point(btn.left()+1,btn.top()+1),point(btn.left()+1,btn.bottom()-1),gray);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_button_up (
+ const canvas& c,
+ const rectangle& btn,
+ unsigned char alpha
+ )
+ {
+ rectangle area = btn.intersect(c);
+ if (area.is_empty() == false)
+ {
+ const rgb_alpha_pixel dark_gray(64,64,64,alpha);
+ const rgb_alpha_pixel gray(128,128,128,alpha);
+ const rgb_alpha_pixel white(255,255,255,alpha);
+
+ draw_line(c,point(btn.left(),btn.top()),point(btn.right()-1,btn.top()),white);
+
+ draw_line(c,point(btn.left(),btn.bottom()),point(btn.right(),btn.bottom()),dark_gray);
+ draw_line(c,point(btn.left()+1,btn.bottom()-1),point(btn.right()-1,btn.bottom()-1),gray);
+
+ draw_line(c,point(btn.left(),btn.top()+1),point(btn.left(),btn.bottom()-1),white);
+
+ draw_line(c,point(btn.right(),btn.top()),point(btn.right(),btn.bottom()-1),dark_gray);
+ draw_line(c,point(btn.right()-1,btn.top()+1),point(btn.right()-1,btn.bottom()-2),gray);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CANVAS_DRAWINg_CPP_
+
diff --git a/ml/dlib/dlib/gui_widgets/canvas_drawing.h b/ml/dlib/dlib/gui_widgets/canvas_drawing.h
new file mode 100644
index 000000000..61f688112
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/canvas_drawing.h
@@ -0,0 +1,964 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifndef DLIB_GUI_CANVAS_DRAWINg_
+#define DLIB_GUI_CANVAS_DRAWINg_
+
+#include "canvas_drawing_abstract.h"
+#include "../gui_core.h"
+#include "../algs.h"
+#include "../array2d.h"
+#include "../pixel.h"
+#include "../image_transforms/assign_image.h"
+#include "../geometry.h"
+#include <cmath>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void draw_line (
+ const canvas& c,
+ const point& p1,
+ const point& p2,
+ const pixel_type& pixel,
+ const rectangle& area = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ )
+ {
+ rectangle valid_area(c.intersect(area));
+ long x1 = p1.x();
+ long y1 = p1.y();
+ long x2 = p2.x();
+ long y2 = p2.y();
+ if (x1 == x2)
+ {
+ // if the x coordinate is inside the canvas's area
+ if (x1 <= valid_area.right() && x1 >= valid_area.left())
+ {
+ // make sure y1 comes before y2
+ if (y1 > y2)
+ swap(y1,y2);
+
+ y1 = std::max(y1,valid_area.top());
+ y2 = std::min(y2,valid_area.bottom());
+ // this is a vertical line
+ for (long y = y1; y <= y2; ++y)
+ {
+ assign_pixel(c[y-c.top()][x1-c.left()], pixel);
+ }
+ }
+ }
+ else if (y1 == y2)
+ {
+ // if the y coordinate is inside the canvas's area
+ if (y1 <= valid_area.bottom() && y1 >= valid_area.top())
+ {
+ // make sure x1 comes before x2
+ if (x1 > x2)
+ swap(x1,x2);
+
+ x1 = std::max(x1,valid_area.left());
+ x2 = std::min(x2,valid_area.right());
+ // this is a horizontal line
+ for (long x = x1; x <= x2; ++x)
+ {
+ assign_pixel(c[y1-c.top()][x-c.left()], pixel);
+ }
+ }
+ }
+ else
+ {
+ rgb_alpha_pixel alpha_pixel;
+ assign_pixel(alpha_pixel, pixel);
+ const unsigned char max_alpha = alpha_pixel.alpha;
+
+ const long rise = (((long)y2) - ((long)y1));
+ const long run = (((long)x2) - ((long)x1));
+ if (std::abs(rise) < std::abs(run))
+ {
+ const double slope = ((double)rise)/run;
+
+ double first, last;
+
+ if (x1 > x2)
+ {
+ first = std::max(x2,valid_area.left());
+ last = std::min(x1,valid_area.right());
+ }
+ else
+ {
+ first = std::max(x1,valid_area.left());
+ last = std::min(x2,valid_area.right());
+ }
+
+
+ long y;
+ long x;
+ const double x1f = x1;
+ const double y1f = y1;
+ for (double i = first; i <= last; ++i)
+ {
+ const double dy = slope*(i-x1f) + y1f;
+ const double dx = i;
+
+ y = static_cast<long>(dy);
+ x = static_cast<long>(dx);
+
+
+ if (y >= valid_area.top() && y <= valid_area.bottom())
+ {
+ alpha_pixel.alpha = static_cast<unsigned char>((1.0-(dy-y))*max_alpha);
+ assign_pixel(c[y-c.top()][x-c.left()], alpha_pixel);
+ }
+ if (y+1 >= valid_area.top() && y+1 <= valid_area.bottom())
+ {
+ alpha_pixel.alpha = static_cast<unsigned char>((dy-y)*max_alpha);
+ assign_pixel(c[y+1-c.top()][x-c.left()], alpha_pixel);
+ }
+ }
+ }
+ else
+ {
+ const double slope = ((double)run)/rise;
+
+ double first, last;
+
+ if (y1 > y2)
+ {
+ first = std::max(y2,valid_area.top());
+ last = std::min(y1,valid_area.bottom());
+ }
+ else
+ {
+ first = std::max(y1,valid_area.top());
+ last = std::min(y2,valid_area.bottom());
+ }
+
+ long x;
+ long y;
+ const double x1f = x1;
+ const double y1f = y1;
+ for (double i = first; i <= last; ++i)
+ {
+ const double dx = slope*(i-y1f) + x1f;
+ const double dy = i;
+
+ y = static_cast<long>(dy);
+ x = static_cast<long>(dx);
+
+ if (x >= valid_area.left() && x <= valid_area.right())
+ {
+ alpha_pixel.alpha = static_cast<unsigned char>((1.0-(dx-x))*max_alpha);
+ assign_pixel(c[y-c.top()][x-c.left()], alpha_pixel);
+ }
+ if (x+1 >= valid_area.left() && x+1 <= valid_area.right())
+ {
+ alpha_pixel.alpha = static_cast<unsigned char>((dx-x)*max_alpha);
+ assign_pixel(c[y-c.top()][x+1-c.left()], alpha_pixel);
+ }
+ }
+ }
+ }
+
+ }
+ inline void draw_line (
+ const canvas& c,
+ const point& p1,
+ const point& p2
+ ){ draw_line(c,p1,p2,0); }
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_sunken_rectangle (
+ const canvas& c,
+ const rectangle& border,
+ unsigned char alpha = 255
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ inline void draw_pixel (
+ const canvas& c,
+ const point& p,
+ const pixel_type& pixel
+ )
+ {
+ if (c.contains(p))
+ {
+ assign_pixel(c[p.y()-c.top()][p.x()-c.left()],pixel);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void draw_checkered (
+ const canvas& c,
+ const rectangle& a,
+ const pixel_type& pixel1,
+ const pixel_type& pixel2
+ )
+ {
+ rectangle area = a.intersect(c);
+ if (area.is_empty())
+ return;
+
+ for (long i = area.left(); i <= area.right(); ++i)
+ {
+ for (long j = area.top(); j <= area.bottom(); ++j)
+ {
+ canvas::pixel& p = c[j - c.top()][i - c.left()];
+ if ((j&0x1) ^ (i&0x1))
+ {
+ assign_pixel(p,pixel1);
+ }
+ else
+ {
+ assign_pixel(p,pixel2);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_button_down (
+ const canvas& c,
+ const rectangle& btn,
+ unsigned char alpha = 255
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_button_up (
+ const canvas& c,
+ const rectangle& btn,
+ unsigned char alpha = 255
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void draw_circle (
+ const canvas& c,
+ const point& center_point,
+ double radius,
+ const pixel_type& pixel,
+ const rectangle& area = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ )
+ {
+ using std::sqrt;
+ rectangle valid_area(c.intersect(area));
+ const long x = center_point.x();
+ const long y = center_point.y();
+ if (radius > 1)
+ {
+ long first_x = static_cast<long>(x - radius + 0.5);
+ long last_x = static_cast<long>(x + radius + 0.5);
+ const double rs = radius*radius;
+
+ // ensure that we only loop over the part of the x dimension that this
+ // canvas contains.
+ if (first_x < valid_area.left())
+ first_x = valid_area.left();
+ if (last_x > valid_area.right())
+ last_x = valid_area.right();
+
+ long top, bottom;
+
+ top = static_cast<long>(sqrt(std::max(rs - (first_x-x-0.5)*(first_x-x-0.5),0.0))+0.5);
+ top += y;
+ long last = top;
+
+ // draw the left half of the circle
+ long middle = std::min(x-1,last_x);
+ for (long i = first_x; i <= middle; ++i)
+ {
+ double a = i - x + 0.5;
+ // find the top of the arc
+ top = static_cast<long>(sqrt(std::max(rs - a*a,0.0))+0.5);
+ top += y;
+ long temp = top;
+
+ while(top >= last)
+ {
+ bottom = y - top + y;
+ if (top >= valid_area.top() && top <= valid_area.bottom() )
+ {
+ assign_pixel(c[top-c.top()][i-c.left()],pixel);
+ }
+
+ if (bottom >= valid_area.top() && bottom <= valid_area.bottom() )
+ {
+ assign_pixel(c[bottom-c.top()][i-c.left()],pixel);
+ }
+ --top;
+ }
+
+ last = temp;
+ }
+
+ middle = std::max(x,first_x);
+ top = static_cast<long>(sqrt(std::max(rs - (last_x-x+0.5)*(last_x-x+0.5),0.0))+0.5);
+ top += y;
+ last = top;
+ // draw the right half of the circle
+ for (long i = last_x; i >= middle; --i)
+ {
+ double a = i - x - 0.5;
+ // find the top of the arc
+ top = static_cast<long>(sqrt(std::max(rs - a*a,0.0))+0.5);
+ top += y;
+ long temp = top;
+
+ while(top >= last)
+ {
+ bottom = y - top + y;
+ if (top >= valid_area.top() && top <= valid_area.bottom() )
+ {
+ assign_pixel(c[top-c.top()][i-c.left()],pixel);
+ }
+
+ if (bottom >= valid_area.top() && bottom <= valid_area.bottom() )
+ {
+ assign_pixel(c[bottom-c.top()][i-c.left()],pixel);
+ }
+ --top;
+ }
+
+ last = temp;
+ }
+ }
+ else if (radius == 1 &&
+ x >= valid_area.left() && x <= valid_area.right() &&
+ y >= valid_area.top() && y <= valid_area.bottom() )
+ {
+ assign_pixel(c[y-c.top()][x-c.left()], pixel);
+ }
+ }
+ inline void draw_circle (
+ const canvas& c,
+ const point& center_point,
+ double radius
+ ){ draw_circle(c, center_point, radius, 0); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void draw_solid_circle (
+ const canvas& c,
+ const point& center_point,
+ double radius,
+ const pixel_type& pixel,
+ const rectangle& area = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ )
+ {
+ using std::sqrt;
+ rectangle valid_area(c.intersect(area));
+ const long x = center_point.x();
+ const long y = center_point.y();
+ if (radius > 1)
+ {
+ long first_x = static_cast<long>(x - radius + 0.5);
+ long last_x = static_cast<long>(x + radius + 0.5);
+ const double rs = radius*radius;
+
+ // ensure that we only loop over the part of the x dimension that this
+ // canvas contains.
+ if (first_x < valid_area.left())
+ first_x = valid_area.left();
+ if (last_x > valid_area.right())
+ last_x = valid_area.right();
+
+ long top, bottom;
+
+ top = static_cast<long>(sqrt(std::max(rs - (first_x-x-0.5)*(first_x-x-0.5),0.0))+0.5);
+ top += y;
+ long last = top;
+
+ // draw the left half of the circle
+ long middle = std::min(x-1,last_x);
+ for (long i = first_x; i <= middle; ++i)
+ {
+ double a = i - x + 0.5;
+ // find the top of the arc
+ top = static_cast<long>(sqrt(std::max(rs - a*a,0.0))+0.5);
+ top += y;
+ long temp = top;
+
+ while(top >= last)
+ {
+ bottom = y - top + y;
+ draw_line(c, point(i,top),point(i,bottom),pixel,area);
+ --top;
+ }
+
+ last = temp;
+ }
+
+ middle = std::max(x,first_x);
+ top = static_cast<long>(sqrt(std::max(rs - (last_x-x+0.5)*(last_x-x+0.5),0.0))+0.5);
+ top += y;
+ last = top;
+ // draw the right half of the circle
+ for (long i = last_x; i >= middle; --i)
+ {
+ double a = i - x - 0.5;
+ // find the top of the arc
+ top = static_cast<long>(sqrt(std::max(rs - a*a,0.0))+0.5);
+ top += y;
+ long temp = top;
+
+ while(top >= last)
+ {
+ bottom = y - top + y;
+ draw_line(c, point(i,top),point(i,bottom),pixel,area);
+ --top;
+ }
+
+ last = temp;
+ }
+ }
+ else if (radius == 1 &&
+ x >= valid_area.left() && x <= valid_area.right() &&
+ y >= valid_area.top() && y <= valid_area.bottom() )
+ {
+ assign_pixel(c[y-c.top()][x-c.left()], pixel);
+ }
+ }
+ inline void draw_solid_circle (
+ const canvas& c,
+ const point& center_point,
+ double radius
+ ) { draw_solid_circle(c, center_point, radius, 0); }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ template <typename alloc>
+ void get_convex_polygon_shape (
+ const std::vector<point>& points,
+ const long top,
+ const long bottom,
+ std::vector<double,alloc>& left_boundary,
+ std::vector<double,alloc>& right_boundary
+ )
+ /*!
+ requires
+ - 0 <= top <= bottom
+ ensures
+ - interprets points as the coordinates defining a convex polygon. In
+ particular, we interpret points as a list of the vertices of the polygon
+ and assume they are ordered in clockwise order.
+ - #left_boundary.size() == bottom-top+1
+ - #right_boundary.size() == bottom-top+1
+ - for all top <= y <= bottom:
+ - #left_boundary[y-top] == the x coordinate for the left most side of
+ the polygon at coordinate y.
+ - #right_boundary[y-top] == the x coordinate for the right most side of
+ the polygon at coordinate y.
+ !*/
+ {
+ using std::min;
+ using std::max;
+
+ left_boundary.assign(bottom-top+1, std::numeric_limits<double>::infinity());
+ right_boundary.assign(bottom-top+1, -std::numeric_limits<double>::infinity());
+
+ // trace out the points along the edge of the polynomial and record them
+ for (unsigned long i = 0; i < points.size(); ++i)
+ {
+ const point p1 = points[i];
+ const point p2 = points[(i+1)%points.size()];
+
+ if (p1.y() == p2.y())
+ {
+ if (top <= p1.y() && p1.y() <= bottom)
+ {
+ const long y = p1.y() - top;
+ const double xmin = min(p1.x(), p2.x());
+ const double xmax = min(p1.x(), p2.x());
+ left_boundary[y] = min(left_boundary[y], xmin);
+ right_boundary[y] = max(right_boundary[y], xmax);
+ }
+ }
+ else
+ {
+ // Here we trace out the line from p1 to p2 and record where it hits.
+
+ // x = m*y + b
+ const double m = (p2.x() - p1.x())/(double)(p2.y()-p1.y());
+ const double b = p1.x() - m*p1.y(); // because: x1 = m*y1 + b
+
+ const long ymin = max(top,min(p1.y(), p2.y()));
+ const long ymax = min(bottom,max(p1.y(), p2.y()));
+ for (long y = ymin; y <= ymax; ++y)
+ {
+ const double x = m*y + b;
+ const unsigned long idx = y-top;
+ left_boundary[idx] = min(left_boundary[idx], x);
+ right_boundary[idx] = max(right_boundary[idx], x);
+ }
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+ template <typename pixel_type>
+ void draw_solid_convex_polygon (
+ const canvas& c,
+ const std::vector<point>& polygon,
+ const pixel_type& pixel,
+ const rectangle& area = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ )
+ {
+ using std::max;
+ using std::min;
+ const rectangle valid_area(c.intersect(area));
+
+ rectangle bounding_box;
+ for (unsigned long i = 0; i < polygon.size(); ++i)
+ bounding_box += polygon[i];
+
+ // Don't do anything if the polygon is totally outside the area we can draw in
+ // right now.
+ if (bounding_box.intersect(valid_area).is_empty())
+ return;
+
+ rgb_alpha_pixel alpha_pixel;
+ assign_pixel(alpha_pixel, pixel);
+ const unsigned char max_alpha = alpha_pixel.alpha;
+
+ // we will only want to loop over the part of left_boundary that is part of the
+ // valid_area.
+ long top = max(valid_area.top(),bounding_box.top());
+ long bottom = min(valid_area.bottom(),bounding_box.bottom());
+
+ // Since we look at the adjacent rows of boundary information when doing the alpha
+ // blending, we want to make sure we always have some boundary information unless
+ // we are at the absolute edge of the polygon.
+ const long top_offset = (top == bounding_box.top()) ? 0 : 1;
+ const long bottom_offset = (bottom == bounding_box.bottom()) ? 0 : 1;
+ if (top != bounding_box.top())
+ top -= 1;
+ if (bottom != bounding_box.bottom())
+ bottom += 1;
+
+ std::vector<double> left_boundary;
+ std::vector<double> right_boundary;
+ impl::get_convex_polygon_shape(polygon, top, bottom, left_boundary, right_boundary);
+
+
+ // draw the polygon row by row
+ for (unsigned long i = top_offset; i < left_boundary.size(); ++i)
+ {
+ long left_x = static_cast<long>(std::ceil(left_boundary[i]));
+ long right_x = static_cast<long>(std::floor(right_boundary[i]));
+
+ left_x = max(left_x, valid_area.left());
+ right_x = min(right_x, valid_area.right());
+
+ if (i < left_boundary.size()-bottom_offset)
+ {
+ // draw the main body of the polygon
+ for (long x = left_x; x <= right_x; ++x)
+ {
+ const long y = i+top;
+ assign_pixel(c[y-c.top()][x-c.left()], pixel);
+ }
+ }
+
+ if (i == 0)
+ continue;
+
+ // Now draw anti-aliased edges so they don't look all pixely.
+
+ // Alpha blend the edges on the left side.
+ double delta = left_boundary[i-1] - left_boundary[i];
+ if (std::abs(delta) <= 1)
+ {
+ if (std::floor(left_boundary[i]) != left_x)
+ {
+ const point p(static_cast<long>(std::floor(left_boundary[i])), i+top);
+ rgb_alpha_pixel temp = alpha_pixel;
+ temp.alpha = max_alpha-static_cast<unsigned char>((left_boundary[i]-p.x())*max_alpha);
+ if (valid_area.contains(p))
+ assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp);
+ }
+ }
+ else if (delta < 0) // on the bottom side
+ {
+ for (long x = static_cast<long>(std::ceil(left_boundary[i-1])); x < left_x; ++x)
+ {
+ const point p(x, i+top);
+ rgb_alpha_pixel temp = alpha_pixel;
+ temp.alpha = static_cast<unsigned char>((x-left_boundary[i-1])/std::abs(delta)*max_alpha);
+ if (valid_area.contains(p))
+ assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp);
+ }
+ }
+ else // on the top side
+ {
+ const long old_left_x = static_cast<long>(std::ceil(left_boundary[i-1]));
+ for (long x = left_x; x < old_left_x; ++x)
+ {
+ const point p(x, i+top-1);
+ rgb_alpha_pixel temp = alpha_pixel;
+ temp.alpha = static_cast<unsigned char>((x-left_boundary[i])/delta*max_alpha);
+ if (valid_area.contains(p))
+ assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp);
+ }
+ }
+
+
+ // Alpha blend the edges on the right side
+ delta = right_boundary[i-1] - right_boundary[i];
+ if (std::abs(delta) <= 1)
+ {
+ if (std::ceil(right_boundary[i]) != right_x)
+ {
+ const point p(static_cast<long>(std::ceil(right_boundary[i])), i+top);
+ rgb_alpha_pixel temp = alpha_pixel;
+ temp.alpha = max_alpha-static_cast<unsigned char>((p.x()-right_boundary[i])*max_alpha);
+ if (valid_area.contains(p))
+ assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp);
+ }
+ }
+ else if (delta < 0) // on the top side
+ {
+ for (long x = static_cast<long>(std::floor(right_boundary[i-1]))+1; x <= right_x; ++x)
+ {
+ const point p(x, i+top-1);
+ rgb_alpha_pixel temp = alpha_pixel;
+ temp.alpha = static_cast<unsigned char>((right_boundary[i]-x)/std::abs(delta)*max_alpha);
+ if (valid_area.contains(p))
+ assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp);
+ }
+ }
+ else // on the bottom side
+ {
+ const long old_right_x = static_cast<long>(std::floor(right_boundary[i-1]));
+ for (long x = right_x+1; x <= old_right_x; ++x)
+ {
+ const point p(x, i+top);
+ rgb_alpha_pixel temp = alpha_pixel;
+ temp.alpha = static_cast<unsigned char>((right_boundary[i-1]-x)/delta*max_alpha);
+ if (valid_area.contains(p))
+ assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp);
+ }
+ }
+ }
+ }
+ inline void draw_solid_convex_polygon (
+ const canvas& c,
+ const std::vector<point>& polygon
+ ) { draw_solid_convex_polygon(c, polygon, 0); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void draw_image (
+ const canvas& c,
+ const point& p,
+ const image_type& img,
+ const rectangle& area_ = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ )
+ {
+ const long x = p.x();
+ const long y = p.y();
+ rectangle rect(x,y,num_columns(img)+x-1,num_rows(img)+y-1);
+ rectangle area = c.intersect(rect).intersect(area_);
+ if (area.is_empty())
+ return;
+
+ for (long row = area.top(); row <= area.bottom(); ++row)
+ {
+ for (long col = area.left(); col <= area.right(); ++col)
+ {
+ assign_pixel(c[row-c.top()][col-c.left()], img[row-rect.top()][col-rect.left()]);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void draw_image (
+ const canvas& c,
+ const rectangle& rect,
+ const image_type& img,
+ const rectangle& area_ = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ )
+ {
+ const rectangle area = c.intersect(rect).intersect(area_);
+ if (area.is_empty() || num_columns(img) * num_rows(img) == 0)
+ return;
+
+ const matrix<long,1> x = matrix_cast<long>(round(linspace(0, num_columns(img)-1, rect.width())));
+ const matrix<long,1> y = matrix_cast<long>(round(linspace(0, num_rows(img)-1, rect.height())));
+
+ for (long row = area.top(); row <= area.bottom(); ++row)
+ {
+ const long r = y(row-rect.top());
+ long cc = area.left() - rect.left();
+ for (long col = area.left(); col <= area.right(); ++col)
+ {
+ assign_pixel(c[row-c.top()][col-c.left()], img[r][x(cc++)]);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void draw_rounded_rectangle (
+ const canvas& c,
+ const rectangle& rect,
+ unsigned radius,
+ const pixel_type& color,
+ const rectangle& area_ = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ )
+ {
+ if ( rect.intersect ( c ).is_empty() )
+ return;
+
+ draw_line ( c, point(rect.left() + radius + 1, rect.bottom()),
+ point(rect.right() - radius - 1, rect.bottom()), color,area_ );
+
+ draw_line ( c, point(rect.left() + radius + 1, rect.top()),
+ point(rect.right() - radius - 1, rect.top()), color,area_ );
+
+ draw_line ( c, point(rect.left(), rect.top() + radius + 1),
+ point(rect.left(), rect.bottom() - radius - 1), color,area_ );
+
+ draw_line ( c, point(rect.right(), rect.top() + radius + 1),
+ point(rect.right(), rect.bottom() - radius - 1), color,area_ );
+
+ unsigned x = radius, y = 0, old_x = x;
+
+ point p;
+ while ( x > y )
+ {
+ p = point(rect.left() + radius - y, rect.top() + radius - x);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.right() - radius + y, rect.top() + radius - x);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.right() - radius + y, rect.bottom() - radius + x);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.left() + radius - y, rect.bottom() - radius + x);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.left() + radius - x, rect.top() + radius - y);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.right() - radius + x, rect.top() + radius - y);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.right() - radius + x, rect.bottom() - radius + y);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.left() + radius - x, rect.bottom() - radius + y);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ y++;
+ old_x = x;
+ x = square_root ( ( radius * radius - y * y ) * 4 ) / 2;
+ }
+
+ if ( x == y && old_x != x )
+ {
+ p = point(rect.left() + radius - y, rect.top() + radius - x);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.right() - radius + y, rect.top() + radius - x);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.right() - radius + y, rect.bottom() - radius + x);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ p = point(rect.left() + radius - y, rect.bottom() - radius + x);
+ if (area_.contains(p)) draw_pixel (c, p , color );
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void fill_gradient_rounded (
+ const canvas& c,
+ const rectangle& rect,
+ unsigned long radius,
+ const pixel_type& top_color,
+ const pixel_type& bottom_color,
+ const rectangle& area = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+
+ )
+ {
+ rectangle valid_area(c.intersect(area.intersect(rect)));
+ if ( valid_area.is_empty() )
+ return;
+
+
+ unsigned long m_prev = 0, m = radius, c_div = valid_area.height() - 1;
+
+ const long c_top = valid_area.top();
+ const long c_bottom = valid_area.bottom();
+
+ for ( long y = c_top; y <= c_bottom;y++ )
+ {
+
+ unsigned long c_s = y - c_top;
+
+ unsigned long c_t = c_bottom - y;
+
+
+ if ( c_div == 0 )
+ {
+ // only a single round, just take the average color
+ c_div = 2;
+ c_s = c_t = 1;
+ }
+
+ rgb_alpha_pixel color;
+ vector_to_pixel(color,
+ ((pixel_to_vector<unsigned long>(top_color)*c_t + pixel_to_vector<unsigned long>(bottom_color)*c_s)/c_div));
+
+ unsigned long s = y - rect.top();
+
+ unsigned long t = rect.bottom() - y;
+
+ if ( s < radius )
+ {
+ m = radius - square_root ( ( radius * radius - ( radius - s ) * ( radius - s ) ) * 4 ) / 2;
+
+ if ( s == m && m + 1 < m_prev ) // these are hacks to remove distracting artefacts at small radii
+ m++;
+ }
+ else if ( t < radius )
+ {
+ m = radius - square_root ( ( radius * radius - ( radius - t ) * ( radius - t ) ) * 4 ) / 2;
+
+ if ( t == m && m == m_prev )
+ m++;
+ }
+ else
+ {
+ m = 0;
+ }
+
+ m_prev = m;
+
+ draw_line ( c, point(rect.left() + m, y),
+ point(rect.right() - m, y), color, valid_area );
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void draw_rectangle (
+ const canvas& c,
+ rectangle rect,
+ const pixel_type& pixel,
+ const rectangle& area = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ )
+ {
+ // top line
+ draw_line(c, point(rect.left(),rect.top()),
+ point(rect.right(),rect.top()),
+ pixel, area);
+
+ // bottom line
+ draw_line(c, point(rect.left(),rect.bottom()),
+ point(rect.right(),rect.bottom()),
+ pixel, area);
+
+ // left line
+ draw_line(c, point(rect.left(),rect.top()),
+ point(rect.left(),rect.bottom()),
+ pixel, area);
+
+ // right line
+ draw_line(c, point(rect.right(),rect.top()),
+ point(rect.right(),rect.bottom()),
+ pixel, area);
+ }
+ inline void draw_rectangle (
+ const canvas& c,
+ rectangle rect
+ ){ draw_rectangle(c, rect, 0); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void fill_rect (
+ const canvas& c,
+ const rectangle& rect,
+ const pixel_type& pixel
+ )
+ {
+ rectangle area = rect.intersect(c);
+ for (long y = area.top(); y <= area.bottom(); ++y)
+ {
+ for (long x = area.left(); x <= area.right(); ++x)
+ {
+ assign_pixel(c[y-c.top()][x-c.left()], pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void fill_rect_with_vertical_gradient (
+ const canvas& c,
+ const rectangle& rect,
+ const pixel_type& pixel_top,
+ const pixel_type& pixel_bottom,
+ const rectangle& area_ = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ )
+ {
+ rectangle area = rect.intersect(c).intersect(area_);
+ pixel_type pixel;
+
+ const long s = rect.bottom()-rect.top();
+
+ for (long y = area.top(); y <= area.bottom(); ++y)
+ {
+ const long t = rect.bottom()-y;
+ const long b = y-rect.top();
+ vector_to_pixel(pixel,
+ ((pixel_to_vector<long>(pixel_top)*t +
+ pixel_to_vector<long>(pixel_bottom)*b)/s));
+
+ for (long x = area.left(); x <= area.right(); ++x)
+ {
+ assign_pixel(c[y-c.top()][x-c.left()], pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "canvas_drawing.cpp"
+#endif
+
+#endif // DLIB_GUI_CANVAS_DRAWINg_
+
diff --git a/ml/dlib/dlib/gui_widgets/canvas_drawing_abstract.h b/ml/dlib/dlib/gui_widgets/canvas_drawing_abstract.h
new file mode 100644
index 000000000..e4a298c76
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/canvas_drawing_abstract.h
@@ -0,0 +1,364 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_GUI_CANVAS_DRAWINg_ABSTRACT_
+#ifdef DLIB_GUI_CANVAS_DRAWINg_ABSTRACT_
+
+#include "../gui_core.h"
+#include "../pixel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void draw_line (
+ const canvas& c,
+ const point& p1,
+ const point& p2,
+ const pixel_type& pixel = rgb_pixel(0,0,0),
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - draws the part of the line from p1 to p1 that overlaps with
+ the canvas and area onto the canvas.
+ - Uses the given pixel color.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void draw_rectangle (
+ const canvas& c,
+ rectangle rect,
+ const pixel_type& pixel = rgb_pixel(0,0,0),
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - Draws the part of the rectangle that overlaps with
+ the canvas and area onto the canvas.
+ - Uses the given pixel color.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void draw_circle (
+ const canvas& c,
+ const point& center_point,
+ double radius,
+ const pixel_type& pixel = rgb_pixel(0,0,0),
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - draws the part of the circle centered at center_point with the given radius
+ that overlaps with the canvas and area onto the canvas.
+ - Uses the given pixel color.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void draw_pixel (
+ const canvas& c,
+ const point& p,
+ const pixel_type& pixel
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - if (c.contains(p)) then
+ - sets the pixel in c that represents the point p to the
+ given pixel color.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void draw_solid_circle (
+ const canvas& c,
+ const point& center_point,
+ double radius,
+ const pixel_type& pixel = rgb_pixel(0,0,0),
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - draws the part of the solid circle centered at center_point with the given
+ radius that overlaps with the canvas and area onto the canvas.
+ ("solid" means that the interior is also filled in with the given
+ pixel color)
+ - Uses the given pixel color.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void draw_solid_convex_polygon (
+ const canvas& c,
+ const std::vector<point>& polygon,
+ const pixel_type& pixel = rgb_pixel(0,0,0),
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - Interprets the given std::vector polygon object as defining a convex polygon
+ shape. In particular, the polygon is given by taking the points and drawing
+ lines between them. That is, imagine drawing a line connecting polygon[i]
+ and polygon[(i+1)%polygon.size()], for all valid i, and then filling in the
+ interior of the polygon. That is what this function does.
+ - When drawing the polygon, only the part of the polygon which overlaps both
+ the given canvas and area rectangle is drawn.
+ - Uses the given pixel color to draw the polygon.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_button_down (
+ const canvas& c,
+ const rectangle& btn,
+ unsigned char alpha = 255
+ );
+ /*!
+ requires
+ - 0 <= alpha <= 255
+ ensures
+ - draws the border of a button onto canvas c:
+ - the border will be that of a button that is depressed
+ - only the part of the border that overlaps with the canvas object
+ will be drawn.
+ - the border will be for the button whose area is defined by the
+ rectangle btn.
+ - performs alpha blending such that the button is drawn with full opacity
+ when alpha is 255 and fully transparent when alpha is 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_sunken_rectangle (
+ const canvas& c,
+ const rectangle& border,
+ unsigned char alpha = 255
+ );
+ /*!
+ requires
+ - 0 <= alpha <= 255
+ ensures
+ - draws a sunken rectangle around the given border.
+ (This is the type of border used for text_fields and
+ check_boxes and the like).
+ - performs alpha blending such that the rectangle is drawn with full opacity
+ when alpha is 255 and fully transparent when alpha is 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_button_up (
+ const canvas& c,
+ const rectangle& btn,
+ unsigned char alpha = 255
+ );
+ /*!
+ requires
+ - 0 <= alpha <= 255
+ ensures
+ - draws the border of a button onto canvas c:
+ - the border will be that of a button that is NOT depressed
+ - only the part of the border that overlaps with the canvas object
+ will be drawn.
+ - the border will be for the button whose area is defined by the
+ rectangle btn.
+ - performs alpha blending such that the button is drawn with full opacity
+ when alpha is 255 and fully transparent when alpha is 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void draw_checkered (
+ const canvas& c,
+ const rectangle& area,
+ const pixel_type& pixel1,
+ const pixel_type& pixel2
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - fills the area on the given canvas defined by the rectangle area with a checkers
+ board pattern where every other pixel gets assigned either pixel1 or pixel2.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void draw_image (
+ const canvas& c
+ const point& p,
+ const image_type& image,
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - image_type == an implementation of array2d/array2d_kernel_abstract.h
+ - pixel_traits<typename image_type::type> is defined
+ ensures
+ - draws the given image object onto the canvas such that the upper left corner of the
+ image will appear at the point p in the canvas's window. (note that the
+ upper left corner of the image is assumed to be the pixel image[0][0] and the
+ lower right corner of the image is assumed to be image[image.nr()-1][image.nc()-1])
+ - only draws the part of the image that overlaps with the area rectangle
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void draw_image (
+ const canvas& c,
+ const rectangle& rect,
+ const image_type& img,
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - image_type == an implementation of array2d/array2d_kernel_abstract.h
+ - pixel_traits<typename image_type::type> is defined
+ ensures
+ - draws the given image object onto the canvas such that the upper left corner
+ of the image will appear at the point rect.tl_corner() in the canvas's window
+ and the lower right corner of the image will appear at rect.br_corner() in
+ the canvas's window. (note that the upper left corner of the image is
+ assumed to be the pixel image[0][0] and the lower right corner of the image
+ is assumed to be image[image.nr()-1][image.nc()-1])
+ - only draws the part of the image that overlaps with the area rectangle
+ - Uses nearest neighbor interpolation when the given rect isn't the same size
+ as the input image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void fill_rect (
+ const canvas& c,
+ const rectangle& rect,
+ const pixel_type& pixel
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - fills the area defined by rect in the given canvas with the given pixel color.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void fill_rect_with_vertical_gradient (
+ const canvas& c,
+ const rectangle& rect,
+ const pixel_type& pixel_top,
+ const pixel_type& pixel_bottom,
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - fills the rectangle defined by rect in the given canvas with the given colors.
+ The top of the area will have the pixel_top color and will slowly fade
+ towards the pixel_bottom color towards the bottom of rect.
+ - only draws the part of the image that overlaps with the area rectangle
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void fill_gradient_rounded (
+ const canvas& c,
+ const rectangle& rect,
+ unsigned long radius,
+ const pixel_type& top_color,
+ const pixel_type& bottom_color,
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - Fills the region defined by rect in the given canvas with the given colors.
+ The top of the region will have the top_color color and will slowly fade
+ towards the bottom_color color towards the bottom of rect.
+ - The drawn rectangle will have rounded corners and with the amount of
+ - rounding given by the radius argument.
+ - only the part of this object that overlaps with area and the canvas
+ will be drawn on the canvas
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pixel_type
+ >
+ void draw_rounded_rectangle (
+ const canvas& c,
+ const rectangle& rect,
+ unsigned radius,
+ const pixel_type& color,
+ const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity)
+ );
+ /*!
+ requires
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - Draws the part of the rectangle that overlaps with
+ the canvas onto the canvas.
+ - The drawn rectangle will have rounded corners and with the amount of
+ rounding given by the radius argument.
+ - Uses the given pixel color.
+ - only draws the part of the image that overlaps with the area rectangle
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GUI_CANVAS_DRAWINg_ABSTRACT_
+
diff --git a/ml/dlib/dlib/gui_widgets/drawable.cpp b/ml/dlib/dlib/gui_widgets/drawable.cpp
new file mode 100644
index 000000000..8cf114950
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/drawable.cpp
@@ -0,0 +1,544 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DRAWABLe_CPP_
+#define DLIB_DRAWABLe_CPP_
+
+#include "drawable.h"
+
+#include <algorithm>
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------- drawable_window object ------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ rgb_pixel drawable_window::
+ background_color (
+ ) const
+ {
+ auto_mutex M(wm);
+ return bg_color;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ set_background_color (
+ unsigned long red_,
+ unsigned long green_,
+ unsigned long blue_
+ )
+ {
+ wm.lock();
+ bg_color.red = red_;
+ bg_color.green = green_;
+ bg_color.blue = blue_;
+ wm.unlock();
+ // now repaint the window
+ unsigned long width,height;
+ get_size(width,height);
+ rectangle rect(0,0,width-1,height-1);
+ invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ paint (
+ const canvas& c
+ )
+ {
+ ++event_id;
+ c.fill(bg_color.red,bg_color.green,bg_color.blue);
+
+ widgets.reset();
+ while (widgets.move_next())
+ {
+ widgets.element().value().reset();
+ while (widgets.element().value().move_next())
+ {
+ // only dispatch a draw() call if this widget isn't hidden
+ if (widgets.element().value().element()->hidden == false &&
+ widgets.element().value().element()->event_id != event_id)
+ {
+ widgets.element().value().element()->event_id = event_id;
+ widgets.element().value().element()->draw(c);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_user_event (
+ void* p,
+ int i
+ )
+ {
+ drawable* d = static_cast<drawable*>(p);
+ if (widget_set.is_member(d))
+ {
+ d->on_user_event(i);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_window_moved(
+ )
+ {
+ ++event_id;
+ window_moved.reset();
+ while (window_moved.move_next())
+ {
+ if (window_moved.element()->event_id != event_id)
+ {
+ window_moved.element()->event_id = event_id;
+ window_moved.element()->on_window_moved();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_window_resized(
+ )
+ {
+ ++event_id;
+ window_resized.reset();
+ while (window_resized.move_next())
+ {
+ if (window_resized.element()->event_id != event_id)
+ {
+ window_resized.element()->event_id = event_id;
+ window_resized.element()->on_window_resized();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ ++event_id;
+ keyboard.reset();
+ while (keyboard.move_next())
+ {
+ if (keyboard.element()->event_id != event_id)
+ {
+ keyboard.element()->event_id = event_id;
+ keyboard.element()->on_keydown(key,is_printable,state);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_focus_gained (
+ )
+ {
+ ++event_id;
+ focus.reset();
+ while (focus.move_next())
+ {
+ if (focus.element()->event_id != event_id)
+ {
+ focus.element()->event_id = event_id;
+ focus.element()->on_focus_gained();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_focus_lost (
+ )
+ {
+ ++event_id;
+ focus.reset();
+ while (focus.move_next())
+ {
+ if (focus.element()->event_id != event_id)
+ {
+ focus.element()->event_id = event_id;
+ focus.element()->on_focus_lost();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ )
+ {
+ lastx = x;
+ lasty = y;
+
+ ++event_id;
+ mouse_click.reset();
+ while (mouse_click.move_next())
+ {
+ if (mouse_click.element()->event_id != event_id)
+ {
+ mouse_click.element()->event_id = event_id;
+ mouse_click.element()->on_mouse_down(btn,state,x,y,is_double_click);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ lastx = x;
+ lasty = y;
+
+ ++event_id;
+ mouse_click.reset();
+ while (mouse_click.move_next())
+ {
+ if (mouse_click.element()->event_id != event_id)
+ {
+ mouse_click.element()->event_id = event_id;
+ mouse_click.element()->on_mouse_up(btn,state,x,y);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ lastx = x;
+ lasty = y;
+
+ ++event_id;
+ mouse_move.reset();
+ while (mouse_move.move_next())
+ {
+ if (mouse_move.element()->event_id != event_id)
+ {
+ mouse_move.element()->event_id = event_id;
+ mouse_move.element()->on_mouse_move(state,x,y);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_mouse_leave (
+ )
+ {
+ lastx = -1;
+ lasty = -1;
+
+ ++event_id;
+ mouse_move.reset();
+ while (mouse_move.move_next())
+ {
+ if (mouse_move.element()->event_id != event_id)
+ {
+ mouse_move.element()->event_id = event_id;
+ mouse_move.element()->on_mouse_leave();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_mouse_enter (
+ )
+ {
+ ++event_id;
+ mouse_move.reset();
+ while (mouse_move.move_next())
+ {
+ if (mouse_move.element()->event_id != event_id)
+ {
+ mouse_move.element()->event_id = event_id;
+ mouse_move.element()->on_mouse_enter();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_wheel_up (
+ unsigned long state
+ )
+ {
+ ++event_id;
+ mouse_wheel.reset();
+ while (mouse_wheel.move_next())
+ {
+ if (mouse_wheel.element()->event_id != event_id)
+ {
+ mouse_wheel.element()->event_id = event_id;
+ mouse_wheel.element()->on_wheel_up(state);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_wheel_down (
+ unsigned long state
+ )
+ {
+ ++event_id;
+ mouse_wheel.reset();
+ while (mouse_wheel.move_next())
+ {
+ if (mouse_wheel.element()->event_id != event_id)
+ {
+ mouse_wheel.element()->event_id = event_id;
+ mouse_wheel.element()->on_wheel_down(state);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable_window::
+ on_string_put (
+ const std::wstring &str
+ )
+ {
+ ++event_id;
+ string_put.reset();
+ while (string_put.move_next())
+ {
+ if (string_put.element()->event_id != event_id)
+ {
+ string_put.element()->event_id = event_id;
+ string_put.element()->on_string_put(str);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------- drawable object ----------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void drawable::
+ enable_events (
+ )
+ {
+ auto_mutex M(m);
+ if (enabled_events == false)
+ {
+ enabled_events = true;
+ drawable* temp = this;
+ long zo = z_order_value;
+
+ drawable_window::set_of_drawables* sod = parent.widgets[zo];
+ if (sod == 0)
+ {
+ // this drawable is the first widget at this z order so we need
+ // to make its containing set
+ drawable_window::set_of_drawables s;
+ s.add(temp);
+ parent.widgets.add(zo,s);
+ }
+ else
+ {
+ sod->add(temp);
+ }
+
+ temp = this;
+ parent.widget_set.add(temp);
+
+ if (events & MOUSE_MOVE)
+ {
+ temp = this;
+ parent.mouse_move.add(temp);
+ }
+ if (events & MOUSE_CLICK)
+ {
+ temp = this;
+ parent.mouse_click.add(temp);
+ }
+ if (events & MOUSE_WHEEL)
+ {
+ temp = this;
+ parent.mouse_wheel.add(temp);
+ }
+ if (events & WINDOW_RESIZED)
+ {
+ temp = this;
+ parent.window_resized.add(temp);
+ }
+ if (events & KEYBOARD_EVENTS)
+ {
+ temp = this;
+ parent.keyboard.add(temp);
+ }
+ if (events & FOCUS_EVENTS)
+ {
+ temp = this;
+ parent.focus.add(temp);
+ }
+ if (events & WINDOW_MOVED)
+ {
+ temp = this;
+ parent.window_moved.add(temp);
+ }
+ if (events & STRING_PUT)
+ {
+ temp = this;
+ parent.string_put.add(temp);
+ }
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable::
+ set_z_order (
+ long order
+ )
+ {
+ auto_mutex M(m);
+ if (order != z_order_value && enabled_events)
+ {
+ // first remove this drawable from widgets
+ drawable_window::set_of_drawables* sod = parent.widgets[z_order_value];
+ drawable* junk;
+ sod->remove(this,junk);
+
+ // if there are no more drawables at this z order then destroy the
+ // set for this order
+ if (sod->size() == 0)
+ parent.widgets.destroy(z_order_value);
+
+ // now add this drawable to its new z order
+ sod = parent.widgets[order];
+ if (sod == 0)
+ {
+ // this drawable is the first widget at this z order so we need
+ // to make its containing set
+ drawable_window::set_of_drawables s, x;
+ s.add(junk);
+ long temp_order = order;
+ parent.widgets.add(temp_order,s);
+ }
+ else
+ {
+ sod->add(junk);
+ }
+ parent.invalidate_rectangle(rect);
+
+ }
+ z_order_value = order;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void drawable::
+ disable_events (
+ )
+ {
+ auto_mutex M(m);
+ if (enabled_events)
+ {
+ enabled_events = false;
+ // first remove this drawable from widgets
+ drawable_window::set_of_drawables* sod = parent.widgets[z_order_value];
+ drawable* junk;
+ sod->remove(this,junk);
+
+ // if there are no more drawables at this z order then destroy the
+ // set for this order
+ if (sod->size() == 0)
+ parent.widgets.destroy(z_order_value);
+
+ parent.widget_set.remove(this,junk);
+
+ // now unregister this drawable from all the events it has registered for.
+ if (events & MOUSE_MOVE)
+ parent.mouse_move.remove(this,junk);
+ if (events & MOUSE_CLICK)
+ parent.mouse_click.remove(this,junk);
+ if (events & MOUSE_WHEEL)
+ parent.mouse_wheel.remove(this,junk);
+ if (events & WINDOW_RESIZED)
+ parent.window_resized.remove(this,junk);
+ if (events & KEYBOARD_EVENTS)
+ parent.keyboard.remove(this,junk);
+ if (events & FOCUS_EVENTS)
+ parent.focus.remove(this,junk);
+ if (events & WINDOW_MOVED)
+ parent.window_moved.remove(this,junk);
+ if (events & STRING_PUT)
+ parent.string_put.remove(this,junk);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ drawable::
+ ~drawable (
+ )
+ {
+ try
+ {
+ DLIB_ASSERT(events_are_enabled() == false,
+ "\tdrawable::~drawable()"
+ << "\n\tYou must disable events for drawable objects in their destructor by calling disable_events()."
+ << "\n\tthis: " << this
+ );
+ }
+ catch (std::exception& e)
+ {
+ std::cerr << e.what() << std::endl;
+ assert(false);
+ abort();
+ }
+ disable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DRAWABLe_CPP_
+
diff --git a/ml/dlib/dlib/gui_widgets/drawable.h b/ml/dlib/dlib/gui_widgets/drawable.h
new file mode 100644
index 000000000..a270b53c8
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/drawable.h
@@ -0,0 +1,527 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifndef DLIB_DRAWABLe_
+#define DLIB_DRAWABLe_
+
+#include <memory>
+
+#include "drawable_abstract.h"
+#include "../gui_core.h"
+#include "../set.h"
+#include "../binary_search_tree.h"
+#include "../algs.h"
+#include "../pixel.h"
+#include "fonts.h"
+#include "../matrix.h"
+#include "canvas_drawing.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class drawable_window
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class drawable;
+ class drawable_window : public base_window
+ {
+ /*!
+ INITIAL VALUE
+ - lastx == -1
+ - lasty == -1
+ - event_id == 1
+
+ CONVENTION
+ - bg_color == background_color()
+
+ - widgets == this binary search tree contains every drawable that is in
+ this window. It is a mapping of each drawable's z-order to a pointer
+ to said drawable.
+ - widget_set == a set that contains all the widgets in this window and
+ want to receive events.
+
+ - mouse_move == this is a set of drawables that are in this window and
+ want to receive the mouse movement events.
+ - mouse_wheel == this is a set of drawables that are in this window and
+ want to receive the mouse wheel events.
+ - mouse_click == this is a set of drawables that are in this window and
+ want to receive the mouse click events.
+ - window_resized == this is a set of drawables that are in this window and
+ want to receive the window_resized event.
+ - keyboard == this is a set of drawables that are in this window and
+ want to receive keyboard events.
+ - focus == this is a set of drawables that are in this window and
+ want to receive focus events.
+ - window_moved == this is a set of drawables that are in this window and
+ want to receive window move events.
+
+ - lastx == the x coordinate that we last saw the mouse at or -1 if the
+ mouse is outside this window.
+ - lasty == the y coordinate that we last saw the mouse at or -1 if the
+ mouse is outside this window.
+
+ - event_id == a number we use to tag events so we don't end up sending
+ an event to a drawable more than once. This could happen if one of the
+ event handlers does something to reset the enumerator while we are
+ dispatching events (e.g. creating a new widget).
+ !*/
+ public:
+
+ drawable_window(
+ bool resizable = true,
+ bool undecorated = false
+ ) :
+ base_window(resizable,undecorated),
+ bg_color(rgb_pixel(212,208,200)),
+ lastx(-1),
+ lasty(-1),
+ event_id(1)
+ {}
+
+ void set_background_color (
+ unsigned long red,
+ unsigned long green,
+ unsigned long blue
+ );
+
+ rgb_pixel background_color (
+ ) const;
+
+ virtual inline ~drawable_window()=0;
+
+ private:
+
+ void paint (
+ const canvas& c
+ );
+
+ protected:
+
+ void on_window_resized(
+ );
+
+ void on_window_moved(
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_leave (
+ );
+
+ void on_mouse_enter (
+ );
+
+ void on_wheel_up (
+ unsigned long state
+ );
+
+ void on_wheel_down (
+ unsigned long state
+ );
+
+ void on_focus_gained (
+ );
+
+ void on_focus_lost (
+ );
+
+ void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+
+ void on_string_put (
+ const std::wstring &str
+ );
+
+ void on_user_event (
+ void* p,
+ int i
+ );
+
+ private:
+
+ friend class drawable;
+
+
+ rgb_pixel bg_color;
+
+ typedef set<drawable*>::kernel_1a_c set_of_drawables;
+
+ binary_search_tree<long,set_of_drawables>::kernel_1a_c widgets;
+
+ set_of_drawables widget_set;
+ set_of_drawables mouse_move;
+ set_of_drawables mouse_wheel;
+ set_of_drawables mouse_click;
+ set_of_drawables window_resized;
+ set_of_drawables keyboard;
+ set_of_drawables focus;
+ set_of_drawables window_moved;
+ set_of_drawables string_put;
+
+ long lastx, lasty;
+ unsigned long event_id;
+
+
+ // restricted functions
+ drawable_window(drawable_window&); // copy constructor
+ drawable_window& operator=(drawable_window&); // assignment operator
+
+
+ };
+
+ drawable_window::~drawable_window(){ close_window();}
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class drawable
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ enum
+ {
+ MOUSE_MOVE = 1,
+ MOUSE_CLICK = 2,
+ MOUSE_WHEEL = 4,
+ WINDOW_RESIZED = 8,
+ KEYBOARD_EVENTS = 16,
+ FOCUS_EVENTS = 32,
+ WINDOW_MOVED = 64,
+ STRING_PUT = 128
+ };
+
+ class drawable
+ {
+
+ /*!
+ INITIAL VALUE
+ - enabled_events == false
+ - event_id == 0
+
+ CONVENTION
+ - events == a bitset specifying what events this drawable is to receive.
+
+ - z_order_value == z_order()
+
+ - if (this drawable has been added to the parent window's sets and
+ binary search tree) then
+ - enabled_events == true
+ - else
+ - enabled_events == false
+
+ - event_id == the id of the last event we got from our parent window
+ !*/
+
+ public:
+
+ friend class drawable_window;
+
+ drawable (
+ drawable_window& w,
+ unsigned long events_ = 0
+ ) :
+ m(w.wm),
+ parent(w),
+ hidden(false),
+ enabled(true),
+ lastx(w.lastx),
+ lasty(w.lasty),
+ mfont(default_font::get_font()),
+ z_order_value(0),
+ events(events_),
+ enabled_events(false),
+ event_id(0)
+ {}
+
+ virtual ~drawable (
+ );
+
+ long z_order (
+ ) const
+ {
+ m.lock();
+ long temp = z_order_value;
+ m.unlock();
+ return temp;
+ }
+
+ virtual void set_z_order (
+ long order
+ );
+
+ const rectangle get_rect (
+ ) const
+ {
+ auto_mutex M(m);
+ return rect;
+ }
+
+ long bottom (
+ ) const
+ {
+ auto_mutex M(m);
+ return rect.bottom();
+ }
+
+ long top (
+ ) const
+ {
+ auto_mutex M(m);
+ return rect.top();
+ }
+
+ long left (
+ ) const
+ {
+ auto_mutex M(m);
+ return rect.left();
+ }
+
+ long right (
+ ) const
+ {
+ auto_mutex M(m);
+ return rect.right();
+ }
+
+ long width (
+ ) const
+ {
+ auto_mutex M(m);
+ return rect.width();
+ }
+
+ long height (
+ ) const
+ {
+ auto_mutex M(m);
+ return rect.height();
+ }
+
+ bool is_enabled (
+ ) const
+ {
+ auto_mutex M(m);
+ return enabled;
+ }
+
+ virtual void enable (
+ )
+ {
+ auto_mutex M(m);
+ enabled = true;
+ parent.invalidate_rectangle(rect);
+ }
+
+ virtual void disable (
+ )
+ {
+ auto_mutex M(m);
+ enabled = false;
+ parent.invalidate_rectangle(rect);
+ }
+
+ virtual void set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+ parent.invalidate_rectangle(rect);
+ }
+
+ const std::shared_ptr<font> main_font (
+ ) const
+ {
+ auto_mutex M(m);
+ return mfont;
+ }
+
+ bool is_hidden (
+ ) const
+ {
+ auto_mutex M(m);
+ return hidden;
+ }
+
+ virtual void set_pos (
+ long x,
+ long y
+ )
+ {
+ m.lock();
+ rectangle old(rect);
+
+ const unsigned long width = rect.width();
+ const unsigned long height = rect.height();
+ rect.set_top(y);
+ rect.set_left(x);
+ rect.set_right(static_cast<long>(x+width)-1);
+ rect.set_bottom(static_cast<long>(y+height)-1);
+
+ parent.invalidate_rectangle(rect+old);
+ m.unlock();
+ }
+
+ virtual void show (
+ )
+ {
+ m.lock();
+ hidden = false;
+ parent.invalidate_rectangle(rect);
+ m.unlock();
+ }
+
+ virtual void hide (
+ )
+ {
+ m.lock();
+ hidden = true;
+ parent.invalidate_rectangle(rect);
+ m.unlock();
+ }
+
+ base_window& parent_window (
+ ) { return parent; }
+
+ const base_window& parent_window (
+ ) const { return parent; }
+
+ virtual int next_free_user_event_number (
+ )const { return 0; }
+
+ protected:
+ rectangle rect;
+ const rmutex& m;
+ drawable_window& parent;
+ bool hidden;
+ bool enabled;
+ const long& lastx;
+ const long& lasty;
+ std::shared_ptr<font> mfont;
+
+
+ void enable_events (
+ );
+
+ bool events_are_enabled (
+ ) const { auto_mutex M(m); return enabled_events; }
+
+ void disable_events (
+ );
+
+ private:
+
+ long z_order_value;
+ const unsigned long events;
+ bool enabled_events;
+ unsigned long event_id;
+
+
+ // restricted functions
+ drawable(drawable&); // copy constructor
+ drawable& operator=(drawable&); // assignment operator
+
+
+ protected:
+
+ virtual void draw (
+ const canvas& c
+ ) const=0;
+
+ virtual void on_user_event (
+ int
+ ){}
+
+ virtual void on_window_resized(
+ ){}
+
+ virtual void on_window_moved(
+ ){}
+
+ virtual void on_mouse_down (
+ unsigned long ,
+ unsigned long ,
+ long ,
+ long ,
+ bool
+ ){}
+
+ virtual void on_mouse_up (
+ unsigned long ,
+ unsigned long ,
+ long ,
+ long
+ ){}
+
+ virtual void on_mouse_move (
+ unsigned long ,
+ long ,
+ long
+ ){}
+
+ virtual void on_mouse_leave (
+ ){}
+
+ virtual void on_mouse_enter (
+ ){}
+
+ virtual void on_wheel_up (
+ unsigned long
+ ){}
+
+ virtual void on_wheel_down (
+ unsigned long
+ ){}
+
+ virtual void on_focus_gained (
+ ){}
+
+ virtual void on_focus_lost (
+ ){}
+
+ virtual void on_keydown (
+ unsigned long ,
+ bool ,
+ unsigned long
+ ){}
+
+ virtual void on_string_put (
+ const std::wstring&
+ ){}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "drawable.cpp"
+#endif
+
+#endif // DLIB_DRAWABLe_
+
diff --git a/ml/dlib/dlib/gui_widgets/drawable_abstract.h b/ml/dlib/dlib/gui_widgets/drawable_abstract.h
new file mode 100644
index 000000000..8f741d8bb
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/drawable_abstract.h
@@ -0,0 +1,717 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#undef DLIB_DRAWABLe_ABSTRACT_
+#ifdef DLIB_DRAWABLe_ABSTRACT_
+
+#include "../gui_core.h"
+#include "fonts_abstract.h"
+#include "canvas_drawing_abstract.h"
+
+namespace dlib
+{
+
+ /*!
+ GENERAL REMARKS
+ This file defines the drawable interface class and the drawable_window which
+ is just a window that is capable of displaying drawable objects (i.e. objects
+ that implement the drawable interface).
+
+ The drawable interface is a simple framework for creating more complex
+ graphical widgets. It provides a default set of functionality and a
+ set of events which a gui widget may use.
+
+ THREAD SAFETY
+ All objects and functions defined in this file are thread safe. You may
+ call them from any thread without serializing access to them.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class drawable_window
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class drawable_window : public base_window
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a window on the desktop that is capable of
+ containing drawable objects.
+
+ INITIAL STATE
+ The initial state of the drawable_window is to be hidden. This means
+ you need to call show() to make it appear.
+
+ EVENTS
+ The drawable_window object uses all the events provided by base_window
+ except for the on_window_close() event. This means that if you
+ define handlers for these events yourself you will have to call
+ the drawable_window's version of them so that the drawable_window
+ can continue to process and forward these events to its drawable
+ objects.
+ !*/
+ public:
+
+ drawable_window (
+ bool resizable = true,
+ bool undecorated = false
+ );
+ /*!
+ requires
+ - if (undecorated == true) then
+ - resizable == false
+ ensures
+ - #*this has been properly initialized
+ - #background_color() == rgb_pixel(212,208,200)
+ - if (resizable == true) then
+ - this window will be resizable by the user
+ - else
+ - this window will not be resizable by the user
+ - if (undecorated == true) then
+ - this window will not have any graphical elements outside
+ of its drawable area or appear in the system task bar.
+ (e.g. a popup menu)
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ - dlib::gui_error
+ This exception is thrown if there is an error while
+ creating this window.
+ !*/
+
+ virtual ~drawable_window(
+ )=0;
+ /*!
+ ensures
+ - if (this window has not already been closed) then
+ - closes the window
+ - does NOT trigger the on_window_close() event
+ - all resources associated with *this have been released
+ !*/
+
+ void set_background_color (
+ unsigned long red,
+ unsigned long green,
+ unsigned long blue
+ );
+ /*!
+ ensures
+ - #background_color().red == red
+ - #background_color().green == green
+ - #background_color().blue == blue
+ !*/
+
+ rgb_pixel background_color (
+ ) const;
+ /*!
+ ensures
+ - returns the background color this window paints its canvas
+ with before it passes it onto its drawable widgets
+ !*/
+
+ private:
+ // restricted functions
+ drawable_window(drawable_window&); // copy constructor
+ drawable_window& operator=(drawable_window&); // assignment operator
+
+ friend class drawable;
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class drawable
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ enum
+ {
+ MOUSE_MOVE = 1,
+ MOUSE_CLICK = 2,
+ MOUSE_WHEEL = 4,
+ WINDOW_RESIZED = 8,
+ KEYBOARD_EVENTS = 16,
+ FOCUS_EVENTS = 32,
+ WINDOW_MOVED = 64,
+ STRING_PUT = 128
+ };
+
+ class drawable
+ {
+ /*!
+ INITIAL VALUE
+ top() == 0
+ left() == 0
+ right() == -1
+ bottom() == -1
+ get_rect().is_empty() == true
+ is_hidden() == false
+ is_enabled() == true
+ z_order() == 0
+ main_font() == default_font::get_font()
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an interface that all drawable widgets implement. It
+ provides a standard method (draw()) to draw a widget onto a canvas
+ and many other convenient functions for drawable objects.
+
+ EVENT FORWARDING
+ All the events that come to a drawable object are forwarded from its
+ parent window. Additionally, there is no filtering. This means that
+ if a drawable registers to receive a certain kind of event then whenever
+ its parent window receives that event the drawable object will get a
+ forwarded copy of it as well even if the event occurred outside the
+ drawable's rectangle.
+
+ The only events that have anything in the way of filtering are the
+ draw() and on_user_event() events. draw() is only called on a drawable
+ object when that object is not hidden. on_user_event() is only called
+ for drawables that the on_user_event()'s first argument specifically
+ references. All other events are not filtered at all though.
+
+ Z ORDER
+ Z order defines the order in which drawable objects are drawn. The
+ lower numbered drawables are drawn first and then the higher numbered
+ ones. So a drawable with a z order of 0 is drawn before one with a
+ z order of 1 and so on.
+ !*/
+
+ public:
+
+ friend class drawable_window;
+
+ drawable (
+ drawable_window& w,
+ unsigned long events = 0
+ ) :
+ m(w.wm),
+ parent(w),
+ hidden(false),
+ enabled(true)
+ {}
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #parent_window() == w
+ - #*this will not receive any events or draw() requests until
+ enable_events() is called
+ - once events_are_enabled() == true this drawable will receive
+ the on_user_event() event. (i.e. you always get this event, you don't
+ have to enable it by setting something in the events bitset).
+ - if (events & MOUSE_MOVE) then
+ - once events_are_enabled() == true this drawable will receive
+ the following events related to mouse movement: on_mouse_move,
+ on_mouse_leave, and on_mouse_enter.
+ - if (events & MOUSE_CLICK) then
+ - once events_are_enabled() == true this drawable will receive
+ the following events related to mouse clicks: on_mouse_down and
+ on_mouse_up.
+ - if (events & MOUSE_WHEEL) then
+ - once events_are_enabled() == true this drawable will receive
+ the following events related to mouse wheel scrolling:
+ on_wheel_up and on_wheel_down.
+ - if (events & WINDOW_RESIZED) then
+ - once events_are_enabled() == true this drawable will receive
+ the following event related to its parent window resizing:
+ on_window_resized.
+ - if (events & KEYBOARD_EVENTS) then
+ - once events_are_enabled() == true this drawable will receive
+ the following keyboard event: on_keydown.
+ - if (events & FOCUS_EVENTS) then
+ - once events_are_enabled() == true this drawable will receive
+ the following focus events: on_focus_gained and on_focus_lost.
+ - if (events & WINDOW_MOVED) then
+ - once events_are_enabled() == true this drawable will receive
+ the following event related to its parent window moving:
+ on_window_moved.
+ - if (events & STRING_PUT) then
+ - once events_are_enabled() == true this drawable will receive
+ the following event related to wide character string input:
+ on_string_put.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~drawable (
+ );
+ /*!
+ requires
+ - events_are_enabled() == false
+ ensures
+ - any resources associated with *this have been released
+ - *this has been removed from its containing window parent_window() and
+ its parent window will no longer try to dispatch events to it.
+ Note that this does not trigger a redraw of the parent window. If you
+ want to do that you must do it yourself.
+ !*/
+
+ long z_order (
+ ) const;
+ /*!
+ ensures
+ - returns the z order for this drawable.
+ !*/
+
+ virtual void set_z_order (
+ long order
+ );
+ /*!
+ ensures
+ - #z_order() == order
+ - if (events_are_enabled() == true) then
+ - parent_window() is updated to reflect the new state of #*this
+ throws
+ - std::bad_alloc
+ !*/
+
+ const rectangle get_rect (
+ ) const;
+ /*!
+ ensures
+ - returns the rectangle that defines the area and position of this
+ drawable inside its containing window parent_window().
+ !*/
+
+ long bottom (
+ ) const;
+ /*!
+ ensures
+ - returns get_rect().bottom()
+ !*/
+
+ long top (
+ ) const;
+ /*!
+ ensures
+ - returns get_rect().top()
+ !*/
+
+ long left (
+ ) const;
+ /*!
+ ensures
+ - returns get_rect().left()
+ !*/
+
+ long right (
+ ) const;
+ /*!
+ ensures
+ - returns get_rect().right()
+ !*/
+
+ unsigned long width (
+ ) const;
+ /*!
+ ensures
+ - returns get_rect().width()
+ !*/
+
+ unsigned long height (
+ ) const;
+ /*!
+ ensures
+ - returns get_rect().height()
+ !*/
+
+ virtual void set_pos (
+ long x,
+ long y
+ );
+ /*!
+ ensures
+ - #top() == y
+ - #left() == x
+ - #width() == width()
+ - #height() == height()
+ - if (events_are_enabled() == true) then
+ - parent_window() is updated to reflect the new state of #*this
+ - i.e. This just sets the upper left corner of this drawable to the
+ location (x,y)
+ !*/
+
+ bool is_enabled (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object is enabled and false otherwise.
+ (it is up to derived classes to define exactly what it means to be
+ "enabled")
+ !*/
+
+ virtual void enable (
+ );
+ /*!
+ ensures
+ - #is_enabled() == true
+ - if (events_are_enabled() == true) then
+ - parent_window() is updated to reflect the new state of #*this
+ !*/
+
+ virtual void disable (
+ );
+ /*!
+ ensures
+ - #is_enabled() == false
+ - if (events_are_enabled() == true) then
+ - parent_window() is updated to reflect the new state of #*this
+ !*/
+
+ virtual void set_main_font (
+ const shared_ptr_thread_safe<font>& f
+ );
+ /*!
+ ensures
+ - #main_font() == f
+ - if (events_are_enabled() == true) then
+ - parent_window() is updated to reflect the new state of #*this
+ !*/
+
+ const shared_ptr_thread_safe<font> main_font (
+ ) const;
+ /*!
+ ensures
+ - returns the current main font being used by this widget
+ !*/
+
+ bool is_hidden (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object is NOT currently displayed on parent_window()
+ and false otherwise.
+ !*/
+
+ virtual void show (
+ );
+ /*!
+ ensures
+ - #is_hidden() == false
+ - if (events_are_enabled() == true) then
+ - parent_window() is updated to reflect the new state of #*this
+ !*/
+
+ virtual void hide (
+ );
+ /*!
+ ensures
+ - #is_hidden() == true
+ - if (events_are_enabled() == true) then
+ - parent_window() is updated to reflect the new state of #*this
+ !*/
+
+ drawable_window& parent_window (
+ );
+ /*!
+ ensures
+ - returns a reference to the drawable_window that this drawable is
+ being drawn on and receiving events from.
+ !*/
+
+ const drawable_window& parent_window (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the drawable_window that this drawable
+ is being drawn on and receiving events from.
+ !*/
+
+ virtual int next_free_user_event_number (
+ )const { return 0; }
+ /*!
+ ensures
+ - returns the smallest number, i, that is the next user event number you
+ can use in calls to parent.trigger_user_event((void*)this,i).
+ - This function exists because of the following scenario. Suppose
+ you make a class called derived1 that inherits from drawable and
+ in derived1 you use a user event to do something. Then suppose
+ you inherit from derived1 to make derived2. Now in derived2 you
+ may want to use a user event to do something as well. How are you
+ to know which user event numbers are in use already? This function
+ solves that problem. You would define derived1::next_free_user_event_number()
+ so that it returned a number bigger than any user event numbers used by
+ derived1 or its ancestors. Then derived2 could just call
+ derived1::next_free_user_event_number() to find out what numbers it could use.
+ !*/
+
+ protected:
+ /*!A drawable_protected_variables
+
+ These protected members are provided because they are needed to
+ implement drawable widgets.
+ !*/
+
+ // This is the rectangle that is returned by get_rect()
+ rectangle rect;
+
+ // This is the mutex used to serialize access to this class.
+ const rmutex& m;
+
+ // This is the parent window of this drawable
+ drawable_window& parent;
+
+ // This is the bool returned by is_hidden()
+ bool hidden;
+
+ // This is the bool returned by is_enabled()
+ bool enabled;
+
+ // This is the font pointer returned by main_font()
+ shared_ptr_thread_safe<font> mfont;
+
+ // This is the x coordinate that we last saw the mouse at or -1 if the mouse
+ // is outside the parent window.
+ const long& lastx;
+
+ // This is the y coordinate that we last saw the mouse at or -1 if the mouse
+ // is outside the parent window.
+ const long& lasty;
+
+
+ void enable_events (
+ );
+ /*!
+ ensures
+ - #events_are_enabled() == true
+ !*/
+
+ void disable_events (
+ );
+ /*!
+ ensures
+ - #events_are_enabled() == false
+ !*/
+
+ bool events_are_enabled (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object is receiving events and draw()
+ requests from its parent window.
+ - returns false otherwise
+ !*/
+
+ // ---------------- EVENT HANDLERS ------------------
+
+ virtual void on_user_event (
+ int i
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - is called whenever the parent window receives an on_user_event(p,i) event
+ where p == this. (i.e. this is just a redirect of on_user_event for
+ cases where the first argument of on_user_event is equal to the
+ this pointer).
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_window_resized(
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_window_resized() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_window_moved(
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_window_moved() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_mouse_down() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_mouse_up() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - x == lastx
+ - y == lasty
+ - this is just the base_window::on_mouse_move() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_mouse_leave (
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_mouse_leave() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_mouse_enter (
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_mouse_enter() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_wheel_up (
+ unsigned long state
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_wheel_up() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_wheel_down (
+ unsigned long state
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_wheel_down() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_focus_gained (
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_focus_gained() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_focus_lost (
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_focus_lost() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_keydown() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void on_string_put (
+ const std::wstring &str
+ ){}
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - this is just the base_window::on_put_string() event forwarded to
+ this object. See the gui_core specs for the details about this event.
+ ensures
+ - does not change the state of mutex m.
+ !*/
+
+ virtual void draw (
+ const canvas& c
+ ) const=0;
+ /*!
+ requires
+ - events_are_enabled() == true
+ - mutex m is locked
+ - is_hidden() == false
+ - is called by parent_window() when it needs to repaint itself.
+ - c == the canvas object for the area of parent_window() that needs
+ to be repainted.
+ ensures
+ - does not change the state of mutex m.
+ - draws the area of *this that intersects with the canvas onto
+ the canvas object c.
+ !*/
+
+ private:
+
+ // restricted functions
+ drawable(drawable&); // copy constructor
+ drawable& operator=(drawable&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DRAWABLe_ABSTRACT_
+
diff --git a/ml/dlib/dlib/gui_widgets/fonts.cpp b/ml/dlib/dlib/gui_widgets/fonts.cpp
new file mode 100644
index 000000000..dfbf9f720
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/fonts.cpp
@@ -0,0 +1,673 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt, Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FONTs_CPP_
+#define DLIB_FONTs_CPP_
+
+#include "fonts.h"
+
+#include <fstream>
+#include <memory>
+#include <sstream>
+
+#include "../serialize.h"
+#include "../base64.h"
+#include "../compress_stream.h"
+#include "../tokenizer.h"
+#include "nativefont.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string get_decoded_string_with_default_font_data()
+ {
+ dlib::base64::kernel_1a base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ /*
+ SOURCE BDF FILE (helvR12.bdf) COMMENTS
+ COMMENT $XConsortium: helvR12.bdf,v 1.15 95/01/26 18:02:58 gildea Exp $
+ COMMENT $Id: helvR12.bdf,v 1.26 2004-11-28 20:08:46+00 mgk25 Rel $
+ COMMENT
+ COMMENT +
+ COMMENT Copyright 1984-1989, 1994 Adobe Systems Incorporated.
+ COMMENT Copyright 1988, 1994 Digital Equipment Corporation.
+ COMMENT
+ COMMENT Adobe is a trademark of Adobe Systems Incorporated which may be
+ COMMENT registered in certain jurisdictions.
+ COMMENT Permission to use these trademarks is hereby granted only in
+ COMMENT association with the images described in this file.
+ COMMENT
+ COMMENT Permission to use, copy, modify, distribute and sell this software
+ COMMENT and its documentation for any purpose and without fee is hereby
+ COMMENT granted, provided that the above copyright notices appear in all
+ COMMENT copies and that both those copyright notices and this permission
+ COMMENT notice appear in supporting documentation, and that the names of
+ COMMENT Adobe Systems and Digital Equipment Corporation not be used in
+ COMMENT advertising or publicity pertaining to distribution of the software
+ COMMENT without specific, written prior permission. Adobe Systems and
+ COMMENT Digital Equipment Corporation make no representations about the
+ COMMENT suitability of this software for any purpose. It is provided "as
+ COMMENT is" without express or implied warranty.
+ COMMENT -
+ */
+
+ // The base64 encoded data we want to decode and return.
+ sout << "AXF+zOQzCgGitrKiOCGEL4hlIv1ZenWJyjMQ4rJ6f/oPMeHqsZn+8XnpehwFQTz3dtUGlZRAUoOa";
+ sout << "uVo8UiplcFxuK69A+94rpMCMAyEeeOwZ/tRzkX4eKuU3L4xtsJDknMiYUNKaMrYimb1QJ0E+SRqQ";
+ sout << "wATrMTecYNZvJJm02WibiwE4cJ5scvkHNl4KJT5QfdwRdGopTyUVdZvRvtbTLLjsJP0fQEQLqemf";
+ sout << "qPE4kDD79ehrBIwLO1Y6TzxtrrIoQR57zlwTUyLenqRtSN3VLtjWYd82cehRIlTLtuxBg2s+zZVq";
+ sout << "jNlNnYTSM+Swy06qnQgg+Dt0lhtlB9shR1OAlcfCtTW6HKoBk/FGeDmjTGW4bNCGv7RjgM6TlLDg";
+ sout << "ZYSSA6ZCCAKBgE++U32gLHCCiVkPTkkp9P6ioR+e3SSKRNm9p5MHf+ZQ3LJkW8KFJ/K9gKT1yvyv";
+ sout << "F99pAvOOq16tHRFvzBs+xZj/mUpH0lGIS7kLWr9oP2KuccVrz25aJn3kDruwTYoD+CYlOqtPO0Mv";
+ sout << "dEI0LUR0Ykp1M2rWo76fJ/fpzHjV7737hjkNPJ13nO72RMDr4R5V3uG7Dw7Ng+vGX3WgJZ4wh1JX";
+ sout << "pl2VMqC5JXccctzvnQvnuvBvRm7THgwQUgMKKT3WK6afUUVlJy8DHKuU4k1ibfVMxAmrwKdTUX2w";
+ sout << "cje3A05Qji3aop65qEdwgI5O17HIVoRQOG/na+XRMowOfUvI4H8Z4+JGACfRrQctgYDAM9eJzm8i";
+ sout << "PibyutmJfZBGg0a3oC75S5R9lTxEjPocnEyJRYNnmVnVAmKKbTbTsznuaD+D1XhPdr2t3A4bRTsp";
+ sout << "toKKtlFnd9YGwLWwONDwLnoQ/IXwyF7txrRHNSVToh772U0Aih/yn5vnmcMF750eiMzRAgXu5sbR";
+ sout << "VXEOVCiLgVevN5umkvjZt1eGTSSzDMrIvnv4nyOfaFsD+I76wQfgLqd71rheozGtjNc0AOTx4Ggc";
+ sout << "eUSFHTDAVfTExBzckurtyuIAqF986a0JLHCtsDpBa2wWNuiQYOH3/LX1zkdU2hdamhBW774bpEwr";
+ sout << "dguMxxOeDGOBgIlM5gxXGYXSf5IN3fUAEPfOPRxB7T+tpjFnWd7cg+JMabci3zhJ9ANaYT7HGeTX";
+ sout << "bulKnGHjYrR1BxdK3YeliogQRU4ytmxlyL5zlNFU/759mA8XSfIPMEZn9Vxkb00q1htF7REiDcr3";
+ sout << "kW1rtPAc7VQNEhT54vK/YF6rMvjO7kBZ/vLYo7E8e8hDKEnY8ucrC3KGmeo31Gei74BBcEbvJBd3";
+ sout << "/YAaIKgXWwU2wSUw9wLq2RwGwyguvKBx0J/gn27tjcVAHorRBwxzPpk8r+YPyN+SifSzEL7LEy1G";
+ sout << "lPHxmXTrcqnH9qraeAqXJUJvU8SJJpf/tmsAE+XSKD/kpVBnT5qXsJ1SRFS7MtfPjE1j/NYbaQBI";
+ sout << "bOrh81zaYCEJR0IKHWCIsu/MC3zKXfkxFgQ9XpYAuWjSSK64YpgkxSMe8VG8yYvigOw2ODg/z4FU";
+ sout << "+HpnEKF/M/mKfLKK1i/8BV7xcYVHrhEww1QznoFklJs/pEg3Kd5PE1lRii6hvTn6McVAkw+YbH9q";
+ sout << "/sg4gFIAvai64hMcZ1oIZYppj3ZN6KMdyhK5s4++ZS/YOV2nNhW73ovivyi2Tjg7lxjJJtsYrLKb";
+ sout << "zIN1slOICKYwBq42TFBcFXaZ6rf0Czd09tL+q6A1Ztgr3BNuhCenjhWN5ji0LccGYZo6bLTggRG/";
+ sout << "Uz6K3CBBU/byLs79c5qCohrr7rlpDSdbuR+aJgNiWoU6T0i2Tvua6h51LcWEHy5P2n146/Ae2di4";
+ sout << "eh20WQvclrsgm1oFTGD0Oe85GKOTA7vvwKmLBc1wwA0foTuxzVgj0TMTFBiYLTLG4ujUyBYy1N6e";
+ sout << "H8EKi8H+ZAlqezrjABO3BQr33ewdZL5IeJ4w7gdGUDA6+P+7cODcBW50X9++6YTnKctuEw6aXBpy";
+ sout << "GgcMfPE61G8YKBbFGFic3TVvGCLvre1iURv+F+hU4/ee6ILuPnpYnSXX2iCIK/kmkBse8805d4Qe";
+ sout << "DG/8rBW9ojvAgc0jX7CatPEMHGkcz+KIZoKMI7XXK4PJpGQUdq6EdIhJC4koXEynjwwXMeC+jJqH";
+ sout << "agwrlDNssq/8AA==";
+
+
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+
+ default_font::
+ default_font (
+ )
+ {
+ using namespace std;
+ l = new letter[256];
+
+ try
+ {
+ istringstream sin(get_decoded_string_with_default_font_data());
+
+ for (int i = 0; i < 256; ++i)
+ {
+ deserialize(l[i],sin);
+ }
+
+ }
+ catch (...)
+ {
+ delete [] l;
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const letter& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.w,out);
+ serialize(item.count,out);
+
+ for (unsigned long i = 0; i < item.count; ++i)
+ {
+ serialize(item.points[i].x,out);
+ serialize(item.points[i].y,out);
+ }
+ }
+ catch (serialization_error e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type letter");
+ }
+ }
+
+ void deserialize (
+ letter& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ if (item.points)
+ delete [] item.points;
+
+ deserialize(item.w,in);
+ deserialize(item.count,in);
+
+ if (item.count > 0)
+ item.points = new letter::point[item.count];
+ else
+ item.points = 0;
+
+ for (unsigned long i = 0; i < item.count; ++i)
+ {
+ deserialize(item.points[i].x,in);
+ deserialize(item.points[i].y,in);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.w = 0;
+ item.count = 0;
+ item.points = 0;
+ throw serialization_error(e.info + "\n while deserializing object of type letter");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace bdf_font_helpers
+ {
+ class bdf_parser
+ {
+ public:
+ bdf_parser( std::istream& in ) : in_( in )
+ {
+ std::string str_tmp;
+ int int_tmp;
+
+ str_tmp = "STARTFONT"; int_tmp = STARTFONT; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "FONTBOUNDINGBOX";int_tmp = FONTBOUNDINGBOX; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "DWIDTH"; int_tmp = DWIDTH; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "CHARS"; int_tmp = CHARS; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "STARTCHAR"; int_tmp = STARTCHAR; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "ENCODING"; int_tmp = ENCODING; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "BBX"; int_tmp = BBX; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "BITMAP"; int_tmp = BITMAP; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "ENDCHAR"; int_tmp = ENDCHAR; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "ENDFONT"; int_tmp = ENDFONT; keyword_map.add( str_tmp, int_tmp );
+ str_tmp = "DEFAULT_CHAR"; int_tmp = DEFAULT_CHAR; keyword_map.add( str_tmp, int_tmp );
+
+ tokzr.set_identifier_token( tokzr.uppercase_letters(), tokzr.uppercase_letters() + "_" );
+ tokzr.set_stream( in );
+
+ }
+
+ enum bdf_enums
+ {
+ NO_KEYWORD = 0,
+ STARTFONT = 1,
+ FONTBOUNDINGBOX = 2,
+ DWIDTH = 4,
+ DEFAULT_CHAR = 8,
+ CHARS = 16,
+ STARTCHAR = 32,
+ ENCODING = 64,
+ BBX = 128,
+ BITMAP = 256,
+ ENDCHAR = 512,
+ ENDFONT = 1024
+
+ };
+ struct header_info
+ {
+ int FBBx, FBBy, Xoff, Yoff;
+ int dwx0, dwy0;
+ bool has_global_dw;
+ long default_char;
+ };
+ struct char_info
+ {
+ int dwx0, dwy0;
+ int BBw, BBh, BBxoff0x, BByoff0y;
+ array2d<char> bitmap;
+ bool has_dw;
+ };
+ bool parse_header( header_info& info )
+ {
+ if ( required_keyword( STARTFONT ) == false )
+ return false; // parse_error: required keyword missing
+ info.has_global_dw = false;
+ int find = FONTBOUNDINGBOX | DWIDTH | DEFAULT_CHAR;
+ int stop = CHARS | STARTCHAR | ENCODING | BBX | BITMAP | ENDCHAR | ENDFONT;
+ int res;
+ while ( 1 )
+ {
+ res = find_keywords( find | stop );
+ if ( res & FONTBOUNDINGBOX )
+ {
+ in_ >> info.FBBx >> info.FBBy >> info.Xoff >> info.Yoff;
+ if ( in_.fail() )
+ return false; // parse_error
+ find &= ~FONTBOUNDINGBOX;
+ continue;
+ }
+ if ( res & DWIDTH )
+ {
+ in_ >> info.dwx0 >> info.dwy0;
+ if ( in_.fail() )
+ return false; // parse_error
+ find &= ~DWIDTH;
+ info.has_global_dw = true;
+ continue;
+ }
+ if ( res & DEFAULT_CHAR )
+ {
+ in_ >> info.default_char;
+ if ( in_.fail() )
+ return false; // parse_error
+ find &= ~DEFAULT_CHAR;
+ continue;
+ }
+ if ( res & NO_KEYWORD )
+ return false; // parse_error: unexpected EOF
+ break;
+ }
+ if ( res != CHARS || ( find & FONTBOUNDINGBOX ) )
+ return false; // parse_error: required keyword missing or unexpeced keyword
+ return true;
+ }
+ int parse_glyph( char_info& info, unichar& enc )
+ {
+ info.has_dw = false;
+ int e;
+ int res;
+ while ( 1 )
+ {
+ res = find_keywords( ENCODING );
+ if ( res != ENCODING )
+ return 0; // no more glyphs
+ in_ >> e;
+ if ( in_.fail() )
+ return -1; // parse_error
+ if ( e >= static_cast<int>(enc) )
+ break;
+ }
+ int find = BBX | DWIDTH;
+ int stop = STARTCHAR | ENCODING | BITMAP | ENDCHAR | ENDFONT;
+ while ( 1 )
+ {
+ res = find_keywords( find | stop );
+ if ( res & BBX )
+ {
+ in_ >> info.BBw >> info.BBh >> info.BBxoff0x >> info.BByoff0y;
+ if ( in_.fail() )
+ return -1; // parse_error
+ find &= ~BBX;
+ continue;
+ }
+ if ( res & DWIDTH )
+ {
+ in_ >> info.dwx0 >> info.dwy0;
+ if ( in_.fail() )
+ return -1; // parse_error
+ find &= ~DWIDTH;
+ info.has_dw = true;
+ continue;
+ }
+ if ( res & NO_KEYWORD )
+ return -1; // parse_error: unexpected EOF
+ break;
+ }
+ if ( res != BITMAP || ( find != NO_KEYWORD ) )
+ return -1; // parse_error: required keyword missing or unexpeced keyword
+ unsigned h = info.BBh;
+ unsigned w = ( info.BBw + 7 ) / 8 * 2;
+ info.bitmap.set_size( h, w );
+ for ( unsigned r = 0;r < h;r++ )
+ {
+ trim();
+ std::string str = "";
+ extract_hex(str);
+ if(str.size() < w)
+ return -1; // parse_error
+ for ( unsigned c = 0;c < w;c++ )
+ info.bitmap[r][c] = str[c];
+ }
+ if ( in_.fail() )
+ return -1; // parse_error
+ if ( required_keyword( ENDCHAR ) == false )
+ return -1; // parse_error: required keyword missing
+ enc = e;
+ return 1;
+ }
+ private:
+ map<std::string, int>::kernel_1a_c keyword_map;
+ tokenizer::kernel_1a_c tokzr;
+ std::istream& in_;
+ void extract_hex(std::string& str)
+ {
+ int type;
+ std::string token;
+ while ( 1 )
+ {
+ type = tokzr.peek_type();
+ if ( type == tokenizer::kernel_1a_c::IDENTIFIER || type == tokenizer::kernel_1a_c::NUMBER )
+ {
+ tokzr.get_token( type, token );
+ str += token;
+ continue;
+ }
+ break;
+ }
+ }
+ void trim()
+ {
+ int type;
+ std::string token;
+ while ( 1 )
+ {
+ type = tokzr.peek_type();
+ if ( type == tokenizer::kernel_1a_c::WHITE_SPACE || type == tokenizer::kernel_1a_c::END_OF_LINE )
+ {
+ tokzr.get_token( type, token );
+ continue;
+ }
+ break;
+ }
+ }
+ bool required_keyword( int kw )
+ {
+ int type;
+ std::string token;
+ while ( 1 )
+ {
+ tokzr.get_token( type, token );
+ if ( type == tokenizer::kernel_1a_c::WHITE_SPACE || type == tokenizer::kernel_1a_c::END_OF_LINE )
+ continue;
+ if ( type != tokenizer::kernel_1a_c::IDENTIFIER || keyword_map.is_in_domain( token ) == false || ( keyword_map[token] & kw ) == 0 )
+ return false;
+ break;
+ }
+ return true;
+ }
+ int find_keywords( int find )
+ {
+ int type;
+ std::string token;
+ while ( 1 )
+ {
+ tokzr.get_token( type, token );
+ if ( type == tokenizer::kernel_1a_c::END_OF_FILE )
+ return NO_KEYWORD;
+ if ( type == tokenizer::kernel_1a_c::IDENTIFIER && keyword_map.is_in_domain( token ) == true )
+ {
+ int kw = keyword_map[token];
+ if ( kw & find )
+ return kw;
+ }
+ }
+ return true;
+ }
+
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// bdf_font functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ bdf_font::bdf_font(
+ long default_char_
+ ) :
+ default_char(0),
+ is_initialized( false ),
+ right_overflow_( 0 ),
+ has_global_width( false ),
+ specified_default_char( default_char_ )
+ {
+ // make sure gl contains at least one letter
+ gl.resize(1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void bdf_font::adjust_metrics(
+ )
+ {
+ if ( is_initialized == false )
+ return;
+ // set starting values for fbb
+ if ( gl[default_char].num_of_points() > 0 )
+ {
+ letter& g = gl[default_char];
+ fbb.set_top( g[0].y );
+ fbb.set_bottom( g[0].y );
+ fbb.set_left( g[0].x );
+ fbb.set_right( g[0].x );
+ }
+ else
+ {
+ // ok, the default char was a space
+ // let's just choose some safe arbitrary values then...
+ fbb.set_top( 10000 );
+ fbb.set_bottom( -10000 );
+ fbb.set_left( 10000 );
+ fbb.set_right( -10000 );
+ }
+ right_overflow_ = 0;
+ for ( unichar n = 0; n < gl.size(); n++ )
+ {
+ letter& g = gl[n];
+ unsigned short nr_pts = g.num_of_points();
+ for ( unsigned short k = 0;k < nr_pts;k++ )
+ {
+ fbb.set_top( std::min( fbb.top(), (long)g[k].y ) );
+ fbb.set_left( std::min( fbb.left(), (long)g[k].x ) );
+ fbb.set_bottom( std::max( fbb.bottom(), (long)g[k].y ) );
+ fbb.set_right( std::max( fbb.right(), (long)g[k].x ) );
+ right_overflow_ = std::max( right_overflow_, (unsigned long)(g[k].x - g.width()) ); // superfluous?
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long bdf_font::
+ read_bdf_file(
+ std::istream& in,
+ unichar max_enc,
+ unichar min_enc
+ )
+ {
+ using namespace bdf_font_helpers;
+
+ bdf_parser parser( in );
+ bdf_parser::header_info hinfo;
+ bdf_parser::char_info cinfo;
+
+ gl.resize(max_enc+1);
+ hinfo.default_char = - 1;
+ if ( is_initialized == false || static_cast<std::streamoff>(in.tellg()) == std::ios::beg )
+ {
+ if ( parser.parse_header( hinfo ) == false )
+ return 0; // parse_error: invalid or missing header
+ }
+ else
+ {
+ // not start of file, so use values from previous read.
+ hinfo.has_global_dw = has_global_width;
+ hinfo.dwx0 = global_width;
+ }
+ int res;
+ unichar nr_letters_added = 0;
+ unsigned width;
+ for ( unichar n = min_enc; n <= max_enc; n++ )
+ {
+ if ( in.eof() )
+ break;
+ long pos = in.tellg();
+ res = parser.parse_glyph( cinfo, n );
+ if ( res < 0 )
+ return 0; // parse_error
+ if ( res == 0 )
+ continue;
+ if ( n > max_enc )
+ {
+ in.seekg( pos );
+ break;
+ }
+
+ if ( cinfo.has_dw == false )
+ {
+ if ( hinfo.has_global_dw == false )
+ return 0; // neither width info for the glyph, nor for the font as a whole (monospace).
+ width = hinfo.dwx0;
+ }
+ else
+ width = cinfo.dwx0;
+
+
+ if ( bitmap_to_letter( cinfo.bitmap, n, width, cinfo.BBxoff0x, cinfo.BByoff0y ) == false )
+ return 0;
+ nr_letters_added++;
+
+ if ( is_initialized == false )
+ {
+ // Bonding rectangle for the font.
+ fbb.set_top( -( hinfo.Yoff + hinfo.FBBy - 1 ) );
+ fbb.set_bottom( -hinfo.Yoff );
+ fbb.set_left( hinfo.Xoff );
+ fbb.set_right( hinfo.Xoff + hinfo.FBBx - 1 );
+ // We need to compute this after all the glyphs are loaded.
+ right_overflow_ = 0;
+ // set this to something valid now, just in case.
+ default_char = n;
+ // Save any global width in case we later read from the same file.
+ has_global_width = hinfo.has_global_dw;
+ if ( has_global_width )
+ global_width = hinfo.dwx0;
+ // dont override value specified in the constructor with value specified in the file
+ if ( specified_default_char < 0 && hinfo.default_char >= 0 )
+ specified_default_char = hinfo.default_char;
+
+ is_initialized = true;
+ }
+ }
+ if ( is_initialized == false )
+ return 0; // Not a single glyph was found within the specified range.
+
+ if ( specified_default_char >= 0 )
+ default_char = specified_default_char;
+ // no default char specified, try find something sane.
+ else
+ default_char = 0;
+
+ return nr_letters_added;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool bdf_font::
+ bitmap_to_letter(
+ array2d<char>& bitmap,
+ unichar enc,
+ unsigned long width,
+ int x_offset,
+ int y_offset
+ )
+ {
+ unsigned nr_points = 0;
+ bitmap.reset();
+ while ( bitmap.move_next() )
+ {
+ unsigned char ch = bitmap.element();
+ if ( ch > '9' )
+ ch -= 'A' - '9' - 1;
+ ch -= '0';
+ if ( ch > 0xF )
+ return false; // parse error: invalid hex digit
+ bitmap.element() = ch;
+ if ( ch & 8 )
+ nr_points++;
+ if ( ch & 4 )
+ nr_points++;
+ if ( ch & 2 )
+ nr_points++;
+ if ( ch & 1 )
+ nr_points++;
+ }
+
+ letter( width, nr_points ).swap(gl[enc]);
+
+ unsigned index = 0;
+ for ( int r = 0;r < bitmap.nr();r++ )
+ {
+ for ( int c = 0;c < bitmap.nc();c++ )
+ {
+ int x = x_offset + c * 4;
+ int y = -( y_offset + bitmap.nr() - r - 1 );
+ char ch = bitmap[r][c];
+ letter& glyph = gl[enc];
+ if ( ch & 8 )
+ {
+ glyph[index] = letter::point( x, y );
+ right_overflow_ = std::max( right_overflow_, x - width );
+ index++;
+ }
+ if ( ch & 4 )
+ {
+ glyph[index] = letter::point( x + 1, y );
+ right_overflow_ = std::max( right_overflow_, x + 1 - width );
+ index++;
+ }
+ if ( ch & 2 )
+ {
+ glyph[index] = letter::point( x + 2, y );
+ right_overflow_ = std::max( right_overflow_, x + 2 - width );
+ index++;
+ }
+ if ( ch & 1 )
+ {
+ glyph[index] = letter::point( x + 3, y );
+ right_overflow_ = std::max( right_overflow_, x + 3 - width );
+ index++;
+ }
+ }
+ }
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::shared_ptr<font> get_native_font (
+ )
+ {
+ return nativefont::native_font::get_font();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FONTs_CPP_
+
diff --git a/ml/dlib/dlib/gui_widgets/fonts.h b/ml/dlib/dlib/gui_widgets/fonts.h
new file mode 100644
index 000000000..5d3181aaa
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/fonts.h
@@ -0,0 +1,628 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt, Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifndef DLIB_FONTs_
+#define DLIB_FONTs_
+
+#include <memory>
+#include <string>
+
+#include "fonts_abstract.h"
+#include "../gui_core.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "../unicode.h"
+#include "../array.h"
+#include "../array2d.h"
+#include "../threads.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class letter
+ {
+ /*!
+ INITIAL VALUE
+ - defined by constructor
+
+ CONVENTION
+ - if (points != 0) then
+ - points == an array of count point structs
+ - w == width()
+ - count == num_of_points()
+ !*/
+ public:
+ struct point
+ {
+ point (){}
+
+ point (
+ signed char x_,
+ signed char y_
+ ) :
+ x(x_),
+ y(y_)
+ {}
+
+ signed char x;
+ signed char y;
+ };
+
+ letter (
+ ) :
+ points(0),
+ w(0),
+ count(0)
+ {}
+
+ letter (
+ unsigned short width_,
+ unsigned short point_count
+ ) :
+ points(new point[point_count]),
+ w(width_),
+ count(point_count)
+ {}
+
+ ~letter(
+ )
+ {
+ if (points)
+ delete [] points;
+ }
+
+ unsigned short width (
+ ) const { return w; }
+
+ unsigned short num_of_points (
+ ) const { return count;}
+
+ point& operator[] (
+ unsigned short i
+ )
+ {
+ DLIB_ASSERT (i < num_of_points(),
+ "\tvoid letter::operator[]()"
+ << "\n\ti: " << i
+ << "\n\tnum_of_points(): " << num_of_points() );
+ return points[i];
+ }
+
+ const point& operator[] (
+ unsigned short i
+ ) const
+ {
+ DLIB_ASSERT (i < num_of_points(),
+ "\tvoid letter::operator[]()"
+ << "\n\ti: " << i
+ << "\n\tnum_of_points(): " << num_of_points() );
+ return points[i];
+ }
+
+ friend void serialize (
+ const letter& item,
+ std::ostream& out
+ );
+
+ friend void deserialize (
+ letter& item,
+ std::istream& in
+ );
+
+ void swap (
+ letter& item
+ )
+ {
+ exchange(points, item.points);
+ exchange(w, item.w);
+ exchange(count, item.count);
+ }
+
+ private:
+ // restricted functions
+ letter(letter&); // copy constructor
+ letter& operator=(letter&); // assignment operator
+
+ point* points;
+ unsigned short w;
+ unsigned short count;
+ };
+
+ inline void swap (
+ letter& a,
+ letter& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ class font
+ {
+ public:
+ virtual ~font() {}
+
+ virtual bool has_character (
+ unichar ch
+ )const=0;
+ bool has_character(char ch) const { return this->has_character(zero_extend_cast<unichar>(ch)); }
+ bool has_character(wchar_t ch) const { return this->has_character(zero_extend_cast<unichar>(ch)); }
+
+ const letter& operator[] (char ch) const { return (*this)[zero_extend_cast<unichar>(ch)]; };
+ const letter& operator[] (wchar_t ch)const { return (*this)[zero_extend_cast<unichar>(ch)]; };
+
+ virtual const letter& operator[] (
+ unichar ch
+ )const=0;
+
+ virtual unsigned long height (
+ ) const = 0;
+
+ virtual unsigned long ascender (
+ ) const = 0;
+
+ virtual unsigned long left_overflow (
+ ) const = 0;
+
+ virtual unsigned long right_overflow (
+ ) const = 0;
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, typename traits, typename alloc>
+ void compute_size (
+ const std::basic_string<T,traits,alloc>& str,
+ unsigned long& width,
+ unsigned long& height,
+ typename std::basic_string<T,traits,alloc>::size_type first = 0,
+ typename std::basic_string<T,traits,alloc>::size_type last = (std::basic_string<T,traits,alloc>::npos)
+ ) const
+ {
+ typedef std::basic_string<T,traits,alloc> string;
+ DLIB_ASSERT ( (last == string::npos) || (first <= last && last < str.size()) ,
+ "\tvoid font::compute_size()"
+ << "\n\tlast == string::npos: " << ((last == string::npos)?"true":"false")
+ << "\n\tfirst: " << (unsigned long)first
+ << "\n\tlast: " << (unsigned long)last
+ << "\n\tstr.size(): " << (unsigned long)str.size() );
+
+ unsigned long line_width = 0;
+ unsigned long newlines = 0;
+ width = 0;
+ height = 0;
+
+ if (str.size())
+ {
+ if (last == string::npos)
+ last = str.size()-1;
+ const font& f = *this;
+
+ for (typename string::size_type i = first; i <= last; ++i)
+ {
+ // ignore '\r' characters
+ if (str[i] == '\r')
+ continue;
+
+ if (str[i] == '\n')
+ {
+ ++newlines;
+ width = std::max(width,line_width);
+ line_width = 0;
+ }
+ else
+ {
+ if (is_combining_char(str[i]) == false)
+ line_width += f[str[i]].width();
+ }
+ }
+ width = std::max(width,line_width);
+
+ height = (newlines+1)*f.height();
+ width += f.left_overflow() + f.right_overflow();
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, typename traits, typename alloc, typename pixel_type>
+ void draw_string (
+ const canvas& c,
+ const rectangle& rect,
+ const std::basic_string<T,traits,alloc>& str,
+ const pixel_type& color,
+ typename std::basic_string<T,traits,alloc>::size_type first = 0,
+ typename std::basic_string<T,traits,alloc>::size_type last = (std::basic_string<T,traits,alloc>::npos),
+ const rectangle area_ = rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max())
+ ) const
+ {
+ typedef std::basic_string<T,traits,alloc> string;
+ DLIB_ASSERT ( (last == string::npos) || (first <= last && last < str.size()) ,
+ "\tvoid font::draw_string()"
+ << "\n\tlast == string::npos: " << ((last == string::npos)?"true":"false")
+ << "\n\tfirst: " << (unsigned long)first
+ << "\n\tlast: " << (unsigned long)last
+ << "\n\tstr.size(): " << (unsigned long)str.size() );
+
+ rectangle area = rect.intersect(c).intersect(area_);
+ if (area.is_empty() || str.size() == 0)
+ return;
+
+ if (last == string::npos)
+ last = str.size()-1;
+
+ const font& f = *this;
+
+ long y_offset = rect.top() + f.ascender() - 1;
+
+ long pos = rect.left()+f.left_overflow();
+ for (typename string::size_type i = first; i <= last; ++i)
+ {
+ // ignore the '\r' character
+ if (str[i] == '\r')
+ continue;
+
+ // A combining character should be applied to the previous character, and we
+ // therefore make one step back. If a combining comes right after a newline,
+ // then there must be some kind of error in the string, and we don't combine.
+ if(is_combining_char(str[i]) &&
+ pos > rect.left() + static_cast<long>(f.left_overflow()))
+ {
+ pos -= f[str[i]].width();
+ }
+
+ if (str[i] == '\n')
+ {
+ y_offset += f.height();
+ pos = rect.left()+f.left_overflow();
+ continue;
+ }
+
+ // only look at letters in the intersection area
+ if (area.bottom() + static_cast<long>(f.height()) < y_offset)
+ {
+ // the string is now below our rectangle so we are done
+ break;
+ }
+ else if (area.left() > pos - static_cast<long>(f.left_overflow()) &&
+ pos + static_cast<long>(f[str[i]].width() + f.right_overflow()) < area.left() )
+ {
+ pos += f[str[i]].width();
+ continue;
+ }
+ else if (area.right() + static_cast<long>(f.right_overflow()) < pos)
+ {
+ // keep looking because there might be a '\n' in the string that
+ // will wrap us around and put us back into our rectangle.
+ continue;
+ }
+
+ // at this point in the loop we know that f[str[i]] overlaps
+ // horizontally with the intersection rectangle area.
+
+ const letter& l = f[str[i]];
+ for (unsigned short i = 0; i < l.num_of_points(); ++i)
+ {
+ const long x = l[i].x + pos;
+ const long y = l[i].y + y_offset;
+ // draw each pixel of the letter if it is inside the intersection
+ // rectangle
+ if (area.contains(x,y))
+ {
+ assign_pixel(c[y-c.top()][x-c.left()], color);
+ }
+ }
+
+ pos += l.width();
+ }
+ }
+ template <typename T, typename traits, typename alloc>
+ void draw_string (
+ const canvas& c,
+ const rectangle& rect,
+ const std::basic_string<T,traits,alloc>& str
+ ) const
+ {
+ draw_string(c,rect, str, 0, 0, (std::basic_string<T,traits,alloc>::npos),
+ rectangle(std::numeric_limits<long>::min(), std::numeric_limits<long>::min(),
+ std::numeric_limits<long>::max(), std::numeric_limits<long>::max()));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, typename traits, typename alloc>
+ const rectangle compute_cursor_rect (
+ const rectangle& rect,
+ const std::basic_string<T,traits,alloc>& str,
+ unsigned long index,
+ typename std::basic_string<T,traits,alloc>::size_type first = 0,
+ typename std::basic_string<T,traits,alloc>::size_type last = (std::basic_string<T,traits,alloc>::npos)
+ ) const
+ {
+ typedef std::basic_string<T,traits,alloc> string;
+ DLIB_ASSERT ( (last == string::npos) || (first <= last && last < str.size()) ,
+ "\trectangle font::compute_cursor_rect()"
+ << "\n\tlast == string::npos: " << ((last == string::npos)?"true":"false")
+ << "\n\tfirst: " << (unsigned long)first
+ << "\n\tlast: " << (unsigned long)last
+ << "\n\tindex: " << index
+ << "\n\tstr.size(): " << (unsigned long)str.size() );
+
+ const font& f = *this;
+
+ if (last == string::npos)
+ last = str.size()-1;
+
+ long x = f.left_overflow();
+ long y = 0;
+ int count = 0;
+
+ if (str.size() != 0)
+ {
+ for (typename string::size_type i = first; i <= last && i < index; ++i)
+ {
+ ++count;
+ if (str[i] == '\n')
+ {
+ x = f.left_overflow();
+ y += f.height();
+ count = 0;
+ }
+ else if (is_combining_char(str[i]) == false &&
+ str[i] != '\r')
+ {
+ x += f[str[i]].width();
+ }
+ }
+ }
+
+ x += rect.left();
+ y += rect.top();
+
+ // if the cursor is at the start of a line then back it up one pixel
+ if (count == 0)
+ --x;
+
+ return rectangle(x,y,x,y+f.height()-1);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, typename traits, typename alloc>
+ unsigned long compute_cursor_pos (
+ const rectangle& rect,
+ const std::basic_string<T,traits,alloc>& str,
+ long x,
+ long y,
+ typename std::basic_string<T,traits,alloc>::size_type first = 0,
+ typename std::basic_string<T,traits,alloc>::size_type last = (std::basic_string<T,traits,alloc>::npos)
+ ) const
+ {
+ typedef std::basic_string<T,traits,alloc> string;
+ DLIB_ASSERT ( (last == string::npos) || (first <= last && last < str.size()) ,
+ "\tunsigned long font::compute_cursor_pos()"
+ << "\n\tlast == string::npos: " << ((last == string::npos)?"true":"false")
+ << "\n\tfirst: " << (unsigned long)first
+ << "\n\tlast: " << (unsigned long)last
+ << "\n\tx: " << x
+ << "\n\ty: " << y
+ << "\n\tstr.size(): " << (unsigned long)str.size() );
+ const font& f = *this;
+
+
+ if (str.size() == 0)
+ return 0;
+ else if (first >= str.size())
+ return static_cast<unsigned long>(str.size());
+
+ y -= rect.top();
+ x -= rect.left();
+ if (y < 0)
+ y = 0;
+ if (x < 0)
+ x = 0;
+
+ if (last == string::npos)
+ last = str.size()-1;
+
+
+ // first figure out what line we are on
+ typename string::size_type pos = first;
+ long line = 0;
+ while (static_cast<unsigned long>(y) >= f.height())
+ {
+ ++line;
+ y -= f.height();
+ }
+
+ // find the start of the given line
+ for (typename string::size_type i = first; i <= last && line != 0; ++i)
+ {
+ if (str[i] == '\n')
+ {
+ --line;
+ pos = i + 1;
+ }
+ }
+
+
+ // now str[pos] == the first character of the start of the line
+ // that contains the cursor.
+ const typename string::size_type start_of_line = pos;
+
+
+ long cur_x = f.left_overflow();
+ // set the current cursor position to where the mouse clicked
+ while (pos <= last)
+ {
+ if (x <= cur_x || str[pos] == '\n')
+ break;
+
+ if (is_combining_char(str[pos]) == false &&
+ str[pos] != '\r')
+ {
+ cur_x += f[str[pos]].width();
+ }
+ ++pos;
+ }
+
+ if (x <= cur_x)
+ {
+ if (pos != start_of_line)
+ {
+ // we might actually be closer to the previous character
+ // so check for that and if so then jump us back one.
+ const long width = f[str[pos-1]].width();
+ if (x < cur_x - width/2)
+ --pos;
+ }
+ }
+ return static_cast<unsigned long>(pos);
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ const std::shared_ptr<font> get_native_font ();
+
+// ----------------------------------------------------------------------------------------
+
+ class default_font : public font
+ {
+ letter* l;
+
+
+ default_font(
+ );
+ default_font(default_font&); // copy constructor
+ default_font& operator=(default_font&); // assignment operator
+
+
+
+ public:
+ static const std::shared_ptr<font>& get_font (
+ )
+ {
+ static mutex m;
+ static std::shared_ptr<font> f;
+ auto_mutex M(m);
+ if (f.get() == 0)
+ f.reset(new default_font);
+
+ return f;
+ }
+
+ ~default_font(
+ )
+ {
+ delete [] l;
+ }
+
+ unsigned long height (
+ ) const { return 16; }
+
+ unsigned long ascender (
+ ) const { return 12; }
+
+ unsigned long left_overflow (
+ ) const { return 1; }
+
+ unsigned long right_overflow (
+ ) const { return 2; }
+
+ bool has_character (
+ unichar ch
+ )const
+ {
+ if (ch < 256 && (l[ch].width() != 0 || l[ch].num_of_points() != 0))
+ return true;
+ else
+ return false;
+ }
+
+ const letter& operator[] (
+ unichar ch
+ ) const
+ {
+ if(ch < 256)
+ return l[ch];
+ return l[0]; // just return one of the empty characters in this case
+ }
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ class bdf_font : public font
+ {
+
+ public:
+ bdf_font( long default_char_ = -1 );
+
+ long read_bdf_file( std::istream& in, unichar max_enc, unichar min_enc = 0 );
+ unsigned long height() const
+ {
+ return fbb.height();
+ }
+ unsigned long ascender() const
+ {
+ return std::max( 0L, 1 - fbb.top() );
+ }
+ unsigned long left_overflow() const
+ {
+ return std::max( 0L, -fbb.left() );
+ }
+ unsigned long right_overflow() const
+ {
+ return right_overflow_;
+ }
+ const letter& operator[] ( unichar uch ) const
+ {
+ if ( !has_character(uch) )
+ {
+ return gl[default_char];
+ }
+ return gl[uch];
+ }
+
+ bool has_character (
+ unichar ch
+ )const
+ {
+ if (ch < gl.size() && (gl[ch].width() != 0 || gl[ch].num_of_points() != 0))
+ return true;
+ else
+ return false;
+ }
+
+ void adjust_metrics();
+ private:
+
+ bool bitmap_to_letter( array2d<char>& bitmap, unichar enc, unsigned long width, int x_offset, int y_offset );
+
+ array<letter> gl;
+ unichar default_char; // if (is_intialized == true), then this MUST be an actual glyph
+ bool is_initialized;
+ rectangle fbb;
+ unsigned long right_overflow_;
+
+ unsigned global_width;
+ bool has_global_width;
+ long specified_default_char;
+
+ bdf_font( bdf_font& ); // copy constructor
+ bdf_font& operator=( bdf_font& ); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "fonts.cpp"
+#endif
+
+#endif // DLIB_FONTs_
+
diff --git a/ml/dlib/dlib/gui_widgets/fonts_abstract.h b/ml/dlib/dlib/gui_widgets/fonts_abstract.h
new file mode 100644
index 000000000..df194bccd
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/fonts_abstract.h
@@ -0,0 +1,492 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Nils Labugt, Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FONTs_ABSTRACT_
+#ifdef DLIB_FONTs_ABSTRACT_
+
+#include "../gui_core.h"
+#include <string>
+#include "../serialize.h"
+#include "../unicode.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class letter
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a letter in a font. It tells you the nominal
+ width of the letter and which pixels form the letter.
+
+ THREAD SAFETY
+ const versions of this object are thread safe but if you are going to
+ be modifying it then you must serialize access to it.
+ !*/
+ public:
+ struct point
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents one of the pixels of a letter.
+
+ The origin (i.e. (0,0)) of the coordinate plane is at the left
+ side of the letter's baseline. Also note that y is negative when
+ above the baseline and positive below (it is zero on the baseline
+ itself).
+
+ The x value is positive going to the right and negative to the left.
+ The meaning of a negative x value is that any points with a negative
+ x value will overlap with the preceding letter.
+ !*/
+
+ point (
+ );
+ /*!
+ ensures
+ - This constructor does nothing. The value of x and y
+ are undefined after its execution.
+ !*/
+
+ point (
+ signed char x_,
+ signed char y_
+ );
+ /*!
+ ensures
+ - #x == x_
+ - #y == y_
+ !*/
+
+
+ signed char x;
+ signed char y;
+ };
+
+ // ---------------------------------
+
+ letter (
+ );
+ /*!
+ ensures
+ - #width() == 0
+ - #num_of_points() == 0
+ !*/
+
+ letter (
+ unsigned short width_,
+ unsigned short point_count
+ );
+ /*!
+ ensures
+ - #width() == width_
+ - #num_of_points() == point_count
+ !*/
+
+ ~letter(
+ );
+ /*!
+ ensures
+ - any resources used by *this have been freed
+ !*/
+
+ const unsigned short width (
+ ) const;
+ /*!
+ ensures
+ - returns the width reserved for this letter in pixels. This is the
+ number of pixels that are reserved for this letter between adjoining
+ letters. It isn't necessarily the width of the actual letter itself.
+ (for example, you can make a letter with a width less than how wide it
+ actually is so that it overlaps with its neighbor letters.)
+ !*/
+
+ const unsigned short num_of_points (
+ ) const;
+ /*!
+ ensures
+ - returns the number of pixels that make up this letter.
+ !*/
+
+ point& operator[] (
+ unsigned short i
+ );
+ /*!
+ requires
+ - i < num_of_points()
+ ensures
+ - returns a non-const reference to the ith point in this letter.
+ !*/
+
+ const point& operator[] (
+ unsigned short i
+ ) const;
+ /*!
+ requires
+ - i < num_of_points()
+ ensures
+ - returns a const reference to the ith point in this letter.
+ !*/
+
+ void swap (
+ letter& item
+ );
+ /*!
+ ensures
+ - swaps *this with item
+ !*/
+
+ private:
+
+ // restricted functions
+ letter(letter&); // copy constructor
+ letter& operator=(letter&); // assignment operator
+ };
+
+ inline void swap (
+ letter& a,
+ letter& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const letter& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for letter objects
+ !*/
+
+ void deserialize (
+ letter& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for letter objects
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class font
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines an interface for a font type. It provides metrics
+ for the font and functions to help you draw strings on a canvas object.
+
+ THREAD SAFETY
+ All the functions in this class are thread safe.
+ !*/
+
+ public:
+
+ virtual bool has_character (
+ unichar ch
+ )const=0;
+ /*!
+ ensures
+ - if (this font has a glyph for the given character) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ bool has_character(char ch) const { return this->has_character(zero_extend_cast<unichar>(ch)); }
+ bool has_character(wchar_t ch) const { return this->has_character(zero_extend_cast<unichar>(ch)); }
+ /* Cast char and wchar_t to unichar correctly when char or wchar_t is a signed type */
+
+ virtual const letter& operator[] (
+ unichar ch
+ )const=0;
+ /*!
+ ensures
+ - if (has_character(ch) == true) then
+ - returns a letter object that tells you how to draw this character.
+ - else
+ - returns some default glyph for characters that aren't in this font.
+ !*/
+ const letter& operator[] (char ch) const { return (*this)[zero_extend_cast<unichar>(ch)]; };
+ const letter& operator[] (wchar_t ch) const { return (*this)[zero_extend_cast<unichar>(ch)]; };
+ /* Cast char and wchar_t to unichar correctly when char or wchar_t is a signed type */
+
+ virtual const unsigned long height (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the height in pixels of the tallest letter in the font
+ !*/
+
+ virtual const unsigned long ascender (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the height() minus the number of pixels below the baseline used
+ by the letter that hangs the lowest below the baseline.
+ !*/
+
+ virtual const unsigned long left_overflow (
+ ) const = 0;
+ /*!
+ ensures
+ - returns how far outside and to the left of its width a letter
+ from this font may set pixels. (i.e. how many extra pixels to its
+ left may a font use)
+ !*/
+
+ virtual const unsigned long right_overflow (
+ ) const = 0;
+ /*!
+ ensures
+ - returns how far outside and to the right of its width a letter
+ from this font may set pixels. (i.e. how many extra pixels to its
+ right may a font use)
+ !*/
+
+ template <typename T, typename traits, typename alloc>
+ void compute_size (
+ const std::basic_string<T,traits,alloc>& str,
+ unsigned long& width,
+ unsigned long& height,
+ typename std::basic_string<T,traits,alloc>::size_type first = 0,
+ typename std::basic_string<T,traits,alloc>::size_type last = std::basic_string<T,traits,alloc>::npos
+ ) const;
+ /*!
+ requires
+ - if (last != std::basic_string<T,traits,alloc>::npos) then
+ - first <= last
+ - last < str.size()
+ ensures
+ - all characters in str with an index < first are ignored by this
+ function.
+ - if (last != std::basic_string<T,traits,alloc>::npos) then
+ - all characters in str with an index > last are ignored by
+ this function.
+ - if (str.size() == 0) then
+ - #width == 0
+ - #height == 0
+ - else
+ - #width == sum of the widths of the characters in the widest
+ line in str + left_overflow() + right_overflow().
+ - #height == (count(str.begin(),str.end(),'\n')+1)*height()
+ !*/
+
+ template <typename T, typename traits, typename alloc, typename pixel_type>
+ void draw_string (
+ const canvas& c,
+ const rectangle& rect,
+ const std::basic_string<T,traits,alloc>& str,
+ const pixel_type& color = rgb_pixel(0,0,0),
+ typename std::basic_string<T,traits,alloc>::size_type first = 0,
+ typename std::basic_string<T,traits,alloc>::size_type last = std::basic_string<T,traits,alloc>::npos,
+ const rectangle area = rectangle(-infinity,-infinity,infinity,infinity)
+ ) const;
+ /*!
+ requires
+ - if (last != std::basic_string<T,traits,alloc>::npos) then
+ - first <= last
+ - last < str.size()
+ ensures
+ - all characters in str with an index < first are ignored by this
+ function.
+ - if (last != std::basic_string<T,traits,alloc>::npos) then
+ - all characters in str with an index > last are ignored by
+ this function.
+ - if (str.size() == 0) then
+ - does nothing
+ - else
+ - draws str on the given canvas at the position defined by rect.
+ Also uses the given pixel colors for the font color.
+ - If the string is too big to fit in rect then the right and
+ bottom sides of it will be clipped to make it fit.
+ - only the part of the string that is contained inside the area
+ rectangle will be drawn
+ !*/
+
+ template <typename T, typename traits, typename alloc>
+ const rectangle compute_cursor_rect (
+ const rectangle& rect,
+ const std::basic_string<T,traits,alloc>& str,
+ unsigned long index,
+ typename std::basic_string<T,traits,alloc>::size_type first = 0,
+ typename std::basic_string<T,traits,alloc>::size_type last = std::basic_string<T,traits,alloc>::npos
+ ) const;
+ /*!
+ requires
+ - if (last != std::basic_string<T,traits,alloc>::npos) then
+ - first <= last
+ - last < str.size()
+ ensures
+ - the returned rectangle has a width of 1 and a
+ height of this->height().
+ - computes the location of the cursor that would sit just before
+ the character str[index] if str were drawn on the screen by
+ draw_string(rect,str,...,first,last). The cursor location is
+ returned in the form of a rectangle.
+ - if (index < first) then
+ - the returned cursor will be just before the character str[first].
+ - if (last != std::basic_string<T,traits,alloc>::npos && index > last) then
+ - the returned cursor will be just after the character str[last]
+ - if (str.size() == 0) then
+ - the returned cursor will be just at the start of the rectangle where
+ str would be drawn if it wasn't empty.
+ - if (index > str.size()-1) then
+ - the returned cursor will be just after the character str[str.size()-1]
+ !*/
+
+ template <typename T, typename traits, typename alloc>
+ const unsigned long compute_cursor_pos (
+ const rectangle& rect,
+ const std::basic_string<T,traits,alloc>& str,
+ long x,
+ long y,
+ typename std::basic_string<T,traits,alloc>::size_type first = 0,
+ typename std::basic_string<T,traits,alloc>::size_type last = std::basic_string<T,traits,alloc>::npos
+ ) const;
+ /*!
+ requires
+ - if (last != std::basic_string<T,traits,alloc>::npos) then
+ - first <= last
+ - last < str.size()
+ ensures
+ - returns a number idx that has the following properties:
+ - if (first < str.size()) then
+ - first <= idx
+ - else
+ - idx == str.size()
+ - if (last != std::basic_string<T,traits,alloc>::npos) then
+ - idx <= last + 1
+ - compute_cursor_rect(rect,str,idx,first,last) == the cursor
+ position that is closest to the pixel (x,y)
+ !*/
+
+
+ private:
+
+ // restricted functions
+ font(font&); // copy constructor
+ font& operator=(font&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class default_font : public font
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the Helvetica 12 point font.
+
+ THREAD SAFETY
+ It is safe to call get_font() and access the returned font from any
+ thread and no synchronization is needed as long as it is called
+ after the main() function has been entered.
+ !*/
+
+ public:
+ static const shared_ptr_thread_safe<font> get_font(
+ );
+ /*!
+ ensures
+ - returns an instance of this font.
+ throws
+ - std::bad_alloc
+ This exception is thrown if there is a problem gathering the needed
+ memory for the font object.
+ !*/
+
+ private:
+
+ // restricted functions
+ default_font(); // normal constructor
+ default_font(default_font&); // copy constructor
+ default_font& operator=(default_font&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class bdf_font : public font
+ {
+
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a font object that is capable of loading of loading BDF (Glyph
+ Bitmap Distribution Format) font files.
+
+ THREAD SAFETY
+ If you only access this object via the functions in the parent class font
+ then this object is thread safe. But if you need to call any of the
+ functions introduced in this derived class then you need to serialize
+ access to this object while you call these functions.
+ !*/
+
+ public:
+
+ bdf_font(
+ long default_char = -1
+ );
+ /*!
+ ensures
+ - for all x:
+ - #has_character(x) == false
+ (i.e. this font starts out empty. You have to call read_bdf_file()
+ to load it with data)
+ - if (default_char == -1) then
+ - the letter returned by (*this)[ch] for values of
+ ch where has_character(ch) == false will be the
+ default glyph defined in the bdf file.
+ - else
+ - the letter returned by (*this)[ch] for values of
+ ch where has_character(ch) == false will be the
+ letter (*this)[default_char].
+ !*/
+
+ long read_bdf_file(
+ std::istream& in,
+ unichar max_enc,
+ unichar min_enc = 0
+ );
+ /*!
+ ensures
+ - attempts to read the font data from the given input stream into
+ *this. The input stream is expected to contain a valid BDF file.
+ - reads in characters with encodings in the range min_enc to max_enc
+ into this font. All characters in the font file outside this range
+ are ignored.
+ - returns the number of characters loaded into this font from the
+ given input stream.
+ !*/
+
+ void adjust_metrics();
+ /*!
+ ensures
+ - Computes metrics based on actual glyphs loaded, instead of using
+ the values in the bdf file header. (May be useful if loading glyphs
+ from more than one file or a small part of a file.)
+ !*/
+
+ private:
+
+ bdf_font( bdf_font& ); // copy constructor
+ bdf_font& operator=( bdf_font& ); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ const shared_ptr_thread_safe<font> get_native_font(
+ );
+ /*!
+ ensures
+ - returns a font object that uses the local font
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FONTs_ABSTRACT_
+
diff --git a/ml/dlib/dlib/gui_widgets/nativefont.h b/ml/dlib/dlib/gui_widgets/nativefont.h
new file mode 100644
index 000000000..2de0edfaa
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/nativefont.h
@@ -0,0 +1,612 @@
+// Copyright (C) 2006 Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IGG_FONT_RENDERER_H_
+#define DLIB_IGG_FONT_RENDERER_H_
+#include "../platform.h"
+
+
+#include "../gui_widgets.h"
+#include "../unicode.h"
+#include "../uintn.h"
+
+#include <map>
+#include <memory>
+
+#include <stdio.h>
+#include <string.h>
+#include <stdlib.h>
+#include <locale.h>
+
+#if defined(WIN32)
+#include <windows.h>
+#include <mbstring.h>
+#elif defined(POSIX)
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h>
+#include <X11/Xlib.h>
+#include <X11/Xutil.h>
+#include <X11/Xlocale.h>
+#endif
+
+namespace nativefont
+{
+// ----------------------------------------------------------------------------------------
+
+ namespace font_renderer
+ {
+ typedef dlib::uint8 byte;
+
+
+#ifdef WIN32
+ template <typename T> struct input2native_trait{
+ };
+ template <> struct input2native_trait<char>{
+ typedef char type_t;
+ };
+ template <> struct input2native_trait<wchar_t>{
+ typedef wchar_t type_t;
+ };
+ template <> struct input2native_trait<dlib::unichar>{
+ typedef wchar_t type_t;
+ };
+#endif
+ // T : N : sizeof_source_type
+ template <int N> struct size2inner_trait{
+ };
+ template <> struct size2inner_trait<1>{
+ typedef char type_t;
+ };
+ template <> struct size2inner_trait<2>{
+ typedef dlib::uint16 type_t;
+ };
+ template <> struct size2inner_trait<4>{
+ typedef dlib::unichar type_t;
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <int N> struct create_helper{ };
+ template <> struct create_helper<1>{
+ typedef char type_t;
+ type_t *istr;
+ int len;
+ create_helper(char *str){
+ len = (int)strlen(str);
+ istr = str;
+ }
+ ~create_helper(){}
+ };
+ template <> struct create_helper<2>{
+ typedef wchar_t type_t;
+ type_t *istr;
+ bool allocated;
+ int len;
+ create_helper(wchar_t *str){
+ allocated = false;
+ len = (int)wcslen(str);
+ istr = str;
+ }
+ create_helper(dlib::unichar *str){
+ allocated = true;
+ len = 0;
+ int unicount = 0;
+ dlib::unichar *p = str;
+ while(*p){
+ if (*p > 0xffff){
+ len += 2;
+ }else{
+ len++;
+ }
+ unicount++;
+ p++;
+ }
+ istr = new wchar_t[len+1];
+ for (int i = 0, wi = 0; i < unicount; ++i){
+ dlib::unichar high, low;
+ if (str[i] > 0xffff){
+ dlib::unichar_to_surrogate_pair(str[i], high, low);
+ istr[wi] = (wchar_t)high, istr[wi+1] = (wchar_t)low;
+ wi += 2;
+ }else{
+ istr[wi] = (wchar_t)str[i];
+ wi += 1;
+ }
+ }
+ istr[len] = L'\0';
+ }
+
+ ~create_helper(){
+ if (allocated) delete[] istr;
+ }
+ };
+ template <> struct create_helper<4>{
+ typedef wchar_t type_t;
+ type_t *istr;
+ int len;
+ create_helper(dlib::unichar *str){
+ len = (int)wcslen((wchar_t *)str);
+ istr = (type_t *)str;
+ }
+ ~create_helper(){}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class font_renderer{
+ public:
+
+ struct rgb_type{
+ byte r, g, b;
+ rgb_type() : r(0), g(0), b(0){};
+ rgb_type(byte r_, byte g_, byte b_) : r(r_), g(g_), b(b_){};
+ };
+ private:
+
+ byte *image;
+ int width, height;
+ void destroy(){
+ width = height = 0;
+ delete image;
+ image = 0;
+ }
+ struct vals_internal{
+ int width, height;
+#ifdef WIN32
+ COLORREF rgb2RGB(rgb_type &rgb){
+ return RGB(rgb.r, rgb.g, rgb.b);
+ }
+ HBITMAP hBmp, hBmpOld;
+ HDC hDCBmp;
+ BYTE *pixelint;
+ HFONT hFont, hFontOld;
+ HBRUSH hBrush;
+ int pix_width_prev, pix_height_prev;
+ bool first;
+ int ascender, descender;
+ int height_prev;
+ char attribute_prev;
+
+ template <typename T> void create(T *str, int height_want, bool italic, bool bold, bool fixed, rgb_type &background, rgb_type &foreground){
+ struct inner{
+ inline static BOOL GetTextExtentPoint32(HDC hDC, LPCSTR str, int len, LPSIZE lpsize){
+ return ::GetTextExtentPoint32A(hDC, str, len, lpsize);
+ }
+ inline static BOOL GetTextExtentPoint32(HDC hDC, LPCWSTR str, int len, LPSIZE lpsize){
+ return ::GetTextExtentPoint32W(hDC, str, len, lpsize);
+ }
+ inline static BOOL TextOut(HDC hDC, int nxstart, int nystart, LPCSTR str, int cbstr){
+ return ::TextOutA(hDC, nxstart, nystart, str, cbstr);
+ }
+ inline static BOOL TextOut(HDC hDC, int nxstart, int nystart, LPCWSTR str, int cbstr){
+ return ::TextOutW(hDC, nxstart, nystart, str, cbstr);
+ }
+ };
+
+ create_helper<sizeof(typename input2native_trait<T>::type_t)> ch(str);
+
+ if (hDCBmp == NULL){
+ HWND hWnd = GetDesktopWindow();
+ HDC hDC = GetDC(hWnd);
+ hDCBmp = CreateCompatibleDC(hDC);
+ ReleaseDC(hWnd, hDC);
+ }
+ SetTextColor(hDCBmp, rgb2RGB(foreground));
+ SetBkColor(hDCBmp, rgb2RGB(background));
+
+ char attribute = (italic ? 1 : 0) | (bold ? 2 : 0) | (fixed ? 4 : 0);
+ if (!hFont || height_prev != height || attribute != attribute_prev){
+ attribute_prev = attribute;
+ height_prev = height_want;
+ if (hFont){
+ SelectObject(hDCBmp, hFontOld);
+ DeleteObject(hFont);
+ }
+ hFont = CreateFont(height_want, 0, 0, 0, bold ? FW_BOLD : FW_DONTCARE, italic ? TRUE : FALSE,
+ FALSE, FALSE, DEFAULT_CHARSET, OUT_DEFAULT_PRECIS, CLIP_DEFAULT_PRECIS, DEFAULT_QUALITY,
+ fixed ? (FIXED_PITCH | FF_DONTCARE) : (VARIABLE_PITCH | FF_DONTCARE), NULL);
+ hFontOld = (HFONT)SelectObject(hDCBmp, hFont);
+ }
+
+ {
+ SIZE sz;
+ inner::GetTextExtentPoint32(hDCBmp, ch.istr, ch.len, &sz);
+ width = ((sz.cx + 3) / 4) * 4;
+ height = sz.cy;
+ }
+
+ if (pix_width_prev < width || pix_height_prev < height){
+ if (hBmp){
+ SelectObject(hDCBmp, hBmpOld);
+ DeleteObject(hBmp);
+ }
+ pix_width_prev = width * 2;
+ pix_height_prev = height * 2;
+ BITMAPINFO bi;
+ ZeroMemory(&bi, sizeof(bi));
+ bi.bmiHeader.biSize = sizeof(BITMAPINFOHEADER);
+ bi.bmiHeader.biBitCount = 24;
+ bi.bmiHeader.biPlanes = 1;
+ bi.bmiHeader.biWidth = pix_width_prev;
+ bi.bmiHeader.biHeight = -pix_height_prev;
+ hBmp = CreateDIBSection(NULL, &bi, DIB_RGB_COLORS, (void **)&pixelint, NULL, 0);
+ hBmpOld = (HBITMAP)SelectObject(hDCBmp, hBmp);
+ }
+
+ {
+ HBRUSH hBrush = CreateSolidBrush(rgb2RGB(background));
+ RECT rc;
+ rc.left = rc.top = 0;
+ rc.right = pix_width_prev;
+ rc.bottom = pix_height_prev;
+ FillRect(hDCBmp, &rc, hBrush);
+ }
+
+ inner::TextOut(hDCBmp, 0, 0, ch.istr, ch.len);
+ TEXTMETRICW tm;
+ GetTextMetricsW(hDCBmp,&tm);
+ ascender = tm.tmAscent;
+ descender = tm.tmDescent;
+ }
+
+ template <typename T> vals_internal(T *str, int height_want, bool italic = false,
+ bool bold = false, bool fixed = false, rgb_type background = rgb_type(), rgb_type foreground = rgb_type()){
+ first = true;
+ hFont = NULL;
+ hDCBmp = 0;
+ hBmpOld = 0;
+ hBmp = 0;
+ hDCBmp = 0;
+ pixelint = 0;
+ pix_width_prev = pix_height_prev = 0;
+ height_prev = -1;
+ attribute_prev = 0;
+ create(str, height_want, italic, bold, fixed, background, foreground);
+ first = false;
+ }
+
+ inline int get_ascender(){
+ return ascender;
+ }
+
+ inline int get_descender(){
+ return descender;
+ }
+
+ inline void get_pixel(int x, int y, byte &r, byte &g, byte &b){
+ byte *p = pixelint + (y * pix_width_prev + x) * 3;
+ r = *(p+2), g = *(p+1), b = *p;
+ }
+
+ void destroy(){
+ SelectObject(hDCBmp, hBmpOld);
+ DeleteObject(hBmp);
+ SelectObject(hDCBmp, hFontOld);
+ DeleteObject(hFont);
+ DeleteDC(hDCBmp);
+ hFont = NULL;
+ hDCBmp = 0;
+ hBmpOld = 0;
+ hBmp = 0;
+ hDCBmp = 0;
+ pixelint = 0;
+ }
+ ~vals_internal(){
+ destroy();
+ }
+#elif defined(POSIX)
+ XImage *ximg;
+ Display *d;
+ GC gc;
+ XFontSet fs;
+ Pixmap pix;
+ Colormap cmap;
+ int ascender, descender;
+ int pix_width_prev, pix_height_prev;
+ char fontset_prev[256];
+ unsigned long rgb2color(rgb_type col, Display *d, Colormap &cmap){
+ XColor xcol;
+ xcol.red = col.r * 257;
+ xcol.green = col.g * 257;
+ xcol.blue = col.b * 257;
+ XAllocColor(d, cmap, &xcol);
+ return xcol.pixel;
+ }
+ template <typename T> void create(T *str, int height_want, bool italic, bool bold, bool fixed, rgb_type background, rgb_type foreground){
+ struct inner{
+ inline static int XTextExtents (XFontSet fs, char *str, int len, XRectangle *ink, XRectangle *logical){
+ return XmbTextExtents(fs, str, len, ink, logical);
+ }
+ inline static int XTextExtents (XFontSet fs, wchar_t *str, int len, XRectangle *ink, XRectangle *logical){
+ return XwcTextExtents(fs, str, len, ink, logical);
+ }
+ inline static void XDrawString(Display *d, Window w, XFontSet fs, GC gc, int x, int y, char *str, int num_bytes){
+ XmbDrawString(d, w, fs, gc, x, y, str, num_bytes);
+ }
+ inline static void XDrawString(Display *d, Window w, XFontSet fs, GC gc, int x, int y, wchar_t *str, int num_bytes){
+ XwcDrawString(d, w, fs, gc, x, y, str, num_bytes);
+ }
+ };
+ create_helper<sizeof(T)> ch((typename size2inner_trait<sizeof(T)>::type_t *)str);
+ setlocale(LC_CTYPE, "");
+ if (d == NULL){
+ d = XOpenDisplay(NULL);
+ if (d == 0)
+ {
+ d = XOpenDisplay(":0.0");
+ if (d == 0)
+ {
+ throw dlib::gui_error("Unable to connect to the X display.");
+ }
+ }
+
+ cmap = DefaultColormap(d, DefaultScreen(d));
+ }
+ char fontset[256];
+ {
+ char *p = fontset;
+ p += sprintf(fontset, "-*-*-%s-%c-normal--%d-*-*-*-%c",
+ bold ? "bold" : "medium", italic ? 'i' : 'r', height_want, fixed ? 'c' : 'p');
+ if (fixed){
+ sprintf(p, ",-*-*-%s-%c-normal--%d-*-*-*-m",
+ bold ? "bold" : "medium", italic ? 'i' : 'r', height_want);
+ }
+ }
+ bool equal_font;
+ if (strcmp(fontset, fontset_prev) == 0){
+ equal_font = true;
+ }else{
+ equal_font = false;
+ strcpy(fontset_prev, fontset);
+ }
+
+ char **mlist;
+ int mcount;
+ char *def_str;
+ if (!equal_font){
+ if (fs){
+ XFreeFontSet(d, fs);
+ }
+ fs = XCreateFontSet(d, fontset, &mlist, &mcount, &def_str);
+ if (fs == NULL)
+ throw dlib::gui_error("gui_error: XCreateFontSet() failure");
+
+ XFontSetExtents *extent;
+ extent = XExtentsOfFontSet(fs);
+ ascender = -extent->max_logical_extent.y;
+ descender = extent->max_logical_extent.height - ascender;
+ XFreeStringList(mlist);
+ }
+ XRectangle ink, logical;
+ inner::XTextExtents (fs, ch.istr, ch.len, &ink, &logical);
+ width = logical.width;
+ height = height_want;
+
+ if (pix == None || pix_width_prev < width || pix_height_prev < height){
+ if (pix != None){
+ XFreeGC(d, gc);
+ XFreePixmap(d, pix);
+ }
+ pix_width_prev = width * 2;
+ pix_height_prev = height * 2;
+ pix = XCreatePixmap(d, DefaultRootWindow(d), pix_width_prev, pix_height_prev, XDefaultDepth(d, DefaultScreen(d)));
+ gc = XCreateGC(d, pix, 0, NULL);
+ }
+
+ unsigned long backcolor = rgb2color(background, d, cmap);
+ XSetForeground(d, gc, backcolor);
+ XSetBackground(d, gc, backcolor);
+ XFillRectangle(d, pix, gc, 0, 0, width, height);
+ XSetForeground(d, gc, rgb2color(foreground, d, cmap));
+ inner::XDrawString(d, pix, fs, gc, 0, ascender, ch.istr, ch.len);
+
+ if (ximg) XDestroyImage(ximg);
+ ximg = XGetImage(d, pix, 0, 0, width, height, AllPlanes, ZPixmap );
+ }
+
+ template <typename T> vals_internal(T *str, int height_want, bool italic = false,
+ bool bold = false, bool fixed = false, rgb_type background = rgb_type(), rgb_type foreground = rgb_type()){
+ fontset_prev[0] = '\0';
+ ximg = NULL;
+ d = NULL;
+ pix = None;
+ fs = NULL;
+ ascender = descender = -1;
+ pix_width_prev = pix_height_prev = -1;
+ create(str, height_want, italic, bold, fixed, background, foreground);
+ }
+
+ inline int get_ascender(){
+ return ascender;
+ }
+
+ inline int get_descender(){
+ return descender;
+ }
+
+ std::map<unsigned long,rgb_type> col2rgb;
+ rgb_type color2rgb(unsigned long color, Display *d, Colormap &cmap){
+ if (col2rgb.count(color)){
+ return col2rgb[color];
+ }else{
+ XColor xcol;
+ xcol.pixel = color;
+ XQueryColor(d, cmap, &xcol);
+ rgb_type rgb_((byte)(xcol.red/257), (byte)(xcol.green/257), (byte)(xcol.blue/257));
+ col2rgb[color] = rgb_;
+ return rgb_;
+ }
+ }
+ inline void get_pixel(int x, int y, byte &r, byte &g, byte &b){
+ rgb_type c = color2rgb(XGetPixel(ximg,x,y), d, cmap);
+ r = c.r, g = c.g, b = c.b;
+ }
+
+ ~vals_internal(){
+ XDestroyImage(ximg);
+
+ XFreeGC(d, gc);
+ XFreeFontSet(d, fs);
+ XFreePixmap(d, pix);
+ XCloseDisplay(d);
+ }
+#endif
+ };
+
+ struct image_size_setter{
+ void operator()(int&, int&){
+ }
+ };
+
+ int ascender, descender;
+ vals_internal *vi;
+ public:
+ font_renderer() : image(0), width(0), height(0){
+ ascender = descender = 0;
+ vi = NULL;
+ }
+
+ template<typename T> font_renderer(T *str, int height_want, bool italic = false, bool bold = false, bool fixed = false, rgb_type background = rgb_type(0,0,0), rgb_type foreground = rgb_type(255,255,255)){
+ render(str, height_want, italic, bold, fixed, background, foreground);
+ }
+
+ template<typename T> void render(T *str, int height_want,
+ bool italic = false, bool bold = false, bool fixed = false,
+ rgb_type background = rgb_type(0,0,0), rgb_type foreground = rgb_type(255,255,255)){
+ if (vi == NULL){
+ vi = new vals_internal(str, height_want, italic, bold, fixed, background, foreground);
+ }else{
+ vi->create(str, height_want, italic, bold, fixed, background, foreground);
+ }
+ width = vi->width, height = vi->height;
+ image = new byte[width * height * 3];
+ ascender = vi->get_ascender();
+ descender = vi->get_descender();
+
+ int h = height, w = width;
+ for (int j = 0, i3 = 0; j < h; ++j){
+ for (int i = 0; i < w; ++i, i3 += 3){
+ vi->get_pixel(i, j, image[i3], image[i3+1], image[i3+2]);
+ }
+ }
+ }
+
+ ~font_renderer(){
+ if (vi) delete vi;
+ destroy();
+ }
+ int get_width(){
+ return width;
+ }
+ int get_height(){
+ return height;
+ }
+ inline int get_ascender(){
+ return ascender;
+ }
+ inline int get_descender(){
+ return descender;
+ }
+
+ const byte *get_image(){
+ return image;
+ }
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class native_font : public dlib::font
+ {
+ unsigned long ascender_;
+ native_font(){
+ setlocale(LC_CTYPE, "");
+ ascender_ = 0;
+ get_letter((int)('x'));
+ }
+ typedef std::map<int,dlib::letter *> letters_map_type;
+ letters_map_type letters;
+ font_renderer::font_renderer fl;
+ public:
+
+ virtual ~native_font()
+ {
+ // delete all the letter objects we have in our letters map
+ letters_map_type::iterator i;
+ for (i = letters.begin(); i != letters.end(); ++i)
+ {
+ delete i->second;
+ }
+ }
+
+ virtual bool has_character (
+ dlib::unichar ch
+ )const{
+ return (*this)[ch].width() > 0;
+ }
+
+ static const std::shared_ptr<font>& get_font (
+ )
+ {
+ static std::shared_ptr<font> f(new native_font);
+ return f;
+ }
+
+ virtual const dlib::letter& operator[] (dlib::unichar ch) const{
+ return (const_cast<native_font *>(this))->get_letter(ch);
+ }
+
+ dlib::letter& get_letter (
+ dlib::unichar ch
+ ){
+ if (letters.count(ch)){
+ dlib::letter *l = letters.find(ch)->second;
+ return *l;
+ }
+
+ dlib::unichar c[2];
+ c[0] = ch;
+ c[1] = 0;
+
+ fl.render(c, height(),false,false,true);
+ if (ascender_ == 0){
+ ascender_ = fl.get_ascender();
+ }
+ std::vector<dlib::letter::point> v;
+ const font_renderer::byte *bp = fl.get_image();
+ for (int j = 0; j < fl.get_height(); ++j){
+ for (int i = 0; i < fl.get_width(); ++i, bp += 3){
+ if (*bp){
+ v.push_back(dlib::letter::point(i,j-ascender()+1));
+ }
+ }
+ }
+ dlib::letter *l = new dlib::letter(fl.get_width(), (unsigned long)v.size());
+
+ letters.insert(std::make_pair(ch,l));
+ for (int i = 0; i < (int)v.size(); ++i){
+ (*l)[i] = v.at(i);
+ }
+ return *l;
+ }
+
+ virtual unsigned long height (
+ ) const { return 12; }
+
+ virtual unsigned long ascender (
+ ) const { return ascender_; }
+
+ virtual unsigned long left_overflow (
+ ) const { return 1; }
+
+ virtual unsigned long right_overflow (
+ ) const { return 2; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IGG_FONT_RENDERER_H_
+
diff --git a/ml/dlib/dlib/gui_widgets/style.cpp b/ml/dlib/dlib/gui_widgets/style.cpp
new file mode 100644
index 000000000..a3d22d10f
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/style.cpp
@@ -0,0 +1,998 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net), and Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_WIDGETs_STYLE_CPP_
+#define DLIB_WIDGETs_STYLE_CPP_
+
+#include "style.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // button style stuff
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void button_style_default::draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long ,
+ const long ,
+ const ustring& name,
+ const bool is_depressed
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ fill_rect(c,rect,rgb_pixel(212,208,200));
+
+ unsigned char red, green, blue;
+ if (enabled)
+ {
+ red = 0;
+ green = 0;
+ blue = 0;
+ }
+ else
+ {
+ red = 128;
+ green = 128;
+ blue = 128;
+ }
+
+ // compute the name length if it hasn't already been computed
+ if (name_width == 0)
+ {
+ unsigned long height;
+ mfont.compute_size(name,name_width,height);
+ }
+
+ // figure out where the name string should appear
+ rectangle name_rect;
+ const unsigned long width = name_width;
+ const unsigned long height = mfont.height();
+ name_rect.set_left((rect.right() + rect.left() - width)/2);
+ name_rect.set_top((rect.bottom() + rect.top() - height)/2 + 1);
+ name_rect.set_right(name_rect.left()+width-1);
+ name_rect.set_bottom(name_rect.top()+height);
+
+
+ if (is_depressed)
+ {
+ name_rect.set_left(name_rect.left()+1);
+ name_rect.set_right(name_rect.right()+1);
+ name_rect.set_top(name_rect.top()+1);
+ name_rect.set_bottom(name_rect.bottom()+1);
+
+ mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue));
+
+ draw_button_down(c,rect);
+ }
+ else
+ {
+ mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue));
+
+ // now draw the edge of the button
+ draw_button_up(c,rect);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle button_style_default::
+ get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const
+ {
+
+ unsigned long width;
+ unsigned long height;
+ mfont.compute_size(name,width,height);
+ name_width = width;
+
+ return rectangle(width+2*padding, height+2*padding);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void button_style_toolbar1::draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ const long radius = 4;
+
+ unsigned char red, green, blue;
+ if (enabled)
+ {
+ red = 0;
+ green = 0;
+ blue = 0;
+
+ long d = 0;
+ if (rect.contains(lastx,lasty))
+ d = -70;
+
+ if (is_depressed)
+ d = 20;
+
+ if (d != 0)
+ {
+ rectangle temp(rect);
+ temp.left()--; temp.top()--; temp.right()++; temp.bottom()++;
+ draw_rounded_rectangle(c, temp, radius, rgb_alpha_pixel(255,255,0,120));
+ temp.left()--; temp.top()--; temp.right()++; temp.bottom()++;
+ draw_rounded_rectangle(c, temp, radius, rgb_alpha_pixel(255,255,0,40));
+ }
+
+ fill_gradient_rounded(c,rect,radius,rgb_alpha_pixel(255, 255, 255,120-d),
+ rgb_alpha_pixel(255, 255, 255,0));
+ draw_rounded_rectangle(c,rect,radius, rgb_alpha_pixel(30,30,30,200));
+ }
+ else
+ {
+ red = 128;
+ green = 128;
+ blue = 128;
+ draw_rounded_rectangle(c,rect,radius, rgb_alpha_pixel(red,green,blue,210));
+ }
+
+
+ // compute the name length if it hasn't already been computed
+ if (name_width == 0)
+ {
+ unsigned long height;
+ mfont.compute_size(name,name_width,height);
+ }
+
+ // figure out where the name string should appear
+ rectangle name_rect;
+ const unsigned long width = name_width;
+ const unsigned long height = mfont.height();
+ name_rect.set_left((rect.right() + rect.left() - width)/2);
+ name_rect.set_top((rect.bottom() + rect.top() - height)/2 + 1);
+ name_rect.set_right(name_rect.left()+width-1);
+ name_rect.set_bottom(name_rect.top()+height);
+
+
+ if (is_depressed)
+ {
+ name_rect.set_left(name_rect.left()+1);
+ name_rect.set_right(name_rect.right()+1);
+ name_rect.set_top(name_rect.top()+1);
+ name_rect.set_bottom(name_rect.bottom()+1);
+
+ mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue));
+
+ }
+ else
+ {
+ mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle button_style_toolbar1::
+ get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const
+ {
+
+ unsigned long width;
+ unsigned long height;
+ mfont.compute_size(name,width,height);
+ name_width = width;
+
+ return rectangle(width+2*padding, height+2*padding);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void button_style_toolbar_icon1::draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& ,
+ const long lastx,
+ const long lasty,
+ const ustring& ,
+ const bool is_depressed
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ const long radius = padding;
+
+ if (enabled)
+ {
+ if (rect.contains(lastx,lasty))
+ {
+ if (is_depressed)
+ {
+ fill_gradient_rounded(c,rect,radius,rgb_alpha_pixel(100,100,200,150),
+ rgb_alpha_pixel(50,50,100,100));
+ draw_rounded_rectangle(c,rect,radius, rgb_alpha_pixel(150,150,30,200));
+ }
+ else
+ {
+ fill_gradient_rounded(c,rect,radius,rgb_alpha_pixel(150,150,250,130),
+ rgb_alpha_pixel(100,100,150,90));
+ draw_rounded_rectangle(c,rect,radius, rgb_alpha_pixel(150,150,30,200));
+ }
+ }
+
+ if (is_depressed)
+ {
+ rectangle img_rect(translate_rect(centered_rect(rect,img_mouseover.nc(),img_mouseover.nr()),1,1));
+ point p(img_rect.left(),img_rect.top());
+ draw_image(c,p,img_mouseover);
+ }
+ else
+ {
+ rectangle img_rect(centered_rect(rect,img_normal.nc(),img_normal.nr()));
+ point p(img_rect.left(),img_rect.top());
+ if (rect.contains(lastx,lasty))
+ draw_image(c,p,img_mouseover);
+ else
+ draw_image(c,p,img_normal);
+ }
+
+ }
+ else
+ {
+ rectangle img_rect(centered_rect(rect,img_normal.nc(),img_normal.nr()));
+ point p(img_rect.left(),img_rect.top());
+ draw_image(c,p,img_disabled);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle button_style_toolbar_icon1::
+ get_min_size (
+ const ustring& ,
+ const font&
+ ) const
+ {
+ return rectangle(img_normal.nc()+2*padding, img_normal.nr()+2*padding);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void button_style_arrow::
+ draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& ,
+ const long ,
+ const long ,
+ const ustring& ,
+ const bool is_depressed
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ fill_rect(c,rect,rgb_pixel(212,208,200));
+
+ const long height = rect.height();
+ const long width = rect.width();
+
+ const long smallest = (width < height) ? width : height;
+
+ const long rows = (smallest+3)/4;
+ const long start = rows + rows/2-1;
+ long dep;
+
+ long tip_x = 0;
+ long tip_y = 0;
+ long wy = 0;
+ long hy = 0;
+ long wx = 0;
+ long hx = 0;
+
+ if (is_depressed)
+ {
+ dep = 0;
+
+ // draw the button's border
+ draw_button_down(c,rect);
+ }
+ else
+ {
+ dep = -1;
+
+ // draw the button's border
+ draw_button_up(c,rect);
+ }
+
+
+ switch (dir)
+ {
+ case UP:
+ tip_x = width/2 + rect.left() + dep;
+ tip_y = (height - start)/2 + rect.top() + dep + 1;
+ wy = 0;
+ hy = 1;
+ wx = 1;
+ hx = 0;
+ break;
+
+ case DOWN:
+ tip_x = width/2 + rect.left() + dep;
+ tip_y = rect.bottom() - (height - start)/2 + dep;
+ wy = 0;
+ hy = -1;
+ wx = 1;
+ hx = 0;
+ break;
+
+ case LEFT:
+ tip_x = rect.left() + (width - start)/2 + dep + 1;
+ tip_y = height/2 + rect.top() + dep;
+ wy = 1;
+ hy = 0;
+ wx = 0;
+ hx = 1;
+ break;
+
+ case RIGHT:
+ tip_x = rect.right() - (width - start)/2 + dep;
+ tip_y = height/2 + rect.top() + dep;
+ wy = 1;
+ hy = 0;
+ wx = 0;
+ hx = -1;
+ break;
+ }
+
+
+ rgb_pixel color;
+ if (enabled)
+ {
+ color.red = 0;
+ color.green = 0;
+ color.blue = 0;
+ }
+ else
+ {
+ color.red = 128;
+ color.green = 128;
+ color.blue = 128;
+ }
+
+
+
+ for (long i = 0; i < rows; ++i)
+ {
+ draw_line(c,point(tip_x + wx*i + hx*i, tip_y + wy*i + hy*i),
+ point(tip_x + wx*i*-1 + hx*i, tip_y + wy*i*-1 + hy*i),
+ color);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // toggle button style stuff
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button_style_default::draw_toggle_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long ,
+ const long ,
+ const ustring& name,
+ const bool is_depressed,
+ const bool is_checked
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ fill_rect(c,rect,rgb_pixel(212,208,200));
+
+ unsigned char red, green, blue;
+ if (enabled)
+ {
+ red = 0;
+ green = 0;
+ blue = 0;
+ }
+ else
+ {
+ red = 128;
+ green = 128;
+ blue = 128;
+ }
+
+ // compute the name length if it hasn't already been computed
+ if (name_width == 0)
+ {
+ unsigned long height;
+ mfont.compute_size(name,name_width,height);
+ }
+
+ // figure out where the name string should appear
+ rectangle name_rect;
+ const unsigned long width = name_width;
+ const unsigned long height = mfont.height();
+ name_rect.set_left((rect.right() + rect.left() - width)/2);
+ name_rect.set_top((rect.bottom() + rect.top() - height)/2 + 1);
+ name_rect.set_right(name_rect.left()+width-1);
+ name_rect.set_bottom(name_rect.top()+height);
+
+ long d = 0;
+ if (is_checked)
+ d = 1;
+
+ if (is_depressed)
+ d = 2;
+
+ name_rect.set_left(name_rect.left()+d);
+ name_rect.set_right(name_rect.right()+d);
+ name_rect.set_top(name_rect.top()+d);
+ name_rect.set_bottom(name_rect.bottom()+d);
+
+ mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue));
+
+ // now draw the edge of the button
+ if (is_checked || is_depressed)
+ draw_button_down(c,rect);
+ else
+ draw_button_up(c,rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle toggle_button_style_default::
+ get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const
+ {
+
+ unsigned long width;
+ unsigned long height;
+ mfont.compute_size(name,width,height);
+ name_width = width;
+
+ return rectangle(width+2*padding, height+2*padding);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button_style_check_box::draw_toggle_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long ,
+ const long ,
+ const ustring& name,
+ const bool is_depressed,
+ const bool is_checked
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+
+ rgb_pixel color;
+ if (enabled)
+ {
+ color.red = 0;
+ color.green = 0;
+ color.blue = 0;
+ }
+ else
+ {
+ color.red = 128;
+ color.green = 128;
+ color.blue = 128;
+ }
+
+
+ // figure out where the name string should appear
+ rectangle name_rect, box_rect;
+ unsigned long padding = 0;
+ if (mfont.height() < 13)
+ padding = (rect.height() - mfont.height())/2;
+
+ name_rect = rect;
+ name_rect.set_left(rect.left() + 17-1);
+ name_rect.set_top(rect.top() + padding);
+ name_rect.set_bottom(rect.bottom() - padding);
+
+ box_rect = rect;
+ box_rect.set_right(rect.left() + 12);
+ box_rect.set_bottom(rect.top() + 12);
+
+ mfont.draw_string(c,name_rect,name,color);
+
+ if (enabled && is_depressed == false)
+ fill_rect(c, box_rect,rgb_pixel(255,255,255));
+ else
+ fill_rect(c, box_rect,rgb_pixel(212,208,200));
+
+ draw_sunken_rectangle(c, box_rect);
+
+
+ if (is_checked)
+ {
+ const long x = box_rect.left();
+ const long y = box_rect.top();
+ draw_line(c,point(3+x,5+y),point(6+x,8+y),color);
+ draw_line(c,point(3+x,6+y),point(5+x,8+y),color);
+ draw_line(c,point(3+x,7+y),point(5+x,9+y),color);
+ draw_line(c,point(6+x,6+y),point(9+x,3+y),color);
+ draw_line(c,point(6+x,7+y),point(9+x,4+y),color);
+ draw_line(c,point(6+x,8+y),point(9+x,5+y),color);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle toggle_button_style_check_box::
+ get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const
+ {
+ unsigned long width;
+ unsigned long height;
+ mfont.compute_size(name,width,height);
+
+ if (height < 13)
+ height = 13;
+
+ return rectangle(width + 17 -1, height -1);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button_style_radio_button::draw_toggle_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long ,
+ const long ,
+ const ustring& name,
+ const bool is_depressed,
+ const bool is_checked
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+
+ rgb_pixel color;
+
+ // figure out where the name string should appear
+ rectangle name_rect, box_rect;
+ unsigned long padding = 0;
+ if (mfont.height() < 13)
+ padding = (rect.height() - mfont.height())/2;
+
+ name_rect = rect;
+ name_rect.set_left(rect.left() + 17-1);
+ name_rect.set_top(rect.top() + padding);
+ name_rect.set_bottom(rect.bottom() - padding);
+
+ box_rect = rect;
+ box_rect.set_right(rect.left() + 12);
+ box_rect.set_bottom(rect.top() + 12);
+
+
+ const long x = box_rect.left();
+ const long y = box_rect.top();
+
+ if (enabled && is_depressed == false)
+ draw_solid_circle(c,point(rect.left()+5,rect.top()+5),4.5,rgb_pixel(255,255,255));
+ else
+ draw_solid_circle(c,point(rect.left()+5,rect.top()+5),4.5,rgb_pixel(212,208,200));
+
+
+ color = rgb_pixel(128,128,128);
+ draw_line(c,point(0+x,4+y),point(0+x,7+y),color);
+ draw_line(c,point(1+x,2+y),point(1+x,9+y),color);
+ draw_line(c,point(2+x,1+y),point(9+x,1+y),color);
+ draw_line(c,point(4+x,0+y),point(7+x,0+y),color);
+
+ color = rgb_pixel(255,255,255);
+ draw_line(c,point(4+x,11+y),point(7+x,11+y),color);
+ draw_line(c,point(2+x,10+y),point(9+x,10+y),color);
+ draw_line(c,point(10+x,2+y),point(10+x,9+y),color);
+ draw_line(c,point(11+x,4+y),point(11+x,7+y),color);
+
+ color = rgb_pixel(64,64,64);
+ draw_line(c,point(1+x,4+y),point(1+x,7+y),color);
+ draw_line(c,point(4+x,1+y),point(7+x,1+y),color);
+ draw_pixel(c,point(2+x,3+y),color);
+ draw_pixel(c,point(3+x,2+y),color);
+ draw_pixel(c,point(2+x,2+y),color);
+ draw_pixel(c,point(2+x,8+y),color);
+ draw_pixel(c,point(8+x,2+y),color);
+ draw_pixel(c,point(9+x,2+y),color);
+
+ color = rgb_pixel(212,208,200);
+ draw_line(c,point(4+x,10+y),point(7+x,10+y),color);
+ draw_line(c,point(10+x,4+y),point(10+x,7+y),color);
+ draw_pixel(c,point(3+x,9+y),color);
+ draw_pixel(c,point(9+x,3+y),color);
+
+ if (enabled)
+ {
+ color.red = 0;
+ color.green = 0;
+ color.blue = 0;
+ }
+ else
+ {
+ color.red = 128;
+ color.green = 128;
+ color.blue = 128;
+ }
+
+ mfont.draw_string(c,name_rect,name,color);
+
+ if (is_checked)
+ {
+ draw_line(c,point(5+x,4+y),point(6+x,4+y),color);
+ draw_line(c,point(4+x,5+y),point(7+x,5+y),color);
+ draw_line(c,point(4+x,6+y),point(7+x,6+y),color);
+ draw_line(c,point(5+x,7+y),point(6+x,7+y),color);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle toggle_button_style_radio_button::
+ get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const
+ {
+ unsigned long width;
+ unsigned long height;
+ mfont.compute_size(name,width,height);
+
+ if (height < 13)
+ height = 13;
+
+ return rectangle(width + 17 -1, height -1);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // scroll bar style stuff
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ long scroll_bar_style_default::
+ get_slider_length (
+ long total_length,
+ long max_pos
+ ) const
+ {
+ // if the length is too small then we have to smash up the arrow buttons
+ // and hide the slider.
+ if (total_length <= get_width()*2)
+ {
+ return 0;
+ }
+ else
+ {
+ double range = total_length - get_button_length(total_length, max_pos)*2;
+
+ double scale_factor = 30.0/(max_pos + 30.0);
+
+ if (scale_factor < 0.1)
+ scale_factor = 0.1;
+
+
+ double fraction = range/(max_pos + range)*scale_factor;
+ double result = fraction * range;
+ long res = static_cast<long>(result);
+ if (res < 8)
+ res = 8;
+ return res;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long scroll_bar_style_default::
+ get_button_length (
+ long total_length,
+ long
+ ) const
+ {
+ // if the length is too small then we have to smash up the arrow buttons
+ // and hide the slider.
+ if (total_length <= get_width()*2)
+ {
+ return total_length/2;
+ }
+ else
+ {
+ return get_width();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar_style_default::
+ draw_scroll_bar_background (
+ const canvas& c,
+ const rectangle& rect,
+ const bool ,
+ const long ,
+ const long ,
+ const bool is_depressed
+ ) const
+ {
+ if (is_depressed)
+ draw_checkered(c, rect,rgb_pixel(0,0,0),rgb_pixel(43,47,55));
+ else
+ draw_checkered(c, rect,rgb_pixel(255,255,255),rgb_pixel(212,208,200));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void scroll_bar_style_default::
+ draw_scroll_bar_slider (
+ const canvas& c,
+ const rectangle& rect,
+ const bool ,
+ const long ,
+ const long ,
+ const bool
+ ) const
+ {
+ fill_rect(c, rect, rgb_pixel(212,208,200));
+ draw_button_up(c, rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // text_field styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ unsigned long text_field_style_default::
+ get_padding (
+ const font& mfont
+ ) const
+ {
+ return mfont.height()-mfont.ascender();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field_style_default::
+ draw_text_field (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& text_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const unsigned long cursor_x,
+ const unsigned long text_pos,
+ const rgb_pixel& text_color,
+ const rgb_pixel& bg_color,
+ const bool has_focus,
+ const bool cursor_visible,
+ const long highlight_start,
+ const long highlight_end
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+
+ if (enabled)
+ {
+ // first fill our area with the bg_color
+ fill_rect(c, area,bg_color);
+ }
+ else
+ {
+ // first fill our area with gray
+ fill_rect(c, area,rgb_pixel(212,208,200));
+ }
+
+
+ if (enabled)
+ mfont.draw_string(c,text_rect,text,text_color,text_pos);
+ else
+ mfont.draw_string(c,text_rect,text,rgb_pixel(128,128,128),text_pos);
+
+ // now draw the edge of the text_field
+ draw_sunken_rectangle(c, rect);
+
+ if (highlight_start <= highlight_end && enabled)
+ {
+ rectangle highlight_rect = text_rect;
+ unsigned long left_pad = 0, right_pad = mfont.left_overflow();
+
+ long i;
+ for (i = text_pos; i <= highlight_end; ++i)
+ {
+ if (i == highlight_start)
+ left_pad = right_pad;
+
+ right_pad += mfont[text[i]].width();
+ }
+
+ highlight_rect.set_left(text_rect.left()+left_pad);
+ highlight_rect.set_right(text_rect.left()+right_pad);
+
+ // highlight the highlight_rect area
+ highlight_rect = highlight_rect.intersect(c);
+ for (long row = highlight_rect.top(); row <= highlight_rect.bottom(); ++row)
+ {
+ for (long col = highlight_rect.left(); col <= highlight_rect.right(); ++col)
+ {
+ canvas::pixel& pixel = c[row-c.top()][col-c.left()];
+ if (pixel.red == 255 && pixel.green == 255 && pixel.blue == 255)
+ {
+ // this is a background (and white) pixel so set it to a dark
+ // blueish color.
+ pixel.red = 10;
+ pixel.green = 36;
+ pixel.blue = 106;
+ }
+ else
+ {
+ // this should be a pixel that is part of a letter so set it to white
+ pixel.red = 255;
+ pixel.green = 255;
+ pixel.blue = 255;
+ }
+ }
+ }
+ }
+
+ // now draw the cursor if we need to
+ if (cursor_visible && has_focus && enabled)
+ {
+ const unsigned long top = rect.top()+3;
+ const unsigned long bottom = rect.bottom()-3;
+ draw_line(c, point(rect.left()+cursor_x,top),point(rect.left()+cursor_x,bottom));
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // text_box styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void text_box_style_default::
+ draw_text_box (
+ const canvas& c,
+ const rectangle& display_rect,
+ const rectangle& text_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const rectangle& cursor_rect,
+ const rgb_pixel& text_color,
+ const rgb_pixel& bg_color,
+ const bool has_focus,
+ const bool cursor_visible,
+ const long highlight_start,
+ const long highlight_end
+ ) const
+ {
+ rectangle area = display_rect.intersect(c);
+
+ if (enabled)
+ {
+ // first fill our area with the bg_color
+ fill_rect(c, area,bg_color);
+ }
+ else
+ {
+ // first fill our area with gray
+ fill_rect(c, area,rgb_pixel(212,208,200));
+ }
+
+
+ if (enabled)
+ mfont.draw_string(c,text_rect,text,text_color, 0, ustring::npos, area);
+ else
+ mfont.draw_string(c,text_rect,text,rgb_pixel(128,128,128), 0, ustring::npos, area);
+
+
+ // now draw the highlight if there is any
+ if (highlight_start <= highlight_end && enabled)
+ {
+ const rectangle first_pos = mfont.compute_cursor_rect(text_rect, text, highlight_start);
+ const rectangle last_pos = mfont.compute_cursor_rect(text_rect, text, highlight_end+1);
+
+ const rgb_alpha_pixel color(10, 30, 106, 90);
+
+ // if the highlighted text is all on one line
+ if (first_pos.top() == last_pos.top())
+ {
+ fill_rect(c, (first_pos + last_pos).intersect(display_rect), color);
+ }
+ else
+ {
+ const rectangle min_boundary(display_rect.left()+4, display_rect.top()+4,
+ display_rect.right()-4, display_rect.bottom()-4);
+ const rectangle boundary( display_rect.intersect(text_rect) + min_boundary);
+
+ rectangle first_row, last_row, middle_rows;
+ first_row += first_pos;
+ first_row += point(boundary.right(), first_pos.top());
+ last_row += last_pos;
+ last_row += point(boundary.left(), last_pos.bottom());
+
+ middle_rows.left() = boundary.left();
+ middle_rows.right() = boundary.right();
+ middle_rows.top() = first_row.bottom()+1;
+ middle_rows.bottom() = last_row.top()-1;
+
+ fill_rect(c, first_row.intersect(display_rect), color);
+ fill_rect(c, middle_rows, color);
+ fill_rect(c, last_row.intersect(display_rect), color);
+ }
+ }
+
+ // now draw the cursor if we need to
+ if (cursor_visible && has_focus && enabled)
+ {
+ draw_line(c, point(cursor_rect.left(), cursor_rect.top()),point(cursor_rect.left(), cursor_rect.bottom()), 0, area);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_WIDGETs_STYLE_CPP_
+
diff --git a/ml/dlib/dlib/gui_widgets/style.h b/ml/dlib/dlib/gui_widgets/style.h
new file mode 100644
index 000000000..f31caee30
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/style.h
@@ -0,0 +1,825 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net), and Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_WIDGETs_STYLE_
+#define DLIB_WIDGETs_STYLE_
+
+#include "../algs.h"
+#include "style_abstract.h"
+#include "../gui_core.h"
+#include "canvas_drawing.h"
+#include <string>
+#include <sstream>
+#include "../unicode.h"
+#include "../array2d.h"
+#include "../pixel.h"
+#include "fonts.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // button styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class button_style
+ {
+ public:
+
+ button_style()
+ {
+ }
+
+ virtual ~button_style()
+ {}
+
+ virtual bool redraw_on_mouse_over (
+ ) const { return false; }
+
+ virtual rectangle get_invalidation_rect (
+ const rectangle& rect
+ ) const { return rect; }
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const = 0;
+
+ virtual void draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed
+ ) const = 0;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class button_style_default : public button_style
+ {
+ public:
+ button_style_default () : padding(4), name_width(0) {}
+
+ virtual void draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed
+ ) const;
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const;
+
+ private:
+
+ // this is the minimum amount of padding that can separate the name from the
+ // edge of the button
+ const unsigned long padding;
+ // this is the width of the name string
+ mutable unsigned long name_width;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class button_style_toolbar1 : public button_style
+ {
+ public:
+ button_style_toolbar1 () : padding(4), name_width(0) {}
+
+ virtual void draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed
+ ) const;
+
+ virtual rectangle get_invalidation_rect (
+ const rectangle& rect
+ ) const
+ {
+ rectangle temp(rect);
+ temp.left() -= 2;
+ temp.top() -= 2;
+ temp.right() += 2;
+ temp.bottom() += 2;
+ return temp;
+ }
+
+ virtual bool redraw_on_mouse_over (
+ ) const { return true; }
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const;
+
+ private:
+
+ // this is the minimum amount of padding that can separate the name from the
+ // edge of the button
+ const unsigned long padding;
+ // this is the width of the name string
+ mutable unsigned long name_width;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class button_style_toolbar_icon1 : public button_style
+ {
+ public:
+ template <typename image_type>
+ button_style_toolbar_icon1 (const image_type& img_, unsigned long pad = 6) : padding(pad)
+ {
+ assign_image(img_mouseover,img_);
+ make_images();
+ }
+
+ button_style_toolbar_icon1( const button_style_toolbar_icon1& item): button_style(item), padding(item.padding)
+ {
+ assign_image(img_mouseover, item.img_mouseover);
+ assign_image(img_normal, item.img_normal);
+ assign_image(img_disabled, item.img_disabled);
+ }
+
+ virtual void draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed
+ ) const;
+
+ virtual bool redraw_on_mouse_over (
+ ) const { return true; }
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const;
+
+ private:
+
+ void make_images (
+ )
+ {
+ // make the disabled image grayscale and make both non-mouseover images have weaker alpha channels
+ img_disabled.set_size(img_mouseover.nr(), img_mouseover.nc());
+ img_normal.set_size(img_mouseover.nr(), img_mouseover.nc());
+
+ for (long r = 0; r < img_mouseover.nr(); ++r)
+ {
+ for (long c = 0; c < img_mouseover.nc(); ++c)
+ {
+ rgb_alpha_pixel p = img_mouseover[r][c];
+ long avg = p.red;
+ avg += p.green;
+ avg += p.blue;
+ avg /= 3;
+
+ if (p.alpha > 40)
+ p.alpha -= 40;
+ else
+ p.alpha = 0;
+
+ img_normal[r][c] = p;
+
+ if (p.alpha > 80)
+ p.alpha -= 80;
+ else
+ p.alpha = 0;
+
+ p.red = avg;
+ p.green = avg;
+ p.blue = avg;
+ img_disabled[r][c] = p;
+ }
+ }
+ }
+
+ array2d<rgb_alpha_pixel> img_mouseover;
+ array2d<rgb_alpha_pixel> img_normal;
+ array2d<rgb_alpha_pixel> img_disabled;
+
+ // this is the minimum amount of padding that can separate the name from the
+ // edge of the button
+ const unsigned long padding;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class button_style_arrow : public button_style
+ {
+
+ public:
+
+ enum arrow_direction
+ {
+ UP,
+ DOWN,
+ LEFT,
+ RIGHT
+ };
+
+ button_style_arrow (
+ arrow_direction dir_
+ ) : dir(dir_) {}
+
+ virtual void draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed
+ ) const;
+
+ virtual rectangle get_min_size (
+ const ustring& ,
+ const font&
+ ) const { return rectangle(); }
+
+ private:
+ arrow_direction dir;
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // toggle button styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button_style
+ {
+ public:
+
+ toggle_button_style()
+ {
+ }
+
+ virtual ~toggle_button_style()
+ {}
+
+ virtual bool redraw_on_mouse_over (
+ ) const { return false; }
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const = 0;
+
+ virtual void draw_toggle_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed,
+ const bool is_checked
+ ) const = 0;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button_style_default : public toggle_button_style
+ {
+ public:
+ toggle_button_style_default () : padding(4), name_width(0) {}
+
+ virtual void draw_toggle_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed,
+ const bool is_checked
+ ) const;
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const;
+
+ private:
+
+ // this is the minimum amount of padding that can separate the name from the
+ // edge of the button
+ const unsigned long padding;
+ // this is the width of the name string
+ mutable unsigned long name_width;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button_style_check_box : public toggle_button_style
+ {
+ public:
+ virtual void draw_toggle_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed,
+ const bool is_checked
+ ) const;
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button_style_radio_button : public toggle_button_style
+ {
+ public:
+ virtual void draw_toggle_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed,
+ const bool is_checked
+ ) const;
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // scroll_bar styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class scroll_bar_style
+ {
+ public:
+
+ virtual ~scroll_bar_style() {}
+
+ virtual bool redraw_on_mouse_over_slider (
+ ) const { return false; }
+
+ virtual long get_width (
+ ) const = 0;
+
+ virtual long get_slider_length (
+ long total_length,
+ long max_pos
+ ) const = 0;
+
+ virtual long get_button_length (
+ long total_length,
+ long max_pos
+ ) const = 0;
+
+ virtual void draw_scroll_bar_background (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const long lastx,
+ const long lasty,
+ const bool is_depressed
+ ) const = 0;
+
+ virtual void draw_scroll_bar_slider (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const long lastx,
+ const long lasty,
+ const bool is_being_dragged
+ ) const = 0;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class scroll_bar_style_default : public scroll_bar_style
+ {
+ public:
+ button_style_arrow get_up_button_style (
+ ) const { return button_style_arrow(button_style_arrow::UP); }
+
+ button_style_arrow get_down_button_style (
+ ) const { return button_style_arrow(button_style_arrow::DOWN); }
+
+ button_style_arrow get_left_button_style (
+ ) const { return button_style_arrow(button_style_arrow::LEFT); }
+
+ button_style_arrow get_right_button_style (
+ ) const { return button_style_arrow(button_style_arrow::RIGHT); }
+
+ virtual long get_width (
+ ) const { return 16; }
+
+ virtual long get_slider_length (
+ long total_length,
+ long max_pos
+ ) const;
+
+ virtual long get_button_length (
+ long total_length,
+ long max_pos
+ ) const;
+
+ virtual void draw_scroll_bar_background (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const long lastx,
+ const long lasty,
+ const bool is_depressed
+ ) const;
+
+ virtual void draw_scroll_bar_slider (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const long lastx,
+ const long lasty,
+ const bool is_being_dragged
+ ) const;
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // scrollable_region styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class scrollable_region_style
+ {
+ public:
+
+ virtual ~scrollable_region_style() {}
+
+ virtual long get_border_size (
+ ) const = 0;
+
+ virtual void draw_scrollable_region_border (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled
+ ) const = 0;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class scrollable_region_style_default : public scrollable_region_style
+ {
+ public:
+ scroll_bar_style_default get_horizontal_scroll_bar_style (
+ ) const { return scroll_bar_style_default(); }
+
+ scroll_bar_style_default get_vertical_scroll_bar_style (
+ ) const { return scroll_bar_style_default(); }
+
+ virtual long get_border_size (
+ ) const { return 2; }
+
+ virtual void draw_scrollable_region_border (
+ const canvas& c,
+ const rectangle& rect,
+ const bool
+ ) const { draw_sunken_rectangle(c,rect); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // list_box styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class list_box_style
+ {
+ public:
+
+ virtual ~list_box_style() {}
+
+ virtual void draw_list_box_background (
+ const canvas& c,
+ const rectangle& display_rect,
+ const bool enabled
+ ) const = 0;
+
+ virtual void draw_list_box_item (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const std::string& text,
+ const bool is_selected
+ ) const = 0;
+
+ virtual void draw_list_box_item (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const std::wstring& text,
+ const bool is_selected
+ ) const = 0;
+
+ virtual void draw_list_box_item (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const bool is_selected
+ ) const = 0;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class list_box_style_default : public list_box_style
+ {
+ public:
+ scrollable_region_style_default get_scrollable_region_style (
+ ) const { return scrollable_region_style_default(); }
+
+ virtual void draw_list_box_item (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const std::string& text,
+ const bool is_selected
+ ) const { draw_list_box_item_template(c,rect,display_rect, enabled, mfont, text, is_selected); }
+
+ virtual void draw_list_box_item (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const std::wstring& text,
+ const bool is_selected
+ ) const { draw_list_box_item_template(c,rect,display_rect, enabled, mfont, text, is_selected); }
+
+ virtual void draw_list_box_item (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const bool is_selected
+ ) const { draw_list_box_item_template(c,rect,display_rect, enabled, mfont, text, is_selected); }
+
+ template <typename string_type>
+ void draw_list_box_item_template (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const string_type& text,
+ const bool is_selected
+ ) const
+ {
+ if (is_selected)
+ {
+ if (enabled)
+ fill_rect_with_vertical_gradient(c,rect,rgb_pixel(110,160,255), rgb_pixel(100,130,250),display_rect);
+ else
+ fill_rect_with_vertical_gradient(c,rect,rgb_pixel(140,190,255), rgb_pixel(130,160,250),display_rect);
+ }
+
+ if (enabled)
+ mfont.draw_string(c,rect,text,rgb_pixel(0,0,0),0,std::string::npos,display_rect);
+ else
+ mfont.draw_string(c,rect,text,rgb_pixel(128,128,128),0,std::string::npos,display_rect);
+ }
+
+ virtual void draw_list_box_background (
+ const canvas& c,
+ const rectangle& display_rect,
+ const bool enabled
+ ) const
+ {
+ if (enabled)
+ {
+ // first fill our area with white
+ fill_rect(c, display_rect,rgb_pixel(255,255,255));
+ }
+ else
+ {
+ // first fill our area with gray
+ fill_rect(c, display_rect,rgb_pixel(212,208,200));
+ }
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // text_box styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class text_box_style
+ {
+ public:
+
+ text_box_style()
+ {
+ }
+
+ virtual ~text_box_style()
+ {}
+
+ virtual unsigned long get_padding (
+ const font& mfont
+ ) const = 0;
+
+ virtual void draw_text_box (
+ const canvas& c,
+ const rectangle& display_rect,
+ const rectangle& text_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const rectangle& cursor_rect,
+ const rgb_pixel& text_color,
+ const rgb_pixel& bg_color,
+ const bool has_focus,
+ const bool cursor_visible,
+ const long highlight_start,
+ const long highlight_end
+ ) const = 0;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class text_box_style_default : public text_box_style
+ {
+ public:
+
+ text_box_style_default()
+ {
+ }
+
+ scrollable_region_style_default get_scrollable_region_style (
+ ) const { return scrollable_region_style_default(); }
+
+ virtual ~text_box_style_default()
+ {}
+
+ virtual unsigned long get_padding (
+ const font&
+ ) const { return 1; }
+
+ virtual void draw_text_box (
+ const canvas& c,
+ const rectangle& display_rect,
+ const rectangle& text_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const rectangle& cursor_rect,
+ const rgb_pixel& text_color,
+ const rgb_pixel& bg_color,
+ const bool has_focus,
+ const bool cursor_visible,
+ const long highlight_start,
+ const long highlight_end
+ ) const;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // text_field styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class text_field_style
+ {
+ public:
+
+ text_field_style()
+ {
+ }
+
+ virtual ~text_field_style()
+ {}
+
+ virtual unsigned long get_padding (
+ const font& mfont
+ ) const = 0;
+
+ virtual void draw_text_field (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& text_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const unsigned long cursor_x,
+ const unsigned long text_pos,
+ const rgb_pixel& text_color,
+ const rgb_pixel& bg_color,
+ const bool has_focus,
+ const bool cursor_visible,
+ const long highlight_start,
+ const long highlight_end
+ ) const = 0;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class text_field_style_default : public text_field_style
+ {
+ public:
+
+ text_field_style_default()
+ {
+ }
+
+ virtual ~text_field_style_default()
+ {}
+
+ virtual unsigned long get_padding (
+ const font& mfont
+ ) const;
+
+ virtual void draw_text_field (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& text_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const unsigned long cursor_x,
+ const unsigned long text_pos,
+ const rgb_pixel& text_color,
+ const rgb_pixel& bg_color,
+ const bool has_focus,
+ const bool cursor_visible,
+ const long highlight_start,
+ const long highlight_end
+ ) const;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "style.cpp"
+#endif
+
+#endif // DLIB_WIDGETs_STYLE_
+
+
diff --git a/ml/dlib/dlib/gui_widgets/style_abstract.h b/ml/dlib/dlib/gui_widgets/style_abstract.h
new file mode 100644
index 000000000..e4d3245df
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/style_abstract.h
@@ -0,0 +1,777 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net), and Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_WIDGETs_STYLE_ABSTRACT_
+#ifdef DLIB_WIDGETs_STYLE_ABSTRACT_
+
+#include "../algs.h"
+#include "../gui_core.h"
+#include "widgets_abstract.h"
+#include "../unicode/unicode_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // button styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class button_style
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an abstract class that defines the interface a
+ button style object must implement.
+
+ Note that derived classes must be copyable via
+ their copy constructors.
+ !*/
+
+ public:
+
+ virtual ~button_style() {}
+
+ virtual bool redraw_on_mouse_over (
+ ) const { return false; }
+ /*!
+ ensures
+ - if (this style draws buttons differently when a mouse is over them) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ virtual rectangle get_invalidation_rect (
+ const rectangle& rect
+ ) const { return rect; }
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - rect == the get_rect() that defines where the button is
+ ensures
+ - returns a rectangle that should be invalidated whenever a button
+ needs to redraw itself. (e.g. If you wanted your button style to
+ draw outside the button then you could return a larger rectangle)
+ !*/
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ ensures
+ - returns a rectangle that represents the minimum size of the button
+ given the name and font.
+ !*/
+
+ virtual void draw_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - c == the canvas to draw on
+ - rect, enabled, mfont, lastx, and lasty are the variables
+ defined in the protected section of the drawable class.
+ - name == the name of the button to be drawn
+ - is_depressed == true if the button is to be drawn in a depressed state
+ ensures
+ - draws the button on the canvas c at the location given by rect.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class button_style_default : public button_style
+ {
+ /*!
+ This is the default style for button objects. It will cause
+ a button to appear as the simple MS Windows 2000 button style.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class button_style_toolbar1 : public button_style
+ {
+ /*!
+ This draws a simple toolbar style button that displays its name in the
+ middle of itself. When the mouse moves over it it will light up.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class button_style_toolbar_icon1 : public button_style
+ {
+ /*!
+ This draws a simple toolbar style button that displays an image in the
+ middle of itself. When the mouse moves over it it will light up.
+ !*/
+ template <typename image_type>
+ button_style_toolbar_icon1 (
+ const image_type& img,
+ unsigned long border_size = 6
+ );
+ /*!
+ requires
+ - image_type == an implementation of array2d/array2d_kernel_abstract.h
+ - pixel_traits<typename image_type::type> is defined
+ ensures
+ - displays image img in the middle of the button
+ - the distance between the edge of the button and the image
+ will be border_size pixels
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class button_style_arrow : public button_style
+ {
+ public:
+ /*!
+ This draws a simple button with an arrow in it
+ !*/
+
+ enum arrow_direction
+ {
+ UP,
+ DOWN,
+ LEFT,
+ RIGHT
+ };
+
+ button_style_arrow (
+ arrow_direction dir
+ );
+ /*!
+ ensures
+ - the arrow in the button will point in the given direction
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // toggle button styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button_style
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an abstract class that defines the interface a
+ toggle button style object must implement.
+
+ Note that derived classes must be copyable via
+ their copy constructors.
+ !*/
+
+ public:
+
+ virtual ~toggle_button_style() {}
+
+ virtual bool redraw_on_mouse_over (
+ ) const { return false; }
+ /*!
+ ensures
+ - if (this style draws buttons differently when a mouse is over them) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ virtual rectangle get_min_size (
+ const ustring& name,
+ const font& mfont
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ ensures
+ - returns a rectangle that represents the minimum size of the button
+ given the name and font.
+ !*/
+
+ virtual void draw_toggle_button (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const font& mfont,
+ const long lastx,
+ const long lasty,
+ const ustring& name,
+ const bool is_depressed,
+ const bool is_checked
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - c == the canvas to draw on
+ - rect, enabled, mfont, lastx, and lasty are the variables
+ defined in the protected section of the drawable class.
+ - name == the name of the button to be drawn
+ - is_depressed == true if the button is to be drawn in a depressed state
+ - is_checked == true if the toggle_button is in the checked state
+ ensures
+ - draws the button on the canvas c at the location given by rect.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button_style_default : public toggle_button_style
+ {
+ /*!
+ This is the default style for toggle_button objects. It will cause
+ a button to appear as the simple MS Windows 2000 button style.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button_style_check_box : public toggle_button_style
+ {
+ /*!
+ This draws a simple check box style toggle button that displays its
+ name to the right of a check box.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button_style_radio_button : public toggle_button_style
+ {
+ /*!
+ This draws a simple radio button style toggle button that displays its
+ name to the right of a circular radio button.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // scroll_bar styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class scroll_bar_style
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an abstract class that defines the interface a
+ scroll_bar style object must implement.
+
+ Note that derived classes must be copyable via
+ their copy constructors.
+
+ There are three parts of a scroll bar, the slider, the background,
+ and the two buttons on its ends. The "slider" is the thing that you
+ drag around on the scroll bar and the "background" is the part
+ in between the slider and the buttons on the ends.
+ !*/
+
+ public:
+
+ virtual ~scroll_bar_style() {}
+
+ virtual bool redraw_on_mouse_over_slider (
+ ) const { return false; }
+ /*!
+ ensures
+ - if (this style draws a scroll_bar's slider differently when a mouse is over it
+ or it is being dragged) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ virtual long get_width (
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ ensures
+ - returns the width in pixels of the scroll bar
+ !*/
+
+ virtual long get_slider_length (
+ long total_length,
+ long max_pos
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - total_length == the total length in pixels of the scroll bar
+ - max_pos == the value of scroll_bar::max_slider_pos() for this
+ scroll bar
+ ensures
+ - returns the length in pixels of the scroll bar's slider
+ !*/
+
+ virtual long get_button_length (
+ long total_length,
+ long max_pos
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - total_length == the total length in pixels of the scroll bar
+ - max_pos == the value of scroll_bar::max_slider_pos() for this
+ scroll bar
+ ensures
+ - returns the length in pixels of each of the scroll bar's
+ buttons
+ !*/
+
+ virtual void draw_scroll_bar_background (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const long lastx,
+ const long lasty,
+ const bool is_depressed
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - c == the canvas to draw on
+ - rect, enabled, lastx, and lasty are the variables
+ defined in the protected section of the drawable class.
+ - is_depressed == true if the background area of the scroll_bar is to
+ be drawn in a depressed state (because the user is clicking on it)
+ ensures
+ - draws the background part of the scroll_bar on the canvas c at the
+ location given by rect.
+ !*/
+
+ virtual void draw_scroll_bar_slider (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled,
+ const long lastx,
+ const long lasty,
+ const bool is_being_dragged
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - c == the canvas to draw on
+ - rect, enabled, lastx, and lasty are the variables
+ defined in the protected section of the drawable class
+ - is_being_dragged == true if the user is dragging the slider
+ ensures
+ - draws the slider part of the scroll_bar on the canvas c at the
+ location given by rect.
+ !*/
+
+ button_style_type get_up_button_style (
+ ) const;
+ /*!
+ ensures
+ - returns the type of button_style to use for a button on the
+ top side of a vertical scroll bar.
+ !*/
+
+ button_style_type get_down_button_style (
+ ) const;
+ /*!
+ ensures
+ - returns the type of button_style to use for a button on the
+ bottom side of a vertical scroll bar.
+ !*/
+
+ button_style_type get_left_button_style (
+ ) const;
+ /*!
+ ensures
+ - returns the type of button_style to use for a button on the
+ left side of a horizontal scroll bar.
+ !*/
+
+ button_style_type get_right_button_style (
+ ) const;
+ /*!
+ ensures
+ - returns the type of button_style to use for a button on the
+ right side of a horizontal scroll bar.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class scroll_bar_style_default : public scroll_bar_style
+ {
+ /*!
+ This is the default style for scroll_bar objects. It will cause
+ a scroll_bar to appear as the simple MS Windows 2000 scroll_bar style.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // scrollable_region (and zoomable_region) styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class scrollable_region_style
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an abstract class that defines the interface a
+ scrollable_region and zoomable_region style object must implement.
+
+ Note that derived classes must be copyable via
+ their copy constructors.
+ !*/
+ public:
+
+ virtual ~scrollable_region_style() {}
+
+ virtual long get_border_size (
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ ensures
+ - returns the size of the border region in pixels
+ !*/
+
+ virtual void draw_scrollable_region_border (
+ const canvas& c,
+ const rectangle& rect,
+ const bool enabled
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - c == the canvas to draw on
+ - rect and enabled are the variables defined in the protected section
+ of the drawable class.
+ ensures
+ - draws the border part of a scrollable_region on the canvas c at the
+ location given by rect.
+ !*/
+
+ scroll_bar_style_type get_horizontal_scroll_bar_style (
+ ) const;
+ /*!
+ ensures
+ - returns the style of scroll_bar to use for the
+ horizontal scroll_bar in this widget.
+ !*/
+
+ scroll_bar_style_type get_vertical_scroll_bar_style (
+ ) const;
+ /*!
+ ensures
+ - returns the style of scroll_bar to use for the
+ vertical scroll_bar in this widget.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class scrollable_region_style_default : public scrollable_region_style
+ {
+ public:
+ /*!
+ This is the default style for scrollable_region and zoomable_region objects.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // text_box styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class text_box_style
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an abstract class that defines the interface a
+ text_box style object must implement.
+
+ Note that derived classes must be copyable via
+ their copy constructors.
+ !*/
+ public:
+
+ virtual ~text_field_style() {}
+
+ scrollable_region_style_type get_scrollable_region_style (
+ ) const;
+ /*!
+ ensures
+ - returns the style of scrollable_region to use for the
+ text_box.
+ !*/
+
+ virtual unsigned long get_padding (
+ const font& mfont
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ ensures
+ - returns the number of pixels that separate the text in the text_box
+ from the edge of the text_box widget itself.
+ !*/
+
+ virtual void draw_text_box (
+ const canvas& c,
+ const rectangle& display_rect,
+ const rectangle& text_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const rectangle& cursor_rect,
+ const rgb_pixel& text_color,
+ const rgb_pixel& bg_color,
+ const bool has_focus,
+ const bool cursor_visible,
+ const long highlight_start,
+ const long highlight_end
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - c == the canvas to draw on
+ - enabled and mfont are the variables defined in the protected section
+ - text_rect == the rectangle in which we should draw the given text
+ of the drawable class.
+ - display_rect == the rectangle returned by scrollable_region::display_rect()
+ - text == the current text in the text_box
+ - cursor_rect == A rectangle of width 1 that represents the current
+ position of the cursor on the screen.
+ - text_color == the color of the text to be drawn
+ - bg_color == the background color of the text field
+ - has_focus == true if this text field has keyboard input focus
+ - cursor_visible == true if the cursor should be drawn
+ - if (highlight_start <= highlight_end) then
+ - text[highlight_start] though text[highlight_end] should be
+ highlighted
+ ensures
+ - draws the text_box on the canvas c at the location given by text_rect.
+ (Note that the scroll bars and borders are drawn by the scrollable_region
+ and therefore the style returned by get_scrollable_region_style()
+ controls how those appear)
+ - doesn't draw anything outside display_rect
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class text_box_style_default : public text_box_style
+ {
+ public:
+ /*!
+ This is the default style for text_box objects.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // list_box styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class list_box_style
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an abstract class that defines the interface a
+ list_box style object must implement.
+
+ Note that derived classes must be copyable via
+ their copy constructors.
+ !*/
+ public:
+
+ virtual ~list_box_style() {}
+
+ virtual void draw_list_box_background (
+ const canvas& c,
+ const rectangle& display_rect,
+ const bool enabled
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - c == the canvas to draw on
+ - display_rect == the display_rect for the list_box. This is the area
+ in which list box items are drawn (see display_rect in the scrollable_region
+ widget for more info)
+ - enabled == true if the list box is enabled
+ ensures
+ - draws the background of a list box on the canvas c at the location given
+ by display_rect.
+ !*/
+
+ scrollable_region_style_type get_scrollable_region_style (
+ ) const;
+ /*!
+ ensures
+ - returns the style of scrollable_region to use for the
+ list_box.
+ !*/
+
+ virtual void draw_list_box_item (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const std::string& text,
+ const bool is_selected
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - c == the canvas to draw on
+ - rect == the rectangle that defines where on the screen this list box item is.
+ - display_rect == the display_rect for the list_box. This is the area
+ in which list box items are drawn (see display_rect in the scrollable_region
+ widget for more info)
+ - mfont == the font to use to draw the list box item
+ - text == the text of the list box item to be drawn
+ - enabled == true if the list box is enabled
+ - is_selected == true if the item is to be drawn in a selected state
+ ensures
+ - draws the list box item on the canvas c at the location given by rect.
+ !*/
+
+ // wide character overloads
+ virtual void draw_list_box_item (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const std::wstring& text,
+ const bool is_selected
+ ) const = 0;
+
+ virtual void draw_list_box_item (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& display_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const bool is_selected
+ ) const = 0;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class list_box_style_default : public list_box_style
+ {
+ public:
+ /*!
+ This is the default style for list_box objects.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // text_field styles
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class text_field_style
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an abstract class that defines the interface a
+ text_field style object must implement.
+
+ Note that derived classes must be copyable via
+ their copy constructors.
+ !*/
+ public:
+
+ virtual ~text_field_style() {}
+
+ virtual unsigned long get_padding (
+ const font& mfont
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ ensures
+ - returns the number of pixels that separate the text in the text_field
+ from the edge of the text_field widget itself.
+ !*/
+
+ virtual void draw_text_field (
+ const canvas& c,
+ const rectangle& rect,
+ const rectangle& text_rect,
+ const bool enabled,
+ const font& mfont,
+ const ustring& text,
+ const unsigned long cursor_x,
+ const unsigned long text_pos,
+ const rgb_pixel& text_color,
+ const rgb_pixel& bg_color,
+ const bool has_focus,
+ const bool cursor_visible,
+ const long highlight_start,
+ const long highlight_end
+ ) const = 0;
+ /*!
+ requires
+ - the mutex drawable::m is locked
+ - c == the canvas to draw on
+ - rect, enabled, and mfont are the variables defined in the protected section
+ of the drawable class.
+ - text == the current text in the text_field
+ - text_rect == the rectangle in which we should draw the given text
+ - cursor_x == the x coordinate of the cursor relative to the left side
+ of rect. i.e. the number of pixels that separate the cursor from the
+ left side of the text_field.
+ - text_pos == the index of the first letter in text that appears in
+ this text field.
+ - text_color == the color of the text to be drawn
+ - bg_color == the background color of the text field
+ - has_focus == true if this text field has keyboard input focus
+ - cursor_visible == true if the cursor should be drawn
+ - if (highlight_start <= highlight_end) then
+ - text[highlight_start] though text[highlight_end] should be
+ highlighted
+ ensures
+ - draws the text_field on the canvas c at the location given by rect.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class text_field_style_default : public text_field_style
+ {
+ public:
+ /*!
+ This is the default style for text_field objects.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_WIDGETs_STYLE_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/gui_widgets/widgets.cpp b/ml/dlib/dlib/gui_widgets/widgets.cpp
new file mode 100644
index 000000000..c460d946d
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/widgets.cpp
@@ -0,0 +1,7341 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_WIDGETs_CPP_
+#define DLIB_WIDGETs_CPP_
+
+#include <algorithm>
+#include <memory>
+
+#include "widgets.h"
+#include "../string.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // toggle_button object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ rectangle min_rect = style->get_min_size(name_,*mfont);
+ // only change the size if it isn't going to be too small to fit the name
+ if (height >= min_rect.height() &&
+ width >= min_rect.width())
+ {
+ rectangle old(rect);
+ rect = resize_rect(rect,width,height);
+ parent.invalidate_rectangle(rect+old);
+ btn_tooltip.set_size(width,height);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ set_checked (
+ )
+ {
+ auto_mutex M(m);
+ checked = true;
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ set_unchecked (
+ )
+ {
+ auto_mutex M(m);
+ checked = false;
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool toggle_button::
+ is_checked (
+ ) const
+ {
+ auto_mutex M(m);
+ return checked;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ show (
+ )
+ {
+ button_action::show();
+ btn_tooltip.show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ hide (
+ )
+ {
+ button_action::hide();
+ btn_tooltip.hide();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ enable (
+ )
+ {
+ button_action::enable();
+ btn_tooltip.enable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ disable (
+ )
+ {
+ button_action::disable();
+ btn_tooltip.disable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ set_tooltip_text (
+ const std::string& text
+ )
+ {
+ btn_tooltip.set_text(text);
+ }
+
+ void toggle_button::
+ set_tooltip_text (
+ const std::wstring& text
+ )
+ {
+ btn_tooltip.set_text(text);
+ }
+
+ void toggle_button::
+ set_tooltip_text (
+ const dlib::ustring& text
+ )
+ {
+ btn_tooltip.set_text(text);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string toggle_button::
+ tooltip_text (
+ ) const
+ {
+ return btn_tooltip.text();
+ }
+
+ const std::wstring toggle_button::
+ tooltip_wtext (
+ ) const
+ {
+ return btn_tooltip.wtext();
+ }
+
+ const dlib::ustring toggle_button::
+ tooltip_utext (
+ ) const
+ {
+ return btn_tooltip.utext();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+ set_name(name_);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ auto_mutex M(m);
+ button_action::set_pos(x,y);
+ btn_tooltip.set_pos(x,y);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ set_name (
+ const std::string& name
+ )
+ {
+ set_name(convert_mbstring_to_wstring(name));
+ }
+
+ void toggle_button::
+ set_name (
+ const std::wstring& name
+ )
+ {
+ set_name(convert_wstring_to_utf32(name));
+ }
+
+ void toggle_button::
+ set_name (
+ const dlib::ustring& name
+ )
+ {
+ auto_mutex M(m);
+ name_ = name;
+ // do this to get rid of any reference counting that may be present in
+ // the std::string implementation.
+ name_[0] = name_[0];
+
+ rectangle old(rect);
+ rect = move_rect(style->get_min_size(name,*mfont),rect.left(),rect.top());
+ btn_tooltip.set_size(rect.width(),rect.height());
+
+ parent.invalidate_rectangle(rect+old);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string toggle_button::
+ name (
+ ) const
+ {
+ return convert_wstring_to_mbstring(wname());
+ }
+
+ const std::wstring toggle_button::
+ wname (
+ ) const
+ {
+ return convert_utf32_to_wstring(uname());
+ }
+
+ const dlib::ustring toggle_button::
+ uname (
+ ) const
+ {
+ auto_mutex M(m);
+ dlib::ustring temp = name_;
+ // do this to get rid of any reference counting that may be present in
+ // the std::string implementation.
+ temp[0] = name_[0];
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void toggle_button::
+ on_button_up (
+ bool mouse_over
+ )
+ {
+ if (mouse_over)
+ {
+ checked = !checked;
+ // this is a valid toggle_button click
+ if (event_handler.is_set())
+ event_handler();
+ else if (event_handler_self.is_set())
+ event_handler_self(*this);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // label object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void label::
+ draw (
+ const canvas& c
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty() || text_.size() == 0)
+ return;
+
+ using namespace std;
+ unsigned char r = text_color_.red;
+ unsigned char g = text_color_.green;
+ unsigned char b = text_color_.blue;
+ if (!enabled)
+ {
+ r = 128;
+ g = 128;
+ b = 128;
+ }
+
+ rectangle text_rect(rect);
+
+ string::size_type first, last;
+ first = 0;
+ last = text_.find_first_of('\n');
+ mfont->draw_string(c,text_rect,text_,rgb_pixel(r,g,b),first,last);
+
+ while (last != string::npos)
+ {
+ first = last+1;
+ last = text_.find_first_of('\n',first);
+ text_rect.set_top(text_rect.top()+mfont->height());
+ mfont->draw_string(c,text_rect,text_,rgb_pixel(r,g,b),first,last);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void label::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+ set_text(text_);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ void label::
+ set_text (
+ const std::string& text
+ )
+ {
+ set_text(convert_mbstring_to_wstring(text));
+ }
+
+ void label::
+ set_text (
+ const std::wstring& text
+ )
+ {
+ set_text(convert_wstring_to_utf32(text));
+ }
+
+ void label::
+ set_text (
+ const dlib::ustring& text
+ )
+ {
+ using namespace std;
+ auto_mutex M(m);
+ text_ = text;
+ // do this to get rid of any reference counting that may be present in
+ // the std::string implementation.
+ text_[0] = text[0];
+
+ rectangle old(rect);
+
+ unsigned long width;
+ unsigned long height;
+ mfont->compute_size(text,width,height);
+
+ rect.set_right(rect.left() + width - 1);
+ rect.set_bottom(rect.top() + height - 1);
+
+ parent.invalidate_rectangle(rect+old);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string label::
+ text (
+ ) const
+ {
+ return convert_wstring_to_mbstring(wtext());
+ }
+
+ const std::wstring label::
+ wtext (
+ ) const
+ {
+ return convert_utf32_to_wstring(utext());
+ }
+
+ const dlib::ustring label::
+ utext (
+ ) const
+ {
+ auto_mutex M(m);
+ dlib::ustring temp = text_;
+ // do this to get rid of any reference counting that may be present in
+ // the std::string implementation.
+ temp[0] = text_[0];
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void label::
+ set_text_color (
+ const rgb_pixel color
+ )
+ {
+ m.lock();
+ text_color_ = color;
+ parent.invalidate_rectangle(rect);
+ m.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const rgb_pixel label::
+ text_color (
+ ) const
+ {
+ auto_mutex M(m);
+ return text_color_;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // text_field object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ rectangle text_field::
+ get_text_rect (
+ ) const
+ {
+ // figure out where the text string should appear
+ unsigned long vertical_pad = (rect.height() - mfont->height())/2+1;
+
+ rectangle text_rect;
+ text_rect.set_left(rect.left()+style->get_padding(*mfont));
+ text_rect.set_top(rect.top()+vertical_pad);
+ text_rect.set_right(rect.right()-style->get_padding(*mfont));
+ text_rect.set_bottom(text_rect.top()+mfont->height()-1);
+ return text_rect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ enable (
+ )
+ {
+ drawable::enable();
+ right_click_menu.enable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ give_input_focus (
+ )
+ {
+ auto_mutex M(m);
+ has_focus = true;
+ cursor_visible = true;
+ parent.invalidate_rectangle(rect);
+ t.start();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool text_field::
+ has_input_focus (
+ ) const
+ {
+ auto_mutex M(m);
+ return has_focus;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ select_all_text (
+ )
+ {
+ auto_mutex M(m);
+ on_select_all();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_cut (
+ )
+ {
+ on_copy();
+ on_delete_selected();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_copy (
+ )
+ {
+ if (highlight_start <= highlight_end)
+ {
+ put_on_clipboard(text_.substr(highlight_start, highlight_end-highlight_start+1));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_paste (
+ )
+ {
+ ustring temp_str;
+ get_from_clipboard(temp_str);
+
+ // If this is a multi line string then just take the first line.
+ ustring::size_type pos = temp_str.find_first_of('\n');
+ if (pos != ustring::npos)
+ {
+ temp_str = temp_str.substr(0,pos);
+ }
+
+ if (highlight_start <= highlight_end)
+ {
+ text_ = text_.substr(0,highlight_start) + temp_str +
+ text_.substr(highlight_end+1,text_.size()-highlight_end-1);
+ move_cursor(highlight_start+temp_str.size());
+ highlight_start = 0;
+ highlight_end = -1;
+ parent.invalidate_rectangle(rect);
+ on_no_text_selected();
+
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ else
+ {
+ text_ = text_.substr(0,cursor_pos) + temp_str +
+ text_.substr(cursor_pos,text_.size()-cursor_pos);
+ move_cursor(cursor_pos+temp_str.size());
+
+ // send out the text modified event
+ if (temp_str.size() != 0 && text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_select_all (
+ )
+ {
+ move_cursor(static_cast<long>(text_.size()));
+ highlight_start = 0;
+ highlight_end = static_cast<long>(text_.size()-1);
+ if (highlight_start <= highlight_end)
+ on_text_is_selected();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_delete_selected (
+ )
+ {
+ if (highlight_start <= highlight_end)
+ {
+ text_ = text_.erase(highlight_start,highlight_end-highlight_start+1);
+ move_cursor(highlight_start);
+ highlight_start = 0;
+ highlight_end = -1;
+
+ on_no_text_selected();
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_text_is_selected (
+ )
+ {
+ right_click_menu.menu().enable_menu_item(0);
+ right_click_menu.menu().enable_menu_item(1);
+ right_click_menu.menu().enable_menu_item(3);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_no_text_selected (
+ )
+ {
+ right_click_menu.menu().disable_menu_item(0);
+ right_click_menu.menu().disable_menu_item(1);
+ right_click_menu.menu().disable_menu_item(3);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ show (
+ )
+ {
+ drawable::show();
+ right_click_menu.show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ disable (
+ )
+ {
+ auto_mutex M(m);
+ drawable::disable();
+ t.stop();
+ has_focus = false;
+ cursor_visible = false;
+ right_click_menu.disable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ hide (
+ )
+ {
+ auto_mutex M(m);
+ drawable::hide();
+ t.stop();
+ has_focus = false;
+ cursor_visible = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+ // adjust the height of this text field so that it is appropriate for the current
+ // font size
+ rect.set_bottom(rect.top() + mfont->height()+ (style->get_padding(*mfont))*2);
+ set_text(text_);
+ right_click_menu.set_rect(get_text_rect());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ draw (
+ const canvas& c
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ style->draw_text_field(c,rect,get_text_rect(), enabled, *mfont, text_, cursor_x, text_pos,
+ text_color_, bg_color_, has_focus, cursor_visible, highlight_start,
+ highlight_end);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ set_text (
+ const std::string& text
+ )
+ {
+ set_text(convert_mbstring_to_wstring(text));
+ }
+
+ void text_field::
+ set_text (
+ const std::wstring& text
+ )
+ {
+ set_text(convert_wstring_to_utf32(text));
+ }
+
+ void text_field::
+ set_text (
+ const dlib::ustring& text
+ )
+ {
+ DLIB_ASSERT ( text.find_first_of('\n') == std::string::npos ,
+ "\tvoid text_field::set_text()"
+ << "\n\ttext: " << narrow(text) );
+ auto_mutex M(m);
+ // do this to get rid of any reference counting that may be present in
+ // the std::string implementation.
+ text_ = text.c_str();
+
+ move_cursor(0);
+
+ highlight_start = 0;
+ highlight_end = -1;
+
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string text_field::
+ text (
+ ) const
+ {
+ std::string temp = convert_wstring_to_mbstring(wtext());
+ return temp;
+ }
+
+ const std::wstring text_field::
+ wtext (
+ ) const
+ {
+ std::wstring temp = convert_utf32_to_wstring(utext());
+ return temp;
+ }
+
+ const dlib::ustring text_field::
+ utext (
+ ) const
+ {
+ auto_mutex M(m);
+ // do this to get rid of any reference counting that may be present in
+ // the dlib::ustring implementation.
+ dlib::ustring temp = text_.c_str();
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ set_width (
+ unsigned long width
+ )
+ {
+ auto_mutex M(m);
+ if (width < style->get_padding(*mfont)*2)
+ return;
+
+ rectangle old(rect);
+
+ rect.set_right(rect.left() + width - 1);
+
+ right_click_menu.set_rect(get_text_rect());
+ parent.invalidate_rectangle(rect+old);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ drawable::set_pos(x,y);
+ right_click_menu.set_rect(get_text_rect());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ set_background_color (
+ const rgb_pixel color
+ )
+ {
+ auto_mutex M(m);
+ bg_color_ = color;
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const rgb_pixel text_field::
+ background_color (
+ ) const
+ {
+ auto_mutex M(m);
+ return bg_color_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ set_text_color (
+ const rgb_pixel color
+ )
+ {
+ auto_mutex M(m);
+ text_color_ = color;
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const rgb_pixel text_field::
+ text_color (
+ ) const
+ {
+ auto_mutex M(m);
+ return text_color_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ if (!enabled || hidden || !has_focus)
+ {
+ return;
+ }
+
+ if (state & base_window::LEFT)
+ {
+ if (highlight_start <= highlight_end)
+ {
+ if (highlight_start == cursor_pos)
+ shift_pos = highlight_end + 1;
+ else
+ shift_pos = highlight_start;
+ }
+
+ unsigned long new_pos = mfont->compute_cursor_pos(get_text_rect(),text_,x,y,text_pos);
+ if (static_cast<long>(new_pos) != cursor_pos)
+ {
+ move_cursor(new_pos);
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (shift_pos != -1)
+ {
+ shift_pos = -1;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_mouse_up (
+ unsigned long btn,
+ unsigned long,
+ long ,
+ long
+ )
+ {
+ if (!enabled || hidden)
+ return;
+
+ if (btn == base_window::LEFT)
+ shift_pos = -1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool double_clicked
+ )
+ {
+ using namespace std;
+ if (!enabled || hidden || btn != (unsigned long)base_window::LEFT)
+ return;
+
+ if (rect.contains(x,y))
+ {
+ has_focus = true;
+ cursor_visible = true;
+ parent.invalidate_rectangle(rect);
+ t.start();
+
+ if (double_clicked)
+ {
+ // highlight the double clicked word
+ string::size_type first, last;
+ const ustring ustr = convert_utf8_to_utf32(std::string(" \t\n"));
+ first = text_.substr(0,cursor_pos).find_last_of(ustr.c_str());
+ last = text_.find_first_of(ustr.c_str(),cursor_pos);
+ long f = static_cast<long>(first);
+ long l = static_cast<long>(last);
+ if (first == string::npos)
+ f = -1;
+ if (last == string::npos)
+ l = static_cast<long>(text_.size());
+
+ ++f;
+ --l;
+
+ move_cursor(l+1);
+ highlight_start = f;
+ highlight_end = l;
+ on_text_is_selected();
+ }
+ else
+ {
+ if (state & base_window::SHIFT)
+ {
+ if (highlight_start <= highlight_end)
+ {
+ if (highlight_start == cursor_pos)
+ shift_pos = highlight_end + 1;
+ else
+ shift_pos = highlight_start;
+ }
+ else
+ {
+ shift_pos = cursor_pos;
+ }
+ }
+
+ bool at_end = false;
+ if (cursor_pos == 0 || cursor_pos == static_cast<long>(text_.size()))
+ at_end = true;
+ const long old_pos = cursor_pos;
+
+ unsigned long new_pos = mfont->compute_cursor_pos(get_text_rect(),text_,x,y,text_pos);
+ if (static_cast<long>(new_pos) != cursor_pos)
+ {
+ move_cursor(new_pos);
+ parent.invalidate_rectangle(rect);
+ }
+ shift_pos = cursor_pos;
+
+ if (at_end && cursor_pos == old_pos)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+ }
+ else if (has_focus)
+ {
+ t.stop();
+ has_focus = false;
+ cursor_visible = false;
+ shift_pos = -1;
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+
+ if (focus_lost_handler.is_set())
+ focus_lost_handler();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ // If the right click menu is up then we don't want to do anything with
+ // the keyboard ourselves. Let the popup menu use the keyboard for now.
+ if (right_click_menu.popup_menu_visible())
+ return;
+
+ const ustring space_str = convert_utf8_to_utf32(std::string(" \t\n"));
+ const bool shift = (state&base_window::KBD_MOD_SHIFT) != 0;
+ const bool ctrl = (state&base_window::KBD_MOD_CONTROL) != 0;
+ if (has_focus && enabled && !hidden)
+ {
+ if (shift && is_printable == false)
+ {
+ if (shift_pos == -1)
+ {
+ if (highlight_start <= highlight_end)
+ {
+ if (highlight_start == cursor_pos)
+ shift_pos = highlight_end + 1;
+ else
+ shift_pos = highlight_start;
+ }
+ else
+ {
+ shift_pos = cursor_pos;
+ }
+ }
+ }
+ else
+ {
+ shift_pos = -1;
+ }
+
+ if (key == base_window::KEY_LEFT ||
+ key == base_window::KEY_UP)
+ {
+ if (cursor_pos != 0)
+ {
+ unsigned long new_pos;
+ if (ctrl)
+ {
+ // find the first non-whitespace to our left
+ std::string::size_type pos = text_.find_last_not_of(space_str.c_str(),cursor_pos);
+ if (pos != std::string::npos)
+ {
+ pos = text_.find_last_of(space_str.c_str(),pos);
+ if (pos != std::string::npos)
+ new_pos = static_cast<unsigned long>(pos);
+ else
+ new_pos = 0;
+ }
+ else
+ {
+ new_pos = 0;
+ }
+ }
+ else
+ {
+ new_pos = cursor_pos-1;
+ }
+
+ move_cursor(new_pos);
+ }
+ else if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+
+ }
+ else if (key == base_window::KEY_RIGHT ||
+ key == base_window::KEY_DOWN)
+ {
+ if (cursor_pos != static_cast<long>(text_.size()))
+ {
+ unsigned long new_pos;
+ if (ctrl)
+ {
+ // find the first non-whitespace to our left
+ std::string::size_type pos = text_.find_first_not_of(space_str.c_str(),cursor_pos);
+ if (pos != std::string::npos)
+ {
+ pos = text_.find_first_of(space_str.c_str(),pos);
+ if (pos != std::string::npos)
+ new_pos = static_cast<unsigned long>(pos+1);
+ else
+ new_pos = static_cast<unsigned long>(text_.size());
+ }
+ else
+ {
+ new_pos = static_cast<unsigned long>(text_.size());
+ }
+ }
+ else
+ {
+ new_pos = cursor_pos+1;
+ }
+
+ move_cursor(new_pos);
+ }
+ else if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (is_printable)
+ {
+ if (ctrl)
+ {
+ if (key == 'a')
+ {
+ on_select_all();
+ }
+ else if (key == 'c')
+ {
+ on_copy();
+ }
+ else if (key == 'v')
+ {
+ on_paste();
+ }
+ else if (key == 'x')
+ {
+ on_cut();
+ }
+ }
+ else if (key != '\n')
+ {
+ if (highlight_start <= highlight_end)
+ {
+ text_ = text_.substr(0,highlight_start) + static_cast<unichar>(key) +
+ text_.substr(highlight_end+1,text_.size()-highlight_end-1);
+ move_cursor(highlight_start+1);
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ else
+ {
+ text_ = text_.substr(0,cursor_pos) + static_cast<unichar>(key) +
+ text_.substr(cursor_pos,text_.size()-cursor_pos);
+ move_cursor(cursor_pos+1);
+ }
+ unsigned long height;
+ mfont->compute_size(text_,text_width,height,text_pos);
+
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ else if (key == '\n')
+ {
+ if (enter_key_handler.is_set())
+ enter_key_handler();
+ }
+ }
+ else if (key == base_window::KEY_BACKSPACE)
+ {
+ // if something is highlighted then delete that
+ if (highlight_start <= highlight_end)
+ {
+ on_delete_selected();
+ }
+ else if (cursor_pos != 0)
+ {
+ text_ = text_.erase(cursor_pos-1,1);
+ move_cursor(cursor_pos-1);
+
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ else
+ {
+ // do this just so it repaints itself right
+ move_cursor(cursor_pos);
+ }
+ unsigned long height;
+ mfont->compute_size(text_,text_width,height,text_pos);
+ parent.invalidate_rectangle(rect);
+ }
+ else if (key == base_window::KEY_DELETE)
+ {
+ // if something is highlighted then delete that
+ if (highlight_start <= highlight_end)
+ {
+ on_delete_selected();
+ }
+ else if (cursor_pos != static_cast<long>(text_.size()))
+ {
+ text_ = text_.erase(cursor_pos,1);
+
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ else
+ {
+ // do this just so it repaints itself right
+ move_cursor(cursor_pos);
+ }
+ parent.invalidate_rectangle(rect);
+
+ unsigned long height;
+ mfont->compute_size(text_,text_width,height,text_pos);
+ }
+ else if (key == base_window::KEY_HOME)
+ {
+ move_cursor(0);
+ if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (key == base_window::KEY_END)
+ {
+ move_cursor(static_cast<unsigned long>(text_.size()));
+ if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ cursor_visible = true;
+ recent_movement = true;
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ on_string_put(
+ const std::wstring &str
+ )
+ {
+ if (has_focus && enabled && !hidden){
+ ustring ustr = convert_wstring_to_utf32(str);
+ if (highlight_start <= highlight_end)
+ {
+ text_ = text_.substr(0,highlight_start) + ustr +
+ text_.substr(highlight_end+1,text_.size()-highlight_end-1);
+ move_cursor(highlight_start+ustr.size());
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ else
+ {
+ text_ = text_.substr(0,cursor_pos) + ustr +
+ text_.substr(cursor_pos,text_.size()-cursor_pos);
+ move_cursor(cursor_pos+ustr.size());
+ }
+ unsigned long height;
+ mfont->compute_size(text_,text_width,height,text_pos);
+
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_field::
+ move_cursor (
+ unsigned long pos
+ )
+ {
+ using namespace std;
+ const long old_cursor_pos = cursor_pos;
+
+ if (text_pos >= pos)
+ {
+ // the cursor should go all the way to the left side of the text
+ if (pos >= 6)
+ text_pos = pos-6;
+ else
+ text_pos = 0;
+
+ cursor_pos = pos;
+ unsigned long height;
+ mfont->compute_size(text_,text_width,height,text_pos);
+
+ unsigned long width;
+ unsigned long new_x = style->get_padding(*mfont);
+ if (static_cast<long>(cursor_pos)-1 >= static_cast<long>(text_pos))
+ {
+ mfont->compute_size(text_,width,height,text_pos,cursor_pos-1);
+ if (cursor_pos != 0)
+ new_x += width - mfont->right_overflow();
+ }
+
+ cursor_x = new_x;
+ }
+ else
+ {
+ unsigned long height;
+ unsigned long width;
+ mfont->compute_size(text_,width,height,text_pos,pos-1);
+
+ unsigned long new_x = style->get_padding(*mfont) +
+ width - mfont->right_overflow();
+
+ // move the text to the left if necessary
+ if (new_x + 4 > rect.width())
+ {
+ while (new_x > rect.width() - rect.width()/5)
+ {
+ new_x -= (*mfont)[text_[text_pos]].width();
+ ++text_pos;
+ }
+ }
+
+ cursor_x = new_x;
+ cursor_pos = pos;
+ mfont->compute_size(text_,text_width,height,text_pos);
+ }
+
+ parent.set_im_pos(rect.left()+cursor_x, rect.top());
+
+ if (old_cursor_pos != cursor_pos)
+ {
+ if (shift_pos != -1)
+ {
+ highlight_start = std::min(shift_pos,cursor_pos);
+ highlight_end = std::max(shift_pos,cursor_pos)-1;
+ }
+ else
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ }
+
+ if (highlight_start > highlight_end)
+ on_no_text_selected();
+ else
+ on_text_is_selected();
+
+ recent_movement = true;
+ cursor_visible = true;
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// tabbed_display object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ tabbed_display::
+ tabbed_display(
+ drawable_window& w
+ ) :
+ drawable(w,MOUSE_CLICK),
+ selected_tab_(0),
+ left_pad(6),
+ right_pad(4),
+ top_pad(3),
+ bottom_pad(3)
+ {
+ rect = rectangle(0,0,40,mfont->height()+top_pad+bottom_pad);
+ enable_events();
+ tabs.set_max_size(1);
+ tabs.set_size(1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ tabbed_display::
+ ~tabbed_display(
+ )
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ auto_mutex M(m);
+ // we have to adjust the positions of all the tab rectangles
+ const long xdelta = rect.left() - x;
+ const long ydelta = rect.top() - y;
+ for (unsigned long i = 0; i < tabs.size(); ++i)
+ {
+ tabs[i].rect.set_left(tabs[i].rect.left()+xdelta);
+ tabs[i].rect.set_right(tabs[i].rect.right()+xdelta);
+
+ tabs[i].rect.set_top(tabs[i].rect.top()+ydelta);
+ tabs[i].rect.set_bottom(tabs[i].rect.bottom()+ydelta);
+
+
+ // adjust the position of the group associated with this tab if it exists
+ if (tabs[i].group)
+ tabs[i].group->set_pos(x+3, y+mfont->height()+top_pad+bottom_pad+3);
+ }
+ drawable::set_pos(x,y);
+ recompute_tabs();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ fit_to_contents (
+ )
+ {
+ auto_mutex M(m);
+ rectangle new_rect;
+ point p(rect.left(),rect.top());
+ new_rect += p;
+
+ for (unsigned long i = 0; i < tabs.size(); ++i)
+ {
+ if (tabs[i].group)
+ {
+ tabs[i].group->fit_to_contents();
+ new_rect += tabs[i].group->get_rect();
+ }
+ }
+
+ // and give the new rect an additional 4 pixels on the bottom and right sides
+ // so that the contents to hit the edge of the tabbed display
+ new_rect = resize_rect(new_rect, new_rect.width()+4, new_rect.height()+4);
+
+ parent.invalidate_rectangle(new_rect+rect);
+ rect = new_rect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ rectangle old(rect);
+ const long x = rect.left();
+ const long y = rect.top();
+ rect.set_right(x+width-1);
+ rect.set_bottom(y+height-1);
+
+ recompute_tabs();
+
+ parent.invalidate_rectangle(rect+old);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ set_number_of_tabs (
+ unsigned long num
+ )
+ {
+ auto_mutex M(m);
+
+ DLIB_ASSERT ( num > 0 ,
+ "\tvoid tabbed_display::set_number_of_tabs()"
+ << "\n\tnum: " << num );
+
+ tabs.set_max_size(num);
+ tabs.set_size(num);
+
+ selected_tab_ = 0;
+
+ recompute_tabs();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long tabbed_display::
+ selected_tab (
+ ) const
+ {
+ auto_mutex M(m);
+ return selected_tab_;
+ }
+
+ unsigned long tabbed_display::
+ number_of_tabs (
+ ) const
+ {
+ auto_mutex M(m);
+ return tabs.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string tabbed_display::
+ tab_name (
+ unsigned long idx
+ ) const
+ {
+ return convert_wstring_to_mbstring(tab_wname(idx));
+ }
+
+ const std::wstring tabbed_display::
+ tab_wname (
+ unsigned long idx
+ ) const
+ {
+ return convert_utf32_to_wstring(tab_uname(idx));
+ }
+
+ const dlib::ustring& tabbed_display::
+ tab_uname (
+ unsigned long idx
+ ) const
+ {
+ auto_mutex M(m);
+
+ DLIB_ASSERT ( idx < number_of_tabs() ,
+ "\tvoid tabbed_display::tab_name()"
+ << "\n\tidx: " << idx
+ << "\n\tnumber_of_tabs(): " << number_of_tabs() );
+
+ return tabs[idx].name;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ set_tab_name (
+ unsigned long idx,
+ const std::string& new_name
+ )
+ {
+ set_tab_name(idx, convert_mbstring_to_wstring(new_name));
+ }
+
+ void tabbed_display::
+ set_tab_name (
+ unsigned long idx,
+ const std::wstring& new_name
+ )
+ {
+ set_tab_name(idx, convert_wstring_to_utf32(new_name));
+ }
+
+ void tabbed_display::
+ set_tab_name (
+ unsigned long idx,
+ const dlib::ustring& new_name
+ )
+ {
+ auto_mutex M(m);
+
+
+ DLIB_ASSERT ( idx < number_of_tabs() ,
+ "\tvoid tabbed_display::set_tab_name()"
+ << "\n\tidx: " << idx
+ << "\n\tnumber_of_tabs(): " << number_of_tabs() );
+
+
+ tabs[idx].name = new_name;
+ // do this so that there isn't any reference counting going on
+ tabs[idx].name[0] = tabs[idx].name[0];
+ unsigned long height;
+ mfont->compute_size(new_name,tabs[idx].width,height);
+
+
+ recompute_tabs();
+
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long,
+ long x,
+ long y,
+ bool
+ )
+ {
+ if (rect.contains(x,y) && btn == base_window::LEFT && enabled && !hidden)
+ {
+ rectangle temp = rect;
+ const long offset = mfont->height() + bottom_pad + top_pad;
+ temp.set_bottom(rect.top()+offset);
+ if (temp.contains(x,y))
+ {
+ // now we have to figure out which tab was clicked
+ for (unsigned long i = 0; i < tabs.size(); ++i)
+ {
+ if (selected_tab_ != i && tabs[i].rect.contains(x,y) &&
+ tabs[selected_tab_].rect.contains(x,y) == false)
+ {
+ unsigned long old_idx = selected_tab_;
+ selected_tab_ = i;
+ recompute_tabs();
+ parent.invalidate_rectangle(temp);
+
+ // adjust the widget_group objects for these tabs if they exist
+ if (tabs[i].group)
+ tabs[i].group->show();
+ if (tabs[old_idx].group)
+ tabs[old_idx].group->hide();
+
+ if (event_handler.is_set())
+ event_handler(i,old_idx);
+ break;
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ set_tab_group (
+ unsigned long idx,
+ widget_group& group
+ )
+ {
+ auto_mutex M(m);
+
+ DLIB_ASSERT ( idx < number_of_tabs() ,
+ "\tvoid tabbed_display::set_tab_group()"
+ << "\n\tidx: " << idx
+ << "\n\tnumber_of_tabs(): " << number_of_tabs() );
+
+
+ tabs[idx].group = &group;
+ group.set_pos(rect.left()+3,rect.top()+mfont->height()+top_pad+bottom_pad+2);
+ if (idx == selected_tab_)
+ group.show();
+ else
+ group.hide();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ disable (
+ )
+ {
+ auto_mutex M(m);
+ if (tabs[selected_tab_].group)
+ tabs[selected_tab_].group->disable();
+ drawable::disable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ enable (
+ )
+ {
+ auto_mutex M(m);
+ if (tabs[selected_tab_].group)
+ tabs[selected_tab_].group->enable();
+ drawable::enable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ hide (
+ )
+ {
+ auto_mutex M(m);
+ if (tabs[selected_tab_].group)
+ tabs[selected_tab_].group->hide();
+ drawable::hide();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ show (
+ )
+ {
+ auto_mutex M(m);
+ if (tabs[selected_tab_].group)
+ tabs[selected_tab_].group->show();
+ drawable::show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ draw (
+ const canvas& c
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ // draw the main border first
+ rectangle main_box(rect.left(),rect.top()+mfont->height()+top_pad+bottom_pad,rect.right(),rect.bottom());
+ draw_button_up(c,main_box);
+ draw_pixel(c,point(main_box.right()-1,main_box.top()),rgb_pixel(128,128,128));
+
+ rgb_pixel color;
+ if (enabled)
+ {
+ color.red = 0;
+ color.green = 0;
+ color.blue = 0;
+ }
+ else
+ {
+ color.red = 128;
+ color.green = 128;
+ color.blue = 128;
+ }
+
+ // draw the tabs
+ for (unsigned long i = 0; i < tabs.size(); ++i)
+ {
+ if (selected_tab_ != i)
+ draw_tab(tabs[i].rect,c);
+
+ // draw the name string
+ rectangle temp = tabs[i].rect;
+ temp.set_top(temp.top()+top_pad);
+ temp.set_bottom(temp.bottom()+bottom_pad);
+ temp.set_left(temp.left()+left_pad);
+ temp.set_right(temp.right()+right_pad);
+ mfont->draw_string(c,temp,tabs[i].name,color);
+ }
+ draw_tab(tabs[selected_tab_].rect,c);
+ draw_line(c,
+ point(tabs[selected_tab_].rect.left()+1,
+ tabs[selected_tab_].rect.bottom()),
+ point(tabs[selected_tab_].rect.right()-2,
+ tabs[selected_tab_].rect.bottom()),
+ rgb_pixel(212,208,200));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ draw_tab (
+ const rectangle& tab,
+ const canvas& c
+ ) const
+ {
+ const rgb_pixel white(255,255,255);
+ const rgb_pixel background(212,208,200);
+ const rgb_pixel dark_gray(64,64,64);
+ const rgb_pixel gray(128,128,128);
+ draw_line(c,point(tab.left(),tab.top()+2),point(tab.left(),tab.bottom()),white);
+ draw_line(c,point(tab.left()+1,tab.top()+2),point(tab.left()+1,tab.bottom()),background);
+ draw_line(c,point(tab.right(),tab.top()+2),point(tab.right(),tab.bottom()),dark_gray);
+ draw_line(c,point(tab.right()-1,tab.top()+2),point(tab.right()-1,tab.bottom()),gray);
+ draw_line(c,point(tab.left()+2,tab.top()),point(tab.right()-2,tab.top()),white);
+ draw_pixel(c,point(tab.left()+1,tab.top()+1),white);
+ draw_pixel(c,point(tab.right()-1,tab.top()+1),dark_gray);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+
+ for (unsigned long i = 0; i < tabs.size(); ++i)
+ {
+ unsigned long height;
+ mfont->compute_size(tabs[i].name,tabs[i].width,height);
+ }
+
+ recompute_tabs();
+ set_pos(rect.left(), rect.top());
+
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tabbed_display::
+ recompute_tabs (
+ )
+ {
+ const long offset = mfont->height() + bottom_pad + top_pad;
+
+
+ // figure out the size and position of all the tabs
+ rectangle sel_tab_rect, other_tab;
+ sel_tab_rect.set_top(rect.top());
+ sel_tab_rect.set_bottom(rect.top()+offset);
+
+ other_tab.set_top(rect.top()+2);
+ other_tab.set_bottom(rect.top()+offset-1);
+
+ long cur_x = rect.left();
+ for (unsigned long i = 0; i < tabs.size(); ++i)
+ {
+ const unsigned long str_width = tabs[i].width;
+ if (selected_tab_ != i)
+ {
+ other_tab.set_left(cur_x);
+ cur_x += left_pad + str_width + right_pad;
+ other_tab.set_right(cur_x);
+ tabs[i].rect = other_tab;
+ ++cur_x;
+
+ }
+ else
+ {
+ if (i != 0)
+ sel_tab_rect.set_left(cur_x-2);
+ else
+ sel_tab_rect.set_left(cur_x);
+
+ cur_x += left_pad + str_width + right_pad;
+
+ if (i != tabs.size()-1)
+ sel_tab_rect.set_right(cur_x+2);
+ else
+ sel_tab_rect.set_right(cur_x);
+ ++cur_x;
+
+ tabs[i].rect = sel_tab_rect;
+ }
+ }
+
+ // make sure this object is wide enough
+ const rectangle& last = tabs[tabs.size()-1].rect;
+ const rectangle& first = tabs[0].rect;
+ rect = last + rect + first;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// named_rectangle object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ named_rectangle::
+ named_rectangle(
+ drawable_window& w
+ ) :
+ drawable(w),
+ name_width(0),
+ name_height(0)
+ {
+ make_name_fit_in_rect();
+ enable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ named_rectangle::
+ ~named_rectangle(
+ )
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void named_rectangle::
+ set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ rectangle old(rect);
+ const long x = rect.left();
+ const long y = rect.top();
+ rect.set_right(x+width-1);
+ rect.set_bottom(y+height-1);
+
+ make_name_fit_in_rect();
+ parent.invalidate_rectangle(rect+old);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void named_rectangle::
+ wrap_around (
+ const rectangle& r
+ )
+ {
+ auto_mutex M(m);
+ rectangle old(rect);
+ const unsigned long pad = name_height/2;
+
+ rect = rectangle(r.left()-pad, r.top()-name_height*4/3, r.right()+pad, r.bottom()+pad);
+
+ make_name_fit_in_rect();
+ parent.invalidate_rectangle(rect+old);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void named_rectangle::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+ mfont->compute_size(name_,name_width,name_height);
+ make_name_fit_in_rect();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void named_rectangle::
+ make_name_fit_in_rect (
+ )
+ {
+ // make sure the named rectangle is big enough to contain the name
+ const unsigned long wtemp = mfont->height() + name_width;
+ const unsigned long htemp = mfont->height() + name_height;
+ if (rect.width() < wtemp)
+ rect.set_right(rect.left() + wtemp - 1 );
+ if (rect.height() < htemp)
+ rect.set_bottom(rect.bottom() + htemp - 1 );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void named_rectangle::
+ set_name (
+ const std::string& name
+ )
+ {
+ set_name(convert_mbstring_to_wstring(name));
+ }
+
+ void named_rectangle::
+ set_name (
+ const std::wstring& name
+ )
+ {
+ set_name(convert_wstring_to_utf32(name));
+ }
+
+ void named_rectangle::
+ set_name (
+ const dlib::ustring& name
+ )
+ {
+ auto_mutex M(m);
+ name_ = name.c_str();
+ mfont->compute_size(name_,name_width,name_height);
+
+ make_name_fit_in_rect();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string named_rectangle::
+ name (
+ ) const
+ {
+ return convert_wstring_to_mbstring(wname());
+ }
+
+ const std::wstring named_rectangle::
+ wname (
+ ) const
+ {
+ return convert_utf32_to_wstring(uname());
+ }
+
+ const dlib::ustring named_rectangle::
+ uname (
+ ) const
+ {
+ auto_mutex M(m);
+ return dlib::ustring(name_.c_str());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void named_rectangle::
+ draw (
+ const canvas& c
+ ) const
+ {
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ const unsigned long gap = mfont->height()/2;
+ rectangle strrect = rect;
+ strrect.set_left(rect.left() + gap);
+
+ const unsigned long rtop = rect.top() + name_height/2;
+
+ const rgb_pixel white(255,255,255);
+ const rgb_pixel gray(128,128,128);
+
+ mfont->draw_string(c,strrect,name_);
+ draw_line(c,point(rect.left(), rtop),
+ point(rect.left()+gap/2, rtop), gray);
+ draw_line(c,point(rect.left(), rtop),
+ point(rect.left(), rect.bottom()-1), gray);
+ draw_line(c,point(rect.left(), rect.bottom()-1),
+ point(rect.right()-1, rect.bottom()-1), gray);
+ draw_line(c,point(rect.right()-1, rtop),
+ point(rect.right()-1, rect.bottom()-2), gray);
+ draw_line(c,point(strrect.left() + name_width + 2, rtop),
+ point(rect.right()-1, rtop), gray);
+
+ draw_line(c,point(strrect.left() + name_width + 2, rtop+1),
+ point( rect.right()-2, rtop+1), white);
+ draw_line(c,point(rect.right(), rtop),
+ point(rect.right(), rect.bottom()), white);
+ draw_line(c,point(rect.left(), rect.bottom()),
+ point(rect.right(), rect.bottom()), white);
+ draw_line(c,point(rect.left()+1, rtop+1),
+ point(rect.left()+1, rect.bottom()-2), white);
+ draw_line(c,point(rect.left()+1, rtop+1),
+ point(rect.left()+gap/2, rtop+1), white);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class mouse_tracker
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ mouse_tracker::
+ mouse_tracker(
+ drawable_window& w
+ ) :
+ draggable(w),
+ offset(18),
+ nr(w),
+ x_label(w),
+ y_label(w),
+ click_x(-1),
+ click_y(-1)
+ {
+ set_draggable_area(rectangle(0,0,500,500));
+
+
+ x_label.set_text("x: ");
+ y_label.set_text("y: ");
+ nr.set_name("mouse position");
+
+
+ x_label.set_pos(offset,offset);
+ y_label.set_pos(x_label.get_rect().left(), x_label.get_rect().bottom()+3);
+
+ nr.wrap_around(x_label.get_rect() + y_label.get_rect());
+ rect = nr.get_rect();
+
+ set_z_order(2000000000);
+ x_label.set_z_order(2000000001);
+ y_label.set_z_order(2000000001);
+ nr.set_z_order(2000000001);
+
+ enable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ mouse_tracker::
+ ~mouse_tracker(
+ )
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ nr.set_main_font(f);
+ x_label.set_main_font(f);
+ y_label.set_main_font(f);
+ mfont = f;
+ nr.wrap_around(x_label.get_rect() + y_label.get_rect());
+ rect = nr.get_rect();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ draggable::set_pos(x,y);
+ nr.set_pos(x,y);
+ x_label.set_pos(rect.left()+offset,rect.top()+offset);
+ y_label.set_pos(x_label.get_rect().left(), x_label.get_rect().bottom()+3);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ show (
+ )
+ {
+ draggable::show();
+ nr.show();
+ x_label.show();
+ y_label.show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ hide (
+ )
+ {
+ draggable::hide();
+ nr.hide();
+ x_label.hide();
+ y_label.hide();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ enable (
+ )
+ {
+ draggable::enable();
+ nr.enable();
+ x_label.enable();
+ y_label.enable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ disable (
+ )
+ {
+ draggable::disable();
+ nr.disable();
+ x_label.disable();
+ y_label.disable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool double_clicked
+ )
+ {
+ draggable::on_mouse_down(btn,state,x,y,double_clicked);
+ if ((state & base_window::SHIFT) && (btn == base_window::LEFT) && enabled && !hidden)
+ {
+ parent.invalidate_rectangle(rectangle(x,y,x,y));
+ parent.invalidate_rectangle(rectangle(click_x,click_y,click_x,click_y));
+ click_x = x;
+ click_y = y;
+
+ y_label.set_text("y: 0");
+ x_label.set_text("x: 0");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ if (!hidden && enabled)
+ {
+ parent.invalidate_rectangle(rect);
+ draggable::on_mouse_move(state,x,y);
+
+ long dx = 0;
+ long dy = 0;
+ if (click_x != -1)
+ dx = click_x;
+ if (click_y != -1)
+ dy = click_y;
+
+ sout.str("");
+ sout << "y: " << y - dy;
+ y_label.set_text(sout.str());
+
+ sout.str("");
+ sout << "x: " << x - dx;
+ x_label.set_text(sout.str());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ on_drag (
+ )
+ {
+ nr.set_pos(rect.left(),rect.top());
+ x_label.set_pos(rect.left()+offset,rect.top()+offset);
+ y_label.set_pos(x_label.get_rect().left(), x_label.get_rect().bottom()+3);
+
+ long x = 0;
+ long y = 0;
+ if (click_x != -1)
+ x = click_x;
+ if (click_y != -1)
+ y = click_y;
+
+ sout.str("");
+ sout << "y: " << lasty - y;
+ y_label.set_text(sout.str());
+
+ sout.str("");
+ sout << "x: " << lastx - x;
+ x_label.set_text(sout.str());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void mouse_tracker::
+ draw (
+ const canvas& c
+ ) const
+ {
+ fill_rect(c, rect,rgb_pixel(212,208,200));
+ draw_pixel(c, point(click_x,click_y),rgb_pixel(255,0,0));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class list_box
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace list_box_helper{
+ template <typename S>
+ list_box<S>::
+ list_box(
+ drawable_window& w
+ ) :
+ scrollable_region(w,MOUSE_WHEEL|MOUSE_CLICK),
+ ms_enabled(false),
+ last_selected(0)
+ {
+ set_vertical_scroll_increment(mfont->height());
+ set_horizontal_scroll_increment(mfont->height());
+
+ style.reset(new list_box_style_default());
+ enable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ list_box<S>::
+ ~list_box(
+ )
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ void list_box<S>::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+ // recompute the sizes of all the items
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ mfont->compute_size(items[i].name,items[i].width, items[i].height);
+ }
+ set_vertical_scroll_increment(mfont->height());
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ bool list_box<S>::
+ is_selected (
+ unsigned long index
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( index < size() ,
+ "\tbool list_box::is_selected(index)"
+ << "\n\tindex: " << index
+ << "\n\tsize(): " << size() );
+
+ return items[index].is_selected;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ void list_box<S>::
+ select (
+ unsigned long index
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( index < size() ,
+ "\tvoid list_box::select(index)"
+ << "\n\tindex: " << index
+ << "\n\tsize(): " << size() );
+
+ last_selected = index;
+ items[index].is_selected = true;
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ void list_box<S>::
+ unselect (
+ unsigned long index
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( index < size() ,
+ "\tvoid list_box::unselect(index)"
+ << "\n\tindex: " << index
+ << "\n\tsize(): " << size() );
+ items[index].is_selected = false;
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const S& list_box<S>::operator [] (
+ unsigned long index
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( index < size() ,
+ "\tconst std::string& list_box::operator[](index)"
+ << "\n\tindex: " << index
+ << "\n\tsize(): " << size() );
+ return items[index].name;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ bool list_box<S>::
+ multiple_select_enabled (
+ ) const
+ {
+ auto_mutex M(m);
+ return ms_enabled;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ void list_box<S>::
+ enable_multiple_select (
+ )
+ {
+ auto_mutex M(m);
+ ms_enabled = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ void list_box<S>::
+ disable_multiple_select (
+ )
+ {
+ auto_mutex M(m);
+ ms_enabled = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ bool list_box<S>::
+ at_start (
+ ) const
+ {
+ auto_mutex M(m);
+ return items.at_start();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ void list_box<S>::
+ reset (
+ ) const
+ {
+ auto_mutex M(m);
+ items.reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ bool list_box<S>::
+ current_element_valid (
+ ) const
+ {
+ auto_mutex M(m);
+ return items.current_element_valid();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const S &list_box<S>::
+ element (
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( current_element_valid() ,
+ "\tconst std::string& list_box::element()"
+ );
+ return items.element().name;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const S &list_box<S>::
+ element (
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( current_element_valid() ,
+ "\tconst std::string& list_box::element()"
+ );
+ return items.element().name;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ bool list_box<S>::
+ move_next (
+ ) const
+ {
+ auto_mutex M(m);
+ return items.move_next();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ size_t list_box<S>::
+ size (
+ ) const
+ {
+ auto_mutex M(m);
+ return items.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ void list_box<S>::
+ draw (
+ const canvas& c
+ ) const
+ {
+ scrollable_region::draw(c);
+
+ rectangle area = display_rect().intersect(c);
+ if (area.is_empty())
+ return;
+
+ style->draw_list_box_background(c, display_rect(), enabled);
+
+ long y = total_rect().top();
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ if (y+(long)items[i].height <= area.top())
+ {
+ y += items[i].height;
+ continue;
+ }
+
+ rectangle r(total_rect().left(), y, display_rect().right(), y+items[i].height-1);
+
+ style->draw_list_box_item(c,r, display_rect(), enabled, *mfont, items[i].name, items[i].is_selected);
+
+
+ y += items[i].height;
+
+ if (y > area.bottom())
+ break;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ void list_box<S>::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ )
+ {
+ if (display_rect().contains(x,y) && btn == base_window::LEFT && enabled && !hidden )
+ {
+ if ( ms_enabled == false ||
+ ((!(state&base_window::CONTROL)) && !(state&base_window::SHIFT)))
+ {
+ items.reset();
+ while (items.move_next())
+ {
+ items.element().is_selected = false;
+ }
+ }
+
+ y -= total_rect().top();
+ long h = 0;
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ h += items[i].height;
+ if (h >= y)
+ {
+ if (ms_enabled)
+ {
+ if (state&base_window::CONTROL)
+ {
+ items[i].is_selected = !items[i].is_selected;
+ if (items[i].is_selected)
+ last_selected = i;
+ }
+ else if (state&base_window::SHIFT)
+ {
+ // we want to select everything between (and including) the
+ // current thing clicked and last_selected.
+ const unsigned long first = std::min(i,last_selected);
+ const unsigned long last = std::max(i,last_selected);
+ for (unsigned long j = first; j <= last; ++j)
+ items[j].is_selected = true;
+ }
+ else
+ {
+ items[i].is_selected = true;
+ last_selected = i;
+ if (is_double_click && event_handler.is_set())
+ event_handler(i);
+ else if (single_click_event_handler.is_set())
+ single_click_event_handler(i);
+ }
+ }
+ else
+ {
+ items[i].is_selected = true;
+ last_selected = i;
+ if (is_double_click && event_handler.is_set())
+ event_handler(i);
+ else if (single_click_event_handler.is_set())
+ single_click_event_handler(i);
+ }
+
+ break;
+ }
+ }
+
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ unsigned long list_box<S>::
+ get_selected (
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( multiple_select_enabled() == false,
+ "\tunsigned long list_box::get_selected()"
+ );
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ if (items[i].is_selected)
+ return i;
+ }
+ return items.size();
+ }
+// ----------------------------------------------------------------------------------------
+
+ // making instance of template
+ template class list_box<std::string>;
+ template class list_box<std::wstring>;
+ template class list_box<dlib::ustring>;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // function message_box()
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace message_box_helper
+ {
+ void box_win::
+ initialize (
+ )
+ {
+ msg.set_pos(20,20);
+ msg.set_text(message);
+ rectangle msg_rect = msg.get_rect();
+ btn_ok.set_name("OK");
+ btn_ok.set_size(60,btn_ok.height());
+ if (msg_rect.width() >= 60)
+ btn_ok.set_pos(msg_rect.width()/2+msg_rect.left()-btn_ok.width()/2,msg_rect.bottom()+15);
+ else
+ btn_ok.set_pos(20,msg_rect.bottom()+15);
+ btn_ok.set_click_handler(*this,&box_win::on_click);
+
+ rectangle size = btn_ok.get_rect() + msg_rect;
+ set_size(size.right()+20,size.bottom()+20);
+
+
+ show();
+ set_title(title);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ box_win::
+ box_win (
+ const std::string& title_,
+ const std::string& message_
+ ) :
+ drawable_window(false),
+ title(convert_mbstring_to_wstring(title_)),
+ message(convert_mbstring_to_wstring(message_)),
+ msg(*this),
+ btn_ok(*this)
+ {
+ initialize();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ box_win::
+ box_win (
+ const std::wstring& title_,
+ const std::wstring& message_
+ ) :
+ drawable_window(false),
+ title(title_),
+ message(message_),
+ msg(*this),
+ btn_ok(*this)
+ {
+ initialize();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ box_win::
+ box_win (
+ const dlib::ustring& title_,
+ const dlib::ustring& message_
+ ) :
+ drawable_window(false),
+ title(convert_utf32_to_wstring(title_)),
+ message(convert_utf32_to_wstring(message_)),
+ msg(*this),
+ btn_ok(*this)
+ {
+ initialize();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ box_win::
+ ~box_win (
+ )
+ {
+ close_window();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ deleter_thread (
+ void* param
+ )
+ {
+ // The point of this extra event_handler stuff is to allow the user
+ // to end the program from within the callback. So we want to destroy the
+ // window *before* we call their callback.
+ box_win& w = *static_cast<box_win*>(param);
+ w.close_window();
+ any_function<void()> event_handler(w.event_handler);
+ delete &w;
+ if (event_handler.is_set())
+ event_handler();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ on_click (
+ )
+ {
+ hide();
+ create_new_thread(&deleter_thread,this);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ base_window::on_close_return_code box_win::
+ on_window_close (
+ )
+ {
+ // The point of this extra event_handler stuff is to allow the user
+ // to end the program within the callback. So we want to destroy the
+ // window *before* we call their callback.
+ any_function<void()> event_handler_copy(event_handler);
+ delete this;
+ if (event_handler_copy.is_set())
+ event_handler_copy();
+ return CLOSE_WINDOW;
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ void blocking_box_win::
+ initialize (
+ )
+ {
+ msg.set_pos(20,20);
+ msg.set_text(message);
+ rectangle msg_rect = msg.get_rect();
+ btn_ok.set_name("OK");
+ btn_ok.set_size(60,btn_ok.height());
+ if (msg_rect.width() >= 60)
+ btn_ok.set_pos(msg_rect.width()/2+msg_rect.left()-btn_ok.width()/2,msg_rect.bottom()+15);
+ else
+ btn_ok.set_pos(20,msg_rect.bottom()+15);
+ btn_ok.set_click_handler(*this,&blocking_box_win::on_click);
+
+ rectangle size = btn_ok.get_rect() + msg_rect;
+ set_size(size.right()+20,size.bottom()+20);
+
+
+ set_title(title);
+ show();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ blocking_box_win::
+ blocking_box_win (
+ const std::string& title_,
+ const std::string& message_
+ ) :
+ drawable_window(false),
+ title(convert_mbstring_to_wstring(title_)),
+ message(convert_mbstring_to_wstring(message_)),
+ msg(*this),
+ btn_ok(*this)
+ {
+ initialize();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ blocking_box_win::
+ blocking_box_win (
+ const std::wstring& title_,
+ const std::wstring& message_
+ ) :
+ drawable_window(false),
+ title(title_),
+ message(message_),
+ msg(*this),
+ btn_ok(*this)
+ {
+ initialize();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ blocking_box_win::
+ blocking_box_win (
+ const dlib::ustring& title_,
+ const dlib::ustring& message_
+ ) :
+ drawable_window(false),
+ title(convert_utf32_to_wstring(title_)),
+ message(convert_utf32_to_wstring(message_)),
+ msg(*this),
+ btn_ok(*this)
+ {
+ initialize();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ blocking_box_win::
+ ~blocking_box_win (
+ )
+ {
+ close_window();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void blocking_box_win::
+ on_click (
+ )
+ {
+ close_window();
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // function open_file_box()
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace open_file_box_helper
+ {
+ box_win::
+ box_win (
+ const std::string& title,
+ bool has_text_field
+ ) :
+ lbl_dirs(*this),
+ lbl_files(*this),
+ lbl_file_name(*this),
+ lb_dirs(*this),
+ lb_files(*this),
+ btn_ok(*this),
+ btn_cancel(*this),
+ btn_root(*this),
+ tf_file_name(*this)
+ {
+ if (has_text_field == false)
+ {
+ tf_file_name.hide();
+ lbl_file_name.hide();
+ }
+ else
+ {
+ lbl_file_name.set_text("File: ");
+ }
+
+ cur_dir = -1;
+ set_size(500,300);
+
+ lbl_dirs.set_text("Directories:");
+ lbl_files.set_text("Files:");
+ btn_ok.set_name("Ok");
+ btn_cancel.set_name("Cancel");
+ btn_root.set_name("/");
+
+ btn_root.set_click_handler(*this,&box_win::on_root_click);
+ btn_cancel.set_click_handler(*this,&box_win::on_cancel_click);
+ btn_ok.set_click_handler(*this,&box_win::on_open_click);
+ lb_dirs.set_double_click_handler(*this,&box_win::on_dirs_click);
+ lb_files.set_click_handler(*this,&box_win::on_files_click);
+ lb_files.set_double_click_handler(*this,&box_win::on_files_double_click);
+
+
+ btn_root.set_pos(5,5);
+
+ set_sizes();
+ set_title(title);
+
+ on_root_click();
+
+ // make it so that the file box starts out in our current working
+ // directory
+ std::string full_name(get_current_dir());
+
+ while (full_name.size() > 0)
+ {
+ std::string::size_type pos = full_name.find_first_of("\\/");
+ std::string left(full_name.substr(0,pos));
+ if (pos != std::string::npos)
+ full_name = full_name.substr(pos+1);
+ else
+ full_name.clear();
+
+ if (left.size() > 0)
+ enter_folder(left);
+ }
+
+
+ show();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ box_win::
+ ~box_win (
+ )
+ {
+ close_window();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ set_sizes(
+ )
+ {
+ unsigned long width, height;
+ get_size(width,height);
+
+
+ if (lbl_file_name.is_hidden())
+ {
+ lbl_dirs.set_pos(0,btn_root.bottom()+5);
+ lb_dirs.set_pos(0,lbl_dirs.bottom());
+ lb_dirs.set_size(width/2,height-lb_dirs.top()-btn_cancel.height()-10);
+
+ lbl_files.set_pos(lb_dirs.right(),btn_root.bottom()+5);
+ lb_files.set_pos(lb_dirs.right(),lbl_files.bottom());
+ lb_files.set_size(width-lb_files.left(),height-lb_files.top()-btn_cancel.height()-10);
+
+ btn_ok.set_pos(width - btn_ok.width()-25,lb_files.bottom()+5);
+ btn_cancel.set_pos(btn_ok.left() - btn_cancel.width()-5,lb_files.bottom()+5);
+ }
+ else
+ {
+
+ lbl_dirs.set_pos(0,btn_root.bottom()+5);
+ lb_dirs.set_pos(0,lbl_dirs.bottom());
+ lb_dirs.set_size(width/2,height-lb_dirs.top()-btn_cancel.height()-10-tf_file_name.height());
+
+ lbl_files.set_pos(lb_dirs.right(),btn_root.bottom()+5);
+ lb_files.set_pos(lb_dirs.right(),lbl_files.bottom());
+ lb_files.set_size(width-lb_files.left(),height-lb_files.top()-btn_cancel.height()-10-tf_file_name.height());
+
+ lbl_file_name.set_pos(lb_files.left(), lb_files.bottom()+8);
+ tf_file_name.set_pos(lbl_file_name.right(), lb_files.bottom()+5);
+ tf_file_name.set_width(width-tf_file_name.left()-5);
+
+ btn_ok.set_pos(width - btn_ok.width()-25,tf_file_name.bottom()+5);
+ btn_cancel.set_pos(btn_ok.left() - btn_cancel.width()-5,tf_file_name.bottom()+5);
+ }
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ on_window_resized (
+ )
+ {
+ set_sizes();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ deleter_thread (
+ )
+ {
+ close_window();
+ delete this;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ enter_folder (
+ const std::string& folder_name
+ )
+ {
+ if (btn_root.is_checked())
+ btn_root.set_unchecked();
+ if (cur_dir != -1)
+ sob[cur_dir]->set_unchecked();
+
+
+ const std::string old_path = path;
+ const long old_cur_dir = cur_dir;
+
+ std::unique_ptr<toggle_button> new_btn(new toggle_button(*this));
+ new_btn->set_name(folder_name);
+ new_btn->set_click_handler(*this,&box_win::on_path_button_click);
+
+ // remove any path buttons that won't be part of the path anymore
+ if (sob.size())
+ {
+ while (sob.size() > (unsigned long)(cur_dir+1))
+ {
+ std::unique_ptr<toggle_button> junk;
+ sob.remove(cur_dir+1,junk);
+ }
+ }
+
+ if (sob.size())
+ new_btn->set_pos(sob[sob.size()-1]->right()+5,sob[sob.size()-1]->top());
+ else
+ new_btn->set_pos(btn_root.right()+5,btn_root.top());
+
+ cur_dir = sob.size();
+ sob.add(sob.size(),new_btn);
+
+ path += folder_name + directory::get_separator();
+ if (set_dir(prefix + path) == false)
+ {
+ sob.remove(sob.size()-1,new_btn);
+ path = old_path;
+ cur_dir = old_cur_dir;
+ }
+ else
+ {
+
+ sob[cur_dir]->set_checked();
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ on_dirs_click (
+ unsigned long idx
+ )
+ {
+ enter_folder(lb_dirs[idx]);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ on_files_click (
+ unsigned long idx
+ )
+ {
+ if (tf_file_name.is_hidden() == false)
+ {
+ tf_file_name.set_text(lb_files[idx]);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ on_files_double_click (
+ unsigned long
+ )
+ {
+ on_open_click();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ on_cancel_click (
+ )
+ {
+ hide();
+ create_new_thread<box_win,&box_win::deleter_thread>(*this);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ on_open_click (
+ )
+ {
+ if (lb_files.get_selected() != lb_files.size() || tf_file_name.text().size() > 0)
+ {
+ if (event_handler.is_set())
+ {
+ if (tf_file_name.is_hidden())
+ event_handler(prefix + path + lb_files[lb_files.get_selected()]);
+ else if (tf_file_name.text().size() > 0)
+ event_handler(prefix + path + tf_file_name.text());
+ }
+ hide();
+ create_new_thread<box_win,&box_win::deleter_thread>(*this);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ on_path_button_click (
+ toggle_button& btn
+ )
+ {
+ if (btn_root.is_checked())
+ btn_root.set_unchecked();
+ if (cur_dir != -1)
+ sob[cur_dir]->set_unchecked();
+ std::string new_path;
+
+ for (unsigned long i = 0; i < sob.size(); ++i)
+ {
+ new_path += sob[i]->name() + directory::get_separator();
+ if (sob[i].get() == &btn)
+ {
+ cur_dir = i;
+ sob[i]->set_checked();
+ break;
+ }
+ }
+ if (path != new_path)
+ {
+ path = new_path;
+ set_dir(prefix+path);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ struct case_insensitive_compare
+ {
+ bool operator() (
+ const std::string& a,
+ const std::string& b
+ ) const
+ {
+ std::string::size_type i, size;
+ size = std::min(a.size(),b.size());
+ for (i = 0; i < size; ++i)
+ {
+ if (std::tolower(a[i]) < std::tolower(b[i]))
+ return true;
+ else if (std::tolower(a[i]) > std::tolower(b[i]))
+ return false;
+ }
+ if (a.size() < b.size())
+ return true;
+ else
+ return false;
+ }
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ bool box_win::
+ set_dir (
+ const std::string& dir
+ )
+ {
+ try
+ {
+ directory d(dir);
+ queue<directory>::kernel_1a_c qod;
+ queue<file>::kernel_1a_c qof;
+ queue<std::string>::sort_1a_c qos;
+ d.get_dirs(qod);
+ d.get_files(qof);
+
+ qod.reset();
+ while (qod.move_next())
+ {
+ std::string temp = qod.element().name();
+ qos.enqueue(temp);
+ }
+ qos.sort(case_insensitive_compare());
+ lb_dirs.load(qos);
+ qos.clear();
+
+ qof.reset();
+ while (qof.move_next())
+ {
+ std::string temp = qof.element().name();
+ qos.enqueue(temp);
+ }
+ qos.sort(case_insensitive_compare());
+ lb_files.load(qos);
+ return true;
+ }
+ catch (directory::listing_error& )
+ {
+ return false;
+ }
+ catch (directory::dir_not_found&)
+ {
+ return false;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void box_win::
+ on_root_click (
+ )
+ {
+ btn_root.set_checked();
+ if (cur_dir != -1)
+ sob[cur_dir]->set_unchecked();
+
+ queue<directory>::kernel_1a_c qod, qod2;
+ queue<file>::kernel_1a_c qof;
+ queue<std::string>::sort_1a_c qos;
+ get_filesystem_roots(qod);
+ path.clear();
+ cur_dir = -1;
+ if (qod.size() == 1)
+ {
+ qod.current().get_files(qof);
+ qod.current().get_dirs(qod2);
+ prefix = qod.current().full_name();
+
+ qod2.reset();
+ while (qod2.move_next())
+ {
+ std::string temp = qod2.element().name();
+ qos.enqueue(temp);
+ }
+ qos.sort(case_insensitive_compare());
+ lb_dirs.load(qos);
+ qos.clear();
+
+ qof.reset();
+ while (qof.move_next())
+ {
+ std::string temp = qof.element().name();
+ qos.enqueue(temp);
+ }
+ qos.sort(case_insensitive_compare());
+ lb_files.load(qos);
+ }
+ else
+ {
+ prefix.clear();
+ qod.reset();
+ while (qod.move_next())
+ {
+ std::string temp = qod.element().full_name();
+ temp = temp.substr(0,temp.size()-1);
+ qos.enqueue(temp);
+ }
+ qos.sort(case_insensitive_compare());
+ lb_dirs.load(qos);
+ qos.clear();
+ lb_files.load(qos);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ base_window::on_close_return_code box_win::
+ on_window_close (
+ )
+ {
+ delete this;
+ return CLOSE_WINDOW;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class menu_bar
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ menu_bar::
+ menu_bar(
+ drawable_window& w
+ ) :
+ drawable(w, 0xFFFF), // listen for all events
+ open_menu(0)
+ {
+ adjust_position();
+ enable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ menu_bar::
+ ~menu_bar()
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+ adjust_position();
+ compute_menu_geometry();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ set_number_of_menus (
+ unsigned long num
+ )
+ {
+ auto_mutex M(m);
+ menus.set_max_size(num);
+ menus.set_size(num);
+ open_menu = menus.size();
+ compute_menu_geometry();
+
+ for (unsigned long i = 0; i < menus.size(); ++i)
+ {
+ menus[i].menu.set_on_hide_handler(*this,&menu_bar::on_popup_hide);
+ }
+
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long menu_bar::
+ number_of_menus (
+ ) const
+ {
+ auto_mutex M(m);
+ return menus.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ set_menu_name (
+ unsigned long idx,
+ const std::string name,
+ char underline_ch
+ )
+ {
+ set_menu_name(idx, convert_mbstring_to_wstring(name), underline_ch);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ set_menu_name (
+ unsigned long idx,
+ const std::wstring name,
+ char underline_ch
+ )
+ {
+ set_menu_name(idx, convert_wstring_to_utf32(name), underline_ch);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ set_menu_name (
+ unsigned long idx,
+ const dlib::ustring name,
+ char underline_ch
+ )
+ {
+ DLIB_ASSERT ( idx < number_of_menus() ,
+ "\tvoid menu_bar::set_menu_name()"
+ << "\n\tidx: " << idx
+ << "\n\tnumber_of_menus(): " << number_of_menus()
+ );
+ auto_mutex M(m);
+ menus[idx].name = name.c_str();
+ menus[idx].underline_pos = name.find_first_of(underline_ch);
+ compute_menu_geometry();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string menu_bar::
+ menu_name (
+ unsigned long idx
+ ) const
+ {
+ return convert_wstring_to_mbstring(menu_wname(idx));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::wstring menu_bar::
+ menu_wname (
+ unsigned long idx
+ ) const
+ {
+ return convert_utf32_to_wstring(menu_uname(idx));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const dlib::ustring menu_bar::
+ menu_uname (
+ unsigned long idx
+ ) const
+ {
+ DLIB_ASSERT ( idx < number_of_menus() ,
+ "\tstd::string menu_bar::menu_name()"
+ << "\n\tidx: " << idx
+ << "\n\tnumber_of_menus(): " << number_of_menus()
+ );
+ auto_mutex M(m);
+ return menus[idx].name.c_str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ popup_menu& menu_bar::
+ menu (
+ unsigned long idx
+ )
+ {
+ DLIB_ASSERT ( idx < number_of_menus() ,
+ "\tpopup_menu& menu_bar::menu()"
+ << "\n\tidx: " << idx
+ << "\n\tnumber_of_menus(): " << number_of_menus()
+ );
+ auto_mutex M(m);
+ return menus[idx].menu;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const popup_menu& menu_bar::
+ menu (
+ unsigned long idx
+ ) const
+ {
+ DLIB_ASSERT ( idx < number_of_menus() ,
+ "\tconst popup_menu& menu_bar::menu()"
+ << "\n\tidx: " << idx
+ << "\n\tnumber_of_menus(): " << number_of_menus()
+ );
+ auto_mutex M(m);
+ return menus[idx].menu;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ on_window_resized (
+ )
+ {
+ adjust_position();
+ hide_menu();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ draw (
+ const canvas& c
+ ) const
+ {
+ rectangle area(rect.intersect(c));
+ if (area.is_empty())
+ return;
+
+ const unsigned char opacity = 40;
+ fill_rect_with_vertical_gradient(c, rect,rgb_alpha_pixel(255,255,255,opacity),
+ rgb_alpha_pixel(0,0,0,opacity));
+
+ // first draw the border between the menu and the rest of the window
+ draw_line(c, point(rect.left(),rect.bottom()-1),
+ point(rect.right(),rect.bottom()-1), 100);
+ draw_line(c, point(rect.left(),rect.bottom()),
+ point(rect.right(),rect.bottom()), 255);
+
+ // now draw all the menu buttons
+ for (unsigned long i = 0; i < menus.size(); ++i)
+ {
+ mfont->draw_string(c,menus[i].rect, menus[i].name );
+ if (menus[i].underline_p1 != menus[i].underline_p2)
+ draw_line(c, menus[i].underline_p1, menus[i].underline_p2);
+
+ if (open_menu == i)
+ {
+ fill_rect_with_vertical_gradient(c, menus[i].bgrect,rgb_alpha_pixel(255,255,0,40), rgb_alpha_pixel(0,0,0,40));
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ on_window_moved (
+ )
+ {
+ hide_menu();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ on_focus_lost (
+ )
+ {
+ hide_menu();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long ,
+ long x,
+ long y,
+ bool
+ )
+ {
+
+ if (rect.contains(x,y) == false || btn != (unsigned long)base_window::LEFT)
+ {
+ hide_menu();
+ return;
+ }
+
+ unsigned long old_menu = menus.size();
+
+ // if a menu is currently open then save its index
+ if (open_menu != menus.size())
+ {
+ old_menu = open_menu;
+ hide_menu();
+ }
+
+ // figure out which menu should be open if any
+ for (unsigned long i = 0; i < menus.size(); ++i)
+ {
+ if (menus[i].bgrect.contains(x,y))
+ {
+ if (old_menu != i)
+ show_menu(i);
+
+ break;
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ on_mouse_move (
+ unsigned long ,
+ long x,
+ long y
+ )
+ {
+ // if the mouse is over the menu_bar and some menu is currently open
+ if (rect.contains(x,y) && open_menu != menus.size())
+ {
+ // if the mouse is still in the same rectangle then don't do anything
+ if (menus[open_menu].bgrect.contains(x,y) == false)
+ {
+ // figure out which menu should be instead
+ for (unsigned long i = 0; i < menus.size(); ++i)
+ {
+ if (menus[i].bgrect.contains(x,y))
+ {
+ show_menu(i);
+ break;
+ }
+ }
+
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ if (state&base_window::KBD_MOD_ALT)
+ {
+ // check if the key matches any of our underlined keys
+ for (unsigned long i = 0; i < menus.size(); ++i)
+ {
+ // if we have found a matching key
+ if (is_printable &&
+ menus[i].underline_pos != std::string::npos &&
+ std::tolower(menus[i].name[menus[i].underline_pos]) == std::tolower(key))
+ {
+ show_menu(i);
+ menus[open_menu].menu.select_first_item();
+ return;
+ }
+ }
+ }
+
+ if (open_menu != menus.size())
+ {
+ unsigned long i = open_menu;
+ // if the submenu doesn't use this key for something then we will
+ if (menus[open_menu].menu.forwarded_on_keydown(key,is_printable,state) == false)
+ {
+ if (key == base_window::KEY_LEFT)
+ {
+ i = (i+menus.size()-1)%menus.size();
+ show_menu(i);
+ menus[open_menu].menu.select_first_item();
+ }
+ else if (key == base_window::KEY_RIGHT)
+ {
+ i = (i+1)%menus.size();
+ show_menu(i);
+ menus[open_menu].menu.select_first_item();
+ }
+ else if (key == base_window::KEY_ESC)
+ {
+ hide_menu();
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ show_menu (
+ unsigned long i
+ )
+ {
+ rectangle temp;
+
+ // menu already open so do nothing
+ if (i == open_menu)
+ return;
+
+ // if a menu is currently open
+ if (open_menu != menus.size())
+ {
+ menus[open_menu].menu.hide();
+ temp = menus[open_menu].bgrect;
+ }
+
+ // display the new menu
+ open_menu = i;
+ long wx, wy;
+ parent.get_pos(wx,wy);
+ wx += menus[i].bgrect.left();
+ wy += menus[i].bgrect.bottom()+1;
+ menus[i].menu.set_pos(wx,wy);
+ menus[i].menu.show();
+ parent.invalidate_rectangle(menus[i].bgrect+temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ hide_menu (
+ )
+ {
+ // if a menu is currently open
+ if (open_menu != menus.size())
+ {
+ menus[open_menu].menu.hide();
+ parent.invalidate_rectangle(menus[open_menu].bgrect);
+ open_menu = menus.size();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ on_popup_hide (
+ )
+ {
+ // if a menu is currently open
+ if (open_menu != menus.size())
+ {
+ parent.invalidate_rectangle(menus[open_menu].bgrect);
+ open_menu = menus.size();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ compute_menu_geometry (
+ )
+ {
+ long x = 7;
+ long bg_x = 0;
+ for (unsigned long i = 0; i < menus.size(); ++i)
+ {
+ // compute the locations of the text rectangles
+ menus[i].rect.set_top(5);
+ menus[i].rect.set_left(x);
+ menus[i].rect.set_bottom(rect.bottom()-2);
+
+ unsigned long width, height;
+ mfont->compute_size(menus[i].name,width,height);
+ menus[i].rect = resize_rect_width(menus[i].rect, width);
+ x = menus[i].rect.right()+10;
+
+ menus[i].bgrect.set_top(0);
+ menus[i].bgrect.set_left(bg_x);
+ menus[i].bgrect.set_bottom(rect.bottom()-2);
+ menus[i].bgrect.set_right(x-5);
+ bg_x = menus[i].bgrect.right()+1;
+
+ if (menus[i].underline_pos != std::string::npos)
+ {
+ // now compute the location of the underline bar
+ rectangle r1 = mfont->compute_cursor_rect(
+ menus[i].rect,
+ menus[i].name,
+ menus[i].underline_pos);
+
+ rectangle r2 = mfont->compute_cursor_rect(
+ menus[i].rect,
+ menus[i].name,
+ menus[i].underline_pos+1);
+
+ menus[i].underline_p1.x() = r1.left()+1;
+ menus[i].underline_p2.x() = r2.left()-1;
+ menus[i].underline_p1.y() = r1.bottom()-mfont->height()+mfont->ascender()+2;
+ menus[i].underline_p2.y() = r2.bottom()-mfont->height()+mfont->ascender()+2;
+ }
+ else
+ {
+ // there is no underline in this case
+ menus[i].underline_p1 = menus[i].underline_p2;
+ }
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void menu_bar::
+ adjust_position (
+ )
+ {
+ unsigned long width, height;
+ rectangle old(rect);
+ parent.get_size(width,height);
+ rect.set_left(0);
+ rect.set_top(0);
+ rect = resize_rect(rect,width,mfont->height()+10);
+ parent.invalidate_rectangle(old+rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// class text_grid
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ text_grid::
+ text_grid (
+ drawable_window& w
+ ) :
+ scrollable_region(w, KEYBOARD_EVENTS | MOUSE_CLICK | FOCUS_EVENTS ),
+ has_focus(false),
+ cursor_timer(*this,&text_grid::timer_action),
+ border_color_(128,128,128)
+ {
+
+ cursor_timer.set_delay_time(500);
+ set_vertical_scroll_increment(10);
+ set_horizontal_scroll_increment(10);
+ enable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ text_grid::
+ ~text_grid (
+ )
+ {
+ // Disable all further events for this drawable object. We have to do this
+ // because we don't want draw() events coming to this object while or after
+ // it has been destructed.
+ disable_events();
+
+ // wait for the timer to stop doing its thing
+ cursor_timer.stop_and_wait();
+ // Tell the parent window to redraw its area that previously contained this
+ // drawable object.
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_grid_size (
+ unsigned long rows,
+ unsigned long cols
+ )
+ {
+ auto_mutex M(m);
+ row_height.set_max_size(rows);
+ row_height.set_size(rows);
+
+ col_width.set_max_size(cols);
+ col_width.set_size(cols);
+
+ grid.set_size(rows,cols);
+
+ for (unsigned long i = 0; i < row_height.size(); ++i)
+ row_height[i] = (mfont->height()*3)/2;
+ for (unsigned long i = 0; i < col_width.size(); ++i)
+ col_width[i] = mfont->height()*5;
+
+ compute_total_rect();
+ compute_bg_rects();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long text_grid::
+ number_of_columns (
+ ) const
+ {
+ auto_mutex M(m);
+ return grid.nc();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long text_grid::
+ number_of_rows (
+ ) const
+ {
+ auto_mutex M(m);
+ return grid.nr();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int text_grid::
+ next_free_user_event_number (
+ ) const
+ {
+ return scrollable_region::next_free_user_event_number()+1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rgb_pixel text_grid::
+ border_color (
+ ) const
+ {
+ auto_mutex M(m);
+ return border_color_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_border_color (
+ rgb_pixel color
+ )
+ {
+ auto_mutex M(m);
+ border_color_ = color;
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string text_grid::
+ text (
+ unsigned long row,
+ unsigned long col
+ ) const
+ {
+ return convert_wstring_to_mbstring(wtext(row, col));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::wstring text_grid::
+ wtext (
+ unsigned long row,
+ unsigned long col
+ ) const
+ {
+ return convert_utf32_to_wstring(utext(row, col));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const dlib::ustring text_grid::
+ utext (
+ unsigned long row,
+ unsigned long col
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(),
+ "\tconst std::string text_grid::text(row,col)"
+ << "\n\trow: " << row
+ << "\n\tcol: " << col
+ << "\n\tnumber_of_rows(): " << number_of_rows()
+ << "\n\tnumber_of_columns(): " << number_of_columns()
+ << "\n\tthis: " << this
+ );
+ return grid[row][col].text.c_str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_text (
+ unsigned long row,
+ unsigned long col,
+ const std::string& str
+ )
+ {
+ set_text(row, col, convert_mbstring_to_wstring(str));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_text (
+ unsigned long row,
+ unsigned long col,
+ const std::wstring& str
+ )
+ {
+ set_text(row, col, convert_wstring_to_utf32(str));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_text (
+ unsigned long row,
+ unsigned long col,
+ const dlib::ustring& str
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(),
+ "\tvoid text_grid::set_text(row,col)"
+ << "\n\trow: " << row
+ << "\n\tcol: " << col
+ << "\n\tnumber_of_rows(): " << number_of_rows()
+ << "\n\tnumber_of_columns(): " << number_of_columns()
+ << "\n\tthis: " << this
+ );
+ grid[row][col].text = str.c_str();
+ parent.invalidate_rectangle(get_text_rect(row,col));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const rgb_pixel text_grid::
+ text_color (
+ unsigned long row,
+ unsigned long col
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(),
+ "\tconst rgb_pixel text_grid::text_color(row,col)"
+ << "\n\trow: " << row
+ << "\n\tcol: " << col
+ << "\n\tnumber_of_rows(): " << number_of_rows()
+ << "\n\tnumber_of_columns(): " << number_of_columns()
+ << "\n\tthis: " << this
+ );
+ return grid[row][col].text_color;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_text_color (
+ unsigned long row,
+ unsigned long col,
+ const rgb_pixel color
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(),
+ "\tvoid text_grid::set_text_color(row,col,color)"
+ << "\n\trow: " << row
+ << "\n\tcol: " << col
+ << "\n\tnumber_of_rows(): " << number_of_rows()
+ << "\n\tnumber_of_columns(): " << number_of_columns()
+ << "\n\tthis: " << this
+ );
+ grid[row][col].text_color = color;
+ parent.invalidate_rectangle(get_text_rect(row,col));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const rgb_pixel text_grid::
+ background_color (
+ unsigned long row,
+ unsigned long col
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(),
+ "\tconst rgb_pixel text_grid::background_color(row,col,color)"
+ << "\n\trow: " << row
+ << "\n\tcol: " << col
+ << "\n\tnumber_of_rows(): " << number_of_rows()
+ << "\n\tnumber_of_columns(): " << number_of_columns()
+ << "\n\tthis: " << this
+ );
+ return grid[row][col].bg_color;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_background_color (
+ unsigned long row,
+ unsigned long col,
+ const rgb_pixel color
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(),
+ "\tvoid text_grid::set_background_color(row,col,color)"
+ << "\n\trow: " << row
+ << "\n\tcol: " << col
+ << "\n\tnumber_of_rows(): " << number_of_rows()
+ << "\n\tnumber_of_columns(): " << number_of_columns()
+ << "\n\tthis: " << this
+ );
+ grid[row][col].bg_color = color;
+ parent.invalidate_rectangle(get_bg_rect(row,col));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool text_grid::
+ is_editable (
+ unsigned long row,
+ unsigned long col
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(),
+ "\tbool text_grid::is_editable(row,col)"
+ << "\n\trow: " << row
+ << "\n\tcol: " << col
+ << "\n\tnumber_of_rows(): " << number_of_rows()
+ << "\n\tnumber_of_columns(): " << number_of_columns()
+ << "\n\tthis: " << this
+ );
+ return grid[row][col].is_editable;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_editable (
+ unsigned long row,
+ unsigned long col,
+ bool editable
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(),
+ "\tvoid text_grid::set_editable(row,col,editable)"
+ << "\n\trow: " << row
+ << "\n\tcol: " << col
+ << "\n\tnumber_of_rows(): " << number_of_rows()
+ << "\n\tnumber_of_columns(): " << number_of_columns()
+ << "\n\teditable: " << editable
+ << "\n\tthis: " << this
+ );
+ grid[row][col].is_editable = editable;
+ if (has_focus && active_row == static_cast<long>(row) && active_col == static_cast<long>(col))
+ {
+ drop_input_focus();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_column_width (
+ unsigned long col,
+ unsigned long width
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( col < number_of_columns(),
+ "\tvoid text_grid::set_column_width(col,width)"
+ << "\n\tcol: " << col
+ << "\n\tnumber_of_columns(): " << number_of_columns()
+ << "\n\twidth: " << width
+ << "\n\tthis: " << this
+ );
+ col_width[col] = width;
+ compute_total_rect();
+ compute_bg_rects();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ set_row_height (
+ unsigned long row,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( row < number_of_rows() ,
+ "\tvoid text_grid::set_row_height(row,height)"
+ << "\n\trow: " << row
+ << "\n\tnumber_of_rows(): " << number_of_rows()
+ << "\n\theight: " << height
+ << "\n\tthis: " << this
+ );
+ row_height[row] = height;
+ compute_total_rect();
+ compute_bg_rects();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ disable (
+ )
+ {
+ auto_mutex M(m);
+ scrollable_region::disable();
+ drop_input_focus();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ hide (
+ )
+ {
+ auto_mutex M(m);
+ scrollable_region::hide();
+ drop_input_focus();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ on_user_event (
+ int num
+ )
+ {
+ // ignore this user event if it isn't for us
+ if (num != scrollable_region::next_free_user_event_number())
+ return;
+
+ if (has_focus && !recent_cursor_move && enabled && !hidden)
+ {
+ show_cursor = !show_cursor;
+ parent.invalidate_rectangle(get_text_rect(active_row,active_col));
+ }
+ recent_cursor_move = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ timer_action (
+ )
+ {
+ parent.trigger_user_event(this,scrollable_region::next_free_user_event_number());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ compute_bg_rects (
+ )
+ {
+ // loop over each element in the grid and figure out what its rectangle should be
+ // with respect to the total_rect()
+ point p1, p2;
+ p1.y() = total_rect().top();
+ for (long row = 0; row < grid.nr(); ++row)
+ {
+ p1.x() = total_rect().left();
+ p2.y() = p1.y() + row_height[row]-1;
+ for (long col = 0; col < grid.nc(); ++col)
+ {
+ // if this is the last box in this row make it super wide so that it always
+ // goes to the end of the widget
+ if (col+1 == grid.nc())
+ p2.x() = 1000000;
+ else
+ p2.x() = p1.x() + col_width[col]-1;
+
+ // at this point p1 is the upper left corner of this box and p2 is the
+ // lower right corner of the box;
+ rectangle bg_rect(p1);
+ bg_rect += p2;
+
+ grid[row][col].bg_rect = translate_rect(bg_rect, -total_rect().left(), -total_rect().top());
+
+
+ p1.x() += 1 + col_width[col];
+ }
+ p1.y() += 1 + row_height[row];
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ compute_total_rect (
+ )
+ {
+ if (grid.size() == 0)
+ {
+ set_total_rect_size(0,0);
+ }
+ else
+ {
+ unsigned long width = col_width.size()-1;
+ unsigned long height = row_height.size()-1;
+
+ for (unsigned long i = 0; i < col_width.size(); ++i)
+ width += col_width[i];
+ for (unsigned long i = 0; i < row_height.size(); ++i)
+ height += row_height[i];
+
+ set_total_rect_size(width,height);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ // ignore this event if we are disabled or hidden
+ if (!enabled || hidden)
+ return;
+
+ if (has_focus)
+ {
+ if (is_printable)
+ {
+ // if the user hit the tab key then jump to the next box
+ if (key == '\t')
+ {
+ if (active_col+1 == grid.nc())
+ {
+ if (active_row+1 == grid.nr())
+ move_cursor(0,0,0);
+ else
+ move_cursor(active_row+1,0,0);
+ }
+ else
+ {
+ move_cursor(active_row,active_col+1,0);
+ }
+ }
+ if (key == '\n')
+ {
+ // ignore the enter key
+ }
+ else if (grid[active_row][active_col].is_editable)
+ {
+ // insert the key the user pressed into the string
+ grid[active_row][active_col].text.insert(cursor_pos,1,static_cast<char>(key));
+ move_cursor(active_row,active_col,cursor_pos+1);
+
+ if (text_modified_handler.is_set())
+ text_modified_handler(active_row,active_col);
+ }
+ }
+ else if ((state & base_window::KBD_MOD_CONTROL))
+ {
+ if (key == base_window::KEY_LEFT)
+ move_cursor(active_row,active_col-1,0);
+ else if (key == base_window::KEY_RIGHT)
+ move_cursor(active_row,active_col+1,0);
+ else if (key == base_window::KEY_UP)
+ move_cursor(active_row-1,active_col,0);
+ else if (key == base_window::KEY_DOWN)
+ move_cursor(active_row+1,active_col,0);
+ else if (key == base_window::KEY_END)
+ move_cursor(active_row,active_col,grid[active_row][active_col].text.size());
+ else if (key == base_window::KEY_HOME)
+ move_cursor(active_row,active_col,0);
+ }
+ else
+ {
+ if (key == base_window::KEY_LEFT)
+ move_cursor(active_row,active_col,cursor_pos-1);
+ else if (key == base_window::KEY_RIGHT)
+ move_cursor(active_row,active_col,cursor_pos+1);
+ else if (key == base_window::KEY_UP)
+ move_cursor(active_row-1,active_col,0);
+ else if (key == base_window::KEY_DOWN)
+ move_cursor(active_row+1,active_col,0);
+ else if (key == base_window::KEY_END)
+ move_cursor(active_row,active_col,grid[active_row][active_col].text.size());
+ else if (key == base_window::KEY_HOME)
+ move_cursor(active_row,active_col,0);
+ else if (key == base_window::KEY_BACKSPACE)
+ {
+ if (cursor_pos > 0 && grid[active_row][active_col].is_editable)
+ {
+ grid[active_row][active_col].text.erase(
+ grid[active_row][active_col].text.begin()+cursor_pos-1,
+ grid[active_row][active_col].text.begin()+cursor_pos);
+ move_cursor(active_row,active_col,cursor_pos-1);
+
+ if (text_modified_handler.is_set())
+ text_modified_handler(active_row,active_col);
+ }
+ }
+ else if (key == base_window::KEY_DELETE)
+ {
+ if (cursor_pos < static_cast<long>(grid[active_row][active_col].text.size()) &&
+ grid[active_row][active_col].is_editable)
+ {
+ grid[active_row][active_col].text.erase(
+ grid[active_row][active_col].text.begin()+cursor_pos);
+ move_cursor(active_row,active_col,cursor_pos);
+
+ if (text_modified_handler.is_set())
+ text_modified_handler(active_row,active_col);
+ }
+ }
+ }
+ } // if (has_focus)
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ )
+ {
+ scrollable_region::on_mouse_down(btn, state, x, y, is_double_click);
+ if (display_rect().contains(x,y) && enabled && !hidden)
+ {
+ // figure out which box this click landed in
+ rectangle hit;
+
+ // find which column we hit
+ unsigned long col = 0;
+ long box_x = total_rect().left();
+ for (unsigned long i = 0; i < col_width.size(); ++i)
+ {
+ if (box_x <= x && (x < box_x+static_cast<long>(col_width[i]) || (i+1 == col_width.size())))
+ {
+ col = i;
+ hit.set_left(box_x);
+ hit.set_right(box_x+col_width[i]-1);
+ break;
+ }
+ else
+ {
+ box_x += col_width[i]+1;
+ }
+ }
+
+ // find which row we hit
+ unsigned long row = 0;
+ long box_y = total_rect().top();
+ for (unsigned long i = 0; i < row_height.size(); ++i)
+ {
+ if (box_y <= y && y < box_y+static_cast<long>(row_height[i]))
+ {
+ row = i;
+ hit.set_top(box_y);
+ hit.set_bottom(box_y+row_height[i]-1);
+ break;
+ }
+ else
+ {
+ box_y += row_height[i]+1;
+ }
+ }
+
+ // if we hit a box
+ if (hit.is_empty() == false)
+ {
+ move_cursor(row,
+ col,
+ mfont->compute_cursor_pos(get_text_rect(row,col), grid[row][col].text, x, y, grid[row][col].first)
+ );
+ }
+ else
+ {
+ drop_input_focus();
+ }
+ }
+ else
+ {
+ drop_input_focus();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ scrollable_region::on_mouse_up(btn, state, x, y);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ on_focus_lost (
+ )
+ {
+ drop_input_focus();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ draw (
+ const canvas& c
+ ) const
+ {
+ scrollable_region::draw(c);
+ rectangle area = c.intersect(display_rect());
+ if (area.is_empty() == true)
+ return;
+
+ if (enabled)
+ fill_rect(c, area, 255);
+
+ // don't do anything if the grid is empty
+ if (grid.size() == 0)
+ return;
+
+ // draw all the vertical lines
+ point p1, p2;
+ p1.x() = p2.x() = total_rect().left();
+ p1.y() = total_rect().top();
+ p2.y() = total_rect().bottom();
+ for (unsigned long i = 0; i < col_width.size()-1; ++i)
+ {
+ p1.x() += col_width[i];
+ p2.x() += col_width[i];
+ if (enabled)
+ draw_line(c,p1,p2,border_color_,area);
+ else
+ draw_line(c,p1,p2,128,area);
+ p1.x() += 1;
+ p2.x() += 1;
+ }
+
+ // draw all the horizontal lines
+ p1.y() = p2.y() = total_rect().top();
+ p1.x() = display_rect().left();
+ p2.x() = display_rect().right();
+ for (unsigned long i = 0; i < row_height.size(); ++i)
+ {
+ p1.y() += row_height[i];
+ p2.y() += row_height[i];
+ if (enabled)
+ draw_line(c,p1,p2,border_color_,area);
+ else
+ draw_line(c,p1,p2,128,area);
+ p1.y() += 1;
+ p2.y() += 1;
+ }
+
+ // draw the backgrounds and text for each box
+ for (long row = 0; row < grid.nr(); ++row)
+ {
+ for (long col = 0; col < grid.nc(); ++col)
+ {
+ rectangle bg_rect(get_bg_rect(row,col));
+
+ rectangle text_rect(get_text_rect(row,col));
+
+ if (enabled)
+ {
+ fill_rect(c,bg_rect.intersect(area),grid[row][col].bg_color);
+
+ mfont->draw_string(c,
+ text_rect,
+ grid[row][col].text,
+ grid[row][col].text_color,
+ grid[row][col].first,
+ std::string::npos,
+ area);
+ }
+ else
+ {
+ mfont->draw_string(c,
+ text_rect,
+ grid[row][col].text,
+ 128,
+ grid[row][col].first,
+ std::string::npos,
+ area);
+ }
+
+ // if this box has input focus then draw it with a cursor
+ if (has_focus && active_col == col && active_row == row && show_cursor)
+ {
+ rectangle cursor_rect = mfont->compute_cursor_rect(text_rect,
+ grid[row][col].text,
+ cursor_pos,
+ grid[row][col].first);
+ draw_rectangle(c,cursor_rect,0,area);
+ }
+
+ }
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle text_grid::
+ get_text_rect (
+ unsigned long row,
+ unsigned long col
+ ) const
+ {
+ rectangle bg_rect(get_bg_rect(row,col));
+ long padding = (bg_rect.height() - mfont->height())/2 + (bg_rect.height() - mfont->height())%2;
+ if (padding < 0)
+ padding = 0;
+ bg_rect.set_left(bg_rect.left()+padding);
+ bg_rect.set_top(bg_rect.top()+padding);
+ bg_rect.set_right(bg_rect.right()-padding);
+ bg_rect.set_bottom(bg_rect.bottom()-padding);
+ return bg_rect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle text_grid::
+ get_bg_rect (
+ unsigned long row,
+ unsigned long col
+ ) const
+ {
+ return translate_rect(grid[row][col].bg_rect, total_rect().left(), total_rect().top());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ drop_input_focus (
+ )
+ {
+ if (has_focus)
+ {
+ parent.invalidate_rectangle(get_text_rect(active_row,active_col));
+ has_focus = false;
+ show_cursor = false;
+ cursor_timer.stop();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_grid::
+ move_cursor (
+ long row,
+ long col,
+ long new_cursor_pos
+ )
+ {
+ // don't do anything if the grid is empty
+ if (grid.size() == 0)
+ {
+ return;
+ }
+
+ if (row < 0)
+ row = 0;
+ if (row >= grid.nr())
+ row = grid.nr()-1;
+ if (col < 0)
+ col = 0;
+ if (col >= grid.nc())
+ col = grid.nc()-1;
+
+ if (new_cursor_pos < 0)
+ {
+ if (col == 0)
+ {
+ new_cursor_pos = 0;
+ }
+ else
+ {
+ --col;
+ new_cursor_pos = grid[row][col].text.size();
+ }
+ }
+
+ if (new_cursor_pos > static_cast<long>(grid[row][col].text.size()))
+ {
+ if (col+1 == grid.nc())
+ {
+ new_cursor_pos = grid[row][col].text.size();
+ }
+ else
+ {
+ ++col;
+ new_cursor_pos = 0;
+ }
+ }
+
+ // if some other box had the input focus then redraw it
+ if (has_focus && (active_row != row || active_col != col ))
+ {
+ parent.invalidate_rectangle(get_text_rect(active_row,active_col));
+ }
+
+ if (has_focus == false)
+ {
+ cursor_timer.start();
+ }
+
+ has_focus = true;
+ recent_cursor_move = true;
+ show_cursor = true;
+ active_row = row;
+ active_col = col;
+ cursor_pos = new_cursor_pos;
+
+ // adjust the first character to draw so that the string is displayed well
+ rectangle text_rect(get_text_rect(active_row,active_col));
+ rectangle cursor_rect = mfont->compute_cursor_rect(text_rect,
+ grid[row][col].text,
+ cursor_pos,
+ grid[row][col].first);
+
+ // if the cursor rect is too far to the left of the string
+ if (cursor_pos < static_cast<long>(grid[row][col].first))
+ {
+ if (cursor_pos > 5)
+ {
+ grid[row][col].first = cursor_pos - 5;
+ }
+ else
+ {
+ grid[row][col].first = 0;
+ }
+ }
+ // if the cursor rect is too far to the right of the string
+ else if (cursor_rect.left() > text_rect.right())
+ {
+ long distance = (cursor_rect.left() - text_rect.right()) + text_rect.width()/3;
+ // find the letter that is distance pixels from the start of the string
+ long sum = 0;
+ for (unsigned long i = grid[row][col].first; i < grid[row][col].text.size(); ++i)
+ {
+ sum += (*mfont)[grid[row][col].text[i]].width();
+ if (sum >= distance)
+ {
+ grid[row][col].first = i;
+ break;
+ }
+ }
+ }
+
+ scroll_to_rect(get_bg_rect(row,col));
+
+ // redraw our box
+ parent.invalidate_rectangle(text_rect);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // text_field object methods
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ rectangle text_box::
+ get_text_rect (
+ ) const
+ {
+ const unsigned long padding = style->get_padding(*mfont);
+
+ rectangle text_rect;
+ text_rect.set_left(total_rect().left()+padding);
+ text_rect.set_top(total_rect().top()+padding);
+ text_rect.set_right(total_rect().right()-padding);
+ text_rect.set_bottom(total_rect().bottom()-padding);
+ return text_rect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ enable (
+ )
+ {
+ scrollable_region::enable();
+ right_click_menu.enable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_cut (
+ )
+ {
+ on_copy();
+ on_delete_selected();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_copy (
+ )
+ {
+ if (highlight_start <= highlight_end)
+ {
+ put_on_clipboard(text_.substr(highlight_start, highlight_end-highlight_start+1));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_paste (
+ )
+ {
+ ustring temp_str;
+ get_from_clipboard(temp_str);
+
+
+ if (highlight_start <= highlight_end)
+ {
+ text_ = text_.substr(0,highlight_start) + temp_str +
+ text_.substr(highlight_end+1,text_.size()-highlight_end-1);
+ move_cursor(highlight_start+temp_str.size());
+ highlight_start = 0;
+ highlight_end = -1;
+ parent.invalidate_rectangle(rect);
+ on_no_text_selected();
+
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ else
+ {
+ text_ = text_.substr(0,cursor_pos) + temp_str +
+ text_.substr(cursor_pos,text_.size()-cursor_pos);
+ move_cursor(cursor_pos+temp_str.size());
+
+ // send out the text modified event
+ if (temp_str.size() != 0 && text_modified_handler.is_set())
+ text_modified_handler();
+ }
+
+ adjust_total_rect();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_select_all (
+ )
+ {
+ move_cursor(static_cast<long>(text_.size()));
+ highlight_start = 0;
+ highlight_end = static_cast<long>(text_.size()-1);
+ if (highlight_start <= highlight_end)
+ on_text_is_selected();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_delete_selected (
+ )
+ {
+ if (highlight_start <= highlight_end)
+ {
+ text_ = text_.erase(highlight_start,highlight_end-highlight_start+1);
+ move_cursor(highlight_start);
+ highlight_start = 0;
+ highlight_end = -1;
+
+ on_no_text_selected();
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+
+ adjust_total_rect();
+
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_text_is_selected (
+ )
+ {
+ right_click_menu.menu().enable_menu_item(0);
+ right_click_menu.menu().enable_menu_item(1);
+ right_click_menu.menu().enable_menu_item(3);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_no_text_selected (
+ )
+ {
+ right_click_menu.menu().disable_menu_item(0);
+ right_click_menu.menu().disable_menu_item(1);
+ right_click_menu.menu().disable_menu_item(3);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ show (
+ )
+ {
+ scrollable_region::show();
+ right_click_menu.show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ disable (
+ )
+ {
+ auto_mutex M(m);
+ scrollable_region::disable();
+ t.stop();
+ has_focus = false;
+ cursor_visible = false;
+ right_click_menu.disable();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ hide (
+ )
+ {
+ auto_mutex M(m);
+ scrollable_region::hide();
+ t.stop();
+ has_focus = false;
+ cursor_visible = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ adjust_total_rect (
+ )
+ {
+ const unsigned long padding = style->get_padding(*mfont);
+ unsigned long text_width;
+ unsigned long text_height;
+
+ mfont->compute_size(text_, text_width, text_height);
+
+ set_total_rect_size(text_width + padding*2, text_height + padding*2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ set_main_font (
+ const std::shared_ptr<font>& f
+ )
+ {
+ auto_mutex M(m);
+ mfont = f;
+ adjust_total_rect();
+ right_click_menu.set_rect(display_rect());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ draw (
+ const canvas& c
+ ) const
+ {
+ scrollable_region::draw(c);
+ rectangle area = rect.intersect(c);
+ if (area.is_empty())
+ return;
+
+ const point origin(total_rect().left(), total_rect().top());
+
+ style->draw_text_box(c,display_rect(),get_text_rect(), enabled, *mfont, text_,
+ translate_rect(cursor_rect, origin),
+ text_color_, bg_color_, has_focus, cursor_visible, highlight_start,
+ highlight_end);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ set_text (
+ const std::string& text
+ )
+ {
+ set_text(convert_mbstring_to_wstring(text));
+ }
+
+ void text_box::
+ set_text (
+ const std::wstring& text
+ )
+ {
+ set_text(convert_wstring_to_utf32(text));
+ }
+
+ void text_box::
+ set_text (
+ const dlib::ustring& text
+ )
+ {
+ auto_mutex M(m);
+ // do this to get rid of any reference counting that may be present in
+ // the std::string implementation.
+ text_ = text.c_str();
+
+ adjust_total_rect();
+ move_cursor(0);
+
+ highlight_start = 0;
+ highlight_end = -1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string text_box::
+ text (
+ ) const
+ {
+ std::string temp = convert_wstring_to_mbstring(wtext());
+ return temp;
+ }
+
+ const std::wstring text_box::
+ wtext (
+ ) const
+ {
+ std::wstring temp = convert_utf32_to_wstring(utext());
+ return temp;
+ }
+
+ const dlib::ustring text_box::
+ utext (
+ ) const
+ {
+ auto_mutex M(m);
+ // do this to get rid of any reference counting that may be present in
+ // the dlib::ustring implementation.
+ dlib::ustring temp = text_.c_str();
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex M(m);
+ scrollable_region::set_size(width,height);
+ right_click_menu.set_rect(display_rect());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ set_pos (
+ long x,
+ long y
+ )
+ {
+ scrollable_region::set_pos(x,y);
+ right_click_menu.set_rect(get_text_rect());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ set_background_color (
+ const rgb_pixel color
+ )
+ {
+ auto_mutex M(m);
+ bg_color_ = color;
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const rgb_pixel text_box::
+ background_color (
+ ) const
+ {
+ auto_mutex M(m);
+ return bg_color_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ set_text_color (
+ const rgb_pixel color
+ )
+ {
+ auto_mutex M(m);
+ text_color_ = color;
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const rgb_pixel text_box::
+ text_color (
+ ) const
+ {
+ auto_mutex M(m);
+ return text_color_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ if (!enabled || hidden || !has_focus)
+ {
+ return;
+ }
+
+ if (state & base_window::LEFT)
+ {
+ if (highlight_start <= highlight_end)
+ {
+ if (highlight_start == cursor_pos)
+ shift_pos = highlight_end + 1;
+ else
+ shift_pos = highlight_start;
+ }
+
+ unsigned long new_pos = mfont->compute_cursor_pos(get_text_rect(),text_,x,y);
+ if (static_cast<long>(new_pos) != cursor_pos)
+ {
+ move_cursor(new_pos);
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (shift_pos != -1)
+ {
+ shift_pos = -1;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_mouse_up (
+ unsigned long btn,
+ unsigned long,
+ long ,
+ long
+ )
+ {
+ if (!enabled || hidden)
+ return;
+
+ if (btn == base_window::LEFT)
+ shift_pos = -1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool double_clicked
+ )
+ {
+ using namespace std;
+ if (!enabled || hidden || btn != (unsigned long)base_window::LEFT)
+ return;
+
+ if (display_rect().contains(x,y))
+ {
+ has_focus = true;
+ cursor_visible = true;
+ parent.invalidate_rectangle(rect);
+ t.start();
+
+
+ if (double_clicked)
+ {
+ // highlight the double clicked word
+ string::size_type first, last;
+ const ustring ustr = convert_utf8_to_utf32(std::string(" \t\n"));
+ first = text_.substr(0,cursor_pos).find_last_of(ustr.c_str());
+ last = text_.find_first_of(ustr.c_str(),cursor_pos);
+ long f = static_cast<long>(first);
+ long l = static_cast<long>(last);
+ if (first == string::npos)
+ f = -1;
+ if (last == string::npos)
+ l = static_cast<long>(text_.size());
+
+ ++f;
+ --l;
+
+ move_cursor(l+1);
+ highlight_start = f;
+ highlight_end = l;
+ on_text_is_selected();
+ }
+ else
+ {
+ if (state & base_window::SHIFT)
+ {
+ if (highlight_start <= highlight_end)
+ {
+ if (highlight_start == cursor_pos)
+ shift_pos = highlight_end + 1;
+ else
+ shift_pos = highlight_start;
+ }
+ else
+ {
+ shift_pos = cursor_pos;
+ }
+ }
+
+ bool at_end = false;
+ if (cursor_pos == 0 || cursor_pos == static_cast<long>(text_.size()))
+ at_end = true;
+ const long old_pos = cursor_pos;
+
+ unsigned long new_pos = mfont->compute_cursor_pos(get_text_rect(),text_,x,y);
+ move_cursor(new_pos);
+
+ shift_pos = cursor_pos;
+
+ if (at_end && cursor_pos == old_pos)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ }
+ }
+
+ }
+ else if (has_focus && rect.contains(x,y) == false)
+ {
+ t.stop();
+ has_focus = false;
+ cursor_visible = false;
+ shift_pos = -1;
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+
+ if (focus_lost_handler.is_set())
+ focus_lost_handler();
+ parent.invalidate_rectangle(rect);
+ }
+ else
+ {
+ has_focus = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ // If the right click menu is up then we don't want to do anything with
+ // the keyboard ourselves. Let the popup menu use the keyboard for now.
+ if (right_click_menu.popup_menu_visible())
+ return;
+
+ if (has_focus && enabled && !hidden)
+ {
+ const ustring space_str = convert_utf8_to_utf32(std::string(" \t\n"));
+ const bool shift = (state&base_window::KBD_MOD_SHIFT) != 0;
+ const bool ctrl = (state&base_window::KBD_MOD_CONTROL) != 0;
+
+ if (shift && is_printable == false)
+ {
+ if (shift_pos == -1)
+ {
+ if (highlight_start <= highlight_end)
+ {
+ if (highlight_start == cursor_pos)
+ shift_pos = highlight_end + 1;
+ else
+ shift_pos = highlight_start;
+ }
+ else
+ {
+ shift_pos = cursor_pos;
+ }
+ }
+ }
+ else
+ {
+ shift_pos = -1;
+ }
+
+ if (key == base_window::KEY_LEFT)
+ {
+ if (cursor_pos != 0)
+ {
+ unsigned long new_pos;
+ if (ctrl)
+ {
+ // find the first non-whitespace to our left
+ std::string::size_type pos = text_.find_last_not_of(space_str.c_str(),cursor_pos);
+ if (pos != std::string::npos)
+ {
+ pos = text_.find_last_of(space_str.c_str(),pos);
+ if (pos != std::string::npos)
+ new_pos = static_cast<unsigned long>(pos);
+ else
+ new_pos = 0;
+ }
+ else
+ {
+ new_pos = 0;
+ }
+ }
+ else
+ {
+ new_pos = cursor_pos-1;
+ }
+
+ move_cursor(new_pos);
+ }
+ else if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+
+ }
+ else if (key == base_window::KEY_RIGHT)
+ {
+ if (cursor_pos != static_cast<long>(text_.size()))
+ {
+ unsigned long new_pos;
+ if (ctrl)
+ {
+ // find the first non-whitespace to our left
+ std::string::size_type pos = text_.find_first_not_of(space_str.c_str(),cursor_pos);
+ if (pos != std::string::npos)
+ {
+ pos = text_.find_first_of(space_str.c_str(),pos);
+ if (pos != std::string::npos)
+ new_pos = static_cast<unsigned long>(pos+1);
+ else
+ new_pos = static_cast<unsigned long>(text_.size());
+ }
+ else
+ {
+ new_pos = static_cast<unsigned long>(text_.size());
+ }
+ }
+ else
+ {
+ new_pos = cursor_pos+1;
+ }
+
+ move_cursor(new_pos);
+ }
+ else if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (key == base_window::KEY_UP)
+ {
+ if (ctrl)
+ {
+ move_cursor(0);
+ }
+ else
+ {
+ const point origin(total_rect().left(), total_rect().top());
+ // move the cursor so the position that is just a few pixels above
+ // the current cursor_rect
+ move_cursor(mfont->compute_cursor_pos(
+ get_text_rect(), text_, cursor_rect.left()+origin.x(),
+ cursor_rect.top()+origin.y()-mfont->height()/2));
+
+ }
+
+ if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (key == base_window::KEY_DOWN)
+ {
+ if (ctrl)
+ {
+ move_cursor(static_cast<unsigned long>(text_.size()));
+ }
+ else
+ {
+ const point origin(total_rect().left(), total_rect().top());
+ // move the cursor so the position that is just a few pixels above
+ // the current cursor_rect
+ move_cursor(mfont->compute_cursor_pos(
+ get_text_rect(), text_, cursor_rect.left()+origin.x(),
+ cursor_rect.bottom()+origin.y()+mfont->height()/2));
+ }
+
+ if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (is_printable)
+ {
+ if (ctrl)
+ {
+ if (key == 'a')
+ {
+ on_select_all();
+ }
+ else if (key == 'c')
+ {
+ on_copy();
+ }
+ else if (key == 'v')
+ {
+ on_paste();
+ }
+ else if (key == 'x')
+ {
+ on_cut();
+ }
+ }
+ else
+ {
+ if (highlight_start <= highlight_end)
+ {
+ text_ = text_.substr(0,highlight_start) + static_cast<unichar>(key) +
+ text_.substr(highlight_end+1,text_.size()-highlight_end-1);
+
+ adjust_total_rect();
+ move_cursor(highlight_start+1);
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ }
+ else
+ {
+ text_ = text_.substr(0,cursor_pos) + static_cast<unichar>(key) +
+ text_.substr(cursor_pos,text_.size()-cursor_pos);
+ adjust_total_rect();
+ move_cursor(cursor_pos+1);
+ }
+
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+
+ }
+
+ if (key == '\n')
+ {
+ if (enter_key_handler.is_set())
+ enter_key_handler();
+ }
+ }
+ else if (key == base_window::KEY_BACKSPACE)
+ {
+ // if something is highlighted then delete that
+ if (highlight_start <= highlight_end)
+ {
+ on_delete_selected();
+ }
+ else if (cursor_pos != 0)
+ {
+ text_ = text_.erase(cursor_pos-1,1);
+ adjust_total_rect();
+ move_cursor(cursor_pos-1);
+
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ else
+ {
+ // do this just so it repaints itself right
+ move_cursor(cursor_pos);
+ }
+
+ }
+ else if (key == base_window::KEY_DELETE)
+ {
+ // if something is highlighted then delete that
+ if (highlight_start <= highlight_end)
+ {
+ on_delete_selected();
+ }
+ else if (cursor_pos != static_cast<long>(text_.size()))
+ {
+ text_ = text_.erase(cursor_pos,1);
+
+ adjust_total_rect();
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ else
+ {
+ // do this just so it repaints itself right
+ move_cursor(cursor_pos);
+ }
+
+ }
+ else if (key == base_window::KEY_HOME)
+ {
+ if (ctrl)
+ {
+ move_cursor(0);
+ }
+ else if (cursor_pos != 0)
+ {
+ // find the start of the current line
+ ustring::size_type pos = text_.find_last_of('\n',cursor_pos-1);
+ if (pos == ustring::npos)
+ pos = 0;
+ else
+ pos += 1;
+ move_cursor(static_cast<unsigned long>(pos));
+
+ }
+
+ if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (key == base_window::KEY_END)
+ {
+ if (ctrl)
+ {
+ move_cursor(static_cast<unsigned long>(text_.size()));
+ }
+ {
+ ustring::size_type pos = text_.find_first_of('\n',cursor_pos);
+ if (pos == ustring::npos)
+ pos = text_.size();
+
+ move_cursor(static_cast<unsigned long>(pos));
+ }
+
+ if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (key == base_window::KEY_PAGE_DOWN || key == base_window::KEY_PAGE_UP)
+ {
+ long jump_size = display_rect().height() -
+ std::min(mfont->height()*3, display_rect().height()/5);
+
+ // if we are supposed to page up then just jump in the other direction
+ if (key == base_window::KEY_PAGE_UP)
+ jump_size = -jump_size;
+
+ scroll_to_rect(translate_rect(display_rect(), point(0, jump_size )));
+ }
+
+ cursor_visible = true;
+ recent_movement = true;
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ on_string_put(
+ const std::wstring &str
+ )
+ {
+ if (has_focus && enabled && !hidden)
+ {
+ ustring ustr = convert_wstring_to_utf32(str);
+ if (highlight_start <= highlight_end)
+ {
+ text_ = text_.substr(0,highlight_start) + ustr +
+ text_.substr(highlight_end+1,text_.size()-highlight_end-1);
+
+ adjust_total_rect();
+ move_cursor(highlight_start+ustr.size());
+ highlight_start = 0;
+ highlight_end = -1;
+ on_no_text_selected();
+ }
+ else
+ {
+ text_ = text_.substr(0,cursor_pos) + ustr +
+ text_.substr(cursor_pos,text_.size()-cursor_pos);
+
+ adjust_total_rect();
+ move_cursor(cursor_pos+ustr.size());
+ }
+
+
+ // send out the text modified event
+ if (text_modified_handler.is_set())
+ text_modified_handler();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void text_box::
+ move_cursor (
+ unsigned long pos
+ )
+ {
+ using namespace std;
+ const long old_cursor_pos = cursor_pos;
+
+
+
+ // figure out where the cursor is supposed to be
+ cursor_rect = mfont->compute_cursor_rect(get_text_rect(), text_, pos);
+ const point origin(total_rect().left(), total_rect().top());
+
+
+ cursor_pos = pos;
+
+
+ const unsigned long padding = style->get_padding(*mfont);
+
+ // now scroll us so that we can see the current cursor
+ scroll_to_rect(centered_rect(cursor_rect, cursor_rect.width() + padding + 6, cursor_rect.height() + 1));
+
+ // adjust the cursor_rect so that it is relative to the total_rect
+ cursor_rect = translate_rect(cursor_rect, -origin);
+
+ parent.set_im_pos(cursor_rect.left(), cursor_rect.top());
+
+ if (old_cursor_pos != cursor_pos)
+ {
+ if (shift_pos != -1)
+ {
+ highlight_start = std::min(shift_pos,cursor_pos);
+ highlight_end = std::max(shift_pos,cursor_pos)-1;
+ }
+
+ if (highlight_start > highlight_end)
+ on_no_text_selected();
+ else
+ on_text_is_selected();
+
+ recent_movement = true;
+ cursor_visible = true;
+ parent.invalidate_rectangle(display_rect());
+ }
+
+ if (shift_pos == -1)
+ {
+ highlight_start = 0;
+ highlight_end = -1;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// perspective_display member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ perspective_display::
+ perspective_display(
+ drawable_window& w
+ ) :
+ drawable(w,MOUSE_MOVE|MOUSE_CLICK|MOUSE_WHEEL)
+ {
+ clear_overlay();
+ enable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ perspective_display::
+ ~perspective_display(
+ )
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex lock(m);
+ rectangle old(rect);
+ rect = resize_rect(rect,width,height);
+ tform = camera_transform(tform.get_camera_pos(),
+ tform.get_camera_looking_at(),
+ tform.get_camera_up_direction(),
+ tform.get_camera_field_of_view(),
+ std::min(rect.width(),rect.height()));
+ parent.invalidate_rectangle(old+rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ add_overlay (
+ const std::vector<overlay_line>& overlay
+ )
+ {
+ auto_mutex M(m);
+ if (overlay.size() == 0)
+ return;
+ // push this new overlay into our overlay vector
+ overlay_lines.insert(overlay_lines.end(), overlay.begin(), overlay.end());
+
+ for (unsigned long i = 0; i < overlay.size(); ++i)
+ {
+ sum_pts += overlay[i].p1;
+ sum_pts += overlay[i].p2;
+ max_pts.x() = std::max(overlay[i].p1.x(), max_pts.x());
+ max_pts.x() = std::max(overlay[i].p2.x(), max_pts.x());
+ max_pts.y() = std::max(overlay[i].p1.y(), max_pts.y());
+ max_pts.y() = std::max(overlay[i].p2.y(), max_pts.y());
+ max_pts.z() = std::max(overlay[i].p1.z(), max_pts.z());
+ max_pts.z() = std::max(overlay[i].p2.z(), max_pts.z());
+ }
+
+ tform = camera_transform(max_pts,
+ sum_pts/(overlay_lines.size()*2+overlay_dots.size()),
+ vector<double>(0,0,1),
+ tform.get_camera_field_of_view(),
+ std::min(rect.width(),rect.height()));
+
+
+ // make the parent window redraw us now that we changed the overlay
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ add_overlay (
+ const std::vector<overlay_dot>& overlay
+ )
+ {
+ auto_mutex M(m);
+ if (overlay.size() == 0)
+ return;
+
+ for (unsigned long i = 0; i < overlay.size(); ++i)
+ {
+ overlay_dots.push_back(overlay[i]);
+
+ sum_pts += overlay[i].p;
+ max_pts.x() = std::max(overlay[i].p.x(), max_pts.x());
+ max_pts.y() = std::max(overlay[i].p.y(), max_pts.y());
+ max_pts.z() = std::max(overlay[i].p.z(), max_pts.z());
+ }
+
+ tform = camera_transform(max_pts,
+ sum_pts/(overlay_lines.size()*2+overlay_dots.size()),
+ vector<double>(0,0,1),
+ tform.get_camera_field_of_view(),
+ std::min(rect.width(),rect.height()));
+
+
+ // make the parent window redraw us now that we changed the overlay
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ clear_overlay (
+ )
+ {
+ auto_mutex lock(m);
+ overlay_dots.clear();
+ overlay_lines.clear();
+ sum_pts = vector<double>();
+ max_pts = vector<double>(-std::numeric_limits<double>::infinity(),
+ -std::numeric_limits<double>::infinity(),
+ -std::numeric_limits<double>::infinity());
+
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ set_dot_double_clicked_handler (
+ const any_function<void(const vector<double>&)>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ dot_clicked_event_handler = event_handler_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ draw (
+ const canvas& c
+ ) const
+ {
+ if (depth.nr() < (long)c.height() || depth.nc() < (long)c.width())
+ depth.set_size(c.height(), c.width());
+ assign_all_pixels(depth, std::numeric_limits<float>::infinity());
+
+ rectangle area = rect.intersect(c);
+ fill_rect(c, area, 0);
+ for (unsigned long i = 0; i < overlay_lines.size(); ++i)
+ {
+ draw_line(c, tform(overlay_lines[i].p1)+rect.tl_corner(),
+ tform(overlay_lines[i].p2)+rect.tl_corner(),
+ overlay_lines[i].color,
+ area);
+ }
+ for (unsigned long i = 0; i < overlay_dots.size(); ++i)
+ {
+ double scale, distance;
+ point p = tform(overlay_dots[i].p, scale, distance) + rect.tl_corner();
+ if (area.contains(p) && depth[p.y()-c.top()][p.x()-c.left()] > distance)
+ {
+ depth[p.y()-c.top()][p.x()-c.left()] = distance;
+ assign_pixel(c[p.y()-c.top()][p.x()-c.left()], overlay_dots[i].color);
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ on_wheel_up (
+ unsigned long //state
+ )
+ {
+ if (rect.contains(lastx,lasty) == false || hidden || !enabled)
+ return;
+
+ const double alpha = 0.10;
+ const vector<double> delta = alpha*(tform.get_camera_pos() - tform.get_camera_looking_at());
+ tform = camera_transform(
+ tform.get_camera_pos() - delta,
+ tform.get_camera_looking_at(),
+ tform.get_camera_up_direction(),
+ tform.get_camera_field_of_view(),
+ std::min(rect.width(),rect.height()));
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ on_wheel_down (
+ unsigned long //state
+ )
+ {
+ if (rect.contains(lastx,lasty) == false || hidden || !enabled)
+ return;
+
+ const double alpha = 0.10;
+ const vector<double> delta = alpha*(tform.get_camera_pos() - tform.get_camera_looking_at());
+ tform = camera_transform(
+ tform.get_camera_pos() + delta,
+ tform.get_camera_looking_at(),
+ tform.get_camera_up_direction(),
+ tform.get_camera_field_of_view(),
+ std::min(rect.width(),rect.height()));
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long, //state
+ long x,
+ long y,
+ bool is_double_click
+ )
+ {
+ if (btn == base_window::LEFT || btn == base_window::RIGHT)
+ {
+ last = point(x,y);
+ }
+ if (is_double_click && btn == base_window::LEFT && enabled && !hidden && overlay_dots.size() != 0)
+ {
+ double best_dist = std::numeric_limits<double>::infinity();
+ unsigned long best_idx = 0;
+ const dpoint pp(x,y);
+ for (unsigned long i = 0; i < overlay_dots.size(); ++i)
+ {
+ dpoint p = tform(overlay_dots[i].p) + rect.tl_corner();
+ double dist = length_squared(p-pp);
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ best_idx = i;
+ }
+ }
+ if (dot_clicked_event_handler.is_set())
+ dot_clicked_event_handler(overlay_dots[best_idx].p);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void perspective_display::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ if (!enabled || hidden)
+ return;
+
+ if (state == base_window::LEFT)
+ {
+ const point cur(x, y);
+ dpoint delta = last-cur;
+ last = cur;
+
+ const vector<double> radius = tform.get_camera_pos()-tform.get_camera_looking_at();
+ delta *= 2*pi*length(radius)/600.0;
+ vector<double> tangent_x = tform.get_camera_up_direction().cross(radius).normalize();
+ vector<double> tangent_y = radius.cross(tangent_x).normalize();
+ vector<double> new_pos = tform.get_camera_pos() + tangent_x*delta.x() + tangent_y*-delta.y();
+
+ // now make it have the correct radius relative to the looking at point.
+ new_pos = (new_pos-tform.get_camera_looking_at()).normalize()*length(radius) + tform.get_camera_looking_at();
+
+ tform = camera_transform(new_pos,
+ tform.get_camera_looking_at(),
+ tangent_y,
+ tform.get_camera_field_of_view(),
+ std::min(rect.width(),rect.height()));
+ parent.invalidate_rectangle(rect);
+ }
+ else if (state == (base_window::LEFT|base_window::SHIFT) ||
+ state == base_window::RIGHT)
+ {
+ const point cur(x, y);
+ dpoint delta = last-cur;
+ last = cur;
+
+ const vector<double> radius = tform.get_camera_pos()-tform.get_camera_looking_at();
+ delta *= 2*pi*length(radius)/600.0;
+ vector<double> tangent_x = tform.get_camera_up_direction().cross(radius).normalize();
+ vector<double> tangent_y = radius.cross(tangent_x).normalize();
+
+ vector<double> offset = tangent_x*delta.x() + tangent_y*-delta.y();
+
+
+ tform = camera_transform(
+ tform.get_camera_pos()+offset,
+ tform.get_camera_looking_at()+offset,
+ tform.get_camera_up_direction(),
+ tform.get_camera_field_of_view(),
+ std::min(rect.width(),rect.height()));
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// image_display member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ class image_display_functor
+ {
+ const std::string str;
+ const member_function_pointer<const std::string&> mfp;
+ public:
+ image_display_functor (
+ const std::string& str_,
+ const member_function_pointer<const std::string&>& mfp_
+ ) : str(str_),
+ mfp(mfp_)
+ {}
+
+ void operator() (
+ ) const { mfp(str); }
+ };
+ }
+
+ image_display::
+ image_display(
+ drawable_window& w
+ ):
+ scrollable_region(w,KEYBOARD_EVENTS),
+ zoom_in_scale(1),
+ zoom_out_scale(1),
+ drawing_rect(true),
+ rect_is_selected(false),
+ selected_rect(0),
+ default_rect_color(255,0,0,255),
+ parts_menu(w),
+ part_width(100), // "parts" circles are drawn 1.0/part_width size on the screen relative to the size of the bounding rectangle.
+ overlay_editing_enabled(true),
+ highlight_timer(*this, &image_display::timer_event_unhighlight_rect),
+ highlighted_rect(std::numeric_limits<unsigned long>::max()),
+ holding_shift_key(false)
+ {
+ enable_mouse_drag();
+
+ highlight_timer.set_delay_time(250);
+ set_horizontal_scroll_increment(1);
+ set_vertical_scroll_increment(1);
+ set_horizontal_mouse_wheel_scroll_increment(30);
+ set_vertical_mouse_wheel_scroll_increment(30);
+
+ parts_menu.disable();
+
+
+ enable_events();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ on_part_add (
+ const std::string& part_name
+ )
+ {
+ if (!rect_is_selected)
+ return;
+
+ const point loc = last_right_click_pos;
+
+ // Transform loc from gui window space into the space used by the overlay
+ // rectangles (i.e. relative to the raw image)
+ const point origin(total_rect().tl_corner());
+ point c1 = loc - origin;
+ if (zoom_in_scale != 1)
+ {
+ c1 = c1/(double)zoom_in_scale;
+ }
+ else if (zoom_out_scale != 1)
+ {
+ c1 = c1*(double)zoom_out_scale;
+ }
+
+ overlay_rects[selected_rect].parts[part_name] = c1;
+ parent.invalidate_rectangle(rect);
+
+ if (event_handler.is_set())
+ event_handler();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ image_display::
+ ~image_display(
+ )
+ {
+ highlight_timer.stop_and_wait();
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle image_display::
+ get_image_display_rect (
+ ) const
+ {
+ if (zoom_in_scale != 1)
+ {
+ return rectangle(0,0, img.nc()*zoom_in_scale-1, img.nr()*zoom_in_scale-1);
+ }
+ else if (zoom_out_scale != 1)
+ {
+ return rectangle(0,0, img.nc()/zoom_out_scale-1, img.nr()/zoom_out_scale-1);
+ }
+ else
+ {
+ return dlib::get_rect(img);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ add_overlay (
+ const overlay_rect& overlay
+ )
+ {
+ auto_mutex M(m);
+ // push this new overlay into our overlay vector
+ overlay_rects.push_back(overlay);
+
+ // make the parent window redraw us now that we changed the overlay
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ add_overlay (
+ const overlay_line& overlay
+ )
+ {
+ auto_mutex M(m);
+
+ // push this new overlay into our overlay vector
+ overlay_lines.push_back(overlay);
+
+ // make the parent window redraw us now that we changed the overlay
+ parent.invalidate_rectangle(get_rect_on_screen(rectangle(overlay.p1, overlay.p2)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ add_overlay (
+ const overlay_circle& overlay
+ )
+ {
+ auto_mutex M(m);
+
+ // push this new overlay into our overlay vector
+ overlay_circles.push_back(overlay);
+
+ // make the parent window redraw us now that we changed the overlay
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ add_overlay (
+ const std::vector<overlay_rect>& overlay
+ )
+ {
+ auto_mutex M(m);
+
+ // push this new overlay into our overlay vector
+ overlay_rects.insert(overlay_rects.end(), overlay.begin(), overlay.end());
+
+ // make the parent window redraw us now that we changed the overlay
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ add_overlay (
+ const std::vector<overlay_line>& overlay
+ )
+ {
+ auto_mutex M(m);
+
+ // push this new overlay into our overlay vector
+ overlay_lines.insert(overlay_lines.end(), overlay.begin(), overlay.end());
+
+ // make the parent window redraw us now that we changed the overlay
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ add_overlay (
+ const std::vector<overlay_circle>& overlay
+ )
+ {
+ auto_mutex M(m);
+
+ // push this new overlay into our overlay vector
+ overlay_circles.insert(overlay_circles.end(), overlay.begin(), overlay.end());
+
+ // make the parent window redraw us now that we changed the overlay
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ clear_overlay (
+ )
+ {
+ auto_mutex M(m);
+ overlay_rects.clear();
+ overlay_lines.clear();
+ overlay_circles.clear();
+ parent.invalidate_rectangle(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle image_display::
+ get_rect_on_screen (
+ rectangle orect
+ ) const
+ {
+ const point origin(total_rect().tl_corner());
+ orect.left() = orect.left()*zoom_in_scale/zoom_out_scale;
+ orect.top() = orect.top()*zoom_in_scale/zoom_out_scale;
+ if (zoom_in_scale != 1)
+ {
+ // make it so the box surrounds the pixels when we zoom in.
+ orect.right() = (orect.right()+1)*zoom_in_scale/zoom_out_scale;
+ orect.bottom() = (orect.bottom()+1)*zoom_in_scale/zoom_out_scale;
+ }
+ else
+ {
+ orect.right() = orect.right()*zoom_in_scale/zoom_out_scale;
+ orect.bottom() = orect.bottom()*zoom_in_scale/zoom_out_scale;
+ }
+
+ return translate_rect(orect, origin);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle image_display::
+ get_rect_on_screen (
+ unsigned long idx
+ ) const
+ {
+ return get_rect_on_screen(overlay_rects[idx].rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ draw (
+ const canvas& c
+ ) const
+ {
+ scrollable_region::draw(c);
+
+ rectangle area = display_rect().intersect(c);
+ if (area.is_empty())
+ return;
+
+ const point origin(total_rect().tl_corner());
+
+ // draw the image on the screen
+ const double scale = zoom_out_scale/(double)zoom_in_scale;
+ const rectangle img_area = total_rect().intersect(area);
+ for (long row = img_area.top(); row <= img_area.bottom(); ++row)
+ {
+ const long rc = row-c.top();
+ const long rimg = (row-origin.y())*scale;
+ for (long col = img_area.left(); col <= img_area.right(); ++col)
+ {
+ assign_pixel(c[rc][col-c.left()],
+ img[rimg][(col-origin.x())*scale]);
+ }
+ }
+
+ // draw the mouse cross-hairs
+ if (holding_shift_key && total_rect().contains(lastx,lasty) )
+ {
+ draw_line(c, point(lastx,-10000), point(lastx,100000),rgb_pixel(255,255,0), area);
+ draw_line(c, point(-10000,lasty), point(100000,lasty),rgb_pixel(255,255,0), area);
+ }
+
+ // now draw all the overlay rectangles
+ for (unsigned long i = 0; i < overlay_rects.size(); ++i)
+ {
+ const rectangle orect = get_rect_on_screen(i);
+ rgb_alpha_pixel color = overlay_rects[i].color;
+ // draw crossed out boxes slightly faded
+ if (overlay_rects[i].crossed_out)
+ color.alpha = 150;
+
+ if (rect_is_selected && selected_rect == i)
+ {
+ draw_rectangle(c, orect, invert_pixel(color), area);
+ }
+ else if (highlighted_rect < overlay_rects.size() && highlighted_rect == i)
+ {
+ // Draw the rectangle wider and with a slightly different color that tapers
+ // out at the edges of the line.
+ hsi_pixel temp;
+ assign_pixel(temp, 0);
+ assign_pixel(temp, overlay_rects[i].color);
+ temp.s = 255;
+ temp.h = temp.h + 20;
+ if (temp.i < 245)
+ temp.i += 10;
+ rgb_pixel p;
+ assign_pixel(p, temp);
+ rgb_alpha_pixel po, po2;
+ assign_pixel(po, p);
+ po.alpha = 160;
+ po2 = po;
+ po2.alpha = 90;
+ draw_rectangle(c, grow_rect(orect,2), po2, area);
+ draw_rectangle(c, grow_rect(orect,1), po, area);
+ draw_rectangle(c, orect, p, area);
+ draw_rectangle(c, shrink_rect(orect,1), po, area);
+ draw_rectangle(c, shrink_rect(orect,2), po2, area);
+ }
+ else
+ {
+ draw_rectangle(c, orect, color, area);
+ }
+
+ if (overlay_rects[i].label.size() != 0)
+ {
+ // make a rectangle that is at the spot we want to draw our string
+ rectangle r(orect.br_corner(), c.br_corner());
+ mfont->draw_string(c, r, overlay_rects[i].label, color, 0,
+ std::string::npos, area);
+ }
+
+
+ // draw circles for each "part" in this overlay rectangle.
+ std::map<std::string,point>::const_iterator itr;
+ for (itr = overlay_rects[i].parts.begin(); itr != overlay_rects[i].parts.end(); ++itr)
+ {
+ if (itr->second == OBJECT_PART_NOT_PRESENT)
+ continue;
+
+ const long part_size = (long)std::max(1.0,std::round(std::sqrt(orect.area())/part_width));
+ rectangle temp = centered_rect(get_rect_on_screen(centered_rect(itr->second,1,1)), part_size, part_size);
+
+ if (rect_is_selected && selected_rect == i &&
+ selected_part_name.size() != 0 && selected_part_name == itr->first)
+ {
+ draw_circle(c, center(temp), temp.width(), invert_pixel(color), area);
+ }
+ else
+ {
+ draw_circle(c, center(temp), temp.width(), color, area);
+ }
+
+ // make a rectangle that is at the spot we want to draw our string
+ rectangle r((temp.br_corner() + temp.bl_corner())/2,
+ c.br_corner());
+ mfont->draw_string(c, r, itr->first, color, 0,
+ std::string::npos, area);
+ }
+
+ if (overlay_rects[i].crossed_out)
+ {
+ if (rect_is_selected && selected_rect == i)
+ {
+ draw_line(c, orect.tl_corner(), orect.br_corner(),invert_pixel(color), area);
+ draw_line(c, orect.bl_corner(), orect.tr_corner(),invert_pixel(color), area);
+ }
+ else
+ {
+ draw_line(c, orect.tl_corner(), orect.br_corner(),color, area);
+ draw_line(c, orect.bl_corner(), orect.tr_corner(),color, area);
+ }
+ }
+ }
+
+ // now draw all the overlay lines
+ for (unsigned long i = 0; i < overlay_lines.size(); ++i)
+ {
+ draw_line(c,
+ zoom_in_scale*overlay_lines[i].p1/zoom_out_scale + origin,
+ zoom_in_scale*overlay_lines[i].p2/zoom_out_scale + origin,
+ overlay_lines[i].color, area);
+ }
+
+ // now draw all the overlay circles
+ for (unsigned long i = 0; i < overlay_circles.size(); ++i)
+ {
+ const point center = zoom_in_scale*overlay_circles[i].center/zoom_out_scale + origin;
+ const int radius = zoom_in_scale*overlay_circles[i].radius/zoom_out_scale;
+ draw_circle(c,
+ center,
+ radius,
+ overlay_circles[i].color, area);
+
+ if (overlay_circles[i].label.size() != 0)
+ {
+ const point temp = center + point(0,radius);
+
+ // make a rectangle that is at the spot we want to draw our string
+ rectangle r(temp, c.br_corner());
+ mfont->draw_string(c, r, overlay_circles[i].label, overlay_circles[i].color, 0,
+ std::string::npos, area);
+ }
+ }
+
+ if (drawing_rect)
+ draw_rectangle(c, rect_to_draw, invert_pixel(default_rect_color), area);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ scrollable_region::on_keydown(key,is_printable, state);
+
+ if (!is_printable && key==base_window::KEY_SHIFT)
+ {
+ if (!holding_shift_key)
+ {
+ holding_shift_key = true;
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (holding_shift_key)
+ {
+ holding_shift_key = false;
+ parent.invalidate_rectangle(rect);
+ }
+
+ if (!is_printable && !hidden && enabled && rect_is_selected &&
+ (key == base_window::KEY_BACKSPACE || key == base_window::KEY_DELETE))
+ {
+ moving_overlay = false;
+ rect_is_selected = false;
+ parts_menu.disable();
+ if (selected_part_name.size() == 0)
+ overlay_rects.erase(overlay_rects.begin() + selected_rect);
+ else
+ overlay_rects[selected_rect].parts.erase(selected_part_name);
+ parent.invalidate_rectangle(rect);
+
+ if (event_handler.is_set())
+ event_handler();
+ }
+
+ if (!hidden && enabled && rect_is_selected &&
+ ((is_printable && key == 'i') || (!is_printable && key==base_window::KEY_END)))
+ {
+ overlay_rects[selected_rect].crossed_out = !overlay_rects[selected_rect].crossed_out;
+ parent.invalidate_rectangle(rect);
+
+ if (event_handler.is_set())
+ event_handler();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ add_labelable_part_name (
+ const std::string& name
+ )
+ {
+ auto_mutex lock(m);
+ if (part_names.insert(name).second)
+ {
+ member_function_pointer<const std::string&> mfp;
+ mfp.set(*this,&image_display::on_part_add);
+ parts_menu.menu().add_menu_item(menu_item_text("Add " + name,impl::image_display_functor(name,mfp)));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ clear_labelable_part_names (
+ )
+ {
+ auto_mutex lock(m);
+ part_names.clear();
+ parts_menu.menu().clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ )
+ {
+ scrollable_region::on_mouse_down(btn, state, x, y, is_double_click);
+
+ if (state&base_window::SHIFT)
+ {
+ holding_shift_key = true;
+ }
+ else if (holding_shift_key)
+ {
+ holding_shift_key = false;
+ parent.invalidate_rectangle(rect);
+ }
+
+ if (rect.contains(x,y) == false || hidden || !enabled)
+ return;
+
+ if (image_clicked_handler.is_set())
+ {
+ const point origin(total_rect().tl_corner());
+ point p(x,y);
+ p -= origin;
+ if (zoom_in_scale != 1)
+ p = p/zoom_in_scale;
+ else if (zoom_out_scale != 1)
+ p = p*zoom_out_scale;
+
+ if (dlib::get_rect(img).contains(p))
+ image_clicked_handler(p, is_double_click, btn);
+ }
+
+ if (!overlay_editing_enabled)
+ return;
+
+ if (btn == base_window::RIGHT && (state&base_window::SHIFT))
+ {
+ const bool rect_was_selected = rect_is_selected;
+ rect_is_selected = false;
+ parts_menu.disable();
+
+ long best_dist = std::numeric_limits<long>::max();
+ long best_idx = 0;
+ std::string best_part;
+
+ // check if this click landed on any of the overlay rectangles
+ for (unsigned long i = 0; i < overlay_rects.size(); ++i)
+ {
+ const rectangle orect = get_rect_on_screen(i);
+
+ const long dist = distance_to_rect_edge(orect, point(x,y));
+
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ best_idx = i;
+ best_part.clear();
+ }
+
+ std::map<std::string,point>::const_iterator itr;
+ for (itr = overlay_rects[i].parts.begin(); itr != overlay_rects[i].parts.end(); ++itr)
+ {
+ if (itr->second == OBJECT_PART_NOT_PRESENT)
+ continue;
+
+ const long part_size = (long)std::max(1.0,std::round(std::sqrt(orect.area())/part_width));
+ rectangle temp = centered_rect(get_rect_on_screen(centered_rect(itr->second,1,1)), part_size, part_size);
+ point c = center(temp);
+
+ // distance from edge of part circle
+ const long dist = static_cast<long>(std::abs(length(c - point(x,y)) + 0.5 - temp.width()));
+ if (dist < best_dist)
+ {
+ best_idx = i;
+ best_dist = dist;
+ best_part = itr->first;
+ }
+ }
+ }
+
+
+ if (best_dist < 13)
+ {
+ moving_overlay = true;
+ moving_rect = best_idx;
+ moving_part_name = best_part;
+ // If we are moving one of the sides of the rectangle rather than one of
+ // the parts circles then we need to figure out which side of the rectangle
+ // we are moving.
+ if (best_part.size() == 0)
+ {
+ // which side is the click closest to?
+ const rectangle orect = get_rect_on_screen(best_idx);
+ const point p = nearest_point(orect,point(x,y));
+ long dist_left = std::abs(p.x()-orect.left());
+ long dist_top = std::abs(p.y()-orect.top());
+ long dist_right = std::abs(p.x()-orect.right());
+ long dist_bottom = std::abs(p.y()-orect.bottom());
+ long min_val = std::min(std::min(dist_left,dist_right),std::min(dist_top,dist_bottom));
+ if (dist_left == min_val)
+ moving_what = MOVING_RECT_LEFT;
+ else if (dist_top == min_val)
+ moving_what = MOVING_RECT_TOP;
+ else if (dist_right == min_val)
+ moving_what = MOVING_RECT_RIGHT;
+ else
+ moving_what = MOVING_RECT_BOTTOM;
+ }
+ else
+ {
+ moving_what = MOVING_PART;
+ }
+ // Do this to make the moving stuff snap to the mouse immediately.
+ on_mouse_move(state|btn,x,y);
+ }
+
+ if (rect_was_selected)
+ parent.invalidate_rectangle(rect);
+
+ return;
+ }
+
+ if (btn == base_window::RIGHT && rect_is_selected)
+ {
+ last_right_click_pos = point(x,y);
+ parts_menu.set_rect(rect);
+ return;
+ }
+
+ if (btn == base_window::LEFT && (state&base_window::CONTROL) && !drawing_rect)
+ {
+ long best_dist = std::numeric_limits<long>::max();
+ long best_idx = 0;
+ // check if this click landed on any of the overlay rectangles
+ for (unsigned long i = 0; i < overlay_rects.size(); ++i)
+ {
+ const rectangle orect = get_rect_on_screen(i);
+ const long dist = distance_to_rect_edge(orect, point(x,y));
+
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ best_idx = i;
+ }
+ }
+ if (best_dist < 13)
+ {
+ overlay_rects[best_idx].label = default_rect_label;
+ overlay_rects[best_idx].color = default_rect_color;
+ highlighted_rect = best_idx;
+ highlight_timer.stop();
+ highlight_timer.start();
+ if (event_handler.is_set())
+ event_handler();
+ parent.invalidate_rectangle(rect);
+ }
+ return;
+ }
+
+
+ if (!is_double_click && btn == base_window::LEFT && (state&base_window::SHIFT))
+ {
+ drawing_rect = true;
+ rect_anchor = point(x,y);
+
+ if (rect_is_selected)
+ {
+ rect_is_selected = false;
+ parts_menu.disable();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+ else if (drawing_rect)
+ {
+ if (rect_is_selected)
+ {
+ rect_is_selected = false;
+ parts_menu.disable();
+ }
+
+ drawing_rect = false;
+ parent.invalidate_rectangle(rect);
+ }
+ else if (is_double_click)
+ {
+ const bool rect_was_selected = rect_is_selected;
+ rect_is_selected = false;
+ parts_menu.disable();
+
+ long best_dist = std::numeric_limits<long>::max();
+ long best_idx = 0;
+ std::string best_part;
+
+ // check if this click landed on any of the overlay rectangles
+ for (unsigned long i = 0; i < overlay_rects.size(); ++i)
+ {
+ const rectangle orect = get_rect_on_screen(i);
+
+ const long dist = distance_to_rect_edge(orect, point(x,y));
+
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ best_idx = i;
+ best_part.clear();
+ }
+
+ std::map<std::string,point>::const_iterator itr;
+ for (itr = overlay_rects[i].parts.begin(); itr != overlay_rects[i].parts.end(); ++itr)
+ {
+ if (itr->second == OBJECT_PART_NOT_PRESENT)
+ continue;
+
+ const long part_size = (long)std::max(1.0,std::round(std::sqrt(orect.area())/part_width));
+ rectangle temp = centered_rect(get_rect_on_screen(centered_rect(itr->second,1,1)), part_size, part_size);
+ point c = center(temp);
+
+ // distance from edge of part circle
+ const long dist = static_cast<long>(std::abs(length(c - point(x,y)) + 0.5 - temp.width()));
+ if (dist < best_dist)
+ {
+ best_idx = i;
+ best_dist = dist;
+ best_part = itr->first;
+ }
+ }
+ }
+
+
+ if (best_dist < 13)
+ {
+ rect_is_selected = true;
+ if (part_names.size() != 0)
+ parts_menu.enable();
+ selected_rect = best_idx;
+ selected_part_name = best_part;
+ if (orect_selected_event_handler.is_set())
+ orect_selected_event_handler(overlay_rects[best_idx]);
+ }
+
+ if (rect_is_selected || rect_was_selected)
+ parent.invalidate_rectangle(rect);
+ }
+ else if (rect_is_selected)
+ {
+ rect_is_selected = false;
+ parts_menu.disable();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::vector<image_display::overlay_rect> image_display::
+ get_overlay_rects (
+ ) const
+ {
+ auto_mutex lock(m);
+ return overlay_rects;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ set_default_overlay_rect_label (
+ const std::string& label
+ )
+ {
+ auto_mutex lock(m);
+ default_rect_label = label;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string image_display::
+ get_default_overlay_rect_label (
+ ) const
+ {
+ auto_mutex lock(m);
+ return default_rect_label;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ set_default_overlay_rect_color (
+ const rgb_alpha_pixel& color
+ )
+ {
+ auto_mutex lock(m);
+ default_rect_color = color;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ rgb_alpha_pixel image_display::
+ get_default_overlay_rect_color (
+ ) const
+ {
+ auto_mutex lock(m);
+ return default_rect_color;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ scrollable_region::on_mouse_up(btn,state,x,y);
+
+ if (state&base_window::SHIFT)
+ {
+ holding_shift_key = true;
+ }
+ else if (holding_shift_key)
+ {
+ holding_shift_key = false;
+ parent.invalidate_rectangle(rect);
+ }
+
+ if (drawing_rect && btn == base_window::LEFT && (state&base_window::SHIFT) &&
+ !hidden && enabled)
+ {
+ const point origin(total_rect().tl_corner());
+ point c1 = point(x,y) - origin;
+ point c2 = rect_anchor - origin;
+
+ if (zoom_in_scale != 1)
+ {
+ c1 = c1/(double)zoom_in_scale;
+ c2 = c2/(double)zoom_in_scale;
+ }
+ else if (zoom_out_scale != 1)
+ {
+ c1 = c1*(double)zoom_out_scale;
+ c2 = c2*(double)zoom_out_scale;
+ }
+
+ rectangle new_rect(c1,c2);
+ if (zoom_in_scale != 1)
+ {
+ // When we are zoomed in we adjust the rectangles a little so they
+ // are drown surrounding the pixels inside the rect. This adjustment
+ // is necessary to make this code consistent with this goal.
+ new_rect.right() -= 1;
+ new_rect.bottom() -= 1;
+ }
+
+
+ if (new_rect.width() > 0 && new_rect.height() > 0)
+ {
+ add_overlay(overlay_rect(new_rect, default_rect_color, default_rect_label));
+
+ if (event_handler.is_set())
+ event_handler();
+ }
+ }
+
+ if (drawing_rect)
+ {
+ drawing_rect = false;
+ parent.invalidate_rectangle(rect);
+ }
+ if (moving_overlay)
+ {
+ moving_overlay = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ scrollable_region::on_mouse_move(state,x,y);
+
+ if (enabled && !hidden)
+ {
+ if (holding_shift_key)
+ parent.invalidate_rectangle(rect);
+
+ if (state&base_window::SHIFT)
+ holding_shift_key = true;
+ else if (holding_shift_key)
+ holding_shift_key = false;
+ }
+
+ if (drawing_rect)
+ {
+ if ((state&base_window::LEFT) && (state&base_window::SHIFT) && !hidden && enabled)
+ {
+ rectangle new_rect(point(x,y), rect_anchor);
+ parent.invalidate_rectangle(new_rect + rect_to_draw);
+ rect_to_draw = new_rect;
+ }
+ else
+ {
+ drawing_rect = false;
+ parent.invalidate_rectangle(rect);
+ }
+ moving_overlay = false;
+ }
+ else if (moving_overlay)
+ {
+ if ((state&base_window::RIGHT) && (state&base_window::SHIFT) && !hidden && enabled)
+ {
+ // map point(x,y) into the image coordinate space.
+ point p = point(x,y) - total_rect().tl_corner();
+ if (zoom_in_scale != 1)
+ {
+ if (moving_what == MOVING_PART)
+ p = p/(double)zoom_in_scale-dpoint(0.5,0.5);
+ else
+ p = p/(double)zoom_in_scale;
+ }
+ else if (zoom_out_scale != 1)
+ {
+ p = p*(double)zoom_out_scale;
+ }
+
+
+ if (moving_what == MOVING_PART)
+ {
+ if (overlay_rects[moving_rect].parts[moving_part_name] != p)
+ {
+ overlay_rects[moving_rect].parts[moving_part_name] = p;
+ parent.invalidate_rectangle(rect);
+ if (event_handler.is_set())
+ event_handler();
+ }
+ }
+ else
+ {
+ rectangle original = overlay_rects[moving_rect].rect;
+ if (moving_what == MOVING_RECT_LEFT)
+ overlay_rects[moving_rect].rect.left() = std::min(p.x(), overlay_rects[moving_rect].rect.right());
+ else if (moving_what == MOVING_RECT_RIGHT)
+ overlay_rects[moving_rect].rect.right() = std::max(p.x()-1, overlay_rects[moving_rect].rect.left());
+ else if (moving_what == MOVING_RECT_TOP)
+ overlay_rects[moving_rect].rect.top() = std::min(p.y(), overlay_rects[moving_rect].rect.bottom());
+ else
+ overlay_rects[moving_rect].rect.bottom() = std::max(p.y()-1, overlay_rects[moving_rect].rect.top());
+
+ if (original != overlay_rects[moving_rect].rect)
+ {
+ parent.invalidate_rectangle(rect);
+ if (event_handler.is_set())
+ event_handler();
+ }
+ }
+ }
+ else
+ {
+ moving_overlay = false;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ on_wheel_up (
+ unsigned long state
+ )
+ {
+ // disable mouse wheel if the user is drawing a rectangle
+ if (drawing_rect)
+ return;
+
+ // if CONTROL is not being held down
+ if ((state & base_window::CONTROL) == 0)
+ {
+ scrollable_region::on_wheel_up(state);
+ return;
+ }
+
+ if (rect.contains(lastx,lasty) == false || hidden || !enabled)
+ return;
+
+
+ if (zoom_in_scale < 100 && zoom_out_scale == 1)
+ {
+ const point mouse_loc(lastx, lasty);
+ // the pixel in img that the mouse is over
+ const point pix_loc = (mouse_loc - total_rect().tl_corner())/zoom_in_scale;
+
+ zoom_in_scale = zoom_in_scale*10/9 + 1;
+
+ set_total_rect_size(img.nc()*zoom_in_scale, img.nr()*zoom_in_scale);
+
+ // make is to the pixel under the mouse doesn't move while we zoom
+ const point delta = total_rect().tl_corner() - (mouse_loc - pix_loc*zoom_in_scale);
+ scroll_to_rect(translate_rect(display_rect(), delta));
+ }
+ else if (zoom_out_scale != 1)
+ {
+ const point mouse_loc(lastx, lasty);
+ // the pixel in img that the mouse is over
+ const point pix_loc = (mouse_loc - total_rect().tl_corner())*zoom_out_scale;
+
+ zoom_out_scale = zoom_out_scale*9/10;
+ if (zoom_out_scale == 0)
+ zoom_out_scale = 1;
+
+ set_total_rect_size(img.nc()/zoom_out_scale, img.nr()/zoom_out_scale);
+
+ // make is to the pixel under the mouse doesn't move while we zoom
+ const point delta = total_rect().tl_corner() - (mouse_loc - pix_loc/zoom_out_scale);
+ scroll_to_rect(translate_rect(display_rect(), delta));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_display::
+ on_wheel_down (
+ unsigned long state
+ )
+ {
+ // disable mouse wheel if the user is drawing a rectangle
+ if (drawing_rect)
+ return;
+
+ // if CONTROL is not being held down
+ if ((state & base_window::CONTROL) == 0)
+ {
+ scrollable_region::on_wheel_down(state);
+ return;
+ }
+
+ if (rect.contains(lastx,lasty) == false || hidden || !enabled)
+ return;
+
+
+ if (zoom_in_scale != 1)
+ {
+ const point mouse_loc(lastx, lasty);
+ // the pixel in img that the mouse is over
+ const point pix_loc = (mouse_loc - total_rect().tl_corner())/zoom_in_scale;
+
+ zoom_in_scale = zoom_in_scale*9/10;
+ if (zoom_in_scale == 0)
+ zoom_in_scale = 1;
+
+ set_total_rect_size(img.nc()*zoom_in_scale, img.nr()*zoom_in_scale);
+
+ // make is to the pixel under the mouse doesn't move while we zoom
+ const point delta = total_rect().tl_corner() - (mouse_loc - pix_loc*zoom_in_scale);
+ scroll_to_rect(translate_rect(display_rect(), delta));
+ }
+ else if (std::max(img.nr(), img.nc())/zoom_out_scale > 10)
+ {
+ const point mouse_loc(lastx, lasty);
+ // the pixel in img that the mouse is over
+ const point pix_loc = (mouse_loc - total_rect().tl_corner())*zoom_out_scale;
+
+ zoom_out_scale = zoom_out_scale*10/9 + 1;
+
+ set_total_rect_size(img.nc()/zoom_out_scale, img.nr()/zoom_out_scale);
+
+ // make is to the pixel under the mouse doesn't move while we zoom
+ const point delta = total_rect().tl_corner() - (mouse_loc - pix_loc/zoom_out_scale);
+ scroll_to_rect(translate_rect(display_rect(), delta));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// image_window member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ image_window::
+ image_window(
+ ) :
+ gui_img(*this),
+ window_has_closed(false),
+ have_last_click(false),
+ mouse_btn(0),
+ clicked_signaler(this->wm),
+ tie_input_events(false)
+ {
+
+ gui_img.set_image_clicked_handler(*this, &image_window::on_image_clicked);
+ gui_img.disable_overlay_editing();
+ // show this window on the screen
+ show();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ image_window::
+ ~image_window(
+ )
+ {
+ // You should always call close_window() in the destructor of window
+ // objects to ensure that no events will be sent to this window while
+ // it is being destructed.
+ close_window();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ base_window::on_close_return_code image_window::
+ on_window_close(
+ )
+ {
+ window_has_closed = true;
+ clicked_signaler.broadcast();
+ return base_window::CLOSE_WINDOW;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool image_window::
+ get_next_keypress (
+ unsigned long& key,
+ bool& is_printable,
+ unsigned long& state
+ )
+ {
+ auto_mutex lock(wm);
+ while (have_last_keypress == false && !window_has_closed &&
+ (have_last_click == false || !tie_input_events))
+ {
+ clicked_signaler.wait();
+ }
+
+ if (window_has_closed)
+ return false;
+
+ if (have_last_keypress)
+ {
+ // Mark that we are taking the key click so the next call to get_next_keypress()
+ // will have to wait for another click.
+ have_last_keypress = false;
+ key = next_key;
+ is_printable = next_is_printable;
+ state = next_state;
+ return true;
+ }
+ else
+ {
+ key = 0;
+ is_printable = true;
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ dlib::drawable_window::on_keydown(key,is_printable,state);
+
+ have_last_keypress = true;
+ next_key = key;
+ next_is_printable = is_printable;
+ next_state = state;
+ clicked_signaler.signal();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ tie_events (
+ )
+ {
+ auto_mutex lock(wm);
+ tie_input_events = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ untie_events (
+ )
+ {
+ auto_mutex lock(wm);
+ tie_input_events = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool image_window::
+ events_tied (
+ ) const
+ {
+ auto_mutex lock(wm);
+ return tie_input_events;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool image_window::
+ get_next_double_click (
+ point& p,
+ unsigned long& mouse_button
+ )
+ {
+ p = point(-1,-1);
+
+ auto_mutex lock(wm);
+ while (have_last_click == false && !window_has_closed &&
+ (have_last_keypress==false || !tie_input_events))
+ {
+ clicked_signaler.wait();
+ }
+
+ if (window_has_closed)
+ return false;
+
+ if (have_last_click)
+ {
+ // Mark that we are taking the point click so the next call to
+ // get_next_double_click() will have to wait for another click.
+ have_last_click = false;
+ mouse_button = mouse_btn;
+ p = last_clicked_point;
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ on_image_clicked (
+ const point& p,
+ bool is_double_click,
+ unsigned long btn
+ )
+ {
+ if (is_double_click)
+ {
+ have_last_click = true;
+ last_clicked_point = p;
+ mouse_btn = btn;
+ clicked_signaler.signal();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ add_overlay (
+ const overlay_rect& overlay
+ )
+ {
+ gui_img.add_overlay(overlay);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ add_overlay (
+ const overlay_line& overlay
+ )
+ {
+ gui_img.add_overlay(overlay);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ add_overlay (
+ const overlay_circle& overlay
+ )
+ {
+ gui_img.add_overlay(overlay);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ add_overlay (
+ const std::vector<overlay_rect>& overlay
+ )
+ {
+ gui_img.add_overlay(overlay);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ add_overlay (
+ const std::vector<overlay_line>& overlay
+ )
+ {
+ gui_img.add_overlay(overlay);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ add_overlay (
+ const std::vector<overlay_circle>& overlay
+ )
+ {
+ gui_img.add_overlay(overlay);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ clear_overlay (
+ )
+ {
+ gui_img.clear_overlay();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void image_window::
+ on_window_resized(
+ )
+ {
+ drawable_window::on_window_resized();
+ unsigned long width, height;
+ get_size(width,height);
+ gui_img.set_size(width, height);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_WIDGETs_CPP_
+
diff --git a/ml/dlib/dlib/gui_widgets/widgets.h b/ml/dlib/dlib/gui_widgets/widgets.h
new file mode 100644
index 000000000..305d7dc00
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/widgets.h
@@ -0,0 +1,4165 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifndef DLIB_WIDGETs_
+#define DLIB_WIDGETs_
+
+#include <cctype>
+#include <memory>
+#include <set>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "../algs.h"
+#include "widgets_abstract.h"
+#include "drawable.h"
+#include "../gui_core.h"
+#include "fonts.h"
+#include "../timer.h"
+#include "base_widgets.h"
+#include "../member_function_pointer.h"
+#include "../array.h"
+#include "../array2d.h"
+#include "../sequence.h"
+#include "../dir_nav.h"
+#include "../queue.h"
+#include "style.h"
+#include "../string.h"
+#include "../misc_api.h"
+#include "../any.h"
+#include "../image_processing/full_object_detection.h"
+
+#ifdef _MSC_VER
+// This #pragma directive is also located in the algs.h file but for whatever
+// reason visual studio 9 just ignores it when it is only there.
+
+// this is to disable the "'this' : used in base member initializer list"
+// warning you get from some of the GUI objects since all the objects
+// require that their parent class be passed into their constructor.
+// In this case though it is totally safe so it is ok to disable this warning.
+#pragma warning(disable : 4355)
+#endif
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class label
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class label : public drawable
+ {
+ public:
+ label(
+ drawable_window& w
+ ) :
+ drawable(w),
+ text_color_(0,0,0)
+ {
+ enable_events();
+ }
+
+ ~label()
+ { disable_events(); parent.invalidate_rectangle(rect); }
+
+ void set_text (
+ const std::string& text
+ );
+
+ void set_text (
+ const std::wstring& text
+ );
+
+ void set_text (
+ const dlib::ustring& text
+ );
+
+ const std::string text (
+ ) const;
+
+ const std::wstring wtext (
+ ) const;
+
+ const dlib::ustring utext (
+ ) const;
+
+ void set_text_color (
+ const rgb_pixel color
+ );
+
+ const rgb_pixel text_color (
+ ) const;
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ private:
+ dlib::ustring text_;
+ rgb_pixel text_color_;
+
+
+ // restricted functions
+ label(label&); // copy constructor
+ label& operator=(label&); // assignment operator
+
+ protected:
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class toggle_button
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button : public button_action
+ {
+ /*!
+ INITIAL VALUE
+ - checked == false
+
+ CONVENTION
+ - is_checked() == checked
+ !*/
+
+ public:
+
+ toggle_button(
+ drawable_window& w
+ ) :
+ button_action(w),
+ btn_tooltip(w),
+ checked(false)
+ {
+ style.reset(new toggle_button_style_default());
+ enable_events();
+ }
+
+ ~toggle_button() { disable_events(); parent.invalidate_rectangle(rect); }
+
+ void set_name (
+ const std::string& name
+ );
+
+ void set_name (
+ const std::wstring& name
+ );
+
+ void set_name (
+ const dlib::ustring& name
+ );
+
+ void set_size (
+ unsigned long width_,
+ unsigned long height_
+ );
+
+ void set_tooltip_text (
+ const std::string& text
+ );
+
+ void set_tooltip_text (
+ const std::wstring& text
+ );
+
+ void set_tooltip_text (
+ const ustring& text
+ );
+
+ const std::string tooltip_text (
+ ) const;
+
+ const std::wstring tooltip_wtext (
+ ) const;
+
+ const dlib::ustring tooltip_utext (
+ ) const;
+
+ bool is_checked (
+ ) const;
+
+ const std::string name (
+ ) const;
+
+ const std::wstring wname (
+ ) const;
+
+ const dlib::ustring uname (
+ ) const;
+
+ void set_checked (
+ );
+
+ void set_unchecked (
+ );
+
+ void show (
+ );
+
+ void hide (
+ );
+
+ void enable (
+ );
+
+ void disable (
+ );
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ void set_pos (
+ long x,
+ long y
+ );
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ )
+ {
+ auto_mutex M(m);
+ style.reset(new style_type(style_));
+ rect = move_rect(style->get_min_size(name_,*mfont), rect.left(), rect.top());
+ parent.invalidate_rectangle(rect);
+ }
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler_)()
+ )
+ {
+ auto_mutex M(m);
+ event_handler = make_mfp(object,event_handler_);
+ event_handler_self.clear();
+ }
+
+ void set_click_handler (
+ const any_function<void()>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ event_handler = event_handler_;
+ event_handler_self.clear();
+ }
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler_)(toggle_button&)
+ )
+ {
+ auto_mutex M(m);
+ event_handler_self = make_mfp(object,event_handler_);
+ event_handler.clear();
+ }
+
+ void set_sourced_click_handler (
+ const any_function<void(toggle_button&)>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ event_handler_self = event_handler_;
+ event_handler.clear();
+ }
+
+ private:
+
+ // restricted functions
+ toggle_button(toggle_button&); // copy constructor
+ toggle_button& operator=(toggle_button&); // assignment operator
+
+ dlib::ustring name_;
+ tooltip btn_tooltip;
+ bool checked;
+
+ any_function<void()> event_handler;
+ any_function<void(toggle_button&)> event_handler_self;
+
+ std::unique_ptr<toggle_button_style> style;
+
+ protected:
+
+ void draw (
+ const canvas& c
+ ) const { style->draw_toggle_button(c,rect,enabled,*mfont,lastx,lasty,name_,is_depressed(),checked); }
+
+ void on_button_up (
+ bool mouse_over
+ );
+
+ void on_mouse_over (
+ ){ if (style->redraw_on_mouse_over()) parent.invalidate_rectangle(rect); }
+
+ void on_mouse_not_over (
+ ){ if (style->redraw_on_mouse_over()) parent.invalidate_rectangle(rect); }
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class text_field
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class text_field : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ text_color_ == rgb_pixel(0,0,0)
+ bg_color_ == rgb_pixel(255,255,255)
+ cursor_pos == 0
+ text_width == 0
+ text_ == ""
+ has_focus == false
+ cursor_visible == false
+ recent_movement == false
+ highlight_start == 0
+ highlight_end == -1
+ shift_pos == -1
+ text_pos == 0
+
+ CONVENTION
+ - cursor_pos == the position of the cursor in the string text_. The
+ cursor appears before the letter text_[cursor_pos]
+ - cursor_x == the x coordinate of the cursor relative to the left side
+ of rect. i.e. the number of pixels that separate the cursor from the
+ left side of the text_field.
+ - has_focus == true if this text field has keyboard input focus
+ - cursor_visible == true if the cursor should be painted
+ - text_ == text()
+ - text_pos == the index of the first letter in text_ that appears in
+ this text field.
+ - text_width == the width of text_[text_pos] though text_[text.size()-1]
+
+ - if (has_focus && the user has recently moved the cursor) then
+ - recent_movement == true
+ - else
+ - recent_movement == false
+
+ - if (highlight_start <= highlight_end) then
+ - text[highlight_start] though text[highlight_end] should be
+ highlighted
+
+ - if (shift_pos != -1) then
+ - has_focus == true
+ - the shift key is being held down or the left mouse button is
+ being held down.
+ - shift_pos == the position of the cursor when the shift or mouse key
+ was first pressed.
+
+ - text_color() == text_color_
+ - background_color() == bg_color_
+ !*/
+
+ public:
+ text_field(
+ drawable_window& w
+ ) :
+ drawable(w,MOUSE_CLICK | KEYBOARD_EVENTS | MOUSE_MOVE | STRING_PUT),
+ text_color_(0,0,0),
+ bg_color_(255,255,255),
+ text_width(0),
+ text_pos(0),
+ recent_movement(false),
+ has_focus(false),
+ cursor_visible(false),
+ cursor_pos(0),
+ highlight_start(0),
+ highlight_end(-1),
+ shift_pos(-1),
+ t(*this,&text_field::timer_action),
+ right_click_menu(w)
+ {
+ style.reset(new text_field_style_default());
+ rect.set_bottom(mfont->height()+ (style->get_padding(*mfont))*2);
+ rect.set_right((style->get_padding(*mfont))*2);
+ cursor_x = style->get_padding(*mfont);
+
+ right_click_menu.menu().add_menu_item(menu_item_text("Cut",*this,&text_field::on_cut,'t'));
+ right_click_menu.menu().add_menu_item(menu_item_text("Copy",*this,&text_field::on_copy,'C'));
+ right_click_menu.menu().add_menu_item(menu_item_text("Paste",*this,&text_field::on_paste,'P'));
+ right_click_menu.menu().add_menu_item(menu_item_text("Delete",*this,&text_field::on_delete_selected,'D'));
+ right_click_menu.menu().add_menu_item(menu_item_separator());
+ right_click_menu.menu().add_menu_item(menu_item_text("Select All",*this,&text_field::on_select_all,'A'));
+
+ right_click_menu.set_rect(get_text_rect());
+ enable_events();
+
+ t.set_delay_time(500);
+ }
+
+ ~text_field (
+ )
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ t.stop_and_wait();
+ }
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ )
+ {
+ auto_mutex M(m);
+ style.reset(new style_type(style_));
+ // call this just so that this widget redraws itself with the new style
+ set_main_font(mfont);
+ }
+
+ void set_text (
+ const std::string& text_
+ );
+
+ void set_text (
+ const std::wstring& text_
+ );
+
+ void give_input_focus (
+ );
+
+ bool has_input_focus (
+ ) const;
+
+ void select_all_text (
+ );
+
+ void set_text (
+ const dlib::ustring& text_
+ );
+
+ const std::string text (
+ ) const;
+
+ const std::wstring wtext (
+ ) const;
+
+ const dlib::ustring utext (
+ ) const;
+
+ void set_text_color (
+ const rgb_pixel color
+ );
+
+ const rgb_pixel text_color (
+ ) const;
+
+ void set_background_color (
+ const rgb_pixel color
+ );
+
+ const rgb_pixel background_color (
+ ) const;
+
+ void set_width (
+ unsigned long width
+ );
+
+ void set_pos (
+ long x,
+ long y
+ );
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ int next_free_user_event_number (
+ ) const
+ {
+ return drawable::next_free_user_event_number()+1;
+ }
+
+ void disable (
+ );
+
+ void enable (
+ );
+
+ void hide (
+ );
+
+ void show (
+ );
+
+ template <
+ typename T
+ >
+ void set_text_modified_handler (
+ T& object,
+ void (T::*event_handler)()
+ )
+ {
+ auto_mutex M(m);
+ text_modified_handler = make_mfp(object,event_handler);
+ }
+
+ template <
+ typename T
+ >
+ void set_enter_key_handler (
+ T& object,
+ void (T::*event_handler)()
+ )
+ {
+ auto_mutex M(m);
+ enter_key_handler = make_mfp(object,event_handler);
+ }
+
+ void set_text_modified_handler (
+ const any_function<void()>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ text_modified_handler = event_handler;
+ }
+
+ void set_enter_key_handler (
+ const any_function<void()>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ enter_key_handler = event_handler;
+ }
+
+ template <
+ typename T
+ >
+ void set_focus_lost_handler (
+ T& object,
+ void (T::*event_handler)()
+ )
+ {
+ auto_mutex M(m);
+ focus_lost_handler = make_mfp(object,event_handler);
+ }
+
+ void set_focus_lost_handler (
+ const any_function<void()>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ focus_lost_handler = event_handler;
+ }
+
+ private:
+
+ void on_cut (
+ );
+
+ void on_copy (
+ );
+
+ void on_paste (
+ );
+
+ void on_select_all (
+ );
+
+ void on_delete_selected (
+ );
+
+ void on_text_is_selected (
+ );
+
+ void on_no_text_selected (
+ );
+
+ void on_user_event (
+ int num
+ )
+ {
+ // ignore this user event if it isn't for us
+ if (num != drawable::next_free_user_event_number())
+ return;
+
+ if (recent_movement == false)
+ {
+ cursor_visible = !cursor_visible;
+ parent.invalidate_rectangle(rect);
+ }
+ else
+ {
+ if (cursor_visible == false)
+ {
+ cursor_visible = true;
+ parent.invalidate_rectangle(rect);
+ }
+ recent_movement = false;
+ }
+ }
+
+ void timer_action (
+ ) { parent.trigger_user_event(this,drawable::next_free_user_event_number()); }
+ /*!
+ ensures
+ - flips the state of cursor_visible
+ !*/
+
+ void move_cursor (
+ unsigned long pos
+ );
+ /*!
+ requires
+ - pos <= text_.size()
+ ensures
+ - moves the cursor to the position given by pos and moves the text
+ in the text box if necessary
+ - if the position changes then the parent window will be updated
+ !*/
+
+ rectangle get_text_rect (
+ ) const;
+ /*!
+ ensures
+ - returns the rectangle that should contain the text in this widget
+ !*/
+
+ dlib::ustring text_;
+ rgb_pixel text_color_;
+ rgb_pixel bg_color_;
+
+ unsigned long text_width;
+ unsigned long text_pos;
+
+
+ bool recent_movement;
+ bool has_focus;
+ bool cursor_visible;
+ long cursor_pos;
+ unsigned long cursor_x;
+
+ // this tells you what part of the text is highlighted
+ long highlight_start;
+ long highlight_end;
+ long shift_pos;
+ any_function<void()> text_modified_handler;
+ any_function<void()> enter_key_handler;
+ any_function<void()> focus_lost_handler;
+
+ std::unique_ptr<text_field_style> style;
+
+ timer<text_field> t;
+
+ popup_menu_region right_click_menu;
+
+ // restricted functions
+ text_field(text_field&); // copy constructor
+ text_field& operator=(text_field&); // assignment operator
+
+
+ protected:
+
+ void draw (
+ const canvas& c
+ ) const;
+
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+
+ void on_string_put (
+ const std::wstring &str
+ );
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class text_box
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class text_box : public scrollable_region
+ {
+ /*!
+ INITIAL VALUE
+ text_color_ == rgb_pixel(0,0,0)
+ bg_color_ == rgb_pixel(255,255,255)
+ cursor_pos == 0
+ text_ == ""
+ has_focus == false
+ cursor_visible == false
+ recent_movement == false
+ highlight_start == 0
+ highlight_end == -1
+ shift_pos == -1
+
+ CONVENTION
+ - cursor_pos == the position of the cursor in the string text_. The
+ cursor appears before the letter text_[cursor_pos]
+ - cursor_rect == The rectangle that should be drawn for the cursor.
+ The position is relative to total_rect().
+ - has_focus == true if this text field has keyboard input focus
+ - cursor_visible == true if the cursor should be painted
+ - text_ == text()
+
+ - if (has_focus && the user has recently moved the cursor) then
+ - recent_movement == true
+ - else
+ - recent_movement == false
+
+ - if (highlight_start <= highlight_end) then
+ - text[highlight_start] though text[highlight_end] should be
+ highlighted
+
+ - if (shift_pos != -1) then
+ - has_focus == true
+ - the shift key is being held down or the left mouse button is
+ being held down.
+ - shift_pos == the position of the cursor when the shift or mouse key
+ was first pressed.
+
+ - text_color() == text_color_
+ - background_color() == bg_color_
+ !*/
+
+ public:
+ text_box(
+ drawable_window& w
+ ) :
+ scrollable_region(w,MOUSE_CLICK | KEYBOARD_EVENTS | MOUSE_MOVE | STRING_PUT),
+ text_color_(0,0,0),
+ bg_color_(255,255,255),
+ recent_movement(false),
+ has_focus(false),
+ cursor_visible(false),
+ cursor_pos(0),
+ highlight_start(0),
+ highlight_end(-1),
+ shift_pos(-1),
+ t(*this,&text_box::timer_action),
+ right_click_menu(w)
+ {
+ style.reset(new text_box_style_default());
+
+ const long padding = static_cast<long>(style->get_padding(*mfont));
+ cursor_rect = mfont->compute_cursor_rect(rectangle(padding,padding,1000000,1000000), text_, 0);
+
+ adjust_total_rect();
+
+ set_vertical_mouse_wheel_scroll_increment(mfont->height());
+ set_horizontal_mouse_wheel_scroll_increment(mfont->height());
+
+ right_click_menu.menu().add_menu_item(menu_item_text("Cut",*this,&text_box::on_cut,'t'));
+ right_click_menu.menu().add_menu_item(menu_item_text("Copy",*this,&text_box::on_copy,'C'));
+ right_click_menu.menu().add_menu_item(menu_item_text("Paste",*this,&text_box::on_paste,'P'));
+ right_click_menu.menu().add_menu_item(menu_item_text("Delete",*this,&text_box::on_delete_selected,'D'));
+ right_click_menu.menu().add_menu_item(menu_item_separator());
+ right_click_menu.menu().add_menu_item(menu_item_text("Select All",*this,&text_box::on_select_all,'A'));
+
+ right_click_menu.set_rect(get_text_rect());
+
+ set_size(100,100);
+
+ enable_events();
+
+ t.set_delay_time(500);
+ }
+
+ ~text_box (
+ )
+ {
+ disable_events();
+ parent.invalidate_rectangle(rect);
+ t.stop_and_wait();
+ }
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ )
+ {
+ auto_mutex M(m);
+ style.reset(new style_type(style_));
+
+ scrollable_region::set_style(style_.get_scrollable_region_style());
+ // call this just so that this widget redraws itself with the new style
+ set_main_font(mfont);
+ }
+
+ void set_text (
+ const std::string& text_
+ );
+
+ void set_text (
+ const std::wstring& text_
+ );
+
+ void set_text (
+ const dlib::ustring& text_
+ );
+
+ const std::string text (
+ ) const;
+
+ const std::wstring wtext (
+ ) const;
+
+ const dlib::ustring utext (
+ ) const;
+
+ void set_text_color (
+ const rgb_pixel color
+ );
+
+ const rgb_pixel text_color (
+ ) const;
+
+ void set_background_color (
+ const rgb_pixel color
+ );
+
+ const rgb_pixel background_color (
+ ) const;
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ void set_pos (
+ long x,
+ long y
+ );
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ int next_free_user_event_number (
+ ) const
+ {
+ return scrollable_region::next_free_user_event_number()+1;
+ }
+
+ void disable (
+ );
+
+ void enable (
+ );
+
+ void hide (
+ );
+
+ void show (
+ );
+
+ template <
+ typename T
+ >
+ void set_text_modified_handler (
+ T& object,
+ void (T::*event_handler)()
+ )
+ {
+ auto_mutex M(m);
+ text_modified_handler = make_mfp(object,event_handler);
+ }
+
+ void set_text_modified_handler (
+ const any_function<void()>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ text_modified_handler = event_handler;
+ }
+
+ template <
+ typename T
+ >
+ void set_enter_key_handler (
+ T& object,
+ void (T::*event_handler)()
+ )
+ {
+ auto_mutex M(m);
+ enter_key_handler = make_mfp(object,event_handler);
+ }
+
+ void set_enter_key_handler (
+ const any_function<void()>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ enter_key_handler = event_handler;
+ }
+
+ template <
+ typename T
+ >
+ void set_focus_lost_handler (
+ T& object,
+ void (T::*event_handler)()
+ )
+ {
+ auto_mutex M(m);
+ focus_lost_handler = make_mfp(object,event_handler);
+ }
+
+ void set_focus_lost_handler (
+ const any_function<void()>& event_handler
+ )
+ {
+ auto_mutex M(m);
+ focus_lost_handler = event_handler;
+ }
+
+ private:
+
+ void on_cut (
+ );
+
+ void on_copy (
+ );
+
+ void on_paste (
+ );
+
+ void on_select_all (
+ );
+
+ void on_delete_selected (
+ );
+
+ void on_text_is_selected (
+ );
+
+ void on_no_text_selected (
+ );
+
+ void on_user_event (
+ int num
+ )
+ {
+ // ignore this user event if it isn't for us
+ if (num != scrollable_region::next_free_user_event_number())
+ return;
+
+ if (recent_movement == false)
+ {
+ cursor_visible = !cursor_visible;
+ parent.invalidate_rectangle(rect);
+ }
+ else
+ {
+ if (cursor_visible == false)
+ {
+ cursor_visible = true;
+ parent.invalidate_rectangle(rect);
+ }
+ recent_movement = false;
+ }
+ }
+
+ // The reason for using user actions here rather than just having the timer just call
+ // what it needs directly is to avoid a potential deadlock during destruction of this widget.
+ void timer_action (
+ ) { parent.trigger_user_event(this,scrollable_region::next_free_user_event_number()); }
+ /*!
+ ensures
+ - flips the state of cursor_visible
+ !*/
+
+ void move_cursor (
+ unsigned long pos
+ );
+ /*!
+ requires
+ - pos <= text_.size()
+ ensures
+ - moves the cursor to the position given by pos and moves the text
+ in the text box if necessary
+ - if the position changes then the parent window will be updated
+ !*/
+
+ rectangle get_text_rect (
+ ) const;
+ /*!
+ ensures
+ - returns the rectangle that should contain the text in this widget
+ !*/
+
+ void adjust_total_rect (
+ );
+ /*!
+ ensures
+ - adjusts total_rect() so that it is big enough to contain the text
+ currently in this object.
+ !*/
+
+ dlib::ustring text_;
+ rgb_pixel text_color_;
+ rgb_pixel bg_color_;
+
+
+
+ bool recent_movement;
+ bool has_focus;
+ bool cursor_visible;
+ long cursor_pos;
+ rectangle cursor_rect;
+
+ // this tells you what part of the text is highlighted
+ long highlight_start;
+ long highlight_end;
+ long shift_pos;
+ any_function<void()> text_modified_handler;
+ any_function<void()> enter_key_handler;
+ any_function<void()> focus_lost_handler;
+
+ std::unique_ptr<text_box_style> style;
+
+ timer<text_box> t;
+
+ popup_menu_region right_click_menu;
+
+ // restricted functions
+ text_box(text_box&); // copy constructor
+ text_box& operator=(text_box&); // assignment operator
+
+
+ protected:
+
+ void draw (
+ const canvas& c
+ ) const;
+
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+
+ void on_string_put (
+ const std::wstring &str
+ );
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class check_box
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class check_box : public toggle_button
+ {
+ public:
+ check_box(
+ drawable_window& w
+ ) : toggle_button(w)
+ {
+ set_style(toggle_button_style_check_box());
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class radio_button
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class radio_button : public toggle_button
+ {
+ public:
+ radio_button (
+ drawable_window& w
+ ) : toggle_button(w)
+ {
+ set_style(toggle_button_style_radio_button());
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class tabbed_display
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class tabbed_display : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - tabs.size() == 0
+ - selected_tab_ == 0
+
+ CONVENTION
+ - number_of_tabs() == tabs.size()
+ - tab_name(idx) == tabs[idx]
+ - if (tabs.size() > 0) then
+ - selected_tab_ == the index of the tab that is currently selected
+
+ - for all valid i:
+ - tabs[i].width == mfont->compute_size(tabs[i].name)
+ - tabs[i].rect == the rectangle that defines where this tab is
+ - if (tabs[i].group != 0) then
+ - tabs[i].group == a pointer to the widget_group for this tab.
+
+ - left_pad == the amount of padding in a tab to the left of the name string.
+ - right_pad == the amount of padding in a tab to the right of the name string.
+ - top_pad == the amount of padding in a tab to the top of the name string.
+ - bottom_pad == the amount of padding in a tab to the bottom of the name string.
+
+ - if (event_handler.is_set()) then
+ - event_handler() is what is called to process click events
+ on this object.
+ !*/
+
+ public:
+
+ tabbed_display(
+ drawable_window& w
+ );
+
+ virtual ~tabbed_display(
+ );
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ void set_number_of_tabs (
+ unsigned long num
+ );
+
+ unsigned long selected_tab (
+ ) const;
+
+ unsigned long number_of_tabs (
+ ) const;
+
+ const std::string tab_name (
+ unsigned long idx
+ ) const;
+
+ const std::wstring tab_wname (
+ unsigned long idx
+ ) const;
+
+ const dlib::ustring& tab_uname (
+ unsigned long idx
+ ) const;
+
+ void set_tab_name (
+ unsigned long idx,
+ const std::string& new_name
+ );
+
+ void set_tab_name (
+ unsigned long idx,
+ const std::wstring& new_name
+ );
+
+ void set_tab_name (
+ unsigned long idx,
+ const dlib::ustring& new_name
+ );
+
+ void set_pos (
+ long x,
+ long y
+ );
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*eh)(unsigned long new_idx,unsigned long old_idx)
+ )
+ {
+ auto_mutex M(m);
+ event_handler = make_mfp(object,eh);
+ }
+
+ void set_click_handler (
+ const any_function<void(unsigned long,unsigned long)>& eh
+ )
+ {
+ auto_mutex M(m);
+ event_handler = eh;
+ }
+
+ void set_tab_group (
+ unsigned long idx,
+ widget_group& group
+ );
+
+ void show (
+ );
+
+ void hide (
+ );
+
+ void enable (
+ );
+
+ void disable (
+ );
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ void fit_to_contents (
+ );
+
+ protected:
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ private:
+ void recompute_tabs (
+ );
+ /*!
+ ensures
+ - recomputes the rectangles for all the tabs and makes this object
+ wider if needed
+ !*/
+
+ void draw_tab (
+ const rectangle& tab,
+ const canvas& c
+ ) const;
+ /*!
+ ensures
+ - draws the outline of a tab as given by the rectangle onto c
+ !*/
+
+ struct tab_data
+ {
+ tab_data() : width(0), group(0) {}
+
+ dlib::ustring name;
+ unsigned long width;
+ rectangle rect;
+ widget_group* group;
+ };
+
+ unsigned long selected_tab_;
+
+ array<tab_data> tabs;
+
+ const long left_pad;
+ const long right_pad;
+ const long top_pad;
+ const long bottom_pad;
+
+ any_function<void(unsigned long,unsigned long)> event_handler;
+
+ // restricted functions
+ tabbed_display(tabbed_display&); // copy constructor
+ tabbed_display& operator=(tabbed_display&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class named_rectangle
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class named_rectangle : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ name == ""
+
+ CONVENTION
+ name_ == name()
+ !*/
+
+ public:
+
+ named_rectangle(
+ drawable_window& w
+ );
+
+ virtual ~named_rectangle(
+ );
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ void set_name (
+ const std::string& name
+ );
+
+ void set_name (
+ const std::wstring& name
+ );
+
+ void set_name (
+ const dlib::ustring& name
+ );
+
+ const std::string name (
+ ) const;
+
+ const std::wstring wname (
+ ) const;
+
+ const dlib::ustring uname (
+ ) const;
+
+ void wrap_around (
+ const rectangle& rect
+ );
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ protected:
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ private:
+
+ void make_name_fit_in_rect (
+ );
+
+ dlib::ustring name_;
+ unsigned long name_width;
+ unsigned long name_height;
+
+ // restricted functions
+ named_rectangle(named_rectangle&); // copy constructor
+ named_rectangle& operator=(named_rectangle&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class mouse_tracker
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class mouse_tracker : public draggable
+ {
+
+ public:
+
+ mouse_tracker(
+ drawable_window& w
+ );
+
+ ~mouse_tracker(
+ );
+
+ void show (
+ );
+
+ void hide (
+ );
+
+ void enable (
+ );
+
+ void disable (
+ );
+
+ void set_pos (
+ long x,
+ long y
+ );
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ protected:
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_drag (
+ );
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+
+ private:
+
+ const long offset;
+ named_rectangle nr;
+ label x_label;
+ label y_label;
+ std::ostringstream sout;
+
+ long click_x, click_y;
+
+ // restricted functions
+ mouse_tracker(mouse_tracker&); // copy constructor
+ mouse_tracker& operator=(mouse_tracker&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // function message_box()
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace message_box_helper
+ {
+ class box_win : public drawable_window
+ {
+ void initialize (
+ );
+ public:
+ box_win (
+ const std::string& title_,
+ const std::string& message_
+ );
+
+ box_win (
+ const std::wstring& title_,
+ const std::wstring& message_
+ );
+
+ box_win (
+ const dlib::ustring& title_,
+ const dlib::ustring& message_
+ );
+
+ ~box_win (
+ );
+
+ void set_click_handler (
+ const any_function<void()>& event_handler_
+ )
+ {
+ auto_mutex M(wm);
+ event_handler = event_handler_;
+ }
+
+ private:
+
+ static void deleter_thread (
+ void* param
+ );
+
+ void on_click (
+ );
+
+ on_close_return_code on_window_close (
+ );
+
+ const std::wstring title;
+ const std::wstring message;
+ label msg;
+ button btn_ok;
+
+ any_function<void()> event_handler;
+ };
+
+ class blocking_box_win : public drawable_window
+ {
+ void initialize (
+ );
+
+ public:
+ blocking_box_win (
+ const std::string& title_,
+ const std::string& message_
+ );
+
+ blocking_box_win (
+ const std::wstring& title_,
+ const std::wstring& message_
+ );
+
+ blocking_box_win (
+ const dlib::ustring& title_,
+ const dlib::ustring& message_
+ );
+
+ ~blocking_box_win (
+ );
+
+ private:
+
+ void on_click (
+ );
+
+ const std::wstring title;
+ const std::wstring message;
+ label msg;
+ button btn_ok;
+ };
+ }
+
+ template <
+ typename T
+ >
+ void message_box (
+ const std::string& title,
+ const std::string& message,
+ T& object,
+ void (T::*event_handler)()
+ )
+ {
+ using namespace message_box_helper;
+ box_win* win = new box_win(title,message);
+ win->set_click_handler(make_mfp(object,event_handler));
+ }
+
+ inline void message_box (
+ const std::string& title,
+ const std::string& message,
+ const any_function<void()>& event_handler
+ )
+ {
+ using namespace message_box_helper;
+ box_win* win = new box_win(title,message);
+ win->set_click_handler(event_handler);
+ }
+
+ inline void message_box (
+ const std::string& title,
+ const std::string& message
+ )
+ {
+ using namespace message_box_helper;
+ new box_win(title,message);
+ }
+
+ inline void message_box_blocking (
+ const std::string& title,
+ const std::string& message
+ )
+ {
+ using namespace message_box_helper;
+ blocking_box_win w(title,message);
+ w.wait_until_closed();
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class list_box
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace list_box_helper{
+ template <typename S = std::string>
+ class list_box : public scrollable_region,
+ public enumerable<const S>
+ {
+ /*!
+ INITIAL VALUE
+ - ms_enabled == false
+ - items.size() == 0
+ - last_selected = 0
+
+ CONVENTION
+ - size() == items.size()
+ - (*this)[i] == items[i].name
+ - is_selected(i) == items[i].is_selected
+ - ms_enabled == multiple_select_enabled()
+
+ - items[i].width == the width of items[i].name as given by font::compute_size()
+ - items[i].height == the height of items[i].name as given by font::compute_size()
+
+ - last_selected == the last item the user selected
+ !*/
+
+ public:
+
+ list_box(
+ drawable_window& w
+ );
+
+ ~list_box(
+ );
+
+ bool is_selected (
+ unsigned long index
+ ) const;
+
+ void select (
+ unsigned long index
+ );
+
+ void unselect (
+ unsigned long index
+ );
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style_
+ )
+ {
+ auto_mutex M(m);
+ style.reset(new style_type(style_));
+ scrollable_region::set_style(style_.get_scrollable_region_style());
+ parent.invalidate_rectangle(rect);
+ }
+
+ template <typename T>
+ void get_selected (
+ T& list
+ ) const
+ {
+ auto_mutex M(m);
+ list.clear();
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ if (items[i].is_selected)
+ {
+ unsigned long idx = i;
+ list.enqueue(idx);
+ }
+ }
+ }
+
+ template <typename T>
+ void load (
+ const T& list
+ )
+ {
+ auto_mutex M(m);
+ items.clear();
+ unsigned long i = 0;
+ items.set_max_size(list.size());
+ items.set_size(list.size());
+ list.reset();
+ unsigned long max_width = 0;
+ unsigned long total_height = 0;
+ while (list.move_next())
+ {
+ items[i].is_selected = false;
+ items[i].name = list.element();
+ mfont->compute_size(items[i].name,items[i].width, items[i].height);
+
+ if (items[i].width > max_width)
+ max_width = items[i].width;
+ total_height += items[i].height;
+
+ ++i;
+ }
+ set_total_rect_size(max_width, total_height);
+
+ parent.invalidate_rectangle(rect);
+ last_selected = 0;
+ }
+
+ const S& operator[] (
+ unsigned long index
+ ) const;
+
+ bool multiple_select_enabled (
+ ) const;
+
+ void enable_multiple_select (
+ );
+
+ void disable_multiple_select (
+ );
+
+ template <
+ typename T
+ >
+ void set_double_click_handler (
+ T& object,
+ void (T::*eh)(unsigned long index)
+ ) { auto_mutex M(m); event_handler = make_mfp(object,eh); }
+
+ void set_double_click_handler (
+ const any_function<void(unsigned long)>& eh
+ ) { auto_mutex M(m); event_handler = eh; }
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*eh)(unsigned long index)
+ ) { auto_mutex M(m); single_click_event_handler = make_mfp(object,eh); }
+
+ void set_click_handler (
+ const any_function<void(unsigned long)>& eh
+ ) { auto_mutex M(m); single_click_event_handler = eh; }
+
+ bool at_start (
+ ) const;
+
+ void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ const S& element (
+ ) const;
+
+ const S& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ size_t size (
+ ) const;
+
+ unsigned long get_selected (
+ ) const;
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ private:
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ template <typename SS>
+ struct data
+ {
+ SS name;
+ bool is_selected;
+ unsigned long width;
+ unsigned long height;
+ };
+
+ bool ms_enabled;
+ array<data<S> > items;
+ any_function<void(unsigned long)> event_handler;
+ any_function<void(unsigned long)> single_click_event_handler;
+ unsigned long last_selected;
+
+ std::unique_ptr<list_box_style> style;
+
+ // restricted functions
+ list_box(list_box&); // copy constructor
+ list_box& operator=(list_box&); // assignment operator
+ };
+ }
+ typedef list_box_helper::list_box<std::string> list_box;
+ typedef list_box_helper::list_box<std::wstring> wlist_box;
+ typedef list_box_helper::list_box<dlib::ustring> ulist_box;
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // function open_file_box()
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace open_file_box_helper
+ {
+ class box_win : public drawable_window
+ {
+ public:
+ box_win (
+ const std::string& title,
+ bool has_text_field = false
+ );
+
+ ~box_win (
+ );
+
+ void set_click_handler (
+ const any_function<void(const std::string&)>& event_handler_
+ )
+ {
+ auto_mutex M(wm);
+ event_handler = event_handler_;
+ }
+
+ private:
+
+ void set_sizes(
+ );
+
+ void on_window_resized (
+ );
+
+ void deleter_thread (
+ );
+
+ void enter_folder (
+ const std::string& folder_name
+ );
+
+ void on_dirs_click (
+ unsigned long idx
+ );
+
+ void on_files_click (
+ unsigned long idx
+ );
+
+ void on_files_double_click (
+ unsigned long
+ );
+
+ void on_cancel_click (
+ );
+
+ void on_open_click (
+ );
+
+ void on_path_button_click (
+ toggle_button& btn
+ );
+
+ bool set_dir (
+ const std::string& dir
+ );
+
+ void on_root_click (
+ );
+
+ on_close_return_code on_window_close (
+ );
+
+ label lbl_dirs;
+ label lbl_files;
+ label lbl_file_name;
+ list_box lb_dirs;
+ list_box lb_files;
+ button btn_ok;
+ button btn_cancel;
+ toggle_button btn_root;
+ text_field tf_file_name;
+ std::string path;
+ std::string prefix;
+ int cur_dir;
+
+ any_function<void(const std::string&)> event_handler;
+ sequence<std::unique_ptr<toggle_button> >::kernel_2a_c sob;
+ };
+ }
+
+ template <
+ typename T
+ >
+ void open_file_box (
+ T& object,
+ void (T::*event_handler)(const std::string&)
+ )
+ {
+ using namespace open_file_box_helper;
+ box_win* win = new box_win("Open File",true);
+ win->set_click_handler(make_mfp(object,event_handler));
+ }
+
+ inline void open_file_box (
+ const any_function<void(const std::string&)>& event_handler
+ )
+ {
+ using namespace open_file_box_helper;
+ box_win* win = new box_win("Open File",true);
+ win->set_click_handler(event_handler);
+ }
+
+ template <
+ typename T
+ >
+ void open_existing_file_box (
+ T& object,
+ void (T::*event_handler)(const std::string&)
+ )
+ {
+ using namespace open_file_box_helper;
+ box_win* win = new box_win("Open File");
+ win->set_click_handler(make_mfp(object,event_handler));
+ }
+
+ inline void open_existing_file_box (
+ const any_function<void(const std::string&)>& event_handler
+ )
+ {
+ using namespace open_file_box_helper;
+ box_win* win = new box_win("Open File");
+ win->set_click_handler(event_handler);
+ }
+
+ template <
+ typename T
+ >
+ void save_file_box (
+ T& object,
+ void (T::*event_handler)(const std::string&)
+ )
+ {
+ using namespace open_file_box_helper;
+ box_win* win = new box_win("Save File",true);
+ win->set_click_handler(make_mfp(object,event_handler));
+ }
+
+ inline void save_file_box (
+ const any_function<void(const std::string&)>& event_handler
+ )
+ {
+ using namespace open_file_box_helper;
+ box_win* win = new box_win("Save File",true);
+ win->set_click_handler(event_handler);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class menu_bar
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class menu_bar : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - menus.size() == 0
+ - open_menu == 0
+
+ CONVENTION
+ - size() == menus.size()
+ - all menu data is stored in menus
+ - menus[x].name == the name of the xth menu
+ - if (menus[x].underline_pos != std::string::npos) then
+ - menus[x].underline_pos == the position of the character in the
+ menu name that should be underlined
+ - menus[x].underline_p1 != menus[x].underline_p2
+ and these two points define the underline bar
+ - else
+ - menus[x].underline_p1 == menus[x].underline_p2
+ - menus[x].menu == menu(x)
+ - menus[x].rect == the rectangle in which menus[x].name is drawn
+ - menus[x].bgrect == the rectangle for the xth menu button
+
+ - if (there is an open menu on the screen) then
+ - open_menu == the index of the open menu from menus
+ - else
+ - open_menu == menus.size()
+ !*/
+
+ public:
+ menu_bar(
+ drawable_window& w
+ );
+
+ ~menu_bar();
+
+ // this function does nothing
+ void set_pos(long,long){}
+
+ void set_main_font (
+ const std::shared_ptr<font>& f
+ );
+
+ void set_number_of_menus (
+ unsigned long num
+ );
+
+ unsigned long number_of_menus (
+ ) const;
+
+ void set_menu_name (
+ unsigned long idx,
+ const std::string name,
+ char underline_ch = '\0'
+ );
+
+ void set_menu_name (
+ unsigned long idx,
+ const std::wstring name,
+ char underline_ch = '\0'
+ );
+
+ void set_menu_name (
+ unsigned long idx,
+ const dlib::ustring name,
+ char underline_ch = '\0'
+ );
+
+ const std::string menu_name (
+ unsigned long idx
+ ) const;
+
+ const std::wstring menu_wname (
+ unsigned long idx
+ ) const;
+
+ const dlib::ustring menu_uname (
+ unsigned long idx
+ ) const;
+
+ popup_menu& menu (
+ unsigned long idx
+ );
+
+ const popup_menu& menu (
+ unsigned long idx
+ ) const;
+
+ protected:
+
+ void on_window_resized (
+ );
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ void on_window_moved (
+ );
+
+ void on_focus_lost (
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long ,
+ long x,
+ long y,
+ bool
+ );
+
+ void on_mouse_move (
+ unsigned long ,
+ long x,
+ long y
+ );
+
+ void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+
+ private:
+
+ void show_menu (
+ unsigned long i
+ );
+
+ void hide_menu (
+ );
+
+ void on_popup_hide (
+ );
+
+ void compute_menu_geometry (
+ );
+
+ void adjust_position (
+ );
+
+ struct menu_data
+ {
+ menu_data():underline_pos(dlib::ustring::npos){}
+
+ dlib::ustring name;
+ dlib::ustring::size_type underline_pos;
+ popup_menu menu;
+ rectangle rect;
+ rectangle bgrect;
+ point underline_p1;
+ point underline_p2;
+ };
+
+ array<menu_data> menus;
+ unsigned long open_menu;
+
+ // restricted functions
+ menu_bar(menu_bar&); // copy constructor
+ menu_bar& operator=(menu_bar&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class directed_graph_drawer
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename graph_type>
+ class directed_graph_drawer : public zoomable_region
+ {
+ /*!
+ INITIAL VALUE
+ - edge_selected == false
+ - mouse_drag == false
+ - selected_node == 0
+ - graph_.number_of_nodes() == 0
+ - external_graph.number_of_nodes() == 0
+ - radius == 25
+ - last_mouse_click_in_display == false
+
+ CONVENTION
+ - radius == the radius of the nodes when they aren't zoomed
+ - external_graph and graph_ have the same graph structure
+ - external_graph == graph()
+ - external_graph.node(i) == graph_node(i)
+
+ - if (one of the nodes is selected) then
+ - selected_node < graph_.number_of_nodes()
+ - graph_.node(selected_node) == the selected node
+ - else
+ - selected_node == graph_.number_of_nodes()
+
+ - if (the user is dragging a node with the mouse) then
+ - mouse_drag == true
+ - drag_offset == the vector from the mouse position to the
+ center of the node
+ - else
+ - mouse_drag == false
+
+ - if (the user has selected an edge) then
+ - edge_selected == true
+ - the parent node is graph_.node(selected_edge_parent)
+ - the child node is graph_.node(selected_edge_parent)
+ - else
+ - edge_selected == false
+
+ - for all valid i:
+ - graph_.node(i).data.p == the center of the node in graph space
+ - graph_.node(i).data.name == node_label(i)
+ - graph_.node(i).data.color == node_color(i)
+ - graph_.node(i).data.str_rect == a rectangle sized to contain graph_.node(i).data.name
+
+ - if (the last mouse click in our parent window as in our display_rect_ ) then
+ - last_mouse_click_in_display == true
+ - else
+ - last_mouse_click_in_display == false
+ !*/
+
+ public:
+ directed_graph_drawer (
+ drawable_window& w
+ ) :
+ zoomable_region(w,MOUSE_CLICK | MOUSE_WHEEL | KEYBOARD_EVENTS),
+ radius(25),
+ edge_selected(false),
+ last_mouse_click_in_display(false)
+ {
+ mouse_drag = false;
+ selected_node = 0;
+
+ // Whenever you make your own drawable (or inherit from draggable or button_action)
+ // you have to remember to call this function to enable the events. The idea
+ // here is that you can perform whatever setup you need to do to get your
+ // object into a valid state without needing to worry about event handlers
+ // triggering before you are ready.
+ enable_events();
+ }
+
+ ~directed_graph_drawer (
+ )
+ {
+ // Disable all further events for this drawable object. We have to do this
+ // because we don't want draw() events coming to this object while or after
+ // it has been destructed.
+ disable_events();
+
+ // Tell the parent window to redraw its area that previously contained this
+ // drawable object.
+ parent.invalidate_rectangle(rect);
+ }
+
+ void clear_graph (
+ )
+ {
+ auto_mutex M(m);
+ graph_.clear();
+ external_graph.clear();
+ parent.invalidate_rectangle(display_rect());
+ }
+
+ const typename graph_type::node_type& graph_node (
+ unsigned long i
+ ) const
+ {
+ DLIB_ASSERT ( i < number_of_nodes() ,
+ "\tgraph_type::node_type& directed_graph_drawer::graph_node(i)"
+ << "\n\ti: " << i
+ << "\n\tnumber_of_nodes(): " << number_of_nodes()
+ );
+ return external_graph.node(i);
+ }
+
+ typename graph_type::node_type& graph_node (
+ unsigned long i
+ )
+ {
+ DLIB_ASSERT ( i < number_of_nodes() ,
+ "\tgraph_type::node_type& directed_graph_drawer::graph_node(i)"
+ << "\n\ti: " << i
+ << "\n\tnumber_of_nodes(): " << number_of_nodes()
+ );
+ return external_graph.node(i);
+ }
+
+ const graph_type& graph (
+ ) const
+ {
+ return external_graph;
+ }
+
+ void save_graph (
+ std::ostream& out
+ )
+ {
+ auto_mutex M(m);
+ serialize(external_graph, out);
+ serialize(graph_, out);
+ parent.invalidate_rectangle(display_rect());
+ }
+
+ void load_graph (
+ std::istream& in
+ )
+ {
+ auto_mutex M(m);
+ deserialize(external_graph, in);
+ deserialize(graph_, in);
+ parent.invalidate_rectangle(display_rect());
+ }
+
+ unsigned long number_of_nodes (
+ ) const
+ {
+ auto_mutex M(m);
+ return graph_.number_of_nodes();
+ }
+
+ void set_node_label (
+ unsigned long i,
+ const std::string& label
+ )
+ {
+ set_node_label(i, convert_mbstring_to_wstring(label));
+ }
+
+ void set_node_label (
+ unsigned long i,
+ const std::wstring& label
+ )
+ {
+ set_node_label(i, convert_wstring_to_utf32(label));
+ }
+
+ void set_node_label (
+ unsigned long i,
+ const dlib::ustring& label
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( i < number_of_nodes() ,
+ "\tvoid directed_graph_drawer::set_node_label(i,label)"
+ << "\n\ti: " << i
+ << "\n\tlabel: " << narrow(label)
+ << "\n\tnumber_of_nodes(): " << number_of_nodes()
+ );
+ graph_.node(i).data.name = label.c_str();
+ unsigned long width, height;
+ mfont->compute_size(label,width,height);
+ graph_.node(i).data.str_rect = rectangle(width,height);
+ parent.invalidate_rectangle(display_rect());
+ }
+
+ void set_node_color (
+ unsigned long i,
+ rgb_pixel color
+ )
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( i < number_of_nodes() ,
+ "\tvoid directed_graph_drawer::set_node_color(i,label)"
+ << "\n\ti: " << i
+ << "\n\tnumber_of_nodes(): " << number_of_nodes()
+ );
+ graph_.node(i).data.color = color;
+ parent.invalidate_rectangle(display_rect());
+ }
+
+ rgb_pixel node_color (
+ unsigned long i
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( i < number_of_nodes() ,
+ "\trgb_pixel directed_graph_drawer::node_color(i)"
+ << "\n\ti: " << i
+ << "\n\tnumber_of_nodes(): " << number_of_nodes()
+ );
+ return graph_.node(i).data.color;
+ }
+
+ const std::string node_label (
+ unsigned long i
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( i < number_of_nodes() ,
+ "\tconst std::ustring directed_graph_drawer::node_label(i)"
+ << "\n\ti: " << i
+ << "\n\tnumber_of_nodes(): " << number_of_nodes()
+ );
+ return narrow(graph_.node(i).data.name);
+ }
+
+ const std::wstring node_wlabel (
+ unsigned long i
+ ) const
+ {
+ return convert_utf32_to_wstring(node_ulabel(i));
+ }
+
+ const dlib::ustring node_ulabel (
+ unsigned long i
+ ) const
+ {
+ auto_mutex M(m);
+ DLIB_ASSERT ( i < number_of_nodes() ,
+ "\tconst std::ustring directed_graph_drawer::node_label(i)"
+ << "\n\ti: " << i
+ << "\n\tnumber_of_nodes(): " << number_of_nodes()
+ );
+ return graph_.node(i).data.name.c_str();
+ }
+
+ template <
+ typename T
+ >
+ void set_node_selected_handler (
+ T& object,
+ void (T::*event_handler_)(unsigned long)
+ )
+ {
+ auto_mutex M(m);
+ node_selected_handler = make_mfp(object,event_handler_);
+ }
+
+ void set_node_selected_handler (
+ const any_function<void(unsigned long)>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ node_selected_handler = event_handler_;
+ }
+
+ template <
+ typename T
+ >
+ void set_node_deselected_handler (
+ T& object,
+ void (T::*event_handler_)(unsigned long)
+ )
+ {
+ auto_mutex M(m);
+ node_deselected_handler = make_mfp(object,event_handler_);
+ }
+
+ void set_node_deselected_handler (
+ const any_function<void(unsigned long)>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ node_deselected_handler = event_handler_;
+ }
+
+ template <
+ typename T
+ >
+ void set_node_deleted_handler (
+ T& object,
+ void (T::*event_handler_)()
+ )
+ {
+ auto_mutex M(m);
+ node_deleted_handler = make_mfp(object,event_handler_);
+ }
+
+ void set_node_deleted_handler (
+ const any_function<void()>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ node_deleted_handler = event_handler_;
+ }
+
+ template <
+ typename T
+ >
+ void set_graph_modified_handler (
+ T& object,
+ void (T::*event_handler_)()
+ )
+ {
+ auto_mutex M(m);
+ graph_modified_handler = make_mfp(object,event_handler_);
+ }
+
+ void set_graph_modified_handler (
+ const any_function<void()>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ graph_modified_handler = event_handler_;
+ }
+
+ protected:
+
+ void on_keydown (
+ unsigned long key,
+ bool ,
+ unsigned long
+ )
+ {
+ // ignore all keyboard input if the last thing the user clicked on
+ // wasn't the display area
+ if (last_mouse_click_in_display == false)
+ return;
+
+ // if a node is selected
+ if (selected_node != graph_.number_of_nodes())
+ {
+ // deselect the node if the user hits escape
+ if (key == base_window::KEY_ESC)
+ {
+ parent.invalidate_rectangle(display_rect());
+ if (node_deselected_handler.is_set())
+ node_deselected_handler(selected_node);
+ selected_node = graph_.number_of_nodes();
+ }
+
+ // delete the node if the user hits delete
+ if (key == base_window::KEY_DELETE || key == base_window::KEY_BACKSPACE)
+ {
+ parent.invalidate_rectangle(display_rect());
+ graph_.remove_node(selected_node);
+ external_graph.remove_node(selected_node);
+ selected_node = graph_.number_of_nodes();
+ mouse_drag = false;
+ if (graph_modified_handler.is_set())
+ graph_modified_handler();
+ if (node_deleted_handler.is_set())
+ node_deleted_handler();
+ }
+ }
+
+ // if an edge is selected
+ if (edge_selected)
+ {
+ // deselect the node if the user hits escape
+ if (key == base_window::KEY_ESC)
+ {
+ parent.invalidate_rectangle(display_rect());
+ edge_selected = false;
+ }
+
+ // delete the node if the user hits delete
+ if (key == base_window::KEY_DELETE || key == base_window::KEY_BACKSPACE)
+ {
+ parent.invalidate_rectangle(display_rect());
+ graph_.remove_edge(selected_edge_parent, selected_edge_child);
+ external_graph.remove_edge(selected_edge_parent, selected_edge_child);
+ edge_selected = false;
+
+ if (graph_modified_handler.is_set())
+ graph_modified_handler();
+ }
+ }
+ }
+
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ if (mouse_drag)
+ {
+ const point p(nearest_point(display_rect(),point(x,y)));
+
+ point center = drag_offset + p;
+ graph_.node(selected_node).data.p = gui_to_graph_space(center);
+ parent.invalidate_rectangle(display_rect());
+ }
+ else
+ {
+ zoomable_region::on_mouse_move(state,x,y);
+ }
+
+ // check if the mouse isn't being dragged anymore
+ if ((state & base_window::LEFT) == 0)
+ {
+ mouse_drag = false;
+ }
+ }
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ )
+ {
+ mouse_drag = false;
+ zoomable_region::on_mouse_up(btn,state,x,y);
+ }
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ )
+ {
+ bool redraw = false;
+
+ if (display_rect().contains(x,y) &&
+ (btn == base_window::RIGHT || btn == base_window::LEFT) &&
+ (state & base_window::SHIFT) == 0 )
+ {
+ // start out saying no edge is selected
+ if (edge_selected)
+ {
+ edge_selected = false;
+ redraw = true;
+ }
+
+ bool click_hit_node = false;
+ dlib::vector<double,2> p(gui_to_graph_space(point(x,y)));
+ // check if this click is on an existing node
+ for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i)
+ {
+ dlib::vector<double,2> n(graph_.node(i).data.p);
+ if ((p-n).length() < radius)
+ {
+ click_hit_node = true;
+ point center = graph_to_gui_space(graph_.node(i).data.p);
+ mouse_drag = true;
+ drag_offset = center - point(x,y);
+
+ // only do something if the click isn't on the currently
+ // selected node
+ if (selected_node != i)
+ {
+ // send out the deselected event if appropriate
+ if (selected_node != graph_.number_of_nodes() && node_deselected_handler.is_set())
+ node_deselected_handler(selected_node);
+
+ selected_node = i;
+ redraw = true;
+ if (node_selected_handler.is_set())
+ node_selected_handler(selected_node);
+ }
+ break;
+ }
+ }
+
+ // if the click didn't hit any node then make sure nothing is selected
+ if (click_hit_node == false && selected_node != graph_.number_of_nodes())
+ {
+ if (node_deselected_handler.is_set())
+ node_deselected_handler(selected_node);
+ selected_node = graph_.number_of_nodes();
+ redraw = true;
+ }
+
+
+ // check if this click is on an edge if we didn't click on a node
+ if (click_hit_node == false)
+ {
+ for (unsigned long n = 0; n < graph_.number_of_nodes() && edge_selected == false; ++n)
+ {
+ const dlib::vector<double,2> parent_center(graph_to_gui_space(graph_.node(n).data.p));
+ for (unsigned long e = 0; e < graph_.node(n).number_of_children() && edge_selected == false; ++e)
+ {
+ const dlib::vector<double,2> child_center(graph_to_gui_space(graph_.node(n).child(e).data.p));
+
+ rectangle area;
+ area += parent_center;
+ area += child_center;
+ // if the point(x,y) is between the two nodes then lets consider it further
+ if (area.contains(point(x,y)))
+ {
+ p = point(x,y);
+ const dlib::vector<double> z(0,0,1);
+ // find the distance from the line between the two nodes
+ const dlib::vector<double,2> perpendicular(z.cross(parent_center-child_center).normalize());
+ double distance = std::abs((child_center-p).dot(perpendicular));
+ if (distance < 8)
+ {
+ edge_selected = true;
+ selected_edge_parent = n;
+ selected_edge_child = graph_.node(n).child(e).index();
+ redraw = true;
+ }
+ }
+ }
+ }
+ }
+
+
+ // if the click didn't land on any node then add a new one if this was
+ // a right mouse button click
+ if (click_hit_node == false && btn == base_window::RIGHT)
+ {
+ const unsigned long n = graph_.add_node();
+ external_graph.add_node();
+
+ graph_.node(n).data.p = gui_to_graph_space(point(x,y));
+
+ redraw = true;
+ selected_node = n;
+ mouse_drag = false;
+ if (graph_modified_handler.is_set())
+ graph_modified_handler();
+
+ if (node_selected_handler.is_set())
+ node_selected_handler(selected_node);
+
+ }
+ else if (selected_node == graph_.number_of_nodes())
+ {
+ // in this case the click landed in the white area between nodes
+ zoomable_region::on_mouse_down( btn, state, x, y, is_double_click);
+ }
+ }
+
+ // If the user is shift clicking with the mouse then see if we
+ // should add a new edge.
+ if (display_rect().contains(x,y) &&
+ btn == base_window::LEFT &&
+ (state & base_window::SHIFT) &&
+ selected_node != graph_.number_of_nodes() )
+ {
+ dlib::vector<double,2> p(gui_to_graph_space(point(x,y)));
+ // check if this click is on an existing node
+ for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i)
+ {
+ dlib::vector<double,2> n(graph_.node(i).data.p);
+ if ((p-n).length() < radius)
+ {
+ // add the edge if it doesn't already exist and isn't an edge back to
+ // the same node
+ if (graph_.has_edge(selected_node,i) == false && selected_node != i &&
+ graph_.has_edge(i, selected_node) == false)
+ {
+ graph_.add_edge(selected_node,i);
+ external_graph.add_edge(selected_node,i);
+ redraw = true;
+
+ if (graph_modified_handler.is_set())
+ graph_modified_handler();
+ }
+ break;
+ }
+ }
+ }
+
+
+ if (redraw)
+ parent.invalidate_rectangle(display_rect());
+
+
+ if (display_rect().contains(x,y) == false)
+ last_mouse_click_in_display = false;
+ else
+ last_mouse_click_in_display = true;
+ }
+
+ void draw (
+ const canvas& c
+ ) const
+ {
+ zoomable_region::draw(c);
+
+ rectangle area = c.intersect(display_rect());
+ if (area.is_empty() == true)
+ return;
+
+
+ if (enabled)
+ fill_rect(c,display_rect(),255);
+ else
+ fill_rect(c,display_rect(),128);
+
+
+ const unsigned long rad = static_cast<unsigned long>(radius*zoom_scale());
+ point center;
+
+
+ // first draw all the edges
+ for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i)
+ {
+ center = graph_to_gui_space(graph_.node(i).data.p);
+ const rectangle circle_area(centered_rect(center,2*(rad+8),2*(rad+8)));
+
+ // draw lines to all this node's parents
+ const dlib::vector<double> z(0,0,1);
+ for (unsigned long j = 0; j < graph_.node(i).number_of_parents(); ++j)
+ {
+ point p(graph_to_gui_space(graph_.node(i).parent(j).data.p));
+
+ rgb_pixel color(0,0,0);
+ // if this is the selected edge then draw it with red instead of black
+ if (edge_selected && selected_edge_child == i && selected_edge_parent == graph_.node(i).parent(j).index())
+ {
+ color.red = 255;
+ // we need to be careful when drawing this line to not draw it over the node dots since it
+ // has a different color from them and would look weird
+ dlib::vector<double,2> v(p-center);
+ v = v.normalize()*rad;
+ draw_line(c,center+v,p-v ,color, area);
+ }
+ else
+ {
+ draw_line(c,center,p ,color, area);
+ }
+
+
+ // draw the triangle pointing to this node
+ if (area.intersect(circle_area).is_empty() == false)
+ {
+ dlib::vector<double,2> v(p-center);
+ v = v.normalize();
+
+ dlib::vector<double,2> cross = z.cross(v).normalize();
+ dlib::vector<double,2> r(center + v*rad);
+ for (double i = 0; i < 8*zoom_scale(); i += 0.1)
+ draw_line(c,(r+v*i)+cross*i, (r+v*i)-cross*i,color,area);
+ }
+ }
+ }
+
+
+ // now draw all the node dots
+ for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i)
+ {
+ center = graph_to_gui_space(graph_.node(i).data.p);
+ const rectangle circle_area(centered_rect(center,2*(rad+8),2*(rad+8)));
+
+ // draw the actual dot for this node
+ if (area.intersect(circle_area).is_empty()==false)
+ {
+ rgb_alpha_pixel color;
+ assign_pixel(color, graph_.node(i).data.color);
+ // this node is in area so lets draw it and all of it's edges as well
+ draw_solid_circle(c,center,rad-3,color,area);
+ color.alpha = 240;
+ draw_circle(c,center,rad-3,color,area);
+ color.alpha = 200;
+ draw_circle(c,center,rad-2.5,color,area);
+ color.alpha = 160;
+ draw_circle(c,center,rad-2.0,color,area);
+ color.alpha = 120;
+ draw_circle(c,center,rad-1.5,color,area);
+ color.alpha = 80;
+ draw_circle(c,center,rad-1.0,color,area);
+ color.alpha = 40;
+ draw_circle(c,center,rad-0.5,color,area);
+
+ }
+
+
+ if (i == selected_node)
+ draw_circle(c,center,rad+5,rgb_pixel(0,0,255),area);
+ }
+
+
+ // now draw all the strings last
+ for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i)
+ {
+ center = graph_to_gui_space(graph_.node(i).data.p);
+ rectangle circle_area(centered_rect(center,2*rad+3,2*rad+3));
+ if (area.intersect(circle_area).is_empty()==false)
+ {
+ rgb_pixel color = graph_.node(i).data.color;
+ // invert this color
+ color.red = 255-color.red;
+ color.green = 255-color.green;
+ color.blue = 255-color.blue;
+ sout << i;
+ unsigned long width, height;
+ mfont->compute_size(sout.str(),width,height);
+ rectangle str_rect(centered_rect(center, width,height));
+ if (circle_area.contains(str_rect))
+ {
+ mfont->draw_string(c,str_rect,sout.str(),color,0,std::string::npos,area);
+
+ // draw the label for this node if it isn't empty
+ if(graph_.node(i).data.name.size() > 0)
+ {
+ rectangle str_rect(graph_.node(i).data.str_rect);
+ str_rect = centered_rect(center.x(), center.y()-rad-mfont->height(), str_rect.width(), str_rect.height());
+ mfont->draw_string(c,str_rect,graph_.node(i).data.name,0,0,std::string::npos,area);
+ }
+ }
+ sout.str("");
+ }
+ }
+ }
+
+ private:
+
+ struct data
+ {
+ data() : color(0,0,0) {}
+ vector<double> p;
+ dlib::ustring name;
+ rectangle str_rect;
+ rgb_pixel color;
+ };
+
+ friend void serialize(const data& item, std::ostream& out)
+ {
+ serialize(item.p, out);
+ serialize(item.name, out);
+ serialize(item.str_rect, out);
+ serialize(item.color, out);
+ }
+
+ friend void deserialize(data& item, std::istream& in)
+ {
+ deserialize(item.p, in);
+ deserialize(item.name, in);
+ deserialize(item.str_rect, in);
+ deserialize(item.color, in);
+ }
+
+ mutable std::ostringstream sout;
+
+ const double radius;
+ unsigned long selected_node;
+ bool mouse_drag; // true if the user is dragging a node
+ point drag_offset;
+
+ bool edge_selected;
+ unsigned long selected_edge_parent;
+ unsigned long selected_edge_child;
+
+ any_function<void(unsigned long)> node_selected_handler;
+ any_function<void(unsigned long)> node_deselected_handler;
+ any_function<void()> node_deleted_handler;
+ any_function<void()> graph_modified_handler;
+
+ graph_type external_graph;
+ // rebind the graph_ type to make us a graph_ of data structs
+ typename graph_type::template rebind<data,char, typename graph_type::mem_manager_type>::other graph_;
+
+ bool last_mouse_click_in_display;
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class text_grid
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class text_grid : public scrollable_region
+ {
+ /*!
+ INITIAL VALUE
+ - has_focus == false
+ - vertical_scroll_increment() == 10
+ - horizontal_scroll_increment() == 10
+ - border_color_ == rgb_pixel(128,128,128)
+
+ CONVENTION
+ - grid.nr() == row_height.size()
+ - grid.nc() == col_width.size()
+ - border_color() == border_color_
+ - text(r,c) == grid[r][c].text
+ - text_color(r,c) == grid[r][c].text_color
+ - background_color(r,c) == grid[r][c].bg_color
+
+ - if (the user has clicked on this widget and caused one of the
+ boxes to have input focus) then
+ - has_focus == true
+ - grid[active_row][active_col] == the active text box
+ - cursor_pos == the position of the cursor in the above box
+ - if (the cursor should be displayed) then
+ - show_cursor == true
+ - else
+ - show_cursor == false
+ - else
+ - has_focus == false
+ !*/
+
+ public:
+ text_grid (
+ drawable_window& w
+ );
+
+ ~text_grid (
+ );
+
+ void set_grid_size (
+ unsigned long rows,
+ unsigned long cols
+ );
+
+ unsigned long number_of_columns (
+ ) const;
+
+ unsigned long number_of_rows (
+ ) const;
+
+ int next_free_user_event_number (
+ ) const;
+
+ rgb_pixel border_color (
+ ) const;
+
+ void set_border_color (
+ rgb_pixel color
+ );
+
+ const std::string text (
+ unsigned long row,
+ unsigned long col
+ ) const;
+
+ const std::wstring wtext (
+ unsigned long row,
+ unsigned long col
+ ) const;
+
+ const dlib::ustring utext (
+ unsigned long row,
+ unsigned long col
+ ) const;
+
+ void set_text (
+ unsigned long row,
+ unsigned long col,
+ const std::string& str
+ );
+
+ void set_text (
+ unsigned long row,
+ unsigned long col,
+ const std::wstring& str
+ );
+
+ void set_text (
+ unsigned long row,
+ unsigned long col,
+ const dlib::ustring& str
+ );
+
+ const rgb_pixel text_color (
+ unsigned long row,
+ unsigned long col
+ ) const;
+
+ void set_text_color (
+ unsigned long row,
+ unsigned long col,
+ const rgb_pixel color
+ );
+
+ const rgb_pixel background_color (
+ unsigned long row,
+ unsigned long col
+ ) const;
+
+ void set_background_color (
+ unsigned long row,
+ unsigned long col,
+ const rgb_pixel color
+ );
+
+ bool is_editable (
+ unsigned long row,
+ unsigned long col
+ ) const;
+
+ void set_editable (
+ unsigned long row,
+ unsigned long col,
+ bool editable
+ );
+
+ void set_column_width (
+ unsigned long col,
+ unsigned long width
+ );
+
+ void set_row_height (
+ unsigned long row,
+ unsigned long height
+ );
+
+ void disable (
+ );
+
+ void hide (
+ );
+
+ template <
+ typename T
+ >
+ void set_text_modified_handler (
+ T& object,
+ void (T::*eh)(unsigned long, unsigned long)
+ ) { text_modified_handler = make_mfp(object,eh); }
+
+ void set_text_modified_handler (
+ const any_function<void(unsigned long, unsigned long)>& eh
+ ) { text_modified_handler = eh; }
+
+ private:
+
+ void on_user_event (
+ int num
+ );
+
+ void timer_action (
+ );
+ /*!
+ ensures
+ - flips the state of show_cursor
+ !*/
+
+ void compute_bg_rects (
+ );
+
+ void compute_total_rect (
+ );
+
+ void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_focus_lost (
+ );
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ rectangle get_text_rect (
+ unsigned long row,
+ unsigned long col
+ ) const;
+
+ rectangle get_bg_rect (
+ unsigned long row,
+ unsigned long col
+ ) const;
+
+ struct data_type
+ {
+ data_type(): text_color(0,0,0), bg_color(255,255,255),
+ first(0), is_editable(true)
+ {}
+
+ dlib::ustring text;
+ rgb_pixel text_color;
+ rgb_pixel bg_color;
+ rectangle bg_rect;
+ dlib::ustring::size_type first;
+ bool is_editable;
+ };
+
+ void drop_input_focus (
+ );
+
+ void move_cursor (
+ long row,
+ long col,
+ long new_cursor_pos
+ );
+
+ array2d<data_type> grid;
+ array<unsigned long> col_width;
+ array<unsigned long> row_height;
+ bool has_focus;
+ long active_col;
+ long active_row;
+ long cursor_pos;
+ bool show_cursor;
+ bool recent_cursor_move;
+ timer<text_grid> cursor_timer;
+ rgb_pixel border_color_;
+ any_function<void(unsigned long, unsigned long)> text_modified_handler;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class image_display : public scrollable_region
+ {
+ /*!
+ INITIAL VALUE
+ - img.size() == 0
+ - overlay_rects.size() == 0
+ - overlay_lines.size() == 0
+ - drawing_rect == false
+ - rect_is_selected == false
+
+ CONVENTION
+ - img == the image this object displays
+ - overlay_rects == the overlay rectangles this object displays
+ - overlay_lines == the overlay lines this object displays
+
+ - if (drawing_rect) then
+ - the user is drawing a rectangle on the screen and is
+ thus holding down CTRL and the left mouse button.
+ - rect_anchor == the point on the screen where the user
+ clicked to begin drawing the rectangle.
+ - rect_to_draw == the rectangle which should appear on the screen.
+
+ - if (rect_is_selected) then
+ - selected_rect == the index in overlay_rects of the user selected
+ rectangle.
+ - last_right_click_pos == the last place we saw the user right click
+ the mouse.
+ - parts_menu.is_enabled() == true
+ - if (it is actually a part of this rect that is selected) then
+ - selected_part_name == the name of the part in overlay_rects[selected_rect].parts
+ that is selected.
+ - else
+ - selected_part_name.size() == 0
+ - else
+ - parts_menu.is_enabled() == false
+ - selected_part_name.size() == 0
+
+ - if (moving_overlay) then
+ - moving_rect == the index in overlay_rects that the move applies to.
+ - if (moving_what == MOVING_PART) then
+ - moving_part_name == the name of the part in
+ overlay_rects[moving_rect] that is being moved around with the
+ mouse.
+ - else
+ - moving_what will tell us which side of the rectangle in
+ overlay_rects[moving_rect] is being moved by the mouse.
+ !*/
+
+ public:
+
+ image_display(
+ drawable_window& w
+ );
+
+ ~image_display(
+ );
+
+ template <
+ typename image_type
+ >
+ void set_image (
+ const image_type& new_img
+ )
+ {
+ auto_mutex M(m);
+
+ // if the new image has a different size when compared to the previous image
+ // then we should readjust the total rectangle size.
+ if (num_rows(new_img) != img.nr() || num_columns(new_img) != img.nc())
+ {
+ if (zoom_in_scale != 1)
+ set_total_rect_size(num_columns(new_img)*zoom_in_scale, num_rows(new_img)*zoom_in_scale);
+ else
+ set_total_rect_size(num_columns(new_img)/zoom_out_scale, num_rows(new_img)/zoom_out_scale);
+ }
+ else
+ {
+ parent.invalidate_rectangle(rect);
+ }
+
+ highlighted_rect = std::numeric_limits<unsigned long>::max();
+ rect_is_selected = false;
+ parts_menu.disable();
+ assign_image_scaled(img,new_img);
+ }
+
+ virtual void set_pos (
+ long x,
+ long y
+ )
+ {
+ auto_mutex lock(m);
+ scrollable_region::set_pos(x,y);
+ parts_menu.set_rect(rect);
+ }
+
+ virtual void set_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ auto_mutex lock(m);
+ scrollable_region::set_size(width,height);
+ parts_menu.set_rect(rect);
+ }
+
+ struct overlay_rect
+ {
+ overlay_rect() :crossed_out(false) { assign_pixel(color, 0);}
+
+ template <typename pixel_type>
+ overlay_rect(const rectangle& r, pixel_type p)
+ : rect(r),crossed_out(false) { assign_pixel(color, p); }
+
+ template <typename pixel_type>
+ overlay_rect(const rectangle& r, pixel_type p, const std::string& l)
+ : rect(r),label(l),crossed_out(false) { assign_pixel(color, p); }
+
+ template <typename pixel_type>
+ overlay_rect(const rectangle& r, pixel_type p, const std::string& l, const std::map<std::string,point>& parts_)
+ : rect(r),label(l),parts(parts_),crossed_out(false) { assign_pixel(color, p); }
+
+ rectangle rect;
+ rgb_alpha_pixel color;
+ std::string label;
+ std::map<std::string,point> parts;
+ bool crossed_out;
+ };
+
+ struct overlay_line
+ {
+ overlay_line() { assign_pixel(color, 0);}
+
+ template <typename pixel_type>
+ overlay_line(const point& p1_, const point& p2_, pixel_type p)
+ : p1(p1_), p2(p2_) { assign_pixel(color, p); }
+
+ point p1;
+ point p2;
+ rgb_alpha_pixel color;
+ };
+
+ struct overlay_circle
+ {
+ overlay_circle():radius(0) { assign_pixel(color, 0);}
+
+ template <typename pixel_type>
+ overlay_circle(const point& center_, const int radius_, pixel_type p)
+ : center(center_), radius(radius_) { assign_pixel(color, p); }
+
+ template <typename pixel_type>
+ overlay_circle(const point& center_, const int radius_, pixel_type p, const std::string& l)
+ : center(center_), radius(radius_), label(l) { assign_pixel(color, p); }
+
+ point center;
+ int radius;
+ rgb_alpha_pixel color;
+ std::string label;
+ };
+
+ void add_overlay (
+ const overlay_rect& overlay
+ );
+
+ void add_overlay (
+ const overlay_line& overlay
+ );
+
+ void add_overlay (
+ const overlay_circle& overlay
+ );
+
+ void add_overlay (
+ const std::vector<overlay_rect>& overlay
+ );
+
+ void add_overlay (
+ const std::vector<overlay_line>& overlay
+ );
+
+ void add_overlay (
+ const std::vector<overlay_circle>& overlay
+ );
+
+ void clear_overlay (
+ );
+
+ rectangle get_image_display_rect (
+ ) const;
+
+ std::vector<overlay_rect> get_overlay_rects (
+ ) const;
+
+ void set_default_overlay_rect_label (
+ const std::string& label
+ );
+
+ std::string get_default_overlay_rect_label (
+ ) const;
+
+ void set_default_overlay_rect_color (
+ const rgb_alpha_pixel& color
+ );
+
+ rgb_alpha_pixel get_default_overlay_rect_color (
+ ) const;
+
+ template <
+ typename T
+ >
+ void set_overlay_rects_changed_handler (
+ T& object,
+ void (T::*event_handler_)()
+ )
+ {
+ auto_mutex M(m);
+ event_handler = make_mfp(object,event_handler_);
+ }
+
+ void set_overlay_rects_changed_handler (
+ const any_function<void()>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ event_handler = event_handler_;
+ }
+
+ template <
+ typename T
+ >
+ void set_overlay_rect_selected_handler (
+ T& object,
+ void (T::*event_handler_)(const overlay_rect& orect)
+ )
+ {
+ auto_mutex M(m);
+ orect_selected_event_handler = make_mfp(object,event_handler_);
+ }
+
+ void set_overlay_rect_selected_handler (
+ const any_function<void(const overlay_rect& orect)>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ orect_selected_event_handler = event_handler_;
+ }
+
+ template <
+ typename T
+ >
+ void set_image_clicked_handler (
+ T& object,
+ void (T::*event_handler_)(const point& p, bool is_double_click, unsigned long btn)
+ )
+ {
+ auto_mutex M(m);
+ image_clicked_handler = make_mfp(object,event_handler_);
+ }
+
+ void set_image_clicked_handler (
+ const any_function<void(const point& p, bool is_double_click, unsigned long btn)>& event_handler_
+ )
+ {
+ auto_mutex M(m);
+ image_clicked_handler = event_handler_;
+ }
+
+ void add_labelable_part_name (
+ const std::string& name
+ );
+
+ void clear_labelable_part_names (
+ );
+
+ void enable_overlay_editing (
+ ) { auto_mutex M(m); overlay_editing_enabled = true; }
+
+ void disable_overlay_editing (
+ )
+ {
+ auto_mutex M(m);
+ overlay_editing_enabled = false;
+ rect_is_selected = false;
+ drawing_rect = false;
+ parent.invalidate_rectangle(rect);
+ }
+
+ bool overlay_editing_is_enabled (
+ ) const { auto_mutex M(m); return overlay_editing_enabled; }
+
+ private:
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ void on_wheel_up (
+ unsigned long state
+ );
+
+ void on_wheel_down (
+ unsigned long state
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void on_mouse_up (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+
+ void on_part_add (
+ const std::string& part_name
+ );
+
+ rectangle get_rect_on_screen (
+ unsigned long idx
+ ) const;
+
+ rectangle get_rect_on_screen (
+ rectangle orect
+ ) const;
+
+ rgb_alpha_pixel invert_pixel (const rgb_alpha_pixel& p) const
+ { return rgb_alpha_pixel(255-p.red, 255-p.green, 255-p.blue, p.alpha); }
+
+ virtual int next_free_user_event_number (
+ ) const { return scrollable_region::next_free_user_event_number()+1; }
+ // The reason for using user actions here rather than just having the timer just call
+ // what it needs directly is to avoid a potential deadlock during destruction of this widget.
+ void timer_event_unhighlight_rect()
+ {
+ highlight_timer.stop();
+ parent.trigger_user_event(this,scrollable_region::next_free_user_event_number());
+ }
+ void on_user_event (int num)
+ {
+ // ignore this user event if it isn't for us
+ if (num != scrollable_region::next_free_user_event_number())
+ return;
+ if (highlighted_rect < overlay_rects.size())
+ {
+ highlighted_rect = std::numeric_limits<unsigned long>::max();
+ parent.invalidate_rectangle(rect);
+ }
+ }
+
+
+ array2d<rgb_alpha_pixel> img;
+
+
+ std::vector<overlay_rect> overlay_rects;
+ std::vector<overlay_line> overlay_lines;
+ std::vector<overlay_circle> overlay_circles;
+
+ long zoom_in_scale;
+ long zoom_out_scale;
+ bool drawing_rect;
+ point rect_anchor;
+ rectangle rect_to_draw;
+ bool rect_is_selected;
+ std::string selected_part_name;
+ unsigned long selected_rect;
+ rgb_alpha_pixel default_rect_color;
+ std::string default_rect_label;
+ any_function<void()> event_handler;
+ any_function<void(const overlay_rect& orect)> orect_selected_event_handler;
+ any_function<void(const point& p, bool is_double_click, unsigned long btn)> image_clicked_handler;
+ popup_menu_region parts_menu;
+ point last_right_click_pos;
+ const double part_width;
+ std::set<std::string> part_names;
+ bool overlay_editing_enabled;
+ timer<image_display> highlight_timer;
+ unsigned long highlighted_rect;
+ bool holding_shift_key;
+
+ bool moving_overlay;
+ unsigned long moving_rect;
+ enum {
+ MOVING_RECT_LEFT,
+ MOVING_RECT_TOP,
+ MOVING_RECT_RIGHT,
+ MOVING_RECT_BOTTOM,
+ MOVING_PART
+ } moving_what;
+ std::string moving_part_name;
+
+ // restricted functions
+ image_display(image_display&); // copy constructor
+ image_display& operator=(image_display&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class perspective_display : public drawable, noncopyable
+ {
+ public:
+
+ perspective_display(
+ drawable_window& w
+ );
+
+ ~perspective_display(
+ );
+
+ virtual void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ struct overlay_line
+ {
+ overlay_line() { assign_pixel(color, 0);}
+
+ overlay_line(const vector<double>& p1_, const vector<double>& p2_)
+ : p1(p1_), p2(p2_) { assign_pixel(color, 255); }
+
+ template <typename pixel_type>
+ overlay_line(const vector<double>& p1_, const vector<double>& p2_, pixel_type p)
+ : p1(p1_), p2(p2_) { assign_pixel(color, p); }
+
+ vector<double> p1;
+ vector<double> p2;
+ rgb_pixel color;
+ };
+
+ struct overlay_dot
+ {
+ overlay_dot() { assign_pixel(color, 0);}
+
+ overlay_dot(const vector<double>& p_)
+ : p(p_) { assign_pixel(color, 255); }
+
+ template <typename pixel_type>
+ overlay_dot(const vector<double>& p_, pixel_type color_)
+ : p(p_) { assign_pixel(color, color_); }
+
+ vector<double> p;
+ rgb_pixel color;
+ };
+
+
+ void add_overlay (
+ const std::vector<overlay_line>& overlay
+ );
+
+ void add_overlay (
+ const std::vector<overlay_dot>& overlay
+ );
+
+ void clear_overlay (
+ );
+
+ template <
+ typename T
+ >
+ void set_dot_double_clicked_handler (
+ T& object,
+ void (T::*event_handler_)(const vector<double>&)
+ )
+ {
+ auto_mutex M(m);
+ dot_clicked_event_handler = make_mfp(object,event_handler_);
+ }
+
+ void set_dot_double_clicked_handler (
+ const any_function<void(const vector<double>&)>& event_handler_
+ );
+
+ private:
+
+ void draw (
+ const canvas& c
+ ) const;
+
+ void on_wheel_up (
+ unsigned long state
+ );
+
+ void on_wheel_down (
+ unsigned long state
+ );
+
+ void on_mouse_down (
+ unsigned long btn,
+ unsigned long state,
+ long x,
+ long y,
+ bool is_double_click
+ );
+
+ void on_mouse_move (
+ unsigned long state,
+ long x,
+ long y
+ );
+
+ static bool compare_second (
+ const std::pair<overlay_dot,float>& a,
+ const std::pair<overlay_dot,float>& b
+ ) { return a.second < b.second; }
+
+
+ point last;
+ std::vector<overlay_line> overlay_lines;
+ std::vector<overlay_dot> overlay_dots;
+
+ camera_transform tform;
+ vector<double> sum_pts;
+ vector<double> max_pts;
+ any_function<void(const vector<double>&)> dot_clicked_event_handler;
+ mutable array2d<float> depth;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class perspective_window : public drawable_window, noncopyable
+ {
+ public:
+
+ typedef perspective_display::overlay_line overlay_line;
+ typedef perspective_display::overlay_dot overlay_dot;
+
+ perspective_window(
+ ) : disp(*this)
+ {
+ set_size(100,100);
+ on_window_resized();
+ show();
+ }
+
+ perspective_window(
+ const std::vector<dlib::vector<double> >& point_cloud
+ ) :
+ disp(*this)
+ {
+ set_size(100,100);
+ on_window_resized();
+ add_overlay(point_cloud);
+ show();
+ }
+
+ perspective_window(
+ const std::vector<dlib::vector<double> >& point_cloud,
+ const std::string& title
+ ) :
+ disp(*this)
+ {
+ set_size(100,100);
+ on_window_resized();
+ add_overlay(point_cloud);
+ set_title(title);
+ show();
+ }
+
+ ~perspective_window(
+ )
+ {
+ // You should always call close_window() in the destructor of window
+ // objects to ensure that no events will be sent to this window while
+ // it is being destructed.
+ close_window();
+ }
+
+ void add_overlay (
+ const std::vector<overlay_line>& overlay
+ )
+ {
+ disp.add_overlay(overlay);
+ }
+
+ void add_overlay (
+ const std::vector<overlay_dot>& overlay
+ )
+ {
+ disp.add_overlay(overlay);
+ }
+
+ void clear_overlay (
+ )
+ {
+ disp.clear_overlay();
+ }
+
+ template <typename pixel_type>
+ void add_overlay(const vector<double>& p1, const vector<double>& p2, pixel_type p)
+ {
+ add_overlay(std::vector<overlay_line>(1,overlay_line(p1,p2,p)));
+ }
+
+ void add_overlay(const std::vector<dlib::vector<double> >& d)
+ {
+ add_overlay(d, 255);
+ }
+
+ template <typename pixel_type>
+ void add_overlay(const std::vector<dlib::vector<double> >& d, pixel_type p)
+ {
+ std::vector<overlay_dot> temp;
+ temp.resize(d.size());
+ for (unsigned long i = 0; i < temp.size(); ++i)
+ temp[i] = overlay_dot(d[i], p);
+
+ add_overlay(temp);
+ }
+
+ template <
+ typename T
+ >
+ void set_dot_double_clicked_handler (
+ T& object,
+ void (T::*event_handler_)(const vector<double>&)
+ )
+ {
+ disp.set_dot_double_clicked_handler(object,event_handler_);
+ }
+
+ void set_dot_double_clicked_handler (
+ const any_function<void(const vector<double>&)>& event_handler_
+ )
+ {
+ disp.set_dot_double_clicked_handler(event_handler_);
+ }
+
+ private:
+
+ void on_window_resized(
+ )
+ {
+ drawable_window::on_window_resized();
+ unsigned long width, height;
+ get_size(width,height);
+ disp.set_pos(0,0);
+ disp.set_size(width, height);
+ }
+
+ perspective_display disp;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class image_window : public drawable_window
+ {
+ public:
+
+ typedef image_display::overlay_rect overlay_rect;
+ typedef image_display::overlay_line overlay_line;
+ typedef image_display::overlay_circle overlay_circle;
+
+ image_window(
+ );
+
+ template < typename image_type >
+ image_window(
+ const image_type& img
+ ) :
+ gui_img(*this),
+ window_has_closed(false),
+ have_last_click(false),
+ mouse_btn(0),
+ clicked_signaler(this->wm),
+ have_last_keypress(false),
+ tie_input_events(false)
+ {
+ gui_img.set_image_clicked_handler(*this, &image_window::on_image_clicked);
+ gui_img.disable_overlay_editing();
+ set_image(img);
+ show();
+ }
+
+ template < typename image_type >
+ image_window(
+ const image_type& img,
+ const std::string& title
+ ) :
+ gui_img(*this),
+ window_has_closed(false),
+ have_last_click(false),
+ mouse_btn(0),
+ clicked_signaler(this->wm),
+ have_last_keypress(false),
+ tie_input_events(false)
+ {
+ gui_img.set_image_clicked_handler(*this, &image_window::on_image_clicked);
+ gui_img.disable_overlay_editing();
+ set_image(img);
+ set_title(title);
+ show();
+ }
+
+
+ ~image_window(
+ );
+
+ template < typename image_type >
+ void set_image (
+ const image_type& img
+ )
+ {
+ const unsigned long padding = scrollable_region_style_default().get_border_size();
+ auto_mutex M(wm);
+ gui_img.set_image(img);
+
+ // Only ever mess with the size of the window if the user is giving us an image
+ // that is a different size. Otherwise we assume that they will have already
+ // sized the window to whatever they feel is reasonable for an image of the
+ // current size.
+ if (previous_image_size != get_rect(img))
+ {
+ const rectangle r = gui_img.get_image_display_rect();
+ if (image_rect != r)
+ {
+ // set the size of this window to match the size of the input image
+ set_size(r.width()+padding*2,r.height()+padding*2);
+
+ // call this to make sure everything else is setup properly
+ on_window_resized();
+
+ image_rect = r;
+ }
+ previous_image_size = get_rect(img);
+ }
+ }
+
+ void add_overlay (
+ const overlay_rect& overlay
+ );
+
+ template <typename pixel_type>
+ void add_overlay(const rectangle& r, pixel_type p)
+ { add_overlay(image_display::overlay_rect(r,p)); }
+
+ void add_overlay(const rectangle& r)
+ { add_overlay(image_display::overlay_rect(r,rgb_pixel(255,0,0))); }
+
+ template <typename pixel_type>
+ void add_overlay(const rectangle& r, pixel_type p, const std::string& l)
+ { add_overlay(image_display::overlay_rect(r,p,l)); }
+
+ template <typename pixel_type>
+ void add_overlay(const std::vector<rectangle>& r, pixel_type p)
+ {
+ std::vector<overlay_rect> temp;
+ temp.resize(r.size());
+ for (unsigned long i = 0; i < temp.size(); ++i)
+ temp[i] = overlay_rect(r[i], p);
+
+ add_overlay(temp);
+ }
+
+ void add_overlay(const std::vector<rectangle>& r)
+ { add_overlay(r, rgb_pixel(255,0,0)); }
+
+ void add_overlay(
+ const full_object_detection& object,
+ const std::vector<std::string>& part_names
+ )
+ {
+
+ add_overlay(overlay_rect(object.get_rect(), rgb_pixel(255,0,0)));
+
+ std::vector<overlay_circle> temp;
+ temp.reserve(object.num_parts());
+ for (unsigned long i = 0; i < object.num_parts(); ++i)
+ {
+ if (object.part(i) != OBJECT_PART_NOT_PRESENT)
+ {
+ if (i < part_names.size())
+ temp.push_back(overlay_circle(object.part(i), 7, rgb_pixel(0,255,0), part_names[i]));
+ else
+ temp.push_back(overlay_circle(object.part(i), 7, rgb_pixel(0,255,0)));
+ }
+ }
+
+ add_overlay(temp);
+ }
+
+ void add_overlay(
+ const full_object_detection& object
+ )
+ {
+ std::vector<std::string> part_names;
+ add_overlay(object, part_names);
+ }
+
+ void add_overlay(
+ const std::vector<full_object_detection>& objects,
+ const std::vector<std::string>& part_names
+ )
+ {
+ std::vector<overlay_rect> rtemp;
+ rtemp.reserve(objects.size());
+ for (unsigned long i = 0; i < objects.size(); ++i)
+ {
+ rtemp.push_back(overlay_rect(objects[i].get_rect(), rgb_pixel(255,0,0)));
+ }
+
+ add_overlay(rtemp);
+
+ std::vector<overlay_circle> temp;
+
+ for (unsigned long i = 0; i < objects.size(); ++i)
+ {
+ for (unsigned long j = 0; j < objects[i].num_parts(); ++j)
+ {
+ if (objects[i].part(j) != OBJECT_PART_NOT_PRESENT)
+ {
+ if (j < part_names.size())
+ temp.push_back(overlay_circle(objects[i].part(j), 7, rgb_pixel(0,255,0),part_names[j]));
+ else
+ temp.push_back(overlay_circle(objects[i].part(j), 7, rgb_pixel(0,255,0)));
+ }
+ }
+ }
+
+ add_overlay(temp);
+ }
+
+ void add_overlay(
+ const std::vector<full_object_detection>& objects
+ )
+ {
+ std::vector<std::string> part_names;
+ add_overlay(objects, part_names);
+ }
+
+ void add_overlay (
+ const overlay_line& overlay
+ );
+
+ void add_overlay (
+ const overlay_circle& overlay
+ );
+
+ template <typename pixel_type>
+ void add_overlay(const point& p1, const point& p2, pixel_type p)
+ { add_overlay(image_display::overlay_line(p1,p2,p)); }
+
+ void add_overlay (
+ const std::vector<overlay_rect>& overlay
+ );
+
+ void add_overlay (
+ const std::vector<overlay_line>& overlay
+ );
+
+ void add_overlay (
+ const std::vector<overlay_circle>& overlay
+ );
+
+ void clear_overlay (
+ );
+
+ bool get_next_double_click (
+ point& p,
+ unsigned long& mouse_button
+ );
+
+ void tie_events (
+ );
+
+ void untie_events (
+ );
+
+ bool events_tied (
+ ) const;
+
+ bool get_next_double_click (
+ point& p
+ )
+ {
+ unsigned long mouse_button;
+ return get_next_double_click(p, mouse_button);
+ }
+
+ bool get_next_keypress (
+ unsigned long& key,
+ bool& is_printable,
+ unsigned long& state
+ );
+
+ bool get_next_keypress (
+ unsigned long& key,
+ bool& is_printable
+ )
+ {
+ unsigned long state;
+ return get_next_keypress(key,is_printable,state);
+ }
+
+ private:
+
+ virtual base_window::on_close_return_code on_window_close(
+ );
+
+ void on_window_resized(
+ );
+
+ void on_image_clicked (
+ const point& p,
+ bool is_double_click,
+ unsigned long btn
+ );
+
+ virtual void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ );
+
+ // restricted functions
+ image_window(image_window&);
+ image_window& operator= (image_window&);
+
+ image_display gui_img;
+ rectangle image_rect;
+ rectangle previous_image_size;
+ bool window_has_closed;
+ bool have_last_click;
+ point last_clicked_point;
+ unsigned long mouse_btn;
+ rsignaler clicked_signaler;
+
+ bool have_last_keypress;
+ unsigned long next_key;
+ bool next_is_printable;
+ unsigned long next_state;
+ bool tie_input_events;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "widgets.cpp"
+#endif
+
+#endif // DLIB_WIDGETs_
+
diff --git a/ml/dlib/dlib/gui_widgets/widgets_abstract.h b/ml/dlib/dlib/gui_widgets/widgets_abstract.h
new file mode 100644
index 000000000..2b4dc4486
--- /dev/null
+++ b/ml/dlib/dlib/gui_widgets/widgets_abstract.h
@@ -0,0 +1,3461 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#undef DLIB_WIDGETs_ABSTRACT_
+#ifdef DLIB_WIDGETs_ABSTRACT_
+
+#include "fonts_abstract.h"
+#include "drawable_abstract.h"
+#include "base_widgets_abstract.h"
+
+#include "../gui_core.h"
+#include <string>
+#include <map>
+#include "../interfaces/enumerable.h"
+#include "style_abstract.h"
+#include "../image_processing/full_object_detection_abstract.h"
+
+namespace dlib
+{
+
+ /*!
+ GENERAL REMARKS
+ This component is a collection of various windowing widgets such as buttons,
+ labels, text boxes, and so on. This component also includes the drawable
+ interface, drawable_window, and font handling objects. The file you are
+ currently viewing defines all the high level graphical widgets which are
+ provided by this component that can appear in a drawable_window. To view
+ the specifications for the other members of this component look at
+ fonts_abstract.h, base_widgets_abstract.h, and drawable_abstract.h
+
+ THREAD SAFETY
+ All objects and functions defined in this file are thread safe. You may
+ call them from any thread without serializing access to them.
+
+ EVENT HANDLERS
+ Note that all event handlers, including the user registered callback
+ functions, are executed in the event handling thread. Additionally,
+ the drawable::m mutex will always be locked while these event handlers
+ are running. Also, don't rely on get_thread_id() always returning the
+ same ID from inside event handlers.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // function open_file_box(), open_existing_file_box(), and save_file_box()
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void open_file_box (
+ T& object,
+ void (T::*event_handler)(const std::string&)
+ );
+ /*!
+ requires
+ - event_handler == a valid pointer to a member function of object T.
+ ensures
+ - Displays a window titled "Open File" that will allow the user to select a
+ file.
+ - The displayed window will start out showing the directory get_current_dir()
+ (i.e. it starts in the current working directory)
+ - The event_handler function is called on object if the user selects
+ a file. If the user closes the window without selecting a file
+ then nothing occurs.
+ !*/
+
+ void open_file_box (
+ const any_function<void(const std::string&)>& event_handler
+ );
+ /*!
+ ensures
+ - Displays a window titled "Open File" that will allow the user to select a
+ file.
+ - The displayed window will start out showing the directory get_current_dir()
+ (i.e. it starts in the current working directory)
+ - The event_handler function is called if the user selects
+ a file. If the user closes the window without selecting a file
+ then nothing occurs.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void open_existing_file_box (
+ T& object,
+ void (T::*event_handler)(const std::string&)
+ );
+ /*!
+ requires
+ - event_handler == a valid pointer to a member function of object T.
+ ensures
+ - Displays a window titled "Open File" that will allow the user to select
+ a file. But only a file that already exists.
+ - The displayed window will start out showing the directory get_current_dir()
+ (i.e. it starts in the current working directory)
+ - The event_handler function is called on object if the user selects
+ a file. If the user closes the window without selecting a file
+ then nothing occurs.
+ !*/
+
+ void open_existing_file_box (
+ const any_function<void(const std::string&)>& event_handler
+ );
+ /*!
+ ensures
+ - Displays a window titled "Open File" that will allow the user to select
+ a file. But only a file that already exists.
+ - The displayed window will start out showing the directory get_current_dir()
+ (i.e. it starts in the current working directory)
+ - The event_handler function is called if the user selects
+ a file. If the user closes the window without selecting a file
+ then nothing occurs.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void save_file_box (
+ T& object,
+ void (T::*event_handler)(const std::string&)
+ );
+ /*!
+ requires
+ - event_handler == a valid pointer to a member function of object T.
+ ensures
+ - Displays a window titled "Save File" that will allow the user to select
+ a file.
+ - The displayed window will start out showing the directory get_current_dir()
+ (i.e. it starts in the current working directory)
+ - The event_handler function is called on object if the user selects
+ a file. If the user closes the window without selecting a file
+ then nothing occurs.
+ !*/
+
+ void save_file_box (
+ const any_function<void(const std::string&)>& event_handler
+ );
+ /*!
+ ensures
+ - Displays a window titled "Save File" that will allow the user to select
+ a file.
+ - The displayed window will start out showing the directory get_current_dir()
+ (i.e. it starts in the current working directory)
+ - The event_handler function is called if the user selects
+ a file. If the user closes the window without selecting a file
+ then nothing occurs.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // function message_box()
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void message_box (
+ const std::string& title,
+ const std::string& message
+ );
+ /*!
+ ensures
+ - displays a message box with the given title and message. It will have a
+ single button and when the user clicks it the message box will go away.
+ - this function does not block but instead returns immediately.
+ !*/
+
+ void message_box_blocking (
+ const std::string& title,
+ const std::string& message
+ );
+ /*!
+ ensures
+ - displays a message box with the given title and message. It will have a
+ single button and when the user clicks it the message box will go away.
+ - this function blocks until the user clicks on the message box and
+ causes it to go away.
+ !*/
+
+ template <
+ typename T
+ >
+ void message_box (
+ const std::string& title,
+ const std::string& message,
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler == a valid pointer to a member function of object T.
+ ensures
+ - Displays a message box with the given title and message. It will have a
+ single button and when the user clicks it the message box will go away.
+ - The event_handler function is called on object when the user clicks
+ ok or otherwise closes the message box window.
+ - this function does not block but instead returns immediately.
+ !*/
+
+ void message_box (
+ const std::string& title,
+ const std::string& message,
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - Displays a message box with the given title and message. It will have a
+ single button and when the user clicks it the message box will go away.
+ - The event_handler function is called when the user clicks
+ ok or otherwise closes the message box window.
+ - this function does not block but instead returns immediately.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class label
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class label : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ text() == ""
+ the text color will be black
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple text label. The size of the label
+ is automatically set to be just big enough to contain its text.
+ !*/
+
+ public:
+
+ label(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~label(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_text (const std::wstring& text);
+ void set_text (const dlib::ustring& text);
+ void set_text (
+ const std::string& text
+ );
+ /*!
+ ensures
+ - #text() == text
+ throws
+ - std::bad_alloc
+ !*/
+
+ const std::wstring wtext () const;
+ const dlib::ustring utext () const;
+ const std::string text (
+ ) const;
+ /*!
+ ensures
+ - returns the text of this label
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_text_color (
+ const rgb_pixel color
+ );
+ /*!
+ ensures
+ - #text_color() == color
+ !*/
+
+ const rgb_pixel text_color (
+ ) const;
+ /*!
+ ensures
+ - returns the color used to draw the text in this widget
+ !*/
+
+ private:
+
+ // restricted functions
+ label(label&); // copy constructor
+ label& operator=(label&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class toggle_button
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class toggle_button : public button_action
+ {
+ /*!
+ INITIAL VALUE
+ name() == ""
+ is_checked() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple two state toggle button. Is is either
+ in the checked or unchecked state and when a user clicks on it it toggles its
+ state.
+
+ When this object is disabled it means it will not respond to user clicks.
+ !*/
+
+ public:
+
+ toggle_button(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~toggle_button(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_name (const std::wstring& name);
+ void set_name (const dlib::ustring& name);
+ void set_name (
+ const std::string& name
+ );
+ /*!
+ ensures
+ - #name() == name
+ - this toggle_button has been resized such that it is big enough to contain
+ the new name.
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_size (
+ unsigned long width_,
+ unsigned long height_
+ );
+ /*!
+ ensures
+ - if (width and height are big enough to contain the name of this button) then
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this button stays the
+ same but its width and height are modified
+ !*/
+
+ void set_tooltip_text (const std::wstring& text);
+ void set_tooltip_text (const dlib::ustring& text);
+ void set_tooltip_text (
+ const std::string& text
+ );
+ /*!
+ ensures
+ - #tooltip_text() == text
+ - enables the tooltip for this toggle_button
+ !*/
+
+ const dlib::ustring tooltip_utext () const;
+ const std::wstring tooltip_wtext () const;
+ const std::string tooltip_text (
+ ) const;
+ /*!
+ ensures
+ - returns the text that is displayed in the tooltip for this toggle_button
+ !*/
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style
+ );
+ /*!
+ requires
+ - style_type == a type that inherits from toggle_button_style
+ ensures
+ - this toggle_button object will draw itself using the given
+ button style
+ !*/
+
+ bool is_checked (
+ ) const;
+ /*!
+ ensures
+ - if (this box is currently checked) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ const std::wstring wname () const;
+ const dlib::ustring uname () const;
+ const std::string name (
+ ) const;
+ /*!
+ ensures
+ - returns the name of this toggle_button. The name is a string
+ that appears to the right of the actual check box.
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_checked (
+ );
+ /*!
+ ensures
+ - #is_checked() == true
+ !*/
+
+ void set_unchecked (
+ );
+ /*!
+ ensures
+ - #is_checked() == false
+ !*/
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the toggle_button is
+ toggled by the user.
+ - this event is NOT triggered by calling set_checked() or set_unchecked().
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_click_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the toggle_button is
+ toggled by the user.
+ - this event is NOT triggered by calling set_checked() or set_unchecked().
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler)(toggle_button& self)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T.
+ ensures
+ - the event_handler function is called on object when the toggle_button is
+ toggled by the user. self will be a reference to the toggle_button object
+ that the user clicked.
+ - this event is NOT triggered by calling set_checked() or set_unchecked().
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_sourced_click_handler (
+ const any_function<void(toggle_button& self)>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the toggle_button is
+ toggled by the user. self will be a reference to the toggle_button object
+ that the user clicked.
+ - this event is NOT triggered by calling set_checked() or set_unchecked().
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ toggle_button(toggle_button&); // copy constructor
+ toggle_button& operator=(toggle_button&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class text_field
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class text_field : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - text() == ""
+ - width() == 10
+ - height() == a height appropriate for the font used. The text color will
+ be black.
+ - has_input_focus() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple one line text input field.
+ !*/
+
+ public:
+
+ text_field(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~text_field(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style
+ );
+ /*!
+ requires
+ - style_type == a type that inherits from text_field_style
+ ensures
+ - this text_field object will draw itself using the given
+ text field style
+ !*/
+
+ void set_text (const std::wstring& text);
+ void set_text (const dlib::ustring& text);
+ void set_text (
+ const std::string& text
+ );
+ /*!
+ requires
+ - text.find_first_of('\n') == std::string::npos
+ (i.e. there aren't any new lines in text)
+ ensures
+ - #text() == text
+ throws
+ - std::bad_alloc
+ !*/
+
+ const std::wstring wtext () const;
+ const dlib::ustring utext () const;
+ const std::string text (
+ ) const;
+ /*!
+ ensures
+ - returns the text of this text_field
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_width (
+ unsigned long width_
+ );
+ /*!
+ ensures
+ - if (width >= 10) then
+ - #width() == width_
+ - #height() == height()
+ - #top() == top()
+ - #left() == left()
+ - i.e. The width of this drawable is set to the given width but
+ nothing else changes.
+ !*/
+
+ void give_input_focus (
+ );
+ /*!
+ ensures
+ - #has_input_focus() == true
+ !*/
+
+ bool has_input_focus (
+ );
+ /*!
+ ensures
+ - Returns true if this txt field has input keyboard focus. If this
+ is the case then it means that when the user types on the keyboard
+ the output will appear inside the text field.
+ !*/
+
+ void select_all_text (
+ );
+ /*!
+ ensures
+ - causes all the text in the text field to become selected.
+ (note that it doesn't give input focus)
+ !*/
+
+ void set_text_color (
+ const rgb_pixel color
+ );
+ /*!
+ ensures
+ - #text_color() == color
+ !*/
+
+ const rgb_pixel text_color (
+ ) const;
+ /*!
+ ensures
+ - returns the color used to draw the text in this widget
+ !*/
+
+ void set_background_color (
+ const rgb_pixel color
+ );
+ /*!
+ ensures
+ - #background_color() == color
+ !*/
+
+ const rgb_pixel background_color (
+ ) const;
+ /*!
+ ensures
+ - returns the color used to fill in the background of this widget
+ !*/
+
+ template <
+ typename T
+ >
+ void set_text_modified_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the text
+ in this text_field is modified by the user.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_text_modified_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the text in this text_field
+ is modified by the user.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_enter_key_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when this text field
+ has input focus and the user hits the enter key on their keyboard.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_enter_key_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when this text field has input
+ focus and the user hits the enter key on their keyboard.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_focus_lost_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when this object
+ loses input focus due to the user clicking outside the text field
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_focus_lost_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when this object loses input
+ focus due to the user clicking outside the text field
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ text_field(text_field&); // copy constructor
+ text_field& operator=(text_field&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class text_box
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class text_box : public scrollable_region
+ {
+ /*!
+ INITIAL VALUE
+ - text() == ""
+ - The text color will be black.
+ - width() == 100
+ - height() == 100
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple multi-line text input box.
+ !*/
+
+ public:
+
+ text_box(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~text_box(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style
+ );
+ /*!
+ requires
+ - style_type == a type that inherits from text_box_style
+ ensures
+ - this text_box object will draw itself using the given
+ text box style
+ !*/
+
+ void set_text (const std::wstring& text);
+ void set_text (const dlib::ustring& text);
+ void set_text (
+ const std::string& text
+ );
+ /*!
+ ensures
+ - #text() == text
+ throws
+ - std::bad_alloc
+ !*/
+
+ const std::wstring wtext () const;
+ const dlib::ustring utext () const;
+ const std::string text (
+ ) const;
+ /*!
+ ensures
+ - returns the text of this text_box
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ ensures
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this widget stays the
+ same but its width and height are modified
+ !*/
+
+ void set_text_color (
+ const rgb_pixel color
+ );
+ /*!
+ ensures
+ - #text_color() == color
+ !*/
+
+ const rgb_pixel text_color (
+ ) const;
+ /*!
+ ensures
+ - returns the color used to draw the text in this widget
+ !*/
+
+ void set_background_color (
+ const rgb_pixel color
+ );
+ /*!
+ ensures
+ - #background_color() == color
+ !*/
+
+ const rgb_pixel background_color (
+ ) const;
+ /*!
+ ensures
+ - returns the color used to fill in the background of this widget
+ !*/
+
+ template <
+ typename T
+ >
+ void set_text_modified_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the text
+ in this text_box is modified by the user.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_text_modified_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the text in this text_box
+ is modified by the user.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_enter_key_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when this text box
+ has input focus and the user hits the enter key on their keyboard.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_enter_key_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when this text box has input
+ focus and the user hits the enter key on their keyboard.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_focus_lost_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when this object
+ loses input focus due to the user clicking outside the text box
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_focus_lost_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when this object loses input
+ focus due to the user clicking outside the text box
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ text_box(text_box&); // copy constructor
+ text_box& operator=(text_box&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class check_box
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class check_box : public toggle_button
+ {
+ /*!
+ This is just a toggle button with the style set to
+ toggle_button_style_check_box.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class radio_button
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class radio_button : public toggle_button
+ {
+ /*!
+ This is just a toggle button with the style set to
+ toggle_button_style_radio_button.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class tabbed_display
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class tabbed_display : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ number_of_tabs() == 1
+ selected_tab() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a row of tabs that are user selectable.
+
+ When this object is disabled it means it will not respond to user clicks.
+ !*/
+
+ public:
+
+ tabbed_display(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~tabbed_display(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_size (
+ unsigned long width_,
+ unsigned long height_
+ );
+ /*!
+ ensures
+ - if (width and height are big enough to contain the tabs) then
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this widget stays the
+ same but its width and height are modified
+ !*/
+
+ void set_number_of_tabs (
+ unsigned long num
+ );
+ /*!
+ requires
+ - num > 0
+ ensures
+ - #number_of_tabs() == num
+ - no tabs have any widget_groups associated with them.
+ - for all valid idx:
+ - #tab_name(idx) == ""
+ throws
+ - std::bad_alloc
+ !*/
+
+ unsigned long selected_tab (
+ ) const;
+ /*!
+ ensures
+ - returns the index of the currently selected tab
+ !*/
+
+ unsigned long number_of_tabs (
+ ) const;
+ /*!
+ ensures
+ - returns the number of tabs in this tabbed_display
+ !*/
+
+ const std::wstring& tab_wname (unsigned long idx) const;
+ const dlib::ustring& tab_uname (unsigned long idx) const;
+ const std::string& tab_name (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_tabs()
+ ensures
+ - returns a const reference to the name of the tab given by idx
+ !*/
+
+ void set_tab_name (unsigned long idx, const std::wstring& new_name);
+ void set_tab_name (unsigned long idx, const dlib::ustring& new_name);
+ void set_tab_name (
+ unsigned long idx,
+ const std::string& new_name
+ );
+ /*!
+ requires
+ - idx < number_of_tabs()
+ ensures
+ - #tab_name(idx) == new_name
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_tab_group (
+ unsigned long idx,
+ widget_group& group
+ );
+ /*!
+ requires
+ - idx < number_of_tabs()
+ ensures
+ - if (is_hidden()) then
+ - group.is_hidden() == true
+ - else
+ - whenever the tab with index idx is selected group.is_hidden() == false
+ - whenever the tab with index idx is deselected group.is_hidden() == true
+ - whenever the position of *this changes the position of group will be
+ updated so that it is still inside the tabbed_display. The position of group
+ will also be updated after this call to set_tab_group().
+ - any previous calls to set_tab_group() with this index are overridden by this
+ new call. (i.e. you can only have one widget_group associated with a single
+ tab at a time)
+ !*/
+
+ void fit_to_contents (
+ );
+ /*!
+ ensures
+ - Adjusts the size this tabbed_display so that it nicely contains
+ all of its widget_group objects.
+ - does not change the position of this object.
+ (i.e. the upper left corner of get_rect() remains at the same position)
+ !*/
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler)(unsigned long new_idx, unsigned long old_idx)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - The event_handler function is called on object when the user clicks
+ on a tab that isn't already selected. new_idx will give the index of
+ the newly selected tab and old_idx will give the index of the tab
+ that was previously selected.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_click_handler (
+ const any_function<void(unsigned long new_idx, unsigned long old_idx)>& eh
+ );
+ /*!
+ ensures
+ - The event_handler function is called when the user clicks on a tab
+ that isn't already selected. new_idx will give the index of the
+ newly selected tab and old_idx will give the index of the tab that
+ was previously selected.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ tabbed_display(tabbed_display&); // copy constructor
+ tabbed_display& operator=(tabbed_display&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class named_rectangle
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class named_rectangle : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ name() == ""
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple named rectangle.
+ !*/
+
+ public:
+
+ named_rectangle(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~named_rectangle(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_size (
+ unsigned long width_,
+ unsigned long height_
+ );
+ /*!
+ ensures
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this widget stays the
+ same but its width and height are modified
+ !*/
+
+ void wrap_around (
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - This object will be repositioned and sized so that it fits
+ around the given rectangle.
+ !*/
+
+ void set_name (const std::wstring& name);
+ void set_name (const dlib::ustring& name);
+ void set_name (
+ const std::string& name
+ );
+ /*!
+ ensures
+ - #name() == name
+ throws
+ - std::bad_alloc
+ !*/
+
+ const std::wstring wname () const;
+ const dlib::ustring uname () const;
+ const std::string name (
+ ) const;
+ /*!
+ ensures
+ - returns the name of this named_rectangle
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ named_rectangle(named_rectangle&); // copy constructor
+ named_rectangle& operator=(named_rectangle&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class mouse_tracker
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class mouse_tracker : public draggable
+ {
+ /*!
+ INITIAL VALUE
+ draggable_area() == rectangle(0,0,500,500)
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple draggable box that displays the
+ current location of the mouse.
+
+ Also, if you hold shift and left click on the parent window then the
+ mouse_tracker will place a single red pixel where you clicked and will
+ display the mouse position relative to that point.
+ !*/
+
+ public:
+
+ mouse_tracker(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~mouse_tracker(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ private:
+
+ // restricted functions
+ mouse_tracker(mouse_tracker&); // copy constructor
+ mouse_tracker& operator=(mouse_tracker&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class list_box
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class list_box : public scrollable_region,
+ public enumerable<const std::string>
+ {
+ /*!
+ INITIAL VALUE
+ multiple_select_enabled() == false
+ size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the elements in the list_box from
+ the 0th element to the (size()-1)th element. i.e. (*this)[0] to
+ (*this)[size()-1].
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple textual list box. It contains a
+ vertical list of strings which the user may select from.
+ !*/
+
+ public:
+
+ list_box(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~list_box(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ template <
+ typename style_type
+ >
+ void set_style (
+ const style_type& style
+ );
+ /*!
+ requires
+ - style_type == a type that inherits from list_box_style
+ ensures
+ - this list_box object will draw itself using the given style
+ !*/
+
+ void set_size (
+ unsigned long width_,
+ unsigned long height_
+ );
+ /*!
+ ensures
+ - #width() == width_
+ - #height() == height_
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this widget stays the
+ same but its width and height are modified
+ !*/
+
+ bool is_selected (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < size()
+ ensures
+ - if (the item given by index is currently selected) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void select (
+ unsigned long index
+ );
+ /*!
+ requires
+ - index < size()
+ ensures
+ - #is_selected(index) == true
+ !*/
+
+ void unselect (
+ unsigned long index
+ );
+ /*!
+ requires
+ - index < size()
+ ensures
+ - #is_selected(index) == false
+ !*/
+
+ template <typename T>
+ void get_selected (
+ T& list
+ ) const;
+ /*!
+ requires
+ - T == an implementation of dlib/queue/queue_kernel_abstract.h
+ - T::type == unsigned long
+ ensures
+ - #list == a list of all the currently selected indices for this list_box.
+ !*/
+
+ unsigned long get_selected (
+ ) const;
+ /*!
+ requires
+ - multiple_select_enabled() == false
+ ensures
+ - if (there is currently something selected) then
+ - returns the index of the selected item
+ - else
+ - returns size()
+ !*/
+
+ template <typename T>
+ void load (
+ const T& list
+ );
+ /*!
+ requires
+ - T == compatible with dlib::enumerable<std::string>
+ ensures
+ - #size() == list.size()
+ - Copies all the strings from list into *this in the order in which they are enumerated.
+ (i.e. The first one goes into (*this)[0], the second into (*this)[1], and so on...)
+ !*/
+
+ const std::string& operator[] (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < size()
+ ensures
+ - returns the name of the indexth item/row in this list box.
+ !*/
+
+ bool multiple_select_enabled (
+ ) const;
+ /*!
+ ensures
+ - if (this object will allow the user to select more than one item at a time) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void enable_multiple_select (
+ );
+ /*!
+ ensures
+ - #multiple_select_enabled() == true
+ !*/
+
+ void disable_multiple_select (
+ );
+ /*!
+ ensures
+ - #multiple_select_enabled() == false
+ !*/
+
+ template <
+ typename T
+ >
+ void set_double_click_handler (
+ T& object,
+ void (T::*event_handler)(unsigned long index)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T.
+ ensures
+ - The event_handler function is called on object when the user double
+ clicks on one of the rows in this list box. index gives the row
+ number for the item the user clicked.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_double_click_handler (
+ const any_function<void(unsigned long index)>& event_handler
+ );
+ /*!
+ ensures
+ - The event_handler function is called when the user double clicks on
+ one of the rows in this list box. index gives the row number for
+ the item the user clicked.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_click_handler (
+ T& object,
+ void (T::*event_handler)(unsigned long index)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T.
+ ensures
+ - The event_handler function is called on object when the user
+ clicks on one of the rows in this list box. index gives the row
+ number for the item the user clicked. (Note that the second click
+ in a double click triggers the double click handler above instead
+ of this event)
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_click_handler (
+ const any_function<void(unsigned long index)>& event_handler
+ );
+ /*!
+ ensures
+ - The event_handler function is called when the user clicks on one
+ of the rows in this list box. index gives the row number for the
+ item the user clicked. (Note that the second click in a double
+ click triggers the double click handler above instead of this event)
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ list_box(list_box&); // copy constructor
+ list_box& operator=(list_box&); // assignment operator
+ };
+
+ class wlist_box : public scrollable_region,
+ public enumerable<const std::wstring>;
+ /*!
+ same as list_box except for std::wstring instead of std::string
+ !*/
+
+ class ulist_box : public scrollable_region,
+ public enumerable<const dlib::ustring>;
+ /*!
+ same as list_box except for dlib::ustring instead of std::string
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // class menu_bar
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class menu_bar : public drawable
+ {
+ /*!
+ INITIAL VALUE
+ - number_of_menus() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a menu bar that appears at the top of a
+ window.
+ !*/
+
+ public:
+
+ menu_bar(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~menu_bar(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_number_of_menus (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #number_of_menus() == num
+ !*/
+
+ unsigned long number_of_menus (
+ ) const;
+ /*!
+ ensures
+ - returns the number of menus in this menu_bar
+ !*/
+
+ void set_menu_name (unsigned long idx, const std::wstring name, char underline_ch = '\0');
+ void set_menu_name (unsigned long idx, const dlib::ustring name, char underline_ch = '\0');
+ void set_menu_name (
+ unsigned long idx,
+ const std::string name,
+ char underline_ch = '\0'
+ );
+ /*!
+ requires
+ - idx < number_of_menus()
+ ensures
+ - #menu_name(idx) == name
+ - if (underline_ch is present in name) then
+ - The menu with index idx will have the first underline_ch character
+ in its name underlined and users will be able to activate the menu
+ by hitting alt+underline_char
+ !*/
+
+ const std::wstring menu_wname (unsigned long idx) const;
+ const dlib::ustring menu_uname (unsigned long idx) const;
+ const std::string menu_name (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_menus()
+ ensures
+ - returns the name of the menu with index idx
+ !*/
+
+ popup_menu& menu (
+ unsigned long idx
+ );
+ /*!
+ requires
+ - idx < number_of_menus()
+ ensures
+ - returns a non-const reference to the popup_menu for the menu with
+ index idx.
+ !*/
+
+ const popup_menu& menu (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < number_of_menus()
+ ensures
+ - returns a const reference to the popup_menu for the menu with
+ index idx.
+ !*/
+
+ private:
+
+ // restricted functions
+ menu_bar(menu_bar&); // copy constructor
+ menu_bar& operator=(menu_bar&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ class directed_graph_drawer : public zoomable_region
+ {
+ /*!
+ REQUIREMENTS ON graph_type
+ - must be an implementation of directed_graph/directed_graph_kernel_abstract.h
+
+ INITIAL VALUE
+ - get_graph().size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a graphical widget that allows the user to draw
+ a directed graph.
+
+ The user can create nodes by right clicking on the draw area and add
+ edges by selecting a node (via left clicking on it) and then holding
+ shift and clicking on the node that is to be the child node of the
+ selected node.
+ !*/
+
+ public:
+
+ directed_graph_drawer (
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~directed_graph_drawer (
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ const graph_type& graph (
+ ) const;
+ /*!
+ requires
+ - drawable::m is locked
+ ensures
+ - returns a const reference to the graph that this widget has been drawing
+ !*/
+
+ unsigned long number_of_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns graph().number_of_nodes()
+ !*/
+
+ void clear_graph (
+ );
+ /*!
+ ensures
+ - #number_of_nodes() == 0
+ !*/
+
+ const typename graph_type::node_type& graph_node (
+ unsigned long i
+ ) const;
+ /*!
+ requires
+ - drawable::m is locked
+ - i < number_of_nodes()
+ ensures
+ - returns a const reference to get_graph().node(i)
+ !*/
+
+ typename graph_type::node_type& graph_node (
+ unsigned long i
+ );
+ /*!
+ requires
+ - drawable::m is locked
+ - i < number_of_nodes()
+ ensures
+ - returns a non-const reference to get_graph().node(i)
+ !*/
+
+ void save_graph (
+ std::ostream& out
+ );
+ /*!
+ ensures
+ - saves the state of the graph to the output stream. Does so in a
+ way that not only preserves the state of the graph this->graph()
+ but also preserves the graphical layout of the graph in this
+ GUI widget.
+ - Also, the first part of the saved state is a serialized
+ version of this->graph(). Thus, you can deserialize just the
+ this->graph() object from the serialized data if you like.
+ !*/
+
+ void load_graph (
+ std::istream& in
+ );
+ /*!
+ ensures
+ - loads a saved graph from the given input stream.
+ !*/
+
+ void set_node_label (unsigned long i, const std::wstring& label);
+ void set_node_label (unsigned long i, const dlib::ustring& label);
+ void set_node_label (
+ unsigned long i,
+ const std::string& label
+ );
+ /*!
+ requires
+ - i < number_of_nodes()
+ ensures
+ - #node_label(i) == label
+ !*/
+
+ void set_node_color (
+ unsigned long i,
+ rgb_pixel color
+ );
+ /*!
+ requires
+ - i < number_of_nodes()
+ ensures
+ - #node_color(i) == color
+ !*/
+
+ rgb_pixel node_color (
+ unsigned long i
+ ) const;
+ /*!
+ requires
+ - i < number_of_nodes()
+ ensures
+ - returns the color used to draw node graph_node(i)
+ !*/
+
+ const std::wstring node_wlabel (unsigned long i) const;
+ const dlib::ustring node_ulabel (unsigned long i) const;
+ const std::string node_label (
+ unsigned long i
+ ) const;
+ /*!
+ requires
+ - i < number_of_nodes()
+ ensures
+ - returns the text label for node graph_node(i)
+ !*/
+
+ template <
+ typename T
+ >
+ void set_node_selected_handler (
+ T& object,
+ void (T::*event_handler)(unsigned long node_index)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the user selects
+ a node.
+ - node_index == the index of the node that was selected
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_node_selected_handler (
+ const any_function<void(unsigned long node_index)>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the user selects
+ a node.
+ - node_index == the index of the node that was selected
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_node_deselected_handler (
+ T& object,
+ void (T::*event_handler)(unsigned long node_index)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the user
+ deselects a node.
+ - node_index == the index of the node that was deselected
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_node_deselected_handler (
+ const any_function<void(unsigned long node_index)>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the user deselects a node.
+ - node_index == the index of the node that was deselected
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_node_deleted_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the user
+ deletes a node.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_node_deleted_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the user deletes a node.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_graph_modified_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the user
+ modifies the graph (i.e. adds or removes a node or edge)
+ - the event_handler function is not called when the user just
+ moves nodes around on the screen.
+ - This event is always dispatched before any more specific event
+ that results from the user modifying the graph.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_graph_modified_handler (
+ const any_function<void()>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the user modifies
+ the graph (i.e. adds or removes a node or edge)
+ - the event_handler function is not called when the user just
+ moves nodes around on the screen.
+ - This event is always dispatched before any more specific event
+ that results from the user modifying the graph.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ directed_graph_drawer(directed_graph_drawer&); // copy constructor
+ directed_graph_drawer& operator=(directed_graph_drawer&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class text_grid : public scrollable_region
+ {
+ /*!
+ INITIAL VALUE
+ - vertical_scroll_increment() == 10
+ - horizontal_scroll_increment() == 10
+ - border_color() == rgb_pixel(128,128,128)
+ - number_of_columns() == 0
+ - number_of_rows() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple grid of square text fields that
+ looks more or less like a spreadsheet grid.
+ !*/
+
+ public:
+
+ text_grid (
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~text_grid (
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_grid_size (
+ unsigned long rows,
+ unsigned long cols
+ );
+ /*!
+ ensures
+ - #number_of_rows() == rows
+ - #number_of_columns() == cols
+ - for all valid r and c:
+ - #text(r,c) == ""
+ - #text_color(r,c) == rgb_pixel(0,0,0)
+ - #background_color(r,c) == rgb_pixel(255,255,255)
+ - #is_editable(r,c) == true
+ !*/
+
+ unsigned long number_of_columns (
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns contained in this grid
+ !*/
+
+ unsigned long number_of_rows (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows contained in this grid
+ !*/
+
+ rgb_pixel border_color (
+ ) const;
+ /*!
+ ensures
+ - returns the color of the lines drawn between the grid elements
+ !*/
+
+ void set_border_color (
+ rgb_pixel color
+ );
+ /*!
+ ensures
+ - #border_color() == color
+ !*/
+
+ const std::wstring wtext (unsigned long row, unsigned long col) const;
+ const dlib::ustring utext (unsigned long row, unsigned long col) const;
+ const std::string text (
+ unsigned long row,
+ unsigned long col
+ ) const;
+ /*!
+ requires
+ - row < number_of_rows()
+ - col < number_of_columns()
+ ensures
+ - returns the text in the given grid location
+ !*/
+
+ void set_text (unsigned long row, unsigned long col, const std::wstring& str);
+ void set_text (unsigned long row, unsigned long col, const dlib::ustring& str);
+ void set_text (
+ unsigned long row,
+ unsigned long col,
+ const std::string& str
+ );
+ /*!
+ requires
+ - row < number_of_rows()
+ - col < number_of_columns()
+ ensures
+ - #text(row,col) == str
+ !*/
+
+ const rgb_pixel text_color (
+ unsigned long row,
+ unsigned long col
+ ) const;
+ /*!
+ requires
+ - row < number_of_rows()
+ - col < number_of_columns()
+ ensures
+ - returns the color of the text in the given grid location
+ !*/
+
+ void set_text_color (
+ unsigned long row,
+ unsigned long col,
+ const rgb_pixel color
+ );
+ /*!
+ requires
+ - row < number_of_rows()
+ - col < number_of_columns()
+ ensures
+ - #text_color(row,col) == color
+ !*/
+
+ const rgb_pixel background_color (
+ unsigned long row,
+ unsigned long col
+ ) const;
+ /*!
+ requires
+ - row < number_of_rows()
+ - col < number_of_columns()
+ ensures
+ - returns the background color of the given grid location
+ !*/
+
+ void set_background_color (
+ unsigned long row,
+ unsigned long col,
+ const rgb_pixel color
+ );
+ /*!
+ requires
+ - row < number_of_rows()
+ - col < number_of_columns()
+ ensures
+ - #background_color(row,col) == color
+ !*/
+
+ bool is_editable (
+ unsigned long row,
+ unsigned long col
+ ) const;
+ /*!
+ requires
+ - row < number_of_rows()
+ - col < number_of_columns()
+ ensures
+ - if (the given grid location is editable by the user) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void set_editable (
+ unsigned long row,
+ unsigned long col,
+ bool editable
+ );
+ /*!
+ requires
+ - row < number_of_rows()
+ - col < number_of_columns()
+ ensures
+ - #is_editable(row,col) == editable
+ !*/
+
+ void set_column_width (
+ unsigned long col,
+ unsigned long width
+ );
+ /*!
+ requires
+ - col < number_of_columns()
+ ensures
+ - the given column will be displayed such that it is width pixels wide
+ !*/
+
+ void set_row_height (
+ unsigned long row,
+ unsigned long height
+ );
+ /*!
+ requires
+ - row < number_of_rows()
+ ensures
+ - the given row will be displayed such that it is height pixels wide
+ !*/
+
+ template <
+ typename T
+ >
+ void set_text_modified_handler (
+ T& object,
+ void (T::*event_handler)(unsigned long row, unsigned long col)
+ );
+ /*!
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the user selects
+ a node.
+ - row == row will give the row of the grid item that was modified
+ - col == col will give the column of the grid item that was modified
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_text_modified_handler (
+ const any_function<void(unsigned long row, unsigned long col)>& event_handler
+ );
+ /*!
+ ensures
+ - the event_handler function is called when the user selects a node.
+ - row == row will give the row of the grid item that was modified
+ - col == col will give the column of the grid item that was modified
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ text_grid(text_grid&); // copy constructor
+ text_grid& operator=(text_grid&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class image_display : public scrollable_region
+ {
+ /*!
+ INITIAL VALUE
+ - This object isn't displaying anything.
+ - get_overlay_rects().size() == 0
+ - get_default_overlay_rect_label() == ""
+ - get_default_overlay_rect_color() == rgb_alpha_pixel(255,0,0,255) (i.e. RED)
+ - This object does not have any user labelable parts defined.
+ - overlay_editing_is_enabled() == true
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an image inside a scrollable region.
+ You give it an image to display by calling set_image().
+ This widget also allows you to add rectangle and line overlays that
+ will be drawn on top of the image.
+
+ If you hold the Ctrl key you can zoom in and out using the mouse wheel.
+ You can also add new overlay rectangles by holding shift, left clicking,
+ and dragging the mouse. Additionally, you can delete an overlay rectangle
+ by double clicking on it and hitting delete or backspace. Finally, you
+ can also add part labels (if they have been defined by calling add_labelable_part_name())
+ by selecting an overlay rectangle with the mouse and then right clicking
+ on the part. If you want to move any rectangle or an object part then
+ shift+right click and drag it.
+
+ Finally, if you hold Ctrl and left click an overlay rectangle it will
+ change its label to get_default_overlay_rect_label() and color to
+ get_default_overlay_rect_color().
+
+ The image is drawn such that:
+ - the pixel img[0][0] is the upper left corner of the image.
+ - the pixel img[img.nr()-1][0] is the lower left corner of the image.
+ - the pixel img[0][img.nc()-1] is the upper right corner of the image.
+ - the pixel img[img.nr()-1][img.nc()-1] is the lower right corner of the image.
+ !*/
+
+ public:
+
+ image_display(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ !*/
+
+ ~image_display(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ template <
+ typename image_type
+ >
+ void set_image (
+ const image_type& new_img
+ );
+ /*!
+ requires
+ - image_type == an implementation of array2d/array2d_kernel_abstract.h or
+ a dlib::matrix or something convertible to a matrix via mat()
+ - pixel_traits<typename image_type::type> must be defined
+ ensures
+ - #*this widget is now displaying the given image new_img.
+ !*/
+
+ struct overlay_rect
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a rectangle that is drawn on top of the
+ image shown by this object. Each rectangle is represented by
+ a rectangle object as well as a color and text label. The label
+ is drawn below the lower right corner of the rectangle.
+
+ Moreover, the rectangle can have sub-parts. Each part is listed
+ in the parts member variable. This variable maps the name of the
+ part to its position.
+
+ Rectangles with crossed_out == true will be drawn with an X through
+ them.
+ !*/
+
+ rectangle rect;
+ rgb_alpha_pixel color;
+ std::string label;
+ std::map<std::string,point> parts;
+ bool crossed_out;
+
+ overlay_rect(
+ );
+ /*!
+ ensures
+ - #color == rgb_alpha_pixel(0,0,0,0)
+ - #rect == rectangle()
+ - #label.size() == 0
+ - #crossed_out == false
+ !*/
+
+ template <typename pixel_type>
+ overlay_rect(
+ const rectangle& r,
+ pixel_type p
+ );
+ /*!
+ ensures
+ - #rect == r
+ - performs assign_pixel(color, p)
+ - #label.size() == 0
+ - #crossed_out == false
+ !*/
+
+ template <typename pixel_type>
+ overlay_rect(
+ const rectangle& r,
+ pixel_type p,
+ const std::string& l
+ );
+ /*!
+ ensures
+ - #rect == r
+ - performs assign_pixel(color, p)
+ - #label == l
+ - #crossed_out == false
+ !*/
+
+ template <typename pixel_type>
+ overlay_rect(
+ const rectangle& r,
+ pixel_type p,
+ const std::string& l,
+ const std::map<std::string,point>& parts_
+ );
+ /*!
+ ensures
+ - #rect == r
+ - performs assign_pixel(color, p)
+ - #label == l
+ - #parts == parts_
+ - #crossed_out == false
+ !*/
+
+ };
+
+ struct overlay_line
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a line that is drawn on top of the
+ image shown by this object. Each line is represented by
+ its two end points (p1 and p2) as well as a color.
+ !*/
+
+ point p1;
+ point p2;
+ rgb_alpha_pixel color;
+
+ overlay_line(
+ );
+ /*!
+ ensures
+ - #color == rgb_alpha_pixel(0,0,0,0)
+ - #p1 == point()
+ - #p2 == point()
+ !*/
+
+ template <typename pixel_type>
+ overlay_line(
+ const point& p1_,
+ const point& p2_,
+ pixel_type p
+ );
+ /*!
+ ensures
+ - performs assign_pixel(color, p)
+ - #p1 == p1_
+ - #p2 == p2_
+ !*/
+
+ };
+
+ struct overlay_circle
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a circle that is drawn on top of the
+ image shown by this object. Each circle is represented by
+ its center, radius, and color. It can also have an optional
+ text label which will appear below the circle.
+ !*/
+
+ point center;
+ int radius;
+ rgb_alpha_pixel color;
+ std::string label;
+
+ overlay_circle(
+ );
+ /*!
+ ensures
+ - #center == point(0,0)
+ - #radius == 0
+ - #color == rgb_alpha_pixel(0,0,0,0)
+ - #label.size() == 0
+ !*/
+
+ template <typename pixel_type>
+ overlay_circle(
+ const point& center_,
+ const int radius_,
+ pixel_type p
+ );
+ /*!
+ ensures
+ - performs assign_pixel(color, p)
+ - #center == center_
+ - #radius == radius_
+ !*/
+
+ template <typename pixel_type>
+ overlay_circle(
+ const point& center_,
+ const int radius_,
+ pixel_type p,
+ const std::string& label_
+ );
+ /*!
+ ensures
+ - performs assign_pixel(color, p)
+ - #center == center_
+ - #radius == radius_
+ - #label == label_
+ !*/
+
+ };
+
+ void add_overlay (
+ const overlay_rect& overlay
+ );
+ /*!
+ ensures
+ - adds the given overlay rectangle into this object such
+ that it will be displayed.
+ !*/
+
+ void add_overlay (
+ const overlay_line& overlay
+ );
+ /*!
+ ensures
+ - adds the given overlay line into this object such
+ that it will be displayed.
+ !*/
+
+ void add_overlay (
+ const overlay_circle& overlay
+ );
+ /*!
+ ensures
+ - adds the given overlay circle into this object such
+ that it will be displayed.
+ !*/
+
+ void add_overlay (
+ const std::vector<overlay_rect>& overlay
+ );
+ /*!
+ ensures
+ - adds the given set of overlay rectangles into this object such
+ that they will be displayed.
+ !*/
+
+ void add_overlay (
+ const std::vector<overlay_line>& overlay
+ );
+ /*!
+ ensures
+ - adds the given set of overlay lines into this object such
+ that they will be displayed.
+ !*/
+
+ void add_overlay (
+ const std::vector<overlay_circle>& overlay
+ );
+ /*!
+ ensures
+ - adds the given set of overlay circles into this object such
+ that they will be displayed.
+ !*/
+
+ void clear_overlay (
+ );
+ /*!
+ ensures
+ - removes all overlays from this object.
+ - #get_overlay_rects().size() == 0
+ !*/
+
+ std::vector<overlay_rect> get_overlay_rects (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of all the overlay_rect objects currently displayed.
+ !*/
+
+ void set_default_overlay_rect_label (
+ const std::string& label
+ );
+ /*!
+ ensures
+ - #get_default_overlay_rect_label() == label
+ !*/
+
+ std::string get_default_overlay_rect_label (
+ ) const;
+ /*!
+ ensures
+ - returns the label given to new overlay rectangles created by the user
+ (i.e. when the user holds shift and adds them with the mouse)
+ !*/
+
+ void set_default_overlay_rect_color (
+ const rgb_alpha_pixel& color
+ );
+ /*!
+ ensures
+ - #get_default_overlay_rect_color() == color
+ !*/
+
+ rgb_alpha_pixel get_default_overlay_rect_color (
+ ) const;
+ /*!
+ ensures
+ - returns the color given to new overlay rectangles created by the user
+ (i.e. when the user holds shift and adds them with the mouse)
+ !*/
+
+ void add_labelable_part_name (
+ const std::string& name
+ );
+ /*!
+ ensures
+ - adds a user labelable part with the given name. If the name has
+ already been added then this function has no effect.
+ - These parts can be added by the user by selecting an overlay box
+ and then right clicking anywhere in it. A popup menu will appear
+ listing the parts. The user can then click a part name and it will
+ add it into the overlay_rect::parts variable and also show it on the
+ screen.
+ !*/
+
+ void clear_labelable_part_names (
+ );
+ /*!
+ ensures
+ - removes all use labelable parts. Calling this function undoes
+ all previous calls to add_labelable_part_name(). Therefore, the
+ user won't be able to label any parts after clear_labelable_part_names()
+ is called.
+ !*/
+
+ rectangle get_image_display_rect (
+ ) const;
+ /*!
+ ensures
+ - returns a rectangle R that tells you how big the image inside the
+ display is when it appears on the screen. Note that it takes the
+ current zoom level into account.
+ - R.width() == the width of the displayed image
+ - R.height() == the height of the displayed image
+ - R.tl_corner() == (0,0)
+ !*/
+
+ void enable_overlay_editing (
+ );
+ /*!
+ ensures
+ - #overlay_editing_is_enabled() == true
+ !*/
+
+ void disable_overlay_editing (
+ );
+ /*!
+ ensures
+ - #overlay_editing_is_enabled() == false
+ !*/
+
+ bool overlay_editing_is_enabled (
+ ) const;
+ /*!
+ ensures
+ - if this function returns true then it is possible for the user to add or
+ remove overlay objects (e.g. rectangles) using the mouse and keyboard.
+ If it returns false then the overlay is not user editable.
+ !*/
+
+ template <
+ typename T
+ >
+ void set_overlay_rects_changed_handler (
+ T& object,
+ void (T::*event_handler)()
+ );
+ /*
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - the event_handler function is called on object when the user adds,
+ removes, or modifies an overlay rectangle.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ */
+
+ void set_overlay_rects_changed_handler (
+ const any_function<void()>& event_handler
+ );
+ /*
+ ensures
+ - the event_handler function is called when the user adds or removes
+ an overlay rectangle.
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ */
+
+ template <
+ typename T
+ >
+ void set_overlay_rect_selected_handler (
+ T& object,
+ void (T::*event_handler)(const overlay_rect& orect)
+ );
+ /*
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - The event_handler function is called on object when the user selects
+ an overlay rectangle by double clicking on it. The selected rectangle
+ will be passed to event_handler().
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ */
+
+ void set_overlay_rect_selected_handler (
+ const any_function<void(const overlay_rect& orect)>& event_handler
+ );
+ /*
+ ensures
+ - The event_handler function is called when the user selects an overlay
+ rectangle by double clicking on it. The selected rectangle will be
+ passed to event_handler().
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ */
+
+ template <
+ typename T
+ >
+ void set_image_clicked_handler (
+ T& object,
+ void (T::*event_handler)(const point& p, bool is_double_click, unsigned long btn)
+ );
+ /*
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - The event_handler function is called on object when the user left clicks
+ anywhere on the image. When they do so this callback is called with the
+ location of the image pixel which was clicked. The is_double_click bool
+ will also tell you if it was a double click or single click.
+ - btn == the button that was released. (either base_window::LEFT, base_window::MIDDLE, or base_window::RIGHT)
+ - any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ throws
+ - std::bad_alloc
+ */
+
+ void set_image_clicked_handler (
+ const any_function<void(const point& p, bool is_double_click, unsigned long btn)>& event_handler
+ );
+ /*
+ ensures
+ - The event_handler function is called when the user left clicks anywhere
+ on the image. When they do so this callback is called with the location
+ of the image pixel which was clicked. The is_double_click bool will also
+ tell you if it was a double click or single click.
+ - btn == the button that was released. (either base_window::LEFT, base_window::MIDDLE, or base_window::RIGHT)
+ - Any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this event at a
+ time)
+ throws
+ - std::bad_alloc
+ */
+
+ private:
+
+ // restricted functions
+ image_display(image_display&); // copy constructor
+ image_display& operator=(image_display&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class image_window : public drawable_window
+ {
+ /*!
+ INITIAL VALUE
+ - initially, this object is visible on the screen
+ - events_tied() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple window that is just a container for an image_display.
+ It exists to make it easy to throw image_displays onto the screen
+ without having to put together your own drawable_window objects.
+ !*/
+ public:
+
+ typedef image_display::overlay_rect overlay_rect;
+ typedef image_display::overlay_line overlay_line;
+ typedef image_display::overlay_circle overlay_circle;
+
+ image_window(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ template <typename image_type>
+ image_window(
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an implementation of array2d/array2d_kernel_abstract.h or
+ a dlib::matrix or something convertible to a matrix via mat()
+ - pixel_traits<typename image_type::type> must be defined
+ ensures
+ - this object is properly initialized
+ - #*this window is now displaying the given image img.
+ !*/
+
+ template < typename image_type>
+ image_window(
+ const image_type& img,
+ const std::string& title
+ );
+ /*!
+ requires
+ - image_type == an implementation of array2d/array2d_kernel_abstract.h or
+ a dlib::matrix or something convertible to a matrix via mat()
+ - pixel_traits<typename image_type::type> must be defined
+ ensures
+ - this object is properly initialized
+ - #*this window is now displaying the given image img.
+ - The title of the window will be set to the given title string.
+ !*/
+
+ ~image_window(
+ );
+ /*!
+ ensures
+ - any resources associated with this object have been released
+ !*/
+
+ template <typename image_type>
+ void set_image (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an implementation of array2d/array2d_kernel_abstract.h
+ - pixel_traits<typename image_type::type> must be defined
+ ensures
+ - #*this window is now displaying the given image img.
+ !*/
+
+ void add_overlay (
+ const overlay_rect& overlay
+ );
+ /*!
+ ensures
+ - adds the given overlay rectangle into this object such
+ that it will be displayed.
+ !*/
+
+ template <typename pixel_type>
+ void add_overlay(
+ const rectangle& r,
+ pixel_type p = rgb_pixel(255,0,0)
+ );
+ /*!
+ ensures
+ - performs: add_overlay(overlay_rect(r,p));
+ !*/
+
+ template <typename pixel_type>
+ void add_overlay(
+ const rectangle& r,
+ pixel_type p,
+ const std::string& l
+ );
+ /*!
+ ensures
+ - performs: add_overlay(overlay_rect(r,p,l));
+ !*/
+
+ template <typename pixel_type>
+ void add_overlay(
+ const std::vector<rectangle>& r,
+ pixel_type p = rgb_pixel(255,0,0)
+ );
+ /*!
+ ensures
+ - adds the given set of rectangles into this object such
+ that they will be displayed with the color specific by p.
+ !*/
+
+ void add_overlay(
+ const full_object_detection& object,
+ const std::vector<std::string>& part_names
+ );
+ /*!
+ ensures
+ - adds the given full_object_detection to the overlays
+ and shows it on the screen. This includes any of its
+ parts that are not set equal to OBJECT_PART_NOT_PRESENT.
+ - for all valid i < part_names.size():
+ - the part object.part(i) will be labeled with the string
+ part_names[i].
+ !*/
+
+ void add_overlay(
+ const full_object_detection& object
+ );
+ /*!
+ ensures
+ - adds the given full_object_detection to the overlays
+ and shows it on the screen. This includes any of its
+ parts that are not set equal to OBJECT_PART_NOT_PRESENT.
+ !*/
+
+ void add_overlay(
+ const std::vector<full_object_detection>& objects,
+ const std::vector<std::string>& part_names
+ );
+ /*!
+ ensures
+ - calling this function is equivalent to calling the following
+ sequence of functions, for all valid i:
+ - add_overlay(objects[i], part_names);
+ !*/
+
+ void add_overlay(
+ const std::vector<full_object_detection>& objects
+ );
+ /*!
+ ensures
+ - calling this function is equivalent to calling the following
+ sequence of functions, for all valid i:
+ - add_overlay(objects[i]);
+ !*/
+
+ void add_overlay (
+ const overlay_line& overlay
+ );
+ /*!
+ ensures
+ - adds the given overlay line into this object such
+ that it will be displayed.
+ !*/
+
+ void add_overlay (
+ const overlay_circle& overlay
+ );
+ /*!
+ ensures
+ - adds the given overlay circle into this object such
+ that it will be displayed.
+ !*/
+
+ template <typename pixel_type>
+ void add_overlay(
+ const point& p1,
+ const point& p2,
+ pixel_type p
+ );
+ /*!
+ ensures
+ - performs: add_overlay(overlay_line(p1,p2,p));
+ !*/
+
+ void add_overlay (
+ const std::vector<overlay_rect>& overlay
+ );
+ /*!
+ ensures
+ - adds the given set of overlay rectangles into this object such
+ that they will be displayed.
+ !*/
+
+ void add_overlay (
+ const std::vector<overlay_line>& overlay
+ );
+ /*!
+ ensures
+ - adds the given set of overlay lines into this object such
+ that they will be displayed.
+ !*/
+
+ void add_overlay (
+ const std::vector<overlay_circle>& overlay
+ );
+ /*!
+ ensures
+ - adds the given set of overlay circles into this object such
+ that they will be displayed.
+ !*/
+
+ void clear_overlay (
+ );
+ /*!
+ ensures
+ - removes all overlays from this object.
+ !*/
+
+ void tie_events (
+ );
+ /*!
+ ensures
+ - #events_tied() == true
+ !*/
+
+ void untie_events (
+ );
+ /*!
+ ensures
+ - #events_tied() == false
+ !*/
+
+ bool events_tied (
+ ) const;
+ /*!
+ ensures
+ - returns true if and only if the get_next_double_click() and
+ get_next_keypress() events are tied together. If they are tied it means
+ that you can use a loop of the following form to listen for both events
+ simultaneously:
+ while (mywindow.get_next_double_click(p) || mywindow.get_next_keypress(key,printable))
+ {
+ if (p.x() < 0)
+ // Do something with the keyboard event
+ else
+ // Do something with the mouse event
+ }
+ !*/
+
+ bool get_next_double_click (
+ point& p
+ );
+ /*!
+ ensures
+ - This function blocks until the user double clicks on the image or the
+ window is closed by the user. It will also unblock for a keyboard key
+ press if events_tied() == true.
+ - if (this function returns true) then
+ - This means the user double clicked the mouse.
+ - #p == the next image pixel the user clicked.
+ - else
+ - #p == point(-1,1)
+ !*/
+
+ bool get_next_double_click (
+ point& p,
+ unsigned long& mouse_button
+ );
+ /*!
+ ensures
+ - This function blocks until the user double clicks on the image or the
+ window is closed by the user. It will also unblock for a keyboard key
+ press if events_tied() == true.
+ - if (this function returns true) then
+ - This means the user double clicked the mouse.
+ - #p == the next image pixel the user clicked.
+ - #mouse_button == the mouse button which was used to double click.
+ This will be either dlib::base_window::LEFT,
+ dlib::base_window::MIDDLE, or dlib::base_window::RIGHT
+ - else
+ - #p == point(-1,1)
+ (Note that this point is outside any possible image)
+ !*/
+
+ bool get_next_keypress (
+ unsigned long& key,
+ bool& is_printable,
+ unsigned long& state
+ );
+ /*!
+ ensures
+ - This function blocks until the user presses a keyboard key or the
+ window is closed by the user. It will also unblock for a mouse double
+ click if events_tied() == true.
+ - if (this function returns true) then
+ - This means the user pressed a keyboard key.
+ - The keyboard button press is recorded into #key, #is_printable, and
+ #state. In particular, these variables are populated with the three
+ identically named arguments to the base_window::on_keydown(key,is_printable,state)
+ event.
+ !*/
+
+ bool get_next_keypress (
+ unsigned long& key,
+ bool& is_printable
+ );
+ /*!
+ ensures
+ - This function blocks until the user presses a keyboard key or the
+ window is closed by the user. It will also unblock for a mouse double
+ click if events_tied() == true.
+ - This function is the equivalent to calling get_next_keypress(key,is_printable,temp)
+ and then discarding temp.
+ !*/
+
+ private:
+
+ // restricted functions
+ image_window(image_window&);
+ image_window& operator= (image_window&);
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class perspective_display : public drawable, noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for displaying 3D point clouds on a screen. You can
+ navigate the display with the mouse. Left click and drag rotates the
+ camera around the displayed data. Scrolling the mouse wheel zooms and
+ shift+left click (or just right click) and drag pans the view around.
+ !*/
+
+ public:
+
+ perspective_display(
+ drawable_window& w
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #*this has been added to window w
+ - #parent_window() == w
+ !*/
+
+ ~perspective_display(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_size (
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ ensures
+ - #width() == width
+ - #height() == height
+ - #top() == top()
+ - #left() == left()
+ - i.e. The location of the upper left corner of this widget stays the
+ same but its width and height are modified.
+ !*/
+
+ struct overlay_line
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a line that is drawn on the screen. Each line
+ is represented by its two end points (p1 and p2) as well as a color.
+ !*/
+
+ overlay_line() { assign_pixel(color, 0);}
+
+ overlay_line(const vector<double>& p1_, const vector<double>& p2_)
+ : p1(p1_), p2(p2_) { assign_pixel(color, 255); }
+
+ template <typename pixel_type>
+ overlay_line(const vector<double>& p1_, const vector<double>& p2_, pixel_type p)
+ : p1(p1_), p2(p2_) { assign_pixel(color, p); }
+
+ vector<double> p1;
+ vector<double> p2;
+ rgb_pixel color;
+ };
+
+ struct overlay_dot
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a dot that is drawn on the screen. Each dot is
+ represented by one point and a color.
+ !*/
+
+ overlay_dot() { assign_pixel(color, 0);}
+
+ overlay_dot(const vector<double>& p_)
+ : p(p_) { assign_pixel(color, 255); }
+
+ template <typename pixel_type>
+ overlay_dot(const vector<double>& p_, pixel_type color_)
+ : p(p_) { assign_pixel(color, color_); }
+
+ vector<double> p; // The location of the dot
+ rgb_pixel color;
+ };
+
+ void add_overlay (
+ const std::vector<overlay_line>& overlay
+ );
+ /*!
+ ensures
+ - Adds the given overlay lines into this object such that it will be
+ displayed.
+ !*/
+
+ void add_overlay (
+ const std::vector<overlay_dot>& overlay
+ );
+ /*!
+ ensures
+ - Adds the given overlay dots into this object such that it will be
+ displayed.
+ !*/
+
+ void clear_overlay (
+ );
+ /*!
+ ensures
+ - Removes all overlays from this object. The display will be empty.
+ !*/
+
+ template <typename T>
+ void set_dot_double_clicked_handler (
+ T& object,
+ void (T::*event_handler)(const vector<double>&)
+ );
+ /*
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - The event_handler function is called on object when the user double
+ clicks on one of the overlay dots. The selected dot will be passed to
+ event_handler().
+ - Any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ */
+
+ void set_dot_double_clicked_handler (
+ const any_function<void(const vector<double>&)>& event_handler
+ );
+ /*
+ ensures
+ - The event_handler function is called when the user double clicks on one
+ of the overlay dots. The selected dot will be passed to event_handler().
+ - Any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ */
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class perspective_window : public drawable_window, noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple window that is just a container for a perspective_display.
+ It exists to make it easy to throw perspective_displays onto the screen
+ without having to put together your own drawable_window objects.
+ !*/
+ public:
+
+ typedef perspective_display::overlay_line overlay_line;
+ typedef perspective_display::overlay_dot overlay_dot;
+
+ perspective_window(
+ );
+ /*!
+ ensures
+ - The window is displayed on the screen and is 100x100 pixels in size.
+ !*/
+
+ perspective_window(
+ const std::vector<dlib::vector<double> >& point_cloud
+ );
+ /*!
+ ensures
+ - The window is displayed on the screen and is 100x100 pixels in size.
+ - This window will have point_cloud added to it via add_overlay() and the
+ points will all be white.
+ !*/
+
+ perspective_window(
+ const std::vector<dlib::vector<double> >& point_cloud,
+ const std::string& title
+ );
+ /*!
+ ensures
+ - The window is displayed on the screen and is 100x100 pixels in size.
+ - This window will have point_cloud added to it via add_overlay() and the
+ points will all be white.
+ - The title of the window will be set to the given title string.
+ !*/
+
+ ~perspective_window(
+ );
+ /*!
+ ensures
+ - any resources associated with this object have been released
+ !*/
+
+ void add_overlay (
+ const std::vector<overlay_line>& overlay
+ );
+ /*!
+ ensures
+ - Adds the given overlay lines into this object such that it will be
+ displayed.
+ !*/
+
+ void add_overlay (
+ const std::vector<overlay_dot>& overlay
+ );
+ /*!
+ ensures
+ - Adds the given overlay dots into this object such that it will be
+ displayed.
+ !*/
+
+ void clear_overlay (
+ );
+ /*!
+ ensures
+ - Removes all overlays from this object. The display will be empty.
+ !*/
+
+ void add_overlay(
+ const std::vector<dlib::vector<double> >& d
+ );
+ /*!
+ ensures
+ - Adds the given dots into this object such that it will be
+ displayed. They will be colored white.
+ !*/
+
+ template <typename pixel_type>
+ void add_overlay(
+ const std::vector<dlib::vector<double> >& d,
+ pixel_type p
+ );
+ /*!
+ ensures
+ - Adds the given dots into this object such that it will be
+ displayed. They will be colored by pixel color p.
+ !*/
+
+ template <typename pixel_type>
+ void add_overlay(
+ const vector<double>& p1,
+ const vector<double>& p2,
+ pixel_type color
+ );
+ /*!
+ ensures
+ - Adds an overlay line going from p1 to p2 with the given color.
+ !*/
+
+ template < typename T >
+ void set_dot_double_clicked_handler (
+ T& object,
+ void (T::*event_handler)(const vector<double>&)
+ );
+ /*
+ requires
+ - event_handler is a valid pointer to a member function in T
+ ensures
+ - The event_handler function is called on object when the user double
+ clicks on one of the overlay dots. The selected dot will be passed to
+ event_handler().
+ - Any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ */
+
+ void set_dot_double_clicked_handler (
+ const any_function<void(const vector<double>&)>& event_handler
+ );
+ /*
+ ensures
+ - The event_handler function is called when the user double clicks on one
+ of the overlay dots. The selected dot will be passed to event_handler().
+ - Any previous calls to this function are overridden by this new call.
+ (i.e. you can only have one event handler associated with this
+ event at a time)
+ */
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_WIDGETs_ABSTRACT_
+
diff --git a/ml/dlib/dlib/hash.h b/ml/dlib/dlib/hash.h
new file mode 100644
index 000000000..5a018b438
--- /dev/null
+++ b/ml/dlib/dlib/hash.h
@@ -0,0 +1,14 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASh_
+#define DLIB_HASh_
+
+
+#include "general_hash/hash.h"
+#include "general_hash/random_hashing.h"
+#include "general_hash/count_bits.h"
+
+
+#endif // DLIB_HASh_
+
+
diff --git a/ml/dlib/dlib/hash_map.h b/ml/dlib/dlib/hash_map.h
new file mode 100644
index 000000000..225ebd466
--- /dev/null
+++ b/ml/dlib/dlib/hash_map.h
@@ -0,0 +1,63 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_MAp_
+#define DLIB_HASH_MAp_
+
+#include "hash_map/hash_map_kernel_1.h"
+#include "hash_map/hash_map_kernel_c.h"
+
+#include "hash_table.h"
+#include "algs.h"
+
+#include "algs.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<domain>
+ >
+ class hash_map
+ {
+ hash_map() {}
+
+ typedef typename hash_table<domain,range,mem_manager,compare>::kernel_1a
+ hash_table_1;
+ typedef typename hash_table<domain,range,mem_manager,compare>::kernel_2a
+ hash_table_2;
+ typedef typename hash_table<domain,range,mem_manager,compare>::kernel_2b
+ hash_table_3;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef hash_map_kernel_1<domain,range,expnum,hash_table_1,mem_manager>
+ kernel_1a;
+ typedef hash_map_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+ // kernel_1b
+ typedef hash_map_kernel_1<domain,range,expnum,hash_table_2,mem_manager>
+ kernel_1b;
+ typedef hash_map_kernel_c<kernel_1b>
+ kernel_1b_c;
+
+ // kernel_1c
+ typedef hash_map_kernel_1<domain,range,expnum,hash_table_3,mem_manager>
+ kernel_1c;
+ typedef hash_map_kernel_c<kernel_1c>
+ kernel_1c_c;
+
+
+ };
+}
+
+#endif // DLIB_HASH_MAp_
+
diff --git a/ml/dlib/dlib/hash_map/hash_map_kernel_1.h b/ml/dlib/dlib/hash_map/hash_map_kernel_1.h
new file mode 100644
index 000000000..e8b1d6d71
--- /dev/null
+++ b/ml/dlib/dlib/hash_map/hash_map_kernel_1.h
@@ -0,0 +1,460 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_MAP_KERNEl_1_
+#define DLIB_HASH_MAP_KERNEl_1_
+
+#include "hash_map_kernel_abstract.h"
+#include "../algs.h"
+#include "../general_hash/general_hash.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/map_pair.h"
+#include "../interfaces/remover.h"
+#include "../assert.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager = default_memory_manager
+ >
+ class hash_map_kernel_1 : public enumerable<map_pair<domain,range> >,
+ public pair_remover<domain,range>
+ {
+
+ /*!
+ REQUIREMENTS ON hash_table
+ hash_table is instantiated with domain and range and
+ T_is_POD must be set to false and
+ implements hash_table/hash_table_kernel_abstract.h
+
+ INITIAL VALUE
+ table.size() == 0
+
+ CONVENTION
+ table.size() = size() == the number of elements in the map
+ the elements in this hash_map are stored in table
+ !*/
+
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef typename hash_table::compare_type compare_type;
+ typedef mem_manager mem_manager_type;
+
+ hash_map_kernel_1(
+ ) :
+ table(expnum)
+ {
+ COMPILE_TIME_ASSERT(expnum < 32);
+ }
+
+ virtual ~hash_map_kernel_1(
+ )
+ {}
+
+ inline void clear(
+ );
+
+ void add (
+ domain& d,
+ range& r
+ );
+
+ inline bool is_in_domain (
+ const domain& d
+ ) const;
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ void destroy (
+ const domain& d
+ );
+
+ range& operator[] (
+ const domain& d
+ );
+
+ const range& operator[] (
+ const domain& d
+ ) const;
+
+ inline void swap (
+ hash_map_kernel_1& item
+ );
+
+ // functions from the remover interface
+ inline void remove_any (
+ domain& d,
+ range& r
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ inline bool current_element_valid (
+ ) const;
+
+ inline const map_pair<domain,range>& element (
+ ) const;
+
+ inline map_pair<domain,range>& element (
+ );
+
+ inline bool move_next (
+ ) const;
+
+ private:
+
+ hash_table table;
+
+ // restricted functions
+ hash_map_kernel_1(hash_map_kernel_1&);
+ hash_map_kernel_1& operator= ( hash_map_kernel_1&);
+
+ };
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ inline void swap (
+ hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>& a,
+ hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>& b
+ ) { a.swap(b); }
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void deserialize (
+ hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ domain d;
+ range r;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(d,in);
+ deserialize(r,in);
+ item.add(d,r);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type hash_map_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ clear (
+ )
+ {
+ table.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ add (
+ domain& d,
+ range& r
+ )
+ {
+ table.add(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ bool hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ is_in_domain(
+ const domain& d
+ ) const
+ {
+ return (table[d] != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ remove_any (
+ domain& d,
+ range& r
+ )
+ {
+ table.remove_any(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ remove(
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ table.remove(d,d_copy,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ destroy(
+ const domain& d
+ )
+ {
+ table.destroy(d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ range& hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ operator[](
+ const domain& d
+ )
+ {
+ return *table[d];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ const range& hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ operator[](
+ const domain& d
+ ) const
+ {
+ return *table[d];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ size_t hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ size (
+ ) const
+ {
+ return table.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ swap (
+ hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>& item
+ )
+ {
+ table.swap(item.table);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ bool hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ at_start (
+ ) const
+ {
+ return table.at_start();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ reset (
+ ) const
+ {
+ table.reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ bool hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return table.current_element_valid();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ const map_pair<domain,range>& hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ element (
+ ) const
+ {
+ return table.element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ map_pair<domain,range>& hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ element (
+ )
+ {
+ return table.element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ bool hash_map_kernel_1<domain,range,expnum,hash_table,mem_manager>::
+ move_next (
+ ) const
+ {
+ return table.move_next();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HASH_MAP_KERNEl_1_
+
diff --git a/ml/dlib/dlib/hash_map/hash_map_kernel_abstract.h b/ml/dlib/dlib/hash_map/hash_map_kernel_abstract.h
new file mode 100644
index 000000000..cee404f54
--- /dev/null
+++ b/ml/dlib/dlib/hash_map/hash_map_kernel_abstract.h
@@ -0,0 +1,247 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_HASH_MAP_KERNEl_ABSTRACT_
+#ifdef DLIB_HASH_MAP_KERNEl_ABSTRACT_
+
+#include "../general_hash/general_hash.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../interfaces/map_pair.h"
+#include "../serialize.h"
+#include "../algs.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<T>
+ >
+ class hash_map : public enumerable<map_pair<domain,range> >,
+ public pair_remover<domain,range>
+ {
+
+ /*!
+ REQUIREMENTS ON domain
+ domain must be comparable by compare where compare is a functor compatible with std::less and
+ domain must be hashable by general_hash
+ (general_hash is defined in dlib/general_hash) and
+ domain must be swappable by a global swap() and
+ domain must have a default constructor
+
+ REQUIREMENTS ON range
+ range must be swappable by a global swap() and
+ range must have a default constructor
+
+ REQUIREMENTS ON expnum
+ expnum < 32
+ 2^expnum is the number of buckets to hash items of type T into.
+ Note that this is really just a suggestion to the hash table.
+ Implementations are free to manage the table size however is most
+ appropriate.
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap(), is_in_domain(), and operator[] functions do
+ not invalidate pointers or references to internal data.
+ All other functions have no such guarantees.
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ No order is specified. Only that each element will be visited once
+ and only once.
+
+ WHAT THIS OBJECT REPRESENTS
+ hash_map contains items of type domain and range
+
+ This object is similar an array. It maps items of type domain on to
+ items of type range.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+ definition of equivalent:
+ a is equivalent to b if
+ a < b == false and
+ b < a == false
+ !*/
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ hash_map(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ !*/
+
+ virtual ~hash_map(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void add (
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - &d != &r (i.e. d and r cannot be the same variable)
+ - is_in_domain(d) == false
+ ensures
+ - #is_in_domain(d) == true
+ - #operator[](d) == r
+ - #d and #r have initial values for their types
+ - #size() == size() + 1
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ if add() throws then it has no effect
+ !*/
+
+ bool is_in_domain (
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - returns whether or not an element equivalent to d is in the
+ domain of *this
+ !*/
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+ /*!
+ requires
+ - &d != &r (i.e. d and r cannot be the same variable)
+ - &d != &d_copy (i.e. d and d_copy cannot be the same variable)
+ - &r != &d_copy (i.e. r and d_copy cannot be the same variable)
+ - is_in_domain(d) == true
+ ensures
+ - #is_in_domain(d) == false
+ - #d_copy is equivalent to d
+ - the element in the range of *this associated with #d_copy has
+ been swapped into #r
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void destroy (
+ const domain& d
+ );
+ /*!
+ requires
+ - is_in_domain(d) == true
+ ensures
+ - #is_in_domain(d) == false
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ range& operator[] (
+ const domain& d
+ );
+ /*!
+ requires
+ - is_in_domain(d) == true
+ ensures
+ - returns a non-const reference to the element in the range of *this
+ associated with the element equivalent to d
+ !*/
+
+ const range& operator[] (
+ const domain& d
+ ) const;
+ /*!
+ requires
+ - is_in_domain(d) == true
+ ensures
+ - returns a const reference to the element in the range of *this
+ associated with the element equivalent to d
+ !*/
+
+ void swap (
+ hash_map& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+
+ private:
+
+ // restricted functions
+ hash_map(hash_map&);
+ hash_map& operator=(hash_map&);
+ };
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename mem_manager,
+ typename compare
+ >
+ inline void swap (
+ hash_map<domain,range,expnum,mem_manager,compare>& a,
+ hash_map<domain,range,expnum,mem_manager,compare>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename domain,
+ typename range,
+ unsigned long expnum,
+ typename mem_manager,
+ typename compare
+ >
+ void deserialize (
+ hash_map<domain,range,expnum,mem_manager,compare>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_HASH_MAP_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/hash_map/hash_map_kernel_c.h b/ml/dlib/dlib/hash_map/hash_map_kernel_c.h
new file mode 100644
index 000000000..4db937249
--- /dev/null
+++ b/ml/dlib/dlib/hash_map/hash_map_kernel_c.h
@@ -0,0 +1,276 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_MAP_KERNEl_C_
+#define DLIB_HASH_MAP_KERNEl_C_
+
+#include "hash_map_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ typename hash_map_base
+ >
+ class hash_map_kernel_c : public hash_map_base
+ {
+
+ typedef typename hash_map_base::domain_type domain;
+ typedef typename hash_map_base::range_type range;
+
+
+ public:
+ void add (
+ domain& d,
+ range& r
+ );
+
+ void remove_any (
+ domain& d,
+ range& r
+ );
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ void destroy (
+ const domain& d
+ );
+
+ range& operator[] (
+ const domain& d
+ );
+
+ const range& operator[] (
+ const domain& d
+ ) const;
+
+ const map_pair<domain,range>& element (
+ ) const;
+
+ map_pair<domain,range>& element (
+ );
+ };
+
+ template <
+ typename hash_map_base
+ >
+ inline void swap (
+ hash_map_kernel_c<hash_map_base>& a,
+ hash_map_kernel_c<hash_map_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_map_base
+ >
+ void hash_map_kernel_c<hash_map_base>::
+ add (
+ domain& d,
+ range& r
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (!this->is_in_domain(d)) &&
+ (static_cast<void*>(&d) != static_cast<void*>(&r)),
+ "\tvoid hash_map::add"
+ << "\n\tdomain element being added must not already be in the hash_map"
+ << "\n\tand d and r must not be the same variable"
+ << "\n\tis_in_domain(d): " << (this->is_in_domain(d) ? "true" : "false")
+ << "\n\tthis: " << this
+ << "\n\t&d: " << static_cast<void*>(&d)
+ << "\n\t&r: " << static_cast<void*>(&r)
+ );
+
+
+ // call the real function
+ hash_map_base::add(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_map_base
+ >
+ void hash_map_kernel_c<hash_map_base>::
+ remove_any (
+ domain& d,
+ range& r
+ )
+ {
+
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (this->size() > 0) &&
+ (static_cast<void*>(&d) != static_cast<void*>(&r)),
+ "\tvoid hash_map::remove_any"
+ << "\n\tsize() must be greater than zero if something is going to be removed"
+ << "\n\tand d and r must not be the same variable."
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ << "\n\t&d: " << static_cast<void*>(&d)
+ << "\n\t&r: " << static_cast<void*>(&r)
+ );
+
+
+ // call the real function
+ hash_map_base::remove_any(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_map_base
+ >
+ void hash_map_kernel_c<hash_map_base>::
+ remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (this->is_in_domain(d)) &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&r)) &&
+ (static_cast<void*>(&r) != static_cast<void*>(&d_copy)) &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&d_copy)),
+ "\tvoid hash_map::remove"
+ << "\n\tcan't remove something that isn't in the hash_map or if the paremeters"
+ << "\n\tare actually the same variable. Either way can't remove."
+ << "\n\tis_in_domain(d): " << (this->is_in_domain(d) ? "true" : "false")
+ << "\n\tthis: " << this
+ << "\n\t&d: " << static_cast<const void*>(&d)
+ << "\n\t&r: " << static_cast<void*>(&r)
+ << "\n\t&d_copy: " << static_cast<void*>(&d_copy)
+ );
+
+
+ // call the real function
+ hash_map_base::remove(d,d_copy,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_map_base
+ >
+ void hash_map_kernel_c<hash_map_base>::
+ destroy (
+ const domain& d
+ )
+ {
+
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->is_in_domain(d),
+ "\tvoid hash_map::destroy"
+ << "\n\tcan't remove something that isn't in the hash_map"
+ << "\n\tthis: " << this
+ << "\n\t&d: " << static_cast<const void*>(&d)
+ );
+
+
+ // call the real function
+ hash_map_base::destroy(d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_map_base
+ >
+ typename hash_map_base::range_type& hash_map_kernel_c<hash_map_base>::
+ operator[] (
+ const domain& d
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->is_in_domain(d),
+ "\trange& hash_map::operator[]"
+ << "\n\td must be in the domain of the hash_map"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return hash_map_base::operator[](d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_map_base
+ >
+ const typename hash_map_base::range_type& hash_map_kernel_c<hash_map_base>::
+ operator[] (
+ const domain& d
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->is_in_domain(d),
+ "\tconst range& hash_map::operator[]"
+ << "\n\td must be in the domain of the hash_map"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return hash_map_base::operator[](d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_map_base
+ >
+ const map_pair<typename hash_map_base::domain_type,typename hash_map_base::range_type>& hash_map_kernel_c<hash_map_base>::
+ element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst map_pair<domain,range>& hash_map::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return hash_map_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_map_base
+ >
+ map_pair<typename hash_map_base::domain_type,typename hash_map_base::range_type>& hash_map_kernel_c<hash_map_base>::
+ element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tmap_pair<domain,range>& hash_map::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return hash_map_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HASH_MAP_KERNEl_C_
+
diff --git a/ml/dlib/dlib/hash_set.h b/ml/dlib/dlib/hash_set.h
new file mode 100644
index 000000000..90a51fd20
--- /dev/null
+++ b/ml/dlib/dlib/hash_set.h
@@ -0,0 +1,63 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_SEt_
+#define DLIB_HASH_SEt_
+
+#include "hash_set/hash_set_kernel_1.h"
+#include "hash_set/hash_set_kernel_c.h"
+
+#include "hash_table.h"
+#include "algs.h"
+
+
+#include "algs.h"
+#include <functional>
+
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<T>
+ >
+ class hash_set
+ {
+ hash_set() {}
+
+ typedef typename hash_table<T,char,mem_manager,compare>::kernel_1a ht1a;
+ typedef typename hash_table<T,char,mem_manager,compare>::kernel_1a ht2a;
+ typedef typename hash_table<T,char,mem_manager,compare>::kernel_1a ht2b;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef hash_set_kernel_1<T,expnum,ht1a,mem_manager>
+ kernel_1a;
+ typedef hash_set_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+ // kernel_1b
+ typedef hash_set_kernel_1<T,expnum,ht2a,mem_manager>
+ kernel_1b;
+ typedef hash_set_kernel_c<kernel_1b>
+ kernel_1b_c;
+
+ // kernel_1c
+ typedef hash_set_kernel_1<T,expnum,ht2b,mem_manager>
+ kernel_1c;
+ typedef hash_set_kernel_c<kernel_1c>
+ kernel_1c_c;
+
+
+
+
+ };
+}
+
+#endif // DLIB_HASH_SEt_
+
diff --git a/ml/dlib/dlib/hash_set/hash_set_kernel_1.h b/ml/dlib/dlib/hash_set/hash_set_kernel_1.h
new file mode 100644
index 000000000..c40770b91
--- /dev/null
+++ b/ml/dlib/dlib/hash_set/hash_set_kernel_1.h
@@ -0,0 +1,391 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_SET_KERNEl_1_
+#define DLIB_HASH_SET_KERNEl_1_
+
+#include "hash_set_kernel_abstract.h"
+#include "../algs.h"
+#include "../general_hash/general_hash.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../assert.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager = default_memory_manager
+ >
+ class hash_set_kernel_1 : public enumerable<const T>,
+ public remover<T>
+ {
+
+ /*!
+ REQUIREMENTS ON hash_table
+ hash_table is instantiated with <domain=T,range=char> and
+ T_is_POD must be set to false and
+ is an implementation of hash_table/hash_table_kernel_abstract.h
+
+ INITIAL VALUE
+ table.size() == 0
+
+ CONVENTION
+ table.size() = size() == the number of elements in the set and
+ the elements in this hash_set are stored in table
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef typename hash_table::compare_type compare_type;
+ typedef mem_manager mem_manager_type;
+
+ hash_set_kernel_1(
+ ) :
+ table(expnum)
+ {
+ COMPILE_TIME_ASSERT(expnum < 32);
+ }
+
+ virtual ~hash_set_kernel_1(
+ )
+ {}
+
+ inline void clear(
+ );
+
+ inline void add (
+ T& item
+ );
+
+ inline bool is_member (
+ const T& item
+ ) const;
+
+ inline void remove (
+ const T& item,
+ T& item_copy
+ );
+
+ inline void destroy (
+ const T& item
+ );
+
+ inline void swap (
+ hash_set_kernel_1& item
+ );
+
+ // functions from the remover interface
+ inline void remove_any (
+ T& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ inline bool current_element_valid (
+ ) const;
+
+ inline const T& element (
+ ) const;
+
+ inline const T& element (
+ );
+
+ inline bool move_next (
+ ) const;
+
+ private:
+
+ hash_table table;
+ char junk;
+
+ // restricted functions
+ hash_set_kernel_1(hash_set_kernel_1&);
+ hash_set_kernel_1& operator= ( hash_set_kernel_1&);
+
+ };
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ inline void swap (
+ hash_set_kernel_1<T,expnum,hash_table,mem_manager>& a,
+ hash_set_kernel_1<T,expnum,hash_table,mem_manager>& b
+ ) { a.swap(b); }
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void deserialize (
+ hash_set_kernel_1<T,expnum,hash_table,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ T temp;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(temp,in);
+ item.add(temp);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type hash_set_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ clear (
+ )
+ {
+ table.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ add (
+ T& item
+ )
+ {
+ table.add(item,junk);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ bool hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ is_member(
+ const T& item
+ ) const
+ {
+ return (table[item] != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ remove_any (
+ T& item
+ )
+ {
+ table.remove_any(item,junk);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ remove(
+ const T& item,
+ T& item_copy
+ )
+ {
+ table.remove(item,item_copy,junk);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ destroy(
+ const T& item
+ )
+ {
+ table.destroy(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ size_t hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ size (
+ ) const
+ {
+ return table.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ swap (
+ hash_set_kernel_1<T,expnum,hash_table,mem_manager>& item
+ )
+ {
+ table.swap(item.table);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ bool hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ at_start (
+ ) const
+ {
+ return table.at_start();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ void hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ reset (
+ ) const
+ {
+ table.reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ bool hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return table.current_element_valid();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ const T& hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ element (
+ ) const
+ {
+ return table.element().key();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ const T& hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ element (
+ )
+ {
+ return table.element().key();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename hash_table,
+ typename mem_manager
+ >
+ bool hash_set_kernel_1<T,expnum,hash_table,mem_manager>::
+ move_next (
+ ) const
+ {
+ return table.move_next();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HASH_SET_KERNEl_1_
+
diff --git a/ml/dlib/dlib/hash_set/hash_set_kernel_abstract.h b/ml/dlib/dlib/hash_set/hash_set_kernel_abstract.h
new file mode 100644
index 000000000..35d1de74e
--- /dev/null
+++ b/ml/dlib/dlib/hash_set/hash_set_kernel_abstract.h
@@ -0,0 +1,207 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_HASH_SET_KERNEl_ABSTRACT_
+#ifdef DLIB_HASH_SET_KERNEl_ABSTRACT_
+
+#include "../general_hash/general_hash.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include "../algs.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<T>
+ >
+ class hash_set : public enumerable<const T>,
+ public remover<T>
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ domain must be comparable by compare where compare is a functor compatible with std::less and
+ T must be hashable by general_hash
+ (general_hash is defined in dlib/general_hash) and
+ T must be swappable by a global swap() and
+ T must have a default constructor
+
+ REQUIREMENTS ON expnum
+ expnum < 32
+ 2^expnum is the number of buckets to hash items of type T into.
+ Note that this is really just a suggestion to the hash table.
+ Implementations are free to manage the table size however is most
+ appropriate.
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap() and is_member() functions do not invalidate
+ pointers or references to internal data.
+ All other functions have no such guarantee.
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ No order is specified. Only that each element will be visited once
+ and only once.
+
+ WHAT THIS OBJECT REPRESENTS
+ hash_set contains items of type T
+
+ This object represents an unaddressed collection
+ of items. Every element in a hash_set is unique.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+ definition of equivalent:
+ a is equivalent to b if
+ a < b == false and
+ b < a == false
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ hash_set(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ !*/
+
+ virtual ~hash_set(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void add (
+ T& item
+ );
+ /*!
+ requires
+ - is_member(item) == false
+ ensures
+ - #is_member(item) == true
+ - #item has an initial value for its type
+ - #size() == size() + 1
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if add() throws then it has no effect
+ !*/
+
+ bool is_member (
+ const T& item
+ ) const;
+ /*!
+ ensures
+ - returns whether or not there is an element in *this equivalent
+ to item
+ !*/
+
+ void remove (
+ const T& item,
+ T& item_copy
+ );
+ /*!
+ requires
+ - is_member(item) == true
+ - &item != &item_copy (i.e. item and item_copy cannot be the
+ same variable)
+ ensures
+ - #is_member(item) == false
+ - the element in *this equivalent to item has been removed and
+ swapped into #item_copy
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void destroy (
+ const T& item
+ );
+ /*!
+ requires
+ - is_member(item) == true
+ ensures
+ - #is_member(item) == false
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void swap (
+ hash_set& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ hash_set(hash_set&); // copy constructor
+ hash_set& operator=(hash_set&); // assignment operator
+
+ };
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename mem_manager,
+ typename compare
+ >
+ inline void swap (
+ hash_set<T,expnum,mem_manager,compare>& a,
+ hash_set<T,expnum,mem_manager,compare>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ unsigned long expnum,
+ typename mem_manager,
+ typename compare
+ >
+ void deserialize (
+ hash_set<T,expnum,mem_manager,compare>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_HASH_SET_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/hash_set/hash_set_kernel_c.h b/ml/dlib/dlib/hash_set/hash_set_kernel_c.h
new file mode 100644
index 000000000..bd0abd848
--- /dev/null
+++ b/ml/dlib/dlib/hash_set/hash_set_kernel_c.h
@@ -0,0 +1,190 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_SET_KERNEl_C_
+#define DLIB_HASH_SET_KERNEl_C_
+
+#include "hash_set_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ typename hash_set_base
+ >
+ class hash_set_kernel_c : public hash_set_base
+ {
+ typedef typename hash_set_base::type T;
+ public:
+
+ void add (
+ T& item
+ );
+
+ void remove_any (
+ T& item
+ );
+
+ void remove (
+ const T& item,
+ T& item_copy
+ );
+
+ void destroy (
+ const T& item
+ );
+
+ const T& element (
+ ) const;
+
+ const T& element (
+ );
+
+
+ };
+
+
+ template <
+ typename hash_set_base
+ >
+ inline void swap (
+ hash_set_kernel_c<hash_set_base>& a,
+ hash_set_kernel_c<hash_set_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_set_base
+ >
+ void hash_set_kernel_c<hash_set_base>::
+ add(
+ T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !this->is_member(item),
+ "\tvoid hash_set::add"
+ << "\n\titem being added must not already be in the hash_set"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ hash_set_base::add(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_set_base
+ >
+ void hash_set_kernel_c<hash_set_base>::
+ remove (
+ const T& item,
+ T& item_copy
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->is_member(item) &&
+ (static_cast<const void*>(&item) != static_cast<void*>(&item_copy)),
+ "\tvoid hash_set::remove"
+ << "\n\titem should be in the hash_set if it's going to be removed"
+ << "\n\tthis: " << this
+ << "\n\t&item: " << &item
+ << "\n\t&item_copy: " << &item_copy
+ );
+
+ // call the real function
+ hash_set_base::remove(item,item_copy);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_set_base
+ >
+ void hash_set_kernel_c<hash_set_base>::
+ destroy (
+ const T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->is_member(item),
+ "\tvoid hash_set::destroy"
+ << "\n\titem should be in the hash_set if it's going to be removed"
+ << "\n\tthis: " << this
+ << "\n\t&item: " << &item
+ );
+
+ // call the real function
+ hash_set_base::destroy(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_set_base
+ >
+ void hash_set_kernel_c<hash_set_base>::
+ remove_any (
+ T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->size() != 0,
+ "\tvoid hash_set::remove_any"
+ << "\n\tsize must be greater than zero if an item is to be removed"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ hash_set_base::remove_any(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_set_base
+ >
+ const typename hash_set_base::type& hash_set_kernel_c<hash_set_base>::
+ element (
+ ) const
+ {
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst T& hash_set::element()"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return hash_set_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hash_set_base
+ >
+ const typename hash_set_base::type& hash_set_kernel_c<hash_set_base>::
+ element (
+ )
+ {
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tT& hash_set::element()"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return hash_set_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HASH_SET_KERNEl_C_
+
diff --git a/ml/dlib/dlib/hash_table.h b/ml/dlib/dlib/hash_table.h
new file mode 100644
index 000000000..b49feb916
--- /dev/null
+++ b/ml/dlib/dlib/hash_table.h
@@ -0,0 +1,60 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_TABLe_
+#define DLIB_HASH_TABLe_
+
+
+#include "hash_table/hash_table_kernel_1.h"
+#include "hash_table/hash_table_kernel_2.h"
+#include "hash_table/hash_table_kernel_c.h"
+#include "algs.h"
+
+#include "binary_search_tree.h"
+#include <functional>
+
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<domain>
+ >
+ class hash_table
+ {
+ hash_table() {}
+
+ typedef typename binary_search_tree<domain,range,mem_manager,compare>::kernel_1a
+ bst_1;
+ typedef typename binary_search_tree<domain,range,mem_manager,compare>::kernel_2a
+ bst_2;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef hash_table_kernel_1<domain,range,mem_manager,compare>
+ kernel_1a;
+ typedef hash_table_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ // kernel_2a
+ typedef hash_table_kernel_2<domain,range,bst_1,mem_manager,compare>
+ kernel_2a;
+ typedef hash_table_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+ // kernel_2b
+ typedef hash_table_kernel_2<domain,range,bst_2,mem_manager,compare>
+ kernel_2b;
+ typedef hash_table_kernel_c<kernel_2b>
+ kernel_2b_c;
+ };
+}
+
+#endif // DLIB_HASH_TABLe_
+
diff --git a/ml/dlib/dlib/hash_table/hash_table_kernel_1.h b/ml/dlib/dlib/hash_table/hash_table_kernel_1.h
new file mode 100644
index 000000000..f06847a72
--- /dev/null
+++ b/ml/dlib/dlib/hash_table/hash_table_kernel_1.h
@@ -0,0 +1,819 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_TABLE_KERNEl_1_
+#define DLIB_HASH_TABLE_KERNEl_1_
+
+#include "hash_table_kernel_abstract.h"
+#include "../general_hash/general_hash.h"
+#include "../algs.h"
+#include "../interfaces/map_pair.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../assert.h"
+#include "../serialize.h"
+#include <functional>
+
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<domain>
+ >
+ class hash_table_kernel_1 : public enumerable<map_pair<domain, range> >,
+ public pair_remover<domain,range>
+ {
+
+ /*!
+ INITIAL VALUE
+ hash_size == 0
+ table == pointer to an array of num_of_buckets node pointers
+ num_of_buckets == the number of buckets in the hash table
+ current_element == 0
+ at_start_ == true
+ mask == num_of_buckets-1
+
+ CONVENTION
+ current_element_valid() == (current_element != 0)
+ element() == current_element->d and current_element->r
+ at_start_ == at_start()
+ if (current_element != 0) then
+ table[current_bucket] == a pointer to the linked list that contains
+ the node pointed to by current_element
+
+ mask == num_of_buckets-1
+
+
+
+ hash_size = size() == the number of elements in the hash_table and
+ table == pointer to an array of num_of_buckets node pointers and
+ num_of_buckets == the number of buckets in the hash table and
+ for all i:
+ table[i] == pointer to the first node in a linked list or
+ table[i] == 0 if this bucket is currently not in use
+
+
+ for all nodes:
+ d == the domain element stored in this node
+ r == the range element stored in this node which is associated with
+ d.
+ next == pointer to the next node in the linked list or
+ next == 0 if this is the last node in the linked list
+
+ !*/
+
+ struct node
+ {
+ node* next;
+ domain d;
+ range r;
+ };
+
+
+ class mpair : public map_pair<domain,range>
+ {
+ public:
+ const domain* d;
+ range* r;
+
+ const domain& key(
+ ) const { return *d; }
+
+ const range& value(
+ ) const { return *r; }
+
+ range& value(
+ ) { return *r; }
+ };
+
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ explicit hash_table_kernel_1(
+ unsigned long expnum
+ );
+
+ virtual ~hash_table_kernel_1(
+ );
+
+ void clear(
+ );
+
+ unsigned long count (
+ const domain& item
+ ) const;
+
+ void add (
+ domain& d,
+ range& r
+ );
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ void destroy (
+ const domain& d
+ );
+
+ const range* operator[] (
+ const domain& d
+ ) const;
+
+ range* operator[] (
+ const domain& d
+ );
+
+ void swap (
+ hash_table_kernel_1& item
+ );
+
+ // functions from the remover interface
+ void remove_any (
+ domain& d,
+ range& r
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ const map_pair<domain,range>& element (
+ ) const;
+
+ map_pair<domain,range>& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ private:
+
+ // data members
+ typename mem_manager::template rebind<node>::other pool;
+ typename mem_manager::template rebind<node*>::other ppool;
+ unsigned long hash_size;
+ node** table;
+ general_hash<domain> hash;
+ unsigned long num_of_buckets;
+ unsigned long mask;
+
+ mutable mpair p;
+
+ mutable unsigned long current_bucket;
+ mutable node* current_element;
+ mutable bool at_start_;
+ compare comp;
+
+ // restricted functions
+ hash_table_kernel_1(hash_table_kernel_1&);
+ hash_table_kernel_1& operator=(hash_table_kernel_1&);
+
+ };
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ inline void swap (
+ hash_table_kernel_1<domain,range,mem_manager,compare>& a,
+ hash_table_kernel_1<domain,range,mem_manager,compare>& b
+ ) { a.swap(b); }
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void deserialize (
+ hash_table_kernel_1<domain,range,mem_manager,compare>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ domain d;
+ range r;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(d,in);
+ deserialize(r,in);
+ item.add(d,r);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type hash_table_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ hash_table_kernel_1<domain,range,mem_manager,compare>::
+ hash_table_kernel_1(
+ unsigned long expnum
+ ) :
+ hash_size(0),
+ current_element(0),
+ at_start_(true)
+ {
+
+ num_of_buckets = 1;
+ while (expnum != 0)
+ {
+ --expnum;
+ num_of_buckets <<= 1;
+ }
+ mask = num_of_buckets-1;
+
+ table = ppool.allocate_array(num_of_buckets);
+ for (unsigned long i = 0; i < num_of_buckets; ++i)
+ {
+ table[i] = 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ hash_table_kernel_1<domain,range,mem_manager,compare>::
+ ~hash_table_kernel_1(
+ )
+ {
+ for (unsigned long i = 0; i < num_of_buckets; ++i)
+ {
+ // delete this linked list
+ node* temp = table[i];
+ while (temp)
+ {
+ node* t = temp;
+ temp = temp->next;
+ pool.deallocate(t);
+ }
+ table[i] = 0;
+ }
+ ppool.deallocate_array(table);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_1<domain,range,mem_manager,compare>::
+ clear(
+ )
+ {
+ if (hash_size > 0)
+ {
+ for (unsigned long i = 0; i < num_of_buckets; ++i)
+ {
+ // delete this linked list
+ node* temp = table[i];
+ while (temp)
+ {
+ node* t = temp;
+ temp = temp->next;
+ pool.deallocate(t);
+ }
+ table[i] = 0;
+ }
+ hash_size = 0;
+ }
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ size_t hash_table_kernel_1<domain,range,mem_manager,compare>::
+ size(
+ ) const
+ {
+ return hash_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ unsigned long hash_table_kernel_1<domain,range,mem_manager,compare>::
+ count(
+ const domain& d
+ ) const
+ {
+ unsigned long items_found = 0;
+ node* temp = table[hash(d)&mask];
+
+ while (temp != 0)
+ {
+ // look for an element equivalent to d
+ if ( !(comp(temp->d , d) || comp(d , temp->d)) )
+ {
+ ++items_found;
+ }
+ temp = temp->next;
+ }
+
+ return items_found;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_1<domain,range,mem_manager,compare>::
+ add(
+ domain& d,
+ range& r
+ )
+ {
+ unsigned long hash_value = hash(d)&mask;
+
+ // make a new node for this item
+ node& temp = *(pool.allocate());
+ exchange(d,temp.d);
+ exchange(r,temp.r);
+
+ // add this new node to the head of the linked list in bucket number hash_value
+ temp.next = table[hash_value];
+ table[hash_value] = &temp;
+
+ ++hash_size;
+
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_1<domain,range,mem_manager,compare>::
+ destroy(
+ const domain& d
+ )
+ {
+ node* last;
+ const unsigned long hash_value = hash(d)&mask;
+ node* temp = table[hash_value];
+
+ // if there is more than one thing in this bucket
+ if (temp->next != 0)
+ {
+ // start looking with the second item in the list
+ last = temp;
+ temp = temp->next;
+ while (true)
+ {
+ // if we hit the end of the list without finding item then it must
+ // be the first element in the list so splice it out
+ if (temp == 0)
+ {
+ temp = table[hash_value];
+ table[hash_value] = temp->next;
+
+ break;
+ }
+
+ // look for an element equivalent to item
+ if ( !(comp(temp->d , d) || comp(d , temp->d)) )
+ {
+ // splice out the node we want to remove
+ last->next = temp->next;
+ break;
+ }
+
+ last = temp;
+ temp = temp->next;
+ }
+
+ }
+ // else there is only one node in this linked list
+ else
+ {
+ table[hash_value] = 0;
+ }
+
+ pool.deallocate(temp);
+
+ --hash_size;
+
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_1<domain,range,mem_manager,compare>::
+ remove(
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ node* last;
+ const unsigned long hash_value = hash(d)&mask;
+ node* temp = table[hash_value];
+
+ // if there is more than one thing in this bucket
+ if (temp->next != 0)
+ {
+ // start looking with the second item in the list
+ last = temp;
+ temp = temp->next;
+ while (true)
+ {
+ // if we hit the end of the list without finding item then it must
+ // be the first element in the list so splice it out
+ if (temp == 0)
+ {
+ temp = table[hash_value];
+ table[hash_value] = temp->next;
+
+ break;
+ }
+
+ // look for an element equivalent to item
+ if ( !(comp(temp->d , d) || comp(d , temp->d)) )
+ {
+ // splice out the node we want to remove
+ last->next = temp->next;
+ break;
+ }
+
+ last = temp;
+ temp = temp->next;
+ }
+
+ }
+ // else there is only one node in this linked list
+ else
+ {
+ table[hash_value] = 0;
+ }
+
+
+ exchange(d_copy,temp->d);
+ exchange(r,temp->r);
+ pool.deallocate(temp);
+
+ --hash_size;
+
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_1<domain,range,mem_manager,compare>::
+ remove_any(
+ domain& d,
+ range& r
+ )
+ {
+ unsigned long i = 0;
+
+ // while the ith bucket is empty keep looking
+ while (table[i] == 0)
+ {
+ ++i;
+ }
+
+ // remove the first node in the linked list in the ith bucket
+ node& temp = *(table[i]);
+
+ exchange(temp.d,d);
+ exchange(temp.r,r);
+ table[i] = temp.next;
+
+ pool.deallocate(&temp);
+
+ --hash_size;
+
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ const range* hash_table_kernel_1<domain,range,mem_manager,compare>::
+ operator[](
+ const domain& d
+ ) const
+ {
+ node* temp = table[hash(d)&mask];
+
+ while (temp != 0)
+ {
+ // look for an element equivalent to item
+ if ( !(comp(temp->d , d) || comp(d , temp->d)) )
+ return &(temp->r);
+
+ temp = temp->next;
+ }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ range* hash_table_kernel_1<domain,range,mem_manager,compare>::
+ operator[](
+ const domain& d
+ )
+ {
+ node* temp = table[hash(d)&mask];
+
+ while (temp != 0)
+ {
+ // look for an element equivalent to item
+ if ( !(comp(temp->d , d) || comp(d , temp->d)) )
+ return &(temp->r);
+
+ temp = temp->next;
+ }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_1<domain,range,mem_manager,compare>::
+ swap(
+ hash_table_kernel_1<domain,range,mem_manager,compare>& item
+ )
+ {
+ exchange(mask,item.mask);
+ exchange(table,item.table);
+ exchange(hash_size,item.hash_size);
+ exchange(num_of_buckets,item.num_of_buckets);
+ exchange(current_bucket,item.current_bucket);
+ exchange(current_element,item.current_element);
+ exchange(at_start_,item.at_start_);
+ pool.swap(item.pool);
+ ppool.swap(item.ppool);
+ exchange(p,item.p);
+ exchange(comp,item.comp);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool hash_table_kernel_1<domain,range,mem_manager,compare>::
+ at_start (
+ ) const
+ {
+ return at_start_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_1<domain,range,mem_manager,compare>::
+ reset (
+ ) const
+ {
+ at_start_ = true;
+ current_element = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool hash_table_kernel_1<domain,range,mem_manager,compare>::
+ current_element_valid (
+ ) const
+ {
+ return (current_element != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ const map_pair<domain,range>& hash_table_kernel_1<domain,range,mem_manager,compare>::
+ element (
+ ) const
+ {
+ p.d = &(current_element->d);
+ p.r = &(current_element->r);
+ return p;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ map_pair<domain,range>& hash_table_kernel_1<domain,range,mem_manager,compare>::
+ element (
+ )
+ {
+ p.d = &(current_element->d);
+ p.r = &(current_element->r);
+ return p;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ bool hash_table_kernel_1<domain,range,mem_manager,compare>::
+ move_next (
+ ) const
+ {
+ if (at_start_)
+ {
+ at_start_ = false;
+ // if the queue is empty then there is nothing to do
+ if (hash_size == 0)
+ {
+ return false;
+ }
+ else
+ {
+ // find the first element in the hash table
+ for (current_bucket = 0; true ; ++current_bucket)
+ {
+ if (table[current_bucket] != 0)
+ {
+ current_element = table[current_bucket];
+ break;
+ }
+ }
+ return true;
+ }
+ }
+ else
+ {
+ // if we have already enumerated every element
+ if (current_element == 0)
+ {
+ return false;
+ }
+ else
+ {
+ // find the next element if it exists
+ if (current_element->next != 0)
+ {
+ current_element = current_element->next;
+ return true;
+ }
+ else
+ {
+ // find next bucket with something in it
+ for (current_bucket+=1; current_bucket<num_of_buckets; ++current_bucket)
+ {
+ if (table[current_bucket] != 0)
+ {
+ // we just found the next bucket
+ current_element = table[current_bucket];
+ break;
+ }
+ }
+ // make sure we actually found another nonempty bucket
+ if (current_bucket == num_of_buckets)
+ {
+ // we didn't find anything
+ current_element = 0;
+ return false;
+ }
+ else
+ {
+ // we found another bucket
+ return true;
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HASH_TABLE_KERNEl_1_
+
diff --git a/ml/dlib/dlib/hash_table/hash_table_kernel_2.h b/ml/dlib/dlib/hash_table/hash_table_kernel_2.h
new file mode 100644
index 000000000..58646413c
--- /dev/null
+++ b/ml/dlib/dlib/hash_table/hash_table_kernel_2.h
@@ -0,0 +1,612 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_TABLE_KERNEl_2_
+#define DLIB_HASH_TABLE_KERNEl_2_
+
+#include "hash_table_kernel_abstract.h"
+#include "../general_hash/general_hash.h"
+#include "../algs.h"
+#include "../interfaces/map_pair.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../assert.h"
+#include "../serialize.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<domain>
+ >
+ class hash_table_kernel_2 : public enumerable<map_pair<domain,range> >,
+ public pair_remover<domain,range>
+ {
+
+ /*!
+ REQUIREMENTS ON bst_base
+ bst_base is instantiated with domain and range and
+ implements binray_search_tree/binary_search_tree_kernel_abstract.h
+
+ INITIAL VALUE
+ hash_size == 0
+ table == pointer to an array of num_of_buckets bst_base objects
+ num_of_buckets == the number of buckets in the hash table
+ current_bucket == 0
+ at_start_ == true
+
+ CONVENTION
+ current_element_valid() == (current_bucket != 0)
+ element() == current_bucket->element()
+ at_start_ == at_start()
+
+ mask == num_of_buckets-1
+
+ for all integers i where &table[i] != current_bucket
+ table[i].at_start() == true
+
+
+ hash_size = size() == the number of elements in the hash_table and
+ table == pointer to an array of num_of_buckets bst_base objects
+ num_of_buckets == the number of buckets in the hash table and
+ the elements in this hash table are stored in the bst_base objects in the
+ array table
+
+ !*/
+
+
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ explicit hash_table_kernel_2(
+ unsigned long expnum
+ );
+
+ virtual ~hash_table_kernel_2(
+ )
+ { pool.deallocate_array(table); }
+
+ void clear(
+ );
+
+ unsigned long count (
+ const domain& item
+ ) const;
+
+ inline void add (
+ domain& d,
+ range& r
+ );
+
+ void destroy (
+ const domain& d
+ );
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ const range* operator[] (
+ const domain& item
+ ) const;
+
+ range* operator[] (
+ const domain& item
+ );
+
+ inline void swap (
+ hash_table_kernel_2& item
+ );
+
+ // functions from the remover interface
+ void remove_any (
+ domain& d,
+ range& r
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ inline const map_pair<domain,range>& element (
+ ) const;
+
+ inline map_pair<domain,range>& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ private:
+
+ // data members
+ typename mem_manager::template rebind<bst_base>::other pool;
+ unsigned long mask;
+ unsigned long hash_size;
+ unsigned long num_of_buckets;
+ bst_base* table;
+ general_hash<domain> hash;
+ mutable bst_base* current_bucket;
+ mutable bool at_start_;
+ compare comp;
+
+ // restricted functions
+ hash_table_kernel_2(hash_table_kernel_2&);
+ hash_table_kernel_2& operator=(hash_table_kernel_2&);
+
+ };
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ inline void swap (
+ hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>& a,
+ hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>& b
+ ) { a.swap(b); }
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ void deserialize (
+ hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ domain d;
+ range r;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(d,in);
+ deserialize(r,in);
+ item.add(d,r);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type hash_table_kernel_2");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ hash_table_kernel_2(
+ unsigned long expnum
+ ) :
+ hash_size(0),
+ current_bucket(0),
+ at_start_(true)
+ {
+
+ num_of_buckets = 1;
+ while (expnum != 0)
+ {
+ --expnum;
+ num_of_buckets <<= 1;
+ }
+ mask = num_of_buckets-1;
+
+ table = pool.allocate_array(num_of_buckets);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ clear(
+ )
+ {
+ if (hash_size != 0)
+ {
+ hash_size = 0;
+ for (unsigned long i = 0; i < num_of_buckets; ++i)
+ table[i].clear();
+ }
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ size_t hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ size(
+ ) const
+ {
+ return hash_size;
+ }
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ unsigned long hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ count(
+ const domain& item
+ ) const
+ {
+ return table[hash(item)&mask].count(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ destroy(
+ const domain& item
+ )
+ {
+ table[hash(item)&mask].destroy(item);
+ --hash_size;
+
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ add(
+ domain& d,
+ range& r
+ )
+ {
+ table[hash(d)&mask].add(d,r);
+ ++hash_size;
+
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ remove(
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ table[hash(d)&mask].remove(d,d_copy,r);
+ --hash_size;
+
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ remove_any(
+ domain& d,
+ range& r
+ )
+ {
+ unsigned long i = 0;
+ while (table[i].size() == 0)
+ {
+ ++i;
+ }
+ table[i].remove_any(d,r);
+ --hash_size;
+
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ const range* hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ operator[](
+ const domain& d
+ ) const
+ {
+ return table[hash(d)&mask][d];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ range* hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ operator[](
+ const domain& d
+ )
+ {
+ return table[hash(d)&mask][d];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ swap(
+ hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>& item
+ )
+ {
+ pool.swap(item.pool);
+ exchange(mask,item.mask);
+ exchange(hash_size,item.hash_size);
+ exchange(num_of_buckets,item.num_of_buckets);
+ exchange(table,item.table);
+ exchange(current_bucket,item.current_bucket);
+ exchange(at_start_,item.at_start_);
+ exchange(comp,item.comp);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ bool hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ at_start (
+ ) const
+ {
+ return at_start_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ void hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ reset (
+ ) const
+ {
+ at_start_ = true;
+ if (current_bucket != 0)
+ {
+ current_bucket->reset();
+ current_bucket = 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ bool hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ current_element_valid (
+ ) const
+ {
+ return (current_bucket != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ const map_pair<domain,range>& hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ element (
+ ) const
+ {
+ return current_bucket->element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ map_pair<domain,range>& hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ element (
+ )
+ {
+ return current_bucket->element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager,
+ typename compare
+ >
+ bool hash_table_kernel_2<domain,range,bst_base,mem_manager,compare>::
+ move_next (
+ ) const
+ {
+ if (at_start_)
+ {
+ at_start_ = false;
+ // if the queue is empty then there is nothing to do
+ if (hash_size == 0)
+ {
+ return false;
+ }
+ else
+ {
+ // find the first element in the hash table
+ current_bucket = table;
+ while (current_bucket->size() == 0)
+ {
+ ++current_bucket;
+ }
+
+ current_bucket->move_next();
+
+ return true;
+ }
+ }
+ else
+ {
+ // if we have already enumerated every element
+ if (current_bucket == 0)
+ {
+ return false;
+ }
+ else
+ {
+ if (current_bucket->move_next())
+ {
+ // if there is another element in this current bucket then use that
+ return true;
+ }
+ else
+ {
+ // find the next bucket
+ bst_base* end = table + num_of_buckets;
+ current_bucket->reset();
+
+ while (true)
+ {
+ ++current_bucket;
+ // if we ran out of buckets and didn't find anything
+ if (current_bucket == end)
+ {
+ current_bucket = 0;
+ return false;
+ }
+ if (current_bucket->size() > 0)
+ {
+ current_bucket->move_next();
+ return true;
+ }
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HASH_TABLE_KERNEl_2_
+
diff --git a/ml/dlib/dlib/hash_table/hash_table_kernel_abstract.h b/ml/dlib/dlib/hash_table/hash_table_kernel_abstract.h
new file mode 100644
index 000000000..52480e91c
--- /dev/null
+++ b/ml/dlib/dlib/hash_table/hash_table_kernel_abstract.h
@@ -0,0 +1,253 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_HASH_TABLE_KERNEl_ABSTRACT_
+#ifdef DLIB_HASH_TABLE_KERNEl_ABSTRACT_
+
+#include "../interfaces/map_pair.h"
+#include "../general_hash/general_hash.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include "../algs.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<domain>
+ >
+ class hash_table : public enumerable<map_pair<domain,range> >,
+ public pair_remover<domain,range>
+ {
+
+ /*!
+ REQUIREMENTS ON domain
+ domain must be comparable by compare where compare is a functor compatible with std::less and
+ domain must be hashable by general_hash
+ (general_hash is defined in dlib/general_hash) and
+ domain must be swappable by a global swap() and
+ domain must have a default constructor
+
+ REQUIREMENTS ON range
+ range must be swappable by a global swap() and
+ range must have a default constructor
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap(), count(), and operator[] functions do
+ not invalidate pointers or references to internal data.
+ All other functions have no such guarantee.
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ No order is specified. Only that each element will be visited once
+ and only once.
+
+ WHAT THIS OBJECT REPRESENTS
+ hash_table contains items of type T
+
+ This object represents a data dictionary that is built on top of some
+ kind of hash table. The number of buckets in the hash table is
+ defined by the constructor argument and is some power of 2.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+ NOTE:
+ definition of equivalent:
+ a is equivalent to b if
+ a < b == false and
+ b < a == false
+ !*/
+
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ explicit hash_table(
+ unsigned long expnum
+ );
+ /*!
+ requires
+ - expnum < 32
+ ensures
+ - #*this is properly initialized
+ - #*this will use 2^expnum as a suggestion for the initial number
+ of buckets.
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ !*/
+
+ virtual ~hash_table(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ unsigned long count (
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - returns the number of elements in the domain of *this that are
+ equivalent to d
+ !*/
+
+ void add (
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - &d != &r (i.e. d and r cannot be the same variable)
+ ensures
+ - adds a mapping between d and r to *this
+ - if (count(d) == 0) then
+ - #*(*this)[d] == r
+ - else
+ - #(*this)[d] != 0
+ - #d and #r have initial values for their types
+ - #count(d) == count(d) + 1
+ - #at_start() == true
+ - #size() == size() + 1
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ if add() throws then it has no effect
+ !*/
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+ /*!
+ requires
+ - (*this)[d] != 0
+ - &d != &r (i.e. d and r cannot be the same variable)
+ - &d != &d_copy (i.e. d and d_copy cannot be the same variable)
+ - &r != &d_copy (i.e. r and d_copy cannot be the same variable)
+ ensures
+ - some element in the domain of *this that is equivalent to d has
+ been removed and swapped into #d_copy. Additionally, its
+ associated range element has been removed and swapped into #r.
+ - #count(d) = count(d) - 1
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void destroy (
+ const domain& d
+ );
+ /*!
+ requires
+ - (*this)[d] != 0
+ ensures
+ - an element in the domain of *this equivalent to d has been removed.
+ The element in the range of *this associated with d has also been
+ removed.
+ - #count(d) == count(d) - 1
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ const range* operator[] (
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - if (there is an element in the domain equivalent to d) then
+ - returns a pointer to an element in the range of *this that
+ is associated with an element in the domain of *this
+ equivalent to d.
+ - else
+ - returns 0
+ !*/
+
+ range* operator[] (
+ const domain& d
+ );
+ /*!
+ ensures
+ - if (there is an element in the domain equivalent to d) then
+ - returns a pointer to an element in the range of *this that
+ is associated with an element in the domain of *this
+ equivalent to d.
+ - else
+ - returns 0
+ !*/
+
+ void swap (
+ hash_table& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ hash_table(hash_table&);
+ hash_table& operator=(hash_table&);
+
+ };
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager
+ >
+ inline void swap (
+ hash_table<domain,range,mem_manager>& a,
+ hash_table<domain,range,mem_manager>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager
+ >
+ void deserialize (
+ hash_table<domain,range,mem_manager>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_HASH_TABLE_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/hash_table/hash_table_kernel_c.h b/ml/dlib/dlib/hash_table/hash_table_kernel_c.h
new file mode 100644
index 000000000..93ef16d0b
--- /dev/null
+++ b/ml/dlib/dlib/hash_table/hash_table_kernel_c.h
@@ -0,0 +1,194 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASH_TABLE_KERNEl_C_
+#define DLIB_HASH_TABLE_KERNEl_C_
+
+#include "hash_table_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/map_pair.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ typename ht_base
+ >
+ class hash_table_kernel_c : public ht_base
+ {
+ typedef typename ht_base::domain_type domain;
+ typedef typename ht_base::range_type range;
+ public:
+
+ explicit hash_table_kernel_c (
+ unsigned long expnum
+ ) :
+ ht_base(expnum)
+ {
+ DLIB_CASSERT(expnum < 32,
+ "\thash_table::hash_table(unsigned long)"
+ << "\n\tyou can't set expnum >= 32"
+ << "\n\tthis: " << this
+ << "\n\texpnum: " << expnum
+ );
+ }
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ void remove_any (
+ domain& d,
+ range& r
+ );
+
+ void add (
+ domain& d,
+ range& r
+ );
+
+ void destroy (
+ const domain& d
+ );
+
+ const map_pair<domain,range>& element (
+ ) const
+ {
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst map_pair<domain,range>& hash_table::element() const"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return ht_base::element();
+ }
+
+ map_pair<domain,range>& element (
+ )
+ {
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tmap_pair<domain,range>& hash_table::element()"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return ht_base::element();
+ }
+
+
+ };
+
+
+ template <
+ typename ht_base
+ >
+ inline void swap (
+ hash_table_kernel_c<ht_base>& a,
+ hash_table_kernel_c<ht_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename ht_base
+ >
+ void hash_table_kernel_c<ht_base>::
+ remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ DLIB_CASSERT(this->operator[](d) != 0 &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&d_copy)) &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&r)) &&
+ (static_cast<const void*>(&r) != static_cast<void*>(&d_copy)),
+ "\tvoid binary_search_tree::remove"
+ << "\n\tthe element must be in the table for it to be removed"
+ << "\n\tthis: " << this
+ << "\n\t&d: " << &d
+ << "\n\t&d_copy: " << &d_copy
+ << "\n\t&r: " << &r
+ );
+
+ ht_base::remove(d,d_copy,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename ht_base
+ >
+ void hash_table_kernel_c<ht_base>::
+ add(
+ domain& d,
+ range& r
+ )
+ {
+ DLIB_CASSERT( static_cast<const void*>(&d) != static_cast<void*>(&r),
+ "\tvoid binary_search_tree::add"
+ << "\n\tyou can't call add() and give the same object to both arguments."
+ << "\n\tthis: " << this
+ << "\n\t&d: " << &d
+ << "\n\t&r: " << &r
+ << "\n\tsize(): " << this->size()
+ );
+
+ ht_base::add(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename ht_base
+ >
+ void hash_table_kernel_c<ht_base>::
+ destroy(
+ const domain& d
+ )
+ {
+ DLIB_CASSERT((*this)[d] != 0,
+ "\tvoid hash_table::destroy"
+ << "\n\tthe element must be in the table for it to be destroyed"
+ << "\n\tthis: " << this
+ << "\n\t&d: " << &d
+ );
+
+ ht_base::destroy(d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename ht_base
+ >
+ void hash_table_kernel_c<ht_base>::
+ remove_any(
+ domain& d,
+ range& r
+ )
+ {
+ DLIB_CASSERT(this->size() != 0 &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&r)),
+ "\tvoid hash_table::remove_any"
+ << "\n\ttable must not be empty if something is going to be removed"
+ << "\n\tthis: " << this
+ << "\n\t&d: " << &d
+ << "\n\t&r: " << &r
+ );
+
+ ht_base::remove_any(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HASH_TABLE_KERNEl_C_
+
diff --git a/ml/dlib/dlib/http_client/http_client.cpp b/ml/dlib/dlib/http_client/http_client.cpp
new file mode 100644
index 000000000..75e838a73
--- /dev/null
+++ b/ml/dlib/dlib/http_client/http_client.cpp
@@ -0,0 +1,743 @@
+
+
+#include "../sockets.h"
+#include "../string.h"
+#include "../logger.h"
+#include "../sockstreambuf.h"
+#include "../timeout.h"
+#include "http_client.h"
+#include <time.h>
+#include <stdio.h>
+#include <fstream>
+#include <sstream>
+#include <iostream>
+
+namespace dlib
+{
+
+ typedef std::shared_ptr<dlib::timeout> timeout_ptr;
+
+
+#ifdef _MSC_VER
+#define BR_CASECMP strnicmp
+#else
+#define BR_CASECMP strncasecmp
+#endif
+// Default timeout after 60 seconds
+#define DEFAULT_TIMEOUT 60000
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool isXdigit( char c )
+ {
+ return (c >= '0' && c <= '9') ||
+ (c >= 'A' && c <= 'Z') ||
+ (c >= 'a' && c <= 'z');
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string http_client::urldecode( const std::string& s )
+ {
+ std::stringstream ss;
+
+ for ( char const * p_read = s.c_str(), * p_end = (s.c_str() + s.size()); p_read < p_end; p_read++ )
+ {
+ if ( p_read[0] == '%' && p_read+1 != p_end && p_read+2 != p_end && isXdigit(p_read[1]) && isXdigit(p_read[2]) )
+ {
+ ss << static_cast<char>((( (p_read[1] & 0xf) + ((p_read[1] >= 'A') ? 9 : 0) ) << 4 ) | ( (p_read[2] & 0xf) + ((p_read[2] >= 'A') ? 9 : 0) ));
+ p_read += 2;
+ }
+ else if ( p_read[0] == '+' )
+ {
+ // Undo the encoding that replaces spaces with plus signs.
+ ss << ' ';
+ }
+ else
+ {
+ ss << p_read[0];
+ }
+ }
+
+ return ss.str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+//! \return modified string ``s'' with spaces trimmed from left
+ inline std::string& triml(std::string& s)
+ {
+ int pos(0);
+ for ( ; s[pos] == ' ' || s[pos] == '\t' || s[pos] == '\r' || s[pos] == '\n' ; ++pos );
+ s.erase(0, pos);
+ return s;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+//! \return modified string ``s'' with spaces trimmed from right
+ inline std::string& trimr(std::string& s)
+ {
+ int pos(s.size());
+ for ( ; pos && (s[pos-1] == ' ' || s[pos-1] == '\t' || s[pos-1] == '\r' || s[pos-1] == '\n') ; --pos );
+ s.erase(pos, s.size()-pos);
+ return s;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+//! \return modified string ``s'' with spaces trimmed from edges
+ inline std::string& trim(std::string& s)
+ {
+ return triml(trimr(s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ http_client::
+ http_client(
+ ) :
+ http_return(0),
+ timeout(DEFAULT_TIMEOUT),
+ OnDownload(0)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string http_client::get_header(const std::string& header_name) const
+ {
+ stringmap::const_iterator ci = headers.find(header_name);
+ return ci != headers.end() ? ci->second : std::string();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void http_client::set_header(const std::string& header_name, long header_value)
+ {
+ char buf[21] = { 0 };
+#ifdef __WXMSW__
+ ::ltoa(header_value, buf, 10);
+#else
+ sprintf(buf, "%ld", header_value);
+#endif
+ set_header(header_name, buf);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void http_client::set_header(const std::string& header_name, const std::string& header_value)
+ {
+ headers[header_name] = header_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool http_client::is_header_set(const std::string& header_name) const
+ {
+ stringmap::const_iterator ci = headers.find(header_name);
+ return ci != headers.end() && !ci->second.empty();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void http_client::remove_header(const std::string& header_name)
+ {
+ headers.erase(header_name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void http_client::set_cookie(const std::string& cookie_name, long cookie_value)
+ {
+ char buf[21] = { 0 };
+#ifdef __WXMSW__
+ ::ltoa(cookie_value, buf, 10);
+#else
+ sprintf(buf, "%ld", cookie_value);
+#endif
+ set_cookie(cookie_name, buf);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void http_client::set_cookie(const std::string& cookie_name, const std::string& cookie_value)
+ {
+ cookies[cookie_name] = cookie_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void http_client::remove_cookie(const std::string& cookie_name)
+ {
+ cookies.erase(cookie_name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// POST
+ const std::string& http_client::post_url (const std::string& url, const string_to_stringmap& postvars, const string_to_stringmap& filenames)
+ {
+ std::string CT;
+ std::string postBody = build_post(CT, postvars, filenames);
+ set_header("Content-Type", CT);
+ set_header("Content-Length", static_cast<long>(postBody.size()));
+
+ grab_url(url, "POST", postBody);
+
+ return returned_body;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string& http_client::post_url (const std::string& url, const std::string& postbuffer)
+ {
+ if ( !is_header_set("Content-Type") ) // Maybe they just forgot it?
+ set_header("Content-Type", "application/x-www-form-urlencoded");
+
+ set_header("Content-Length", static_cast<long>(postbuffer.size()));
+
+ grab_url(url, "POST", postbuffer);
+
+ return returned_body;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string http_client::get_random_string( size_t length ) const
+ {
+ static bool has_seeded(false);
+ static std::string allowed_chars("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789");
+
+ if ( !has_seeded )
+ {
+ has_seeded = true;
+ ::srand( static_cast<unsigned int>(::time(NULL)) );
+ }
+
+ std::string retVal; retVal.reserve(length);
+ while ( retVal.size() < length )
+ {
+ retVal += allowed_chars[(rand() % allowed_chars.size())];
+ }
+
+ return retVal;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// static
+ std::string http_client::urlencode(const std::string& in, bool post_encode)
+ {
+ static std::string allowed_chars("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_");
+
+ std::stringstream ss;
+ ss << std::hex;
+ for (std::string::const_iterator ci = in.begin(); ci != in.end(); ++ci)
+ {
+ if ( allowed_chars.find(*ci) != std::string::npos )
+ {
+ ss << *ci;
+ }
+ else if ( post_encode && *ci == ' ' )
+ {
+ ss << '+';
+ }
+ else
+ {
+ ss << '%' << std::setfill('0') << std::setw(2) << std::right << static_cast<int>(*ci);
+ }
+ }
+
+ return ss.str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string http_client::get_basename( const std::string& filename ) const
+ {
+ std::string::size_type pos = filename.find_last_of("\\/");
+ if ( pos == std::string::npos )
+ return filename;
+ else
+ return filename.substr(pos+1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool http_client::parse_url(
+ const std::string& url,
+ std::string& scheme,
+ std::string& user,
+ std::string& pass,
+ std::string& host,
+ short& port,
+ std::string& path
+ ) const
+ {
+ scheme.clear();
+ user.clear();
+ pass.clear();
+ host.clear();
+ path.clear();
+ port = 0;
+
+ // Find scheme
+ std::string::size_type pos_scheme = url.find("://");
+ if ( pos_scheme == std::string::npos )
+ {
+ pos_scheme = 0;
+ }
+ else
+ {
+ scheme = strtolower(url.substr(0, pos_scheme));
+ pos_scheme += 3;
+ }
+
+ std::string::size_type pos_path = url.find('/', pos_scheme);
+ if ( pos_path == std::string::npos )
+ {
+ host = url.substr(pos_scheme);
+ }
+ else
+ {
+ host = url.substr(pos_scheme, pos_path - pos_scheme);
+ path = url.substr(pos_path);
+ }
+
+ std::string::size_type pos_at = host.find('@');
+ if ( pos_at != std::string::npos )
+ {
+ std::string::size_type pos_dp = host.find(':');
+ if ( pos_dp != std::string::npos && pos_dp < pos_at )
+ {
+ user = host.substr(0, pos_dp);
+ pass = host.substr(pos_dp+1, pos_at-pos_dp-1);
+ }
+ else
+ {
+ user = host.substr(0, pos_at);
+ }
+ host = host.substr(pos_at+1);
+ }
+
+ std::string::size_type pos_dp = host.find(':');
+ if ( pos_dp != std::string::npos )
+ {
+ port = dlib::string_cast<short>(host.substr(pos_dp+1));
+ host = host.substr(0, pos_dp);
+ }
+
+ host = strtolower(host);
+
+ if ( port == 0 )
+ {
+ if ( scheme == "http" )
+ port = 80;
+ else if ( scheme == "ftp" )
+ port = 21;
+ else if ( scheme == "https" )
+ port = 443;
+ }
+
+ return !host.empty();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string http_client::strtolower(const std::string& in) const
+ {
+ std::string retVal = in;
+
+ for (std::string::iterator ii = retVal.begin(); ii != retVal.end(); ++ii)
+ {
+ *ii = ::tolower(*ii);
+ }
+
+ return retVal;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string http_client::strtoupper(const std::string& in) const
+ {
+ std::string retVal = in;
+
+ for (std::string::iterator ii = retVal.begin(); ii != retVal.end(); ++ii)
+ {
+ *ii = ::toupper(*ii);
+ }
+
+ return retVal;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// GET
+ const std::string& http_client::get_url(const std::string& url)
+ {
+ std::string CT = get_header("Content-Type");
+
+ // You do a GET with a POST header??
+ if ( CT == "application/x-www-form-urlencoded" || CT == "multipart/form-data" )
+ remove_header("Content-Type");
+
+ grab_url(url);
+
+ return returned_body;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string http_client::build_post(std::string& content_type, const string_to_stringmap& postvars, const string_to_stringmap& filenames_in) const
+ {
+ if ( postvars.empty() && filenames_in.empty() )
+ return std::string();
+
+ string_to_stringmap filenames = filenames_in;
+
+ // sanitize the files
+ if ( !filenames.empty() )
+ {
+ string_to_stringmap::iterator var_names = filenames.begin();
+ while (var_names != filenames.end())
+ {
+ stringmap::iterator fnames = var_names->second.begin();
+
+ while( fnames != var_names->second.end() )
+ {
+ FILE *fp = ::fopen(fnames->second.c_str(), "rb");
+ if ( fp == NULL )
+ {
+ stringmap::iterator old_one = fnames++;
+ var_names->second.erase(old_one);
+ }
+ else
+ {
+ fclose(fp);
+ ++fnames;
+ }
+ }
+
+ if ( fnames->second.empty() )
+ {
+ string_to_stringmap::iterator old_one = var_names++;
+ filenames.erase(old_one);
+ }
+ else
+ {
+ ++var_names;
+ }
+ }
+ }
+
+ content_type = !filenames.empty() ? "multipart/form-data" : "application/x-www-form-urlencoded";
+ std::stringstream postBody;
+ if ( !filenames.empty() )
+ {
+ std::string mime_boundary = get_random_string(32);
+
+ // First add the form vars
+ for (string_to_stringmap::const_iterator ci = postvars.begin(); ci != postvars.end(); ++ci)
+ {
+ for (stringmap::const_iterator si = ci->second.begin(); si != ci->second.end(); ++si)
+ {
+ postBody << "--" << mime_boundary << "\r\n"
+ "Content-Disposition: form-data; name=\"" << ci->first << "\"\r\n\r\n"
+ << si->second << "\r\n";
+ }
+ }
+
+ // Then add the files
+ for (string_to_stringmap::const_iterator ci = filenames.begin(); ci != filenames.end(); ++ci)
+ {
+ for (stringmap::const_iterator si = ci->second.begin(); si != ci->second.end(); ++si)
+ {
+ std::ifstream in(si->second.c_str());
+ postBody << "--" << mime_boundary << "\r\n"
+ "Content-Disposition: form-data; name=\"" << ci->first << "\"; filename=\"" << get_basename(si->second) << "\"\r\n\r\n"
+ << in.rdbuf() << "\r\n";
+ }
+ }
+
+ postBody << "--" << mime_boundary << "--\r\n";
+ }
+ else
+ {
+ // No files...
+ for (string_to_stringmap::const_iterator ci = postvars.begin(); ci != postvars.end(); ++ci)
+ {
+ for (stringmap::const_iterator si = ci->second.begin(); si != ci->second.end(); ++si)
+ {
+ postBody << urlencode(ci->first) << '=' << urlencode(si->second) << '&';
+ }
+ }
+
+ // read the last '&'
+ char c;
+ postBody.read(&c, 1);
+ }
+
+ return postBody.str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool http_client::grab_url(const std::string& url, const std::string& method, const std::string& post_body)
+ {
+ error_field.clear();
+ returned_headers.clear();
+ http_return = 0;
+ returned_body.clear();
+
+ std::string to_use_method = strtoupper(method);
+
+ std::string scheme, user, pass, host, path;
+ short port;
+ if ( !parse_url(url, scheme, user, pass, host, port, path) )
+ {
+ error_field = "Couldn't parse the URL!";
+ return false;
+ }
+
+ // Build request
+ std::stringstream ret;
+ ret << to_use_method << ' ' << path << " HTTP/1.0\r\n"
+ << "Host: " << host;
+ if (port != 80 && port != 443) ret << ':' << port;
+ ret << "\r\n";
+
+ bool content_length_said = false;
+
+ set_header("Connection", "Close");
+ for (stringmap::iterator ci = headers.begin(); ci != headers.end(); ++ci)
+ {
+ std::string head = strtolower(ci->first);
+
+ if ( head == "content-length" )
+ {
+ content_length_said = true;
+ }
+
+ ret << ci->first << ':' << ' ' << ci->second << "\r\n";
+ }
+
+ if ( !content_length_said && to_use_method != "GET" )
+ ret << "Content-Length: " << static_cast<unsigned int>(post_body.size()) << "\r\n";
+
+ std::stringstream cookie_ss;
+ for (stringmap::iterator ci = cookies.begin(); ci != cookies.end(); ++ci)
+ {
+ std::string var = ci->first ; trim(var);
+ std::string val = ci->second; trim(val);
+
+ if ( val.empty() || var.empty() )
+ continue;
+
+ if ( !cookie_ss.str().empty() )
+ cookie_ss << ';' << ' ';
+
+ cookie_ss << urlencode(var) << '=' << urlencode(val);
+ }
+
+ if ( !cookie_ss.str().empty() )
+ ret << "Cookie: " << cookie_ss.str() << "\r\n";
+
+ ret << "\r\n";
+ ret << post_body;
+
+ std::string request_build = ret.str();
+
+ std::stringstream ss;
+ {
+ dlib::connection * conn(0);
+ try
+ {
+ if (timeout > 0)
+ conn = dlib::connect(host, port, timeout);
+ else
+ conn = dlib::connect(host, port);
+ }
+ catch (const dlib::socket_error& e)
+ {
+ error_field = e.what();
+ return false;
+ }
+
+ // Implement a timeout
+ timeout_ptr t;
+ if ( timeout > 0 )
+ t.reset( new dlib::timeout(*conn, &dlib::connection::shutdown, timeout) );
+
+ // Write our request
+ conn->write(request_build.c_str(), static_cast<long>(request_build.size()));
+
+ t.reset();
+
+ // And read the response
+ char buf[512];
+ long bytes_read(0), bytes_total(0);
+ bool read_headers(true);
+
+ if ( timeout > 0 )
+ t.reset( new dlib::timeout(*conn, &dlib::connection::shutdown, timeout) );
+
+ while ( (bytes_read = conn->read(buf, 512)) > 0 )
+ {
+ ss.write(buf, bytes_read);
+
+ // Incremental read headers
+ if ( read_headers )
+ {
+ std::string body_with_headers = ss.str();
+ std::string::size_type ctr(0);
+
+ while ( true )
+ {
+ std::string::size_type pos = body_with_headers.find("\r\n", ctr);
+ if ( pos == std::string::npos )
+ {
+ // This is our last position of "\r\n"
+ ss.str("");
+ ss.write( body_with_headers.substr(ctr).c_str(), body_with_headers.size() - ctr );
+ break;
+ }
+
+ std::string header = body_with_headers.substr(ctr, pos-ctr);
+ if ( header.empty() )
+ {
+ // Ok, we're done reading the headers
+ read_headers = false;
+ // What follows now is the body
+ ss.str("");
+ ss.write( body_with_headers.substr(pos + 2).c_str(), body_with_headers.size() - pos - 2 );
+ break;
+ }
+ ctr = pos + 2;
+
+ if ( returned_headers.empty() )
+ {
+ if (
+ header[0] == 'H' &&
+ header[1] == 'T' &&
+ header[2] == 'T' &&
+ header[3] == 'P' &&
+ header[4] == '/' &&
+ (header[5] >= '0' && header[5] <= '9') &&
+ header[6] == '.' &&
+ (header[7] >= '0' && header[7] <= '9') &&
+ header[8] == ' '
+ )
+ {
+ http_return = (header[9 ] - '0') * 100 +
+ (header[10] - '0') * 10 +
+ (header[11] - '0');
+ continue;
+ }
+ }
+
+ std::string::size_type pos_dp = header.find_first_of(':');
+ std::string header_name, header_value;
+ if ( pos_dp == std::string::npos )
+ {
+ // **TODO** what should I do here??
+ header_name = header;
+ }
+ else
+ {
+ header_name = trim(header.substr(0, pos_dp));
+ header_value = trim(header.substr(pos_dp+1));
+ }
+
+ returned_headers[ header_name ].push_back(header_value);
+
+ if ( BR_CASECMP(header_name.c_str(), "Content-Length", 14) == 0 )
+ {
+ bytes_total = atol( header_value.c_str() );
+ }
+ else if ( BR_CASECMP(header_name.c_str(), "Set-Cookie", 10) == 0 )
+ {
+ std::string::size_type cur_pos(0), pos_pk, pos_is;
+ std::string work, var, val;
+ for ( cur_pos = 0; cur_pos < header_value.size(); cur_pos++ )
+ {
+ pos_pk = header_value.find(';', cur_pos);
+ work = trim( header_value.substr(cur_pos, pos_pk - cur_pos) );
+
+ pos_is = work.find('=');
+ if ( pos_is != std::string::npos )
+ { // Hmmm? what in the else case?
+ var = trim( http_client::urldecode( work.substr(0, pos_is) ) );
+ val = trim( http_client::urldecode( work.substr(pos_is + 1) ) );
+
+ if ( var != "expires" && var != "domain" && var != "path" )
+ set_cookie( var, val );
+ }
+ cur_pos = pos_pk == std::string::npos ? pos_pk - 1 : pos_pk;
+ }
+ } // Set-Cookie?
+
+ } // while (true)
+ } // read_headers?
+
+ // Call the OnDownload function if it's set
+ if ( OnDownload && !read_headers )
+ {
+ if ( (*OnDownload)(static_cast<long>(ss.tellp()), bytes_total, user_info) == false )
+ {
+ t.reset();
+ break;
+ }
+ }
+
+ if ( bytes_total != 0 && static_cast<long>(ss.tellp()) == bytes_total )
+ {
+ t.reset();
+ break;
+ }
+
+ if ( timeout > 0 )
+ t.reset( new dlib::timeout(*conn, &dlib::connection::shutdown, timeout) );
+ } // while still data to read
+
+ t.reset();
+
+ delete conn;
+
+
+ switch ( bytes_read )
+ {
+ case dlib::TIMEOUT: error_field = "Timeout"; return false; break;
+ case dlib::WOULDBLOCK: error_field = "Would block"; return false; break;
+ case dlib::OTHER_ERROR: error_field = "Other error"; return false; break;
+ case dlib::SHUTDOWN: error_field = "Timeout"; return false; break;
+ case dlib::PORTINUSE: error_field = "Port in use"; return false; break;
+ }
+ }
+
+ returned_body = ss.str();
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void http_client::clear()
+ {
+ headers.clear();
+ cookies.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void http_client::prepare_for_next_url( )
+ {
+ remove_header("Content-Type");
+ remove_header("Content-Length");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
diff --git a/ml/dlib/dlib/http_client/http_client.h b/ml/dlib/dlib/http_client/http_client.h
new file mode 100644
index 000000000..fd2996ab1
--- /dev/null
+++ b/ml/dlib/dlib/http_client/http_client.h
@@ -0,0 +1,101 @@
+#ifndef DLIB_BROWSERhH
+#define DLIB_BROWSERhH
+
+
+#include <map>
+#include <string>
+#include <vector>
+#include "http_client_abstract.h"
+
+
+// Default timeout after 60 seconds
+#define DEFAULT_TIMEOUT 60000
+
+namespace dlib
+{
+
+ // Function which is called when there is data available.
+ // Return false to stop the download process...
+ typedef bool (*fnOnDownload)(long already_downloaded, long total_to_download, void * userInfo);
+
+
+ class http_client
+ {
+ public:
+ http_client();
+
+ typedef std::map< std::string, std::string > stringmap;
+ typedef std::map< std::string, stringmap > string_to_stringmap;
+ typedef std::map< std::string, std::vector<std::string> > string_to_stringvector;
+
+ // Header functions
+ void set_header(const std::string& header_name, const std::string& header_value);
+ void set_header(const std::string& header_name, long header_value);
+ std::string get_header(const std::string& header_name) const;
+ void remove_header(const std::string& header_name);
+ bool is_header_set(const std::string& header_name) const;
+
+ // This function will clear out all cookies & headers set until now
+ void clear();
+ // This function will clear out the Content-Type header
+ void prepare_for_next_url();
+
+ void set_callback_function( fnOnDownload od, void * _user_info ) { OnDownload = od; user_info = _user_info; }
+
+ void set_cookie(const std::string& cookie_name, const std::string& cookie_value);
+ void set_cookie(const std::string& cookie_name, long cookie_value);
+ void remove_cookie(const std::string& cookie_name);
+
+ void set_user_agent(const std::string& new_agent) { set_header("User-Agent", new_agent); }
+
+
+ void set_timeout( unsigned int milliseconds = DEFAULT_TIMEOUT ) { timeout = milliseconds; }
+
+
+ string_to_stringvector get_returned_headers() const { return returned_headers; }
+ short get_http_return () const { return http_return; }
+ const std::string& get_body () const { return returned_body; }
+
+ // POST
+ const std::string& post_url (const std::string& url, const string_to_stringmap& postvars, const string_to_stringmap& filenames = string_to_stringmap());
+ const std::string& post_url (const std::string& url, const std::string& postbuffer);
+ // GET
+ const std::string& get_url (const std::string& url);
+
+ bool has_error( ) const { return !error_field.empty(); }
+ const std::string& get_error( ) const { return error_field; }
+
+ static std::string urlencode(const std::string& in, bool post_encode = false);
+ static std::string urldecode(const std::string& in);
+ private:
+ bool grab_url(const std::string& url, const std::string& method = "GET", const std::string& post_body = "");
+ std::string build_post(std::string& content_type, const string_to_stringmap& postvars, const string_to_stringmap& filenames) const;
+
+ std::string get_random_string( size_t length = 32 ) const;
+ std::string get_basename( const std::string& filename ) const;
+ std::string strtolower(const std::string& in) const;
+ std::string strtoupper(const std::string& in) const;
+
+ bool parse_url(const std::string& url, std::string& scheme, std::string& user, std::string& pass, std::string& host, short& port, std::string& path) const;
+
+ stringmap headers;
+ stringmap cookies;
+
+ string_to_stringvector returned_headers;
+ short http_return;
+ std::string returned_body, error_field;
+
+ unsigned int timeout;
+
+ fnOnDownload OnDownload;
+ void * user_info;
+ };
+
+}
+
+#ifdef NO_MAKEFILE
+#include "http_client.cpp"
+#endif
+
+#endif // DLIB_BROWSERhH
+
diff --git a/ml/dlib/dlib/http_client/http_client_abstract.h b/ml/dlib/dlib/http_client/http_client_abstract.h
new file mode 100644
index 000000000..5e9d1d5e6
--- /dev/null
+++ b/ml/dlib/dlib/http_client/http_client_abstract.h
@@ -0,0 +1,218 @@
+#undef DLIB_BROWSER_ABSTRACh_
+#ifdef DLIB_BROWSER_ABSTRACh_
+
+
+
+namespace dlib
+{
+
+ // Function which is called when there is data available.
+ // Return false to stop the download process...
+ typedef bool (*fnOnDownload)(long already_downloaded, long total_to_download, void * userInfo);
+
+
+// ----------------------------------------------------------------------------------------
+/*
+TODO:
+- Timed cookie support
+- POSTing files: check it!
+- Don't timeout when still downloading!
+*/
+// ----------------------------------------------------------------------------------------
+
+
+ class Browser
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a possibility for the end user to download webpages (HTTP/1.0)
+ from the internet like a normal webbrowser would do.
+ !*/
+
+ public:
+
+ Browser(
+ );
+ /*!
+ Constructor
+ !*/
+
+ void set_header(
+ const std::string& header_name,
+ const std::string& header_value
+ );
+ /*!
+ Set a header to a certain value
+ Example: set_header("User-Agent", "Internet Explorer")
+ !*/
+
+
+ void set_header(
+ const std::string& header_name,
+ long header_value
+ );
+ /*!
+ Set a header to a certain number
+ Example: set_header("Content-Length", 1234)
+ !*/
+
+ std::string get_header(
+ const std::string& header_name
+ ) const;
+ /*!
+ Get the value of the header or an empty string when it's not set.
+ Example: get_header("Content-Length") would return "1234"
+ !*/
+
+ void remove_header(
+ const std::string& header_name
+ );
+ /*!
+ Removes a certain header
+ !*/
+
+ bool is_header_set(
+ const std::string& header_name
+ ) const;
+ /*!
+ Returns when a header is set and is not empty
+ !*/
+
+ void set_user_agent(
+ const std::string& new_agent
+ ) { set_header("User-Agent", new_agent); }
+ /*!
+ Convenience function for setting a user agent
+ !*/
+
+ void clear(
+ );
+ /*!
+ Clear out all cookies & headers set until now
+ !*/
+
+ void prepare_for_next_url(
+ );
+ /*!
+ Clear out any header and/or cookie which would obstruct getting a next page.
+ At this moment this is cleared:
+ - the Content-Type header
+ !*/
+
+ void set_callback_function(
+ fnOnDownload od,
+ void * _user_info
+ );
+ /*!
+ Set a callback function for one of the following events:
+ - OnDownload: this will tell you how much is downloaded and how much will need to be downloaded
+ !*/
+
+ void set_cookie(
+ const std::string& cookie_name,
+ const std::string& cookie_value
+ );
+ /*!
+ Set a cookie
+ !*/
+
+ void set_cookie(
+ const std::string& cookie_name,
+ long cookie_value
+ );
+ /*!
+ Set a cookie
+ !*/
+
+ void remove_cookie(
+ const std::string& cookie_name
+ );
+ /*!
+ Remove a cookie if it's set
+ !*/
+
+ void set_timeout(
+ unsigned int milliseconds
+ );
+ /*!
+ Set the maximum time how long a request can take. Setting this to 0 disables
+ this behavior.
+ !*/
+
+ string_to_stringvector get_returned_headers(
+ ) const;
+ /*!
+ Returns all the headers which are returned in the download of the webpage.
+ !*/
+
+ short get_http_return (
+ ) const;
+ /*!
+ Retrieves the HTTP return code.
+ !*/
+
+ const std::string& get_body (
+ ) const;
+ /*!
+ Retrieves the HTTP body.
+ !*/
+
+ const std::string& post_url (
+ const std::string& url,
+ const string_to_stringmap& postvars,
+ const string_to_stringmap& filenames = string_to_stringmap()
+ );
+ /*!
+ POST an url to the internet.
+ You can pass the post variables as well as a list of filenames
+ !*/
+
+ const std::string& post_url (
+ const std::string& url,
+ const std::string& postbuffer
+ );
+ /*!
+ POST an url to the internet.
+ In this function you have constructed the POST string yourselves
+ !*/
+
+ const std::string& get_url (
+ const std::string& url
+ );
+ /*!
+ GET an url from the internet.
+ !*/
+
+ bool has_error(
+ ) const;
+ /*!
+ Has there happened an error?
+ !*/
+
+ const std::string& get_error(
+ ) const;
+ /*!
+ Get the error explanation
+ !*/
+
+ static std::string urlencode(
+ const std::string& in,
+ bool post_encode = false
+ );
+ /*!
+ Convenience function to URLencode a string
+ !*/
+
+ static std::string urldecode(
+ const std::string& in
+ );
+ /*!
+ Convenience function to URLdecode a string
+ !*/
+
+ };
+
+}
+
+#endif // DLIB_BROWSER_ABSTRACh_
+
diff --git a/ml/dlib/dlib/image_io.h b/ml/dlib/dlib/image_io.h
new file mode 100644
index 000000000..4fdc79881
--- /dev/null
+++ b/ml/dlib/dlib/image_io.h
@@ -0,0 +1,20 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_IMAGe_IO_
+#define DLIB_IMAGe_IO_
+
+#include "image_loader/image_loader.h"
+#include "image_loader/png_loader.h"
+#include "image_loader/jpeg_loader.h"
+#include "image_loader/load_image.h"
+#include "image_saver/image_saver.h"
+#include "image_saver/save_png.h"
+#include "image_saver/save_jpeg.h"
+
+#endif // DLIB_IMAGe_IO_
+
diff --git a/ml/dlib/dlib/image_keypoint.h b/ml/dlib/dlib/image_keypoint.h
new file mode 100644
index 000000000..335d620e8
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint.h
@@ -0,0 +1,16 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IMAGE_KEYPOINt_H_
+#define DLIB_IMAGE_KEYPOINt_H_
+
+#include "image_keypoint/surf.h"
+#include "image_keypoint/hessian_pyramid.h"
+#include "image_keypoint/hog.h"
+#include "image_keypoint/poly_image.h"
+#include "image_keypoint/fine_hog_image.h"
+#include "image_keypoint/hashed_feature_image.h"
+#include "image_keypoint/nearest_neighbor_feature_image.h"
+#include "image_keypoint/binned_vector_feature_image.h"
+
+#endif // DLIB_IMAGE_KEYPOINt_H_
+
diff --git a/ml/dlib/dlib/image_keypoint/binned_vector_feature_image.h b/ml/dlib/dlib/image_keypoint/binned_vector_feature_image.h
new file mode 100644
index 000000000..019a12739
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/binned_vector_feature_image.h
@@ -0,0 +1,433 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINNED_VECTOR_IMAGE_FEATUrES_Hh_
+#define DLIB_BINNED_VECTOR_IMAGE_FEATUrES_Hh_
+
+#include "../lsh/projection_hash.h"
+#include "binned_vector_feature_image_abstract.h"
+#include <vector>
+#include "../algs.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type_ = projection_hash
+ >
+ class binned_vector_feature_image : noncopyable
+ {
+
+ public:
+ typedef feature_extractor feature_extractor_type;
+ typedef hash_function_type_ hash_function_type;
+
+ typedef std::vector<std::pair<unsigned int,double> > descriptor_type;
+
+ binned_vector_feature_image (
+ );
+
+ void clear (
+ );
+
+ void set_hash (
+ const hash_function_type& hash_
+ );
+
+ const hash_function_type& get_hash (
+ ) const;
+
+ void copy_configuration (
+ const feature_extractor& item
+ );
+
+ void copy_configuration (
+ const binned_vector_feature_image& item
+ );
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ );
+
+ inline size_t size (
+ ) const;
+
+ inline long nr (
+ ) const;
+
+ inline long nc (
+ ) const;
+
+ inline long get_num_dimensions (
+ ) const;
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const;
+
+ inline const rectangle get_block_rect (
+ long row,
+ long col
+ ) const;
+
+ inline const point image_to_feat_space (
+ const point& p
+ ) const;
+
+ inline const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const;
+
+ inline const point feat_to_image_space (
+ const point& p
+ ) const;
+
+ inline const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const;
+
+ template <typename T>
+ friend void serialize (
+ const binned_vector_feature_image<T>& item,
+ std::ostream& out
+ );
+
+ template <typename T>
+ friend void deserialize (
+ binned_vector_feature_image<T>& item,
+ std::istream& in
+ );
+
+ private:
+
+ array2d<descriptor_type> feats;
+ feature_extractor fe;
+ hash_function_type phash;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const binned_vector_feature_image<T>& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.feats, out);
+ serialize(item.fe, out);
+ serialize(item.phash, out);
+ }
+
+ template <typename T>
+ void deserialize (
+ binned_vector_feature_image<T>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw dlib::serialization_error("Unexpected version found while deserializing dlib::binned_vector_feature_image");
+ deserialize(item.feats, in);
+ deserialize(item.fe, in);
+ deserialize(item.phash, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// binned_vector_feature_image member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ binned_vector_feature_image<feature_extractor,hash_function_type>::
+ binned_vector_feature_image (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void binned_vector_feature_image<feature_extractor,hash_function_type>::
+ clear (
+ )
+ {
+ fe.clear();
+ phash = hash_function_type();
+ feats.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void binned_vector_feature_image<feature_extractor,hash_function_type>::
+ set_hash (
+ const hash_function_type& hash_
+ )
+ {
+ phash = hash_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const hash_function_type& binned_vector_feature_image<feature_extractor,hash_function_type>::
+ get_hash (
+ ) const
+ {
+ return phash;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void binned_vector_feature_image<feature_extractor,hash_function_type>::
+ copy_configuration (
+ const feature_extractor& item
+ )
+ {
+ fe.copy_configuration(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void binned_vector_feature_image<feature_extractor,hash_function_type>::
+ copy_configuration (
+ const binned_vector_feature_image& item
+ )
+ {
+ fe.copy_configuration(item.fe);
+ phash = item.phash;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ template <
+ typename image_type
+ >
+ void binned_vector_feature_image<feature_extractor,hash_function_type>::
+ load (
+ const image_type& img
+ )
+ {
+ fe.load(img);
+
+ if (fe.size() != 0)
+ {
+ feats.set_size(fe.nr(), fe.nc());
+ for (long r = 0; r < feats.nr(); ++r)
+ {
+ for (long c = 0; c < feats.nc(); ++c)
+ {
+ feats[r][c].clear();
+ feats[r][c].reserve(fe.get_num_dimensions()+1);
+ const typename feature_extractor::descriptor_type& des = fe(r,c);
+ const unsigned long idx = phash(des);
+ const unsigned long offset = idx*(fe.get_num_dimensions()+1);
+
+ for (long i = 0; i < des.size(); ++i)
+ {
+ feats[r][c].push_back(std::make_pair(offset + i, des(i)));
+ }
+ feats[r][c].push_back(std::make_pair(offset + des.size(), 1.0));
+ }
+ }
+ }
+ else
+ {
+ feats.set_size(0,0);
+ }
+
+ fe.unload();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ size_t binned_vector_feature_image<feature_extractor,hash_function_type>::
+ size (
+ ) const
+ {
+ return feats.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ long binned_vector_feature_image<feature_extractor,hash_function_type>::
+ nr (
+ ) const
+ {
+ return feats.nr();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ long binned_vector_feature_image<feature_extractor,hash_function_type>::
+ nc (
+ ) const
+ {
+ return feats.nc();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ long binned_vector_feature_image<feature_extractor,hash_function_type>::
+ get_num_dimensions (
+ ) const
+ {
+ return phash.num_hash_bins()*(fe.get_num_dimensions()+1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const std::vector<std::pair<unsigned int,double> >& binned_vector_feature_image<feature_extractor,hash_function_type>::
+ operator() (
+ long row,
+ long col
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= row && row < nr() &&
+ 0 <= col && col < nc(),
+ "\t descriptor_type binned_vector_feature_image::operator(row,col)"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t row: " << row
+ << "\n\t col: " << col
+ << "\n\t nr(): " << nr()
+ << "\n\t nc(): " << nc()
+ << "\n\t this: " << this
+ );
+
+ return feats[row][col];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const rectangle binned_vector_feature_image<feature_extractor,hash_function_type>::
+ get_block_rect (
+ long row,
+ long col
+ ) const
+ {
+ return fe.get_block_rect(row,col);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const point binned_vector_feature_image<feature_extractor,hash_function_type>::
+ image_to_feat_space (
+ const point& p
+ ) const
+ {
+ return fe.image_to_feat_space(p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const rectangle binned_vector_feature_image<feature_extractor,hash_function_type>::
+ image_to_feat_space (
+ const rectangle& rect
+ ) const
+ {
+ return fe.image_to_feat_space(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const point binned_vector_feature_image<feature_extractor,hash_function_type>::
+ feat_to_image_space (
+ const point& p
+ ) const
+ {
+ return fe.feat_to_image_space(p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const rectangle binned_vector_feature_image<feature_extractor,hash_function_type>::
+ feat_to_image_space (
+ const rectangle& rect
+ ) const
+ {
+ return fe.feat_to_image_space(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BINNED_VECTOR_IMAGE_FEATUrES_Hh_
+
+
diff --git a/ml/dlib/dlib/image_keypoint/binned_vector_feature_image_abstract.h b/ml/dlib/dlib/image_keypoint/binned_vector_feature_image_abstract.h
new file mode 100644
index 000000000..6bd6cdbb8
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/binned_vector_feature_image_abstract.h
@@ -0,0 +1,287 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BINNED_VECTOR_FEATUrES_ABSTRACT_Hh_
+#ifdef DLIB_BINNED_VECTOR_FEATUrES_ABSTRACT_Hh_
+
+#include "../lsh/projection_hash_abstract.h"
+#include <vector>
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type_ = projection_hash
+ >
+ class binned_vector_feature_image : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ - must be an object with an interface compatible with dlib::hog_image
+
+ REQUIREMENTS ON hash_function_type_
+ - must be an object with an interface compatible with projection_hash
+
+ INITIAL VALUE
+ - size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for performing image feature extraction. In
+ particular, it wraps another image feature extractor and converts the
+ wrapped image feature vectors into a high dimensional sparse vector. For
+ example, if the lower level feature extractor outputs the vector [3,4,5]
+ and this vector is hashed into the second bin of four bins then the output
+ sparse vector is:
+ [0,0,0,0, 3,4,5,1, 0,0,0,0, 0,0,0,0].
+ That is, the output vector has a dimensionality that is equal to the number
+ of hash bins times the dimensionality of the lower level vector plus one.
+ The value in the extra dimension concatenated onto the end of the vector is
+ always a constant value of of 1 and serves as a bias value. This means
+ that, if there are N hash bins, these vectors are capable of representing N
+ different linear functions, each operating on the vectors that fall into
+ their corresponding hash bin.
+
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be
+ protected by a mutex lock except for the case where you are copying the
+ configuration (via copy_configuration()) of a binned_vector_feature_image
+ object to many other threads. In this case, it is safe to copy the
+ configuration of a shared object so long as no other operations are
+ performed on it.
+
+
+ NOTATION
+ let BASE_FE denote the base feature_extractor object contained inside the
+ binned_vector_feature_image.
+ !*/
+
+ public:
+
+ typedef feature_extractor feature_extractor_type;
+ typedef hash_function_type_ hash_function_type;
+ typedef std::vector<std::pair<unsigned int,double> > descriptor_type;
+
+ binned_vector_feature_image (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - this object will have its initial value
+ !*/
+
+ void set_hash (
+ const hash_function_type& hash
+ );
+ /*!
+ ensures
+ - #get_hash() == hash
+ !*/
+
+ const hash_function_type& get_hash (
+ ) const;
+ /*!
+ ensures
+ - returns the hash function used by this object to hash
+ base feature vectors into integers.
+ !*/
+
+ void copy_configuration (
+ const feature_extractor& item
+ );
+ /*!
+ ensures
+ - performs BASE_FE.copy_configuration(item)
+ !*/
+
+ void copy_configuration (
+ const binned_vector_feature_image& item
+ );
+ /*!
+ ensures
+ - copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two binned_vector_feature_image
+ objects H1 and H2, the following sequence of instructions should always
+ result in both of them having the exact same state.
+ H2.copy_configuration(H1);
+ H1.load(img);
+ H2.load(img);
+ !*/
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == any type that can be supplied to feature_extractor::load()
+ ensures
+ - performs BASE_FE.load(img)
+ i.e. does feature extraction. The features can be accessed using
+ operator() as defined below.
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.size()
+ !*/
+
+ long nr (
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.nr()
+ !*/
+
+ long nc (
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.nc()
+ !*/
+
+ long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the feature vectors returned by operator().
+ In this case, this is the number of hash bins times the dimensionality of
+ the features produced by BASE_FE plus one. That is, this function
+ returns get_hash().num_hash_bins()*(BASE_FE.get_num_dimensions()+1)
+ !*/
+
+ const descriptor_type& operator() (
+ long row,
+ long col
+ ) const;
+ /*!
+ requires
+ - 0 <= row < nr()
+ - 0 <= col < nc()
+ - It must be legal to evaluate expressions of the form: get_hash()(BASE_FE(row,col))
+ (e.g. the hash function must be properly configured to process the feature
+ vectors produced by the base feature extractor)
+ ensures
+ - hashes BASE_FE(row,col) and returns the resulting sparse vector. In
+ particular, we return a vector that is a copy of BASE_FE(row,col) that
+ has been shifted into the part of the sparse vector indicated by the hash
+ function. It will also have a constant bias value of 1 appended to it.
+ - To be precise, this function returns a sparse vector V such that:
+ - V.size() == BASE_FE.get_num_dimensions()+1
+ - let IDX = get_hash()(BASE_FE(row,col))
+ - for i where 0 <= i < BASE_FE.get_num_dimensions():
+ - V[i].first == IDX*(BASE_FE.get_num_dimensions()+1) + i
+ - V[i].second == BASE_FE(row,col)(i)
+ - V[BASE_FE.get_num_dimensions()].first == IDX*(BASE_FE.get_num_dimensions()+1) + BASE_FE.get_num_dimensions()
+ - V[BASE_FE.get_num_dimensions()].second == 1
+ !*/
+
+ const rectangle get_block_rect (
+ long row,
+ long col
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.get_block_rect(row,col)
+ I.e. returns a rectangle that tells you what part of the original image is associated
+ with a particular feature vector.
+ !*/
+
+ const point image_to_feat_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.image_to_feat_space(p)
+ I.e. Each local feature is extracted from a certain point in the input image.
+ This function returns the identity of the local feature corresponding
+ to the image location p. Or in other words, let P == image_to_feat_space(p),
+ then (*this)(P.y(),P.x()) == the local feature closest to, or centered at,
+ the point p in the input image. Note that some image points might not have
+ corresponding feature locations. E.g. border points or points outside the
+ image. In these cases the returned point will be outside get_rect(*this).
+ !*/
+
+ const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.image_to_feat_space(rect)
+ I.e. returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner()));
+ (i.e. maps a rectangle from image space to feature space)
+ !*/
+
+ const point feat_to_image_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.feat_to_image_space(p)
+ I.e. returns the location in the input image space corresponding to the center
+ of the local feature at point p. In other words, this function computes
+ the inverse of image_to_feat_space(). Note that it may only do so approximately,
+ since more than one image location might correspond to the same local feature.
+ That is, image_to_feat_space() might not be invertible so this function gives
+ the closest possible result.
+ !*/
+
+ const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.feat_to_image_space(rect)
+ I.e. return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner()));
+ (i.e. maps a rectangle from feature space to image space)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ void serialize (
+ const binned_vector_feature_image<T,U>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ void deserialize (
+ binned_vector_feature_image<T,U>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BINNED_VECTOR_FEATUrES_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_keypoint/build_separable_poly_filters.h b/ml/dlib/dlib/image_keypoint/build_separable_poly_filters.h
new file mode 100644
index 000000000..aea59067d
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/build_separable_poly_filters.h
@@ -0,0 +1,186 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BUILD_SEPARABLE_PoLY_FILTERS_Hh_
+#define DLIB_BUILD_SEPARABLE_PoLY_FILTERS_Hh_
+
+#include "../matrix.h"
+#include "surf.h"
+#include "../uintn.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ typedef std::pair<matrix<double,0,1>, matrix<double,0,1> > separable_filter_type;
+ typedef std::pair<matrix<int32,0,1>, matrix<int32,0,1> > separable_int32_filter_type;
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<std::vector<separable_filter_type> > build_separable_poly_filters (
+ const long order,
+ const long window_size
+ )
+ /*!
+ requires
+ - 1 <= order <= 6
+ - window_size >= 3 && window_size is odd
+ ensures
+ - the "first" element is the row_filter, the second is the col_filter.
+ - Some filters are not totally separable and that's why they are grouped
+ into vectors of vectors. The groups are all the parts of a partially
+ separable filter.
+ !*/
+ {
+ long num_filters = 6;
+ switch (order)
+ {
+ case 1: num_filters = 3; break;
+ case 2: num_filters = 6; break;
+ case 3: num_filters = 10; break;
+ case 4: num_filters = 15; break;
+ case 5: num_filters = 21; break;
+ case 6: num_filters = 28; break;
+ }
+
+ matrix<double> X(window_size*window_size,num_filters);
+ matrix<double,0,1> G(window_size*window_size,1);
+ const double sigma = window_size/4.0;
+
+
+ long cnt = 0;
+ for (double x = -window_size/2; x <= window_size/2; ++x)
+ {
+ for (double y = -window_size/2; y <= window_size/2; ++y)
+ {
+ X(cnt, 0) = 1;
+ X(cnt, 1) = x;
+ X(cnt, 2) = y;
+
+ if (X.nc() > 5)
+ {
+ X(cnt, 3) = x*x;
+ X(cnt, 4) = x*y;
+ X(cnt, 5) = y*y;
+ }
+ if (X.nc() > 9)
+ {
+ X(cnt, 6) = x*x*x;
+ X(cnt, 7) = y*x*x;
+ X(cnt, 8) = y*y*x;
+ X(cnt, 9) = y*y*y;
+ }
+ if (X.nc() > 14)
+ {
+ X(cnt, 10) = x*x*x*x;
+ X(cnt, 11) = y*x*x*x;
+ X(cnt, 12) = y*y*x*x;
+ X(cnt, 13) = y*y*y*x;
+ X(cnt, 14) = y*y*y*y;
+ }
+ if (X.nc() > 20)
+ {
+ X(cnt, 15) = x*x*x*x*x;
+ X(cnt, 16) = y*x*x*x*x;
+ X(cnt, 17) = y*y*x*x*x;
+ X(cnt, 18) = y*y*y*x*x;
+ X(cnt, 19) = y*y*y*y*x;
+ X(cnt, 20) = y*y*y*y*y;
+ }
+ if (X.nc() > 27)
+ {
+ X(cnt, 21) = x*x*x*x*x*x;
+ X(cnt, 22) = y*x*x*x*x*x;
+ X(cnt, 23) = y*y*x*x*x*x;
+ X(cnt, 24) = y*y*y*x*x*x;
+ X(cnt, 25) = y*y*y*y*x*x;
+ X(cnt, 26) = y*y*y*y*y*x;
+ X(cnt, 27) = y*y*y*y*y*y;
+ }
+
+ G(cnt) = std::sqrt(gaussian(x,y,sigma));
+ ++cnt;
+ }
+ }
+
+ X = diagm(G)*X;
+
+ const matrix<double> S = inv(trans(X)*X)*trans(X)*diagm(G);
+
+ matrix<double,0,1> row_filter, col_filter;
+
+ matrix<double> u,v, temp;
+ matrix<double,0,1> w;
+
+ std::vector<std::vector<separable_filter_type> > results(num_filters);
+
+ for (long r = 0; r < S.nr(); ++r)
+ {
+ temp = reshape(rowm(S,r), window_size, window_size);
+ svd3(temp,u,w,v);
+ const double thresh = max(w)*1e-8;
+ for (long i = 0; i < w.size(); ++i)
+ {
+ if (w(i) > thresh)
+ {
+ col_filter = std::sqrt(w(i))*colm(u,i);
+ row_filter = std::sqrt(w(i))*colm(v,i);
+ results[r].push_back(std::make_pair(row_filter, col_filter));
+ }
+ }
+ }
+
+ return results;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<std::vector<separable_int32_filter_type> > build_separable_int32_poly_filters (
+ const long order,
+ const long window_size,
+ const double max_range = 300.0
+ )
+ /*!
+ requires
+ - 1 <= order <= 6
+ - window_size >= 3 && window_size is odd
+ - max_range > 1
+ ensures
+ - the "first" element is the row_filter, the second is the col_filter.
+ !*/
+ {
+ const std::vector<std::vector<separable_filter_type> >& filters = build_separable_poly_filters(order, window_size);
+ std::vector<std::vector<separable_int32_filter_type> > int_filters(filters.size());
+
+ for (unsigned long i = 0; i < filters.size(); ++i)
+ {
+
+ double max_val = 0;
+ for (unsigned long j = 0; j < filters[i].size(); ++j)
+ {
+ const separable_filter_type& filt = filters[i][j];
+ max_val = std::max(max_val, max(abs(filt.first)));
+ max_val = std::max(max_val, max(abs(filt.second)));
+ }
+ if (max_val == 0)
+ max_val = 1;
+
+ int_filters[i].resize(filters[i].size());
+ for (unsigned long j = 0; j < filters[i].size(); ++j)
+ {
+ const separable_filter_type& filt = filters[i][j];
+ int_filters[i][j].first = matrix_cast<int32>(round(filt.first*max_range/max_val));
+ int_filters[i][j].second = matrix_cast<int32>(round(filt.second*max_range/max_val));
+ }
+ }
+
+ return int_filters;
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_BUILD_SEPARABLE_PoLY_FILTERS_Hh_
+
diff --git a/ml/dlib/dlib/image_keypoint/draw_surf_points.h b/ml/dlib/dlib/image_keypoint/draw_surf_points.h
new file mode 100644
index 000000000..b16c28f5d
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/draw_surf_points.h
@@ -0,0 +1,40 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DRAW_SURf_POINTS_H_
+#define DLIB_DRAW_SURf_POINTS_H_
+
+#include "surf.h"
+#include "../gui_widgets.h"
+#include "draw_surf_points_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+ inline void draw_surf_points (
+ image_window& win,
+ const std::vector<surf_point>& sp
+ )
+ {
+ for (unsigned long i = 0; i < sp.size(); ++i)
+ {
+ const unsigned long radius = static_cast<unsigned long>(sp[i].p.scale*3);
+ const point center(sp[i].p.center);
+ point direction = center + point(radius,0);
+ // SURF descriptors are rotated by sp[i].angle. So we want to include a visual
+ // indication of this rotation on our overlay.
+ direction = rotate_point(center, direction, sp[i].angle);
+
+ win.add_overlay(image_display::overlay_circle(center, radius, rgb_pixel(0,255,0)));
+ // Draw a line showing the orientation of the SURF descriptor.
+ win.add_overlay(center, direction, rgb_pixel(255,0,0));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DRAW_SURf_POINTS_H_
+
+
diff --git a/ml/dlib/dlib/image_keypoint/draw_surf_points_abstract.h b/ml/dlib/dlib/image_keypoint/draw_surf_points_abstract.h
new file mode 100644
index 000000000..86a66ef49
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/draw_surf_points_abstract.h
@@ -0,0 +1,30 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DRAW_SURf_POINTS_ABSTRACT_H_
+#ifdef DLIB_DRAW_SURf_POINTS_ABSTRACT_H_
+
+#include "surf_abstract.h"
+#include "../gui_widgets.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void draw_surf_points (
+ image_window& win,
+ const std::vector<surf_point>& sp
+ );
+ /*!
+ ensures
+ - draws all the SURF points in sp onto the given image_window. They
+ are drawn as overlay circles with extra lines to indicate the rotation
+ of the SURF descriptor.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DRAW_SURf_POINTS_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/image_keypoint/fine_hog_image.h b/ml/dlib/dlib/image_keypoint/fine_hog_image.h
new file mode 100644
index 000000000..a421ffe7c
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/fine_hog_image.h
@@ -0,0 +1,378 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FINE_HOG_IMaGE_Hh_
+#define DLIB_FINE_HOG_IMaGE_Hh_
+
+#include "fine_hog_image_abstract.h"
+#include "../array2d.h"
+#include "../matrix.h"
+#include "hog.h"
+
+
+namespace dlib
+{
+ template <
+ unsigned long cell_size_,
+ unsigned long block_size_,
+ unsigned long pixel_stride_,
+ unsigned char num_orientation_bins_,
+ int gradient_type_
+ >
+ class fine_hog_image : noncopyable
+ {
+ COMPILE_TIME_ASSERT(cell_size_ > 1);
+ COMPILE_TIME_ASSERT(block_size_ > 0);
+ COMPILE_TIME_ASSERT(pixel_stride_ > 0);
+ COMPILE_TIME_ASSERT(num_orientation_bins_ > 0);
+
+ COMPILE_TIME_ASSERT( gradient_type_ == hog_signed_gradient ||
+ gradient_type_ == hog_unsigned_gradient);
+
+
+ public:
+
+ const static unsigned long cell_size = cell_size_;
+ const static unsigned long block_size = block_size_;
+ const static unsigned long pixel_stride = pixel_stride_;
+ const static unsigned long num_orientation_bins = num_orientation_bins_;
+ const static int gradient_type = gradient_type_;
+
+ const static long min_size = cell_size*block_size+2;
+
+ typedef matrix<double, block_size*block_size*num_orientation_bins, 1> descriptor_type;
+
+ fine_hog_image (
+ ) :
+ num_block_rows(0),
+ num_block_cols(0)
+ {}
+
+ void clear (
+ )
+ {
+ num_block_rows = 0;
+ num_block_cols = 0;
+ hist_counts.clear();
+ }
+
+ void copy_configuration (
+ const fine_hog_image&
+ ){}
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ )
+ {
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<image_type>::pixel_type>::has_alpha == false );
+ load_impl(mat(img));
+ }
+
+ inline void unload(
+ ) { clear(); }
+
+ inline size_t size (
+ ) const { return static_cast<size_t>(nr()*nc()); }
+
+ inline long nr (
+ ) const { return num_block_rows; }
+
+ inline long nc (
+ ) const { return num_block_cols; }
+
+ long get_num_dimensions (
+ ) const
+ {
+ return block_size*block_size*num_orientation_bins;
+ }
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( 0 <= row && row < nr() &&
+ 0 <= col && col < nc(),
+ "\t descriptor_type fine_hog_image::operator()()"
+ << "\n\t invalid row or col argument"
+ << "\n\t row: " << row
+ << "\n\t col: " << col
+ << "\n\t nr(): " << nr()
+ << "\n\t nc(): " << nc()
+ << "\n\t this: " << this
+ );
+
+ row *= pixel_stride;
+ col *= pixel_stride;
+
+ des = 0;
+ unsigned long off = 0;
+ for (unsigned long r = 0; r < block_size; ++r)
+ {
+ for (unsigned long c = 0; c < block_size; ++c)
+ {
+ for (unsigned long rr = 0; rr < cell_size; ++rr)
+ {
+ for (unsigned long cc = 0; cc < cell_size; ++cc)
+ {
+ const histogram_count& hist = hist_counts[row + r*cell_size + rr][col + c*cell_size + cc];
+ des(off + hist.quantized_angle_lower) += hist.lower_strength;
+ des(off + hist.quantized_angle_upper) += hist.upper_strength;
+ }
+ }
+
+ off += num_orientation_bins;
+ }
+ }
+
+ des /= length(des) + 1e-8;
+
+ return des;
+ }
+
+ const rectangle get_block_rect (
+ long row,
+ long col
+ ) const
+ {
+ row *= pixel_stride;
+ col *= pixel_stride;
+
+ // do this to account for the 1 pixel padding we use all around the image
+ ++row;
+ ++col;
+
+ return rectangle(col, row, col+cell_size*block_size-1, row+cell_size*block_size-1);
+ }
+
+ const point image_to_feat_space (
+ const point& p
+ ) const
+ {
+ const long border_size = 1 + cell_size*block_size/2;
+ return (p-point(border_size,border_size))/(long)pixel_stride;
+ }
+
+ const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const
+ {
+ return rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner()));
+ }
+
+ const point feat_to_image_space (
+ const point& p
+ ) const
+ {
+ const long border_size = 1 + cell_size*block_size/2;
+ return p*(long)pixel_stride + point(border_size,border_size);
+ }
+
+ const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const
+ {
+ return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner()));
+ }
+
+
+
+ // these _PRIVATE_ functions are only here as a workaround for a bug in visual studio 2005.
+ void _PRIVATE_serialize (std::ostream& out) const
+ {
+ // serialize hist_counts
+ serialize(hist_counts.nc(),out);
+ serialize(hist_counts.nr(),out);
+ hist_counts.reset();
+ while (hist_counts.move_next())
+ hist_counts.element().serialize(out);
+ hist_counts.reset();
+
+
+ serialize(num_block_rows, out);
+ serialize(num_block_cols, out);
+ }
+
+ void _PRIVATE_deserialize (std::istream& in )
+ {
+ // deserialize item.hist_counts
+ long nc, nr;
+ deserialize(nc,in);
+ deserialize(nr,in);
+ hist_counts.set_size(nr,nc);
+ while (hist_counts.move_next())
+ hist_counts.element().deserialize(in);
+ hist_counts.reset();
+
+
+ deserialize(num_block_rows, in);
+ deserialize(num_block_cols, in);
+ }
+
+ private:
+
+ template <
+ typename image_type
+ >
+ void load_impl (
+ const image_type& img
+ )
+ {
+ // Note that we keep a border of 1 pixel all around the image so that we don't have
+ // to worry about running outside the image when computing the horizontal and vertical
+ // gradients.
+
+
+
+ // check if the window is just too small
+ if (img.nr() < min_size || img.nc() < min_size)
+ {
+ // If the image is smaller than our windows then there aren't any descriptors at all!
+ num_block_rows = 0;
+ num_block_cols = 0;
+ hist_counts.clear();
+ return;
+ }
+
+ hist_counts.set_size(img.nr()-2, img.nc()-2);
+
+
+
+
+ for (long r = 0; r < hist_counts.nr(); ++r)
+ {
+ for (long c = 0; c < hist_counts.nc(); ++c)
+ {
+ unsigned long left;
+ unsigned long right;
+ unsigned long top;
+ unsigned long bottom;
+
+ assign_pixel(left, img(r+1,c));
+ assign_pixel(right, img(r+1,c+2));
+ assign_pixel(top, img(r ,c+1));
+ assign_pixel(bottom, img(r+2,c+1));
+
+ double grad_x = (long)right-(long)left;
+ double grad_y = (long)top-(long)bottom;
+
+ // obtain the angle of the gradient. Make sure it is scaled between 0 and 1.
+ double angle = std::max(0.0, std::atan2(grad_y, grad_x)/pi + 1)/2;
+
+
+ if (gradient_type == hog_unsigned_gradient)
+ {
+ angle *= 2;
+ if (angle >= 1)
+ angle -= 1;
+ }
+
+
+ // now scale angle to between 0 and num_orientation_bins
+ angle *= num_orientation_bins;
+
+
+ const double strength = std::sqrt(grad_y*grad_y + grad_x*grad_x);
+
+
+ unsigned char quantized_angle_lower = static_cast<unsigned char>(std::floor(angle));
+ unsigned char quantized_angle_upper = static_cast<unsigned char>(std::ceil(angle));
+
+ quantized_angle_lower %= num_orientation_bins;
+ quantized_angle_upper %= num_orientation_bins;
+
+ const double angle_split = (angle-std::floor(angle));
+ const double upper_strength = angle_split*strength;
+ const double lower_strength = (1-angle_split)*strength;
+
+ // Stick into gradient counts. Note that we linearly interpolate between neighboring
+ // histogram buckets.
+ hist_counts[r][c].quantized_angle_lower = quantized_angle_lower;
+ hist_counts[r][c].quantized_angle_upper = quantized_angle_upper;
+ hist_counts[r][c].lower_strength = lower_strength;
+ hist_counts[r][c].upper_strength = upper_strength;
+
+ }
+ }
+
+
+ // Now figure out how many feature extraction blocks we should have.
+ num_block_rows = (hist_counts.nr() - block_size*cell_size + 1)/(long)pixel_stride;
+ num_block_cols = (hist_counts.nc() - block_size*cell_size + 1)/(long)pixel_stride;
+
+ }
+
+ struct histogram_count
+ {
+ unsigned char quantized_angle_lower;
+ unsigned char quantized_angle_upper;
+ float lower_strength;
+ float upper_strength;
+
+ void serialize(std::ostream& out) const
+ {
+ dlib::serialize(quantized_angle_lower, out);
+ dlib::serialize(quantized_angle_upper, out);
+ dlib::serialize(lower_strength, out);
+ dlib::serialize(upper_strength, out);
+ }
+ void deserialize(std::istream& in)
+ {
+ dlib::deserialize(quantized_angle_lower, in);
+ dlib::deserialize(quantized_angle_upper, in);
+ dlib::deserialize(lower_strength, in);
+ dlib::deserialize(upper_strength, in);
+ }
+ };
+
+ array2d<histogram_count> hist_counts;
+
+ mutable descriptor_type des;
+
+ long num_block_rows;
+ long num_block_cols;
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long T1,
+ unsigned long T2,
+ unsigned long T3,
+ unsigned char T4,
+ int T5
+ >
+ void serialize (
+ const fine_hog_image<T1,T2,T3,T4,T5>& item,
+ std::ostream& out
+ )
+ {
+ item._PRIVATE_serialize(out);
+ }
+
+ template <
+ unsigned long T1,
+ unsigned long T2,
+ unsigned long T3,
+ unsigned char T4,
+ int T5
+ >
+ void deserialize (
+ fine_hog_image<T1,T2,T3,T4,T5>& item,
+ std::istream& in
+ )
+ {
+ item._PRIVATE_deserialize(in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FINE_HOG_IMaGE_Hh_
+
diff --git a/ml/dlib/dlib/image_keypoint/fine_hog_image_abstract.h b/ml/dlib/dlib/image_keypoint/fine_hog_image_abstract.h
new file mode 100644
index 000000000..50be85afe
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/fine_hog_image_abstract.h
@@ -0,0 +1,276 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FINE_HOG_IMaGE_ABSTRACT_Hh_
+#ifdef DLIB_FINE_HOG_IMaGE_ABSTRACT_Hh_
+
+#include "../array2d.h"
+#include "../matrix.h"
+#include "hog_abstract.h"
+
+
+namespace dlib
+{
+ template <
+ unsigned long cell_size_,
+ unsigned long block_size_,
+ unsigned long pixel_stride_,
+ unsigned char num_orientation_bins_,
+ int gradient_type_
+ >
+ class fine_hog_image : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON TEMPLATE PARAMETERS
+ - cell_size_ > 1
+ - block_size_ > 0
+ - pixel_stride_ > 0
+ - num_orientation_bins_ > 0
+ - gradient_type_ == hog_signed_gradient or hog_unsigned_gradient
+
+ INITIAL VALUE
+ - size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a version of the hog_image that allows you to extract HOG
+ features at a finer resolution. The hog_image can only extract HOG features
+ cell_size_ pixels apart. However, this object, the fine_hog_image can
+ extract HOG features from every pixel location.
+
+ The template arguments to this class have the same meaning as they do for
+ the hog_image, except for pixel_stride_. This controls the stepping between
+ HOG extraction locations. A value of 1 indicates HOG features should be
+ extracted from every pixel location. A value of 2 indicates every other pixel
+ location, etc.
+
+ Finally, note that the interpolation used by this object is equivalent
+ to using hog_angle_interpolation with hog_image.
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be protected
+ by a mutex lock except for the case where you are copying the configuration
+ (via copy_configuration()) of a fine_hog_image object to many other threads.
+ In this case, it is safe to copy the configuration of a shared object so long
+ as no other operations are performed on it.
+ !*/
+
+ public:
+
+ const static unsigned long cell_size = cell_size_;
+ const static unsigned long block_size = block_size_;
+ const static unsigned long pixel_stride = pixel_stride_;
+ const static unsigned long num_orientation_bins = num_orientation_bins_;
+ const static int gradient_type = gradient_type_;
+
+ const static long min_size = cell_size*block_size+2;
+
+ typedef matrix<double, block_size*block_size*num_orientation_bins, 1> descriptor_type;
+
+ fine_hog_image (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - this object will have its initial value
+ !*/
+
+ void copy_configuration (
+ const fine_hog_image&
+ );
+ /*!
+ ensures
+ - copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two fine_hog_image
+ objects H1 and H2, the following sequence of instructions should always
+ result in both of them having the exact same state.
+ H2.copy_configuration(H1);
+ H1.load(img);
+ H2.load(img);
+ !*/
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type is a dlib::matrix or something convertible to a matrix
+ via mat()
+ - pixel_traits<typename image_traits<image_type>::pixel_type>::has_alpha == false
+ ensures
+ - if (img.nr() < min_size || img.nc() < min_size) then
+ - the image is too small so we don't compute anything on it
+ - #size() == 0
+ - else
+ - generates a HOG image from the given image.
+ - #size() > 0
+ !*/
+
+ inline void unload(
+ );
+ /*!
+ ensures
+ - #nr() == 0
+ - #nc() == 0
+ - clears only the state information which is populated by load(). For
+ example, let H be a fine_hog_image object. Then consider the two
+ sequences of instructions:
+ Sequence 1:
+ H.load(img);
+ H.unload();
+ H.load(img);
+
+ Sequence 2:
+ H.load(img);
+ Both sequence 1 and sequence 2 should have the same effect on H.
+ !*/
+
+ inline size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns nr()*nc()
+ !*/
+
+ inline long nr (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in this HOG image
+ !*/
+
+ inline long nc (
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in this HOG image
+ !*/
+
+ long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns the number of dimensions in the feature vectors generated by
+ this object.
+ - In particular, returns the value block_size*block_size*num_orientation_bins
+ !*/
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const;
+ /*!
+ requires
+ - 0 <= row < nr()
+ - 0 <= col < nc()
+ ensures
+ - returns the descriptor for the HOG block at the given row and column. This descriptor
+ will include information from a window that is located at get_block_rect(row,col) in
+ the original image given to load().
+ - The returned descriptor vector will have get_num_dimensions() elements.
+ !*/
+
+ const rectangle get_block_rect (
+ long row,
+ long col
+ ) const;
+ /*!
+ ensures
+ - returns a rectangle that tells you what part of the original image is associated
+ with a particular HOG block. That is, what part of the input image is associated
+ with (*this)(row,col).
+ - The returned rectangle will be cell_size*block_size pixels wide and tall.
+ !*/
+
+ const point image_to_feat_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - Each local feature is extracted from a certain point in the input image.
+ This function returns the identity of the local feature corresponding
+ to the image location p. Or in other words, let P == image_to_feat_space(p),
+ then (*this)(P.y(),P.x()) == the local feature closest to, or centered at,
+ the point p in the input image. Note that some image points might not have
+ corresponding feature locations. E.g. border points or points outside the
+ image. In these cases the returned point will be outside get_rect(*this).
+ !*/
+
+ const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner()));
+ (i.e. maps a rectangle from image space to feature space)
+ !*/
+
+ const point feat_to_image_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - returns the location in the input image space corresponding to the center
+ of the local feature at point p. In other words, this function computes
+ the inverse of image_to_feat_space(). Note that it may only do so approximately,
+ since more than one image location might correspond to the same local feature.
+ That is, image_to_feat_space() might not be invertible so this function gives
+ the closest possible result.
+ !*/
+
+ const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner()));
+ (i.e. maps a rectangle from feature space to image space)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long T1,
+ unsigned long T2,
+ unsigned long T3,
+ unsigned char T4,
+ int T5
+ >
+ void serialize (
+ const fine_hog_image<T1,T2,T3,T4,T5>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ unsigned long T1,
+ unsigned long T2,
+ unsigned long T3,
+ unsigned char T4,
+ int T5
+ >
+ void deserialize (
+ fine_hog_image<T1,T2,T3,T4,T5>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FINE_HOG_IMaGE_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_keypoint/hashed_feature_image.h b/ml/dlib/dlib/image_keypoint/hashed_feature_image.h
new file mode 100644
index 000000000..80f429330
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/hashed_feature_image.h
@@ -0,0 +1,518 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HASHED_IMAGE_FEATUrES_Hh_
+#define DLIB_HASHED_IMAGE_FEATUrES_Hh_
+
+#include "../lsh/projection_hash.h"
+#include "hashed_feature_image_abstract.h"
+#include <vector>
+#include "../algs.h"
+#include "../matrix.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type_ = projection_hash
+ >
+ class hashed_feature_image : noncopyable
+ {
+
+ public:
+ typedef feature_extractor feature_extractor_type;
+ typedef hash_function_type_ hash_function_type;
+
+ typedef std::vector<std::pair<unsigned int,double> > descriptor_type;
+
+ hashed_feature_image (
+ );
+
+ void clear (
+ );
+
+ void set_hash (
+ const hash_function_type& hash_
+ );
+
+ const hash_function_type& get_hash (
+ ) const;
+
+ void copy_configuration (
+ const feature_extractor& item
+ );
+
+ void copy_configuration (
+ const hashed_feature_image& item
+ );
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ );
+
+ inline size_t size (
+ ) const;
+
+ inline long nr (
+ ) const;
+
+ inline long nc (
+ ) const;
+
+ inline long get_num_dimensions (
+ ) const;
+
+ void use_relative_feature_weights (
+ );
+
+ void use_uniform_feature_weights (
+ );
+
+ bool uses_uniform_feature_weights (
+ ) const;
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const;
+
+ inline const rectangle get_block_rect (
+ long row,
+ long col
+ ) const;
+
+ inline const point image_to_feat_space (
+ const point& p
+ ) const;
+
+ inline const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const;
+
+ inline const point feat_to_image_space (
+ const point& p
+ ) const;
+
+ inline const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const;
+
+ template <typename T>
+ friend void serialize (
+ const hashed_feature_image<T>& item,
+ std::ostream& out
+ );
+
+ template <typename T>
+ friend void deserialize (
+ hashed_feature_image<T>& item,
+ std::istream& in
+ );
+
+ private:
+
+ array2d<unsigned long> feats;
+ feature_extractor fe;
+ hash_function_type phash;
+ std::vector<float> feat_counts;
+ bool uniform_feature_weights;
+
+
+ // This is a transient variable. It is just here so it doesn't have to be
+ // reallocated over and over inside operator()
+ mutable descriptor_type hash_feats;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const hashed_feature_image<T>& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.feats, out);
+ serialize(item.fe, out);
+ serialize(item.phash, out);
+ serialize(item.feat_counts, out);
+ serialize(item.uniform_feature_weights, out);
+ }
+
+ template <typename T>
+ void deserialize (
+ hashed_feature_image<T>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing a dlib::hashed_feature_image object.");
+
+ deserialize(item.feats, in);
+ deserialize(item.fe, in);
+ deserialize(item.phash, in);
+ deserialize(item.feat_counts, in);
+ deserialize(item.uniform_feature_weights, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// hashed_feature_image member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ hashed_feature_image<feature_extractor,hash_function_type>::
+ hashed_feature_image (
+ )
+ {
+ clear();
+ hash_feats.resize(1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void hashed_feature_image<feature_extractor,hash_function_type>::
+ clear (
+ )
+ {
+ fe.clear();
+ phash = hash_function_type();
+ feats.clear();
+ feat_counts.clear();
+ uniform_feature_weights = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void hashed_feature_image<feature_extractor,hash_function_type>::
+ set_hash (
+ const hash_function_type& hash_
+ )
+ {
+ phash = hash_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const hash_function_type& hashed_feature_image<feature_extractor,hash_function_type>::
+ get_hash (
+ ) const
+ {
+ return phash;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void hashed_feature_image<feature_extractor,hash_function_type>::
+ copy_configuration (
+ const feature_extractor& item
+ )
+ {
+ fe.copy_configuration(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void hashed_feature_image<feature_extractor,hash_function_type>::
+ copy_configuration (
+ const hashed_feature_image& item
+ )
+ {
+ fe.copy_configuration(item.fe);
+ phash = item.phash;
+ uniform_feature_weights = item.uniform_feature_weights;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ template <
+ typename image_type
+ >
+ void hashed_feature_image<feature_extractor,hash_function_type>::
+ load (
+ const image_type& img
+ )
+ {
+ fe.load(img);
+
+ if (fe.size() != 0)
+ {
+ feats.set_size(fe.nr(), fe.nc());
+ feat_counts.assign(phash.num_hash_bins(),1);
+ if (uniform_feature_weights)
+ {
+ for (long r = 0; r < feats.nr(); ++r)
+ {
+ for (long c = 0; c < feats.nc(); ++c)
+ {
+ feats[r][c] = phash(fe(r,c));
+ }
+ }
+ }
+ else
+ {
+ for (long r = 0; r < feats.nr(); ++r)
+ {
+ for (long c = 0; c < feats.nc(); ++c)
+ {
+ feats[r][c] = phash(fe(r,c));
+ feat_counts[feats[r][c]]++;
+ }
+ }
+ }
+ }
+ else
+ {
+ feats.set_size(0,0);
+ }
+
+ if (!uniform_feature_weights)
+ {
+ // use the inverse frequency as the scale for each feature. We also scale
+ // these counts so that they are invariant to the size of the image (we scale
+ // them so they all look like they come from a 500x400 images).
+ const double scale = image_size(img)/(500.0*400.0);
+ for (unsigned long i = 0; i < feat_counts.size(); ++i)
+ {
+ feat_counts[i] = scale/feat_counts[i];
+ }
+ }
+
+ fe.unload();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ size_t hashed_feature_image<feature_extractor,hash_function_type>::
+ size (
+ ) const
+ {
+ return feats.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ long hashed_feature_image<feature_extractor,hash_function_type>::
+ nr (
+ ) const
+ {
+ return feats.nr();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ long hashed_feature_image<feature_extractor,hash_function_type>::
+ nc (
+ ) const
+ {
+ return feats.nc();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ long hashed_feature_image<feature_extractor,hash_function_type>::
+ get_num_dimensions (
+ ) const
+ {
+ return phash.num_hash_bins();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void hashed_feature_image<feature_extractor,hash_function_type>::
+ use_relative_feature_weights (
+ )
+ {
+ uniform_feature_weights = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ void hashed_feature_image<feature_extractor,hash_function_type>::
+ use_uniform_feature_weights (
+ )
+ {
+ uniform_feature_weights = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ bool hashed_feature_image<feature_extractor,hash_function_type>::
+ uses_uniform_feature_weights (
+ ) const
+ {
+ return uniform_feature_weights;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const std::vector<std::pair<unsigned int,double> >& hashed_feature_image<feature_extractor,hash_function_type>::
+ operator() (
+ long row,
+ long col
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= row && row < nr() &&
+ 0 <= col && col < nc(),
+ "\t descriptor_type hashed_feature_image::operator(row,col)"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t row: " << row
+ << "\n\t col: " << col
+ << "\n\t nr(): " << nr()
+ << "\n\t nc(): " << nc()
+ << "\n\t this: " << this
+ );
+
+ hash_feats[0] = std::make_pair(feats[row][col],feat_counts[feats[row][col]]);
+ return hash_feats;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const rectangle hashed_feature_image<feature_extractor,hash_function_type>::
+ get_block_rect (
+ long row,
+ long col
+ ) const
+ {
+ return fe.get_block_rect(row,col);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const point hashed_feature_image<feature_extractor,hash_function_type>::
+ image_to_feat_space (
+ const point& p
+ ) const
+ {
+ return fe.image_to_feat_space(p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const rectangle hashed_feature_image<feature_extractor,hash_function_type>::
+ image_to_feat_space (
+ const rectangle& rect
+ ) const
+ {
+ return fe.image_to_feat_space(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const point hashed_feature_image<feature_extractor,hash_function_type>::
+ feat_to_image_space (
+ const point& p
+ ) const
+ {
+ return fe.feat_to_image_space(p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type
+ >
+ const rectangle hashed_feature_image<feature_extractor,hash_function_type>::
+ feat_to_image_space (
+ const rectangle& rect
+ ) const
+ {
+ return fe.feat_to_image_space(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HASHED_IMAGE_FEATUrES_Hh_
+
+
diff --git a/ml/dlib/dlib/image_keypoint/hashed_feature_image_abstract.h b/ml/dlib/dlib/image_keypoint/hashed_feature_image_abstract.h
new file mode 100644
index 000000000..90c1348c5
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/hashed_feature_image_abstract.h
@@ -0,0 +1,303 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_HASHED_IMAGE_FEATUrES_ABSTRACT_Hh_
+#ifdef DLIB_HASHED_IMAGE_FEATUrES_ABSTRACT_Hh_
+
+#include "../lsh/projection_hash_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor,
+ typename hash_function_type_ = projection_hash
+ >
+ class hashed_feature_image : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ - must be an object with an interface compatible with dlib::hog_image
+
+ REQUIREMENTS ON hash_function_type_
+ - must be an object with an interface compatible with projection_hash
+
+ INITIAL VALUE
+ - size() == 0
+ - uses_uniform_feature_weights() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for performing image feature extraction. In
+ particular, it wraps another image feature extractor and converts the
+ wrapped image feature vectors into sparse indicator vectors. It does this
+ by hashing each feature vector into the range [0, get_num_dimensions()-1]
+ and then returns a new vector which is zero everywhere except for the
+ position determined by the hash.
+
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be protected
+ by a mutex lock except for the case where you are copying the configuration
+ (via copy_configuration()) of a hashed_feature_image object to many other threads.
+ In this case, it is safe to copy the configuration of a shared object so long
+ as no other operations are performed on it.
+
+
+ NOTATION
+ let BASE_FE denote the base feature_extractor object contained inside
+ the hashed_feature_image.
+ !*/
+
+ public:
+
+ typedef feature_extractor feature_extractor_type;
+ typedef hash_function_type_ hash_function_type;
+ typedef std::vector<std::pair<unsigned int,double> > descriptor_type;
+
+ hashed_feature_image (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - this object will have its initial value
+ !*/
+
+ void set_hash (
+ const hash_function_type& hash
+ );
+ /*!
+ ensures
+ - #get_hash() == hash
+ !*/
+
+ const hash_function_type& get_hash (
+ ) const;
+ /*!
+ ensures
+ - returns the hash function used by this object to hash
+ base feature vectors into integers.
+ !*/
+
+ void copy_configuration (
+ const feature_extractor& item
+ );
+ /*!
+ ensures
+ - performs BASE_FE.copy_configuration(item)
+ !*/
+
+ void copy_configuration (
+ const hashed_feature_image& item
+ );
+ /*!
+ ensures
+ - copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two hashed_feature_image
+ objects H1 and H2, the following sequence of instructions should always
+ result in both of them having the exact same state.
+ H2.copy_configuration(H1);
+ H1.load(img);
+ H2.load(img);
+ !*/
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == any type that can be supplied to feature_extractor::load()
+ ensures
+ - performs BASE_FE.load(img)
+ i.e. does feature extraction. The features can be accessed using
+ operator() as defined below.
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.size()
+ !*/
+
+ long nr (
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.nr()
+ !*/
+
+ long nc (
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.nc()
+ !*/
+
+ long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the feature vectors returned by operator().
+ In this case, this is the number of hash bins. That is, get_hash().num_hash_bins()
+ !*/
+
+ void use_relative_feature_weights (
+ );
+ /*!
+ ensures
+ - #uses_uniform_feature_weights() == false
+ !*/
+
+ void use_uniform_feature_weights (
+ );
+ /*!
+ ensures
+ - #uses_uniform_feature_weights() == true
+ !*/
+
+ bool uses_uniform_feature_weights (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object weights each feature with a value of 1 and
+ false if it uses a weighting of 1/N where N is the number of occurrences
+ of the feature in an image (note that we normalize N so that it is
+ invariant to the size of the image given to load()).
+ !*/
+
+ const descriptor_type& operator() (
+ long row,
+ long col
+ ) const;
+ /*!
+ requires
+ - 0 <= row < nr()
+ - 0 <= col < nc()
+ - It must be legal to evaluate expressions of the form: get_hash()(BASE_FE(row,col))
+ (e.g. the hash function must be properly configured to process the feature
+ vectors produced by the base feature extractor)
+ ensures
+ - hashes BASE_FE(row,col) and returns the resulting indicator vector.
+ - To be precise, this function returns a sparse vector V such that:
+ - V.size() == 1
+ - V[0].first == get_hash()(BASE_FE(row,col))
+ - if (uses_uniform_feature_weights()) then
+ - V[0].second == 1
+ - else
+ - V[0].second == 1/N where N is the number of times a feature in
+ hash bin V[0].first was observed in the image given to load().
+ Note that we scale all the counts so that they are invariant to
+ the size of the image.
+ !*/
+
+ const rectangle get_block_rect (
+ long row,
+ long col
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.get_block_rect(row,col)
+ I.e. returns a rectangle that tells you what part of the original image is associated
+ with a particular feature vector.
+ !*/
+
+ const point image_to_feat_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.image_to_feat_space(p)
+ I.e. Each local feature is extracted from a certain point in the input image.
+ This function returns the identity of the local feature corresponding
+ to the image location p. Or in other words, let P == image_to_feat_space(p),
+ then (*this)(P.y(),P.x()) == the local feature closest to, or centered at,
+ the point p in the input image. Note that some image points might not have
+ corresponding feature locations. E.g. border points or points outside the
+ image. In these cases the returned point will be outside get_rect(*this).
+ !*/
+
+ const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.image_to_feat_space(rect)
+ I.e. returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner()));
+ (i.e. maps a rectangle from image space to feature space)
+ !*/
+
+ const point feat_to_image_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.feat_to_image_space(p)
+ I.e. returns the location in the input image space corresponding to the center
+ of the local feature at point p. In other words, this function computes
+ the inverse of image_to_feat_space(). Note that it may only do so approximately,
+ since more than one image location might correspond to the same local feature.
+ That is, image_to_feat_space() might not be invertible so this function gives
+ the closest possible result.
+ !*/
+
+ const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.feat_to_image_space(rect)
+ I.e. return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner()));
+ (i.e. maps a rectangle from feature space to image space)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ void serialize (
+ const hashed_feature_image<T,U>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ void deserialize (
+ hashed_feature_image<T,U>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HASHED_IMAGE_FEATUrES_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/image_keypoint/hessian_pyramid.h b/ml/dlib/dlib/image_keypoint/hessian_pyramid.h
new file mode 100644
index 000000000..2e672c0d0
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/hessian_pyramid.h
@@ -0,0 +1,531 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HESSIAN_PYRAMId_Hh_
+#define DLIB_HESSIAN_PYRAMId_Hh_
+
+#include "hessian_pyramid_abstract.h"
+#include "../algs.h"
+#include "../image_transforms/integral_image.h"
+#include "../array.h"
+#include "../array2d.h"
+#include "../noncopyable.h"
+#include "../matrix.h"
+#include "../stl_checked.h"
+#include <algorithm>
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct interest_point
+ {
+ interest_point() : scale(0), score(0), laplacian(0) {}
+
+ dlib::vector<double,2> center;
+ double scale;
+ double score;
+ double laplacian;
+
+ bool operator < (const interest_point& p) const { return score < p.score; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize(
+ const interest_point& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.center,out);
+ serialize(item.scale,out);
+ serialize(item.score,out);
+ serialize(item.laplacian,out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type interest_point");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize(
+ interest_point& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.center,in);
+ deserialize(item.scale,in);
+ deserialize(item.score,in);
+ deserialize(item.laplacian,in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type interest_point");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class hessian_pyramid : noncopyable
+ {
+ public:
+ hessian_pyramid()
+ {
+ num_octaves = 0;
+ num_intervals = 0;
+ initial_step_size = 0;
+ }
+
+ template <typename integral_image_type>
+ void build_pyramid (
+ const integral_image_type& img,
+ long num_octaves,
+ long num_intervals,
+ long initial_step_size
+ )
+ {
+ DLIB_ASSERT(num_octaves > 0 && num_intervals > 0 && initial_step_size > 0,
+ "\tvoid build_pyramid()"
+ << "\n\tAll arguments to this function must be > 0"
+ << "\n\t this: " << this
+ << "\n\t num_octaves: " << num_octaves
+ << "\n\t num_intervals: " << num_intervals
+ << "\n\t initial_step_size: " << initial_step_size
+ );
+
+ this->num_octaves = num_octaves;
+ this->num_intervals = num_intervals;
+ this->initial_step_size = initial_step_size;
+
+ // allocate space for the pyramid
+ pyramid.resize(num_octaves*num_intervals);
+ for (long o = 0; o < num_octaves; ++o)
+ {
+ const long step_size = get_step_size(o);
+ for (long i = 0; i < num_intervals; ++i)
+ {
+ pyramid[num_intervals*o + i].set_size(img.nr()/step_size, img.nc()/step_size);
+ }
+ }
+
+ // now fill out the pyramid with data
+ for (long o = 0; o < num_octaves; ++o)
+ {
+ const long step_size = get_step_size(o);
+
+ for (long i = 0; i < num_intervals; ++i)
+ {
+ const long border_size = get_border_size(i)*step_size;
+ const long lobe_size = static_cast<long>(std::pow(2.0, o+1.0)+0.5)*(i+1) + 1;
+ const double area_inv = 1.0/std::pow(3.0*lobe_size, 2.0);
+
+ const long lobe_offset = lobe_size/2+1;
+ const point tl(-lobe_offset,-lobe_offset);
+ const point tr(lobe_offset,-lobe_offset);
+ const point bl(-lobe_offset,lobe_offset);
+ const point br(lobe_offset,lobe_offset);
+
+ for (long r = border_size; r < img.nr() - border_size; r += step_size)
+ {
+ for (long c = border_size; c < img.nc() - border_size; c += step_size)
+ {
+ const point p(c,r);
+
+ double Dxx = img.get_sum_of_area(centered_rect(p, lobe_size*3, 2*lobe_size-1)) -
+ img.get_sum_of_area(centered_rect(p, lobe_size, 2*lobe_size-1))*3.0;
+
+ double Dyy = img.get_sum_of_area(centered_rect(p, 2*lobe_size-1, lobe_size*3)) -
+ img.get_sum_of_area(centered_rect(p, 2*lobe_size-1, lobe_size))*3.0;
+
+ double Dxy = img.get_sum_of_area(centered_rect(p+bl, lobe_size, lobe_size)) +
+ img.get_sum_of_area(centered_rect(p+tr, lobe_size, lobe_size)) -
+ img.get_sum_of_area(centered_rect(p+tl, lobe_size, lobe_size)) -
+ img.get_sum_of_area(centered_rect(p+br, lobe_size, lobe_size));
+
+ // now we normalize the filter responses
+ Dxx *= area_inv;
+ Dyy *= area_inv;
+ Dxy *= area_inv;
+
+
+ double sign_of_laplacian = +1;
+ if (Dxx + Dyy < 0)
+ sign_of_laplacian = -1;
+
+ double determinant = Dxx*Dyy - 0.81*Dxy*Dxy;
+
+ // If the determinant is negative then just blank it out by setting
+ // it to zero.
+ if (determinant < 0)
+ determinant = 0;
+
+ // Save the determinant of the Hessian into our image pyramid. Also
+ // pack the laplacian sign into the value so we can get it out later.
+ pyramid[o*num_intervals + i][r/step_size][c/step_size] = sign_of_laplacian*determinant;
+
+ }
+ }
+
+ }
+ }
+ }
+
+ long get_border_size (
+ long interval
+ ) const
+ {
+ DLIB_ASSERT(0 <= interval && interval < intervals(),
+ "\tlong get_border_size(interval)"
+ << "\n\tInvalid interval value"
+ << "\n\t this: " << this
+ << "\n\t interval: " << interval
+ );
+
+ const double lobe_size = 2.0*(interval+1) + 1;
+ const double filter_size = 3*lobe_size;
+
+ const long bs = static_cast<long>(std::ceil(filter_size/2.0));
+ return bs;
+ }
+
+ long get_step_size (
+ long octave
+ ) const
+ {
+ DLIB_ASSERT(0 <= octave && octave < octaves(),
+ "\tlong get_step_size(octave)"
+ << "\n\tInvalid octave value"
+ << "\n\t this: " << this
+ << "\n\t octave: " << octave
+ );
+
+ return initial_step_size*static_cast<long>(std::pow(2.0, (double)octave)+0.5);
+ }
+
+ long nr (
+ long octave
+ ) const
+ {
+ DLIB_ASSERT(0 <= octave && octave < octaves(),
+ "\tlong nr(octave)"
+ << "\n\tInvalid octave value"
+ << "\n\t this: " << this
+ << "\n\t octave: " << octave
+ );
+
+ return pyramid[num_intervals*octave].nr();
+ }
+
+ long nc (
+ long octave
+ ) const
+ {
+ DLIB_ASSERT(0 <= octave && octave < octaves(),
+ "\tlong nc(octave)"
+ << "\n\tInvalid octave value"
+ << "\n\t this: " << this
+ << "\n\t octave: " << octave
+ );
+
+ return pyramid[num_intervals*octave].nc();
+ }
+
+ double get_value (
+ long octave,
+ long interval,
+ long r,
+ long c
+ ) const
+ {
+ DLIB_ASSERT(0 <= octave && octave < octaves() &&
+ 0 <= interval && interval < intervals() &&
+ get_border_size(interval) <= r && r < nr(octave)-get_border_size(interval) &&
+ get_border_size(interval) <= c && c < nc(octave)-get_border_size(interval),
+ "\tdouble get_value(octave, interval, r, c)"
+ << "\n\tInvalid inputs to this function"
+ << "\n\t this: " << this
+ << "\n\t octave: " << octave
+ << "\n\t interval: " << interval
+ << "\n\t octaves: " << octaves()
+ << "\n\t intervals: " << intervals()
+ << "\n\t r: " << r
+ << "\n\t c: " << c
+ << "\n\t nr(octave): " << nr(octave)
+ << "\n\t nc(octave): " << nc(octave)
+ << "\n\t get_border_size(interval): " << get_border_size(interval)
+ );
+
+ return std::abs(pyramid[num_intervals*octave + interval][r][c]);
+ }
+
+ double get_laplacian (
+ long octave,
+ long interval,
+ long r,
+ long c
+ ) const
+ {
+ DLIB_ASSERT(0 <= octave && octave < octaves() &&
+ 0 <= interval && interval < intervals() &&
+ get_border_size(interval) <= r && r < nr(octave)-get_border_size(interval) &&
+ get_border_size(interval) <= c && c < nc(octave)-get_border_size(interval),
+ "\tdouble get_laplacian(octave, interval, r, c)"
+ << "\n\tInvalid inputs to this function"
+ << "\n\t this: " << this
+ << "\n\t octave: " << octave
+ << "\n\t interval: " << interval
+ << "\n\t octaves: " << octaves()
+ << "\n\t intervals: " << intervals()
+ << "\n\t r: " << r
+ << "\n\t c: " << c
+ << "\n\t nr(octave): " << nr(octave)
+ << "\n\t nc(octave): " << nc(octave)
+ << "\n\t get_border_size(interval): " << get_border_size(interval)
+ );
+
+ // return the sign of the laplacian
+ if (pyramid[num_intervals*octave + interval][r][c] > 0)
+ return +1;
+ else
+ return -1;
+ }
+
+ long octaves (
+ ) const { return num_octaves; }
+
+ long intervals (
+ ) const { return num_intervals; }
+
+ private:
+
+ long num_octaves;
+ long num_intervals;
+ long initial_step_size;
+
+ typedef array2d<double> image_type;
+ typedef array<image_type> pyramid_type;
+
+ pyramid_type pyramid;
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace hessian_pyramid_helpers
+ {
+ inline bool is_maximum_in_region(
+ const hessian_pyramid& pyr,
+ long o,
+ long i,
+ long r,
+ long c
+ )
+ {
+ // First check if this point is near the edge of the octave
+ // If it is then we say it isn't a maximum as these points are
+ // not as reliable.
+ if (i <= 0 || i+1 >= pyr.intervals())
+ {
+ return false;
+ }
+
+ const double val = pyr.get_value(o,i,r,c);
+
+ // now check if there are any bigger values around this guy
+ for (long ii = i-1; ii <= i+1; ++ii)
+ {
+ for (long rr = r-1; rr <= r+1; ++rr)
+ {
+ for (long cc = c-1; cc <= c+1; ++cc)
+ {
+ if (pyr.get_value(o,ii,rr,cc) > val)
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline const matrix<double,3,1> get_hessian_gradient (
+ const hessian_pyramid& pyr,
+ long o,
+ long i,
+ long r,
+ long c
+ )
+ {
+ matrix<double,3,1> grad;
+ grad(0) = (pyr.get_value(o,i,r,c+1) - pyr.get_value(o,i,r,c-1))/2.0;
+ grad(1) = (pyr.get_value(o,i,r+1,c) - pyr.get_value(o,i,r-1,c))/2.0;
+ grad(2) = (pyr.get_value(o,i+1,r,c) - pyr.get_value(o,i-1,r,c))/2.0;
+ return grad;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline const matrix<double,3,3> get_hessian_hessian (
+ const hessian_pyramid& pyr,
+ long o,
+ long i,
+ long r,
+ long c
+ )
+ {
+ matrix<double,3,3> hess;
+ const double val = pyr.get_value(o,i,r,c);
+
+ double Dxx = (pyr.get_value(o,i,r,c+1) + pyr.get_value(o,i,r,c-1)) - 2*val;
+ double Dyy = (pyr.get_value(o,i,r+1,c) + pyr.get_value(o,i,r-1,c)) - 2*val;
+ double Dss = (pyr.get_value(o,i+1,r,c) + pyr.get_value(o,i-1,r,c)) - 2*val;
+
+ double Dxy = (pyr.get_value(o,i,r+1,c+1) + pyr.get_value(o,i,r-1,c-1) -
+ pyr.get_value(o,i,r-1,c+1) - pyr.get_value(o,i,r+1,c-1)) / 4.0;
+
+ double Dxs = (pyr.get_value(o,i+1,r,c+1) + pyr.get_value(o,i-1,r,c-1) -
+ pyr.get_value(o,i-1,r,c+1) - pyr.get_value(o,i+1,r,c-1)) / 4.0;
+
+ double Dys = (pyr.get_value(o,i+1,r+1,c) + pyr.get_value(o,i-1,r-1,c) -
+ pyr.get_value(o,i-1,r+1,c) - pyr.get_value(o,i+1,r-1,c)) / 4.0;
+
+
+ hess = Dxx, Dxy, Dxs,
+ Dxy, Dyy, Dys,
+ Dxs, Dys, Dss;
+
+ return hess;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline const interest_point interpolate_point (
+ const hessian_pyramid& pyr,
+ long o,
+ long i,
+ long r,
+ long c
+ )
+ {
+ dlib::vector<double,2> p(c,r);
+
+ dlib::vector<double,3> start_point(c,r,i);
+ dlib::vector<double,3> interpolated_point = -inv(get_hessian_hessian(pyr,o,i,r,c))*get_hessian_gradient(pyr,o,i,r,c);
+
+ //cout << "inter: " << trans(interpolated_point);
+
+ interest_point temp;
+ if (max(abs(interpolated_point)) < 0.5)
+ {
+ p = (start_point+interpolated_point)*pyr.get_step_size(o);
+ const double lobe_size = std::pow(2.0, o+1.0)*(i+interpolated_point.z()+1) + 1;
+ const double filter_size = 3*lobe_size;
+ const double scale = 1.2/9.0 * filter_size;
+
+ temp.center = p;
+ temp.scale = scale;
+ temp.score = pyr.get_value(o,i,r,c);
+ temp.laplacian = pyr.get_laplacian(o,i,r,c);
+ }
+ else
+ {
+ // this indicates to the caller that no interest point was found.
+ temp.score = -1;
+ }
+
+ return temp;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename Alloc>
+ void get_interest_points (
+ const hessian_pyramid& pyr,
+ double threshold,
+ std::vector<interest_point,Alloc>& result_points
+ )
+ {
+ DLIB_ASSERT(threshold >= 0,
+ "\tvoid get_interest_points()"
+ << "\n\t Invalid arguments to this function"
+ << "\n\t threshold: " << threshold
+ );
+ using namespace std;
+ using namespace hessian_pyramid_helpers;
+
+ result_points.clear();
+
+ for (long o = 0; o < pyr.octaves(); ++o)
+ {
+ const long nr = pyr.nr(o);
+ const long nc = pyr.nc(o);
+
+ // do non-maximum suppression on all the intervals in the current octave and
+ // accumulate the results in result_points
+ for (long i = 1; i < pyr.intervals()-1; i += 1)
+ {
+ const long border_size = pyr.get_border_size(i+1);
+ for (long r = border_size+1; r < nr - border_size-1; r += 1)
+ {
+ for (long c = border_size+1; c < nc - border_size-1; c += 1)
+ {
+ double max_val = pyr.get_value(o,i,r,c);
+ long max_i = i;
+ long max_r = r;
+ long max_c = c;
+
+
+ // If the max point we found is really a maximum in its own region and
+ // is big enough then add it to the results.
+ if (max_val >= threshold && is_maximum_in_region(pyr, o, max_i, max_r, max_c))
+ {
+ //cout << max_val << endl;
+ interest_point sp = interpolate_point (pyr, o, max_i, max_r, max_c);
+ if (sp.score >= threshold)
+ {
+ result_points.push_back(sp);
+ }
+ }
+
+ }
+ }
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename Alloc>
+ void get_interest_points (
+ const hessian_pyramid& pyr,
+ double threshold,
+ std_vector_c<interest_point,Alloc>& result_points
+ )
+ /*!
+ This function is just an overload that automatically casts std_vector_c objects
+ into std::vector objects. (Usually this is automatic but the template argument
+ there messes up the conversion so we have to do it explicitly)
+ !*/
+ {
+ std::vector<interest_point,Alloc>& v = result_points;
+ get_interest_points(pyr, threshold, v);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HESSIAN_PYRAMId_Hh_
+
diff --git a/ml/dlib/dlib/image_keypoint/hessian_pyramid_abstract.h b/ml/dlib/dlib/image_keypoint/hessian_pyramid_abstract.h
new file mode 100644
index 000000000..2db39c210
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/hessian_pyramid_abstract.h
@@ -0,0 +1,244 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_HESSIAN_PYRAMId_ABSTRACT_Hh_
+#ifdef DLIB_HESSIAN_PYRAMId_ABSTRACT_Hh_
+
+#include "../image_transforms/integral_image_abstract.h"
+#include "../noncopyable.h"
+#include <vector>
+
+namespace dlib
+{
+
+ class hessian_pyramid : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ - octaves() == 0
+ - intervals() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an image pyramid where each level in the
+ pyramid holds determinants of Hessian matrices for the original
+ input image. This object can be used to find stable interest
+ points in an image. For further details consult the following
+ papers.
+
+ This object is an implementation of the fast Hessian pyramid
+ as described in the paper:
+ SURF: Speeded Up Robust Features
+ By Herbert Bay, Tinne Tuytelaars, and Luc Van Gool
+
+ This implementation was also influenced by the very well documented
+ OpenSURF library and its corresponding description of how the fast
+ Hessian algorithm functions:
+ Notes on the OpenSURF Library
+ Christopher Evans
+ !*/
+ public:
+
+ template <typename integral_image_type>
+ void build_pyramid (
+ const integral_image_type& img,
+ long num_octaves,
+ long num_intervals,
+ long initial_step_size
+ );
+ /*!
+ requires
+ - num_octaves > 0
+ - num_intervals > 0
+ - initial_step_size > 0
+ - integral_image_type == an object such as dlib::integral_image or another
+ type that implements the interface defined in image_transforms/integral_image_abstract.h
+ ensures
+ - #get_step_size(0) == initial_step_size
+ - #octaves() == num_octaves
+ - #intervals() == num_intervals
+ - creates a Hessian pyramid from the given input image.
+ !*/
+
+ long octaves (
+ ) const;
+ /*!
+ ensures
+ - returns the number of octaves in this pyramid
+ !*/
+
+ long intervals (
+ ) const;
+ /*!
+ ensures
+ - returns the number of intervals in this pyramid
+ !*/
+
+ long get_border_size (
+ long interval
+ ) const;
+ /*!
+ requires
+ - 0 <= interval < intervals()
+ ensures
+ - Each interval of the pyramid has a certain sized border region where we
+ can't compute the Hessian values since they are too close to the edge
+ of the input image. This function returns the size of that border.
+ !*/
+
+ long get_step_size (
+ long octave
+ ) const;
+ /*!
+ requires
+ - 0 <= octave < octaves()
+ ensures
+ - Each octave has a step size value. This value determines how many
+ input image pixels separate each pixel in the given pyramid octave.
+ As the octave gets larger (i.e. as it goes to the top of the pyramid) the
+ step size gets bigger and thus the pyramid narrows.
+ !*/
+
+ long nr (
+ long octave
+ ) const;
+ /*!
+ requires
+ - 0 <= octave < octaves()
+ ensures
+ - returns the number of rows there are per layer in the given
+ octave of pyramid
+ !*/
+
+ long nc (
+ long octave
+ ) const;
+ /*!
+ requires
+ - 0 <= octave < octaves()
+ ensures
+ - returns the number of columns there are per layer in the given
+ octave of pyramid
+ !*/
+
+ double get_value (
+ long octave,
+ long interval,
+ long r,
+ long c
+ ) const;
+ /*!
+ requires
+ - 0 <= octave < octaves()
+ - 0 <= interval < intervals()
+ - Let BS == get_border_size(interval): then
+ - BS <= r < nr(octave)-BS
+ - BS <= c < nc(octave)-BS
+ ensures
+ - returns the determinant of the Hessian from the given octave and interval
+ of the pyramid. The specific point sampled at this pyramid level is
+ the one that corresponds to the input image point (point(r,c)*get_step_size(octave)).
+ !*/
+
+ double get_laplacian (
+ long octave,
+ long interval,
+ long r,
+ long c
+ ) const;
+ /*!
+ requires
+ - 0 <= octave < octaves()
+ - 0 <= interval < intervals()
+ - Let BS == get_border_size(interval): then
+ - BS <= r < nr(octave)-BS
+ - BS <= c < nc(octave)-BS
+ ensures
+ - returns the sign of the laplacian for the given octave and interval
+ of the pyramid. The specific point sampled at this pyramid level is
+ the one that corresponds to the input image point (point(r,c)*get_step_size(octave)).
+ - The laplacian is the trace of the Hessian at the given point. So this
+ function returns either +1 or -1 depending on this number's sign. This
+ value can be used to distinguish bright blobs on dark backgrounds from
+ the reverse.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct interest_point
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object contains the interest points found using the
+ hessian_pyramid object. Its fields have the following
+ meanings:
+ - center == the x/y location of the center of the interest point
+ (in image space coordinates. y gives the row and x gives the
+ column in the image)
+ - scale == the scale at which the point was detected. This is a number
+ >= 1. If it is 1 then it means the interest point was detected at
+ the lowest scale in the image pyramid. Larger numbers indicate that
+ the interest point is from high up in the image pyramid. For
+ example, a scale of 4 would mean the interest point was located at a
+ point in the pyramid where the image had been shrunk by a factor of 4.
+ - score == the determinant of the Hessian for the interest point
+ - laplacian == the sign of the laplacian for the interest point
+ !*/
+
+ interest_point() : scale(0), score(0), laplacian(0) {}
+
+ dlib::vector<double,2> center;
+ double scale;
+ double score;
+ double laplacian;
+
+ bool operator < (const interest_point& p) const { return score < p.score; }
+ /*!
+ This function is here so you can sort interest points according to
+ their scores
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const interest_point& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ interest_point& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename Alloc>
+ void get_interest_points (
+ const hessian_pyramid& pyr,
+ double threshold,
+ std::vector<interest_point,Alloc>& result_points
+ )
+ /*!
+ requires
+ - threshold >= 0
+ ensures
+ - extracts interest points from the pyramid pyr and stores them into
+ result_points (note that result_points is cleared before these new interest
+ points are added to it).
+ - Only interest points with determinant values in the pyramid larger than
+ threshold are output.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HESSIAN_PYRAMId_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_keypoint/hog.h b/ml/dlib/dlib/image_keypoint/hog.h
new file mode 100644
index 000000000..823c25d6d
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/hog.h
@@ -0,0 +1,514 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HoG_Hh_
+#define DLIB_HoG_Hh_
+
+#include "hog_abstract.h"
+#include "../algs.h"
+#include "../matrix.h"
+#include "../array2d.h"
+#include "../geometry.h"
+#include <cmath>
+
+namespace dlib
+{
+ enum
+ {
+ hog_no_interpolation,
+ hog_angle_interpolation,
+ hog_full_interpolation,
+ hog_signed_gradient,
+ hog_unsigned_gradient
+ };
+
+ template <
+ unsigned long cell_size_,
+ unsigned long block_size_,
+ unsigned long cell_stride_,
+ unsigned long num_orientation_bins_,
+ int gradient_type_,
+ int interpolation_type_
+ >
+ class hog_image : noncopyable
+ {
+ COMPILE_TIME_ASSERT(cell_size_ > 1);
+ COMPILE_TIME_ASSERT(block_size_ > 0);
+ COMPILE_TIME_ASSERT(cell_stride_ > 0);
+ COMPILE_TIME_ASSERT(num_orientation_bins_ > 0);
+
+ COMPILE_TIME_ASSERT( gradient_type_ == hog_signed_gradient ||
+ gradient_type_ == hog_unsigned_gradient);
+
+ COMPILE_TIME_ASSERT( interpolation_type_ == hog_no_interpolation ||
+ interpolation_type_ == hog_angle_interpolation ||
+ interpolation_type_ == hog_full_interpolation );
+
+
+ public:
+
+ const static unsigned long cell_size = cell_size_;
+ const static unsigned long block_size = block_size_;
+ const static unsigned long cell_stride = cell_stride_;
+ const static unsigned long num_orientation_bins = num_orientation_bins_;
+ const static int gradient_type = gradient_type_;
+ const static int interpolation_type = interpolation_type_;
+
+ const static long min_size = cell_size*block_size+2;
+
+ typedef matrix<double, block_size*block_size*num_orientation_bins, 1> descriptor_type;
+
+ hog_image (
+ ) :
+ num_block_rows(0),
+ num_block_cols(0)
+ {}
+
+ void clear (
+ )
+ {
+ num_block_rows = 0;
+ num_block_cols = 0;
+ hist_cells.clear();
+ }
+
+ void copy_configuration (
+ const hog_image&
+ ){}
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ )
+ {
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<image_type>::pixel_type>::has_alpha == false );
+ load_impl(mat(img));
+ }
+
+ inline void unload(
+ ) { clear(); }
+
+ inline size_t size (
+ ) const { return static_cast<size_t>(nr()*nc()); }
+
+ inline long nr (
+ ) const { return num_block_rows; }
+
+ inline long nc (
+ ) const { return num_block_cols; }
+
+ long get_num_dimensions (
+ ) const
+ {
+ return block_size*block_size*num_orientation_bins;
+ }
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( 0 <= row && row < nr() &&
+ 0 <= col && col < nc(),
+ "\t descriptor_type hog_image::operator()()"
+ << "\n\t invalid row or col argument"
+ << "\n\t row: " << row
+ << "\n\t col: " << col
+ << "\n\t nr(): " << nr()
+ << "\n\t nc(): " << nc()
+ << "\n\t this: " << this
+ );
+
+ row *= cell_stride;
+ col *= cell_stride;
+ ++row;
+ ++col;
+
+ int feat = 0;
+ for (unsigned long r = 0; r < block_size; ++r)
+ {
+ for (unsigned long c = 0; c < block_size; ++c)
+ {
+ for (unsigned long i = 0; i < num_orientation_bins; ++i)
+ {
+ des(feat++) = hist_cells[row+r][col+c].values[i];
+ }
+ }
+ }
+
+ des /= length(des) + 1e-8;
+
+ return des;
+ }
+
+ const rectangle get_block_rect (
+ long row,
+ long col
+ ) const
+ {
+ row *= cell_stride;
+ col *= cell_stride;
+
+ row *= cell_size;
+ col *= cell_size;
+
+ // do this to account for the 1 pixel padding we use all around the image
+ ++row;
+ ++col;
+
+ return rectangle(col, row, col+cell_size*block_size-1, row+cell_size*block_size-1);
+ }
+
+ const point image_to_feat_space (
+ const point& p
+ ) const
+ {
+
+ const long half_block = block_size/2;
+ if ((block_size%2) == 0)
+ {
+ return point(((p.x()-1)/(long)cell_size - half_block)/(long)cell_stride,
+ ((p.y()-1)/(long)cell_size - half_block)/(long)cell_stride);
+ }
+ else
+ {
+ return point(((p.x()-1-(long)cell_size/2)/(long)cell_size - half_block)/(long)cell_stride,
+ ((p.y()-1-(long)cell_size/2)/(long)cell_size - half_block)/(long)cell_stride);
+ }
+ }
+
+ const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const
+ {
+ return rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner()));
+ }
+
+ const point feat_to_image_space (
+ const point& p
+ ) const
+ {
+ const long half_block = block_size/2;
+ if ((block_size%2) == 0)
+ {
+ return point((p.x()*cell_stride + half_block)*cell_size + 1,
+ (p.y()*cell_stride + half_block)*cell_size + 1);
+ }
+ else
+ {
+ return point((p.x()*cell_stride + half_block)*cell_size + 1 + cell_size/2,
+ (p.y()*cell_stride + half_block)*cell_size + 1 + cell_size/2);
+ }
+ }
+
+ const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const
+ {
+ return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner()));
+ }
+
+
+
+ // these _PRIVATE_ functions are only here as a workaround for a bug in visual studio 2005.
+ void _PRIVATE_serialize (std::ostream& out) const
+ {
+ // serialize hist_cells
+ serialize(hist_cells.nc(),out);
+ serialize(hist_cells.nr(),out);
+ hist_cells.reset();
+ while (hist_cells.move_next())
+ serialize(hist_cells.element().values,out);
+ hist_cells.reset();
+
+
+ serialize(num_block_rows, out);
+ serialize(num_block_cols, out);
+ }
+
+ void _PRIVATE_deserialize (std::istream& in )
+ {
+ // deserialize item.hist_cells
+ long nc, nr;
+ deserialize(nc,in);
+ deserialize(nr,in);
+ hist_cells.set_size(nr,nc);
+ while (hist_cells.move_next())
+ deserialize(hist_cells.element().values,in);
+ hist_cells.reset();
+
+
+ deserialize(num_block_rows, in);
+ deserialize(num_block_cols, in);
+ }
+
+ private:
+
+ template <
+ typename image_type
+ >
+ void load_impl (
+ const image_type& img
+ )
+ {
+ // Note that we keep a border of 1 pixel all around the image so that we don't have
+ // to worry about running outside the image when computing the horizontal and vertical
+ // gradients.
+
+ // Note also that we have a border of unused cells around the hist_cells array so that we
+ // don't have to worry about edge effects when doing the interpolation in the main loop
+ // below.
+
+
+ // check if the window is just too small
+ if (img.nr() < min_size || img.nc() < min_size)
+ {
+ // If the image is smaller than our windows then there aren't any descriptors at all!
+ num_block_rows = 0;
+ num_block_cols = 0;
+ return;
+ }
+
+ // Make sure we have the right number of cell histograms and that they are
+ // all set to zero.
+ hist_cells.set_size((img.nr()-2)/cell_size+2, (img.nc()-2)/cell_size+2);
+ for (long r = 0; r < hist_cells.nr(); ++r)
+ {
+ for (long c = 0; c < hist_cells.nc(); ++c)
+ {
+ hist_cells[r][c].zero();
+ }
+ }
+
+
+ // loop over all the histogram cells and fill them out
+ for (long rh = 1; rh < hist_cells.nr()-1; ++rh)
+ {
+ for (long ch = 1; ch < hist_cells.nc()-1; ++ch)
+ {
+ // Fill out the current histogram cell.
+ // First, figure out the row and column offsets into the image for the current histogram cell.
+ const long roff = (rh-1)*cell_size + 1;
+ const long coff = (ch-1)*cell_size + 1;
+
+ for (long r = 0; r < (long)cell_size; ++r)
+ {
+ for (long c = 0; c < (long)cell_size; ++c)
+ {
+ unsigned long left;
+ unsigned long right;
+ unsigned long top;
+ unsigned long bottom;
+
+ assign_pixel(left, img(r+roff,c+coff-1));
+ assign_pixel(right, img(r+roff,c+coff+1));
+ assign_pixel(top, img(r+roff-1,c+coff));
+ assign_pixel(bottom, img(r+roff+1,c+coff));
+
+ double grad_x = (long)right-(long)left;
+ double grad_y = (long)top-(long)bottom;
+
+ // obtain the angle of the gradient. Make sure it is scaled between 0 and 1.
+ double angle = std::max(0.0, std::atan2(grad_y, grad_x)/pi + 1)/2;
+
+
+ if (gradient_type == hog_unsigned_gradient)
+ {
+ angle *= 2;
+ if (angle >= 1)
+ angle -= 1;
+ }
+
+
+ // now scale angle to between 0 and num_orientation_bins
+ angle *= num_orientation_bins;
+
+
+ const double strength = std::sqrt(grad_y*grad_y + grad_x*grad_x);
+
+
+ if (interpolation_type == hog_no_interpolation)
+ {
+ // no interpolation
+ hist_cells[rh][ch].values[round_to_int(angle)%num_orientation_bins] += strength;
+ }
+ else // if we should do some interpolation
+ {
+ unsigned long quantized_angle_lower = static_cast<unsigned long>(std::floor(angle));
+ unsigned long quantized_angle_upper = static_cast<unsigned long>(std::ceil(angle));
+
+ quantized_angle_lower %= num_orientation_bins;
+ quantized_angle_upper %= num_orientation_bins;
+
+ const double angle_split = (angle-std::floor(angle));
+ const double upper_strength = angle_split*strength;
+ const double lower_strength = (1-angle_split)*strength;
+
+ if (interpolation_type == hog_angle_interpolation)
+ {
+ // Stick into gradient histogram. Note that we linearly interpolate between neighboring
+ // histogram buckets.
+ hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength;
+ hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength;
+ }
+ else // here we do hog_full_interpolation
+ {
+ const double center_r = (cell_size-1)/2.0;
+ const double center_c = (cell_size-1)/2.0;
+
+ const double lin_neighbor_r = std::abs(center_r - r)/cell_size;
+ const double lin_main_r = 1-lin_neighbor_r;
+
+ const double lin_neighbor_c = std::abs(center_c - c)/cell_size;
+ const double lin_main_c = 1-lin_neighbor_c;
+
+ // Which neighboring cells we interpolate into depends on which
+ // corner of our main cell we are nearest.
+ if (r < center_r)
+ {
+ if (c < center_c)
+ {
+ hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength * lin_main_r*lin_main_c;
+ hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength * lin_main_r*lin_main_c;
+
+ hist_cells[rh-1][ch].values[quantized_angle_upper] += upper_strength * lin_neighbor_r*lin_main_c;
+ hist_cells[rh-1][ch].values[quantized_angle_lower] += lower_strength * lin_neighbor_r*lin_main_c;
+
+ hist_cells[rh][ch-1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_main_r;
+ hist_cells[rh][ch-1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_main_r;
+
+ hist_cells[rh-1][ch-1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_neighbor_r;
+ hist_cells[rh-1][ch-1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_neighbor_r;
+ }
+ else
+ {
+ hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength * lin_main_r*lin_main_c;
+ hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength * lin_main_r*lin_main_c;
+
+ hist_cells[rh-1][ch].values[quantized_angle_upper] += upper_strength * lin_neighbor_r*lin_main_c;
+ hist_cells[rh-1][ch].values[quantized_angle_lower] += lower_strength * lin_neighbor_r*lin_main_c;
+
+ hist_cells[rh][ch+1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_main_r;
+ hist_cells[rh][ch+1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_main_r;
+
+ hist_cells[rh-1][ch+1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_neighbor_r;
+ hist_cells[rh-1][ch+1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_neighbor_r;
+ }
+ }
+ else
+ {
+ if (c < center_c)
+ {
+ hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength * lin_main_r*lin_main_c;
+ hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength * lin_main_r*lin_main_c;
+
+ hist_cells[rh+1][ch].values[quantized_angle_upper] += upper_strength * lin_neighbor_r*lin_main_c;
+ hist_cells[rh+1][ch].values[quantized_angle_lower] += lower_strength * lin_neighbor_r*lin_main_c;
+
+ hist_cells[rh][ch-1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_main_r;
+ hist_cells[rh][ch-1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_main_r;
+
+ hist_cells[rh+1][ch-1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_neighbor_r;
+ hist_cells[rh+1][ch-1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_neighbor_r;
+ }
+ else
+ {
+ hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength * lin_main_r*lin_main_c;
+ hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength * lin_main_r*lin_main_c;
+
+ hist_cells[rh+1][ch].values[quantized_angle_upper] += upper_strength * lin_neighbor_r*lin_main_c;
+ hist_cells[rh+1][ch].values[quantized_angle_lower] += lower_strength * lin_neighbor_r*lin_main_c;
+
+ hist_cells[rh][ch+1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_main_r;
+ hist_cells[rh][ch+1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_main_r;
+
+ hist_cells[rh+1][ch+1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_neighbor_r;
+ hist_cells[rh+1][ch+1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_neighbor_r;
+ }
+ }
+ }
+ }
+
+
+ }
+ }
+ }
+ }
+
+
+ // Now figure out how many blocks we should have. Note again that the hist_cells has a border of
+ // unused cells (thats where that -2 comes from).
+ num_block_rows = (hist_cells.nr()-2 - (block_size-1) + cell_stride - 1)/cell_stride;
+ num_block_cols = (hist_cells.nc()-2 - (block_size-1) + cell_stride - 1)/cell_stride;
+
+ }
+
+ unsigned long round_to_int(
+ double val
+ ) const
+ {
+ return static_cast<unsigned long>(std::floor(val + 0.5));
+ }
+
+ struct histogram
+ {
+ void zero()
+ {
+ for (unsigned long i = 0; i < num_orientation_bins; ++i)
+ values[i] = 0;
+ }
+ double values[num_orientation_bins];
+ };
+
+ array2d<histogram> hist_cells;
+
+ mutable descriptor_type des;
+
+ long num_block_rows;
+ long num_block_cols;
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long T1,
+ unsigned long T2,
+ unsigned long T3,
+ unsigned long T4,
+ int T5,
+ int T6
+ >
+ void serialize (
+ const hog_image<T1,T2,T3,T4,T5,T6>& item,
+ std::ostream& out
+ )
+ {
+ item._PRIVATE_serialize(out);
+ }
+
+ template <
+ unsigned long T1,
+ unsigned long T2,
+ unsigned long T3,
+ unsigned long T4,
+ int T5,
+ int T6
+ >
+ void deserialize (
+ hog_image<T1,T2,T3,T4,T5,T6>& item,
+ std::istream& in
+ )
+ {
+ item._PRIVATE_deserialize(in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HoG_Hh_
+
diff --git a/ml/dlib/dlib/image_keypoint/hog_abstract.h b/ml/dlib/dlib/image_keypoint/hog_abstract.h
new file mode 100644
index 000000000..26c8cab64
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/hog_abstract.h
@@ -0,0 +1,335 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_HoG_ABSTRACT_Hh_
+#ifdef DLIB_HoG_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "../matrix.h"
+#include "../array2d.h"
+#include "../geometry.h"
+#include <cmath>
+
+namespace dlib
+{
+ enum
+ {
+ hog_no_interpolation,
+ hog_angle_interpolation,
+ hog_full_interpolation,
+ hog_signed_gradient,
+ hog_unsigned_gradient
+ };
+
+ template <
+ unsigned long cell_size_,
+ unsigned long block_size_,
+ unsigned long cell_stride_,
+ unsigned long num_orientation_bins_,
+ int gradient_type_,
+ int interpolation_type_
+ >
+ class hog_image : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON TEMPLATE PARAMETERS
+ - cell_size_ > 1
+ - block_size_ > 0
+ - cell_stride_ > 0
+ - num_orientation_bins_ > 0
+ - gradient_type_ == hog_signed_gradient or hog_unsigned_gradient
+ - interpolation_type_ == hog_no_interpolation, hog_angle_interpolation, or
+ hog_full_interpolation
+
+ INITIAL VALUE
+ - size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for performing the image feature extraction algorithm
+ described in the following paper:
+ Histograms of Oriented Gradients for Human Detection
+ by Navneet Dalal and Bill Triggs
+
+
+ To summarize the technique, this object tiles non-overlapping cells over an
+ image. Each of these cells is a box that is cell_size by cell_size pixels
+ in size. Each cell contains an array of size num_orientation_bins. The array
+ in a cell is used to store a histogram of all the edge orientations contained
+ within the cell's image region.
+
+ Once the grid of cells and their histograms has been computed (via load())
+ you can obtain descriptors for each "block" in the image. A block is just a
+ group of cells and blocks are allowed to overlap. Each block is square and
+ made up of block_size*block_size cells. So when you call operator()(r,c)
+ what you obtain is a vector that is just a bunch of cell histograms that
+ have been concatenated (and length normalized).
+
+ The template arguments control the various parameters of this algorithm.
+
+ The interpolation_type parameter controls the amount of interpolation
+ that happens during the creation of the edge orientation histograms. It
+ varies from no interpolation at all to full spatial and angle interpolation.
+
+ Angle interpolation means that an edge doesn't just go into its nearest
+ histogram bin but instead gets interpolated into its two nearest neighbors.
+ Similarly, spatial interpolation means that an edge doesn't just go into
+ the cell it is in but it also contributes to nearby cells depending on how
+ close they are.
+
+ The gradient_type parameter controls how edge orientations are measured.
+ Consider the following ASCII art:
+ signed gradients: unsigned gradients:
+ /\ |
+ || |
+ <--- ----> ------+------
+ || |
+ \/ |
+
+ An image is full of gradients caused by edges between objects. The direction
+ of a gradient is determined by which end of it has pixels of highest intensity.
+ So for example, suppose you had a picture containing black and white stripes.
+ Then the magnitude of the gradient at each point in the image tells you if you
+ are on the edge of a stripe and the gradient's orientation tells you which
+ direction you have to move get into the white stripe.
+
+ Signed gradients preserve this direction information while unsigned gradients
+ do not. An unsigned gradient will only tell you the orientation of the stripe
+ but not which direction leads to the white stripe.
+
+ Finally, the cell_stride parameter controls how much overlap you get between
+ blocks. The maximum amount of overlap is obtained when cell_stride == 1.
+ At the other extreme, you would have no overlap if cell_stride == block_size.
+
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be protected
+ by a mutex lock except for the case where you are copying the configuration
+ (via copy_configuration()) of a hog_image object to many other threads.
+ In this case, it is safe to copy the configuration of a shared object so long
+ as no other operations are performed on it.
+ !*/
+
+ public:
+
+ const static unsigned long cell_size = cell_size_;
+ const static unsigned long block_size = block_size_;
+ const static unsigned long cell_stride = cell_stride_;
+ const static unsigned long num_orientation_bins = num_orientation_bins_;
+ const static int gradient_type = gradient_type_;
+ const static int interpolation_type = interpolation_type_;
+
+ const static long min_size = cell_size*block_size+2;
+
+ typedef matrix<double, block_size*block_size*num_orientation_bins, 1> descriptor_type;
+
+ hog_image (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - this object will have its initial value
+ !*/
+
+ void copy_configuration (
+ const hog_image& item
+ );
+ /*!
+ ensures
+ - copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two hog_image
+ objects H1 and H2, the following sequence of instructions should always
+ result in both of them having the exact same state.
+ H2.copy_configuration(H1);
+ H1.load(img);
+ H2.load(img);
+ !*/
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type is a dlib::matrix or something convertible to a matrix
+ via mat().
+ - pixel_traits<typename image_traits<image_type>::pixel_type>::has_alpha == false
+ ensures
+ - if (img.nr() < min_size || img.nc() < min_size) then
+ - the image is too small so we don't compute anything on it
+ - #size() == 0
+ - else
+ - generates a HOG image from the given image.
+ - #size() > 0
+ !*/
+
+ inline void unload (
+ );
+ /*!
+ ensures
+ - #nr() == 0
+ - #nc() == 0
+ - clears only the state information which is populated by load(). For
+ example, let H be a hog_image object. Then consider the two sequences
+ of instructions:
+ Sequence 1:
+ H.load(img);
+ H.unload();
+ H.load(img);
+
+ Sequence 2:
+ H.load(img);
+ Both sequence 1 and sequence 2 should have the same effect on H.
+ !*/
+
+ inline size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns nr()*nc()
+ !*/
+
+ inline long nr (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in this HOG image
+ !*/
+
+ inline long nc (
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in this HOG image
+ !*/
+
+ long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns the number of dimensions in the feature vectors generated by
+ this object.
+ - In particular, returns the value block_size*block_size*num_orientation_bins
+ !*/
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const;
+ /*!
+ requires
+ - 0 <= row < nr()
+ - 0 <= col < nc()
+ ensures
+ - returns the descriptor for the HOG block at the given row and column. This descriptor
+ will include information from a window that is located at get_block_rect(row,col) in
+ the original image given to load().
+ - The returned descriptor vector will have get_num_dimensions() elements.
+ !*/
+
+ const rectangle get_block_rect (
+ long row,
+ long col
+ ) const;
+ /*!
+ ensures
+ - returns a rectangle that tells you what part of the original image is associated
+ with a particular HOG block. That is, what part of the input image is associated
+ with (*this)(row,col).
+ - The returned rectangle will be cell_size*block_size pixels wide and tall.
+ !*/
+
+ const point image_to_feat_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - Each local feature is extracted from a certain point in the input image.
+ This function returns the identity of the local feature corresponding
+ to the image location p. Or in other words, let P == image_to_feat_space(p),
+ then (*this)(P.y(),P.x()) == the local feature closest to, or centered at,
+ the point p in the input image. Note that some image points might not have
+ corresponding feature locations. E.g. border points or points outside the
+ image. In these cases the returned point will be outside get_rect(*this).
+ !*/
+
+ const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner()));
+ (i.e. maps a rectangle from image space to feature space)
+ !*/
+
+ const point feat_to_image_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - returns the location in the input image space corresponding to the center
+ of the local feature at point p. In other words, this function computes
+ the inverse of image_to_feat_space(). Note that it may only do so approximately,
+ since more than one image location might correspond to the same local feature.
+ That is, image_to_feat_space() might not be invertible so this function gives
+ the closest possible result.
+ !*/
+
+ const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner()));
+ (i.e. maps a rectangle from feature space to image space)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long T1,
+ unsigned long T2,
+ unsigned long T3,
+ unsigned long T4,
+ int T5,
+ int T6
+ >
+ void serialize (
+ const hog_image<T1,T2,T3,T4,T5,T6>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ unsigned long T1,
+ unsigned long T2,
+ unsigned long T3,
+ unsigned long T4,
+ int T5,
+ int T6
+ >
+ void deserialize (
+ hog_image<T1,T2,T3,T4,T5,T6>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_HoG_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image.h b/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image.h
new file mode 100644
index 000000000..2ee45da2f
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image.h
@@ -0,0 +1,408 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_Hh_
+#define DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_Hh_
+
+#include "nearest_neighbor_feature_image_abstract.h"
+#include <vector>
+#include "../algs.h"
+#include "../matrix.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class nearest_neighbor_feature_image : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ - nn_feats.size() == 1
+
+ CONVENTION
+ - nn_feats.size() == 1
+
+ !*/
+
+ public:
+
+ typedef std::vector<std::pair<unsigned int,double> > descriptor_type;
+
+ nearest_neighbor_feature_image (
+ );
+
+ void clear (
+ );
+
+ void copy_configuration (
+ const feature_extractor& item
+ );
+
+ void copy_configuration (
+ const nearest_neighbor_feature_image& item
+ );
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ );
+
+ inline size_t size (
+ ) const;
+
+ inline long nr (
+ ) const;
+
+ inline long nc (
+ ) const;
+
+ inline long get_num_dimensions (
+ ) const;
+
+ template <typename vector_type>
+ void set_basis (
+ const vector_type& new_basis
+ );
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const;
+
+ inline const rectangle get_block_rect (
+ long row,
+ long col
+ ) const;
+
+ inline const point image_to_feat_space (
+ const point& p
+ ) const;
+
+ inline const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const;
+
+ inline const point feat_to_image_space (
+ const point& p
+ ) const;
+
+ inline const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const;
+
+ template <typename T>
+ friend void serialize (
+ const nearest_neighbor_feature_image<T>& item,
+ std::ostream& out
+ );
+
+ template <typename T>
+ friend void deserialize (
+ nearest_neighbor_feature_image<T>& item,
+ std::istream& in
+ );
+
+ private:
+
+ array2d<unsigned long> feats;
+ feature_extractor fe;
+ std::vector<typename feature_extractor::descriptor_type> basis;
+
+ // This is a transient variable. It is just here so it doesn't have to be
+ // reallocated over and over inside operator()
+ mutable descriptor_type nn_feats;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const nearest_neighbor_feature_image<T>& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.feats, out);
+ serialize(item.fe, out);
+ serialize(item.basis, out);
+ }
+
+ template <typename T>
+ void deserialize (
+ nearest_neighbor_feature_image<T>& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.feats, in);
+ deserialize(item.fe, in);
+ deserialize(item.basis, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// nearest_neighbor_feature_image member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ nearest_neighbor_feature_image<feature_extractor>::
+ nearest_neighbor_feature_image (
+ )
+ {
+ nn_feats.resize(1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void nearest_neighbor_feature_image<feature_extractor>::
+ clear (
+ )
+ {
+ feats.clear();
+ fe.clear();
+ basis.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void nearest_neighbor_feature_image<feature_extractor>::
+ copy_configuration (
+ const feature_extractor& item
+ )
+ {
+ fe.copy_configuration(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void nearest_neighbor_feature_image<feature_extractor>::
+ copy_configuration (
+ const nearest_neighbor_feature_image& item
+ )
+ {
+ fe.copy_configuration(item.fe);
+ basis = item.basis;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ template <
+ typename image_type
+ >
+ void nearest_neighbor_feature_image<feature_extractor>::
+ load (
+ const image_type& img
+ )
+ {
+ fe.load(img);
+
+ feats.set_size(fe.nr(), fe.nc());
+
+ // find the nearest neighbor for each feature vector and store the
+ // result in feats.
+ for (long r = 0; r < feats.nr(); ++r)
+ {
+ for (long c = 0; c < feats.nc(); ++c)
+ {
+ const typename feature_extractor::descriptor_type& local_feat = fe(r,c);
+
+ double best_dist = std::numeric_limits<double>::infinity();
+ unsigned long best_idx = 0;
+ for (unsigned long i = 0; i < basis.size(); ++i)
+ {
+ double dist = length_squared(local_feat - basis[i]);
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ best_idx = i;
+ }
+ }
+
+ feats[r][c] = best_idx;
+ }
+ }
+
+ fe.unload();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ size_t nearest_neighbor_feature_image<feature_extractor>::
+ size (
+ ) const
+ {
+ return feats.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ long nearest_neighbor_feature_image<feature_extractor>::
+ nr (
+ ) const
+ {
+ return feats.nr();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ long nearest_neighbor_feature_image<feature_extractor>::
+ nc (
+ ) const
+ {
+ return feats.nc();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ long nearest_neighbor_feature_image<feature_extractor>::
+ get_num_dimensions (
+ ) const
+ {
+ return basis.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename feature_extractor>
+ template <typename vector_type>
+ void nearest_neighbor_feature_image<feature_extractor>::
+ set_basis (
+ const vector_type& new_basis
+ )
+ {
+ basis.assign(new_basis.begin(), new_basis.end());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ const typename nearest_neighbor_feature_image<feature_extractor>::descriptor_type&
+ nearest_neighbor_feature_image<feature_extractor>::
+ operator() (
+ long row,
+ long col
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= row && row < nr() &&
+ 0 <= col && col < nc(),
+ "\t descriptor_type nearest_neighbor_feature_image::operator(row,col)"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t row: " << row
+ << "\n\t col: " << col
+ << "\n\t nr(): " << nr()
+ << "\n\t nc(): " << nc()
+ << "\n\t this: " << this
+ );
+
+ nn_feats[0] = std::make_pair(feats[row][col],1);
+ return nn_feats;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ const rectangle nearest_neighbor_feature_image<feature_extractor>::
+ get_block_rect (
+ long row,
+ long col
+ ) const
+ {
+ return fe.get_block_rect(row,col);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ const point nearest_neighbor_feature_image<feature_extractor>::
+ image_to_feat_space (
+ const point& p
+ ) const
+ {
+ return fe.image_to_feat_space(p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ const rectangle nearest_neighbor_feature_image<feature_extractor>::
+ image_to_feat_space (
+ const rectangle& rect
+ ) const
+ {
+ return fe.image_to_feat_space(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ const point nearest_neighbor_feature_image<feature_extractor>::
+ feat_to_image_space (
+ const point& p
+ ) const
+ {
+ return fe.feat_to_image_space(p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ const rectangle nearest_neighbor_feature_image<feature_extractor>::
+ feat_to_image_space (
+ const rectangle& rect
+ ) const
+ {
+ return fe.feat_to_image_space(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_Hh_
+
+
diff --git a/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h b/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h
new file mode 100644
index 000000000..59d7cfeb7
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h
@@ -0,0 +1,254 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_ABSTRACT_Hh_
+#ifdef DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_ABSTRACT_Hh_
+
+#include <vector>
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class nearest_neighbor_feature_image : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ - must be an object with an interface compatible with dlib::hog_image
+
+ INITIAL VALUE
+ - size() == 0
+ - get_num_dimensions() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for performing image feature extraction. In
+ particular, it wraps another image feature extractor and converts
+ the wrapped image feature vectors into sparse indicator vectors. It does
+ this by finding the nearest neighbor for each feature vector and returning an
+ indicator vector that is zero everywhere except for the position indicated by
+ the nearest neighbor.
+
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be protected
+ by a mutex lock except for the case where you are copying the configuration
+ (via copy_configuration()) of a nearest_neighbor_feature_image object to many other
+ threads. In this case, it is safe to copy the configuration of a shared object so
+ long as no other operations are performed on it.
+
+
+ NOTATION
+ let BASE_FE denote the base feature_extractor object contained inside
+ the nearest_neighbor_feature_image.
+ !*/
+
+ public:
+
+ typedef std::vector<std::pair<unsigned int,double> > descriptor_type;
+
+ nearest_neighbor_feature_image (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - this object will have its initial value
+ !*/
+
+ void copy_configuration (
+ const feature_extractor& item
+ );
+ /*!
+ ensures
+ - performs BASE_FE.copy_configuration(item)
+ !*/
+
+ void copy_configuration (
+ const nearest_neighbor_feature_image& item
+ );
+ /*!
+ ensures
+ - copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two
+ nearest_neighbor_feature_image objects H1 and H2, the following sequence
+ of instructions should always result in both of them having the exact
+ same state.
+ H2.copy_configuration(H1);
+ H1.load(img);
+ H2.load(img);
+ !*/
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == any type that can be supplied to feature_extractor::load()
+ ensures
+ - performs BASE_FE.load(img)
+ i.e. does feature extraction. The features can be accessed using
+ operator() as defined below.
+ !*/
+
+ inline size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.size()
+ !*/
+
+ inline long nr (
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.nr()
+ !*/
+
+ inline long nc (
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.nc()
+ !*/
+
+ inline long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the feature vectors returned by operator().
+ In this case, this is the number of basis elements. That is, it is the number
+ of vectors given to the set_basis() member function.
+ !*/
+
+ template <typename vector_type>
+ void set_basis (
+ const vector_type& new_basis
+ );
+ /*!
+ ensures
+ - #get_num_dimensions() == new_basis.size()
+ - The operator() member function defined below will use new_basis to
+ determine nearest neighbors.
+ !*/
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const;
+ /*!
+ requires
+ - 0 <= row < nr()
+ - 0 <= col < nc()
+ - get_num_dimensions() > 0
+ ensures
+ - determines which basis element is nearest to BASE_FE(row,col) and returns a sparse
+ indicator vector identifying the nearest neighbor.
+ - To be precise, this function returns a sparse vector V such that:
+ - V.size() == 1
+ - V[0].first == The basis element index for the basis vector nearest to BASE_FE(row,col).
+ "nearness" is determined using Euclidean distance.
+ - V[0].second == 1
+ !*/
+
+ inline const rectangle get_block_rect (
+ long row,
+ long col
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.get_block_rect(row,col)
+ I.e. returns a rectangle that tells you what part of the original image is associated
+ with a particular feature vector.
+ !*/
+
+ inline const point image_to_feat_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.image_to_feat_space(p)
+ I.e. Each local feature is extracted from a certain point in the input image.
+ This function returns the identity of the local feature corresponding
+ to the image location p. Or in other words, let P == image_to_feat_space(p),
+ then (*this)(P.y(),P.x()) == the local feature closest to, or centered at,
+ the point p in the input image. Note that some image points might not have
+ corresponding feature locations. E.g. border points or points outside the
+ image. In these cases the returned point will be outside get_rect(*this).
+ !*/
+
+ inline const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.image_to_feat_space(rect)
+ I.e. returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner()));
+ (i.e. maps a rectangle from image space to feature space)
+ !*/
+
+ inline const point feat_to_image_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.feat_to_image_space(p)
+ I.e. returns the location in the input image space corresponding to the center
+ of the local feature at point p. In other words, this function computes
+ the inverse of image_to_feat_space(). Note that it may only do so approximately,
+ since more than one image location might correspond to the same local feature.
+ That is, image_to_feat_space() might not be invertible so this function gives
+ the closest possible result.
+ !*/
+
+ inline const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns BASE_FE.feat_to_image_space(rect)
+ I.e. return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner()));
+ (i.e. maps a rectangle from feature space to image space)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const nearest_neighbor_feature_image<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <typename T>
+ void deserialize (
+ nearest_neighbor_feature_image<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/image_keypoint/poly_image.h b/ml/dlib/dlib/image_keypoint/poly_image.h
new file mode 100644
index 000000000..8abb912f0
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/poly_image.h
@@ -0,0 +1,649 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_POLY_ImAGE_Hh_
+#define DLIB_POLY_ImAGE_Hh_
+
+#include "poly_image_abstract.h"
+#include "build_separable_poly_filters.h"
+#include "../algs.h"
+#include "../matrix.h"
+#include "../array2d.h"
+#include "../geometry.h"
+#include <cmath>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long Downsample
+ >
+ class poly_image : noncopyable
+ {
+ COMPILE_TIME_ASSERT(Downsample >= 1);
+ public:
+ const static long downsample = Downsample;
+ typedef matrix<double, 0, 1> descriptor_type;
+
+ poly_image(
+ long order_,
+ long window_size_,
+ bool normalization = true,
+ bool rotation_invariance_ = false
+ )
+ {
+ setup(order_, window_size_);
+ set_uses_normalization(normalization);
+ set_is_rotationally_invariant(rotation_invariance_);
+ }
+
+ poly_image (
+ )
+ {
+ clear();
+ }
+
+ void clear (
+ )
+ {
+ normalize = true;
+ rotation_invariance = false;
+ poly_coef.clear();
+ order = 3;
+ window_size = 13;
+ border_size = (long)std::ceil(std::floor(window_size/2.0)/downsample);
+ num_rows = 0;
+ num_cols = 0;
+ filters = build_separable_poly_filters(order, window_size);
+ }
+
+ long get_order (
+ ) const
+ {
+ return order;
+ }
+
+ long get_window_size (
+ ) const
+ {
+ return window_size;
+ }
+
+ void setup (
+ long order_,
+ long window_size_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(1 <= order_ && order_ <= 6 &&
+ window_size_ >= 3 && (window_size_%2) == 1,
+ "\t descriptor_type poly_image::setup()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t order_: " << order_
+ << "\n\t window_size_: " << window_size_
+ << "\n\t this: " << this
+ );
+
+
+ poly_coef.clear();
+ order = order_;
+ window_size = window_size_;
+ border_size = (long)std::ceil(std::floor(window_size/2.0)/downsample);
+ num_rows = 0;
+ num_cols = 0;
+ filters = build_separable_poly_filters(order, window_size);
+ }
+
+ bool uses_normalization (
+ ) const { return normalize; }
+
+ void set_uses_normalization (
+ bool normalization
+ )
+ {
+ normalize = normalization;
+ }
+
+ bool is_rotationally_invariant (
+ ) const { return rotation_invariance; }
+
+ void set_is_rotationally_invariant (
+ bool rotation_invariance_
+ )
+ {
+ rotation_invariance = rotation_invariance_;
+ }
+
+ void copy_configuration (
+ const poly_image& item
+ )
+ {
+ normalize = item.normalize;
+ rotation_invariance = item.rotation_invariance;
+ if (order != item.order ||
+ window_size != item.window_size)
+ {
+ order = item.order;
+ window_size = item.window_size;
+ border_size = item.border_size;
+ filters = item.filters;
+ }
+ }
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ )
+ {
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<image_type>::pixel_type>::has_alpha == false );
+
+ poly_coef.resize(get_num_dimensions());
+ des.set_size(get_num_dimensions());
+
+
+ if (normalize)
+ {
+ array2d<float> coef0;
+ rectangle rect = filter_image(img, coef0, filters[0]);
+ num_rows = rect.height();
+ num_cols = rect.width();
+
+ for (unsigned long i = 1; i < filters.size(); ++i)
+ {
+ filter_image(img, poly_coef[i-1], filters[i]);
+
+ // intensity normalize everything
+ for (long r = 0; r < coef0.nr(); ++r)
+ {
+ for (long c = 0; c < coef0.nc(); ++c)
+ {
+ if (coef0[r][c] >= 1)
+ poly_coef[i-1][r][c] /= coef0[r][c];
+ else
+ poly_coef[i-1][r][c] = 0;
+ }
+ }
+ }
+
+ if (rotation_invariance)
+ rotate_polys(rect);
+ }
+ else
+ {
+ rectangle rect;
+ for (unsigned long i = 0; i < filters.size(); ++i)
+ {
+ rect = filter_image(img, poly_coef[i], filters[i]);
+ }
+ num_rows = rect.height();
+ num_cols = rect.width();
+
+ if (rotation_invariance)
+ rotate_polys(rect);
+ }
+ }
+
+ void unload()
+ {
+ poly_coef.clear();
+ num_rows = 0;
+ num_cols = 0;
+ }
+
+ inline size_t size (
+ ) const { return static_cast<unsigned long>(nr()*nc()); }
+
+ inline long nr (
+ ) const { return num_rows; }
+
+ inline long nc (
+ ) const { return num_cols; }
+
+ long get_num_dimensions (
+ ) const
+ {
+ if (normalize)
+ {
+ // -1 because we discard the constant term of the polynomial.
+ return filters.size()-1;
+ }
+ else
+ {
+ return filters.size();
+ }
+ }
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( 0 <= row && row < nr() &&
+ 0 <= col && col < nc(),
+ "\t descriptor_type poly_image::operator()()"
+ << "\n\t invalid row or col argument"
+ << "\n\t row: " << row
+ << "\n\t col: " << col
+ << "\n\t nr(): " << nr()
+ << "\n\t nc(): " << nc()
+ << "\n\t this: " << this
+ );
+
+ // add because of the zero border around the poly_coef images
+ row += border_size;
+ col += border_size;
+
+ for (long i = 0; i < des.size(); ++i)
+ des(i) = poly_coef[i][row][col];
+
+ return des;
+ }
+
+ const rectangle get_block_rect (
+ long row,
+ long col
+ ) const
+ {
+ return centered_rect(Downsample*point(col+border_size, row+border_size),
+ window_size, window_size);
+ }
+
+ const point image_to_feat_space (
+ const point& p
+ ) const
+ {
+ return p/Downsample - point(border_size, border_size);
+ }
+
+ const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const
+ {
+ return rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner()));
+ }
+
+ const point feat_to_image_space (
+ const point& p
+ ) const
+ {
+ return (p + point(border_size, border_size))*Downsample;
+ }
+
+ const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const
+ {
+ return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner()));
+ }
+
+
+
+ friend void serialize (const poly_image& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.poly_coef, out);
+ serialize(item.order, out);
+ serialize(item.window_size, out);
+ serialize(item.border_size, out);
+ serialize(item.num_rows, out);
+ serialize(item.num_cols, out);
+ serialize(item.normalize, out);
+ serialize(item.rotation_invariance, out);
+ serialize(item.filters, out);
+ }
+
+ friend void deserialize (poly_image& item, std::istream& in )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw dlib::serialization_error("Unexpected version found while deserializing dlib::poly_image");
+
+ deserialize(item.poly_coef, in);
+ deserialize(item.order, in);
+ deserialize(item.window_size, in);
+ deserialize(item.border_size, in);
+ deserialize(item.num_rows, in);
+ deserialize(item.num_cols, in);
+ deserialize(item.normalize, in);
+ deserialize(item.rotation_invariance, in);
+ deserialize(item.filters, in);
+ }
+
+ private:
+
+ matrix<float,2,1> rotate_order_1 (
+ const matrix<float,2,1>& w,
+ double cos_theta,
+ double sin_theta
+ ) const
+ {
+ const double w1 = w(0);
+ const double w2 = w(1);
+ matrix<double,2,2> M;
+ M = w1, w2,
+ w2, -w1;
+
+ matrix<double,2,1> x;
+ x = cos_theta,
+ sin_theta;
+
+ return matrix_cast<float>(M*x);
+ }
+
+ matrix<float,3,1> rotate_order_2 (
+ const matrix<float,3,1>& w,
+ double cos_theta,
+ double sin_theta
+ ) const
+ {
+ const double w1 = w(0);
+ const double w2 = w(1);
+ const double w3 = w(2);
+ matrix<double,3,3> M;
+ M = w1, w2, w3,
+ w2, (2*w3-2*w1), -w2,
+ w3, -w2, w1;
+
+ matrix<double,3,1> x;
+ x = std::pow(cos_theta,2.0),
+ cos_theta*sin_theta,
+ std::pow(sin_theta,2.0);
+
+ return matrix_cast<float>(M*x);
+ }
+
+ matrix<float,4,1> rotate_order_3 (
+ const matrix<float,4,1>& w,
+ double cos_theta,
+ double sin_theta
+ ) const
+ {
+ const double w1 = w(0);
+ const double w2 = w(1);
+ const double w3 = w(2);
+ const double w4 = w(3);
+ matrix<double,4,4> M;
+ M = w1, w2, w3, w4,
+ w2, (2*w3-3*w1), (3*w4-2*w2), -w3,
+ w3, (3*w4-2*w2), (3*w1-2*w3), w2,
+ w4, -w3, w2, -w1;
+
+ matrix<double,4,1> x;
+ x = std::pow(cos_theta,3.0),
+ std::pow(cos_theta,2.0)*sin_theta,
+ cos_theta*std::pow(sin_theta,2.0),
+ std::pow(sin_theta,3.0);
+
+ return matrix_cast<float>(M*x);
+ }
+
+ matrix<float,5,1> rotate_order_4 (
+ const matrix<float,5,1>& w,
+ double cos_theta,
+ double sin_theta
+ ) const
+ {
+ const double w1 = w(0);
+ const double w2 = w(1);
+ const double w3 = w(2);
+ const double w4 = w(3);
+ const double w5 = w(4);
+ matrix<double,5,5> M;
+ M = w1, w2, w3, w4, w5,
+ w2, (2*w3-4*w1), (3*w4-3*w2), (4*w5-2*w3), -w4,
+ w3, (3*w4-3*w2), (6*w1-4*w3+6*w5), (3*w2-3*w4), w3,
+ w4, (4*w5-2*w3), (3*w2-3*w4), (2*w3-4*w1), -w2,
+ w5, -w4, w3, -w2, w1;
+
+ matrix<double,5,1> x;
+ x = std::pow(cos_theta,4.0),
+ std::pow(cos_theta,3.0)*sin_theta,
+ std::pow(cos_theta,2.0)*std::pow(sin_theta,2.0),
+ cos_theta*std::pow(sin_theta,3.0),
+ std::pow(sin_theta,4.0);
+
+ return matrix_cast<float>(M*x);
+ }
+
+ matrix<float,6,1> rotate_order_5 (
+ const matrix<float,6,1>& w,
+ double cos_theta,
+ double sin_theta
+ ) const
+ {
+ const double w1 = w(0);
+ const double w2 = w(1);
+ const double w3 = w(2);
+ const double w4 = w(3);
+ const double w5 = w(4);
+ const double w6 = w(5);
+ matrix<double,6,6> M;
+ M = w1, w2, w3, w4, w5, w6,
+ w2, (2*w3-5*w1), (3*w4-4*w2), (4*w5-3*w3), (5*w6-2*w4), -w5,
+ w3, (3*w4-4*w2), (10*w1-6*w3+6*w5), (6*w2-6*w4+10*w6), (3*w3-4*w5), w4,
+ w4, (4*w5-3*w3), (6*w2-6*w4+10*w6), (-10*w1+6*w3-6*w5), (3*w4-4*w2), -w3,
+ w5, (5*w6-2*w4), (3*w3-4*w5), (3*w4-4*w2), (5*w1-2*w3), w2,
+ w6, -w5, w4, -w3, w2, -w1;
+
+ matrix<double,6,1> x;
+ x = std::pow(cos_theta,5.0),
+ std::pow(cos_theta,4.0)*sin_theta,
+ std::pow(cos_theta,3.0)*std::pow(sin_theta,2.0),
+ std::pow(cos_theta,2.0)*std::pow(sin_theta,3.0),
+ cos_theta*std::pow(sin_theta,4.0),
+ std::pow(sin_theta,5.0);
+
+ return matrix_cast<float>(M*x);
+ }
+
+ matrix<float,7,1> rotate_order_6 (
+ const matrix<float,7,1>& w,
+ double cos_theta,
+ double sin_theta
+ ) const
+ {
+ const double w1 = w(0);
+ const double w2 = w(1);
+ const double w3 = w(2);
+ const double w4 = w(3);
+ const double w5 = w(4);
+ const double w6 = w(5);
+ const double w7 = w(6);
+ matrix<double,7,7> M;
+ M = w1, w2, w3, w4, w5, w6, w7,
+ w2, (2*w3-6*w1), (3*w4-5*w2), (4*w5-4*w3), (5*w6-3*w4), (6*w7-2*w5), -w6,
+ w3, (3*w4-5*w2), (15*w1-8*w3+ 6*w5), ( 10*w2 -9*w4+10*w6), ( 6*w3-8*w5+15*w7), (3*w4-5*w6), w5,
+ w4, (4*w5-4*w3), (10*w2-9*w4+10*w6), (-20*w1+12*w3-12*w5+20*w7), (-10*w2+9*w4-10*w6), (4*w5-4*w3), -w4,
+ w5, (5*w6-3*w4), ( 6*w3-8*w5+15*w7), (-10*w2 +9*w4-10*w6), ( 15*w1-8*w3 +6*w5), (5*w2-3*w4), w3,
+ w6, (6*w7-2*w5), (3*w4-5*w6), (4*w5-4*w3), (5*w2-3*w4), (2*w3-6*w1), -w2,
+ w7, -w6, w5, -w4, w3, -w2, w1;
+
+ matrix<double,7,1> x;
+ x = std::pow(cos_theta,6.0),
+ std::pow(cos_theta,5.0)*sin_theta,
+ std::pow(cos_theta,4.0)*std::pow(sin_theta,2.0),
+ std::pow(cos_theta,3.0)*std::pow(sin_theta,3.0),
+ std::pow(cos_theta,2.0)*std::pow(sin_theta,4.0),
+ cos_theta*std::pow(sin_theta,5.0),
+ std::pow(sin_theta,6.0);
+
+ return matrix_cast<float>(M*x);
+ }
+
+ void rotate_polys (
+ const rectangle& rect
+ )
+ /*!
+ ensures
+ - rotates all the polynomials in poly_coef so that they are
+ rotationally invariant
+ !*/
+ {
+ // The idea here is to use a rotation matrix to rotate the
+ // coordinate system for the polynomial so that the x axis
+ // always lines up with the gradient vector (or direction of
+ // max curvature). This way we can make the representation
+ // rotation invariant.
+
+ // Note that the rotation matrix is given by:
+ // [ cos_theta -sin_theta ]
+ // [ sin_theta cos_theta ]
+
+ // need to offset poly_coef to get past the constant term if there isn't any normalization.
+ const int off = (normalize) ? 0 : 1;
+
+ for (long r = rect.top(); r <= rect.bottom(); ++r)
+ {
+ for (long c = rect.left(); c <= rect.right(); ++c)
+ {
+ dlib::vector<double,2> g(poly_coef[off+0][r][c],
+ poly_coef[off+1][r][c]);
+
+ const double len = g.length();
+ if (len != 0)
+ {
+ g /= len;
+ }
+ else
+ {
+ g.x() = 1;
+ g.y() = 0;
+ }
+ // since we normalized g we can find the sin/cos of its angle easily.
+ const double cos_theta = g.x();
+ const double sin_theta = g.y();
+
+ if (order >= 1)
+ {
+ matrix<float,2,1> w;
+ w = poly_coef[off+0][r][c],
+ poly_coef[off+1][r][c];
+ w = rotate_order_1(w, cos_theta, sin_theta);
+ poly_coef[off+0][r][c] = w(0);
+ poly_coef[off+1][r][c] = w(1);
+ }
+ if (order >= 2)
+ {
+ matrix<float,3,1> w;
+ w = poly_coef[off+2][r][c],
+ poly_coef[off+3][r][c],
+ poly_coef[off+4][r][c];
+ w = rotate_order_2(w, cos_theta, sin_theta);
+ poly_coef[off+2][r][c] = w(0);
+ poly_coef[off+3][r][c] = w(1);
+ poly_coef[off+4][r][c] = w(2);
+ }
+ if (order >= 3)
+ {
+ matrix<float,4,1> w;
+ w = poly_coef[off+5][r][c],
+ poly_coef[off+6][r][c],
+ poly_coef[off+7][r][c],
+ poly_coef[off+8][r][c];
+ w = rotate_order_3(w, cos_theta, sin_theta);
+ poly_coef[off+5][r][c] = w(0);
+ poly_coef[off+6][r][c] = w(1);
+ poly_coef[off+7][r][c] = w(2);
+ poly_coef[off+8][r][c] = w(3);
+ }
+ if (order >= 4)
+ {
+ matrix<float,5,1> w;
+ w = poly_coef[off+9][r][c],
+ poly_coef[off+10][r][c],
+ poly_coef[off+11][r][c],
+ poly_coef[off+12][r][c],
+ poly_coef[off+13][r][c];
+ w = rotate_order_4(w, cos_theta, sin_theta);
+ poly_coef[off+9][r][c] = w(0);
+ poly_coef[off+10][r][c] = w(1);
+ poly_coef[off+11][r][c] = w(2);
+ poly_coef[off+12][r][c] = w(3);
+ poly_coef[off+13][r][c] = w(4);
+ }
+ if (order >= 5)
+ {
+ matrix<float,6,1> w;
+ w = poly_coef[off+14][r][c],
+ poly_coef[off+15][r][c],
+ poly_coef[off+16][r][c],
+ poly_coef[off+17][r][c],
+ poly_coef[off+18][r][c],
+ poly_coef[off+19][r][c];
+ w = rotate_order_5(w, cos_theta, sin_theta);
+ poly_coef[off+14][r][c] = w(0);
+ poly_coef[off+15][r][c] = w(1);
+ poly_coef[off+16][r][c] = w(2);
+ poly_coef[off+17][r][c] = w(3);
+ poly_coef[off+18][r][c] = w(4);
+ poly_coef[off+19][r][c] = w(5);
+ }
+ if (order >= 6)
+ {
+ matrix<float,7,1> w;
+ w = poly_coef[off+20][r][c],
+ poly_coef[off+21][r][c],
+ poly_coef[off+22][r][c],
+ poly_coef[off+23][r][c],
+ poly_coef[off+24][r][c],
+ poly_coef[off+25][r][c],
+ poly_coef[off+26][r][c];
+ w = rotate_order_6(w, cos_theta, sin_theta);
+ poly_coef[off+20][r][c] = w(0);
+ poly_coef[off+21][r][c] = w(1);
+ poly_coef[off+22][r][c] = w(2);
+ poly_coef[off+23][r][c] = w(3);
+ poly_coef[off+24][r][c] = w(4);
+ poly_coef[off+25][r][c] = w(5);
+ poly_coef[off+26][r][c] = w(6);
+ }
+ }
+ }
+
+ }
+
+ template <typename image_type>
+ rectangle filter_image (
+ const image_type& img,
+ array2d<float>& out,
+ const std::vector<separable_filter_type>& filter
+ ) const
+ {
+ rectangle rect = spatially_filter_image_separable_down(downsample, img, out, filter[0].first, filter[0].second);
+ for (unsigned long i = 1; i < filter.size(); ++i)
+ {
+ spatially_filter_image_separable_down(downsample, img, out, filter[i].first, filter[i].second, 1, false, true);
+ }
+ return rect;
+ }
+
+
+
+ std::vector<std::vector<separable_filter_type> > filters;
+
+ dlib::array<array2d<float> > poly_coef;
+ long order;
+ long window_size;
+ long border_size;
+ long num_rows;
+ long num_cols;
+
+ bool normalize;
+ bool rotation_invariance;
+
+ mutable descriptor_type des;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_POLY_ImAGE_Hh_
+
+
diff --git a/ml/dlib/dlib/image_keypoint/poly_image_abstract.h b/ml/dlib/dlib/image_keypoint/poly_image_abstract.h
new file mode 100644
index 000000000..2f17bb31e
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/poly_image_abstract.h
@@ -0,0 +1,335 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_POLY_ImAGE_ABSTRACT_Hh_
+#ifdef DLIB_POLY_ImAGE_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "../matrix.h"
+#include "../geometry/rectangle_abstract.h"
+#include <cmath>
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+ template <
+ long Downsample
+ >
+ class poly_image : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON TEMPLATE PARAMETERS
+ - Downsample >= 1
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for extracting local feature descriptors from an image.
+ In particular, it fits polynomials to local pixel patches and allows you to
+ query the coefficients of these polynomials. Additionally, the coefficients
+ may be intensity normalized by dividing them by the constant term of the fitted
+ polynomial and then the constant term is discarded.
+
+ Finally, the user can specify a downsampling rate. If the template argument
+ Downsample is set to 1 then feature extraction is performed at every pixel of
+ an input image (except for a small area around the image border). However,
+ if Downsample is set to 2 then feature extraction is only performed at every
+ other pixel location. More generally, if Downsample is set to N then feature
+ extraction is performed only every N pixels.
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be protected
+ by a mutex lock except for the case where you are copying the configuration
+ (via copy_configuration()) of a poly_image object to many other threads.
+ In this case, it is safe to copy the configuration of a shared object so long
+ as no other operations are performed on it.
+ !*/
+
+ public:
+
+ typedef matrix<double, 0, 1> descriptor_type;
+ const static long downsample = Downsample;
+
+ poly_image (
+ );
+ /*!
+ ensures
+ - #get_order() == 3
+ - #get_window_size() == 13
+ - #size() == 0
+ - #uses_normalization() == true
+ - #is_rotationally_invariant() == false
+ !*/
+
+ poly_image(
+ long order,
+ long window_size,
+ bool normalization = true,
+ bool rotation_invariance = false
+ );
+ /*!
+ requires
+ - 1 <= order <= 6
+ - window_size >= 3 && window_size is odd
+ ensures
+ - #get_order() == order
+ - #get_window_size() == window_size
+ - #size() == 0
+ - #uses_normalization() == normalization
+ - #is_rotationally_invariant() == rotation_invariance
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - this object will have its initial value
+ !*/
+
+ void setup (
+ long order,
+ long window_size
+ );
+ /*!
+ requires
+ - 1 <= order <= 6
+ - window_size >= 3 && window_size is odd
+ ensures
+ - #get_order() == order
+ - #get_window_size() == window_size
+ !*/
+
+ long get_order (
+ ) const;
+ /*!
+ ensures
+ - returns the order of the polynomial that will be fitted to
+ each local pixel patch during feature extraction.
+ !*/
+
+ long get_window_size (
+ ) const;
+ /*!
+ ensures
+ - returns the size of the window used for local feature extraction.
+ This is the width and height of the window in pixels.
+ !*/
+
+ bool uses_normalization (
+ ) const;
+ /*!
+ ensures
+ - returns true if the polynomial coefficients are intensity normalized
+ and false otherwise.
+ !*/
+
+ void set_uses_normalization (
+ bool normalization
+ );
+ /*!
+ ensures
+ - #uses_normalization() == normalization
+ !*/
+
+ bool is_rotationally_invariant (
+ );
+ /*!
+ ensures
+ - returns true if the feature extractor will adjust the output so that it
+ is rotationally invariant. This is done by rotating each patch such that
+ the gradient vector always points in the same direction.
+ !*/
+
+ void set_is_rotationally_invariant (
+ bool rotation_invariance
+ );
+ /*!
+ ensures
+ - #is_rotationally_invariant() == rotation_invariance
+ !*/
+
+ void copy_configuration (
+ const poly_image& item
+ );
+ /*!
+ ensures
+ - copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two poly_image
+ objects H1 and H2, the following sequence of instructions should always
+ result in both of them having the exact same state.
+ H2.copy_configuration(H1);
+ H1.load(img);
+ H2.load(img);
+ !*/
+
+ template <
+ typename image_type
+ >
+ inline void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<typename image_traits<image_type>::pixel_type>::has_alpha == false
+ ensures
+ - Performs the feature extraction described in the WHAT THIS OBJECT REPRESENTS
+ section above. This means after load() finishes you can call (*this)(row,col)
+ to obtain the polynomial coefficients for an order get_order() polynomial which
+ was fitted to the image patch get_block_rect(row,col).
+ - #size() > 0
+ !*/
+
+ void unload(
+ );
+ /*!
+ ensures
+ - #nr() == 0
+ - #nc() == 0
+ - clears only the state information which is populated by load(). For
+ example, let H be a poly_image object. Then consider the two sequences
+ of instructions:
+ Sequence 1:
+ H.load(img);
+ H.unload();
+ H.load(img);
+
+ Sequence 2:
+ H.load(img);
+ Both sequence 1 and sequence 2 should have the same effect on H.
+ !*/
+
+ inline size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns nr()*nc()
+ !*/
+
+ inline long nr (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in this polynomial feature image
+ !*/
+
+ inline long nc (
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in this polynomial feature image
+ !*/
+
+ long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns the number of dimensions in the feature vectors generated by
+ this object.
+ - In this case, this will be the number of coefficients in an order
+ get_order() polynomial, except for the constant term of the polynomial
+ if uses_normalization() == true.
+ !*/
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const;
+ /*!
+ requires
+ - 0 <= row < nr()
+ - 0 <= col < nc()
+ ensures
+ - returns the descriptor for the polynomial filtering block at the given row and column.
+ This vector will contain the polynomial coefficients for a polynomial fitted to the
+ image patch located at get_block_rect(row,col) in the original image given to load().
+ - The returned descriptor vector will have get_num_dimensions() elements.
+ !*/
+
+ const rectangle get_block_rect (
+ long row,
+ long col
+ ) const;
+ /*!
+ ensures
+ - returns a rectangle that tells you what part of the original image is associated
+ with a particular polynomial filter block. That is, what part of the input image
+ is associated with (*this)(row,col).
+ - The returned rectangle will be get_window_size() pixels wide and tall.
+ !*/
+
+ const point image_to_feat_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - Each local feature is extracted from a certain point in the input image.
+ This function returns the identity of the local feature corresponding
+ to the image location p. Or in other words, let P == image_to_feat_space(p),
+ then (*this)(P.y(),P.x()) == the local feature closest to, or centered at,
+ the point p in the input image. Note that some image points might not have
+ corresponding feature locations. E.g. border points or points outside the
+ image. In these cases the returned point will be outside get_rect(*this).
+ !*/
+
+ const rectangle image_to_feat_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner()));
+ (i.e. maps a rectangle from image space to feature space)
+ !*/
+
+ const point feat_to_image_space (
+ const point& p
+ ) const;
+ /*!
+ ensures
+ - returns the location in the input image space corresponding to the center
+ of the local feature at point p. In other words, this function computes
+ the inverse of image_to_feat_space(). Note that it may only do so approximately,
+ since more than one image location might correspond to the same local feature.
+ That is, image_to_feat_space() might not be invertible so this function gives
+ the closest possible result.
+ !*/
+
+ const rectangle feat_to_image_space (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner()));
+ (i.e. maps a rectangle from feature space to image space)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long downsample
+ >
+ void serialize (
+ const poly_image<downsample>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ long downsample
+ >
+ void deserialize (
+ poly_image<downsample>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_POLY_ImAGE_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_keypoint/surf.h b/ml/dlib/dlib/image_keypoint/surf.h
new file mode 100644
index 000000000..d12b30840
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/surf.h
@@ -0,0 +1,295 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SURf_H_
+#define DLIB_SURf_H_
+
+#include "surf_abstract.h"
+#include "hessian_pyramid.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct surf_point
+ {
+ interest_point p;
+ matrix<double,64,1> des;
+ double angle;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize(
+ const surf_point& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.p,out);
+ serialize(item.des,out);
+ serialize(item.angle,out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type surf_point");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize(
+ surf_point& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.p,in);
+ deserialize(item.des,in);
+ deserialize(item.angle,in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type surf_point");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double gaussian (double x, double y, double sig)
+ {
+ DLIB_ASSERT(sig > 0,
+ "\tdouble gaussian()"
+ << "\n\t sig must be bigger than 0"
+ << "\n\t sig: " << sig
+ );
+ const double sqrt_2_pi = 2.5066282746310002416123552393401041626930;
+ return 1.0/(sig*sqrt_2_pi) * std::exp( -(x*x + y*y)/(2*sig*sig));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename integral_image_type, typename T>
+ double compute_dominant_angle (
+ const integral_image_type& img,
+ const dlib::vector<T,2>& center,
+ const double& scale
+ )
+ {
+ DLIB_ASSERT(get_rect(img).contains(centered_rect(center, (unsigned long)(17*scale),(unsigned long)(17*scale))) == true &&
+ scale > 0,
+ "\tdouble compute_dominant_angle(img, center, scale)"
+ << "\n\tAll arguments to this function must be > 0"
+ << "\n\t get_rect(img): " << get_rect(img)
+ << "\n\t center: " << center
+ << "\n\t scale: " << scale
+ );
+
+
+ std::vector<double> ang;
+ std::vector<dlib::vector<double,2> > samples;
+
+ const long sc = static_cast<long>(scale+0.5);
+
+ // accumulate a bunch of angle and vector samples
+ dlib::vector<double,2> vect;
+ for (long r = -6; r <= 6; ++r)
+ {
+ for (long c = -6; c <= 6; ++c)
+ {
+ if (r*r + c*c < 36)
+ {
+ // compute a Gaussian weighted gradient and the gradient's angle.
+ const double gauss = gaussian(c,r, 2.5);
+ vect.x() = gauss*haar_x(img, sc*point(c,r)+center, 4*sc);
+ vect.y() = gauss*haar_y(img, sc*point(c,r)+center, 4*sc);
+ samples.push_back(vect);
+ ang.push_back(atan2(vect.y(), vect.x()));
+ }
+ }
+ }
+
+
+ // now find the dominant direction
+ double max_length = 0;
+ double best_ang = 0;
+ // look at a bunch of pie shaped slices of a circle
+ const long slices = 45;
+ const double ang_step = (2*pi)/slices;
+ for (long ang_i = 0; ang_i < slices; ++ang_i)
+ {
+ // compute the bounding angles
+ double ang1 = ang_step*ang_i - pi;
+ double ang2 = ang1 + pi/3;
+
+
+ // compute sum of all vectors that are within the above two angles
+ vect.x() = 0;
+ vect.y() = 0;
+ for (unsigned long i = 0; i < ang.size(); ++i)
+ {
+ if (ang1 <= ang[i] && ang[i] <= ang2)
+ {
+ vect += samples[i];
+ }
+ else if (ang2 > pi && (ang[i] >= ang1 || ang[i] <= (-2*pi+ang2)))
+ {
+ vect += samples[i];
+ }
+ }
+
+
+ // record the angle of the best vectors
+ if (length_squared(vect) > max_length)
+ {
+ max_length = length_squared(vect);
+ best_ang = atan2(vect.y(), vect.x());
+ }
+ }
+
+ return best_ang;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename integral_image_type, typename T, typename MM, typename L>
+ void compute_surf_descriptor (
+ const integral_image_type& img,
+ const dlib::vector<T,2>& center,
+ const double scale,
+ const double angle,
+ matrix<double,64,1,MM,L>& des
+ )
+ {
+ DLIB_ASSERT(get_rect(img).contains(centered_rect(center, (unsigned long)(32*scale),(unsigned long)(32*scale))) == true &&
+ scale > 0,
+ "\tvoid compute_surf_descriptor(img, center, scale, angle)"
+ << "\n\tAll arguments to this function must be > 0"
+ << "\n\t get_rect(img): " << get_rect(img)
+ << "\n\t center: " << center
+ << "\n\t scale: " << scale
+ );
+
+ point_rotator rot(angle);
+ point_rotator inv_rot(-angle);
+
+ const long sc = static_cast<long>(scale+0.5);
+ long count = 0;
+
+ // loop over the 4x4 grid of histogram buckets
+ for (long r = -10; r < 10; r += 5)
+ {
+ for (long c = -10; c < 10; c += 5)
+ {
+ dlib::vector<double,2> vect, abs_vect, temp;
+
+ // now loop over 25 points in this bucket and sum their features. Note
+ // that we include 1 pixels worth of padding around the outside of each 5x5
+ // cell. This is to help neighboring cells interpolate their counts into
+ // each other a little bit.
+ for (long y = r-1; y < r+5+1; ++y)
+ {
+ if (y < -10 || y >= 10)
+ continue;
+ for (long x = c-1; x < c+5+1; ++x)
+ {
+ if (x < -10 || x >= 10)
+ continue;
+
+ // get the rotated point for this extraction point
+ point p(rot(point(x,y)*scale) + center);
+
+ // Give points farther from the center of the bucket a lower weight.
+ const long center_r = r+2;
+ const long center_c = c+2;
+ const double weight = 1.0/(4+std::abs(center_r-y) + std::abs(center_c-x));
+
+ temp.x() = weight*haar_x(img, p, 2*sc);
+ temp.y() = weight*haar_y(img, p, 2*sc);
+
+ // rotate this vector into alignment with the surf descriptor box
+ temp = inv_rot(temp);
+
+ vect += temp;
+ abs_vect += abs(temp);
+ }
+ }
+
+ des(count++) = vect.x();
+ des(count++) = vect.y();
+ des(count++) = abs_vect.x();
+ des(count++) = abs_vect.y();
+ }
+ }
+
+ // Return the length normalized descriptor. Add a small number
+ // to guard against division by zero.
+ const double len = length(des) + 1e-7;
+ des = des/len;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ const std::vector<surf_point> get_surf_points (
+ const image_type& img,
+ long max_points = 10000,
+ double detection_threshold = 30.0
+ )
+ {
+ DLIB_ASSERT(max_points > 0 && detection_threshold >= 0,
+ "\t std::vector<surf_point> get_surf_points()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t max_points: " << max_points
+ << "\n\t detection_threshold: " << detection_threshold
+ );
+
+ // Figure out the proper scalar type we should use to work with these pixels.
+ typedef typename pixel_traits<typename image_traits<image_type>::pixel_type>::basic_pixel_type bp_type;
+ typedef typename promote<bp_type>::type working_pixel_type;
+
+ // make an integral image first
+ integral_image_generic<working_pixel_type> int_img;
+ int_img.load(img);
+
+ // now make a hessian pyramid
+ hessian_pyramid pyr;
+ pyr.build_pyramid(int_img, 4, 6, 2);
+
+ // now get all the interest points from the hessian pyramid
+ std::vector<interest_point> points;
+ get_interest_points(pyr, detection_threshold, points);
+ std::vector<surf_point> spoints;
+
+ // sort all the points by how strong their detect is
+ std::sort(points.rbegin(), points.rend());
+
+ // now extract SURF descriptors for the points
+ surf_point sp;
+ for (unsigned long i = 0; i < std::min((size_t)max_points,points.size()); ++i)
+ {
+ // ignore points that are close to the edge of the image
+ const double border = 32;
+ const unsigned long border_size = static_cast<unsigned long>(border*points[i].scale);
+ if (get_rect(int_img).contains(centered_rect(points[i].center, border_size, border_size)))
+ {
+ sp.angle = compute_dominant_angle(int_img, points[i].center, points[i].scale);
+ compute_surf_descriptor(int_img, points[i].center, points[i].scale, sp.angle, sp.des);
+ sp.p = points[i];
+
+ spoints.push_back(sp);
+ }
+ }
+
+ return spoints;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SURf_H_
+
diff --git a/ml/dlib/dlib/image_keypoint/surf_abstract.h b/ml/dlib/dlib/image_keypoint/surf_abstract.h
new file mode 100644
index 000000000..e539f3e24
--- /dev/null
+++ b/ml/dlib/dlib/image_keypoint/surf_abstract.h
@@ -0,0 +1,163 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SURf_ABSTRACT_H_
+#ifdef DLIB_SURf_ABSTRACT_H_
+
+#include "hessian_pyramid_abstract.h"
+#include "../geometry/vector_abstract.h"
+#include "../matrix/matrix_abstract.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+ /*
+ The functions in this file implement the components of the SURF algorithm
+ for extracting scale invariant feature descriptors from images.
+
+ For the full story on what this algorithm does and how it works
+ you should refer to the following papers.
+
+ This is the original paper which introduced the algorithm:
+ SURF: Speeded Up Robust Features
+ By Herbert Bay, Tinne Tuytelaars, and Luc Van Gool
+
+ This paper provides a nice detailed overview of how the algorithm works:
+ Notes on the OpenSURF Library by Christopher Evans
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ double gaussian (
+ double x,
+ double y,
+ double sig
+ );
+ /*!
+ requires
+ - sig > 0
+ ensures
+ - computes and returns the value of a 2D Gaussian function with mean 0
+ and standard deviation sig at the given (x,y) point.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename integral_image_type, typename T>
+ double compute_dominant_angle (
+ const integral_image_type& img,
+ const dlib::vector<T,2>& center,
+ const double& scale
+ );
+ /*!
+ requires
+ - integral_image_type == an object such as dlib::integral_image or another
+ type that implements the interface defined in image_transforms/integral_image_abstract.h
+ - scale > 0
+ - get_rect(img).contains(centered_rect(center, 17*scale, 17*scale)) == true
+ (i.e. center can't be within 17*scale pixels of the edge of the image)
+ ensures
+ - computes and returns the dominant angle (i.e. the angle of the dominant gradient)
+ at the given center point and scale in img.
+ - The returned angle is in radians. Specifically, if the angle is described by
+ a vector vect then the angle is exactly the value of std::atan2(vect.y(), vect.x())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename integral_image_type, typename T, typename MM, typename L>
+ void compute_surf_descriptor (
+ const integral_image_type& img,
+ const dlib::vector<T,2>& center,
+ const double scale,
+ const double angle,
+ matrix<double,64,1,MM,L>& des
+ )
+ /*!
+ requires
+ - integral_image_type == an object such as dlib::integral_image or another
+ type that implements the interface defined in image_transforms/integral_image_abstract.h
+ - scale > 0
+ - get_rect(img).contains(centered_rect(center, 32*scale, 32*scale)) == true
+ (i.e. center can't be within 32*scale pixels of the edge of the image)
+ ensures
+ - computes the 64 dimensional SURF descriptor vector of a box centered
+ at the given center point, tilted at an angle determined by the given
+ angle, and sized according to the given scale.
+ - #des == the computed SURF descriptor vector extracted from the img object.
+ - The angle is measured in radians and measures the degree of counter-clockwise
+ rotation around the center point. This is the same kind of rotation as is
+ performed by the dlib::rotate_point() function.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ struct surf_point
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a detected SURF point. The meanings of
+ its fields are defined below in the get_surf_points() function.
+ !*/
+
+ interest_point p;
+ matrix<double,64,1> des;
+ double angle;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const surf_point& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ surf_point& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ const std::vector<surf_point> get_surf_points (
+ const image_type& img,
+ long max_points = 10000,
+ double detection_threshold = 30.0
+ );
+ /*!
+ requires
+ - max_points > 0
+ - detection_threshold >= 0
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - Let P denote the type of pixel in img, then we require:
+ - pixel_traits<P>::has_alpha == false
+ ensures
+ - This function runs the complete SURF algorithm on the given input image and
+ returns the points it found.
+ - returns a vector V such that:
+ - V.size() <= max_points
+ - for all valid i:
+ - V[i] == a SURF point found in the given input image img
+ - V[i].p == the interest_point extracted from the hessian pyramid for this
+ SURF point.
+ - V[i].des == the SURF descriptor for this point (calculated using
+ compute_surf_descriptor())
+ - V[i].angle == the angle of the SURF box at this point (calculated using
+ compute_dominant_angle())
+ - V[i].p.score >= detection_threshold
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SURf_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/image_loader/image_loader.h b/ml/dlib/dlib/image_loader/image_loader.h
new file mode 100644
index 000000000..4fa29dab2
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/image_loader.h
@@ -0,0 +1,863 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IMAGE_LOADEr_
+#define DLIB_IMAGE_LOADEr_
+
+#include "image_loader_abstract.h"
+#include <iostream>
+#include <sstream>
+#include "../algs.h"
+#include "../pixel.h"
+#include "../image_saver/dng_shared.h"
+#include "../entropy_decoder_model.h"
+#include "../entropy_decoder.h"
+#include "../uintn.h"
+#include "../image_transforms/assign_image.h"
+#include <algorithm>
+#include "../vectorstream.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class image_load_error : public dlib::error {
+ public: image_load_error(const std::string& str) : error(EIMAGE_LOAD,str){}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void load_bmp (
+ image_type& image_,
+ std::istream& in_
+ )
+ {
+ image_view<image_type> image(image_);
+ try
+ {
+ unsigned long bytes_read_so_far = 0;
+ unsigned long bfSize;
+ unsigned long bfOffBits;
+ unsigned long bfReserved;
+ unsigned long biSize;
+ unsigned long biWidth;
+ unsigned long biHeight;
+ unsigned short biBitCount;
+ unsigned long biCompression;
+ /*
+ unsigned long biSizeImage;
+ unsigned long biClrUsed;
+ unsigned long biClrImportant;
+ */
+ unsigned long a, b, c, d, i;
+
+ using namespace std;
+
+ streambuf& in = *in_.rdbuf();
+ // streamsize num;
+ unsigned char buf[100];
+
+
+ // first make sure the BMP starts with BM
+ if (in.sgetn(reinterpret_cast<char*>(buf),2) != 2)
+ throw image_load_error("bmp load error 1: header error");
+ bytes_read_so_far += 2;
+
+ if (buf[0] != 'B' || buf[1] != 'M')
+ throw image_load_error("bmp load error 2: header error");
+
+ // now read the BITMAPFILEHEADER
+ if (in.sgetn(reinterpret_cast<char*>(buf),12) != 12)
+ throw image_load_error("bmp load error 3: header error");
+
+ bytes_read_so_far += 12;
+
+ i = 0;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ bfSize = a | (b<<8) | (c<<16) | (d<<24);
+
+ i = 4;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ bfReserved = a | (b<<8) | (c<<16) | (d<<24);
+
+ i = 8;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ bfOffBits = a | (b<<8) | (c<<16) | (d<<24);
+
+ // if this value isn't zero then there is something wrong
+ // with this bitmap.
+ if (bfReserved != 0)
+ throw image_load_error("bmp load error 4: reserved area not zero");
+
+
+ // load the BITMAPINFOHEADER
+ if (in.sgetn(reinterpret_cast<char*>(buf),40) != 40)
+ throw image_load_error("bmp load error 5: file too short");
+ bytes_read_so_far += 40;
+
+
+ i = 0;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ biSize = a | (b<<8) | (c<<16) | (d<<24);
+
+ i += 4;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ biWidth = a | (b<<8) | (c<<16) | (d<<24);
+
+ i += 4;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ biHeight = a | (b<<8) | (c<<16) | (d<<24);
+
+ i += 4+2;
+ a = buf[i]; b = buf[i+1];
+ biBitCount = static_cast<unsigned short>(a | (b<<8));
+
+ i += 2;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ biCompression = a | (b<<8) | (c<<16) | (d<<24);
+
+ /*
+ i += 4;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ biSizeImage = a | (b<<8) | (c<<16) | (d<<24);
+
+ i += 4+4+4;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ biClrUsed = a | (b<<8) | (c<<16) | (d<<24);
+
+ i += 4;
+ a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3];
+ biClrImportant = a | (b<<8) | (c<<16) | (d<<24);
+ */
+
+
+ if (biSize != 40)
+ throw image_load_error("bmp load error 6: header too small");
+
+ // read and discard any extra bytes that are part of the header
+ if (biSize > 40)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),biSize-40) != static_cast<long>(biSize - 40))
+ {
+ throw image_load_error("bmp load error 7: header too small");
+ }
+ bytes_read_so_far += biSize-40;
+ }
+
+ image.set_size(biHeight, biWidth);
+
+ switch (biBitCount)
+ {
+ case 1:
+ {
+ // figure out how the pixels are packed
+ long padding;
+ if (bfSize - bfOffBits == biWidth*biHeight/8)
+ padding = 0;
+ else
+ padding = 4 - ((biWidth+7)/8)%4;
+
+ const unsigned int palette_size = 2;
+ unsigned char red[palette_size];
+ unsigned char green[palette_size];
+ unsigned char blue[palette_size];
+
+ for (unsigned int i = 0; i < palette_size; ++i)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),4) != 4)
+ {
+ throw image_load_error("bmp load error 20: color palette missing");
+ }
+ bytes_read_so_far += 4;
+ blue[i] = buf[0];
+ green[i] = buf[1];
+ red[i] = buf[2];
+ }
+
+
+ // seek to the start of the pixel data
+ while (bytes_read_so_far != bfOffBits)
+ {
+ const long to_read = (long)std::min(bfOffBits - bytes_read_so_far, (unsigned long)sizeof(buf));
+ if (in.sgetn(reinterpret_cast<char*>(buf), to_read) != to_read)
+ {
+ throw image_load_error("bmp load error: missing data");
+ }
+ bytes_read_so_far += to_read;
+ }
+
+ // load the image data
+ for (long row = biHeight-1; row >= 0; --row)
+ {
+ for (unsigned long col = 0; col < biWidth; col+=8)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),1) != 1)
+ {
+ throw image_load_error("bmp load error 21.6: file too short");
+ }
+
+ unsigned char pixels[8];
+
+ pixels[0] = (buf[0]>>7);
+ pixels[1] = ((buf[0]>>6)&0x01);
+ pixels[2] = ((buf[0]>>5)&0x01);
+ pixels[3] = ((buf[0]>>4)&0x01);
+ pixels[4] = ((buf[0]>>3)&0x01);
+ pixels[5] = ((buf[0]>>2)&0x01);
+ pixels[6] = ((buf[0]>>1)&0x01);
+ pixels[7] = ((buf[0])&0x01);
+
+ for (int i = 0; i < 8 && col+i < biWidth; ++i)
+ {
+ rgb_pixel p;
+ p.red = red[pixels[i]];
+ p.green = green[pixels[i]];
+ p.blue = blue[pixels[i]];
+ assign_pixel(image[row][col+i],p);
+ }
+ }
+ if (in.sgetn(reinterpret_cast<char*>(buf),padding) != padding)
+ throw image_load_error("bmp load error 9: file too short");
+ }
+
+
+
+ } break;
+ case 4:
+ {
+ // figure out how the pixels are packed
+ long padding;
+ if (bfSize - bfOffBits == biWidth*biHeight/2)
+ padding = 0;
+ else
+ padding = 4 - ((biWidth+1)/2)%4;
+
+ const unsigned int palette_size = 16;
+ unsigned char red[palette_size];
+ unsigned char green[palette_size];
+ unsigned char blue[palette_size];
+
+ for (unsigned int i = 0; i < palette_size; ++i)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),4) != 4)
+ {
+ throw image_load_error("bmp load error 20: color palette missing");
+ }
+ bytes_read_so_far += 4;
+ blue[i] = buf[0];
+ green[i] = buf[1];
+ red[i] = buf[2];
+ }
+
+
+ // seek to the start of the pixel data
+ while (bytes_read_so_far != bfOffBits)
+ {
+ const long to_read = (long)std::min(bfOffBits - bytes_read_so_far, (unsigned long)sizeof(buf));
+ if (in.sgetn(reinterpret_cast<char*>(buf), to_read) != to_read)
+ {
+ throw image_load_error("bmp load error: missing data");
+ }
+ bytes_read_so_far += to_read;
+ }
+
+ // load the image data
+ for (long row = biHeight-1; row >= 0; --row)
+ {
+ for (unsigned long col = 0; col < biWidth; col+=2)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),1) != 1)
+ {
+ throw image_load_error("bmp load error 21.7: file too short");
+ }
+
+ const unsigned char pixel1 = (buf[0]>>4);
+ const unsigned char pixel2 = (buf[0]&0x0F);
+
+ rgb_pixel p;
+ p.red = red[pixel1];
+ p.green = green[pixel1];
+ p.blue = blue[pixel1];
+ assign_pixel(image[row][col], p);
+
+ if (col+1 < biWidth)
+ {
+ p.red = red[pixel2];
+ p.green = green[pixel2];
+ p.blue = blue[pixel2];
+ assign_pixel(image[row][col+1], p);
+ }
+ }
+ if (in.sgetn(reinterpret_cast<char*>(buf),padding) != padding)
+ throw image_load_error("bmp load error 9: file too short");
+ }
+
+
+
+ } break;
+ case 8:
+ {
+ // figure out how the pixels are packed
+ long padding;
+ if (bfSize - bfOffBits == biWidth*biHeight)
+ padding = 0;
+ else
+ padding = 4 - biWidth%4;
+
+ // check for this case. It shouldn't happen but some BMP writers screw up the files
+ // so we have to do this.
+ if (biHeight*(biWidth+padding) > bfSize - bfOffBits)
+ padding = 0;
+
+ const unsigned int palette_size = 256;
+ unsigned char red[palette_size];
+ unsigned char green[palette_size];
+ unsigned char blue[palette_size];
+
+ for (unsigned int i = 0; i < palette_size; ++i)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),4) != 4)
+ {
+ throw image_load_error("bmp load error 20: color palette missing");
+ }
+ bytes_read_so_far += 4;
+ blue[i] = buf[0];
+ green[i] = buf[1];
+ red[i] = buf[2];
+ }
+
+
+ // seek to the start of the pixel data
+ while (bytes_read_so_far != bfOffBits)
+ {
+ const long to_read = (long)std::min(bfOffBits - bytes_read_so_far, (unsigned long)sizeof(buf));
+ if (in.sgetn(reinterpret_cast<char*>(buf), to_read) != to_read)
+ {
+ throw image_load_error("bmp load error: missing data");
+ }
+ bytes_read_so_far += to_read;
+ }
+
+ // Next we load the image data.
+
+ // if there is no RLE compression
+ if (biCompression == 0)
+ {
+ for (long row = biHeight-1; row >= 0; --row)
+ {
+ for (unsigned long col = 0; col < biWidth; ++col)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),1) != 1)
+ {
+ throw image_load_error("bmp load error 21.8: file too short");
+ }
+
+ rgb_pixel p;
+ p.red = red[buf[0]];
+ p.green = green[buf[0]];
+ p.blue = blue[buf[0]];
+ assign_pixel(image[row][col],p);
+ }
+ if (in.sgetn(reinterpret_cast<char*>(buf),padding) != padding)
+ throw image_load_error("bmp load error 9: file too short");
+ }
+ }
+ else
+ {
+ // Here we deal with the psychotic RLE used by BMP files.
+
+ // First zero the image since the RLE sometimes jumps over
+ // pixels and assumes the image has been zero initialized.
+ assign_all_pixels(image, 0);
+
+ long row = biHeight-1;
+ long col = 0;
+ while (true)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),2) != 2)
+ {
+ throw image_load_error("bmp load error 21.9: file too short");
+ }
+
+ const unsigned char count = buf[0];
+ const unsigned char command = buf[1];
+
+ if (count == 0 && command == 0)
+ {
+ // This is an escape code that means go to the next row
+ // of the image
+ --row;
+ col = 0;
+ continue;
+ }
+ else if (count == 0 && command == 1)
+ {
+ // This is the end of the image. So quit this loop.
+ break;
+ }
+ else if (count == 0 && command == 2)
+ {
+ // This is the escape code for the command to jump to
+ // a new part of the image relative to where we are now.
+ if (in.sgetn(reinterpret_cast<char*>(buf),2) != 2)
+ {
+ throw image_load_error("bmp load error 21.1: file too short");
+ }
+ col += buf[0];
+ row -= buf[1];
+ continue;
+ }
+ else if (count == 0)
+ {
+ // This is the escape code for a run of uncompressed bytes
+
+ if (row < 0 || col + command > image.nc())
+ {
+ // If this is just some padding bytes at the end then ignore them
+ if (row >= 0 && col + count <= image.nc() + padding)
+ continue;
+
+ throw image_load_error("bmp load error 21.2: file data corrupt");
+ }
+
+ // put the bytes into the image
+ for (unsigned int i = 0; i < command; ++i)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),1) != 1)
+ {
+ throw image_load_error("bmp load error 21.3: file too short");
+ }
+ rgb_pixel p;
+ p.red = red[buf[0]];
+ p.green = green[buf[0]];
+ p.blue = blue[buf[0]];
+ assign_pixel(image[row][col],p);
+
+ ++col;
+ }
+
+ // if we read an uneven number of bytes then we need to read and
+ // discard the next byte.
+ if ((command&1) != 1)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),1) != 1)
+ {
+ throw image_load_error("bmp load error 21.4: file too short");
+ }
+ }
+
+ continue;
+ }
+
+ rgb_pixel p;
+
+ if (row < 0 || col + count > image.nc())
+ {
+ // If this is just some padding bytes at the end then ignore them
+ if (row >= 0 && col + count <= image.nc() + padding)
+ continue;
+
+ throw image_load_error("bmp load error 21.5: file data corrupt");
+ }
+
+ // put the bytes into the image
+ for (unsigned int i = 0; i < count; ++i)
+ {
+ p.red = red[command];
+ p.green = green[command];
+ p.blue = blue[command];
+ assign_pixel(image[row][col],p);
+
+ ++col;
+ }
+ }
+ }
+
+
+
+ }
+ break;
+ case 16:
+ throw image_load_error ("16 bit BMP images not supported");
+ case 24:
+ {
+ // figure out how the pixels are packed
+ long padding;
+ if (bfSize - bfOffBits == biWidth*biHeight*3)
+ padding = 0;
+ else
+ padding = 4 - (biWidth*3)%4;
+
+ // check for this case. It shouldn't happen but some BMP writers screw up the files
+ // so we have to do this.
+ if (biHeight*(biWidth*3+padding) > bfSize - bfOffBits)
+ padding = 0;
+
+ // seek to the start of the pixel data
+ while (bytes_read_so_far != bfOffBits)
+ {
+ const long to_read = (long)std::min(bfOffBits - bytes_read_so_far, (unsigned long)sizeof(buf));
+ if (in.sgetn(reinterpret_cast<char*>(buf), to_read) != to_read)
+ {
+ throw image_load_error("bmp load error: missing data");
+ }
+ bytes_read_so_far += to_read;
+ }
+
+ // load the image data
+ for (long row = biHeight-1; row >= 0; --row)
+ {
+ for (unsigned long col = 0; col < biWidth; ++col)
+ {
+ if (in.sgetn(reinterpret_cast<char*>(buf),3) != 3)
+ {
+ throw image_load_error("bmp load error 8: file too short");
+ }
+
+ rgb_pixel p;
+ p.red = buf[2];
+ p.green = buf[1];
+ p.blue = buf[0];
+ assign_pixel(image[row][col], p);
+
+ }
+ if (in.sgetn(reinterpret_cast<char*>(buf),padding) != padding)
+ throw image_load_error("bmp load error 9: file too short");
+ }
+
+ break;
+ }
+ case 32:
+ throw image_load_error ("32 bit BMP images not supported");
+ default:
+ throw image_load_error("bmp load error 10: unknown color depth");
+
+ }
+ }
+ catch (...)
+ {
+ image.clear();
+ throw;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void load_dng (
+ image_type& image_,
+ std::istream& in
+ )
+ {
+ image_view<image_type> image(image_);
+ using namespace dng_helpers_namespace;
+ try
+ {
+ if (in.get() != 'D' || in.get() != 'N' || in.get() != 'G')
+ throw image_load_error("the stream does not contain a dng image file");
+
+ unsigned long version;
+ deserialize(version,in);
+ if (version != 1)
+ throw image_load_error("You need the new version of the dlib library to read this dng file");
+
+ unsigned long type;
+ deserialize(type,in);
+
+ long width, height;
+ deserialize(width,in);
+ deserialize(height,in);
+
+ if (width > 0 && height > 0)
+ image.set_size(height,width);
+ else
+ image.clear();
+
+ if (type != grayscale_float)
+ {
+ typedef entropy_decoder::kernel_2a decoder_type;
+ decoder_type decoder;
+ decoder.set_stream(in);
+
+ entropy_decoder_model<256,decoder_type>::kernel_5a edm(decoder);
+ unsigned long symbol;
+ rgb_pixel p_rgb;
+ rgb_alpha_pixel p_rgba;
+ hsi_pixel p_hsi;
+ switch (type)
+ {
+ case rgb_alpha_paeth:
+
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ p_rgba = predictor_rgb_alpha_paeth(image,r,c);
+ edm.decode(symbol);
+ p_rgba.red += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgba.green += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgba.blue += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgba.alpha += static_cast<unsigned char>(symbol);
+
+ assign_pixel(image[r][c],p_rgba);
+ }
+ }
+ break;
+
+ case rgb_alpha:
+
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ p_rgba = predictor_rgb_alpha(image,r,c);
+ edm.decode(symbol);
+ p_rgba.red += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgba.green += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgba.blue += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgba.alpha += static_cast<unsigned char>(symbol);
+
+ assign_pixel(image[r][c],p_rgba);
+ }
+ }
+ break;
+
+ case rgb_paeth:
+
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ p_rgb = predictor_rgb_paeth(image,r,c);
+ edm.decode(symbol);
+ p_rgb.red += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgb.green += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgb.blue += static_cast<unsigned char>(symbol);
+
+ assign_pixel(image[r][c],p_rgb);
+ }
+ }
+ break;
+
+ case rgb:
+
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ p_rgb = predictor_rgb(image,r,c);
+ edm.decode(symbol);
+ p_rgb.red += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgb.green += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_rgb.blue += static_cast<unsigned char>(symbol);
+
+ assign_pixel(image[r][c],p_rgb);
+ }
+ }
+ break;
+
+ case hsi:
+
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ p_hsi = predictor_hsi(image,r,c);
+ edm.decode(symbol);
+ p_hsi.h += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_hsi.s += static_cast<unsigned char>(symbol);
+
+ edm.decode(symbol);
+ p_hsi.i += static_cast<unsigned char>(symbol);
+
+ assign_pixel(image[r][c],p_hsi);
+ }
+ }
+ break;
+
+ case grayscale:
+ {
+ unsigned char p;
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ edm.decode(symbol);
+ p = static_cast<unsigned char>(symbol);
+ p += predictor_grayscale(image,r,c);
+ assign_pixel(image[r][c],p);
+ }
+ }
+ }
+ break;
+
+ case grayscale_16bit:
+ {
+ uint16 p;
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ edm.decode(symbol);
+ p = static_cast<uint16>(symbol);
+ p <<= 8;
+ edm.decode(symbol);
+ p |= static_cast<uint16>(symbol);
+
+ p += predictor_grayscale_16(image,r,c);
+ assign_pixel(image[r][c],p);
+ }
+ }
+ }
+ break;
+
+ default:
+ throw image_load_error("corruption detected in the dng file");
+ } // switch (type)
+
+ edm.decode(symbol);
+ if (symbol != dng_magic_byte)
+ throw image_load_error("corruption detected in the dng file");
+ edm.decode(symbol);
+ if (symbol != dng_magic_byte)
+ throw image_load_error("corruption detected in the dng file");
+ edm.decode(symbol);
+ if (symbol != dng_magic_byte)
+ throw image_load_error("corruption detected in the dng file");
+ edm.decode(symbol);
+ if (symbol != dng_magic_byte)
+ throw image_load_error("corruption detected in the dng file");
+ }
+ else // if this is a grayscale_float type image
+ {
+ std::vector<int64> man(image.size());
+ std::vector<char> expbuf;
+ // get the mantissa data
+ for (unsigned long j = 0; j < man.size(); ++j)
+ deserialize(man[j], in);
+ // get the compressed exponent data
+ deserialize(expbuf, in);
+ typedef entropy_decoder::kernel_2a decoder_type;
+ typedef entropy_decoder_model<256,decoder_type>::kernel_4a edm_exp_type;
+ vectorstream inexp(expbuf);
+ decoder_type decoder;
+ decoder.set_stream(inexp);
+
+ edm_exp_type edm_exp(decoder);
+ float_details prev;
+ unsigned long i = 0;
+ // fill out the image
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ unsigned long exp1, exp2;
+ edm_exp.decode(exp1);
+ edm_exp.decode(exp2);
+
+ float_details cur(man[i++],(exp2<<8) | exp1);
+ cur.exponent += prev.exponent;
+ cur.mantissa += prev.mantissa;
+ prev = cur;
+
+ // Only use long double precision if the target image contains long
+ // doubles because it's slower to use those.
+ if (!is_same_type<typename image_traits<image_type>::pixel_type,long double>::value)
+ {
+ double temp = cur;
+ assign_pixel(image[r][c],temp);
+ }
+ else
+ {
+ long double temp = cur;
+ assign_pixel(image[r][c],temp);
+ }
+ }
+ }
+ unsigned long symbol;
+ edm_exp.decode(symbol);
+ if (symbol != dng_magic_byte)
+ throw image_load_error("corruption detected in the dng file");
+ edm_exp.decode(symbol);
+ if (symbol != dng_magic_byte)
+ throw image_load_error("corruption detected in the dng file");
+ edm_exp.decode(symbol);
+ if (symbol != dng_magic_byte)
+ throw image_load_error("corruption detected in the dng file");
+ edm_exp.decode(symbol);
+ if (symbol != dng_magic_byte)
+ throw image_load_error("corruption detected in the dng file");
+ }
+ }
+ catch (...)
+ {
+ image.clear();
+ throw;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ void load_bmp (
+ image_type& image,
+ const std::string& file_name
+ )
+ {
+ std::ifstream fin(file_name.c_str(), std::ios::binary);
+ if (!fin)
+ throw image_load_error("Unable to open " + file_name + " for reading.");
+ load_bmp(image, fin);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ void load_dng (
+ image_type& image,
+ const std::string& file_name
+ )
+ {
+ std::ifstream fin(file_name.c_str(), std::ios::binary);
+ if (!fin)
+ throw image_load_error("Unable to open " + file_name + " for reading.");
+ load_dng(image, fin);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IMAGE_LOADEr_
+
+
+
diff --git a/ml/dlib/dlib/image_loader/image_loader_abstract.h b/ml/dlib/dlib/image_loader/image_loader_abstract.h
new file mode 100644
index 000000000..cd66b3699
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/image_loader_abstract.h
@@ -0,0 +1,136 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_IMAGE_LOADEr_ABSTRACT_
+#ifdef DLIB_IMAGE_LOADEr_ABSTRACT_
+
+#include <iosfwd>
+#include "../algs.h"
+#include "../pixel.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+ class image_load_error : public dlib::error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an exception used to indicate a failure to load an image.
+ Its type member variable will be set to EIMAGE_LOAD.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void load_bmp (
+ image_type& image,
+ std::istream& in
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - #image == the image of the MS Windows BMP file that was available
+ in the input stream in.
+ - #image[0][0] will be the upper left corner of the image
+ - #image[image.nr()-1][image.nc()-1] will be the lower right
+ corner of the image
+ - Performs any color space conversion necessary to convert the
+ BMP image data into the pixel type used by the given image
+ object.
+ throws
+ - image_load_error
+ This exception is thrown if there is an error that prevents us
+ from loading the image. If this exception is thrown then
+ #image will have an initial value for its type.
+ - std::bad_alloc
+ If this exception is thrown then #image will have an initial
+ value for its type.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void load_bmp (
+ image_type& image,
+ const std::string& file_name
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - opens the file indicated by file_name with an input file stream named fin
+ and performs:
+ load_bmp(image,fin);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ dlib dng file format:
+ This is a file format I created for this library. It is a lossless
+ compressed image format that is similar to the PNG format but uses
+ the dlib PPM compression algorithms instead of the DEFLATE algorithm.
+ !*/
+
+ template <
+ typename image_type
+ >
+ void load_dng (
+ image_type& image,
+ std::istream& in
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - #image == the image of the dlib dng file that was available
+ in the input stream in.
+ - #image[0][0] will be the upper left corner of the image
+ - #image[image.nr()-1][image.nc()-1] will be the lower right
+ corner of the image
+ - Performs any color space conversion necessary to convert the
+ dng image data into the pixel type used by the given image
+ object.
+ throws
+ - image_load_error
+ This exception is thrown if there is an error that prevents us
+ from loading the image. If this exception is thrown then
+ #image will have an initial value for its type.
+ - std::bad_alloc
+ If this exception is thrown then #image will have an initial
+ value for its type.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void load_dng (
+ image_type& image,
+ const std::string& file_name
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - opens the file indicated by file_name with an input file stream named fin
+ and performs:
+ load_dng(image,fin);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IMAGE_LOADEr_ABSTRACT_
+
diff --git a/ml/dlib/dlib/image_loader/jpeg_loader.cpp b/ml/dlib/dlib/image_loader/jpeg_loader.cpp
new file mode 100644
index 000000000..710d7586f
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/jpeg_loader.cpp
@@ -0,0 +1,173 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net), Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_JPEG_LOADER_CPp_
+#define DLIB_JPEG_LOADER_CPp_
+
+// only do anything with this file if DLIB_JPEG_SUPPORT is defined
+#ifdef DLIB_JPEG_SUPPORT
+
+#include "../array2d.h"
+#include "../pixel.h"
+#include "../dir_nav.h"
+#include "jpeg_loader.h"
+#include <stdio.h>
+#ifdef DLIB_JPEG_STATIC
+# include "../external/libjpeg/jpeglib.h"
+#else
+# include <jpeglib.h>
+#endif
+#include <sstream>
+#include <setjmp.h>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ jpeg_loader::
+ jpeg_loader( const char* filename ) : height_( 0 ), width_( 0 ), output_components_(0)
+ {
+ read_image( filename );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ jpeg_loader::
+ jpeg_loader( const std::string& filename ) : height_( 0 ), width_( 0 ), output_components_(0)
+ {
+ read_image( filename.c_str() );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ jpeg_loader::
+ jpeg_loader( const dlib::file& f ) : height_( 0 ), width_( 0 ), output_components_(0)
+ {
+ read_image( f.full_name().c_str() );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool jpeg_loader::is_gray() const
+ {
+ return (output_components_ == 1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool jpeg_loader::is_rgb() const
+ {
+ return (output_components_ == 3);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool jpeg_loader::is_rgba() const
+ {
+ return (output_components_ == 4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ struct jpeg_loader_error_mgr
+ {
+ jpeg_error_mgr pub; /* "public" fields */
+ jmp_buf setjmp_buffer; /* for return to caller */
+ };
+
+ void jpeg_loader_error_exit (j_common_ptr cinfo)
+ {
+ /* cinfo->err really points to a jpeg_loader_error_mgr struct, so coerce pointer */
+ jpeg_loader_error_mgr* myerr = (jpeg_loader_error_mgr*) cinfo->err;
+
+ /* Return control to the setjmp point */
+ longjmp(myerr->setjmp_buffer, 1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void jpeg_loader::read_image( const char* filename )
+ {
+ if ( filename == NULL )
+ {
+ throw image_load_error("jpeg_loader: invalid filename, it is NULL");
+ }
+ FILE *fp = fopen( filename, "rb" );
+ if ( !fp )
+ {
+ throw image_load_error(std::string("jpeg_loader: unable to open file ") + filename);
+ }
+
+ jpeg_decompress_struct cinfo;
+ jpeg_loader_error_mgr jerr;
+
+ cinfo.err = jpeg_std_error(&jerr.pub);
+
+ jerr.pub.error_exit = jpeg_loader_error_exit;
+
+ /* Establish the setjmp return context for my_error_exit to use. */
+ if (setjmp(jerr.setjmp_buffer))
+ {
+ /* If we get here, the JPEG code has signaled an error.
+ * We need to clean up the JPEG object, close the input file, and return.
+ */
+ jpeg_destroy_decompress(&cinfo);
+ fclose(fp);
+ throw image_load_error(std::string("jpeg_loader: error while reading ") + filename);
+ }
+
+
+ jpeg_create_decompress(&cinfo);
+
+ jpeg_stdio_src(&cinfo, fp);
+
+ jpeg_read_header(&cinfo, TRUE);
+
+ jpeg_start_decompress(&cinfo);
+
+ height_ = cinfo.output_height;
+ width_ = cinfo.output_width;
+ output_components_ = cinfo.output_components;
+
+ if (output_components_ != 1 &&
+ output_components_ != 3 &&
+ output_components_ != 4)
+ {
+ fclose( fp );
+ jpeg_destroy_decompress(&cinfo);
+ std::ostringstream sout;
+ sout << "jpeg_loader: Unsupported number of colors (" << output_components_ << ") in file " << filename;
+ throw image_load_error(sout.str());
+ }
+
+ std::vector<unsigned char*> rows;
+ rows.resize(height_);
+
+ // size the image buffer
+ data.resize(height_*width_*output_components_);
+
+ // setup pointers to each row
+ for (unsigned long i = 0; i < rows.size(); ++i)
+ rows[i] = &data[i*width_*output_components_];
+
+ // read the data into the buffer
+ while (cinfo.output_scanline < cinfo.output_height)
+ {
+ jpeg_read_scanlines(&cinfo, &rows[cinfo.output_scanline], 100);
+ }
+
+ jpeg_finish_decompress(&cinfo);
+ jpeg_destroy_decompress(&cinfo);
+
+ fclose( fp );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_JPEG_SUPPORT
+
+#endif // DLIB_JPEG_LOADER_CPp_
+
+
diff --git a/ml/dlib/dlib/image_loader/jpeg_loader.h b/ml/dlib/dlib/image_loader/jpeg_loader.h
new file mode 100644
index 000000000..097a461f8
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/jpeg_loader.h
@@ -0,0 +1,109 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net), Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_JPEG_IMPORT
+#define DLIB_JPEG_IMPORT
+
+#include <vector>
+
+#include "jpeg_loader_abstract.h"
+#include "image_loader.h"
+#include "../pixel.h"
+#include "../dir_nav.h"
+#include "../test_for_odr_violations.h"
+
+namespace dlib
+{
+
+ class jpeg_loader : noncopyable
+ {
+ public:
+
+ jpeg_loader( const char* filename );
+ jpeg_loader( const std::string& filename );
+ jpeg_loader( const dlib::file& f );
+
+ bool is_gray() const;
+ bool is_rgb() const;
+ bool is_rgba() const;
+
+ template<typename T>
+ void get_image( T& t_) const
+ {
+#ifndef DLIB_JPEG_SUPPORT
+ /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ You are getting this error because you are trying to use the jpeg_loader
+ object but you haven't defined DLIB_JPEG_SUPPORT. You must do so to use
+ this object. You must also make sure you set your build environment
+ to link against the libjpeg library.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+ COMPILE_TIME_ASSERT(sizeof(T) == 0);
+#endif
+ image_view<T> t(t_);
+ t.set_size( height_, width_ );
+ for ( unsigned n = 0; n < height_;n++ )
+ {
+ const unsigned char* v = get_row( n );
+ for ( unsigned m = 0; m < width_;m++ )
+ {
+ if ( is_gray() )
+ {
+ unsigned char p = v[m];
+ assign_pixel( t[n][m], p );
+ }
+ else if ( is_rgba() ) {
+ rgb_alpha_pixel p;
+ p.red = v[m*4];
+ p.green = v[m*4+1];
+ p.blue = v[m*4+2];
+ p.alpha = v[m*4+3];
+ assign_pixel( t[n][m], p );
+ }
+ else // if ( is_rgb() )
+ {
+ rgb_pixel p;
+ p.red = v[m*3];
+ p.green = v[m*3+1];
+ p.blue = v[m*3+2];
+ assign_pixel( t[n][m], p );
+ }
+ }
+ }
+ }
+
+ private:
+ const unsigned char* get_row( unsigned long i ) const
+ {
+ return &data[i*width_*output_components_];
+ }
+
+ void read_image( const char* filename );
+ unsigned long height_;
+ unsigned long width_;
+ unsigned long output_components_;
+ std::vector<unsigned char> data;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void load_jpeg (
+ image_type& image,
+ const std::string& file_name
+ )
+ {
+ jpeg_loader(file_name).get_image(image);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "jpeg_loader.cpp"
+#endif
+
+#endif // DLIB_JPEG_IMPORT
+
+
diff --git a/ml/dlib/dlib/image_loader/jpeg_loader_abstract.h b/ml/dlib/dlib/image_loader/jpeg_loader_abstract.h
new file mode 100644
index 000000000..48b5bb031
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/jpeg_loader_abstract.h
@@ -0,0 +1,133 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net), Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_JPEG_IMPORT_ABSTRACT
+#ifdef DLIB_JPEG_IMPORT_ABSTRACT
+
+#include "image_loader_abstract.h"
+#include "../algs.h"
+#include "../pixel.h"
+#include "../dir_nav.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+ class jpeg_loader : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ Defined by the constructors
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a class capable of loading JPEG image files.
+ Once an instance of it is created to contain a JPEG file from
+ disk you can obtain the image stored in it via get_image().
+ !*/
+
+ public:
+
+ jpeg_loader(
+ const char* filename
+ );
+ /*!
+ ensures
+ - loads the JPEG file with the given file name into this object
+ throws
+ - std::bad_alloc
+ - image_load_error
+ This exception is thrown if there is some error that prevents
+ us from loading the given JPEG file.
+ !*/
+
+ jpeg_loader(
+ const std::string& filename
+ );
+ /*!
+ ensures
+ - loads the JPEG file with the given file name into this object
+ throws
+ - std::bad_alloc
+ - image_load_error
+ This exception is thrown if there is some error that prevents
+ us from loading the given JPEG file.
+ !*/
+
+ jpeg_loader(
+ const dlib::file& f
+ );
+ /*!
+ ensures
+ - loads the JPEG file with the given file name into this object
+ throws
+ - std::bad_alloc
+ - image_load_error
+ This exception is thrown if there is some error that prevents
+ us from loading the given JPEG file.
+ !*/
+
+ ~jpeg_loader(
+ );
+ /*!
+ ensures
+ - all resources associated with *this has been released
+ !*/
+
+ bool is_gray(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains a grayscale image) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_rgb(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains a 3 channel RGB image) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ template<
+ typename image_type
+ >
+ void get_image(
+ image_type& img
+ ) const;
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - loads the JPEG image stored in this object into img
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void load_jpeg (
+ image_type& image,
+ const std::string& file_name
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - performs: jpeg_loader(file_name).get_image(image);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_JPEG_IMPORT_ABSTRACT
+
diff --git a/ml/dlib/dlib/image_loader/load_image.h b/ml/dlib/dlib/image_loader/load_image.h
new file mode 100644
index 000000000..64ccea9f2
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/load_image.h
@@ -0,0 +1,226 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net), Nils Labugt, Changjiang Yang (yangcha@leidos.com)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LOAd_IMAGE_Hh_
+#define DLIB_LOAd_IMAGE_Hh_
+
+#include "load_image_abstract.h"
+#include "../string.h"
+#include "png_loader.h"
+#include "jpeg_loader.h"
+#include "image_loader.h"
+#include <fstream>
+#include <sstream>
+#ifdef DLIB_GIF_SUPPORT
+#include <gif_lib.h>
+#endif
+
+namespace dlib
+{
+ namespace image_file_type
+ {
+ enum type
+ {
+ BMP,
+ JPG,
+ PNG,
+ DNG,
+ GIF,
+ UNKNOWN
+ };
+
+ inline type read_type(const std::string& file_name)
+ {
+ std::ifstream file(file_name.c_str(), std::ios::in|std::ios::binary);
+ if (!file)
+ throw image_load_error("Unable to open file: " + file_name);
+
+ char buffer[9];
+ file.read((char*)buffer, 8);
+ buffer[8] = 0;
+
+ // Determine the true image type using link:
+ // http://en.wikipedia.org/wiki/List_of_file_signatures
+
+ if (strcmp(buffer, "\x89\x50\x4E\x47\x0D\x0A\x1A\x0A") == 0)
+ return PNG;
+ else if(buffer[0]=='\xff' && buffer[1]=='\xd8' && buffer[2]=='\xff')
+ return JPG;
+ else if(buffer[0]=='B' && buffer[1]=='M')
+ return BMP;
+ else if(buffer[0]=='D' && buffer[1]=='N' && buffer[2] == 'G')
+ return DNG;
+ else if(buffer[0]=='G' && buffer[1]=='I' && buffer[2] == 'F')
+ return GIF;
+
+ return UNKNOWN;
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+// handle the differences in API between libgif v5 and older.
+#if defined(GIFLIB_MAJOR) && GIFLIB_MAJOR >= 5
+#define DLIB_GIFLIB_HANDLE_DIFF_VERSIONS ,0
+#else
+#define DLIB_GIFLIB_HANDLE_DIFF_VERSIONS
+#endif
+
+ template <typename image_type>
+ void load_image (
+ image_type& image,
+ const std::string& file_name
+ )
+ {
+ const image_file_type::type im_type = image_file_type::read_type(file_name);
+ switch (im_type)
+ {
+ case image_file_type::BMP: load_bmp(image, file_name); return;
+ case image_file_type::DNG: load_dng(image, file_name); return;
+#ifdef DLIB_PNG_SUPPORT
+ case image_file_type::PNG: load_png(image, file_name); return;
+#endif
+#ifdef DLIB_JPEG_SUPPORT
+ case image_file_type::JPG: load_jpeg(image, file_name); return;
+#endif
+#ifdef DLIB_GIF_SUPPORT
+ case image_file_type::GIF:
+ {
+ image_view<image_type> img(image);
+ GifFileType* gif = DGifOpenFileName(file_name.c_str() DLIB_GIFLIB_HANDLE_DIFF_VERSIONS);
+ try
+ {
+ if (gif == 0) throw image_load_error("Couldn't open file " + file_name);
+ if (DGifSlurp(gif) != GIF_OK)
+ throw image_load_error("Error reading from " + file_name);
+
+ if (gif->ImageCount != 1) throw image_load_error("Dlib only supports reading GIF files containing one image.");
+ if (gif->SavedImages == 0) throw image_load_error("Unsupported GIF format 1.");
+
+ ColorMapObject* cmo=gif->SColorMap?gif->SColorMap:gif->SavedImages->ImageDesc.ColorMap;
+
+ if (cmo==0) throw image_load_error("Unsupported GIF format 2.");
+ if (cmo->Colors == 0) throw image_load_error("Unsupported GIF format 3.");
+ if (gif->SavedImages->ImageDesc.Width != gif->SWidth) throw image_load_error("Unsupported GIF format 4.");
+ if (gif->SavedImages->ImageDesc.Height != gif->SHeight) throw image_load_error("Unsupported GIF format 5.");
+ if (gif->SavedImages->RasterBits == 0) throw image_load_error("Unsupported GIF format 6.");
+ if (gif->Image.Top != 0) throw image_load_error("Unsupported GIF format 7.");
+ if (gif->Image.Left != 0) throw image_load_error("Unsupported GIF format 8.");
+
+ img.set_size(gif->SHeight, gif->SWidth);
+ unsigned char* raster = gif->SavedImages->RasterBits;
+ GifColorType* colormap = cmo->Colors;
+ if (gif->Image.Interlace)
+ {
+ const long interlaced_offset[] = { 0, 4, 2, 1 };
+ const long interlaced_jumps[] = { 8, 8, 4, 2 };
+ for (int i = 0; i < 4; ++i)
+ {
+ for (long r = interlaced_offset[i]; r < img.nr(); r += interlaced_jumps[i])
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ if (*raster >= cmo->ColorCount)
+ throw image_load_error("Invalid GIF color value");
+ rgb_pixel p;
+ p.red = colormap[*raster].Red;
+ p.green = colormap[*raster].Green;
+ p.blue = colormap[*raster].Blue;
+ assign_pixel(img[r][c], p);
+ ++raster;
+ }
+ }
+ }
+ }
+ else
+ {
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ if (*raster >= cmo->ColorCount)
+ throw image_load_error("Invalid GIF color value");
+ rgb_pixel p;
+ p.red = colormap[*raster].Red;
+ p.green = colormap[*raster].Green;
+ p.blue = colormap[*raster].Blue;
+ assign_pixel(img[r][c], p);
+ ++raster;
+ }
+ }
+ }
+ DGifCloseFile(gif DLIB_GIFLIB_HANDLE_DIFF_VERSIONS);
+ }
+ catch(...)
+ {
+ if (gif)
+ DGifCloseFile(gif DLIB_GIFLIB_HANDLE_DIFF_VERSIONS);
+ throw;
+ }
+ return;
+ }
+#endif
+ default: ;
+ }
+
+ if (im_type == image_file_type::JPG)
+ {
+ std::ostringstream sout;
+ sout << "Unable to load image in file " + file_name + ".\n" +
+ "You must #define DLIB_JPEG_SUPPORT and link to libjpeg to read JPEG files.\n" +
+ "Do this by following the instructions at http://dlib.net/compile.html.\n\n";
+#ifdef _MSC_VER
+ sout << "Note that you must cause DLIB_JPEG_SUPPORT to be defined for your entire project.\n";
+ sout << "So don't #define it in one file. Instead, add it to the C/C++->Preprocessor->Preprocessor Definitions\n";
+ sout << "field in Visual Studio's Property Pages window so it takes effect for your entire application.";
+#else
+ sout << "Note that you must cause DLIB_JPEG_SUPPORT to be defined for your entire project.\n";
+ sout << "So don't #define it in one file. Instead, use a compiler switch like -DDLIB_JPEG_SUPPORT\n";
+ sout << "so it takes effect for your entire application.";
+#endif
+ throw image_load_error(sout.str());
+ }
+ else if (im_type == image_file_type::PNG)
+ {
+ std::ostringstream sout;
+ sout << "Unable to load image in file " + file_name + ".\n" +
+ "You must #define DLIB_PNG_SUPPORT and link to libpng to read PNG files.\n" +
+ "Do this by following the instructions at http://dlib.net/compile.html.\n\n";
+#ifdef _MSC_VER
+ sout << "Note that you must cause DLIB_PNG_SUPPORT to be defined for your entire project.\n";
+ sout << "So don't #define it in one file. Instead, add it to the C/C++->Preprocessor->Preprocessor Definitions\n";
+ sout << "field in Visual Studio's Property Pages window so it takes effect for your entire application.\n";
+#else
+ sout << "Note that you must cause DLIB_PNG_SUPPORT to be defined for your entire project.\n";
+ sout << "So don't #define it in one file. Instead, use a compiler switch like -DDLIB_PNG_SUPPORT\n";
+ sout << "so it takes effect for your entire application.";
+#endif
+ throw image_load_error(sout.str());
+ }
+ else if (im_type == image_file_type::GIF)
+ {
+ std::ostringstream sout;
+ sout << "Unable to load image in file " + file_name + ".\n" +
+ "You must #define DLIB_GIF_SUPPORT and link to libgif to read GIF files.\n\n";
+#ifdef _MSC_VER
+ sout << "Note that you must cause DLIB_GIF_SUPPORT to be defined for your entire project.\n";
+ sout << "So don't #define it in one file. Instead, add it to the C/C++->Preprocessor->Preprocessor Definitions\n";
+ sout << "field in Visual Studio's Property Pages window so it takes effect for your entire application.\n";
+#else
+ sout << "Note that you must cause DLIB_GIF_SUPPORT to be defined for your entire project.\n";
+ sout << "So don't #define it in one file. Instead, use a compiler switch like -DDLIB_GIF_SUPPORT\n";
+ sout << "so it takes effect for your entire application.";
+#endif
+ throw image_load_error(sout.str());
+ }
+ else
+ {
+ throw image_load_error("Unknown image file format: Unable to load image in file " + file_name);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LOAd_IMAGE_Hh_
+
diff --git a/ml/dlib/dlib/image_loader/load_image_abstract.h b/ml/dlib/dlib/image_loader/load_image_abstract.h
new file mode 100644
index 000000000..f357bb278
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/load_image_abstract.h
@@ -0,0 +1,37 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net), Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LOAd_IMAGE_ABSTRACT_
+#ifdef DLIB_LOAd_IMAGE_ABSTRACT_
+
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+ template <typename image_type>
+ void load_image (
+ image_type& image,
+ const std::string& file_name
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - This function loads an image from disk, in the indicated file file_name, and
+ writes it to the indicated image object.
+ - It is capable of reading the PNG, JPEG, BMP, GIF, and DNG image formats. It
+ is always capable of reading BMP and DNG images. However, for PNG, JPEG, and
+ GIF you must #define DLIB_PNG_SUPPORT, DLIB_JPEG_SUPPORT, and
+ DLIB_GIF_SUPPORT respectively and link your program to libpng, libjpeg, and
+ libgif respectively.
+ throws
+ - image_load_error
+ This exception is thrown if there is some error that prevents
+ us from loading the given image file.
+ !*/
+
+}
+
+#endif // DLIB_LOAd_IMAGE_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/image_loader/png_loader.cpp b/ml/dlib/dlib/image_loader/png_loader.cpp
new file mode 100644
index 000000000..3346ddb6a
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/png_loader.cpp
@@ -0,0 +1,222 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net), Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PNG_LOADER_CPp_
+#define DLIB_PNG_LOADER_CPp_
+
+// only do anything with this file if DLIB_PNG_SUPPORT is defined
+#ifdef DLIB_PNG_SUPPORT
+
+#include "../array2d.h"
+#include "../pixel.h"
+#include "../dir_nav.h"
+#include "png_loader.h"
+#include <png.h>
+#include "../string.h"
+#include "../byte_orderer.h"
+#include <sstream>
+#include <cstring>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct LibpngData
+ {
+ png_bytep* row_pointers_;
+ png_structp png_ptr_;
+ png_infop info_ptr_;
+ png_infop end_info_;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ png_loader::
+ png_loader( const char* filename ) : height_( 0 ), width_( 0 )
+ {
+ read_image( filename );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ png_loader::
+ png_loader( const std::string& filename ) : height_( 0 ), width_( 0 )
+ {
+ read_image( filename.c_str() );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ png_loader::
+ png_loader( const dlib::file& f ) : height_( 0 ), width_( 0 )
+ {
+ read_image( f.full_name().c_str() );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const unsigned char* png_loader::get_row( unsigned i ) const
+ {
+ return ld_->row_pointers_[i];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ png_loader::~png_loader()
+ {
+ if ( ld_ && ld_->row_pointers_ != NULL )
+ png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool png_loader::is_gray() const
+ {
+ return ( color_type_ == PNG_COLOR_TYPE_GRAY );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool png_loader::is_graya() const
+ {
+ return ( color_type_ == PNG_COLOR_TYPE_GRAY_ALPHA );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool png_loader::is_rgb() const
+ {
+ return ( color_type_ == PNG_COLOR_TYPE_RGB );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool png_loader::is_rgba() const
+ {
+ return ( color_type_ == PNG_COLOR_TYPE_RGB_ALPHA );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // Don't do anything when libpng calls us to tell us about an error. Just return to
+ // our own code and throw an exception (at the long jump target).
+ void png_loader_user_error_fn_silent(png_structp png_struct, png_const_charp )
+ {
+ longjmp(png_jmpbuf(png_struct),1);
+ }
+ void png_loader_user_warning_fn_silent(png_structp , png_const_charp )
+ {
+ }
+
+ void png_loader::read_image( const char* filename )
+ {
+ ld_.reset(new LibpngData);
+ if ( filename == NULL )
+ {
+ throw image_load_error("png_loader: invalid filename, it is NULL");
+ }
+ FILE *fp = fopen( filename, "rb" );
+ if ( !fp )
+ {
+ throw image_load_error(std::string("png_loader: unable to open file ") + filename);
+ }
+ png_byte sig[8];
+ if (fread( sig, 1, 8, fp ) != 8)
+ {
+ fclose( fp );
+ throw image_load_error(std::string("png_loader: error reading file ") + filename);
+ }
+ if ( png_sig_cmp( sig, 0, 8 ) != 0 )
+ {
+ fclose( fp );
+ throw image_load_error(std::string("png_loader: format error in file ") + filename);
+ }
+ ld_->png_ptr_ = png_create_read_struct( PNG_LIBPNG_VER_STRING, NULL, &png_loader_user_error_fn_silent, &png_loader_user_warning_fn_silent );
+ if ( ld_->png_ptr_ == NULL )
+ {
+ fclose( fp );
+ std::ostringstream sout;
+ sout << "Error, unable to allocate png structure while opening file " << filename << std::endl;
+ const char* runtime_version = png_get_header_ver(NULL);
+ if (runtime_version && std::strcmp(PNG_LIBPNG_VER_STRING, runtime_version) != 0)
+ {
+ sout << "This is happening because you compiled against one version of libpng, but then linked to another." << std::endl;
+ sout << "Compiled against libpng version: " << PNG_LIBPNG_VER_STRING << std::endl;
+ sout << "Linking to this version of libpng: " << runtime_version << std::endl;
+ }
+ throw image_load_error(sout.str());
+ }
+ ld_->info_ptr_ = png_create_info_struct( ld_->png_ptr_ );
+ if ( ld_->info_ptr_ == NULL )
+ {
+ fclose( fp );
+ png_destroy_read_struct( &( ld_->png_ptr_ ), ( png_infopp )NULL, ( png_infopp )NULL );
+ throw image_load_error(std::string("png_loader: parse error in file ") + filename);
+ }
+ ld_->end_info_ = png_create_info_struct( ld_->png_ptr_ );
+ if ( ld_->end_info_ == NULL )
+ {
+ fclose( fp );
+ png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), ( png_infopp )NULL );
+ throw image_load_error(std::string("png_loader: parse error in file ") + filename);
+ }
+
+ if (setjmp(png_jmpbuf(ld_->png_ptr_)))
+ {
+ // If we get here, we had a problem writing the file
+ fclose(fp);
+ png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) );
+ throw image_load_error(std::string("png_loader: parse error in file ") + filename);
+ }
+
+ png_set_palette_to_rgb(ld_->png_ptr_);
+
+ png_init_io( ld_->png_ptr_, fp );
+ png_set_sig_bytes( ld_->png_ptr_, 8 );
+ // flags force one byte per channel output
+ byte_orderer bo;
+ int png_transforms = PNG_TRANSFORM_PACKING;
+ if (bo.host_is_little_endian())
+ png_transforms |= PNG_TRANSFORM_SWAP_ENDIAN;
+ png_read_png( ld_->png_ptr_, ld_->info_ptr_, png_transforms, NULL );
+ height_ = png_get_image_height( ld_->png_ptr_, ld_->info_ptr_ );
+ width_ = png_get_image_width( ld_->png_ptr_, ld_->info_ptr_ );
+ bit_depth_ = png_get_bit_depth( ld_->png_ptr_, ld_->info_ptr_ );
+ color_type_ = png_get_color_type( ld_->png_ptr_, ld_-> info_ptr_ );
+
+
+ if (color_type_ != PNG_COLOR_TYPE_GRAY &&
+ color_type_ != PNG_COLOR_TYPE_RGB &&
+ color_type_ != PNG_COLOR_TYPE_RGB_ALPHA &&
+ color_type_ != PNG_COLOR_TYPE_GRAY_ALPHA)
+ {
+ fclose( fp );
+ png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) );
+ throw image_load_error(std::string("png_loader: unsupported color type in file ") + filename);
+ }
+
+ if (bit_depth_ != 8 && bit_depth_ != 16)
+ {
+ fclose( fp );
+ png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) );
+ throw image_load_error("png_loader: unsupported bit depth of " + cast_to_string(bit_depth_) + " in file " + std::string(filename));
+ }
+
+ ld_->row_pointers_ = png_get_rows( ld_->png_ptr_, ld_->info_ptr_ );
+
+ fclose( fp );
+ if ( ld_->row_pointers_ == NULL )
+ {
+ png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) );
+ throw image_load_error(std::string("png_loader: parse error in file ") + filename);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_PNG_SUPPORT
+
+#endif // DLIB_PNG_LOADER_CPp_
+
diff --git a/ml/dlib/dlib/image_loader/png_loader.h b/ml/dlib/dlib/image_loader/png_loader.h
new file mode 100644
index 000000000..291d3fddd
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/png_loader.h
@@ -0,0 +1,223 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net), Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PNG_IMPORT
+#define DLIB_PNG_IMPORT
+
+#include <memory>
+
+#include "png_loader_abstract.h"
+#include "image_loader.h"
+#include "../pixel.h"
+#include "../dir_nav.h"
+#include "../test_for_odr_violations.h"
+
+namespace dlib
+{
+
+ struct LibpngData;
+ class png_loader : noncopyable
+ {
+ public:
+
+ png_loader( const char* filename );
+ png_loader( const std::string& filename );
+ png_loader( const dlib::file& f );
+ ~png_loader();
+
+ bool is_gray() const;
+ bool is_graya() const;
+ bool is_rgb() const;
+ bool is_rgba() const;
+
+ unsigned int bit_depth () const { return bit_depth_; }
+
+ template<typename T>
+ void get_image( T& t_) const
+ {
+#ifndef DLIB_PNG_SUPPORT
+ /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ You are getting this error because you are trying to use the png_loader
+ object but you haven't defined DLIB_PNG_SUPPORT. You must do so to use
+ this object. You must also make sure you set your build environment
+ to link against the libpng library.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+ COMPILE_TIME_ASSERT(sizeof(T) == 0);
+#endif
+
+ typedef typename image_traits<T>::pixel_type pixel_type;
+ image_view<T> t(t_);
+ t.set_size( height_, width_ );
+
+
+ if (is_gray() && bit_depth_ == 8)
+ {
+ for ( unsigned n = 0; n < height_;n++ )
+ {
+ const unsigned char* v = get_row( n );
+ for ( unsigned m = 0; m < width_;m++ )
+ {
+ unsigned char p = v[m];
+ assign_pixel( t[n][m], p );
+ }
+ }
+ }
+ else if (is_gray() && bit_depth_ == 16)
+ {
+ for ( unsigned n = 0; n < height_;n++ )
+ {
+ const uint16* v = (uint16*)get_row( n );
+ for ( unsigned m = 0; m < width_;m++ )
+ {
+ dlib::uint16 p = v[m];
+ assign_pixel( t[n][m], p );
+ }
+ }
+ }
+ else if (is_graya() && bit_depth_ == 8)
+ {
+ for ( unsigned n = 0; n < height_;n++ )
+ {
+ const unsigned char* v = get_row( n );
+ for ( unsigned m = 0; m < width_; m++ )
+ {
+ unsigned char p = v[m*2];
+ if (!pixel_traits<pixel_type>::has_alpha)
+ {
+ assign_pixel( t[n][m], p );
+ }
+ else
+ {
+ unsigned char pa = v[m*2+1];
+ rgb_alpha_pixel pix;
+ assign_pixel(pix, p);
+ assign_pixel(pix.alpha, pa);
+ assign_pixel(t[n][m], pix);
+ }
+ }
+ }
+ }
+ else if (is_graya() && bit_depth_ == 16)
+ {
+ for ( unsigned n = 0; n < height_;n++ )
+ {
+ const uint16* v = (uint16*)get_row( n );
+ for ( unsigned m = 0; m < width_; m++ )
+ {
+ dlib::uint16 p = v[m*2];
+ if (!pixel_traits<pixel_type>::has_alpha)
+ {
+ assign_pixel( t[n][m], p );
+ }
+ else
+ {
+ dlib::uint16 pa = v[m*2+1];
+ rgb_alpha_pixel pix;
+ assign_pixel(pix, p);
+ assign_pixel(pix.alpha, pa);
+ assign_pixel(t[n][m], pix);
+ }
+ }
+ }
+ }
+ else if (is_rgb() && bit_depth_ == 8)
+ {
+ for ( unsigned n = 0; n < height_;n++ )
+ {
+ const unsigned char* v = get_row( n );
+ for ( unsigned m = 0; m < width_;m++ )
+ {
+ rgb_pixel p;
+ p.red = v[m*3];
+ p.green = v[m*3+1];
+ p.blue = v[m*3+2];
+ assign_pixel( t[n][m], p );
+ }
+ }
+ }
+ else if (is_rgb() && bit_depth_ == 16)
+ {
+ for ( unsigned n = 0; n < height_;n++ )
+ {
+ const uint16* v = (uint16*)get_row( n );
+ for ( unsigned m = 0; m < width_;m++ )
+ {
+ rgb_pixel p;
+ p.red = static_cast<uint8>(v[m*3]);
+ p.green = static_cast<uint8>(v[m*3+1]);
+ p.blue = static_cast<uint8>(v[m*3+2]);
+ assign_pixel( t[n][m], p );
+ }
+ }
+ }
+ else if (is_rgba() && bit_depth_ == 8)
+ {
+ if (!pixel_traits<pixel_type>::has_alpha)
+ assign_all_pixels(t,0);
+
+ for ( unsigned n = 0; n < height_;n++ )
+ {
+ const unsigned char* v = get_row( n );
+ for ( unsigned m = 0; m < width_;m++ )
+ {
+ rgb_alpha_pixel p;
+ p.red = v[m*4];
+ p.green = v[m*4+1];
+ p.blue = v[m*4+2];
+ p.alpha = v[m*4+3];
+ assign_pixel( t[n][m], p );
+ }
+ }
+ }
+ else if (is_rgba() && bit_depth_ == 16)
+ {
+ if (!pixel_traits<pixel_type>::has_alpha)
+ assign_all_pixels(t,0);
+
+ for ( unsigned n = 0; n < height_;n++ )
+ {
+ const uint16* v = (uint16*)get_row( n );
+ for ( unsigned m = 0; m < width_;m++ )
+ {
+ rgb_alpha_pixel p;
+ p.red = static_cast<uint8>(v[m*4]);
+ p.green = static_cast<uint8>(v[m*4+1]);
+ p.blue = static_cast<uint8>(v[m*4+2]);
+ p.alpha = static_cast<uint8>(v[m*4+3]);
+ assign_pixel( t[n][m], p );
+ }
+ }
+ }
+ }
+
+ private:
+ const unsigned char* get_row( unsigned i ) const;
+ void read_image( const char* filename );
+ unsigned height_, width_;
+ unsigned bit_depth_;
+ int color_type_;
+ std::unique_ptr<LibpngData> ld_;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void load_png (
+ image_type& image,
+ const std::string& file_name
+ )
+ {
+ png_loader(file_name).get_image(image);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "png_loader.cpp"
+#endif
+
+#endif // DLIB_PNG_IMPORT
+
diff --git a/ml/dlib/dlib/image_loader/png_loader_abstract.h b/ml/dlib/dlib/image_loader/png_loader_abstract.h
new file mode 100644
index 000000000..d81e7f83a
--- /dev/null
+++ b/ml/dlib/dlib/image_loader/png_loader_abstract.h
@@ -0,0 +1,162 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net), Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_PNG_IMPORT_ABSTRACT
+#ifdef DLIB_PNG_IMPORT_ABSTRACT
+
+#include "image_loader_abstract.h"
+#include "../algs.h"
+#include "../pixel.h"
+#include "../dir_nav.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+ class png_loader : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ Defined by the constructors
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a class capable of loading PNG image files.
+ Once an instance of it is created to contain a PNG file from
+ disk you can obtain the image stored in it via get_image().
+ !*/
+
+ public:
+
+ png_loader(
+ const char* filename
+ );
+ /*!
+ ensures
+ - loads the PNG file with the given file name into this object
+ throws
+ - std::bad_alloc
+ - image_load_error
+ This exception is thrown if there is some error that prevents
+ us from loading the given PNG file.
+ !*/
+
+ png_loader(
+ const std::string& filename
+ );
+ /*!
+ ensures
+ - loads the PNG file with the given file name into this object
+ throws
+ - std::bad_alloc
+ - image_load_error
+ This exception is thrown if there is some error that prevents
+ us from loading the given PNG file.
+ !*/
+
+ png_loader(
+ const dlib::file& f
+ );
+ /*!
+ ensures
+ - loads the PNG file with the given file name into this object
+ throws
+ - std::bad_alloc
+ - image_load_error
+ This exception is thrown if there is some error that prevents
+ us from loading the given PNG file.
+ !*/
+
+ ~png_loader(
+ );
+ /*!
+ ensures
+ - all resources associated with *this has been released
+ !*/
+
+ bool is_gray(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains a grayscale image without an alpha channel) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_graya(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains a grayscale image with an alpha channel) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_rgb(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains a 3 channel RGB image) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_rgba(
+ ) const;
+ /*!
+ ensures
+ - if (this object contains a 4 channel RGB alpha image) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ unsigned int bit_depth (
+ ) const;
+ /*!
+ ensures
+ - returns the number of bits per channel in the image contained by this
+ object. The possible values are 8 or 16.
+ !*/
+
+ template<
+ typename image_type
+ >
+ void get_image(
+ image_type& img
+ ) const;
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - loads the PNG image stored in this object into img
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void load_png (
+ image_type& image,
+ const std::string& file_name
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - performs: png_loader(file_name).get_image(image);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_PNG_IMPORT_ABSTRACT
+
+
diff --git a/ml/dlib/dlib/image_processing.h b/ml/dlib/dlib/image_processing.h
new file mode 100644
index 000000000..a53f4a9d1
--- /dev/null
+++ b/ml/dlib/dlib/image_processing.h
@@ -0,0 +1,28 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_IMAGE_PROCESSInG_H_h_
+#define DLIB_IMAGE_PROCESSInG_H_h_
+
+#include "image_processing/scan_image.h"
+#include "image_processing/scan_image_pyramid.h"
+#include "image_processing/detection_template_tools.h"
+#include "image_processing/object_detector.h"
+#include "image_processing/box_overlap_testing.h"
+#include "image_processing/scan_image_pyramid_tools.h"
+#include "image_processing/setup_hashed_features.h"
+#include "image_processing/scan_image_boxes.h"
+#include "image_processing/scan_image_custom.h"
+#include "image_processing/remove_unobtainable_rectangles.h"
+#include "image_processing/scan_fhog_pyramid.h"
+#include "image_processing/shape_predictor.h"
+#include "image_processing/shape_predictor_trainer.h"
+#include "image_processing/correlation_tracker.h"
+
+#endif // DLIB_IMAGE_PROCESSInG_H_h_
+
+
diff --git a/ml/dlib/dlib/image_processing/box_overlap_testing.h b/ml/dlib/dlib/image_processing/box_overlap_testing.h
new file mode 100644
index 000000000..32409d13e
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/box_overlap_testing.h
@@ -0,0 +1,215 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BOX_OVERlAP_TESTING_Hh_
+#define DLIB_BOX_OVERlAP_TESTING_Hh_
+
+#include "box_overlap_testing_abstract.h"
+#include "../geometry.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ inline double box_intersection_over_union (
+ const drectangle& a,
+ const drectangle& b
+ )
+ {
+ const double inner = a.intersect(b).area();
+ if (inner == 0)
+ return 0;
+ const double outer = (a+b).area();
+ return inner/outer;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double box_intersection_over_union (
+ const rectangle& a,
+ const rectangle& b
+ )
+ {
+ return box_intersection_over_union(drectangle(a),drectangle(b));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double box_percent_covered (
+ const drectangle& a,
+ const drectangle& b
+ )
+ {
+ const double inner = a.intersect(b).area();
+ if (inner == 0)
+ return 0;
+ return std::max(inner/a.area(), inner/b.area());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double box_percent_covered (
+ const rectangle& a,
+ const rectangle& b
+ )
+ {
+ return box_percent_covered(drectangle(a), drectangle(b));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_box_overlap
+ {
+ public:
+ test_box_overlap (
+ ) : iou_thresh(0.5), percent_covered_thresh(1.0)
+ {}
+
+ explicit test_box_overlap (
+ double iou_thresh_,
+ double percent_covered_thresh_ = 1.0
+ ) : iou_thresh(iou_thresh_), percent_covered_thresh(percent_covered_thresh_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= iou_thresh && iou_thresh <= 1 &&
+ 0 <= percent_covered_thresh && percent_covered_thresh <= 1,
+ "\t test_box_overlap::test_box_overlap(iou_thresh, percent_covered_thresh)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t iou_thresh: " << iou_thresh
+ << "\n\t percent_covered_thresh: " << percent_covered_thresh
+ << "\n\t this: " << this
+ );
+
+ }
+
+ bool operator() (
+ const dlib::rectangle& a,
+ const dlib::rectangle& b
+ ) const
+ {
+ const double inner = a.intersect(b).area();
+ if (inner == 0)
+ return false;
+
+ const double outer = (a+b).area();
+ if (inner/outer > iou_thresh ||
+ inner/a.area() > percent_covered_thresh ||
+ inner/b.area() > percent_covered_thresh)
+ return true;
+ else
+ return false;
+ }
+
+ double get_percent_covered_thresh (
+ ) const
+ {
+ return percent_covered_thresh;
+ }
+
+ double get_iou_thresh (
+ ) const
+ {
+ return iou_thresh;
+ }
+
+ private:
+ double iou_thresh;
+ double percent_covered_thresh;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const test_box_overlap& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.get_iou_thresh(), out);
+ serialize(item.get_percent_covered_thresh(), out);
+ }
+
+ inline void deserialize (
+ test_box_overlap& item,
+ std::istream& in
+ )
+ {
+ double percent_covered_thresh, iou_thresh;
+ deserialize(iou_thresh, in);
+ deserialize(percent_covered_thresh, in);
+ item = test_box_overlap(iou_thresh, percent_covered_thresh);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline test_box_overlap find_tight_overlap_tester (
+ const std::vector<std::vector<rectangle> >& rects
+ )
+ {
+ double max_pcov = 0;
+ double max_iou_score = 0;
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ for (unsigned long j = 0; j < rects[i].size(); ++j)
+ {
+ for (unsigned long k = j+1; k < rects[i].size(); ++k)
+ {
+ const rectangle a = rects[i][j];
+ const rectangle b = rects[i][k];
+ const double iou_score = (a.intersect(b)).area()/(double)(a+b).area();
+ const double pcov_a = (a.intersect(b)).area()/(double)(a).area();
+ const double pcov_b = (a.intersect(b)).area()/(double)(b).area();
+
+ if (iou_score > max_iou_score)
+ max_iou_score = iou_score;
+
+ if (pcov_a > max_pcov)
+ max_pcov = pcov_a;
+ if (pcov_b > max_pcov)
+ max_pcov = pcov_b;
+ }
+ }
+ }
+
+ // Relax these thresholds very slightly. We do this because on some systems the
+ // boxes that generated the max values erroneously trigger a box overlap iou even
+ // though their percent covered and iou values are *equal* to the thresholds but
+ // not greater. That is, sometimes when double values get moved around they change
+ // their values slightly, so this avoids the problems that can create.
+ max_iou_score = std::min(1.0000001*max_iou_score, 1.0);
+ max_pcov = std::min(1.0000001*max_pcov, 1.0);
+ return test_box_overlap(max_iou_score, max_pcov);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool overlaps_any_box (
+ const test_box_overlap& tester,
+ const std::vector<rectangle>& rects,
+ const rectangle& rect
+ )
+ {
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ if (tester(rects[i],rect))
+ return true;
+ }
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool overlaps_any_box (
+ const std::vector<rectangle>& rects,
+ const rectangle& rect
+ )
+ {
+ return overlaps_any_box(test_box_overlap(),rects,rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BOX_OVERlAP_TESTING_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/box_overlap_testing_abstract.h b/ml/dlib/dlib/image_processing/box_overlap_testing_abstract.h
new file mode 100644
index 000000000..1bb4a28ae
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/box_overlap_testing_abstract.h
@@ -0,0 +1,201 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_BOX_OVERlAP_TESTING_ABSTRACT_Hh_
+#ifdef DLIB_BOX_OVERlAP_TESTING_ABSTRACT_Hh_
+
+#include "../geometry.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ inline double box_intersection_over_union (
+ const drectangle& a,
+ const drectangle& b
+ );
+ /*!
+ ensures
+ - returns area of the intersection of a and b divided by (a+b).area(). If both
+ boxes are empty then returns 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline double box_intersection_over_union (
+ const rectangle& a,
+ const rectangle& b
+ );
+ /*!
+ ensures
+ - returns area of the intersection of a and b divided by (a+b).area(). If both
+ boxes are empty then returns 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline double box_percent_covered (
+ const drectangle& a,
+ const drectangle& b
+ );
+ /*!
+ ensures
+ - let OVERLAP = a.intersect(b).area()
+ - This function returns max(OVERLAP/a.area(), OVERLAP/b.area())
+ e.g. If one box entirely contains another then this function returns 1, if
+ they don't overlap at all it returns 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline double box_percent_covered (
+ const rectangle& a,
+ const rectangle& b
+ );
+ /*!
+ ensures
+ - let OVERLAP = a.intersect(b).area()
+ - This function returns max(OVERLAP/a.area(), OVERLAP/b.area())
+ e.g. If one box entirely contains another then this function returns 1, if
+ they don't overlap at all it returns 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class test_box_overlap
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple function object for determining if two rectangles
+ overlap.
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is safe provided that
+ only const member functions are invoked. Otherwise, access must be
+ protected by a mutex lock.
+ !*/
+
+ public:
+ test_box_overlap (
+ );
+ /*!
+ ensures
+ - #get_iou_thresh() == 0.5
+ - #get_percent_covered_thresh() == 1.0
+ !*/
+
+ explicit test_box_overlap (
+ double iou_thresh,
+ double percent_covered_thresh = 1.0
+ );
+ /*!
+ requires
+ - 0 <= iou_thresh <= 1
+ - 0 <= percent_covered_thresh <= 1
+ ensures
+ - #get_iou_thresh() == iou_thresh
+ - #get_percent_covered_thresh() == percent_covered_thresh
+ !*/
+
+ bool operator() (
+ const dlib::rectangle& a,
+ const dlib::rectangle& b
+ ) const;
+ /*!
+ ensures
+ - returns true if a and b overlap "enough". This is defined precisely below.
+ - if (a.intersect(b).area()/(a+b).area() > get_iou_thresh() ||
+ a.intersect(b).area()/a.area() > get_percent_covered_thresh() ||
+ a.intersect(b).area()/b.area() > get_percent_covered_thresh() ) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ double get_iou_thresh (
+ ) const;
+ /*!
+ ensures
+ - returns the threshold used to determine if two rectangle's intersection
+ over union value is big enough to be considered a match. Note that the
+ iou score varies from 0 to 1 and only becomes 1 when two rectangles are
+ identical.
+ !*/
+
+ double get_percent_covered_thresh (
+ ) const;
+ /*!
+ ensures
+ - returns the threshold used to determine if two rectangles overlap. This
+ value is the percent of a rectangle's area covered by another rectangle.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const test_box_overlap& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ test_box_overlap& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ test_box_overlap find_tight_overlap_tester (
+ const std::vector<std::vector<rectangle> >& rects
+ );
+ /*!
+ ensures
+ - This function finds the most restrictive test_box_overlap object possible
+ that is consistent with the given set of sets of rectangles.
+ - To be precise, this function finds and returns a test_box_overlap object
+ TBO such that:
+ - TBO.get_iou_thresh() and TBO.get_percent_covered_thresh() are as small
+ as possible such that the following conditions are satisfied.
+ - for all valid i:
+ - for all distinct rectangles A and B in rects[i]:
+ - TBO(A,B) == false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool overlaps_any_box (
+ const test_box_overlap& tester,
+ const std::vector<rectangle>& rects,
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - returns true if rect overlaps any box in rects and false otherwise. Overlap
+ is determined based on the given tester object.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool overlaps_any_box (
+ const std::vector<rectangle>& rects,
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - returns overlaps_any_box(test_box_overlap(), rects, rect)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_BOX_OVERlAP_TESTING_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/correlation_tracker.h b/ml/dlib/dlib/image_processing/correlation_tracker.h
new file mode 100644
index 000000000..f005ddc7b
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/correlation_tracker.h
@@ -0,0 +1,404 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CORRELATION_TrACKER_H_
+#define DLIB_CORRELATION_TrACKER_H_
+
+#include "correlation_tracker_abstract.h"
+#include "../geometry.h"
+#include "../matrix.h"
+#include "../array2d.h"
+#include "../image_transforms/assign_image.h"
+#include "../image_transforms/interpolation.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class correlation_tracker
+ {
+ public:
+
+ explicit correlation_tracker (unsigned long filter_size = 6,
+ unsigned long num_scale_levels = 5,
+ unsigned long scale_window_size = 23,
+ double regularizer_space = 0.001,
+ double nu_space = 0.025,
+ double regularizer_scale = 0.001,
+ double nu_scale = 0.025,
+ double scale_pyramid_alpha = 1.020
+ )
+ : filter_size(1 << filter_size), num_scale_levels(1 << num_scale_levels),
+ scale_window_size(scale_window_size),
+ regularizer_space(regularizer_space), nu_space(nu_space),
+ regularizer_scale(regularizer_scale), nu_scale(nu_scale),
+ scale_pyramid_alpha(scale_pyramid_alpha)
+ {
+ // Create the cosine mask used for space filtering.
+ mask = make_cosine_mask();
+
+ // Create the cosine mask used for the scale filtering.
+ scale_cos_mask.resize(get_num_scale_levels());
+ const long max_level = get_num_scale_levels()/2;
+ for (unsigned long k = 0; k < get_num_scale_levels(); ++k)
+ {
+ double dist = std::abs((double)k-max_level)/max_level*pi/2;
+ dist = std::min(dist, pi/2);
+ scale_cos_mask[k] = std::cos(dist);
+ }
+ }
+
+ template <typename image_type>
+ void start_track (
+ const image_type& img,
+ const drectangle& p
+ )
+ {
+ DLIB_CASSERT(p.is_empty() == false,
+ "\t void correlation_tracker::start_track()"
+ << "\n\t You can't give an empty rectangle."
+ );
+
+ B.set_size(0,0);
+
+ point_transform_affine tform = inv(make_chip(img, p, F));
+ for (unsigned long i = 0; i < F.size(); ++i)
+ fft_inplace(F[i]);
+ make_target_location_image(tform(center(p)), G);
+ A.resize(F.size());
+ for (unsigned long i = 0; i < F.size(); ++i)
+ {
+ A[i] = pointwise_multiply(G, F[i]);
+ B += squared(real(F[i]))+squared(imag(F[i]));
+ }
+
+ position = p;
+
+ // now do the scale space stuff
+ make_scale_space(img, Fs);
+ for (unsigned long i = 0; i < Fs.size(); ++i)
+ fft_inplace(Fs[i]);
+ make_scale_target_location_image(get_num_scale_levels()/2, Gs);
+ Bs.set_size(0);
+ As.resize(Fs.size());
+ for (unsigned long i = 0; i < Fs.size(); ++i)
+ {
+ As[i] = pointwise_multiply(Gs, Fs[i]);
+ Bs += squared(real(Fs[i]))+squared(imag(Fs[i]));
+ }
+ }
+
+
+ unsigned long get_filter_size (
+ ) const { return filter_size; }
+
+ unsigned long get_num_scale_levels(
+ ) const { return num_scale_levels; }
+
+ unsigned long get_scale_window_size (
+ ) const { return scale_window_size; }
+
+ double get_regularizer_space (
+ ) const { return regularizer_space; }
+ inline double get_nu_space (
+ ) const { return nu_space;}
+
+ double get_regularizer_scale (
+ ) const { return regularizer_scale; }
+ double get_nu_scale (
+ ) const { return nu_scale;}
+
+ drectangle get_position (
+ ) const
+ {
+ return position;
+ }
+
+ double get_scale_pyramid_alpha (
+ ) const { return scale_pyramid_alpha; }
+
+
+ template <typename image_type>
+ double update_noscale(
+ const image_type& img,
+ const drectangle& guess
+ )
+ {
+ DLIB_CASSERT(get_position().is_empty() == false,
+ "\t double correlation_tracker::update()"
+ << "\n\t You must call start_track() first before calling update()."
+ );
+
+
+ const point_transform_affine tform = make_chip(img, guess, F);
+ for (unsigned long i = 0; i < F.size(); ++i)
+ fft_inplace(F[i]);
+
+ // use the current filter to predict the object's location
+ G = 0;
+ for (unsigned long i = 0; i < F.size(); ++i)
+ G += pointwise_multiply(F[i],conj(A[i]));
+ G = pointwise_multiply(G, reciprocal(B+get_regularizer_space()));
+ ifft_inplace(G);
+ const dlib::vector<double,2> pp = max_point_interpolated(real(G));
+
+
+ // Compute the peak to side lobe ratio.
+ const point p = pp;
+ running_stats<double> rs;
+ const rectangle peak = centered_rect(p, 8,8);
+ for (long r = 0; r < G.nr(); ++r)
+ {
+ for (long c = 0; c < G.nc(); ++c)
+ {
+ if (!peak.contains(point(c,r)))
+ rs.add(G(r,c).real());
+ }
+ }
+ const double psr = (G(p.y(),p.x()).real()-rs.mean())/rs.stddev();
+
+ // update the position of the object
+ position = translate_rect(guess, tform(pp)-center(guess));
+
+ // now update the position filters
+ make_target_location_image(pp, G);
+ B *= (1-get_nu_space());
+ for (unsigned long i = 0; i < F.size(); ++i)
+ {
+ A[i] = get_nu_space()*pointwise_multiply(G, F[i]) + (1-get_nu_space())*A[i];
+ B += get_nu_space()*(squared(real(F[i]))+squared(imag(F[i])));
+ }
+
+ return psr;
+ }
+
+ template <typename image_type>
+ double update (
+ const image_type& img,
+ const drectangle& guess
+ )
+ {
+ double psr = update_noscale(img, guess);
+
+ // Now predict the scale change
+ make_scale_space(img, Fs);
+ for (unsigned long i = 0; i < Fs.size(); ++i)
+ fft_inplace(Fs[i]);
+ Gs = 0;
+ for (unsigned long i = 0; i < Fs.size(); ++i)
+ Gs += pointwise_multiply(Fs[i],conj(As[i]));
+ Gs = pointwise_multiply(Gs, reciprocal(Bs+get_regularizer_scale()));
+ ifft_inplace(Gs);
+ const double pos = max_point_interpolated(real(Gs)).y();
+
+ // update the rectangle's scale
+ position *= std::pow(get_scale_pyramid_alpha(), pos-(double)get_num_scale_levels()/2);
+
+
+
+ // Now update the scale filters
+ make_scale_target_location_image(pos, Gs);
+ Bs *= (1-get_nu_scale());
+ for (unsigned long i = 0; i < Fs.size(); ++i)
+ {
+ As[i] = get_nu_scale()*pointwise_multiply(Gs, Fs[i]) + (1-get_nu_scale())*As[i];
+ Bs += get_nu_scale()*(squared(real(Fs[i]))+squared(imag(Fs[i])));
+ }
+
+
+ return psr;
+ }
+
+ template <typename image_type>
+ double update_noscale (
+ const image_type& img
+ )
+ {
+ return update_noscale(img, get_position());
+ }
+
+ template <typename image_type>
+ double update(
+ const image_type& img
+ )
+ {
+ return update(img, get_position());
+ }
+
+ private:
+
+ template <typename image_type>
+ void make_scale_space(
+ const image_type& img,
+ std::vector<matrix<std::complex<double>,0,1> >& Fs
+ ) const
+ {
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+
+ // Make an image pyramid and put it into the chips array.
+ const long chip_size = get_scale_window_size();
+ drectangle ppp = position*std::pow(get_scale_pyramid_alpha(), -(double)get_num_scale_levels()/2);
+ dlib::array<array2d<pixel_type> > chips;
+ std::vector<dlib::vector<double,2> > from_points, to_points;
+ from_points.push_back(point(0,0));
+ from_points.push_back(point(chip_size-1,0));
+ from_points.push_back(point(chip_size-1,chip_size-1));
+ for (unsigned long i = 0; i < get_num_scale_levels(); ++i)
+ {
+ array2d<pixel_type> chip(chip_size,chip_size);
+
+ // pull box into chip
+ to_points.clear();
+ to_points.push_back(ppp.tl_corner());
+ to_points.push_back(ppp.tr_corner());
+ to_points.push_back(ppp.br_corner());
+ transform_image(img,chip,interpolate_bilinear(),find_affine_transform(from_points, to_points));
+
+ chips.push_back(chip);
+ ppp *= get_scale_pyramid_alpha();
+ }
+
+
+ // extract HOG for each chip
+ dlib::array<dlib::array<array2d<float> > > hogs(chips.size());
+ for (unsigned long i = 0; i < chips.size(); ++i)
+ {
+ extract_fhog_features(chips[i], hogs[i], 4);
+ hogs[i].resize(32);
+ assign_image(hogs[i][31], chips[i]);
+ assign_image(hogs[i][31], mat(hogs[i][31])/255.0);
+ }
+
+ // Now copy the hog features into the Fs outputs and also apply the cosine
+ // windowing.
+ Fs.resize(hogs[0].size()*hogs[0][0].size());
+ unsigned long i = 0;
+ for (long r = 0; r < hogs[0][0].nr(); ++r)
+ {
+ for (long c = 0; c < hogs[0][0].nc(); ++c)
+ {
+ for (unsigned long j = 0; j < hogs[0].size(); ++j)
+ {
+ Fs[i].set_size(hogs.size());
+ for (unsigned long k = 0; k < hogs.size(); ++k)
+ {
+ Fs[i](k) = hogs[k][j][r][c]*scale_cos_mask[k];
+ }
+ ++i;
+ }
+ }
+ }
+ }
+
+ template <typename image_type>
+ point_transform_affine make_chip (
+ const image_type& img,
+ drectangle p,
+ std::vector<matrix<std::complex<double> > >& chip
+ ) const
+ {
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ array2d<pixel_type> temp;
+ const double padding = 1.4;
+ const chip_details details(p*padding, chip_dims(get_filter_size(), get_filter_size()));
+ extract_image_chip(img, details, temp);
+
+
+ chip.resize(32);
+ dlib::array<array2d<float> > hog;
+ extract_fhog_features(temp, hog, 1, 3,3 );
+ for (unsigned long i = 0; i < hog.size(); ++i)
+ assign_image(chip[i], pointwise_multiply(matrix_cast<double>(mat(hog[i])), mask));
+
+ assign_image(chip[31], temp);
+ assign_image(chip[31], pointwise_multiply(mat(chip[31]), mask)/255.0);
+
+ return inv(get_mapping_to_chip(details));
+ }
+
+ void make_target_location_image (
+ const dlib::vector<double,2>& p,
+ matrix<std::complex<double> >& g
+ ) const
+ {
+ g.set_size(get_filter_size(), get_filter_size());
+ g = 0;
+ rectangle area = centered_rect(p, 21,21).intersect(get_rect(g));
+ for (long r = area.top(); r <= area.bottom(); ++r)
+ {
+ for (long c = area.left(); c <= area.right(); ++c)
+ {
+ double dist = length(point(c,r)-p);
+ g(r,c) = std::exp(-dist/3.0);
+ }
+ }
+ fft_inplace(g);
+ g = conj(g);
+ }
+
+
+ void make_scale_target_location_image (
+ const double scale,
+ matrix<std::complex<double>,0,1>& g
+ ) const
+ {
+ g.set_size(get_num_scale_levels());
+ for (long i = 0; i < g.size(); ++i)
+ {
+ double dist = std::pow((i-scale),2.0);
+ g(i) = std::exp(-dist/1.000);
+ }
+ fft_inplace(g);
+ g = conj(g);
+ }
+
+ matrix<double> make_cosine_mask (
+ ) const
+ {
+ const long size = get_filter_size();
+ matrix<double> temp(size,size);
+ point cent = center(get_rect(temp));
+ for (long r = 0; r < temp.nr(); ++r)
+ {
+ for (long c = 0; c < temp.nc(); ++c)
+ {
+ point delta = point(c,r)-cent;
+ double dist = length(delta)/(size/2.0)*(pi/2);
+ dist = std::min(dist*1.0, pi/2);
+
+ temp(r,c) = std::cos(dist);
+ }
+ }
+ return temp;
+ }
+
+
+ std::vector<matrix<std::complex<double> > > A, F;
+ matrix<double> B;
+
+ std::vector<matrix<std::complex<double>,0,1> > As, Fs;
+ matrix<double,0,1> Bs;
+ drectangle position;
+
+ matrix<double> mask;
+ std::vector<double> scale_cos_mask;
+
+ // G and Gs do not logically contribute to the state of this object. They are
+ // here just so we can void reallocating them over and over.
+ matrix<std::complex<double> > G;
+ matrix<std::complex<double>,0,1> Gs;
+
+ unsigned long filter_size;
+ unsigned long num_scale_levels;
+ unsigned long scale_window_size;
+ double regularizer_space;
+ double nu_space;
+ double regularizer_scale;
+ double nu_scale;
+ double scale_pyramid_alpha;
+ };
+}
+
+#endif // DLIB_CORRELATION_TrACKER_H_
+
diff --git a/ml/dlib/dlib/image_processing/correlation_tracker_abstract.h b/ml/dlib/dlib/image_processing/correlation_tracker_abstract.h
new file mode 100644
index 000000000..5514f5e76
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/correlation_tracker_abstract.h
@@ -0,0 +1,162 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CORRELATION_TrACKER_ABSTRACT_H_
+#ifdef DLIB_CORRELATION_TrACKER_ABSTRACT_H_
+
+#include "../geometry/drectangle_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class correlation_tracker
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool for tracking moving objects in a video stream. You give it
+ the bounding box of an object in the first frame and it attempts to track the
+ object in the box from frame to frame.
+
+ This tool is an implementation of the method described in the following paper:
+ Danelljan, Martin, et al. "Accurate scale estimation for robust visual
+ tracking." Proceedings of the British Machine Vision Conference BMVC. 2014.
+ !*/
+
+ public:
+
+ explicit correlation_tracker (unsigned long filter_size = 6,
+ unsigned long num_scale_levels = 5,
+ unsigned long scale_window_size = 23,
+ double regularizer_space = 0.001,
+ double nu_space = 0.025,
+ double regularizer_scale = 0.001,
+ double nu_scale = 0.025,
+ double scale_pyramid_alpha = 1.020
+ );
+ /*!
+ requires
+ - p.is_empty() == false
+ ensures
+ - Initializes correlation_tracker. Higher value of filter_size and
+ num_scale_levels increases tracking precision but requires more CPU
+ for processing. Recommended values for filter_size = 5-7,
+ default = 6, for num_scale_levels = 4-6, default = 5
+ - #get_position().is_empty() == true
+ !*/
+
+ template <
+ typename image_type
+ >
+ void start_track (
+ const image_type& img,
+ const drectangle& p
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - p.is_empty() == false
+ ensures
+ - This object will start tracking the thing inside the bounding box in the
+ given image. That is, if you call update() with subsequent video frames
+ then it will try to keep track of the position of the object inside p.
+ - #get_position() == p
+ !*/
+
+ drectangle get_position (
+ ) const;
+ /*!
+ ensures
+ - returns the predicted position of the object under track.
+ !*/
+
+ template <
+ typename image_type
+ >
+ double update_noscale (
+ const image_type& img,
+ const drectangle& guess
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - get_position().is_empty() == false
+ (i.e. you must have started tracking by calling start_track())
+ ensures
+ - When searching for the object in img, we search in the area around the
+ provided guess. This function only tracks object position without trying
+ to track the scale
+ - #get_position() == the new predicted location of the object in img. This
+ location will be a copy of guess that has been translated and NOT scaled
+ appropriately based on the content of img so that it, hopefully, bounds
+ the object in img.
+ - Returns the peak to side-lobe ratio. This is a number that measures how
+ confident the tracker is that the object is inside #get_position().
+ Larger values indicate higher confidence.
+ !*/
+
+ template <
+ typename image_type
+ >
+ double update (
+ const image_type& img,
+ const drectangle& guess
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - get_position().is_empty() == false
+ (i.e. you must have started tracking by calling start_track())
+ ensures
+ - When searching for the object in img, we search in the area around the
+ provided guess.
+ - #get_position() == the new predicted location of the object in img. This
+ location will be a copy of guess that has been translated and scaled
+ appropriately based on the content of img so that it, hopefully, bounds
+ the object in img.
+ - Returns the peak to side-lobe ratio. This is a number that measures how
+ confident the tracker is that the object is inside #get_position().
+ Larger values indicate higher confidence.
+ !*/
+
+ template <
+ typename image_type
+ >
+ double update_noscale (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - get_position().is_empty() == false
+ (i.e. you must have started tracking by calling start_track())
+ ensures
+ - performs: return update_noscale(img, get_position())
+ !*/
+ template <
+ typename image_type
+ >
+ double update (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - get_position().is_empty() == false
+ (i.e. you must have started tracking by calling start_track())
+ ensures
+ - performs: return update(img, get_position())
+ !*/
+
+ };
+}
+
+#endif // DLIB_CORRELATION_TrACKER_ABSTRACT_H_
+
+
+
diff --git a/ml/dlib/dlib/image_processing/detection_template_tools.h b/ml/dlib/dlib/image_processing/detection_template_tools.h
new file mode 100644
index 000000000..b22c109fe
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/detection_template_tools.h
@@ -0,0 +1,113 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DETECTION_TEMPlATE_TOOLS_Hh_
+#define DLIB_DETECTION_TEMPlATE_TOOLS_Hh_
+
+#include "detection_template_tools_abstract.h"
+#include "../geometry.h"
+#include "../matrix.h"
+#include <utility>
+#include <vector>
+#include <cmath>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle compute_box_dimensions (
+ const double width_to_height_ratio,
+ const double area
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(width_to_height_ratio > 0 && area > 0,
+ "\t rectangle compute_box_dimensions()"
+ << "\n\t Invalid arguments were given to this function. "
+ << "\n\t width_to_height_ratio: " << width_to_height_ratio
+ << "\n\t area: " << area
+ );
+
+ /*
+ width*height == area
+ width/height == width_to_height_ratio
+ */
+ using namespace std;
+
+ const int height = (int)std::floor(std::sqrt(area/width_to_height_ratio) + 0.5);
+ const int width = (int)std::floor(area/height + 0.5);
+
+ return centered_rect(0,0,width,height);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<rectangle> create_single_box_detection_template (
+ const rectangle& object_box
+ )
+ {
+ std::vector<rectangle> temp;
+ temp.push_back(object_box);
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<rectangle> create_overlapped_2x2_detection_template (
+ const rectangle& object_box
+ )
+ {
+ std::vector<rectangle> result;
+
+ const point c = center(object_box);
+
+ result.push_back(rectangle() + c + object_box.tl_corner() + object_box.tr_corner());
+ result.push_back(rectangle() + c + object_box.bl_corner() + object_box.br_corner());
+ result.push_back(rectangle() + c + object_box.tl_corner() + object_box.bl_corner());
+ result.push_back(rectangle() + c + object_box.tr_corner() + object_box.br_corner());
+
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<rectangle> create_grid_detection_template (
+ const rectangle& object_box,
+ unsigned int cells_x,
+ unsigned int cells_y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(cells_x > 0 && cells_y > 0,
+ "\t std::vector<rectangle> create_grid_detection_template()"
+ << "\n\t The number of cells along a dimension can't be zero. "
+ << "\n\t cells_x: " << cells_x
+ << "\n\t cells_y: " << cells_y
+ );
+
+ std::vector<rectangle> result;
+
+ const matrix<double,1> x = linspace(object_box.left(), object_box.right(), cells_x+1);
+ const matrix<double,1> y = linspace(object_box.top(), object_box.bottom(), cells_y+1);
+
+ for (long j = 0; j+1 < y.size(); ++j)
+ {
+ for (long i = 0; i+1 < x.size(); ++i)
+ {
+ const dlib::vector<double,2> tl(x(i),y(j));
+ const dlib::vector<double,2> br(x(i+1),y(j+1));
+ result.push_back(rectangle(tl,br));
+ }
+ }
+
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_DETECTION_TEMPlATE_TOOLS_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/detection_template_tools_abstract.h b/ml/dlib/dlib/image_processing/detection_template_tools_abstract.h
new file mode 100644
index 000000000..30b0ad5b9
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/detection_template_tools_abstract.h
@@ -0,0 +1,95 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DETECTION_TEMPlATE_TOOLS_ABSTRACT_Hh_
+#ifdef DLIB_DETECTION_TEMPlATE_TOOLS_ABSTRACT_Hh_
+
+#include "../geometry.h"
+#include <utility>
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ rectangle compute_box_dimensions (
+ const double width_to_height_ratio,
+ const double area
+ );
+ /*!
+ requires
+ - area > 0
+ - width_to_height_ratio > 0
+ ensures
+ - returns a rectangle with the given area and width_to_height_ratio.
+ - In particular, returns a rectangle R such that:
+ - R.area() == area (to within integer precision)
+ - R.width()/R.height() == width_to_height_ratio (to within integer precision)
+ - center(R) == point(0,0)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::vector<rectangle> create_single_box_detection_template (
+ const rectangle& object_box
+ );
+ /*!
+ ensures
+ - returns a vector that contains only object_box.
+ - In particular, returns a vector V such that:
+ - V.size() == 1
+ - V[0] == object_box
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::vector<rectangle> create_overlapped_2x2_detection_template (
+ const rectangle& object_box
+ );
+ /*!
+ ensures
+ - Divides object_box up into four overlapping regions, the
+ top half, bottom half, left half, and right half. These
+ four rectangles are returned inside a std::vector.
+ - In particular, returns a vector V such that:
+ - V.size() == 4
+ - V[0] == top half of object_box
+ - V[1] == bottom half of object_box
+ - V[2] == left half of object_box
+ - V[3] == right half of object_box
+ - for all valid i: object_box.contains(V[i]) == true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::vector<rectangle> create_grid_detection_template (
+ const rectangle& object_box,
+ unsigned int cells_x,
+ unsigned int cells_y
+ );
+ /*!
+ requires
+ - cells_x > 0
+ - cells_y > 0
+ ensures
+ - Divides object_box up into a grid and returns a vector
+ containing all the rectangles corresponding to elements
+ of the grid. Moreover, the grid will be cells_x elements
+ wide and cells_y elements tall.
+ - In particular, returns a vector V such that:
+ - V.size() == cells_x*cells_y
+ - for all valid i:
+ - object_box.contains(V[i]) == true
+ - V[i] == The rectangle corresponding to the ith grid
+ element.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_DETECTION_TEMPlATE_TOOLS_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/image_processing/frontal_face_detector.h b/ml/dlib/dlib/image_processing/frontal_face_detector.h
new file mode 100644
index 000000000..3f4b59769
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/frontal_face_detector.h
@@ -0,0 +1,2373 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FRONTAL_FACE_DETECTOr_Hh_
+#define DLIB_FRONTAL_FACE_DETECTOr_Hh_
+
+#include "frontal_face_detector_abstract.h"
+#include "../image_processing/object_detector.h"
+#include "../image_processing/scan_fhog_pyramid.h"
+#include <sstream>
+#include "../compress_stream.h"
+#include "../base64.h"
+
+namespace dlib
+{
+ typedef object_detector<scan_fhog_pyramid<pyramid_down<6> > > frontal_face_detector;
+ inline const std::string get_serialized_frontal_faces();
+
+ inline frontal_face_detector get_frontal_face_detector()
+ {
+ std::istringstream sin(get_serialized_frontal_faces());
+ frontal_face_detector detector;
+ deserialize(detector, sin);
+ return detector;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ /*
+ It is built out of 5 HOG filters. A front looking, left looking, right looking,
+ front looking but rotated left, and finally a front looking but rotated right one.
+
+ Moreover, here is the training log and parameters used to generate the filters:
+ The front detector:
+ trained on mirrored set of labeled_faces_in_the_wild/frontal_faces.xml
+ upsampled each image by 2:1
+ used pyramid_down<6>
+ loss per missed target: 1
+ epsilon: 0.05
+ padding: 0
+ detection window size: 80 80
+ C: 700
+ nuclear norm regularizer: 9
+ cell_size: 8
+ num filters: 78
+ num images: 4748
+ Train detector (precision,recall,AP): 0.999793 0.895517 0.895368
+ singular value threshold: 0.15
+
+ The left detector:
+ trained on labeled_faces_in_the_wild/left_faces.xml
+ upsampled each image by 2:1
+ used pyramid_down<6>
+ loss per missed target: 2
+ epsilon: 0.05
+ padding: 0
+ detection window size: 80 80
+ C: 250
+ nuclear norm regularizer: 8
+ cell_size: 8
+ num filters: 63
+ num images: 493
+ Train detector (precision,recall,AP): 0.991803 0.86019 0.859486
+ singular value threshold: 0.15
+
+ The right detector:
+ trained left-right flip of labeled_faces_in_the_wild/left_faces.xml
+ upsampled each image by 2:1
+ used pyramid_down<6>
+ loss per missed target: 2
+ epsilon: 0.05
+ padding: 0
+ detection window size: 80 80
+ C: 250
+ nuclear norm regularizer: 8
+ cell_size: 8
+ num filters: 66
+ num images: 493
+ Train detector (precision,recall,AP): 0.991781 0.85782 0.857341
+ singular value threshold: 0.19
+
+ The front-rotate-left detector:
+ trained on mirrored set of labeled_faces_in_the_wild/frontal_faces.xml
+ upsampled each image by 2:1
+ used pyramid_down<6>
+ rotated left 27 degrees
+ loss per missed target: 1
+ epsilon: 0.05
+ padding: 0
+ detection window size: 80 80
+ C: 700
+ nuclear norm regularizer: 9
+ cell_size: 8
+ num images: 4748
+ singular value threshold: 0.12
+
+ The front-rotate-right detector:
+ trained on mirrored set of labeled_faces_in_the_wild/frontal_faces.xml
+ upsampled each image by 2:1
+ used pyramid_down<6>
+ rotated right 27 degrees
+ loss per missed target: 1
+ epsilon: 0.05
+ padding: 0
+ detection window size: 80 80
+ C: 700
+ nuclear norm regularizer: 9
+ cell_size: 8
+ num filters: 89
+ num images: 4748
+ Train detector (precision,recall,AP): 1 0.897369 0.897369
+ singular value threshold: 0.15
+ */
+ inline const std::string get_serialized_frontal_faces()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'object_detector.dat' we want to decode and return.
+ sout << "AW2B5ZIvv09mlKLVYjKqbJC05yeR2KsCpPGEGOgn2QlwM92S4UT4HgQkV0V9WqYRf6xETTSVKz7Z";
+ sout << "YcJ84Jc4C3+VdPgZDhV+LDt6qAt3OI4nA9zN4Y9cCIb6ivlETkN/JMmapbOAUW2mrSzDif5zjAaq";
+ sout << "+NFvw/5V0Jciopw9tR6nYtV41unWGvyyfsO9CcqvDy81QIydToHh0a7UaL0jCtA2DYzkViDufxyv";
+ sout << "Kpsn4xMyiU0haM1ge3UktIO48io/gSzjEKu0YYAffbD2YO1IE34tUH15Z3Z9NjkBFxTytDgrMxk8";
+ sout << "i9MYq+Nl9nS421aogmec3ugExJYjLZMHs4KAk71jvG8vtJyJEA3qyLY6lvONt98gzQwGQ9+2B6de";
+ sout << "ocb/DDJUza6mvudHQNJBYraR4gCWcIn9gFu2rJiRHf4IiqP4GEB3B1zKiHfJRo9jZbhxQUitAxAx";
+ sout << "U/E2SuuHGZDilqK9AJ4K41RAudraxF9li/Bs4f+CK3G8Z/c97P7WLVekJL2ws+MsCdL9ObHE5ePD";
+ sout << "uLLQWBy5NUbgPVM6HEnhnOiZk3rA4DYNqbABy3uemablAln9BLGkk4wrm2UcicacnzY8Aq054Ttb";
+ sout << "3CCTcG4SOSPfePl/7T1M6Uy1hOesp5MpXfUR8gBKr4466dbdXCDHSahI05gra6NzxkOpOo2mOqBg";
+ sout << "LYNGZUkHK4tdRyyD12N1MH+nJiMJbgk+qj54t5i3AuEr/71HTRXoTT8AEYbvc9y4f2WAlliQYXPn";
+ sout << "O2Uaza3lKYrH7mFjKMNhLfvrezy9fe+1asbSlRKelnU3eY4lhD6fTVJjXqZypBfMnfmGQQJ0Q7g5";
+ sout << "1Z/9GzpRyZnPSzQljtJgzVp8Gk0z3fuKiXPO9g+s4XL2cEuxBOFij0KGTy4eNitM0gcPc6xzp3tz";
+ sout << "6Wv0W2h7w4h+V8Bzvyn8ag1sbEO0G1Lf2BrDVM9+pNxFoWFxYHqdoOmJPVvb8PRQqoC5bkqhplFr";
+ sout << "TR5l3XsQedgwsnkadxNZQ3MbRJyo0JU0kvV1cfphLcn24MIIKqAnw3daXqbJaba+oCUep5GTuzI7";
+ sout << "nad7ykHNN0iFkgYXMmXJl+F5TsS8y+izuHlXAX6wX1qRVzWJwCpM5oVVG/5eYTzg0J9C1bCcNyHL";
+ sout << "2w5TJFYrD8bq3O+Y3fiO5LJ8F5/vsu2EBUMi1+eP1WfsTwd6N9jFtF5gA5sHX3zI925aDqVx9byr";
+ sout << "j4X5yr68p5P6f8wSLL8jzW8i4a0yP3zXlqN6QQDY1ssfNsMf43tOTtmbBlmxviL2egs4gvadD7Gd";
+ sout << "fRNowL71P3mkqRmnrnihlI01NbDl+Trzsh3EOn43PRC9nl8yo+fYVH8GqS8JGy1xOw4G479vOifI";
+ sout << "9GC4BGnSDJdKgSnBwI1AJQ2TT8EZ//56lkRlgusg25TwC7uQ1zreeL6baYdgfXSggx3ULdNDGl5o";
+ sout << "ftRK9LDaop6XvB6I0ITsLYvAoGP/5sHfttDj6HlQW/LlzkSPmzY/FtV6h6bE+k1gG7BANrQjwOW5";
+ sout << "sfHNYadD1v4zIFdt2su3docGbGP/iDMvM+BmYIBP86zX5eIlTYwDmxXht95T6GCCjS/XuMMy12hd";
+ sout << "Fdb6lm1O42ieM4KQ/2EOFy3Ij+YOIapzYA6p6Jz9dtINpCojgUHyo6xc4HTNnEKRy+YN+awhb1l2";
+ sout << "FJdy2/QI3xGVNNTnWcQrsvjGZb/Z3VaZUltrIbnCeEZOeOCM0TxkBEhqFfI3qwMx8PUj+imUlTDM";
+ sout << "7N+p5sxmKLliHHovOO32ajBTKUSI9IMQzf3QY6dZDts4JkMYQ1xc6lpm679s1KMVVrWuOqiAU5Vs";
+ sout << "qehfnl+oMRngi0G0BnMne45CjU5RECvhg+Vkkxx0kAp38+9pY3XiO/DuyIxpOSPip2o0+9rZLF1Z";
+ sout << "cAUGnG85CFEXl96wpxvqVlIULUV2+pNJxdU+q1MkCsxDeXrvfjhEAJpPE38dUb3t4blsNUZ3wJ2w";
+ sout << "s6cXe0nEPWNkZlmEsXcFpw5zHe0Gd7YpXigz7Z+IVhvplpv686TJiLTpVPW2T1uJvSmMuG/FqvT5";
+ sout << "JIIMg2of1ydicw5EbWrqhIUzllX3l0u00gFziPmKAioiqCxjWojd9l3Q0Q6IsaZAH+WzV2xFabbY";
+ sout << "4b8SwoFvhe4qnUQLFdOSTbzeDIKP9B8bSiQwbjUBg3jYEWUrMz+eR9lpGu8603vChIEXaTxyMrO5";
+ sout << "SCeaVOgPE77potDoSUV1hsoW7ZqGCFH+AGyVTohitS0iqZbIxC7+7rnVP8XfXw5YpSajF94z2TSd";
+ sout << "jW0KpmuCZ88DTCPFamf5zh917qp/PzQOGTdalr+Ov+ogvrJraDnoE+ONWrdHqBm7Adgn8/wy5vzX";
+ sout << "fNu1AT14eYrEmWmXvt6JDAbBYqP8Aw8b1QRZff11MblUh0IpztedWhifGy/RFJUN0/e66Mh0cKeF";
+ sout << "plmK6NqchTzOQMKJVq9jxdyurcjcA0uu4dVJ1XXkAtxBim2J2m0zcwX/+HcRe9VbeNehmDbUC49o";
+ sout << "ktNvrwbbB1IUV/c0MNCruV359DVINXskQTK12g2X5qprOLW+YPO6CnTFpJRsiFBoLllF1sUTjROH";
+ sout << "SrHHRYp3W5t5gqfT4afBxmtTmpJEG0oG4eNfMhxEhQ7HjoVhahOM6px9Be9S+4ca/w+zII7NnUkY";
+ sout << "Iaas+FW7vhOIDOiV82SpJqBjdY9eIP//XGR1DFQKI5cLKmT2/DF8tB9XcqTgmVWNMVt9Xw21CaeR";
+ sout << "eYeoWvLHlm8o7ahtJCSQ0iHypTZMA16wdJ5IJD5WoYd50rUn58RBa9sTXT/t/KhxJfG5OWXl55eq";
+ sout << "abYojSlluFyvFSk7Z/wu/EqFUEBD8r4OIrlJCMZl6kKy4EncmjUrb3mG6aDKxsaRBRBkRRya9t77";
+ sout << "epMG60v3MRCcY9E+n9sXAOUpf+ErN7iD6FY4XFpq5R0Z+6MiLRE0af/JQ9R42quTl8CLH7609DDd";
+ sout << "s8+8bKA2zjvSJhWbwGURRCW8SK9tNKuemwkt3Eutm+xMJemP2JIVFVXYxjCvDmxIIODneu1vmcSy";
+ sout << "XadKkyjtYDwacddFAqGh0kLqHX9i/WoedVKC1Vuup+AYPkyZ1lPraGVqjq0nsiwp/vxm9c/+4/wS";
+ sout << "hW99Q+zoAZ0IWWeYAqcXGdZqvd58gx0/fmU/Pq4FqtCdJ2qnoUDMvjZeyWE7lA/Xf7AdLcz4XHNz";
+ sout << "VAidxMj7/K8p3KdK+XqED94Ey1WzpUQ2mH+10Zq/6jebtoYJlht9meMsjvjWxg4nwFIZY1QAMZPV";
+ sout << "phcmEwrLA+Z/Xjo+FEq4hKD8pIriQi4xT4uAoPzOFGp/ziwBbAb/EfYsspnVxpnERKblbDsV9bFK";
+ sout << "df5VSgeqg7p2auZBk/WkX/wOeXkulbiJA5lXTsgInJGoREJ+uaudFadnLD1pmjqq+VFJW6XOT++I";
+ sout << "arRHT1sYJY5mhqFztTeUGH5VXZNtRGl1nWpewvmgyK6T5XLUcuZqsyZVtzkkQ0eSR4h+nBuRAQZK";
+ sout << "EmqzcPrRKObVC7Xv+kMcnM+M2+zCuZoUSO/zt7OOXNt/B51oQ2DRqthPxgzUrWOvoOgZleeayImR";
+ sout << "rqG2QnkA8+Kb/mhxJ+SAqOjsIJATfLc13SzVIKVumz+uX5jUiZXWfWF/e1cdS1w3Nf+dNinnGQ2v";
+ sout << "Vf2SxiRlTfDTZTZXcLlT8VrCOP4UYvg0QrzBqU2myM41lZDUS2X/WOzvrNrRpEoCS6/OcbvMj5gf";
+ sout << "dzoZ1oqvaL3dosQ9/QAwI7wPC2/QQTRyDIbl4EuhcX/ebyueLqlxKRLPrmq/mE2YU6aU7cf+t8PM";
+ sout << "LX7J0eyNGl00TEWN0R7ui2xlfdnLfmILs+lNNthYvmUbtbncoqx0sCWgjk1Iqagp9uFlWA+6vMa6";
+ sout << "nx7Qg0Jz+Qn2u4iZpyGDZKUWmHgDYhRcKfsnjnzbztyNms4tEnmwtIeLwDqFPlC4BKefz9gja+tt";
+ sout << "0Om7TAtwcmWHve1ENSOQSKTLNvVwhjRLgmfK9SFUUjKeVSMp1g27Rf2WCoDyBXhRauHpqCdj7GH4";
+ sout << "AmlI3EUHSjS+/1ZoT/2DnHwuN+GKVh0d8k/7sGrln94r3JuxEPvfyQFvPSRlFkWYyKPdxb+H37L7";
+ sout << "DCYzgf43vxZ82JGYVXB7QBpO0VeEXiQcIgXV7uOsGKiXfFUwueL4kdNknk2hAfrFFQcpoBiQSz89";
+ sout << "sNRT1tH5Ipbuf19R66cogiaAesXnm42jLjQpnkMwd3S4F6P8zDL7m13u6ahcrWvZiyDuBc/+Th2+";
+ sout << "Swex+BSv2TkUorR6nwVoozQEA16/MlZB1acIFSrD0kSr22Vubdbo05svEAZ7DKIdQjDu2wTTOOvL";
+ sout << "IVJOYPjdxXZytBv1jRIhUyyYBvRtaFsl57ZWAmvbFEXZXLihnrBskTqxrxNPqhg+bLxicTDlFyHI";
+ sout << "UIipzL4AvofdYWolB8RvFyom18/szC67Flr1OW4axZ+k5S7249Y64eqtU1hk98joIhOdWaWBHxkL";
+ sout << "nP+ooeHeEvB5hNRIA23Yxoh4zzsWUB1KKvg1XRzjt2CBQ2FPaCfHsOKf52aaj0W2FByC81rpryrm";
+ sout << "Ye51T5zP5/N5j7yA+a907774PwIS3eYYyJRUCSh0ywfQ8rgkbjBdf2rKa0alzokz7Kmo2Iswnid0";
+ sout << "WzpliQr9KaPwAk7hkLjprMjzdJIug3KOKVAgygXP7rkgETIfTfZRG49EdJOjlW8mlmHZsO+arTFW";
+ sout << "vj0FgJCAQrrX0X9BOQ0MPu0friAGK0TNGsFs17lcHjaRNHXz3v6dY/MSR2TY82iyEkRofvNY/Xjr";
+ sout << "FRB2KM6Aq2pPpIjY4EuSQS5sU9ur5oxrKo68jNzoB9iRvmKhQq5HRSYKL5ACBF85HM2oWtyVl23y";
+ sout << "TqTW+jwNHfF+sc0FPS6xfwr3yvHi/OVlW046gnNLKOxO3RjGntaJeX8EVXGGpXDHLyle0UaZE0iG";
+ sout << "1xeLZjyKq5wRJ/Q1MPry/JJbfCXSIHeO2Uznqn7O5rcs5v3Z1PdlF7BfPUhnP7+Wcqryfi4xJ8rS";
+ sout << "BzyJkibOCzegvXnKTTw7q5/lrgh7LxfrY/4G6/Js8ibrUU9NGqBkOHUmxa9P7UPK43pz/bS7SWtl";
+ sout << "yA/3hBa0bv6hN0OeXVaBxtr8sMfS7FcvR3wtvmtKn4BlIYer3LMSvigPCK3K5seTPH3cx0J2uGzf";
+ sout << "SlPZus5idN8MnFCEiBUbs4W1M/BSw9EYA9rJyhDTyOYqKr6s1kagBUoVCXVlEPVgrJoppc4vLghu";
+ sout << "NMgUpcakhT8SAulssCjPb1UWPF92XPpn8/byK8dJoSFe1lfFb5Yog5YZMjgoKKbokk0n3eMlrbm2";
+ sout << "AGwIh0acdOXRR+lpeJQ240N/Waw3e+FhAI+AYfOkIXodtQcod08+F8uHCAAcd9dZvYyxZxNKjbCc";
+ sout << "aYTFUYPN53OZEwEyCIFWwPf0QhdhlpyAGCj9gqVU4N9b5FJYX2ZqVAl5JF4nl9yDWrJ3zmhwL4r1";
+ sout << "P8Pdv02ysNeZu76Y60+ffPXCqmjHjllu082gde9BXIEWdS1sd5qaH0qb8KRpV8WAYaM7/ccGTHQ+";
+ sout << "H+0C5o2904WS3MG8rR6LI6EqO2fcBnJzZ5BJX2bHv4kNHhQiW2tZjBlwKjuMH8Ddayd1BVqzjeuH";
+ sout << "5dfcL8xV4su36eRT/Vmanq/NZ80+KXsXZO1k88RIfQwwZdt5XribJfUSwzKGsKQrhu+8iUCjGP8l";
+ sout << "ScrIRdj3gjy2brM+zBr9z9pvFZR5NLjYN1Ko2BptMbEDxdjnYkYWix8BF1P+/PtSEJeGATIyl2al";
+ sout << "rAlEHX3ysdDjUic86ZNUx9c6N59ZcQkIr7IwFl6kc5sbuthroXmAnbW0A2UIO/LN6KFbbE53Up4Q";
+ sout << "KwoMeMHxlgEwundK+LV5WZ136K5JoA6SpvxzuKhCckg0Ev4+KtyA+1wlna6AHOQaj24BzblSd9k4";
+ sout << "2lWsVOwAOtGxFIRIxpou7S4yqPrvS93KPtVkrHDBqIveGcwoGfyw2ZSX+5o5SIZ5PUG3mFM/sNWw";
+ sout << "twketaHdV/ndITa3aJyGpqChs3hcwMOgODnpC+vjtY1D8zdp3pn4MBgb33jxc5kOCpDktiKGyaQN";
+ sout << "sQ7oaOy4aKmr4TFfWbrH3qeR9gz0utGL/iVHcgSlfl8rw4BFncc8HIB0SGJJhYE+lfEYpsP8H+1p";
+ sout << "pfG0yzIA752vcaOIWIGt+C/EvuXl5PP8qyE0aBe637yQd1aMyRhf46rsAIlhzwZ28wPYZ9KCaC41";
+ sout << "ap2+7/EJMw3HramAAo5OVqA6M5cV2V+MlGifSoVgTN+5TaY9EnqexQy2Gqw+9484Tv7QNVaEtwtY";
+ sout << "/O+aQ8nzc6H+clWCWJkWDvoqIrIqP4jFUaJ/FnqlPEb2GkPoNluJV92HqQj6fD2Iz0TQKCVQkVWq";
+ sout << "D/QuFVq8c+EC8Bz2j1cI0D30iwROmneb6XHTYVwn4yHkZ6LAoOz28fjT6dwJFdYo+Ci4Hhyl62tW";
+ sout << "P0al3X19i/IjH4Xi+ZH+lISFmA0oEJo4AG/oAklXtGtRtIwfKGIuIzqqEztmX9tY+INu7PtgH/FP";
+ sout << "z7d2f3CBZTZY4qTPMEPQ/th8jnjHrROZIM7Cej4v+zYms7NPlJ7x/k+eX5ISG7xEbWr8j+kr+R70";
+ sout << "bjGaz/rED73YxTMBmhQSKMDUjNaW/qclrQuvaNwUgM/VCtnY7NANztFMXhCa2hGjZaG/bp8Yc9IN";
+ sout << "T20nhrbTX+KPkcEmQjsHwyK8hT9XN6J+TD4iwdnb4A/KQI9JwaqpYPp0S1d99j0iqXlirvdPcotu";
+ sout << "AsUmNf1YOlK1I5KxaFA42emXXmg1vr7USuKiX62IslSjknRY0+bPxOcn09P0VK+HTTdLIZ+8p9+k";
+ sout << "fKsgY8ajl6qZ/LP5qbZ2KgHJHJNArwRSxNn4CR5ish97R+1A3DglEaWJ4vVuu4oaFIHc9eSgRMdJ";
+ sout << "IPJ52p+8SKIpjM3Tnig/Gw35R+sPcuZlpauFplYb3vcIoY7vN/f6+RKxPtWnuOfBh1iPJxJJfz+H";
+ sout << "MDVZihR471I+DLXgGrZ0fgMQZqVFelhF5eszKMOxB81TbRuPqUmneRijtWvR8QAySqzV3o+OoM/n";
+ sout << "fpoLxmcVQm3LanGF1VfCbk7X9dhocgWTpk9XjDzVjIjPceJ2IPuFjHcrNtu6L/fe/6sMqkKWNRQH";
+ sout << "8GGbrJU0/kqXIeYch5gXjiKFTIU/QIRt0e4YKlNV0Rpqkh8vY/X3OL8xbNCBd0bM4nCXMp5Ytwyw";
+ sout << "DEyjzBl4SvxgGSqG6ehAzY1LrZ3bHU0Bn/Q7vD6RIEr/WcxUvdr8oy1JuIey4PllgfcCaDdW1+wG";
+ sout << "YCz/81Acw7xOiG//gLZ+tApj+tGpMP3Z/vnC7bZmXAmXWCfZeWDwIcxX/V5Sco8G21PpzMYPM7k0";
+ sout << "1MEkR1PgNhpKv5he6criGZ6D/xVAfVJbxc7blkovBkLh109MFBCAGiA8zk3MAShzI7cynZVbyWGw";
+ sout << "+x5Bvl8/6xSUyG0MLFANqDilWFIEBpT8h6G6StX3WXoEfqJrO2pYMjQdqOg08AvXKWJg/xj4U2Mw";
+ sout << "m8+nK+zX8aXHC333WcQ+1eG918/0TEDoQAXep1atGq3wir0iBvurJHbOXffjGQalMd3AeFCLWaFn";
+ sout << "7tYSTWYcPnWwWuA47FxTSOPezm1PrihBIC7CyVjGHAGvtBdh2EjCVptJHYgft9Ivp0YpPFaGtT1c";
+ sout << "IsaiWl+dF+Yg0K4FVIpNqRq/g1EEpGni0mrTmlTKeeSiKzAXdGnjOZ/9woea73BFZkY0kAqMlrn4";
+ sout << "6AOIuXh9Af1UJe0OxzhcYKFjFuzj7Imjv0SgNaah5XePYFfLyqNUCctmTlFna9nZWZ4/Q/N0tqN1";
+ sout << "QJwMtZOdFdKoSwFcDrSuMBc2kKNCEgnXAB9azTyR6Frs6RDNbCOdmMKEIF0Ra6v4fqO/rzc+m8nM";
+ sout << "2GAyE9yBNQq1THcQSqlataFHDe9KkmlQ41F9hKifZEPJ2eMe4WbpMdXmjT0nNmxif9OiPMKR28EQ";
+ sout << "pcqtuJxTE2oQArxmoOD6uUSUpm+Xc190raj1/JA7kfFQPkONEkNn9fYRh7J9VvPk58RIkyDL3RfG";
+ sout << "SjlzXsvz0d2uU14U8ppyPSOUEgcvxUu9Zk/TcwZkWvQeJTPd/i7jbUyAHTPXy0secfKXWSoF4T1S";
+ sout << "AuuRuErtEIXmJm6bd3v6ozR+Vc494q5Nu80EGIEy+09XWaDi8E0ChYGPUn15jWmkw9aZ2SUGju+0";
+ sout << "OS7eaGTSBcbS7l3IaP+053oZvh6NN+iYo5Lb0rs+bog58fqpXLFLeaJFnHUmZipr4oX2EfpI3FuE";
+ sout << "1I7xjgdZiWMq57u9UId2PuNahTVN62Du790tZhGfoAACZxKx9xxi8nRwxz1Rh8uFosXdHJridfzF";
+ sout << "gzZhDTmxJjYCq7tg6769BcDtHxT2G+JOh2hMFV+aieGZEkBEfj6EWhuot2jR+VVjpLUhys154Fj1";
+ sout << "NLN0d0zMnJDThQlNGigIaHgVdQ+l/lNtN9ovAuVJRib/fYnSDRBpOQpOU5NuwyHjeHnYg20iKuBT";
+ sout << "ZWphFPD7M+zYlrVVH9Lg0AsaY5Yt0U7g+TXLuT/bi2tUz3rrrk/5bY7iLkGbEmOFZmxzXXqWdm8h";
+ sout << "ENOKVj/yrgSa/l93WqyOESSZk7hLMvP+OVkSj8qAKKBQ1+XyqVLODZZae0volbIcZe3HAAIjdYTm";
+ sout << "+JQfIGAWgkqcHwgv3WJiGPhq1WOVi4FSq2Dgxi5/J6cRg1Smsr9aCx0uNC2x362lI8Jd9yKn8m+m";
+ sout << "te+3Zx6sx0NCnYKxaWcH3V7BfF0hp0WQJ3vQbPG20PD/ACHvMEgmo2dDFit6m4yfWAQxHzQZE/3N";
+ sout << "E5TLT4EMnZxi00F6sV0G25nElE2t9CrGkLNxTUbK2sGKx+ybsveIWpoNtQty7hY8NF2KIICOd8QJ";
+ sout << "FsAKxGHbydI+9NV/8KyW2UID4JpoNJOQkh4B8pp/1bkBRPsikKLyowC6RWuWmVBm/DCPSIwkiV2A";
+ sout << "jNHVRmSoDO+U3eTMxbamjBV/H+xWgrBzBu+4aaFGH0MbKNtXG5COeCVMCtA5v9pmR65GLD/DYWcM";
+ sout << "JltMV+H82nUN8qDVTMpCSzrlkiv4Gmvh6b9HkxZC03g+IrBKAkXWkhIl3iYkIjLYNudFSUddDb6g";
+ sout << "/wHCk0lGJlbYim9VV0uRYJITZenRrzsMcb6g6Cm22cB9awV0qpixCGVW+jms3MfgcstzqdN36KPw";
+ sout << "C0IDdKjN2Bu9aNqHqWafK8Vl+oTYVU6foPJSOmHD3MhFHhuZk0oPtptRs/0aSZKH3FI+jz6KyTM0";
+ sout << "E9UDooIsxYAo7og8Ka1QCVel8cH4mmTBWTGLNNxuVwvHYgQc+j+QgKJ8DX3XzEHJQVL2fxCmm3i/";
+ sout << "tjJTltGK8o3S66OO3dN5g1KaYyxCDkmKjsGpyqAKGhdrzQzwLru4oof/b4cM6E/3aqGWH9pI27G2";
+ sout << "8jNYhu6r5LhYMpczurY9gssS94+RdUn4UFMt2zZSlpFsCY9E0NNaGwQ5sX0pcyk9r/FKWAWxT+e4";
+ sout << "b/buzfSIVsHWrzytkKOYCHMylaPd+juDOWX/Y1x5IBmR/VnpsIWbuYFjRlK7bvNoVcwitIZyI1Ku";
+ sout << "vmkH9u5YzndkbH9fj8FroFgMdZumIeRSFz2448yoIh/1+2wyEUXUvof32q5kktEutky9XtKCTIen";
+ sout << "LlWO9/7k0Kcz2Cp1S8bugmULKSLHEWMTtScZhEOl/o3jyMjLpbHhSfY5IHwZXVp6MO/bxpk4F4ur";
+ sout << "C2eAlsHUW0484VZFIm/GtgNRKq5H4MTRSmlzHxh0o5KnK87ZZNGKv2sGFoxhOKT8g9s8uz2ZfkI8";
+ sout << "HS0VWQ5y0Y9dY00ShJj9FShAuForC1EW8TBcgW2wjk7uN4CjXgupadGHC4hMFxVjJJ4tPj53PX+w";
+ sout << "KKTety47QKF0aeNAXeiNkzo0e/H8XYKYvyRKPpUhWbj5rzdkSev920dKjpq731kGRLUP8kljqmx4";
+ sout << "j/1ukvHqJrarb/U0LWECXe1mHUjehedJNCuDsXmlX3OIT4557z3W9vMbzKyu+0R+LN2YtYUFTdXA";
+ sout << "z442W3oqy7cJRIMfioDLTO4ry26sNo4uyq83j3iFx2iY4Wc41ZUGg9cwh5TVKg8XEh5US5xlsqVO";
+ sout << "kDR3XfYXA3GwuKklaNN6vImd+oP4g2ZYSl51f7tj5hd9xpTSRwIy3RJJ5VoTz+36jpT4Y3fnlppg";
+ sout << "GqBmWhJrY5UemTIoZbJ5X12NjQjW2HiKsuiCpLS9Wm0IXYWcRSYfiWYBLP+QZFyRA2VqVpwmY2X1";
+ sout << "EafYVxAjG2au8TnbfK+PLccuRg+kYNExJfD/hLUMyVg4wkLxP95L8CB85+g/1VomueeKJFnlrnkO";
+ sout << "ezCBls31aI/r4fMbdISFALkRwPav4rVwi8M67zuhxx/K97+5I4ONkaSU8/DI4SpqjfEIzl/y07Rg";
+ sout << "VUou00laGIhidjtfwENl18fyXGmjmLI/Mn+/H8gU1mW4Z0stSN/NkPYZjTx1AnvjG/LgaY1750yS";
+ sout << "4dk+ygLr07oWPhGB3BhIElS7VDxZnnPo2MFIPXTqWHqZ1/lNq8DE2EqgHgpFQGmp2MZVi060DA0Y";
+ sout << "En5g8zk1NXq0irzIv/hXYLbDEnL4ieulF+BlWN1oeERYelY8VkqgMtqGwBlwiO/qN488MVobHHAk";
+ sout << "VARDBpkSyX2bsF0KS4BCwybuQtNPVCaozYKWd8Q0RSNvsK72afBC+snd/y2KrFhcE4mE9ZhAwV7R";
+ sout << "LRR4IBmgNkDPDi7YXFEVZ4No5G5dJYL3yfsZy4b0kBEplbOoIjYxwz2dXYtX5Wc3hcKzRblKZG2i";
+ sout << "GkmTHabzN4BTwbGBxmCTbbyecAIO6MFJGlnxW6tQfdiQbcBbt1utUTpjVhZPVkGolN4VgU/qFPCj";
+ sout << "UyO9bO+RUapMvtwhI9+1KPcGiTbQsAX/V9+dSCjQIgD5sLRjfVQcKmK6/R0VSppo3ab0+XHDv55p";
+ sout << "FOPkhAKiKvI4Wl1JcKcsx8mwxCoTSchCxp5JhNn+WYBoINpTlmdRKI2hfXvfY+YXUzbATuTLKIZX";
+ sout << "IsHeRrzLNmntT4lzgHtEArSwEcYDRXLKBd+L13FZBV8iMX3ON8vUBMLU8QKoSDXatEI//h8RcI2R";
+ sout << "pOba7GU2f5TWFy5lB74tBKpcllmmid9w6jE2T3yhxU0E5GFWxWv64oSJDCfyD5GRfY7L2dOVBVwA";
+ sout << "H1DuC3NeBQfgaY+DPYFyC2gR6vEihtW5biK4HZoQkEHaBD8nREBdMlh8DcGuXwsTwEH1co2xFaNz";
+ sout << "53QpwalF61MYqPbQuFXZBvFEruliv3cYHUgIqtFo902pwFOK447zzj81l+5XzdVZHsA6dCGAjSqW";
+ sout << "J//PGJo3M48ERSqeURrEwNN6lD6nOqs9XAkQyGp0xLcv1/EVyzMoYTWazSaTkHbocIh58BOJVDya";
+ sout << "rjRytgcV9cAKzYvY9O3NBPvWMBSybUG0weBGTpWXNlydqxlAc7PBND1DfOL4XA6aDHpra2rRpJ/t";
+ sout << "xQJvaFWVNRYBOpR34GsrLpczGcf/z5hhR1gpE5y9//b26xf7V66n3kn0w2qGADZz4eL+Y7Wl1rIJ";
+ sout << "QXs4U95d6lfp26TVY7MsmQRf1GaO4keltA6LW8XkS9zXro/Ydl49AWToXe7suuJk6OGzaUqJImLB";
+ sout << "fI1w0xXDoVdNfY1SgepZyQxrW7PqtQUlLTHccsTDUJqVdu9ZUMnCVlo+6fQNz5lS7wvRbv5iqgkz";
+ sout << "DMyynFxFQvzk2L3sZUt1+xTw9r2d7urJ9VmGpj0arjR2+qb+2mfFqH0HaldqN+DGEiibZ7w9PmCT";
+ sout << "MNDZvjC0zm2N87yPuRBSbwn4JoAD979lNhFSpExOt7v2zucluinLIqwESRQjWnyun+xTZbu1MAka";
+ sout << "JAut97DUpQb5ALQ7TLqKOfk4vSSP07cVRJPSH6K3XnR+ZFX9W+7kb1mYRhJ60r3uKUYAoYJIdGqL";
+ sout << "jgbNfvqdTZqUOVq/Sfc2/B2T3kY0W6facFDev+/YnpwWe95pYSfUbewbM35nEZGJ0HVSRHnBTWIO";
+ sout << "n7C6Xeg9e29pfohDW3jy7vPL9HU7+GdvhZYMUfNeQTe0zYKuY0+/UtMIuMFzDJ1J9tBy/cLPuI4K";
+ sout << "oyPNxmTBGCcf33xcff6ZvAePZPBFgjmbV8izFio89if2qmyhGPDi6LH2NxYGpjC0f+aPj3j3H7Ua";
+ sout << "yX5PEPGDl+3l5jZjuY+sLwwqgrUV0skzdcjAyEbLPClOkj2BG5dELl4VcD8ESsOwyk7Yyb2mt44n";
+ sout << "GKgKNGm2+EwSNyvECcoEksksg7gaE6ZNXazytt+kRITYczq/v57+U7/tSjyTRL5qPLxWX5OwESUw";
+ sout << "y2zx0ulSrfH44+Xxr7ZnI82X5IgDehZJQvNPBmtPTB6JvDuUhJMd+hQF1lboLwEHAfZKpcN4v7FB";
+ sout << "GEZi7Sp/iWCZQwtALzUDY4YKGUuS7uOyjpHcQp+hxIlbhXY9byIyhvvVy361/nbVwOnEHo5BaKYE";
+ sout << "csaN1xi8WvBN108lpddsUUDRgBW7oKXoiDI06pfubDTDZHSJDABQlnor5sTsIQBMs35yYGuq0lMN";
+ sout << "lDJ5h6Nb8r6h2HhenA8tSBmMXoq3j4IAq0jUDpeR9TXX4pBbGfN1HgWpbIrAKmSh9L6Pxa/tB97o";
+ sout << "D5seIFPmORWWemSfAMoAs28YqCise3933/HPnCk83PcWH/4S7+KITJx0tIgF8ssoS36XP4J9L/1Y";
+ sout << "ym65j7ffiEEDH2rDgip/UQ7utwOHAIW6rOjkComtHS0F+m+yOcdoIwecU9J4rPGLgpq7KW7oSdD8";
+ sout << "5/ckw1VAwgUvvo4YT9forakPHIPB9BgroHxUjvbvABXJEkYGw7xUwn51NSj+7LieWTsE+IVs2kWx";
+ sout << "TAU1B6hsUPr3f4Z2g74JI1AVGc9KSuJhTtognYLM0amQd7HkR9Y4gmTRYYrSbE1yCWj/gYd5Sn+W";
+ sout << "/NUvdGmfqjcmItvBAkDr7lf79aevcKySCPfP5ZzDBfM6aJw/T6EC3KwBpY4obv/Zgx9dLKZhA9Uh";
+ sout << "jCKQEEpTnfOOGI7D92zvtySthJNjrGzN8ZVdJHyzXMSYWgHEElfM3bB3LdAe54vVG/XyHag1EMMH";
+ sout << "DnH9JOUvMeXHOLRDnkI0RNlGg21wNjl3HTxSiXkIwpANPsBpcoow32KWYqrygnB9iF30IdwfVTbz";
+ sout << "OhTNM/4qyrwjdzxSTX4IeMQrviMB+gi12mTcB4G1ggqXuz1q6uFqfxrlmMx+gDAuoEbR0vFF/bXg";
+ sout << "M+8PXQ/oKyGYtptl5gM50TsI2CxNaBAU7SUTG2zH9pDkoko6VO3mXfRblwFH4vjJ3XETsr2uAJlS";
+ sout << "7wOiJOWfMj9dKFMH6efJkuZPegH2WtkRjomYXO3l/UVkWwW2KuLJgJhAgqJcI3ODJd/kVYoR+THn";
+ sout << "IiPnJvoBXfLTKJ6r2lbjeImNg/CwzRVhDVVPJi401mloyrMU6JQ4DtwoqeiS6qAcLDlJMcu2A0bG";
+ sout << "+F1isgRc72oPo86rpVxR7oJDEVRSsQqkOhv8O/lVaziMsLCBuqXUGfuohPNE/+mCdSrZZ5LzkKSe";
+ sout << "iTlYATHr66c+jnkOETWaEPjUpwQ6ABfit3mbttnnONmDSnenmUnHUf20QyUonJGpFMsWB+DSPFs6";
+ sout << "7zI9eOAhqKQh///VoPYY37AdsdacucmhBJY6lmIHHiDyT804IvuNqWLSt5/Cko8t8thgShjeSM8J";
+ sout << "9med4U5W4XJ1fEoialm7jil6e/fr23OJJf3VJp8JEaibvAk+rbAbc5VIzwaUG1duo1O4783HLJu5";
+ sout << "4we8QCONtekxRwXi5R2gUi//qD1kJKbKxnkYOXaKNWkUbEXPOSy+evfT6jfbYcdk7VvmfA0qioZc";
+ sout << "B9nVWevTpoC/1sE/aSX7dqAWjOd5WH+KsReDpJAtB+uMhu4iyzaHV/gPuCAUdm16nVmHgcBP65Ix";
+ sout << "Rz93awPYE4aI2yoWBvJnIN/GgUJPBW5rHFcTncTV1LSUMePStawPL7BZaY/V/HRBUOnQ3V+p3xwE";
+ sout << "QqFY44ilI49X0t0OR04upM1hjEnx1lVyd/2bSw03lCDr+y6oNwi/hrk389KdFRtlxT0/rItCg+gX";
+ sout << "JVV3Q5LyY4WEE5on7coii0m/ZyXMxNT/RitPnLWc2aPtEKhbVOWpVuQGw00eYcKTs+AN1SuqcsrA";
+ sout << "7mVcPVIYhQ9/rDYzAzdG3HTcuiFrDWkGIOQp++BZEYitA7zEexC0xZPQZsqcKoH8RieRidrtPNXS";
+ sout << "ihFNGNNuxYQWVchClJMvEBrl1ankneT/fJOLTob7xAG+o/n1zdtSeTUtXPN4O4ym3GiubaONLzLL";
+ sout << "Z1TzkBa+H/t9NkI+Vp3kVRswJr3cEu6K607OPm6yGAxw/BwpQRBli6uf6SMAdNcAPMwZr9xV7Est";
+ sout << "pLz8ibkfNdfMj6fMY9WKJ5CJhajqg0WPFTlnSnaRs5ERtBLK9r8Ip/XS2VUT9/rqeFivpq8OsInl";
+ sout << "yV5iKygaW5OyZOtBbI4SrhN30LZZaoP4D4fjXqc5/EzHyzGCYCfgjKrytefR2F6CqUUdBOn0nVtH";
+ sout << "Z4xjlb6IBw80vupy3KEjpjsl8eAiYM9JsV9aw4Fd2hjCdeg6yPCNN56pm59Yamga80+31oINYjri";
+ sout << "3OcSjN6tVNwdf1Jr0s6Y1+0VgrXT+AbHuMkdPQbhTgQr/AAqHzUr+5FhrIZ7xM2vF3PrqUDakSkQ";
+ sout << "P8xIrxYawDr6fXVDWeVPOlVhUSihPBMHjc187YnXDd8Hun9Lww0wUuzOPc9P3Wb8wBTFY2HiXNL1";
+ sout << "ciWhFec1G2O1lNNBgYSeclowdwMNrC5z9lk0jhLKLrX2Ji+B5ypECjWGE7ZMNSuETIucCTh4wl/Z";
+ sout << "fLIB5I8Lx2D0asU17GjJQk1UdQa9uWdNgpG07osHpTH5FoWxZcQSBl7cvfkqXltox1ItArv9yuKo";
+ sout << "3gDp6AgZTFOqYhSdagGzYHdzB6KkEpIUJJlvMZsRzlSNIUtHJ4muh5SbP/X0AAGWnNjNZj95Yf4L";
+ sout << "IS5+ZQRnfrzIl7Nvb4KkxbQicPMrtXCcZkWJ0zN+xlNOX4Ph69XZEpmkzj5OBi7H59Kcw6ZB8yEc";
+ sout << "3SIw3oNS+6XAIMU1TvhPexpfDTyQNBbIgyycOPYaeA7eSgg6yz/4z1RfNMVZEj8PgPri6IzZc7h7";
+ sout << "AzIGqSzGJWAiWCtBFSmDQ3KbDXDMAaG6e8g+zzdm5dnujiAJ+s3PneWlapo5dIvjh4MaL3w6iy6w";
+ sout << "T62tjz17F9eEnJD8IM36+Wn13OSPk0iFfPKZfBDZPhEAGRYG7tzc/HKJ/d4m0hEg2GTY3M6pEjZj";
+ sout << "nQIcccSE/e76TSkeBNrZGp7lplsixpLjBdRFSFQ59D4juFAU/8tf+MmgtxWd2VVPU/mtkYU9QXzq";
+ sout << "JeDq/+MOHtQoMdFuxJvlEj0EE6Aa4E4Ya31LBEoSbp7ln5dDcP/R1LvqaHZr7+XU73GMzpgMec3D";
+ sout << "7UY6Rnip+AXpYOaWcfz6XX6y6lLA3kdIsNptHnc+f85kigFZ1RsCXZxugGLjxcFWXVieSKv6PwVG";
+ sout << "oQdmyR7KlT+tdjXvfGTP9AwdU3S4QGHi77l1FSZaebpelVkrMgWhcug3s1Ed0/1c55yvaZi/ymXU";
+ sout << "OYEtOmPmMAB5wcOagBZsTBP4/6w8Zrfy+27SKh0W1vD/rHQMP083Xsv7HCqWppSVZMWOGJqyUkUV";
+ sout << "rBnbjjEmLTyHXr0e5DP7TRCpx4MWMIFUI+fkRdrWSXxNJqPlST0J4BzbQl3XSppj3iURoccQBDpN";
+ sout << "VoLZA+61XgLAwY4a+1HGcvVRLzlEhHCmWEaLIeWMfIpb6L7U/LQyfG4nDiMqt5HtmMbLHhSf1Iqy";
+ sout << "swpB3hHI+UjL+bFOD5XkCRclzVPucjimVLtsH1QXWKYIcrC23dh5/tfoR4SUxzYn30LEbmcMNctf";
+ sout << "ETO3ebcGC9+hvFGH1CwVowGxDhQfxf8tjORR4Vv+L0Xfp7yr0Li/Q8wPnrIQQkbasy2O7PwodnVx";
+ sout << "W3iohD3htvNaL90vc5WOf7o0YUKcAWfM5ryT2OJoFFYCSJaI4qLPb/b9OLg2WUL0jV7lYLrK1mBE";
+ sout << "JJhgJRQr1OiI4TnIIWrUQsdvkrRVYiUWzVS6RsE9dLpw0ThoXtu3Bi1gWJInLknUCJ/yUPkNqwQS";
+ sout << "hKult4TehcOZBfHVc+BOtTdcLNEzQVWy+HPWssvhSNIYtoWp839hdGzzLoFxsIGik40aHU78d+cq";
+ sout << "ksDCHIFvnbgBvPkpLxrmvXrKofATp+ywoYFeV0g808/Pl4kX+27zXT6ggZWAO/I9anenVXNcgtvx";
+ sout << "ICdGbYeHkfXRr1/IVvTgtN2kaS4sSRbf9LPij72aJoftCIa5EknOgNSOQuqpEDidYdaXZXl+6tg9";
+ sout << "lz1qKU3i0ivIjIcaGmxLEi1pBM1LwgEYWcovusXNcv5+pm6SXUgVzQkHu0Iz5MJEdrgsSsc4NN+2";
+ sout << "swZHdmviqcDDIk7fuOSwmj4IdjAWUq5lmgYZbpZLZ0pPsmTjqX5uaBFXqmlpKVj/vEIKiFOCGtZu";
+ sout << "uEek9ZEpH8aTYqjf+tGKsNNANsDNOFVwLsdD5edStQS0c3U9f2Q1KGKXw16BM7pArxVx6KxFjI4D";
+ sout << "LQxcYx18Gm6V4sCn2J0ahj7IO389LWJQJBcfJNyNSFhfaRbVha6itGi8UaBr7Q5LvkCvV01WUcJu";
+ sout << "AsuyKjRBScPvjzypYoCxSZp3ln/sXB58RGCVZ6c7UXeZnGs2ABzXEIRYIJyrsNNVky0aGKSHFRem";
+ sout << "r2gsZ/RYPQBVw+xt8kGwAkM2km4waF7nHbkN5SYq3VIedvw1gU3UIkbpno8zJeJcwrnoVT9n686i";
+ sout << "aE/9ltlEfn/OW7XUGFK4jXB9GBJ455E/9iUejULkvx7iqfRsDnhbI7UsVDn7Q4snN82f65MGUtU0";
+ sout << "w9UaxqWKQUZqvP6rX/4u+2IfBKWAksqUG/Rl3O4krkxQRuuOS2KA+u512w/JhgdR/9O0BNG1YuBd";
+ sout << "C14QpgMmqPdGEfNrXUZN7uSWJSdBiwqwh+yPFVqoclcjencYDg3ZzKNrfUMun9eKKRBJ2UqPNmJ0";
+ sout << "zM06doKb85m39v5GBACWExd6vWsrP8JxcHeWtDkH3Bt6qhZ4YjU38qiTD1avmK12ti8n3lpzOpNB";
+ sout << "ObN1g9F9JLpGQs4RrVt5xH7Xv2LsAC3UxcKZMW8nr+QVc9BykqYIU96dVj3kffJlvfM1fTyAtTN0";
+ sout << "4016YMvWy9OzdbSeaW9c3ua91Eq0w9Ve7rR4D7rUm0DGwaPPNaCAQP41DDP23U2RkaV2yhcS2ntN";
+ sout << "95eArvlyyr0JKWkMochvrYl2iHN/4cv15vog9n/9pUP15ttJdZdqiE+qwBaGA+B78y7kCf4X8dAE";
+ sout << "ab3I0Gbc7FLCNEsJGcQlFTSmB+2ccRQtRh74pirUXd5BPNNQUEkZWXyD4tVcDCkWRqHbkxlneyqS";
+ sout << "ziQQlbRyiM3MmvSmsUOWYYlq6iKu3rzomWTRukwdFP/LbClcTaMW611t5rXM2Jrl2JXWer6HsRk4";
+ sout << "Z65qxcLwDmjokz84nvJ1zpNeDcCq19jnxNqEOaDIpDVQxtM6RY5L96Sn395ecZHJFpL/E1TEHSbk";
+ sout << "V9ZnTcZwiga4d4FaVDa/L26ckv1+93o7KzMuxgopuJqf9GJ2c9inY+Y0m+71i6MNharLdfuGkLwr";
+ sout << "/iEeyeu8K8QcwNyDDd8QfStXRGgIGhNR573Q6ARI6a9Uft8y0hUYrJSTcaryLZRnV0xYAGAq11p0";
+ sout << "YOLr8U2iOncUfz0+Cfc3cu7nytOEpr+jDM70ojkQxjU7DmdCCYqdgik4v093Rv9hTdPEBFtzNNqh";
+ sout << "hWlaCaul6E6Pe+PdjMStduOGI8+eNFpOJ7/K7IlXuQLMLcggaXELqeqzUTaXGnMigQeiNsXUhXJr";
+ sout << "7g9PJYDPXXLNYIR1TuLPXCc6L5KuF+fjWQ9CwUxT6F0xCBMVMUVHQSkNoqngCeaHsQbIjpkFBbxH";
+ sout << "BXjl24RA5TKttRF9mUgNUQK4VV9LU93FJFqPAegUWA8A5AbaRkJbylwWT26qCwRcNcMm+wVFjxdI";
+ sout << "BYvJqx7TCrvo5ytBIlgRVx1HLjUcyITmeQ9CCl1j/Tfb4RwslDCgCeW0LHocZde6lCdknTwlOre8";
+ sout << "FHdSxxvQAImwiZKBSxPYqsRLXEGMtcFxkpAjbUZjITfHP99qD+h18ywpMp1xQz1FAa6QBaDjFQUy";
+ sout << "q6DWqkCYI2cOpovwq+eU9y0HKT/CxiAclgINEMJ26zRPgJDBK2vYmft5gfB0CHj8zUzBuYCK8n+u";
+ sout << "6RiFrok3YKqfszQbsWj/M1nBKReS75d282S02qdrnm+OwlbRFmkUX5VNUsI48fffdLUzQAVOD42L";
+ sout << "O04+nlesMESB1w0GezPWmsG+eNiUghaLxLnptRrcATVWseeWuqcdcDG9ct3BkyPEjaNSZGbvqN0q";
+ sout << "8S0H1IjsnmKrctlK+1ELXGBaND79Uq4HC8NtNseb4gEQFBTj2rI85gXxRTqPwtYvB26mpBPrWmgz";
+ sout << "JfLIOawOFX/GEe3W3NelU+CoBpxvGv2wmgqW2quks6TmilBZaQX1ewR61jVhKaI5oG96e4uUksjW";
+ sout << "I/cXdAP2GFl20jLWia2m76GGbTimCyffGDV9v3uzu+lpoZ87JgudGUKn/jdJK1uqTad/YQj5t2N2";
+ sout << "Y0PnVN59l6Do19E0YwY5vx5vDXs6q0SiWoy/zuY2hqcZ7paY1iJhIaanMAJjFK+3FelY/IH0Xo33";
+ sout << "Uhv+4k5pPRuL3AU0nA+rg4JCT9YS3OwOpcwv4kNyQL1xRg2DnlcryXMHHCZPYeVApEruSVmO8nTE";
+ sout << "g9cbFMz5nmDHKWdF+KH1Zs9jMTI7tfaOGS5qwX+gBwcMGUPeHf8OMiJ6y0zXo26vHzLJ61wPWLrP";
+ sout << "juyapqVcV/YgQl9Ok54s8YGCv9ZB4rxnxjiABMls7ZZquK5kNJ2ShQWW5F9ibtec8EcjMuK4i20B";
+ sout << "tTgq8ymele0eOIWVE5AZZ3yD64qqbY7kuCLJG1LxvQRe2zr3FYVYqWhfyRvywRIRFcequso1Bc9b";
+ sout << "HKoPDpwJDaTrQn7uhEsb02WZKg7Q4XKZVRXxDGHHcBqwa4fnNM3IPScvskFpmyrZAvR1QrjNLTii";
+ sout << "NdjPVn0Klyr0sQlppC50bu1eCy/Wt14QKDUA8OFYOmA5mEjVGtFktwxy9wLssgpAD8LoyPQSxuzD";
+ sout << "cgiB2HqNtOwNGPvlE4/ZDfT/N/j8s3lk0q0cZmrAUCBXsBDiAHNbkm3WeDEBDY9+Un6fFF1U5chM";
+ sout << "tKKpyvjeZd4bjQXK6zzZNavXSJzOvqVb3OKlH6OTvP11rgv3pMdHYo9T28C0onwMHN53QGPsWbzO";
+ sout << "57SmomSDGYs+ERJRpGEVUXgj3D+Q4O4v/fR+XMAtiSOVmz1c3c2y7Ys9Pq9pFX3UF4q6DLIWfBmE";
+ sout << "6omLA6O/y/Y6p++EZleosnni/RH3hMH8TvZRYFW4EojNCm7Ss+eyuktlVXQhcPUOuxQ/lK1SUx4k";
+ sout << "vCBOk1YCMcl67xok/WdgM/lWJvovLTLqylpQlhsHIM3I8ccuOcJ++lhPABcaXmInXnPMEV9K20kk";
+ sout << "d08Q72uJPoU4rjT6SMbBbq8L9UA9Ba7U2cOdK5dr3rUZwQyFZBswroQ7A9cZnuCD4ugJ4l8AnUQo";
+ sout << "A7ghXFgGkmzItOKtHoFYz9HmxZa+3qX23EQk3jVdml/8fFh6VjpTK73RPKwrbZyclJs0pyN09eDb";
+ sout << "RjZ153ucSgBH0jflZbSIoQhPdmmi+xQBqWV+YjqVYzyYJMJypf5ZrLCb682KW0KF2dpleDJUoX+Q";
+ sout << "4pQNhHSZWXtdJTUcRdmM8Wl3AjY0QNsbjyqe8rYj/o7FJaX/Y8b85y1fGLF3qqPJDxZQgR+jKfTg";
+ sout << "0vlbJdh3EbV5L//jZX7EcOTDU+dvkXOSyx8zQeS+5xwVXWPmTmaTNIriV/6EvNJBPQ0vmYCjsUoU";
+ sout << "6hD4EOcBuuOXADFmEcgRZl9z46qgDwqRasacBLwpaICbLCnpc8Q7QrBhpbmeHsWmqYtK6SzfiQ4j";
+ sout << "e19bNsz4SP4zzFzdEpLl/J7PeNWURM4SUBwZcNOZbDbD8as2KcD78JSCs0sG9zWL12JPjZ9lJ7AY";
+ sout << "lN9vqJl+2N6H1VGiJR3eO8Zrb/lYX2LkSz87AzRggZxDXdv1DjnPzGj740McdjWa96DZeexzAKtj";
+ sout << "DVHwn2PFsYDKzvmwr71zUNYLxwcK2U3ayJf8nuuP5nPkRDwl3b8ttN9QHNo2JWabnVJKWid0Dmiu";
+ sout << "zRus2qOuzcXLkGbgE5DdmONYcg7qznU3ostY+QyHO4/UZbDpqOPG+uXuk3SVhp6yBEmO2yE3T/We";
+ sout << "WCWw3dfW4DlOTxb4m+nf29ST7WIBoNR8omWSyxyZZodAXRy6NfOnpkYrgFAorXCprqzNiRLSbi8e";
+ sout << "hZrbJNNoqEUTrw1R/hXMJHftJH8GotFVuFuXTBV0wcm9eM0UeN1cWvuT/0dc7ORsLdbhNW7X9Uke";
+ sout << "tYEHwobKiM0mfa1dCdTWFvee2XkrmYsNjHfMNoQRUN/w+1VHqnFy1Q/qc1MoU/C6L1rPUjth3gNC";
+ sout << "oNr7jNNWH/tXMAEGqsHPP3+Hw9pqk4XE/B3QSbeQYrZqcZojBhWcwLJQIbSyJmya3w+QYKqie3k4";
+ sout << "/LNyngQb4on/sr1vLNCNc4c8yJAxV/nMoNx5cVDBVg/HNAh+qwhQ9tVhi7xRUpIrAPICcaWaX+RH";
+ sout << "Wl9jEeq6PdI1bUaTdBvF7DgSvAriB22oHxG6Jy8X5WIycn4FHMF0/ZlCfwccg8HcjvZzDlvFHfbs";
+ sout << "lepXFAJ21XIOWHwDzG19VnizLorKXM/FmOnFwClhG/+yTREVbWCjaTkDsOlWL6JGOehVHckhxWRM";
+ sout << "03eFtiNGh2k/oqsqHkDxYtt0rMmGly3vLlLN82Eiijq2iNo19EF+euIVAl2h1iEmOGXQf6lkXCCZ";
+ sout << "yscrDfP3XEbe0grXP0+/ETyFrAAl3/zoENKR5MUYxVziTQ2cy733o/8aX8J/X/hFTm3ZFVKKmZRp";
+ sout << "vS7HzVzjj8i9zZNshaRzWt1jYnQxKtJ9w9BEn0VLGcb2spLOKyctfmsughG6DDA7wjjkmNDNPFMH";
+ sout << "mefa7RvWjXog4gPP6SiKITag5LwuBDotZn966sOaBOWK85QelE/NBsk4hCsf5LNPkMDFJYwB2ZUK";
+ sout << "g/+WbwWHIGFI2O/sNsM8W4xDyia/k6ZFpaki6uOZ48uR8uLOu9mMR7NSJ19gRoe/aHeOAi40FC8v";
+ sout << "K6Bs4rbxF82hbJmSfp64b0d8pNPevQ4X0UEbQI8d9o8RjmDqgVwV49InO/hobuZNWyY5sIz+8b+0";
+ sout << "swgWi4uhT8bsQvMswskKsWmV9bxrrQ4EJFOGtzCIeC1X0Kzm25gSf7biU85dr8/3dXDQcEqdL3x/";
+ sout << "BVrjeWsXd3ko8dQTJzD2kNLqi2yNmymIarjj5qzTxlnAZYbJFtFxtPPO4bfPDeRRQ+D5PhiZVZn6";
+ sout << "a9WUtuDFmwKmoZKZIaKQfQtwkMQq6F03sU/EsO5UuglsfN9gZmLVZrNR89YPC10gM3cSXmABMYcx";
+ sout << "OHnaivD81i4KmkX23r4rltxlqsgzdUKiGvEpPhfRwD1bKlKb+dTFgA6x5cYaOQ+2/KqeGn0JvRHB";
+ sout << "HWmQQG0aJvlLelva7sG2mahqaTpsRGunwr6EkeTDwSzY711r2cNcLBRq5VGIg9ODw/Pn2eMN1Sza";
+ sout << "0xBt8eEzdGYywXR7zpcalcJfOOKUdpm5D/Lr0Q2y9qGhqr0yambYW4ltxSreDuBWLTp9lxjbPmeg";
+ sout << "tpdUAqTqOEIfskeG6FVSzfTSzUN+q9BjZw4RE7aWCtqKm1M7hF7o9FRLMhqws47tamk3AZZTIC3x";
+ sout << "sGaCMG1/h/gu5bH/VQUZpzIj7KWgf7PnJEL/WhAsjjgx7XRcEW4OE02pwmhe5C0WQehHYOdTSTQ3";
+ sout << "Y6djpbeCYLGJNSW42W9N5Qvrp3GTYyqDlRcFgsZFDK0DmBonUAsSh6Ytr8pxPSNWAejairTkKoi+";
+ sout << "gon4K6TAdRrK8VQSSf+eWE9oTlcteftAn4iWQzY99aisuJP0MzOJr6gZgp6s2GlsaiAA0KObTlwm";
+ sout << "/SptTPhSn9K+d0Os8QMMXYHlhF6waJ4i2NrCiMulOp0vHYPOKmfyCI0+hQt9R3ArNvY1pqhwBHYo";
+ sout << "+PMNqnEJOH28aU5s9HyHPOkQOOSvMTYnUo0kOns0sa2dSxuaZOq9kb7aqV4qG6+ZXzFL9FbAhth+";
+ sout << "4FWqkdDxPUbYUh9pKNQucyvCOtJdlrQbrJgMOTHShadJl9g+eF5boVBDAMZI5WP08py+U2sw19IE";
+ sout << "/uju0H1VxjRPA4xips+lxnZdrgWQu0zG2nHeLObhdS/gbO3R25LWZxUILdNWpVbxuQrE5dRWlIaj";
+ sout << "aB5qbQO7zrwCwjB/ZjDc2dNEH/4lvPv8vnEehJBa/sIoieseVzcQgFRLK0n20sdaXil9vJ3B5qyA";
+ sout << "UUJBjOT9I8dtvWAP41Y53UOjkkuqWOoAbfd0sp1wxpmjOgzm3BhPoGmLhJRlJ58LkVYTF2Ix/Hwt";
+ sout << "ZvLSvOwNE8HdPFOU8BALTzAPtVpfQCKKDD0iOdAzDJC8NX77Ar/8UGHkgcJP+9LiycmuPnxOzhGa";
+ sout << "tsZH//NolNfZ3+4HA2wir+hHAjw0ia6d7OSIJyLV03kjAsL6I/CajovAsT8uAkegs8ydmvLryqGB";
+ sout << "+98uO61GMTuZNyrv6TIG0oRDpNiKywoYmXoqdum/Dj2UtsE0A+XAb9KK8Nl6RIM27j6vkFWjgC5p";
+ sout << "7Y9u+TIHFd18fxQ20A8RuTKooQaC12h11+nFUR2kAUqkJq0e9LWklG9jiq2kIsfHr6jORWi5b0a5";
+ sout << "Tdu2IDMtMCS/58L3+RggG0BAs+hc+uX0txQJUYEUCmJDlzReIAwTWj/5j6YRFpq534fp38muNl0k";
+ sout << "OV5OE0UU/K+TLtfXsUQ1QCbKOvEW7RcDICz+5sC2/FIGC0B+5KY8GUeQILlr33+LIxJTKAQwdp3y";
+ sout << "G5BZk4EXDjoLiis8OAO6TzgfktYXL/KwAicyHdRzSUAuzQ921QA8TeKMAPd+4UrzhUp3S/cgGxNF";
+ sout << "sF9DVFCaYR7t8Wg/ecyZ/Kn71D8ClkIigOWuWyN2XIuHgLVb81DF3gDdcoiMQ9qroTmrCbJCukft";
+ sout << "4zE53jcVbaVnhr7hK56cEbe3F1eFW6pEUFNLOxFrFOwyiwngw18cy/PUThn6kpenQoYS5f4FXPul";
+ sout << "qlFnVMTXc/j5mWhRSNXddfti53WsTB9eh6UI/MSWNE9z8kjdMUWn4z2perojF4VUavHqea9U+6ms";
+ sout << "samqyF9mL9CxEhDNL0+Uof3C+yPo4KSqMC3or2rqSe8IyrAbemjcfKGqMjsxBpSGH0YcRSNoHGw9";
+ sout << "wC6wJ37LuOncQ3VLJM9f8zgTpgln3y5NnKz9nMrgXE017K+nW+I/U8o0XdPAJMgXlY66pfkJsShw";
+ sout << "JmidbowoBQGAXWO35cYQ92Avtitrs8kvq0GIuUGdIsej9PogzuW1ZLTSNJkee/S9t39cXyQ9YpzC";
+ sout << "KSj+H4uhCrGCQJ6Wts6l7kEhHG24Eu/mwSEw0fTTNqL5maeimxJsOdmR/IEVC/pps38ZEUuTJXuq";
+ sout << "J2uTsoZmiVs3ImbIkjHnN7NF0I4ryYf11Qd4ULqa3eIjraoiDn1KOZIJK41CkKOYAYIts6aFS2iw";
+ sout << "+34SlRxtIsIrn29zkOG3L0yc0o+fRT5gXHXPKv3eVmNca7kPGaqSrBzv+c1Fz6ShHK6tBjiPVtwp";
+ sout << "S7k2/05j6t0tMny1O3QzQZDTdaAi0Xx5yg+1oyK3M8XrgLsEU5aN7XNRorjMSHv4XT+xhe7F2B7V";
+ sout << "//JUutejh1StxytV/AswArYDSz63Q/7Vs1Eel0/Fkc/Rz9gYPKDqYo0UAFi/baWoUrSYZjYlIvBb";
+ sout << "fQn7nAgDTWd8Ry8dx/j8kUSR/CBWx+xz+tNPRl0TjIOWZJtT/MOqLZyODDBisgCpV2en/f9VUr0o";
+ sout << "V68sh5TYTczlcSA/lngZLj3uNH6n5Vh7u1kjINhOiYzYgaUzbruQJycDF096vEGTZ8JWnz8PNzT+";
+ sout << "gkDfWt8JtyCsQg3UHfO58YSspPbapErvM6zc+nOzOde7dLkjj10pxRtQMGc+fES5ER9QLpNF8wIR";
+ sout << "zGF4ENLZ63jytxO/Y3klAguefc14abKtLaOi/SfWseX0HoNpW+cA80GVq+oYU7krKhh4n7mg7NHX";
+ sout << "PBsUV4gmzhj0DkVAgmaGd9qT9KlZPWHU4vf1pGE5P8mIoDIwWmpFsN2sz/mEcugGWAT2JhT6CP/J";
+ sout << "sPOCAB66Ln2SoLzSJvBQwmtzdBZm0a0NF6TOloCdC18cJZfqYFuEJm2gUozzFMPjLKonOR3Zhr2w";
+ sout << "CpBUdXL90wQKDIoTIjnOIorXzjFZv+O9O7y4OBDpjDsFrcdQG3AZPtHcQU49W42aH+6XcOUqlxRc";
+ sout << "DYsJdg75GZliymNIWgfGGuQxJ5kysL8iPjmCxVGNnXM5FvYW/JYzcXy4A2DXqO7eyVBw6KXKHxq0";
+ sout << "SEF7IIjjj0TgUp7tLKY7/L0aIVuvlNblobJNZDAHYheG6OG4dAAnwaggRO/dneeiQ6hDtFbenFmQ";
+ sout << "SWdiYK09ZQIpXbdlXUlPiJIG+mtHudHcbaaZwVlYS5iG3NIOx/IYdKzYnStqv+5NpkBEPySa0sYy";
+ sout << "npsOyO7Hz/09SMXbSOo+RBkLYrYXiLtH9qiSj+Lv2N2ueJw1fN2PeCd3I05jcLJwqCEsZ0ch7FfZ";
+ sout << "1MaXVqdKYBD3BxBs+SY3qtHgffqUrGXDIbFARc807a65++J1NLuoD5ZXI4empkMcD1O2j46tQqdt";
+ sout << "MtApg+jZVUSdVCJl9YArIMwOuUdgH0oPcOItEBlmYRmDDcWX0E/fuPfa6p7u4X6zwQNY5PdI8nG9";
+ sout << "+Qp+PhYXKt4fB1xEivvnHXy1x3xMmDwDH7qyZEmvpiNPxZSO5ZGUbDdzUH5AwTCxVRA2/Yd9B0GZ";
+ sout << "+WkLsX0Ds1K55I2cTSN5Tcb2IgT2VQnFb7ceok9bY3ABeIipxABdKzhLiDeYDfeI3is/pftpDxGC";
+ sout << "vkyX1ZUnB7cpC9lwlwwbRdXhQ8WidAEFZNmt77xdTVeyU4cVby6WCBelfRiKd0+wXtQsSp8s7iRY";
+ sout << "rNHSq8FQwmt5LPdhIsvLlD6254b5AV95lYKnriWhCJ3frsNaqWc001OhiyRT2FTtYIb45pxEtFw+";
+ sout << "fY3kLR/xs0V7KWpXJCyt7uI51vRNCIlI2PbfEe6epAMg7oLrbXOtjJ+z6f6C5DPlVUf5pxNNg7PT";
+ sout << "EEGuVTVEkWOD/hKQB6iJfhLqIbgfOnglwI+pR8JzGJSJg+yuK82g9rRvzhPvIQm8Y/xmsMoQTvg4";
+ sout << "zBpA2kgegUxYS+0VXG6ZEbO9CjSdmtUvcIpa1yw6emAmjR1TgiDqtTtQVWksXhRNiQ2XBal/YlBE";
+ sout << "+45tqfGRjSiInj056XBdqsQqcQVuXcxXMT90/uslRAEV7oTmVX/XwSahsuP1VIevdVI+rC4liNj8";
+ sout << "yWmDEdVfQOM0u0MgevyXmDj8BMQHznlHulIVNQRhEUmX13ibjVIDLBHDyozBbvPneuY7hBejxnV+";
+ sout << "0kaG1VGjxSaTpH+jOhXwK2UvlXPTN+FM4UPLS93zEzEiJ4EWDUR1wL9VUXEfx7IveXQSYixy/rMy";
+ sout << "zpSX+UpoqFyiiKaG8DGxg2hu4CPjNXfKVI7/IZAFIT6KHcxPrrO/3BZ2HTqwpPUpD7SCiVQkqzwd";
+ sout << "gA5zjgo5/vrSFtah1TLBrd3PiN2RBsGn4EzHBdR2IIrDQcE3uB3nEQRyWPhRcO2pkVry8RtcGh4b";
+ sout << "hkH0XKI26GAIC3cEuRzdwWU70JlrbuN4beXlL3z3xqxz6J88vpg599N6Offn7V0Ki/z0w+kIgHJN";
+ sout << "H2CSc9uBNeKfQMGNdQzutSckMpg/9neWQVhgnegiV9hJ+L2lQODdnUCJp6nBYPHqkgtbOAEFRL+K";
+ sout << "tkKQcyc7YkJs20WUsItmeW36T26+0T3uyN72dbmGqBaZIonm32MKPgh+sVTbUspUiLdPGSFz6zuF";
+ sout << "89/dwHkVzv/dndOHh5yb+itB1LaCjsDRmdHw9uBYto6eJDdPa+XYOBe2w2qLyOHph/3NaHrQI8ia";
+ sout << "gfvPuVAZ+V4XXQgXAaWjPF0XtyBmRuEm+C0J9nWznqrKjja+HXr+hnmlEuJrHGFIJeTcpNrrbcRi";
+ sout << "RBdh1oiBTCG1tipAQeDf4fqRy/zrvugJg5Rx6/4fA39XHXW+IjqLMuOpA1qCpOWf/4llleXazoM1";
+ sout << "J392TqpJVxxg43Jfxehz+XXl//IZPMCoiNeuztiSaWuGsw8ZjBAc8mMnzehBVPv28axiOixuiofH";
+ sout << "WF61WGuAUXWOZ5Jqsg+7dMuHhTZznTPwGVj6hzw0cyL4KEoj02zic8JaSnnprKzhM442UFIADacs";
+ sout << "m6PchnrjUqfG3D9DwbGIwwmVFXePtscyV6vb19r5nITmXaQHm2zJdkpVmwe8rB0Vusg+d3Crpxe/";
+ sout << "kOWVPy11Gv/rkYx/1BsKd7SQJfTdsZFABUTjLIaJb9pmry7FnI+pNxWNfIyEu4wHTyNSs8ZWPOlM";
+ sout << "0oXY2CqzRqdnzt0rOdEp/ncMhjN0p8q93nMhfwSlH3Fvse9KZ5Vrk4Np8HDTOEjjgCQO3vpDveL4";
+ sout << "S53oq7IfTsp6lKSyzR3h9iYqa+fnr2G5wDoIlLBPIRXOx9+l9OH64rrNITKFntHEkFFJi8rnkr54";
+ sout << "0892mQk/ia9ELfn7V7pxKnEgXnSqDwldLYAWSD6hhGmvUYEvEKbS+4/D/zsfnauYNFkc+cdtPuFx";
+ sout << "uo6oFl+6Th5VYBwrmk7xq7j0vu9xxY0NfYXkRcd2gp7tuxAtMn0gJVbfu+IfKW7cneAT1fszraXX";
+ sout << "+qVo8SX9+d7phb7a4srtpA6513DcZSBSObMJ9RT+HZeYDn+3l986vZGAOJoQMGIohkU79sFTdqfR";
+ sout << "i0lfN9X/JohokHdD9eroMq/F4vMWEDj4ax+OR1wvJ1T9syDfS7YT9IXXt1eXYL78aHwoCJ352IQf";
+ sout << "Rv3phfVYK45OdBn0BqAQiUhWdQxPcwuvt9O4cs1JdChTjSSGz7bKtOPTBxst3pj0butIF1ORrs/h";
+ sout << "n9iLG4GVA///xeK2m1OFKDuk9AQK8hX+Nm6lEu5zKZO6NBJjSzUX928LJ+Wx93h6F/I2sy6hmBNS";
+ sout << "0bhas2TNmGX2Eiu/EslEzPudjbp8uBKLCyHvM5P0kiOdzOqfDD8L0qyuOLNumauOEVB4EuSoROZg";
+ sout << "uGPoG2xyiVGdkKIXIfvb9YjLpU3mebY73zeInPcaQCVY0PcbvEYIs/wMXch9CU8FP4wyJRYsp5rP";
+ sout << "XELbGL9uDF4FnhmFDjVOTX5j/zV/jVNNWOv98Ue2JkVZ1o1gMweglPMk5Gnxn+twxl/OY785uZkB";
+ sout << "CJdsVSNbXanSF4dFQWhkEsmjXyA0PVsw7uD0tfIaHB/BekjR2hWcwbUmsAUzIm4EVvepSk3Ec3rc";
+ sout << "v0ZNE6FO3KD0kNueAt8RvhpxbO2AOL5pm3vCt3q7lbz86dalSf4mFuv23Di0Dyl8iZqNAHzsrToW";
+ sout << "50Wua4dgrUEhqSgn8ohZq5Y5tI7m3uiP2LTLAfVDag0TdA7IAREMvTT9F3N2f/2Ff3j3b9HW3MEI";
+ sout << "2PU2HZOvUzJyyOfuc0TnK2DXpyXwGWfgoeAQ1X8c4CZhEfecXs/FXM+BfUyJEBAxnAkHxSV2oKW3";
+ sout << "PPismlBkfTS/XPxnChKJu57uy88MTaL9cEfVexPwEiCWC8VOBAd2nQUYouKadJLVx2UliM/dKQu2";
+ sout << "oYzUBJKlphHkLMGgUUuhXmGFa0KFkpisUVwFtyX1ey5myWvdyjm3xGkJJW7+6KzKIuFZ8oF6QEBS";
+ sout << "thWf7/v+I+TMG2FI3cl96g0zhRnguGVWeQ1NCzPRsDmK0fDZcgnCGyQGtQkdVgRBysC5hLqgnq0W";
+ sout << "+2S+oRxhL7AY+DaVGOHfYnI+7jyf+NVcHhBfqw9qmVH/rZk6cuG23Lj3nS4IzdR96a70EWufMtHo";
+ sout << "Mu1zlfyTEcZ0HgZ8vZdaBzCDRT7nkqSou4/uomjQzFUBRkZYD+etcrNb70zi1Pa+AR9kYXVJIO04";
+ sout << "TTRLgikAT4Ja36SlxjBbWDS0swHyWy3ecHhdKwVvZcklA8yxzLvdBbiANv51v/o4O8vP2AAqAQV8";
+ sout << "AcOjIeo46sslQzYC0aYgaigzGmsx6nYGgQ2zFH8bpWR0KNnP394GdKHx0AuIjmC1/FKtN0KVVQeE";
+ sout << "xcU6gbb8Lbe9Mw/WZYDiGtSVXrNBlD59+0dNG/SAyDZj1LqU1Fp0thu+CzID2LVAJtA7z4Wlx8f3";
+ sout << "t3iMWYIq9+3fDuCta6hORK3L/FUG89sB3Rqbi6OaYW7nu60+tm+qRP7ECwoNcyVFtGLOOMDonoyM";
+ sout << "uHGfkqQcEeQaAfPeWl7sEbnzmBbdqZ+xUj63WTpGQ42Ceqaa+LeoUJT3zdt8tDK4Hc9j39rZh0zX";
+ sout << "HtwdVVVlozsSQO6HvvElDrBwO8EY/YPczI4dUGCc28KcHNgEQQ+M7UenP4mbqoTIbo4Q0i7FCvEu";
+ sout << "9/TrBv0JQ88ae2xds6Lk5xsQomtH4ITr4VYliRfDO9bD60zxQBhhiHo5joRQ7sW/t3ms1tDzsco0";
+ sout << "YG4xVvyLzUHdDh4FBNVXJ5ZHnHnTBnppmNZ3M3ucVQpk3mEXauKEdPt7AWwdX6rCqNR81/ZasFJh";
+ sout << "qTU34x7ZH4Lg3Ut+95CHI1qubhh5W9feqPOLbvUaJzuqxlcsTZBhY43N1PSCerMFFez/b1/Rtw0m";
+ sout << "OAmNLDlwUfsqtrs4q3OnKlfXyHlwuIDe3fClgcE4r+9i6QSjVzPDaPfRwE+eTFRKPm9uhlLQYAnP";
+ sout << "YvVqOtnUCbLQod0SIakYVb24cwnvuHfGwx1tikIizumhn07N/8LXxtVGth+prxi0Lu8b3hMI52e0";
+ sout << "LXdzbJfXUa44mXoyEL0ARMyhaombmUt6BX/HaBLx19n6j+munyoheRlpqIdIMWllIAaEMbxCoyXO";
+ sout << "ExRZf8NPFANOB0hvtXhN5U17hleB7U8ri/6Uf/j30M+jglbucV015gMXP3h+E8qg8aOy/RoxhkWi";
+ sout << "lJBmD8OIGytd0GcoSmpzYzpoOQ+DVNMvVlBJxzG2sQy87PtAiJW0u22n5LGUVEQLxAJrRDdEiKcO";
+ sout << "zp/wdKOlFD102SjWbLp1mQIrqkjVgGRs+M8kIzTSyUw3HQPOvucyewaf30EkILRHwNOpUq18WxRt";
+ sout << "dE3BXUtm6v0Xc1rn71+Kbka3pXmU3qoLLQaCLLFQXcDBnKGt0olk84XYOb0/cTlnxBJiips2Q0z1";
+ sout << "CD8+ny5LOekJtzcze0PcYimx7Rl1EilhFrq9fF2N1cdyF2skG9UfoPVor/6U4dDGzN1qTzN/uflE";
+ sout << "harEkP1c3/Dk4BqYx39Vym1nvTSs9dYVYli+6K9EPaxjUu6ng+3ad48uZWE0xA5i8DfEYU+VBYSg";
+ sout << "ssaHO3OrblCccri8HgN1U17xnDR5C/zBxpcgoQ5BZfs6GYsjsDmIHz4Kq7dCvZ5xXStqZ9FEPnNf";
+ sout << "oeAMmWp6H15Ci6pqj12u9lZRG4rj1kQvaBvnpKgWOiscli6Q6mCTEsgtIYvmQhds27mbO6Y+uwDd";
+ sout << "ZP01VsjBxmPZaLxpUabmGIZiwAsAZKLN9qQgRt7EdMIDsyLehadWOzcjiYTHrWTrQQ8R7zN3e/eT";
+ sout << "NYJ1wsWhxMslZ+Q96vKEmejIyWAvcVmxkCuiifp6VfXCF8BEYRsf0lY+Y1VFe7yW1BaOfy+8Q4z6";
+ sout << "jtBh12vD54vp7dOo9xGlFYZD/w3KQKdKyVnnqm/tbT/pCrc/X6vbrgsf7sIbV7nMUJ12FRKYYX8J";
+ sout << "p0C9T9nkvrPZkxNhUKIMkra9+NT4tMsBPw4MyQQsMmyUjLjZMTSRiIyupPjo8U01tXEXElUdaNxj";
+ sout << "feSkHmcQS1PBXV8oQfQnwEIXLGUX84fe6q6i46YrhrEhjf0crtCdpYrDOuhJPAKnu46i0sHS+xRG";
+ sout << "B6/685zGr8K+1ET/1RqwR7zo8t0vJf4zqhAY8OpWgvY5ASN2f0gK2dzBMfuqrai05tWaJaUhYRqL";
+ sout << "7FufcaIyxbmDc1laARn/mvN6pYA8C0VxexvsQqF0OlOhpLOsY31GPy722kMJeKxFLuFYYp9DeHQS";
+ sout << "Br1KtYMU2S/RnQ/sIbzZ15aSbYstM0K8J+9hxHCi4aq5K7KGzzNCSsuMTAz/LEDEV5UKCFd7se7E";
+ sout << "rs4gwaKuHqH5b9n2cAb1wMpM+VMttHtz5oQplOyrZaz1r5sUV2KGkuU4oWHiHKSgu/tvrBvU9LqT";
+ sout << "1TB0RV405rPe3HjtIAzSPD8/VUHZy8pNBjGi/vqZR/ybRNP8GQFWZZ+vt6l7YMX83DqWbTbfk4oR";
+ sout << "C06g1YnZcYt9UsV0GldCGJQZ5oJ/UHLM4bfRzSaKiRAWTwOuDv5hHL+AKJcYOvDub5pCp+EsZh7h";
+ sout << "kMS9v85LQNmI5vUHbhrS8XE/2Zuki86Oms0AtR3aVWWGWM28oQopU35Ayu0rQE4oexb5NDjsKcP5";
+ sout << "A38n8gS7O4cbHdQ3XAIQaPoUx6TVK4xL5Pc5bUJ5aRsyivvPGd6meP4AGW8YBKObMMTL6TSRBD3R";
+ sout << "e0B/yqLXd30yRNiazNfw/aajOIAeHfbNuFvNLeqe7IKdpRrDyLg29Fh5n3Pg4Y5+vdaJigi2Iy39";
+ sout << "FKvhs6zp5SE0tJubMJuFDUfgYqmW44XVNv8z+ZUIJCYQScuSkCfMmexS3LczxvOTqTm4yShuNr5A";
+ sout << "+o4wXi9n6jEa0pcjTVlU45mbOXiX5ELwfbSpYDorlP2BJTZqMAvnfMdxdX7cs123ysQUHxMqKK/Y";
+ sout << "fSUaJORhU5zWNpmiSFeTCsH0uVuWMj11H2ttk1+gotSz7I/RCvotGNzuS7SqqOZzb3h6LpjHtplh";
+ sout << "04W+r6/FtyAq5z9gOphIZQUhoyO2tvbY7aCciCBBX1lugJWKgCQ7Ui60G80225lAU4b9QKOGzwN6";
+ sout << "hqKtxP6c9Dh+rxiERh4NoE6rctk+CfBoCAxDOEFVM/bcWlEbed+KvM49mqrDF9epNJP81oIZZb4i";
+ sout << "afVphDQaM8LuV9LNhPSMGLOoPeYNPAxyCHnlNyAsOrZK7otQ4lFc4dLM66NFJBbSA5iBTYaW6MqM";
+ sout << "qGCTCWxu5MzZbgNgwMv7QrTgPy4LBH1qwST759B3Um80nBimBa+8HbddXPbsqC/KzjW2yjodKmqx";
+ sout << "2t6vnINn3v5Q9dFFKv4oaetAj2dqSfbfCAvcMvS/Tl0Ont6PZt8iFpH2NmwcEbrwSY/2cKKqZIh0";
+ sout << "sVj8a5fxwXHAkMcov5J692QQf7P+OUzmkI/P/+TlBoWYW5Zor+79wh05NnrDHqdVMY24aSRKrdIE";
+ sout << "qOeME0ZEnPkMyOHJRUzHoPPw5+r7Oowvbske31mERKc3Es0cCjhJ3bab+OSaoXt2xfzgLd1e2eXm";
+ sout << "DtOlfyyLT4vK55WTN+IP8xpw+e+D5N4bdGDn4UC13jvhrviBYxasYYU/EtFWpFGonD17FPZ22PAX";
+ sout << "pGNpKbM9Eet+Cabzi3hJjAXwKYCmTXHk8+aHceC6DtWExR06Dj1QKoxSJg8vcbuik11BP16e8fEB";
+ sout << "UiPiV0hzsLwtHdbg7uZQgvd/e8qoyivF/NqZ9e6A0s4r0NoRO+6s8OPAL1QTpavvLtDhF37g+7HH";
+ sout << "vjGkDbTbzU2TOdzl0OvxmSezNfASeloEOSXqIkjMStawtwC+DgLRXzMhACjjVuc+MY0D+/zGt7lp";
+ sout << "hQlKeLuLWLDEUYNn3Dvb7G8/6beGyyjKHlHSOr9WK90JsS6OQT+xkEZnOibrIzgZBwmws4SPrhuf";
+ sout << "wwwQCiHH0xjOXF7iT/vi5qEKbJePKxh4m6kbcxxaPmP56M7TfmPSYJ62RTfZbnq/FIPurztCgSHG";
+ sout << "jmTJvBi4m1qzKgiwHVBoWjHK3wnjQxTuerxN6CMuTA8mv+ayvrabdbRwpVRJgjSXbb/Vzgj2e0PX";
+ sout << "IeHUocfZtrz9J0D9f5D5nhIlqv/aSodiEn91GhsasFWwDWwSdvKwAnTtCaC8EF/+rLx/qzCneheg";
+ sout << "4gGRJ/VgaQ4ib9Dx0+0SnOEWoE0GJAceSDJImKJ7yi3Oo7Dxy3J/tsJFST+JKNuakqFkNCnTetmm";
+ sout << "FkVNEXbo8tKzbeuXlextGxSr3ZCyOYfxXsfNQUJawLDxpL6fVpaDbf6Ot7ip2nejQHL5wt1AqQIA";
+ sout << "GNwV3Gszf6q/bx3vXKCTrkL+AFc1VmJR4ZShMwVlVz5pdzMWJWmal5ivqLFkRQ4q4HXA1bPVhS8O";
+ sout << "tfnYwzNTBMRaNMpsKGTL7MOOHdmXUHDPKhoMki5bX7Oy66C0x35qEGyTozFQgygLOMFe6eOleXF0";
+ sout << "D40Yu6VhauEQko7l/VN0rICXaTJSaemFIgmPLT7lMfrd3Ta6XNKYxXB449SOC63A4doQBIFJnekm";
+ sout << "VO2t09QLePM0/ztdbA3toWuK8cTSNZgelZRbSlpgE0+d5rqxjlxmKom9I/8kj5bjS8FrQAbbsnR3";
+ sout << "pLEArb7kuWg3W3Gc0pdlONrgWVvoYpTtwQsrY8QfZmIVmQ+Dkcz4q+LqfYRU+TEIM4Y06Gm67JzA";
+ sout << "ncMkUYcVm5YS2Adqe/VFsctKTliHEPhPRFpo9Sfn7csGTexvc+IYB1mFIWy2DjRSsfIorQJ25wAy";
+ sout << "FX+E8PlVJPf7S9IKNM7NmNX4k0aqJYJt9VVuDidECJrYiTW9vsJ/SDM39qR2bQ01w42aG7/ksF6B";
+ sout << "tWPS6j9Gr6KDOtdAK9UOof2Oq6uI1ENoajWJuDMxe9fT3J9/1ruJMMDE4dn8eSXCi09ip7qbnrxv";
+ sout << "5splSpCC7MvfBr1bpzbQ6gsZF7tPXz9R6FfM1idioIEJITqqaaYdABv+qwad9Pwv84maHHslw8O0";
+ sout << "/SPCYyCpfOAbbYuVpw+cEryeS7RD7LNom1M2Qo6A5Orh6Qmd+xTExUOJCxkPNvyxUDezWtVCPxHk";
+ sout << "lu2vn9lqjj3UaUjALvD9xB0WsGhdJbOUtlQw3e/gEloGvQ8rLamwCBAelQqTTwM10iHkvx/AzA4p";
+ sout << "TZSpfZOyXY+3HEjROeej+gvP6VXL1lWcYSrEexizEzy2eN/H7kVTEgnHOzYwp2A8x6vlKHWL1gnU";
+ sout << "lGV1R9H2bzmhj/iGFom/TMYWmiv6mtUe3PTRTIo3W2yjlLKnBiBFhWMh8TjfcLz63dadzCLEV3hL";
+ sout << "D9DPn8fnJYh2HZyd/N2/lVDtEDnxqyd8C5lZnxY28bew7dB/n5euYTo1H6zWvpIHlbbbmXrHd8VQ";
+ sout << "iaJew6T+G+bBxuUpUQlu1NscX6mcHZBoT+/g8Ng1Q8bmjiT7/c+Pl7wpRlF/0jwhWxPIL/5mlmPf";
+ sout << "+pHBfYj5BWXTDrZNchJKF5Pvwdnl0GbmJsPPiQQ+Wvt+y3lt0wv1gwefzFdGhi77EJyRY841JGSh";
+ sout << "LU/TVYCUFn1MdrtVLhGdomxo7Bl6Zj4bZaYlKXf3Qru1bRjD0Ug2tIePfjX6TPymfgxBkgGQ0zVx";
+ sout << "dJuLeeRNmhIFVH8T5jJTMcmd/lS9u0hzXrEjvpSFKPmH+Fmv6LJF/D0bl/V0rYEI2gKmpAWJLd1M";
+ sout << "iI6/XFF/bt6tQTRu+g4WgPWk8HG3VffsHUSAI7jDK8hjrwrJm5AHIa1a/PRxiSCH5Hi2EHCPRfkB";
+ sout << "grfcTEwLwHJA+j0zK0tV668Bal1NBw5bJGMzQSONvQqB/vnVqcRDxhCyztjW4oAnEeeBKPlAmAsy";
+ sout << "MlrkLbkVwjiaYbpzkYCsO/KlIAgoahAJEOLghLQGBuhP0voDuN99bKp5JZFaGlaaycEuYFHyk9f7";
+ sout << "zo6Xfa+hhmjfgWJ8KCFKKvr50ndwdqRTafQVxjZRe7QkwYB5ZIpwYehKsXlr5SfTLp+ANmZqP832";
+ sout << "yXRF39OZLGs7xtsMmPdI6WX/xsOKiCc0oLHdzH7y7tYzLVQVSFdRZtN7HA2r2Aun8gTbLExrbfRW";
+ sout << "+16+go4YQEubnMwW3FeRQqbpM+0GCsiuTlBPTgnvNEWi1n3JHGHeJLs/0HRNo6ba+jVE3vIZO6B8";
+ sout << "NN1s+3CSDGA0uqpGg/51P21F/Q7wHDgQDxRuTP1KQGNm4u8A7Kbdw5rGZUb0ZREWwGPfUpMU0qXi";
+ sout << "TgY2XA4rESFQRvn3Uo8jPP7UOZgKM5bsdToC7Rp3wuKXJwBlmdTZgdly6CmELOILvHVNtGHg4yxO";
+ sout << "XA9Co0RF7RsHOg1b9XzfeqibiK27ZiQnZ6sIvl+EdmsnxOY+Qa8oUZdrQ0JnFUiZjC0+P0760Si4";
+ sout << "i36AOc4h+GzZ6WbG0yPZlFAFEcPiBbyOCp6VoY76H32BskEuSBJx0U5+iFKvkvK/6lyvr9WDTSX5";
+ sout << "n74qoW6/CduCm96RcKbi2ywZ0Sk8vi6V+tWYUnrDaepg/iY84UA0qpAw/QsNZN99BqOlNEWftRVP";
+ sout << "SZhPmpxzE2xg/SLUwR5y7NRijA8QalWWQf4G9TotViLV1IkftS1d9TvTMb38vGW4o55pWw5SZlzW";
+ sout << "kFf9fbqlQ5ktTZVol/v/n2pwBjDeKu9wbNFlI7bhl2xhg/3nN2e3mW0HikB6K81JywLpxL+sa8+l";
+ sout << "SW6bdGwbgj3BTAt9evqejNkQHVb8ALTkmZJUcPP3CmW0DWrKioqolNXrfVRjl9DLcxgQfhG7osjc";
+ sout << "zE7/8gU0JYIk9Tr/XDHdk9zzOsMZAHu6nUbKpIL/5mytR193QuNE3x/eCQpsG1dMC/alF8EGn77V";
+ sout << "QK8dhMFI3BQCXlo5TmLXd4RWOUmsNGt061t6ZQ/jvP7oPRskGSB6YUtr9Z6CN92wbnBUIf5VFJPD";
+ sout << "qf0mPU0vKaMsbwQmD8+rNG/bYItW8lRAZLa4gJR5r6lULMKvcMiHeP9742lYL97T5w5VTo1c0Ec/";
+ sout << "YTUMJKSvhfGPq4IfdtYmMOX9pmqd0VKlwnQ5w4pKiGRY3lj8zpXAHjraMA+46wTo9ZJ7Foq1XR1z";
+ sout << "eCY3xmnahD5uHBR3Va66rw3W0WGRgbuCk63WTKDULa0dWuvNw0RWy+MkZzbn76QLX9dYGrWYUZWM";
+ sout << "9mnQiC9yDObAz9vRcITf5PhEKduGFQgAl9CU0k/ZgmSmAYZ/wFJTb3U/roNqnptdm/KGuegjdQqT";
+ sout << "5utMQZwwcgV8DCuE3/Y9DZcnEMoeSjZF5ugB4Hw30EJt5xAvCUpVRA1tygRwwWD5XiKt9srDjBrM";
+ sout << "nmkQCbblTb/HTPqOvdkBO+4yHSGtpXfA3G0DdSFtFPO7OTwPLrUxf6rM6CQTgpzP6/eoupLM/I7z";
+ sout << "tlfmNGcq+JYCy8csjKaNlPTJextxMVqNScbpeYwG/9SpjmhqObv+oYs/uga5UBWSyS5ls+eF/Tph";
+ sout << "QSD9FkMCTY15Twa1LqEJy2TMZaDkE/UVJJB0Xe9eRQRXEgkM/7Olj2qsyN6VaLpzNBhcXT+Q7FDE";
+ sout << "NyZ3C6b4MeoE7zRM0/KEXc25ma/uiHu142x4Ar2icGm7NYTp/gxVRUrSpqOOK2EYoHzCTFilsryB";
+ sout << "JhzQqEIsDpNqmqyjy0BR5+rrfBJFxXAaEDryGMntT9XoLOF41ZMUAsypcv6yujSkbCemiaCTJ3gc";
+ sout << "9p0IkAplKKgb0m1mJ3SrW2WR56ZGhZuVIQ0p24yHi1ECqcEK4GJq1jMpHpqIGfxVLUD8yZ6wFLDl";
+ sout << "NP890I+aZC1kFSBSuj0Jee1VbQzTykN9oj6+8Bo8v9Yv/qt67WH1h/oIP5bxmCrjTzbncAT0zs5w";
+ sout << "qzi+jN+AcQS4qh+IiuCke78wgaLtnMJIb1yA0GkDM6t9crv7vKAJwGkEFZSSg+p9TucFv2RqCCjQ";
+ sout << "D5v0X3g0OpBql7+5IqXmuTiRCJE6BIBeeXVOtD0R6JORJ0e5Hy7SvT1LrIxISqZwat7//kb168WS";
+ sout << "cdpRlCa5c0+ds81SVFkjTKViZAnFkOHPGVgzxrsdlfTt1T/2fx9MZO8HT3p5V9eM9XTASzXM/jMN";
+ sout << "5X63GR7qz/5hpM12SMJvgcznXESiudVq7d7HYd0NeId04fzMlFaGtgGLDRr+15G0eWPCk7mcF0np";
+ sout << "7XLz9ytRsWC/WXaVx2/ucmztaGzksA02If8fHmKuIHbfJ4YkjjKjVRrthMCOLsGYzfb9Hat6WbXv";
+ sout << "jkdPQZd/EYGM/ZUd0wfuWo8W7yLidl+zMgMcJ+LeJjvRdBkHU3o70srnlKBaneUv7Ly/r2xPcsj6";
+ sout << "wLMH9iuRZ7jZJ45EM3htNEHJzrT/9PxaGKiK3caGIilUnFPpyb+HgousVcBbn4XSdkIjy32UPdmQ";
+ sout << "aG5w8GLvMEJ+SDAseg5/B2poqEiGstsA7EbJyJn3ptHk6I0lUFnpCgHI1H/vp7f5vhSIVkDyP4SD";
+ sout << "GPPTMLfV3mmMzS/IeIdOLdshYQCrL9I/C1sgaASFLuh0fZoWnNmSxfS1dVA7fU597PnuEzGn4TlM";
+ sout << "CVb8UdlIBfEmrFXegOnBEKu8RKspYa2TdawhfTZs1lyPJOxOewWH054k09q5m+p/QLHfmK4kr6wF";
+ sout << "eShsc8VFczeBC2MtoBVF+a2gGsz+F2Q8WhtG35kJuTqZK+lqB8clFAVSljny5eVo3WJ83QHoZMO0";
+ sout << "L8XspRqErV0+T/n7Uf9p72CQbw059rmyc2X1I1Mka7+cdqewt3RUZbnB7YD17+lg+do05keNDDea";
+ sout << "kkjLqmK+oKKOJUyDULGM2syzbhMrJn3WY1v8RuEVsLIUrCxgvMzD/Hwcl+lC3idwjonUrQuWWDyK";
+ sout << "dTjkimDohgLx3uoRs2ke/SCb7ERVtB0DMpLF3zytvY3D4ZGIdet4BGQ9H2nXNWMK9ahTfo4GFMxB";
+ sout << "EBdLiQLAqBJteGuoJ1CWdYGj6BDbYoy9sXSdnY4Hw0a8ydeBmFqoRm5wZy6ozEpxX0IeEc3W0GfV";
+ sout << "+4EojUBei+nfqmpeaDTWFgITbhplcgSU4snYS6dHxO4U66xa6N58YokpK0oSHkL42y9+Ap9qxiTk";
+ sout << "eDRtTd1wgP4m4EZS6ld1rTROZMhwgwZoyKizgzW3cc6o3rS7wNBvrdRMtOPVo0YJD0UZJtwCyDzT";
+ sout << "/LMXSVs5EQByRsZkKxNFMEky5F+D1IKlVny8ylC3oOI0VNJVxr+hIcuYjOf3adwKeslTQ2AO//HN";
+ sout << "4WEyXck1inZFxnDybu7NGa/WhNdet+PzgzEI43go+Nn/9AYLjR3m7nAcBOGgtMqmrwOTPJl0Dhnh";
+ sout << "QWv3JWK+OXqz3RFJ2q0lZP7A7EhU+3G5OHdHl6yFOh+Mb9TNh1ofYXPXL/mnnUhxCvUNYEXVdHGD";
+ sout << "RALydLzf6oY5id6VdYgCOcQY6c32WcJUEk2zCeMFNxyZ1ntzmn2KlawbF6MPWHqemMEvJM7YCFXl";
+ sout << "ObWzVlcxcnecWLltZB5xeRqDBcWetCMJvG5crQJk9V2jV3iXi56gFzOta7HzrEkKsjuYtys2c2aN";
+ sout << "EMcOMedGcMH4q3xkzca56eoR5b0mXeZRnFRnK82LczMxJjitaFLOMTA84YAgWWpWt7ENpJ1qA7bu";
+ sout << "QFdAdnQC+DKTJuLykS9eGpMZIuOarnlRH00AbK0j98XyjkaOhmL4ygSKNJQghgmFg2vyOfji/XNR";
+ sout << "F2M6c1HdfQK3Im0qLTXaMZlEjkWPIZTcDtVt+CdUtiEZaqPQwW1H37Eqk9gZAs1TnLe5AOHAMWh3";
+ sout << "mS458zVeG1d59r1BBc7pxrju9CWB1PjF19pxCUc71R8sh2D98ss2W02R8dqCAT2wdUuEC1WQwpL3";
+ sout << "GGGEYk8jB/4ccjN/+EdYqAulxV/Bp+q1jspEm10/EU3Cpu6RkNFZj310v0t643E1MqzLjpg7ZSko";
+ sout << "FkzcMRM2E7Z7WEzOQwbr4KVgXcXZ+76Xc1ahlvMqwkUQV+/PnYj68Ogv95pj6biD9vweY46XV2dF";
+ sout << "uxp1FG2j2e/nzzFJgFfIcsB1411szIaEqmHTUPBo26qg/EyWg0775HTFSZwlw8iDOFG13OBlz7kQ";
+ sout << "kTvrxy0iwvM6cTm7IqJUZi8hhC/eAX+Sdx8S6pKhOcy1q0fadSKF5cTynAQtrSYwDK0/xYRjVl28";
+ sout << "4oTks7lHxGE3Jdu0+j2ufxNuem5fcEgWFSKZnZGM4XLvUnpVbdjPMEjf+KJ0w8LqDbTbmZYXW69q";
+ sout << "Apl9sQ7AdAUcS0ZhGtACegAyg1QcJUjFvDKDhXwx5d1+0ZKChwYt8FFzqht6nWY9WjDT1i61XmbS";
+ sout << "m2Emc68sgA9VMv1AL2bOhXq44//vlGYj0bX0gj7ZDEdUhCVNi5aXobYv5fn4nGfTNKp8njHyqGrZ";
+ sout << "CT1xO+YrGMY/4qAGiBsOMGtbRrIcFHeeCMKOO/5so80js3nh2fw0L/XNVlY9RJXJA8WkEsalm5O8";
+ sout << "MUDzmTukLJQu6YIjFhCHjT0FA9R2agaQnGFZXiiZ+0GnsRoUSrGC5mdkZ9H6jg7Odm+aB5fD7I5M";
+ sout << "pHzEi++Jt8xHcbX0jyaQd4Y4pdku3xx1MrdgZht2WmCpUNxg66syuutjS6iWKQlvWSdDtlyf5zVU";
+ sout << "rRn2ESMesAxpmuFd4sG0j8Qm/8jzAF0Fwolv5GmJFlwKx5X+RW5Mb3HZdJYg3xou0Uwb3FBgV6aK";
+ sout << "mLA7GQUK92IMgRKKNXcDAPJdaF6wT4lin82Ss1QpWBBnWwLNTslobUbn1TJsg+MgurruWUB9byjb";
+ sout << "56mNxeoTkMh9h9fEuDZYL7svjxh8U13AgnJCgchaDj7iK0TOd6b+sbz2DadE2SoyFYo4XDTbXcpE";
+ sout << "LS8TbCb1rH7G2WAS2a4JKRyahhELZhOB/X+jPRtre+MNSMvoYOFYtHYXjIzoEpOgEGdSPOsLcA8s";
+ sout << "JHCoPcU7bFQcnIlB/HZVqbDZwVzAtUxKoven135MSyoRVu8j8nTIcM0RBT1VjQa0L7eHcgQVHPvR";
+ sout << "uhtFmhtUJJT6lDJmoolIJt+nxeT45+ndSTz16YrT4pODPMuawO3XIULSvkPDWNhQybkF+2jmUqlX";
+ sout << "brWKyviTX/IZsHS/0YMkgXZRX5McLzUH3Vs1PcjNBoj2q6tdCq2HCcIdBKH0Mokm9DLqEi0oY4mq";
+ sout << "kEsB6HFUnYZm4xezpAoY/wHkXRq9Y/6lOa4DtgBz8ng6/WwFwKwCIkT3/aTHzKDQSewYiU5jmQGe";
+ sout << "R4+wtbbRqGm4yfRc7z0xEsMESF2FJnODPSSSBAFT60rKGdGK4ai0IaYbWoWaHeDxhHVeOUYbwUlr";
+ sout << "dBHstgs0/k6t3QKbHDo0KwMv6LN0sYYDK0fD5cLA+92Pf8mOlOAS2YNlR0sG2GrSa8M2gZhpgTcf";
+ sout << "dz4j66DDlLmQaC9UozZNf2PxDFLUN8NEHu2XnfzYEyuyt0BQcTQyqs8DvrayzIwV/qiD4o/xRI0T";
+ sout << "Ma5RvRpagaW5EUApHgu09rkvFXk6ijL+2L+7j+lDQTgkvn+wHYwbTADm/b9tZs2egY+6pSDDnK97";
+ sout << "invYOsWUNUDtgvRxgFTDDrjUNX4xNifDUnGCJrODnhg41dOG4M6ST3gBgxtKkrmuHkcgafuohkCm";
+ sout << "w8bPHam0dnd55H+lj0J74uyhw/LreRAthXdcR3RbizO0rWIX5EDP5Y+JYxQfejANE9APE2zFHg6f";
+ sout << "6y7DE+GVQHO9iN/N85TOUhN5ya9MZ9fk8M2JR/440Q7gqoqNJ7w1PpzfHCyQBSszPQTPBr5WeirK";
+ sout << "3D9axZGpND29nMOQ8iZy/7b+5IinED5SDZq8MTekt6kgW/ScvByBVOeBbxh38DRAcClRnn7CdFT3";
+ sout << "Ey+kaJAYlahuErS/gjU3VkqPKGds02dkPwumMWd1G6s37r4hQdPoO0F73+5dxWwqF1dX75nQ5tyz";
+ sout << "J6f2pbTFnJcp4+CGMYs/SEoPzODOqMbMt5hDqQHZuEekNwMNI/X80PhnIOiMa7oK7h+8cOb9v4Fc";
+ sout << "07AR19kD1T6r53s77u7o8wYLp7/uf7NLT7tA+bc6Qeaiss+JLJznbb0LHa9uSUZ7dK0hvHFAYBjh";
+ sout << "fGqkQ2JI+adCJ/eUQLs9zvo/0GvEV4P3dZlRcuXw+y0uj/HhbMnTYgnD8AUbRYLCLS5Lx00m+Vrg";
+ sout << "lVyzOLy3d84zSmZMR/xoX1Q27biAyEnvZK5vkX9CmMY4lhP3XnEBfcK/UVxxZDQf5HfZIamjAbOP";
+ sout << "ZRyNX/+VYpWja1qGHZmW8Tc3aJfH9GsIKdls9/FKfuydYwfwr2kPuG+PTwXxiHKWWdzUXluLrmB5";
+ sout << "T9UCUmf1mkBM+Jv5r5R7qUPldDHzJP3njQfT4SUrNNXL/ooR2NTX/jYMqkTlwSCM5w7jYa19IDe5";
+ sout << "lYzspMfRfAnjG/DPX7m8H3fYqHXo6JKfWGL9NsuFzFPzbeVqVac5AQ7XcoEDXj8b5OmGu915yRqO";
+ sout << "zs+Oag12V6PRX1w6EKKyhQT9BPvW9E8I9kSKGZrO6rSq9pBaZk5DWES/UX0U7yCSHaBAlYZP2fmt";
+ sout << "tvEMLhzKwu5vjk35yr3Nm2ICYBow40OjKI6ft7EpneSgGbkBuHH7X2RxfCVcg0+7THSBMEMgnBzC";
+ sout << "+8JERFD94nZDWrX7P/QTaejq4L8yWMb7uIChSHDnIIUleV3h5kGuM0MCRChSUPvXUESKuhkezPTF";
+ sout << "KRt5t3r1CbOHFwnuIjB5JAMUhCUF0IEJtk1b0gnCY109aCl/sefPp3yznMH8tDAWTL6G/aJbRLxv";
+ sout << "DpBb20+ZyK78kfTrFaJ6xGgAyi/fkcggLa8ANOwU+ZbVKvVEQOzU9d/8ygn1SjNBASKaHfVja6Ym";
+ sout << "TffBgjcGFtnz2NkjUBccd83WqOyZOxsFqOugfmoavukdhDAaVtcB/KijSoBYQVhNITRC7/RLRJSE";
+ sout << "/obYwojRxTfq//cIE33pXW5RCDP0odyGIr4YVCHioG+li9NK2/xrduroOiLdaKly5eAaFZgiFof2";
+ sout << "jVYA3mXRxwH83YBdKUSwHDhItv7R2lAT0nsfXopxcJBQVNwPvCabAmsCf71FTl+Wn/LfALLRgtvl";
+ sout << "b9JBLI/gPf7tKXx/MhCyz0Lq33VoyOGm246DvalV2QChpolNafxuxK5AnvLxYox3U+53A88SddqM";
+ sout << "P+k7oftTdX+u7RhDmZtjHLYn4ikjLwK+afDTW4hCBGFR/bmDsZvAyp6OOVnnvsj9qSOWpeI0VGQe";
+ sout << "tCaSgoPEwP8GtY4AX+Dhy1uIYPmoXZUPyNQ+ng7MlCk5S6HAzLlsd5giVhyzN4b4N7f6vseMubhI";
+ sout << "/lG0Hfcg+rMX2VvTF54FMPgPjkf6CWr7QorPDB/Puw8c3dt69dr09Q/r16t6aJmtyVa0EzAxiqsR";
+ sout << "n1POTejlH9K3Ul98w2mthFuaFVwk0ILJjCZZVgY1GxlvYRzPs1/qN+VQ6GdLYFaaBWN3mI20aPND";
+ sout << "wcKX7LEodSgu63Da0BBrKBkoihTYg+IbO9m473LWf9U4dZDQmSjJeISGLa7HDoh/JAnfp0Ej3EVG";
+ sout << "K9oe/ewohkPFTgXIV/qg4LDRny+zIJ+jsEOHeryCzBfml0mR+WWwwYYcPur4cCABhkZfKGVH5L9o";
+ sout << "S2pTxevQeY8hDj9dUblfU7W9RBl9t0j6pYm3wP9wwP1yh5AKezedXy5VavBPk82SzWBqtsyTpODJ";
+ sout << "AWQzgF5G6yLvPeUChHRDGedP6KzrPxkSxpyKFZq6dTNl8g+yLg/FHwOZH/R8GesbKKCxjjeZqD+4";
+ sout << "UjB68IayH6ttcLEjncK/rLOWn8lWrnTsQdOf7GznjWDHILJ9B5hxEsKMKqOfYLKosPMpIJFk8UE6";
+ sout << "CYffnCZFI5/M4stj/2A4zYt2kcqIMN5rIBJ0hJtBlfZRTMweBmJjhAKbM3b/ER6jQOlbPBzNNEmN";
+ sout << "DvcgFIL1c9/mqfI041LVldLTpHeaNDRVVXzho5+iVi/R5UfYYTvvDU48E6cKh/ErrdGHXURNYRXZ";
+ sout << "ghNAKIOA3YeXZk0SpcdzxRhQaFwLcim6U1j3kVLEYnk17GavItFxjv+22zr382ivLNfjVdx9vxBj";
+ sout << "bRqj1fy/goe9ilHC5sxCphuoe9vfAe4Le0b2hqn0Y3LyO/lvzJEVbP2hkFbU06YXUOktXJVCH5bk";
+ sout << "JO928XrFA98vOciWwqtRXhO26A5yELcjimSmqZ9zwt+TKz5cG53+CXufCssHPHDSaZCbF2gxST9/";
+ sout << "TooDJUbagLQ7VfNZMBIwDVoLpl1dG1Z7GgobQ70lkUAA6Rlo7JfhNM5d8N8UgM2dPPukCAQ5yuMl";
+ sout << "0KTHpW9a+qp3B3VE3+8mPoX16p+zpRlw5Y+jA5ctu84Pez0FqZcTmM4lFcevPe6C9LCxJGzKJJer";
+ sout << "yGyryVbEzPmVk6nsuWnHFodPve6NeeYnlTkifu3FBfP10DzXYSUO6nPwSZiGiG47AYnyQXmIcj0N";
+ sout << "iSwKDzHcPf5IVgT4qZswNoy+t8QBATl4run7klhzNdE1QSAp8eNZnjtdmZdGY4UhB3/mto70MIPV";
+ sout << "OBRoyQCHkzWx8tPcn7NMUW9lDJdFFgNiHngYBe47fz+Tkuv1iBxRppU32ByeXHOSKMgaTyb5f1XS";
+ sout << "ern+rWMqDivMybInX6vMV8BfQ8eOcnl9MS2OblGPj7XDsbLmr/eOe4cFoKdBKg791eavES07Kr6/";
+ sout << "LXw/mELHEEXBqkuafGTUlUfatUz5OmHHM6ssKoQM1uUrGqzgcdiAe8vNxuMsapJYJjIb80meaYia";
+ sout << "EyAvl8dEYiGSx5etaf8hwCnXAOzEfHNi+qjIC6Dqr5s87SQpYPDexkuUv1hEGSDnEbWuDo7BSwu5";
+ sout << "hZCA4T6yoSSki1hhXMMsl0tR45bm6UlB74VX1IsL7VD6HiPY3BcveCMbk9aNCLt61KYwY+4+xEWk";
+ sout << "smbgq1BOC31buf7xQxX9UtsyBHQig3dCScPoFBm0nKaEHK1fuf8zN2WIi7+p3pEQqRTphOxHtsVk";
+ sout << "QyC7PmZJS4merPZpki9Mh/jo5MM76vCXgESUr91CY5neMliK6mlVDQjWBPenKw6JwM2UeDjjQHTa";
+ sout << "5w1vSg47kgvLT4KUfqQjS/JDx+u/RIrKGZbsj7WnZlJXsBvoFYn5N2UGBj/SOLVmCXYhBfUHZxqH";
+ sout << "GvrtVWHffKoSHcVlm0G0u0EKYYvrYCydASHOegavtcc66k4YIqge/pv8j5lTBiswafWuBz5OwbDe";
+ sout << "zxzAcApflzsfgNLol/DNUutk3QpBTk2qW2XXt4r6tkfCG89NsMcr1L3E0+NqWc1IbIIxJuSBhq9F";
+ sout << "wNvPOBUYmMD08jguGFHHFLFTVAhcxP90IP4NN6/ImZWqUmZ7DsQz27ritt7RAgjR0aXUJ2AipCss";
+ sout << "LWr6mIptUwcOwPHwqwrb+s5xuLIGsT+hXzwpaNrOq7AjKx4GbRcL2UhNwDJ3jAQe8K6uJxffzABu";
+ sout << "x1FLhLe7N63S+rKYHMLSSQMKCdf8uxVRz2NX44QRIdxXpV5/mHXgDd8qysZkv26IGp7DG6Pggag0";
+ sout << "DtuCocsV49swsov9/gw2yR4r8ziyO0oUMOygSA0Uirda3B1MOiYAvCULRc/HyzndANKNjef+uagl";
+ sout << "U9r+o2c2pqApAZY9JeKNj0vbW5mmIZPU9B+/LNGnLAmJvpbrcRTpkHI/0VUV541n26da7jXoJ9y9";
+ sout << "ejOkZUY9xMjkLl4aqgwi4cfg3TJNC+0GXjM7aRlJPk2uPNujEiLSkrPk5GX1+jlZIdBmWSfIkNeo";
+ sout << "XVQf2dVu1YlWKqmRHsgnJecopec6lhGHVIeamgsyyO9GEmQBG0cZ0lkkq3Fc3Mhakcbnw8InWFYe";
+ sout << "skGjFM2kSHQ8ImGUuGx9UW/cKMZbq++oZJQ8y/rmTDfzASJWXezdCOJsOAEAqstGrVZXO0Aoq2Qr";
+ sout << "rwO6oJYlxdCpc4MLBzriX0qlA/kj+3qc/lAb7svIugXusyTAshRsSeKgBb4qNUdJ+31j5UoZbwT1";
+ sout << "LQt6WHxVFoqYzdjnYsdbnSbV6llixlEfdmdg+gyAQuYciR5z4B4D/dydNjmbFU2BT9hbeMqkam/3";
+ sout << "vZSjziXOOsqMnGj5ZA9yIPwKcU9IqvlhxSsLPyFN6eTq7EUV6/njn97pMQhew7hrhXWpoq66A6lG";
+ sout << "5kE6M2y39eZuQnh1GmWhF2Pr+0Y0hk2WHx+h/6sRNVTH2qOm0wFh3g2ZqCPgQSFSdam5f+Jo5yYi";
+ sout << "IFxYEsif6HxxKGg100Yt59WpeKCVjYO26uc+bclSKZ9VNFsowfyCI3ZtFOp5aqZne3aB7tNkqUh0";
+ sout << "387ffCRsqMoQTLLZcycLiIpcCgmOCb8ASCMi5WgPrtZxKA9oB44xUgosSQlRgchUdu5qtROn1Q7f";
+ sout << "PHqeNv9EyJWiwHUyLp26f8pus1iHAYIJWvoFj09Q0vXiyvPM4iSd4tFX7kIXxOSNy3ch1FFHWc6j";
+ sout << "FDiXSs9Lf1iWpvF0O9iqXZcMdez9dx3lnNuXBttsbNFg6vBABk8K3DFtwDVWg6a4CCFlVmucQD4F";
+ sout << "Nf2eNdlVUwwpJn//rtDlvPnZXqpqx8JCEyBcYt0/oYT5vJdrhOKRNiteeUuWf51JxdJh4XKnLE+e";
+ sout << "8JAmuTL3AJ/gp6Bxpi+4P3jAR9mZfA/sQnTqasw32kz427/v+hGWMq5j+Ak0wGrKdTlDeA7KLBGl";
+ sout << "vbZUrSR60wvpL7SJSYkM5XraJzR5ICO+IxE+hfsDR2uTSrzfcsb0M6wnVatnXjOFNLL5AY6jpHc5";
+ sout << "sHLhUHSziNpnsMsWJ4h22wbZmrbO0zeD6wTHvUtrTwOTiPfGkgvz3DLcNBzcn4oqfaEk5JnXoE06";
+ sout << "eH4+bqe1/kZ7tlZLzjd+7lFAO8dlAA1KoIGOvAL1L5aqnPbCuJVaak81GSiwHlaztibjOB9Qolm7";
+ sout << "y/ln+qnZ5nbMYuWnn3lN4m3Gnk692nuMfiqh4RQuwwxzR1FZtiCzGBPkVD50VMzUsxs5d170C4MH";
+ sout << "ER71dvl0U9lLf8WKsHyiU79tJv4wTOpzSMszbf+WCpWZXCJrfyJ6ylRO3dvk1QfjPmFkUwiuQy67";
+ sout << "q+lFfRo9N50ZvVB2bYNsne/tlNHPk9/WdeP1rtTyMBMdl+RfJUf27wduC7RQzMtdUuhFLYxdDGxm";
+ sout << "1ZZBolFeteOW4IboI2KanbMlv+GOAwxkAWC9ymADZ9YBecvz1unjZkLSGzge1MY7ORdnLBA8fyUv";
+ sout << "JsigTjg7KuX2kpkuKxQHYs2l78Tlkdwo3x2VDKEfNegyiZMGMr8CSksLfxnI9dHwbUmyBQWMZBOR";
+ sout << "gWU+s45M6S9+HP3R+lwsT0Pn12HpljhnPde8ZpemkMsTbJW0rCQnYsrdryfL/X1GUxUxhNm5WJk4";
+ sout << "blqylbW9X7IbbUCB3+8A0phl522/wLWusDD9ndRsiiXj2EDO2Ah5MXodILg7FBo8PpHDsgKBLZeK";
+ sout << "WlYVQlWNtg6m/jC5004csHtx2ra6MjSxAhxyt2Y5qf90KBDt/61iJfKC0N9hWm5KKMsGntwWZBVe";
+ sout << "h8eJjOzM8vP4azMfymC9U87PqPuI2/J+IB+JNVnjLsx2z2NumqgAKZc4mVUs0PiNK2sSVRlMU4uF";
+ sout << "DdbEn+oK+iQ7ILBnkRhDu/C4h1eSyAQwF8XENwyYaclxhoVGRTCRkqrOaLh3SSb+YFSFvbDz49jo";
+ sout << "ZcfxsSvLNgIeW0MnmYUH9ydCunC+y9aiO2sZTfKrl95BeX8aEjtxCjpqVj4daFz0diNEQSiWIWlU";
+ sout << "/EqYRDfOIZoA5QkPBp2rioNMsElAHaosb90vh+LiDXpOj5cz0S9cKvQw7hI5Ma0gVmhQbAqFrnOD";
+ sout << "Me/94YZgsabKZ5IxOJrajCoF1WO6GgdJCkdSUrfRZTkxe0y4qBPAbVg46BGOTjnOkSxUUzXAjXDE";
+ sout << "wPf86SdrIWq0NDJ3N+D2lu0MloB/tabrVHmBXQwNDqk24uMVQPBQRqYUGCTzkjr6awEz6oICB2iC";
+ sout << "14KqcsiwuSxIsY2jyL0VG32aHA/XNQTZetsvMq5fee0jzSZWAqA0EFzw6UFW/kUmVEsBdX5kpYpS";
+ sout << "XuNVhtxVGey4zANa9P7nc40VmxDM++pP/45XNbpyIUOPxfcBM5YoHjnpAwZlTTUsBi4Kdx2D9oRx";
+ sout << "PGlK+TGQF9T9ZvI0mxDTXiotRScPozQ8jom/oINYIVzdH7EWLEG+nmKib+3icO9A09pOYWw8TSZj";
+ sout << "HNAxPczCwc98w5FdlLe4zDwHnILl4NwfnAz8x0QU+VyW6kugntyk+G1NvFHIGvOnFnZA/Ku1nqM6";
+ sout << "3jErxhCl0Ii8h4dD7+HVUu7FWUMOM91LQHsgmphdHA+NQXm+/J/zj2WxamzK8pfpnq9SCYpNfjGe";
+ sout << "NOJGOuytsDuG/Ct5fik3DquGqccuhve3sw9H5RDhaFqjWOEPT2rhlu66UfRsyy6Y2FwRWcB5YuLZ";
+ sout << "+96O5+VUhbSynvOuygLQfmCZb4RR7h1x4RLgimV+8DdDaH8abjuhKv9bFlSFUNpireUlxkYk1NoK";
+ sout << "q+3XpmpJyDI4N/qB35pf7RUdp0mlTU8/cF7g4OJ05ryskis2/tURAdoKcePaWSzn7xoA1tCGYC1y";
+ sout << "+Z9m/7erwcD8n+1ZHu3499wbywGzscGo0q7Drc4SxiwUUDSH2FWpaoUmKEUjsYLMwx57QPZVIiH8";
+ sout << "WjLd9hcjaht7OqgIQI7ihdSY5TTvihLif9siWmrJk2kpl7Qpx9B2VSVfFBrie6JU+6Ecn4XIMY4n";
+ sout << "GYVF/bao4WKkobcbSYCmUfTr1gx7hh8VkIpmuPfiyBVHXI1dQtY/BDxkLIBiUIk509Kb/U0xA45N";
+ sout << "+NPG833fYqJ6p9ygAJCQJfmymr3TXL/TxfXtSQrO42/upQob6+KEjSBpewaULLu3+6diY1/P9n79";
+ sout << "gLnJOnGncD988L0KCthNKXM9Rm9hzGjnObcxg2IHckRu8F5Bn298963sUe8x5G5BKCLWTTPFbBxO";
+ sout << "Mr/zSVZjlKaafRQYBWKE3BPQRbLKBYN+XaNuEQ/cfxnsN+gUimbvpZUueMU0Jeo1ReHqZpZnjvAE";
+ sout << "Upf5Lrmjpa0Iyy3irUQ8HEVEisLMdpg1Qo5RwuWanqTVmGxW0sCjfrCOhV3TwIQqxeWPfc8o4cRO";
+ sout << "uN7IkWSyck2HR/4ncq+tfaFXehMA9rxigFkeXNh3divWzV3W+U/SeAAQcrgUgecpcxqmGkQjqV3k";
+ sout << "Lttwa2MF5SCMfM8LwEG3Dl5JPvPfTY15dIUGNjEIFJ25UuJ1vK9K/ns4ifG5QxNKGFXLvgdv2WXy";
+ sout << "0968qdLYeIMXJAv3ieTFm2ivd6OVs7c+jp7k2Ondx5e/WeC8ciCOt7JKyAbJcCZYbasFlaOOeNOO";
+ sout << "tm7/SmLUp1hZALfjTeoDc26EtdeNf2x0HzFsGivu+/qd+eYiQXLGqStyYPVc2IZ08AYLFClR2r+L";
+ sout << "to5+RouZ3CCXiBqoaoZVsGiqyM9Dgt1E3PlaXtgLFzhpXNw/l4FXFKPiqS9DZh0X1aNQE1KAqLgM";
+ sout << "F/VYjhOpbzycTrRhiJ0vhzaQUiCCe/8QUlSqCVVNZ9w6Dvc7uuNNp+exxN+IjPcWoSYDe2QW4MHi";
+ sout << "3Sk99Bhu+5H4tVqc3ehExt01lIjQI/M1A1CpuwGv9Sz3/tspG0hYpOT+Mo7sHZ7ojYfkfLHOfp9J";
+ sout << "gtUT2uqV0LtDgctrD5vx6NEgfMzP2dNnj7wopH9P5YsTv2XPW+bKxUp95K5KMz0Ea2xLmrZhZqG9";
+ sout << "HXhubQS90G74ibC9p5nvDts1DoIC5swJpo6ZzVude5M6GmzjVUHTaS7LSjIbxQBcjyFrx+panbfB";
+ sout << "TLLL2VaFjfxw82RGjjQtaeCeZt0BsnkfL1A4Vb5ItP+qhayO+9lqnSXdECHa1qFTXKDVPvD+SQC1";
+ sout << "cXQnxXvCxuLmHf/ao8i5QFXd7tHr96ioCtng0PWHxbXyr3NQhriX2zbwD2kkrrL4udvZo2TRIe3D";
+ sout << "/L46KWG8S+zpo/iXc68pt4sjfuc6DNIXDtx8xkFAre/q9dRKmuV1BrtlYiIBQlhDByyQuqD4Cdwj";
+ sout << "1QMd3JkiF/WXCc3vsuLH5v54EMHM02MkS0TjzLjOinkMT9s8u+ssSyp0NUE83XRdUoND1Kw8DbI7";
+ sout << "e8izxcz3VIpR0PE6MYUKS6SR1Gq+haGlgnZYxBCbrhSkTRn0VQ02rnBF9An6rmXRHbqs1rJZ4yXX";
+ sout << "9/jj2dPwYrFW5NMXL35Za8fTfG6pXaDjPnqJokXt2rKX8X2/DVWhYQwBLpZytXTc9RpDis4VA6ao";
+ sout << "0ZG/u1qgQeqsmhkwTayh6S+VFpwrQA5upwIL9KI6JdbCCsQhazF7qsEDoc1GeAntjyqntc+mml41";
+ sout << "I1ZCYtP0sSOWfWiQ5yZavEsrp94P01thtx7s2UsrMtrYQc5hxs/J8VWtIIpg9AIqDW7g50qsBv4w";
+ sout << "id8XQPCcfyXQszUJi+w4Kubu+MPwAMiAVkDmlJGAC9t5oXIIwIbvNZHrq3C1ixd84+VqhvZZgw+/";
+ sout << "izIXi9IrhvkcOJmuwxQ9FCyd726/ks2e4H7TafuTXyE1RQ+6Ju3odB4meFU8SIylbCw2jwoI3kTX";
+ sout << "XXqHbAcdgMWoxoDXPZ0XzVqC3utfaOARHji4+RXEseqA3tJhS5CPHj+D+6yXJivrVC6KWSwPqLyW";
+ sout << "6JI3rVeT9B94EYtGH6qWvcA2GaSJjxiQUZQyqahHk/FxSkmnMTPwtvV6AcmYZjQESUWYI9UQyHvy";
+ sout << "5e/ULXf2bVF3dmTWoKOsDchTNOQ/Vtswf5avr9tuz841NKT+1fojuuAhhYP7uvRruvMJqebeiMIL";
+ sout << "QBzv27ji9+tQ6D42+Z5U40UW5XtytFKcGR1NWtaxNkWCPMGRLcC4mFLuQMYq8cK3ZZDhpzyVeyIi";
+ sout << "nLJRFigzQfHlBeJmXm8UccbJYo6oruyFUEagjILnb+fb4y+uE7vN6WoqPcg4vvr4RzSkuHueZxHL";
+ sout << "5ks+O6XyEywn5hU3kqOmI21d+C/HBRjJHbiL39BcHKdzhnHJ2huEtwqSV6SsDiqhxSUL7zYEDECj";
+ sout << "zs+9/3LHJ/9TbZZ104stYMWXIWuLc2NSS0bF5NHarE4XedBAcXaJzKzlj2o6NYb4Ifxcs0shAxl4";
+ sout << "6ptwCMOtvcTz1t5QB1vG5CHrc7emzriP7+H1tukmCRA2i7cZPumoiw81AM7SPv/A/9yr5EFRdBmv";
+ sout << "FCOwJNDKfD9Mfb6U1Lk+ikWFR1wm3MSqq4ZSy1Q1s4F0npfVh1J5nWZTxd0Z/TqDhSAUU47cC9do";
+ sout << "qRMBg/bBPxfrWGeDEOLActWvxMr2g6O4tdWLxTtWcGsGiatS9TahYclwL40gnzAC+4g1yOgc4/u4";
+ sout << "Ye0T1KyUWUnAfpWmFibwphrKOrVj4ZBVzahAWMG4jWzDcxTaKoZvAtQZi2FOn4V/+bKMURZChDl8";
+ sout << "DZpqrBSfSyY7DnPiGRgxOmTKK6/kb36nlas7zwcQvSgQTw3vPJSU+LJg4gKmDycFpl5deILsDRBp";
+ sout << "78NSCvjJqK7UVNfdAG/2C9nQHNk9nHEf3jo4Er7ogKGkWOHshaIvzL98Nz8BHbr115gtKo06mD+C";
+ sout << "0EJGtywTGXAzkO8J3T/gEuW35KH+kOIBMXzQcJNY1lNud+fALQxNv2Rw5u17KBXqLtABWVaOABgz";
+ sout << "IVCkpM90yG5CwxkzyKO+CIxb8JyNs6Shk+CPNoiU8mbqMFigPgXw+yWEEdOMvTdzJuj5dKX28e+I";
+ sout << "jwOz6wcXPYod1yzF0kPUlol0XaDR1m3abfEiTwquKsTY7upWPjU+Q4yUnrP0/+W4Oe6TgH7JvEvy";
+ sout << "fhqLicE67NkotZ1/+AMZVDYWZOScrGmi8c9Yj5dg3lWi5OtfLYg74TSQ8AOiH6jL+wLc9J+QwOaL";
+ sout << "gibI6x/dWW6eJGHxInVjQrt0xmP6I90cDPOnx8JC9tMAEv1MrLMt9VKpjdMut/Gok2doLdEHyTKL";
+ sout << "wBNiLGq0viiPXgm9/sbi3KGrWhOdypAbECm0emMVQK/WdkSFZu6PSArdOTvHthXOvyTb58fm7wS6";
+ sout << "oWvK8J4f/T9t9aNb4lBAYFtufB5S6ndUr0hEU0nirxgfsXl9ZsRpsQidq58dK6l6WI5CmUpYboL2";
+ sout << "6T66Tndz5jnvtPAybpIus/Kbmh3tPsaVtI7e6IRl5mzbEngtlOqfbwlyJyUiC5UuGJAUCtqoKadJ";
+ sout << "QKGtTuLv2VqavDmNIuZ1mRX7g+2Kbf6y+7L42iqivvX0Lmo8YSJFTuNkWTNlu7s+yYH91uC8BdT6";
+ sout << "3DqKRgiRWVfgKuuJvDp72gyJknKesDa36+iCp2qaZZ+yU0XhK/6U2phZ6XeNJlQY5H8JD8pe+RJk";
+ sout << "qgi9XqP8qvLIIv0RkfnDVJtl4V8pi6I6UiaD3dNUeHdN1YHCu+Ub225CFQIDsko0Ly87Iv4oULoI";
+ sout << "1gmb6dLkogavqkQfz69GYinC/z5/En7ekQ2TnOgFiPV3SaF70p7RmyMnn1/ON9r8vd4NfFPXXCCL";
+ sout << "NBSihPK2Tt3T5i+TRXipLYq4+kghSyTRhT2YRIbpvT3yYs1EGbEekCQQmIltN78NpyWqBwBks9Gt";
+ sout << "xpIS1Mj7Ff20DpfjgaSok+XhDSDA8pjNzuM4Hv55IJbF5se3Nj27glEhioHCvKzx84FUuInAjQtT";
+ sout << "0+5hsgLkTppvJTdSUg1hzhU7wa6NixwuZgfUenls7CwH9K5WYB7KN6ls8gzbK0V3QMoyZBRIoSuO";
+ sout << "7yyzgY5erfXbwfcYQ2LliJQQyYyLiEcm9HCxJGMuBL/41LQxxm5KBXMZhJx4c/S/54GJmTMDfU+3";
+ sout << "xceNujbOdtlz9ahcxUBPGOU5iTYBGVUPFscRy0nxuHhOu2VI9rYBhoIUZs01Sh3BDgksu4uUnGof";
+ sout << "crDdCqjpR+1UVnpoEmxbW/Z1Ux03tiG90lCn58/g872xPsO/O/DDX4wfNU1yeibFvWrW57k0FND/";
+ sout << "lNNeLErDW/nBpQSbaWal3O8lfmjD2SBwk3xtJ2qXx538rFZRM8DQJUOfUjvhycRGITPP4GZBuiE2";
+ sout << "cVpeof6Vqbe6J4I3E9Bcrnm3moAyKGiPDJ6UqthQuILGNTH2GOOXAl2I1X0qcJ4G2H5xgLNliett";
+ sout << "T+uB6q/kG7Ab+4QNVm0j0zdPdVZdUUtGceMJViJ+ZoRVjunK3QuzwgBHYr89OqNxvULc7tImHva3";
+ sout << "xan2frwQ2ZzJy8hRCIdaNSNgdZp4WONm+5Bw8/MU0l3PjjXjtlzjUE+rHSwdv5abDy81XQwkDP0y";
+ sout << "CS2C7MMahBsySi7v+PrLljFQbyG+deR6ch15UfFjtyd6Mctr8FcW2umvPYSnl1g2Dl0jE50ZLOHj";
+ sout << "2gRAl1EvFXIuPjxcA7gsrlATb2SgCUtrelJSDOJpf8a3WMKrLRYp2MONLXacFRpkI0iU4fNhbzgB";
+ sout << "o/CdP3O/AMGGnaZN5kIkmOwm+naBEF3mv1xFE3ZZ5VldoFrWrxZ+9s4lOAScPYzVpABjbJxzOzSo";
+ sout << "fz7+pEOY7/OY4oNgqWA4MEtO2pMbjUpdTbSiVuXD87KHJYB48ffxkA5aPKFWpdRFCJHM+J2Zjkaq";
+ sout << "ZRRUeBZVzo2rRvFBWrpbvAkkaXEDk+k+N3mOn1LF78Qbm5FWAn8aw8QL7oQnmrkEqUj9EAiYTx+q";
+ sout << "v3oRQ/x10A2LctzTyy0mCuM4bR+WQW2h8dGoMUhjJpGPQzeR6owMDi1n09T84aVUDkxfR2PpIHxs";
+ sout << "HBag5ryae61mzdzqNmyFzySbIbD83r7ML6FqMCW4xpWF6ru2DvNykb9GYL/kquN1rjNOe9XYEQM8";
+ sout << "pyYvcipUqJdEeYj9sjgJBXkuDr8fCBpmvJuklMOeuAwAMXwxL1b53VM0XNgVokPmFTSF5zqesA/p";
+ sout << "ah/7f2URRjrP9RR344YUdzGoccdhpNnFQ+g4KRdmd7FuPTJYePyDncJFnVZNT1TP026QrRPzg7H8";
+ sout << "sFMoZQuzZxjLy8gSU3JYcBCnfqBdcvrtF5Wb+pKkEZxROE4qURll4Y7zYVxPv03Z2PvA6eovQyap";
+ sout << "Pr2jTJPFgyEiAtcdsluxtc7NvY+B0Oxgk+oUyaXAhkihKKSGT2pPRUvd0LqDYU29/7v148UhItLJ";
+ sout << "minqiK/TinqUPHOe8oyljHRIO6awrl1oc7HbFs/8ZrK8hd07LG27lMDcrBogEyUZ+mcBRykohy5w";
+ sout << "vEPLCUA/hLP4NHaUts+q33Wtrq5n0m+4k79SRjwuvmcSXjf+Uf09CQD32oicQr0UEfo8uhrUPwta";
+ sout << "WukfxXE/VF/jfynSXwEomRZopcNoc20gWzRlk+M7HpjY1X3TCi0T6f8qAJuTjSSbUZTFC926jhsk";
+ sout << "a7Z67jTJ0woPzPFjK41KLWQcoArKTuqOytbd5fKao7qpxiFnUPd30dq6Xhpnt8oW2QESiOPq/4mT";
+ sout << "Bdtd+wl8c3bzgxRpOGBKakq0iI5Vc30wc3vTUw8Mpae/beCHolCN52hJF1LAKRJT4xYWMHLpKwNl";
+ sout << "ZqpKMybMAGM7O5X3xqcE4nH+bPkpkMILmNN2x4JfWYfy+667pQdIed0jEaDRHueSfyuXCp+Yu7sV";
+ sout << "q7CTe3bwrUSklM7zpU6NxuX++oMNuruKzZ7GwsP/NkepKp8MGUyoT8EOz7yIq7k9z61OLxz0Ylxr";
+ sout << "0irk3K8oCPttZz8y3u4ZmkAGcYOs41JR3Lu7D+/V0hjdGtNdb3ALJON1AB7r+OQ2LW3SEP87POnT";
+ sout << "R3wayKImuVIpIjZi+yBSMqomZMtH3nExOSKTHzLzkSyefGR3WTSIXI5GqtWNetBTk9y2eeRzbSu5";
+ sout << "+7c0XsuaWRQd+2qfns49HNwQocfbM5kNMKa8npRCg8YUBMimQiihwEgM6NJQWJ3ig6vwXncIOAYI";
+ sout << "kwcYWkLBme6vSiaeyLRG8y6A78zeoc93lGy87yjRWYkr9CeARb8B3tHX4E1APF0Dcs3/6MTZhnmv";
+ sout << "eMdEj9TW/MFATPy0BELuV7nLBpSrH2Intvo1jzPU7dGsw/36MbHr2LFO9a2PrZkMBCixQ5kApER9";
+ sout << "HQCKGangWrdU/jO76/8j+mcrxFNUEc7y0OUUFK8ZXwHMXQp7MOew6OVUmRuFSl/jv22Zdz16njaZ";
+ sout << "Vd7fVrgOdTG+EXRseg0LvvzZ40qWpRahwg7zqP/sAWp1ZhH1yJVoalrCFuTLgabjWGo5Amp4vyk4";
+ sout << "zzseJ+O10oZZrWsFOg1lbPjs4j+U1IhQ48qNS/zk4Y2kaiM7/0q7PkaUxdwdcdHwbOmp7zszAyfu";
+ sout << "rwXlt1PBiRyEsBKKKsBLwUzP/KJEoRvFvgc2VgzDxxe7pOsX3E6+ey7gKt08W7RNTTGx70YF2uVG";
+ sout << "4cEQHQNqLygscv4DllKnH7Jz+2cHnSES9us1PaWqRwpQ07XG/sSI6t9o8P0xMBVkR6XB9JkdKP1o";
+ sout << "Bj1MgxVN8pMbv3ANGQzTbmCQ5IczVK1Io/OQbSEVlZb8ZOatdQ9nq2YZvedMelM3C+t00b5AaQa6";
+ sout << "DSslaJ3FWzJLguvwYQgYUKvJ+jOC8dMr74YNAxt06b5RaijSib2TNDjnZ3tIySixM/dYOnygj2QA";
+ sout << "gubdj4tuvNpWh04VszgcZd7TbLntvd2IA8b+FDDzMkyT067BC51elULM0CWKoZgaaMglt3LqR43C";
+ sout << "H4TQMRoitKUyfjAPE07UmUi6lEGYcrCyawT+Si9YQy+T5vGMpryrpy9gaHAmFQTGCJgit/0g/c5k";
+ sout << "mbyZ4pilxX69VO9VHrc23YN0CMQ0W3pP/DpPNYQTcW3jeG2T3d8b48B+iY9boq4ONEGcx6y4wWok";
+ sout << "SuOSzK6rTR834/xqlNGMLNZWSOFjLdq7G4oQLmVZVRIkUXrQImDl5W+1EjdglFE0jAnnuxjT6+DU";
+ sout << "B9VIMbIzVlVId6xVrp6Kg+Fb5gcxGsKcpH83c18DbuuIoLsywmIEiWQfrSV7bPVuPKJVof/t+Mar";
+ sout << "Q32TiRHMO2ZsUXC8w46fkI7EsH7mWMWOMfhkvnY3AVoAj2i6+orDHNwnTmWQiGj7FR9Aw5CQcCbh";
+ sout << "W57bPF88lkQMxdEHBd0Hk1c8Pm+p3x71xzdIofXjTnCd71R8FIIvoduLv6XQttPFatTHY0CeRmjI";
+ sout << "49yWBFJ+w+BSEP6cnTQeSo7uAOyit2S3EdLudn7XEBr9eN4wIVAF5lhti7HqJqV/pfqpkgaAQeEv";
+ sout << "mbTG0L3+6Mr4RPAKaa1H1aQXcYT32jdC7w4BKaxvACcyjNuukUwhi8GxxXxIktPrOJnwNCs719MY";
+ sout << "TbFkKOxloVafybwRZx5DdrEKMq+jKyCffePTaM77bvL1hNyK/78Q9TdEO4hx2q0m853zlnibtbwM";
+ sout << "YcO9adPVjghDSU5amF4Ul5+rMxroEoaGKLW8+gL9DbmcW6Uj/vg7TRQohnGbzmKwPWT0hu22ID7w";
+ sout << "+ILYIhGhG337Pr0Osu6m80nETubjTijMQHAb1/Pwv2eNOOdu8X2EM2awT+qtxqHCq6il4YFSbUdi";
+ sout << "+i8vY32x0cXZ1EdKfgciuKtK180q1S57ZNZTCwe26z5eLXomvjtxLxBZl5z2TJyPEnJprnZWNOo7";
+ sout << "5nuXsYypcrV3W2KlRmybUjRfc+/6nMKKBtU8U2phCUg1uVbzAXgyKJuGRmwxUOXXWeiS6f+X02yn";
+ sout << "ZpJda7aMY2VBMEuxMRw7P+S18C/drO223DaUROCv25mTOZGKRI8iXQYEa9UxLvXgRfUB+eEGOuAS";
+ sout << "7N4NVz2kfzGWlQdKqRoh7fWDkM6R8J3BgWhXvl6dacJcZyeRQGjpR0ZEKDFTtLK8Dr9NFYXk6Pnq";
+ sout << "N9kMrQpwK3bDmdv++XISq0WHnf6SBJcNhSx74TUlUDBxsYAcoFefI1FNbFFobGmT7lEX+L7uQNmS";
+ sout << "/TMxpzGwSeRavNmjkanqTTZFpWNvRXrfsod+TmJPe53ApSSZpiuOAedN0/3AGsVuqjNtcxf4DUCZ";
+ sout << "13v1FAYWluQfSXtsM0weQ0q6ZvhM2SyX3VjNleA0/X/PPmOLLK423iG/d4m5+D95t3NRAbU+TEUp";
+ sout << "Nff+BWEiVSlWCTwbP7mTaJPnAUWND7rjif1IIrQLL9IdyauKUOP04yMgCk/LQs0HsDkLxSexCMHZ";
+ sout << "rK0S8Bh+/2V/dNF54MVJEqln+Urz2REIGmzqSjly38FZqDwG8Rn3dPYImGASFgGgrCAbYza3ZyEb";
+ sout << "yaPzMiO8wN+1F+DbMDfvMtlfkY3fO8LvJ5csTKK94rSw2erFACsU0cXqThvGma5gOcpL1H9CwUsa";
+ sout << "zKLfBclmCCGrP4rX1/SN7OBpNjbrCHZnkEebf2TrzxT2moacNtJt1cwVUVGhiJ/DYdZawCRUE8IS";
+ sout << "IKDq/tSGREdGSMLEwJdXyRk+4LyVF1V+YF3Jp55yFh5skh3Be2lP5YY4+YiaZDb4IjJh7qxjjO0R";
+ sout << "RmvES9HZTw0GhCa3ep3x8dvbHTp1NB4//onnMx9JheeyZRZOdEK/Suw5BlXNgCvc8UJJR4ZLtyuU";
+ sout << "7HxX7orjWgkHNdix0VnJKSFYfn2uJ0LTPrz3sARjeVnQ3W0On0DIXqNe4thTZDwUFWUKklmQGV2d";
+ sout << "bgVLPmuKff8DHP0xgGyiy1kHollAqwEUzD6bW3Tkrtu2f+LzTN2yAQxKkVsiW+sqEKqUtzNoXVQD";
+ sout << "36tnDjAFotkirlS5Bzk5pXgDpbAxFnVsTDq/TcT0XbIR7QvBd0bXq4am1uoZjaLGA6bF0l3Ej6Ub";
+ sout << "pB0ZMs2nLo/8EHNIgx4cBgXL2FN0DWyb2CzchL+jsPU4s7xBYuVvGA1bLaM/nlRV7DohA29pE5jw";
+ sout << "MXQPYIDzraGVsOdJ0ghhlqqnwPgO0oCaGDxIlBehjzkJcgXPhWarYGGBdk19SUjt9y+aqDQmxShW";
+ sout << "PrJagqBImi4S4N99hOHnIxGa46BFlk72ddnCvaJ4n5LahIZXXAhqOrxRiCKGKBX20/vzOh2ICqUn";
+ sout << "MgfLtH5T87Mho0k7hh5gVFpYktOqOowwqPJp7Q7Qlwty07GrnkriILYvkp1bftSHHyepPxAAgaJo";
+ sout << "jeyuGTtsa+K0jHc8WaDSZncIqmuKs8oVUZ/1YD2EkVlVXIh5NjyDc6RFfYhP5/w043WKp7xqeRwq";
+ sout << "zaMTKjrUr02omLUduz14J+ka9uUErd5O2xPaV2O0LpLuZaMjoZ83ZirW6U6ipV6d7c0B3kC4CcVk";
+ sout << "kUdzqBoH3VVbTav1vugceepwBjBiXsRqReTJBzNRJK4PPu/ApBKQKp6jyZoWCfhJiuJNB1sEjCbc";
+ sout << "o6XEZWbBEjqRaCujvVzo/dLLGHUUOXyGA8sKJq/iAM2dqvDoaduLBnCsp2oBIFw1q0s2psDRLdvY";
+ sout << "KiTAKEAypmqRB37gV45Mrmn0zYeDR2/EsWDoo1HFu7Kh3FX/y2fL0155dtiBJxkHDJIXj7POIea9";
+ sout << "Bl/qheUN9FwqXsVheoacBp+Uh9NulkH01NsKCyX3nV9ec1sdzwh8hj3zWZIpD+ATVH3S9Hh0pt2+";
+ sout << "FHzfu7z+YtwgpWfSqim/C0/yfDF1QoRws09DqLTke3TmnvxoRfkdOwsc5y1s4YR72xFXVrmIRpkW";
+ sout << "pqHx43/MxHaQ8WrEABai43+FqGJmR70fvCNecLDPm0ZlGySkOCrfgMH9lK018QD+DRu5D3n6alPI";
+ sout << "3bEqruPXlP/N8eNnJFEeDbxQ0aQFJFxSOmTkYWr3LZBjtjX1Cc4qujxcCaUH2hnKvdyxejlt8fEa";
+ sout << "h/I2aSaD3ejPS2xz8y5D7Ve0mxx6znMEO8E61RbXztaIRzze57GvfwpTa+NbAjtX43ltW48u0yeu";
+ sout << "5UM25BNbXDdmfUzzVFpiaflImZj4Ga/k8TjlTVURwRXX1GzB7lfdEukFRdQnY6VQcsjzo3O0gKDA";
+ sout << "+ZUtku8C2JpdWVYYsTVoyR+5rNlgHdaGqbqzADSrSqNBBs0OLaEH+pqoVvN7gNNk/hCHN7tvmi8n";
+ sout << "LZQHQ+0Ukp9p3bPTUapH+h+NJ/SXUHdjHrJpaFJR+ANmN8a79zacX7v09ig91M4yRTSxzMzsbQef";
+ sout << "xnNKm4+2JxpFLH2fmM4FM+Rx30KcqLmUdBGOaJYNHQqBVJYYMNVfdRCo/KwBX7r+V9SRlLQ5ijED";
+ sout << "zKbJlHi7LyjYOq0iv0R4MIiW67OR+5O0aHeoXQ5RJ6Rs1Xytaiipn7aXphPVPo+fZ2xH6nLTQfFq";
+ sout << "fEutvtOrYIR8G0dGUSYsbZEhdPC/gpSdHsO0bOJlxgXhNGNeOFtYohwIKP9Z7Giv+lM15GPlxnm6";
+ sout << "usrSaqlZWs75RuyOn7etmGBSg86+J5wE3ihgVXcsprkgs/y5mnP9iPifs1UBfDO+2L6UnmceygSm";
+ sout << "sIOIuMzGUkrM6iOLYHtp15txD6bUYZsIEeCCnjaIJujQc0lvbU9bh5AVae/HG+SfuWqE5nzp2TD0";
+ sout << "jug6A2KoPK/ufGzjCJJ6VIe9v2bG3WXApQGWyj8lWCO+d6YtKFA4agCPTv4FSnS+Ip3Bws18ZKxj";
+ sout << "M+AFJKpr7lmSJURUl38hITa//w1TKSA3vSoMaEFGEH8OFkOCDFOHoGK/z/ievfyBkC0/WgvwcSbj";
+ sout << "uvaOd+M4umm/BCNqcPjZ86HSej9mSb24Kl3Rn2pxywK74b8P3z41Z2yVbRIjK1Dwvo6t0zPp8vtD";
+ sout << "QHW/9v2uZ0VhBdbRAL2+qFD6qKg/sacjR0kE/o/zRR5vsW4K0efFZ8uerG5hLs4rYhZPtiDOHr/M";
+ sout << "xOGIfBpL30xzMdMs2zTnbXF7wECjcklZcwlefbO5ubaT0doKkVSZ6YkCQqVXL9LNVNHqULonUgBs";
+ sout << "Xrbe3KKGFrToxCljsw4+3+RUVnNN48CklcKfEfTHEXoUNZ/p5msufJct4JL4K+35Ecfauq0GOsLX";
+ sout << "sTkVaLfkelv8Uiju3VEggHb+lLdrY7oKT7aC+kyQu8SjtJQY7b2aKQNELwrX9u95mPNZkySc5vaA";
+ sout << "xwL/6ar2Y5/UA/Lh+5d4hGSkQ0PbSacBStFaLzJ8AVdOcXihah1a1ddYNtfH6kNsck9XJqz9Y4G0";
+ sout << "YdFr6DGHNLBlu2ksVL0ta/tfPZy9BVFo6Vz3tbpnBwkPJ1vo9EZrJxo5GJ309oeLz6mYG+0ctXJr";
+ sout << "PC90nYz7s4RxXkWZBzq3RbNiBGNgP1iOU3x9INOf1CbKBhNSp3b5QJDfM+hVzL0Xd78pBpO/XWQq";
+ sout << "hdl2dL5XZtd3XYslKjHDaZQyZgnaoWD9hhRlk9lhweUSB5bFjHOH/FZjlAH8et4Eri7N9Y/OFHFP";
+ sout << "+a+FoLLT0YiBGFsJiqe9udoECrxUR/A1Yi3Wen2xOQv+eMMQ6IcvA5xYONHWrS8WgMT/I7oELwnx";
+ sout << "GK3kwjKy8lQJ9HOmVa0yuQsV5bbrtjQcP72TmuWKOcyNWf9vHvLbWSZUDEXchz4HfADCJEehJzh2";
+ sout << "vIYiVvotktpiC/MzHmHoTw0Sf/hOT8AALmyPqVFGrkZHhLU4GZTqcJEzBNUvkvuwo3qMHzOc3zfk";
+ sout << "VtoQHlg0pOcKczFo/44eMRlpf0Cn4KTcU6vEmUncnvx1nNMHEYKQNGp/IXHpXp+PN7s9zFDFbQHY";
+ sout << "iXURtZZ/cF6jW5TshrT/CgpldM4oi/gn3yTm8AJoqATWeF83ft0AphzHDfz+B0qf2wF9Te8adqcb";
+ sout << "ekFiLrNf0VfA/wMCE3drkqLhZKO6L99EK9SZqsly4Q1rwLr+sIgOrI5v5g82DOFpY9zL8Sxvg7Gp";
+ sout << "03KDCEsW/2eFfxNtGGZ2P6JJOWLIe/6Zsx+1F5lcbYxPhSirbDYp+1bMZvbm4uLWHttHS6e5NQH9";
+ sout << "pW9+qFTlk68lJWtpI5Nzy2CutML4mI1uVDgIctZLSHZOTlWW3LICiFxYVLjvE7JExUZMI5J+3VZx";
+ sout << "OTTeFGiuzoEveGJ34tOMrHxoj3dIMkRN9kh2dN10crD8MKoo1sb2EqZvhKu3yb+XxKKl/RkHTuBv";
+ sout << "JZjCnoTnGa9JltcMDgzO6demBWSN9950HXwS0cZRcDIgG+eDQnfSfHGGJ0UuL9ydA0KYWYAPxVMt";
+ sout << "tqR+S8lKMzk18k7FuGBsFjloOSgsychm0GSTKu+AVUJJohpnT5KhE9Z7Y1iSc1pBOxyf4iS+FF5v";
+ sout << "AzB7BEivwSPqcjPplzSQ1uGNZsa/mCfGy77XBBm5wYUpKDshd4l8mOHPFofPW9WrpI4X2e3LuiMJ";
+ sout << "8BijYVutWZv/8BrhhFuUnngfLsaJBhO+KL1m/JsQEvWIJ53jPt5qgEOn/4Tei953J2hanoVzKelM";
+ sout << "l2OhgHFUxvIYGvKMwLRMPQ4kUH3C1vvBJivvyIZ/JWcHRq/0CCjvg1aPUXSXhBrgkqFwQduoNW9G";
+ sout << "D3BowPV9v7Mk+IPddFxHdQeyHdaquVKxWoWHzlDnv0gbzlv/EdZVUd7k8rbqYgUdPl9os0dbh8gq";
+ sout << "nRpOW6MTBdFojgzrdOHPhdHM+Rs/TIntPPd2YmryOjdHhAyrVgFTS/rx+yDRrP+ChjTPAF+k8lG+";
+ sout << "qvN+Z8gDFq3B2vfIgAUs0rbSMRfOahM4+EH0pX7fe9ylY9yPHtWKs9i3mU+uVTgB5ZtRKuSKkpnO";
+ sout << "MgWOp94H7Ui/tI7b7hAC673dprCx1PZgdIncutPRIlxzv8exEVhjhH5AsW4gRNxLX9pukC5fVdXF";
+ sout << "nWif4xcZbEcsW+mm6IV8NE7iokkhEC0bahxpHeS0H0RB28nZY2idhHP6hLDMkSg9b8mlJgqcGBtj";
+ sout << "Q2Vd396H38VydtreELEV2FgC6Ski/PFP37rc3/NDStzcY7d8twoEiWa1msraDn7zlvORuk+Cu/Ec";
+ sout << "VaHwuWQdyqf7dhOhrH+/wK0RTPMKHDyLxkwJmO31Ka0LEOvqMcsUzEaxAMUNz406l0r6pLW2MbP4";
+ sout << "AdGu/pH2HwxZ7b1ps+8JiOu+U3SDrmhMZH7kS8LUhCNpH21Ty6QzgJcaw2jC9HYUoNeobJSQaOFu";
+ sout << "kiOHd69hw7VbD752KMiVur1AYmRegfhkEm9850zS/AjT+SntphKTtFNsc5cR074M9iucA3cH6Dv2";
+ sout << "AUSKvl15U83cbs/B8NA/4iSp1OTEb58O97rM0iLsWyBLSq2Lx5ZblFZeLi/rPkoVGMo/o9UV8pNa";
+ sout << "Cy/ushHj3CMku0REQZQfcxl3OgQ6U6avNywmwdncjoDCwX2j7MwoZ9SrAKvxSedrSkdRhMvVkuJO";
+ sout << "l9c6Mp86g8KafJQIp2ZS7CMC16uoiaUZCqdcdIEsKvzMreCoM+XGxXE589e5BsfV2gTuxpvnWVCs";
+ sout << "8BaKUgLnLKCHzyioUVC7bzz8Ov1/Wqa1LHzsrobPw7DV+dIDHqAS0d7TRra2q6rXAqYA/0P99NTp";
+ sout << "uhJkixu4ACmGbuBO2ob+t8cdczxIBSRPHHZNCXQG1t6A1IYpozEWiFjATpoD+gEL7Mv8AZW7AF3Y";
+ sout << "wFQEt2JapH373jpvJ/lAIt33yTMgyTQHKQ6bsTOWpK1EK8NvnOSOeiFfmtseS85l33K8vl8pyXBL";
+ sout << "Fx3dqms9fdbwhUmUWGcpzynCrKyV6DVGFznQtpr4QCl67Oz7RaaYXGe+lOmyrBgrnjKcQk4ue99C";
+ sout << "ucZa94y9SaRBTFNbXtpRm9jjdzoDmtd0JPCGzSACGikNArk4gC2AVKMJpFdJrS3e11Ew+pprc2uR";
+ sout << "fq/nf3C9WQ9lQ+KbpwEokeXHWNmWBwDZwFCf2w41zxYqjTu2XjGt27ea7nleb1IA1PtQtcSz/qoC";
+ sout << "6A7G4EFk2WE6VNL7/Wv81DjrdTdYy8DFQBzab3Qc7T0+T+aNZdNfoZSj5YK35BB0mmPbxw2U4sfx";
+ sout << "1nahYLG8OEQgoOVG/3vG+EmA+iCAcCBHPObthul5KdW5JXwPw3AMgL0UL/3GIpXRrzE4JTPrFfzX";
+ sout << "zFmbikWuDkDw9sP8UNfHAOGWE+OorN/Y43rcSfVA6xNgMhqr0pBbvVTIXbeIy/nR34IJGR5xN6YS";
+ sout << "G1SriivqytwRmJLC+Z4OoSVtqNPzPvUBzs2wHPF00gGAE1FpUa/CfxOfXhbx8VoZWMcuysBgpRCg";
+ sout << "Xne4u7qvre3CRkuXLazM6qKDt47OA1is3gG6ZA/XCvcI8EA61YccN6zTpFNT2cBGQNQ11tK5jqEV";
+ sout << "JeONoq3L4SCF6qCiCKMg0nsc6sk/xJ5Tam9Rl292lzxTV1Xa3be+lfSjUdiWnmbU4ml2Y7FEVbWv";
+ sout << "4A3WHted4mKb0HiCz8wgbHrjTnNQyIntFcXuNAUUzri4nGsV128EuJrPmbYXqWuKIe1wPT2Lbco+";
+ sout << "7qH3xg/zxl+t/bqlyqghqoYQMPw1FJXz7Z1TMI9TNbxUtCd+J+mX3xkrYI5AtCbTZh3lMvbwWJZa";
+ sout << "4AOY4WRavT6XHsBs7hKl0iHOmTyMe1bkM6sg2BChZ2B4T3BoeFISya1Qn2U3R0Y61plw8OFrTsAT";
+ sout << "FKgD4pSc2VaV5bOPM/vqg4B1In01SOw4itnoIC9N4+IvHvzQFYzeqiueAcHYGss+uaTMqqc9c+gz";
+ sout << "cgiJCv480X1iiauIKK1Szm/06GiT5WWnkyqWEqxmNyMPrj125XXg75AZeRDazTWXjObzZkJ7UA6i";
+ sout << "YQRFCFH0enKb75qIjHQ7635m/Xegoboat9VCJDGGVNVG58hQAVMAWHmnIb6oK9t1aaqXmAYXXgFo";
+ sout << "3uchUubvOjQXN+txW64ZRXzM5lAw4TklqmhDjxruXRdddoBONkvSPSh/XiDTkdmO6QKTCh/t6NLo";
+ sout << "Jb9Z7nfmQUd4adkS/+aUfb9LNOPd0GHxq2O/y9XXHvRTBASUV6xlJzpnOBAtTJSDrbLj9NIhiWpF";
+ sout << "nsQ+Nq9aEthi7RHvc6rujDGF1gqfJjsdPUTxveNk22Xc0fe4Qm+07HEg0HvhfFWJdFJJHirD5j82";
+ sout << "bZ+QZTppInMAQRh+dxZhlCzPKdqS7WPS8dpXp45xBpYH4GsG7SaRgbCBm8R96j5iSW3oYd+ql99U";
+ sout << "OC5fYGDbbGT8zl+mtYi23CAV6AY/izASsZHq7KzSV5taKzEZyJXf9D+wRaopIZ5JPPtmi/fr2HAC";
+ sout << "XyKoAPNAywLXUTtDezuKTJiyy/DBlHKxloa3uVTeG9YBJfiloyym8p9GsCOcz1p0s4K1Oyph1NCo";
+ sout << "CyklDAHvlMNjTA4hkWIHTyIStVmgvVkNN93qpPYYNnV+fIETW0tvJDnyJ2OnHVP4h4VBWbntk+uE";
+ sout << "JDG+DLleC5n87Kk+e/lurPTik1Yqpicwe2ct0AgA9XJCrDO6CYGT7QzFhEujgOYqzUmFe+KGn9u8";
+ sout << "4QMY5HbYXY1P9V3pR/QyoyxtuZNU/caTzRM58Hxe62q7BHmsL53db7cDUE5dYVR3Lx2uK/pm3QWH";
+ sout << "IIoyUhcLmk04dLwoi7YAhnao9IJQ12KlFISFoa/3WvhWbb5M/GttUDXGX1A9JuGrhjzXMIfGyS/u";
+ sout << "wugpWw/5SnWaCis9sNGmmMNC2rniqemMrjAUv5mGPYA57KyKzvVGGKK7VM8qyrtwSKOQ7fWRevDE";
+ sout << "NVAwhR5NnzdQmL1x2UrFcEkARtk/8RU/tfy3nXqg6Bi8g+myCn4hH4BOtK5E9GItI4R68YjNDVBR";
+ sout << "MiW9KS5gXcigTV1fXrChyjn4Enh+GrdLBpWDJQr7MKSVSmqGXntmhyTAD8Hgi111whRmKXgwi7OC";
+ sout << "k1Jn/NAURbtbi18YCjCc/+NfStTpuMv01C7+iZSQbO2rzP2aqaYSSvIx2e3/e1fWfZfTJBMQzcfs";
+ sout << "zITZqhDilOIDxPPYv9xAkBxnwGXIkPAi5133yWC5gcNh9hzEq146cKvQ8BFi1RHGL/yOiDjADrns";
+ sout << "Ky2PGu1fBfFNfw7tqF03f6zNB5IVWVjmnKtXhM2bIGqI/EM6Qx79bcYKYXcFg8yYbTc3z9KAcgp8";
+ sout << "es3R4Rwfg9irMdI0DK1CyYnpoa+zWhUXUzpTbSlFx0zZ0Vfx4xnPh7ZJqJ7/U/QNAU3lUP8LXVAP";
+ sout << "+jMIEj/H2mx80xaTNCDp/Sl+QcU1jZYx1Ycl1B4gH2/wsYYr59q/mDD0gE70hQpsE9ODlZzx9hoE";
+ sout << "HFRmDUTI0um1u5fLykqAJFVPcfk920vT5TqEZSMX0WvMVxqvN92IUxy2JNpXXMlBiekCZhEKKQqg";
+ sout << "D3MqZCUlGFkFpO/zc35vlTBUCuTvbcnak5HVcREFV0yNbTyukhEJzcdZArcCZs0xIbM873u1y/mZ";
+ sout << "+dUcW4zZJ4jNSa/Vz3W6sj41hSQKrQfPyeeKpFd5iKxJPXDxV2iuv+ZBgJi584KVpYSVOUA8Bzi0";
+ sout << "0kQh/NO5BdU1qaekY5lfrpA+O/tTzkTL1ZslaaBimJapn5KZcTtsZR0YmrE0E5okd4DxUccbElFb";
+ sout << "81SuaH/YCvib5kUbkXA0ooelAU5UmmMCZ1ArVVgwlLti65IqtSWD3LB8rIOUT+QKutHVEOpNMOrJ";
+ sout << "A8DnPVgF2kmeOngOlOwb8Hgpg9I9YCuu2cojbmROM8BZh2V6E9oiBTeKPgqgZ2tkgqGSf27SEpPo";
+ sout << "xHPyKMhQewkcq3WQ132hIb0s96PdEfniPUSig2MLn+5qBZDFLHHuB6K+pl/6OF/h7ZaAaCa4NCA9";
+ sout << "32ecTvY2MzUHmBM6CI/QKLRtBjmE09TB8s/5/X++k3cde6xPE77c0/G6qLrOUKFITVb+iLVADv3O";
+ sout << "zhj95j7URImm87lmo9ZBDTm6bbKoQtNAzEpBAiGs/VDuFfA5ZoQcwTwEnhmOs/4bT4591uwosx/T";
+ sout << "G4whvzgfcbPt9yHc2nnSSrIdYTSAwFZ1Ksh36RgvDXxaOHozLvlduvPaJkYSk9HijbmULXMUcyvD";
+ sout << "6IrzOnn/8NY6m++hE3KxgFM+hlrlAp3qQBDEGj5FAYFu7AcCT2r7YB5kiuwjM/7J8KyJYWg7Z6Nh";
+ sout << "zJbYleW8QA54s40d7D6AmOBX4+5Mx4e0roOM7WEFoqvXqK0ebAOrgwZ7MmqtGwe+5/bqkfNqg06m";
+ sout << "o0UCVDLmF/SzqAsxgIJt1P4/l8HzWtrFINUcBwcDY4ZhV1fbWICOmej8aBMJ4j/iUHLkxbnzQFBr";
+ sout << "xuNgC5/5I+iLjiYTd5+bdsITz2aRpXgUHx94KdOtcdEJ7uLjUeo7jeYvQVt50x/V2thmd8zHebNA";
+ sout << "6Yj14+3dHQH8c29EghfihdNobQHSyw9eVo8RyVI40oUxykuk65ILT1mLbnASJcKIInvh57YMVTRv";
+ sout << "rpPmXm/b5ZUfTMbzpTwO5DThvD+TTGvLY+eC0awfepSJj9CjiBTqNBJDbd5yKVjH+oM0wTufEUgP";
+ sout << "Fr0fondoPNTB61iwLzHrGVbzU9Rl8rsuiQOr6Yz/Dj9OjgzyuvDb3zMilAEcLydUFmkZlq3uVCl+";
+ sout << "LTYgz2tAQIl0+6SWZCel1bGyRKY50U2bJO0dbHawgGrUgZ/OO2k2HJnC+0+WKZeSj4s3v1OfCZbq";
+ sout << "F1nQDtk64Q2X857gH4kTx3MQ8gZCok7M+sd3p3eA43nUuis/SwComUQODyRA+TmuoE9AdASWTZQw";
+ sout << "Vz/To25UixKQtov8has8TRMEPOWqtZGHg/1tGGyZKkTbHQlulKwyb9uusuzC+FAQUEtpTcXkQzDE";
+ sout << "RNHKa/I1/zTSLVnaADsLgrQA13Igh5w86Bpn1bFgE8iEW5zQoNs6KDYwsRQ7mMiJI75ylVIvfbN1";
+ sout << "Cp7tpGCKCL/8xtWWII5ygWogcGhi5abUDxG0WQOaYk3+kPU6w8OPpbaT7h9TDO94XhcwUb5q11O7";
+ sout << "LAKIZY5+/7DhdVfMnplHpLNftDihrGwf8r46FUrZ8MRHAREnHSsch5dYY3boGxsFoNLc5r+fbJVD";
+ sout << "TxslZ5TEbJCFrBtFqbT3xuLReyo4FsOPZd0N3ZtFXt/ChTutUudGtCLfD3ucoJohZuz0Y5Mq5sHQ";
+ sout << "Kb1vAx9sRwpFnpO5o+RKYaGh9QkdrPXqUlcyXkSrXeNVfGGyET2vZB4VtQEAlG0nF2dkiF2uxKbc";
+ sout << "6swYBEYWLaftOzle9xXK6NrwQCrH/guAo10Ct+7uLmhK6P5MZW4kxyLi5nhOPkB/jiASFRPN1NOw";
+ sout << "KdNeaakEO7PEKPdcIhWjvw3sePbhOyHnmdJGHRB5WUvjOqD553+DFu4pgpUPHx2PV2IfEF2Wtfvk";
+ sout << "1J4pPO9DvXw5kBSFmoR/VdU1bGkVfkDMafaHt2diRtBZwbJb6XgkzJsuSiEtzaZ3BLaieOYJKCFf";
+ sout << "LK0/Jj4SEEDYj9NIHq5MsWH9PyVYvV3+IpJa/AB/WHXF6XAMOsLR7ulEmViNjv3UYOiWYEyT2v2o";
+ sout << "hYYg+6/lbwOg3frwE4tHu8IMfIZZzoFHbfT6MnP9ajoZWuSbbIj5KwvJ27wLUsQPevjbTypUXIEv";
+ sout << "W5WLBcBCegLZyIdElFV+1kJmSF5dZhY9RJeDOs2/CpgRctMX8hUGtMjLIAMjm2yzS6FgTtuuPJdr";
+ sout << "ErWoYhOWgGSZEitvtLZn6tGl+JZghT2F7dfmMaJH8y86txR7YCaPQgThH1zCh+yE5hvZPXW7+OCj";
+ sout << "D/MUIXp648v2+9ItiAlAZ+ho+JzhIA1jgC3h5fWWmEc7RnWvtyiatDFAGo9Z4UCz1qtRqoKdlhr9";
+ sout << "loA5N8Q/E1vPO/G3X++gNdEU6/ykRYyAODNhxYTnkiT47qRUzqSgOwsi6ePyQVuLDbSCwcpZK0ps";
+ sout << "qR803wEQBh1/nA3ZUeO9g7ZGaybxNoTj6Qs0u5vRNudI6OkjlUMMTdXVWf5ZyV338dihwHho3RQT";
+ sout << "Y6nYvFbud7IP9HMA5ga98s8z1ImzGbKGTMqy15rL5/PVkyddf7MArGuMSKWR4F0njVU40ILJh4Ns";
+ sout << "xkuzzPGZHb68JdMit3OazsaLvkZ6kBlq3KwWH6BISfiFv8eLYRxCKSfGWoqcw1qlUaFM1MQTJ0Zf";
+ sout << "4IK+R9+6/GGQJmSEDCyK1t4jbWDr6pjoXOIsihE/hsJj+Uy0mcQ/pRQpxGc9DvL2/7rA09Lr+Hk1";
+ sout << "xmm09tLOMEtd6mNexizTXHotFeN0+M/fMRwjIUY3lZSOHFVhb3AcNRcVE8IhsOnrUu9SuZCP5cLI";
+ sout << "xpc66aGsMww2jY/TEMBmWn5IUpX3vNe26iJ8kXZYYqPKz67I1BmH1wOXBaHZVznvMHXBMF7WhwhN";
+ sout << "gq7Pf38r16eou+L80O0aAaHE840lzt+jynl/vNb5tZ6F75G6jIPW+ARUy002bD+n/k1lnmwZ1WgP";
+ sout << "WnxwFloLan9gTO8dZG6uPyDHL8HaXU7cpty67AFbDMnYnajqpixUyDYCZeeYDvcCv+lqDm3YgKPx";
+ sout << "XQPEQQP+qpwvPcwLIebl3F15EhdJuy1gdwt3Dyca13/+aQOq4Rd4AYFl9U8ehlJsrUOus7Q6xohm";
+ sout << "J6wNJJzJpnRwNH+qekXnqKqRdTUwcQ8JwcN9u4sUE4pmBWS4/hJSYHpFVaUbMPmun0Dou54UPzod";
+ sout << "HFGXXZd76cZ+1MqaoyofxHU95d3rh/47EjPtQInAKAXmQlud1wYLjOhdDwGyvhPN3XfohEPrH+s0";
+ sout << "n/tMDShXPyFIaS+WLQZbaid5LuarlwA1rQc5cJ16xtSrW9sRjpEM22MxGmCyyuR5mMhbYLlULR1R";
+ sout << "b9UJ6/BbJ9iqO2LOh4geqx7Pl+kvdiOwAfXLP3mFZ8PHTHk9mJMa1TlA/fLSo8O0YqHMTnFe7OtW";
+ sout << "qMNzs0Pz4ucJ5uGzptFUXyCqq4vmSS3Cx4eolZv36mgGa7Ll2TJDCJckb6m3lsroRCTs6zF4B2yN";
+ sout << "VfTekTPXttFqhi4GI85Y+ZBvMoOIAsvysWCsV31AzKg/qlb2+dOumyMScT0GX5+l3OMO0iqG1QXG";
+ sout << "3fJZOpDQHq07WuCisGH2DOO1sEE43+d8fZHb1l7PQ2w25w6hq0IA8u73oDWRyTr5IqY+oDM5d2K4";
+ sout << "Ysuzkl+odreYxlUvxCjwgG1UkXw31DQ50xIwpChB8rEd/5gW10lnXK8XTbVgL3fPuy0cd74toIhY";
+ sout << "JFg1jyA1ZST25w82oJBYq8f0NQtHc7bK01cS4OM0hmOGd7PDKiRGiIrsLcjgfw/J/jx3mBuEaR5J";
+ sout << "X0NRxcrJZN0RdwhmwfRIybRGdSX46JVUCrf15yD6vFi08cwkMxqpdn6fDfh0BTXzhmdJU0+219/a";
+ sout << "JFJBFD8iktFAl629X8denKKlc2sMyBT0IODc2OYlS1+xn1t13uEeLar2l/EgG15/eeBsELvM+M61";
+ sout << "7uSI7t1szDmeZmWrO5yVcOBq1HFoEmHIlDJNtYQ7sOqRiaJehiRNhnj/9umPDOWCmBdnIXYAxWMB";
+ sout << "TJXpWzOOeRvpFNXUHN8MDttwWjx7Osmal/5ZqIbX66+jdQrLih+L4//J8oLyCBjawxIj+4BVLSpI";
+ sout << "/jyMUEbzBmRM3HiOoqqgiTFlLeRAM666O6oGhNmrFC/Zf32boZ6YmU0OicO1P1zsaRLJkvMbCYaU";
+ sout << "Dmi4reAoKBG1PBqaWEhDE2VmjiPPCs2WqZfll85wyGQDgKKdyboLm26UWwW44oM698E6stYbg8cv";
+ sout << "P8eyOR8awT6WwPgOS3fJZFuKTgN5LXJxk2BVCzx1F3gnmMQ5FHhIVm1srDDFIIwSkuP9mgjN9lEI";
+ sout << "SGCLsFYdEhVedNZuOWMYCvV4RChI5XCPh0ZRilVbCP4kkj0Xaqq3SbNWD90lwwJvFdIkZOEPeYc3";
+ sout << "xfLARYjHqHbNkYZLpnTw6BPH8L4BCDl4v7ndyl5Ef5zvShk7L2VIyKUK+3lboiVjwDoI+k6A7iNC";
+ sout << "UW5f3fTA9D/s1hxS/LxaFU+9Z/bOMjiRaPLQ/ioOkwyrUJpsy8/K+T78Vv5XXqtOct0aXrtUydgP";
+ sout << "uWZKhAFxJe5467wlC1aja4hHwyEeQmn0P8ilBq2G2hDYGEnKEF3XN7bwyFS9RGFdU06jkKlop8cm";
+ sout << "s9bVHFI4NeItHNO48LVlRFQaVdq1p5vD5elrH1h+J0s0BzDauJaraINeTlF/PrpGzdxiA/vq7sB2";
+ sout << "DHPBfrxe/yTTu3YTnSzMYR5mvPdjKh9mH6PmzTYyu/lUN780SUyW7mmEqoNT3LcgPg6QGUfg7hqn";
+ sout << "HP7ml/WtBRhZh1eLok4o0lVOhX5S5vMkXFa7K2pObSkGyyItCW+Kk9SkH2BboW7aYomw/3tAd2ce";
+ sout << "FUm8H6Q9NP8KDpaebg/67KZijw5kCSH9dBi/XzfZY8vzWGPW6Uj6zpbcDfRQDb1jwrPKZRhQHHiT";
+ sout << "dU6ZscvJMqSdHXsCNByBtKrUX37K8C59AaEmCdshNlaIOQLvceIIpJ7XTfmtnjSSvK0zJzBWt+xf";
+ sout << "RdpAeZwDNK66MCl0DIJwG0aRK0nwuePf0f2k51xlsrnl5nY5rM+hfblfsm2BKH6BH/SYlUSLxQlC";
+ sout << "MhdXA3MGhxN9yuHCm0m8WVc4/4g+SCKnp70ZghQORzaijlU/0KS6P3e1po+y554gEqicyXXndnwA";
+ sout << "qmRp58+1pxbwORmCQIdgriAZxPQl77hXxCMES9WQkBVb0RKNnbWCWvVdkwxxYHBGhO+m4kZVjFwW";
+ sout << "0jxlEcKNa7WdQStRXCap2VDFA0c5kUS7WN6R9HbM/u8xMNLzzE6jrH8l2h8QrLkjHE2R5y/NIVAX";
+ sout << "G2+cQp2X9XxHVWQmfflpXPKcDGGYR70ALFrgeRkQQAxfR4WBDLLQitSat8H/q+yyl6hdGj5nmYXW";
+ sout << "BRy7URAh0ibk4K/Bs4d2mN8ivkxahOB1wPZS2FFHP/I9C7nikHmQwlg4VlPJTwLp25JDsOypt2sH";
+ sout << "J8CK02T35fcv89Dg6GKoIp9RFkJ2RcSUXyWP3JWFxH/u1rdTmOlVP7mAW4MgT8AQdXWXkfaJxq4S";
+ sout << "n0ZRo9RT68RvAcrdv2z6Hwh4klmlRw9qFtokZMGLk8BOtPGSOU/Av3k0RjEsqHLiQaRo0JpnBCw7";
+ sout << "JcdRwIuqvfWRJMpGF6DhHkIQCFZcTSftnNDKUBLVBn41O/VnUCakShsRIy0a3hjQpg0M/7hFbZI1";
+ sout << "wqiEu/aR4jZE3wixAJhXFC86bb59UwlH6TZ1Eszn7eH0inEqCSlMMRKTo5Bz5kybjP8As1pkPS4Z";
+ sout << "/A8L6SnY/HIm6jy1MiZJmzdimBLwv14JNZ2oUM/TWMzgnFayE0C1dwiX+75KOrJqLl5hnM/5r+Rz";
+ sout << "dB3emOb6i0OQI444QHxHLK6ZkLlJr8kY5opZlAfjyzuWzJUG9p1CXDfL4t99GOoq8TebYYdjmbyV";
+ sout << "BLn51c8VBQ63UM/dzKKgL6n5aUWAenJ9sNV3/b+fmN4cPWAv6TZQTDt6XZxPz+GT74rFhMv5YcwV";
+ sout << "zkuiYPllmx7WE0pegs+TV7BpKef2DBNlo6D/7ELNsr7L3+rz6w3l+m95i8jtIpeoPRzNxMWfvFZw";
+ sout << "OP46naSo+j+8Euy3CnMnx5HQVwBpEeGbqjwyhsKmRyev5wzgYGe4mxNTgrTKsVh4exFOYkFJovWm";
+ sout << "l9VKTRY4/Zd0H4dodHVquSWT/i6ogNNwB1O4T0I8hs8T8i4XybD4hM//dNvbPxq82HobB1R77rqz";
+ sout << "7wGi/X9MmX3ct7FIpU/sYUT2DeHZ8QFMfiBhSLqoHcJj2RgNr4cX9ZbFXk2nzw4Dz/ggTDOx8XRO";
+ sout << "qLovxOk8Sw7KXBDVmW1/kBhLVN/yffDJjzwwk/JqiUI3kPaLOiiJnnLRH+fS52mC/+M/xg7JHde/";
+ sout << "oGmlCCGQb/1q2j7he+nBCFZM8F6BVHnlA7DoGCzSetTLkOLo4OZssfDw3OSU/2vxTxkZ9qsarMCF";
+ sout << "uIl0MrtO59MbXtiMtGS5gcZL7Tu0SMDJUQYFtoZ6+QnDgHD3P3BsqyXfjRR+DTGvcVSyh878HpJV";
+ sout << "3R786KM3ZwW8iVxFabz1OOFx0Ox1fCspZa3L1ckftSTD6ypZAN4f5ox3Ad2iY3l2J7iPeE6/a9sk";
+ sout << "MnHQmuMy7pMum4VVepnd3lfqJljkCG5ptnznCTmp2wPgkqmvzXl1UU7kyYrdMt1jTzg6fiKsbROp";
+ sout << "wGSxXLbRGT2sFjU0KExb/rjG21a0NVVp2PnYMQuyF6glGmCeSqohLsV3nv7eCbUd42IacVWAR0qN";
+ sout << "c+5Dj8hdAmRMU5DJrbaLpuHeg0gouaBYrmCYGrcMqcI77VLyeQGF5hzq6t6IhiPk9uLrW2tpF5Df";
+ sout << "gHr/I94WMSy4Wf0RnACjVxSeRewTXJDuA5pKq0EqNs2RzmiB7LoktPBzqBgMjCRyKlCAzwZ63XOt";
+ sout << "vOvKGWCbjXcVEOZGAY3s4lQRiNIJ1Wat62u0MSj9BERVU8kd/NWueMn+tbdGTSR6Hgyk9bvHF+Q8";
+ sout << "CdFQntiNHHU2Qx5ZbzPPoNNjqwy+FGicBBG8/xV2bPd/JWm3uo/8BFPqzORzGAnG/BHSARPbEvkJ";
+ sout << "ix1Nw+gepwJ/M8YNp2aE+hYTBUBJueaSPyPR5FS6ocR5oLd5j6biPlTv9gHGg03Wc4W3BqrC4/He";
+ sout << "R5sSsRThAGj/wbDu6SMIoT+2gUNTbUZ7Lx+8rVcWGY4OLaySvre1BzyhGPpNDuIhmbW7n2jfEO9C";
+ sout << "MYbtPL7skgqV4L9JlLup+vbCNC5ccOpxm0Y+XSA/Gidd+aZu9D4R9Ddz66ns1Ida6bWLFM7T8BpE";
+ sout << "dYgcQBJ63ixc2nYknR7w5Ez6JPnmfUSDjXTrY+tnXlbyWIrF1fI1Og7V5i0d3G2jcuitqs0v6SwR";
+ sout << "rgEDVnVz0O+cd56TOssSv1d2KPIY6tv22oH9P1sNzjBYN9np1jRH7e1GpnoEvY7MA2gNyBIWdrKA";
+ sout << "uUs0c5gYr88/DbJvZtSnmPUMRr2Lp/RtOHnabXkbAevAovdIvmh/BfnXnSjrOkH+BH3Ici47ohO2";
+ sout << "Da/7OlZj4wAESZO3WO0SKN4fUwC6M6LUM2L7kg1tvqLA5KtYokFOc2dMSRs8qz6qpneGUv/f8NMu";
+ sout << "PphJtFjOixUrhCOZVv+0dN94L3dNPRyjXlDp4EejgUxWdeYVPxiyXor+LSIP/s4yPNmn6wIYWF4l";
+ sout << "hIFjh3drkdxwsaS6QWeYL3knA6I2rR0NERayOsVKxcVftkonwoZJHpdRkhDhBj0Sa8TicGcSgQt5";
+ sout << "kqnjLvOqNA3A4XkXj445r/HiNWVbd/v3DRYn0lh09HU7ihGo/qrC1DDjIQjRm5MO64mM9xZY58yU";
+ sout << "2fHWbglRconR2QeXmQ1D71aEvNZwBs1G6s1VqbCNgpeA+7kvjlzwFzyvfIqMuLTlvUo+4uCHf03O";
+ sout << "im1AxMxq1zvmRPM844aIenx1bhx4vKmeEwTeHKYvjUkTCMFqOXxq0UeK5XSKJkXrK0F98MORKeEI";
+ sout << "lhpQXL7ky8QZChU2k8DDBaocWzF9oRtrXDY6vfuvcituo7gbfVIDzQrp4ran2eY/pvucYyn/Maai";
+ sout << "W5IQzzordTdlVKX99bpoGJrmFzYrrpUa7xxb5Hs+YvwicR8IE/2ePoZ0t04dBaIwPP3ZuPO0G3EA";
+ sout << "2XGdG7f9AARr1098ES4qU/JRanR/5XZY16deNuVPRKMdUEBMKijfHoQCCorvA2HoQz/jD1nhg2ZR";
+ sout << "cB19h0shhjnz/OVgb60bN4izBugSDZfyiuZvvdpnCbc7bReVBCM8456EG4mHdI1NSdGJfATpvYbT";
+ sout << "lKBwrESSbmkYIVCcBxn845i5lrPHMdlt0s0qwtcqBTvjBmhRpezoF/U1BnPdETs303TMgApql/xS";
+ sout << "MWS/Jme3ioI6E35BzkTXABoxIxyjdFkFTququq686oqyGCSpCxlJA2OyXwUc7XkBI4Ki35ewV3Yh";
+ sout << "DbONeeMe9NAred9EhUBg+uLO3W8oI3094ZbMa7N9b0tVpewsJuqP7nfW4hZtLSpFEtjlQVGsqfQv";
+ sout << "TqljECWVoZ4cbclYUXTMqtd/5okENdDEk4BJnmd5b091PoGNeSqYPf94LVrlyVSM6G9DmQm1UTFr";
+ sout << "bUipf0LcSMZ+DKYhwIJLpP/GppEMORYKRhDtgPy66GNfhiNKzLP+fyPpSf6LbtCUbE3Ni+PJ8zux";
+ sout << "k810i0by9lVlFUW9tNnUcSE7IPrcFs8fozpLwNSI/gTKFy0mNEnHJ4hlSgaVB157PLKmHSv6WJIe";
+ sout << "6LE89xQbc/Rl2TRjh2SxZXB62FoRLXApUFcSdCl4ZRdfA5p5zXcQVWxl3RUdJbpUhsqNA7vmHT9q";
+ sout << "qinoGtKuPZvANZr9Yn5hB8zzHLiRvk9NHhFdR/Ibw2gGXq6XMuJmrShSFRMZrTiIVtQeBWZNzQBL";
+ sout << "jViE75H59XWkpeUM8hAUUawEH+CFcYynOsbTECPSbgIbt1w6DRB096vYoMnxnttX3J7SIf4JtpEZ";
+ sout << "iSI/Xzdcu9L3QX7bQaHzaf/FIqWvdl3quN2CrSnBd/UbGCdgylgCQyBtQjcr2lP9O3QKav9laEz2";
+ sout << "YJbjjXldgdumq2W/0rTopOIX3V3xCHB6xkZq1SS773LCHWAaoPRMwgfrJplzRSNIaD9fRsVQuS79";
+ sout << "TUwphVg4iGrl/0u8YLEr9TmSTojfTS3M23HNJU/K8l1zOB1gfCYwmj6jGxTCuvDHvgpJ/UnViT2j";
+ sout << "fnQjGXg/WUwEoaD+66ZePipeu+PlIHDSX3Jj2+jHAmMMsFX7PZlP3xhWUyG6U0hJ1guDe4HOcMWQ";
+ sout << "97gJDdYzwSE+S7vB1LQWKNdEnZPmZ+rQ79gpjNYVfzE2djCO9Sn0rQD5W6O7Wd/QCV+qu619i1wN";
+ sout << "IziA40bBSmDLiEtZv9Aviu7q7Au24X7ZcXxRqruz9zKPSFd6dhpWAK9lD8PCNwUrDXAvlHT8pS+U";
+ sout << "8K2ISz2knDjOiHIb1P4jF374K89RNSNvefvkGBNrJ8H7JiseiuhumYfAiSVqBQwGUeh8lnPq4aeQ";
+ sout << "V0gFbCj8kNEs8/cDLBnnwCGZbqRfSGgSXTKWu3/oSmn/bMralWdW35AgMM2tJ3Fo5fIkfzsq4ykT";
+ sout << "ywlPUtOOHIS8ImXHCQgVt1+/OPbTpgCly1i3ssNeIkjPY5DyGk08WWYiih7IAKxjFeHSavuYKy1P";
+ sout << "1QMI0dJ3y2Wp717Zg46c+iKxlg+oUR7It9UT+ePKJCm/UAayU7LyOWThMfP73hM9k5G/h+8oENw2";
+ sout << "D/Kj5IkwB3iK6iC/6QrZaIHTWVRNYioSzsEndB0X9NhBMlcTmyvPJfctFydhK5JjQwjjx96k4FvS";
+ sout << "P3C7zh1ww1K/x1J5vK/2NViZxgSFmIz7L3NbwUb87eXPictC8NVReqskmSw8hxNmgTHG/WEU6NZi";
+ sout << "XRI+HIV5erFIeNRIixP7gp1z/o/5bL1QH5ALgBk4TRc1B9qG+fcWfJg7o0NckPbkkueM3u1J79mu";
+ sout << "pcs18OQ2A8zehkFZeaTCC6Q6BRijJFBpPNnDIxJ7RDazpi3i366Hds6mD/r8IDRT8TbWqedF8bNv";
+ sout << "vGRPs1aEWykMumZdtkRZXfByK14T0gF0CSjUeIRnPuhXbfMBznBUGue7eCj3jHnLynEKdtuA7CyU";
+ sout << "FlXmJWrixsWkohwNSY+2jCNgmIWsUsvKS+dlNktC7OdSo+QEIDnCfPio5AwCTWBBvUaxwR90GeWW";
+ sout << "yk3Rob5grT0wgxQGoCoew1J9DJT6iLEuy/jJ4FjFSB581f5186b4lOYaMoSCpOp+uWEz+I0c2G0o";
+ sout << "DZKFloHMTpZyEt8LggzchgxypR0SGSFxm1X0/lhdPwa+8lKRdcXmhe2ai6H0R3lhR9ZTBaCWhMMo";
+ sout << "GJ6SvT1ttMF3rg9qfFc1RKjK4+TsRxwYOP1XDz1rjEp7KWHLxZ/JKzp4IXomPm2f6exs+miIPME/";
+ sout << "yo+Dtmev/tYJLd9poc+LAVkFIDqQ5ID08LgCGwAwLz7J+nLxSww2OjOxhbVfRIm0Bj0ERLmvy0jv";
+ sout << "9Oz6hHcYVHPqSGN2aMsPtpWUWOgND2ofYfAJqqUMbwalMYUyaimuIU9ODyoT0xIMH3VeN/nrIHDB";
+ sout << "u1J1iY/Lshk4okLUB8OnbtJY8t3KATibzVtoLaoOcmcQzOJahRLw861zYSlCs7SQpAqVEoRjYtZC";
+ sout << "5JPGKLY9zqoogM41u3uy16DfVEvDYr7XZm9PRYY7H33QZiOYhPRgzZvs9nJZAJ7QGRDYiNQG5V6a";
+ sout << "Zc/qLzF9UaRbUOxVKJ6FkdhuIuK8d38cN1ky2uUM2m2cFAyJmzKosMn1+ZKo85F097lSyuDqgXvF";
+ sout << "hD9fu2ze0hTo2zqFmHHoDvA/3dN3Ilt5omMSSf/HQdvOeXj/2vtmJhpnERGJo0NCiOZfKI5QG4qP";
+ sout << "313wIpwQPZxAvJzh9YEHdb3fv6GpwlEygu93G8Uu4WZkT7P+XXqhH5PL3cZ2MYQeYGR70nvUzJWM";
+ sout << "lz7g/DvPE2aQHxGSlH+W6JEabNL1YUvAim6/uLWXvjbU+QdKDUQsml0P/auI6oHYMAgXnifheH85";
+ sout << "xyT4TLqA33Bc0SRpwPV6lIigxQg8OEX1czyeYInQ7Y+YhaPF9Dh+YR2Kf6xLS/qnhOJrAr102gf1";
+ sout << "apYuPnSogxRcir2pFtltl7k7xjaAcfcOUfM10VkYOoq92X9XTfbEBbeEAhFtw5AgWJw/iryapVb0";
+ sout << "nNPvXwZ6g8e0LM8cky8/CR+68D7KCkxTQ+C3cq6Fco2niuPO1/IbF4nnpLB7vDS7qjrgqU4/t0bq";
+ sout << "qMPMFgX4MhBikxwzWh6xnQgylA47Rbzy0EKMD1j4NLSkySg7zpm29sTTvEh68OnSBaaiNm7yoIpv";
+ sout << "41yj7kewoFyIY6wqbc6bxfAc+cKLdxYIxwwXRknJQuZW2mx4wpsjXYqEwZ6OCDldMgYkQWuLd1r0";
+ sout << "LNY80XRGlKZw4bV3eBiEAo8E7yU6VBqi/C+fTWCdaD3kjC1mu7V9XX2tMg4G2kaSqc6JMjOZ9VF1";
+ sout << "06FJujVmJxzZ8t4Z8yI3JGUpmWgYnvY3baanLpy54Fd7M2JgoXJrElrfsEMFsci44CmK54NpntkB";
+ sout << "wFoMnSxymuW02SsBU3DEBbZJmXspP2yqCUVHyqpp2Wez21u9T2LV2yNLpKHm8uNRf++WJ49zGBDY";
+ sout << "vJ2xirDEOgRWCQ/z4J7fg6AwMrG0PhqX5Rw1NStM8rPEWqKd9+6KKTL3lMj0rvvVjUCylKvev2Vw";
+ sout << "aWeVNhBWgelJ9v/Kdp8lQRsSyqJ7JP9CQ9iRicBTKpKio0NTaYn8LEo0yH5B8HVycOVJmPjjYCg6";
+ sout << "eiZ6Vxr0laxchTlNPwRAU9hsZky0eO68C/cLrfgqB2IxJeZH0z1XWmsA5dGcYAWMDamqBNPTm1jW";
+ sout << "USE/0iR1KPxvIemOxt2V565of+jareTk/Q77E+jkopN7K2AmKOGouPeRUGEmfrtZQY2sarl0tF2r";
+ sout << "f3TvsoqK47umT9s3R208qwO7nC1TQjaEdP2KIfcPxbyuZ38d+GxQcyuID5+qS3IQR4X1k+lyUxY9";
+ sout << "wtfNYYbmsXnqAXnM/mKlfGlYalVxalvxo9eJdugBIVbaQ3JmEovseS8IV0dZz9ItdrW3HtHquix1";
+ sout << "xKf4VBhFqpCdVXBY0/mMlcWws6nItmS20yCf8i6jJ1HcaHGuPQvhyiv+xL8kELcu2avfpf93i4m5";
+ sout << "KaJb4xkjSXNAcBerhBlZ+DO58TBlEZSI8UTw1+yuzymxOyNbq2o5rwqAqMDxmfJlJvbum5OnTWCc";
+ sout << "2AgSoTUVOUmg1br3xDxH5xX81Ty+fMaizY/8P+dFg6MGikZZuUky2R/nJRilgi4thJV2q9iCm3sX";
+ sout << "KVW0vt0AHRnUb+QkaLl4ZJwGUcx37S2YcsFCAk/GG9F3g1zo3zUNgbl7WelqICfme9vwcl+LDnvK";
+ sout << "q8jgEUNwGe4ZcgDuPAxiOmIo0W0iRDbt97R56EnRs9n0E4W+/DWBvHY09nBYtVrQWpodvZ3ZkEp7";
+ sout << "9zg0d0rwaLtb21iZPH1swSnf3hTYC2sxq8ppPJKENrTOYj6qiaye/kqg/u+rZmhd/LwNckkjFy4H";
+ sout << "9OA+Te/sjSoRCoSSGXtGSexAhh9WfFMnivoJ+9kAIGRUCJPQCx+NPQ5s1+d5sMZHyGs6CsR2Hpjf";
+ sout << "ZvXMpbZVEV/4Fgf7VMtSexPTV27hFxD0wHrhU4lsQp0T3QcI20VYx9kTz025uiMWUD29Fm3b7peD";
+ sout << "K2bXo7y3G3k4ykOPrF4kR0VSkjB1NbhXftQImcpZiSDce0MZOZZvQ3lH0pZ7LaU5d++Dyp8xsxLt";
+ sout << "4toEQgoyLaS22tEe1jSQsT7v4AWwjMBAQngr8MsqmcWyu1UYshp3kmosoxHZ5N0w+RN6P8dnRWuR";
+ sout << "N3gIp6dBvllTNVG1Nf1zJ1r0w8KITbzsS4bBLijKlUilL4b/P8N81fEUDaXCrcdcL9b0fzv64vg4";
+ sout << "9ZXsKIfw3rjw+/ZiDt6a4WtwCZvCyI7cRRB2YGxaeynx/Yyqq9xwnWThHA0a/SAwkokg/zbxfi3u";
+ sout << "kzPOBau+RpDKYDL6DBkYaxMeTnbF9dRZnckTEhu8kcSh5uHhVNSeyYSngJ0YXZV+tEzNi8yI0wA+";
+ sout << "6M9OX5gcHv1pbem3Fyx/3Rqu/LbjD19TtdwuyhXnnGcCXPGzqFrpEdN6504IdM7WDRrkdjpyTlOR";
+ sout << "5hzk1GymO/IPAk8uRPEHeU+873xDLoxeuNseeNDRuNwFa5FcIKmIAEGeSydVPv2uPH6l0VQezasr";
+ sout << "XMXlWOCOFo9eGGJW0p4tz4qX35tClsU+0ml1uaRR/fFeJyJ7+WrqyqDdHpTqnGWp9OdyyCiowxFD";
+ sout << "T7MJ/Py4+ytZ++sArWxfC5KATdztRSzuquLbzKfLtF5x3vihfHzonsdHzZNUTi6py1JV0uQO6by/";
+ sout << "1HjNlxiNqfOEaQxfLe/KwZvvvSVasZQz+64BSVLPlmwFnha6iUXYl0m8pNQ/E3QlysaeDvf6j6nC";
+ sout << "By2Ft84JgN/2Gy/XJuQRobhPiDh2Fro4ebBwOjIIfgQx9Q9FNzKRDHAS0/HXB+P41/CzXWr2eTu2";
+ sout << "PXZ0O8TshRiD5gH7Su23qpwB7qlNrPbULE06oc8ftAJtuKAGDPFXbom3zFVjsj/yeTmS2HvU/b80";
+ sout << "8FCgSSxx8cvIqfPHLLOSPBUxJAVyCFPvDIUzHE3sBPyw9OY40bgSyaYoC+hoKuf0lHgt20r1Tz5A";
+ sout << "Em0sAZfnbWRESfaM8lafrA+OjaSDQQW7wAI4+XO+HbWYxz6Zj3hu4qfmID2Iz2kuo57Ci1RLOUzl";
+ sout << "sWcTkVqUrdnoGzjtese/OnbbsCUQIfCM5ygX1KKgmlb331Er2So/dqrYpCFAqBo3SmctlCtmERuC";
+ sout << "hkB59TSoGqV6ePxidHGWKBsEHudcQyRv+pwNhnRkW+l1XmpzMmdQqeBjqcw0RAVJgDmHS1SMrpPE";
+ sout << "GZmNu+LmuN8xdGV1tbylOTC4TVND5NOWaPDE+vx0YaRi2orc7KgrLc9OhtAQdTC7hlk/8jlmYxFp";
+ sout << "xl0Fe0oW1fU0avYFk0p+dUXfYCA6HXTetcQJdo7Ai4JveZIUdZ5xW2snKTZXd/IAE1UG1o5Uz0LA";
+ sout << "/y3Q2+1hBR+4cVnmqVftEf4ZQQoJlY2CN5Wltg8kWf3nsgSmu1JNtRXEV0DKODWTAGSizMJJleG7";
+ sout << "ilV6GJU335MapNMUrlHYVNicy1NY8CX/PhZtbGzqNM3gfwS0KhWmKN6HS2BYwfPllMYJk4koLHWS";
+ sout << "R9dpGgz8n+PDrKOcLG4U8xYocSvLUVJ6YmdbYWtWunGa0I0egn3YmUc6RTflJ6G2rRpjY/B//CY/";
+ sout << "tAKUuY+4e+4/VTH/sm1z3yHSycLOorAeMjKnPpDfN5d4hyrqeHgG/PKdZ/QlMZYw0M2Kb9fpvefl";
+ sout << "cqeu2IZaH1E7G/YbqnL4wRlPq6HSQcT7C4r7vJxwYGGwZ2wClOXdvvbdjAOxZITv116e6jaMQ3vN";
+ sout << "NNd09s9zZlXyXDyBfdLgFebbKLjrxC+2eB16aURJkorv9jQydbdMMsZ7DnW75DppDplDyPvMC3TT";
+ sout << "0o4PDSd+YBeWo6CAZcWNR+lB/ZB36tA/RIyCopqeUOL/P8cfeJSX+3k1yghKm4HeY6VaWJSALiiP";
+ sout << "/0nEgGc4dbVmspZGzaORpGOTqqGnHm7TiZmOJW0PFrCVY0N0VqHcSkTOZn6PtuQKFXtFHYEtpepB";
+ sout << "0d332dQEz4AxgF0qrhrviqEZpXTxG1bWE4uC3H0tDWIJhyJe16xF96vE3QN6z4P2+esYwjsfRdQp";
+ sout << "3CYcxcjagfWBngiEfRsseTLhVf1KsFIXgafpK36gWlZan3O/7uyqhOYyg3W7aJ20bvZWuZZ94Gls";
+ sout << "7VLJyTKS4LNKHcb9Y2PtX7nzosb7tL2hcYgGChva73FTDSuf56HcC5+KM+t/MXf+BPBr4FUsgN3M";
+ sout << "1QrQiRM4dZte2DfAvSaZmWQuqtsH2uOLimHU2AAsfZav6k+7Qb4mQs6zdRPTMPcd8cUv8MWf3aOW";
+ sout << "skdmBW3ausbs+iQFW1MUqr/4lwe4DUekn8nfniBCI129E9p0JLXLw9NrT+669orbCX3UOGc5UHTg";
+ sout << "0SvFTMRqw7Y9mw4NMLOMCn8/mINMrmmaTvtEwLhHSOTwYjRCJ1A353ENYX3T551pA3QS1CrWfIrI";
+ sout << "apV92iQiSH5A/rzOjhq3BDCU8HIXihgVCsR/w5Kcy7nCkI4pvbsIjPNivZsDhx8H+IKNTUNpSRt/";
+ sout << "CPzh885kcdofM5k/nqVvvJhQzhpk18+IRuxTp5Dqa62MYZwgEgURS1lVVzWQX86wJl6/4cbIVEcb";
+ sout << "1dowX6FrycYfy81NkQdF06egVOZVYCaZdW7f1W4Xn4l7PWb6pb1HRb3/0Wu8YuYjm8pSY/wA70cJ";
+ sout << "G141hphqUxQeblVInhyk40zcaroWoWWGTkiHUnAp2cQgjZJRHbrfoJgF5UVmdDJrwdwiMREmR7zZ";
+ sout << "uSS4xTZ85fXMH1NGcOAu3tFt97bY2uh/Nzk+XucsND69TUp4s7RBmiahbRw8XQa5izZMRPgYkVss";
+ sout << "A4ggkTd9wiuotgjMtjd+NsitkGhuoQiQHLW0R0PVPEfG5mtdlp6/PlqCsTum8bRuxivsPkCcjEB5";
+ sout << "9gdgEeSfSOSRTe0hnm/gDQXs4BnFeyG/YzHm1pHHIP2VHkr3GlWVu0w7tLPXFLXF8kRESZgo1Np5";
+ sout << "WFYh5OmLS+J3xhJMfzNjUJ0bzfN5BnlQMDT8vttfAovg4IO7zFTexfLBxgKLWJZ2aQnoPeoBIJFk";
+ sout << "uPBxTjUY29133ygl2fAylnqP9zIipFbgkLEw0NrfKrfKIPK2Kf9VI1+adXQlsHo0vsI7l/uWbW5n";
+ sout << "XMP598ROkdbP5aSkpXPdld++HGzYtfD76teXubxsFcCpmm3OSbcLo2vAOMDQdpfFTOLM0ig1VhC6";
+ sout << "tCC32GUIWix0VLb2dZCdidXJXu40BNL+ASvXUAipVBD+qBL0uSHZXjns8B6XFD88zrW+reKH9ZO/";
+ sout << "gRZI+FH7eqxkZlu5ChFHJeZVSK1MwHk1ikQxAe8jVXTzzVAMcK3EVGvH/l0RKX/SJvU58t0SuHOr";
+ sout << "HVule85Z1pwxt5DWS5IYXibDgxOWfrcKp1iTXhhKsmOmqLA4M8iGbtJ59RjGp3ndoc8k5lgx0YXO";
+ sout << "CZKSNjMhF/swTPJF3fvVwzBU9vvWe2HCUTejmLi+bivqW3n+qtHiI1y9VWPY9B90N9OalQHJD2OB";
+ sout << "myHYFjs7AVdwFYoQZIBFtaKpozISrGZBa47RfYaS3zH2aNzYnT6zDsgKJYdA8skR/tNgMQPdgEtD";
+ sout << "d/NZPBXg0YnjIJNX9z3kcvXuwNnpv/ZafofCuHJQDRlWe2hvfHXZJf22IWrkeFjTGAzVl7bGry9d";
+ sout << "zaaFJKt73ogJ4IM36xHJ4pgGuP+daAnif+2sdHoEavTV6zvkV8T2UXgdP5LmPVnoz0zue+DBTt9q";
+ sout << "Lco7T8RhIgRDQL3K/PB23bc/gYdGKh2ISHY1klB2pUWu30sBa+g1ap7cTFIC5eDQ0L19jaJNlLWf";
+ sout << "MQR0Tshn9dRsznAfmjIm4RvhMvf97Wer3t9CjFOOaQfSXxdPQ0KpnxeyfyuBpRYScUX2cCEkIXjA";
+ sout << "IiOHn7aTJ22O3S7vZ2a4jVygPosfhvd/nNfWhAukM9UfVIFi/kG8CXvuPjKhn77k1ddfIbsp7Q2L";
+ sout << "Na7xoC6HiPus1PI5kd/p1NkclRSCiMBCl2qQfHqVIioK25jAjHl8SxsiPpkGDgiBJTO+HR39syAt";
+ sout << "iyjvIv5KH/yAzl+w13S68pDbAGZ7U8hDI/NKlhgVycVmkfKZzGB/9wzJVemHhvosb5tY/P9ZHqgB";
+ sout << "HuIUgdxusvm5IyzBBl0DbaMxwdHevHQgJ7NmxtvGDDofkHQ8JQXkr+N7+Oswa8KAEdJc2t9fST82";
+ sout << "X2NPa6s0wEQ9PvZhUQd45UcFtXCBbkCoisIEA4X5Af1xG7gqwo/PqEW3g0hYwmbWDgNTnEHaHSsI";
+ sout << "4k18Jcem/+3cNZx6zFKEZ6u6/JlvGJsNjvoZQlRluwF64lvxtmvxwD7ak+UmFc9Ll+cYhiHhdy7h";
+ sout << "e2RAF5HLTVf4BZkUs+NI4h4c5z+pU4dxM4JOVNQ5OoBisyMeFbFDlBd2x6wRMtiREIF4h/yDm5/g";
+ sout << "QIQN6sXCt4zrBLoPvFBvC7UvVPaLEpW4Ay7yeYoFhTUcZ7G11HdyJHIrWhXj8mazBuT33wiXPXD0";
+ sout << "fgDO2G5kVrYv837OEagI68UWoYwHFVQyKy7rXAqnJOH6rUOiSeXwrAs0VXCOo/iWz0+H7Vgkjvo3";
+ sout << "a+RGXFU0PakfqhrS4elT76gTx27rIG2VuKjWnNWi91g2MnD9dwS4xGNu0nB/HiOTDKKfzQ3uoE/D";
+ sout << "oP/Fn1wFPtBDC6935HgjJA30RBpTT4lZtqaVhy4F/pFfqFBL/RUU0jmlG8WvHSF28NkXMJFxqJqr";
+ sout << "FdZj/ImUTWIBtAOXIMgu8OQ9LXpd6kyh7dpy2lLG5ik8X18yJnJS6WZnRAx6nX3Egg7zi+jq6P+B";
+ sout << "AGQecbLJ+ngHGqQd6fP8hOU7ujU4yyAEKX3CWnvjK3DKN+fH9enZ3VHAgBS6uy4BicQeqNJx0D2k";
+ sout << "ab3+B1WPW98ieaG+G6TjyFGNFwqA7ComKR7X745Px3MW0qoQXJVK6FkV1x5zOjrI09hc3G/NqzC8";
+ sout << "5hR5EA4ebicmVwYgc2rftfhZ9FY8vTslWklZHaZosv3/QsoFuoDNMA82bg0/tDUo7G47EAETe7lj";
+ sout << "7MRS+dqjw9rMDlry6erLVu9vOvz5gVxhNM2+Q2VKx2f2ugkn8/VGq3ufEEoEkKtnLI0aJuli0ZFM";
+ sout << "yEDEHqombS/JdKacUNhM/zevqMA+F63pABb/wTCYhNIZWZM5SRCoYiduOSQYs7eR4ex+Y9a2T2cc";
+ sout << "UyRYpcFRySR4XB1Nf5GVdBmuuVqvg7wuRwCKXDkxcesKSrg7HLlUT9xaIReIHoFsZj5w3RcZvSCT";
+ sout << "q+d61nOKgP5ym2c5BwxytP42yQ3SisRyykroBgipHm7QmrJnrsw0OZvAkFh/bo/mhstLDPKkoA89";
+ sout << "Dg/kr5VTRG7Id11Qb6CKvSb9hq0mtnGAGnYa7XTmmojYhVXeFF1aj7E94bUVjtl1rJHlpNSBRl8U";
+ sout << "LffnXik6Ju9PK+gw/WXTgcyOM0Gmw91IqJsBiNVmLMB8VPpunL+J27aoPVdSxIpR1H5PjWhWyqNz";
+ sout << "jDTlxcIKpM0PBs7PyNJ3a0QSQYQrGDYI9VkTQJpxwYhPOAWdIjTKDBfx/wl5UyxbcAzD2e4mXkGq";
+ sout << "535EjhmfhG0SRvXP4heNN4Vzl7KMDjO9VNMB4/L7500528rL2b1wC81igCwNWpUL7Pn8/XTK8jRz";
+ sout << "KCcb92liptbbjGVb1V9hh6o2aR5l6ZQtS2MHlkZG5rucr+NT/SFYoVKB93lLnq2CbGPlzZZdPHLv";
+ sout << "+RG2+80boVipNp8XD53hLWtOvWMQGWYwVWSqQ6jrO24yXdRjZCn+F7mn9mP/8U2dEi546WLWvaOi";
+ sout << "RzQJdz+5IcL8RBPWCcUqR2VSyPIT5Fzw4D2+W6G3cvc7mWgiEjT9Vzjpeo5D5rWM3pAPl/kI+gMg";
+ sout << "42h/H45O4HY7CTb1Isa5GyX4L9vPfLXyfqU0N/8v8/UvQyB8TyKaDMedoiA7+hdWtXRMJ3cFxepR";
+ sout << "WqHTCK3o0Lhh29qPjMh7hKtC4xNVL5/Cz9KmIu8ZYAC4aWXX0g7wzBb9o/Xb6mJ+ZEDmtI25DiFy";
+ sout << "McNqoVFxoxYhTrUYNOaAqJYVTOauV8N3Mv9WiR+s00mjLAIEOrcExjt6WrCUVAI6XCqsWPibNFTM";
+ sout << "tSBPtg/w6jDsX9H13vPzwv2Wt1gGzejC2nuqaEVXlr1UKExiXjQGAJyiw13PjRHSOD5zUWOhkPBr";
+ sout << "h7lzK/HcCTA6S8PzM3yIPWYUtlzfFT1a1gIoBvO6dx3EbGGRYcePQ+xCln/07SFsUXfbKu43/RIV";
+ sout << "LUXeP19XnLIFDk9NRlK577rEfl3rZN2TkeFXMESPA2E3WtUXLw/rQxAi+E9n4pHJ1dfLJDrXcuH4";
+ sout << "iY3QbOnJj/ay/dY0eDCjoR8XBlbvqDCvuKDZd/SPNLseIwXXTSFxnRBONagPDgQTwM6s0bqn99DJ";
+ sout << "CEVWRIrkCH9o44FsWgF/gZEcEwFbZifnkjRSbiYPRn+euaJrvU8t+x86gbdywBJg/NdAPdn7gmyR";
+ sout << "V0mxYX035zo6xzjvYlOVP8cDewPK76P5sFFM/00tE134eK0jCmw2uMJ+AX2mEm3RhtmUdShzk8Y1";
+ sout << "lVbWyEn8VGdiZmsaB9SbeWAX4vj763YnH8pogS5AfRtR7+Z/B8dB7wBowpfs0Kvtm6LzPAqpyUYI";
+ sout << "OqvwEK1/f+fwq70lm3K0jFwda0ycYlBWhmGhpuFWYh60ictb+b19JLmdT59NrIme9aHqB8+SrEnN";
+ sout << "C5eh/Ee6eAU7fcsai415NdkTphEqSj/0IJNNqpoAm6WYXyNcFeXDma7VCa/6mXZFTHi18+qWqdYa";
+ sout << "FjFmYunmwGz5eEXqWwZrN0RQlZLzGTjAkNuVnrDhUqfPtTQ3Jg9Q/g514VBc/e1f4tcmY4lBDeGZ";
+ sout << "yAfWnvwG76t3EfpMmnfgTyHXf+cf2g6OUsha7eTpLNh6DUT6TBPtMElNMOWMDlCRrNUAvUGaiDCk";
+ sout << "dkZ0LcNkPZmrS9YTvzVVXU/DfKQk2Lc4kJw08UoAFRTXRPA3reX47HbrXeo7bU4PAhA6Flkzays6";
+ sout << "NWx7RVOmdSvo0Jc0Ec2rzSfShigqkmjXo/CSxvmuemnMZaBJb6o+R3iVOdMpCsw5+PwCtlG9ghH+";
+ sout << "egfrWDE2MA4hENDmgbONs1XrTLrEWl26avRL15OcLNfjb4Q85Db+r+8Q8ApGlnw3Sk76VrAoqY5W";
+ sout << "RIsD0jH/CD5VBtJnFNU+qFtYVNp3kBNKcNTP9RDpi9BxY0qZuEcRGtfGJMIAtqrjo9STIe7RxdfE";
+ sout << "wy30LUHKaHZSURLZgUYeY5h+OTOYDpTMjwm+xT9r2GyOOQSmNL+2QVXyrBq7BxtbTwDUOd+a/9SU";
+ sout << "rCqZfwqmA6OrPbLHkRkQHYJmtsR9k/+C3R4bSl28PQDEf4jfiA2ztVZ3Ot3FLHooVKF5Vk8kIVsp";
+ sout << "ZnbVJA52R4CLnmUyLsqqwDaUJYVfiOQ3zuqR8lDzhNDKYuTzkS2g+X9b91wF58wA+Mg/PyhUpWFs";
+ sout << "iYsELBbYlTW6eZYdDsIlI2x7khrgJuralZPQpwt2+3cg4NThKyKTpCIDQee6stuiUXqh1NKE4F/k";
+ sout << "U2M6WifViFrXuKqIChGe9DSraR3UfTyjDm4l3uGX6oE2MWXRyigPCfkHCqunq34357V0Bds/Q341";
+ sout << "hA6uDgedbPjXqPj5MF7/bZUHvqnnwD22WkGXfzgCpGAwN+WOWFxUZL2adBjJoEC7hYwx+umcXQZa";
+ sout << "H+p04+MqMvyzyfnI8d20mZQaZef/nY6vdNy5rXna6ZyFJnKR757m/OUaL9wrvO5Wlk0nzhQYnexX";
+ sout << "KyOvQYBbd4Hu5tWQ518f2RtMZzm7ReM7Xr78leb7GVB0claKWD+7ptwWwFVn7eiphI0rnxmSY7Fm";
+ sout << "BABTROBL//vZp3c2VqmpdZbhgeSyWEvyZ8KUheldonXlpbm0OahvTVNZHafDX2aUph5zYHSzT8YE";
+ sout << "cvm85ivZA1XdQfLkQot0Y3H6rQKiW89VE3qNoU2IReWEYZYbTk2FHIgIEicG9LECWuoPaHMLf+tr";
+ sout << "8dwW4q37Bz4IpwhsPrlBmXDdlPJnWDvq/C5AWpDkgr5JEg2E1CL2QA0YsUqLMt3ogOyqhr8jvAFI";
+ sout << "mWgYRweFTHjCMSbsRHRKr/WRRljRHqIs4sV1X7wzsXXR9V0GwYnkwHWu7SekTSivJHoedUspgymz";
+ sout << "Xk/OkmO3hJlvow4lyNXI/QsF4ObXSF5xvUuUtNHy4EF54IRqc1chr8E073MVyHdYbK/LYDoHJ1Kn";
+ sout << "wx+A4+BWLmeVS5weTfUHronUZlRIv1YgrSzw8M3Uxf3tA9+Vi9Pzu29X2IeNt7DbiZzGzMZWkgeD";
+ sout << "ye0enuQfIDcZ/AV+5B2mG6uxjLnMFp3S5HxSQEfbCwyvpxHEwImq2NR+9/J4UcxXNsqe2k92Wz8H";
+ sout << "jPlROh4/lFbmzxE2kT49ro3ewacP+bvpRpWyXpuVQo/oMQ0NuljpBtIoig4csJHhp/cREvXVsp2O";
+ sout << "q3nF/kxgs6BrfptKS6RE5R+gz1s9WNJ0yJk2QkurEWQReloMW/w0ktcfENcFevr/HkOtTr28gP1i";
+ sout << "vMleP2QxZrdNad2iBLdpJRs9SH663kVa67RTfBczDeXTqQiFtCpI16Dr7GZWnBoYtT8d2AXtCsNQ";
+ sout << "wUfRuiVc23aXnOkgD5PtRmcp7WvBDczaGtBBmTbKSy4DhlWol+B8kUcajMRGnrKDC6j1a5mYtWkK";
+ sout << "XzGrqT7Bsg0Tffj61OjYU6nCDLPm2Mm2XLHmhf8Ud3WJWANjsiWQKFTNH7eZUhzL3XhK/Iet+dPk";
+ sout << "Y1h1o4tGWeZPgFm9lLIhXuT/SkLR+XAGNNgrbAMGSlyTP6d2EPn5WvAL8Txt+1EaMITteGO0uKwM";
+ sout << "X+vpdPJC8cocJroQOethNqV+vK7brbtz7NZsdgFov5VYoLNWpwQLigIcaDxORQBC5yMh1ygeDQjD";
+ sout << "exiUJVF6Y0blMVyLqBz5OwIO1ftDAxeirjE5CIwSAs/E/CslH8dTJAL6hmTJ5Ncw70pf0NvzXkJz";
+ sout << "GUFSaWzRNyrpqFh7vtXVSqy/YAPOmvbU6TxUR8I/1O5afmTKkETJQP90BiuzjYs1N84IYoVt17YD";
+ sout << "xPWkQWH4LxwDdA51WSGCIwE/svX3BMYH/0TB8EMdL3FwVpNMgjCVW4srZP0KShXr6oJ18/RVSe8e";
+ sout << "Ls5HmsOL5vbvGarpKLvt0IXmveJx1Lr7h+9FAOG/n0HHD9S3Ug1JrgweszosdeIho7KDV0tGCEZR";
+ sout << "lLtI0p0KkoNEc6Hcm7VZyMOx6d4aN12NuVL3l6DFtd4SCFwBrIgqyDIf6ZAb4xeH2RckVs/uTBQo";
+ sout << "tNCZQKMaxcgc3Eee/IOjDVzFhjW4lo33FHhTrzYzvCISWREM4eDr4/ce+UZ6E00IsClZI2h2HkiO";
+ sout << "3k6ydBh51dHWHf5CZPp6gU218nIU7usc+XttaZpsodclSmh+LpXGroacfKyUGTB+I7lf+yzzBrmk";
+ sout << "Eo+WFBgemz3Rbcyem+zZaRS8YXKi4NlUX5lS3heqA2ogFwFDbQJB42FBCEMi3WVr33GELllqWSqw";
+ sout << "FZPBsn6RrIcIwI0/ehYLjO01/iOmI2Z+YAg/DRuxYDI/eV49z/xv0fCU2Xs3vYv8q7ZH3fjwyusW";
+ sout << "R7lKhO5nksDaJ1OS3bf0u237OCQYoc7dB/BkC9BHX2eo2Ou4x3pKBn5HQnldPNqA8Om/tdCgQeVx";
+ sout << "izQq9KJefumUdCmn4ChiVdPs7SK1DfhBnyL2xVHck9bC4fb3MuWZYL2XgwMwQ4OOOtUY+/sX8z/U";
+ sout << "O1EM/l936v45oXW1LJJYQtIhBf/rlqxqc/3l2YtHR7Sa40RknaaukggYYQJB2PJqXNBs3kI3FwIO";
+ sout << "z1mMafTy12d1ROBlqBee7AF5QJ25fCM90GgxqRwuF8yrGyXwv9pbnqmLBC6tz5DjQcYWpwm6qric";
+ sout << "pQu/xFI6RKtMM+tq+ETci06pbHQIcQaTvvW6iX5tEJ1iMFJaCueVEx7Q0MOgNhux+338+BEe7Gwq";
+ sout << "5g8xr553nO6KTDIrcTAgkRnQIL29WN3AOoFl5JWtP4uo1/cKqKpN3JcAIW9yQTjKtVOfSvZgSRxg";
+ sout << "Eir9UhU9SgMrM0/cdNLZed3+P8U6AYn5LxZDNWIlFzmxCBRh0SuJJ+yXtPkqApgvzso7d60FhUPa";
+ sout << "H0LYx+zhjx7jPhABHXhu3lY5uzsMcRnndIEOx5+N27vFUwxJC+KmgOi31g6XcPIBqwaWV5YxdVf3";
+ sout << "7cYOF9FYgRaCwiE0OKGADJJA6mbTYv7s4tzcPtHEyg3XK4UuT3lJjaxjsl/VLXbWW3JKKXwOuYGQ";
+ sout << "w7ExwYxyfm0tbXgwlK5nctw7Xg9U5CJA09jh2RXOQku+PmwbtseSP1HbWFC0pOskZatQneNQ6jd7";
+ sout << "kcYZ49MxSI5q5+rZuEOke5NytoWvPIeCFsKMUC9HqsRPivSrczNf51yX0fIuSTXts/ch5fakqj8f";
+ sout << "cA3Ln8e3wASkQo6qtG3PP2tAyEOP3qj39DAZT9/a7e5m49vjCx+ui8Kq7tj5ehGuKhrteERJ02q2";
+ sout << "ObOq3gYffJ4LDddCB1HQlrK1vV0ZSDctR4QZbEhD5PC0ToKsarT+FKHCDN2GngDIh7zSmEkq1YHk";
+ sout << "VXzD3V406Uij54UzOXv+6riqiv139xC8jo5TntvqR7S7/CNQu7D2ix+l+WMlUbFWy3PT6pL+Al2g";
+ sout << "ydAx24ZbJ3PQLivg44d64nA87Ii+eDsrJNIFt7jitB7w8tDfpJRsCyGWDOkA81Y6Us8udhe5KDUR";
+ sout << "7jekzGrKuX5VfUGAQi5+Vmk9W4d6C8U9Ma/D4wHIWkU+M0ffMVEgJjETYNMDyefi1gcCwh+OMCgC";
+ sout << "wbPlRBhdqaUMFlMBv7ysGeYJUQ6VoDw/TFcNtDMZV7G5mMwtsMhfJQV32hxLDhscF8nsafoeWIhU";
+ sout << "9aoihwKib0H9EvT2gg3KyvAIwNM/EjpninJXaaMNGbEFJ/7gEAZ2W/jAJ/eGmkLjYvv7Hh2oAkYn";
+ sout << "5UOk7TabvEFSsoCdeTjR4YjQS18zZcz3Ta7ffCgp2htZaAyJwVeHbBBL6KoJAlsgKi1yVtP8HXY/";
+ sout << "lpulZcwuaKXjQJVcHS/RcgZsaam9WxNbYh+X/CBouHSoZ2/LbtimWXb0WbQE3xk2Ntq3j7FJArKl";
+ sout << "gLGaykVfPzTfXFPKF6BduSPmxLCmgLgeZU2rIYBC0UXMVtEtizbU/4MF/Qo+0jjbO2BZGb/6tyea";
+ sout << "PRE/tXWyjGHkcx2T0ZUYTTpPtWI4MQN2LxopbcmDkZszXvxH9MlWC1jkhv2xBnddamb3jc1d4qmP";
+ sout << "YetgmnmQzoozyXiu95T4s6V+H6gNZVhle4+QmwxOe+zNzZaLhn9GBiwfLDGfBqc2usa2FvFWpqke";
+ sout << "PGE7LZ5HN+A4XFJ7NrEva9PtoPUOCM3FU9gyGdzKUyyW2ilrPX6rAil7+ZF6KDAVYOmnQQCtp5Ls";
+ sout << "P8t9MlqGMM5uwELErplaidbu2fNVTXvySFbZkga6AYWWV+xQfekzRYoN3LSs+Y0OuUU86RBqyOe3";
+ sout << "+5F9fvPztVD8Zc/c1jfY7/3lgJMtgEO/+TDa+gDQAZoalgf/SGpvv66MuiABp8LALajU0a2k1DYT";
+ sout << "otIdBHlVLj/JKD2tfHDi0FevTo3l2N2do9NEx2tyMGoLQwuJplvP40vHr3aII9Acuim4GEqh8cXT";
+ sout << "XWT5KzFF5L8f+QK27m4lgotZVF1ujXVuQRT+CHFwv5LTiwHaZ/KFgVVY3gIkjCRoP0XRARMk5oIh";
+ sout << "F7G9FnoQM7oV8f2b4mTKVvqMMv3uWzd4zr0Tx7GTSgL5YfcJox9my7ibHgpBpBNtej8uX6MtzHXT";
+ sout << "SbBoG3lDrdZ4fhMx8d9+/oACjVThf3PPa0exxvme5s6T3GaTnZt2BgzLrSknMaRZ1Rx2/kT/1ecn";
+ sout << "GALfb5irpfKmQPLdGJv8EIAB/Br50DqHGC7DfCtZ1BshNTwgY77mI9cus0SGrgmPRXYuuEUVhFHf";
+ sout << "OfrUwDNX/L8+ztdvlgJ4NnE5tUnegEz0ApOyZWywtxDodl3AZoZfBV5ODhoHNsCyZCqD/HYZXHb+";
+ sout << "u3ZJ2PB7wcBbxSCf3mLSo8Kknq1d87o9YAQdyyNR0s2XvQO9IakDz3HfBFugDO0LQn0ewgxAFhMs";
+ sout << "O6hA4M0Q9NS3X+Dr/Roy7wn51BkJhrLvh9TSO0+VCa1EBxo3mI5vM0m+Aji9HwySVBcd9L5FcNIm";
+ sout << "PTCkaKyAd3DsqBIoHFhdLhauODisNzUFC32TZiz6VxSwxFnNvyrRQUDidLEwWWcp8Kmm//Q4nq1f";
+ sout << "F2Qq30IqkLC9nObGbVSjZb+n+EwqMlE8mJxqbgkh5TQ94V+zQY1ykiRa5a98cFA5MLazMniE0JXv";
+ sout << "YqUtmkvvbyFjbAZ+/RPUqG3eQxyvWxBSQSC/zpbG8rJ58ZUzuPkmmXe2z3PIzFpb2BMVi7U9t8Kh";
+ sout << "+ywyagKP5gLgUTqOS6Sz6p30dvtDsayhngpRtVfmOoJ+Px1iw8f+3o+JEdJGnXEcecaGIzABCOST";
+ sout << "mEC/TQQ5ecV4yEwN/cN695Nu76+cSiHR+7UyWNSVV1A9WJhsaW6NJl/YY7dZDlMYWCtlB8kUCfWr";
+ sout << "4bhrudo8xW5XeM1EtmmpONd8HeMN0479Y0KqMHYCnxYz+UnC7XZBJkweyFPa6TroZoqjaIZMEyRm";
+ sout << "YizuyvyaX1eYGM3Q58/qti2hIgrRovGl8VYnJWJbaU41OB2CY+R/TsuQo4EcqfvAUCgNxZwwmvkG";
+ sout << "brIAJ5+fOVfI4Fa6swJ2PulbJBXXS+fCrjCEvBAqg2j4GxupCDk92IF6cUH0dgmx1TBFgRTNDfzf";
+ sout << "ENcymno2Q8HKz6uvMenEfNDAsmV+giNq1jIY6gm09kMwtDk2hs1c0hqFuno7pOrkS1RFbd8onyGC";
+ sout << "qoN7GtGpMk6q82mLqrEc4chaIOF5UpPSH3sRa8tbiP44q8ItnbXnMlzox3ZhZqJN6QgrNWnQa3wD";
+ sout << "4qp2gUxGkFE4i242i/9xzvfYu8h461f7e3Es8rW0sB5PZ+FQ8SImlpE5vh+Q1qhXoNof+I0fGGb7";
+ sout << "fbTZmz4ZZrhCS6KgP6HBmbohfPXCKubfc8Q+F9PCaAN31Fp9ycTV8FAfXySbQzdShgnnhnddJl/T";
+ sout << "Vhx2PFPythxz7GEv62hwjLuZdxZq3nCykVHZ/yTE5yoLBaYUvsVVZn9rv1YD5+oSYeJZTbjR6Hif";
+ sout << "t7geARjzf6oe85YZdQvODK9SMHYo/QwdcPwwnSf4Dtmvvg5XnSG41s16X0n7o7g5ZgFr+HTLwMzR";
+ sout << "k1Zk3offnOMpkcpEXvM8OhFvFiXoN91CoCvin7f2gX7IWrNv0mSj92fbKzQ9Qa7N5cH1pldFuDcA";
+ sout << "kQ9wHFmz+rrnzgEl03Br7NqKsB6koxE5YEzHkNLHmg+pTOyu+yDREbopMHBi+jTUReVS+5fKFMLY";
+ sout << "LEB6T96s1i+ygPVSt/4H/DQsU+0caW2dCfrm0onr04auJcU+oBYjbR3OlM+6SD9UaUjt2PtNEQ8j";
+ sout << "fZocgKhSxt855hYm0y5qvfBLoDaADhziF6Exuh7YM/G6ywKMi04Ab7qecuut0c2bWHXYrQPhmmCE";
+ sout << "NXCiCX64ZIGfsFR1pf8eCChyiA/GW+rQvxv23bl1RxbDFk2ZO9o5581NjEPjQKJ8AyE57W0bZjol";
+ sout << "8s4S1ZNwMF46UgZEJwCPLuqJEJvZyLBwahHwC1B7dgoDdZ4hjaJdH+jYEZusfsvV8ZC7hS5e8V6U";
+ sout << "rkOARJZwmYES8JJlVtfbfpCqBXZ0gTPuBKt2yYIlGTqGwGu+biiNMVt3spu2Ov6MlVoxY37UZax8";
+ sout << "B0Bl1cbrqIBC3LN3uWsmnR99chFqEmhi1cRQR7+3DtD9gpkt1Uo0STq0YCgbVmr4azLoJ3kcEdRa";
+ sout << "cfZ4LQ+7zlFTAqOhh/kBPsL51ZLQ83ipFUgz+c1f71aM0eCFharOyhIPy6YAAPHpm0FdMv9Q53V2";
+ sout << "p7flk/h/s2eqe6M2jbtud3bxxUdercWJvaioQZWMkNSXcP4jKaRN88np7Xc44pjUnPgO+AYc3Kl0";
+ sout << "2DXf19IXDjAegihkFzqvAKV4zttcsT5gi8wN3M+zdoG1TcnG5gJfXumw+Bu5giUudQ9P0mUZPc4t";
+ sout << "vMWo3r7ajnqeiUaN3PZaItRimu7vzun5+qOEIT5nxHNDuwKVN8ZTigvOrFRUxYkyzoUzUsOPUdla";
+ sout << "nE3MUbT1bj2cs6ih5jqtbqgutSNQSGupC8Xve/W8tiQcrfRPBMdtvJIp6Iz+Z7g/SXpunCoFaBSa";
+ sout << "mAuCHeTbUrCQ5xbQ5k8G5qIYRglHRADsWO/oM+vWtc1jEEBZ9UwLdFZxIdj1ytNnTwIZWstliRQw";
+ sout << "hwNPyxfJQwng2WNJFG9LPvzxxnpCwYG4Fk75wPd05j14gLMcERPyWpwYurfqeQG673BTCuBeF/DV";
+ sout << "zZHI0nIV9CLhSSbuo46Vkb1+FaZwwpsD5sBg4+EwKcglHUzrqiskq83hIryVcvSKRNdb2R8VDLCa";
+ sout << "pPblCQx811bg68OJ9UHqbw8FPcSu8M5zbahisVzxszdjmyZ7fqWR1twCBsJu+kLjNEsgC1cmC1ca";
+ sout << "OLyaDYs0sad92UjkYqqBQijASoQgafhSPuJ7sEoXvLXrPP9GWb1DV6S6mD1ZBRicwlo7pWb6LREQ";
+ sout << "HFXvVkX4/seCn3O76cNuqenVr3EtIOhKC+lKDCBbnKZ2ggyoIqgO5BIteeFcs5rz4yuQRM+lB7gT";
+ sout << "Nvdwa7iYLDOsYOrtBpNMD5KO95LQw6m/Okd9RKQ9yKTcV3jQFGfAlYcvXt3g7tfz7Zyj8ED/thPO";
+ sout << "pcrsgQE1E99T07ozCCDNRaPc+DhgHyyIEHNNxmgGUWziRBKqhsnei3nPCg123K00ifq005MgO5wF";
+ sout << "oHUXDMjbWBF+ELX8MWGaMh+a5OWq50xs6wPw1WhPug4vnEJa1j1rJxeGRSa3SsELXmeY2QoAX/9b";
+ sout << "S8HeDJq2O9OdVy8bMKdfxxvxL+Q/5m35EGdExoccnU+LcpGc1fIKqLfJn5oKPS6BlNKZnMpkyV+8";
+ sout << "z6VQA7zKvYP/Coj8NSmZkBJm4SoUkTBt0hZ9KEXEAAzR7Xf7BzR710lXcFdRCYWxz4KEjsu8z1e6";
+ sout << "NcjhBpcgvFTnowMThMLtaWotUJ4KnGU2ys5wMNh9E94+pZNkEQ8xShAEt9le6/gSLLuXpssWJIZE";
+ sout << "JR6PlTgQtFoSwYfx4+/iqbMVtbH7f6tBpLbnrIzQXRUY07Uw13p7kblM28k91GxRDqbu0FR3cWnI";
+ sout << "nTZBplbiYcOM/SlFJFfThLbgAtTO0RWgJkSr0n0n4duBka7ZzfUiMXjyAHvBAQTGkHIb+YidnC0w";
+ sout << "eY86YsFjo3jTqJpLYuciQLJ/ZUI1g9v681JJVICEw4CxgguvGJJNOgQDh2CEZI3AGLlbm81ftzRs";
+ sout << "MpKPE/RoI5kzoo43tcr3WCUImlVdrNak6K4gaNXH4J8aui6MN3kmZ0LWS0l/7Of/bVs56dEOozSm";
+ sout << "PcGfVLmvLExSbXU90kmknOftcUcTG6J9r1S6PR6qaF2w/G1t3wx4mdberrM9hOIMoXQhXEFShNHn";
+ sout << "uZ/DLE6t4Dl1pvD707gtoY6UpHpv9oV5xVVPVnq7smExwf3gQsuiHnDDGrPlz+DmLnkZ1XYCaKSf";
+ sout << "EQleJCCuZApUfEl9F0EjFjSDBgPoPpwQ4mhDXFjhYjYhdWmOfANCDsfXm2PUoBm4JJNI5tV6fa1+";
+ sout << "S+yaFaIw3TiqWvaOGIA0GyicfP4SaHodizLUVo+yfYJhycZrWqclh3OkuT2SVndIDLs+8Xc2oSBY";
+ sout << "uJr8P3Rm+Du02X7X3D94XvNI+tl22SU9CMMfG5E2kReYpkvtVsejWaeg7QKil8gyRLtkH1kIb4Bl";
+ sout << "Nu1oeg25ufTr+pyX5nqCNqeStAsIMH6ynVRecIVr0R+AF0+winGiFVVoeedIQzWc2pe1nD1F8LwD";
+ sout << "H1NTNdcm8ukf6AA3O4pl1uvLPIZIwC3Qk+AW3u/Xwhjq6AxEpkCGoLrUztB4Xj+uzoBpQ2ka3Lu8";
+ sout << "K5XooiNcanVGjWC+0/IBmB9gioabhNXit99vFEdTQCsZzlDY8D9le1IZ5lXIGgSZR5LMPOAzSaTA";
+ sout << "b7vXFps/tixvj3n9wxtOTAMREIVaGkJgunwSNST9lQ3tcveAJWnowLztcn861ystgFGuDgW65xgF";
+ sout << "Fo1EHK/Sq38jEw9frNatPCJK3eIJRih53VdvB9A+viAX9IoSK+KH70Jo25Bcrv1c+6Gr2H/Rhv4o";
+ sout << "quxm7d9XpPAA5dGZb20fFloxerWOvWO+QHghXZjyD81o/1hhR6ZkOiK0trPkvDLcLOfPl++YIjPP";
+ sout << "vI1hZ7uTmn6qlN20FCN6P5H2YVJBnhSZrKx8gFkULKgXdT5SpVSR194vK5cwoOUY0V5b4EdSAkWa";
+ sout << "3cQm6ibiokDAktBSDS8vpitwP0sLI5K/7QfLuXjihWXONCSf+RobFWqWPXzYW51q43+MP7xb0gbM";
+ sout << "39I93rn8s5XNNwVdO/IXzd/HmMHiElOr6xEHMzCUErUuQQQ59NbJ/N/iVuv2c4JZ7ppFWcz1ZQRz";
+ sout << "dgtOzq49kBMIxksFoFuEWuRV8LEqBbNBQkW2zFNNWpV59mqwpikzNsraFEvUPBUJKz5JNSuXVLF5";
+ sout << "vq+hYqVmhd3UC5RU50BWbQgMgRm1Bw98ScOPxfMMqBFPiMntTDcd2UuFptT066sbXg8slg/itSSO";
+ sout << "/CQmixHP0Km3Lz7K5OqVdIr8GYAeP1M3pjxZCEsaLGF8YGrPnIjMyjyFxFhyteYsgU6330hc0TGN";
+ sout << "LKJw1F8lbEyGt5lujwnumUFn4Wdg+gtd91ewchV+uCkT8Atrrd5gvCc2AEWxRceQHrUTuW+u6UCf";
+ sout << "LZvF/BC/8PJszVQsoy6dHruw4E7UwlX+TCUvP3G/wMeDf9SBsgnn/KsRr0aQmzuIJpWNNULXGyXZ";
+ sout << "0OHSxs3fJYq/X1FB2EgjJMXrAczU6aKHAj1AbCUbBIaJcB2umY4YAQqt4x2q/iOJDbJ+k3K28+Ry";
+ sout << "vW5w7JdiinwwjNpQkF5OO+4bQ+rqPzJare1PT1eBn5OxRpVlZYpn6AZIt6/La8Ir2iWtAdUu9Krm";
+ sout << "9F4qrhoykAlcXnI2GZi+U3I5ftBgWfDVodX8hxzyS5Soo2LFCwmG4aYayxu+RigV4YRyuXdHe3vD";
+ sout << "L6Vs2qmyKQ4JE+IbtJQq9jkC2PrFMwfP3nOstccqgnlGMrKMmN/UX2R0emMXsImuMGTn+rKctWJn";
+ sout << "sXIysVIEupHUUD4qC1RuZGvdGk/LsV7mQO4SkPJWD1JjUADB+8Cw2nkMVCxn8gFTf1JU+ovxgHl6";
+ sout << "nMrlYnN46eRJSFIFvZP0S0gZf18aJaxkgZt/YPxm7m8ZT6NLqLE54+st+e+sWHlP+9VeqA5u4aLh";
+ sout << "XEJGaJozsDZxkNfKN9lGMkacM7LZzR3lhjX1o7cjZ9D5l4m40M7DS8sbMg0/5wXn3WxlPimZFLvC";
+ sout << "IBLsEAoUM+oWB1xRq58SAPIAgic3jETi8vYhDR3vlN8ubWsXyTqOJKIXJ8O9Tz+aKkgLq1P/1y21";
+ sout << "Py1VM0VW3XRnnzs1mN1NLO0A9Ot7NMluCUEJekVYi3jrhhLRG/tGYkLUiYcUq3ozoUZjBI3nX+g3";
+ sout << "4d9uDIXJT7GH+ovsh1q5Rw7H9bijVF/qT2TNJH0DoI25kuplXnihc1+PAMp9XseXwA3lPZhxMVLU";
+ sout << "ftPlCMbdP/LUPrfoq2MaAQ9oKokaK99YSEumcKzgcNyx6wUrB7ILTO3zqFDPsxTO6pxaPgOXC+yM";
+ sout << "mLm+LXoEFMsByb+F8xXsrtCFDwNBR8XU92UgjMQLbvymvFE/fEeYMwOd5VOPcPeOwPs8Jwn6K1TK";
+ sout << "Gl/QNu8NQOgVK9rIISGbRdpmcLRW7JY4sp84q8CHBaZ+uVRBYzdFDw8VokNyWestrK06wQG7UOMa";
+ sout << "AZGqcttVsN0aePOk13IfTG1llEL2fczoNioOpCaVqzzUSU4wA0QBOa33h1VHOY2Py3x7PWI2QIJt";
+ sout << "v1IJSEmnpKgVzYjSnpt7kUlB2dYv7w3qwDeiE47NuK4ywMZAUuONlL/ZMpR3+I/EIu3VJRo0RknW";
+ sout << "jKkYUjOdo0FrGhQWjuISkZB8prJxAFrTDxI4aiHSnfMB0q+L2LJU8LMLLI3Ok2NIxWmdIJaFKckH";
+ sout << "f3JUFpn784RDpdHlIUU5EOzfqal9flGrMlz03gMrvUiq+9ZtJR44EXoUwLLdAWN+FO8a9Rrj3PUw";
+ sout << "E9q2KZ9M6BqLc/YdZFHlpZYOiIlE5x34OqkVqcQOmlWmPVP4bSwbXwmk6oWjruRxmaA2YQ/sLIHm";
+ sout << "XBZZQdMe+i8TqEjte84I5EebJRi0oiAzTFUlf89NWonCzSnzdFQPOiB6Ik6dX8o4607SkyijR81F";
+ sout << "5ZiusAkWNx6zksMfeZzCxnS93TqiOSmFVrbxMx6BkArmc1IWPCFu0OAQYC2af0KciOEndo00R0xf";
+ sout << "wQAtJyKyfM68kaHebkMhRWDwk9HEwznfn52y6evH3n7NnDttgw8PnphD4/l19wWJdZ2l4tcmY4eX";
+ sout << "ly2zjWCcaIsx4QMz1gEGc7lvMq0J6RbzsoH4oa6Zdu6kaCN0Fy246vkDDBaW4n/fZt+RvT9YbiqO";
+ sout << "fzMs1dcXd9mCOz2yB0/i+XzzG7awbdiR7bafVk+7/8H4y3qVyfGrhilzYSP7qEjF0ossijkvU1bo";
+ sout << "OblpwSw4RLL1npCLQDtXh5ugQGcsEesLEBq2uFpxPTSLKWgqTr3P/XL+EwNJgkWDIXdMgG0Ft52X";
+ sout << "bNQpevk0mqynVmh3HiVnAZ/mUaRXZG40kqcfQzxtH1gvf41S71sDXT+EOQDdwxygNV9BhOfyA4eO";
+ sout << "gNeFlhuCPE3cAzKDUrE3ITYl1++4JIpP3BAcsT03GlBlSL4N1nXi8YuTCliEmN3TRF6FLJ5DkSN6";
+ sout << "tCk/xGsBkaZlQXjj2P4FV/BdS//Uqh6EBRnRMoRH7LgLl0QJxJrj9RxDdRwW91fNOAJjg53IIBX/";
+ sout << "XCsZr09+jke+Ci0r1NWHRhmF6qNY27CVKs3Wub6dYS/GifyrN++qpK4aTOLkkXQv5v6Zhs48qFAZ";
+ sout << "4XORp6I1s7vCQ/yU6R4UeBNoRAT4naUFGc3gvmszzFb0Cp+li0RM9wgW25miE5Y1PMpne0nBvV8a";
+ sout << "k1RYRJA1ZyOEQbmFxLMqanwITgwxZT7tK+nSITSmHAo7/m9IaopgYOumkPGdy/KsrjkR2gTq0XxL";
+ sout << "OSjhTS/sOrZhSV0HJ4pBgjZMIAlEwv6k9XsefXyREo+rudBS9CiXW6TzIViAWCMc/hQo5zo5tUeA";
+ sout << "vmTrWxkRgUmdsj6FGS21aQFPADj3oPrAXsTiH1ZcwpyzJrT42+4P/SLFw/h0NBYRozI01AVVaxQ8";
+ sout << "sbJ8VUFliiZnB++Mh46aAF1PzomK1NiHC7n+llPAdEqpqiWqVH5YK7CpKdmCPnTUuOb4a7A2Ylow";
+ sout << "rodSKCOX4VHwxzARShJxnQz+FIfVIzpiCFs9ZJKM6x89ybXTKVRrmwTbTAXRbihUKa9d6t1vmoDP";
+ sout << "07z8WyFDoeWoSmK4XZfegZeZMIXF3hrDPCYcZEluxujpPG3ZPPw+/pGjBOx3/QFJTQdeHBz/PQ/G";
+ sout << "wM7DkD3NM3fVApST0LbPQFKxlEXx9tCL3GBul7YnnUmF+N3OIs8vDoJQqNqVi5ZQnGt2uFLOYRo/";
+ sout << "67XsBQv/ZwYNC88VlJykWEbbVDN3rY6i1JLr/9WRygS6aqF2JJBTgLwXpjfvaq7Ygttm3WiT/dxj";
+ sout << "QtLg9L/y3Q+LoLOKZ5syIdk1+oQhYnSqUgTNEuw4dcTXj3v2lHm5oxzst2dtierbluemr//Bnclq";
+ sout << "8FdWTT2JC9N/YI+CZi6K1lz1Bq2yblxL+BC9UE3rtsq81Rm8kJxSIVQ9HsS2ligaxnEfCIPNvMhV";
+ sout << "orLVKxoWuOfhJt++IepP8kap2JZtLdmdZSr9/6suSkInlANS9SkwaNWB+EiTUm5cfu7s90V+m4J2";
+ sout << "SEiAMESJKLywmeQi/JcV+8lEiThyYTxhm5huK2dUGPWnzKK0HsdOuqRs0RnHn8SQIgLndjyTzcV9";
+ sout << "2mpVtjcwCRKVBFiPdyeqyVI1j7srk8WrRc0GozuUkSBMVk5KHQa3/qZqiV1w9KN7eMRaYiT46bP4";
+ sout << "2dpdvqNrms7/qmyv0q+Bxp7JYv8bYqlQPaXgHxS78I7HqFsLvP4kZVWxmhCXIoIObX304w3KRWH3";
+ sout << "WlQEj/llet0DPOhF6EPfcuMeSvH1UhucSS2pnYwo/P0BvMYvj6gcbjTowWMk3vmjiZV6a1xq0Nm3";
+ sout << "sgrrS/RqYJeGPM6THw80LulT4e0aflcyFN3QQ+C01vGHjpaw0OwIW/iWAMQolVwcZznJPq+6Aya/";
+ sout << "aYbiphirZC6fZtUMIXiKRrUxQDyfgI0y8XzJgf28AxgdyCP9Dashm9uelYMGU0Jglf/iwa2K1pkL";
+ sout << "FTS7zfGJmz66x2+7agbLuCHd7AQhZb/rfOwauykZXpCE9ooLMLgW+BlKpqgVynLrt+b0I3KqQZy9";
+ sout << "KkFxFwC9cumQJPtKGKUigDLV6+orLNwgtVmNal9RFHppuj95RWG3XgwNLLGoE+K5TiLhLY1na4AF";
+ sout << "1NXsHL3pFjM2VHEU6cOqVLqhEWwGhrmBoxIGrbmT//lUiGwSLcmDHCvv3Qig43HbRmuQWW90qFzc";
+ sout << "Sp7atxofyP8SqmW2aRw2SRH/nvCTTpKw4s9bhDEWCukZKo+bpfp1ti0P4DzSiUxTy9uSjWs3+7Cf";
+ sout << "nWsTsoUxU+YdC6rDJYq2TAcn3Zl6YBU9XZl3YLckRqpUGUQ0IkF0Fc0bb5kQCJH8h9qc0Olnqa6s";
+ sout << "sjHyjTwkzZWFmuvJPfxEZzsZZFu9qWdk/Pzx/46e3J7/VaQhtyYidoCdV+wtN9IOuxnbyIV0A+vx";
+ sout << "ZAjS2f7y5aPxzc2XAmlsmNLBqrRfMm9dvxN+gYVD3UcWgzTVxfIXRyZqmdAvuHoTq7BV8FYPMIqb";
+ sout << "SKpCcdEztosRjzsCxioB0JPCGxMhHgPU4FCNBy24fa5RIz3rZkR7xFdTNw4LTkJ15BRKMsBxHaiS";
+ sout << "E30FSLCvFvBVlUITI7aOrvc8+2IgX5cXLxHQeqaSTyKm76ioORshYyRkp8K1Ao0hZXCW8rks9fxm";
+ sout << "FeRen1ureYGmkoKJUWq92iAPwNsFvvxTm4IprFoUDn4s56Ung85NTvcEJEyc4yVyevqb7hh4axAu";
+ sout << "h9N2uFPhx9YPLbOHA9m84PWgt/IPOfdwhlylIL50MZ9qitPPG02an2gjybEvicqckLYZt4cY1Utu";
+ sout << "pslZJgCrFZ7ITUBixaxhkuXW1wxwkF8RGosHdfB5zycnXkvXhdV41+QDvhZ1tR2E10jwqjVR/7W+";
+ sout << "suB35kUFqrlvWs/bBTQHwwPdNmGiNFi+uq6hIBQCm/J0MbCaYcPPfhtO3LIsd2KeiTz5oV8AqBVu";
+ sout << "NPQav6McdWJjg6dBQw4G+raxkUf9Qf9TPQspS2Ll/Kmt8WwbqiRHutPOaxaBBeE8Igx6BT5kIASc";
+ sout << "OsAnqIzudE0cclrkxN4aeTglqmwwq/eeZUSOzkv6Ge0xtpYDgqfGbe53fxtYvItYYExMgy5HdojY";
+ sout << "1zDaNXsfHTebUUjPnsXitzIQQpRobTCe5ttWhfeeXOvE+MPLRt19KxSrFKHaRlQtjYsmxdg78zc3";
+ sout << "8ceNlnoQbh/puSKv+QWnnELfDnmgVZnU7lIKs2xM2B+pNy4Q2YX1wmIqlC53rsTzU/Ik7Db7Z59s";
+ sout << "i5QWZ0R1ll6rdkCsUCTkK8aqsF+mSvWNbMd269Nch0fSx1D1q2NN/4lHDDoOX1JjLPMxVNbdxLMt";
+ sout << "gP8gpiQrFJWZRmTgYBW+biZPtiKgOztsDrZ+eolNWpsGnLTK9Xe11lZrS2fPton4L4jnr+R5AnHk";
+ sout << "F6tBVhSBRHtZdF/xNcCsSOKvTOZJCtftMO6XaKMfsI/yexeLz41D5qfMVSBHY4bIgL3Jgh952tlH";
+ sout << "2blbe7lcZxZN2XkUXXaVaX60XDTLgZUbGzqJo0+x/TJIAuCAlGF1XeF52jTY7X1veE63l423NPsf";
+ sout << "kCmiRcepJT1ihXQDd/VU37tJfeGZrN8FjJzx53PVp70r7eLJMedMbBUJBYmOofO2F+OU4Rfyab3Q";
+ sout << "Wy1ngiJXrfjF1DVqmKNfY79SQWWjlzST97ldjTICsyQdHWn8j+Bof6sL5k/3d/cT/sohhSSl0Bux";
+ sout << "GssAFa27VYsJYMjnriFZD13ZYmjuesSXyQ+EkUsHn56MB2xEwLRVEhq8ygDP2eO82PnsIXh+uVFe";
+ sout << "0ZQ+eGSiGZtqeL1aoePOA51s2eE4DxGBHWdSNr/7Nc43oUU0P7sV2lut9VmhgDleV+Du8IaZUMIb";
+ sout << "2MS/TkvMW5S/ebAoXaL68zmsU32tk3C0FG5oxbdZD1tcZ4yHNJ38HPSx39uOfRG+FrmDqomfs3k9";
+ sout << "pu6KaH/Be6Urzd7mr4EvxaXHz58hsJGVH31tOdo1jhg7i+6lCJAuJX01FWfnClLz0mSJnIrcPge4";
+ sout << "zQrhghg31+fLZw4FtNmiJJcU/hJPGhUZ5LDJWFHUi04u3+dfP1pYR4e+CEOyvlK2YXx8CHcvjWJ2";
+ sout << "kCW8TMN9I5arBUTXdFfOg/v/EH4cw54bmTgzjpn10HlKqfJn/mJPPbBGCCS+pdGmNIzItVME9IES";
+ sout << "Gco5a+DT/a/TThfkRVKhJShe7ccCC6uQPdU10P+DV2KPxOfKvar2ww3I+VZ3FKxhC0AcMZBbN0V/";
+ sout << "0tTOJVH4GxjH84c9Jr6niQJ/2w4G6tuTcJ2RLb+HnzfingM7J9VOr9C3IfG0jndHCww96Exjjk0j";
+ sout << "bE+Aq4H39i5pH3jEvdMspm1B3VY7EN2s+fC5jpmZpA5kRb6Vj+CYCbyEYNicyAJqCHdJLrtHDiGN";
+ sout << "fIhpfrkpc4+g5xHHel+nnJKgduqzs5eecM68vOtBTdmLa/nxUz/VBywqWNmmN0v1/g0Jlw3U+IGb";
+ sout << "bZOvv2Kg0FO5rJmnRIUsokPkcVxnfz9iyf+l1d2QIBmeR+YeNtLS3XUfIjHvHwSrZ3lJdfXNnnwr";
+ sout << "zZcKmHHNAoRLIhM/0ZYQXMbOvfEqUBZ0GCzMdwiEuNwVPtbDkgmcWjYR6ew4YjRZR9alIas8o+qF";
+ sout << "GkL+PT4LEHfASjQPnCZIhUISIpwzuf7Ii1I2Fk/v8jBbL47msvPsQcVH8JiFiWBGCBfSB8IfSSw8";
+ sout << "sgc4bo2hwHMIES5W+rXXoQMSkGFP6tUfd4GvS+8t60UyDUePQbBGj1lt3ETsinD0btygvMOqDhD6";
+ sout << "H+E5gRWBXOV6QrXtQSPG4ZSPlDvATmDdgLCBtQBBMTLMg36Kvia0W+Yyu0bFtFEJre9N/bmyyvRR";
+ sout << "P4tHn8uOTi3ry5PT2PQiJ9c9utByvU3ydcBhoIlcdQz0f1Rf8y0eyKj6qt1/OhnzVXuLrkqXLylC";
+ sout << "sthAWqyIavUltofLnsAP1uMHZ0pHeX6R1lwpIROxtRH6p/0+//OK4lIZx+7D8aMPIaefQAB6Cf3C";
+ sout << "jKnvyCpR3QEh9JLv22OghKwswHT50P/+z2XbyoQsRySXNTUbTsGJuYBazgG10fe6YHQOTrUDUp24";
+ sout << "6PS/wEco60z1OF49dVlYak9zGA6kTRndkjWDu49NFhJRY8sBd7TVxeiU34NVyjIDLrLlx7DcuZ6P";
+ sout << "1/XcA8czU9m2n6VDgkkAB6eaeL0XhB6XZfOrg9lY6R+xjVYo50Fg9A8AGen5T6+m0PeUr20e3WV1";
+ sout << "H4KANYIvY6+zN+Fe6K37VaO+CbDZfvFMzifvzTQWQ3kDDTWX/BLfFgLGZ6QBWrTwF7MNiQ6fCpG9";
+ sout << "BAYy9V+QWy4Iz9lqnp19J4Q+cqlIUDBp8b8vBNyrOjjLjAC7ezFUujvB/7RrKtnbaYmS8vYF1f9+";
+ sout << "mGqA78owo1zfad8DsvGEnr5J5mC8d10rXhVXXb/udkiU5iEhYPnSxRy62tgLbKHJOvW+R28r7lpc";
+ sout << "yDK1++NYKRpDIHYMuaZ13oDNdIoQD+d42Su0NsP0wECAoMmQKkWUaJyYUyFqzmQ2FJmvl6hDSbrs";
+ sout << "lDipexBL2U3slCAJbX/PyvE0KrBPBe+vT/w1Z5s2GBvSoGwmFlN/oa5I4TfcTA5W4ie1rBHUKqPr";
+ sout << "az/4un38eXTaF0Gfiw86pwlgJXWr+D9qvXQApo6KmaJhjKo+4/MaPw+iEQrU3IM44eaqq9exmiNF";
+ sout << "/SgBv76gC8hXFrucFFp5znYrl6ISQUedvk81jnI7ce0Up1jYsh8fpfp+0V54IUNKqxT9YlgfIkbG";
+ sout << "THuQQ3p2F3gLaflJWfZEo/lPtKc6RqVadizBP/oxl7q+zTriiyK1JWiJojr97xQcNpj5j1QF6uqK";
+ sout << "+7NQgfVevfP9FGVKOlIgvsqc9fnqWd7pzWKEt+kfjbpMGpr0XFERPld0+EqLts1kAj/ejnMymrXY";
+ sout << "LF9QJotJBy/QvCuoaxdptGBtrC+qJIuoYKTNTwW4A+kaf3RrNr4ohuLBcH8RRpHFU+3fEDW/kOx0";
+ sout << "fdRlEUJoxdfK4Za4moLfsmI8pfl5rBvlh9oumT0sgvPg1P3zp70UXJI0tSHisCb/uK+Hggdk5AJC";
+ sout << "gC3NM9uf3AubM+CW8rJXM6qz+xwOeTuMawJTGTrSg6bB0jREYOXfvwZ7oNgiSdqe4+s7tIhd21Hd";
+ sout << "A4cJw2aRWk0ft7jp9O/6A8II+hJoxReOnEFn54/kdM5LFasSjd/wwxsdtGevYixIItCTJ36K48dU";
+ sout << "nndygDo7qFWaryImEjgErq5NT9VUfg/VdIhCfAro6xE/1C5mMZ083LOjvTdxdRO2qtgKYCvL3HWO";
+ sout << "Q0k2/WqcaAKyh19cHsnoVKd1gXCummMSRahTGpmK82fAE2vP7MThuNel1CGti+nKVvzu7Yf3i9eJ";
+ sout << "rUapLrviTwh0cQzVAWHXxgE1PB5O1Hj1/hxwRVStuKNsPjZKFmNu+j02TT98hpYCFQ3F2hE/jsIT";
+ sout << "VN7hr7k16mXUqWbCzNWO/xo86NyCS/8HYFXMomE9KRLlP1cw1KrmBltivEnesHGCi9sAyHnE9bcX";
+ sout << "mWXL/tg8wP96Eh10KeJKzx80xfSPz929wa2Z+fAfjHqVgroE6AJQNKengtaLjTo9vBo2E4mFbs+O";
+ sout << "AUy5hDVB7Mm/koguhT3+BxefAXoUbaUkNlIewCbDUlLIT2o7ZLYhoR1WTBv5K0eQIt1wTFytuCVq";
+ sout << "Ldoni/6ad1IzqamnGBam4hFQlbQgpezUea29Idj4dhUqTmksk0fV6sXYaT7P9ELB+W/OnF6IY/d3";
+ sout << "6TUUvnSPfyfOHHxts3HnXQjDqrRhPeJown/p6tzdtytVCLeUtXu2o1SgfGJrs9uAtT5+3mt/NjFb";
+ sout << "Z+VX86HZQ0A5XgYWiBDvd/Lu+ptgu5kMX/QAzBZp4Ubxo/sF57ZuA3TJMx3CjUFgNBwHgdWsiMHf";
+ sout << "X+sBqB17McUR4Ar0OVoRewmYlRL/J9vN3ZJ/P24Ctp/7Y2Ozk6vJZBIKnTvcapys6mah6rzYurq1";
+ sout << "vQ2xWL0+hq2nsbkZ6bMl0ummFaCrPrMmBLGAQTa0Qa6QqO8sCtepBHRzJdCH8InI/jRBJKrGM2jt";
+ sout << "s/bM1qvX3U8jLY+vGwt8URt+H9VrF13rKxlgTJmu15oC0duyGyc0ejeDsopX4NkOQ6xScBsPAMV3";
+ sout << "ObV53ImrQQxdypGM8UWByxISLsLxMDDB49DhgBvOqfJCJ62M/m3zZNV7PzUkUcI3iB8QDpovj4Iz";
+ sout << "CmTXnRgtsARcRc977luV5QiiofiyUOVQFbBY7obGei9EnfRspuXIwDsShcRIEHEz49K5SdToZ3Ky";
+ sout << "XLZUEdbMsZDXWjmE1A+hN8G1oAh+fkFoXav5S1xvmYqr28vAff3UhxXH1ZKVGM0ePGoaE+AgPIuC";
+ sout << "NawQwcyWFoPqLjaWg9bgf1K+gbhxH4ot5ehVVMb5YiAbMrIp03ONvVFLGWzI3tzFMMTOemMgM7UI";
+ sout << "1T5rYrWbm0q0IBIhEpRCpR1Smtgf3wK4pUX5y+1+xGk3bIl+Sk2Niyl0DmrmqZlenbbts1n/hDyX";
+ sout << "ZS+148JBO7FJvG6L/oD7LaIzjBAPMWB2TLARdbc/ShtzxMY8l58fxDoOs8gZXcpLPHpdgEALIJWi";
+ sout << "GHRB0U1osylCUfS64GUzB/mVA4677H/n31R1WaghcDrZRgZs/aeyy8DIO8fprsM8MwamVSDoXU4y";
+ sout << "YNyDupAJA6GKeYC98bCgZ64TdmXX3pDp++8TOyDJ0VDNrp6LfdqOaMNZBoG4G8KduFcUssKyGs8D";
+ sout << "mGZ49omOt6rvrlZpGgBxi6afE/7tg4ac2hHd368gOLuPyLY8UKBEqhRGG4POlbg+v4AEegzLQdmx";
+ sout << "GrtxnwcojpDFU3k+HPD7wQwv6dSpzCGJy3d0682y/x3muJLZ/bQaZUV5yj2SCfgSIaIZcmVW6YFE";
+ sout << "dB6GXs8kxAZgfLLPT49+5jIL2XXkcmK/LK4dY3ah3zGKd/Gl+NUxbZmgvOzG2LFmPvp1Mr73WZ4L";
+ sout << "q/K4XId/B35/LLnT1bboTQ0L4BaJR0r8uURHEUO2fo9HtByhvRmaDXy/+EAxT01G0sjy4Nhwd0oj";
+ sout << "AZrryg3KSdCTwUoi/CMme7l7udmRxUMSg+L6+pTXgONqyBNDwkksQ5aw5YP0q02WH8pd4kVHHLkc";
+ sout << "oVv3diXYApXi3VBhtVmpyjAm9xy2SPL/iaoitL7/0d9nIlPymzGh3Ko+ghijUD0Ft1CJHr2pn4QI";
+ sout << "zWL0rF8DTAP2TcD1dl/I1VNW6e+GLtWgI5+/iDElGyrELfamWyu63q14Yp/Uk/PrMTsxRxZBzZ6j";
+ sout << "whO3UnC70N0EWPc5M1q2HdWuUadl9tbXFj4eX5F2OABfnMgQHDo9F+pux+TxpCZcNFWHsd3frjF4";
+ sout << "HEg6qkqxyaqgZLUM7e90wLxYi2/XV59BWFuZC6kVAkBCNWMZm+jisGYdf+kLWVrbZ091mYTllIpt";
+ sout << "qP2NaF948T+/rfiT1TbhmAGwj3NI4dd3azqqJuobPSjj9pdK9JEKhyC2QlkZ4HgRqyKTH0xPA1Bo";
+ sout << "mBCwzYhubIB2Ro/oxS152TwESUPDvZXTsDbmJXGsdCrrmfQy4NCXNQEDQrjrRY40qXKi9Cxr7wjn";
+ sout << "W6NrSshxpsV1NwlSGT1Omn7RUMTpe1JaKLINxTUJCCPOuNCAgbkmPfB2L/vzOq/PJ3/EGbolcvCc";
+ sout << "b9zvJpUeKfNOK9oqBZ2dZQqgGAD70uTPitKu0/5pA8I+14sLkIfVpVAiJI/54Jl7cz9lhMQ/X6oh";
+ sout << "ILDtsAHQjcDXP9BCfByca38PJy1k56vOjfg6Tuc+0hPcDkhWobXqi5xhJaS8WftAOhqwiYZKsimZ";
+ sout << "yFuYCz0EeT8EGFa4APSgjPSmsV7jCpOWoY0RtUYvhNMLSmFuwOVOiUHLrlsb8gFXQ9nmw/K8hxHc";
+ sout << "9kP96Rx3f+y79BSrhUcQyUrCSV50Dkvq07wfcAEvz2dzSKvz6zTDIKUFxfz6ejsAvrcx/7UgP0i+";
+ sout << "1rtTnLXH/Qy6rlJAkxXknkQQ6YG8egWa0mX9Wh7RxpTIBHPaUU8gnIxo5/RZEPOqrF9DtYoJMOkB";
+ sout << "T6davyvC84bSPrk5QW0DMqzZtHDDuzqiRYZt9PfEvURu7iI0IJxcVWYQcXJL/ZCp2GVS8689pmYU";
+ sout << "zIse1tFwwd64kJJfyBqN/vW9A4aI/PutNu14PwgPVr+NeH8wQk1L6MhOvcnJwYDWm504takuedAe";
+ sout << "hB3lePRnMEBawImStuD73hcZV1KzTKFvu+ebNR8414Wj6gVJYfPvHFv333u0ROqJ4/Mhb8occomU";
+ sout << "iedo9jM8ZJCMNbPF2Vxxgg0ahEwTKyq49Qht9fhwUTm6pMQKUXOWU/rD3x2BwrKQ7NZoOUqTlfGI";
+ sout << "/czGxGeP9PqmLcleyMhHuC4GdlvI+ulMTlngO760teQeuF7EYqbwwuwqtSV3E/pbXm15OFGSr+Jo";
+ sout << "FrOdb5/WWAQ7OZPwdgYiWxbLELUCjNu22PoVAN+PlAYShXG/qR2lO6I8Mh6TXZ2Oo1tP68lGp83c";
+ sout << "zpgoFpNuaQzLCWIjC4ka5v5k+Y06crZnzCzxaPsanjgOtAlQ+BOykH38ErkDlM4M2SGs8IAZXW0R";
+ sout << "zHgZygNho43FDRQHPiYxFtGV2ktVmtGNTG9YtWQsBcnd/T8xtUP8DI3tR3nCg7Q9esbsZOgbzCIL";
+ sout << "bbtn92I2iiTkjuPSuuYoWPzVS+hi4dB/BQWeHtnwdNAmqU8IAJki4jNkjh/6NfTgVepkT2Nc58pT";
+ sout << "V2DzfqFrhOW+whXsJIFkMSW0dXp5TCuJBbXTIiChMjXI3c+Dzes1/6CF7l2lFA9Ol+AiPwjPvhhQ";
+ sout << "/9WCwWzfsmp3+w+9nLLt5CtWRePIB4LC4nugM985fxf10qYES/K0vxKR1W/Ox90s3D4aG57SHpsX";
+ sout << "8frJ3HU/ouV73ZOtCg/lOfdocNCQa8KsoKhy0T7tDnWoKJvk0tT0RhFWltVQHV+sVteGXOImaxcZ";
+ sout << "MXDU6mnRzp5PEkQnNnBDeynPvs0vWfVSYpzQnO1tyIM2YYwfyf4vdn82ikk/Bq7onvdohQH02/jU";
+ sout << "VQjCfYj7X93of1bK2ja4vdEtcFpH34YTnSyZ7Rc+Np8BjO915a5XsUagea3ZHQbC9bWM75QKp9f3";
+ sout << "HuDpEn4tNg6RTcwWf8s2G2vZCfIDkH0geal//39BsVC9LqIyEF6urWV10NerZXv8fumNmX/5JraN";
+ sout << "aZ4GGdnhHPnryfHkHG62ESpmUCBO+85oDXrZcI1riSFKCYP1eBn/KCWkjMf0oV/pYljo1KrZgQay";
+ sout << "6vXXMdX3Ur6UUBO8iFht7ETaE5BeQ+3CZQEe3tdz1Z87V8rI1Qe0Mkq40CwniIxvk86JigLkt1Yo";
+ sout << "iQBWhcgJXCdVPSB1XpOdtZjgMPQNJqJtcbu3SDHsWQ3TWA6BDtVbHH7xY7CPVN2si2K8B8/Czxmq";
+ sout << "xptTLBIMfU0xaBZTXLSmRKNTEj9D0H+PwHvzPper0tQU5NEs0tZ5h6FZER1mgFzEmbZ6PUVgc79f";
+ sout << "ZnMaLmKpirOVc7HwFfLuwxAlvZRWtykLc8AvkZAp+bYGHERo3uHgdIKkeW82HQUKsUMN1VglX/YJ";
+ sout << "gT3Htmy3dEuYGNh+OKulcvAoRvLmTMYgACr/0ZNWoYjqAeilI7ZoqBWHYswLrvefHlIz6WAAHwKC";
+ sout << "xkbk1H0uWe9IRcf0DgEvWkXSrWRFHSS7uTePFxtLM7OQW/GmGfS3wt4/YLomS1enK1L4tH8fEPoL";
+ sout << "ZeZgjNYiMg4gQLm0+FyseqNjiUnQPkoFQg1jztPOZSxxZToCSdeUIawgr/3KKQ5+d2w3LEOQNAy+";
+ sout << "Vf+o5CjsfXpYJ8HVxpLsEC3xwVEzhpeY/04yUW2ygAfZqbsLARAAWtW2GcPN77MnGAmm46PeCSe/";
+ sout << "2QS7gM0ygtJewBCaxSfij+Xg4HQhXiKDGuwDYqUDmLInHskRXyVuuojyCbmuNmtyj+Uh+HsdFeOU";
+ sout << "fqTBHfEQzzAFVyORf14oEhGBwDuvMzLjciKSfLu6KcMCFo0TGgW3Niu7kDdbxOxg6Kwa4wiaWWk5";
+ sout << "mJpG6uhNFaDqpfVLY5D/o0Hn/QfMZwR/MzmiFjstPVEd9ZzuRvKKSDOM9Ult2lTU6z6Ci2b4unu/";
+ sout << "qfLIE8CjUAZrZnEFbC/KWx9I9zuUuoXS8elD9aG4if80uCyJ/tT4MXjx8iRsbvZ9B2MoFdO9gAUn";
+ sout << "qafo/IYR+5BOderTedEwe6FanPX4opAcpmpXM/cbD51QA9LCLyNtSoSPlwji/ZtYA97UUV6k5c47";
+ sout << "TNFB+QtUaW7/DpPZbxsU92jE19ztQ8VV5Ylm8v9RdvZDQ6T37KVEYu/z2bHX5L0r/bkEtQQ3gf/e";
+ sout << "TRfpy9j5f1lVMpN/UccOtsZfQpW1WvDcbAbOjcuPuuwNOudfzztYg8vjJNn0pMATK1+p4adLcxJ6";
+ sout << "uCh+zVgxAS9W+z/t5X3YTMnFDULRXqs8lsN6H/1t+gVBp1uG04BBDO8YMXZsUB8dLyb5BHbkjPiU";
+ sout << "c/zt703lv2U6So90+H+6hlnWDBvfLHI1W1C/RkelOCD8WVZ0rk/sgI9UYf+jk/SkIkZgbtVTB5S5";
+ sout << "gd5IXkexjN/kP372AvWgX1t+uKVmKRCrYlG8TZzDrjRD4BjS1M72OGwm3MDT6GPRcaGc+jtSGAxd";
+ sout << "eW+6BPofMujBNOBsrxKejRcLsddfOrGTBqP0Wc8B3cNOjEg/ITJ3LOu3FFA5TyLh0BvOEkuiD7Yf";
+ sout << "SpB8L9jDYG9nCcgFC19CA0Ji/Oc+aYFDMiDM4be+gG9QUdAxk7Yg13HdJirKRiOvJkf++ayrJBbZ";
+ sout << "xs8NEcd6lN3XoogW4gChg9mwlUNehzzjIl56wG01T+29FhD8Z8ebywHq0xEbP5aVGrEzwj4HF7kz";
+ sout << "erUrxlRZbmRgleNhldZDjqUpMcHyJzF+OxZ7W2AfpjJf5FaF3E4Rb9IG6vBtPhZbk3X41VpKx5Ih";
+ sout << "DedeX1BVHzRYAPt6JchdpOa+F0Y4Evke5XR97NecmGoYXrSkLrArSXDu2qqeQVx6XJ2Qginw4711";
+ sout << "fJwfVb8QMMRouBWEYz7k2LflgRG3c3XL1KP8npyz3yNw+6b3/K5JSq4hRFzXfsIWoLArSWgfboMa";
+ sout << "w/wIVdPgwPnUh3bfYKEXsQSenrazuqhPosJRMsX8qdZ1HB9PBwMKMHqDessyUX8u4kFMZwo1egfw";
+ sout << "7x8xkykzBeY8Iy0y6uOVpowXPgEPg9PZnquqW/r7qQAiqgoW/4oMU8DZtcfT9fTEN8BI0YPPNPFY";
+ sout << "1N02ocPS4kd5V8/bcTWkYmtCWIqlDtyPCEcEoBQbhneOdw6EmOmTqJVArAwydyN8TLsJq5A4WOf5";
+ sout << "wWO407yhoizBhRjGlqsp+LW9DlG4LHKLWyKCXb+iozkdD8Bh9RX3iYz4RWNSm552G1u8YQtZkeN6";
+ sout << "ISOxeylEI9OFkxvueA6u7juBnwpmt3sjcvCOLbbwwpc0JyWfAhiRdBfC7aXLjHm0NcwiULATnoJ8";
+ sout << "AMfNt/qzosP8LYQCvRYwJYShbskncZCd5amEmN4eStNaebKyFX++T+XC9EiG9FysF2VZdRf3OhzG";
+ sout << "dDBtLbkdALyBU3A6GhCueeZ4c5vEsHC8RMH5iJtYHXwLc0OWzUPC/DVlhVi82avRChrbJQ55qHza";
+ sout << "DBnr3I00eUsfGxijrqaQ4bfLLv6S5e4jP6C6waRbf4RY/Q6kTb/9oECBuXgayar3WIvLp9txdLEi";
+ sout << "dc0sokRgOMXjrheq/gZ+WHBm88SLRpecDwc4D0kGo7UQC/PIajFQ6b+rTJBKwYxeaaVpA2gaakqR";
+ sout << "BeQurXrdpAW7hWwBHjzWtxG/qKwO/w1+x/zAICr9N2/1Bex8yhnnIEcltD2qN5ykzqBzrKrbAzaT";
+ sout << "+6+ZV4ZTvaeeR9ELOBDYzkkUF6cgTWAZnwT0dKySp2C6FSPkpNseYrrgwg4Ddp49pN7ehwCE6V99";
+ sout << "ebgMYWGWqJkBJ7ID8NSMDFjOTvuWzz08QpaygfGSGgipHcuEKdHEe+waRMOMZSPVD7uDdV67805a";
+ sout << "67OTB9j4JiA9zhyXqx3STc/ICLFkSjZbl0VC5XAyVSrC0ssNFCVLp7OQ1REJYnj2bCPMpZCX7VHF";
+ sout << "RY8HRrA7V+X0HGBVPAcCrBn47lFvImWGBWGAWFO+TltuyKwvuiBOU1VfG1mxFCJJj/+dv7jIdDp5";
+ sout << "+VEE6pCWOJxzveqnJ31PvwrlGjxY8WondS64rNe2qnQ+grQCp69T4vHe7zuYfta+QyjtQxHiXSBV";
+ sout << "Zyu+R1M648V5jmGpQUT1bek5FLy5LoDDjkHITygysq4NzrytxR7Ncmgns3tdB9LV5zcHBG2I2gzc";
+ sout << "lHRBaE17Z3Lo6zdwTlGpfumcYid4Rlj7OgZIZTz3ogfaWU3OblHZ7Kqs3GGvctnHnaEtK9oqQQFP";
+ sout << "zEvAG1nL5qjUKB70e9YgDiUYKAC7xyWVoB0BCIjIXdJh8JlwXMRQDQwEwAeV8FAYKQffZB/2JEeR";
+ sout << "NnlaRxoQVuiNom9Dh7AokAEHuitMI+KD0n/bijH01pr9cxGoz9uimYQwc13mw+ZS0GsWeeEqYSWV";
+ sout << "kUsJVrisNxDeJMKW0hQuZmNTkjve2aX70+V9ouwgMlYQusDxVGugu0HNUQBgWbnE7CW9fh8GQV6d";
+ sout << "42/8i28kLjJ5wiiReXFbXDQ1XBz2lwyXXm4ITq5e3tqn/BDX5S6yQg5yKGl+2pISvDvZipWVNj0i";
+ sout << "vNeVWDITOZw8OAceKLZJ0WDSaV5W+F2Io89LI4mdzDdvOk/Ct9OpxFAacJ+H3Kx2S6b70hcdrx2u";
+ sout << "0Z+LB5tDRmxItNYMfz7br1fZKyAO8hUxeWN3EF6v5Kh8t0R7ZZaCEJgt19e/3sFnR8ssvG3yrdXG";
+ sout << "RT6qvp2JAUyP53dm4Z8+6e/k7YA7WR1sv7d7omc56Q3LXhpvmJJPbAJ/qBXHog0pO1ypN8szYVUo";
+ sout << "MbTlTlz0TECPbqdThQFSps4oXTeCcR+5fj7SR2K3FVjbxHfWhrUfw35Nh7nPfzyDHwfmX+CMdwoj";
+ sout << "oosswTADZSwvEh46PBffUydxEG7X70qyn8VhwtZGHTFNTjjncB9AXeEpYpqxo9ZzOXcCM4JwI+jC";
+ sout << "eDsx+HOZdie9keAE27XUcvKhNYv+G5Hq3rAVNe2GRgDdt47Ysz9vECFOBxO4RVAl0jr3TmfKoqmR";
+ sout << "3WBN07KGMf2q1osu+RZ0I2wAQBwQOFG4c7IjxqEh1tlBsFUQBE207K8zhrYyURXulKfdHL+iWMFs";
+ sout << "LxDsVCo3kG1L1WRBXynobD9DGvbx3EJf8eIgrKuHSEF46SHUWs6HjACerAjryYxvgHUoUez5cD88";
+ sout << "TcQwFBS662AAs2ZAp/uA2Y5Pc+mnluC+5HYwxYI1hh9TVEQPXRPLccrfkitZYn/kV1nf7AzK8lEm";
+ sout << "lavQkdl8e0rKXTunMcuayFgtPCyqiUPn16MkaQBfaNM9B+gw/fFeRctuRq46xbF0fTzBS8iQBkoB";
+ sout << "jTV7UmXA0Ysi+iVFAod6e4o9eyW/WEhGUCrAWefyUqbytPJDH7/1VnTdlN0bpeXXSyqWgTpsWYay";
+ sout << "GhdYTsPDdfTSzkEU/0wdKm/w8Mjcf3+ZcVNJhPtJpEXzoFsSL0VVGym0kDsJooOizfSqJPyczVD/";
+ sout << "b3vN7al2kuDcNZjVyWiZCfPyED1UQTSQFVZBygAfmjg/spVyg3zbiW6BjcDuxOxOHmBGXuE2XOjM";
+ sout << "Qa4Qcw7ioTk8dw1VobafnWiGN2S01drTqLFry6ZOG2es7IejQh6HsGig3Iis3xAxWWGS/Wazhw8v";
+ sout << "M3OQzSmbW7j0L0KcLOMtSxiYo5LHXxhti/pAaI0VZSLkQYgyMD/mBdbs0B8i3diXoAM/BzqfoU7L";
+ sout << "7SNslWLwXXscPrJsCwFqe2UiahyQyIvf+qXqBhKuFnIHVt0cF5cFAeEh1kYzAd0JDbi/LB2hYsVX";
+ sout << "rosavz3xN715cRmjpjuQvlIDUGfmmJPPvUIbbXiT6q+TrDdQaPw56G7vz5llNd17QS7UUZymWjY8";
+ sout << "p/8bq6MieVzWeOmiBPbmWRtmZ4GPTc0bmau8FH77JxwgrkdH/7jDL6uhcd3grTrsm8X8ZX9FGCZy";
+ sout << "QZQ52pU3tWbfy7KsWnBZjkRwzZVRY3EVjg9V1PMQYKAqps7mU0LrDBPwrgHW0c9EvuYa9SIShUgX";
+ sout << "JtiWswoJVbbJXltdhclBmvuDsyK/MfLJrSE7qnPOD009hILCojdSCmqdaFgpmw88EiYAVDqwcFzw";
+ sout << "EiDdvqg91OFplvWoGEQx5wsjOKaQst1Z1LSc/C1ZZH+mmBMgHfYmr3ZiX/eYUjGzYpdKGa00v8gk";
+ sout << "dDEkUCLQ1I0ayiaz9cMPWkZ3avnYb7TdUKa3iwERVVFEJWfhpWe/4KP5qBL3Ih8qLicgtdkt6FOM";
+ sout << "ymOpNQmfe2X+P3U6042OQPIU1vmig0euGeKQmsE4K3NoSOEy2QheNv/2GSSNRzc925zhjAk1zxXQ";
+ sout << "04KKCDA4HJKFK0tly5P6KjSVTiQXZ6e+IWH9CQfAJWumCVtHNCulD2epjCPlTEjHH332gl5c7h7u";
+ sout << "eJRVbsJKZm9z1TgFN5w1iDWI1cBLZipzM0BznWEvEoYNixyZot+Yfm6IC0SZi8DA7c4Ngn9V0ydg";
+ sout << "c8JNp+A7T7LqB/LWgma/E4+wv/cdoXiKi3mwg71ZQRM1d5oWknFfm8hq0KNOl2RsSOYO92PIF8Lp";
+ sout << "kG2KPK2nBU7VkbNtaJwjmJpadvkAEPPpCbYDdRjrvKaSw9yehsWKfLVuCLQHlO7uJmGe2LHWKojj";
+ sout << "BtYYau7+Vfyjj8K4eNqABvny8hj8zV0HyToirIxVpBmH/6/8p2imqYx8hc7yBE84LSNCEac3UJYf";
+ sout << "8sxZluHs/ZD3t6s9jVisgtx67iv8u3A4toAOlYjBqj/RXVQRAb30MzAFy1YCdxedXIvKBa/jXMwk";
+ sout << "0DsQ+YcRsnKmnGQ2yOxNVfiJZIg+KqN2Ork0y2QNEy7aTv25OEUM8U2wKaRkiyNYh9VmvZpYk5xM";
+ sout << "OPOsReX2FOJOFTCUGZ+w4Lp13OI+QLc9TBllim4bhtbBscCMN78WOpnUccS17gXHnqtpK1Nx7R3N";
+ sout << "dwpVTYBCxVVUSCHAREPvmPlHj7c7nnIfO5LRxuGxrg0TfCBUwCNiJvuUpiuqNyJha+eJCPS+SlXy";
+ sout << "BN9m4jvEM4KGbkGSwelXon6LWOKq3cya+8xOqQ6KliFrdQaNk0Tdg1Mi2oRjBlqjyb55oXYNeRgT";
+ sout << "wQ8wAPqY5TM/z/Zk9ELGeQz7JjS/ru5Q1LGNUU3tL0AjqUdU6adovph7CC0/g8lX/NqYg1oRT575";
+ sout << "7m4ihHqVTIN3pMkB8TJBZNAa3i1rlom4qcOFGdjTi4WTRd9r3ll0y4PB7V81SddpOk8RPTOIDKR8";
+ sout << "pgIa6xFOQ/GTIaB8i7uLax+IMs4dIq+gzpxE9g35BP7J5cbIYJJeMFIm92l7wJB93ANTT84yQABm";
+ sout << "oBnlfu/fjcbGvfHw41CTRpMWraO5d9nABxuIX4wFKEzxse0iAVsJA21RbpdBtIHLBcbYKCh1bUW/";
+ sout << "hRSk1yislMs9Mtjx/VSUn50NAompaF3NLpG10/hGd7pp1co4wZm7+fiW6sIPBCqLtkg6A1+qGUxG";
+ sout << "nUU6NqUvmyWo5B4Re3mDggASRSIFdD8ckv7LlAbyDgVyIxsBrf80ISQN7jRDDd7MjkyAJ5CRlrUu";
+ sout << "vFCwCo11+7E/Yfk2dadQJTjGvuyyDXW72l4ze1PuBrjnql0vBe7CCPZKgKJWWDzbGIEKGbBqDfBV";
+ sout << "EVA0NrVpWYBu3uhj8sHi+cAMwiXy2T1ar59Cz8bvuIA+9egegV19YAay92S86BsJbhVfb/THW/G9";
+ sout << "sbTgF3MWTUgvDFmaJzYn7mkR/2xsBn6UQsLYPZPxOQ4VaEUzAKFSP0zwGZfMYE/REnDIbdOR1ai+";
+ sout << "zrA5gWXmbLlr09hkM7tIn2Tw/+X7zEzZ66oDyjDs6g4a+CBt8OtOLi7Ga/dO1DP2y4n1YPAsGADW";
+ sout << "Jf24rVZGv8IWyGmlYJZWX2tBKV++HDUTxnmbsDSI7Jy3r1UXNWskyDC9zYWetsp8UxAVRYBDhwfv";
+ sout << "fqlchxnmTpm6ozULqYI6ExncbbbK1IBSbY2A+VBUrEqdH7Rxg45S0HKRbvB2jAkksOyhbXUwngM8";
+ sout << "6uG1TUcRdnxLvZniKptRcYJthyx0HDHCzLoTT1AlyTB9kvrktDArOUa1zmF2hCZAYP1swqb+zBjP";
+ sout << "CHd4aExlUf/UokEBduXCWomJLGXsJpVG1EVz6kTZb2Twdt5x5vZJK3O/LCBKgZNMuS7/xLFZpvsB";
+ sout << "cUA0MoO577pa7SECGSrUrSblkliCMKmIECRrQoEJaNKCs1r991ptT9aWnLSbl6HTTWXKcceN0vXQ";
+ sout << "tKwTpiM5nhBmea6eOQn5JP5oeytyDul1Jd0SHR4WILgQvYQSrqdyQOf5zDBQHe2/EsinsfonJy50";
+ sout << "4M4l2ArPeWEjJKtQBpgQaLF7CRGTjnv0TPJ1Xjo2Cy2OFvWpkA1ZjK3Q/I5PJsMrlkYkTSgQql/k";
+ sout << "HwmLiFL0KGdVfo6F1HT5J/5ZxCDOWhO+DmvLXtvTQSEE2asJVZWDo1rpKOq6H3AhObvUgp8AtPec";
+ sout << "+KXWiyzTH9YYq5RFU4JE8a+lt8+bGtYHxGOE6AuVgQsvUvAJ2kgqgi+kLYF19FZjeYbvuxzI3t7j";
+ sout << "+qCO3v59HfjZWfQ7FOJ/ciLVAAEM65r1sX1JpZRodNnZUvbxT7QoVOm0LwjW3qw2VxZCbVU/jPkc";
+ sout << "XQ+8lxiZpHO9mSxBXvQkwRW4K9MZ9dBl0oYdvq8IK33SkyTcaHBqlp2oszKYwDFyLo2C9uiE2BSE";
+ sout << "gj9UpPG4M7XOUjUMJ3gs7aitPsTdI5oHG1xklLOTvyWYIJ6vk2/uBrMImGSExPSMMGw0NEbHFEX0";
+ sout << "Fh4Td6+cO6QgCQeobzV6Hxsi+HTk5mtAbMGXiIKi7bAiSC3tkoiCeDfBxzqROYgSVQU9fW8PeyW9";
+ sout << "xEmwZSABCNRog1xAwHaOuosBqYDMidyh7F5ID9XeYHH0qaMxLmzYkZ0SAl1qECFzJ7wUSiwRzvzC";
+ sout << "VfatrCs84cDUKQIpr82LPouxQBeVFQAlj0Plcuo6CxrJgVR8q6ciC8eCnZ5WaF27pKHrfPc7ZjdL";
+ sout << "h4AwTLGOpUg0N0/aUJxvFGOVbm/g8pH7duFC3+ycmYrOXZPbhz2qthsxf894+lkU6TfmN3OOoTwG";
+ sout << "EXrda/bR2NQNvqJWBRTv//VdXIr1RmlNMDcz2d9lL3rwF9Gq5EupgYIfe1FhiiqDjyrTzh7fw2Qw";
+ sout << "P5T1SfKa5Ww9mGxa1psmYlJ/IzOQlSlSzSHxMVvD7c6UDpWwKWFlzzRPjl3WPS9Dmzk3tzUo2ZYx";
+ sout << "RdP7RBZg7Cveb3lkIn+gU5XII4cD79YzKPesFWpWR1ejkLVGW1YQB7jtRaJmoz8T81UTuqhxhBt4";
+ sout << "NswM1L0VtG8asAidWd9nT3Y/WWI2ydJ9YgiepKeMdtpt8aYqoutZUFzPKGhdZsZIYikb9r1Svrd3";
+ sout << "MAyceFX5RIKLVcyJJC9n1QlJVSXz2e47rrj99prVvMN4ROB4hpORryHzGuPMMhpn21ZywDk1aH9D";
+ sout << "0WGZq8Nm4YZEjneUtkkW0nXcF0nxJ0wTFBmOrXEgjpk6jzfQ+9+b803WlwtqRUJ6S6FYirRyTv6X";
+ sout << "r8xCqyXgNioZycZSgjxzLcoFqq/65u935jK7BEMtBGWa4zQcwq38IcmZu9xfCrQLCfHpIfydjxBs";
+ sout << "8PrNV84ccG5Yfs4zN6oaLOahJBukoFHUZLLM+67oALQzsStEXS+HQhWMJCRj5M/8/SwoxZM0ACfQ";
+ sout << "XY5G7OnF21/5NwpsUoYboT1wNFT2r8TKrFOx9bEIbKV8xTUgdrnZwbKGPemsypnnsFHoA3BwIWKx";
+ sout << "+w0vNWVqK2vsbL/pTiZiov1lxvfKFV25Q9uylplUaSnzYuaDQFLOPBFr8nhcmnfZA8r4ljcjMNw6";
+ sout << "0AUHO03MyUDfu0BUVYQEyGXpEvwnI5JmVOI/y3/TqnVnReemG4D+diuwm2Q5UgyY8KyykxJam4Zv";
+ sout << "W+Tn3DHT0emJNWv3imi5itW0rUIaTpt/9a0LvqIWi0Q1OOvrxGgUcrsRzfUVi+ru+tX9YZrBOYXe";
+ sout << "Ut+KwUCvITG88r9m7o8aovNyH9V7bC4axTrIxmtpRiRP5e0Z/IiH6b560Z9ixQ0MRv5SUn/lhX/l";
+ sout << "AP2rTsqxLmt2mWz9b3GqtEO0iW8XisORLiHzupEX9jRbcqShrBoG9bu+4DUO9hGdemq5lX1782dh";
+ sout << "FeZOjhIE0/8JttTmz7EZwzp7SeHup1yqlForPR0hhkFzfOOQk5CwfYR1tD+0ImdmbL/QyE/TRsnP";
+ sout << "NoDB0NVyXcmlJQzRH+s/dK9kJjHFn4FAUJ7Eaw5xi8/O5VokSz3gfwwPHbKFexrY18fDiyeZ404M";
+ sout << "cAGYSGyYzxkpa1HwuZE7UzUNlyzpuja8RR/8XaUv1XRK1lIkmDFA9cKT81vtXJBY8/04Jh2OsPBT";
+ sout << "mFRnFF6+vItwlsmOaF/WUstZg2XxP+MePh2jGkKj3b4c4Vw9YwzD0/xNlkCk22jnN60wtM5WHiLT";
+ sout << "Opu4ZnP6/woOP+7uOOHdjEJvo2d/v3TOmNmfx38TBCBdAbRpMzy4XDZdsWpP7l0DN9/OnbErxD30";
+ sout << "2sECWhkxKygkNxeq+CoeiS6Zl3mzEKkQ9q/XHccCioWgiODFt6+Wg9MtPHGaR1cnaDDBBp4YQWJG";
+ sout << "c0D1Urc9H4Tpn0lmJkx0p09nBoMHGIBUKmHuhMOnuEuZp9wDYSQ6UnjtLy/+QJcCB8QDvTCg9mhu";
+ sout << "W2HNZPqa9DFD2tmV5e3+pTojHI4O0tOwA5B2OYmqFzdvMWLuFn0uMHOu5285KE0ZDykOOv4Nupq9";
+ sout << "z/rjTtrxInMTU+a7hynQ7Ra75F5nCnEefKXL2lgD90IkuNDpmxHBX5OgOYb9RUwqcVGAFSZq1jYm";
+ sout << "rg7sxuDJu/honmXsDohqfo3/vBVca9U28wnCJ2IsuJXRQwzBLOJAz44ijj6+Bx4R/7U3IbN6nSNW";
+ sout << "5fXiqV1JD/r3AMAf6THZKnYy88mv6kMzt0ZXPd4PlyHCYhLTIzw3AZP36xPjEWBRP8FGaJw5Pm02";
+ sout << "Kh91uM2QdF2jowNV/Ago7rHS09R3B/aWjmGdVTyhCjdogPtrhVHdaq33OftcCuCF7q/wZ8kL56HD";
+ sout << "2T2BxxmsZY0XqSlU3757R/VoeGdk7XjAGbpOeyxn+yrP8/stbyIwxq8PSv6YhRN2c3+H9AYka8Lr";
+ sout << "Vrb16X3aEEP49jxNwEOL4hM2rcgS7f3uayVLwsuviWxiUcQdWcmjTVgTos/11pdXpSaPqSAvpAMs";
+ sout << "m79R8fdq1V5+kmZZ8UZKUdQwjJxeg83dCTDpYRTXAdSdKDLXeD4Kkvd2ebKOGb+1j345x4HUnuef";
+ sout << "ijTom/Xz3fEw1jLgqZd/nVeUCBYcXpozL50LwzYDZzjxERSVCMCgmt6VJBgpvwFBXyAYJql5972d";
+ sout << "vIP6cFU9K1AGAUipBxgaVR/5NBFTIjgo2CEOJulAQkwWXE1TOUXHF6hD+Tp9Zf3nZRmTBU2OXA/U";
+ sout << "8M5Q6nyDRaV/qRdB3GMrYPHZIXcy+uh9xI6OYDWSfv6uZxGfKDzN2JjMNLuK5CEIu4hiLj3JHC5o";
+ sout << "vHJ77Vimbm7wa/cTYx4ztSvSSFwSXy82nSd3o6d3Z7vL4FlzDjxkW+6AfP+8SQLvC8yiKiGr6U9/";
+ sout << "cqFdi3xOg3Ska+3F1sn/OIrDZVLkP5eyqo2aDC++WnVv4Aig14Xf7lgD/nOV16X4sn6BmpIysM/i";
+ sout << "Cjqh7AuHgDkB9kgr7jQXFeGAhdGCNby93jgIRBdWTY61JRX6Isek98cdoL1XqCheIxKckNMmljnr";
+ sout << "lHL4Fi8okkonI43IKV+8NJ5eH/JmUEtrcOuaB9M4rEI1NsEQK72dIWyuuIYyocaeA87R2dt2biia";
+ sout << "SsZLV5jE2O9YE/AWY5fDGAH9PA40TD8bC6bxgaik/QFegoo4tRw2yR+GwtUh9Utsu9LUofRnq6KZ";
+ sout << "hq+3Bqzu3U9grOYNU9MUxNo5jZS52U28A8NSuHBDeS2870sdUvEM2RftmNusMnCDR1bw/tru/i33";
+ sout << "iYIUoaX6EPN2UOJMm1TGqIOMs4QWkDTXYAsIdpBgq2nMfXVDb5y/nqpmQH2eJCgz/7Ly0VtZISf9";
+ sout << "mkDrrok3cIEMKMgvM5X/dA4F6Bv929nkXqqSi6AtUER/jJJARkpAJHJTj0IxUgY2unVx4KP4OuL5";
+ sout << "b1HeIkWdlln4V64OPJnBiKl17p4pX5BGtEXngsMR3kdDBJwliLv6ciLvPKybgggvLSQbQ/vqnZdh";
+ sout << "1x1GlnvrpKfoLFH1S6f4HiDicapOP3wKOp8ECyNZp2IAVbRmmPhjKjPry1EdgySgpXFZUHVEHdR3";
+ sout << "R7cGnH1AGhZinYuoN74cFsNooe+LxT74F4KWSW08+S659M7862DADzaxlHAbd9BL79Tu8RGS5CUB";
+ sout << "i2ATQ4rUgRvzC9OGcwCpOAfbp33+jqLmAWRllB5f5uBJuO2nG7OiKHV9jeunOyHxNnp4UTpblCBW";
+ sout << "yKx4EVA7p5T7qH8krIftUpBbUx38XhCtPMcsIZIMp2w+nmrlZeHyPGVGFrA+Z1+VjmRtiwGCghXG";
+ sout << "JjZ9j+HPIniiHK78MNAPoRlt4G5mtBopMopDn/1cdVroLuJkxX1CpCOu3MmkQQghl/km3jphiehr";
+ sout << "Bj/E9IZI3/LesK3Esv1J5VJCFl+Eyn8i90S2ij6sBBPF+eqY7TAMYAr7MK9mFtfM9orEpIJTvG+U";
+ sout << "Uun4ibgHRw0t6UQ4nnvZKHL4qmZXZB4clGgvF2Sl1hj3FpUhQLr4YWJbMeBVddvi8Ifsl/KpN/J8";
+ sout << "k0iL9nhnyDfLBNy9iT6Fnyrt6wKEJKAmvMTUS/U5YZ3CFAgWEmqI04ESYkZ1F/pVjwgfDtwjbkPV";
+ sout << "Vr+9zaxAa2gT5yZZ8m6CqMJJ+Hr1857AgVw+msikx4c3G0zFZpUEv4o3PiprOr0t0bGSzr4rWnLb";
+ sout << "rPNfMq0mQI2bWIMeZDtiAeMLOwrGdlnBUekp8NydJeL1XNAMMJjHKvNEnopNqkZ9Gt1Kmvl3Eqcx";
+ sout << "TuwbsX2ew5yNJOUD7T9q4lYRdKSNs88fdTFzE6heGTV9UGs6z0PMP2HoLXSShBhvPAIX3t92KK+9";
+ sout << "8986KZvRZ/QVN93Aqt1mczK4SEgo2X6Tm5814X6WcCWiZpH8CcwQjW23psQyguC4PyLhGJo92XJY";
+ sout << "GRVVisGEARz1qjm/GY/jW9kE6alXImoLE+F0opY6QTQXAo+KKleXAWmMEUs972E/aFQtmR2V2mxo";
+ sout << "Ga6EtSNXobw7jP+kMkYoSiav9pYhII3t/v/badrNBhKksNt56Zx6BYFOnfhPrwGkR+YoRz2+x+gR";
+ sout << "GMHLwRFe8lMPR584WNMrC8r0r3i+qrp2TeG5vFM+VKwWLSBUoG+bgZShS7ITSJiSjkuDMpa7jy3w";
+ sout << "1s1AIgtNyvxrKPTKr6C2FjVuhDkpSInVX+Cdr7utecsExirDn7uyb46oI46oNbsD/6HW3sIkvrVH";
+ sout << "xWY1bwvcKlT1+aFWKYtdZEOk8JE60pwiat+MpxgOk9TC5EUdJvcxOP8M8zQsIuQhWlkgNhsm3G7X";
+ sout << "1xWL302yhLrl66DlkKHjCH8+ee7/RCRm3l4nS62XeVQ9ZlzEIUbPP1tnnzg2aFi/VxMxk+EfF23P";
+ sout << "auIEXL5jALZVnr1D93LlSiqGwKoAq7s1bblpIGyxcXOsuaZ8Ls/J7FsxcLD8BQiK9v3wotEzaOOD";
+ sout << "jSfsH5DDUQ/nIdlBwZk1J8aOTSIo7DZ1wZRBSz66ptR9sSqcwrXvFndHulLtva04DfNDPzPNaCoH";
+ sout << "cAjxsOZRKl+ZPi9X/qsVzCr6DijJYeWT+UqbhTOGX/Vrl3M3Mt7q4Pcs0vcdXH+SUV8YlF/zmaD/";
+ sout << "xxZcFz5DPz8NRAculqklupW4Wc1D421TC78Kka2NTJKLzivgisWROR9c6g4ml5cqbrizhraRbvXP";
+ sout << "ilWQ0PSOq+uyJqeo+1JOSH3KWoyGWfOfGxlRffmvsOOuY2oblonsSCz6sYKySsnKeTezRcSP9PFg";
+ sout << "6EGCjFg1WodhnA/KWixJtQ8QsfhaRDoIYklDBWlPBEKxgvHi3yVF3vTyakoY/VR82B8UWfzxW+gL";
+ sout << "ldM+gn3H/DTpzR8/h7xBK2yoOw2Gpv2TcEB5oMtPsXsN58QdoqAsgBcHI/GNdfqaQDK0lt+npV2M";
+ sout << "8NAnFcK3r/wdAJ6OwJYtXPEIbhHxRkEY7GLUT1IwLeV2PIV8gy3Eh03ULZDcmii5xO/MsPgrUSKM";
+ sout << "Z9s+JmBmJhQISUjumZ2SBkAU5V3SG3T8c8EaeN/9yNYMFmSXX/jAjlAViNK8bCfnPUkrmnTAMaMv";
+ sout << "blW6tw1yhXV1WSAe075/Zl9slrZhcCJHYX+mZ1yKDpt7x/QxayGmwypWDo9ukTnE66UKiZ0udTiE";
+ sout << "awN5uFe+qvj+zZLrnM4Bf4ZjMjKzjd67W8UaxMJCVtx2SHeNC7Ffjo9xN4bicDUnpqTt5phMV+hX";
+ sout << "lBRiuUYA067qUb58MbkIE02F36n/ioHNYj15Eteh9wv1IMIpMLz/1UGm6+97m8xIAKH5B691KOsD";
+ sout << "k1ovwfjzGXmFX12KsdSplPSmH9UCo6CsuJmFkhAkYd5vCw99+JVu7RNqhJGF4h6Lg22AqisjfHde";
+ sout << "5CFUks0DocGrgDw3GcmVMFAV1Ix29iekfAgEJcOphJBGJ3FvwRmGkflrIJW5x2sPMaRHeU8Fkzso";
+ sout << "o/IF8537YleWU45S1m+GXNdomkGIFVXZCKwiQ8bRr1MmzMcs+0EkU0r6ei7yxi+6LlKfp2NQOv/Q";
+ sout << "owQ74v4YHvyNJcI+YLLMHaZySYlxIjL33H6b8vjrzsDntYXii8lVRRXhykAiLldwH+89cFk2tbeO";
+ sout << "am9kZeEIlCX889URxsw58TRlwd0Lfag6IhNO/hHpi652lgEYcZu5NPaDQPRRC6vkl/+ocoMmTDhX";
+ sout << "uGhphdczLsCBpIejuQVqI1dhrnko128Jzl6yIe6mpuhzRT1ReFL80F6VRjegnDa67Srztleu6qzi";
+ sout << "OUgU3EDRfE8IJ4i97SsnTFJR7Eyou2/GP1U+ai1U7JTknHbg7fNeD5wjTIyQVE31oOR2WTNELWa0";
+ sout << "/GWO3P6Em3J7vZJifeKTGYWf1124uh+oBYgv+j3vUuX01i1P9kmeUP0nmVgzWs6AGIvbIp6vF+gV";
+ sout << "Uc+8TvNYH+0f19q/lPOBnBT46URHpJ9EZxodLin8K5aTCB3sLjpXiqQDPD9fbMzL2r0s1Z0YPWgH";
+ sout << "Mllgxh7l6bIhAmymXIUX3+Iw9rnW/6ODO1huBbvgZOAUrLMnTA9j54trOL/n2DxwAU7JJ/fCeH5u";
+ sout << "PKH6nM9l/PvCTslH7ahvT/uqGD3/++3FeHok90oU0QxyDVbsMDX+ksj5Hn15ZQZQtYqToONe/OI2";
+ sout << "F+xJHU1CWejRZDJgDICcBYQSI9e6crc1vdf0y7aUeNRthj2xLC0zbtfy7PaIEy/R2E93I3jOoAq2";
+ sout << "Mz+0icoUb3Hckf2St+HqfYPL59yCrEyqoFONN5XMArC/MTYpMn/XQfCQL2X9xV9T7WaScs9TAqQ2";
+ sout << "/qZrrcfRxUaPCVIBtG5V+of/vfYoZzwQrLTioTohrX6SC0WvxS05UlbYo9sDwrchmIxPmrf5vI10";
+ sout << "T4Sc1b7sRfiijqZA+YTnYyBPG5c2qqsmSXLmtctpzUn94XPhS76qLj6uQEcoaYTOwFK2dVTvFpfu";
+ sout << "zlRNDzCCfzXc64lA9oKwEhn2tTH4ddClyYobly+2HH3xz3NQMrLxeYXblIKBFJzn6IlY26a+77qM";
+ sout << "srMzEwRmnzUojuRBO6n7BxGFfP1JuP4FhIpKLTgA1Ql5mf7Z9LrsM2lORzMuO/og7LvjrO5jVs/w";
+ sout << "+ZKmQzSzNBjEgq09UDk3gMG5OV38SgLXQCkUnCUThGU6En5OTQkSAF0IDQ9Bmmyadqp/+V6WiViv";
+ sout << "Py/WYbFDSAb95ME/J9YGfTg9VWncLDEU4otj6Milbr8OLjlR5YbbGkieo2I4jYfuKHR+/C1uXVWk";
+ sout << "HEOZBv2lYRKteXx18m8DLOZ1PO0PC20KRwcep1QSmw8SH0NJ2Gj5L2WvkpfCSwKYJOJECrAc2PcI";
+ sout << "Bzf363+QgmckMylMuHJ0Kf6AAP2xFtKYw3qM7tTUjt14QkC1Bi0l7ViNGt1vvyTq8erWFJYveVwm";
+ sout << "lJQMmyvrHq5g+hbj9ftT8cPxbKJ/EPYGKrtrM8AGxTxrtA0FHqmzjPTm9rHXhGYfm3ALWw+frbVc";
+ sout << "d06D+iiyRXpTTDZoEcbMaZYCSyHI1QQmTR9Wfr64r6XSk+/jDLNlL/Syp/Hzs5dK3P2YR/UjTB3I";
+ sout << "8U70K/yeV0v9rDUI41COSK/akYmpybimqmexTop7zqNUIFOqFHbH7TCqRxfq76DFebeHbE1Nq7l9";
+ sout << "p8HRuOuavv9wKGBJx5BoYl0h+ujVfqWIJHas9QuCbJK8Z2eNht/j/FFSvdlPfWIrRu20Anay6EHN";
+ sout << "bEDl6/voathXXz8Rtc+/xkb4A2D0+WwXHrYh0onectahnSjI6sK35wtf8UF653KENWSOVFiFAwjr";
+ sout << "XHV4YZb+VWi5jQ7470jOdd+6eg9v47AymA+6Z6ZeTWazbxSgfV8hCsAepHyZ+U2Z0B1Zgeuc5rKp";
+ sout << "MhBQZ8xf5x+yM1G9Jtu4Z5kgQ8NdliztMLkquT+mbY9U4eri/WsAQa0BvC6hisueLg7IGqjQGgIJ";
+ sout << "29Y6iqD9fJZuGAgbEbW7qu7J3DyYzGaKyAx4+GOyctio9ChW1v4uq7gkc8XNpmYiGl/90mo20hzR";
+ sout << "vEXjcg3Hh/aHJlVXH0hJ6W4kOuBFkZxptoDFi4vKdSCVWAgnQgWhmhkeI492Y9bUaDTxejYhTfDX";
+ sout << "N08UiRXnAN96KK9iL1fJZxB5MUX7cJ/GvW7RjTVP1MNzy/FdlAN+k6Sb8vyst3Sz3eGHOgO1hpTq";
+ sout << "Wz9Moc4BVabx5NLlb74ScQHCHX3dnKXle6sPCzfbPjM2Z8rbm8NvB/VH1FZ/EzriephAUYdra8OV";
+ sout << "+f9V3o9zEYMGujKxqZysPfheqaA3gKKDVgKGhfUn51dEq5YC61aWqxjO1suUrW6rQpBHN/V/VadA";
+ sout << "hpuSHr2UHCeL8yUzC2nnXxpHCifGCrx1/R+A7EXBjTlrxd0n83pEV5paIK1AB8Vsv7RNrT0bEPxx";
+ sout << "uNxgTVdFUVHVLZf2zuhlqQCSr6zKzQgdlsW2Rbkzqm/RpdwRmONCWkS5IRTwdqLM3rfStz3zt3U2";
+ sout << "PklWaPeSMTUzREfa9xfaxA99Yn8QcQZU0GK+zEe0d/iyeDr+7XPYH6ZOgQbFp3BgHWk9tAgeVx3d";
+ sout << "hOwbut1u5WRa4azjGkYUUlvXLJpehNwKEiaej7u9jKrdQMTRi7gtSYliGdNr0nnw4ADV60LXD0Fi";
+ sout << "uYvPQwJnbICCYdy/Yx0X/HBu0dUS9D15JuHuT9z1D/dqirMKNiWhXlEzC99gX4/mGGM4Q6l53SQw";
+ sout << "0205XkJ21NasDxznXp8UE2e4GsI+N40b6mLvLZaauw9dB2IeNhwZvn7mkGDraXjDnpnnmX0iV0Y3";
+ sout << "OH6OCq5ZBlk9/IIVJpS6Xtec6MiuVBYrhKj2MVn0tI7rRalXXqk+7Plmg/9S/Vh6U5RM3pCjTTf4";
+ sout << "ZEAxFH7LhT/JRpkk7fPxcncaHQVrnrbM8xEYc9/2nptKYriJdXHXecsiH8xYjjzw8ACs/FeFyPsK";
+ sout << "84MjAvnsSWnPVpJe9n51BIQ68dUG87+igoKtNjWomVVSAMlBKtbiNrzFN+4AvESKCufoTC4gpGPD";
+ sout << "2sFkTO5yg8+v3rwFK4aAjBLA6eKzSGkcJqcThom+NfRjrvCtPysEisAdXHc40muhThMiIiOFXwbH";
+ sout << "T4722yaSHEMGNwDqY9dlixhGz8L5G/B3hbrlqCPe2+GHcmmMI58oFMgZhEh4u/1t5nkCDizZ1COg";
+ sout << "+CDWA4n4lcOXfVMwp/8u7GXZMMLuGSJUy38DgjLzy6f8bji0+EEZ8rbFsvMjssgcx25YJ9Jy23eq";
+ sout << "rcwGu8iMKUQ6t8J2sZFPXuFrn9cyFP/vP6qFNWZ1CtiN2fRVbqwtBj6DXmSae7NDMmL8XfGlTD5i";
+ sout << "aiIWK5SPqYzS0K6oPFagYWYjIKqkNRyIAseq0fqTLKnyKmv8oj///TlXGawbehr9clexoKbAM5Mv";
+ sout << "1RKu3/pJiPVwOPXcwCIlQ+BemtjmsnBzJySAkA1I9mpgkPm4Pg4qblWhgq/aVulwN5uSkJP4ZU6n";
+ sout << "reWj3khTO6Svsua3toMPYm6FOoioKOOrMMXQCoIjsybBmSaEHeAFxZmHODtbf/WrGL4YbhUMal5o";
+ sout << "90Ay+c0apxZIV+8d7Bm5L0dRXmQfSUCnP1CKyUEGPaMAPK+DYRZNU97lzpowFRdjTlrJX5aN3jrk";
+ sout << "WN8/xaU0PTVo1LtIesYvyG8n8gBSmj90hw5QLFt/94H1NDvPYMvw5A6oGoQUYdCXTXzrRJ2OEow1";
+ sout << "0/qEzLCWExSS9q8nbsf9Ne5ZC24XA9KHmiJot5v7MZ4KBORjPy/Ub223U3xuJqVGHlSjYuHqmOFC";
+ sout << "bhqGY4yNo/7PU1/cvg6LbtQSriMj8+85GQSheMBI73EnwIHgPOIV7EGEr/GHrfyM1jI763x7Pr0H";
+ sout << "Bcsg5o34glMy6YWSMLd16/Djq277MW4XiMoGPK7Rptwh9ahtgjS+SAl92VTVyx+kr7gZXfecCivl";
+ sout << "gXUmA8sQugTC3Zf7QqR5eJvYA/i+PMRWkLw4YCogDOYXK7tKY86djtBHftkYob6bMj8a7/1/XGzm";
+ sout << "Xz1OGinw9yYz7N19wgkVZEubmK08ZnKINN7A9IU98iAauGaLizHRlQ1y84Sz13coFKe9lRtgnNHr";
+ sout << "zsG5yqiTSIlH1jLyxVh1O94Qj6RlNT65LPJa714IW1ot0XsS8QgX7Jazdb8mqyiBky3/w+VoLWGx";
+ sout << "kplBp4VZN5gRvTIHRBq4LaoGgzc1C5us3loLsxNUDgKBCBjUkywWAD+l7sCB+2btJo2HAcmgbNBQ";
+ sout << "zITjO+dRirzIS6GerX9+zbzv4lC2d5K8ZmDUMHV2x2NKaXIU3mP6II/iBaP7W2VRx5qfAAChmMAP";
+ sout << "B6zDwzc7LNopXp/GgOKhYzjctxbbpOcMjP1gnQ/2F4EAyXOU+iMZeXqlmbJcXk9S+bCK/Q8i96kW";
+ sout << "k8lbmB/ecQxbYJKynBJoijaSzeRYL5wC/IZyqScchpePu8J84p7DEu4ECqd9+vhjp28/u6aCgnKe";
+ sout << "T0mE9aIOVS9DLgIURa5qmWUFvolbrZ698AwXlEV8vkMgSDZyxgRSBdyZcmQawO26rEVQWeAr7dw2";
+ sout << "zQIGGtM4wACd+Vv3e1FgjvkytZOwyo0NLlURyqdykYhZx+youz/Kmmri50XG/hpD1+5gXtVFWUhO";
+ sout << "V5G9Aw6GG3+trxtMJnjCdkaA7DTjTIU8nUjNnx3TsCqXK0OqDMvem3e2TTeXorRPe2zYbhYB5pVY";
+ sout << "lSJ5yYpUrIzJCb3xPGNxQP1qVv/tCEq7IIKGqPl2RS0NihXE1uz8xPMu3cxu1juHsw4Mx3qi2C3z";
+ sout << "KH4LvPoEbP7fO9uELQD6ipi+AmmB7bKabD2AwsHFurh8oFtaCt+KuVy+676ggvwGEFm0j28Cj/ff";
+ sout << "SQzCut2j0HQjychGdngbYIF3HLJQNr8rbh7k1FPpJhdpJc5lH4QXMddVxuydPbwQuB5zGOWZdc6y";
+ sout << "0O4zjCgP0ibSPyEdhkolB5T7Sm6ftL3dctp49kJKtIyGd5scyxMreHGJxeaGADdgqken2Ahx9izm";
+ sout << "VoVdjhj2tvLlJ5FbAlHqir5ZcZIu+prlMkAnW20mXjeWdh5lSCUFGcOTTyxirXkUfyDRFVwfMnOi";
+ sout << "Y8j4v6X/LyNIKgBPY5qFKfbwAjYWyAPYYQVBurLEu5gsYGSYhmtm6BGZ39gnxbBfhB/JkCyE68Pp";
+ sout << "wwAA3SA3Uxha5yMTYK75ClGNsWyI1AwnkfYuYgEyCbBv/psf+I/jSpT3yvcEDKiI7FwToriSgA61";
+ sout << "3ksBBfw9E7Y66wmwTHq9O4alqsltIAVLkj/Lp+DxSc6tZn5lI66aOj7paImjnZac/GoJAxVUfcKL";
+ sout << "q8pBZEwlJoAZb6T3wCuEdgZlYp3MZ5uInqAZSOsqQQ1/S1vCj5pwdzHrHF6Q0FJemS1AO2/GCk8p";
+ sout << "dEaniaohFtWT+AMdEQi5TLVmtpltYP8J1Sm+U0/TazrQStLUXm7Wpc2DkMdU3cZw+ncK5znAzzn4";
+ sout << "dQGBHVR7TMfIpx8anoWGhShI6oHh4zPXYtRtiV0Y514GWxAMB7oeLOjnJTWeiwO8VS9IPCjcKEbO";
+ sout << "zTpPAwMiBrZEnglnKmIGxQXRmGMpag1JZfPRj/XyucX+LreWYF+PabGLPrjSnG5e2D4aDlNx9wGH";
+ sout << "raYpvkG/LJJp81U+s3VsS2vUA6uWgz2XiHonASuttcV+a5vo5tcKUDXDaFckDzjpIJe2rn+Z1Xxo";
+ sout << "X8OomKyPCCGmMEmE3SxPkGCkTIDrbb5Pc9j0NDVXbRyjVZUf2hrZ7pR+H8+jzUVgt4IfnB8Q7/Yq";
+ sout << "uKW+YvA6w3hYv94BqXPHeO6FMLhH14iBLVLp1Yp2aOp1ehJUQ21OgrhNGjMmF04P/2EF8DXy2V9s";
+ sout << "h+ohPrdH13fyVrQWRUkG+7Fn+GAYcpD1K5jvRavlli9pHys/axwdu089ivNpA5D3D+t+SHVGEw/L";
+ sout << "ZcWpKRQrUZnOLTNjJEpEIAkryuxqTcGDzZlB6ngOVxK7oUiwzocC+zjKtmafaH1BqxhX3RxkXy/8";
+ sout << "sXA9pDQ2HIn82Y7vv1+9L+88KXcTLTIUSI6iwrL43h6IhM018jTNfCgdVx2Uav7Vax1kmu9O8mhT";
+ sout << "NOxDk0uYoMJfLPhr/nSWlKjgojoEj4IBmALtiMtqe8L4zNpHTRxgvLW8wMQAyj/woN1mLxLb7Lxr";
+ sout << "UIY+ECqQomA4TLPWXuTsLMJXQyFgd4nQgqVTL3PUdAtwes3xrHTODTYv9f5CLEiTBV3BP5V6lo5O";
+ sout << "Wj50oMM6yJWV4jZypwEp8osWHksLQjPEubu5CiBucOv3abc3Zvteygzp8J0la8wjOT3cd4+LHFwg";
+ sout << "hQjO7JuNXwM1xhgdostGxjYExHC6KB9HDG3OUQ0wtxsrh5QITF3hmlvkE//xrWsyDzTcbWCgG5Wo";
+ sout << "NRy0RpsXsMImYHlzO3hV5wQHz1Xjswc0ATMArZ+YJvXieYgEyhkIIzYlFlc1/GgnScyJs5Al2mvs";
+ sout << "5GFjwMfC7wnMNRmV97SuDzQjKOiwMFochKCR8pSbMzJ4+dW0W9LMvasWGVAnwvQdfIm+E4+O0W2N";
+ sout << "ArEB7utReECczQcBVWjdKqK6/N6QbXobZfvLto8AqTTw5CjOgTpoEybY4og+zNemINlmYgzuU7Jk";
+ sout << "IwiyzZLohI6UaTrPWd9Ck2lLn9kJf07TnWVw5+2gIqhr5M0d4Yz1Xlgy7EfAWMhPEu6MuD0WKKB9";
+ sout << "1MSe1zl2pbq14zDGyICSCXCLTTnLjc5Yj+E92s76igXq71sobtFNGzjQfZtK9VLTrAaeKf4Ql4jO";
+ sout << "KWQJTFYbvd1ScHzbPyF3S/T0CHPDZssQQTd/VBOuspVr48Vl2/MRdLpdzjd2b9gwCh2EwGeJIy3o";
+ sout << "R+AdmgHfA6PwhJKBkIliNMMR1DEhwN6QDY+b4GKnRHe3+k6tChrTNUqZpZ7Eyhret/Hz//LhmPn5";
+ sout << "EJ7PSeJR7K8PKE5WoCzr0h/o7SeiHi9qi+2STfd6J940/2D4HBkYvJOEH4Tu+1XBPi589dh6vnhP";
+ sout << "aQkicmf2fG4KuHO9f9Jlt21I4ksrEaibH+BKFFw2Rq2cHME8Bg2cG6HO+TyRwIROhy9yxtWQEYqY";
+ sout << "g+zU+DM/WCXBCNDCPqy2q8nY19izJrU3R/9ZCUF6Ji4GjEzjjb+mErVhPcWpaLk1FWAPxtC/A+By";
+ sout << "rEiAVG7+asI05YWpap2g7apUyGoyIdnHQ2g8QucKspJQhW5BiF4pi7Kh/zzPQPXkuxQ3biy1pOQ6";
+ sout << "/6EEXB6W8dLKHTHn3cXCphbbQzqnti+yQdGuvwDiC/GWZbB7ePUsUSJ3+b7gcR54EjMeIX6rBUxF";
+ sout << "2xXAs8El87oP8PCp023NStXDjpiRFCyQ9mYweS0bRh7MewCbvgDIGBuJfXL46eVRfj1sdwlnpS1w";
+ sout << "aN7ePzPyTR66zIu6kwZzfv5zw10c2LC9cj/WcujPSKir/nCMFaAUroinmmNiH8C4M/CL4mU7SBkZ";
+ sout << "/hidsIh9YDq7cQRP+vBwLgoVewGxnS2Q8yfZ0K9lAEsEo1Lpxg/Bhr6Ei21nnfaWcMQEVj3Gwpdb";
+ sout << "fqKrvOzQrUbQBADmmksC7FSgEBUTYze282+qXvS5jnhFg23iV6bafe7LOnZmqLivJGA6S0Y2D0nN";
+ sout << "HNP+0t/IutdM96z446RKbwvUrfNiPZF1CJWu+T7Gifwib9XOmMCvamYUqI+dplpbgCxmwbSnMwJM";
+ sout << "yK2++dILo2W4OfcPyM9eSBgY/fWKdcv330fbSE74lTUiInCJr0hECkFD6gThcDVRY1MJY1hX2Dak";
+ sout << "E+r/OKBGTM5i0AlXeZoi5+RyIiA6wRTKEzjznz3eHhAe8gRKj57DUbhXd0L7BK1ZPs/F4y3/GHtZ";
+ sout << "rfh8fPWKi3TX77vcIOANXRmGtpDm3CDxlhcWq/bWxEL2u6wk4Cqqn7xxYLNmvnZHB3GF8aFfPRdW";
+ sout << "iqdGfvdkoEWNNAd+OJd5dmu+kdqFXnoGtlItp5myxabhN/I5sdFxalICuAm+kAl0ocG98Q7v9YVI";
+ sout << "PaajznXTdN79Fn3cFSHlgrJ7yDb5T5mzpk6ud45ux/SxLw4ERY0jykYeAYV7NQp87gSxDfp/xyJu";
+ sout << "LKgD5aTz1RUrkC6VnCDdij8ZnqWx6FlSUdXt7+mDe7GAKFBHsKBUY0B7/nqnP+wS9GU7SqPuHa54";
+ sout << "fxgnq2fQi/VaMA4ZTugMOvCKP6PgHWuzhFOorZKnkN/Au+/iOH4z+1NvTzBiBjCH1ZPzXEU8pXMC";
+ sout << "d+XzQyzL1fd1xS5+FTOq/FV9fm5ihRm3f4qVxdJFmSnNKBdUt6+BnrDoReWLDOpNiOF9oOfADJMM";
+ sout << "c55TRsb7ikJpOUhXHQsldLRMCBQOB1YRiDjLgRXnPM2Kh1/YUFV7dOS/tLUj6Th4Tmtq833I1AWD";
+ sout << "CLADK2WJXlXL5s9HzGBwR4eVwzuvQsTfIFZmbnsjb+pIDR1Ek7JH3eyAEwy3WWATtUfSC0HXjEgo";
+ sout << "lKpcSGD6BZsJ7IkSo7Yf/pVcQbpqOV0W0QyGcnu9IUu8q1O8iF7Nu5LNEFWBm+hpN/sUgzE+XIBn";
+ sout << "3sInc0pG0ehnOT3HznAl24ljHBudSqVqoT/O/K2aswIhxOl6RHw20PypAGvTEDQThB3prX1RuTeg";
+ sout << "kmQnQir2ffAp7pxEtxvQJ1kjA7Xp88ZdE+xOFX/8yvCnCWh7jzidGLOZrh6T/OlPIhd0Ipj5GiM/";
+ sout << "CZqxIuPwEyKbZJ612Ny57dLmPK2J/J5AuWaX/wmMlItloLwNSvTwQ2tSiLUraT98oIVqhu/GSRs5";
+ sout << "pRbWOeTspkUZVC1306DklVBAKqWSDalxyilnxToN3l1+DG+UgJtSevkNqlDpNfH9LqrcRrfmokTW";
+ sout << "cymYpib926UvmWf9lsMzJfBqFg3QU5Sa/eF+BwdMgB8Cno/p7FsZ3hKytbOM5fnPzm2MYR7Kw5Mp";
+ sout << "xPcptC3Al6ammAMjlp51IMU/1wyfDEZslTCZrBCGSeYBke5D9Qz31gl1Oyp+RV2lxhNngKZcBKez";
+ sout << "HjNzzJHm/PvqfxyAILoth+5Cb5LxcBp49XjcuPhZ3sFIfNUs8IhjnT4RUZxJHs0PMFHwx9NwgryB";
+ sout << "4UqLFn0+Xves6a9xW6mIzgSJZXxgKcWmA4MZFR1welVzD7l0Y6Rl+Wjw+qrklobDSNgntToguXwR";
+ sout << "mrSM3jnAYy00HnoxJVcDFZ47Fa1aYpZvAlZHbgqJ42LrR1KvI4Qfe+cJVyV2iq3FNB+4aO6DUVkn";
+ sout << "2dBbf2jmR8fNIjz4XkSCXPpwx9OZuYB2C6JjUkQ9lGfDGPfFcDNA0mjkX/FmKf67UaGkJsG+DV2m";
+ sout << "sxojbh+zm48COhWz0xrBGSdj0viVa10iqcKs4+izaPgDSHMFdez7nT1aIBqf/ys792hN9h8kybET";
+ sout << "GZw6HGYgD0u5+hbXeXVC11f+8aUi2vjGW9gc+R0cAnAjcLNBH6IPIhAzC3E3IJ3R37wqQ5huxyrm";
+ sout << "qZ5AWYXo3PRswYjxwlXvUSwEbc+GxcFR+jirERKaqgwoVOw/2F2WHFQ5yQ5R0nI6uUhNMbTjlF9x";
+ sout << "2MYa3H19/dgtY4UC496uQ0uVJ06wlVRCL1SCRpSn8y9IIvAEpkhX1b2o8NAMxR1duq4U0fUhCDo7";
+ sout << "GtzXFiUKyLgKT6tEZdww8+FztcaEB5GmoIxy25ReE9Y2Yz1w/IjIvyOQvqDxqNMC64ETTqYYlF7o";
+ sout << "oOGEBqO3JnTVdOtwbv18JsefEaqWs6hyUboe3zxvZqnrmSrF2Ezydw6jFhjHyTdI2rRNK7mTCg2i";
+ sout << "/1fxkFC4Rlw3U4NctvjVXqhVG7/NafRynbtBxSVq17MH1Sz3IRDsKS1/HtORRuAYX5/KYcgDjNyL";
+ sout << "ew/mVxTedPDQ+NDsj/5k4gjpHB28TQrYD4R8DheQ9liqV7R+shaUz98cdqoamEcpi+Vvo9eftqpo";
+ sout << "/tfL8h88W/BFEBPqG5u0CrSasq88l2EWUflOfdHaVjPiUpA03DhJ6d7WEXIgKqR4KQ2QUFqhft1V";
+ sout << "x6XQ5c6HvkYEA3JXB9WY7pqwhh807ucN0Kp30K87OElOIDXe7FWotG1Nu0IjnhIK93NW9W4g9CRw";
+ sout << "jDvID7dpdENEY4wzLG3Q+PX6srxE/bjKkWXRrIVe4egRq1yc1BdTPd8SuxHKgVvQYGNlDHZww/Cm";
+ sout << "sdtNn9gO2AN+zouZkXsG5JdkJVoGZq7SoLF1JNiBi0asFYTD3y6q6wvfonilbZSnc6V70VVmjY53";
+ sout << "NUCkcTs9H8acP44DwzIAMSm+IWWynZnEmul7B9+ViqgqsHW63Q7AdBCd1O9oqYaLKa16iFiHXAGS";
+ sout << "AAwYVBuirfpq/yGh2obX9w5LuIfc1Ohmm22xT1tL5Tbx1y6rOFL/LirpgtqOKACxNbxiCzEvUEpB";
+ sout << "XHMNDcX/fRaVFYss25guCVEj9/YsVcDz+lf0+qb88TM8LGqYCb0A2lVHbL1O3jdvBtfw3MXHX+3f";
+ sout << "UqH4kjiyAkgn/TLXzlbEjkGu9EdImol1ikOMjffpoXri7J3Rakpq19cJcXZEfN8Nl/J0IxjghnQc";
+ sout << "/Ge2ls1WqEwMUrKyQfRwoJL8Pe/8JErvyUuIIIjG1PHDPAUDmEHVXWU7wDms6GPAxh5hWM9Gwghx";
+ sout << "xDxn6q999jMjBmut7Xvl2yozhYpzs0KCNYzpFwSOyIjfwU/S1SvpXbV/fOlzGLqM8uDMveoDeLTG";
+ sout << "5LmSt1eI2ZMnKoDquGtw9Up/Wrj191RIHkurW0RubmX95xu9KJaLOSQ07cDP/FDiME9LyrHlT1e7";
+ sout << "4DkaBHyVw9fNhbeEWjF9dCOT/IQz3dQhiNizidYZgJX3d/coX3xjbBKhI18DAtQPLdx3EP6IFcPe";
+ sout << "AiU6yzb6lcthTbZ+DAntjRMbEh4sCd/OtjC1HHKm0foDGYcpsN6R3pxA2fJ13yU0ZPLYSJT+3hCE";
+ sout << "vyjvV1go70WzRw5TJ02F1ROusAkZnL+LVJj5cOIKQ+MNifUgd+jtwdiUhpFs4HWEnDxq2tWLNf5d";
+ sout << "xQ3igiJjm1MTfsrex6ehzwmyueSZKxTBgVRQdLTcSSSiivdx0zO+/TqN/0s5WkgCufzMm6uMf/1o";
+ sout << "eWPkERpbGs8c740XtOwHi/3JxDC135mV1Gfau8qPUA4VRqfTKMKwcCk2j7SYzx0RaQvaNDa94aNd";
+ sout << "a8gVcQhLQcStjdUOR8smPQtqtEoIsHqLLXDZF/3gWIEjQfYjbchQcscXEi5iM4jMQCP1/Gf/oR2s";
+ sout << "BY/eEjhzUNnDxyikEt+idxkKLnuRYJoRVTzfCk3Gdlc556Vz14By+uKeAmmgVVH2+U9WtTMOO7J6";
+ sout << "vlpJ3lMEQ9ANonhWe788Jz9pzjzIHhx2UfmZM20I/t4U90+iTJ29/AoDc7g73bHfzugK+2WpYocd";
+ sout << "ecImcUaQLdMqL8gyK882q967InpmujoXoDo4hrqpaLsowHhWiAbcPTIH6hrvkNYJ++PYo3QGCYwf";
+ sout << "z7z1vc6pUEuKEx02TvYO3PX6e/i7RVBQ5HBjVsuNiAw9ai0fsc9sKjwOml7//NR/o6hdFvNnn7wE";
+ sout << "Qk7pP/kasS/nVGYvsxAcqQBBXRDfDEB1mpkZX2bcNzZ/PrP5t9aS9t7LQYfqA7lWVgY+brfPTdmE";
+ sout << "haXezSHgRgy8iQjBAVdBl+Y972kmR4zZuW/SiK3mkOe92GD0dVXz6l4/HL5VFflONDAHCpM8s0zR";
+ sout << "HIlzwxAHl2wwiKqW2LtUIfnHA8Tb36Uix1nyKcGiYjHtv3+mc13GzYU4KZ4sj4suUs6Yr/XEtAXm";
+ sout << "WYO9T8LJ3zzbxMNJZgsVHWx1UWKixJwGCv1HrEAzQNmm5aDBGYrutpNtcC8B2DKIBex+IGlpTwrJ";
+ sout << "Avo0T5uJCfw3hCNwVpoSvpaOpP/hFVJYI5RLtDQbgshmFI4iFz4pIN55qCA98tYAg7a1m2g4CH0l";
+ sout << "1R1ErZK5A08UpTFnxKWS4pVB1XAy7hqtEd8YTADZvNpFTkocwBFsYSsDU92hTS1GzEQ3+NPIj+pg";
+ sout << "i9G8bDLaAY/bojHUktrnaYL47+BmrKQ7N9YWtFsjFrOaS39NIw6WUf4nPd5uxcoqlaM59SrOyEBT";
+ sout << "r4emeEzPDzF7GJd69qFOek/DrjX87M0rDcyCOKABNJfIBaeEJKKnyniWOx+ZN4fb9meuM5bKTqr5";
+ sout << "p6JLFbf7WFzqgs3h7tOcMhSvpVv4ve5Ap4yx4w8ftImJrW8f2j/OUiX6jtA9shoYI+4DOqYL1mRS";
+ sout << "xYK5WKMnkftUB6lK1ldv0ogi8Z2Niil3ZBpWqhY4+tuEvTLAT8nQtX+p9/gtHxNCyGUV9nQk/T/N";
+ sout << "lN/kBTdZs+KzJ0K+AIi7K1t4dVKggZ4vmqYVPj2WTKHboSzNg+oxAL8VmZAxnY2bS5AFR4wRSLgF";
+ sout << "8yQj9xfyDKTPRK4D7b+Axh0s/ra4XzE0S+M2wk0cmRbb0I8/pG7yXSFzYPM0bVhRuiTVT+mrxcpA";
+ sout << "qSJRKzFQ6h+KQLSYUVAspQeGnIQX70fF1/BcYviKGzwiN+T6zLH+thLJnZX/lgHwZZOtHYnCaBBc";
+ sout << "t1WFi3G/Ex2+UHo+nPBCFPq9TNEWcAArNxPwPb9e2H+ZH7dNgLZjZtGQoEiEI2uwI5l7/X1++/3N";
+ sout << "EWAKPziqUxX1sZeW/T9cLMiDsVajYh6wqn11+RM+lCUKvYZym15JcEKAxxsFxo3S3gbILFXSy6IL";
+ sout << "AkGGUsKtITt1LB3PO3sIare5OP3YAk+PzEVZcwp80V9dHq72cSskofr8hvUkzsBZ/iH9dC1dupTF";
+ sout << "UbI4YOrA0nQRnvx9AQVrPldyzgxv9y7mVakcgLLpd1YOUDS7We9Wh7exElCVuB73RtXfcd7hEjRb";
+ sout << "fu0xg2KT3wOTQy7WDAnidLZLtWeKRqDLxmO/W7370N1TgXS2sKLG6lGk6UY5mEOmoLsBtu6G2bRj";
+ sout << "QIZnC0q9x7sUVMyg0TOZ3698MgxWXpvNqocP7+TYdJeptTkg0RZdgON1WLlCs/yBSm9DYpN6u1ak";
+ sout << "pr7GDx/Yira5gaDGvnLtoL/6L4/0bUH3UJmeNACcLec9V+qxUJnYjQPYPjGjALRfRNPV8YRsjWBn";
+ sout << "eEP8OJG4i1dQAif8ehmV1Q++qHcmf5je2IQwcsrgj3+UIw6urMoRWVwDISw5Xy5Y1HwitWZ6VW2Y";
+ sout << "eeeLRYBIWAsoJAw7rSyLqKjaihO8rlGw3IeVrAHt3k12wX2t2f7eLF/MKv/8JjK7d/rr7EbNKrGA";
+ sout << "iyp1zQFtenTYrjr1/Wh12/lmTnipKS8ZDCJuMM0i8NTXB1fuisqNI35HCuXXoJQmc0Yfu0jA99jL";
+ sout << "T87FUwkmKQA5+Q0HV2C70CfE01ALVU9ZbdP80mvHlMhiLURPjRmYjXl++FKpHU4wvRmflXG2K7KD";
+ sout << "aKccgkYqc+Pw5T2AGX5LiJscIykcNaMLVtH/QYPasG8jfgVH3dENDXagJjuYDwkROZ7Dm36vvMZG";
+ sout << "QkRDHXApzY0T9GizL6V9WWflM0lopSWELw4HNPExg+ZMSE1eAebW58HHTDZEnu/Yrr/itfWUid8t";
+ sout << "od265bhRmkCVeEwIL6Kt7DBdDVDu268mwE/OlCRmE0Z1XHDIR8ggVbDjoQHV8Fzr9YCM/36qexHL";
+ sout << "v5VxMx6PzZGBMIGuPHIiqnqaDPHhVpMGbD6Xowd50iE82GANJrgulQOlIg66T5znMtXg6g6hOqh4";
+ sout << "WL4qrPm7T/2mH+T5PVIR/D+VJSoEtnRLdnO6zZnIZp7tj3jXjhNSo+YGzZmdp/H1iuSqC6GuWKue";
+ sout << "agUBFs+p5Zz/yMrYI31yx1URhmsLJrIQU8llICpkRo3uqhPzpXcZs/MZIS5pWniqIfGnqUtzeusS";
+ sout << "bouAVAOgf2rModIu5NWIwdnf+FEsDsw4sYDEI0jkJD8StwBgTg3uiq1jV3FP8nHWgu/QoWWXGtpx";
+ sout << "u7OR7HYi+Aq4lK6sVu2qld+deujeO27r4atK+wQfmbuPyjJugTMJEdLfT2SkkW9mN2Qx4jPmABq+";
+ sout << "5tCUgibAdWz5SqJBDr/jxW8WZ5YCfAhLifSbTdcM/IUF5Nq+NCT9vrcWaUak2K2uhfvoJjN7DWVW";
+ sout << "6KZMzeN5XIRKMZc2ItCW7sUmAy120f3LWzGZtlX49LIY5nsawUDOrPIENNab3CWEPAuKAciOsFU7";
+ sout << "CV6BvOxRekHvH/6clgwrfSCGgJsru77XV43H1/wZciyR/4QDMXPaewDGxdH1FNddnkZVAuYWrRzq";
+ sout << "PF763f2SQmwM5GERk2bxwr4eeeBWBRqMCK4ZNWpdYBseupT694Rrk/bnUKjuo3kwYrs2LzRXDJVz";
+ sout << "Fvj0UjorGcr63LSXNyLhyBGo/Uhn7dgtVUaYzsY5rChqvLthaBhab+268cfGyYhacN/E3vezmySL";
+ sout << "J1Bj4twQ+kcbWWQE/Jx5zGkgorM3Yu1vbjilzv/UbmNb+Cul6MgnKDRFS+cGQ5UMw2DD9OyTKJZ4";
+ sout << "5k+BcnLih6kE/s4UmFkeC4QRqLMBBhTlScS3wXS2lFsvMd/BXNXWtyFtk+dUP0AH5tgrRsaY2/h8";
+ sout << "xTgN0sMK70yvSzojbnExkON5mDszehIEHnfBlL14eqImQRiUKlGQlP55T2AKSYBbpKFJtzy+/aXp";
+ sout << "SxmkYMQcvFD+AH0oaqJFaB9JBLT7JW5D4ZwTHpMX7AS6W0LxgriTgPZnSqVgosl2DZLKBUt1zXjw";
+ sout << "dbSmrROYcUAXJxMOAj6jNlv7nMkn2xV28QS7lFE+aq5yBIX+UBTF1uqk1XI6OddMKrtzh0SH9Dey";
+ sout << "q2NWoOiH5YIxHHbWx2c4fiFrSCo9v2JQlk/rv0GKz5jX4GEI9+UPWzr4MA6GzJrYS9cydN/NXwY8";
+ sout << "lvHx0e+zI0aiQkbZb3hNnWsjQ5q2gk2pj4yT4zCFhdL9NBTAwWrWqjSi/Zf8BU5R23WHHMMPBfEj";
+ sout << "1gX3IdAZXw6tNUseFnBithp1Zo52/Rv+OUyNIC5o9whYx7T+e1Yyr2ytlKeiAHOWmiP8h+Mqnb0l";
+ sout << "BXYuP1fiYcGClI5Krh/05hb4CCW0n385qHMIGDiFTw1emVktV9OTiM0Z5M8PUHDi7IGUyrJcwXyr";
+ sout << "Bjh5HwdrsCA7NdLqE+8QPYnAPj+rMJWw+xrrlcPSpRBTTPTPiVnn/PXL6khN9lBiwz6Kn+X61bhO";
+ sout << "8yN+OP1IWu/DZUzauOoVqhEPEARq316Uj15VXoPpy2oa9/L0Pqzna3d+8/VxnSlWWo7i7bWEXbc5";
+ sout << "N9OKxooZQ6XAFihvJMIsVvkYp0oAQ4AIoOLUSm0w0ejFIWsQ/goNmLj+CQneb3+YUpdYtpLODhln";
+ sout << "/4HgBsPjGnEug/fzuWbEqWXb/e+VL8RUX/r9nXe2rGeQtdQkdqsNj8avFShnnth+UNk/2TRAq9Cd";
+ sout << "eDYJMWOQgP9NSecba1B8FGmBdvC6gRsp06uKfGesrktOPlzgKmZPoqsjHhrLLV58KwTavZSoIriF";
+ sout << "Mzj6ZAalNMDEqUGuhtXZrhTVoRiYsy+3jjJ8IubSS/RxiQfPS3+3nudS8jzA5u4uHJcy42t04z/l";
+ sout << "NkN5/kjem3F1fAu79/16XeDD/wk9A7FyjqsnyvNXsiX1yKcEC1Bw/yxCh33mgcAlh5ilP+gWc7sg";
+ sout << "QQXi/dnqQebktoqyUYklPN/cYQpCQhhY+YRp1wwbLJYWYrBXiN3TbACI1koRROZw8pT8YKG4utqc";
+ sout << "P+IjOrAVYg94bzGeuRVI/T+mpmGHs+lsgUY/WLx52PqnmLgBORLHZzM1sFp50ALnbYah3gAR19rz";
+ sout << "twGYVQXIIhK8rewz/RwJJYwSvilMiJNx5UqJc8VFaEnwKcr2QTqlJ9FezDpuE5kvqlYl4tywsnwU";
+ sout << "ta1MB0FXAC86i+4pXWfaRt5Y1g48PGEI3TdjkpiRJwEQ6YGSoDngFX5K5l+rZePE6FHNLBoRcIp1";
+ sout << "5aNK7cS3MPyb1MN8zhjX400wFp208a28cBpCH0KmHvt8+ip26ejS1C1/aLvdz9j9sQtixw7JU4Jb";
+ sout << "ytPOVlxKjAzFFFdaQnVEAkFCKetYCkMzrARSzcLQJcBhhChb8NQDZoQWxe2286/y9p1ImU+jHlzC";
+ sout << "iTREz2ADRBGQO3JGz2h5GYjU6YjZn6ojyUyMECTZASV8RbvNy0/8yYPNONYWm8v4/LLOUM0fq8iv";
+ sout << "xwag9U6L0Y+TWtW1YD13ThbTjTrg1lTHKzyp6KeGWNWGH141FgXz9+r04Qkvulcn+rGimg4SuvNJ";
+ sout << "XLotrbWuhULd/esyJIAqfTtHzHZMSBoQ0WauyWV+so3c2rGDjo197k87SJkfGp+l78kqj03jh6ev";
+ sout << "WyjhQZizdnuNdBpm26GnuDYmZdEaOAADCtkR9dvMWdeNsDCjyRcCbFEPiUX2xduSbw6aulm6mW4l";
+ sout << "CckSOv3l0/pfw8bTKEszFqeXiPEI2SwIbSx9R1nWd62VVDrYmAkAXzdJ33QOlJm6YJT97+QuU+61";
+ sout << "RVXpRanpmC4Eh3jEZKV2BWIuGPXzrDGgAHoHNhRG6vJ0Kp3RqaxFv3cPrLZdxDnQhZQH01O/dZ8S";
+ sout << "swcpb3/D5GdWvUS6MsIPhO+nhoabkyxwvbBNzyaLVDoXZK6taVT2GWFyYqguu3ubVBGd0H0Mksh9";
+ sout << "luAh4ogG6drswOFxJhR4S8VIMkF3x8PDUj5w8WTC2nlhxIYVh1fxvdj8fEfIH7WzBElS8vpZ1tXC";
+ sout << "1JSgHngiEYafDFX4OJeHxKFu994g18YvLkJDdQKzTDvOiD9TaKnnaSGGXZuasKTtD9z87GX4O84P";
+ sout << "rye5DvbYUP3bZx8GOmhf5YsaLI6b5m+iX0Mm8ulB0hOla8uE6EfLBovorLCKke5iPFU/gynxeCkR";
+ sout << "8RL9zw58ATEm3YuzE6nDNaIekqV7Q0MqCpnpAdBPxjtEGo9yRRPeu8SSSGL7IGKUUcr/6Xp8ooyK";
+ sout << "iK4f+mB0hgO0pQqffQHkEXxGg7Zd/eTGAO/n3Acq6bjGr03T6LEO0KFJl1/m96ZEBQK/iiV4m5+V";
+ sout << "bZvr6xOrErUE+ih4g6vk30cSQqFvyH574K93bx/uyydqzyPEQTZ0oJmT/KoUrYti7CxlGwrqPDTZ";
+ sout << "AuEOwQGqqxoW097Ql1bpxekOevJPSL21snAT+Lf7JZ/79YwZu2WZnWoDQRNNm4nwn5M6IPmErESJ";
+ sout << "szsF5VmUs71cj2/gnK7NDU4c1kLjXSsVVnN3/VL1mOnYV25Nh/ktPRfTLn1TDNG+rkdJDjGcjjSd";
+ sout << "JY9Ro1Rab18UKm9nE2tWjgbNUQbzsisy9P4F7cBLBd5K6y47wRYZ7MamnAUZMUEDmIiO1SR1uLey";
+ sout << "gugfd3Sb93PYjImZCjulVD/w64IThgqDIgVd1BOdAN23laLSQpfuR0xW5k6GqGtiP98VwULbV9rJ";
+ sout << "thalp5t6mwsOdo46101N512/p70XcZuX/VyFxw3bGO597RbO2nj+gM4UanMOfTEQch/kaJsk/WTu";
+ sout << "m4QMSxd9SAn9/aRXKR55Im98Fx/m9q2GB0vn8RZqfqHAqjmO6F3usDuUFxk0kn5Hq+8jXBephw4/";
+ sout << "YCOP8eBKnY9V2k2bnQO/9BWhCyuwun8huNz6keaC+qa9PFFwXmAf54BIYdgr+NspegaRej7bOISM";
+ sout << "WpHZ77YuKL04uozXUz6B+0Jy/zDzI9EgLTCk9L1CUpGPSoNGpDfb9VqkOdTvjQphSOPB5PxD7NNC";
+ sout << "tV4tl1nhB46JCCEalhhsip0/ZEtx0Hd+UNxYRA+qjaGiEXDPPDyjbYYHWNRhYKo1SDJLVZG8hoVh";
+ sout << "stl5hebAI8t8rCBUF3x1HrMRdrRX/GEzLfx8xZgk/YAegRw68TfN9V1bY63aRS712i/twY6KDxD0";
+ sout << "V0DeubXqkK+N1ERAn8ygvQ7EJNvGbp0OTu765R4zv8y9u+GqdmpA5E1Ti6+2l/L6008zbYQVWG6z";
+ sout << "ejkRAp4UjMSDTf2zTNjP76zzQQOQwsmwodggF1paT6UlXpb5HHp9osSrdwGXb06FiXf2uVo1rtSh";
+ sout << "Stzfee1i+BebVOqKM38qGjBhXwGiiiZh/mwuG4Fa4wOKlrqaTVEHgGm5N0uxr9by1E78BXj7Nha7";
+ sout << "XuDtqivJf0brsiAUNNtfUjBT5WJrj+iYT16Zs7ds6Fm60mdF7vSBrxXPNw7ZuueSA5xuN2bkeP/m";
+ sout << "CQ/3N3lGgTQlMNkMUzgY37ln3xYCXGcMpQ67ACMf4gf3l6FhI0S+JMwzxD4XMZBTxgt/2YX410Ia";
+ sout << "ZijtokULoOBQRqLjxREcBBhm+tpSSU5cRb071EVhz4K+aSN7PfcFOTMC1mVGy4cibxeRIq5Un1o0";
+ sout << "NYSkHLiVCA4OCLkemWvzxswd9Fduc9mK11xyqBPu688yHH9y2e4Fo5YjiKiZynd/kydWuNsYM9i0";
+ sout << "U9ukHpMFmrDmIPZJMGDgZm0I2QPvoieRcpitQwNYh5x3cvsEs3eAAc4mfxQZUR+v0gClhfu7wSn2";
+ sout << "XBPPNVyzE16RK8q96Z+w1paaFKHR2t4TRGhmeCcxR3vS+jvqpGaYFJ+YJe8hvJYK0IMZTBSPg6Qx";
+ sout << "Cx/cwlr7gStMvrFBKjLGDMAsrw0OyDbIe6/E3Erq43GHsAplYSxYAvOFZUwX0b6VkfAB8CsBnRFb";
+ sout << "jnZmZP+cZPIySLLGJJm/oIs/Hy1wYuyn2XpI/uCuBDZifE6fnNzdpzQ0BpH3qjaIF4OV4gghhruD";
+ sout << "toj7yNaS1FZl8IhQ59Z7P1eSvQNulTHYsfBw6E8RE780AO+aRHAxVtuS08RWGrde4wV3ma1NzMaT";
+ sout << "6fHsCr9IYY1qACpBtyGz389+cnXpeN2VZczNw7JBx3VthteVNdpX7cpftA0e/mazMuKLX0vcBfus";
+ sout << "/p3Vj1SxPevb5pslgkfBTVcFxmEM9Doy5dSKQ32SVFgYwTppEuyviNKuUzQWwO/XhsL50gcuUhyU";
+ sout << "ELpBLIkpABrbGRa9K9ye5qLWYYzzBzlPoyhEwt8cWi4hVnafg+RHEJ7lXXQ5GlEMlzxPJw18LwZs";
+ sout << "jjbC/PE4abuILrfHlOygD1CqQeYzzBN+gwwaTbmquBtZyIMvGr5FrRqz6sRP/W2OvOBN5iVSUa50";
+ sout << "VjKdnNUwSkeLCU14Lp1zN0WQQVVNeR3o1JfJsGvQNjD74MnSlHTFSUYloSFgHbM2ANeKgiYN3uB+";
+ sout << "xLb+hRJPxIwiTshZyclnLmv5mh63OWU+7afdA07BMdmnK77uFB1rdXFJIWUBiUG8u2yqUk3PKVg6";
+ sout << "9rJYvVu3eA7mjBqRvL//koo7VJt6g1Hj+8lnqu71OwsodU/t+W4tmcIKP9S0eafHEQC6hYCu8b5v";
+ sout << "QgVuWf1LnTXwTGbm1ohar6OTxyZV07nJGC0xscQ2yzhxnssaCoAw5xFUzFrgo8a0J1Nj9aOaV40C";
+ sout << "3y5lP6NgiO+3r2J8wUcKWEVFLHDAYMXiMfe3XMp9/ERswAGSkzfXSxZN8T4RPorh+pjPH7sZJExm";
+ sout << "o4ZOdqaWLdLS03+0kRuPXm6FvZCaMY7Hz0e2Zv9X22KYmXluGqcgErFTjYWhQFhcpt8oZA5T3mU5";
+ sout << "NDE3qI+bgtkN9vvNbhIrrOH7mKm0af/zrHA8nnU33sc6Gnv0wfCvolpiyTqla+HU7rwapX0+Ge6j";
+ sout << "SZZht4l8OPe3/mSmwGjX9U8xWxiIhhr9VbVxZZ9UTnAhfcceB3TrDA7F2SxGmfBSRvQ50uRp9Yi4";
+ sout << "14Sasx4kgF6e2wJh7TBRfDMQM6dew3rmIn1p7l8f43AcBe7T3rnWB6UJ5eh14+xb6HwyT260Njvy";
+ sout << "vF8Z4nTq/1MCLv3c4lQoO8yuhFA+J3rWlYTQmc2ilfWbYpdiAOxT7k44QYoHTZDS58anIywl9J8Z";
+ sout << "uBYa+giIPDINPDLGNobNEPl3kerfygAFfVkxd33eO9LwDAMGAZmbywt/OwRmraCFhT8Qz9jygFse";
+ sout << "mT1tZ7+hF2rfTvyF4OiTz2nutfzVvYQ45DLjmbHFQsHMVGzU6xdLQlZYWq+h5tfVng4zhwIge8tB";
+ sout << "mA9ndgrTZx5tspm45fVKTjE3MYL5qNsmzghyfasarz1MOD/2wml6uzWZjtljHIzap5gtIchaIF3f";
+ sout << "JMOvFu8Y4R0Fu2UNkbOfy/YLRnKlUHEChwTeEavSkaHBkprvRZJuMsIWcvEZHDIWdRte0qfNm9V9";
+ sout << "zWmj4B3RchqvXh5XsoDkMqergDisI4gCiHFIxzzfJBsX8TQDUQIvAGk0pc3ib1/6xwCvLs+NXC+8";
+ sout << "0sI9e/aruBe9uEXGya5n6Vqy8z9soOoBCcFqc4B+1ZUOyN2zpIxVFd9/20NKmp1W8RjRforBvF7F";
+ sout << "2E4IkWh/3RcL8tNyPpWkAA7Snlm8oO8L5cc+lMx0ph7DzYX5NcIc6V7otI+Bd7w8xLSVxzfX3Awu";
+ sout << "S6bjrU2af39KbPFq2LpGOakR+1gAh15sd6WYscozdkSZu7kGpuEppnuMxYFmbMrCxtw2lDt4IwMY";
+ sout << "ttBHEP2XwF5z6DQpRqsWBMcuBbztZ9BNa3y2aR8pASldncrnRjYQ1Iq6KGrHeau/uq80ui3VPbOD";
+ sout << "IrV/JeNFGet12rX0zneHBGYUxG+Pbso07l1fJxd6OgaIrwVksovd1psBmD6qYsbc3V3dJLacNOdt";
+ sout << "fvFnHWz6EDWivzMDWvumzAO1JPVFB1TebaaWaINSxaxS3h3IKUVPTu+2Ytw4Plwt6fc/LakedkFH";
+ sout << "fxE4uN3LR8Y1bKQN+/VvTHXGUtf+UyZx/VWHaxuiqoxXUazpT+l2muie+/2l6v8JUCBTUYwa+2W7";
+ sout << "WWvrBxyliD2V+KaN4dybWey+3RIp68UgV0gP8h579kezx0+Q2Ku60vjCIKJ4vJB8Gt06a15lHOmx";
+ sout << "WcJq2DmQuyWTXsqaJPuljuHUhqYWMPnqY9gi53ddw9SwcpgE8cCC0cG65FKn23C1Ihh5NnuBohYp";
+ sout << "ly/QVOFi/rGW9kKrF9FW7C93p2DxgBhygKC5E582EIjRXm15bQ5bUSwxfYfQ8VDCPlL1cR4yhpqf";
+ sout << "CpAd5/ls0MJ3SFgUS4e96011A5intr3a3UMPqdMKWrwCW6cP+7TJKC22+LwvAvyS2dI+NdyLt0qY";
+ sout << "o2t9/ml4u0bbBjwl24o/CIvm7T+045fz+3h4yrKYYPSbHPgGiN0+oTcGhtjBgtil6WCBMuzgEn4Y";
+ sout << "zVaTD4gQV2l6Cnpi6fi3ZIOwuEi8aqfRdNrPc8+Rtik2wWdmbYZ66d52HweeALKwfzIm1DESb/Dd";
+ sout << "qc+pz+yVNmYp5fffsjf70WrDyGWMDqSVvj3tXMkxKiKRetutfs0Tf3vGDdUEdTmtw4rEca/1K8p1";
+ sout << "eNpotuLSio7xp4/gz4cmq3IUEXzk7b++TAs6GAq5ddjbcTJoyRSnOhu2oetJGzWJSglrWwGpAAoC";
+ sout << "l5jHFfMXIyi4lggUHjIcpEosLLxntEai5cP9bcGzTw3bnnQMo2mI8bCKSrw3uzKapI6MdDJ7fXZD";
+ sout << "aR+zuSlVjO98VU+BugUHWJcrHneh7tRZ5ijux8t2d6y/iGeJm0gs2C6ukA9PTwSV7UxnmVOCsGqp";
+ sout << "kh37fINlLQWIjS+fVQZwYcm5XyzPzsyHa/x1DZ9CJgSaLrOCm8PDPk2BHFSdoYMrUGn9CSL6zRNN";
+ sout << "TL8KprKSZLrhrk4wtgRrTAkO9twMovO8Pjq/xgO6djfJygflU/Yn3VPYjg31O0SQLXGQLTL2SMcf";
+ sout << "NlIqmLEov/m6A5cs5pJF7be1yX0XneDEa7arl7v4tcvMw9Ot57OytQUTbFu2iUb6zviTVWMEMJST";
+ sout << "eLj7c/+RNy+lhBjlxG756YIjtWULrrdPBkie+AYHMSJ7PnJ/QAl3D126W7D9ubCjagGVZ2cLX/ft";
+ sout << "/pLcIuPRcRUeL/jTGshsfijbq8us89JPjLS08HMiaaPhTtSOKbyqHwVEvfu++thQZsu/3wM8f2a8";
+ sout << "CnKdaRAkLethTUyv2xNyQmx36Qg3LVGf5W/tzC9U2AQ1XRE5l/IoL1AJ08JdQDqHSAEAiw0LWMIW";
+ sout << "a2alrr54YOp0qiPcRTLKo0gZLr6Sy4jKmQqEFrPCxqPSGMe2pc/pzVuxlYhW9fDMjE1vvfzmWnhU";
+ sout << "SSTXpxONjtMDtFWol0pYyjOLyJcATXCrOK3PXPX50e+S2Xjl1iJezlgtdqWVR6mK8yJ63GyAtjzT";
+ sout << "oBIBQjydPb9tpuVY3Qy129f/Pnw/MT/6wggYCZDuhhNmOqmSRbm3runLfwSX3G8XbN8LEw94wy+n";
+ sout << "yMs17FWzTwWk2fMEIV88TuqsxDPW0Ko+a+OuiOQ+7TGVdOLxtqmUDW/3n5QA6Ryc2kbaFhZAhONU";
+ sout << "zJXyJzll6UPZfUsmAiX3u/qyTWlbB3EKngamgmD/1WNGX6ysD3RYPu2olopo9wOqihwweOtxfQYr";
+ sout << "QdXhUKreENZAZ/lOjn0tSfAFsB21/MuDNpuz37dI0ACfvKyPdGnFpNocIXHWmSWZbOB6OypSbzMP";
+ sout << "5r1fNPRk9qMY01lKWRz+lkFGzl1vbfKfSO17q073cCPouHosT+RpE5HBBf6E0yG84m5o3u1bLUBb";
+ sout << "8V4YyMREOIa/2wkUH99l3YksFc3eeUMzHSIoHrvCQq1CTyaed3tNbZw0rHPLzwSSgQMZAhsi5mxp";
+ sout << "ZnMABMsXSHdvgBRkVn6PH5rjrKAcSnikddL8F9/FXvJR21nKfo78g6RULWdlcc8tPQR0E+PyOoLW";
+ sout << "VR3yw7L8Ddl1zHu7hoerz+cM4gW+67kjehlzV+BAA70h43CfxsmvZm0Kaw9GOkxjDs8F5b5XRtdN";
+ sout << "gxh9Y5o/vlxI8f1w388jL7od20NPAYSqvONaJjF4nthDN+CqlNWjy3K44Xeghk1o4smFlbJM1ZdM";
+ sout << "taj+ZlPozWzYEugvDe7/9woqdFtYqCWDlIAcMGqw6QGFGE9EVwP25LeGWnM4bcSWkdP/Fc+Bpssq";
+ sout << "QaRpCt/+Igw9C26fM3zSM3y+RfDs6dseYokawF7xp7oo++opFuMO18Flo9dJw9HeXMZ6R8lXrNrD";
+ sout << "l1RWaWakmYp2KgPJYT9USAV6S0JI0m9MxL0R0t7XT23cWgtQsG1m/gLb7HT+CFNm4SYnWMGpza7N";
+ sout << "4w6hXTItZLuRrC3pH6AF/m8MuwdlKJEkTn0Mmu0e+MVe/DJmnA2w46oxML03wNy4qZMj0LywOnIu";
+ sout << "Ip8p1oLIviccDGpnXqZORZGVMqzuidIfNBSgalxT1A+ySEAVroXtGjtW/L7sedcTElNyK+aspa3l";
+ sout << "O10XSF3gjgcZnN7W1j3wiqePmtzL/RlZj81ruAoYJRWaNUDT28p67/k75Uhl4Sh7qWwc/mJnG9qm";
+ sout << "rvZSX5GiW11EJs35JJ08G3UHAUa98VEIOesPMaL1Wr0CVHFVC9eW2iHwEq6vk3JrGfsKARrZgg9B";
+ sout << "YDhkyO+3c+qNgTk1WMC3MXaSZg+mwqwx0TGeJRg6/+rcKr0Y+y0ZEDLz3kMD/aihZ0ictCZlcYAz";
+ sout << "8pqdirBd1Dbi8AaMCJTWHGWaOIxmrdU2XyVWVk5Vpc0rmzyc3ya8WfXjkdDzXKUS1IIFi34b3/7G";
+ sout << "qYpJNqtVAPT1TXnnM+2koRgGaCARu/RoKzl169YpDYPHVJuRJDjy6lfrKNEqPzJY9phbh6WhmZ4X";
+ sout << "CZMPrgnrs+0Bs///61zNUJA3fArNyessRIkZqSxnKjVOHz6/LUB7/UzJXAd/XfMYi4yNw72BSAjJ";
+ sout << "Z2AdnjGIO9p0E2KIOVx39bETt8l8JKOyT4JbimwFT0IN1AcvwH4wgwtKYAJ29L0kSNoBmMwaKCAj";
+ sout << "1I1b8usj6EM/OUjn+e9xmLMkeAUo8HM9BvLP33TuKJwLINhNcggbZVeaFSwBsvj6hEQ8dN06eLJ/";
+ sout << "yIGaRsni2n6oXeRUo8/TVWmNIbloEE7mQUuttR1pVBdkbJsHGWm/S5hwEmsKKOpvQo44Rao7fsDK";
+ sout << "cCtKATplclwuCDCPIogRMkpHRSSrHUKE4tzdA7rWWj/jaYc1Sw6hG7OaDFSWuP09Qq7klmC3F+0U";
+ sout << "CIAtQbKJevYhCSdWHfo4THHFR5CeuA+shfbdWyPNmP1HRA3pvo52jSsJi8HtPBp6XmGmM8dN8Vaz";
+ sout << "pG5d28HH76fj2Ny9McHrxGFaV/WhhoST0GDqN4PnqV1X6d05aBKfwRbqdWo2XvmrC2YNI7Ou6+3M";
+ sout << "Ovp0NH7MHXpLY0B+GMf+x2X8hoXkffnGCa5CQ1SwTbZRkeGjejwGhpOybsOl7AaSA/fqymEW9JMF";
+ sout << "1E9XEtZtzyATkj0cBUlofwC6N3N//altqBAbmDKrHThV3SDsvn+tpPiZG4+pQQOWqQomy3QAWwVV";
+ sout << "nZTorQIvoPMUdJWPwsVLaRldOrUEjmK0OF36+QysxIJLgj9x5uQDymixyS0Pvu6ybdB5Z+ggH+hs";
+ sout << "A4wt9pW2wbaOGU6jjYN42C/ehSOClqGP9JE1qejqcLRAXm+JDDFZGmCBfZRw7W6R39cNfrbbACVF";
+ sout << "de9yicnqvNbpaeAWAmfjrwg0QvVVmmlvsgDinO2OdZ+TLx0KKj+4+Zrt0fFfXQgUa2oIFGoeCuQv";
+ sout << "eyjv6ZUnkYHtJLeMDfuhaxaV6dIG7US6tdGcxKxw0/Z3Cl1AqyDSOSeZIfABudI/tNrQF3bk/x3g";
+ sout << "Kb/i4VNKLwEq6O1T8HrZSJzTWus7XcQncGAOJWTsuC0/0Cv/Eg7pY87c6VtF2/kltcQeH0qWNix8";
+ sout << "qHpkEC5Nmjm2erJhQ2gsdVQFrVjxv+irSaO7GLhuVwFgj0pMzKDQTFvd49DRMyrwb9cml+u9MZgs";
+ sout << "Uqwmc0zXzCCT51uF3LMqVbeDEgt/Q59+03U1N2QAJXtTScjJO97PYZSJvMEPONjSI5AeRlRqpD23";
+ sout << "4O5xrQGkzb31WrZP54ZyZ7Gt4T4bxdwVIZSrgSrGJt8RjREWl6LuDpuRsPgnt85EKNRjf+icgibb";
+ sout << "sGxnWWdqFdOLHdPKDyeCTiJb/5dc9YvjVxWINlgOMxdYpUU1OZ1hX3OmDDepo9HP5OWA4GnGXR8Q";
+ sout << "cNS85y2Q+DS7RF+qxSbIO7hOcNL2PyFxScBXYyrkLUWWxzfc0BnyCaCg4lrMr6uH8vnM+7u16IsY";
+ sout << "0f/y4nxYNholE4Q9JlbhXVTBL/ZZaC4qU+/mXrEBkUVY/xj8NqYSicBJpv2c0IYwtf6hwOFdMm8H";
+ sout << "y+l26vbWkXEGS9XV/dKKKLbW8EqhKxtlRGB+Tle1HnOp7HhcINVwH97RL9vS8EurSNlSWJCHc/1P";
+ sout << "Ui+ZsmkNIQtV6tATb57g/AaResCtNWa7znBxzJUW4stpxdxjD7fLZxyP1dSAlXDluOaJ3vaUpdnt";
+ sout << "JqkG/gHcvI/0DK84ddN4VLKIZ0XTrCxvTQJLJcUhA0ENgzdIPQj96CN1XvJ/DdFxx4YKBnb5qyA3";
+ sout << "51z6P84v5wm+ciOVpTFERffXwtPPfro5z7/ZWazbUjB4AcOFxj5Sk4EweqTrVSD2Ujv3quBT8SUj";
+ sout << "A8Mc7LBMtANusOZeMURs5vqoI5O5hwicxsnqHJeGe7FMRRVH9lrM/+3W5dfm+jJ/ntoZJArHCrWQ";
+ sout << "wJqFWJksHTraWTJ8crxkGSgZzFhSHv5pkz/FfVPRW91XprikZvTlVhgyTFbB5gVGeEu1X+M1m+29";
+ sout << "8Gqo/AYNTsXS6KfOWnfc+jfQRkkjiLsBS2vsnV255hPBz7Fvr0mqaN6AawiESUSXcFBnfzWQbOdN";
+ sout << "wzA468JTsVoOz1MUXZ71hOcXpGMVdKN5ABH+8imVfDYGVUUTXv1e1Z41jrzXVphOQcQA+bJYc3mP";
+ sout << "K0HVQQfyuwo73PTJ1pINwbi4OrAJvTP1vEuo9xMtyc2d9zIN5lVg0gTlEE7uwGWVctS67onu5vd9";
+ sout << "cQ4/+1axVD2yy/Z+s9JGRZ78DA8rIUqVhOTWNlqmqquxf/mh+nY9Q3PdHfcpNoqc4g465Xe3CvO1";
+ sout << "cBGJT0vnZw1vV7rtQ2FqAB1LPAp0f8nIT3VExu/mABWauAImnDE7ZiEnOuiuq2vdS8WspLMjktCg";
+ sout << "vlvwF+7AGxfwDCImKSzh4Ph5AA1VuYW3sVQ5I3XuONfappCoJjZwiV8jyKC91rHY1/jP4eMb5vwy";
+ sout << "8T5ciRyzv977mW5MK5sx3ujnlVFWr8y8VDhp4qmhnvvswLIKBRbrmW7jkG8Jg1YB7mjY0rP0sVJl";
+ sout << "zICdb33DVPWxgRHDjyO189HiKWkzbH7JTQfKOBQwufTNQ4v94HlyUYNTzHg9S6nqI7LiBq8zO784";
+ sout << "hGYDRvfB37TziuwR5oHeQ0IiV3x79V0+qLUHMY7ZFpQB+LnGcxG41kvposK0UBUCPDLKsLNQDSn3";
+ sout << "ltl6K1uTz/4gjdLjNVlkpO05xJXjK4JEoG+siPuABINp/7iLAgYGHq2G9p8KmKiu2CLtN/g5Aa37";
+ sout << "pewz9BkOuvCHm+J9/V8+rMHTCxD3QT962/XOO+iguzglwO+YlN3BpWTnr0VhpQN4ZZZcZcucEze4";
+ sout << "PgoLS6CcGSP6xrMdnasvqM1kFCvgOn2ROITQEJzaakPDaB7fzVIw6WyusM85BSNnyc8XT2RkDJaf";
+ sout << "VTGEMba9qIAi36u+pC/PH6XTEW+x6nehFJu8q8e7B+VrnN9lkpvVHVDxjci3CHvG1QLPGl1GaGYJ";
+ sout << "oLeINwB1VTKtRtEyKENgc0MS2FMk2G81MoKGKVKId2oddu9ACvHqzQMbOqnzp0l7RJewHez830/f";
+ sout << "8oKpsUvl1N/b9wfTXJwN77MF/n2zAg3VSv4HInhkdGA9HRSg1JylghqBAOrqpRsh6r/TG6ByJ1uQ";
+ sout << "rv0aYvVdVaHQ0kHH+ma5Mlc73U0x+ncDjlicbrIca7x6y/b90emfyRcN0uEKr0fZWtcbK0+1XJXX";
+ sout << "WWMGGDUO5b5cJOpDzW7GJe6i1m6yIPPlfp4AWLa4M6F/ys1ONLrbXqjYCFJY4y62OeRK32Ff2zN2";
+ sout << "/Lw3hMt6iPTm7oiJOJjDcamybQlWOoYIv8tkEQYN2nCX7k7DD6z183yX7KtJF8Sj95nOeYpzAnLG";
+ sout << "YOfY00SA1cM4aLRUPG+FYrzn8UFF8Bl9UFrLXoJ/IxXOWKqrbjhsKxMvC4xriww6KZqHx0L5FWGO";
+ sout << "lcFkhdpx2fFeswyfjQwtyqsGsiY7MXGsWA7fA1TOv+FZhJbgCixrx1rxwJaNel8uMo6pGHsZt1N5";
+ sout << "Cr4CkiA5ZZ0ITa63oOGM0tD30yaSLkSDZahPT7xo5s33HNAgXNxZmYnv8M6p421yuNPraabZyJ5P";
+ sout << "3FtNwzPpw7Q4BPSEBoKc/UbY+mdzm8ES0nKPI6Oeg1tzc9NgP5JaNVgWymA408iLfTjxaxS7FqOw";
+ sout << "BCYBbz0EdJigWmF73di27QmfsygTWbHrfdqxCafzsOPuk7l+W26tfJTVxP7S5snKfhbqVNrV/4M3";
+ sout << "uhpZthruy0mSHtBx5CuXAtGNQoesk06k22scTmX89BgNsYJAshikS3IHg+LMMfaxBe8KQaaySUX9";
+ sout << "j/ULsdT1YZKTeV3NVn/Z01BxF/ZPvV33Dbj4bjEcf2N9byWwicz2CFpkFyQlXaSPPSzMYYtSPCZH";
+ sout << "QqXL7ZKh1Xkgh68h5WyuoKkzDXGuZcl0bGVWXSOfO6gK1Mz22to6vJ3lfssNBUV1mxA7Yos1bUb+";
+ sout << "tGBxyzoBmphWy8dFCdE5vN/WS+w7VV6O/JQ6j8qbpBn596kNhL4iaK9Lzh0yXRR85nqJbi8bkIZZ";
+ sout << "FWskXD/A31jO2IUZKCdB/iqtIuEFZQnq4LLgCTlwF3GbxPeZTlJ1yKcy/2e7nIHZ7xHeG5tzyqjP";
+ sout << "oi2AAttJbrL9n4Ebk20aqzYxdYUkR8DcflVhCjPDfeczGwzwJxwznHHhEV/OuxZpWYRvTQDOuUzJ";
+ sout << "Qpu+291rSMqntqfrUo7uxCGJLmyRwWqfesAluH9y3GQco+V9yYzqf7ekdNpuygq+htLNc2lG6deM";
+ sout << "02oxJHjYyjDlG48zlMEFy9OhLLLPvFVmQWVL/D4FwAwKSQfYWmkUbgqgB0QwHBFSA+xqZBWZbzoC";
+ sout << "3xLt42o0hhdlZkXEQtGhtaqMfU4rBbKWxKx1gKhHYk/Gzba+PSi3J0e1A0fC4chtUfWc9UTZT0HO";
+ sout << "9QLCXzuC90KrYTYOeaPSoxfOm2aboXBWQfyxZGX9fKRdLgfJ9+8EHw7rD2TUTJueSdfvVahIRi+t";
+ sout << "eiEAG3CWoToa3S7i3tpR+YczkNVTl9DCuwD/QMAuPo2mul4bdVQt2NVnoz7PKMWim/JYSFPlJBXP";
+ sout << "RaKg8manuhMoDfHTVx3FcToGnIDBd2in/rfbEjE46d+98FRs8p/zhmyQ6/+xw34DKX/NrnU42qli";
+ sout << "1+erGDbJFhLmIoU3OhYTiABG5JVefsXQFgIL4kSF+Scratkn+s1aW28PLoooaClnGF14imxCA1Dc";
+ sout << "Y/O0HqBFR5lAn+57fsKP195CJBvmyPR+UE8pBE818w09TvBmN6onwBCFtsrEnzTmstINfz0TX5lP";
+ sout << "7eOD6+XiZM3y+5YBEK140cgAUae91wD0Hjz8chAlXmgE8xZfalAjHnS1c7ROuPxV9zhj5AjnXV/Q";
+ sout << "7Bllgaqoyl3q8Sc6eXgEpJMz8fyZ+LfMkKkMBycdjKG0AZAkF5902y9b1VPVIhgIHueHDt5irhQE";
+ sout << "2slLN6pSkRg/3FxPb97BOt0/zt96EvehPVGoKs7OlOd5C6tQLWio1Eh9IL67ldEsXOiwaSMw8QLs";
+ sout << "wAw3Vx1VfWnDCReIT7/yxZGHmEGhjdtqs+gGPtZ5jkIp0l4QPQwLhVL1tC4RSWlt8BR2Nq0Ivon0";
+ sout << "E4t8jhH5zWeLYZxLIVkyW0SKj5xi7s4Xo2XOZL95m43Kv2zrxzNngy1oPOXTAEUaEwuRKmo90B5k";
+ sout << "xdxVS3x9BarxlfUKwNBvwQ8OVmJu6VPWQ5hx++9ZLDfwK63V0gH0tIU69q3diMOjXv+fmvQe28cx";
+ sout << "iUh9o3Em/w+7P3DTZu35eSPUD6yzP67VH3pDvahh4uLzes3bdTYmpo39kD+gFc7ADodPWp2E4787";
+ sout << "JZaS2A+efTMt/OPen2TV+wMSSK80nbAta3VE1KSQogWp2bFPi0NXqC6GZ+bHV1DVnw5YGs28V5Zd";
+ sout << "RuuIzTLauK958DBeAc57QbcAun5sEsPtg+CICI1pcOGhv8hdLzT+YKcuUhji95EAYzCj41nxMM1b";
+ sout << "wP1rH0TG7jaHvET8IY4dQb5VvzNXVne0FExHIyIFGIUWoyQgjROX+cHDLe39tvjZo2xs27/CAhZH";
+ sout << "dsrgamGD6wj8ZMH+8gHzu6b1yy2crBAw1Ot2JdNhW0kp4DNdFgjqMcZVCfDM8qRZ71DBO82U9QU9";
+ sout << "hXh0DCIAyryS/6W9mY2I4IxPrbSL4RkmKzhepcIoui74zu3p+lKY9LgUxYbGNOK46S2olWwJYf7Y";
+ sout << "qkOPWcN6yen8rO/9Wkvmxl06qSDDhcYdsc5S409X/VeWz67rJ2vNDiyTKjUJaDWcSU/znMnhtosm";
+ sout << "QT5LyX0ZVd++zQvXfjSjBLkJrki5vvKMZSUKhiYL3mIapZpcnsOQ1WRwnmSEave/MidONmZEP4rH";
+ sout << "OgS00fnCK7eAWXsZb/rk7tq0cjAFsNeJsHe53E2hJK4TyaYcUi1IM7ZNSQ1yqMYbcFn5bEHBpuxS";
+ sout << "+TUy1QppziHZ9xoLn8nOxRstSGxrhBG3WjfMhOwR6JNOe4Y1RDfxVHYeo2BAmtKUutztKsf04UNy";
+ sout << "XTP5HSaKx1tZVMDn2fkIj6W+s07imc8sg68Sd1lwJx09bG9ggOw6D0eouy/IeJi9YQAyZT2BqBQ7";
+ sout << "2D5CrrD6hqhVF4IVhBPCkJ8+9hbFs5EA8XhItGVBPyXg4/alIJq2ZeBsjhK+tuKMunwOl13sxzn7";
+ sout << "iSO4+wO3geQ3+SzKShzywLk+uGm2BdS6teWxatVdm5hKNDirghiEzqdHdwGSoZUlLMYp3KTdzpMg";
+ sout << "eAAPELcLl78jObmQkun0ECGaV0paIRhytIRpV3yJGmrRIxWNI5WrqLtocN7hnm4fTrzdlSGnDojf";
+ sout << "CBmISIwp60UVjJijLDTVYBIP3VcwaoT9V2F2HkDZ7ifQUTvHmxRDFHuquws0bNRjLSBSPa4jh6th";
+ sout << "6h62qlgbHQ/BBLIUXi/I12CwYwlg3ute/qQJIpTCsRxf4rgtEjvV2dG9ltlo8PPMhOd9fkqQcvua";
+ sout << "BJaZYLYw9rxq4U4TY/IxFQe8MGJ6It9BOFHD83v3awKPKP84inEWXAkFyI0YYmzFjW/gofECo69R";
+ sout << "Wk4/2wv+K5MFEEIzBUGsyrjz0TmiKvzEJ08KZOuWtFXeF2LUGvYeL3TO8FyLRpPnCAoErjZzVAhI";
+ sout << "9Sd3ukmv2oZ5oYk52HrUQZ9WywqVgFL8r1KKEouCd3iNhmb+2qUSh8VcvAncWTQbn8yjpGe+kwUf";
+ sout << "HssRskQ0nKLIvCz6JwWEjoaglBC3PiJ0EOwDLNP1fKXqPtcPglnu6aGaTrb/lOTCX70yvk2XpdPE";
+ sout << "HzsxZHKsV2LxqDI/IT8yIxvM67v8b9CSGOv64a8MSBZ5zYOQR0hunA/iYWfjzuvX5dgW90GnM5ZR";
+ sout << "VWoq9zxcBEr+Nk0HlWHpPm3o08OPGmfx/ndMbLLtuH85q4/7Kelm3NSXLXfOxjkRuSZkcAeiRk8e";
+ sout << "64X1nXADtf6z1Uaj8H0JfREe91GPSUfyZjwI0pQLDnnsgGoG0DGn01qshIL3ANXsLTHDsdgjaFGP";
+ sout << "zZ2f/7V/NKLrJ6HrVpWoQiwhNGY5OYv/KXa2q2XMGybUaoONCdnjGKzPYyXA6C0utQNy7JTKcbgV";
+ sout << "/71QilRGXbV8OPBFLLN1VNHV2rQJDSH6UL4RYDdhRSsItpsl2tMJGGS2RLH/FwEv0jDo9e/218Fd";
+ sout << "0P3DIy3wAevcgLQu6Ex2SFlTpZcvO+yidCgCv56rJHlVxgN7eSbcKS/l7CiAcafTsjsNm5VtgTx0";
+ sout << "CM3ehbIeuPrV5GqbZPgbF1akOkP9vzeNsGGW5kep471Q+xn5o78MZBz7YsMT3IAHd6RViNn/Rrbj";
+ sout << "gLWyleDnUQ13gY1xNXl1SD8xaQVvWvvpmlZKyB7uIPEL2VrDZSxmFYyhdhK64yb4Nvvawge8jIWM";
+ sout << "a43Sgazodlobk0CkFktZz10gZEUobBos3idIWk0hQP+nGOLhiGO//fwpiplh+g6NBMLj6vI3GoHZ";
+ sout << "MKoqqYh4LPDaEtnvILpMCP5NJoBbxNLrUALf17ObB5Kgz9YQsxRIhtrXoyK0XdU98ws4zxDGn5Od";
+ sout << "Mj4Zpw/QpI2cCvc6fiue8+l/6BsFse0JL8pvOZDYhy73kaVarhJhQgoZabwnyATtN5VltM8niedj";
+ sout << "reSbVYi75bWEdYZrO1PESHmhr7lx+s4gBJQd7kz8wbRADk4dJBf8QDccgMO2yxf6biY6ezjB9GtZ";
+ sout << "DnGrIuG867zq9fMbooU1PUrXEMYkeeccI3d3ClrF11AiOxFAh/tSSj05pwl7UWVKDXJrP/T9sl37";
+ sout << "qW5Uy21+/Mi+kNMCaunB+fgJHsjvoQpwSmEb9k7b35H04v/qe1deEKPua/xXlawyTt4vv9OFwLFV";
+ sout << "ovcegUF44BJCRF7zlZnLMzyveCE2Uizux/z5zXOW21bFEaODnvc9N8oTl7XkQAyDT67JrBC7WbH3";
+ sout << "g8lzwSjSyC4paGxi87EifET4QFZ/wdhP/mP8lZXzkq8xieb2ad9dSCHS1T36XB40s2QmiSu/dXNN";
+ sout << "yaGFiWtFCKYhie16fkeBuCuzcW98RXOUQYBVCbSaDVM7nNTXU5+2B5NfM6ts6+7fO+ovsJh74pGx";
+ sout << "HrPXVdWPNb1Y0v7YhX7fLu77zpFbr5Mo1tdL3H1P8k/y+ep5/vsBE+git78nNPeoZajrbKuf3QtR";
+ sout << "uKx/sDQrnt3opQr2k0BEt2HZOPeDw0HuPJnpUNg0u31w+xWZN5YsTO6HBpfg9x5hsYX9YybRFoWJ";
+ sout << "bpfcNHrNzlc940Pbn7cy/XKiIsegvY8JKScIsmP0pnv/3rr3M+FDocBoe9QDnj7a044KpD2ghnwl";
+ sout << "4uJjqFHv7RoGkY4365yAFnD/vm7tEvO09ibX74jzreb5lcSu/f0hy3IA2IbFEYiY1MBL8MPDBh6h";
+ sout << "dKL6Jf+UHWrz7MM2bSfDzf467cr65Pt+N3bmX/BR5yd5sPyaiqQhEewGImg/bOiaqpRnIKIFNrF1";
+ sout << "dHziLHvI5n/RvsyKPc1ScVFFHdGs9VMXR3ODzG4Q4wNyGYm8Bt9ciTw+Cd1VuAODX8fjGFhCRq5m";
+ sout << "wlIDo6rwxR8ME68IXNMqztq1JPQU1VqEXWGAOBIewAYPLOepF4PR28SUgEGqn32oSG6hutM8UAK/";
+ sout << "4LEsqVBnc1XrCczlpF2vad4Oux/Qd27JfspaVThUqSc2x0kY1Ca8Keggn3j0s5GCSXxzsQh0CuYq";
+ sout << "jSJjKBVxbOP7XU9EHt6w4/9SOnEmd2QVJ26x5o/znpYpW1zTwrIG1ZhvU5AhBtzFtlLEOs1WzBH9";
+ sout << "c2jFEPc221taxwBmLaOqa22VULgiW0zJ0nFkXZ4vfoJghqWOK9N9ux14hpKyWUtPPwPHT+oqctBf";
+ sout << "RzNzQ/ODTSIkbFdJzskMR10TAYT3zsY27dWLgdAnmwXE1xr4SCbilLOEZcsJJDEX+XkRUsMbxKxp";
+ sout << "Cr/IrdTuAzx5MogSWKsUj0W3cT+EG8Nv+iHw/q5PIag2y98b99gk73KWnswo+QXx6nTCwuSuON2s";
+ sout << "KlVN2mvKk/t6dNJCbLNQoEnIxoHumQimOrsmgtD5rTXK2fR8rZlnBzUG5Bp/45YBEK2a4JMWivLu";
+ sout << "nY+zNxxg0wslFV8OJaH3yfssBMr7lzi3jX0D1dCD86mSslQZ9hAx/pkgNPogz8KDbmoQjozc05Dr";
+ sout << "nUKPCCTYq7OUxpGXu9JAFsQsbqnWIkqyxAg3+s4M2cc9+P+dFpb99cQAHDIn0LSBR57f6o8lojFg";
+ sout << "nPtdvUF4gNJd1EVqy2ERrWmwRQ8gnzLHVI6wJYlJVUa7p6gk2yZBz4yRak1JJUgxp3l1goGG/0Ci";
+ sout << "PsE78UDJQW+3FmtOwLuphs51+Clq65O/n5RIeLQpQ89mh1govqglh7o7nHyyulhQTrEpBEvy2GWh";
+ sout << "LhFNp5FWc1ZV/yl4SwQEIbSpGk/MTmlPM2vZcOwTGNIxHia1PISF2LNDMDMmqQ0N8y/NlaKha2T/";
+ sout << "ZfRhOQdR9nOgI0svhZQtke/ExHkpYNHsF4bbGisnGyzmtrKgr47s3l527s8lP8Pxe2g5cgX1focZ";
+ sout << "aeI+CIeKPFynBjP5GmbCoBpOBpFfOgPrMU5cPMvBktXq27wF0eKbZ1pi3WvJo5Yt5SNlRbs8cX6V";
+ sout << "dM0WDcqw39ff+2ztBdDR+aeFRmHelayiPJ6Rmd3dWgH5Zk/vIGEaSsc7/KfYuJQXkca+gR1SVE67";
+ sout << "wez78yJbRDcoXiWhFOC4ilWR1DDKgTy7ej+qsbEtgEQcAqLukxeAnICNaIYzrQsPacqSvE6MLBiN";
+ sout << "DcoQ5TnAjtFPGGkkhYH8TwD4kTMQgSDXt5yqADxbAMLu+XsmgOJzyAL597+yMJm0w/COAOvVYPRw";
+ sout << "kgsDb3IXuTpW+bL/xGC9RvOefz1UoxpwP7znW1IYeltHBqT9JLiGLMQSRTl2x+OudBurZ66LiqQf";
+ sout << "TnX+HXrN7MnTdshMCQ339YwnY6b1baYKnxVwNdIUcLnpyVqMCKZNu92bPJsdj4af5nZuIP8HMwoR";
+ sout << "g29BOeXRb8UNtWd/EMGiuHiNwMMhpDBN/L4H1f6GS6ZeNobN9k613HrCeJmlXQDBWCHE66vVsAhX";
+ sout << "X/Qwi0w5PCLwwpoIsgMDnZDLa9ZCyfBez3VP48ur6Qs0llb1j5y2gHz24yGH/fVlWPdnc0pbKvFR";
+ sout << "KKW9jVKicLU9NGaOq5vLvBmFMnx2YLwhGMIlOgnnDLUkB7wCY3ADrQau1lPLfQ4QHr919kbUpCmq";
+ sout << "3syJxq1j5248HvzVQRCQV9H35RIQGwd0fjwtK8Y60fxa6W0tD7B/hf1b9lrodaHQL4gYt9/lrCQU";
+ sout << "LezUF2Y3HN9LxAWQNfoCbYaz9HebHTfCM+ynCUXAXabGngBLCat1GxbYaWlbotAXawMKxRplVYnB";
+ sout << "mORrYoF3NaBlDPznNhgiU0Uz6t7Nh03r3wR8NMho/VNPv4XggWVXyyoNvqY1NDZVtHEPVaSlVq12";
+ sout << "yUOiQ2qEpATWET8P+19RL1KAtCoRzPcuY2Ikir9MR8sTbog1FWyTqZN42P6Bx+darDn7eQTRmaHv";
+ sout << "DrC97SY8szr6hZEA2ei3vM0JpNWKYTOHo7pTlaDWt8gEDt/oQFPUSk7WJmJJ3DeXCGOdnRABlCXI";
+ sout << "NXDZx79pTuebQNZWIysOeisk2h2TXVvghE7J7QyqIYh80zqZn6nuAy05mtUT5igW4twr21r/X+W+";
+ sout << "Eu+dnXFU/Cw04SFh4MVLMaZ8kCerQ21GQn0NSgFCPd/bfeJyonCdXaO4rVATPw8tgbsI93po//q3";
+ sout << "Ps0KQWV//ZPdqERqZJYkZKZm7CNJRC9mcLsPAfnE8PRF191kr/A7BjLJu4UgxcGZlyQbi/uOJFd2";
+ sout << "mErOsNiJIC7MuRTh3SF5baHXsjiVg+l9m9lBGXmaSHAHBsKpGqpFHF2UuJnkWihfK1QoNhEkUhBV";
+ sout << "QOI63RPKJlgU7h7hi8fEzxzDbXbpWhsMwcKQ+j3UBSBzUGbRYznWar5vHxf8GQ1LklLRr+8jBJwZ";
+ sout << "fIPGjtKXXyMNJACNv0iulszYy3XUM56cIdLXqjVgEu41sMUfJGvbi3aP0tX6wl2GO5gptrX40j6d";
+ sout << "Qys2Evu2rH26dj/oCWFOqK0EKcH/Nfj/+EY+gm1GI6rSTgKq8RtD5J9fEe1nro4Sz+B9J230FWd9";
+ sout << "G89GSwEnpPOE6hRzubd/NyFZqe9sxVIFdNqRlOsPbXC/gJw7Gx9B83I5j1dpYAsqWyBKhalrGRz4";
+ sout << "UMLThFH6rJkGO0YWO2FF3Y59cWk3rS7taKomqL37407f3VWOtRH3YNhvxXI/6m0LV6czsmFHrSp1";
+ sout << "lkpYyq/UlH4kL0c/m/P2w8JrXwD+1SwgMtMjuwIQsETng1TeEgdyRlkRnmKXVFDfSLQSppfq3U4Y";
+ sout << "zrlnUiKPk511ZcXkg7TOkHxIMUpkhlSihUaBPwT02QpbcJR73i50BzQNDVQbQfMrXny8LEv2D331";
+ sout << "mtjKrWCWThul7ReXcJV7p3LpWuJWD8h/6dwx8WuBg4/zKPXZst6WHD/xoEjYG8QNoDDOE+SylSks";
+ sout << "AwuI4F5meYJNpFAoCj7qQXJKKyzyeTz2Y0nISbP34Ue1H01zu4SBN7OMvhP/Kl7RSFpc6JkV2brz";
+ sout << "zksgJZ4vpuqdij2S4lQfJbtdrk8smc1VWsVtkRKZzRfbPmECuz/0aNHsWe7sFyqISk2Pe5kwuj3O";
+ sout << "3mvWDzXBK05J63FYRnpemGdnvIkKLddEmxNecMuotq2s7fIWy2xeu1Ge6/8ILJmtwm9Bq3afgFfw";
+ sout << "nRBQ7dj40h9kCgGUoAx0a2oEwTdvdgHp3eEjlyGMqdodEC5F7Cn+uSh89UsHUYbxAI/qTpYekaVh";
+ sout << "5MRydNwdIg+3XCHQi+QGEr41DECSXUAoL+mCG6W87WE70G1BsvA3ofdGtKUxTXskdWVtyNip/qXC";
+ sout << "f7Fez1i7Cd5U8rzkdCZD56P4KKx7H9N10aU5jIyfba6IKOG6WEqgf92SISTXCVuQSHqVBNHoGiG3";
+ sout << "r1xxH8KEzApRHJF+whOvt+pP8eDVfVIdLrzDH7YP5nzQvnD1eVv6iZw0izQu7ArxMoqgSXOZlKBW";
+ sout << "L3nkAohU98uR9oYPGck8HNiyKzuiI+DCeu7BOg7unI0WJsn0qE/F6EFEFYoqusBXRgDVwmySLyAo";
+ sout << "+0eB/sUawsnUh/w0lRHBso36uynLjUSMBCwppQUU84pJ8fqr5aUfdGfWzogc/q0pFxppiexM6xHd";
+ sout << "P8/lVcbgKEclaNJRkE401wDUWHDmtVAHiaZA/532FWyrucLOp6IRgM+eZAtao2/hQUtwOOcaAQNa";
+ sout << "DcCN488KJmPQntnz9zn7Ep2ZZvsrx6fg9ubEEjHHIiRYSUDClLUqqFAJZaGrGsXmrSDxF0AEfQgJ";
+ sout << "aZF7CJ0WBl4qyUBcIIeWr/+u7Xm9WDnD3Aowl3/RVxH4L9ffGDEMkCeY/rmNp2SDuWfti7vSrxge";
+ sout << "NSQEaqiPPMHrgQ0vuGU7ipACRYtW8MgxSwxMc8AaAbhPQu8L/ZEC+/pfxDBy9mtMKnF8Jj2scJtB";
+ sout << "GKhZ4s1s7F1fqA/m3Ts4rfA0HLY7bAWeXa2sNAdvkjoQHMLAiP/5BB9B4vNqVGLSrislMJ0sYfcr";
+ sout << "t/O0wfihW++wYxK5HJWn/g7U7+lu4OZHco87HB62wHcvy24LqwNKY1w033SQJzkLb6Nkqj4HU9cw";
+ sout << "E82PHpVUAXqR6FRqOSGX9bMMqWL/e66vl0G3hbhIwQs5toUgg3XVv13F1o2sTzjiQ+22wf3PEuD+";
+ sout << "aYIYKQiE4R8kl3/96og7QmVfWoPxZiVEGBzM9bTDHx88W1D/CkwPkH389oV1Xky/L9UHd8adnI6l";
+ sout << "kfuHNP/geDSReFNJ3vtqUo6RnWHAFaTj+x5GDJDtol2hC0KaDLJCETOQRXUOOW371wB3I6slEKjG";
+ sout << "gANfqlMyzsO/7bxUw8X71uDns683wpW+Kwr73bJwe0OeP4dgwgTdTLQApMZncaphcpuZCEn9iSpw";
+ sout << "GX3EO37EZR2RvLBMvlHhkZSAWr8TqOALJeFEuYiduuUmyM2xh7mLw73AqGA0Wm6rnc7oZwXhwaBa";
+ sout << "VhwUZyv5rGMNQl/p5bMO49+tQiaIP5UXxheMCUN8eD6m7qDyHsljVQ+Madsj5iaJw5rkNcO4oquG";
+ sout << "mkrQOI62kZGM7reBAZjYYjf33ttAZspzXJVFkNt8geb8e9/AFBdfB2nkJebf8aAAie13umwDR09c";
+ sout << "44fUO0jtn43FxMBMXRt8SN3zVf4ycLW0WBYliqRrfhbOlN4wfzHNnII0JBO4wuDMC0zXnurwcLA3";
+ sout << "SWng9i7SI/XH9YZLdEn1cEtkTNulFFsZJGbLci2EvHEnG6D/m041JrCBRP7oCGonxdi2zaKe02B2";
+ sout << "LHaMeX7Of28XebQgQ3ZQrpSxW+h5FqIVNNo9EYxnqDDKyB7TeQnuISmSOwrcARqs7EgNAQ3Wbepu";
+ sout << "Fciu2AFnzjsW37Di4Z+L8tVWDN4nZstII68QCzlx78pYFiWD7lBQvuMYXGPujFgNGXIW+uk8NM/H";
+ sout << "ZWKV4UopvDx32VBudzMGJn3XPfrFqeMDukDYZXe+8G/jo8poz/kPfQmBqMSoHZyoKZ6yPNEeo1pD";
+ sout << "JyR5/neJsSu8rCPX61ZMuInitxMKRTS8OD7RhcO5o0kNsFwXxe2AhJm2HbOCqSnGktciP5M3gvpL";
+ sout << "d25K7cTMNvZIpmwH8eQmOoicl8IRKpj6qzxZt4m8yZU5yFiYzkMfEPjUwCmJs7tCZelrNnaLoFpD";
+ sout << "LxraAAsxLtsTcIL9ajBequXTFW3SFL7jsG8WiYRSYmu60gWPgUcmpaK+dk6HFy8SFCsaxlTmCn4G";
+ sout << "cbIFfJU1O45770/tHsDxNOBu/vsdp5yZw97L/xl6jSM9gTMdFbTdYw4Va3nr3L99zk9WnSYVZtwl";
+ sout << "C5LtQR0xUAldgCgSUfM06Y85VismlOis/GxVfYqYSwN17u3Rv5xj8JDpTxV8N2CwRXxrq013NXoj";
+ sout << "eDztUxX0NWUEDZuVvZUIBv9IO+9UaBHLZ834PTzKolkznv0dfedf47H18EWGxhXcJs/bqIP7qlm2";
+ sout << "B3bpaGxX3BbRVzuGrWRccOO2K6K1wxDPBSh03z/qH2ddZsTPArbhqlvRiD7kmMDGJzX/Y2i0SuCw";
+ sout << "nwhTENF7U/QfBFD8hWqi6/5rHArt/YSfPhaJyBOoRhTWgGc5RPlvMKviM7fE41Yhyd0vD4BqylQh";
+ sout << "iOKpy350GOo7g72UBTjlUHq0uWTX9rMpghHkZrMt2bvNzUn9SCx31vC/H7dVjcuEjCeUH4xkVpWu";
+ sout << "g820oyLQSkreplnhgE8IyG9OjJ07t1A2YcKHq0sr3t+TMLp23O6D2b/ZAYcnBFVIu1T2LIuX7NAp";
+ sout << "tOmFBbO315b3PV8bQdywYBSlIcILFc9AGlttfB2ksqGF9Juhx7UOPczWeBgaTOhBFFcBoSELVoqJ";
+ sout << "YP1KMe2EreYgzjVdUEAlbUh5R7mOs/pPF6XjoP++qnu9YE82r567jLBo3fmMGQJqURjgAIxr12l1";
+ sout << "+O9rTfxhjumcmJo6iOeK6/wCPuiD2QosXIlYHi/nFlS3qJ8d5z+tU3R9mi6v6I70GRSB4TkFuHmy";
+ sout << "5AX3LA05BcrjPH5LO/ulytgvkKsNIgOA4lRgZ9Z+whgI/MxvgGEzmTEyM2nxUFfzxlM4Abi210+2";
+ sout << "WFu1QMDXI6MQ54KGwi17BQOckYEFa+0PSrxwf0ntQcgcGnpPNCSlSATUQZKCN6E67UTP1KRbak1K";
+ sout << "nMvD3D9t061X0fAWW1txiUca/iO3KLYtMqAKgW6DZXuUzJdgtk2aqUYTKoI7sVkyVHo03aBBr3fT";
+ sout << "optSJzBz8UIRxzRauuWo6o1mCyja2eM8fn7TMvSsZx4QkIF9XZA6UNdeUpbSzwa4XkdTqdzmoa7J";
+ sout << "V3z+mkCQ/AaaEctE6Oi63jOlzfPmAOobY9Vp9i+ojaeo77SXu/7cwoWgeASY6zOHdN66e/zedP/L";
+ sout << "SNYH+B7QfAvUgBLK7QuP26mtws0qeeGWPYCB3x5+pbeS/PHqIF578D3lasqUUXwa3Nz+OzyBGXBl";
+ sout << "AB587DXXvbqK5gsoPE+mh4CJ0qYcUGWOhDejTpuEG50OQiGORtaswXqr86d+3vWdsmoJq8YFEWTR";
+ sout << "Pd+TwEF9FO3qrN+R73YWEZ5nihD5eG2Q9ABWg+C9Us9KqtlPYQOeUjwrEBhAJZ1j3UT+OKHXuR6e";
+ sout << "6dFh0fmlpXewt/gZ2z2WMZT3okC5JyKIMROWN7m+W+mBMmYg19gvib1KdBeBLO4sxdB/mElqdYYt";
+ sout << "k+mwhcuKUWYmaSmnzXdU8/Lxyvfr+hpPIPJHYLXftm1DpNHQjLY2bX+UROQ67weTACWOtfz0VJGv";
+ sout << "auzt3Kd5UG7m9nLXgk8ZibGOie2MWMY8C2reyrfU5sqOnLO9VNwazWvKb34cFEPKrGU1QqVELim9";
+ sout << "kq/0gFdbsyiotlds/Va5AkeVAVC81yYhYRnQm2W23NY62dMMgorLZoxuJV3RPVUoAa5RKRPq346V";
+ sout << "HnC7z8pqOsVfoslTCivd2CEzNQK/cYT/LJ7n2Kw3bhT5niBsVNyAjZaK4oI5fqzMoqC75ipQz1IP";
+ sout << "JwLqRuBTBjipG1Boum2TahnMriKcAhV1vAiHcB1GkYf9Bwnne8KiWpVdN54drWw7otSWxpsD7oyi";
+ sout << "nPfN7o9CrAcERG94IxkZJD3uY748ly7UrScLvIS0M/zcC3f0FuBHb3k1LKOa53B4Q/wo9FfffpYA";
+ sout << "RoixLpFQRoFkC4l2wfMenZEgaYkYyfWoWEfGD9ucOiu8h5IrffBDzs8Ck8/y32oHvL/FlouHrmpz";
+ sout << "5SKBIez1dkvdFTI/378ArjAwoJ0i2QsGyS/uO7FqgTAgd/9iIqeGonMnJFRPmH/VdT5mZwwj1BhY";
+ sout << "2/L0tu2kF3uwGX4c7fSDN8AkPK+DexcvcRBwG6cdmFs1871Ltv0A+i2UTgdk6/Q7MGnhRxnDUp56";
+ sout << "7cTxaOSfGTgKcJJ/US0PZjyWLN6tinlrYxkDq7uODDifVo8kEIOzYn/Zo7YbrC2FEXSqw7khbirb";
+ sout << "5ErpT55vaggbD/HtahzrQeYIc6lF5oVtW5TuAAdBBwnPwtYrkAVYQ2KZeGh68JKCEbKfNoAeKr3Y";
+ sout << "hhU51cHOBC8p6hm5T5p4JZppu15YmKxbLZIscvWCfriYVwPycUde10Nu+AixLGcaJQCr5LdWm06+";
+ sout << "ZJPovwV5yErl40B+WEa4YTBm9HJqYXOI3JRAEkwvioR9TotiOOqCaoo8YtIHRU4y3YOUcAdqvtIW";
+ sout << "Sj9eLPFLgTQ33kOtsZKDDzBJ/eVl6eSv559u23AuaeuI0ynarOjvQPw/543W/lCKibBGhuaKbYJy";
+ sout << "igAy49KqZ2IBHGJEOxd/7e4pnBgW2Yg7T7oeruGrIlJk3BFKQxEOO9Usc7XYSiXMVXwzGDA4huZt";
+ sout << "2zq/k193xAGK86lvndpmzpJ/eyVqxqGSN7yyr9o9Mq8mfVJ6CYGXgb+CybtXX1r5FfNTT7ikJreQ";
+ sout << "koaSxOGLvHfys0h09POLbTZ4jXBVrcuLuYoN2Dg6vpnMQSEoZzY85NOGxyCPUHXf9KFRfUe+r5zp";
+ sout << "O4NnC66UUzrBan3eTYs5X7DEksg+bPx4KGMQ/byJm4Jr8Jzeug6S1qV1diRpkgR60KCS8Q2zCmqa";
+ sout << "Y42riGdaBXcaZ3MM1onSKNWiIWa62QGu/GuwY64JUXkSTZViCgkPbieNk4J9WSjwKWPHwPq97qBy";
+ sout << "UcJb7MNuA5cxGXOUpUgo1I6ewBWZJvTIU4o/b+hBbBntfgacsZDooFmU2xUYycL0/ornfsEuJczW";
+ sout << "1dI+thnc10hcf/WWhUpOuMLQK3f34AUL7tUWtsuYFaRN+1dabnldREIXMFnZs6G+gmbaX0c4TDLh";
+ sout << "i31s8ch2irdU8aSTT/1J7EkMvyDZt7NLHeTRgRRrVtGg8nYbBrB/MxwqGTjtD8X7+MYzloLLiUrc";
+ sout << "7ST+dShb/DY4JS5RA8I0RBJbqOcMYzdv61UaxfYOnvZHtbCGlj5MmIMaPCDe4ZTSPsk1+NhUZSKr";
+ sout << "NcbSth2EUMRZsJqDLN5Xh8LFoerrFLNfbb6r/Zo1I4T7NcVaSKim7u03ti6xCQj4Ds4yqDLrql0P";
+ sout << "8IP3w1WmY07b+YRN9eqS7gFN4ReqCPVlk1qnZOzzPjH8pUPqgb+0cjuvQyJslbT98FrPljGtXbJ5";
+ sout << "si6Cp+zT+aIEy9HTtQ8Wq3ojrVLtjMl/aRNPb94A0ncTZYuoZCnOC75i2lLru7HOwHCdA7Acyt4M";
+ sout << "a7xgqqPiV13DaPpOQYp5rwKHNelo56FyhrmhBC2qQznvwqLJPRcnMEvfsg3Prs26ncYOOwXJ8DK0";
+ sout << "GzXkJVznLs71ggUJY00mOvTS5+5sPWwxyjKhdckQfSil9orOJJWeDExnCVZYuIWBij5VXQiACLyN";
+ sout << "WBKo52tef5NJwqNbjInxKmFChkX8N6en0uHyEAPxNLrYrn5I8+HibsvccOOIwb9yeT3TV/DB4SiF";
+ sout << "rKq/nLdDdQjbReZfZUHhjKz7i/QB7eHinoBVgLcmAkN5lDJVec1Oq9GNwqKTpM9hzHl9KZueqC0h";
+ sout << "EJILZgoJDs6Vy9NgjK+FYD6O/2Uv9CGfnCL47ziWfyWgLX2UPFX8DTrnWG71HfNjdlWO0oAmwJxh";
+ sout << "0zoURT4B03KBdm54pk6TNeoThxauGpNkK+vU6qcILlI64ZFxaO42b5rjQXtnmYEGL9d5MroM39vC";
+ sout << "hykPHjc2j2qFIRUQ2Eg4mtSR++AAriXGpDu3WuIygTiKMNy4Jbt8AnInqnHe5IeWgzE2eAomR8Yp";
+ sout << "ZkI4LCBOIPRQbFZ+DnX9Ntxqx/6rJmXE9njSqgbepHd4e38SMP9iMBR6u5fX9joI6rcRikf9W5nu";
+ sout << "G/TiDJNhEwvSOtgdEu3pv/urCx7CFwbkhSVz3r/HbQwup/g5FDuTxteiAXYn9TKCWXomd4oKtip5";
+ sout << "aZw1aIIYHKR4zNPsNPw6uJUpwKU6aaJjAjXJtaB3Aq/fcvuHdCYL0AkL5AAItGnCq3TOli3VAHg0";
+ sout << "zLm6OpVaCUravHHyYSuwXHB58w/zjz88c3tlOLRJO2CoFIUzvesT7FkKN8+y9itHmIdvgZCCWbm8";
+ sout << "uoYwp4U7dy6bBVvcxjca4A9f5BcsnZ0u3h8JRxdhMpwk3H1kt6PwEFTrk2Gb4IemrBh4GugYC1NU";
+ sout << "ZGgtiPD9Lk9pXqPb9LHcY4nWCEioiMUBd9IJlcWa5NOnZe+xeEHgVpy0vYrtqVHsoRAiXzRTxgaa";
+ sout << "2RlSne0P1g0NgU4YCnRf12ZIS+LvPMbYe3mXfDpn099AvuQiUngIFGKm//F59pxxMjZg/ADElluv";
+ sout << "0eOBWtYiaYVDL7D0hiffIcBcKQotLJ5p4q7t5v+7oowQ7cMyM4hvL5IjS3BqCpKNwWg3PA9jnhun";
+ sout << "sAVHJa9lLMk3cdXix+BKNhfq3IfqpzXAMdlk7GeKbEl3F1/+q4sxVq7sd9X9XLNR2lhJVy93soM8";
+ sout << "HfWHR7qaHnA/7XlDMDORDxuT7H3IKxDa+TkORiI9ldqwCEzC0ZGDBU3rBBxNK1jzXH7vzQ1vX7x7";
+ sout << "eqYl74NRYH8j1jDxWeQx6bQWSy0V6N+c8uC9N9HPVuH/18KArH+sEDxGRW+Hxi1Wkls6+QyBiweA";
+ sout << "tlHueW6h53s3Hc79AEK7gM+40nwvjCOG9vojSoI7bc2ndc+gdqRxYdUz8mu79bx8CI9Y/WlGO/Vk";
+ sout << "opsB2ETODt3OdS5xMTUZfSO732Hj3nQUTmVVheFbxDpVPUXNx6AAp63iRR1oYv8b4w6rXrjpj2ru";
+ sout << "8U+EXxv2W29LDJFC/JRBwj8JyTR9lDhARPqehW6EZmUYERihI4NZbM9mhgszbC+XDiM5rL80FVz+";
+ sout << "4eIA3mJrPNgmevZC2Y8osV71pnmt54foH2yERar9tqSS+q0q7PWX3+5jIrM/IDT/TBrHQ+Boo0y3";
+ sout << "bz8JHPnwCKXv34Kjg4NxbDjFT0G8hL1+LsNle+d+PyYPAKSLOjtzDicknXz3M6DgjwMrzD3zLf0O";
+ sout << "jUlwIR13B4f1mwH/cOatF5LfXQRCmUxJt/16/VtPsB9FuC0sziuda+4/CGAIjJ5GB0gqzwN7DSOt";
+ sout << "7cNdIO7w0CTNOOfVq9C1zPi4fndOnX9F3M+sPSzv7YjeUbkchXzJhip96etsRY855grSR//reW8P";
+ sout << "XNodO/FvibrfkHKpG7fZOU+CRLqnikcNgIGBDYumpgdWA1ynvAOFPgAk5tqtLZKmrIe3ceptlG6X";
+ sout << "E3UDGPMNVgExa9GFR1e5bIq0GZCNeIyGMq8cGLny0uRC/u9nPu0NoTMxFb9rr2iEp1DJB/zyiped";
+ sout << "EoTpBPDptrXVQ9DOLvpZnplsmuuCMELZ7eP10CaDzNXodNNSFr6bk5oDTHJG3foIB/VHzdH7WDDG";
+ sout << "DKMcvwwXAB+gAe0Uz3C+W3DlvI2xAHxpjx4hPvATyJwVgYLus9pHzKzZCJ/gLYMgcgyArdTPI7JT";
+ sout << "ExgJcLp8QTosu4zuYCJ6GgPmtYlVtUdLWiPz6lHCW77bkkFmz+c6u9WF3yzzY1nyhecgnt/7jbyA";
+ sout << "iSeDme7x/u9K3n3CTHc7+ZXHLT6HSdxKlwZ6811c5Z12X8aP6vvL/ysGTVwa8lRngJOOiGQuX6GT";
+ sout << "PMohKZIhYGudZundQNXPOCavN+dpdT5m3AVMN9IrU/ht8o9eH+Hmhs0gw/iEwos7GgEJDl7IUcAg";
+ sout << "TZb8blBMrDfdTyAKjYcuj6iA8aDYes9sB/W8R6AtNRG4S9vQcJuCD9JBKjxwd/5jG4hrpibV5bMD";
+ sout << "F7dW4/5kQsc3AEdYIeGITaDjXorQ8UZ4zF9Zy81kTw6RAoTKdMYtDgnEffkDSRWGUqySn1CkTs/p";
+ sout << "cdY3SRk1HMd98JwlVLpED+mWbcLwRuROBimfERC048b/vfUXzGSCgdYZ9AGQhHPzu8y/nr7Lp74h";
+ sout << "I0ghlojt4zEkgBgK1JRMXD/dxwqIu+PIOo6s+8vwOMCLytvqCcTR8hqs63JgVGJ0c4/O8R1nbwFg";
+ sout << "iQiCrQ2EjBbGaBMq8vhYNA1r9QzZwHejoHm2Z4a+JXoef2al/GqOWnuOUHBotsp70MwvCS8Q/EtU";
+ sout << "TFGRttRKcrm7OhNmeiFLAr4P6C6O7gJxMBu78s0bIzpSIDTZFCZRYU7NWdSysmRq1ngerD9inrKg";
+ sout << "zt/qQLnry4PmY8NjPwt7QYcQziZLN2ruCtbvk2DrPVSx24HDHGvB+0g8jLNLQDfS552/Xs2G4zc1";
+ sout << "wwPpayeQy7HFGT7Gtf3oCI4jb35g8MTCXoOMgM1Dk+EYJIyDT2PBae1AcDcDVj2JKOI1UbvanFI2";
+ sout << "sFpsLkIgj8SVUI5dX2PEEl7AjmoFJqyeVpqmBgNEfNwPEqfNNLUp6mic3WusrWO/jQwFtUIdEWAc";
+ sout << "tjyIqyXeF1e1zO5po9/kJOrRJFAFxMFI8PVZNSQJ6KIQDhjPJHi/z5cCaf+lxqSo6OUt0D7z9o5A";
+ sout << "hfA5GwHX6oy8tGqommBrb6nF1AWVInk19mynB6E17TQRo3T5W2OFG+BaMRO1xlg1znAurKht21fr";
+ sout << "6VZscHHtctXifsg4dequ2SkchmjJ5yWpxqRqckES9B3nVTSiqn71/T2O/poDieJhqADWNw3gj96c";
+ sout << "GI4vlUR4Zo1drdvtgwHFFR+RKndNuzv2BIYvB4nm5SBqx0nuKcvbbxkEEhIRVOb/GiMJ8ZwjeVxy";
+ sout << "b07lDWMdBxLyxytbDoNZbefVc8N0YFrsFvrc16W8y4XKS2NG7Kfa4u977Rr6Yd6+u2+DIYwfsrzM";
+ sout << "DWs67DC1+vHawIhIQ5NyB/R6hx/Tym42s7IK7l9w80SaBvpN2A6vaDlNlOD7ANeZsmaAG8L6zFsb";
+ sout << "FLBPRNFxm+HxHYjpG3+dBySqcIyK0tSFfykLVffPMFCtgzTwDM3+5VJEghCyVlGEhDl5tv97jwVI";
+ sout << "aTYM84hH/JqKcDUQwmilCXqbb8bh/kTEWKzsuuyaLxr3D6+5MhVMakjG3RkIKfjddRukZC341efa";
+ sout << "fFJVWanSIf3lSrhFjsiH5k63EjbZmAmuMpfiLZ/Eh+zjBbh0dE1spYJ62S+PwkIXMTjCxmoJW46O";
+ sout << "5Bdl71SqZW7lnQUubNcgRlxAZGKY3aQlIV2qHptzubV04SMqdcyVYPrPmvZSM5U40eDoJB9UYB/D";
+ sout << "JaxlWFMYrW1ntJTv+LrtYepKdD3uC9SLKi5VYrWjx9WzRMxxkDcNt85Ak2cfpA3MQXI+uT7lQVoG";
+ sout << "vDM472l4Ruls317G1X95HA3lcJNseSQnFsAeE9GhE6XXmxDwIDqUQ6Mk+F71rvDufNIj+GaUTISI";
+ sout << "+lwC4tqse9dg+7wcqSTqe1cx0K7tZo3ZwR+Y3JEe+6HyURImluo6BRV0n6URS2MLpgN4geV0m+6z";
+ sout << "ZIML0oUAKyJVdyj4XWEq5Q/pPDvz7jDfKyqcHTCwBCuBQEnhsxi7Pr3yEac/dwYDKRSkX+7Q9mdo";
+ sout << "lAjgzkEnxlR2JUJxtrsEuhTiWql6HjmrkT6ohOeR/k24Y5cEiKx3QM7gBcB57IoxxMVaDJut+Q3x";
+ sout << "o4HQ7J/IRtkKCPp0LLSK4Ue+nSgJY5LWtD2a4KzmVnahmXeYXCuAsKm3+7dGFISapet16Grci2+j";
+ sout << "ETdZGgs0DEdZWncLxrCnhnvXHrfsSsVyqocOtTs1KiwUgGgcEhlSBZyH6bpF7IYhQhPzML/SIwy9";
+ sout << "ByDnFZIvESXwQObXa1iYfR37SyV0Yw/75bcWAnx/+jbodjLZF+RMskqs/i+PnEasmlC5wrjYo55o";
+ sout << "E0tqAakxpTNP8+TrRgNMLFP4tP25lfDLUka+pUXoeLC7gWQ0MTA8LDx5GyMlblaIFGHx/eMZv05r";
+ sout << "he2nLJgKs5ZtTMh+8pMNkEBWXQDKyJn0oVLcRQL80bPfpUfsnTW+jrEkmH8azM7CN028QyPEhBPx";
+ sout << "lvDsEjVbElUqBH+eo3o0CYLpnABUQ28KvsJQhHAsk2WVScBv/hm/T4c2zABqaDOl2LkRCK9ZeF2W";
+ sout << "kqLNjVH/iZBVQcjo49++lINm4Nupi0GPnWpsuMnNoOdZzK7Oygqbxs9Zz5VqdjRCpxQGQgOhAD2x";
+ sout << "aifYz6RSk7i0RB33QkjE3u+MMEQLL51zl2EOREXed6AJT0U9D0+hSLH39wqMbdEV9bIFcJ7XprJP";
+ sout << "ueO/zZOmCHIBfz+47Kec5BAWkufEOfKxOxwV/Wvh9+VA2RcbA0exvVj3Fy/hPmRaw6E+SBZTJoi0";
+ sout << "aQltfhFmfLQtT4SRl2FjDwlpNJHjmbmVOALldNBRhxSz4oWfurPqAB1A4MazDorJI74vqElH1SMQ";
+ sout << "7RZI0hDhS100VqCdgnYhtMa/We1kNuLcLi7cSq28JPUE0tkIE8SVLldq2Ekn90XBP3mY8LJe0WMR";
+ sout << "J/LzfvopGKYrqday/eYuiXl+lihXp0dpph9ZUpRzmAYWtns/djnzFSGS3bF7VOOaFuGxwYHi/1p1";
+ sout << "SaiBJV2a4U1vr0mh2OU3wMuqd/t5nfl+Sv75zG3YOAvcoiUH76xxHvRnn+D97TegC72k4mUgBCPs";
+ sout << "NZJUhjyKF5S3rerk5uhxFmlAmy+XMo9VJJaYs6OxiIeSQa3iAnbnWDzY3m1chupbDYK2fs8tBg3c";
+ sout << "ta7LM5SgWZLMN5Lj1RtEbFBpAylXWrb3JXGuaK/C2biFNTSFU10Xoit4NcqZuYTXu6OhNahZqayp";
+ sout << "IYXTMXUy1KPyHHMFWZ9gXXB/igOg80NZ2O+EYHHfRiqSOFJ7xhnY+H4X0kh9v6N7lxlv1WOudRnJ";
+ sout << "lWg4FktfX+JVDflDvJ0IqKuM77O5fgtBpRL3tpS2dn4U227ZiXRybiv9zCnbpK7cfq7/t69AVgUG";
+ sout << "qjyTGZ+YG+PDjVBWC0KGsJWIBIHwCoTgIZO5kgU/831jcG+LyuHWf1Q9bIEkOXy6UzexKYCJgIaf";
+ sout << "EH8IsTAVe+uUvVSBpr/Eis7LJkyx6KOAbDR3eXjXjhXJTSTK2efctB2Bb3GraweN/LUZ7suO16Ng";
+ sout << "odHLPnS7jsgag4B7sLNCE3NeGfd4VIE6bYg/9+Ci8k26WMJzOHIO3H+Kc7cTPCdohAQS93epRcLz";
+ sout << "bz957op4v9NBlgMcIKW3Fp1opHX/+vApvBecnNDF2oJ0bUq4pFuM8Ag+2PdSDlJfcdnq6S0lNA8w";
+ sout << "H/65WlKeVTHRt0nm/p3jJ8JNUlIZa8iZyGa/vpCaLqymASIcIfYEdfi4ugCOvbKgBOT//4cl8kji";
+ sout << "3BpcImX1IRx1UzgcIcK4lqO8WoSpGqhClHHi+/ezdj06JLKP27zKeHZwhTMAyYifDnCOJG2XtuY6";
+ sout << "gg4LTyI7BcZ5jAogPeibS+GdwlEfN+hOk1idevzrgEdzJTIR3i6p7ugPS4unLC2+waQfdJMs4Vjd";
+ sout << "aCmDdalKnQlpHL0Smr7Ezu0P1jPQcyO2hGupDOHl/TslPpoPcA1N6OSTraByTWL4yo9aKG1QDIOX";
+ sout << "NLGvos5UGUwmPxJMTeBGkxT7hyKmNzMmhfKB/yfqDaBaV6fosCm7S1FVsdnU4PvcFxa61Jrtok6n";
+ sout << "/lwuY3BM5VCG9/kZQtVUtImeRB4hYy8fETZI/WH+WSGhZZaum+Z7r4Ngg8/DC21U33wyR/WElnVE";
+ sout << "+Ny+yHoK9xkiLpLlx4ns79XnKV+S511udI6XniRyg36kCNdRWTfN81HfbS6h/KXzl2ByPVo0Q9Op";
+ sout << "lTF2Ox2WJ0x0RMrE1SyS1KMyL1ff+ToYLKarrgalvFUvSMI7cxPlc6phczmtYojlm9R7JdJDpuGB";
+ sout << "k9qA/UEnH0v+wl1+5rt+RZ4dHtD4FOXvnqW3GdN7XvISIFJLhq+UMhBPiI4M39ZRxGpuROiF3Dzd";
+ sout << "gD19YX9c930yj3JEhkbYzjzwa9CWLcnXbSYe9sayBJmCjpUxYmfrWo+mH66POo9ZxJeeEOPiVS77";
+ sout << "HXv52am6J9m3gm2Buruibflo5CGo6Ngw9slSm2rcP+pGEAyKjQoLrU6fb0TxqUnRB8q4PVvbBRFj";
+ sout << "c1y5MU03KShFhFfcM7iND5az9RcBJwyNZtCZr7OxlUDVDJP5IJubY/5dahCfiX3EfMXT/dR6R8Yq";
+ sout << "jnUNqjVGnIVm456KF1xlEZBDwCL5LNUhGCV+kp90STd77lypcQPP1q8iFDo1HhHnJDj8is+5eMY6";
+ sout << "5bfMP1EP13tCqPnHgD4hK7Ega8bkvoEKZInxohWoaS6F6o3lslr267EKKvjHDmnA9uGWzxss7nuj";
+ sout << "3PlaWuYpU9NfJp6Eq0xYNYcf4qLTN3Nhd/GRpptxzdMCuu8x2CYKBkHKMy5iJELoJ8ntEdLtVdd+";
+ sout << "46L46j8oPIUyswAbmodfp82VIodnJ8xvJUyJWRCXpCa2esfid02xvel9h8w1V8Uc7BKu9HMsFCDn";
+ sout << "+FbebdVIahMe67eUI3i9+mSZ3BNq0qbkn/dDabGwixw9lsCU5uyHZrzafG6lIA/MiepNzIRLO2Wp";
+ sout << "iwhMc4f4l8ZRb89HhL4bjl0Wo44VEBP+Sqf+jWMPS8qo5st+MZbtu89RObDH4fOr+1v5dvGkl5/d";
+ sout << "YROeKTCWyAWWSSoinCHoR/IcBi8J31l34Pk3nNrZ7BAyMn28ZY7mue6AF6hpieJ+o/0RRv5m6VM1";
+ sout << "1OpAFyajq7E9/BjrAfRSpCtoicoIx+APfjrLbCkEehMCBZoOM9USTMi92Id+Fm2iI+/AMStEU6tJ";
+ sout << "zumgP60vNm3VwZw6ep1m1DSb+2lduAxtuDtaspztXYx3QhksLmbzeKD91/v2Mu/JveEfy35tf1kY";
+ sout << "fcFtGY/puCXnAIxEeKY26aiZroOWt6HLIafNvD4wg//A/9iTszmQauiTWQbBLxcK/KreAE2TOpJ4";
+ sout << "klwh8NfDhULKVKoJ/AgNS4I/Y/LFZSX1VerAImW/WQmpkqvjnOUrpLNkuV2NjA2tKP1G45dInmG/";
+ sout << "nAnaPoGs/bAdQLIJayWx42xsKM1pil6q86WvoKcEH1WrRRXsRJgw0dU+Z1QIFlC+xKqP7xS0LZEs";
+ sout << "fjyeED07SGKBhUj1kbSp9E3PMr32yDHj23zmnFyBcwihktaRV8fEquoE0CacrSoxfdxzdWIDy9Ya";
+ sout << "6tVBbCcxWoz4CSGzBKSrEXVjedHR6H3etlrGwzPWYwbiFQQu3oSNml6ORLh2fu87qIjz8VrO2804";
+ sout << "dAqaUcJBtHOOB5KnNMSyZWA1+62WX3P/yJN8Be3AcHY7xf9bP1HUxyq5cWcAbtx9yJV1dhXyvPz6";
+ sout << "D4k1sntDiWgivLkY7l3OBVRs/QKGOOTj1uLei/v6X1vR54lH5wsRwq6niPYsyEcCnWk2zzTKOcK/";
+ sout << "OBhLQHtj5Wxf+NQ/S1zZnIBS0cww5BqZD0aATVsHfifZLEw2eRM+vap33o0ozqXIdbCzzsotLwlx";
+ sout << "CfFT2a/MUDatTSTuGjOPpjbmVtK7qWkWmHAF0j/GQFs73K/8PGmAWomFb2WWB8frlHkkJoEPyHVR";
+ sout << "N6hlsAEd8ylb6SwVawZglxdpqfhQpn71NaRU/ZJnk7fnSVmKwcGVeh1sppPYKEu/Q7i4YrAGtLDp";
+ sout << "Osipwqh5gL/skuxQwsBI/VNqoxIELFbJKhGrVqr2kHWB3yE9cB5puaP/MPPvIwbPSNRZVVbvGV+m";
+ sout << "Mmi4LDSKPlxDmjl3wWnzdldnNUb281O5dA7lMuPHgibIFjl8EeVHBFLocpnFbSLRiCZx+xh8sNRX";
+ sout << "ti+5Zl/pGvlloP2k7yVNRFpB3Av30jzQePY/DOy5DoF2k4eY0qvSR7HtV2IG+9wpBAUCbIaD/2t3";
+ sout << "FBcQT2cWEx0TPZbBqPLAxruNGcv1z23BdWUeQpeEV6ij78xVILQegtNiJlCgMTCmIRr1meCfuVJ5";
+ sout << "vbvNLxRiL4JSZbjKgfonK98os72YhGB0Y2oL7ZpcthGFKCGAjo76XrMB/hoeXlx6UKYlLknZImWO";
+ sout << "NQj5SAhZElv+LKdSZ1oc9U+0Qm8K4G9+8OnQ4wS61R0MpZKXnfFAMp99XHqN0tD6W5Br22/IOyj9";
+ sout << "x2QH50KAfIYZFrJaLt0MMylpVhHPQ1RBembk0wDcwDt2phnU5nm3m22r81uUHp5Sc0Or/vrzpMEc";
+ sout << "9vqhEFhFG74o+EGTigUiyzCGUHJu9ITqb55xDmPm829KOqkOEV3L4XX/jLwRFVkbAHaijjmg17xN";
+ sout << "yJZkFlYG/RntbFd+p+10JuhI9ahBvvCjFnkC+8/F64Umx4jlWJRPIADPr3yH6tt+nVkC2Cbx+U4y";
+ sout << "bgINetMWF1b3syNyJh4qJy/KXX5UckzU9A1Xkz2MvXWiB8Jp0f+Q/bZ0inTmecpyBGWCsD0xRYPv";
+ sout << "Ktm2zqD2HL3Qtfpyi0khXGmUDKwvuFnjX5V57ustxKCq2h/ylqTDCMpqS/QQ122Tr2DuHHVjGwGG";
+ sout << "PajaD64n0qEj+zaqyUgpDZatEVMgFxpLpbJ+sTyBaYKYPYqGxoo+XycDYIWoc7K+7d2NT76qQcI/";
+ sout << "LHo0kGXRTNiciphMn4EX3PkH7zUqR/yCWBRsP4wtPax6aM2m6IpY3bDbOfwMzXDTZDLAfi3is8Lb";
+ sout << "H5AfK2DDG45w4olFJuLKO27zafuz9GIAgoTsUwpJIJeR4xdtlP9EtMjG6SZdp22+MBSh6w4CXe82";
+ sout << "iIquldB5HzaHvBr3UP5LGQR3W/Wky4udFj2+XCpKsDUll+4MZCeGz1VM5ZYjDKfRDzSaLjJspoVH";
+ sout << "TT05yuagEN3+w/ROxo419ZEnvR+X4QLyp0JofhFpSsr+J9fQwaz3IGpPMs7wA+Q9CfEw3FkrFfeH";
+ sout << "022qfnq3naUsXyh8xAeI4r+MP6WGlscgFkWPZhM/vDq2IcIbCwZ10Ivx8r9bwi6bAfPE5cru0oRf";
+ sout << "YYs7kU0l9z3yzeSvjVjkdZbmgRNPpED67KlUQyRibkWe34KGglSuSClFu6PWrfFxb/9ch0KTFVQS";
+ sout << "fo3NDryMrF6oeLHe7osiyJ1ismBimKN3idmZ7yUknYbpR7/Z2K0LbAIttidcQ8LpgUzy/6jROtHE";
+ sout << "MRLxKl4ba6tx6BP8WpXWVzx6xYox42qMX24SkhhNVR/9ETntqyFEeDXXlS88Nm14AZARVAp8wJ3n";
+ sout << "WaA9OydaCeVzFVZ614BRC41g8VoHz9iZxHF/7fL5gz+YTJH7ccf18hxv4U/y0ROiZAQflTW6+lYI";
+ sout << "kIO8t9A9eQ7Djh7cOHCyA7E391UriFUZH8kBzKjnMeF27uYx5U9acC5YUVtDk7O7Xq3wqeGaxS1b";
+ sout << "oTVKitvXCbdmvIiCK8N8RZeVlteuYhJ20u7LTHsO87c1Tv83pu9OiYDMwEpAFLwfHwi5YfZecAIA";
+ sout << "bkkveDusgnh80pxc/24zp7W42Hm6hfq852ZLLWD2bHm3YZdey2WIbGsjvQS+io4xoEKGLHUOi7Vs";
+ sout << "/Te80Yujo61Y3BjY8w58xNplBAXvK4Fn2EenYcDz839o62+xeUuGN4wSUpJF94aU9//hEF93J9Zf";
+ sout << "CxFPlooKe1W7nU3sl2ziEh0PeM4RImmQ9zVNkDEx8ahdmvWWSWlmlv1K754XAdwp68cdfKUIY8kk";
+ sout << "M666TRZrSCDUvDufPL4y1LTpQUmIhtNLkrYZg0WEXx5QRF00x90VrsIFcXA/3Sjc2MPi0iMck6vM";
+ sout << "kh4kYbIqlvgQSwTqc6XmXEM/W3HRuVCH8kQKIXjN5my0C3RbbYIr4vAPaQL8PCWaIjbLSHDcutk1";
+ sout << "LBjyPnJQzn2/vvm+3oKN2JXPxeKvTbaoYu7EvWtXxkFJlrKY+o+FQfxmtN8qxniGWVLuLX6k8keB";
+ sout << "wCLZV6wF1XG5FsqBoVh8qvAV/Tykk0VO5FJnoJ0sVwOveUgbCFKsReM24lZlgzHsHFaTN4qSWHyi";
+ sout << "BGg3nr0dETiOyEfLP/ckF4GhGMMQZmT9jMFyUEgDfszoDgF+jxA1buvCKTRm4iVgFT16pi4wTG8I";
+ sout << "fCOM+SU2UltyA9EHdOW3rm8P8w4z0D0/IzH9Vd/O8Y2kB4TUOW4yHdB95MBCJbvEGYXR/A1eTsMW";
+ sout << "J96ZVGCmlcOzWqPsX61LShr57TLbE7+hJzxUMEW+ri+goqnSehcpGurw0+d4K3+5jFHxUpNaT5n+";
+ sout << "t0p0S42lOvrgigSZrQiJozBCCpm4sCCM91bzGge0VpyuIGTVkzbzXfjzF5J/wliLSVyzqSNq+HqL";
+ sout << "06BqMB0leDtLT/msAUV6SRKr7/n6qMgkrIfrAri4d13kZTpexQTRTFXxcoIgwUelzsxhdwZQ4yMX";
+ sout << "dAeL/oKOQdjCgnsVFD5BD7clhTEOA8gVVxASHinziac7ZkdYaHwESDZBxJhv4j1nkKNLWZLifph7";
+ sout << "NdJBK51yqlBvkZKROBbTiQi7UCe/YGXusDPCBsu55iieo/o65HkF7olpANKxvHxBrQuCTxYnL1tl";
+ sout << "pMjxmZc1BA2T7vyd4AzXRJ6tMUdqUqnjPVFY7OIKMgUn7qRieMqt95vzJqh8+jdApnY+xwnIKosv";
+ sout << "ox55mijPLs9oUBzAJPpD3nDLR9pnTVIkY2RmVRQFUN/kuJHYbNtc0PRIAv6iDiZhe+jeCkTx/dXC";
+ sout << "sSVwD5hp/v0TvaPa0XSPr1BbqlvK6KjtdsVJsUOjHFskNm/8qlIGKp9F5QCtLOBhp1eoy2AZlNlN";
+ sout << "+eYQRzwMSsJNxq44rixF97d7qeiOkC/Uu3wNk7aL11AR5iS7gau10LHLs3YhMbUcb+4kf2j9NpWG";
+ sout << "wqMklOYYJag/XNyoQs8g44qAha1rVyeq4eXodi0JegvjkXWEB4Mq8jBuHXbYjYiRiHoL68/9mry5";
+ sout << "nlN2Duwp7g5yl982CZLZc0k7uSjKaDkWyynH60MwLnmVj2sA";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+}
+
+#endif // DLIB_FRONTAL_FACE_DETECTOr_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/frontal_face_detector_abstract.h b/ml/dlib/dlib/image_processing/frontal_face_detector_abstract.h
new file mode 100644
index 000000000..20815cd0e
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/frontal_face_detector_abstract.h
@@ -0,0 +1,25 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FRONTAL_FACE_DETECTOr_ABSTRACT_Hh_
+#ifdef DLIB_FRONTAL_FACE_DETECTOr_ABSTRACT_Hh_
+
+#include "object_detector_abstract.h"
+#include "scan_fhog_pyramid_abstract.h"
+#include "../image_transforms/image_pyramid_abstract.h"
+
+namespace dlib
+{
+ typedef object_detector<scan_fhog_pyramid<pyramid_down<6> > > frontal_face_detector;
+
+ frontal_face_detector get_frontal_face_detector(
+ );
+ /*!
+ ensures
+ - returns an object_detector that is configured to find human faces that are
+ looking more or less towards the camera.
+ !*/
+
+}
+
+#endif // DLIB_FRONTAL_FACE_DETECTOr_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/full_object_detection.h b/ml/dlib/dlib/image_processing/full_object_detection.h
new file mode 100644
index 000000000..1dfc99b2d
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/full_object_detection.h
@@ -0,0 +1,191 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FULL_OBJECT_DeTECTION_Hh_
+#define DLIB_FULL_OBJECT_DeTECTION_Hh_
+
+#include "../geometry.h"
+#include "full_object_detection_abstract.h"
+#include <vector>
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ const static point OBJECT_PART_NOT_PRESENT(0x7FFFFFFF,
+ 0x7FFFFFFF);
+
+// ----------------------------------------------------------------------------------------
+
+ class full_object_detection
+ {
+ public:
+ full_object_detection(
+ const rectangle& rect_,
+ const std::vector<point>& parts_
+ ) : rect(rect_), parts(parts_) {}
+
+ full_object_detection(){}
+
+ explicit full_object_detection(
+ const rectangle& rect_
+ ) : rect(rect_) {}
+
+ const rectangle& get_rect() const { return rect; }
+ rectangle& get_rect() { return rect; }
+ unsigned long num_parts() const { return parts.size(); }
+
+ const point& part(
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < num_parts(),
+ "\t point full_object_detection::part()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t idx: " << idx
+ << "\n\t num_parts(): " << num_parts()
+ << "\n\t this: " << this
+ );
+ return parts[idx];
+ }
+
+ point& part(
+ unsigned long idx
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < num_parts(),
+ "\t point full_object_detection::part()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t idx: " << idx
+ << "\n\t num_parts(): " << num_parts()
+ << "\n\t this: " << this
+ );
+ return parts[idx];
+ }
+
+ friend void serialize (
+ const full_object_detection& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.rect, out);
+ serialize(item.parts, out);
+ }
+
+ friend void deserialize (
+ full_object_detection& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version encountered while deserializing dlib::full_object_detection.");
+
+ deserialize(item.rect, in);
+ deserialize(item.parts, in);
+ }
+
+ bool operator==(
+ const full_object_detection& rhs
+ ) const
+ {
+ if (rect != rhs.rect)
+ return false;
+ if (parts.size() != rhs.parts.size())
+ return false;
+ for (size_t i = 0; i < parts.size(); ++i)
+ {
+ if (parts[i] != rhs.parts[i])
+ return false;
+ }
+ return true;
+ }
+
+ private:
+ rectangle rect;
+ std::vector<point> parts;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool all_parts_in_rect (
+ const full_object_detection& obj
+ )
+ {
+ for (unsigned long i = 0; i < obj.num_parts(); ++i)
+ {
+ if (obj.get_rect().contains(obj.part(i)) == false &&
+ obj.part(i) != OBJECT_PART_NOT_PRESENT)
+ return false;
+ }
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ struct mmod_rect
+ {
+ mmod_rect() = default;
+ mmod_rect(const rectangle& r) : rect(r) {}
+ mmod_rect(const rectangle& r, double score) : rect(r),detection_confidence(score) {}
+ mmod_rect(const rectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score), label(label) {}
+
+ rectangle rect;
+ double detection_confidence = 0;
+ bool ignore = false;
+ std::string label;
+
+ operator rectangle() const { return rect; }
+ bool operator == (const mmod_rect& rhs) const
+ {
+ return rect == rhs.rect
+ && detection_confidence == rhs.detection_confidence
+ && ignore == rhs.ignore
+ && label == rhs.label;
+ }
+ };
+
+ inline mmod_rect ignored_mmod_rect(const rectangle& r)
+ {
+ mmod_rect temp(r);
+ temp.ignore = true;
+ return temp;
+ }
+
+ inline void serialize(const mmod_rect& item, std::ostream& out)
+ {
+ int version = 2;
+ serialize(version, out);
+ serialize(item.rect, out);
+ serialize(item.detection_confidence, out);
+ serialize(item.ignore, out);
+ serialize(item.label, out);
+ }
+
+ inline void deserialize(mmod_rect& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1 && version != 2)
+ throw serialization_error("Unexpected version found while deserializing dlib::mmod_rect");
+ deserialize(item.rect, in);
+ deserialize(item.detection_confidence, in);
+ deserialize(item.ignore, in);
+ if (version == 2)
+ deserialize(item.label, in);
+ else
+ item.label = "";
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FULL_OBJECT_DeTECTION_H_
+
diff --git a/ml/dlib/dlib/image_processing/full_object_detection_abstract.h b/ml/dlib/dlib/image_processing/full_object_detection_abstract.h
new file mode 100644
index 000000000..099ee01b0
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/full_object_detection_abstract.h
@@ -0,0 +1,203 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_Hh_
+#ifdef DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_Hh_
+
+#include <vector>
+#include "../geometry.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ const static point OBJECT_PART_NOT_PRESENT(0x7FFFFFFF,
+ 0x7FFFFFFF);
+
+// ----------------------------------------------------------------------------------------
+
+ class full_object_detection
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the location of an object in an image along with the
+ positions of each of its constituent parts.
+ !*/
+
+ public:
+
+ full_object_detection(
+ const rectangle& rect,
+ const std::vector<point>& parts
+ );
+ /*!
+ ensures
+ - #get_rect() == rect
+ - #num_parts() == parts.size()
+ - for all valid i:
+ - part(i) == parts[i]
+ !*/
+
+ full_object_detection(
+ );
+ /*!
+ ensures
+ - #get_rect().is_empty() == true
+ - #num_parts() == 0
+ !*/
+
+ explicit full_object_detection(
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - #get_rect() == rect
+ - #num_parts() == 0
+ !*/
+
+ const rectangle& get_rect(
+ ) const;
+ /*!
+ ensures
+ - returns the rectangle that indicates where this object is. In general,
+ this should be the bounding box for the object.
+ !*/
+
+ rectangle& get_rect(
+ );
+ /*!
+ ensures
+ - returns the rectangle that indicates where this object is. In general,
+ this should be the bounding box for the object.
+ !*/
+
+ unsigned long num_parts(
+ ) const;
+ /*!
+ ensures
+ - returns the number of parts in this object.
+ !*/
+
+ const point& part(
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < num_parts()
+ ensures
+ - returns the location of the center of the idx-th part of this object.
+ Note that it is valid for a part to be "not present". This is indicated
+ when the return value of part() is equal to OBJECT_PART_NOT_PRESENT.
+ This is useful for modeling object parts that are not always observed.
+ !*/
+
+ point& part(
+ unsigned long idx
+ );
+ /*!
+ requires
+ - idx < num_parts()
+ ensures
+ - returns the location of the center of the idx-th part of this object.
+ Note that it is valid for a part to be "not present". This is indicated
+ when the return value of part() is equal to OBJECT_PART_NOT_PRESENT.
+ This is useful for modeling object parts that are not always observed.
+ !*/
+
+ bool operator==(
+ const full_object_detection& rhs
+ ) const;
+ /*!
+ ensures
+ - returns true if and only if *this and rhs have identical state.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const full_object_detection& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ full_object_detection& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool all_parts_in_rect (
+ const full_object_detection& obj
+ );
+ /*!
+ ensures
+ - returns true if all the parts in obj are contained within obj.get_rect().
+ That is, returns true if and only if, for all valid i, the following is
+ always true:
+ obj.get_rect().contains(obj.part(i)) == true || obj.part(i) == OBJECT_PART_NOT_PRESENT
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ struct mmod_rect
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple struct that is used to give training data and receive detections
+ from the Max-Margin Object Detection loss layer loss_mmod_ object.
+ !*/
+
+ mmod_rect() = default;
+ mmod_rect(const rectangle& r) : rect(r) {}
+ mmod_rect(const rectangle& r, double score) : rect(r),detection_confidence(score) {}
+ mmod_rect(const rectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score),label(label) {}
+
+ rectangle rect;
+ double detection_confidence = 0;
+ bool ignore = false;
+ std::string label;
+
+ operator rectangle() const { return rect; }
+
+ bool operator == (const mmod_rect& rhs) const;
+ /*!
+ ensures
+ - returns true if and only if all the elements of this object compare equal
+ to the corresponding elements of rhs.
+ !*/
+ };
+
+ mmod_rect ignored_mmod_rect(
+ const rectangle& r
+ );
+ /*!
+ ensures
+ - returns a mmod_rect R such that:
+ - R.rect == r
+ - R.ignore == true
+ - R.detection_confidence == 0
+ - R.label == ""
+ !*/
+
+ void serialize(const mmod_rect& item, std::ostream& out);
+ void deserialize(mmod_rect& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/generic_image.h b/ml/dlib/dlib/image_processing/generic_image.h
new file mode 100644
index 000000000..362277368
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/generic_image.h
@@ -0,0 +1,431 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GeNERIC_IMAGE_Hh_
+#define DLIB_GeNERIC_IMAGE_Hh_
+
+#include "../assert.h"
+
+namespace dlib
+{
+
+ /*!
+ In dlib, an "image" is any object that implements the generic image interface. In
+ particular, this simply means that an image type (let's refer to it as image_type
+ from here on) has the following seven global functions defined for it:
+ - long num_rows (const image_type& img)
+ - long num_columns (const image_type& img)
+ - void set_image_size( image_type& img, long rows, long cols)
+ - void* image_data ( image_type& img)
+ - const void* image_data (const image_type& img)
+ - long width_step (const image_type& img)
+ - void swap ( image_type& a, image_type& b)
+ And also provides a specialization of the image_traits template that looks like:
+ namespace dlib
+ {
+ template <>
+ struct image_traits<image_type>
+ {
+ typedef the_type_of_pixel_used_in_image_type pixel_type;
+ };
+ }
+
+ Additionally, an image object must be default constructable. This means that
+ expressions of the form:
+ image_type img;
+ Must be legal.
+
+ Finally, the type of pixel in image_type must have a pixel_traits specialization.
+ That is, pixel_traits<typename image_traits<image_type>::pixel_type> must be one of
+ the specializations of pixel_traits.
+
+
+ To be very precise, the seven functions defined above are defined thusly:
+
+ long num_rows(
+ const image_type& img
+ );
+ /!*
+ ensures
+ - returns the number of rows in the given image
+ *!/
+
+ long num_columns(
+ const image_type& img
+ );
+ /!*
+ ensures
+ - returns the number of columns in the given image
+ *!/
+
+ void set_image_size(
+ image_type& img,
+ long rows,
+ long cols
+ );
+ /!*
+ requires
+ - rows >= 0 && cols >= 0
+ ensures
+ - num_rows(#img) == rows
+ - num_columns(#img) == cols
+ *!/
+
+ void* image_data(
+ image_type& img
+ );
+ /!*
+ ensures
+ - returns a non-const pointer to the pixel at row and column position 0,0
+ in the given image. Or if the image has zero rows or columns in it
+ then this function returns NULL.
+ - The image lays pixels down in row major order. However, there might
+ be padding at the end of each row. The amount of padding is given by
+ width_step(img).
+ *!/
+
+ const void* image_data(
+ const image_type& img
+ );
+ /!*
+ ensures
+ - returns a const pointer to the pixel at row and column position 0,0 in
+ the given image. Or if the image has zero rows or columns in it then
+ this function returns NULL.
+ - The image lays pixels down in row major order. However, there might
+ be padding at the end of each row. The amount of padding is given by
+ width_step(img).
+ *!/
+
+ long width_step(
+ const image_type& img
+ );
+ /!*
+ ensures
+ - returns the size of one row of the image, in bytes. More precisely,
+ return a number N such that: (char*)image_data(img) + N*R == a
+ pointer to the first pixel in the R-th row of the image. This means
+ that the image must lay its pixels down in row major order.
+ *!/
+
+ void swap(
+ image_type& a,
+ image_type& b
+ );
+ /!*
+ ensures
+ - swaps the state of a and b
+ *!/
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ struct image_traits;
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a traits class for generic image objects. You can use it to find out
+ the pixel type contained within an image via an expression of the form:
+ image_traits<image_type>::pixel_type
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// UTILITIES TO MAKE ACCESSING IMAGE PIXELS SIMPLER
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ class image_view
+ {
+ /*!
+ REQUIREMENTS ON image_type
+ image_type must be an image object as defined at the top of this file.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object takes an image object and wraps it with an interface that makes
+ it look like a dlib::array2d. That is, it makes it look similar to a
+ regular 2-dimensional C style array, making code which operates on the
+ pixels simple to read.
+
+ Note that an image_view instance is valid until the image given to its
+ constructor is modified through an interface other than the image_view
+ instance. This is because, for example, someone might cause the underlying
+ image object to reallocate its memory, thus invalidating the pointer to its
+ pixel data stored in the image_view.
+
+ As an side, the reason why this object stores a pointer to the image
+ object's data and uses that pointer instead of calling image_data() each
+ time a pixel is accessed is to allow for image objects to implement
+ complex, and possibly slow, image_data() functions. For example, an image
+ object might perform some kind of synchronization between a GPU and the
+ host memory during a call to image_data(). Therefore, we call image_data()
+ only in image_view's constructor to avoid the performance penalty of
+ calling it for each pixel access.
+ !*/
+
+ public:
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+
+ image_view(
+ image_type& img
+ ) :
+ _data((char*)image_data(img)),
+ _width_step(width_step(img)),
+ _nr(num_rows(img)),
+ _nc(num_columns(img)),
+ _img(&img)
+ {}
+
+ long nr() const { return _nr; }
+ /*!
+ ensures
+ - returns the number of rows in this image.
+ !*/
+
+ long nc() const { return _nc; }
+ /*!
+ ensures
+ - returns the number of columns in this image.
+ !*/
+
+ unsigned long size() const { return static_cast<unsigned long>(nr()*nc()); }
+ /*!
+ ensures
+ - returns the number of pixels in this image.
+ !*/
+
+#ifndef ENABLE_ASSERTS
+ pixel_type* operator[] (long row) { return (pixel_type*)(_data+_width_step*row); }
+ /*!
+ requires
+ - 0 <= row < nr()
+ ensures
+ - returns a pointer to the first pixel in the row-th row. Therefore, the
+ pixel at row and column position r,c can be accessed via (*this)[r][c].
+ !*/
+
+ const pixel_type* operator[] (long row) const { return (const pixel_type*)(_data+_width_step*row); }
+ /*!
+ requires
+ - 0 <= row < nr()
+ ensures
+ - returns a const pointer to the first pixel in the row-th row. Therefore,
+ the pixel at row and column position r,c can be accessed via
+ (*this)[r][c].
+ !*/
+#else
+ // If asserts are enabled then we need to return a proxy class so we can make sure
+ // the column accesses don't go out of bounds.
+ struct pix_row
+ {
+ pix_row(pixel_type* data_, long nc_) : data(data_),_nc(nc_) {}
+ const pixel_type& operator[] (long col) const
+ {
+ DLIB_ASSERT(0 <= col && col < _nc,
+ "\t The given column index is out of range."
+ << "\n\t col: " << col
+ << "\n\t _nc: " << _nc);
+ return data[col];
+ }
+ pixel_type& operator[] (long col)
+ {
+ DLIB_ASSERT(0 <= col && col < _nc,
+ "\t The given column index is out of range."
+ << "\n\t col: " << col
+ << "\n\t _nc: " << _nc);
+ return data[col];
+ }
+ private:
+ pixel_type* const data;
+ const long _nc;
+ };
+ pix_row operator[] (long row)
+ {
+ DLIB_ASSERT(0 <= row && row < _nr,
+ "\t The given row index is out of range."
+ << "\n\t row: " << row
+ << "\n\t _nr: " << _nr);
+ return pix_row((pixel_type*)(_data+_width_step*row), _nc);
+ }
+ const pix_row operator[] (long row) const
+ {
+ DLIB_ASSERT(0 <= row && row < _nr,
+ "\t The given row index is out of range."
+ << "\n\t row: " << row
+ << "\n\t _nr: " << _nr);
+ return pix_row((pixel_type*)(_data+_width_step*row), _nc);
+ }
+#endif
+
+ void set_size(long rows, long cols)
+ /*!
+ requires
+ - rows >= 0 && cols >= 0
+ ensures
+ - Tells the underlying image to resize itself to have the given number of
+ rows and columns.
+ - #nr() == rows
+ - #nc() == cols
+ !*/
+ {
+ DLIB_ASSERT((cols >= 0 && rows >= 0),
+ "\t image_view::set_size(long rows, long cols)"
+ << "\n\t The images can't have negative rows or columns."
+ << "\n\t cols: " << cols
+ << "\n\t rows: " << rows
+ );
+ set_image_size(*_img, rows, cols); *this = *_img;
+ }
+
+ void clear() { set_size(0,0); }
+ /*!
+ ensures
+ - sets the image to have 0 pixels in it.
+ !*/
+
+ private:
+
+ char* _data;
+ long _width_step;
+ long _nr;
+ long _nc;
+ image_type* _img;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ class const_image_view
+ {
+ /*!
+ REQUIREMENTS ON image_type
+ image_type must be an image object as defined at the top of this file.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is just like the image_view except that it provides a "const"
+ view into an image. That is, it has the same interface as image_view
+ except that you can't modify the image through a const_image_view.
+ !*/
+
+ public:
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+
+ const_image_view(
+ const image_type& img
+ ) :
+ _data((char*)image_data(img)),
+ _width_step(width_step(img)),
+ _nr(num_rows(img)),
+ _nc(num_columns(img))
+ {}
+
+ long nr() const { return _nr; }
+ long nc() const { return _nc; }
+ unsigned long size() const { return static_cast<unsigned long>(nr()*nc()); }
+#ifndef ENABLE_ASSERTS
+ const pixel_type* operator[] (long row) const { return (const pixel_type*)(_data+_width_step*row); }
+#else
+ // If asserts are enabled then we need to return a proxy class so we can make sure
+ // the column accesses don't go out of bounds.
+ struct pix_row
+ {
+ pix_row(pixel_type* data_, long nc_) : data(data_),_nc(nc_) {}
+ const pixel_type& operator[] (long col) const
+ {
+ DLIB_ASSERT(0 <= col && col < _nc,
+ "\t The given column index is out of range."
+ << "\n\t col: " << col
+ << "\n\t _nc: " << _nc);
+ return data[col];
+ }
+ private:
+ pixel_type* const data;
+ const long _nc;
+ };
+ const pix_row operator[] (long row) const
+ {
+ DLIB_ASSERT(0 <= row && row < _nr,
+ "\t The given row index is out of range."
+ << "\n\t row: " << row
+ << "\n\t _nr: " << _nr);
+ return pix_row((pixel_type*)(_data+_width_step*row), _nc);
+ }
+#endif
+
+ private:
+ const char* _data;
+ long _width_step;
+ long _nr;
+ long _nc;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ image_view<image_type> make_image_view ( image_type& img)
+ { return image_view<image_type>(img); }
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined at the
+ top of this file.
+ ensures
+ - constructs an image_view from an image object
+ !*/
+
+ template <typename image_type>
+ const_image_view<image_type> make_image_view (const image_type& img)
+ { return const_image_view<image_type>(img); }
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined at the
+ top of this file.
+ ensures
+ - constructs a const_image_view from an image object
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ inline unsigned long image_size(
+ const image_type& img
+ ) { return num_columns(img)*num_rows(img); }
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined at the
+ top of this file.
+ ensures
+ - returns the number of pixels in the given image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ inline long num_rows(
+ const image_type& img
+ ) { return img.nr(); }
+ /*!
+ ensures
+ - By default, try to use the member function .nr() to determine the number
+ of rows in an image. However, as stated at the top of this file, image
+ objects should provide their own overload of num_rows() if needed.
+ !*/
+
+ template <typename image_type>
+ inline long num_columns(
+ const image_type& img
+ ) { return img.nc(); }
+ /*!
+ ensures
+ - By default, try to use the member function .nc() to determine the number
+ of columns in an image. However, as stated at the top of this file, image
+ objects should provide their own overload of num_rows() if needed.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GeNERIC_IMAGE_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/object_detector.h b/ml/dlib/dlib/image_processing/object_detector.h
new file mode 100644
index 000000000..9f78abd19
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/object_detector.h
@@ -0,0 +1,626 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OBJECT_DeTECTOR_Hh_
+#define DLIB_OBJECT_DeTECTOR_Hh_
+
+#include "object_detector_abstract.h"
+#include "../geometry.h"
+#include <vector>
+#include "box_overlap_testing.h"
+#include "full_object_detection.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct rect_detection
+ {
+ double detection_confidence;
+ unsigned long weight_index;
+ rectangle rect;
+
+ bool operator<(const rect_detection& item) const { return detection_confidence < item.detection_confidence; }
+ };
+
+ struct full_detection
+ {
+ double detection_confidence;
+ unsigned long weight_index;
+ full_object_detection rect;
+
+ bool operator<(const full_detection& item) const { return detection_confidence < item.detection_confidence; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_scanner_type>
+ struct processed_weight_vector
+ {
+ processed_weight_vector(){}
+
+ typedef typename image_scanner_type::feature_vector_type feature_vector_type;
+
+ void init (
+ const image_scanner_type&
+ )
+ /*!
+ requires
+ - w has already been assigned its value. Note that the point of this
+ function is to allow an image scanner to overload the
+ processed_weight_vector template and provide some different kind of
+ object as the output of get_detect_argument(). For example, the
+ scan_fhog_pyramid object uses an overload that causes
+ get_detect_argument() to return the special fhog_filterbank object
+ instead of a feature_vector_type. This avoids needing to construct the
+ fhog_filterbank during each call to detect and therefore speeds up
+ detection.
+ !*/
+ {}
+
+ // return the first argument to image_scanner_type::detect()
+ const feature_vector_type& get_detect_argument() const { return w; }
+
+ feature_vector_type w;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type_
+ >
+ class object_detector
+ {
+ public:
+ typedef image_scanner_type_ image_scanner_type;
+ typedef typename image_scanner_type::feature_vector_type feature_vector_type;
+
+ object_detector (
+ );
+
+ object_detector (
+ const object_detector& item
+ );
+
+ object_detector (
+ const image_scanner_type& scanner_,
+ const test_box_overlap& overlap_tester_,
+ const feature_vector_type& w_
+ );
+
+ object_detector (
+ const image_scanner_type& scanner_,
+ const test_box_overlap& overlap_tester_,
+ const std::vector<feature_vector_type>& w_
+ );
+
+ explicit object_detector (
+ const std::vector<object_detector>& detectors
+ );
+
+ unsigned long num_detectors (
+ ) const { return w.size(); }
+
+ const feature_vector_type& get_w (
+ unsigned long idx = 0
+ ) const { return w[idx].w; }
+
+ const processed_weight_vector<image_scanner_type>& get_processed_w (
+ unsigned long idx = 0
+ ) const { return w[idx]; }
+
+ const test_box_overlap& get_overlap_tester (
+ ) const;
+
+ const image_scanner_type& get_scanner (
+ ) const;
+
+ object_detector& operator= (
+ const object_detector& item
+ );
+
+ template <
+ typename image_type
+ >
+ std::vector<rectangle> operator() (
+ const image_type& img,
+ double adjust_threshold = 0
+ );
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<std::pair<double, rectangle> >& final_dets,
+ double adjust_threshold = 0
+ );
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<std::pair<double, full_object_detection> >& final_dets,
+ double adjust_threshold = 0
+ );
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<full_object_detection>& final_dets,
+ double adjust_threshold = 0
+ );
+
+ // These typedefs are here for backwards compatibility with previous versions of
+ // dlib.
+ typedef ::dlib::rect_detection rect_detection;
+ typedef ::dlib::full_detection full_detection;
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<rect_detection>& final_dets,
+ double adjust_threshold = 0
+ );
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<full_detection>& final_dets,
+ double adjust_threshold = 0
+ );
+
+ template <typename T>
+ friend void serialize (
+ const object_detector<T>& item,
+ std::ostream& out
+ );
+
+ template <typename T>
+ friend void deserialize (
+ object_detector<T>& item,
+ std::istream& in
+ );
+
+ private:
+
+ bool overlaps_any_box (
+ const std::vector<rect_detection>& rects,
+ const dlib::rectangle& rect
+ ) const
+ {
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ if (boxes_overlap(rects[i].rect, rect))
+ return true;
+ }
+ return false;
+ }
+
+ test_box_overlap boxes_overlap;
+ std::vector<processed_weight_vector<image_scanner_type> > w;
+ image_scanner_type scanner;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const object_detector<T>& item,
+ std::ostream& out
+ )
+ {
+ int version = 2;
+ serialize(version, out);
+
+ T scanner;
+ scanner.copy_configuration(item.scanner);
+ serialize(scanner, out);
+ serialize(item.boxes_overlap, out);
+ // serialize all the weight vectors
+ serialize(item.w.size(), out);
+ for (unsigned long i = 0; i < item.w.size(); ++i)
+ serialize(item.w[i].w, out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void deserialize (
+ object_detector<T>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version == 1)
+ {
+ deserialize(item.scanner, in);
+ item.w.resize(1);
+ deserialize(item.w[0].w, in);
+ item.w[0].init(item.scanner);
+ deserialize(item.boxes_overlap, in);
+ }
+ else if (version == 2)
+ {
+ deserialize(item.scanner, in);
+ deserialize(item.boxes_overlap, in);
+ unsigned long num_detectors = 0;
+ deserialize(num_detectors, in);
+ item.w.resize(num_detectors);
+ for (unsigned long i = 0; i < item.w.size(); ++i)
+ {
+ deserialize(item.w[i].w, in);
+ item.w[i].init(item.scanner);
+ }
+ }
+ else
+ {
+ throw serialization_error("Unexpected version encountered while deserializing a dlib::object_detector object.");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// object_detector member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ object_detector<image_scanner_type>::
+ object_detector (
+ )
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ object_detector<image_scanner_type>::
+ object_detector (
+ const object_detector& item
+ )
+ {
+ boxes_overlap = item.boxes_overlap;
+ w = item.w;
+ scanner.copy_configuration(item.scanner);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ object_detector<image_scanner_type>::
+ object_detector (
+ const image_scanner_type& scanner_,
+ const test_box_overlap& overlap_tester,
+ const feature_vector_type& w_
+ ) :
+ boxes_overlap(overlap_tester)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(scanner_.get_num_detection_templates() > 0 &&
+ w_.size() == scanner_.get_num_dimensions() + 1,
+ "\t object_detector::object_detector(scanner_,overlap_tester,w_)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates()
+ << "\n\t w_.size(): " << w_.size()
+ << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions()
+ << "\n\t this: " << this
+ );
+
+ scanner.copy_configuration(scanner_);
+ w.resize(1);
+ w[0].w = w_;
+ w[0].init(scanner);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ object_detector<image_scanner_type>::
+ object_detector (
+ const image_scanner_type& scanner_,
+ const test_box_overlap& overlap_tester,
+ const std::vector<feature_vector_type>& w_
+ ) :
+ boxes_overlap(overlap_tester)
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(scanner_.get_num_detection_templates() > 0 && w_.size() > 0,
+ "\t object_detector::object_detector(scanner_,overlap_tester,w_)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates()
+ << "\n\t w_.size(): " << w_.size()
+ << "\n\t this: " << this
+ );
+
+ for (unsigned long i = 0; i < w_.size(); ++i)
+ {
+ DLIB_CASSERT(w_[i].size() == scanner_.get_num_dimensions() + 1,
+ "\t object_detector::object_detector(scanner_,overlap_tester,w_)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates()
+ << "\n\t w_["<<i<<"].size(): " << w_[i].size()
+ << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions()
+ << "\n\t this: " << this
+ );
+ }
+
+ scanner.copy_configuration(scanner_);
+ w.resize(w_.size());
+ for (unsigned long i = 0; i < w.size(); ++i)
+ {
+ w[i].w = w_[i];
+ w[i].init(scanner);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ object_detector<image_scanner_type>::
+ object_detector (
+ const std::vector<object_detector>& detectors
+ )
+ {
+ DLIB_CASSERT(detectors.size() != 0,
+ "\t object_detector::object_detector(detectors)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t this: " << this
+ );
+ std::vector<feature_vector_type> weights;
+ weights.reserve(detectors.size());
+ for (unsigned long i = 0; i < detectors.size(); ++i)
+ {
+ for (unsigned long j = 0; j < detectors[i].num_detectors(); ++j)
+ weights.push_back(detectors[i].get_w(j));
+ }
+
+ *this = object_detector(detectors[0].get_scanner(), detectors[0].get_overlap_tester(), weights);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ object_detector<image_scanner_type>& object_detector<image_scanner_type>::
+ operator= (
+ const object_detector& item
+ )
+ {
+ if (this == &item)
+ return *this;
+
+ boxes_overlap = item.boxes_overlap;
+ w = item.w;
+ scanner.copy_configuration(item.scanner);
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ template <
+ typename image_type
+ >
+ void object_detector<image_scanner_type>::
+ operator() (
+ const image_type& img,
+ std::vector<rect_detection>& final_dets,
+ double adjust_threshold
+ )
+ {
+ scanner.load(img);
+ std::vector<std::pair<double, rectangle> > dets;
+ std::vector<rect_detection> dets_accum;
+ for (unsigned long i = 0; i < w.size(); ++i)
+ {
+ const double thresh = w[i].w(scanner.get_num_dimensions());
+ scanner.detect(w[i].get_detect_argument(), dets, thresh + adjust_threshold);
+ for (unsigned long j = 0; j < dets.size(); ++j)
+ {
+ rect_detection temp;
+ temp.detection_confidence = dets[j].first-thresh;
+ temp.weight_index = i;
+ temp.rect = dets[j].second;
+ dets_accum.push_back(temp);
+ }
+ }
+
+ // Do non-max suppression
+ final_dets.clear();
+ if (w.size() > 1)
+ std::sort(dets_accum.rbegin(), dets_accum.rend());
+ for (unsigned long i = 0; i < dets_accum.size(); ++i)
+ {
+ if (overlaps_any_box(final_dets, dets_accum[i].rect))
+ continue;
+
+ final_dets.push_back(dets_accum[i]);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ template <
+ typename image_type
+ >
+ void object_detector<image_scanner_type>::
+ operator() (
+ const image_type& img,
+ std::vector<full_detection>& final_dets,
+ double adjust_threshold
+ )
+ {
+ std::vector<rect_detection> dets;
+ (*this)(img,dets,adjust_threshold);
+
+ final_dets.resize(dets.size());
+
+ // convert all the rectangle detections into full_object_detections.
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ final_dets[i].detection_confidence = dets[i].detection_confidence;
+ final_dets[i].weight_index = dets[i].weight_index;
+ final_dets[i].rect = scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ template <
+ typename image_type
+ >
+ std::vector<rectangle> object_detector<image_scanner_type>::
+ operator() (
+ const image_type& img,
+ double adjust_threshold
+ )
+ {
+ std::vector<rect_detection> dets;
+ (*this)(img,dets,adjust_threshold);
+
+ std::vector<rectangle> final_dets(dets.size());
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ final_dets[i] = dets[i].rect;
+
+ return final_dets;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ template <
+ typename image_type
+ >
+ void object_detector<image_scanner_type>::
+ operator() (
+ const image_type& img,
+ std::vector<std::pair<double, rectangle> >& final_dets,
+ double adjust_threshold
+ )
+ {
+ std::vector<rect_detection> dets;
+ (*this)(img,dets,adjust_threshold);
+
+ final_dets.resize(dets.size());
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ final_dets[i] = std::make_pair(dets[i].detection_confidence,dets[i].rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ template <
+ typename image_type
+ >
+ void object_detector<image_scanner_type>::
+ operator() (
+ const image_type& img,
+ std::vector<std::pair<double, full_object_detection> >& final_dets,
+ double adjust_threshold
+ )
+ {
+ std::vector<rect_detection> dets;
+ (*this)(img,dets,adjust_threshold);
+
+ final_dets.clear();
+ final_dets.reserve(dets.size());
+
+ // convert all the rectangle detections into full_object_detections.
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ final_dets.push_back(std::make_pair(dets[i].detection_confidence,
+ scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w)));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ template <
+ typename image_type
+ >
+ void object_detector<image_scanner_type>::
+ operator() (
+ const image_type& img,
+ std::vector<full_object_detection>& final_dets,
+ double adjust_threshold
+ )
+ {
+ std::vector<rect_detection> dets;
+ (*this)(img,dets,adjust_threshold);
+
+ final_dets.clear();
+ final_dets.reserve(dets.size());
+
+ // convert all the rectangle detections into full_object_detections.
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ final_dets.push_back(scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ const test_box_overlap& object_detector<image_scanner_type>::
+ get_overlap_tester (
+ ) const
+ {
+ return boxes_overlap;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ const image_scanner_type& object_detector<image_scanner_type>::
+ get_scanner (
+ ) const
+ {
+ return scanner;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OBJECT_DeTECTOR_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/object_detector_abstract.h b/ml/dlib/dlib/image_processing/object_detector_abstract.h
new file mode 100644
index 000000000..9578d8b03
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/object_detector_abstract.h
@@ -0,0 +1,404 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OBJECT_DeTECTOR_ABSTRACT_Hh_
+#ifdef DLIB_OBJECT_DeTECTOR_ABSTRACT_Hh_
+
+#include "../geometry.h"
+#include <vector>
+#include "box_overlap_testing_abstract.h"
+#include "full_object_detection_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct rect_detection
+ {
+ double detection_confidence;
+ unsigned long weight_index;
+ rectangle rect;
+ };
+
+ struct full_detection
+ {
+ double detection_confidence;
+ unsigned long weight_index;
+ full_object_detection rect;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type_
+ >
+ class object_detector
+ {
+ /*!
+ REQUIREMENTS ON image_scanner_type_
+ image_scanner_type_ must be an implementation of
+ dlib/image_processing/scan_image_pyramid_abstract.h or
+ dlib/image_processing/scan_fhog_pyramid.h or
+ dlib/image_processing/scan_image_custom.h or
+ dlib/image_processing/scan_image_boxes_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for detecting the positions of objects in an image.
+ In particular, it is a simple container to aggregate an instance of an image
+ scanner (i.e. scan_image_pyramid, scan_fhog_pyramid, scan_image_custom, or
+ scan_image_boxes), the weight vector needed by one of these image scanners,
+ and finally an instance of test_box_overlap. The test_box_overlap object
+ is used to perform non-max suppression on the output of the image scanner
+ object.
+
+ Note further that this object can contain multiple weight vectors. In this
+ case, it will run the image scanner multiple times, once with each of the
+ weight vectors. Then it will aggregate the results from all runs, perform
+ non-max suppression and then return the results. Therefore, the object_detector
+ can also be used as a container for a set of object detectors that all use
+ the same image scanner but different weight vectors. This is useful since
+ the object detection procedure has two parts. A loading step where the
+ image is loaded into the scanner, then a detect step which uses the weight
+ vector to locate objects in the image. Since the loading step is independent
+ of the weight vector it is most efficient to run multiple detectors by
+ performing one load into a scanner followed by multiple detect steps. This
+ avoids unnecessarily loading the same image into the scanner multiple times.
+ !*/
+ public:
+ typedef image_scanner_type_ image_scanner_type;
+ typedef typename image_scanner_type::feature_vector_type feature_vector_type;
+
+ object_detector (
+ );
+ /*!
+ ensures
+ - This detector won't generate any detections when
+ presented with an image.
+ - #num_detectors() == 0
+ !*/
+
+ object_detector (
+ const object_detector& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ - #get_scanner() == item.get_scanner()
+ (note that only the "configuration" of item.get_scanner() is copied.
+ I.e. the copy is done using copy_configuration())
+ !*/
+
+ object_detector (
+ const image_scanner_type& scanner,
+ const test_box_overlap& overlap_tester,
+ const feature_vector_type& w
+ );
+ /*!
+ requires
+ - w.size() == scanner.get_num_dimensions() + 1
+ - scanner.get_num_detection_templates() > 0
+ ensures
+ - When the operator() member function is called it will
+ invoke scanner.detect(w,dets,w(w.size()-1)), suppress
+ overlapping detections, and then report the results.
+ - when #*this is used to detect objects, the set of
+ output detections will never contain any overlaps
+ with respect to overlap_tester. That is, for all
+ pairs of returned detections A and B, we will always
+ have: overlap_tester(A,B) == false
+ - #get_w() == w
+ - #get_overlap_tester() == overlap_tester
+ - #get_scanner() == scanner
+ (note that only the "configuration" of scanner is copied.
+ I.e. the copy is done using copy_configuration())
+ - #num_detectors() == 1
+ !*/
+
+ object_detector (
+ const image_scanner_type& scanner,
+ const test_box_overlap& overlap_tester,
+ const std::vector<feature_vector_type>& w
+ );
+ /*!
+ requires
+ - for all valid i:
+ - w[i].size() == scanner.get_num_dimensions() + 1
+ - scanner.get_num_detection_templates() > 0
+ - w.size() > 0
+ ensures
+ - When the operator() member function is called it will invoke
+ get_scanner().detect(w[i],dets,w[i](w[i].size()-1)) for all valid i. Then it
+ will take all the detections output by the calls to detect() and suppress
+ overlapping detections, and finally report the results.
+ - when #*this is used to detect objects, the set of output detections will
+ never contain any overlaps with respect to overlap_tester. That is, for
+ all pairs of returned detections A and B, we will always have:
+ overlap_tester(A,B) == false
+ - for all valid i:
+ - #get_w(i) == w[i]
+ - #num_detectors() == w.size()
+ - #get_overlap_tester() == overlap_tester
+ - #get_scanner() == scanner
+ (note that only the "configuration" of scanner is copied.
+ I.e. the copy is done using copy_configuration())
+ !*/
+
+ explicit object_detector (
+ const std::vector<object_detector>& detectors
+ );
+ /*!
+ requires
+ - detectors.size() != 0
+ - All the detectors must use compatibly configured scanners. That is, it
+ must make sense for the weight vector from one detector to be used with
+ the scanner from any other.
+ - for all valid i:
+ - detectors[i].get_scanner().get_num_dimensions() == detectors[0].get_scanner().get_num_dimensions()
+ (i.e. all the detectors use scanners that use the same kind of feature vectors.)
+ ensures
+ - Very much like the above constructor, this constructor takes all the
+ given detectors and packs them into #*this. That is, invoking operator()
+ on #*this will run all the detectors, perform non-max suppression, and
+ then report the results.
+ - When #*this is used to detect objects, the set of output detections will
+ never contain any overlaps with respect to overlap_tester. That is, for
+ all pairs of returned detections A and B, we will always have:
+ overlap_tester(A,B) == false
+ - #num_detectors() == The sum of detectors[i].num_detectors() for all valid i.
+ - #get_overlap_tester() == detectors[0].get_overlap_tester()
+ - #get_scanner() == detectors[0].get_scanner()
+ (note that only the "configuration" of scanner is copied. I.e. the copy
+ is done using copy_configuration())
+ !*/
+
+ unsigned long num_detectors (
+ ) const;
+ /*!
+ ensures
+ - returns the number of weight vectors in this object. Since each weight
+ vector logically represents an object detector, this returns the number
+ of object detectors contained in this object.
+ !*/
+
+ const feature_vector_type& get_w (
+ unsigned long idx = 0
+ ) const;
+ /*!
+ requires
+ - idx < num_detectors()
+ ensures
+ - returns the idx-th weight vector loaded into this object. All the weight vectors
+ have the same dimension and logically each represents a different detector.
+ !*/
+
+ const test_box_overlap& get_overlap_tester (
+ ) const;
+ /*!
+ ensures
+ - returns the overlap tester used by this object
+ !*/
+
+ const image_scanner_type& get_scanner (
+ ) const;
+ /*!
+ ensures
+ - returns the image scanner used by this object.
+ !*/
+
+ object_detector& operator= (
+ const object_detector& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ - #get_scanner() == item.get_scanner()
+ (note that only the "configuration" of item.get_scanner() is
+ copied. I.e. the copy is done using copy_configuration())
+ - returns #*this
+ !*/
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<rect_detection>& dets,
+ double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - img == an object which can be accepted by image_scanner_type::load()
+ ensures
+ - Performs object detection on the given image and stores the detected
+ objects into #dets. In particular, we will have that:
+ - #dets is sorted such that the highest confidence detections come
+ first. E.g. element 0 is the best detection, element 1 the next
+ best, and so on.
+ - #dets.size() == the number of detected objects.
+ - #dets[i].detection_confidence == The strength of the i-th detection.
+ Larger values indicate that the detector is more confident that
+ #dets[i] is a correct detection rather than being a false alarm.
+ Moreover, the detection_confidence is equal to the detection value
+ output by the scanner minus the threshold value stored at the end of
+ the weight vector in get_w(#dets[i].weight_index).
+ - #dets[i].weight_index == the index for the weight vector that
+ generated this detection.
+ - #dets[i].rect == the bounding box for the i-th detection.
+ - #get_scanner() will have been loaded with img. Therefore, you can call
+ #get_scanner().get_feature_vector() to obtain the feature vectors or
+ #get_scanner().get_full_object_detection() to get the
+ full_object_detections for the resulting object detection boxes.
+ - The detection threshold is adjusted by having adjust_threshold added to
+ it. Therefore, an adjust_threshold value > 0 makes detecting objects
+ harder while a negative value makes it easier. Moreover, the following
+ will be true for all valid i:
+ - #dets[i].detection_confidence >= adjust_threshold
+ This means that, for example, you can obtain the maximum possible number
+ of detections by setting adjust_threshold equal to negative infinity.
+ !*/
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<full_detection>& dets,
+ double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - img == an object which can be accepted by image_scanner_type::load()
+ ensures
+ - This function is identical to the above operator() routine, except that
+ it outputs full_object_detections instead of rectangles. This means that
+ the output includes part locations. In particular, calling this function
+ is the same as calling the above operator() routine and then using
+ get_scanner().get_full_object_detection() to resolve all the rectangles
+ into full_object_detections. Therefore, this version of operator() is
+ simply a convenience function for performing this set of operations.
+ !*/
+
+ template <
+ typename image_type
+ >
+ std::vector<rectangle> operator() (
+ const image_type& img,
+ const adjust_threshold = 0
+ );
+ /*!
+ requires
+ - img == an object which can be accepted by image_scanner_type::load()
+ ensures
+ - This function is identical to the above operator() routine, except that
+ it returns a std::vector<rectangle> which contains just the bounding
+ boxes of all the detections.
+ !*/
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<std::pair<double, rectangle> >& dets,
+ double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - img == an object which can be accepted by image_scanner_type::load()
+ ensures
+ - performs object detection on the given image and stores the
+ detected objects into #dets. In particular, we will have that:
+ - #dets is sorted such that the highest confidence detections
+ come first. E.g. element 0 is the best detection, element 1
+ the next best, and so on.
+ - #dets.size() == the number of detected objects.
+ - #dets[i].first gives the "detection confidence", of the i-th
+ detection. This is the detection value output by the scanner minus
+ the threshold value stored at the end of the weight vector in get_w().
+ - #dets[i].second == the bounding box for the i-th detection.
+ - #get_scanner() will have been loaded with img. Therefore, you can call
+ #get_scanner().get_feature_vector() to obtain the feature vectors or
+ #get_scanner().get_full_object_detection() to get the
+ full_object_detections for the resulting object detection boxes.
+ - The detection threshold is adjusted by having adjust_threshold added to
+ it. Therefore, an adjust_threshold value > 0 makes detecting objects
+ harder while a negative value makes it easier. Moreover, the following
+ will be true for all valid i:
+ - #dets[i].first >= adjust_threshold
+ This means that, for example, you can obtain the maximum possible number
+ of detections by setting adjust_threshold equal to negative infinity.
+ !*/
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<std::pair<double, full_object_detection> >& dets,
+ double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - img == an object which can be accepted by image_scanner_type::load()
+ ensures
+ - This function is identical to the above operator() routine, except that
+ it outputs full_object_detections instead of rectangles. This means that
+ the output includes part locations. In particular, calling this function
+ is the same as calling the above operator() routine and then using
+ get_scanner().get_full_object_detection() to resolve all the rectangles
+ into full_object_detections. Therefore, this version of operator() is
+ simply a convenience function for performing this set of operations.
+ !*/
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ const image_type& img,
+ std::vector<full_object_detection>& dets,
+ double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - img == an object which can be accepted by image_scanner_type::load()
+ ensures
+ - This function is identical to the above operator() routine, except that
+ it doesn't include a double valued score. That is, it just outputs the
+ full_object_detections.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const object_detector<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support. Note that this function only saves the
+ configuration part of item.get_scanner(). That is, we use the scanner's
+ copy_configuration() function to get a copy of the scanner that doesn't contain any
+ loaded image data and we then save just the configuration part of the scanner.
+ This means that any serialized object_detectors won't remember any images they have
+ processed but will otherwise contain all their state and be able to detect objects
+ in new images.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void deserialize (
+ object_detector<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OBJECT_DeTECTOR_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles.h b/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles.h
new file mode 100644
index 000000000..95ab4f353
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles.h
@@ -0,0 +1,317 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_Hh_
+#define DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_Hh_
+
+#include "remove_unobtainable_rectangles_abstract.h"
+#include "scan_image_pyramid.h"
+#include "scan_image_boxes.h"
+#include "scan_image_custom.h"
+#include "scan_fhog_pyramid.h"
+#include "../svm/structural_object_detection_trainer.h"
+#include "../geometry.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline bool matches_rect (
+ const std::vector<rectangle>& rects,
+ const rectangle& rect,
+ const double eps
+ )
+ {
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ const double score = (rect.intersect(rects[i])).area()/(double)(rect+rects[i]).area();
+ if (score > eps)
+ return true;
+ }
+
+ return false;
+ }
+
+ inline rectangle get_best_matching_rect (
+ const std::vector<rectangle>& rects,
+ const rectangle& rect
+ )
+ {
+ double best_score = -1;
+ rectangle best_rect;
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ const double score = (rect.intersect(rects[i])).area()/(double)(rect+rects[i]).area();
+ if (score > best_score)
+ {
+ best_score = score;
+ best_rect = rects[i];
+ }
+ }
+ return best_rect;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename image_scanner_type
+ >
+ std::vector<std::vector<rectangle> > pyramid_remove_unobtainable_rectangles (
+ const structural_object_detection_trainer<image_scanner_type>& trainer,
+ const image_array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations
+ )
+ {
+ using namespace dlib::impl;
+ // make sure requires clause is not broken
+ DLIB_ASSERT(images.size() == object_locations.size(),
+ "\t std::vector<std::vector<rectangle>> remove_unobtainable_rectangles()"
+ << "\n\t Invalid inputs were given to this function."
+ );
+
+
+ std::vector<std::vector<rectangle> > rejects(images.size());
+
+ // If the trainer is setup to automatically fit the overlap tester to the data then
+ // we should use the loosest possible overlap tester here. Otherwise we should use
+ // the tester the trainer will use.
+ test_box_overlap boxes_overlap(0.9999999,1);
+ if (!trainer.auto_set_overlap_tester())
+ boxes_overlap = trainer.get_overlap_tester();
+
+ for (unsigned long k = 0; k < images.size(); ++k)
+ {
+ std::vector<rectangle> objs = object_locations[k];
+
+ // First remove things that don't have any matches with the candidate object
+ // locations.
+ std::vector<rectangle> good_rects;
+ for (unsigned long j = 0; j < objs.size(); ++j)
+ {
+ const rectangle rect = trainer.get_scanner().get_best_matching_rect(objs[j]);
+ const double score = (objs[j].intersect(rect)).area()/(double)(objs[j] + rect).area();
+ if (score > trainer.get_match_eps())
+ good_rects.push_back(objs[j]);
+ else
+ rejects[k].push_back(objs[j]);
+ }
+ object_locations[k] = good_rects;
+
+
+ // Remap these rectangles to the ones that can come out of the scanner. That
+ // way when we compare them to each other in the following loop we will know if
+ // any distinct truth rectangles get mapped to overlapping boxes.
+ objs.resize(good_rects.size());
+ for (unsigned long i = 0; i < good_rects.size(); ++i)
+ objs[i] = trainer.get_scanner().get_best_matching_rect(good_rects[i]);
+
+ good_rects.clear();
+ // now check for truth rects that are too close together.
+ for (unsigned long i = 0; i < objs.size(); ++i)
+ {
+ // check if objs[i] hits another box
+ bool hit_box = false;
+ for (unsigned long j = i+1; j < objs.size(); ++j)
+ {
+ if (boxes_overlap(objs[i], objs[j]))
+ {
+ hit_box = true;
+ break;
+ }
+ }
+ if (hit_box)
+ rejects[k].push_back(object_locations[k][i]);
+ else
+ good_rects.push_back(object_locations[k][i]);
+ }
+ object_locations[k] = good_rects;
+ }
+
+ return rejects;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ std::vector<std::vector<rectangle> > remove_unobtainable_rectangles (
+ const structural_object_detection_trainer<scan_image_pyramid<Pyramid_type, Feature_extractor_type> >& trainer,
+ const image_array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations
+ )
+ {
+ return impl::pyramid_remove_unobtainable_rectangles(trainer, images, object_locations);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ std::vector<std::vector<rectangle> > remove_unobtainable_rectangles (
+ const structural_object_detection_trainer<scan_fhog_pyramid<Pyramid_type,Feature_extractor_type> >& trainer,
+ const image_array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations
+ )
+ {
+ return impl::pyramid_remove_unobtainable_rectangles(trainer, images, object_locations);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename image_array_type,
+ typename scanner_type,
+ typename get_boxes_functor
+ >
+ std::vector<std::vector<rectangle> > remove_unobtainable_rectangles (
+ get_boxes_functor& bg,
+ const structural_object_detection_trainer<scanner_type>& trainer,
+ const image_array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations
+ )
+ {
+ using namespace dlib::impl;
+ // make sure requires clause is not broken
+ DLIB_ASSERT(images.size() == object_locations.size(),
+ "\t std::vector<std::vector<rectangle>> remove_unobtainable_rectangles()"
+ << "\n\t Invalid inputs were given to this function."
+ );
+
+ std::vector<rectangle> rects;
+
+ std::vector<std::vector<rectangle> > rejects(images.size());
+
+ // If the trainer is setup to automatically fit the overlap tester to the data then
+ // we should use the loosest possible overlap tester here. Otherwise we should use
+ // the tester the trainer will use.
+ test_box_overlap boxes_overlap(0.9999999,1);
+ if (!trainer.auto_set_overlap_tester())
+ boxes_overlap = trainer.get_overlap_tester();
+
+ for (unsigned long k = 0; k < images.size(); ++k)
+ {
+ std::vector<rectangle> objs = object_locations[k];
+ // Don't even bother computing the candidate rectangles if there aren't any
+ // object locations for this image since there isn't anything to do anyway.
+ if (objs.size() == 0)
+ continue;
+
+ bg(images[k], rects);
+
+
+ // First remove things that don't have any matches with the candidate object
+ // locations.
+ std::vector<rectangle> good_rects;
+ for (unsigned long j = 0; j < objs.size(); ++j)
+ {
+ if (matches_rect(rects, objs[j], trainer.get_match_eps()))
+ good_rects.push_back(objs[j]);
+ else
+ rejects[k].push_back(objs[j]);
+ }
+ object_locations[k] = good_rects;
+
+
+ // Remap these rectangles to the ones that can come out of the scanner. That
+ // way when we compare them to each other in the following loop we will know if
+ // any distinct truth rectangles get mapped to overlapping boxes.
+ objs.resize(good_rects.size());
+ for (unsigned long i = 0; i < good_rects.size(); ++i)
+ objs[i] = get_best_matching_rect(rects, good_rects[i]);
+
+ good_rects.clear();
+ // now check for truth rects that are too close together.
+ for (unsigned long i = 0; i < objs.size(); ++i)
+ {
+ // check if objs[i] hits another box
+ bool hit_box = false;
+ for (unsigned long j = i+1; j < objs.size(); ++j)
+ {
+ if (boxes_overlap(objs[i], objs[j]))
+ {
+ hit_box = true;
+ break;
+ }
+ }
+ if (hit_box)
+ rejects[k].push_back(object_locations[k][i]);
+ else
+ good_rects.push_back(object_locations[k][i]);
+ }
+ object_locations[k] = good_rects;
+ }
+
+ return rejects;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct load_to_functor
+ {
+ load_to_functor(T& obj_) : obj(obj_) {}
+ T& obj;
+
+ template <typename U, typename V>
+ void operator()(const U& u, V& v)
+ {
+ obj.load(u,v);
+ }
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename feature_extractor,
+ typename box_generator
+ >
+ std::vector<std::vector<rectangle> > remove_unobtainable_rectangles (
+ const structural_object_detection_trainer<scan_image_boxes<feature_extractor, box_generator> >& trainer,
+ const image_array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations
+ )
+ {
+ box_generator bg = trainer.get_scanner().get_box_generator();
+ return impl::remove_unobtainable_rectangles(bg, trainer, images, object_locations);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename feature_extractor
+ >
+ std::vector<std::vector<rectangle> > remove_unobtainable_rectangles (
+ const structural_object_detection_trainer<scan_image_custom<feature_extractor> >& trainer,
+ const image_array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations
+ )
+ {
+ feature_extractor fe;
+ fe.copy_configuration(trainer.get_scanner().get_feature_extractor());
+ impl::load_to_functor<feature_extractor> bg(fe);
+ return impl::remove_unobtainable_rectangles(bg, trainer, images, object_locations);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles_abstract.h b/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles_abstract.h
new file mode 100644
index 000000000..328326f1c
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles_abstract.h
@@ -0,0 +1,56 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_ABSTRACT_Hh_
+#ifdef DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_ABSTRACT_Hh_
+
+#include "scan_image_pyramid_abstract.h"
+#include "scan_image_boxes_abstract.h"
+#include "scan_image_custom_abstract.h"
+#include "scan_fhog_pyramid_abstract.h"
+#include "../svm/structural_object_detection_trainer_abstract.h"
+#include "../geometry.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type,
+ typename image_array_type
+ >
+ std::vector<std::vector<rectangle> > remove_unobtainable_rectangles (
+ const structural_object_detection_trainer<image_scanner_type>& trainer,
+ const image_array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations
+ );
+ /*!
+ requires
+ - image_scanner_type must be either scan_image_boxes, scan_image_pyramid,
+ scan_image_custom, or scan_fhog_pyramid.
+ - images.size() == object_locations.size()
+ ensures
+ - Recall that the image scanner objects can't produce all possible rectangles
+ as object detections since they only consider a limited subset of all possible
+ object positions. Moreover, the structural_object_detection_trainer requires
+ its input training data to not contain any object positions which are unobtainable
+ by its scanner object. Therefore, remove_unobtainable_rectangles() is a tool
+ to filter out these unobtainable rectangles from the training data before giving
+ it to a structural_object_detection_trainer.
+ - This function interprets object_locations[i] as the set of object positions for
+ image[i], for all valid i.
+ - In particular, this function removes unobtainable rectangles from object_locations
+ and also returns a vector V such that:
+ - V.size() == object_locations.size()
+ - for all valid i:
+ - V[i] == the set of rectangles removed from object_locations[i]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/render_face_detections.h b/ml/dlib/dlib/image_processing/render_face_detections.h
new file mode 100644
index 000000000..96ff8971f
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/render_face_detections.h
@@ -0,0 +1,99 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RENDER_FACE_DeTECTIONS_H_
+#define DLIB_RENDER_FACE_DeTECTIONS_H_
+
+#include "full_object_detection.h"
+#include "../gui_widgets.h"
+#include "render_face_detections_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+ inline std::vector<image_window::overlay_line> render_face_detections (
+ const std::vector<full_object_detection>& dets,
+ const rgb_pixel color = rgb_pixel(0,255,0)
+ )
+ {
+ std::vector<image_window::overlay_line> lines;
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ DLIB_CASSERT(dets[i].num_parts() == 68 || dets[i].num_parts() == 5,
+ "\t std::vector<image_window::overlay_line> render_face_detections()"
+ << "\n\t You have to give either a 5 point or 68 point face landmarking output to this function. "
+ << "\n\t dets["<<i<<"].num_parts(): " << dets[i].num_parts()
+ );
+
+ const full_object_detection& d = dets[i];
+
+ if (d.num_parts() == 5)
+ {
+ lines.push_back(image_window::overlay_line(d.part(0), d.part(1), color));
+ lines.push_back(image_window::overlay_line(d.part(1), d.part(4), color));
+ lines.push_back(image_window::overlay_line(d.part(4), d.part(3), color));
+ lines.push_back(image_window::overlay_line(d.part(3), d.part(2), color));
+ }
+ else
+ {
+ // Around Chin. Ear to Ear
+ for (unsigned long i = 1; i <= 16; ++i)
+ lines.push_back(image_window::overlay_line(d.part(i), d.part(i-1), color));
+
+ // Line on top of nose
+ for (unsigned long i = 28; i <= 30; ++i)
+ lines.push_back(image_window::overlay_line(d.part(i), d.part(i-1), color));
+
+ // left eyebrow
+ for (unsigned long i = 18; i <= 21; ++i)
+ lines.push_back(image_window::overlay_line(d.part(i), d.part(i-1), color));
+ // Right eyebrow
+ for (unsigned long i = 23; i <= 26; ++i)
+ lines.push_back(image_window::overlay_line(d.part(i), d.part(i-1), color));
+ // Bottom part of the nose
+ for (unsigned long i = 31; i <= 35; ++i)
+ lines.push_back(image_window::overlay_line(d.part(i), d.part(i-1), color));
+ // Line from the nose to the bottom part above
+ lines.push_back(image_window::overlay_line(d.part(30), d.part(35), color));
+
+ // Left eye
+ for (unsigned long i = 37; i <= 41; ++i)
+ lines.push_back(image_window::overlay_line(d.part(i), d.part(i-1), color));
+ lines.push_back(image_window::overlay_line(d.part(36), d.part(41), color));
+
+ // Right eye
+ for (unsigned long i = 43; i <= 47; ++i)
+ lines.push_back(image_window::overlay_line(d.part(i), d.part(i-1), color));
+ lines.push_back(image_window::overlay_line(d.part(42), d.part(47), color));
+
+ // Lips outer part
+ for (unsigned long i = 49; i <= 59; ++i)
+ lines.push_back(image_window::overlay_line(d.part(i), d.part(i-1), color));
+ lines.push_back(image_window::overlay_line(d.part(48), d.part(59), color));
+
+ // Lips inside part
+ for (unsigned long i = 61; i <= 67; ++i)
+ lines.push_back(image_window::overlay_line(d.part(i), d.part(i-1), color));
+ lines.push_back(image_window::overlay_line(d.part(60), d.part(67), color));
+ }
+ }
+ return lines;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<image_window::overlay_line> render_face_detections (
+ const full_object_detection& det,
+ const rgb_pixel color = rgb_pixel(0,255,0)
+ )
+ {
+ std::vector<full_object_detection> dets;
+ dets.push_back(det);
+ return render_face_detections(dets, color);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RENDER_FACE_DeTECTIONS_H_
+
diff --git a/ml/dlib/dlib/image_processing/render_face_detections_abstract.h b/ml/dlib/dlib/image_processing/render_face_detections_abstract.h
new file mode 100644
index 000000000..f609c8e8c
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/render_face_detections_abstract.h
@@ -0,0 +1,59 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RENDER_FACE_DeTECTIONS_ABSTRACT_H_
+#ifdef DLIB_RENDER_FACE_DeTECTIONS_ABSTRACT_H_
+
+#include "full_object_detection_abstract.h"
+#include "../gui_widgets.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<image_window::overlay_line> render_face_detections (
+ const std::vector<full_object_detection>& dets,
+ const rgb_pixel color = rgb_pixel(0,255,0)
+ );
+ /*!
+ requires
+ - for all valid i:
+ - dets[i].num_parts() == 68 || dets[i].num_parts() == 5
+ ensures
+ - Interprets the given objects as face detections with parts annotated using
+ either the iBUG face landmark scheme or a 5 point face annotation. We then
+ return a set of overlay lines that will draw the objects onto the screen in a
+ way that properly draws the outline of the face features defined by the part
+ locations.
+ - returns a vector with dets.size() elements, each containing the lines
+ necessary to render a face detection from dets.
+ - The 5 point face annotation scheme is assumed to be:
+ - det part 0 == left eye corner, outside part of eye.
+ - det part 1 == left eye corner, inside part of eye.
+ - det part 2 == right eye corner, outside part of eye.
+ - det part 3 == right eye corner, inside part of eye.
+ - det part 4 == immediately under the nose, right at the top of the philtrum.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<image_window::overlay_line> render_face_detections (
+ const full_object_detection& det,
+ const rgb_pixel color = rgb_pixel(0,255,0)
+ );
+ /*!
+ requires
+ - det.num_parts() == 68 || det.num_parts() == 5
+ ensures
+ - This function is identical to the above render_face_detections() routine
+ except that it takes just a single full_object_detection instead of a
+ std::vector of them.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RENDER_FACE_DeTECTIONS_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/image_processing/scan_fhog_pyramid.h b/ml/dlib/dlib/image_processing/scan_fhog_pyramid.h
new file mode 100644
index 000000000..5ae0310af
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_fhog_pyramid.h
@@ -0,0 +1,1348 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SCAN_fHOG_PYRAMID_Hh_
+#define DLIB_SCAN_fHOG_PYRAMID_Hh_
+
+#include "scan_fhog_pyramid_abstract.h"
+#include "../matrix.h"
+#include "../image_transforms.h"
+#include "../array.h"
+#include "../array2d.h"
+#include "object_detector.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class default_fhog_feature_extractor
+ {
+ public:
+ inline rectangle image_to_feats (
+ const rectangle& rect,
+ int cell_size,
+ int filter_rows_padding,
+ int filter_cols_padding
+ ) const
+ {
+ return image_to_fhog(rect, cell_size, filter_rows_padding, filter_cols_padding);
+ }
+
+ inline rectangle feats_to_image (
+ const rectangle& rect,
+ int cell_size,
+ int filter_rows_padding,
+ int filter_cols_padding
+ ) const
+ {
+ return fhog_to_image(rect, cell_size, filter_rows_padding, filter_cols_padding);
+ }
+
+ template <
+ typename image_type
+ >
+ void operator()(
+ const image_type& img,
+ dlib::array<array2d<float> >& hog,
+ int cell_size,
+ int filter_rows_padding,
+ int filter_cols_padding
+ ) const
+ {
+ extract_fhog_features(img,hog,cell_size,filter_rows_padding,filter_cols_padding);
+ }
+
+ inline unsigned long get_num_planes (
+ ) const
+ {
+ return 31;
+ }
+ };
+
+ inline void serialize (const default_fhog_feature_extractor&, std::ostream&) {}
+ inline void deserialize (default_fhog_feature_extractor&, std::istream&) {}
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type = default_fhog_feature_extractor
+ >
+ class scan_fhog_pyramid : noncopyable
+ {
+
+ public:
+
+ typedef matrix<double,0,1> feature_vector_type;
+
+ typedef Pyramid_type pyramid_type;
+ typedef Feature_extractor_type feature_extractor_type;
+
+ scan_fhog_pyramid (
+ );
+
+ explicit scan_fhog_pyramid (
+ const feature_extractor_type& fe_
+ );
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+
+ inline bool is_loaded_with_image (
+ ) const;
+
+ inline void copy_configuration (
+ const scan_fhog_pyramid& item
+ );
+
+ void set_detection_window_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(width > 0 && height > 0,
+ "\t void scan_fhog_pyramid::set_detection_window_size()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t width: " << width
+ << "\n\t height: " << height
+ << "\n\t this: " << this
+ );
+
+ window_width = width;
+ window_height = height;
+ feats.clear();
+ }
+
+ inline unsigned long get_detection_window_width (
+ ) const { return window_width; }
+ inline unsigned long get_detection_window_height (
+ ) const { return window_height; }
+
+ inline unsigned long get_num_detection_templates (
+ ) const;
+
+ inline unsigned long get_num_movable_components_per_detection_template (
+ ) const;
+
+ void set_padding (
+ unsigned long new_padding
+ )
+ {
+ padding = new_padding;
+ feats.clear();
+ }
+
+ unsigned long get_padding (
+ ) const { return padding; }
+
+ void set_cell_size (
+ unsigned long new_cell_size
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(new_cell_size > 0 ,
+ "\t void scan_fhog_pyramid::set_cell_size()"
+ << "\n\t You can't have zero sized fHOG cells. "
+ << "\n\t this: " << this
+ );
+
+ cell_size = new_cell_size;
+ feats.clear();
+ }
+
+ unsigned long get_cell_size (
+ ) const { return cell_size; }
+
+ inline long get_num_dimensions (
+ ) const;
+
+ unsigned long get_max_pyramid_levels (
+ ) const;
+
+ const feature_extractor_type& get_feature_extractor(
+ ) const { return fe; }
+
+ void set_max_pyramid_levels (
+ unsigned long max_levels
+ );
+
+ void set_min_pyramid_layer_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ inline unsigned long get_min_pyramid_layer_width (
+ ) const;
+
+ inline unsigned long get_min_pyramid_layer_height (
+ ) const;
+
+ void detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_loaded_with_image() &&
+ w.size() >= get_num_dimensions(),
+ "\t void scan_fhog_pyramid::detect()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t w.size(): " << w.size()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t this: " << this
+ );
+
+ fhog_filterbank temp = build_fhog_filterbank(w);
+ detect(temp, dets, thresh);
+ }
+
+ class fhog_filterbank
+ {
+ friend class scan_fhog_pyramid;
+ public:
+ inline long get_num_dimensions() const
+ {
+ unsigned long dims = 0;
+ for (unsigned long i = 0; i < filters.size(); ++i)
+ {
+ dims += filters[i].size();
+ }
+ return dims;
+ }
+
+ const std::vector<matrix<float> >& get_filters() const { return filters;}
+
+ unsigned long num_separable_filters() const
+ {
+ unsigned long num = 0;
+ for (unsigned long i = 0; i < row_filters.size(); ++i)
+ {
+ num += row_filters[i].size();
+ }
+ return num;
+ }
+
+ std::vector<matrix<float> > filters;
+ std::vector<std::vector<matrix<float,0,1> > > row_filters, col_filters;
+ };
+
+ fhog_filterbank build_fhog_filterbank (
+ const feature_vector_type& weights
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(weights.size() >= get_num_dimensions(),
+ "\t fhog_filterbank scan_fhog_pyramid::build_fhog_filterbank()"
+ << "\n\t The number of weights isn't enough to fill out the filterbank. "
+ << "\n\t weights.size(): " << weights.size()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t this: " << this
+ );
+
+ fhog_filterbank temp;
+ temp.filters.resize(fe.get_num_planes());
+ temp.row_filters.resize(fe.get_num_planes());
+ temp.col_filters.resize(fe.get_num_planes());
+
+ // load filters from w
+ unsigned long width, height;
+ compute_fhog_window_size(width, height);
+ const long size = width*height;
+ for (unsigned long i = 0; i < temp.filters.size(); ++i)
+ {
+ matrix<double> u,v,w,f;
+ f = reshape(rowm(weights, range(i*size, (i+1)*size-1)), height, width);
+ temp.filters[i] = matrix_cast<float>(f);
+
+ svd3(f, u,w,v);
+
+ matrix<double> w2 = w;
+ rsort_columns(u,w);
+ rsort_columns(v,w2);
+
+ double thresh = std::max(1e-4, max(w)*0.001);
+ w = round_zeros(w, thresh);
+
+
+ for (long j = 0; j < w.size(); ++j)
+ {
+ if (w(j) != 0)
+ {
+ temp.col_filters[i].push_back(matrix_cast<float>(colm(u,j)*std::sqrt(w(j))));
+ temp.row_filters[i].push_back(matrix_cast<float>(colm(v,j)*std::sqrt(w(j))));
+ }
+ }
+ }
+
+ return temp;
+ }
+
+ void detect (
+ const fhog_filterbank& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const;
+
+
+ void get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const;
+
+ full_object_detection get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& w
+ ) const;
+
+ const rectangle get_best_matching_rect (
+ const rectangle& rect
+ ) const;
+
+ double get_nuclear_norm_regularization_strength (
+ ) const { return nuclear_norm_regularization_strength; }
+
+ void set_nuclear_norm_regularization_strength (
+ double strength
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(strength >= 0 ,
+ "\t void scan_fhog_pyramid::set_nuclear_norm_regularization_strength()"
+ << "\n\t You can't have a negative regularization strength."
+ << "\n\t strength: " << strength
+ << "\n\t this: " << this
+ );
+
+ nuclear_norm_regularization_strength = strength;
+ }
+
+ unsigned long get_fhog_window_width (
+ ) const
+ {
+ unsigned long width, height;
+ compute_fhog_window_size(width, height);
+ return width;
+ }
+
+ unsigned long get_fhog_window_height (
+ ) const
+ {
+ unsigned long width, height;
+ compute_fhog_window_size(width, height);
+ return height;
+ }
+
+ template <typename T, typename U>
+ friend void serialize (
+ const scan_fhog_pyramid<T,U>& item,
+ std::ostream& out
+ );
+
+ template <typename T, typename U>
+ friend void deserialize (
+ scan_fhog_pyramid<T,U>& item,
+ std::istream& in
+ );
+
+ private:
+ inline void compute_fhog_window_size(
+ unsigned long& width,
+ unsigned long& height
+ ) const
+ {
+ const rectangle rect = centered_rect(point(0,0),window_width,window_height);
+ const rectangle temp = grow_rect(fe.image_to_feats(rect, cell_size, 1, 1), padding);
+ width = temp.width();
+ height = temp.height();
+ }
+
+ void get_mapped_rect_and_metadata (
+ const unsigned long number_pyramid_levels,
+ const rectangle& rect,
+ rectangle& mapped_rect,
+ rectangle& fhog_rect,
+ unsigned long& best_level
+ ) const;
+
+ double get_match_score (
+ rectangle r1,
+ rectangle r2
+ ) const
+ {
+ // make the rectangles overlap as much as possible before computing the match score.
+ r1 = move_rect(r1, r2.tl_corner());
+ return (r1.intersect(r2).area())/(double)(r1 + r2).area();
+ }
+
+ typedef array<array2d<float> > fhog_image;
+
+ feature_extractor_type fe;
+ array<fhog_image> feats;
+ int cell_size;
+ unsigned long padding;
+ unsigned long window_width;
+ unsigned long window_height;
+ unsigned long max_pyramid_levels;
+ unsigned long min_pyramid_layer_width;
+ unsigned long min_pyramid_layer_height;
+ double nuclear_norm_regularization_strength;
+
+ void init()
+ {
+ cell_size = 8;
+ padding = 1;
+ window_width = 64;
+ window_height = 64;
+ max_pyramid_levels = 1000;
+ min_pyramid_layer_width = 64;
+ min_pyramid_layer_height = 64;
+ nuclear_norm_regularization_strength = 0;
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename fhog_filterbank>
+ rectangle apply_filters_to_fhog (
+ const fhog_filterbank& w,
+ const array<array2d<float> >& feats,
+ array2d<float>& saliency_image
+ )
+ {
+ const unsigned long num_separable_filters = w.num_separable_filters();
+ rectangle area;
+ // use the separable filters if they would be faster than running the regular filters.
+ if (num_separable_filters > w.filters.size()*std::min(w.filters[0].nr(),w.filters[0].nc())/3.0)
+ {
+ area = spatially_filter_image(feats[0], saliency_image, w.filters[0]);
+ for (unsigned long i = 1; i < w.filters.size(); ++i)
+ {
+ // now we filter but the output adds to saliency_image rather than
+ // overwriting it.
+ spatially_filter_image(feats[i], saliency_image, w.filters[i], 1, false, true);
+ }
+ }
+ else
+ {
+ saliency_image.clear();
+ array2d<float> scratch;
+
+ // find the first filter to apply
+ unsigned long i = 0;
+ while (i < w.row_filters.size() && w.row_filters[i].size() == 0)
+ ++i;
+
+ for (; i < w.row_filters.size(); ++i)
+ {
+ for (unsigned long j = 0; j < w.row_filters[i].size(); ++j)
+ {
+ if (saliency_image.size() == 0)
+ area = float_spatially_filter_image_separable(feats[i], saliency_image, w.row_filters[i][j], w.col_filters[i][j],scratch,false);
+ else
+ area = float_spatially_filter_image_separable(feats[i], saliency_image, w.row_filters[i][j], w.col_filters[i][j],scratch,true);
+ }
+ }
+ if (saliency_image.size() == 0)
+ {
+ saliency_image.set_size(feats[0].nr(), feats[0].nc());
+ assign_all_pixels(saliency_image, 0);
+ }
+ }
+ return area;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ void serialize (
+ const scan_fhog_pyramid<T,U>& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.fe, out);
+ serialize(item.feats, out);
+ serialize(item.cell_size, out);
+ serialize(item.padding, out);
+ serialize(item.window_width, out);
+ serialize(item.window_height, out);
+ serialize(item.max_pyramid_levels, out);
+ serialize(item.min_pyramid_layer_width, out);
+ serialize(item.min_pyramid_layer_height, out);
+ serialize(item.nuclear_norm_regularization_strength, out);
+ serialize(item.get_num_dimensions(), out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ void deserialize (
+ scan_fhog_pyramid<T,U>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unsupported version found when deserializing a scan_fhog_pyramid object.");
+
+ deserialize(item.fe, in);
+ deserialize(item.feats, in);
+ deserialize(item.cell_size, in);
+ deserialize(item.padding, in);
+ deserialize(item.window_width, in);
+ deserialize(item.window_height, in);
+ deserialize(item.max_pyramid_levels, in);
+ deserialize(item.min_pyramid_layer_width, in);
+ deserialize(item.min_pyramid_layer_height, in);
+ deserialize(item.nuclear_norm_regularization_strength, in);
+
+ // When developing some feature extractor, it's easy to accidentally change its
+ // number of dimensions and then try to deserialize data from an older version of
+ // your extractor into the current code. This check is here to catch that kind of
+ // user error.
+ long dims;
+ deserialize(dims, in);
+ if (item.get_num_dimensions() != dims)
+ throw serialization_error("Number of dimensions in serialized scan_fhog_pyramid doesn't match the expected number.");
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// scan_fhog_pyramid member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ scan_fhog_pyramid (
+ )
+ {
+ init();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ scan_fhog_pyramid (
+ const feature_extractor_type& fe_
+ )
+ {
+ init();
+ fe = fe_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename pyramid_type,
+ typename image_type,
+ typename feature_extractor_type
+ >
+ void create_fhog_pyramid (
+ const image_type& img,
+ const feature_extractor_type& fe,
+ array<array<array2d<float> > >& feats,
+ int cell_size,
+ int filter_rows_padding,
+ int filter_cols_padding,
+ unsigned long min_pyramid_layer_width,
+ unsigned long min_pyramid_layer_height,
+ unsigned long max_pyramid_levels
+ )
+ {
+ unsigned long levels = 0;
+ rectangle rect = get_rect(img);
+
+ // figure out how many pyramid levels we should be using based on the image size
+ pyramid_type pyr;
+ do
+ {
+ rect = pyr.rect_down(rect);
+ ++levels;
+ } while (rect.width() >= min_pyramid_layer_width && rect.height() >= min_pyramid_layer_height &&
+ levels < max_pyramid_levels);
+
+ if (feats.max_size() < levels)
+ feats.set_max_size(levels);
+ feats.set_size(levels);
+
+
+
+ // build our feature pyramid
+ fe(img, feats[0], cell_size,filter_rows_padding,filter_cols_padding);
+ DLIB_ASSERT(feats[0].size() == fe.get_num_planes(),
+ "Invalid feature extractor used with dlib::scan_fhog_pyramid. The output does not have the \n"
+ "indicated number of planes.");
+
+ if (feats.size() > 1)
+ {
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ array2d<pixel_type> temp1, temp2;
+ pyr(img, temp1);
+ fe(temp1, feats[1], cell_size,filter_rows_padding,filter_cols_padding);
+ swap(temp1,temp2);
+
+ for (unsigned long i = 2; i < feats.size(); ++i)
+ {
+ pyr(temp2, temp1);
+ fe(temp1, feats[i], cell_size,filter_rows_padding,filter_cols_padding);
+ swap(temp1,temp2);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ template <
+ typename image_type
+ >
+ void scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ load (
+ const image_type& img
+ )
+ {
+ unsigned long width, height;
+ compute_fhog_window_size(width,height);
+ impl::create_fhog_pyramid<Pyramid_type>(img, fe, feats, cell_size, height,
+ width, min_pyramid_layer_width, min_pyramid_layer_height,
+ max_pyramid_levels);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ bool scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ is_loaded_with_image (
+ ) const
+ {
+ return feats.size() != 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ void scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ copy_configuration (
+ const scan_fhog_pyramid& item
+ )
+ {
+ cell_size = item.cell_size;
+ padding = item.padding;
+ window_width = item.window_width;
+ window_height = item.window_height;
+ max_pyramid_levels = item.max_pyramid_levels;
+ min_pyramid_layer_width = item.min_pyramid_layer_width;
+ min_pyramid_layer_height = item.min_pyramid_layer_height;
+ nuclear_norm_regularization_strength = item.nuclear_norm_regularization_strength;
+ fe = item.fe;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ unsigned long scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_num_detection_templates (
+ ) const
+ {
+ return 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ unsigned long scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_num_movable_components_per_detection_template (
+ ) const
+ {
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ long scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_num_dimensions (
+ ) const
+ {
+ unsigned long width, height;
+ compute_fhog_window_size(width,height);
+ return width*height*fe.get_num_planes();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ unsigned long scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_max_pyramid_levels (
+ ) const
+ {
+ return max_pyramid_levels;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ void scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ set_max_pyramid_levels (
+ unsigned long max_levels
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_levels > 0 ,
+ "\t void scan_fhog_pyramid::set_max_pyramid_levels()"
+ << "\n\t You can't have zero levels. "
+ << "\n\t max_levels: " << max_levels
+ << "\n\t this: " << this
+ );
+
+ max_pyramid_levels = max_levels;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline bool compare_pair_rect (
+ const std::pair<double, rectangle>& a,
+ const std::pair<double, rectangle>& b
+ )
+ {
+ return a.first < b.first;
+ }
+
+ template <
+ typename pyramid_type,
+ typename feature_extractor_type,
+ typename fhog_filterbank
+ >
+ void detect_from_fhog_pyramid (
+ const array<array<array2d<float> > >& feats,
+ const feature_extractor_type& fe,
+ const fhog_filterbank& w,
+ const double thresh,
+ const unsigned long det_box_height,
+ const unsigned long det_box_width,
+ const int cell_size,
+ const int filter_rows_padding,
+ const int filter_cols_padding,
+ std::vector<std::pair<double, rectangle> >& dets
+ )
+ {
+ dets.clear();
+
+ array2d<float> saliency_image;
+ pyramid_type pyr;
+
+ // for all pyramid levels
+ for (unsigned long l = 0; l < feats.size(); ++l)
+ {
+ const rectangle area = apply_filters_to_fhog(w, feats[l], saliency_image);
+
+ // now search the saliency image for any detections
+ for (long r = area.top(); r <= area.bottom(); ++r)
+ {
+ for (long c = area.left(); c <= area.right(); ++c)
+ {
+ // if we found a detection
+ if (saliency_image[r][c] >= thresh)
+ {
+ rectangle rect = fe.feats_to_image(centered_rect(point(c,r),det_box_width,det_box_height),
+ cell_size, filter_rows_padding, filter_cols_padding);
+ rect = pyr.rect_up(rect, l);
+ dets.push_back(std::make_pair(saliency_image[r][c], rect));
+ }
+ }
+ }
+ }
+
+ std::sort(dets.rbegin(), dets.rend(), compare_pair_rect);
+ }
+
+ inline bool overlaps_any_box (
+ const test_box_overlap& tester,
+ const std::vector<rect_detection>& rects,
+ const rect_detection& rect
+ )
+ {
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ // Only compare detections from the same detector. That is, we don't want
+ // the output of one detector to stop on the output of another detector.
+ if (rects[i].weight_index == rect.weight_index && tester(rects[i].rect, rect.rect))
+ return true;
+ }
+ return false;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ void scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ detect (
+ const fhog_filterbank& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_loaded_with_image() &&
+ w.get_num_dimensions() == get_num_dimensions(),
+ "\t void scan_fhog_pyramid::detect()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t w.get_num_dimensions(): " << w.get_num_dimensions()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t this: " << this
+ );
+
+ unsigned long width, height;
+ compute_fhog_window_size(width,height);
+
+ impl::detect_from_fhog_pyramid<pyramid_type>(feats, fe, w, thresh,
+ height-2*padding, width-2*padding, cell_size, height, width, dets);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ const rectangle scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_best_matching_rect (
+ const rectangle& rect
+ ) const
+ {
+ rectangle mapped_rect, fhog_rect;
+ unsigned long best_level;
+ get_mapped_rect_and_metadata(max_pyramid_levels, rect, mapped_rect, fhog_rect, best_level);
+ return mapped_rect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ void scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_mapped_rect_and_metadata (
+ const unsigned long number_pyramid_levels,
+ const rectangle& rect,
+ rectangle& mapped_rect,
+ rectangle& fhog_rect,
+ unsigned long& best_level
+ ) const
+ {
+ pyramid_type pyr;
+ best_level = 0;
+ double best_match_score = -1;
+
+
+ unsigned long width, height;
+ compute_fhog_window_size(width,height);
+
+ // Figure out the pyramid level which best matches rect against our detection
+ // window.
+ for (unsigned long l = 0; l < number_pyramid_levels; ++l)
+ {
+ const rectangle rect_fhog_space = fe.image_to_feats(pyr.rect_down(rect,l), cell_size, height,width);
+
+ const rectangle win_image_space = pyr.rect_up(fe.feats_to_image(centered_rect(center(rect_fhog_space),width-2*padding,height-2*padding), cell_size, height,width), l);
+
+ const double match_score = get_match_score(win_image_space, rect);
+ if (match_score > best_match_score)
+ {
+ best_match_score = match_score;
+ best_level = l;
+ fhog_rect = centered_rect(center(rect_fhog_space), width, height);
+ }
+
+ if (rect_fhog_space.area() <= 1)
+ break;
+ }
+ mapped_rect = pyr.rect_up(fe.feats_to_image(shrink_rect(fhog_rect,padding), cell_size,height,width),best_level);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ full_object_detection scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type&
+ ) const
+ {
+ return full_object_detection(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ void scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_loaded_with_image() &&
+ psi.size() >= get_num_dimensions() &&
+ obj.num_parts() == 0,
+ "\t void scan_fhog_pyramid::get_feature_vector()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t psi.size(): " << psi.size()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t obj.num_parts(): " << obj.num_parts()
+ << "\n\t this: " << this
+ );
+
+
+
+ rectangle mapped_rect;
+ unsigned long best_level;
+ rectangle fhog_rect;
+ get_mapped_rect_and_metadata(feats.size(), obj.get_rect(), mapped_rect, fhog_rect, best_level);
+
+
+ long i = 0;
+ for (unsigned long ii = 0; ii < feats[best_level].size(); ++ii)
+ {
+ const rectangle rect = get_rect(feats[best_level][0]);
+ for (long r = fhog_rect.top(); r <= fhog_rect.bottom(); ++r)
+ {
+ for (long c = fhog_rect.left(); c <= fhog_rect.right(); ++c)
+ {
+ if (rect.contains(c,r))
+ psi(i) += feats[best_level][ii][r][c];
+ ++i;
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ void scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ set_min_pyramid_layer_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(width > 0 && height > 0 ,
+ "\t void scan_fhog_pyramid::set_min_pyramid_layer_size()"
+ << "\n\t These sizes can't be zero. "
+ << "\n\t width: " << width
+ << "\n\t height: " << height
+ << "\n\t this: " << this
+ );
+
+ min_pyramid_layer_width = width;
+ min_pyramid_layer_height = height;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ unsigned long scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_min_pyramid_layer_width (
+ ) const
+ {
+ return min_pyramid_layer_width;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ unsigned long scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::
+ get_min_pyramid_layer_height (
+ ) const
+ {
+ return min_pyramid_layer_height;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ matrix<unsigned char> draw_fhog (
+ const object_detector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> >& detector,
+ const unsigned long weight_index = 0,
+ const long cell_draw_size = 15
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(weight_index < detector.num_detectors(),
+ "\t matrix draw_fhog()"
+ << "\n\t Invalid arguments were given to this function. "
+ << "\n\t weight_index: " << weight_index
+ << "\n\t detector.num_detectors(): " << detector.num_detectors()
+ );
+ DLIB_ASSERT(cell_draw_size > 0 && detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions(),
+ "\t matrix draw_fhog()"
+ << "\n\t Invalid arguments were given to this function. "
+ << "\n\t cell_draw_size: " << cell_draw_size
+ << "\n\t weight_index: " << weight_index
+ << "\n\t detector.get_w(weight_index).size(): " << detector.get_w(weight_index).size()
+ << "\n\t detector.get_scanner().get_num_dimensions(): " << detector.get_scanner().get_num_dimensions()
+ );
+
+ typename scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::fhog_filterbank fb = detector.get_scanner().build_fhog_filterbank(detector.get_w(weight_index));
+ return draw_fhog(fb.get_filters(),cell_draw_size);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ unsigned long num_separable_filters (
+ const object_detector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> >& detector,
+ const unsigned long weight_index = 0
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(weight_index < detector.num_detectors(),
+ "\t unsigned long num_separable_filters()"
+ << "\n\t Invalid arguments were given to this function. "
+ << "\n\t weight_index: " << weight_index
+ << "\n\t detector.num_detectors(): " << detector.num_detectors()
+ );
+ DLIB_ASSERT(detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions() ,
+ "\t unsigned long num_separable_filters()"
+ << "\n\t Invalid arguments were given to this function. "
+ << "\n\t detector.get_w(weight_index).size(): " << detector.get_w(weight_index).size()
+ << "\n\t detector.get_scanner().get_num_dimensions(): " << detector.get_scanner().get_num_dimensions()
+ );
+
+ typename scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::fhog_filterbank fb = detector.get_scanner().build_fhog_filterbank(detector.get_w(weight_index));
+ return fb.num_separable_filters();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ object_detector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> > threshold_filter_singular_values (
+ const object_detector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> >& detector,
+ double thresh,
+ const unsigned long weight_index = 0
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(thresh >= 0 ,
+ "\t object_detector threshold_filter_singular_values()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t thresh: " << thresh
+ );
+
+ DLIB_ASSERT(weight_index < detector.num_detectors(),
+ "\t object_detector threshold_filter_singular_values()"
+ << "\n\t Invalid arguments were given to this function. "
+ << "\n\t weight_index: " << weight_index
+ << "\n\t detector.num_detectors(): " << detector.num_detectors()
+ );
+ DLIB_ASSERT(detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions() ,
+ "\t object_detector threshold_filter_singular_values()"
+ << "\n\t Invalid arguments were given to this function. "
+ << "\n\t detector.get_w(weight_index).size(): " << detector.get_w(weight_index).size()
+ << "\n\t detector.get_scanner().get_num_dimensions(): " << detector.get_scanner().get_num_dimensions()
+ );
+
+
+ const unsigned long width = detector.get_scanner().get_fhog_window_width();
+ const unsigned long height = detector.get_scanner().get_fhog_window_height();
+ const long num_planes = detector.get_scanner().get_feature_extractor().get_num_planes();
+ const long size = width*height;
+
+ std::vector<matrix<double,0,1> > detector_weights;
+ for (unsigned long j = 0; j < detector.num_detectors(); ++j)
+ {
+ matrix<double,0,1> weights = detector.get_w(j);
+
+ if (j == weight_index)
+ {
+ matrix<double> u,v,w,f;
+ for (long i = 0; i < num_planes; ++i)
+ {
+ f = reshape(rowm(weights, range(i*size, (i+1)*size-1)), height, width);
+
+ svd3(f, u,w,v);
+ const double scaled_thresh = std::max(1e-3, max(w)*thresh);
+ w = round_zeros(w, scaled_thresh);
+ f = u*diagm(w)*trans(v);
+
+ set_rowm(weights,range(i*size, (i+1)*size-1)) = reshape_to_column_vector(f);
+ }
+ }
+ detector_weights.push_back(weights);
+ }
+
+ return object_detector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> >(detector.get_scanner(),
+ detector.get_overlap_tester(),
+ detector_weights);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type,
+ typename svm_struct_prob_type
+ >
+ void configure_nuclear_norm_regularizer (
+ const scan_fhog_pyramid<Pyramid_type,feature_extractor_type>& scanner,
+ svm_struct_prob_type& prob
+ )
+ {
+ const double strength = scanner.get_nuclear_norm_regularization_strength();
+ const long num_planes = scanner.get_feature_extractor().get_num_planes();
+ if (strength != 0)
+ {
+ const unsigned long width = scanner.get_fhog_window_width();
+ const unsigned long height = scanner.get_fhog_window_height();
+ for (long i = 0; i < num_planes; ++i)
+ {
+ prob.add_nuclear_norm_regularizer(i*width*height, height, width, strength);
+ }
+ prob.set_cache_based_epsilon(0.001);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ struct processed_weight_vector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> >
+ {
+ processed_weight_vector(){}
+
+ typedef matrix<double,0,1> feature_vector_type;
+ typedef typename scan_fhog_pyramid<Pyramid_type,feature_extractor_type>::fhog_filterbank fhog_filterbank;
+
+ void init (
+ const scan_fhog_pyramid<Pyramid_type,feature_extractor_type>& scanner
+ )
+ {
+ fb = scanner.build_fhog_filterbank(w);
+ }
+
+ const fhog_filterbank& get_detect_argument() const { return fb; }
+
+ feature_vector_type w;
+ fhog_filterbank fb;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type,
+ typename image_type
+ >
+ void evaluate_detectors (
+ const std::vector<object_detector<scan_fhog_pyramid<pyramid_type> > >& detectors,
+ const image_type& img,
+ std::vector<rect_detection>& dets,
+ const double adjust_threshold = 0
+ )
+ {
+ typedef scan_fhog_pyramid<pyramid_type> scanner_type;
+
+ dets.clear();
+ if (detectors.size() == 0)
+ return;
+
+ const unsigned long cell_size = detectors[0].get_scanner().get_cell_size();
+
+ // Find the maximum sized filters and also most extreme pyramiding settings used.
+ unsigned long max_filter_width = 0;
+ unsigned long max_filter_height = 0;
+ unsigned long min_pyramid_layer_width = std::numeric_limits<unsigned long>::max();
+ unsigned long min_pyramid_layer_height = std::numeric_limits<unsigned long>::max();
+ unsigned long max_pyramid_levels = 0;
+ bool all_cell_sizes_the_same = true;
+ for (unsigned long i = 0; i < detectors.size(); ++i)
+ {
+ const scanner_type& scanner = detectors[i].get_scanner();
+ max_filter_width = std::max(max_filter_width, scanner.get_fhog_window_width());
+ max_filter_height = std::max(max_filter_height, scanner.get_fhog_window_height());
+ max_pyramid_levels = std::max(max_pyramid_levels, scanner.get_max_pyramid_levels());
+ min_pyramid_layer_width = std::min(min_pyramid_layer_width, scanner.get_min_pyramid_layer_width());
+ min_pyramid_layer_height = std::min(min_pyramid_layer_height, scanner.get_min_pyramid_layer_height());
+ if (cell_size != scanner.get_cell_size())
+ all_cell_sizes_the_same = false;
+ }
+
+ std::vector<rect_detection> dets_accum;
+ // Do to the HOG feature extraction to make the fhog pyramid. Again, note that we
+ // are making a pyramid that will work with any of the detectors. But only if all
+ // the cell sizes are the same. If they aren't then we have to calculate the
+ // pyramid for each detector individually.
+ array<array<array2d<float> > > feats;
+ if (all_cell_sizes_the_same)
+ {
+ impl::create_fhog_pyramid<pyramid_type>(img,
+ detectors[0].get_scanner().get_feature_extractor(), feats, cell_size,
+ max_filter_height, max_filter_width, min_pyramid_layer_width,
+ min_pyramid_layer_height, max_pyramid_levels);
+ }
+
+ std::vector<std::pair<double, rectangle> > temp_dets;
+ for (unsigned long i = 0; i < detectors.size(); ++i)
+ {
+ const scanner_type& scanner = detectors[i].get_scanner();
+ if (!all_cell_sizes_the_same)
+ {
+ impl::create_fhog_pyramid<pyramid_type>(img,
+ scanner.get_feature_extractor(), feats, scanner.get_cell_size(),
+ max_filter_height, max_filter_width, min_pyramid_layer_width,
+ min_pyramid_layer_height, max_pyramid_levels);
+ }
+
+ const unsigned long det_box_width = scanner.get_fhog_window_width() - 2*scanner.get_padding();
+ const unsigned long det_box_height = scanner.get_fhog_window_height() - 2*scanner.get_padding();
+ // A single detector object might itself have multiple weight vectors in it. So
+ // we need to evaluate all of them.
+ for (unsigned d = 0; d < detectors[i].num_detectors(); ++d)
+ {
+ const double thresh = detectors[i].get_processed_w(d).w(scanner.get_num_dimensions());
+
+ impl::detect_from_fhog_pyramid<pyramid_type>(feats, scanner.get_feature_extractor(),
+ detectors[i].get_processed_w(d).get_detect_argument(), thresh+adjust_threshold,
+ det_box_height, det_box_width, cell_size, max_filter_height,
+ max_filter_width, temp_dets);
+
+ for (unsigned long j = 0; j < temp_dets.size(); ++j)
+ {
+ rect_detection temp;
+ temp.detection_confidence = temp_dets[j].first-thresh;
+ temp.weight_index = i;
+ temp.rect = temp_dets[j].second;
+ dets_accum.push_back(temp);
+ }
+ }
+ }
+
+
+ // Do non-max suppression
+ if (detectors.size() > 1)
+ std::sort(dets_accum.rbegin(), dets_accum.rend());
+ for (unsigned long i = 0; i < dets_accum.size(); ++i)
+ {
+ const test_box_overlap tester = detectors[dets_accum[i].weight_index].get_overlap_tester();
+ if (impl::overlaps_any_box(tester, dets, dets_accum[i]))
+ continue;
+
+ dets.push_back(dets_accum[i]);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename image_type
+ >
+ std::vector<rectangle> evaluate_detectors (
+ const std::vector<object_detector<scan_fhog_pyramid<Pyramid_type> > >& detectors,
+ const image_type& img,
+ const double adjust_threshold = 0
+ )
+ {
+ std::vector<rectangle> out_dets;
+ std::vector<rect_detection> dets;
+ evaluate_detectors(detectors, img, dets, adjust_threshold);
+ out_dets.reserve(dets.size());
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ out_dets.push_back(dets[i].rect);
+ return out_dets;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_fHOG_PYRAMID_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/scan_fhog_pyramid_abstract.h b/ml/dlib/dlib/image_processing/scan_fhog_pyramid_abstract.h
new file mode 100644
index 000000000..d12a2b2b8
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_fhog_pyramid_abstract.h
@@ -0,0 +1,784 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SCAN_fHOG_PYRAMID_ABSTRACT_Hh_
+#ifdef DLIB_SCAN_fHOG_PYRAMID_ABSTRACT_Hh_
+
+#include <vector>
+#include "../image_transforms/fhog_abstract.h"
+#include "object_detector_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ matrix<unsigned char> draw_fhog (
+ const object_detector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> >& detector,
+ const unsigned long weight_index = 0,
+ const long cell_draw_size = 15
+ );
+ /*!
+ requires
+ - cell_draw_size > 0
+ - weight_index < detector.num_detectors()
+ - detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions()
+ (i.e. the detector must have been populated with a HOG filter)
+ ensures
+ - Converts the HOG filters in the given detector (specifically, the filters in
+ detector.get_w(weight_index)) into an image suitable for display on the
+ screen. In particular, we draw all the HOG cells into a grayscale image in a
+ way that shows the magnitude and orientation of the gradient energy in each
+ cell. The resulting image is then returned.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ unsigned long num_separable_filters (
+ const object_detector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> >& detector,
+ const unsigned long weight_index = 0
+ );
+ /*!
+ requires
+ - weight_index < detector.num_detectors()
+ - detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions()
+ (i.e. the detector must have been populated with a HOG filter)
+ ensures
+ - Returns the number of separable filters necessary to represent the HOG
+ filters in the given detector's weight_index'th filter. This is the filter
+ defined by detector.get_w(weight_index).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename feature_extractor_type
+ >
+ object_detector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> > threshold_filter_singular_values (
+ const object_detector<scan_fhog_pyramid<Pyramid_type,feature_extractor_type> >& detector,
+ double thresh,
+ const unsigned long weight_index = 0
+ );
+ /*!
+ requires
+ - thresh >= 0
+ - weight_index < detector.num_detectors()
+ - detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions()
+ (i.e. the detector must have been populated with a HOG filter)
+ ensures
+ - Removes all components of the filters in the given detector that have
+ singular values that are smaller than the given threshold. Therefore, this
+ function allows you to control how many separable filters are in a detector.
+ In particular, as thresh gets larger the quantity
+ num_separable_filters(threshold_filter_singular_values(detector,thresh,weight_index),weight_index)
+ will generally get smaller and therefore give a faster running detector.
+ However, note that at some point a large enough thresh will drop too much
+ information from the filters and their accuracy will suffer.
+ - returns the updated detector
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class default_fhog_feature_extractor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ The scan_fhog_pyramid object defined below is primarily meant to be used
+ with the feature extraction technique implemented by extract_fhog_features().
+ This technique can generally be understood as taking an input image and
+ outputting a multi-planed output image of floating point numbers that
+ somehow describe the image contents. Since there are many ways to define
+ how this feature mapping is performed, the scan_fhog_pyramid allows you to
+ replace the extract_fhog_features() method with a customized method of your
+ choosing. To do this you implement a class with the same interface as
+ default_fhog_feature_extractor.
+
+ Therefore, the point of default_fhog_feature_extractor is two fold. First,
+ it provides the default FHOG feature extraction method used by scan_fhog_pyramid.
+ Second, it serves to document the interface you need to implement to define
+ your own custom HOG style feature extraction.
+ !*/
+
+ public:
+
+ rectangle image_to_feats (
+ const rectangle& rect,
+ int cell_size,
+ int filter_rows_padding,
+ int filter_cols_padding
+ ) const { return image_to_fhog(rect, cell_size, filter_rows_padding, filter_cols_padding); }
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ ensures
+ - Maps a rectangle from the coordinates in an input image to the corresponding
+ area in the output feature image.
+ !*/
+
+ rectangle feats_to_image (
+ const rectangle& rect,
+ int cell_size,
+ int filter_rows_padding,
+ int filter_cols_padding
+ ) const { return fhog_to_image(rect, cell_size, filter_rows_padding, filter_cols_padding); }
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ ensures
+ - Maps a rectangle from the coordinates of the hog feature image back to
+ the input image.
+ - Mapping from feature space to image space is an invertible
+ transformation. That is, for any rectangle R we have:
+ R == image_to_feats(feats_to_image(R,cell_size,filter_rows_padding,filter_cols_padding),
+ cell_size,filter_rows_padding,filter_cols_padding).
+ !*/
+
+ template <
+ typename image_type
+ >
+ void operator()(
+ const image_type& img,
+ dlib::array<array2d<float> >& hog,
+ int cell_size,
+ int filter_rows_padding,
+ int filter_cols_padding
+ ) const { extract_fhog_features(img,hog,cell_size,filter_rows_padding,filter_cols_padding); }
+ /*!
+ requires
+ - image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - img contains some kind of pixel type.
+ (i.e. pixel_traits<typename image_type::type> is defined)
+ ensures
+ - Extracts FHOG features by calling extract_fhog_features(). The results are
+ stored into #hog. Note that if you are implementing your own feature extractor you can
+ pretty much do whatever you want in terms of feature extraction so long as the following
+ conditions are met:
+ - #hog.size() == get_num_planes()
+ - Each image plane in #hog has the same dimensions.
+ - for all valid i, r, and c:
+ - #hog[i][r][c] == a feature value describing the image content centered at the
+ following pixel location in img:
+ feats_to_image(point(c,r),cell_size,filter_rows_padding,filter_cols_padding)
+ !*/
+
+ inline unsigned long get_num_planes (
+ ) const { return 31; }
+ /*!
+ ensures
+ - returns the number of planes in the hog image output by the operator()
+ method.
+ !*/
+ };
+
+ inline void serialize (const default_fhog_feature_extractor&, std::ostream&) {}
+ inline void deserialize (default_fhog_feature_extractor&, std::istream&) {}
+ /*!
+ Provides serialization support. Note that there is no state in the default hog
+ feature extractor so these functions do nothing. But if you define a custom
+ feature extractor then make sure you remember to serialize any state in your
+ feature extractor.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type = default_fhog_feature_extractor
+ >
+ class scan_fhog_pyramid : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON Pyramid_type
+ - Must be one of the pyramid_down objects defined in
+ dlib/image_transforms/image_pyramid_abstract.h or an object with a
+ compatible interface
+
+ REQUIREMENTS ON Feature_extractor_type
+ - Must be a type with an interface compatible with the
+ default_fhog_feature_extractor.
+
+ INITIAL VALUE
+ - get_padding() == 1
+ - get_cell_size() == 8
+ - get_detection_window_width() == 64
+ - get_detection_window_height() == 64
+ - get_max_pyramid_levels() == 1000
+ - get_min_pyramid_layer_width() == 64
+ - get_min_pyramid_layer_height() == 64
+ - get_nuclear_norm_regularization_strength() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for running a fixed sized sliding window classifier
+ over an image pyramid. In particular, it slides a linear classifier over
+ a HOG pyramid as discussed in the paper:
+ Histograms of Oriented Gradients for Human Detection by Navneet Dalal
+ and Bill Triggs, CVPR 2005
+ However, we augment the method slightly to use the version of HOG features
+ from:
+ Object Detection with Discriminatively Trained Part Based Models by
+ P. Felzenszwalb, R. Girshick, D. McAllester, D. Ramanan
+ IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. 32, No. 9, Sep. 2010
+ Since these HOG features have been shown to give superior performance.
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be
+ protected by a mutex lock except for the case where you are copying the
+ configuration (via copy_configuration()) of a scan_fhog_pyramid object to
+ many other threads. In this case, it is safe to copy the configuration of
+ a shared object so long as no other operations are performed on it.
+ !*/
+
+ public:
+ typedef matrix<double,0,1> feature_vector_type;
+ typedef Pyramid_type pyramid_type;
+ typedef Feature_extractor_type feature_extractor_type;
+
+ scan_fhog_pyramid (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ explicit scan_fhog_pyramid (
+ const feature_extractor_type& fe
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ - #get_feature_extractor() == fe
+ !*/
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - img contains some kind of pixel type.
+ (i.e. pixel_traits<typename image_type::type> is defined)
+ ensures
+ - #is_loaded_with_image() == true
+ - This object is ready to run a classifier over img to detect object
+ locations. Call detect() to do this.
+ !*/
+
+ const feature_extractor_type& get_feature_extractor(
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the feature extractor used by this object.
+ !*/
+
+ bool is_loaded_with_image (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object has been loaded with an image to process and
+ false otherwise.
+ !*/
+
+ void copy_configuration (
+ const scan_fhog_pyramid& item
+ );
+ /*!
+ ensures
+ - Copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two scan_fhog_pyramid
+ objects S1 and S2, the following sequence of instructions should always
+ result in both of them having the exact same state:
+ S2.copy_configuration(S1);
+ S1.load(img);
+ S2.load(img);
+ !*/
+
+ void set_detection_window_size (
+ unsigned long window_width,
+ unsigned long window_height
+ );
+ /*!
+ requires
+ - window_width > 0
+ - window_height > 0
+ ensures
+ - When detect() is called, this object scans a window that is of the given
+ width and height (in pixels) over each layer in an image pyramid. This
+ means that the rectangle detections which come out of detect() will have
+ a width to height ratio approximately equal to window_width/window_height
+ and will be approximately window_width*window_height pixels in area or
+ larger. Therefore, the smallest object that can be detected is roughly
+ window_width by window_height pixels in size.
+ - #get_detection_window_width() == window_width
+ - #get_detection_window_height() == window_height
+ - Since we use a HOG feature representation, the detection procedure works
+ as follows:
+ Step 1. Make an image pyramid.
+ Step 2. Convert each layer of the image pyramid into a multi-planed HOG "image".
+ (the number of bands is given by get_feature_extractor().get_num_planes())
+ Step 3. Scan a linear classifier over each HOG image in the pyramid.
+ Moreover, the HOG features quantize the input image into a grid of cells,
+ each cell being get_cell_size() by get_cell_size() pixels in size. So
+ when we scan the object detector over the pyramid we are scanning an
+ appropriately sized window over these smaller quantized HOG features. In
+ particular, the size of the window we scan over the HOG feature pyramid
+ is #get_fhog_window_width() by #get_fhog_window_height() HOG cells in
+ size.
+ - #is_loaded_with_image() == false
+ !*/
+
+ unsigned long get_detection_window_width (
+ ) const;
+ /*!
+ ensures
+ - returns the width, in pixels, of the detection window that is scanned
+ over the image when detect() is called.
+ !*/
+
+ inline unsigned long get_detection_window_height (
+ ) const;
+ /*!
+ ensures
+ - returns the height, in pixels, of the detection window that is scanned
+ over the image when detect() is called.
+ !*/
+
+ unsigned long get_fhog_window_width (
+ ) const;
+ /*!
+ ensures
+ - Returns the width of the HOG scanning window in terms of HOG cell blocks.
+ Note that this is a function of get_detection_window_width(), get_cell_size(),
+ and get_padding() and is therefore not something you set directly.
+ - #get_fhog_window_width() is approximately equal to the number of HOG cells
+ that fit into get_detection_window_width() pixels plus 2*get_padding()
+ since we include additional padding around each window to add context.
+ !*/
+
+ unsigned long get_fhog_window_height (
+ ) const;
+ /*!
+ ensures
+ - Returns the height of the HOG scanning window in terms of HOG cell blocks.
+ Note that this is a function of get_detection_window_height(), get_cell_size(),
+ and get_padding() and is therefore not something you set directly.
+ - #get_fhog_window_height() is approximately equal to the number of HOG cells
+ that fit into get_detection_window_height() pixels plus 2*get_padding()
+ since we include additional padding around each window to add context.
+ !*/
+
+ void set_padding (
+ unsigned long new_padding
+ );
+ /*!
+ ensures
+ - #get_padding() == new_padding
+ - #is_loaded_with_image() == false
+ !*/
+
+ unsigned long get_padding (
+ ) const;
+ /*!
+ ensures
+ - The HOG windows scanned over the HOG pyramid can include additional HOG
+ cells outside the detection window. This can help add context and
+ improve detection accuracy. This function returns the number of extra
+ HOG cells added onto the border of the HOG windows which are scanned by
+ detect().
+ !*/
+
+ unsigned long get_cell_size (
+ ) const;
+ /*!
+ ensures
+ - Returns the size of the HOG cells. Each HOG cell is square and contains
+ get_cell_size()*get_cell_size() pixels.
+ !*/
+
+ void set_cell_size (
+ unsigned long new_cell_size
+ );
+ /*!
+ requires
+ - new_cell_size > 0
+ ensures
+ - #get_cell_size() == new_cell_size
+ - #is_loaded_with_image() == false
+ !*/
+
+ inline long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns get_fhog_window_width()*get_fhog_window_height()*get_feature_extractor().get_num_planes()
+ (i.e. The number of features is equal to the size of the HOG window times
+ the number of planes output by the feature extractor. )
+ !*/
+
+ inline unsigned long get_num_detection_templates (
+ ) const { return 1; }
+ /*!
+ ensures
+ - returns 1. Note that this function is here only for compatibility with
+ the scan_image_pyramid object. Notionally, its return value indicates
+ that a scan_fhog_pyramid object is always ready to detect objects once
+ an image has been loaded.
+ !*/
+
+ inline unsigned long get_num_movable_components_per_detection_template (
+ ) const { return 0; }
+ /*!
+ ensures
+ - returns 0. Note that this function is here only for compatibility with
+ the scan_image_pyramid object. Its return value means that this object
+ does not support using movable part models.
+ !*/
+
+ unsigned long get_max_pyramid_levels (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of image pyramid levels this object will use.
+ Note that #get_max_pyramid_levels() == 1 indicates that no image pyramid
+ will be used at all. That is, only the original image will be processed
+ and no lower scale versions will be created.
+ !*/
+
+ void set_max_pyramid_levels (
+ unsigned long max_levels
+ );
+ /*!
+ requires
+ - max_levels > 0
+ ensures
+ - #get_max_pyramid_levels() == max_levels
+ !*/
+
+ void set_min_pyramid_layer_size (
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ requires
+ - width > 0
+ - height > 0
+ ensures
+ - #get_min_pyramid_layer_width() == width
+ - #get_min_pyramid_layer_height() == height
+ !*/
+
+ inline unsigned long get_min_pyramid_layer_width (
+ ) const;
+ /*!
+ ensures
+ - returns the smallest allowable width of an image in the image pyramid.
+ All pyramids will always include the original input image, however, no
+ pyramid levels will be created which have a width smaller than the
+ value returned by this function.
+ !*/
+
+ inline unsigned long get_min_pyramid_layer_height (
+ ) const;
+ /*!
+ ensures
+ - returns the smallest allowable height of an image in the image pyramid.
+ All pyramids will always include the original input image, however, no
+ pyramid levels will be created which have a height smaller than the
+ value returned by this function.
+ !*/
+
+ fhog_filterbank build_fhog_filterbank (
+ const feature_vector_type& weights
+ ) const;
+ /*!
+ requires
+ - weights.size() >= get_num_dimensions()
+ ensures
+ - Creates and then returns a fhog_filterbank object FB such that:
+ - FB.get_num_dimensions() == get_num_dimensions()
+ - FB.get_filters() == the values in weights unpacked into get_feature_extractor().get_num_planes() filters.
+ - FB.num_separable_filters() == the number of separable filters necessary to
+ represent all the filters in FB.get_filters().
+ !*/
+
+ class fhog_filterbank
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a HOG filter bank. That is, the classifier that is
+ slid over a HOG pyramid is a set of get_feature_extractor().get_num_planes()
+ linear filters, each get_fhog_window_width() rows by get_fhog_window_height()
+ columns in size. This object contains that set of filters.
+ !*/
+
+ public:
+ long get_num_dimensions(
+ ) const;
+ /*!
+ ensures
+ - Returns the total number of values in the filters.
+ !*/
+
+ const std::vector<matrix<float> >& get_filters(
+ ) const;
+ /*!
+ ensures
+ - returns the set of HOG filters in this object.
+ !*/
+
+ unsigned long num_separable_filters(
+ ) const;
+ /*!
+ ensures
+ - returns the number of separable filters necessary to represent all
+ the filters in get_filters().
+ !*/
+ };
+
+ void detect (
+ const fhog_filterbank& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const;
+ /*!
+ requires
+ - w.get_num_dimensions() == get_num_dimensions()
+ - is_loaded_with_image() == true
+ ensures
+ - Scans the HOG filter defined by w over the HOG pyramid that was populated
+ by the last call to load() and stores all object detections into #dets.
+ - for all valid i:
+ - #dets[i].second == The object box which produced this detection. This rectangle gives
+ the location of the detection. Note that the rectangle will have been converted back into
+ the original image input space. That is, if this detection was made at a low level in the
+ image pyramid then the object box will have been automatically mapped up the pyramid layers
+ to the original image space. Or in other words, if you plot #dets[i].second on top of the
+ image given to load() it will show up in the right place.
+ - #dets[i].first == The score for this detection. This value is equal to dot(w, feature vector
+ for this sliding window location).
+ - #dets[i].first >= thresh
+ - #dets will be sorted in descending order. (i.e. #dets[i].first >= #dets[j].first for all i, and j>i)
+ - Elements of w beyond index get_num_dimensions()-1 are ignored. I.e. only the first
+ get_num_dimensions() are used.
+ - Note that no form of non-max suppression is performed. If a window has a score >= thresh
+ then it is reported in #dets.
+ !*/
+
+ void detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const;
+ /*!
+ requires
+ - w.size() >= get_num_dimensions()
+ - is_loaded_with_image() == true
+ ensures
+ - performs: detect(build_fhog_filterbank(w), dets, thresh)
+ !*/
+
+ void get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const;
+ /*!
+ requires
+ - obj.num_parts() == 0
+ - is_loaded_with_image() == true
+ - psi.size() >= get_num_dimensions()
+ (i.e. psi must have preallocated its memory before this function is called)
+ ensures
+ - This function allows you to determine the feature vector used for an
+ object detection output from detect(). Note that this vector is
+ added to psi. Note also that you can use get_full_object_detection() to
+ convert a rectangle from detect() into the needed full_object_detection.
+ - The dimensionality of the vector added to psi is get_num_dimensions(). This
+ means that elements of psi after psi(get_num_dimensions()-1) are not modified.
+ - Since scan_fhog_pyramid only searches a limited set of object locations,
+ not all possible rectangles can be output by detect(). So in the case
+ where obj.get_rect() could not arise from a call to detect(), this
+ function will map obj.get_rect() to the nearest possible rectangle and
+ then add the feature vector for the mapped rectangle into #psi.
+ - get_best_matching_rect(obj.get_rect()) == the rectangle obj.get_rect()
+ gets mapped to for feature extraction.
+ !*/
+
+ full_object_detection get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& w
+ ) const;
+ /*!
+ ensures
+ - returns full_object_detection(rect)
+ (This function is here only for compatibility with the scan_image_pyramid
+ object)
+ !*/
+
+ const rectangle get_best_matching_rect (
+ const rectangle& rect
+ ) const;
+ /*!
+ ensures
+ - Since scan_fhog_pyramid only searches a limited set of object locations,
+ not all possible rectangles can be represented. Therefore, this function
+ allows you to supply a rectangle and obtain the nearest possible
+ candidate object location rectangle.
+ !*/
+
+ double get_nuclear_norm_regularization_strength (
+ ) const;
+ /*!
+ ensures
+ - If the number of separable filters in a fhog_filterbank is small then the
+ filter bank can be scanned over an image much faster than a normal set of
+ filters. Therefore, this object provides the option to encourage
+ machine learning methods that learn a HOG filter bank (i.e.
+ structural_object_detection_trainer) to select filter banks that have
+ this beneficial property. In particular, the value returned by
+ get_nuclear_norm_regularization_strength() is a multiplier on a nuclear
+ norm regularizer which will encourage the selection of filters that use a
+ small number of separable components. Larger values encourage tend to
+ give a smaller number of separable filters.
+ - if (get_nuclear_norm_regularization_strength() == 0) then
+ - This feature is disabled
+ - else
+ - A nuclear norm regularizer will be added when
+ structural_object_detection_trainer is used to learn a HOG filter
+ bank. Note that this can make the training process take
+ significantly longer (but can result in faster object detectors).
+ !*/
+
+ void set_nuclear_norm_regularization_strength (
+ double strength
+ );
+ /*!
+ requires
+ - strength >= 0
+ ensures
+ - #get_nuclear_norm_regularization_strength() == strength
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const scan_fhog_pyramid<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void deserialize (
+ scan_fhog_pyramid<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type,
+ typename image_type
+ >
+ void evaluate_detectors (
+ const std::vector<object_detector<scan_fhog_pyramid<pyramid_type>>>& detectors,
+ const image_type& img,
+ std::vector<rect_detection>& dets,
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - img contains some kind of pixel type.
+ (i.e. pixel_traits<typename image_type::type> is defined)
+ ensures
+ - This function runs each of the provided object_detector objects over img and
+ stores the resulting detections into #dets. Importantly, this function is
+ faster than running each detector individually because it computes the HOG
+ features only once and then reuses them for each detector. However, it is
+ important to note that this speedup is only possible if all the detectors use
+ the same cell_size parameter that determines how HOG features are computed.
+ If different cell_size values are used then this function will not be any
+ faster than running the detectors individually.
+ - This function applies non-max suppression individually to the output of each
+ detector. Therefore, the output is the same as if you ran each detector
+ individually and then concatenated the results.
+ - To be precise, this function performs object detection on the given image and
+ stores the detected objects into #dets. In particular, we will have that:
+ - #dets is sorted such that the highest confidence detections come first.
+ E.g. element 0 is the best detection, element 1 the next best, and so on.
+ - #dets.size() == the number of detected objects.
+ - #dets[i].detection_confidence == The strength of the i-th detection.
+ Larger values indicate that the detector is more confident that #dets[i]
+ is a correct detection rather than being a false alarm. Moreover, the
+ detection_confidence is equal to the detection value output by the
+ scanner minus the threshold value stored at the end of the weight vector.
+ - #dets[i].rect == the bounding box for the i-th detection.
+ - The detection #dets[i].rect was produced by detectors[#dets[i].weight_index].
+ - The detection threshold is adjusted by having adjust_threshold added to it.
+ Therefore, an adjust_threshold value > 0 makes detecting objects harder while
+ a negative value makes it easier. Moreover, the following will be true for
+ all valid i:
+ - #dets[i].detection_confidence >= adjust_threshold
+ This means that, for example, you can obtain the maximum possible number of
+ detections by setting adjust_threshold equal to negative infinity.
+ - This function is threadsafe in the sense that multiple threads can call
+ evaluate_detectors() with the same instances of detectors and img without
+ requiring a mutex lock.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type,
+ typename image_type
+ >
+ std::vector<rectangle> evaluate_detectors (
+ const std::vector<object_detector<scan_fhog_pyramid<pyramid_type>>>& detectors,
+ const image_type& img,
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - img contains some kind of pixel type.
+ (i.e. pixel_traits<typename image_type::type> is defined)
+ ensures
+ - This function just calls the above evaluate_detectors() routine and copies
+ the output dets into a vector<rectangle> object and returns it. Therefore,
+ this function is provided for convenience.
+ - This function is threadsafe in the sense that multiple threads can call
+ evaluate_detectors() with the same instances of detectors and img without
+ requiring a mutex lock.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_fHOG_PYRAMID_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/scan_image.h b/ml/dlib/dlib/image_processing/scan_image.h
new file mode 100644
index 000000000..1a9c46eda
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image.h
@@ -0,0 +1,368 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SCAN_iMAGE_Hh_
+#define DLIB_SCAN_iMAGE_Hh_
+
+#include <vector>
+#include <utility>
+#include "scan_image_abstract.h"
+#include "../matrix.h"
+#include "../algs.h"
+#include "../rand.h"
+#include "../array2d.h"
+#include "../image_transforms/spatial_filtering.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ inline rectangle bounding_box_of_rects (
+ const std::vector<std::pair<unsigned int, rectangle> >& rects,
+ const point& position
+ )
+ /*!
+ ensures
+ - returns the smallest rectangle that contains all the
+ rectangles in rects. That is, returns the rectangle that
+ contains translate_rect(rects[i].second,position) for all valid i.
+ !*/
+ {
+ rectangle rect;
+
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ rect += translate_rect(rects[i].second,position);
+ }
+
+ return rect;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ bool all_images_same_size (
+ const image_array_type& images
+ )
+ {
+ if (images.size() == 0)
+ return true;
+
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ if (num_rows(images[0]) != num_rows(images[i]) ||
+ num_columns(images[0]) != num_columns(images[i]))
+ return false;
+ }
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ double sum_of_rects_in_images (
+ const image_array_type& images,
+ const std::vector<std::pair<unsigned int, rectangle> >& rects,
+ const point& position
+ )
+ {
+ DLIB_ASSERT(all_images_same_size(images),
+ "\t double sum_of_rects_in_images()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t all_images_same_size(images): " << all_images_same_size(images)
+ );
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ DLIB_ASSERT(rects[i].first < images.size(),
+ "\t double sum_of_rects_in_images()"
+ << "\n\t rects["<<i<<"].first must refer to a valid image."
+ << "\n\t rects["<<i<<"].first: " << rects[i].first
+ << "\n\t images.size(): " << images.size()
+ );
+ }
+#endif
+
+
+ typedef typename image_traits<typename image_array_type::type>::pixel_type pixel_type;
+ typedef typename promote<pixel_type>::type ptype;
+
+ ptype temp = 0;
+
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ const typename image_array_type::type& img = images[rects[i].first];
+ const rectangle rect = get_rect(img).intersect(translate_rect(rects[i].second,position));
+ temp += sum(matrix_cast<ptype>(subm(mat(img), rect)));
+ }
+
+ return static_cast<double>(temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ double sum_of_rects_in_images_movable_parts (
+ const image_array_type& images,
+ const rectangle& window,
+ const std::vector<std::pair<unsigned int, rectangle> >& fixed_rects,
+ const std::vector<std::pair<unsigned int, rectangle> >& movable_rects,
+ const point& position
+ )
+ {
+ DLIB_ASSERT(all_images_same_size(images) && center(window) == point(0,0),
+ "\t double sum_of_rects_in_images_movable_parts()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t all_images_same_size(images): " << all_images_same_size(images)
+ << "\n\t center(window): " << center(window)
+ );
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < fixed_rects.size(); ++i)
+ {
+ DLIB_ASSERT(fixed_rects[i].first < images.size(),
+ "\t double sum_of_rects_in_images_movable_parts()"
+ << "\n\t fixed_rects["<<i<<"].first must refer to a valid image."
+ << "\n\t fixed_rects["<<i<<"].first: " << fixed_rects[i].first
+ << "\n\t images.size(): " << images.size()
+ );
+ }
+ for (unsigned long i = 0; i < movable_rects.size(); ++i)
+ {
+ DLIB_ASSERT(movable_rects[i].first < images.size(),
+ "\t double sum_of_rects_in_images_movable_parts()"
+ << "\n\t movable_rects["<<i<<"].first must refer to a valid image."
+ << "\n\t movable_rects["<<i<<"].first: " << movable_rects[i].first
+ << "\n\t images.size(): " << images.size()
+ );
+ DLIB_ASSERT(center(movable_rects[i].second) == point(0,0),
+ "\t double sum_of_rects_in_images_movable_parts()"
+ << "\n\t movable_rects["<<i<<"].second: " << movable_rects[i].second
+ );
+ }
+#endif
+ typedef typename image_traits<typename image_array_type::type>::pixel_type pixel_type;
+ typedef typename promote<pixel_type>::type ptype;
+
+ ptype temp = 0;
+
+ // compute TOTAL_FIXED part
+ for (unsigned long i = 0; i < fixed_rects.size(); ++i)
+ {
+ const typename image_array_type::type& img = images[fixed_rects[i].first];
+ const rectangle rect = get_rect(img).intersect(translate_rect(fixed_rects[i].second,position));
+ temp += sum(matrix_cast<ptype>(subm(mat(img), rect)));
+ }
+
+ if (images.size() > 0)
+ {
+ // compute TOTAL_MOVABLE part
+ array2d<ptype> tempimg(images[0].nr(), images[0].nc());
+ for (unsigned long i = 0; i < movable_rects.size(); ++i)
+ {
+ const typename image_array_type::type& img = images[movable_rects[i].first];
+
+ sum_filter_assign(img, tempimg, movable_rects[i].second);
+
+ const rectangle rect = get_rect(tempimg).intersect(translate_rect(window,position));
+ if (rect.is_empty() == false)
+ temp += std::max(0,max(matrix_cast<ptype>(subm(mat(tempimg), rect))));
+ }
+ }
+
+ return static_cast<double>(temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void find_points_above_thresh (
+ std::vector<std::pair<double, point> >& dets,
+ const image_type& img_,
+ const double thresh,
+ const unsigned long max_dets
+ )
+ {
+ const_image_view<image_type> img(img_);
+ typedef typename image_traits<image_type>::pixel_type ptype;
+
+ dets.clear();
+ if (max_dets == 0)
+ return;
+
+ unsigned long count = 0;
+ dlib::rand rnd;
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ const ptype val = img[r][c];
+ if (val >= thresh)
+ {
+ ++count;
+
+ if (dets.size() < max_dets)
+ {
+ dets.push_back(std::make_pair(val, point(c,r)));
+ }
+ else
+ {
+ // The idea here is to cause us to randomly sample possible detection
+ // locations throughout the image rather than just stopping the detection
+ // procedure once we hit the max_dets limit. So this method will result
+ // in a random subsample of all the detections >= thresh being in dets
+ // at the end of scan_image().
+ const unsigned long random_index = rnd.get_random_32bit_number()%count;
+ if (random_index < dets.size())
+ {
+ dets[random_index] = std::make_pair(val, point(c,r));
+ }
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ void scan_image (
+ std::vector<std::pair<double, point> >& dets,
+ const image_array_type& images,
+ const std::vector<std::pair<unsigned int, rectangle> >& rects,
+ const double thresh,
+ const unsigned long max_dets
+ )
+ {
+ DLIB_ASSERT(images.size() > 0 && rects.size() > 0 && all_images_same_size(images),
+ "\t void scan_image()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t rects.size(): " << rects.size()
+ << "\n\t all_images_same_size(images): " << all_images_same_size(images)
+ );
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ DLIB_ASSERT(rects[i].first < images.size(),
+ "\t void scan_image()"
+ << "\n\t rects["<<i<<"].first must refer to a valid image."
+ << "\n\t rects["<<i<<"].first: " << rects[i].first
+ << "\n\t images.size(): " << images.size()
+ );
+ }
+#endif
+
+
+
+
+ typedef typename image_traits<typename image_array_type::type>::pixel_type pixel_type;
+ typedef typename promote<pixel_type>::type ptype;
+
+ array2d<ptype> accum(images[0].nr(), images[0].nc());
+ assign_all_pixels(accum, 0);
+
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ sum_filter(images[rects[i].first], accum, rects[i].second);
+
+ find_points_above_thresh(dets, accum, thresh, max_dets);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ void scan_image_movable_parts (
+ std::vector<std::pair<double, point> >& dets,
+ const image_array_type& images,
+ const rectangle& window,
+ const std::vector<std::pair<unsigned int, rectangle> >& fixed_rects,
+ const std::vector<std::pair<unsigned int, rectangle> >& movable_rects,
+ const double thresh,
+ const unsigned long max_dets
+ )
+ {
+ DLIB_ASSERT(images.size() > 0 && all_images_same_size(images) &&
+ center(window) == point(0,0) && window.area() > 0,
+ "\t void scan_image_movable_parts()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t all_images_same_size(images): " << all_images_same_size(images)
+ << "\n\t center(window): " << center(window)
+ << "\n\t window.area(): " << window.area()
+ << "\n\t images.size(): " << images.size()
+ );
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < fixed_rects.size(); ++i)
+ {
+ DLIB_ASSERT(fixed_rects[i].first < images.size(),
+ "\t void scan_image_movable_parts()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t fixed_rects["<<i<<"].first must refer to a valid image."
+ << "\n\t fixed_rects["<<i<<"].first: " << fixed_rects[i].first
+ << "\n\t images.size(): " << images.size()
+ );
+ }
+ for (unsigned long i = 0; i < movable_rects.size(); ++i)
+ {
+ DLIB_ASSERT(movable_rects[i].first < images.size(),
+ "\t void scan_image_movable_parts()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t movable_rects["<<i<<"].first must refer to a valid image."
+ << "\n\t movable_rects["<<i<<"].first: " << movable_rects[i].first
+ << "\n\t images.size(): " << images.size()
+ );
+ DLIB_ASSERT(center(movable_rects[i].second) == point(0,0) &&
+ movable_rects[i].second.area() > 0,
+ "\t void scan_image_movable_parts()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t movable_rects["<<i<<"].second: " << movable_rects[i].second
+ << "\n\t movable_rects["<<i<<"].second.area(): " << movable_rects[i].second.area()
+ );
+ }
+#endif
+
+ if (movable_rects.size() == 0 && fixed_rects.size() == 0)
+ return;
+
+ typedef typename image_traits<typename image_array_type::type>::pixel_type pixel_type;
+ typedef typename promote<pixel_type>::type ptype;
+
+ array2d<ptype> accum(images[0].nr(), images[0].nc());
+ assign_all_pixels(accum, 0);
+
+ for (unsigned long i = 0; i < fixed_rects.size(); ++i)
+ sum_filter(images[fixed_rects[i].first], accum, fixed_rects[i].second);
+
+ array2d<ptype> temp(accum.nr(), accum.nc());
+ for (unsigned long i = 0; i < movable_rects.size(); ++i)
+ {
+ const rectangle rect = movable_rects[i].second;
+ sum_filter_assign(images[movable_rects[i].first], temp, rect);
+ max_filter(temp, accum, window.width(), window.height(), 0);
+ }
+
+ find_points_above_thresh(dets, accum, thresh, max_dets);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_iMAGE_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/scan_image_abstract.h b/ml/dlib/dlib/image_processing/scan_image_abstract.h
new file mode 100644
index 000000000..fe2fc51ac
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image_abstract.h
@@ -0,0 +1,227 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SCAN_iMAGE_ABSTRACT_Hh_
+#ifdef DLIB_SCAN_iMAGE_ABSTRACT_Hh_
+
+#include <vector>
+#include <utility>
+#include "../algs.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ bool all_images_same_size (
+ const image_array_type& images
+ );
+ /*!
+ requires
+ - image_array_type == an implementation of array/array_kernel_abstract.h
+ - image_array_type::type == an image object that implements the interface
+ defined in dlib/image_processing/generic_image.h
+ ensures
+ - if (all elements of images have the same dimensions (i.e.
+ for all i and j: get_rect(images[i]) == get_rect(images[j]))) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ double sum_of_rects_in_images (
+ const image_array_type& images,
+ const std::vector<std::pair<unsigned int, rectangle> >& rects,
+ const point& position
+ );
+ /*!
+ requires
+ - image_array_type == an implementation of array/array_kernel_abstract.h
+ - image_array_type::type == an image object that implements the interface
+ defined in dlib/image_processing/generic_image.h. Moreover, these objects must
+ contain a scalar pixel type (e.g. int rather than rgb_pixel)
+ - all_images_same_size(images) == true
+ - for all valid i: rects[i].first < images.size()
+ (i.e. all the rectangles must reference valid elements of images)
+ ensures
+ - returns the sum of the pixels inside the given rectangles. To be precise,
+ let RECT_SUM[i] = sum of pixels inside the rectangle translate_rect(rects[i].second, position)
+ from the image images[rects[i].first]. Then this function returns the
+ sum of RECT_SUM[i] for all the valid values of i.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ double sum_of_rects_in_images_movable_parts (
+ const image_array_type& images,
+ const rectangle& window,
+ const std::vector<std::pair<unsigned int, rectangle> >& fixed_rects,
+ const std::vector<std::pair<unsigned int, rectangle> >& movable_rects,
+ const point& position
+ );
+ /*!
+ requires
+ - image_array_type == an implementation of array/array_kernel_abstract.h
+ - image_array_type::type == an image object that implements the interface
+ defined in dlib/image_processing/generic_image.h. Moreover, these objects must
+ contain a scalar pixel type (e.g. int rather than rgb_pixel)
+ - all_images_same_size(images) == true
+ - center(window) == point(0,0)
+ - for all valid i:
+ - fixed_rects[i].first < images.size()
+ (i.e. all the rectangles must reference valid elements of images)
+ - for all valid i:
+ - movable_rects[i].first < images.size()
+ (i.e. all the rectangles must reference valid elements of images)
+ - center(movable_rects[i].second) == point(0,0)
+ ensures
+ - returns the sum of the pixels inside fixed_rects as well as the sum of the pixels
+ inside movable_rects when these latter rectangles are placed at their highest
+ scoring locations inside the given window. To be precise:
+ - let RECT_SUM(r,x) = sum of pixels inside the rectangle translate_rect(r.second, x)
+ from the image images[r.first].
+ - let WIN_MAX(i) = The maximum value of RECT_SUM(movable_rects[i],X) when maximizing
+ over all the X such that translate_rect(window,position).contains(X) == true.
+
+ - let TOTAL_FIXED == sum over all elements R in fixed_rects of: RECT_SUM(R,position)
+ - let TOTAL_MOVABLE == sum over all valid i of: max(WIN_MAX(i), 0)
+
+ Then this function returns TOTAL_FIXED + TOTAL_MOVABLE.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void find_points_above_thresh (
+ std::vector<std::pair<double, point> >& dets,
+ const image_type& img,
+ const double thresh,
+ const unsigned long max_dets
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h. Moreover, these it must contain a
+ scalar pixel type (e.g. int rather than rgb_pixel)
+ ensures
+ - #dets == a list of points from img which had pixel values >= thresh.
+ - Specifically, we have:
+ - #dets.size() <= max_dets
+ (note that dets is cleared before new detections are added by find_points_above_thresh())
+ - for all valid i:
+ - #dets[i].first == img[#dets[i].second.y()][#dets[i].second.x()]
+ (i.e. the first field contains the value of the pixel at this detection location)
+ - #dets[i].first >= thresh
+ - if (there are more than max_dets locations that pass the above threshold test) then
+ - #dets == a random subsample of all the locations which passed the threshold
+ test.
+ - else
+ - #dets == all the points which passed the threshold test.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ void scan_image (
+ std::vector<std::pair<double, point> >& dets,
+ const image_array_type& images,
+ const std::vector<std::pair<unsigned int, rectangle> >& rects,
+ const double thresh,
+ const unsigned long max_dets
+ );
+ /*!
+ requires
+ - image_array_type == an implementation of array/array_kernel_abstract.h
+ - image_array_type::type == an image object that implements the interface
+ defined in dlib/image_processing/generic_image.h. Moreover, these objects must
+ contain a scalar pixel type (e.g. int rather than rgb_pixel)
+ - images.size() > 0
+ - rects.size() > 0
+ - all_images_same_size(images) == true
+ - for all valid i: rects[i].first < images.size()
+ (i.e. all the rectangles must reference valid elements of images)
+ ensures
+ - slides the set of rectangles over the image space and reports the locations
+ which give a sum bigger than thresh.
+ - Specifically, we have:
+ - #dets.size() <= max_dets
+ (note that dets is cleared before new detections are added by scan_image())
+ - for all valid i:
+ - #dets[i].first == sum_of_rects_in_images(images,rects,#dets[i].second) >= thresh
+ - if (there are more than max_dets locations that pass the threshold test) then
+ - #dets == a random subsample of all the locations which passed the threshold
+ test.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ void scan_image_movable_parts (
+ std::vector<std::pair<double, point> >& dets,
+ const image_array_type& images,
+ const rectangle& window,
+ const std::vector<std::pair<unsigned int, rectangle> >& fixed_rects,
+ const std::vector<std::pair<unsigned int, rectangle> >& movable_rects,
+ const double thresh,
+ const unsigned long max_dets
+ );
+ /*!
+ requires
+ - image_array_type == an implementation of array/array_kernel_abstract.h
+ - image_array_type::type == an image object that implements the interface
+ defined in dlib/image_processing/generic_image.h. Moreover, these objects must
+ contain a scalar pixel type (e.g. int rather than rgb_pixel)
+ - images.size() > 0
+ - all_images_same_size(images) == true
+ - center(window) == point(0,0)
+ - window.area() > 0
+ - for all valid i:
+ - fixed_rects[i].first < images.size()
+ (i.e. all the rectangles must reference valid elements of images)
+ - for all valid i:
+ - movable_rects[i].first < images.size()
+ (i.e. all the rectangles must reference valid elements of images)
+ - center(movable_rects[i].second) == point(0,0)
+ - movable_rects[i].second.area() > 0
+ ensures
+ - Scans the given window over the images and reports the locations with a score bigger
+ than thresh.
+ - Specifically, we have:
+ - #dets.size() <= max_dets
+ (note that dets is cleared before new detections are added by scan_image_movable_parts())
+ - for all valid i:
+ - #dets[i].first == sum_of_rects_in_images_movable_parts(images,
+ window,
+ fixed_rects,
+ movable_rects,
+ #dets[i].second) >= thresh
+ - if (there are more than max_dets locations that pass the above threshold test) then
+ - #dets == a random subsample of all the locations which passed the threshold
+ test.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_iMAGE_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/image_processing/scan_image_boxes.h b/ml/dlib/dlib/image_processing/scan_image_boxes.h
new file mode 100644
index 000000000..f4549565c
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image_boxes.h
@@ -0,0 +1,630 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SCAN_IMAGE_bOXES_Hh_
+#define DLIB_SCAN_IMAGE_bOXES_Hh_
+
+#include "scan_image_boxes_abstract.h"
+#include "../matrix.h"
+#include "../geometry.h"
+#include "../array2d.h"
+#include <vector>
+#include "../image_processing/full_object_detection.h"
+#include "../image_transforms.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class default_box_generator
+ {
+ public:
+ template <typename image_type>
+ void operator() (
+ const image_type& img,
+ std::vector<rectangle>& rects
+ ) const
+ {
+ rects.clear();
+ find_candidate_object_locations(img, rects);
+ }
+ };
+
+ inline void serialize(const default_box_generator&, std::ostream& ) {}
+ inline void deserialize(default_box_generator&, std::istream& ) {}
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator = default_box_generator
+ >
+ class scan_image_boxes : noncopyable
+ {
+
+ public:
+
+ typedef matrix<double,0,1> feature_vector_type;
+
+ typedef Feature_extractor_type feature_extractor_type;
+ typedef Box_generator box_generator;
+
+ scan_image_boxes (
+ );
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+
+ inline bool is_loaded_with_image (
+ ) const;
+
+ inline void copy_configuration(
+ const feature_extractor_type& fe
+ );
+
+ inline void copy_configuration(
+ const box_generator& bg
+ );
+
+ const box_generator& get_box_generator (
+ ) const { return detect_boxes; }
+
+ const Feature_extractor_type& get_feature_extractor (
+ ) const { return feats; }
+
+ inline void copy_configuration (
+ const scan_image_boxes& item
+ );
+
+ inline long get_num_dimensions (
+ ) const;
+
+ unsigned long get_num_spatial_pyramid_levels (
+ ) const;
+
+ void set_num_spatial_pyramid_levels (
+ unsigned long levels
+ );
+
+ void detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const;
+
+ void get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const;
+
+ full_object_detection get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& w
+ ) const;
+
+ const rectangle get_best_matching_rect (
+ const rectangle& rect
+ ) const;
+ /*!
+ requires
+ - is_loaded_with_image() == true
+ !*/
+
+ inline unsigned long get_num_detection_templates (
+ ) const { return 1; }
+
+ inline unsigned long get_num_movable_components_per_detection_template (
+ ) const { return 0; }
+
+ template <typename T, typename U>
+ friend void serialize (
+ const scan_image_boxes<T,U>& item,
+ std::ostream& out
+ );
+
+ template <typename T, typename U>
+ friend void deserialize (
+ scan_image_boxes<T,U>& item,
+ std::istream& in
+ );
+
+ private:
+ static bool compare_pair_rect (
+ const std::pair<double, rectangle>& a,
+ const std::pair<double, rectangle>& b
+ )
+ {
+ return a.first < b.first;
+ }
+
+ void test_coordinate_transforms()
+ {
+ for (long x = -10; x <= 10; x += 10)
+ {
+ for (long y = -10; y <= 10; y += 10)
+ {
+ const rectangle rect = centered_rect(x,y,5,6);
+ rectangle a;
+
+ a = feats.image_to_feat_space(rect);
+ if (a.width() > 10000000 || a.height() > 10000000 )
+ {
+ DLIB_CASSERT(false, "The image_to_feat_space() routine is outputting rectangles of an implausibly "
+ << "\nlarge size. This means there is probably a bug in your feature extractor.");
+ }
+ a = feats.feat_to_image_space(rect);
+ if (a.width() > 10000000 || a.height() > 10000000 )
+ {
+ DLIB_CASSERT(false, "The feat_to_image_space() routine is outputting rectangles of an implausibly "
+ << "\nlarge size. This means there is probably a bug in your feature extractor.");
+ }
+ }
+ }
+
+ }
+
+ static void add_grid_rects (
+ std::vector<rectangle>& rects,
+ const rectangle& object_box,
+ unsigned int cells_x,
+ unsigned int cells_y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(cells_x > 0 && cells_y > 0,
+ "\t void add_grid_rects()"
+ << "\n\t The number of cells along a dimension can't be zero. "
+ << "\n\t cells_x: " << cells_x
+ << "\n\t cells_y: " << cells_y
+ );
+
+ const matrix_range_exp<double>& x = linspace(object_box.left(), object_box.right(), cells_x+1);
+ const matrix_range_exp<double>& y = linspace(object_box.top(), object_box.bottom(), cells_y+1);
+
+ for (long j = 0; j+1 < y.size(); ++j)
+ {
+ for (long i = 0; i+1 < x.size(); ++i)
+ {
+ const dlib::vector<double,2> tl(x(i),y(j));
+ const dlib::vector<double,2> br(x(i+1),y(j+1));
+ rects.push_back(rectangle(tl,br));
+ }
+ }
+ }
+
+ void get_feature_extraction_regions (
+ const rectangle& rect,
+ std::vector<rectangle>& regions
+ ) const
+ /*!
+ ensures
+ - #regions.size() is always the same number no matter what the input is. The
+ regions also have a consistent ordering.
+ - all the output rectangles are contained within rect.
+ !*/
+ {
+ regions.clear();
+
+ for (unsigned int l = 1; l <= num_spatial_pyramid_levels; ++l)
+ {
+ const int cells = (int)std::pow(2.0, l-1.0);
+ add_grid_rects(regions, rect, cells, cells);
+ }
+ }
+
+ unsigned int get_num_components_per_detection_template(
+ ) const
+ {
+ return (unsigned int)(std::pow(4.0,(double)num_spatial_pyramid_levels)-1)/3;
+ }
+
+ feature_extractor_type feats;
+ std::vector<rectangle> search_rects;
+ bool loaded_with_image;
+ unsigned int num_spatial_pyramid_levels;
+ box_generator detect_boxes;
+
+ const long box_sizedims;
+ const long box_maxsize;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ void serialize (
+ const scan_image_boxes<T,U>& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.feats, out);
+ serialize(item.search_rects, out);
+ serialize(item.loaded_with_image, out);
+ serialize(item.num_spatial_pyramid_levels, out);
+ serialize(item.detect_boxes, out);
+ serialize(item.get_num_dimensions(), out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ void deserialize (
+ scan_image_boxes<T,U>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unsupported version found when deserializing a scan_image_boxes object.");
+
+ deserialize(item.feats, in);
+ deserialize(item.search_rects, in);
+ deserialize(item.loaded_with_image, in);
+ deserialize(item.num_spatial_pyramid_levels, in);
+ deserialize(item.detect_boxes, in);
+
+ // When developing some feature extractor, it's easy to accidentally change its
+ // number of dimensions and then try to deserialize data from an older version of
+ // your extractor into the current code. This check is here to catch that kind of
+ // user error.
+ long dims;
+ deserialize(dims, in);
+ if (item.get_num_dimensions() != dims)
+ throw serialization_error("Number of dimensions in serialized scan_image_boxes doesn't match the expected number.");
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// scan_image_boxes member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ scan_image_boxes<Feature_extractor_type,Box_generator>::
+ scan_image_boxes (
+ ) :
+ loaded_with_image(false),
+ num_spatial_pyramid_levels(3),
+ box_sizedims(20),
+ box_maxsize(1200)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ template <
+ typename image_type
+ >
+ void scan_image_boxes<Feature_extractor_type,Box_generator>::
+ load (
+ const image_type& img
+ )
+ {
+ feats.load(img);
+ detect_boxes(img, search_rects);
+ loaded_with_image = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ bool scan_image_boxes<Feature_extractor_type,Box_generator>::
+ is_loaded_with_image (
+ ) const
+ {
+ return loaded_with_image;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ void scan_image_boxes<Feature_extractor_type,Box_generator>::
+ copy_configuration(
+ const feature_extractor_type& fe
+ )
+ {
+ test_coordinate_transforms();
+ feats.copy_configuration(fe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ void scan_image_boxes<Feature_extractor_type,Box_generator>::
+ copy_configuration(
+ const box_generator& bg
+ )
+ {
+ detect_boxes = bg;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ void scan_image_boxes<Feature_extractor_type,Box_generator>::
+ copy_configuration (
+ const scan_image_boxes& item
+ )
+ {
+ feats.copy_configuration(item.feats);
+ detect_boxes = item.detect_boxes;
+ num_spatial_pyramid_levels = item.num_spatial_pyramid_levels;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ unsigned long scan_image_boxes<Feature_extractor_type,Box_generator>::
+ get_num_spatial_pyramid_levels (
+ ) const
+ {
+ return num_spatial_pyramid_levels;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ void scan_image_boxes<Feature_extractor_type,Box_generator>::
+ set_num_spatial_pyramid_levels (
+ unsigned long levels
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(levels > 0,
+ "\t void scan_image_boxes::set_num_spatial_pyramid_levels()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t levels: " << levels
+ << "\n\t this: " << this
+ );
+
+
+ num_spatial_pyramid_levels = levels;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ long scan_image_boxes<Feature_extractor_type,Box_generator>::
+ get_num_dimensions (
+ ) const
+ {
+ return feats.get_num_dimensions()*get_num_components_per_detection_template() + box_sizedims*2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ void scan_image_boxes<Feature_extractor_type,Box_generator>::
+ detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_loaded_with_image() &&
+ w.size() >= get_num_dimensions(),
+ "\t void scan_image_boxes::detect()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t w.size(): " << w.size()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t this: " << this
+ );
+
+ dets.clear();
+
+ array<integral_image_generic<double> > saliency_images(get_num_components_per_detection_template());
+
+ array2d<double> temp_img(feats.nr(), feats.nc());
+
+ // build saliency images
+ for (unsigned long i = 0; i < saliency_images.size(); ++i)
+ {
+ const unsigned long offset = 2*box_sizedims + feats.get_num_dimensions()*i;
+
+ // make the basic saliency image for the i-th feature extraction region
+ for (long r = 0; r < feats.nr(); ++r)
+ {
+ for (long c = 0; c < feats.nc(); ++c)
+ {
+ const typename feature_extractor_type::descriptor_type& descriptor = feats(r,c);
+
+ double sum = 0;
+ for (unsigned long k = 0; k < descriptor.size(); ++k)
+ {
+ sum += w(descriptor[k].first + offset)*descriptor[k].second;
+ }
+ temp_img[r][c] = sum;
+ }
+ }
+
+ // now convert base saliency image into final integral image
+ saliency_images[i].load(temp_img);
+ }
+
+
+ // now search the saliency images
+ std::vector<rectangle> regions;
+ const rectangle bounds = get_rect(feats);
+ for (unsigned long i = 0; i < search_rects.size(); ++i)
+ {
+ const rectangle rect = feats.image_to_feat_space(search_rects[i]).intersect(bounds);
+ if (rect.is_empty())
+ continue;
+ get_feature_extraction_regions(rect, regions);
+ double score = 0;
+ for (unsigned long k = 0; k < regions.size(); ++k)
+ {
+ score += saliency_images[k].get_sum_of_area(regions[k]);
+ }
+ const double width = search_rects[i].width();
+ const double height = search_rects[i].height();
+
+ score += dot(linpiece(width, linspace(0, box_maxsize, box_sizedims+1)), rowm(w, range(0,box_sizedims-1)));
+ score += dot(linpiece(height, linspace(0, box_maxsize, box_sizedims+1)), rowm(w, range(box_sizedims,2*box_sizedims-1)));
+
+ if (score >= thresh)
+ {
+ dets.push_back(std::make_pair(score, search_rects[i]));
+ }
+ }
+
+ std::sort(dets.rbegin(), dets.rend(), compare_pair_rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ const rectangle scan_image_boxes<Feature_extractor_type,Box_generator>::
+ get_best_matching_rect (
+ const rectangle& rect
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_loaded_with_image(),
+ "\t const rectangle scan_image_boxes::get_best_matching_rect()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t this: " << this
+ );
+
+
+ double best_score = -1;
+ rectangle best_rect;
+ for (unsigned long i = 0; i < search_rects.size(); ++i)
+ {
+ const double score = (rect.intersect(search_rects[i])).area()/(double)(rect+search_rects[i]).area();
+ if (score > best_score)
+ {
+ best_score = score;
+ best_rect = search_rects[i];
+ }
+ }
+ return best_rect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ full_object_detection scan_image_boxes<Feature_extractor_type,Box_generator>::
+ get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& /*w*/
+ ) const
+ {
+ return full_object_detection(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ void scan_image_boxes<Feature_extractor_type,Box_generator>::
+ get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_loaded_with_image() &&
+ psi.size() >= get_num_dimensions() &&
+ obj.num_parts() == 0,
+ "\t void scan_image_boxes::get_feature_vector()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t psi.size(): " << psi.size()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t obj.num_parts(): " << obj.num_parts()
+ << "\n\t this: " << this
+ );
+
+
+
+ const rectangle best_rect = get_best_matching_rect(obj.get_rect());
+ const rectangle mapped_rect = feats.image_to_feat_space(best_rect).intersect(get_rect(feats));
+ if (mapped_rect.is_empty())
+ return;
+
+ std::vector<rectangle> regions;
+ get_feature_extraction_regions(mapped_rect, regions);
+
+ // pull features out of all the boxes in regions.
+ for (unsigned long j = 0; j < regions.size(); ++j)
+ {
+ const rectangle rect = regions[j];
+
+ const unsigned long template_region_id = j;
+ const unsigned long offset = box_sizedims*2 + feats.get_num_dimensions()*template_region_id;
+ for (long r = rect.top(); r <= rect.bottom(); ++r)
+ {
+ for (long c = rect.left(); c <= rect.right(); ++c)
+ {
+ const typename feature_extractor_type::descriptor_type& descriptor = feats(r,c);
+ for (unsigned long k = 0; k < descriptor.size(); ++k)
+ {
+ psi(descriptor[k].first + offset) += descriptor[k].second;
+ }
+ }
+ }
+ }
+
+ const double width = best_rect.width();
+ const double height = best_rect.height();
+ set_rowm(psi, range(0,box_sizedims-1)) += linpiece(width, linspace(0, box_maxsize, box_sizedims+1));
+ set_rowm(psi, range(box_sizedims,box_sizedims*2-1)) += linpiece(height, linspace(0, box_maxsize, box_sizedims+1));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_IMAGE_bOXES_Hh_
+
+
+
diff --git a/ml/dlib/dlib/image_processing/scan_image_boxes_abstract.h b/ml/dlib/dlib/image_processing/scan_image_boxes_abstract.h
new file mode 100644
index 000000000..e2f16aa76
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image_boxes_abstract.h
@@ -0,0 +1,394 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SCAN_IMAGE_bOXES_ABSTRACT_Hh_
+#ifdef DLIB_SCAN_IMAGE_bOXES_ABSTRACT_Hh_
+
+#include "../matrix.h"
+#include "../geometry.h"
+#include "../image_processing.h"
+#include "../array2d.h"
+#include "full_object_detection_abstract.h"
+#include "../image_transforms/segment_image_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class default_box_generator
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a function object that takes in an image and outputs a set of
+ candidate object locations. It is also the default box generator used by
+ the scan_image_boxes object defined below.
+ !*/
+
+ public:
+
+ template <typename image_type>
+ void operator() (
+ const image_type& img,
+ std::vector<rectangle>& rects
+ ) const
+ /*!
+ ensures
+ - #rects == the set of candidate object locations which should be searched
+ inside img. That is, these are the rectangles which might contain
+ objects of interest within the given image.
+ !*/
+ {
+ rects.clear();
+ find_candidate_object_locations(img, rects);
+ }
+ };
+
+ inline void serialize (const default_box_generator&, std::ostream& ) {}
+ inline void deserialize( default_box_generator&, std::istream& ) {}
+ /*!
+ ensures
+ - provides serialization support.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator = default_box_generator
+ >
+ class scan_image_boxes : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON Feature_extractor_type
+ - must be an object with an interface compatible with the hashed_feature_image
+ object defined in dlib/image_keypoint/hashed_feature_image_abstract.h or
+ with the nearest_neighbor_feature_image object defined in
+ dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h
+
+ REQUIREMENTS ON Box_generator
+ - must be an object with an interface compatible with the
+ default_box_generator object defined at the top of this file.
+
+ INITIAL VALUE
+ - get_num_spatial_pyramid_levels() == 3
+ - is_loaded_with_image() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for running a classifier over an image with the goal
+ of localizing each object present. The localization is in the form of the
+ bounding box around each object of interest.
+
+ Unlike the scan_image_pyramid object which scans a fixed sized window over
+ an image pyramid, the scan_image_boxes tool allows you to define your own
+ list of "candidate object locations" which should be evaluated. This is
+ simply a list of rectangle objects which might contain objects of interest.
+ The scan_image_boxes object will then evaluate the classifier at each of
+ these locations and return the subset of rectangles which appear to have
+ objects in them. The candidate object location generation is provided by
+ the Box_generator that is passed in as a template argument.
+
+ This object can also be understood as a general tool for implementing the
+ spatial pyramid models described in the paper:
+ Beyond Bags of Features: Spatial Pyramid Matching for Recognizing
+ Natural Scene Categories by Svetlana Lazebnik, Cordelia Schmid,
+ and Jean Ponce
+
+
+ The classifiers used by this object have three parts:
+ 1. The underlying feature extraction provided by Feature_extractor_type
+ objects, which associate a vector with each location in an image.
+
+ 2. A rule for extracting a feature vector from a candidate object
+ location. In this object we use the spatial pyramid matching method.
+ This means we cut an object's detection window into a set of "feature
+ extraction regions" and extract a bag-of-words vector from each
+ before finally concatenating them to form the final feature vector
+ representing the entire object window. The set of feature extraction
+ regions can be configured by the user by calling
+ set_num_spatial_pyramid_levels(). To be a little more precise, the
+ feature vector for a candidate object window is defined as follows:
+ - Let N denote the number of feature extraction zones.
+ - Let M denote the dimensionality of the vectors output by
+ Feature_extractor_type objects.
+ - Let F(i) == the M dimensional vector which is the sum of all
+ vectors given by our Feature_extractor_type object inside the
+ i-th feature extraction zone. So this is notionally a
+ bag-of-words vector from the i-th zone.
+ - Then the feature vector for an object window is an M*N
+ dimensional vector [F(1) F(2) F(3) ... F(N)] (i.e. it is a
+ concatenation of the N vectors). This feature vector can be
+ thought of as a collection of N bags-of-words, each bag coming
+ from a spatial location determined by one of the feature
+ extraction zones.
+
+ 3. A weight vector and a threshold value. The dot product between the
+ weight vector and the feature vector for a candidate object location
+ gives the score of the location. If this score is greater than the
+ threshold value then the candidate object location is output as a
+ detection.
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be
+ protected by a mutex lock except for the case where you are copying the
+ configuration (via copy_configuration()) of a scan_image_boxes object to
+ many other threads. In this case, it is safe to copy the configuration of
+ a shared object so long as no other operations are performed on it.
+ !*/
+
+ public:
+
+ typedef matrix<double,0,1> feature_vector_type;
+
+ typedef Feature_extractor_type feature_extractor_type;
+ typedef Box_generator box_generator;
+
+ scan_image_boxes (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type must be a type with the following properties:
+ - image_type objects can be loaded into Feature_extractor_type
+ objects via Feature_extractor_type::load().
+ - image_type objects can be passed to the first argument of
+ Box_generator::operator()
+ ensures
+ - #is_loaded_with_image() == true
+ - This object is ready to run a classifier over img to detect object
+ locations. Call detect() to do this.
+ !*/
+
+ bool is_loaded_with_image (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object has been loaded with an image to process and
+ false otherwise.
+ !*/
+
+ const feature_extractor_type& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the feature_extractor_type object used
+ internally for local feature extraction.
+ !*/
+
+ void copy_configuration(
+ const feature_extractor_type& fe
+ );
+ /*!
+ ensures
+ - This function performs the equivalent of
+ get_feature_extractor().copy_configuration(fe) (i.e. this function allows
+ you to configure the parameters of the underlying feature extractor used
+ by a scan_image_boxes object)
+ !*/
+
+ void copy_configuration(
+ const box_generator& bg
+ );
+ /*!
+ ensures
+ - #get_box_generator() == bg
+ (i.e. this function allows you to configure the parameters of the
+ underlying box generator used by a scan_image_boxes object)
+ !*/
+
+ const box_generator& get_box_generator (
+ ) const;
+ /*!
+ ensures
+ - returns the box_generator used by this object to generate candidate
+ object locations.
+ !*/
+
+ void copy_configuration (
+ const scan_image_boxes& item
+ );
+ /*!
+ ensures
+ - Copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two scan_image_boxes
+ objects S1 and S2, the following sequence of instructions should always
+ result in both of them having the exact same state:
+ S2.copy_configuration(S1);
+ S1.load(img);
+ S2.load(img);
+ !*/
+
+ long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns the number of dimensions in the feature vector for a candidate
+ object location. This value is the dimensionality of the underlying
+ feature vectors produced by Feature_extractor_type times the number of
+ feature extraction regions used. Note that the number of feature
+ extraction regions used is a function of
+ get_num_spatial_pyramid_levels().
+ !*/
+
+ unsigned long get_num_spatial_pyramid_levels (
+ ) const;
+ /*!
+ ensures
+ - returns the number of layers in the spatial pyramid. For example, if
+ this function returns 1 then it means we use a simple bag-of-words
+ representation over the whole object window. If it returns 2 then it
+ means the feature representation is the concatenation of 5 bag-of-words
+ vectors, one from the entire object window and 4 others from 4 different
+ parts of the object window. If it returns 3 then there are 1+4+16
+ bag-of-words vectors concatenated together in the feature representation,
+ and so on.
+ !*/
+
+ void set_num_spatial_pyramid_levels (
+ unsigned long levels
+ );
+ /*!
+ requires
+ - levels > 0
+ ensures
+ - #get_num_spatial_pyramid_levels() == levels
+ !*/
+
+ void detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const;
+ /*!
+ requires
+ - w.size() >= get_num_dimensions()
+ - is_loaded_with_image() == true
+ ensures
+ - Scans over all the candidate object locations as discussed in the WHAT
+ THIS OBJECT REPRESENTS section and stores all detections into #dets.
+ - for all valid i:
+ - #dets[i].second == The candidate object location which produced this
+ detection. This rectangle gives the location of the detection.
+ - #dets[i].first == The score for this detection. This value is equal
+ to dot(w, feature vector for this candidate object location).
+ - #dets[i].first >= thresh
+ - #dets will be sorted in descending order.
+ (i.e. #dets[i].first >= #dets[j].first for all i, and j>i)
+ - Elements of w beyond index get_num_dimensions()-1 are ignored. I.e. only
+ the first get_num_dimensions() are used.
+ - Note that no form of non-max suppression is performed. If a locations
+ has a score >= thresh then it is reported in #dets.
+ !*/
+
+ void get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const;
+ /*!
+ requires
+ - obj.num_parts() == 0
+ - is_loaded_with_image() == true
+ - psi.size() >= get_num_dimensions()
+ (i.e. psi must have preallocated its memory before this function is called)
+ ensures
+ - This function allows you to determine the feature vector used for a
+ candidate object location output from detect(). Note that this vector is
+ added to psi. Note also that you must use get_full_object_detection() to
+ convert a rectangle from detect() into the needed full_object_detection.
+ - The dimensionality of the vector added to psi is get_num_dimensions(). This
+ means that elements of psi after psi(get_num_dimensions()-1) are not modified.
+ - Since scan_image_boxes only searches a limited set of object locations,
+ not all possible rectangles can be output by detect(). So in the case
+ where obj.get_rect() could not arise from a call to detect(), this
+ function will map obj.get_rect() to the nearest possible rectangle and
+ then add the feature vector for the mapped rectangle into #psi.
+ - get_best_matching_rect(obj.get_rect()) == the rectangle obj.get_rect()
+ gets mapped to for feature extraction.
+ !*/
+
+ full_object_detection get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& w
+ ) const;
+ /*!
+ ensures
+ - returns full_object_detection(rect)
+ (This function is here only for compatibility with the scan_image_pyramid
+ object)
+ !*/
+
+ const rectangle get_best_matching_rect (
+ const rectangle& rect
+ ) const;
+ /*!
+ requires
+ - is_loaded_with_image() == true
+ ensures
+ - Since scan_image_boxes only searches a limited set of object locations,
+ not all possible rectangles can be represented. Therefore, this function
+ allows you to supply a rectangle and obtain the nearest possible
+ candidate object location rectangle.
+ !*/
+
+ unsigned long get_num_detection_templates (
+ ) const { return 1; }
+ /*!
+ ensures
+ - returns 1. Note that this function is here only for compatibility with
+ the scan_image_pyramid object. Notionally, its return value indicates
+ that a scan_image_boxes object is always ready to detect objects once
+ an image has been loaded.
+ !*/
+
+ unsigned long get_num_movable_components_per_detection_template (
+ ) const { return 0; }
+ /*!
+ ensures
+ - returns 0. Note that this function is here only for compatibility with
+ the scan_image_pyramid object. Its return value means that this object
+ does not support using movable part models.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ void serialize (
+ const scan_image_boxes<Feature_extractor_type,Box_generator>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ typename Feature_extractor_type,
+ typename Box_generator
+ >
+ void deserialize (
+ scan_image_boxes<Feature_extractor_type,Box_generator>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_IMAGE_bOXES_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/scan_image_custom.h b/ml/dlib/dlib/image_processing/scan_image_custom.h
new file mode 100644
index 000000000..29b969fca
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image_custom.h
@@ -0,0 +1,401 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SCAN_IMAGE_CuSTOM_Hh_
+#define DLIB_SCAN_IMAGE_CuSTOM_Hh_
+
+#include "scan_image_custom_abstract.h"
+#include "../matrix.h"
+#include "../geometry.h"
+#include <vector>
+#include "../image_processing/full_object_detection.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ class scan_image_custom : noncopyable
+ {
+
+ public:
+
+ typedef matrix<double,0,1> feature_vector_type;
+ typedef Feature_extractor_type feature_extractor_type;
+
+ scan_image_custom (
+ );
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+
+ inline bool is_loaded_with_image (
+ ) const;
+
+ inline void copy_configuration(
+ const feature_extractor_type& fe
+ );
+
+ const Feature_extractor_type& get_feature_extractor (
+ ) const { return feats; }
+
+ inline void copy_configuration (
+ const scan_image_custom& item
+ );
+
+ inline long get_num_dimensions (
+ ) const;
+
+ void detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const;
+
+ void get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const;
+
+ full_object_detection get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& w
+ ) const;
+
+ const rectangle get_best_matching_rect (
+ const rectangle& rect
+ ) const;
+
+ inline unsigned long get_num_detection_templates (
+ ) const { return 1; }
+
+ inline unsigned long get_num_movable_components_per_detection_template (
+ ) const { return 0; }
+
+ template <typename T>
+ friend void serialize (
+ const scan_image_custom<T>& item,
+ std::ostream& out
+ );
+
+ template <typename T>
+ friend void deserialize (
+ scan_image_custom<T>& item,
+ std::istream& in
+ );
+
+ private:
+ static bool compare_pair_rect (
+ const std::pair<double, rectangle>& a,
+ const std::pair<double, rectangle>& b
+ )
+ {
+ return a.first < b.first;
+ }
+
+
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(
+ has_compute_object_score,
+ double,
+ compute_object_score,
+ ( const matrix<double,0,1>& w, const rectangle& obj) const
+ );
+
+ template <typename fe_type>
+ typename enable_if<has_compute_object_score<fe_type> >::type compute_all_rect_scores (
+ const fe_type& feats,
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const
+ {
+ for (unsigned long i = 0; i < search_rects.size(); ++i)
+ {
+ const double score = feats.compute_object_score(w, search_rects[i]);
+ if (score >= thresh)
+ {
+ dets.push_back(std::make_pair(score, search_rects[i]));
+ }
+ }
+ }
+
+ template <typename fe_type>
+ typename disable_if<has_compute_object_score<fe_type> >::type compute_all_rect_scores (
+ const fe_type& feats,
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const
+ {
+ matrix<double,0,1> psi(w.size());
+ psi = 0;
+ double prev_dot = 0;
+ for (unsigned long i = 0; i < search_rects.size(); ++i)
+ {
+ // Reset these back to zero every so often to avoid the accumulation of
+ // rounding error. Note that the only reason we do this loop in this
+ // complex way is to avoid needing to zero the psi vector every iteration.
+ if ((i%500) == 499)
+ {
+ psi = 0;
+ prev_dot = 0;
+ }
+
+ feats.get_feature_vector(search_rects[i], psi);
+ const double cur_dot = dot(psi, w);
+ const double score = cur_dot - prev_dot;
+ if (score >= thresh)
+ {
+ dets.push_back(std::make_pair(score, search_rects[i]));
+ }
+ prev_dot = cur_dot;
+ }
+ }
+
+
+ feature_extractor_type feats;
+ std::vector<rectangle> search_rects;
+ bool loaded_with_image;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const scan_image_custom<T>& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.feats, out);
+ serialize(item.search_rects, out);
+ serialize(item.loaded_with_image, out);
+ serialize(item.get_num_dimensions(), out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void deserialize (
+ scan_image_custom<T>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unsupported version found when deserializing a scan_image_custom object.");
+
+ deserialize(item.feats, in);
+ deserialize(item.search_rects, in);
+ deserialize(item.loaded_with_image, in);
+
+ // When developing some feature extractor, it's easy to accidentally change its
+ // number of dimensions and then try to deserialize data from an older version of
+ // your extractor into the current code. This check is here to catch that kind of
+ // user error.
+ long dims;
+ deserialize(dims, in);
+ if (item.get_num_dimensions() != dims)
+ throw serialization_error("Number of dimensions in serialized scan_image_custom doesn't match the expected number.");
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// scan_image_custom member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ scan_image_custom<Feature_extractor_type>::
+ scan_image_custom (
+ ) :
+ loaded_with_image(false)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ template <
+ typename image_type
+ >
+ void scan_image_custom<Feature_extractor_type>::
+ load (
+ const image_type& img
+ )
+ {
+ feats.load(img, search_rects);
+ loaded_with_image = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ bool scan_image_custom<Feature_extractor_type>::
+ is_loaded_with_image (
+ ) const
+ {
+ return loaded_with_image;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ void scan_image_custom<Feature_extractor_type>::
+ copy_configuration(
+ const feature_extractor_type& fe
+ )
+ {
+ feats.copy_configuration(fe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ void scan_image_custom<Feature_extractor_type>::
+ copy_configuration (
+ const scan_image_custom& item
+ )
+ {
+ feats.copy_configuration(item.feats);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ long scan_image_custom<Feature_extractor_type>::
+ get_num_dimensions (
+ ) const
+ {
+ return feats.get_num_dimensions();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ void scan_image_custom<Feature_extractor_type>::
+ detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_loaded_with_image() &&
+ w.size() >= get_num_dimensions(),
+ "\t void scan_image_custom::detect()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t w.size(): " << w.size()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t this: " << this
+ );
+
+ dets.clear();
+ compute_all_rect_scores(feats, w,dets,thresh);
+ std::sort(dets.rbegin(), dets.rend(), compare_pair_rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ const rectangle scan_image_custom<Feature_extractor_type>::
+ get_best_matching_rect (
+ const rectangle& rect
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_loaded_with_image(),
+ "\t const rectangle scan_image_custom::get_best_matching_rect()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t this: " << this
+ );
+
+
+ double best_score = -1;
+ rectangle best_rect;
+ for (unsigned long i = 0; i < search_rects.size(); ++i)
+ {
+ const double score = (rect.intersect(search_rects[i])).area()/(double)(rect+search_rects[i]).area();
+ if (score > best_score)
+ {
+ best_score = score;
+ best_rect = search_rects[i];
+ }
+ }
+ return best_rect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ full_object_detection scan_image_custom<Feature_extractor_type>::
+ get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& /*w*/
+ ) const
+ {
+ return full_object_detection(rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ void scan_image_custom<Feature_extractor_type>::
+ get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_loaded_with_image() &&
+ psi.size() >= get_num_dimensions() &&
+ obj.num_parts() == 0,
+ "\t void scan_image_custom::get_feature_vector()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t psi.size(): " << psi.size()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t obj.num_parts(): " << obj.num_parts()
+ << "\n\t this: " << this
+ );
+
+
+ feats.get_feature_vector(get_best_matching_rect(obj.get_rect()), psi);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_IMAGE_CuSTOM_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/scan_image_custom_abstract.h b/ml/dlib/dlib/image_processing/scan_image_custom_abstract.h
new file mode 100644
index 000000000..ca3ba402a
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image_custom_abstract.h
@@ -0,0 +1,390 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SCAN_IMAGE_CuSTOM_ABSTRACT_Hh_
+#ifdef DLIB_SCAN_IMAGE_CuSTOM_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "../geometry.h"
+#include "../image_processing/full_object_detection_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class example_feature_extractor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines the interface a feature extractor must implement if it
+ is to be used with the scan_image_custom object defined at the bottom of
+ this file.
+
+ In this case, the purpose of a feature extractor is to associated a
+ complete feature vector with each rectangle in an image. In particular,
+ each rectangle is scored by taking the dot product between this feature
+ vector and a weight vector. If this score is greater than a threshold then
+ the rectangle is output as a detection.
+ !*/
+
+ public:
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& image,
+ std::vector<rectangle>& candidate_objects
+ );
+ /*!
+ ensures
+ - Loads the given image into this feature extractor. This means that
+ subsequent calls to get_feature_vector() will return the feature vector
+ corresponding to locations in the image given to load().
+ - #candidate_objects == a set of bounding boxes in the given image that
+ might contain objects of interest. These are the locations that will be
+ checked for the presents of objects when this feature extractor is used
+ with the scan_image_custom object.
+
+ !*/
+
+ void copy_configuration (
+ const feature_extractor& item
+ );
+ /*!
+ ensures
+ - Copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two
+ feature extractor objects S1 and S2, the following sequence of
+ instructions should always result in both of them having the exact same
+ state:
+ S2.copy_configuration(S1);
+ S1.load(img, temp);
+ S2.load(img, temp);
+ !*/
+
+ unsigned long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the feature vectors output by this object.
+ !*/
+
+ void get_feature_vector (
+ const rectangle& obj,
+ matrix<double,0,1>& psi
+ ) const;
+ /*!
+ requires
+ - psi.size() >= get_num_dimensions()
+ (i.e. psi must have preallocated its memory before this function is called)
+ ensures
+ - This function computes the feature vector associated with the given rectangle
+ in obj. This rectangle is interpreted as a bounding box within the last image
+ given to this->load() and a feature vector describing that bounding box is
+ output into psi.
+ - The feature vector is added into psi. That is, it does not overwrite the
+ previous contents of psi, but instead, it adds the vector to psi.
+ - The dimensionality of the vector added to psi is get_num_dimensions(). This
+ means that elements of psi after psi(get_num_dimensions()-1) are not modified.
+ - #psi.size() == psi.size()
+ (i.e. this function does not change the size of the psi vector)
+ !*/
+
+ double compute_object_score (
+ const matrix<double,0,1>& w,
+ const rectangle& obj
+ ) const;
+ /*!
+ requires
+ - w.size() >= get_num_dimensions()
+ ensures
+ - This function returns the dot product between the feature vector for
+ object box obj and the given w vector. That is, this function computes
+ the same number as the following code snippet:
+ matrix<double,0,1> psi(w.size());
+ psi = 0;
+ get_feature_vector(obj, psi);
+ return dot(psi, w);
+ The point of the compute_object_score() routine is to compute this dot
+ product in a much more efficient way than directly calling
+ get_feature_vector() and dot(). Therefore, compute_object_score() is an
+ optional function. If you can't think of a faster way to compute these
+ scores then do not implement compute_object_score() and the
+ scan_image_custom object will simply compute these scores for you.
+ However, it is often the case that there is something clever you can do
+ to make this computation faster. If that is the case, then you can
+ provide an implementation of this function with your feature extractor
+ and then scan_image_custom will use it instead of using the default
+ calculation method shown in the above code snippet.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize(
+ const feature_extractor& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize(
+ feature_extractor& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Feature_extractor_type
+ >
+ class scan_image_custom : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON Feature_extractor_type
+ - must be an object with an interface compatible with the
+ example_feature_extractor defined at the top of this file.
+
+ INITIAL VALUE
+ - is_loaded_with_image() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for running a classifier over an image with the goal
+ of localizing each object present. The localization is in the form of the
+ bounding box around each object of interest.
+
+ Unlike the scan_image_pyramid and scan_image_boxes objects, this image
+ scanner delegates all the work of constructing the object feature vector to
+ its Feature_extractor_type template argument. That is, scan_image_custom
+ simply asks the supplied feature extractor what boxes in the image we
+ should investigate and then asks the feature extractor for the complete
+ feature vector for each box. That is, scan_image_custom does not apply any
+ kind of pyramiding or other higher level processing to the features coming
+ out of the feature extractor. That means that when you use
+ scan_image_custom it is completely up to you to define the feature vector
+ used with each image box.
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be
+ protected by a mutex lock except for the case where you are copying the
+ configuration (via copy_configuration()) of a scan_image_custom object to
+ many other threads. In this case, it is safe to copy the configuration of
+ a shared object so long as no other operations are performed on it.
+ !*/
+
+ public:
+
+ typedef matrix<double,0,1> feature_vector_type;
+ typedef Feature_extractor_type feature_extractor_type;
+
+ scan_image_custom (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type must be a type with the following properties:
+ - image_type objects can be loaded into Feature_extractor_type
+ objects via Feature_extractor_type::load().
+ ensures
+ - #is_loaded_with_image() == true
+ - Calls get_feature_extractor().load() on the given image. That is, we
+ will have loaded the image into the feature extractor in this
+ scan_image_custom object. We will also have stored the candidate
+ object locations generated by the feature extractor and will scan
+ over them when this->detect() is called.
+ - This object is ready to run a classifier over img to detect object
+ locations. Call detect() to do this.
+ !*/
+
+ bool is_loaded_with_image (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object has been loaded with an image to process and
+ false otherwise.
+ !*/
+
+ const feature_extractor_type& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the feature_extractor_type object used
+ internally for local feature extraction.
+ !*/
+
+ void copy_configuration(
+ const feature_extractor_type& fe
+ );
+ /*!
+ ensures
+ - This function performs the equivalent of
+ get_feature_extractor().copy_configuration(fe) (i.e. this function allows
+ you to configure the parameters of the underlying feature extractor used
+ by a scan_image_custom object)
+ !*/
+
+ void copy_configuration (
+ const scan_image_custom& item
+ );
+ /*!
+ ensures
+ - Copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two
+ scan_image_custom objects S1 and S2, the following sequence of
+ instructions should always result in both of them having the exact same
+ state:
+ S2.copy_configuration(S1);
+ S1.load(img);
+ S2.load(img);
+ !*/
+
+ long get_num_dimensions (
+ ) const;
+ /*!
+ ensures
+ - returns the number of dimensions in the feature vector for a candidate
+ object location. That is, this function returns get_feature_extractor().get_num_dimensions().
+ !*/
+
+ void detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const;
+ /*!
+ requires
+ - w.size() >= get_num_dimensions()
+ - is_loaded_with_image() == true
+ ensures
+ - Scans over all the candidate object locations produced by the feature
+ extractor during image loading and stores all detections into #dets.
+ - for all valid i:
+ - #dets[i].second == The candidate object location which produced this
+ detection. This rectangle gives the location of the detection.
+ - #dets[i].first == The score for this detection. This value is equal
+ to dot(w, feature vector for this candidate object location).
+ - #dets[i].first >= thresh
+ - #dets will be sorted in descending order.
+ (i.e. #dets[i].first >= #dets[j].first for all i, and j>i)
+ - Elements of w beyond index get_num_dimensions()-1 are ignored. I.e. only
+ the first get_num_dimensions() are used.
+ - Note that no form of non-max suppression is performed. If a locations
+ has a score >= thresh then it is reported in #dets.
+ !*/
+
+ void get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const;
+ /*!
+ requires
+ - obj.num_parts() == 0
+ - is_loaded_with_image() == true
+ - psi.size() >= get_num_dimensions()
+ (i.e. psi must have preallocated its memory before this function is called)
+ ensures
+ - This function allows you to determine the feature vector used for a
+ candidate object location output from detect(). Note that this vector is
+ added to psi. Note also that you must use get_full_object_detection() to
+ convert a rectangle from detect() into the needed full_object_detection.
+ - The dimensionality of the vector added to psi is get_num_dimensions(). This
+ means that elements of psi after psi(get_num_dimensions()-1) are not modified.
+ - Since scan_image_custom only searches a limited set of object locations,
+ not all possible rectangles can be output by detect(). So in the case
+ where obj.get_rect() could not arise from a call to detect(), this
+ function will map obj.get_rect() to the nearest possible rectangle and
+ then add the feature vector for the mapped rectangle into #psi.
+ - get_best_matching_rect(obj.get_rect()) == the rectangle obj.get_rect()
+ gets mapped to for feature extraction.
+ !*/
+
+ full_object_detection get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& w
+ ) const;
+ /*!
+ ensures
+ - returns full_object_detection(rect)
+ (This function is here only for compatibility with the scan_image_pyramid
+ object)
+ !*/
+
+ const rectangle get_best_matching_rect (
+ const rectangle& rect
+ ) const;
+ /*!
+ requires
+ - is_loaded_with_image() == true
+ ensures
+ - Since scan_image_custom only searches a limited set of object locations,
+ not all possible rectangles can be represented. Therefore, this function
+ allows you to supply a rectangle and obtain the nearest possible
+ candidate object location rectangle.
+ !*/
+
+ unsigned long get_num_detection_templates (
+ ) const { return 1; }
+ /*!
+ ensures
+ - returns 1. Note that this function is here only for compatibility with
+ the scan_image_pyramid object. Notionally, its return value indicates
+ that a scan_image_custom object is always ready to detect objects once an
+ image has been loaded.
+ !*/
+
+ unsigned long get_num_movable_components_per_detection_template (
+ ) const { return 0; }
+ /*!
+ ensures
+ - returns 0. Note that this function is here only for compatibility with
+ the scan_image_pyramid object. Its return value means that this object
+ does not support using movable part models.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const scan_image_custom<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <typename T>
+ void deserialize (
+ scan_image_custom<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_IMAGE_CuSTOM_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/scan_image_pyramid.h b/ml/dlib/dlib/image_processing/scan_image_pyramid.h
new file mode 100644
index 000000000..455f1a649
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image_pyramid.h
@@ -0,0 +1,1101 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SCAN_IMaGE_PYRAMID_Hh_
+#define DLIB_SCAN_IMaGE_PYRAMID_Hh_
+
+#include "scan_image_pyramid_abstract.h"
+#include "../matrix.h"
+#include "../geometry.h"
+#include "scan_image.h"
+#include "../array2d.h"
+#include <vector>
+#include "full_object_detection.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ class scan_image_pyramid : noncopyable
+ {
+
+ public:
+
+ typedef matrix<double,0,1> feature_vector_type;
+
+ typedef Pyramid_type pyramid_type;
+ typedef Feature_extractor_type feature_extractor_type;
+
+ scan_image_pyramid (
+ );
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+
+ inline bool is_loaded_with_image (
+ ) const;
+
+ inline void copy_configuration(
+ const feature_extractor_type& fe
+ );
+
+ inline void copy_configuration (
+ const scan_image_pyramid& item
+ );
+
+ const Feature_extractor_type& get_feature_extractor (
+ ) const { return feats_config; }
+
+ void add_detection_template (
+ const rectangle& object_box,
+ const std::vector<rectangle>& stationary_feature_extraction_regions,
+ const std::vector<rectangle>& movable_feature_extraction_regions
+ );
+
+ void add_detection_template (
+ const rectangle& object_box,
+ const std::vector<rectangle>& stationary_feature_extraction_regions
+ );
+
+ inline unsigned long get_num_detection_templates (
+ ) const;
+
+ inline unsigned long get_num_movable_components_per_detection_template (
+ ) const;
+
+ inline unsigned long get_num_stationary_components_per_detection_template (
+ ) const;
+
+ inline unsigned long get_num_components_per_detection_template (
+ ) const;
+
+ inline long get_num_dimensions (
+ ) const;
+
+ unsigned long get_max_pyramid_levels (
+ ) const;
+
+ void set_max_pyramid_levels (
+ unsigned long max_levels
+ );
+
+ inline unsigned long get_max_detections_per_template (
+ ) const;
+
+ void set_min_pyramid_layer_size (
+ unsigned long width,
+ unsigned long height
+ );
+
+ inline unsigned long get_min_pyramid_layer_width (
+ ) const;
+
+ inline unsigned long get_min_pyramid_layer_height (
+ ) const;
+
+ void set_max_detections_per_template (
+ unsigned long max_dets
+ );
+
+ void detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const;
+
+ void get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const;
+
+ full_object_detection get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& w
+ ) const;
+
+ const rectangle get_best_matching_rect (
+ const rectangle& rect
+ ) const;
+
+ template <typename T, typename U>
+ friend void serialize (
+ const scan_image_pyramid<T,U>& item,
+ std::ostream& out
+ );
+
+ template <typename T, typename U>
+ friend void deserialize (
+ scan_image_pyramid<T,U>& item,
+ std::istream& in
+ );
+
+ private:
+ static bool compare_pair_rect (
+ const std::pair<double, rectangle>& a,
+ const std::pair<double, rectangle>& b
+ )
+ {
+ return a.first < b.first;
+ }
+
+ struct detection_template
+ {
+ rectangle object_box; // always centered at (0,0)
+ std::vector<rectangle> rects; // template with respect to (0,0)
+ std::vector<rectangle> movable_rects;
+ };
+
+ friend void serialize(const detection_template& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.object_box, out);
+ serialize(item.rects, out);
+ serialize(item.movable_rects, out);
+ }
+ friend void deserialize(detection_template& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing a dlib::scan_image_pyramid::detection_template object.");
+
+ deserialize(item.object_box, in);
+ deserialize(item.rects, in);
+ deserialize(item.movable_rects, in);
+ }
+
+ void get_mapped_rect_and_metadata (
+ const unsigned long number_pyramid_levels,
+ rectangle rect,
+ rectangle& mapped_rect,
+ detection_template& best_template,
+ rectangle& object_box,
+ unsigned long& best_level,
+ unsigned long& detection_template_idx
+ ) const;
+
+ double get_match_score (
+ rectangle r1,
+ rectangle r2
+ ) const
+ {
+ // make the rectangles overlap as much as possible before computing the match score.
+ r1 = move_rect(r1, r2.tl_corner());
+ return (r1.intersect(r2).area())/(double)(r1 + r2).area();
+ }
+
+ void test_coordinate_transforms()
+ {
+ for (long x = -10; x <= 10; x += 10)
+ {
+ for (long y = -10; y <= 10; y += 10)
+ {
+ const rectangle rect = centered_rect(x,y,5,6);
+ rectangle a;
+
+ a = feats_config.image_to_feat_space(rect);
+ if (a.width() > 10000000 || a.height() > 10000000 )
+ {
+ DLIB_CASSERT(false, "The image_to_feat_space() routine is outputting rectangles of an implausibly "
+ << "\nlarge size. This means there is probably a bug in your feature extractor.");
+ }
+ a = feats_config.feat_to_image_space(rect);
+ if (a.width() > 10000000 || a.height() > 10000000 )
+ {
+ DLIB_CASSERT(false, "The feat_to_image_space() routine is outputting rectangles of an implausibly "
+ << "\nlarge size. This means there is probably a bug in your feature extractor.");
+ }
+ }
+ }
+
+ }
+
+ feature_extractor_type feats_config; // just here to hold configuration. use it to populate the feats elements.
+ array<feature_extractor_type> feats;
+ std::vector<detection_template> det_templates;
+ unsigned long max_dets_per_template;
+ unsigned long max_pyramid_levels;
+ unsigned long min_pyramid_layer_width;
+ unsigned long min_pyramid_layer_height;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ void serialize (
+ const scan_image_pyramid<T,U>& item,
+ std::ostream& out
+ )
+ {
+ int version = 3;
+ serialize(version, out);
+ serialize(item.feats_config, out);
+ serialize(item.feats, out);
+ serialize(item.det_templates, out);
+ serialize(item.max_dets_per_template, out);
+ serialize(item.max_pyramid_levels, out);
+ serialize(item.min_pyramid_layer_width, out);
+ serialize(item.min_pyramid_layer_height, out);
+ serialize(item.get_num_dimensions(), out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ void deserialize (
+ scan_image_pyramid<T,U>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 3)
+ throw serialization_error("Unsupported version found when deserializing a scan_image_pyramid object.");
+
+ deserialize(item.feats_config, in);
+ deserialize(item.feats, in);
+ deserialize(item.det_templates, in);
+ deserialize(item.max_dets_per_template, in);
+ deserialize(item.max_pyramid_levels, in);
+ deserialize(item.min_pyramid_layer_width, in);
+ deserialize(item.min_pyramid_layer_height, in);
+
+ // When developing some feature extractor, it's easy to accidentally change its
+ // number of dimensions and then try to deserialize data from an older version of
+ // your extractor into the current code. This check is here to catch that kind of
+ // user error.
+ long dims;
+ deserialize(dims, in);
+ if (item.get_num_dimensions() != dims)
+ throw serialization_error("Number of dimensions in serialized scan_image_pyramid doesn't match the expected number.");
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// scan_image_pyramid member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ scan_image_pyramid (
+ ) :
+ max_dets_per_template(10000),
+ max_pyramid_levels(1000),
+ min_pyramid_layer_width(20),
+ min_pyramid_layer_height(20)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ template <
+ typename image_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ load (
+ const image_type& img
+ )
+ {
+ unsigned long levels = 0;
+ rectangle rect = get_rect(img);
+
+ // figure out how many pyramid levels we should be using based on the image size
+ pyramid_type pyr;
+ do
+ {
+ rect = pyr.rect_down(rect);
+ ++levels;
+ } while (rect.width() >= min_pyramid_layer_width && rect.height() >= min_pyramid_layer_height &&
+ levels < max_pyramid_levels);
+
+ if (feats.max_size() < levels)
+ feats.set_max_size(levels);
+ feats.set_size(levels);
+
+ for (unsigned long i = 0; i < feats.size(); ++i)
+ feats[i].copy_configuration(feats_config);
+
+ // build our feature pyramid
+ feats[0].load(img);
+ if (feats.size() > 1)
+ {
+ image_type temp1, temp2;
+ pyr(img, temp1);
+ feats[1].load(temp1);
+ swap(temp1,temp2);
+
+ for (unsigned long i = 2; i < feats.size(); ++i)
+ {
+ pyr(temp2, temp1);
+ feats[i].load(temp1);
+ swap(temp1,temp2);
+ }
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ unsigned long scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_max_detections_per_template (
+ ) const
+ {
+ return max_dets_per_template;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ set_max_detections_per_template (
+ unsigned long max_dets
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_dets > 0 ,
+ "\t void scan_image_pyramid::set_max_detections_per_template()"
+ << "\n\t The max number of possible detections can't be zero. "
+ << "\n\t max_dets: " << max_dets
+ << "\n\t this: " << this
+ );
+
+ max_dets_per_template = max_dets;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ bool scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ is_loaded_with_image (
+ ) const
+ {
+ return feats.size() != 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ copy_configuration(
+ const feature_extractor_type& fe
+ )
+ {
+ test_coordinate_transforms();
+ feats_config.copy_configuration(fe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ copy_configuration (
+ const scan_image_pyramid& item
+ )
+ {
+ feats_config.copy_configuration(item.feats_config);
+ det_templates = item.det_templates;
+ max_dets_per_template = item.max_dets_per_template;
+ max_pyramid_levels = item.max_pyramid_levels;
+ min_pyramid_layer_width = item.min_pyramid_layer_width;
+ min_pyramid_layer_height = item.min_pyramid_layer_height;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ add_detection_template (
+ const rectangle& object_box,
+ const std::vector<rectangle>& stationary_feature_extraction_regions,
+ const std::vector<rectangle>& movable_feature_extraction_regions
+ )
+ {
+#ifdef ENABLE_ASSERTS
+ // make sure requires clause is not broken
+ DLIB_ASSERT((get_num_detection_templates() == 0 ||
+ (get_num_stationary_components_per_detection_template() == stationary_feature_extraction_regions.size() &&
+ get_num_movable_components_per_detection_template() == movable_feature_extraction_regions.size())) &&
+ center(object_box) == point(0,0),
+ "\t void scan_image_pyramid::add_detection_template()"
+ << "\n\t The number of rects in this new detection template doesn't match "
+ << "\n\t the number in previous detection templates."
+ << "\n\t get_num_stationary_components_per_detection_template(): " << get_num_stationary_components_per_detection_template()
+ << "\n\t stationary_feature_extraction_regions.size(): " << stationary_feature_extraction_regions.size()
+ << "\n\t get_num_movable_components_per_detection_template(): " << get_num_movable_components_per_detection_template()
+ << "\n\t movable_feature_extraction_regions.size(): " << movable_feature_extraction_regions.size()
+ << "\n\t this: " << this
+ );
+
+ for (unsigned long i = 0; i < movable_feature_extraction_regions.size(); ++i)
+ {
+ DLIB_ASSERT(center(movable_feature_extraction_regions[i]) == point(0,0),
+ "Invalid inputs were given to this function."
+ << "\n\t center(movable_feature_extraction_regions["<<i<<"]): " << center(movable_feature_extraction_regions[i])
+ << "\n\t this: " << this
+ );
+ }
+#endif
+
+ detection_template temp;
+ temp.object_box = object_box;
+ temp.rects = stationary_feature_extraction_regions;
+ temp.movable_rects = movable_feature_extraction_regions;
+ det_templates.push_back(temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ add_detection_template (
+ const rectangle& object_box,
+ const std::vector<rectangle>& stationary_feature_extraction_regions
+ )
+ {
+ // an empty set of movable feature regions
+ const std::vector<rectangle> movable_feature_extraction_regions;
+ add_detection_template(object_box, stationary_feature_extraction_regions,
+ movable_feature_extraction_regions);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ unsigned long scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_num_detection_templates (
+ ) const
+ {
+ return det_templates.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ unsigned long scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_num_stationary_components_per_detection_template (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_num_detection_templates() > 0 ,
+ "\t unsigned long scan_image_pyramid::get_num_stationary_components_per_detection_template()"
+ << "\n\t You need to give some detection templates before calling this function. "
+ << "\n\t get_num_detection_templates(): " << get_num_detection_templates()
+ << "\n\t this: " << this
+ );
+
+ return det_templates[0].rects.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ unsigned long scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_num_movable_components_per_detection_template (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_num_detection_templates() > 0 ,
+ "\t unsigned long scan_image_pyramid::get_num_movable_components_per_detection_template()"
+ << "\n\t You need to give some detection templates before calling this function. "
+ << "\n\t get_num_detection_templates(): " << get_num_detection_templates()
+ << "\n\t this: " << this
+ );
+
+ return det_templates[0].movable_rects.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ unsigned long scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_num_components_per_detection_template (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_num_detection_templates() > 0 ,
+ "\t unsigned long scan_image_pyramid::get_num_components_per_detection_template()"
+ << "\n\t You need to give some detection templates before calling this function. "
+ << "\n\t get_num_detection_templates(): " << get_num_detection_templates()
+ << "\n\t this: " << this
+ );
+
+ return get_num_movable_components_per_detection_template() +
+ get_num_stationary_components_per_detection_template();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ long scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_num_dimensions (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_num_detection_templates() > 0 ,
+ "\t long scan_image_pyramid::get_num_dimensions()"
+ << "\n\t You need to give some detection templates before calling this function. "
+ << "\n\t get_num_detection_templates(): " << get_num_detection_templates()
+ << "\n\t this: " << this
+ );
+
+ return feats_config.get_num_dimensions()*get_num_components_per_detection_template() + get_num_detection_templates();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ unsigned long scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_max_pyramid_levels (
+ ) const
+ {
+ return max_pyramid_levels;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ set_max_pyramid_levels (
+ unsigned long max_levels
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_levels > 0 ,
+ "\t void scan_image_pyramid::set_max_pyramid_levels()"
+ << "\n\t You can't have zero levels. "
+ << "\n\t max_levels: " << max_levels
+ << "\n\t this: " << this
+ );
+
+ max_pyramid_levels = max_levels;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_num_detection_templates() > 0 &&
+ is_loaded_with_image() &&
+ w.size() >= get_num_dimensions(),
+ "\t void scan_image_pyramid::detect()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t get_num_detection_templates(): " << get_num_detection_templates()
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t w.size(): " << w.size()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t this: " << this
+ );
+
+ dets.clear();
+
+ array<array2d<double> > saliency_images;
+ saliency_images.set_max_size(get_num_components_per_detection_template());
+ saliency_images.set_size(get_num_components_per_detection_template());
+ std::vector<std::pair<unsigned int,rectangle> > stationary_region_rects(get_num_stationary_components_per_detection_template());
+ std::vector<std::pair<unsigned int,rectangle> > movable_region_rects(get_num_movable_components_per_detection_template());
+ pyramid_type pyr;
+ std::vector<std::pair<double, point> > point_dets;
+
+ // for all pyramid levels
+ for (unsigned long l = 0; l < feats.size(); ++l)
+ {
+ for (unsigned long i = 0; i < saliency_images.size(); ++i)
+ {
+ saliency_images[i].set_size(feats[l].nr(), feats[l].nc());
+ const unsigned long offset = get_num_detection_templates() + feats_config.get_num_dimensions()*i;
+
+ // build saliency images for pyramid level l
+ for (long r = 0; r < feats[l].nr(); ++r)
+ {
+ for (long c = 0; c < feats[l].nc(); ++c)
+ {
+ const typename feature_extractor_type::descriptor_type& descriptor = feats[l](r,c);
+
+ double sum = 0;
+ for (unsigned long k = 0; k < descriptor.size(); ++k)
+ {
+ sum += w(descriptor[k].first + offset)*descriptor[k].second;
+ }
+ saliency_images[i][r][c] = sum;
+ }
+ }
+ }
+
+ // now search the saliency images
+ for (unsigned long i = 0; i < det_templates.size(); ++i)
+ {
+ const point offset = -feats[l].image_to_feat_space(point(0,0));
+ for (unsigned long j = 0; j < stationary_region_rects.size(); ++j)
+ {
+ stationary_region_rects[j] = std::make_pair(j, translate_rect(feats[l].image_to_feat_space(det_templates[i].rects[j]),offset));
+ }
+ for (unsigned long j = 0; j < movable_region_rects.size(); ++j)
+ {
+ // Scale the size of the movable rectangle but make sure its center
+ // stays at point(0,0).
+ const rectangle temp = feats[l].image_to_feat_space(det_templates[i].movable_rects[j]);
+ movable_region_rects[j] = std::make_pair(j+stationary_region_rects.size(),
+ centered_rect(point(0,0),temp.width(), temp.height()));
+ }
+
+ // Scale the object box into the feature extraction image, but keeping it
+ // centered at point(0,0).
+ rectangle scaled_object_box = feats[l].image_to_feat_space(det_templates[i].object_box);
+ scaled_object_box = centered_rect(point(0,0),scaled_object_box.width(), scaled_object_box.height());
+
+ // Each detection template gets its own special threshold in addition to
+ // the global detection threshold. This allows us to model the fact that
+ // some detection templates might be more prone to false alarming or since
+ // their size is different naturally require a larger or smaller threshold
+ // (since they integrate over a larger or smaller region of the image).
+ const double template_specific_thresh = w(i);
+
+ scan_image_movable_parts(point_dets, saliency_images, scaled_object_box,
+ stationary_region_rects, movable_region_rects,
+ thresh+template_specific_thresh, max_dets_per_template);
+
+ // convert all the point detections into rectangles at the original image scale and coordinate system
+ for (unsigned long j = 0; j < point_dets.size(); ++j)
+ {
+ const double score = point_dets[j].first-template_specific_thresh;
+ point p = point_dets[j].second;
+ p = feats[l].feat_to_image_space(p);
+ rectangle rect = translate_rect(det_templates[i].object_box, p);
+ rect = pyr.rect_up(rect, l);
+
+ dets.push_back(std::make_pair(score, rect));
+ }
+ }
+ }
+
+ std::sort(dets.rbegin(), dets.rend(), compare_pair_rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ const rectangle scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_best_matching_rect (
+ const rectangle& rect
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_num_detection_templates() > 0 ,
+ "\t const rectangle scan_image_pyramid::get_best_matching_rect()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t get_num_detection_templates(): " << get_num_detection_templates()
+ << "\n\t this: " << this
+ );
+
+ rectangle mapped_rect, object_box;
+ detection_template best_template;
+ unsigned long best_level, junk;
+ get_mapped_rect_and_metadata(max_pyramid_levels, rect, mapped_rect, best_template, object_box, best_level, junk);
+ return mapped_rect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_mapped_rect_and_metadata (
+ const unsigned long number_pyramid_levels,
+ rectangle rect,
+ rectangle& mapped_rect,
+ detection_template& best_template,
+ rectangle& object_box,
+ unsigned long& best_level,
+ unsigned long& detection_template_idx
+ ) const
+ {
+ pyramid_type pyr;
+ // Figure out the pyramid level which best matches rect against one of our
+ // detection template object boxes.
+ best_level = 0;
+ double best_match_score = -1;
+
+
+ // Find the best matching detection template for rect
+ for (unsigned long l = 0; l < number_pyramid_levels; ++l)
+ {
+ const rectangle temp = pyr.rect_down(rect,l);
+ if (temp.area() <= 1)
+ break;
+
+ // At this pyramid level, what matches best?
+ for (unsigned long t = 0; t < det_templates.size(); ++t)
+ {
+ const double match_score = get_match_score(det_templates[t].object_box, temp);
+ if (match_score > best_match_score)
+ {
+ best_match_score = match_score;
+ best_level = l;
+ best_template = det_templates[t];
+ detection_template_idx = t;
+ }
+ }
+ }
+
+
+ // Now we translate best_template into the right spot (it should be centered at the location
+ // determined by rect) and convert it into the feature image coordinate system.
+ rect = pyr.rect_down(rect,best_level);
+ const point offset = -feats_config.image_to_feat_space(point(0,0));
+ const point origin = feats_config.image_to_feat_space(center(rect)) + offset;
+ for (unsigned long k = 0; k < best_template.rects.size(); ++k)
+ {
+ rectangle temp = best_template.rects[k];
+ temp = feats_config.image_to_feat_space(temp);
+ temp = translate_rect(temp, origin);
+ best_template.rects[k] = temp;
+ }
+ for (unsigned long k = 0; k < best_template.movable_rects.size(); ++k)
+ {
+ rectangle temp = best_template.movable_rects[k];
+ temp = feats_config.image_to_feat_space(temp);
+ temp = centered_rect(point(0,0), temp.width(), temp.height());
+ best_template.movable_rects[k] = temp;
+ }
+
+ const rectangle scaled_object_box = feats_config.image_to_feat_space(best_template.object_box);
+ object_box = centered_rect(origin-offset, scaled_object_box.width(), scaled_object_box.height());
+
+ // The input rectangle was mapped to one of the detection templates. Reverse the process
+ // to figure out what the mapped rectangle is in the original input space.
+ mapped_rect = translate_rect(best_template.object_box, feats_config.feat_to_image_space(origin-offset));
+ mapped_rect = pyr.rect_up(mapped_rect, best_level);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ full_object_detection scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& w
+ ) const
+ {
+ // fill in movable part positions.
+
+ rectangle mapped_rect;
+ detection_template best_template;
+ unsigned long best_level, junk;
+ rectangle object_box;
+ get_mapped_rect_and_metadata(feats.size(), rect, mapped_rect, best_template, object_box, best_level, junk);
+
+ Pyramid_type pyr;
+
+ array2d<double> saliency_image, sum_img;
+
+ double total_temp_score = 0;
+ // convert into feature space.
+ object_box = object_box.intersect(get_rect(feats[best_level]));
+
+ std::vector<point> movable_parts;
+ movable_parts.reserve(get_num_movable_components_per_detection_template());
+ for (unsigned long i = 0; i < get_num_movable_components_per_detection_template(); ++i)
+ {
+ // make the saliency_image for the ith movable part.
+
+ const rectangle part_rect = best_template.movable_rects[i];
+ const rectangle area = grow_rect(object_box,
+ part_rect.width()/2,
+ part_rect.height()/2).intersect(get_rect(feats[best_level]));
+
+ saliency_image.set_size(area.height(), area.width());
+ const unsigned long offset = get_num_detection_templates() + feats_config.get_num_dimensions()*(i+get_num_stationary_components_per_detection_template());
+
+ // build saliency image for pyramid level best_level
+ for (long r = area.top(); r <= area.bottom(); ++r)
+ {
+ for (long c = area.left(); c <= area.right(); ++c)
+ {
+ const typename feature_extractor_type::descriptor_type& descriptor = feats[best_level](r,c);
+
+ double sum = 0;
+ for (unsigned long k = 0; k < descriptor.size(); ++k)
+ {
+ sum += w(descriptor[k].first + offset)*descriptor[k].second;
+ }
+ saliency_image[r-area.top()][c-area.left()] = sum;
+ }
+ }
+
+ sum_img.set_size(saliency_image.nr(), saliency_image.nc());
+ sum_filter_assign(saliency_image, sum_img, part_rect);
+ // Figure out where the maximizer is in sum_img. Note that we
+ // only look in the part of sum_img that corresponds to a location inside
+ // object_box.
+ rectangle valid_area = get_rect(sum_img);
+ valid_area.left() += object_box.left() - area.left();
+ valid_area.top() += object_box.top() - area.top();
+ valid_area.right() += object_box.right() - area.right();
+ valid_area.bottom() += object_box.bottom() - area.bottom();
+ double max_val = 0;
+ point max_loc;
+ for (long r = valid_area.top(); r <= valid_area.bottom(); ++r)
+ {
+ for (long c = valid_area.left(); c <= valid_area.right(); ++c)
+ {
+ if (sum_img[r][c] > max_val)
+ {
+ //if (object_box.contains(point(c,r) + area.tl_corner()))
+ {
+ max_loc = point(c,r);
+ max_val = sum_img[r][c];
+ }
+ }
+ }
+ }
+
+ if (max_val <= 0)
+ {
+ max_loc = OBJECT_PART_NOT_PRESENT;
+ }
+ else
+ {
+ total_temp_score += max_val;
+ // convert max_loc back into feature image space from our cropped image.
+ max_loc += area.tl_corner();
+
+ // now convert from feature space to image space.
+ max_loc = feats[best_level].feat_to_image_space(max_loc);
+ max_loc = pyr.point_up(max_loc, best_level);
+ max_loc = nearest_point(rect, max_loc);
+ }
+
+ movable_parts.push_back(max_loc);
+ }
+
+ return full_object_detection(rect, movable_parts);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_num_detection_templates() > 0 &&
+ is_loaded_with_image() &&
+ psi.size() >= get_num_dimensions() &&
+ obj.num_parts() == get_num_movable_components_per_detection_template(),
+ "\t void scan_image_pyramid::get_feature_vector()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t get_num_detection_templates(): " << get_num_detection_templates()
+ << "\n\t is_loaded_with_image(): " << is_loaded_with_image()
+ << "\n\t psi.size(): " << psi.size()
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t get_num_movable_components_per_detection_template(): " << get_num_movable_components_per_detection_template()
+ << "\n\t obj.num_parts(): " << obj.num_parts()
+ << "\n\t this: " << this
+ );
+ DLIB_ASSERT(all_parts_in_rect(obj),
+ "\t void scan_image_pyramid::get_feature_vector()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t obj.get_rect(): " << obj.get_rect()
+ << "\n\t this: " << this
+ );
+
+
+
+ rectangle mapped_rect;
+ detection_template best_template;
+ unsigned long best_level, detection_template_idx;
+ rectangle object_box;
+ get_mapped_rect_and_metadata(feats.size(), obj.get_rect(), mapped_rect, best_template, object_box, best_level, detection_template_idx);
+
+ psi(detection_template_idx) -= 1;
+
+ Pyramid_type pyr;
+
+ // put the movable rects at the places indicated by obj.
+ std::vector<rectangle> rects = best_template.rects;
+ for (unsigned long i = 0; i < obj.num_parts(); ++i)
+ {
+ if (obj.part(i) != OBJECT_PART_NOT_PRESENT)
+ {
+ // map from the original image to scaled feature space.
+ point loc = feats[best_level].image_to_feat_space(pyr.point_down(obj.part(i), best_level));
+ // Make sure the movable part always stays within the object_box.
+ // Otherwise it would be at a place that the detect() function can never
+ // look.
+ loc = nearest_point(object_box, loc);
+ rects.push_back(translate_rect(best_template.movable_rects[i], loc));
+ }
+ else
+ {
+ // add an empty rectangle since this part wasn't observed.
+ rects.push_back(rectangle());
+ }
+ }
+
+ // pull features out of all the boxes in rects.
+ for (unsigned long j = 0; j < rects.size(); ++j)
+ {
+ const rectangle rect = rects[j].intersect(get_rect(feats[best_level]));
+ const unsigned long template_region_id = j;
+ const unsigned long offset = get_num_detection_templates() + feats_config.get_num_dimensions()*template_region_id;
+ for (long r = rect.top(); r <= rect.bottom(); ++r)
+ {
+ for (long c = rect.left(); c <= rect.right(); ++c)
+ {
+ const typename feature_extractor_type::descriptor_type& descriptor = feats[best_level](r,c);
+ for (unsigned long k = 0; k < descriptor.size(); ++k)
+ {
+ psi(descriptor[k].first + offset) += descriptor[k].second;
+ }
+ }
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ set_min_pyramid_layer_size (
+ unsigned long width,
+ unsigned long height
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(width > 0 && height > 0 ,
+ "\t void scan_image_pyramid::set_min_pyramid_layer_size()"
+ << "\n\t These sizes can't be zero. "
+ << "\n\t width: " << width
+ << "\n\t height: " << height
+ << "\n\t this: " << this
+ );
+
+ min_pyramid_layer_width = width;
+ min_pyramid_layer_height = height;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ unsigned long scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_min_pyramid_layer_width (
+ ) const
+ {
+ return min_pyramid_layer_width;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ unsigned long scan_image_pyramid<Pyramid_type,Feature_extractor_type>::
+ get_min_pyramid_layer_height (
+ ) const
+ {
+ return min_pyramid_layer_height;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_IMaGE_PYRAMID_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/scan_image_pyramid_abstract.h b/ml/dlib/dlib/image_processing/scan_image_pyramid_abstract.h
new file mode 100644
index 000000000..e985a3f32
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image_pyramid_abstract.h
@@ -0,0 +1,495 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SCAN_IMaGE_PYRAMID_ABSTRACT_Hh_
+#ifdef DLIB_SCAN_IMaGE_PYRAMID_ABSTRACT_Hh_
+
+#include "../matrix.h"
+#include "../geometry.h"
+#include "../image_processing.h"
+#include "../array2d.h"
+#include <vector>
+#include "full_object_detection_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ class scan_image_pyramid : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON Pyramid_type
+ - must be one of the pyramid_down objects defined in
+ dlib/image_transforms/image_pyramid_abstract.h or an object with
+ a compatible interface
+
+ REQUIREMENTS ON Feature_extractor_type
+ - must be an object with an interface compatible with the hashed_feature_image
+ object defined in dlib/image_keypoint/hashed_feature_image_abstract.h or
+ with the nearest_neighbor_feature_image object defined in
+ dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h
+
+ INITIAL VALUE
+ - get_num_detection_templates() == 0
+ - is_loaded_with_image() == false
+ - get_max_detections_per_template() == 10000
+ - get_max_pyramid_levels() == 1000
+ - get_min_pyramid_layer_width() == 20
+ - get_min_pyramid_layer_height() == 20
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for running a sliding window classifier over
+ an image pyramid. This object can also be understood as a general
+ tool for implementing the spatial pyramid models described in the paper:
+ Beyond Bags of Features: Spatial Pyramid Matching for Recognizing
+ Natural Scene Categories by Svetlana Lazebnik, Cordelia Schmid,
+ and Jean Ponce
+ It also includes the ability to represent movable part models.
+
+
+
+
+ The sliding window classifiers used by this object have three parts:
+ 1. The underlying feature extraction provided by Feature_extractor_type
+ objects, which associate a vector with each location in an image.
+
+ 2. A detection template. This is a rectangle which defines the shape of a
+ sliding window (i.e. the object_box), as well as a set of rectangular feature
+ extraction regions inside it. This set of regions defines the spatial
+ structure of the overall feature extraction within a sliding window. In
+ particular, each location of a sliding window has a feature vector
+ associated with it. This feature vector is defined as follows:
+ - Let N denote the number of feature extraction zones.
+ - Let M denote the dimensionality of the vectors output by Feature_extractor_type
+ objects.
+ - Let F(i) == the M dimensional vector which is the sum of all vectors
+ given by our Feature_extractor_type object inside the i-th feature extraction
+ zone.
+ - Then the feature vector for a sliding window is an M*N dimensional vector
+ [F(1) F(2) F(3) ... F(N)] (i.e. it is a concatenation of the N vectors).
+ This feature vector can be thought of as a collection of N "bags of features",
+ each bag coming from a spatial location determined by one of the rectangular
+ feature extraction zones.
+
+ 3. A weight vector and a threshold value. The dot product between the weight
+ vector and the feature vector for a sliding window location gives the score
+ of the window. If this score is greater than the threshold value then the
+ window location is output as a detection.
+
+ Finally, the sliding window classifiers described above are applied to every level of
+ an image pyramid. Moreover, some of the feature extraction zones are allowed to move
+ freely within the object box. This means that when we are sliding the classifier over
+ an image, some feature extraction zones are stationary (i.e. always in the same place
+ relative to the object box) while others are allowed to move anywhere within the object
+ box. In particular, the movable regions are placed at the locations that maximize the
+ score of the classifier. Note further that each of the movable feature extraction
+ zones must pass a threshold test for it to be included. That is, if the score that a
+ movable zone would contribute to the overall score for a sliding window location is not
+ positive then that zone is not included in the feature vector (i.e. its part of the
+ feature vector is set to zero. This way the length of the feature vector stays
+ constant). This movable region construction allows us to represent objects with parts
+ that move around relative to the object box. For example, a human has hands but they
+ aren't always in the same place relative to a person's bounding box.
+
+ THREAD SAFETY
+ Concurrent access to an instance of this object is not safe and should be protected
+ by a mutex lock except for the case where you are copying the configuration
+ (via copy_configuration()) of a scan_image_pyramid object to many other threads.
+ In this case, it is safe to copy the configuration of a shared object so long
+ as no other operations are performed on it.
+ !*/
+ public:
+
+ typedef matrix<double,0,1> feature_vector_type;
+
+ typedef Pyramid_type pyramid_type;
+ typedef Feature_extractor_type feature_extractor_type;
+
+ scan_image_pyramid (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ template <
+ typename image_type
+ >
+ void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type must be a type with the following properties:
+ - image_type is default constructable.
+ - image_type is swappable by the global swap() function.
+ - image_type logically represents some kind of image and therefore its
+ number of rows and columns can be queried via num_rows(img) and
+ num_columns(img) respectively.
+ - image_type objects can be loaded into Feature_extractor_type
+ objects via Feature_extractor_type::load().
+ - image_type objects can be used with Pyramid_type. That is,
+ if pyr is an object of type Pyramid_type while img1 and img2
+ are objects of image_type, then pyr(img1,img2) should be
+ a valid expression which downsamples img1 into img2.
+ ensures
+ - #is_loaded_with_image() == true
+ - This object is ready to run sliding window classifiers over img. Call
+ detect() to do this.
+ !*/
+
+ bool is_loaded_with_image (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object has been loaded with an image to process
+ and false otherwise.
+ !*/
+
+ const feature_extractor_type& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the feature_extractor_type object used
+ internally for local feature extraction.
+ !*/
+
+ void copy_configuration(
+ const feature_extractor_type& fe
+ );
+ /*!
+ ensures
+ - This function performs the equivalent of
+ get_feature_extractor().copy_configuration(fe) (i.e. this function allows
+ you to configure the parameters of the underlying feature extractor used
+ by a scan_image_pyramid object)
+ !*/
+
+ void copy_configuration (
+ const scan_image_pyramid& item
+ );
+ /*!
+ ensures
+ - copies all the state information of item into *this, except for state
+ information populated by load(). More precisely, given two scan_image_pyramid
+ objects S1 and S2, the following sequence of instructions should always
+ result in both of them having the exact same state.
+ S2.copy_configuration(S1);
+ S1.load(img);
+ S2.load(img);
+ !*/
+
+ void add_detection_template (
+ const rectangle& object_box,
+ const std::vector<rectangle>& stationary_feature_extraction_regions,
+ const std::vector<rectangle>& movable_feature_extraction_regions
+ );
+ /*!
+ requires
+ - center(object_box) == point(0,0)
+ - for all valid i:
+ - center(movable_feature_extraction_regions[i]) == point(0,0)
+ - if (get_num_detection_templates() > 0) then
+ - get_num_stationary_components_per_detection_template() == stationary_feature_extraction_regions.size()
+ - get_num_movable_components_per_detection_template() == movable_feature_extraction_regions.size()
+ (i.e. if you already have detection templates in this object, then
+ any new detection template must declare a consistent number of
+ feature extraction regions)
+ ensures
+ - Adds another detection template to this object. In particular, object_box
+ defines the size and shape of a sliding window while stationary_feature_extraction_regions
+ and movable_feature_extraction_regions defines the locations for feature extraction as
+ discussed in the WHAT THIS OBJECT REPRESENTS section above. Note also that the locations of
+ the stationary feature extraction regions are relative to the object_box.
+ - #get_num_detection_templates() == get_num_detection_templates() + 1
+ - The order of rectangles in stationary_feature_extraction_regions and
+ movable_feature_extraction_regions matters. Recall that each rectangle
+ gets its own set of features. So given two different templates, their
+ i-th rectangles will both share the same part of the weight vector (i.e. the w
+ supplied to detect()). So there should be some reasonable correspondence
+ between the rectangle ordering in different detection templates. For,
+ example, different detection templates should place corresponding feature
+ extraction regions in roughly the same part of the object_box.
+ - #get_num_stationary_components_per_detection_template() = stationary_feature_extraction_regions.size()
+ - #get_num_movable_components_per_detection_template() = movable_feature_extraction_regions.size()
+ !*/
+
+ void add_detection_template (
+ const rectangle& object_box,
+ const std::vector<rectangle>& stationary_feature_extraction_regions
+ );
+ /*!
+ ensures
+ - calls add_detection_template(object_box, stationary_feature_extraction_regions, empty_list)
+ where empty_list is a vector of size 0. I.e. this function is just a convenience
+ routine for adding detection templates with no movable regions.
+ !*/
+
+ unsigned long get_num_detection_templates (
+ ) const;
+ /*!
+ ensures
+ - returns the number of detection templates in this object
+ !*/
+
+ unsigned long get_num_stationary_components_per_detection_template (
+ ) const;
+ /*!
+ requires
+ - get_num_detection_templates() > 0
+ ensures
+ - A detection template is a rectangle which defines the shape of a sliding
+ window (the object_box), as well as a set of rectangles which define
+ feature extraction zones. This function returns the number of stationary
+ feature extraction zones in the detection templates used by this object.
+ !*/
+
+ unsigned long get_num_movable_components_per_detection_template (
+ ) const;
+ /*!
+ requires
+ - get_num_detection_templates() > 0
+ ensures
+ - A detection template is a rectangle which defines the shape of a sliding
+ window (the object_box), as well as a set of rectangles which define
+ feature extraction zones. This function returns the number of movable
+ feature extraction zones in the detection templates used by this object.
+ !*/
+
+ unsigned long get_num_components_per_detection_template (
+ ) const;
+ /*!
+ requires
+ - get_num_detection_templates() > 0
+ ensures
+ - returns the total number of feature extraction zones in the detection
+ templates used by this object. That is, returns the following:
+ - get_num_movable_components_per_detection_template() +
+ get_num_stationary_components_per_detection_template()
+ !*/
+
+ long get_num_dimensions (
+ ) const;
+ /*!
+ requires
+ - get_num_detection_templates() > 0
+ ensures
+ - returns the number of dimensions in the feature vector for a sliding window
+ location. This value is the dimensionality of the underlying feature vectors
+ produced by Feature_extractor_type times (get_num_stationary_components_per_detection_template() +
+ get_num_movable_components_per_detection_template()).
+ !*/
+
+ unsigned long get_max_pyramid_levels (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of image pyramid levels this object will use.
+ Note that #get_max_pyramid_levels() == 1 indicates that no image pyramid
+ will be used at all. That is, only the original image will be processed
+ and no lower scale versions will be created.
+ !*/
+
+ void set_max_pyramid_levels (
+ unsigned long max_levels
+ );
+ /*!
+ requires
+ - max_levels > 0
+ ensures
+ - #get_max_pyramid_levels() == max_levels
+ !*/
+
+ void set_min_pyramid_layer_size (
+ unsigned long width,
+ unsigned long height
+ );
+ /*!
+ requires
+ - width > 0
+ - height > 0
+ ensures
+ - #get_min_pyramid_layer_width() == width
+ - #get_min_pyramid_layer_height() == height
+ !*/
+
+ inline unsigned long get_min_pyramid_layer_width (
+ ) const;
+ /*!
+ ensures
+ - returns the smallest allowable width of an image in the image pyramid.
+ All pyramids will always include the original input image, however, no
+ pyramid levels will be created which have a width smaller than the
+ value returned by this function.
+ !*/
+
+ inline unsigned long get_min_pyramid_layer_height (
+ ) const;
+ /*!
+ ensures
+ - returns the smallest allowable height of an image in the image pyramid.
+ All pyramids will always include the original input image, however, no
+ pyramid levels will be created which have a height smaller than the
+ value returned by this function.
+ !*/
+
+ unsigned long get_max_detections_per_template (
+ ) const;
+ /*!
+ ensures
+ - For each image pyramid layer and detection template, this object scans a sliding
+ window classifier over an image and produces a number of detections. This
+ function returns a number which defines a hard upper limit on the number of
+ detections allowed by a single scan. This means that the total number of
+ possible detections produced by detect() is get_max_detections_per_template()*
+ get_num_detection_templates()*(number of image pyramid layers). Additionally,
+ if the maximum number of detections is reached during a scan then this object
+ will return a random subsample of all detections which are above the detection
+ threshold.
+ !*/
+
+ void set_max_detections_per_template (
+ unsigned long max_dets
+ );
+ /*!
+ requires
+ - max_dets > 0
+ ensures
+ - #get_max_detections_per_template() == max_dets
+ !*/
+
+ void detect (
+ const feature_vector_type& w,
+ std::vector<std::pair<double, rectangle> >& dets,
+ const double thresh
+ ) const;
+ /*!
+ requires
+ - w.size() >= get_num_dimensions()
+ - is_loaded_with_image() == true
+ - get_num_detection_templates() > 0
+ ensures
+ - Scans all the detection templates over all pyramid layers as discussed in the
+ WHAT THIS OBJECT REPRESENTS section and stores all detections into #dets.
+ - for all valid i:
+ - #dets[i].second == The object box which produced this detection. This rectangle gives
+ the location of the detection. Note that the rectangle will have been converted back into
+ the original image input space. That is, if this detection was made at a low level in the
+ image pyramid then the object box will have been automatically mapped up the pyramid layers
+ to the original image space. Or in other words, if you plot #dets[i].second on top of the
+ image given to load() it will show up in the right place.
+ - #dets[i].first == The score for this detection. This value is equal to dot(w, feature vector
+ for this sliding window location).
+ - #dets[i].first >= thresh
+ - #dets will be sorted in descending order. (i.e. #dets[i].first >= #dets[j].first for all i, and j>i)
+ - Elements of w beyond index get_num_dimensions()-1 are ignored. I.e. only the first
+ get_num_dimensions() are used.
+ - Note that no form of non-max suppression is performed. If a window has a score >= thresh
+ then it is reported in #dets (assuming the limit imposed by get_max_detections_per_template() hasn't
+ been reached).
+ !*/
+
+ const rectangle get_best_matching_rect (
+ const rectangle& rect
+ ) const;
+ /*!
+ requires
+ - get_num_detection_templates() > 0
+ ensures
+ - Since scan_image_pyramid is a sliding window classifier system, not all possible rectangles
+ can be represented. Therefore, this function allows you to supply a rectangle and obtain the
+ nearest possible sliding window rectangle.
+ !*/
+
+ void get_feature_vector (
+ const full_object_detection& obj,
+ feature_vector_type& psi
+ ) const;
+ /*!
+ requires
+ - all_parts_in_rect(obj) == true
+ - obj.num_parts() == get_num_movable_components_per_detection_template()
+ - is_loaded_with_image() == true
+ - get_num_detection_templates() > 0
+ - psi.size() >= get_num_dimensions()
+ (i.e. psi must have preallocated its memory before this function is called)
+ ensures
+ - This function allows you to determine the feature vector used for a
+ sliding window location. Note that this vector is added to psi. Note
+ also that you must use get_full_object_detection() to convert a rect from
+ detect() into the needed full_object_detection.
+ - The dimensionality of the vector added to psi is get_num_dimensions(). This
+ means that elements of psi after psi(get_num_dimensions()-1) are not modified.
+ - Since scan_image_pyramid is a sliding window classifier system, not all
+ possible rectangles can be output by detect(). So in the case where
+ obj.get_rect() could not arise from a call to detect(), this function
+ will map obj.get_rect() to the nearest possible object box and then add
+ the feature vector for the mapped rectangle into #psi.
+ - get_best_matching_rect(obj.get_rect()) == the rectangle obj.get_rect()
+ gets mapped to for feature extraction.
+ !*/
+
+ full_object_detection get_full_object_detection (
+ const rectangle& rect,
+ const feature_vector_type& w
+ ) const;
+ /*!
+ requires
+ - w.size() >= get_num_dimensions()
+ - is_loaded_with_image() == true
+ - get_num_detection_templates() > 0
+ ensures
+ - This function allows you to determine the full_object_detection
+ corresponding to a sliding window location. Note that the detect()
+ routine doesn't return the locations of the movable parts in a detected
+ object. Therefore, if you are using any movable parts in your model you
+ must use get_full_object_detection() to find out where the movable parts
+ were detected. To do this, you supply the w and detected rectangle.
+ Then the corresponding fully populated full_object_detection will be
+ returned.
+ - returns a full_object_detection, OBJ, such that:
+ - OBJ.get_rect() == rect
+ - OBJ.num_parts() == get_num_movable_components_per_detection_template()
+ - OBJ.part(i) == the location of the i-th movable part inside this detection,
+ or OBJECT_PART_NOT_PRESENT if the part was not found.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void serialize (
+ const scan_image_pyramid<Pyramid_type,Feature_extractor_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ typename Pyramid_type,
+ typename Feature_extractor_type
+ >
+ void deserialize (
+ scan_image_pyramid<Pyramid_type,Feature_extractor_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_IMaGE_PYRAMID_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/scan_image_pyramid_tools.h b/ml/dlib/dlib/image_processing/scan_image_pyramid_tools.h
new file mode 100644
index 000000000..874b995b4
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image_pyramid_tools.h
@@ -0,0 +1,180 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SCAN_IMaGE_PYRAMID_TOOLS_Hh_
+#define DLIB_SCAN_IMaGE_PYRAMID_TOOLS_Hh_
+
+#include "scan_image_pyramid_tools_abstract.h"
+#include "../statistics.h"
+#include <list>
+#include "../geometry.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline bool compare_first (
+ const std::pair<unsigned long,rectangle>& a,
+ const std::pair<unsigned long,rectangle>& b
+ )
+ {
+ return a.first < b.first;
+ }
+ }
+
+
+ template <typename image_scanner_type>
+ std::vector<rectangle> determine_object_boxes (
+ const image_scanner_type& scanner,
+ const std::vector<rectangle>& rects,
+ double min_match_score
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < min_match_score && min_match_score <= 1,
+ "\t std::vector<rectangle> determine_object_boxes()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t min_match_score: " << min_match_score
+ );
+
+ typename image_scanner_type::pyramid_type pyr;
+
+ typedef std::list<std::pair<unsigned long, rectangle> > list_type;
+
+ unsigned long max_area = 0;
+
+ // Copy rects into sorted_rects and sort them in order of increasing area. But
+ // only include the rectangles that aren't already obtainable by the scanner.
+ list_type sorted_rects;
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ if (scanner.get_num_detection_templates() > 0)
+ {
+ rectangle temp = scanner.get_best_matching_rect(rects[i]);
+ const double match_score = (rects[i].intersect(temp).area())/(double)(rects[i] + temp).area();
+ // skip this rectangle if it's already matched well enough.
+ if (match_score > min_match_score)
+ continue;
+ }
+ max_area = std::max(rects[i].area(), max_area);
+ sorted_rects.push_back(std::make_pair(rects[i].area(), rects[i]));
+ }
+ sorted_rects.sort(dlib::impl::compare_first);
+
+ // Make sure this area value is comfortably larger than all the
+ // rectangles' areas.
+ max_area = 3*max_area + 100;
+
+ std::vector<rectangle> object_boxes;
+
+ while (sorted_rects.size() != 0)
+ {
+ rectangle cur = sorted_rects.front().second;
+ sorted_rects.pop_front();
+ object_boxes.push_back(centered_rect(point(0,0), cur.width(), cur.height()));
+
+ // Scale cur up the image pyramid and remove any rectangles which match.
+ // But also stop when cur gets large enough to not match anything.
+ for (unsigned long itr = 0;
+ itr < scanner.get_max_pyramid_levels() && cur.area() < max_area;
+ ++itr)
+ {
+ list_type::iterator i = sorted_rects.begin();
+ while (i != sorted_rects.end())
+ {
+ const rectangle temp = move_rect(i->second, cur.tl_corner());
+ const double match_score = (cur.intersect(temp).area())/(double)(cur + temp).area();
+ if (match_score > min_match_score)
+ {
+ i = sorted_rects.erase(i);
+ }
+ else
+ {
+ ++i;
+ }
+ }
+
+ cur = pyr.rect_up(cur);
+ }
+
+ }
+
+ return object_boxes;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_scanner_type>
+ std::vector<rectangle> determine_object_boxes (
+ const image_scanner_type& scanner,
+ const std::vector<std::vector<rectangle> >& rects,
+ double min_match_score
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < min_match_score && min_match_score <= 1,
+ "\t std::vector<rectangle> determine_object_boxes()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t min_match_score: " << min_match_score
+ );
+
+ std::vector<rectangle> temp;
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ for (unsigned long j = 0; j < rects[i].size(); ++j)
+ {
+ temp.push_back(rects[i][j]);
+ }
+ }
+
+ return determine_object_boxes(scanner, temp, min_match_score);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_scanner_type>
+ void setup_grid_detection_templates (
+ image_scanner_type& scanner,
+ const std::vector<std::vector<rectangle> >& rects,
+ unsigned int cells_x,
+ unsigned int cells_y,
+ double min_match_score = 0.75
+ )
+ {
+ const std::vector<rectangle>& object_boxes = determine_object_boxes(scanner, rects, min_match_score);
+ for (unsigned long i = 0; i < object_boxes.size(); ++i)
+ {
+ scanner.add_detection_template(object_boxes[i], create_grid_detection_template(object_boxes[i], cells_x, cells_y));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_scanner_type>
+ void setup_grid_detection_templates_verbose (
+ image_scanner_type& scanner,
+ const std::vector<std::vector<rectangle> >& rects,
+ unsigned int cells_x,
+ unsigned int cells_y,
+ double min_match_score = 0.75
+ )
+ {
+ const std::vector<rectangle>& object_boxes = determine_object_boxes(scanner, rects, min_match_score);
+ std::cout << "number of detection templates: "<< object_boxes.size() << std::endl;
+ for (unsigned long i = 0; i < object_boxes.size(); ++i)
+ {
+ std::cout << " object box " << i << ": width: " << object_boxes[i].width()
+ << " height: "<< object_boxes[i].height() << std::endl;
+ scanner.add_detection_template(object_boxes[i], create_grid_detection_template(object_boxes[i], cells_x, cells_y));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_IMaGE_PYRAMID_TOOLS_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/scan_image_pyramid_tools_abstract.h b/ml/dlib/dlib/image_processing/scan_image_pyramid_tools_abstract.h
new file mode 100644
index 000000000..83a572df7
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/scan_image_pyramid_tools_abstract.h
@@ -0,0 +1,118 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SCAN_IMaGE_PYRAMID_TOOLS_ABSTRACT_Hh_
+#ifdef DLIB_SCAN_IMaGE_PYRAMID_TOOLS_ABSTRACT_Hh_
+
+#include "scan_image_pyramid_abstract.h"
+#include <vector>
+#include "../geometry.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ std::vector<rectangle> determine_object_boxes (
+ const image_scanner_type& scanner,
+ const std::vector<rectangle>& rects,
+ double min_match_score
+ );
+ /*!
+ requires
+ - 0 < min_match_score <= 1
+ - image_scanner_type == an implementation of the scan_image_pyramid
+ object defined in dlib/image_processing/scan_image_pyramid_tools_abstract.h
+ ensures
+ - returns a set of object boxes which, when used as detection templates with
+ the given scanner, can attain at least min_match_score alignment with every
+ element of rects. Note that the alignment between two rectangles A and B is
+ defined as:
+ (A.intersect(B).area())/(double)(A+B).area()
+ - Only elements of rects which are not already well matched by the scanner are
+ considered. That is, if the scanner already has some detection templates in
+ it then the contents of rects will be checked against those detection
+ templates and elements with a match better than min_match_score are ignore.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ std::vector<rectangle> determine_object_boxes (
+ const image_scanner_type& scanner,
+ const std::vector<std::vector<rectangle> >& rects,
+ double min_match_score
+ );
+ /*!
+ requires
+ - 0 < min_match_score <= 1
+ - image_scanner_type == an implementation of the scan_image_pyramid
+ object defined in dlib/image_processing/scan_image_pyramid_tools_abstract.h
+ ensures
+ - copies all rectangles in rects into a std::vector<rectangle> object, call it
+ R. Then this function returns determine_object_boxes(scanner,R,min_match_score).
+ That is, it just called the version of determine_object_boxes() defined above
+ and returns the results.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ void setup_grid_detection_templates (
+ image_scanner_type& scanner,
+ const std::vector<std::vector<rectangle> >& rects,
+ unsigned int cells_x,
+ unsigned int cells_y,
+ double min_match_score = 0.75
+ );
+ /*!
+ requires
+ - cells_x > 0
+ - cells_y > 0
+ - 0 < min_match_score <= 1
+ - image_scanner_type == an implementation of the scan_image_pyramid
+ object defined in dlib/image_processing/scan_image_pyramid_tools_abstract.h
+ ensures
+ - uses determine_object_boxes(scanner,rects,min_match_score) to obtain a set of
+ object boxes and then adds them to the given scanner object as detection templates.
+ Also uses create_grid_detection_template(object_box, cells_x, cells_y) to create
+ each feature extraction region. Therefore, the detection templates will extract
+ features from a regular grid inside each object box.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ void setup_grid_detection_templates_verbose (
+ image_scanner_type& scanner,
+ const std::vector<std::vector<rectangle> >& rects,
+ unsigned int cells_x,
+ unsigned int cells_y,
+ double min_match_score = 0.75
+ );
+ /*!
+ requires
+ - cells_x > 0
+ - cells_y > 0
+ - 0 < min_match_score <= 1
+ - image_scanner_type == an implementation of the scan_image_pyramid
+ object defined in dlib/image_processing/scan_image_pyramid_tools_abstract.h
+ ensures
+ - this function is identical to setup_grid_detection_templates() except
+ that it also outputs the selected detection templates to standard out.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SCAN_IMaGE_PYRAMID_TOOLS_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_processing/setup_hashed_features.h b/ml/dlib/dlib/image_processing/setup_hashed_features.h
new file mode 100644
index 000000000..5b82cecb4
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/setup_hashed_features.h
@@ -0,0 +1,219 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SETUP_HAShED_FEATURES_Hh_
+#define DLIB_SETUP_HAShED_FEATURES_Hh_
+
+#include "setup_hashed_features_abstract.h"
+#include "scan_image_pyramid.h"
+#include "scan_image_boxes.h"
+#include "../lsh.h"
+#include "../statistics.h"
+#include "../image_keypoint.h"
+#include "../geometry.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class image_hash_construction_failure : public error
+ {
+ public:
+ image_hash_construction_failure(
+ const std::string& a
+ ): error(a) {}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner
+ >
+ void use_uniform_feature_weights (
+ image_scanner& scanner
+ )
+ {
+ typename image_scanner::feature_extractor_type fe;
+ fe.copy_configuration(scanner.get_feature_extractor());
+ fe.use_uniform_feature_weights();
+ scanner.copy_configuration(fe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner
+ >
+ void use_relative_feature_weights (
+ image_scanner& scanner
+ )
+ {
+ typename image_scanner::feature_extractor_type fe;
+ fe.copy_configuration(scanner.get_feature_extractor());
+ fe.use_relative_feature_weights();
+ scanner.copy_configuration(fe);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// stuff for scan_image_pyramid
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array,
+ typename pyramid,
+ typename feature_extractor,
+ template <typename fe, typename hash> class feature_image
+ >
+ void setup_hashed_features (
+ scan_image_pyramid<pyramid, feature_image<feature_extractor, projection_hash> >& scanner,
+ const image_array& images,
+ const feature_extractor& fe,
+ int bits,
+ unsigned long num_samples = 200000
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < bits && bits <= 32 &&
+ num_samples > 1 &&
+ images.size() > 0,
+ "\t void setup_hashed_features()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t bits: " << bits
+ << "\n\t num_samples: " << num_samples
+ << "\n\t images.size(): " << images.size()
+ );
+
+ pyramid pyr;
+
+ const random_subset_selector<typename feature_extractor::descriptor_type>& samps =
+ randomly_sample_image_features(images, pyr, fe, num_samples);
+
+ if (samps.size() <= 1)
+ throw dlib::image_hash_construction_failure("Images too small, not able to gather enough samples to make hash");
+
+ projection_hash phash = create_random_projection_hash(samps, bits);
+
+ feature_image<feature_extractor, projection_hash> hfe;
+ hfe.copy_configuration(scanner.get_feature_extractor());
+ hfe.set_hash(phash);
+ hfe.copy_configuration(fe);
+ scanner.copy_configuration(hfe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array,
+ typename pyramid,
+ typename feature_extractor,
+ template <typename fe, typename hash> class feature_image
+ >
+ void setup_hashed_features (
+ scan_image_pyramid<pyramid, feature_image<feature_extractor, projection_hash> >& scanner,
+ const image_array& images,
+ int bits,
+ unsigned long num_samples = 200000
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < bits && bits <= 32 &&
+ num_samples > 1 &&
+ images.size() > 0,
+ "\t void setup_hashed_features()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t bits: " << bits
+ << "\n\t num_samples: " << num_samples
+ << "\n\t images.size(): " << images.size()
+ );
+
+ feature_extractor fe;
+ setup_hashed_features(scanner, images, fe, bits, num_samples);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// stuff for scan_image_boxes
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array,
+ typename feature_extractor,
+ template <typename fe, typename hash> class feature_image,
+ typename box_generator
+ >
+ void setup_hashed_features (
+ scan_image_boxes<feature_image<feature_extractor, projection_hash>,box_generator >& scanner,
+ const image_array& images,
+ const feature_extractor& fe,
+ int bits,
+ unsigned long num_samples = 200000
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < bits && bits <= 32 &&
+ num_samples > 1 &&
+ images.size() > 0,
+ "\t void setup_hashed_features()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t bits: " << bits
+ << "\n\t num_samples: " << num_samples
+ << "\n\t images.size(): " << images.size()
+ );
+
+ pyramid_disable pyr;
+
+ const random_subset_selector<typename feature_extractor::descriptor_type>& samps =
+ randomly_sample_image_features(images, pyr, fe, num_samples);
+
+ if (samps.size() <= 1)
+ throw dlib::image_hash_construction_failure("Images too small, not able to gather enough samples to make hash");
+
+ projection_hash phash = create_random_projection_hash(samps, bits);
+
+ feature_image<feature_extractor, projection_hash> hfe;
+ hfe.copy_configuration(scanner.get_feature_extractor());
+ hfe.set_hash(phash);
+ hfe.copy_configuration(fe);
+ scanner.copy_configuration(hfe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array,
+ typename feature_extractor,
+ template <typename fe, typename hash> class feature_image,
+ typename box_generator
+ >
+ void setup_hashed_features (
+ scan_image_boxes<feature_image<feature_extractor, projection_hash>,box_generator>& scanner,
+ const image_array& images,
+ int bits,
+ unsigned long num_samples = 200000
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < bits && bits <= 32 &&
+ num_samples > 1 &&
+ images.size() > 0,
+ "\t void setup_hashed_features()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t bits: " << bits
+ << "\n\t num_samples: " << num_samples
+ << "\n\t images.size(): " << images.size()
+ );
+
+ feature_extractor fe;
+ setup_hashed_features(scanner, images, fe, bits, num_samples);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SETUP_HAShED_FEATURES_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/setup_hashed_features_abstract.h b/ml/dlib/dlib/image_processing/setup_hashed_features_abstract.h
new file mode 100644
index 000000000..886411cd4
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/setup_hashed_features_abstract.h
@@ -0,0 +1,210 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SETUP_HAShED_FEATURES_ABSTRACT_Hh_
+#ifdef DLIB_SETUP_HAShED_FEATURES_ABSTRACT_Hh_
+
+#include "scan_image_pyramid_abstract.h"
+#include "scan_image_boxes_abstract.h"
+#include "../lsh/projection_hash_abstract.h"
+#include "../image_keypoint/hashed_feature_image_abstract.h"
+#include "../image_keypoint/binned_vector_feature_image_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class image_hash_construction_failure : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception object used by the routines in this file.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner
+ >
+ void use_uniform_feature_weights (
+ image_scanner& scanner
+ );
+ /*!
+ requires
+ - image_scanner should be either scan_image_pyramid or scan_image_boxes and
+ should use the hashed_feature_image as its local feature extractor.
+ ensures
+ - #scanner.get_feature_extractor().uses_uniform_feature_weights() == true
+ (i.e. Make the scanner's feature extractor use the uniform feature weighting
+ scheme)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner
+ >
+ void use_relative_feature_weights (
+ image_scanner& scanner
+ );
+ /*!
+ requires
+ - image_scanner should be either scan_image_pyramid or scan_image_boxes and
+ should use the hashed_feature_image as its local feature extractor.
+ ensures
+ - #scanner.get_feature_extractor().uses_uniform_feature_weights() == false
+ (i.e. Make the scanner's feature extractor use the relative feature weighting
+ scheme)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array,
+ typename pyramid,
+ typename feature_extractor
+ template <typename fe, typename hash> class feature_image
+ >
+ void setup_hashed_features (
+ scan_image_pyramid<pyramid, feature_image<feature_extractor, projection_hash> >& scanner,
+ const image_array& images,
+ const feature_extractor& fe,
+ int bits,
+ unsigned long num_samples = 200000
+ );
+ /*!
+ requires
+ - 0 < bits <= 32
+ - num_samples > 1
+ - images.size() > 0
+ - it must be valid to pass images[0] into scanner.load().
+ (also, image_array must be an implementation of dlib/array/array_kernel_abstract.h)
+ - feature_image == must be either hashed_feature_image, binned_vector_feature_image,
+ or a type with a compatible interface.
+ ensures
+ - Creates a projection_hash suitable for hashing the feature vectors produced by
+ fe and then configures scanner to use this hash function.
+ - The hash function will map vectors into integers in the range [0, pow(2,bits))
+ - The hash function will be setup so that it hashes a random sample of num_samples
+ vectors from fe such that each bin ends up with roughly the same number of
+ elements in it.
+ throws
+ - image_hash_construction_failure
+ This exception is thrown if there is a problem creating the projection_hash.
+ This should only happen the images are so small they contain less than 2
+ feature vectors.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array,
+ typename pyramid,
+ typename feature_extractor
+ template <typename fe, typename hash> class feature_image
+ >
+ void setup_hashed_features (
+ scan_image_pyramid<pyramid, feature_image<feature_extractor, projection_hash> >& scanner,
+ const image_array& images,
+ int bits,
+ unsigned long num_samples = 200000
+ );
+ /*!
+ requires
+ - 0 < bits <= 32
+ - num_samples > 1
+ - images.size() > 0
+ - it must be valid to pass images[0] into scanner.load().
+ (also, image_array must be an implementation of dlib/array/array_kernel_abstract.h)
+ - feature_image == must be either hashed_feature_image, binned_vector_feature_image,
+ or a type with a compatible interface.
+ ensures
+ - performs: setup_hashed_features(scanner, images, feature_extractor(), bits, num_samples)
+ throws
+ - image_hash_construction_failure
+ This exception is thrown if there is a problem creating the projection_hash.
+ This should only happen the images are so small they contain less than 2
+ feature vectors.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array,
+ typename feature_extractor,
+ template <typename fe, typename hash> class feature_image
+ typename box_generator
+ >
+ void setup_hashed_features (
+ scan_image_boxes<feature_image<feature_extractor, projection_hash>,box_generator>& scanner,
+ const image_array& images,
+ const feature_extractor& fe,
+ int bits,
+ unsigned long num_samples = 200000
+ );
+ /*!
+ requires
+ - 0 < bits <= 32
+ - num_samples > 1
+ - images.size() > 0
+ - it must be valid to pass images[0] into scanner.load().
+ (also, image_array must be an implementation of dlib/array/array_kernel_abstract.h)
+ - feature_image == must be either hashed_feature_image, binned_vector_feature_image,
+ or a type with a compatible interface.
+ ensures
+ - Creates a projection_hash suitable for hashing the feature vectors produced by
+ fe and then configures scanner to use this hash function.
+ - The hash function will map vectors into integers in the range [0, pow(2,bits))
+ - The hash function will be setup so that it hashes a random sample of num_samples
+ vectors from fe such that each bin ends up with roughly the same number of
+ elements in it.
+ throws
+ - image_hash_construction_failure
+ This exception is thrown if there is a problem creating the projection_hash.
+ This should only happen the images are so small they contain less than 2
+ feature vectors.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array,
+ typename feature_extractor,
+ template <typename fe, typename hash> class feature_image
+ typename box_generator
+ >
+ void setup_hashed_features (
+ scan_image_boxes<feature_image<feature_extractor, projection_hash>,box_generator>& scanner,
+ const image_array& images,
+ int bits,
+ unsigned long num_samples = 200000
+ );
+ /*!
+ requires
+ - 0 < bits <= 32
+ - num_samples > 1
+ - images.size() > 0
+ - it must be valid to pass images[0] into scanner.load().
+ (also, image_array must be an implementation of dlib/array/array_kernel_abstract.h)
+ - feature_image == must be either hashed_feature_image, binned_vector_feature_image,
+ or a type with a compatible interface.
+ ensures
+ - performs: setup_hashed_features(scanner, images, feature_extractor(), bits, num_samples)
+ throws
+ - image_hash_construction_failure
+ This exception is thrown if there is a problem creating the projection_hash.
+ This should only happen the images are so small they contain less than 2
+ feature vectors.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SETUP_HAShED_FEATURES_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_processing/shape_predictor.h b/ml/dlib/dlib/image_processing/shape_predictor.h
new file mode 100644
index 000000000..05e9a60fd
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/shape_predictor.h
@@ -0,0 +1,524 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SHAPE_PREDICToR_H_
+#define DLIB_SHAPE_PREDICToR_H_
+
+#include "shape_predictor_abstract.h"
+#include "full_object_detection.h"
+#include "../algs.h"
+#include "../matrix.h"
+#include "../geometry.h"
+#include "../pixel.h"
+#include "../statistics.h"
+#include <utility>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ struct split_feature
+ {
+ unsigned long idx1;
+ unsigned long idx2;
+ float thresh;
+
+ friend inline void serialize (const split_feature& item, std::ostream& out)
+ {
+ dlib::serialize(item.idx1, out);
+ dlib::serialize(item.idx2, out);
+ dlib::serialize(item.thresh, out);
+ }
+ friend inline void deserialize (split_feature& item, std::istream& in)
+ {
+ dlib::deserialize(item.idx1, in);
+ dlib::deserialize(item.idx2, in);
+ dlib::deserialize(item.thresh, in);
+ }
+ };
+
+
+ // a tree is just a std::vector<impl::split_feature>. We use this function to navigate the
+ // tree nodes
+ inline unsigned long left_child (unsigned long idx) { return 2*idx + 1; }
+ /*!
+ ensures
+ - returns the index of the left child of the binary tree node idx
+ !*/
+ inline unsigned long right_child (unsigned long idx) { return 2*idx + 2; }
+ /*!
+ ensures
+ - returns the index of the left child of the binary tree node idx
+ !*/
+
+ struct regression_tree
+ {
+ std::vector<split_feature> splits;
+ std::vector<matrix<float,0,1> > leaf_values;
+
+ unsigned long num_leaves() const { return leaf_values.size(); }
+
+ inline const matrix<float,0,1>& operator()(
+ const std::vector<float>& feature_pixel_values,
+ unsigned long& i
+ ) const
+ /*!
+ requires
+ - All the index values in splits are less than feature_pixel_values.size()
+ - leaf_values.size() is a power of 2.
+ (i.e. we require a tree with all the levels fully filled out.
+ - leaf_values.size() == splits.size()+1
+ (i.e. there needs to be the right number of leaves given the number of splits in the tree)
+ ensures
+ - runs through the tree and returns the vector at the leaf we end up in.
+ - #i == the selected leaf node index.
+ !*/
+ {
+ i = 0;
+ while (i < splits.size())
+ {
+ if ((float)feature_pixel_values[splits[i].idx1] - (float)feature_pixel_values[splits[i].idx2] > splits[i].thresh)
+ i = left_child(i);
+ else
+ i = right_child(i);
+ }
+ i = i - splits.size();
+ return leaf_values[i];
+ }
+
+ friend void serialize (const regression_tree& item, std::ostream& out)
+ {
+ dlib::serialize(item.splits, out);
+ dlib::serialize(item.leaf_values, out);
+ }
+ friend void deserialize (regression_tree& item, std::istream& in)
+ {
+ dlib::deserialize(item.splits, in);
+ dlib::deserialize(item.leaf_values, in);
+ }
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ inline vector<float,2> location (
+ const matrix<float,0,1>& shape,
+ unsigned long idx
+ )
+ /*!
+ requires
+ - idx < shape.size()/2
+ - shape.size()%2 == 0
+ ensures
+ - returns the idx-th point from the shape vector.
+ !*/
+ {
+ return vector<float,2>(shape(idx*2), shape(idx*2+1));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline unsigned long nearest_shape_point (
+ const matrix<float,0,1>& shape,
+ const dlib::vector<float,2>& pt
+ )
+ {
+ // find the nearest part of the shape to this pixel
+ float best_dist = std::numeric_limits<float>::infinity();
+ const unsigned long num_shape_parts = shape.size()/2;
+ unsigned long best_idx = 0;
+ for (unsigned long j = 0; j < num_shape_parts; ++j)
+ {
+ const float dist = length_squared(location(shape,j)-pt);
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ best_idx = j;
+ }
+ }
+ return best_idx;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline void create_shape_relative_encoding (
+ const matrix<float,0,1>& shape,
+ const std::vector<dlib::vector<float,2> >& pixel_coordinates,
+ std::vector<unsigned long>& anchor_idx,
+ std::vector<dlib::vector<float,2> >& deltas
+ )
+ /*!
+ requires
+ - shape.size()%2 == 0
+ - shape.size() > 0
+ ensures
+ - #anchor_idx.size() == pixel_coordinates.size()
+ - #deltas.size() == pixel_coordinates.size()
+ - for all valid i:
+ - pixel_coordinates[i] == location(shape,#anchor_idx[i]) + #deltas[i]
+ !*/
+ {
+ anchor_idx.resize(pixel_coordinates.size());
+ deltas.resize(pixel_coordinates.size());
+
+
+ for (unsigned long i = 0; i < pixel_coordinates.size(); ++i)
+ {
+ anchor_idx[i] = nearest_shape_point(shape, pixel_coordinates[i]);
+ deltas[i] = pixel_coordinates[i] - location(shape,anchor_idx[i]);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline point_transform_affine find_tform_between_shapes (
+ const matrix<float,0,1>& from_shape,
+ const matrix<float,0,1>& to_shape
+ )
+ {
+ DLIB_ASSERT(from_shape.size() == to_shape.size() && (from_shape.size()%2) == 0 && from_shape.size() > 0,"");
+ std::vector<vector<float,2> > from_points, to_points;
+ const unsigned long num = from_shape.size()/2;
+ from_points.reserve(num);
+ to_points.reserve(num);
+ if (num == 1)
+ {
+ // Just use an identity transform if there is only one landmark.
+ return point_transform_affine();
+ }
+
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ from_points.push_back(location(from_shape,i));
+ to_points.push_back(location(to_shape,i));
+ }
+ return find_similarity_transform(from_points, to_points);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline point_transform_affine normalizing_tform (
+ const rectangle& rect
+ )
+ /*!
+ ensures
+ - returns a transform that maps rect.tl_corner() to (0,0) and rect.br_corner()
+ to (1,1).
+ !*/
+ {
+ std::vector<vector<float,2> > from_points, to_points;
+ from_points.push_back(rect.tl_corner()); to_points.push_back(point(0,0));
+ from_points.push_back(rect.tr_corner()); to_points.push_back(point(1,0));
+ from_points.push_back(rect.br_corner()); to_points.push_back(point(1,1));
+ return find_affine_transform(from_points, to_points);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline point_transform_affine unnormalizing_tform (
+ const rectangle& rect
+ )
+ /*!
+ ensures
+ - returns a transform that maps (0,0) to rect.tl_corner() and (1,1) to
+ rect.br_corner().
+ !*/
+ {
+ std::vector<vector<float,2> > from_points, to_points;
+ to_points.push_back(rect.tl_corner()); from_points.push_back(point(0,0));
+ to_points.push_back(rect.tr_corner()); from_points.push_back(point(1,0));
+ to_points.push_back(rect.br_corner()); from_points.push_back(point(1,1));
+ return find_affine_transform(from_points, to_points);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename image_type, typename feature_type>
+ void extract_feature_pixel_values (
+ const image_type& img_,
+ const rectangle& rect,
+ const matrix<float,0,1>& current_shape,
+ const matrix<float,0,1>& reference_shape,
+ const std::vector<unsigned long>& reference_pixel_anchor_idx,
+ const std::vector<dlib::vector<float,2> >& reference_pixel_deltas,
+ std::vector<feature_type>& feature_pixel_values
+ )
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - reference_pixel_anchor_idx.size() == reference_pixel_deltas.size()
+ - current_shape.size() == reference_shape.size()
+ - reference_shape.size()%2 == 0
+ - max(mat(reference_pixel_anchor_idx)) < reference_shape.size()/2
+ ensures
+ - #feature_pixel_values.size() == reference_pixel_deltas.size()
+ - for all valid i:
+ - #feature_pixel_values[i] == the value of the pixel in img_ that
+ corresponds to the pixel identified by reference_pixel_anchor_idx[i]
+ and reference_pixel_deltas[i] when the pixel is located relative to
+ current_shape rather than reference_shape.
+ !*/
+ {
+ const matrix<float,2,2> tform = matrix_cast<float>(find_tform_between_shapes(reference_shape, current_shape).get_m());
+ const point_transform_affine tform_to_img = unnormalizing_tform(rect);
+
+ const rectangle area = get_rect(img_);
+
+ const_image_view<image_type> img(img_);
+ feature_pixel_values.resize(reference_pixel_deltas.size());
+ for (unsigned long i = 0; i < feature_pixel_values.size(); ++i)
+ {
+ // Compute the point in the current shape corresponding to the i-th pixel and
+ // then map it from the normalized shape space into pixel space.
+ point p = tform_to_img(tform*reference_pixel_deltas[i] + location(current_shape, reference_pixel_anchor_idx[i]));
+ if (area.contains(p))
+ feature_pixel_values[i] = get_pixel_intensity(img[p.y()][p.x()]);
+ else
+ feature_pixel_values[i] = 0;
+ }
+ }
+
+ } // end namespace impl
+
+// ----------------------------------------------------------------------------------------
+
+ class shape_predictor
+ {
+ public:
+
+
+ shape_predictor (
+ )
+ {}
+
+ shape_predictor (
+ const matrix<float,0,1>& initial_shape_,
+ const std::vector<std::vector<impl::regression_tree> >& forests_,
+ const std::vector<std::vector<dlib::vector<float,2> > >& pixel_coordinates
+ ) : initial_shape(initial_shape_), forests(forests_)
+ /*!
+ requires
+ - initial_shape.size()%2 == 0
+ - forests.size() == pixel_coordinates.size() == the number of cascades
+ - for all valid i:
+ - all the index values in forests[i] are less than pixel_coordinates[i].size()
+ - for all valid i and j:
+ - forests[i][j].leaf_values.size() is a power of 2.
+ (i.e. we require a tree with all the levels fully filled out.
+ - forests[i][j].leaf_values.size() == forests[i][j].splits.size()+1
+ (i.e. there need to be the right number of leaves given the number of splits in the tree)
+ !*/
+ {
+ anchor_idx.resize(pixel_coordinates.size());
+ deltas.resize(pixel_coordinates.size());
+ // Each cascade uses a different set of pixels for its features. We compute
+ // their representations relative to the initial shape now and save it.
+ for (unsigned long i = 0; i < pixel_coordinates.size(); ++i)
+ impl::create_shape_relative_encoding(initial_shape, pixel_coordinates[i], anchor_idx[i], deltas[i]);
+ }
+
+ unsigned long num_parts (
+ ) const
+ {
+ return initial_shape.size()/2;
+ }
+
+ unsigned long num_features (
+ ) const
+ {
+ unsigned long num = 0;
+ for (unsigned long iter = 0; iter < forests.size(); ++iter)
+ for (unsigned long i = 0; i < forests[iter].size(); ++i)
+ num += forests[iter][i].num_leaves();
+ return num;
+ }
+
+ template <typename image_type>
+ full_object_detection operator()(
+ const image_type& img,
+ const rectangle& rect
+ ) const
+ {
+ using namespace impl;
+ matrix<float,0,1> current_shape = initial_shape;
+ std::vector<float> feature_pixel_values;
+ for (unsigned long iter = 0; iter < forests.size(); ++iter)
+ {
+ extract_feature_pixel_values(img, rect, current_shape, initial_shape,
+ anchor_idx[iter], deltas[iter], feature_pixel_values);
+ unsigned long leaf_idx;
+ // evaluate all the trees at this level of the cascade.
+ for (unsigned long i = 0; i < forests[iter].size(); ++i)
+ current_shape += forests[iter][i](feature_pixel_values, leaf_idx);
+ }
+
+ // convert the current_shape into a full_object_detection
+ const point_transform_affine tform_to_img = unnormalizing_tform(rect);
+ std::vector<point> parts(current_shape.size()/2);
+ for (unsigned long i = 0; i < parts.size(); ++i)
+ parts[i] = tform_to_img(location(current_shape, i));
+ return full_object_detection(rect, parts);
+ }
+
+ template <typename image_type, typename T, typename U>
+ full_object_detection operator()(
+ const image_type& img,
+ const rectangle& rect,
+ std::vector<std::pair<T,U> >& feats
+ ) const
+ {
+ feats.clear();
+ using namespace impl;
+ matrix<float,0,1> current_shape = initial_shape;
+ std::vector<float> feature_pixel_values;
+ unsigned long feat_offset = 0;
+ for (unsigned long iter = 0; iter < forests.size(); ++iter)
+ {
+ extract_feature_pixel_values(img, rect, current_shape, initial_shape,
+ anchor_idx[iter], deltas[iter], feature_pixel_values);
+ // evaluate all the trees at this level of the cascade.
+ for (unsigned long i = 0; i < forests[iter].size(); ++i)
+ {
+ unsigned long leaf_idx;
+ current_shape += forests[iter][i](feature_pixel_values, leaf_idx);
+
+ feats.push_back(std::make_pair(feat_offset+leaf_idx, 1));
+ feat_offset += forests[iter][i].num_leaves();
+ }
+ }
+
+ // convert the current_shape into a full_object_detection
+ const point_transform_affine tform_to_img = unnormalizing_tform(rect);
+ std::vector<point> parts(current_shape.size()/2);
+ for (unsigned long i = 0; i < parts.size(); ++i)
+ parts[i] = tform_to_img(location(current_shape, i));
+ return full_object_detection(rect, parts);
+ }
+
+ friend void serialize (const shape_predictor& item, std::ostream& out);
+
+ friend void deserialize (shape_predictor& item, std::istream& in);
+
+ private:
+ matrix<float,0,1> initial_shape;
+ std::vector<std::vector<impl::regression_tree> > forests;
+ std::vector<std::vector<unsigned long> > anchor_idx;
+ std::vector<std::vector<dlib::vector<float,2> > > deltas;
+ };
+
+ inline void serialize (const shape_predictor& item, std::ostream& out)
+ {
+ int version = 1;
+ dlib::serialize(version, out);
+ dlib::serialize(item.initial_shape, out);
+ dlib::serialize(item.forests, out);
+ dlib::serialize(item.anchor_idx, out);
+ dlib::serialize(item.deltas, out);
+ }
+
+ inline void deserialize (shape_predictor& item, std::istream& in)
+ {
+ int version = 0;
+ dlib::deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::shape_predictor.");
+ dlib::deserialize(item.initial_shape, in);
+ dlib::deserialize(item.forests, in);
+ dlib::deserialize(item.anchor_idx, in);
+ dlib::deserialize(item.deltas, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array
+ >
+ double test_shape_predictor (
+ const shape_predictor& sp,
+ const image_array& images,
+ const std::vector<std::vector<full_object_detection> >& objects,
+ const std::vector<std::vector<double> >& scales
+ )
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ DLIB_CASSERT( images.size() == objects.size() ,
+ "\t double test_shape_predictor()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ );
+ for (unsigned long i = 0; i < objects.size(); ++i)
+ {
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ DLIB_CASSERT(objects[i][j].num_parts() == sp.num_parts(),
+ "\t double test_shape_predictor()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t objects["<<i<<"]["<<j<<"].num_parts(): " << objects[i][j].num_parts()
+ << "\n\t sp.num_parts(): " << sp.num_parts()
+ );
+ }
+ if (scales.size() != 0)
+ {
+ DLIB_CASSERT(objects[i].size() == scales[i].size(),
+ "\t double test_shape_predictor()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t objects["<<i<<"].size(): " << objects[i].size()
+ << "\n\t scales["<<i<<"].size(): " << scales[i].size()
+ );
+
+ }
+ }
+#endif
+
+ running_stats<double> rs;
+ for (unsigned long i = 0; i < objects.size(); ++i)
+ {
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ // Just use a scale of 1 (i.e. no scale at all) if the caller didn't supply
+ // any scales.
+ const double scale = scales.size()==0 ? 1 : scales[i][j];
+
+ full_object_detection det = sp(images[i], objects[i][j].get_rect());
+
+ for (unsigned long k = 0; k < det.num_parts(); ++k)
+ {
+ if (objects[i][j].part(k) != OBJECT_PART_NOT_PRESENT)
+ {
+ double score = length(det.part(k) - objects[i][j].part(k))/scale;
+ rs.add(score);
+ }
+ }
+ }
+ }
+ return rs.mean();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array
+ >
+ double test_shape_predictor (
+ const shape_predictor& sp,
+ const image_array& images,
+ const std::vector<std::vector<full_object_detection> >& objects
+ )
+ {
+ std::vector<std::vector<double> > no_scales;
+ return test_shape_predictor(sp, images, objects, no_scales);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SHAPE_PREDICToR_H_
+
diff --git a/ml/dlib/dlib/image_processing/shape_predictor_abstract.h b/ml/dlib/dlib/image_processing/shape_predictor_abstract.h
new file mode 100644
index 000000000..718b4952e
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/shape_predictor_abstract.h
@@ -0,0 +1,195 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SHAPE_PREDICToR_ABSTRACT_H_
+#ifdef DLIB_SHAPE_PREDICToR_ABSTRACT_H_
+
+#include "full_object_detection_abstract.h"
+#include "../matrix.h"
+#include "../geometry.h"
+#include "../pixel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class shape_predictor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool that takes in an image region containing some object
+ and outputs a set of point locations that define the pose of the object.
+ The classic example of this is human face pose prediction, where you take
+ an image of a human face as input and are expected to identify the
+ locations of important facial landmarks such as the corners of the mouth
+ and eyes, tip of the nose, and so forth.
+
+ To create useful instantiations of this object you need to use the
+ shape_predictor_trainer object defined in the
+ shape_predictor_trainer_abstract.h file to train a shape_predictor using a
+ set of training images, each annotated with shapes you want to predict.
+
+ THREAD SAFETY
+ No synchronization is required when using this object. In particular, a
+ single instance of this object can be used from multiple threads at the
+ same time.
+ !*/
+
+ public:
+
+ shape_predictor (
+ );
+ /*!
+ ensures
+ - #num_parts() == 0
+ - #num_features() == 0
+ !*/
+
+ unsigned long num_parts (
+ ) const;
+ /*!
+ ensures
+ - returns the number of parts in the shapes predicted by this object.
+ !*/
+
+ unsigned long num_features (
+ ) const;
+ /*!
+ ensures
+ - Returns the dimensionality of the feature vector output by operator().
+ This number is the total number of trees in this object times the number
+ of leaves on each tree.
+ !*/
+
+ template <typename image_type, typename T, typename U>
+ full_object_detection operator()(
+ const image_type& img,
+ const rectangle& rect,
+ std::vector<std::pair<T,U> >& feats
+ ) const;
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - T is some unsigned integral type (e.g. unsigned int).
+ - U is any scalar type capable of storing the value 1 (e.g. float).
+ ensures
+ - Runs the shape prediction algorithm on the part of the image contained in
+ the given bounding rectangle. So it will try and fit the shape model to
+ the contents of the given rectangle in the image. For example, if there
+ is a human face inside the rectangle and you use a face landmarking shape
+ model then this function will return the locations of the face landmarks
+ as the parts. So the return value is a full_object_detection DET such
+ that:
+ - DET.get_rect() == rect
+ - DET.num_parts() == num_parts()
+ - for all valid i:
+ - DET.part(i) == the location in img for the i-th part of the shape
+ predicted by this object.
+ - #feats == a sparse vector that records which leaf each tree used to make
+ the shape prediction. Moreover, it is an indicator vector, Therefore,
+ for all valid i:
+ - #feats[i].second == 1
+ Further, #feats is a vector from the space of num_features() dimensional
+ vectors. The output shape positions can be represented as the dot
+ product between #feats and a weight vector. Therefore, #feats encodes
+ all the information from img that was used to predict the returned shape
+ object.
+ !*/
+
+ template <typename image_type>
+ full_object_detection operator()(
+ const image_type& img,
+ const rectangle& rect
+ ) const;
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - Calling this function is equivalent to calling (*this)(img, rect, ignored)
+ where the 3d argument is discarded.
+ !*/
+
+ };
+
+ void serialize (const shape_predictor& item, std::ostream& out);
+ void deserialize (shape_predictor& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array
+ >
+ double test_shape_predictor (
+ const shape_predictor& sp,
+ const image_array& images,
+ const std::vector<std::vector<full_object_detection> >& objects,
+ const std::vector<std::vector<double> >& scales
+ );
+ /*!
+ requires
+ - image_array is a dlib::array of image objects where each image object
+ implements the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ - for all valid i and j:
+ - objects[i][j].num_parts() == sp.num_parts()
+ - if (scales.size() != 0) then
+ - There must be a scale value for each full_object_detection in objects.
+ That is, it must be the case that:
+ - scales.size() == objects.size()
+ - for all valid i:
+ - scales[i].size() == objects[i].size()
+ ensures
+ - Tests the given shape_predictor by running it on each of the given objects and
+ checking how well it recovers the part positions. In particular, for all
+ valid i and j we perform:
+ sp(images[i], objects[i][j].get_rect())
+ and compare the result with the truth part positions in objects[i][j]. We
+ then return the average distance (measured in pixels) between a predicted
+ part location and its true position.
+ - Note that any parts in objects that are set to OBJECT_PART_NOT_PRESENT are
+ simply ignored.
+ - if (scales.size() != 0) then
+ - Each time we compute the distance between a predicted part location and
+ its true location in objects[i][j] we divide the distance by
+ scales[i][j]. Therefore, if you want the reported error to be the
+ average pixel distance then give an empty scales vector, but if you want
+ the returned value to be something else like the average distance
+ normalized by some feature of each object (e.g. the interocular distance)
+ then you can supply those normalizing values via scales.
+ !*/
+
+ template <
+ typename image_array
+ >
+ double test_shape_predictor (
+ const shape_predictor& sp,
+ const image_array& images,
+ const std::vector<std::vector<full_object_detection> >& objects
+ );
+ /*!
+ requires
+ - image_array is a dlib::array of image objects where each image object
+ implements the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ - for all valid i and j:
+ - objects[i][j].num_parts() == sp.num_parts()
+ ensures
+ - returns test_shape_predictor(sp, images, objects, no_scales) where no_scales
+ is an empty vector. So this is just a convenience function for calling the
+ above test_shape_predictor() routine without a scales argument.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SHAPE_PREDICToR_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/image_processing/shape_predictor_trainer.h b/ml/dlib/dlib/image_processing/shape_predictor_trainer.h
new file mode 100644
index 000000000..3090998f9
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/shape_predictor_trainer.h
@@ -0,0 +1,852 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SHAPE_PREDICToR_TRAINER_H_
+#define DLIB_SHAPE_PREDICToR_TRAINER_H_
+
+#include "shape_predictor_trainer_abstract.h"
+#include "shape_predictor.h"
+#include "../console_progress_indicator.h"
+#include "../threads.h"
+#include "../data_io/image_dataset_metadata.h"
+#include "box_overlap_testing.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class shape_predictor_trainer
+ {
+ /*!
+ This thing really only works with unsigned char or rgb_pixel images (since we assume the threshold
+ should be in the range [-128,128]).
+ !*/
+ public:
+
+ enum padding_mode_t
+ {
+ bounding_box_relative,
+ landmark_relative
+ };
+
+ shape_predictor_trainer (
+ )
+ {
+ _cascade_depth = 10;
+ _tree_depth = 4;
+ _num_trees_per_cascade_level = 500;
+ _nu = 0.1;
+ _oversampling_amount = 20;
+ _feature_pool_size = 400;
+ _lambda = 0.1;
+ _num_test_splits = 20;
+ _feature_pool_region_padding = 0;
+ _verbose = false;
+ _num_threads = 0;
+ _padding_mode = landmark_relative;
+ }
+
+ unsigned long get_cascade_depth (
+ ) const { return _cascade_depth; }
+
+ void set_cascade_depth (
+ unsigned long depth
+ )
+ {
+ DLIB_CASSERT(depth > 0,
+ "\t void shape_predictor_trainer::set_cascade_depth()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t depth: " << depth
+ );
+
+ _cascade_depth = depth;
+ }
+
+ unsigned long get_tree_depth (
+ ) const { return _tree_depth; }
+
+ void set_tree_depth (
+ unsigned long depth
+ )
+ {
+ DLIB_CASSERT(depth > 0,
+ "\t void shape_predictor_trainer::set_tree_depth()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t depth: " << depth
+ );
+
+ _tree_depth = depth;
+ }
+
+ unsigned long get_num_trees_per_cascade_level (
+ ) const { return _num_trees_per_cascade_level; }
+
+ void set_num_trees_per_cascade_level (
+ unsigned long num
+ )
+ {
+ DLIB_CASSERT( num > 0,
+ "\t void shape_predictor_trainer::set_num_trees_per_cascade_level()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t num: " << num
+ );
+ _num_trees_per_cascade_level = num;
+ }
+
+ double get_nu (
+ ) const { return _nu; }
+ void set_nu (
+ double nu
+ )
+ {
+ DLIB_CASSERT(0 < nu && nu <= 1,
+ "\t void shape_predictor_trainer::set_nu()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t nu: " << nu
+ );
+
+ _nu = nu;
+ }
+
+ std::string get_random_seed (
+ ) const { return rnd.get_seed(); }
+ void set_random_seed (
+ const std::string& seed
+ ) { rnd.set_seed(seed); }
+
+ unsigned long get_oversampling_amount (
+ ) const { return _oversampling_amount; }
+ void set_oversampling_amount (
+ unsigned long amount
+ )
+ {
+ DLIB_CASSERT(amount > 0,
+ "\t void shape_predictor_trainer::set_oversampling_amount()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t amount: " << amount
+ );
+
+ _oversampling_amount = amount;
+ }
+
+ unsigned long get_feature_pool_size (
+ ) const { return _feature_pool_size; }
+ void set_feature_pool_size (
+ unsigned long size
+ )
+ {
+ DLIB_CASSERT(size > 1,
+ "\t void shape_predictor_trainer::set_feature_pool_size()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t size: " << size
+ );
+
+ _feature_pool_size = size;
+ }
+
+ double get_lambda (
+ ) const { return _lambda; }
+ void set_lambda (
+ double lambda
+ )
+ {
+ DLIB_CASSERT(lambda > 0,
+ "\t void shape_predictor_trainer::set_lambda()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t lambda: " << lambda
+ );
+
+ _lambda = lambda;
+ }
+
+ unsigned long get_num_test_splits (
+ ) const { return _num_test_splits; }
+ void set_num_test_splits (
+ unsigned long num
+ )
+ {
+ DLIB_CASSERT(num > 0,
+ "\t void shape_predictor_trainer::set_num_test_splits()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t num: " << num
+ );
+
+ _num_test_splits = num;
+ }
+
+ void set_padding_mode (
+ padding_mode_t mode
+ )
+ {
+ _padding_mode = mode;
+ }
+
+ padding_mode_t get_padding_mode (
+ ) const { return _padding_mode; }
+
+ double get_feature_pool_region_padding (
+ ) const { return _feature_pool_region_padding; }
+ void set_feature_pool_region_padding (
+ double padding
+ )
+ {
+ DLIB_CASSERT(padding > -0.5,
+ "\t void shape_predictor_trainer::set_feature_pool_region_padding()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t padding: " << padding
+ );
+
+ _feature_pool_region_padding = padding;
+ }
+
+ void be_verbose (
+ )
+ {
+ _verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ _verbose = false;
+ }
+
+ unsigned long get_num_threads (
+ ) const { return _num_threads; }
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ _num_threads = num;
+ }
+
+ template <typename image_array>
+ shape_predictor train (
+ const image_array& images,
+ const std::vector<std::vector<full_object_detection> >& objects
+ ) const
+ {
+ using namespace impl;
+ DLIB_CASSERT(images.size() == objects.size() && images.size() > 0,
+ "\t shape_predictor shape_predictor_trainer::train()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ );
+ // make sure the objects agree on the number of parts and that there is at
+ // least one full_object_detection.
+ unsigned long num_parts = 0;
+ std::vector<int> part_present;
+ for (unsigned long i = 0; i < objects.size(); ++i)
+ {
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ if (num_parts == 0)
+ {
+ num_parts = objects[i][j].num_parts();
+ DLIB_CASSERT(objects[i][j].num_parts() != 0,
+ "\t shape_predictor shape_predictor_trainer::train()"
+ << "\n\t You can't give objects that don't have any parts to the trainer."
+ );
+ part_present.resize(num_parts);
+ }
+ else
+ {
+ DLIB_CASSERT(objects[i][j].num_parts() == num_parts,
+ "\t shape_predictor shape_predictor_trainer::train()"
+ << "\n\t All the objects must agree on the number of parts. "
+ << "\n\t objects["<<i<<"]["<<j<<"].num_parts(): " << objects[i][j].num_parts()
+ << "\n\t num_parts: " << num_parts
+ );
+ }
+ for (unsigned long p = 0; p < objects[i][j].num_parts(); ++p)
+ {
+ if (objects[i][j].part(p) != OBJECT_PART_NOT_PRESENT)
+ part_present[p] = 1;
+ }
+ }
+ }
+ DLIB_CASSERT(num_parts != 0,
+ "\t shape_predictor shape_predictor_trainer::train()"
+ << "\n\t You must give at least one full_object_detection if you want to train a shape model and it must have parts."
+ );
+ DLIB_CASSERT(sum(mat(part_present)) == (long)num_parts,
+ "\t shape_predictor shape_predictor_trainer::train()"
+ << "\n\t Each part must appear at least once in this training data. That is, "
+ << "\n\t you can't have a part that is always set to OBJECT_PART_NOT_PRESENT."
+ );
+
+ // creating thread pool. if num_threads <= 1, trainer should work in caller thread
+ thread_pool tp(_num_threads > 1 ? _num_threads : 0);
+
+ // determining the type of features used for this type of images
+ typedef typename std::remove_const<typename std::remove_reference<decltype(images[0])>::type>::type image_type;
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ typedef typename pixel_traits<pixel_type>::basic_pixel_type feature_type;
+
+ rnd.set_seed(get_random_seed());
+
+ std::vector<training_sample<feature_type>> samples;
+ const matrix<float,0,1> initial_shape = populate_training_sample_shapes(objects, samples);
+ const std::vector<std::vector<dlib::vector<float,2> > > pixel_coordinates = randomly_sample_pixel_coordinates(initial_shape);
+
+ unsigned long trees_fit_so_far = 0;
+ console_progress_indicator pbar(get_cascade_depth()*get_num_trees_per_cascade_level());
+ if (_verbose)
+ std::cout << "Fitting trees..." << std::endl;
+
+ std::vector<std::vector<impl::regression_tree> > forests(get_cascade_depth());
+ // Now start doing the actual training by filling in the forests
+ for (unsigned long cascade = 0; cascade < get_cascade_depth(); ++cascade)
+ {
+ // Each cascade uses a different set of pixels for its features. We compute
+ // their representations relative to the initial shape first.
+ std::vector<unsigned long> anchor_idx;
+ std::vector<dlib::vector<float,2> > deltas;
+ create_shape_relative_encoding(initial_shape, pixel_coordinates[cascade], anchor_idx, deltas);
+
+ // First compute the feature_pixel_values for each training sample at this
+ // level of the cascade.
+ parallel_for(tp, 0, samples.size(), [&](unsigned long i)
+ {
+ impl::extract_feature_pixel_values(images[samples[i].image_idx], samples[i].rect,
+ samples[i].current_shape, initial_shape, anchor_idx,
+ deltas, samples[i].feature_pixel_values);
+ }, 1);
+
+ // Now start building the trees at this cascade level.
+ for (unsigned long i = 0; i < get_num_trees_per_cascade_level(); ++i)
+ {
+ forests[cascade].push_back(make_regression_tree(tp, samples, pixel_coordinates[cascade]));
+
+ if (_verbose)
+ {
+ ++trees_fit_so_far;
+ pbar.print_status(trees_fit_so_far);
+ }
+ }
+ }
+
+ if (_verbose)
+ std::cout << "Training complete " << std::endl;
+
+ return shape_predictor(initial_shape, forests, pixel_coordinates);
+ }
+
+ private:
+
+ static void object_to_shape (
+ const full_object_detection& obj,
+ matrix<float,0,1>& shape,
+ matrix<float,0,1>& present // a mask telling which elements of #shape are present.
+ )
+ {
+ shape.set_size(obj.num_parts()*2);
+ present.set_size(obj.num_parts()*2);
+ const point_transform_affine tform_from_img = impl::normalizing_tform(obj.get_rect());
+ for (unsigned long i = 0; i < obj.num_parts(); ++i)
+ {
+ if (obj.part(i) != OBJECT_PART_NOT_PRESENT)
+ {
+ vector<float,2> p = tform_from_img(obj.part(i));
+ shape(2*i) = p.x();
+ shape(2*i+1) = p.y();
+ present(2*i) = 1;
+ present(2*i+1) = 1;
+
+ if (length(p) > 100)
+ {
+ std::cout << "Warning, one of your objects has parts that are way outside its bounding box! This is probably an error in your annotation." << std::endl;
+ }
+ }
+ else
+ {
+ shape(2*i) = 0;
+ shape(2*i+1) = 0;
+ present(2*i) = 0;
+ present(2*i+1) = 0;
+ }
+ }
+ }
+
+ template<typename feature_type>
+ struct training_sample
+ {
+ /*!
+
+ CONVENTION
+ - feature_pixel_values.size() == get_feature_pool_size()
+ - feature_pixel_values[j] == the value of the j-th feature pool
+ pixel when you look it up relative to the shape in current_shape.
+
+ - target_shape == The truth shape. Stays constant during the whole
+ training process (except for the parts that are not present, those are
+ always equal to the current_shape values).
+ - present == 0/1 mask saying which parts of target_shape are present.
+ - rect == the position of the object in the image_idx-th image. All shape
+ coordinates are coded relative to this rectangle.
+ - diff_shape == temporary value for holding difference between current
+ shape and target shape
+ !*/
+
+ unsigned long image_idx;
+ rectangle rect;
+ matrix<float,0,1> target_shape;
+ matrix<float,0,1> present;
+
+ matrix<float,0,1> current_shape;
+ matrix<float,0,1> diff_shape;
+ std::vector<feature_type> feature_pixel_values;
+
+ void swap(training_sample& item)
+ {
+ std::swap(image_idx, item.image_idx);
+ std::swap(rect, item.rect);
+ target_shape.swap(item.target_shape);
+ present.swap(item.present);
+ current_shape.swap(item.current_shape);
+ diff_shape.swap(item.diff_shape);
+ feature_pixel_values.swap(item.feature_pixel_values);
+ }
+ };
+
+ template<typename feature_type>
+ impl::regression_tree make_regression_tree (
+ thread_pool& tp,
+ std::vector<training_sample<feature_type>>& samples,
+ const std::vector<dlib::vector<float,2> >& pixel_coordinates
+ ) const
+ {
+ using namespace impl;
+ std::deque<std::pair<unsigned long, unsigned long> > parts;
+ parts.push_back(std::make_pair(0, (unsigned long)samples.size()));
+
+ impl::regression_tree tree;
+
+ // walk the tree in breadth first order
+ const unsigned long num_split_nodes = static_cast<unsigned long>(std::pow(2.0, (double)get_tree_depth())-1);
+ std::vector<matrix<float,0,1> > sums(num_split_nodes*2+1);
+ if (tp.num_threads_in_pool() > 1)
+ {
+ // Here we need to calculate shape differences and store sum of differences into sums[0]
+ // to make it. I am splitting samples into blocks, each block will be processed by
+ // separate thread, and the sum of differences of each block is stored into separate
+ // place in block_sums
+
+ const unsigned long num_workers = std::max(1UL, tp.num_threads_in_pool());
+ const unsigned long num = samples.size();
+ const unsigned long block_size = std::max(1UL, (num + num_workers - 1) / num_workers);
+ std::vector<matrix<float,0,1> > block_sums(num_workers);
+
+ parallel_for(tp, 0, num_workers, [&](unsigned long block)
+ {
+ const unsigned long block_begin = block * block_size;
+ const unsigned long block_end = std::min(num, block_begin + block_size);
+ for (unsigned long i = block_begin; i < block_end; ++i)
+ {
+ samples[i].diff_shape = samples[i].target_shape - samples[i].current_shape;
+ block_sums[block] += samples[i].diff_shape;
+ }
+ }, 1);
+
+ // now calculate the total result from separate blocks
+ for (unsigned long i = 0; i < block_sums.size(); ++i)
+ sums[0] += block_sums[i];
+ }
+ else
+ {
+ // synchronous implementation
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ samples[i].diff_shape = samples[i].target_shape - samples[i].current_shape;
+ sums[0] += samples[i].diff_shape;
+ }
+ }
+
+ for (unsigned long i = 0; i < num_split_nodes; ++i)
+ {
+ std::pair<unsigned long,unsigned long> range = parts.front();
+ parts.pop_front();
+
+ const impl::split_feature split = generate_split(tp, samples, range.first,
+ range.second, pixel_coordinates, sums[i], sums[left_child(i)],
+ sums[right_child(i)]);
+ tree.splits.push_back(split);
+ const unsigned long mid = partition_samples(split, samples, range.first, range.second);
+
+ parts.push_back(std::make_pair(range.first, mid));
+ parts.push_back(std::make_pair(mid, range.second));
+ }
+
+ // Now all the parts contain the ranges for the leaves so we can use them to
+ // compute the average leaf values.
+ matrix<float,0,1> present_counts(samples[0].target_shape.size());
+ tree.leaf_values.resize(parts.size());
+ for (unsigned long i = 0; i < parts.size(); ++i)
+ {
+ // Get the present counts for each dimension so we can divide each
+ // dimension by the number of observations we have on it to find the mean
+ // displacement in each leaf.
+ present_counts = 0;
+ for (unsigned long j = parts[i].first; j < parts[i].second; ++j)
+ present_counts += samples[j].present;
+ present_counts = dlib::reciprocal(present_counts);
+
+ if (parts[i].second != parts[i].first)
+ tree.leaf_values[i] = pointwise_multiply(present_counts,sums[num_split_nodes+i]*get_nu());
+ else
+ tree.leaf_values[i] = zeros_matrix(samples[0].target_shape);
+
+ // now adjust the current shape based on these predictions
+ parallel_for(tp, parts[i].first, parts[i].second, [&](unsigned long j)
+ {
+ samples[j].current_shape += tree.leaf_values[i];
+ // For parts that aren't present in the training data, we just make
+ // sure that the target shape always matches and therefore gives zero
+ // error. So this makes the algorithm simply ignore non-present
+ // landmarks.
+ for (long k = 0; k < samples[j].present.size(); ++k)
+ {
+ // if this part is not present
+ if (samples[j].present(k) == 0)
+ samples[j].target_shape(k) = samples[j].current_shape(k);
+ }
+ }, 1);
+ }
+
+ return tree;
+ }
+
+ impl::split_feature randomly_generate_split_feature (
+ const std::vector<dlib::vector<float,2> >& pixel_coordinates
+ ) const
+ {
+ const double lambda = get_lambda();
+ impl::split_feature feat;
+ const size_t max_iters = get_feature_pool_size()*get_feature_pool_size();
+ for (size_t i = 0; i < max_iters; ++i)
+ {
+ feat.idx1 = rnd.get_integer(get_feature_pool_size());
+ feat.idx2 = rnd.get_integer(get_feature_pool_size());
+ while (feat.idx1 == feat.idx2)
+ feat.idx2 = rnd.get_integer(get_feature_pool_size());
+ const double dist = length(pixel_coordinates[feat.idx1]-pixel_coordinates[feat.idx2]);
+ const double accept_prob = std::exp(-dist/lambda);
+ if (accept_prob > rnd.get_random_double())
+ break;
+ }
+
+ feat.thresh = (rnd.get_random_double()*256 - 128)/2.0;
+
+ return feat;
+ }
+
+ template<typename feature_type>
+ impl::split_feature generate_split (
+ thread_pool& tp,
+ const std::vector<training_sample<feature_type>>& samples,
+ unsigned long begin,
+ unsigned long end,
+ const std::vector<dlib::vector<float,2> >& pixel_coordinates,
+ const matrix<float,0,1>& sum,
+ matrix<float,0,1>& left_sum,
+ matrix<float,0,1>& right_sum
+ ) const
+ {
+ // generate a bunch of random splits and test them and return the best one.
+
+ const unsigned long num_test_splits = get_num_test_splits();
+
+ // sample the random features we test in this function
+ std::vector<impl::split_feature> feats;
+ feats.reserve(num_test_splits);
+ for (unsigned long i = 0; i < num_test_splits; ++i)
+ feats.push_back(randomly_generate_split_feature(pixel_coordinates));
+
+ std::vector<matrix<float,0,1> > left_sums(num_test_splits);
+ std::vector<unsigned long> left_cnt(num_test_splits);
+
+ const unsigned long num_workers = std::max(1UL, tp.num_threads_in_pool());
+ const unsigned long block_size = std::max(1UL, (num_test_splits + num_workers - 1) / num_workers);
+
+ // now compute the sums of vectors that go left for each feature
+ parallel_for(tp, 0, num_workers, [&](unsigned long block)
+ {
+ const unsigned long block_begin = block * block_size;
+ const unsigned long block_end = std::min(block_begin + block_size, num_test_splits);
+
+ for (unsigned long j = begin; j < end; ++j)
+ {
+ for (unsigned long i = block_begin; i < block_end; ++i)
+ {
+ if ((float)samples[j].feature_pixel_values[feats[i].idx1] - (float)samples[j].feature_pixel_values[feats[i].idx2] > feats[i].thresh)
+ {
+ left_sums[i] += samples[j].diff_shape;
+ ++left_cnt[i];
+ }
+ }
+ }
+
+ }, 1);
+
+ // now figure out which feature is the best
+ double best_score = -1;
+ unsigned long best_feat = 0;
+ matrix<float,0,1> temp;
+ for (unsigned long i = 0; i < num_test_splits; ++i)
+ {
+ // check how well the feature splits the space.
+ double score = 0;
+ unsigned long right_cnt = end-begin-left_cnt[i];
+ if (left_cnt[i] != 0 && right_cnt != 0)
+ {
+ temp = sum - left_sums[i];
+ score = dot(left_sums[i],left_sums[i])/left_cnt[i] + dot(temp,temp)/right_cnt;
+ if (score > best_score)
+ {
+ best_score = score;
+ best_feat = i;
+ }
+ }
+ }
+
+ left_sums[best_feat].swap(left_sum);
+ if (left_sum.size() != 0)
+ {
+ right_sum = sum - left_sum;
+ }
+ else
+ {
+ right_sum = sum;
+ left_sum = zeros_matrix(sum);
+ }
+ return feats[best_feat];
+ }
+
+ template<typename feature_type>
+ unsigned long partition_samples (
+ const impl::split_feature& split,
+ std::vector<training_sample<feature_type>>& samples,
+ unsigned long begin,
+ unsigned long end
+ ) const
+ {
+ // splits samples based on split (sorta like in quick sort) and returns the mid
+ // point. make sure you return the mid in a way compatible with how we walk
+ // through the tree.
+
+ unsigned long i = begin;
+ for (unsigned long j = begin; j < end; ++j)
+ {
+ if ((float)samples[j].feature_pixel_values[split.idx1] - (float)samples[j].feature_pixel_values[split.idx2] > split.thresh)
+ {
+ samples[i].swap(samples[j]);
+ ++i;
+ }
+ }
+ return i;
+ }
+
+
+
+ template<typename feature_type>
+ matrix<float,0,1> populate_training_sample_shapes(
+ const std::vector<std::vector<full_object_detection> >& objects,
+ std::vector<training_sample<feature_type>>& samples
+ ) const
+ {
+ samples.clear();
+ matrix<float,0,1> mean_shape;
+ matrix<float,0,1> count;
+ // first fill out the target shapes
+ for (unsigned long i = 0; i < objects.size(); ++i)
+ {
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ training_sample<feature_type> sample;
+ sample.image_idx = i;
+ sample.rect = objects[i][j].get_rect();
+ object_to_shape(objects[i][j], sample.target_shape, sample.present);
+ for (unsigned long itr = 0; itr < get_oversampling_amount(); ++itr)
+ samples.push_back(sample);
+ mean_shape += sample.target_shape;
+ count += sample.present;
+ }
+ }
+
+ mean_shape = pointwise_multiply(mean_shape,reciprocal(count));
+
+ // now go pick random initial shapes
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if ((i%get_oversampling_amount()) == 0)
+ {
+ // The mean shape is what we really use as an initial shape so always
+ // include it in the training set as an example starting shape.
+ samples[i].current_shape = mean_shape;
+ }
+ else
+ {
+ samples[i].current_shape.set_size(0);
+
+ matrix<float,0,1> hits(mean_shape.size());
+ hits = 0;
+
+ int iter = 0;
+ // Pick a few samples at random and randomly average them together to
+ // make the initial shape. Note that we make sure we get at least one
+ // observation (i.e. non-OBJECT_PART_NOT_PRESENT) on each part
+ // location.
+ while(min(hits) == 0 || iter < 2)
+ {
+ ++iter;
+ const unsigned long rand_idx = rnd.get_random_32bit_number()%samples.size();
+ const double alpha = rnd.get_random_double()+0.1;
+ samples[i].current_shape += alpha*samples[rand_idx].target_shape;
+ hits += alpha*samples[rand_idx].present;
+ }
+ samples[i].current_shape = pointwise_multiply(samples[i].current_shape, reciprocal(hits));
+ }
+
+ }
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ for (long k = 0; k < samples[i].present.size(); ++k)
+ {
+ // if this part is not present
+ if (samples[i].present(k) == 0)
+ samples[i].target_shape(k) = samples[i].current_shape(k);
+ }
+ }
+
+
+ return mean_shape;
+ }
+
+
+ void randomly_sample_pixel_coordinates (
+ std::vector<dlib::vector<float,2> >& pixel_coordinates,
+ const double min_x,
+ const double min_y,
+ const double max_x,
+ const double max_y
+ ) const
+ /*!
+ ensures
+ - #pixel_coordinates.size() == get_feature_pool_size()
+ - for all valid i:
+ - pixel_coordinates[i] == a point in the box defined by the min/max x/y arguments.
+ !*/
+ {
+ pixel_coordinates.resize(get_feature_pool_size());
+ for (unsigned long i = 0; i < get_feature_pool_size(); ++i)
+ {
+ pixel_coordinates[i].x() = rnd.get_random_double()*(max_x-min_x) + min_x;
+ pixel_coordinates[i].y() = rnd.get_random_double()*(max_y-min_y) + min_y;
+ }
+ }
+
+ std::vector<std::vector<dlib::vector<float,2> > > randomly_sample_pixel_coordinates (
+ const matrix<float,0,1>& initial_shape
+ ) const
+ {
+ const double padding = get_feature_pool_region_padding();
+ // Figure out the bounds on the object shapes. We will sample uniformly
+ // from this box.
+ matrix<float> temp = reshape(initial_shape, initial_shape.size()/2, 2);
+ double min_x = min(colm(temp,0));
+ double min_y = min(colm(temp,1));
+ double max_x = max(colm(temp,0));
+ double max_y = max(colm(temp,1));
+
+ if (get_padding_mode() == bounding_box_relative)
+ {
+ min_x = std::min(0.0, min_x);
+ min_y = std::min(0.0, min_y);
+ max_x = std::max(1.0, max_x);
+ max_y = std::max(1.0, max_y);
+ }
+
+ min_x -= padding;
+ min_y -= padding;
+ max_x += padding;
+ max_y += padding;
+
+ std::vector<std::vector<dlib::vector<float,2> > > pixel_coordinates;
+ pixel_coordinates.resize(get_cascade_depth());
+ for (unsigned long i = 0; i < get_cascade_depth(); ++i)
+ randomly_sample_pixel_coordinates(pixel_coordinates[i], min_x, min_y, max_x, max_y);
+ return pixel_coordinates;
+ }
+
+
+
+ mutable dlib::rand rnd;
+
+ unsigned long _cascade_depth;
+ unsigned long _tree_depth;
+ unsigned long _num_trees_per_cascade_level;
+ double _nu;
+ unsigned long _oversampling_amount;
+ unsigned long _feature_pool_size;
+ double _lambda;
+ unsigned long _num_test_splits;
+ double _feature_pool_region_padding;
+ bool _verbose;
+ unsigned long _num_threads;
+ padding_mode_t _padding_mode;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename some_type_of_rectangle
+ >
+ image_dataset_metadata::dataset make_bounding_box_regression_training_data (
+ const image_dataset_metadata::dataset& truth,
+ const std::vector<std::vector<some_type_of_rectangle>>& detections
+ )
+ {
+ DLIB_CASSERT(truth.images.size() == detections.size(),
+ "truth.images.size(): "<< truth.images.size() <<
+ "\tdetections.size(): "<< detections.size()
+ );
+ image_dataset_metadata::dataset result = truth;
+
+ for (size_t i = 0; i < truth.images.size(); ++i)
+ {
+ result.images[i].boxes.clear();
+ for (auto truth_box : truth.images[i].boxes)
+ {
+ if (truth_box.ignore)
+ continue;
+
+ // Find the detection that best matches the current truth_box.
+ auto det = max_scoring_element(detections[i], [&truth_box](const rectangle& r) { return box_intersection_over_union(r, truth_box.rect); });
+ if (det.second > 0.5)
+ {
+ // Remove any existing parts and replace them with the truth_box corners.
+ truth_box.parts.clear();
+ auto b = truth_box.rect;
+ truth_box.parts["left"] = (b.tl_corner()+b.bl_corner())/2;
+ truth_box.parts["right"] = (b.tr_corner()+b.br_corner())/2;
+ truth_box.parts["top"] = (b.tl_corner()+b.tr_corner())/2;
+ truth_box.parts["bottom"] = (b.bl_corner()+b.br_corner())/2;
+ truth_box.parts["middle"] = center(b);
+
+ // Now replace the bounding truth_box with the detector's bounding truth_box.
+ truth_box.rect = det.first;
+
+ result.images[i].boxes.push_back(truth_box);
+ }
+ }
+ }
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SHAPE_PREDICToR_TRAINER_H_
+
diff --git a/ml/dlib/dlib/image_processing/shape_predictor_trainer_abstract.h b/ml/dlib/dlib/image_processing/shape_predictor_trainer_abstract.h
new file mode 100644
index 000000000..278b97842
--- /dev/null
+++ b/ml/dlib/dlib/image_processing/shape_predictor_trainer_abstract.h
@@ -0,0 +1,418 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SHAPE_PREDICToR_TRAINER_ABSTRACT_H_
+#ifdef DLIB_SHAPE_PREDICToR_TRAINER_ABSTRACT_H_
+
+#include "shape_predictor_abstract.h"
+#include "../data_io/image_dataset_metadata.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class shape_predictor_trainer
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for training shape_predictors based on annotated training
+ images. Its implementation uses the algorithm described in:
+ One Millisecond Face Alignment with an Ensemble of Regression Trees
+ by Vahid Kazemi and Josephine Sullivan, CVPR 2014
+
+ !*/
+
+ public:
+
+ shape_predictor_trainer (
+ );
+ /*!
+ ensures
+ - #get_cascade_depth() == 10
+ - #get_tree_depth() == 4
+ - #get_num_trees_per_cascade_level() == 500
+ - #get_nu() == 0.1
+ - #get_oversampling_amount() == 20
+ - #get_feature_pool_size() == 400
+ - #get_lambda() == 0.1
+ - #get_num_test_splits() == 20
+ - #get_feature_pool_region_padding() == 0
+ - #get_random_seed() == ""
+ - #get_num_threads() == 0
+ - #get_padding_mode() == landmark_relative
+ - This object will not be verbose
+ !*/
+
+ unsigned long get_cascade_depth (
+ ) const;
+ /*!
+ ensures
+ - returns the number of cascades created when you train a model. This
+ means that the total number of trees in the learned model is equal to
+ get_cascade_depth()*get_num_trees_per_cascade_level().
+ !*/
+
+ void set_cascade_depth (
+ unsigned long depth
+ );
+ /*!
+ requires
+ - depth > 0
+ ensures
+ - #get_cascade_depth() == depth
+ !*/
+
+ unsigned long get_tree_depth (
+ ) const;
+ /*!
+ ensures
+ - returns the depth of the trees used in the cascade. In particular, there
+ are pow(2,get_tree_depth()) leaves in each tree.
+ !*/
+
+ void set_tree_depth (
+ unsigned long depth
+ );
+ /*!
+ requires
+ - depth > 0
+ ensures
+ - #get_tree_depth() == depth
+ !*/
+
+ unsigned long get_num_trees_per_cascade_level (
+ ) const;
+ /*!
+ ensures
+ - returns the number of trees created for each cascade. This means that
+ the total number of trees in the learned model is equal to
+ get_cascade_depth()*get_num_trees_per_cascade_level().
+ !*/
+
+ void set_num_trees_per_cascade_level (
+ unsigned long num
+ );
+ /*!
+ requires
+ - num > 0
+ ensures
+ - #get_num_trees_per_cascade_level() == num
+ !*/
+
+ double get_nu (
+ ) const;
+ /*!
+ ensures
+ - returns the regularization parameter. Larger values of this parameter
+ will cause the algorithm to fit the training data better but may also
+ cause overfitting.
+ !*/
+
+ void set_nu (
+ double nu
+ );
+ /*!
+ requires
+ - 0 < nu <= 1
+ ensures
+ - #get_nu() == nu
+ !*/
+
+ std::string get_random_seed (
+ ) const;
+ /*!
+ ensures
+ - returns the random seed used by the internal random number generator.
+ Since this algorithm is a random forest style algorithm it relies on a
+ random number generator for generating the trees. So each setting of the
+ random seed will produce slightly different outputs.
+ !*/
+
+ void set_random_seed (
+ const std::string& seed
+ );
+ /*!
+ ensures
+ - #get_random_seed() == seed
+ !*/
+
+ unsigned long get_oversampling_amount (
+ ) const;
+ /*!
+ ensures
+ - You give annotated images to this object as training examples. You
+ can effectively increase the amount of training data by adding in each
+ training example multiple times but with a randomly selected deformation
+ applied to it. That is what this parameter controls. That is, if you
+ supply N training samples to train() then the algorithm runs internally
+ with N*get_oversampling_amount() training samples. So the bigger this
+ parameter the better (excepting that larger values make training take
+ longer). In terms of the Kazemi paper, this parameter is the number of
+ randomly selected initial starting points sampled for each training
+ example.
+ !*/
+
+ void set_oversampling_amount (
+ unsigned long amount
+ );
+ /*!
+ requires
+ - amount > 0
+ ensures
+ - #get_oversampling_amount() == amount
+ !*/
+
+ unsigned long get_feature_pool_size (
+ ) const;
+ /*!
+ ensures
+ - At each level of the cascade we randomly sample get_feature_pool_size()
+ pixels from the image. These pixels are used to generate features for
+ the random trees. So in general larger settings of this parameter give
+ better accuracy but make the algorithm run slower.
+ !*/
+
+ void set_feature_pool_size (
+ unsigned long size
+ );
+ /*!
+ requires
+ - size > 1
+ ensures
+ - #get_feature_pool_size() == size
+ !*/
+
+ enum padding_mode_t
+ {
+ bounding_box_relative,
+ landmark_relative
+ };
+
+ padding_mode_t get_padding_mode (
+ ) const;
+ /*!
+ ensures
+ - returns the current padding mode. See get_feature_pool_region_padding()
+ for a discussion of the modes.
+ !*/
+
+ void set_padding_mode (
+ padding_mode_t mode
+ );
+ /*!
+ ensures
+ - #get_padding_mode() == mode
+ !*/
+
+ double get_feature_pool_region_padding (
+ ) const;
+ /*!
+ ensures
+ - This algorithm works by comparing the relative intensity of pairs of
+ pixels in the input image. To decide which pixels to look at, the
+ training algorithm randomly selects pixels from a box roughly centered
+ around the object of interest. We call this box the feature pool region
+ box.
+
+ Each object of interest is defined by a full_object_detection, which
+ contains a bounding box and a list of landmarks. If
+ get_padding_mode()==landmark_relative then the feature pool region box is
+ the tightest box that contains the landmarks inside the
+ full_object_detection. In this mode the full_object_detection's bounding
+ box is ignored. Otherwise, if the padding mode is bounding_box_relative
+ then the feature pool region box is the tightest box that contains BOTH
+ the landmarks and the full_object_detection's bounding box.
+
+ Additionally, you can adjust the size of the feature pool padding region
+ by setting get_feature_pool_region_padding() to some value. If
+ get_feature_pool_region_padding()==0 then the feature pool region box is
+ unmodified and defined exactly as stated above. However, you can expand
+ the size of the box by setting the padding > 0 or shrink it by setting it
+ to something < 0.
+
+ To explain this precisely, for a padding of 0 we say that the pixels are
+ sampled from a box of size 1x1. The padding value is added to each side
+ of the box. So a padding of 0.5 would cause the algorithm to sample
+ pixels from a box that was 2x2, effectively multiplying the area pixels
+ are sampled from by 4. Similarly, setting the padding to -0.2 would
+ cause it to sample from a box 0.6x0.6 in size.
+ !*/
+
+ void set_feature_pool_region_padding (
+ double padding
+ );
+ /*!
+ requires
+ - padding > -0.5
+ ensures
+ - #get_feature_pool_region_padding() == padding
+ !*/
+
+ double get_lambda (
+ ) const;
+ /*!
+ ensures
+ - To decide how to split nodes in the regression trees the algorithm looks
+ at pairs of pixels in the image. These pixel pairs are sampled randomly
+ but with a preference for selecting pixels that are near each other.
+ get_lambda() controls this "nearness" preference. In particular, smaller
+ values of get_lambda() will make the algorithm prefer to select pixels
+ close together and larger values of get_lambda() will make it care less
+ about picking nearby pixel pairs.
+
+ Note that this is the inverse of how it is defined in the Kazemi paper.
+ For this object, you should think of lambda as "the fraction of the
+ bounding box will we traverse to find a neighboring pixel". Nominally,
+ this is normalized between 0 and 1. So reasonable settings of lambda are
+ values in the range 0 < lambda < 1.
+ !*/
+
+ void set_lambda (
+ double lambda
+ );
+ /*!
+ requires
+ - lambda > 0
+ ensures
+ - #get_lambda() == lambda
+ !*/
+
+ unsigned long get_num_test_splits (
+ ) const;
+ /*!
+ ensures
+ - When generating the random trees we randomly sample get_num_test_splits()
+ possible split features at each node and pick the one that gives the best
+ split. Larger values of this parameter will usually give more accurate
+ outputs but take longer to train.
+ !*/
+
+ void set_num_test_splits (
+ unsigned long num
+ );
+ /*!
+ requires
+ - num > 0
+ ensures
+ - #get_num_test_splits() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - When running training process, it is possible to make some parts of it parallel
+ using CPU threads with #parallel_for() extension and creating #thread_pool internally
+ When get_num_threads() == 0, trainer will not create threads and all processing will
+ be done in the calling thread
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ requires
+ - num >= 0
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - This object will not print anything to standard out
+ !*/
+
+ template <typename image_array>
+ shape_predictor train (
+ const image_array& images,
+ const std::vector<std::vector<full_object_detection> >& objects
+ ) const;
+ /*!
+ requires
+ - image_array is a dlib::array of image objects where each image object
+ implements the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ - images.size() > 0
+ - for some i: objects[i].size() != 0
+ (i.e. there has to be at least one full_object_detection in the training set)
+ - for all valid p, there must exist i and j such that:
+ objects[i][j].part(p) != OBJECT_PART_NOT_PRESENT.
+ (i.e. You can't define a part that is always set to OBJECT_PART_NOT_PRESENT.)
+ - for all valid i,j,k,l:
+ - objects[i][j].num_parts() == objects[k][l].num_parts()
+ (i.e. all objects must agree on the number of parts)
+ - objects[i][j].num_parts() > 0
+ ensures
+ - This object will try to learn to predict the locations of an object's parts
+ based on the object bounding box (i.e. full_object_detection::get_rect())
+ and the image pixels in that box. That is, we will try to learn a
+ shape_predictor, SP, such that:
+ SP(images[i], objects[i][j].get_rect()) == objects[i][j]
+ This learned SP object is then returned.
+ - Not all parts are required to be observed for all objects. So if you
+ have training instances with missing parts then set the part positions
+ equal to OBJECT_PART_NOT_PRESENT and this algorithm will basically ignore
+ those missing parts.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename some_type_of_rectangle
+ >
+ image_dataset_metadata::dataset make_bounding_box_regression_training_data (
+ const image_dataset_metadata::dataset& truth,
+ const std::vector<std::vector<some_type_of_rectangle>>& detections
+ );
+ /*!
+ requires
+ - truth.images.size() == detections.size()
+ - some_type_of_rectangle == rectangle, drectangle, mmod_rect, or any other type
+ that is convertible to a rectangle.
+ ensures
+ - Suppose you have an object detector that can roughly locate objects in an
+ image. This means your detector draws boxes around objects, but these are
+ *rough* boxes in the sense that they aren't positioned super accurately. For
+ instance, HOG based detectors usually have a stride of 8 pixels. So the
+ positional accuracy is going to be, at best, +/-8 pixels.
+
+ If you want to get better positional accuracy one easy thing to do is train a
+ shape_predictor to give you the location of the object's box. The
+ make_bounding_box_regression_training_data() routine helps you do this by
+ creating an appropriate training dataset. It does this by taking the dataset
+ you used to train your detector (given by the truth object), and combining
+ that with the output of your detector on each image in the training dataset
+ (given by the detections object). In particular, it will create a new
+ annotated dataset where each object box is one of the rectangles from
+ detections and that object has 5 part annotations. These annotations
+ identify the sides and middle of the truth rectangle corresponding to the
+ detection rectangle. You can then take the returned dataset and train a
+ shape_predictor on it. The resulting shape_predictor can then be used to do
+ bounding box regression.
+
+ As an aside, the reason we create 5 part annotations in this way is because
+ it gives the best shape_predictor when trained. If instead you used the 4
+ corners it wouldn't work as well, due to tedious vagaries of the shape_predictor
+ training process.
+
+ - We assume that detections[i] contains object detections corresponding to
+ the image truth.images[i].
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SHAPE_PREDICToR_TRAINER_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/image_saver/dng_shared.h b/ml/dlib/dlib/image_saver/dng_shared.h
new file mode 100644
index 000000000..d098851b3
--- /dev/null
+++ b/ml/dlib/dlib/image_saver/dng_shared.h
@@ -0,0 +1,288 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNG_SHAREd_
+#define DLIB_DNG_SHAREd_
+
+#include "../pixel.h"
+#include <cmath>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+ namespace dng_helpers_namespace
+ {
+ enum
+ {
+ grayscale = 1,
+ rgb,
+ hsi,
+ rgb_paeth,
+ rgb_alpha,
+ rgb_alpha_paeth,
+ grayscale_16bit,
+ grayscale_float
+ };
+
+ const unsigned long dng_magic_byte = 100;
+
+ template <typename T>
+ rgb_pixel predictor_rgb_paeth (const T& img, long row, long col)
+ /*
+ This is similar to the Paeth filter from the PNG image format.
+ */
+ {
+ // a = left, b = above, c = upper left
+ rgb_pixel a(0,0,0), b(0,0,0), c(0,0,0);
+
+
+ const long c1 = col-1;
+ const long r1 = row-1;
+
+ if (c1 >= 0)
+ assign_pixel(a, img[row][c1]);
+ else
+ assign_pixel(a,(unsigned char)0);
+
+ if (c1 >= 0 && r1 >= 0)
+ assign_pixel(c, img[r1][c1]);
+ else
+ assign_pixel(c,(unsigned char)0);
+
+ if (r1 >= 0)
+ assign_pixel(b, img[r1][col]);
+ else
+ assign_pixel(b,(unsigned char)0);
+
+
+ rgb_pixel p;
+ p.red = a.red + b.red - c.red;
+ p.green = a.green + b.green - c.green;
+ p.blue = a.blue + b.blue - c.blue;
+
+ short pa = std::abs((short)p.red - (short)a.red) +
+ std::abs((short)p.green - (short)a.green) +
+ std::abs((short)p.blue - (short)a.blue);
+ short pb = std::abs((short)p.red - (short)b.red) +
+ std::abs((short)p.green - (short)b.green) +
+ std::abs((short)p.blue - (short)b.blue);
+ short pc = std::abs((short)p.red - (short)c.red) +
+ std::abs((short)p.green - (short)c.green) +
+ std::abs((short)p.blue - (short)c.blue);
+
+ if (pa <= pb && pa <= pc)
+ return a;
+ else if (pb <= pc)
+ return b;
+ else
+ return c;
+ }
+
+
+ template <typename T>
+ rgb_pixel predictor_rgb (const T& img, long row, long col)
+ {
+ // a = left, b = above, c = upper left
+ rgb_pixel a(0,0,0), b(0,0,0), c(0,0,0);
+
+
+ const long c1 = col-1;
+ const long r1 = row-1;
+
+ if (c1 >= 0)
+ assign_pixel(a, img[row][c1]);
+ else
+ assign_pixel(a,(unsigned char)0);
+
+ if (c1 >= 0 && r1 >= 0)
+ assign_pixel(c, img[r1][c1]);
+ else
+ assign_pixel(c,(unsigned char)0);
+
+ if (r1 >= 0)
+ assign_pixel(b, img[r1][col]);
+ else
+ assign_pixel(b,(unsigned char)0);
+
+
+ rgb_pixel p;
+ p.red = a.red + b.red - c.red;
+ p.green = a.green + b.green - c.green;
+ p.blue = a.blue + b.blue - c.blue;
+ return p;
+ }
+
+ template <typename T>
+ rgb_alpha_pixel predictor_rgb_alpha_paeth (const T& img, long row, long col)
+ /*
+ This is similar to the Paeth filter from the PNG image format.
+ */
+ {
+ // a = left, b = above, c = upper left
+ rgb_alpha_pixel a, b, c;
+
+
+ const long c1 = col-1;
+ const long r1 = row-1;
+
+ if (c1 >= 0)
+ assign_pixel(a, img[row][c1]);
+ else
+ assign_pixel(a,(unsigned char)0);
+
+ if (c1 >= 0 && r1 >= 0)
+ assign_pixel(c, img[r1][c1]);
+ else
+ assign_pixel(c,(unsigned char)0);
+
+ if (r1 >= 0)
+ assign_pixel(b, img[r1][col]);
+ else
+ assign_pixel(b,(unsigned char)0);
+
+
+ rgb_alpha_pixel p;
+ p.red = a.red + b.red - c.red;
+ p.green = a.green + b.green - c.green;
+ p.blue = a.blue + b.blue - c.blue;
+
+ short pa = std::abs((short)p.red - (short)a.red) +
+ std::abs((short)p.green - (short)a.green) +
+ std::abs((short)p.blue - (short)a.blue);
+ short pb = std::abs((short)p.red - (short)b.red) +
+ std::abs((short)p.green - (short)b.green) +
+ std::abs((short)p.blue - (short)b.blue);
+ short pc = std::abs((short)p.red - (short)c.red) +
+ std::abs((short)p.green - (short)c.green) +
+ std::abs((short)p.blue - (short)c.blue);
+
+ if (pa <= pb && pa <= pc)
+ return a;
+ else if (pb <= pc)
+ return b;
+ else
+ return c;
+ }
+
+
+ template <typename T>
+ rgb_alpha_pixel predictor_rgb_alpha (const T& img, long row, long col)
+ {
+ // a = left, b = above, c = upper left
+ rgb_alpha_pixel a, b, c;
+
+
+ const long c1 = col-1;
+ const long r1 = row-1;
+
+ if (c1 >= 0)
+ assign_pixel(a, img[row][c1]);
+ else
+ assign_pixel(a,(unsigned char)0);
+
+ if (c1 >= 0 && r1 >= 0)
+ assign_pixel(c, img[r1][c1]);
+ else
+ assign_pixel(c,(unsigned char)0);
+
+ if (r1 >= 0)
+ assign_pixel(b, img[r1][col]);
+ else
+ assign_pixel(b,(unsigned char)0);
+
+
+ rgb_alpha_pixel p;
+ p.red = a.red + b.red - c.red;
+ p.green = a.green + b.green - c.green;
+ p.blue = a.blue + b.blue - c.blue;
+ p.alpha = a.alpha + b.alpha - c.alpha;
+ return p;
+ }
+
+
+ template <typename T>
+ hsi_pixel predictor_hsi (const T& img, long row, long col)
+ {
+ // a = left, b = above, c = upper left
+ hsi_pixel a(0,0,0), b(0,0,0), c(0,0,0);
+
+
+ const long c1 = col-1;
+ const long r1 = row-1;
+
+ if (c1 >= 0)
+ assign_pixel(a, img[row][c1]);
+ else
+ assign_pixel(a,(unsigned char)0);
+
+ if (c1 >= 0 && r1 >= 0)
+ assign_pixel(c, img[r1][c1]);
+ else
+ assign_pixel(c,(unsigned char)0);
+
+ if (r1 >= 0)
+ assign_pixel(b, img[r1][col]);
+ else
+ assign_pixel(b,(unsigned char)0);
+
+
+ hsi_pixel p;
+ p.h = a.h + b.h - c.h;
+ p.s = a.s + b.s - c.s;
+ p.i = a.i + b.i - c.i;
+ return p;
+ }
+
+ template <typename T>
+ unsigned char predictor_grayscale (const T& img, long row, long col)
+ {
+ // a = left, b = above, c = upper left
+ unsigned char a = 0, b = 0, c = 0;
+
+
+ const long c1 = col-1;
+ const long r1 = row-1;
+
+ if (c1 >= 0)
+ assign_pixel(a, img[row][c1]);
+
+ if (c1 >= 0 && r1 >= 0)
+ assign_pixel(c, img[r1][c1]);
+
+ if (r1 >= 0)
+ assign_pixel(b, img[r1][col]);
+
+
+ unsigned char p = a + b - c;
+ return p;
+ }
+
+ template <typename T>
+ uint16 predictor_grayscale_16 (const T& img, long row, long col)
+ {
+ // a = left, b = above, c = upper left
+ uint16 a = 0, b = 0, c = 0;
+
+
+ const long c1 = col-1;
+ const long r1 = row-1;
+
+ if (c1 >= 0)
+ assign_pixel(a, img[row][c1]);
+
+ if (c1 >= 0 && r1 >= 0)
+ assign_pixel(c, img[r1][c1]);
+
+ if (r1 >= 0)
+ assign_pixel(b, img[r1][col]);
+
+
+ uint16 p = a + b - c;
+ return p;
+ }
+
+ }
+}
+
+#endif // DLIB_DNG_SHAREd_
+
diff --git a/ml/dlib/dlib/image_saver/image_saver.h b/ml/dlib/dlib/image_saver/image_saver.h
new file mode 100644
index 000000000..43a2717af
--- /dev/null
+++ b/ml/dlib/dlib/image_saver/image_saver.h
@@ -0,0 +1,688 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IMAGE_SAVEr_
+#define DLIB_IMAGE_SAVEr_
+
+#include "image_saver_abstract.h"
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include "../algs.h"
+#include "../pixel.h"
+#include "../byte_orderer.h"
+#include "../entropy_encoder.h"
+#include "../entropy_encoder_model.h"
+#include "dng_shared.h"
+#include "../uintn.h"
+#include "../dir_nav.h"
+#include "../float_details.h"
+#include "../vectorstream.h"
+#include "../matrix/matrix_exp.h"
+#include "../image_transforms/assign_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class image_save_error : public dlib::error {
+ public: image_save_error(const std::string& str) : error(EIMAGE_SAVE,str){}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ bool grayscale = pixel_traits<typename image_traits<image_type>::pixel_type>::grayscale
+ >
+ struct save_bmp_helper;
+
+
+ template <typename image_type>
+ struct save_bmp_helper<image_type,false>
+ {
+ static void save_bmp (
+ const image_type& image_,
+ std::ostream& out
+ )
+ {
+ const_image_view<image_type> image(image_);
+ // we are going to write out a 24bit color image.
+ byte_orderer::kernel_1a bo;
+
+ out.write("BM",2);
+
+ if (!out)
+ throw image_save_error("error writing image to output stream");
+
+
+ unsigned long pad = 4 - (image.nc()*3)%4;
+ if (pad == 4)
+ pad = 0;
+
+ unsigned long bfSize = 14 + 40 + (image.nc()*3 + pad)*image.nr();
+ unsigned long bfReserved = 0;
+ unsigned long bfOffBits = 14 + 40;
+ unsigned long biSize = 40;
+ unsigned long biWidth = image.nc();
+ unsigned long biHeight = image.nr();
+ unsigned short biPlanes = 1;
+ unsigned short biBitCount = 24;
+ unsigned long biCompression = 0;
+ unsigned long biSizeImage = 0;
+ unsigned long biXPelsPerMeter = 0;
+ unsigned long biYPelsPerMeter = 0;
+ unsigned long biClrUsed = 0;
+ unsigned long biClrImportant = 0;
+
+ bo.host_to_little(bfSize);
+ bo.host_to_little(bfOffBits);
+ bo.host_to_little(biSize);
+ bo.host_to_little(biWidth);
+ bo.host_to_little(biHeight);
+ bo.host_to_little(biPlanes);
+ bo.host_to_little(biBitCount);
+
+ out.write((char*)&bfSize,4);
+ out.write((char*)&bfReserved,4);
+ out.write((char*)&bfOffBits,4);
+ out.write((char*)&biSize,4);
+ out.write((char*)&biWidth,4);
+ out.write((char*)&biHeight,4);
+ out.write((char*)&biPlanes,2);
+ out.write((char*)&biBitCount,2);
+ out.write((char*)&biCompression,4);
+ out.write((char*)&biSizeImage,4);
+ out.write((char*)&biXPelsPerMeter,4);
+ out.write((char*)&biYPelsPerMeter,4);
+ out.write((char*)&biClrUsed,4);
+ out.write((char*)&biClrImportant,4);
+
+
+ if (!out)
+ throw image_save_error("error writing image to output stream");
+
+ // now we write out the pixel data
+ for (long row = image.nr()-1; row >= 0; --row)
+ {
+ for (long col = 0; col < image.nc(); ++col)
+ {
+ rgb_pixel p;
+ p.red = 0;
+ p.green = 0;
+ p.blue = 0;
+ assign_pixel(p,image[row][col]);
+ out.write((char*)&p.blue,1);
+ out.write((char*)&p.green,1);
+ out.write((char*)&p.red,1);
+ }
+
+ // write out some zeros so that this line is a multiple of 4 bytes
+ for (unsigned long i = 0; i < pad; ++i)
+ {
+ unsigned char p = 0;
+ out.write((char*)&p,1);
+ }
+ }
+
+ if (!out)
+ throw image_save_error("error writing image to output stream");
+ }
+ };
+
+ template <typename image_type>
+ struct save_bmp_helper<image_type,true>
+ {
+ static void save_bmp (
+ const image_type& image_,
+ std::ostream& out
+ )
+ {
+ const_image_view<image_type> image(image_);
+ // we are going to write out an 8bit color image.
+ byte_orderer::kernel_1a bo;
+
+ out.write("BM",2);
+
+ if (!out)
+ throw image_save_error("error writing image to output stream");
+
+ unsigned long pad = 4 - image.nc()%4;
+ if (pad == 4)
+ pad = 0;
+
+ unsigned long bfSize = 14 + 40 + (image.nc() + pad)*image.nr() + 256*4;
+ unsigned long bfReserved = 0;
+ unsigned long bfOffBits = 14 + 40 + 256*4;
+ unsigned long biSize = 40;
+ unsigned long biWidth = image.nc();
+ unsigned long biHeight = image.nr();
+ unsigned short biPlanes = 1;
+ unsigned short biBitCount = 8;
+ unsigned long biCompression = 0;
+ unsigned long biSizeImage = 0;
+ unsigned long biXPelsPerMeter = 0;
+ unsigned long biYPelsPerMeter = 0;
+ unsigned long biClrUsed = 0;
+ unsigned long biClrImportant = 0;
+
+ bo.host_to_little(bfSize);
+ bo.host_to_little(bfOffBits);
+ bo.host_to_little(biSize);
+ bo.host_to_little(biWidth);
+ bo.host_to_little(biHeight);
+ bo.host_to_little(biPlanes);
+ bo.host_to_little(biBitCount);
+
+ out.write((char*)&bfSize,4);
+ out.write((char*)&bfReserved,4);
+ out.write((char*)&bfOffBits,4);
+ out.write((char*)&biSize,4);
+ out.write((char*)&biWidth,4);
+ out.write((char*)&biHeight,4);
+ out.write((char*)&biPlanes,2);
+ out.write((char*)&biBitCount,2);
+ out.write((char*)&biCompression,4);
+ out.write((char*)&biSizeImage,4);
+ out.write((char*)&biXPelsPerMeter,4);
+ out.write((char*)&biYPelsPerMeter,4);
+ out.write((char*)&biClrUsed,4);
+ out.write((char*)&biClrImportant,4);
+
+
+ // write out the color palette
+ for (unsigned int i = 0; i <= 255; ++i)
+ {
+ unsigned char ch = static_cast<unsigned char>(i);
+ out.write((char*)&ch,1);
+ out.write((char*)&ch,1);
+ out.write((char*)&ch,1);
+ ch = 0;
+ out.write((char*)&ch,1);
+ }
+
+ if (!out)
+ throw image_save_error("error writing image to output stream");
+
+ // now we write out the pixel data
+ for (long row = image.nr()-1; row >= 0; --row)
+ {
+ for (long col = 0; col < image.nc(); ++col)
+ {
+ unsigned char p = 0;
+ assign_pixel(p,image[row][col]);
+ out.write((char*)&p,1);
+ }
+
+ // write out some zeros so that this line is a multiple of 4 bytes
+ for (unsigned long i = 0; i < pad; ++i)
+ {
+ unsigned char p = 0;
+ out.write((char*)&p,1);
+ }
+ }
+
+ if (!out)
+ throw image_save_error("error writing image to output stream");
+
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ inline typename disable_if<is_matrix<image_type> >::type save_bmp (
+ const image_type& image,
+ std::ostream& out
+ )
+ {
+ save_bmp_helper<image_type>::save_bmp(image,out);
+ }
+
+ template <
+ typename EXP
+ >
+ inline void save_bmp (
+ const matrix_exp<EXP>& image,
+ std::ostream& out
+ )
+ {
+ array2d<typename EXP::type> temp;
+ assign_image(temp, image);
+ save_bmp_helper<array2d<typename EXP::type> >::save_bmp(temp,out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace dng_helpers_namespace
+ {
+ template <
+ typename image_type,
+ typename enabled = void
+ >
+ struct save_dng_helper;
+
+ typedef entropy_encoder::kernel_2a encoder_type;
+ typedef entropy_encoder_model<256,encoder_type>::kernel_5a eem_type;
+
+ typedef entropy_encoder_model<256,encoder_type>::kernel_4a eem_exp_type;
+
+ template <typename image_type >
+ struct save_dng_helper<image_type, typename enable_if<is_float_type<typename image_traits<image_type>::pixel_type> >::type >
+ {
+ static void save_dng (
+ const image_type& image_,
+ std::ostream& out
+ )
+ {
+ const_image_view<image_type> image(image_);
+ out.write("DNG",3);
+ unsigned long version = 1;
+ serialize(version,out);
+ unsigned long type = grayscale_float;
+ serialize(type,out);
+ serialize(image.nc(),out);
+ serialize(image.nr(),out);
+
+
+ // Write the compressed exponent data into expbuf. We will append it
+ // to the stream at the end of the loops.
+ std::vector<char> expbuf;
+ expbuf.reserve(image.size()*2);
+ vectorstream outexp(expbuf);
+ encoder_type encoder;
+ encoder.set_stream(outexp);
+
+ eem_exp_type eem_exp(encoder);
+ float_details prev;
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ float_details cur = image[r][c];
+ int16 exp = cur.exponent-prev.exponent;
+ int64 man = cur.mantissa-prev.mantissa;
+ prev = cur;
+
+ unsigned char ebyte1 = exp&0xFF;
+ unsigned char ebyte2 = exp>>8;
+ eem_exp.encode(ebyte1);
+ eem_exp.encode(ebyte2);
+
+ serialize(man, out);
+ }
+ }
+ // write out the magic byte to mark the end of the compressed data.
+ eem_exp.encode(dng_magic_byte);
+ eem_exp.encode(dng_magic_byte);
+ eem_exp.encode(dng_magic_byte);
+ eem_exp.encode(dng_magic_byte);
+
+ encoder.clear();
+ serialize(expbuf, out);
+ }
+ };
+
+
+ template <typename image_type>
+ struct is_non_float_non8bit_grayscale
+ {
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ const static bool value = pixel_traits<pixel_type>::grayscale &&
+ sizeof(pixel_type) != 1 &&
+ !is_float_type<pixel_type>::value;
+ };
+
+ template <typename image_type >
+ struct save_dng_helper<image_type, typename enable_if<is_non_float_non8bit_grayscale<image_type> >::type>
+ {
+ static void save_dng (
+ const image_type& image_,
+ std::ostream& out
+ )
+ {
+ const_image_view<image_type> image(image_);
+ out.write("DNG",3);
+ unsigned long version = 1;
+ serialize(version,out);
+ unsigned long type = grayscale_16bit;
+ serialize(type,out);
+ serialize(image.nc(),out);
+ serialize(image.nr(),out);
+
+ encoder_type encoder;
+ encoder.set_stream(out);
+
+ eem_type eem(encoder);
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ uint16 cur;
+ assign_pixel(cur, image[r][c]);
+ cur -= predictor_grayscale_16(image,r,c);
+ unsigned char byte1 = cur&0xFF;
+ unsigned char byte2 = cur>>8;
+ eem.encode(byte2);
+ eem.encode(byte1);
+ }
+ }
+ // write out the magic byte to mark the end of the data
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ }
+ };
+
+ template <typename image_type>
+ struct is_8bit_grayscale
+ {
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ const static bool value = pixel_traits<pixel_type>::grayscale && sizeof(pixel_type) == 1;
+ };
+
+ template <typename image_type>
+ struct save_dng_helper<image_type, typename enable_if<is_8bit_grayscale<image_type> >::type>
+ {
+ static void save_dng (
+ const image_type& image_,
+ std::ostream& out
+ )
+ {
+ const_image_view<image_type> image(image_);
+ out.write("DNG",3);
+ unsigned long version = 1;
+ serialize(version,out);
+ unsigned long type = grayscale;
+ serialize(type,out);
+ serialize(image.nc(),out);
+ serialize(image.nr(),out);
+
+ encoder_type encoder;
+ encoder.set_stream(out);
+
+ eem_type eem(encoder);
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ unsigned char cur;
+ assign_pixel(cur, image[r][c]);
+ cur -= predictor_grayscale(image,r,c);
+ eem.encode(cur);
+ }
+ }
+ // write out the magic byte to mark the end of the data
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ }
+ };
+
+ template <typename image_type>
+ struct is_rgb_image
+ {
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ const static bool value = pixel_traits<pixel_type>::rgb;
+ };
+
+ template <typename image_type>
+ struct save_dng_helper<image_type,typename enable_if<is_rgb_image<image_type> >::type>
+ {
+ static void save_dng (
+ const image_type& image_,
+ std::ostream& out
+ )
+ {
+ const_image_view<image_type> image(image_);
+ out.write("DNG",3);
+ unsigned long version = 1;
+ serialize(version,out);
+
+ unsigned long type = rgb;
+ // if this is a small image then we will use a different predictor
+ if (image.size() < 4000)
+ type = rgb_paeth;
+
+ serialize(type,out);
+ serialize(image.nc(),out);
+ serialize(image.nr(),out);
+
+ encoder_type encoder;
+ encoder.set_stream(out);
+
+ rgb_pixel pre, cur;
+ eem_type eem(encoder);
+
+ if (type == rgb)
+ {
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ pre = predictor_rgb(image,r,c);
+ assign_pixel(cur, image[r][c]);
+
+ eem.encode((unsigned char)(cur.red - pre.red));
+ eem.encode((unsigned char)(cur.green - pre.green));
+ eem.encode((unsigned char)(cur.blue - pre.blue));
+ }
+ }
+ }
+ else
+ {
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ pre = predictor_rgb_paeth(image,r,c);
+ assign_pixel(cur, image[r][c]);
+
+ eem.encode((unsigned char)(cur.red - pre.red));
+ eem.encode((unsigned char)(cur.green - pre.green));
+ eem.encode((unsigned char)(cur.blue - pre.blue));
+ }
+ }
+ }
+ // write out the magic byte to mark the end of the data
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ }
+ };
+
+ template <typename image_type>
+ struct is_rgb_alpha_image
+ {
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ const static bool value = pixel_traits<pixel_type>::rgb_alpha;
+ };
+
+ template <typename image_type>
+ struct save_dng_helper<image_type,typename enable_if<is_rgb_alpha_image<image_type> >::type>
+ {
+ static void save_dng (
+ const image_type& image_,
+ std::ostream& out
+ )
+ {
+ const_image_view<image_type> image(image_);
+ out.write("DNG",3);
+ unsigned long version = 1;
+ serialize(version,out);
+
+ unsigned long type = rgb_alpha;
+ // if this is a small image then we will use a different predictor
+ if (image.size() < 4000)
+ type = rgb_alpha_paeth;
+
+ serialize(type,out);
+ serialize(image.nc(),out);
+ serialize(image.nr(),out);
+
+ encoder_type encoder;
+ encoder.set_stream(out);
+
+ rgb_alpha_pixel pre, cur;
+ eem_type eem(encoder);
+
+ if (type == rgb_alpha)
+ {
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ pre = predictor_rgb_alpha(image,r,c);
+ assign_pixel(cur, image[r][c]);
+
+ eem.encode((unsigned char)(cur.red - pre.red));
+ eem.encode((unsigned char)(cur.green - pre.green));
+ eem.encode((unsigned char)(cur.blue - pre.blue));
+ eem.encode((unsigned char)(cur.alpha - pre.alpha));
+ }
+ }
+ }
+ else
+ {
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ pre = predictor_rgb_alpha_paeth(image,r,c);
+ assign_pixel(cur, image[r][c]);
+
+ eem.encode((unsigned char)(cur.red - pre.red));
+ eem.encode((unsigned char)(cur.green - pre.green));
+ eem.encode((unsigned char)(cur.blue - pre.blue));
+ eem.encode((unsigned char)(cur.alpha - pre.alpha));
+ }
+ }
+ }
+ // write out the magic byte to mark the end of the data
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ }
+ };
+
+ template <typename image_type>
+ struct is_hsi_image
+ {
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ const static bool value = pixel_traits<pixel_type>::hsi;
+ };
+
+ template <typename image_type>
+ struct save_dng_helper<image_type,typename enable_if<is_hsi_image<image_type> >::type>
+ {
+ static void save_dng (
+ const image_type& image_,
+ std::ostream& out
+ )
+ {
+ const_image_view<image_type> image(image_);
+ out.write("DNG",3);
+ unsigned long version = 1;
+ serialize(version,out);
+ unsigned long type = hsi;
+ serialize(type,out);
+ serialize(image.nc(),out);
+ serialize(image.nr(),out);
+
+ encoder_type encoder;
+ encoder.set_stream(out);
+
+ hsi_pixel pre, cur;
+ eem_type eem(encoder);
+ for (long r = 0; r < image.nr(); ++r)
+ {
+ for (long c = 0; c < image.nc(); ++c)
+ {
+ pre = predictor_hsi(image,r,c);
+ assign_pixel(cur, image[r][c]);
+
+ eem.encode((unsigned char)(cur.h - pre.h));
+ eem.encode((unsigned char)(cur.s - pre.s));
+ eem.encode((unsigned char)(cur.i - pre.i));
+ }
+ }
+ // write out the magic byte to mark the end of the data
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ eem.encode(dng_magic_byte);
+ }
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ inline typename disable_if<is_matrix<image_type> >::type save_dng (
+ const image_type& image,
+ std::ostream& out
+ )
+ {
+ using namespace dng_helpers_namespace;
+ save_dng_helper<image_type>::save_dng(image,out);
+ }
+
+ template <
+ typename EXP
+ >
+ inline void save_dng (
+ const matrix_exp<EXP>& image,
+ std::ostream& out
+ )
+ {
+ array2d<typename EXP::type> temp;
+ assign_image(temp, image);
+ using namespace dng_helpers_namespace;
+ save_dng_helper<array2d<typename EXP::type> >::save_dng(temp,out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ void save_dng (
+ const image_type& image,
+ const std::string& file_name
+ )
+ {
+ std::ofstream fout(file_name.c_str(), std::ios::binary);
+ if (!fout)
+ throw image_save_error("Unable to open " + file_name + " for writing.");
+ save_dng(image, fout);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ void save_bmp (
+ const image_type& image,
+ const std::string& file_name
+ )
+ {
+ std::ofstream fout(file_name.c_str(), std::ios::binary);
+ if (!fout)
+ throw image_save_error("Unable to open " + file_name + " for writing.");
+ save_bmp(image, fout);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IMAGE_SAVEr_
+
+
+
+
diff --git a/ml/dlib/dlib/image_saver/image_saver_abstract.h b/ml/dlib/dlib/image_saver/image_saver_abstract.h
new file mode 100644
index 000000000..82f91ed45
--- /dev/null
+++ b/ml/dlib/dlib/image_saver/image_saver_abstract.h
@@ -0,0 +1,129 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_IMAGE_SAVEr_ABSTRACT_
+#ifdef DLIB_IMAGE_SAVEr_ABSTRACT_
+
+#include <iosfwd>
+#include "../algs.h"
+#include "../pixel.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+ class image_save_error : public dlib::error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an exception used to indicate a failure to save an image.
+ Its type member variable will be set to EIMAGE_SAVE.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void save_bmp (
+ const image_type& image,
+ std::ostream& out
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or any kind of matrix expression.
+ ensures
+ - writes the image to the out stream in the Microsoft Windows BMP format.
+ - image[0][0] will be in the upper left corner of the image.
+ - image[image.nr()-1][image.nc()-1] will be in the lower right
+ corner of the image.
+ - This routine can save images containing any type of pixel. However, it
+ will convert all color pixels into rgb_pixel and grayscale pixels into
+ uint8 type before saving to disk.
+ throws
+ - image_save_error
+ This exception is thrown if there is an error that prevents us
+ from saving the image.
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void save_bmp (
+ const image_type& image,
+ const std::string& file_name
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or any kind of matrix expression.
+ ensures
+ - opens the file indicated by file_name with an output file stream named fout
+ and performs:
+ save_bmp(image,fout);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ dlib dng file format:
+ This is a file format I created for this library. It is a lossless
+ compressed image format that is similar to the PNG format but uses
+ the dlib PPM compression algorithms instead of the DEFLATE algorithm.
+ !*/
+
+ template <
+ typename image_type
+ >
+ void save_dng (
+ const image_type& image,
+ std::ostream& out
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or any kind of matrix expression.
+ ensures
+ - writes the image to the out stream in the dlib dng format.
+ - image[0][0] will be in the upper left corner of the image.
+ - image[image.nr()-1][image.nc()-1] will be in the lower right
+ corner of the image.
+ - This routine can save images containing any type of pixel. However, the DNG
+ format can natively store only the following pixel types: rgb_pixel,
+ hsi_pixel, rgb_alpha_pixel, uint8, uint16, float, and double.
+ All other pixel types will be converted into one of these types as
+ appropriate before being saved to disk.
+ throws
+ - image_save_error
+ This exception is thrown if there is an error that prevents us
+ from saving the image.
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ void save_dng (
+ const image_type& image,
+ const std::string& file_name
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or any kind of matrix expression.
+ ensures
+ - opens the file indicated by file_name with an output file stream named fout
+ and performs:
+ save_dng(image,fout);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IMAGE_SAVEr_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/image_saver/save_jpeg.cpp b/ml/dlib/dlib/image_saver/save_jpeg.cpp
new file mode 100644
index 000000000..ef637fa7a
--- /dev/null
+++ b/ml/dlib/dlib/image_saver/save_jpeg.cpp
@@ -0,0 +1,175 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net), Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_JPEG_SAVER_CPp_
+#define DLIB_JPEG_SAVER_CPp_
+
+// only do anything with this file if DLIB_JPEG_SUPPORT is defined
+#ifdef DLIB_JPEG_SUPPORT
+
+#include "../array2d.h"
+#include "../pixel.h"
+#include "save_jpeg.h"
+#include <stdio.h>
+#include <sstream>
+#include <setjmp.h>
+#include "image_saver.h"
+
+#ifdef DLIB_JPEG_STATIC
+# include "../external/libjpeg/jpeglib.h"
+#else
+# include <jpeglib.h>
+#endif
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct jpeg_saver_error_mgr
+ {
+ jpeg_error_mgr pub; /* "public" fields */
+ jmp_buf setjmp_buffer; /* for return to caller */
+ };
+
+ void jpeg_saver_error_exit (j_common_ptr cinfo)
+ {
+ /* cinfo->err really points to a jpeg_saver_error_mgr struct, so coerce pointer */
+ jpeg_saver_error_mgr* myerr = (jpeg_saver_error_mgr*) cinfo->err;
+
+ /* Return control to the setjmp point */
+ longjmp(myerr->setjmp_buffer, 1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void save_jpeg (
+ const array2d<rgb_pixel>& img,
+ const std::string& filename,
+ int quality
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(img.size() != 0,
+ "\t save_jpeg()"
+ << "\n\t You can't save an empty image as a JPEG."
+ );
+ DLIB_CASSERT(0 <= quality && quality <= 100,
+ "\t save_jpeg()"
+ << "\n\t Invalid quality value."
+ << "\n\t quality: " << quality
+ );
+
+ FILE* outfile = fopen(filename.c_str(), "wb");
+ if (!outfile)
+ throw image_save_error("Can't open file " + filename + " for writing.");
+
+ jpeg_compress_struct cinfo;
+
+ jpeg_saver_error_mgr jerr;
+ cinfo.err = jpeg_std_error(&jerr.pub);
+ jerr.pub.error_exit = jpeg_saver_error_exit;
+ /* Establish the setjmp return context for my_error_exit to use. */
+ if (setjmp(jerr.setjmp_buffer))
+ {
+ /* If we get here, the JPEG code has signaled an error.
+ * We need to clean up the JPEG object, close the input file, and return.
+ */
+ jpeg_destroy_compress(&cinfo);
+ fclose(outfile);
+ throw image_save_error("save_jpeg: error while writing " + filename);
+ }
+
+ jpeg_create_compress(&cinfo);
+ jpeg_stdio_dest(&cinfo, outfile);
+
+ cinfo.image_width = img.nc();
+ cinfo.image_height = img.nr();
+ cinfo.input_components = 3;
+ cinfo.in_color_space = JCS_RGB;
+ jpeg_set_defaults(&cinfo);
+ jpeg_set_quality (&cinfo, quality, TRUE);
+ jpeg_start_compress(&cinfo, TRUE);
+
+ // now write out the rows one at a time
+ while (cinfo.next_scanline < cinfo.image_height) {
+ JSAMPROW row_pointer = (JSAMPROW) &img[cinfo.next_scanline][0];
+ jpeg_write_scanlines(&cinfo, &row_pointer, 1);
+ }
+
+ jpeg_finish_compress(&cinfo);
+ jpeg_destroy_compress(&cinfo);
+ fclose( outfile );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void save_jpeg (
+ const array2d<unsigned char>& img,
+ const std::string& filename,
+ int quality
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(img.size() != 0,
+ "\t save_jpeg()"
+ << "\n\t You can't save an empty image as a JPEG."
+ );
+ DLIB_CASSERT(0 <= quality && quality <= 100,
+ "\t save_jpeg()"
+ << "\n\t Invalid quality value."
+ << "\n\t quality: " << quality
+ );
+
+
+ FILE* outfile = fopen(filename.c_str(), "wb");
+ if (!outfile)
+ throw image_save_error("Can't open file " + filename + " for writing.");
+
+ jpeg_compress_struct cinfo;
+
+ jpeg_saver_error_mgr jerr;
+ cinfo.err = jpeg_std_error(&jerr.pub);
+ jerr.pub.error_exit = jpeg_saver_error_exit;
+ /* Establish the setjmp return context for my_error_exit to use. */
+ if (setjmp(jerr.setjmp_buffer))
+ {
+ /* If we get here, the JPEG code has signaled an error.
+ * We need to clean up the JPEG object, close the input file, and return.
+ */
+ jpeg_destroy_compress(&cinfo);
+ fclose(outfile);
+ throw image_save_error("save_jpeg: error while writing " + filename);
+ }
+
+ jpeg_create_compress(&cinfo);
+ jpeg_stdio_dest(&cinfo, outfile);
+
+ cinfo.image_width = img.nc();
+ cinfo.image_height = img.nr();
+ cinfo.input_components = 1;
+ cinfo.in_color_space = JCS_GRAYSCALE;
+ jpeg_set_defaults(&cinfo);
+ jpeg_set_quality (&cinfo, quality, TRUE);
+ jpeg_start_compress(&cinfo, TRUE);
+
+ // now write out the rows one at a time
+ while (cinfo.next_scanline < cinfo.image_height) {
+ JSAMPROW row_pointer = (JSAMPROW) &img[cinfo.next_scanline][0];
+ jpeg_write_scanlines(&cinfo, &row_pointer, 1);
+ }
+
+ jpeg_finish_compress(&cinfo);
+ jpeg_destroy_compress(&cinfo);
+ fclose( outfile );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_JPEG_SUPPORT
+
+#endif // DLIB_JPEG_SAVER_CPp_
+
+
+
diff --git a/ml/dlib/dlib/image_saver/save_jpeg.h b/ml/dlib/dlib/image_saver/save_jpeg.h
new file mode 100644
index 000000000..fb1808c44
--- /dev/null
+++ b/ml/dlib/dlib/image_saver/save_jpeg.h
@@ -0,0 +1,82 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SAVE_JPEG_Hh_
+#define DLIB_SAVE_JPEG_Hh_
+
+#include "save_jpeg_abstract.h"
+
+#include "../enable_if.h"
+#include "../matrix.h"
+#include "../array2d.h"
+#include "../pixel.h"
+#include "../image_processing/generic_image.h"
+#include <string>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void save_jpeg (
+ const array2d<rgb_pixel>& img,
+ const std::string& filename,
+ int quality = 75
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ void save_jpeg (
+ const array2d<unsigned char>& img,
+ const std::string& filename,
+ int quality = 75
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ typename disable_if<is_matrix<image_type> >::type save_jpeg(
+ const image_type& img,
+ const std::string& filename,
+ int quality = 75
+ )
+ {
+ // Convert any kind of grayscale image to an unsigned char image
+ if (pixel_traits<typename image_traits<image_type>::pixel_type>::grayscale)
+ {
+ array2d<unsigned char> temp;
+ assign_image(temp, img);
+ save_jpeg(temp, filename, quality);
+ }
+ else
+ {
+ // This is some other kind of color image so just save it as an RGB image.
+ array2d<rgb_pixel> temp;
+ assign_image(temp, img);
+ save_jpeg(temp, filename, quality);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ void save_jpeg(
+ const matrix_exp<EXP>& img,
+ const std::string& file_name,
+ int quality = 75
+ )
+ {
+ array2d<typename EXP::type> temp;
+ assign_image(temp, img);
+ save_jpeg(temp, file_name, quality);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SAVE_JPEG_Hh_
+
diff --git a/ml/dlib/dlib/image_saver/save_jpeg_abstract.h b/ml/dlib/dlib/image_saver/save_jpeg_abstract.h
new file mode 100644
index 000000000..f441339b9
--- /dev/null
+++ b/ml/dlib/dlib/image_saver/save_jpeg_abstract.h
@@ -0,0 +1,52 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SAVE_JPEG_ABSTRACT_Hh_
+#ifdef DLIB_SAVE_JPEG_ABSTRACT_Hh_
+
+#include "../image_processing/generic_image.h"
+#include "../pixel.h"
+#include <string>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void save_jpeg (
+ const image_type& img,
+ const std::string& filename,
+ int quality = 75
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or a matrix expression
+ - image.size() != 0
+ - 0 <= quality <= 100
+ ensures
+ - writes the image to the file indicated by file_name in the JPEG format.
+ - image[0][0] will be in the upper left corner of the image.
+ - image[image.nr()-1][image.nc()-1] will be in the lower right corner of the
+ image.
+ - This routine can save images containing any type of pixel. However,
+ save_jpeg() can only natively store rgb_pixel and uint8 pixel types. All
+ other pixel types will be converted into one of these types as appropriate
+ before being saved to disk.
+ - The quality value determines how lossy the compression is. Larger quality
+ values result in larger output images but the images will look better.
+ throws
+ - image_save_error
+ This exception is thrown if there is an error that prevents us from saving
+ the image.
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SAVE_JPEG_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_saver/save_png.cpp b/ml/dlib/dlib/image_saver/save_png.cpp
new file mode 100644
index 000000000..1c96b929c
--- /dev/null
+++ b/ml/dlib/dlib/image_saver/save_png.cpp
@@ -0,0 +1,124 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SAVE_PnG_CPPh_
+#define DLIB_SAVE_PnG_CPPh_
+
+// only do anything with this file if DLIB_PNG_SUPPORT is defined
+#ifdef DLIB_PNG_SUPPORT
+
+#include "save_png.h"
+#include <cstdio>
+#include <png.h>
+#include "../byte_orderer.h"
+
+namespace dlib
+{
+ // Don't do anything when libpng calls us to tell us about an error. Just return to
+ // our own code and throw an exception (at the long jump target).
+ void png_reader_user_error_fn_silent(png_structp png_struct, png_const_charp )
+ {
+ longjmp(png_jmpbuf(png_struct),1);
+ }
+ void png_reader_user_warning_fn_silent(png_structp , png_const_charp )
+ {
+ }
+
+ namespace impl
+ {
+ void impl_save_png (
+ const std::string& file_name,
+ std::vector<unsigned char*>& row_pointers,
+ const long width,
+ const png_type type,
+ const int bit_depth
+ )
+ {
+
+ FILE *fp;
+ png_structp png_ptr;
+ png_infop info_ptr;
+
+ /* Open the file */
+ fp = fopen(file_name.c_str(), "wb");
+ if (fp == NULL)
+ throw image_save_error("Unable to open " + file_name + " for writing.");
+
+ /* Create and initialize the png_struct with the desired error handler
+ * functions. If you want to use the default stderr and longjump method,
+ * you can supply NULL for the last three parameters. We also check that
+ * the library version is compatible with the one used at compile time,
+ * in case we are using dynamically linked libraries. REQUIRED.
+ */
+ png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, NULL, &png_reader_user_error_fn_silent, &png_reader_user_warning_fn_silent);
+
+ if (png_ptr == NULL)
+ {
+ fclose(fp);
+ throw image_save_error("Error while writing PNG file " + file_name);
+ }
+
+ /* Allocate/initialize the image information data. REQUIRED */
+ info_ptr = png_create_info_struct(png_ptr);
+ if (info_ptr == NULL)
+ {
+ fclose(fp);
+ png_destroy_write_struct(&png_ptr, NULL);
+ throw image_save_error("Error while writing PNG file " + file_name);
+ }
+
+ /* Set error handling. REQUIRED if you aren't supplying your own
+ * error handling functions in the png_create_write_struct() call.
+ */
+ if (setjmp(png_jmpbuf(png_ptr)))
+ {
+ /* If we get here, we had a problem writing the file */
+ fclose(fp);
+ png_destroy_write_struct(&png_ptr, &info_ptr);
+ throw image_save_error("Error while writing PNG file " + file_name);
+ }
+
+ int color_type = 0;
+ switch(type)
+ {
+ case png_type_rgb: color_type = PNG_COLOR_TYPE_RGB; break;
+ case png_type_rgb_alpha: color_type = PNG_COLOR_TYPE_RGB_ALPHA; break;
+ case png_type_gray: color_type = PNG_COLOR_TYPE_GRAY; break;
+ default:
+ {
+ fclose(fp);
+ png_destroy_write_struct(&png_ptr, &info_ptr);
+ throw image_save_error("Invalid color type");
+ }
+ }
+
+
+ /* Set up the output control if you are using standard C streams */
+ png_init_io(png_ptr, fp);
+
+
+ int png_transforms = PNG_TRANSFORM_IDENTITY;
+ byte_orderer bo;
+ if (bo.host_is_little_endian())
+ png_transforms |= PNG_TRANSFORM_SWAP_ENDIAN;
+
+ const long height = row_pointers.size();
+
+
+ png_set_IHDR(png_ptr, info_ptr, width, height, bit_depth, color_type, PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT, PNG_FILTER_TYPE_DEFAULT);
+ png_set_rows(png_ptr, info_ptr, &row_pointers[0]);
+ png_write_png(png_ptr, info_ptr, png_transforms, NULL);
+
+ /* Clean up after the write, and free any memory allocated */
+ png_destroy_write_struct(&png_ptr, &info_ptr);
+
+ /* Close the file */
+ fclose(fp);
+ }
+ }
+}
+
+#endif // DLIB_PNG_SUPPORT
+
+#endif // DLIB_SAVE_PnG_CPPh_
+
+
diff --git a/ml/dlib/dlib/image_saver/save_png.h b/ml/dlib/dlib/image_saver/save_png.h
new file mode 100644
index 000000000..cddf03ff6
--- /dev/null
+++ b/ml/dlib/dlib/image_saver/save_png.h
@@ -0,0 +1,162 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SAVE_PnG_Hh_
+#define DLIB_SAVE_PnG_Hh_
+
+#include "save_png_abstract.h"
+#include "image_saver.h"
+#include "../array2d.h"
+#include <vector>
+#include <string>
+#include "../pixel.h"
+#include "../matrix/matrix_exp.h"
+#include "../image_transforms/assign_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ enum png_type
+ {
+ png_type_rgb,
+ png_type_rgb_alpha,
+ png_type_gray,
+ };
+
+ void impl_save_png (
+ const std::string& file_name,
+ std::vector<unsigned char*>& row_pointers,
+ const long width,
+ const png_type type,
+ const int bit_depth
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ typename disable_if<is_matrix<image_type> >::type save_png(
+ const image_type& img_,
+ const std::string& file_name
+ )
+ {
+ const_image_view<image_type> img(img_);
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(img.size() != 0,
+ "\t save_png()"
+ << "\n\t You can't save an empty image as a PNG"
+ );
+
+
+#ifndef DLIB_PNG_SUPPORT
+ /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ You are getting this error because you are trying to use save_png()
+ but you haven't defined DLIB_PNG_SUPPORT. You must do so to use
+ this function. You must also make sure you set your build environment
+ to link against the libpng library.
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
+ COMPILE_TIME_ASSERT(sizeof(image_type) == 0);
+#else
+ std::vector<unsigned char*> row_pointers(img.nr());
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+
+ if (is_same_type<rgb_pixel,pixel_type>::value)
+ {
+ for (unsigned long i = 0; i < row_pointers.size(); ++i)
+ row_pointers[i] = (unsigned char*)(&img[i][0]);
+
+ impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_rgb, 8);
+ }
+ else if (is_same_type<rgb_alpha_pixel,pixel_type>::value)
+ {
+ for (unsigned long i = 0; i < row_pointers.size(); ++i)
+ row_pointers[i] = (unsigned char*)(&img[i][0]);
+
+ impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_rgb_alpha, 8);
+ }
+ else if (pixel_traits<pixel_type>::lab || pixel_traits<pixel_type>::hsi || pixel_traits<pixel_type>::rgb)
+ {
+ // convert from Lab or HSI to RGB (Or potentially RGB pixels that aren't laid out as R G B)
+ array2d<rgb_pixel> temp_img;
+ assign_image(temp_img, img_);
+ for (unsigned long i = 0; i < row_pointers.size(); ++i)
+ row_pointers[i] = (unsigned char*)(&temp_img[i][0]);
+
+ impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_rgb, 8);
+ }
+ else if (pixel_traits<pixel_type>::rgb_alpha)
+ {
+ // convert from RGBA pixels that aren't laid out as R G B A
+ array2d<rgb_alpha_pixel> temp_img;
+ assign_image(temp_img, img_);
+ for (unsigned long i = 0; i < row_pointers.size(); ++i)
+ row_pointers[i] = (unsigned char*)(&temp_img[i][0]);
+
+ impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_rgb_alpha, 8);
+ }
+ else // this is supposed to be grayscale
+ {
+ DLIB_CASSERT(pixel_traits<pixel_type>::grayscale, "impossible condition detected");
+
+ if (pixel_traits<pixel_type>::is_unsigned && sizeof(pixel_type) == 1)
+ {
+ for (unsigned long i = 0; i < row_pointers.size(); ++i)
+ row_pointers[i] = (unsigned char*)(&img[i][0]);
+
+ impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_gray, 8);
+ }
+ else if (pixel_traits<pixel_type>::is_unsigned && sizeof(pixel_type) == 2)
+ {
+ for (unsigned long i = 0; i < row_pointers.size(); ++i)
+ row_pointers[i] = (unsigned char*)(&img[i][0]);
+
+ impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_gray, 16);
+ }
+ else
+ {
+ // convert from whatever this is to 16bit grayscale
+ array2d<dlib::uint16> temp_img;
+ assign_image(temp_img, img_);
+ for (unsigned long i = 0; i < row_pointers.size(); ++i)
+ row_pointers[i] = (unsigned char*)(&temp_img[i][0]);
+
+ impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_gray, 16);
+ }
+ }
+
+
+#endif
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ void save_png(
+ const matrix_exp<EXP>& img,
+ const std::string& file_name
+ )
+ {
+ array2d<typename EXP::type> temp;
+ assign_image(temp, img);
+ save_png(temp, file_name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "save_png.cpp"
+#endif
+
+#endif // DLIB_SAVE_PnG_Hh_
+
diff --git a/ml/dlib/dlib/image_saver/save_png_abstract.h b/ml/dlib/dlib/image_saver/save_png_abstract.h
new file mode 100644
index 000000000..ae495d1f2
--- /dev/null
+++ b/ml/dlib/dlib/image_saver/save_png_abstract.h
@@ -0,0 +1,50 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SAVE_PnG_ABSTRACT_
+#ifdef DLIB_SAVE_PnG_ABSTRACT_
+
+#include "../pixel.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void save_png (
+ const image_type& image,
+ const std::string& file_name
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or a matrix expression
+ - image.size() != 0
+ ensures
+ - writes the image to the file indicated by file_name in the PNG (Portable Network Graphics)
+ format.
+ - image[0][0] will be in the upper left corner of the image.
+ - image[image.nr()-1][image.nc()-1] will be in the lower right
+ corner of the image.
+ - This routine can save images containing any type of pixel. However, save_png() can
+ only natively store the following pixel types: rgb_pixel, rgb_alpha_pixel, uint8,
+ and uint16. All other pixel types will be converted into one of these types as
+ appropriate before being saved to disk.
+ throws
+ - image_save_error
+ This exception is thrown if there is an error that prevents us from saving
+ the image.
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SAVE_PnG_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/image_transforms.h b/ml/dlib/dlib/image_transforms.h
new file mode 100644
index 000000000..89b4e0db6
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms.h
@@ -0,0 +1,31 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_IMAGE_TRANSFORMs_
+#define DLIB_IMAGE_TRANSFORMs_
+
+#include "image_transforms/assign_image.h"
+#include "image_transforms/equalize_histogram.h"
+#include "image_transforms/morphological_operations.h"
+#include "image_transforms/spatial_filtering.h"
+#include "image_transforms/thresholding.h"
+#include "image_transforms/edge_detector.h"
+#include "image_transforms/draw.h"
+#include "image_transforms/integral_image.h"
+#include "image_transforms/image_pyramid.h"
+#include "image_transforms/hough_transform.h"
+#include "image_transforms/label_connected_blobs.h"
+#include "image_transforms/colormaps.h"
+#include "image_transforms/segment_image.h"
+#include "image_transforms/interpolation.h"
+#include "image_transforms/fhog.h"
+#include "image_transforms/lbp.h"
+#include "image_transforms/random_color_transform.h"
+#include "image_transforms/random_cropper.h"
+
+#endif // DLIB_IMAGE_TRANSFORMs_
+
diff --git a/ml/dlib/dlib/image_transforms/assign_image.h b/ml/dlib/dlib/image_transforms/assign_image.h
new file mode 100644
index 000000000..c69878efa
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/assign_image.h
@@ -0,0 +1,385 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ASSIGN_IMAGe_
+#define DLIB_ASSIGN_IMAGe_
+
+#include "../pixel.h"
+#include "assign_image_abstract.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dest_image_type,
+ typename src_image_type
+ >
+ void impl_assign_image (
+ image_view<dest_image_type>& dest,
+ const src_image_type& src
+ )
+ {
+ dest.set_size(src.nr(),src.nc());
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ assign_pixel(dest[r][c], src(r,c));
+ }
+ }
+ }
+
+ template <
+ typename dest_image_type,
+ typename src_image_type
+ >
+ void impl_assign_image (
+ dest_image_type& dest_,
+ const src_image_type& src
+ )
+ {
+ image_view<dest_image_type> dest(dest_);
+ impl_assign_image(dest, src);
+ }
+
+ template <
+ typename dest_image_type,
+ typename src_image_type
+ >
+ void assign_image (
+ dest_image_type& dest,
+ const src_image_type& src
+ )
+ {
+ // check for the case where dest is the same object as src
+ if (is_same_object(dest,src))
+ return;
+
+ impl_assign_image(dest, mat(src));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dest_image_type,
+ typename src_image_type
+ >
+ void impl_assign_image_scaled (
+ image_view<dest_image_type>& dest,
+ const src_image_type& src,
+ const double thresh
+ )
+ {
+ DLIB_ASSERT( thresh > 0,
+ "\tvoid assign_image_scaled()"
+ << "\n\t You have given an threshold value"
+ << "\n\t thresh: " << thresh
+ );
+
+
+ typedef typename image_traits<dest_image_type>::pixel_type dest_pixel;
+
+ // If the destination has a dynamic range big enough to contain the source image data then just do a
+ // regular assign_image()
+ if (pixel_traits<dest_pixel>::max() >= pixel_traits<typename src_image_type::type>::max() &&
+ pixel_traits<dest_pixel>::min() <= pixel_traits<typename src_image_type::type>::min() )
+ {
+ impl_assign_image(dest, src);
+ return;
+ }
+
+ dest.set_size(src.nr(),src.nc());
+
+ if (src.size() == 0)
+ return;
+
+ if (src.size() == 1)
+ {
+ impl_assign_image(dest, src);
+ return;
+ }
+
+ // gather image statistics
+ running_stats<double> rs;
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ rs.add(get_pixel_intensity(src(r,c)));
+ }
+ }
+ typedef typename pixel_traits<typename src_image_type::type>::basic_pixel_type spix_type;
+
+ if (std::numeric_limits<spix_type>::is_integer)
+ {
+ // If the destination has a dynamic range big enough to contain the source image data then just do a
+ // regular assign_image()
+ if (pixel_traits<dest_pixel>::max() >= rs.max() &&
+ pixel_traits<dest_pixel>::min() <= rs.min() )
+ {
+ impl_assign_image(dest, src);
+ return;
+ }
+ }
+
+ // Figure out the range of pixel values based on image statistics. There might be some huge
+ // outliers so don't just pick the min and max values.
+ const double upper = std::min(rs.mean() + thresh*rs.stddev(), rs.max());
+ const double lower = std::max(rs.mean() - thresh*rs.stddev(), rs.min());
+
+
+ const double dest_min = pixel_traits<dest_pixel>::min();
+ const double dest_max = pixel_traits<dest_pixel>::max();
+
+ const double scale = (upper!=lower)? ((dest_max - dest_min) / (upper - lower)) : 0;
+
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ const double val = get_pixel_intensity(src(r,c)) - lower;
+
+ assign_pixel(dest[r][c], scale*val + dest_min);
+ }
+ }
+ }
+
+ template <
+ typename dest_image_type,
+ typename src_image_type
+ >
+ void impl_assign_image_scaled (
+ dest_image_type& dest_,
+ const src_image_type& src,
+ const double thresh
+ )
+ {
+ image_view<dest_image_type> dest(dest_);
+ impl_assign_image_scaled(dest, src, thresh);
+ }
+
+ template <
+ typename dest_image_type,
+ typename src_image_type
+ >
+ void assign_image_scaled (
+ dest_image_type& dest,
+ const src_image_type& src,
+ const double thresh = 4
+ )
+ {
+ // check for the case where dest is the same object as src
+ if (is_same_object(dest,src))
+ return;
+
+ impl_assign_image_scaled(dest, mat(src),thresh);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dest_image_type,
+ typename src_pixel_type
+ >
+ void assign_all_pixels (
+ image_view<dest_image_type>& dest_img,
+ const src_pixel_type& src_pixel
+ )
+ {
+ for (long r = 0; r < dest_img.nr(); ++r)
+ {
+ for (long c = 0; c < dest_img.nc(); ++c)
+ {
+ assign_pixel(dest_img[r][c], src_pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dest_image_type,
+ typename src_pixel_type
+ >
+ void assign_all_pixels (
+ dest_image_type& dest_img_,
+ const src_pixel_type& src_pixel
+ )
+ {
+ image_view<dest_image_type> dest_img(dest_img_);
+ assign_all_pixels(dest_img, src_pixel);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void assign_border_pixels (
+ image_view<image_type>& img,
+ long x_border_size,
+ long y_border_size,
+ const typename image_traits<image_type>::pixel_type& p
+ )
+ {
+ DLIB_ASSERT( x_border_size >= 0 && y_border_size >= 0,
+ "\tvoid assign_border_pixels(img, p, border_size)"
+ << "\n\tYou have given an invalid border_size"
+ << "\n\tx_border_size: " << x_border_size
+ << "\n\ty_border_size: " << y_border_size
+ );
+
+ y_border_size = std::min(y_border_size, img.nr()/2+1);
+ x_border_size = std::min(x_border_size, img.nc()/2+1);
+
+ // assign the top border
+ for (long r = 0; r < y_border_size; ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = p;
+ }
+ }
+
+ // assign the bottom border
+ for (long r = img.nr()-y_border_size; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = p;
+ }
+ }
+
+ // now assign the two sides
+ for (long r = y_border_size; r < img.nr()-y_border_size; ++r)
+ {
+ // left border
+ for (long c = 0; c < x_border_size; ++c)
+ img[r][c] = p;
+
+ // right border
+ for (long c = img.nc()-x_border_size; c < img.nc(); ++c)
+ img[r][c] = p;
+ }
+ }
+
+ template <
+ typename image_type
+ >
+ void assign_border_pixels (
+ image_type& img_,
+ long x_border_size,
+ long y_border_size,
+ const typename image_traits<image_type>::pixel_type& p
+ )
+ {
+ image_view<image_type> img(img_);
+ assign_border_pixels(img, x_border_size, y_border_size, p);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void zero_border_pixels (
+ image_type& img,
+ long x_border_size,
+ long y_border_size
+ )
+ {
+ DLIB_ASSERT( x_border_size >= 0 && y_border_size >= 0,
+ "\tvoid zero_border_pixels(img, p, border_size)"
+ << "\n\tYou have given an invalid border_size"
+ << "\n\tx_border_size: " << x_border_size
+ << "\n\ty_border_size: " << y_border_size
+ );
+
+ typename image_traits<image_type>::pixel_type zero_pixel;
+ assign_pixel_intensity(zero_pixel, 0);
+ assign_border_pixels(img, x_border_size, y_border_size, zero_pixel);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void zero_border_pixels (
+ image_view<image_type>& img,
+ long x_border_size,
+ long y_border_size
+ )
+ {
+ DLIB_ASSERT( x_border_size >= 0 && y_border_size >= 0,
+ "\tvoid zero_border_pixels(img, p, border_size)"
+ << "\n\tYou have given an invalid border_size"
+ << "\n\tx_border_size: " << x_border_size
+ << "\n\ty_border_size: " << y_border_size
+ );
+
+ typename image_traits<image_type>::pixel_type zero_pixel;
+ assign_pixel_intensity(zero_pixel, 0);
+ assign_border_pixels(img, x_border_size, y_border_size, zero_pixel);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void zero_border_pixels (
+ image_view<image_type>& img,
+ rectangle inside
+ )
+ {
+ inside = inside.intersect(get_rect(img));
+ if (inside.is_empty())
+ {
+ assign_all_pixels(img, 0);
+ return;
+ }
+
+ for (long r = 0; r < inside.top(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ assign_pixel(img[r][c], 0);
+ }
+ for (long r = inside.top(); r <= inside.bottom(); ++r)
+ {
+ for (long c = 0; c < inside.left(); ++c)
+ assign_pixel(img[r][c], 0);
+ for (long c = inside.right()+1; c < img.nc(); ++c)
+ assign_pixel(img[r][c], 0);
+ }
+ for (long r = inside.bottom()+1; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ assign_pixel(img[r][c], 0);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void zero_border_pixels (
+ image_type& img_,
+ const rectangle& inside
+ )
+ {
+ image_view<image_type> img(img_);
+ zero_border_pixels(img, inside);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ASSIGN_IMAGe_
+
+
+
diff --git a/ml/dlib/dlib/image_transforms/assign_image_abstract.h b/ml/dlib/dlib/image_transforms/assign_image_abstract.h
new file mode 100644
index 000000000..5ba262ba5
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/assign_image_abstract.h
@@ -0,0 +1,196 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ASSIGN_IMAGe_ABSTRACT
+#ifdef DLIB_ASSIGN_IMAGe_ABSTRACT
+
+#include "../pixel.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dest_image_type,
+ typename src_image_type
+ >
+ void assign_image (
+ dest_image_type& dest_img,
+ const src_image_type& src_img
+ );
+ /*!
+ requires
+ - src_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or any object convertible to a matrix
+ via mat().
+ - dest_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or an image_view.
+ ensures
+ - #dest_img.nc() == src_img.nc()
+ - #dest_img.nr() == src_img.nr()
+ - for all valid r and c:
+ - performs assign_pixel(#dest_img[r][c],src_img[r][c])
+ (i.e. copies the src image to dest image)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dest_image_type,
+ typename src_image_type
+ >
+ void assign_image_scaled (
+ dest_image_type& dest_img,
+ const src_image_type& src_img,
+ const double thresh = 4
+ );
+ /*!
+ requires
+ - src_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or any object convertible to a matrix
+ via mat().
+ - dest_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or an image_view.
+ - thresh > 0
+ ensures
+ - #dest_img.nc() == src_img.nc()
+ - #dest_img.nr() == src_img.nr()
+ - if (dest_img's pixels have a wide enough dynamic range to contain all the
+ pixels in src_img. (Note that dynamic range is determined by the min() and
+ max() pixel_traits properties)) then
+ - performs: assign_image(dest_img, src_img)
+ (i.e. in this case, no scaling is performed. Just a normal color space
+ conversion and copy )
+ - else
+ - #dest_img will be converted to a grayscale image
+ - scales the contents of src_img into the dynamic range of dest_img and then
+ assigns the result into dest_img. The thresh parameter is used to filter
+ source pixel values which are outliers. These outliers will saturate
+ at the edge of the destination image's dynamic range.
+ - Specifically, for all valid r and c:
+ - scales get_pixel_intensity(src_img[r][c]) into the dynamic range
+ of the dest_img. This is done by computing the mean and standard
+ deviation of src_img. Call the mean M and the standard deviation
+ D. Then the scaling from src_img to dest_img is performed using
+ the following mapping:
+ let SRC_UPPER = min(M + thresh*D, max(mat(src_img)))
+ let SRC_LOWER = max(M - thresh*D, min(mat(src_img)))
+ let DEST_UPPER = pixel_traits<image_traits<dest_image_type>::pixel_type>::max()
+ let DEST_LOWER = pixel_traits<image_traits<dest_image_type>::pixel_type>::min()
+
+ MAPPING: [SRC_LOWER, SRC_UPPER] -> [DEST_LOWER, DEST_UPPER]
+
+ Where this mapping is a linear mapping of values from the left range
+ into the right range of values. Source pixel values outside the left
+ range are modified to be at the appropriate end of the range.
+
+ The scaled pixel is then stored in dest_img[r][c].
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dest_image_type,
+ typename src_pixel_type
+ >
+ void assign_all_pixels (
+ dest_image_type& dest_img,
+ const src_pixel_type& src_pixel
+ );
+ /*!
+ requires
+ - dest_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or an image_view.
+ - pixel_traits<src_pixel_type> is defined
+ ensures
+ - #dest_img.nc() == dest_img.nc()
+ - #dest_img.nr() == dest_img.nr()
+ (i.e. the size of dest_img isn't changed by this function)
+ - for all valid r and c:
+ - performs assign_pixel(#dest_img[r][c],src_pixel)
+ (i.e. assigns the src pixel to every pixel in the dest image)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void assign_border_pixels (
+ image_type& img,
+ long x_border_size,
+ long y_border_size,
+ const typename image_traits<image_type>::pixel_type& p
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or an image_view
+ - x_border_size >= 0
+ - y_border_size >= 0
+ ensures
+ - #img.nc() == img.nc()
+ - #img.nr() == img.nr()
+ (i.e. the size of img isn't changed by this function)
+ - for all valid r such that r+y_border_size or r-y_border_size gives an invalid row
+ - for all valid c such that c+x_border_size or c-x_border_size gives an invalid column
+ - performs assign_pixel(#img[r][c],p)
+ (i.e. assigns the given pixel to every pixel in the border of img)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void zero_border_pixels (
+ image_type& img,
+ long x_border_size,
+ long y_border_size
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or an image_view
+ - x_border_size >= 0
+ - y_border_size >= 0
+ ensures
+ - #img.nc() == img.nc()
+ - #img.nr() == img.nr()
+ (i.e. the size of img isn't changed by this function)
+ - for all valid r such that r+y_border_size or r-y_border_size gives an invalid row
+ - for all valid c such that c+x_border_size or c-x_border_size gives an invalid column
+ - performs assign_pixel(#img[r][c], 0 )
+ (i.e. assigns 0 to every pixel in the border of img)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void zero_border_pixels (
+ image_type& img,
+ rectangle inside
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or an image_view
+ ensures
+ - #img.nc() == img.nc()
+ - #img.nr() == img.nr()
+ (i.e. the size of img isn't changed by this function)
+ - All the pixels in img that are not contained inside the inside rectangle
+ given to this function are set to 0. That is, anything not "inside" is on
+ the border and set to 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ASSIGN_IMAGe_ABSTRACT
+
+
diff --git a/ml/dlib/dlib/image_transforms/colormaps.h b/ml/dlib/dlib/image_transforms/colormaps.h
new file mode 100644
index 000000000..813d1ff75
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/colormaps.h
@@ -0,0 +1,269 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RANDOMLY_COlOR_IMAGE_Hh_
+#define DLIB_RANDOMLY_COlOR_IMAGE_Hh_
+
+#include "colormaps_abstract.h"
+#include "../hash.h"
+#include "../pixel.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct op_randomly_color_image : does_not_alias
+ {
+ op_randomly_color_image( const T& img_) : img(img_){}
+
+ const T& img;
+
+ const static long cost = 7;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef rgb_pixel type;
+ typedef const rgb_pixel const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const
+ {
+ const unsigned long gray = get_pixel_intensity(mat(img)(r,c));
+ if (gray != 0)
+ {
+ const uint32 h = murmur_hash3_2(gray,0);
+ rgb_pixel pix;
+ pix.red = static_cast<unsigned char>(h)%200 + 55;
+ pix.green = static_cast<unsigned char>(h>>8)%200 + 55;
+ pix.blue = static_cast<unsigned char>(h>>16)%200 + 55;
+ return pix;
+ }
+ else
+ {
+ // keep black pixels black
+ return rgb_pixel(0,0,0);
+ }
+ }
+
+ long nr () const { return num_rows(img); }
+ long nc () const { return num_columns(img); }
+ };
+
+ template <
+ typename image_type
+ >
+ const matrix_op<op_randomly_color_image<image_type> >
+ randomly_color_image (
+ const image_type& img
+ )
+ {
+ typedef op_randomly_color_image<image_type> op;
+ return matrix_op<op>(op(img));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline rgb_pixel colormap_heat (
+ double value,
+ double min_val,
+ double max_val
+ )
+ {
+ // scale the gray value into the range [0, 1]
+ const double gray = put_in_range(0, 1, (value - min_val)/(max_val-min_val));
+ rgb_pixel pix(0,0,0);
+
+ pix.red = static_cast<unsigned char>(std::min(gray/0.4,1.0)*255 + 0.5);
+
+ if (gray > 0.4)
+ {
+ pix.green = static_cast<unsigned char>(std::min((gray-0.4)/0.4,1.0)*255 + 0.5);
+ }
+ if (gray > 0.8)
+ {
+ pix.blue = static_cast<unsigned char>(std::min((gray-0.8)/0.2,1.0)*255 + 0.5);
+ }
+
+ return pix;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct op_heatmap : does_not_alias
+ {
+ op_heatmap(
+ const T& img_,
+ const double max_val_,
+ const double min_val_
+ ) : img(img_), max_val(max_val_), min_val(min_val_){}
+
+ const T& img;
+
+ const double max_val;
+ const double min_val;
+
+ const static long cost = 7;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef rgb_pixel type;
+ typedef const rgb_pixel const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const
+ {
+ return colormap_heat(get_pixel_intensity(mat(img)(r,c)), min_val, max_val);
+ }
+
+ long nr () const { return num_rows(img); }
+ long nc () const { return num_columns(img); }
+ };
+
+ template <
+ typename image_type
+ >
+ const matrix_op<op_heatmap<image_type> >
+ heatmap (
+ const image_type& img,
+ double max_val,
+ double min_val = 0
+ )
+ {
+ typedef op_heatmap<image_type> op;
+ return matrix_op<op>(op(img,max_val,min_val));
+ }
+
+ template <
+ typename image_type
+ >
+ const matrix_op<op_heatmap<image_type> >
+ heatmap (
+ const image_type& img
+ )
+ {
+ typedef op_heatmap<image_type> op;
+ if (num_columns(img) * num_rows(img) != 0)
+ return matrix_op<op>(op(img,max(mat(img)),min(mat(img))));
+ else
+ return matrix_op<op>(op(img,0,0));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline rgb_pixel colormap_jet (
+ double value,
+ double min_val,
+ double max_val
+ )
+ {
+ // scale the gray value into the range [0, 8]
+ const double gray = 8*put_in_range(0, 1, (value - min_val)/(max_val-min_val));
+ rgb_pixel pix;
+ // s is the slope of color change
+ const double s = 1.0/2.0;
+
+ if (gray <= 1)
+ {
+ pix.red = 0;
+ pix.green = 0;
+ pix.blue = static_cast<unsigned char>((gray+1)*s*255 + 0.5);
+ }
+ else if (gray <= 3)
+ {
+ pix.red = 0;
+ pix.green = static_cast<unsigned char>((gray-1)*s*255 + 0.5);
+ pix.blue = 255;
+ }
+ else if (gray <= 5)
+ {
+ pix.red = static_cast<unsigned char>((gray-3)*s*255 + 0.5);
+ pix.green = 255;
+ pix.blue = static_cast<unsigned char>((5-gray)*s*255 + 0.5);
+ }
+ else if (gray <= 7)
+ {
+ pix.red = 255;
+ pix.green = static_cast<unsigned char>((7-gray)*s*255 + 0.5);
+ pix.blue = 0;
+ }
+ else
+ {
+ pix.red = static_cast<unsigned char>((9-gray)*s*255 + 0.5);
+ pix.green = 0;
+ pix.blue = 0;
+ }
+
+ return pix;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct op_jet : does_not_alias
+ {
+ op_jet(
+ const T& img_,
+ const double max_val_,
+ const double min_val_
+ ) : img(img_), max_val(max_val_), min_val(min_val_){}
+
+ const T& img;
+
+ const double max_val;
+ const double min_val;
+
+ const static long cost = 7;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef rgb_pixel type;
+ typedef const rgb_pixel const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const
+ {
+ return colormap_jet(get_pixel_intensity(mat(img)(r,c)), min_val, max_val);
+ }
+
+ long nr () const { return num_rows(img); }
+ long nc () const { return num_columns(img); }
+ };
+
+ template <
+ typename image_type
+ >
+ const matrix_op<op_jet<image_type> >
+ jet (
+ const image_type& img,
+ double max_val,
+ double min_val = 0
+ )
+ {
+ typedef op_jet<image_type> op;
+ return matrix_op<op>(op(img,max_val,min_val));
+ }
+
+ template <
+ typename image_type
+ >
+ const matrix_op<op_jet<image_type> >
+ jet (
+ const image_type& img
+ )
+ {
+ typedef op_jet<image_type> op;
+ if (num_columns(img) * num_rows(img) != 0)
+ return matrix_op<op>(op(img,max(mat(img)),min(mat(img))));
+ else
+ return matrix_op<op>(op(img,0,0));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANDOMLY_COlOR_IMAGE_Hh_
+
diff --git a/ml/dlib/dlib/image_transforms/colormaps_abstract.h b/ml/dlib/dlib/image_transforms/colormaps_abstract.h
new file mode 100644
index 000000000..41a7784ba
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/colormaps_abstract.h
@@ -0,0 +1,152 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RANDOMLY_COlOR_IMAGE_ABSTRACT_Hh_
+#ifdef DLIB_RANDOMLY_COlOR_IMAGE_ABSTRACT_Hh_
+
+#include "../hash.h"
+#include "../pixel.h"
+#include "../matrix.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ const matrix_exp randomly_color_image (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h, or something convertible to a matrix
+ via mat().
+ ensures
+ - randomly generates a mapping from gray level pixel values
+ to the RGB pixel space and then uses this mapping to create
+ a colored version of img. Returns a matrix which represents
+ this colored version of img.
+ - black pixels in img will remain black in the output image.
+ - The returned matrix will have the same dimensions as img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ rgb_pixel colormap_heat (
+ double value,
+ double min_val,
+ double max_val
+ );
+ /*!
+ requires
+ - min_val <= max_val
+ ensures
+ - Maps value to a color. In particular, we use a heatmap color scheme where
+ values <= min_val are black and larger values become more red, then yellow,
+ and then white as they approach max_val.
+ !*/
+
+ template <
+ typename image_type
+ >
+ const matrix_exp heatmap (
+ const image_type& img,
+ double max_val,
+ double min_val = 0
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h, or something convertible to a matrix
+ via mat().
+ ensures
+ - Interprets img as a grayscale image and returns a new matrix which represents
+ a colored version of img. In particular, the colormap is defined by
+ out_color = colormap_heat(grayscale_pixel_value, min_val, max_val).
+ - The returned matrix will have the same dimensions as img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ const matrix_exp heatmap (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h, or something convertible to a matrix
+ via mat().
+ ensures
+ - returns heatmap(img, max(mat(img)), min(mat(img)))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ rgb_pixel colormap_jet (
+ double value,
+ double min_val,
+ double max_val
+ );
+ /*!
+ requires
+ - min_val <= max_val
+ ensures
+ - Maps value to a color. In particular, we use a jet color scheme where
+ values <= min_val are dark blue and larger values become light blue, then
+ yellow, and then finally red as they approach max_val.
+ !*/
+
+ template <
+ typename image_type
+ >
+ const matrix_exp jet (
+ const image_type& img,
+ double max_val,
+ double min_val = 0
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h, or something convertible to a matrix
+ via mat().
+ ensures
+ - Interprets img as a grayscale image and returns a new matrix which represents
+ a colored version of img. In particular, the colormap is defined by
+ out_color = colormap_jet(grayscale_pixel_value, min_val, max_val).
+ - The returned matrix will have the same dimensions as img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ const matrix_exp jet (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h, or something convertible to a matrix
+ via mat().
+ ensures
+ - returns jet(img, max(mat(img)), min(mat(img)))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANDOMLY_COlOR_IMAGE_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_transforms/draw.h b/ml/dlib/dlib/image_transforms/draw.h
new file mode 100644
index 000000000..66737b215
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/draw.h
@@ -0,0 +1,396 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DRAW_IMAGe_
+#define DLIB_DRAW_IMAGe_
+
+#include "draw_abstract.h"
+#include "../algs.h"
+#include "../pixel.h"
+#include "../matrix.h"
+#include <cmath>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void draw_line (
+ long x1,
+ long y1,
+ long x2,
+ long y2,
+ image_type& c_,
+ const pixel_type& val
+ )
+ {
+ image_view<image_type> c(c_);
+ if (x1 == x2)
+ {
+ // make sure y1 comes before y2
+ if (y1 > y2)
+ swap(y1,y2);
+
+ if (x1 < 0 || x1 >= c.nc())
+ return;
+
+
+ // this is a vertical line
+ for (long y = y1; y <= y2; ++y)
+ {
+ if (y < 0 || y >= c.nr())
+ continue;
+
+ assign_pixel(c[y][x1], val);
+ }
+ }
+ else if (y1 == y2)
+ {
+
+ // make sure x1 comes before x2
+ if (x1 > x2)
+ swap(x1,x2);
+
+ if (y1 < 0 || y1 >= c.nr())
+ return;
+
+ // this is a horizontal line
+ for (long x = x1; x <= x2; ++x)
+ {
+ if (x < 0 || x >= c.nc())
+ continue;
+
+ assign_pixel(c[y1][x] , val);
+ }
+ }
+ else
+ {
+ // This part is a little more complicated because we are going to perform alpha
+ // blending so the diagonal lines look nice.
+ const rectangle valid_area = get_rect(c);
+ rgb_alpha_pixel alpha_pixel;
+ assign_pixel(alpha_pixel, val);
+ const unsigned char max_alpha = alpha_pixel.alpha;
+
+ const long rise = (((long)y2) - ((long)y1));
+ const long run = (((long)x2) - ((long)x1));
+ if (std::abs(rise) < std::abs(run))
+ {
+ const double slope = ((double)rise)/run;
+
+
+ double first, last;
+
+
+ if (x1 > x2)
+ {
+ first = std::max(x2,valid_area.left());
+ last = std::min(x1,valid_area.right());
+ }
+ else
+ {
+ first = std::max(x1,valid_area.left());
+ last = std::min(x2,valid_area.right());
+ }
+
+ long y;
+ long x;
+ const double x1f = x1;
+ const double y1f = y1;
+ for (double i = first; i <= last; ++i)
+ {
+ const double dy = slope*(i-x1f) + y1f;
+ const double dx = i;
+
+ y = static_cast<long>(dy);
+ x = static_cast<long>(dx);
+
+
+ if (y >= valid_area.top() && y <= valid_area.bottom())
+ {
+ alpha_pixel.alpha = static_cast<unsigned char>((1.0-(dy-y))*max_alpha);
+ assign_pixel(c[y][x], alpha_pixel);
+ }
+ if (y+1 >= valid_area.top() && y+1 <= valid_area.bottom())
+ {
+ alpha_pixel.alpha = static_cast<unsigned char>((dy-y)*max_alpha);
+ assign_pixel(c[y+1][x], alpha_pixel);
+ }
+ }
+ }
+ else
+ {
+ const double slope = ((double)run)/rise;
+
+
+ double first, last;
+
+
+ if (y1 > y2)
+ {
+ first = std::max(y2,valid_area.top());
+ last = std::min(y1,valid_area.bottom());
+ }
+ else
+ {
+ first = std::max(y1,valid_area.top());
+ last = std::min(y2,valid_area.bottom());
+ }
+
+ long x;
+ long y;
+ const double x1f = x1;
+ const double y1f = y1;
+ for (double i = first; i <= last; ++i)
+ {
+ const double dx = slope*(i-y1f) + x1f;
+ const double dy = i;
+
+ y = static_cast<long>(dy);
+ x = static_cast<long>(dx);
+
+ if (x >= valid_area.left() && x <= valid_area.right())
+ {
+ alpha_pixel.alpha = static_cast<unsigned char>((1.0-(dx-x))*max_alpha);
+ assign_pixel(c[y][x], alpha_pixel);
+ }
+ if (x+1 >= valid_area.left() && x+1 <= valid_area.right())
+ {
+ alpha_pixel.alpha = static_cast<unsigned char>((dx-x)*max_alpha);
+ assign_pixel(c[y][x+1], alpha_pixel);
+ }
+ }
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void draw_line (
+ image_type& c,
+ const point& p1,
+ const point& p2,
+ const pixel_type& val
+ )
+ {
+ draw_line(p1.x(),p1.y(),p2.x(),p2.y(),c,val);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void draw_rectangle (
+ image_type& c,
+ const rectangle& rect,
+ const pixel_type& val
+ )
+ {
+ draw_line(c, rect.tl_corner(), rect.tr_corner(), val);
+ draw_line(c, rect.bl_corner(), rect.br_corner(), val);
+ draw_line(c, rect.tl_corner(), rect.bl_corner(), val);
+ draw_line(c, rect.tr_corner(), rect.br_corner(), val);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void draw_rectangle (
+ image_type& c,
+ const rectangle& rect,
+ const pixel_type& val,
+ unsigned int thickness
+ )
+ {
+ for (unsigned int i = 0; i < thickness; ++i)
+ {
+ if ((i%2)==0)
+ draw_rectangle(c,shrink_rect(rect,(i+1)/2),val);
+ else
+ draw_rectangle(c,grow_rect(rect,(i+1)/2),val);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void fill_rect (
+ image_type& img_,
+ const rectangle& rect,
+ const pixel_type& pixel
+ )
+ {
+ image_view<image_type> img(img_);
+ rectangle area = rect.intersect(get_rect(img));
+
+ for (long r = area.top(); r <= area.bottom(); ++r)
+ {
+ for (long c = area.left(); c <= area.right(); ++c)
+ {
+ assign_pixel(img[r][c], pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ matrix<typename image_traits<typename image_array_type::value_type>::pixel_type> tile_images (
+ const image_array_type& images
+ )
+ {
+ typedef typename image_traits<typename image_array_type::value_type>::pixel_type T;
+
+ if (images.size() == 0)
+ return matrix<T>();
+
+ const unsigned long size_nc = square_root(images.size());
+ const unsigned long size_nr = (size_nc*(size_nc-1)>=images.size())? size_nc-1 : size_nc;
+ // Figure out the size we have to use for each chip in the big main image. We will
+ // use the largest dimensions seen across all the chips.
+ long nr = 0;
+ long nc = 0;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ nr = std::max(num_rows(images[i]), nr);
+ nc = std::max(num_columns(images[i]), nc);
+ }
+
+ matrix<T> temp(size_nr*nr, size_nc*nc);
+ T background_color;
+ assign_pixel(background_color, 0);
+ temp = background_color;
+ unsigned long idx = 0;
+ for (unsigned long r = 0; r < size_nr; ++r)
+ {
+ for (unsigned long c = 0; c < size_nc; ++c)
+ {
+ if (idx < images.size())
+ {
+ set_subm(temp, r*nr, c*nc, nr, nc) = mat(images[idx]);
+ }
+ ++idx;
+ }
+ }
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void draw_solid_circle (
+ image_type& img_,
+ const dpoint& center_point,
+ double radius,
+ const pixel_type& pixel
+ )
+ {
+ image_view<image_type> img(img_);
+ using std::sqrt;
+ const rectangle valid_area(get_rect(img));
+ const double x = center_point.x();
+ const double y = center_point.y();
+ const point cp(center_point);
+ if (radius > 1)
+ {
+ long first_x = static_cast<long>(x - radius + 0.5);
+ long last_x = static_cast<long>(x + radius + 0.5);
+ const double rs = radius*radius;
+
+ // ensure that we only loop over the part of the x dimension that this
+ // image contains.
+ if (first_x < valid_area.left())
+ first_x = valid_area.left();
+ if (last_x > valid_area.right())
+ last_x = valid_area.right();
+
+ long top, bottom;
+
+ top = static_cast<long>(sqrt(std::max(rs - (first_x-x-0.5)*(first_x-x-0.5),0.0))+0.5);
+ top += y;
+ long last = top;
+
+ // draw the left half of the circle
+ long middle = std::min(cp.x()-1,last_x);
+ for (long i = first_x; i <= middle; ++i)
+ {
+ double a = i - x + 0.5;
+ // find the top of the arc
+ top = static_cast<long>(sqrt(std::max(rs - a*a,0.0))+0.5);
+ top += y;
+ long temp = top;
+
+ while(top >= last)
+ {
+ bottom = y - top + y;
+ draw_line(img_, point(i,top),point(i,bottom),pixel);
+ --top;
+ }
+
+ last = temp;
+ }
+
+ middle = std::max(cp.x(),first_x);
+ top = static_cast<long>(sqrt(std::max(rs - (last_x-x+0.5)*(last_x-x+0.5),0.0))+0.5);
+ top += y;
+ last = top;
+ // draw the right half of the circle
+ for (long i = last_x; i >= middle; --i)
+ {
+ double a = i - x - 0.5;
+ // find the top of the arc
+ top = static_cast<long>(sqrt(std::max(rs - a*a,0.0))+0.5);
+ top += y;
+ long temp = top;
+
+ while(top >= last)
+ {
+ bottom = y - top + y;
+ draw_line(img_, point(i,top),point(i,bottom),pixel);
+ --top;
+ }
+
+ last = temp;
+ }
+ }
+ else if (valid_area.contains(cp))
+ {
+ // For circles smaller than a pixel we will just alpha blend them in proportion
+ // to how small they are.
+ rgb_alpha_pixel temp;
+ assign_pixel(temp, pixel);
+ temp.alpha = static_cast<unsigned char>(255*radius + 0.5);
+ assign_pixel(img[cp.y()][cp.x()], temp);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DRAW_IMAGe_
+
+
+
+
diff --git a/ml/dlib/dlib/image_transforms/draw_abstract.h b/ml/dlib/dlib/image_transforms/draw_abstract.h
new file mode 100644
index 000000000..6631f8d8f
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/draw_abstract.h
@@ -0,0 +1,150 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DRAW_IMAGe_ABSTRACT
+#ifdef DLIB_DRAW_IMAGe_ABSTRACT
+
+#include "../matrix.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void draw_line (
+ image_type& img,
+ const point& p1,
+ const point& p2,
+ const pixel_type& val
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - #img.nr() == img.nr() && #img.nc() == img.nc()
+ (i.e. the dimensions of the input image are not changed)
+ - for all valid r and c that are on the line between point p1 and p2:
+ - performs assign_pixel(img[r][c], val)
+ (i.e. it draws the line from p1 to p2 onto the image)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void draw_line (
+ long x1,
+ long y1,
+ long x2,
+ long y2,
+ image_type& img,
+ const pixel_type& val
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - performs draw_line(img, point(x1,y1), point(x2,y2), val)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void draw_rectangle (
+ image_type& img,
+ const rectangle& rect,
+ const pixel_type& val,
+ unsigned int thickness = 1
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - Draws the given rectangle onto the image img. It does this by calling
+ draw_line() four times to draw the four sides of the rectangle.
+ - The rectangle is drawn with the color given by val.
+ - The drawn rectangle will have edges that are thickness pixels wide.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void draw_solid_circle (
+ image_type& img,
+ const dpoint& center_point,
+ double radius,
+ const pixel_type& pixel
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - Draws a fully filled in circle onto image that is centered at center_point
+ and has the given radius. The circle will be filled by assigning the given
+ pixel value to each element of the circle.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pixel_type
+ >
+ void fill_rect (
+ image_type& img,
+ const rectangle& rect,
+ const pixel_type& pixel
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - fills the area defined by rect in the given image with the given pixel value.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ matrix<typename image_traits<typename image_array_type::value_type>::pixel_type> tile_images (
+ const image_array_type& images
+ );
+ /*!
+ requires
+ - image_array_type is a dlib::array of image objects where each image object
+ implements the interface defined in dlib/image_processing/generic_image.h
+ ensures
+ - This function takes the given images and tiles them into a single large
+ square image and returns this new big tiled image. Therefore, it is a useful
+ method to visualize many small images at once.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DRAW_IMAGe_ABSTRACT
+
+
+
diff --git a/ml/dlib/dlib/image_transforms/edge_detector.h b/ml/dlib/dlib/image_transforms/edge_detector.h
new file mode 100644
index 000000000..2fa898fed
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/edge_detector.h
@@ -0,0 +1,302 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_EDGE_DETECTOr_
+#define DLIB_EDGE_DETECTOr_
+
+#include "edge_detector_abstract.h"
+#include "../pixel.h"
+#include "../array2d.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ inline char edge_orientation (
+ const T& x_,
+ const T& y_
+ )
+ {
+
+ // if this is a perfectly horizontal gradient then return right away
+ if (x_ == 0)
+ {
+ return '|';
+ }
+ else if (y_ == 0) // if this is a perfectly vertical gradient then return right away
+ {
+ return '-';
+ }
+
+ // Promote x so that when we multiply by 128 later we know overflow won't happen.
+ typedef typename promote<T>::type type;
+ type x = x_;
+ type y = y_;
+
+ if (x < 0)
+ {
+ x = -x;
+ if (y < 0)
+ {
+ y = -y;
+ x *= 128;
+ const type temp = x/y;
+ if (temp > 309)
+ return '-';
+ else if (temp > 53)
+ return '/';
+ else
+ return '|';
+ }
+ else
+ {
+ x *= 128;
+ const type temp = x/y;
+ if (temp > 309)
+ return '-';
+ else if (temp > 53)
+ return '\\';
+ else
+ return '|';
+ }
+ }
+ else
+ {
+ if (y < 0)
+ {
+ y = -y;
+ x *= 128;
+
+ const type temp = x/y;
+ if (temp > 309)
+ return '-';
+ else if (temp > 53)
+ return '\\';
+ else
+ return '|';
+ }
+ else
+ {
+ x *= 128;
+
+ const type temp = x/y;
+ if (temp > 309)
+ return '-';
+ else if (temp > 53)
+ return '/';
+ else
+ return '|';
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void sobel_edge_detector (
+ const in_image_type& in_img_,
+ out_image_type& horz_,
+ out_image_type& vert_
+ )
+ {
+ typedef typename image_traits<out_image_type>::pixel_type pixel_type;
+ COMPILE_TIME_ASSERT(pixel_traits<pixel_type>::is_unsigned == false);
+ DLIB_ASSERT( !is_same_object(in_img_,horz_) && !is_same_object(in_img_,vert_) &&
+ !is_same_object(horz_,vert_),
+ "\tvoid sobel_edge_detector(in_img_, horz_, vert_)"
+ << "\n\t You can't give the same image as more than one argument"
+ << "\n\t is_same_object(in_img_,horz_): " << is_same_object(in_img_,horz_)
+ << "\n\t is_same_object(in_img_,vert_): " << is_same_object(in_img_,vert_)
+ << "\n\t is_same_object(horz_,vert_): " << is_same_object(horz_,vert_)
+ );
+
+
+ const int vert_filter[3][3] = {{-1,-2,-1},
+ {0,0,0},
+ {1,2,1}};
+ const int horz_filter[3][3] = { {-1,0,1},
+ {-2,0,2},
+ {-1,0,1}};
+
+ const long M = 3;
+ const long N = 3;
+
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> horz(horz_);
+ image_view<out_image_type> vert(vert_);
+
+ horz.set_size(in_img.nr(),in_img.nc());
+ vert.set_size(in_img.nr(),in_img.nc());
+
+ assign_border_pixels(horz,1,1,0);
+ assign_border_pixels(vert,1,1,0);
+
+ // figure out the range that we should apply the filter to
+ const long first_row = M/2;
+ const long first_col = N/2;
+ const long last_row = in_img.nr() - M/2;
+ const long last_col = in_img.nc() - N/2;
+
+
+ // apply the filter to the image
+ for (long r = first_row; r < last_row; ++r)
+ {
+ for (long c = first_col; c < last_col; ++c)
+ {
+ typedef typename pixel_traits<typename image_traits<in_image_type>::pixel_type>::basic_pixel_type bp_type;
+
+ typename promote<bp_type>::type p, horz_temp, vert_temp;
+ horz_temp = 0;
+ vert_temp = 0;
+ for (long m = 0; m < M; ++m)
+ {
+ for (long n = 0; n < N; ++n)
+ {
+ // pull out the current pixel and put it into p
+ p = get_pixel_intensity(in_img[r-M/2+m][c-N/2+n]);
+
+ horz_temp += p*horz_filter[m][n];
+ vert_temp += p*vert_filter[m][n];
+ }
+ }
+
+ assign_pixel(horz[r][c] , horz_temp);
+ assign_pixel(vert[r][c] , vert_temp);
+
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T>
+ typename promote<T>::type square (const T& a)
+ {
+ return static_cast<T>(a)*static_cast<T>(a);
+ }
+ }
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void suppress_non_maximum_edges (
+ const in_image_type& horz_,
+ const in_image_type& vert_,
+ out_image_type& out_img_
+ )
+ {
+ const_image_view<in_image_type> horz(horz_);
+ const_image_view<in_image_type> vert(vert_);
+ image_view<out_image_type> out_img(out_img_);
+
+ COMPILE_TIME_ASSERT(is_signed_type<typename image_traits<in_image_type>::pixel_type>::value);
+ DLIB_ASSERT( horz.nr() == vert.nr() && horz.nc() == vert.nc(),
+ "\tvoid suppress_non_maximum_edges(horz, vert, out_img)"
+ << "\n\tYou have to give horz and vert gradient images that are the same size"
+ << "\n\thorz.nr(): " << horz.nr()
+ << "\n\thorz.nc(): " << horz.nc()
+ << "\n\tvert.nr(): " << vert.nr()
+ << "\n\tvert.nc(): " << vert.nc()
+ );
+ DLIB_ASSERT( !is_same_object(out_img_,horz_) && !is_same_object(out_img_,vert_),
+ "\tvoid suppress_non_maximum_edges(horz_, vert_, out_img_)"
+ << "\n\t out_img can't be the same as one of the input images."
+ << "\n\t is_same_object(out_img_,horz_): " << is_same_object(out_img_,horz_)
+ << "\n\t is_same_object(out_img_,vert_): " << is_same_object(out_img_,vert_)
+ );
+
+ using std::min;
+ using std::abs;
+
+
+ // if there isn't any input image then don't do anything
+ if (horz.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(horz.nr(),horz.nc());
+
+ zero_border_pixels(out_img,1,1);
+
+ // now do non maximum suppression while we copy the
+ const long M = 3;
+ const long N = 3;
+
+ // figure out the range that we should apply the filter to
+ const long first_row = M/2;
+ const long first_col = N/2;
+ const long last_row = horz.nr() - M/2;
+ const long last_col = horz.nc() - N/2;
+
+
+ // apply the filter to the image
+ for (long r = first_row; r < last_row; ++r)
+ {
+ for (long c = first_col; c < last_col; ++c)
+ {
+ typedef typename promote<typename image_traits<in_image_type>::pixel_type>::type T;
+ const T y = horz[r][c];
+ const T x = vert[r][c];
+
+ using impl::square;
+
+ const T val = square(horz[r][c]) + square(vert[r][c]);
+
+ const char ori = edge_orientation(x,y);
+ const unsigned char zero = 0;
+ switch (ori)
+ {
+ case '-':
+ if (square(horz[r-1][c])+square(vert[r-1][c]) > val || square(horz[r+1][c]) + square(vert[r+1][c]) > val)
+ assign_pixel(out_img[r][c] , zero);
+ else
+ assign_pixel(out_img[r][c] , std::sqrt((double)val));
+ break;
+
+ case '|':
+ if (square(horz[r][c-1]) + square(vert[r][c-1]) > val || square(horz[r][c+1]) + square(vert[r][c+1]) > val)
+ assign_pixel(out_img[r][c] , zero);
+ else
+ assign_pixel(out_img[r][c] , std::sqrt((double)val));
+ break;
+
+ case '/':
+ if (square(horz[r-1][c-1]) + square(vert[r-1][c-1]) > val || square(horz[r+1][c+1]) + square(vert[r+1][c+1]) > val)
+ assign_pixel(out_img[r][c] , zero);
+ else
+ assign_pixel(out_img[r][c] , std::sqrt((double)val));
+ break;
+
+ case '\\':
+ if (square(horz[r+1][c-1]) + square(vert[r+1][c-1]) > val || square(horz[r-1][c+1]) + square(vert[r-1][c+1]) > val)
+ assign_pixel(out_img[r][c] , zero);
+ else
+ assign_pixel(out_img[r][c] , std::sqrt((double)val));
+ break;
+
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_EDGE_DETECTOr_
+
+
+
diff --git a/ml/dlib/dlib/image_transforms/edge_detector_abstract.h b/ml/dlib/dlib/image_transforms/edge_detector_abstract.h
new file mode 100644
index 000000000..42c991665
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/edge_detector_abstract.h
@@ -0,0 +1,112 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_EDGE_DETECTOr_ABSTRACT_
+#ifdef DLIB_EDGE_DETECTOr_ABSTRACT_
+
+#include "../pixel.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ inline char edge_orientation (
+ const T& x,
+ const T& y
+ );
+ /*!
+ ensures
+ - returns the orientation of the line drawn from the origin to the point (x,y).
+ The orientation is represented pictorially using the four ascii
+ characters /,|,\, and -.
+ - if (the line is horizontal) then
+ returns '-'
+ - if (the line is vertical) then
+ returns '|'
+ - if (the line is diagonal with a positive slope) then
+ returns '/'
+ - if (the line is diagonal with a negative slope) then
+ returns '\\'
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void sobel_edge_detector (
+ const in_image_type& in_img,
+ out_image_type& horz,
+ out_image_type& vert
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type must use signed grayscale pixels
+ - is_same_object(in_img,horz) == false
+ - is_same_object(in_img,vert) == false
+ - is_same_object(horz,vert) == false
+ ensures
+ - Applies the sobel edge detector to the given input image and stores the resulting
+ edge detections in the horz and vert images
+ - #horz.nr() == in_img.nr()
+ - #horz.nc() == in_img.nc()
+ - #vert.nr() == in_img.nr()
+ - #vert.nc() == in_img.nc()
+ - for all valid r and c:
+ - #horz[r][c] == the magnitude of the horizontal gradient at the point in_img[r][c]
+ - #vert[r][c] == the magnitude of the vertical gradient at the point in_img[r][c]
+ - edge_orientation(#vert[r][c], #horz[r][c]) == the edge direction at this point in
+ the image
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void suppress_non_maximum_edges (
+ const in_image_type& horz,
+ const in_image_type& vert,
+ out_image_type& out_img
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - horz.nr() == vert.nr()
+ - horz.nc() == vert.nc()
+ - is_same_object(out_img, horz) == false
+ - is_same_object(out_img, vert) == false
+ - image_traits<in_image_type>::pixel_type == A signed scalar type (e.g. int, double, etc.)
+ ensures
+ - #out_img.nr() = horz.nr()
+ - #out_img.nc() = horz.nc()
+ - let edge_strength(r,c) == sqrt(pow(horz[r][c],2) + pow(vert[r][c],2))
+ (i.e. The Euclidean norm of the gradient)
+ - for all valid r and c:
+ - if (edge_strength(r,c) is at a maximum with respect to its 2 neighboring
+ pixels along the line given by edge_orientation(vert[r][c],horz[r][c])) then
+ - performs assign_pixel(#out_img[r][c], edge_strength(r,c))
+ - else
+ - performs assign_pixel(#out_img[r][c], 0)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_EDGE_DETECTOr_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/image_transforms/equalize_histogram.h b/ml/dlib/dlib/image_transforms/equalize_histogram.h
new file mode 100644
index 000000000..dd048759a
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/equalize_histogram.h
@@ -0,0 +1,143 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_EQUALIZE_HISTOGRAm_
+#define DLIB_EQUALIZE_HISTOGRAm_
+
+#include "../pixel.h"
+#include "equalize_histogram_abstract.h"
+#include <vector>
+#include "../enable_if.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ---------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ long R,
+ long C,
+ typename MM
+ >
+ void get_histogram (
+ const in_image_type& in_img_,
+ matrix<unsigned long,R,C,MM>& hist
+ )
+ {
+ typedef typename image_traits<in_image_type>::pixel_type pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<pixel_type>::is_unsigned == true );
+
+ typedef typename pixel_traits<pixel_type>::basic_pixel_type in_image_basic_pixel_type;
+ COMPILE_TIME_ASSERT( sizeof(in_image_basic_pixel_type) <= 2);
+
+ // make sure hist is the right size
+ if (R == 1)
+ hist.set_size(1,pixel_traits<pixel_type>::max()+1);
+ else
+ hist.set_size(pixel_traits<pixel_type>::max()+1,1);
+
+
+ set_all_elements(hist,0);
+
+ const_image_view<in_image_type> in_img(in_img_);
+ // compute the histogram
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ for (long c = 0; c < in_img.nc(); ++c)
+ {
+ unsigned long p = get_pixel_intensity(in_img[r][c]);
+ ++hist(p);
+ }
+ }
+ }
+
+// ---------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void equalize_histogram (
+ const in_image_type& in_img_,
+ out_image_type& out_img_
+ )
+ {
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::is_unsigned == true );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::is_unsigned == true );
+
+ typedef typename pixel_traits<in_pixel_type>::basic_pixel_type in_image_basic_pixel_type;
+ COMPILE_TIME_ASSERT( sizeof(in_image_basic_pixel_type) <= 2);
+
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+ unsigned long p;
+
+ matrix<unsigned long,1,0> histogram;
+ get_histogram(in_img_, histogram);
+ in_img = in_img_;
+
+ double scale = pixel_traits<out_pixel_type>::max();
+ if (in_img.size() > histogram(0))
+ scale /= in_img.size()-histogram(0);
+ else
+ scale = 0;
+
+ // make the black pixels remain black in the output image
+ histogram(0) = 0;
+
+ // compute the transform function
+ for (long i = 1; i < histogram.size(); ++i)
+ histogram(i) += histogram(i-1);
+ // scale so that it is in the range [0,pixel_traits<out_pixel_type>::max()]
+ for (long i = 0; i < histogram.size(); ++i)
+ histogram(i) = static_cast<unsigned long>(histogram(i)*scale);
+
+ // now do the transform
+ for (long row = 0; row < in_img.nr(); ++row)
+ {
+ for (long col = 0; col < in_img.nc(); ++col)
+ {
+ p = histogram(get_pixel_intensity(in_img[row][col]));
+ assign_pixel(out_img[row][col], in_img[row][col]);
+ assign_pixel_intensity(out_img[row][col],p);
+ }
+ }
+
+ }
+
+ template <
+ typename image_type
+ >
+ void equalize_histogram (
+ image_type& img
+ )
+ {
+ equalize_histogram(img,img);
+ }
+
+// ---------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_EQUALIZE_HISTOGRAm_
+
+
+
diff --git a/ml/dlib/dlib/image_transforms/equalize_histogram_abstract.h b/ml/dlib/dlib/image_transforms/equalize_histogram_abstract.h
new file mode 100644
index 000000000..2592aef1a
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/equalize_histogram_abstract.h
@@ -0,0 +1,91 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_EQUALIZE_HISTOGRAm_ABSTRACT_
+#ifdef DLIB_EQUALIZE_HISTOGRAm_ABSTRACT_
+
+#include "../pixel.h"
+#include "../matrix.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ---------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void equalize_histogram (
+ const in_image_type& in_img,
+ out_image_type& out_img
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - Let pixel_type be the type of pixel in either input or output images, then we
+ must have:
+ - pixel_traits<pixel_type>::has_alpha == false
+ - pixel_traits<pixel_type>::is_unsigned == true
+ - For the input image pixel type, we have the additional requirement that:
+ - pixel_traits<pixel_type>::max() <= 65535
+ ensures
+ - #out_img == the histogram equalized version of in_img
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+ template <
+ typename image_type
+ >
+ void equalize_histogram (
+ image_type& img
+ );
+ /*!
+ requires
+ - it is valid to call equalize_histogram(img,img)
+ ensures
+ - calls equalize_histogram(img,img);
+ !*/
+
+// ---------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ long R,
+ long C,
+ typename MM
+ >
+ void get_histogram (
+ const in_image_type& in_img,
+ matrix<unsigned long,R,C,MM>& hist
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - Let pixel_type denote the type of pixel in in_img, then we must have:
+ - pixel_traits<pixel_type>::is_unsigned == true
+ - pixel_traits<pixel_type>::max() <= 65535
+ - hist must be capable of representing a column vector of length
+ pixel_traits<typename in_image_type>::max(). I.e. if R and C are nonzero
+ then they must be values that don't conflict with the previous sentence.
+ ensures
+ - #hist.size() == pixel_traits<typename in_image_type>::max()
+ - #hist.nc() == 1 || #hist.nr() == 1 (i.e. hist is either a row or column vector)
+ - #hist == the histogram for in_img. I.e. it is the case that for all
+ valid i:
+ - hist(i) == the number of times a pixel with intensity i appears
+ in in_img
+ !*/
+
+// ---------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_EQUALIZE_HISTOGRAm_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/image_transforms/fhog.h b/ml/dlib/dlib/image_transforms/fhog.h
new file mode 100644
index 000000000..d99973adf
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/fhog.h
@@ -0,0 +1,1404 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_fHOG_Hh_
+#define DLIB_fHOG_Hh_
+
+#include "fhog_abstract.h"
+#include "../matrix.h"
+#include "../array2d.h"
+#include "../array.h"
+#include "../geometry.h"
+#include "assign_image.h"
+#include "draw.h"
+#include "interpolation.h"
+#include "../simd.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl_fhog
+ {
+ template <typename image_type, typename T>
+ inline typename dlib::enable_if_c<pixel_traits<typename image_type::pixel_type>::rgb>::type get_gradient (
+ const int r,
+ const int c,
+ const image_type& img,
+ matrix<T,2,1>& grad,
+ T& len
+ )
+ {
+ matrix<T, 2, 1> grad2, grad3;
+ // get the red gradient
+ grad(0) = (int)img[r][c+1].red-(int)img[r][c-1].red;
+ grad(1) = (int)img[r+1][c].red-(int)img[r-1][c].red;
+ len = length_squared(grad);
+
+ // get the green gradient
+ grad2(0) = (int)img[r][c+1].green-(int)img[r][c-1].green;
+ grad2(1) = (int)img[r+1][c].green-(int)img[r-1][c].green;
+ T v2 = length_squared(grad2);
+
+ // get the blue gradient
+ grad3(0) = (int)img[r][c+1].blue-(int)img[r][c-1].blue;
+ grad3(1) = (int)img[r+1][c].blue-(int)img[r-1][c].blue;
+ T v3 = length_squared(grad3);
+
+ // pick color with strongest gradient
+ if (v2 > len)
+ {
+ len = v2;
+ grad = grad2;
+ }
+ if (v3 > len)
+ {
+ len = v3;
+ grad = grad3;
+ }
+ }
+
+ template <typename image_type>
+ inline typename dlib::enable_if_c<pixel_traits<typename image_type::pixel_type>::rgb>::type get_gradient (
+ const int r,
+ const int c,
+ const image_type& img,
+ simd4f& grad_x,
+ simd4f& grad_y,
+ simd4f& len
+ )
+ {
+ simd4i rleft((int)img[r][c-1].red,
+ (int)img[r][c].red,
+ (int)img[r][c+1].red,
+ (int)img[r][c+2].red);
+ simd4i rright((int)img[r][c+1].red,
+ (int)img[r][c+2].red,
+ (int)img[r][c+3].red,
+ (int)img[r][c+4].red);
+ simd4i rtop((int)img[r-1][c].red,
+ (int)img[r-1][c+1].red,
+ (int)img[r-1][c+2].red,
+ (int)img[r-1][c+3].red);
+ simd4i rbottom((int)img[r+1][c].red,
+ (int)img[r+1][c+1].red,
+ (int)img[r+1][c+2].red,
+ (int)img[r+1][c+3].red);
+
+ simd4i gleft((int)img[r][c-1].green,
+ (int)img[r][c].green,
+ (int)img[r][c+1].green,
+ (int)img[r][c+2].green);
+ simd4i gright((int)img[r][c+1].green,
+ (int)img[r][c+2].green,
+ (int)img[r][c+3].green,
+ (int)img[r][c+4].green);
+ simd4i gtop((int)img[r-1][c].green,
+ (int)img[r-1][c+1].green,
+ (int)img[r-1][c+2].green,
+ (int)img[r-1][c+3].green);
+ simd4i gbottom((int)img[r+1][c].green,
+ (int)img[r+1][c+1].green,
+ (int)img[r+1][c+2].green,
+ (int)img[r+1][c+3].green);
+
+ simd4i bleft((int)img[r][c-1].blue,
+ (int)img[r][c].blue,
+ (int)img[r][c+1].blue,
+ (int)img[r][c+2].blue);
+ simd4i bright((int)img[r][c+1].blue,
+ (int)img[r][c+2].blue,
+ (int)img[r][c+3].blue,
+ (int)img[r][c+4].blue);
+ simd4i btop((int)img[r-1][c].blue,
+ (int)img[r-1][c+1].blue,
+ (int)img[r-1][c+2].blue,
+ (int)img[r-1][c+3].blue);
+ simd4i bbottom((int)img[r+1][c].blue,
+ (int)img[r+1][c+1].blue,
+ (int)img[r+1][c+2].blue,
+ (int)img[r+1][c+3].blue);
+
+ simd4i grad_x_red = rright-rleft;
+ simd4i grad_y_red = rbottom-rtop;
+ simd4i grad_x_green = gright-gleft;
+ simd4i grad_y_green = gbottom-gtop;
+ simd4i grad_x_blue = bright-bleft;
+ simd4i grad_y_blue = bbottom-btop;
+
+ simd4i rlen = grad_x_red*grad_x_red + grad_y_red*grad_y_red;
+ simd4i glen = grad_x_green*grad_x_green + grad_y_green*grad_y_green;
+ simd4i blen = grad_x_blue*grad_x_blue + grad_y_blue*grad_y_blue;
+
+ simd4i cmp = rlen>glen;
+ simd4i tgrad_x = select(cmp,grad_x_red,grad_x_green);
+ simd4i tgrad_y = select(cmp,grad_y_red,grad_y_green);
+ simd4i tlen = select(cmp,rlen,glen);
+
+ cmp = tlen>blen;
+ grad_x = select(cmp,tgrad_x,grad_x_blue);
+ grad_y = select(cmp,tgrad_y,grad_y_blue);
+ len = select(cmp,tlen,blen);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ inline typename dlib::enable_if_c<pixel_traits<typename image_type::pixel_type>::rgb>::type get_gradient(
+ const int r,
+ const int c,
+ const image_type& img,
+ simd8f& grad_x,
+ simd8f& grad_y,
+ simd8f& len
+ )
+ {
+ simd8i rleft((int)img[r][c - 1].red,
+ (int)img[r][c].red,
+ (int)img[r][c + 1].red,
+ (int)img[r][c + 2].red,
+ (int)img[r][c + 3].red,
+ (int)img[r][c + 4].red,
+ (int)img[r][c + 5].red,
+ (int)img[r][c + 6].red);
+ simd8i rright((int)img[r][c + 1].red,
+ (int)img[r][c + 2].red,
+ (int)img[r][c + 3].red,
+ (int)img[r][c + 4].red,
+ (int)img[r][c + 5].red,
+ (int)img[r][c + 6].red,
+ (int)img[r][c + 7].red,
+ (int)img[r][c + 8].red);
+ simd8i rtop((int)img[r - 1][c].red,
+ (int)img[r - 1][c + 1].red,
+ (int)img[r - 1][c + 2].red,
+ (int)img[r - 1][c + 3].red,
+ (int)img[r - 1][c + 4].red,
+ (int)img[r - 1][c + 5].red,
+ (int)img[r - 1][c + 6].red,
+ (int)img[r - 1][c + 7].red);
+ simd8i rbottom((int)img[r + 1][c].red,
+ (int)img[r + 1][c + 1].red,
+ (int)img[r + 1][c + 2].red,
+ (int)img[r + 1][c + 3].red,
+ (int)img[r + 1][c + 4].red,
+ (int)img[r + 1][c + 5].red,
+ (int)img[r + 1][c + 6].red,
+ (int)img[r + 1][c + 7].red);
+
+ simd8i gleft((int)img[r][c - 1].green,
+ (int)img[r][c].green,
+ (int)img[r][c + 1].green,
+ (int)img[r][c + 2].green,
+ (int)img[r][c + 3].green,
+ (int)img[r][c + 4].green,
+ (int)img[r][c + 5].green,
+ (int)img[r][c + 6].green);
+ simd8i gright((int)img[r][c + 1].green,
+ (int)img[r][c + 2].green,
+ (int)img[r][c + 3].green,
+ (int)img[r][c + 4].green,
+ (int)img[r][c + 5].green,
+ (int)img[r][c + 6].green,
+ (int)img[r][c + 7].green,
+ (int)img[r][c + 8].green);
+ simd8i gtop((int)img[r - 1][c].green,
+ (int)img[r - 1][c + 1].green,
+ (int)img[r - 1][c + 2].green,
+ (int)img[r - 1][c + 3].green,
+ (int)img[r - 1][c + 4].green,
+ (int)img[r - 1][c + 5].green,
+ (int)img[r - 1][c + 6].green,
+ (int)img[r - 1][c + 7].green);
+ simd8i gbottom((int)img[r + 1][c].green,
+ (int)img[r + 1][c + 1].green,
+ (int)img[r + 1][c + 2].green,
+ (int)img[r + 1][c + 3].green,
+ (int)img[r + 1][c + 4].green,
+ (int)img[r + 1][c + 5].green,
+ (int)img[r + 1][c + 6].green,
+ (int)img[r + 1][c + 7].green);
+
+ simd8i bleft((int)img[r][c - 1].blue,
+ (int)img[r][c].blue,
+ (int)img[r][c + 1].blue,
+ (int)img[r][c + 2].blue,
+ (int)img[r][c + 3].blue,
+ (int)img[r][c + 4].blue,
+ (int)img[r][c + 5].blue,
+ (int)img[r][c + 6].blue);
+ simd8i bright((int)img[r][c + 1].blue,
+ (int)img[r][c + 2].blue,
+ (int)img[r][c + 3].blue,
+ (int)img[r][c + 4].blue,
+ (int)img[r][c + 5].blue,
+ (int)img[r][c + 6].blue,
+ (int)img[r][c + 7].blue,
+ (int)img[r][c + 8].blue);
+ simd8i btop((int)img[r - 1][c].blue,
+ (int)img[r - 1][c + 1].blue,
+ (int)img[r - 1][c + 2].blue,
+ (int)img[r - 1][c + 3].blue,
+ (int)img[r - 1][c + 4].blue,
+ (int)img[r - 1][c + 5].blue,
+ (int)img[r - 1][c + 6].blue,
+ (int)img[r - 1][c + 7].blue);
+ simd8i bbottom((int)img[r + 1][c].blue,
+ (int)img[r + 1][c + 1].blue,
+ (int)img[r + 1][c + 2].blue,
+ (int)img[r + 1][c + 3].blue,
+ (int)img[r + 1][c + 4].blue,
+ (int)img[r + 1][c + 5].blue,
+ (int)img[r + 1][c + 6].blue,
+ (int)img[r + 1][c + 7].blue);
+
+ simd8i grad_x_red = rright - rleft;
+ simd8i grad_y_red = rbottom - rtop;
+ simd8i grad_x_green = gright - gleft;
+ simd8i grad_y_green = gbottom - gtop;
+ simd8i grad_x_blue = bright - bleft;
+ simd8i grad_y_blue = bbottom - btop;
+
+ simd8i rlen = grad_x_red*grad_x_red + grad_y_red*grad_y_red;
+ simd8i glen = grad_x_green*grad_x_green + grad_y_green*grad_y_green;
+ simd8i blen = grad_x_blue*grad_x_blue + grad_y_blue*grad_y_blue;
+
+ simd8i cmp = rlen > glen;
+ simd8i tgrad_x = select(cmp, grad_x_red, grad_x_green);
+ simd8i tgrad_y = select(cmp, grad_y_red, grad_y_green);
+ simd8i tlen = select(cmp, rlen, glen);
+
+ cmp = tlen > blen;
+ grad_x = select(cmp, tgrad_x, grad_x_blue);
+ grad_y = select(cmp, tgrad_y, grad_y_blue);
+ len = select(cmp, tlen, blen);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename image_type, typename T>
+ inline typename dlib::disable_if_c<pixel_traits<typename image_type::pixel_type>::rgb>::type get_gradient (
+ const int r,
+ const int c,
+ const image_type& img,
+ matrix<T, 2, 1>& grad,
+ T& len
+ )
+ {
+ grad(0) = (int)get_pixel_intensity(img[r][c+1])-(int)get_pixel_intensity(img[r][c-1]);
+ grad(1) = (int)get_pixel_intensity(img[r+1][c])-(int)get_pixel_intensity(img[r-1][c]);
+ len = length_squared(grad);
+ }
+
+ template <typename image_type>
+ inline typename dlib::disable_if_c<pixel_traits<typename image_type::pixel_type>::rgb>::type get_gradient (
+ int r,
+ int c,
+ const image_type& img,
+ simd4f& grad_x,
+ simd4f& grad_y,
+ simd4f& len
+ )
+ {
+ simd4i left((int)get_pixel_intensity(img[r][c-1]),
+ (int)get_pixel_intensity(img[r][c]),
+ (int)get_pixel_intensity(img[r][c+1]),
+ (int)get_pixel_intensity(img[r][c+2]));
+ simd4i right((int)get_pixel_intensity(img[r][c+1]),
+ (int)get_pixel_intensity(img[r][c+2]),
+ (int)get_pixel_intensity(img[r][c+3]),
+ (int)get_pixel_intensity(img[r][c+4]));
+
+ simd4i top((int)get_pixel_intensity(img[r-1][c]),
+ (int)get_pixel_intensity(img[r-1][c+1]),
+ (int)get_pixel_intensity(img[r-1][c+2]),
+ (int)get_pixel_intensity(img[r-1][c+3]));
+ simd4i bottom((int)get_pixel_intensity(img[r+1][c]),
+ (int)get_pixel_intensity(img[r+1][c+1]),
+ (int)get_pixel_intensity(img[r+1][c+2]),
+ (int)get_pixel_intensity(img[r+1][c+3]));
+
+ grad_x = right-left;
+ grad_y = bottom-top;
+
+ len = (grad_x*grad_x + grad_y*grad_y);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ inline typename dlib::disable_if_c<pixel_traits<typename image_type::pixel_type>::rgb>::type get_gradient(
+ int r,
+ int c,
+ const image_type& img,
+ simd8f& grad_x,
+ simd8f& grad_y,
+ simd8f& len
+ )
+ {
+ simd8i left((int)get_pixel_intensity(img[r][c - 1]),
+ (int)get_pixel_intensity(img[r][c]),
+ (int)get_pixel_intensity(img[r][c + 1]),
+ (int)get_pixel_intensity(img[r][c + 2]),
+ (int)get_pixel_intensity(img[r][c + 3]),
+ (int)get_pixel_intensity(img[r][c + 4]),
+ (int)get_pixel_intensity(img[r][c + 5]),
+ (int)get_pixel_intensity(img[r][c + 6]));
+ simd8i right((int)get_pixel_intensity(img[r][c + 1]),
+ (int)get_pixel_intensity(img[r][c + 2]),
+ (int)get_pixel_intensity(img[r][c + 3]),
+ (int)get_pixel_intensity(img[r][c + 4]),
+ (int)get_pixel_intensity(img[r][c + 5]),
+ (int)get_pixel_intensity(img[r][c + 6]),
+ (int)get_pixel_intensity(img[r][c + 7]),
+ (int)get_pixel_intensity(img[r][c + 8]));
+
+ simd8i top((int)get_pixel_intensity(img[r - 1][c]),
+ (int)get_pixel_intensity(img[r - 1][c + 1]),
+ (int)get_pixel_intensity(img[r - 1][c + 2]),
+ (int)get_pixel_intensity(img[r - 1][c + 3]),
+ (int)get_pixel_intensity(img[r - 1][c + 4]),
+ (int)get_pixel_intensity(img[r - 1][c + 5]),
+ (int)get_pixel_intensity(img[r - 1][c + 6]),
+ (int)get_pixel_intensity(img[r - 1][c + 7]));
+ simd8i bottom((int)get_pixel_intensity(img[r + 1][c]),
+ (int)get_pixel_intensity(img[r + 1][c + 1]),
+ (int)get_pixel_intensity(img[r + 1][c + 2]),
+ (int)get_pixel_intensity(img[r + 1][c + 3]),
+ (int)get_pixel_intensity(img[r + 1][c + 4]),
+ (int)get_pixel_intensity(img[r + 1][c + 5]),
+ (int)get_pixel_intensity(img[r + 1][c + 6]),
+ (int)get_pixel_intensity(img[r + 1][c + 7]));
+
+ grad_x = right - left;
+ grad_y = bottom - top;
+
+ len = (grad_x*grad_x + grad_y*grad_y);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, typename mm1, typename mm2>
+ inline void set_hog (
+ dlib::array<array2d<T,mm1>,mm2>& hog,
+ int o,
+ int x,
+ int y,
+ const float& value
+ )
+ {
+ hog[o][y][x] = value;
+ }
+
+ template <typename T, typename mm1, typename mm2>
+ void init_hog (
+ dlib::array<array2d<T,mm1>,mm2>& hog,
+ int hog_nr,
+ int hog_nc,
+ int filter_rows_padding,
+ int filter_cols_padding
+ )
+ {
+ const int num_hog_bands = 27+4;
+ hog.resize(num_hog_bands);
+ for (int i = 0; i < num_hog_bands; ++i)
+ {
+ hog[i].set_size(hog_nr+filter_rows_padding-1, hog_nc+filter_cols_padding-1);
+ rectangle rect = get_rect(hog[i]);
+ rect.top() += (filter_rows_padding-1)/2;
+ rect.left() += (filter_cols_padding-1)/2;
+ rect.right() -= filter_cols_padding/2;
+ rect.bottom() -= filter_rows_padding/2;
+ zero_border_pixels(hog[i],rect);
+ }
+ }
+
+ template <typename T, typename mm1, typename mm2>
+ void init_hog_zero_everything (
+ dlib::array<array2d<T,mm1>,mm2>& hog,
+ int hog_nr,
+ int hog_nc,
+ int filter_rows_padding,
+ int filter_cols_padding
+ )
+ {
+ const int num_hog_bands = 27+4;
+ hog.resize(num_hog_bands);
+ for (int i = 0; i < num_hog_bands; ++i)
+ {
+ hog[i].set_size(hog_nr+filter_rows_padding-1, hog_nc+filter_cols_padding-1);
+ assign_all_pixels(hog[i], 0);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, typename mm>
+ inline void set_hog (
+ array2d<matrix<T,31,1>,mm>& hog,
+ int o,
+ int x,
+ int y,
+ const float& value
+ )
+ {
+ hog[y][x](o) = value;
+ }
+
+ template <typename T, typename mm>
+ void init_hog (
+ array2d<matrix<T,31,1>,mm>& hog,
+ int hog_nr,
+ int hog_nc,
+ int filter_rows_padding,
+ int filter_cols_padding
+ )
+ {
+ hog.set_size(hog_nr+filter_rows_padding-1, hog_nc+filter_cols_padding-1);
+
+ // now zero out the border region
+ rectangle rect = get_rect(hog);
+ rect.top() += (filter_rows_padding-1)/2;
+ rect.left() += (filter_cols_padding-1)/2;
+ rect.right() -= filter_cols_padding/2;
+ rect.bottom() -= filter_rows_padding/2;
+ border_enumerator be(get_rect(hog),rect);
+ while (be.move_next())
+ {
+ const point p = be.element();
+ set_all_elements(hog[p.y()][p.x()], 0);
+ }
+ }
+
+ template <typename T, typename mm>
+ void init_hog_zero_everything (
+ array2d<matrix<T,31,1>,mm>& hog,
+ int hog_nr,
+ int hog_nc,
+ int filter_rows_padding,
+ int filter_cols_padding
+ )
+ {
+ hog.set_size(hog_nr+filter_rows_padding-1, hog_nc+filter_cols_padding-1);
+
+ for (long r = 0; r < hog.nr(); ++r)
+ {
+ for (long c = 0; c < hog.nc(); ++c)
+ {
+ set_all_elements(hog[r][c], 0);
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename out_type
+ >
+ void impl_extract_fhog_features_cell_size_1(
+ const image_type& img_,
+ out_type& hog,
+ int filter_rows_padding,
+ int filter_cols_padding
+ )
+ {
+ const_image_view<image_type> img(img_);
+ // make sure requires clause is not broken
+ DLIB_ASSERT( filter_rows_padding > 0 &&
+ filter_cols_padding > 0 ,
+ "\t void extract_fhog_features()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t filter_rows_padding: " << filter_rows_padding
+ << "\n\t filter_cols_padding: " << filter_cols_padding
+ );
+
+ /*
+ This function is an optimized version of impl_extract_fhog_features() for
+ the case where cell_size == 1.
+ */
+
+
+ // unit vectors used to compute gradient orientation
+ matrix<float,2,1> directions[9];
+ directions[0] = 1.0000, 0.0000;
+ directions[1] = 0.9397, 0.3420;
+ directions[2] = 0.7660, 0.6428;
+ directions[3] = 0.500, 0.8660;
+ directions[4] = 0.1736, 0.9848;
+ directions[5] = -0.1736, 0.9848;
+ directions[6] = -0.5000, 0.8660;
+ directions[7] = -0.7660, 0.6428;
+ directions[8] = -0.9397, 0.3420;
+
+
+
+ if (img.nr() <= 2 || img.nc() <= 2)
+ {
+ hog.clear();
+ return;
+ }
+
+ array2d<unsigned char> angle(img.nr(), img.nc());
+
+ array2d<float> norm(img.nr(), img.nc());
+ zero_border_pixels(norm,1,1);
+
+ // memory for HOG features
+ const long hog_nr = img.nr()-2;
+ const long hog_nc = img.nc()-2;
+
+ const int padding_rows_offset = (filter_rows_padding-1)/2;
+ const int padding_cols_offset = (filter_cols_padding-1)/2;
+ init_hog_zero_everything(hog, hog_nr, hog_nc, filter_rows_padding, filter_cols_padding);
+
+
+ const int visible_nr = img.nr()-1;
+ const int visible_nc = img.nc()-1;
+
+ // First populate the gradient histograms
+ for (int y = 1; y < visible_nr; y++)
+ {
+ int x;
+ for (x = 1; x < visible_nc - 7; x += 8)
+ {
+ // v will be the length of the gradient vectors.
+ simd8f grad_x, grad_y, v;
+ get_gradient(y, x, img, grad_x, grad_y, v);
+
+ float _vv[8];
+ v.store(_vv);
+
+ // Now snap the gradient to one of 18 orientations
+ simd8f best_dot = 0;
+ simd8f best_o = 0;
+ for (int o = 0; o < 9; o++)
+ {
+ simd8f dot = grad_x*directions[o](0) + grad_y*directions[o](1);
+ simd8f_bool cmp = dot>best_dot;
+ best_dot = select(cmp, dot, best_dot);
+ dot *= -1;
+ best_o = select(cmp, o, best_o);
+
+ cmp = dot > best_dot;
+ best_dot = select(cmp, dot, best_dot);
+ best_o = select(cmp, o + 9, best_o);
+ }
+
+ int32 _best_o[8]; simd8i(best_o).store(_best_o);
+
+ norm[y][x + 0] = _vv[0];
+ norm[y][x + 1] = _vv[1];
+ norm[y][x + 2] = _vv[2];
+ norm[y][x + 3] = _vv[3];
+ norm[y][x + 4] = _vv[4];
+ norm[y][x + 5] = _vv[5];
+ norm[y][x + 6] = _vv[6];
+ norm[y][x + 7] = _vv[7];
+
+ angle[y][x + 0] = _best_o[0];
+ angle[y][x + 1] = _best_o[1];
+ angle[y][x + 2] = _best_o[2];
+ angle[y][x + 3] = _best_o[3];
+ angle[y][x + 4] = _best_o[4];
+ angle[y][x + 5] = _best_o[5];
+ angle[y][x + 6] = _best_o[6];
+ angle[y][x + 7] = _best_o[7];
+ }
+ // Now process the right columns that don't fit into simd registers.
+ for (; x < visible_nc; x++)
+ {
+ matrix<float,2,1> grad;
+ float v;
+ get_gradient(y,x,img,grad,v);
+
+ // snap to one of 18 orientations
+ float best_dot = 0;
+ int best_o = 0;
+ for (int o = 0; o < 9; o++)
+ {
+ const float dot = dlib::dot(directions[o], grad);
+ if (dot > best_dot)
+ {
+ best_dot = dot;
+ best_o = o;
+ }
+ else if (-dot > best_dot)
+ {
+ best_dot = -dot;
+ best_o = o+9;
+ }
+ }
+
+ norm[y][x] = v;
+ angle[y][x] = best_o;
+ }
+ }
+
+ const float eps = 0.0001;
+ // compute features
+ for (int y = 0; y < hog_nr; y++)
+ {
+ const int yy = y+padding_rows_offset;
+ for (int x = 0; x < hog_nc; x++)
+ {
+ const simd4f z1(norm[y+1][x+1],
+ norm[y][x+1],
+ norm[y+1][x],
+ norm[y][x]);
+
+ const simd4f z2(norm[y+1][x+2],
+ norm[y][x+2],
+ norm[y+1][x+1],
+ norm[y][x+1]);
+
+ const simd4f z3(norm[y+2][x+1],
+ norm[y+1][x+1],
+ norm[y+2][x],
+ norm[y+1][x]);
+
+ const simd4f z4(norm[y+2][x+2],
+ norm[y+1][x+2],
+ norm[y+2][x+1],
+ norm[y+1][x+1]);
+
+ const simd4f temp0 = std::sqrt(norm[y+1][x+1]);
+ const simd4f nn = 0.2*sqrt(z1+z2+z3+z4+eps);
+ const simd4f n = 0.1/nn;
+
+ simd4f t = 0;
+
+ const int xx = x+padding_cols_offset;
+
+ simd4f h0 = min(temp0,nn)*n;
+ const float vv = sum(h0);
+ set_hog(hog,angle[y+1][x+1],xx,yy, vv);
+ t += h0;
+
+ t *= 2*0.2357;
+
+ // contrast-insensitive features
+ set_hog(hog,angle[y+1][x+1]%9+18,xx,yy, vv);
+
+
+ float temp[4];
+ t.store(temp);
+
+ // texture features
+ set_hog(hog,27,xx,yy, temp[0]);
+ set_hog(hog,28,xx,yy, temp[1]);
+ set_hog(hog,29,xx,yy, temp[2]);
+ set_hog(hog,30,xx,yy, temp[3]);
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename out_type
+ >
+ void impl_extract_fhog_features(
+ const image_type& img_,
+ out_type& hog,
+ int cell_size,
+ int filter_rows_padding,
+ int filter_cols_padding
+ )
+ {
+ const_image_view<image_type> img(img_);
+ // make sure requires clause is not broken
+ DLIB_ASSERT( cell_size > 0 &&
+ filter_rows_padding > 0 &&
+ filter_cols_padding > 0 ,
+ "\t void extract_fhog_features()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t cell_size: " << cell_size
+ << "\n\t filter_rows_padding: " << filter_rows_padding
+ << "\n\t filter_cols_padding: " << filter_cols_padding
+ );
+
+ /*
+ This function implements the HOG feature extraction method described in
+ the paper:
+ P. Felzenszwalb, R. Girshick, D. McAllester, D. Ramanan
+ Object Detection with Discriminatively Trained Part Based Models
+ IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. 32, No. 9, Sep. 2010
+
+ Moreover, this function is derived from the HOG feature extraction code
+ from the features.cc file in the voc-releaseX code (see
+ http://people.cs.uchicago.edu/~rbg/latent/) which is has the following
+ license (note that the code has been modified to work with grayscale and
+ color as well as planar and interlaced input and output formats):
+
+ Copyright (C) 2011, 2012 Ross Girshick, Pedro Felzenszwalb
+ Copyright (C) 2008, 2009, 2010 Pedro Felzenszwalb, Ross Girshick
+ Copyright (C) 2007 Pedro Felzenszwalb, Deva Ramanan
+
+ Permission is hereby granted, free of charge, to any person obtaining
+ a copy of this software and associated documentation files (the
+ "Software"), to deal in the Software without restriction, including
+ without limitation the rights to use, copy, modify, merge, publish,
+ distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to
+ the following conditions:
+
+ The above copyright notice and this permission notice shall be
+ included in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+ if (cell_size == 1)
+ {
+ impl_extract_fhog_features_cell_size_1(img_,hog,filter_rows_padding,filter_cols_padding);
+ return;
+ }
+
+ // unit vectors used to compute gradient orientation
+ matrix<float,2,1> directions[9];
+ directions[0] = 1.0000, 0.0000;
+ directions[1] = 0.9397, 0.3420;
+ directions[2] = 0.7660, 0.6428;
+ directions[3] = 0.500, 0.8660;
+ directions[4] = 0.1736, 0.9848;
+ directions[5] = -0.1736, 0.9848;
+ directions[6] = -0.5000, 0.8660;
+ directions[7] = -0.7660, 0.6428;
+ directions[8] = -0.9397, 0.3420;
+
+
+
+ // First we allocate memory for caching orientation histograms & their norms.
+ const int cells_nr = (int)((float)img.nr()/(float)cell_size + 0.5);
+ const int cells_nc = (int)((float)img.nc()/(float)cell_size + 0.5);
+
+ if (cells_nr == 0 || cells_nc == 0)
+ {
+ hog.clear();
+ return;
+ }
+
+ // We give hist extra padding around the edges (1 cell all the way around the
+ // edge) so we can avoid needing to do boundary checks when indexing into it
+ // later on. So some statements assign to the boundary but those values are
+ // never used.
+ array2d<matrix<float,18,1> > hist(cells_nr+2, cells_nc+2);
+ for (long r = 0; r < hist.nr(); ++r)
+ {
+ for (long c = 0; c < hist.nc(); ++c)
+ {
+ hist[r][c] = 0;
+ }
+ }
+
+ array2d<float> norm(cells_nr, cells_nc);
+ assign_all_pixels(norm, 0);
+
+ // memory for HOG features
+ const int hog_nr = std::max(cells_nr-2, 0);
+ const int hog_nc = std::max(cells_nc-2, 0);
+ if (hog_nr == 0 || hog_nc == 0)
+ {
+ hog.clear();
+ return;
+ }
+ const int padding_rows_offset = (filter_rows_padding-1)/2;
+ const int padding_cols_offset = (filter_cols_padding-1)/2;
+ init_hog(hog, hog_nr, hog_nc, filter_rows_padding, filter_cols_padding);
+
+ const int visible_nr = std::min((long)cells_nr*cell_size,img.nr())-1;
+ const int visible_nc = std::min((long)cells_nc*cell_size,img.nc())-1;
+
+ // First populate the gradient histograms
+ for (int y = 1; y < visible_nr; y++)
+ {
+ const float yp = ((float)y+0.5)/(float)cell_size - 0.5;
+ const int iyp = (int)std::floor(yp);
+ const float vy0 = yp - iyp;
+ const float vy1 = 1.0 - vy0;
+ int x;
+ for (x = 1; x < visible_nc - 7; x += 8)
+ {
+ simd8f xx(x, x + 1, x + 2, x + 3, x + 4, x + 5, x + 6, x + 7);
+ // v will be the length of the gradient vectors.
+ simd8f grad_x, grad_y, v;
+ get_gradient(y, x, img, grad_x, grad_y, v);
+
+ // We will use bilinear interpolation to add into the histogram bins.
+ // So first we precompute the values needed to determine how much each
+ // pixel votes into each bin.
+ simd8f xp = (xx + 0.5) / (float)cell_size + 0.5;
+ simd8i ixp = simd8i(xp);
+ simd8f vx0 = xp - ixp;
+ simd8f vx1 = 1.0f - vx0;
+
+ v = sqrt(v);
+
+ // Now snap the gradient to one of 18 orientations
+ simd8f best_dot = 0;
+ simd8f best_o = 0;
+ for (int o = 0; o < 9; o++)
+ {
+ simd8f dot = grad_x*directions[o](0) + grad_y*directions[o](1);
+ simd8f_bool cmp = dot>best_dot;
+ best_dot = select(cmp, dot, best_dot);
+ dot *= -1;
+ best_o = select(cmp, o, best_o);
+
+ cmp = dot > best_dot;
+ best_dot = select(cmp, dot, best_dot);
+ best_o = select(cmp, o + 9, best_o);
+ }
+
+
+ // Add the gradient magnitude, v, to 4 histograms around pixel using
+ // bilinear interpolation.
+ vx1 *= v;
+ vx0 *= v;
+ // The amounts for each bin
+ simd8f v11 = vy1*vx1;
+ simd8f v01 = vy0*vx1;
+ simd8f v10 = vy1*vx0;
+ simd8f v00 = vy0*vx0;
+
+ int32 _best_o[8]; simd8i(best_o).store(_best_o);
+ int32 _ixp[8]; ixp.store(_ixp);
+ float _v11[8]; v11.store(_v11);
+ float _v01[8]; v01.store(_v01);
+ float _v10[8]; v10.store(_v10);
+ float _v00[8]; v00.store(_v00);
+
+ hist[iyp + 1][_ixp[0]](_best_o[0]) += _v11[0];
+ hist[iyp + 1 + 1][_ixp[0]](_best_o[0]) += _v01[0];
+ hist[iyp + 1][_ixp[0] + 1](_best_o[0]) += _v10[0];
+ hist[iyp + 1 + 1][_ixp[0] + 1](_best_o[0]) += _v00[0];
+
+ hist[iyp + 1][_ixp[1]](_best_o[1]) += _v11[1];
+ hist[iyp + 1 + 1][_ixp[1]](_best_o[1]) += _v01[1];
+ hist[iyp + 1][_ixp[1] + 1](_best_o[1]) += _v10[1];
+ hist[iyp + 1 + 1][_ixp[1] + 1](_best_o[1]) += _v00[1];
+
+ hist[iyp + 1][_ixp[2]](_best_o[2]) += _v11[2];
+ hist[iyp + 1 + 1][_ixp[2]](_best_o[2]) += _v01[2];
+ hist[iyp + 1][_ixp[2] + 1](_best_o[2]) += _v10[2];
+ hist[iyp + 1 + 1][_ixp[2] + 1](_best_o[2]) += _v00[2];
+
+ hist[iyp + 1][_ixp[3]](_best_o[3]) += _v11[3];
+ hist[iyp + 1 + 1][_ixp[3]](_best_o[3]) += _v01[3];
+ hist[iyp + 1][_ixp[3] + 1](_best_o[3]) += _v10[3];
+ hist[iyp + 1 + 1][_ixp[3] + 1](_best_o[3]) += _v00[3];
+
+ hist[iyp + 1][_ixp[4]](_best_o[4]) += _v11[4];
+ hist[iyp + 1 + 1][_ixp[4]](_best_o[4]) += _v01[4];
+ hist[iyp + 1][_ixp[4] + 1](_best_o[4]) += _v10[4];
+ hist[iyp + 1 + 1][_ixp[4] + 1](_best_o[4]) += _v00[4];
+
+ hist[iyp + 1][_ixp[5]](_best_o[5]) += _v11[5];
+ hist[iyp + 1 + 1][_ixp[5]](_best_o[5]) += _v01[5];
+ hist[iyp + 1][_ixp[5] + 1](_best_o[5]) += _v10[5];
+ hist[iyp + 1 + 1][_ixp[5] + 1](_best_o[5]) += _v00[5];
+
+ hist[iyp + 1][_ixp[6]](_best_o[6]) += _v11[6];
+ hist[iyp + 1 + 1][_ixp[6]](_best_o[6]) += _v01[6];
+ hist[iyp + 1][_ixp[6] + 1](_best_o[6]) += _v10[6];
+ hist[iyp + 1 + 1][_ixp[6] + 1](_best_o[6]) += _v00[6];
+
+ hist[iyp + 1][_ixp[7]](_best_o[7]) += _v11[7];
+ hist[iyp + 1 + 1][_ixp[7]](_best_o[7]) += _v01[7];
+ hist[iyp + 1][_ixp[7] + 1](_best_o[7]) += _v10[7];
+ hist[iyp + 1 + 1][_ixp[7] + 1](_best_o[7]) += _v00[7];
+ }
+ // Now process the right columns that don't fit into simd registers.
+ for (; x < visible_nc; x++)
+ {
+ matrix<float, 2, 1> grad;
+ float v;
+ get_gradient(y,x,img,grad,v);
+
+ // snap to one of 18 orientations
+ float best_dot = 0;
+ int best_o = 0;
+ for (int o = 0; o < 9; o++)
+ {
+ const float dot = dlib::dot(directions[o], grad);
+ if (dot > best_dot)
+ {
+ best_dot = dot;
+ best_o = o;
+ }
+ else if (-dot > best_dot)
+ {
+ best_dot = -dot;
+ best_o = o+9;
+ }
+ }
+
+ v = std::sqrt(v);
+ // add to 4 histograms around pixel using bilinear interpolation
+ const float xp = ((double)x + 0.5) / (double)cell_size - 0.5;
+ const int ixp = (int)std::floor(xp);
+ const float vx0 = xp - ixp;
+ const float vx1 = 1.0 - vx0;
+
+ hist[iyp+1][ixp+1](best_o) += vy1*vx1*v;
+ hist[iyp+1+1][ixp+1](best_o) += vy0*vx1*v;
+ hist[iyp+1][ixp+1+1](best_o) += vy1*vx0*v;
+ hist[iyp+1+1][ixp+1+1](best_o) += vy0*vx0*v;
+ }
+ }
+
+ // compute energy in each block by summing over orientations
+ for (int r = 0; r < cells_nr; ++r)
+ {
+ for (int c = 0; c < cells_nc; ++c)
+ {
+ for (int o = 0; o < 9; o++)
+ {
+ norm[r][c] += (hist[r+1][c+1](o) + hist[r+1][c+1](o+9)) * (hist[r+1][c+1](o) + hist[r+1][c+1](o+9));
+ }
+ }
+ }
+
+ const float eps = 0.0001;
+ // compute features
+ for (int y = 0; y < hog_nr; y++)
+ {
+ const int yy = y+padding_rows_offset;
+ for (int x = 0; x < hog_nc; x++)
+ {
+ const simd4f z1(norm[y+1][x+1],
+ norm[y][x+1],
+ norm[y+1][x],
+ norm[y][x]);
+
+ const simd4f z2(norm[y+1][x+2],
+ norm[y][x+2],
+ norm[y+1][x+1],
+ norm[y][x+1]);
+
+ const simd4f z3(norm[y+2][x+1],
+ norm[y+1][x+1],
+ norm[y+2][x],
+ norm[y+1][x]);
+
+ const simd4f z4(norm[y+2][x+2],
+ norm[y+1][x+2],
+ norm[y+2][x+1],
+ norm[y+1][x+1]);
+
+ const simd4f nn = 0.2*sqrt(z1+z2+z3+z4+eps);
+ const simd4f n = 0.1/nn;
+
+ simd4f t = 0;
+
+ const int xx = x+padding_cols_offset;
+
+ // contrast-sensitive features
+ for (int o = 0; o < 18; o+=3)
+ {
+ simd4f temp0(hist[y+1+1][x+1+1](o));
+ simd4f temp1(hist[y+1+1][x+1+1](o+1));
+ simd4f temp2(hist[y+1+1][x+1+1](o+2));
+ simd4f h0 = min(temp0,nn)*n;
+ simd4f h1 = min(temp1,nn)*n;
+ simd4f h2 = min(temp2,nn)*n;
+ set_hog(hog,o,xx,yy, sum(h0));
+ set_hog(hog,o+1,xx,yy, sum(h1));
+ set_hog(hog,o+2,xx,yy, sum(h2));
+ t += h0+h1+h2;
+ }
+
+ t *= 2*0.2357;
+
+ // contrast-insensitive features
+ for (int o = 0; o < 9; o+=3)
+ {
+ simd4f temp0 = hist[y+1+1][x+1+1](o) + hist[y+1+1][x+1+1](o+9);
+ simd4f temp1 = hist[y+1+1][x+1+1](o+1) + hist[y+1+1][x+1+1](o+9+1);
+ simd4f temp2 = hist[y+1+1][x+1+1](o+2) + hist[y+1+1][x+1+1](o+9+2);
+ simd4f h0 = min(temp0,nn)*n;
+ simd4f h1 = min(temp1,nn)*n;
+ simd4f h2 = min(temp2,nn)*n;
+ set_hog(hog,o+18,xx,yy, sum(h0));
+ set_hog(hog,o+18+1,xx,yy, sum(h1));
+ set_hog(hog,o+18+2,xx,yy, sum(h2));
+ }
+
+
+ float temp[4];
+ t.store(temp);
+
+ // texture features
+ set_hog(hog,27,xx,yy, temp[0]);
+ set_hog(hog,28,xx,yy, temp[1]);
+ set_hog(hog,29,xx,yy, temp[2]);
+ set_hog(hog,30,xx,yy, temp[3]);
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline void create_fhog_bar_images (
+ dlib::array<matrix<float> >& mbars,
+ const long w
+ )
+ {
+ const long bdims = 9;
+ // Make the oriented lines we use to draw on each HOG cell.
+ mbars.resize(bdims);
+ dlib::array<array2d<unsigned char> > bars(bdims);
+ array2d<unsigned char> temp(w,w);
+ for (unsigned long i = 0; i < bars.size(); ++i)
+ {
+ assign_all_pixels(temp, 0);
+ draw_line(temp, point(w/2,0), point(w/2,w-1), 255);
+ rotate_image(temp, bars[i], i*-pi/bars.size());
+
+ mbars[i] = subm(matrix_cast<float>(mat(bars[i])), centered_rect(get_rect(bars[i]),w,w) );
+ }
+ }
+
+ } // end namespace impl_fhog
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T,
+ typename mm1,
+ typename mm2
+ >
+ void extract_fhog_features(
+ const image_type& img,
+ dlib::array<array2d<T,mm1>,mm2>& hog,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ )
+ {
+ impl_fhog::impl_extract_fhog_features(img, hog, cell_size, filter_rows_padding, filter_cols_padding);
+ // If the image is too small then the above function outputs an empty feature map.
+ // But to make things very uniform in usage we require the output to still have the
+ // 31 planes (but they are just empty).
+ if (hog.size() == 0)
+ hog.resize(31);
+ }
+
+ template <
+ typename image_type,
+ typename T,
+ typename mm
+ >
+ void extract_fhog_features(
+ const image_type& img,
+ array2d<matrix<T,31,1>,mm>& hog,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ )
+ {
+ impl_fhog::impl_extract_fhog_features(img, hog, cell_size, filter_rows_padding, filter_cols_padding);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T
+ >
+ void extract_fhog_features(
+ const image_type& img,
+ matrix<T,0,1>& feats,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ )
+ {
+ dlib::array<array2d<T> > hog;
+ extract_fhog_features(img, hog, cell_size, filter_rows_padding, filter_cols_padding);
+ feats.set_size(hog.size()*hog[0].size());
+ for (unsigned long i = 0; i < hog.size(); ++i)
+ {
+ const long size = hog[i].size();
+ set_rowm(feats, range(i*size, (i+1)*size-1)) = reshape_to_column_vector(mat(hog[i]));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ matrix<double,0,1> extract_fhog_features(
+ const image_type& img,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ )
+ {
+ matrix<double, 0, 1> feats;
+ extract_fhog_features(img, feats, cell_size, filter_rows_padding, filter_cols_padding);
+ return feats;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ inline point image_to_fhog (
+ point p,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( cell_size > 0 &&
+ filter_rows_padding > 0 &&
+ filter_cols_padding > 0 ,
+ "\t point image_to_fhog()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t cell_size: " << cell_size
+ << "\n\t filter_rows_padding: " << filter_rows_padding
+ << "\n\t filter_cols_padding: " << filter_cols_padding
+ );
+
+ // There is a one pixel border around the image.
+ p -= point(1,1);
+ // There is also a 1 "cell" border around the HOG image formation.
+ return p/cell_size - point(1,1) + point((filter_cols_padding-1)/2,(filter_rows_padding-1)/2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle image_to_fhog (
+ const rectangle& rect,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( cell_size > 0 &&
+ filter_rows_padding > 0 &&
+ filter_cols_padding > 0 ,
+ "\t rectangle image_to_fhog()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t cell_size: " << cell_size
+ << "\n\t filter_rows_padding: " << filter_rows_padding
+ << "\n\t filter_cols_padding: " << filter_cols_padding
+ );
+
+ return rectangle(image_to_fhog(rect.tl_corner(),cell_size,filter_rows_padding,filter_cols_padding),
+ image_to_fhog(rect.br_corner(),cell_size,filter_rows_padding,filter_cols_padding));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline point fhog_to_image (
+ point p,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( cell_size > 0 &&
+ filter_rows_padding > 0 &&
+ filter_cols_padding > 0 ,
+ "\t point fhog_to_image()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t cell_size: " << cell_size
+ << "\n\t filter_rows_padding: " << filter_rows_padding
+ << "\n\t filter_cols_padding: " << filter_cols_padding
+ );
+
+ // Convert to image space and then set to the center of the cell.
+ point offset;
+
+ p = (p+point(1,1)-point((filter_cols_padding-1)/2,(filter_rows_padding-1)/2))*cell_size + point(1,1);
+ if (p.x() >= 0 && p.y() >= 0) offset = point(cell_size/2,cell_size/2);
+ if (p.x() < 0 && p.y() >= 0) offset = point(-cell_size/2,cell_size/2);
+ if (p.x() >= 0 && p.y() < 0) offset = point(cell_size/2,-cell_size/2);
+ if (p.x() < 0 && p.y() < 0) offset = point(-cell_size/2,-cell_size/2);
+ return p + offset;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle fhog_to_image (
+ const rectangle& rect,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( cell_size > 0 &&
+ filter_rows_padding > 0 &&
+ filter_cols_padding > 0 ,
+ "\t rectangle fhog_to_image()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t cell_size: " << cell_size
+ << "\n\t filter_rows_padding: " << filter_rows_padding
+ << "\n\t filter_cols_padding: " << filter_cols_padding
+ );
+
+ return rectangle(fhog_to_image(rect.tl_corner(),cell_size,filter_rows_padding,filter_cols_padding),
+ fhog_to_image(rect.br_corner(),cell_size,filter_rows_padding,filter_cols_padding));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mm1,
+ typename mm2
+ >
+ matrix<unsigned char> draw_fhog(
+ const dlib::array<array2d<T,mm1>,mm2>& hog,
+ const long cell_draw_size = 15,
+ const float min_response_threshold = 0.0
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( cell_draw_size > 0 && hog.size()==31,
+ "\t matrix<unsigned char> draw_fhog()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t cell_draw_size: " << cell_draw_size
+ << "\n\t hog.size(): " << hog.size()
+ );
+
+ dlib::array<matrix<float> > mbars;
+ impl_fhog::create_fhog_bar_images(mbars,cell_draw_size);
+
+ // now draw the bars onto the HOG cells
+ matrix<float> himg(hog[0].nr()*cell_draw_size, hog[0].nc()*cell_draw_size);
+ himg = 0;
+ for (unsigned long d = 0; d < mbars.size(); ++d)
+ {
+ for (long r = 0; r < himg.nr(); r+=cell_draw_size)
+ {
+ for (long c = 0; c < himg.nc(); c+=cell_draw_size)
+ {
+ const float val = hog[d][r/cell_draw_size][c/cell_draw_size] +
+ hog[d+mbars.size()][r/cell_draw_size][c/cell_draw_size] +
+ hog[d+mbars.size()*2][r/cell_draw_size][c/cell_draw_size];
+ if (val > min_response_threshold)
+ {
+ set_subm(himg, r, c, cell_draw_size, cell_draw_size) += val*mbars[d%mbars.size()];
+ }
+ }
+ }
+ }
+
+ const float thresh = mean(himg) + 4 * stddev(himg);
+ if (thresh != 0)
+ return matrix_cast<unsigned char>(upperbound(round(himg*255/thresh),255));
+ else
+ return matrix_cast<unsigned char>(himg);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ matrix<unsigned char> draw_fhog (
+ const std::vector<matrix<T> >& hog,
+ const long cell_draw_size = 15,
+ const float min_response_threshold = 0.0
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( cell_draw_size > 0 && hog.size()==31,
+ "\t matrix<unsigned char> draw_fhog()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t cell_draw_size: " << cell_draw_size
+ << "\n\t hog.size(): " << hog.size()
+ );
+
+ // Just convert the input into the right object and then call the above draw_fhog()
+ // function on it.
+ dlib::array<array2d<T> > temp(hog.size());
+ for (unsigned long i = 0; i < temp.size(); ++i)
+ {
+ temp[i].set_size(hog[i].nr(), hog[i].nc());
+ for (long r = 0; r < hog[i].nr(); ++r)
+ {
+ for (long c = 0; c < hog[i].nc(); ++c)
+ {
+ temp[i][r][c] = hog[i](r,c);
+ }
+ }
+ }
+ return draw_fhog(temp,cell_draw_size, min_response_threshold);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mm
+ >
+ matrix<unsigned char> draw_fhog(
+ const array2d<matrix<T,31,1>,mm>& hog,
+ const long cell_draw_size = 15,
+ const float min_response_threshold = 0.0
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( cell_draw_size > 0,
+ "\t matrix<unsigned char> draw_fhog()"
+ << "\n\t Invalid inputs were given to this function. "
+ << "\n\t cell_draw_size: " << cell_draw_size
+ );
+
+ dlib::array<matrix<float> > mbars;
+ impl_fhog::create_fhog_bar_images(mbars,cell_draw_size);
+
+ // now draw the bars onto the HOG cells
+ matrix<float> himg(hog.nr()*cell_draw_size, hog.nc()*cell_draw_size);
+ himg = 0;
+ for (unsigned long d = 0; d < mbars.size(); ++d)
+ {
+ for (long r = 0; r < himg.nr(); r+=cell_draw_size)
+ {
+ for (long c = 0; c < himg.nc(); c+=cell_draw_size)
+ {
+ const float val = hog[r/cell_draw_size][c/cell_draw_size](d) +
+ hog[r/cell_draw_size][c/cell_draw_size](d+mbars.size()) +
+ hog[r/cell_draw_size][c/cell_draw_size](d+mbars.size()*2);
+ if (val > min_response_threshold)
+ {
+ set_subm(himg, r, c, cell_draw_size, cell_draw_size) += val*mbars[d%mbars.size()];
+ }
+ }
+ }
+ }
+
+ const float thresh = mean(himg) + 4 * stddev(himg);
+ if (thresh != 0)
+ return matrix_cast<unsigned char>(upperbound(round(himg*255/thresh),255));
+ else
+ return matrix_cast<unsigned char>(himg);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_fHOG_Hh_
+
diff --git a/ml/dlib/dlib/image_transforms/fhog_abstract.h b/ml/dlib/dlib/image_transforms/fhog_abstract.h
new file mode 100644
index 000000000..f66c5d55a
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/fhog_abstract.h
@@ -0,0 +1,346 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_fHOG_ABSTRACT_Hh_
+#ifdef DLIB_fHOG_ABSTRACT_Hh_
+
+#include "../matrix/matrix_abstract.h"
+#include "../array2d/array2d_kernel_abstract.h"
+#include "../array/array_kernel_abstract.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T,
+ typename mm
+ >
+ void extract_fhog_features(
+ const image_type& img,
+ array2d<matrix<T,31,1>,mm>& hog,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ );
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - T should be float or double
+ ensures
+ - This function implements the HOG feature extraction method described in
+ the paper:
+ Object Detection with Discriminatively Trained Part Based Models by
+ P. Felzenszwalb, R. Girshick, D. McAllester, D. Ramanan
+ IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. 32, No. 9, Sep. 2010
+ This means that it takes an input image img and outputs Felzenszwalb's
+ 31 dimensional version of HOG features, which are stored into #hog.
+ - The input image is broken into cells that are cell_size by cell_size pixels
+ and within each cell we compute a 31 dimensional FHOG vector. This vector
+ describes the gradient structure within the cell.
+ - A common task is to convolve each channel of the hog image with a linear
+ filter. This is made more convenient if the contents of #hog includes extra
+ rows and columns of zero padding along the borders. This extra padding
+ allows for more efficient convolution code since the code does not need to
+ perform expensive boundary checking. Therefore, you can set
+ filter_rows_padding and filter_cols_padding to indicate the size of the
+ filter you wish to use and this function will ensure #hog has the appropriate
+ extra zero padding along the borders. In particular, it will include the
+ following extra padding:
+ - (filter_rows_padding-1)/2 extra rows of zeros on the top of #hog.
+ - (filter_cols_padding-1)/2 extra columns of zeros on the left of #hog.
+ - filter_rows_padding/2 extra rows of zeros on the bottom of #hog.
+ - filter_cols_padding/2 extra columns of zeros on the right of #hog.
+ Therefore, the extra padding is done such that functions like
+ spatially_filter_image() apply their filters to the entire content containing
+ area of a hog image (note that you should use the following planar version of
+ extract_fhog_features() instead of the interlaced version if you want to use
+ spatially_filter_image() on a hog image).
+ - #hog.nr() == max(round(img.nr()/(double)cell_size)-2,0) + filter_rows_padding-1.
+ - #hog.nc() == max(round(img.nc()/(double)cell_size)-2,0) + filter_cols_padding-1.
+ (i.e. Each output dimension is roughly 1/cell_size the original size but
+ there is a one cell_size border all around the image that is lost and then we
+ add on any additional padding that is requested.)
+ - for all valid r and c:
+ - #hog[r][c] == the FHOG vector describing the cell centered at the pixel location
+ fhog_to_image(point(c,r),cell_size,filter_rows_padding,filter_cols_padding) in img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T,
+ typename mm1,
+ typename mm2
+ >
+ void extract_fhog_features(
+ const image_type& img,
+ dlib::array<array2d<T,mm1>,mm2>& hog,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ );
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - T should be float or double
+ ensures
+ - This function is identical to the above extract_fhog_features() routine
+ except that it outputs the results in a planar format rather than the
+ interlaced format used above. That is, each element of the hog vector is
+ placed into one of 31 images inside #hog. To be precise, if vhog is the
+ output of the above interlaced version of extract_fhog_features() then we
+ will have, for all valid r and c:
+ - #hog[i][r][c] == vhog[r][c](i)
+ (where 0 <= i < 31)
+ - #hog.size() == 31
+ - for all valid i:
+ - #hog[i].nr() == hog[0].nr()
+ - #hog[i].nc() == hog[0].nc()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ matrix<double,0,1> extract_fhog_features(
+ const image_type& img,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ );
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - This function calls the above extract_fhog_features() routine and simply
+ packages the entire output into a dlib::matrix. The matrix is constructed
+ using the planar version of extract_fhog_features() and then each output
+ plane is converted into a column vector and subsequently all 31 column
+ vectors are concatenated together and returned.
+ - Each plane is converted into a column vector using reshape_to_column_vector(),
+ and is therefore represented in row major order inside the returned vector.
+ - If H is the array<array2d<double>> object output by the planar
+ extract_fhog_features() then the returned vector is composed by concatenating
+ H[0], then H[1], then H[2], and so on in ascending index order.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T
+ >
+ void extract_fhog_features(
+ const image_type& img,
+ matrix<T,0,1>& feats,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ );
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - T is float, double, or long double
+ ensures
+ - This function is identical to the above version of extract_fhog_features()
+ that returns a matrix<double,0,1> except that it returns the matrix here
+ through a reference argument instead of returning it by value.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline point image_to_fhog (
+ point p,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ );
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ ensures
+ - When using extract_fhog_features(), each FHOG cell is extracted from a
+ certain region in the input image. image_to_fhog() returns the identity of
+ the FHOG cell containing the image pixel at location p. Or in other words,
+ let P == image_to_fhog(p) and hog be a FHOG feature map output by
+ extract_fhog_features(), then hog[P.y()][P.x()] == the FHOG vector/cell
+ containing the point p in the input image. Note that some image points
+ might not have corresponding feature locations. E.g. border points or points
+ outside the image. In these cases the returned point will be outside the
+ input image.
+ - Note that you should use the same values of cell_size, filter_rows_padding,
+ and filter_cols_padding that you used with extract_fhog_features().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle image_to_fhog (
+ const rectangle& rect,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ );
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ ensures
+ - maps a rectangle from image space to fhog space. In particular this function returns:
+ rectangle(image_to_fhog(rect.tl_corner(),cell_size,filter_rows_padding,filter_cols_padding),
+ image_to_fhog(rect.br_corner(),cell_size,filter_rows_padding,filter_cols_padding))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline point fhog_to_image (
+ point p,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ );
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ ensures
+ - Maps a pixel in a FHOG image (produced by extract_fhog_features()) back to the
+ corresponding original input pixel. Note that since FHOG images are
+ spatially downsampled by aggregation into cells the mapping is not totally
+ invertible. Therefore, the returned location will be the center of the cell
+ in the original image that contained the FHOG vector at position p. Moreover,
+ cell_size, filter_rows_padding, and filter_cols_padding should be set to the
+ values used by the call to extract_fhog_features().
+ - Mapping from fhog space to image space is an invertible transformation. That
+ is, for any point P we have P == image_to_fhog(fhog_to_image(P,cell_size,filter_rows_padding,filter_cols_padding),
+ cell_size,filter_rows_padding,filter_cols_padding).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline rectangle fhog_to_image (
+ const rectangle& rect,
+ int cell_size = 8,
+ int filter_rows_padding = 1,
+ int filter_cols_padding = 1
+ );
+ /*!
+ requires
+ - cell_size > 0
+ - filter_rows_padding > 0
+ - filter_cols_padding > 0
+ ensures
+ - maps a rectangle from fhog space to image space. In particular this function returns:
+ rectangle(fhog_to_image(rect.tl_corner(),cell_size,filter_rows_padding,filter_cols_padding),
+ fhog_to_image(rect.br_corner(),cell_size,filter_rows_padding,filter_cols_padding))
+ - Mapping from fhog space to image space is an invertible transformation. That
+ is, for any rectangle R we have R == image_to_fhog(fhog_to_image(R,cell_size,filter_rows_padding,filter_cols_padding),
+ cell_size,filter_rows_padding,filter_cols_padding).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mm1,
+ typename mm2
+ >
+ matrix<unsigned char> draw_fhog(
+ const dlib::array<array2d<T,mm1>,mm2>& hog,
+ const long cell_draw_size = 15,
+ const float min_response_threshold = 0.0
+ );
+ /*!
+ requires
+ - cell_draw_size > 0
+ - hog.size() == 31
+ ensures
+ - Interprets hog as a FHOG feature map output by extract_fhog_features() and
+ converts it into an image suitable for display on the screen. In particular,
+ we draw all the hog cells into a grayscale image in a way that shows the
+ magnitude and orientation of the gradient energy in each cell. The result is
+ then returned.
+ - The size of the cells in the output image will be rendered as cell_draw_size
+ pixels wide and tall.
+ - HOG cells with a response value less than min_response_threshold are not
+ drawn.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ matrix<unsigned char> draw_fhog (
+ const std::vector<matrix<T> >& hog,
+ const long cell_draw_size = 15,
+ const float min_response_threshold = 0.0
+ );
+ /*!
+ requires
+ - cell_draw_size > 0
+ - hog.size() == 31
+ ensures
+ - This function just converts the given hog object into an array<array2d<T>>
+ and passes it to the above draw_fhog() routine and returns the results.
+ - HOG cells with a response value less than min_response_threshold are not
+ drawn.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mm
+ >
+ matrix<unsigned char> draw_fhog(
+ const array2d<matrix<T,31,1>,mm>& hog,
+ const long cell_draw_size = 15,
+ const float min_response_threshold = 0.0
+ );
+ /*!
+ requires
+ - cell_draw_size > 0
+ ensures
+ - Interprets hog as a FHOG feature map output by extract_fhog_features() and
+ converts it into an image suitable for display on the screen. In particular,
+ we draw all the hog cells into a grayscale image in a way that shows the
+ magnitude and orientation of the gradient energy in each cell. The result is
+ then returned.
+ - The size of the cells in the output image will be rendered as cell_draw_size
+ pixels wide and tall.
+ - HOG cells with a response value less than min_response_threshold are not
+ drawn.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_fHOG_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_transforms/hough_transform.h b/ml/dlib/dlib/image_transforms/hough_transform.h
new file mode 100644
index 000000000..477b4dc2b
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/hough_transform.h
@@ -0,0 +1,358 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_HOUGH_tRANSFORM_Hh_
+#define DLIB_HOUGH_tRANSFORM_Hh_
+
+#include "hough_transform_abstract.h"
+#include "../image_processing/generic_image.h"
+#include "../geometry.h"
+#include "../algs.h"
+#include "assign_image.h"
+#include <limits>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class hough_transform
+ {
+
+ public:
+ explicit hough_transform (
+ unsigned long size_
+ ) : _size(size_)
+ {
+ DLIB_CASSERT(size_ > 0,
+ "\t hough_transform::hough_transform(size_)"
+ << "\n\t Invalid arguments given to this function."
+ );
+
+ even_size = _size - (_size%2);
+
+ const point cent = center(rectangle(0,0,size_-1,size_-1));
+ xcos_theta.set_size(size_, size_);
+ ysin_theta.set_size(size_, size_);
+
+ std::vector<double> cos_theta(size_), sin_theta(size_);
+ const double scale = 1<<16;
+ for (unsigned long t = 0; t < size_; ++t)
+ {
+ double theta = t*pi/even_size;
+
+ cos_theta[t] = scale*std::cos(theta)/sqrt_2;
+ sin_theta[t] = scale*std::sin(theta)/sqrt_2;
+ }
+ const double offset = scale*even_size/4.0 + 0.5;
+
+ for (unsigned long c = 0; c < size_; ++c)
+ {
+ const long x = c - cent.x();
+ for (unsigned long t = 0; t < size_; ++t)
+ xcos_theta(c,t) = static_cast<int32>(x*cos_theta[t] + offset);
+ }
+ for (unsigned long r = 0; r < size_; ++r)
+ {
+ const long y = r - cent.y();
+ for (unsigned long t = 0; t < size_; ++t)
+ ysin_theta(r,t) = static_cast<int32>(y*sin_theta[t] + offset);
+ }
+ }
+
+ unsigned long size(
+ ) const { return _size; }
+
+ long nr(
+ ) const { return _size; }
+
+ long nc(
+ ) const { return _size; }
+
+ std::pair<point, point> get_line (
+ const point& p
+ ) const
+ {
+ DLIB_ASSERT(rectangle(0,0,size()-1,size()-1).contains(p) == true,
+ "\t pair<point,point> hough_transform::get_line(point)"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t p: " << p
+ << "\n\t size(): " << size()
+ );
+
+ // First we compute the radius measured in pixels from the center and the theta
+ // angle in radians.
+ typedef dlib::vector<double,2> vect;
+ const rectangle box(0,0,size()-1,size()-1);
+ const vect cent = center(box);
+ double theta = p.x()-cent.x();
+ double radius = p.y()-cent.y();
+ theta = theta*pi/even_size;
+ radius = radius*sqrt_2 + 0.5;
+
+ // now make a line segment on the line.
+ vect v1 = cent + vect(size()+1000,0) + vect(0,radius);
+ vect v2 = cent - vect(size()+1000,0) + vect(0,radius);
+ point p1 = rotate_point(cent, v1, theta);
+ point p2 = rotate_point(cent, v2, theta);
+
+ clip_line_to_rectangle(box, p1, p2);
+
+ return std::make_pair(p1,p2);
+ }
+
+ template <
+ typename image_type
+ >
+ point get_best_hough_point (
+ const point& p,
+ const image_type& himg_
+ )
+ {
+ const const_image_view<image_type> himg(himg_);
+
+ DLIB_ASSERT(himg.nr() == size() && himg.nc() == size() &&
+ rectangle(0,0,size()-1,size()-1).contains(p) == true,
+ "\t point hough_transform::get_best_hough_point()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t himg.nr(): " << himg.nr()
+ << "\n\t himg.nc(): " << himg.nc()
+ << "\n\t size(): " << size()
+ << "\n\t p: " << p
+ );
+
+
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ COMPILE_TIME_ASSERT(pixel_traits<pixel_type>::grayscale == true);
+ pixel_type best_val = std::numeric_limits<pixel_type>::min();
+ point best_point;
+
+
+ const long max_n8 = (himg.nc()/8)*8;
+ const long max_n4 = (himg.nc()/4)*4;
+ const long r = p.y();
+ const long c = p.x();
+
+ const int32* ysin = &ysin_theta(r,0);
+ const int32* xcos = &xcos_theta(c,0);
+ long t = 0;
+ while(t < max_n8)
+ {
+ long rr0 = (*xcos++ + *ysin++)>>16;
+ long rr1 = (*xcos++ + *ysin++)>>16;
+ long rr2 = (*xcos++ + *ysin++)>>16;
+ long rr3 = (*xcos++ + *ysin++)>>16;
+ long rr4 = (*xcos++ + *ysin++)>>16;
+ long rr5 = (*xcos++ + *ysin++)>>16;
+ long rr6 = (*xcos++ + *ysin++)>>16;
+ long rr7 = (*xcos++ + *ysin++)>>16;
+
+ if (himg[rr0][t++] > best_val)
+ {
+ best_val = himg[rr0][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr0;
+ }
+ if (himg[rr1][t++] > best_val)
+ {
+ best_val = himg[rr1][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr1;
+ }
+ if (himg[rr2][t++] > best_val)
+ {
+ best_val = himg[rr2][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr2;
+ }
+ if (himg[rr3][t++] > best_val)
+ {
+ best_val = himg[rr3][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr3;
+ }
+ if (himg[rr4][t++] > best_val)
+ {
+ best_val = himg[rr4][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr4;
+ }
+ if (himg[rr5][t++] > best_val)
+ {
+ best_val = himg[rr5][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr5;
+ }
+ if (himg[rr6][t++] > best_val)
+ {
+ best_val = himg[rr6][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr6;
+ }
+ if (himg[rr7][t++] > best_val)
+ {
+ best_val = himg[rr7][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr7;
+ }
+ }
+ while(t < max_n4)
+ {
+ long rr0 = (*xcos++ + *ysin++)>>16;
+ long rr1 = (*xcos++ + *ysin++)>>16;
+ long rr2 = (*xcos++ + *ysin++)>>16;
+ long rr3 = (*xcos++ + *ysin++)>>16;
+ if (himg[rr0][t++] > best_val)
+ {
+ best_val = himg[rr0][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr0;
+ }
+ if (himg[rr1][t++] > best_val)
+ {
+ best_val = himg[rr1][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr1;
+ }
+ if (himg[rr2][t++] > best_val)
+ {
+ best_val = himg[rr2][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr2;
+ }
+ if (himg[rr3][t++] > best_val)
+ {
+ best_val = himg[rr3][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr3;
+ }
+ }
+ while(t < himg.nc())
+ {
+ long rr0 = (*xcos++ + *ysin++)>>16;
+ if (himg[rr0][t++] > best_val)
+ {
+ best_val = himg[rr0][t-1];
+ best_point.x() = t-1;
+ best_point.y() = rr0;
+ }
+ }
+
+ return best_point;
+ }
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void operator() (
+ const in_image_type& img_,
+ const rectangle& box,
+ out_image_type& himg_
+ ) const
+ {
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+
+ DLIB_CASSERT(box.width() == size() && box.height() == size(),
+ "\t hough_transform::hough_transform(size_)"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t box.width(): " << box.width()
+ << "\n\t box.height(): " << box.height()
+ << "\n\t size(): " << size()
+ );
+
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type>::grayscale == true);
+ COMPILE_TIME_ASSERT(pixel_traits<out_pixel_type>::grayscale == true);
+
+ const_image_view<in_image_type> img(img_);
+ image_view<out_image_type> himg(himg_);
+
+ himg.set_size(size(), size());
+ assign_all_pixels(himg, 0);
+
+ const rectangle area = box.intersect(get_rect(img));
+
+ const long max_n8 = (himg.nc()/8)*8;
+ const long max_n4 = (himg.nc()/4)*4;
+ for (long r = area.top(); r <= area.bottom(); ++r)
+ {
+ const int32* ysin_base = &ysin_theta(r-box.top(),0);
+ for (long c = area.left(); c <= area.right(); ++c)
+ {
+ const out_pixel_type val = static_cast<out_pixel_type>(img[r][c]);
+ if (val != 0)
+ {
+ /*
+ // The code in this comment is equivalent to the more complex but
+ // faster code below. We keep this simple version of the Hough
+ // transform implementation here just to document what it's doing
+ // more clearly.
+ const point cent = center(box);
+ const long x = c - cent.x();
+ const long y = r - cent.y();
+ for (long t = 0; t < himg.nc(); ++t)
+ {
+ double theta = t*pi/even_size;
+ double radius = (x*std::cos(theta) + y*std::sin(theta))/sqrt_2 + even_size/2 + 0.5;
+ long rr = static_cast<long>(radius);
+ himg[rr][t] += val;
+ }
+ continue;
+ */
+
+ // Run the speed optimized version of the code in the above
+ // comment.
+ const int32* ysin = ysin_base;
+ const int32* xcos = &xcos_theta(c-box.left(),0);
+ long t = 0;
+ while(t < max_n8)
+ {
+ long rr0 = (*xcos++ + *ysin++)>>16;
+ long rr1 = (*xcos++ + *ysin++)>>16;
+ long rr2 = (*xcos++ + *ysin++)>>16;
+ long rr3 = (*xcos++ + *ysin++)>>16;
+ long rr4 = (*xcos++ + *ysin++)>>16;
+ long rr5 = (*xcos++ + *ysin++)>>16;
+ long rr6 = (*xcos++ + *ysin++)>>16;
+ long rr7 = (*xcos++ + *ysin++)>>16;
+
+ himg[rr0][t++] += val;
+ himg[rr1][t++] += val;
+ himg[rr2][t++] += val;
+ himg[rr3][t++] += val;
+ himg[rr4][t++] += val;
+ himg[rr5][t++] += val;
+ himg[rr6][t++] += val;
+ himg[rr7][t++] += val;
+ }
+ while(t < max_n4)
+ {
+ long rr0 = (*xcos++ + *ysin++)>>16;
+ long rr1 = (*xcos++ + *ysin++)>>16;
+ long rr2 = (*xcos++ + *ysin++)>>16;
+ long rr3 = (*xcos++ + *ysin++)>>16;
+ himg[rr0][t++] += val;
+ himg[rr1][t++] += val;
+ himg[rr2][t++] += val;
+ himg[rr3][t++] += val;
+ }
+ while(t < himg.nc())
+ {
+ long rr0 = (*xcos++ + *ysin++)>>16;
+ himg[rr0][t++] += val;
+ }
+ }
+ }
+ }
+ }
+
+ private:
+
+ unsigned long _size;
+ unsigned long even_size; // equal to _size if _size is even, otherwise equal to _size-1.
+ matrix<int32> xcos_theta, ysin_theta;
+ };
+}
+
+#endif // DLIB_HOUGH_tRANSFORM_Hh_
+
diff --git a/ml/dlib/dlib/image_transforms/hough_transform_abstract.h b/ml/dlib/dlib/image_transforms/hough_transform_abstract.h
new file mode 100644
index 000000000..f0ff2b550
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/hough_transform_abstract.h
@@ -0,0 +1,145 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_HOUGH_tRANSFORM_ABSTRACT_Hh_
+#ifdef DLIB_HOUGH_tRANSFORM_ABSTRACT_Hh_
+
+#include "../geometry.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class hough_transform
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for computing the line finding version of the Hough
+ transform given some kind of edge detection image as input. It also allows
+ the edge pixels to be weighted such that higher weighted edge pixels
+ contribute correspondingly more to the output of the Hough transform,
+ allowing stronger edges to create correspondingly stronger line detections
+ in the final Hough transform.
+
+ THREAD SAFETY
+ It is safe for multiple threads to make concurrent accesses to this object
+ without synchronization.
+ !*/
+
+ public:
+
+ explicit hough_transform (
+ unsigned long size_
+ );
+ /*!
+ requires
+ - size_ > 0
+ ensures
+ - This object will compute Hough transforms that are size_ by size_ pixels.
+ This is in terms of both the Hough accumulator array size as well as the
+ input image size.
+ - #size() == size_
+ !*/
+
+ unsigned long size(
+ ) const;
+ /*!
+ ensures
+ - returns the size of the Hough transforms generated by this object. In
+ particular, this object creates Hough transform images that are size() by
+ size() pixels in size.
+ !*/
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns size()
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns size()
+ !*/
+
+ std::pair<point, point> get_line (
+ const point& p
+ ) const;
+ /*!
+ requires
+ - rectangle(0,0,size()-1,size()-1).contains(p) == true
+ (i.e. p must be a point inside the Hough accumulator array)
+ ensures
+ - returns the line segment in the original image space corresponding
+ to Hough transform point p.
+ - The returned points are inside rectangle(0,0,size()-1,size()-1).
+ !*/
+
+ template <
+ typename image_type
+ >
+ point get_best_hough_point (
+ const point& p,
+ const image_type& himg
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain grayscale pixels.
+ - himg.nr() == size()
+ - himg.nc() == size()
+ - rectangle(0,0,size()-1,size()-1).contains(p) == true
+ ensures
+ - This function interprets himg as a Hough image and p as a point in the
+ original image space. Given this, it finds the maximum scoring line that
+ passes though p. That is, it checks all the Hough accumulator bins in
+ himg corresponding to lines though p and returns the location with the
+ largest score.
+ - returns a point X such that get_rect(himg).contains(X) == true
+ !*/
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void operator() (
+ const in_image_type& img,
+ const rectangle& box,
+ out_image_type& himg
+ ) const;
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain grayscale pixels.
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain grayscale pixels.
+ - box.width() == size()
+ - box.height() == size()
+ ensures
+ - Computes the Hough transform of the part of img contained within box.
+ In particular, we do a grayscale version of the Hough transform where any
+ non-zero pixel in img is treated as a potential component of a line and
+ accumulated into the Hough accumulator #himg. However, rather than
+ adding 1 to each relevant accumulator bin we add the value of the pixel
+ in img to each Hough accumulator bin. This means that, if all the
+ pixels in img are 0 or 1 then this routine performs a normal Hough
+ transform. However, if some pixels have larger values then they will be
+ weighted correspondingly more in the resulting Hough transform.
+ - #himg.nr() == size()
+ - #himg.nc() == size()
+ - #himg is the Hough transform of the part of img contained in box. Each
+ point in #himg corresponds to a line in the input box. In particular,
+ the line for #himg[y][x] is given by get_line(point(x,y)). Also, when
+ viewing the #himg image, the x-axis gives the angle of the line and the
+ y-axis the distance of the line from the center of the box.
+ !*/
+
+ };
+}
+
+#endif // DLIB_HOUGH_tRANSFORM_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_transforms/image_pyramid.h b/ml/dlib/dlib/image_transforms/image_pyramid.h
new file mode 100644
index 000000000..3efed30d8
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/image_pyramid.h
@@ -0,0 +1,1238 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IMAGE_PYRaMID_Hh_
+#define DLIB_IMAGE_PYRaMID_Hh_
+
+#include "image_pyramid_abstract.h"
+#include "../pixel.h"
+#include "../array2d.h"
+#include "../geometry.h"
+#include "spatial_filtering.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class pyramid_disable : noncopyable
+ {
+ public:
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>&
+ ) const
+ {
+ return vector<double,2>(0,0);
+ }
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>&
+ ) const
+ {
+ return vector<double,2>(0,0);
+ }
+
+ // -----------------------------
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const
+ {
+ if (levels == 0)
+ return p;
+ else
+ return vector<double,2>(0,0);
+ }
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const
+ {
+ if (levels == 0)
+ return p;
+ else
+ return vector<double,2>(0,0);
+ }
+
+ // -----------------------------
+
+ drectangle rect_up (
+ const drectangle& rect
+ ) const
+ {
+ return drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner()));
+ }
+
+ drectangle rect_up (
+ const drectangle& rect,
+ unsigned int levels
+ ) const
+ {
+ return drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels));
+ }
+
+ // -----------------------------
+
+ drectangle rect_down (
+ const drectangle& rect
+ ) const
+ {
+ return drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner()));
+ }
+
+ drectangle rect_down (
+ const drectangle& rect,
+ unsigned int levels
+ ) const
+ {
+ return drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels));
+ }
+
+ // -----------------------------
+
+ public:
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void operator() (
+ // we do this #ifdef stuff to avoid compiler warnings about unused variables.
+#ifdef ENABLE_ASSERTS
+ const in_image_type& original,
+#else
+ const in_image_type& ,
+#endif
+ out_image_type& down
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_same_object(original, down) == false,
+ "\t void pyramid_disable::operator()"
+ << "\n\t is_same_object(original, down): " << is_same_object(original, down)
+ << "\n\t this: " << this
+ );
+
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ set_image_size(down, 0, 0);
+ }
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ image_type& img
+ ) const
+ {
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<pixel_type>::has_alpha == false );
+ set_image_size(img, 0, 0);
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ class pyramid_down_2_1 : noncopyable
+ {
+ public:
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>& p
+ ) const
+ {
+ return p/2.0 - vector<double,2>(1.25,0.75);
+ }
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>& p
+ ) const
+ {
+ return (p + vector<T,2>(1.25,0.75))*2;
+ }
+
+ // -----------------------------
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const
+ {
+ vector<double,2> temp = p;
+ for (unsigned int i = 0; i < levels; ++i)
+ temp = point_down(temp);
+ return temp;
+ }
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const
+ {
+ vector<double,2> temp = p;
+ for (unsigned int i = 0; i < levels; ++i)
+ temp = point_up(temp);
+ return temp;
+ }
+
+ // -----------------------------
+
+ drectangle rect_up (
+ const drectangle& rect
+ ) const
+ {
+ return drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner()));
+ }
+
+ drectangle rect_up (
+ const drectangle& rect,
+ unsigned int levels
+ ) const
+ {
+ return drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels));
+ }
+
+ // -----------------------------
+
+ drectangle rect_down (
+ const drectangle& rect
+ ) const
+ {
+ return drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner()));
+ }
+
+ drectangle rect_down (
+ const drectangle& rect,
+ unsigned int levels
+ ) const
+ {
+ return drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels));
+ }
+
+ // -----------------------------
+
+ private:
+ template <typename T, typename U>
+ struct both_images_rgb
+ {
+ typedef typename image_traits<T>::pixel_type T_pix;
+ typedef typename image_traits<U>::pixel_type U_pix;
+ const static bool value = pixel_traits<T_pix>::rgb && pixel_traits<U_pix>::rgb;
+ };
+ public:
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ typename disable_if<both_images_rgb<in_image_type,out_image_type> >::type operator() (
+ const in_image_type& original_,
+ out_image_type& down_
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(original_, down_) == false,
+ "\t void pyramid_down_2_1::operator()"
+ << "\n\t is_same_object(original_, down_): " << is_same_object(original_, down_)
+ << "\n\t this: " << this
+ );
+
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ const_image_view<in_image_type> original(original_);
+ image_view<out_image_type> down(down_);
+
+ if (original.nr() <= 8 || original.nc() <= 8)
+ {
+ down.clear();
+ return;
+ }
+
+ typedef typename pixel_traits<in_pixel_type>::basic_pixel_type bp_type;
+ typedef typename promote<bp_type>::type ptype;
+ array2d<ptype> temp_img;
+ temp_img.set_size(original.nr(), (original.nc()-3)/2);
+ down.set_size((original.nr()-3)/2, (original.nc()-3)/2);
+
+
+ // This function applies a 5x5 Gaussian filter to the image. It
+ // does this by separating the filter into its horizontal and vertical
+ // components and then downsamples the image by dropping every other
+ // row and column. Note that we can do these things all together in
+ // one step.
+
+ // apply row filter
+ for (long r = 0; r < temp_img.nr(); ++r)
+ {
+ long oc = 0;
+ for (long c = 0; c < temp_img.nc(); ++c)
+ {
+ ptype pix1;
+ ptype pix2;
+ ptype pix3;
+ ptype pix4;
+ ptype pix5;
+
+ assign_pixel(pix1, original[r][oc]);
+ assign_pixel(pix2, original[r][oc+1]);
+ assign_pixel(pix3, original[r][oc+2]);
+ assign_pixel(pix4, original[r][oc+3]);
+ assign_pixel(pix5, original[r][oc+4]);
+
+ pix2 *= 4;
+ pix3 *= 6;
+ pix4 *= 4;
+
+ assign_pixel(temp_img[r][c], pix1 + pix2 + pix3 + pix4 + pix5);
+ oc += 2;
+ }
+ }
+
+
+ // apply column filter
+ long dr = 0;
+ for (long r = 2; r < temp_img.nr()-2; r += 2)
+ {
+ for (long c = 0; c < temp_img.nc(); ++c)
+ {
+ ptype temp = temp_img[r-2][c] +
+ temp_img[r-1][c]*4 +
+ temp_img[r ][c]*6 +
+ temp_img[r+1][c]*4 +
+ temp_img[r+2][c];
+
+ assign_pixel(down[dr][c],temp/256);
+ }
+ ++dr;
+ }
+
+ }
+
+ private:
+ struct rgbptype
+ {
+ uint16 red;
+ uint16 green;
+ uint16 blue;
+ };
+ public:
+ // ------------------------------------------
+ // OVERLOAD FOR RGB TO RGB IMAGES
+ // ------------------------------------------
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ typename enable_if<both_images_rgb<in_image_type,out_image_type> >::type operator() (
+ const in_image_type& original_,
+ out_image_type& down_
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(original_, down_) == false,
+ "\t void pyramid_down_2_1::operator()"
+ << "\n\t is_same_object(original_, down_): " << is_same_object(original_, down_)
+ << "\n\t this: " << this
+ );
+
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ const_image_view<in_image_type> original(original_);
+ image_view<out_image_type> down(down_);
+
+ if (original.nr() <= 8 || original.nc() <= 8)
+ {
+ down.clear();
+ return;
+ }
+
+ array2d<rgbptype> temp_img;
+ temp_img.set_size(original.nr(), (original.nc()-3)/2);
+ down.set_size((original.nr()-3)/2, (original.nc()-3)/2);
+
+
+ // This function applies a 5x5 Gaussian filter to the image. It
+ // does this by separating the filter into its horizontal and vertical
+ // components and then downsamples the image by dropping every other
+ // row and column. Note that we can do these things all together in
+ // one step.
+
+ // apply row filter
+ for (long r = 0; r < temp_img.nr(); ++r)
+ {
+ long oc = 0;
+ for (long c = 0; c < temp_img.nc(); ++c)
+ {
+ rgbptype pix1;
+ rgbptype pix2;
+ rgbptype pix3;
+ rgbptype pix4;
+ rgbptype pix5;
+
+ pix1.red = original[r][oc].red;
+ pix2.red = original[r][oc+1].red;
+ pix3.red = original[r][oc+2].red;
+ pix4.red = original[r][oc+3].red;
+ pix5.red = original[r][oc+4].red;
+ pix1.green = original[r][oc].green;
+ pix2.green = original[r][oc+1].green;
+ pix3.green = original[r][oc+2].green;
+ pix4.green = original[r][oc+3].green;
+ pix5.green = original[r][oc+4].green;
+ pix1.blue = original[r][oc].blue;
+ pix2.blue = original[r][oc+1].blue;
+ pix3.blue = original[r][oc+2].blue;
+ pix4.blue = original[r][oc+3].blue;
+ pix5.blue = original[r][oc+4].blue;
+
+ pix2.red *= 4;
+ pix3.red *= 6;
+ pix4.red *= 4;
+
+ pix2.green *= 4;
+ pix3.green *= 6;
+ pix4.green *= 4;
+
+ pix2.blue *= 4;
+ pix3.blue *= 6;
+ pix4.blue *= 4;
+
+ rgbptype temp;
+ temp.red = pix1.red + pix2.red + pix3.red + pix4.red + pix5.red;
+ temp.green = pix1.green + pix2.green + pix3.green + pix4.green + pix5.green;
+ temp.blue = pix1.blue + pix2.blue + pix3.blue + pix4.blue + pix5.blue;
+
+ temp_img[r][c] = temp;
+
+ oc += 2;
+ }
+ }
+
+
+ // apply column filter
+ long dr = 0;
+ for (long r = 2; r < temp_img.nr()-2; r += 2)
+ {
+ for (long c = 0; c < temp_img.nc(); ++c)
+ {
+ rgbptype temp;
+ temp.red = temp_img[r-2][c].red +
+ temp_img[r-1][c].red*4 +
+ temp_img[r ][c].red*6 +
+ temp_img[r+1][c].red*4 +
+ temp_img[r+2][c].red;
+ temp.green = temp_img[r-2][c].green +
+ temp_img[r-1][c].green*4 +
+ temp_img[r ][c].green*6 +
+ temp_img[r+1][c].green*4 +
+ temp_img[r+2][c].green;
+ temp.blue = temp_img[r-2][c].blue +
+ temp_img[r-1][c].blue*4 +
+ temp_img[r ][c].blue*6 +
+ temp_img[r+1][c].blue*4 +
+ temp_img[r+2][c].blue;
+
+ down[dr][c].red = temp.red/256;
+ down[dr][c].green = temp.green/256;
+ down[dr][c].blue = temp.blue/256;
+ }
+ ++dr;
+ }
+
+ }
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ image_type& img
+ ) const
+ {
+ image_type temp;
+ (*this)(img, temp);
+ swap(temp, img);
+ }
+
+ private:
+
+
+ };
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ class pyramid_down_3_2 : noncopyable
+ {
+ public:
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>& p
+ ) const
+ {
+ const double ratio = 2.0/3.0;
+ return p*ratio - vector<double,2>(1,1);
+ }
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>& p
+ ) const
+ {
+ const double ratio = 3.0/2.0;
+ return p*ratio + vector<T,2>(ratio,ratio);
+ }
+
+ // -----------------------------
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const
+ {
+ vector<double,2> temp = p;
+ for (unsigned int i = 0; i < levels; ++i)
+ temp = point_down(temp);
+ return temp;
+ }
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const
+ {
+ vector<double,2> temp = p;
+ for (unsigned int i = 0; i < levels; ++i)
+ temp = point_up(temp);
+ return temp;
+ }
+
+ // -----------------------------
+
+ drectangle rect_up (
+ const drectangle& rect
+ ) const
+ {
+ return drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner()));
+ }
+
+ drectangle rect_up (
+ const drectangle& rect,
+ unsigned int levels
+ ) const
+ {
+ return drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels));
+ }
+
+ // -----------------------------
+
+ drectangle rect_down (
+ const drectangle& rect
+ ) const
+ {
+ return drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner()));
+ }
+
+ drectangle rect_down (
+ const drectangle& rect,
+ unsigned int levels
+ ) const
+ {
+ return drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels));
+ }
+
+ // -----------------------------
+
+ private:
+ template <typename T, typename U>
+ struct both_images_rgb
+ {
+ typedef typename image_traits<T>::pixel_type T_pix;
+ typedef typename image_traits<U>::pixel_type U_pix;
+ const static bool value = pixel_traits<T_pix>::rgb && pixel_traits<U_pix>::rgb;
+ };
+ public:
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ typename disable_if<both_images_rgb<in_image_type,out_image_type> >::type operator() (
+ const in_image_type& original_,
+ out_image_type& down_
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_same_object(original_, down_) == false,
+ "\t void pyramid_down_3_2::operator()"
+ << "\n\t is_same_object(original_, down_): " << is_same_object(original_, down_)
+ << "\n\t this: " << this
+ );
+
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ const_image_view<in_image_type> original(original_);
+ image_view<out_image_type> down(down_);
+
+ if (original.nr() <= 8 || original.nc() <= 8)
+ {
+ down.clear();
+ return;
+ }
+
+ const long size_in = 3;
+ const long size_out = 2;
+
+ typedef typename pixel_traits<in_pixel_type>::basic_pixel_type bp_type;
+ typedef typename promote<bp_type>::type ptype;
+ const long full_nr = size_out*((original.nr()-2)/size_in);
+ const long part_nr = (size_out*(original.nr()-2))/size_in;
+ const long full_nc = size_out*((original.nc()-2)/size_in);
+ const long part_nc = (size_out*(original.nc()-2))/size_in;
+ down.set_size(part_nr, part_nc);
+
+
+ long rr = 1;
+ long r;
+ for (r = 0; r < full_nr; r+=size_out)
+ {
+ long cc = 1;
+ long c;
+ for (c = 0; c < full_nc; c+=size_out)
+ {
+ ptype block[size_in][size_in];
+ separable_3x3_filter_block_grayscale(block, original_, rr, cc, 2, 12, 2);
+
+ // bi-linearly interpolate block
+ assign_pixel(down[r][c] , (block[0][0]*9 + block[1][0]*3 + block[0][1]*3 + block[1][1])/(16*256));
+ assign_pixel(down[r][c+1] , (block[0][2]*9 + block[1][2]*3 + block[0][1]*3 + block[1][1])/(16*256));
+ assign_pixel(down[r+1][c] , (block[2][0]*9 + block[1][0]*3 + block[2][1]*3 + block[1][1])/(16*256));
+ assign_pixel(down[r+1][c+1] , (block[2][2]*9 + block[1][2]*3 + block[2][1]*3 + block[1][1])/(16*256));
+
+ cc += size_in;
+ }
+ if (part_nc - full_nc == 1)
+ {
+ ptype block[size_in][2];
+ separable_3x3_filter_block_grayscale(block, original_, rr, cc, 2, 12, 2);
+
+ // bi-linearly interpolate partial block
+ assign_pixel(down[r][c] , (block[0][0]*9 + block[1][0]*3 + block[0][1]*3 + block[1][1])/(16*256));
+ assign_pixel(down[r+1][c] , (block[2][0]*9 + block[1][0]*3 + block[2][1]*3 + block[1][1])/(16*256));
+ }
+ rr += size_in;
+ }
+ if (part_nr - full_nr == 1)
+ {
+ long cc = 1;
+ long c;
+ for (c = 0; c < full_nc; c+=size_out)
+ {
+ ptype block[2][size_in];
+ separable_3x3_filter_block_grayscale(block, original_, rr, cc, 2, 12, 2);
+
+ // bi-linearly interpolate partial block
+ assign_pixel(down[r][c] , (block[0][0]*9 + block[1][0]*3 + block[0][1]*3 + block[1][1])/(16*256));
+ assign_pixel(down[r][c+1] , (block[0][2]*9 + block[1][2]*3 + block[0][1]*3 + block[1][1])/(16*256));
+
+ cc += size_in;
+ }
+ if (part_nc - full_nc == 1)
+ {
+ ptype block[2][2];
+ separable_3x3_filter_block_grayscale(block, original_, rr, cc, 2, 12, 2);
+
+ // bi-linearly interpolate partial block
+ assign_pixel(down[r][c] , (block[0][0]*9 + block[1][0]*3 + block[0][1]*3 + block[1][1])/(16*256));
+ }
+ }
+
+ }
+
+ private:
+ struct rgbptype
+ {
+ uint32 red;
+ uint32 green;
+ uint32 blue;
+ };
+
+ public:
+ // ------------------------------------------
+ // OVERLOAD FOR RGB TO RGB IMAGES
+ // ------------------------------------------
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ typename enable_if<both_images_rgb<in_image_type,out_image_type> >::type operator() (
+ const in_image_type& original_,
+ out_image_type& down_
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(original_, down_) == false,
+ "\t void pyramid_down_3_2::operator()"
+ << "\n\t is_same_object(original_, down_): " << is_same_object(original_, down_)
+ << "\n\t this: " << this
+ );
+
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ const_image_view<in_image_type> original(original_);
+ image_view<out_image_type> down(down_);
+
+ if (original.nr() <= 8 || original.nc() <= 8)
+ {
+ down.clear();
+ return;
+ }
+
+ const long size_in = 3;
+ const long size_out = 2;
+
+ const long full_nr = size_out*((original.nr()-2)/size_in);
+ const long part_nr = (size_out*(original.nr()-2))/size_in;
+ const long full_nc = size_out*((original.nc()-2)/size_in);
+ const long part_nc = (size_out*(original.nc()-2))/size_in;
+ down.set_size(part_nr, part_nc);
+
+
+ long rr = 1;
+ long r;
+ for (r = 0; r < full_nr; r+=size_out)
+ {
+ long cc = 1;
+ long c;
+ for (c = 0; c < full_nc; c+=size_out)
+ {
+ rgbptype block[size_in][size_in];
+ separable_3x3_filter_block_rgb(block, original_, rr, cc, 2, 12, 2);
+
+ // bi-linearly interpolate block
+ down[r][c].red = (block[0][0].red*9 + block[1][0].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256);
+ down[r][c].green = (block[0][0].green*9 + block[1][0].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256);
+ down[r][c].blue = (block[0][0].blue*9 + block[1][0].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256);
+
+ down[r][c+1].red = (block[0][2].red*9 + block[1][2].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256);
+ down[r][c+1].green = (block[0][2].green*9 + block[1][2].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256);
+ down[r][c+1].blue = (block[0][2].blue*9 + block[1][2].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256);
+
+ down[r+1][c].red = (block[2][0].red*9 + block[1][0].red*3 + block[2][1].red*3 + block[1][1].red)/(16*256);
+ down[r+1][c].green = (block[2][0].green*9 + block[1][0].green*3 + block[2][1].green*3 + block[1][1].green)/(16*256);
+ down[r+1][c].blue = (block[2][0].blue*9 + block[1][0].blue*3 + block[2][1].blue*3 + block[1][1].blue)/(16*256);
+
+ down[r+1][c+1].red = (block[2][2].red*9 + block[1][2].red*3 + block[2][1].red*3 + block[1][1].red)/(16*256);
+ down[r+1][c+1].green = (block[2][2].green*9 + block[1][2].green*3 + block[2][1].green*3 + block[1][1].green)/(16*256);
+ down[r+1][c+1].blue = (block[2][2].blue*9 + block[1][2].blue*3 + block[2][1].blue*3 + block[1][1].blue)/(16*256);
+
+ cc += size_in;
+ }
+ if (part_nc - full_nc == 1)
+ {
+ rgbptype block[size_in][2];
+ separable_3x3_filter_block_rgb(block, original_, rr, cc, 2, 12, 2);
+
+ // bi-linearly interpolate partial block
+ down[r][c].red = (block[0][0].red*9 + block[1][0].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256);
+ down[r][c].green = (block[0][0].green*9 + block[1][0].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256);
+ down[r][c].blue = (block[0][0].blue*9 + block[1][0].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256);
+
+ down[r+1][c].red = (block[2][0].red*9 + block[1][0].red*3 + block[2][1].red*3 + block[1][1].red)/(16*256);
+ down[r+1][c].green = (block[2][0].green*9 + block[1][0].green*3 + block[2][1].green*3 + block[1][1].green)/(16*256);
+ down[r+1][c].blue = (block[2][0].blue*9 + block[1][0].blue*3 + block[2][1].blue*3 + block[1][1].blue)/(16*256);
+ }
+ rr += size_in;
+ }
+ if (part_nr - full_nr == 1)
+ {
+ long cc = 1;
+ long c;
+ for (c = 0; c < full_nc; c+=size_out)
+ {
+ rgbptype block[2][size_in];
+ separable_3x3_filter_block_rgb(block, original_, rr, cc, 2, 12, 2);
+
+ // bi-linearly interpolate partial block
+ down[r][c].red = (block[0][0].red*9 + block[1][0].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256);
+ down[r][c].green = (block[0][0].green*9 + block[1][0].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256);
+ down[r][c].blue = (block[0][0].blue*9 + block[1][0].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256);
+
+ down[r][c+1].red = (block[0][2].red*9 + block[1][2].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256);
+ down[r][c+1].green = (block[0][2].green*9 + block[1][2].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256);
+ down[r][c+1].blue = (block[0][2].blue*9 + block[1][2].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256);
+
+ cc += size_in;
+ }
+ if (part_nc - full_nc == 1)
+ {
+ rgbptype block[2][2];
+ separable_3x3_filter_block_rgb(block, original_, rr, cc, 2, 12, 2);
+
+ // bi-linearly interpolate partial block
+ down[r][c].red = (block[0][0].red*9 + block[1][0].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256);
+ down[r][c].green = (block[0][0].green*9 + block[1][0].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256);
+ down[r][c].blue = (block[0][0].blue*9 + block[1][0].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256);
+ }
+ }
+ }
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ image_type& img
+ ) const
+ {
+ image_type temp;
+ (*this)(img, temp);
+ swap(temp, img);
+ }
+ private:
+
+
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned int N
+ >
+ class pyramid_down : noncopyable
+ {
+ public:
+
+ COMPILE_TIME_ASSERT(N > 0);
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>& p
+ ) const
+ {
+ const double ratio = (N-1.0)/N;
+ return (p - 0.3)*ratio;
+ }
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>& p
+ ) const
+ {
+ const double ratio = N/(N-1.0);
+ return p*ratio + 0.3;
+ }
+
+ // -----------------------------
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const
+ {
+ vector<double,2> temp = p;
+ for (unsigned int i = 0; i < levels; ++i)
+ temp = point_down(temp);
+ return temp;
+ }
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const
+ {
+ vector<double,2> temp = p;
+ for (unsigned int i = 0; i < levels; ++i)
+ temp = point_up(temp);
+ return temp;
+ }
+
+ // -----------------------------
+
+ drectangle rect_up (
+ const drectangle& rect
+ ) const
+ {
+ return drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner()));
+ }
+
+ drectangle rect_up (
+ const drectangle& rect,
+ unsigned int levels
+ ) const
+ {
+ return drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels));
+ }
+
+ // -----------------------------
+
+ drectangle rect_down (
+ const drectangle& rect
+ ) const
+ {
+ return drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner()));
+ }
+
+ drectangle rect_down (
+ const drectangle& rect,
+ unsigned int levels
+ ) const
+ {
+ return drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels));
+ }
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void operator() (
+ const in_image_type& original,
+ out_image_type& down
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_same_object(original, down) == false,
+ "\t void pyramid_down::operator()"
+ << "\n\t is_same_object(original, down): " << is_same_object(original, down)
+ << "\n\t this: " << this
+ );
+
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+
+ set_image_size(down, ((N-1)*num_rows(original))/N+0.5, ((N-1)*num_columns(original))/N+0.5);
+ resize_image(original, down);
+ }
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ image_type& img
+ ) const
+ {
+ image_type temp;
+ (*this)(img, temp);
+ swap(temp, img);
+ }
+ };
+
+ template <>
+ class pyramid_down<1> : public pyramid_disable {};
+
+ template <>
+ class pyramid_down<2> : public dlib::impl::pyramid_down_2_1 {};
+
+ template <>
+ class pyramid_down<3> : public dlib::impl::pyramid_down_3_2 {};
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <unsigned int N>
+ double pyramid_rate(const pyramid_down<N>&)
+ {
+ return (N-1.0)/N;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <unsigned int N>
+ void find_pyramid_down_output_image_size(
+ const pyramid_down<N>& pyr,
+ long& nr,
+ long& nc
+ )
+ {
+ const double rate = pyramid_rate(pyr);
+ nr = std::floor(rate*nr);
+ nc = std::floor(rate*nc);
+ }
+
+ inline void find_pyramid_down_output_image_size(
+ const pyramid_down<3>& /*pyr*/,
+ long& nr,
+ long& nc
+ )
+ {
+ nr = 2*(nr-2)/3;
+ nc = 2*(nc-2)/3;
+ }
+
+ inline void find_pyramid_down_output_image_size(
+ const pyramid_down<2>& /*pyr*/,
+ long& nr,
+ long& nc
+ )
+ {
+ nr = (nr-3)/2;
+ nc = (nc-3)/2;
+ }
+
+ inline void find_pyramid_down_output_image_size(
+ const pyramid_down<1>& /*pyr*/,
+ long& nr,
+ long& nc
+ )
+ {
+ nr = 0;
+ nc = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename pyramid_type>
+ void compute_tiled_image_pyramid_details (
+ const pyramid_type& pyr,
+ long nr,
+ long nc,
+ const unsigned long padding,
+ const unsigned long outer_padding,
+ std::vector<rectangle>& rects,
+ long& pyramid_image_nr,
+ long& pyramid_image_nc
+ )
+ {
+ rects.clear();
+ if (nr*nc == 0)
+ {
+ pyramid_image_nr = 0;
+ pyramid_image_nc = 0;
+ return;
+ }
+
+ const long min_height = 5;
+ rects.reserve(100);
+ rects.push_back(rectangle(nc,nr));
+ // build the whole pyramid
+ while(true)
+ {
+ find_pyramid_down_output_image_size(pyr, nr, nc);
+ if (nr*nc == 0 || nr < min_height)
+ break;
+ rects.push_back(rectangle(nc,nr));
+ }
+
+ // figure out output image size
+ long total_height = 0;
+ for (auto&& i : rects)
+ total_height += i.height()+padding;
+ total_height -= padding*2; // don't add unnecessary padding to the very right side.
+ long height = 0;
+ long prev_width = 0;
+ for (auto&& i : rects)
+ {
+ // Figure out how far we go on the first column. We go until the next image can
+ // fit next to the previous one, which means we can double back for the second
+ // column of images.
+ if (i.width() <= rects[0].width()-prev_width-(long)padding &&
+ (height-rects[0].height())*2 >= (total_height-rects[0].height()))
+ {
+ break;
+ }
+ height += i.height() + padding;
+ prev_width = i.width();
+ }
+ height -= padding; // don't add unnecessary padding to the very right side.
+
+ const long width = rects[0].width();
+ pyramid_image_nr = height+outer_padding*2;
+ pyramid_image_nc = width+outer_padding*2;
+
+
+ long y = outer_padding;
+ size_t i = 0;
+ while(y < height+(long)outer_padding && i < rects.size())
+ {
+ rects[i] = translate_rect(rects[i],point(outer_padding,y));
+ DLIB_ASSERT(rectangle(pyramid_image_nc,pyramid_image_nr).contains(rects[i]));
+ y += rects[i].height()+padding;
+ ++i;
+ }
+ y -= padding;
+ while (i < rects.size())
+ {
+ point p1(outer_padding+width-1,y-1);
+ point p2 = p1 - rects[i].br_corner();
+ rectangle rect(p1,p2);
+ DLIB_ASSERT(rectangle(pyramid_image_nc,pyramid_image_nr).contains(rect));
+ // don't keep going on the last row if it would intersect the original image.
+ if (!rects[0].intersect(rect).is_empty())
+ break;
+
+ rects[i] = rect;
+ y -= rects[i].height()+padding;
+ ++i;
+ }
+
+ // Delete any extraneous rectangles if we broke out of the above loop early due to
+ // intersection with the original image.
+ rects.resize(i);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type,
+ typename image_type1,
+ typename image_type2
+ >
+ void create_tiled_pyramid (
+ const image_type1& img,
+ image_type2& out_img,
+ std::vector<rectangle>& rects,
+ const unsigned long padding = 10,
+ const unsigned long outer_padding = 0
+ )
+ {
+ DLIB_ASSERT(!is_same_object(img, out_img));
+
+ long out_nr, out_nc;
+ pyramid_type pyr;
+ impl::compute_tiled_image_pyramid_details(pyr, img.nr(), img.nc(), padding, outer_padding, rects, out_nr, out_nc);
+
+ set_image_size(out_img, out_nr, out_nc);
+ assign_all_pixels(out_img, 0);
+
+ if (rects.size() == 0)
+ return;
+
+ // now build the image pyramid into out_img
+ auto si = sub_image(out_img, rects[0]);
+ assign_image(si, img);
+ for (size_t i = 1; i < rects.size(); ++i)
+ {
+ auto s1 = sub_image(out_img, rects[i-1]);
+ auto s2 = sub_image(out_img, rects[i]);
+ pyr(s1,s2);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type
+ >
+ dpoint image_to_tiled_pyramid (
+ const std::vector<rectangle>& rects,
+ double scale,
+ dpoint p
+ )
+ {
+ DLIB_CASSERT(rects.size() > 0);
+ DLIB_CASSERT(0 < scale && scale <= 1);
+ pyramid_type pyr;
+ // This scale factor maps this many levels down the pyramid
+ long pyramid_down_iter = static_cast<long>(std::log(scale)/std::log(pyramid_rate(pyr))+0.5);
+ pyramid_down_iter = put_in_range(0, (long)rects.size()-1, pyramid_down_iter);
+
+ return rects[pyramid_down_iter].tl_corner() + pyr.point_down(p, pyramid_down_iter);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type
+ >
+ drectangle image_to_tiled_pyramid (
+ const std::vector<rectangle>& rects,
+ double scale,
+ drectangle r
+ )
+ {
+ DLIB_ASSERT(rects.size() > 0);
+ DLIB_ASSERT(0 < scale && scale <= 1);
+ return drectangle(image_to_tiled_pyramid<pyramid_type>(rects, scale, r.tl_corner()),
+ image_to_tiled_pyramid<pyramid_type>(rects, scale, r.br_corner()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type
+ >
+ dpoint tiled_pyramid_to_image (
+ const std::vector<rectangle>& rects,
+ dpoint p
+ )
+ {
+ DLIB_CASSERT(rects.size() > 0);
+
+ size_t pyramid_down_iter = nearest_rect(rects, p);
+
+ p -= rects[pyramid_down_iter].tl_corner();
+ pyramid_type pyr;
+ return pyr.point_up(p, pyramid_down_iter);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type
+ >
+ drectangle tiled_pyramid_to_image (
+ const std::vector<rectangle>& rects,
+ drectangle r
+ )
+ {
+ DLIB_CASSERT(rects.size() > 0);
+
+ size_t pyramid_down_iter = nearest_rect(rects, dcenter(r));
+
+ dpoint origin = rects[pyramid_down_iter].tl_corner();
+ r = drectangle(r.tl_corner()-origin, r.br_corner()-origin);
+ pyramid_type pyr;
+ return pyr.rect_up(r, pyramid_down_iter);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IMAGE_PYRaMID_Hh_
+
diff --git a/ml/dlib/dlib/image_transforms/image_pyramid_abstract.h b/ml/dlib/dlib/image_transforms/image_pyramid_abstract.h
new file mode 100644
index 000000000..a61b275fd
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/image_pyramid_abstract.h
@@ -0,0 +1,384 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_IMAGE_PYRaMID_ABSTRACT_Hh_
+#ifdef DLIB_IMAGE_PYRaMID_ABSTRACT_Hh_
+
+#include "../pixel.h"
+#include "../array2d.h"
+#include "../geometry.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+ template <
+ unsigned int N
+ >
+ class pyramid_down : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON N
+ N > 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple functor to help create image pyramids. In particular, it
+ downsamples images at a ratio of N to N-1.
+
+ Note that setting N to 1 means that this object functions like
+ pyramid_disable (defined at the bottom of this file).
+
+ WARNING, when mapping rectangles from one layer of a pyramid
+ to another you might end up with rectangles which extend slightly
+ outside your images. This is because points on the border of an
+ image at a higher pyramid layer might correspond to points outside
+ images at lower layers. So just keep this in mind. Note also
+ that it's easy to deal with. Just say something like this:
+ rect = rect.intersect(get_rect(my_image)); // keep rect inside my_image
+ !*/
+ public:
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void operator() (
+ const in_image_type& original,
+ out_image_type& down
+ ) const;
+ /*!
+ requires
+ - is_same_object(original, down) == false
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - for both pixel types P in the input and output images, we require:
+ - pixel_traits<P>::has_alpha == false
+ ensures
+ - #down will contain an image that is roughly (N-1)/N times the size of the
+ original image.
+ - If both input and output images contain RGB pixels then the downsampled image will
+ be in color. Otherwise, the downsampling will be performed in a grayscale mode.
+ - The location of a point P in original image will show up at point point_down(P)
+ in the #down image.
+ - Note that some points on the border of the original image might correspond to
+ points outside the #down image.
+ !*/
+
+ template <
+ typename image_type
+ >
+ void operator() (
+ image_type& img
+ ) const;
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<typename image_traits<image_type>::pixel_type>::has_alpha == false
+ ensures
+ - This function downsamples the given image and stores the results in #img.
+ In particular, it is equivalent to performing:
+ (*this)(img, temp);
+ swap(img, temp);
+ !*/
+
+ // -------------------------------
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>& p
+ ) const;
+ /*!
+ ensures
+ - interprets p as a point in a parent image and returns the
+ point in a downsampled image which corresponds to p.
+ - This function is the inverse of point_up(). I.e. for a point P:
+ point_down(point_up(P)) == P
+ !*/
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>& p
+ ) const;
+ /*!
+ ensures
+ - interprets p as a point in a downsampled image and returns the
+ point in a parent image which corresponds to p.
+ - This function is the inverse of point_down(). I.e. for a point P:
+ point_up(point_down(P)) == P
+ !*/
+
+ drectangle rect_down (
+ const drectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner()));
+ (i.e. maps rect into a downsampled)
+ !*/
+
+ drectangle rect_up (
+ const drectangle& rect
+ ) const;
+ /*!
+ ensures
+ - returns drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner()));
+ (i.e. maps rect into a parent image)
+ !*/
+
+ // -------------------------------
+
+ template <typename T>
+ vector<double,2> point_down (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const;
+ /*!
+ ensures
+ - applies point_down() to p levels times and returns the result.
+ (i.e. point_down(p,2) == point_down(point_down(p)),
+ point_down(p,1) == point_down(p),
+ point_down(p,0) == p, etc. )
+ !*/
+
+ template <typename T>
+ vector<double,2> point_up (
+ const vector<T,2>& p,
+ unsigned int levels
+ ) const;
+ /*!
+ ensures
+ - applies point_up() to p levels times and returns the result.
+ (i.e. point_up(p,2) == point_up(point_up(p)),
+ point_up(p,1) == point_up(p),
+ point_up(p,0) == p, etc. )
+ !*/
+
+ drectangle rect_down (
+ const drectangle& rect,
+ unsigned int levels
+ ) const;
+ /*!
+ ensures
+ - returns drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels));
+ (i.e. Basically applies rect_down() to rect levels times and returns the result.)
+ !*/
+
+ drectangle rect_up (
+ const drectangle& rect,
+ unsigned int levels
+ ) const;
+ /*!
+ ensures
+ - returns drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels));
+ (i.e. Basically applies rect_up() to rect levels times and returns the result.)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class pyramid_disable : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a function object with an interface identical to pyramid_down (defined
+ at the top of this file) except that it downsamples images at a ratio of infinity
+ to 1. That means it always outputs images of size 0 regardless of the size
+ of the inputs.
+
+ This is useful because it can be supplied to routines which take a pyramid_down
+ function object and it will essentially disable pyramid processing. This way,
+ a pyramid oriented function can be turned into a regular routine which processes
+ just the original undownsampled image.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned int N
+ >
+ double pyramid_rate(
+ const pyramid_down<N>& pyr
+ );
+ /*!
+ ensures
+ - returns (N-1.0)/N
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned int N
+ >
+ void find_pyramid_down_output_image_size(
+ const pyramid_down<N>& pyr,
+ long& nr,
+ long& nc
+ );
+ /*!
+ requires
+ - nr >= 0
+ - nc >= 0
+ ensures
+ - If pyr() were called on an image with nr by nc rows and columns, what would
+ be the size of the output image? This function finds the size of the output
+ image and stores it back into #nr and #nc.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type,
+ typename image_type1,
+ typename image_type2
+ >
+ void create_tiled_pyramid (
+ const image_type1& img,
+ image_type2& out_img,
+ std::vector<rectangle>& rects,
+ const unsigned long padding = 10,
+ const unsigned long outer_padding = 0
+ );
+ /*!
+ requires
+ - pyramid_type == one of the dlib::pyramid_down template instances defined above.
+ - is_same_object(img, out_img) == false
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - for both pixel types P in the input and output images, we require:
+ - pixel_traits<P>::has_alpha == false
+ ensures
+ - Creates an image pyramid from the input image img. The pyramid is made using
+ pyramid_type. The highest resolution image is img and then all further
+ pyramid levels are generated from pyramid_type's downsampling. The entire
+ resulting pyramid is packed into a single image and stored in out_img.
+ - When packing pyramid levels into out_img, there will be padding pixels of
+ space between each sub-image. There will also be outer_padding pixels of
+ padding around the edge of the image. All padding pixels have a value of 0.
+ - The resulting pyramid will be composed of #rects.size() images packed into
+ out_img. Moreover, #rects[i] is the location inside out_img of the i-th
+ pyramid level.
+ - #rects.size() > 0
+ - #rects[0] == get_rect(img). I.e. the first rectangle is the highest
+ resolution pyramid layer. Subsequent elements of #rects correspond to
+ smaller and smaller pyramid layers inside out_img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type
+ >
+ dpoint image_to_tiled_pyramid (
+ const std::vector<rectangle>& rects,
+ double scale,
+ dpoint p
+ );
+ /*!
+ requires
+ - pyramid_type == one of the dlib::pyramid_down template instances defined above.
+ - 0 < scale <= 1
+ - rects.size() > 0
+ ensures
+ - The function create_tiled_pyramid() converts an image, img, to a "tiled
+ pyramid" called out_img. It also outputs a vector of rectangles, rect, that
+ show where each pyramid layer appears in out_img. Therefore,
+ image_to_tiled_pyramid() allows you to map from coordinates in img (i.e. p)
+ to coordinates in the tiled pyramid out_img, when given the rects metadata.
+
+ So given a point p in img, you can ask, what coordinate in out_img
+ corresponds to img[p.y()][p.x()] when things are scale times smaller? This
+ new coordinate is a location in out_img and is what is returned by this
+ function.
+ - A scale of 1 means we don't move anywhere in the pyramid scale space relative
+ to the input image while smaller values of scale mean we move down the
+ pyramid.
+ - Assumes pyramid_type is the pyramid class used to produce the tiled image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type
+ >
+ drectangle image_to_tiled_pyramid (
+ const std::vector<rectangle>& rects,
+ double scale,
+ drectangle r
+ );
+ /*!
+ requires
+ - pyramid_type == one of the dlib::pyramid_down template instances defined above.
+ - 0 < scale <= 1
+ - rects.size() > 0
+ ensures
+ - This function maps from input image space to tiled pyramid coordinate space
+ just as the above image_to_tiled_pyramid() does, except it operates on
+ rectangle objects instead of points.
+ - Assumes pyramid_type is the pyramid class used to produce the tiled image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type
+ >
+ dpoint tiled_pyramid_to_image (
+ const std::vector<rectangle>& rects,
+ dpoint p
+ );
+ /*!
+ requires
+ - pyramid_type == one of the dlib::pyramid_down template instances defined above.
+ - rects.size() > 0
+ ensures
+ - This function maps from a coordinate in a tiled pyramid to the corresponding
+ input image coordinate. Therefore, it is essentially the inverse of
+ image_to_tiled_pyramid().
+ - It should be noted that this function isn't always an inverse of
+ image_to_tiled_pyramid(). This is because you can ask
+ image_to_tiled_pyramid() for the coordinates of points outside the input
+ image and they will be mapped to somewhere that doesn't have an inverse. But
+ for points actually inside the image this function performs an approximate
+ inverse mapping.
+ - Assumes pyramid_type is the pyramid class used to produce the tiled image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type
+ >
+ drectangle tiled_pyramid_to_image (
+ const std::vector<rectangle>& rects,
+ drectangle r
+ );
+ /*!
+ requires
+ - pyramid_type == one of the dlib::pyramid_down template instances defined above.
+ - rects.size() > 0
+ ensures
+ - This function maps from a coordinate in a tiled pyramid to the corresponding
+ input image coordinate. Therefore, it is essentially the inverse of
+ image_to_tiled_pyramid().
+ - It should be noted that this function isn't always an inverse of
+ image_to_tiled_pyramid(). This is because you can ask
+ image_to_tiled_pyramid() for the coordinates of points outside the input
+ image and they will be mapped to somewhere that doesn't have an inverse. But
+ for points actually inside the image this function performs an approximate
+ inverse mapping.
+ - Assumes pyramid_type is the pyramid class used to produce the tiled image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IMAGE_PYRaMID_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_transforms/integral_image.h b/ml/dlib/dlib/image_transforms/integral_image.h
new file mode 100644
index 000000000..2ae47d921
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/integral_image.h
@@ -0,0 +1,190 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_INTEGRAL_IMAGE
+#define DLIB_INTEGRAL_IMAGE
+
+#include "integral_image_abstract.h"
+
+#include "../algs.h"
+#include "../assert.h"
+#include "../geometry.h"
+#include "../array2d.h"
+#include "../matrix.h"
+#include "../pixel.h"
+#include "../noncopyable.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class integral_image_generic : noncopyable
+ {
+ public:
+ typedef T value_type;
+
+ long nr() const { return int_img.nr(); }
+ long nc() const { return int_img.nc(); }
+
+ template <typename image_type>
+ void load (
+ const image_type& img_
+ )
+ {
+ const_image_view<image_type> img(img_);
+ T pixel;
+ int_img.set_size(img.nr(), img.nc());
+
+ // compute the first row of the integral image
+ T temp = 0;
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ assign_pixel(pixel, img[0][c]);
+ temp += pixel;
+ int_img[0][c] = temp;
+ }
+
+ // now compute the rest of the integral image
+ for (long r = 1; r < img.nr(); ++r)
+ {
+ temp = 0;
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ assign_pixel(pixel, img[r][c]);
+ temp += pixel;
+ int_img[r][c] = temp + int_img[r-1][c];
+ }
+ }
+
+ }
+
+ value_type get_sum_of_area (
+ const rectangle& rect
+ ) const
+ {
+ DLIB_ASSERT(get_rect(*this).contains(rect) == true && rect.is_empty() == false,
+ "\tvalue_type get_sum_of_area(rect)"
+ << "\n\tYou have given a rectangle that goes outside the image"
+ << "\n\tthis: " << this
+ << "\n\trect.is_empty(): " << rect.is_empty()
+ << "\n\trect: " << rect
+ << "\n\tget_rect(*this): " << get_rect(*this)
+ );
+
+ T top_left = 0, top_right = 0, bottom_left = 0, bottom_right = 0;
+
+ bottom_right = int_img[rect.bottom()][rect.right()];
+ if (rect.left()-1 >= 0 && rect.top()-1 >= 0)
+ {
+ top_left = int_img[rect.top()-1][rect.left()-1];
+ bottom_left = int_img[rect.bottom()][rect.left()-1];
+ top_right = int_img[rect.top()-1][rect.right()];
+ }
+ else if (rect.left()-1 >= 0)
+ {
+ bottom_left = int_img[rect.bottom()][rect.left()-1];
+ }
+ else if (rect.top()-1 >= 0)
+ {
+ top_right = int_img[rect.top()-1][rect.right()];
+ }
+
+ return bottom_right - bottom_left - top_right + top_left;
+ }
+
+ void swap(integral_image_generic& item)
+ {
+ int_img.swap(item.int_img);
+ }
+
+ private:
+
+ array2d<T> int_img;
+ };
+
+
+ template <
+ typename T
+ >
+ void swap (
+ integral_image_generic<T>& a,
+ integral_image_generic<T>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ typedef integral_image_generic<long> integral_image;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename integral_image_type>
+ typename integral_image_type::value_type haar_x (
+ const integral_image_type& img,
+ const point& p,
+ long width
+ )
+ {
+ DLIB_ASSERT(get_rect(img).contains(centered_rect(p,width,width)) == true,
+ "\tlong haar_x(img,p,width)"
+ << "\n\tYou have given a point and with that goes outside the image"
+ << "\n\tget_rect(img): " << get_rect(img)
+ << "\n\tp: " << p
+ << "\n\twidth: " << width
+ );
+
+ rectangle left_rect;
+ left_rect.set_left ( p.x() - width / 2 );
+ left_rect.set_top ( p.y() - width / 2 );
+ left_rect.set_right ( p.x()-1 );
+ left_rect.set_bottom ( left_rect.top() + width - 1 );
+
+ rectangle right_rect;
+ right_rect.set_left ( p.x() );
+ right_rect.set_top ( left_rect.top() );
+ right_rect.set_right ( left_rect.left() + width -1 );
+ right_rect.set_bottom ( left_rect.bottom() );
+
+ return img.get_sum_of_area(right_rect) - img.get_sum_of_area(left_rect);
+ }
+
+ // ----------------------------------------------------------------------------
+
+ template <typename integral_image_type>
+ typename integral_image_type::value_type haar_y (
+ const integral_image_type& img,
+ const point& p,
+ long width
+ )
+ {
+ DLIB_ASSERT(get_rect(img).contains(centered_rect(p,width,width)) == true,
+ "\tlong haar_y(img,p,width)"
+ << "\n\tYou have given a point and with that goes outside the image"
+ << "\n\tget_rect(img): " << get_rect(img)
+ << "\n\tp: " << p
+ << "\n\twidth: " << width
+ );
+
+ rectangle top_rect;
+ top_rect.set_left ( p.x() - width / 2 );
+ top_rect.set_top ( p.y() - width / 2 );
+ top_rect.set_right ( top_rect.left() + width - 1 );
+ top_rect.set_bottom ( p.y()-1 );
+
+ rectangle bottom_rect;
+ bottom_rect.set_left ( top_rect.left() );
+ bottom_rect.set_top ( p.y() );
+ bottom_rect.set_right ( top_rect.right() );
+ bottom_rect.set_bottom ( top_rect.top() + width - 1 );
+
+ return img.get_sum_of_area(bottom_rect) - img.get_sum_of_area(top_rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_INTEGRAL_IMAGE
+
diff --git a/ml/dlib/dlib/image_transforms/integral_image_abstract.h b/ml/dlib/dlib/image_transforms/integral_image_abstract.h
new file mode 100644
index 000000000..583fa0375
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/integral_image_abstract.h
@@ -0,0 +1,169 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_INTEGRAL_IMAGe_ABSTRACT_
+#ifdef DLIB_INTEGRAL_IMAGe_ABSTRACT_
+
+#include "../geometry/rectangle_abstract.h"
+#include "../array2d/array2d_kernel_abstract.h"
+#include "../pixel.h"
+#include "../noncopyable.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class integral_image_generic : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON T
+ T should be a built in scalar type. Moreover, it should
+ be capable of storing sums of whatever kind of pixel
+ you will be dealing with.
+
+ INITIAL VALUE
+ - nr() == 0
+ - nc() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is an alternate way of representing image data
+ that allows for very fast computations of sums of pixels in
+ rectangular regions. To use this object you load it with a
+ normal image and then you can use the get_sum_of_area()
+ function to compute sums of pixels in a given area in
+ constant time.
+ !*/
+ public:
+ typedef T value_type;
+
+ const long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in this integral image object
+ !*/
+
+ const long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in this integral image object
+ !*/
+
+ template <typename image_type>
+ void load (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - Let P denote the type of pixel in img, then we require:
+ - pixel_traits<P>::has_alpha == false
+ ensures
+ - #nr() == img.nr()
+ - #nc() == img.nc()
+ - #*this will now contain an "integral image" representation of the
+ given input image.
+ !*/
+
+ value_type get_sum_of_area (
+ const rectangle& rect
+ ) const;
+ /*!
+ requires
+ - rect.is_empty() == false
+ - get_rect(*this).contains(rect) == true
+ (i.e. rect must not be outside the integral image)
+ ensures
+ - Let O denote the image this integral image was generated from.
+ Then this function returns sum(subm(mat(O),rect)).
+ That is, this function returns the sum of the pixels in O that
+ are contained within the given rectangle.
+ !*/
+
+ void swap(
+ integral_image_generic& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template < typename T >
+ void swap (
+ integral_image_generic<T>& a,
+ integral_image_generic<T>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ typedef integral_image_generic<long> integral_image;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename integral_image_type>
+ typename integral_image_type::value_type haar_x (
+ const integral_image_type& img,
+ const point& p,
+ long width
+ )
+ /*!
+ requires
+ - get_rect(img).contains(centered_rect(p,width,width)) == true
+ - integral_image_type == a type that implements the integral_image_generic
+ interface defined above
+ ensures
+ - returns the response of a Haar wavelet centered at the point p
+ with the given width. The wavelet is oriented along the X axis
+ and has the following shape:
+ ----++++
+ ----++++
+ ----++++
+ ----++++
+ That is, the wavelet is square and computes the sum of pixels on the
+ right minus the sum of pixels on the left.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename integral_image_type>
+ typename integral_image_type::value_type haar_y (
+ const integral_image_type& img,
+ const point& p,
+ long width
+ )
+ /*!
+ requires
+ - get_rect(img).contains(centered_rect(p,width,width)) == true
+ - integral_image_type == a type that implements the integral_image_generic
+ interface defined above
+ ensures
+ - returns the response of a Haar wavelet centered at the point p
+ with the given width in the given image. The wavelet is oriented
+ along the Y axis and has the following shape:
+ --------
+ --------
+ ++++++++
+ ++++++++
+ That is, the wavelet is square and computes the sum of pixels on the
+ bottom minus the sum of pixels on the top.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_INTEGRAL_IMAGe_ABSTRACT_
+
diff --git a/ml/dlib/dlib/image_transforms/interpolation.h b/ml/dlib/dlib/image_transforms/interpolation.h
new file mode 100644
index 000000000..11c561e2d
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/interpolation.h
@@ -0,0 +1,2193 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_INTERPOlATIONh_
+#define DLIB_INTERPOlATIONh_
+
+#include "interpolation_abstract.h"
+#include "../pixel.h"
+#include "../matrix.h"
+#include "assign_image.h"
+#include "image_pyramid.h"
+#include "../simd.h"
+#include "../image_processing/full_object_detection.h"
+#include <limits>
+#include "../rand.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct sub_image_proxy
+ {
+ sub_image_proxy() = default;
+
+ sub_image_proxy (
+ T& img,
+ rectangle rect
+ )
+ {
+ rect = rect.intersect(get_rect(img));
+ typedef typename image_traits<T>::pixel_type pixel_type;
+
+ _nr = rect.height();
+ _nc = rect.width();
+ _width_step = width_step(img);
+ _data = (char*)image_data(img) + sizeof(pixel_type)*rect.left() + rect.top()*_width_step;
+ }
+
+ void* _data = 0;
+ long _width_step = 0;
+ long _nr = 0;
+ long _nc = 0;
+ };
+
+ template <typename T>
+ struct const_sub_image_proxy
+ {
+ const_sub_image_proxy() = default;
+
+ const_sub_image_proxy (
+ const T& img,
+ rectangle rect
+ )
+ {
+ rect = rect.intersect(get_rect(img));
+ typedef typename image_traits<T>::pixel_type pixel_type;
+
+ _nr = rect.height();
+ _nc = rect.width();
+ _width_step = width_step(img);
+ _data = (const char*)image_data(img) + sizeof(pixel_type)*rect.left() + rect.top()*_width_step;
+ }
+
+ const void* _data = 0;
+ long _width_step = 0;
+ long _nr = 0;
+ long _nc = 0;
+ };
+
+ template <typename T>
+ struct image_traits<sub_image_proxy<T> >
+ {
+ typedef typename image_traits<T>::pixel_type pixel_type;
+ };
+ template <typename T>
+ struct image_traits<const sub_image_proxy<T> >
+ {
+ typedef typename image_traits<T>::pixel_type pixel_type;
+ };
+ template <typename T>
+ struct image_traits<const_sub_image_proxy<T> >
+ {
+ typedef typename image_traits<T>::pixel_type pixel_type;
+ };
+ template <typename T>
+ struct image_traits<const const_sub_image_proxy<T> >
+ {
+ typedef typename image_traits<T>::pixel_type pixel_type;
+ };
+
+ template <typename T>
+ inline long num_rows( const sub_image_proxy<T>& img) { return img._nr; }
+ template <typename T>
+ inline long num_columns( const sub_image_proxy<T>& img) { return img._nc; }
+
+ template <typename T>
+ inline long num_rows( const const_sub_image_proxy<T>& img) { return img._nr; }
+ template <typename T>
+ inline long num_columns( const const_sub_image_proxy<T>& img) { return img._nc; }
+
+ template <typename T>
+ inline void* image_data( sub_image_proxy<T>& img)
+ {
+ return img._data;
+ }
+ template <typename T>
+ inline const void* image_data( const sub_image_proxy<T>& img)
+ {
+ return img._data;
+ }
+
+ template <typename T>
+ inline const void* image_data( const const_sub_image_proxy<T>& img)
+ {
+ return img._data;
+ }
+
+ template <typename T>
+ inline long width_step(
+ const sub_image_proxy<T>& img
+ ) { return img._width_step; }
+
+ template <typename T>
+ inline long width_step(
+ const const_sub_image_proxy<T>& img
+ ) { return img._width_step; }
+
+ template <typename T>
+ void set_image_size(sub_image_proxy<T>& img, long rows, long cols)
+ {
+ DLIB_CASSERT(img._nr == rows && img._nc == cols, "A sub_image can't be resized."
+ << "\n\t img._nr: "<< img._nr
+ << "\n\t img._nc: "<< img._nc
+ << "\n\t rows: "<< rows
+ << "\n\t cols: "<< cols
+ );
+ }
+
+ template <
+ typename image_type
+ >
+ sub_image_proxy<image_type> sub_image (
+ image_type& img,
+ const rectangle& rect
+ )
+ {
+ return sub_image_proxy<image_type>(img,rect);
+ }
+
+ template <
+ typename image_type
+ >
+ const const_sub_image_proxy<image_type> sub_image (
+ const image_type& img,
+ const rectangle& rect
+ )
+ {
+ return const_sub_image_proxy<image_type>(img,rect);
+ }
+
+ template <typename T>
+ inline sub_image_proxy<matrix<T>> sub_image (
+ T* img,
+ long nr,
+ long nc,
+ long row_stride
+ )
+ {
+ sub_image_proxy<matrix<T>> tmp;
+ tmp._data = img;
+ tmp._nr = nr;
+ tmp._nc = nc;
+ tmp._width_step = row_stride*sizeof(T);
+ return tmp;
+ }
+
+ template <typename T>
+ inline const const_sub_image_proxy<matrix<T>> sub_image (
+ const T* img,
+ long nr,
+ long nc,
+ long row_stride
+ )
+ {
+ const_sub_image_proxy<matrix<T>> tmp;
+ tmp._data = img;
+ tmp._nr = nr;
+ tmp._nc = nc;
+ tmp._width_step = row_stride*sizeof(T);
+ return tmp;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class interpolate_nearest_neighbor
+ {
+ public:
+
+ template <typename image_view_type, typename pixel_type>
+ bool operator() (
+ const image_view_type& img,
+ const dlib::point& p,
+ pixel_type& result
+ ) const
+ {
+ COMPILE_TIME_ASSERT(pixel_traits<typename image_view_type::pixel_type>::has_alpha == false);
+
+ if (get_rect(img).contains(p))
+ {
+ assign_pixel(result, img[p.y()][p.x()]);
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class interpolate_bilinear
+ {
+ template <typename T>
+ struct is_rgb_image
+ {
+ const static bool value = pixel_traits<typename T::pixel_type>::rgb;
+ };
+
+ public:
+
+ template <typename T, typename image_view_type, typename pixel_type>
+ typename disable_if<is_rgb_image<image_view_type>,bool>::type operator() (
+ const image_view_type& img,
+ const dlib::vector<T,2>& p,
+ pixel_type& result
+ ) const
+ {
+ COMPILE_TIME_ASSERT(pixel_traits<typename image_view_type::pixel_type>::has_alpha == false);
+
+ const long left = static_cast<long>(std::floor(p.x()));
+ const long top = static_cast<long>(std::floor(p.y()));
+ const long right = left+1;
+ const long bottom = top+1;
+
+
+ // if the interpolation goes outside img
+ if (!(left >= 0 && top >= 0 && right < img.nc() && bottom < img.nr()))
+ return false;
+
+ const double lr_frac = p.x() - left;
+ const double tb_frac = p.y() - top;
+
+ double tl = 0, tr = 0, bl = 0, br = 0;
+
+ assign_pixel(tl, img[top][left]);
+ assign_pixel(tr, img[top][right]);
+ assign_pixel(bl, img[bottom][left]);
+ assign_pixel(br, img[bottom][right]);
+
+ double temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
+ tb_frac*((1-lr_frac)*bl + lr_frac*br);
+
+ assign_pixel(result, temp);
+ return true;
+ }
+
+ template <typename T, typename image_view_type, typename pixel_type>
+ typename enable_if<is_rgb_image<image_view_type>,bool>::type operator() (
+ const image_view_type& img,
+ const dlib::vector<T,2>& p,
+ pixel_type& result
+ ) const
+ {
+ COMPILE_TIME_ASSERT(pixel_traits<typename image_view_type::pixel_type>::has_alpha == false);
+
+ const long left = static_cast<long>(std::floor(p.x()));
+ const long top = static_cast<long>(std::floor(p.y()));
+ const long right = left+1;
+ const long bottom = top+1;
+
+
+ // if the interpolation goes outside img
+ if (!(left >= 0 && top >= 0 && right < img.nc() && bottom < img.nr()))
+ return false;
+
+ const double lr_frac = p.x() - left;
+ const double tb_frac = p.y() - top;
+
+ double tl, tr, bl, br;
+
+ tl = img[top][left].red;
+ tr = img[top][right].red;
+ bl = img[bottom][left].red;
+ br = img[bottom][right].red;
+ const double red = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
+ tb_frac*((1-lr_frac)*bl + lr_frac*br);
+
+ tl = img[top][left].green;
+ tr = img[top][right].green;
+ bl = img[bottom][left].green;
+ br = img[bottom][right].green;
+ const double green = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
+ tb_frac*((1-lr_frac)*bl + lr_frac*br);
+
+ tl = img[top][left].blue;
+ tr = img[top][right].blue;
+ bl = img[bottom][left].blue;
+ br = img[bottom][right].blue;
+ const double blue = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
+ tb_frac*((1-lr_frac)*bl + lr_frac*br);
+
+ rgb_pixel temp;
+ assign_pixel(temp.red, red);
+ assign_pixel(temp.green, green);
+ assign_pixel(temp.blue, blue);
+ assign_pixel(result, temp);
+ return true;
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class interpolate_quadratic
+ {
+ template <typename T>
+ struct is_rgb_image
+ {
+ const static bool value = pixel_traits<typename T::pixel_type>::rgb;
+ };
+
+ public:
+
+ template <typename T, typename image_view_type, typename pixel_type>
+ typename disable_if<is_rgb_image<image_view_type>,bool>::type operator() (
+ const image_view_type& img,
+ const dlib::vector<T,2>& p,
+ pixel_type& result
+ ) const
+ {
+ COMPILE_TIME_ASSERT(pixel_traits<typename image_view_type::pixel_type>::has_alpha == false);
+
+ const point pp(p);
+
+ // if the interpolation goes outside img
+ if (!get_rect(img).contains(grow_rect(pp,1)))
+ return false;
+
+ const long r = pp.y();
+ const long c = pp.x();
+
+ const double temp = interpolate(p-pp,
+ img[r-1][c-1],
+ img[r-1][c ],
+ img[r-1][c+1],
+ img[r ][c-1],
+ img[r ][c ],
+ img[r ][c+1],
+ img[r+1][c-1],
+ img[r+1][c ],
+ img[r+1][c+1]);
+
+ assign_pixel(result, temp);
+ return true;
+ }
+
+ template <typename T, typename image_view_type, typename pixel_type>
+ typename enable_if<is_rgb_image<image_view_type>,bool>::type operator() (
+ const image_view_type& img,
+ const dlib::vector<T,2>& p,
+ pixel_type& result
+ ) const
+ {
+ COMPILE_TIME_ASSERT(pixel_traits<typename image_view_type::pixel_type>::has_alpha == false);
+
+ const point pp(p);
+
+ // if the interpolation goes outside img
+ if (!get_rect(img).contains(grow_rect(pp,1)))
+ return false;
+
+ const long r = pp.y();
+ const long c = pp.x();
+
+ const double red = interpolate(p-pp,
+ img[r-1][c-1].red,
+ img[r-1][c ].red,
+ img[r-1][c+1].red,
+ img[r ][c-1].red,
+ img[r ][c ].red,
+ img[r ][c+1].red,
+ img[r+1][c-1].red,
+ img[r+1][c ].red,
+ img[r+1][c+1].red);
+ const double green = interpolate(p-pp,
+ img[r-1][c-1].green,
+ img[r-1][c ].green,
+ img[r-1][c+1].green,
+ img[r ][c-1].green,
+ img[r ][c ].green,
+ img[r ][c+1].green,
+ img[r+1][c-1].green,
+ img[r+1][c ].green,
+ img[r+1][c+1].green);
+ const double blue = interpolate(p-pp,
+ img[r-1][c-1].blue,
+ img[r-1][c ].blue,
+ img[r-1][c+1].blue,
+ img[r ][c-1].blue,
+ img[r ][c ].blue,
+ img[r ][c+1].blue,
+ img[r+1][c-1].blue,
+ img[r+1][c ].blue,
+ img[r+1][c+1].blue);
+
+
+ rgb_pixel temp;
+ assign_pixel(temp.red, red);
+ assign_pixel(temp.green, green);
+ assign_pixel(temp.blue, blue);
+ assign_pixel(result, temp);
+
+ return true;
+ }
+
+ private:
+
+ /* tl tm tr
+ ml mm mr
+ bl bm br
+ */
+ // The above is the pixel layout in our little 3x3 neighborhood. interpolate() will
+ // fit a quadratic to these 9 pixels and then use that quadratic to find the interpolated
+ // value at point p.
+ inline double interpolate(
+ const dlib::vector<double,2>& p,
+ double tl, double tm, double tr,
+ double ml, double mm, double mr,
+ double bl, double bm, double br
+ ) const
+ {
+ matrix<double,6,1> w;
+ // x
+ w(0) = (tr + mr + br - tl - ml - bl)*0.16666666666;
+ // y
+ w(1) = (bl + bm + br - tl - tm - tr)*0.16666666666;
+ // x^2
+ w(2) = (tl + tr + ml + mr + bl + br)*0.16666666666 - (tm + mm + bm)*0.333333333;
+ // x*y
+ w(3) = (tl - tr - bl + br)*0.25;
+ // y^2
+ w(4) = (tl + tm + tr + bl + bm + br)*0.16666666666 - (ml + mm + mr)*0.333333333;
+ // 1 (constant term)
+ w(5) = (tm + ml + mr + bm)*0.222222222 - (tl + tr + bl + br)*0.11111111 + (mm)*0.55555556;
+
+ const double x = p.x();
+ const double y = p.y();
+
+ matrix<double,6,1> z;
+ z = x, y, x*x, x*y, y*y, 1.0;
+
+ return dot(w,z);
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class black_background
+ {
+ public:
+ template <typename pixel_type>
+ void operator() ( pixel_type& p) const { assign_pixel(p, 0); }
+ };
+
+ class white_background
+ {
+ public:
+ template <typename pixel_type>
+ void operator() ( pixel_type& p) const { assign_pixel(p, 255); }
+ };
+
+ class no_background
+ {
+ public:
+ template <typename pixel_type>
+ void operator() ( pixel_type& ) const { }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type,
+ typename point_mapping_type,
+ typename background_type
+ >
+ void transform_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const interpolation_type& interp,
+ const point_mapping_type& map_point,
+ const background_type& set_background,
+ const rectangle& area
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( get_rect(out_img).contains(area) == true &&
+ is_same_object(in_img, out_img) == false ,
+ "\t void transform_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t get_rect(out_img).contains(area): " << get_rect(out_img).contains(area)
+ << "\n\t get_rect(out_img): " << get_rect(out_img)
+ << "\n\t area: " << area
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ const_image_view<image_type1> imgv(in_img);
+ image_view<image_type2> out_imgv(out_img);
+
+ for (long r = area.top(); r <= area.bottom(); ++r)
+ {
+ for (long c = area.left(); c <= area.right(); ++c)
+ {
+ if (!interp(imgv, map_point(dlib::vector<double,2>(c,r)), out_imgv[r][c]))
+ set_background(out_imgv[r][c]);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type,
+ typename point_mapping_type,
+ typename background_type
+ >
+ void transform_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const interpolation_type& interp,
+ const point_mapping_type& map_point,
+ const background_type& set_background
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t void transform_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ transform_image(in_img, out_img, interp, map_point, set_background, get_rect(out_img));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type,
+ typename point_mapping_type
+ >
+ void transform_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const interpolation_type& interp,
+ const point_mapping_type& map_point
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t void transform_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+
+ transform_image(in_img, out_img, interp, map_point, black_background(), get_rect(out_img));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type
+ >
+ point_transform_affine rotate_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ double angle,
+ const interpolation_type& interp
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t point_transform_affine rotate_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ const rectangle rimg = get_rect(in_img);
+
+
+ // figure out bounding box for rotated rectangle
+ rectangle rect;
+ rect += rotate_point(center(rimg), rimg.tl_corner(), -angle);
+ rect += rotate_point(center(rimg), rimg.tr_corner(), -angle);
+ rect += rotate_point(center(rimg), rimg.bl_corner(), -angle);
+ rect += rotate_point(center(rimg), rimg.br_corner(), -angle);
+ set_image_size(out_img, rect.height(), rect.width());
+
+ const matrix<double,2,2> R = rotation_matrix(angle);
+
+ point_transform_affine trans = point_transform_affine(R, -R*dcenter(get_rect(out_img)) + dcenter(rimg));
+ transform_image(in_img, out_img, interp, trans);
+ return inv(trans);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ point_transform_affine rotate_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ double angle
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t point_transform_affine rotate_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ return rotate_image(in_img, out_img, angle, interpolate_quadratic());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ class helper_resize_image
+ {
+ public:
+ helper_resize_image(
+ double x_scale_,
+ double y_scale_
+ ):
+ x_scale(x_scale_),
+ y_scale(y_scale_)
+ {}
+
+ dlib::vector<double,2> operator() (
+ const dlib::vector<double,2>& p
+ ) const
+ {
+ return dlib::vector<double,2>(p.x()*x_scale, p.y()*y_scale);
+ }
+
+ private:
+ const double x_scale;
+ const double y_scale;
+ };
+ }
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type
+ >
+ void resize_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const interpolation_type& interp
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t void resize_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ const double x_scale = (num_columns(in_img)-1)/(double)std::max<long>((num_columns(out_img)-1),1);
+ const double y_scale = (num_rows(in_img)-1)/(double)std::max<long>((num_rows(out_img)-1),1);
+ transform_image(in_img, out_img, interp,
+ dlib::impl::helper_resize_image(x_scale,y_scale));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ struct is_rgb_image { const static bool value = pixel_traits<typename image_traits<image_type>::pixel_type>::rgb; };
+ template <typename image_type>
+ struct is_grayscale_image { const static bool value = pixel_traits<typename image_traits<image_type>::pixel_type>::grayscale; };
+
+ // This is an optimized version of resize_image for the case where bilinear
+ // interpolation is used.
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ typename disable_if_c<(is_rgb_image<image_type1>::value&&is_rgb_image<image_type2>::value) ||
+ (is_grayscale_image<image_type1>::value&&is_grayscale_image<image_type2>::value)>::type
+ resize_image (
+ const image_type1& in_img_,
+ image_type2& out_img_,
+ interpolate_bilinear
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img_, out_img_) == false ,
+ "\t void resize_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img_, out_img_): " << is_same_object(in_img_, out_img_)
+ );
+
+ const_image_view<image_type1> in_img(in_img_);
+ image_view<image_type2> out_img(out_img_);
+
+ if (out_img.size() == 0 || in_img.size() == 0)
+ return;
+
+
+ typedef typename image_traits<image_type1>::pixel_type T;
+ typedef typename image_traits<image_type2>::pixel_type U;
+ const double x_scale = (in_img.nc()-1)/(double)std::max<long>((out_img.nc()-1),1);
+ const double y_scale = (in_img.nr()-1)/(double)std::max<long>((out_img.nr()-1),1);
+ double y = -y_scale;
+ for (long r = 0; r < out_img.nr(); ++r)
+ {
+ y += y_scale;
+ const long top = static_cast<long>(std::floor(y));
+ const long bottom = std::min(top+1, in_img.nr()-1);
+ const double tb_frac = y - top;
+ double x = -x_scale;
+ if (pixel_traits<U>::grayscale)
+ {
+ for (long c = 0; c < out_img.nc(); ++c)
+ {
+ x += x_scale;
+ const long left = static_cast<long>(std::floor(x));
+ const long right = std::min(left+1, in_img.nc()-1);
+ const double lr_frac = x - left;
+
+ double tl = 0, tr = 0, bl = 0, br = 0;
+
+ assign_pixel(tl, in_img[top][left]);
+ assign_pixel(tr, in_img[top][right]);
+ assign_pixel(bl, in_img[bottom][left]);
+ assign_pixel(br, in_img[bottom][right]);
+
+ double temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
+ tb_frac*((1-lr_frac)*bl + lr_frac*br);
+
+ assign_pixel(out_img[r][c], temp);
+ }
+ }
+ else
+ {
+ for (long c = 0; c < out_img.nc(); ++c)
+ {
+ x += x_scale;
+ const long left = static_cast<long>(std::floor(x));
+ const long right = std::min(left+1, in_img.nc()-1);
+ const double lr_frac = x - left;
+
+ const T tl = in_img[top][left];
+ const T tr = in_img[top][right];
+ const T bl = in_img[bottom][left];
+ const T br = in_img[bottom][right];
+
+ T temp;
+ assign_pixel(temp, 0);
+ vector_to_pixel(temp,
+ (1-tb_frac)*((1-lr_frac)*pixel_to_vector<double>(tl) + lr_frac*pixel_to_vector<double>(tr)) +
+ tb_frac*((1-lr_frac)*pixel_to_vector<double>(bl) + lr_frac*pixel_to_vector<double>(br)));
+ assign_pixel(out_img[r][c], temp);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ struct images_have_same_pixel_types
+ {
+ typedef typename image_traits<image_type1>::pixel_type ptype1;
+ typedef typename image_traits<image_type2>::pixel_type ptype2;
+ const static bool value = is_same_type<ptype1, ptype2>::value;
+ };
+
+ template <
+ typename image_type,
+ typename image_type2
+ >
+ typename enable_if_c<is_grayscale_image<image_type>::value && is_grayscale_image<image_type2>::value && images_have_same_pixel_types<image_type,image_type2>::value>::type
+ resize_image (
+ const image_type& in_img_,
+ image_type2& out_img_,
+ interpolate_bilinear
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img_, out_img_) == false ,
+ "\t void resize_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img_, out_img_): " << is_same_object(in_img_, out_img_)
+ );
+
+ const_image_view<image_type> in_img(in_img_);
+ image_view<image_type2> out_img(out_img_);
+
+ if (out_img.size() == 0 || in_img.size() == 0)
+ return;
+
+ typedef typename image_traits<image_type>::pixel_type T;
+ const double x_scale = (in_img.nc()-1)/(double)std::max<long>((out_img.nc()-1),1);
+ const double y_scale = (in_img.nr()-1)/(double)std::max<long>((out_img.nr()-1),1);
+ double y = -y_scale;
+ for (long r = 0; r < out_img.nr(); ++r)
+ {
+ y += y_scale;
+ const long top = static_cast<long>(std::floor(y));
+ const long bottom = std::min(top+1, in_img.nr()-1);
+ const double tb_frac = y - top;
+ double x = -4*x_scale;
+
+ const simd4f _tb_frac = tb_frac;
+ const simd4f _inv_tb_frac = 1-tb_frac;
+ const simd4f _x_scale = 4*x_scale;
+ simd4f _x(x, x+x_scale, x+2*x_scale, x+3*x_scale);
+ long c = 0;
+ for (;; c+=4)
+ {
+ _x += _x_scale;
+ simd4i left = simd4i(_x);
+
+ simd4f _lr_frac = _x-left;
+ simd4f _inv_lr_frac = 1-_lr_frac;
+ simd4i right = left+1;
+
+ simd4f tlf = _inv_tb_frac*_inv_lr_frac;
+ simd4f trf = _inv_tb_frac*_lr_frac;
+ simd4f blf = _tb_frac*_inv_lr_frac;
+ simd4f brf = _tb_frac*_lr_frac;
+
+ int32 fleft[4];
+ int32 fright[4];
+ left.store(fleft);
+ right.store(fright);
+
+ if (fright[3] >= in_img.nc())
+ break;
+ simd4f tl(in_img[top][fleft[0]], in_img[top][fleft[1]], in_img[top][fleft[2]], in_img[top][fleft[3]]);
+ simd4f tr(in_img[top][fright[0]], in_img[top][fright[1]], in_img[top][fright[2]], in_img[top][fright[3]]);
+ simd4f bl(in_img[bottom][fleft[0]], in_img[bottom][fleft[1]], in_img[bottom][fleft[2]], in_img[bottom][fleft[3]]);
+ simd4f br(in_img[bottom][fright[0]], in_img[bottom][fright[1]], in_img[bottom][fright[2]], in_img[bottom][fright[3]]);
+
+ simd4f out = simd4f(tlf*tl + trf*tr + blf*bl + brf*br);
+ float fout[4];
+ out.store(fout);
+
+ out_img[r][c] = static_cast<T>(fout[0]);
+ out_img[r][c+1] = static_cast<T>(fout[1]);
+ out_img[r][c+2] = static_cast<T>(fout[2]);
+ out_img[r][c+3] = static_cast<T>(fout[3]);
+ }
+ x = -x_scale + c*x_scale;
+ for (; c < out_img.nc(); ++c)
+ {
+ x += x_scale;
+ const long left = static_cast<long>(std::floor(x));
+ const long right = std::min(left+1, in_img.nc()-1);
+ const float lr_frac = x - left;
+
+ float tl = 0, tr = 0, bl = 0, br = 0;
+
+ assign_pixel(tl, in_img[top][left]);
+ assign_pixel(tr, in_img[top][right]);
+ assign_pixel(bl, in_img[bottom][left]);
+ assign_pixel(br, in_img[bottom][right]);
+
+ float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
+ tb_frac*((1-lr_frac)*bl + lr_frac*br);
+
+ assign_pixel(out_img[r][c], temp);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ typename enable_if<is_rgb_image<image_type> >::type resize_image (
+ const image_type& in_img_,
+ image_type& out_img_,
+ interpolate_bilinear
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img_, out_img_) == false ,
+ "\t void resize_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img_, out_img_): " << is_same_object(in_img_, out_img_)
+ );
+
+ const_image_view<image_type> in_img(in_img_);
+ image_view<image_type> out_img(out_img_);
+
+ if (out_img.size() == 0 || in_img.size() == 0)
+ return;
+
+
+ typedef typename image_traits<image_type>::pixel_type T;
+ const double x_scale = (in_img.nc()-1)/(double)std::max<long>((out_img.nc()-1),1);
+ const double y_scale = (in_img.nr()-1)/(double)std::max<long>((out_img.nr()-1),1);
+ double y = -y_scale;
+ for (long r = 0; r < out_img.nr(); ++r)
+ {
+ y += y_scale;
+ const long top = static_cast<long>(std::floor(y));
+ const long bottom = std::min(top+1, in_img.nr()-1);
+ const double tb_frac = y - top;
+ double x = -4*x_scale;
+
+ const simd4f _tb_frac = tb_frac;
+ const simd4f _inv_tb_frac = 1-tb_frac;
+ const simd4f _x_scale = 4*x_scale;
+ simd4f _x(x, x+x_scale, x+2*x_scale, x+3*x_scale);
+ long c = 0;
+ for (;; c+=4)
+ {
+ _x += _x_scale;
+ simd4i left = simd4i(_x);
+ simd4f lr_frac = _x-left;
+ simd4f _inv_lr_frac = 1-lr_frac;
+ simd4i right = left+1;
+
+ simd4f tlf = _inv_tb_frac*_inv_lr_frac;
+ simd4f trf = _inv_tb_frac*lr_frac;
+ simd4f blf = _tb_frac*_inv_lr_frac;
+ simd4f brf = _tb_frac*lr_frac;
+
+ int32 fleft[4];
+ int32 fright[4];
+ left.store(fleft);
+ right.store(fright);
+
+ if (fright[3] >= in_img.nc())
+ break;
+ simd4f tl(in_img[top][fleft[0]].red, in_img[top][fleft[1]].red, in_img[top][fleft[2]].red, in_img[top][fleft[3]].red);
+ simd4f tr(in_img[top][fright[0]].red, in_img[top][fright[1]].red, in_img[top][fright[2]].red, in_img[top][fright[3]].red);
+ simd4f bl(in_img[bottom][fleft[0]].red, in_img[bottom][fleft[1]].red, in_img[bottom][fleft[2]].red, in_img[bottom][fleft[3]].red);
+ simd4f br(in_img[bottom][fright[0]].red, in_img[bottom][fright[1]].red, in_img[bottom][fright[2]].red, in_img[bottom][fright[3]].red);
+
+ simd4i out = simd4i(tlf*tl + trf*tr + blf*bl + brf*br);
+ int32 fout[4];
+ out.store(fout);
+
+ out_img[r][c].red = static_cast<unsigned char>(fout[0]);
+ out_img[r][c+1].red = static_cast<unsigned char>(fout[1]);
+ out_img[r][c+2].red = static_cast<unsigned char>(fout[2]);
+ out_img[r][c+3].red = static_cast<unsigned char>(fout[3]);
+
+
+ tl = simd4f(in_img[top][fleft[0]].green, in_img[top][fleft[1]].green, in_img[top][fleft[2]].green, in_img[top][fleft[3]].green);
+ tr = simd4f(in_img[top][fright[0]].green, in_img[top][fright[1]].green, in_img[top][fright[2]].green, in_img[top][fright[3]].green);
+ bl = simd4f(in_img[bottom][fleft[0]].green, in_img[bottom][fleft[1]].green, in_img[bottom][fleft[2]].green, in_img[bottom][fleft[3]].green);
+ br = simd4f(in_img[bottom][fright[0]].green, in_img[bottom][fright[1]].green, in_img[bottom][fright[2]].green, in_img[bottom][fright[3]].green);
+ out = simd4i(tlf*tl + trf*tr + blf*bl + brf*br);
+ out.store(fout);
+ out_img[r][c].green = static_cast<unsigned char>(fout[0]);
+ out_img[r][c+1].green = static_cast<unsigned char>(fout[1]);
+ out_img[r][c+2].green = static_cast<unsigned char>(fout[2]);
+ out_img[r][c+3].green = static_cast<unsigned char>(fout[3]);
+
+
+ tl = simd4f(in_img[top][fleft[0]].blue, in_img[top][fleft[1]].blue, in_img[top][fleft[2]].blue, in_img[top][fleft[3]].blue);
+ tr = simd4f(in_img[top][fright[0]].blue, in_img[top][fright[1]].blue, in_img[top][fright[2]].blue, in_img[top][fright[3]].blue);
+ bl = simd4f(in_img[bottom][fleft[0]].blue, in_img[bottom][fleft[1]].blue, in_img[bottom][fleft[2]].blue, in_img[bottom][fleft[3]].blue);
+ br = simd4f(in_img[bottom][fright[0]].blue, in_img[bottom][fright[1]].blue, in_img[bottom][fright[2]].blue, in_img[bottom][fright[3]].blue);
+ out = simd4i(tlf*tl + trf*tr + blf*bl + brf*br);
+ out.store(fout);
+ out_img[r][c].blue = static_cast<unsigned char>(fout[0]);
+ out_img[r][c+1].blue = static_cast<unsigned char>(fout[1]);
+ out_img[r][c+2].blue = static_cast<unsigned char>(fout[2]);
+ out_img[r][c+3].blue = static_cast<unsigned char>(fout[3]);
+ }
+ x = -x_scale + c*x_scale;
+ for (; c < out_img.nc(); ++c)
+ {
+ x += x_scale;
+ const long left = static_cast<long>(std::floor(x));
+ const long right = std::min(left+1, in_img.nc()-1);
+ const double lr_frac = x - left;
+
+ const T tl = in_img[top][left];
+ const T tr = in_img[top][right];
+ const T bl = in_img[bottom][left];
+ const T br = in_img[bottom][right];
+
+ T temp;
+ assign_pixel(temp, 0);
+ vector_to_pixel(temp,
+ (1-tb_frac)*((1-lr_frac)*pixel_to_vector<double>(tl) + lr_frac*pixel_to_vector<double>(tr)) +
+ tb_frac*((1-lr_frac)*pixel_to_vector<double>(bl) + lr_frac*pixel_to_vector<double>(br)));
+ assign_pixel(out_img[r][c], temp);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void resize_image (
+ const image_type1& in_img,
+ image_type2& out_img
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t void resize_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ resize_image(in_img, out_img, interpolate_bilinear());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void resize_image (
+ double size_scale,
+ image_type& img
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( size_scale > 0 ,
+ "\t void resize_image()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t size_scale: " << size_scale
+ );
+
+ image_type temp;
+ set_image_size(temp, std::round(size_scale*num_rows(img)), std::round(size_scale*num_columns(img)));
+ resize_image(img, temp);
+ swap(img, temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ point_transform_affine flip_image_left_right (
+ const image_type1& in_img,
+ image_type2& out_img
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t void flip_image_left_right()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ assign_image(out_img, fliplr(mat(in_img)));
+ std::vector<dlib::vector<double,2> > from, to;
+ rectangle r = get_rect(in_img);
+ from.push_back(r.tl_corner()); to.push_back(r.tr_corner());
+ from.push_back(r.bl_corner()); to.push_back(r.br_corner());
+ from.push_back(r.tr_corner()); to.push_back(r.tl_corner());
+ from.push_back(r.br_corner()); to.push_back(r.bl_corner());
+ return find_affine_transform(from,to);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ point_transform_affine flip_image_left_right (
+ image_type& img
+ )
+ {
+ image_type temp;
+ auto tform = flip_image_left_right(img, temp);
+ swap(temp,img);
+ return tform;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void flip_image_up_down (
+ const image_type1& in_img,
+ image_type2& out_img
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t void flip_image_up_down()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ assign_image(out_img, flipud(mat(in_img)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline rectangle flip_rect_left_right (
+ const rectangle& rect,
+ const rectangle& window
+ )
+ {
+ rectangle temp;
+ temp.top() = rect.top();
+ temp.bottom() = rect.bottom();
+
+ const long left_dist = rect.left()-window.left();
+
+ temp.right() = window.right()-left_dist;
+ temp.left() = temp.right()-rect.width()+1;
+ return temp;
+ }
+
+ inline rectangle tform_object (
+ const point_transform_affine& tran,
+ const rectangle& rect
+ )
+ {
+ return centered_rect(tran(center(rect)), rect.width(), rect.height());
+ }
+
+ inline mmod_rect tform_object (
+ const point_transform_affine& tran,
+ mmod_rect rect
+ )
+ {
+ rect.rect = tform_object(tran, rect.rect);
+ return rect;
+ }
+
+ inline full_object_detection tform_object(
+ const point_transform_affine& tran,
+ const full_object_detection& obj
+ )
+ {
+ std::vector<point> parts;
+ parts.reserve(obj.num_parts());
+ for (unsigned long i = 0; i < obj.num_parts(); ++i)
+ {
+ if (obj.part(i) != OBJECT_PART_NOT_PRESENT)
+ parts.push_back(tran(obj.part(i)));
+ else
+ parts.push_back(OBJECT_PART_NOT_PRESENT);
+ }
+ return full_object_detection(tform_object(tran,obj.get_rect()), parts);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename T
+ >
+ void add_image_left_right_flips (
+ image_array_type& images,
+ std::vector<std::vector<T> >& objects
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( images.size() == objects.size(),
+ "\t void add_image_left_right_flips()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ );
+
+ typename image_array_type::value_type temp;
+ std::vector<T> rects;
+
+ const unsigned long num = images.size();
+ for (unsigned long j = 0; j < num; ++j)
+ {
+ const point_transform_affine tran = flip_image_left_right(images[j], temp);
+
+ rects.clear();
+ for (unsigned long i = 0; i < objects[j].size(); ++i)
+ rects.push_back(impl::tform_object(tran, objects[j][i]));
+
+ images.push_back(std::move(temp));
+ objects.push_back(rects);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename T,
+ typename U
+ >
+ void add_image_left_right_flips (
+ image_array_type& images,
+ std::vector<std::vector<T> >& objects,
+ std::vector<std::vector<U> >& objects2
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( images.size() == objects.size() &&
+ images.size() == objects2.size(),
+ "\t void add_image_left_right_flips()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ << "\n\t objects2.size(): " << objects2.size()
+ );
+
+ typename image_array_type::value_type temp;
+ std::vector<T> rects;
+ std::vector<U> rects2;
+
+ const unsigned long num = images.size();
+ for (unsigned long j = 0; j < num; ++j)
+ {
+ const point_transform_affine tran = flip_image_left_right(images[j], temp);
+ images.push_back(std::move(temp));
+
+ rects.clear();
+ for (unsigned long i = 0; i < objects[j].size(); ++i)
+ rects.push_back(impl::tform_object(tran, objects[j][i]));
+ objects.push_back(rects);
+
+ rects2.clear();
+ for (unsigned long i = 0; i < objects2[j].size(); ++i)
+ rects2.push_back(impl::tform_object(tran, objects2[j][i]));
+ objects2.push_back(rects2);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_array_type>
+ void flip_image_dataset_left_right (
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( images.size() == objects.size(),
+ "\t void flip_image_dataset_left_right()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ );
+
+ typename image_array_type::value_type temp;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ flip_image_left_right(images[i], temp);
+ swap(temp,images[i]);
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ objects[i][j] = impl::flip_rect_left_right(objects[i][j], get_rect(images[i]));
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_array_type>
+ void flip_image_dataset_left_right (
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects,
+ std::vector<std::vector<rectangle> >& objects2
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( images.size() == objects.size() &&
+ images.size() == objects2.size(),
+ "\t void flip_image_dataset_left_right()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ << "\n\t objects2.size(): " << objects2.size()
+ );
+
+ typename image_array_type::value_type temp;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ flip_image_left_right(images[i], temp);
+ swap(temp, images[i]);
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ objects[i][j] = impl::flip_rect_left_right(objects[i][j], get_rect(images[i]));
+ }
+ for (unsigned long j = 0; j < objects2[i].size(); ++j)
+ {
+ objects2[i][j] = impl::flip_rect_left_right(objects2[i][j], get_rect(images[i]));
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type,
+ typename image_array_type
+ >
+ void upsample_image_dataset (
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects,
+ unsigned long max_image_size = std::numeric_limits<unsigned long>::max()
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( images.size() == objects.size(),
+ "\t void upsample_image_dataset()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ );
+
+ typename image_array_type::value_type temp;
+ pyramid_type pyr;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ const unsigned long img_size = num_rows(images[i])*num_columns(images[i]);
+ if (img_size <= max_image_size)
+ {
+ pyramid_up(images[i], temp, pyr);
+ swap(temp, images[i]);
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ objects[i][j] = pyr.rect_up(objects[i][j]);
+ }
+ }
+ }
+ }
+
+ template <
+ typename pyramid_type,
+ typename image_array_type
+ >
+ void upsample_image_dataset (
+ image_array_type& images,
+ std::vector<std::vector<mmod_rect>>& objects,
+ unsigned long max_image_size = std::numeric_limits<unsigned long>::max()
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( images.size() == objects.size(),
+ "\t void upsample_image_dataset()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ );
+
+ typename image_array_type::value_type temp;
+ pyramid_type pyr;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ const unsigned long img_size = num_rows(images[i])*num_columns(images[i]);
+ if (img_size <= max_image_size)
+ {
+ pyramid_up(images[i], temp, pyr);
+ swap(temp, images[i]);
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ objects[i][j].rect = pyr.rect_up(objects[i][j].rect);
+ }
+ }
+ }
+ }
+
+ template <
+ typename pyramid_type,
+ typename image_array_type
+ >
+ void upsample_image_dataset (
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects,
+ std::vector<std::vector<rectangle> >& objects2,
+ unsigned long max_image_size = std::numeric_limits<unsigned long>::max()
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( images.size() == objects.size() &&
+ images.size() == objects2.size(),
+ "\t void upsample_image_dataset()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ << "\n\t objects2.size(): " << objects2.size()
+ );
+
+ typename image_array_type::value_type temp;
+ pyramid_type pyr;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ const unsigned long img_size = num_rows(images[i])*num_columns(images[i]);
+ if (img_size <= max_image_size)
+ {
+ pyramid_up(images[i], temp, pyr);
+ swap(temp, images[i]);
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ objects[i][j] = pyr.rect_up(objects[i][j]);
+ }
+ for (unsigned long j = 0; j < objects2[i].size(); ++j)
+ {
+ objects2[i][j] = pyr.rect_up(objects2[i][j]);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_array_type>
+ void rotate_image_dataset (
+ double angle,
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( images.size() == objects.size(),
+ "\t void rotate_image_dataset()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ );
+
+ typename image_array_type::value_type temp;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ const point_transform_affine tran = rotate_image(images[i], temp, angle);
+ swap(temp, images[i]);
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ const rectangle rect = objects[i][j];
+ objects[i][j] = centered_rect(tran(center(rect)), rect.width(), rect.height());
+ }
+ }
+ }
+
+ template <typename image_array_type>
+ void rotate_image_dataset (
+ double angle,
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects,
+ std::vector<std::vector<rectangle> >& objects2
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( images.size() == objects.size() &&
+ images.size() == objects2.size(),
+ "\t void rotate_image_dataset()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ << "\n\t objects2.size(): " << objects2.size()
+ );
+
+ typename image_array_type::value_type temp;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ const point_transform_affine tran = rotate_image(images[i], temp, angle);
+ swap(temp, images[i]);
+ for (unsigned long j = 0; j < objects[i].size(); ++j)
+ {
+ const rectangle rect = objects[i][j];
+ objects[i][j] = centered_rect(tran(center(rect)), rect.width(), rect.height());
+ }
+ for (unsigned long j = 0; j < objects2[i].size(); ++j)
+ {
+ const rectangle rect = objects2[i][j];
+ objects2[i][j] = centered_rect(tran(center(rect)), rect.width(), rect.height());
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename EXP,
+ typename T,
+ typename U
+ >
+ void add_image_rotations (
+ const matrix_exp<EXP>& angles,
+ image_array_type& images,
+ std::vector<std::vector<T> >& objects,
+ std::vector<std::vector<U> >& objects2
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_vector(angles) && angles.size() > 0 &&
+ images.size() == objects.size() &&
+ images.size() == objects2.size(),
+ "\t void add_image_rotations()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_vector(angles): " << is_vector(angles)
+ << "\n\t angles.size(): " << angles.size()
+ << "\n\t images.size(): " << images.size()
+ << "\n\t objects.size(): " << objects.size()
+ << "\n\t objects2.size(): " << objects2.size()
+ );
+
+ image_array_type new_images;
+ std::vector<std::vector<T> > new_objects;
+ std::vector<std::vector<U> > new_objects2;
+
+ using namespace impl;
+
+ std::vector<T> objtemp;
+ std::vector<U> objtemp2;
+ typename image_array_type::value_type temp;
+ for (long i = 0; i < angles.size(); ++i)
+ {
+ for (unsigned long j = 0; j < images.size(); ++j)
+ {
+ const point_transform_affine tran = rotate_image(images[j], temp, angles(i));
+ new_images.push_back(std::move(temp));
+
+ objtemp.clear();
+ for (unsigned long k = 0; k < objects[j].size(); ++k)
+ objtemp.push_back(tform_object(tran, objects[j][k]));
+ new_objects.push_back(objtemp);
+
+ objtemp2.clear();
+ for (unsigned long k = 0; k < objects2[j].size(); ++k)
+ objtemp2.push_back(tform_object(tran, objects2[j][k]));
+ new_objects2.push_back(objtemp2);
+ }
+ }
+
+ new_images.swap(images);
+ new_objects.swap(objects);
+ new_objects2.swap(objects2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename EXP,
+ typename T
+ >
+ void add_image_rotations (
+ const matrix_exp<EXP>& angles,
+ image_array_type& images,
+ std::vector<std::vector<T> >& objects
+ )
+ {
+ std::vector<std::vector<T> > objects2(objects.size());
+ add_image_rotations(angles, images, objects, objects2);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename pyramid_type,
+ typename interpolation_type
+ >
+ void pyramid_up (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const pyramid_type& pyr,
+ const interpolation_type& interp
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t void pyramid_up()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ if (image_size(in_img) == 0)
+ {
+ set_image_size(out_img, 0, 0);
+ return;
+ }
+
+ rectangle rect = get_rect(in_img);
+ rectangle uprect = pyr.rect_up(rect);
+ if (uprect.is_empty())
+ {
+ set_image_size(out_img, 0, 0);
+ return;
+ }
+ set_image_size(out_img, uprect.bottom()+1, uprect.right()+1);
+
+ resize_image(in_img, out_img, interp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename pyramid_type
+ >
+ void pyramid_up (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const pyramid_type& pyr
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_same_object(in_img, out_img) == false ,
+ "\t void pyramid_up()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img)
+ );
+
+ pyramid_up(in_img, out_img, pyr, interpolate_bilinear());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pyramid_type
+ >
+ void pyramid_up (
+ image_type& img,
+ const pyramid_type& pyr
+ )
+ {
+ image_type temp;
+ pyramid_up(img, temp, pyr);
+ swap(temp, img);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void pyramid_up (
+ image_type& img
+ )
+ {
+ pyramid_down<2> pyr;
+ pyramid_up(img, pyr);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ struct chip_dims
+ {
+ chip_dims (
+ unsigned long rows_,
+ unsigned long cols_
+ ) : rows(rows_), cols(cols_) { }
+
+ unsigned long rows;
+ unsigned long cols;
+ };
+
+ struct chip_details
+ {
+ chip_details() : angle(0), rows(0), cols(0) {}
+ chip_details(const rectangle& rect_) : rect(rect_),angle(0), rows(rect_.height()), cols(rect_.width()) {}
+ chip_details(const drectangle& rect_) : rect(rect_),angle(0),
+ rows((unsigned long)(rect_.height()+0.5)), cols((unsigned long)(rect_.width()+0.5)) {}
+ chip_details(const drectangle& rect_, unsigned long size) : rect(rect_),angle(0)
+ { compute_dims_from_size(size); }
+ chip_details(const drectangle& rect_, unsigned long size, double angle_) : rect(rect_),angle(angle_)
+ { compute_dims_from_size(size); }
+
+ chip_details(const drectangle& rect_, const chip_dims& dims) :
+ rect(rect_),angle(0),rows(dims.rows), cols(dims.cols) {}
+ chip_details(const drectangle& rect_, const chip_dims& dims, double angle_) :
+ rect(rect_),angle(angle_),rows(dims.rows), cols(dims.cols) {}
+
+ template <typename T>
+ chip_details(
+ const std::vector<dlib::vector<T,2> >& chip_points,
+ const std::vector<dlib::vector<T,2> >& img_points,
+ const chip_dims& dims
+ ) :
+ rows(dims.rows), cols(dims.cols)
+ {
+ DLIB_CASSERT( chip_points.size() == img_points.size() && chip_points.size() >= 2,
+ "\t chip_details::chip_details(chip_points,img_points,dims)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t chip_points.size(): " << chip_points.size()
+ << "\n\t img_points.size(): " << img_points.size()
+ );
+
+ const point_transform_affine tform = find_similarity_transform(chip_points,img_points);
+ dlib::vector<double,2> p(1,0);
+ p = tform.get_m()*p;
+
+ // There are only 3 things happening in a similarity transform. There is a
+ // rescaling, a rotation, and a translation. So here we pick out the scale and
+ // rotation parameters.
+ angle = std::atan2(p.y(),p.x());
+ // Note that the translation and scale part are represented by the extraction
+ // rectangle. So here we build the appropriate rectangle.
+ const double scale = length(p);
+ rect = centered_drect(tform(point(dims.cols,dims.rows)/2.0),
+ dims.cols*scale,
+ dims.rows*scale);
+ }
+
+
+ drectangle rect;
+ double angle;
+ unsigned long rows;
+ unsigned long cols;
+
+ inline unsigned long size() const
+ {
+ return rows*cols;
+ }
+
+ private:
+ void compute_dims_from_size (
+ unsigned long size
+ )
+ {
+ const double relative_size = std::sqrt(size/(double)rect.area());
+ rows = static_cast<unsigned long>(rect.height()*relative_size + 0.5);
+ cols = static_cast<unsigned long>(size/(double)rows + 0.5);
+ rows = std::max(1ul,rows);
+ cols = std::max(1ul,cols);
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline point_transform_affine get_mapping_to_chip (
+ const chip_details& details
+ )
+ {
+ std::vector<dlib::vector<double,2> > from, to;
+ point p1(0,0);
+ point p2(details.cols-1,0);
+ point p3(details.cols-1, details.rows-1);
+ to.push_back(p1);
+ from.push_back(rotate_point<double>(center(details.rect),details.rect.tl_corner(),details.angle));
+ to.push_back(p2);
+ from.push_back(rotate_point<double>(center(details.rect),details.rect.tr_corner(),details.angle));
+ to.push_back(p3);
+ from.push_back(rotate_point<double>(center(details.rect),details.rect.br_corner(),details.angle));
+ return find_affine_transform(from, to);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline full_object_detection map_det_to_chip(
+ const full_object_detection& det,
+ const chip_details& details
+ )
+ {
+ point_transform_affine tform = get_mapping_to_chip(details);
+ full_object_detection res(det);
+ // map the parts
+ for (unsigned long l = 0; l < det.num_parts(); ++l)
+ {
+ if (det.part(l) != OBJECT_PART_NOT_PRESENT)
+ res.part(l) = tform(det.part(l));
+ else
+ res.part(l) = OBJECT_PART_NOT_PRESENT;
+ }
+ // map the main rectangle
+ rectangle rect;
+ rect += tform(det.get_rect().tl_corner());
+ rect += tform(det.get_rect().tr_corner());
+ rect += tform(det.get_rect().bl_corner());
+ rect += tform(det.get_rect().br_corner());
+ res.get_rect() = rect;
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void basic_extract_image_chip (
+ const image_type1& img,
+ const rectangle& location,
+ image_type2& chip
+ )
+ /*!
+ ensures
+ - This function doesn't do any scaling or rotating. It just pulls out the
+ chip in the given rectangle. This also means the output image has the
+ same dimensions as the location rectangle.
+ !*/
+ {
+ const_image_view<image_type1> vimg(img);
+ image_view<image_type2> vchip(chip);
+
+ vchip.set_size(location.height(), location.width());
+
+ // location might go outside img so clip it
+ rectangle area = location.intersect(get_rect(img));
+
+ // find the part of the chip that corresponds to area in img.
+ rectangle chip_area = translate_rect(area, -location.tl_corner());
+
+ zero_border_pixels(chip, chip_area);
+ // now pull out the contents of area/chip_area.
+ for (long r = chip_area.top(), rr = area.top(); r <= chip_area.bottom(); ++r,++rr)
+ {
+ for (long c = chip_area.left(), cc = area.left(); c <= chip_area.right(); ++c,++cc)
+ {
+ assign_pixel(vchip[r][c], vimg[rr][cc]);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type
+ >
+ void extract_image_chips (
+ const image_type1& img,
+ const std::vector<chip_details>& chip_locations,
+ dlib::array<image_type2>& chips,
+ const interpolation_type& interp
+ )
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < chip_locations.size(); ++i)
+ {
+ DLIB_CASSERT(chip_locations[i].size() != 0 &&
+ chip_locations[i].rect.is_empty() == false,
+ "\t void extract_image_chips()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t chip_locations["<<i<<"].size(): " << chip_locations[i].size()
+ << "\n\t chip_locations["<<i<<"].rect.is_empty(): " << chip_locations[i].rect.is_empty()
+ );
+ }
+#endif
+
+ pyramid_down<2> pyr;
+ long max_depth = 0;
+ // If the chip is supposed to be much smaller than the source subwindow then you
+ // can't just extract it using bilinear interpolation since at a high enough
+ // downsampling amount it would effectively turn into nearest neighbor
+ // interpolation. So we use an image pyramid to make sure the interpolation is
+ // fast but also high quality. The first thing we do is figure out how deep the
+ // image pyramid needs to be.
+ rectangle bounding_box;
+ for (unsigned long i = 0; i < chip_locations.size(); ++i)
+ {
+ long depth = 0;
+ double grow = 2;
+ drectangle rect = pyr.rect_down(chip_locations[i].rect);
+ while (rect.area() > chip_locations[i].size())
+ {
+ rect = pyr.rect_down(rect);
+ ++depth;
+ // We drop the image size by a factor of 2 each iteration and then assume a
+ // border of 2 pixels is needed to avoid any border effects of the crop.
+ grow = grow*2 + 2;
+ }
+ drectangle rot_rect;
+ const vector<double,2> cent = center(chip_locations[i].rect);
+ rot_rect += rotate_point<double>(cent,chip_locations[i].rect.tl_corner(),chip_locations[i].angle);
+ rot_rect += rotate_point<double>(cent,chip_locations[i].rect.tr_corner(),chip_locations[i].angle);
+ rot_rect += rotate_point<double>(cent,chip_locations[i].rect.bl_corner(),chip_locations[i].angle);
+ rot_rect += rotate_point<double>(cent,chip_locations[i].rect.br_corner(),chip_locations[i].angle);
+ bounding_box += grow_rect(rot_rect, grow).intersect(get_rect(img));
+ max_depth = std::max(depth,max_depth);
+ }
+ //std::cout << "max_depth: " << max_depth << std::endl;
+ //std::cout << "crop amount: " << bounding_box.area()/(double)get_rect(img).area() << std::endl;
+
+ // now make an image pyramid
+ dlib::array<array2d<typename image_traits<image_type1>::pixel_type> > levels(max_depth);
+ if (levels.size() != 0)
+ pyr(sub_image(img,bounding_box),levels[0]);
+ for (unsigned long i = 1; i < levels.size(); ++i)
+ pyr(levels[i-1],levels[i]);
+
+ std::vector<dlib::vector<double,2> > from, to;
+
+ // now pull out the chips
+ chips.resize(chip_locations.size());
+ for (unsigned long i = 0; i < chips.size(); ++i)
+ {
+ // If the chip doesn't have any rotation or scaling then use the basic version
+ // of chip extraction that just does a fast copy.
+ if (chip_locations[i].angle == 0 &&
+ chip_locations[i].rows == chip_locations[i].rect.height() &&
+ chip_locations[i].cols == chip_locations[i].rect.width())
+ {
+ impl::basic_extract_image_chip(img, chip_locations[i].rect, chips[i]);
+ }
+ else
+ {
+ set_image_size(chips[i], chip_locations[i].rows, chip_locations[i].cols);
+
+ // figure out which level in the pyramid to use to extract the chip
+ int level = -1;
+ drectangle rect = translate_rect(chip_locations[i].rect, -bounding_box.tl_corner());
+ while (pyr.rect_down(rect).area() > chip_locations[i].size())
+ {
+ ++level;
+ rect = pyr.rect_down(rect);
+ }
+
+ // find the appropriate transformation that maps from the chip to the input
+ // image
+ from.clear();
+ to.clear();
+ from.push_back(get_rect(chips[i]).tl_corner()); to.push_back(rotate_point<double>(center(rect),rect.tl_corner(),chip_locations[i].angle));
+ from.push_back(get_rect(chips[i]).tr_corner()); to.push_back(rotate_point<double>(center(rect),rect.tr_corner(),chip_locations[i].angle));
+ from.push_back(get_rect(chips[i]).bl_corner()); to.push_back(rotate_point<double>(center(rect),rect.bl_corner(),chip_locations[i].angle));
+ point_transform_affine trns = find_affine_transform(from,to);
+
+ // now extract the actual chip
+ if (level == -1)
+ transform_image(sub_image(img,bounding_box),chips[i],interp,trns);
+ else
+ transform_image(levels[level],chips[i],interp,trns);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void extract_image_chips(
+ const image_type1& img,
+ const std::vector<chip_details>& chip_locations,
+ dlib::array<image_type2>& chips
+ )
+ {
+ extract_image_chips(img, chip_locations, chips, interpolate_bilinear());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type
+ >
+ void extract_image_chip (
+ const image_type1& img,
+ const chip_details& location,
+ image_type2& chip,
+ const interpolation_type& interp
+ )
+ {
+ // If the chip doesn't have any rotation or scaling then use the basic version of
+ // chip extraction that just does a fast copy.
+ if (location.angle == 0 &&
+ location.rows == location.rect.height() &&
+ location.cols == location.rect.width())
+ {
+ impl::basic_extract_image_chip(img, location.rect, chip);
+ }
+ else
+ {
+ std::vector<chip_details> chip_locations(1,location);
+ dlib::array<image_type2> chips;
+ extract_image_chips(img, chip_locations, chips, interp);
+ swap(chips[0], chip);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void extract_image_chip (
+ const image_type1& img,
+ const chip_details& location,
+ image_type2& chip
+ )
+ {
+ extract_image_chip(img, location, chip, interpolate_bilinear());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline chip_details get_face_chip_details (
+ const full_object_detection& det,
+ const unsigned long size = 200,
+ const double padding = 0.2
+ )
+ {
+ DLIB_CASSERT(det.num_parts() == 68 || det.num_parts() == 5,
+ "\t chip_details get_face_chip_details()"
+ << "\n\t You have to give either a 5 point or 68 point face landmarking output to this function. "
+ << "\n\t det.num_parts(): " << det.num_parts()
+ );
+ DLIB_CASSERT(padding >= 0 && size > 0,
+ "\t chip_details get_face_chip_details()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t padding: " << padding
+ << "\n\t size: " << size
+ );
+
+
+ std::vector<dpoint> from_points, to_points;
+ if (det.num_parts() == 5)
+ {
+ dpoint p0(0.8595674595992, 0.2134981538014);
+ dpoint p1(0.6460604764104, 0.2289674387677);
+ dpoint p2(0.1205750620789, 0.2137274526848);
+ dpoint p3(0.3340850613712, 0.2290642403242);
+ dpoint p4(0.4901123135679, 0.6277975316475);
+
+
+ p0 = (padding+p0)/(2*padding+1);
+ p1 = (padding+p1)/(2*padding+1);
+ p2 = (padding+p2)/(2*padding+1);
+ p3 = (padding+p3)/(2*padding+1);
+ p4 = (padding+p4)/(2*padding+1);
+
+ from_points.push_back(p0*size);
+ to_points.push_back(det.part(0));
+
+ from_points.push_back(p1*size);
+ to_points.push_back(det.part(1));
+
+ from_points.push_back(p2*size);
+ to_points.push_back(det.part(2));
+
+ from_points.push_back(p3*size);
+ to_points.push_back(det.part(3));
+
+ from_points.push_back(p4*size);
+ to_points.push_back(det.part(4));
+ }
+ else
+ {
+ // Average positions of face points 17-67
+ const double mean_face_shape_x[] = {
+ 0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483, 0.799124,
+ 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127, 0.36688, 0.426036,
+ 0.490127, 0.554217, 0.613373, 0.121737, 0.187122, 0.265825, 0.334606, 0.260918,
+ 0.182743, 0.645647, 0.714428, 0.793132, 0.858516, 0.79751, 0.719335, 0.254149,
+ 0.340985, 0.428858, 0.490127, 0.551395, 0.639268, 0.726104, 0.642159, 0.556721,
+ 0.490127, 0.423532, 0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874,
+ 0.553364, 0.490127, 0.42689
+ };
+ const double mean_face_shape_y[] = {
+ 0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891,
+ 0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625, 0.587326,
+ 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758, 0.179852, 0.231733,
+ 0.245099, 0.244077, 0.231733, 0.179852, 0.178758, 0.216423, 0.244077, 0.245099,
+ 0.780233, 0.745405, 0.727388, 0.742578, 0.727388, 0.745405, 0.780233, 0.864805,
+ 0.902192, 0.909281, 0.902192, 0.864805, 0.784792, 0.778746, 0.785343, 0.778746,
+ 0.784792, 0.824182, 0.831803, 0.824182
+ };
+
+ COMPILE_TIME_ASSERT(sizeof(mean_face_shape_x)/sizeof(double) == 68-17);
+
+ for (unsigned long i = 17; i < det.num_parts(); ++i)
+ {
+ // Ignore the lower lip
+ if ((55 <= i && i <= 59) || (65 <= i && i <= 67))
+ continue;
+ // Ignore the eyebrows
+ if (17 <= i && i <= 26)
+ continue;
+
+ dpoint p;
+ p.x() = (padding+mean_face_shape_x[i-17])/(2*padding+1);
+ p.y() = (padding+mean_face_shape_y[i-17])/(2*padding+1);
+ from_points.push_back(p*size);
+ to_points.push_back(det.part(i));
+ }
+ }
+
+ return chip_details(from_points, to_points, chip_dims(size,size));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::vector<chip_details> get_face_chip_details (
+ const std::vector<full_object_detection>& dets,
+ const unsigned long size = 200,
+ const double padding = 0.2
+ )
+ {
+ std::vector<chip_details> res;
+ res.reserve(dets.size());
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ res.push_back(get_face_chip_details(dets[i], size, padding));
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ image_type jitter_image(
+ const image_type& img,
+ dlib::rand& rnd
+ )
+ {
+ DLIB_CASSERT(num_rows(img)*num_columns(img) != 0);
+ DLIB_CASSERT(num_rows(img)==num_columns(img));
+
+ const double max_rotation_degrees = 3;
+ const double min_object_height = 0.97;
+ const double max_object_height = 0.99999;
+ const double translate_amount = 0.02;
+
+
+ const auto rect = shrink_rect(get_rect(img),3);
+
+ // perturb the location of the crop by a small fraction of the object's size.
+ const point rand_translate = dpoint(rnd.get_double_in_range(-translate_amount,translate_amount)*rect.width(),
+ rnd.get_double_in_range(-translate_amount,translate_amount)*rect.height());
+
+ // perturb the scale of the crop by a fraction of the object's size
+ const double rand_scale_perturb = rnd.get_double_in_range(min_object_height, max_object_height);
+
+ const long box_size = rect.height()/rand_scale_perturb;
+ const auto crop_rect = centered_rect(center(rect)+rand_translate, box_size, box_size);
+ const double angle = rnd.get_double_in_range(-max_rotation_degrees, max_rotation_degrees)*pi/180;
+ image_type crop;
+ extract_image_chip(img, chip_details(crop_rect, chip_dims(img.nr(),img.nc()), angle), crop);
+ if (rnd.get_random_double() > 0.5)
+ flip_image_left_right(crop);
+
+ return crop;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_INTERPOlATIONh_
+
diff --git a/ml/dlib/dlib/image_transforms/interpolation_abstract.h b/ml/dlib/dlib/image_transforms/interpolation_abstract.h
new file mode 100644
index 000000000..f2da2fb02
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/interpolation_abstract.h
@@ -0,0 +1,1480 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_INTERPOlATION_ABSTRACT_
+#ifdef DLIB_INTERPOlATION_ABSTRACT_
+
+#include "../pixel.h"
+#include "../image_processing/full_object_detection_abstract.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class interpolate_nearest_neighbor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for performing nearest neighbor interpolation
+ on an image.
+ !*/
+
+ public:
+
+ template <
+ typename image_view_type,
+ typename pixel_type
+ >
+ bool operator() (
+ const image_view_type& img,
+ const dlib::point& p,
+ pixel_type& result
+ ) const;
+ /*!
+ requires
+ - image_view_type == an image_view or const_image_view object.
+ - pixel_traits<typename image_view_type::pixel_type>::has_alpha == false
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - if (p is located inside img) then
+ - #result == img[p.y()][p.x()]
+ (This assignment is done using assign_pixel(#result, img[p.y()][p.x()]),
+ therefore any necessary color space conversion will be performed)
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class interpolate_bilinear
+ {
+
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for performing bilinear interpolation
+ on an image. This is performed by looking at the 4 pixels
+ nearest to a point and deriving an interpolated value from them.
+ !*/
+
+ public:
+
+ template <
+ typename T,
+ typename image_view_type,
+ typename pixel_type
+ >
+ bool operator() (
+ const image_view_type& img,
+ const dlib::vector<T,2>& p,
+ pixel_type& result
+ ) const;
+ /*!
+ requires
+ - image_view_type == an image_view or const_image_view object
+ - pixel_traits<typename image_view_type::pixel_type>::has_alpha == false
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - if (there is an interpolatable image location at point p in img) then
+ - #result == the interpolated pixel value from img at point p.
+ - assign_pixel() will be used to write to #result, therefore any
+ necessary color space conversion will be performed.
+ - returns true
+ - if img contains RGB pixels then the interpolation will be in color.
+ Otherwise, the interpolation will be performed in a grayscale mode.
+ - else
+ - returns false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class interpolate_quadratic
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for performing quadratic interpolation
+ on an image. This is performed by looking at the 9 pixels
+ nearest to a point and deriving an interpolated value from them.
+ !*/
+
+ public:
+
+ template <
+ typename T,
+ typename image_view_type,
+ typename pixel_type
+ >
+ bool operator() (
+ const image_view_type& img,
+ const dlib::vector<T,2>& p,
+ pixel_type& result
+ ) const;
+ /*!
+ requires
+ - image_view_type == an image_view or const_image_view object.
+ - pixel_traits<typename image_view_type::pixel_type>::has_alpha == false
+ - pixel_traits<pixel_type> is defined
+ ensures
+ - if (there is an interpolatable image location at point p in img) then
+ - #result == the interpolated pixel value from img at point p
+ - assign_pixel() will be used to write to #result, therefore any
+ necessary color space conversion will be performed.
+ - returns true
+ - if img contains RGB pixels then the interpolation will be in color.
+ Otherwise, the interpolation will be performed in a grayscale mode.
+ - else
+ - returns false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class black_background
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a function object which simply sets a pixel
+ to have a black value.
+ !*/
+
+ public:
+ template <typename pixel_type>
+ void operator() ( pixel_type& p) const { assign_pixel(p, 0); }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class white_background
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a function object which simply sets a pixel
+ to have a white value.
+ !*/
+
+ public:
+ template <typename pixel_type>
+ void operator() ( pixel_type& p) const { assign_pixel(p, 255); }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class no_background
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a function object which does nothing. It is useful
+ when used with the transform_image() routine defined below
+ if no modification of uninterpolated output pixels is desired.
+ !*/
+ public:
+ template <typename pixel_type>
+ void operator() ( pixel_type& ) const { }
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type,
+ typename point_mapping_type,
+ typename background_type
+ >
+ void transform_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const interpolation_type& interp,
+ const point_mapping_type& map_point,
+ const background_type& set_background,
+ const rectangle& area
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear,
+ interpolate_quadratic, or a type with a compatible interface.
+ - map_point should be a function which takes dlib::vector<T,2> objects and
+ returns dlib::vector<T,2> objects. An example is point_transform_affine.
+ - set_background should be a function which can take a single argument of
+ type image_traits<image_type2>::pixel_type. Examples are black_background,
+ white_background, and no_background.
+ - get_rect(out_img).contains(area) == true
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - The map_point function defines a mapping from pixels in out_img to pixels
+ in in_img. transform_image() uses this mapping, along with the supplied
+ interpolation routine interp, to fill the region of out_img defined by
+ area with an interpolated copy of in_img.
+ - This function does not change the size of out_img.
+ - Only pixels inside the region defined by area in out_img are modified.
+ - For all locations r and c such that area.contains(c,r) but have no corresponding
+ locations in in_img:
+ - set_background(out_img[r][c]) is invoked
+ (i.e. some parts of out_img might correspond to areas outside in_img and
+ therefore can't supply interpolated values. In these cases, these
+ pixels can be assigned a value by the supplied set_background() routine)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type,
+ typename point_mapping_type,
+ typename background_type
+ >
+ void transform_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const interpolation_type& interp,
+ const point_mapping_type& map_point,
+ const background_type& set_background
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear,
+ interpolate_quadratic, or a type with a compatible interface.
+ - map_point should be a function which takes dlib::vector<T,2> objects and
+ returns dlib::vector<T,2> objects. An example is point_transform_affine.
+ - set_background should be a function which can take a single argument of
+ type image_traits<image_type2>::pixel_type. Examples are black_background, white_background,
+ and no_background.
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - performs:
+ transform_image(in_img, out_img, interp, map_point, set_background, get_rect(out_img));
+ (i.e. runs transform_image() on the entire out_img)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type,
+ typename point_mapping_type
+ >
+ void transform_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const interpolation_type& interp,
+ const point_mapping_type& map_point
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear,
+ interpolate_quadratic, or a type with a compatible interface.
+ - map_point should be a function which takes dlib::vector<T,2> objects and
+ returns dlib::vector<T,2> objects. An example is point_transform_affine.
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - performs:
+ transform_image(in_img, out_img, interp, map_point, black_background(), get_rect(out_img));
+ (i.e. runs transform_image() on the entire out_img and sets non-interpolated
+ pixels to black)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type
+ >
+ point_transform_affine rotate_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ double angle,
+ const interpolation_type& interp
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear,
+ interpolate_quadratic, or a type with a compatible interface.
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - #out_img == a copy of in_img which has been rotated angle radians counter clockwise.
+ The rotation is performed with respect to the center of the image.
+ - Parts of #out_img which have no corresponding locations in in_img are set to black.
+ - uses the supplied interpolation routine interp to perform the necessary
+ pixel interpolation.
+ - returns a transformation object that maps points in in_img into their corresponding
+ location in #out_img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ point_transform_affine rotate_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ double angle
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<typename image_traits<image_type1>::pixel_type>::has_alpha == false
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - #out_img == a copy of in_img which has been rotated angle radians counter clockwise.
+ The rotation is performed with respect to the center of the image.
+ - Parts of #out_img which have no corresponding locations in in_img are set to black.
+ - uses the interpolate_quadratic object to perform the necessary pixel interpolation.
+ - returns a transformation object that maps points in in_img into their corresponding
+ location in #out_img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type
+ >
+ void resize_image (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const interpolation_type& interp
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear,
+ interpolate_quadratic, or a type with a compatible interface.
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - #out_img == A copy of in_img which has been stretched so that it
+ fits exactly into out_img.
+ - The size of out_img is not modified. I.e.
+ - #out_img.nr() == out_img.nr()
+ - #out_img.nc() == out_img.nc()
+ - uses the supplied interpolation routine interp to perform the necessary
+ pixel interpolation.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void resize_image (
+ const image_type1& in_img,
+ image_type2& out_img
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<typename image_traits<image_type1>::pixel_type>::has_alpha == false
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - #out_img == A copy of in_img which has been stretched so that it
+ fits exactly into out_img.
+ - The size of out_img is not modified. I.e.
+ - #out_img.nr() == out_img.nr()
+ - #out_img.nc() == out_img.nc()
+ - Uses the bilinear interpolation to perform the necessary pixel interpolation.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void resize_image (
+ double size_scale,
+ image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<typename image_traits<image_type>::pixel_type>::has_alpha == false
+ ensures
+ - Resizes img so that each of it's dimensions are size_scale times larger than img.
+ In particular, we will have:
+ - #img.nr() == std::round(size_scale*img.nr())
+ - #img.nc() == std::round(size_scale*img.nc())
+ - #img == a bilinearly interpolated copy of the input image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ point_transform_affine flip_image_left_right (
+ const image_type1& in_img,
+ image_type2& out_img
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - #out_img.nr() == in_img.nr()
+ - #out_img.nc() == in_img.nc()
+ - #out_img == a copy of in_img which has been flipped from left to right.
+ (i.e. it is flipped as if viewed though a mirror)
+ - returns a transformation object that maps points in in_img into their
+ corresponding location in #out_img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ point_transform_affine flip_image_left_right (
+ image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - This function is identical to the above version of flip_image_left_right()
+ except that it operates in-place.
+ - #img.nr() == img.nr()
+ - #img.nc() == img.nc()
+ - #img == a copy of img which has been flipped from left to right.
+ (i.e. it is flipped as if viewed though a mirror)
+ - returns a transformation object that maps points in img into their
+ corresponding location in #img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename T
+ >
+ void add_image_left_right_flips (
+ image_array_type& images,
+ std::vector<std::vector<T> >& objects
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - T == rectangle, full_object_detection, or mmod_rect
+ - images.size() == objects.size()
+ ensures
+ - This function computes all the left/right flips of the contents of images and
+ then appends them onto the end of the images array. It also finds the
+ left/right flips of the rectangles in objects and similarly appends them into
+ objects. That is, we assume objects[i] is the set of bounding boxes in
+ images[i] and we flip the bounding boxes so that they still bound the same
+ objects in the new flipped images.
+ - #images.size() == images.size()*2
+ - #objects.size() == objects.size()*2
+ - All the original elements of images and objects are left unmodified. That
+ is, this function only appends new elements to each of these containers.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename T,
+ typename U
+ >
+ void add_image_left_right_flips (
+ image_array_type& images,
+ std::vector<std::vector<T> >& objects,
+ std::vector<std::vector<U> >& objects2
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ - images.size() == objects2.size()
+ - T == rectangle, full_object_detection, or mmod_rect
+ - U == rectangle, full_object_detection, or mmod_rect
+ ensures
+ - This function computes all the left/right flips of the contents of images and
+ then appends them onto the end of the images array. It also finds the
+ left/right flips of the rectangles in objects and objects2 and similarly
+ appends them into objects and objects2 respectively. That is, we assume
+ objects[i] is the set of bounding boxes in images[i] and we flip the bounding
+ boxes so that they still bound the same objects in the new flipped images.
+ We similarly flip the boxes in objects2.
+ - #images.size() == images.size()*2
+ - #objects.size() == objects.size()*2
+ - #objects2.size() == objects2.size()*2
+ - All the original elements of images, objects, and objects2 are left unmodified.
+ That is, this function only appends new elements to each of these containers.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename EXP,
+ typename T,
+ typename U
+ >
+ void add_image_rotations (
+ const matrix_exp<EXP>& angles,
+ image_array_type& images,
+ std::vector<std::vector<T> >& objects,
+ std::vector<std::vector<U> >& objects2
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - is_vector(angles) == true
+ - angles.size() > 0
+ - images.size() == objects.size()
+ - images.size() == objects2.size()
+ - T == rectangle, full_object_detection, or mmod_rect
+ - U == rectangle, full_object_detection, or mmod_rect
+ ensures
+ - This function computes angles.size() different rotations of all the given
+ images and then replaces the contents of images with those rotations of the
+ input dataset. We will also adjust the rectangles inside objects and
+ objects2 so that they still bound the same objects in the new rotated images.
+ That is, we assume objects[i] and objects2[i] are bounding boxes for things
+ in images[i]. So we will adjust the positions of the boxes in objects and
+ objects2 accordingly.
+ - The elements of angles are interpreted as angles in radians and we will
+ rotate the images around their center using the values in angles. Moreover,
+ the rotation is done counter clockwise.
+ - #images.size() == images.size()*angles.size()
+ - #objects.size() == objects.size()*angles.size()
+ - #objects2.size() == objects2.size()*angles.size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename EXP,
+ typename T
+ >
+ void add_image_rotations (
+ const matrix_exp<EXP>& angles,
+ image_array_type& images,
+ std::vector<std::vector<T> >& objects
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - is_vector(angles) == true
+ - angles.size() > 0
+ - images.size() == objects.size()
+ - T == rectangle, full_object_detection, or mmod_rect
+ ensures
+ - This function is identical to the add_image_rotations() define above except
+ that it doesn't have objects2 as an argument.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ void flip_image_dataset_left_right (
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ ensures
+ - This function replaces each image in images with the left/right flipped
+ version of the image. Therefore, #images[i] will contain the left/right
+ flipped version of images[i]. It also flips all the rectangles in objects so
+ that they still bound the same visual objects in each image.
+ - #images.size() == image.size()
+ - #objects.size() == objects.size()
+ - for all valid i:
+ #objects[i].size() == objects[i].size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ void flip_image_dataset_left_right (
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects,
+ std::vector<std::vector<rectangle> >& objects2
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ - images.size() == objects2.size()
+ ensures
+ - This function replaces each image in images with the left/right flipped
+ version of the image. Therefore, #images[i] will contain the left/right
+ flipped version of images[i]. It also flips all the rectangles in objects
+ and objects2 so that they still bound the same visual objects in each image.
+ - #images.size() == image.size()
+ - #objects.size() == objects.size()
+ - #objects2.size() == objects2.size()
+ - for all valid i:
+ #objects[i].size() == objects[i].size()
+ - for all valid i:
+ #objects2[i].size() == objects2[i].size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type,
+ typename image_array_type
+ >
+ void upsample_image_dataset (
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects,
+ unsigned long max_image_size = std::numeric_limits<unsigned long>::max()
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ ensures
+ - This function replaces each image in images with an upsampled version of that
+ image. Each image is upsampled using pyramid_up() and the given
+ pyramid_type. Therefore, #images[i] will contain the larger upsampled
+ version of images[i]. It also adjusts all the rectangles in objects so that
+ they still bound the same visual objects in each image.
+ - Input images already containing more than max_image_size pixels are not upsampled.
+ - #images.size() == image.size()
+ - #objects.size() == objects.size()
+ - for all valid i:
+ #objects[i].size() == objects[i].size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type,
+ typename image_array_type
+ >
+ void upsample_image_dataset (
+ image_array_type& images,
+ std::vector<std::vector<mmod_rect>>& objects,
+ unsigned long max_image_size = std::numeric_limits<unsigned long>::max()
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ ensures
+ - This function replaces each image in images with an upsampled version of that
+ image. Each image is upsampled using pyramid_up() and the given
+ pyramid_type. Therefore, #images[i] will contain the larger upsampled
+ version of images[i]. It also adjusts all the rectangles in objects so that
+ they still bound the same visual objects in each image.
+ - Input images already containing more than max_image_size pixels are not upsampled.
+ - #images.size() == image.size()
+ - #objects.size() == objects.size()
+ - for all valid i:
+ #objects[i].size() == objects[i].size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pyramid_type,
+ typename image_array_type,
+ >
+ void upsample_image_dataset (
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects,
+ std::vector<std::vector<rectangle> >& objects2,
+ unsigned long max_image_size = std::numeric_limits<unsigned long>::max()
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ - images.size() == objects2.size()
+ ensures
+ - This function replaces each image in images with an upsampled version of that
+ image. Each image is upsampled using pyramid_up() and the given
+ pyramid_type. Therefore, #images[i] will contain the larger upsampled
+ version of images[i]. It also adjusts all the rectangles in objects and
+ objects2 so that they still bound the same visual objects in each image.
+ - Input images already containing more than max_image_size pixels are not upsampled.
+ - #images.size() == image.size()
+ - #objects.size() == objects.size()
+ - #objects2.size() == objects2.size()
+ - for all valid i:
+ #objects[i].size() == objects[i].size()
+ - for all valid i:
+ #objects2[i].size() == objects2[i].size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_array_type>
+ void rotate_image_dataset (
+ double angle,
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ ensures
+ - This function replaces each image in images with a rotated version of that
+ image. In particular, each image is rotated using
+ rotate_image(original,rotated,angle). Therefore, the images are rotated
+ angle radians counter clockwise around their centers. That is, #images[i]
+ will contain the rotated version of images[i]. It also adjusts all
+ the rectangles in objects so that they still bound the same visual objects in
+ each image.
+ - All the rectangles will still have the same sizes and aspect ratios after
+ rotation. They will simply have had their positions adjusted so they still
+ fall on the same objects.
+ - #images.size() == image.size()
+ - #objects.size() == objects.size()
+ - for all valid i:
+ #objects[i].size() == objects[i].size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_array_type>
+ void rotate_image_dataset (
+ double angle,
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& objects,
+ std::vector<std::vector<rectangle> >& objects2
+ );
+ /*!
+ requires
+ - image_array_type == a dlib::array or std::vector of image objects that each
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - images.size() == objects.size()
+ - images.size() == objects2.size()
+ ensures
+ - This function replaces each image in images with a rotated version of that
+ image. In particular, each image is rotated using
+ rotate_image(original,rotated,angle). Therefore, the images are rotated
+ angle radians counter clockwise around their centers. That is, #images[i]
+ will contain the rotated version of images[i]. It also adjusts all
+ the rectangles in objects and objects2 so that they still bound the same
+ visual objects in each image.
+ - All the rectangles will still have the same sizes and aspect ratios after
+ rotation. They will simply have had their positions adjusted so they still
+ fall on the same objects.
+ - #images.size() == image.size()
+ - #objects.size() == objects.size()
+ - #objects2.size() == objects2.size()
+ - for all valid i:
+ #objects[i].size() == objects[i].size()
+ - for all valid i:
+ #objects2[i].size() == objects2[i].size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void flip_image_up_down (
+ const image_type1& in_img,
+ image_type2& out_img
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - #out_img.nr() == in_img.nr()
+ - #out_img.nc() == in_img.nc()
+ - #out_img == a copy of in_img which has been flipped upside down.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename pyramid_type,
+ typename interpolation_type
+ >
+ void pyramid_up (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const pyramid_type& pyr,
+ const interpolation_type& interp
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pyramid_type == a type compatible with the image pyramid objects defined
+ in dlib/image_transforms/image_pyramid_abstract.h
+ - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear,
+ interpolate_quadratic, or a type with a compatible interface.
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - This function inverts the downsampling transformation performed by pyr().
+ In particular, it attempts to make an image, out_img, which would result
+ in in_img when downsampled with pyr().
+ - #out_img == An upsampled copy of in_img. In particular, downsampling
+ #out_img 1 time with pyr() should result in a final image which looks like
+ in_img.
+ - Uses the supplied interpolation routine interp to perform the necessary
+ pixel interpolation.
+ - Note that downsampling an image with pyr() and then upsampling it with
+ pyramid_up() will not necessarily result in a final image which is
+ the same size as the original. This is because the exact size of the
+ original image cannot be determined based on the downsampled image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename pyramid_type
+ >
+ void pyramid_up (
+ const image_type1& in_img,
+ image_type2& out_img,
+ const pyramid_type& pyr
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pyramid_type == a type compatible with the image pyramid objects defined
+ in dlib/image_transforms/image_pyramid_abstract.h
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - performs: pyramid_up(in_img, out_img, pyr, interpolate_bilinear());
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename pyramid_type
+ >
+ void pyramid_up (
+ image_type& img,
+ const pyramid_type& pyr
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pyramid_type == a type compatible with the image pyramid objects defined
+ in dlib/image_transforms/image_pyramid_abstract.h
+ ensures
+ - Performs an in-place version of pyramid_up() on the given image. In
+ particular, this function is equivalent to:
+ pyramid_up(img, temp, pyr);
+ temp.swap(img);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void pyramid_up (
+ image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - performs: pyramid_up(img, pyramid_down<2>());
+ (i.e. it upsamples the given image and doubles it in size.)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ struct chip_dims
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple tool for passing in a pair of row and column values to the
+ chip_details constructor.
+ !*/
+
+ chip_dims (
+ unsigned long rows_,
+ unsigned long cols_
+ ) : rows(rows_), cols(cols_) { }
+
+ unsigned long rows;
+ unsigned long cols;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct chip_details
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object describes where an image chip is to be extracted from within
+ another image. In particular, it specifies that the image chip is
+ contained within the rectangle this->rect and that prior to extraction the
+ image should be rotated counter-clockwise by this->angle radians. Finally,
+ the extracted chip should have this->rows rows and this->cols columns in it
+ regardless of the shape of this->rect. This means that the extracted chip
+ will be stretched to fit via bilinear interpolation when necessary.
+ !*/
+
+ chip_details(
+ );
+ /*!
+ ensures
+ - #rect.is_empty() == true
+ - #size() == 0
+ - #angle == 0
+ - #rows == 0
+ - #cols == 0
+ !*/
+
+ chip_details(
+ const drectangle& rect_
+ );
+ /*!
+ ensures
+ - #rect == rect_
+ - #size() == rect_.area()
+ - #angle == 0
+ - #rows == rect_.height()
+ - #cols == rect_.width()
+ !*/
+
+ chip_details(
+ const rectangle& rect_
+ );
+ /*!
+ ensures
+ - #rect == rect_
+ - #size() == rect_.area()
+ - #angle == 0
+ - #rows == rect_.height()
+ - #cols == rect_.width()
+ !*/
+
+ chip_details(
+ const drectangle& rect_,
+ unsigned long size_
+ );
+ /*!
+ ensures
+ - #rect == rect_
+ - #size() == size_
+ - #angle == 0
+ - #rows and #cols is set such that the total size of the chip is as close
+ to size_ as possible but still matches the aspect ratio of rect_.
+ - As long as size_ and the aspect ratio of of rect_ stays constant then
+ #rows and #cols will always have the same values. This means that, for
+ example, if you want all your chips to have the same dimensions then
+ ensure that size_ is always the same and also that rect_ always has the
+ same aspect ratio. Otherwise the calculated values of #rows and #cols
+ may be different for different chips. Alternatively, you can use the
+ chip_details constructor below that lets you specify the exact values for
+ rows and cols.
+ !*/
+
+ chip_details(
+ const drectangle& rect_,
+ unsigned long size_,
+ double angle_
+ );
+ /*!
+ ensures
+ - #rect == rect_
+ - #size() == size_
+ - #angle == angle_
+ - #rows and #cols is set such that the total size of the chip is as close
+ to size_ as possible but still matches the aspect ratio of rect_.
+ - As long as size_ and the aspect ratio of of rect_ stays constant then
+ #rows and #cols will always have the same values. This means that, for
+ example, if you want all your chips to have the same dimensions then
+ ensure that size_ is always the same and also that rect_ always has the
+ same aspect ratio. Otherwise the calculated values of #rows and #cols
+ may be different for different chips. Alternatively, you can use the
+ chip_details constructor below that lets you specify the exact values for
+ rows and cols.
+ !*/
+
+ chip_details(
+ const drectangle& rect_,
+ const chip_dims& dims
+ );
+ /*!
+ ensures
+ - #rect == rect_
+ - #size() == dims.rows*dims.cols
+ - #angle == 0
+ - #rows == dims.rows
+ - #cols == dims.cols
+ !*/
+
+ chip_details(
+ const drectangle& rect_,
+ const chip_dims& dims,
+ double angle_
+ );
+ /*!
+ ensures
+ - #rect == rect_
+ - #size() == dims.rows*dims.cols
+ - #angle == angle_
+ - #rows == dims.rows
+ - #cols == dims.cols
+ !*/
+
+ template <typename T>
+ chip_details(
+ const std::vector<dlib::vector<T,2> >& chip_points,
+ const std::vector<dlib::vector<T,2> >& img_points,
+ const chip_dims& dims
+ );
+ /*!
+ requires
+ - chip_points.size() == img_points.size()
+ - chip_points.size() >= 2
+ ensures
+ - The chip will be extracted such that the pixel locations chip_points[i]
+ in the chip are mapped to img_points[i] in the original image by a
+ similarity transform. That is, if you know the pixelwize mapping you
+ want between the chip and the original image then you use this function
+ of chip_details constructor to define the mapping.
+ - #rows == dims.rows
+ - #cols == dims.cols
+ - #size() == dims.rows*dims.cols
+ - #rect and #angle are computed based on the given size of the output chip
+ (specified by dims) and the similarity transform between the chip and
+ image (specified by chip_points and img_points).
+ !*/
+
+ inline unsigned long size() const { return rows*cols; }
+ /*!
+ ensures
+ - returns the number of pixels in this chip. This is just rows*cols.
+ !*/
+
+ drectangle rect;
+ double angle;
+ unsigned long rows;
+ unsigned long cols;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ point_transform_affine get_mapping_to_chip (
+ const chip_details& details
+ );
+ /*!
+ ensures
+ - returns a transformation that maps from the pixels in the original image
+ to the pixels in the cropped image defined by the given details object.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ full_object_detection map_det_to_chip (
+ const full_object_detection& det,
+ const chip_details& details
+ );
+ /*!
+ ensures
+ - Maps the given detection into the pixel space of the image chip defined by
+ the given details object. That is, this function returns an object D such
+ that:
+ - D.get_rect() == a box that bounds the same thing in the image chip as
+ det.get_rect() bounds in the original image the chip is extracted from.
+ - for all valid i:
+ - D.part(i) == the location in the image chip corresponding to
+ det.part(i) in the original image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type
+ >
+ void extract_image_chips (
+ const image_type1& img,
+ const std::vector<chip_details>& chip_locations,
+ dlib::array<image_type2>& chips,
+ const interpolation_type& interp
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<typename image_traits<image_type1>::pixel_type>::has_alpha == false
+ - for all valid i:
+ - chip_locations[i].rect.is_empty() == false
+ - chip_locations[i].size() != 0
+ - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear,
+ interpolate_quadratic, or a type with a compatible interface.
+ ensures
+ - This function extracts "chips" from an image. That is, it takes a list of
+ rectangular sub-windows (i.e. chips) within an image and extracts those
+ sub-windows, storing each into its own image. It also scales and rotates the
+ image chips according to the instructions inside each chip_details object.
+ It uses the interpolation method supplied as a parameter.
+ - #chips == the extracted image chips
+ - #chips.size() == chip_locations.size()
+ - for all valid i:
+ - #chips[i] == The image chip extracted from the position
+ chip_locations[i].rect in img.
+ - #chips[i].nr() == chip_locations[i].rows
+ - #chips[i].nc() == chip_locations[i].cols
+ - The image will have been rotated counter-clockwise by
+ chip_locations[i].angle radians, around the center of
+ chip_locations[i].rect, before the chip was extracted.
+ - Any pixels in an image chip that go outside img are set to 0 (i.e. black).
+ !*/
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void extract_image_chips (
+ const image_type1& img,
+ const std::vector<chip_details>& chip_locations,
+ dlib::array<image_type2>& chips
+ );
+ /*!
+ ensures
+ - This function is a simple convenience / compatibility wrapper that calls the
+ above-defined extract_image_chips() function using bilinear interpolation.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2,
+ typename interpolation_type
+ >
+ void extract_image_chip (
+ const image_type1& img,
+ const chip_details& chip_location,
+ image_type2& chip,
+ const interpolation_type& interp
+ );
+ /*!
+ ensures
+ - This function simply calls extract_image_chips() with a single chip location
+ and stores the single output chip into #chip. It uses the provided
+ interpolation method.
+ !*/
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void extract_image_chip (
+ const image_type1& img,
+ const chip_details& chip_location,
+ image_type2& chip
+ );
+ /*!
+ ensures
+ - This function is a simple convenience / compatibility wrapper that calls the
+ above-defined extract_image_chip() function using bilinear interpolation.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ struct sub_image_proxy
+ {
+ /*!
+ REQUIREMENTS ON image_type
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a lightweight image object for referencing a subwindow of an image.
+ It implements the generic image interface and can therefore be used with
+ any function that expects a generic image, excepting that you cannot change
+ the size of a sub_image_proxy.
+
+ Note that it only stores a pointer to the image data given to its
+ constructor and therefore does not perform a copy. Moreover, this means
+ that an instance of this object becomes invalid after the underlying image
+ data it references is destroyed.
+ !*/
+ sub_image_proxy (
+ T& img,
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - This object is an image that represents the part of img contained within
+ rect. If rect is larger than img then rect is cropped so that it does
+ not go outside img.
+ !*/
+ };
+
+ template <
+ typename image_type
+ >
+ sub_image_proxy<image_type> sub_image (
+ image_type& img,
+ const rectangle& rect
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - returns sub_image_proxy<image_type>(img,rect)
+ !*/
+
+ template <typename T>
+ sub_image_proxy<some_appropriate_type> sub_image (
+ T* img,
+ long nr,
+ long nc,
+ long row_stride
+ );
+ /*!
+ requires
+ - img == a pointer to at least nr*row_stride T objects
+ - nr >= 0
+ - nc >= 0
+ - row_stride >= 0
+ ensures
+ - This function returns an image that is just a thin wrapper around the given
+ pointer. It will have the dimensions defined by the supplied longs. To be
+ precise, this function returns an image object IMG such that:
+ - image_data(IMG) == img
+ - num_rows(IMG) == nr
+ - num_columns(IMG) == nc
+ - width_step(IMG) == row_stride*sizeof(T)
+ - IMG contains pixels of type T.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ struct const_sub_image_proxy
+ {
+ /*!
+ REQUIREMENTS ON image_type
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is just like sub_image_proxy except that it does not allow the
+ pixel data to be modified.
+ !*/
+ const_sub_image_proxy (
+ const T& img,
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - This object is an image that represents the part of img contained within
+ rect. If rect is larger than img then rect is cropped so that it does
+ not go outside img.
+ !*/
+ };
+
+ template <
+ typename image_type
+ >
+ const const_sub_image_proxy<image_type> sub_image (
+ const image_type& img,
+ const rectangle& rect
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - returns const_sub_image_proxy<image_type>(img,rect)
+ !*/
+
+ template <typename T>
+ const const_sub_image_proxy<some_appropriate_type> sub_image (
+ const T* img,
+ long nr,
+ long nc,
+ long row_stride
+ );
+ /*!
+ requires
+ - img == a pointer to at least nr*row_stride T objects
+ - nr >= 0
+ - nc >= 0
+ - row_stride >= 0
+ ensures
+ - This function returns an image that is just a thin wrapper around the given
+ pointer. It will have the dimensions defined by the supplied longs. To be
+ precise, this function returns an image object IMG such that:
+ - image_data(IMG) == img
+ - num_rows(IMG) == nr
+ - num_columns(IMG) == nc
+ - width_step(IMG) == row_stride*sizeof(T)
+ - IMG contains pixels of type T.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ chip_details get_face_chip_details (
+ const full_object_detection& det,
+ const unsigned long size = 200,
+ const double padding = 0.2
+ );
+ /*!
+ requires
+ - det.num_parts() == 68 || det.num_parts() == 5
+ - size > 0
+ - padding >= 0
+ ensures
+ - This function assumes det contains a human face detection with face parts
+ annotated using the annotation scheme from the iBUG 300-W face landmark
+ dataset or a 5 point face annotation. Given these assumptions, it creates a
+ chip_details object that will extract a copy of the face that has been
+ rotated upright, centered, and scaled to a standard size when given to
+ extract_image_chip().
+ - This function is specifically calibrated to work with one of these models:
+ - http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2
+ - http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+ - The extracted chips will have size rows and columns in them.
+ - if padding == 0 then the chip will be closely cropped around the face.
+ Setting larger padding values will result a looser cropping. In particular,
+ a padding of 0.5 would double the width of the cropped area, a value of 1
+ would triple it, and so forth.
+ - The 5 point face annotation scheme is assumed to be:
+ - det part 0 == left eye corner, outside part of eye.
+ - det part 1 == left eye corner, inside part of eye.
+ - det part 2 == right eye corner, outside part of eye.
+ - det part 3 == right eye corner, inside part of eye.
+ - det part 4 == immediately under the nose, right at the top of the philtrum.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::vector<chip_details> get_face_chip_details (
+ const std::vector<full_object_detection>& dets,
+ const unsigned long size = 200,
+ const double padding = 0.2
+ );
+ /*!
+ requires
+ - for all valid i:
+ - det[i].num_parts() == 68
+ - size > 0
+ - padding >= 0
+ ensures
+ - This function is identical to the version of get_face_chip_details() defined
+ above except that it creates and returns an array of chip_details objects,
+ one for each input full_object_detection.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ image_type jitter_image(
+ const image_type& img,
+ dlib::rand& rnd
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - pixel_traits<typename image_traits<image_type>::pixel_type>::has_alpha == false
+ - img.size() > 0
+ - img.nr() == img.nc()
+ ensures
+ - Randomly jitters the image a little bit and returns this new jittered image.
+ To be specific, the returned image has the same size as img and will look
+ generally similar. The difference is that the returned image will have been
+ slightly rotated, zoomed, and translated. There is also a 50% chance it will
+ be mirrored left to right.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_INTERPOlATION_ABSTRACT_
+
diff --git a/ml/dlib/dlib/image_transforms/label_connected_blobs.h b/ml/dlib/dlib/image_transforms/label_connected_blobs.h
new file mode 100644
index 000000000..c25346c76
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/label_connected_blobs.h
@@ -0,0 +1,188 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LABEL_CONNeCTED_BLOBS_H_
+#define DLIB_LABEL_CONNeCTED_BLOBS_H_
+
+#include "label_connected_blobs_abstract.h"
+#include "../geometry.h"
+#include <stack>
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct neighbors_8
+ {
+ void operator() (
+ const point& p,
+ std::vector<point>& neighbors
+ ) const
+ {
+ neighbors.push_back(point(p.x()+1,p.y()+1));
+ neighbors.push_back(point(p.x()+1,p.y() ));
+ neighbors.push_back(point(p.x()+1,p.y()-1));
+
+ neighbors.push_back(point(p.x(),p.y()+1));
+ neighbors.push_back(point(p.x(),p.y()-1));
+
+ neighbors.push_back(point(p.x()-1,p.y()+1));
+ neighbors.push_back(point(p.x()-1,p.y() ));
+ neighbors.push_back(point(p.x()-1,p.y()-1));
+ }
+ };
+
+ struct neighbors_4
+ {
+ void operator() (
+ const point& p,
+ std::vector<point>& neighbors
+ ) const
+ {
+ neighbors.push_back(point(p.x()+1,p.y()));
+ neighbors.push_back(point(p.x()-1,p.y()));
+ neighbors.push_back(point(p.x(),p.y()+1));
+ neighbors.push_back(point(p.x(),p.y()-1));
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct connected_if_both_not_zero
+ {
+ template <typename image_type>
+ bool operator() (
+ const image_type& img,
+ const point& a,
+ const point& b
+ ) const
+ {
+ return (img[a.y()][a.x()] != 0 && img[b.y()][b.x()] != 0);
+ }
+ };
+
+ struct connected_if_equal
+ {
+ template <typename image_type>
+ bool operator() (
+ const image_type& img,
+ const point& a,
+ const point& b
+ ) const
+ {
+ return (img[a.y()][a.x()] == img[b.y()][b.x()]);
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct zero_pixels_are_background
+ {
+ template <typename image_type>
+ bool operator() (
+ const image_type& img,
+ const point& p
+ ) const
+ {
+ return img[p.y()][p.x()] == 0;
+ }
+
+ };
+
+ struct nothing_is_background
+ {
+ template <typename image_type>
+ bool operator() (
+ const image_type&,
+ const point&
+ ) const
+ {
+ return false;
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename label_image_type,
+ typename background_functor_type,
+ typename neighbors_functor_type,
+ typename connected_functor_type
+ >
+ unsigned long label_connected_blobs (
+ const image_type& img_,
+ const background_functor_type& is_background,
+ const neighbors_functor_type& get_neighbors,
+ const connected_functor_type& is_connected,
+ label_image_type& label_img_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_same_object(img_, label_img_) == false,
+ "\t unsigned long label_connected_blobs()"
+ << "\n\t The input image and output label image can't be the same object."
+ );
+
+ const_image_view<image_type> img(img_);
+ image_view<label_image_type> label_img(label_img_);
+
+ std::stack<point> neighbors;
+ label_img.set_size(img.nr(), img.nc());
+ assign_all_pixels(label_img, 0);
+ unsigned long next = 1;
+
+ if (img.size() == 0)
+ return 0;
+
+ const rectangle area = get_rect(img);
+
+ std::vector<point> window;
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ // skip already labeled pixels or background pixels
+ if (label_img[r][c] != 0 || is_background(img,point(c,r)))
+ continue;
+
+ label_img[r][c] = next;
+
+ // label all the neighbors of this point
+ neighbors.push(point(c,r));
+ while (neighbors.size() > 0)
+ {
+ const point p = neighbors.top();
+ neighbors.pop();
+
+ window.clear();
+ get_neighbors(p, window);
+
+ for (unsigned long i = 0; i < window.size(); ++i)
+ {
+ if (area.contains(window[i]) && // point in image.
+ !is_background(img,window[i]) && // isn't background.
+ label_img[window[i].y()][window[i].x()] == 0 && // haven't already labeled it.
+ is_connected(img, p, window[i])) // it's connected.
+ {
+ label_img[window[i].y()][window[i].x()] = next;
+ neighbors.push(window[i]);
+ }
+ }
+ }
+
+ ++next;
+ }
+ }
+
+ return next;
+ }
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LABEL_CONNeCTED_BLOBS_H_
+
diff --git a/ml/dlib/dlib/image_transforms/label_connected_blobs_abstract.h b/ml/dlib/dlib/image_transforms/label_connected_blobs_abstract.h
new file mode 100644
index 000000000..5dc984000
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/label_connected_blobs_abstract.h
@@ -0,0 +1,199 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LABEL_CONNeCTED_BLOBS_ABSTRACT_H_
+#ifdef DLIB_LABEL_CONNeCTED_BLOBS_ABSTRACT_H_
+
+#include "../geometry.h"
+#include <vector>
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct neighbors_8
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a pixel neighborhood generating functor for
+ use with the label_connected_blobs() routine defined below.
+ !*/
+
+ void operator() (
+ const point& p,
+ std::vector<point>& neighbors
+ ) const;
+ /*!
+ ensures
+ - adds the 8 neighboring pixels surrounding p into neighbors
+ !*/
+ };
+
+ struct neighbors_4
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a pixel neighborhood generating functor for
+ use with the label_connected_blobs() routine defined below.
+ !*/
+
+ void operator() (
+ const point& p,
+ std::vector<point>& neighbors
+ ) const;
+ /*!
+ ensures
+ - adds the 4 neighboring pixels of p into neighbors. These
+ are the ones immediately to the left, top, right, and bottom.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct connected_if_both_not_zero
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a pixel connection testing functor for use
+ with the label_connected_blobs() routine defined below.
+ !*/
+
+ template <typename image_view_type>
+ bool operator() (
+ const image_view_type& img,
+ const point& a,
+ const point& b
+ ) const
+ {
+ return (img[a.y()][a.x()] != 0 && img[b.y()][b.x()] != 0);
+ }
+ };
+
+ struct connected_if_equal
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a pixel connection testing functor for use
+ with the label_connected_blobs() routine defined below.
+ !*/
+
+ template <typename image_view_type>
+ bool operator() (
+ const image_view_type& img,
+ const point& a,
+ const point& b
+ ) const
+ {
+ return (img[a.y()][a.x()] == img[b.y()][b.x()]);
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct zero_pixels_are_background
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a background testing functor for use
+ with the label_connected_blobs() routine defined below.
+ !*/
+
+ template <typename image_view_type>
+ bool operator() (
+ const image_view_type& img,
+ const point& p
+ ) const
+ {
+ return img[p.y()][p.x()] == 0;
+ }
+
+ };
+
+ struct nothing_is_background
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a background testing functor for use
+ with the label_connected_blobs() routine defined below.
+ !*/
+
+ template <typename image_view_type>
+ bool operator() (
+ const image_view_type&,
+ const point&
+ ) const
+ {
+ return false;
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename label_image_type,
+ typename background_functor_type,
+ typename neighbors_functor_type,
+ typename connected_functor_type
+ >
+ unsigned long label_connected_blobs (
+ const image_type& img,
+ const background_functor_type& is_background,
+ const neighbors_functor_type& get_neighbors,
+ const connected_functor_type& is_connected,
+ label_image_type& label_img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - label_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain integer pixels.
+ - is_background(img, point(c,r)) is a legal expression that evaluates to a bool.
+ - is_connected(img, point(c,r), point(c2,r2)) is a legal expression that
+ evaluates to a bool.
+ - get_neighbors(point(c,r), neighbors) is a legal expression where neighbors
+ is of type std::vector<point>.
+ - is_same_object(img, label_img) == false
+ ensures
+ - This function labels each of the connected blobs in img with a unique integer
+ label.
+ - An image can be thought of as a graph where pixels A and B are connected if
+ and only if the following two statements are satisfied:
+ - is_connected(img,A,B) == true
+ - get_neighbors(A, neighbors) results in neighbors containing B or
+ get_neighbors(B, neighbors) results in neighbors containing A.
+ Then this function can be understood as labeling all the connected components
+ of this pixel graph such that all pixels in a component get the same label while
+ pixels in different components get different labels. Note that there is a
+ special "background" component determined by is_background(). Any pixels which
+ are "background" always get a blob id of 0 regardless of any other considerations.
+ - #label_img.nr() == img.nr()
+ - #label_img.nc() == img.nc()
+ - for all valid r and c:
+ - #label_img[r][c] == the blob label number for pixel img[r][c].
+ - #label_img[r][c] >= 0
+ - if (is_background(img, point(c,r))) then
+ - #label_img[r][c] == 0
+ - else
+ - #label_img[r][c] != 0
+ - if (img.size() != 0) then
+ - returns max(mat(#label_img))+1
+ (i.e. returns a number one greater than the maximum blob id number,
+ this is the number of blobs found.)
+ - else
+ - returns 0
+ - blob labels are contiguous, therefore, the number returned by this function is
+ the number of blobs in the image (including the background blob).
+ - It is guaranteed that is_connected() and is_background() will never be
+ called with points outside the image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LABEL_CONNeCTED_BLOBS_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/image_transforms/lbp.h b/ml/dlib/dlib/image_transforms/lbp.h
new file mode 100644
index 000000000..b6bbac9cf
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/lbp.h
@@ -0,0 +1,307 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LBP_Hh_
+#define DLIB_LBP_Hh_
+
+#include "lbp_abstract.h"
+#include "../image_processing/generic_image.h"
+#include "assign_image.h"
+#include "../pixel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename image_type2
+ >
+ void make_uniform_lbp_image (
+ const image_type& img_,
+ image_type2& lbp_
+ )
+ {
+ const static unsigned char uniform_lbps[] = {
+ 0, 1, 2, 3, 4, 58, 5, 6, 7, 58, 58, 58, 8, 58, 9, 10, 11, 58, 58, 58, 58, 58,
+ 58, 58, 12, 58, 58, 58, 13, 58, 14, 15, 16, 58, 58, 58, 58, 58, 58, 58, 58, 58,
+ 58, 58, 58, 58, 58, 58, 17, 58, 58, 58, 58, 58, 58, 58, 18, 58, 58, 58, 19, 58,
+ 20, 21, 22, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58,
+ 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 23, 58, 58, 58, 58, 58,
+ 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 24, 58, 58, 58, 58, 58, 58, 58, 25, 58,
+ 58, 58, 26, 58, 27, 28, 29, 30, 58, 31, 58, 58, 58, 32, 58, 58, 58, 58, 58, 58,
+ 58, 33, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 34, 58, 58,
+ 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58,
+ 58, 58, 58, 58, 58, 58, 58, 58, 58, 35, 36, 37, 58, 38, 58, 58, 58, 39, 58, 58,
+ 58, 58, 58, 58, 58, 40, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58,
+ 58, 41, 42, 43, 58, 44, 58, 58, 58, 45, 58, 58, 58, 58, 58, 58, 58, 46, 47, 48,
+ 58, 49, 58, 58, 58, 50, 51, 52, 58, 53, 54, 55, 56, 57
+ };
+
+ COMPILE_TIME_ASSERT(sizeof(uniform_lbps) == 256);
+
+ const_image_view<image_type> img(img_);
+ image_view<image_type2> lbp(lbp_);
+
+ lbp.set_size(img.nr(), img.nc());
+
+ // set all the border pixels to the "non-uniform LBP value".
+ assign_border_pixels(lbp, 1, 1, 58);
+
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ typedef typename pixel_traits<pixel_type>::basic_pixel_type basic_pixel_type;
+
+ for (long r = 1; r+1 < img.nr(); ++r)
+ {
+ for (long c = 1; c+1 < img.nc(); ++c)
+ {
+ const basic_pixel_type pix = get_pixel_intensity(img[r][c]);
+ unsigned char b1 = 0;
+ unsigned char b2 = 0;
+ unsigned char b3 = 0;
+ unsigned char b4 = 0;
+ unsigned char b5 = 0;
+ unsigned char b6 = 0;
+ unsigned char b7 = 0;
+ unsigned char b8 = 0;
+
+ unsigned char x = 0;
+ if (get_pixel_intensity(img[r-1][c-1]) > pix) b1 = 0x80;
+ if (get_pixel_intensity(img[r-1][c ]) > pix) b2 = 0x40;
+ if (get_pixel_intensity(img[r-1][c+1]) > pix) b3 = 0x20;
+ x |= b1;
+ if (get_pixel_intensity(img[r ][c-1]) > pix) b4 = 0x10;
+ x |= b2;
+ if (get_pixel_intensity(img[r ][c+1]) > pix) b5 = 0x08;
+ x |= b3;
+ if (get_pixel_intensity(img[r+1][c-1]) > pix) b6 = 0x04;
+ x |= b4;
+ if (get_pixel_intensity(img[r+1][c ]) > pix) b7 = 0x02;
+ x |= b5;
+ if (get_pixel_intensity(img[r+1][c+1]) > pix) b8 = 0x01;
+
+ x |= b6;
+ x |= b7;
+ x |= b8;
+
+ lbp[r][c] = uniform_lbps[x];
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T
+ >
+ void extract_histogram_descriptors (
+ const image_type& img_,
+ const point& loc,
+ std::vector<T>& histograms,
+ const unsigned int cell_size = 10,
+ const unsigned int block_size = 4,
+ const unsigned int max_val = 58
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(cell_size >= 1 && block_size >= 1 && max_val < 256 &&
+ (unsigned int)max(mat(img_)) <= max_val,
+ "\t void extract_histogram_descriptors()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t cell_size: " << cell_size
+ << "\n\t block_size: " << block_size
+ << "\n\t max_val: " << max_val
+ << "\n\t max(mat(img_)): " << max(mat(img_))
+ );
+
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+ COMPILE_TIME_ASSERT((is_same_type<pixel_type, unsigned char>::value));
+
+ const_image_view<image_type> img(img_);
+
+ const rectangle area = get_rect(img);
+ const rectangle window = centered_rect(loc, block_size*cell_size, block_size*cell_size);
+ unsigned int cell_top = window.top();
+ for (unsigned int br = 0; br < block_size; ++br)
+ {
+ unsigned int cell_left = window.left();
+ for (unsigned int bc = 0; bc < block_size; ++bc)
+ {
+ // figure out the cell boundaries
+ rectangle cell(cell_left, cell_top, cell_left+cell_size-1, cell_top+cell_size-1);
+ cell = cell.intersect(area);
+
+ // make the actual histogram for this cell
+ unsigned int hist[256] = {0};
+ for (long r = cell.top(); r <= cell.bottom(); ++r)
+ {
+ for (long c = cell.left(); c <= cell.right(); ++c)
+ {
+ hist[img[r][c]]++;
+ }
+ }
+
+ // copy histogram into the output.
+ histograms.insert(histograms.end(), hist, hist + max_val+1);
+
+ cell_left += cell_size;
+ }
+ cell_top += cell_size;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T
+ >
+ void extract_uniform_lbp_descriptors (
+ const image_type& img,
+ std::vector<T>& feats,
+ const unsigned int cell_size = 10
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(cell_size >= 1,
+ "\t void extract_uniform_lbp_descriptors()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t cell_size: " << cell_size
+ );
+
+ feats.clear();
+ array2d<unsigned char> lbp;
+ make_uniform_lbp_image(img, lbp);
+ for (long r = 0; r < lbp.nr(); r+=cell_size)
+ {
+ for (long c = 0; c < lbp.nc(); c+=cell_size)
+ {
+ const rectangle cell = rectangle(c,r,c+cell_size-1,r+cell_size-1).intersect(get_rect(lbp));
+ // make the actual histogram for this cell
+ unsigned int hist[59] = {0};
+ for (long r = cell.top(); r <= cell.bottom(); ++r)
+ {
+ for (long c = cell.left(); c <= cell.right(); ++c)
+ {
+ hist[lbp[r][c]]++;
+ }
+ }
+
+ // copy histogram into the output.
+ feats.insert(feats.end(), hist, hist + 59);
+ }
+ }
+
+ for (unsigned long i = 0; i < feats.size(); ++i)
+ feats[i] = std::sqrt(feats[i]);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T
+ >
+ void extract_highdim_face_lbp_descriptors (
+ const image_type& img,
+ const full_object_detection& det,
+ std::vector<T>& feats
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(det.num_parts() == 68,
+ "\t void extract_highdim_face_lbp_descriptors()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t det.num_parts(): " << det.num_parts()
+ );
+
+ const unsigned long num_scales = 5;
+ feats.clear();
+ dlib::vector<double,2> l, r;
+ double cnt = 0;
+ // Find the center of the left eye by averaging the points around
+ // the eye.
+ for (unsigned long i = 36; i <= 41; ++i)
+ {
+ l += det.part(i);
+ ++cnt;
+ }
+ l /= cnt;
+
+ // Find the center of the right eye by averaging the points around
+ // the eye.
+ cnt = 0;
+ for (unsigned long i = 42; i <= 47; ++i)
+ {
+ r += det.part(i);
+ ++cnt;
+ }
+ r /= cnt;
+
+ // We only do feature extraction from these face parts. These are things like the
+ // corners of the eyes and mouth and stuff like that.
+ std::vector<point> parts;
+ parts.reserve(30);
+ parts.push_back(l);
+ parts.push_back(r);
+ parts.push_back(det.part(17));
+ parts.push_back(det.part(21));
+ parts.push_back(det.part(22));
+ parts.push_back(det.part(26));
+ parts.push_back(det.part(36));
+ parts.push_back(det.part(39));
+ parts.push_back(det.part(42));
+ parts.push_back(det.part(45));
+ parts.push_back(det.part(27));
+ parts.push_back(det.part(28));
+ parts.push_back(det.part(29));
+ parts.push_back(det.part(30));
+ parts.push_back(det.part(31));
+ parts.push_back(det.part(35));
+ parts.push_back(det.part(33));
+ parts.push_back(det.part(48));
+ parts.push_back(det.part(54));
+ parts.push_back(det.part(51));
+ parts.push_back(det.part(57));
+
+ array2d<unsigned char> lbp;
+ make_uniform_lbp_image(img, lbp);
+ for (unsigned long i = 0; i < parts.size(); ++i)
+ extract_histogram_descriptors(lbp, parts[i], feats);
+
+ if (num_scales > 1)
+ {
+ pyramid_down<4> pyr;
+ image_type img_temp;
+ pyr(img, img_temp);
+ unsigned long num_pyr_calls = 1;
+
+ // now pull the features out at coarser scales
+ for (unsigned long iter = 1; iter < num_scales; ++iter)
+ {
+ // now do the feature extraction
+ make_uniform_lbp_image(img_temp, lbp);
+ for (unsigned long i = 0; i < parts.size(); ++i)
+ extract_histogram_descriptors(lbp, pyr.point_down(parts[i],num_pyr_calls), feats);
+
+ if (iter+1 < num_scales)
+ {
+ pyr(img_temp);
+ ++num_pyr_calls;
+ }
+ }
+ }
+
+ for (unsigned long i = 0; i < feats.size(); ++i)
+ feats[i] = std::sqrt(feats[i]);
+
+ DLIB_ASSERT(feats.size() == 99120, feats.size());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LBP_Hh_
+
diff --git a/ml/dlib/dlib/image_transforms/lbp_abstract.h b/ml/dlib/dlib/image_transforms/lbp_abstract.h
new file mode 100644
index 000000000..1a20082a2
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/lbp_abstract.h
@@ -0,0 +1,139 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LBP_ABSTRACT_Hh_
+#ifdef DLIB_LBP_ABSTRACT_Hh_
+
+#include "../image_processing/generic_image.h"
+#include "../pixel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename image_type2
+ >
+ void make_uniform_lbp_image (
+ const image_type& img,
+ image_type2& lbp
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 should contain a grayscale pixel type such as unsigned char.
+ ensures
+ - #lbp.nr() == img.nr()
+ - #lbp.nc() == img.nc()
+ - This function extracts the uniform local-binary-pattern feature at every pixel
+ and stores it into #lbp. In particular, we have the following for all valid
+ r and c:
+ - #lbp[r][c] == the uniform LBP for the 3x3 pixel window centered on img[r][c].
+ In particular, this is a value in the range 0 to 58 inclusive.
+ - We use the idea of uniform LBPs from the paper:
+ Face Description with Local Binary Patterns: Application to Face Recognition
+ by Ahonen, Hadid, and Pietikainen.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T
+ >
+ void extract_histogram_descriptors (
+ const image_type& img,
+ const point& loc,
+ std::vector<T>& histograms,
+ const unsigned int cell_size = 10,
+ const unsigned int block_size = 4,
+ const unsigned int max_val = 58
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type contains unsigned char valued pixels.
+ - T is some scalar type like int or double
+ - All pixel values in img are <= max_val
+ - cell_size >= 1
+ - block_size >= 1
+ - max_val < 256
+ ensures
+ - This function extracts histograms of pixel values from block_size*block_size
+ windows in the area in img immediately around img[loc.y()][loc.x()]. The
+ histograms are appended onto the end of #histograms. Each window is
+ cell_size pixels wide and tall. Moreover, the windows do not overlap.
+ - #histograms.size() == histograms.size() + block_size*block_size*(max_val+1)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T
+ >
+ void extract_uniform_lbp_descriptors (
+ const image_type& img,
+ std::vector<T>& feats,
+ const unsigned int cell_size = 10
+ );
+ /*!
+ requires
+ - cell_size >= 1
+ - T is some scalar type like int or double
+ ensures
+ - Extracts histograms of uniform local-binary-patterns from img. The
+ histograms are from densely tiled windows that are cell_size pixels wide and
+ tall. The windows do not overlap and cover all of img.
+ - #feats.size() == 59*(number of windows that fit into img)
+ (i.e. #feats contains the LBP histograms)
+ - We will have taken the square root of all the histogram elements. That is,
+ #feats[i] is the square root of the number of LBPs that appeared in its
+ corresponding window.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename T
+ >
+ void extract_highdim_face_lbp_descriptors (
+ const image_type& img,
+ const full_object_detection& det,
+ std::vector<T>& feats
+ );
+ /*!
+ requires
+ - T is some scalar type like int or double
+ - det.num_parts() == 68
+ ensures
+ - This function extracts the high-dimensional LBP feature described in the
+ paper:
+ Blessing of Dimensionality: High-dimensional Feature and Its Efficient
+ Compression for Face Verification by Dong Chen, Xudong Cao, Fang Wen, and
+ Jian Sun
+ - #feats == the high-dimensional LBP descriptor. It is the concatenation of
+ many LBP histograms, each extracted from different scales and from different
+ windows around different face landmarks. We also take the square root of
+ each histogram element before storing it into #feats.
+ - #feats.size() == 99120
+ - This function assumes img has already been aligned and normalized to a
+ standard size.
+ - This function assumes det contains a human face detection with face parts
+ annotated using the annotation scheme from the iBUG 300-W face landmark
+ dataset. This means that det.part(i) gives the locations of different face
+ landmarks according to the iBUG 300-W annotation scheme.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LBP_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_transforms/morphological_operations.h b/ml/dlib/dlib/image_transforms/morphological_operations.h
new file mode 100644
index 000000000..a659e4bdc
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/morphological_operations.h
@@ -0,0 +1,846 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MORPHOLOGICAL_OPERATIONs_
+#define DLIB_MORPHOLOGICAL_OPERATIONs_
+
+#include "../pixel.h"
+#include "thresholding.h"
+#include "morphological_operations_abstract.h"
+#include "assign_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace morphological_operations_helpers
+ {
+ template <typename image_type>
+ bool is_binary_image (
+ const image_type& img_
+ )
+ /*!
+ ensures
+ - returns true if img_ contains only on_pixel and off_pixel values.
+ - returns false otherwise
+ !*/
+ {
+ const_image_view<image_type> img(img_);
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ if (img[r][c] != on_pixel && img[r][c] != off_pixel)
+ {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ template <
+ long M,
+ long N
+ >
+ bool is_binary_image (
+ const unsigned char (&structuring_element)[M][N]
+ )
+ /*!
+ ensures
+ - returns true if structuring_element contains only on_pixel and off_pixel values.
+ - returns false otherwise
+ !*/
+ {
+ for (long m = 0; m < M; ++m)
+ {
+ for (long n = 0; n < N; ++n)
+ {
+ if (structuring_element[m][n] != on_pixel &&
+ structuring_element[m][n] != off_pixel)
+ {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ long M,
+ long N
+ >
+ void binary_dilation (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const unsigned char (&structuring_element)[M][N]
+ )
+ {
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ using namespace morphological_operations_helpers;
+ COMPILE_TIME_ASSERT(M%2 == 1);
+ COMPILE_TIME_ASSERT(N%2 == 1);
+ DLIB_ASSERT(is_same_object(in_img_,out_img_) == false,
+ "\tvoid binary_dilation()"
+ << "\n\tYou must give two different image objects"
+ );
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type>::grayscale);
+ DLIB_ASSERT(is_binary_image(in_img_) ,
+ "\tvoid binary_dilation()"
+ << "\n\tin_img must be a binary image"
+ );
+ DLIB_ASSERT(is_binary_image(structuring_element) ,
+ "\tvoid binary_dilation()"
+ << "\n\tthe structuring_element must be a binary image"
+ );
+
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+ // apply the filter to the image
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ for (long c = 0; c < in_img.nc(); ++c)
+ {
+ unsigned char out_pixel = off_pixel;
+ for (long m = 0; m < M && out_pixel == off_pixel; ++m)
+ {
+ for (long n = 0; n < N && out_pixel == off_pixel; ++n)
+ {
+ if (structuring_element[m][n] == on_pixel)
+ {
+ // if this pixel is inside the image then get it from the image
+ // but if it isn't just pretend it was an off_pixel value
+ if (r+m >= M/2 && c+n >= N/2 &&
+ r+m-M/2 < in_img.nr() && c+n-N/2 < in_img.nc())
+ {
+ out_pixel = in_img[r+m-M/2][c+n-N/2];
+ }
+ }
+ }
+ }
+ assign_pixel(out_img[r][c], out_pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ long M,
+ long N
+ >
+ void binary_erosion (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const unsigned char (&structuring_element)[M][N]
+ )
+ {
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ using namespace morphological_operations_helpers;
+ COMPILE_TIME_ASSERT(M%2 == 1);
+ COMPILE_TIME_ASSERT(N%2 == 1);
+ DLIB_ASSERT(is_same_object(in_img_,out_img_) == false,
+ "\tvoid binary_erosion()"
+ << "\n\tYou must give two different image objects"
+ );
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type>::grayscale);
+ DLIB_ASSERT(is_binary_image(in_img_) ,
+ "\tvoid binary_erosion()"
+ << "\n\tin_img must be a binary image"
+ );
+ DLIB_ASSERT(is_binary_image(structuring_element) ,
+ "\tvoid binary_erosion()"
+ << "\n\tthe structuring_element must be a binary image"
+ );
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+ // apply the filter to the image
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ for (long c = 0; c < in_img.nc(); ++c)
+ {
+ unsigned char out_pixel = on_pixel;
+ for (long m = 0; m < M && out_pixel == on_pixel; ++m)
+ {
+ for (long n = 0; n < N && out_pixel == on_pixel; ++n)
+ {
+ if (structuring_element[m][n] == on_pixel)
+ {
+ // if this pixel is inside the image then get it from the image
+ // but if it isn't just pretend it was an off_pixel value
+ if (r+m >= M/2 && c+n >= N/2 &&
+ r+m-M/2 < in_img.nr() && c+n-N/2 < in_img.nc())
+ {
+ out_pixel = in_img[r+m-M/2][c+n-N/2];
+ }
+ else
+ {
+ out_pixel = off_pixel;
+ }
+ }
+ }
+ }
+ assign_pixel(out_img[r][c], out_pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ long M,
+ long N
+ >
+ void binary_open (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const unsigned char (&structuring_element)[M][N],
+ const unsigned long iter = 1
+ )
+ {
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ using namespace morphological_operations_helpers;
+ COMPILE_TIME_ASSERT(M%2 == 1);
+ COMPILE_TIME_ASSERT(N%2 == 1);
+ DLIB_ASSERT(is_same_object(in_img,out_img) == false,
+ "\tvoid binary_open()"
+ << "\n\tYou must give two different image objects"
+ );
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type>::grayscale);
+ DLIB_ASSERT(is_binary_image(in_img) ,
+ "\tvoid binary_open()"
+ << "\n\tin_img must be a binary image"
+ );
+ DLIB_ASSERT(is_binary_image(structuring_element) ,
+ "\tvoid binary_open()"
+ << "\n\tthe structuring_element must be a binary image"
+ );
+
+
+ // if there isn't any input image then don't do anything
+ if (num_rows(in_img)*num_columns(in_img) == 0)
+ {
+ set_image_size(out_img, 0,0);
+ return;
+ }
+
+ set_image_size(out_img, num_rows(in_img), num_columns(in_img));
+
+ if (iter == 0)
+ {
+ // just copy the image over
+ assign_image(out_img, in_img);
+ }
+ else if (iter == 1)
+ {
+ in_image_type temp;
+ binary_erosion(in_img,temp,structuring_element);
+ binary_dilation(temp,out_img,structuring_element);
+ }
+ else
+ {
+ in_image_type temp1, temp2;
+ binary_erosion(in_img,temp1,structuring_element);
+
+ // do the extra erosions
+ for (unsigned long i = 1; i < iter; ++i)
+ {
+ swap(temp1, temp2);
+ binary_erosion(temp2,temp1,structuring_element);
+ }
+
+ // do the extra dilations
+ for (unsigned long i = 1; i < iter; ++i)
+ {
+ swap(temp1, temp2);
+ binary_dilation(temp2,temp1,structuring_element);
+ }
+
+ binary_dilation(temp1,out_img,structuring_element);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ long M,
+ long N
+ >
+ void binary_close (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const unsigned char (&structuring_element)[M][N],
+ const unsigned long iter = 1
+ )
+ {
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+
+ using namespace morphological_operations_helpers;
+ COMPILE_TIME_ASSERT(M%2 == 1);
+ COMPILE_TIME_ASSERT(N%2 == 1);
+ DLIB_ASSERT(is_same_object(in_img,out_img) == false,
+ "\tvoid binary_close()"
+ << "\n\tYou must give two different image objects"
+ );
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type>::grayscale);
+ DLIB_ASSERT(is_binary_image(in_img) ,
+ "\tvoid binary_close()"
+ << "\n\tin_img must be a binary image"
+ );
+ DLIB_ASSERT(is_binary_image(structuring_element) ,
+ "\tvoid binary_close()"
+ << "\n\tthe structuring_element must be a binary image"
+ );
+
+
+ // if there isn't any input image then don't do anything
+ if (num_rows(in_img)*num_columns(in_img) == 0)
+ {
+ set_image_size(out_img, 0,0);
+ return;
+ }
+
+ set_image_size(out_img, num_rows(in_img), num_columns(in_img));
+
+ if (iter == 0)
+ {
+ // just copy the image over
+ assign_image(out_img, in_img);
+ }
+ else if (iter == 1)
+ {
+ in_image_type temp;
+ binary_dilation(in_img,temp,structuring_element);
+ binary_erosion(temp,out_img,structuring_element);
+ }
+ else
+ {
+ in_image_type temp1, temp2;
+ binary_dilation(in_img,temp1,structuring_element);
+
+ // do the extra dilations
+ for (unsigned long i = 1; i < iter; ++i)
+ {
+ swap(temp1, temp2);
+ binary_dilation(temp2,temp1,structuring_element);
+ }
+
+ // do the extra erosions
+ for (unsigned long i = 1; i < iter; ++i)
+ {
+ swap(temp1, temp2);
+ binary_erosion(temp2,temp1,structuring_element);
+ }
+
+ binary_erosion(temp1,out_img,structuring_element);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type1,
+ typename in_image_type2,
+ typename out_image_type
+ >
+ void binary_intersection (
+ const in_image_type1& in_img1_,
+ const in_image_type2& in_img2_,
+ out_image_type& out_img_
+ )
+ {
+ typedef typename image_traits<in_image_type1>::pixel_type in_pixel_type1;
+ typedef typename image_traits<in_image_type2>::pixel_type in_pixel_type2;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type1>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type2>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ using namespace morphological_operations_helpers;
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type1>::grayscale);
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type2>::grayscale);
+ DLIB_ASSERT(is_binary_image(in_img1_) ,
+ "\tvoid binary_intersection()"
+ << "\n\tin_img1 must be a binary image"
+ );
+ DLIB_ASSERT(is_binary_image(in_img2_) ,
+ "\tvoid binary_intersection()"
+ << "\n\tin_img2 must be a binary image"
+ );
+
+ const_image_view<in_image_type1> in_img1(in_img1_);
+ const_image_view<in_image_type2> in_img2(in_img2_);
+ image_view<out_image_type> out_img(out_img_);
+
+ DLIB_ASSERT(in_img1.nc() == in_img2.nc(),
+ "\tvoid binary_intersection()"
+ << "\n\tin_img1 and in_img2 must have the same ncs."
+ << "\n\tin_img1.nc(): " << in_img1.nc()
+ << "\n\tin_img2.nc(): " << in_img2.nc()
+ );
+ DLIB_ASSERT(in_img1.nr() == in_img2.nr(),
+ "\tvoid binary_intersection()"
+ << "\n\tin_img1 and in_img2 must have the same nrs."
+ << "\n\tin_img1.nr(): " << in_img1.nr()
+ << "\n\tin_img2.nr(): " << in_img2.nr()
+ );
+
+
+
+ // if there isn't any input image then don't do anything
+ if (in_img1.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(in_img1.nr(),in_img1.nc());
+
+ for (long r = 0; r < in_img1.nr(); ++r)
+ {
+ for (long c = 0; c < in_img1.nc(); ++c)
+ {
+ if (in_img1[r][c] == on_pixel && in_img2[r][c] == on_pixel)
+ assign_pixel(out_img[r][c], on_pixel);
+ else
+ assign_pixel(out_img[r][c], off_pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type1,
+ typename in_image_type2,
+ typename out_image_type
+ >
+ void binary_union (
+ const in_image_type1& in_img1_,
+ const in_image_type2& in_img2_,
+ out_image_type& out_img_
+ )
+ {
+ typedef typename image_traits<in_image_type1>::pixel_type in_pixel_type1;
+ typedef typename image_traits<in_image_type2>::pixel_type in_pixel_type2;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type1>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type2>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+
+ using namespace morphological_operations_helpers;
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type1>::grayscale);
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type2>::grayscale);
+ DLIB_ASSERT(is_binary_image(in_img1_) ,
+ "\tvoid binary_intersection()"
+ << "\n\tin_img1 must be a binary image"
+ );
+ DLIB_ASSERT(is_binary_image(in_img2_) ,
+ "\tvoid binary_intersection()"
+ << "\n\tin_img2 must be a binary image"
+ );
+
+ const_image_view<in_image_type1> in_img1(in_img1_);
+ const_image_view<in_image_type2> in_img2(in_img2_);
+ image_view<out_image_type> out_img(out_img_);
+
+ DLIB_ASSERT(in_img1.nc() == in_img2.nc(),
+ "\tvoid binary_intersection()"
+ << "\n\tin_img1 and in_img2 must have the same ncs."
+ << "\n\tin_img1.nc(): " << in_img1.nc()
+ << "\n\tin_img2.nc(): " << in_img2.nc()
+ );
+ DLIB_ASSERT(in_img1.nr() == in_img2.nr(),
+ "\tvoid binary_intersection()"
+ << "\n\tin_img1 and in_img2 must have the same nrs."
+ << "\n\tin_img1.nr(): " << in_img1.nr()
+ << "\n\tin_img2.nr(): " << in_img2.nr()
+ );
+
+
+
+ // if there isn't any input image then don't do anything
+ if (in_img1.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(in_img1.nr(),in_img1.nc());
+
+ for (long r = 0; r < in_img1.nr(); ++r)
+ {
+ for (long c = 0; c < in_img1.nc(); ++c)
+ {
+ if (in_img1[r][c] == on_pixel || in_img2[r][c] == on_pixel)
+ assign_pixel(out_img[r][c], on_pixel);
+ else
+ assign_pixel(out_img[r][c], off_pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type1,
+ typename in_image_type2,
+ typename out_image_type
+ >
+ void binary_difference (
+ const in_image_type1& in_img1_,
+ const in_image_type2& in_img2_,
+ out_image_type& out_img_
+ )
+ {
+ typedef typename image_traits<in_image_type1>::pixel_type in_pixel_type1;
+ typedef typename image_traits<in_image_type2>::pixel_type in_pixel_type2;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type1>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type2>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+ using namespace morphological_operations_helpers;
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type1>::grayscale);
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type2>::grayscale);
+ DLIB_ASSERT(is_binary_image(in_img1_) ,
+ "\tvoid binary_difference()"
+ << "\n\tin_img1 must be a binary image"
+ );
+ DLIB_ASSERT(is_binary_image(in_img2_) ,
+ "\tvoid binary_difference()"
+ << "\n\tin_img2 must be a binary image"
+ );
+
+ const_image_view<in_image_type1> in_img1(in_img1_);
+ const_image_view<in_image_type2> in_img2(in_img2_);
+ image_view<out_image_type> out_img(out_img_);
+
+ DLIB_ASSERT(in_img1.nc() == in_img2.nc(),
+ "\tvoid binary_difference()"
+ << "\n\tin_img1 and in_img2 must have the same ncs."
+ << "\n\tin_img1.nc(): " << in_img1.nc()
+ << "\n\tin_img2.nc(): " << in_img2.nc()
+ );
+ DLIB_ASSERT(in_img1.nr() == in_img2.nr(),
+ "\tvoid binary_difference()"
+ << "\n\tin_img1 and in_img2 must have the same nrs."
+ << "\n\tin_img1.nr(): " << in_img1.nr()
+ << "\n\tin_img2.nr(): " << in_img2.nr()
+ );
+
+
+
+ // if there isn't any input image then don't do anything
+ if (in_img1.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(in_img1.nr(),in_img1.nc());
+
+ for (long r = 0; r < in_img1.nr(); ++r)
+ {
+ for (long c = 0; c < in_img1.nc(); ++c)
+ {
+ if (in_img1[r][c] == on_pixel && in_img2[r][c] == off_pixel)
+ assign_pixel(out_img[r][c], on_pixel);
+ else
+ assign_pixel(out_img[r][c], off_pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void binary_complement (
+ const in_image_type& in_img_,
+ out_image_type& out_img_
+ )
+ {
+ typedef typename image_traits<in_image_type>::pixel_type in_pixel_type;
+ typedef typename image_traits<out_image_type>::pixel_type out_pixel_type;
+ COMPILE_TIME_ASSERT( pixel_traits<in_pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<out_pixel_type>::has_alpha == false );
+
+
+ using namespace morphological_operations_helpers;
+ COMPILE_TIME_ASSERT(pixel_traits<in_pixel_type>::grayscale);
+ DLIB_ASSERT(is_binary_image(in_img_) ,
+ "\tvoid binary_complement()"
+ << "\n\tin_img must be a binary image"
+ );
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ for (long c = 0; c < in_img.nc(); ++c)
+ {
+ if (in_img[r][c] == on_pixel)
+ assign_pixel(out_img[r][c], off_pixel);
+ else
+ assign_pixel(out_img[r][c], on_pixel);
+ }
+ }
+ }
+
+ template <
+ typename image_type
+ >
+ void binary_complement (
+ image_type& img
+ )
+ {
+ binary_complement(img,img);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename image_type>
+ inline bool should_remove_pixel (
+ const image_type& img,
+ long r,
+ long c,
+ int iter
+ )
+ {
+ unsigned int p2 = img[r-1][c];
+ unsigned int p3 = img[r-1][c+1];
+ unsigned int p4 = img[r][c+1];
+ unsigned int p5 = img[r+1][c+1];
+ unsigned int p6 = img[r+1][c];
+ unsigned int p7 = img[r+1][c-1];
+ unsigned int p8 = img[r][c-1];
+ unsigned int p9 = img[r-1][c-1];
+
+ int A = (p2 == 0 && p3 == 255) + (p3 == 0 && p4 == 255) +
+ (p4 == 0 && p5 == 255) + (p5 == 0 && p6 == 255) +
+ (p6 == 0 && p7 == 255) + (p7 == 0 && p8 == 255) +
+ (p8 == 0 && p9 == 255) + (p9 == 0 && p2 == 255);
+ int B = p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9;
+ int m1 = iter == 0 ? (p2 * p4 * p6) : (p2 * p4 * p8);
+ int m2 = iter == 0 ? (p4 * p6 * p8) : (p2 * p6 * p8);
+ // Decide if we should remove the pixel img[r][c].
+ return (A == 1 && (B >= 2*255 && B <= 6*255) && m1 == 0 && m2 == 0);
+ }
+
+ template <typename image_type>
+ inline void add_to_remove (
+ std::vector<point>& to_remove,
+ array2d<unsigned char>& marker,
+ const image_type& img,
+ long r,
+ long c,
+ int iter
+ )
+ {
+ if (marker[r][c]&&should_remove_pixel(img,r,c,iter))
+ {
+ to_remove.push_back(point(c,r));
+ marker[r][c] = 0;
+ }
+ }
+
+ template <typename image_type>
+ inline bool is_bw_border_pixel(
+ const image_type& img,
+ long r,
+ long c
+ )
+ {
+ unsigned int p2 = img[r-1][c];
+ unsigned int p3 = img[r-1][c+1];
+ unsigned int p4 = img[r][c+1];
+ unsigned int p5 = img[r+1][c+1];
+ unsigned int p6 = img[r+1][c];
+ unsigned int p7 = img[r+1][c-1];
+ unsigned int p8 = img[r][c-1];
+ unsigned int p9 = img[r-1][c-1];
+
+ int B = p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9;
+ // If you are on but at least one of your neighbors isn't.
+ return B<8*255 && img[r][c];
+
+ }
+
+ inline void add_if(
+ std::vector<point>& to_check2,
+ const array2d<unsigned char>& marker,
+ long c,
+ long r
+ )
+ {
+ if (marker[r][c])
+ to_check2.push_back(point(c,r));
+ }
+
+ } // end namespace impl
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void skeleton(
+ image_type& img_
+ )
+ {
+ /*
+ The implementation of this function is based on the paper
+ "A fast parallel algorithm for thinning digital patterns” by T.Y. Zhang and C.Y. Suen.
+ and also the excellent discussion of it at:
+ http://opencv-code.com/quick-tips/implementation-of-thinning-algorithm-in-opencv/
+ */
+
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+
+ // This function only works on grayscale images
+ COMPILE_TIME_ASSERT(pixel_traits<pixel_type>::grayscale);
+
+ using namespace impl;
+ // Note that it's important to zero the border for 2 reasons. First, it allows
+ // thinning to being at the border of the image. But more importantly, it causes
+ // the mask to have a border of 0 pixels as well which we use later to avoid
+ // indexing outside the image inside add_to_remove().
+ zero_border_pixels(img_,1,1);
+ image_view<image_type> img(img_);
+
+ // We use the marker to keep track of pixels we have committed to removing but
+ // haven't yet removed from img.
+ array2d<unsigned char> marker(img.nr(), img.nc());
+ assign_image(marker, img);
+
+
+ // Begin by making a list of the pixels on the borders of binary blobs.
+ std::vector<point> to_remove, to_check, to_check2;
+ for (int r = 1; r < img.nr()-1; r++)
+ {
+ for (int c = 1; c < img.nc()-1; c++)
+ {
+ if (is_bw_border_pixel(img, r, c))
+ {
+ to_check.push_back(point(c,r));
+ }
+ }
+ }
+
+ // Now start iteratively looking at the border pixels and removing them.
+ while(to_check.size() != 0)
+ {
+ for (int iter = 0; iter <= 1; ++iter)
+ {
+ // Check which pixels we should remove
+ to_remove.clear();
+ for (unsigned long i = 0; i < to_check.size(); ++i)
+ {
+ long r = to_check[i].y();
+ long c = to_check[i].x();
+ add_to_remove(to_remove, marker, img, r, c, iter);
+ }
+ for (unsigned long i = 0; i < to_check2.size(); ++i)
+ {
+ long r = to_check2[i].y();
+ long c = to_check2[i].x();
+ add_to_remove(to_remove, marker, img, r, c, iter);
+ }
+ // Now remove those pixels. Also add their neighbors into the "to check"
+ // pixel list for the next iteration.
+ for (unsigned long i = 0; i < to_remove.size(); ++i)
+ {
+ long r = to_remove[i].y();
+ long c = to_remove[i].x();
+ // remove the pixel
+ img[r][c] = 0;
+ add_if(to_check2, marker, c-1, r-1);
+ add_if(to_check2, marker, c, r-1);
+ add_if(to_check2, marker, c+1, r-1);
+ add_if(to_check2, marker, c-1, r);
+ add_if(to_check2, marker, c+1, r);
+ add_if(to_check2, marker, c-1, r+1);
+ add_if(to_check2, marker, c, r+1);
+ add_if(to_check2, marker, c+1, r+1);
+ }
+ }
+ to_check.clear();
+ to_check.swap(to_check2);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MORPHOLOGICAL_OPERATIONs_
+
diff --git a/ml/dlib/dlib/image_transforms/morphological_operations_abstract.h b/ml/dlib/dlib/image_transforms/morphological_operations_abstract.h
new file mode 100644
index 000000000..c69bdd1ca
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/morphological_operations_abstract.h
@@ -0,0 +1,316 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MORPHOLOGICAL_OPERATIONs_ABSTRACT_
+#ifdef DLIB_MORPHOLOGICAL_OPERATIONs_ABSTRACT_
+
+#include "../pixel.h"
+#include "thresholding_abstract.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ long M,
+ long N
+ >
+ void binary_dilation (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const unsigned char (&structuring_element)[M][N]
+ );
+ /*!
+ requires
+ - in_image_type and out_image_type are image objects that implement the
+ interface defined in dlib/image_processing/generic_image.h
+ - in_img must contain a grayscale pixel type.
+ - both in_img and out_img must contain pixels with no alpha channel.
+ (i.e. pixel_traits::has_alpha==false for their pixels)
+ - is_same_object(in_img,out_img) == false
+ - M % 2 == 1 (i.e. M must be odd)
+ - N % 2 == 1 (i.e. N must be odd)
+ - all pixels in in_img are set to either on_pixel or off_pixel
+ (i.e. it must be a binary image)
+ - all pixels in structuring_element are set to either on_pixel or off_pixel
+ (i.e. it must be a binary image)
+ ensures
+ - Does a binary dilation of in_img using the given structuring element and
+ stores the result in out_img.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ long M,
+ long N
+ >
+ void binary_erosion (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const unsigned char (&structuring_element)[M][N]
+ );
+ /*!
+ requires
+ - in_image_type and out_image_type are image objects that implement the
+ interface defined in dlib/image_processing/generic_image.h
+ - in_img must contain a grayscale pixel type.
+ - both in_img and out_img must contain pixels with no alpha channel.
+ (i.e. pixel_traits::has_alpha==false for their pixels)
+ - is_same_object(in_img,out_img) == false
+ - M % 2 == 1 (i.e. M must be odd)
+ - N % 2 == 1 (i.e. N must be odd)
+ - all pixels in in_img are set to either on_pixel or off_pixel
+ (i.e. it must be a binary image)
+ - all pixels in structuring_element are set to either on_pixel or off_pixel
+ (i.e. it must be a binary image)
+ ensures
+ - Does a binary erosion of in_img using the given structuring element and
+ stores the result in out_img.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ long M,
+ long N
+ >
+ void binary_open (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const unsigned char (&structuring_element)[M][N],
+ const unsigned long iter = 1
+ );
+ /*!
+ requires
+ - in_image_type and out_image_type are image objects that implement the
+ interface defined in dlib/image_processing/generic_image.h
+ - in_img must contain a grayscale pixel type.
+ - both in_img and out_img must contain pixels with no alpha channel.
+ (i.e. pixel_traits::has_alpha==false for their pixels)
+ - is_same_object(in_img,out_img) == false
+ - M % 2 == 1 (i.e. M must be odd)
+ - N % 2 == 1 (i.e. N must be odd)
+ - all pixels in in_img are set to either on_pixel or off_pixel
+ (i.e. it must be a binary image)
+ - all pixels in structuring_element are set to either on_pixel or off_pixel
+ (i.e. it must be a binary image)
+ ensures
+ - Does a binary open of in_img using the given structuring element and
+ stores the result in out_img. Specifically, iter iterations of binary
+ erosion are applied and then iter iterations of binary dilation.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ long M,
+ long N
+ >
+ void binary_close (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const unsigned char (&structuring_element)[M][N],
+ const unsigned long iter = 1
+ );
+ /*!
+ requires
+ - in_image_type and out_image_type are image objects that implement the
+ interface defined in dlib/image_processing/generic_image.h
+ - in_img must contain a grayscale pixel type.
+ - both in_img and out_img must contain pixels with no alpha channel.
+ (i.e. pixel_traits::has_alpha==false for their pixels)
+ - is_same_object(in_img,out_img) == false
+ - M % 2 == 1 (i.e. M must be odd)
+ - N % 2 == 1 (i.e. N must be odd)
+ - all pixels in in_img are set to either on_pixel or off_pixel
+ (i.e. it must be a binary image)
+ - all pixels in structuring_element are set to either on_pixel or off_pixel
+ (i.e. it must be a binary image)
+ ensures
+ - Does a binary close of in_img using the given structuring element and
+ stores the result in out_img. Specifically, iter iterations of binary
+ dilation are applied and then iter iterations of binary erosion.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type1,
+ typename in_image_type2,
+ typename out_image_type
+ >
+ void binary_intersection (
+ const in_image_type1& in_img1,
+ const in_image_type2& in_img2,
+ out_image_type& out_img
+ );
+ /*!
+ requires
+ - in_image_type1, in_image_type2, and out_image_type are image objects that
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - in_img1 and in_img2 must contain grayscale pixel types.
+ - in_img1, in_img2, and out_img must contain pixels with no alpha channel.
+ (i.e. pixel_traits::has_alpha==false for their pixels)
+ - all pixels in in_img1 and in_img2 are set to either on_pixel or off_pixel
+ (i.e. they must be binary images)
+ - in_img1.nc() == in_img2.nc()
+ - in_img1.nr() == in_img2.nr()
+ ensures
+ - #out_img == the binary intersection of in_img1 and in_img2. (i.e. All
+ the pixels that are set to on_pixel in both in_img1 and in_img2 will be set
+ to on_pixel in #out_img. All other pixels will be set to off_pixel)
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type1,
+ typename in_image_type2,
+ typename out_image_type
+ >
+ void binary_union (
+ const in_image_type1& in_img1,
+ const in_image_type2& in_img2,
+ out_image_type& out_img
+ );
+ /*!
+ requires
+ - in_image_type1, in_image_type2, and out_image_type are image objects that
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - in_img1 and in_img2 must contain grayscale pixel types.
+ - in_img1, in_img2, and out_img must contain pixels with no alpha channel.
+ (i.e. pixel_traits::has_alpha==false for their pixels)
+ - all pixels in in_img1 and in_img2 are set to either on_pixel or off_pixel
+ (i.e. they must be binary images)
+ - in_img1.nc() == in_img2.nc()
+ - in_img1.nr() == in_img2.nr()
+ ensures
+ - #out_img == the binary union of in_img1 and in_img2. (i.e. All
+ the pixels that are set to on_pixel in in_img1 and/or in_img2 will be set
+ to on_pixel in #out_img. All other pixels will be set to off_pixel)
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type1,
+ typename in_image_type2,
+ typename out_image_type
+ >
+ void binary_difference (
+ const in_image_type1& in_img1,
+ const in_image_type2& in_img2,
+ out_image_type& out_img
+ );
+ /*!
+ requires
+ - in_image_type1, in_image_type2, and out_image_type are image objects that
+ implement the interface defined in dlib/image_processing/generic_image.h
+ - in_img1 and in_img2 must contain grayscale pixel types.
+ - in_img1, in_img2, and out_img must contain pixels with no alpha channel.
+ (i.e. pixel_traits::has_alpha==false for their pixels)
+ - all pixels in in_img1 and in_img2 are set to either on_pixel or off_pixel
+ (i.e. they must be binary images)
+ - in_img1.nc() == in_img2.nc()
+ - in_img1.nr() == in_img2.nr()
+ ensures
+ - #out_img == the binary difference of in_img1 and in_img2. (i.e. #out_img
+ will be a copy of in_img1 except that any pixels in in_img2 that are set to
+ on_pixel will be set to off_pixel)
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void binary_complement (
+ const in_image_type& in_img,
+ out_image_type& out_img
+ );
+ /*!
+ requires
+ - in_image_type and out_image_type are image objects that implement the
+ interface defined in dlib/image_processing/generic_image.h
+ - in_img must contain a grayscale pixel type.
+ - both in_img and out_img must contain pixels with no alpha channel.
+ (i.e. pixel_traits::has_alpha==false for their pixels)
+ - all pixels in in_img are set to either on_pixel or off_pixel
+ (i.e. it must be a binary image)
+ ensures
+ - #out_img == the binary complement of in_img. (i.e. For each pixel in
+ in_img, if it is on_pixel then it will be set to off_pixel in #out_img and
+ if it was off_pixel in in_img then it will be on_pixel in #out_img)
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+ template <
+ typename image_type
+ >
+ void binary_complement (
+ image_type& img
+ );
+ /*!
+ requires
+ - it must be valid to call binary_complement(img,img);
+ ensures
+ - calls binary_complement(img,img);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void skeleton(
+ image_type& img
+ );
+ /*!
+ requires
+ - image_type is an object that implement the interface defined in
+ dlib/image_processing/generic_image.h
+ - img must contain a grayscale pixel type.
+ - all pixels in img are set to either on_pixel or off_pixel.
+ (i.e. it must be a binary image)
+ ensures
+ - This function computes the skeletonization of img and stores the result in
+ #img. That is, given a binary image, we progressively thin the binary blobs
+ (composed of on_pixel values) until only a single pixel wide skeleton of the
+ original blobs remains.
+ - #img.nc() == img.nc()
+ - #img.nr() == img.nr()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MORPHOLOGICAL_OPERATIONs_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/image_transforms/random_color_transform.h b/ml/dlib/dlib/image_transforms/random_color_transform.h
new file mode 100644
index 000000000..7433da1f7
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/random_color_transform.h
@@ -0,0 +1,157 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RANDOM_cOLOR_TRANSFORM_Hh_
+#define DLIB_RANDOM_cOLOR_TRANSFORM_Hh_
+
+#include "random_color_transform_abstract.h"
+#include "../image_processing/generic_image.h"
+#include "../pixel.h"
+#include "../rand.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class random_color_transform
+ {
+ public:
+
+ random_color_transform (
+ dlib::rand& rnd,
+ const double gamma_magnitude = 0.5,
+ const double color_magnitude = 0.2
+ )
+ {
+ // pick a random gamma correction factor.
+ double gamma = std::max(0.0, 1 + gamma_magnitude*(rnd.get_random_double()-0.5));
+
+ // pick a random color balancing scheme.
+ double red_scale = 1-rnd.get_random_double()*color_magnitude;
+ double green_scale = 1-rnd.get_random_double()*color_magnitude;
+ double blue_scale = 1-rnd.get_random_double()*color_magnitude;
+ const double m = 255*std::max(std::max(red_scale,green_scale),blue_scale);
+ red_scale /= m;
+ green_scale /= m;
+ blue_scale /= m;
+
+ // Now compute a lookup table for all the color channels. The table tells us
+ // what the transform does.
+ table.resize(256*3);
+ unsigned long i = 0;
+ for (int k = 0; k < 256; ++k)
+ {
+ double v = 255*std::pow(k*red_scale, gamma);
+ table[i++] = (unsigned char)(v + 0.5);
+ }
+ for (int k = 0; k < 256; ++k)
+ {
+ double v = 255*std::pow(k*green_scale, gamma);
+ table[i++] = (unsigned char)(v + 0.5);
+ }
+ for (int k = 0; k < 256; ++k)
+ {
+ double v = 255*std::pow(k*blue_scale, gamma);
+ table[i++] = (unsigned char)(v + 0.5);
+ }
+ }
+
+ rgb_pixel operator()(rgb_pixel p) const
+ {
+ p.red = table[(unsigned int)p.red];
+ p.green = table[(unsigned int)p.green+256];
+ p.blue = table[(unsigned int)p.blue+512];
+ return p;
+ }
+
+ private:
+ std::vector<unsigned char> table;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ void disturb_colors (
+ image_type& img_,
+ dlib::rand& rnd,
+ const double gamma_magnitude = 0.5,
+ const double color_magnitude = 0.2
+ )
+ {
+ image_view<image_type> img(img_);
+ random_color_transform tform(rnd, gamma_magnitude, color_magnitude);
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ rgb_pixel temp;
+ assign_pixel(temp, img[r][c]);
+ temp = tform(temp);
+ assign_pixel(img[r][c], temp);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ void apply_random_color_offset (
+ image_type& img_,
+ dlib::rand& rnd
+ )
+ {
+ // Make a random color offset. This tform matrix came from looking at the
+ // covariance matrix of RGB values in a bunch of images. In particular, if you
+ // multiply Gaussian random vectors by tform it will result in vectors with the
+ // same covariance matrix as the original RGB data. Also, this color transform is
+ // what is suggested by the paper:
+ // Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet
+ // classification with deep convolutional neural networks." Advances in neural
+ // information processing systems. 2012.
+ // Except that we used the square root of the eigenvalues (which I'm pretty sure is
+ // what the authors intended).
+ matrix<double,3,3> tform;
+ tform = -66.379, 25.094, 6.79698,
+ -68.0492, -0.302309, -13.9539,
+ -68.4907, -24.0199, 7.27653;
+ matrix<double,3,1> v;
+ v = rnd.get_random_gaussian(),rnd.get_random_gaussian(),rnd.get_random_gaussian();
+ v = round(tform*0.1*v);
+ const int roffset = v(0);
+ const int goffset = v(1);
+ const int boffset = v(2);
+
+ // Make up lookup tables that apply the color mapping so we don't have to put a
+ // bunch of complicated conditional branches in the loop below.
+ unsigned char rtable[256];
+ unsigned char gtable[256];
+ unsigned char btable[256];
+ for (int i = 0; i < 256; ++i)
+ {
+ rtable[i] = put_in_range(0, 255, i+roffset);
+ gtable[i] = put_in_range(0, 255, i+goffset);
+ btable[i] = put_in_range(0, 255, i+boffset);
+ }
+
+ // now transform the image.
+ image_view<image_type> img(img_);
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ rgb_pixel temp;
+ assign_pixel(temp, img[r][c]);
+ temp.red = rtable[temp.red];
+ temp.green = gtable[temp.green];
+ temp.blue = btable[temp.blue];
+ assign_pixel(img[r][c], temp);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANDOM_cOLOR_TRANSFORM_Hh_
+
diff --git a/ml/dlib/dlib/image_transforms/random_color_transform_abstract.h b/ml/dlib/dlib/image_transforms/random_color_transform_abstract.h
new file mode 100644
index 000000000..5826e16a6
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/random_color_transform_abstract.h
@@ -0,0 +1,94 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RANDOM_cOLOR_TRANSFORM_ABSTRACT_Hh_
+#ifdef DLIB_RANDOM_cOLOR_TRANSFORM_ABSTRACT_Hh_
+
+#include "../image_processing/generic_image.h"
+#include "../pixel.h"
+#include "../rand.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class random_color_transform
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object generates a random color balancing and gamma correction
+ transform. It then allows you to apply that specific transform to as many
+ rgb_pixel objects as you like.
+ !*/
+
+ public:
+
+ random_color_transform (
+ dlib::rand& rnd,
+ const double gamma_magnitude = 0.5,
+ const double color_magnitude = 0.2
+ );
+ /*!
+ requires
+ - 0 <= gamma_magnitude
+ - 0 <= color_magnitude <= 1
+ ensures
+ - This constructor generates a random color transform which can be applied
+ by calling this object's operator() method.
+ - The color transform is a gamma correction and color rebalancing. If
+ gamma_magnitude == 0 and color_magnitude == 0 then the transform doesn't
+ change any colors at all. However, the larger these parameters the more
+ noticeable the resulting transform.
+ !*/
+
+ rgb_pixel operator()(
+ rgb_pixel p
+ ) const;
+ /*!
+ ensures
+ - returns the color transformed version of p.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ void disturb_colors (
+ image_type& img,
+ dlib::rand& rnd,
+ const double gamma_magnitude = 0.5,
+ const double color_magnitude = 0.2
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - Applies a random color transform to the given image. This is done by
+ creating a random_color_transform with the given parameters and then
+ transforming each pixel in the image with the resulting transform.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ void apply_random_color_offset (
+ image_type& img,
+ dlib::rand& rnd
+ );
+ /*!
+ ensures
+ - Picks a random color offset vector and adds it to the given image. The offset
+ vector is selected using the method described in the paper:
+ Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet
+ classification with deep convolutional neural networks." Advances in neural
+ information processing systems. 2012.
+ In particular, we sample an RGB value from the typical distribution of RGB
+ values, assuming it has a Gaussian distribution, and then divide it by 10.
+ This sampled RGB vector is added to each pixel of img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_RANDOM_cOLOR_TRANSFORM_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/image_transforms/random_cropper.h b/ml/dlib/dlib/image_transforms/random_cropper.h
new file mode 100644
index 000000000..2c754b608
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/random_cropper.h
@@ -0,0 +1,361 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RaNDOM_CROPPER_H_
+#define DLIB_RaNDOM_CROPPER_H_
+
+#include "random_cropper_abstract.h"
+#include "../threads.h"
+#include <mutex>
+#include <vector>
+#include "interpolation.h"
+#include "../image_processing/full_object_detection.h"
+#include "../rand.h"
+
+namespace dlib
+{
+ class random_cropper
+ {
+ chip_dims dims = chip_dims(300,300);
+ bool randomly_flip = true;
+ double max_rotation_degrees = 30;
+ long min_object_length_long_dim = 75; // cropped object will be at least this many pixels along its longest edge.
+ long min_object_length_short_dim = 30; // cropped object will be at least this many pixels along its shortest edge.
+ double max_object_size = 0.7; // cropped object will be at most this fraction of the size of the image.
+ double background_crops_fraction = 0.5;
+ double translate_amount = 0.10;
+
+ std::mutex rnd_mutex;
+ dlib::rand rnd;
+ public:
+
+ void set_seed (
+ time_t seed
+ ) { rnd = dlib::rand(seed); }
+
+ double get_translate_amount (
+ ) const { return translate_amount; }
+
+ void set_translate_amount (
+ double value
+ )
+ {
+ DLIB_CASSERT(0 <= value);
+ translate_amount = value;
+ }
+
+ double get_background_crops_fraction (
+ ) const { return background_crops_fraction; }
+
+ void set_background_crops_fraction (
+ double value
+ )
+ {
+ DLIB_CASSERT(0 <= value && value <= 1);
+ background_crops_fraction = value;
+ }
+
+ const chip_dims& get_chip_dims(
+ ) const { return dims; }
+
+ void set_chip_dims (
+ const chip_dims& dims_
+ ) { dims = dims_; }
+
+ void set_chip_dims (
+ unsigned long rows,
+ unsigned long cols
+ ) { set_chip_dims(chip_dims(rows,cols)); }
+
+ bool get_randomly_flip (
+ ) const { return randomly_flip; }
+
+ void set_randomly_flip (
+ bool value
+ ) { randomly_flip = value; }
+
+ double get_max_rotation_degrees (
+ ) const { return max_rotation_degrees; }
+ void set_max_rotation_degrees (
+ double value
+ ) { max_rotation_degrees = std::abs(value); }
+
+ long get_min_object_length_long_dim (
+ ) const { return min_object_length_long_dim; }
+ long get_min_object_length_short_dim (
+ ) const { return min_object_length_short_dim; }
+
+ void set_min_object_size (
+ long long_dim,
+ long short_dim
+ )
+ {
+ DLIB_CASSERT(0 < short_dim && short_dim <= long_dim);
+ min_object_length_long_dim = long_dim;
+ min_object_length_short_dim = short_dim;
+ }
+
+ double get_max_object_size (
+ ) const { return max_object_size; }
+ void set_max_object_size (
+ double value
+ )
+ {
+ DLIB_CASSERT(0 < value);
+ max_object_size = value;
+ }
+
+ template <
+ typename array_type
+ >
+ void operator() (
+ size_t num_crops,
+ const array_type& images,
+ const std::vector<std::vector<mmod_rect>>& rects,
+ array_type& crops,
+ std::vector<std::vector<mmod_rect>>& crop_rects
+ )
+ {
+ DLIB_CASSERT(images.size() == rects.size());
+ crops.clear();
+ crop_rects.clear();
+ append(num_crops, images, rects, crops, crop_rects);
+ }
+
+ template <
+ typename array_type
+ >
+ void append (
+ size_t num_crops,
+ const array_type& images,
+ const std::vector<std::vector<mmod_rect>>& rects,
+ array_type& crops,
+ std::vector<std::vector<mmod_rect>>& crop_rects
+ )
+ {
+ DLIB_CASSERT(images.size() == rects.size());
+ DLIB_CASSERT(crops.size() == crop_rects.size());
+ auto original_size = crops.size();
+ crops.resize(crops.size()+num_crops);
+ crop_rects.resize(crop_rects.size()+num_crops);
+ parallel_for(original_size, original_size+num_crops, [&](long i) {
+ (*this)(images, rects, crops[i], crop_rects[i]);
+ });
+ }
+
+
+ template <
+ typename array_type,
+ typename image_type
+ >
+ void operator() (
+ const array_type& images,
+ const std::vector<std::vector<mmod_rect>>& rects,
+ image_type& crop,
+ std::vector<mmod_rect>& crop_rects
+ )
+ {
+ DLIB_CASSERT(images.size() == rects.size());
+ size_t idx;
+ { std::lock_guard<std::mutex> lock(rnd_mutex);
+ idx = rnd.get_integer(images.size());
+ }
+ (*this)(images[idx], rects[idx], crop, crop_rects);
+ }
+
+ template <
+ typename image_type1
+ >
+ image_type1 operator() (
+ const image_type1& img
+ )
+ {
+ image_type1 crop;
+ std::vector<mmod_rect> junk1, junk2;
+ (*this)(img, junk1, crop, junk2);
+ return crop;
+ }
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void operator() (
+ const image_type1& img,
+ const std::vector<mmod_rect>& rects,
+ image_type2& crop,
+ std::vector<mmod_rect>& crop_rects
+ )
+ {
+ DLIB_CASSERT(num_rows(img)*num_columns(img) != 0);
+ chip_details crop_plan;
+ bool should_flip_crop;
+ make_crop_plan(img, rects, crop_plan, should_flip_crop);
+
+ extract_image_chip(img, crop_plan, crop);
+ const rectangle_transform tform = get_mapping_to_chip(crop_plan);
+
+ // copy rects into crop_rects and set ones that are outside the crop to ignore or
+ // drop entirely as appropriate.
+ crop_rects.clear();
+ for (auto rect : rects)
+ {
+ // map to crop
+ rect.rect = tform(rect.rect);
+
+ // if the rect is at least partly in the crop
+ if (get_rect(crop).intersect(rect.rect).area() != 0)
+ {
+ // set to ignore if not totally in the crop or if too small.
+ if (!get_rect(crop).contains(rect.rect) ||
+ ((long)rect.rect.height() < min_object_length_long_dim && (long)rect.rect.width() < min_object_length_long_dim) ||
+ ((long)rect.rect.height() < min_object_length_short_dim || (long)rect.rect.width() < min_object_length_short_dim))
+ {
+ rect.ignore = true;
+ }
+
+ crop_rects.push_back(rect);
+ }
+ }
+
+ // Also randomly flip the image
+ if (should_flip_crop)
+ {
+ image_type2 temp;
+ flip_image_left_right(crop, temp);
+ swap(crop,temp);
+ for (auto&& rect : crop_rects)
+ rect.rect = impl::flip_rect_left_right(rect.rect, get_rect(crop));
+ }
+ }
+
+ private:
+
+ template <typename image_type1>
+ void make_crop_plan (
+ const image_type1& img,
+ const std::vector<mmod_rect>& rects,
+ chip_details& crop_plan,
+ bool& should_flip_crop
+ )
+ {
+ std::lock_guard<std::mutex> lock(rnd_mutex);
+ rectangle crop_rect;
+ if (has_non_ignored_box(rects) && rnd.get_random_double() >= background_crops_fraction)
+ {
+ auto rect = rects[randomly_pick_rect(rects)].rect;
+
+ // perturb the location of the crop by a small fraction of the object's size.
+ const point rand_translate = dpoint(rnd.get_double_in_range(-translate_amount,translate_amount)*std::max(rect.height(),rect.width()),
+ rnd.get_double_in_range(-translate_amount,translate_amount)*std::max(rect.height(),rect.width()));
+
+ // We are going to grow rect into the cropping rect. First, we grow it a
+ // little so that it has the desired minimum border around it.
+ drectangle drect = centered_drect(center(rect)+rand_translate, rect.width()/max_object_size, rect.height()/max_object_size);
+
+ // Now make rect have the same aspect ratio as dims so that there won't be
+ // any funny stretching when we crop it. We do this by growing it along
+ // whichever dimension is too short.
+ const double target_aspect = dims.cols/(double)dims.rows;
+ if (drect.width()/drect.height() < target_aspect)
+ drect = centered_drect(drect, target_aspect*drect.height(), drect.height());
+ else
+ drect = centered_drect(drect, drect.width(), drect.width()/target_aspect);
+
+ // Now perturb the scale of the crop. We do this by shrinking it, but not
+ // so much that it gets smaller than the min object sizes require.
+ double current_width = dims.cols*rect.width()/drect.width();
+ double current_height = dims.rows*rect.height()/drect.height();
+
+ // never make any dimension smaller than the short dim.
+ double min_scale1 = std::max(min_object_length_short_dim/current_width, min_object_length_short_dim/current_height);
+ // at least one dimension needs to be longer than the long dim.
+ double min_scale2 = std::min(min_object_length_long_dim/current_width, min_object_length_long_dim/current_height);
+ double min_scale = std::max(min_scale1, min_scale2);
+
+ const double rand_scale_perturb = 1.0/rnd.get_double_in_range(min_scale, 1);
+ crop_rect = centered_drect(drect, drect.width()*rand_scale_perturb, drect.height()*rand_scale_perturb);
+
+ }
+ else
+ {
+ crop_rect = make_random_cropping_rect(img);
+ }
+ should_flip_crop = randomly_flip && rnd.get_random_double() > 0.5;
+ const double angle = rnd.get_double_in_range(-max_rotation_degrees, max_rotation_degrees)*pi/180;
+ crop_plan = chip_details(crop_rect, dims, angle);
+ }
+
+ bool has_non_ignored_box (
+ const std::vector<mmod_rect>& rects
+ ) const
+ {
+ for (auto&& b : rects)
+ {
+ if (!b.ignore)
+ return true;
+ }
+ return false;
+ }
+
+ size_t randomly_pick_rect (
+ const std::vector<mmod_rect>& rects
+ )
+ {
+ DLIB_CASSERT(has_non_ignored_box(rects));
+ size_t idx = rnd.get_integer(rects.size());
+ while(rects[idx].ignore)
+ idx = rnd.get_integer(rects.size());
+ return idx;
+ }
+
+ template <typename image_type>
+ rectangle make_random_cropping_rect(
+ const image_type& img_
+ )
+ {
+ const_image_view<image_type> img(img_);
+ // Figure out what rectangle we want to crop from the image. We are going to
+ // crop out an image of size this->dims, so we pick a random scale factor that
+ // lets this random box be either as big as it can be while still fitting in
+ // the image or as small as a 3x zoomed in box randomly somewhere in the image.
+ double mins = 1.0/3.0, maxs = std::min(img.nr()/(double)dims.rows, img.nc()/(double)dims.cols);
+ mins = std::min(mins, maxs);
+ auto scale = rnd.get_double_in_range(mins, maxs);
+ rectangle rect(scale*dims.cols, scale*dims.rows);
+ // randomly shift the box around
+ point offset(rnd.get_integer(1+img.nc()-rect.width()),
+ rnd.get_integer(1+img.nr()-rect.height()));
+ return move_rect(rect, offset);
+ }
+
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::ostream& operator<< (
+ std::ostream& out,
+ const random_cropper& item
+ )
+ {
+ using std::endl;
+ out << "random_cropper details: " << endl;
+ out << " chip_dims.rows: " << item.get_chip_dims().rows << endl;
+ out << " chip_dims.cols: " << item.get_chip_dims().cols << endl;
+ out << " randomly_flip: " << std::boolalpha << item.get_randomly_flip() << endl;
+ out << " max_rotation_degrees: " << item.get_max_rotation_degrees() << endl;
+ out << " min_object_length_long_dim: " << item.get_min_object_length_long_dim() << endl;
+ out << " min_object_length_short_dim: " << item.get_min_object_length_short_dim() << endl;
+ out << " max_object_size: " << item.get_max_object_size() << endl;
+ out << " background_crops_fraction: " << item.get_background_crops_fraction() << endl;
+ out << " translate_amount: " << item.get_translate_amount() << endl;
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RaNDOM_CROPPER_H_
+
diff --git a/ml/dlib/dlib/image_transforms/random_cropper_abstract.h b/ml/dlib/dlib/image_transforms/random_cropper_abstract.h
new file mode 100644
index 000000000..7603a1c47
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/random_cropper_abstract.h
@@ -0,0 +1,346 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RaNDOM_CROPPER_ABSTRACT_H_
+#ifdef DLIB_RaNDOM_CROPPER_ABSTRACT_H_
+
+#include "../threads.h"
+#include <mutex>
+#include <vector>
+#include "interpolation.h"
+#include "../image_processing/full_object_detection.h"
+#include "../rand.h"
+
+namespace dlib
+{
+ class random_cropper
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for extracting random crops of objects from a set of
+ images. The crops are randomly jittered in scale, translation, and
+ rotation but more or less centered on objects specified by mmod_rect
+ objects.
+
+ THREAD SAFETY
+ It is safe for multiple threads to make concurrent calls to this object's
+ operator() methods.
+ !*/
+
+ public:
+
+ random_cropper (
+ );
+ /*!
+ ensures
+ - #get_chip_dims() == chip_dims(300,300)
+ - #get_randomly_flip() == true
+ - #get_max_rotation_degrees() == 30
+ - #get_min_object_length_long_dim() == 70
+ - #get_min_object_length_short_dim() == 30
+ - #get_max_object_size() == 0.7
+ - #get_background_crops_fraction() == 0.5
+ - #get_translate_amount() == 0.1
+ !*/
+
+ void set_seed (
+ time_t seed
+ );
+ /*!
+ ensures
+ - Seeds the internal random number generator with the given seed.
+ !*/
+
+ double get_translate_amount (
+ ) const;
+ /*!
+ ensures
+ - When a box is cropped out, it will be randomly translated prior to
+ cropping by #get_translate_amount()*(the box's height) up or down and
+ #get_translate_amount()*(the box's width) left or right.
+ !*/
+
+ void set_translate_amount (
+ double value
+ );
+ /*!
+ requires
+ - value >= 0
+ ensures
+ - #get_translate_amount() == value
+ !*/
+
+ double get_background_crops_fraction (
+ ) const;
+ /*!
+ ensures
+ - When making random crops, get_background_crops_fraction() fraction of
+ them will be from random background rather than being centered on some
+ object in the dataset.
+ !*/
+
+ void set_background_crops_fraction (
+ double value
+ );
+ /*!
+ requires
+ - 0 <= value <= 1
+ ensures
+ - #get_background_crops_fraction() == value
+ !*/
+
+ const chip_dims& get_chip_dims(
+ ) const;
+ /*!
+ ensures
+ - returns the dimensions of image chips produced by this object.
+ !*/
+
+ void set_chip_dims (
+ const chip_dims& dims
+ );
+ /*!
+ ensures
+ - #get_chip_dims() == dims
+ !*/
+
+ void set_chip_dims (
+ unsigned long rows,
+ unsigned long cols
+ );
+ /*!
+ ensures
+ - #get_chip_dims() == chip_dims(rows,cols)
+ !*/
+
+ bool get_randomly_flip (
+ ) const;
+ /*!
+ ensures
+ - if this object will randomly mirror chips left to right.
+ !*/
+
+ void set_randomly_flip (
+ bool value
+ );
+ /*!
+ ensures
+ - #get_randomly_flip() == value
+ !*/
+
+ double get_max_rotation_degrees (
+ ) const;
+ /*!
+ ensures
+ - When extracting an image chip, this object will pick a random rotation
+ in the range [-get_max_rotation_degrees(), get_max_rotation_degrees()]
+ and rotate the chip by that amount.
+ !*/
+
+ void set_max_rotation_degrees (
+ double value
+ );
+ /*!
+ ensures
+ - #get_max_rotation_degrees() == std::abs(value)
+ !*/
+
+ long get_min_object_length_long_dim (
+ ) const;
+ /*!
+ ensures
+ - When a chip is extracted around an object, the chip will be sized so that
+ the longest edge of the object (i.e. either its height or width,
+ whichever is longer) is at least #get_min_object_length_long_dim() pixels
+ in length. When we say "object" here we are referring specifically to
+ the rectangle in the mmod_rect output by the cropper.
+ !*/
+
+ long get_min_object_length_short_dim (
+ ) const;
+ /*!
+ ensures
+ - When a chip is extracted around an object, the chip will be sized so that
+ the shortest edge of the object (i.e. either its height or width,
+ whichever is shorter) is at least #get_min_object_length_short_dim()
+ pixels in length. When we say "object" here we are referring
+ specifically to the rectangle in the mmod_rect output by the cropper.
+ !*/
+
+ void set_min_object_size (
+ long long_dim,
+ long short_dim
+ );
+ /*!
+ requires
+ - 0 < short_dim <= long_dim
+ ensures
+ - #get_min_object_length_short_dim() == short_dim
+ - #get_min_object_length_long_dim() == long_dim
+ !*/
+
+ double get_max_object_size (
+ ) const;
+ /*!
+ ensures
+ - When a chip is extracted around an object, the chip will be sized so that
+ both the object's height and width are at most get_max_object_size() *
+ the chip's height and width, respectively. E.g. if the chip is 640x480
+ pixels in size then the object will be at most 480*get_max_object_size()
+ pixels tall and 640*get_max_object_size() pixels wide.
+ !*/
+
+ void set_max_object_size (
+ double value
+ );
+ /*!
+ requires
+ - 0 < value
+ ensures
+ - #get_max_object_size() == value
+ !*/
+
+ template <
+ typename array_type
+ >
+ void append (
+ size_t num_crops,
+ const array_type& images,
+ const std::vector<std::vector<mmod_rect>>& rects,
+ array_type& crops,
+ std::vector<std::vector<mmod_rect>>& crop_rects
+ );
+ /*!
+ requires
+ - images.size() == rects.size()
+ - crops.size() == crop_rects.size()
+ - for all valid i:
+ - images[i].size() != 0
+ - array_type is a type with an interface compatible with dlib::array or
+ std::vector and it must in turn contain image objects that implement the
+ interface defined in dlib/image_processing/generic_image.h
+ ensures
+ - Randomly extracts num_crops chips from images and appends them to the end
+ of crops. We also copy the object metadata for each extracted crop and
+ store it into #crop_rects. In particular, calling this function is the
+ same as making multiple calls to the version of operator() below that
+ outputs a single crop, except that append() will use multiple CPU cores
+ to do the processing and is therefore faster.
+ - #crops.size() == crops.size()+num_crops
+ - #crop_rects.size() == crop_rects.size()+num_crops
+ !*/
+
+ template <
+ typename array_type
+ >
+ void operator() (
+ size_t num_crops,
+ const array_type& images,
+ const std::vector<std::vector<mmod_rect>>& rects,
+ array_type& crops,
+ std::vector<std::vector<mmod_rect>>& crop_rects
+ );
+ /*!
+ requires
+ - images.size() == rects.size()
+ - for all valid i:
+ - images[i].size() != 0
+ - array_type is a type with an interface compatible with dlib::array or
+ std::vector and it must in turn contain image objects that implement the
+ interface defined in dlib/image_processing/generic_image.h
+ ensures
+ - Randomly extracts num_crops chips from images. We also copy the object
+ metadata for each extracted crop and store it into #crop_rects. In
+ particular, calling this function is the same as invoking the version of
+ operator() below multiple times, except that this version of operator()
+ will use multiple CPU cores to do the processing and is therefore faster.
+ - #crops.size() == num_crops
+ - #crop_rects.size() == num_crops
+ !*/
+
+ template <
+ typename array_type,
+ typename image_type
+ >
+ void operator() (
+ const array_type& images,
+ const std::vector<std::vector<mmod_rect>>& rects,
+ image_type& crop,
+ std::vector<mmod_rect>& crop_rects
+ );
+ /*!
+ requires
+ - images.size() == rects.size()
+ - for all valid i:
+ - images[i].size() != 0
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - array_type is a type with an interface compatible with dlib::array or
+ std::vector and it must in turn contain image objects that implement the
+ interface defined in dlib/image_processing/generic_image.h
+ ensures
+ - Selects a random image and creates a random crop from it. Specifically,
+ we pick a random index IDX < images.size() and then execute
+ (*this)(images[IDX],rects[IDX],crop,crop_rects)
+ !*/
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void operator() (
+ const image_type1& img,
+ const std::vector<mmod_rect>& rects,
+ image_type2& crop,
+ std::vector<mmod_rect>& crop_rects
+ );
+ /*!
+ requires
+ - img.size() != 0
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - Extracts a random crop from img and copies over the mmod_rect objects in
+ rects to #crop_rects if they are contained inside the crop. Moreover,
+ rectangles are marked as ignore if they aren't completely contained
+ inside the crop.
+ - #crop_rects.size() <= rects.size()
+ !*/
+
+ template <
+ typename image_type1
+ >
+ image_type1 operator() (
+ const image_type1& img
+ );
+ /*!
+ requires
+ - img.size() != 0
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ ensures
+ - This function simply calls (*this)(img, junk1, crop, junk2) and returns
+ crop. Therefore it is simply a convenience function for extracting a
+ random background patch.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ std::ostream& operator<< (
+ std::ostream& out,
+ const random_cropper& item
+ );
+ /*!
+ ensures
+ - Prints the state of all the parameters of item to out.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RaNDOM_CROPPER_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/image_transforms/segment_image.h b/ml/dlib/dlib/image_transforms/segment_image.h
new file mode 100644
index 000000000..3b57e4801
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/segment_image.h
@@ -0,0 +1,730 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEGMENT_ImAGE_Hh_
+#define DLIB_SEGMENT_ImAGE_Hh_
+
+#include "segment_image_abstract.h"
+#include "../algs.h"
+#include <vector>
+#include "../geometry.h"
+#include "../disjoint_subsets.h"
+#include "../set.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T>
+ inline T edge_diff_uint(
+ const T& a,
+ const T& b
+ )
+ {
+ if (a > b)
+ return a - b;
+ else
+ return b - a;
+ }
+
+ // ----------------------------------------
+
+ template <typename T, typename enabled = void>
+ struct edge_diff_funct
+ {
+ typedef double diff_type;
+
+ template <typename pixel_type>
+ double operator()(
+ const pixel_type& a,
+ const pixel_type& b
+ ) const
+ {
+ return length(pixel_to_vector<double>(a) - pixel_to_vector<double>(b));
+ }
+ };
+
+ template <>
+ struct edge_diff_funct<uint8,void>
+ {
+ typedef uint8 diff_type;
+ uint8 operator()( const uint8& a, const uint8& b) const { return edge_diff_uint(a,b); }
+ };
+
+ template <>
+ struct edge_diff_funct<uint16,void>
+ {
+ typedef uint16 diff_type;
+ uint16 operator()( const uint16& a, const uint16& b) const { return edge_diff_uint(a,b); }
+ };
+
+ template <>
+ struct edge_diff_funct<uint32,void>
+ {
+ typedef uint32 diff_type;
+ uint32 operator()( const uint32& a, const uint32& b) const { return edge_diff_uint(a,b); }
+ };
+
+ template <>
+ struct edge_diff_funct<double,void>
+ {
+ typedef double diff_type;
+ double operator()( const double& a, const double& b) const { return std::abs(a-b); }
+ };
+
+ template <typename T>
+ struct edge_diff_funct<T, typename enable_if<is_matrix<T> >::type>
+ {
+ typedef double diff_type;
+ double operator()(
+ const T& a,
+ const T& b
+ ) const
+ {
+ return length(a-b);
+ }
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct graph_image_segmentation_data_T
+ {
+ graph_image_segmentation_data_T() : component_size(1), internal_diff(0) {}
+ unsigned long component_size;
+ T internal_diff;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct segment_image_edge_data_T
+ {
+ segment_image_edge_data_T (){}
+
+ segment_image_edge_data_T (
+ const rectangle& rect,
+ const point& p1,
+ const point& p2,
+ const T& diff_
+ ) :
+ idx1(p1.y()*rect.width() + p1.x()),
+ idx2(p2.y()*rect.width() + p2.x()),
+ diff(diff_)
+ {}
+
+ bool operator<(const segment_image_edge_data_T& item) const
+ { return diff < item.diff; }
+
+ unsigned long idx1;
+ unsigned long idx2;
+ T diff;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename image_view_type>
+ struct uint8_or_uint16_pixels
+ {
+ typedef typename image_view_type::pixel_type pixel_type;
+ const static bool value = is_same_type<pixel_type,uint8>::value ||
+ is_same_type<pixel_type,uint16>::value;
+ };
+
+ // This is an overload of get_pixel_edges() that is optimized to segment images
+ // with 8bit or 16bit pixels very quickly. We do this by using a radix sort
+ // instead of quicksort.
+ template <typename in_image_type, typename T>
+ typename enable_if<uint8_or_uint16_pixels<in_image_type> >::type
+ get_pixel_edges (
+ const in_image_type& in_img,
+ std::vector<segment_image_edge_data_T<T> >& sorted_edges
+ )
+ {
+ typedef typename in_image_type::pixel_type ptype;
+ typedef T diff_type;
+ std::vector<unsigned long> counts(std::numeric_limits<ptype>::max()+1, 0);
+
+ edge_diff_funct<ptype> edge_diff;
+
+ border_enumerator be(get_rect(in_img), 1);
+ // we are going to do a radix sort on the edge weights. So the first step
+ // is to accumulate them into count.
+ const rectangle area = get_rect(in_img);
+ while (be.move_next())
+ {
+ const long r = be.element().y();
+ const long c = be.element().x();
+ const ptype pix = in_img[r][c];
+ if (area.contains(c-1,r)) counts[edge_diff(pix, in_img[r ][c-1])] += 1;
+ if (area.contains(c+1,r)) counts[edge_diff(pix, in_img[r ][c+1])] += 1;
+ if (area.contains(c ,r-1)) counts[edge_diff(pix, in_img[r-1][c ])] += 1;
+ if (area.contains(c ,r+1)) counts[edge_diff(pix, in_img[r+1][c ])] += 1;
+ }
+ for (long r = 1; r+1 < in_img.nr(); ++r)
+ {
+ for (long c = 1; c+1 < in_img.nc(); ++c)
+ {
+ const ptype pix = in_img[r][c];
+ counts[edge_diff(pix, in_img[r-1][c+1])] += 1;
+ counts[edge_diff(pix, in_img[r ][c+1])] += 1;
+ counts[edge_diff(pix, in_img[r+1][c ])] += 1;
+ counts[edge_diff(pix, in_img[r+1][c+1])] += 1;
+ }
+ }
+
+ const unsigned long num_edges = shrink_rect(area,1).area()*4 + in_img.nr()*2*3 - 4 + (in_img.nc()-2)*2*3;
+ typedef segment_image_edge_data_T<T> segment_image_edge_data;
+ sorted_edges.resize(num_edges);
+
+ // integrate counts. The idea is to have sorted_edges[counts[i]] be the location that edges
+ // with an edge_diff of i go. So counts[0] == 0, counts[1] == number of 0 edge diff edges, etc.
+ unsigned long prev = counts[0];
+ for (unsigned long i = 1; i < counts.size(); ++i)
+ {
+ const unsigned long temp = counts[i];
+ counts[i] += counts[i-1];
+ counts[i-1] -= prev;
+ prev = temp;
+ }
+ counts[counts.size()-1] -= prev;
+
+
+ // now build a sorted list of all the edges
+ be.reset();
+ while(be.move_next())
+ {
+ const point p = be.element();
+ const long r = p.y();
+ const long c = p.x();
+ const ptype pix = in_img[r][c];
+ if (area.contains(c-1,r))
+ {
+ const diff_type diff = edge_diff(pix, in_img[r ][c-1]);
+ sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c-1,r),diff);
+ }
+
+ if (area.contains(c+1,r))
+ {
+ const diff_type diff = edge_diff(pix, in_img[r ][c+1]);
+ sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c+1,r),diff);
+ }
+
+ if (area.contains(c ,r-1))
+ {
+ const diff_type diff = edge_diff(pix, in_img[r-1][c ]);
+ sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c ,r-1),diff);
+ }
+
+ if (area.contains(c ,r+1))
+ {
+ const diff_type diff = edge_diff(pix, in_img[r+1][c ]);
+ sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c ,r+1),diff);
+ }
+ }
+ // same thing as the above loop but now we do it on the interior of the image and therefore
+ // don't have to include the boundary checking if statements used above.
+ for (long r = 1; r+1 < in_img.nr(); ++r)
+ {
+ for (long c = 1; c+1 < in_img.nc(); ++c)
+ {
+ const point p(c,r);
+ const ptype pix = in_img[r][c];
+ diff_type diff;
+
+ diff = edge_diff(pix, in_img[r ][c+1]);
+ sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c+1,r),diff);
+ diff = edge_diff(pix, in_img[r-1][c+1]);
+ sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c+1,r-1),diff);
+ diff = edge_diff(pix, in_img[r+1][c+1]);
+ sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c+1,r+1),diff);
+ diff = edge_diff(pix, in_img[r+1][c ]);
+ sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c ,r+1),diff);
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ // This is the general purpose version of get_pixel_edges(). It handles all pixel types.
+ template <typename in_image_type, typename T>
+ typename disable_if<uint8_or_uint16_pixels<in_image_type> >::type
+ get_pixel_edges (
+ const in_image_type& in_img,
+ std::vector<segment_image_edge_data_T<T> >& sorted_edges
+ )
+ {
+ const rectangle area = get_rect(in_img);
+ sorted_edges.reserve(area.area()*4);
+
+ typedef typename in_image_type::pixel_type ptype;
+ edge_diff_funct<ptype> edge_diff;
+ typedef T diff_type;
+ typedef segment_image_edge_data_T<T> segment_image_edge_data;
+
+ border_enumerator be(get_rect(in_img), 1);
+
+ // now build a sorted list of all the edges
+ be.reset();
+ while(be.move_next())
+ {
+ const point p = be.element();
+ const long r = p.y();
+ const long c = p.x();
+ const ptype& pix = in_img[r][c];
+ if (area.contains(c-1,r))
+ {
+ const diff_type diff = edge_diff(pix, in_img[r ][c-1]);
+ sorted_edges.push_back(segment_image_edge_data(area,p,point(c-1,r),diff));
+ }
+
+ if (area.contains(c+1,r))
+ {
+ const diff_type diff = edge_diff(pix, in_img[r ][c+1]);
+ sorted_edges.push_back(segment_image_edge_data(area,p,point(c+1,r),diff));
+ }
+
+ if (area.contains(c ,r-1))
+ {
+ const diff_type diff = edge_diff(pix, in_img[r-1][c ]);
+ sorted_edges.push_back( segment_image_edge_data(area,p,point(c ,r-1),diff));
+ }
+ if (area.contains(c ,r+1))
+ {
+ const diff_type diff = edge_diff(pix, in_img[r+1][c ]);
+ sorted_edges.push_back( segment_image_edge_data(area,p,point(c ,r+1),diff));
+ }
+ }
+ // same thing as the above loop but now we do it on the interior of the image and therefore
+ // don't have to include the boundary checking if statements used above.
+ for (long r = 1; r+1 < in_img.nr(); ++r)
+ {
+ for (long c = 1; c+1 < in_img.nc(); ++c)
+ {
+ const point p(c,r);
+ const ptype& pix = in_img[r][c];
+ diff_type diff;
+
+ diff = edge_diff(pix, in_img[r ][c+1]);
+ sorted_edges.push_back( segment_image_edge_data(area,p,point(c+1,r),diff));
+ diff = edge_diff(pix, in_img[r+1][c+1]);
+ sorted_edges.push_back( segment_image_edge_data(area,p,point(c+1,r+1),diff));
+ diff = edge_diff(pix, in_img[r+1][c ]);
+ sorted_edges.push_back( segment_image_edge_data(area,p,point(c ,r+1),diff));
+ diff = edge_diff(pix, in_img[r-1][c+1]);
+ sorted_edges.push_back( segment_image_edge_data(area,p,point(c+1,r-1),diff));
+ }
+ }
+
+ std::sort(sorted_edges.begin(), sorted_edges.end());
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ } // end of namespace impl
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void segment_image (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const double k = 200,
+ const unsigned long min_size = 10
+ )
+ {
+ using namespace dlib::impl;
+ typedef typename image_traits<in_image_type>::pixel_type ptype;
+ typedef typename edge_diff_funct<ptype>::diff_type diff_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_same_object(in_img_, out_img_) == false,
+ "\t void segment_image()"
+ << "\n\t The input images can't be the same object."
+ );
+
+ COMPILE_TIME_ASSERT(is_unsigned_type<typename image_traits<out_image_type>::pixel_type>::value);
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ out_img.set_size(in_img.nr(), in_img.nc());
+ // don't bother doing anything if the image is too small
+ if (in_img.nr() < 2 || in_img.nc() < 2)
+ {
+ assign_all_pixels(out_img,0);
+ return;
+ }
+
+ disjoint_subsets sets;
+ sets.set_size(in_img.size());
+
+ std::vector<segment_image_edge_data_T<diff_type> > sorted_edges;
+ get_pixel_edges(in_img, sorted_edges);
+
+ std::vector<graph_image_segmentation_data_T<diff_type> > data(in_img.size());
+
+ // now start connecting blobs together to make a minimum spanning tree.
+ for (unsigned long i = 0; i < sorted_edges.size(); ++i)
+ {
+ const unsigned long idx1 = sorted_edges[i].idx1;
+ const unsigned long idx2 = sorted_edges[i].idx2;
+
+ unsigned long set1 = sets.find_set(idx1);
+ unsigned long set2 = sets.find_set(idx2);
+ if (set1 != set2)
+ {
+ const diff_type diff = sorted_edges[i].diff;
+ const diff_type tau1 = static_cast<diff_type>(k/data[set1].component_size);
+ const diff_type tau2 = static_cast<diff_type>(k/data[set2].component_size);
+
+ const diff_type mint = std::min(data[set1].internal_diff + tau1,
+ data[set2].internal_diff + tau2);
+ if (diff <= mint)
+ {
+ const unsigned long new_set = sets.merge_sets(set1, set2);
+ data[new_set].component_size = data[set1].component_size + data[set2].component_size;
+ data[new_set].internal_diff = diff;
+ }
+ }
+ }
+
+ // now merge any really small blobs
+ if (min_size != 0)
+ {
+ for (unsigned long i = 0; i < sorted_edges.size(); ++i)
+ {
+ const unsigned long idx1 = sorted_edges[i].idx1;
+ const unsigned long idx2 = sorted_edges[i].idx2;
+
+ unsigned long set1 = sets.find_set(idx1);
+ unsigned long set2 = sets.find_set(idx2);
+ if (set1 != set2 && (data[set1].component_size < min_size || data[set2].component_size < min_size))
+ {
+ const unsigned long new_set = sets.merge_sets(set1, set2);
+ data[new_set].component_size = data[set1].component_size + data[set2].component_size;
+ //data[new_set].internal_diff = sorted_edges[i].diff;
+ }
+ }
+ }
+
+ unsigned long idx = 0;
+ for (long r = 0; r < out_img.nr(); ++r)
+ {
+ for (long c = 0; c < out_img.nc(); ++c)
+ {
+ out_img[r][c] = sets.find_set(idx++);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Candidate object location generation code.
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ struct edge_data
+ {
+ double edge_diff;
+ unsigned long set1;
+ unsigned long set2;
+ bool operator<(const edge_data& item) const
+ {
+ return edge_diff < item.edge_diff;
+ }
+ };
+
+ template <
+ typename in_image_type,
+ typename diff_type
+ >
+ void find_basic_candidate_object_locations (
+ const in_image_type& in_img,
+ const std::vector<dlib::impl::segment_image_edge_data_T<diff_type> >& sorted_edges,
+ std::vector<rectangle>& out_rects,
+ std::vector<edge_data>& edges,
+ const double k,
+ const unsigned long min_size
+ )
+ {
+ using namespace dlib::impl;
+
+ std::vector<dlib::impl::segment_image_edge_data_T<diff_type> > rejected_edges;
+ rejected_edges.reserve(sorted_edges.size());
+
+ out_rects.clear();
+ edges.clear();
+
+ // don't bother doing anything if the image is too small
+ if (in_img.nr() < 2 || in_img.nc() < 2)
+ {
+ return;
+ }
+
+ disjoint_subsets sets;
+ sets.set_size(in_img.size());
+
+
+ std::vector<graph_image_segmentation_data_T<diff_type> > data(in_img.size());
+
+
+
+ std::pair<unsigned long,unsigned long> last_blob_edge(std::numeric_limits<unsigned long>::max(),
+ std::numeric_limits<unsigned long>::max());;
+ // now start connecting blobs together to make a minimum spanning tree.
+ for (unsigned long i = 0; i < sorted_edges.size(); ++i)
+ {
+ const unsigned long idx1 = sorted_edges[i].idx1;
+ const unsigned long idx2 = sorted_edges[i].idx2;
+
+ unsigned long set1 = sets.find_set(idx1);
+ unsigned long set2 = sets.find_set(idx2);
+ if (set1 != set2)
+ {
+ const diff_type diff = sorted_edges[i].diff;
+ const diff_type tau1 = static_cast<diff_type>(k/data[set1].component_size);
+ const diff_type tau2 = static_cast<diff_type>(k/data[set2].component_size);
+
+ const diff_type mint = std::min(data[set1].internal_diff + tau1,
+ data[set2].internal_diff + tau2);
+ if (diff <= mint)
+ {
+ const unsigned long new_set = sets.merge_sets(set1, set2);
+ data[new_set].component_size = data[set1].component_size + data[set2].component_size;
+ data[new_set].internal_diff = diff;
+ }
+ else
+ {
+ // Don't bother keeping multiple edges from the same pair of blobs, we
+ // only need one for what we will do later.
+ if (std::make_pair(set1,set2) != last_blob_edge)
+ {
+ segment_image_edge_data_T<diff_type> temp = sorted_edges[i];
+ temp.idx1 = set1;
+ temp.idx2 = set2;
+ rejected_edges.push_back(temp);
+ last_blob_edge = std::make_pair(set1,set2);
+ }
+ }
+ }
+ }
+
+
+ // merge small blobs
+ for (unsigned long i = 0; i < rejected_edges.size(); ++i)
+ {
+ const unsigned long idx1 = rejected_edges[i].idx1;
+ const unsigned long idx2 = rejected_edges[i].idx2;
+
+ unsigned long set1 = sets.find_set(idx1);
+ unsigned long set2 = sets.find_set(idx2);
+ rejected_edges[i].idx1 = set1;
+ rejected_edges[i].idx2 = set2;
+ if (set1 != set2 && (data[set1].component_size < min_size || data[set2].component_size < min_size))
+ {
+ const unsigned long new_set = sets.merge_sets(set1, set2);
+ data[new_set].component_size = data[set1].component_size + data[set2].component_size;
+ data[new_set].internal_diff = rejected_edges[i].diff;
+ }
+ }
+
+ // find bounding boxes of each blob
+ std::map<unsigned long, rectangle> boxes;
+ std::map<unsigned long, unsigned long> box_id_map;
+ unsigned long idx = 0;
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ for (long c = 0; c < in_img.nc(); ++c)
+ {
+ const unsigned long id = sets.find_set(idx++);
+ // Accumulate the current point into its box and if it is the first point
+ // in the box then also record the id number for this box.
+ if ((boxes[id] += point(c,r)).area() == 1)
+ box_id_map[id] = boxes.size()-1;
+ }
+ }
+
+ // copy boxes into out_rects
+ out_rects.resize(boxes.size());
+ for (std::map<unsigned long,rectangle>::iterator i = boxes.begin(); i != boxes.end(); ++i)
+ {
+ out_rects[box_id_map[i->first]] = i->second;
+ }
+
+ // Now find the edges between the boxes
+ typedef dlib::memory_manager<char>::kernel_2c mm_type;
+ dlib::set<std::pair<unsigned long, unsigned long>, mm_type>::kernel_1a neighbors_final;
+ for (unsigned long i = 0; i < rejected_edges.size(); ++i)
+ {
+ const unsigned long idx1 = rejected_edges[i].idx1;
+ const unsigned long idx2 = rejected_edges[i].idx2;
+
+ unsigned long set1 = sets.find_set(idx1);
+ unsigned long set2 = sets.find_set(idx2);
+ if (set1 != set2)
+ {
+ std::pair<unsigned long, unsigned long> p = std::make_pair(set1,set2);
+ if (!neighbors_final.is_member(p))
+ {
+ neighbors_final.add(p);
+
+ edge_data temp;
+ const diff_type mint = std::min(data[set1].internal_diff ,
+ data[set2].internal_diff );
+ temp.edge_diff = rejected_edges[i].diff - mint;
+ temp.set1 = box_id_map[set1];
+ temp.set2 = box_id_map[set2];
+ edges.push_back(temp);
+ }
+ }
+ }
+
+ std::sort(edges.begin(), edges.end());
+ }
+ } // end namespace impl
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename alloc>
+ void remove_duplicates (
+ std::vector<rectangle,alloc>& rects
+ )
+ {
+ std::sort(rects.begin(), rects.end(), std::less<rectangle>());
+ unsigned long num_unique = 1;
+ for (unsigned long i = 1; i < rects.size(); ++i)
+ {
+ if (rects[i] != rects[i-1])
+ {
+ rects[num_unique++] = rects[i];
+ }
+ }
+ if (rects.size() != 0)
+ rects.resize(num_unique);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename EXP
+ >
+ void find_candidate_object_locations (
+ const in_image_type& in_img_,
+ std::vector<rectangle>& rects,
+ const matrix_exp<EXP>& kvals,
+ const unsigned long min_size = 20,
+ const unsigned long max_merging_iterations = 50
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(kvals) && kvals.size() > 0,
+ "\t void find_candidate_object_locations()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t is_vector(kvals): " << is_vector(kvals)
+ << "\n\t kvals.size(): " << kvals.size()
+ );
+
+ typedef dlib::memory_manager<char>::kernel_2c mm_type;
+ typedef dlib::set<rectangle, mm_type>::kernel_1a set_of_rects;
+
+ using namespace dlib::impl;
+ typedef typename image_traits<in_image_type>::pixel_type ptype;
+ typedef typename edge_diff_funct<ptype>::diff_type diff_type;
+
+ const_image_view<in_image_type> in_img(in_img_);
+
+ // don't bother doing anything if the image is too small
+ if (in_img.nr() < 2 || in_img.nc() < 2)
+ {
+ return;
+ }
+
+ std::vector<edge_data> edges;
+ std::vector<rectangle> working_rects;
+ std::vector<segment_image_edge_data_T<diff_type> > sorted_edges;
+ get_pixel_edges(in_img, sorted_edges);
+
+ disjoint_subsets sets;
+
+ for (long j = 0; j < kvals.size(); ++j)
+ {
+ const double k = kvals(j);
+
+ find_basic_candidate_object_locations(in_img, sorted_edges, working_rects, edges, k, min_size);
+ rects.insert(rects.end(), working_rects.begin(), working_rects.end());
+
+
+ // Now iteratively merge all the rectangles we have and record the results.
+ // Note that, unlike what is described in the paper
+ // Segmentation as Selective Search for Object Recognition" by Koen E. A. van de Sande, et al.
+ // we don't use any kind of histogram/SIFT like thing to order the edges
+ // between the blobs. Here we simply order by the pixel difference value.
+ // Additionally, note that we keep progressively merging boxes in the outer
+ // loop rather than performing just a single iteration as indicated in the
+ // paper.
+ set_of_rects detected_rects;
+ bool did_merge = true;
+ for (unsigned long iter = 0; did_merge && iter < max_merging_iterations; ++iter)
+ {
+ did_merge = false;
+ sets.clear();
+ sets.set_size(working_rects.size());
+
+ // recursively merge neighboring blobs until we have merged everything
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ edge_data temp = edges[i];
+
+ temp.set1 = sets.find_set(temp.set1);
+ temp.set2 = sets.find_set(temp.set2);
+ if (temp.set1 != temp.set2)
+ {
+ rectangle merged_rect = working_rects[temp.set1] + working_rects[temp.set2];
+ // Skip merging this pair of blobs if it was merged in a previous
+ // iteration. Doing this lets us consider other possible blob
+ // merges.
+ if (!detected_rects.is_member(merged_rect))
+ {
+ const unsigned long new_set = sets.merge_sets(temp.set1, temp.set2);
+ rects.push_back(merged_rect);
+ working_rects[new_set] = merged_rect;
+ did_merge = true;
+ detected_rects.add(merged_rect);
+ }
+ }
+ }
+ }
+ }
+
+ remove_duplicates(rects);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type
+ >
+ void find_candidate_object_locations (
+ const in_image_type& in_img,
+ std::vector<rectangle>& rects
+ )
+ {
+ find_candidate_object_locations(in_img, rects, linspace(50, 200, 3));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEGMENT_ImAGE_Hh_
+
diff --git a/ml/dlib/dlib/image_transforms/segment_image_abstract.h b/ml/dlib/dlib/image_transforms/segment_image_abstract.h
new file mode 100644
index 000000000..af1af46a1
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/segment_image_abstract.h
@@ -0,0 +1,126 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SEGMENT_ImAGE_ABSTRACT_Hh_
+#ifdef DLIB_SEGMENT_ImAGE_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void segment_image (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const double k = 200,
+ const unsigned long min_size = 10
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - in_image_type can contain any pixel type with a pixel_traits specialization
+ or a dlib matrix object representing a row or column vector.
+ - out_image_type must contain an unsigned integer pixel type.
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - Attempts to segment in_img into regions which have some visual consistency to
+ them. In particular, this function implements the algorithm described in the
+ paper: Efficient Graph-Based Image Segmentation by Felzenszwalb and Huttenlocher.
+ - #out_img.nr() == in_img.nr()
+ - #out_img.nc() == in_img.nc()
+ - for all valid r and c:
+ - #out_img[r][c] == an integer value indicating the identity of the segment
+ containing the pixel in_img[r][c].
+ - The k parameter is a measure used to influence how large the segment regions
+ will be. Larger k generally results in larger segments being produced. For
+ a deeper discussion of the k parameter you should consult the above
+ referenced paper.
+ - min_size is a lower bound on the size of the output segments. That is, it is
+ guaranteed that all output segments will have at least min_size pixels in
+ them (unless the whole image contains fewer than min_size pixels, in this
+ case the entire image will be put into a single segment).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename EXP
+ >
+ void find_candidate_object_locations (
+ const in_image_type& in_img,
+ std::vector<rectangle>& rects,
+ const matrix_exp<EXP>& kvals = linspace(50, 200, 3),
+ const unsigned long min_size = 20,
+ const unsigned long max_merging_iterations = 50
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - is_vector(kvals) == true
+ - kvals.size() > 0
+ ensures
+ - This function takes an input image and generates a set of candidate
+ rectangles which are expected to bound any objects in the image. It does
+ this by running a version of the segment_image() routine on the image and
+ then reports rectangles containing each of the segments as well as rectangles
+ containing unions of adjacent segments. The basic idea is described in the
+ paper:
+ Segmentation as Selective Search for Object Recognition by Koen E. A. van de Sande, et al.
+ Note that this function deviates from what is described in the paper slightly.
+ See the code for details.
+ - The basic segmentation is performed kvals.size() times, each time with the k
+ parameter (see segment_image() and the Felzenszwalb paper for details on k)
+ set to a different value from kvals.
+ - When doing the basic segmentations prior to any box merging, we discard all
+ rectangles that have an area < min_size. Therefore, all outputs and
+ subsequent merged rectangles are built out of rectangles that contain at
+ least min_size pixels. Note that setting min_size to a smaller value than
+ you might otherwise be interested in using can be useful since it allows a
+ larger number of possible merged boxes to be created.
+ - There are max_merging_iterations rounds of neighboring blob merging.
+ Therefore, this parameter has some effect on the number of output rectangles
+ you get, with larger values of the parameter giving more output rectangles.
+ - This function appends the output rectangles into #rects. This means that any
+ rectangles in rects before this function was called will still be in there
+ after it terminates. Note further that #rects will not contain any duplicate
+ rectangles. That is, for all valid i and j where i != j it will be true
+ that:
+ - #rects[i] != rects[j]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename alloc
+ >
+ void remove_duplicates (
+ std::vector<rectangle,alloc>& rects
+ );
+ /*!
+ ensures
+ - This function finds any duplicate rectangles in rects and removes the extra
+ instances. This way, the result is that rects contains only unique rectangle
+ instances.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEGMENT_ImAGE_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/image_transforms/spatial_filtering.h b/ml/dlib/dlib/image_transforms/spatial_filtering.h
new file mode 100644
index 000000000..91dcae321
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/spatial_filtering.h
@@ -0,0 +1,1580 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SPATIAL_FILTERINg_H_
+#define DLIB_SPATIAL_FILTERINg_H_
+
+#include "../pixel.h"
+#include "spatial_filtering_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include "../array2d.h"
+#include "../matrix.h"
+#include "../geometry/border_enumerator.h"
+#include "../simd.h"
+#include <limits>
+#include "assign_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP,
+ typename T
+ >
+ rectangle grayscale_spatially_filter_image (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const matrix_exp<EXP>& filter_,
+ T scale,
+ bool use_abs,
+ bool add_to
+ )
+ {
+ const_temp_matrix<EXP> filter(filter_);
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false );
+
+ DLIB_ASSERT(scale != 0 && filter.size() != 0,
+ "\trectangle spatially_filter_image()"
+ << "\n\t You can't give a scale of zero or an empty filter."
+ << "\n\t scale: "<< scale
+ << "\n\t filter.nr(): "<< filter.nr()
+ << "\n\t filter.nc(): "<< filter.nc()
+ );
+ DLIB_ASSERT(is_same_object(in_img_, out_img_) == false,
+ "\trectangle spatially_filter_image()"
+ << "\n\tYou must give two different image objects"
+ );
+
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return rectangle();
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+
+ // figure out the range that we should apply the filter to
+ const long first_row = filter.nr()/2;
+ const long first_col = filter.nc()/2;
+ const long last_row = in_img.nr() - ((filter.nr()-1)/2);
+ const long last_col = in_img.nc() - ((filter.nc()-1)/2);
+
+ const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1);
+ if (!add_to)
+ zero_border_pixels(out_img_, non_border);
+
+ // apply the filter to the image
+ for (long r = first_row; r < last_row; ++r)
+ {
+ for (long c = first_col; c < last_col; ++c)
+ {
+ typedef typename EXP::type ptype;
+ ptype p;
+ ptype temp = 0;
+ for (long m = 0; m < filter.nr(); ++m)
+ {
+ for (long n = 0; n < filter.nc(); ++n)
+ {
+ // pull out the current pixel and put it into p
+ p = get_pixel_intensity(in_img[r-first_row+m][c-first_col+n]);
+ temp += p*filter(m,n);
+ }
+ }
+
+ temp /= scale;
+
+ if (use_abs && temp < 0)
+ {
+ temp = -temp;
+ }
+
+ // save this pixel to the output image
+ if (add_to == false)
+ {
+ assign_pixel(out_img[r][c], temp);
+ }
+ else
+ {
+ assign_pixel(out_img[r][c], temp + out_img[r][c]);
+ }
+ }
+ }
+
+ return non_border;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP
+ >
+ rectangle float_spatially_filter_image (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const matrix_exp<EXP>& filter_,
+ bool add_to
+ )
+ {
+
+ const_temp_matrix<EXP> filter(filter_);
+ DLIB_ASSERT(filter.size() != 0,
+ "\trectangle spatially_filter_image()"
+ << "\n\t You can't give an empty filter."
+ << "\n\t filter.nr(): "<< filter.nr()
+ << "\n\t filter.nc(): "<< filter.nc()
+ );
+ DLIB_ASSERT(is_same_object(in_img_, out_img_) == false,
+ "\trectangle spatially_filter_image()"
+ << "\n\tYou must give two different image objects"
+ );
+
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return rectangle();
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+
+ // figure out the range that we should apply the filter to
+ const long first_row = filter.nr()/2;
+ const long first_col = filter.nc()/2;
+ const long last_row = in_img.nr() - ((filter.nr()-1)/2);
+ const long last_col = in_img.nc() - ((filter.nc()-1)/2);
+
+ const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1);
+ if (!add_to)
+ zero_border_pixels(out_img_, non_border);
+
+ // apply the filter to the image
+ for (long r = first_row; r < last_row; ++r)
+ {
+ long c = first_col;
+ for (; c < last_col-7; c+=8)
+ {
+ simd8f p,p2,p3;
+ simd8f temp = 0, temp2=0, temp3=0;
+ for (long m = 0; m < filter.nr(); ++m)
+ {
+ long n = 0;
+ for (; n < filter.nc()-2; n+=3)
+ {
+ // pull out the current pixel and put it into p
+ p.load(&in_img[r-first_row+m][c-first_col+n]);
+ p2.load(&in_img[r-first_row+m][c-first_col+n+1]);
+ p3.load(&in_img[r-first_row+m][c-first_col+n+2]);
+ temp += p*filter(m,n);
+ temp2 += p2*filter(m,n+1);
+ temp3 += p3*filter(m,n+2);
+ }
+ for (; n < filter.nc(); ++n)
+ {
+ // pull out the current pixel and put it into p
+ p.load(&in_img[r-first_row+m][c-first_col+n]);
+ temp += p*filter(m,n);
+ }
+ }
+ temp += temp2+temp3;
+
+ // save this pixel to the output image
+ if (add_to == false)
+ {
+ temp.store(&out_img[r][c]);
+ }
+ else
+ {
+ p.load(&out_img[r][c]);
+ temp += p;
+ temp.store(&out_img[r][c]);
+ }
+ }
+ for (; c < last_col; ++c)
+ {
+ float p;
+ float temp = 0;
+ for (long m = 0; m < filter.nr(); ++m)
+ {
+ for (long n = 0; n < filter.nc(); ++n)
+ {
+ // pull out the current pixel and put it into p
+ p = in_img[r-first_row+m][c-first_col+n];
+ temp += p*filter(m,n);
+ }
+ }
+
+ // save this pixel to the output image
+ if (add_to == false)
+ {
+ out_img[r][c] = temp;
+ }
+ else
+ {
+ out_img[r][c] += temp;
+ }
+ }
+ }
+
+ return non_border;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP
+ >
+ struct is_float_filtering2
+ {
+ const static bool value = is_same_type<typename image_traits<in_image_type>::pixel_type,float>::value &&
+ is_same_type<typename image_traits<out_image_type>::pixel_type,float>::value &&
+ is_same_type<typename EXP::type,float>::value;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP,
+ typename T
+ >
+ typename enable_if_c<pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale &&
+ is_float_filtering2<in_image_type,out_image_type,EXP>::value,rectangle>::type
+ spatially_filter_image (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP>& filter,
+ T scale,
+ bool use_abs = false,
+ bool add_to = false
+ )
+ {
+ if (use_abs == false)
+ {
+ if (scale == 1)
+ return impl::float_spatially_filter_image(in_img, out_img, filter, add_to);
+ else
+ return impl::float_spatially_filter_image(in_img, out_img, filter/scale, add_to);
+ }
+ else
+ {
+ return impl::grayscale_spatially_filter_image(in_img, out_img, filter, scale, true, add_to);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP,
+ typename T
+ >
+ typename enable_if_c<pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale &&
+ !is_float_filtering2<in_image_type,out_image_type,EXP>::value,rectangle>::type
+ spatially_filter_image (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP>& filter,
+ T scale,
+ bool use_abs = false,
+ bool add_to = false
+ )
+ {
+ return impl::grayscale_spatially_filter_image(in_img,out_img,filter,scale,use_abs,add_to);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP,
+ typename T
+ >
+ typename disable_if_c<pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale,rectangle>::type
+ spatially_filter_image (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const matrix_exp<EXP>& filter_,
+ T scale
+ )
+ {
+ const_temp_matrix<EXP> filter(filter_);
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false );
+
+ DLIB_ASSERT(scale != 0 && filter.size() != 0,
+ "\trectangle spatially_filter_image()"
+ << "\n\t You can't give a scale of zero or an empty filter."
+ << "\n\t scale: "<< scale
+ << "\n\t filter.nr(): "<< filter.nr()
+ << "\n\t filter.nc(): "<< filter.nc()
+ );
+ DLIB_ASSERT(is_same_object(in_img_, out_img_) == false,
+ "\trectangle spatially_filter_image()"
+ << "\n\tYou must give two different image objects"
+ );
+
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return rectangle();
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+
+ // figure out the range that we should apply the filter to
+ const long first_row = filter.nr()/2;
+ const long first_col = filter.nc()/2;
+ const long last_row = in_img.nr() - ((filter.nr()-1)/2);
+ const long last_col = in_img.nc() - ((filter.nc()-1)/2);
+
+ const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1);
+ zero_border_pixels(out_img, non_border);
+
+ // apply the filter to the image
+ for (long r = first_row; r < last_row; ++r)
+ {
+ for (long c = first_col; c < last_col; ++c)
+ {
+ typedef typename image_traits<in_image_type>::pixel_type pixel_type;
+ typedef matrix<typename EXP::type,pixel_traits<pixel_type>::num,1> ptype;
+ ptype p;
+ ptype temp;
+ temp = 0;
+ for (long m = 0; m < filter.nr(); ++m)
+ {
+ for (long n = 0; n < filter.nc(); ++n)
+ {
+ // pull out the current pixel and put it into p
+ p = pixel_to_vector<typename EXP::type>(in_img[r-first_row+m][c-first_col+n]);
+ temp += p*filter(m,n);
+ }
+ }
+
+ temp /= scale;
+
+ pixel_type pp;
+ vector_to_pixel(pp, temp);
+ assign_pixel(out_img[r][c], pp);
+ }
+ }
+
+ return non_border;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP
+ >
+ rectangle spatially_filter_image (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP>& filter
+ )
+ {
+ return spatially_filter_image(in_img,out_img,filter,1);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2,
+ typename T
+ >
+ rectangle grayscale_spatially_filter_image_separable (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const matrix_exp<EXP1>& _row_filter,
+ const matrix_exp<EXP2>& _col_filter,
+ T scale,
+ bool use_abs,
+ bool add_to
+ )
+ {
+ const_temp_matrix<EXP1> row_filter(_row_filter);
+ const_temp_matrix<EXP2> col_filter(_col_filter);
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false );
+
+ DLIB_ASSERT(scale != 0 && row_filter.size() != 0 && col_filter.size() != 0 &&
+ is_vector(row_filter) &&
+ is_vector(col_filter),
+ "\trectangle spatially_filter_image_separable()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t scale: "<< scale
+ << "\n\t row_filter.size(): "<< row_filter.size()
+ << "\n\t col_filter.size(): "<< col_filter.size()
+ << "\n\t is_vector(row_filter): "<< is_vector(row_filter)
+ << "\n\t is_vector(col_filter): "<< is_vector(col_filter)
+ );
+ DLIB_ASSERT(is_same_object(in_img_, out_img_) == false,
+ "\trectangle spatially_filter_image_separable()"
+ << "\n\tYou must give two different image objects"
+ );
+
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return rectangle();
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+
+ // figure out the range that we should apply the filter to
+ const long first_row = col_filter.size()/2;
+ const long first_col = row_filter.size()/2;
+ const long last_row = in_img.nr() - ((col_filter.size()-1)/2);
+ const long last_col = in_img.nc() - ((row_filter.size()-1)/2);
+
+ const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1);
+ if (!add_to)
+ zero_border_pixels(out_img, non_border);
+
+ typedef typename EXP1::type ptype;
+
+ array2d<ptype> temp_img;
+ temp_img.set_size(in_img.nr(), in_img.nc());
+
+ // apply the row filter
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ for (long c = first_col; c < last_col; ++c)
+ {
+ ptype p;
+ ptype temp = 0;
+ for (long n = 0; n < row_filter.size(); ++n)
+ {
+ // pull out the current pixel and put it into p
+ p = get_pixel_intensity(in_img[r][c-first_col+n]);
+ temp += p*row_filter(n);
+ }
+ temp_img[r][c] = temp;
+ }
+ }
+
+ // apply the column filter
+ for (long r = first_row; r < last_row; ++r)
+ {
+ for (long c = first_col; c < last_col; ++c)
+ {
+ ptype temp = 0;
+ for (long m = 0; m < col_filter.size(); ++m)
+ {
+ temp += temp_img[r-first_row+m][c]*col_filter(m);
+ }
+
+ temp /= scale;
+
+ if (use_abs && temp < 0)
+ {
+ temp = -temp;
+ }
+
+ // save this pixel to the output image
+ if (add_to == false)
+ {
+ assign_pixel(out_img[r][c], temp);
+ }
+ else
+ {
+ assign_pixel(out_img[r][c], temp + out_img[r][c]);
+ }
+ }
+ }
+ return non_border;
+ }
+
+ } // namespace impl
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2
+ >
+ struct is_float_filtering
+ {
+ const static bool value = is_same_type<typename image_traits<in_image_type>::pixel_type,float>::value &&
+ is_same_type<typename image_traits<out_image_type>::pixel_type,float>::value &&
+ is_same_type<typename EXP1::type,float>::value &&
+ is_same_type<typename EXP2::type,float>::value;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ // This overload is optimized to use SIMD instructions when filtering float images with
+ // float filters.
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2
+ >
+ rectangle float_spatially_filter_image_separable (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const matrix_exp<EXP1>& _row_filter,
+ const matrix_exp<EXP2>& _col_filter,
+ out_image_type& scratch_,
+ bool add_to = false
+ )
+ {
+ // You can only use this function with images and filters containing float
+ // variables.
+ COMPILE_TIME_ASSERT((is_float_filtering<in_image_type,out_image_type,EXP1,EXP2>::value == true));
+
+
+ const_temp_matrix<EXP1> row_filter(_row_filter);
+ const_temp_matrix<EXP2> col_filter(_col_filter);
+ DLIB_ASSERT(row_filter.size() != 0 && col_filter.size() != 0 &&
+ is_vector(row_filter) &&
+ is_vector(col_filter),
+ "\trectangle float_spatially_filter_image_separable()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t row_filter.size(): "<< row_filter.size()
+ << "\n\t col_filter.size(): "<< col_filter.size()
+ << "\n\t is_vector(row_filter): "<< is_vector(row_filter)
+ << "\n\t is_vector(col_filter): "<< is_vector(col_filter)
+ );
+ DLIB_ASSERT(is_same_object(in_img_, out_img_) == false,
+ "\trectangle float_spatially_filter_image_separable()"
+ << "\n\tYou must give two different image objects"
+ );
+
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return rectangle();
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+ // figure out the range that we should apply the filter to
+ const long first_row = col_filter.size()/2;
+ const long first_col = row_filter.size()/2;
+ const long last_row = in_img.nr() - ((col_filter.size()-1)/2);
+ const long last_col = in_img.nc() - ((row_filter.size()-1)/2);
+
+ const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1);
+ if (!add_to)
+ zero_border_pixels(out_img, non_border);
+
+ image_view<out_image_type> scratch(scratch_);
+ scratch.set_size(in_img.nr(), in_img.nc());
+
+ // apply the row filter
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ long c = first_col;
+ for (; c < last_col-7; c+=8)
+ {
+ simd8f p,p2,p3, temp = 0, temp2=0, temp3=0;
+ long n = 0;
+ for (; n < row_filter.size()-2; n+=3)
+ {
+ // pull out the current pixel and put it into p
+ p.load(&in_img[r][c-first_col+n]);
+ p2.load(&in_img[r][c-first_col+n+1]);
+ p3.load(&in_img[r][c-first_col+n+2]);
+ temp += p*row_filter(n);
+ temp2 += p2*row_filter(n+1);
+ temp3 += p3*row_filter(n+2);
+ }
+ for (; n < row_filter.size(); ++n)
+ {
+ // pull out the current pixel and put it into p
+ p.load(&in_img[r][c-first_col+n]);
+ temp += p*row_filter(n);
+ }
+ temp += temp2 + temp3;
+ temp.store(&scratch[r][c]);
+ }
+ for (; c < last_col; ++c)
+ {
+ float p;
+ float temp = 0;
+ for (long n = 0; n < row_filter.size(); ++n)
+ {
+ // pull out the current pixel and put it into p
+ p = in_img[r][c-first_col+n];
+ temp += p*row_filter(n);
+ }
+ scratch[r][c] = temp;
+ }
+ }
+
+ // apply the column filter
+ for (long r = first_row; r < last_row; ++r)
+ {
+ long c = first_col;
+ for (; c < last_col-7; c+=8)
+ {
+ simd8f p, p2, p3, temp = 0, temp2 = 0, temp3 = 0;
+ long m = 0;
+ for (; m < col_filter.size()-2; m+=3)
+ {
+ p.load(&scratch[r-first_row+m][c]);
+ p2.load(&scratch[r-first_row+m+1][c]);
+ p3.load(&scratch[r-first_row+m+2][c]);
+ temp += p*col_filter(m);
+ temp2 += p2*col_filter(m+1);
+ temp3 += p3*col_filter(m+2);
+ }
+ for (; m < col_filter.size(); ++m)
+ {
+ p.load(&scratch[r-first_row+m][c]);
+ temp += p*col_filter(m);
+ }
+ temp += temp2+temp3;
+
+ // save this pixel to the output image
+ if (add_to == false)
+ {
+ temp.store(&out_img[r][c]);
+ }
+ else
+ {
+ p.load(&out_img[r][c]);
+ temp += p;
+ temp.store(&out_img[r][c]);
+ }
+ }
+ for (; c < last_col; ++c)
+ {
+ float temp = 0;
+ for (long m = 0; m < col_filter.size(); ++m)
+ {
+ temp += scratch[r-first_row+m][c]*col_filter(m);
+ }
+
+ // save this pixel to the output image
+ if (add_to == false)
+ {
+ out_img[r][c] = temp;
+ }
+ else
+ {
+ out_img[r][c] += temp;
+ }
+ }
+ }
+ return non_border;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2,
+ typename T
+ >
+ typename enable_if_c<pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale &&
+ is_float_filtering<in_image_type,out_image_type,EXP1,EXP2>::value,rectangle>::type
+ spatially_filter_image_separable (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP1>& row_filter,
+ const matrix_exp<EXP2>& col_filter,
+ T scale,
+ bool use_abs = false,
+ bool add_to = false
+ )
+ {
+ if (use_abs == false)
+ {
+ out_image_type scratch;
+ if (scale == 1)
+ return float_spatially_filter_image_separable(in_img, out_img, row_filter, col_filter, scratch, add_to);
+ else
+ return float_spatially_filter_image_separable(in_img, out_img, row_filter/scale, col_filter, scratch, add_to);
+ }
+ else
+ {
+ return impl::grayscale_spatially_filter_image_separable(in_img, out_img, row_filter, col_filter, scale, true, add_to);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2,
+ typename T
+ >
+ typename enable_if_c<pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale &&
+ !is_float_filtering<in_image_type,out_image_type,EXP1,EXP2>::value,rectangle>::type
+ spatially_filter_image_separable (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP1>& row_filter,
+ const matrix_exp<EXP2>& col_filter,
+ T scale,
+ bool use_abs = false,
+ bool add_to = false
+ )
+ {
+ return impl::grayscale_spatially_filter_image_separable(in_img,out_img, row_filter, col_filter, scale, use_abs, add_to);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2,
+ typename T
+ >
+ typename disable_if_c<pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale,rectangle>::type
+ spatially_filter_image_separable (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const matrix_exp<EXP1>& _row_filter,
+ const matrix_exp<EXP2>& _col_filter,
+ T scale
+ )
+ {
+ const_temp_matrix<EXP1> row_filter(_row_filter);
+ const_temp_matrix<EXP2> col_filter(_col_filter);
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false );
+
+ DLIB_ASSERT(scale != 0 && row_filter.size() != 0 && col_filter.size() != 0 &&
+ is_vector(row_filter) &&
+ is_vector(col_filter),
+ "\trectangle spatially_filter_image_separable()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t scale: "<< scale
+ << "\n\t row_filter.size(): "<< row_filter.size()
+ << "\n\t col_filter.size(): "<< col_filter.size()
+ << "\n\t is_vector(row_filter): "<< is_vector(row_filter)
+ << "\n\t is_vector(col_filter): "<< is_vector(col_filter)
+ );
+ DLIB_ASSERT(is_same_object(in_img_, out_img_) == false,
+ "\trectangle spatially_filter_image_separable()"
+ << "\n\tYou must give two different image objects"
+ );
+
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return rectangle();
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+
+ // figure out the range that we should apply the filter to
+ const long first_row = col_filter.size()/2;
+ const long first_col = row_filter.size()/2;
+ const long last_row = in_img.nr() - ((col_filter.size()-1)/2);
+ const long last_col = in_img.nc() - ((row_filter.size()-1)/2);
+
+ const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1);
+ zero_border_pixels(out_img, non_border);
+
+ typedef typename image_traits<in_image_type>::pixel_type pixel_type;
+ typedef matrix<typename EXP1::type,pixel_traits<pixel_type>::num,1> ptype;
+
+ array2d<ptype> temp_img;
+ temp_img.set_size(in_img.nr(), in_img.nc());
+
+ // apply the row filter
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ for (long c = first_col; c < last_col; ++c)
+ {
+ ptype p;
+ ptype temp;
+ temp = 0;
+ for (long n = 0; n < row_filter.size(); ++n)
+ {
+ // pull out the current pixel and put it into p
+ p = pixel_to_vector<typename EXP1::type>(in_img[r][c-first_col+n]);
+ temp += p*row_filter(n);
+ }
+ temp_img[r][c] = temp;
+ }
+ }
+
+ // apply the column filter
+ for (long r = first_row; r < last_row; ++r)
+ {
+ for (long c = first_col; c < last_col; ++c)
+ {
+ ptype temp;
+ temp = 0;
+ for (long m = 0; m < col_filter.size(); ++m)
+ {
+ temp += temp_img[r-first_row+m][c]*col_filter(m);
+ }
+
+ temp /= scale;
+
+
+ // save this pixel to the output image
+ pixel_type p;
+ vector_to_pixel(p, temp);
+ assign_pixel(out_img[r][c], p);
+ }
+ }
+ return non_border;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2
+ >
+ rectangle spatially_filter_image_separable (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP1>& row_filter,
+ const matrix_exp<EXP2>& col_filter
+ )
+ {
+ return spatially_filter_image_separable(in_img,out_img,row_filter,col_filter,1);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2,
+ typename T
+ >
+ rectangle spatially_filter_image_separable_down (
+ const unsigned long downsample,
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ const matrix_exp<EXP1>& row_filter,
+ const matrix_exp<EXP2>& col_filter,
+ T scale,
+ bool use_abs = false,
+ bool add_to = false
+ )
+ {
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale == true );
+
+ DLIB_ASSERT(downsample > 0 &&
+ scale != 0 &&
+ row_filter.size()%2 == 1 &&
+ col_filter.size()%2 == 1 &&
+ is_vector(row_filter) &&
+ is_vector(col_filter),
+ "\trectangle spatially_filter_image_separable_down()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t downsample: "<< downsample
+ << "\n\t scale: "<< scale
+ << "\n\t row_filter.size(): "<< row_filter.size()
+ << "\n\t col_filter.size(): "<< col_filter.size()
+ << "\n\t is_vector(row_filter): "<< is_vector(row_filter)
+ << "\n\t is_vector(col_filter): "<< is_vector(col_filter)
+ );
+ DLIB_ASSERT(is_same_object(in_img_, out_img_) == false,
+ "\trectangle spatially_filter_image_separable_down()"
+ << "\n\tYou must give two different image objects"
+ );
+
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return rectangle();
+ }
+
+ out_img.set_size((long)(std::ceil((double)in_img.nr()/downsample)),
+ (long)(std::ceil((double)in_img.nc()/downsample)));
+
+ const double col_border = std::floor(col_filter.size()/2.0);
+ const double row_border = std::floor(row_filter.size()/2.0);
+
+ // figure out the range that we should apply the filter to
+ const long first_row = (long)std::ceil(col_border/downsample);
+ const long first_col = (long)std::ceil(row_border/downsample);
+ const long last_row = (long)std::ceil((in_img.nr() - col_border)/downsample) - 1;
+ const long last_col = (long)std::ceil((in_img.nc() - row_border)/downsample) - 1;
+
+ // zero border pixels
+ const rectangle non_border = rectangle(first_col, first_row, last_col, last_row);
+ zero_border_pixels(out_img,non_border);
+
+ typedef typename EXP1::type ptype;
+
+ array2d<ptype> temp_img;
+ temp_img.set_size(in_img.nr(), out_img.nc());
+
+ // apply the row filter
+ for (long r = 0; r < temp_img.nr(); ++r)
+ {
+ for (long c = non_border.left(); c <= non_border.right(); ++c)
+ {
+ ptype p;
+ ptype temp = 0;
+ for (long n = 0; n < row_filter.size(); ++n)
+ {
+ // pull out the current pixel and put it into p
+ p = get_pixel_intensity(in_img[r][c*downsample-row_filter.size()/2+n]);
+ temp += p*row_filter(n);
+ }
+ temp_img[r][c] = temp;
+ }
+ }
+
+ // apply the column filter
+ for (long r = non_border.top(); r <= non_border.bottom(); ++r)
+ {
+ for (long c = non_border.left(); c <= non_border.right(); ++c)
+ {
+ ptype temp = 0;
+ for (long m = 0; m < col_filter.size(); ++m)
+ {
+ temp += temp_img[r*downsample-col_filter.size()/2+m][c]*col_filter(m);
+ }
+
+ temp /= scale;
+
+ if (use_abs && temp < 0)
+ {
+ temp = -temp;
+ }
+
+ // save this pixel to the output image
+ if (add_to == false)
+ {
+ assign_pixel(out_img[r][c], temp);
+ }
+ else
+ {
+ assign_pixel(out_img[r][c], temp + out_img[r][c]);
+ }
+ }
+ }
+
+ return non_border;
+ }
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2
+ >
+ rectangle spatially_filter_image_separable_down (
+ const unsigned long downsample,
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP1>& row_filter,
+ const matrix_exp<EXP2>& col_filter
+ )
+ {
+ return spatially_filter_image_separable_down(downsample,in_img,out_img,row_filter,col_filter,1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long NR,
+ long NC,
+ typename T,
+ typename U,
+ typename in_image_type
+ >
+ inline void separable_3x3_filter_block_grayscale (
+ T (&block)[NR][NC],
+ const in_image_type& img_,
+ const long& r,
+ const long& c,
+ const U& fe1, // separable filter end
+ const U& fm, // separable filter middle
+ const U& fe2 // separable filter end 2
+ )
+ {
+ const_image_view<in_image_type> img(img_);
+ // make sure requires clause is not broken
+ DLIB_ASSERT(shrink_rect(get_rect(img),1).contains(c,r) &&
+ shrink_rect(get_rect(img),1).contains(c+NC-1,r+NR-1),
+ "\t void separable_3x3_filter_block_grayscale()"
+ << "\n\t The sub-window doesn't fit inside the given image."
+ << "\n\t get_rect(img): " << get_rect(img)
+ << "\n\t (c,r): " << point(c,r)
+ << "\n\t (c+NC-1,r+NR-1): " << point(c+NC-1,r+NR-1)
+ );
+
+
+ T row_filt[NR+2][NC];
+ for (long rr = 0; rr < NR+2; ++rr)
+ {
+ for (long cc = 0; cc < NC; ++cc)
+ {
+ row_filt[rr][cc] = get_pixel_intensity(img[r+rr-1][c+cc-1])*fe1 +
+ get_pixel_intensity(img[r+rr-1][c+cc])*fm +
+ get_pixel_intensity(img[r+rr-1][c+cc+1])*fe2;
+ }
+ }
+
+ for (long rr = 0; rr < NR; ++rr)
+ {
+ for (long cc = 0; cc < NC; ++cc)
+ {
+ block[rr][cc] = (row_filt[rr][cc]*fe1 +
+ row_filt[rr+1][cc]*fm +
+ row_filt[rr+2][cc]*fe2);
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long NR,
+ long NC,
+ typename T,
+ typename U,
+ typename in_image_type
+ >
+ inline void separable_3x3_filter_block_rgb (
+ T (&block)[NR][NC],
+ const in_image_type& img_,
+ const long& r,
+ const long& c,
+ const U& fe1, // separable filter end
+ const U& fm, // separable filter middle
+ const U& fe2 // separable filter end 2
+ )
+ {
+ const_image_view<in_image_type> img(img_);
+ // make sure requires clause is not broken
+ DLIB_ASSERT(shrink_rect(get_rect(img),1).contains(c,r) &&
+ shrink_rect(get_rect(img),1).contains(c+NC-1,r+NR-1),
+ "\t void separable_3x3_filter_block_rgb()"
+ << "\n\t The sub-window doesn't fit inside the given image."
+ << "\n\t get_rect(img): " << get_rect(img)
+ << "\n\t (c,r): " << point(c,r)
+ << "\n\t (c+NC-1,r+NR-1): " << point(c+NC-1,r+NR-1)
+ );
+
+ T row_filt[NR+2][NC];
+ for (long rr = 0; rr < NR+2; ++rr)
+ {
+ for (long cc = 0; cc < NC; ++cc)
+ {
+ row_filt[rr][cc].red = img[r+rr-1][c+cc-1].red*fe1 + img[r+rr-1][c+cc].red*fm + img[r+rr-1][c+cc+1].red*fe2;
+ row_filt[rr][cc].green = img[r+rr-1][c+cc-1].green*fe1 + img[r+rr-1][c+cc].green*fm + img[r+rr-1][c+cc+1].green*fe2;
+ row_filt[rr][cc].blue = img[r+rr-1][c+cc-1].blue*fe1 + img[r+rr-1][c+cc].blue*fm + img[r+rr-1][c+cc+1].blue*fe2;
+ }
+ }
+
+ for (long rr = 0; rr < NR; ++rr)
+ {
+ for (long cc = 0; cc < NC; ++cc)
+ {
+ block[rr][cc].red = row_filt[rr][cc].red*fe1 + row_filt[rr+1][cc].red*fm + row_filt[rr+2][cc].red*fe2;
+ block[rr][cc].green = row_filt[rr][cc].green*fe1 + row_filt[rr+1][cc].green*fm + row_filt[rr+2][cc].green*fe2;
+ block[rr][cc].blue = row_filt[rr][cc].blue*fe1 + row_filt[rr+1][cc].blue*fm + row_filt[rr+2][cc].blue*fe2;
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double gaussian (
+ double x,
+ double sigma
+ )
+ {
+ DLIB_ASSERT(sigma > 0,
+ "\tdouble gaussian(x)"
+ << "\n\t sigma must be bigger than 0"
+ << "\n\t sigma: " << sigma
+ );
+ const double sqrt_2_pi = 2.5066282746310002416123552393401041626930;
+ return 1.0/(sigma*sqrt_2_pi) * std::exp( -(x*x)/(2*sigma*sigma));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ matrix<T,0,1> create_gaussian_filter (
+ double sigma,
+ int max_size
+ )
+ {
+ DLIB_ASSERT(sigma > 0 && max_size > 0 && (max_size%2)==1,
+ "\t matrix<T,0,1> create_gaussian_filter()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t sigma: " << sigma
+ << "\n\t max_size: " << max_size
+ );
+
+ // Adjust the size so that the ratio of the gaussian values isn't huge.
+ // This only matters when T is an integer type. However, we do it for
+ // all types so that the behavior of this function is always relatively
+ // the same.
+ while (gaussian(0,sigma)/gaussian(max_size/2,sigma) > 50)
+ --max_size;
+
+
+ matrix<double,0,1> f(max_size);
+ for (long i = 0; i < f.size(); ++i)
+ {
+ f(i) = gaussian(i-max_size/2, sigma);
+ }
+
+ if (is_float_type<T>::value == false)
+ {
+ f /= f(0);
+ return matrix_cast<T>(round(f));
+ }
+ else
+ {
+ return matrix_cast<T>(f);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ rectangle gaussian_blur (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ double sigma = 1,
+ int max_size = 1001
+ )
+ {
+ DLIB_ASSERT(sigma > 0 && max_size > 0 && (max_size%2)==1 &&
+ is_same_object(in_img, out_img) == false,
+ "\t void gaussian_blur()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t sigma: " << sigma
+ << "\n\t max_size: " << max_size
+ << "\n\t is_same_object(in_img,out_img): " << is_same_object(in_img,out_img)
+ );
+
+ if (sigma < 18)
+ {
+ typedef typename pixel_traits<typename image_traits<out_image_type>::pixel_type>::basic_pixel_type type;
+ typedef typename promote<type>::type ptype;
+ const matrix<ptype,0,1>& filt = create_gaussian_filter<ptype>(sigma, max_size);
+ ptype scale = sum(filt);
+ scale = scale*scale;
+ return spatially_filter_image_separable(in_img, out_img, filt, filt, scale);
+ }
+ else
+ {
+ // For large sigma we need to use a type with a lot of precision to avoid
+ // numerical problems. So we use double here.
+ typedef double ptype;
+ const matrix<ptype,0,1>& filt = create_gaussian_filter<ptype>(sigma, max_size);
+ ptype scale = sum(filt);
+ scale = scale*scale;
+ return spatially_filter_image_separable(in_img, out_img, filt, filt, scale);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ bool add_to,
+ typename image_type1,
+ typename image_type2
+ >
+ void sum_filter (
+ const image_type1& img_,
+ image_type2& out_,
+ const rectangle& rect
+ )
+ {
+ const_image_view<image_type1> img(img_);
+ image_view<image_type2> out(out_);
+ DLIB_ASSERT(img.nr() == out.nr() &&
+ img.nc() == out.nc() &&
+ is_same_object(img_,out_) == false,
+ "\t void sum_filter()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t img.nr(): " << img.nr()
+ << "\n\t img.nc(): " << img.nc()
+ << "\n\t out.nr(): " << out.nr()
+ << "\n\t out.nc(): " << out.nc()
+ << "\n\t is_same_object(img_,out_): " << is_same_object(img_,out_)
+ );
+
+ typedef typename image_traits<image_type1>::pixel_type pixel_type;
+ typedef typename promote<pixel_type>::type ptype;
+
+ std::vector<ptype> column_sum;
+ column_sum.resize(img.nc() + rect.width(),0);
+
+ const long top = -1 + rect.top();
+ const long bottom = -1 + rect.bottom();
+ long left = rect.left()-1;
+
+ // initialize column_sum at row -1
+ for (unsigned long j = 0; j < column_sum.size(); ++j)
+ {
+ rectangle strip(left,top,left,bottom);
+ strip = strip.intersect(get_rect(img));
+ if (!strip.is_empty())
+ {
+ column_sum[j] = sum(matrix_cast<ptype>(subm(mat(img),strip)));
+ }
+
+ ++left;
+ }
+
+
+ const rectangle area = get_rect(img);
+
+ // Save width to avoid computing it over and over.
+ const long width = rect.width();
+
+
+ // Now do the bulk of the filtering work.
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ // set to sum at point(-1,r). i.e. should be equal to sum(mat(img), translate_rect(rect, point(-1,r)))
+ // We compute it's value in the next loop.
+ ptype cur_sum = 0;
+
+ // Update the first part of column_sum since we only work on the c+width part of column_sum
+ // in the main loop.
+ const long top = r + rect.top() - 1;
+ const long bottom = r + rect.bottom();
+ for (long k = 0; k < width; ++k)
+ {
+ const long right = k-width + rect.right();
+
+ const ptype br_corner = area.contains(right,bottom) ? img[bottom][right] : 0;
+ const ptype tr_corner = area.contains(right,top) ? img[top][right] : 0;
+ // update the sum in this column now that we are on the next row
+ column_sum[k] = column_sum[k] + br_corner - tr_corner;
+ cur_sum += column_sum[k];
+ }
+
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ const long top = r + rect.top() - 1;
+ const long bottom = r + rect.bottom();
+ const long right = c + rect.right();
+
+ const ptype br_corner = area.contains(right,bottom) ? img[bottom][right] : 0;
+ const ptype tr_corner = area.contains(right,top) ? img[top][right] : 0;
+
+ // update the sum in this column now that we are on the next row
+ column_sum[c+width] = column_sum[c+width] + br_corner - tr_corner;
+
+ // add in the new right side of the rect and subtract the old right side.
+ cur_sum = cur_sum + column_sum[c+width] - column_sum[c];
+
+ if (add_to)
+ out[r][c] += static_cast<typename image_traits<image_type2>::pixel_type>(cur_sum);
+ else
+ out[r][c] = static_cast<typename image_traits<image_type2>::pixel_type>(cur_sum);
+ }
+ }
+ }
+ }
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void sum_filter (
+ const image_type1& img,
+ image_type2& out,
+ const rectangle& rect
+ )
+ {
+ impl::sum_filter<true>(img,out,rect);
+ }
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void sum_filter_assign (
+ const image_type1& img,
+ image_type2& out,
+ const rectangle& rect
+ )
+ {
+ set_image_size(out, num_rows(img), num_columns(img));
+ impl::sum_filter<false>(img,out,rect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T>
+ class fast_deque
+ {
+ /*
+ This is a fast and minimal implementation of std::deque for
+ use with the max_filter.
+
+ This object assumes that no more than max_size elements
+ will ever be pushed into it at a time.
+ */
+ public:
+
+ explicit fast_deque(unsigned long max_size)
+ {
+ // find a power of two that upper bounds max_size
+ mask = 2;
+ while (mask < max_size)
+ mask *= 2;
+
+ clear();
+
+ data.resize(mask);
+ --mask; // make into bit mask
+ }
+
+ void clear()
+ {
+ first = 1;
+ last = 0;
+ size = 0;
+ }
+
+ bool empty() const
+ {
+ return size == 0;
+ }
+
+ void pop_back()
+ {
+ last = (last-1)&mask;
+ --size;
+ }
+
+ void push_back(const T& item)
+ {
+ last = (last+1)&mask;
+ ++size;
+ data[last] = item;
+ }
+
+ void pop_front()
+ {
+ first = (first+1)&mask;
+ --size;
+ }
+
+ const T& front() const
+ {
+ return data[first];
+ }
+
+ const T& back() const
+ {
+ return data[last];
+ }
+
+ private:
+
+ std::vector<T> data;
+ unsigned long mask;
+ unsigned long first;
+ unsigned long last;
+ unsigned long size;
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void max_filter (
+ image_type1& img_,
+ image_type2& out_,
+ const long width,
+ const long height,
+ const typename image_traits<image_type1>::pixel_type& thresh
+ )
+ {
+ image_view<image_type1> img(img_);
+ image_view<image_type2> out(out_);
+ DLIB_ASSERT( width > 0 &&
+ height > 0 &&
+ out.nr() == img.nr() &&
+ out.nc() == img.nc() &&
+ is_same_object(img_,out_) == false,
+ "\t void max_filter()"
+ << "\n\t Invalid arguments given to this function."
+ << "\n\t img.nr(): " << img.nr()
+ << "\n\t img.nc(): " << img.nc()
+ << "\n\t out.nr(): " << out.nr()
+ << "\n\t out.nc(): " << out.nc()
+ << "\n\t width: " << width
+ << "\n\t height: " << height
+ << "\n\t is_same_object(img_,out_): " << is_same_object(img_,out_)
+ );
+
+ typedef typename image_traits<image_type1>::pixel_type pixel_type;
+
+
+ dlib::impl::fast_deque<std::pair<long,pixel_type> > Q(std::max(width,height));
+
+ const long last_col = std::max(img.nc(), ((width-1)/2));
+ const long last_row = std::max(img.nr(), ((height-1)/2));
+
+ // run max filter along rows of img
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ Q.clear();
+ for (long c = 0; c < (width-1)/2 && c < img.nc(); ++c)
+ {
+ while (!Q.empty() && img[r][c] >= Q.back().second)
+ Q.pop_back();
+ Q.push_back(std::make_pair(c,img[r][c]));
+ }
+
+ for (long c = (width-1)/2; c < img.nc(); ++c)
+ {
+ while (!Q.empty() && img[r][c] >= Q.back().second)
+ Q.pop_back();
+ while (!Q.empty() && Q.front().first <= c-width)
+ Q.pop_front();
+ Q.push_back(std::make_pair(c,img[r][c]));
+
+ img[r][c-((width-1)/2)] = Q.front().second;
+ }
+
+ for (long c = last_col; c < img.nc() + ((width-1)/2); ++c)
+ {
+ while (!Q.empty() && Q.front().first <= c-width)
+ Q.pop_front();
+
+ img[r][c-((width-1)/2)] = Q.front().second;
+ }
+ }
+
+ // run max filter along columns of img. Store result in out.
+ for (long cc = 0; cc < img.nc(); ++cc)
+ {
+ Q.clear();
+ for (long rr = 0; rr < (height-1)/2 && rr < img.nr(); ++rr)
+ {
+ while (!Q.empty() && img[rr][cc] >= Q.back().second)
+ Q.pop_back();
+ Q.push_back(std::make_pair(rr,img[rr][cc]));
+ }
+
+ for (long rr = (height-1)/2; rr < img.nr(); ++rr)
+ {
+ while (!Q.empty() && img[rr][cc] >= Q.back().second)
+ Q.pop_back();
+ while (!Q.empty() && Q.front().first <= rr-height)
+ Q.pop_front();
+ Q.push_back(std::make_pair(rr,img[rr][cc]));
+
+ out[rr-((height-1)/2)][cc] += std::max(Q.front().second, thresh);
+ }
+
+ for (long rr = last_row; rr < img.nr() + ((height-1)/2); ++rr)
+ {
+ while (!Q.empty() && Q.front().first <= rr-height)
+ Q.pop_front();
+
+ out[rr-((height-1)/2)][cc] += std::max(Q.front().second, thresh);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SPATIAL_FILTERINg_H_
+
+
diff --git a/ml/dlib/dlib/image_transforms/spatial_filtering_abstract.h b/ml/dlib/dlib/image_transforms/spatial_filtering_abstract.h
new file mode 100644
index 000000000..5e200aa9a
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/spatial_filtering_abstract.h
@@ -0,0 +1,487 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SPATIAL_FILTERINg_ABSTRACT_
+#ifdef DLIB_SPATIAL_FILTERINg_ABSTRACT_
+
+#include "../pixel.h"
+#include "../matrix.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP,
+ typename T
+ >
+ rectangle spatially_filter_image (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP>& filter,
+ T scale = 1,
+ bool use_abs = false,
+ bool add_to = false
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - in_img and out_img do not contain pixels with an alpha channel. That is,
+ pixel_traits::has_alpha is false for the pixels in these objects.
+ - is_same_object(in_img, out_img) == false
+ - T must be some scalar type
+ - filter.size() != 0
+ - scale != 0
+ - if (in_img doesn't contain grayscale pixels) then
+ - use_abs == false && add_to == false
+ (i.e. You can only use the use_abs and add_to options with grayscale images)
+ ensures
+ - Applies the given spatial filter to in_img and stores the result in out_img (i.e.
+ cross-correlates in_img with filter). Also divides each resulting pixel by scale.
+ - The intermediate filter computations will be carried out using variables of type EXP::type.
+ This is whatever scalar type is used inside the filter matrix.
+ - Pixel values are stored into out_img using the assign_pixel() function and therefore
+ any applicable color space conversion or value saturation is performed. Note that if
+ add_to is true then the filtered output value will be added to out_img rather than
+ overwriting the original value.
+ - if (in_img doesn't contain grayscale pixels) then
+ - The filter is applied to each color channel independently.
+ - if (use_abs == true) then
+ - pixel values after filtering that are < 0 are converted to their absolute values.
+ - The filter is applied such that it's centered over the pixel it writes its
+ output into. For centering purposes, we consider the center element of the
+ filter to be filter(filter.nr()/2,filter.nc()/2). This means that the filter
+ that writes its output to a pixel at location point(c,r) and is W by H (width
+ by height) pixels in size operates on exactly the pixels in the rectangle
+ centered_rect(point(c,r),W,H) within in_img.
+ - Pixels close enough to the edge of in_img to not have the filter still fit
+ inside the image are always set to zero.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ - returns a rectangle which indicates what pixels in #out_img are considered
+ non-border pixels and therefore contain output from the filter.
+ - if (use_abs == false && all images and filers contain float types) then
+ - This function will use SIMD instructions and is particularly fast. So if
+ you can use this form of the function it can give a decent speed boost.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2,
+ typename T
+ >
+ rectangle spatially_filter_image_separable (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP1>& row_filter,
+ const matrix_exp<EXP2>& col_filter,
+ T scale = 1,
+ bool use_abs = false,
+ bool add_to = false
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - in_img and out_img do not contain pixels with an alpha channel. That is,
+ pixel_traits::has_alpha is false for the pixels in these objects.
+ - is_same_object(in_img, out_img) == false
+ - T must be some scalar type
+ - scale != 0
+ - row_filter.size() != 0
+ - col_filter.size() != 0
+ - is_vector(row_filter) == true
+ - is_vector(col_filter) == true
+ - if (in_img doesn't contain grayscale pixels) then
+ - use_abs == false && add_to == false
+ (i.e. You can only use the use_abs and add_to options with grayscale images)
+ ensures
+ - Applies the given separable spatial filter to in_img and stores the result in out_img.
+ Also divides each resulting pixel by scale. Calling this function has the same
+ effect as calling the regular spatially_filter_image() routine with a filter,
+ FILT, defined as follows:
+ - FILT(r,c) == col_filter(r)*row_filter(c)
+ - The intermediate filter computations will be carried out using variables of type EXP1::type.
+ This is whatever scalar type is used inside the row_filter matrix.
+ - Pixel values are stored into out_img using the assign_pixel() function and therefore
+ any applicable color space conversion or value saturation is performed. Note that if
+ add_to is true then the filtered output value will be added to out_img rather than
+ overwriting the original value.
+ - if (in_img doesn't contain grayscale pixels) then
+ - The filter is applied to each color channel independently.
+ - if (use_abs == true) then
+ - pixel values after filtering that are < 0 are converted to their absolute values
+ - The filter is applied such that it's centered over the pixel it writes its
+ output into. For centering purposes, we consider the center element of the
+ filter to be FILT(col_filter.size()/2,row_filter.size()/2). This means that
+ the filter that writes its output to a pixel at location point(c,r) and is W
+ by H (width by height) pixels in size operates on exactly the pixels in the
+ rectangle centered_rect(point(c,r),W,H) within in_img.
+ - Pixels close enough to the edge of in_img to not have the filter still fit
+ inside the image are always set to zero.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ - returns a rectangle which indicates what pixels in #out_img are considered
+ non-border pixels and therefore contain output from the filter.
+ - if (use_abs == false && all images and filers contain float types) then
+ - This function will use SIMD instructions and is particularly fast. So if
+ you can use this form of the function it can give a decent speed boost.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2
+ >
+ rectangle float_spatially_filter_image_separable (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP1>& row_filter,
+ const matrix_exp<EXP2>& col_filter,
+ out_image_type& scratch,
+ bool add_to = false
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - in_img, out_img, row_filter, and col_filter must all contain float type elements.
+ - is_same_object(in_img, out_img) == false
+ - row_filter.size() != 0
+ - col_filter.size() != 0
+ - is_vector(row_filter) == true
+ - is_vector(col_filter) == true
+ ensures
+ - This function is identical to the above spatially_filter_image_separable()
+ function except that it can only be invoked on float images with float
+ filters. In fact, spatially_filter_image_separable() invokes
+ float_spatially_filter_image_separable() in those cases. So why is
+ float_spatially_filter_image_separable() in the public API? The reason is
+ because the separable filtering routines internally allocate an image each
+ time they are called. If you want to avoid this memory allocation then you
+ can call float_spatially_filter_image_separable() and provide the scratch
+ image as input. This allows you to reuse the same scratch image for many
+ calls to float_spatially_filter_image_separable() and thereby avoid having it
+ allocated and freed for each call.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2,
+ typename T
+ >
+ rectangle spatially_filter_image_separable_down (
+ const unsigned long downsample,
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP1>& row_filter,
+ const matrix_exp<EXP2>& col_filter,
+ T scale = 1,
+ bool use_abs = false,
+ bool add_to = false
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - in_img and out_img do not contain pixels with an alpha channel. That is,
+ pixel_traits::has_alpha is false for the pixels in these objects.
+ - out_img contains grayscale pixels.
+ - is_same_object(in_img, out_img) == false
+ - T must be some scalar type
+ - scale != 0
+ - is_vector(row_filter) == true
+ - is_vector(col_filter) == true
+ - row_filter.size() % 2 == 1 (i.e. must be odd)
+ - col_filter.size() % 2 == 1 (i.e. must be odd)
+ - downsample > 0
+ ensures
+ - This function is equivalent to calling
+ spatially_filter_image_separable(in_img,out_img,row_filter,col_filter,scale,use_abs,add_to)
+ and then downsampling the output image by a factor of downsample. Therefore,
+ we will have that:
+ - #out_img.nr() == ceil((double)in_img.nr()/downsample)
+ - #out_img.nc() == ceil((double)in_img.nc()/downsample)
+ - #out_img[r][c] == filtered pixel corresponding to in_img[r*downsample][c*downsample]
+ - returns a rectangle which indicates what pixels in #out_img are considered
+ non-border pixels and therefore contain output from the filter.
+ - Note that the first row and column of non-zero padded data are the following
+ - first_row == ceil(floor(col_filter.size()/2.0)/downsample)
+ - first_col == ceil(floor(row_filter.size()/2.0)/downsample)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long NR,
+ long NC,
+ typename T,
+ typename U,
+ typename in_image_type
+ >
+ inline void separable_3x3_filter_block_grayscale (
+ T (&block)[NR][NC],
+ const in_image_type& img,
+ const long& r,
+ const long& c,
+ const U& fe1,
+ const U& fm,
+ const U& fe2
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - T and U should be scalar types
+ - shrink_rect(get_rect(img),1).contains(c,r)
+ - shrink_rect(get_rect(img),1).contains(c+NC-1,r+NR-1)
+ ensures
+ - Filters the image in the sub-window of img defined by a rectangle
+ with its upper left corner at (c,r) and lower right at (c+NC-1,r+NR-1).
+ - The output of the filter is stored in #block. Note that img will be
+ interpreted as a grayscale image.
+ - The filter used is defined by the separable filter [fe1 fm fe2]. So the
+ spatial filter is thus:
+ fe1*fe1 fe1*fm fe2*fe1
+ fe1*fm fm*fm fe2*fm
+ fe1*fe2 fe2*fm fe2*fe2
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long NR,
+ long NC,
+ typename T,
+ typename U,
+ typename in_image_type
+ >
+ inline void separable_3x3_filter_block_rgb (
+ T (&block)[NR][NC],
+ const in_image_type& img,
+ const long& r,
+ const long& c,
+ const U& fe1,
+ const U& fm,
+ const U& fe2
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - img must contain RGB pixels, that is pixel_traits::rgb == true for the pixels
+ in img.
+ - T should be a struct with .red .green and .blue members.
+ - U should be a scalar type
+ - shrink_rect(get_rect(img),1).contains(c,r)
+ - shrink_rect(get_rect(img),1).contains(c+NC-1,r+NR-1)
+ ensures
+ - Filters the image in the sub-window of img defined by a rectangle
+ with its upper left corner at (c,r) and lower right at (c+NC-1,r+NR-1).
+ - The output of the filter is stored in #block. Note that the filter is applied
+ to each color component independently.
+ - The filter used is defined by the separable filter [fe1 fm fe2]. So the
+ spatial filter is thus:
+ fe1*fe1 fe1*fm fe2*fe1
+ fe1*fm fm*fm fe2*fm
+ fe1*fe2 fe2*fm fe2*fe2
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline double gaussian (
+ double x,
+ double sigma
+ );
+ /*!
+ requires
+ - sigma > 0
+ ensures
+ - computes and returns the value of a 1D Gaussian function with mean 0
+ and standard deviation sigma at the given x value.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ matrix<T,0,1> create_gaussian_filter (
+ double sigma,
+ int size
+ );
+ /*!
+ requires
+ - sigma > 0
+ - size > 0
+ - size is an odd number
+ ensures
+ - returns a separable Gaussian filter F such that:
+ - is_vector(F) == true
+ - F.size() == size
+ - F is suitable for use with the spatially_filter_image_separable() routine
+ and its use with this function corresponds to running a Gaussian filter
+ of sigma width over an image.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ rectangle gaussian_blur (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ double sigma = 1,
+ int max_size = 1001
+ );
+ /*!
+ requires
+ - in_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - out_image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h
+ - in_img and out_img do not contain pixels with an alpha channel. That is,
+ pixel_traits::has_alpha is false for the pixels in these objects.
+ - is_same_object(in_img, out_img) == false
+ - sigma > 0
+ - max_size > 0
+ - max_size is an odd number
+ ensures
+ - Filters in_img with a Gaussian filter of sigma width. The actual spatial filter will
+ be applied to pixel blocks that are at most max_size wide and max_size tall (note that
+ this function will automatically select a smaller block size as appropriate). The
+ results are stored into #out_img.
+ - Pixel values are stored into out_img using the assign_pixel() function and therefore
+ any applicable color space conversion or value saturation is performed.
+ - if (in_img doesn't contain grayscale pixels) then
+ - The filter is applied to each color channel independently.
+ - Pixels close enough to the edge of in_img to not have the filter still fit
+ inside the image are set to zero.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ - returns a rectangle which indicates what pixels in #out_img are considered
+ non-border pixels and therefore contain output from the filter.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void sum_filter (
+ const image_type1& img,
+ image_type2& out,
+ const rectangle& rect
+ );
+ /*!
+ requires
+ - out.nr() == img.nr()
+ - out.nc() == img.nc()
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain grayscale pixels.
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain grayscale pixels.
+ - is_same_object(img,out) == false
+ ensures
+ - for all valid r and c:
+ - let SUM(r,c) == sum of pixels from img which are inside the rectangle
+ translate_rect(rect, point(c,r)).
+ - #out[r][c] == out[r][c] + SUM(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void sum_filter_assign (
+ const image_type1& img,
+ image_type2& out,
+ const rectangle& rect
+ );
+ /*!
+ requires
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain grayscale pixels.
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain grayscale pixels.
+ - is_same_object(img,out) == false
+ ensures
+ - #out.nr() == img.nr()
+ - #out.nc() == img.nc()
+ - for all valid r and c:
+ - let SUM(r,c) == sum of pixels from img which are inside the rectangle
+ translate_rect(rect, point(c,r)).
+ - #out[r][c] == SUM(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void max_filter (
+ image_type1& img,
+ image_type2& out,
+ const long width,
+ const long height,
+ const typename image_traits<image_type1>::pixel_type& thresh
+ );
+ /*!
+ requires
+ - out.nr() == img.nr()
+ - out.nc() == img.nc()
+ - image_type1 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain grayscale pixels.
+ - image_type2 == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h and it must contain grayscale pixels.
+ - is_same_object(img,out) == false
+ - width > 0 && height > 0
+ ensures
+ - for all valid r and c:
+ - let MAX(r,c) == maximum of pixels from img which are inside the rectangle
+ centered_rect(point(c,r), width, height)
+ - if (MAX(r,c) >= thresh)
+ - #out[r][c] == out[r][c] + MAX(r,c)
+ - else
+ - #out[r][c] == out[r][c] + thresh
+ - Does not change the size of img.
+ - Uses img as scratch space. Therefore, the pixel values in img will have
+ been modified by this function. That is, max_filter() destroys the contents
+ of img.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SPATIAL_FILTERINg_ABSTRACT_
+
diff --git a/ml/dlib/dlib/image_transforms/thresholding.h b/ml/dlib/dlib/image_transforms/thresholding.h
new file mode 100644
index 000000000..e4fb02c4a
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/thresholding.h
@@ -0,0 +1,340 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THRESHOLDINg_
+#define DLIB_THRESHOLDINg_
+
+#include "../pixel.h"
+#include "thresholding_abstract.h"
+#include "equalize_histogram.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ const unsigned char on_pixel = 255;
+ const unsigned char off_pixel = 0;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void threshold_image (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ typename pixel_traits<typename image_traits<in_image_type>::pixel_type>::basic_pixel_type thresh
+ )
+ {
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false );
+
+ COMPILE_TIME_ASSERT(pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale);
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ for (long c = 0; c < in_img.nc(); ++c)
+ {
+ if (get_pixel_intensity(in_img[r][c]) >= thresh)
+ assign_pixel(out_img[r][c], on_pixel);
+ else
+ assign_pixel(out_img[r][c], off_pixel);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ void threshold_image (
+ image_type& img,
+ typename pixel_traits<typename image_traits<image_type>::pixel_type>::basic_pixel_type thresh
+ )
+ {
+ threshold_image(img,img,thresh);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void auto_threshold_image (
+ const in_image_type& in_img_,
+ out_image_type& out_img_
+ )
+ {
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<in_image_type>::pixel_type>::is_unsigned == true );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::is_unsigned == true );
+
+ COMPILE_TIME_ASSERT(pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale);
+
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (image_size(in_img_) == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ unsigned long thresh;
+ // find the threshold we should use
+ matrix<unsigned long,1> hist;
+ get_histogram(in_img_,hist);
+
+ const_image_view<in_image_type> in_img(in_img_);
+
+ // Start our two means (a and b) out at the ends of the histogram
+ long a = 0;
+ long b = hist.size()-1;
+ bool moved_a = true;
+ bool moved_b = true;
+ while (moved_a || moved_b)
+ {
+ moved_a = false;
+ moved_b = false;
+
+ // catch the degenerate case where the histogram is empty
+ if (a >= b)
+ break;
+
+ if (hist(a) == 0)
+ {
+ ++a;
+ moved_a = true;
+ }
+
+ if (hist(b) == 0)
+ {
+ --b;
+ moved_b = true;
+ }
+ }
+
+ // now do k-means clustering with k = 2 on the histogram.
+ moved_a = true;
+ moved_b = true;
+ while (moved_a || moved_b)
+ {
+ moved_a = false;
+ moved_b = false;
+
+ int64 a_hits = 0;
+ int64 b_hits = 0;
+ int64 a_mass = 0;
+ int64 b_mass = 0;
+
+ for (long i = 0; i < hist.size(); ++i)
+ {
+ // if i is closer to a
+ if (std::abs(i-a) < std::abs(i-b))
+ {
+ a_mass += hist(i)*i;
+ a_hits += hist(i);
+ }
+ else // if i is closer to b
+ {
+ b_mass += hist(i)*i;
+ b_hits += hist(i);
+ }
+ }
+
+ long new_a = (a_mass + a_hits/2)/a_hits;
+ long new_b = (b_mass + b_hits/2)/b_hits;
+
+ if (new_a != a)
+ {
+ moved_a = true;
+ a = new_a;
+ }
+
+ if (new_b != b)
+ {
+ moved_b = true;
+ b = new_b;
+ }
+ }
+
+ // put the threshold between the two means we found
+ thresh = (a + b)/2;
+
+ // now actually apply the threshold
+ threshold_image(in_img_,out_img_,thresh);
+ }
+
+ template <
+ typename image_type
+ >
+ void auto_threshold_image (
+ image_type& img
+ )
+ {
+ auto_threshold_image(img,img);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void hysteresis_threshold (
+ const in_image_type& in_img_,
+ out_image_type& out_img_,
+ typename pixel_traits<typename image_traits<in_image_type>::pixel_type>::basic_pixel_type lower_thresh,
+ typename pixel_traits<typename image_traits<in_image_type>::pixel_type>::basic_pixel_type upper_thresh
+ )
+ {
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false );
+ COMPILE_TIME_ASSERT( pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false );
+
+ COMPILE_TIME_ASSERT(pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale);
+
+ DLIB_ASSERT( lower_thresh <= upper_thresh && is_same_object(in_img_, out_img_) == false,
+ "\tvoid hysteresis_threshold(in_img_, out_img_, lower_thresh, upper_thresh)"
+ << "\n\tYou can't use an upper_thresh that is less than your lower_thresh"
+ << "\n\tlower_thresh: " << lower_thresh
+ << "\n\tupper_thresh: " << upper_thresh
+ << "\n\tis_same_object(in_img_,out_img_): " << is_same_object(in_img_,out_img_)
+ );
+
+ const_image_view<in_image_type> in_img(in_img_);
+ image_view<out_image_type> out_img(out_img_);
+
+ // if there isn't any input image then don't do anything
+ if (in_img.size() == 0)
+ {
+ out_img.clear();
+ return;
+ }
+
+ out_img.set_size(in_img.nr(),in_img.nc());
+ assign_all_pixels(out_img, off_pixel);
+
+ const long size = 1000;
+ long rstack[size];
+ long cstack[size];
+
+ // now do the thresholding
+ for (long r = 0; r < in_img.nr(); ++r)
+ {
+ for (long c = 0; c < in_img.nc(); ++c)
+ {
+ typename pixel_traits<typename image_traits<in_image_type>::pixel_type>::basic_pixel_type p;
+ assign_pixel(p,in_img[r][c]);
+ if (p >= upper_thresh)
+ {
+ // now do line following for pixels >= lower_thresh.
+ // set the stack position to 0.
+ long pos = 1;
+ rstack[0] = r;
+ cstack[0] = c;
+
+ while (pos > 0)
+ {
+ --pos;
+ const long r = rstack[pos];
+ const long c = cstack[pos];
+
+ // This is the base case of our recursion. We want to stop if we hit a
+ // pixel we have already visited.
+ if (out_img[r][c] == on_pixel)
+ continue;
+
+ out_img[r][c] = on_pixel;
+
+ // put the neighbors of this pixel on the stack if they are bright enough
+ if (r-1 >= 0)
+ {
+ if (pos < size && get_pixel_intensity(in_img[r-1][c]) >= lower_thresh)
+ {
+ rstack[pos] = r-1;
+ cstack[pos] = c;
+ ++pos;
+ }
+ if (pos < size && c-1 >= 0 && get_pixel_intensity(in_img[r-1][c-1]) >= lower_thresh)
+ {
+ rstack[pos] = r-1;
+ cstack[pos] = c-1;
+ ++pos;
+ }
+ if (pos < size && c+1 < in_img.nc() && get_pixel_intensity(in_img[r-1][c+1]) >= lower_thresh)
+ {
+ rstack[pos] = r-1;
+ cstack[pos] = c+1;
+ ++pos;
+ }
+ }
+
+ if (pos < size && c-1 >= 0 && get_pixel_intensity(in_img[r][c-1]) >= lower_thresh)
+ {
+ rstack[pos] = r;
+ cstack[pos] = c-1;
+ ++pos;
+ }
+ if (pos < size && c+1 < in_img.nc() && get_pixel_intensity(in_img[r][c+1]) >= lower_thresh)
+ {
+ rstack[pos] = r;
+ cstack[pos] = c+1;
+ ++pos;
+ }
+
+ if (r+1 < in_img.nr())
+ {
+ if (pos < size && get_pixel_intensity(in_img[r+1][c]) >= lower_thresh)
+ {
+ rstack[pos] = r+1;
+ cstack[pos] = c;
+ ++pos;
+ }
+ if (pos < size && c-1 >= 0 && get_pixel_intensity(in_img[r+1][c-1]) >= lower_thresh)
+ {
+ rstack[pos] = r+1;
+ cstack[pos] = c-1;
+ ++pos;
+ }
+ if (pos < size && c+1 < in_img.nc() && get_pixel_intensity(in_img[r+1][c+1]) >= lower_thresh)
+ {
+ rstack[pos] = r+1;
+ cstack[pos] = c+1;
+ ++pos;
+ }
+ }
+
+ } // end while (pos >= 0)
+
+ }
+ else
+ {
+ out_img[r][c] = off_pixel;
+ }
+
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THRESHOLDINg_
+
diff --git a/ml/dlib/dlib/image_transforms/thresholding_abstract.h b/ml/dlib/dlib/image_transforms/thresholding_abstract.h
new file mode 100644
index 000000000..e7c1e8826
--- /dev/null
+++ b/ml/dlib/dlib/image_transforms/thresholding_abstract.h
@@ -0,0 +1,139 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_THRESHOLDINg_ABSTRACT_
+#ifdef DLIB_THRESHOLDINg_ABSTRACT_
+
+#include "../pixel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ const unsigned char on_pixel = 255;
+ const unsigned char off_pixel = 0;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void threshold_image (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ typename pixel_traits<typename image_traits<in_image_type>::pixel_type>::basic_pixel_type thresh
+ );
+ /*!
+ requires
+ - in_image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - out_image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale == true
+ - pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false
+ - pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false
+ ensures
+ - #out_img == the thresholded version of in_img (in_img is converted to a grayscale
+ intensity image if it is color). Pixels in in_img with grayscale values >= thresh
+ have an output value of on_pixel and all others have a value of off_pixel.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+ template <
+ typename image_type
+ >
+ void threshold_image (
+ image_type& img,
+ typename pixel_traits<typename image_traits<image_type>::pixel_type>::basic_pixel_type thresh
+ );
+ /*!
+ requires
+ - it is valid to call threshold_image(img,img,thresh);
+ ensures
+ - calls threshold_image(img,img,thresh);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void auto_threshold_image (
+ const in_image_type& in_img,
+ out_image_type& out_img
+ );
+ /*!
+ requires
+ - in_image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - out_image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - pixel_traits<typename image_traits<in_image_type>::pixel_type>::max() <= 65535
+ - pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false
+ - pixel_traits<typename image_traits<in_image_type>::pixel_type>::is_unsigned == true
+ - pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale == true
+ - pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false
+ - pixel_traits<typename image_traits<out_image_type>::pixel_type>::is_unsigned == true
+ ensures
+ - #out_img == the thresholded version of in_img (in_img is converted to a grayscale
+ intensity image if it is color). Pixels in in_img with grayscale values >= thresh
+ have an output value of on_pixel and all others have a value of off_pixel.
+ - The thresh value used is determined by performing a k-means clustering
+ on the input image histogram with a k of 2. The point between the two
+ means found is used as the thresh value.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+ template <
+ typename image_type
+ >
+ void auto_threshold_image (
+ image_type& img
+ );
+ /*!
+ requires
+ - it is valid to call auto_threshold_image(img,img);
+ ensures
+ - calls auto_threshold_image(img,img);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void hysteresis_threshold (
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ typename pixel_traits<typename image_traits<in_image_type>::pixel_type>::basic_pixel_type lower_thresh,
+ typename pixel_traits<typename image_traits<in_image_type>::pixel_type>::basic_pixel_type upper_thresh
+ );
+ /*!
+ requires
+ - in_image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - out_image_type == is an implementation of array2d/array2d_kernel_abstract.h
+ - pixel_traits<typename image_traits<out_image_type>::pixel_type>::grayscale == true
+ - pixel_traits<typename image_traits<in_image_type>::pixel_type>::has_alpha == false
+ - pixel_traits<typename image_traits<out_image_type>::pixel_type>::has_alpha == false
+ - lower_thresh <= upper_thresh
+ - is_same_object(in_img, out_img) == false
+ ensures
+ - #out_img == the hysteresis thresholded version of in_img (in_img is converted to a
+ grayscale intensity image if it is color). Pixels in in_img with grayscale
+ values >= upper_thresh have an output value of on_pixel and all others have a
+ value of off_pixel unless they are >= lower_thresh and are connected to a pixel
+ with a value >= upper_thresh, in which case they have a value of on_pixel. Here
+ pixels are connected if there is a path between them composed of pixels that
+ would receive an output of on_pixel.
+ - #out_img.nc() == in_img.nc()
+ - #out_img.nr() == in_img.nr()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THRESHOLDINg_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/interfaces/cmd_line_parser_option.h b/ml/dlib/dlib/interfaces/cmd_line_parser_option.h
new file mode 100644
index 000000000..797dcd2e6
--- /dev/null
+++ b/ml/dlib/dlib/interfaces/cmd_line_parser_option.h
@@ -0,0 +1,107 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CMD_LINE_PARSER_OPTIOn_
+#define DLIB_CMD_LINE_PARSER_OPTIOn_
+
+#include <string>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT
+ >
+ class cmd_line_parser_option
+ {
+ /*!
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ None of the functions in cmd_line_parser_option will invalidate
+ pointers or references to internal data when called.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a command line option.
+ !*/
+
+ public:
+
+ typedef charT char_type;
+ typedef std::basic_string<charT> string_type;
+
+ virtual ~cmd_line_parser_option (
+ ) = 0;
+
+ virtual const string_type& name (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the name of this option
+ !*/
+
+ virtual const string_type& group_name (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the name of the group this option is in. If no group was set for
+ this option then this function returns "".
+ !*/
+
+ virtual const string_type& description (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the description for this option
+ !*/
+
+ virtual unsigned long number_of_arguments(
+ ) const = 0;
+ /*!
+ ensures
+ - returns the number of arguments for this option
+ !*/
+
+ virtual unsigned long count(
+ ) const = 0;
+ /*!
+ ensures
+ - returns the number of times this option appears on the command line.
+ !*/
+
+ virtual const string_type& argument (
+ unsigned long arg = 0,
+ unsigned long N = 0
+ ) const = 0;
+ /*!
+ requires
+ - arg < number_of_arguments()
+ - N < count()
+ ensures
+ - returns the arg-th argument to the Nth occurrence of this
+ option on the command line.
+ !*/
+
+ inline operator bool (
+ ) const { return count() > 0; }
+ /*!
+ ensures
+ - returns true if this option appears on the command line at all
+ !*/
+
+ protected:
+
+ // restricted functions
+ cmd_line_parser_option& operator=(const cmd_line_parser_option&){return *this;}
+
+ };
+
+ // destructor does nothing
+ template < typename charT >
+ cmd_line_parser_option<charT>::~cmd_line_parser_option() {}
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CMD_LINE_PARSER_OPTIOn_
+
diff --git a/ml/dlib/dlib/interfaces/enumerable.h b/ml/dlib/dlib/interfaces/enumerable.h
new file mode 100644
index 000000000..e8f5ae78c
--- /dev/null
+++ b/ml/dlib/dlib/interfaces/enumerable.h
@@ -0,0 +1,130 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ENUMERABLe_INTERFACE_
+#define DLIB_ENUMERABLe_INTERFACE_
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class enumerable
+ {
+ /*!
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ - if (at_start()) then
+ - all pointers and references to data returned via element() are
+ invalid.
+ - calling move_next() or reset() invalidates pointers and references to
+ data returned via element() and only data returned via element().
+ - calling at_start(), current_element_valid(), size(), or element()
+ does NOT invalidate pointers or references to any internal data.
+
+ INITIAL VALUE
+ current_element_valid() == false
+ at_start() == true
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represent an interface for iterating through the
+ elements in a container. It starts out one before the first element
+ in the container.
+
+
+ EXAMPLE: The following loops though all elements in the container
+ and prints them to cout.
+
+ container.reset();
+ while(container.move_next()) {
+ cout << container.element();
+ }
+ !*/
+
+ public:
+ typedef T type;
+
+ inline virtual ~enumerable(
+ ) = 0;
+
+ virtual bool at_start (
+ ) const = 0;
+ /*!
+ ensures
+ - returns true if *this represents one position before the first element
+ in the container (this would also make the current element invalid)
+ else returns false
+ !*/
+
+ virtual void reset (
+ ) const = 0;
+ /*!
+ ensures
+ - #current_element_valid() == false
+ - #at_start() == true
+ !*/
+
+ virtual bool current_element_valid (
+ ) const = 0;
+ /*!
+ ensures
+ - returns true if we are currently at a valid element else
+ returns false
+ !*/
+
+ virtual const T& element (
+ ) const = 0;
+ /*!
+ requires
+ - current_element_valid() == true
+ ensures
+ - returns a const reference to the current element
+ !*/
+
+ virtual T& element (
+ ) = 0;
+ /*!
+ requires
+ - current_element_valid() == true
+ ensures
+ - returns a non-const reference to the current element
+ !*/
+
+ virtual bool move_next (
+ ) const = 0;
+ /*!
+ ensures
+ - moves to the next element. i.e. #element() will now
+ return the next element in the container
+ - the return value will be equal to #current_element_valid()
+ - #at_start() == false
+
+ - returns true if there is another element
+ - returns false if there are no more elements in the container
+ !*/
+
+ virtual size_t size (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the number of elements in *this
+ !*/
+
+ protected:
+
+ // restricted functions
+ enumerable& operator=(const enumerable&) {return *this;} // no assignment operator
+
+ };
+
+ // destructor does nothing
+ template <typename T>
+ enumerable<T>::~enumerable() {}
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ENUMERABLe_INTERFACE_
+
diff --git a/ml/dlib/dlib/interfaces/map_pair.h b/ml/dlib/dlib/interfaces/map_pair.h
new file mode 100644
index 000000000..64310152e
--- /dev/null
+++ b/ml/dlib/dlib/interfaces/map_pair.h
@@ -0,0 +1,74 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MAP_PAIr_INTERFACE_
+#define DLIB_MAP_PAIr_INTERFACE_
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T1,
+ typename T2
+ >
+ class map_pair
+ {
+ /*!
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ None of the functions in map_pair will invalidate
+ pointers or references to internal data when called.
+
+ WHAT THIS OBJECT REPRESENTS
+ this object is used to return the key/value pair used in the
+ map and hash_map containers when using the enumerable interface.
+
+ note that the enumerable interface is defined in
+ interfaces/enumerable.h
+ !*/
+
+ public:
+ typedef T1 key_type;
+ typedef T2 value_type;
+
+ virtual ~map_pair(
+ )=0;
+
+ virtual const T1& key(
+ ) const =0;
+ /*!
+ ensures
+ - returns a const reference to the key
+ !*/
+
+ virtual const T2& value(
+ ) const =0;
+ /*!
+ ensures
+ - returns a const reference to the value associated with key
+ !*/
+
+ virtual T2& value(
+ )=0;
+ /*!
+ ensures
+ - returns a non-const reference to the value associated with key
+ !*/
+
+ protected:
+
+ // restricted functions
+ map_pair<T1,T2>& operator=(const map_pair<T1,T2>&) {return *this;} // no assignment operator
+
+ };
+
+ // destructor does nothing
+ template <typename T1,typename T2>
+ map_pair<T1,T2>::~map_pair () {}
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MAP_PAIr_INTERFACE_
+
diff --git a/ml/dlib/dlib/interfaces/remover.h b/ml/dlib/dlib/interfaces/remover.h
new file mode 100644
index 000000000..f2098cba6
--- /dev/null
+++ b/ml/dlib/dlib/interfaces/remover.h
@@ -0,0 +1,220 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_REMOVER_KERNEl_INTERFACE_
+#define DLIB_REMOVER_KERNEl_INTERFACE_
+
+#include <functional>
+
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class remover
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T is swappable by a global swap() and
+ T must have a default constructor
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ The size() function does not invalidate pointers or
+ references to internal data. All other functions have no such
+ guarantee.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents some generalized interface for removing
+ single items from container classes.
+ !*/
+
+
+ public:
+ typedef T type;
+
+ virtual ~remover(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released.
+ !*/
+
+ virtual void remove_any (
+ T& item
+ ) = 0;
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - #size() == size() - 1
+ - removes an element from *this and swaps it into item.
+ - if (*this implements the enumerable interface) then
+ - #at_start() == true
+ !*/
+
+ virtual size_t size (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the number of elements in *this
+ !*/
+
+ protected:
+
+ // restricted functions
+ remover& operator=(const remover&) {return *this;} // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ class asc_remover : public remover<T>
+ {
+ /*!
+ REQUIREMENTS ON T
+ T is swappable by a global swap() and
+ T must have a default constructor and
+ T must be comparable by compare where compare is a functor compatible with std::less
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the same thing as remover except
+ that remove_any() will remove elements in ascending order
+ according to the compare functor.
+ !*/
+ public:
+ typedef compare compare_type;
+
+ protected:
+ // restricted functions
+ asc_remover& operator=(const asc_remover&) {return *this;} // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range
+ >
+ class pair_remover
+ {
+
+ /*!
+ REQUIREMENTS ON domain
+ domain is swappable by a global swap() and
+ domain must have a default constructor
+
+ REQUIREMENTS ON range
+ range is swappable by a global swap() and
+ range must have a default constructor
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ The size() function does not invalidate pointers or
+ references to internal data. All other functions have no such
+ guarantee.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents some generalized interface for removing
+ pairs from container classes which enforce some kind of pairing on
+ the elements that they contain.
+ !*/
+
+ public:
+ typedef domain domain_type;
+ typedef range range_type;
+
+ virtual ~pair_remover(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released.
+ !*/
+
+ virtual void remove_any (
+ domain& d,
+ range& r
+ ) = 0;
+ /*!
+ requires
+ - &d != &r (i.e. d and r cannot be the same variable)
+ - size() != 0
+ ensures
+ - #size() == size() - 1
+ - removes an element from the domain of *this and swaps it
+ into d.
+ - removes the element in *this's range that is associated
+ with #d and swaps it into r.
+ - if (*this implements the enumerable interface) then
+ - #at_start() == true
+ !*/
+
+ virtual size_t size (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the number of elements in *this
+ !*/
+
+
+ protected:
+
+ // restricted functions
+ pair_remover& operator=(const pair_remover&) {return *this;} // assignment operator
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ class asc_pair_remover : public pair_remover<domain,range>
+ {
+ /*!
+ REQUIREMENTS ON domain
+ domain is swappable by a global swap() and
+ domain must have a default constructor and
+ domain must be comparable by compare where compare is a functor compatible with std::less
+
+ REQUIREMENTS ON range
+ range is swappable by a global swap() and
+ range must have a default constructor
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the same thing as pair_remover except
+ that remove_any() will remove domain elements in ascending
+ order according to the compare functor.
+ !*/
+ public:
+ typedef compare compare_type;
+
+ protected:
+ // restricted functions
+ asc_pair_remover& operator=(const asc_pair_remover&) {return *this;} // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ // destructor does nothing
+ template <typename T>
+ remover<T>::~remover() {}
+
+ // destructor does nothing
+ template <typename domain, typename range>
+ pair_remover<domain,range>::~pair_remover() {}
+
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+#endif // DLIB_REMOVER_KERNEl_INTERFACE_
+
diff --git a/ml/dlib/dlib/iomanip b/ml/dlib/dlib/iomanip
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/iomanip
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/iosfwd b/ml/dlib/dlib/iosfwd
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/iosfwd
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/iosockstream.h b/ml/dlib/dlib/iosockstream.h
new file mode 100644
index 000000000..da1b9b505
--- /dev/null
+++ b/ml/dlib/dlib/iosockstream.h
@@ -0,0 +1,11 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IOSOCkSTREAM_H_h_
+#define DLIB_IOSOCkSTREAM_H_h_
+
+#include "iosockstream/iosockstream.h"
+
+
+#endif // DLIB_IOSOCkSTREAM_H_h_
+
+
diff --git a/ml/dlib/dlib/iosockstream/iosockstream.h b/ml/dlib/dlib/iosockstream/iosockstream.h
new file mode 100644
index 000000000..e49d2e37f
--- /dev/null
+++ b/ml/dlib/dlib/iosockstream/iosockstream.h
@@ -0,0 +1,171 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IOSOCKSTrEAM_Hh_
+#define DLIB_IOSOCKSTrEAM_Hh_
+
+#include "iosockstream_abstract.h"
+
+#include <iostream>
+#include <memory>
+
+#include "../sockstreambuf.h"
+#include "../timeout.h"
+
+#ifdef _MSC_VER
+// Disable the warning about inheriting from std::iostream 'via dominance' since this warning is a warning about
+// visual studio conforming to the standard and is ignorable.
+// See http://connect.microsoft.com/VisualStudio/feedback/details/733720/inheriting-from-std-fstream-produces-c4250-warning
+// for further details if interested.
+#pragma warning(disable : 4250)
+#endif // _MSC_VER
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class iosockstream : public std::iostream
+ {
+ public:
+
+ iosockstream(
+ ) :
+ std::iostream(0)
+ {
+ }
+
+ iosockstream(
+ const network_address& addr
+ ) :
+ std::iostream(0)
+ {
+ open(addr);
+ }
+
+ iosockstream(
+ const network_address& addr,
+ unsigned long timeout
+ ) :
+ std::iostream(0)
+ {
+ open(addr, timeout);
+ }
+
+ ~iosockstream()
+ {
+ close();
+ }
+
+ void open (
+ const network_address& addr
+ )
+ {
+ auto_mutex lock(class_mutex);
+ close();
+ con.reset(connect(addr));
+ buf.reset(new sockstreambuf(con.get()));
+ // Note that we use the sockstreambuf's ability to autoflush instead of
+ // telling the iostream::tie() function to tie the stream to itself even though
+ // that should work fine. The reason we do it this way is because there is a
+ // bug in visual studio 2012 that causes a program to crash when a stream is
+ // tied to itself and then used. See
+ // http://connect.microsoft.com/VisualStudio/feedback/details/772293/tying-a-c-iostream-object-to-itself-causes-a-stack-overflow-in-visual-studio-2012
+ // for further details.
+ buf->flush_output_on_read();
+ rdbuf(buf.get());
+ clear();
+ }
+
+ void open (
+ const network_address& addr,
+ unsigned long timeout
+ )
+ {
+ auto_mutex lock(class_mutex);
+ close(timeout);
+ con.reset(connect(addr.host_address, addr.port, timeout));
+ buf.reset(new sockstreambuf(con.get()));
+ buf->flush_output_on_read();
+ rdbuf(buf.get());
+ clear();
+ }
+
+ void close(
+ unsigned long timeout = 10000
+ )
+ {
+ auto_mutex lock(class_mutex);
+ rdbuf(0);
+ try
+ {
+ if (buf)
+ {
+ dlib::timeout t(*con,&connection::shutdown,timeout);
+
+ // This will flush the sockstreambuf and also destroy it.
+ buf.reset();
+
+ if(con->shutdown_outgoing())
+ {
+ // there was an error so just close it now and return
+ con->shutdown();
+ }
+ else
+ {
+ char junk[100];
+ // wait for the other end to close their side
+ while (con->read(junk,sizeof(junk)) > 0);
+ }
+ }
+ }
+ catch (...)
+ {
+ con.reset();
+ throw;
+ }
+ con.reset();
+ }
+
+ void terminate_connection_after_timeout (
+ unsigned long timeout
+ )
+ {
+ auto_mutex lock(class_mutex);
+ if (con)
+ {
+ con_timeout.reset(new dlib::timeout(*this,&iosockstream::terminate_connection,timeout,con));
+ }
+ }
+
+ void shutdown (
+ )
+ {
+ auto_mutex lock(class_mutex);
+ if (con)
+ con->shutdown();
+ }
+
+ private:
+
+ void terminate_connection(
+ std::shared_ptr<connection> thecon
+ )
+ {
+ thecon->shutdown();
+ }
+
+ std::unique_ptr<timeout> con_timeout;
+ rmutex class_mutex;
+ std::shared_ptr<connection> con;
+ std::unique_ptr<sockstreambuf> buf;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_IOSOCKSTrEAM_Hh_
+
+
diff --git a/ml/dlib/dlib/iosockstream/iosockstream_abstract.h b/ml/dlib/dlib/iosockstream/iosockstream_abstract.h
new file mode 100644
index 000000000..2328f426e
--- /dev/null
+++ b/ml/dlib/dlib/iosockstream/iosockstream_abstract.h
@@ -0,0 +1,171 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_IOSOCKSTrEAM_ABSTRACT_Hh_
+#ifdef DLIB_IOSOCKSTrEAM_ABSTRACT_Hh_
+
+#include "../sockstreambuf/sockstreambuf_abstract.h"
+
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class iosockstream : public std::iostream
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an iostream object that reads/writes from a TCP network connection.
+
+ Note that any attempt to read from this stream will automatically flush the
+ stream's output buffers.
+
+ THREAD SAFETY
+ It is not safe for multiple threads to make concurrent accesses to the same
+ instance of this object (except for calls to shutdown() which are always
+ threadsafe). Therefore, you should mutex lock an instance of this object
+ if you need to touch it from multiple threads.
+ !*/
+
+ public:
+
+ iosockstream(
+ );
+ /*!
+ ensures
+ - #good() == false
+ !*/
+
+ iosockstream(
+ const network_address& addr
+ );
+ /*!
+ ensures
+ - Attempts to connect to the given network address.
+ - Calling this constructor is equivalent to calling the default constructor
+ and then invoking open(addr).
+ - #good() == true
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is some problem that prevents us from
+ creating the connection.
+ !*/
+
+ iosockstream(
+ const network_address& addr,
+ unsigned long timeout
+ );
+ /*!
+ ensures
+ - Attempts to connect to the given network address.
+ - Calling this constructor is equivalent to calling the default constructor
+ and then invoking open(addr, timeout).
+ - #good() == true
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is some problem that prevents us from
+ creating the connection or if timeout milliseconds elapses before the
+ connect is successful.
+ !*/
+
+ ~iosockstream(
+ );
+ /*!
+ ensures
+ - Invokes close() before destructing the stream. Therefore, any open
+ connection will be gracefully closed using the default timeout time.
+ This also means any data in the stream will be flushed to the connection.
+ !*/
+
+ void open (
+ const network_address& addr
+ );
+ /*!
+ ensures
+ - This object will attempt to create a TCP connection with the remote host
+ indicated by addr.
+ - Any previous connection in this iosockstream is closed by calling close()
+ before we make any new connection.
+ - #good() == true
+ (i.e. the error flags are reset by calling open())
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is some problem that prevents us from
+ creating the connection.
+ !*/
+
+ void open (
+ const network_address& addr,
+ unsigned long timeout
+ );
+ /*!
+ ensures
+ - This object will attempt to create a TCP connection with the remote host
+ indicated by addr.
+ - Any previous connection in this iosockstream is closed by calling close()
+ before we make any new connection.
+ - #good() == true
+ (i.e. the error flags are reset by calling open())
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is some problem that prevents us from
+ creating the connection or if timeout milliseconds elapses before the
+ connect is successful.
+ !*/
+
+ void close(
+ unsigned long timeout = 10000
+ );
+ /*!
+ ensures
+ - #good() == false
+ - if (there is an active TCP connection) then
+ - Flushes any data buffered in the output part of the stream
+ to the connection.
+ - Performs a proper graceful close (i.e. like dlib::close_gracefully()).
+ - Will only wait timeout milliseconds for the buffer flush and graceful
+ close to finish before the connection is terminated forcefully.
+ Therefore, close() will only block for at most timeout milliseconds.
+ !*/
+
+ void terminate_connection_after_timeout (
+ unsigned long timeout
+ );
+ /*!
+ ensures
+ - if (there is an active TCP connection) then
+ - Any operations on this TCP connection will return error or
+ end-of-file once timeout milliseconds have elapsed from this call to
+ terminate_connection_after_timeout(). This is true unless another
+ call to terminate_connection_after_timeout() is made which gives a
+ new time. In this case, the previous call is forgotten and the
+ timeout is reset.
+ - This timeout only applies to the current TCP connection. That is, if
+ the iosockstream is closed and a new connection is established, any
+ previous timeouts setup by terminate_connection_after_timeout() do
+ not apply.
+ - else
+ - This function has no effect on this object.
+ !*/
+
+ void shutdown (
+ );
+ /*!
+ ensures
+ - Immediately closes the TCP connection and causes all I/O operations on
+ this object to return an error.
+ - It is safe to call this function from any thread, therefore, you can use
+ it to signal when you want a connection to terminate from another thread.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_IOSOCKSTrEAM_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/iostream b/ml/dlib/dlib/iostream
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/iostream
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/is_kind.h b/ml/dlib/dlib/is_kind.h
new file mode 100644
index 000000000..e8dcb6320
--- /dev/null
+++ b/ml/dlib/dlib/is_kind.h
@@ -0,0 +1,162 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IS_KINd_H_
+#define DLIB_IS_KINd_H_
+
+#include <vector>
+
+namespace dlib
+{
+ /*!
+ This file contains a set of templates that enable you to determine if
+ a given type implements an abstract interface defined in one of the
+ dlib *_abstract.h files.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ struct default_is_kind_value { static const bool value = false; };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct is_graph : public default_is_kind_value
+ {
+ /*!
+ - if (T is an implementation of graph/graph_kernel_abstract.h) then
+ - is_graph<T>::value == true
+ - else
+ - is_graph<T>::value == false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct is_directed_graph : public default_is_kind_value
+ {
+ /*!
+ - if (T is an implementation of directed_graph/directed_graph_kernel_abstract.h) then
+ - is_directed_graph<T>::value == true
+ - else
+ - is_directed_graph<T>::value == false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename helper = void>
+ struct is_matrix : public default_is_kind_value
+ {
+ /*!
+ - if (T is some kind of matrix expression from the matrix/matrix_exp_abstract.h component) then
+ - is_matrix<T>::value == true
+ - else
+ - is_matrix<T>::value == false
+ !*/
+
+ // Don't set the helper to anything. Just let it be void.
+ ASSERT_ARE_SAME_TYPE(helper,void);
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct is_array2d : public default_is_kind_value
+ {
+ /*!
+ - if (T is an implementation of array2d/array2d_kernel_abstract.h) then
+ - is_array2d<T>::value == true
+ - else
+ - is_array2d<T>::value == false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct is_array : public default_is_kind_value
+ {
+ /*!
+ - if (T is an implementation of array/array_kernel_abstract.h) then
+ - is_array<T>::value == true
+ - else
+ - is_array<T>::value == false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct is_std_vector : public default_is_kind_value
+ {
+ /*!
+ - if (T is an implementation of the standard C++ std::vector object) then
+ - is_std_vector<T>::value == true
+ - else
+ - is_std_vector<T>::value == false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct is_pair : public default_is_kind_value
+ {
+ /*!
+ - if (T is a std::pair object) then
+ - is_std_vector<T>::value == true
+ - else
+ - is_std_vector<T>::value == false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct is_rand : public default_is_kind_value
+ {
+ /*!
+ - if (T is an implementation of rand/rand_kernel_abstract.h) then
+ - is_rand<T>::value == true
+ - else
+ - is_rand<T>::value == false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct is_config_reader : public default_is_kind_value
+ {
+ /*!
+ - if (T is an implementation of config_reader/config_reader_kernel_abstract.h) then
+ - is_config_reader<T>::value == true
+ - else
+ - is_config_reader<T>::value == false
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Implementation details
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ struct is_std_vector<std::vector<T,alloc> > { const static bool value = true; };
+ template <typename T> struct is_std_vector<T&> { const static bool value = is_std_vector<T>::value; };
+ template <typename T> struct is_std_vector<const T&>{ const static bool value = is_std_vector<T>::value; };
+ template <typename T> struct is_std_vector<const T> { const static bool value = is_std_vector<T>::value; };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ struct is_pair<std::pair<T,U> > { const static bool value = true; };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IS_KINd_H_
+
diff --git a/ml/dlib/dlib/istream b/ml/dlib/dlib/istream
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/istream
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/java/CMakeLists.txt b/ml/dlib/dlib/java/CMakeLists.txt
new file mode 100644
index 000000000..4d66a513c
--- /dev/null
+++ b/ml/dlib/dlib/java/CMakeLists.txt
@@ -0,0 +1,32 @@
+
+cmake_minimum_required (VERSION 2.8.12)
+project (myproject)
+set(java_package_name net.dlib)
+set(source_files
+ )
+
+include(../cmake_utils/release_build_by_default)
+
+include_directories(
+ .
+ )
+
+# Additional dependencies
+#add_subdirectory(../../dlib dlib_build)
+#set(additional_link_libraries dlib::dlib)
+
+# Tell swig to put the output files (the shared library and .jar) into the local folder.
+set(install_target_output_folder .)
+
+# Alternatively, instead of using install_target_output_folder, you can tell
+# cmake to output the shared library, java source files, and the jar to
+# separate output folders. These commands would put them into folders thelib,
+# thesrc, and thejar, respectively.
+#set(install_shared_library_output_folder thelib)
+#set(install_java_source_output_folder thesrc)
+#set(install_jar_output_folder thejar)
+
+
+include(cmake_swig_jni)
+
+
diff --git a/ml/dlib/dlib/java/cmake_swig_jni b/ml/dlib/dlib/java/cmake_swig_jni
new file mode 100644
index 000000000..d74dd60ec
--- /dev/null
+++ b/ml/dlib/dlib/java/cmake_swig_jni
@@ -0,0 +1,265 @@
+# This file is used to create SWIG based JNI interfaces to C++ code. You use
+# it by defining some CMake variables and then include(cmake_swig_jni). You
+# would make a CMakeLists.txt file that looks like the following:
+#
+# cmake_minimum_required (VERSION 2.8.12)
+# project (example)
+# set(java_package_name "org.mycompany")
+# set(source_files
+# your_cpp_source.cpp
+# more_cpp_source.cpp
+# )
+#
+# ### We might need to link our code to some other C++ library like dlib. You
+# ### can do that by setting additional_link_libraries. Here is an example of
+# ### linking to dlib:
+# include(../../dlib/dlib/cmake)
+# set(additional_link_libraries dlib::dlib)
+#
+# ### Tell swig to put the output files into the parent folder of your CMakeLists.txt
+# ### file when you run make install.
+# set(install_target_output_folder ..)
+# include(cmake_swig_jni)
+#
+# ### Alternatively, instead of using install_target_output_folder, you can tell
+# ### cmake to output the shared library, java source files, and the jar to
+# ### separate output folders. These commands would put them into folders
+# ### thelib, thesrc, and thejar, respectively.
+# # set(install_shared_library_output_folder thelib)
+# # set(install_java_source_output_folder thesrc)
+# # set(install_jar_output_folder thejar)
+
+
+
+
+
+
+################################################################################
+################################################################################
+# IMPLEMENTATION DETAILS
+################################################################################
+################################################################################
+
+cmake_minimum_required (VERSION 2.8.12)
+
+include(${CMAKE_CURRENT_LIST_DIR}/../cmake_utils/use_cpp_11.cmake)
+
+# This block of code tries to figure out what the JAVA_HOME environment
+# variable should be by looking at the folder that contains the java
+# executable.
+if (NOT DEFINED ENV{JAVA_HOME})
+ message(STATUS "JAVA_HOME environment variable not set, trying to guess it...")
+ find_program(JAVA_EXECUTABLE java)
+ # Resolve symbolic links, hopefully this will give us a path in the proper
+ # java home directory.
+ get_filename_component(JAVA_EXECUTABLE ${JAVA_EXECUTABLE} REALPATH)
+ # Pick out the parent directories
+ get_filename_component(JAVA_PATH1 ${JAVA_EXECUTABLE} PATH)
+ get_filename_component(JAVA_PATH2 ${JAVA_PATH1} PATH)
+ get_filename_component(JAVA_PATH3 ${JAVA_PATH2} PATH)
+ # and search them for include/jni.h. If we find that then we probably have
+ # a good java home candidate.
+ find_path(AUTO_JAVA_HOME include/jni.h
+ PATHS
+ ${JAVA_PATH1}
+ ${JAVA_PATH2}
+ ${JAVA_PATH3}
+ "C:/Program Files/Java/jdk*"
+ "C:/Program Files (x86)/Java/jdk*"
+ )
+
+ if (AUTO_JAVA_HOME)
+ set(ENV{JAVA_HOME} ${AUTO_JAVA_HOME})
+ message(STATUS "Using JAVA_HOME OF " ${AUTO_JAVA_HOME})
+ else()
+ message(FATAL_ERROR "Couldn't find a folder for JAVA_HOME. You must set the JAVA_HOME environment variable before running CMake.")
+ endif()
+endif()
+
+set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib")
+set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE "${CMAKE_CURRENT_BINARY_DIR}/lib")
+
+find_package(SWIG REQUIRED)
+find_package(Java REQUIRED)
+find_package(JNI REQUIRED)
+include(UseSWIG)
+
+macro (add_global_switch def_name )
+ if (NOT CMAKE_CXX_FLAGS MATCHES "${def_name}")
+ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${def_name}"
+ CACHE STRING "Flags used by the compiler during all C++ builds."
+ FORCE)
+ endif ()
+ if (NOT CMAKE_C_FLAGS MATCHES "${def_name}")
+ set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${def_name}"
+ CACHE STRING "Flags used by the compiler during all C builds."
+ FORCE)
+ endif ()
+endmacro()
+
+# SWIG doesn't work if optimizations are enabled and strict aliasing is not
+# turned off. This is a little wonky but it's how SWIG is.
+if (CMAKE_COMPILER_IS_GNUCXX)
+ add_definitions(-fno-strict-aliasing)
+endif()
+if (UNIX)
+ # we need to make sure all the code is compiled with -fPIC. In particular,
+ # it's important that all the code for the whole project is, not just the
+ # stuff immediately compiled by us in this cmake file. So we add -fPIC to
+ # the top level cmake flags variables.
+ add_global_switch(-fPIC)
+endif()
+
+set(dlib_root_path ${CMAKE_CURRENT_LIST_DIR}/../../)
+
+string(REGEX REPLACE "\\." "/" package_path ${java_package_name})
+string(REGEX REPLACE "\\..*" "" package_root_name ${java_package_name})
+
+include_directories(${dlib_root_path})
+
+set(CMAKE_SWIG_FLAGS -package ${java_package_name} -I${dlib_root_path})
+set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/lib/java_src/${package_path})
+
+set(output_library_name ${PROJECT_NAME})
+
+# Create the swig.i interface file that swig will run on. We do it here in
+# the cmake script because this lets us automatically include the correct
+# output library name into the call to System.loadLibrary().
+FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/swig.i
+ "
+ // Put the global functions in our api into a java class called global.
+ %module global
+
+ %{
+ #include <exception>
+ #include <stdexcept>
+ static JavaVM *cached_jvm = 0;
+
+ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *jvm, void *reserved) {
+ cached_jvm = jvm;
+ return JNI_VERSION_1_6;
+ }
+
+ static JNIEnv * JNI_GetEnv() {
+ JNIEnv *env;
+ jint rc = cached_jvm->GetEnv((void **)&env, JNI_VERSION_1_6);
+ if (rc == JNI_EDETACHED)
+ throw std::runtime_error(\"current thread not attached\");
+ if (rc == JNI_EVERSION)
+ throw std::runtime_error(\"jni version not supported\");
+ return env;
+ }
+
+ #include \"swig_api.h\"
+ %}
+
+ // Convert all C++ exceptions into java.lang.Exception
+ %exception {
+ try {
+ $action
+ } catch(std::exception& e) {
+ jclass clazz = jenv->FindClass(\"java/lang/Exception\");
+ jenv->ThrowNew(clazz, e.what());
+ return $null;
+ }
+ }
+
+ %pragma(java) jniclasscode=%{
+ static { System.loadLibrary(\"${output_library_name}\"); }
+ %}
+
+ %include \"swig_api.h\"
+ "
+)
+
+# There is a bug in CMake's Swig scripts that causes the build to fail if the
+# binary folder doesn't contain a folder with the same name as the binary dir.
+# So we make a subfolder of the same name to avoid that bug.
+get_filename_component(binary_dir_name "${CMAKE_CURRENT_BINARY_DIR}" NAME)
+FILE(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${binary_dir_name}")
+
+set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/swig.i PROPERTIES CPLUSPLUS ON)
+swig_add_module(${output_library_name} java ${CMAKE_CURRENT_BINARY_DIR}/swig.i ${source_files})
+enable_cpp11_for_target(${output_library_name})
+
+include_directories(${JNI_INCLUDE_DIRS})
+swig_link_libraries(${output_library_name} ${additional_link_libraries})
+
+# Things to delete when "make clean" is run.
+set(clean_files
+ ${CMAKE_CURRENT_BINARY_DIR}/intermediate_files_compiled
+ ${CMAKE_CURRENT_BINARY_DIR}/lib/java_src
+ )
+set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES "${clean_files}")
+
+# Compile the java files into a jar file and stick it in the lib folder. Also, one problem
+# with this cmake setup is that it doesn't know that modifications to swig_api.h mean that
+# swig.i is invalidated and thus swig needs to be rerun. So here we also touch swig.i
+# every time we build to make it always out of date and force swig to run on each build,
+# thus avoiding the stale swig outputs problem that would otherwise irritate people who
+# modify something and attempt to rebuild.
+add_custom_command(TARGET ${output_library_name}
+ POST_BUILD
+ COMMAND cmake -E echo "compiling Java files..."
+ COMMAND cmake -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/intermediate_files_compiled"
+ COMMAND ${Java_JAVAC_EXECUTABLE} ${CMAKE_SWIG_OUTDIR}/*.java -d "${CMAKE_CURRENT_BINARY_DIR}/intermediate_files_compiled"
+ COMMAND cmake -E echo "Making jar file..."
+ COMMAND ${Java_JAR_EXECUTABLE} cvf "${CMAKE_CURRENT_BINARY_DIR}/lib/${PROJECT_NAME}.jar" -C "${CMAKE_CURRENT_BINARY_DIR}/intermediate_files_compiled" ${package_root_name}
+ COMMAND cmake -E touch swig.i
+ )
+
+
+# Determine the path to our CMakeLists.txt file.
+# There is either a bug (or break in compatability maybe) between versions
+# of cmake that cause the or expression in this regular expression to be
+# necessary.
+string(REGEX REPLACE "(cmake_swig_jni|CMakeLists.txt)$" "" base_path ${CMAKE_PARENT_LIST_FILE})
+
+#if the including cmake script set the install_target_output_folder variable
+#then make it so we install the compiled library and jar into that folder
+if (install_target_output_folder)
+ # The directory we will write the output files to.
+ set(install_dir "${base_path}${install_target_output_folder}")
+ set(CMAKE_INSTALL_PREFIX "${install_dir}")
+ set(CMAKE_INSTALL_SYSTEM_RUNTIME_DESTINATION "${install_dir}")
+ install(TARGETS ${output_library_name}
+ DESTINATION "${install_dir}"
+ )
+ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/lib/${PROJECT_NAME}.jar
+ DESTINATION "${install_dir}"
+ )
+endif()
+
+if (install_shared_library_output_folder)
+ set(install_dir "${base_path}${install_shared_library_output_folder}")
+ install(TARGETS ${output_library_name}
+ DESTINATION "${install_dir}"
+ )
+endif()
+
+if (install_java_source_output_folder)
+ set(install_dir "${base_path}${install_java_source_output_folder}")
+ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib/java_src/${package_root_name}
+ DESTINATION "${install_dir}"
+ )
+endif()
+
+if (install_jar_output_folder)
+ set(install_dir "${base_path}${install_jar_output_folder}")
+ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/lib/${PROJECT_NAME}.jar
+ DESTINATION "${install_dir}"
+ )
+endif()
+
+
+# Copy any system libraries to the output folder. This really only matters on
+# windows where it's good to have the visual studio runtime show up in the lib
+# folder so that you don't forget to include it in your binary distribution.
+INCLUDE(InstallRequiredSystemLibraries)
+foreach (file_i ${CMAKE_INSTALL_SYSTEM_RUNTIME_LIBS})
+ add_custom_command(TARGET ${output_library_name}
+ POST_BUILD
+ COMMAND cmake -E copy ${file_i} "${CMAKE_CURRENT_BINARY_DIR}/lib/"
+ )
+endforeach()
+
diff --git a/ml/dlib/dlib/java/java_array.h b/ml/dlib/dlib/java/java_array.h
new file mode 100644
index 000000000..6c4d5f03d
--- /dev/null
+++ b/ml/dlib/dlib/java/java_array.h
@@ -0,0 +1,605 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SWIG_JAVA_ARRAY_H_
+#define DLIB_SWIG_JAVA_ARRAY_H_
+
+
+/*
+
+ This file defines three special classes: array, array_view, and array_view_crit. An
+ array is a simple opaque handle to a java array, like a double[] array. The array_view
+ and array_view_crit objects allow you to access the contents of an array. The
+ interfaces of these objects is shown below, but for an example use, suppose you had an
+ array of int in java and you wanted to pass it to C++. You could create a C++ function
+ like this:
+
+ void my_function(const array_view<int32_t>& array);
+
+ and then within java you could call it with code like this:
+
+ int[] array = new int[100];
+ my_function(array);
+
+ and it will work just like you would expect. The array_view<int32_t> will usually result in
+ the JVM doing a copy in the background. However, you can also declare your function
+ like this:
+
+ void my_function(const array_view_crit<int32_t>& array);
+
+ and still call it the same way in java, however, using array_view_crit<int32_t> will usually
+ not result in any copying, and is therefore very fast. array_view_crit uses the JNI
+ routine GetPrimitiveArrayCritical() to get a lock on the java memory underlying the
+ array. So it will probably prevent the garbage collector from running while your
+ function is executing. The JNI documentation is somewhat vague on the limitations of
+ GetPrimitiveArrayCritical(), saying only that you shouldn't hold the lock on the array
+ for "an extended period" or call back into the JVM. Deciding whether or not this
+ matters in your application is left as an exercise for the reader.
+
+
+ There are two ways you can declare your methods if they take an array_view or
+ array_view_crit. Taking a const reference or a non-const reference. E.g.
+ void my_function(const array_view<int32_t>& array);
+ void my_function(array_view<int32_t>& array);
+ You can't declare them to be by value. The non-const version allows you to modify the
+ contents of the array and the modifications will be visible to java, as you would
+ expect. You can also make functions that take array objects directly, but that's only
+ useful if you want to store the array handle somewhere, like in a member of a long
+ lived class. You can also write functions that return arrays back to java. E.g.
+ array<int32_t> make_an_array(size_t s)
+ {
+ array<int32_t> arr(s);
+ array_view<int32_t> aview(arr);
+ // Use aview to put data into the array and generally do something useful.
+ ...
+ return arr;
+ }
+ This would create an array and return it as a java int[] array.
+
+
+ You can also of course use functions taking many arguments, as is normally the case
+ with SWIG. Finally, these classes work with the following primitive types:
+ - int16_t
+ - int32_t
+ - int64_t
+ - char (corresponding to java byte)
+ - float
+ - double
+
+
+
+
+namespace java
+{
+ template <typename T>
+ class array
+ {
+ /!*
+ WHAT THIS OBJECT REPRESENTS
+ This is a handle to a java array. I.e. a reference to an array instance in
+ java like a double[] or int[]. It doesn't do anything other than tell you
+ the size of the array and allow you to hold a reference to it.
+
+ To access the array contents, you need to create an array_view or
+ array_view_crit from the array.
+ *!/
+ public:
+ array();
+ /!*
+ ensures
+ - #size() == 0
+ - this array is a null reference, i.e. it doesn't reference any array.
+ *!/
+
+ explicit array(size_t new_size);
+ /!*
+ ensures
+ - #size() == new_size
+ - Allocates a new java array.
+ - This array is a reference to the newly allocated java array object.
+ *!/
+
+ size_t size() const;
+ /!*
+ ensures
+ - returns the number of elements in this java array.
+ *!/
+
+ void swap(array& item);
+ /!*
+ ensures
+ - swaps the state of *this and item.
+ *!/
+
+ array(const array& item);
+ array& operator= (const array& item)
+ array(array&& item);
+ array& operator= (array&& item);
+ /!*
+ ensures
+ - The array is copyable, assignable, and movable. All copies will
+ reference the same underlying array. So the copies are shallow, as is
+ normally the case with java reference semantics.
+ *!/
+ };
+
+
+
+ template <typename T>
+ class array_view
+ {
+ /!*
+ WHAT THIS OBJECT REPRESENTS
+ This is a view into a java array object. It allows you to access the
+ values stored in an array and modify them if you want to.
+
+ You should only create array_view objects locally in a function since an
+ array_view is only valid as long as the array it references exists. So
+ don't store array_view objects in the member area of a class or globally.
+ *!/
+
+ public:
+ array_view();
+ /!*
+ ensures
+ - #size() == 0
+ - #data() == nullptr
+ *!/
+
+ array_view(const array<T>& arr, bool might_be_modified=true);
+ /!*
+ ensures
+ - #size() == arr.size()
+ - #data() == a pointer to the beginning of the array data referenced by arr.
+ - When you get a view on a java array, sometimes the JVM will actually
+ give you a pointer to a copy of the original array. You therefore have
+ to tell the JVM if you modified the array when you are done using it. If
+ you say you modified it then the JVM will perform another copy from your
+ memory buffer back into the JVM. The state of might_be_modified controls
+ if we do this. So if you are going to modify the array via this
+ array_view you should set might_be_modified==true.
+ *!/
+
+ size_t size() const;
+ /!*
+ ensures
+ - returns the number of elements in this java array.
+ *!/
+
+ T* data();
+ const T* data() const;
+ /!*
+ ensures
+ - returns a pointer to the beginning of the array. Or nullptr if this is a
+ handle to null, rather than an actual array instance.
+ *!/
+
+ T* begin();
+ T* end();
+ const T* begin() const;
+ const T* end() const;
+ /!*
+ ensures
+ - returns iterators to the start and one-past-the-end of the array, as is
+ the convention for iterator ranges in C++.
+ *!/
+
+ T& operator[](size_t i);
+ const T& operator[](size_t i) const;
+ /!*
+ ensures
+ - returns data()[i]
+ *!/
+
+ private:
+ // this object is non-copyable.
+ array_view(const array_view&);
+ array_view& operator=(const array_view&);
+ };
+
+
+ template <typename T>
+ class array_view_crit
+ {
+ /!*
+ WHAT THIS OBJECT REPRESENTS
+ This is just like an array_view and has an identical interface. The only
+ difference is that we use the JNI call GetPrimitiveArrayCritical() to get a
+ critical lock on the array's memory. Therefore, using array_view_crit is
+ usually faster than array_view since it avoids any unnecessary copying back
+ and forth between the JVM.
+
+ However, this critical lock can block the JVM's garbage collector from
+ running. So don't create long lived array_view_crit objects.
+ *!/
+ };
+
+}
+*/
+
+
+
+
+
+
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// IMPLEMENTATION DETAILS
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+
+
+
+
+
+
+
+namespace java
+{
+
+template <typename T>
+class array_view_base
+{
+public:
+ array_view_base() = default;
+
+ size_t size() const { return sz; }
+ T* data() { return pdata; }
+ const T* data() const { return pdata; }
+
+ T* begin() { return pdata; }
+ T* end() { return pdata+sz; }
+ const T* begin() const { return pdata; }
+ const T* end() const { return pdata+sz; }
+
+ T& operator[](size_t i) { return pdata[i]; }
+ const T& operator[](size_t i) const { return pdata[i]; }
+
+protected:
+ T* pdata = nullptr;
+ size_t sz = 0;
+
+private:
+ // this object is non-copyable
+ array_view_base(const array_view_base&);
+ array_view_base& operator=(const array_view_base&);
+
+};
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+template <typename T>
+struct find_java_array_type;
+
+template <> struct find_java_array_type<int16_t> { typedef jshortArray type; };
+template <> struct find_java_array_type<int32_t> { typedef jintArray type; };
+template <> struct find_java_array_type<int64_t> { typedef jlongArray type; };
+template <> struct find_java_array_type<char> { typedef jbyteArray type; };
+template <> struct find_java_array_type<float> { typedef jfloatArray type; };
+template <> struct find_java_array_type<double> { typedef jdoubleArray type; };
+
+jshortArray create_java_array(int16_t, size_t size) { return JNI_GetEnv()->NewShortArray(size); }
+jintArray create_java_array(int32_t, size_t size) { return JNI_GetEnv()->NewIntArray(size); }
+jlongArray create_java_array(int64_t, size_t size) { return JNI_GetEnv()->NewLongArray(size); }
+jbyteArray create_java_array(char, size_t size) { return JNI_GetEnv()->NewByteArray(size); }
+jfloatArray create_java_array(float, size_t size) { return JNI_GetEnv()->NewFloatArray(size); }
+jdoubleArray create_java_array(double , size_t size) { return JNI_GetEnv()->NewDoubleArray(size); }
+
+template <typename T>
+class array
+{
+public:
+
+ typedef typename find_java_array_type<T>::type java_type;
+
+ array() {}
+
+ explicit array(size_t size)
+ {
+ ref = create_java_array(T(),size);
+ is_global_ref = false;
+ }
+
+ array(java_type ref_)
+ {
+ if (ref_)
+ {
+ ref = (java_type)JNI_GetEnv()->NewGlobalRef(ref_);
+ is_global_ref = true;
+ }
+ }
+
+#ifndef SWIG
+ array(array&& item)
+ {
+ ref = item.ref;
+ is_global_ref = item.is_global_ref;
+ item.ref = NULL;
+ item.is_global_ref = false;
+ }
+ array& operator= (array&& item)
+ {
+ array(std::move(item)).swap(*this);
+ return *this;
+ }
+#endif
+
+ ~array()
+ {
+ if (ref)
+ {
+ // Don't delete the reference if it's a local reference, since the only reason
+ // we will normally be using array object's that contain local references
+ // is because we plan on returning the newly constructed array back to the JVM,
+ // which automatically frees local references using the normal JVM garbage
+ // collection scheme.
+ if (is_global_ref)
+ JNI_GetEnv()->DeleteGlobalRef(ref);
+
+ ref = NULL;
+ is_global_ref = false;
+ }
+ }
+
+ size_t size() const
+ {
+ if (ref)
+ return JNI_GetEnv()->GetArrayLength(ref);
+ else
+ return 0;
+ }
+
+ array(const array& item)
+ {
+ array(item.ref).swap(*this);
+ }
+
+ array& operator= (const array& item)
+ {
+ array(item).swap(*this);
+ return *this;
+ }
+
+ operator java_type() const { return ref;}
+
+ void swap(array& item)
+ {
+ std::swap(ref, item.ref);
+ std::swap(is_global_ref, item.is_global_ref);
+ }
+
+private:
+ java_type ref = NULL;
+ bool is_global_ref = false;
+};
+
+#ifdef SWIG
+// Tell SWIG to not use it's SwigValueWrapper stuff on array objects since they aren't
+// needed and it causes superfluous construction and destruction of array objects.
+%feature("novaluewrapper") array<int16_t>;
+%template() array<int16_t>;
+%feature("novaluewrapper") array<int32_t>;
+%template() array<int32_t>;
+%feature("novaluewrapper") array<int64_t>;
+%template() array<int64_t>;
+%feature("novaluewrapper") array<char>;
+%template() array<char>;
+%feature("novaluewrapper") array<float>;
+%template() array<float>;
+%feature("novaluewrapper") array<double>;
+%template() array<double>;
+#endif
+
+#ifdef SWIG
+%define tostring(token)
+ #token
+%enddef
+
+%define define_javaObjectRef_converion(type, java_type)
+ // Define array conversions for non-const arrays
+ %typemap(jtype) (array<type>) "java_type[]"
+ %typemap(jstype) (array<type>) "java_type[]"
+ %typemap(jni) (array<type>) tostring(j##java_type##Array)
+ %typemap(javain) (array<type>) "$javainput"
+ %typemap(in) (array<type>) { $1 = java::array<type>($input); }
+ %typemap(javaout) (array<type>) {return $jnicall; }
+ %typemap(out) (array<type>) {jresult = result;}
+
+ %typemap(jtype) (array<type>&) "java_type[]"
+ %typemap(jstype) (array<type>&) "java_type[]"
+ %typemap(jni) (array<type>&) tostring(j##java_type##Array)
+ %typemap(javain) (array<type>&) "$javainput"
+ %typemap(arginit) (array<type>&) { $1 = &temp$argnum; }
+ %typemap(in) (array<type>&) (java::array<type> temp) { *($1) = java::array<type>($input); }
+
+ %typemap(jtype) (const array<type>&) "java_type[]"
+ %typemap(jstype) (const array<type>&) "java_type[]"
+ %typemap(jni) (const array<type>&) tostring(j##java_type##Array)
+ %typemap(javain) (const array<type>&) "$javainput"
+ %typemap(arginit) (const array<type>&) { $1 = &temp$argnum; }
+ %typemap(in) (const array<type>&) (java::array<type> temp) { *($1) = java::array<type>($input); }
+%enddef
+define_javaObjectRef_converion(int16_t,short)
+define_javaObjectRef_converion(int32_t,int)
+define_javaObjectRef_converion(int64_t,long)
+define_javaObjectRef_converion(char,byte)
+define_javaObjectRef_converion(float,float)
+define_javaObjectRef_converion(double,double)
+
+#endif
+// ----------------------------------------------------------------------------------------
+
+template <typename T> class array_view;
+
+#define JAVA_ARRAY_CLASS_SPEC(ctype, type, Type) \
+template <> class array_view<ctype> : public array_view_base<ctype> \
+{ \
+public: \
+ ~array_view() { clear(); } \
+ array_view() {} \
+ array_view(const array<ctype>& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} \
+ void reset(JNIEnv* jenv_, j##type##Array arr, bool might_be_modified_) { \
+ clear(); \
+ jenv = jenv_; \
+ oldArr = arr; \
+ if (arr) { \
+ pdata = (ctype*)jenv->Get##Type##ArrayElements(arr, 0); \
+ sz = jenv->GetArrayLength(arr); \
+ } \
+ might_be_modified = might_be_modified_; \
+ } \
+private: \
+ void clear() { \
+ if (pdata) { \
+ jenv->Release##Type##ArrayElements(oldArr, (j##type*)pdata, might_be_modified?0:JNI_ABORT); \
+ pdata = nullptr; \
+ sz = 0; \
+ } \
+ } \
+ JNIEnv* jenv = nullptr; \
+ j##type##Array oldArr; \
+ bool might_be_modified; \
+};
+
+JAVA_ARRAY_CLASS_SPEC(int16_t,short, Short)
+JAVA_ARRAY_CLASS_SPEC(int32_t,int, Int)
+JAVA_ARRAY_CLASS_SPEC(int64_t,long, Long)
+JAVA_ARRAY_CLASS_SPEC(char,byte, Byte)
+JAVA_ARRAY_CLASS_SPEC(float,float, Float)
+JAVA_ARRAY_CLASS_SPEC(double,double, Double)
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+
+template <typename T, typename JARR>
+class array_view_crit_base
+{
+public:
+ array_view_crit_base() = default;
+
+ size_t size() const { return sz; }
+ T* data() { return pdata; }
+ const T* data() const { return pdata; }
+
+ T* begin() { return pdata; }
+ T* end() { return pdata+sz; }
+ const T* begin() const { return pdata; }
+ const T* end() const { return pdata+sz; }
+ T& operator[](size_t i) { return pdata[i]; }
+ const T& operator[](size_t i) const { return pdata[i]; }
+
+ ~array_view_crit_base() { clear(); }
+
+ void reset(JNIEnv* jenv_, JARR arr, bool might_be_modified_)
+ {
+ clear();
+ jenv = jenv_;
+ oldArr = arr;
+ if (arr)
+ {
+ pdata = (T*)jenv->GetPrimitiveArrayCritical(arr, 0);
+ sz = jenv->GetArrayLength(arr);
+ }
+ might_be_modified = might_be_modified_;
+ }
+
+private:
+
+ void clear()
+ {
+ if (pdata) {
+ jenv->ReleasePrimitiveArrayCritical(oldArr, pdata, might_be_modified?0:JNI_ABORT);
+ pdata = nullptr;
+ sz = 0;
+ }
+ }
+
+ // this object is non-copyable
+ array_view_crit_base(const array_view_crit_base&);
+ array_view_crit_base& operator=(const array_view_crit_base&);
+
+ T* pdata = nullptr;
+ size_t sz = 0;
+ JNIEnv* jenv = nullptr;
+ JARR oldArr;
+ bool might_be_modified;
+};
+
+template <typename T> class array_view_crit;
+
+template <> class array_view_crit<int16_t> : public array_view_crit_base<int16_t,jshortArray> { public: array_view_crit(){} array_view_crit(const array<int16_t>& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} };
+template <> class array_view_crit<int32_t> : public array_view_crit_base<int32_t,jintArray> { public: array_view_crit(){} array_view_crit(const array<int32_t>& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} };
+template <> class array_view_crit<int64_t> : public array_view_crit_base<int64_t,jlongArray> { public: array_view_crit(){} array_view_crit(const array<int64_t>& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} };
+template <> class array_view_crit<char> : public array_view_crit_base<char,jbyteArray> { public: array_view_crit(){} array_view_crit(const array<char>& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} };
+template <> class array_view_crit<float> : public array_view_crit_base<float,jfloatArray> { public: array_view_crit(){} array_view_crit(const array<float>& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} };
+template <> class array_view_crit<double> : public array_view_crit_base<double,jdoubleArray> { public: array_view_crit(){} array_view_crit(const array<double>& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} };
+
+// ----------------------------------------------------------------------------------------
+
+// Define SWIG typemaps so SWIG will know what to do with the array_view and array_view_crit
+// objects.
+#ifdef SWIG
+%define define_array_converion(type, java_type)
+ // Define array conversions for non-const arrays
+ %typemap(jtype) (array_view<type>&) "java_type[]"
+ %typemap(jstype) (array_view<type>&) "java_type[]"
+ %typemap(jni) (array_view<type>&) tostring(j##java_type##Array)
+ %typemap(javain) (array_view<type>&) "$javainput"
+ %typemap(arginit) (array_view<type>&) { $1 = &temp$argnum; }
+ %typemap(in) (array_view<type>&) (java::array_view<type> temp) { $1->reset(jenv, $input, true); }
+
+ %typemap(jtype) (const array_view<type>&) "java_type[]"
+ %typemap(jstype) (const array_view<type>&) "java_type[]"
+ %typemap(jni) (const array_view<type>&) tostring(j##java_type##Array)
+ %typemap(javain) (const array_view<type>&) "$javainput"
+ %typemap(arginit) (const array_view<type>&) { $1 = &temp$argnum; }
+ %typemap(in) (const array_view<type>&) (java::array_view<type> temp) { $1->reset(jenv, $input, false); }
+%enddef
+define_array_converion(int16_t,short)
+define_array_converion(int32_t,int)
+define_array_converion(int64_t,long)
+define_array_converion(char,byte)
+define_array_converion(float,float)
+define_array_converion(double,double)
+
+
+
+%define define_array_crit_converion(type, java_type)
+ // Define array conversions for non-const arrays
+ %typemap(jtype) (array_view_crit<type>&) "java_type[]"
+ %typemap(jstype) (array_view_crit<type>&) "java_type[]"
+ %typemap(jni) (array_view_crit<type>&) tostring(j##java_type##Array)
+ %typemap(javain) (array_view_crit<type>&) "$javainput"
+ %typemap(arginit) (array_view_crit<type>&) { $1 = &temp$argnum; }
+ %typemap(in) (array_view_crit<type>&) (java::array_view_crit<type> temp) { $1->reset(jenv, $input, true); }
+
+ %typemap(jtype) (const array_view_crit<type>&) "java_type[]"
+ %typemap(jstype) (const array_view_crit<type>&) "java_type[]"
+ %typemap(jni) (const array_view_crit<type>&) tostring(j##java_type##Array)
+ %typemap(javain) (const array_view_crit<type>&) "$javainput"
+ %typemap(arginit) (const array_view_crit<type>&) { $1 = &temp$argnum; }
+ %typemap(in) (const array_view_crit<type>&) (java::array_view_crit<type> temp) { $1->reset(jenv, $input, false); }
+%enddef
+define_array_crit_converion(int16_t,short)
+define_array_crit_converion(int32_t,int)
+define_array_crit_converion(int64_t,long)
+define_array_crit_converion(char,byte)
+define_array_crit_converion(float,float)
+define_array_crit_converion(double,double)
+
+#endif // SWIG
+
+}
+
+#endif // DLIB_SWIG_JAVA_ARRAY_H_
+
diff --git a/ml/dlib/dlib/java/run_test.sh b/ml/dlib/dlib/java/run_test.sh
new file mode 100755
index 000000000..192ea6c8b
--- /dev/null
+++ b/ml/dlib/dlib/java/run_test.sh
@@ -0,0 +1,17 @@
+
+# build the jar and shared library of C++ code needed by the JVM
+mkdir build
+cd build
+cmake ..
+cmake --build . --config Release --target install
+cd ..
+
+
+# setup paths so the JVM can find our jar and shared library.
+export LD_LIBRARY_PATH=.
+export DYLD_LIBRARY_PATH=.
+export CLASSPATH=myproject.jar:.
+
+# Now compile and run our java test that calls our C++ code.
+javac swig_test.java
+java swig_test
diff --git a/ml/dlib/dlib/java/swig_api.h b/ml/dlib/dlib/java/swig_api.h
new file mode 100644
index 000000000..e807c6c8f
--- /dev/null
+++ b/ml/dlib/dlib/java/swig_api.h
@@ -0,0 +1,126 @@
+#ifndef EXAMPLE_SWIG_ApI_H_
+#define EXAMPLE_SWIG_ApI_H_
+
+// This file is essentially a small unit test for the swig cmake scripts and the java array
+// classes. All it does it define a few simple functions for writing to and summing
+// arrays. The swig_test.java file then calls these C++ functions and checks if they work
+// correctly.
+
+
+
+// Let's use java_array.h, a tool for efficiently binding java native arrays to C++
+// function arguments. You do this by putting this pair of include statements in your
+// swig_api.h file. Then after that you can use the java::array, java::array_view, and
+// java::array_view_crit classes.
+#include <dlib/java/java_array.h>
+#ifdef SWIG
+%include <dlib/java/java_array.h>
+#endif
+
+
+using namespace java;
+
+
+// SWIG can't expose templated functions to java. We declare these here as helper
+// functions to make the non-templated routines swig will expose easier to write. You can
+// see these java exposed methods below (i.e. sum(), sum_crit(), assign(), and
+// assign_crit()).
+template <typename T>
+T tsum(const array_view_crit<T>& arr)
+{
+ T s = 0;
+ for (auto& v : arr)
+ s += v;
+ return s;
+}
+template <typename T>
+T tsum(const array_view<T>& arr)
+{
+ T s = 0;
+ for (auto& v : arr)
+ s += v;
+ return s;
+}
+template <typename T>
+void tassign(T& arr)
+{
+ for (size_t i = 0; i < arr.size(); ++i)
+ arr[i] = i;
+}
+
+// ----------------------------------------------------------------------------------------
+
+// Now write some functions SWIG will expose to java. SWIG will automatically expose
+// pretty much any non-template C++ code to java. So just by defining these functions here
+// we expose them to java.
+//
+// All global C++ functions will appear in java as static member functions of class called
+// "global", which is where these sum and assign routines will appear. You can see
+// examples of java code that calls them in swig_test.java.
+
+inline int sum_crit(const array_view_crit<int16_t>& arr) { return tsum(arr); }
+inline int sum(const array_view<int16_t>& arr) { return tsum(arr); }
+inline void assign_crit(array_view_crit<int16_t>& arr) { tassign(arr); }
+inline void assign(array_view<int16_t>& arr) { tassign(arr); }
+
+
+inline int sum_crit(const array_view_crit<int32_t>& arr) { return tsum(arr); }
+inline int sum(const array_view<int32_t>& arr) { return tsum(arr); }
+inline void assign_crit(array_view_crit<int32_t>& arr) { tassign(arr); }
+inline void assign(array_view<int32_t>& arr) { tassign(arr); }
+
+
+inline int sum_crit(const array_view_crit<int64_t>& arr) { return tsum(arr); }
+inline int sum(const array_view<int64_t>& arr) { return tsum(arr); }
+inline void assign_crit(array_view_crit<int64_t>& arr) { tassign(arr); }
+inline void assign(array_view<int64_t>& arr) { tassign(arr); }
+
+
+inline int sum_crit(const array_view_crit<char>& arr) { return tsum(arr); }
+inline int sum(const array_view<char>& arr) { return tsum(arr); }
+inline void assign_crit(array_view_crit<char>& arr) { tassign(arr); }
+inline void assign(array_view<char>& arr) { tassign(arr); }
+
+
+
+inline double sum_crit(const array_view_crit<double>& arr) { return tsum(arr); }
+inline double sum(const array_view<double>& arr) { return tsum(arr); }
+inline void assign_crit(array_view_crit<double>& arr) { tassign(arr); }
+inline void assign(array_view<double>& arr) { tassign(arr); }
+
+
+inline float sum_crit(array<float> arr)
+{
+ array_view_crit<float> a(arr);
+ return tsum(a);
+}
+inline float sum(const array<float>& arr)
+{
+ array_view<float> a(arr);
+ return tsum(a);
+}
+inline void assign_crit(array_view_crit<float>& arr) { tassign(arr); }
+inline void assign(array<float>& arr)
+{
+ array_view<float> a(arr);
+ tassign(a);
+}
+
+array<int32_t> make_an_array(size_t s)
+{
+ array<int32_t> arr(s);
+ array_view_crit<int32_t> a(arr);
+
+ for (size_t i = 0; i < a.size(); ++i)
+ a[i] = i;
+
+ return arr;
+}
+
+
+// ----------------------------------------------------------------------------------------
+
+
+#endif // EXAMPLE_SWIG_ApI_H_
+
+
diff --git a/ml/dlib/dlib/java/swig_test.java b/ml/dlib/dlib/java/swig_test.java
new file mode 100644
index 000000000..e75edb913
--- /dev/null
+++ b/ml/dlib/dlib/java/swig_test.java
@@ -0,0 +1,254 @@
+
+/*
+
+ This file tests all the ways of using jvector and jvector_crit.
+
+*/
+
+
+import net.dlib.*;
+
+public class swig_test
+{
+ public static int sum(long[] arr)
+ {
+ int s = 0;
+ for (int i = 0; i < arr.length; ++i)
+ s += arr[i];
+ return s;
+ }
+ public static void zero(long[] arr)
+ {
+ for (int i = 0; i < arr.length; ++i)
+ arr[i] = 0;
+ }
+
+ public static int sum(byte[] arr)
+ {
+ int s = 0;
+ for (int i = 0; i < arr.length; ++i)
+ s += arr[i];
+ return s;
+ }
+ public static void zero(byte[] arr)
+ {
+ for (int i = 0; i < arr.length; ++i)
+ arr[i] = 0;
+ }
+ public static int sum(short[] arr)
+ {
+ int s = 0;
+ for (int i = 0; i < arr.length; ++i)
+ s += arr[i];
+ return s;
+ }
+ public static void zero(short[] arr)
+ {
+ for (int i = 0; i < arr.length; ++i)
+ arr[i] = 0;
+ }
+
+ public static int sum(int[] arr)
+ {
+ int s = 0;
+ for (int i = 0; i < arr.length; ++i)
+ s += arr[i];
+ return s;
+ }
+ public static void zero(int[] arr)
+ {
+ for (int i = 0; i < arr.length; ++i)
+ arr[i] = 0;
+ }
+
+ public static void assertIs28(int val)
+ {
+ if (val != 28)
+ {
+ throw new RuntimeException("Test failed " + val);
+ }
+ }
+
+ public static void assertIsEqual(int val1, int val2)
+ {
+ if (val1 != val2)
+ {
+ throw new RuntimeException("Test failed " + val1 + " should be equal to " + val2);
+ }
+ }
+
+ public static double sum(double[] arr)
+ {
+ double s = 0;
+ for (int i = 0; i < arr.length; ++i)
+ s += arr[i];
+ return s;
+ }
+ public static void zero(double[] arr)
+ {
+ for (int i = 0; i < arr.length; ++i)
+ arr[i] = 0;
+ }
+
+ public static void assertIs28(double val)
+ {
+ if (val != 28)
+ {
+ throw new RuntimeException("Test failed " + val);
+ }
+ }
+
+ public static float sum(float[] arr)
+ {
+ float s = 0;
+ for (int i = 0; i < arr.length; ++i)
+ s += arr[i];
+ return s;
+ }
+ public static void zero(float[] arr)
+ {
+ for (int i = 0; i < arr.length; ++i)
+ arr[i] = 0;
+ }
+
+ public static void assertIs28(float val)
+ {
+ if (val != 28)
+ {
+ throw new RuntimeException("Test failed " + val);
+ }
+ }
+
+ public static void main(String[] args)
+ {
+ {
+ float[] arr = new float[8];
+
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ }
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum_crit(arr));
+ }
+ }
+ {
+ double[] arr = new double[8];
+
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ }
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum_crit(arr));
+ }
+ }
+ {
+ byte[] arr = new byte[8];
+
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ }
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum_crit(arr));
+ }
+ }
+ {
+ long[] arr = new long[8];
+
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ }
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum_crit(arr));
+ }
+ }
+ {
+ short[] arr = new short[8];
+
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ }
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum_crit(arr));
+ }
+ }
+ {
+ int[] arr = new int[8];
+
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ }
+ for (int round = 0; round < 100; ++round)
+ {
+ zero(arr); global.assign(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum(arr));
+ zero(arr); global.assign_crit(arr);
+ assertIs28(sum(arr));
+ assertIs28(global.sum_crit(arr));
+ }
+ }
+ {
+ int[] a = global.make_an_array(4);
+ for (int i = 0; i < a.length; ++i)
+ {
+ assertIsEqual(a[i], i);
+ }
+ }
+
+ System.out.println("\n\n ALL TESTS COMPLETED SUCCESSFULLY\n");
+ }
+}
diff --git a/ml/dlib/dlib/linker.h b/ml/dlib/dlib/linker.h
new file mode 100644
index 000000000..e4f9b5daf
--- /dev/null
+++ b/ml/dlib/dlib/linker.h
@@ -0,0 +1,9 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LINKEr_
+#define DLIB_LINKEr_
+
+#include "linker/linker_kernel_1.h"
+
+#endif // DLIB_LINKEr_
+
diff --git a/ml/dlib/dlib/linker/linker_kernel_1.cpp b/ml/dlib/dlib/linker/linker_kernel_1.cpp
new file mode 100644
index 000000000..e76009b37
--- /dev/null
+++ b/ml/dlib/dlib/linker/linker_kernel_1.cpp
@@ -0,0 +1,357 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LINKER_KERNEL_1_CPp_
+#define DLIB_LINKER_KERNEL_1_CPp_
+#include "linker_kernel_1.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ linker::
+ linker (
+ ) :
+ running(false),
+ running_signaler(running_mutex),
+ A(0),
+ B(0),
+ service_connection_running_signaler(service_connection_running_mutex)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ linker::
+ linker (
+ connection& a,
+ connection& b
+ ) :
+ running(false),
+ running_signaler(running_mutex),
+ A(0),
+ B(0),
+ service_connection_running_signaler(service_connection_running_mutex)
+ {
+ link(a,b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ linker::
+ ~linker (
+ )
+ {
+ clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void linker::
+ clear (
+ )
+ {
+
+ // shutdown the connections
+ cons_mutex.lock();
+ if (A != 0 )
+ {
+ A->shutdown();
+ A = 0;
+ }
+ if (B != 0)
+ {
+ B->shutdown();
+ B = 0;
+ }
+ cons_mutex.unlock();
+
+
+ // wait for the other threads to signal that they have ended
+ running_mutex.lock();
+ while (running == true)
+ {
+ running_signaler.wait();
+ }
+ running_mutex.unlock();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool linker::
+ is_running (
+ ) const
+ {
+ running_mutex.lock();
+ bool temp = running;
+ running_mutex.unlock();
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void linker::
+ link (
+ connection& a,
+ connection& b
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(
+ this->is_running() == false ,
+ "\tvoid linker::link"
+ << "\n\tis_running() == " << this->is_running()
+ << "\n\tthis: " << this
+ );
+
+ running_mutex.lock();
+ running = true;
+ running_mutex.unlock();
+
+ cons_mutex.lock();
+ A = &a;
+ B = &b;
+ cons_mutex.unlock();
+
+
+
+ service_connection_running_mutex.lock();
+ service_connection_running = true;
+ service_connection_running_mutex.unlock();
+
+ service_connection_error_mutex.lock();
+ service_connection_error = false;
+ service_connection_error_mutex.unlock();
+
+ // if we fail to make the thread
+ if (!create_new_thread(service_connection,this))
+ {
+ a.shutdown();
+ b.shutdown();
+
+ service_connection_running_mutex.lock();
+ service_connection_running = false;
+ service_connection_running_mutex.unlock();
+
+ cons_mutex.lock();
+ A = 0;
+ B = 0;
+ cons_mutex.unlock();
+
+ running_mutex.lock();
+ running = false;
+ running_mutex.unlock();
+
+
+
+ throw dlib::thread_error (
+ ECREATE_THREAD,
+ "failed to make new thread in linker::link()"
+ );
+ }
+
+
+
+ // forward data from a to b
+ char buf[200];
+ int status;
+ bool error = false; // becomes true if one of the connections returns an error
+ while (true)
+ {
+ status = a.read(buf,sizeof(buf));
+ // if there was an error reading from the socket
+ if (status == OTHER_ERROR)
+ {
+ error = true;
+ break;
+ }
+ else if (status == SHUTDOWN)
+ {
+ b.shutdown();
+ }
+
+ if (status <= 0)
+ {
+ // if a has closed normally
+ if (status == 0)
+ b.shutdown_outgoing();
+ break;
+ }
+
+ status = b.write(buf,status);
+ // if there was an error writing to the socket then break
+ if (status == OTHER_ERROR)
+ {
+ error = true;
+ break;
+ }
+
+ if (status <= 0)
+ break;
+ }
+
+
+ // if there was an error then shutdown both connections
+ if (error)
+ {
+ a.shutdown();
+ b.shutdown();
+ }
+
+
+
+
+ // wait for the other thread to end
+ service_connection_running_mutex.lock();
+ while(service_connection_running)
+ {
+ service_connection_running_signaler.wait();
+ }
+ service_connection_running_mutex.unlock();
+
+
+ // make sure connections are shutdown
+ a.shutdown();
+ b.shutdown();
+
+
+ // both threads have ended so the connections are no longer needed
+ cons_mutex.lock();
+ A = 0;
+ B = 0;
+ cons_mutex.unlock();
+
+
+ // if service_connection terminated due to an error then set error to true
+ service_connection_error_mutex.lock();
+ if (service_connection_error)
+ error = true;
+ service_connection_error_mutex.unlock();
+
+
+ // if we are ending because of an error
+ if (error)
+ {
+
+ // signal that the link() function is ending
+ running_mutex.lock();
+ running = false;
+ running_signaler.broadcast();
+ running_mutex.unlock();
+
+ // throw the exception for this error
+ throw dlib::socket_error (
+ ECONNECTION,
+ "a connection returned an error in linker::link()"
+ );
+
+ }
+
+ // signal that the link() function is ending
+ running_mutex.lock();
+ running = false;
+ running_signaler.broadcast();
+ running_mutex.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void linker::
+ service_connection (
+ void* param
+ )
+ {
+ linker& p = *static_cast<linker*>(param);
+
+ p.cons_mutex.lock();
+ // if the connections are gone for whatever reason then return
+ if (p.A == 0 || p.B == 0)
+ {
+ // signal that this function is ending
+ p.service_connection_running_mutex.lock();
+ p.service_connection_running = false;
+ p.service_connection_running_signaler.broadcast();
+ p.service_connection_running_mutex.unlock();
+ return;
+ }
+ connection& a = *p.A;
+ connection& b = *p.B;
+ p.cons_mutex.unlock();
+
+
+
+ // forward data from b to a
+ char buf[200];
+ int status;
+ bool error = false;
+ while (true)
+ {
+ status = b.read(buf,sizeof(buf));
+ // if there was an error reading from the socket
+ if (status == OTHER_ERROR)
+ {
+ error = true;
+ break;
+ }
+ else if (status == SHUTDOWN)
+ {
+ a.shutdown();
+ }
+
+
+ if (status <= 0)
+ {
+ // if b has closed normally
+ if (status == 0)
+ a.shutdown_outgoing();
+ break;
+ }
+
+
+ status = a.write(buf,status);
+ // if there was an error writing to the socket then break
+ if (status == OTHER_ERROR)
+ {
+ error = true;
+ break;
+ }
+
+ if (status <= 0)
+ break;
+ }
+
+
+ // if there was an error then shutdown both connections
+ if (error)
+ {
+ a.shutdown();
+ b.shutdown();
+ }
+
+
+ // if there was an error then signal that
+ if (error)
+ {
+ p.service_connection_error_mutex.lock();
+ p.service_connection_error = true;
+ p.service_connection_error_mutex.unlock();
+ }
+
+ // signal that this function is ending
+ p.service_connection_running_mutex.lock();
+ p.service_connection_running = false;
+ p.service_connection_running_signaler.broadcast();
+ p.service_connection_running_mutex.unlock();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_LINKER_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/linker/linker_kernel_1.h b/ml/dlib/dlib/linker/linker_kernel_1.h
new file mode 100644
index 000000000..b101026b2
--- /dev/null
+++ b/ml/dlib/dlib/linker/linker_kernel_1.h
@@ -0,0 +1,141 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LINKER_KERNEl_1_
+#define DLIB_LINKER_KERNEl_1_
+
+#include "linker_kernel_abstract.h"
+#include "../threads.h"
+#include "../sockets.h"
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+ class linker
+ {
+
+ /*!
+ INITIAL VALUE
+ running == false
+ A == 0
+ B == 0
+ running_mutex == a mutex
+ running_signaler == a signaler associated with running_mutex
+ cons_mutex == a mutex
+ service_connection_running == false
+ service_connection_running_mutex == a mutex
+ service_connection_running_signaler == a signaler associated with
+ service_connection_running_mutex
+
+ service_connection_error == false
+ service_connection_error_mutex == a mutex
+
+
+
+ CONVENTION
+ running == is_running()
+ running_mutex == a mutex for running
+ running_signaler == a signaler for signaling when
+ running becomes false and is associated with
+ running_mutex
+ cons_mutex == a mutex for A and B
+
+ service_connection_running == true when service_connection() is
+ running or is about to run else
+ false
+ service_connection_running_mutex == a mutex for service_connection_running
+ service_connection_running_signaler == a signaler associated with
+ service_connection_running_mutex
+
+ if (running) then
+ A == address of a from link()
+ B == address of b from link()
+ else
+ A == 0
+ B == 0
+
+ service_connection_error == service_connection uses this bool
+ to indicate if it terminated due to
+ an error or not
+ service_connection_error_mutex == a mutex for service_connection_error
+
+
+ !*/
+
+ public:
+
+ // These two typedefs are here for backwards compatibility with previous
+ // versions of dlib.
+ typedef linker kernel_1a;
+ typedef linker kernel_1a_c;
+
+ linker(
+ );
+
+ linker (
+ connection& a,
+ connection& b
+ );
+
+ virtual ~linker(
+ );
+
+ void clear(
+ );
+
+ bool is_running(
+ ) const;
+
+ void link (
+ connection& a,
+ connection& b
+ );
+
+
+ private:
+
+ static void service_connection (
+ void* param
+ );
+ /*!
+ requires
+ param == pointer to a linker object
+ ensures
+ waits for data from b and forwards it to a and
+ if (b closes normally or is shutdown()) service_connection ends and
+ if (b closes normally) then a.shutdown_outgoing() is called and
+ if (a or b returns an error) then a and b are shutdown()
+ !*/
+
+
+ // data members
+ bool running;
+ mutex running_mutex;
+ signaler running_signaler;
+ connection* A;
+ connection* B;
+ mutex cons_mutex;
+
+ bool service_connection_running;
+ mutex service_connection_running_mutex;
+ signaler service_connection_running_signaler;
+
+ bool service_connection_error;
+ mutex service_connection_error_mutex;
+
+ // restricted functions
+ linker(linker&); // copy constructor
+ linker& operator=(linker&); // assignment operator
+ };
+
+
+
+}
+
+#ifdef NO_MAKEFILE
+#include "linker_kernel_1.cpp"
+#endif
+
+#endif // DLIB_LINKER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/linker/linker_kernel_abstract.h b/ml/dlib/dlib/linker/linker_kernel_abstract.h
new file mode 100644
index 000000000..cef0901e4
--- /dev/null
+++ b/ml/dlib/dlib/linker/linker_kernel_abstract.h
@@ -0,0 +1,141 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LINKER_KERNEl_ABSTRACT_
+#ifdef DLIB_LINKER_KERNEl_ABSTRACT_
+
+#include "../threads/threads_kernel_abstract.h"
+#include "../sockets/sockets_kernel_abstract.h"
+
+namespace dlib
+{
+
+ class linker
+ {
+
+ /*!
+ INITIAL VALUE
+ is_running() == false
+
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that takes two connections and lets
+ them talk to each other. i.e. any incoming data from one connection is
+ passed unaltered to the other and vice versa.
+
+ note that linker objects are not swappable.
+
+ Also note that when one connection is closed shutdown_outgoing()
+ is called on the other to signal that no more data will be sent
+ in that direction on the connection.
+ (i.e. the FIN packet is effectively also forwarded by the linker object)
+
+ THREAD SAFETY
+ all member functions are thread-safe.
+
+ !*/
+
+ public:
+
+ linker(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ linker (
+ connection& a,
+ connection& b
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - immediately invokes link(a,b);
+ (i.e. using this constructor is the same as creating a linker with
+ the default constructor and then immediately invoking link() on it)
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~linker(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ - if (is_running()) then
+ - the two connections being linked will be shutdown()
+ throws
+ - std::bad_alloc
+ if this exception is thrown then the linker object is unusable
+ until clear() is called and succeeds and
+ if is_running() then the connections will STILL be shutdown()
+ even though an exception is being thrown
+ !*/
+
+ bool is_running(
+ ) const;
+ /*!
+ ensures
+ - returns true if link() is running else
+ - returns false if link() is not running or has released all its
+ resources and is about to terminate
+ throws
+ - std::bad_alloc
+ !*/
+
+
+ void link (
+ connection& a,
+ connection& b
+ );
+ /*!
+ requires
+ - is_running() == false
+ ensures
+ - all incoming data from connection a will be forwarded to b
+ - all incoming data from connection b will be forwarded to a
+ - #a and #b will have been shutdown()
+ - link() will block until both of the connections have ended
+ or an error occurs
+ throws
+ - std::bad_alloc
+ link() may throw this exception and if it does then the object
+ will be unusable until clear() is called and succeeds and
+ connections a and b will be shutdown()
+ - dlib::socket_error
+ link() will throw a this exception if one of the connections
+ returns an error value (being shutdown is not an error).
+ If this happens then the linker object will be cleared and
+ have its initial value. note that if this happens then the
+ connections being linked will be shutdown()
+ - dlib::thread_error
+ link() will throw a this exception if there is a problem
+ creating new threads. Or it may throw this exception if there
+ is a problem creating threading objects. If this happens
+ then the linker object will be cleared and have its initial value.
+ note that if this happens then the connections being linked will
+ be shutdown().
+ !*/
+
+ private:
+
+ // restricted functions
+ linker(linker&); // copy constructor
+ linker& operator=(linker&); // assignment operator
+ };
+
+}
+
+#endif // DLIB_LINKER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/locale b/ml/dlib/dlib/locale
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/locale
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/logger.h b/ml/dlib/dlib/logger.h
new file mode 100644
index 000000000..e49fad8fd
--- /dev/null
+++ b/ml/dlib/dlib/logger.h
@@ -0,0 +1,11 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LOGGEr_
+#define DLIB_LOGGEr_
+
+#include "logger/logger_kernel_1.h"
+#include "logger/extra_logger_headers.h"
+#include "logger/logger_config_file.h"
+
+#endif // DLIB_LOGGEr_
+
diff --git a/ml/dlib/dlib/logger/extra_logger_headers.cpp b/ml/dlib/dlib/logger/extra_logger_headers.cpp
new file mode 100644
index 000000000..becc1ab2b
--- /dev/null
+++ b/ml/dlib/dlib/logger/extra_logger_headers.cpp
@@ -0,0 +1,40 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_EXTRA_LOGGER_HEADERs_CPP_
+#define DLIB_EXTRA_LOGGER_HEADERs_CPP_
+
+#include "extra_logger_headers.h"
+#include <ctime>
+#include <cstring>
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ void print_datetime_logger_header (
+ std::ostream& out,
+ const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id
+ )
+ {
+ using namespace std;
+ char* buf;
+
+ time_t t = time(0);
+ buf = ctime(&t);
+ // remove the trailing '\n'
+ size_t size = strlen(buf);
+ buf[size-1] = '\0';
+
+ out << l.name << " (" << buf << ") [" << thread_id << "] " << logger_name << ": ";
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_EXTRA_LOGGER_HEADERs_CPP_
+
+
diff --git a/ml/dlib/dlib/logger/extra_logger_headers.h b/ml/dlib/dlib/logger/extra_logger_headers.h
new file mode 100644
index 000000000..6eb24d84c
--- /dev/null
+++ b/ml/dlib/dlib/logger/extra_logger_headers.h
@@ -0,0 +1,41 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_EXTRA_LOGGER_HEADERs_
+#define DLIB_EXTRA_LOGGER_HEADERs_
+
+#include "logger_kernel_abstract.h"
+#include "logger_kernel_1.h"
+#include <iostream>
+#include <string>
+#include "../uintn.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ void print_datetime_logger_header (
+ std::ostream& out,
+ const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id
+ );
+ /*!
+ requires
+ - is not called more than once at a time (i.e. is not called from multiple
+ threads at the same time).
+ ensures
+ - let DATE be the current date and time (e.g. Thu Aug 31 16:41:52 2006).
+ - prints a string to out in the form: "l.name (DATE) [thread_id] logger_name:"
+ !*/
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef NO_MAKEFILE
+#include "extra_logger_headers.cpp"
+#endif
+
+#endif // DLIB_EXTRA_LOGGER_HEADERs_
+
diff --git a/ml/dlib/dlib/logger/logger_config_file.cpp b/ml/dlib/dlib/logger/logger_config_file.cpp
new file mode 100644
index 000000000..108f66c8c
--- /dev/null
+++ b/ml/dlib/dlib/logger/logger_config_file.cpp
@@ -0,0 +1,214 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LOGGER_CONFIg_FILE_CPP
+#define DLIB_LOGGER_CONFIg_FILE_CPP
+
+#include "logger_config_file.h"
+#include <string>
+#include "../config_reader.h"
+#include <fstream>
+#include <sstream>
+#include "../error.h"
+#include "../map.h"
+#include "../string.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ namespace logger_config_file_helpers
+ {
+
+// ----------------------------------------------------------------------------------------
+
+ std::ostream& get_file_stream (
+ const std::string& file_name
+ )
+ {
+ using namespace std;
+ static dlib::mutex m;
+ auto_mutex M(m);
+ static dlib::map<string,ostream*>::kernel_1a_c file_map;
+
+ if (file_map.is_in_domain(file_name) == false)
+ {
+ // We won't ever delete this output stream. It should be around for the
+ // entire life of the program so just let the OS take care of it.
+ ostream* fout = new ofstream(file_name.c_str());
+ if (!(*fout))
+ {
+ delete fout;
+ throw error("logger_config: unable to open output file " + file_name);
+ }
+
+ // add this file to our file map
+ string temp(file_name);
+ file_map.add(temp,fout);
+ }
+
+ return *file_map[file_name];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ log_level string_to_log_level (
+ const std::string& level
+ )
+ {
+ using namespace std;
+ if (level == "LALL" || level == "ALL" || level == "all")
+ return LALL;
+ else if (level == "LNONE" || level == "NONE" || level == "none")
+ return LNONE;
+ else if (level == "LTRACE" || level == "TRACE" || level == "trace")
+ return LTRACE;
+ else if (level == "LDEBUG" || level == "DEBUG" || level == "debug")
+ return LDEBUG;
+ else if (level == "LINFO" || level == "INFO" || level == "info")
+ return LINFO;
+ else if (level == "LWARN" || level == "WARN" || level == "warn")
+ return LWARN;
+ else if (level == "LERROR" || level == "ERROR" || level == "error")
+ return LERROR;
+ else if (level == "LFATAL" || level == "FATAL" || level == "fatal")
+ return LFATAL;
+ else
+ {
+ const int priority = string_cast<int>(level);
+ return log_level(priority,"CONFIG_FILE_DEFINED");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void configure_sub_blocks (
+ const config_reader& cr,
+ const std::string& name
+ )
+ {
+ using namespace std;
+
+ logger dlog(name.c_str());
+
+ if (cr.is_key_defined("logging_level"))
+ {
+ dlog.set_level(string_to_log_level(cr["logging_level"]));
+ }
+
+ if (cr.is_key_defined("output"))
+ {
+ string output = cr["output"];
+ if (output == "cout")
+ dlog.set_output_stream(cout);
+ else if (output == "cerr")
+ dlog.set_output_stream(cerr);
+ else if (output == "clog")
+ dlog.set_output_stream(clog);
+ else
+ {
+ istringstream sin(output);
+ string one, two, three;
+ sin >> one;
+ sin >> two;
+ sin >> three;
+ if (one == "file" && three.size() == 0)
+ dlog.set_output_stream(get_file_stream(two));
+ else
+ throw error("logger_config: invalid argument to output option: " + output);
+ }
+
+ } // if (cr.is_key_defined("output"))
+
+ // now configure all the sub-blocks
+ std_vector_c<std::string> blocks;
+ cr.get_blocks(blocks);
+ for (unsigned long i = 0; i < blocks.size(); ++i)
+ {
+ configure_sub_blocks(cr.block(blocks[i]), name + "." + blocks[i]);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ } // namespace
+
+// ----------------------------------------------------------------------------------------
+
+ void configure_loggers_from_file (
+ const std::string& file_name
+ )
+ {
+ std::ifstream fin(file_name.c_str());
+
+ if (!fin)
+ throw logger_config_file_error("logger_config: unable to open config file " + file_name);
+
+ config_reader temp(fin);
+ configure_loggers_from_file(temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void configure_loggers_from_file (
+ const config_reader& main_cr
+ )
+ {
+ using namespace logger_config_file_helpers;
+ using namespace std;
+
+ if (main_cr.is_block_defined("logger_config"))
+ {
+ const config_reader& cr = main_cr.block("logger_config");
+
+ if (cr.is_key_defined("logging_level"))
+ {
+ set_all_logging_levels(string_to_log_level(cr["logging_level"]));
+ }
+
+ if (cr.is_key_defined("output"))
+ {
+ string output = cr["output"];
+ if (output == "cout")
+ set_all_logging_output_streams(cout);
+ else if (output == "cerr")
+ set_all_logging_output_streams(cerr);
+ else if (output == "clog")
+ set_all_logging_output_streams(clog);
+ else
+ {
+ istringstream sin(output);
+ string one, two, three;
+ sin >> one;
+ sin >> two;
+ sin >> three;
+ if (one == "file" && three.size() == 0)
+ set_all_logging_output_streams(get_file_stream(two));
+ else
+ throw logger_config_file_error("logger_config: invalid argument to output option: " + output);
+ }
+
+ } // if (cr.is_key_defined("output"))
+
+ // now configure all the sub-blocks
+ std_vector_c<std::string> blocks;
+ cr.get_blocks(blocks);
+ for (unsigned long i = 0; i < blocks.size(); ++i)
+ {
+ configure_sub_blocks(cr.block(blocks[i]), blocks[i]);
+ }
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LOGGER_CONFIg_FILE_CPP
+
+
+
diff --git a/ml/dlib/dlib/logger/logger_config_file.h b/ml/dlib/dlib/logger/logger_config_file.h
new file mode 100644
index 000000000..b0a030f80
--- /dev/null
+++ b/ml/dlib/dlib/logger/logger_config_file.h
@@ -0,0 +1,135 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LOGGER_CONFIg_FILE_
+#define DLIB_LOGGER_CONFIg_FILE_
+
+#include "logger_kernel_abstract.h"
+#include "logger_kernel_1.h"
+#include <string>
+#include "../config_reader.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ class logger_config_file_error : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception class used by the configure_loggers_from_file()
+ function defined below.
+ !*/
+ public:
+ logger_config_file_error(const std::string& s):error(s){}
+ };
+
+ void configure_loggers_from_file (
+ const std::string& file_name
+ );
+ /*!
+ ensures
+ - configures the loggers with the contents of the file_name file
+ throws
+ - dlib::logger_config_file_error
+ this exception is thrown if there is a problem reading the config file
+ !*/
+
+ void configure_loggers_from_file (
+ const config_reader& cr
+ );
+ /*!
+ ensures
+ - configures the loggers with the contents of cr. This function is just like
+ the above version that reads from a file except that it reads from an in-memory
+ config_reader instead.
+ throws
+ - dlib::logger_config_file_error
+ this exception is thrown if there is a problem reading the config file
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ # -----------------------------------------------
+ # ------------- EXAMPLE CONFIG FILE -------------
+ # -----------------------------------------------
+
+ # The overall format of the config file is the same as the one defined by
+ # the config_reader component of this library.
+
+ # This line is a comment line
+
+ # The config file always has a block named logger_config. This is where all the
+ # config data for the loggers reside.
+ logger_config
+ {
+ # This sets all loggers to the level LINFO since it is just inside the
+ # logger_config block
+ logging_level = info
+
+ # Alternatively we could specify a user defined logging level by
+ # supplying a priority number. The following line would specify
+ # that only logging levels at or above 100 are printed. (note that
+ # you would have to comment out the logging_level statement above
+ # to avoid a conflict).
+ # logging_level = 100
+
+ parent_logger
+ {
+ # This sets all loggers named "parent_logger" or children of
+ # loggers with that name to not log at all (i.e. to logging level
+ # LNONE).
+ logging_level = none
+ }
+
+
+ parent_logger2
+ {
+ # set loggers named "parent_logger2" and its children loggers
+ # to write their output to a file named out.txt
+ output = file out.txt
+
+ child_logger
+ {
+ # Set loggers named "parent_logger2.child_logger" and children of loggers
+ # with this name to logging level LALL
+ logging_level = all
+
+ # Note that this logger will also log to out.txt because that is what
+ # its parent does and we haven't overridden it here with something else.
+ # if we wanted this logger to write to cout instead we could uncomment
+ # the following line:
+ # output = cout
+ }
+ }
+ }
+
+ # So in summary, all logger config stuff goes inside a block named logger_config. Then
+ # inside that block all blocks must be the names of loggers. There are only two keys,
+ # logging_level and output.
+ #
+ # The valid values of logging_level are:
+ # "LALL", "LNONE", "LTRACE", "LDEBUG", "LINFO", "LWARN", "LERROR", "LFATAL",
+ # "ALL", "NONE", "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL",
+ # "all", "none", "trace", "debug", "info", "warn", "error", "fatal", or
+ # any integral value
+ #
+ # The valid values of output are:
+ # "cout", "cerr", "clog", or a string of the form "file some_file_name"
+ # which causes the output to be logged to the specified file.
+ #
+ !*/
+
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef NO_MAKEFILE
+#include "logger_config_file.cpp"
+#endif
+
+#endif // DLIB_LOGGER_CONFIg_FILE_
+
+
+
diff --git a/ml/dlib/dlib/logger/logger_kernel_1.cpp b/ml/dlib/dlib/logger/logger_kernel_1.cpp
new file mode 100644
index 000000000..093cd29a8
--- /dev/null
+++ b/ml/dlib/dlib/logger/logger_kernel_1.cpp
@@ -0,0 +1,498 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LOGGER_KERNEL_1_CPp_
+#define DLIB_LOGGER_KERNEL_1_CPp_
+
+#include "logger_kernel_1.h"
+#include <iostream>
+#include <sstream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void set_all_logging_output_streams (
+ std::ostream& out_
+ )
+ {
+ logger::global_data& gd = logger::get_global_data();
+ auto_mutex M(gd.m);
+ gd.loggers.reset();
+ while (gd.loggers.move_next())
+ {
+ gd.loggers.element()->out.rdbuf(out_.rdbuf());
+ gd.loggers.element()->hook.clear();
+ }
+
+ gd.set_output_stream("",out_);
+
+ // set the default hook to be an empty member function pointer
+ logger::hook_mfp hook;
+ gd.set_output_hook("",hook);
+ }
+
+ void set_all_logging_levels (
+ const log_level& new_level
+ )
+ {
+ logger::global_data& gd = logger::get_global_data();
+ auto_mutex M(gd.m);
+ gd.loggers.reset();
+ while (gd.loggers.move_next())
+ {
+ gd.loggers.element()->cur_level = new_level;
+ }
+
+ gd.set_level("",new_level);
+ }
+
+ void set_all_logging_headers (
+ const print_header_type& new_header
+ )
+ {
+ logger::global_data& gd = logger::get_global_data();
+ auto_mutex M(gd.m);
+ gd.loggers.reset();
+ while (gd.loggers.move_next())
+ {
+ gd.loggers.element()->print_header = new_header;
+ }
+
+ gd.set_logger_header("",new_header);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace logger_helper_stuff
+ {
+ class helper
+ {
+ public:
+ helper()
+ {
+ std::ostringstream sout;
+ print_default_logger_header(sout,"some_name",LDEBUG,0);
+ }
+ };
+ // do this to make sure all the static members of print_default_logger_header get
+ // initialized when the program turns on.
+ static helper a;
+ // make a logger to make extra sure the static global_data object gets
+ // initialized before any threads start up. Also do this so that there is always
+ // at least one logger so that the global data won't be deleted until the
+ // program is terminating.
+ static logger log("dlib");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void print_default_logger_header (
+ std::ostream& out,
+ const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id
+ )
+ {
+ using namespace std;
+ static timestamper ts;
+ static const uint64 first_time = ts.get_timestamp();
+
+ const uint64 cur_time = (ts.get_timestamp() - first_time)/1000;
+ streamsize old_width = out.width(); out.width(5);
+ out << cur_time << " " << l.name;
+ out.width(old_width);
+
+ out << " [" << thread_id << "] " << logger_name << ": ";
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// global_data stuff
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ logger::global_data::
+ ~global_data (
+ )
+ {
+ unregister_thread_end_handler(*this,&global_data::thread_end_handler);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ logger::global_data::
+ global_data(
+ ) :
+ next_thread_name(1)
+ {
+ // make sure the main program thread always has id 0. Since there is
+ // a global logger object declared in this file we should expect that
+ // the global_data object will be initialized in the main program thread
+ // so if we call get_thread_id() now we should get the main thread id.
+ thread_id_type main_id = get_thread_id();
+ uint64 id_zero = 0;
+ thread_names.add(main_id,id_zero);
+
+ // set up the defaults
+ auto_flush_table.val = true;
+ streambuf_table.val = std::cout.rdbuf();
+ header_table.val = print_default_logger_header;
+
+ // also allocate an initial buffer for hook based logging
+ hookbuf.buffer.reserve(1000);
+ }
+
+ logger::global_data::level_container::
+ level_container (
+ ) : val(300,"ERROR") {}
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ const T& search_tables (
+ const T& c,
+ const std::string& name
+ )
+ {
+ if (c.table.size() == 0 || name.size() == 0)
+ return c;
+
+ const std::string::size_type pos = name.find_first_of(".");
+ const std::string first = name.substr(0,pos);
+ std::string last;
+ if (pos != std::string::npos)
+ last = name.substr(pos+1);
+
+ if (c.table.is_in_domain(first))
+ {
+ return search_tables(*c.table[first], last);
+ }
+ else
+ {
+ return c;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ void assign_tables (
+ T& c,
+ const std::string& name,
+ const U& val
+ )
+ {
+ if (name.size() == 0)
+ {
+ c.val = val;
+ c.table.clear();
+ return;
+ }
+
+ const std::string::size_type pos = name.find_first_of(".");
+ std::string first = name.substr(0,pos);
+ std::string last;
+ if (pos != std::string::npos)
+ last = name.substr(pos+1);
+
+ if (c.table.is_in_domain(first))
+ {
+ assign_tables(*c.table[first], last, val);
+ }
+ else
+ {
+ std::unique_ptr<T> temp (new T);
+ temp->val = c.val;
+ assign_tables(*temp, last, val);
+ c.table.add(first,temp);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const log_level logger::global_data::
+ level (
+ const std::string& name
+ ) const
+ {
+ auto_mutex M(m);
+ return search_tables(level_table, name).val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void logger::global_data::
+ set_level (
+ const std::string& name,
+ const log_level& new_level
+ )
+ {
+ auto_mutex M(m);
+ assign_tables(level_table, name, new_level);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool logger::global_data::
+ auto_flush (
+ const std::string& name
+ ) const
+ {
+ auto_mutex M(m);
+ return search_tables(auto_flush_table, name).val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void logger::global_data::
+ set_auto_flush (
+ const std::string& name,
+ bool enabled
+ )
+ {
+ auto_mutex M(m);
+ assign_tables(auto_flush_table, name, enabled);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::streambuf* logger::global_data::
+ output_streambuf (
+ const std::string& name
+ )
+ {
+ auto_mutex M(m);
+ return search_tables(streambuf_table, name).val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void logger::global_data::
+ set_output_stream (
+ const std::string& name,
+ std::ostream& out_
+ )
+ {
+ auto_mutex M(m);
+ assign_tables( streambuf_table, name, out_.rdbuf());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void logger::global_data::
+ set_output_stream (
+ const std::string& name,
+ std::streambuf& buf
+ )
+ {
+ auto_mutex M(m);
+ assign_tables( streambuf_table, name, &buf);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ logger::hook_mfp logger::global_data::
+ output_hook (
+ const std::string& name
+ )
+ {
+ auto_mutex M(m);
+ return search_tables(hook_table, name).val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void logger::global_data::
+ set_output_hook (
+ const std::string& name,
+ const hook_mfp& hook
+ )
+ {
+ auto_mutex M(m);
+ assign_tables( hook_table, name, hook);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ print_header_type logger::global_data::
+ logger_header (
+ const std::string& name
+ )
+ {
+ auto_mutex M(m);
+ return search_tables(header_table, name).val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void logger::global_data::
+ set_logger_header (
+ const std::string& name,
+ print_header_type ph
+ )
+ {
+ auto_mutex M(m);
+ assign_tables(header_table, name, ph);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ logger::global_data& logger::get_global_data()
+ {
+ // Allocate the global_data on the heap rather than on the stack because
+ // we want to guard against the case where this static object would be destroyed
+ // during program termination BEFORE all logger objects are destroyed.
+ static global_data* gd = new global_data;
+ return *gd;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void logger::global_data::
+ thread_end_handler (
+ )
+ {
+ auto_mutex M(m);
+ thread_id_type id = get_thread_id();
+ thread_id_type junkd;
+ uint64 junkr;
+ thread_names.remove(id,junkd,junkr);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint64 logger::global_data::
+ get_thread_name (
+ )
+ {
+ thread_id_type id = get_thread_id();
+ uint64 thread_name;
+ if (thread_names.is_in_domain(id))
+ {
+ thread_name = thread_names[id];
+ }
+ else
+ {
+ if (is_dlib_thread(id))
+ register_thread_end_handler(*this,&global_data::thread_end_handler);
+ thread_name = next_thread_name;
+ thread_names.add(id,thread_name);
+ thread_name = next_thread_name;
+ ++next_thread_name;
+ }
+ return thread_name;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// logger_stream stuff
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void logger::logger_stream::
+ print_header_and_stuff (
+ )
+ {
+ if (!been_used)
+ {
+ log.gd.m.lock();
+
+ // Check if the output hook is setup. If it isn't then we print the logger
+ // header like normal. Otherwise we need to remember to clear out the output
+ // stringstream we always write to.
+ if (log.hook.is_set() == false)
+ {
+ log.logger_header()(log.out,log.name(),l,log.gd.get_thread_name());
+ }
+ else
+ {
+ // Make sure the hook buffer doesn't have any old data in it before we start
+ // logging a new message into it.
+ log.gd.hookbuf.buffer.resize(0);
+ }
+ been_used = true;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void logger::logger_stream::
+ print_end_of_line (
+ )
+ {
+ auto_unlock M(log.gd.m);
+
+ if (log.hook.is_set() == false)
+ {
+ if (log.auto_flush_enabled)
+ log.out << std::endl;
+ else
+ log.out << "\n";
+ }
+ else
+ {
+ // Make sure the buffer is a proper C-string
+ log.gd.hookbuf.buffer.push_back('\0');
+ // call the output hook with all the info regarding this log message.
+ log.hook(log.name(), l, log.gd.get_thread_name(), &log.gd.hookbuf.buffer[0]);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// logger stuff
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ logger::
+ logger (
+ const std::string& name_
+ ) :
+ gd(get_global_data()),
+ logger_name(name_),
+ out(gd.output_streambuf(logger_name)),
+ cur_level(gd.level(logger_name))
+ {
+ DLIB_ASSERT(name_[0] != '\0',
+ "\tlogger::logger()"
+ << "\n\tYou can't make a logger with an empty name"
+ << "\n\tthis: " << this
+ );
+
+ auto_mutex M(gd.m);
+ logger* temp = this;
+ gd.loggers.add(temp);
+
+ // load the appropriate settings
+ print_header = gd.logger_header(logger_name);
+ auto_flush_enabled = gd.auto_flush(logger_name);
+ hook = gd.output_hook(logger_name);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ logger::
+ ~logger (
+ )
+ {
+ gd.m.lock();
+ gd.loggers.destroy(this);
+ // if this was the last logger then delete the global data
+ if (gd.loggers.size() == 0)
+ {
+ gd.m.unlock();
+ delete &gd;
+ }
+ else
+ {
+ gd.m.unlock();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LOGGER_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/logger/logger_kernel_1.h b/ml/dlib/dlib/logger/logger_kernel_1.h
new file mode 100644
index 000000000..528bd6f67
--- /dev/null
+++ b/ml/dlib/dlib/logger/logger_kernel_1.h
@@ -0,0 +1,687 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LOGGER_KERNEl_1_
+#define DLIB_LOGGER_KERNEl_1_
+
+#include <limits>
+#include <memory>
+#include <cstring>
+#include <streambuf>
+#include <vector>
+
+#include "../threads.h"
+#include "../misc_api.h"
+#include "../set.h"
+#include "logger_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include "../uintn.h"
+#include "../map.h"
+#include "../member_function_pointer.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class log_level
+ {
+ public:
+ log_level(
+ int priority_,
+ const char* name_
+ ) :
+ priority(priority_)
+ {
+ strncpy(name,name_,19);
+ name[19] = '\0';
+ }
+
+ bool operator< (const log_level& rhs) const { return priority < rhs.priority; }
+ bool operator<=(const log_level& rhs) const { return priority <= rhs.priority; }
+ bool operator> (const log_level& rhs) const { return priority > rhs.priority; }
+ bool operator>=(const log_level& rhs) const { return priority >= rhs.priority; }
+
+ int priority;
+ char name[20];
+ };
+
+ inline std::ostream& operator<< (std::ostream& out, const log_level& item)
+ {
+ out << item.name;
+ return out;
+ }
+
+ const log_level LALL (std::numeric_limits<int>::min(),"ALL");
+ const log_level LNONE (std::numeric_limits<int>::max(),"NONE");
+ const log_level LTRACE(-100,"TRACE");
+ const log_level LDEBUG(0 ,"DEBUG");
+ const log_level LINFO (100,"INFO ");
+ const log_level LWARN (200,"WARN ");
+ const log_level LERROR(300,"ERROR");
+ const log_level LFATAL(400,"FATAL");
+
+// ----------------------------------------------------------------------------------------
+
+ void set_all_logging_output_streams (
+ std::ostream& out
+ );
+
+ void set_all_logging_levels (
+ const log_level& new_level
+ );
+
+ typedef void (*print_header_type)(
+ std::ostream& out,
+ const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id
+ );
+
+ void set_all_logging_headers (
+ const print_header_type& new_header
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ void print_default_logger_header (
+ std::ostream& out,
+ const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id
+ );
+
+ template <
+ typename T
+ >
+ void set_all_logging_output_hooks (
+ T& object,
+ void (T::*hook_)(const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id,
+ const char* message_to_log)
+ );
+
+ template <
+ typename T
+ >
+ void set_all_logging_output_hooks (
+ T& object
+ )
+ {
+ set_all_logging_output_hooks(object, &T::log);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class logger
+ {
+ /*!
+ INITIAL VALUE
+ - print_header == print_default_logger_header
+ - out.rdbuf() == std::cout.rdbuf()
+ - cur_level == LERROR
+ - auto_flush_enabled == true
+ - hook.is_set() == false
+
+ CONVENTION
+ - print_header == logger_header()
+ - if (hook.is_set() == false) then
+ - out.rdbuf() == output_streambuf()
+ - else
+ - out.rdbuf() == &gd.hookbuf
+ - output_streambuf() == 0
+
+ - cur_level == level()
+ - logger_name == name()
+ - auto_flush_enabled == auto_flush()
+
+ - logger::gd::loggers == a set containing all currently existing loggers.
+ - logger::gd::m == the mutex used to lock everything in the logger
+ - logger::gd::thread_names == a map of thread ids to thread names.
+ - logger::gd::next_thread_name == the next thread name that will be given out
+ to a thread when we find that it isn't already in thread_names.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ class logger_stream
+ {
+ /*!
+ INITIAL VALUE
+ - been_used == false
+
+ CONVENTION
+ - enabled == is_enabled()
+ - if (been_used) then
+ - logger::gd::m is locked
+ - someone has used the << operator to write something to the
+ output stream.
+ !*/
+ public:
+ logger_stream (
+ const log_level& l_,
+ logger& log_
+ ) :
+ l(l_),
+ log(log_),
+ been_used(false),
+ enabled (l.priority >= log.cur_level.priority)
+ {}
+
+ inline ~logger_stream(
+ )
+ {
+ if (!been_used)
+ {
+ return;
+ }
+ else
+ {
+ print_end_of_line();
+ }
+ }
+
+ bool is_enabled (
+ ) const { return enabled; }
+
+ template <typename T>
+ inline logger_stream& operator << (
+ const T& item
+ )
+ {
+ if (!enabled)
+ {
+ return *this;
+ }
+ else
+ {
+ print_header_and_stuff();
+ log.out << item;
+ return *this;
+ }
+ }
+
+ private:
+
+ void print_header_and_stuff (
+ );
+ /*!
+ ensures
+ - if (!been_used) then
+ - prints the logger header
+ - locks log.gd.m
+ - #been_used == true
+ !*/
+
+ void print_end_of_line (
+ );
+ /*!
+ ensures
+ - prints a newline to log.out
+ - unlocks log.gd.m
+ !*/
+
+ const log_level& l;
+ logger& log;
+ bool been_used;
+ const bool enabled;
+ }; // end of class logger_stream
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ friend class logger_stream;
+ public:
+
+ typedef member_function_pointer<const std::string&, const log_level&,
+ const uint64, const char*> hook_mfp;
+
+ logger (
+ const std::string& name_
+ );
+
+ virtual ~logger (
+ );
+
+ const std::string& name (
+ ) const { return logger_name; }
+
+ logger_stream operator << (
+ const log_level& l
+ ) const { return logger_stream(l,const_cast<logger&>(*this)); }
+
+ bool is_child_of (
+ const logger& log
+ ) const
+ {
+ return (name().find(log.name() + ".") == 0) || (log.name() == name());
+ }
+
+ const log_level level (
+ ) const
+ {
+ auto_mutex M(gd.m);
+ return log_level(cur_level);
+ };
+
+ void set_level (
+ const log_level& new_level
+ )
+ {
+ auto_mutex M(gd.m);
+ gd.loggers.reset();
+ while (gd.loggers.move_next())
+ {
+ if (gd.loggers.element()->is_child_of(*this))
+ gd.loggers.element()->cur_level = new_level;
+ }
+
+ gd.set_level(logger_name, new_level);
+ }
+
+ bool auto_flush (
+ ) const
+ {
+ auto_mutex M(gd.m);
+ return auto_flush_enabled;
+ };
+
+ void set_auto_flush (
+ bool enabled
+ )
+ {
+ auto_mutex M(gd.m);
+ gd.loggers.reset();
+ while (gd.loggers.move_next())
+ {
+ if (gd.loggers.element()->is_child_of(*this))
+ gd.loggers.element()->auto_flush_enabled = enabled;
+ }
+
+ gd.set_auto_flush(logger_name, enabled);
+ }
+
+ std::streambuf* output_streambuf (
+ )
+ {
+ auto_mutex M(gd.m);
+
+ // if there is an output hook set then we are supposed to return 0.
+ if (hook)
+ return 0;
+ else
+ return out.rdbuf();
+ }
+
+ template <
+ typename T
+ >
+ void set_output_hook (
+ T& object,
+ void (T::*hook_)(const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id,
+ const char* message_to_log)
+ )
+ {
+ auto_mutex M(gd.m);
+ hook.set(object, hook_);
+
+ gd.loggers.reset();
+ while (gd.loggers.move_next())
+ {
+ if (gd.loggers.element()->is_child_of(*this))
+ {
+ gd.loggers.element()->out.rdbuf(&gd.hookbuf);
+ gd.loggers.element()->hook = hook;
+ }
+ }
+
+ gd.set_output_hook(logger_name, hook);
+ gd.set_output_stream(logger_name, gd.hookbuf);
+ }
+
+ void set_output_stream (
+ std::ostream& out_
+ )
+ {
+ auto_mutex M(gd.m);
+ gd.loggers.reset();
+ while (gd.loggers.move_next())
+ {
+ if (gd.loggers.element()->is_child_of(*this))
+ {
+ gd.loggers.element()->out.rdbuf(out_.rdbuf());
+ gd.loggers.element()->hook.clear();
+ }
+ }
+
+ gd.set_output_stream(logger_name, out_);
+
+ hook.clear();
+ gd.set_output_hook(logger_name, hook);
+ }
+
+ print_header_type logger_header (
+ ) const { return print_header; }
+
+ void set_logger_header (
+ print_header_type ph
+ )
+ {
+ auto_mutex M(gd.m);
+ gd.loggers.reset();
+ while (gd.loggers.move_next())
+ {
+ if (gd.loggers.element()->is_child_of(*this))
+ gd.loggers.element()->print_header = ph;
+ }
+
+ gd.set_logger_header(logger_name, ph);
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ struct global_data
+ {
+ rmutex m;
+ set<logger*>::kernel_1b loggers;
+ map<thread_id_type,uint64>::kernel_1b thread_names;
+ uint64 next_thread_name;
+
+ // Make a very simple streambuf that writes characters into a std::vector<char>. We can
+ // use this as the output target for hooks. The reason we don't just use a std::ostringstream
+ // instead is that this way we can be guaranteed that logging doesn't perform memory allocations.
+ // This is because a std::vector never frees memory. I.e. its capacity() doesn't go down when
+ // you resize it back to 0. It just stays the same.
+ class hook_streambuf : public std::streambuf
+ {
+ public:
+ std::vector<char> buffer;
+ int_type overflow ( int_type c)
+ {
+ if (c != EOF) buffer.push_back(static_cast<char>(c));
+ return c;
+ }
+
+ std::streamsize xsputn ( const char* s, std::streamsize num)
+ {
+ buffer.insert(buffer.end(), s, s+num);
+ return num;
+ }
+ };
+
+ hook_streambuf hookbuf;
+
+ global_data (
+ );
+
+ ~global_data(
+ );
+
+ uint64 get_thread_name (
+ );
+ /*!
+ requires
+ - m is locked
+ ensures
+ - returns a unique id for the calling thread. also makes the number
+ small and nice unlike what you get from get_thread_id()
+ !*/
+
+ void thread_end_handler (
+ );
+ /*!
+ ensures
+ - removes the terminated thread from thread_names
+ !*/
+
+ struct level_container
+ {
+ level_container ();
+
+ log_level val;
+ map<std::string,std::unique_ptr<level_container> >::kernel_1b_c table;
+ } level_table;
+
+ const log_level level (
+ const std::string& name
+ ) const;
+ /*!
+ ensures
+ - returns the level loggers with the given name are supposed
+ to have
+ !*/
+
+ void set_level (
+ const std::string& name,
+ const log_level& new_level
+ );
+ /*!
+ ensures
+ - for all children C of name:
+ - #level(C) == new_level
+ - if name == "" then
+ - for all loggers L:
+ - #level(L) == new_level
+ !*/
+
+ struct auto_flush_container
+ {
+ bool val;
+ map<std::string,std::unique_ptr<auto_flush_container> >::kernel_1b_c table;
+ } auto_flush_table;
+
+ bool auto_flush (
+ const std::string& name
+ ) const;
+ /*!
+ ensures
+ - returns the auto_flush value loggers with the given name are supposed
+ to have
+ !*/
+
+ void set_auto_flush (
+ const std::string& name,
+ bool enabled
+ );
+ /*!
+ ensures
+ - for all children C of name:
+ - #auto_flush_enabled(C) == enabled
+ - if name == "" then
+ - for all loggers L:
+ - #auto_flush_enabled(L) == enabled
+ !*/
+
+ struct output_streambuf_container
+ {
+ std::streambuf* val;
+ map<std::string,std::unique_ptr<output_streambuf_container> >::kernel_1b_c table;
+ } streambuf_table;
+
+ std::streambuf* output_streambuf (
+ const std::string& name
+ );
+ /*!
+ ensures
+ - returns the streambuf loggers with the given name are supposed
+ to have
+ !*/
+
+ void set_output_stream (
+ const std::string& name,
+ std::ostream& out_
+ );
+ /*!
+ ensures
+ - for all children C of name:
+ - #output_streambuf(C) == out_.rdbuf()
+ - if name == "" then
+ - for all loggers L:
+ - #output_streambuf(L) == out_.rdbuf()
+ !*/
+
+ void set_output_stream (
+ const std::string& name,
+ std::streambuf& buf
+ );
+ /*!
+ ensures
+ - for all children C of name:
+ - #output_streambuf(C) == &buf
+ - if name == "" then
+ - for all loggers L:
+ - #output_streambuf(L) == &buf
+ !*/
+
+ struct output_hook_container
+ {
+ hook_mfp val;
+ map<std::string,std::unique_ptr<output_hook_container> >::kernel_1b_c table;
+ } hook_table;
+
+ hook_mfp output_hook (
+ const std::string& name
+ );
+ /*!
+ ensures
+ - returns the hook loggers with the given name are supposed
+ to have
+ !*/
+
+ void set_output_hook (
+ const std::string& name,
+ const hook_mfp& hook
+ );
+ /*!
+ ensures
+ - for all children C of name:
+ - #output_hook(C) == hook
+ - if name == "" then
+ - for all loggers L:
+ - #output_hook(L) == hook
+ !*/
+
+ struct logger_header_container
+ {
+ print_header_type val;
+ map<std::string,std::unique_ptr<logger_header_container> >::kernel_1b_c table;
+ } header_table;
+
+ print_header_type logger_header (
+ const std::string& name
+ );
+ /*!
+ ensures
+ - returns the header function loggers with the given name are supposed
+ to have
+ !*/
+
+ void set_logger_header (
+ const std::string& name,
+ print_header_type ph
+ );
+ /*!
+ ensures
+ - for all children C of name:
+ - #logger_header(C) == ph
+ - if name == "" then
+ - for all loggers L:
+ - #logger_header(L) == ph
+ !*/
+
+ }; // end of struct global_data
+
+ static global_data& get_global_data();
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ friend void set_all_logging_levels (
+ const log_level& new_level
+ );
+
+ friend void set_all_logging_headers (
+ const print_header_type& new_header
+ );
+
+ friend void set_all_logging_output_streams (
+ std::ostream& out
+ );
+
+ template <
+ typename T
+ >
+ friend void set_all_logging_output_hooks (
+ T& object,
+ void (T::*hook_)(const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id,
+ const char* message_to_log)
+ )
+ {
+ logger::hook_mfp hook;
+
+ // There is a bug in one of the versions (but not all apparently) of
+ // Visual studio 2005 that causes it to error out if <T> isn't in the
+ // following line of code. However, there is also a bug in gcc-3.3
+ // that causes it to error out if <T> is present. So this works around
+ // this problem.
+#if defined(_MSC_VER) && _MSC_VER == 1400
+ hook.set<T>(object, hook_);
+#else
+ hook.set(object, hook_);
+#endif
+
+ logger::global_data& gd = logger::get_global_data();
+ auto_mutex M(gd.m);
+ gd.loggers.reset();
+ while (gd.loggers.move_next())
+ {
+ gd.loggers.element()->out.rdbuf(&gd.hookbuf);
+ gd.loggers.element()->hook = hook;
+ }
+
+ gd.set_output_stream("",gd.hookbuf);
+ gd.set_output_hook("",hook);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ global_data& gd;
+
+ const std::string logger_name;
+
+ print_header_type print_header;
+ bool auto_flush_enabled;
+ std::ostream out;
+ log_level cur_level;
+
+ hook_mfp hook;
+
+
+ // restricted functions
+ logger(const logger&); // copy constructor
+ logger& operator=(const logger&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+
+
+
+}
+
+#ifdef NO_MAKEFILE
+#include "logger_kernel_1.cpp"
+#endif
+
+#endif // DLIB_LOGGER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/logger/logger_kernel_abstract.h b/ml/dlib/dlib/logger/logger_kernel_abstract.h
new file mode 100644
index 000000000..b6a4367a2
--- /dev/null
+++ b/ml/dlib/dlib/logger/logger_kernel_abstract.h
@@ -0,0 +1,429 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LOGGER_KERNEl_ABSTRACT_
+#ifdef DLIB_LOGGER_KERNEl_ABSTRACT_
+
+#include "../threads.h"
+#include <limits>
+#include <string>
+#include <iostream>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class log_level
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple named level to log at. It contains a numeric
+ priority and a name to use in the logging messages.
+ !*/
+ public:
+ log_level(
+ int priority_,
+ const char* name_
+ );
+ /*!
+ ensures
+ - #priority = priority_
+ - the first 19 characters of name_ are copied into name and name
+ is null terminated.
+ !*/
+
+ bool operator< (const log_level& rhs) const { return priority < rhs.priority; }
+ bool operator<=(const log_level& rhs) const { return priority <= rhs.priority; }
+ bool operator> (const log_level& rhs) const { return priority > rhs.priority; }
+ bool operator>=(const log_level& rhs) const { return priority >= rhs.priority; }
+
+ int priority;
+ char name[20];
+ };
+
+ inline std::ostream& operator<< (std::ostream& out, const log_level& item);
+ /*!
+ ensures
+ - performs out << item.name
+ - returns out
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const log_level LALL (std::numeric_limits<int>::min(),"ALL");
+ const log_level LNONE (std::numeric_limits<int>::max(),"NONE");
+ const log_level LTRACE(-100,"TRACE");
+ const log_level LDEBUG(0 ,"DEBUG");
+ const log_level LINFO (100 ,"INFO ");
+ const log_level LWARN (200 ,"WARN ");
+ const log_level LERROR(300 ,"ERROR");
+ const log_level LFATAL(400 ,"FATAL");
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void set_all_logging_output_streams (
+ std::ostream& out
+ );
+ /*!
+ ensures
+ - for all loggers L (even loggers not yet constructed):
+ - #L.output_streambuf() == out.rdbuf()
+ - Removes any previous output hook from L. So now the logger
+ L will write all its messages to the given output stream.
+ throws
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ typedef void (*print_header_type)(
+ std::ostream& out,
+ const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id
+ );
+
+ void set_all_logging_headers (
+ const print_header_type& new_header
+ );
+ /*!
+ ensures
+ - for all loggers L (even loggers not yet constructed):
+ - #L.logger_header() == new_header
+ throws
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void set_all_logging_output_hooks (
+ T& object,
+ void (T::*hook)(const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id,
+ const char* message_to_log)
+ );
+ /*!
+ ensures
+ - for all loggers L (even loggers not yet constructed):
+ - #L.output_streambuf() == 0
+ - performs the equivalent to calling L.set_output_hook(object, hook);
+ (i.e. sets all loggers so that they will use the given hook function)
+ throws
+ - std::bad_alloc
+ !*/
+
+ template <
+ typename T
+ >
+ void set_all_logging_output_hooks (
+ T& object
+ );
+ /*!
+ ensures
+ - calls set_all_logging_output_hooks(object, &T::log);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void set_all_logging_levels (
+ const log_level& new_level
+ );
+ /*!
+ ensures
+ - for all loggers L (even loggers not yet constructed):
+ - #L.level() == new_level
+ throws
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void print_default_logger_header (
+ std::ostream& out,
+ const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id
+ );
+ /*!
+ requires
+ - is not called more than once at a time (i.e. is not called from multiple
+ threads at the same time).
+ ensures
+ - let MS be the number of milliseconds since program start.
+ - prints a string to out in the form: "MS l.name [thread_id] logger_name:"
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class logger
+ {
+ /*!
+ INITIAL VALUE
+ - name() == a user supplied value given to the constructor
+ - The values of level(), output_streambuf(), logger_header(), and
+ auto_flush() are inherited from the parent of this logger.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a logging output stream in the style of the log4j
+ logger available for Java.
+
+ Additionally, the logger doesn't perform any memory allocations during
+ each logging action. It just writes directly into the user supplied output
+ stream. Alternatively, if you use a logging output hook no memory allocations
+ are performed either. Logging just goes straight into a memory buffer
+ which gets passed to the user supplied logging hook.
+
+ DEFAULTS
+ If the user hasn't specified values for the four inherited values level(),
+ output_streambuf(), logger_header(), or auto_flush() then the default
+ values will be used. The defaults are as follows:
+ - level() == LERROR
+ - output_streambuf() == std::cout.rdbuf() (i.e. the default is to log
+ to standard output).
+ - logger_header() == print_default_logger_header
+ - auto_flush() == true
+
+ THREAD SAFETY
+ All methods of this class are thread safe. Note that it is safe to
+ chain calls to operator << such as:
+ log << LINFO << "message " << variable << " more message";
+ The logger ensures that the entire statement executes atomically so the
+ message won't be broken up by other loggers in other threads.
+ !*/
+
+ class logger_stream
+ {
+ public:
+
+ bool is_enabled (
+ ) const;
+ /*!
+ ensures
+ - returns true if this logger stream will print out items
+ given to it by the << operator. returns false otherwise.
+ !*/
+
+ template <typename T>
+ logger_stream& operator << (
+ const T& item
+ );
+ /*!
+ ensures
+ - if (is_enabled()) then
+ - writes item to this output stream
+ - returns *this
+ !*/
+ };
+
+ public:
+
+ logger (
+ const std::string& name_
+ );
+ /*!
+ requires
+ - name_ != ""
+ ensures
+ - #*this is properly initialized
+ - #name() == name_
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~logger (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ const std::string& name (
+ ) const;
+ /*!
+ ensures
+ - returns the name of this logger
+ !*/
+
+ logger_stream operator << (
+ const log_level& l
+ ) const;
+ /*!
+ ensures
+ - if (l.priority >= level().priority) then
+ - returns a logger_stream with is_enabled() == true. I.e. this
+ returned stream will write its output to the I/O destination
+ used by this logger object.
+ - else
+ - returns a logger stream with is_enabled() == false
+ throws
+ - std::bad_alloc
+ !*/
+
+ bool is_child_of (
+ const logger& log
+ ) const;
+ /*!
+ ensures
+ - if ( (name().find(log.name() + ".") == 0) || (log.name() == name()) ) then
+ - returns true
+ (i.e. if log.name() + "." is a prefix of name() or if both *this and log
+ have the same name then return true)
+ - else
+ - returns false
+ !*/
+
+ const log_level level (
+ ) const;
+ /*!
+ ensures
+ - returns the current log level of this logger.
+ !*/
+
+ void set_level (
+ const log_level& new_level
+ );
+ /*!
+ ensures
+ - for all loggers L such that L.is_child_of(*this) == true:
+ - #L.level() == new_level
+ throws
+ - std::bad_alloc
+ !*/
+
+ bool auto_flush (
+ );
+ /*!
+ ensures
+ - returns true if the output stream is flushed after every logged message.
+ returns false otherwise. (Note that flushing only does anything if
+ the logger is set to use an output stream rather than a hook)
+ !*/
+
+ void set_auto_flush (
+ bool enabled
+ );
+ /*!
+ ensures
+ - for all loggers L such that L.is_child_of(*this) == true:
+ - #L.auto_flush() == enabled
+ throws
+ - std::bad_alloc
+ !*/
+
+
+ template <
+ typename T
+ >
+ void set_output_hook (
+ T& object,
+ void (T::*hook)(const std::string& logger_name,
+ const log_level& l,
+ const uint64 thread_id,
+ const char* message_to_log)
+ );
+ /*!
+ requires
+ - hook is a valid pointer to a member function in T
+ ensures
+ - for all loggers L such that L.is_child_of(*this) == true:
+ - #L.output_streambuf() == 0
+ - #L will not send its log messages to an ostream object anymore. Instead
+ it will call the given hook member function (i.e. (object.*hook)(name,l,id,msg) )
+ for each message that needs to be logged.
+ - The arguments to the hook function have the following meanings:
+ - logger_name == The name of the logger that is printing the log message.
+ - l == The level of the logger that is printing the log message.
+ - thread_id == A number that uniquely identifies the thread trying to log
+ the message. Note that this number is unique among all threads, past and
+ present. Also note that this id is not the same one returned by
+ get_thread_id().
+ - message_to_log == the actual text of the message the user is giving to
+ the logger object to log.
+ - All hook functions will also only be called one at a time. This means
+ that hook functions don't need to be thread safe.
+ !*/
+
+ std::streambuf* output_streambuf (
+ );
+ /*!
+ ensures
+ - if (an output hook isn't set) then
+ - returns the output stream buffer that this logger writes all
+ messages to.
+ - else
+ - returns 0
+ !*/
+
+ void set_output_stream (
+ std::ostream& out
+ );
+ /*!
+ ensures
+ - for all loggers L such that L.is_child_of(*this) == true:
+ - #L.output_streambuf() == out.rdbuf()
+ - Removes any previous output hook from L. So now the logger
+ L will write all its messages to the given output stream.
+ throws
+ - std::bad_alloc
+ !*/
+
+ print_header_type logger_header (
+ ) const;
+ /*!
+ ensures
+ - returns the function that is called to print the header information
+ onto each logged message. The arguments to the function have the following
+ meanings:
+ - out == The output stream this function writes the header to.
+ - logger_name == The name of the logger that is printing the log message.
+ - l == The level of the logger that is printing the log message.
+ - thread_id == A number that uniquely identifies the thread trying to log
+ the message. Note that this number is unique among all threads, past and
+ present. Also note that this id is not the same one returned by
+ get_thread_id().
+ - This logger_header function will also only be called once at a time. This means
+ the logger_header function doesn't need to be thread safe.
+ - the logger_header function is only used when output_streambuf() != 0
+ !*/
+
+ void set_logger_header (
+ print_header_type print_header
+ );
+ /*!
+ ensures
+ - for all loggers L such that L.is_child_of(*this) == true:
+ - #L.logger_header() == print_header
+ throws
+ - std::bad_alloc
+ !*/
+
+ private:
+
+ // restricted functions
+ logger(const logger&); // copy constructor
+ logger& operator=(const logger&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LOGGER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/lsh.h b/ml/dlib/dlib/lsh.h
new file mode 100644
index 000000000..28f4b9bc4
--- /dev/null
+++ b/ml/dlib/dlib/lsh.h
@@ -0,0 +1,14 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LSh_
+#define DLIB_LSh_
+
+
+#include "lsh/projection_hash.h"
+#include "lsh/create_random_projection_hash.h"
+#include "lsh/hashes.h"
+
+
+#endif // DLIB_LSh_
+
+
diff --git a/ml/dlib/dlib/lsh/create_random_projection_hash.h b/ml/dlib/dlib/lsh/create_random_projection_hash.h
new file mode 100644
index 000000000..b3aecd9ec
--- /dev/null
+++ b/ml/dlib/dlib/lsh/create_random_projection_hash.h
@@ -0,0 +1,232 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CREATE_RANDOM_PROJECTION_HAsH_Hh_
+#define DLIB_CREATE_RANDOM_PROJECTION_HAsH_Hh_
+
+#include "create_random_projection_hash_abstract.h"
+#include "projection_hash.h"
+#include "../matrix.h"
+#include "../rand.h"
+#include "../statistics.h"
+#include "../svm.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ projection_hash create_random_projection_hash (
+ const vector_type& v,
+ const int bits,
+ dlib::rand& rnd
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < bits && bits <= 32 &&
+ v.size() > 1,
+ "\t projection_hash create_random_projection_hash()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t bits: " << bits
+ << "\n\t v.size(): " << v.size()
+ );
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < v.size(); ++i)
+ {
+ DLIB_ASSERT(v[0].size() == v[i].size() && v[i].size() > 0 && is_col_vector(v[i]),
+ "\t projection_hash create_random_projection_hash()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t m(0).size(): " << v[0].size()
+ << "\n\t m("<<i<<").size(): " << v[i].size()
+ << "\n\t is_col_vector(v["<<i<<"]): " << is_col_vector(v[i])
+ );
+ }
+#endif
+
+ running_covariance<matrix<double> > rc;
+ for (unsigned long i = 0; i < v.size(); ++i)
+ rc.add(matrix_cast<double>(v[i]));
+
+ // compute a whitening matrix
+ matrix<double> whiten = trans(chol(pinv(rc.covariance())));
+
+
+ // hashes
+ std::vector<unsigned long> h(v.size(),0);
+
+ std::vector<double> vals(v.size(),0);
+
+ // number of hits for each hash value
+ std::vector<unsigned long> counts;
+
+ std::vector<double> temp;
+
+ // build a random projection matrix
+ matrix<double> proj(bits, v[0].size());
+ for (long r = 0; r < proj.nr(); ++r)
+ for (long c = 0; c < proj.nc(); ++c)
+ proj(r,c) = rnd.get_random_gaussian();
+
+ // merge whitening matrix with projection matrix
+ proj = proj*whiten;
+
+ matrix<double,0,1> offset(bits);
+
+
+ // figure out what the offset values should be
+ for (int itr = 0; itr < offset.size(); ++itr)
+ {
+ counts.assign(static_cast<unsigned long>(std::pow(2.0,bits)), 0);
+ // count the popularity of each hash value
+ for (unsigned long i = 0; i < h.size(); ++i)
+ {
+ h[i] <<= 1;
+ counts[h[i]] += 1;
+ }
+
+ const unsigned long max_h = index_of_max(mat(counts));
+
+ temp.clear();
+ for (unsigned long i = 0; i < v.size(); ++i)
+ {
+ vals[i] = dot(rowm(proj,itr), matrix_cast<double>(v[i]));
+ if (h[i] == max_h)
+ temp.push_back(vals[i]);
+ }
+
+ // split down the middle
+ std::sort(temp.begin(), temp.end());
+ const double split = temp[temp.size()/2];
+ offset(itr) = -split;
+
+ for (unsigned long i = 0; i < vals.size(); ++i)
+ {
+ if (vals[i] - split > 0)
+ h[i] |= 1;
+ }
+ }
+
+
+ return projection_hash(proj, offset);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ projection_hash create_random_projection_hash (
+ const vector_type& v,
+ const int bits
+ )
+ {
+ dlib::rand rnd;
+ return create_random_projection_hash(v,bits,rnd);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ projection_hash create_max_margin_projection_hash (
+ const vector_type& v,
+ const int bits,
+ const double C,
+ dlib::rand& rnd
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < bits && bits <= 32 &&
+ v.size() > 1,
+ "\t projection_hash create_max_margin_projection_hash()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t bits: " << bits
+ << "\n\t v.size(): " << v.size()
+ );
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < v.size(); ++i)
+ {
+ DLIB_ASSERT(v[0].size() == v[i].size() && v[i].size() > 0 && is_col_vector(v[i]),
+ "\t projection_hash create_max_margin_projection_hash()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t m(0).size(): " << v[0].size()
+ << "\n\t m("<<i<<").size(): " << v[i].size()
+ << "\n\t is_col_vector(v["<<i<<"]): " << is_col_vector(v[i])
+ );
+ }
+#endif
+
+ running_covariance<matrix<double> > rc;
+ for (unsigned long i = 0; i < v.size(); ++i)
+ rc.add(matrix_cast<double>(v[i]));
+
+ // compute a whitening matrix
+ matrix<double> whiten = trans(chol(pinv(rc.covariance())));
+ const matrix<double,0,1> meanval = whiten*rc.mean();
+
+
+
+ typedef matrix<double,0,1> sample_type;
+ random_subset_selector<sample_type> training_samples;
+ random_subset_selector<double> training_labels;
+ // We set this up to use enough samples to cover the vector space used by elements
+ // of v.
+ training_samples.set_max_size(v[0].size()*10);
+ training_labels.set_max_size(v[0].size()*10);
+
+ matrix<double> proj(bits, v[0].size());
+ matrix<double,0,1> offset(bits);
+
+ // learn the random planes and put them into proj and offset.
+ for (int itr = 0; itr < offset.size(); ++itr)
+ {
+ training_samples.make_empty();
+ training_labels.make_empty();
+ // pick random training data and give each sample a random label.
+ for (unsigned long i = 0; i < v.size(); ++i)
+ {
+ training_samples.add(whiten*v[i]-meanval);
+ if (rnd.get_random_double() > 0.5)
+ training_labels.add(+1);
+ else
+ training_labels.add(-1);
+ }
+
+ svm_c_linear_dcd_trainer<linear_kernel<sample_type> > trainer;
+ trainer.set_c(C);
+ decision_function<linear_kernel<sample_type> > df = trainer.train(training_samples, training_labels);
+ offset(itr) = -df.b;
+ set_rowm(proj,itr) = trans(df.basis_vectors(0));
+ }
+
+
+ return projection_hash(proj*whiten, offset-proj*meanval);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ projection_hash create_max_margin_projection_hash (
+ const vector_type& v,
+ const int bits,
+ const double C = 10
+ )
+ {
+ dlib::rand rnd;
+ return create_max_margin_projection_hash(v,bits,C,rnd);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CREATE_RANDOM_PROJECTION_HAsH_Hh_
+
diff --git a/ml/dlib/dlib/lsh/create_random_projection_hash_abstract.h b/ml/dlib/dlib/lsh/create_random_projection_hash_abstract.h
new file mode 100644
index 000000000..cff55b9a5
--- /dev/null
+++ b/ml/dlib/dlib/lsh/create_random_projection_hash_abstract.h
@@ -0,0 +1,148 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CREATE_RANDOM_PROJECTION_HAsH_ABSTRACT_Hh_
+#ifdef DLIB_CREATE_RANDOM_PROJECTION_HAsH_ABSTRACT_Hh_
+
+#include "projection_hash_abstract.h"
+#include "../rand.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ projection_hash create_random_projection_hash (
+ const vector_type& v,
+ const int bits,
+ dlib::rand& rnd
+ );
+ /*!
+ requires
+ - 0 < bits <= 32
+ - v.size() > 1
+ - vector_type == a std::vector or compatible type containing dlib::matrix
+ objects, each representing a column vector of the same size.
+ - for all valid i, j:
+ - is_col_vector(v[i]) == true
+ - v[i].size() > 0
+ - v[i].size() == v[j].size()
+ - i.e. v contains only column vectors and all the column vectors
+ have the same non-zero length
+ - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface
+ ensures
+ - returns a hash function H such that:
+ - H.num_hash_bins() == pow(2,bits)
+ - H will be setup so that it hashes the contents of v such that each bin
+ ends up with roughly the same number of elements in it. This is
+ accomplished by picking random hyperplanes passing though the data. In
+ particular, each plane normal vector is filled with Gaussian random
+ numbers and we also perform basic centering to ensure the plane passes
+ though the data.
+ - This function uses the supplied random number generator, rnd, to drive part
+ of it's processing. Therefore, giving different random number generators
+ will produce different outputs.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ projection_hash create_random_projection_hash (
+ const vector_type& v,
+ const int bits
+ );
+ /*!
+ requires
+ - 0 < bits <= 32
+ - v.size() > 1
+ - vector_type == a std::vector or compatible type containing dlib::matrix
+ objects, each representing a column vector of the same size.
+ - for all valid i, j:
+ - is_col_vector(v[i]) == true
+ - v[i].size() > 0
+ - v[i].size() == v[j].size()
+ - i.e. v contains only column vectors and all the column vectors
+ have the same non-zero length
+ ensures
+ - returns create_random_projection_hash(v,bits,dlib::rand())
+ (i.e. calls the above function with a default initialized random number generator)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ projection_hash create_max_margin_projection_hash (
+ const vector_type& v,
+ const int bits,
+ const double C,
+ dlib::rand& rnd
+ );
+ /*!
+ requires
+ - 0 < bits <= 32
+ - v.size() > 1
+ - vector_type == a std::vector or compatible type containing dlib::matrix
+ objects, each representing a column vector of the same size.
+ - for all valid i, j:
+ - is_col_vector(v[i]) == true
+ - v[i].size() > 0
+ - v[i].size() == v[j].size()
+ - i.e. v contains only column vectors and all the column vectors
+ have the same non-zero length
+ - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface
+ ensures
+ - returns a hash function H such that:
+ - H.num_hash_bins() == pow(2,bits)
+ - H will be setup so that it hashes the contents of v such that
+ each bin ends up with roughly the same number of elements
+ in it. This is accomplished using a variation on the random hyperplane
+ generation technique from the paper:
+ Random Maximum Margin Hashing by Alexis Joly and Olivier Buisson
+ In particular, we use the svm_c_linear_dcd_trainer to generate planes.
+ We train it on randomly selected and randomly labeled points from v.
+ The C SVM parameter is set to the given C argument.
+ - This function uses the supplied random number generator, rnd, to drive part
+ of it's processing. Therefore, giving different random number generators
+ will produce different outputs.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ projection_hash create_max_margin_projection_hash (
+ const vector_type& v,
+ const int bits,
+ const double C = 10
+ );
+ /*!
+ requires
+ - 0 < bits <= 32
+ - v.size() > 1
+ - vector_type == a std::vector or compatible type containing dlib::matrix
+ objects, each representing a column vector of the same size.
+ - for all valid i, j:
+ - is_col_vector(v[i]) == true
+ - v[i].size() > 0
+ - v[i].size() == v[j].size()
+ - i.e. v contains only column vectors and all the column vectors
+ have the same non-zero length
+ ensures
+ - returns create_max_margin_projection_hash(v,bits,C,dlib::rand())
+ (i.e. calls the above function with a default initialized random number generator)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CREATE_RANDOM_PROJECTION_HAsH_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/lsh/hashes.h b/ml/dlib/dlib/lsh/hashes.h
new file mode 100644
index 000000000..35053ce4e
--- /dev/null
+++ b/ml/dlib/dlib/lsh/hashes.h
@@ -0,0 +1,219 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LSH_HAShES_Hh_
+#define DLIB_LSH_HAShES_Hh_
+
+#include "hashes_abstract.h"
+#include "../hash.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class hash_similar_angles_64
+ {
+ public:
+ hash_similar_angles_64 (
+ ) : seed(0) {}
+
+ hash_similar_angles_64 (
+ const uint64 seed_
+ ) : seed(seed_) {}
+
+ uint64 get_seed (
+ ) const { return seed; }
+
+
+ typedef uint64 result_type;
+
+ template <
+ typename sparse_vector_type
+ >
+ typename disable_if<is_matrix<sparse_vector_type>,uint64>::type operator() (
+ const sparse_vector_type& v
+ ) const
+ {
+ typedef typename sparse_vector_type::value_type::second_type scalar_type;
+
+ uint64 temp = 0;
+ for (int i = 0; i < 64; ++i)
+ {
+ // compute the dot product between v and a Gaussian random vector.
+ scalar_type val = 0;
+ for (typename sparse_vector_type::const_iterator j = v.begin(); j != v.end(); ++j)
+ val += j->second*gaussian_random_hash(j->first, i, seed);
+
+ if (val > 0)
+ temp |= 1;
+ temp <<= 1;
+ }
+ return temp;
+ }
+
+ template <typename EXP>
+ uint64 operator() (
+ const matrix_exp<EXP>& v
+ ) const
+ {
+ typedef typename EXP::type T;
+ uint64 temp = 0;
+ for (unsigned long i = 0; i < 64; ++i)
+ {
+ if (dot(matrix_cast<T>(gaussian_randm(v.size(),1,i+seed*64)), v) > 0)
+ temp |= 1;
+ temp <<= 1;
+ }
+ return temp;
+ }
+
+ unsigned int distance (
+ const result_type& a,
+ const result_type& b
+ ) const
+ {
+ return hamming_distance(a,b);
+ }
+
+ private:
+ const uint64 seed;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class hash_similar_angles_128
+ {
+ public:
+ hash_similar_angles_128 (
+ ) : seed(0),hasher1(0), hasher2(1) {}
+
+ hash_similar_angles_128 (
+ const uint64 seed_
+ ) : seed(seed_),hasher1(2*seed),hasher2(2*seed+1) {}
+
+ uint64 get_seed (
+ ) const { return seed; }
+
+ typedef std::pair<uint64,uint64> result_type;
+
+ template <
+ typename vector_type
+ >
+ result_type operator() (
+ const vector_type& v
+ ) const
+ {
+ return std::make_pair(hasher1(v), hasher2(v));
+ }
+
+ unsigned int distance (
+ const result_type& a,
+ const result_type& b
+ ) const
+ {
+ return hamming_distance(a.first,b.first) +
+ hamming_distance(a.second,b.second);
+ }
+
+ private:
+ const uint64 seed;
+ hash_similar_angles_64 hasher1;
+ hash_similar_angles_64 hasher2;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class hash_similar_angles_256
+ {
+ public:
+ hash_similar_angles_256 (
+ ) : seed(0), hasher1(0), hasher2(1) {}
+
+ hash_similar_angles_256 (
+ const uint64 seed_
+ ) : seed(seed_),hasher1(2*seed),hasher2(2*seed+1) {}
+
+ uint64 get_seed (
+ ) const { return seed; }
+
+ typedef std::pair<uint64,uint64> hash128_type;
+ typedef std::pair<hash128_type,hash128_type> result_type;
+
+ template <
+ typename vector_type
+ >
+ result_type operator() (
+ const vector_type& v
+ ) const
+ {
+ return std::make_pair(hasher1(v), hasher2(v));
+ }
+
+ unsigned int distance (
+ const result_type& a,
+ const result_type& b
+ ) const
+ {
+ return hasher1.distance(a.first,b.first) +
+ hasher1.distance(a.second,b.second);
+ }
+
+ private:
+ const uint64 seed;
+ hash_similar_angles_128 hasher1;
+ hash_similar_angles_128 hasher2;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class hash_similar_angles_512
+ {
+ public:
+ hash_similar_angles_512 (
+ ) : seed(0), hasher1(0), hasher2(1) {}
+
+ hash_similar_angles_512 (
+ const uint64 seed_
+ ) : seed(seed_),hasher1(2*seed),hasher2(2*seed+1) {}
+
+ uint64 get_seed (
+ ) const { return seed; }
+
+
+ typedef hash_similar_angles_256::result_type hash256_type;
+ typedef std::pair<hash256_type,hash256_type> result_type;
+
+ template <
+ typename vector_type
+ >
+ result_type operator() (
+ const vector_type& v
+ ) const
+ {
+ return std::make_pair(hasher1(v), hasher2(v));
+ }
+
+ unsigned int distance (
+ const result_type& a,
+ const result_type& b
+ ) const
+ {
+ return hasher1.distance(a.first,b.first) +
+ hasher1.distance(a.second,b.second);
+ }
+
+ private:
+ const uint64 seed;
+ hash_similar_angles_256 hasher1;
+ hash_similar_angles_256 hasher2;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LSH_HAShES_Hh_
+
diff --git a/ml/dlib/dlib/lsh/hashes_abstract.h b/ml/dlib/dlib/lsh/hashes_abstract.h
new file mode 100644
index 000000000..27f8ddb69
--- /dev/null
+++ b/ml/dlib/dlib/lsh/hashes_abstract.h
@@ -0,0 +1,286 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LSH_HAShES_ABSTRACT_Hh_
+#ifdef DLIB_LSH_HAShES_ABSTRACT_Hh_
+
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class hash_similar_angles_64
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for computing locality sensitive hashes that give
+ vectors with small angles between each other similar hash values. In
+ particular, this object creates 64 random planes which pass though the
+ origin and uses them to create a 64bit hash. To compute the hash for a new
+ vector, this object checks which side of each plane the vector falls on and
+ records this information into a 64bit integer.
+ !*/
+
+ public:
+
+ hash_similar_angles_64 (
+ );
+ /*!
+ ensures
+ - #get_seed() == 0
+ !*/
+
+ hash_similar_angles_64 (
+ const uint64 seed
+ );
+ /*!
+ ensures
+ - #get_seed() == seed
+ !*/
+
+ uint64 get_seed (
+ ) const;
+ /*!
+ ensures
+ - returns the random seed used to generate the random planes used for
+ hashing.
+ !*/
+
+ typedef uint64 result_type;
+
+ template <typename vector_type>
+ result_type perator() (
+ const vector_type& v
+ ) const;
+ /*!
+ requires
+ - v is an unsorted sparse vector or a dlib matrix representing either a
+ column or row vector.
+ ensures
+ - returns a 64 bit hash of the input vector v. The bits in the hash record
+ which side of each random plane v falls on.
+
+ !*/
+
+ unsigned int distance (
+ const result_type& a,
+ const result_type& b
+ ) const;
+ /*!
+ ensures
+ - returns the Hamming distance between the two hashes given to this
+ function. That is, we return the number of bits in a and b which differ.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct hash_similar_angles_128
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for computing locality sensitive hashes that give
+ vectors with small angles between each other similar hash values. In
+ particular, this object creates 128 random planes which pass though the
+ origin and uses them to create a 128bit hash. To compute the hash for a new
+ vector, this object checks which side of each plane the vector falls on and
+ records this information into a 128bit integer.
+ !*/
+
+ public:
+
+ hash_similar_angles_128 (
+ );
+ /*!
+ ensures
+ - #get_seed() == 0
+ !*/
+
+ hash_similar_angles_128 (
+ const uint64 seed
+ );
+ /*!
+ ensures
+ - #get_seed() == seed
+ !*/
+
+ uint64 get_seed (
+ ) const;
+ /*!
+ ensures
+ - returns the random seed used to generate the random planes used for
+ hashing.
+ !*/
+
+ typedef std::pair<uint64,uint64> result_type;
+
+ template <typename vector_type>
+ result_type perator() (
+ const vector_type& v
+ ) const;
+ /*!
+ requires
+ - v is an unsorted sparse vector or a dlib matrix representing either a
+ column or row vector.
+ ensures
+ - returns a 128 bit hash of the input vector v. The bits in the hash record
+ which side of each random plane v falls on.
+
+ !*/
+
+ unsigned int distance (
+ const result_type& a,
+ const result_type& b
+ ) const;
+ /*!
+ ensures
+ - returns the Hamming distance between the two hashes given to this
+ function. That is, we return the number of bits in a and b which differ.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct hash_similar_angles_256
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for computing locality sensitive hashes that give
+ vectors with small angles between each other similar hash values. In
+ particular, this object creates 256 random planes which pass though the
+ origin and uses them to create a 256bit hash. To compute the hash for a new
+ vector, this object checks which side of each plane the vector falls on and
+ records this information into a 256bit integer.
+ !*/
+
+ public:
+
+ hash_similar_angles_256 (
+ );
+ /*!
+ ensures
+ - #get_seed() == 0
+ !*/
+
+ hash_similar_angles_256 (
+ const uint64 seed
+ );
+ /*!
+ ensures
+ - #get_seed() == seed
+ !*/
+
+ uint64 get_seed (
+ ) const;
+ /*!
+ ensures
+ - returns the random seed used to generate the random planes used for
+ hashing.
+ !*/
+
+ typedef std::pair<uint64,uint64> hash128_type;
+ typedef std::pair<hash128_type,hash128_type> result_type;
+
+ template <typename vector_type>
+ result_type perator() (
+ const vector_type& v
+ ) const;
+ /*!
+ requires
+ - v is an unsorted sparse vector or a dlib matrix representing either a
+ column or row vector.
+ ensures
+ - returns a 256 bit hash of the input vector v. The bits in the hash record
+ which side of each random plane v falls on.
+
+ !*/
+
+ unsigned int distance (
+ const result_type& a,
+ const result_type& b
+ ) const;
+ /*!
+ ensures
+ - returns the Hamming distance between the two hashes given to this
+ function. That is, we return the number of bits in a and b which differ.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct hash_similar_angles_512
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for computing locality sensitive hashes that give
+ vectors with small angles between each other similar hash values. In
+ particular, this object creates 512 random planes which pass though the
+ origin and uses them to create a 512bit hash. To compute the hash for a new
+ vector, this object checks which side of each plane the vector falls on and
+ records this information into a 512bit integer.
+ !*/
+
+ public:
+
+ hash_similar_angles_512 (
+ );
+ /*!
+ ensures
+ - #get_seed() == 0
+ !*/
+
+ hash_similar_angles_512 (
+ const uint64 seed
+ );
+ /*!
+ ensures
+ - #get_seed() == seed
+ !*/
+
+ uint64 get_seed (
+ ) const;
+ /*!
+ ensures
+ - returns the random seed used to generate the random planes used for
+ hashing.
+ !*/
+
+ typedef hash_similar_angles_256::result_type hash256_type;
+ typedef std::pair<hash256_type,hash256_type> result_type;
+
+ template <typename vector_type>
+ result_type perator() (
+ const vector_type& v
+ ) const;
+ /*!
+ requires
+ - v is an unsorted sparse vector or a dlib matrix representing either a
+ column or row vector.
+ ensures
+ - returns a 512 bit hash of the input vector v. The bits in the hash record
+ which side of each random plane v falls on.
+
+ !*/
+
+ unsigned int distance (
+ const result_type& a,
+ const result_type& b
+ ) const;
+ /*!
+ ensures
+ - returns the Hamming distance between the two hashes given to this
+ function. That is, we return the number of bits in a and b which differ.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LSH_HAShES_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/lsh/projection_hash.h b/ml/dlib/dlib/lsh/projection_hash.h
new file mode 100644
index 000000000..16de0ba11
--- /dev/null
+++ b/ml/dlib/dlib/lsh/projection_hash.h
@@ -0,0 +1,118 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PROJECTION_HASh_Hh_
+#define DLIB_PROJECTION_HASh_Hh_
+
+#include "projection_hash_abstract.h"
+#include "../matrix.h"
+#include "../rand.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class projection_hash
+ {
+ public:
+
+ projection_hash() {}
+
+ template <typename EXP1, typename EXP2>
+ projection_hash(
+ const matrix_exp<EXP1>& proj_,
+ const matrix_exp<EXP2>& offset_
+ ) : proj(proj_), offset(offset_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(proj.nr() == offset.nr(),
+ "\t projection_hash::projection_hash()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t proj.nr(): " << proj.nr()
+ << "\n\t offset.nr(): " << offset.nr()
+ );
+
+ }
+
+ const matrix<double>& get_projection_matrix (
+ ) const { return proj; }
+
+ const matrix<double,0,1>& get_offset_matrix (
+ ) const { return offset; }
+
+ unsigned long num_hash_bins (
+ ) const
+ {
+ return static_cast<unsigned long>(std::pow(2.0, (double)offset.size()));
+ }
+
+ template <typename EXP>
+ unsigned long operator() (
+ const matrix_exp<EXP>& v
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(v) &&
+ v.size() == get_projection_matrix().nc() &&
+ v.size() > 0,
+ "\t unsigned long projection_hash::operator()(v)"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t is_col_vector(v): " << is_col_vector(v)
+ << "\n\t get_projection_matrix().nc(): " << get_projection_matrix().nc()
+ << "\n\t v.size(): " << v.size()
+ );
+
+ return do_hash(proj*matrix_cast<double>(v) + offset);
+ }
+
+ private:
+
+ template <typename EXP>
+ unsigned long do_hash (
+ const matrix_exp<EXP>& v
+ ) const
+ {
+ unsigned long h = 0;
+ for (long i = 0; i < v.size(); ++i)
+ {
+ h <<= 1;
+ if (v(i) > 0)
+ h |= 1;
+ }
+ return h;
+ }
+
+ matrix<double> proj;
+ matrix<double,0,1> offset;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const projection_hash& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.get_projection_matrix(), out);
+ serialize(item.get_offset_matrix(), out);
+ }
+
+ inline void deserialize (
+ projection_hash& item,
+ std::istream& in
+ )
+ {
+ matrix<double> proj;
+ matrix<double,0,1> offset;
+ deserialize(proj, in);
+ deserialize(offset, in);
+ item = projection_hash(proj, offset);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_PROJECTION_HASh_Hh_
+
diff --git a/ml/dlib/dlib/lsh/projection_hash_abstract.h b/ml/dlib/dlib/lsh/projection_hash_abstract.h
new file mode 100644
index 000000000..abe78d10c
--- /dev/null
+++ b/ml/dlib/dlib/lsh/projection_hash_abstract.h
@@ -0,0 +1,119 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_PROJECTION_HASh_ABSTRACT_Hh_
+#ifdef DLIB_PROJECTION_HASh_ABSTRACT_Hh_
+
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class projection_hash
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool for hashing elements of a vector space into the integers.
+ It is intended to represent locality sensitive hashing functions such as
+ the popular random projection hashing method.
+
+ In particular, it represents hash functions of the form:
+ hash bit 0 = sign(rowm(P*v + O,0))
+ hash bit 1 = sign(rowm(P*v + O,1))
+ hash bit 2 = sign(rowm(P*v + O,2))
+ ...
+ Where v is the vector to be hashed. The parameters of the projection
+ hash are the P and O matrices.
+
+ THREAD SAFETY
+ The const members of this object can be called concurrently from multiple
+ threads, however, any operation that modifies the state of an instance of
+ this object must serialize access to that instance.
+ !*/
+ public:
+
+ projection_hash(
+ );
+ /*!
+ ensures
+ - #get_projection_matrix().size() == 0
+ - #get_offset_matrix().size() == 0
+ !*/
+
+ template <typename EXP1, typename EXP2>
+ projection_hash(
+ const matrix_exp<EXP1>& proj,
+ const matrix_exp<EXP2>& offset
+ );
+ /*!
+ requires
+ - proj.nr() == offset.nr()
+ ensures
+ - #get_projection_matrix() == proj
+ - #get_offset_matrix() == offset
+ !*/
+
+ const matrix<double>& get_projection_matrix (
+ ) const;
+ /*!
+ ensures
+ - returns the P matrix discussed above in the WHAT THIS OBJECT REPRESENTS
+ section.
+ !*/
+
+ const matrix<double,0,1>& get_offset_matrix (
+ ) const;
+ /*!
+ ensures
+ - returns the O matrix discussed above in the WHAT THIS OBJECT REPRESENTS
+ section.
+ !*/
+
+ unsigned long num_hash_bins (
+ ) const;
+ /*!
+ ensures
+ - returns the number of possible outputs from this hashing function.
+ - Specifically, returns: std::pow(2, get_offset_matrix().size())
+ !*/
+
+ template <typename EXP>
+ unsigned long operator() (
+ const matrix_exp<EXP>& v
+ ) const;
+ /*!
+ requires
+ - is_col_vector(v) == true
+ - v.size() == get_projection_matrix().nc()
+ - v.size() > 0
+ ensures
+ - hashes v into the range [0, num_hash_bins()) using the method
+ discussed in the WHAT THIS OBJECT REPRESENTS section.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const projection_hash& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ projection_hash& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_PROJECTION_HASh_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/lz77_buffer.h b/ml/dlib/dlib/lz77_buffer.h
new file mode 100644
index 000000000..b7364ad9a
--- /dev/null
+++ b/ml/dlib/dlib/lz77_buffer.h
@@ -0,0 +1,47 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LZ77_BUFFEr_
+#define DLIB_LZ77_BUFFEr_
+
+
+#include "lz77_buffer/lz77_buffer_kernel_1.h"
+#include "lz77_buffer/lz77_buffer_kernel_2.h"
+#include "lz77_buffer/lz77_buffer_kernel_c.h"
+
+#include "sliding_buffer.h"
+
+
+namespace dlib
+{
+
+
+ class lz77_buffer
+ {
+
+ lz77_buffer() {}
+
+ typedef sliding_buffer<unsigned char>::kernel_1a sb1;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef lz77_buffer_kernel_1<sb1>
+ kernel_1a;
+ typedef lz77_buffer_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ // kernel_2a
+ typedef lz77_buffer_kernel_2<sb1>
+ kernel_2a;
+ typedef lz77_buffer_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+
+ };
+}
+
+#endif // DLIB_LZ77_BUFFEr_
+
diff --git a/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_1.h b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_1.h
new file mode 100644
index 000000000..b93f0628d
--- /dev/null
+++ b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_1.h
@@ -0,0 +1,263 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LZ77_BUFFER_KERNEl_1_
+#define DLIB_LZ77_BUFFER_KERNEl_1_
+
+#include "lz77_buffer_kernel_abstract.h"
+#include "../algs.h"
+
+
+
+namespace dlib
+{
+
+ template <
+ typename sliding_buffer
+ >
+ class lz77_buffer_kernel_1
+ {
+ /*!
+ REQUIREMENTS ON sliding_buffer
+ sliding_buffer must be an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h
+
+ INITIAL VALUE
+ history_limit == defined by constructor arguments
+ lookahead_limit == defined by constructor arguments
+ history_size == 0
+ lookahead_size == 0
+ buffer.size() == history_limit + lookahead_limit
+
+
+ CONVENTION
+ history_limit == get_history_buffer_limit()
+ lookahead_limit == get_lookahead_buffer_limit()
+ history_size == get_history_buffer_size()
+ lookahead_limit == get_lookahead_buffer_size()
+
+ buffer.size() == history_limit + lookahead_limit
+
+ lookahead_buffer(i) == buffer[lookahead_limit-1-i]
+ history_buffer(i) == buffer[lookahead_limit+i]
+ !*/
+
+ public:
+
+ lz77_buffer_kernel_1 (
+ unsigned long total_limit_,
+ unsigned long lookahead_limit_
+ );
+
+ virtual ~lz77_buffer_kernel_1 (
+ ) {}
+
+ void clear(
+ );
+
+ void add (
+ unsigned char symbol
+ );
+
+ void find_match (
+ unsigned long& index,
+ unsigned long& length,
+ unsigned long min_match_length
+ );
+
+ inline unsigned long get_history_buffer_limit (
+ ) const { return history_limit; }
+
+ inline unsigned long get_lookahead_buffer_limit (
+ ) const { return lookahead_limit; }
+
+ inline unsigned long get_history_buffer_size (
+ ) const { return history_size; }
+
+ inline unsigned long get_lookahead_buffer_size (
+ ) const { return lookahead_size; }
+
+ inline unsigned char lookahead_buffer (
+ unsigned long index
+ ) const { return buffer[lookahead_limit-1-index]; }
+
+ inline unsigned char history_buffer (
+ unsigned long index
+ ) const { return buffer[lookahead_limit+index]; }
+
+
+ inline void shift_buffers (
+ unsigned long N
+ ) { shift_buffer(N); }
+
+ private:
+
+
+ inline void shift_buffer (
+ unsigned long N
+ )
+ /*!
+ requires
+ - N <= lookahead_size
+ ensuers
+ - #lookahead_size == lookahead_size - N
+ - if (history_size+N < history_limit) then
+ - #history_size == history_size+N
+ - else
+ - #history_size == history_limit
+ - for all i where 0 <= i < N:
+ #history_buffer(N-1-i) == lookahead_buffer(i)
+ - for all i where 0 <= i < #history_size-N:
+ #history_buffer(N+i) == history_buffer(i)
+ - for all i where 0 <= i < #lookahead_size
+ #lookahead_buffer(i) == lookahead_buffer(N+i)
+ !*/
+ {
+ unsigned long temp = history_size+N;
+ buffer.rotate_left(N);
+ lookahead_size -= N;
+ if (temp < history_limit)
+ history_size = temp;
+ else
+ history_size = history_limit;
+ }
+
+
+ // member data
+ sliding_buffer buffer;
+ unsigned long lookahead_limit;
+ unsigned long history_limit;
+
+
+ unsigned long lookahead_size;
+ unsigned long history_size;
+
+
+ // restricted functions
+ lz77_buffer_kernel_1(lz77_buffer_kernel_1<sliding_buffer>&); // copy constructor
+ lz77_buffer_kernel_1<sliding_buffer>& operator=(lz77_buffer_kernel_1<sliding_buffer>&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ lz77_buffer_kernel_1<sliding_buffer>::
+ lz77_buffer_kernel_1 (
+ unsigned long total_limit_,
+ unsigned long lookahead_limit_
+ ) :
+ lookahead_size(0),
+ history_size(0)
+ {
+ buffer.set_size(total_limit_);
+ lookahead_limit = lookahead_limit_;
+ history_limit = buffer.size() - lookahead_limit_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ void lz77_buffer_kernel_1<sliding_buffer>::
+ clear(
+ )
+ {
+ lookahead_size = 0;
+ history_size = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ void lz77_buffer_kernel_1<sliding_buffer>::
+ add (
+ unsigned char symbol
+ )
+ {
+ if (lookahead_size == lookahead_limit)
+ {
+ shift_buffer(1);
+ }
+ buffer[lookahead_limit-1-lookahead_size] = symbol;
+ ++lookahead_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ void lz77_buffer_kernel_1<sliding_buffer>::
+ find_match (
+ unsigned long& index,
+ unsigned long& length,
+ unsigned long min_match_length
+ )
+ {
+ unsigned long hpos = history_size; // current position in the history buffer
+ unsigned long lpos = 0; // current position in the lookahead buffer
+
+ unsigned long match_length = 0; // the length of the longest match we find
+ unsigned long match_index = 0; // the index of the longest match we find
+
+ // try to find a match
+ while (hpos != 0)
+ {
+ --hpos;
+ // if we are finding a match
+ if (history_buffer(hpos) == lookahead_buffer(lpos))
+ {
+ ++lpos;
+ // if we have found a match that is as long as the lookahead buffer
+ // then we are done
+ if (lpos == lookahead_size)
+ break;
+ }
+ // else if we found the end of a match
+ else if (lpos > 0)
+ {
+ // if this match is longer than the last match we saw
+ if (lpos > match_length)
+ {
+ match_length = lpos;
+ match_index = hpos + lpos;
+ }
+ lpos = 0;
+ }
+ } // while (hpos != 0)
+
+ // if we found a match at the end of the loop that is greater than
+ // the match in match_index
+ if (lpos > match_length)
+ {
+ match_length = lpos;
+ match_index = hpos + lpos - 1;
+ }
+
+
+ // if we found a match that was long enough then report it
+ if (match_length >= min_match_length)
+ {
+ shift_buffer(match_length);
+ index = match_index;
+ length = match_length;
+ }
+ else
+ {
+ length = 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LZ77_BUFFER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_2.h b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_2.h
new file mode 100644
index 000000000..f8d332784
--- /dev/null
+++ b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_2.h
@@ -0,0 +1,504 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LZ77_BUFFER_KERNEl_2_
+#define DLIB_LZ77_BUFFER_KERNEl_2_
+
+#include "lz77_buffer_kernel_abstract.h"
+#include "../algs.h"
+
+
+
+namespace dlib
+{
+
+ template <
+ typename sliding_buffer
+ >
+ class lz77_buffer_kernel_2
+ {
+ /*!
+ REQUIREMENTS ON sliding_buffer
+ sliding_buffer must be an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h
+ and must be instantiated to contain unsigned char data
+
+ INITIAL VALUE
+ history_limit == defined by constructor arguments
+ lookahead_limit == defined by constructor arguments
+ history_size == 0
+ lookahead_size == 0
+ buffer.size() == history_limit + lookahead_limit
+ buffer[i] == 0 for all valid i
+
+ nodes == an array of history_limit-3 nodes
+ id_table == an array of buffer.size() pointers
+ hash_table == an array of buffer.size() pointers and all are set to 0
+ mask == buffer.size() - 1
+ next_free_node == 0
+
+
+ CONVENTION
+ history_limit == get_history_buffer_limit()
+ lookahead_limit == get_lookahead_buffer_limit()
+ history_size == get_history_buffer_size()
+ lookahead_limit == get_lookahead_buffer_size()
+
+ buffer.size() == history_limit + lookahead_limit
+
+ lookahead_buffer(i) == buffer[lookahead_limit-1-i]
+ history_buffer(i) == buffer[lookahead_limit+i]
+
+
+ hash_table[hash(a,b,c,d)] points to the head of a linked list.
+ Each node in this linked list tells the location in the buffer
+ of a string that begins with abcd or a string who's first four
+ letters have the same hash. The linked list is terminated by a
+ node with a null next pointer.
+
+ hash_table[i] == 0 if there is no linked list for this element of the hash
+ table.
+
+ each node in the hash table is allocated from the array nodes.
+ When adding a node to hash_table:
+ if (if all nodes aren't already in the hash_table) then
+ {
+ the next node to use is nodes[next_free_node].
+ }
+ else
+ {
+ recycle nodes from the hash_table itself. This works because
+ when we add new nodes we also have to remove nodes.
+ }
+
+ if (there is a node defined with an id of i) then
+ {
+ if (id_table[i] != 0) then
+ id_table[i]->next->id == i
+ else
+ hash_table[some_hash]->id == i
+ }
+ !*/
+
+ public:
+
+ lz77_buffer_kernel_2 (
+ unsigned long total_limit_,
+ unsigned long lookahead_limit_
+ );
+
+ virtual ~lz77_buffer_kernel_2 (
+ );
+
+ void clear(
+ );
+
+ void add (
+ unsigned char symbol
+ );
+
+ void find_match (
+ unsigned long& index,
+ unsigned long& length,
+ unsigned long min_match_length
+ );
+
+ inline unsigned long get_history_buffer_limit (
+ ) const { return history_limit; }
+
+ inline unsigned long get_lookahead_buffer_limit (
+ ) const { return lookahead_limit; }
+
+ inline unsigned long get_history_buffer_size (
+ ) const { return history_size; }
+
+ inline unsigned long get_lookahead_buffer_size (
+ ) const { return lookahead_size; }
+
+ inline unsigned char lookahead_buffer (
+ unsigned long index
+ ) const { return buffer[lookahead_limit-1-index]; }
+
+ inline unsigned char history_buffer (
+ unsigned long index
+ ) const { return buffer[lookahead_limit+index]; }
+
+
+ inline void shift_buffers (
+ unsigned long N
+ ) { shift_buffer(N); }
+
+ private:
+
+ inline unsigned long hash (
+ unsigned char a,
+ unsigned char b,
+ unsigned char c,
+ unsigned char d
+ ) const
+ /*!
+ ensures
+ - returns a hash of the 4 arguments and the hash is in the range
+ !*/
+ {
+ unsigned long B = b << 3;
+ unsigned long C = c << 6;
+ unsigned long D = d << 9;
+
+ unsigned long temp = a + B;
+ temp += C;
+ temp += D;
+
+ return (temp&mask); /**/
+ }
+
+ void shift_buffer (
+ unsigned long N
+ );
+ /*!
+ requires
+ - N <= lookahead_size
+ ensuers
+ - #lookahead_size == lookahead_size - N
+ - if (history_size+N < history_limit) then
+ - #history_size == history_size+N
+ - else
+ - #history_size == history_limit
+ - for all i where 0 <= i < N:
+ #history_buffer(N-1-i) == lookahead_buffer(i)
+ - for all i where 0 <= i < #history_size-N:
+ #history_buffer(N+i) == history_buffer(i)
+ - for all i where 0 <= i < #lookahead_size
+ #lookahead_buffer(i) == lookahead_buffer(N+i)
+ !*/
+
+
+
+ // member data
+ sliding_buffer buffer;
+ unsigned long lookahead_limit;
+ unsigned long history_limit;
+
+ struct node
+ {
+ unsigned long id;
+ node* next;
+ };
+
+ node** hash_table;
+ node* nodes;
+ node** id_table;
+ unsigned long next_free_node;
+ unsigned long mask;
+
+ unsigned long lookahead_size;
+ unsigned long history_size;
+
+
+ // restricted functions
+ lz77_buffer_kernel_2(lz77_buffer_kernel_2<sliding_buffer>&); // copy constructor
+ lz77_buffer_kernel_2<sliding_buffer>& operator=(lz77_buffer_kernel_2<sliding_buffer>&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ lz77_buffer_kernel_2<sliding_buffer>::
+ lz77_buffer_kernel_2 (
+ unsigned long total_limit_,
+ unsigned long lookahead_limit_
+ ) :
+ lookahead_size(0),
+ history_size(0)
+ {
+ buffer.set_size(total_limit_);
+ lookahead_limit = lookahead_limit_;
+ history_limit = buffer.size() - lookahead_limit_;
+
+ nodes = new node[history_limit-3];
+
+ try { id_table = new node*[buffer.size()]; }
+ catch (...) { delete [] nodes; throw; }
+
+ try { hash_table = new node*[buffer.size()]; }
+ catch (...) { delete [] id_table; delete [] nodes; throw; }
+
+ mask = buffer.size()-1;
+ next_free_node = 0;
+
+
+ node** start = hash_table;
+ node** end = hash_table + buffer.size();
+ while (start != end)
+ {
+ *start = 0;
+ ++start;
+ }
+
+ for (unsigned long i = 0; i < buffer.size(); ++i)
+ buffer[i] = 0;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ lz77_buffer_kernel_2<sliding_buffer>::
+ ~lz77_buffer_kernel_2 (
+ )
+ {
+ delete [] nodes;
+ delete [] hash_table;
+ delete [] id_table;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ void lz77_buffer_kernel_2<sliding_buffer>::
+ clear(
+ )
+ {
+ lookahead_size = 0;
+ history_size = 0;
+ next_free_node = 0;
+
+ node** start = hash_table;
+ node** end = hash_table + buffer.size();
+ while (start != end)
+ {
+ *start = 0;
+ ++start;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ void lz77_buffer_kernel_2<sliding_buffer>::
+ shift_buffer (
+ unsigned long N
+ )
+ {
+ unsigned long old_history_size = history_size;
+ unsigned long temp = history_size+N;
+ unsigned long new_nodes; // the number of nodes to pull from the nodes array
+ unsigned long recycled_nodes; // the number of nodes to pull from hash_table
+ lookahead_size -= N;
+ if (temp <= history_limit)
+ {
+ if (history_size <= 3)
+ {
+ if ((3-history_size) >= N)
+ new_nodes = 0;
+ else
+ new_nodes = N - (3-history_size);
+ }
+ else
+ {
+ new_nodes = N;
+ }
+
+ recycled_nodes = 0;
+ history_size = temp;
+ }
+ else
+ {
+ if (history_size != history_limit)
+ {
+ new_nodes = history_limit - history_size;
+ recycled_nodes = temp - history_limit;
+ history_size = history_limit;
+ }
+ else
+ {
+ new_nodes = 0;
+ recycled_nodes = N;
+ }
+ }
+
+ unsigned long i = lookahead_limit + 2;
+
+ // if there are any "new" nodes to add to the hash table
+ if (new_nodes != 0)
+ {
+ unsigned long stop = i - new_nodes;
+ for (; i > stop; --i)
+ {
+ nodes[next_free_node].next = 0;
+ nodes[next_free_node].id = buffer.get_element_id(i);
+ id_table[nodes[next_free_node].id] = 0;
+
+ unsigned long new_hash = hash(buffer[i],buffer[i-1],buffer[i-2],buffer[i-3]);
+
+ if (hash_table[new_hash] != 0)
+ id_table[hash_table[new_hash]->id] = &nodes[next_free_node];
+ nodes[next_free_node].next = hash_table[new_hash];
+ hash_table[new_hash] = &nodes[next_free_node];
+
+ ++next_free_node;
+ }
+ } // if (new_nodes != 0)
+
+
+
+ unsigned long stop = i - recycled_nodes;
+ unsigned long old = old_history_size-1+lookahead_limit;
+ for (; i > stop; --i)
+ {
+ // find the next node to recycle in hash_table
+ node* recycled_node;
+
+
+ unsigned long old_id = buffer.get_element_id(old);
+
+ // find the node with id old_id
+ if (id_table[old_id] == 0)
+ {
+ unsigned long old_hash = hash(buffer[old],buffer[old-1],buffer[old-2],buffer[old-3]);
+ recycled_node = hash_table[old_hash];
+
+ // fill the gap left by removing this node
+ hash_table[old_hash] = recycled_node->next;
+ }
+ else
+ {
+ recycled_node = id_table[old_id]->next;
+
+ // fill the gap left by removing this node
+ id_table[old_id]->next = recycled_node->next;
+ }
+
+ --old;
+
+
+
+
+
+
+ recycled_node->next = 0;
+ recycled_node->id = buffer.get_element_id(i);
+ id_table[recycled_node->id] = 0;
+
+ unsigned long new_hash = hash(buffer[i],buffer[i-1],buffer[i-2],buffer[i-3]);
+
+ if (hash_table[new_hash] != 0)
+ id_table[hash_table[new_hash]->id] = recycled_node;
+
+ recycled_node->next = hash_table[new_hash];
+ hash_table[new_hash] = recycled_node;
+
+ } // for (; i > stop; --i)
+
+
+
+
+ buffer.rotate_left(N);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ void lz77_buffer_kernel_2<sliding_buffer>::
+ add (
+ unsigned char symbol
+ )
+ {
+ if (lookahead_size == lookahead_limit)
+ {
+ shift_buffer(1);
+ }
+ buffer[lookahead_limit-1-lookahead_size] = symbol;
+ ++lookahead_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sliding_buffer
+ >
+ void lz77_buffer_kernel_2<sliding_buffer>::
+ find_match (
+ unsigned long& index,
+ unsigned long& length,
+ unsigned long min_match_length
+ )
+ {
+ unsigned long match_length = 0; // the length of the longest match we find
+ unsigned long match_index = 0; // the index of the longest match we find
+
+
+ const unsigned long hash_value = hash(lookahead_buffer(0),
+ lookahead_buffer(1),
+ lookahead_buffer(2),
+ lookahead_buffer(3)
+ );
+
+
+
+ node* temp = hash_table[hash_value];
+ while (temp != 0)
+ {
+ // current position in the history buffer
+ unsigned long hpos = buffer.get_element_index(temp->id)-lookahead_limit;
+ // current position in the lookahead buffer
+ unsigned long lpos = 0;
+
+ // find length of this match
+ while (history_buffer(hpos) == lookahead_buffer(lpos))
+ {
+ ++lpos;
+ if (hpos == 0)
+ break;
+ --hpos;
+ if (lpos == lookahead_size)
+ break;
+ }
+
+ if (lpos > match_length)
+ {
+ match_length = lpos;
+ match_index = buffer.get_element_index(temp->id)-lookahead_limit;
+ // if this is the longest possible match then stop looking
+ if (lpos == lookahead_limit)
+ break;
+ }
+
+
+ temp = temp->next;
+ } // while (temp != 0)
+
+
+
+
+ // if we found a match that was long enough then report it
+ if (match_length >= min_match_length)
+ {
+ shift_buffer(match_length);
+ index = match_index;
+ length = match_length;
+ }
+ else
+ {
+ length = 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LZ77_BUFFER_KERNEl_2_
+
diff --git a/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_abstract.h b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_abstract.h
new file mode 100644
index 000000000..942b4e3c2
--- /dev/null
+++ b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_abstract.h
@@ -0,0 +1,210 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LZ77_BUFFER_KERNEl_ABSTRACT_
+#ifdef DLIB_LZ77_BUFFER_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+ class lz77_buffer
+ {
+ /*!
+ INITIAL VALUE
+ get_history_buffer_limit() == defined by constructor arguments
+ get_lookahead_buffer_limit() == defined by constructor arguments
+ get_history_buffer_size() == 0
+ get_lookahead_buffer_size() == 0
+
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a pair of buffers (history and lookahead buffers)
+ used during lz77 style compression.
+
+ It's main function is to search the history buffer for long strings which
+ match the contents (or a part of the contents) of the lookahead buffer.
+
+
+ HISTORY AND LOOKAHEAD BUFFERS
+ The buffers have the following structure:
+ | history buffer | lookahead buffer | <-- contents of buffers
+ | ...9876543210 | 0123456789... | <-- index numbers
+
+ So this means that history_buffer(0) == 'r', history_buffer(1) == 'e'
+ and so on. And lookahead_buffer(0) == 'l', lookahead_buffer(1) == 'o'
+ and so on.
+
+
+ What shift_buffers() does in english:
+ This function just means that the buffers have their contents shifted
+ left by N elements and that elements shifted out of the lookahead buffer
+ go into the history buffer. An example will make it clearer.
+
+ Suppose that we have the following buffers before we apply shift_buffers()
+ history_buffer() == "hey" and
+ lookahead_buffer() == "lookahead buffer"
+ And in the same format as the above diagram it would be
+ | hey | lookahead buffer | <-- contents of buffers
+ | 210 | 0123456789... | <-- index numbers
+
+ Applying shift_buffers(4) will give
+ lookahead_buffer() == "ahead buffer"
+ history_buffer() == "heylook" or "eylook" or "ylook" or "look"
+
+ You might be wondering why the history_buffer can resize itself in
+ such a nondeterministic way. It is just to allow a lot of freedom in the
+ implementations of this object.
+ !*/
+
+ public:
+
+ lz77_buffer (
+ unsigned long total_limit,
+ unsigned long lookahead_limit
+ );
+ /*!
+ requires
+ - 6 < total_limit < 32
+ - 15 < lookahead_limit <= 2^(total_limit-2)
+ ensures
+ - #*this is properly initialized
+ - #get_history_buffer_limit() == 2^total_limit - lookahead_limit
+ - #get_lookahead_buffer_limit() == lookahead_limit
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~lz77_buffer (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void shift_buffers (
+ unsigned long N
+ );
+ /*!
+ requires
+ - N <= get_lookahead_buffer_size()
+ ensures
+ - #get_lookahead_buffer_size() == get_lookahead_buffer_size() - N
+ - #get_history_buffer_size() >= N
+ - #get_history_buffer_size() <= get_history_buffer_size()+N
+ - #get_history_buffer_size() <= get_history_buffer_limit()
+ - for all i where 0 <= i < N:
+ #history_buffer(N-1-i) == lookahead_buffer(i)
+ - for all i where 0 <= i < #get_history_buffer_size()-N:
+ #history_buffer(N+i) == history_buffer(i)
+ - for all i where 0 <= i < #get_lookahead_buffer_size()
+ #lookahead_buffer(i) == lookahead_buffer(N+i)
+ !*/
+
+ void add (
+ unsigned char symbol
+ );
+ /*!
+ ensures
+ - if (get_lookahead_buffer_size() == get_lookahead_buffer_limit()) then
+ - performs shift_buffers(1)
+ - #lookahead_buffer(get_lookahead_buffer_limit()-1) == symbol
+ - #get_lookahead_buffer_size() == get_lookahead_buffer_size()
+ - else
+ - #lookahead_buffer(get_lookahead_buffer_size()) == symbol
+ - #get_lookahead_buffer_size() == get_lookahead_buffer_size() + 1
+ throws
+ - std::bad_alloc
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void find_match (
+ unsigned long& index,
+ unsigned long& length,
+ unsigned long min_match_length
+ );
+ /*!
+ ensures
+ - if (#length != 0) then
+ - #length >= min_match_length
+ - for all i where 0 <= i < #length:
+ history_buffer(#index-i) == lookahead_buffer(i)
+ - performs shift_buffers(#length)
+ throws
+ - std::bad_alloc
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ unsigned long get_history_buffer_limit (
+ ) const;
+ /*!
+ ensures
+ - returns the max number of symbols that can fit in the history buffer
+ !*/
+
+ unsigned long get_lookahead_buffer_limit (
+ ) const;
+ /*!
+ ensures
+ - returns the max number of symbols that can fit in the lookahead buffer
+ !*/
+
+ unsigned long get_history_buffer_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of symbols currently in the history buffer
+ !*/
+
+ unsigned long get_lookahead_buffer_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of symbols currently in the lookahead buffer
+ !*/
+
+ unsigned char lookahead_buffer (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < get_lookahead_buffer_size()
+ ensures
+ - returns the symbol in the lookahead buffer at location index
+ !*/
+
+ unsigned char history_buffer (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < get_history_buffer_size()
+ ensures
+ - returns the symbol in the history buffer at location index
+ !*/
+
+
+ private:
+
+ // restricted functions
+ lz77_buffer(lz77_buffer&); // copy constructor
+ lz77_buffer& operator=(lz77_buffer&); // assignment operator
+
+ };
+}
+
+#endif // DLIB_LZ77_BUFFER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_c.h b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_c.h
new file mode 100644
index 000000000..704763ad1
--- /dev/null
+++ b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_c.h
@@ -0,0 +1,169 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LZ77_BUFFER_KERNEl_C_
+#define DLIB_LZ77_BUFFER_KERNEl_C_
+
+#include "lz77_buffer_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename lz77_base
+ >
+ class lz77_buffer_kernel_c : public lz77_base
+ {
+
+ public:
+ lz77_buffer_kernel_c (
+ unsigned long total_limit,
+ unsigned long lookahead_limit
+ );
+
+ unsigned char lookahead_buffer (
+ unsigned long index
+ ) const;
+
+ unsigned char history_buffer (
+ unsigned long index
+ ) const;
+
+ void shift_buffers (
+ unsigned long N
+ );
+
+
+
+ unsigned long make_safe (
+ unsigned long total_limit,
+ unsigned long lookahead_limit
+ )
+ /*!
+ ensures
+ - if ( 6 < total_limit < 32 &&
+ 15 < lookahead_limit <= 2^(total_limit-2)
+ ) then
+ - returns total_limit
+ - else
+ - throws due to failed CASSERT
+ !*/
+ {
+ unsigned long exp_size = (total_limit!=0)?total_limit-2:0;
+ unsigned long two_pow_total_limit_minus_2 = 1;
+ while (exp_size != 0)
+ {
+ --exp_size;
+ two_pow_total_limit_minus_2 <<= 1;
+ }
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT( 6 < total_limit && total_limit < 32 &&
+ 15 < lookahead_limit && lookahead_limit <= two_pow_total_limit_minus_2,
+ "\tlz77_buffer::lz77_buffer(unsigned long,unsigned long)"
+ << "\n\ttotal_limit must be in the range 7 to 31 and \n\tlookahead_limit in the range 15 to 2^(total_limit-2)"
+ << "\n\tthis: " << this
+ << "\n\ttotal_limit: " << total_limit
+ << "\n\tlookahead_limit: " << lookahead_limit
+ );
+
+ return total_limit;
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lz77_base
+ >
+ void lz77_buffer_kernel_c<lz77_base>::
+ shift_buffers (
+ unsigned long N
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( N <= this->get_lookahead_buffer_size(),
+ "\tvoid lz77_buffer::shift_buffers(unsigned long)"
+ << "\n\tN must be <= the number of chars in the lookahead buffer"
+ << "\n\tthis: " << this
+ << "\n\tget_lookahead_buffer_size(): " << this->get_lookahead_buffer_size()
+ << "\n\tN: " << N
+ );
+
+ // call the real function
+ lz77_base::shift_buffers(N);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lz77_base
+ >
+ unsigned char lz77_buffer_kernel_c<lz77_base>::
+ history_buffer (
+ unsigned long index
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( index < this->get_history_buffer_size(),
+ "\tunsigned char lz77_buffer::history_buffer(unsigned long) const"
+ << "\n\tindex must be in the range 0 to get_history_buffer_size()-1"
+ << "\n\tthis: " << this
+ << "\n\tget_history_buffer_size(): " << this->get_history_buffer_size()
+ << "\n\tindex: " << index
+ );
+
+ // call the real function
+ return lz77_base::history_buffer(index);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lz77_base
+ >
+ unsigned char lz77_buffer_kernel_c<lz77_base>::
+ lookahead_buffer (
+ unsigned long index
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( index < this->get_lookahead_buffer_size(),
+ "\tunsigned char lz77_buffer::lookahead_buffer(unsigned long) const"
+ << "\n\tindex must be in the range 0 to get_lookahead_buffer_size()-1"
+ << "\n\tthis: " << this
+ << "\n\tget_lookahead_buffer_size(): " << this->get_lookahead_buffer_size()
+ << "\n\tindex: " << index
+ );
+
+ // call the real function
+ return lz77_base::lookahead_buffer(index);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lz77_base
+ >
+ lz77_buffer_kernel_c<lz77_base>::
+ lz77_buffer_kernel_c (
+ unsigned long total_limit,
+ unsigned long lookahead_limit
+ ) :
+ lz77_base(make_safe(total_limit,lookahead_limit),lookahead_limit)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LZ77_BUFFER_KERNEl_C_
+
diff --git a/ml/dlib/dlib/lzp_buffer.h b/ml/dlib/dlib/lzp_buffer.h
new file mode 100644
index 000000000..090a53de3
--- /dev/null
+++ b/ml/dlib/dlib/lzp_buffer.h
@@ -0,0 +1,46 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LZP_BUFFEr_
+#define DLIB_LZP_BUFFEr_
+
+
+#include "lzp_buffer/lzp_buffer_kernel_1.h"
+#include "lzp_buffer/lzp_buffer_kernel_2.h"
+#include "lzp_buffer/lzp_buffer_kernel_c.h"
+
+#include "sliding_buffer.h"
+
+
+namespace dlib
+{
+
+
+ class lzp_buffer
+ {
+
+ lzp_buffer() {}
+
+ typedef sliding_buffer<unsigned char>::kernel_1a sb1;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef lzp_buffer_kernel_1<sb1>
+ kernel_1a;
+ typedef lzp_buffer_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+ // kernel_2a
+ typedef lzp_buffer_kernel_2<sb1>
+ kernel_2a;
+ typedef lzp_buffer_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+
+ };
+}
+
+#endif // DLIB_LZP_BUFFEr_
+
diff --git a/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_1.h b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_1.h
new file mode 100644
index 000000000..f24d74eee
--- /dev/null
+++ b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_1.h
@@ -0,0 +1,236 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LZP_BUFFER_KERNEl_1_
+#define DLIB_LZP_BUFFER_KERNEl_1_
+
+#include "../algs.h"
+#include "lzp_buffer_kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename sbuf
+ >
+ class lzp_buffer_kernel_1
+ {
+ /*!
+ REQUIREMENTS ON sbuf
+ sbuf is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h
+ T == unsigned char
+
+ INITIAL VALUE
+ - buffer.size() == the size as defined by the constructor
+ - table_size == the number of elements in the table array
+ - for all i: buffer[i] == 0
+ - for all i: table[i] == buffer.size()
+
+ CONVENTION
+ - table_size == the number of elements in the table array
+ - size() == buffer.size()
+ - operator[](i) == buffer[i]
+
+ - if (table[hash()] != buffer.size()) then
+ - buffer.get_element_index(table[hash()]) == the index we will
+ predict for the current context
+ - else
+ - there is no prediction for the current context
+
+ - last_element == buffer.size()-1
+
+
+ This is LZP with just an order-3 model without context confirmation.
+
+ !*/
+
+ public:
+
+ explicit lzp_buffer_kernel_1 (
+ unsigned long buffer_size
+ );
+
+ virtual ~lzp_buffer_kernel_1 (
+ );
+
+ void clear(
+ );
+
+ inline void add (
+ unsigned char symbol
+ );
+
+ inline unsigned long predict_match (
+ unsigned long& index
+ );
+
+ inline size_t size (
+ ) const;
+
+ inline unsigned char operator[] (
+ unsigned long index
+ ) const;
+
+ private:
+
+ inline unsigned long hash (
+ ) const
+ /*!
+ ensures
+ - returns a hash computed from the current context. This hash
+ is always in the range for table.
+ !*/
+ {
+ unsigned long temp = buffer[0];
+ temp <<= 16;
+ unsigned long temp2 = buffer[1];
+ temp2 <<= 8;
+ unsigned long temp3 = buffer[2];
+ temp = temp|temp2|temp3;
+
+ temp = ((temp>>11)^temp)&0xFFFF;
+
+ return temp;
+ }
+
+ sbuf buffer;
+ const unsigned long table_size;
+ unsigned long* const table;
+ unsigned long last_element;
+
+ // restricted functions
+ lzp_buffer_kernel_1(const lzp_buffer_kernel_1<sbuf>&); // copy constructor
+ lzp_buffer_kernel_1<sbuf>& operator=(const lzp_buffer_kernel_1<sbuf>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ lzp_buffer_kernel_1<sbuf>::
+ lzp_buffer_kernel_1 (
+ unsigned long buffer_size
+ ) :
+ table_size(65536),
+ table(new unsigned long[table_size])
+ {
+ buffer.set_size(buffer_size);
+
+ for (unsigned long i = 0; i < buffer.size(); ++i)
+ buffer[i] = 0;
+
+ for (unsigned long i = 0; i < table_size; ++i)
+ table[i] = buffer.size();
+
+ last_element = buffer.size()-1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ lzp_buffer_kernel_1<sbuf>::
+ ~lzp_buffer_kernel_1 (
+ )
+ {
+ delete [] table;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ void lzp_buffer_kernel_1<sbuf>::
+ clear(
+ )
+ {
+ for (unsigned long i = 0; i < buffer.size(); ++i)
+ buffer[i] = 0;
+
+ for (unsigned long i = 0; i < table_size; ++i)
+ table[i] = buffer.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ void lzp_buffer_kernel_1<sbuf>::
+ add (
+ unsigned char symbol
+ )
+ {
+ buffer.rotate_left(1);
+ buffer[0] = symbol;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ unsigned long lzp_buffer_kernel_1<sbuf>::
+ predict_match (
+ unsigned long& index
+ )
+ {
+ const unsigned long i = hash();
+
+ if (table[i] != buffer.size())
+ {
+ index = buffer.get_element_index(table[i]);
+
+ if (index > 20)
+ {
+ // update the prediction for this context
+ table[i] = buffer.get_element_id(last_element);
+ }
+ return 3;
+ }
+ else
+ {
+ // update the prediction for this context
+ table[i] = buffer.get_element_id(last_element);
+ return 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ size_t lzp_buffer_kernel_1<sbuf>::
+ size (
+ ) const
+ {
+ return buffer.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ unsigned char lzp_buffer_kernel_1<sbuf>::
+ operator[] (
+ unsigned long index
+ ) const
+ {
+ return buffer[index];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LZP_BUFFER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_2.h b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_2.h
new file mode 100644
index 000000000..47c0443f1
--- /dev/null
+++ b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_2.h
@@ -0,0 +1,319 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LZP_BUFFER_KERNEl_2_
+#define DLIB_LZP_BUFFER_KERNEl_2_
+
+#include "../algs.h"
+#include "lzp_buffer_kernel_abstract.h"
+#include <new>
+
+namespace dlib
+{
+
+ template <
+ typename sbuf
+ >
+ class lzp_buffer_kernel_2
+ {
+ /*!
+ REQUIREMENTS ON sbuf
+ sbuf is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h
+ T == unsigned char
+
+ INITIAL VALUE
+ - buffer.size() == the size as defined by the constructor
+ - table_size == the number of elements in the table3 and table4 arrays
+ - for all i: buffer[i] == 0
+ - for all i: table3[i] == buffer.size()
+ - for all i: table4[i] == buffer.size()
+
+ CONVENTION
+ - table_size == the number of elements in the table3 and table4 arrays
+ - size() == buffer.size()
+ - operator[](i) == buffer[i]
+
+
+
+ - last_element == buffer.size()-1
+
+
+ This is LZP with an order-5-4-3 model with context confirmation.
+ To save memory the order5 and order3 predictions exist in the same
+ table, that is, table3.
+
+ !*/
+
+ public:
+
+ explicit lzp_buffer_kernel_2 (
+ unsigned long buffer_size
+ );
+
+ virtual ~lzp_buffer_kernel_2 (
+ );
+
+ void clear(
+ );
+
+ inline void add (
+ unsigned char symbol
+ );
+
+ inline unsigned long predict_match (
+ unsigned long& index
+ );
+
+ inline size_t size (
+ ) const;
+
+ inline unsigned char operator[] (
+ unsigned long index
+ ) const;
+
+ private:
+
+ inline bool verify (
+ unsigned long index
+ ) const
+ /*!
+ ensures
+ - returns true if buffer[index]'s context matches the current context
+ !*/
+ {
+ if (index+3 < buffer.size())
+ {
+ if (buffer[0] != buffer[index+1])
+ return false;
+ if (buffer[1] != buffer[index+2])
+ return false;
+ if (buffer[2] != buffer[index+3])
+ return false;
+ return true;
+ }
+ else
+ {
+ // just call this a match
+ return true;
+ }
+ }
+
+
+ sbuf buffer;
+ unsigned long* table3;
+ unsigned long* table4;
+ unsigned long last_element;
+ const unsigned long table_size;
+
+ // restricted functions
+ lzp_buffer_kernel_2(const lzp_buffer_kernel_2<sbuf>&); // copy constructor
+ lzp_buffer_kernel_2<sbuf>& operator=(const lzp_buffer_kernel_2<sbuf>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ lzp_buffer_kernel_2<sbuf>::
+ lzp_buffer_kernel_2 (
+ unsigned long buffer_size
+ ) :
+ table3(0),
+ table4(0),
+ table_size(65536)
+ {
+ buffer.set_size(buffer_size);
+
+ table3 = new (std::nothrow) unsigned long[table_size];
+ table4 = new (std::nothrow) unsigned long[table_size];
+
+ if (!table3 || !table4)
+ {
+ if (!table3)
+ delete [] table3;
+ if (!table4)
+ delete [] table4;
+
+ throw std::bad_alloc();
+ }
+
+
+
+ for (unsigned long i = 0; i < buffer.size(); ++i)
+ buffer[i] = 0;
+
+ for (unsigned long i = 0; i < table_size; ++i)
+ {
+ table3[i] = buffer.size();
+ table4[i] = buffer.size();
+ }
+
+ last_element = buffer.size()-1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ lzp_buffer_kernel_2<sbuf>::
+ ~lzp_buffer_kernel_2 (
+ )
+ {
+ delete [] table3;
+ delete [] table4;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ void lzp_buffer_kernel_2<sbuf>::
+ clear(
+ )
+ {
+ for (unsigned long i = 0; i < buffer.size(); ++i)
+ buffer[i] = 0;
+
+ for (unsigned long i = 0; i < table_size; ++i)
+ {
+ table3[i] = buffer.size();
+ table4[i] = buffer.size();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ void lzp_buffer_kernel_2<sbuf>::
+ add (
+ unsigned char symbol
+ )
+ {
+ buffer.rotate_left(1);
+ buffer[0] = symbol;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ unsigned long lzp_buffer_kernel_2<sbuf>::
+ predict_match (
+ unsigned long& index
+ )
+ {
+ unsigned long temp1 = buffer[0];
+ unsigned long temp2 = buffer[1];
+ temp2 <<= 8;
+ unsigned long temp3 = buffer[2];
+ temp3 <<= 16;
+ unsigned long temp4 = buffer[3];
+ temp4 <<= 24;
+ unsigned long temp5 = buffer[4];
+ temp5 <<= 12;
+
+ unsigned long context1 = temp1|temp2|temp3;
+ unsigned long context2 = context1|temp4;
+
+
+ const unsigned long i5 = ((temp5|(context2>>20))^context2)&0xFFFF;
+ const unsigned long i4 = ((context2>>15)^context2)&0xFFFF;
+ const unsigned long i3 = ((context1>>11)^context1)&0xFFFF;
+
+
+
+ // check the 5-order context's prediction
+ if (table3[i5] != buffer.size() &&
+ verify(buffer.get_element_index(table3[i5])) )
+ {
+ index = buffer.get_element_index(table3[i5]);
+ if (index > 20)
+ {
+ // update the prediction for this context
+ table3[i3] = buffer.get_element_id(last_element);
+ table4[i4] = table3[i3];
+ table3[i5] = table3[i3];
+ }
+ return 5;
+ }
+ // check the 4-order context's prediction
+ else if (table4[i4] != buffer.size() &&
+ verify(buffer.get_element_index(table4[i4])) )
+ {
+ index = buffer.get_element_index(table4[i4]);
+ if (index > 20)
+ {
+ // update the prediction for this context
+ table3[i3] = buffer.get_element_id(last_element);
+ table4[i4] = table3[i3];
+ table3[i5] = table3[i3];
+ }
+ return 4;
+ }
+ // check the 3-order context's prediction
+ else if (table3[i3] != buffer.size() &&
+ verify(buffer.get_element_index(table3[i3])))
+ {
+ index = buffer.get_element_index(table3[i3]);
+
+ if (index > 20)
+ {
+ // update the prediction for this context
+ table3[i3] = buffer.get_element_id(last_element);
+ table4[i4] = table3[i3];
+ table3[i5] = table3[i3];
+ }
+ return 3;
+ }
+ else
+ {
+ // update the prediction for this context
+ table3[i3] = buffer.get_element_id(last_element);
+ table4[i4] = table3[i3];
+ table3[i5] = table3[i3];
+
+ return 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ size_t lzp_buffer_kernel_2<sbuf>::
+ size (
+ ) const
+ {
+ return buffer.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sbuf
+ >
+ unsigned char lzp_buffer_kernel_2<sbuf>::
+ operator[] (
+ unsigned long index
+ ) const
+ {
+ return buffer[index];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LZP_BUFFER_KERNEl_2_
+
diff --git a/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_abstract.h b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_abstract.h
new file mode 100644
index 000000000..df8b8c80f
--- /dev/null
+++ b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_abstract.h
@@ -0,0 +1,130 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LZP_BUFFER_KERNEl_ABSTRACT_
+#ifdef DLIB_LZP_BUFFER_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+ class lzp_buffer
+ {
+ /*!
+ INITIAL VALUE
+ size() == some value defined by the constructor argument
+ Initially this object is at some predefined empty or ground state.
+ for all i: (*this)[i] == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents some varation on the LZP algorithm
+ described by Charles Bloom in his paper "LZP: a new data
+ compression algorithm"
+
+ The LZP algorithm is a lot like lz77 except there is no need to pass
+ the location of matches in the history buffer to the decoder because
+ LZP uses the data it has already seen to predict the location of the
+ next match.
+
+ NOTE
+ The add() and predict_match() functions must be called in the same
+ order by the coder and decoder. If they aren't the state of the
+ lzp_buffer objects in the coder and decoder may differ and the decoder
+ won't be able to correctly decode the data stream.
+ !*/
+
+ public:
+
+ explicit lzp_buffer (
+ unsigned long buffer_size
+ );
+ /*!
+ requires
+ - 10 < buffer_size < 32
+ ensures
+ - #*this is properly initialized
+ - #size() == 2^buffer_size
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~lzp_buffer (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void add (
+ unsigned char symbol
+ );
+ /*!
+ ensures
+ - shifts everything in the history buffer left 1.
+ (i.e. #(*this)[i+1] == (*this)[i])
+ - #(*this)[0] == symbol
+ throws
+ - std::bad_alloc
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ unsigned long predict_match (
+ unsigned long& index
+ );
+ /*!
+ ensures
+ - updates the prediction for the current context.
+ (the current context is the last few symbols seen. i.e. (*this)[0],
+ (*this)[1], etc.)
+ - if (*this can generate a prediction) then
+ - #index == the predicted location of a match in the history buffer.
+ (i.e. (*this)[#index] is the first symbol of the predicted match)
+ - returns the order this prediction came from
+ - else
+ - returns 0
+ throws
+ - std::bad_alloc
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns the size of the history buffer
+ !*/
+
+ unsigned char operator[] (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < size()
+ ensures
+ - returns the symbol at the given index in the history buffer
+ !*/
+
+ private:
+
+ // restricted functions
+ lzp_buffer(const lzp_buffer&); // copy constructor
+ lzp_buffer& operator=(const lzp_buffer&); // assignment operator
+
+ };
+}
+
+#endif // DLIB_LZP_BUFFER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_c.h b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_c.h
new file mode 100644
index 000000000..2b2de2f1d
--- /dev/null
+++ b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_c.h
@@ -0,0 +1,101 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LZP_BUFFER_KERNEl_C_
+#define DLIB_LZP_BUFFER_KERNEl_C_
+
+#include "lzp_buffer_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename lzp_base
+ >
+ class lzp_buffer_kernel_c : public lzp_base
+ {
+
+ public:
+ lzp_buffer_kernel_c (
+ unsigned long buffer_size
+ );
+
+
+ unsigned char operator[] (
+ unsigned long index
+ ) const;
+
+
+ unsigned long make_safe (
+ unsigned long buffer_size
+ )
+ /*!
+ ensures
+ - if ( 10 < buffer_size < 32) then
+ - returns buffer_size
+ - else
+ - throws due to failed CASSERT
+ !*/
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT( 10 < buffer_size && buffer_size < 32,
+ "\tlzp_buffer::lzp_buffer(unsigned long)"
+ << "\n\tbuffer_size must be in the range 11 to 31."
+ << "\n\tthis: " << this
+ << "\n\tbuffer_size: " << buffer_size
+ );
+
+ return buffer_size;
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lzp_base
+ >
+ unsigned char lzp_buffer_kernel_c<lzp_base>::
+ operator[] (
+ unsigned long index
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( index < this->size(),
+ "\tunsigned char lzp_buffer::operator[](unsigned long) const"
+ << "\n\tindex must be in the range 0 to size()()-1"
+ << "\n\tthis: " << this
+ << "\n\tsize(): " << this->size()
+ << "\n\tindex: " << index
+ );
+
+ // call the real function
+ return lzp_base::operator[](index);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lzp_base
+ >
+ lzp_buffer_kernel_c<lzp_base>::
+ lzp_buffer_kernel_c (
+ unsigned long buffer_size
+ ) :
+ lzp_base(make_safe(buffer_size))
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LZP_BUFFER_KERNEl_C_
+
diff --git a/ml/dlib/dlib/manifold_regularization.h b/ml/dlib/dlib/manifold_regularization.h
new file mode 100644
index 000000000..a7222fd4a
--- /dev/null
+++ b/ml/dlib/dlib/manifold_regularization.h
@@ -0,0 +1,13 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MANIFOLD_REGULARIzATION_HEADER
+#define DLIB_MANIFOLD_REGULARIzATION_HEADER
+
+#include "graph_utils/edge_list_graphs.h"
+#include "manifold_regularization/linear_manifold_regularizer.h"
+#include "graph_utils/function_objects.h"
+
+#endif // DLIB_MANIFOLD_REGULARIzATION_HEADER
+
+
+
diff --git a/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer.h b/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer.h
new file mode 100644
index 000000000..95b8b1128
--- /dev/null
+++ b/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer.h
@@ -0,0 +1,328 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LINEAR_MANIFOLD_ReGULARIZER_Hh_
+#define DLIB_LINEAR_MANIFOLD_ReGULARIZER_Hh_
+
+#include "linear_manifold_regularizer_abstract.h"
+#include <limits>
+#include <vector>
+#include "../serialize.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace impl
+ {
+ class undirected_adjacency_list
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is simply a tool for turning a vector of sample_pair objects
+ into an adjacency list with floating point weights on each edge.
+ !*/
+ public:
+
+ undirected_adjacency_list (
+ )
+ {
+ _size = 0;
+ sum_edge_weights = 0;
+ }
+
+ struct neighbor
+ {
+ neighbor(unsigned long idx, double w):index(idx), weight(w) {}
+ neighbor():index(0), weight(0) {}
+
+ unsigned long index;
+ double weight;
+ };
+
+ typedef std::vector<neighbor>::const_iterator const_iterator;
+
+ size_t size (
+ ) const
+ /*!
+ ensures
+ - returns the number of vertices in this graph
+ !*/
+ {
+ return _size;
+ }
+
+ const_iterator begin(
+ unsigned long idx
+ ) const
+ /*!
+ requires
+ - idx < size()
+ ensures
+ - returns an iterator that points to the first neighbor of
+ the idx'th vertex.
+ !*/
+ {
+ return blocks[idx];
+ }
+
+ const_iterator end(
+ unsigned long idx
+ ) const
+ /*!
+ requires
+ - idx < size()
+ ensures
+ - returns an iterator that points one past the last neighbor
+ of the idx'th vertex.
+ !*/
+ {
+ return blocks[idx+1];
+ }
+
+
+ template <typename vector_type, typename weight_function_type>
+ void build (
+ const vector_type& edges,
+ const weight_function_type& weight_funct
+ )
+ /*!
+ requires
+ - vector_type == a type with an interface compatible with std::vector and
+ it must in turn contain objects with an interface compatible with dlib::sample_pair
+ - edges.size() > 0
+ - contains_duplicate_pairs(edges) == false
+ - weight_funct(edges[i]) must be a valid expression that evaluates to a
+ floating point number >= 0
+ ensures
+ - #size() == one greater than the max index in edges.
+ - builds the adjacency list so that it contains all the given edges.
+ - The weight in each neighbor is set to the output of the weight_funct()
+ for the associated edge.
+ !*/
+ {
+
+
+ // Figure out how many neighbors each sample ultimately has. We do this so
+ // we will know how much space to allocate in the data vector.
+ std::vector<unsigned long> num_neighbors;
+ num_neighbors.reserve(edges.size());
+
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ // make sure num_neighbors is always big enough
+ const unsigned long min_size = std::max(edges[i].index1(), edges[i].index2())+1;
+ if (num_neighbors.size() < min_size)
+ num_neighbors.resize(min_size, 0);
+
+ num_neighbors[edges[i].index1()] += 1;
+ num_neighbors[edges[i].index2()] += 1;
+ }
+
+ _size = num_neighbors.size();
+
+ // Now setup the iterators in blocks. Also setup a version of blocks that holds
+ // non-const iterators so we can use it below when we populate data.
+ std::vector<std::vector<neighbor>::iterator> mutable_blocks;
+ data.resize(edges.size()*2); // each edge will show up twice
+ blocks.resize(_size + 1);
+ blocks[0] = data.begin();
+ mutable_blocks.resize(_size + 1);
+ mutable_blocks[0] = data.begin();
+ for (unsigned long i = 0; i < num_neighbors.size(); ++i)
+ {
+ blocks[i+1] = blocks[i] + num_neighbors[i];
+ mutable_blocks[i+1] = mutable_blocks[i] + num_neighbors[i];
+ }
+
+ sum_edge_weights = 0;
+ // finally, put the edges into data
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ const double weight = weight_funct(edges[i]);
+ sum_edge_weights += weight;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(weight >= 0,
+ "\t void linear_manifold_regularizer::build()"
+ << "\n\t You supplied a weight_funct() that generated a negative weight."
+ << "\n\t weight: " << weight
+ );
+
+ *mutable_blocks[edges[i].index1()]++ = neighbor(edges[i].index2(), weight);
+ *mutable_blocks[edges[i].index2()]++ = neighbor(edges[i].index1(), weight);
+ }
+
+ }
+
+ double sum_of_edge_weights (
+ ) const
+ {
+ return sum_edge_weights;
+ }
+
+ private:
+
+ /*!
+ INITIAL VALUE
+ - _size == 0
+ - data.size() == 0
+ - blocks.size() == 0
+ - sum_edge_weights == 0
+
+ CONVENTION
+ - size() == _size
+ - blocks.size() == _size + 1
+ - sum_of_edge_weights() == sum_edge_weights
+ - blocks == a vector of iterators that point into data.
+ For all valid i:
+ - The iterator range [blocks[i], blocks[i+1]) contains all the edges
+ for the i'th node in the graph
+ !*/
+
+ std::vector<neighbor> data;
+ std::vector<const_iterator> blocks;
+ unsigned long _size;
+
+ double sum_edge_weights;
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class linear_manifold_regularizer
+ {
+
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+
+
+ template <
+ typename vector_type1,
+ typename vector_type2,
+ typename weight_function_type
+ >
+ void build (
+ const vector_type1& samples,
+ const vector_type2& edges,
+ const weight_function_type& weight_funct
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(edges.size() > 0 &&
+ contains_duplicate_pairs(edges) == false &&
+ max_index_plus_one(edges) <= samples.size(),
+ "\t void linear_manifold_regularizer::build()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t edges.size(): " << edges.size()
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t contains_duplicate_pairs(edges): " << contains_duplicate_pairs(edges)
+ << "\n\t max_index_plus_one(edges): " << max_index_plus_one(edges)
+ );
+
+
+ impl::undirected_adjacency_list graph;
+ graph.build(edges, weight_funct);
+
+ sum_edge_weights = graph.sum_of_edge_weights();
+
+ make_mr_matrix(samples, graph);
+ }
+
+ long dimensionality (
+ ) const { return reg_mat.nr(); }
+
+ general_matrix get_transformation_matrix (
+ scalar_type intrinsic_regularization_strength
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(intrinsic_regularization_strength >= 0,
+ "\t matrix linear_manifold_regularizer::get_transformation_matrix()"
+ << "\n\t This value must not be negative"
+ << "\n\t intrinsic_regularization_strength: " << intrinsic_regularization_strength
+ );
+
+ if (dimensionality() == 0)
+ return general_matrix();
+
+
+ // This isn't how it's defined in the referenced paper but normalizing these kinds of
+ // sums is typical of most machine learning algorithms. Moreover, doing this makes
+ // the argument to this function more invariant to the size of the edge set. So it
+ // should make it easier for the user.
+ intrinsic_regularization_strength /= sum_edge_weights;
+
+ return inv_lower_triangular(chol(identity_matrix<scalar_type>(reg_mat.nr()) + intrinsic_regularization_strength*reg_mat));
+ }
+
+ private:
+
+ template <typename vector_type>
+ void make_mr_matrix (
+ const vector_type& samples,
+ const impl::undirected_adjacency_list& graph
+ )
+ /*!
+ requires
+ - samples.size() == graph.size()
+ ensures
+ - computes trans(X)*lap(graph)*X where X is the data matrix
+ (i.e. the matrix that contains all the samples in its rows)
+ and lap(graph) is the laplacian matrix of the graph. The
+ resulting matrix is stored in reg_mat.
+ !*/
+ {
+ const unsigned long dims = samples[0].size();
+ reg_mat.set_size(dims,dims);
+ reg_mat = 0;
+
+
+ typename impl::undirected_adjacency_list::const_iterator beg, end;
+
+ // loop over the columns of the X matrix
+ for (unsigned long d = 0; d < dims; ++d)
+ {
+ // loop down the row of X
+ for (unsigned long i = 0; i < graph.size(); ++i)
+ {
+ beg = graph.begin(i);
+ end = graph.end(i);
+
+ // if this node in the graph has any neighbors
+ if (beg != end)
+ {
+ double weight_sum = 0;
+ double val = 0;
+ for (; beg != end; ++beg)
+ {
+ val -= beg->weight * samples[beg->index](d);
+ weight_sum += beg->weight;
+ }
+ val += weight_sum * samples[i](d);
+
+ for (unsigned long j = 0; j < dims; ++j)
+ {
+ reg_mat(d,j) += val*samples[i](j);
+ }
+ }
+ }
+ }
+
+ }
+
+ general_matrix reg_mat;
+ double sum_edge_weights;
+ };
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LINEAR_MANIFOLD_ReGULARIZER_Hh_
+
diff --git a/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer_abstract.h b/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer_abstract.h
new file mode 100644
index 000000000..9a9b579c9
--- /dev/null
+++ b/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer_abstract.h
@@ -0,0 +1,137 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LINEAR_MANIFOLD_ReGULARIZER_ABSTRACT_Hh_
+#ifdef DLIB_LINEAR_MANIFOLD_ReGULARIZER_ABSTRACT_Hh_
+
+#include <limits>
+#include <vector>
+#include "../serialize.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class linear_manifold_regularizer
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ Must be some type of dlib::matrix.
+
+ INITIAL VALUE
+ - dimensionality() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ Many learning algorithms attempt to minimize a function that, at a high
+ level, looks like this:
+ f(w) == complexity + training_set_error
+
+ The idea is to find the set of parameters, w, that gives low error on
+ your training data but also is not "complex" according to some particular
+ measure of complexity. This strategy of penalizing complexity is
+ usually called regularization.
+
+ In the above setting, all the training data consists of labeled samples.
+ However, it would be nice to be able to benefit from unlabeled data.
+ The idea of manifold regularization is to extract useful information from
+ unlabeled data by first defining which data samples are "close" to each other
+ (perhaps by using their 3 nearest neighbors) and then adding a term to
+ the above function that penalizes any decision rule which produces
+ different outputs on data samples which we have designated as being close.
+
+ It turns out that it is possible to transform these manifold regularized
+ learning problems into the normal form shown above by applying a certain kind
+ of preprocessing to all our data samples. Once this is done we can use a
+ normal learning algorithm, such as the svm_c_linear_trainer, on just the
+ labeled data samples and obtain the same output as the manifold regularized
+ learner would have produced.
+
+ The linear_manifold_regularizer is a tool for creating this preprocessing
+ transformation. In particular, the transformation is linear. That is, it
+ is just a matrix you multiply with all your samples. For a more detailed
+ discussion of this topic you should consult the following paper. In
+ particular, see section 4.2. This object computes the inverse T matrix
+ described in that section.
+
+ Linear Manifold Regularization for Large Scale Semi-supervised Learning
+ by Vikas Sindhwani, Partha Niyogi, and Mikhail Belkin
+ !*/
+
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+
+ template <
+ typename vector_type1,
+ typename vector_type2,
+ typename weight_function_type
+ >
+ void build (
+ const vector_type1& samples,
+ const vector_type2& edges,
+ const weight_function_type& weight_funct
+ );
+ /*!
+ requires
+ - vector_type1 == a type with an interface compatible with std::vector and it must
+ in turn contain dlib::matrix objects.
+ - vector_type2 == a type with an interface compatible with std::vector and
+ it must in turn contain objects with an interface compatible with dlib::sample_pair
+ - edges.size() > 0
+ - contains_duplicate_pairs(edges) == false
+ - max_index_plus_one(edges) <= samples.size()
+ - weight_funct(edges[i]) must be a valid expression that evaluates to a
+ floating point number >= 0
+ ensures
+ - #dimensionality() == samples[0].size()
+ - This function sets up the transformation matrix describe above. The manifold
+ regularization is done assuming that the samples are meant to be "close"
+ according to the graph defined by the given edges. I.e:
+ - for all valid i: samples[edges[i].index1()] is close to samples[edges[i].index2()].
+ How much we care about these two samples having similar outputs according
+ to the learned rule is given by weight_funct(edges[i]). Bigger weights mean
+ we care more.
+ !*/
+
+ long dimensionality (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows and columns in the transformation matrix
+ produced by this object.
+ !*/
+
+ general_matrix get_transformation_matrix (
+ scalar_type intrinsic_regularization_strength
+ ) const;
+ /*!
+ requires
+ - intrinsic_regularization_strength >= 0
+ ensures
+ - returns a matrix that represents the preprocessing transformation described above.
+ - You must choose how important the manifold regularizer is relative to the basic
+ "don't be complex" regularizer described above. The intrinsic_regularization_strength
+ is the parameter that controls this trade-off. A large value of
+ intrinsic_regularization_strength means that more emphasis should be placed on
+ finding decision rules which produce the same output on similar samples. On
+ the other hand, a small value would mean that we don't care much about the
+ manifold regularizer. For example, using 0 will cause this function to return the
+ identity matrix.
+ - The returned matrix will have dimensionality() rows and columns.
+ !*/
+
+ };
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LINEAR_MANIFOLD_ReGULARIZER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/map.h b/ml/dlib/dlib/map.h
new file mode 100644
index 000000000..12036380e
--- /dev/null
+++ b/ml/dlib/dlib/map.h
@@ -0,0 +1,59 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MAp_
+#define DLIB_MAp_
+
+#include "map/map_kernel_1.h"
+#include "map/map_kernel_c.h"
+
+#include "binary_search_tree.h"
+
+
+#include "algs.h"
+#include <functional>
+
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<domain>
+ >
+ class map
+ {
+ map() {}
+
+
+ // a typedef for the binary search tree used by kernel_2
+ typedef typename binary_search_tree<domain,range,mem_manager,compare>::kernel_1a
+ binary_search_tree_1;
+
+ // a typedef for the binary search tree used by kernel_2
+ typedef typename binary_search_tree<domain,range,mem_manager,compare>::kernel_2a
+ binary_search_tree_2;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef map_kernel_1<domain,range,binary_search_tree_1,mem_manager>
+ kernel_1a;
+ typedef map_kernel_c<kernel_1a >
+ kernel_1a_c;
+
+ // kernel_1b
+ typedef map_kernel_1<domain,range,binary_search_tree_2,mem_manager>
+ kernel_1b;
+ typedef map_kernel_c<kernel_1b >
+ kernel_1b_c;
+
+
+ };
+}
+
+#endif // DLIB_MAp_
+
diff --git a/ml/dlib/dlib/map/map_kernel_1.h b/ml/dlib/dlib/map/map_kernel_1.h
new file mode 100644
index 000000000..1c79d179f
--- /dev/null
+++ b/ml/dlib/dlib/map/map_kernel_1.h
@@ -0,0 +1,436 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MAP_KERNEl_1_
+#define DLIB_MAP_KERNEl_1_
+
+#include "map_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/map_pair.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager = default_memory_manager
+ >
+ class map_kernel_1 : public enumerable<map_pair<domain,range> >,
+ public asc_pair_remover<domain,range,typename bst_base::compare_type>
+ {
+
+ /*!
+ REQUIREMENTS ON BST_BASE
+ bst_base is instantiated with domain and range and
+ implements binary_search_tree/binary_search_tree_kernel_abstract.h
+
+ INITIAL VALUE
+ bst has its initial value
+
+ CONVENTION
+ bst.size() == the number of elements in the map and
+ the elements in map are stored in bst_base
+ !*/
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef typename bst_base::compare_type compare_type;
+ typedef mem_manager mem_manager_type;
+
+ map_kernel_1(
+ )
+ {}
+
+ virtual ~map_kernel_1(
+ )
+ {}
+
+ inline void clear(
+ );
+
+ inline void add (
+ domain& d,
+ range& r
+ );
+
+ inline bool is_in_domain (
+ const domain& d
+ ) const;
+
+ inline void remove_any (
+ domain& d,
+ range& r
+ );
+
+ inline void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ inline void destroy (
+ const domain& d
+ );
+
+ inline range& operator[] (
+ const domain& d
+ );
+
+ inline const range& operator[] (
+ const domain& d
+ ) const;
+
+ inline void swap (
+ map_kernel_1& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ inline bool current_element_valid (
+ ) const;
+
+ inline const map_pair<domain,range>& element (
+ ) const;
+
+ inline map_pair<domain,range>& element (
+ );
+
+ inline bool move_next (
+ ) const;
+
+
+ private:
+
+ bst_base bst;
+
+ // restricted functions
+ map_kernel_1(map_kernel_1&);
+ map_kernel_1& operator= ( map_kernel_1&);
+ };
+
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ inline void swap (
+ map_kernel_1<domain,range,bst_base,mem_manager>& a,
+ map_kernel_1<domain,range,bst_base,mem_manager>& b
+ ) { a.swap(b); }
+
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ void deserialize (
+ map_kernel_1<domain,range,bst_base,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ domain d;
+ range r;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(d,in);
+ deserialize(r,in);
+ item.add(d,r);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type map_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ void map_kernel_1<domain,range,bst_base,mem_manager>::
+ clear (
+ )
+ {
+ bst.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ void map_kernel_1<domain,range,bst_base,mem_manager>::
+ add(
+ domain& d,
+ range& r
+ )
+ {
+ // try to add pair to bst_base
+ bst.add(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ bool map_kernel_1<domain,range,bst_base,mem_manager>::
+ is_in_domain(
+ const domain& d
+ ) const
+ {
+ return (bst[d] != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ void map_kernel_1<domain,range,bst_base,mem_manager>::
+ remove_any(
+ domain& d,
+ range& r
+ )
+ {
+ bst.remove_any(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ void map_kernel_1<domain,range,bst_base,mem_manager>::
+ remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ bst.remove(d,d_copy,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ void map_kernel_1<domain,range,bst_base,mem_manager>::
+ destroy (
+ const domain& d
+ )
+ {
+ bst.destroy(d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ range& map_kernel_1<domain,range,bst_base,mem_manager>::
+ operator[](
+ const domain& d
+ )
+ {
+ return *bst[d];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ const range& map_kernel_1<domain,range,bst_base,mem_manager>::
+ operator[](
+ const domain& d
+ ) const
+ {
+ return *bst[d];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ size_t map_kernel_1<domain,range,bst_base,mem_manager>::
+ size (
+ ) const
+ {
+ return bst.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ void map_kernel_1<domain,range,bst_base,mem_manager>::
+ swap (
+ map_kernel_1<domain,range,bst_base,mem_manager>& item
+ )
+ {
+ bst.swap(item.bst);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ bool map_kernel_1<domain,range,bst_base,mem_manager>::
+ at_start (
+ ) const
+ {
+ return bst.at_start();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ void map_kernel_1<domain,range,bst_base,mem_manager>::
+ reset (
+ ) const
+ {
+ bst.reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ bool map_kernel_1<domain,range,bst_base,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return bst.current_element_valid();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ const map_pair<domain,range>& map_kernel_1<domain,range,bst_base,mem_manager>::
+ element (
+ ) const
+ {
+ return bst.element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ map_pair<domain,range>& map_kernel_1<domain,range,bst_base,mem_manager>::
+ element (
+ )
+ {
+ return bst.element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename bst_base,
+ typename mem_manager
+ >
+ bool map_kernel_1<domain,range,bst_base,mem_manager>::
+ move_next (
+ ) const
+ {
+ return bst.move_next();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MAP_KERNEl_1_
+
diff --git a/ml/dlib/dlib/map/map_kernel_abstract.h b/ml/dlib/dlib/map/map_kernel_abstract.h
new file mode 100644
index 000000000..1e07e5e56
--- /dev/null
+++ b/ml/dlib/dlib/map/map_kernel_abstract.h
@@ -0,0 +1,235 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MAP_KERNEl_ABSTRACT_
+#ifdef DLIB_MAP_KERNEl_ABSTRACT_
+
+#include "../interfaces/map_pair.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include "../algs.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<domain>
+ >
+ class map : public enumerable<map_pair<domain,range> >,
+ public asc_pair_remover<domain,range,compare>
+ {
+
+ /*!
+ REQUIREMENTS ON domain
+ domain must be comparable by compare where compare is a functor compatible with std::less and
+ domain is swappable by a global swap() and
+ domain must have a default constructor
+
+ REQUIREMENTS ON range
+ range is swappable by a global swap() and
+ range must have a default constructor
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap(), is_in_domain(), and operator[] functions do not invalidate
+ pointers or references to internal data.
+ All other functions have no such guarantee.
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the domain (and each associated
+ range element) elements in ascending order according to the compare functor.
+ (i.e. the elements are enumerated in sorted order)
+
+ WHAT THIS OBJECT REPRESENTS
+ map contains items of type domain and range
+
+ This object is similar an array. It maps items of type domain on to
+ items of type range.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+ definition of equivalent:
+ a is equivalent to b if
+ a < b == false and
+ b < a == false
+ !*/
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ map(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ !*/
+
+ virtual ~map(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void add (
+ domain& d,
+ range& r
+ );
+ /*!
+ requires
+ - &d != &r (i.e. d and r cannot be the same variable)
+ - is_in_domain(d) == false
+ ensures
+ - #is_in_domain(d) == true
+ - #operator[](d) == r
+ - #d and #r have initial values for their types
+ - #size() == size() + 1
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ if add() throws then it has no effect
+ !*/
+
+ bool is_in_domain (
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - returns whether or not an element equivalent to d is in the
+ domain of *this
+ !*/
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+ /*!
+ requires
+ - &d != &r (i.e. d and r cannot be the same variable)
+ - &d != &d_copy (i.e. d and d_copy cannot be the same variable)
+ - &r != &d_copy (i.e. r and d_copy cannot be the same variable)
+ - is_in_domain(d) == true
+ ensures
+ - #is_in_domain(d) == false
+ - #d_copy is equivalent to d
+ - the element in the range of *this associated with #d_copy has been
+ swapped into #r
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void destroy (
+ const domain& d
+ );
+ /*!
+ requires
+ - is_in_domain(d) == true
+ ensures
+ - #is_in_domain(d) == false
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ range& operator[] (
+ const domain& d
+ );
+ /*!
+ requires
+ - is_in_domain(d) == true
+ ensures
+ - returns a non-const reference to the element in the range of *this
+ associated with the element equivalent to d
+ !*/
+
+ const range& operator[] (
+ const domain& d
+ ) const;
+ /*!
+ requires
+ - is_in_domain(d) == true
+ ensures
+ - returns a const reference to the element in the range of *this
+ associated with the element equivalent to d
+ !*/
+
+ void swap (
+ map& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+
+ private:
+
+ // restricted functions
+ map(map&); // copy constructor
+ map& operator=(map&); // assignment operator
+ };
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ inline void swap (
+ map<domain,range,mem_manager,compare>& a,
+ map<domain,range,mem_manager,compare>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename domain,
+ typename range,
+ typename mem_manager,
+ typename compare
+ >
+ void deserialize (
+ map<domain,range,mem_manager,compare>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_MAP_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/map/map_kernel_c.h b/ml/dlib/dlib/map/map_kernel_c.h
new file mode 100644
index 000000000..dfdbd4632
--- /dev/null
+++ b/ml/dlib/dlib/map/map_kernel_c.h
@@ -0,0 +1,248 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MAP_KERNEl_C_
+#define DLIB_MAP_KERNEl_C_
+
+#include "map_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include "../interfaces/map_pair.h"
+
+namespace dlib
+{
+
+ template <
+ typename map_base
+ >
+ class map_kernel_c : public map_base
+ {
+
+ typedef typename map_base::domain_type domain;
+ typedef typename map_base::range_type range;
+
+ public:
+ void add (
+ domain& d,
+ range& r
+ );
+
+ void remove_any (
+ domain& d,
+ range& r
+ );
+
+ void remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ );
+
+ void destroy (
+ const domain& d
+ );
+
+ range& operator[] (
+ const domain& d
+ );
+
+ const range& operator[] (
+ const domain& d
+ ) const;
+
+ const map_pair<domain,range>& element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst map_pair<domain,range>& map::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return map_base::element();
+ }
+
+ map_pair<domain,range>& element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tmap_pair<domain,range>& map::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return map_base::element();
+ }
+
+ };
+
+ template <
+ typename map_base
+ >
+ inline void swap (
+ map_kernel_c<map_base>& a,
+ map_kernel_c<map_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_base
+ >
+ void map_kernel_c<map_base>::
+ add (
+ domain& d,
+ range& r
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (!this->is_in_domain(d)) &&
+ (static_cast<void*>(&d) != static_cast<void*>(&r)),
+ "\tvoid map::add"
+ << "\n\tdomain element being added must not already be in the map"
+ << "\n\tand d and r must not be the same variable"
+ << "\n\tis_in_domain(d): " << (this->is_in_domain(d) ? "true" : "false")
+ << "\n\tthis: " << this
+ << "\n\t&d: " << static_cast<void*>(&d)
+ << "\n\t&r: " << static_cast<void*>(&r)
+ );
+
+ // call the real function
+ map_base::add(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_base
+ >
+ void map_kernel_c<map_base>::
+ remove_any (
+ domain& d,
+ range& r
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (this->size() > 0) &&
+ (static_cast<void*>(&d) != static_cast<void*>(&r)),
+ "\tvoid map::remove_any"
+ << "\n\tsize() must be greater than zero if something is going to be removed"
+ << "\n\tand d and r must not be the same variable."
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ << "\n\t&d: " << static_cast<void*>(&d)
+ << "\n\t&r: " << static_cast<void*>(&r)
+ );
+
+ // call the real function
+ map_base::remove_any(d,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_base
+ >
+ void map_kernel_c<map_base>::
+ remove (
+ const domain& d,
+ domain& d_copy,
+ range& r
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (this->is_in_domain(d)) &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&r)) &&
+ (static_cast<void*>(&r) != static_cast<void*>(&d_copy)) &&
+ (static_cast<const void*>(&d) != static_cast<void*>(&d_copy)),
+ "\tvoid map::remove"
+ << "\n\tcan't remove something that isn't in the map or if the paremeters actually"
+ << "\n\tare the same variable. Either way can't remove."
+ << "\n\tis_in_domain(d): " << (this->is_in_domain(d) ? "true" : "false")
+ << "\n\tthis: " << this
+ << "\n\t&d: " << static_cast<const void*>(&d)
+ << "\n\t&r: " << static_cast<void*>(&r)
+ << "\n\t&d_copy: " << static_cast<void*>(&d_copy)
+ );
+
+ // call the real function
+ map_base::remove(d,d_copy,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_base
+ >
+ void map_kernel_c<map_base>::
+ destroy (
+ const domain& d
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->is_in_domain(d),
+ "\tvoid map::destroy"
+ << "\n\tcan't remove something that isn't in the map"
+ << "\n\tthis: " << this
+ << "\n\t&d: " << static_cast<const void*>(&d)
+ );
+
+ // call the real function
+ map_base::destroy(d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_base
+ >
+ typename map_base::range_type& map_kernel_c<map_base>::
+ operator[] (
+ const domain& d
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->is_in_domain(d),
+ "\trange& map::operator[]"
+ << "\n\td must be in the domain of the map"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return map_base::operator[](d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_base
+ >
+ const typename map_base::range_type& map_kernel_c<map_base>::
+ operator[] (
+ const domain& d
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->is_in_domain(d),
+ "\tconst range& map::operator[]"
+ << "\n\td must be in the domain of the map"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return map_base::operator[](d);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MAP_KERNEl_C_
+
diff --git a/ml/dlib/dlib/matlab/CMakeLists.txt b/ml/dlib/dlib/matlab/CMakeLists.txt
new file mode 100644
index 000000000..b9a0beab9
--- /dev/null
+++ b/ml/dlib/dlib/matlab/CMakeLists.txt
@@ -0,0 +1,22 @@
+
+cmake_minimum_required(VERSION 2.8.12)
+
+PROJECT(mex_functions)
+
+include(cmake_mex_wrapper)
+
+add_subdirectory(.. dlib_build)
+
+
+# You can tell cmake where to put the mex files when you run 'make install' by
+# setting this variable. The path is relative to this CMakeLists.txt file.
+set(install_target_output_folder .)
+
+# Compile the example_mex_function.cpp file and link it to dlib. Note
+# that you can give a list of things to link to here. E.g.
+# add_mex_function(some_other_mex_function pthread dlib fftw)
+add_mex_function(example_mex_function dlib::dlib)
+add_mex_function(example_mex_callback dlib::dlib)
+add_mex_function(example_mex_struct dlib::dlib)
+add_mex_function(example_mex_class dlib::dlib)
+
diff --git a/ml/dlib/dlib/matlab/README.txt b/ml/dlib/dlib/matlab/README.txt
new file mode 100644
index 000000000..d571e2330
--- /dev/null
+++ b/ml/dlib/dlib/matlab/README.txt
@@ -0,0 +1,20 @@
+This folder contains a set of tools which make it easy to create MATLAB mex
+functions. To understand how they work, you should read the
+example_mex_function.cpp, example_mex_struct.cpp, and example_mex_callback.cpp examples.
+
+To compile them, you can use CMake. In particular, from this folder execute
+these commands:
+
+ mkdir build
+ cd build
+ cmake ..
+ cmake --build . --config release --target install
+
+That should build the mex files on any platform.
+
+Note that on windows you will probably need to tell CMake to use a 64bit
+version of visual studio. You can do this by using a command like:
+ cmake -G "Visual Studio 10 Win64" ..
+instead of
+ cmake ..
+
diff --git a/ml/dlib/dlib/matlab/call_matlab.h b/ml/dlib/dlib/matlab/call_matlab.h
new file mode 100644
index 000000000..cc06a6812
--- /dev/null
+++ b/ml/dlib/dlib/matlab/call_matlab.h
@@ -0,0 +1,852 @@
+// Copyright (C) 2012 Massachusetts Institute of Technology, Lincoln Laboratory
+// License: Boost Software License See LICENSE.txt for the full license.
+// Authors: Davis E. King (davis@dlib.net)
+#ifndef MIT_LL_CALL_MATLAB_H__
+#define MIT_LL_CALL_MATLAB_H__
+
+#include <string>
+#include <sstream>
+#include <dlib/error.h>
+#include <dlib/assert.h>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+struct invalid_args_exception : error
+{
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown when the mex wrapper tries to convert a matlab
+ object into a C++ object but for whatever reason can't (usually because the
+ types don't match).
+ !*/
+ invalid_args_exception(const std::string& msg_): error(msg_) {}
+ invalid_args_exception(const std::ostringstream& msg_): error(msg_.str()) {}
+};
+
+// ----------------------------------------------------------------------------------------
+
+void check_for_matlab_ctrl_c();
+/*!
+ ensures
+ - If the user of MATLAB has pressed ctrl+c then this function will throw an
+ exception.
+!*/
+
+// ----------------------------------------------------------------------------------------
+
+class matlab_object
+{
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple wrapper around matlab's generic mxArray, which is the
+ thing that is matlab's "anything object". So a matlab_object can be used as an
+ argument to a mex_function() that can bind to any matlab object at all. It can
+ also bind to "nothing" and so is inherently also an optional argument when
+ present in a mex_funciton().
+ !*/
+public:
+ matlab_object() : handle(0),should_free(false),arg_idx(0) {}
+ matlab_object(const matlab_object&) = delete;
+
+ ~matlab_object();
+
+
+ // Check if a matlab object is bound to this object.
+ bool is_empty() const { return handle==0; }
+ operator bool() const { return handle!=0; }
+
+ // Convert from MATLAB to C++, throw invalid_args_exception if not possible.
+ template <typename T> operator T() const;
+ template <typename T> void get(T& item) const;
+
+ // Convert from a C++ object to MATLAB
+ template <typename T> matlab_object& operator= (const T& new_val);
+
+
+ template <typename T> bool try_get(T& item) const
+ {
+ try { get(item); return true; }
+ catch(invalid_args_exception&) { return false; }
+ }
+
+ const void* get_handle() const { return handle; }
+ /*!
+ ensures
+ - returns a pointer to the mxArray object. Might be NULL.
+ !*/
+
+
+ matlab_object& operator=(const matlab_object&) = delete;
+
+ // Users shouldn't call these functions
+ const void* release_object_to_matlab() { const void* temp=handle; handle = 0; return temp; }
+ void set_object_handle(int arg_idx_, const void* sh) { DLIB_CASSERT(!handle); handle = sh; arg_idx=arg_idx_; }
+private:
+
+ const void* handle;
+ bool should_free;
+ int arg_idx;
+};
+
+// ----------------------------------------------------------------------------------------
+
+class matlab_struct
+{
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object lets you interface with MATLAB structs from C++. For example,
+ given a MATLAB struct named mystruct, you could access it's fields like this:
+ MATLAB way: mystruct.field
+ C++ way: mystruct["field"]
+ MATLAB way: mystruct.field.subfield
+ C++ way: mystruct["field"]["subfield"]
+
+ To get the values as C++ types you do something like this:
+ int val = mystruct["field"];
+ or
+ int val;
+ mystruct["field"].get(val);
+
+ See also example_mex_struct.cpp for an example that uses this part of the API.
+ !*/
+
+ class sub;
+public:
+ matlab_struct() : struct_handle(0),should_free(false),arg_idx(0) {}
+ matlab_struct(const matlab_struct&) = delete;
+ ~matlab_struct();
+
+ const sub operator[] (const std::string& name) const;
+ sub operator[] (const std::string& name);
+ bool has_field(const std::string& name) const;
+
+ const void* release_struct_to_matlab() { const void* temp=struct_handle; struct_handle = 0; return temp; }
+ void set_struct_handle(int arg_idx_, const void* sh) { DLIB_CASSERT(!struct_handle); struct_handle = sh; arg_idx=arg_idx_; }
+private:
+
+ class sub
+ {
+ public:
+ sub() : struct_handle(0), field_idx(-1) {}
+
+ template <typename T> operator T() const;
+ template <typename T> void get(T& item) const;
+ template <typename T> sub& operator= (const T& new_val);
+ const sub operator[] (const std::string& name) const;
+ sub operator[] (const std::string& name);
+ bool has_field(const std::string& name) const;
+ private:
+ friend class matlab_struct;
+ const void* struct_handle;
+ int field_idx;
+ sub& operator=(const sub&);
+ };
+ const void* struct_handle;
+ bool should_free;
+ int arg_idx;
+ matlab_struct& operator=(const matlab_struct&);
+};
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+template <typename T>
+struct output_decorator
+{
+ output_decorator(T& item_):item(item_){}
+ T& item;
+};
+
+template <typename T>
+output_decorator<T> returns(T& item) { return output_decorator<T>(item); }
+/*!
+ ensures
+ - decorates item as an output type. This stuff is used by the call_matlab()
+ functions to tell if an argument is an input to the function or is supposed
+ to be bound to one of the return arguments.
+!*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+struct function_handle
+{
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This type is used to represent function handles passed from MATLAB into a
+ mex function. You can call the function referenced by the handle by
+ saying:
+ call_matlab(my_handle);
+ !*/
+
+ // These two lines are just implementation details, ignore them.
+ function_handle():h(0){}
+ void* const h;
+};
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+void call_matlab (
+ const std::string& function_name
+);
+/*!
+ ensures
+ - Calls MATLAB's function of the given name
+!*/
+
+// ----------------------------------------------------------------------------------------
+
+void call_matlab (
+ const function_handle& funct
+);
+/*!
+ ensures
+ - Calls MATLAB's function represented by the handle funct
+!*/
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1
+);
+/*!
+ ensures
+ - calls MATLAB's function of the given name.
+ - if (A1 is not decorated as an output by returns()) then
+ - A1 is passed as an argument into the MATLAB function
+ - else
+ - A1 is treated as the first return value from the MATLAB function.
+!*/
+
+template <
+ typename T1
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1
+) { call_matlab("feval", funct, A1); }
+/*!
+ ensures
+ - Calls MATLAB's function represented by the handle funct
+ - if (A1 is not decorated as an output by returns()) then
+ - A1 is passed as an argument into the MATLAB function
+ - else
+ - A1 is treated as the first return value from the MATLAB function.
+!*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+/*
+ The rest of this file is just overloads of call_matlab() for up to 10 arguments (or
+ just 9 arguments if function_handle is used). They all do the same thing as the above
+ version of call_matlab(). Generally, any argument not decorated by returns() is an
+ input to the MATLAB function. On the other hand, all arguments decorated by returns()
+ are treated as outputs.
+*/
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1,
+ typename T2
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1,
+ typename T2,
+ typename T3
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& A12
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15, typename T16
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15, typename T16, typename T17
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15, typename T16, typename T17, typename T18
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17,
+ const T18& A18
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15, typename T16, typename T17, typename T18, typename T19
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17,
+ const T18& A18, const T19& A19
+);
+
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15, typename T16, typename T17, typename T18, typename T19,
+ typename T20
+ >
+void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17,
+ const T18& A18, const T19& A19, const T20& A20
+);
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+template <
+ typename T1,
+ typename T2
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1,
+ const T2& A2
+)
+{
+ call_matlab("feval", funct, A1, A2);
+}
+
+template <
+ typename T1,
+ typename T2,
+ typename T3
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3
+)
+{
+ call_matlab("feval", funct, A1, A2, A3);
+}
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4);
+}
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5);
+}
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6);
+}
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7);
+}
+
+template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15, typename T16
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15, typename T16, typename T17
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15, typename T16, typename T17, typename T18
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17,
+ const T18& A18
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18);
+}
+
+template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename
+ T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13,
+ typename T14, typename T15, typename T16, typename T17, typename T18, typename T19
+ >
+void call_matlab (
+ const function_handle& funct,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12&
+ A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17,
+ const T18& A18, const T19& A19
+)
+{
+ call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19);
+}
+
+// ----------------------------------------------------------------------------------------
+
+// We define this function here so that, if you write some code that has check_for_matlab_ctrl_c()
+// sprinkled throughout it you can still compile that code outside the mex wrapper
+// environment and these calls will simply be no-ops.
+#ifndef MATLAB_MEX_FILE
+inline void check_for_matlab_ctrl_c() {}
+#endif
+
+}
+
+#endif // MIT_LL_CALL_MATLAB_H__
+
diff --git a/ml/dlib/dlib/matlab/cmake_mex_wrapper b/ml/dlib/dlib/matlab/cmake_mex_wrapper
new file mode 100644
index 000000000..67729ff70
--- /dev/null
+++ b/ml/dlib/dlib/matlab/cmake_mex_wrapper
@@ -0,0 +1,103 @@
+# This file figures out where MATLAB is and then defines a macro, add_mex_function(name)
+# which when called instructs CMake to build a mex file from a file called name.cpp. Note
+# that additional library dependencies can be added like this: add_mex_function(name lib1 dlib libetc).
+# That is, just add more libraries after the name and they will be build into the mex file.
+
+cmake_minimum_required(VERSION 2.8.12)
+
+set(BUILDING_MATLAB_MEX_FILE true)
+set(CMAKE_POSITION_INDEPENDENT_CODE True)
+
+# Trying to use cuda with matlab hasn't worked well, so just disable it.
+SET(DLIB_USE_CUDA OFF CACHE BOOL "" FORCE)
+
+# Find MATLAB's include directory and needed libraries
+find_program(MATLAB_EXECUTABLE matlab PATHS
+ "C:/Program Files/MATLAB/*/bin"
+ "C:/Program Files (x86)/MATLAB/*/bin"
+ )
+# Resolve symbolic links to try and get the real path to the MATLAB executable
+get_filename_component(MATLAB_EXECUTABLE ${MATLAB_EXECUTABLE} REALPATH)
+# Now get MATLAB root directory
+get_filename_component(MATLAB_HOME ${MATLAB_EXECUTABLE} PATH)
+get_filename_component(MATLAB_HOME ${MATLAB_HOME} PATH)
+set(MATLAB_LIB_FOLDERS
+ "${MATLAB_HOME}/extern/lib/win64/microsoft"
+ "${MATLAB_HOME}/bin/glnxa64"
+ )
+# If there is a MATLAB_HOME environment variable then look there as well.
+if (DEFINED ENV{MATLAB_HOME})
+ set(MATLAB_LIB_FOLDERS
+ "$ENV{MATLAB_HOME}/extern/lib/win64/microsoft"
+ "$ENV{MATLAB_HOME}/bin/glnxa64"
+ ${MATLAB_LIB_FOLDERS}
+ )
+endif()
+# Find the MATLAB libraries that need to get linked into the mex file
+if (WIN32)
+ find_library(MATLAB_MEX_LIBRARY libmex PATHS ${MATLAB_LIB_FOLDERS} )
+ find_library(MATLAB_MX_LIBRARY libmx PATHS ${MATLAB_LIB_FOLDERS} )
+ find_library(MATLAB_ENG_LIBRARY libeng PATHS ${MATLAB_LIB_FOLDERS} )
+else()
+ find_library(MATLAB_MEX_LIBRARY mex PATHS ${MATLAB_LIB_FOLDERS} )
+ find_library(MATLAB_MX_LIBRARY mx PATHS ${MATLAB_LIB_FOLDERS} )
+ find_library(MATLAB_ENG_LIBRARY eng PATHS ${MATLAB_LIB_FOLDERS} )
+endif()
+set(MATLAB_LIBRARIES ${MATLAB_MEX_LIBRARY} ${MATLAB_MX_LIBRARY} ${MATLAB_ENG_LIBRARY})
+# Figure out the path to MATLAB's mex.h so we can add it to the include search path.
+find_path(mex_header_path mex.h
+ PATHS "$ENV{MATLAB_HOME}/extern/include"
+ "${MATLAB_HOME}/extern/include"
+ )
+INCLUDE_DIRECTORIES(${mex_header_path})
+
+# Determine the path to cmake_mex_wrapper file so we can add it to the include search path..
+string(REGEX REPLACE "cmake_mex_wrapper$" "" dlib_matlab_binding_path ${CMAKE_CURRENT_LIST_FILE})
+INCLUDE_DIRECTORIES("${dlib_matlab_binding_path}")
+# Also add dlib to the include search path
+INCLUDE_DIRECTORIES(${dlib_matlab_binding_path}/../..)
+
+add_definitions(-DMATLAB_MEX_FILE)
+
+# Determine the path to our CMakeLists.txt file. This is the file that
+# includeded the one you are reading right now. So here we make it so that
+# when you run the install target it will copy the compiled mex files into the
+# same folder as the parent CMakeLists.txt file.
+string(REGEX REPLACE "CMakeLists.txt$" "" install_dir ${CMAKE_PARENT_LIST_FILE})
+set(CMAKE_INSTALL_PREFIX "${install_dir}")
+set(CMAKE_INSTALL_SYSTEM_RUNTIME_DESTINATION "${install_dir}")
+INCLUDE(InstallRequiredSystemLibraries)
+
+
+MACRO(add_mex_function name )
+ ADD_LIBRARY(${name} MODULE ${name}.cpp )
+ target_compile_definitions(${name} PRIVATE -DMEX_FILENAME=${name})
+ if (UNIX)
+ # Doing this prevents our mex function from exporting any symbols
+ # other than mexFunction(). This sometimes doesn't matter but sometimes
+ # avoids causing errors or otherwise bad behavior in MATLAB.
+ if (DEFINED ENV{MATLAB_HOME})
+ set_target_properties(${name} PROPERTIES LINK_FLAGS "-Wl,--version-script,$ENV{MATLAB_HOME}/extern/lib/glnxa64/mexFunction.map")
+ else()
+ set_target_properties(${name} PROPERTIES LINK_FLAGS "-Wl,--version-script,${MATLAB_HOME}/extern/lib/glnxa64/mexFunction.map")
+ endif()
+ endif()
+
+ # Change the output file extension to a mex extension.
+ if (WIN32)
+ set_target_properties(${name} PROPERTIES SUFFIX ".mexw64")
+ elseif(APPLE)
+ set_target_properties(${name} PROPERTIES SUFFIX ".mexmaci64")
+ else()
+ set_target_properties(${name} PROPERTIES SUFFIX ".mexa64")
+ endif()
+ set_target_properties(${name} PROPERTIES PREFIX "")
+ TARGET_LINK_LIBRARIES(${name} ${MATLAB_LIBRARIES} ${ARGN})
+ if (install_target_output_folder)
+ install(TARGETS ${name} DESTINATION "${install_target_output_folder}")
+ else()
+ install(TARGETS ${name} DESTINATION "${install_dir}")
+ endif()
+ENDMACRO()
+
+
diff --git a/ml/dlib/dlib/matlab/example.m b/ml/dlib/dlib/matlab/example.m
new file mode 100644
index 000000000..8ed47346b
--- /dev/null
+++ b/ml/dlib/dlib/matlab/example.m
@@ -0,0 +1,16 @@
+% This example calls the three mex functions defined in this folder. As you
+% can see, you call them just like you would normal MATLAB functions.
+
+x = magic(3)
+y = 2*magic(3)
+
+[out1, out2] = example_mex_function(x,y, 12345)
+
+z = example_mex_callback(x, @(a)a+a)
+
+
+input = {}
+input.val = 2
+input.stuff = 'some string'
+output = example_mex_struct(input)
+
diff --git a/ml/dlib/dlib/matlab/example_mex_callback.cpp b/ml/dlib/dlib/matlab/example_mex_callback.cpp
new file mode 100644
index 000000000..a5a25dda1
--- /dev/null
+++ b/ml/dlib/dlib/matlab/example_mex_callback.cpp
@@ -0,0 +1,52 @@
+// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
+
+#include "call_matlab.h"
+#include "dlib/matrix.h"
+
+using namespace dlib;
+using namespace std;
+
+/*
+ This mex function takes a MATLAB function handle, calls it, and
+ returns the results.
+
+ For example, you can call this function in MATLAB like so:
+ A = magic(3)
+ y = example_mex_callback(A, @(x)x+x)
+
+ This will result in y containing the value 2*A.
+*/
+
+void mex_function (
+ const matrix<double>& A,
+ const function_handle& f,
+ matrix<double>& result
+)
+{
+ // The f argument to this function is a function handle passed from MATLAB. To
+ // call it we use the following syntax:
+ call_matlab(f, A, returns(result));
+ // This is equivalent to result = f(A). Therefore, the returns(variable) syntax
+ // is used to indicate which variables are outputs of the function.
+
+
+
+
+ // Another thing we can do is call MATLAB functions based on their string name
+ // rather than a function_handle. Here is an example of calling eigs().
+ matrix<double> m(2,2);
+ m = 1,2,
+ 3,4;
+ matrix<double> v,d;
+
+ // This is equivalent to [v,d] = eigs(m);
+ call_matlab("eigs", m, returns(v), returns(d));
+ cout << "eigenvectors: \n" << v << endl;
+ cout << "eigenvalues: \n" << d << endl;
+}
+
+
+
+// #including this brings in all the mex boiler plate needed by MATLAB.
+#include "mex_wrapper.cpp"
+
diff --git a/ml/dlib/dlib/matlab/example_mex_class.cpp b/ml/dlib/dlib/matlab/example_mex_class.cpp
new file mode 100644
index 000000000..b4242721b
--- /dev/null
+++ b/ml/dlib/dlib/matlab/example_mex_class.cpp
@@ -0,0 +1,72 @@
+// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
+/*
+ This mex file will create a MATLAB function called example_mex_class. If you call it
+ with no arguments it will output the MATLAB .m code to create a MATLAB wrapper class.
+ Paste that code into a .m file. Then you will be able to work with this C++ class
+ directly in MATLAB.
+*/
+
+#include <iostream>
+#include <dlib/matrix.h>
+
+
+using namespace std;
+using namespace dlib;
+
+class example_class
+{
+public:
+
+ // The class must have a default constructor. It's also the only kind of constructor
+ // you can call from MATLAB.
+ example_class()
+ {
+ xx.set_size(3,2);
+ xx = 1;
+ }
+
+ // The rest of the member functions that you want to bind have to return void and
+ // generally have the same syntax limitations as regular mex funcitons.
+ void do_stuff(const matrix_colmajor& x)
+ {
+ cout << "in do_stuff" << endl;
+ cout << x << endl;
+ xx = x;
+ }
+
+ void do_other_stuff(int x)
+ {
+ cout << "in do_other_stuff" << endl;
+ cout << "x: " << x << endl;
+ }
+
+ void print_state()
+ {
+ cout << xx << endl;
+ }
+
+ // saveobj() and load_obj() are special functions. If you provide these then you will
+ // be able to save() and load() your objects using MATLAB's built in object
+ // serialization.
+ void saveobj(matrix_colmajor& state)
+ {
+ // save this object's state to state.
+ state = xx;
+ }
+ void load_obj(const matrix_colmajor& state)
+ {
+ xx = state;
+ }
+
+private:
+ matrix_colmajor xx;
+};
+
+// Just tell the mex wrapper the name of your class and list the methods you want to bind.
+#define MEX_CLASS_NAME example_class
+#define MEX_CLASS_METHODS do_stuff, do_other_stuff, print_state, saveobj, load_obj
+
+
+#include "mex_wrapper.cpp"
+
+
diff --git a/ml/dlib/dlib/matlab/example_mex_function.cpp b/ml/dlib/dlib/matlab/example_mex_function.cpp
new file mode 100644
index 000000000..49d2e35fa
--- /dev/null
+++ b/ml/dlib/dlib/matlab/example_mex_function.cpp
@@ -0,0 +1,84 @@
+// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
+
+#include "dlib/matrix.h"
+using namespace dlib;
+using namespace std;
+
+
+/*!
+ This file defines a function callable from MATLAB once you mex it.
+
+ It computes the same thing as the following MATLAB function:
+
+ function [A, B] = example_mex_function(x, y, some_number)
+ A = x+y;
+ B = sum(sum(x+y));
+ disp(['some_number: ' num2str(some_number)])
+ end
+
+
+ VALID INPUT AND OUTPUT ARGUMENTS
+ The mex wrapper can handle the following kinds of input and output arguments:
+ - Types corresponding to a MATLAB matrix
+ - a dlib::matrix containing any kind of scalar value.
+ - a dlib::array2d containing any kind of scalar value.
+ - a dlib::vector containing any kind of scalar value.
+ - a dlib::point
+ - matrix_colmajor or fmatrix_colmajor
+ These are just typedefs for matrix containing double or float and using a
+ column major memory layout. However, they have the special distinction
+ of being fast to use in mex files since they sit directly on top of
+ MATLAB's built in matrices. That is, while other types of arguments copy
+ a MATLAB object into themselves, the matrix_colmajor and fmatrix_colmajor
+ do no such copy and are effectively zero overhead methods for working on
+ MATLAB's matrices.
+
+ - RGB color images
+ - dlib::array2d<dlib::rgb_pixel> can be used to represent
+ MATLAB uint8 MxNx3 images.
+
+ - Types corresponding to a MATLAB scalar
+ - any kind of scalar value, e.g. double, int, etc.
+
+ - Types corresponding to a MATLAB string
+ - std::string
+
+ - Types corresponding to a MATLAB cell array
+ - a std::vector or dlib::array containing any of the above
+ types of objects or std::vector or dlib::array objects.
+
+ - matlab_struct and matlab_object. These are special types defined in the
+ call_matlab.h file and correspond to matlab structs and arbitrary matlab
+ objects respectively.
+!*/
+
+
+// You can also define default values for your input arguments. So
+// here we say that if the user in MATLAB doesn't provide the "some_number"
+// then it will get a value of 3.141.
+#define ARG_5_DEFAULT 3.141
+
+// Make a function named mex_function() and put your code inside it.
+// Note that the return type should be void. Use non-const reference
+// arguments to return outputs. Finally, mex_function() must have no
+// more than 20 arguments.
+void mex_function (
+ const matrix_colmajor& x,
+ const matrix_colmajor& y,
+ matrix_colmajor& out1,
+ double& out2,
+ double some_number
+)
+{
+ out1 = x + y;
+ out2 = sum(x+y);
+
+ // we can also use cout to print things as usual:
+ cout << "some_number: "<< some_number << endl;
+}
+
+
+
+// #including this brings in all the mex boiler plate needed by MATLAB.
+#include "mex_wrapper.cpp"
+
diff --git a/ml/dlib/dlib/matlab/example_mex_struct.cpp b/ml/dlib/dlib/matlab/example_mex_struct.cpp
new file mode 100644
index 000000000..a948fbff9
--- /dev/null
+++ b/ml/dlib/dlib/matlab/example_mex_struct.cpp
@@ -0,0 +1,55 @@
+// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
+
+#include "call_matlab.h"
+#include "dlib/matrix.h"
+using namespace dlib;
+using namespace std;
+
+
+/*
+ This mex function takes a MATLAB struct, prints a few of its fields,
+ and then returns a new struct.
+
+ For example, you can call this function in MATLAB like so:
+ input = {}
+ input.val = 2
+ input.stuff = 'some string'
+ output = example_mex_struct(input)
+
+ output.number
+ output.number2
+ output.sub.stuff
+ output.sub.some_matrix
+*/
+
+
+void mex_function (
+ const matlab_struct& input,
+ matlab_struct& output
+)
+{
+ int val = input["val"];
+ string stuff = input["stuff"];
+
+ if (input.has_field("val2"))
+ {
+ string val2 = input["val2"];
+ cout << "The optional val2 field was set to: " << val2 << endl;
+ }
+
+ cout << "val: "<< val << endl;
+ cout << "stuff: " << stuff << endl;
+
+ output["number"] = 999;
+
+ output["number2"] = 1000;
+ output["sub"]["stuff"] = "some other string";
+ matrix<double> m = randm(2,2);
+ output["sub"]["some_matrix"] = m;
+}
+
+
+
+// #including this brings in all the mex boiler plate needed by MATLAB.
+#include "mex_wrapper.cpp"
+
diff --git a/ml/dlib/dlib/matlab/mex_wrapper.cpp b/ml/dlib/dlib/matlab/mex_wrapper.cpp
new file mode 100644
index 000000000..30c7e12ac
--- /dev/null
+++ b/ml/dlib/dlib/matlab/mex_wrapper.cpp
@@ -0,0 +1,5144 @@
+// Copyright (C) 2012 Massachusetts Institute of Technology, Lincoln Laboratory
+// License: Boost Software License See LICENSE.txt for the full license.
+// Authors: Davis E. King (davis@dlib.net)
+/*
+ READ THIS FIRST
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ \############/
+ \##########/
+ \########/
+ \######/
+ \####/
+ \##/
+ \/
+
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ See example_mex_function.cpp for a discussion of how to use the mex wrapper.
+
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ /\
+ /##\
+ /####\
+ /######\
+ /########\
+ /##########\
+ /############\
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ ######
+ READ THIS FIRST
+*/
+
+// Copyright (C) 2012 Massachusetts Institute of Technology, Lincoln Laboratory
+// License: Boost Software License See LICENSE.txt for the full license.
+// Authors: Davis E. King (davis@dlib.net)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// BEGIN IMPLEMENTATION DETAILS
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+#include "../matrix.h"
+#include "../array2d.h"
+#include "../array.h"
+#include "../image_transforms.h"
+#include "../is_kind.h"
+#include "../string.h"
+#include "../any.h" // for sig_traits
+#include "../hash.h"
+#include <tuple>
+#include <map>
+
+#if defined(_MSC_VER)
+#define DLL_EXPORT_SYM __declspec(dllexport)
+#endif
+#include "mex.h"
+#include <sstream>
+#include "call_matlab.h"
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef ARG_1_DEFAULT
+#define ELSE_ASSIGN_ARG_1 else A1 = ARG_1_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_1
+#endif
+
+#ifdef ARG_2_DEFAULT
+#define ELSE_ASSIGN_ARG_2 else A2 = ARG_2_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_2
+#endif
+
+#ifdef ARG_3_DEFAULT
+#define ELSE_ASSIGN_ARG_3 else A3 = ARG_3_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_3
+#endif
+
+#ifdef ARG_4_DEFAULT
+#define ELSE_ASSIGN_ARG_4 else A4 = ARG_4_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_4
+#endif
+
+#ifdef ARG_5_DEFAULT
+#define ELSE_ASSIGN_ARG_5 else A5 = ARG_5_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_5
+#endif
+
+#ifdef ARG_6_DEFAULT
+#define ELSE_ASSIGN_ARG_6 else A6 = ARG_6_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_6
+#endif
+
+#ifdef ARG_7_DEFAULT
+#define ELSE_ASSIGN_ARG_7 else A7 = ARG_7_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_7
+#endif
+
+#ifdef ARG_8_DEFAULT
+#define ELSE_ASSIGN_ARG_8 else A8 = ARG_8_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_8
+#endif
+
+#ifdef ARG_9_DEFAULT
+#define ELSE_ASSIGN_ARG_9 else A9 = ARG_9_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_9
+#endif
+
+#ifdef ARG_10_DEFAULT
+#define ELSE_ASSIGN_ARG_10 else A10 = ARG_10_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_10
+#endif
+
+#ifdef ARG_11_DEFAULT
+#define ELSE_ASSIGN_ARG_11 else A11 = ARG_11_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_11
+#endif
+
+#ifdef ARG_12_DEFAULT
+#define ELSE_ASSIGN_ARG_12 else A12 = ARG_12_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_12
+#endif
+
+#ifdef ARG_13_DEFAULT
+#define ELSE_ASSIGN_ARG_13 else A13 = ARG_13_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_13
+#endif
+
+#ifdef ARG_14_DEFAULT
+#define ELSE_ASSIGN_ARG_14 else A14 = ARG_14_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_14
+#endif
+
+#ifdef ARG_15_DEFAULT
+#define ELSE_ASSIGN_ARG_15 else A15 = ARG_15_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_15
+#endif
+
+#ifdef ARG_16_DEFAULT
+#define ELSE_ASSIGN_ARG_16 else A16 = ARG_16_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_16
+#endif
+
+#ifdef ARG_17_DEFAULT
+#define ELSE_ASSIGN_ARG_17 else A17 = ARG_17_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_17
+#endif
+
+#ifdef ARG_18_DEFAULT
+#define ELSE_ASSIGN_ARG_18 else A18 = ARG_18_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_18
+#endif
+
+#ifdef ARG_19_DEFAULT
+#define ELSE_ASSIGN_ARG_19 else A19 = ARG_19_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_19
+#endif
+
+#ifdef ARG_20_DEFAULT
+#define ELSE_ASSIGN_ARG_20 else A20 = ARG_20_DEFAULT;
+#else
+#define ELSE_ASSIGN_ARG_20
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+namespace mex_binding
+{
+ using namespace dlib;
+
+ template <typename T>
+ struct is_input_type
+ {
+ const static unsigned long value = (!is_same_type<void,T>::value && (!is_reference_type<T>::value || is_const_type<T>::value )) ? 1 : 0;
+ };
+ template <typename T>
+ struct is_output_type
+ {
+ const static unsigned long value = (!is_same_type<void,T>::value && is_reference_type<T>::value && !is_const_type<T>::value) ? 1 : 0;
+ };
+
+
+ template <typename funct>
+ struct funct_traits
+ {
+ const static unsigned long num_inputs = is_input_type<typename sig_traits<funct>::arg1_type>::value +
+ is_input_type<typename sig_traits<funct>::arg2_type>::value +
+ is_input_type<typename sig_traits<funct>::arg3_type>::value +
+ is_input_type<typename sig_traits<funct>::arg4_type>::value +
+ is_input_type<typename sig_traits<funct>::arg5_type>::value +
+ is_input_type<typename sig_traits<funct>::arg6_type>::value +
+ is_input_type<typename sig_traits<funct>::arg7_type>::value +
+ is_input_type<typename sig_traits<funct>::arg8_type>::value +
+ is_input_type<typename sig_traits<funct>::arg9_type>::value +
+ is_input_type<typename sig_traits<funct>::arg10_type>::value +
+ is_input_type<typename sig_traits<funct>::arg11_type>::value +
+ is_input_type<typename sig_traits<funct>::arg12_type>::value +
+ is_input_type<typename sig_traits<funct>::arg13_type>::value +
+ is_input_type<typename sig_traits<funct>::arg14_type>::value +
+ is_input_type<typename sig_traits<funct>::arg15_type>::value +
+ is_input_type<typename sig_traits<funct>::arg16_type>::value +
+ is_input_type<typename sig_traits<funct>::arg17_type>::value +
+ is_input_type<typename sig_traits<funct>::arg18_type>::value +
+ is_input_type<typename sig_traits<funct>::arg19_type>::value +
+ is_input_type<typename sig_traits<funct>::arg20_type>::value;
+
+ const static unsigned long num_outputs= is_output_type<typename sig_traits<funct>::arg1_type>::value +
+ is_output_type<typename sig_traits<funct>::arg2_type>::value +
+ is_output_type<typename sig_traits<funct>::arg3_type>::value +
+ is_output_type<typename sig_traits<funct>::arg4_type>::value +
+ is_output_type<typename sig_traits<funct>::arg5_type>::value +
+ is_output_type<typename sig_traits<funct>::arg6_type>::value +
+ is_output_type<typename sig_traits<funct>::arg7_type>::value +
+ is_output_type<typename sig_traits<funct>::arg8_type>::value +
+ is_output_type<typename sig_traits<funct>::arg9_type>::value +
+ is_output_type<typename sig_traits<funct>::arg10_type>::value +
+ is_output_type<typename sig_traits<funct>::arg11_type>::value +
+ is_output_type<typename sig_traits<funct>::arg12_type>::value +
+ is_output_type<typename sig_traits<funct>::arg13_type>::value +
+ is_output_type<typename sig_traits<funct>::arg14_type>::value +
+ is_output_type<typename sig_traits<funct>::arg15_type>::value +
+ is_output_type<typename sig_traits<funct>::arg16_type>::value +
+ is_output_type<typename sig_traits<funct>::arg17_type>::value +
+ is_output_type<typename sig_traits<funct>::arg18_type>::value +
+ is_output_type<typename sig_traits<funct>::arg19_type>::value +
+ is_output_type<typename sig_traits<funct>::arg20_type>::value;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct is_array_type
+ {
+ // true if T is std::vector or dlib::array
+ const static bool value = is_std_vector<T>::value || dlib::is_array<T>::value;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename enabled = void
+ >
+ struct inner_type
+ {
+ typedef T type;
+ };
+
+ template < typename T>
+ struct inner_type<T, typename dlib::enable_if_c<is_matrix<T>::value || is_array2d<T>::value || dlib::is_array<T>::value >::type>
+ {
+ typedef typename T::type type;
+ };
+
+ template < typename T>
+ struct inner_type<T, typename dlib::enable_if<is_std_vector<T> >::type>
+ {
+ typedef typename T::value_type type;
+ };
+
+
+// -------------------------------------------------------
+
+ struct user_hit_ctrl_c {};
+
+// -------------------------------------------------------
+
+ template <typename T>
+ void validate_and_populate_arg (
+ long arg_idx,
+ const mxArray *prhs,
+ T& arg
+ );
+
+// -------------------------------------------------------
+
+ template <typename T>
+ struct is_column_major_matrix : public default_is_kind_value {};
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ struct is_column_major_matrix<matrix<T,num_rows,num_cols,mem_manager,column_major_layout> >
+ { static const bool value = true; };
+
+// -------------------------------------------------------
+
+ string escape_percent(const string& str)
+ {
+ string temp;
+ for(auto c : str)
+ {
+ if (c != '%')
+ {
+ temp += c;
+ }
+ else
+ {
+ temp += c;
+ temp += c;
+ }
+ }
+ return temp;
+ }
+
+ string escape_percent(const std::ostringstream& sout)
+ {
+ return escape_percent(sout.str());
+ }
+
+// -------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ typename dlib::enable_if_c<is_matrix<matrix_type>::value || is_array2d<matrix_type>::value >::type
+ clear_mat (
+ matrix_type& m
+ )
+ {
+ m.set_size(0,0);
+ }
+
+ template <
+ typename matrix_type
+ >
+ typename dlib::disable_if_c<is_matrix<matrix_type>::value || is_array2d<matrix_type>::value >::type
+ clear_mat (
+ matrix_type&
+ )
+ {
+ }
+
+// -------------------------------------------------------
+
+ template <
+ typename matrix_type,
+ typename EXP
+ >
+ typename dlib::enable_if_c<is_matrix<matrix_type>::value && is_same_type<typename inner_type<matrix_type>::type,typename EXP::type>::value >::type
+ assign_mat (
+ const long arg_idx,
+ matrix_type& m,
+ const matrix_exp<EXP>& src
+ )
+ {
+ if (matrix_type::NR != 0 && matrix_type::NR != src.nc())
+ {
+ std::ostringstream sout;
+ sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NR << " rows but got one with " << src.nc();
+ throw invalid_args_exception(sout);
+ }
+ if (matrix_type::NC != 0 && matrix_type::NC != src.nr())
+ {
+ std::ostringstream sout;
+ sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NC << " columns but got one with " << src.nr();
+ throw invalid_args_exception(sout);
+ }
+
+
+ m = trans(src);
+ }
+
+ template <
+ typename matrix_type,
+ typename EXP
+ >
+ typename dlib::enable_if_c<is_array2d<matrix_type>::value && is_same_type<typename inner_type<matrix_type>::type,typename EXP::type>::value >::type
+ assign_mat (
+ const long arg_idx,
+ matrix_type& m,
+ const matrix_exp<EXP>& src
+ )
+ {
+ assign_image(m , trans(src));
+ }
+
+ template <
+ typename matrix_type,
+ typename EXP
+ >
+ typename disable_if_c<(is_array2d<matrix_type>::value || is_matrix<matrix_type>::value) &&
+ is_same_type<typename inner_type<matrix_type>::type,typename EXP::type>::value >::type
+ assign_mat (
+ const long arg_idx,
+ matrix_type& ,
+ const matrix_exp<EXP>&
+ )
+ {
+ std::ostringstream sout;
+ sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
+ throw invalid_args_exception(sout);
+ }
+
+
+// -------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ typename dlib::enable_if_c<is_built_in_scalar_type<T>::value || is_same_type<T,bool>::value >::type
+ assign_scalar (
+ const long arg_idx,
+ T& dest,
+ const U& src
+ )
+ {
+ if (is_signed_type<U>::value && src < 0 && is_unsigned_type<T>::value)
+ {
+ std::ostringstream sout;
+ sout << "Error, input argument " << arg_idx+1 << " must be a non-negative number.";
+ throw invalid_args_exception(sout);
+ }
+ else
+ {
+ dest = src;
+ }
+ }
+
+ template <
+ typename T,
+ typename U
+ >
+ typename dlib::disable_if_c<is_built_in_scalar_type<T>::value || is_same_type<T,bool>::value >::type
+ assign_scalar (
+ const long arg_idx,
+ T& ,
+ const U&
+ )
+ {
+ std::ostringstream sout;
+ sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
+ throw invalid_args_exception(sout);
+ }
+
+
+// -------------------------------------------------------
+
+ void assign_function_handle (
+ const long arg_idx,
+ function_handle& dest,
+ const mxArray* src
+ )
+ {
+ const_cast<void*&>(dest.h) = (void*)src;
+ }
+
+ template <
+ typename T
+ >
+ void assign_function_handle (
+ const long arg_idx,
+ T& ,
+ const mxArray*
+ )
+ {
+ std::ostringstream sout;
+ sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
+ throw invalid_args_exception(sout);
+ }
+
+
+// -------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename dlib::enable_if<is_array_type<T> >::type
+ assign_std_vector (
+ const long arg_idx,
+ T& dest,
+ const mxArray* src
+ )
+ {
+ const long nr = mxGetM(src);
+ const long nc = mxGetN(src);
+
+ typedef typename inner_type<T>::type type;
+
+ if (!mxIsCell(src))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a cell array";
+ throw invalid_args_exception(sout);
+ }
+ if (nr != 1 && nc != 1)
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a cell array with exactly 1 row or 1 column (i.e. a row or column vector)";
+ throw invalid_args_exception(sout);
+ }
+
+ const long size = nr*nc;
+ dest.resize(size);
+
+ for (unsigned long i = 0; i < dest.size(); ++i)
+ {
+ try
+ {
+ validate_and_populate_arg(i, mxGetCell(src, i), dest[i]);
+ }
+ catch (invalid_args_exception& e)
+ {
+ std::ostringstream sout;
+ sout << "Error in argument " << arg_idx+1 << ": element " << i+1 << " of cell array not the expected type.\n";
+ sout << "\t" << e.what();
+ throw invalid_args_exception(sout);
+ }
+ }
+
+ }
+
+ template <
+ typename T
+ >
+ typename disable_if<is_array_type<T> >::type
+ assign_std_vector (
+ const long arg_idx,
+ T& ,
+ const mxArray*
+ )
+ {
+ std::ostringstream sout;
+ sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
+ throw invalid_args_exception(sout);
+ }
+
+// -------------------------------------------------------
+
+ template <typename T>
+ void assign_image (
+ const long arg_idx,
+ T&,
+ const dlib::uint8* data,
+ long nr,
+ long nc
+ )
+ {
+ std::ostringstream sout;
+ sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
+ throw invalid_args_exception(sout);
+ }
+
+ template <typename MM>
+ void assign_image(
+ const long ,
+ array2d<dlib::rgb_pixel,MM>& img,
+ const dlib::uint8* data,
+ long nr,
+ long nc
+ )
+ {
+ img.set_size(nr, nc);
+ for (long c = 0; c < img.nc(); ++c)
+ for (long r = 0; r < img.nr(); ++r)
+ img[r][c].red = *data++;
+ for (long c = 0; c < img.nc(); ++c)
+ for (long r = 0; r < img.nr(); ++r)
+ img[r][c].green = *data++;
+ for (long c = 0; c < img.nc(); ++c)
+ for (long r = 0; r < img.nr(); ++r)
+ img[r][c].blue = *data++;
+ }
+
+// -------------------------------------------------------
+
+ template <typename T>
+ void call_private_set_mxArray(T&, mxArray*) {}
+ void call_private_set_mxArray(matrix_colmajor& item, mxArray* m) { item._private_set_mxArray(m); }
+ void call_private_set_mxArray(fmatrix_colmajor& item, mxArray* m) { item._private_set_mxArray(m); }
+
+// -------------------------------------------------------
+
+ template <typename T>
+ void validate_and_populate_arg (
+ long arg_idx,
+ const mxArray *prhs,
+ T& arg
+ )
+ {
+ using namespace mex_binding;
+ if (is_built_in_scalar_type<T>::value || is_same_type<T,bool>::value)
+ {
+ if( !(mxIsDouble(prhs) || mxIsSingle(prhs) || mxIsLogical(prhs) ) ||
+ mxIsComplex(prhs) ||
+ mxGetNumberOfElements(prhs)!=1 )
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a scalar";
+ throw invalid_args_exception(sout);
+ }
+
+ assign_scalar(arg_idx, arg , mxGetScalar(prhs));
+ }
+ else if (is_matrix<T>::value || is_array2d<T>::value)
+ {
+ if (prhs == NULL)
+ {
+ clear_mat(arg);
+ return;
+ }
+
+ typedef typename inner_type<T>::type type;
+
+ const int num_dims = mxGetNumberOfDimensions(prhs);
+ const long nr = mxGetM(prhs);
+ const long nc = mxGetN(prhs);
+
+ if (is_same_type<type,dlib::rgb_pixel>::value)
+ {
+ if (!(num_dims == 3 && mxGetDimensions(prhs)[2] == 3 && mxIsUint8(prhs)))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a 3-D NxMx3 image matrix of uint8";
+ throw invalid_args_exception(sout);
+ }
+
+ const long rows = mxGetDimensions(prhs)[0];
+ const long cols = mxGetDimensions(prhs)[1];
+ assign_image(arg_idx, arg , (const dlib::uint8*)mxGetData(prhs), rows, cols);
+ return;
+ }
+
+ if (num_dims != 2)
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a 2-D matrix (got a " << num_dims << "-D matrix)";
+ throw invalid_args_exception(sout);
+ }
+
+
+ if (is_same_type<type,double>::value)
+ {
+ if (!mxIsDouble(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of doubles";
+ throw invalid_args_exception(sout);
+ }
+ if (is_column_major_matrix<T>::value)
+ call_private_set_mxArray(arg, (mxArray*)prhs);
+ else
+ assign_mat(arg_idx, arg , pointer_to_matrix(mxGetPr(prhs), nc, nr));
+ }
+ else if (is_same_type<type, float>::value)
+ {
+ if (!mxIsSingle(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of single/float";
+ throw invalid_args_exception(sout);
+ }
+
+ if (is_column_major_matrix<T>::value)
+ call_private_set_mxArray(arg,(mxArray*)prhs);
+ else
+ assign_mat(arg_idx, arg , pointer_to_matrix((const float*)mxGetData(prhs), nc, nr));
+ }
+ else if (is_same_type<type, bool>::value)
+ {
+ if (!mxIsLogical(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of logical elements.";
+ throw invalid_args_exception(sout);
+ }
+ DLIB_CASSERT(sizeof(mxLogical) == sizeof(bool),"logical matrices are not supported by the mex wrapper when mxLogical isn't a bool.");
+
+ assign_mat(arg_idx, arg , pointer_to_matrix((const bool*)mxGetData(prhs), nc, nr));
+ }
+ else if (is_same_type<type, dlib::uint8>::value)
+ {
+ if (!mxIsUint8(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of uint8";
+ throw invalid_args_exception(sout);
+ }
+
+ assign_mat(arg_idx, arg , pointer_to_matrix((const dlib::uint8*)mxGetData(prhs), nc, nr));
+ }
+ else if (is_same_type<type, dlib::int8>::value)
+ {
+ if (!mxIsInt8(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of int8";
+ throw invalid_args_exception(sout);
+ }
+
+ assign_mat(arg_idx, arg , pointer_to_matrix((const dlib::int8*)mxGetData(prhs), nc, nr));
+ }
+ else if (is_same_type<type, dlib::int16>::value ||
+ (is_same_type<type, short>::value && sizeof(short) == sizeof(dlib::int16)))
+ {
+ if (!mxIsInt16(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of int16";
+ throw invalid_args_exception(sout);
+ }
+
+ assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr));
+ }
+ else if (is_same_type<type, dlib::uint16>::value ||
+ (is_same_type<type, unsigned short>::value && sizeof(unsigned short) == sizeof(dlib::uint16)))
+ {
+ if (!mxIsUint16(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of uint16";
+ throw invalid_args_exception(sout);
+ }
+
+ assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr));
+ }
+ else if (is_same_type<type, dlib::int32>::value ||
+ (is_same_type<type, int>::value && sizeof(int) == sizeof(dlib::int32)) ||
+ (is_same_type<type, long>::value && sizeof(long) == sizeof(dlib::int32)))
+ {
+ if (!mxIsInt32(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of int32";
+ throw invalid_args_exception(sout);
+ }
+
+ assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr));
+ }
+ else if (is_same_type<type, dlib::uint32>::value ||
+ (is_same_type<type, unsigned int>::value && sizeof(unsigned int) == sizeof(dlib::uint32)) ||
+ (is_same_type<type, unsigned long>::value && sizeof(unsigned long) == sizeof(dlib::uint32)))
+ {
+ if (!mxIsUint32(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of uint32";
+ throw invalid_args_exception(sout);
+ }
+
+ assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr));
+ }
+ else if (is_same_type<type, dlib::uint64>::value ||
+ (is_same_type<type, unsigned int>::value && sizeof(unsigned int) == sizeof(dlib::uint64)) ||
+ (is_same_type<type, unsigned long>::value && sizeof(unsigned long) == sizeof(dlib::uint64)))
+ {
+ if (!mxIsUint64(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of uint64";
+ throw invalid_args_exception(sout);
+ }
+
+ assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr));
+ }
+ else if (is_same_type<type, dlib::int64>::value ||
+ (is_same_type<type, int>::value && sizeof(int) == sizeof(dlib::int64)) ||
+ (is_same_type<type, long>::value && sizeof(long) == sizeof(dlib::int64)))
+ {
+ if (!mxIsInt64(prhs) || mxIsComplex(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a matrix of int64";
+ throw invalid_args_exception(sout);
+ }
+
+ assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr));
+ }
+ else
+ {
+ throw invalid_args_exception("mex_function uses unsupported matrix type");
+ }
+ }
+ else if (is_array_type<T>::value)
+ {
+ assign_std_vector(arg_idx, arg, prhs);
+
+ }
+ else if (is_same_type<T,function_handle>::value)
+ {
+ if (!mxIsClass(prhs, "function_handle"))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a function handle.";
+ throw invalid_args_exception(sout);
+ }
+ assign_function_handle(arg_idx, arg, prhs);
+ }
+ else
+ {
+ throw invalid_args_exception("mex_function uses unsupported input argument type");
+ }
+ }
+
+ void validate_and_populate_arg(
+ long arg_idx,
+ const mxArray *prhs,
+ matlab_struct& arg
+ )
+ {
+ if (!mxIsStruct(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a struct";
+ throw invalid_args_exception(sout);
+ }
+
+ arg.set_struct_handle(arg_idx, prhs);
+ }
+
+
+ void validate_and_populate_arg(
+ long arg_idx,
+ const mxArray *prhs,
+ matlab_object& arg
+ )
+ {
+ arg.set_object_handle(arg_idx, prhs);
+ }
+
+
+ void validate_and_populate_arg(
+ long arg_idx,
+ const mxArray *prhs,
+ std::string& arg
+ )
+ {
+ if (!mxIsChar(prhs))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " must be a char string";
+ throw invalid_args_exception(sout);
+ }
+
+ const long nr = mxGetM(prhs);
+ const long nc = mxGetN(prhs);
+ const long size = nr*nc;
+ arg.resize(size+1);
+ if (mxGetString(prhs, &arg[0], arg.size()))
+ {
+ std::ostringstream sout;
+ sout << "Input argument " << arg_idx+1 << " encountered an error while calling mxGetString()";
+ throw invalid_args_exception(sout);
+ }
+ arg.resize(size);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ typename dlib::enable_if<is_same_type<dlib::rgb_pixel,typename EXP::type> >::type assign_image_to_matlab (
+ dlib::uint8* mat,
+ const matrix_exp<EXP>& item
+ )
+ {
+ for (long c = 0; c < item.nc(); ++c)
+ for (long r = 0; r < item.nr(); ++r)
+ *mat++ = item(r,c).red;
+ for (long c = 0; c < item.nc(); ++c)
+ for (long r = 0; r < item.nr(); ++r)
+ *mat++ = item(r,c).green;
+ for (long c = 0; c < item.nc(); ++c)
+ for (long r = 0; r < item.nr(); ++r)
+ *mat++ = item(r,c).blue;
+ }
+
+ template <typename T, typename EXP>
+ typename disable_if<is_same_type<dlib::rgb_pixel,typename EXP::type> >::type assign_image_to_matlab (
+ T* mat,
+ const matrix_exp<EXP>&
+ )
+ {
+ mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
+ "mex_function uses unsupported output image argument type");
+ }
+
+ template <typename T>
+ typename dlib::enable_if<is_matrix<T> >::type assign_to_matlab(
+ mxArray*& plhs,
+ const T& item
+ )
+ {
+ typedef typename T::type type;
+
+ type* mat = 0;
+
+ if (is_same_type<double, type>::value)
+ {
+ plhs = mxCreateDoubleMatrix(item.nr(),
+ item.nc(),
+ mxREAL);
+
+ mat = (type*)mxGetPr(plhs);
+ }
+ else if (is_same_type<float, type>::value )
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxSINGLE_CLASS,
+ mxREAL);
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<bool, type>::value )
+ {
+ plhs = mxCreateLogicalMatrix(item.nr(),
+ item.nc());
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<dlib::uint8, type>::value )
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxUINT8_CLASS,
+ mxREAL);
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<dlib::int8, type>::value )
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxINT8_CLASS,
+ mxREAL);
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<dlib::int16, type>::value ||
+ (is_same_type<short,type>::value && sizeof(short) == sizeof(dlib::int16)))
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxINT16_CLASS,
+ mxREAL);
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<dlib::uint16, type>::value ||
+ (is_same_type<unsigned short,type>::value && sizeof(unsigned short) == sizeof(dlib::uint16)))
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxUINT16_CLASS,
+ mxREAL);
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<dlib::int32, type>::value ||
+ (is_same_type<long,type>::value && sizeof(long) == sizeof(dlib::int32)))
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxINT32_CLASS,
+ mxREAL);
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<dlib::uint32, type>::value ||
+ (is_same_type<unsigned long,type>::value && sizeof(unsigned long) == sizeof(dlib::uint32)))
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxUINT32_CLASS,
+ mxREAL);
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<dlib::uint64, type>::value ||
+ (is_same_type<unsigned long,type>::value && sizeof(unsigned long) == sizeof(dlib::uint64)))
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxUINT64_CLASS,
+ mxREAL);
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<dlib::int64, type>::value ||
+ (is_same_type<long,type>::value && sizeof(long) == sizeof(dlib::int64)))
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxINT64_CLASS,
+ mxREAL);
+
+ mat = (type*)mxGetData(plhs);
+ }
+ else if (is_same_type<dlib::rgb_pixel, type>::value)
+ {
+ mwSize dims[3] = {(mwSize)item.nr(), (mwSize)item.nc(), 3};
+ plhs = mxCreateNumericArray(3, dims, mxUINT8_CLASS, mxREAL);
+
+ assign_image_to_matlab((dlib::uint8*)mxGetData(plhs), item);
+ return;
+ }
+ else
+ {
+ mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
+ "mex_function uses unsupported output argument type");
+ }
+
+
+ const_temp_matrix<T> m(item);
+
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ for ( long r = 0; r < m.nr(); ++r)
+ {
+ *mat++ = m(r,c);
+ }
+ }
+ }
+
+ void assign_to_matlab(
+ mxArray*& plhs,
+ matrix_colmajor& item
+ )
+ {
+ if(item._private_is_owned_by_matlab())
+ {
+ // Don't need to do a copy if it's this kind of matrix since we can just
+ // pull the underlying mxArray out directly and thus avoid a copy.
+ plhs = item._private_release_mxArray();
+ // If there isn't anything there because the matrix is empty then set it to an
+ // empty matrix.
+ if (!plhs)
+ plhs = mxCreateDoubleMatrix(item.nr(),
+ item.nc(),
+ mxREAL);
+ }
+ else
+ {
+ plhs = mxCreateDoubleMatrix(item.nr(),
+ item.nc(),
+ mxREAL);
+ if (item.size() != 0)
+ memcpy(mxGetPr(plhs), &item(0,0), item.size()*sizeof(double));
+ }
+ }
+
+ void assign_to_matlab(
+ mxArray*& plhs,
+ fmatrix_colmajor& item
+ )
+ {
+ if(item._private_is_owned_by_matlab())
+ {
+ // Don't need to do a copy if it's this kind of matrix since we can just
+ // pull the underlying mxArray out directly and thus avoid a copy.
+ plhs = item._private_release_mxArray();
+ // If there isn't anything there because the matrix is empty then set it to an
+ // empty matrix.
+ if (!plhs)
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxSINGLE_CLASS,
+ mxREAL);
+ }
+ else
+ {
+ plhs = mxCreateNumericMatrix(item.nr(),
+ item.nc(),
+ mxSINGLE_CLASS,
+ mxREAL);
+ if (item.size() != 0)
+ memcpy(mxGetPr(plhs), &item(0,0), item.size()*sizeof(float));
+ }
+ }
+
+ void assign_to_matlab(
+ mxArray*& plhs,
+ matlab_struct& item
+ )
+ {
+ plhs = (mxArray*)item.release_struct_to_matlab();
+ }
+
+ void assign_to_matlab(
+ mxArray*& plhs,
+ matlab_object& item
+ )
+ {
+ plhs = (mxArray*)item.release_object_to_matlab();
+ }
+
+ void assign_to_matlab(
+ mxArray*& plhs,
+ const std::string& item
+ )
+ {
+ plhs = mxCreateString(item.c_str());
+ }
+
+ template <typename T, typename MM>
+ void assign_to_matlab(
+ mxArray*& plhs,
+ const array2d<T,MM>& item
+ )
+ {
+ assign_to_matlab(plhs,array_to_matrix(item));
+ }
+
+ template <typename T>
+ typename dlib::disable_if_c<is_matrix<T>::value || is_array_type<T>::value ||
+ is_same_type<T,function_handle>::value>::type assign_to_matlab(
+ mxArray*& plhs,
+ const T& item
+ )
+ {
+ plhs = mxCreateDoubleScalar(item);
+ }
+
+
+ void assign_to_matlab (
+ mxArray*& plhs,
+ const char* str
+ )
+ {
+ assign_to_matlab(plhs, std::string(str));
+ }
+
+ void assign_to_matlab(
+ mxArray*& plhs,
+ const function_handle& h
+ )
+ {
+ }
+
+ template <typename T>
+ typename dlib::enable_if<is_array_type<T> >::type assign_to_matlab(
+ mxArray*& plhs,
+ const T& item
+ )
+ {
+ mwSize dims[1] = {item.size()};
+ plhs = mxCreateCellArray(1,dims);
+ for (unsigned long i = 0; i < item.size(); ++i)
+ {
+ mxArray* next = 0;
+ assign_to_matlab(next, item[i]);
+ mxSetCell(plhs, i, next);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void mark_owned_by_matlab (const T&){}
+
+ void mark_owned_by_matlab(matrix_colmajor& item) { item._private_mark_owned_by_matlab(); }
+ void mark_owned_by_matlab(fmatrix_colmajor& item) { item._private_mark_owned_by_matlab(); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long num_args
+ >
+ struct call_mex_function_helper;
+
+ template <>
+ struct call_mex_function_helper<0>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int , mxArray **,
+ int , const mxArray **
+ ) const
+ {
+ f();
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<1>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+
+ typename basic_type<arg1_type>::type A1;
+
+ mark_owned_by_matlab(A1);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+
+ f(A1);
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<2>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+
+ f(A1,A2);
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<3>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+
+ f(A1,A2,A3);
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<4>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+
+ f(A1,A2,A3,A4);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<5>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+
+ f(A1,A2,A3,A4,A5);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ }
+ };
+
+
+ template <>
+ struct call_mex_function_helper<6>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+
+ f(A1,A2,A3,A4,A5,A6);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ }
+ };
+
+
+ template <>
+ struct call_mex_function_helper<7>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+
+ f(A1,A2,A3,A4,A5,A6,A7);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ }
+ };
+
+
+ template <>
+ struct call_mex_function_helper<8>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ }
+ };
+
+
+ template <>
+ struct call_mex_function_helper<9>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ }
+ };
+
+
+
+ template <>
+ struct call_mex_function_helper<10>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<11>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<12>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+ typedef typename sig_traits<funct>::arg12_type arg12_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+ typename basic_type<arg12_type>::type A12;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+ mark_owned_by_matlab(A12);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+ if (i < nrhs && is_input_type<arg12_type>::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ if (is_output_type<arg12_type>::value) {assign_to_matlab(plhs[i],A12); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<13>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+ typedef typename sig_traits<funct>::arg12_type arg12_type;
+ typedef typename sig_traits<funct>::arg13_type arg13_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+ typename basic_type<arg12_type>::type A12;
+ typename basic_type<arg13_type>::type A13;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+ mark_owned_by_matlab(A12);
+ mark_owned_by_matlab(A13);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+ if (i < nrhs && is_input_type<arg12_type>::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12;
+ if (i < nrhs && is_input_type<arg13_type>::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ if (is_output_type<arg12_type>::value) {assign_to_matlab(plhs[i],A12); ++i;}
+ if (is_output_type<arg13_type>::value) {assign_to_matlab(plhs[i],A13); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<14>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+ typedef typename sig_traits<funct>::arg12_type arg12_type;
+ typedef typename sig_traits<funct>::arg13_type arg13_type;
+ typedef typename sig_traits<funct>::arg14_type arg14_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+ typename basic_type<arg12_type>::type A12;
+ typename basic_type<arg13_type>::type A13;
+ typename basic_type<arg14_type>::type A14;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+ mark_owned_by_matlab(A12);
+ mark_owned_by_matlab(A13);
+ mark_owned_by_matlab(A14);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+ if (i < nrhs && is_input_type<arg12_type>::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12;
+ if (i < nrhs && is_input_type<arg13_type>::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13;
+ if (i < nrhs && is_input_type<arg14_type>::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ if (is_output_type<arg12_type>::value) {assign_to_matlab(plhs[i],A12); ++i;}
+ if (is_output_type<arg13_type>::value) {assign_to_matlab(plhs[i],A13); ++i;}
+ if (is_output_type<arg14_type>::value) {assign_to_matlab(plhs[i],A14); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<15>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+ typedef typename sig_traits<funct>::arg12_type arg12_type;
+ typedef typename sig_traits<funct>::arg13_type arg13_type;
+ typedef typename sig_traits<funct>::arg14_type arg14_type;
+ typedef typename sig_traits<funct>::arg15_type arg15_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+ typename basic_type<arg12_type>::type A12;
+ typename basic_type<arg13_type>::type A13;
+ typename basic_type<arg14_type>::type A14;
+ typename basic_type<arg15_type>::type A15;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+ mark_owned_by_matlab(A12);
+ mark_owned_by_matlab(A13);
+ mark_owned_by_matlab(A14);
+ mark_owned_by_matlab(A15);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+ if (i < nrhs && is_input_type<arg12_type>::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12;
+ if (i < nrhs && is_input_type<arg13_type>::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13;
+ if (i < nrhs && is_input_type<arg14_type>::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14;
+ if (i < nrhs && is_input_type<arg15_type>::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ if (is_output_type<arg12_type>::value) {assign_to_matlab(plhs[i],A12); ++i;}
+ if (is_output_type<arg13_type>::value) {assign_to_matlab(plhs[i],A13); ++i;}
+ if (is_output_type<arg14_type>::value) {assign_to_matlab(plhs[i],A14); ++i;}
+ if (is_output_type<arg15_type>::value) {assign_to_matlab(plhs[i],A15); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<16>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+ typedef typename sig_traits<funct>::arg12_type arg12_type;
+ typedef typename sig_traits<funct>::arg13_type arg13_type;
+ typedef typename sig_traits<funct>::arg14_type arg14_type;
+ typedef typename sig_traits<funct>::arg15_type arg15_type;
+ typedef typename sig_traits<funct>::arg16_type arg16_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+ typename basic_type<arg12_type>::type A12;
+ typename basic_type<arg13_type>::type A13;
+ typename basic_type<arg14_type>::type A14;
+ typename basic_type<arg15_type>::type A15;
+ typename basic_type<arg16_type>::type A16;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+ mark_owned_by_matlab(A12);
+ mark_owned_by_matlab(A13);
+ mark_owned_by_matlab(A14);
+ mark_owned_by_matlab(A15);
+ mark_owned_by_matlab(A16);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+ if (i < nrhs && is_input_type<arg12_type>::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12;
+ if (i < nrhs && is_input_type<arg13_type>::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13;
+ if (i < nrhs && is_input_type<arg14_type>::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14;
+ if (i < nrhs && is_input_type<arg15_type>::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15;
+ if (i < nrhs && is_input_type<arg16_type>::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ if (is_output_type<arg12_type>::value) {assign_to_matlab(plhs[i],A12); ++i;}
+ if (is_output_type<arg13_type>::value) {assign_to_matlab(plhs[i],A13); ++i;}
+ if (is_output_type<arg14_type>::value) {assign_to_matlab(plhs[i],A14); ++i;}
+ if (is_output_type<arg15_type>::value) {assign_to_matlab(plhs[i],A15); ++i;}
+ if (is_output_type<arg16_type>::value) {assign_to_matlab(plhs[i],A16); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<17>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+ typedef typename sig_traits<funct>::arg12_type arg12_type;
+ typedef typename sig_traits<funct>::arg13_type arg13_type;
+ typedef typename sig_traits<funct>::arg14_type arg14_type;
+ typedef typename sig_traits<funct>::arg15_type arg15_type;
+ typedef typename sig_traits<funct>::arg16_type arg16_type;
+ typedef typename sig_traits<funct>::arg17_type arg17_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+ typename basic_type<arg12_type>::type A12;
+ typename basic_type<arg13_type>::type A13;
+ typename basic_type<arg14_type>::type A14;
+ typename basic_type<arg15_type>::type A15;
+ typename basic_type<arg16_type>::type A16;
+ typename basic_type<arg17_type>::type A17;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+ mark_owned_by_matlab(A12);
+ mark_owned_by_matlab(A13);
+ mark_owned_by_matlab(A14);
+ mark_owned_by_matlab(A15);
+ mark_owned_by_matlab(A16);
+ mark_owned_by_matlab(A17);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+ if (i < nrhs && is_input_type<arg12_type>::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12;
+ if (i < nrhs && is_input_type<arg13_type>::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13;
+ if (i < nrhs && is_input_type<arg14_type>::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14;
+ if (i < nrhs && is_input_type<arg15_type>::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15;
+ if (i < nrhs && is_input_type<arg16_type>::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16;
+ if (i < nrhs && is_input_type<arg17_type>::value) {validate_and_populate_arg(i,prhs[i],A17); ++i;} ELSE_ASSIGN_ARG_17;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ if (is_output_type<arg12_type>::value) {assign_to_matlab(plhs[i],A12); ++i;}
+ if (is_output_type<arg13_type>::value) {assign_to_matlab(plhs[i],A13); ++i;}
+ if (is_output_type<arg14_type>::value) {assign_to_matlab(plhs[i],A14); ++i;}
+ if (is_output_type<arg15_type>::value) {assign_to_matlab(plhs[i],A15); ++i;}
+ if (is_output_type<arg16_type>::value) {assign_to_matlab(plhs[i],A16); ++i;}
+ if (is_output_type<arg17_type>::value) {assign_to_matlab(plhs[i],A17); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<18>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+ typedef typename sig_traits<funct>::arg12_type arg12_type;
+ typedef typename sig_traits<funct>::arg13_type arg13_type;
+ typedef typename sig_traits<funct>::arg14_type arg14_type;
+ typedef typename sig_traits<funct>::arg15_type arg15_type;
+ typedef typename sig_traits<funct>::arg16_type arg16_type;
+ typedef typename sig_traits<funct>::arg17_type arg17_type;
+ typedef typename sig_traits<funct>::arg18_type arg18_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+ typename basic_type<arg12_type>::type A12;
+ typename basic_type<arg13_type>::type A13;
+ typename basic_type<arg14_type>::type A14;
+ typename basic_type<arg15_type>::type A15;
+ typename basic_type<arg16_type>::type A16;
+ typename basic_type<arg17_type>::type A17;
+ typename basic_type<arg18_type>::type A18;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+ mark_owned_by_matlab(A12);
+ mark_owned_by_matlab(A13);
+ mark_owned_by_matlab(A14);
+ mark_owned_by_matlab(A15);
+ mark_owned_by_matlab(A16);
+ mark_owned_by_matlab(A17);
+ mark_owned_by_matlab(A18);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+ if (i < nrhs && is_input_type<arg12_type>::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12;
+ if (i < nrhs && is_input_type<arg13_type>::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13;
+ if (i < nrhs && is_input_type<arg14_type>::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14;
+ if (i < nrhs && is_input_type<arg15_type>::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15;
+ if (i < nrhs && is_input_type<arg16_type>::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16;
+ if (i < nrhs && is_input_type<arg17_type>::value) {validate_and_populate_arg(i,prhs[i],A17); ++i;} ELSE_ASSIGN_ARG_17;
+ if (i < nrhs && is_input_type<arg18_type>::value) {validate_and_populate_arg(i,prhs[i],A18); ++i;} ELSE_ASSIGN_ARG_18;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ if (is_output_type<arg12_type>::value) {assign_to_matlab(plhs[i],A12); ++i;}
+ if (is_output_type<arg13_type>::value) {assign_to_matlab(plhs[i],A13); ++i;}
+ if (is_output_type<arg14_type>::value) {assign_to_matlab(plhs[i],A14); ++i;}
+ if (is_output_type<arg15_type>::value) {assign_to_matlab(plhs[i],A15); ++i;}
+ if (is_output_type<arg16_type>::value) {assign_to_matlab(plhs[i],A16); ++i;}
+ if (is_output_type<arg17_type>::value) {assign_to_matlab(plhs[i],A17); ++i;}
+ if (is_output_type<arg18_type>::value) {assign_to_matlab(plhs[i],A18); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<19>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+ typedef typename sig_traits<funct>::arg12_type arg12_type;
+ typedef typename sig_traits<funct>::arg13_type arg13_type;
+ typedef typename sig_traits<funct>::arg14_type arg14_type;
+ typedef typename sig_traits<funct>::arg15_type arg15_type;
+ typedef typename sig_traits<funct>::arg16_type arg16_type;
+ typedef typename sig_traits<funct>::arg17_type arg17_type;
+ typedef typename sig_traits<funct>::arg18_type arg18_type;
+ typedef typename sig_traits<funct>::arg19_type arg19_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+ typename basic_type<arg12_type>::type A12;
+ typename basic_type<arg13_type>::type A13;
+ typename basic_type<arg14_type>::type A14;
+ typename basic_type<arg15_type>::type A15;
+ typename basic_type<arg16_type>::type A16;
+ typename basic_type<arg17_type>::type A17;
+ typename basic_type<arg18_type>::type A18;
+ typename basic_type<arg19_type>::type A19;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+ mark_owned_by_matlab(A12);
+ mark_owned_by_matlab(A13);
+ mark_owned_by_matlab(A14);
+ mark_owned_by_matlab(A15);
+ mark_owned_by_matlab(A16);
+ mark_owned_by_matlab(A17);
+ mark_owned_by_matlab(A18);
+ mark_owned_by_matlab(A19);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+ if (i < nrhs && is_input_type<arg12_type>::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12;
+ if (i < nrhs && is_input_type<arg13_type>::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13;
+ if (i < nrhs && is_input_type<arg14_type>::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14;
+ if (i < nrhs && is_input_type<arg15_type>::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15;
+ if (i < nrhs && is_input_type<arg16_type>::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16;
+ if (i < nrhs && is_input_type<arg17_type>::value) {validate_and_populate_arg(i,prhs[i],A17); ++i;} ELSE_ASSIGN_ARG_17;
+ if (i < nrhs && is_input_type<arg18_type>::value) {validate_and_populate_arg(i,prhs[i],A18); ++i;} ELSE_ASSIGN_ARG_18;
+ if (i < nrhs && is_input_type<arg19_type>::value) {validate_and_populate_arg(i,prhs[i],A19); ++i;} ELSE_ASSIGN_ARG_19;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18,A19);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ if (is_output_type<arg12_type>::value) {assign_to_matlab(plhs[i],A12); ++i;}
+ if (is_output_type<arg13_type>::value) {assign_to_matlab(plhs[i],A13); ++i;}
+ if (is_output_type<arg14_type>::value) {assign_to_matlab(plhs[i],A14); ++i;}
+ if (is_output_type<arg15_type>::value) {assign_to_matlab(plhs[i],A15); ++i;}
+ if (is_output_type<arg16_type>::value) {assign_to_matlab(plhs[i],A16); ++i;}
+ if (is_output_type<arg17_type>::value) {assign_to_matlab(plhs[i],A17); ++i;}
+ if (is_output_type<arg18_type>::value) {assign_to_matlab(plhs[i],A18); ++i;}
+ if (is_output_type<arg19_type>::value) {assign_to_matlab(plhs[i],A19); ++i;}
+ }
+ };
+
+ template <>
+ struct call_mex_function_helper<20>
+ {
+ template <typename funct>
+ void callit(
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ ) const
+ {
+ typedef typename sig_traits<funct>::arg1_type arg1_type;
+ typedef typename sig_traits<funct>::arg2_type arg2_type;
+ typedef typename sig_traits<funct>::arg3_type arg3_type;
+ typedef typename sig_traits<funct>::arg4_type arg4_type;
+ typedef typename sig_traits<funct>::arg5_type arg5_type;
+ typedef typename sig_traits<funct>::arg6_type arg6_type;
+ typedef typename sig_traits<funct>::arg7_type arg7_type;
+ typedef typename sig_traits<funct>::arg8_type arg8_type;
+ typedef typename sig_traits<funct>::arg9_type arg9_type;
+ typedef typename sig_traits<funct>::arg10_type arg10_type;
+ typedef typename sig_traits<funct>::arg11_type arg11_type;
+ typedef typename sig_traits<funct>::arg12_type arg12_type;
+ typedef typename sig_traits<funct>::arg13_type arg13_type;
+ typedef typename sig_traits<funct>::arg14_type arg14_type;
+ typedef typename sig_traits<funct>::arg15_type arg15_type;
+ typedef typename sig_traits<funct>::arg16_type arg16_type;
+ typedef typename sig_traits<funct>::arg17_type arg17_type;
+ typedef typename sig_traits<funct>::arg18_type arg18_type;
+ typedef typename sig_traits<funct>::arg19_type arg19_type;
+ typedef typename sig_traits<funct>::arg20_type arg20_type;
+
+ typename basic_type<arg1_type>::type A1;
+ typename basic_type<arg2_type>::type A2;
+ typename basic_type<arg3_type>::type A3;
+ typename basic_type<arg4_type>::type A4;
+ typename basic_type<arg5_type>::type A5;
+ typename basic_type<arg6_type>::type A6;
+ typename basic_type<arg7_type>::type A7;
+ typename basic_type<arg8_type>::type A8;
+ typename basic_type<arg9_type>::type A9;
+ typename basic_type<arg10_type>::type A10;
+ typename basic_type<arg11_type>::type A11;
+ typename basic_type<arg12_type>::type A12;
+ typename basic_type<arg13_type>::type A13;
+ typename basic_type<arg14_type>::type A14;
+ typename basic_type<arg15_type>::type A15;
+ typename basic_type<arg16_type>::type A16;
+ typename basic_type<arg17_type>::type A17;
+ typename basic_type<arg18_type>::type A18;
+ typename basic_type<arg19_type>::type A19;
+ typename basic_type<arg20_type>::type A20;
+
+ mark_owned_by_matlab(A1);
+ mark_owned_by_matlab(A2);
+ mark_owned_by_matlab(A3);
+ mark_owned_by_matlab(A4);
+ mark_owned_by_matlab(A5);
+ mark_owned_by_matlab(A6);
+ mark_owned_by_matlab(A7);
+ mark_owned_by_matlab(A8);
+ mark_owned_by_matlab(A9);
+ mark_owned_by_matlab(A10);
+ mark_owned_by_matlab(A11);
+ mark_owned_by_matlab(A12);
+ mark_owned_by_matlab(A13);
+ mark_owned_by_matlab(A14);
+ mark_owned_by_matlab(A15);
+ mark_owned_by_matlab(A16);
+ mark_owned_by_matlab(A17);
+ mark_owned_by_matlab(A18);
+ mark_owned_by_matlab(A19);
+ mark_owned_by_matlab(A20);
+
+ int i = 0;
+ if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
+ if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
+ if (i < nrhs && is_input_type<arg3_type>::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3;
+ if (i < nrhs && is_input_type<arg4_type>::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4;
+ if (i < nrhs && is_input_type<arg5_type>::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5;
+ if (i < nrhs && is_input_type<arg6_type>::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6;
+ if (i < nrhs && is_input_type<arg7_type>::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7;
+ if (i < nrhs && is_input_type<arg8_type>::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8;
+ if (i < nrhs && is_input_type<arg9_type>::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9;
+ if (i < nrhs && is_input_type<arg10_type>::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10;
+ if (i < nrhs && is_input_type<arg11_type>::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11;
+ if (i < nrhs && is_input_type<arg12_type>::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12;
+ if (i < nrhs && is_input_type<arg13_type>::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13;
+ if (i < nrhs && is_input_type<arg14_type>::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14;
+ if (i < nrhs && is_input_type<arg15_type>::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15;
+ if (i < nrhs && is_input_type<arg16_type>::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16;
+ if (i < nrhs && is_input_type<arg17_type>::value) {validate_and_populate_arg(i,prhs[i],A17); ++i;} ELSE_ASSIGN_ARG_17;
+ if (i < nrhs && is_input_type<arg18_type>::value) {validate_and_populate_arg(i,prhs[i],A18); ++i;} ELSE_ASSIGN_ARG_18;
+ if (i < nrhs && is_input_type<arg19_type>::value) {validate_and_populate_arg(i,prhs[i],A19); ++i;} ELSE_ASSIGN_ARG_19;
+ if (i < nrhs && is_input_type<arg20_type>::value) {validate_and_populate_arg(i,prhs[i],A20); ++i;} ELSE_ASSIGN_ARG_20;
+
+ f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18,A19,A20);
+
+
+ i = 0;
+ if (is_output_type<arg1_type>::value) {assign_to_matlab(plhs[i],A1); ++i;}
+ if (is_output_type<arg2_type>::value) {assign_to_matlab(plhs[i],A2); ++i;}
+ if (is_output_type<arg3_type>::value) {assign_to_matlab(plhs[i],A3); ++i;}
+ if (is_output_type<arg4_type>::value) {assign_to_matlab(plhs[i],A4); ++i;}
+ if (is_output_type<arg5_type>::value) {assign_to_matlab(plhs[i],A5); ++i;}
+ if (is_output_type<arg6_type>::value) {assign_to_matlab(plhs[i],A6); ++i;}
+ if (is_output_type<arg7_type>::value) {assign_to_matlab(plhs[i],A7); ++i;}
+ if (is_output_type<arg8_type>::value) {assign_to_matlab(plhs[i],A8); ++i;}
+ if (is_output_type<arg9_type>::value) {assign_to_matlab(plhs[i],A9); ++i;}
+ if (is_output_type<arg10_type>::value) {assign_to_matlab(plhs[i],A10); ++i;}
+ if (is_output_type<arg11_type>::value) {assign_to_matlab(plhs[i],A11); ++i;}
+ if (is_output_type<arg12_type>::value) {assign_to_matlab(plhs[i],A12); ++i;}
+ if (is_output_type<arg13_type>::value) {assign_to_matlab(plhs[i],A13); ++i;}
+ if (is_output_type<arg14_type>::value) {assign_to_matlab(plhs[i],A14); ++i;}
+ if (is_output_type<arg15_type>::value) {assign_to_matlab(plhs[i],A15); ++i;}
+ if (is_output_type<arg16_type>::value) {assign_to_matlab(plhs[i],A16); ++i;}
+ if (is_output_type<arg17_type>::value) {assign_to_matlab(plhs[i],A17); ++i;}
+ if (is_output_type<arg18_type>::value) {assign_to_matlab(plhs[i],A18); ++i;}
+ if (is_output_type<arg19_type>::value) {assign_to_matlab(plhs[i],A19); ++i;}
+ if (is_output_type<arg20_type>::value) {assign_to_matlab(plhs[i],A20); ++i;}
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T> struct is_matlab_object { const static bool value = false; };
+ template <> struct is_matlab_object <matlab_object> { const static bool value = true; };
+ template <> struct is_matlab_object <const matlab_object> { const static bool value = true; };
+ template <> struct is_matlab_object <matlab_object&> { const static bool value = true; };
+ template <> struct is_matlab_object <const matlab_object&> { const static bool value = true; };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ void call_mex_function (
+ const funct& f,
+ int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[]
+ )
+ {
+ const long expected_nrhs = funct_traits<funct>::num_inputs;
+ const long expected_nlhs = funct_traits<funct>::num_outputs;
+ const long expected_args = expected_nrhs + expected_nlhs;
+
+ long defaulted_args = 0;
+
+ #ifdef ARG_1_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg1_type>::value);
+ #ifndef ARG_2_DEFAULT
+ // You can't define a default for argument 1 if you don't define one for argument 2 also.
+ COMPILE_TIME_ASSERT(expected_args < 2);
+ #endif
+ COMPILE_TIME_ASSERT(1 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_2_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg2_type>::value);
+ #ifndef ARG_3_DEFAULT
+ // You can't define a default for argument 2 if you don't define one for argument 3 also.
+ COMPILE_TIME_ASSERT(expected_args < 3);
+ #endif
+ COMPILE_TIME_ASSERT(2 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_3_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg3_type>::value);
+ #ifndef ARG_4_DEFAULT
+ // You can't define a default for argument 3 if you don't define one for argument 4 also.
+ COMPILE_TIME_ASSERT(expected_args < 4);
+ #endif
+ COMPILE_TIME_ASSERT(3 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_4_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg4_type>::value);
+ #ifndef ARG_5_DEFAULT
+ // You can't define a default for argument 4 if you don't define one for argument 5 also.
+ COMPILE_TIME_ASSERT(expected_args < 5);
+ #endif
+ COMPILE_TIME_ASSERT(4 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_5_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg5_type>::value);
+ #ifndef ARG_6_DEFAULT
+ // You can't define a default for argument 5 if you don't define one for argument 6 also.
+ COMPILE_TIME_ASSERT(expected_args < 6);
+ #endif
+ COMPILE_TIME_ASSERT(5 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_6_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg6_type>::value);
+ #ifndef ARG_7_DEFAULT
+ // You can't define a default for argument 6 if you don't define one for argument 7 also.
+ COMPILE_TIME_ASSERT(expected_args < 7);
+ #endif
+ COMPILE_TIME_ASSERT(6 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_7_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg7_type>::value);
+ #ifndef ARG_8_DEFAULT
+ // You can't define a default for argument 7 if you don't define one for argument 8 also.
+ COMPILE_TIME_ASSERT(expected_args < 8);
+ #endif
+ COMPILE_TIME_ASSERT(7 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_8_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg8_type>::value);
+ #ifndef ARG_9_DEFAULT
+ // You can't define a default for argument 8 if you don't define one for argument 9 also.
+ COMPILE_TIME_ASSERT(expected_args < 9);
+ #endif
+ COMPILE_TIME_ASSERT(8 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_9_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg9_type>::value);
+ #ifndef ARG_10_DEFAULT
+ // You can't define a default for argument 9 if you don't define one for argument 10 also.
+ COMPILE_TIME_ASSERT(expected_args < 10);
+ #endif
+ COMPILE_TIME_ASSERT(9 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_10_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg10_type>::value);
+ #ifndef ARG_11_DEFAULT
+ // You can't define a default for argument 10 if you don't define one for argument 11 also.
+ COMPILE_TIME_ASSERT(expected_args < 11);
+ #endif
+ COMPILE_TIME_ASSERT(10 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_11_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg11_type>::value);
+ #ifndef ARG_12_DEFAULT
+ // You can't define a default for argument 11 if you don't define one for argument 12 also.
+ COMPILE_TIME_ASSERT(expected_args < 12);
+ #endif
+ COMPILE_TIME_ASSERT(11 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_12_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg12_type>::value);
+ #ifndef ARG_13_DEFAULT
+ // You can't define a default for argument 12 if you don't define one for argument 13 also.
+ COMPILE_TIME_ASSERT(expected_args < 13);
+ #endif
+ COMPILE_TIME_ASSERT(12 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_13_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg13_type>::value);
+ #ifndef ARG_14_DEFAULT
+ // You can't define a default for argument 13 if you don't define one for argument 14 also.
+ COMPILE_TIME_ASSERT(expected_args < 14);
+ #endif
+ COMPILE_TIME_ASSERT(13 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_14_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg14_type>::value);
+ #ifndef ARG_15_DEFAULT
+ // You can't define a default for argument 14 if you don't define one for argument 15 also.
+ COMPILE_TIME_ASSERT(expected_args < 15);
+ #endif
+ COMPILE_TIME_ASSERT(14 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_15_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg15_type>::value);
+ #ifndef ARG_16_DEFAULT
+ // You can't define a default for argument 15 if you don't define one for argument 16 also.
+ COMPILE_TIME_ASSERT(expected_args < 16);
+ #endif
+ COMPILE_TIME_ASSERT(15 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_16_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg16_type>::value);
+ #ifndef ARG_17_DEFAULT
+ // You can't define a default for argument 16 if you don't define one for argument 17 also.
+ COMPILE_TIME_ASSERT(expected_args < 17);
+ #endif
+ COMPILE_TIME_ASSERT(16 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_17_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg17_type>::value);
+ #ifndef ARG_18_DEFAULT
+ // You can't define a default for argument 17 if you don't define one for argument 18 also.
+ COMPILE_TIME_ASSERT(expected_args < 18);
+ #endif
+ COMPILE_TIME_ASSERT(17 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_18_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg18_type>::value);
+ #ifndef ARG_19_DEFAULT
+ // You can't define a default for argument 18 if you don't define one for argument 19 also.
+ COMPILE_TIME_ASSERT(expected_args < 19);
+ #endif
+ COMPILE_TIME_ASSERT(18 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_19_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg19_type>::value);
+ #ifndef ARG_20_DEFAULT
+ // You can't define a default for argument 19 if you don't define one for argument 20 also.
+ COMPILE_TIME_ASSERT(expected_args < 20);
+ #endif
+ COMPILE_TIME_ASSERT(19 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+ #ifdef ARG_20_DEFAULT
+ ++defaulted_args;
+ // You can only set an argument's default value if it is an input argument.
+ COMPILE_TIME_ASSERT(is_input_type<typename sig_traits<funct>::arg20_type>::value);
+ COMPILE_TIME_ASSERT(20 <= expected_args); // You can't define a default for an argument that doesn't exist.
+ #endif
+
+
+ // Arguments with type matlab_object are optional in both input and output.
+ int num_optional_inputs = 0;
+ int num_optional_outputs = 0;
+ if (is_matlab_object<typename sig_traits<funct>::arg20_type>::value) if (is_input_type<typename sig_traits<funct>::arg20_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg19_type>::value) if (is_input_type<typename sig_traits<funct>::arg19_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg18_type>::value) if (is_input_type<typename sig_traits<funct>::arg18_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg17_type>::value) if (is_input_type<typename sig_traits<funct>::arg17_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg16_type>::value) if (is_input_type<typename sig_traits<funct>::arg16_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg15_type>::value) if (is_input_type<typename sig_traits<funct>::arg15_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg14_type>::value) if (is_input_type<typename sig_traits<funct>::arg14_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg13_type>::value) if (is_input_type<typename sig_traits<funct>::arg13_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg12_type>::value) if (is_input_type<typename sig_traits<funct>::arg12_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg11_type>::value) if (is_input_type<typename sig_traits<funct>::arg11_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg10_type>::value) if (is_input_type<typename sig_traits<funct>::arg10_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg9_type>::value) if (is_input_type<typename sig_traits<funct>::arg9_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg8_type>::value) if (is_input_type<typename sig_traits<funct>::arg8_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg7_type>::value) if (is_input_type<typename sig_traits<funct>::arg7_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg6_type>::value) if (is_input_type<typename sig_traits<funct>::arg6_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg5_type>::value) if (is_input_type<typename sig_traits<funct>::arg5_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg4_type>::value) if (is_input_type<typename sig_traits<funct>::arg4_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg3_type>::value) if (is_input_type<typename sig_traits<funct>::arg3_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg2_type>::value) if (is_input_type<typename sig_traits<funct>::arg2_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+ if (is_matlab_object<typename sig_traits<funct>::arg1_type>::value) if (is_input_type<typename sig_traits<funct>::arg1_type>::value) ++num_optional_inputs; else ++num_optional_outputs;
+
+
+ /* check for proper number of arguments */
+ if(nrhs > expected_nrhs || nrhs < expected_nrhs - defaulted_args - num_optional_inputs)
+ {
+ std::ostringstream sout;
+ sout << "Expected between " << expected_nrhs-defaulted_args - num_optional_inputs
+ << " and " << expected_nrhs << " input arguments, got " << nrhs << ".";
+
+ mexErrMsgIdAndTxt("mex_function:nrhs",
+ escape_percent(sout).c_str());
+ }
+
+ if (nlhs > expected_nlhs)
+ {
+ std::ostringstream sout;
+ sout << "Expected at most " << expected_nlhs << " output arguments, got " << nlhs << ".";
+
+ mexErrMsgIdAndTxt("mex_function:nlhs",
+ escape_percent(sout).c_str());
+ }
+
+ call_mex_function_helper<sig_traits<funct>::num_args> helper;
+ helper.callit(f, nlhs, plhs, nrhs, prhs);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class mex_streambuf : public std::streambuf
+ {
+
+ public:
+ mex_streambuf (
+ )
+ {
+ buf.resize(1000);
+ setp(&buf[0], &buf[0] + buf.size()-2);
+
+ // make cout send data to mex_streambuf
+ oldbuf = std::cout.rdbuf(this);
+ }
+
+ ~mex_streambuf()
+ {
+ // put cout back to the way we found it before running our mex function.
+ std::cout.rdbuf(oldbuf);
+ }
+
+
+ protected:
+
+
+ int sync (
+ )
+ {
+ int num = static_cast<int>(pptr()-pbase());
+ if (num != 0)
+ {
+ check_for_matlab_ctrl_c();
+
+ buf[num] = 0; // null terminate the string
+ mexPrintf("%s",&buf[0]);
+ mexEvalString("drawnow"); // flush print to screen
+ pbump(-num);
+ }
+ return 0;
+ }
+
+ int_type overflow (
+ int_type c
+ )
+ {
+ if (c != EOF)
+ {
+ *pptr() = c;
+ pbump(1);
+ }
+ sync();
+ return c;
+ }
+
+ private:
+ std::vector<char> buf;
+ std::streambuf* oldbuf;
+
+ };
+
+ class mex_warn_streambuf : public std::streambuf
+ {
+
+ public:
+ mex_warn_streambuf (
+ )
+ {
+ buf.resize(1000);
+ setp(&buf[0], &buf[0] + buf.size()-2);
+
+ // make cout send data to mex_warn_streambuf
+ oldbuf = std::cerr.rdbuf(this);
+ }
+
+ ~mex_warn_streambuf()
+ {
+ // put cerr back to the way we found it before running our mex function.
+ std::cerr.rdbuf(oldbuf);
+ }
+
+ protected:
+
+
+ int sync (
+ )
+ {
+ int num = static_cast<int>(pptr()-pbase());
+ if (num != 0)
+ {
+ check_for_matlab_ctrl_c();
+
+ buf[num] = 0; // null terminate the string
+ mexWarnMsgTxt(&buf[0]);
+ mexEvalString("drawnow"); // flush print to screen
+ pbump(-num);
+ }
+ return 0;
+ }
+
+ int_type overflow (
+ int_type c
+ )
+ {
+ if (c != EOF)
+ {
+ *pptr() = c;
+ pbump(1);
+ }
+ sync();
+ return c;
+ }
+
+ private:
+ std::vector<char> buf;
+ std::streambuf* oldbuf;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void setup_input_args (
+ mxArray*& array,
+ const T& item,
+ int& nrhs
+ )
+ {
+ assign_to_matlab(array, item);
+ ++nrhs;
+ }
+
+ void setup_input_args (
+ mxArray*& array,
+ const function_handle& item,
+ int& nrhs
+ )
+ {
+ array = static_cast<mxArray*>(item.h);
+ ++nrhs;
+ }
+
+ template <typename T>
+ void setup_input_args (
+ mxArray*& array,
+ const output_decorator<T>& item,
+ int& nrhs
+ )
+ {
+ }
+
+ template <typename T>
+ void setup_output_args (
+ const std::string& function_name,
+ mxArray* array,
+ const T& item,
+ int& nrhs
+ )
+ {
+ }
+
+ template <typename T>
+ void setup_output_args (
+ const std::string& function_name,
+ mxArray* array,
+ const output_decorator<T>& item,
+ int& i
+ )
+ {
+ try
+ {
+ validate_and_populate_arg(i,array,const_cast<T&>(item.item));
+ ++i;
+ }
+ catch (invalid_args_exception& e)
+ {
+ throw dlib::error("Error occurred calling MATLAB function '" + function_name + "' from mex file. \n"
+ "The MATLAB function didn't return what we expected it to. \nIn particular, return" + string(e.what()));
+ }
+ }
+
+ void call_matlab_for_real (
+ int nlhs,
+ mxArray* plhs[],
+ int nrhs,
+ mxArray* prhs[],
+ const std::string& function_name
+ )
+ {
+ int status = mexCallMATLAB(nlhs, plhs, nrhs, prhs, function_name.c_str());
+ if (status)
+ {
+ throw dlib::error("Error, an exception was thrown when we tried to call the MATLAB function '" + function_name + "'.");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ void call_matlab (
+ const std::string& function_name
+ )
+ {
+ using namespace mex_binding;
+
+ call_matlab_for_real(0,NULL,0,NULL, function_name);
+ }
+
+ template <typename T1>
+ void free_callback_resources (
+ int nlhs,
+ mxArray* plhs[],
+ int nrhs,
+ mxArray* prhs[]
+ )
+ {
+ // free resources
+ for (int i = 0; i < nlhs; ++i)
+ mxDestroyArray(plhs[i]);
+
+ for (int i = 0; i < nrhs; ++i)
+ {
+ // don't call mxDestroyArray() on function handles (which should only ever be in prhs[0])
+ if (i == 0 && dlib::is_same_type<T1,function_handle>::value)
+ continue;
+ mxDestroyArray(prhs[i]);
+ }
+ }
+
+ template <
+ typename T1
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 1;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1,
+ typename T2
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 2;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 3;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 4;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 5;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 6;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 7;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 8;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8,
+ typename T9
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8,
+ const T9& A9
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 9;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8,
+ typename T9,
+ typename T10
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8,
+ const T9& A9,
+ const T10& A10
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 10;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8,
+ typename T9,
+ typename T10,
+ typename T11
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8,
+ const T9& A9,
+ const T10& A10,
+ const T11& A11
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 11;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8,
+ typename T9,
+ typename T10,
+ typename T11,
+ typename T12
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8,
+ const T9& A9,
+ const T10& A10,
+ const T11& A11,
+ const T12& A12
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 12;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+ setup_input_args(prhs[nrhs], A12, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+ setup_output_args(function_name, plhs[i], A12, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8,
+ typename T9,
+ typename T10,
+ typename T11,
+ typename T12,
+ typename T13
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8,
+ const T9& A9,
+ const T10& A10,
+ const T11& A11,
+ const T12& A12,
+ const T13& A13
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 13;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+ setup_input_args(prhs[nrhs], A12, nrhs);
+ setup_input_args(prhs[nrhs], A13, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+ setup_output_args(function_name, plhs[i], A12, i);
+ setup_output_args(function_name, plhs[i], A13, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8,
+ typename T9,
+ typename T10,
+ typename T11,
+ typename T12,
+ typename T13,
+ typename T14
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8,
+ const T9& A9,
+ const T10& A10,
+ const T11& A11,
+ const T12& A12,
+ const T13& A13,
+ const T14& A14
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 14;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+ setup_input_args(prhs[nrhs], A12, nrhs);
+ setup_input_args(prhs[nrhs], A13, nrhs);
+ setup_input_args(prhs[nrhs], A14, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+ setup_output_args(function_name, plhs[i], A12, i);
+ setup_output_args(function_name, plhs[i], A13, i);
+ setup_output_args(function_name, plhs[i], A14, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8,
+ typename T9,
+ typename T10,
+ typename T11,
+ typename T12,
+ typename T13,
+ typename T14,
+ typename T15
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8,
+ const T9& A9,
+ const T10& A10,
+ const T11& A11,
+ const T12& A12,
+ const T13& A13,
+ const T14& A14,
+ const T15& A15
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 15;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+ setup_input_args(prhs[nrhs], A12, nrhs);
+ setup_input_args(prhs[nrhs], A13, nrhs);
+ setup_input_args(prhs[nrhs], A14, nrhs);
+ setup_input_args(prhs[nrhs], A15, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+ setup_output_args(function_name, plhs[i], A12, i);
+ setup_output_args(function_name, plhs[i], A13, i);
+ setup_output_args(function_name, plhs[i], A14, i);
+ setup_output_args(function_name, plhs[i], A15, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8,
+ typename T9,
+ typename T10,
+ typename T11,
+ typename T12,
+ typename T13,
+ typename T14,
+ typename T15,
+ typename T16
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8,
+ const T9& A9,
+ const T10& A10,
+ const T11& A11,
+ const T12& A12,
+ const T13& A13,
+ const T14& A14,
+ const T15& A15,
+ const T16& A16
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 16;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+ setup_input_args(prhs[nrhs], A12, nrhs);
+ setup_input_args(prhs[nrhs], A13, nrhs);
+ setup_input_args(prhs[nrhs], A14, nrhs);
+ setup_input_args(prhs[nrhs], A15, nrhs);
+ setup_input_args(prhs[nrhs], A16, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+ setup_output_args(function_name, plhs[i], A12, i);
+ setup_output_args(function_name, plhs[i], A13, i);
+ setup_output_args(function_name, plhs[i], A14, i);
+ setup_output_args(function_name, plhs[i], A15, i);
+ setup_output_args(function_name, plhs[i], A16, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1,
+ typename T2,
+ typename T3,
+ typename T4,
+ typename T5,
+ typename T6,
+ typename T7,
+ typename T8,
+ typename T9,
+ typename T10,
+ typename T11,
+ typename T12,
+ typename T13,
+ typename T14,
+ typename T15,
+ typename T16,
+ typename T17
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1,
+ const T2& A2,
+ const T3& A3,
+ const T4& A4,
+ const T5& A5,
+ const T6& A6,
+ const T7& A7,
+ const T8& A8,
+ const T9& A9,
+ const T10& A10,
+ const T11& A11,
+ const T12& A12,
+ const T13& A13,
+ const T14& A14,
+ const T15& A15,
+ const T16& A16,
+ const T17& A17
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 17;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+ setup_input_args(prhs[nrhs], A12, nrhs);
+ setup_input_args(prhs[nrhs], A13, nrhs);
+ setup_input_args(prhs[nrhs], A14, nrhs);
+ setup_input_args(prhs[nrhs], A15, nrhs);
+ setup_input_args(prhs[nrhs], A16, nrhs);
+ setup_input_args(prhs[nrhs], A17, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+ setup_output_args(function_name, plhs[i], A12, i);
+ setup_output_args(function_name, plhs[i], A13, i);
+ setup_output_args(function_name, plhs[i], A14, i);
+ setup_output_args(function_name, plhs[i], A15, i);
+ setup_output_args(function_name, plhs[i], A16, i);
+ setup_output_args(function_name, plhs[i], A17, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6,
+ typename T7, typename T8, typename T9, typename T10, typename T11, typename T12,
+ typename T13, typename T14, typename T15, typename T16, typename T17, typename T18
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const
+ T12& A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const
+ T17& A17, const T18& A18
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 18;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+ setup_input_args(prhs[nrhs], A12, nrhs);
+ setup_input_args(prhs[nrhs], A13, nrhs);
+ setup_input_args(prhs[nrhs], A14, nrhs);
+ setup_input_args(prhs[nrhs], A15, nrhs);
+ setup_input_args(prhs[nrhs], A16, nrhs);
+ setup_input_args(prhs[nrhs], A17, nrhs);
+ setup_input_args(prhs[nrhs], A18, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+ setup_output_args(function_name, plhs[i], A12, i);
+ setup_output_args(function_name, plhs[i], A13, i);
+ setup_output_args(function_name, plhs[i], A14, i);
+ setup_output_args(function_name, plhs[i], A15, i);
+ setup_output_args(function_name, plhs[i], A16, i);
+ setup_output_args(function_name, plhs[i], A17, i);
+ setup_output_args(function_name, plhs[i], A18, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6,
+ typename T7, typename T8, typename T9, typename T10, typename T11, typename T12,
+ typename T13, typename T14, typename T15, typename T16, typename T17, typename T18,
+ typename T19
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const
+ T12& A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const
+ T17& A17, const T18& A18, const T19& A19
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 19;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+ setup_input_args(prhs[nrhs], A12, nrhs);
+ setup_input_args(prhs[nrhs], A13, nrhs);
+ setup_input_args(prhs[nrhs], A14, nrhs);
+ setup_input_args(prhs[nrhs], A15, nrhs);
+ setup_input_args(prhs[nrhs], A16, nrhs);
+ setup_input_args(prhs[nrhs], A17, nrhs);
+ setup_input_args(prhs[nrhs], A18, nrhs);
+ setup_input_args(prhs[nrhs], A19, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+ setup_output_args(function_name, plhs[i], A12, i);
+ setup_output_args(function_name, plhs[i], A13, i);
+ setup_output_args(function_name, plhs[i], A14, i);
+ setup_output_args(function_name, plhs[i], A15, i);
+ setup_output_args(function_name, plhs[i], A16, i);
+ setup_output_args(function_name, plhs[i], A17, i);
+ setup_output_args(function_name, plhs[i], A18, i);
+ setup_output_args(function_name, plhs[i], A19, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+ template <
+ typename T1, typename T2, typename T3, typename T4, typename T5, typename T6,
+ typename T7, typename T8, typename T9, typename T10, typename T11, typename T12,
+ typename T13, typename T14, typename T15, typename T16, typename T17, typename T18,
+ typename T19, typename T20
+ >
+ void call_matlab (
+ const std::string& function_name,
+ const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6,
+ const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const
+ T12& A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const
+ T17& A17, const T18& A18, const T19& A19, const T20& A20
+ )
+ {
+ using namespace mex_binding;
+ const int num_args = 20;
+ mxArray* plhs[num_args] = {0};
+ mxArray* prhs[num_args] = {0};
+
+ int nrhs = 0;
+ setup_input_args(prhs[nrhs], A1, nrhs);
+ setup_input_args(prhs[nrhs], A2, nrhs);
+ setup_input_args(prhs[nrhs], A3, nrhs);
+ setup_input_args(prhs[nrhs], A4, nrhs);
+ setup_input_args(prhs[nrhs], A5, nrhs);
+ setup_input_args(prhs[nrhs], A6, nrhs);
+ setup_input_args(prhs[nrhs], A7, nrhs);
+ setup_input_args(prhs[nrhs], A8, nrhs);
+ setup_input_args(prhs[nrhs], A9, nrhs);
+ setup_input_args(prhs[nrhs], A10, nrhs);
+ setup_input_args(prhs[nrhs], A11, nrhs);
+ setup_input_args(prhs[nrhs], A12, nrhs);
+ setup_input_args(prhs[nrhs], A13, nrhs);
+ setup_input_args(prhs[nrhs], A14, nrhs);
+ setup_input_args(prhs[nrhs], A15, nrhs);
+ setup_input_args(prhs[nrhs], A16, nrhs);
+ setup_input_args(prhs[nrhs], A17, nrhs);
+ setup_input_args(prhs[nrhs], A18, nrhs);
+ setup_input_args(prhs[nrhs], A19, nrhs);
+ setup_input_args(prhs[nrhs], A20, nrhs);
+
+ const int nlhs = num_args - nrhs;
+ call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name);
+
+ int i = 0;
+ setup_output_args(function_name, plhs[i], A1, i);
+ setup_output_args(function_name, plhs[i], A2, i);
+ setup_output_args(function_name, plhs[i], A3, i);
+ setup_output_args(function_name, plhs[i], A4, i);
+ setup_output_args(function_name, plhs[i], A5, i);
+ setup_output_args(function_name, plhs[i], A6, i);
+ setup_output_args(function_name, plhs[i], A7, i);
+ setup_output_args(function_name, plhs[i], A8, i);
+ setup_output_args(function_name, plhs[i], A9, i);
+ setup_output_args(function_name, plhs[i], A10, i);
+ setup_output_args(function_name, plhs[i], A11, i);
+ setup_output_args(function_name, plhs[i], A12, i);
+ setup_output_args(function_name, plhs[i], A13, i);
+ setup_output_args(function_name, plhs[i], A14, i);
+ setup_output_args(function_name, plhs[i], A15, i);
+ setup_output_args(function_name, plhs[i], A16, i);
+ setup_output_args(function_name, plhs[i], A17, i);
+ setup_output_args(function_name, plhs[i], A18, i);
+ setup_output_args(function_name, plhs[i], A19, i);
+ setup_output_args(function_name, plhs[i], A20, i);
+
+ free_callback_resources<T1>(nlhs,plhs,nrhs,prhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ matlab_object::~matlab_object(
+ )
+ {
+ if (handle && should_free)
+ {
+ mxDestroyArray((mxArray*)handle);
+ handle = 0;
+ }
+ }
+
+ template <typename T>
+ matlab_object::
+ operator T(
+ ) const
+ {
+ T item;
+ get(item);
+ return item;
+ }
+
+ template <typename T>
+ void matlab_object::
+ get(
+ T& item
+ ) const
+ {
+ if (handle == 0)
+ throw dlib::invalid_args_exception("An attempt was made to access an empty matlab_object.");
+
+ mex_binding::validate_and_populate_arg(arg_idx,(mxArray*)handle,item);
+ }
+
+ template <typename T>
+ matlab_object& matlab_object::
+ operator= (
+ const T& new_val
+ )
+ {
+ mxArray* item;
+ mex_binding::assign_to_matlab(item, new_val);
+ handle = item;
+ should_free = true;
+ return *this;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ matlab_struct::sub::operator T() const
+ {
+ T item;
+ get(item);
+ return item;
+ }
+
+ template <typename T>
+ void matlab_struct::sub::get(T& item) const
+ {
+ if (struct_handle == 0)
+ throw dlib::error("Attempt to access data in an empty struct.");
+
+ mxArray* temp = mxGetFieldByNumber((const mxArray*)struct_handle, 0, field_idx);
+ if (temp == 0)
+ throw dlib::error("Attempt to access data in an empty struct.");
+
+ try
+ {
+ mex_binding::validate_and_populate_arg(0,temp,item);
+ }
+ catch(mex_binding::invalid_args_exception& e)
+ {
+ std::ostringstream sout;
+ sout << "Struct field '" << mxGetFieldNameByNumber((const mxArray*)struct_handle, field_idx) << "' can't be interpreted as the requested type."
+ << endl << e.what();
+ throw dlib::error(sout.str());
+ }
+ }
+
+ const matlab_struct::sub matlab_struct::
+ operator[] (const std::string& name) const
+ {
+ if (struct_handle == 0)
+ throw dlib::error("Struct does not have a field named '" + name + "'.");
+
+ matlab_struct::sub temp;
+ temp.struct_handle = struct_handle;
+ temp.field_idx = mxGetFieldNumber((const mxArray*)struct_handle, name.c_str());
+ if (temp.field_idx == -1 )
+ throw dlib::error("Struct does not have a field named '" + name + "'.");
+ return temp;
+ }
+
+ matlab_struct::sub matlab_struct::
+ operator[] (const std::string& name)
+ {
+ if (struct_handle == 0)
+ {
+ // We make a struct from scratch and mark that we will free it unless it gets
+ // written back to matlab by assign_to_matlab().
+ mwSize dims[1] = {1};
+ const char* name_str = name.c_str();
+ struct_handle = mxCreateStructArray(1, dims, 1, &name_str);
+ should_free = true;
+ if (struct_handle == 0)
+ throw dlib::error("Error creating struct from within mex function.");
+ }
+
+
+ matlab_struct::sub temp;
+ temp.struct_handle = struct_handle;
+ if ((temp.field_idx=mxGetFieldNumber((mxArray*)struct_handle, name.c_str())) == -1)
+ {
+ if ((temp.field_idx=mxAddField((mxArray*)struct_handle, name.c_str())) == -1)
+ {
+ throw dlib::error("Unable to add field '"+name + "' to struct.");
+ }
+ }
+ return temp;
+ }
+
+ const matlab_struct::sub matlab_struct::sub::
+ operator[] (const std::string& name) const
+ {
+ if (struct_handle == 0)
+ throw dlib::error("Struct does not have a field named '" + name + "'.");
+
+ matlab_struct::sub temp;
+ temp.struct_handle = mxGetFieldByNumber((const mxArray*)struct_handle, 0, field_idx);
+ if (temp.struct_handle == 0)
+ throw dlib::error("Failure to get struct field while calling mxGetFieldByNumber()");
+
+ if (!mxIsStruct((const mxArray*)temp.struct_handle))
+ throw dlib::error("Struct sub-field element '"+name+"' is not another struct.");
+
+ temp.field_idx = mxGetFieldNumber((const mxArray*)temp.struct_handle, name.c_str());
+ if (temp.field_idx == -1 )
+ throw dlib::error("Struct does not have a field named '" + name + "'.");
+ return temp;
+ }
+
+ matlab_struct::sub matlab_struct::sub::
+ operator[] (const std::string& name)
+ {
+ if (struct_handle == 0)
+ throw dlib::error("Struct does not have a field named '" + name + "'.");
+
+ matlab_struct::sub temp;
+ temp.struct_handle = mxGetFieldByNumber((const mxArray*)struct_handle, 0, field_idx);
+ // We are replacing this field with a struct if it exists and isn't already a struct
+ if (temp.struct_handle != 0 && !mxIsStruct((const mxArray*)temp.struct_handle))
+ {
+ mxDestroyArray((mxArray*)temp.struct_handle);
+ temp.struct_handle = 0;
+ }
+ if (temp.struct_handle == 0)
+ {
+ mwSize dims[1] = {1};
+ temp.struct_handle = mxCreateStructArray(1, dims, 0, 0);
+ if (temp.struct_handle == 0)
+ throw dlib::error("Failure to create new sub-struct field");
+ mxSetFieldByNumber((mxArray*)struct_handle, 0, field_idx, (mxArray*)temp.struct_handle);
+ }
+
+
+ if ((temp.field_idx=mxGetFieldNumber((mxArray*)temp.struct_handle, name.c_str())) == -1)
+ {
+ if ((temp.field_idx=mxAddField((mxArray*)temp.struct_handle, name.c_str())) == -1)
+ {
+ throw dlib::error("Unable to add field '"+name + "' to struct.");
+ }
+ }
+ return temp;
+ }
+
+ bool matlab_struct::has_field (
+ const std::string& name
+ ) const
+ {
+ if (struct_handle == 0)
+ return false;
+ return mxGetFieldNumber((const mxArray*)struct_handle, name.c_str()) != -1;
+ }
+
+ bool matlab_struct::sub::has_field (
+ const std::string& name
+ ) const
+ {
+ if (struct_handle == 0)
+ return false;
+ mxArray* temp = mxGetFieldByNumber((const mxArray*)struct_handle, 0, field_idx);
+ if (temp == 0 || !mxIsStruct(temp))
+ return false;
+ return mxGetFieldNumber(temp, name.c_str()) != -1;
+ }
+
+ template <typename T>
+ matlab_struct::sub& matlab_struct::sub::operator= (
+ const T& new_val
+ )
+ {
+ // Delete anything in the field before we overwrite it
+ mxArray* item = mxGetFieldByNumber((mxArray*)struct_handle, 0, field_idx);
+ if (item != 0)
+ {
+ mxDestroyArray((mxArray*)item);
+ item = 0;
+ }
+
+ // Now set the field
+ mex_binding::assign_to_matlab(item, new_val);
+ mxSetFieldByNumber((mxArray*)struct_handle, 0, field_idx, item);
+
+ return *this;
+ }
+
+ matlab_struct::
+ ~matlab_struct (
+ )
+ {
+ if (struct_handle && should_free)
+ {
+ mxDestroyArray((mxArray*)struct_handle);
+ struct_handle = 0;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void call_matlab (
+ const function_handle& funct
+ )
+ {
+ call_matlab("feval", funct);
+ }
+
+ extern "C" bool utIsInterruptPending();
+ void check_for_matlab_ctrl_c(
+ )
+ {
+ if (utIsInterruptPending())
+ throw mex_binding::user_hit_ctrl_c();
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+#ifdef MEX_CLASS_NAME
+template <typename T, typename mfp_type>
+class mex_class_wrapper
+{
+public:
+ mex_class_wrapper(T& obj_, mfp_type mfp_) : obj(obj_), mfp(mfp_) {}
+
+ template <typename... Args>
+ void operator()(Args&&... args) const
+ {
+ (obj.*mfp)(std::forward<Args>(args)...);
+ }
+
+ mfp_type mfp;
+ T& obj;
+};
+
+template <typename T, typename mfp_type>
+mex_class_wrapper<T,mfp_type> wrap_mex_class(T& obj, mfp_type mfp) { return mex_class_wrapper<T,mfp_type>(obj, mfp); }
+
+namespace dlib
+{
+ template <typename T, typename mfp_type>
+ struct sig_traits<mex_class_wrapper<T,mfp_type>>
+ : public sig_traits<mfp_type>
+ {};
+
+ template <size_t i, typename T, bool is_good = i < std::tuple_size<T>::value>
+ struct tuple_element_default_void
+ {
+ typedef void type;
+ };
+
+ template <size_t i, typename T>
+ struct tuple_element_default_void<i,T,true>
+ {
+ typedef typename std::tuple_element<i,T>::type type;
+ };
+
+ template <typename class_type, typename return_type, typename... Args>
+ struct sig_traits<return_type(class_type::*)(Args...) >
+ {
+ enum { num_args = sizeof...(Args) };
+
+ typedef return_type result_type;
+
+ template <size_t i>
+ struct arg
+ {
+ typedef typename tuple_element_default_void<i-1, std::tuple<Args...>>::type type;
+ };
+
+ // These are here because that's how things are defined in sig_traits (since it is
+ // older than C++11, along with most of the other code in this file)
+ typedef typename arg<1>::type arg1_type;
+ typedef typename arg<2>::type arg2_type;
+ typedef typename arg<3>::type arg3_type;
+ typedef typename arg<4>::type arg4_type;
+ typedef typename arg<5>::type arg5_type;
+ typedef typename arg<6>::type arg6_type;
+ typedef typename arg<7>::type arg7_type;
+ typedef typename arg<8>::type arg8_type;
+ typedef typename arg<9>::type arg9_type;
+ typedef typename arg<10>::type arg10_type;
+ typedef typename arg<11>::type arg11_type;
+ typedef typename arg<12>::type arg12_type;
+ typedef typename arg<13>::type arg13_type;
+ typedef typename arg<14>::type arg14_type;
+ typedef typename arg<15>::type arg15_type;
+ typedef typename arg<16>::type arg16_type;
+ typedef typename arg<17>::type arg17_type;
+ typedef typename arg<18>::type arg18_type;
+ typedef typename arg<19>::type arg19_type;
+ typedef typename arg<20>::type arg20_type;
+ };
+
+ template <typename class_type, typename return_type, typename... Args>
+ struct sig_traits<return_type(class_type::*)(Args...) const>
+ {
+ enum { num_args = sizeof...(Args) };
+
+ typedef return_type result_type;
+
+ template <size_t i>
+ struct arg
+ {
+ typedef typename tuple_element_default_void<i-1, std::tuple<Args...>>::type type;
+ };
+
+ // These are here because that's how things are defined in sig_traits (since it is
+ // older than C++11, along with most of the other code in this file)
+ typedef typename arg<1>::type arg1_type;
+ typedef typename arg<2>::type arg2_type;
+ typedef typename arg<3>::type arg3_type;
+ typedef typename arg<4>::type arg4_type;
+ typedef typename arg<5>::type arg5_type;
+ typedef typename arg<6>::type arg6_type;
+ typedef typename arg<7>::type arg7_type;
+ typedef typename arg<8>::type arg8_type;
+ typedef typename arg<9>::type arg9_type;
+ typedef typename arg<10>::type arg10_type;
+ typedef typename arg<11>::type arg11_type;
+ typedef typename arg<12>::type arg12_type;
+ typedef typename arg<13>::type arg13_type;
+ typedef typename arg<14>::type arg14_type;
+ typedef typename arg<15>::type arg15_type;
+ typedef typename arg<16>::type arg16_type;
+ typedef typename arg<17>::type arg17_type;
+ typedef typename arg<18>::type arg18_type;
+ typedef typename arg<19>::type arg19_type;
+ typedef typename arg<20>::type arg20_type;
+ };
+}
+
+// ----------------------------------------------------------------------------------------
+
+
+template <size_t I>
+struct visit_impl
+{
+ template <typename T, typename F>
+ static void visit(T& tup, size_t idx, F fun)
+ {
+ if (idx == I - 1) fun(std::get<I - 1>(tup));
+ else visit_impl<I - 1>::visit(tup, idx, fun);
+ }
+};
+
+template <>
+struct visit_impl<0>
+{
+ template <typename T, typename F>
+ static void visit(T& tup, size_t idx, F fun) { DLIB_CASSERT(false,"this should never happen"); }
+};
+
+template <typename F, typename... Ts>
+void visit_at(std::tuple<Ts...> const& tup, size_t idx, F fun)
+{
+ visit_impl<sizeof...(Ts)>::visit(tup, idx, fun);
+}
+
+template <typename F, typename... Ts>
+void visit_at(std::tuple<Ts...>& tup, size_t idx, F fun)
+{
+ visit_impl<sizeof...(Ts)>::visit(tup, idx, fun);
+}
+
+class mex_class_dispatch
+{
+public:
+ mex_class_dispatch(
+ MEX_CLASS_NAME* ptr_,
+ int nlhs_,
+ mxArray** plhs_,
+ int nrhs_,
+ const mxArray** prhs_
+ ) :
+ ptr(ptr_),
+ nlhs(nlhs_),
+ plhs(plhs_),
+ nrhs(nrhs_),
+ prhs(prhs_)
+ {}
+
+ template <typename funct>
+ void operator() (const funct& mfp)
+ {
+ mex_binding::call_mex_function(wrap_mex_class(*ptr,mfp), nlhs, plhs, nrhs, prhs);
+ }
+
+private:
+ MEX_CLASS_NAME* ptr;
+ int nlhs;
+ mxArray** plhs;
+ int nrhs;
+ const mxArray** prhs;
+};
+
+class class_factory_type : dlib::noncopyable
+{
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a container class for all the MEX_CLASS_NAME objects we create. It allows
+ us to track what we have created and make sure the MATLAB user doesn't do any
+ double frees or use any stale pointers.
+
+ It also helps us deal with the problem that would otherwise arise when a mex file
+ is unloaded from MATLAB when there are still active pointers to MEX_CLASS_NAME objects
+ in MATLAB, since we will be able to detect stale pointers.
+ !*/
+public:
+
+ class_factory_type()
+ {
+ seed = (uint64)time(0);
+ }
+
+ ~class_factory_type()
+ {
+ for (auto i : object_table)
+ delete i.second;
+ }
+
+ template <typename ...T>
+ uint64 create(T&& ...args)
+ {
+ MEX_CLASS_NAME* item = new MEX_CLASS_NAME(std::forward<T>(args)...);
+ uint64 id = (uint64)item;
+ // Now generate a unique id that incorporates our seed value. The point of doing
+ // this is to avoid any chance that a mex file will get unloaded and then reloaded
+ // and start constructing objects with the same addresses, while old stale objects
+ // at those addresses are still stored in matlab, which would then call into the
+ // mex file and make things go crazy. So here we try to generate ID numbers that
+ // are globally unique.
+ uint64 i = 0;
+ id = murmur_hash3_128bit_3(id, seed, ++i).first;
+ // very unlikely but make sure there aren't any hash collisions.
+ while(object_table.count(id) != 0)
+ id = murmur_hash3_128bit_3(id, seed, ++i).first;
+
+ object_table[id] = item;
+ return id;
+ }
+
+ void free(uint64 item)
+ {
+ if (object_table.count(item) == 0)
+ {
+ throw dlib::error("An attempt to deallocate a mex class object with an invalid pointer was detected.");
+ }
+
+ delete object_table[item];
+ object_table.erase(item);
+ }
+
+ MEX_CLASS_NAME* access(uint64 item) // convert numeric ID to pointer to object that can be used.
+ {
+ if (object_table.count(item) == 0)
+ {
+ throw dlib::error("An attempt to access a mex class object with an invalid pointer was detected.");
+ }
+
+ return object_table[item];
+ }
+
+private:
+
+ std::map<uint64, MEX_CLASS_NAME*> object_table;
+ uint64 seed;
+} class_factory;
+
+// ----------------------------------------------------------------------------------------
+
+// Make a FOREACH macro
+#define FE_1(WHAT, X) WHAT(X)
+#define FE_2(WHAT, X, ...) WHAT(X),FE_1(WHAT, __VA_ARGS__)
+#define FE_3(WHAT, X, ...) WHAT(X),FE_2(WHAT, __VA_ARGS__)
+#define FE_4(WHAT, X, ...) WHAT(X),FE_3(WHAT, __VA_ARGS__)
+#define FE_5(WHAT, X, ...) WHAT(X),FE_4(WHAT, __VA_ARGS__)
+#define FE_6(WHAT, X, ...) WHAT(X),FE_5(WHAT, __VA_ARGS__)
+#define FE_7(WHAT, X, ...) WHAT(X),FE_6(WHAT, __VA_ARGS__)
+#define FE_8(WHAT, X, ...) WHAT(X),FE_7(WHAT, __VA_ARGS__)
+#define FE_9(WHAT, X, ...) WHAT(X),FE_8(WHAT, __VA_ARGS__)
+#define FE_10(WHAT, X, ...) WHAT(X),FE_9(WHAT, __VA_ARGS__)
+#define FE_11(WHAT, X, ...) WHAT(X),FE_10(WHAT, __VA_ARGS__)
+#define FE_12(WHAT, X, ...) WHAT(X),FE_11(WHAT, __VA_ARGS__)
+#define FE_13(WHAT, X, ...) WHAT(X),FE_12(WHAT, __VA_ARGS__)
+#define FE_14(WHAT, X, ...) WHAT(X),FE_13(WHAT, __VA_ARGS__)
+#define FE_15(WHAT, X, ...) WHAT(X),FE_14(WHAT, __VA_ARGS__)
+#define FE_16(WHAT, X, ...) WHAT(X),FE_15(WHAT, __VA_ARGS__)
+#define FE_17(WHAT, X, ...) WHAT(X),FE_16(WHAT, __VA_ARGS__)
+#define FE_18(WHAT, X, ...) WHAT(X),FE_17(WHAT, __VA_ARGS__)
+#define FE_19(WHAT, X, ...) WHAT(X),FE_18(WHAT, __VA_ARGS__)
+#define FE_20(WHAT, X, ...) WHAT(X),FE_19(WHAT, __VA_ARGS__)
+//... repeat as needed
+#define GET_MACRO(_1,_2,_3,_4,_5,_6,_7,_8,_9,_10,_11,_12,_13,_14,_15,_16,_17,_18,_19,_20,NAME,...) NAME
+#define FOR_EACH(action,...) GET_MACRO(__VA_ARGS__,FE_20,FE_19,FE_18,FE_17,FE_16,FE_15,FE_14,FE_13,FE_12,FE_11,FE_10,FE_9,FE_8,FE_7,FE_6,FE_5,FE_4,FE_3,FE_2,FE_1)(action,__VA_ARGS__)
+#define MEX_CLASS_ANNOTATE(x) &MEX_CLASS_NAME::x
+
+// Now make a tuple containing all the member function pointers to our MEX_CLASS_NAME
+auto mex_class_methods = std::make_tuple(FOR_EACH(MEX_CLASS_ANNOTATE, MEX_CLASS_METHODS));
+
+
+#endif // MEX_CLASS_NAME
+
+// ----------------------------------------------------------------------------------------
+
+bool is_string(const mxArray* arr, const char* str)
+{
+ if (mxIsChar(arr))
+ {
+ char ch[20];
+ DLIB_CASSERT(mxGetString(arr, ch, sizeof(ch))==0, "Unable to retrieve string");
+ ch[sizeof(ch)-1] = 0;// ensure NULL termination regardless of what MATLAB does.
+ return strcmp(str,ch)==0;
+ }
+ return false;
+}
+
+// ----------------------------------------------------------------------------------------
+
+/* The gateway function called by MATLAB*/
+void mexFunction( int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[])
+{
+ // Only remap cout and cerr if we aren't using octave since octave already does this.
+#if !defined(OCTAVE_IMPORT) && !defined(OCTAVE_API)
+ // make it so cout prints to mexPrintf()
+ mex_binding::mex_streambuf sb;
+ // make it so cerr prints to mexWarnMsgTxt()
+ mex_binding::mex_warn_streambuf wsb;
+#endif
+
+ try
+ {
+#ifdef MEX_CLASS_NAME
+ if (nrhs == 0)
+ {
+ #define DEF2STR(x) DEF2STR2((x))
+ #define DEF2STR2(x) #x
+
+ string classname = trim(string(DEF2STR(MEX_CLASS_NAME)), " \t()");
+ std::vector<string> methods = split(trim(string(DEF2STR(MEX_CLASS_METHODS)), " \t()"), " \t,");
+
+ string mex_filename = trim(string(DEF2STR(MEX_FILENAME))," \t()");
+ bool has_load_obj = false;
+ size_t load_obj_idx = 0;
+
+ cout << "classdef " << classname << " < handle\n"
+ << " properties (Access = private)\n"
+ << " cpp_ptr\n"
+ << " end\n"
+ << "\n"
+ << " methods\n"
+ << " function this = "<<classname<<"()\n"
+ << " this.cpp_ptr = "<<mex_filename<<"('construct');\n"
+ << " end\n"
+ << "\n"
+ << " function copied_obj = clone(this)\n"
+ << " %Returns a new independent object that is a copy of this.\n"
+ << " copied_obj = "<<classname<<"();\n"
+ << " copied_obj.cpp_ptr = "<<mex_filename<<"(this.cpp_ptr,'clone');\n"
+ << " end\n"
+ << "\n";
+ for (size_t i = 0; i < methods.size(); ++i)
+ {
+ if (methods[i] == "load_obj")
+ {
+ has_load_obj = true;
+ load_obj_idx = i;
+ }
+ else
+ {
+ cout << " function varargout = "<<methods[i]<<"(this, varargin) \n"
+ << " [varargout{1:nargout}] = "<<mex_filename<<"(this.cpp_ptr, "<<i+1<<", varargin{:}); \n"
+ << " end \n\n";
+ }
+ }
+ cout << " end\n\n";
+
+ cout << " methods(Access=private) \n"
+ << " function delete(this) \n"
+ << " "<<mex_filename<<"(this.cpp_ptr); \n"
+ << " end \n";
+ if (has_load_obj)
+ {
+ cout << " function varargout = load_obj(this, varargin) \n"
+ << " [varargout{1:nargout}] = "<<mex_filename<<"(this.cpp_ptr, "<<load_obj_idx+1<<", varargin{:}); \n"
+ << " end \n";
+ }
+ cout << " end \n\n";
+
+ if (has_load_obj)
+ {
+ cout << " methods(Static) \n"
+ << " function this = loadobj(in) \n"
+ << " this = "<<classname<<"(); \n"
+ << " this.load_obj(in); \n"
+ << " end \n"
+ << " end \n";
+ }
+ cout << "end \n";
+ }
+ else if (nrhs == 1)
+ {
+ // this is a constructor call
+ if (is_string(prhs[0],"construct"))
+ {
+ DLIB_CASSERT(nlhs == 1, "If you want to construct a new object then you must assign the pointer to something.");
+ plhs[0] = mxCreateNumericMatrix(1, 1, mxUINT64_CLASS, mxREAL);
+ uint64* ptr_int = (uint64*)mxGetData(plhs[0]);
+ *ptr_int = class_factory.create();
+ }
+ else // destructor call
+ {
+ DLIB_CASSERT(mxIsUint64(prhs[0]) && mxGetNumberOfElements(prhs[0])==1, "When calling a class destructor the first argument must be a pointer (a UINT64 in matlab)");
+ const uint64 ptr_int = *((uint64*)mxGetData(prhs[0]));
+ class_factory.free(ptr_int);
+ }
+ }
+ else // a regular function call
+ {
+ DLIB_CASSERT(mxIsUint64(prhs[0]) && mxGetNumberOfElements(prhs[0])==1, "When calling a class member function the first argument must be a pointer (a UINT64 in matlab)");
+ if (is_string(prhs[1], "clone"))
+ {
+ DLIB_CASSERT(nlhs == 1, "If you want to construct a new object then you must assign the pointer to something.");
+ const uint64 ptr_int = *((uint64*)mxGetData(prhs[0]));
+
+ MEX_CLASS_NAME* ptr = class_factory.access(ptr_int);
+
+ plhs[0] = mxCreateNumericMatrix(1, 1, mxUINT64_CLASS, mxREAL);
+ uint64* ptr_int2 = (uint64*)mxGetData(plhs[0]);
+ // copy construct a new object
+ *ptr_int2 = class_factory.create(*ptr);
+ }
+ else
+ {
+ DLIB_CASSERT(mxIsDouble(prhs[1]) && mxGetNumberOfElements(prhs[1])==1, "When calling a class member function the second argument must be a number indicating which member function");
+ const uint64 ptr_int = *((uint64*)mxGetData(prhs[0]));
+ const int funct_idx = *(mxGetPr(prhs[1]));
+
+ auto num_registered_functions = std::tuple_size<decltype(mex_class_methods)>::value;
+ DLIB_CASSERT(1 <= funct_idx && funct_idx <= num_registered_functions, "Invalid function index provided.");
+
+ MEX_CLASS_NAME* ptr = class_factory.access(ptr_int);
+
+ // we used the first two arguments to decide what function to call. So adjust nrhs
+ // and prhs so the member function never sees them.
+ mex_class_dispatch dispatch(ptr, nlhs, plhs, nrhs-2, prhs+2);
+ // now invoke the member function, subtract 1 to convert to 0 indexing.
+ visit_at(mex_class_methods, funct_idx-1, dispatch);
+ }
+ }
+#else
+ mex_binding::call_mex_function(mex_function, nlhs, plhs, nrhs, prhs);
+#endif
+ }
+ catch (mex_binding::invalid_args_exception& e)
+ {
+ mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
+ mex_binding::escape_percent(e.what()).c_str());
+ }
+ catch (mex_binding::user_hit_ctrl_c& )
+ {
+ // do nothing, just return to matlab
+ }
+ catch (std::exception& e)
+ {
+ mexErrMsgIdAndTxt("mex_function:error",
+ mex_binding::escape_percent(e.what()).c_str());
+ }
+
+ cout << flush;
+ cerr << flush;
+}
+
+// ----------------------------------------------------------------------------------------
+
diff --git a/ml/dlib/dlib/matlab/subprocess_stream.cpp b/ml/dlib/dlib/matlab/subprocess_stream.cpp
new file mode 100644
index 000000000..4d4d53af0
--- /dev/null
+++ b/ml/dlib/dlib/matlab/subprocess_stream.cpp
@@ -0,0 +1,537 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "subprocess_stream.h"
+
+#include <sstream>
+#include <utility>
+#include <iostream>
+#include <cstdio>
+#include <fcntl.h>
+#include <signal.h>
+#include <sys/wait.h>
+#include <sys/select.h>
+#include "call_matlab.h"
+
+using namespace std;
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void make_fd_non_blocking(int fd)
+ {
+ int flags = fcntl(fd, F_GETFL, 0);
+ fcntl(fd, F_SETFL, flags | O_NONBLOCK);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // Block until fd is ready to read, while also echoing whatever is in fd_printf to
+ // cout.
+ int read_echoing_select(int fd, int fd_printf)
+ {
+ // run until fd has data ready
+ while(fd_printf >= 0)
+ {
+ fd_set rfds;
+ int retval;
+
+ while(true)
+ {
+ FD_ZERO(&rfds);
+ FD_SET(fd, &rfds);
+ FD_SET(fd_printf, &rfds);
+
+ // select times out every second just so we can check for matlab ctrl+c.
+ struct timeval tv;
+ tv.tv_sec = 1;
+ tv.tv_usec = 0;
+
+ try{check_for_matlab_ctrl_c();} catch(...) { return 1; }
+ retval = select(std::max(fd,fd_printf)+1, &rfds, NULL, NULL, &tv);
+ try{check_for_matlab_ctrl_c();} catch(...) { return 1; }
+ if (retval == 0) // keep going if it was just a timeout.
+ continue;
+ else if (retval == -1 && errno == EINTR)
+ continue;
+
+ break;
+ }
+
+ if (retval == -1)
+ {
+ return 1;
+ }
+ else
+ {
+ if (FD_ISSET(fd,&rfds))
+ {
+ return 0;
+ }
+ else
+ {
+ char buf[1024];
+ int num = read(fd_printf,buf, sizeof(buf)-1);
+ if (num == -1)
+ return 1;
+ if (num > 0)
+ {
+ buf[num] = 0;
+ cout << buf << flush;
+ }
+ }
+ }
+ }
+ return 0;
+ }
+
+ int write_echoing_select(int fd, int fd_printf)
+ {
+ // run until fd has data ready
+ while(fd_printf >= 0)
+ {
+ fd_set rfds, wfds;
+ int retval;
+ while(true)
+ {
+ FD_ZERO(&rfds);
+ FD_ZERO(&wfds);
+ FD_SET(fd, &wfds);
+ FD_SET(fd_printf, &rfds);
+
+ // select times out every second just so we can check for matlab ctrl+c.
+ struct timeval tv;
+ tv.tv_sec = 1;
+ tv.tv_usec = 0;
+
+ try{check_for_matlab_ctrl_c();} catch(...) { return 1; }
+ retval = select(std::max(fd,fd_printf)+1, &rfds, &wfds, NULL, &tv);
+ try{check_for_matlab_ctrl_c();} catch(...) { return 1; }
+ if (retval == 0) // keep going if it was just a timeout.
+ continue;
+ else if (retval == -1 && errno == EINTR)
+ continue;
+
+ break;
+ }
+
+ if (retval == -1)
+ {
+ return 1;
+ }
+ else
+ {
+ if (FD_ISSET(fd,&wfds))
+ {
+ return 0;
+ }
+ else
+ {
+ char buf[1024];
+ int num = read(fd_printf,buf, sizeof(buf)-1);
+ if (num == -1)
+ return 1;
+ if (num > 0)
+ {
+ buf[num] = 0;
+ cout << buf << flush;
+ }
+ }
+ }
+ }
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class filestreambuf : public std::streambuf
+ {
+ /*!
+ INITIAL VALUE
+ - fd == the file descriptor we read from.
+ - in_buffer == an array of in_buffer_size bytes
+ - out_buffer == an array of out_buffer_size bytes
+
+ CONVENTION
+ - in_buffer == the input buffer used by this streambuf
+ - out_buffer == the output buffer used by this streambuf
+ - max_putback == the maximum number of chars to have in the put back buffer.
+ !*/
+
+ public:
+
+ filestreambuf (
+ int fd_,
+ int fd_printf_
+ ) :
+ fd(fd_),
+ fd_printf(fd_printf_),
+ out_buffer(0),
+ in_buffer(0)
+ {
+ init();
+ }
+
+ virtual ~filestreambuf (
+ )
+ {
+ sync();
+ delete [] out_buffer;
+ delete [] in_buffer;
+ }
+
+ int sync (
+ )
+ {
+ if (flush_out_buffer() == EOF)
+ {
+ // an error occurred
+ return -1;
+ }
+ return 0;
+ }
+ protected:
+
+ void init (
+ )
+ {
+ try
+ {
+ out_buffer = new char[out_buffer_size];
+ in_buffer = new char[in_buffer_size];
+ }
+ catch (...)
+ {
+ if (out_buffer) delete [] out_buffer;
+ throw;
+ }
+ setp(out_buffer, out_buffer + (out_buffer_size-1));
+ setg(in_buffer+max_putback,
+ in_buffer+max_putback,
+ in_buffer+max_putback);
+ }
+
+ int flush_out_buffer (
+ )
+ {
+ int num = static_cast<int>(pptr()-pbase());
+ const int num_written = num;
+ char* buf = out_buffer;
+ while(num != 0)
+ {
+ if(write_echoing_select(fd, fd_printf))
+ return EOF;
+ int status = write(fd,buf,num);
+ if (status < 0)
+ {
+ // the write was not successful so return EOF
+ return EOF;
+ }
+ num -= status;
+ buf += status;
+ }
+ pbump(-num_written);
+ return num_written;
+ }
+
+ // output functions
+ int_type overflow (
+ int_type c
+ )
+ {
+ if (c != EOF)
+ {
+ *pptr() = c;
+ pbump(1);
+ }
+ if (flush_out_buffer() == EOF)
+ {
+ // an error occurred
+ return EOF;
+ }
+ return c;
+ }
+
+
+ std::streamsize xsputn (
+ const char* s,
+ std::streamsize num
+ )
+ {
+ // Add a sanity check here
+ DLIB_ASSERT(num >= 0,
+ "\tstd::streamsize filestreambuf::xsputn"
+ << "\n\tThe number of bytes to write can't be negative"
+ << "\n\tnum: " << num
+ << "\n\tthis: " << this
+ );
+
+ std::streamsize space_left = static_cast<std::streamsize>(epptr()-pptr());
+ if (num <= space_left)
+ {
+ std::memcpy(pptr(),s,static_cast<size_t>(num));
+ pbump(static_cast<int>(num));
+ return num;
+ }
+ else
+ {
+ std::memcpy(pptr(),s,static_cast<size_t>(space_left));
+ s += space_left;
+ pbump(space_left);
+ std::streamsize num_left = num - space_left;
+
+ if (flush_out_buffer() == EOF)
+ {
+ // the write was not successful so return that 0 bytes were written
+ return 0;
+ }
+
+ if (num_left < out_buffer_size)
+ {
+ std::memcpy(pptr(),s,static_cast<size_t>(num_left));
+ pbump(num_left);
+ return num;
+ }
+ else
+ {
+ while(num_left != 0)
+ {
+ if(write_echoing_select(fd, fd_printf))
+ return EOF;
+ int status = write(fd,s,num_left);
+ if (status < 0)
+ {
+ // the write was not successful so return that 0 bytes were written
+ return 0;
+ }
+ num_left -= status;
+ s += status;
+ }
+ return num;
+ }
+ }
+ }
+
+ // input functions
+ int_type underflow(
+ )
+ {
+ if (gptr() < egptr())
+ {
+ return static_cast<unsigned char>(*gptr());
+ }
+
+ int num_put_back = static_cast<int>(gptr() - eback());
+ if (num_put_back > max_putback)
+ {
+ num_put_back = max_putback;
+ }
+
+ // copy the putback characters into the putback end of the in_buffer
+ std::memmove(in_buffer+(max_putback-num_put_back), gptr()-num_put_back, num_put_back);
+
+
+ if (read_echoing_select(fd, fd_printf))
+ return EOF;
+ int num = read(fd,in_buffer+max_putback, in_buffer_size-max_putback);
+ if (num <= 0)
+ {
+ // an error occurred or the connection is over which is EOF
+ return EOF;
+ }
+
+ // reset in_buffer pointers
+ setg (in_buffer+(max_putback-num_put_back),
+ in_buffer+max_putback,
+ in_buffer+max_putback+num);
+
+ return static_cast<unsigned char>(*gptr());
+ }
+
+ std::streamsize xsgetn (
+ char_type* s,
+ std::streamsize n
+ )
+ {
+ std::streamsize temp = n;
+ while (n > 0)
+ {
+ int num = static_cast<int>(egptr() - gptr());
+ if (num >= n)
+ {
+ // copy data from our buffer
+ std::memcpy(s, gptr(), static_cast<size_t>(n));
+ gbump(static_cast<int>(n));
+ return temp;
+ }
+
+ // read more data into our buffer
+ if (num == 0)
+ {
+ if (underflow() == EOF)
+ break;
+ continue;
+ }
+
+ // copy all the data from our buffer
+ std::memcpy(s, gptr(), num);
+ n -= num;
+ gbump(num);
+ s += num;
+ }
+ return temp-n;
+ }
+
+ private:
+
+ // member data
+ int fd;
+ int fd_printf;
+ static const std::streamsize max_putback = 4;
+ static const std::streamsize out_buffer_size = 10000;
+ static const std::streamsize in_buffer_size = 10000;
+ char* out_buffer;
+ char* in_buffer;
+
+ };
+
+ namespace impl
+ {
+ int get_data_fd()
+ {
+ char* env_fd = getenv("DLIB_SUBPROCESS_DATA_FD");
+ DLIB_CASSERT(env_fd != 0,"");
+ return atoi(env_fd);
+ }
+
+ std::iostream& get_data_iostream()
+ {
+ static filestreambuf dbuff(get_data_fd(), -1);
+ static iostream out(&dbuff);
+ return out;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ subprocess_stream::
+ subprocess_stream(const char* program_name) : stderr(NULL), iosub(NULL)
+ {
+ if (access(program_name, F_OK))
+ throw dlib::error("Error: '" + std::string(program_name) + "' file does not exist.");
+ if (access(program_name, X_OK))
+ throw dlib::error("Error: '" + std::string(program_name) + "' file is not executable.");
+
+ child_pid = fork();
+ if (child_pid == -1)
+ throw dlib::error("Failed to start child process");
+
+ if (child_pid == 0)
+ {
+ // In child process
+ dup2(stdout_pipe.child_fd(), STDOUT_FILENO);
+ dup2(stderr_pipe.child_fd(), STDERR_FILENO);
+ stdout_pipe.close();
+ stderr_pipe.close();
+
+ char* argv[] = {(char*)program_name, nullptr};
+ char* cudadevs = getenv("CUDA_VISIBLE_DEVICES");
+ if (cudadevs)
+ {
+ std::ostringstream sout;
+ sout << "DLIB_SUBPROCESS_DATA_FD="<<data_pipe.child_fd();
+ std::string extra = sout.str();
+
+ std::string extra2 = std::string("CUDA_VISIBLE_DEVICES=") + cudadevs;
+ char* envp[] = {(char*)extra.c_str(), (char*)extra2.c_str(), nullptr};
+ execve(argv[0], argv, envp);
+ }
+ else
+ {
+ std::ostringstream sout;
+ sout << "DLIB_SUBPROCESS_DATA_FD="<<data_pipe.child_fd();
+ std::string extra = sout.str();
+ char* envp[] = {(char*)extra.c_str(), nullptr};
+ execve(argv[0], argv, envp);
+ }
+
+
+ // If launching the child didn't work then bail immediately so the parent
+ // process has no chance to get tweaked out (*cough* MATLAB *cough*).
+ _Exit(1);
+ }
+ else
+ {
+ // In parent process
+ close(data_pipe.child_fd());
+ close(stdout_pipe.child_fd());
+ close(stderr_pipe.child_fd());
+ make_fd_non_blocking(data_pipe.parent_fd());
+ make_fd_non_blocking(stdout_pipe.parent_fd());
+ make_fd_non_blocking(stderr_pipe.parent_fd());
+ inout_buf = std::unique_ptr<filestreambuf>(new filestreambuf(data_pipe.parent_fd(), stdout_pipe.parent_fd()));
+ err_buf = std::unique_ptr<filestreambuf>(new filestreambuf(stderr_pipe.parent_fd(), stdout_pipe.parent_fd()));
+ iosub.rdbuf(inout_buf.get());
+ stderr.rdbuf(err_buf.get());
+ iosub.tie(&iosub);
+ stderr.tie(&iosub);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ subprocess_stream::
+ ~subprocess_stream()
+ {
+ try
+ {
+ wait();
+ }
+ catch (dlib::error& e)
+ {
+ std::cerr << e.what() << std::endl;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void subprocess_stream::
+ wait()
+ {
+ if (!wait_called)
+ {
+ wait_called = true;
+ send_eof();
+
+ std::ostringstream sout;
+ sout << stderr.rdbuf();
+
+ try{check_for_matlab_ctrl_c();} catch(...)
+ {
+ kill(child_pid, SIGTERM);
+ }
+
+ int status;
+ waitpid(child_pid, &status, 0);
+ if (status)
+ throw dlib::error("Child process terminated with an error.\n" + sout.str());
+
+ if (sout.str().size() != 0)
+ throw dlib::error("Child process terminated with an error.\n" + sout.str());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void subprocess_stream::
+ send_eof() { inout_buf->sync(); ::close(data_pipe.parent_fd()); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
diff --git a/ml/dlib/dlib/matlab/subprocess_stream.h b/ml/dlib/dlib/matlab/subprocess_stream.h
new file mode 100644
index 000000000..b00904c12
--- /dev/null
+++ b/ml/dlib/dlib/matlab/subprocess_stream.h
@@ -0,0 +1,223 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SUBPROCeSS_STREAM_H_
+#define DLIB_SUBPROCeSS_STREAM_H_
+
+#include <utility>
+#include <unistd.h>
+#include <iostream>
+#include <memory>
+#include <dlib/matrix.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+
+
+namespace dlib
+{
+
+// --------------------------------------------------------------------------------------
+
+ // Call dlib's serialize and deserialize by default. The point of this version of
+ // serialize is to do something fast that normally we wouldn't do, like directly copy
+ // memory. This is safe since this is an interprocess communication happening the same
+ // machine.
+ template <typename T> void interprocess_serialize ( const T& item, std::ostream& out) { serialize(item, out); }
+ template <typename T> void interprocess_deserialize (T& item, std::istream& in) { deserialize(item, in); }
+
+ // But have overloads for direct memory copies for some types since this is faster than
+ // their default serialization.
+ template <typename T, long NR, long NC, typename MM, typename L>
+ void interprocess_serialize(const dlib::matrix<T,NR,NC,MM,L>& item, std::ostream& out)
+ {
+ dlib::serialize(item.nr(), out);
+ dlib::serialize(item.nc(), out);
+ if (item.size() != 0)
+ out.write((const char*)&item(0,0), sizeof(T)*item.size());
+ if (!out)
+ throw dlib::serialization_error("Error writing matrix to interprocess iostream.");
+ }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ void interprocess_deserialize(dlib::matrix<T,NR,NC,MM,L>& item, std::istream& in)
+ {
+ long nr, nc;
+ dlib::deserialize(nr, in);
+ dlib::deserialize(nc, in);
+ item.set_size(nr,nc);
+ if (item.size() != 0)
+ in.read((char*)&item(0,0), sizeof(T)*item.size());
+ if (!in)
+ throw dlib::serialization_error("Error reading matrix from interprocess iostream.");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl{ std::iostream& get_data_iostream(); }
+
+ inline void send_to_parent_process() {impl::get_data_iostream().flush();}
+ template <typename U, typename ...T>
+ void send_to_parent_process(U&& arg1, T&& ...args)
+ /*!
+ ensures
+ - sends all the arguments to send_to_parent_process() to the parent process by
+ serializing them with interprocess_serialize().
+ !*/
+ {
+ interprocess_serialize(arg1, impl::get_data_iostream());
+ send_to_parent_process(std::forward<T>(args)...);
+ if (!impl::get_data_iostream())
+ throw dlib::error("Error sending object to parent process.");
+ }
+
+ inline void receive_from_parent_process() {}
+ template <typename U, typename ...T>
+ void receive_from_parent_process(U&& arg1, T&& ...args)
+ /*!
+ ensures
+ - receives all the arguments to receive_from_parent_process() from the parent
+ process by deserializing them from interprocess_serialize().
+ !*/
+ {
+ interprocess_deserialize(arg1, impl::get_data_iostream());
+ receive_from_parent_process(std::forward<T>(args)...);
+ if (!impl::get_data_iostream())
+ throw dlib::error("Error receiving object from parent process.");
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ class filestreambuf;
+
+ class subprocess_stream : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool for spawning a subprocess and communicating with it. Here
+ is an example:
+
+ subprocess_stream s("/usr/bin/some_program");
+ s.send(obj1, obj2, obj3);
+ s.receive(obj4, obj5);
+ s.wait(); // wait for sub process to terminate
+
+ Then in the sub process you would have:
+
+ receive_from_parent_process(obj1, obj2, obj3);
+ // do stuff
+ cout << "echo this text to parent cout" << endl;
+ send_to_parent_process(obj4, obj5);
+
+
+ Additionally, if the sub process writes to its standard out then that will
+ be echoed to std::cout in the parent process. Writing to std::cerr or
+ returning a non-zero value from main will also be noted by the parent
+ process and an appropriate exception will be thrown.
+ !*/
+
+ public:
+
+ explicit subprocess_stream(
+ const char* program_name
+ );
+ /*!
+ ensures
+ - spawns a sub process by executing the file with the given program_name.
+ !*/
+
+ ~subprocess_stream(
+ );
+ /*!
+ ensures
+ - calls wait(). Note that the destructor never throws even though wait() can.
+ If an exception is thrown by wait() it is just logged to std::cerr.
+ !*/
+
+ void wait(
+ );
+ /*!
+ ensures
+ - closes the input stream to the child process and then waits for the child
+ to terminate.
+ - If the child returns an error (by returning != 0 from its main) or
+ outputs to its standard error then wait() throws a dlib::error() with the
+ standard error output in it.
+ !*/
+
+ int get_child_pid() const { return child_pid; }
+ /*!
+ ensures
+ - returns the PID of the child process
+ !*/
+
+ template <typename U, typename ...T>
+ void send(U&& arg1, T&& ...args)
+ /*!
+ ensures
+ - sends all the arguments to send() to the subprocess by serializing them
+ with interprocess_serialize().
+ !*/
+ {
+ interprocess_serialize(arg1, iosub);
+ send(std::forward<T>(args)...);
+ if (!iosub)
+ {
+ std::ostringstream sout;
+ sout << stderr.rdbuf();
+ throw dlib::error("Error sending object to child process.\n" + sout.str());
+ }
+ }
+ void send() {iosub.flush();}
+
+ template <typename U, typename ...T>
+ void receive(U&& arg1, T&& ...args)
+ /*!
+ ensures
+ - receives all the arguments to receive() to the subprocess by deserializing
+ them with interprocess_deserialize().
+ !*/
+ {
+ interprocess_deserialize(arg1, iosub);
+ receive(std::forward<T>(args)...);
+ if (!iosub)
+ {
+ std::ostringstream sout;
+ sout << stderr.rdbuf();
+ throw dlib::error("Error receiving object from child process.\n" + sout.str() );
+ }
+ }
+ void receive() {}
+
+
+ private:
+
+ void send_eof();
+
+ class cpipe : noncopyable
+ {
+ private:
+ int fd[2];
+ public:
+ cpipe() { if (socketpair(AF_LOCAL, SOCK_STREAM, 0, fd)) throw dlib::error("Failed to create pipe"); }
+ ~cpipe() { close(); }
+ int parent_fd() const { return fd[0]; }
+ int child_fd() const { return fd[1]; }
+ void close() { ::close(fd[0]); ::close(fd[1]); }
+ };
+
+ cpipe data_pipe;
+ cpipe stdout_pipe;
+ cpipe stderr_pipe;
+ bool wait_called = false;
+ std::unique_ptr<filestreambuf> inout_buf;
+ std::unique_ptr<filestreambuf> err_buf;
+ int child_pid = -1;
+ std::istream stderr;
+ std::iostream iosub;
+ };
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_SUBPROCeSS_STREAM_H_
+
diff --git a/ml/dlib/dlib/matrix.h b/ml/dlib/dlib/matrix.h
new file mode 100644
index 000000000..d2ae69afb
--- /dev/null
+++ b/ml/dlib/dlib/matrix.h
@@ -0,0 +1,24 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_HEADER
+#define DLIB_MATRIx_HEADER
+
+#include "matrix/matrix.h"
+#include "matrix/matrix_utilities.h"
+#include "matrix/matrix_subexp.h"
+#include "matrix/matrix_math_functions.h"
+#include "matrix/matrix_assign.h"
+#include "matrix/matrix_la.h"
+#include "matrix/symmetric_matrix_cache.h"
+#include "matrix/matrix_conv.h"
+#include "matrix/matrix_read_from_istream.h"
+#include "matrix/matrix_fft.h"
+#include "matrix/matrix_generic_image.h"
+
+#ifdef DLIB_USE_BLAS
+#include "matrix/matrix_blas_bindings.h"
+#endif
+
+#endif // DLIB_MATRIx_HEADER
+
+
diff --git a/ml/dlib/dlib/matrix/cblas_constants.h b/ml/dlib/dlib/matrix/cblas_constants.h
new file mode 100644
index 000000000..6ff89f141
--- /dev/null
+++ b/ml/dlib/dlib/matrix/cblas_constants.h
@@ -0,0 +1,22 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CBLAS_CONSTAnTS_Hh_
+#define DLIB_CBLAS_CONSTAnTS_Hh_
+
+#ifndef CBLAS_H
+namespace dlib
+{
+ namespace blas_bindings
+ {
+ enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
+ enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
+ enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
+ enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
+ enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
+
+ }
+}
+#endif // if not CBLAS_H
+
+#endif // DLIB_CBLAS_CONSTAnTS_Hh_
+
diff --git a/ml/dlib/dlib/matrix/lapack/fortran_id.h b/ml/dlib/dlib/matrix/lapack/fortran_id.h
new file mode 100644
index 000000000..8027ea34f
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/fortran_id.h
@@ -0,0 +1,62 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BOOST_NUMERIC_BINDINGS_TRAITS_FORTRAN_H
+#define DLIB_BOOST_NUMERIC_BINDINGS_TRAITS_FORTRAN_H
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// FORTRAN BINDING STUFF FROM BOOST
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+// Permission to copy, use, modify, sell and
+// distribute this software is granted provided this copyright notice appears
+// in all copies. This software is provided "as is" without express or implied
+// warranty, and with no claim as to its suitability for any purpose.
+// Copyright (C) 2002, 2003 Si-Lab b.v.b.a., Toon Knapen and Kresimir Fresl
+
+
+// First we need to know what the conventions for linking
+// C with Fortran is on this platform/toolset
+#if defined(LAPACK_FORCE_UNDERSCORE)
+#define DLIB_BIND_FORTRAN_LOWERCASE_UNDERSCORE
+#elif defined(LAPACK_FORCE_NOUNDERSCORE)
+#define DLIB_BIND_FORTRAN_LOWERCASE
+#elif defined(__GNUC__) || defined(__ICC) || defined(__sgi) || defined(__COMO__) || defined(__KCC)
+#define DLIB_BIND_FORTRAN_LOWERCASE_UNDERSCORE
+#elif defined(__IBMCPP__) || defined(_MSC_VER) || defined(__BORLANDC__)
+#define DLIB_BIND_FORTRAN_LOWERCASE
+#else
+#error do not know how to link with fortran for the given platform
+#endif
+
+// Next we define macros to convert our symbols to
+// the current convention
+#if defined(DLIB_BIND_FORTRAN_LOWERCASE_UNDERSCORE)
+#define DLIB_FORTRAN_ID( id ) id##_
+#elif defined(DLIB_BIND_FORTRAN_LOWERCASE)
+#define DLIB_FORTRAN_ID( id ) id
+#else
+#error do not know how to bind to fortran calling convention
+#endif
+
+
+
+namespace dlib
+{
+ namespace lapack
+ {
+ // stuff from f2c used to define what exactly is an integer in fortran
+#if (defined(__alpha__) || defined(__sparc64__) || defined(__x86_64__) || defined(__ia64__)) && !defined(MATLAB_MEX_FILE)
+ typedef int integer;
+ typedef unsigned int uinteger;
+#else
+ typedef long int integer;
+ typedef unsigned long int uinteger;
+#endif
+
+ }
+}
+
+#endif // DLIB_BOOST_NUMERIC_BINDINGS_TRAITS_FORTRAN_H
+
diff --git a/ml/dlib/dlib/matrix/lapack/gees.h b/ml/dlib/dlib/matrix/lapack/gees.h
new file mode 100644
index 000000000..a8ee63ff1
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/gees.h
@@ -0,0 +1,264 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_ES_Hh_
+#define DLIB_LAPACk_ES_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+#if defined(__alpha__) || defined(__sparc64__) || defined(__x86_64__) || defined(__ia64__)
+ typedef int logical;
+#else
+ typedef long int logical;
+#endif
+ typedef logical (*L_fp)(...);
+
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dgees) (char *jobvs, char *sort, L_fp select, integer *n,
+ double *a, integer *lda, integer *sdim, double *wr,
+ double *wi, double *vs, integer *ldvs, double *work,
+ integer *lwork, logical *bwork, integer *info);
+
+ void DLIB_FORTRAN_ID(sgees) (char *jobvs, char *sort, L_fp select, integer *n,
+ float *a, integer *lda, integer *sdim, float *wr,
+ float *wi, float *vs, integer *ldvs, float *work,
+ integer *lwork, logical *bwork, integer *info);
+
+ }
+
+ inline int gees (char jobvs, integer n,
+ double *a, integer lda, double *wr,
+ double *wi, double *vs, integer ldvs, double *work,
+ integer lwork)
+ {
+ // No sorting allowed
+ integer info = 0;
+ char sort = 'N';
+ L_fp fnil = 0;
+ logical bwork = 0;
+ integer sdim = 0;
+ DLIB_FORTRAN_ID(dgees)(&jobvs, &sort, fnil, &n,
+ a, &lda, &sdim, wr,
+ wi, vs, &ldvs, work,
+ &lwork, &bwork, &info);
+ return info;
+ }
+
+
+ inline int gees (char jobvs, integer n,
+ float *a, integer lda, float *wr,
+ float *wi, float *vs, integer ldvs, float *work,
+ integer lwork)
+ {
+ // No sorting allowed
+ integer info = 0;
+ char sort = 'N';
+ L_fp fnil = 0;
+ logical bwork = 0;
+ integer sdim = 0;
+ DLIB_FORTRAN_ID(sgees)(&jobvs, &sort, fnil, &n,
+ a, &lda, &sdim, wr,
+ wi, vs, &ldvs, work,
+ &lwork, &bwork, &info);
+ return info;
+ }
+
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+/* -- LAPACK driver routine (version 3.1) -- */
+/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
+/* November 2006 */
+
+/* .. Scalar Arguments .. */
+/* .. */
+/* .. Array Arguments .. */
+/* .. */
+/* .. Function Arguments .. */
+/* .. */
+
+/* Purpose */
+/* ======= */
+
+/* DGEES computes for an N-by-N real nonsymmetric matrix A, the */
+/* eigenvalues, the real Schur form T, and, optionally, the matrix of */
+/* Schur vectors Z. This gives the Schur factorization A = Z*T*(Z**T). */
+
+/* Optionally, it also orders the eigenvalues on the diagonal of the */
+/* real Schur form so that selected eigenvalues are at the top left. */
+/* The leading columns of Z then form an orthonormal basis for the */
+/* invariant subspace corresponding to the selected eigenvalues. */
+
+/* A matrix is in real Schur form if it is upper quasi-triangular with */
+/* 1-by-1 and 2-by-2 blocks. 2-by-2 blocks will be standardized in the */
+/* form */
+/* [ a b ] */
+/* [ c a ] */
+
+/* where b*c < 0. The eigenvalues of such a block are a +- sqrt(bc). */
+
+/* Arguments */
+/* ========= */
+
+/* JOBVS (input) CHARACTER*1 */
+/* = 'N': Schur vectors are not computed; */
+/* = 'V': Schur vectors are computed. */
+
+/* SORT (input) CHARACTER*1 */
+/* Specifies whether or not to order the eigenvalues on the */
+/* diagonal of the Schur form. */
+/* = 'N': Eigenvalues are not ordered; */
+/* = 'S': Eigenvalues are ordered (see SELECT). */
+
+/* SELECT (external procedure) LOGICAL FUNCTION of two DOUBLE PRECISION arguments */
+/* SELECT must be declared EXTERNAL in the calling subroutine. */
+/* If SORT = 'S', SELECT is used to select eigenvalues to sort */
+/* to the top left of the Schur form. */
+/* If SORT = 'N', SELECT is not referenced. */
+/* An eigenvalue WR(j)+sqrt(-1)*WI(j) is selected if */
+/* SELECT(WR(j),WI(j)) is true; i.e., if either one of a complex */
+/* conjugate pair of eigenvalues is selected, then both complex */
+/* eigenvalues are selected. */
+/* Note that a selected complex eigenvalue may no longer */
+/* satisfy SELECT(WR(j),WI(j)) = .TRUE. after ordering, since */
+/* ordering may change the value of complex eigenvalues */
+/* (especially if the eigenvalue is ill-conditioned); in this */
+/* case INFO is set to N+2 (see INFO below). */
+
+/* N (input) INTEGER */
+/* The order of the matrix A. N >= 0. */
+
+/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */
+/* On entry, the N-by-N matrix A. */
+/* On exit, A has been overwritten by its real Schur form T. */
+
+/* LDA (input) INTEGER */
+/* The leading dimension of the array A. LDA >= max(1,N). */
+
+/* SDIM (output) INTEGER */
+/* If SORT = 'N', SDIM = 0. */
+/* If SORT = 'S', SDIM = number of eigenvalues (after sorting) */
+/* for which SELECT is true. (Complex conjugate */
+/* pairs for which SELECT is true for either */
+/* eigenvalue count as 2.) */
+
+/* WR (output) DOUBLE PRECISION array, dimension (N) */
+/* WI (output) DOUBLE PRECISION array, dimension (N) */
+/* WR and WI contain the real and imaginary parts, */
+/* respectively, of the computed eigenvalues in the same order */
+/* that they appear on the diagonal of the output Schur form T. */
+/* Complex conjugate pairs of eigenvalues will appear */
+/* consecutively with the eigenvalue having the positive */
+/* imaginary part first. */
+
+/* VS (output) DOUBLE PRECISION array, dimension (LDVS,N) */
+/* If JOBVS = 'V', VS contains the orthogonal matrix Z of Schur */
+/* vectors. */
+/* If JOBVS = 'N', VS is not referenced. */
+
+/* LDVS (input) INTEGER */
+/* The leading dimension of the array VS. LDVS >= 1; if */
+/* JOBVS = 'V', LDVS >= N. */
+
+/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
+/* On exit, if INFO = 0, WORK(1) contains the optimal LWORK. */
+
+/* LWORK (input) INTEGER */
+/* The dimension of the array WORK. LWORK >= max(1,3*N). */
+/* For good performance, LWORK must generally be larger. */
+
+/* If LWORK = -1, then a workspace query is assumed; the routine */
+/* only calculates the optimal size of the WORK array, returns */
+/* this value as the first entry of the WORK array, and no error */
+/* message related to LWORK is issued by XERBLA. */
+
+/* BWORK (workspace) LOGICAL array, dimension (N) */
+/* Not referenced if SORT = 'N'. */
+
+/* INFO (output) INTEGER */
+/* = 0: successful exit */
+/* < 0: if INFO = -i, the i-th argument had an illegal value. */
+/* > 0: if INFO = i, and i is */
+/* <= N: the QR algorithm failed to compute all the */
+/* eigenvalues; elements 1:ILO-1 and i+1:N of WR and WI */
+/* contain those eigenvalues which have converged; if */
+/* JOBVS = 'V', VS contains the matrix which reduces A */
+/* to its partially converged Schur form. */
+/* = N+1: the eigenvalues could not be reordered because some */
+/* eigenvalues were too close to separate (the problem */
+/* is very ill-conditioned); */
+/* = N+2: after reordering, roundoff changed values of some */
+/* complex eigenvalues so that leading eigenvalues in */
+/* the Schur form no longer satisfy SELECT=.TRUE. This */
+/* could also be caused by underflow due to scaling. */
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2, long NR3, long NR4,
+ long NC1, long NC2, long NC3, long NC4,
+ typename MM,
+ typename layout
+ >
+ int gees (
+ const char jobz,
+ matrix<T,NR1,NC1,MM,column_major_layout>& a,
+ matrix<T,NR2,NC2,MM,layout>& wr,
+ matrix<T,NR3,NC3,MM,layout>& wi,
+ matrix<T,NR4,NC4,MM,column_major_layout>& vs
+ )
+ {
+ matrix<T,0,1,MM,column_major_layout> work;
+
+ const long n = a.nr();
+
+ wr.set_size(n,1);
+ wi.set_size(n,1);
+
+ if (jobz == 'V')
+ vs.set_size(n,n);
+ else
+ vs.set_size(NR4?NR4:1, NC4?NC4:1);
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::gees(jobz, n,
+ &a(0,0), a.nr(), &wr(0,0),
+ &wi(0,0), &vs(0,0), vs.nr(), &work_size,
+ -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual decomposition
+ info = binding::gees(jobz, n,
+ &a(0,0), a.nr(), &wr(0,0),
+ &wi(0,0), &vs(0,0), vs.nr(), &work(0,0),
+ work.size());
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_ES_Hh_
+
diff --git a/ml/dlib/dlib/matrix/lapack/geev.h b/ml/dlib/dlib/matrix/lapack/geev.h
new file mode 100644
index 000000000..d8fdc4af5
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/geev.h
@@ -0,0 +1,234 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_GEEV_Hh_
+#define DLIB_LAPACk_GEEV_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dgeev) (char *jobvl, char *jobvr, integer *n, double * a,
+ integer *lda, double *wr, double *wi, double *vl,
+ integer *ldvl, double *vr, integer *ldvr, double *work,
+ integer *lwork, integer *info);
+
+ void DLIB_FORTRAN_ID(sgeev) (char *jobvl, char *jobvr, integer *n, float * a,
+ integer *lda, float *wr, float *wi, float *vl,
+ integer *ldvl, float *vr, integer *ldvr, float *work,
+ integer *lwork, integer *info);
+
+ }
+
+ inline int geev (char jobvl, char jobvr, integer n, double *a,
+ integer lda, double *wr, double *wi, double *vl,
+ integer ldvl, double *vr, integer ldvr, double *work,
+ integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dgeev)(&jobvl, &jobvr, &n, a,
+ &lda, wr, wi, vl,
+ &ldvl, vr, &ldvr, work,
+ &lwork, &info);
+ return info;
+ }
+
+ inline int geev (char jobvl, char jobvr, integer n, float *a,
+ integer lda, float *wr, float *wi, float *vl,
+ integer ldvl, float *vr, integer ldvr, float *work,
+ integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(sgeev)(&jobvl, &jobvr, &n, a,
+ &lda, wr, wi, vl,
+ &ldvl, vr, &ldvr, work,
+ &lwork, &info);
+ return info;
+ }
+
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+/* -- LAPACK driver routine (version 3.1) -- */
+/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
+/* November 2006 */
+
+/* .. Scalar Arguments .. */
+/* .. */
+/* .. Array Arguments .. */
+/* .. */
+
+/* Purpose */
+/* ======= */
+
+/* DGEEV computes for an N-by-N real nonsymmetric matrix A, the */
+/* eigenvalues and, optionally, the left and/or right eigenvectors. */
+
+/* The right eigenvector v(j) of A satisfies */
+/* A * v(j) = lambda(j) * v(j) */
+/* where lambda(j) is its eigenvalue. */
+/* The left eigenvector u(j) of A satisfies */
+/* u(j)**H * A = lambda(j) * u(j)**H */
+/* where u(j)**H denotes the conjugate transpose of u(j). */
+
+/* The computed eigenvectors are normalized to have Euclidean norm */
+/* equal to 1 and largest component real. */
+
+/* Arguments */
+/* ========= */
+
+/* JOBVL (input) CHARACTER*1 */
+/* = 'N': left eigenvectors of A are not computed; */
+/* = 'V': left eigenvectors of A are computed. */
+
+/* JOBVR (input) CHARACTER*1 */
+/* = 'N': right eigenvectors of A are not computed; */
+/* = 'V': right eigenvectors of A are computed. */
+
+/* N (input) INTEGER */
+/* The order of the matrix A. N >= 0. */
+
+/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */
+/* On entry, the N-by-N matrix A. */
+/* On exit, A has been overwritten. */
+
+/* LDA (input) INTEGER */
+/* The leading dimension of the array A. LDA >= max(1,N). */
+
+/* WR (output) DOUBLE PRECISION array, dimension (N) */
+/* WI (output) DOUBLE PRECISION array, dimension (N) */
+/* WR and WI contain the real and imaginary parts, */
+/* respectively, of the computed eigenvalues. Complex */
+/* conjugate pairs of eigenvalues appear consecutively */
+/* with the eigenvalue having the positive imaginary part */
+/* first. */
+
+/* VL (output) DOUBLE PRECISION array, dimension (LDVL,N) */
+/* If JOBVL = 'V', the left eigenvectors u(j) are stored one */
+/* after another in the columns of VL, in the same order */
+/* as their eigenvalues. */
+/* If JOBVL = 'N', VL is not referenced. */
+/* If the j-th eigenvalue is real, then u(j) = VL(:,j), */
+/* the j-th column of VL. */
+/* If the j-th and (j+1)-st eigenvalues form a complex */
+/* conjugate pair, then u(j) = VL(:,j) + i*VL(:,j+1) and */
+/* u(j+1) = VL(:,j) - i*VL(:,j+1). */
+
+/* LDVL (input) INTEGER */
+/* The leading dimension of the array VL. LDVL >= 1; if */
+/* JOBVL = 'V', LDVL >= N. */
+
+/* VR (output) DOUBLE PRECISION array, dimension (LDVR,N) */
+/* If JOBVR = 'V', the right eigenvectors v(j) are stored one */
+/* after another in the columns of VR, in the same order */
+/* as their eigenvalues. */
+/* If JOBVR = 'N', VR is not referenced. */
+/* If the j-th eigenvalue is real, then v(j) = VR(:,j), */
+/* the j-th column of VR. */
+/* If the j-th and (j+1)-st eigenvalues form a complex */
+/* conjugate pair, then v(j) = VR(:,j) + i*VR(:,j+1) and */
+/* v(j+1) = VR(:,j) - i*VR(:,j+1). */
+
+/* LDVR (input) INTEGER */
+/* The leading dimension of the array VR. LDVR >= 1; if */
+/* JOBVR = 'V', LDVR >= N. */
+
+/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
+/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */
+
+/* LWORK (input) INTEGER */
+/* The dimension of the array WORK. LWORK >= max(1,3*N), and */
+/* if JOBVL = 'V' or JOBVR = 'V', LWORK >= 4*N. For good */
+/* performance, LWORK must generally be larger. */
+
+/* If LWORK = -1, then a workspace query is assumed; the routine */
+/* only calculates the optimal size of the WORK array, returns */
+/* this value as the first entry of the WORK array, and no error */
+/* message related to LWORK is issued by XERBLA. */
+
+/* INFO (output) INTEGER */
+/* = 0: successful exit */
+/* < 0: if INFO = -i, the i-th argument had an illegal value. */
+/* > 0: if INFO = i, the QR algorithm failed to compute all the */
+/* eigenvalues, and no eigenvectors have been computed; */
+/* elements i+1:N of WR and WI contain eigenvalues which */
+/* have converged. */
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2, long NR3, long NR4, long NR5,
+ long NC1, long NC2, long NC3, long NC4, long NC5,
+ typename MM,
+ typename layout
+ >
+ int geev (
+ const char jobvl,
+ const char jobvr,
+ matrix<T,NR1,NC1,MM,column_major_layout>& a,
+ matrix<T,NR2,NC2,MM,layout>& wr,
+ matrix<T,NR3,NC3,MM,layout>& wi,
+ matrix<T,NR4,NC4,MM,column_major_layout>& vl,
+ matrix<T,NR5,NC5,MM,column_major_layout>& vr
+ )
+ {
+ matrix<T,0,1,MM,column_major_layout> work;
+
+ const long n = a.nr();
+
+ wr.set_size(n,1);
+ wi.set_size(n,1);
+
+ if (jobvl == 'V')
+ vl.set_size(n,n);
+ else
+ vl.set_size(NR4?NR4:1, NC4?NC4:1);
+
+ if (jobvr == 'V')
+ vr.set_size(n,n);
+ else
+ vr.set_size(NR5?NR5:1, NC5?NC5:1);
+
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::geev(jobvl, jobvr, n, &a(0,0),
+ a.nr(), &wr(0,0), &wi(0,0), &vl(0,0),
+ vl.nr(), &vr(0,0), vr.nr(), &work_size,
+ -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual decomposition
+ info = binding::geev(jobvl, jobvr, n, &a(0,0),
+ a.nr(), &wr(0,0), &wi(0,0), &vl(0,0),
+ vl.nr(), &vr(0,0), vr.nr(), &work(0,0),
+ work.size());
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_GEEV_Hh_
+
+
diff --git a/ml/dlib/dlib/matrix/lapack/geqrf.h b/ml/dlib/dlib/matrix/lapack/geqrf.h
new file mode 100644
index 000000000..c1f8fc050
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/geqrf.h
@@ -0,0 +1,168 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_GEQRF_Hh_
+#define DLIB_LAPACk_GEQRF_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dgeqrf) (integer *m, integer *n, double *a, integer *
+ lda, double *tau, double *work, integer *lwork,
+ integer *info);
+
+ void DLIB_FORTRAN_ID(sgeqrf) (integer *m, integer *n, float *a, integer *
+ lda, float *tau, float *work, integer *lwork,
+ integer *info);
+ }
+
+ inline int geqrf (integer m, integer n, double *a, integer lda,
+ double *tau, double *work, integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dgeqrf)(&m, &n, a, &lda,
+ tau, work, &lwork, &info);
+ return info;
+ }
+
+ inline int geqrf (integer m, integer n, float *a, integer lda,
+ float *tau, float *work, integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(sgeqrf)(&m, &n, a, &lda,
+ tau, work, &lwork, &info);
+ return info;
+ }
+
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+/* -- LAPACK routine (version 3.1) -- */
+/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
+/* November 2006 */
+
+/* .. Scalar Arguments .. */
+/* .. */
+/* .. Array Arguments .. */
+/* .. */
+
+/* Purpose */
+/* ======= */
+
+/* DGEQRF computes a QR factorization of a real M-by-N matrix A: */
+/* A = Q * R. */
+
+/* Arguments */
+/* ========= */
+
+/* M (input) INTEGER */
+/* The number of rows of the matrix A. M >= 0. */
+
+/* N (input) INTEGER */
+/* The number of columns of the matrix A. N >= 0. */
+
+/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */
+/* On entry, the M-by-N matrix A. */
+/* On exit, the elements on and above the diagonal of the array */
+/* contain the min(M,N)-by-N upper trapezoidal matrix R (R is */
+/* upper triangular if m >= n); the elements below the diagonal, */
+/* with the array TAU, represent the orthogonal matrix Q as a */
+/* product of min(m,n) elementary reflectors (see Further */
+/* Details). */
+
+/* LDA (input) INTEGER */
+/* The leading dimension of the array A. LDA >= max(1,M). */
+
+/* TAU (output) DOUBLE PRECISION array, dimension (min(M,N)) */
+/* The scalar factors of the elementary reflectors (see Further */
+/* Details). */
+
+/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
+/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */
+
+/* LWORK (input) INTEGER */
+/* The dimension of the array WORK. LWORK >= max(1,N). */
+/* For optimum performance LWORK >= N*NB, where NB is */
+/* the optimal blocksize. */
+
+/* If LWORK = -1, then a workspace query is assumed; the routine */
+/* only calculates the optimal size of the WORK array, returns */
+/* this value as the first entry of the WORK array, and no error */
+/* message related to LWORK is issued by XERBLA. */
+
+/* INFO (output) INTEGER */
+/* = 0: successful exit */
+/* < 0: if INFO = -i, the i-th argument had an illegal value */
+
+/* Further Details */
+/* =============== */
+
+/* The matrix Q is represented as a product of elementary reflectors */
+
+/* Q = H(1) H(2) . . . H(k), where k = min(m,n). */
+
+/* Each H(i) has the form */
+
+/* H(i) = I - tau * v * v' */
+
+/* where tau is a real scalar, and v is a real vector with */
+/* v(1:i-1) = 0 and v(i) = 1; v(i+1:m) is stored on exit in A(i+1:m,i), */
+/* and tau in TAU(i). */
+
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2,
+ long NC1, long NC2,
+ typename MM
+ >
+ int geqrf (
+ matrix<T,NR1,NC1,MM,column_major_layout>& a,
+ matrix<T,NR2,NC2,MM,column_major_layout>& tau
+ )
+ {
+ matrix<T,0,1,MM,column_major_layout> work;
+
+ tau.set_size(std::min(a.nr(), a.nc()), 1);
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::geqrf(a.nr(), a.nc(), &a(0,0), a.nr(),
+ &tau(0,0), &work_size, -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual decomposition
+ info = binding::geqrf(a.nr(), a.nc(), &a(0,0), a.nr(),
+ &tau(0,0), &work(0,0), work.size());
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_GEQRF_Hh_
+
+
+
diff --git a/ml/dlib/dlib/matrix/lapack/gesdd.h b/ml/dlib/dlib/matrix/lapack/gesdd.h
new file mode 100644
index 000000000..e6b4d26e1
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/gesdd.h
@@ -0,0 +1,364 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_SDD_Hh_
+#define DLIB_LAPACk_SDD_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dgesdd) (char const* jobz,
+ const integer* m, const integer* n, double* a, const integer* lda,
+ double* s, double* u, const integer* ldu,
+ double* vt, const integer* ldvt,
+ double* work, const integer* lwork, integer* iwork, integer* info);
+
+ void DLIB_FORTRAN_ID(sgesdd) (char const* jobz,
+ const integer* m, const integer* n, float* a, const integer* lda,
+ float* s, float* u, const integer* ldu,
+ float* vt, const integer* ldvt,
+ float* work, const integer* lwork, integer* iwork, integer* info);
+
+ }
+
+ inline integer gesdd (const char jobz,
+ const integer m, const integer n, double* a, const integer lda,
+ double* s, double* u, const integer ldu,
+ double* vt, const integer ldvt,
+ double* work, const integer lwork, integer* iwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dgesdd)(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info);
+ return info;
+ }
+
+ inline integer gesdd (const char jobz,
+ const integer m, const integer n, float* a, const integer lda,
+ float* s, float* u, const integer ldu,
+ float* vt, const integer ldvt,
+ float* work, const integer lwork, integer* iwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(sgesdd)(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info);
+ return info;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+/* -- LAPACK driver routine (version 3.1) -- */
+/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
+/* November 2006 */
+
+/* .. Scalar Arguments .. */
+/* .. */
+/* .. Array Arguments .. */
+/* .. */
+
+/* Purpose */
+/* ======= */
+
+/* DGESDD computes the singular value decomposition (SVD) of a real */
+/* M-by-N matrix A, optionally computing the left and right singular */
+/* vectors. If singular vectors are desired, it uses a */
+/* divide-and-conquer algorithm. */
+
+/* The SVD is written */
+
+/* A = U * SIGMA * transpose(V) */
+
+/* where SIGMA is an M-by-N matrix which is zero except for its */
+/* min(m,n) diagonal elements, U is an M-by-M orthogonal matrix, and */
+/* V is an N-by-N orthogonal matrix. The diagonal elements of SIGMA */
+/* are the singular values of A; they are real and non-negative, and */
+/* are returned in descending order. The first min(m,n) columns of */
+/* U and V are the left and right singular vectors of A. */
+
+/* Note that the routine returns VT = V**T, not V. */
+
+/* The divide and conquer algorithm makes very mild assumptions about */
+/* floating point arithmetic. It will work on machines with a guard */
+/* digit in add/subtract, or on those binary machines without guard */
+/* digits which subtract like the Cray X-MP, Cray Y-MP, Cray C-90, or */
+/* Cray-2. It could conceivably fail on hexadecimal or decimal machines */
+/* without guard digits, but we know of none. */
+
+/* Arguments */
+/* ========= */
+
+/* JOBZ (input) CHARACTER*1 */
+/* Specifies options for computing all or part of the matrix U: */
+/* = 'A': all M columns of U and all N rows of V**T are */
+/* returned in the arrays U and VT; */
+/* = 'S': the first min(M,N) columns of U and the first */
+/* min(M,N) rows of V**T are returned in the arrays U */
+/* and VT; */
+/* = 'O': If M >= N, the first N columns of U are overwritten */
+/* on the array A and all rows of V**T are returned in */
+/* the array VT; */
+/* otherwise, all columns of U are returned in the */
+/* array U and the first M rows of V**T are overwritten */
+/* in the array A; */
+/* = 'N': no columns of U or rows of V**T are computed. */
+
+/* M (input) INTEGER */
+/* The number of rows of the input matrix A. M >= 0. */
+
+/* N (input) INTEGER */
+/* The number of columns of the input matrix A. N >= 0. */
+
+/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */
+/* On entry, the M-by-N matrix A. */
+/* On exit, */
+/* if JOBZ = 'O', A is overwritten with the first N columns */
+/* of U (the left singular vectors, stored */
+/* columnwise) if M >= N; */
+/* A is overwritten with the first M rows */
+/* of V**T (the right singular vectors, stored */
+/* rowwise) otherwise. */
+/* if JOBZ .ne. 'O', the contents of A are destroyed. */
+
+/* LDA (input) INTEGER */
+/* The leading dimension of the array A. LDA >= max(1,M). */
+
+/* S (output) DOUBLE PRECISION array, dimension (min(M,N)) */
+/* The singular values of A, sorted so that S(i) >= S(i+1). */
+
+/* U (output) DOUBLE PRECISION array, dimension (LDU,UCOL) */
+/* UCOL = M if JOBZ = 'A' or JOBZ = 'O' and M < N; */
+/* UCOL = min(M,N) if JOBZ = 'S'. */
+/* If JOBZ = 'A' or JOBZ = 'O' and M < N, U contains the M-by-M */
+/* orthogonal matrix U; */
+/* if JOBZ = 'S', U contains the first min(M,N) columns of U */
+/* (the left singular vectors, stored columnwise); */
+/* if JOBZ = 'O' and M >= N, or JOBZ = 'N', U is not referenced. */
+
+/* LDU (input) INTEGER */
+/* The leading dimension of the array U. LDU >= 1; if */
+/* JOBZ = 'S' or 'A' or JOBZ = 'O' and M < N, LDU >= M. */
+
+/* VT (output) DOUBLE PRECISION array, dimension (LDVT,N) */
+/* If JOBZ = 'A' or JOBZ = 'O' and M >= N, VT contains the */
+/* N-by-N orthogonal matrix V**T; */
+/* if JOBZ = 'S', VT contains the first min(M,N) rows of */
+/* V**T (the right singular vectors, stored rowwise); */
+/* if JOBZ = 'O' and M < N, or JOBZ = 'N', VT is not referenced. */
+
+/* LDVT (input) INTEGER */
+/* The leading dimension of the array VT. LDVT >= 1; if */
+/* JOBZ = 'A' or JOBZ = 'O' and M >= N, LDVT >= N; */
+/* if JOBZ = 'S', LDVT >= min(M,N). */
+
+/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
+/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK; */
+
+/* LWORK (input) INTEGER */
+/* The dimension of the array WORK. LWORK >= 1. */
+/* If JOBZ = 'N', */
+/* LWORK >= 3*min(M,N) + max(max(M,N),7*min(M,N)). */
+/* If JOBZ = 'O', */
+/* LWORK >= 3*min(M,N)*min(M,N) + */
+/* max(max(M,N),5*min(M,N)*min(M,N)+4*min(M,N)). */
+/* If JOBZ = 'S' or 'A' */
+/* LWORK >= 3*min(M,N)*min(M,N) + */
+/* max(max(M,N),4*min(M,N)*min(M,N)+4*min(M,N)). */
+/* For good performance, LWORK should generally be larger. */
+/* If LWORK = -1 but other input arguments are legal, WORK(1) */
+/* returns the optimal LWORK. */
+
+/* IWORK (workspace) INTEGER array, dimension (8*min(M,N)) */
+
+/* INFO (output) INTEGER */
+/* = 0: successful exit. */
+/* < 0: if INFO = -i, the i-th argument had an illegal value. */
+/* > 0: DBDSDC did not converge, updating process failed. */
+
+/* Further Details */
+/* =============== */
+
+/* Based on contributions by */
+/* Ming Gu and Huan Ren, Computer Science Division, University of */
+/* California at Berkeley, USA */
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2, long NR3, long NR4,
+ long NC1, long NC2, long NC3, long NC4,
+ typename MM
+ >
+ int gesdd (
+ const char jobz,
+ matrix<T,NR1,NC1,MM,column_major_layout>& a,
+ matrix<T,NR2,NC2,MM,column_major_layout>& s,
+ matrix<T,NR3,NC3,MM,column_major_layout>& u,
+ matrix<T,NR4,NC4,MM,column_major_layout>& vt
+ )
+ {
+ matrix<T,0,1,MM,column_major_layout> work;
+ matrix<integer,0,1,MM,column_major_layout> iwork;
+
+ const long m = a.nr();
+ const long n = a.nc();
+ s.set_size(std::min(m,n), 1);
+
+ // make sure the iwork memory block is big enough
+ if (iwork.size() < 8*std::min(m,n))
+ iwork.set_size(8*std::min(m,n), 1);
+
+ if (jobz == 'A')
+ {
+ u.set_size(m,m);
+ vt.set_size(n,n);
+ }
+ else if (jobz == 'S')
+ {
+ u.set_size(m, std::min(m,n));
+ vt.set_size(std::min(m,n), n);
+ }
+ else if (jobz == 'O')
+ {
+ DLIB_CASSERT(false, "jobz == 'O' not supported");
+ }
+ else
+ {
+ u.set_size(NR3?NR3:1, NC3?NC3:1);
+ vt.set_size(NR4?NR4:1, NC4?NC4:1);
+ }
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::gesdd(jobz, a.nr(), a.nc(), &a(0,0), a.nr(),
+ &s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(),
+ &work_size, -1, &iwork(0,0));
+
+ if (info != 0)
+ return info;
+
+ // There is a bug in an older version of LAPACK in Debian etch
+ // that causes the gesdd to return the wrong value for work_size
+ // when jobz == 'N'. So verify the value of work_size.
+ if (jobz == 'N')
+ {
+ using std::min;
+ using std::max;
+ const T min_work_size = 3*min(m,n) + max(max(m,n),7*min(m,n));
+ if (work_size < min_work_size)
+ work_size = min_work_size;
+ }
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual SVD
+ info = binding::gesdd(jobz, a.nr(), a.nc(), &a(0,0), a.nr(),
+ &s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(),
+ &work(0,0), work.size(), &iwork(0,0));
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2, long NR3, long NR4,
+ long NC1, long NC2, long NC3, long NC4,
+ typename MM
+ >
+ int gesdd (
+ const char jobz,
+ matrix<T,NR1,NC1,MM,row_major_layout>& a,
+ matrix<T,NR2,NC2,MM,row_major_layout>& s,
+ matrix<T,NR3,NC3,MM,row_major_layout>& u_,
+ matrix<T,NR4,NC4,MM,row_major_layout>& vt_
+ )
+ {
+ matrix<T,0,1,MM,row_major_layout> work;
+ matrix<integer,0,1,MM,row_major_layout> iwork;
+
+ // Row major order matrices are transposed from LAPACK's point of view.
+ matrix<T,NR4,NC4,MM,row_major_layout>& u = vt_;
+ matrix<T,NR3,NC3,MM,row_major_layout>& vt = u_;
+
+
+ const long m = a.nc();
+ const long n = a.nr();
+ s.set_size(std::min(m,n), 1);
+
+ // make sure the iwork memory block is big enough
+ if (iwork.size() < 8*std::min(m,n))
+ iwork.set_size(8*std::min(m,n), 1);
+
+ if (jobz == 'A')
+ {
+ u.set_size(m,m);
+ vt.set_size(n,n);
+ }
+ else if (jobz == 'S')
+ {
+ u.set_size(std::min(m,n), m);
+ vt.set_size(n, std::min(m,n));
+ }
+ else if (jobz == 'O')
+ {
+ DLIB_CASSERT(false, "jobz == 'O' not supported");
+ }
+ else
+ {
+ u.set_size(NR4?NR4:1, NC4?NC4:1);
+ vt.set_size(NR3?NR3:1, NC3?NC3:1);
+ }
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::gesdd(jobz, m, n, &a(0,0), a.nc(),
+ &s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(),
+ &work_size, -1, &iwork(0,0));
+
+ if (info != 0)
+ return info;
+
+ // There is a bug in an older version of LAPACK in Debian etch
+ // that causes the gesdd to return the wrong value for work_size
+ // when jobz == 'N'. So verify the value of work_size.
+ if (jobz == 'N')
+ {
+ using std::min;
+ using std::max;
+ const T min_work_size = 3*min(m,n) + max(max(m,n),7*min(m,n));
+ if (work_size < min_work_size)
+ work_size = min_work_size;
+ }
+
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual SVD
+ info = binding::gesdd(jobz, m, n, &a(0,0), a.nc(),
+ &s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(),
+ &work(0,0), work.size(), &iwork(0,0));
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_SDD_Hh_
+
+
diff --git a/ml/dlib/dlib/matrix/lapack/gesvd.h b/ml/dlib/dlib/matrix/lapack/gesvd.h
new file mode 100644
index 000000000..e00654db6
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/gesvd.h
@@ -0,0 +1,323 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_SVD_Hh_
+#define DLIB_LAPACk_SVD_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dgesvd) (const char* jobu, const char* jobvt,
+ const integer* m, const integer* n, double* a, const integer* lda,
+ double* s, double* u, const integer* ldu,
+ double* vt, const integer* ldvt,
+ double* work, const integer* lwork, integer* info);
+
+ void DLIB_FORTRAN_ID(sgesvd) (const char* jobu, const char* jobvt,
+ const integer* m, const integer* n, float* a, const integer* lda,
+ float* s, float* u, const integer* ldu,
+ float* vt, const integer* ldvt,
+ float* work, const integer* lwork, integer* info);
+
+ }
+
+ inline integer gesvd (const char jobu, const char jobvt,
+ const integer m, const integer n, double* a, const integer lda,
+ double* s, double* u, const integer ldu,
+ double* vt, const integer ldvt,
+ double* work, const integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dgesvd)(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, &info);
+ return info;
+ }
+
+ inline integer gesvd (const char jobu, const char jobvt,
+ const integer m, const integer n, float* a, const integer lda,
+ float* s, float* u, const integer ldu,
+ float* vt, const integer ldvt,
+ float* work, const integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(sgesvd)(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, &info);
+ return info;
+ }
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+/* -- LAPACK driver routine (version 3.1) -- */
+/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
+/* November 2006 */
+
+/* .. Scalar Arguments .. */
+/* .. */
+/* .. Array Arguments .. */
+/* .. */
+
+/* Purpose */
+/* ======= */
+
+/* DGESVD computes the singular value decomposition (SVD) of a real */
+/* M-by-N matrix A, optionally computing the left and/or right singular */
+/* vectors. The SVD is written */
+
+/* A = U * SIGMA * transpose(V) */
+
+/* where SIGMA is an M-by-N matrix which is zero except for its */
+/* min(m,n) diagonal elements, U is an M-by-M orthogonal matrix, and */
+/* V is an N-by-N orthogonal matrix. The diagonal elements of SIGMA */
+/* are the singular values of A; they are real and non-negative, and */
+/* are returned in descending order. The first min(m,n) columns of */
+/* U and V are the left and right singular vectors of A. */
+
+/* Note that the routine returns V**T, not V. */
+
+/* Arguments */
+/* ========= */
+
+/* JOBU (input) CHARACTER*1 */
+/* Specifies options for computing all or part of the matrix U: */
+/* = 'A': all M columns of U are returned in array U: */
+/* = 'S': the first min(m,n) columns of U (the left singular */
+/* vectors) are returned in the array U; */
+/* = 'O': the first min(m,n) columns of U (the left singular */
+/* vectors) are overwritten on the array A; */
+/* = 'N': no columns of U (no left singular vectors) are */
+/* computed. */
+
+/* JOBVT (input) CHARACTER*1 */
+/* Specifies options for computing all or part of the matrix */
+/* V**T: */
+/* = 'A': all N rows of V**T are returned in the array VT; */
+/* = 'S': the first min(m,n) rows of V**T (the right singular */
+/* vectors) are returned in the array VT; */
+/* = 'O': the first min(m,n) rows of V**T (the right singular */
+/* vectors) are overwritten on the array A; */
+/* = 'N': no rows of V**T (no right singular vectors) are */
+/* computed. */
+
+/* JOBVT and JOBU cannot both be 'O'. */
+
+/* M (input) INTEGER */
+/* The number of rows of the input matrix A. M >= 0. */
+
+/* N (input) INTEGER */
+/* The number of columns of the input matrix A. N >= 0. */
+
+/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */
+/* On entry, the M-by-N matrix A. */
+/* On exit, */
+/* if JOBU = 'O', A is overwritten with the first min(m,n) */
+/* columns of U (the left singular vectors, */
+/* stored columnwise); */
+/* if JOBVT = 'O', A is overwritten with the first min(m,n) */
+/* rows of V**T (the right singular vectors, */
+/* stored rowwise); */
+/* if JOBU .ne. 'O' and JOBVT .ne. 'O', the contents of A */
+/* are destroyed. */
+
+/* LDA (input) INTEGER */
+/* The leading dimension of the array A. LDA >= max(1,M). */
+
+/* S (output) DOUBLE PRECISION array, dimension (min(M,N)) */
+/* The singular values of A, sorted so that S(i) >= S(i+1). */
+
+/* U (output) DOUBLE PRECISION array, dimension (LDU,UCOL) */
+/* (LDU,M) if JOBU = 'A' or (LDU,min(M,N)) if JOBU = 'S'. */
+/* If JOBU = 'A', U contains the M-by-M orthogonal matrix U; */
+/* if JOBU = 'S', U contains the first min(m,n) columns of U */
+/* (the left singular vectors, stored columnwise); */
+/* if JOBU = 'N' or 'O', U is not referenced. */
+
+/* LDU (input) INTEGER */
+/* The leading dimension of the array U. LDU >= 1; if */
+/* JOBU = 'S' or 'A', LDU >= M. */
+
+/* VT (output) DOUBLE PRECISION array, dimension (LDVT,N) */
+/* If JOBVT = 'A', VT contains the N-by-N orthogonal matrix */
+/* V**T; */
+/* if JOBVT = 'S', VT contains the first min(m,n) rows of */
+/* V**T (the right singular vectors, stored rowwise); */
+/* if JOBVT = 'N' or 'O', VT is not referenced. */
+
+/* LDVT (input) INTEGER */
+/* The leading dimension of the array VT. LDVT >= 1; if */
+/* JOBVT = 'A', LDVT >= N; if JOBVT = 'S', LDVT >= min(M,N). */
+
+/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
+/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK; */
+/* if INFO > 0, WORK(2:MIN(M,N)) contains the unconverged */
+/* superdiagonal elements of an upper bidiagonal matrix B */
+/* whose diagonal is in S (not necessarily sorted). B */
+/* satisfies A = U * B * VT, so it has the same singular values */
+/* as A, and singular vectors related by U and VT. */
+
+/* LWORK (input) INTEGER */
+/* The dimension of the array WORK. */
+/* LWORK >= MAX(1,3*MIN(M,N)+MAX(M,N),5*MIN(M,N)). */
+/* For good performance, LWORK should generally be larger. */
+
+/* If LWORK = -1, then a workspace query is assumed; the routine */
+/* only calculates the optimal size of the WORK array, returns */
+/* this value as the first entry of the WORK array, and no error */
+/* message related to LWORK is issued by XERBLA. */
+
+/* INFO (output) INTEGER */
+/* = 0: successful exit. */
+/* < 0: if INFO = -i, the i-th argument had an illegal value. */
+/* > 0: if DBDSQR did not converge, INFO specifies how many */
+/* superdiagonals of an intermediate bidiagonal form B */
+/* did not converge to zero. See the description of WORK */
+/* above for details. */
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2, long NR3, long NR4,
+ long NC1, long NC2, long NC3, long NC4,
+ typename MM
+ >
+ int gesvd (
+ const char jobu,
+ const char jobvt,
+ matrix<T,NR1,NC1,MM,column_major_layout>& a,
+ matrix<T,NR2,NC2,MM,column_major_layout>& s,
+ matrix<T,NR3,NC3,MM,column_major_layout>& u,
+ matrix<T,NR4,NC4,MM,column_major_layout>& vt
+ )
+ {
+ matrix<T,0,1,MM,column_major_layout> work;
+
+ const long m = a.nr();
+ const long n = a.nc();
+ s.set_size(std::min(m,n), 1);
+
+ if (jobu == 'A')
+ u.set_size(m,m);
+ else if (jobu == 'S')
+ u.set_size(m, std::min(m,n));
+ else
+ u.set_size(NR3?NR3:1, NC3?NC3:1);
+
+ if (jobvt == 'A')
+ vt.set_size(n,n);
+ else if (jobvt == 'S')
+ vt.set_size(std::min(m,n), n);
+ else
+ vt.set_size(NR4?NR4:1, NC4?NC4:1);
+
+
+ if (jobu == 'O' || jobvt == 'O')
+ {
+ DLIB_CASSERT(false, "job == 'O' not supported");
+ }
+
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::gesvd(jobu, jobvt, a.nr(), a.nc(), &a(0,0), a.nr(),
+ &s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(),
+ &work_size, -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual SVD
+ info = binding::gesvd(jobu, jobvt, a.nr(), a.nc(), &a(0,0), a.nr(),
+ &s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(),
+ &work(0,0), work.size());
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2, long NR3, long NR4,
+ long NC1, long NC2, long NC3, long NC4,
+ typename MM
+ >
+ int gesvd (
+ char jobu,
+ char jobvt,
+ matrix<T,NR1,NC1,MM,row_major_layout>& a,
+ matrix<T,NR2,NC2,MM,row_major_layout>& s,
+ matrix<T,NR3,NC3,MM,row_major_layout>& u_,
+ matrix<T,NR4,NC4,MM,row_major_layout>& vt_
+ )
+ {
+ matrix<T,0,1,MM,row_major_layout> work;
+
+ // Row major order matrices are transposed from LAPACK's point of view.
+ matrix<T,NR4,NC4,MM,row_major_layout>& u = vt_;
+ matrix<T,NR3,NC3,MM,row_major_layout>& vt = u_;
+ std::swap(jobu, jobvt);
+
+ const long m = a.nc();
+ const long n = a.nr();
+ s.set_size(std::min(m,n), 1);
+
+ if (jobu == 'A')
+ u.set_size(m,m);
+ else if (jobu == 'S')
+ u.set_size(std::min(m,n), m);
+ else
+ u.set_size(NR4?NR4:1, NC4?NC4:1);
+
+ if (jobvt == 'A')
+ vt.set_size(n,n);
+ else if (jobvt == 'S')
+ vt.set_size(n, std::min(m,n));
+ else
+ vt.set_size(NR3?NR3:1, NC3?NC3:1);
+
+ if (jobu == 'O' || jobvt == 'O')
+ {
+ DLIB_CASSERT(false, "job == 'O' not supported");
+ }
+
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::gesvd(jobu, jobvt, m, n, &a(0,0), a.nc(),
+ &s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(),
+ &work_size, -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual SVD
+ info = binding::gesvd(jobu, jobvt, m, n, &a(0,0), a.nc(),
+ &s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(),
+ &work(0,0), work.size());
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_SVD_Hh_
+
diff --git a/ml/dlib/dlib/matrix/lapack/getrf.h b/ml/dlib/dlib/matrix/lapack/getrf.h
new file mode 100644
index 000000000..a1f0b139d
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/getrf.h
@@ -0,0 +1,132 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_GETRF_Hh_
+#define DLIB_LAPACk_GETRF_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dgetrf) (integer* m, integer *n, double *a,
+ integer* lda, integer *ipiv, integer *info);
+
+ void DLIB_FORTRAN_ID(sgetrf) (integer* m, integer *n, float *a,
+ integer* lda, integer *ipiv, integer *info);
+
+ }
+
+ inline int getrf (integer m, integer n, double *a,
+ integer lda, integer *ipiv)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dgetrf)(&m, &n, a, &lda, ipiv, &info);
+ return info;
+ }
+
+ inline int getrf (integer m, integer n, float *a,
+ integer lda, integer *ipiv)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(sgetrf)(&m, &n, a, &lda, ipiv, &info);
+ return info;
+ }
+
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+
+/* -- LAPACK routine (version 3.1) -- */
+/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
+/* November 2006 */
+
+/* .. Scalar Arguments .. */
+/* .. */
+/* .. Array Arguments .. */
+/* .. */
+
+/* Purpose */
+/* ======= */
+
+/* DGETRF computes an LU factorization of a general M-by-N matrix A */
+/* using partial pivoting with row interchanges. */
+
+/* The factorization has the form */
+/* A = P * L * U */
+/* where P is a permutation matrix, L is lower triangular with unit */
+/* diagonal elements (lower trapezoidal if m > n), and U is upper */
+/* triangular (upper trapezoidal if m < n). */
+
+/* This is the right-looking Level 3 BLAS version of the algorithm. */
+
+/* Arguments */
+/* ========= */
+
+/* M (input) INTEGER */
+/* The number of rows of the matrix A. M >= 0. */
+
+/* N (input) INTEGER */
+/* The number of columns of the matrix A. N >= 0. */
+
+/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */
+/* On entry, the M-by-N matrix to be factored. */
+/* On exit, the factors L and U from the factorization */
+/* A = P*L*U; the unit diagonal elements of L are not stored. */
+
+/* LDA (input) INTEGER */
+/* The leading dimension of the array A. LDA >= max(1,M). */
+
+/* IPIV (output) INTEGER array, dimension (min(M,N)) */
+/* The pivot indices; for 1 <= i <= min(M,N), row i of the */
+/* matrix was interchanged with row IPIV(i). */
+
+/* INFO (output) INTEGER */
+/* = 0: successful exit */
+/* < 0: if INFO = -i, the i-th argument had an illegal value */
+/* > 0: if INFO = i, U(i,i) is exactly zero. The factorization */
+/* has been completed, but the factor U is exactly */
+/* singular, and division by zero will occur if it is used */
+/* to solve a system of equations. */
+
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2,
+ long NC1, long NC2,
+ typename MM,
+ typename layout
+ >
+ int getrf (
+ matrix<T,NR1,NC1,MM,column_major_layout>& a,
+ matrix<integer,NR2,NC2,MM,layout>& ipiv
+ )
+ {
+ const long m = a.nr();
+ const long n = a.nc();
+
+ ipiv.set_size(std::min(m,n), 1);
+
+ // compute the actual decomposition
+ return binding::getrf(m, n, &a(0,0), a.nr(), &ipiv(0,0));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_GETRF_Hh_
+
diff --git a/ml/dlib/dlib/matrix/lapack/ormqr.h b/ml/dlib/dlib/matrix/lapack/ormqr.h
new file mode 100644
index 000000000..ab66ff4d2
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/ormqr.h
@@ -0,0 +1,224 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_ORMQR_Hh_
+#define DLIB_LAPACk_ORMQR_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dormqr) (char *side, char *trans, integer *m, integer *n,
+ integer *k, const double *a, integer *lda, const double *tau,
+ double * c_, integer *ldc, double *work, integer *lwork,
+ integer *info);
+
+ void DLIB_FORTRAN_ID(sormqr) (char *side, char *trans, integer *m, integer *n,
+ integer *k, const float *a, integer *lda, const float *tau,
+ float * c_, integer *ldc, float *work, integer *lwork,
+ integer *info);
+
+ }
+
+ inline int ormqr (char side, char trans, integer m, integer n,
+ integer k, const double *a, integer lda, const double *tau,
+ double *c_, integer ldc, double *work, integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dormqr)(&side, &trans, &m, &n,
+ &k, a, &lda, tau,
+ c_, &ldc, work, &lwork, &info);
+ return info;
+ }
+
+ inline int ormqr (char side, char trans, integer m, integer n,
+ integer k, const float *a, integer lda, const float *tau,
+ float *c_, integer ldc, float *work, integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(sormqr)(&side, &trans, &m, &n,
+ &k, a, &lda, tau,
+ c_, &ldc, work, &lwork, &info);
+ return info;
+ }
+
+
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+/* -- LAPACK routine (version 3.1) -- */
+/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
+/* November 2006 */
+
+/* .. Scalar Arguments .. */
+/* .. */
+/* .. Array Arguments .. */
+/* .. */
+
+/* Purpose */
+/* ======= */
+
+/* DORMQR overwrites the general real M-by-N matrix C with */
+
+/* SIDE = 'L' SIDE = 'R' */
+/* TRANS = 'N': Q * C C * Q */
+/* TRANS = 'T': Q**T * C C * Q**T */
+
+/* where Q is a real orthogonal matrix defined as the product of k */
+/* elementary reflectors */
+
+/* Q = H(1) H(2) . . . H(k) */
+
+/* as returned by DGEQRF. Q is of order M if SIDE = 'L' and of order N */
+/* if SIDE = 'R'. */
+
+/* Arguments */
+/* ========= */
+
+/* SIDE (input) CHARACTER*1 */
+/* = 'L': apply Q or Q**T from the Left; */
+/* = 'R': apply Q or Q**T from the Right. */
+
+/* TRANS (input) CHARACTER*1 */
+/* = 'N': No transpose, apply Q; */
+/* = 'T': Transpose, apply Q**T. */
+
+/* M (input) INTEGER */
+/* The number of rows of the matrix C. M >= 0. */
+
+/* N (input) INTEGER */
+/* The number of columns of the matrix C. N >= 0. */
+
+/* K (input) INTEGER */
+/* The number of elementary reflectors whose product defines */
+/* the matrix Q. */
+/* If SIDE = 'L', M >= K >= 0; */
+/* if SIDE = 'R', N >= K >= 0. */
+
+/* A (input) DOUBLE PRECISION array, dimension (LDA,K) */
+/* The i-th column must contain the vector which defines the */
+/* elementary reflector H(i), for i = 1,2,...,k, as returned by */
+/* DGEQRF in the first k columns of its array argument A. */
+/* A is modified by the routine but restored on exit. */
+
+/* LDA (input) INTEGER */
+/* The leading dimension of the array A. */
+/* If SIDE = 'L', LDA >= max(1,M); */
+/* if SIDE = 'R', LDA >= max(1,N). */
+
+/* TAU (input) DOUBLE PRECISION array, dimension (K) */
+/* TAU(i) must contain the scalar factor of the elementary */
+/* reflector H(i), as returned by DGEQRF. */
+
+/* C (input/output) DOUBLE PRECISION array, dimension (LDC,N) */
+/* On entry, the M-by-N matrix C. */
+/* On exit, C is overwritten by Q*C or Q**T*C or C*Q**T or C*Q. */
+
+/* LDC (input) INTEGER */
+/* The leading dimension of the array C. LDC >= max(1,M). */
+
+/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
+/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */
+
+/* LWORK (input) INTEGER */
+/* The dimension of the array WORK. */
+/* If SIDE = 'L', LWORK >= max(1,N); */
+/* if SIDE = 'R', LWORK >= max(1,M). */
+/* For optimum performance LWORK >= N*NB if SIDE = 'L', and */
+/* LWORK >= M*NB if SIDE = 'R', where NB is the optimal */
+/* blocksize. */
+
+/* If LWORK = -1, then a workspace query is assumed; the routine */
+/* only calculates the optimal size of the WORK array, returns */
+/* this value as the first entry of the WORK array, and no error */
+/* message related to LWORK is issued by XERBLA. */
+
+/* INFO (output) INTEGER */
+/* = 0: successful exit */
+/* < 0: if INFO = -i, the i-th argument had an illegal value */
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2, long NR3,
+ long NC1, long NC2, long NC3,
+ typename MM,
+ typename C_LAYOUT
+ >
+ int ormqr (
+ char side,
+ char trans,
+ const matrix<T,NR1,NC1,MM,column_major_layout>& a,
+ const matrix<T,NR2,NC2,MM,column_major_layout>& tau,
+ matrix<T,NR3,NC3,MM,C_LAYOUT>& c
+ )
+ {
+ long m = c.nr();
+ long n = c.nc();
+ const long k = a.nc();
+ long ldc;
+ if (is_same_type<C_LAYOUT,column_major_layout>::value)
+ {
+ ldc = c.nr();
+ }
+ else
+ {
+ // Since lapack expects c to be in column major layout we have to
+ // do something to make this work. Since a row major layout matrix
+ // will look just like a transposed C we can just swap a few things around.
+
+ ldc = c.nc();
+ swap(m,n);
+
+ if (side == 'L')
+ side = 'R';
+ else
+ side = 'L';
+
+ if (trans == 'T')
+ trans = 'N';
+ else
+ trans = 'T';
+ }
+
+ matrix<T,0,1,MM,column_major_layout> work;
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::ormqr(side, trans, m, n,
+ k, &a(0,0), a.nr(), &tau(0,0),
+ &c(0,0), ldc, &work_size, -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual result
+ info = binding::ormqr(side, trans, m, n,
+ k, &a(0,0), a.nr(), &tau(0,0),
+ &c(0,0), ldc, &work(0,0), work.size());
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_ORMQR_Hh_
+
diff --git a/ml/dlib/dlib/matrix/lapack/pbtrf.h b/ml/dlib/dlib/matrix/lapack/pbtrf.h
new file mode 100644
index 000000000..23bcc127b
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/pbtrf.h
@@ -0,0 +1,178 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_BDC_Hh_
+#define DLIB_LAPACk_BDC_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dpbtrf) (const char* uplo, const integer* n, const integer* kd,
+ double* ab, const integer* ldab, integer* info);
+
+ void DLIB_FORTRAN_ID(spbtrf) (const char* uplo, const integer* n, const integer* kd,
+ float* ab, const integer* ldab, integer* info);
+
+ }
+
+ inline integer pbtrf (const char uplo, const integer n, const integer kd,
+ double* ab, const integer ldab)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dpbtrf)(&uplo, &n, &kd, ab, &ldab, &info);
+ return info;
+ }
+
+ inline integer pbtrf (const char uplo, const integer n, const integer kd,
+ float* ab, const integer ldab)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(spbtrf)(&uplo, &n, &kd, ab, &ldab, &info);
+ return info;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+/* DPBTRF(l) LAPACK routine (version 1.1) DPBTRF(l)
+
+NAME
+ DPBTRF - compute the Cholesky factorization of a real symmetric positive
+ definite band matrix A
+
+SYNOPSIS
+
+ SUBROUTINE DPBTRF( UPLO, N, KD, AB, LDAB, INFO )
+
+ CHARACTER UPLO
+
+ INTEGER INFO, KD, LDAB, N
+
+ DOUBLE PRECISION AB( LDAB, * )
+
+PURPOSE
+ DPBTRF computes the Cholesky factorization of a real symmetric positive
+ definite band matrix A.
+
+ The factorization has the form
+ A = U**T * U, if UPLO = 'U', or
+ A = L * L**T, if UPLO = 'L',
+ where U is an upper triangular matrix and L is lower triangular.
+
+ARGUMENTS
+
+ UPLO (input) CHARACTER*1
+ = 'U': Upper triangle of A is stored;
+ = 'L': Lower triangle of A is stored.
+
+ N (input) INTEGER
+ The order of the matrix A. N >= 0.
+
+ KD (input) INTEGER
+ The number of superdiagonals of the matrix A if UPLO = 'U', or the
+ number of subdiagonals if UPLO = 'L'. KD >= 0.
+
+ AB (input/output) DOUBLE PRECISION array, dimension (LDAB,N)
+ On entry, the upper or lower triangle of the symmetric band matrix
+ A, stored in the first KD+1 rows of the array. The j-th column of
+ A is stored in the j-th column of the array AB as follows: if UPLO
+ = 'U', AB(kd+1+i-j,j) = A(i,j) for max(1,j-kd)<=i<=j; if UPLO =
+ 'L', AB(1+i-j,j) = A(i,j) for j<=i<=min(n,j+kd).
+
+ On exit, if INFO = 0, the triangular factor U or L from the Chole-
+ sky factorization A = U**T*U or A = L*L**T of the band matrix A, in
+ the same storage format as A.
+
+ LDAB (input) INTEGER
+ The leading dimension of the array AB. LDAB >= KD+1.
+
+ INFO (output) INTEGER
+ = 0: successful exit
+ < 0: if INFO = -i, the i-th argument had an illegal value
+ > 0: if INFO = i, the leading minor of order i is not positive
+ definite, and the factorization could not be completed.
+
+FURTHER DETAILS
+ The band storage scheme is illustrated by the following example, when N =
+ 6, KD = 2, and UPLO = 'U':
+
+ On entry: On exit:
+
+ * * a13 a24 a35 a46 * * u13 u24 u35 u46
+ * a12 a23 a34 a45 a56 * u12 u23 u34 u45 u56
+ a11 a22 a33 a44 a55 a66 u11 u22 u33 u44 u55 u66
+
+ Similarly, if UPLO = 'L' the format of A is as follows:
+
+ On entry: On exit:
+
+ a11 a22 a33 a44 a55 a66 l11 l22 l33 l44 l55 l66
+ a21 a32 a43 a54 a65 * l21 l32 l43 l54 l65 *
+ a31 a42 a53 a64 * * l31 l42 l53 l64 * *
+
+ Array elements marked * are not used by the routine.
+
+ Contributed by
+ Peter Mayes and Giuseppe Radicati, IBM ECSEC, Rome, March 23, 1989 */
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NC1,
+ typename MM
+ >
+ int pbtrf (
+ char uplo, matrix<T,NR1,NC1,MM,column_major_layout>& ab
+ )
+ {
+ const long ldab = ab.nr();
+ const long n = ab.nc();
+ const long kd = ldab - 1; // assume fully packed
+
+ int info = binding::pbtrf(uplo, n, kd, &ab(0,0), ldab);
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+
+ template <
+ typename T,
+ long NR1, long NC1,
+ typename MM
+ >
+ int pbtrf (
+ char uplo, matrix<T,NR1,NC1,MM,row_major_layout>& ab
+ )
+ {
+ const long ldab = ab.nr();
+ const long n = ab.nc();
+ const long kd = ldab - 1; // assume fully packed
+
+ matrix<T,NC1,NR1,MM,row_major_layout> abt = trans(ab);
+
+ int info = binding::pbtrf(uplo, n, kd, &abt(0,0), ldab);
+
+ ab = trans(abt);
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_BDC_Hh_
+
+
diff --git a/ml/dlib/dlib/matrix/lapack/potrf.h b/ml/dlib/dlib/matrix/lapack/potrf.h
new file mode 100644
index 000000000..b9d6a7cc8
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/potrf.h
@@ -0,0 +1,174 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_POTRF_Hh_
+#define DLIB_LAPACk_POTRF_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dpotrf) (char *uplo, integer *n, double *a,
+ integer* lda, integer *info);
+
+ void DLIB_FORTRAN_ID(spotrf) (char *uplo, integer *n, float *a,
+ integer* lda, integer *info);
+
+ }
+
+ inline int potrf (char uplo, integer n, double *a, integer lda)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dpotrf)(&uplo, &n, a, &lda, &info);
+ return info;
+ }
+
+ inline int potrf (char uplo, integer n, float *a, integer lda)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(spotrf)(&uplo, &n, a, &lda, &info);
+ return info;
+ }
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+/* -- LAPACK routine (version 3.1) -- */
+/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
+/* November 2006 */
+
+/* .. Scalar Arguments .. */
+/* .. */
+/* .. Array Arguments .. */
+/* .. */
+
+/* Purpose */
+/* ======= */
+
+/* DPOTRF computes the Cholesky factorization of a real symmetric */
+/* positive definite matrix A. */
+
+/* The factorization has the form */
+/* A = U**T * U, if UPLO = 'U', or */
+/* A = L * L**T, if UPLO = 'L', */
+/* where U is an upper triangular matrix and L is lower triangular. */
+
+/* This is the block version of the algorithm, calling Level 3 BLAS. */
+
+/* Arguments */
+/* ========= */
+
+/* UPLO (input) CHARACTER*1 */
+/* = 'U': Upper triangle of A is stored; */
+/* = 'L': Lower triangle of A is stored. */
+
+/* N (input) INTEGER */
+/* The order of the matrix A. N >= 0. */
+
+/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */
+/* On entry, the symmetric matrix A. If UPLO = 'U', the leading */
+/* N-by-N upper triangular part of A contains the upper */
+/* triangular part of the matrix A, and the strictly lower */
+/* triangular part of A is not referenced. If UPLO = 'L', the */
+/* leading N-by-N lower triangular part of A contains the lower */
+/* triangular part of the matrix A, and the strictly upper */
+/* triangular part of A is not referenced. */
+
+/* On exit, if INFO = 0, the factor U or L from the Cholesky */
+/* factorization A = U**T*U or A = L*L**T. */
+
+/* LDA (input) INTEGER */
+/* The leading dimension of the array A. LDA >= max(1,N). */
+
+/* INFO (output) INTEGER */
+/* = 0: successful exit */
+/* < 0: if INFO = -i, the i-th argument had an illegal value */
+/* > 0: if INFO = i, the leading minor of order i is not */
+/* positive definite, and the factorization could not be */
+/* completed. */
+
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1,
+ long NC1,
+ typename MM
+ >
+ int potrf (
+ char uplo,
+ matrix<T,NR1,NC1,MM,column_major_layout>& a
+ )
+ {
+ // compute the actual decomposition
+ int info = binding::potrf(uplo, a.nr(), &a(0,0), a.nr());
+
+ // If it fails part way though the factorization then make sure
+ // the end of the matrix gets properly initialized with zeros.
+ if (info > 0)
+ {
+ if (uplo == 'L')
+ set_colm(a, range(info-1, a.nc()-1)) = 0;
+ else
+ set_rowm(a, range(info-1, a.nr()-1)) = 0;
+ }
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1,
+ long NC1,
+ typename MM
+ >
+ int potrf (
+ char uplo,
+ matrix<T,NR1,NC1,MM,row_major_layout>& a
+ )
+ {
+ // since we are working on a row major order matrix we need to ask
+ // LAPACK for the transpose of whatever the user asked for.
+
+ if (uplo == 'L')
+ uplo = 'U';
+ else
+ uplo = 'L';
+
+ // compute the actual decomposition
+ int info = binding::potrf(uplo, a.nr(), &a(0,0), a.nr());
+
+ // If it fails part way though the factorization then make sure
+ // the end of the matrix gets properly initialized with zeros.
+ if (info > 0)
+ {
+ if (uplo == 'U')
+ set_colm(a, range(info-1, a.nc()-1)) = 0;
+ else
+ set_rowm(a, range(info-1, a.nr()-1)) = 0;
+ }
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_POTRF_Hh_
+
+
diff --git a/ml/dlib/dlib/matrix/lapack/syev.h b/ml/dlib/dlib/matrix/lapack/syev.h
new file mode 100644
index 000000000..0c9fd251a
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/syev.h
@@ -0,0 +1,218 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_EV_Hh_
+#define DLIB_LAPACk_EV_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dsyev) (char *jobz, char *uplo, integer *n, double *a,
+ integer *lda, double *w, double *work, integer *lwork,
+ integer *info);
+
+ void DLIB_FORTRAN_ID(ssyev) (char *jobz, char *uplo, integer *n, float *a,
+ integer *lda, float *w, float *work, integer *lwork,
+ integer *info);
+
+ }
+
+ inline int syev (char jobz, char uplo, integer n, double *a,
+ integer lda, double *w, double *work, integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dsyev)(&jobz, &uplo, &n, a,
+ &lda, w, work, &lwork, &info);
+ return info;
+ }
+
+ inline int syev (char jobz, char uplo, integer n, float *a,
+ integer lda, float *w, float *work, integer lwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(ssyev)(&jobz, &uplo, &n, a,
+ &lda, w, work, &lwork, &info);
+ return info;
+ }
+
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+/* -- LAPACK driver routine (version 3.1) -- */
+/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
+/* November 2006 */
+
+/* .. Scalar Arguments .. */
+/* .. */
+/* .. Array Arguments .. */
+/* .. */
+
+/* Purpose */
+/* ======= */
+
+/* DSYEV computes all eigenvalues and, optionally, eigenvectors of a */
+/* real symmetric matrix A. */
+
+/* Arguments */
+/* ========= */
+
+/* JOBZ (input) CHARACTER*1 */
+/* = 'N': Compute eigenvalues only; */
+/* = 'V': Compute eigenvalues and eigenvectors. */
+
+/* UPLO (input) CHARACTER*1 */
+/* = 'U': Upper triangle of A is stored; */
+/* = 'L': Lower triangle of A is stored. */
+
+/* N (input) INTEGER */
+/* The order of the matrix A. N >= 0. */
+
+/* A (input/output) DOUBLE PRECISION array, dimension (LDA, N) */
+/* On entry, the symmetric matrix A. If UPLO = 'U', the */
+/* leading N-by-N upper triangular part of A contains the */
+/* upper triangular part of the matrix A. If UPLO = 'L', */
+/* the leading N-by-N lower triangular part of A contains */
+/* the lower triangular part of the matrix A. */
+/* On exit, if JOBZ = 'V', then if INFO = 0, A contains the */
+/* orthonormal eigenvectors of the matrix A. */
+/* If JOBZ = 'N', then on exit the lower triangle (if UPLO='L') */
+/* or the upper triangle (if UPLO='U') of A, including the */
+/* diagonal, is destroyed. */
+
+/* LDA (input) INTEGER */
+/* The leading dimension of the array A. LDA >= max(1,N). */
+
+/* W (output) DOUBLE PRECISION array, dimension (N) */
+/* If INFO = 0, the eigenvalues in ascending order. */
+
+/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
+/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */
+
+/* LWORK (input) INTEGER */
+/* The length of the array WORK. LWORK >= max(1,3*N-1). */
+/* For optimal efficiency, LWORK >= (NB+2)*N, */
+/* where NB is the blocksize for DSYTRD returned by ILAENV. */
+
+/* If LWORK = -1, then a workspace query is assumed; the routine */
+/* only calculates the optimal size of the WORK array, returns */
+/* this value as the first entry of the WORK array, and no error */
+/* message related to LWORK is issued by XERBLA. */
+
+/* INFO (output) INTEGER */
+/* = 0: successful exit */
+/* < 0: if INFO = -i, the i-th argument had an illegal value */
+/* > 0: if INFO = i, the algorithm failed to converge; i */
+/* off-diagonal elements of an intermediate tridiagonal */
+/* form did not converge to zero. */
+
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2,
+ long NC1, long NC2,
+ typename MM
+ >
+ int syev (
+ const char jobz,
+ const char uplo,
+ matrix<T,NR1,NC1,MM,column_major_layout>& a,
+ matrix<T,NR2,NC2,MM,column_major_layout>& w
+ )
+ {
+ matrix<T,0,1,MM,column_major_layout> work;
+
+ const long n = a.nr();
+
+ w.set_size(n,1);
+
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::syev(jobz, uplo, n, &a(0,0),
+ a.nr(), &w(0,0), &work_size, -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual decomposition
+ info = binding::syev(jobz, uplo, n, &a(0,0),
+ a.nr(), &w(0,0), &work(0,0), work.size());
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2,
+ long NC1, long NC2,
+ typename MM
+ >
+ int syev (
+ char jobz,
+ char uplo,
+ matrix<T,NR1,NC1,MM,row_major_layout>& a,
+ matrix<T,NR2,NC2,MM,row_major_layout>& w
+ )
+ {
+ matrix<T,0,1,MM,row_major_layout> work;
+
+ if (uplo == 'L')
+ uplo = 'U';
+ else
+ uplo = 'L';
+
+ const long n = a.nr();
+
+ w.set_size(n,1);
+
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ int info = binding::syev(jobz, uplo, n, &a(0,0),
+ a.nc(), &w(0,0), &work_size, -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+
+ // compute the actual decomposition
+ info = binding::syev(jobz, uplo, n, &a(0,0),
+ a.nc(), &w(0,0), &work(0,0), work.size());
+
+
+ a = trans(a);
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_EV_Hh_
+
+
+
+
diff --git a/ml/dlib/dlib/matrix/lapack/syevr.h b/ml/dlib/dlib/matrix/lapack/syevr.h
new file mode 100644
index 000000000..65190b3d8
--- /dev/null
+++ b/ml/dlib/dlib/matrix/lapack/syevr.h
@@ -0,0 +1,445 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LAPACk_EVR_Hh_
+#define DLIB_LAPACk_EVR_Hh_
+
+#include "fortran_id.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+ namespace lapack
+ {
+ namespace binding
+ {
+ extern "C"
+ {
+ void DLIB_FORTRAN_ID(dsyevr) (char *jobz, char *range, char *uplo, integer *n,
+ double *a, integer *lda, double *vl, double *vu, integer * il,
+ integer *iu, double *abstol, integer *m, double *w,
+ double *z_, integer *ldz, integer *isuppz, double *work,
+ integer *lwork, integer *iwork, integer *liwork, integer *info);
+
+ void DLIB_FORTRAN_ID(ssyevr) (char *jobz, char *range, char *uplo, integer *n,
+ float *a, integer *lda, float *vl, float *vu, integer * il,
+ integer *iu, float *abstol, integer *m, float *w,
+ float *z_, integer *ldz, integer *isuppz, float *work,
+ integer *lwork, integer *iwork, integer *liwork, integer *info);
+ }
+
+ inline int syevr (char jobz, char range, char uplo, integer n,
+ double* a, integer lda, double vl, double vu, integer il,
+ integer iu, double abstol, integer *m, double *w,
+ double *z, integer ldz, integer *isuppz, double *work,
+ integer lwork, integer *iwork, integer liwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(dsyevr)(&jobz, &range, &uplo, &n,
+ a, &lda, &vl, &vu, &il,
+ &iu, &abstol, m, w,
+ z, &ldz, isuppz, work,
+ &lwork, iwork, &liwork, &info);
+ return info;
+ }
+
+ inline int syevr (char jobz, char range, char uplo, integer n,
+ float* a, integer lda, float vl, float vu, integer il,
+ integer iu, float abstol, integer *m, float *w,
+ float *z, integer ldz, integer *isuppz, float *work,
+ integer lwork, integer *iwork, integer liwork)
+ {
+ integer info = 0;
+ DLIB_FORTRAN_ID(ssyevr)(&jobz, &range, &uplo, &n,
+ a, &lda, &vl, &vu, &il,
+ &iu, &abstol, m, w,
+ z, &ldz, isuppz, work,
+ &lwork, iwork, &liwork, &info);
+ return info;
+ }
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ /*
+
+* -- LAPACK driver routine (version 3.1) --
+* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd..
+* November 2006
+*
+* .. Scalar Arguments ..
+ CHARACTER JOBZ, RANGE, UPLO
+ INTEGER IL, INFO, IU, LDA, LDZ, LIWORK, LWORK, M, N
+ DOUBLE PRECISION ABSTOL, VL, VU
+* ..
+* .. Array Arguments ..
+ INTEGER ISUPPZ( * ), IWORK( * )
+ DOUBLE PRECISION A( LDA, * ), W( * ), WORK( * ), Z( LDZ, * )
+* ..
+*
+* Purpose
+* =======
+*
+* DSYEVR computes selected eigenvalues and, optionally, eigenvectors
+* of a real symmetric matrix A. Eigenvalues and eigenvectors can be
+* selected by specifying either a range of values or a range of
+* indices for the desired eigenvalues.
+*
+* DSYEVR first reduces the matrix A to tridiagonal form T with a call
+* to DSYTRD. Then, whenever possible, DSYEVR calls DSTEMR to compute
+* the eigenspectrum using Relatively Robust Representations. DSTEMR
+* computes eigenvalues by the dqds algorithm, while orthogonal
+* eigenvectors are computed from various "good" L D L^T representations
+* (also known as Relatively Robust Representations). Gram-Schmidt
+* orthogonalization is avoided as far as possible. More specifically,
+* the various steps of the algorithm are as follows.
+*
+* For each unreduced block (submatrix) of T,
+* (a) Compute T - sigma I = L D L^T, so that L and D
+* define all the wanted eigenvalues to high relative accuracy.
+* This means that small relative changes in the entries of D and L
+* cause only small relative changes in the eigenvalues and
+* eigenvectors. The standard (unfactored) representation of the
+* tridiagonal matrix T does not have this property in general.
+* (b) Compute the eigenvalues to suitable accuracy.
+* If the eigenvectors are desired, the algorithm attains full
+* accuracy of the computed eigenvalues only right before
+* the corresponding vectors have to be computed, see steps c) and d).
+* (c) For each cluster of close eigenvalues, select a new
+* shift close to the cluster, find a new factorization, and refine
+* the shifted eigenvalues to suitable accuracy.
+* (d) For each eigenvalue with a large enough relative separation compute
+* the corresponding eigenvector by forming a rank revealing twisted
+* factorization. Go back to (c) for any clusters that remain.
+*
+* The desired accuracy of the output can be specified by the input
+* parameter ABSTOL.
+*
+* For more details, see DSTEMR's documentation and:
+* - Inderjit S. Dhillon and Beresford N. Parlett: "Multiple representations
+* to compute orthogonal eigenvectors of symmetric tridiagonal matrices,"
+* Linear Algebra and its Applications, 387(1), pp. 1-28, August 2004.
+* - Inderjit Dhillon and Beresford Parlett: "Orthogonal Eigenvectors and
+* Relative Gaps," SIAM Journal on Matrix Analysis and Applications, Vol. 25,
+* 2004. Also LAPACK Working Note 154.
+* - Inderjit Dhillon: "A new O(n^2) algorithm for the symmetric
+* tridiagonal eigenvalue/eigenvector problem",
+* Computer Science Division Technical Report No. UCB/CSD-97-971,
+* UC Berkeley, May 1997.
+*
+*
+* Note 1 : DSYEVR calls DSTEMR when the full spectrum is requested
+* on machines which conform to the ieee-754 floating point standard.
+* DSYEVR calls DSTEBZ and SSTEIN on non-ieee machines and
+* when partial spectrum requests are made.
+*
+* Normal execution of DSTEMR may create NaNs and infinities and
+* hence may abort due to a floating point exception in environments
+* which do not handle NaNs and infinities in the ieee standard default
+* manner.
+*
+* Arguments
+* =========
+*
+* JOBZ (input) CHARACTER*1
+* = 'N': Compute eigenvalues only;
+* = 'V': Compute eigenvalues and eigenvectors.
+*
+* RANGE (input) CHARACTER*1
+* = 'A': all eigenvalues will be found.
+* = 'V': all eigenvalues in the half-open interval (VL,VU]
+* will be found.
+* = 'I': the IL-th through IU-th eigenvalues will be found.
+********** For RANGE = 'V' or 'I' and IU - IL < N - 1, DSTEBZ and
+********** DSTEIN are called
+*
+* UPLO (input) CHARACTER*1
+* = 'U': Upper triangle of A is stored;
+* = 'L': Lower triangle of A is stored.
+*
+* N (input) INTEGER
+* The order of the matrix A. N >= 0.
+*
+* A (input/output) DOUBLE PRECISION array, dimension (LDA, N)
+* On entry, the symmetric matrix A. If UPLO = 'U', the
+* leading N-by-N upper triangular part of A contains the
+* upper triangular part of the matrix A. If UPLO = 'L',
+* the leading N-by-N lower triangular part of A contains
+* the lower triangular part of the matrix A.
+* On exit, the lower triangle (if UPLO='L') or the upper
+* triangle (if UPLO='U') of A, including the diagonal, is
+* destroyed.
+*
+* LDA (input) INTEGER
+* The leading dimension of the array A. LDA >= max(1,N).
+*
+* VL (input) DOUBLE PRECISION
+* VU (input) DOUBLE PRECISION
+* If RANGE='V', the lower and upper bounds of the interval to
+* be searched for eigenvalues. VL < VU.
+* Not referenced if RANGE = 'A' or 'I'.
+*
+* IL (input) INTEGER
+* IU (input) INTEGER
+* If RANGE='I', the indices (in ascending order) of the
+* smallest and largest eigenvalues to be returned.
+* 1 <= IL <= IU <= N, if N > 0; IL = 1 and IU = 0 if N = 0.
+* Not referenced if RANGE = 'A' or 'V'.
+*
+* ABSTOL (input) DOUBLE PRECISION
+* The absolute error tolerance for the eigenvalues.
+* An approximate eigenvalue is accepted as converged
+* when it is determined to lie in an interval [a,b]
+* of width less than or equal to
+*
+* ABSTOL + EPS * max( |a|,|b| ) ,
+*
+* where EPS is the machine precision. If ABSTOL is less than
+* or equal to zero, then EPS*|T| will be used in its place,
+* where |T| is the 1-norm of the tridiagonal matrix obtained
+* by reducing A to tridiagonal form.
+*
+* See "Computing Small Singular Values of Bidiagonal Matrices
+* with Guaranteed High Relative Accuracy," by Demmel and
+* Kahan, LAPACK Working Note #3.
+*
+* If high relative accuracy is important, set ABSTOL to
+* DLAMCH( 'Safe minimum' ). Doing so will guarantee that
+* eigenvalues are computed to high relative accuracy when
+* possible in future releases. The current code does not
+* make any guarantees about high relative accuracy, but
+* future releases will. See J. Barlow and J. Demmel,
+* "Computing Accurate Eigensystems of Scaled Diagonally
+* Dominant Matrices", LAPACK Working Note #7, for a discussion
+* of which matrices define their eigenvalues to high relative
+* accuracy.
+*
+* M (output) INTEGER
+* The total number of eigenvalues found. 0 <= M <= N.
+* If RANGE = 'A', M = N, and if RANGE = 'I', M = IU-IL+1.
+*
+* W (output) DOUBLE PRECISION array, dimension (N)
+* The first M elements contain the selected eigenvalues in
+* ascending order.
+*
+* Z (output) DOUBLE PRECISION array, dimension (LDZ, max(1,M))
+* If JOBZ = 'V', then if INFO = 0, the first M columns of Z
+* contain the orthonormal eigenvectors of the matrix A
+* corresponding to the selected eigenvalues, with the i-th
+* column of Z holding the eigenvector associated with W(i).
+* If JOBZ = 'N', then Z is not referenced.
+* Note: the user must ensure that at least max(1,M) columns are
+* supplied in the array Z; if RANGE = 'V', the exact value of M
+* is not known in advance and an upper bound must be used.
+* Supplying N columns is always safe.
+*
+* LDZ (input) INTEGER
+* The leading dimension of the array Z. LDZ >= 1, and if
+* JOBZ = 'V', LDZ >= max(1,N).
+*
+* ISUPPZ (output) INTEGER array, dimension ( 2*max(1,M) )
+* The support of the eigenvectors in Z, i.e., the indices
+* indicating the nonzero elements in Z. The i-th eigenvector
+* is nonzero only in elements ISUPPZ( 2*i-1 ) through
+* ISUPPZ( 2*i ).
+********** Implemented only for RANGE = 'A' or 'I' and IU - IL = N - 1
+*
+* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK))
+* On exit, if INFO = 0, WORK(1) returns the optimal LWORK.
+*
+* LWORK (input) INTEGER
+* The dimension of the array WORK. LWORK >= max(1,26*N).
+* For optimal efficiency, LWORK >= (NB+6)*N,
+* where NB is the max of the blocksize for DSYTRD and DORMTR
+* returned by ILAENV.
+*
+* If LWORK = -1, then a workspace query is assumed; the routine
+* only calculates the optimal size of the WORK array, returns
+* this value as the first entry of the WORK array, and no error
+* message related to LWORK is issued by XERBLA.
+*
+* IWORK (workspace/output) INTEGER array, dimension (MAX(1,LIWORK))
+* On exit, if INFO = 0, IWORK(1) returns the optimal LWORK.
+*
+* LIWORK (input) INTEGER
+* The dimension of the array IWORK. LIWORK >= max(1,10*N).
+*
+* If LIWORK = -1, then a workspace query is assumed; the
+* routine only calculates the optimal size of the IWORK array,
+* returns this value as the first entry of the IWORK array, and
+* no error message related to LIWORK is issued by XERBLA.
+*
+* INFO (output) INTEGER
+* = 0: successful exit
+* < 0: if INFO = -i, the i-th argument had an illegal value
+* > 0: Internal error
+*
+* Further Details
+* ===============
+*
+* Based on contributions by
+* Inderjit Dhillon, IBM Almaden, USA
+* Osni Marques, LBNL/NERSC, USA
+* Ken Stanley, Computer Science Division, University of
+* California at Berkeley, USA
+* Jason Riedy, Computer Science Division, University of
+* California at Berkeley, USA
+*
+* =====================================================================
+
+ */
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2, long NR3, long NR4,
+ long NC1, long NC2, long NC3, long NC4,
+ typename MM
+ >
+ int syevr (
+ const char jobz,
+ const char range,
+ const char uplo,
+ matrix<T,NR1,NC1,MM,column_major_layout>& a,
+ const double vl,
+ const double vu,
+ const integer il,
+ const integer iu,
+ const double abstol,
+ integer& num_eigenvalues_found,
+ matrix<T,NR2,NC2,MM,column_major_layout>& w,
+ matrix<T,NR3,NC3,MM,column_major_layout>& z,
+ matrix<integer,NR4,NC4,MM,column_major_layout>& isuppz
+ )
+ {
+ matrix<T,0,1,MM,column_major_layout> work;
+ matrix<integer,0,1,MM,column_major_layout> iwork;
+
+ const long n = a.nr();
+
+ w.set_size(n,1);
+
+ isuppz.set_size(2*n, 1);
+
+ if (jobz == 'V')
+ {
+ z.set_size(n,n);
+ }
+ else
+ {
+ z.set_size(NR3?NR3:1, NC3?NC3:1);
+ }
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ integer iwork_size = 1;
+ int info = binding::syevr(jobz, range, uplo, n, &a(0,0),
+ a.nr(), vl, vu, il, iu, abstol, &num_eigenvalues_found,
+ &w(0,0), &z(0,0), z.nr(), &isuppz(0,0), &work_size, -1,
+ &iwork_size, -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+ if (iwork.size() < iwork_size)
+ iwork.set_size(iwork_size, 1);
+
+ // compute the actual decomposition
+ info = binding::syevr(jobz, range, uplo, n, &a(0,0),
+ a.nr(), vl, vu, il, iu, abstol, &num_eigenvalues_found,
+ &w(0,0), &z(0,0), z.nr(), &isuppz(0,0), &work(0,0), work.size(),
+ &iwork(0,0), iwork.size());
+
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2, long NR3, long NR4,
+ long NC1, long NC2, long NC3, long NC4,
+ typename MM
+ >
+ int syevr (
+ const char jobz,
+ const char range,
+ char uplo,
+ matrix<T,NR1,NC1,MM,row_major_layout>& a,
+ const double vl,
+ const double vu,
+ const integer il,
+ const integer iu,
+ const double abstol,
+ integer& num_eigenvalues_found,
+ matrix<T,NR2,NC2,MM,row_major_layout>& w,
+ matrix<T,NR3,NC3,MM,row_major_layout>& z,
+ matrix<integer,NR4,NC4,MM,row_major_layout>& isuppz
+ )
+ {
+ matrix<T,0,1,MM,row_major_layout> work;
+ matrix<integer,0,1,MM,row_major_layout> iwork;
+
+ if (uplo == 'L')
+ uplo = 'U';
+ else
+ uplo = 'L';
+
+ const long n = a.nr();
+
+ w.set_size(n,1);
+
+ isuppz.set_size(2*n, 1);
+
+ if (jobz == 'V')
+ {
+ z.set_size(n,n);
+ }
+ else
+ {
+ z.set_size(NR3?NR3:1, NC3?NC3:1);
+ }
+
+ // figure out how big the workspace needs to be.
+ T work_size = 1;
+ integer iwork_size = 1;
+ int info = binding::syevr(jobz, range, uplo, n, &a(0,0),
+ a.nc(), vl, vu, il, iu, abstol, &num_eigenvalues_found,
+ &w(0,0), &z(0,0), z.nc(), &isuppz(0,0), &work_size, -1,
+ &iwork_size, -1);
+
+ if (info != 0)
+ return info;
+
+ if (work.size() < work_size)
+ work.set_size(static_cast<long>(work_size), 1);
+ if (iwork.size() < iwork_size)
+ iwork.set_size(iwork_size, 1);
+
+ // compute the actual decomposition
+ info = binding::syevr(jobz, range, uplo, n, &a(0,0),
+ a.nc(), vl, vu, il, iu, abstol, &num_eigenvalues_found,
+ &w(0,0), &z(0,0), z.nc(), &isuppz(0,0), &work(0,0), work.size(),
+ &iwork(0,0), iwork.size());
+
+ z = trans(z);
+
+ return info;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_LAPACk_EVR_Hh_
+
+
+
diff --git a/ml/dlib/dlib/matrix/matrix.h b/ml/dlib/dlib/matrix/matrix.h
new file mode 100644
index 000000000..b16635879
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix.h
@@ -0,0 +1,2162 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_
+#define DLIB_MATRIx_
+
+#include "matrix_exp.h"
+#include "matrix_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "../enable_if.h"
+#include <sstream>
+#include <algorithm>
+#include "../memory_manager.h"
+#include "../is_kind.h"
+#include "matrix_data_layout.h"
+#include "matrix_assign_fwd.h"
+#include "matrix_op.h"
+#include <utility>
+#ifdef DLIB_HAS_INITIALIZER_LISTS
+#include <initializer_list>
+#endif
+
+#ifdef MATLAB_MEX_FILE
+#include <mex.h>
+#endif
+
+#ifdef _MSC_VER
+// Disable the following warnings for Visual Studio
+
+// This warning is:
+// "warning C4355: 'this' : used in base member initializer list"
+// Which we get from this code but it is not an error so I'm turning this
+// warning off and then turning it back on at the end of the file.
+#pragma warning(disable : 4355)
+
+// "warning C4723: potential divide by 0" - This warning is triggered in
+// matrix(const std::initializer_list<T>& l) where the compiler can see that
+// matrix<> was templated in a way making NR ending up 0, but division by 0 at runtime
+// is not possible because the division operation is inside "if (NR!=0)" block.
+#pragma warning(disable : 4723)
+
+#endif
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ // This template will perform the needed loop for element multiplication using whichever
+ // dimension is provided as a compile time constant (if one is at all).
+ template <
+ typename LHS,
+ typename RHS,
+ long lhs_nc = LHS::NC,
+ long rhs_nr = RHS::NR
+ >
+ struct matrix_multiply_helper
+ {
+ typedef typename LHS::type type;
+ template <typename RHS_, typename LHS_>
+ inline const static type eval (
+ const RHS_& rhs,
+ const LHS_& lhs,
+ const long r,
+ const long c
+ )
+ {
+ type temp = lhs(r,0)*rhs(0,c);
+ for (long i = 1; i < rhs.nr(); ++i)
+ {
+ temp += lhs(r,i)*rhs(i,c);
+ }
+ return temp;
+ }
+ };
+
+ template <
+ typename LHS,
+ typename RHS,
+ long lhs_nc
+ >
+ struct matrix_multiply_helper <LHS,RHS,lhs_nc,0>
+ {
+ typedef typename LHS::type type;
+ template <typename RHS_, typename LHS_>
+ inline const static type eval (
+ const RHS_& rhs,
+ const LHS_& lhs,
+ const long r,
+ const long c
+ )
+ {
+ type temp = lhs(r,0)*rhs(0,c);
+ for (long i = 1; i < lhs.nc(); ++i)
+ {
+ temp += lhs(r,i)*rhs(i,c);
+ }
+ return temp;
+ }
+ };
+
+ template <typename LHS, typename RHS>
+ class matrix_multiply_exp;
+
+ template <typename LHS, typename RHS>
+ struct matrix_traits<matrix_multiply_exp<LHS,RHS> >
+ {
+ typedef typename LHS::type type;
+ typedef typename LHS::type const_ret_type;
+ typedef typename LHS::mem_manager_type mem_manager_type;
+ typedef typename LHS::layout_type layout_type;
+ const static long NR = LHS::NR;
+ const static long NC = RHS::NC;
+
+#ifdef DLIB_USE_BLAS
+ // if there are BLAS functions to be called then we want to make sure we
+ // always evaluate any complex expressions so that the BLAS bindings can happen.
+ const static bool lhs_is_costly = (LHS::cost > 2)&&(RHS::NC != 1 || LHS::cost >= 10000);
+ const static bool rhs_is_costly = (RHS::cost > 2)&&(LHS::NR != 1 || RHS::cost >= 10000);
+#else
+ const static bool lhs_is_costly = (LHS::cost > 4)&&(RHS::NC != 1);
+ const static bool rhs_is_costly = (RHS::cost > 4)&&(LHS::NR != 1);
+#endif
+
+ // Note that if we decide that one of the matrices is too costly we will evaluate it
+ // into a temporary. Doing this resets its cost back to 1.
+ const static long lhs_cost = ((lhs_is_costly==true)? 1 : (LHS::cost));
+ const static long rhs_cost = ((rhs_is_costly==true)? 1 : (RHS::cost));
+
+ // The cost of evaluating an element of a matrix multiply is the cost of evaluating elements from
+ // RHS and LHS times the number of rows/columns in the RHS/LHS matrix. If we don't know the matrix
+ // dimensions then just assume it is really large.
+ const static long cost = ((tmax<LHS::NC,RHS::NR>::value!=0)? ((lhs_cost+rhs_cost)*tmax<LHS::NC,RHS::NR>::value):(10000));
+ };
+
+ template <typename T, bool is_ref> struct conditional_matrix_temp { typedef typename T::matrix_type type; };
+ template <typename T> struct conditional_matrix_temp<T,true> { typedef T& type; };
+
+ template <
+ typename LHS,
+ typename RHS
+ >
+ class matrix_multiply_exp : public matrix_exp<matrix_multiply_exp<LHS,RHS> >
+ {
+ /*!
+ REQUIREMENTS ON LHS AND RHS
+ - must be matrix_exp objects.
+ !*/
+ public:
+
+ typedef typename matrix_traits<matrix_multiply_exp>::type type;
+ typedef typename matrix_traits<matrix_multiply_exp>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_multiply_exp>::mem_manager_type mem_manager_type;
+ const static long NR = matrix_traits<matrix_multiply_exp>::NR;
+ const static long NC = matrix_traits<matrix_multiply_exp>::NC;
+ const static long cost = matrix_traits<matrix_multiply_exp>::cost;
+ typedef typename matrix_traits<matrix_multiply_exp>::layout_type layout_type;
+
+
+ const static bool lhs_is_costly = matrix_traits<matrix_multiply_exp>::lhs_is_costly;
+ const static bool rhs_is_costly = matrix_traits<matrix_multiply_exp>::rhs_is_costly;
+ const static bool either_is_costly = lhs_is_costly || rhs_is_costly;
+ const static bool both_are_costly = lhs_is_costly && rhs_is_costly;
+
+ typedef typename conditional_matrix_temp<const LHS,lhs_is_costly == false>::type LHS_ref_type;
+ typedef typename conditional_matrix_temp<const RHS,rhs_is_costly == false>::type RHS_ref_type;
+
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of objects.
+ template <typename T1, typename T2>
+ matrix_multiply_exp (T1,T2);
+
+ inline matrix_multiply_exp (
+ const LHS& lhs_,
+ const RHS& rhs_
+ ) :
+ lhs(lhs_),
+ rhs(rhs_)
+ {
+ // You are trying to multiply two incompatible matrices together. The number of columns
+ // in the matrix on the left must match the number of rows in the matrix on the right.
+ COMPILE_TIME_ASSERT(LHS::NC == RHS::NR || LHS::NC*RHS::NR == 0);
+ DLIB_ASSERT(lhs.nc() == rhs.nr() && lhs.size() > 0 && rhs.size() > 0,
+ "\tconst matrix_exp operator*(const matrix_exp& lhs, const matrix_exp& rhs)"
+ << "\n\tYou are trying to multiply two incompatible matrices together"
+ << "\n\tlhs.nr(): " << lhs.nr()
+ << "\n\tlhs.nc(): " << lhs.nc()
+ << "\n\trhs.nr(): " << rhs.nr()
+ << "\n\trhs.nc(): " << rhs.nc()
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+ // You can't multiply matrices together if they don't both contain the same type of elements.
+ COMPILE_TIME_ASSERT((is_same_type<typename LHS::type, typename RHS::type>::value == true));
+ }
+
+ inline const type operator() (
+ const long r,
+ const long c
+ ) const
+ {
+ return matrix_multiply_helper<LHS,RHS>::eval(rhs,lhs,r,c);
+ }
+
+ inline const type operator() ( long i ) const
+ { return matrix_exp<matrix_multiply_exp>::operator()(i); }
+
+ long nr (
+ ) const { return lhs.nr(); }
+
+ long nc (
+ ) const { return rhs.nc(); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const { return lhs.aliases(item) || rhs.aliases(item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const { return aliases(item); }
+
+ LHS_ref_type lhs;
+ RHS_ref_type rhs;
+ };
+
+ template < typename EXP1, typename EXP2 >
+ inline const matrix_multiply_exp<EXP1, EXP2> operator* (
+ const matrix_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ )
+ {
+ return matrix_multiply_exp<EXP1, EXP2>(m1.ref(), m2.ref());
+ }
+
+ template <typename M, bool use_reference = true>
+ class matrix_mul_scal_exp;
+
+ // -------------------------
+
+ // Now we declare some overloads that cause any scalar multiplications to percolate
+ // up and outside of any matrix multiplies. Note that we are using the non-reference containing
+ // mode of the matrix_mul_scal_exp object since we are passing in locally constructed matrix_multiply_exp
+ // objects. So the matrix_mul_scal_exp object will contain copies of matrix_multiply_exp objects
+ // rather than references to them. This could result in extra matrix copies if the matrix_multiply_exp
+ // decided it should evaluate any of its arguments. So we also try to not apply this percolating operation
+ // if the matrix_multiply_exp would contain a fully evaluated copy of the original matrix_mul_scal_exp
+ // expression.
+ //
+ // Also, the reason we want to apply this transformation in the first place is because it (1) makes
+ // the expressions going into matrix multiply expressions simpler and (2) it makes it a lot more
+ // straightforward to bind BLAS calls to matrix expressions involving scalar multiplies.
+ template < typename EXP1, typename EXP2 >
+ inline const typename disable_if_c< matrix_multiply_exp<matrix_mul_scal_exp<EXP1>, matrix_mul_scal_exp<EXP2> >::both_are_costly ,
+ matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
+ const matrix_mul_scal_exp<EXP1>& m1,
+ const matrix_mul_scal_exp<EXP2>& m2
+ )
+ {
+ typedef matrix_multiply_exp<EXP1, EXP2> exp1;
+ typedef matrix_mul_scal_exp<exp1,false> exp2;
+ return exp2(exp1(m1.m, m2.m), m1.s*m2.s);
+ }
+
+ template < typename EXP1, typename EXP2 >
+ inline const typename disable_if_c< matrix_multiply_exp<matrix_mul_scal_exp<EXP1>, EXP2 >::lhs_is_costly ,
+ matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
+ const matrix_mul_scal_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ )
+ {
+ typedef matrix_multiply_exp<EXP1, EXP2> exp1;
+ typedef matrix_mul_scal_exp<exp1,false> exp2;
+ return exp2(exp1(m1.m, m2.ref()), m1.s);
+ }
+
+ template < typename EXP1, typename EXP2 >
+ inline const typename disable_if_c< matrix_multiply_exp<EXP1, matrix_mul_scal_exp<EXP2> >::rhs_is_costly ,
+ matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
+ const matrix_exp<EXP1>& m1,
+ const matrix_mul_scal_exp<EXP2>& m2
+ )
+ {
+ typedef matrix_multiply_exp<EXP1, EXP2> exp1;
+ typedef matrix_mul_scal_exp<exp1,false> exp2;
+ return exp2(exp1(m1.ref(), m2.m), m2.s);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename LHS, typename RHS>
+ class matrix_add_exp;
+
+ template <typename LHS, typename RHS>
+ struct matrix_traits<matrix_add_exp<LHS,RHS> >
+ {
+ typedef typename LHS::type type;
+ typedef typename LHS::type const_ret_type;
+ typedef typename LHS::mem_manager_type mem_manager_type;
+ typedef typename LHS::layout_type layout_type;
+ const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR;
+ const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC;
+ const static long cost = LHS::cost+RHS::cost+1;
+ };
+
+ template <
+ typename LHS,
+ typename RHS
+ >
+ class matrix_add_exp : public matrix_exp<matrix_add_exp<LHS,RHS> >
+ {
+ /*!
+ REQUIREMENTS ON LHS AND RHS
+ - must be matrix_exp objects.
+ !*/
+ public:
+ typedef typename matrix_traits<matrix_add_exp>::type type;
+ typedef typename matrix_traits<matrix_add_exp>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_add_exp>::mem_manager_type mem_manager_type;
+ const static long NR = matrix_traits<matrix_add_exp>::NR;
+ const static long NC = matrix_traits<matrix_add_exp>::NC;
+ const static long cost = matrix_traits<matrix_add_exp>::cost;
+ typedef typename matrix_traits<matrix_add_exp>::layout_type layout_type;
+
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of objects.
+ template <typename T1, typename T2>
+ matrix_add_exp (T1,T2);
+
+ matrix_add_exp (
+ const LHS& lhs_,
+ const RHS& rhs_
+ ) :
+ lhs(lhs_),
+ rhs(rhs_)
+ {
+ // You can only add matrices together if they both have the same number of rows and columns.
+ COMPILE_TIME_ASSERT(LHS::NR == RHS::NR || LHS::NR == 0 || RHS::NR == 0);
+ COMPILE_TIME_ASSERT(LHS::NC == RHS::NC || LHS::NC == 0 || RHS::NC == 0);
+ DLIB_ASSERT(lhs.nc() == rhs.nc() &&
+ lhs.nr() == rhs.nr(),
+ "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)"
+ << "\n\tYou are trying to add two incompatible matrices together"
+ << "\n\tlhs.nr(): " << lhs.nr()
+ << "\n\tlhs.nc(): " << lhs.nc()
+ << "\n\trhs.nr(): " << rhs.nr()
+ << "\n\trhs.nc(): " << rhs.nc()
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+ // You can only add matrices together if they both contain the same types of elements.
+ COMPILE_TIME_ASSERT((is_same_type<typename LHS::type, typename RHS::type>::value == true));
+ }
+
+ const type operator() (
+ long r,
+ long c
+ ) const { return lhs(r,c) + rhs(r,c); }
+
+ inline const type operator() ( long i ) const
+ { return matrix_exp<matrix_add_exp>::operator()(i); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const { return lhs.aliases(item) || rhs.aliases(item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const { return lhs.destructively_aliases(item) || rhs.destructively_aliases(item); }
+
+ long nr (
+ ) const { return lhs.nr(); }
+
+ long nc (
+ ) const { return lhs.nc(); }
+
+ const LHS& lhs;
+ const RHS& rhs;
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline const matrix_add_exp<EXP1, EXP2> operator+ (
+ const matrix_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ )
+ {
+ return matrix_add_exp<EXP1, EXP2>(m1.ref(),m2.ref());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename LHS, typename RHS>
+ class matrix_subtract_exp;
+
+ template <typename LHS, typename RHS>
+ struct matrix_traits<matrix_subtract_exp<LHS,RHS> >
+ {
+ typedef typename LHS::type type;
+ typedef typename LHS::type const_ret_type;
+ typedef typename LHS::mem_manager_type mem_manager_type;
+ typedef typename LHS::layout_type layout_type;
+ const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR;
+ const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC;
+ const static long cost = LHS::cost+RHS::cost+1;
+ };
+
+ template <
+ typename LHS,
+ typename RHS
+ >
+ class matrix_subtract_exp : public matrix_exp<matrix_subtract_exp<LHS,RHS> >
+ {
+ /*!
+ REQUIREMENTS ON LHS AND RHS
+ - must be matrix_exp objects.
+ !*/
+ public:
+ typedef typename matrix_traits<matrix_subtract_exp>::type type;
+ typedef typename matrix_traits<matrix_subtract_exp>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_subtract_exp>::mem_manager_type mem_manager_type;
+ const static long NR = matrix_traits<matrix_subtract_exp>::NR;
+ const static long NC = matrix_traits<matrix_subtract_exp>::NC;
+ const static long cost = matrix_traits<matrix_subtract_exp>::cost;
+ typedef typename matrix_traits<matrix_subtract_exp>::layout_type layout_type;
+
+
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of objects.
+ template <typename T1, typename T2>
+ matrix_subtract_exp (T1,T2);
+
+ matrix_subtract_exp (
+ const LHS& lhs_,
+ const RHS& rhs_
+ ) :
+ lhs(lhs_),
+ rhs(rhs_)
+ {
+ // You can only subtract one matrix from another if they both have the same number of rows and columns.
+ COMPILE_TIME_ASSERT(LHS::NR == RHS::NR || LHS::NR == 0 || RHS::NR == 0);
+ COMPILE_TIME_ASSERT(LHS::NC == RHS::NC || LHS::NC == 0 || RHS::NC == 0);
+ DLIB_ASSERT(lhs.nc() == rhs.nc() &&
+ lhs.nr() == rhs.nr(),
+ "\tconst matrix_exp operator-(const matrix_exp& lhs, const matrix_exp& rhs)"
+ << "\n\tYou are trying to subtract two incompatible matrices"
+ << "\n\tlhs.nr(): " << lhs.nr()
+ << "\n\tlhs.nc(): " << lhs.nc()
+ << "\n\trhs.nr(): " << rhs.nr()
+ << "\n\trhs.nc(): " << rhs.nc()
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+ // You can only subtract one matrix from another if they both contain elements of the same type.
+ COMPILE_TIME_ASSERT((is_same_type<typename LHS::type, typename RHS::type>::value == true));
+ }
+
+ const type operator() (
+ long r,
+ long c
+ ) const { return lhs(r,c) - rhs(r,c); }
+
+ inline const type operator() ( long i ) const
+ { return matrix_exp<matrix_subtract_exp>::operator()(i); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const { return lhs.aliases(item) || rhs.aliases(item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const { return lhs.destructively_aliases(item) || rhs.destructively_aliases(item); }
+
+ long nr (
+ ) const { return lhs.nr(); }
+
+ long nc (
+ ) const { return lhs.nc(); }
+
+ const LHS& lhs;
+ const RHS& rhs;
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline const matrix_subtract_exp<EXP1, EXP2> operator- (
+ const matrix_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ )
+ {
+ return matrix_subtract_exp<EXP1, EXP2>(m1.ref(),m2.ref());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ class matrix_div_scal_exp;
+
+ template <typename M>
+ struct matrix_traits<matrix_div_scal_exp<M> >
+ {
+ typedef typename M::type type;
+ typedef typename M::type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const static long NR = M::NR;
+ const static long NC = M::NC;
+ const static long cost = M::cost+1;
+ };
+
+ template <
+ typename M
+ >
+ class matrix_div_scal_exp : public matrix_exp<matrix_div_scal_exp<M> >
+ {
+ /*!
+ REQUIREMENTS ON M
+ - must be a matrix_exp object.
+ !*/
+ public:
+ typedef typename matrix_traits<matrix_div_scal_exp>::type type;
+ typedef typename matrix_traits<matrix_div_scal_exp>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_div_scal_exp>::mem_manager_type mem_manager_type;
+ const static long NR = matrix_traits<matrix_div_scal_exp>::NR;
+ const static long NC = matrix_traits<matrix_div_scal_exp>::NC;
+ const static long cost = matrix_traits<matrix_div_scal_exp>::cost;
+ typedef typename matrix_traits<matrix_div_scal_exp>::layout_type layout_type;
+
+
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of objects.
+ template <typename T1>
+ matrix_div_scal_exp (T1, const type&);
+
+ matrix_div_scal_exp (
+ const M& m_,
+ const type& s_
+ ) :
+ m(m_),
+ s(s_)
+ {}
+
+ const type operator() (
+ long r,
+ long c
+ ) const { return m(r,c)/s; }
+
+ inline const type operator() ( long i ) const
+ { return matrix_exp<matrix_div_scal_exp>::operator()(i); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const { return m.aliases(item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const { return m.destructively_aliases(item); }
+
+ long nr (
+ ) const { return m.nr(); }
+
+ long nc (
+ ) const { return m.nc(); }
+
+ const M& m;
+ const type s;
+ };
+
+ template <
+ typename EXP,
+ typename S
+ >
+ inline const typename enable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_div_scal_exp<EXP> >::type operator/ (
+ const matrix_exp<EXP>& m,
+ const S& s
+ )
+ {
+ return matrix_div_scal_exp<EXP>(m.ref(),static_cast<typename EXP::type>(s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, bool use_reference >
+ struct matrix_traits<matrix_mul_scal_exp<M,use_reference> >
+ {
+ typedef typename M::type type;
+ typedef typename M::type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const static long NR = M::NR;
+ const static long NC = M::NC;
+ const static long cost = M::cost+1;
+ };
+
+ template <typename T, bool is_ref> struct conditional_reference { typedef T type; };
+ template <typename T> struct conditional_reference<T,true> { typedef T& type; };
+
+
+ template <
+ typename M,
+ bool use_reference
+ >
+ class matrix_mul_scal_exp : public matrix_exp<matrix_mul_scal_exp<M,use_reference> >
+ {
+ /*!
+ REQUIREMENTS ON M
+ - must be a matrix_exp object.
+
+ !*/
+ public:
+ typedef typename matrix_traits<matrix_mul_scal_exp>::type type;
+ typedef typename matrix_traits<matrix_mul_scal_exp>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_mul_scal_exp>::mem_manager_type mem_manager_type;
+ const static long NR = matrix_traits<matrix_mul_scal_exp>::NR;
+ const static long NC = matrix_traits<matrix_mul_scal_exp>::NC;
+ const static long cost = matrix_traits<matrix_mul_scal_exp>::cost;
+ typedef typename matrix_traits<matrix_mul_scal_exp>::layout_type layout_type;
+
+ // You aren't allowed to multiply a matrix of matrices by a scalar.
+ COMPILE_TIME_ASSERT(is_matrix<type>::value == false);
+
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of objects.
+ template <typename T1>
+ matrix_mul_scal_exp (T1, const type&);
+
+ matrix_mul_scal_exp (
+ const M& m_,
+ const type& s_
+ ) :
+ m(m_),
+ s(s_)
+ {}
+
+ const type operator() (
+ long r,
+ long c
+ ) const { return m(r,c)*s; }
+
+ inline const type operator() ( long i ) const
+ { return matrix_exp<matrix_mul_scal_exp>::operator()(i); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const { return m.aliases(item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const { return m.destructively_aliases(item); }
+
+ long nr (
+ ) const { return m.nr(); }
+
+ long nc (
+ ) const { return m.nc(); }
+
+ typedef typename conditional_reference<const M,use_reference>::type M_ref_type;
+
+ M_ref_type m;
+ const type s;
+ };
+
+ template <
+ typename EXP,
+ typename S
+ >
+ inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
+ const matrix_exp<EXP>& m,
+ const S& s
+ )
+ {
+ typedef typename EXP::type type;
+ return matrix_mul_scal_exp<EXP>(m.ref(),static_cast<type>(s));
+ }
+
+ template <
+ typename EXP,
+ typename S,
+ bool B
+ >
+ inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
+ const matrix_mul_scal_exp<EXP,B>& m,
+ const S& s
+ )
+ {
+ typedef typename EXP::type type;
+ return matrix_mul_scal_exp<EXP>(m.m,static_cast<type>(s)*m.s);
+ }
+
+ template <
+ typename EXP,
+ typename S
+ >
+ inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
+ const S& s,
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef typename EXP::type type;
+ return matrix_mul_scal_exp<EXP>(m.ref(),static_cast<type>(s));
+ }
+
+ template <
+ typename EXP,
+ typename S,
+ bool B
+ >
+ inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
+ const S& s,
+ const matrix_mul_scal_exp<EXP,B>& m
+ )
+ {
+ typedef typename EXP::type type;
+ return matrix_mul_scal_exp<EXP>(m.m,static_cast<type>(s)*m.s);
+ }
+
+ template <
+ typename EXP ,
+ typename S
+ >
+ inline const typename disable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_mul_scal_exp<EXP> >::type operator/ (
+ const matrix_exp<EXP>& m,
+ const S& s
+ )
+ {
+ typedef typename EXP::type type;
+ const type one = 1;
+ return matrix_mul_scal_exp<EXP>(m.ref(),one/static_cast<type>(s));
+ }
+
+ template <
+ typename EXP,
+ bool B,
+ typename S
+ >
+ inline const typename disable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_mul_scal_exp<EXP> >::type operator/ (
+ const matrix_mul_scal_exp<EXP,B>& m,
+ const S& s
+ )
+ {
+ typedef typename EXP::type type;
+ return matrix_mul_scal_exp<EXP>(m.m,m.s/static_cast<type>(s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_s_div_m : basic_op_m<M>
+ {
+ typedef typename M::type type;
+
+ op_s_div_m( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
+
+ const type s;
+
+ const static long cost = M::cost+1;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply (long r, long c) const
+ {
+ return s/this->m(r,c);
+ }
+ };
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename disable_if<is_matrix<S>, matrix_op<op_s_div_m<EXP> > >::type operator/ (
+ const S& val,
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef typename EXP::type type;
+
+ typedef op_s_div_m<EXP> op;
+ return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ inline const matrix_mul_scal_exp<EXP> operator- (
+ const matrix_exp<EXP>& m
+ )
+ {
+ return matrix_mul_scal_exp<EXP>(m.ref(),-1);
+ }
+
+ template <
+ typename EXP,
+ bool B
+ >
+ inline const matrix_mul_scal_exp<EXP> operator- (
+ const matrix_mul_scal_exp<EXP,B>& m
+ )
+ {
+ return matrix_mul_scal_exp<EXP>(m.m,-1*m.s);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_add_scalar : basic_op_m<M>
+ {
+ typedef typename M::type type;
+
+ op_add_scalar( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
+
+ const type s;
+
+ const static long cost = M::cost+1;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply (long r, long c) const
+ {
+ return this->m(r,c) + s;
+ }
+ };
+
+ template <
+ typename EXP,
+ typename T
+ >
+ const typename disable_if<is_matrix<T>, matrix_op<op_add_scalar<EXP> > >::type operator+ (
+ const matrix_exp<EXP>& m,
+ const T& val
+ )
+ {
+ typedef typename EXP::type type;
+
+ typedef op_add_scalar<EXP> op;
+ return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
+ }
+
+ template <
+ typename EXP,
+ typename T
+ >
+ const typename disable_if<is_matrix<T>, matrix_op<op_add_scalar<EXP> > >::type operator+ (
+ const T& val,
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef typename EXP::type type;
+
+ typedef op_add_scalar<EXP> op;
+ return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_subl_scalar : basic_op_m<M>
+ {
+ typedef typename M::type type;
+
+ op_subl_scalar( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
+
+ const type s;
+
+ const static long cost = M::cost+1;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply (long r, long c) const
+ {
+ return s - this->m(r,c) ;
+ }
+ };
+
+ template <
+ typename EXP,
+ typename T
+ >
+ const typename disable_if<is_matrix<T>, matrix_op<op_subl_scalar<EXP> > >::type operator- (
+ const T& val,
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef typename EXP::type type;
+
+ typedef op_subl_scalar<EXP> op;
+ return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_subr_scalar : basic_op_m<M>
+ {
+ typedef typename M::type type;
+
+ op_subr_scalar( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
+
+ const type s;
+
+ const static long cost = M::cost+1;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply (long r, long c) const
+ {
+ return this->m(r,c) - s;
+ }
+ };
+
+ template <
+ typename EXP,
+ typename T
+ >
+ const typename disable_if<is_matrix<T>, matrix_op<op_subr_scalar<EXP> > >::type operator- (
+ const matrix_exp<EXP>& m,
+ const T& val
+ )
+ {
+ typedef typename EXP::type type;
+
+ typedef op_subr_scalar<EXP> op;
+ return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ bool operator== (
+ const matrix_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ )
+ {
+ if (m1.nr() == m2.nr() && m1.nc() == m2.nc())
+ {
+ for (long r = 0; r < m1.nr(); ++r)
+ {
+ for (long c = 0; c < m1.nc(); ++c)
+ {
+ if (m1(r,c) != m2(r,c))
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline bool operator!= (
+ const matrix_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ ) { return !(m1 == m2); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct op_pointer_to_mat;
+ template <typename T>
+ struct op_pointer_to_col_vect;
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager,
+ typename layout
+ >
+ struct matrix_traits<matrix<T,num_rows, num_cols, mem_manager, layout> >
+ {
+ typedef T type;
+ typedef const T& const_ret_type;
+ typedef mem_manager mem_manager_type;
+ typedef layout layout_type;
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+ const static long cost = 1;
+
+ };
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager,
+ typename layout
+ >
+ class matrix : public matrix_exp<matrix<T,num_rows,num_cols, mem_manager,layout> >
+ {
+
+ COMPILE_TIME_ASSERT(num_rows >= 0 && num_cols >= 0);
+
+ public:
+ typedef typename matrix_traits<matrix>::type type;
+ typedef typename matrix_traits<matrix>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix>::mem_manager_type mem_manager_type;
+ typedef typename matrix_traits<matrix>::layout_type layout_type;
+ const static long NR = matrix_traits<matrix>::NR;
+ const static long NC = matrix_traits<matrix>::NC;
+ const static long cost = matrix_traits<matrix>::cost;
+ typedef T* iterator;
+ typedef const T* const_iterator;
+
+ matrix ()
+ {
+ }
+
+ explicit matrix (
+ long length
+ )
+ {
+ // This object you are trying to call matrix(length) on is not a column or
+ // row vector.
+ COMPILE_TIME_ASSERT(NR == 1 || NC == 1);
+ DLIB_ASSERT( length >= 0,
+ "\tmatrix::matrix(length)"
+ << "\n\tlength must be at least 0"
+ << "\n\tlength: " << length
+ << "\n\tNR: " << NR
+ << "\n\tNC: " << NC
+ << "\n\tthis: " << this
+ );
+
+ if (NR == 1)
+ {
+ DLIB_ASSERT(NC == 0 || NC == length,
+ "\tmatrix::matrix(length)"
+ << "\n\tSince this is a statically sized matrix length must equal NC"
+ << "\n\tlength: " << length
+ << "\n\tNR: " << NR
+ << "\n\tNC: " << NC
+ << "\n\tthis: " << this
+ );
+
+ data.set_size(1,length);
+ }
+ else
+ {
+ DLIB_ASSERT(NR == 0 || NR == length,
+ "\tvoid matrix::set_size(length)"
+ << "\n\tSince this is a statically sized matrix length must equal NR"
+ << "\n\tlength: " << length
+ << "\n\tNR: " << NR
+ << "\n\tNC: " << NC
+ << "\n\tthis: " << this
+ );
+
+ data.set_size(length,1);
+ }
+ }
+
+ matrix (
+ long rows,
+ long cols
+ )
+ {
+ DLIB_ASSERT( (NR == 0 || NR == rows) && ( NC == 0 || NC == cols) &&
+ rows >= 0 && cols >= 0,
+ "\tvoid matrix::matrix(rows, cols)"
+ << "\n\tYou have supplied conflicting matrix dimensions"
+ << "\n\trows: " << rows
+ << "\n\tcols: " << cols
+ << "\n\tNR: " << NR
+ << "\n\tNC: " << NC
+ );
+ data.set_size(rows,cols);
+ }
+
+ template <typename EXP>
+ matrix (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // You get an error on this line if the matrix m contains a type that isn't
+ // the same as the type contained in the target matrix.
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true) ||
+ (is_matrix<typename EXP::type>::value == true));
+
+ // The matrix you are trying to assign m to is a statically sized matrix and
+ // m's dimensions don't match that of *this.
+ COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0);
+ COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0);
+ DLIB_ASSERT((NR == 0 || NR == m.nr()) && (NC == 0 || NC == m.nc()),
+ "\tmatrix& matrix::matrix(const matrix_exp& m)"
+ << "\n\tYou are trying to assign a dynamically sized matrix to a statically sized matrix with the wrong size"
+ << "\n\tNR: " << NR
+ << "\n\tNC: " << NC
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tthis: " << this
+ );
+
+ data.set_size(m.nr(),m.nc());
+
+ matrix_assign(*this, m);
+ }
+
+ matrix (
+ const matrix& m
+ ) : matrix_exp<matrix>(*this)
+ {
+ data.set_size(m.nr(),m.nc());
+ matrix_assign(*this, m);
+ }
+
+#ifdef DLIB_HAS_INITIALIZER_LISTS
+ matrix(const std::initializer_list<T>& l)
+ {
+ if (NR*NC != 0)
+ {
+ DLIB_ASSERT(l.size() == NR*NC,
+ "\t matrix::matrix(const std::initializer_list& l)"
+ << "\n\t You are trying to initialize a statically sized matrix with a list that doesn't have a matching size."
+ << "\n\t l.size(): "<< l.size()
+ << "\n\t NR*NC: "<< NR*NC);
+
+ data.set_size(NR, NC);
+ }
+ else if (NR!=0)
+ {
+ DLIB_ASSERT(l.size()%NR == 0,
+ "\t matrix::matrix(const std::initializer_list& l)"
+ << "\n\t You are trying to initialize a statically sized matrix with a list that doesn't have a compatible size."
+ << "\n\t l.size(): "<< l.size()
+ << "\n\t NR: "<< NR);
+
+ if (l.size() != 0)
+ data.set_size(NR, l.size()/NR);
+ }
+ else if (NC!=0)
+ {
+ DLIB_ASSERT(l.size()%NC == 0,
+ "\t matrix::matrix(const std::initializer_list& l)"
+ << "\n\t You are trying to initialize a statically sized matrix with a list that doesn't have a compatible size."
+ << "\n\t l.size(): "<< l.size()
+ << "\n\t NC: "<< NC);
+
+ if (l.size() != 0)
+ data.set_size(l.size()/NC, NC);
+ }
+ else if (l.size() != 0)
+ {
+ data.set_size(l.size(),1);
+ }
+
+ if (l.size() != 0)
+ {
+ T* d = &data(0,0);
+ for (auto&& v : l)
+ *d++ = v;
+ }
+
+ }
+
+ matrix& operator=(const std::initializer_list<T>& l)
+ {
+ matrix temp(l);
+ temp.swap(*this);
+ return *this;
+ }
+#endif // DLIB_HAS_INITIALIZER_LISTS
+
+#ifdef DLIB_HAS_RVALUE_REFERENCES
+ matrix(matrix&& item)
+ {
+ #ifdef MATLAB_MEX_FILE
+ // You can't move memory around when compiled in a matlab mex file and the
+ // different locations have different ownership settings.
+ if (data._private_is_owned_by_matlab() == item.data._private_is_owned_by_matlab())
+ {
+ swap(item);
+ }
+ else
+ {
+ data.set_size(item.nr(),item.nc());
+ matrix_assign(*this, item);
+ }
+ #else
+ swap(item);
+ #endif
+ }
+
+ matrix& operator= (
+ matrix&& rhs
+ )
+ {
+ #ifdef MATLAB_MEX_FILE
+ // You can't move memory around when compiled in a matlab mex file and the
+ // different locations have different ownership settings.
+ if (data._private_is_owned_by_matlab() == rhs.data._private_is_owned_by_matlab())
+ {
+ swap(rhs);
+ }
+ else
+ {
+ data.set_size(rhs.nr(),rhs.nc());
+ matrix_assign(*this, rhs);
+ }
+ #else
+ swap(rhs);
+ #endif
+ return *this;
+ }
+#endif // DLIB_HAS_RVALUE_REFERENCES
+
+ template <typename U, size_t len>
+ explicit matrix (
+ U (&array)[len]
+ )
+ {
+ COMPILE_TIME_ASSERT(NR*NC == len && len > 0);
+ size_t idx = 0;
+ for (long r = 0; r < NR; ++r)
+ {
+ for (long c = 0; c < NC; ++c)
+ {
+ data(r,c) = static_cast<T>(array[idx]);
+ ++idx;
+ }
+ }
+ }
+
+ T& operator() (
+ long r,
+ long c
+ )
+ {
+ DLIB_ASSERT(r < nr() && c < nc() &&
+ r >= 0 && c >= 0,
+ "\tT& matrix::operator(r,c)"
+ << "\n\tYou must give a valid row and column"
+ << "\n\tr: " << r
+ << "\n\tc: " << c
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tthis: " << this
+ );
+ return data(r,c);
+ }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const
+ {
+ DLIB_ASSERT(r < nr() && c < nc() &&
+ r >= 0 && c >= 0,
+ "\tconst T& matrix::operator(r,c)"
+ << "\n\tYou must give a valid row and column"
+ << "\n\tr: " << r
+ << "\n\tc: " << c
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tthis: " << this
+ );
+ return data(r,c);
+ }
+
+ T& operator() (
+ long i
+ )
+ {
+ // You can only use this operator on column vectors.
+ COMPILE_TIME_ASSERT(NC == 1 || NC == 0 || NR == 1 || NR == 0);
+ DLIB_ASSERT(nc() == 1 || nr() == 1,
+ "\tconst type matrix::operator(i)"
+ << "\n\tYou can only use this operator on column or row vectors"
+ << "\n\ti: " << i
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tthis: " << this
+ );
+ DLIB_ASSERT( 0 <= i && i < size(),
+ "\tconst type matrix::operator(i)"
+ << "\n\tYou must give a valid row/column number"
+ << "\n\ti: " << i
+ << "\n\tsize(): " << size()
+ << "\n\tthis: " << this
+ );
+ return data(i);
+ }
+
+ const T& operator() (
+ long i
+ ) const
+ {
+ // You can only use this operator on column vectors.
+ COMPILE_TIME_ASSERT(NC == 1 || NC == 0 || NR == 1 || NR == 0);
+ DLIB_ASSERT(nc() == 1 || nr() == 1,
+ "\tconst type matrix::operator(i)"
+ << "\n\tYou can only use this operator on column or row vectors"
+ << "\n\ti: " << i
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tthis: " << this
+ );
+ DLIB_ASSERT( 0 <= i && i < size(),
+ "\tconst type matrix::operator(i)"
+ << "\n\tYou must give a valid row/column number"
+ << "\n\ti: " << i
+ << "\n\tsize(): " << size()
+ << "\n\tthis: " << this
+ );
+ return data(i);
+ }
+
+ inline operator const type (
+ ) const
+ {
+ COMPILE_TIME_ASSERT(NC == 1 || NC == 0);
+ COMPILE_TIME_ASSERT(NR == 1 || NR == 0);
+ DLIB_ASSERT( nr() == 1 && nc() == 1 ,
+ "\tmatrix::operator const type"
+ << "\n\tYou can only attempt to implicit convert a matrix to a scalar if"
+ << "\n\tthe matrix is a 1x1 matrix"
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tthis: " << this
+ );
+ return data(0);
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray(
+ mxArray* mem
+ )
+ {
+ data._private_set_mxArray(mem);
+ }
+
+ mxArray* _private_release_mxArray(
+ )
+ {
+ return data._private_release_mxArray();
+ }
+
+ void _private_mark_owned_by_matlab()
+ {
+ data._private_mark_owned_by_matlab();
+ }
+
+ bool _private_is_owned_by_matlab()
+ {
+ return data._private_is_owned_by_matlab();
+ }
+#endif
+
+ void set_size (
+ long rows,
+ long cols
+ )
+ {
+ DLIB_ASSERT( (NR == 0 || NR == rows) && ( NC == 0 || NC == cols) &&
+ rows >= 0 && cols >= 0,
+ "\tvoid matrix::set_size(rows, cols)"
+ << "\n\tYou have supplied conflicting matrix dimensions"
+ << "\n\trows: " << rows
+ << "\n\tcols: " << cols
+ << "\n\tNR: " << NR
+ << "\n\tNC: " << NC
+ << "\n\tthis: " << this
+ );
+ if (nr() != rows || nc() != cols)
+ data.set_size(rows,cols);
+ }
+
+ void set_size (
+ long length
+ )
+ {
+ // This object you are trying to call set_size(length) on is not a column or
+ // row vector.
+ COMPILE_TIME_ASSERT(NR == 1 || NC == 1);
+ DLIB_ASSERT( length >= 0,
+ "\tvoid matrix::set_size(length)"
+ << "\n\tlength must be at least 0"
+ << "\n\tlength: " << length
+ << "\n\tNR: " << NR
+ << "\n\tNC: " << NC
+ << "\n\tthis: " << this
+ );
+
+ if (NR == 1)
+ {
+ DLIB_ASSERT(NC == 0 || NC == length,
+ "\tvoid matrix::set_size(length)"
+ << "\n\tSince this is a statically sized matrix length must equal NC"
+ << "\n\tlength: " << length
+ << "\n\tNR: " << NR
+ << "\n\tNC: " << NC
+ << "\n\tthis: " << this
+ );
+
+ if (nc() != length)
+ data.set_size(1,length);
+ }
+ else
+ {
+ DLIB_ASSERT(NR == 0 || NR == length,
+ "\tvoid matrix::set_size(length)"
+ << "\n\tSince this is a statically sized matrix length must equal NR"
+ << "\n\tlength: " << length
+ << "\n\tNR: " << NR
+ << "\n\tNC: " << NC
+ << "\n\tthis: " << this
+ );
+
+ if (nr() != length)
+ data.set_size(length,1);
+ }
+ }
+
+ long nr (
+ ) const { return data.nr(); }
+
+ long nc (
+ ) const { return data.nc(); }
+
+ long size (
+ ) const { return data.nr()*data.nc(); }
+
+ template <typename U, size_t len>
+ matrix& operator= (
+ U (&array)[len]
+ )
+ {
+ COMPILE_TIME_ASSERT(NR*NC == len && len > 0);
+ size_t idx = 0;
+ for (long r = 0; r < NR; ++r)
+ {
+ for (long c = 0; c < NC; ++c)
+ {
+ data(r,c) = static_cast<T>(array[idx]);
+ ++idx;
+ }
+ }
+ return *this;
+ }
+
+ template <typename EXP>
+ matrix& operator= (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // You get an error on this line if the matrix you are trying to
+ // assign m to is a statically sized matrix and m's dimensions don't
+ // match that of *this.
+ COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0);
+ COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0);
+ DLIB_ASSERT((NR == 0 || nr() == m.nr()) &&
+ (NC == 0 || nc() == m.nc()),
+ "\tmatrix& matrix::operator=(const matrix_exp& m)"
+ << "\n\tYou are trying to assign a dynamically sized matrix to a statically sized matrix with the wrong size"
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tthis: " << this
+ );
+
+ // You get an error on this line if the matrix m contains a type that isn't
+ // the same as the type contained in the target matrix.
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true) ||
+ (is_matrix<typename EXP::type>::value == true));
+ if (m.destructively_aliases(*this) == false)
+ {
+ // This if statement is seemingly unnecessary since set_size() contains this
+ // exact same if statement. However, structuring the code this way causes
+ // gcc to handle the way it inlines this function in a much more favorable way.
+ if (data.nr() == m.nr() && data.nc() == m.nc())
+ {
+ matrix_assign(*this, m);
+ }
+ else
+ {
+ set_size(m.nr(),m.nc());
+ matrix_assign(*this, m);
+ }
+ }
+ else
+ {
+ // we have to use a temporary matrix object here because
+ // *this is aliased inside the matrix_exp m somewhere.
+ matrix temp;
+ temp.set_size(m.nr(),m.nc());
+ matrix_assign(temp, m);
+ temp.swap(*this);
+ }
+ return *this;
+ }
+
+ template <typename EXP>
+ matrix& operator += (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // The matrix you are trying to assign m to is a statically sized matrix and
+ // m's dimensions don't match that of *this.
+ COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0);
+ COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0);
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true));
+ if (nr() == m.nr() && nc() == m.nc())
+ {
+ if (m.destructively_aliases(*this) == false)
+ {
+ matrix_assign(*this, *this + m);
+ }
+ else
+ {
+ // we have to use a temporary matrix object here because
+ // this->data is aliased inside the matrix_exp m somewhere.
+ matrix temp;
+ temp.set_size(m.nr(),m.nc());
+ matrix_assign(temp, *this + m);
+ temp.swap(*this);
+ }
+ }
+ else
+ {
+ DLIB_ASSERT(size() == 0,
+ "\t const matrix::operator+=(m)"
+ << "\n\t You are trying to add two matrices that have incompatible dimensions.");
+ *this = m;
+ }
+ return *this;
+ }
+
+
+ template <typename EXP>
+ matrix& operator -= (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // The matrix you are trying to assign m to is a statically sized matrix and
+ // m's dimensions don't match that of *this.
+ COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0);
+ COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0);
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true));
+ if (nr() == m.nr() && nc() == m.nc())
+ {
+ if (m.destructively_aliases(*this) == false)
+ {
+ matrix_assign(*this, *this - m);
+ }
+ else
+ {
+ // we have to use a temporary matrix object here because
+ // this->data is aliased inside the matrix_exp m somewhere.
+ matrix temp;
+ temp.set_size(m.nr(),m.nc());
+ matrix_assign(temp, *this - m);
+ temp.swap(*this);
+ }
+ }
+ else
+ {
+ DLIB_ASSERT(size() == 0,
+ "\t const matrix::operator-=(m)"
+ << "\n\t You are trying to subtract two matrices that have incompatible dimensions.");
+ *this = -m;
+ }
+ return *this;
+ }
+
+ template <typename EXP>
+ matrix& operator *= (
+ const matrix_exp<EXP>& m
+ )
+ {
+ *this = *this * m;
+ return *this;
+ }
+
+ matrix& operator += (
+ const matrix& m
+ )
+ {
+ const long size = m.nr()*m.nc();
+ if (nr() == m.nr() && nc() == m.nc())
+ {
+ for (long i = 0; i < size; ++i)
+ data(i) += m.data(i);
+ }
+ else
+ {
+ DLIB_ASSERT(this->size() == 0,
+ "\t const matrix::operator+=(m)"
+ << "\n\t You are trying to add two matrices that have incompatible dimensions.");
+
+ set_size(m.nr(), m.nc());
+ for (long i = 0; i < size; ++i)
+ data(i) = m.data(i);
+ }
+ return *this;
+ }
+
+ matrix& operator -= (
+ const matrix& m
+ )
+ {
+ const long size = m.nr()*m.nc();
+ if (nr() == m.nr() && nc() == m.nc())
+ {
+ for (long i = 0; i < size; ++i)
+ data(i) -= m.data(i);
+ }
+ else
+ {
+ DLIB_ASSERT(this->size() == 0,
+ "\t const matrix::operator-=(m)"
+ << "\n\t You are trying to subtract two matrices that have incompatible dimensions.");
+ set_size(m.nr(), m.nc());
+ for (long i = 0; i < size; ++i)
+ data(i) = -m.data(i);
+ }
+ return *this;
+ }
+
+ matrix& operator += (
+ const T val
+ )
+ {
+ const long size = nr()*nc();
+ for (long i = 0; i < size; ++i)
+ data(i) += val;
+
+ return *this;
+ }
+
+ matrix& operator -= (
+ const T val
+ )
+ {
+ const long size = nr()*nc();
+ for (long i = 0; i < size; ++i)
+ data(i) -= val;
+
+ return *this;
+ }
+
+ matrix& operator *= (
+ const T a
+ )
+ {
+ *this = *this * a;
+ return *this;
+ }
+
+ matrix& operator /= (
+ const T a
+ )
+ {
+ *this = *this / a;
+ return *this;
+ }
+
+ matrix& operator= (
+ const matrix& m
+ )
+ {
+ if (this != &m)
+ {
+ set_size(m.nr(),m.nc());
+ const long size = m.nr()*m.nc();
+ for (long i = 0; i < size; ++i)
+ data(i) = m.data(i);
+ }
+ return *this;
+ }
+
+ void swap (
+ matrix& item
+ )
+ {
+ data.swap(item.data);
+ }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>&
+ ) const { return false; }
+
+ bool aliases (
+ const matrix_exp<matrix<T,num_rows,num_cols, mem_manager,layout> >& item
+ ) const { return (this == &item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>&
+ ) const { return false; }
+
+ // These two aliases() routines are defined in matrix_mat.h
+ bool aliases (
+ const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
+ ) const;
+ bool aliases (
+ const matrix_exp<matrix_op<op_pointer_to_col_vect<T> > >& item
+ ) const;
+
+ iterator begin()
+ {
+ if (size() != 0)
+ return &data(0,0);
+ else
+ return 0;
+ }
+
+ iterator end()
+ {
+ if (size() != 0)
+ return &data(0,0)+size();
+ else
+ return 0;
+ }
+
+ const_iterator begin() const
+ {
+ if (size() != 0)
+ return &data(0,0);
+ else
+ return 0;
+ }
+
+ const_iterator end() const
+ {
+ if (size() != 0)
+ return &data(0,0)+size();
+ else
+ return 0;
+ }
+
+ private:
+ struct literal_assign_helper
+ {
+ /*
+ This struct is a helper struct returned by the operator<<() function below. It is
+ used primarily to enable us to put DLIB_CASSERT statements on the usage of the
+ operator<< form of matrix assignment.
+ */
+
+ literal_assign_helper(const literal_assign_helper& item) : m(item.m), r(item.r), c(item.c), has_been_used(false) {}
+ explicit literal_assign_helper(matrix* m_): m(m_), r(0), c(0),has_been_used(false) {next();}
+ ~literal_assign_helper() noexcept(false)
+ {
+ DLIB_CASSERT(!has_been_used || r == m->nr(),
+ "You have used the matrix comma based assignment incorrectly by failing to\n"
+ "supply a full set of values for every element of a matrix object.\n");
+ }
+
+ const literal_assign_helper& operator, (
+ const T& val
+ ) const
+ {
+ DLIB_CASSERT(r < m->nr() && c < m->nc(),
+ "You have used the matrix comma based assignment incorrectly by attempting to\n" <<
+ "supply more values than there are elements in the matrix object being assigned to.\n\n" <<
+ "Did you forget to call set_size()?"
+ << "\n\t r: " << r
+ << "\n\t c: " << c
+ << "\n\t m->nr(): " << m->nr()
+ << "\n\t m->nc(): " << m->nc());
+ (*m)(r,c) = val;
+ next();
+ has_been_used = true;
+ return *this;
+ }
+
+ private:
+
+ friend class matrix<T,num_rows,num_cols,mem_manager,layout>;
+
+ void next (
+ ) const
+ {
+ ++c;
+ if (c == m->nc())
+ {
+ c = 0;
+ ++r;
+ }
+ }
+
+ matrix* m;
+ mutable long r;
+ mutable long c;
+ mutable bool has_been_used;
+ };
+
+ public:
+
+ matrix& operator = (
+ const literal_assign_helper& val
+ )
+ {
+ *this = *val.m;
+ return *this;
+ }
+
+ const literal_assign_helper operator = (
+ const T& val
+ )
+ {
+ // assign the given value to every spot in this matrix
+ const long size = nr()*nc();
+ for (long i = 0; i < size; ++i)
+ data(i) = val;
+
+ // Now return the literal_assign_helper so that the user
+ // can use the overloaded comma notation to initialize
+ // the matrix if they want to.
+ return literal_assign_helper(this);
+ }
+
+ private:
+
+
+ typename layout::template layout<T,NR,NC,mem_manager> data;
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename mm,
+ typename l
+ >
+ void swap(
+ matrix<T,NR,NC,mm,l>& a,
+ matrix<T,NR,NC,mm,l>& b
+ ) { a.swap(b); }
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename mm,
+ typename l
+ >
+ void serialize (
+ const matrix<T,NR,NC,mm,l>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ // The reason the serialization is a little funny is because we are trying to
+ // maintain backwards compatibility with an older serialization format used by
+ // dlib while also encoding things in a way that lets the array2d and matrix
+ // objects have compatible serialization formats.
+ serialize(-item.nr(),out);
+ serialize(-item.nc(),out);
+ for (long r = 0; r < item.nr(); ++r)
+ {
+ for (long c = 0; c < item.nc(); ++c)
+ {
+ serialize(item(r,c),out);
+ }
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing dlib::matrix");
+ }
+ }
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename mm,
+ typename l
+ >
+ void deserialize (
+ matrix<T,NR,NC,mm,l>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ long nr, nc;
+ deserialize(nr,in);
+ deserialize(nc,in);
+
+ // this is the newer serialization format
+ if (nr < 0 || nc < 0)
+ {
+ nr *= -1;
+ nc *= -1;
+ }
+
+ if (NR != 0 && nr != NR)
+ throw serialization_error("Error while deserializing a dlib::matrix. Invalid rows");
+ if (NC != 0 && nc != NC)
+ throw serialization_error("Error while deserializing a dlib::matrix. Invalid columns");
+
+ item.set_size(nr,nc);
+ for (long r = 0; r < nr; ++r)
+ {
+ for (long c = 0; c < nc; ++c)
+ {
+ deserialize(item(r,c),in);
+ }
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing a dlib::matrix");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename mm,
+ typename l
+ >
+ void serialize (
+ const ramdump_t<matrix<T,NR,NC,mm,l>>& item_,
+ std::ostream& out
+ )
+ {
+ auto& item = item_.item;
+ serialize(item.nr(), out);
+ serialize(item.nc(), out);
+ if (item.size() != 0)
+ out.write((char*)&item(0,0), sizeof(item(0,0))*item.size());
+ }
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename mm,
+ typename l
+ >
+ void deserialize (
+ ramdump_t<matrix<T,NR,NC,mm,l>>&& item_,
+ std::istream& in
+ )
+ {
+ auto& item = item_.item;
+ long nr, nc;
+ deserialize(nr, in);
+ deserialize(nc, in);
+ item.set_size(nr,nc);
+ if (item.size() != 0)
+ in.read((char*)&item(0,0), sizeof(item(0,0))*item.size());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ std::ostream& operator<< (
+ std::ostream& out,
+ const matrix_exp<EXP>& m
+ )
+ {
+ using namespace std;
+ const streamsize old = out.width();
+
+ // first figure out how wide we should make each field
+ string::size_type w = 0;
+ ostringstream sout;
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ sout << m(r,c);
+ w = std::max(sout.str().size(),w);
+ sout.str("");
+ }
+ }
+
+ // now actually print it
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ out.width(static_cast<streamsize>(w));
+ out << m(r,c) << " ";
+ }
+ out << "\n";
+ }
+ out.width(old);
+ return out;
+ }
+
+ /*
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ std::istream& operator>> (
+ std::istream& in,
+ matrix<T,NR,NC,MM,L>& m
+ );
+
+ This function is defined inside the matrix_read_from_istream.h file.
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ class print_matrix_as_csv_helper
+ {
+ /*!
+ This object is used to define an io manipulator for matrix expressions.
+ In particular, this code allows you to write statements like:
+ cout << csv << yourmatrix;
+ and have it print the matrix with commas separating each element.
+ !*/
+ public:
+ print_matrix_as_csv_helper (std::ostream& out_) : out(out_) {}
+
+ template <typename EXP>
+ std::ostream& operator<< (
+ const matrix_exp<EXP>& m
+ )
+ {
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ if (c+1 == m.nc())
+ out << m(r,c) << "\n";
+ else
+ out << m(r,c) << ", ";
+ }
+ }
+ return out;
+ }
+
+ private:
+ std::ostream& out;
+ };
+
+ class print_matrix_as_csv {};
+ const print_matrix_as_csv csv = print_matrix_as_csv();
+ inline print_matrix_as_csv_helper operator<< (
+ std::ostream& out,
+ const print_matrix_as_csv&
+ )
+ {
+ return print_matrix_as_csv_helper(out);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ class const_temp_matrix;
+
+ template <
+ typename EXP
+ >
+ struct matrix_traits<const_temp_matrix<EXP> >
+ {
+ typedef typename EXP::type type;
+ typedef typename EXP::const_ret_type const_ret_type;
+ typedef typename EXP::mem_manager_type mem_manager_type;
+ typedef typename EXP::layout_type layout_type;
+ const static long NR = EXP::NR;
+ const static long NC = EXP::NC;
+ const static long cost = 1;
+ };
+
+ template <typename EXP>
+ class const_temp_matrix : public matrix_exp<const_temp_matrix<EXP> >, noncopyable
+ {
+ public:
+ typedef typename matrix_traits<const_temp_matrix>::type type;
+ typedef typename matrix_traits<const_temp_matrix>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<const_temp_matrix>::mem_manager_type mem_manager_type;
+ typedef typename matrix_traits<const_temp_matrix>::layout_type layout_type;
+ const static long NR = matrix_traits<const_temp_matrix>::NR;
+ const static long NC = matrix_traits<const_temp_matrix>::NC;
+ const static long cost = matrix_traits<const_temp_matrix>::cost;
+
+ const_temp_matrix (
+ const matrix_exp<EXP>& item
+ ) :
+ ref_(item.ref())
+ {}
+ const_temp_matrix (
+ const EXP& item
+ ) :
+ ref_(item)
+ {}
+
+ const_ret_type operator() (
+ long r,
+ long c
+ ) const { return ref_(r,c); }
+
+ const_ret_type operator() ( long i ) const
+ { return ref_(i); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const { return ref_.aliases(item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const { return ref_.destructively_aliases(item); }
+
+ long nr (
+ ) const { return ref_.nr(); }
+
+ long nc (
+ ) const { return ref_.nc(); }
+
+ private:
+
+ typename conditional_matrix_temp<const EXP, (EXP::cost <= 1)>::type ref_;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ typedef matrix<double,0,0,default_memory_manager,column_major_layout> matrix_colmajor;
+ typedef matrix<float,0,0,default_memory_manager,column_major_layout> fmatrix_colmajor;
+
+}
+
+#ifdef _MSC_VER
+// put warnings back to their default settings
+#pragma warning(default : 4355)
+#pragma warning(default : 4723)
+#endif
+
+#endif // DLIB_MATRIx_
+
diff --git a/ml/dlib/dlib/matrix/matrix_abstract.h b/ml/dlib/dlib/matrix/matrix_abstract.h
new file mode 100644
index 000000000..0d05ce981
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_abstract.h
@@ -0,0 +1,857 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MATRIx_ABSTRACT_
+#ifdef DLIB_MATRIx_ABSTRACT_
+
+#include "matrix_exp_abstract.h"
+#include "../serialize.h"
+#include "../algs.h"
+#include "matrix_data_layout_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*
+ Note that these operator prototypes are not correct C++ (the real versions, which
+ you can see in the implementation are really complex and so probably would
+ distract/confuse people if shown here). Think of this as just a list of the
+ operators available to you and what they do.
+ */
+
+ const matrix_exp operator* (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1.nc() == m2.nr()
+ - m1.size() > 0 && m2.size() > 0
+ (you can't multiply any sort of empty matrices together)
+ - m1 and m2 both contain elements of the same type
+ ensures
+ - returns the result of doing the matrix multiplication m1*m2. The resulting
+ matrix will have m1.nr() rows and m2.nc() columns.
+ !*/
+
+ const matrix_exp operator+ (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1.nr() == m2.nr()
+ - m1.nc() == m2.nc()
+ - m1 and m2 both contain elements of the same type
+ ensures
+ - returns a matrix R such that for all valid r and c:
+ R(r,c) == m1(r,c) + m2(r,c)
+ (i.e. returns the result of doing a pairwise addition of the matrices m1 and m2.)
+ The resulting matrix will have the same dimensions as the originals.
+ !*/
+
+ const matrix_exp operator- (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1.nr() == m2.nr()
+ - m1.nc() == m2.nc()
+ - m1 and m2 both contain elements of the same type
+ ensures
+ - returns a matrix R such that for all valid r and c:
+ R(r,c) == m1(r,c) - m2(r,c)
+ (i.e. returns the result of doing a pairwise subtraction of the matrices m1 and m2.)
+ The resulting matrix will have the same dimensions as the originals.
+ !*/
+
+ template <typename T>
+ const matrix_exp operator* (
+ const matrix_exp& m,
+ const T& value
+ );
+ /*!
+ ensures
+ - returns the result of multiplying all the elements of matrix m by the given
+ scalar value. The resulting matrix will have the same dimensions as m.
+ !*/
+
+ template <typename T>
+ const matrix_exp operator* (
+ const T& value,
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns the result of multiplying all the elements of matrix m by the given
+ scalar value. The resulting matrix will have the same dimensions as m.
+ !*/
+
+ const matrix_exp operator- (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns -1*m
+ !*/
+
+ template <typename T>
+ const matrix_exp operator/ (
+ const matrix_exp& m,
+ const T& value
+ );
+ /*!
+ ensures
+ - returns the result of dividing all the elements of matrix m by the given
+ scalar value. The resulting matrix will have the same dimensions as m.
+ !*/
+
+ template <typename T>
+ const matrix_exp operator/ (
+ const T& value,
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns the result of dividing the given scalar value by all the elements
+ of matrix m. The resulting matrix will have the same dimensions as m.
+ !*/
+
+ template <typename T>
+ const matrix_exp operator+ (
+ const matrix_exp& m,
+ const T& value
+ );
+ /*!
+ ensures
+ - returns the result of adding value to all the elements of matrix m.
+ The resulting matrix will have the same dimensions as m.
+ !*/
+
+ template <typename T>
+ const matrix_exp operator+ (
+ const T& value,
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns the result of adding value to all the elements of matrix m.
+ The resulting matrix will have the same dimensions as m.
+ !*/
+
+ template <typename T>
+ const matrix_exp operator- (
+ const matrix_exp& m,
+ const T& value
+ );
+ /*!
+ ensures
+ - returns the result of subtracting value from all the elements of matrix m.
+ The resulting matrix will have the same dimensions as m.
+ !*/
+
+ template <typename T>
+ const matrix_exp operator- (
+ const T& value,
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - Returns a matrix M such that:
+ - M has the same dimensions as m
+ - M contains the same type of element as m
+ - for all valid r and c:
+ - M(r,c) == value - m(r,c)
+ !*/
+
+ bool operator== (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ ensures
+ - if (m1.nr() == m2.nr() && m1.nc() == m2.nc() &&
+ for all valid r and c: m1(r,c) == m2(r,c) ) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool operator!= (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ ensures
+ - returns !(m1 == m2)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows = 0,
+ long num_cols = 0,
+ typename mem_manager = default_memory_manager,
+ typename layout = row_major_layout
+ >
+ class matrix : public matrix_exp<matrix<T,num_rows,num_cols,mem_manager,layout> >
+ {
+ /*!
+ REQUIREMENTS ON num_rows and num_cols
+ both must be bigger than or equal to 0
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ REQUIREMENTS ON layout
+ Must be either row_major_layout or column_major_layout
+
+ INITIAL VALUE
+ - if (num_rows > 0) then
+ - nr() == num_rows
+ - else
+ - nr() == 0
+
+ - if (num_cols > 0) then
+ - nc() == num_cols
+ - else
+ - nc() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a matrix of nr() rows and nc() columns. This object
+ is also a matrix_exp. Thus it can be used in all of the above
+ global operators.
+
+ The number of rows and columns of this object are determined by the template
+ arguments num_rows and num_cols. If num_rows or num_cols are 0 then
+ the matrix starts out empty (i.e. nr() == 0 and nc() == 0) and you may change
+ its size via the set_size() member function.
+
+ Setting num_rows or num_cols to something other than 0 causes that dimension
+ to have a fixed size. Setting a fixed size at compile time is useful because
+ any errors related to operating on matrices with incompatible dimensions will
+ be detected at compile time. It also allows the compiler to perform loop
+ unrolling which can result in substantially faster code.
+
+ Also note that the elements of this matrix are laid out in memory by the layout
+ object supplied as a template argument to this class. The row_major_layout
+ sets elements down contiguously in memory and in row major order. Additionally,
+ all memory allocations are performed using the memory manager object supplied as
+ a template argument to this class.
+ !*/
+
+ public:
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+ typedef layout layout_type;
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+ const static long cost = 1;
+ typedef T* iterator;
+ typedef const T* const_iterator;
+
+ matrix (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #aliases(*this) == true
+ - #ref().aliases(*this) == true
+ !*/
+
+ explicit matrix (
+ long length
+ );
+ /*!
+ requires
+ - NR == 1 || NC == 1 (i.e. this must be a column or row vector)
+ - length >= 0
+ - if (NR == 1 && NC > 0) then
+ - length == NC
+ - if (NC == 1 && NR > 0) then
+ - length == NR
+ ensures
+ - #*this is properly initialized
+ - #aliases(*this) == true
+ - #ref().aliases(*this) == true
+ - if (NR == 1) then
+ - #nr() == 1
+ - #nc() == length
+ - else
+ - #nr() == length
+ - #nc() == 1
+ !*/
+
+ matrix (
+ long rows,
+ long cols
+ );
+ /*!
+ requires
+ - rows == NR || NR == 0
+ - cols == NC || NC == 0
+ - rows >= 0 && cols >= 0
+ ensures
+ - #*this is properly initialized
+ - #aliases(*this) == true
+ - #ref().aliases(*this) == true
+ - #nr() == rows
+ - #nc() == cols
+ !*/
+
+ template <typename EXP>
+ matrix (
+ const matrix_exp<EXP>& m
+ );
+ /*!
+ requires
+ - matrix_exp<EXP>::type == T
+ (i.e. m contains the same type as *this does)
+ - if (NR != 0) then NR == m.nr()
+ - if (NC != 0) then NC == m.nc()
+ ensures
+ - #*this == m
+ - #aliases(*this) == true
+ - #ref().aliases(*this) == true
+ !*/
+
+ template <typename U, size_t len>
+ explicit matrix (
+ U (&array)[len]
+ );
+ /*!
+ requires
+ - NR != 0 && NC != 0 (i.e. you can only use this constructor on statically sized matrices)
+ - len == nr()*nc() (i.e. the array you give here must be the right size)
+ ensures
+ - for all valid r and c:
+ #(*this)(r,c) == array[r*nc() + c]
+ (i.e. initializes this matrix with the contents of the given array)
+ - #aliases(*this) == true
+ - #ref().aliases(*this) == true
+ !*/
+
+ matrix(
+ const std::initializer_list<T>& l
+ );
+ /*!
+ requires
+ - This matrix is capable of having a size() == l.size(). Therefore, if
+ NR*NC != 0 then l.size() must equal NR*NC. Alternatively, if NR or NC is
+ != 0 then l.size() must be a multiple of the non-zero NR or NC.
+ ensures
+ - #size() == l.size()
+ - The contents of l are enumerated and read into the matrix in row major order.
+ - if (NR != 0) then
+ - #nr() == NR
+ - #nc() == l.size()/NR
+ - if (NC != 0) then
+ - #nr() == l.size()/NC
+ - #nc() == NC
+ - if (NR*NC==0) then
+ - #nr() == l.size()
+ - #nc() == 1
+ - #aliases(*this) == true
+ - #ref().aliases(*this) == true
+ !*/
+
+ T& operator() (
+ long r,
+ long c
+ );
+ /*!
+ requires
+ - 0 <= r < nr()
+ - 0 <= c < nc()
+ ensures
+ - returns a reference to the value at the given row and column in
+ this matrix.
+ !*/
+
+ const T& operator() (
+ long r,
+ long c
+ ) const;
+ /*!
+ requires
+ - 0 <= r < nr()
+ - 0 <= c < nc()
+ ensures
+ - returns a const reference to the value at the given row and column in
+ this matrix.
+ !*/
+
+ T& operator() (
+ long i
+ );
+ /*!
+ requires
+ - nc() == 1 || nr() == 1 (i.e. this must be a column or row vector)
+ - 0 <= i < size()
+ ensures
+ - if (nc() == 1) then
+ - returns a reference to (*this)(i,0)
+ - else
+ - returns a reference to (*this)(0,i)
+ !*/
+
+ const T& operator() (
+ long i
+ ) const;
+ /*!
+ requires
+ - nc() == 1 || nr() == 1 (i.e. this must be a column or row vector)
+ - 0 <= i < size()
+ ensures
+ - if (nc() == 1) then
+ - returns a reference to (*this)(i,0)
+ - else
+ - returns a reference to (*this)(0,i)
+ !*/
+
+ operator const type (
+ ) const;
+ /*!
+ requires
+ - nr() == 1
+ - nc() == 1
+ ensures
+ - returns (*this)(0,0)
+ !*/
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in this matrix
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in this matrix
+ !*/
+
+ long size (
+ ) const;
+ /*!
+ ensures
+ - returns nr()*nc()
+ !*/
+
+ void set_size (
+ long rows,
+ long cols
+ );
+ /*!
+ requires
+ - rows == NR || NR == 0
+ - cols == NC || NC == 0
+ - rows >= 0 && cols >= 0
+ ensures
+ - #nr() == rows
+ - #nc() == cols
+ !*/
+
+ void set_size (
+ long length
+ );
+ /*!
+ requires
+ - NR == 1 || NC == 1 (i.e. this must be a column or row vector)
+ - length >= 0
+ - if (NR == 1 && NC > 0) then
+ - length == NC
+ - if (NC == 1 && NR > 0) then
+ - length == NR
+ ensures
+ - if (NR == 1) then
+ - #nr() == 1
+ - #nc() == length
+ - else
+ - #nr() == length
+ - #nc() == 1
+ !*/
+
+ template <typename U, size_t len>
+ matrix& operator= (
+ U (&array)[len]
+ );
+ /*!
+ requires
+ - len == nr()*nc() (i.e. the array you give here must be the right size)
+ ensures
+ - for all valid r and c:
+ #(*this)(r,c) == array[r*nc() + c]
+ (i.e. loads this matrix with the contents of the given array)
+ - returns *this
+ !*/
+
+ matrix& operator=(
+ const std::initializer_list<T>& l
+ );
+ /*!
+ requires
+ - This matrix is capable of having a size() == l.size(). Therefore, if
+ NR*NC != 0 then l.size() must equal NR*NC. Alternatively, if NR or NC is
+ != 0 then l.size() must be a multiple of the non-zero NR or NC.
+ ensures
+ - Assigns the contents of l to *this by performing: matrix(l).swap(*this)
+ - returns *this
+ !*/
+
+ template <typename EXP>
+ matrix& operator= (
+ const matrix_exp<EXP>& m
+ );
+ /*!
+ requires
+ - matrix_exp<EXP>::type == T
+ (i.e. m contains the same type as *this does)
+ - if (NR != 0) then NR == m.nr()
+ - if (NC != 0) then NC == m.nc()
+ ensures
+ - copies the given matrix expression m to *this
+ - returns *this
+ !*/
+
+ template <typename EXP>
+ matrix& operator += (
+ const matrix_exp<EXP>& m
+ );
+ /*!
+ requires
+ - matrix_exp<EXP>::type == T
+ - One of the following is true:
+ - nr() == m.nr() && nc() == m.nc()
+ - size() == 0
+ (i.e. this matrix must have matching dimensions or it must be empty)
+ ensures
+ - if (nr() == m.nr() && nc() == m.nc()) then
+ - #(*this) == *this + m
+ - else
+ - #(*this) == m
+ (i.e. if the dimensions don't match then this function performs a
+ normal assignment)
+ - returns *this
+ !*/
+
+ template <typename EXP>
+ matrix& operator -= (
+ const matrix_exp<EXP>& m
+ );
+ /*!
+ requires
+ - matrix_exp<EXP>::type == T
+ - One of the following is true:
+ - nr() == m.nr() && nc() == m.nc()
+ - size() == 0
+ (i.e. this matrix must have matching dimensions or it must be empty)
+ ensures
+ - if (nr() == m.nr() && nc() == m.nc()) then
+ - #(*this) == *this - m
+ - else
+ - #(*this) == -m
+ - returns *this
+ !*/
+
+ template <typename EXP>
+ matrix& operator *= (
+ const matrix_exp<EXP>& m
+ );
+ /*!
+ requires
+ - matrix_exp<EXP>::type == T
+ (i.e. m must contain the same type of element as *this)
+ - nc() == m.nr()
+ - size() > 0 && m.size() > 0
+ (you can't multiply any sort of empty matrices together)
+ ensures
+ - #(*this) == *this * m
+ - returns *this
+ !*/
+
+ matrix& operator *= (
+ const T& a
+ );
+ /*!
+ ensures
+ - #(*this) == *this * a
+ - returns *this
+ !*/
+
+ matrix& operator /= (
+ const T& a
+ );
+ /*!
+ ensures
+ - #(*this) == *this / a
+ - returns *this
+ !*/
+
+ matrix& operator += (
+ const T& a
+ );
+ /*!
+ ensures
+ - #(*this) == *this + a
+ - returns *this
+ !*/
+
+ matrix& operator -= (
+ const T& a
+ );
+ /*!
+ ensures
+ - #(*this) == *this - a
+ - returns *this
+ !*/
+
+ const literal_assign_helper operator = (
+ const T& val
+ );
+ /*!
+ This function is somewhat different than all the others defined in this file.
+ The purpose of this function is to enable you to easily initialize a matrix object.
+ For example:
+ matrix<double> m(2,3);
+ m = 1,2,3,
+ 4,5,6;
+
+ The above code creates a matrix m with 2 rows and 3 columns and sets it so that
+ it contains the matrix | 1 2 3 |
+ | 4 5 6 |
+
+ You can also use this function to assign to all elements of a matrix. So
+ saying m = 3; would assign all elements of m equal to 3.
+
+ Note that to use this method of assignment it is required that you supply
+ exactly m.size() or 1 values so that the matrix is fully initialized. Supplying
+ fewer or more than that is an error that will cause a dlib::fatal_error to be
+ thrown.
+
+ Note also that using an expression of the form m = scalar; when m.size() == 0
+ is legal but has no effect on m.
+ !*/
+
+ void swap (
+ matrix& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ iterator begin(
+ );
+ /*!
+ ensures
+ - returns a random access iterator pointing to the first element in this
+ matrix.
+ - The iterator will iterate over the elements of the matrix in row major
+ order if layout is row_major_layout or in column major order if layout is
+ column_major_layout.
+ !*/
+
+ iterator end(
+ );
+ /*!
+ ensures
+ - returns a random access iterator pointing to one past the end of the last
+ element in this matrix.
+ !*/
+
+ const_iterator begin(
+ ) const;
+ /*!
+ ensures
+ - returns a random access iterator pointing to the first element in this
+ matrix.
+ - The iterator will iterate over the elements of the matrix in row major
+ order if layout is row_major_layout or in column major order if layout is
+ column_major_layout.
+ !*/
+
+ const_iterator end(
+ ) const;
+ /*!
+ ensures
+ - returns a random access iterator pointing to one past the end of the last
+ element in this matrix.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A matrix_colmajor
+ This is just a typedef of the matrix object that uses column major layout.
+ !*/
+ typedef matrix<double,0,0,default_memory_manager,column_major_layout> matrix_colmajor;
+
+ /*!A fmatrix_colmajor
+ This is just a typedef of the matrix object that uses column major layout.
+ !*/
+ typedef matrix<float,0,0,default_memory_manager,column_major_layout> fmatrix_colmajor;
+
+// ----------------------------------------------------------------------------------------
+template <
+ typename T,
+ long NR,
+ long NC,
+ typename mm,
+ typename l
+ >
+ void swap(
+ matrix<T,NR,NC,mm,l>& a,
+ matrix<T,NR,NC,mm,l>& b
+ ) { a.swap(b); }
+ /*!
+ Provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename mm,
+ typename l
+ >
+ void serialize (
+ const matrix<T,NR,NC,mm,l>& item,
+ std::ostream& out
+ );
+ /*!
+ Provides serialization support. Note that the serialization formats used by the
+ dlib::matrix and dlib::array2d objects are compatible. That means you can load the
+ serialized data from one into another and it will work properly.
+ !*/
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename mm,
+ typename l
+ >
+ void deserialize (
+ matrix<T,NR,NC,mm,l>& item,
+ std::istream& in
+ );
+ /*!
+ Provides deserialization support
+ !*/
+
+ template <
+ typename EXP
+ >
+ std::ostream& operator<< (
+ std::ostream& out,
+ const matrix_exp<EXP>& m
+ );
+ /*!
+ ensures
+ - writes m to the given out stream in a form suitable for human consumption.
+ - returns out
+ !*/
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ std::istream& operator>> (
+ std::istream& in,
+ matrix<T,NR,NC,MM,L>& m
+ );
+ /*!
+ ensures
+ - Tries to read a matrix from the given input stream and store it into #m.
+ - The format expected is the text format output by the above operator<<().
+ That is, the format should be a grid of text such as:
+ 2 3 4
+ 5 2 6
+ - The separation between numbers can be any number of whitespace characters or
+ commas.
+ - The matrix data is assumed to end upon the first blank line or end-of-file,
+ whichever comes first. This means you can create an input stream with
+ multiple matrices in it by separating them with empty lines.
+ - returns in.
+ - If there was a formatting error or something which prevents the input data
+ from being parsed into a matrix then #in.fail() == true.
+ !*/
+
+ /*!A csv
+ This object is used to define an io manipulator for matrix expressions. In
+ particular, you can write statements like:
+ cout << csv << yourmatrix;
+ and have it print the matrix with commas separating each element.
+ !*/
+ some_undefined_iomnaip_type csv;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ class const_temp_matrix : public matrix_exp<const_temp_matrix<EXP> >, noncopyable
+ {
+ /*!
+ REQUIREMENTS ON EXP
+ - must be an object that inherits publicly from matrix_exp.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a copy of a matrix expression. The twist
+ is that it only actually makes a copy of its input matrix expression
+ if that matrix expression is costly to evaluate. If it has
+ low cost then this object just stores a reference.
+
+ This class is useful in cases where you write a function that
+ takes a matrix_exp object as input and you want to do some
+ intensive computation that looks at each element of that matrix_exp
+ many times. If the input matrix_exp has a high cost then you want
+ to store it into a temporary matrix. But if it has low cost then
+ it is faster if you just use a reference to it. The const_temp_matrix
+ makes doing this easy.
+ !*/
+ public:
+
+ const_temp_matrix (
+ const matrix_exp<EXP>& item
+ );
+ /*!
+ ensures
+ - #*this == item
+ - if (EXP::cost <= 1) then
+ - this const_temp_matrix stores a reference to the item matrix
+ - else
+ - this const_temp_matrix creates a temporary matrix and copies
+ item into it
+ !*/
+
+ const_temp_matrix (
+ const EXP& item
+ );
+ /*!
+ ensures
+ - #*this == item
+ - if (EXP::cost <= 1) then
+ - this const_temp_matrix stores a reference to the item matrix
+ - else
+ - this const_temp_matrix creates a temporary matrix and copies
+ item into it
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_ABSTRACT_
+
diff --git a/ml/dlib/dlib/matrix/matrix_assign.h b/ml/dlib/dlib/matrix/matrix_assign.h
new file mode 100644
index 000000000..da53050b1
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_assign.h
@@ -0,0 +1,978 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_ASSIGn_
+#define DLIB_MATRIx_ASSIGn_
+
+#include "matrix.h"
+#include "matrix_utilities.h"
+#include "matrix_subexp.h"
+#include "../enable_if.h"
+#include "matrix_assign_fwd.h"
+#include "matrix_default_mul.h"
+#include "matrix_conj_trans.h"
+#include "matrix_mat.h"
+
+namespace dlib
+{
+ /*
+ This file contains some templates that are used inside the matrix_blas_bindings.h
+ file to bind various matrix expressions to optimized code for carrying them out.
+ */
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace blas_bindings
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ void zero_matrix (
+ T& m
+ )
+ {
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = 0;
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // This template struct is used to tell us if a matrix expression contains a matrix multiply.
+ template <typename T>
+ struct has_matrix_multiply
+ {
+ const static bool value = false;
+ };
+
+ template <typename T, typename U>
+ struct has_matrix_multiply<matrix_multiply_exp<T,U> >
+ { const static bool value = true; };
+
+ template <typename T, typename U>
+ struct has_matrix_multiply<matrix_add_exp<T,U> >
+ { const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
+
+ template <typename T, typename U>
+ struct has_matrix_multiply<matrix_subtract_exp<T,U> >
+ { const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
+
+ template <typename T, bool Tb>
+ struct has_matrix_multiply<matrix_mul_scal_exp<T,Tb> >
+ { const static bool value = true; };
+
+ template <typename T>
+ struct has_matrix_multiply<matrix_div_scal_exp<T> >
+ { const static bool value = has_matrix_multiply<T>::value; };
+
+ template <typename T>
+ struct has_matrix_multiply<matrix_op<T> >
+ { const static bool value = has_matrix_multiply<T>::value; };
+
+ template <typename T>
+ struct has_matrix_multiply<op_trans<T> >
+ { const static bool value = has_matrix_multiply<T>::value; };
+
+ template <typename T>
+ struct has_matrix_multiply<op_conj_trans<T> >
+ { const static bool value = has_matrix_multiply<T>::value; };
+
+ template <typename T>
+ struct has_matrix_multiply<op_conj<T> >
+ { const static bool value = has_matrix_multiply<T>::value; };
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ const int unknown_matrix = 0;
+ const int general_matrix = 1;
+ const int row_matrix = 2;
+ const int column_matrix = 3;
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct matrix_type_id
+ {
+ const static int value = unknown_matrix;
+ };
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ struct matrix_type_id<matrix<T,NR,NC,MM,L> >
+ {
+ const static int value = general_matrix;
+ };
+
+ template <typename T, long NR, typename MM, typename L>
+ struct matrix_type_id<matrix<T,NR,1,MM,L> >
+ {
+ const static int value = column_matrix;
+ };
+
+ template <typename T, typename MM, typename L>
+ struct matrix_type_id<matrix<T,1,1,MM,L> >
+ {
+ const static int value = column_matrix;
+ };
+
+ template <typename T, long NC, typename MM, typename L>
+ struct matrix_type_id<matrix<T,1,NC,MM,L> >
+ {
+ const static int value = row_matrix;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ struct matrix_type_id<matrix_op<op_colm<matrix<T,NR,NC,MM,L> > > >
+ {
+ const static int value = column_matrix;
+ };
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ struct matrix_type_id<matrix_op<op_rowm<matrix<T,NR,NC,MM,L> > > >
+ {
+ const static int value = row_matrix;
+ };
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ struct matrix_type_id<matrix_op<op_colm2<matrix<T,NR,NC,MM,L> > > >
+ {
+ const static int value = column_matrix;
+ };
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ struct matrix_type_id<matrix_op<op_rowm2<matrix<T,NR,NC,MM,L> > > >
+ {
+ const static int value = row_matrix;
+ };
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ struct matrix_type_id<matrix_op<op_subm<matrix<T,NR,NC,MM,L> > > >
+ {
+ const static int value = general_matrix;
+ };
+
+ template < typename T, typename MM >
+ struct matrix_type_id<matrix_op<op_array2d_to_mat<array2d<T,MM> > > >
+ { const static int value = general_matrix; };
+
+ template < typename T, typename MM >
+ struct matrix_type_id<matrix_op<op_array_to_mat<array<T,MM> > > >
+ { const static int value = column_matrix; };
+
+ template < typename value_type, typename alloc >
+ struct matrix_type_id<matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > > >
+ { const static int value = column_matrix; };
+
+ template < typename value_type, typename alloc >
+ struct matrix_type_id<matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > > >
+ { const static int value = column_matrix; };
+
+ template < typename T >
+ struct matrix_type_id<matrix_op<op_pointer_to_col_vect<T> > >
+ { const static int value = column_matrix; };
+ template < typename T >
+ struct matrix_type_id<matrix_op<op_pointer_to_mat<T> > >
+ { const static int value = general_matrix; };
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ struct same_matrix
+ {
+ const static int T_id = matrix_type_id<T>::value;
+ const static int U_id = matrix_type_id<U>::value;
+ // The check for unknown_matrix is here so that we can be sure that matrix types
+ // other than the ones specifically enumerated above never get pushed into
+ // any of the BLAS bindings. So saying they are never the same as anything
+ // else prevents them from matching any of the BLAS bindings.
+ const static bool value = (T_id == U_id) && (T_id != unknown_matrix);
+ };
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ // This template struct is used to tell us if two matrix expressions both contain the same
+ // sequence of operators, expressions. It also only has a value of true if the T expression
+ // contains only matrices with the given layout.
+ template <typename T, typename U, typename layout>
+ struct same_exp
+ {
+ const static bool value = (is_same_type<typename T::exp_type, typename U::exp_type>::value ||
+ same_matrix<typename T::exp_type, typename U::exp_type>::value) &&
+ is_same_type<typename T::layout_type,layout>::value;
+
+ };
+
+ // Used only below. They help strip off the const and & qualifiers that can show up
+ // in the LHS_ref_type and RHS_ref_type typedefs.
+ template <typename T> struct noref{ typedef T type;};
+ template <typename T> struct noref<T&>{ typedef T type;};
+ template <typename T> struct noref<const T&>{ typedef T type;};
+ template <typename T> struct noref<const T>{ typedef T type;};
+
+ template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs, typename layout>
+ struct same_exp<matrix_multiply_exp<Tlhs,Trhs>, matrix_multiply_exp<Ulhs,Urhs>,layout >
+ {
+ // The reason this case is more complex than the others is because the matrix_multiply_exp
+ // will use a temporary matrix instead of Tlhs or Trhs in the event that one of these
+ // types corresponds to an expensive expression. So we have to use the type that really
+ // gets used. The following typedefs are here to pick out that true type.
+ typedef typename matrix_multiply_exp<Tlhs,Trhs>::LHS_ref_type T_LHS_ref_type;
+ typedef typename matrix_multiply_exp<Tlhs,Trhs>::RHS_ref_type T_RHS_ref_type;
+ typedef typename noref<T_LHS_ref_type>::type T_lhs_type;
+ typedef typename noref<T_RHS_ref_type>::type T_rhs_type;
+
+ typedef typename matrix_multiply_exp<Ulhs,Urhs>::LHS_ref_type U_LHS_ref_type;
+ typedef typename matrix_multiply_exp<Ulhs,Urhs>::RHS_ref_type U_RHS_ref_type;
+ typedef typename noref<U_LHS_ref_type>::type U_lhs_type;
+ typedef typename noref<U_RHS_ref_type>::type U_rhs_type;
+
+ const static bool value = same_exp<T_lhs_type,U_lhs_type,layout>::value &&
+ same_exp<T_rhs_type,U_rhs_type,layout>::value;
+ };
+
+ template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs, typename layout>
+ struct same_exp<matrix_add_exp<Tlhs,Trhs>, matrix_add_exp<Ulhs,Urhs>, layout >
+ { const static bool value = same_exp<Tlhs,Ulhs,layout>::value && same_exp<Trhs,Urhs,layout>::value; };
+
+ template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs, typename layout>
+ struct same_exp<matrix_subtract_exp<Tlhs,Trhs>, matrix_subtract_exp<Ulhs,Urhs>, layout >
+ { const static bool value = same_exp<Tlhs,Ulhs,layout>::value && same_exp<Trhs,Urhs,layout>::value; };
+
+ template <typename T, typename U, bool Tb, bool Ub, typename layout>
+ struct same_exp<matrix_mul_scal_exp<T,Tb>, matrix_mul_scal_exp<U,Ub>, layout >
+ { const static bool value = same_exp<T,U,layout>::value; };
+
+ template <typename T, typename U, typename layout>
+ struct same_exp<matrix_div_scal_exp<T>, matrix_div_scal_exp<U>, layout >
+ { const static bool value = same_exp<T,U,layout>::value; };
+
+ template <typename T, typename U, typename layout>
+ struct same_exp<matrix_op<op_trans<T> >, matrix_op<op_trans<U> >, layout >
+ { const static bool value = same_exp<T,U,layout>::value; };
+
+ template <typename T, typename U, typename layout>
+ struct same_exp<matrix_op<op_conj<T> >, matrix_op<op_conj<U> >, layout >
+ { const static bool value = same_exp<T,U,layout>::value; };
+
+ template <typename T, typename U, typename layout>
+ struct same_exp<matrix_op<op_conj_trans<T> >, matrix_op<op_conj_trans<U> >, layout >
+ { const static bool value = same_exp<T,U,layout>::value; };
+
+ // ------------------------------------------------------------------------------------
+
+ struct yes_type
+ {
+ char ch;
+ };
+ struct no_type
+ {
+ yes_type a, b;
+ };
+
+ // This is a helper that is used below to apply the same_exp template to matrix expressions.
+ template <typename T, typename layout, typename U>
+ typename enable_if<same_exp<T,U,layout>,yes_type>::type test(U);
+ template <typename T, typename layout, typename U>
+ typename disable_if<same_exp<T,U,layout>,no_type>::type test(U);
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename dest_exp,
+ typename src_exp,
+ typename enabled = void
+ >
+ struct matrix_assign_blas_helper
+ {
+ // We are in the default version of the blas helper so this
+ // means there wasn't any more specific overload. So just
+ // let the default matrix assignment happen.
+ template <typename EXP>
+ static void assign (
+ dest_exp& dest,
+ const EXP& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ )
+ {
+ if (transpose == false)
+ matrix_assign_default(dest,src,alpha,add_to);
+ else
+ matrix_assign_default(dest,trans(src),alpha,add_to);
+ }
+
+ // If we know this is a matrix multiply then apply the
+ // default dlib matrix multiply to speed things up a bit more
+ // than the above default function would.
+ template <typename EXP1, typename EXP2>
+ static void assign (
+ dest_exp& dest,
+ const matrix_multiply_exp<EXP1,EXP2>& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ )
+ {
+ // At some point I need to improve the default (i.e. non BLAS) matrix
+ // multiplication algorithm...
+
+ if (alpha == static_cast<typename src_exp::type>(1))
+ {
+ if (add_to == false)
+ {
+ zero_matrix(dest);
+ }
+
+ if (transpose == false)
+ default_matrix_multiply(dest, src.lhs, src.rhs);
+ else
+ default_matrix_multiply(dest, trans(src.rhs), trans(src.lhs));
+ }
+ else
+ {
+ if (add_to)
+ {
+ typename dest_exp::matrix_type temp(dest.nr(),dest.nc());
+ zero_matrix(temp);
+
+ if (transpose == false)
+ default_matrix_multiply(temp, src.lhs, src.rhs);
+ else
+ default_matrix_multiply(temp, trans(src.rhs), trans(src.lhs));
+
+ matrix_assign_default(dest,temp, alpha,true);
+ }
+ else
+ {
+ zero_matrix(dest);
+
+ if (transpose == false)
+ default_matrix_multiply(dest, src.lhs, src.rhs);
+ else
+ default_matrix_multiply(dest, trans(src.rhs), trans(src.lhs));
+
+ matrix_assign_default(dest,dest, alpha, false);
+ }
+ }
+ }
+ };
+
+ // This is a macro to help us add overloads for the matrix_assign_blas_helper template.
+ // Using this macro it is easy to add overloads for arbitrary matrix expressions.
+#define DLIB_ADD_BLAS_BINDING(src_expression) \
+ template <typename T, typename L> struct BOOST_JOIN(blas,__LINE__) \
+ { const static bool value = sizeof(yes_type) == sizeof(test<T,L>(src_expression)); }; \
+ \
+ template < typename dest_exp, typename src_exp > \
+ struct matrix_assign_blas_helper<dest_exp, src_exp, \
+ typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp,typename dest_exp::layout_type> >::type > { \
+ static void assign ( \
+ dest_exp& dest, \
+ const src_exp& src, \
+ typename src_exp::type alpha, \
+ bool add_to, \
+ bool DLIB_NO_WARN_UNUSED transpose \
+ ) { \
+ DLIB_NO_WARN_UNUSED typedef typename dest_exp::type T;
+
+#define DLIB_END_BLAS_BINDING }};
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ // ------------------- Forward Declarations -------------------
+
+ template <
+ typename dest_exp,
+ typename src_exp
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const src_exp& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ );
+ /*!
+ requires
+ - src.aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ !*/
+
+ template <
+ typename dest_exp,
+ typename src_exp, typename src_exp2
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const matrix_add_exp<src_exp, src_exp2>& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ );
+ /*!
+ requires
+ - src.aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ !*/
+
+ template <
+ typename dest_exp,
+ typename src_exp, bool Sb
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const matrix_mul_scal_exp<src_exp,Sb>& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ );
+ /*!
+ requires
+ - src.aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ !*/
+
+ template <
+ typename dest_exp,
+ typename src_exp
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const matrix_op<op_trans<src_exp> >& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ );
+ /*!
+ requires
+ - src.aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ !*/
+
+ template <
+ typename dest_exp,
+ typename src_exp, typename src_exp2
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const matrix_subtract_exp<src_exp, src_exp2>& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ );
+ /*!
+ requires
+ - src.aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ matrix<T,NR,NC,MM,L>& dest,
+ const src_exp& src
+ );
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
+ );
+ /*!
+ This function catches the expressions of the form:
+ M = M + exp;
+ and converts them into the appropriate matrix_assign_blas() call.
+ This is an important case to catch because it is the expression used
+ to represent the += matrix operator.
+ !*/
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_add_exp<src_exp, matrix<T,NR,NC,MM,L> >& src
+ );
+ /*!
+ This function catches the expressions of the form:
+ M = exp + M;
+ and converts them into the appropriate matrix_assign_blas() call.
+ This is an important case to catch because it is the expression used
+ to represent the += matrix operator.
+ !*/
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
+ );
+ /*!
+ This function catches the expressions of the form:
+ M = M - exp;
+ and converts them into the appropriate matrix_assign_blas() call.
+ This is an important case to catch because it is the expression used
+ to represent the -= matrix operator.
+ !*/
+
+
+ // End of forward declarations for overloaded matrix_assign_blas functions
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename dest_exp,
+ typename src_exp
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const src_exp& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ )
+ {
+ matrix_assign_blas_helper<dest_exp,src_exp>::assign(dest,src,alpha,add_to, transpose);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename dest_exp,
+ typename src_exp, typename src_exp2
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const matrix_add_exp<src_exp, src_exp2>& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ )
+ {
+ if (has_matrix_multiply<src_exp>::value || has_matrix_multiply<src_exp2>::value)
+ {
+ matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to, transpose);
+ matrix_assign_blas_proxy(dest, src.rhs, alpha, true, transpose);
+ }
+ else
+ {
+ if (transpose == false)
+ matrix_assign_default(dest, src, alpha, add_to);
+ else
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename dest_exp,
+ typename src_exp, bool Sb
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const matrix_mul_scal_exp<src_exp,Sb>& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ )
+ {
+ matrix_assign_blas_proxy(dest, src.m, alpha*src.s, add_to, transpose);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename dest_exp,
+ typename src_exp
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const matrix_op<op_trans<src_exp> >& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ )
+ {
+ matrix_assign_blas_proxy(dest, src.op.m, alpha, add_to, !transpose);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename dest_exp,
+ typename src_exp, typename src_exp2
+ >
+ void matrix_assign_blas_proxy (
+ dest_exp& dest,
+ const matrix_subtract_exp<src_exp, src_exp2>& src,
+ typename src_exp::type alpha,
+ bool add_to,
+ bool transpose
+ )
+ {
+
+ if (has_matrix_multiply<src_exp>::value || has_matrix_multiply<src_exp2>::value)
+ {
+ matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to, transpose);
+ matrix_assign_blas_proxy(dest, src.rhs, -alpha, true, transpose);
+ }
+ else
+ {
+ if (transpose == false)
+ matrix_assign_default(dest, src, alpha, add_to);
+ else
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ // Once we get into this function it means that we are dealing with a matrix of float,
+ // double, complex<float>, or complex<double> and the src_exp contains at least one
+ // matrix multiply.
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ long NR2, long NC2, bool Sb
+ >
+ void matrix_assign_blas (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_mul_scal_exp<matrix<T,NR2,NC2,MM,L>,Sb>& src
+ )
+ {
+ // It's ok that we don't check for aliasing in this case because there isn't
+ // any complex unrolling of successive + or - operators in this expression.
+ matrix_assign_blas_proxy(dest,src.m,src.s,false, false);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ matrix<T,NR,NC,MM,L>& dest,
+ const src_exp& src
+ )
+ {
+ if (src.aliases(dest))
+ {
+ matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
+ matrix_assign_blas_proxy(temp,src,1,false, false);
+ temp.swap(dest);
+ }
+ else
+ {
+ matrix_assign_blas_proxy(dest,src,1,false, false);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ assignable_sub_matrix<T,NR,NC,MM,L>& dest,
+ const src_exp& src
+ )
+ {
+ if (src.aliases(dest.m))
+ {
+ matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
+ matrix_assign_blas_proxy(temp,src,1,false, false);
+ matrix_assign_default(dest,temp);
+ }
+ else
+ {
+ matrix_assign_blas_proxy(dest,src,1,false, false);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ assignable_ptr_matrix<T>& dest,
+ const src_exp& src
+ )
+ {
+ if (src.aliases(mat(dest.ptr,dest.height,dest.width)))
+ {
+ matrix<T> temp(dest.nr(),dest.nc());
+ matrix_assign_blas_proxy(temp,src,1,false, false);
+ matrix_assign_default(dest,temp);
+ }
+ else
+ {
+ matrix_assign_blas_proxy(dest,src,1,false, false);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ assignable_row_matrix<T,NR,NC,MM,L>& dest,
+ const src_exp& src
+ )
+ {
+ if (src.aliases(dest.m))
+ {
+ matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
+ matrix_assign_blas_proxy(temp,src,1,false, false);
+ matrix_assign_default(dest,temp);
+ }
+ else
+ {
+ matrix_assign_blas_proxy(dest,src,1,false, false);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ assignable_col_matrix<T,NR,NC,MM,L>& dest,
+ const src_exp& src
+ )
+ {
+ if (src.aliases(dest.m))
+ {
+ matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
+ matrix_assign_blas_proxy(temp,src,1,false, false);
+ matrix_assign_default(dest,temp);
+ }
+ else
+ {
+ matrix_assign_blas_proxy(dest,src,1,false, false);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
+ )
+ {
+ if (src.rhs.aliases(dest) == false)
+ {
+ if (&src.lhs != &dest)
+ {
+ dest = src.lhs;
+ }
+
+ matrix_assign_blas_proxy(dest, src.rhs, 1, true, false);
+ }
+ else
+ {
+ matrix<T,NR,NC,MM,L> temp(src.lhs);
+ matrix_assign_blas_proxy(temp, src.rhs, 1, true, false);
+ temp.swap(dest);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_add_exp<src_exp, matrix<T,NR,NC,MM,L> >& src
+ )
+ {
+ // Just switch around the left and right hand sides of the incoming
+ // add expression and pass it back into matrix_assign_blas() so that
+ // the above function will be called.
+ typedef matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp> swapped_add_exp;
+ matrix_assign_blas(dest, swapped_add_exp(src.rhs, src.lhs));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ void matrix_assign_blas (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
+ )
+ {
+ if (src.rhs.aliases(dest) == false)
+ {
+ if (&src.lhs != &dest)
+ {
+ dest = src.lhs;
+ }
+
+ matrix_assign_blas_proxy(dest, src.rhs, -1, true, false);
+ }
+ else
+ {
+ matrix<T,NR,NC,MM,L> temp(src.lhs);
+ matrix_assign_blas_proxy(temp, src.rhs, -1, true, false);
+ temp.swap(dest);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ } // end of namespace blas_bindings
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ inline typename enable_if_c<(is_same_type<T,float>::value ||
+ is_same_type<T,double>::value ||
+ is_same_type<T,std::complex<float> >::value ||
+ is_same_type<T,std::complex<double> >::value) &&
+ blas_bindings::has_matrix_multiply<src_exp>::value
+ >::type matrix_assign_big (
+ matrix<T,NR,NC,MM,L>& dest,
+ const src_exp& src
+ )
+ {
+ blas_bindings::matrix_assign_blas(dest,src);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ inline typename enable_if_c<(is_same_type<T,float>::value ||
+ is_same_type<T,double>::value ||
+ is_same_type<T,std::complex<float> >::value ||
+ is_same_type<T,std::complex<double> >::value) &&
+ blas_bindings::has_matrix_multiply<src_exp>::value
+ >::type matrix_assign_big (
+ assignable_sub_matrix<T,NR,NC,MM,L>& dest,
+ const src_exp& src
+ )
+ {
+ blas_bindings::matrix_assign_blas(dest,src);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename src_exp
+ >
+ inline typename enable_if_c<(is_same_type<T,float>::value ||
+ is_same_type<T,double>::value ||
+ is_same_type<T,std::complex<float> >::value ||
+ is_same_type<T,std::complex<double> >::value) &&
+ blas_bindings::has_matrix_multiply<src_exp>::value
+ >::type matrix_assign_big (
+ assignable_ptr_matrix<T>& dest,
+ const src_exp& src
+ )
+ {
+ blas_bindings::matrix_assign_blas(dest,src);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ inline typename enable_if_c<(is_same_type<T,float>::value ||
+ is_same_type<T,double>::value ||
+ is_same_type<T,std::complex<float> >::value ||
+ is_same_type<T,std::complex<double> >::value) &&
+ blas_bindings::has_matrix_multiply<src_exp>::value
+ >::type matrix_assign_big (
+ assignable_row_matrix<T,NR,NC,MM,L>& dest,
+ const src_exp& src
+ )
+ {
+ blas_bindings::matrix_assign_blas(dest,src);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L,
+ typename src_exp
+ >
+ inline typename enable_if_c<(is_same_type<T,float>::value ||
+ is_same_type<T,double>::value ||
+ is_same_type<T,std::complex<float> >::value ||
+ is_same_type<T,std::complex<double> >::value) &&
+ blas_bindings::has_matrix_multiply<src_exp>::value
+ >::type matrix_assign_big (
+ assignable_col_matrix<T,NR,NC,MM,L>& dest,
+ const src_exp& src
+ )
+ {
+ blas_bindings::matrix_assign_blas(dest,src);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_ASSIGn_
+
diff --git a/ml/dlib/dlib/matrix/matrix_assign_fwd.h b/ml/dlib/dlib/matrix/matrix_assign_fwd.h
new file mode 100644
index 000000000..7d29baf0a
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_assign_fwd.h
@@ -0,0 +1,413 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_ASSIGn_FWD_
+#define DLIB_MATRIx_ASSIGn_FWD_
+
+// GCC 4.8 gives false alarms about some variables being uninitialized. Disable these
+// false warnings.
+#if defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
+#endif
+
+#include "../enable_if.h"
+#include "matrix_data_layout.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ /*
+ The point of the matrix_assign() functions is to contain all the various
+ optimizations that help the matrix assign a matrix_exp to an actual matrix
+ object quickly.
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ namespace ma
+ {
+ // This template here controls how big a compile time sized matrix needs
+ // to be for it to get passed into the optimized versions of the
+ // matrix_assign() function. So small matrices are evaluated with a simple
+ // loop like the ones in this file and bigger matrices may get sent to BLAS
+ // routines or some other kind of optimized thing.
+ template < typename EXP, typename enable = void >
+ struct is_small_matrix { static const bool value = false; };
+ template < typename EXP >
+ struct is_small_matrix<EXP, typename enable_if_c<EXP::NR>=1 && EXP::NC>=1 &&
+ EXP::NR<=17 && EXP::NC<=17 && (EXP::cost <= 70)>::type> { static const bool value = true; };
+
+ // I wouldn't use this mul object to do the multiply but visual studio 7.1 wouldn't
+ // compile otherwise.
+ template <long a, long b>
+ struct mul { const static long value = a*b; };
+
+ template < typename EXP, typename enable = void >
+ struct is_very_small_matrix { static const bool value = false; };
+ template < typename EXP >
+ struct is_very_small_matrix<EXP, typename enable_if_c<EXP::NR>=1 && EXP::NC>=1 &&
+ (mul<EXP::NR,EXP::NC>::value <= 16) && (EXP::cost <= 70)>::type> { static const bool value = true; };
+
+
+ template < typename EXP, typename enable = void >
+ struct has_column_major_layout { static const bool value = false; };
+ template < typename EXP >
+ struct has_column_major_layout<EXP, typename enable_if<is_same_type<typename EXP::layout_type, column_major_layout> >::type >
+ { static const bool value = true; };
+
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ class matrix_exp;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP1, typename EXP2>
+ inline typename disable_if<ma::has_column_major_layout<EXP1> >::type
+ matrix_assign_default (
+ EXP1& dest,
+ const EXP2& src
+ )
+ /*!
+ requires
+ - src.destructively_aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ ensures
+ - #dest == src
+ !*/
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ dest(r,c) = src(r,c);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP1, typename EXP2>
+ inline typename enable_if<ma::has_column_major_layout<EXP1> >::type
+ matrix_assign_default (
+ EXP1& dest,
+ const EXP2& src
+ )
+ /*!
+ requires
+ - src.destructively_aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ ensures
+ - #dest == src
+ !*/
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ dest(r,c) = src(r,c);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP1, typename EXP2>
+ inline typename disable_if<ma::has_column_major_layout<EXP1> >::type
+ matrix_assign_default (
+ EXP1& dest,
+ const EXP2& src,
+ typename EXP2::type alpha,
+ bool add_to
+ )
+ /*!
+ requires
+ - src.destructively_aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ ensures
+ - if (add_to == false) then
+ - #dest == alpha*src
+ - else
+ - #dest == dest + alpha*src
+ !*/
+ {
+ if (add_to)
+ {
+ if (alpha == static_cast<typename EXP2::type>(1))
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ dest(r,c) += src(r,c);
+ }
+ }
+ }
+ else if (alpha == static_cast<typename EXP2::type>(-1))
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ dest(r,c) -= src(r,c);
+ }
+ }
+ }
+ else
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ dest(r,c) += alpha*src(r,c);
+ }
+ }
+ }
+ }
+ else
+ {
+ if (alpha == static_cast<typename EXP2::type>(1))
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ dest(r,c) = src(r,c);
+ }
+ }
+ }
+ else
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ dest(r,c) = alpha*src(r,c);
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP1, typename EXP2>
+ inline typename enable_if<ma::has_column_major_layout<EXP1> >::type
+ matrix_assign_default (
+ EXP1& dest,
+ const EXP2& src,
+ typename EXP2::type alpha,
+ bool add_to
+ )
+ /*!
+ requires
+ - src.destructively_aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ ensures
+ - if (add_to == false) then
+ - #dest == alpha*src
+ - else
+ - #dest == dest + alpha*src
+ !*/
+ {
+ if (add_to)
+ {
+ if (alpha == static_cast<typename EXP2::type>(1))
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ dest(r,c) += src(r,c);
+ }
+ }
+ }
+ else if (alpha == static_cast<typename EXP2::type>(-1))
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ dest(r,c) -= src(r,c);
+ }
+ }
+ }
+ else
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ dest(r,c) += alpha*src(r,c);
+ }
+ }
+ }
+ }
+ else
+ {
+ if (alpha == static_cast<typename EXP2::type>(1))
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ dest(r,c) = src(r,c);
+ }
+ }
+ }
+ else
+ {
+ for (long c = 0; c < src.nc(); ++c)
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ dest(r,c) = alpha*src(r,c);
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_dest_type,
+ typename src_exp
+ >
+ void matrix_assign_big (
+ matrix_dest_type& dest,
+ const matrix_exp<src_exp>& src
+ )
+ {
+ matrix_assign_default(dest,src);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_dest_type,
+ typename src_exp
+ >
+ inline typename disable_if<ma::is_small_matrix<src_exp> >::type matrix_assign (
+ matrix_dest_type& dest,
+ const matrix_exp<src_exp>& src
+ )
+ /*!
+ requires
+ - src.destructively_aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ ensures
+ - #dest == src
+ !*/
+ {
+ // Call src.ref() here so that the derived type of the matrix_exp shows
+ // up so we can overload matrix_assign_big() based on various matrix expression
+ // types.
+ matrix_assign_big(dest,src.ref());
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+// this code is here to perform an unrolled version of the matrix_assign() function
+ template < typename DEST, typename SRC, long NR, long NC,
+ long R = 0, long C = 0, bool base_case = (R==NR) >
+ struct matrix_unroll_helper
+ {
+ inline static void go ( DEST& dest, const SRC& src)
+ {
+ dest(R,C) = src(R,C);
+ matrix_unroll_helper<DEST,SRC,NR,NC, R + (C+1)/NC, (C+1)%NC>::go(dest,src);
+ }
+ };
+
+ template < typename DEST, typename SRC, long NR, long NC, long R, long C >
+ struct matrix_unroll_helper<DEST,SRC,NR,NC,R,C,true>
+ { inline static void go ( DEST& , const SRC& ) {} };
+
+ template <typename DEST, typename SRC>
+ inline void matrix_assign_unrolled (
+ DEST& dest,
+ const SRC& src
+ )
+ /*!
+ requires
+ - src.destructively_aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ ensures
+ - #dest == src
+ !*/
+ {
+ COMPILE_TIME_ASSERT(SRC::NR*SRC::NC != 0);
+ matrix_unroll_helper<DEST,SRC, SRC::NR, SRC::NC>::go(dest,src);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_dest_type,
+ typename src_exp
+ >
+ inline typename enable_if_c<ma::is_small_matrix<src_exp>::value && ma::is_very_small_matrix<src_exp>::value==false >::type matrix_assign (
+ matrix_dest_type& dest,
+ const matrix_exp<src_exp>& src
+ )
+ /*!
+ requires
+ - src.destructively_aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ ensures
+ - #dest == src
+ !*/
+ {
+ matrix_assign_default(dest,src.ref());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_dest_type,
+ typename src_exp
+ >
+ inline typename enable_if_c<ma::is_small_matrix<src_exp>::value && ma::is_very_small_matrix<src_exp>::value==true >::type matrix_assign (
+ matrix_dest_type& dest,
+ const matrix_exp<src_exp>& src
+ )
+ /*!
+ requires
+ - src.destructively_aliases(dest) == false
+ - dest.nr() == src.nr()
+ - dest.nc() == src.nc()
+ ensures
+ - #dest == src
+ !*/
+ {
+ matrix_assign_unrolled(dest,src.ref());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#if defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))
+#pragma GCC diagnostic pop
+#endif
+
+#endif // DLIB_MATRIx_ASSIGn_FWD_
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_blas_bindings.h b/ml/dlib/dlib/matrix/matrix_blas_bindings.h
new file mode 100644
index 000000000..b65e29cdd
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_blas_bindings.h
@@ -0,0 +1,1637 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_BLAS_BINDINGS_
+#define DLIB_MATRIx_BLAS_BINDINGS_
+
+#ifndef DLIB_USE_BLAS
+#error "DLIB_USE_BLAS should be defined if you want to use the BLAS bindings"
+#endif
+
+#include "matrix_assign.h"
+#include "matrix_conj_trans.h"
+#include "cblas_constants.h"
+
+//#include <iostream>
+//using namespace std;
+
+namespace dlib
+{
+
+
+ namespace blas_bindings
+ {
+
+#ifdef DLIB_TEST_BLAS_BINDINGS
+ int& counter_gemm();
+ int& counter_gemv();
+ int& counter_ger();
+ int& counter_dot();
+ int& counter_axpy();
+ int& counter_scal();
+
+ #define DLIB_TEST_BLAS_BINDING_GEMM ++counter_gemm();
+ #define DLIB_TEST_BLAS_BINDING_GEMV ++counter_gemv();
+ #define DLIB_TEST_BLAS_BINDING_GER ++counter_ger();
+ #define DLIB_TEST_BLAS_BINDING_DOT ++counter_dot();
+ #define DLIB_TEST_BLAS_BINDING_AXPY ++counter_axpy();
+ #define DLIB_TEST_BLAS_BINDING_SCAL ++counter_scal();
+#else
+ #define DLIB_TEST_BLAS_BINDING_GEMM
+ #define DLIB_TEST_BLAS_BINDING_GEMV
+ #define DLIB_TEST_BLAS_BINDING_GER
+ #define DLIB_TEST_BLAS_BINDING_DOT
+ #define DLIB_TEST_BLAS_BINDING_AXPY
+ #define DLIB_TEST_BLAS_BINDING_SCAL
+#endif
+
+#ifndef CBLAS_H
+ extern "C"
+ {
+ // Here we declare the prototypes for the CBLAS calls used by the BLAS bindings below
+
+ void cblas_saxpy(const int N, const float alpha, const float *X,
+ const int incX, float *Y, const int incY);
+ void cblas_daxpy(const int N, const double alpha, const double *X,
+ const int incX, double *Y, const int incY);
+ void cblas_caxpy(const int N, const void *alpha, const void *X,
+ const int incX, void *Y, const int incY);
+ void cblas_zaxpy(const int N, const void *alpha, const void *X,
+ const int incX, void *Y, const int incY);
+
+ void cblas_sscal(const int N, const float alpha, float *X, const int incX);
+ void cblas_dscal(const int N, const double alpha, double *X, const int incX);
+ void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
+ void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
+
+ void cblas_sgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const float alpha, const float *A,
+ const int lda, const float *B, const int ldb,
+ const float beta, float *C, const int ldc);
+ void cblas_dgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const double alpha, const double *A,
+ const int lda, const double *B, const int ldb,
+ const double beta, double *C, const int ldc);
+ void cblas_cgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const void *alpha, const void *A,
+ const int lda, const void *B, const int ldb,
+ const void *beta, void *C, const int ldc);
+ void cblas_zgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const void *alpha, const void *A,
+ const int lda, const void *B, const int ldb,
+ const void *beta, void *C, const int ldc);
+ void cblas_sgemv(const CBLAS_ORDER order,
+ const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ const float *X, const int incX, const float beta,
+ float *Y, const int incY);
+ void cblas_dgemv(const CBLAS_ORDER order,
+ const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ const double *X, const int incX, const double beta,
+ double *Y, const int incY);
+ void cblas_cgemv(const CBLAS_ORDER order,
+ const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY);
+ void cblas_zgemv(const CBLAS_ORDER order,
+ const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const void *alpha, const void *A, const int lda,
+ const void *X, const int incX, const void *beta,
+ void *Y, const int incY);
+ void cblas_sger(const CBLAS_ORDER order, const int M, const int N,
+ const float alpha, const float *X, const int incX,
+ const float *Y, const int incY, float *A, const int lda);
+ void cblas_dger(const CBLAS_ORDER order, const int M, const int N,
+ const double alpha, const double *X, const int incX,
+ const double *Y, const int incY, double *A, const int lda);
+ void cblas_cgerc(const CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+ void cblas_zgerc(const CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+ float cblas_sdot(const int N, const float *X, const int incX,
+ const float *Y, const int incY);
+ double cblas_ddot(const int N, const double *X, const int incX,
+ const double *Y, const int incY);
+ void cblas_cdotu_sub(const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotu);
+ void cblas_zdotu_sub(const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotu);
+ void cblas_cdotc_sub(const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotc);
+ void cblas_zdotc_sub(const int N, const void *X, const int incX,
+ const void *Y, const int incY, void *dotc);
+ void cblas_cgeru(const CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+ void cblas_zgeru(const CBLAS_ORDER order, const int M, const int N,
+ const void *alpha, const void *X, const int incX,
+ const void *Y, const int incY, void *A, const int lda);
+ }
+#endif // if not CBLAS_H
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ inline void cblas_axpy(const int N, const float alpha, const float *X,
+ const int incX, float *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_AXPY;
+ cblas_saxpy(N, alpha, X, incX, Y, incY);
+ }
+
+ inline void cblas_axpy(const int N, const double alpha, const double *X,
+ const int incX, double *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_AXPY;
+ cblas_daxpy(N, alpha, X, incX, Y, incY);
+ }
+
+ inline void cblas_axpy(const int N, const std::complex<float>& alpha, const std::complex<float> *X,
+ const int incX, std::complex<float> *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_AXPY;
+ cblas_caxpy(N, &alpha, X, incX, Y, incY);
+ }
+
+ inline void cblas_axpy(const int N, const std::complex<double>& alpha, const std::complex<double> *X,
+ const int incX, std::complex<double> *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_AXPY;
+ cblas_zaxpy(N, &alpha, X, incX, Y, incY);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ inline void cblas_scal(const int N, const float alpha, float *X)
+ {
+ DLIB_TEST_BLAS_BINDING_SCAL;
+ cblas_sscal(N, alpha, X, 1);
+ }
+
+ inline void cblas_scal(const int N, const double alpha, double *X)
+ {
+ DLIB_TEST_BLAS_BINDING_SCAL;
+ cblas_dscal(N, alpha, X, 1);
+ }
+
+ inline void cblas_scal(const int N, const std::complex<float>& alpha, std::complex<float> *X)
+ {
+ DLIB_TEST_BLAS_BINDING_SCAL;
+ cblas_cscal(N, &alpha, X, 1);
+ }
+
+ inline void cblas_scal(const int N, const std::complex<double>& alpha, std::complex<double> *X)
+ {
+ DLIB_TEST_BLAS_BINDING_SCAL;
+ cblas_zscal(N, &alpha, X, 1);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ inline void cblas_gemm( const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const float alpha, const float *A,
+ const int lda, const float *B, const int ldb,
+ const float beta, float *C, const int ldc)
+ {
+ DLIB_TEST_BLAS_BINDING_GEMM;
+ cblas_sgemm( Order, TransA, TransB, M, N,
+ K, alpha, A, lda, B, ldb, beta, C, ldc);
+ }
+
+ inline void cblas_gemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const double alpha, const double *A,
+ const int lda, const double *B, const int ldb,
+ const double beta, double *C, const int ldc)
+ {
+ DLIB_TEST_BLAS_BINDING_GEMM;
+ cblas_dgemm( Order, TransA, TransB, M, N,
+ K, alpha, A, lda, B, ldb, beta, C, ldc);
+ }
+
+ inline void cblas_gemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const std::complex<float>& alpha, const std::complex<float> *A,
+ const int lda, const std::complex<float> *B, const int ldb,
+ const std::complex<float>& beta, std::complex<float> *C, const int ldc)
+ {
+ DLIB_TEST_BLAS_BINDING_GEMM;
+ cblas_cgemm( Order, TransA, TransB, M, N,
+ K, &alpha, A, lda, B, ldb, &beta, C, ldc);
+ }
+
+ inline void cblas_gemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,
+ const int K, const std::complex<double>& alpha, const std::complex<double> *A,
+ const int lda, const std::complex<double> *B, const int ldb,
+ const std::complex<double>& beta, std::complex<double> *C, const int ldc)
+ {
+ DLIB_TEST_BLAS_BINDING_GEMM;
+ cblas_zgemm( Order, TransA, TransB, M, N,
+ K, &alpha, A, lda, B, ldb, &beta, C, ldc);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ inline void cblas_gemv(const CBLAS_ORDER order,
+ const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ const float *X, const int incX, const float beta,
+ float *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_GEMV;
+ cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ }
+
+ inline void cblas_gemv(const CBLAS_ORDER order,
+ const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ const double *X, const int incX, const double beta,
+ double *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_GEMV;
+ cblas_dgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ }
+
+ inline void cblas_gemv(const CBLAS_ORDER order,
+ const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const std::complex<float>& alpha, const std::complex<float> *A, const int lda,
+ const std::complex<float> *X, const int incX, const std::complex<float>& beta,
+ std::complex<float> *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_GEMV;
+ cblas_cgemv(order, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
+ }
+
+ inline void cblas_gemv(const CBLAS_ORDER order,
+ const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const std::complex<double>& alpha, const std::complex<double> *A, const int lda,
+ const std::complex<double> *X, const int incX, const std::complex<double>& beta,
+ std::complex<double> *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_GEMV;
+ cblas_zgemv(order, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ inline void cblas_ger(const CBLAS_ORDER order, const int M, const int N,
+ const std::complex<float>& alpha, const std::complex<float> *X, const int incX,
+ const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda)
+ {
+ DLIB_TEST_BLAS_BINDING_GER;
+ cblas_cgeru (order, M, N, &alpha, X, incX, Y, incY, A, lda);
+ }
+
+ inline void cblas_ger(const CBLAS_ORDER order, const int M, const int N,
+ const std::complex<double>& alpha, const std::complex<double> *X, const int incX,
+ const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda)
+ {
+ DLIB_TEST_BLAS_BINDING_GER;
+ cblas_zgeru (order, M, N, &alpha, X, incX, Y, incY, A, lda);
+ }
+
+ inline void cblas_ger(const CBLAS_ORDER order, const int M, const int N,
+ const float alpha, const float *X, const int incX,
+ const float *Y, const int incY, float *A, const int lda)
+ {
+ DLIB_TEST_BLAS_BINDING_GER;
+ cblas_sger (order, M, N, alpha, X, incX, Y, incY, A, lda);
+ }
+
+ inline void cblas_ger(const CBLAS_ORDER order, const int M, const int N,
+ const double alpha, const double *X, const int incX,
+ const double *Y, const int incY, double *A, const int lda)
+ {
+ DLIB_TEST_BLAS_BINDING_GER;
+ cblas_dger (order, M, N, alpha, X, incX, Y, incY, A, lda);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ inline void cblas_gerc(const CBLAS_ORDER order, const int M, const int N,
+ const std::complex<float>& alpha, const std::complex<float> *X, const int incX,
+ const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda)
+ {
+ DLIB_TEST_BLAS_BINDING_GER;
+ cblas_cgerc (order, M, N, &alpha, X, incX, Y, incY, A, lda);
+ }
+
+ inline void cblas_gerc(const CBLAS_ORDER order, const int M, const int N,
+ const std::complex<double>& alpha, const std::complex<double> *X, const int incX,
+ const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda)
+ {
+ DLIB_TEST_BLAS_BINDING_GER;
+ cblas_zgerc (order, M, N, &alpha, X, incX, Y, incY, A, lda);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ inline float cblas_dot(const int N, const float *X, const int incX,
+ const float *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_DOT;
+ return cblas_sdot(N, X, incX, Y, incY);
+ }
+
+ inline double cblas_dot(const int N, const double *X, const int incX,
+ const double *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_DOT;
+ return cblas_ddot(N, X, incX, Y, incY);
+ }
+
+ inline std::complex<float> cblas_dot(const int N, const std::complex<float> *X, const int incX,
+ const std::complex<float> *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_DOT;
+ std::complex<float> result;
+ cblas_cdotu_sub(N, X, incX, Y, incY, &result);
+ return result;
+ }
+
+ inline std::complex<double> cblas_dot(const int N, const std::complex<double> *X, const int incX,
+ const std::complex<double> *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_DOT;
+ std::complex<double> result;
+ cblas_zdotu_sub(N, X, incX, Y, incY, &result);
+ return result;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ inline std::complex<float> cblas_dotc(const int N, const std::complex<float> *X, const int incX,
+ const std::complex<float> *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_DOT;
+ std::complex<float> result;
+ cblas_cdotc_sub(N, X, incX, Y, incY, &result);
+ return result;
+ }
+
+ inline std::complex<double> cblas_dotc(const int N, const std::complex<double> *X, const int incX,
+ const std::complex<double> *Y, const int incY)
+ {
+ DLIB_TEST_BLAS_BINDING_DOT;
+ std::complex<double> result;
+ cblas_zdotc_sub(N, X, incX, Y, incY, &result);
+ return result;
+ }
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // Helpers for determining the data pointer, LDA, and incX arguments to BLAS functions.
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const matrix<T,NR,NC,MM,row_major_layout>& m) { return m.nc(); }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const matrix<T,NR,NC,MM,column_major_layout>& m) { return m.nr(); }
+
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const matrix_op<op_subm<matrix<T,NR,NC,MM,row_major_layout> > >& m) { return m.op.m.nc(); }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const matrix_op<op_subm<matrix<T,NR,NC,MM,column_major_layout> > >& m) { return m.op.m.nr(); }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const assignable_sub_matrix<T,NR,NC,MM,row_major_layout>& m) { return m.m.nc(); }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const assignable_col_matrix<T,NR,NC,MM,row_major_layout>& m) { return m.m.nc(); }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const assignable_col_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const assignable_row_matrix<T,NR,NC,MM,row_major_layout>& m) { return m.m.nc(); }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_ld (const assignable_row_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
+
+ template <typename T>
+ int get_ld (const assignable_ptr_matrix<T>& m) { return m.nc(); }
+
+ template <typename T, typename MM>
+ int get_ld (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& m) { return m.nc(); }
+ template <typename T, typename MM>
+ int get_ld (const matrix_op<op_array_to_mat<array<T,MM> > >& m) { return m.nc(); }
+ template < typename value_type, typename alloc >
+ int get_ld (const matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > >& m) { return m.nc(); }
+ template < typename value_type, typename alloc >
+ int get_ld (const matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > >& m) { return m.nc(); }
+ template <typename T>
+ int get_ld (const matrix_op<op_pointer_to_col_vect<T> >& m) { return m.nc(); }
+ template <typename T>
+ int get_ld (const matrix_op<op_pointer_to_mat<T> >& m) { return m.op.stride; }
+
+ // --------
+
+ // get_inc() returns the offset from one element to another. If an object has a
+ // non-uniform offset between elements then returns 0 (e.g. a subm() view could
+ // have a non-uniform offset between elements).
+
+ template <typename T, typename MM>
+ int get_inc (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& ) { return 1; }
+ template <typename T, typename MM>
+ int get_inc (const matrix_op<op_array_to_mat<array<T,MM> > >& ) { return 1; }
+ template < typename value_type, typename alloc >
+ int get_inc (const matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > >& ) { return 1; }
+ template < typename value_type, typename alloc >
+ int get_inc (const matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > >& ) { return 1; }
+ template <typename T>
+ int get_inc (const matrix_op<op_pointer_to_col_vect<T> >& ) { return 1; }
+ template <typename T>
+ int get_inc (const matrix_op<op_pointer_to_mat<T> >& m) { return m.op.stride==m.op.cols ? 1 : 0; }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc (const matrix_op<op_subm<matrix<T,NR,NC,MM,row_major_layout> > >& m)
+ {
+ // if the sub-view doesn't cover all the columns then it can't have a uniform
+ // layout.
+ if (m.nc() < m.op.m.nc())
+ return 0;
+ else
+ return 1;
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc (const matrix_op<op_subm<matrix<T,NR,NC,MM,column_major_layout> > >& m)
+ {
+ if (m.nr() < m.op.m.nr())
+ return 0;
+ else
+ return 1;
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc (const assignable_sub_matrix<T,NR,NC,MM,row_major_layout>& m)
+ {
+ if (m.nc() < m.m.nc())
+ return 0;
+ else
+ return 1;
+ }
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m)
+ {
+ if (m.nr() < m.m.nr())
+ return 0;
+ else
+ return 1;
+ }
+
+ template <typename T>
+ int get_inc (const assignable_ptr_matrix<T>& ) { return 1; }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const matrix_op<op_colm<matrix<T,NR,NC,MM,row_major_layout> > >& m)
+ {
+ return m.op.m.nc();
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const matrix_op<op_rowm<matrix<T,NR,NC,MM,row_major_layout> > >& )
+ {
+ return 1;
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const matrix_op<op_colm2<matrix<T,NR,NC,MM,row_major_layout> > >& m)
+ {
+ return m.op.m.nc();
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const matrix_op<op_rowm2<matrix<T,NR,NC,MM,row_major_layout> > >& )
+ {
+ return 1;
+ }
+
+
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const matrix_op<op_colm<matrix<T,NR,NC,MM,column_major_layout> > >& )
+ {
+ return 1;
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const matrix_op<op_rowm<matrix<T,NR,NC,MM,column_major_layout> > >& m)
+ {
+ return m.op.m.nr();
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const matrix_op<op_colm2<matrix<T,NR,NC,MM,column_major_layout> > >& )
+ {
+ return 1;
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const matrix_op<op_rowm2<matrix<T,NR,NC,MM,column_major_layout> > >& m)
+ {
+ return m.op.m.nr();
+ }
+
+
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const assignable_row_matrix<T,NR,NC,MM,row_major_layout>& )
+ {
+ return 1;
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const assignable_row_matrix<T,NR,NC,MM,column_major_layout>& m)
+ {
+ return m.m.nr();
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const assignable_col_matrix<T,NR,NC,MM,row_major_layout>& m)
+ {
+ return m.m.nc();
+ }
+
+ template <typename T, long NR, long NC, typename MM>
+ int get_inc(const assignable_col_matrix<T,NR,NC,MM,column_major_layout>& )
+ {
+ return 1;
+ }
+
+ // --------
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ const T* get_ptr (const matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ T* get_ptr (matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ const T* get_ptr (const matrix_op<op_subm<matrix<T,NR,NC,MM,L> > >& m) { return &m.op.m(m.op.r_,m.op.c_); }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ const T* get_ptr (const matrix_op<op_colm<matrix<T,NR,NC,MM,L> > >& m) { return &m.op.m(0,m.op.col); }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ const T* get_ptr (const matrix_op<op_rowm<matrix<T,NR,NC,MM,L> > >& m) { return &m.op.m(m.op.row,0); }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ const T* get_ptr (const matrix_op<op_colm2<matrix<T,NR,NC,MM,L> > >& m) { return &m.op.m(0,m.op.col); }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ const T* get_ptr (const matrix_op<op_rowm2<matrix<T,NR,NC,MM,L> > >& m) { return &m.op.m(m.op.row,0); }
+
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ T* get_ptr (assignable_col_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ T* get_ptr (assignable_row_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ T* get_ptr (assignable_sub_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
+
+ template <typename T>
+ T* get_ptr (assignable_ptr_matrix<T>& m) { return m.ptr; }
+
+ template <typename T, typename MM>
+ const T* get_ptr (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& m) { return &m.op.array[0][0]; }
+ template <typename T, typename MM>
+ const T* get_ptr (const matrix_op<op_array_to_mat<array<T,MM> > >& m) { return &m.op.vect[0]; }
+ template < typename T, typename alloc >
+ const T* get_ptr (const matrix_op<op_std_vect_to_mat<std::vector<T,alloc> > >& m) { return &m.op.vect[0]; }
+ template < typename T, typename alloc >
+ const T* get_ptr (const matrix_op<op_std_vect_to_mat<std_vector_c<T,alloc> > >& m) { return &m.op.vect[0]; }
+ template <typename T>
+ const T* get_ptr (const matrix_op<op_pointer_to_col_vect<T> >& m) { return m.op.ptr; }
+ template <typename T>
+ const T* get_ptr (const matrix_op<op_pointer_to_mat<T> >& m) { return m.op.ptr; }
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ // Here we declare some matrix objects for use in the DLIB_ADD_BLAS_BINDING macro. These
+ // extern declarations don't actually correspond to any real matrix objects. They are
+ // simply here so we can build matrix expressions with the DLIB_ADD_BLAS_BINDING marco.
+
+
+ // Note that the fact that these are double matrices isn't important, it is just a placeholder in this case.
+ extern matrix<double> m; // general matrix
+ extern matrix<double,1,0> rv; // general row vector
+ extern matrix<double,0,1> cv; // general column vector
+ extern const double s;
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // AXPY/SCAL overloads
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(m)
+ {
+
+ const int N = static_cast<int>(src.size());
+ if (transpose == false && N != 0)
+ {
+ if (add_to)
+ {
+ if (get_inc(src) && get_inc(dest))
+ cblas_axpy(N, alpha, get_ptr(src), get_inc(src), get_ptr(dest), get_inc(dest));
+ else
+ matrix_assign_default(dest, src, alpha, add_to);
+ }
+ else
+ {
+ if (get_ptr(src) == get_ptr(dest))
+ cblas_scal(N, alpha, get_ptr(dest));
+ else
+ matrix_assign_default(dest, src, alpha, add_to);
+ }
+ }
+ else
+ {
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+ }
+
+ } DLIB_END_BLAS_BINDING
+
+ DLIB_ADD_BLAS_BINDING(rv)
+ {
+
+ const int N = static_cast<int>(src.size());
+ if (transpose == false && N != 0)
+ {
+ if (add_to)
+ {
+ if (get_inc(src) && get_inc(dest))
+ cblas_axpy(N, alpha, get_ptr(src), get_inc(src), get_ptr(dest), get_inc(dest));
+ else
+ matrix_assign_default(dest, src, alpha, add_to);
+ }
+ else
+ {
+ if (get_ptr(src) == get_ptr(dest))
+ cblas_scal(N, alpha, get_ptr(dest));
+ else
+ matrix_assign_default(dest, src, alpha, add_to);
+ }
+ }
+ else
+ {
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+ }
+
+ } DLIB_END_BLAS_BINDING
+
+ DLIB_ADD_BLAS_BINDING(cv)
+ {
+
+ const int N = static_cast<int>(src.size());
+ if (transpose == false && N != 0)
+ {
+ if (add_to)
+ {
+ if (get_inc(src) && get_inc(dest))
+ cblas_axpy(N, alpha, get_ptr(src), get_inc(src), get_ptr(dest), get_inc(dest));
+ else
+ matrix_assign_default(dest, src, alpha, add_to);
+ }
+ else
+ {
+ if (get_ptr(src) == get_ptr(dest))
+ cblas_scal(N, alpha, get_ptr(dest));
+ else
+ matrix_assign_default(dest, src, alpha, add_to);
+ }
+ }
+ else
+ {
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+ }
+
+ } DLIB_END_BLAS_BINDING
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // GEMM overloads
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(m*m)
+ {
+ //cout << "BLAS GEMM: m*m" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(src.nr());
+ const int N = static_cast<int>(src.nc());
+ const int K = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs);
+ const int lda = get_ld(src.lhs);
+ const T* B = get_ptr(src.rhs);
+ const int ldb = get_ld(src.rhs);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* C = get_ptr(dest);
+ const int ldc = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ else
+ cblas_gemm(Order, CblasTrans, CblasTrans, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(m)*m)
+ {
+ //cout << "BLAS GEMM: trans(m)*m" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasTrans;
+ const CBLAS_TRANSPOSE TransB = CblasNoTrans;
+ const int M = static_cast<int>(src.nr());
+ const int N = static_cast<int>(src.nc());
+ const int K = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* B = get_ptr(src.rhs);
+ const int ldb = get_ld(src.rhs);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* C = get_ptr(dest);
+ const int ldc = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ else
+ cblas_gemm(Order, TransA, TransB, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(m*trans(m))
+ {
+ //cout << "BLAS GEMM: m*trans(m)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasNoTrans;
+ const CBLAS_TRANSPOSE TransB = CblasTrans;
+ const int M = static_cast<int>(src.nr());
+ const int N = static_cast<int>(src.nc());
+ const int K = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs);
+ const int lda = get_ld(src.lhs);
+ const T* B = get_ptr(src.rhs.op.m);
+ const int ldb = get_ld(src.rhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* C = get_ptr(dest);
+ const int ldc = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ else
+ cblas_gemm(Order, TransA, TransB, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(m)*trans(m))
+ {
+ //cout << "BLAS GEMM: trans(m)*trans(m)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(src.nr());
+ const int N = static_cast<int>(src.nc());
+ const int K = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* B = get_ptr(src.rhs.op.m);
+ const int ldb = get_ld(src.rhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* C = get_ptr(dest);
+ const int ldc = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ else
+ cblas_gemm(Order, CblasNoTrans, CblasNoTrans, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+ // --------------------------------------
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(conj(m))*m)
+ {
+ //cout << "BLAS GEMM: trans(conj(m))*m" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasConjTrans;
+ const CBLAS_TRANSPOSE TransB = CblasNoTrans;
+ const int M = static_cast<int>(src.nr());
+ const int N = static_cast<int>(src.nc());
+ const int K = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* B = get_ptr(src.rhs);
+ const int ldb = get_ld(src.rhs);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* C = get_ptr(dest);
+ const int ldc = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ else
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(m))
+ {
+ //cout << "BLAS GEMM: trans(conj(m))*trans(m)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasConjTrans;
+ const CBLAS_TRANSPOSE TransB = CblasTrans;
+ const int M = static_cast<int>(src.nr());
+ const int N = static_cast<int>(src.nc());
+ const int K = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* B = get_ptr(src.rhs.op.m);
+ const int ldb = get_ld(src.rhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* C = get_ptr(dest);
+ const int ldc = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ else
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(m*trans(conj(m)))
+ {
+ //cout << "BLAS GEMM: m*trans(conj(m))" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasNoTrans;
+ const CBLAS_TRANSPOSE TransB = CblasConjTrans;
+ const int M = static_cast<int>(src.nr());
+ const int N = static_cast<int>(src.nc());
+ const int K = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs);
+ const int lda = get_ld(src.lhs);
+ const T* B = get_ptr(src.rhs.op.m);
+ const int ldb = get_ld(src.rhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* C = get_ptr(dest);
+ const int ldc = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ else
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(m)*trans(conj(m)))
+ {
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasTrans;
+ const CBLAS_TRANSPOSE TransB = CblasConjTrans;
+ const int M = static_cast<int>(src.nr());
+ const int N = static_cast<int>(src.nc());
+ const int K = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* B = get_ptr(src.rhs.op.m);
+ const int ldb = get_ld(src.rhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* C = get_ptr(dest);
+ const int ldc = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ else
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(conj(m)))
+ {
+ //cout << "BLAS GEMM: trans(conj(m))*trans(conj(m))" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasConjTrans;
+ const CBLAS_TRANSPOSE TransB = CblasConjTrans;
+ const int M = static_cast<int>(src.nr());
+ const int N = static_cast<int>(src.nc());
+ const int K = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* B = get_ptr(src.rhs.op.m);
+ const int ldb = get_ld(src.rhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* C = get_ptr(dest);
+ const int ldc = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ else
+ matrix_assign_default(dest, trans(src), alpha, add_to);
+ } DLIB_END_BLAS_BINDING
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // GEMV overloads
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(m*cv)
+ {
+ //cout << "BLAS GEMV: m*cv" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasNoTrans;
+ const int M = static_cast<int>(src.lhs.nr());
+ const int N = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs);
+ const int lda = get_ld(src.lhs);
+ const T* X = get_ptr(src.rhs);
+ const int incX = get_inc(src.rhs);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(rv*m)
+ {
+ // Note that rv*m is the same as trans(m)*trans(rv)
+
+ //cout << "BLAS GEMV: rv*m" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasTrans;
+ const int M = static_cast<int>(src.rhs.nr());
+ const int N = static_cast<int>(src.rhs.nc());
+ const T* A = get_ptr(src.rhs);
+ const int lda = get_ld(src.rhs);
+ const T* X = get_ptr(src.lhs);
+ const int incX = get_inc(src.lhs);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(cv)*m)
+ {
+ // Note that trans(cv)*m is the same as trans(m)*cv
+
+ //cout << "BLAS GEMV: trans(cv)*m" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasTrans;
+ const int M = static_cast<int>(src.rhs.nr());
+ const int N = static_cast<int>(src.rhs.nc());
+ const T* A = get_ptr(src.rhs);
+ const int lda = get_ld(src.rhs);
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(m*trans(rv))
+ {
+ //cout << "BLAS GEMV: m*trans(rv)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasNoTrans;
+ const int M = static_cast<int>(src.lhs.nr());
+ const int N = static_cast<int>(src.lhs.nc());
+ const T* A = get_ptr(src.lhs);
+ const int lda = get_ld(src.lhs);
+ const T* X = get_ptr(src.rhs.op.m);
+ const int incX = get_inc(src.rhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+ // --------------------------------------
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(m)*cv)
+ {
+ //cout << "BLAS GEMV: trans(m)*cv" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasTrans;
+ const int M = static_cast<int>(src.lhs.op.m.nr());
+ const int N = static_cast<int>(src.lhs.op.m.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* X = get_ptr(src.rhs);
+ const int incX = get_inc(src.rhs);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(rv*trans(m))
+ {
+ // Note that rv*trans(m) is the same as m*trans(rv)
+
+ //cout << "BLAS GEMV: rv*trans(m)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasNoTrans;
+ const int M = static_cast<int>(src.rhs.op.m.nr());
+ const int N = static_cast<int>(src.rhs.op.m.nc());
+ const T* A = get_ptr(src.rhs.op.m);
+ const int lda = get_ld(src.rhs.op.m);
+ const T* X = get_ptr(src.lhs);
+ const int incX = get_inc(src.lhs);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(cv)*trans(m))
+ {
+ // Note that trans(cv)*trans(m) is the same as m*cv
+
+ //cout << "BLAS GEMV: trans(cv)*trans(m)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasNoTrans;
+ const int M = static_cast<int>(src.rhs.op.m.nr());
+ const int N = static_cast<int>(src.rhs.op.m.nc());
+ const T* A = get_ptr(src.rhs.op.m);
+ const int lda = get_ld(src.rhs.op.m);
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(m)*trans(rv))
+ {
+ //cout << "BLAS GEMV: trans(m)*trans(rv)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasTrans;
+ const int M = static_cast<int>(src.lhs.op.m.nr());
+ const int N = static_cast<int>(src.lhs.op.m.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* X = get_ptr(src.rhs.op.m);
+ const int incX = get_inc(src.rhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+ // --------------------------------------
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(cv)*conj(m))
+ {
+ // Note that trans(cv)*conj(m) == conj(trans(m))*cv
+ //cout << "BLAS GEMV: trans(cv)*conj(m)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasConjTrans;
+ const int M = static_cast<int>(src.rhs.op.m.nr());
+ const int N = static_cast<int>(src.rhs.op.m.nc());
+ const T* A = get_ptr(src.rhs.op.m);
+ const int lda = get_ld(src.rhs.op.m);
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(rv*conj(m))
+ {
+ // Note that rv*conj(m) == conj(trans(m))*cv
+ //cout << "BLAS GEMV: rv*conj(m)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasConjTrans;
+ const int M = static_cast<int>(src.rhs.op.m.nr());
+ const int N = static_cast<int>(src.rhs.op.m.nc());
+ const T* A = get_ptr(src.rhs.op.m);
+ const int lda = get_ld(src.rhs.op.m);
+ const T* X = get_ptr(src.lhs);
+ const int incX = get_inc(src.lhs);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(conj(m))*cv)
+ {
+ //cout << "BLAS GEMV: trans(conj(m))*cv" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasConjTrans;
+ const int M = static_cast<int>(src.lhs.op.m.nr());
+ const int N = static_cast<int>(src.lhs.op.m.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* X = get_ptr(src.rhs);
+ const int incX = get_inc(src.rhs);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(rv))
+ {
+ //cout << "BLAS GEMV: trans(conj(m))*trans(rv)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const CBLAS_TRANSPOSE TransA = CblasConjTrans;
+ const int M = static_cast<int>(src.lhs.op.m.nr());
+ const int N = static_cast<int>(src.lhs.op.m.nc());
+ const T* A = get_ptr(src.lhs.op.m);
+ const int lda = get_ld(src.lhs.op.m);
+ const T* X = get_ptr(src.rhs.op.m);
+ const int incX = get_inc(src.rhs.op.m);
+
+ const T beta = static_cast<T>(add_to?1:0);
+ T* Y = get_ptr(dest);
+ const int incY = get_inc(dest);
+
+ cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
+ } DLIB_END_BLAS_BINDING
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // GER overloads
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(cv*rv)
+ {
+ //cout << "BLAS GER: cv*rv" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(dest.nr());
+ const int N = static_cast<int>(dest.nc());
+ const T* X = get_ptr(src.lhs);
+ const int incX = get_inc(src.lhs);
+ const T* Y = get_ptr(src.rhs);
+ const int incY = get_inc(src.rhs);
+
+ if (add_to == false)
+ zero_matrix(dest);
+
+ T* A = get_ptr(dest);
+ const int lda = get_ld(dest);
+
+ if (transpose == false)
+ cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
+ else
+ cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(rv)*rv)
+ {
+ //cout << "BLAS GER: trans(rv)*rv" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(dest.nr());
+ const int N = static_cast<int>(dest.nc());
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+ const T* Y = get_ptr(src.rhs);
+ const int incY = get_inc(src.rhs);
+
+ if (add_to == false)
+ zero_matrix(dest);
+
+ T* A = get_ptr(dest);
+ const int lda = get_ld(dest);
+
+ if (transpose == false)
+ cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
+ else
+ cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(cv*trans(cv))
+ {
+ //cout << "BLAS GER: cv*trans(cv)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(dest.nr());
+ const int N = static_cast<int>(dest.nc());
+ const T* X = get_ptr(src.lhs);
+ const int incX = get_inc(src.lhs);
+ const T* Y = get_ptr(src.rhs.op.m);
+ const int incY = get_inc(src.rhs.op.m);
+
+ if (add_to == false)
+ zero_matrix(dest);
+
+ T* A = get_ptr(dest);
+ const int lda = get_ld(dest);
+
+ if (transpose == false)
+ cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
+ else
+ cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(rv)*trans(cv))
+ {
+ //cout << "BLAS GER: trans(rv)*trans(cv)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(dest.nr());
+ const int N = static_cast<int>(dest.nc());
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+ const T* Y = get_ptr(src.rhs.op.m);
+ const int incY = get_inc(src.rhs.op.m);
+
+ if (add_to == false)
+ zero_matrix(dest);
+
+ T* A = get_ptr(dest);
+ const int lda = get_ld(dest);
+
+ if (transpose == false)
+ cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
+ else
+ cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
+ } DLIB_END_BLAS_BINDING
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // GERC overloads
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ /*
+ DLIB_ADD_BLAS_BINDING(cv*conj(rv))
+ {
+ //cout << "BLAS GERC: cv*conj(rv)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(dest.nr());
+ const int N = static_cast<int>(dest.nc());
+ const T* X = get_ptr(src.lhs);
+ const int incX = get_inc(src.lhs);
+ const T* Y = get_ptr(src.rhs.op.m);
+ const int incY = get_inc(src.rhs.op.m);
+
+ if (add_to == false)
+ zero_matrix(dest);
+
+ T* A = get_ptr(dest);
+ const int lda = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
+ else
+ cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
+ } DLIB_END_BLAS_BINDING
+ */
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(cv*conj(trans(cv)))
+ {
+ //cout << "BLAS GERC: cv*conj(trans(cv))" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(dest.nr());
+ const int N = static_cast<int>(dest.nc());
+ const T* X = get_ptr(src.lhs);
+ const int incX = get_inc(src.lhs);
+ const T* Y = get_ptr(src.rhs.op.m);
+ const int incY = get_inc(src.rhs.op.m);
+
+
+ if (transpose == false)
+ {
+ T* A = get_ptr(dest);
+ const int lda = get_ld(dest);
+
+ if (add_to == false)
+ zero_matrix(dest);
+
+ cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
+ }
+ else
+ {
+ matrix_assign_default(dest,trans(src),alpha,add_to);
+ //cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
+ }
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(rv)*conj(trans(cv)))
+ {
+ //cout << "BLAS GERC: trans(rv)*conj(trans(cv))" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(dest.nr());
+ const int N = static_cast<int>(dest.nc());
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+ const T* Y = get_ptr(src.rhs.op.m);
+ const int incY = get_inc(src.rhs.op.m);
+
+
+ if (transpose == false)
+ {
+ if (add_to == false)
+ zero_matrix(dest);
+
+ T* A = get_ptr(dest);
+ const int lda = get_ld(dest);
+
+ cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
+ }
+ else
+ {
+ matrix_assign_default(dest,trans(src),alpha,add_to);
+ //cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
+ }
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ /*
+ DLIB_ADD_BLAS_BINDING(trans(rv)*conj(rv))
+ {
+ //cout << "BLAS GERC: trans(rv)*conj(rv)" << endl;
+ const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
+ const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
+ const int M = static_cast<int>(dest.nr());
+ const int N = static_cast<int>(dest.nc());
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+ const T* Y = get_ptr(src.rhs.op.m);
+ const int incY = get_inc(src.rhs.op.m);
+
+ if (add_to == false)
+ zero_matrix(dest);
+
+ T* A = get_ptr(dest);
+ const int lda = get_ld(dest);
+
+ if (transpose == false)
+ cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
+ else
+ cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
+ } DLIB_END_BLAS_BINDING
+ */
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // DOT overloads
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(rv*cv)
+ {
+ //cout << "BLAS DOT: rv*cv" << endl;
+ const int N = static_cast<int>(src.lhs.size());
+ const T* X = get_ptr(src.lhs);
+ const int incX = get_inc(src.lhs);
+ const T* Y = get_ptr(src.rhs);
+ const int incY = get_inc(src.rhs);
+
+ if (add_to == false)
+ dest(0) = alpha*cblas_dot(N, X, incX, Y, incY);
+ else
+ dest(0) += alpha*cblas_dot(N, X, incX, Y, incY);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(cv)*cv)
+ {
+ //cout << "BLAS DOT: trans(cv)*cv" << endl;
+ const int N = static_cast<int>(src.lhs.size());
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+ const T* Y = get_ptr(src.rhs);
+ const int incY = get_inc(src.rhs);
+
+ if (add_to == false)
+ dest(0) = alpha*cblas_dot(N, X, incX, Y, incY);
+ else
+ dest(0) += alpha*cblas_dot(N, X, incX, Y, incY);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(rv*trans(rv))
+ {
+ //cout << "BLAS DOT: rv*trans(rv)" << endl;
+ const int N = static_cast<int>(src.lhs.size());
+ const T* X = get_ptr(src.lhs);
+ const int incX = get_inc(src.lhs);
+ const T* Y = get_ptr(src.rhs.op.m);
+ const int incY = get_inc(src.rhs.op.m);
+
+ if (add_to == false)
+ dest(0) = alpha*cblas_dot(N, X, incX, Y, incY);
+ else
+ dest(0) += alpha*cblas_dot(N, X, incX, Y, incY);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(cv)*trans(rv))
+ {
+ //cout << "BLAS DOT: trans(cv)*trans(rv)" << endl;
+ const int N = static_cast<int>(src.lhs.op.m.size());
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+ const T* Y = get_ptr(src.rhs.op.m);
+ const int incY = get_inc(src.rhs.op.m);
+
+ if (add_to == false)
+ dest(0) = alpha*cblas_dot(N, X, incX, Y, incY);
+ else
+ dest(0) += alpha*cblas_dot(N, X, incX, Y, incY);
+
+ } DLIB_END_BLAS_BINDING
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // DOTC overloads
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(conj(rv)*cv)
+ {
+ //cout << "BLAS DOTC: conj(rv)*cv" << endl;
+ const int N = static_cast<int>(src.lhs.op.m.size());
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+ const T* Y = get_ptr(src.rhs);
+ const int incY = get_inc(src.rhs);
+
+ if (add_to == false)
+ dest(0) = alpha*cblas_dotc(N, X, incX, Y, incY);
+ else
+ dest(0) += alpha*cblas_dotc(N, X, incX, Y, incY);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(conj(trans(cv))*cv)
+ {
+ //cout << "BLAS DOTC: conj(trans(cv))*cv" << endl;
+ const int N = static_cast<int>(src.lhs.op.m.size());
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+ const T* Y = get_ptr(src.rhs);
+ const int incY = get_inc(src.rhs);
+
+ if (add_to == false)
+ dest(0) = alpha*cblas_dotc(N, X, incX, Y, incY);
+ else
+ dest(0) += alpha*cblas_dotc(N, X, incX, Y, incY);
+
+ } DLIB_END_BLAS_BINDING
+
+ // --------------------------------------
+
+ DLIB_ADD_BLAS_BINDING(trans(conj(cv))*trans(rv))
+ {
+ //cout << "BLAS DOTC: trans(conj(cv))*trans(rv)" << endl;
+ const int N = static_cast<int>(src.lhs.op.m.size());
+ const T* X = get_ptr(src.lhs.op.m);
+ const int incX = get_inc(src.lhs.op.m);
+ const T* Y = get_ptr(src.rhs.op.m);
+ const int incY = get_inc(src.rhs.op.m);
+
+ if (add_to == false)
+ dest(0) = alpha*cblas_dotc(N, X, incX, Y, incY);
+ else
+ dest(0) += alpha*cblas_dotc(N, X, incX, Y, incY);
+
+ } DLIB_END_BLAS_BINDING
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_BLAS_BINDINGS_
+
diff --git a/ml/dlib/dlib/matrix/matrix_cholesky.h b/ml/dlib/dlib/matrix/matrix_cholesky.h
new file mode 100644
index 000000000..fc1140692
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_cholesky.h
@@ -0,0 +1,231 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+// This code was adapted from code from the JAMA part of NIST's TNT library.
+// See: http://math.nist.gov/tnt/
+#ifndef DLIB_MATRIX_CHOLESKY_DECOMPOSITION_H
+#define DLIB_MATRIX_CHOLESKY_DECOMPOSITION_H
+
+#include "matrix.h"
+#include "matrix_utilities.h"
+#include "matrix_subexp.h"
+#include <cmath>
+
+#ifdef DLIB_USE_LAPACK
+#include "lapack/potrf.h"
+#endif
+
+#include "matrix_trsm.h"
+
+namespace dlib
+{
+
+ template <
+ typename matrix_exp_type
+ >
+ class cholesky_decomposition
+ {
+
+ public:
+
+ const static long NR = matrix_exp_type::NR;
+ const static long NC = matrix_exp_type::NC;
+ typedef typename matrix_exp_type::type type;
+ typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_exp_type::layout_type layout_type;
+
+ typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
+ typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
+
+ // You have supplied an invalid type of matrix_exp_type. You have
+ // to use this object with matrices that contain float or double type data.
+ COMPILE_TIME_ASSERT((is_same_type<float, type>::value ||
+ is_same_type<double, type>::value ));
+
+
+
+ template <typename EXP>
+ cholesky_decomposition(
+ const matrix_exp<EXP>& A
+ );
+
+ bool is_spd(
+ ) const;
+
+ const matrix_type& get_l(
+ ) const;
+
+ template <typename EXP>
+ const typename EXP::matrix_type solve (
+ const matrix_exp<EXP>& B
+ ) const;
+
+ private:
+
+ matrix_type L_; // lower triangular factor
+ bool isspd; // true if matrix to be factored was SPD
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ bool cholesky_decomposition<matrix_exp_type>::
+ is_spd(
+ ) const
+ {
+ return isspd;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename cholesky_decomposition<matrix_exp_type>::matrix_type& cholesky_decomposition<matrix_exp_type>::
+ get_l(
+ ) const
+ {
+ return L_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ cholesky_decomposition<matrix_exp_type>::
+ cholesky_decomposition(
+ const matrix_exp<EXP>& A_
+ )
+ {
+ using std::sqrt;
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(A_.nr() == A_.nc() && A_.size() > 0,
+ "\tcholesky_decomposition::cholesky_decomposition(A_)"
+ << "\n\tYou can only use this on square matrices"
+ << "\n\tA_.nr(): " << A_.nr()
+ << "\n\tA_.nc(): " << A_.nc()
+ << "\n\tA_.size(): " << A_.size()
+ << "\n\tthis: " << this
+ );
+
+#ifdef DLIB_USE_LAPACK
+ L_ = A_;
+ const type eps = max(abs(diag(L_)))*std::sqrt(std::numeric_limits<type>::epsilon())/100;
+
+ // check if the matrix is actually symmetric
+ bool is_symmetric = true;
+ for (long r = 0; r < L_.nr() && is_symmetric; ++r)
+ {
+ for (long c = r+1; c < L_.nc() && is_symmetric; ++c)
+ {
+ // this is approximately doing: is_symmetric = is_symmetric && ( L_(k,j) == L_(j,k))
+ is_symmetric = is_symmetric && (std::abs(L_(r,c) - L_(c,r)) < eps );
+ }
+ }
+
+ // now compute the actual cholesky decomposition
+ int info = lapack::potrf('L', L_);
+
+ // check if it's really SPD
+ if (info == 0 && is_symmetric && min(abs(diag(L_))) > eps*100)
+ isspd = true;
+ else
+ isspd = false;
+
+ L_ = lowerm(L_);
+#else
+ const_temp_matrix<EXP> A(A_);
+
+
+ isspd = true;
+
+ const long n = A.nc();
+ L_.set_size(n,n);
+
+ const type eps = max(abs(diag(A)))*std::sqrt(std::numeric_limits<type>::epsilon())/100;
+
+ // Main loop.
+ for (long j = 0; j < n; j++)
+ {
+ type d(0.0);
+ for (long k = 0; k < j; k++)
+ {
+ type s(0.0);
+ for (long i = 0; i < k; i++)
+ {
+ s += L_(k,i)*L_(j,i);
+ }
+
+ // if L_(k,k) != 0
+ if (std::abs(L_(k,k)) > eps)
+ {
+ s = (A(j,k) - s)/L_(k,k);
+ }
+ else
+ {
+ s = (A(j,k) - s);
+ isspd = false;
+ }
+
+ L_(j,k) = s;
+
+ d = d + s*s;
+
+ // this is approximately doing: isspd = isspd && ( A(k,j) == A(j,k))
+ isspd = isspd && (std::abs(A(k,j) - A(j,k)) < eps );
+ }
+ d = A(j,j) - d;
+ isspd = isspd && (d > eps);
+ L_(j,j) = sqrt(d > 0.0 ? d : 0.0);
+ for (long k = j+1; k < n; k++)
+ {
+ L_(j,k) = 0.0;
+ }
+ }
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ const typename EXP::matrix_type cholesky_decomposition<matrix_exp_type>::
+ solve(
+ const matrix_exp<EXP>& B
+ ) const
+ {
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(L_.nr() == B.nr(),
+ "\tconst matrix cholesky_decomposition::solve(B)"
+ << "\n\tInvalid arguments were given to this function."
+ << "\n\tL_.nr(): " << L_.nr()
+ << "\n\tB.nr(): " << B.nr()
+ << "\n\tthis: " << this
+ );
+
+ matrix<type, NR, EXP::NC, mem_manager_type, layout_type> X(B);
+
+ using namespace blas_bindings;
+ // Solve L*y = b;
+ triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, L_, X);
+ // Solve L'*X = Y;
+ triangular_solver(CblasLeft, CblasLower, CblasTrans, CblasNonUnit, L_, X);
+ return X;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+
+}
+
+#endif // DLIB_MATRIX_CHOLESKY_DECOMPOSITION_H
+
+
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_conj_trans.h b/ml/dlib/dlib/matrix/matrix_conj_trans.h
new file mode 100644
index 000000000..3c319ccaf
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_conj_trans.h
@@ -0,0 +1,71 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_CONJ_TRANS_FUNCTIONS
+#define DLIB_MATRIx_CONJ_TRANS_FUNCTIONS
+
+#include "matrix_utilities.h"
+#include "matrix_math_functions.h"
+#include "matrix.h"
+#include "../algs.h"
+#include <cmath>
+#include <complex>
+#include <limits>
+
+
+namespace dlib
+{
+ /*!
+ The point of the two functions defined in this file is to make statements
+ of the form conj(trans(m)) and trans(conj(m)) look the same so that it is
+ easier to map them to BLAS functions later on.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_conj_trans
+ {
+ op_conj_trans( const M& m_) : m(m_){}
+ const M& m;
+
+ const static long cost = M::cost+1;
+ const static long NR = M::NC;
+ const static long NC = M::NR;
+ typedef typename M::type type;
+ typedef typename M::type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply (long r, long c) const { return std::conj(m(c,r)); }
+
+ long nr () const { return m.nc(); }
+ long nc () const { return m.nr(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <typename EXP>
+ const matrix_op<op_conj_trans<EXP> > trans (
+ const matrix_op<op_conj<EXP> >& m
+ )
+ {
+ typedef op_conj_trans<EXP> op;
+ return matrix_op<op>(op(m.op.m));
+ }
+
+ template <typename EXP>
+ const matrix_op<op_conj_trans<EXP> > conj (
+ const matrix_op<op_trans<EXP> >& m
+ )
+ {
+ typedef op_conj_trans<EXP> op;
+ return matrix_op<op>(op(m.op.m));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_CONJ_TRANS_FUNCTIONS
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_conv.h b/ml/dlib/dlib/matrix/matrix_conv.h
new file mode 100644
index 000000000..b90c388bc
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_conv.h
@@ -0,0 +1,358 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_CONV_Hh_
+#define DLIB_MATRIx_CONV_Hh_
+
+#include "matrix_conv_abstract.h"
+#include "matrix.h"
+#include "matrix_fft.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T>
+ const T& conj(const T& item) { return item; }
+ template <typename T>
+ std::complex<T> conj(const std::complex<T>& item) { return std::conj(item); }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, bool flip_m2 = false>
+ struct op_conv
+ {
+ op_conv( const M1& m1_, const M2& m2_) :
+ m1(m1_),
+ m2(m2_),
+ nr_(m1.nr()+m2.nr()-1),
+ nc_(m1.nc()+m2.nc()-1)
+ {
+ if (nr_ < 0 || m1.size() == 0 || m2.size() == 0)
+ nr_ = 0;
+ if (nc_ < 0 || m1.size() == 0 || m2.size() == 0)
+ nc_ = 0;
+ }
+
+ const M1& m1;
+ const M2& m2;
+ long nr_;
+ long nc_;
+
+ const static long cost = (M1::cost+M2::cost)*10;
+ const static long NR = (M1::NR*M2::NR==0) ? (0) : (M1::NR+M2::NR-1);
+ const static long NC = (M1::NC*M2::NC==0) ? (0) : (M1::NC+M2::NC-1);
+ typedef typename M1::type type;
+ typedef type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+
+ const_ret_type apply (long r, long c) const
+ {
+ type temp = 0;
+
+ const long min_rr = std::max<long>(r-m2.nr()+1, 0);
+ const long max_rr = std::min<long>(m1.nr()-1, r);
+
+ const long min_cc = std::max<long>(c-m2.nc()+1, 0);
+ const long max_cc = std::min<long>(m1.nc()-1, c);
+
+ for (long rr = min_rr; rr <= max_rr; ++rr)
+ {
+ for (long cc = min_cc; cc <= max_cc; ++cc)
+ {
+ if (flip_m2)
+ temp += m1(rr,cc)*dlib::impl::conj(m2(m2.nr()-r+rr-1, m2.nc()-c+cc-1));
+ else
+ temp += m1(rr,cc)*m2(r-rr,c-cc);
+ }
+ }
+
+ return temp;
+ }
+
+ long nr () const { return nr_; }
+ long nc () const { return nc_; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m1.aliases(item) || m2.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m1.aliases(item) || m2.aliases(item); }
+
+ };
+
+ template <
+ typename M1,
+ typename M2
+ >
+ const matrix_op<op_conv<M1,M2> > conv (
+ const matrix_exp<M1>& m1,
+ const matrix_exp<M2>& m2
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename M1::type,typename M2::type>::value == true));
+
+ typedef op_conv<M1,M2> op;
+ return matrix_op<op>(op(m1.ref(),m2.ref()));
+ }
+
+ template <
+ typename M1,
+ typename M2
+ >
+ const matrix_op<op_conv<M1,M2,true> > xcorr (
+ const matrix_exp<M1>& m1,
+ const matrix_exp<M2>& m2
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename M1::type,typename M2::type>::value == true));
+
+ typedef op_conv<M1,M2,true> op;
+ return matrix_op<op>(op(m1.ref(),m2.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline size_t bounding_power_of_two (
+ size_t n
+ )
+ {
+ size_t s = 1;
+ for (unsigned int i = 0; i < sizeof(s)*8 && s < n; ++i)
+ s <<= 1;
+ return s;
+ }
+ }
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ typename EXP1::matrix_type xcorr_fft(
+ const matrix_exp<EXP1>& u,
+ const matrix_exp<EXP2>& v
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type, typename EXP2::type>::value == true));
+ using T = typename EXP1::type;
+ COMPILE_TIME_ASSERT((is_same_type<double,T>::value || is_same_type<float,T>::value || is_same_type<long double,T>::value ));
+
+ const long pad_nr = impl::bounding_power_of_two(u.nr() + v.nr() - 1);
+ const long pad_nc = impl::bounding_power_of_two(u.nc() + v.nc() - 1);
+
+ matrix<std::complex<T>> U(pad_nr, pad_nc), V(pad_nr,pad_nc);
+
+ U = 0;
+ V = 0;
+ set_subm(U,U.nr()-u.nr(),U.nc()-u.nc(),u.nr(),u.nc()) = u;
+ set_subm(V,get_rect(v)) = v;
+
+ fft_inplace(U);
+ fft_inplace(V);
+
+ return subm(real(ifft(pointwise_multiply(U, conj(V)))),
+ U.nr()-u.nr()-v.nr()+1,
+ U.nc()-u.nc()-v.nc()+1,
+ u.nr()+v.nr()-1,
+ u.nc()+v.nc()-1
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, bool flip_m2 = false>
+ struct op_conv_same
+ {
+ op_conv_same( const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_),nr_(m1.nr()),nc_(m1.nc())
+ {
+ if (m1.size() == 0 || m2.size() == 0)
+ nr_ = 0;
+ if (m1.size() == 0 || m2.size() == 0)
+ nc_ = 0;
+ }
+
+ const M1& m1;
+ const M2& m2;
+ long nr_;
+ long nc_;
+
+ const static long cost = (M1::cost+M2::cost)*10;
+ const static long NR = M1::NR;
+ const static long NC = M1::NC;
+ typedef typename M1::type type;
+ typedef type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+
+ const_ret_type apply (long r, long c) const
+ {
+ r += m2.nr()/2;
+ c += m2.nc()/2;
+
+ type temp = 0;
+
+ const long min_rr = std::max<long>(r-m2.nr()+1, 0);
+ const long max_rr = std::min<long>(m1.nr()-1, r);
+
+ const long min_cc = std::max<long>(c-m2.nc()+1, 0);
+ const long max_cc = std::min<long>(m1.nc()-1, c);
+
+ for (long rr = min_rr; rr <= max_rr; ++rr)
+ {
+ for (long cc = min_cc; cc <= max_cc; ++cc)
+ {
+ if (flip_m2)
+ temp += m1(rr,cc)*dlib::impl::conj(m2(m2.nr()-r+rr-1, m2.nc()-c+cc-1));
+ else
+ temp += m1(rr,cc)*m2(r-rr,c-cc);
+ }
+ }
+
+ return temp;
+ }
+
+ long nr () const { return nr_; }
+ long nc () const { return nc_; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m1.aliases(item) || m2.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m1.aliases(item) || m2.aliases(item); }
+
+ };
+
+ template <
+ typename M1,
+ typename M2
+ >
+ const matrix_op<op_conv_same<M1,M2> > conv_same (
+ const matrix_exp<M1>& m1,
+ const matrix_exp<M2>& m2
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename M1::type,typename M2::type>::value == true));
+
+ typedef op_conv_same<M1,M2> op;
+ return matrix_op<op>(op(m1.ref(),m2.ref()));
+ }
+
+ template <
+ typename M1,
+ typename M2
+ >
+ const matrix_op<op_conv_same<M1,M2,true> > xcorr_same (
+ const matrix_exp<M1>& m1,
+ const matrix_exp<M2>& m2
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename M1::type,typename M2::type>::value == true));
+
+ typedef op_conv_same<M1,M2,true> op;
+ return matrix_op<op>(op(m1.ref(),m2.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, bool flip_m2 = false>
+ struct op_conv_valid
+ {
+ op_conv_valid( const M1& m1_, const M2& m2_) :
+ m1(m1_),m2(m2_),
+ nr_(m1.nr()-m2.nr()+1),
+ nc_(m1.nc()-m2.nc()+1)
+ {
+ if (nr_ < 0 || nc_ <= 0 || m1.size() == 0 || m2.size() == 0)
+ nr_ = 0;
+ if (nc_ < 0 || nr_ <= 0 || m1.size() == 0 || m2.size() == 0)
+ nc_ = 0;
+ }
+
+ const M1& m1;
+ const M2& m2;
+ long nr_;
+ long nc_;
+
+ const static long cost = (M1::cost+M2::cost)*10;
+ const static long NR = (M1::NR*M2::NR==0) ? (0) : (M1::NR-M2::NR+1);
+ const static long NC = (M1::NC*M2::NC==0) ? (0) : (M1::NC-M2::NC+1);
+ typedef typename M1::type type;
+ typedef type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+
+ const_ret_type apply (long r, long c) const
+ {
+ r += m2.nr()-1;
+ c += m2.nc()-1;
+
+ type temp = 0;
+
+ const long min_rr = std::max<long>(r-m2.nr()+1, 0);
+ const long max_rr = std::min<long>(m1.nr()-1, r);
+
+ const long min_cc = std::max<long>(c-m2.nc()+1, 0);
+ const long max_cc = std::min<long>(m1.nc()-1, c);
+
+ for (long rr = min_rr; rr <= max_rr; ++rr)
+ {
+ for (long cc = min_cc; cc <= max_cc; ++cc)
+ {
+ if (flip_m2)
+ temp += m1(rr,cc)*dlib::impl::conj(m2(m2.nr()-r+rr-1, m2.nc()-c+cc-1));
+ else
+ temp += m1(rr,cc)*m2(r-rr,c-cc);
+ }
+ }
+
+ return temp;
+ }
+
+ long nr () const { return nr_; }
+ long nc () const { return nc_; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m1.aliases(item) || m2.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m1.aliases(item) || m2.aliases(item); }
+
+ };
+
+ template <
+ typename M1,
+ typename M2
+ >
+ const matrix_op<op_conv_valid<M1,M2> > conv_valid (
+ const matrix_exp<M1>& m1,
+ const matrix_exp<M2>& m2
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename M1::type,typename M2::type>::value == true));
+
+ typedef op_conv_valid<M1,M2> op;
+ return matrix_op<op>(op(m1.ref(),m2.ref()));
+ }
+
+ template <
+ typename M1,
+ typename M2
+ >
+ const matrix_op<op_conv_valid<M1,M2,true> > xcorr_valid (
+ const matrix_exp<M1>& m1,
+ const matrix_exp<M2>& m2
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename M1::type,typename M2::type>::value == true));
+
+ typedef op_conv_valid<M1,M2,true> op;
+ return matrix_op<op>(op(m1.ref(),m2.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_CONV_Hh_
+
diff --git a/ml/dlib/dlib/matrix/matrix_conv_abstract.h b/ml/dlib/dlib/matrix/matrix_conv_abstract.h
new file mode 100644
index 000000000..b342f2668
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_conv_abstract.h
@@ -0,0 +1,158 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MATRIx_CONV_ABSTRACT_Hh_
+#ifdef DLIB_MATRIx_CONV_ABSTRACT_Hh_
+
+#include "matrix_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp conv (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1 and m2 both contain elements of the same type
+ ensures
+ - returns a matrix R such that:
+ - R is the convolution of m1 with m2. In particular, this function is
+ equivalent to performing the following in matlab: R = conv2(m1,m2).
+ - R::type == the same type that was in m1 and m2.
+ - R.nr() == m1.nr()+m2.nr()-1
+ - R.nc() == m1.nc()+m2.nc()-1
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp xcorr (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1 and m2 both contain elements of the same type
+ ensures
+ - returns a matrix R such that:
+ - R is the cross-correlation of m1 with m2. In particular, this
+ function returns conv(m1,flip(m2)) if the matrices contain real
+ elements and conv(m1,flip(conj(m2))) if they are complex.
+ - R::type == the same type that was in m1 and m2.
+ - R.nr() == m1.nr()+m2.nr()-1
+ - R.nc() == m1.nc()+m2.nc()-1
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp xcorr_fft (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1 and m2 both contain elements of the same type
+ - m1 and m2 contain real or complex values and must be double, float, or long
+ double valued. (e.g. not integers)
+ ensures
+ - This function is identical to xcorr() except that it uses a fast Fourier
+ transform to do the convolution and is therefore much faster when both m1 and
+ m2 are large.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp conv_same (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1 and m2 both contain elements of the same type
+ ensures
+ - returns a matrix R such that:
+ - R is the convolution of m1 with m2. In particular, this function is
+ equivalent to performing the following in matlab: R = conv2(m1,m2,'same').
+ In particular, this means the result will have the same dimensions as m1 and will
+ contain the central part of the full convolution. Therefore, conv_same(m1,m2) is
+ equivalent to subm(conv(m1,m2), m2.nr()/2, m2.nc()/2, m1.nr(), m1.nc()).
+ - R::type == the same type that was in m1 and m2.
+ - R.nr() == m1.nr()
+ - R.nc() == m1.nc()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp xcorr_same (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1 and m2 both contain elements of the same type
+ ensures
+ - returns a matrix R such that:
+ - R is the cross-correlation of m1 with m2. In particular, this
+ function returns conv_same(m1,flip(m2)) if the matrices contain real
+ elements and conv_same(m1,flip(conj(m2))) if they are complex.
+ - R::type == the same type that was in m1 and m2.
+ - R.nr() == m1.nr()
+ - R.nc() == m1.nc()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp conv_valid (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1 and m2 both contain elements of the same type
+ ensures
+ - returns a matrix R such that:
+ - R is the convolution of m1 with m2. In particular, this function is
+ equivalent to performing the following in matlab: R = conv2(m1,m2,'valid').
+ In particular, this means only elements of the convolution which don't require
+ zero padding are included in the result.
+ - R::type == the same type that was in m1 and m2.
+ - if (m1 has larger dimensions than m2) then
+ - R.nr() == m1.nr()-m2.nr()+1
+ - R.nc() == m1.nc()-m2.nc()+1
+ - else
+ - R.nr() == 0
+ - R.nc() == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp xcorr_valid (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - m1 and m2 both contain elements of the same type
+ ensures
+ - returns a matrix R such that:
+ - R is the cross-correlation of m1 with m2. In particular, this
+ function returns conv_valid(m1,flip(m2)) if the matrices contain real
+ elements and conv_valid(m1,flip(conj(m2))) if they are complex.
+ - R::type == the same type that was in m1 and m2.
+ - if (m1 has larger dimensions than m2) then
+ - R.nr() == m1.nr()-m2.nr()+1
+ - R.nc() == m1.nc()-m2.nc()+1
+ - else
+ - R.nr() == 0
+ - R.nc() == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_CONV_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_data_layout.h b/ml/dlib/dlib/matrix/matrix_data_layout.h
new file mode 100644
index 000000000..22891c228
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_data_layout.h
@@ -0,0 +1,1271 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_DATA_LAYOUT_
+#define DLIB_MATRIx_DATA_LAYOUT_
+
+#include "../algs.h"
+#include "matrix_fwd.h"
+#include "matrix_data_layout_abstract.h"
+#ifdef MATLAB_MEX_FILE
+#include <mex.h>
+#endif
+
+// GCC 4.8 gives false alarms about some matrix operations going out of bounds. Disable
+// these false warnings.
+#if defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Warray-bounds"
+#endif
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ A matrix layout object is any object that contains a templated class called "layout"
+ with an interface identical to one below:
+ (Note that all the template arguments are just the template arguments from the dlib::matrix
+ object and the member functions are defined identically to the ones with the same
+ signatures inside the matrix object.)
+
+ struct matrix_layout
+ {
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout
+ {
+ public:
+
+ T& operator() (
+ long r,
+ long c
+ );
+
+ const T& operator() (
+ long r,
+ long c
+ );
+
+ T& operator() (
+ long i
+ );
+
+ const T& operator() (
+ long i
+ ) const;
+
+ void swap(
+ layout& item
+ );
+
+ long nr (
+ ) const;
+
+ long nc (
+ ) const;
+
+ void set_size (
+ long nr_,
+ long nc_
+ );
+ };
+ };
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ struct row_major_layout
+ {
+ // if a matrix is bigger than this many bytes then don't put it on the stack
+ const static size_t max_stack_based_size = 256;
+
+ // this is a hack to avoid a compile time error in visual studio 8. I would just
+ // use sizeof(T) and be done with it but that won't compile. The idea here
+ // is to avoid using the stack allocation of the layout object if it
+ // is going to contain another matrix and also avoid asking for the sizeof()
+ // the contained matrix.
+ template <typename T>
+ struct get_sizeof_helper
+ {
+ const static std::size_t val = sizeof(T);
+ };
+
+ template <typename T, long NR, long NC, typename mm, typename l>
+ struct get_sizeof_helper<matrix<T,NR,NC,mm,l> >
+ {
+ const static std::size_t val = 1000000;
+ };
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager,
+ int val = static_switch <
+ // when the sizes are all non zero and small
+ (num_rows*num_cols*get_sizeof_helper<T>::val <= max_stack_based_size) && (num_rows != 0 && num_cols != 0),
+ // when the sizes are all non zero and big
+ (num_rows*num_cols*get_sizeof_helper<T>::val > max_stack_based_size) && (num_rows != 0 && num_cols != 0),
+ num_rows == 0 && num_cols != 0,
+ num_rows != 0 && num_cols == 0,
+ num_rows == 0 && num_cols == 0
+ >::value
+ >
+ class layout ;
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the actual allocation of space for a matrix.
+ Small matrices allocate all their data on the stack and bigger ones
+ use a memory_manager to get their memory.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,1> : noncopyable // when the sizes are all non zero and small
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout() {}
+
+ T& operator() (
+ long r,
+ long c
+ ) { return *(data+r*num_cols + c); }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return *(data+r*num_cols + c); }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ for (long r = 0; r < num_rows; ++r)
+ {
+ for (long c = 0; c < num_cols; ++c)
+ {
+ exchange((*this)(r,c),item(r,c));
+ }
+ }
+ }
+
+ long nr (
+ ) const { return num_rows; }
+
+ long nc (
+ ) const { return num_cols; }
+
+ void set_size (
+ long ,
+ long
+ )
+ {
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+
+ private:
+ T data[num_rows*num_cols];
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,2> : noncopyable // when the sizes are all non zero and big
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ) { data = pool.allocate_array(num_rows*num_cols); }
+
+ ~layout ()
+ { pool.deallocate_array(data); }
+
+ T& operator() (
+ long r,
+ long c
+ ) { return data[r*num_cols + c]; }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return data[r*num_cols + c]; }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.data,data);
+ pool.swap(item.pool);
+ }
+
+ long nr (
+ ) const { return num_rows; }
+
+ long nc (
+ ) const { return num_cols; }
+
+ void set_size (
+ long ,
+ long
+ )
+ {
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+
+ private:
+
+ T* data;
+ typename mem_manager::template rebind<T>::other pool;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,3> : noncopyable // when num_rows == 0 && num_cols != 0,
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ):data(0), nr_(0) { }
+
+ ~layout ()
+ {
+ if (data)
+ pool.deallocate_array(data);
+ }
+
+ T& operator() (
+ long r,
+ long c
+ ) { return data[r*num_cols + c]; }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return data[r*num_cols + c]; }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.data,data);
+ std::swap(item.nr_,nr_);
+ pool.swap(item.pool);
+ }
+
+ long nr (
+ ) const { return nr_; }
+
+ long nc (
+ ) const { return num_cols; }
+
+ void set_size (
+ long nr,
+ long nc
+ )
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ data = pool.allocate_array(nr*nc);
+ nr_ = nr;
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+
+ private:
+
+ T* data;
+ long nr_;
+ typename mem_manager::template rebind<T>::other pool;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,4> : noncopyable // when num_rows != 0 && num_cols == 0
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ):data(0), nc_(0) { }
+
+ ~layout ()
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ }
+
+ T& operator() (
+ long r,
+ long c
+ ) { return data[r*nc_ + c]; }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return data[r*nc_ + c]; }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.data,data);
+ std::swap(item.nc_,nc_);
+ pool.swap(item.pool);
+ }
+
+ long nr (
+ ) const { return num_rows; }
+
+ long nc (
+ ) const { return nc_; }
+
+ void set_size (
+ long nr,
+ long nc
+ )
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ data = pool.allocate_array(nr*nc);
+ nc_ = nc;
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+
+ private:
+
+ T* data;
+ long nc_;
+ typename mem_manager::template rebind<T>::other pool;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,5> : noncopyable // when num_rows == 0 && num_cols == 0
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ):data(0), nr_(0), nc_(0) { }
+
+ ~layout ()
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ }
+
+ T& operator() (
+ long r,
+ long c
+ ) { return data[r*nc_ + c]; }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return data[r*nc_ + c]; }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.data,data);
+ std::swap(item.nc_,nc_);
+ std::swap(item.nr_,nr_);
+ pool.swap(item.pool);
+ }
+
+ long nr (
+ ) const { return nr_; }
+
+ long nc (
+ ) const { return nc_; }
+
+ void set_size (
+ long nr,
+ long nc
+ )
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ data = pool.allocate_array(nr*nc);
+ nr_ = nr;
+ nc_ = nc;
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+ private:
+ T* data;
+ long nr_;
+ long nc_;
+ typename mem_manager::template rebind<T>::other pool;
+ };
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct column_major_layout
+ {
+ // if a matrix is bigger than this many bytes then don't put it on the stack
+ const static size_t max_stack_based_size = 256;
+
+
+ // this is a hack to avoid a compile time error in visual studio 8. I would just
+ // use sizeof(T) and be done with it but that won't compile. The idea here
+ // is to avoid using the stack allocation of the layout object if it
+ // is going to contain another matrix and also avoid asking for the sizeof()
+ // the contained matrix.
+ template <typename T>
+ struct get_sizeof_helper
+ {
+ const static std::size_t val = sizeof(T);
+ };
+
+ template <typename T, long NR, long NC, typename mm, typename l>
+ struct get_sizeof_helper<matrix<T,NR,NC,mm,l> >
+ {
+ const static std::size_t val = 1000000;
+ };
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager,
+ int val = static_switch <
+ // when the sizes are all non zero and small
+ (num_rows*num_cols*get_sizeof_helper<T>::val <= max_stack_based_size) && (num_rows != 0 && num_cols != 0),
+ // when the sizes are all non zero and big
+ (num_rows*num_cols*get_sizeof_helper<T>::val > max_stack_based_size) && (num_rows != 0 && num_cols != 0),
+ num_rows == 0 && num_cols != 0,
+ num_rows != 0 && num_cols == 0,
+ num_rows == 0 && num_cols == 0
+ >::value
+ >
+ class layout ;
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the actual allocation of space for a matrix.
+ Small matrices allocate all their data on the stack and bigger ones
+ use a memory_manager to get their memory.
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,1> : noncopyable // when the sizes are all non zero and small
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout() {}
+
+ T& operator() (
+ long r,
+ long c
+ ) { return *(data+c*num_rows + r); }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return *(data+c*num_rows + r); }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ for (long r = 0; r < num_rows; ++r)
+ {
+ for (long c = 0; c < num_cols; ++c)
+ {
+ exchange((*this)(r,c),item(r,c));
+ }
+ }
+ }
+
+ long nr (
+ ) const { return num_rows; }
+
+ long nc (
+ ) const { return num_cols; }
+
+ void set_size (
+ long,
+ long
+ )
+ {
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+
+ private:
+ T data[num_cols*num_rows];
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,2> : noncopyable // when the sizes are all non zero and big
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ) { data = pool.allocate_array(num_rows*num_cols); }
+
+ ~layout ()
+ { pool.deallocate_array(data); }
+
+ T& operator() (
+ long r,
+ long c
+ ) { return data[c*num_rows + r]; }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return data[c*num_rows + r]; }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.data,data);
+ pool.swap(item.pool);
+ }
+
+ long nr (
+ ) const { return num_rows; }
+
+ long nc (
+ ) const { return num_cols; }
+
+ void set_size (
+ long ,
+ long
+ )
+ {
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+
+ private:
+
+ T* data;
+ typename mem_manager::template rebind<T>::other pool;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,3> : noncopyable // when num_rows == 0 && num_cols != 0,
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ):data(0), nr_(0) { }
+
+ ~layout ()
+ {
+ if (data)
+ pool.deallocate_array(data);
+ }
+
+ T& operator() (
+ long r,
+ long c
+ ) { return data[c*nr_ + r]; }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return data[c*nr_ + r]; }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.data,data);
+ std::swap(item.nr_,nr_);
+ pool.swap(item.pool);
+ }
+
+ long nr (
+ ) const { return nr_; }
+
+ long nc (
+ ) const { return num_cols; }
+
+ void set_size (
+ long nr,
+ long nc
+ )
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ data = pool.allocate_array(nr*nc);
+ nr_ = nr;
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+
+ private:
+
+ T* data;
+ long nr_;
+ typename mem_manager::template rebind<T>::other pool;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,4> : noncopyable // when num_rows != 0 && num_cols == 0
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ):data(0), nc_(0) { }
+
+ ~layout ()
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ }
+
+ T& operator() (
+ long r,
+ long c
+ ) { return data[c*num_rows + r]; }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return data[c*num_rows + r]; }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.data,data);
+ std::swap(item.nc_,nc_);
+ pool.swap(item.pool);
+ }
+
+ long nr (
+ ) const { return num_rows; }
+
+ long nc (
+ ) const { return nc_; }
+
+ void set_size (
+ long nr,
+ long nc
+ )
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ data = pool.allocate_array(nr*nc);
+ nc_ = nc;
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+
+ private:
+
+ T* data;
+ long nc_;
+ typename mem_manager::template rebind<T>::other pool;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows,
+ long num_cols,
+ typename mem_manager
+ >
+ class layout<T,num_rows,num_cols,mem_manager,5> : noncopyable // when num_rows == 0 && num_cols == 0
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ):data(0), nr_(0), nc_(0) { }
+
+ ~layout ()
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ }
+
+ T& operator() (
+ long r,
+ long c
+ ) { return data[c*nr_ + r]; }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const { return data[c*nr_ + r]; }
+
+ T& operator() (
+ long i
+ ) { return data[i]; }
+
+ const T& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.data,data);
+ std::swap(item.nc_,nc_);
+ std::swap(item.nr_,nr_);
+ pool.swap(item.pool);
+ }
+
+#ifdef MATLAB_MEX_FILE
+ void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
+ mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
+ void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); }
+ bool _private_is_owned_by_matlab() const { return false; }
+#endif
+
+ long nr (
+ ) const { return nr_; }
+
+ long nc (
+ ) const { return nc_; }
+
+ void set_size (
+ long nr,
+ long nc
+ )
+ {
+ if (data)
+ {
+ pool.deallocate_array(data);
+ }
+ data = pool.allocate_array(nr*nc);
+ nr_ = nr;
+ nc_ = nc;
+ }
+
+ private:
+ T* data;
+ long nr_;
+ long nc_;
+ typename mem_manager::template rebind<T>::other pool;
+ };
+
+#ifdef MATLAB_MEX_FILE
+ template <
+ long num_rows,
+ long num_cols
+ >
+ class layout<double,num_rows,num_cols,default_memory_manager,5> : noncopyable // when num_rows == 0 && num_cols == 0
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ): data(0), nr_(0), nc_(0), owned_by_matlab(false),set_by_private_set_mxArray(false),mem(0) { }
+
+ ~layout ()
+ {
+ if (owned_by_matlab)
+ {
+ if (!set_by_private_set_mxArray && mem)
+ {
+ mxDestroyArray(mem);
+ mem = 0;
+ data = 0;
+ }
+ }
+ else if (data)
+ {
+ delete [] data;
+ data = 0;
+ }
+ }
+
+ double& operator() (
+ long r,
+ long c
+ ) { return data[c*nr_ + r]; }
+
+ const double& operator() (
+ long r,
+ long c
+ ) const { return data[c*nr_ + r]; }
+
+ double& operator() (
+ long i
+ ) { return data[i]; }
+
+ const double& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void _private_set_mxArray (
+ mxArray* mem_
+ )
+ {
+ DLIB_CASSERT(mem == 0 && data == 0,"You can't call this function on an already allocated matrix.");
+ // We don't own the pointer, so make note of that so we won't try to free
+ // it.
+ set_by_private_set_mxArray = true;
+ owned_by_matlab = true;
+ mem = mem_;
+ data = mxGetPr(mem);
+ nr_ = mxGetM(mem);
+ nc_ = mxGetN(mem);
+ }
+
+ mxArray* _private_release_mxArray()
+ {
+ DLIB_CASSERT(owned_by_matlab,"");
+ mxArray* temp = mem;
+ mem = 0;
+ set_by_private_set_mxArray = false;
+ data = 0;
+ nr_ = 0;
+ nc_ = 0;
+ return temp;
+ }
+
+ void _private_mark_owned_by_matlab()
+ {
+ DLIB_CASSERT(mem == 0 && data == 0,"You can't say a matrix should be owned by matlab after it's been allocated.");
+ owned_by_matlab = true;
+ }
+ bool _private_is_owned_by_matlab() const
+ {
+ return owned_by_matlab;
+ }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.owned_by_matlab,owned_by_matlab);
+ std::swap(item.set_by_private_set_mxArray,set_by_private_set_mxArray);
+ std::swap(item.mem,mem);
+ std::swap(item.data,data);
+ std::swap(item.nc_,nc_);
+ std::swap(item.nr_,nr_);
+ }
+
+ long nr (
+ ) const { return nr_; }
+
+ long nc (
+ ) const { return nc_; }
+
+ void set_size (
+ long nr,
+ long nc
+ )
+ {
+ if (owned_by_matlab)
+ {
+ if (!set_by_private_set_mxArray && mem)
+ {
+ mxDestroyArray(mem);
+ mem = 0;
+ data = 0;
+ }
+ set_by_private_set_mxArray = false;
+
+ mem = mxCreateDoubleMatrix(nr, nc, mxREAL);
+ if (mem == 0)
+ throw std::bad_alloc();
+ data = mxGetPr(mem);
+ }
+ else
+ {
+ if (data)
+ delete [] data;
+ data = new double[nr*nc];
+ }
+ nr_ = nr;
+ nc_ = nc;
+ }
+
+ private:
+ double* data;
+ long nr_;
+ long nc_;
+ bool owned_by_matlab;
+ bool set_by_private_set_mxArray;
+ mxArray* mem;
+ };
+
+ template <
+ long num_rows,
+ long num_cols
+ >
+ class layout<float,num_rows,num_cols,default_memory_manager,5> : noncopyable // when num_rows == 0 && num_cols == 0
+ {
+ public:
+ const static long NR = num_rows;
+ const static long NC = num_cols;
+
+ layout (
+ ): data(0), nr_(0), nc_(0), owned_by_matlab(false),set_by_private_set_mxArray(false),mem(0) { }
+
+ ~layout ()
+ {
+ if (owned_by_matlab)
+ {
+ if (!set_by_private_set_mxArray && mem)
+ {
+ mxDestroyArray(mem);
+ mem = 0;
+ data = 0;
+ }
+ }
+ else if (data)
+ {
+ delete [] data;
+ data = 0;
+ }
+ }
+
+ float& operator() (
+ long r,
+ long c
+ ) { return data[c*nr_ + r]; }
+
+ const float& operator() (
+ long r,
+ long c
+ ) const { return data[c*nr_ + r]; }
+
+ float& operator() (
+ long i
+ ) { return data[i]; }
+
+ const float& operator() (
+ long i
+ ) const { return data[i]; }
+
+ void _private_set_mxArray (
+ mxArray* mem_
+ )
+ {
+ DLIB_CASSERT(mem == 0 && data == 0,"You can't call this function on an already allocated matrix.");
+ // We don't own the pointer, so make note of that so we won't try to free
+ // it.
+ set_by_private_set_mxArray = true;
+ owned_by_matlab = true;
+ mem = mem_;
+ data = (float*)mxGetData(mem);
+ nr_ = mxGetM(mem);
+ nc_ = mxGetN(mem);
+ }
+
+ mxArray* _private_release_mxArray()
+ {
+ DLIB_CASSERT(owned_by_matlab,"");
+ mxArray* temp = mem;
+ mem = 0;
+ set_by_private_set_mxArray = false;
+ data = 0;
+ nr_ = 0;
+ nc_ = 0;
+ return temp;
+ }
+
+ void _private_mark_owned_by_matlab()
+ {
+ DLIB_CASSERT(mem == 0 && data == 0,"You can't say a matrix should be owned by matlab after it's been allocated.");
+ owned_by_matlab = true;
+ }
+ bool _private_is_owned_by_matlab() const
+ {
+ return owned_by_matlab;
+ }
+
+ void swap(
+ layout& item
+ )
+ {
+ std::swap(item.owned_by_matlab,owned_by_matlab);
+ std::swap(item.set_by_private_set_mxArray,set_by_private_set_mxArray);
+ std::swap(item.mem,mem);
+ std::swap(item.data,data);
+ std::swap(item.nc_,nc_);
+ std::swap(item.nr_,nr_);
+ }
+
+ long nr (
+ ) const { return nr_; }
+
+ long nc (
+ ) const { return nc_; }
+
+ void set_size (
+ long nr,
+ long nc
+ )
+ {
+ if (owned_by_matlab)
+ {
+ if (!set_by_private_set_mxArray && mem)
+ {
+ mxDestroyArray(mem);
+ mem = 0;
+ data = 0;
+ }
+ set_by_private_set_mxArray = false;
+
+ mem = mxCreateNumericMatrix(nr, nc, mxSINGLE_CLASS, mxREAL);
+ if (mem == 0)
+ throw std::bad_alloc();
+ data = (float*)mxGetData(mem);
+ }
+ else
+ {
+ if (data)
+ delete [] data;
+ data = new float[nr*nc];
+ }
+ nr_ = nr;
+ nc_ = nc;
+ }
+
+ private:
+ float* data;
+ long nr_;
+ long nc_;
+ bool owned_by_matlab;
+ bool set_by_private_set_mxArray;
+ mxArray* mem;
+ };
+#endif
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#if defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))
+#pragma GCC diagnostic pop
+#endif
+
+#endif // DLIB_MATRIx_DATA_LAYOUT_
+
diff --git a/ml/dlib/dlib/matrix/matrix_data_layout_abstract.h b/ml/dlib/dlib/matrix/matrix_data_layout_abstract.h
new file mode 100644
index 000000000..c3fa02be2
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_data_layout_abstract.h
@@ -0,0 +1,40 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MATRIx_DATA_LAYOUT_ABSTRACT_
+#ifdef DLIB_MATRIx_DATA_LAYOUT_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct row_major_layout
+ {
+ /*!
+ This is the default matrix layout. Any matrix object that uses this
+ layout will be laid out in memory in row major order. Additionally,
+ all elements are contiguous (e.g. there isn't any padding at the ends of
+ rows or anything like that)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct column_major_layout
+ {
+ /*!
+ Any matrix object that uses this layout will be laid out in memory in
+ column major order. Additionally, all elements are contiguous (e.g.
+ there isn't any padding at the ends of rows or anything like that)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_DATA_LAYOUT_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_default_mul.h b/ml/dlib/dlib/matrix/matrix_default_mul.h
new file mode 100644
index 000000000..493c641a8
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_default_mul.h
@@ -0,0 +1,134 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_DEFAULT_MULTIPLY_
+#define DLIB_MATRIx_DEFAULT_MULTIPLY_
+
+#include "../geometry/rectangle.h"
+#include "matrix.h"
+#include "matrix_utilities.h"
+#include "../enable_if.h"
+
+namespace dlib
+{
+
+// ------------------------------------------------------------------------------------
+
+ namespace ma
+ {
+ template < typename EXP, typename enable = void >
+ struct matrix_is_vector { static const bool value = false; };
+ template < typename EXP >
+ struct matrix_is_vector<EXP, typename enable_if_c<EXP::NR==1 || EXP::NC==1>::type > { static const bool value = true; };
+ }
+
+// ------------------------------------------------------------------------------------
+
+ /*! This file defines the default_matrix_multiply() function. It is a function
+ that conforms to the following definition:
+
+ template <
+ typename matrix_dest_type,
+ typename EXP1,
+ typename EXP2
+ >
+ void default_matrix_multiply (
+ matrix_dest_type& dest,
+ const EXP1& lhs,
+ const EXP2& rhs
+ );
+ requires
+ - (lhs*rhs).destructively_aliases(dest) == false
+ - dest.nr() == (lhs*rhs).nr()
+ - dest.nc() == (lhs*rhs).nc()
+ ensures
+ - #dest == dest + lhs*rhs
+ !*/
+
+// ------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_dest_type,
+ typename EXP1,
+ typename EXP2
+ >
+ typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true || ma::matrix_is_vector<EXP2>::value == true>::type
+ default_matrix_multiply (
+ matrix_dest_type& dest,
+ const EXP1& lhs,
+ const EXP2& rhs
+ )
+ {
+ matrix_assign_default(dest, lhs*rhs, 1, true);
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_dest_type,
+ typename EXP1,
+ typename EXP2
+ >
+ typename enable_if_c<ma::matrix_is_vector<EXP1>::value == false && ma::matrix_is_vector<EXP2>::value == false>::type
+ default_matrix_multiply (
+ matrix_dest_type& dest,
+ const EXP1& lhs,
+ const EXP2& rhs
+ )
+ {
+ const long bs = 90;
+
+ // if the matrices are small enough then just use the simple multiply algorithm
+ if (lhs.nc() <= 2 || rhs.nc() <= 2 || lhs.nr() <= 2 || rhs.nr() <= 2 || (lhs.size() <= bs*10 && rhs.size() <= bs*10) )
+ {
+ matrix_assign_default(dest, lhs*rhs, 1, true);
+ }
+ else
+ {
+ // if the lhs and rhs matrices are big enough we should use a cache friendly
+ // algorithm that computes the matrix multiply in blocks.
+
+
+ // Loop over all the blocks in the lhs matrix
+ for (long r = 0; r < lhs.nr(); r+=bs)
+ {
+ for (long c = 0; c < lhs.nc(); c+=bs)
+ {
+ // make a rect for the block from lhs
+ rectangle lhs_block(c, r, std::min(c+bs-1,lhs.nc()-1), std::min(r+bs-1,lhs.nr()-1));
+
+ // now loop over all the rhs blocks we have to multiply with the current lhs block
+ for (long i = 0; i < rhs.nc(); i += bs)
+ {
+ // make a rect for the block from rhs
+ rectangle rhs_block(i, c, std::min(i+bs-1,rhs.nc()-1), std::min(c+bs-1,rhs.nr()-1));
+
+ // make a target rect in res
+ rectangle res_block(rhs_block.left(),lhs_block.top(), rhs_block.right(), lhs_block.bottom());
+
+ // This loop is optimized assuming that the data is laid out in
+ // row major order in memory.
+ for (long r = lhs_block.top(); r <= lhs_block.bottom(); ++r)
+ {
+ for (long c = lhs_block.left(); c<= lhs_block.right(); ++c)
+ {
+ const typename EXP2::type temp = lhs(r,c);
+ for (long i = rhs_block.left(); i <= rhs_block.right(); ++i)
+ {
+ dest(r,i) += rhs(c,i)*temp;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+
+ }
+
+// ------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_DEFAULT_MULTIPLY_
+
diff --git a/ml/dlib/dlib/matrix/matrix_eigenvalue.h b/ml/dlib/dlib/matrix/matrix_eigenvalue.h
new file mode 100644
index 000000000..3dc47e105
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_eigenvalue.h
@@ -0,0 +1,1379 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+// This code was adapted from code from the JAMA part of NIST's TNT library.
+// See: http://math.nist.gov/tnt/
+#ifndef DLIB_MATRIX_EIGENVALUE_DECOMPOSITION_H
+#define DLIB_MATRIX_EIGENVALUE_DECOMPOSITION_H
+
+#include "matrix.h"
+#include "matrix_utilities.h"
+#include "matrix_subexp.h"
+#include <algorithm>
+#include <complex>
+#include <cmath>
+
+#ifdef DLIB_USE_LAPACK
+#include "lapack/geev.h"
+#include "lapack/syev.h"
+#include "lapack/syevr.h"
+#endif
+
+#define DLIB_LAPACK_EIGENVALUE_DECOMP_SIZE_THRESH 4
+
+namespace dlib
+{
+
+ template <
+ typename matrix_exp_type
+ >
+ class eigenvalue_decomposition
+ {
+
+ public:
+
+ const static long NR = matrix_exp_type::NR;
+ const static long NC = matrix_exp_type::NC;
+ typedef typename matrix_exp_type::type type;
+ typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_exp_type::layout_type layout_type;
+
+ typedef typename matrix_exp_type::matrix_type matrix_type;
+ typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
+
+ typedef matrix<std::complex<type>,0,0,mem_manager_type,layout_type> complex_matrix_type;
+ typedef matrix<std::complex<type>,NR,1,mem_manager_type,layout_type> complex_column_vector_type;
+
+
+ // You have supplied an invalid type of matrix_exp_type. You have
+ // to use this object with matrices that contain float or double type data.
+ COMPILE_TIME_ASSERT((is_same_type<float, type>::value ||
+ is_same_type<double, type>::value ));
+
+
+ template <typename EXP>
+ eigenvalue_decomposition(
+ const matrix_exp<EXP>& A
+ );
+
+ template <typename EXP>
+ eigenvalue_decomposition(
+ const matrix_op<op_make_symmetric<EXP> >& A
+ );
+
+ long dim (
+ ) const;
+
+ const complex_column_vector_type get_eigenvalues (
+ ) const;
+
+ const column_vector_type& get_real_eigenvalues (
+ ) const;
+
+ const column_vector_type& get_imag_eigenvalues (
+ ) const;
+
+ const complex_matrix_type get_v (
+ ) const;
+
+ const complex_matrix_type get_d (
+ ) const;
+
+ const matrix_type& get_pseudo_v (
+ ) const;
+
+ const matrix_type get_pseudo_d (
+ ) const;
+
+ private:
+
+ /** Row and column dimension (square matrix). */
+ long n;
+
+ bool issymmetric;
+
+ /** Arrays for internal storage of eigenvalues. */
+
+ column_vector_type d; /* real part */
+ column_vector_type e; /* img part */
+
+ /** Array for internal storage of eigenvectors. */
+ matrix_type V;
+
+ /** Array for internal storage of nonsymmetric Hessenberg form.
+ @serial internal storage of nonsymmetric Hessenberg form.
+ */
+ matrix_type H;
+
+
+ /** Working storage for nonsymmetric algorithm.
+ @serial working storage for nonsymmetric algorithm.
+ */
+ column_vector_type ort;
+
+ // Symmetric Householder reduction to tridiagonal form.
+ void tred2();
+
+
+ // Symmetric tridiagonal QL algorithm.
+ void tql2 ();
+
+
+ // Nonsymmetric reduction to Hessenberg form.
+ void orthes ();
+
+
+ // Complex scalar division.
+ type cdivr, cdivi;
+ void cdiv_(type xr, type xi, type yr, type yi);
+
+
+ // Nonsymmetric reduction from Hessenberg to real Schur form.
+ void hqr2 ();
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Public member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ eigenvalue_decomposition<matrix_exp_type>::
+ eigenvalue_decomposition(
+ const matrix_exp<EXP>& A_
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+
+ const_temp_matrix<EXP> A(A_);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(A.nr() == A.nc() && A.size() > 0,
+ "\teigenvalue_decomposition::eigenvalue_decomposition(A)"
+ << "\n\tYou can only use this on square matrices"
+ << "\n\tA.nr(): " << A.nr()
+ << "\n\tA.nc(): " << A.nc()
+ << "\n\tA.size(): " << A.size()
+ << "\n\tthis: " << this
+ );
+
+
+ n = A.nc();
+ V.set_size(n,n);
+ d.set_size(n);
+ e.set_size(n);
+
+
+ issymmetric = true;
+ for (long j = 0; (j < n) && issymmetric; j++)
+ {
+ for (long i = 0; (i < n) && issymmetric; i++)
+ {
+ issymmetric = (A(i,j) == A(j,i));
+ }
+ }
+
+ if (issymmetric)
+ {
+ V = A;
+
+#ifdef DLIB_USE_LAPACK
+ if (A.nr() > DLIB_LAPACK_EIGENVALUE_DECOMP_SIZE_THRESH)
+ {
+ e = 0;
+
+ // We could compute the result using syev()
+ //lapack::syev('V', 'L', V, d);
+
+ // Instead, we use syevr because its faster and maybe more stable.
+ matrix_type tempA(A);
+ matrix<lapack::integer,0,0,mem_manager_type,layout_type> isupz;
+
+ lapack::integer temp;
+ lapack::syevr('V','A','L',tempA,0,0,0,0,-1,temp,d,V,isupz);
+ return;
+ }
+#endif
+ // Tridiagonalize.
+ tred2();
+
+ // Diagonalize.
+ tql2();
+
+ }
+ else
+ {
+
+#ifdef DLIB_USE_LAPACK
+ if (A.nr() > DLIB_LAPACK_EIGENVALUE_DECOMP_SIZE_THRESH)
+ {
+ matrix<type,0,0,mem_manager_type, column_major_layout> temp, vl, vr;
+ temp = A;
+ lapack::geev('N', 'V', temp, d, e, vl, vr);
+ V = vr;
+ return;
+ }
+#endif
+ H = A;
+
+ ort.set_size(n);
+
+ // Reduce to Hessenberg form.
+ orthes();
+
+ // Reduce Hessenberg to real Schur form.
+ hqr2();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ eigenvalue_decomposition<matrix_exp_type>::
+ eigenvalue_decomposition(
+ const matrix_op<op_make_symmetric<EXP> >& A
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(A.nr() == A.nc() && A.size() > 0,
+ "\teigenvalue_decomposition::eigenvalue_decomposition(A)"
+ << "\n\tYou can only use this on square matrices"
+ << "\n\tA.nr(): " << A.nr()
+ << "\n\tA.nc(): " << A.nc()
+ << "\n\tA.size(): " << A.size()
+ << "\n\tthis: " << this
+ );
+
+
+ n = A.nc();
+ V.set_size(n,n);
+ d.set_size(n);
+ e.set_size(n);
+
+
+ V = A;
+
+#ifdef DLIB_USE_LAPACK
+ if (A.nr() > DLIB_LAPACK_EIGENVALUE_DECOMP_SIZE_THRESH)
+ {
+ e = 0;
+
+ // We could compute the result using syev()
+ //lapack::syev('V', 'L', V, d);
+
+ // Instead, we use syevr because its faster and maybe more stable.
+ matrix_type tempA(A);
+ matrix<lapack::integer,0,0,mem_manager_type,layout_type> isupz;
+
+ lapack::integer temp;
+ lapack::syevr('V','A','L',tempA,0,0,0,0,-1,temp,d,V,isupz);
+ return;
+ }
+#endif
+ // Tridiagonalize.
+ tred2();
+
+ // Diagonalize.
+ tql2();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename eigenvalue_decomposition<matrix_exp_type>::matrix_type& eigenvalue_decomposition<matrix_exp_type>::
+ get_pseudo_v (
+ ) const
+ {
+ return V;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ long eigenvalue_decomposition<matrix_exp_type>::
+ dim (
+ ) const
+ {
+ return V.nr();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename eigenvalue_decomposition<matrix_exp_type>::complex_column_vector_type eigenvalue_decomposition<matrix_exp_type>::
+ get_eigenvalues (
+ ) const
+ {
+ return complex_matrix(get_real_eigenvalues(), get_imag_eigenvalues());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename eigenvalue_decomposition<matrix_exp_type>::column_vector_type& eigenvalue_decomposition<matrix_exp_type>::
+ get_real_eigenvalues (
+ ) const
+ {
+ return d;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename eigenvalue_decomposition<matrix_exp_type>::column_vector_type& eigenvalue_decomposition<matrix_exp_type>::
+ get_imag_eigenvalues (
+ ) const
+ {
+ return e;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename eigenvalue_decomposition<matrix_exp_type>::complex_matrix_type eigenvalue_decomposition<matrix_exp_type>::
+ get_d (
+ ) const
+ {
+ return diagm(complex_matrix(get_real_eigenvalues(), get_imag_eigenvalues()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename eigenvalue_decomposition<matrix_exp_type>::complex_matrix_type eigenvalue_decomposition<matrix_exp_type>::
+ get_v (
+ ) const
+ {
+ complex_matrix_type CV(n,n);
+
+ for (long i = 0; i < n; i++)
+ {
+ if (e(i) > 0)
+ {
+ set_colm(CV,i) = complex_matrix(colm(V,i), colm(V,i+1));
+ }
+ else if (e(i) < 0)
+ {
+ set_colm(CV,i) = complex_matrix(colm(V,i), colm(V,i-1));
+ }
+ else
+ {
+ set_colm(CV,i) = complex_matrix(colm(V,i), uniform_matrix<type>(n,1,0));
+ }
+ }
+
+ return CV;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename eigenvalue_decomposition<matrix_exp_type>::matrix_type eigenvalue_decomposition<matrix_exp_type>::
+ get_pseudo_d (
+ ) const
+ {
+ matrix_type D(n,n);
+
+ for (long i = 0; i < n; i++)
+ {
+ for (long j = 0; j < n; j++)
+ {
+ D(i,j) = 0.0;
+ }
+ D(i,i) = d(i);
+ if (e(i) > 0)
+ {
+ D(i,i+1) = e(i);
+ }
+ else if (e(i) < 0)
+ {
+ D(i,i-1) = e(i);
+ }
+ }
+
+ return D;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Private member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+// Symmetric Householder reduction to tridiagonal form.
+ template <typename matrix_exp_type>
+ void eigenvalue_decomposition<matrix_exp_type>::
+ tred2()
+ {
+ using std::abs;
+ using std::sqrt;
+
+ // This is derived from the Algol procedures tred2 by
+ // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for
+ // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding
+ // Fortran subroutine in EISPACK.
+
+ for (long j = 0; j < n; j++)
+ {
+ d(j) = V(n-1,j);
+ }
+
+ // Householder reduction to tridiagonal form.
+
+ for (long i = n-1; i > 0; i--)
+ {
+
+ // Scale to avoid under/overflow.
+
+ type scale = 0.0;
+ type h = 0.0;
+ for (long k = 0; k < i; k++)
+ {
+ scale = scale + abs(d(k));
+ }
+ if (scale == 0.0)
+ {
+ e(i) = d(i-1);
+ for (long j = 0; j < i; j++)
+ {
+ d(j) = V(i-1,j);
+ V(i,j) = 0.0;
+ V(j,i) = 0.0;
+ }
+ }
+ else
+ {
+
+ // Generate Householder vector.
+
+ for (long k = 0; k < i; k++)
+ {
+ d(k) /= scale;
+ h += d(k) * d(k);
+ }
+ type f = d(i-1);
+ type g = sqrt(h);
+ if (f > 0)
+ {
+ g = -g;
+ }
+ e(i) = scale * g;
+ h = h - f * g;
+ d(i-1) = f - g;
+ for (long j = 0; j < i; j++)
+ {
+ e(j) = 0.0;
+ }
+
+ // Apply similarity transformation to remaining columns.
+
+ for (long j = 0; j < i; j++)
+ {
+ f = d(j);
+ V(j,i) = f;
+ g = e(j) + V(j,j) * f;
+ for (long k = j+1; k <= i-1; k++)
+ {
+ g += V(k,j) * d(k);
+ e(k) += V(k,j) * f;
+ }
+ e(j) = g;
+ }
+ f = 0.0;
+ for (long j = 0; j < i; j++)
+ {
+ e(j) /= h;
+ f += e(j) * d(j);
+ }
+ type hh = f / (h + h);
+ for (long j = 0; j < i; j++)
+ {
+ e(j) -= hh * d(j);
+ }
+ for (long j = 0; j < i; j++)
+ {
+ f = d(j);
+ g = e(j);
+ for (long k = j; k <= i-1; k++)
+ {
+ V(k,j) -= (f * e(k) + g * d(k));
+ }
+ d(j) = V(i-1,j);
+ V(i,j) = 0.0;
+ }
+ }
+ d(i) = h;
+ }
+
+ // Accumulate transformations.
+
+ for (long i = 0; i < n-1; i++)
+ {
+ V(n-1,i) = V(i,i);
+ V(i,i) = 1.0;
+ type h = d(i+1);
+ if (h != 0.0)
+ {
+ for (long k = 0; k <= i; k++)
+ {
+ d(k) = V(k,i+1) / h;
+ }
+ for (long j = 0; j <= i; j++)
+ {
+ type g = 0.0;
+ for (long k = 0; k <= i; k++)
+ {
+ g += V(k,i+1) * V(k,j);
+ }
+ for (long k = 0; k <= i; k++)
+ {
+ V(k,j) -= g * d(k);
+ }
+ }
+ }
+ for (long k = 0; k <= i; k++)
+ {
+ V(k,i+1) = 0.0;
+ }
+ }
+ for (long j = 0; j < n; j++)
+ {
+ d(j) = V(n-1,j);
+ V(n-1,j) = 0.0;
+ }
+ V(n-1,n-1) = 1.0;
+ e(0) = 0.0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ void eigenvalue_decomposition<matrix_exp_type>::
+ tql2 ()
+ {
+ using std::pow;
+ using std::min;
+ using std::max;
+ using std::abs;
+
+ // This is derived from the Algol procedures tql2, by
+ // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for
+ // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding
+ // Fortran subroutine in EISPACK.
+
+ for (long i = 1; i < n; i++)
+ {
+ e(i-1) = e(i);
+ }
+ e(n-1) = 0.0;
+
+ type f = 0.0;
+ type tst1 = 0.0;
+ const type eps = std::numeric_limits<type>::epsilon();
+ for (long l = 0; l < n; l++)
+ {
+
+ // Find small subdiagonal element
+
+ tst1 = max(tst1,abs(d(l)) + abs(e(l)));
+ long m = l;
+
+ // Original while-loop from Java code
+ while (m < n)
+ {
+ if (abs(e(m)) <= eps*tst1)
+ {
+ break;
+ }
+ m++;
+ }
+ if (m == n)
+ --m;
+
+
+ // If m == l, d(l) is an eigenvalue,
+ // otherwise, iterate.
+
+ if (m > l)
+ {
+ long iter = 0;
+ do
+ {
+ iter = iter + 1; // (Could check iteration count here.)
+
+ // Compute implicit shift
+
+ type g = d(l);
+ type p = (d(l+1) - g) / (2.0 * e(l));
+ type r = hypot(p,(type)1.0);
+ if (p < 0)
+ {
+ r = -r;
+ }
+ d(l) = e(l) / (p + r);
+ d(l+1) = e(l) * (p + r);
+ type dl1 = d(l+1);
+ type h = g - d(l);
+ for (long i = l+2; i < n; i++)
+ {
+ d(i) -= h;
+ }
+ f = f + h;
+
+ // Implicit QL transformation.
+
+ p = d(m);
+ type c = 1.0;
+ type c2 = c;
+ type c3 = c;
+ type el1 = e(l+1);
+ type s = 0.0;
+ type s2 = 0.0;
+ for (long i = m-1; i >= l; i--)
+ {
+ c3 = c2;
+ c2 = c;
+ s2 = s;
+ g = c * e(i);
+ h = c * p;
+ r = hypot(p,e(i));
+ e(i+1) = s * r;
+ s = e(i) / r;
+ c = p / r;
+ p = c * d(i) - s * g;
+ d(i+1) = h + s * (c * g + s * d(i));
+
+ // Accumulate transformation.
+
+ for (long k = 0; k < n; k++)
+ {
+ h = V(k,i+1);
+ V(k,i+1) = s * V(k,i) + c * h;
+ V(k,i) = c * V(k,i) - s * h;
+ }
+ }
+ p = -s * s2 * c3 * el1 * e(l) / dl1;
+ e(l) = s * p;
+ d(l) = c * p;
+
+ // Check for convergence.
+
+ } while (abs(e(l)) > eps*tst1);
+ }
+ d(l) = d(l) + f;
+ e(l) = 0.0;
+ }
+
+ /*
+ The code to sort the eigenvalues and eigenvectors
+ has been removed from here since, in the non-symmetric case,
+ we can't sort the eigenvalues in a meaningful way. If we left this
+ code in here then the user might supply what they thought was a symmetric
+ matrix but was actually slightly non-symmetric due to rounding error
+ and then they would end up in the non-symmetric eigenvalue solver
+ where the eigenvalues don't end up getting sorted. So to avoid
+ any possible user confusion I'm just removing this.
+ */
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ void eigenvalue_decomposition<matrix_exp_type>::
+ orthes ()
+ {
+ using std::abs;
+ using std::sqrt;
+
+ // This is derived from the Algol procedures orthes and ortran,
+ // by Martin and Wilkinson, Handbook for Auto. Comp.,
+ // Vol.ii-Linear Algebra, and the corresponding
+ // Fortran subroutines in EISPACK.
+
+ long low = 0;
+ long high = n-1;
+
+ for (long m = low+1; m <= high-1; m++)
+ {
+
+ // Scale column.
+
+ type scale = 0.0;
+ for (long i = m; i <= high; i++)
+ {
+ scale = scale + abs(H(i,m-1));
+ }
+ if (scale != 0.0)
+ {
+
+ // Compute Householder transformation.
+
+ type h = 0.0;
+ for (long i = high; i >= m; i--)
+ {
+ ort(i) = H(i,m-1)/scale;
+ h += ort(i) * ort(i);
+ }
+ type g = sqrt(h);
+ if (ort(m) > 0)
+ {
+ g = -g;
+ }
+ h = h - ort(m) * g;
+ ort(m) = ort(m) - g;
+
+ // Apply Householder similarity transformation
+ // H = (I-u*u'/h)*H*(I-u*u')/h)
+
+ for (long j = m; j < n; j++)
+ {
+ type f = 0.0;
+ for (long i = high; i >= m; i--)
+ {
+ f += ort(i)*H(i,j);
+ }
+ f = f/h;
+ for (long i = m; i <= high; i++)
+ {
+ H(i,j) -= f*ort(i);
+ }
+ }
+
+ for (long i = 0; i <= high; i++)
+ {
+ type f = 0.0;
+ for (long j = high; j >= m; j--)
+ {
+ f += ort(j)*H(i,j);
+ }
+ f = f/h;
+ for (long j = m; j <= high; j++)
+ {
+ H(i,j) -= f*ort(j);
+ }
+ }
+ ort(m) = scale*ort(m);
+ H(m,m-1) = scale*g;
+ }
+ }
+
+ // Accumulate transformations (Algol's ortran).
+
+ for (long i = 0; i < n; i++)
+ {
+ for (long j = 0; j < n; j++)
+ {
+ V(i,j) = (i == j ? 1.0 : 0.0);
+ }
+ }
+
+ for (long m = high-1; m >= low+1; m--)
+ {
+ if (H(m,m-1) != 0.0)
+ {
+ for (long i = m+1; i <= high; i++)
+ {
+ ort(i) = H(i,m-1);
+ }
+ for (long j = m; j <= high; j++)
+ {
+ type g = 0.0;
+ for (long i = m; i <= high; i++)
+ {
+ g += ort(i) * V(i,j);
+ }
+ // Double division avoids possible underflow
+ g = (g / ort(m)) / H(m,m-1);
+ for (long i = m; i <= high; i++)
+ {
+ V(i,j) += g * ort(i);
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ void eigenvalue_decomposition<matrix_exp_type>::
+ cdiv_(type xr, type xi, type yr, type yi)
+ {
+ using std::abs;
+ type r,d;
+ if (abs(yr) > abs(yi))
+ {
+ r = yi/yr;
+ d = yr + r*yi;
+ cdivr = (xr + r*xi)/d;
+ cdivi = (xi - r*xr)/d;
+ }
+ else
+ {
+ r = yr/yi;
+ d = yi + r*yr;
+ cdivr = (r*xr + xi)/d;
+ cdivi = (r*xi - xr)/d;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ void eigenvalue_decomposition<matrix_exp_type>::
+ hqr2 ()
+ {
+ using std::pow;
+ using std::min;
+ using std::max;
+ using std::abs;
+ using std::sqrt;
+
+ // This is derived from the Algol procedure hqr2,
+ // by Martin and Wilkinson, Handbook for Auto. Comp.,
+ // Vol.ii-Linear Algebra, and the corresponding
+ // Fortran subroutine in EISPACK.
+
+ // Initialize
+
+ long nn = this->n;
+ long n = nn-1;
+ long low = 0;
+ long high = nn-1;
+ const type eps = std::numeric_limits<type>::epsilon();
+ type exshift = 0.0;
+ type p=0,q=0,r=0,s=0,z=0,t,w,x,y;
+
+ // Store roots isolated by balanc and compute matrix norm
+
+ type norm = 0.0;
+ for (long i = 0; i < nn; i++)
+ {
+ if ((i < low) || (i > high))
+ {
+ d(i) = H(i,i);
+ e(i) = 0.0;
+ }
+ for (long j = max(i-1,0L); j < nn; j++)
+ {
+ norm = norm + abs(H(i,j));
+ }
+ }
+
+ // Outer loop over eigenvalue index
+
+ long iter = 0;
+ while (n >= low)
+ {
+
+ // Look for single small sub-diagonal element
+
+ long l = n;
+ while (l > low)
+ {
+ s = abs(H(l-1,l-1)) + abs(H(l,l));
+ if (s == 0.0)
+ {
+ s = norm;
+ }
+ if (abs(H(l,l-1)) < eps * s)
+ {
+ break;
+ }
+ l--;
+ }
+
+ // Check for convergence
+ // One root found
+
+ if (l == n)
+ {
+ H(n,n) = H(n,n) + exshift;
+ d(n) = H(n,n);
+ e(n) = 0.0;
+ n--;
+ iter = 0;
+
+ // Two roots found
+
+ }
+ else if (l == n-1)
+ {
+ w = H(n,n-1) * H(n-1,n);
+ p = (H(n-1,n-1) - H(n,n)) / 2.0;
+ q = p * p + w;
+ z = sqrt(abs(q));
+ H(n,n) = H(n,n) + exshift;
+ H(n-1,n-1) = H(n-1,n-1) + exshift;
+ x = H(n,n);
+
+ // type pair
+
+ if (q >= 0)
+ {
+ if (p >= 0)
+ {
+ z = p + z;
+ }
+ else
+ {
+ z = p - z;
+ }
+ d(n-1) = x + z;
+ d(n) = d(n-1);
+ if (z != 0.0)
+ {
+ d(n) = x - w / z;
+ }
+ e(n-1) = 0.0;
+ e(n) = 0.0;
+ x = H(n,n-1);
+ s = abs(x) + abs(z);
+ p = x / s;
+ q = z / s;
+ r = sqrt(p * p+q * q);
+ p = p / r;
+ q = q / r;
+
+ // Row modification
+
+ for (long j = n-1; j < nn; j++)
+ {
+ z = H(n-1,j);
+ H(n-1,j) = q * z + p * H(n,j);
+ H(n,j) = q * H(n,j) - p * z;
+ }
+
+ // Column modification
+
+ for (long i = 0; i <= n; i++)
+ {
+ z = H(i,n-1);
+ H(i,n-1) = q * z + p * H(i,n);
+ H(i,n) = q * H(i,n) - p * z;
+ }
+
+ // Accumulate transformations
+
+ for (long i = low; i <= high; i++)
+ {
+ z = V(i,n-1);
+ V(i,n-1) = q * z + p * V(i,n);
+ V(i,n) = q * V(i,n) - p * z;
+ }
+
+ // Complex pair
+
+ }
+ else
+ {
+ d(n-1) = x + p;
+ d(n) = x + p;
+ e(n-1) = z;
+ e(n) = -z;
+ }
+ n = n - 2;
+ iter = 0;
+
+ // No convergence yet
+
+ }
+ else
+ {
+
+ // Form shift
+
+ x = H(n,n);
+ y = 0.0;
+ w = 0.0;
+ if (l < n)
+ {
+ y = H(n-1,n-1);
+ w = H(n,n-1) * H(n-1,n);
+ }
+
+ // Wilkinson's original ad hoc shift
+
+ if (iter == 10)
+ {
+ exshift += x;
+ for (long i = low; i <= n; i++)
+ {
+ H(i,i) -= x;
+ }
+ s = abs(H(n,n-1)) + abs(H(n-1,n-2));
+ x = y = 0.75 * s;
+ w = -0.4375 * s * s;
+ }
+
+ // MATLAB's new ad hoc shift
+
+ if (iter == 30)
+ {
+ s = (y - x) / 2.0;
+ s = s * s + w;
+ if (s > 0)
+ {
+ s = sqrt(s);
+ if (y < x)
+ {
+ s = -s;
+ }
+ s = x - w / ((y - x) / 2.0 + s);
+ for (long i = low; i <= n; i++)
+ {
+ H(i,i) -= s;
+ }
+ exshift += s;
+ x = y = w = 0.964;
+ }
+ }
+
+ iter = iter + 1; // (Could check iteration count here.)
+
+ // Look for two consecutive small sub-diagonal elements
+
+ long m = n-2;
+ while (m >= l)
+ {
+ z = H(m,m);
+ r = x - z;
+ s = y - z;
+ p = (r * s - w) / H(m+1,m) + H(m,m+1);
+ q = H(m+1,m+1) - z - r - s;
+ r = H(m+2,m+1);
+ s = abs(p) + abs(q) + abs(r);
+ p = p / s;
+ q = q / s;
+ r = r / s;
+ if (m == l)
+ {
+ break;
+ }
+ if (abs(H(m,m-1)) * (abs(q) + abs(r)) <
+ eps * (abs(p) * (abs(H(m-1,m-1)) + abs(z) +
+ abs(H(m+1,m+1)))))
+ {
+ break;
+ }
+ m--;
+ }
+
+ for (long i = m+2; i <= n; i++)
+ {
+ H(i,i-2) = 0.0;
+ if (i > m+2)
+ {
+ H(i,i-3) = 0.0;
+ }
+ }
+
+ // Double QR step involving rows l:n and columns m:n
+
+ for (long k = m; k <= n-1; k++)
+ {
+ long notlast = (k != n-1);
+ if (k != m)
+ {
+ p = H(k,k-1);
+ q = H(k+1,k-1);
+ r = (notlast ? H(k+2,k-1) : 0.0);
+ x = abs(p) + abs(q) + abs(r);
+ if (x != 0.0)
+ {
+ p = p / x;
+ q = q / x;
+ r = r / x;
+ }
+ }
+ if (x == 0.0)
+ {
+ break;
+ }
+ s = sqrt(p * p + q * q + r * r);
+ if (p < 0)
+ {
+ s = -s;
+ }
+ if (s != 0)
+ {
+ if (k != m)
+ {
+ H(k,k-1) = -s * x;
+ }
+ else if (l != m)
+ {
+ H(k,k-1) = -H(k,k-1);
+ }
+ p = p + s;
+ x = p / s;
+ y = q / s;
+ z = r / s;
+ q = q / p;
+ r = r / p;
+
+ // Row modification
+
+ for (long j = k; j < nn; j++)
+ {
+ p = H(k,j) + q * H(k+1,j);
+ if (notlast)
+ {
+ p = p + r * H(k+2,j);
+ H(k+2,j) = H(k+2,j) - p * z;
+ }
+ H(k,j) = H(k,j) - p * x;
+ H(k+1,j) = H(k+1,j) - p * y;
+ }
+
+ // Column modification
+
+ for (long i = 0; i <= min(n,k+3); i++)
+ {
+ p = x * H(i,k) + y * H(i,k+1);
+ if (notlast)
+ {
+ p = p + z * H(i,k+2);
+ H(i,k+2) = H(i,k+2) - p * r;
+ }
+ H(i,k) = H(i,k) - p;
+ H(i,k+1) = H(i,k+1) - p * q;
+ }
+
+ // Accumulate transformations
+
+ for (long i = low; i <= high; i++)
+ {
+ p = x * V(i,k) + y * V(i,k+1);
+ if (notlast)
+ {
+ p = p + z * V(i,k+2);
+ V(i,k+2) = V(i,k+2) - p * r;
+ }
+ V(i,k) = V(i,k) - p;
+ V(i,k+1) = V(i,k+1) - p * q;
+ }
+ } // (s != 0)
+ } // k loop
+ } // check convergence
+ } // while (n >= low)
+
+ // Backsubstitute to find vectors of upper triangular form
+
+ if (norm == 0.0)
+ {
+ return;
+ }
+
+ for (n = nn-1; n >= 0; n--)
+ {
+ p = d(n);
+ q = e(n);
+
+ // Real vector
+
+ if (q == 0)
+ {
+ long l = n;
+ H(n,n) = 1.0;
+ for (long i = n-1; i >= 0; i--)
+ {
+ w = H(i,i) - p;
+ r = 0.0;
+ for (long j = l; j <= n; j++)
+ {
+ r = r + H(i,j) * H(j,n);
+ }
+ if (e(i) < 0.0)
+ {
+ z = w;
+ s = r;
+ }
+ else
+ {
+ l = i;
+ if (e(i) == 0.0)
+ {
+ if (w != 0.0)
+ {
+ H(i,n) = -r / w;
+ }
+ else
+ {
+ H(i,n) = -r / (eps * norm);
+ }
+
+ // Solve real equations
+
+ }
+ else
+ {
+ x = H(i,i+1);
+ y = H(i+1,i);
+ q = (d(i) - p) * (d(i) - p) + e(i) * e(i);
+ t = (x * s - z * r) / q;
+ H(i,n) = t;
+ if (abs(x) > abs(z))
+ {
+ H(i+1,n) = (-r - w * t) / x;
+ }
+ else
+ {
+ H(i+1,n) = (-s - y * t) / z;
+ }
+ }
+
+ // Overflow control
+
+ t = abs(H(i,n));
+ if ((eps * t) * t > 1)
+ {
+ for (long j = i; j <= n; j++)
+ {
+ H(j,n) = H(j,n) / t;
+ }
+ }
+ }
+ }
+
+ // Complex vector
+
+ }
+ else if (q < 0)
+ {
+ long l = n-1;
+
+ // Last vector component imaginary so matrix is triangular
+
+ if (abs(H(n,n-1)) > abs(H(n-1,n)))
+ {
+ H(n-1,n-1) = q / H(n,n-1);
+ H(n-1,n) = -(H(n,n) - p) / H(n,n-1);
+ }
+ else
+ {
+ cdiv_(0.0,-H(n-1,n),H(n-1,n-1)-p,q);
+ H(n-1,n-1) = cdivr;
+ H(n-1,n) = cdivi;
+ }
+ H(n,n-1) = 0.0;
+ H(n,n) = 1.0;
+ for (long i = n-2; i >= 0; i--)
+ {
+ type ra,sa,vr,vi;
+ ra = 0.0;
+ sa = 0.0;
+ for (long j = l; j <= n; j++)
+ {
+ ra = ra + H(i,j) * H(j,n-1);
+ sa = sa + H(i,j) * H(j,n);
+ }
+ w = H(i,i) - p;
+
+ if (e(i) < 0.0)
+ {
+ z = w;
+ r = ra;
+ s = sa;
+ }
+ else
+ {
+ l = i;
+ if (e(i) == 0)
+ {
+ cdiv_(-ra,-sa,w,q);
+ H(i,n-1) = cdivr;
+ H(i,n) = cdivi;
+ }
+ else
+ {
+
+ // Solve complex equations
+
+ x = H(i,i+1);
+ y = H(i+1,i);
+ vr = (d(i) - p) * (d(i) - p) + e(i) * e(i) - q * q;
+ vi = (d(i) - p) * 2.0 * q;
+ if ((vr == 0.0) && (vi == 0.0))
+ {
+ vr = eps * norm * (abs(w) + abs(q) +
+ abs(x) + abs(y) + abs(z));
+ }
+ cdiv_(x*r-z*ra+q*sa,x*s-z*sa-q*ra,vr,vi);
+ H(i,n-1) = cdivr;
+ H(i,n) = cdivi;
+ if (abs(x) > (abs(z) + abs(q)))
+ {
+ H(i+1,n-1) = (-ra - w * H(i,n-1) + q * H(i,n)) / x;
+ H(i+1,n) = (-sa - w * H(i,n) - q * H(i,n-1)) / x;
+ }
+ else
+ {
+ cdiv_(-r-y*H(i,n-1),-s-y*H(i,n),z,q);
+ H(i+1,n-1) = cdivr;
+ H(i+1,n) = cdivi;
+ }
+ }
+
+ // Overflow control
+
+ t = max(abs(H(i,n-1)),abs(H(i,n)));
+ if ((eps * t) * t > 1)
+ {
+ for (long j = i; j <= n; j++)
+ {
+ H(j,n-1) = H(j,n-1) / t;
+ H(j,n) = H(j,n) / t;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Vectors of isolated roots
+
+ for (long i = 0; i < nn; i++)
+ {
+ if (i < low || i > high)
+ {
+ for (long j = i; j < nn; j++)
+ {
+ V(i,j) = H(i,j);
+ }
+ }
+ }
+
+ // Back transformation to get eigenvectors of original matrix
+
+ for (long j = nn-1; j >= low; j--)
+ {
+ for (long i = low; i <= high; i++)
+ {
+ z = 0.0;
+ for (long k = low; k <= min(j,high); k++)
+ {
+ z = z + V(i,k) * H(k,j);
+ }
+ V(i,j) = z;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+#endif // DLIB_MATRIX_EIGENVALUE_DECOMPOSITION_H
+
+
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_exp.h b/ml/dlib/dlib/matrix/matrix_exp.h
new file mode 100644
index 000000000..c0afb54c0
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_exp.h
@@ -0,0 +1,271 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_EXP_h_
+#define DLIB_MATRIx_EXP_h_
+
+#include "../algs.h"
+#include "../is_kind.h"
+#include "matrix_fwd.h"
+#include "matrix_exp_abstract.h"
+#include <iterator>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ // We want to return the compile time constant if our NR and NC dimensions
+ // aren't zero but if they are then we want to call ref_.nx() and return
+ // the correct values.
+ template < typename exp_type, long NR >
+ struct get_nr_helper
+ {
+ static inline long get(const exp_type&) { return NR; }
+ };
+
+ template < typename exp_type >
+ struct get_nr_helper<exp_type,0>
+ {
+ static inline long get(const exp_type& m) { return m.nr(); }
+ };
+
+ template < typename exp_type, long NC >
+ struct get_nc_helper
+ {
+ static inline long get(const exp_type&) { return NC; }
+ };
+
+ template < typename exp_type >
+ struct get_nc_helper<exp_type,0>
+ {
+ static inline long get(const exp_type& m) { return m.nc(); }
+ };
+
+ template <typename EXP>
+ struct matrix_traits
+ {
+ typedef typename EXP::type type;
+ typedef typename EXP::const_ret_type const_ret_type;
+ typedef typename EXP::mem_manager_type mem_manager_type;
+ typedef typename EXP::layout_type layout_type;
+ const static long NR = EXP::NR;
+ const static long NC = EXP::NC;
+ const static long cost = EXP::cost;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP> class matrix_exp;
+ template <typename EXP>
+ class matrix_exp_iterator : public std::iterator<std::forward_iterator_tag, typename matrix_traits<EXP>::type>
+ {
+ friend class matrix_exp<EXP>;
+ matrix_exp_iterator(const EXP& m_, long r_, long c_)
+ {
+ r = r_;
+ c = c_;
+ nc = m_.nc();
+ m = &m_;
+ }
+
+ public:
+
+ matrix_exp_iterator() : r(0), c(0), nc(0), m(0) {}
+
+ typedef typename matrix_traits<EXP>::type type;
+ typedef type value_type;
+ typedef typename matrix_traits<EXP>::const_ret_type const_ret_type;
+
+
+ bool operator == ( const matrix_exp_iterator& itr) const
+ { return r == itr.r && c == itr.c; }
+
+ bool operator != ( const matrix_exp_iterator& itr) const
+ { return !(*this == itr); }
+
+ matrix_exp_iterator& operator++()
+ {
+ ++c;
+ if (c==nc)
+ {
+ c = 0;
+ ++r;
+ }
+ return *this;
+ }
+
+ matrix_exp_iterator operator++(int)
+ {
+ matrix_exp_iterator temp(*this);
+ ++(*this);
+ return temp;
+ }
+
+ const_ret_type operator* () const { return (*m)(r,c); }
+
+ private:
+ long r, c;
+ long nc;
+ const EXP* m;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ class matrix_exp
+ {
+ /*!
+ REQUIREMENTS ON EXP
+ EXP should be something convertible to a matrix_exp. That is,
+ it should inherit from matrix_exp
+ !*/
+
+ public:
+ typedef typename matrix_traits<EXP>::type type;
+ typedef type value_type;
+ typedef typename matrix_traits<EXP>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<EXP>::mem_manager_type mem_manager_type;
+ typedef typename matrix_traits<EXP>::layout_type layout_type;
+ const static long NR = matrix_traits<EXP>::NR;
+ const static long NC = matrix_traits<EXP>::NC;
+ const static long cost = matrix_traits<EXP>::cost;
+
+ typedef matrix<type,NR,NC,mem_manager_type,layout_type> matrix_type;
+ typedef EXP exp_type;
+ typedef matrix_exp_iterator<EXP> iterator;
+ typedef matrix_exp_iterator<EXP> const_iterator;
+
+ inline const_ret_type operator() (
+ long r,
+ long c
+ ) const
+ {
+ DLIB_ASSERT(r < nr() && c < nc() && r >= 0 && c >= 0,
+ "\tconst type matrix_exp::operator(r,c)"
+ << "\n\tYou must give a valid row and column"
+ << "\n\tr: " << r
+ << "\n\tc: " << c
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tthis: " << this
+ );
+ return ref()(r,c);
+ }
+
+ const_ret_type operator() (
+ long i
+ ) const
+ {
+ COMPILE_TIME_ASSERT(NC == 1 || NC == 0 || NR == 1 || NR == 0);
+ DLIB_ASSERT(nc() == 1 || nr() == 1,
+ "\tconst type matrix_exp::operator(i)"
+ << "\n\tYou can only use this operator on column or row vectors"
+ << "\n\ti: " << i
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tthis: " << this
+ );
+ DLIB_ASSERT( ((nc() == 1 && i < nr()) || (nr() == 1 && i < nc())) && i >= 0,
+ "\tconst type matrix_exp::operator(i)"
+ << "\n\tYou must give a valid row/column number"
+ << "\n\ti: " << i
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tthis: " << this
+ );
+ if (nc() == 1)
+ return ref()(i,0);
+ else
+ return ref()(0,i);
+ }
+
+ long size (
+ ) const { return nr()*nc(); }
+
+ long nr (
+ ) const { return get_nr_helper<exp_type,NR>::get(ref()); }
+
+ long nc (
+ ) const { return get_nc_helper<exp_type,NC>::get(ref()); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const { return ref().aliases(item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const { return ref().destructively_aliases(item); }
+
+ inline const exp_type& ref (
+ ) const { return *static_cast<const exp_type*>(this); }
+
+ inline operator const type (
+ ) const
+ {
+ COMPILE_TIME_ASSERT(NC == 1 || NC == 0);
+ COMPILE_TIME_ASSERT(NR == 1 || NR == 0);
+ DLIB_ASSERT(nr() == 1 && nc() == 1,
+ "\tmatrix_exp::operator const type() const"
+ << "\n\tYou can only use this operator on a 1x1 matrix"
+ << "\n\tnr(): " << nr()
+ << "\n\tnc(): " << nc()
+ << "\n\tthis: " << this
+ );
+
+ // Put the expression contained in this matrix_exp into
+ // a temporary 1x1 matrix so that the expression will encounter
+ // all the overloads of matrix_assign() and have the chance to
+ // go through any applicable optimizations.
+ matrix<type,1,1,mem_manager_type,layout_type> temp(ref());
+ return temp(0);
+ }
+
+ const_iterator begin() const { return matrix_exp_iterator<EXP>(ref(),0,0); }
+ const_iterator end() const { return matrix_exp_iterator<EXP>(ref(),nr(),0); }
+
+ protected:
+ matrix_exp() {}
+ matrix_exp(const matrix_exp& ) {}
+
+ private:
+
+ matrix_exp& operator= (const matrix_exp&);
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ // something is a matrix if it is convertible to a matrix_exp object
+ template <typename T>
+ struct is_matrix<T, typename enable_if<is_convertible<T, const matrix_exp<typename T::exp_type>& > >::type >
+ { static const bool value = true; };
+ /*
+ is_matrix<T>::value == 1 if T is a matrix type else 0
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ class matrix_diag_exp : public matrix_exp<EXP>
+ {
+ /*!
+ This is a matrix expression type used to represent diagonal matrices.
+ That is, square matrices with all off diagonal elements equal to 0.
+ !*/
+
+ protected:
+ matrix_diag_exp() {}
+ matrix_diag_exp(const matrix_diag_exp& item ):matrix_exp<EXP>(item) {}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_EXP_h_
+
diff --git a/ml/dlib/dlib/matrix/matrix_exp_abstract.h b/ml/dlib/dlib/matrix/matrix_exp_abstract.h
new file mode 100644
index 000000000..14ad143c2
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_exp_abstract.h
@@ -0,0 +1,210 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MATRIx_EXP_ABSTRACT_
+#ifdef DLIB_MATRIx_EXP_ABSTRACT_
+
+#include "matrix_fwd.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ class matrix_exp
+ {
+ /*!
+ REQUIREMENTS ON EXP
+ - must be an object that inherits publicly from matrix_exp (this class).
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an expression that evaluates to a matrix
+ of nr() rows and nc() columns.
+
+ The reason for having an object that represents an expression is that it
+ allows us to use the "expression templates" technique to eliminate the
+ temporary matrix objects that would normally be returned from expressions
+ such as M = A+B+C+D; Normally each invocation of the + operator would
+ construct and return a temporary matrix object but using this technique we
+ can avoid creating all of these temporary objects and receive a large
+ speed boost.
+
+ Note that every time you invoke operator() on this object it recomputes
+ its result which may not be what you want to do. For example, if you
+ are going to be accessing the same element over and over it might
+ be faster to assign the matrix_exp to a temporary matrix and then
+ use that temporary.
+
+
+ const_ret_type typedef (defined below)
+ The purpose of the const_ret_type typedef is to allow matrix expressions
+ to return their elements by reference when appropriate. So const_ret_type
+ should be one of the following types:
+ - const type
+ - const type&
+ !*/
+
+ public:
+ typedef typename EXP::type type;
+ typedef type value_type; // Redefined for compatibility with the STL
+ typedef typename EXP::const_ret_type const_ret_type;
+ typedef typename EXP::mem_manager_type mem_manager_type;
+ typedef typename EXP::layout_type layout_type;
+ const static long cost = EXP::cost;
+ const static long NR = EXP::NR;
+ const static long NC = EXP::NC;
+ typedef matrix<type,NR,NC, mem_manager_type,layout_type> matrix_type;
+ typedef EXP exp_type;
+ typedef matrix_exp_iterator<EXP> iterator;
+ typedef matrix_exp_iterator<EXP> const_iterator;
+
+ const_ret_type operator() (
+ long r,
+ long c
+ ) const;
+ /*!
+ requires
+ - 0 <= r < nr()
+ - 0 <= c < nc()
+ ensures
+ - returns ref()(r,c)
+ (i.e. returns the value at the given row and column that would be in
+ the matrix represented by this matrix expression)
+ !*/
+
+ const_ret_type operator() (
+ long i
+ ) const;
+ /*!
+ requires
+ - nc() == 1 || nr() == 1 (i.e. this must be a column or row vector)
+ - if (nc() == 1) then
+ - 0 <= i < nr()
+ - else
+ - 0 <= i < nc()
+ ensures
+ - if (nc() == 1) then
+ - returns (*this)(i,0)
+ - else
+ - returns (*this)(0,i)
+ !*/
+
+ operator const type (
+ ) const;
+ /*!
+ requires
+ - nr() == 1
+ - nc() == 1
+ ensures
+ - returns (*this)(0,0)
+ !*/
+
+ long nr (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in this matrix expression.
+ !*/
+
+ long nc (
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in this matrix expression.
+ !*/
+
+ long size (
+ ) const;
+ /*!
+ ensures
+ - returns nr()*nc()
+ !*/
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const;
+ /*!
+ ensures
+ - if (A change to the state of item could cause a change to the state of *this
+ matrix_exp object. ) then
+ - returns true
+ - This happens when this matrix_exp contains item in some way.
+ - else
+ - returns false
+ !*/
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const;
+ /*!
+ ensures
+ - if (aliases(item)) then
+ - if (nr() != item.nr() || nc() != item.nc()
+ - returns true
+ (i.e. if this expression has different dimensions than item then
+ we have destructive aliasing)
+
+ - returns true if the following assignment would evaluate incorrectly:
+ for (long r = 0; r < nr(); ++r)
+ for (long c = 0; c < nc(); ++c)
+ item(r,c) = (*this)(r,c)
+ - That is, if this matrix expression aliases item in such a way that a modification
+ to element item(r,c) causes a change in the value of something other than
+ (*this)(r,c) then this function returns true.
+
+ - returns false if none of the above conditions say we should return true
+ - else
+ - returns false
+ !*/
+
+ inline const exp_type& ref (
+ ) const;
+ /*!
+ ensures
+ - returns a reference to the expression contained in *this.
+ (i.e. returns *static_cast<const exp_type*>(this) )
+ !*/
+
+ const_iterator begin(
+ ) const;
+ /*!
+ ensures
+ - returns a forward access iterator pointing to the first element in this
+ matrix expression.
+ - Since matrix_exp objects represent immutable views of a matrix, the
+ returned iterator does not allow the user to modify the matrix
+ expression's elements.
+ - The iterator will iterate over the elements of the matrix in row major
+ order.
+ !*/
+
+ const_iterator end(
+ ) const;
+ /*!
+ ensures
+ - returns a forward access iterator pointing to one past the end of the
+ last element in this matrix expression.
+ !*/
+
+ protected:
+
+ // Only derived classes of matrix_exp may call the matrix_exp constructors.
+ matrix_exp(const matrix_exp&);
+ matrix_exp();
+
+ private:
+ // no one may ever use the assignment operator on a matrix_exp
+ matrix_exp& operator= (const matrix_exp&);
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_EXP_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_expressions.h b/ml/dlib/dlib/matrix/matrix_expressions.h
new file mode 100644
index 000000000..9f057d076
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_expressions.h
@@ -0,0 +1,280 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_EXPRESSIONS_H_
+#define DLIB_MATRIx_EXPRESSIONS_H_
+
+#include "matrix_fwd.h"
+
+#ifdef _MSC_VER
+// This #pragma directive is also located in the algs.h file but for whatever
+// reason visual studio 9 just ignores it when it is only there.
+
+// this is to disable the "'this' : used in base member initializer list"
+// warning you get from some of the GUI objects since all the objects
+// require that their parent class be passed into their constructor.
+// In this case though it is totally safe so it is ok to disable this warning.
+#pragma warning(disable : 4355)
+#endif
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Helper templates for making operators used by expression objects
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class matrix_range_exp;
+
+ template <typename T>
+ struct matrix_traits<matrix_range_exp<T> >
+ {
+ typedef T type;
+ typedef const T const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+ const static long NR = 1;
+ const static long NC = 0;
+ const static long cost = 1;
+ };
+
+ template <typename T>
+ class matrix_range_exp : public matrix_exp<matrix_range_exp<T> >
+ {
+ public:
+ typedef typename matrix_traits<matrix_range_exp>::type type;
+ typedef typename matrix_traits<matrix_range_exp>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_range_exp>::mem_manager_type mem_manager_type;
+ const static long NR = matrix_traits<matrix_range_exp>::NR;
+ const static long NC = matrix_traits<matrix_range_exp>::NC;
+ const static long cost = matrix_traits<matrix_range_exp>::cost;
+ typedef typename matrix_traits<matrix_range_exp>::layout_type layout_type;
+
+
+ matrix_range_exp (
+ T start_,
+ T end_
+ )
+ {
+ start = start_;
+ if (start_ <= end_)
+ inc = 1;
+ else
+ inc = -1;
+ nc_ = std::abs(end_ - start_) + 1;
+ }
+ matrix_range_exp (
+ T start_,
+ T inc_,
+ T end_
+ )
+ {
+ start = start_;
+ nc_ = std::abs(end_ - start_)/inc_ + 1;
+ if (start_ <= end_)
+ inc = inc_;
+ else
+ inc = -inc_;
+ }
+
+ matrix_range_exp (
+ T start_,
+ T end_,
+ long num,
+ bool
+ )
+ {
+ start = start_;
+ nc_ = num;
+ if (num > 1)
+ {
+ inc = (end_-start_)/(num-1);
+ }
+ else
+ {
+ inc = 0;
+ start = end_;
+ }
+
+ }
+
+ const_ret_type operator() (
+ long,
+ long c
+ ) const { return start + c*inc; }
+
+ const_ret_type operator() (
+ long c
+ ) const { return start + c*inc; }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>&
+ ) const { return false; }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>&
+ ) const { return false; }
+
+ long nr (
+ ) const { return NR; }
+
+ long nc (
+ ) const { return nc_; }
+
+ long nc_;
+ T start;
+ T inc;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class matrix_log_range_exp;
+
+ template <typename T>
+ struct matrix_traits<matrix_log_range_exp<T> >
+ {
+ typedef T type;
+ typedef const T const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+ const static long NR = 1;
+ const static long NC = 0;
+ const static long cost = 1;
+ };
+
+ template <typename T>
+ class matrix_log_range_exp : public matrix_exp<matrix_log_range_exp<T> >
+ {
+ public:
+ typedef typename matrix_traits<matrix_log_range_exp>::type type;
+ typedef typename matrix_traits<matrix_log_range_exp>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_log_range_exp>::mem_manager_type mem_manager_type;
+ const static long NR = matrix_traits<matrix_log_range_exp>::NR;
+ const static long NC = matrix_traits<matrix_log_range_exp>::NC;
+ const static long cost = matrix_traits<matrix_log_range_exp>::cost;
+ typedef typename matrix_traits<matrix_log_range_exp>::layout_type layout_type;
+
+
+ matrix_log_range_exp (
+ T start_,
+ T end_,
+ long num
+ )
+ {
+ start = start_;
+ nc_ = num;
+ if (num > 1)
+ {
+ inc = (end_-start_)/(num-1);
+ }
+ else
+ {
+ inc = 0;
+ start = end_;
+ }
+
+ }
+
+ const_ret_type operator() (
+ long,
+ long c
+ ) const { return std::pow((T)10,start + c*inc); }
+
+ const_ret_type operator() (
+ long c
+ ) const { return std::pow((T)10,start + c*inc); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>&
+ ) const { return false; }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>&
+ ) const { return false; }
+
+ long nr (
+ ) const { return NR; }
+
+ long nc (
+ ) const { return nc_; }
+
+ long nc_;
+ T start;
+ T inc;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <long start, long inc_, long end>
+ class matrix_range_static_exp;
+
+ template <long start, long inc_, long end>
+ struct matrix_traits<matrix_range_static_exp<start,inc_,end> >
+ {
+ typedef long type;
+ typedef const long const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ const static long NR = 1;
+ const static long NC = tabs<(end - start)>::value/inc_ + 1;
+ const static long cost = 1;
+ typedef row_major_layout layout_type;
+ };
+
+ template <long start, long inc_, long end_>
+ class matrix_range_static_exp : public matrix_exp<matrix_range_static_exp<start,inc_,end_> >
+ {
+ public:
+ typedef typename matrix_traits<matrix_range_static_exp>::type type;
+ typedef typename matrix_traits<matrix_range_static_exp>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_range_static_exp>::mem_manager_type mem_manager_type;
+ const static long NR = matrix_traits<matrix_range_static_exp>::NR;
+ const static long NC = matrix_traits<matrix_range_static_exp>::NC;
+ const static long cost = matrix_traits<matrix_range_static_exp>::cost;
+ typedef typename matrix_traits<matrix_range_static_exp>::layout_type layout_type;
+
+ const static long inc = (start <= end_)?inc_:-inc_;
+
+
+ matrix_range_static_exp (
+ ) {}
+
+ const_ret_type operator() (
+ long ,
+ long c
+ ) const { return start + c*inc; }
+
+ const_ret_type operator() (
+ long c
+ ) const { return start + c*inc; }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>&
+ ) const { return false; }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>&
+ ) const { return false; }
+
+ long nr (
+ ) const { return NR; }
+
+ long nc (
+ ) const { return NC; }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_EXPRESSIONS_H_
+
diff --git a/ml/dlib/dlib/matrix/matrix_fft.h b/ml/dlib/dlib/matrix/matrix_fft.h
new file mode 100644
index 000000000..fbca6d344
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_fft.h
@@ -0,0 +1,846 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FFt_Hh_
+#define DLIB_FFt_Hh_
+
+#include "matrix_fft_abstract.h"
+#include "matrix_utilities.h"
+#include "../hash.h"
+#include "../algs.h"
+
+#ifdef DLIB_USE_MKL_FFT
+#include <mkl_dfti.h>
+#endif
+
+// No using FFTW until it becomes thread safe!
+#if 0
+#ifdef DLIB_USE_FFTW
+#include <fftw3.h>
+#endif // DLIB_USE_FFTW
+#endif
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool is_power_of_two (
+ const unsigned long& value
+ )
+ {
+ if (value == 0)
+ return true;
+ else
+ return count_bits(value) == 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ /*
+ The next few functions related to doing FFTs are derived from Stefan
+ Gustavson's (stegu@itn.liu.se) public domain 2D Fourier transformation code.
+ The code has a long history, originally a FORTRAN implementation published in:
+ Programming for Digital Signal Processing, IEEE Press 1979, Section 1, by G. D.
+ Bergland and M. T. Dolan. In 2003 it was cleaned up and turned into modern C
+ by Steven Gustavson. Davis King then rewrote it in modern C++ in 2014 and also
+ changed the transform so that the outputs are identical to those given from FFTW.
+ */
+
+ // ------------------------------------------------------------------------------------
+
+ /* Get binary log of integer argument - exact if n is a power of 2 */
+ inline long fastlog2(long n)
+ {
+ long log = -1;
+ while(n) {
+ log++;
+ n >>= 1;
+ }
+ return log ;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ /* Radix-2 iteration subroutine */
+ template <typename T>
+ void R2TX(int nthpo, std::complex<T> *c0, std::complex<T> *c1)
+ {
+ for(int k=0; k<nthpo; k+=2)
+ {
+ std::complex<T> temp = c0[k] + c1[k];
+ c1[k] = c0[k] - c1[k];
+ c0[k] = temp;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ /* Radix-4 iteration subroutine */
+ template <typename T>
+ void R4TX(int nthpo, std::complex<T> *c0, std::complex<T> *c1,
+ std::complex<T> *c2, std::complex<T> *c3)
+ {
+ for(int k=0;k<nthpo;k+=4)
+ {
+ std::complex<T> t1, t2, t3, t4;
+ t1 = c0[k] + c2[k];
+ t2 = c0[k] - c2[k];
+ t3 = c1[k] + c3[k];
+ t4 = c1[k] - c3[k];
+
+ c0[k] = t1 + t3;
+ c1[k] = t1 - t3;
+ c2[k] = std::complex<T>(t2.real()-t4.imag(), t2.imag()+t4.real());
+ c3[k] = std::complex<T>(t2.real()+t4.imag(), t2.imag()-t4.real());
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ class twiddles
+ {
+ /*!
+ The point of this object is to cache the twiddle values so we don't
+ recompute them over and over inside R8TX().
+ !*/
+ public:
+
+ twiddles()
+ {
+ data.resize(64);
+ }
+
+ const std::complex<T>* get_twiddles (
+ int p
+ )
+ /*!
+ requires
+ - 0 <= p <= 64
+ ensures
+ - returns a pointer to the twiddle factors needed by R8TX if nxtlt == 2^p
+ !*/
+ {
+ // Compute the twiddle factors for this p value if we haven't done so
+ // already.
+ if (data[p].size() == 0)
+ {
+ const int nxtlt = 0x1 << p;
+ data[p].reserve(nxtlt*7);
+ const T twopi = 6.2831853071795865; /* 2.0 * pi */
+ const T scale = twopi/(nxtlt*8.0);
+ std::complex<T> cs[7];
+ for (int j = 0; j < nxtlt; ++j)
+ {
+ const T arg = j*scale;
+ cs[0] = std::complex<T>(std::cos(arg),std::sin(arg));
+ cs[1] = cs[0]*cs[0];
+ cs[2] = cs[1]*cs[0];
+ cs[3] = cs[1]*cs[1];
+ cs[4] = cs[2]*cs[1];
+ cs[5] = cs[2]*cs[2];
+ cs[6] = cs[3]*cs[2];
+ data[p].insert(data[p].end(), cs, cs+7);
+ }
+ }
+
+ return &data[p][0];
+ }
+
+ private:
+ std::vector<std::vector<std::complex<T> > > data;
+ };
+
+ // ----------------------------------------------------------------------------------------
+
+ /* Radix-8 iteration subroutine */
+ template <typename T>
+ void R8TX(int nxtlt, int nthpo, int length, const std::complex<T>* cs,
+ std::complex<T> *cc0, std::complex<T> *cc1, std::complex<T> *cc2, std::complex<T> *cc3,
+ std::complex<T> *cc4, std::complex<T> *cc5, std::complex<T> *cc6, std::complex<T> *cc7)
+ {
+ const T irt2 = 0.707106781186548; /* 1.0/sqrt(2.0) */
+
+ for(int j=0; j<nxtlt; j++)
+ {
+ for(int k=j;k<nthpo;k+=length)
+ {
+ std::complex<T> a0, a1, a2, a3, a4, a5, a6, a7;
+ std::complex<T> b0, b1, b2, b3, b4, b5, b6, b7;
+ a0 = cc0[k] + cc4[k];
+ a1 = cc1[k] + cc5[k];
+ a2 = cc2[k] + cc6[k];
+ a3 = cc3[k] + cc7[k];
+ a4 = cc0[k] - cc4[k];
+ a5 = cc1[k] - cc5[k];
+ a6 = cc2[k] - cc6[k];
+ a7 = cc3[k] - cc7[k];
+
+ b0 = a0 + a2;
+ b1 = a1 + a3;
+ b2 = a0 - a2;
+ b3 = a1 - a3;
+
+ b4 = std::complex<T>(a4.real()-a6.imag(), a4.imag()+a6.real());
+ b5 = std::complex<T>(a5.real()-a7.imag(), a5.imag()+a7.real());
+ b6 = std::complex<T>(a4.real()+a6.imag(), a4.imag()-a6.real());
+ b7 = std::complex<T>(a5.real()+a7.imag(), a5.imag()-a7.real());
+
+ const std::complex<T> tmp0(-b3.imag(), b3.real());
+ const std::complex<T> tmp1(irt2*(b5.real()-b5.imag()), irt2*(b5.real()+b5.imag()));
+ const std::complex<T> tmp2(-irt2*(b7.real()+b7.imag()), irt2*(b7.real()-b7.imag()));
+
+ cc0[k] = b0 + b1;
+ cc1[k] = b0 - b1;
+ cc2[k] = b2 + tmp0;
+ cc3[k] = b2 - tmp0;
+ cc4[k] = b4 + tmp1;
+ cc5[k] = b4 - tmp1;
+ cc6[k] = b6 + tmp2;
+ cc7[k] = b6 - tmp2;
+ if(j>0)
+ {
+ cc1[k] *= cs[3];
+ cc2[k] *= cs[1];
+ cc3[k] *= cs[5];
+ cc4[k] *= cs[0];
+ cc5[k] *= cs[4];
+ cc6[k] *= cs[2];
+ cc7[k] *= cs[6];
+ }
+ }
+
+ cs += 7;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename layout>
+ void fft1d_inplace(matrix<std::complex<T>,NR,NC,MM,layout>& data, bool do_backward_fft, twiddles<T>& cs)
+ /*!
+ requires
+ - is_vector(data) == true
+ - is_power_of_two(data.size()) == true
+ ensures
+ - This routine replaces the input std::complex<double> vector by its finite
+ discrete complex fourier transform if do_backward_fft==true. It replaces
+ the input std::complex<double> vector by its finite discrete complex
+ inverse fourier transform if do_backward_fft==false.
+
+ The implementation is a radix-2 FFT, but with faster shortcuts for
+ radix-4 and radix-8. It performs as many radix-8 iterations as possible,
+ and then finishes with a radix-2 or -4 iteration if needed.
+ !*/
+ {
+ COMPILE_TIME_ASSERT((is_same_type<double,T>::value || is_same_type<float,T>::value || is_same_type<long double,T>::value ));
+
+ if (data.size() == 0)
+ return;
+
+ std::complex<T>* const b = &data(0);
+ int L[16],L1,L2,L3,L4,L5,L6,L7,L8,L9,L10,L11,L12,L13,L14,L15;
+ int j1,j2,j3,j4,j5,j6,j7,j8,j9,j10,j11,j12,j13,j14;
+ int j, ij, ji;
+ int n2pow, n8pow, nthpo, ipass, nxtlt, length;
+
+ n2pow = fastlog2(data.size());
+ nthpo = data.size();
+
+ n8pow = n2pow/3;
+
+ if(n8pow)
+ {
+ /* Radix 8 iterations */
+ for(ipass=1;ipass<=n8pow;ipass++)
+ {
+ const int p = n2pow - 3*ipass;
+ nxtlt = 0x1 << p;
+ length = 8*nxtlt;
+ R8TX(nxtlt, nthpo, length, cs.get_twiddles(p),
+ b, b+nxtlt, b+2*nxtlt, b+3*nxtlt,
+ b+4*nxtlt, b+5*nxtlt, b+6*nxtlt, b+7*nxtlt);
+ }
+ }
+
+ if(n2pow%3 == 1)
+ {
+ /* A final radix 2 iteration is needed */
+ R2TX(nthpo, b, b+1);
+ }
+
+ if(n2pow%3 == 2)
+ {
+ /* A final radix 4 iteration is needed */
+ R4TX(nthpo, b, b+1, b+2, b+3);
+ }
+
+ for(j=1;j<=15;j++)
+ {
+ L[j] = 1;
+ if(j-n2pow <= 0) L[j] = 0x1 << (n2pow + 1 - j);
+ }
+
+ L15=L[1];L14=L[2];L13=L[3];L12=L[4];L11=L[5];L10=L[6];L9=L[7];
+ L8=L[8];L7=L[9];L6=L[10];L5=L[11];L4=L[12];L3=L[13];L2=L[14];L1=L[15];
+
+ ij = 0;
+
+ for(j1=0;j1<L1;j1++)
+ for(j2=j1;j2<L2;j2+=L1)
+ for(j3=j2;j3<L3;j3+=L2)
+ for(j4=j3;j4<L4;j4+=L3)
+ for(j5=j4;j5<L5;j5+=L4)
+ for(j6=j5;j6<L6;j6+=L5)
+ for(j7=j6;j7<L7;j7+=L6)
+ for(j8=j7;j8<L8;j8+=L7)
+ for(j9=j8;j9<L9;j9+=L8)
+ for(j10=j9;j10<L10;j10+=L9)
+ for(j11=j10;j11<L11;j11+=L10)
+ for(j12=j11;j12<L12;j12+=L11)
+ for(j13=j12;j13<L13;j13+=L12)
+ for(j14=j13;j14<L14;j14+=L13)
+ for(ji=j14;ji<L15;ji+=L14)
+ {
+ if(ij<ji)
+ swap(b[ij], b[ji]);
+ ij++;
+ }
+
+
+ // unscramble outputs
+ if(!do_backward_fft)
+ {
+ for(long i=1, j=data.size()-1; i<data.size()/2; i++,j--)
+ {
+ swap(b[j], b[i]);
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template < typename T, long NR, long NC, typename MM, typename L >
+ void fft2d_inplace(
+ matrix<std::complex<T>,NR,NC,MM,L>& data,
+ bool do_backward_fft
+ )
+ {
+ if (data.size() == 0)
+ return;
+
+ matrix<std::complex<double> > buff;
+ twiddles<double> cs;
+
+ // Compute transform row by row
+ for(long r=0; r<data.nr(); ++r)
+ {
+ buff = matrix_cast<std::complex<double> >(rowm(data,r));
+ fft1d_inplace(buff, do_backward_fft, cs);
+ set_rowm(data,r) = matrix_cast<std::complex<T> >(buff);
+ }
+
+ // Compute transform column by column
+ for(long c=0; c<data.nc(); ++c)
+ {
+ buff = matrix_cast<std::complex<double> >(colm(data,c));
+ fft1d_inplace(buff, do_backward_fft, cs);
+ set_colm(data,c) = matrix_cast<std::complex<T> >(buff);
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename T
+ >
+ void fft2d(
+ const matrix_exp<EXP>& data,
+ matrix<std::complex<T> >& data_out,
+ bool do_backward_fft
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t matrix fft(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ if (data.size() == 0)
+ return;
+
+ matrix<std::complex<double> > buff;
+ data_out.set_size(data.nr(), data.nc());
+ twiddles<double> cs;
+
+ // Compute transform row by row
+ for(long r=0; r<data.nr(); ++r)
+ {
+ buff = matrix_cast<std::complex<double> >(rowm(data,r));
+ fft1d_inplace(buff, do_backward_fft, cs);
+ set_rowm(data_out,r) = matrix_cast<std::complex<T> >(buff);
+ }
+
+ // Compute transform column by column
+ for(long c=0; c<data_out.nc(); ++c)
+ {
+ buff = matrix_cast<std::complex<double> >(colm(data_out,c));
+ fft1d_inplace(buff, do_backward_fft, cs);
+ set_colm(data_out,c) = matrix_cast<std::complex<T> >(buff);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ } // end namespace impl
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ matrix<typename EXP::type> fft (const matrix_exp<EXP>& data)
+ {
+ // You have to give a complex matrix
+ COMPILE_TIME_ASSERT(is_complex<typename EXP::type>::value);
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t matrix fft(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ if (data.nr() == 1 || data.nc() == 1)
+ {
+ matrix<typename EXP::type> temp(data);
+ impl::twiddles<typename EXP::type::value_type> cs;
+ impl::fft1d_inplace(temp, false, cs);
+ return temp;
+ }
+ else
+ {
+ matrix<typename EXP::type> temp;
+ impl::fft2d(data, temp, false);
+ return temp;
+ }
+ }
+
+ template <typename EXP>
+ matrix<typename EXP::type> ifft (const matrix_exp<EXP>& data)
+ {
+ // You have to give a complex matrix
+ COMPILE_TIME_ASSERT(is_complex<typename EXP::type>::value);
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t matrix ifft(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ matrix<typename EXP::type> temp;
+ if (data.size() == 0)
+ return temp;
+
+ if (data.nr() == 1 || data.nc() == 1)
+ {
+ temp = data;
+ impl::twiddles<typename EXP::type::value_type> cs;
+ impl::fft1d_inplace(temp, true, cs);
+ }
+ else
+ {
+ impl::fft2d(data, temp, true);
+ }
+ temp /= data.size();
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template < typename T, long NR, long NC, typename MM, typename L >
+ typename enable_if_c<NR==1||NC==1>::type fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
+ // Note that we don't divide the outputs by data.size() so this isn't quite the inverse.
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t void fft_inplace(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ impl::twiddles<T> cs;
+ impl::fft1d_inplace(data, false, cs);
+ }
+
+ template < typename T, long NR, long NC, typename MM, typename L >
+ typename disable_if_c<NR==1||NC==1>::type fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
+ // Note that we don't divide the outputs by data.size() so this isn't quite the inverse.
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t void fft_inplace(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ impl::fft2d_inplace(data, false);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template < typename T, long NR, long NC, typename MM, typename L >
+ typename enable_if_c<NR==1||NC==1>::type ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t void ifft_inplace(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ impl::twiddles<T> cs;
+ impl::fft1d_inplace(data, true, cs);
+ }
+
+ template < typename T, long NR, long NC, typename MM, typename L >
+ typename disable_if_c<NR==1||NC==1>::type ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t void ifft_inplace(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ impl::fft2d_inplace(data, true);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ /*
+ I'm disabling any use of the FFTW bindings because FFTW is, as of this writing, not
+ threadsafe as a library. This means that if multiple threads were to make
+ concurrent calls to these fft routines then the program could crash. If at some
+ point FFTW is fixed I'll turn these bindings back on.
+
+ See https://github.com/FFTW/fftw3/issues/16
+ */
+#if 0
+#ifdef DLIB_USE_FFTW
+
+ template <long NR, long NC, typename MM, typename L>
+ matrix<std::complex<double>,NR,NC,MM,L> call_fftw_fft(
+ const matrix<std::complex<double>,NR,NC,MM,L>& data
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t matrix fft(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ if (data.size() == 0)
+ return data;
+
+ matrix<std::complex<double>,NR,NC,MM,L> m2(data.nr(),data.nc());
+ fftw_complex *in, *out;
+ fftw_plan p;
+ in = (fftw_complex*)&data(0,0);
+ out = (fftw_complex*)&m2(0,0);
+ if (data.nr() == 1 || data.nc() == 1)
+ p = fftw_plan_dft_1d(data.size(), in, out, FFTW_FORWARD, FFTW_ESTIMATE);
+ else
+ p = fftw_plan_dft_2d(data.nr(), data.nc(), in, out, FFTW_FORWARD, FFTW_ESTIMATE);
+ fftw_execute(p);
+ fftw_destroy_plan(p);
+ return m2;
+ }
+
+ template <long NR, long NC, typename MM, typename L>
+ matrix<std::complex<double>,NR,NC,MM,L> call_fftw_ifft(
+ const matrix<std::complex<double>,NR,NC,MM,L>& data
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t matrix ifft(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ if (data.size() == 0)
+ return data;
+
+ matrix<std::complex<double>,NR,NC,MM,L> m2(data.nr(),data.nc());
+ fftw_complex *in, *out;
+ fftw_plan p;
+ in = (fftw_complex*)&data(0,0);
+ out = (fftw_complex*)&m2(0,0);
+ if (data.nr() == 1 || data.nc() == 1)
+ p = fftw_plan_dft_1d(data.size(), in, out, FFTW_BACKWARD, FFTW_ESTIMATE);
+ else
+ p = fftw_plan_dft_2d(data.nr(), data.nc(), in, out, FFTW_BACKWARD, FFTW_ESTIMATE);
+ fftw_execute(p);
+ fftw_destroy_plan(p);
+ return m2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// call FFTW for these cases:
+ inline matrix<std::complex<double>,0,1> fft (const matrix<std::complex<double>,0,1>& data) {return call_fftw_fft(data);}
+ inline matrix<std::complex<double>,0,1> ifft(const matrix<std::complex<double>,0,1>& data) {return call_fftw_ifft(data)/data.size();}
+ inline matrix<std::complex<double>,1,0> fft (const matrix<std::complex<double>,1,0>& data) {return call_fftw_fft(data);}
+ inline matrix<std::complex<double>,1,0> ifft(const matrix<std::complex<double>,1,0>& data) {return call_fftw_ifft(data)/data.size();}
+ inline matrix<std::complex<double> > fft (const matrix<std::complex<double> >& data) {return call_fftw_fft(data);}
+ inline matrix<std::complex<double> > ifft(const matrix<std::complex<double> >& data) {return call_fftw_ifft(data)/data.size();}
+
+ inline void fft_inplace (matrix<std::complex<double>,0,1>& data) {data = call_fftw_fft(data);}
+ inline void ifft_inplace(matrix<std::complex<double>,0,1>& data) {data = call_fftw_ifft(data);}
+ inline void fft_inplace (matrix<std::complex<double>,1,0>& data) {data = call_fftw_fft(data);}
+ inline void ifft_inplace(matrix<std::complex<double>,1,0>& data) {data = call_fftw_ifft(data);}
+ inline void fft_inplace (matrix<std::complex<double> >& data) {data = call_fftw_fft(data);}
+ inline void ifft_inplace(matrix<std::complex<double> >& data) {data = call_fftw_ifft(data);}
+
+#endif // DLIB_USE_FFTW
+#endif // end of #if 0
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef DLIB_USE_MKL_FFT
+
+#define DLIB_DFTI_CHECK_STATUS(s) \
+ if((s) != 0 && !DftiErrorClass((s), DFTI_NO_ERROR)) \
+ { \
+ throw dlib::error(DftiErrorMessage((s))); \
+ }
+
+ template < long NR, long NC, typename MM, typename L >
+ matrix<std::complex<double>,NR,NC,MM,L> call_mkl_fft(
+ const matrix<std::complex<double>,NR,NC,MM,L>& data,
+ bool do_backward_fft)
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t matrix fft(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ if (data.size() == 0)
+ return data;
+
+ DFTI_DESCRIPTOR_HANDLE h;
+ MKL_LONG status;
+
+ if (data.nr() == 1 || data.nc() == 1)
+ {
+ status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 1, data.size());
+ DLIB_DFTI_CHECK_STATUS(status);
+ }
+ else
+ {
+ MKL_LONG size[2];
+ size[0] = data.nr();
+ size[1] = data.nc();
+
+ status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 2, size);
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ MKL_LONG strides[3];
+ strides[0] = 0;
+ strides[1] = size[1];
+ strides[2] = 1;
+
+ status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
+ DLIB_DFTI_CHECK_STATUS(status);
+ status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
+ DLIB_DFTI_CHECK_STATUS(status);
+ }
+
+ status = DftiSetValue(h, DFTI_PLACEMENT, DFTI_NOT_INPLACE);
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ // Unless we use sequential mode, the fft results are not correct.
+ status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ status = DftiCommitDescriptor(h);
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ matrix<std::complex<double>,NR,NC,MM,L> out(data.nr(), data.nc());
+
+ if (do_backward_fft)
+ status = DftiComputeBackward(h, (void *)(&data(0, 0)), &out(0,0));
+ else
+ status = DftiComputeForward(h, (void *)(&data(0, 0)), &out(0,0));
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ status = DftiFreeDescriptor(&h);
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ return out;
+ }
+
+ template < long NR, long NC, typename MM, typename L >
+ void call_mkl_fft_inplace(
+ matrix<std::complex<double>,NR,NC,MM,L>& data,
+ bool do_backward_fft
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
+ "\t void ifft_inplace(data)"
+ << "\n\t The number of rows and columns must be powers of two."
+ << "\n\t data.nr(): "<< data.nr()
+ << "\n\t data.nc(): "<< data.nc()
+ << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
+ << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
+ );
+
+ if (data.size() == 0)
+ return;
+
+ DFTI_DESCRIPTOR_HANDLE h;
+ MKL_LONG status;
+
+ if (data.nr() == 1 || data.nc() == 1)
+ {
+ status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 1, data.size());
+ DLIB_DFTI_CHECK_STATUS(status);
+ }
+ else
+ {
+ MKL_LONG size[2];
+ size[0] = data.nr();
+ size[1] = data.nc();
+
+ status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 2, size);
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ MKL_LONG strides[3];
+ strides[0] = 0;
+ strides[1] = size[1];
+ strides[2] = 1;
+
+ status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
+ DLIB_DFTI_CHECK_STATUS(status);
+ }
+
+ // Unless we use sequential mode, the fft results are not correct.
+ status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ status = DftiCommitDescriptor(h);
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ if (do_backward_fft)
+ status = DftiComputeBackward(h, &data(0, 0));
+ else
+ status = DftiComputeForward(h, &data(0, 0));
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ status = DftiFreeDescriptor(&h);
+ DLIB_DFTI_CHECK_STATUS(status);
+
+ return;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // Call the MKL DFTI implementation in these cases
+
+ inline matrix<std::complex<double>,0,1> fft (const matrix<std::complex<double>,0,1>& data)
+ {
+ return call_mkl_fft(data, false);
+ }
+ inline matrix<std::complex<double>,0,1> ifft(const matrix<std::complex<double>,0,1>& data)
+ {
+ return call_mkl_fft(data, true) / data.size();
+ }
+ inline matrix<std::complex<double>,1,0> fft (const matrix<std::complex<double>,1,0>& data)
+ {
+ return call_mkl_fft(data, false);
+ }
+ inline matrix<std::complex<double>,1,0> ifft(const matrix<std::complex<double>,1,0>& data)
+ {
+ return call_mkl_fft(data, true) / data.size();
+ }
+ inline matrix<std::complex<double> > fft (const matrix<std::complex<double> >& data)
+ {
+ return call_mkl_fft(data, false);
+ }
+ inline matrix<std::complex<double> > ifft(const matrix<std::complex<double> >& data)
+ {
+ return call_mkl_fft(data, true) / data.size();
+ }
+
+ inline void fft_inplace (matrix<std::complex<double>,0,1>& data)
+ {
+ call_mkl_fft_inplace(data, false);
+ }
+ inline void ifft_inplace(matrix<std::complex<double>,0,1>& data)
+ {
+ call_mkl_fft_inplace(data, true);
+ }
+ inline void fft_inplace (matrix<std::complex<double>,1,0>& data)
+ {
+ call_mkl_fft_inplace(data, false);
+ }
+ inline void ifft_inplace(matrix<std::complex<double>,1,0>& data)
+ {
+ call_mkl_fft_inplace(data, true);
+ }
+
+ inline void fft_inplace (matrix<std::complex<double> >& data)
+ {
+ call_mkl_fft_inplace(data, false);
+ }
+ inline void ifft_inplace(matrix<std::complex<double> >& data)
+ {
+ call_mkl_fft_inplace(data, true);
+ }
+
+#endif // DLIB_USE_MKL_FFT
+
+// ----------------------------------------------------------------------------------------
+}
+
+#endif // DLIB_FFt_Hh_
+
diff --git a/ml/dlib/dlib/matrix/matrix_fft_abstract.h b/ml/dlib/dlib/matrix/matrix_fft_abstract.h
new file mode 100644
index 000000000..25cdfcaee
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_fft_abstract.h
@@ -0,0 +1,118 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FFt_ABSTRACT_Hh_
+#ifdef DLIB_FFt_ABSTRACT_Hh_
+
+#include "matrix_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_power_of_two (
+ const unsigned long& value
+ );
+ /*!
+ ensures
+ - returns true if value contains a power of two and false otherwise. As a
+ special case, we also consider 0 to be a power of two.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ typename EXP::matrix_type fft (
+ const matrix_exp<EXP>& data
+ );
+ /*!
+ requires
+ - data contains elements of type std::complex<> that itself contains double, float, or long double.
+ - is_power_of_two(data.nr()) == true
+ - is_power_of_two(data.nc()) == true
+ ensures
+ - Computes the 1 or 2 dimensional discrete Fourier transform of the given data
+ matrix and returns it. In particular, we return a matrix D such that:
+ - D.nr() == data.nr()
+ - D.nc() == data.nc()
+ - D(0,0) == the DC term of the Fourier transform.
+ - starting with D(0,0), D contains progressively higher frequency components
+ of the input data.
+ - ifft(D) == D
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ typename EXP::matrix_type ifft (
+ const matrix_exp<EXP>& data
+ );
+ /*!
+ requires
+ - data contains elements of type std::complex<> that itself contains double, float, or long double.
+ - is_power_of_two(data.nr()) == true
+ - is_power_of_two(data.nc()) == true
+ ensures
+ - Computes the 1 or 2 dimensional inverse discrete Fourier transform of the
+ given data vector and returns it. In particular, we return a matrix D such
+ that:
+ - D.nr() == data.nr()
+ - D.nc() == data.nc()
+ - fft(D) == data
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ void fft_inplace (
+ matrix<std::complex<T>,NR,NC,MM,L>& data
+ );
+ /*!
+ requires
+ - data contains elements of type std::complex<> that itself contains double, float, or long double.
+ - is_power_of_two(data.nr()) == true
+ - is_power_of_two(data.nc()) == true
+ ensures
+ - This function is identical to fft() except that it does the FFT in-place.
+ That is, after this function executes we will have:
+ - #data == fft(data)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ void ifft_inplace (
+ matrix<std::complex<T>,NR,NC,MM,L>& data
+ );
+ /*!
+ requires
+ - data contains elements of type std::complex<> that itself contains double, float, or long double.
+ - is_power_of_two(data.nr()) == true
+ - is_power_of_two(data.nc()) == true
+ ensures
+ - This function is identical to ifft() except that it does the inverse FFT
+ in-place. That is, after this function executes we will have:
+ - #data == ifft(data)*data.size()
+ - Note that the output needs to be divided by data.size() to complete the
+ inverse transformation.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FFt_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/matrix/matrix_fwd.h b/ml/dlib/dlib/matrix/matrix_fwd.h
new file mode 100644
index 000000000..1f40a17a8
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_fwd.h
@@ -0,0 +1,31 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_FWD
+#define DLIB_MATRIx_FWD
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct row_major_layout;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long num_rows = 0,
+ long num_cols = 0,
+ typename mem_manager = default_memory_manager,
+ typename layout = row_major_layout
+ >
+ class matrix;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_FWD
+
diff --git a/ml/dlib/dlib/matrix/matrix_generic_image.h b/ml/dlib/dlib/matrix/matrix_generic_image.h
new file mode 100644
index 000000000..0455af205
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_generic_image.h
@@ -0,0 +1,110 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIX_GENERIC_iMAGE_Hh_
+#define DLIB_MATRIX_GENERIC_iMAGE_Hh_
+
+#include "matrix.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ struct image_traits<matrix<T,NR,NC,MM> >
+ {
+ typedef T pixel_type;
+ };
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ struct image_traits<const matrix<T,NR,NC,MM> >
+ {
+ typedef T pixel_type;
+ };
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ inline long num_rows( const matrix<T,NR,NC,MM>& img) { return img.nr(); }
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ inline long num_columns( const matrix<T,NR,NC,MM>& img) { return img.nc(); }
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ inline void set_image_size(
+ matrix<T,NR,NC,MM>& img,
+ long rows,
+ long cols
+ ) { img.set_size(rows,cols); }
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ inline void* image_data(
+ matrix<T,NR,NC,MM>& img
+ )
+ {
+ if (img.size() != 0)
+ return &img(0,0);
+ else
+ return 0;
+ }
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ inline const void* image_data(
+ const matrix<T,NR,NC,MM>& img
+ )
+ {
+ if (img.size() != 0)
+ return &img(0,0);
+ else
+ return 0;
+ }
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ inline long width_step(
+ const matrix<T,NR,NC,MM>& img
+ )
+ {
+ return img.nc()*sizeof(T);
+ }
+
+}
+
+#endif // DLIB_MATRIX_GENERIC_iMAGE_Hh_
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_la.h b/ml/dlib/dlib/matrix/matrix_la.h
new file mode 100644
index 000000000..35b5b42e2
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_la.h
@@ -0,0 +1,1807 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_LA_FUNCTS_
+#define DLIB_MATRIx_LA_FUNCTS_
+
+#include "matrix_la_abstract.h"
+#include "matrix_utilities.h"
+#include "../sparse_vector.h"
+#include "../optimization/optimization_line_search.h"
+
+// The 4 decomposition objects described in the matrix_la_abstract.h file are
+// actually implemented in the following 4 files.
+#include "matrix_lu.h"
+#include "matrix_qr.h"
+#include "matrix_cholesky.h"
+#include "matrix_eigenvalue.h"
+
+#ifdef DLIB_USE_LAPACK
+#include "lapack/potrf.h"
+#include "lapack/pbtrf.h"
+#include "lapack/gesdd.h"
+#include "lapack/gesvd.h"
+#endif
+
+#include "../threads.h"
+
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ enum svd_u_mode
+ {
+ SVD_NO_U,
+ SVD_SKINNY_U,
+ SVD_FULL_U
+ };
+
+ template <
+ typename EXP,
+ long qN, long qX,
+ long uM, long uN,
+ long vM, long vN,
+ typename MM1,
+ typename MM2,
+ typename MM3,
+ typename L1
+ >
+ long svd4 (
+ svd_u_mode u_mode,
+ bool withv,
+ const matrix_exp<EXP>& a,
+ matrix<typename EXP::type,uM,uN,MM1,L1>& u,
+ matrix<typename EXP::type,qN,qX,MM2,L1>& q,
+ matrix<typename EXP::type,vM,vN,MM3,L1>& v
+ )
+ {
+ /*
+ Singular value decomposition. Translated to 'C' from the
+ original Algol code in "Handbook for Automatic Computation,
+ vol. II, Linear Algebra", Springer-Verlag. Note that this
+ published algorithm is considered to be the best and numerically
+ stable approach to computing the real-valued svd and is referenced
+ repeatedly in ieee journal papers, etc where the svd is used.
+
+ This is almost an exact translation from the original, except that
+ an iteration counter is added to prevent stalls. This corresponds
+ to similar changes in other translations.
+
+ Returns an error code = 0, if no errors and 'k' if a failure to
+ converge at the 'kth' singular value.
+
+ USAGE: given the singular value decomposition a = u * diagm(q) * trans(v) for an m*n
+ matrix a with m >= n ...
+ After the svd call u is an m x m matrix which is columnwise
+ orthogonal. q will be an n element vector consisting of singular values
+ and v an n x n orthogonal matrix. eps and tol are tolerance constants.
+ Suitable values are eps=1e-16 and tol=(1e-300)/eps if T == double.
+
+ If u_mode == SVD_NO_U then u won't be computed and similarly if withv == false
+ then v won't be computed. If u_mode == SVD_SKINNY_U then u will be m x n instead of m x m.
+ */
+
+
+ DLIB_ASSERT(a.nr() >= a.nc(),
+ "\tconst matrix_exp svd4()"
+ << "\n\tYou have given an invalidly sized matrix"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ );
+
+
+ typedef typename EXP::type T;
+
+#ifdef DLIB_USE_LAPACK
+ matrix<typename EXP::type,0,0,MM1,L1> temp(a), vtemp;
+
+ char jobu = 'A';
+ char jobvt = 'A';
+ if (u_mode == SVD_NO_U)
+ jobu = 'N';
+ else if (u_mode == SVD_SKINNY_U)
+ jobu = 'S';
+ if (withv == false)
+ jobvt = 'N';
+
+ int info;
+ if (jobu == jobvt)
+ {
+ info = lapack::gesdd(jobu, temp, q, u, vtemp);
+ }
+ else
+ {
+ info = lapack::gesvd(jobu, jobvt, temp, q, u, vtemp);
+ }
+
+ // pad q with zeros if it isn't the length we want
+ if (q.nr() < a.nc())
+ q = join_cols(q, zeros_matrix<T>(a.nc()-q.nr(),1));
+
+ if (withv)
+ v = trans(vtemp);
+
+ return info;
+#else
+ using std::abs;
+ using std::sqrt;
+
+ T eps = std::numeric_limits<T>::epsilon();
+ T tol = std::numeric_limits<T>::min()/eps;
+
+ const long m = a.nr();
+ const long n = a.nc();
+ long i, j, k, l = 0, l1, iter, retval;
+ T c, f, g, h, s, x, y, z;
+
+ matrix<T,qN,1,MM2> e(n,1);
+ q.set_size(n,1);
+ if (u_mode == SVD_FULL_U)
+ u.set_size(m,m);
+ else
+ u.set_size(m,n);
+ retval = 0;
+
+ if (withv)
+ {
+ v.set_size(n,n);
+ }
+
+ /* Copy 'a' to 'u' */
+ for (i=0; i<m; i++)
+ {
+ for (j=0; j<n; j++)
+ u(i,j) = a(i,j);
+ }
+
+ /* Householder's reduction to bidiagonal form. */
+ g = x = 0.0;
+ for (i=0; i<n; i++)
+ {
+ e(i) = g;
+ s = 0.0;
+ l = i + 1;
+
+ for (j=i; j<m; j++)
+ s += (u(j,i) * u(j,i));
+
+ if (s < tol)
+ g = 0.0;
+ else
+ {
+ f = u(i,i);
+ g = (f < 0) ? sqrt(s) : -sqrt(s);
+ h = f * g - s;
+ u(i,i) = f - g;
+
+ for (j=l; j<n; j++)
+ {
+ s = 0.0;
+
+ for (k=i; k<m; k++)
+ s += (u(k,i) * u(k,j));
+
+ f = s / h;
+
+ for (k=i; k<m; k++)
+ u(k,j) += (f * u(k,i));
+ } /* end j */
+ } /* end s */
+
+ q(i) = g;
+ s = 0.0;
+
+ for (j=l; j<n; j++)
+ s += (u(i,j) * u(i,j));
+
+ if (s < tol)
+ g = 0.0;
+ else
+ {
+ f = u(i,i+1);
+ g = (f < 0) ? sqrt(s) : -sqrt(s);
+ h = f * g - s;
+ u(i,i+1) = f - g;
+
+ for (j=l; j<n; j++)
+ e(j) = u(i,j) / h;
+
+ for (j=l; j<m; j++)
+ {
+ s = 0.0;
+
+ for (k=l; k<n; k++)
+ s += (u(j,k) * u(i,k));
+
+ for (k=l; k<n; k++)
+ u(j,k) += (s * e(k));
+ } /* end j */
+ } /* end s */
+
+ y = abs(q(i)) + abs(e(i));
+ if (y > x)
+ x = y;
+ } /* end i */
+
+ /* accumulation of right-hand transformations */
+ if (withv)
+ {
+ for (i=n-1; i>=0; i--)
+ {
+ if (g != 0.0)
+ {
+ h = u(i,i+1) * g;
+
+ for (j=l; j<n; j++)
+ v(j,i) = u(i,j)/h;
+
+ for (j=l; j<n; j++)
+ {
+ s = 0.0;
+
+ for (k=l; k<n; k++)
+ s += (u(i,k) * v(k,j));
+
+ for (k=l; k<n; k++)
+ v(k,j) += (s * v(k,i));
+ } /* end j */
+ } /* end g */
+
+ for (j=l; j<n; j++)
+ v(i,j) = v(j,i) = 0.0;
+
+ v(i,i) = 1.0;
+ g = e(i);
+ l = i;
+ } /* end i */
+ } /* end withv, parens added for clarity */
+
+ /* accumulation of left-hand transformations */
+ if (u_mode != SVD_NO_U)
+ {
+ for (i=n; i<u.nr(); i++)
+ {
+ for (j=n;j<u.nc();j++)
+ u(i,j) = 0.0;
+
+ if (i < u.nc())
+ u(i,i) = 1.0;
+ }
+ }
+
+ if (u_mode != SVD_NO_U)
+ {
+ for (i=n-1; i>=0; i--)
+ {
+ l = i + 1;
+ g = q(i);
+
+ for (j=l; j<u.nc(); j++)
+ u(i,j) = 0.0;
+
+ if (g != 0.0)
+ {
+ h = u(i,i) * g;
+
+ for (j=l; j<u.nc(); j++)
+ {
+ s = 0.0;
+
+ for (k=l; k<m; k++)
+ s += (u(k,i) * u(k,j));
+
+ f = s / h;
+
+ for (k=i; k<m; k++)
+ u(k,j) += (f * u(k,i));
+ } /* end j */
+
+ for (j=i; j<m; j++)
+ u(j,i) /= g;
+ } /* end g */
+ else
+ {
+ for (j=i; j<m; j++)
+ u(j,i) = 0.0;
+ }
+
+ u(i,i) += 1.0;
+ } /* end i*/
+ }
+
+ /* diagonalization of the bidiagonal form */
+ eps *= x;
+
+ for (k=n-1; k>=0; k--)
+ {
+ iter = 0;
+
+test_f_splitting:
+
+ for (l=k; l>=0; l--)
+ {
+ if (abs(e(l)) <= eps)
+ goto test_f_convergence;
+
+ if (abs(q(l-1)) <= eps)
+ goto cancellation;
+ } /* end l */
+
+ /* cancellation of e(l) if l > 0 */
+
+cancellation:
+
+ c = 0.0;
+ s = 1.0;
+ l1 = l - 1;
+
+ for (i=l; i<=k; i++)
+ {
+ f = s * e(i);
+ e(i) *= c;
+
+ if (abs(f) <= eps)
+ goto test_f_convergence;
+
+ g = q(i);
+ h = q(i) = sqrt(f*f + g*g);
+ c = g / h;
+ s = -f / h;
+
+ if (u_mode != SVD_NO_U)
+ {
+ for (j=0; j<m; j++)
+ {
+ y = u(j,l1);
+ z = u(j,i);
+ u(j,l1) = y * c + z * s;
+ u(j,i) = -y * s + z * c;
+ } /* end j */
+ }
+ } /* end i */
+
+test_f_convergence:
+
+ z = q(k);
+ if (l == k)
+ goto convergence;
+
+ /* shift from bottom 2x2 minor */
+ iter++;
+ if (iter > 300)
+ {
+ retval = k;
+ break;
+ }
+ x = q(l);
+ y = q(k-1);
+ g = e(k-1);
+ h = e(k);
+ f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2 * h * y);
+ g = sqrt(f * f + 1.0);
+ f = ((x - z) * (x + z) + h * (y / ((f < 0)?(f - g) : (f + g)) - h)) / x;
+
+ /* next QR transformation */
+ c = s = 1.0;
+
+ for (i=l+1; i<=k; i++)
+ {
+ g = e(i);
+ y = q(i);
+ h = s * g;
+ g *= c;
+ e(i-1) = z = sqrt(f * f + h * h);
+ c = f / z;
+ s = h / z;
+ f = x * c + g * s;
+ g = -x * s + g * c;
+ h = y * s;
+ y *= c;
+
+ if (withv)
+ {
+ for (j=0;j<n;j++)
+ {
+ x = v(j,i-1);
+ z = v(j,i);
+ v(j,i-1) = x * c + z * s;
+ v(j,i) = -x * s + z * c;
+ } /* end j */
+ } /* end withv, parens added for clarity */
+
+ q(i-1) = z = sqrt(f * f + h * h);
+ if (z != 0)
+ {
+ c = f / z;
+ s = h / z;
+ }
+ f = c * g + s * y;
+ x = -s * g + c * y;
+ if (u_mode != SVD_NO_U)
+ {
+ for (j=0; j<m; j++)
+ {
+ y = u(j,i-1);
+ z = u(j,i);
+ u(j,i-1) = y * c + z * s;
+ u(j,i) = -y * s + z * c;
+ } /* end j */
+ }
+ } /* end i */
+
+ e(l) = 0.0;
+ e(k) = f;
+ q(k) = x;
+
+ goto test_f_splitting;
+
+convergence:
+
+ if (z < 0.0)
+ {
+ /* q(k) is made non-negative */
+ q(k) = -z;
+ if (withv)
+ {
+ for (j=0; j<n; j++)
+ v(j,k) = -v(j,k);
+ } /* end withv, parens added for clarity */
+ } /* end z */
+ } /* end k */
+
+ return retval;
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ long qN, long qX,
+ long uM,
+ long vN,
+ typename MM1,
+ typename MM2,
+ typename MM3,
+ typename L1
+ >
+ long svd2 (
+ bool withu,
+ bool withv,
+ const matrix_exp<EXP>& a,
+ matrix<typename EXP::type,uM,uM,MM1,L1>& u,
+ matrix<typename EXP::type,qN,qX,MM2,L1>& q,
+ matrix<typename EXP::type,vN,vN,MM3,L1>& v
+ )
+ {
+ const long NR = matrix_exp<EXP>::NR;
+ const long NC = matrix_exp<EXP>::NC;
+
+ // make sure the output matrices have valid dimensions if they are statically dimensioned
+ COMPILE_TIME_ASSERT(qX == 0 || qX == 1);
+ COMPILE_TIME_ASSERT(NR == 0 || uM == 0 || NR == uM);
+ COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN);
+
+ DLIB_ASSERT(a.nr() >= a.nc(),
+ "\tconst matrix_exp svd4()"
+ << "\n\tYou have given an invalidly sized matrix"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ );
+
+ if (withu)
+ return svd4(SVD_FULL_U, withv, a,u,q,v);
+ else
+ return svd4(SVD_NO_U, withv, a,u,q,v);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ void orthogonalize (
+ matrix<T,NR,NC,MM,row_major_layout>& m
+ )
+ {
+ // We don't really need to use this temporary, but doing it this way runs a lot
+ // faster.
+ matrix<T,NR,NC,MM,column_major_layout> temp;
+ qr_decomposition<matrix<T,NR,NC,MM,row_major_layout>>(m).get_q(temp);
+ m = temp;
+ }
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ void orthogonalize (
+ matrix<T,NR,NC,MM,column_major_layout>& m
+ )
+ {
+ qr_decomposition<matrix<T,NR,NC,MM,column_major_layout>>(m).get_q(m);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long Anr, long Anc,
+ typename MM,
+ typename L
+ >
+ void find_matrix_range (
+ const matrix<T,Anr,Anc,MM,L>& A,
+ unsigned long l,
+ matrix<T,Anr,0,MM,L>& Q,
+ unsigned long q
+ )
+ /*!
+ requires
+ - A.nr() >= l
+ ensures
+ - #Q.nr() == A.nr()
+ - #Q.nc() == l
+ - #Q == an orthonormal matrix whose range approximates the range of the
+ matrix A.
+ - This function implements the randomized subspace iteration defined
+ in the algorithm 4.4 box of the paper:
+ Finding Structure with Randomness: Probabilistic Algorithms for
+ Constructing Approximate Matrix Decompositions by Halko et al.
+ - q defines the number of extra subspace iterations this algorithm will
+ perform. Often q == 0 is fine, but performing more iterations can lead to a
+ more accurate approximation of the range of A if A has slowly decaying
+ singular values. In these cases, using a q of 1 or 2 is good.
+ !*/
+ {
+ DLIB_ASSERT(A.nr() >= (long)l, "Invalid inputs were given to this function.");
+ Q = A*matrix_cast<T>(gaussian_randm(A.nc(), l));
+
+ orthogonalize(Q);
+
+ // Do some extra iterations of the power method to make sure we get Q into the
+ // span of the most important singular vectors of A.
+ if (q != 0)
+ {
+ for (unsigned long itr = 0; itr < q; ++itr)
+ {
+ Q = trans(A)*Q;
+ orthogonalize(Q);
+
+ Q = A*Q;
+ orthogonalize(Q);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long Anr, long Anc,
+ long Unr, long Unc,
+ long Wnr, long Wnc,
+ long Vnr, long Vnc,
+ typename MM,
+ typename L
+ >
+ void svd_fast (
+ const matrix<T,Anr,Anc,MM,L>& A,
+ matrix<T,Unr,Unc,MM,L>& u,
+ matrix<T,Wnr,Wnc,MM,L>& w,
+ matrix<T,Vnr,Vnc,MM,L>& v,
+ unsigned long l,
+ unsigned long q = 1
+ )
+ {
+ const unsigned long k = std::min(l, std::min<unsigned long>(A.nr(),A.nc()));
+
+ DLIB_ASSERT(l > 0 && A.size() > 0,
+ "\t void svd_fast()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t l: " << l
+ << "\n\t A.size(): " << A.size()
+ );
+
+ matrix<T,Anr,0,MM,L> Q;
+ find_matrix_range(A, k, Q, q);
+
+ // Compute trans(B) = trans(Q)*A. The reason we store B transposed
+ // is so that when we take its SVD later using svd3() it doesn't consume
+ // a whole lot of RAM. That is, we make sure the square matrix coming out
+ // of svd3() has size lxl rather than the potentially much larger nxn.
+ matrix<T,0,0,MM,L> B = trans(A)*Q;
+ svd3(B, v,w,u);
+ u = Q*u;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sparse_vector_type,
+ typename T,
+ typename MM,
+ typename L
+ >
+ void find_matrix_range (
+ const std::vector<sparse_vector_type>& A,
+ unsigned long l,
+ matrix<T,0,0,MM,L>& Q,
+ unsigned long q
+ )
+ /*!
+ requires
+ - A.size() >= l
+ ensures
+ - #Q.nr() == A.size()
+ - #Q.nc() == l
+ - #Q == an orthonormal matrix whose range approximates the range of the
+ matrix A. In this case, we interpret A as a matrix of A.size() rows,
+ where each row is defined by a sparse vector.
+ - This function implements the randomized subspace iteration defined
+ in the algorithm 4.4 box of the paper:
+ Finding Structure with Randomness: Probabilistic Algorithms for
+ Constructing Approximate Matrix Decompositions by Halko et al.
+ - q defines the number of extra subspace iterations this algorithm will
+ perform. Often q == 0 is fine, but performing more iterations can lead to a
+ more accurate approximation of the range of A if A has slowly decaying
+ singular values. In these cases, using a q of 1 or 2 is good.
+ !*/
+ {
+ DLIB_ASSERT(A.size() >= l, "Invalid inputs were given to this function.");
+ Q.set_size(A.size(), l);
+
+ // Compute Q = A*gaussian_randm()
+ parallel_for(0, Q.nr(), [&](long r)
+ {
+ for (long c = 0; c < Q.nc(); ++c)
+ {
+ Q(r,c) = dot(A[r], gaussian_randm(std::numeric_limits<long>::max(), 1, c));
+ }
+ });
+
+ orthogonalize(Q);
+
+ // Do some extra iterations of the power method to make sure we get Q into the
+ // span of the most important singular vectors of A.
+ if (q != 0)
+ {
+ dlib::mutex mut;
+ const unsigned long n = max_index_plus_one(A);
+ for (unsigned long itr = 0; itr < q; ++itr)
+ {
+ matrix<T,0,0,MM> Z;
+ // Compute Z = trans(A)*Q
+ parallel_for_blocked(0, A.size(), [&](long begin, long end)
+ {
+ matrix<T,0,0,MM> Zlocal(n,l);
+ Zlocal = 0;
+ for (long m = begin; m < end; ++m)
+ {
+ for (unsigned long r = 0; r < l; ++r)
+ {
+ for (auto& i : A[m])
+ {
+ const auto c = i.first;
+ const auto val = i.second;
+
+ Zlocal(c,r) += Q(m,r)*val;
+ }
+ }
+ }
+ auto_mutex lock(mut);
+ Z += Zlocal;
+ },1);
+
+ Q.set_size(0,0); // free RAM
+ orthogonalize(Z);
+
+ // Compute Q = A*Z
+ Q.set_size(A.size(), l);
+ parallel_for(0, Q.nr(), [&](long r)
+ {
+ for (long c = 0; c < Q.nc(); ++c)
+ {
+ Q(r,c) = dot(A[r], colm(Z,c));
+ }
+ });
+
+ Z.set_size(0,0); // free RAM
+ orthogonalize(Q);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace simpl
+ {
+ template <
+ typename sparse_vector_type,
+ typename T,
+ long Unr, long Unc,
+ long Wnr, long Wnc,
+ long Vnr, long Vnc,
+ typename MM,
+ typename L
+ >
+ void svd_fast (
+ bool compute_u,
+ const std::vector<sparse_vector_type>& A,
+ matrix<T,Unr,Unc,MM,L>& u,
+ matrix<T,Wnr,Wnc,MM,L>& w,
+ matrix<T,Vnr,Vnc,MM,L>& v,
+ unsigned long l,
+ unsigned long q
+ )
+ {
+ const long n = max_index_plus_one(A);
+ const unsigned long k = std::min(l, std::min<unsigned long>(A.size(),n));
+
+ DLIB_ASSERT(l > 0 && A.size() > 0 && n > 0,
+ "\t void svd_fast()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t l: " << l
+ << "\n\t n (i.e. max_index_plus_one(A)): " << n
+ << "\n\t A.size(): " << A.size()
+ );
+
+ matrix<T,0,0,MM,L> Q;
+ find_matrix_range(A, k, Q, q);
+
+ // Compute trans(B) = trans(Q)*A. The reason we store B transposed
+ // is so that when we take its SVD later using svd3() it doesn't consume
+ // a whole lot of RAM. That is, we make sure the square matrix coming out
+ // of svd3() has size lxl rather than the potentially much larger nxn.
+ matrix<T,0,0,MM> B;
+ dlib::mutex mut;
+ parallel_for_blocked(0, A.size(), [&](long begin, long end)
+ {
+ matrix<T,0,0,MM> Blocal(n,k);
+ Blocal = 0;
+ for (long m = begin; m < end; ++m)
+ {
+ for (unsigned long r = 0; r < k; ++r)
+ {
+ for (auto& i : A[m])
+ {
+ const auto c = i.first;
+ const auto val = i.second;
+
+ Blocal(c,r) += Q(m,r)*val;
+ }
+ }
+ }
+ auto_mutex lock(mut);
+ B += Blocal;
+ },1);
+
+ svd3(B, v,w,u);
+ if (compute_u)
+ u = Q*u;
+ }
+ }
+
+ template <
+ typename sparse_vector_type,
+ typename T,
+ long Unr, long Unc,
+ long Wnr, long Wnc,
+ long Vnr, long Vnc,
+ typename MM,
+ typename L
+ >
+ void svd_fast (
+ const std::vector<sparse_vector_type>& A,
+ matrix<T,Unr,Unc,MM,L>& u,
+ matrix<T,Wnr,Wnc,MM,L>& w,
+ matrix<T,Vnr,Vnc,MM,L>& v,
+ unsigned long l,
+ unsigned long q = 1
+ )
+ {
+ simpl::svd_fast(true, A,u,w,v,l,q);
+ }
+
+ template <
+ typename sparse_vector_type,
+ typename T,
+ long Wnr, long Wnc,
+ long Vnr, long Vnc,
+ typename MM,
+ typename L
+ >
+ void svd_fast (
+ const std::vector<sparse_vector_type>& A,
+ matrix<T,Wnr,Wnc,MM,L>& w,
+ matrix<T,Vnr,Vnc,MM,L>& v,
+ unsigned long l,
+ unsigned long q = 1
+ )
+ {
+ matrix<T,0,0,MM,L> u;
+ simpl::svd_fast(false, A,u,w,v,l,q);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ long N
+ >
+ struct inv_helper
+ {
+ static const typename matrix_exp<EXP>::matrix_type inv (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can't invert a non-square matrix
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC ||
+ matrix_exp<EXP>::NR == 0 ||
+ matrix_exp<EXP>::NC == 0);
+ DLIB_ASSERT(m.nr() == m.nc(),
+ "\tconst matrix_exp::type inv(const matrix_exp& m)"
+ << "\n\tYou can only apply inv() to a square matrix"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ typedef typename matrix_exp<EXP>::type type;
+
+ lu_decomposition<EXP> lu(m);
+ return lu.solve(identity_matrix<type>(m.nr()));
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ struct inv_helper<EXP,1>
+ {
+ static const typename matrix_exp<EXP>::matrix_type inv (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC);
+ typedef typename matrix_exp<EXP>::type type;
+
+ matrix<type, 1, 1, typename EXP::mem_manager_type> a;
+ // if m is invertible
+ if (m(0) != 0)
+ a(0) = 1/m(0);
+ else
+ a(0) = 1;
+ return a;
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ struct inv_helper<EXP,2>
+ {
+ static const typename matrix_exp<EXP>::matrix_type inv (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC);
+ typedef typename matrix_exp<EXP>::type type;
+
+ matrix<type, 2, 2, typename EXP::mem_manager_type> a;
+ type d = det(m);
+ if (d != 0)
+ {
+ d = static_cast<type>(1.0/d);
+ a(0,0) = m(1,1)*d;
+ a(0,1) = m(0,1)*-d;
+ a(1,0) = m(1,0)*-d;
+ a(1,1) = m(0,0)*d;
+ }
+ else
+ {
+ // Matrix isn't invertible so just return the identity matrix.
+ a = identity_matrix<type,2>();
+ }
+ return a;
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ struct inv_helper<EXP,3>
+ {
+ static const typename matrix_exp<EXP>::matrix_type inv (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC);
+ typedef typename matrix_exp<EXP>::type type;
+
+ matrix<type, 3, 3, typename EXP::mem_manager_type> ret;
+ type de = det(m);
+ if (de != 0)
+ {
+ de = static_cast<type>(1.0/de);
+ const type a = m(0,0);
+ const type b = m(0,1);
+ const type c = m(0,2);
+ const type d = m(1,0);
+ const type e = m(1,1);
+ const type f = m(1,2);
+ const type g = m(2,0);
+ const type h = m(2,1);
+ const type i = m(2,2);
+
+ ret(0,0) = (e*i - f*h)*de;
+ ret(1,0) = (f*g - d*i)*de;
+ ret(2,0) = (d*h - e*g)*de;
+
+ ret(0,1) = (c*h - b*i)*de;
+ ret(1,1) = (a*i - c*g)*de;
+ ret(2,1) = (b*g - a*h)*de;
+
+ ret(0,2) = (b*f - c*e)*de;
+ ret(1,2) = (c*d - a*f)*de;
+ ret(2,2) = (a*e - b*d)*de;
+ }
+ else
+ {
+ ret = identity_matrix<type,3>();
+ }
+
+ return ret;
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ struct inv_helper<EXP,4>
+ {
+ static const typename matrix_exp<EXP>::matrix_type inv (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC);
+ typedef typename matrix_exp<EXP>::type type;
+
+ matrix<type, 4, 4, typename EXP::mem_manager_type> ret;
+ type de = det(m);
+ if (de != 0)
+ {
+ de = static_cast<type>(1.0/de);
+ ret(0,0) = det(removerc<0,0>(m));
+ ret(0,1) = -det(removerc<0,1>(m));
+ ret(0,2) = det(removerc<0,2>(m));
+ ret(0,3) = -det(removerc<0,3>(m));
+
+ ret(1,0) = -det(removerc<1,0>(m));
+ ret(1,1) = det(removerc<1,1>(m));
+ ret(1,2) = -det(removerc<1,2>(m));
+ ret(1,3) = det(removerc<1,3>(m));
+
+ ret(2,0) = det(removerc<2,0>(m));
+ ret(2,1) = -det(removerc<2,1>(m));
+ ret(2,2) = det(removerc<2,2>(m));
+ ret(2,3) = -det(removerc<2,3>(m));
+
+ ret(3,0) = -det(removerc<3,0>(m));
+ ret(3,1) = det(removerc<3,1>(m));
+ ret(3,2) = -det(removerc<3,2>(m));
+ ret(3,3) = det(removerc<3,3>(m));
+
+ return trans(ret)*de;
+ }
+ else
+ {
+ return identity_matrix<type,4>();
+ }
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ inline const typename matrix_exp<EXP>::matrix_type inv (
+ const matrix_exp<EXP>& m
+ ) { return inv_helper<EXP,matrix_exp<EXP>::NR>::inv(m); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_diag_inv
+ {
+ template <typename EXP>
+ op_diag_inv( const matrix_exp<EXP>& m_) : m(m_){}
+
+
+ const static long cost = 1;
+ const static long NR = ((M::NC!=0)&&(M::NR!=0))? (tmax<M::NR,M::NC>::value) : (0);
+ const static long NC = NR;
+ typedef typename M::type type;
+ typedef const type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+
+ // hold the matrix by value
+ const matrix<type,NR,1,mem_manager_type,layout_type> m;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r==c)
+ return m(r);
+ else
+ return 0;
+ }
+
+ long nr () const { return m.size(); }
+ long nc () const { return m.size(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_diag_op<op_diag_inv<EXP> > inv (
+ const matrix_diag_exp<EXP>& m
+ )
+ {
+ typedef op_diag_inv<EXP> op;
+ return matrix_diag_op<op>(op(reciprocal(diag(m))));
+ }
+
+ template <
+ typename EXP
+ >
+ const matrix_diag_op<op_diag_inv<EXP> > pinv (
+ const matrix_diag_exp<EXP>& m
+ )
+ {
+ typedef op_diag_inv<EXP> op;
+ return matrix_diag_op<op>(op(reciprocal(diag(m))));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const matrix_diag_op<op_diag_inv<EXP> > pinv (
+ const matrix_diag_exp<EXP>& m,
+ double tol
+ )
+ {
+ DLIB_ASSERT(tol >= 0,
+ "\tconst matrix_exp::type pinv(const matrix_exp& m)"
+ << "\n\t tol can't be negative"
+ << "\n\t tol: "<<tol
+ );
+ typedef op_diag_inv<EXP> op;
+ return matrix_diag_op<op>(op(reciprocal(round_zeros(diag(m),tol))));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ const typename matrix_exp<EXP>::matrix_type inv_lower_triangular (
+ const matrix_exp<EXP>& A
+ )
+ {
+ DLIB_ASSERT(A.nr() == A.nc(),
+ "\tconst matrix inv_lower_triangular(const matrix_exp& A)"
+ << "\n\tA must be a square matrix"
+ << "\n\tA.nr(): " << A.nr()
+ << "\n\tA.nc(): " << A.nc()
+ );
+
+ typedef typename matrix_exp<EXP>::matrix_type matrix_type;
+
+ matrix_type m(A);
+
+ for(long c = 0; c < m.nc(); ++c)
+ {
+ if( m(c,c) == 0 )
+ {
+ // there isn't an inverse so just give up
+ return m;
+ }
+
+ // compute m(c,c)
+ m(c,c) = 1/m(c,c);
+
+ // compute the values in column c that are below m(c,c).
+ // We do this by just doing the same thing we do for upper triangular
+ // matrices because we take the transpose of m which turns m into an
+ // upper triangular matrix.
+ for(long r = 0; r < c; ++r)
+ {
+ const long n = c-r;
+ m(c,r) = -m(c,c)*subm(trans(m),r,r,1,n)*subm(trans(m),r,c,n,1);
+ }
+ }
+
+ return m;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ const typename matrix_exp<EXP>::matrix_type inv_upper_triangular (
+ const matrix_exp<EXP>& A
+ )
+ {
+ DLIB_ASSERT(A.nr() == A.nc(),
+ "\tconst matrix inv_upper_triangular(const matrix_exp& A)"
+ << "\n\tA must be a square matrix"
+ << "\n\tA.nr(): " << A.nr()
+ << "\n\tA.nc(): " << A.nc()
+ );
+
+ typedef typename matrix_exp<EXP>::matrix_type matrix_type;
+
+ matrix_type m(A);
+
+ for(long c = 0; c < m.nc(); ++c)
+ {
+ if( m(c,c) == 0 )
+ {
+ // there isn't an inverse so just give up
+ return m;
+ }
+
+ // compute m(c,c)
+ m(c,c) = 1/m(c,c);
+
+ // compute the values in column c that are above m(c,c)
+ for(long r = 0; r < c; ++r)
+ {
+ const long n = c-r;
+ m(r,c) = -m(c,c)*subm(m,r,r,1,n)*subm(m,r,c,n,1);
+ }
+ }
+
+ return m;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ inline const typename matrix_exp<EXP>::matrix_type chol (
+ const matrix_exp<EXP>& A
+ )
+ {
+ DLIB_ASSERT(A.nr() == A.nc(),
+ "\tconst matrix chol(const matrix_exp& A)"
+ << "\n\tYou can only apply the chol to a square matrix"
+ << "\n\tA.nr(): " << A.nr()
+ << "\n\tA.nc(): " << A.nc()
+ );
+ typename matrix_exp<EXP>::matrix_type L(A.nr(),A.nc());
+
+ typedef typename EXP::type T;
+
+ bool banded = false;
+ long bandwidth = 0;
+
+ if (A.nr() > 4) // Only test for banded matrix if matrix is big enough
+ {
+ // Detect if matrix is banded and, if so, matrix bandwidth
+ banded = true;
+ for (long r = 0; r < A.nr(); ++r)
+ for (long c = (r + bandwidth + 1); c < A.nc(); ++c)
+ if (A(r, c) != 0)
+ {
+ bandwidth = c - r;
+ if (bandwidth > A.nr() / 2)
+ {
+ banded = false;
+ goto escape_banded_detection;
+ }
+ }
+ }
+escape_banded_detection:
+
+ if (banded)
+ {
+ // Store in compact form - use column major for LAPACK
+ matrix<T,0,0,default_memory_manager,column_major_layout> B(bandwidth + 1, A.nc());
+ set_all_elements(B, 0);
+
+ for (long r = 0; r < A.nr(); ++r)
+ for (long c = r; c < std::min(r + bandwidth + 1, A.nc()); ++c)
+ B(c - r, r) = A(r, c);
+
+#ifdef DLIB_USE_LAPACK
+
+ lapack::pbtrf('L', B);
+
+#else
+
+ // Peform compact Cholesky
+ for (long k = 0; k < A.nr(); ++k)
+ {
+ long last = std::min(k + bandwidth, A.nr() - 1) - k;
+ for (long j = 1; j <= last; ++j)
+ {
+ long i = k + j;
+ for (long c = 0; c <= (last - j); ++c)
+ B(c, i) -= B(j, k) / B(0, k) * B(c + j, k);
+ }
+ T norm = std::sqrt(B(0, k));
+ for (long i = 0; i <= bandwidth; ++i)
+ B(i, k) /= norm;
+ }
+ for (long c = A.nc() - bandwidth + 1; c < A.nc(); ++c)
+ B(bandwidth, c) = 0;
+
+#endif
+
+ // Unpack lower triangular area
+ set_all_elements(L, 0);
+ for (long c = 0; c < A.nc(); ++c)
+ for (long i = 0; i <= bandwidth; ++i)
+ {
+ long ind = c + i;
+ if (ind < A.nc())
+ L(ind, c) = B(i, c);
+ }
+
+ return L;
+ }
+
+#ifdef DLIB_USE_LAPACK
+ // Only call LAPACK if the matrix is big enough. Otherwise,
+ // our own code is faster, especially for statically dimensioned
+ // matrices.
+ if (A.nr() > 4)
+ {
+ L = A;
+ lapack::potrf('L', L);
+ // mask out upper triangular area
+ return lowerm(L);
+ }
+#endif
+ set_all_elements(L,0);
+
+ // do nothing if the matrix is empty
+ if (A.size() == 0)
+ return L;
+
+ const T eps = std::numeric_limits<T>::epsilon();
+
+ // compute the upper left corner
+ if (A(0,0) > 0)
+ L(0,0) = std::sqrt(A(0,0));
+
+ // compute the first column
+ for (long r = 1; r < A.nr(); ++r)
+ {
+ // if (L(0,0) > 0)
+ if (L(0,0) > eps*std::abs(A(r,0)))
+ L(r,0) = A(r,0)/L(0,0);
+ else
+ return L;
+ }
+
+ // now compute all the other columns
+ for (long c = 1; c < A.nc(); ++c)
+ {
+ // compute the diagonal element
+ T temp = A(c,c);
+ for (long i = 0; i < c; ++i)
+ {
+ temp -= L(c,i)*L(c,i);
+ }
+ if (temp > 0)
+ L(c,c) = std::sqrt(temp);
+
+ // compute the non diagonal elements
+ for (long r = c+1; r < A.nr(); ++r)
+ {
+ temp = A(r,c);
+ for (long i = 0; i < c; ++i)
+ {
+ temp -= L(r,i)*L(c,i);
+ }
+
+ // if (L(c,c) > 0)
+ if (L(c,c) > eps*std::abs(temp))
+ L(r,c) = temp/L(c,c);
+ else
+ return L;
+ }
+ }
+
+ return L;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ long uNR,
+ long uNC,
+ long wN,
+ long vN,
+ long wX,
+ typename MM1,
+ typename MM2,
+ typename MM3,
+ typename L1
+ >
+ inline void svd3 (
+ const matrix_exp<EXP>& m,
+ matrix<typename matrix_exp<EXP>::type, uNR, uNC,MM1,L1>& u,
+ matrix<typename matrix_exp<EXP>::type, wN, wX,MM2,L1>& w,
+ matrix<typename matrix_exp<EXP>::type, vN, vN,MM3,L1>& v
+ )
+ {
+ typedef typename matrix_exp<EXP>::type T;
+ const long NR = matrix_exp<EXP>::NR;
+ const long NC = matrix_exp<EXP>::NC;
+
+ // make sure the output matrices have valid dimensions if they are statically dimensioned
+ COMPILE_TIME_ASSERT(NR == 0 || uNR == 0 || NR == uNR);
+ COMPILE_TIME_ASSERT(NC == 0 || uNC == 0 || NC == uNC);
+ COMPILE_TIME_ASSERT(NC == 0 || wN == 0 || NC == wN);
+ COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN);
+ COMPILE_TIME_ASSERT(wX == 0 || wX == 1);
+
+#ifdef DLIB_USE_LAPACK
+ // use LAPACK but only if it isn't a really small matrix we are taking the SVD of.
+ if (NR*NC == 0 || NR*NC > 3*3)
+ {
+ matrix<typename matrix_exp<EXP>::type, uNR, uNC,MM1,L1> temp(m);
+ lapack::gesvd('S','A', temp, w, u, v);
+ v = trans(v);
+ // if u isn't the size we want then pad it (and v) with zeros
+ if (u.nc() < m.nc())
+ {
+ w = join_cols(w, zeros_matrix<T>(m.nc()-u.nc(),1));
+ u = join_rows(u, zeros_matrix<T>(u.nr(), m.nc()-u.nc()));
+ }
+ return;
+ }
+#endif
+ if (m.nr() >= m.nc())
+ {
+ svd4(SVD_SKINNY_U,true, m, u,w,v);
+ }
+ else
+ {
+ svd4(SVD_FULL_U,true, trans(m), v,w,u);
+
+ // if u isn't the size we want then pad it (and v) with zeros
+ if (u.nc() < m.nc())
+ {
+ w = join_cols(w, zeros_matrix<T>(m.nc()-u.nc(),1));
+ u = join_rows(u, zeros_matrix<T>(u.nr(), m.nc()-u.nc()));
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const matrix<typename EXP::type,EXP::NC,EXP::NR,typename EXP::mem_manager_type> pinv_helper (
+ const matrix_exp<EXP>& m,
+ double tol
+ )
+ /*!
+ ensures
+ - computes the results of pinv(m) but does so using a method that is fastest
+ when m.nc() <= m.nr(). So if m.nc() > m.nr() then it is best to use
+ trans(pinv_helper(trans(m))) to compute pinv(m).
+ !*/
+ {
+ typename matrix_exp<EXP>::matrix_type u;
+ typedef typename EXP::mem_manager_type MM1;
+ typedef typename EXP::layout_type layout_type;
+ matrix<typename EXP::type, EXP::NC, EXP::NC,MM1, layout_type > v;
+
+ typedef typename matrix_exp<EXP>::type T;
+
+ matrix<T,matrix_exp<EXP>::NC,1,MM1, layout_type> w;
+
+ svd3(m, u,w,v);
+
+ const double machine_eps = std::numeric_limits<typename EXP::type>::epsilon();
+ // compute a reasonable epsilon below which we round to zero before doing the
+ // reciprocal. Unless a non-zero tol is given then we just use tol*max(w).
+ const double eps = (tol!=0) ? tol*max(w) : machine_eps*std::max(m.nr(),m.nc())*max(w);
+
+ // now compute the pseudoinverse
+ return tmp(scale_columns(v,reciprocal(round_zeros(w,eps))))*trans(u);
+ }
+
+ template <
+ typename EXP
+ >
+ const matrix<typename EXP::type,EXP::NC,EXP::NR,typename EXP::mem_manager_type> pinv (
+ const matrix_exp<EXP>& m,
+ double tol = 0
+ )
+ {
+ DLIB_ASSERT(tol >= 0,
+ "\tconst matrix_exp::type pinv(const matrix_exp& m)"
+ << "\n\t tol can't be negative"
+ << "\n\t tol: "<<tol
+ );
+ // if m has more columns then rows then it is more efficient to
+ // compute the pseudo-inverse of its transpose (given the way I'm doing it below).
+ if (m.nc() > m.nr())
+ return trans(pinv_helper(trans(m),tol));
+ else
+ return pinv_helper(m,tol);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ long uNR,
+ long uNC,
+ long wN,
+ long vN,
+ typename MM1,
+ typename MM2,
+ typename MM3,
+ typename L1
+ >
+ inline void svd (
+ const matrix_exp<EXP>& m,
+ matrix<typename matrix_exp<EXP>::type, uNR, uNC,MM1,L1>& u,
+ matrix<typename matrix_exp<EXP>::type, wN, wN,MM2,L1>& w,
+ matrix<typename matrix_exp<EXP>::type, vN, vN,MM3,L1>& v
+ )
+ {
+ typedef typename matrix_exp<EXP>::type T;
+ const long NR = matrix_exp<EXP>::NR;
+ const long NC = matrix_exp<EXP>::NC;
+
+ // make sure the output matrices have valid dimensions if they are statically dimensioned
+ COMPILE_TIME_ASSERT(NR == 0 || uNR == 0 || NR == uNR);
+ COMPILE_TIME_ASSERT(NC == 0 || uNC == 0 || NC == uNC);
+ COMPILE_TIME_ASSERT(NC == 0 || wN == 0 || NC == wN);
+ COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN);
+
+ matrix<T,matrix_exp<EXP>::NC,1,MM1, L1> W;
+ svd3(m,u,W,v);
+ w = diagm(W);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const typename matrix_exp<EXP>::type trace (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC ||
+ matrix_exp<EXP>::NR == 0 ||
+ matrix_exp<EXP>::NC == 0
+ );
+ DLIB_ASSERT(m.nr() == m.nc(),
+ "\tconst matrix_exp::type trace(const matrix_exp& m)"
+ << "\n\tYou can only apply trace() to a square matrix"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ return sum(diag(m));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ long N = EXP::NR
+ >
+ struct det_helper
+ {
+ static const typename matrix_exp<EXP>::type det (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC ||
+ matrix_exp<EXP>::NR == 0 ||
+ matrix_exp<EXP>::NC == 0
+ );
+ DLIB_ASSERT(m.nr() == m.nc(),
+ "\tconst matrix_exp::type det(const matrix_exp& m)"
+ << "\n\tYou can only apply det() to a square matrix"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+
+ return lu_decomposition<EXP>(m).det();
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ struct det_helper<EXP,1>
+ {
+ static const typename matrix_exp<EXP>::type det (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC);
+
+ return m(0);
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ struct det_helper<EXP,2>
+ {
+ static const typename matrix_exp<EXP>::type det (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC);
+
+ return m(0,0)*m(1,1) - m(0,1)*m(1,0);
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ struct det_helper<EXP,3>
+ {
+ static const typename matrix_exp<EXP>::type det (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC);
+ typedef typename matrix_exp<EXP>::type type;
+
+ type temp = m(0,0)*(m(1,1)*m(2,2) - m(1,2)*m(2,1)) -
+ m(0,1)*(m(1,0)*m(2,2) - m(1,2)*m(2,0)) +
+ m(0,2)*(m(1,0)*m(2,1) - m(1,1)*m(2,0));
+ return temp;
+ }
+ };
+
+
+ template <
+ typename EXP
+ >
+ inline const typename matrix_exp<EXP>::type det (
+ const matrix_exp<EXP>& m
+ ) { return det_helper<EXP>::det(m); }
+
+
+ template <
+ typename EXP
+ >
+ struct det_helper<EXP,4>
+ {
+ static const typename matrix_exp<EXP>::type det (
+ const matrix_exp<EXP>& m
+ )
+ {
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC);
+ typedef typename matrix_exp<EXP>::type type;
+
+ type temp = m(0,0)*(dlib::det(removerc<0,0>(m))) -
+ m(0,1)*(dlib::det(removerc<0,1>(m))) +
+ m(0,2)*(dlib::det(removerc<0,2>(m))) -
+ m(0,3)*(dlib::det(removerc<0,3>(m)));
+ return temp;
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ const matrix<typename EXP::type, EXP::NR, 1, typename EXP::mem_manager_type, typename EXP::layout_type> real_eigenvalues (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // You can only use this function with matrices that contain float or double values
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP::type, float>::value ||
+ is_same_type<typename EXP::type, double>::value));
+
+ DLIB_ASSERT(m.nr() == m.nc(),
+ "\tconst matrix real_eigenvalues()"
+ << "\n\tYou have given an invalidly sized matrix"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+
+ if (m.nr() == 2)
+ {
+ typedef typename EXP::type T;
+ const T m00 = m(0,0);
+ const T m01 = m(0,1);
+ const T m10 = m(1,0);
+ const T m11 = m(1,1);
+
+ const T b = -(m00 + m11);
+ const T c = m00*m11 - m01*m10;
+ matrix<T,EXP::NR,1, typename EXP::mem_manager_type, typename EXP::layout_type> v(2);
+
+
+ T disc = b*b - 4*c;
+ if (disc >= 0)
+ disc = std::sqrt(disc);
+ else
+ disc = 0;
+
+ v(0) = (-b + disc)/2;
+ v(1) = (-b - disc)/2;
+ return v;
+ }
+ else
+ {
+ // Call .ref() so that the symmetric matrix overload can take effect if m
+ // has the appropriate type.
+ return eigenvalue_decomposition<EXP>(m.ref()).get_real_eigenvalues();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ dlib::vector<double,2> max_point_interpolated (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.size() > 0,
+ "\tdlib::vector<double,2> point max_point_interpolated(const matrix_exp& m)"
+ << "\n\tm can't be empty"
+ << "\n\tm.size(): " << m.size()
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ const point p = max_point(m);
+
+ // If this is a column vector then just do interpolation along a line.
+ if (m.nc()==1)
+ {
+ const long pos = p.y();
+ if (0 < pos && pos+1 < m.nr())
+ {
+ double v1 = dlib::impl::magnitude(m(pos-1));
+ double v2 = dlib::impl::magnitude(m(pos));
+ double v3 = dlib::impl::magnitude(m(pos+1));
+ double y = lagrange_poly_min_extrap(pos-1,pos,pos+1, -v1, -v2, -v3);
+ return vector<double,2>(0,y);
+ }
+ }
+ // If this is a row vector then just do interpolation along a line.
+ if (m.nr()==1)
+ {
+ const long pos = p.x();
+ if (0 < pos && pos+1 < m.nc())
+ {
+ double v1 = dlib::impl::magnitude(m(pos-1));
+ double v2 = dlib::impl::magnitude(m(pos));
+ double v3 = dlib::impl::magnitude(m(pos+1));
+ double x = lagrange_poly_min_extrap(pos-1,pos,pos+1, -v1, -v2, -v3);
+ return vector<double,2>(x,0);
+ }
+ }
+
+
+ // If it's on the border then just return the regular max point.
+ if (shrink_rect(get_rect(m),1).contains(p) == false)
+ return p;
+
+ //matrix<double> A(9,6);
+ //matrix<double,0,1> G(9);
+
+ matrix<double,9,1> pix;
+ long i = 0;
+ for (long r = -1; r <= +1; ++r)
+ {
+ for (long c = -1; c <= +1; ++c)
+ {
+ pix(i) = dlib::impl::magnitude(m(p.y()+r,p.y()+c));
+ /*
+ A(i,0) = c*c;
+ A(i,1) = c*r;
+ A(i,2) = r*r;
+ A(i,3) = c;
+ A(i,4) = r;
+ A(i,5) = 1;
+ G(i) = std::exp(-1*(r*r+c*c)/2.0); // Use a gaussian windowing function around p.
+ */
+ ++i;
+ }
+ }
+
+ // This bit of code is how we generated the derivative_filters matrix below.
+ //A = diagm(G)*A;
+ //std::cout << std::setprecision(20) << inv(trans(A)*A)*trans(A)*diagm(G) << std::endl; exit(1);
+
+ const double m10 = 0.10597077880854270659;
+ const double m21 = 0.21194155761708535768;
+ const double m28 = 0.28805844238291455905;
+ const double m57 = 0.57611688476582878504;
+ // So this derivative_filters finds the parameters of the quadratic surface that best fits
+ // the 3x3 region around p. Then we find the maximizer of that surface within that
+ // small region and return that as the maximum location.
+ const double derivative_filters[] = {
+ // xx
+ m10,-m21,m10,
+ m28,-m57,m28,
+ m10,-m21,m10,
+
+ // xy
+ 0.25 ,0,-0.25,
+ 0 ,0, 0,
+ -0.25,0,0.25,
+
+ // yy
+ m10, m28, m10,
+ -m21,-m57,-m21,
+ m10, m28, m10,
+
+ // x
+ -m10,0,m10,
+ -m28,0,m28,
+ -m10,0,m10,
+
+ // y
+ -m10,-m28,-m10,
+ 0, 0, 0,
+ m10, m28, m10
+ };
+ const matrix<double,5,9> filt(derivative_filters);
+ // Now w contains the parameters of the quadratic surface
+ const matrix<double,5,1> w = filt*pix;
+
+
+ // Now newton step to the max point on the surface
+ matrix<double,2,2> H;
+ matrix<double,2,1> g;
+ H = 2*w(0), w(1),
+ w(1), 2*w(2);
+ g = w(3),
+ w(4);
+ const dlib::vector<double,2> delta = -inv(H)*g;
+
+ // if delta isn't in an ascent direction then just use the normal max point.
+ if (dot(delta, g) < 0)
+ return p;
+ else
+ return vector<double,2>(p)+dlib::clamp(delta, -1, 1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_LA_FUNCTS_
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_la_abstract.h b/ml/dlib/dlib/matrix/matrix_la_abstract.h
new file mode 100644
index 000000000..df6a5fd33
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_la_abstract.h
@@ -0,0 +1,1005 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MATRIx_LA_FUNCTS_ABSTRACT_
+#ifdef DLIB_MATRIx_LA_FUNCTS_ABSTRACT_
+
+#include "matrix_abstract.h"
+#include <complex>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Global linear algebra functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::matrix_type inv (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m is a square matrix
+ ensures
+ - returns the inverse of m
+ (Note that if m is singular or so close to being singular that there
+ is a lot of numerical error then the returned matrix will be bogus.
+ You can check by seeing if m*inv(m) is an identity matrix)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix pinv (
+ const matrix_exp& m,
+ double tol = 0
+ );
+ /*!
+ requires
+ - tol >= 0
+ ensures
+ - returns the Moore-Penrose pseudoinverse of m.
+ - The returned matrix has m.nc() rows and m.nr() columns.
+ - if (tol == 0) then
+ - singular values less than max(m.nr(),m.nc()) times the machine epsilon
+ times the largest singular value are ignored.
+ - else
+ - singular values less than tol*max(singular value in m) are ignored.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void svd (
+ const matrix_exp& m,
+ matrix<matrix_exp::type>& u,
+ matrix<matrix_exp::type>& w,
+ matrix<matrix_exp::type>& v
+ );
+ /*!
+ ensures
+ - computes the singular value decomposition of m
+ - m == #u*#w*trans(#v)
+ - trans(#u)*#u == identity matrix
+ - trans(#v)*#v == identity matrix
+ - diag(#w) == the singular values of the matrix m in no
+ particular order. All non-diagonal elements of #w are
+ set to 0.
+ - #u.nr() == m.nr()
+ - #u.nc() == m.nc()
+ - #w.nr() == m.nc()
+ - #w.nc() == m.nc()
+ - #v.nr() == m.nc()
+ - #v.nc() == m.nc()
+ - if DLIB_USE_LAPACK is #defined then the xGESVD routine
+ from LAPACK is used to compute the SVD.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ long svd2 (
+ bool withu,
+ bool withv,
+ const matrix_exp& m,
+ matrix<matrix_exp::type>& u,
+ matrix<matrix_exp::type>& w,
+ matrix<matrix_exp::type>& v
+ );
+ /*!
+ requires
+ - m.nr() >= m.nc()
+ ensures
+ - computes the singular value decomposition of matrix m
+ - m == subm(#u,get_rect(m))*diagm(#w)*trans(#v)
+ - trans(#u)*#u == identity matrix
+ - trans(#v)*#v == identity matrix
+ - #w == the singular values of the matrix m in no
+ particular order.
+ - #u.nr() == m.nr()
+ - #u.nc() == m.nr()
+ - #w.nr() == m.nc()
+ - #w.nc() == 1
+ - #v.nr() == m.nc()
+ - #v.nc() == m.nc()
+ - if (widthu == false) then
+ - ignore the above regarding #u, it isn't computed and its
+ output state is undefined.
+ - if (widthv == false) then
+ - ignore the above regarding #v, it isn't computed and its
+ output state is undefined.
+ - returns an error code of 0, if no errors and 'k' if we fail to
+ converge at the 'kth' singular value.
+ - if (DLIB_USE_LAPACK is #defined) then
+ - if (withu == withv) then
+ - the xGESDD routine from LAPACK is used to compute the SVD.
+ - else
+ - the xGESVD routine from LAPACK is used to compute the SVD.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void svd3 (
+ const matrix_exp& m,
+ matrix<matrix_exp::type>& u,
+ matrix<matrix_exp::type>& w,
+ matrix<matrix_exp::type>& v
+ );
+ /*!
+ ensures
+ - computes the singular value decomposition of m
+ - m == #u*diagm(#w)*trans(#v)
+ - trans(#u)*#u == identity matrix
+ - trans(#v)*#v == identity matrix
+ - #w == the singular values of the matrix m in no
+ particular order.
+ - #u.nr() == m.nr()
+ - #u.nc() == m.nc()
+ - #w.nr() == m.nc()
+ - #w.nc() == 1
+ - #v.nr() == m.nc()
+ - #v.nc() == m.nc()
+ - if DLIB_USE_LAPACK is #defined then the xGESVD routine
+ from LAPACK is used to compute the SVD.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void svd_fast (
+ const matrix<T>& A,
+ matrix<T>& u,
+ matrix<T>& w,
+ matrix<T>& v,
+ unsigned long l,
+ unsigned long q = 1
+ );
+ /*!
+ requires
+ - l > 0
+ - A.size() > 0
+ (i.e. A can't be an empty matrix)
+ ensures
+ - computes the singular value decomposition of A.
+ - Lets define some constants we use to document the behavior of svd_fast():
+ - Let m = A.nr()
+ - Let n = A.nc()
+ - Let k = min(l, min(m,n))
+ - Therefore, A represents an m by n matrix and svd_fast() is designed
+ to find a rank-k representation of it.
+ - if (the rank of A is <= k) then
+ - A == #u*diagm(#w)*trans(#v)
+ - else
+ - A is approximated by #u*diagm(#w)*trans(#v)
+ (i.e. In this case A can't be represented with a rank-k matrix, so the
+ matrix you get by trying to reconstruct A from the output of the SVD is
+ not exactly the same.)
+ - trans(#u)*#u == identity matrix
+ - trans(#v)*#v == identity matrix
+ - #w == the top k singular values of the matrix A (in no particular order).
+ - #u.nr() == m
+ - #u.nc() == k
+ - #w.nr() == k
+ - #w.nc() == 1
+ - #v.nr() == n
+ - #v.nc() == k
+ - This function implements the randomized subspace iteration defined in the
+ algorithm 4.4 and 5.1 boxes of the paper:
+ Finding Structure with Randomness: Probabilistic Algorithms for
+ Constructing Approximate Matrix Decompositions by Halko et al.
+ Therefore, it is very fast and suitable for use with very large matrices.
+ Moreover, q is the number of subspace iterations performed. Larger
+ values of q might increase the accuracy of the solution but the default
+ value should be good for many problems.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sparse_vector_type,
+ typename T
+ >
+ void svd_fast (
+ const std::vector<sparse_vector_type>& A,
+ matrix<T>& u,
+ matrix<T>& w,
+ matrix<T>& v,
+ unsigned long l,
+ unsigned long q = 1
+ );
+ /*!
+ requires
+ - A contains a set of sparse vectors. See dlib/svm/sparse_vector_abstract.h
+ for a definition of what constitutes a sparse vector.
+ - l > 0
+ - max_index_plus_one(A) > 0
+ (i.e. A can't be an empty matrix)
+ ensures
+ - computes the singular value decomposition of A. In this case, we interpret A
+ as a matrix of A.size() rows, where each row is defined by a sparse vector.
+ - Lets define some constants we use to document the behavior of svd_fast():
+ - Let m = A.size()
+ - Let n = max_index_plus_one(A)
+ - Let k = min(l, min(m,n))
+ - Therefore, A represents an m by n matrix and svd_fast() is designed
+ to find a rank-k representation of it.
+ - if (the rank of A is <= k) then
+ - A == #u*diagm(#w)*trans(#v)
+ - else
+ - A is approximated by #u*diagm(#w)*trans(#v)
+ (i.e. In this case A can't be represented with a rank-k matrix, so the
+ matrix you get by trying to reconstruct A from the output of the SVD is
+ not exactly the same.)
+ - trans(#u)*#u == identity matrix
+ - trans(#v)*#v == identity matrix
+ - #w == the top k singular values of the matrix A (in no particular order).
+ - #u.nr() == m
+ - #u.nc() == k
+ - #w.nr() == k
+ - #w.nc() == 1
+ - #v.nr() == n
+ - #v.nc() == k
+ - This function implements the randomized subspace iteration defined in the
+ algorithm 4.4 and 5.1 boxes of the paper:
+ Finding Structure with Randomness: Probabilistic Algorithms for
+ Constructing Approximate Matrix Decompositions by Halko et al.
+ Therefore, it is very fast and suitable for use with very large matrices.
+ Moreover, q is the number of subspace iterations performed. Larger
+ values of q might increase the accuracy of the solution but the default
+ value should be good for many problems.
+ !*/
+
+ template <
+ typename sparse_vector_type,
+ typename T
+ >
+ void svd_fast (
+ const std::vector<sparse_vector_type>& A,
+ matrix<T>& w,
+ matrix<T>& v,
+ unsigned long l,
+ unsigned long q = 1
+ );
+ /*!
+ This function is identical to the above svd_fast() except it doesn't compute u.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ void orthogonalize (
+ matrix<T,NR,NC,MM,L>& m
+ );
+ /*!
+ requires
+ - m.nr() >= m.nc()
+ - m.size() > 0
+ ensures
+ - #m == an orthogonal matrix with the same dimensions as m. In particular,
+ the columns of #m have the same span as the columns of m.
+ - trans(#m)*#m == identity matrix
+ - This function is just shorthand for computing the QR decomposition of m
+ and then storing the Q factor into #m.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix real_eigenvalues (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.nr() == m.nc()
+ - matrix_exp::type == float or double
+ ensures
+ - returns a matrix E such that:
+ - E.nr() == m.nr()
+ - E.nc() == 1
+ - E contains the real part of all eigenvalues of the matrix m.
+ (note that the eigenvalues are not sorted)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type det (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m is a square matrix
+ ensures
+ - returns the determinant of m
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type trace (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m is a square matrix
+ ensures
+ - returns the trace of m
+ (i.e. returns sum(diag(m)))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::matrix_type chol (
+ const matrix_exp& A
+ );
+ /*!
+ requires
+ - A is a square matrix
+ ensures
+ - if (A has a Cholesky Decomposition) then
+ - returns the decomposition of A. That is, returns a matrix L
+ such that L*trans(L) == A. L will also be lower triangular.
+ - else
+ - returns a matrix with the same dimensions as A but it
+ will have a bogus value. I.e. it won't be a decomposition.
+ In this case the algorithm returns a partial decomposition.
+ - You can tell when chol fails by looking at the lower right
+ element of the returned matrix. If it is 0 then it means
+ A does not have a cholesky decomposition.
+
+ - If DLIB_USE_LAPACK is defined then the LAPACK routine xPOTRF
+ is used to compute the cholesky decomposition.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::matrix_type inv_lower_triangular (
+ const matrix_exp& A
+ );
+ /*!
+ requires
+ - A is a square matrix
+ ensures
+ - if (A is lower triangular) then
+ - returns the inverse of A.
+ - else
+ - returns a matrix with the same dimensions as A but it
+ will have a bogus value. I.e. it won't be an inverse.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::matrix_type inv_upper_triangular (
+ const matrix_exp& A
+ );
+ /*!
+ requires
+ - A is a square matrix
+ ensures
+ - if (A is upper triangular) then
+ - returns the inverse of A.
+ - else
+ - returns a matrix with the same dimensions as A but it
+ will have a bogus value. I.e. it won't be an inverse.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Matrix decomposition classes
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_exp_type
+ >
+ class lu_decomposition
+ {
+ /*!
+ REQUIREMENTS ON matrix_exp_type
+ must be some kind of matrix expression as defined in the
+ dlib/matrix/matrix_abstract.h file. (e.g. a dlib::matrix object)
+ The matrix type must also contain float or double values.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can compute an LU
+ decomposition of a real valued matrix. That is, for any
+ matrix A it computes matrices L, U, and a pivot vector P such
+ that rowm(A,P) == L*U.
+
+ The LU decomposition with pivoting always exists, even if the matrix is
+ singular, so the constructor will never fail. The primary use of the
+ LU decomposition is in the solution of square systems of simultaneous
+ linear equations. This will fail if is_singular() returns true (or
+ if A is very nearly singular).
+
+ If DLIB_USE_LAPACK is defined then the LAPACK routine xGETRF
+ is used to compute the LU decomposition.
+ !*/
+
+ public:
+
+ const static long NR = matrix_exp_type::NR;
+ const static long NC = matrix_exp_type::NC;
+ typedef typename matrix_exp_type::type type;
+ typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_exp_type::layout_type layout_type;
+
+ typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
+ typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
+ typedef matrix<long,NR,1,mem_manager_type,layout_type> pivot_column_vector_type;
+
+ template <typename EXP>
+ lu_decomposition (
+ const matrix_exp<EXP> &A
+ );
+ /*!
+ requires
+ - EXP::type == lu_decomposition::type
+ - A.size() > 0
+ ensures
+ - #nr() == A.nr()
+ - #nc() == A.nc()
+ - #is_square() == (A.nr() == A.nc())
+ - computes the LU factorization of the given A matrix.
+ !*/
+
+ bool is_square (
+ ) const;
+ /*!
+ ensures
+ - if (the input A matrix was a square matrix) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_singular (
+ ) const;
+ /*!
+ requires
+ - is_square() == true
+ ensures
+ - if (the input A matrix is singular) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the input matrix
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in the input matrix
+ !*/
+
+ const matrix_type get_l (
+ ) const;
+ /*!
+ ensures
+ - returns the lower triangular L factor of the LU factorization.
+ - L.nr() == nr()
+ - L.nc() == min(nr(),nc())
+ !*/
+
+ const matrix_type get_u (
+ ) const;
+ /*!
+ ensures
+ - returns the upper triangular U factor of the LU factorization.
+ - U.nr() == min(nr(),nc())
+ - U.nc() == nc()
+ !*/
+
+ const pivot_column_vector_type& get_pivot (
+ ) const;
+ /*!
+ ensures
+ - returns the pivot permutation vector. That is,
+ if A is the input matrix then this function
+ returns a vector P such that:
+ - rowm(A,P) == get_l()*get_u()
+ - P.nr() == A.nr()
+ !*/
+
+ type det (
+ ) const;
+ /*!
+ requires
+ - is_square() == true
+ ensures
+ - computes and returns the determinant of the input
+ matrix using LU factors.
+ !*/
+
+ template <typename EXP>
+ const matrix_type solve (
+ const matrix_exp<EXP> &B
+ ) const;
+ /*!
+ requires
+ - EXP::type == lu_decomposition::type
+ - is_square() == true
+ - B.nr() == nr()
+ ensures
+ - Let A denote the input matrix to this class's constructor.
+ Then this function solves A*X == B for X and returns X.
+ - Note that if A is singular (or very close to singular) then
+ the X returned by this function won't fit A*X == B very well (if at all).
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_exp_type
+ >
+ class cholesky_decomposition
+ {
+ /*!
+ REQUIREMENTS ON matrix_exp_type
+ must be some kind of matrix expression as defined in the
+ dlib/matrix/matrix_abstract.h file. (e.g. a dlib::matrix object)
+ The matrix type must also contain float or double values.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can compute a cholesky
+ decomposition of a real valued matrix. That is, for any
+ symmetric, positive definite matrix A, it computes a lower
+ triangular matrix L such that A == L*trans(L).
+
+ If the matrix is not symmetric or positive definite, the function
+ computes only a partial decomposition. This can be tested with
+ the is_spd() flag.
+
+ If DLIB_USE_LAPACK is defined then the LAPACK routine xPOTRF
+ is used to compute the cholesky decomposition.
+ !*/
+
+ public:
+
+ const static long NR = matrix_exp_type::NR;
+ const static long NC = matrix_exp_type::NC;
+ typedef typename matrix_exp_type::type type;
+ typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_exp_type::layout_type layout_type;
+
+ typedef typename matrix_exp_type::matrix_type matrix_type;
+ typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
+
+ template <typename EXP>
+ cholesky_decomposition(
+ const matrix_exp<EXP>& A
+ );
+ /*!
+ requires
+ - EXP::type == cholesky_decomposition::type
+ - A.size() > 0
+ - A.nr() == A.nc()
+ (i.e. A must be a square matrix)
+ ensures
+ - if (A is symmetric positive-definite) then
+ - #is_spd() == true
+ - Constructs a lower triangular matrix L, such that L*trans(L) == A.
+ and #get_l() == L
+ - else
+ - #is_spd() == false
+ !*/
+
+ bool is_spd(
+ ) const;
+ /*!
+ ensures
+ - if (the input matrix was symmetric positive-definite) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ const matrix_type& get_l(
+ ) const;
+ /*!
+ ensures
+ - returns the lower triangular factor, L, such that L*trans(L) == A
+ (where A is the input matrix to this class's constructor)
+ - Note that if A is not symmetric positive definite or positive semi-definite
+ then the equation L*trans(L) == A won't hold.
+ !*/
+
+ template <typename EXP>
+ const matrix solve (
+ const matrix_exp<EXP>& B
+ ) const;
+ /*!
+ requires
+ - EXP::type == cholesky_decomposition::type
+ - B.nr() == get_l().nr()
+ (i.e. the number of rows in B must match the number of rows in the
+ input matrix A)
+ ensures
+ - Let A denote the input matrix to this class's constructor. Then
+ this function solves A*X = B for X and returns X.
+ - Note that if is_spd() == false or A was really close to being
+ non-SPD then the solver will fail to find an accurate solution.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_exp_type
+ >
+ class qr_decomposition
+ {
+ /*!
+ REQUIREMENTS ON matrix_exp_type
+ must be some kind of matrix expression as defined in the
+ dlib/matrix/matrix_abstract.h file. (e.g. a dlib::matrix object)
+ The matrix type must also contain float or double values.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can compute a classical
+ QR decomposition of an m-by-n real valued matrix A with m >= n.
+
+ The QR decomposition is an m-by-n orthogonal matrix Q and an
+ n-by-n upper triangular matrix R so that A == Q*R. The QR decomposition
+ always exists, even if the matrix does not have full rank, so the
+ constructor will never fail. The primary use of the QR decomposition
+ is in the least squares solution of non-square systems of simultaneous
+ linear equations. This will fail if is_full_rank() returns false or
+ A is very nearly not full rank.
+
+ The Q and R factors can be retrieved via the get_q() and get_r()
+ methods. Furthermore, a solve() method is provided to find the
+ least squares solution of Ax=b using the QR factors.
+
+ If DLIB_USE_LAPACK is #defined then the xGEQRF routine
+ from LAPACK is used to compute the QR decomposition.
+ !*/
+
+ public:
+
+ const static long NR = matrix_exp_type::NR;
+ const static long NC = matrix_exp_type::NC;
+ typedef typename matrix_exp_type::type type;
+ typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_exp_type::layout_type layout_type;
+
+ typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
+
+ template <typename EXP>
+ qr_decomposition(
+ const matrix_exp<EXP>& A
+ );
+ /*!
+ requires
+ - EXP::type == qr_decomposition::type
+ - A.nr() >= A.nc()
+ - A.size() > 0
+ ensures
+ - #nr() == A.nr()
+ - #nc() == A.nc()
+ - computes the QR decomposition of the given A matrix.
+ !*/
+
+ bool is_full_rank(
+ ) const;
+ /*!
+ ensures
+ - if (the input A matrix had full rank) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the input matrix
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in the input matrix
+ !*/
+
+ const matrix_type get_r (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R is the upper triangular factor, R, of the QR factorization
+ - get_q()*R == input matrix A
+ - R.nr() == nc()
+ - R.nc() == nc()
+ !*/
+
+ const matrix_type get_q (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix Q such that:
+ - Q is the economy-sized orthogonal factor Q from the QR
+ factorization.
+ - trans(Q)*Q == identity matrix
+ - Q*get_r() == input matrix A
+ - Q.nr() == nr()
+ - Q.nc() == nc()
+ !*/
+
+ void get_q (
+ matrix_type& Q
+ ) const;
+ /*!
+ ensures
+ - #Q == get_q()
+ - This function exists to allow a user to get the Q matrix without the
+ overhead of returning a matrix by value.
+ !*/
+
+ template <typename EXP>
+ const matrix_type solve (
+ const matrix_exp<EXP>& B
+ ) const;
+ /*!
+ requires
+ - EXP::type == qr_decomposition::type
+ - B.nr() == nr()
+ ensures
+ - Let A denote the input matrix to this class's constructor.
+ Then this function finds the least squares solution to the equation A*X = B
+ and returns X. X has the following properties:
+ - X is the matrix that minimizes the two norm of A*X-B. That is, it
+ minimizes sum(squared(A*X - B)).
+ - X.nr() == nc()
+ - X.nc() == B.nc()
+ - Note that this function will fail to output a good solution if is_full_rank() == false
+ or the A matrix is close to not being full rank.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_exp_type
+ >
+ class eigenvalue_decomposition
+ {
+ /*!
+ REQUIREMENTS ON matrix_exp_type
+ must be some kind of matrix expression as defined in the
+ dlib/matrix/matrix_abstract.h file. (e.g. a dlib::matrix object)
+ The matrix type must also contain float or double values.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can compute an eigenvalue
+ decomposition of a real valued matrix. So it gives
+ you the set of eigenvalues and eigenvectors for a matrix.
+
+ Let A denote the input matrix to this object's constructor. Then
+ what this object does is it finds two matrices, D and V, such that
+ - A*V == V*D
+ Where V is a square matrix that contains all the eigenvectors
+ of the A matrix (each column of V is an eigenvector) and
+ D is a diagonal matrix containing the eigenvalues of A.
+
+
+ It is important to note that if A is symmetric or non-symmetric you
+ get somewhat different results. If A is a symmetric matrix (i.e. A == trans(A))
+ then:
+ - All the eigenvalues and eigenvectors of A are real numbers.
+ - Because of this there isn't really any point in using the
+ part of this class's interface that returns complex matrices.
+ All you need are the get_real_eigenvalues() and
+ get_pseudo_v() functions.
+ - V*trans(V) should be equal to the identity matrix. That is, all the
+ eigenvectors in V should be orthonormal.
+ - So A == V*D*trans(V)
+ - If DLIB_USE_LAPACK is #defined then this object uses the xSYEVR LAPACK
+ routine.
+
+ On the other hand, if A is not symmetric then:
+ - Some of the eigenvalues and eigenvectors might be complex numbers.
+ - An eigenvalue is complex if and only if its corresponding eigenvector
+ is complex. So you can check for this case by just checking
+ get_imag_eigenvalues() to see if any values are non-zero. You don't
+ have to check the V matrix as well.
+ - V*trans(V) won't be equal to the identity matrix but it is usually
+ invertible. So A == V*D*inv(V) is usually a valid statement but
+ A == V*D*trans(V) won't be.
+ - If DLIB_USE_LAPACK is #defined then this object uses the xGEEV LAPACK
+ routine.
+ !*/
+
+ public:
+
+ const static long NR = matrix_exp_type::NR;
+ const static long NC = matrix_exp_type::NC;
+ typedef typename matrix_exp_type::type type;
+ typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_exp_type::layout_type layout_type;
+
+ typedef typename matrix_exp_type::matrix_type matrix_type;
+ typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
+
+ typedef matrix<std::complex<type>,0,0,mem_manager_type,layout_type> complex_matrix_type;
+ typedef matrix<std::complex<type>,NR,1,mem_manager_type,layout_type> complex_column_vector_type;
+
+
+ template <typename EXP>
+ eigenvalue_decomposition(
+ const matrix_exp<EXP>& A
+ );
+ /*!
+ requires
+ - A.nr() == A.nc()
+ - A.size() > 0
+ - EXP::type == eigenvalue_decomposition::type
+ ensures
+ - #dim() == A.nr()
+ - computes the eigenvalue decomposition of A.
+ - #get_eigenvalues() == the eigenvalues of A
+ - #get_v() == all the eigenvectors of A
+ !*/
+
+ template <typename EXP>
+ eigenvalue_decomposition(
+ const matrix_op<op_make_symmetric<EXP> >& A
+ );
+ /*!
+ requires
+ - A.nr() == A.nc()
+ - A.size() > 0
+ - EXP::type == eigenvalue_decomposition::type
+ ensures
+ - #dim() == A.nr()
+ - computes the eigenvalue decomposition of the symmetric matrix A. Does so
+ using a method optimized for symmetric matrices.
+ - #get_eigenvalues() == the eigenvalues of A
+ - #get_v() == all the eigenvectors of A
+ - moreover, since A is symmetric there won't be any imaginary eigenvalues. So
+ we will have:
+ - #get_imag_eigenvalues() == 0
+ - #get_real_eigenvalues() == the eigenvalues of A
+ - #get_pseudo_v() == all the eigenvectors of A
+ - diagm(#get_real_eigenvalues()) == #get_pseudo_d()
+
+ Note that the symmetric matrix operator is created by the
+ dlib::make_symmetric() function. This function simply reflects
+ the lower triangular part of a square matrix into the upper triangular
+ part to create a symmetric matrix. It can also be used to denote that a
+ matrix is already symmetric using the C++ type system.
+ !*/
+
+ long dim (
+ ) const;
+ /*!
+ ensures
+ - dim() == the number of rows/columns in the input matrix A
+ !*/
+
+ const complex_column_vector_type get_eigenvalues (
+ ) const;
+ /*!
+ ensures
+ - returns diag(get_d()). That is, returns a
+ vector that contains the eigenvalues of the input
+ matrix.
+ - the returned vector has dim() rows
+ - the eigenvalues are not sorted in any particular way
+ !*/
+
+ const column_vector_type& get_real_eigenvalues (
+ ) const;
+ /*!
+ ensures
+ - returns the real parts of the eigenvalues. That is,
+ returns real(get_eigenvalues())
+ - the returned vector has dim() rows
+ - the eigenvalues are not sorted in any particular way
+ !*/
+
+ const column_vector_type& get_imag_eigenvalues (
+ ) const;
+ /*!
+ ensures
+ - returns the imaginary parts of the eigenvalues. That is,
+ returns imag(get_eigenvalues())
+ - the returned vector has dim() rows
+ - the eigenvalues are not sorted in any particular way
+ !*/
+
+ const complex_matrix_type get_v (
+ ) const;
+ /*!
+ ensures
+ - returns the eigenvector matrix V that is
+ dim() rows by dim() columns
+ - Each column in V is one of the eigenvectors of the input
+ matrix
+ !*/
+
+ const complex_matrix_type get_d (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix D such that:
+ - D.nr() == dim()
+ - D.nc() == dim()
+ - diag(D) == get_eigenvalues()
+ (i.e. the diagonal of D contains all the eigenvalues in the input matrix)
+ - all off diagonal elements of D are set to 0
+ !*/
+
+ const matrix_type& get_pseudo_v (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix that is dim() rows by dim() columns
+ - Let A denote the input matrix given to this object's constructor.
+ - if (A has any imaginary eigenvalues) then
+ - returns the pseudo-eigenvector matrix V
+ - The matrix V returned by this function is structured such that:
+ - A*V == V*get_pseudo_d()
+ - else
+ - returns the eigenvector matrix V with A's eigenvectors as
+ the columns of V
+ - A*V == V*diagm(get_real_eigenvalues())
+ !*/
+
+ const matrix_type get_pseudo_d (
+ ) const;
+ /*!
+ ensures
+ - The returned matrix is dim() rows by dim() columns
+ - Computes and returns the block diagonal eigenvalue matrix.
+ If the original matrix A is not symmetric, then the eigenvalue
+ matrix D is block diagonal with the real eigenvalues in 1-by-1
+ blocks and any complex eigenvalues,
+ a + i*b, in 2-by-2 blocks, (a, b; -b, a). That is, if the complex
+ eigenvalues look like
+
+ u + iv . . . . .
+ . u - iv . . . .
+ . . a + ib . . .
+ . . . a - ib . .
+ . . . . x .
+ . . . . . y
+
+ Then D looks like
+
+ u v . . . .
+ -v u . . . .
+ . . a b . .
+ . . -b a . .
+ . . . . x .
+ . . . . . y
+
+ This keeps V (The V you get from get_pseudo_v()) a real matrix in both
+ symmetric and non-symmetric cases, and A*V = V*D.
+ - the eigenvalues are not sorted in any particular way
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_LA_FUNCTS_ABSTRACT_
+
diff --git a/ml/dlib/dlib/matrix/matrix_lu.h b/ml/dlib/dlib/matrix/matrix_lu.h
new file mode 100644
index 000000000..3e49cd653
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_lu.h
@@ -0,0 +1,361 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+// This code was adapted from code from the JAMA part of NIST's TNT library.
+// See: http://math.nist.gov/tnt/
+#ifndef DLIB_MATRIX_LU_DECOMPOSITION_H
+#define DLIB_MATRIX_LU_DECOMPOSITION_H
+
+#include "matrix.h"
+#include "matrix_utilities.h"
+#include "matrix_subexp.h"
+#include "matrix_trsm.h"
+#include <algorithm>
+
+#ifdef DLIB_USE_LAPACK
+#include "lapack/getrf.h"
+#endif
+
+
+namespace dlib
+{
+
+ template <
+ typename matrix_exp_type
+ >
+ class lu_decomposition
+ {
+ public:
+
+ const static long NR = matrix_exp_type::NR;
+ const static long NC = matrix_exp_type::NC;
+ typedef typename matrix_exp_type::type type;
+ typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_exp_type::layout_type layout_type;
+
+ typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
+ typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
+ typedef matrix<long,NR,1,mem_manager_type,layout_type> pivot_column_vector_type;
+
+ // You have supplied an invalid type of matrix_exp_type. You have
+ // to use this object with matrices that contain float or double type data.
+ COMPILE_TIME_ASSERT((is_same_type<float, type>::value ||
+ is_same_type<double, type>::value ));
+
+ template <typename EXP>
+ lu_decomposition (
+ const matrix_exp<EXP> &A
+ );
+
+ bool is_square (
+ ) const;
+
+ bool is_singular (
+ ) const;
+
+ long nr(
+ ) const;
+
+ long nc(
+ ) const;
+
+ const matrix_type get_l (
+ ) const;
+
+ const matrix_type get_u (
+ ) const;
+
+ const pivot_column_vector_type& get_pivot (
+ ) const;
+
+ type det (
+ ) const;
+
+ template <typename EXP>
+ const matrix_type solve (
+ const matrix_exp<EXP> &B
+ ) const;
+
+ private:
+
+ /* Array for internal storage of decomposition. */
+ matrix<type,0,0,mem_manager_type,column_major_layout> LU;
+ long m, n, pivsign;
+ pivot_column_vector_type piv;
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Public member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ lu_decomposition<matrix_exp_type>::
+ lu_decomposition (
+ const matrix_exp<EXP>& A
+ ) :
+ LU(A),
+ m(A.nr()),
+ n(A.nc())
+ {
+ using namespace std;
+ using std::abs;
+
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(A.size() > 0,
+ "\tlu_decomposition::lu_decomposition(A)"
+ << "\n\tInvalid inputs were given to this function"
+ << "\n\tA.size(): " << A.size()
+ << "\n\tthis: " << this
+ );
+
+#ifdef DLIB_USE_LAPACK
+ matrix<lapack::integer,0,1,mem_manager_type,layout_type> piv_temp;
+ lapack::getrf(LU, piv_temp);
+
+ pivsign = 1;
+
+ // Turn the piv_temp vector into a more useful form. This way we will have the identity
+ // rowm(A,piv) == L*U. The permutation vector that comes out of LAPACK is somewhat
+ // different.
+ piv = trans(range(0,m-1));
+ for (long i = 0; i < piv_temp.size(); ++i)
+ {
+ // -1 because FORTRAN is indexed starting with 1 instead of 0
+ if (piv(piv_temp(i)-1) != piv(i))
+ {
+ std::swap(piv(i), piv(piv_temp(i)-1));
+ pivsign = -pivsign;
+ }
+ }
+
+#else
+
+ // Use a "left-looking", dot-product, Crout/Doolittle algorithm.
+
+
+ piv = trans(range(0,m-1));
+ pivsign = 1;
+
+ column_vector_type LUcolj(m);
+
+ // Outer loop.
+ for (long j = 0; j < n; j++)
+ {
+
+ // Make a copy of the j-th column to localize references.
+ LUcolj = colm(LU,j);
+
+ // Apply previous transformations.
+ for (long i = 0; i < m; i++)
+ {
+ // Most of the time is spent in the following dot product.
+ const long kmax = std::min(i,j);
+ type s;
+ if (kmax > 0)
+ s = rowm(LU,i, kmax)*colm(LUcolj,0,kmax);
+ else
+ s = 0;
+
+ LU(i,j) = LUcolj(i) -= s;
+ }
+
+ // Find pivot and exchange if necessary.
+ long p = j;
+ for (long i = j+1; i < m; i++)
+ {
+ if (abs(LUcolj(i)) > abs(LUcolj(p)))
+ {
+ p = i;
+ }
+ }
+ if (p != j)
+ {
+ long k=0;
+ for (k = 0; k < n; k++)
+ {
+ type t = LU(p,k);
+ LU(p,k) = LU(j,k);
+ LU(j,k) = t;
+ }
+ k = piv(p);
+ piv(p) = piv(j);
+ piv(j) = k;
+ pivsign = -pivsign;
+ }
+
+ // Compute multipliers.
+ if ((j < m) && (LU(j,j) != 0.0))
+ {
+ for (long i = j+1; i < m; i++)
+ {
+ LU(i,j) /= LU(j,j);
+ }
+ }
+ }
+
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ bool lu_decomposition<matrix_exp_type>::
+ is_square (
+ ) const
+ {
+ return m == n;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ long lu_decomposition<matrix_exp_type>::
+ nr (
+ ) const
+ {
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ long lu_decomposition<matrix_exp_type>::
+ nc (
+ ) const
+ {
+ return n;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ bool lu_decomposition<matrix_exp_type>::
+ is_singular (
+ ) const
+ {
+ /* Is the matrix singular?
+ if upper triangular factor U (and hence A) is singular, false otherwise.
+ */
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_square() == true,
+ "\tbool lu_decomposition::is_singular()"
+ << "\n\tYou can only use this on square matrices"
+ << "\n\tthis: " << this
+ );
+
+ type max_val, min_val;
+ find_min_and_max (abs(diag(LU)), min_val, max_val);
+ type eps = max_val;
+ if (eps != 0)
+ eps *= std::sqrt(std::numeric_limits<type>::epsilon())/10;
+ else
+ eps = 1; // there is no max so just use 1
+
+ return min_val < eps;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
+ get_l (
+ ) const
+ {
+ if (LU.nr() >= LU.nc())
+ return lowerm(LU,1.0);
+ else
+ return lowerm(subm(LU,0,0,m,m), 1.0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
+ get_u (
+ ) const
+ {
+ if (LU.nr() >= LU.nc())
+ return upperm(subm(LU,0,0,n,n));
+ else
+ return upperm(LU);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename lu_decomposition<matrix_exp_type>::pivot_column_vector_type& lu_decomposition<matrix_exp_type>::
+ get_pivot (
+ ) const
+ {
+ return piv;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ typename lu_decomposition<matrix_exp_type>::type lu_decomposition<matrix_exp_type>::
+ det (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_square() == true,
+ "\ttype lu_decomposition::det()"
+ << "\n\tYou can only use this on square matrices"
+ << "\n\tthis: " << this
+ );
+
+ // Check if it is singular and if it is just return 0.
+ // We want to do this because a prod() operation can easily
+ // overcome a single diagonal element that is effectively 0 when
+ // LU is a big enough matrix.
+ if (is_singular())
+ return 0;
+
+ return prod(diag(LU))*static_cast<type>(pivsign);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
+ solve (
+ const matrix_exp<EXP> &B
+ ) const
+ {
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_square() == true && B.nr() == nr(),
+ "\ttype lu_decomposition::solve()"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tis_square(): " << (is_square()? "true":"false" )
+ << "\n\tB.nr(): " << B.nr()
+ << "\n\tnr(): " << nr()
+ << "\n\tthis: " << this
+ );
+
+ // Copy right hand side with pivoting
+ matrix<type,0,0,mem_manager_type,column_major_layout> X(rowm(B, piv));
+
+ using namespace blas_bindings;
+ // Solve L*Y = B(piv,:)
+ triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X);
+ // Solve U*X = Y;
+ triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X);
+ return X;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIX_LU_DECOMPOSITION_H
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_mat.h b/ml/dlib/dlib/matrix/matrix_mat.h
new file mode 100644
index 000000000..803d7d999
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_mat.h
@@ -0,0 +1,733 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_MAT_Hh_
+#define DLIB_MATRIx_MAT_Hh_
+
+#include "matrix_mat_abstract.h"
+#include "../stl_checked.h"
+#include <vector>
+#include "matrix_op.h"
+#include "../array2d.h"
+#include "../array.h"
+#include "../image_processing/generic_image.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const matrix_exp<EXP>& mat (
+ const matrix_exp<EXP>& m
+ )
+ {
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type, typename pixel_type>
+ struct op_image_to_mat : does_not_alias
+ {
+ op_image_to_mat( const image_type& img) : imgview(img){}
+
+ const_image_view<image_type> imgview;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef pixel_type type;
+ typedef const pixel_type& const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const { return imgview[r][c]; }
+
+ long nr () const { return imgview.nr(); }
+ long nc () const { return imgview.nc(); }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ > // The reason we disable this if it is a matrix is because this matrix_op claims
+ // to not alias any matrix. But obviously that would be a problem if we let it
+ // take a matrix.
+ const typename disable_if<is_matrix<image_type>,matrix_op<op_image_to_mat<image_type, typename image_traits<image_type>::pixel_type> > >::type mat (
+ const image_type& img
+ )
+ {
+ typedef op_image_to_mat<image_type, typename image_traits<image_type>::pixel_type> op;
+ return matrix_op<op>(op(img));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ struct op_image_view_to_mat : does_not_alias
+ {
+ op_image_view_to_mat( const image_view<image_type>& img) : imgview(img){}
+
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+
+ const image_view<image_type>& imgview;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef pixel_type type;
+ typedef const pixel_type& const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const { return imgview[r][c]; }
+
+ long nr () const { return imgview.nr(); }
+ long nc () const { return imgview.nc(); }
+ };
+
+ template <
+ typename image_type
+ >
+ const matrix_op<op_image_view_to_mat<image_type> > mat (
+ const image_view<image_type>& img
+ )
+ {
+ typedef op_image_view_to_mat<image_type> op;
+ return matrix_op<op>(op(img));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type>
+ struct op_const_image_view_to_mat : does_not_alias
+ {
+ op_const_image_view_to_mat( const const_image_view<image_type>& img) : imgview(img){}
+
+ typedef typename image_traits<image_type>::pixel_type pixel_type;
+
+ const const_image_view<image_type>& imgview;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef pixel_type type;
+ typedef const pixel_type& const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const { return imgview[r][c]; }
+
+ long nr () const { return imgview.nr(); }
+ long nc () const { return imgview.nc(); }
+ };
+
+ template <
+ typename image_type
+ >
+ const matrix_op<op_const_image_view_to_mat<image_type> > mat (
+ const const_image_view<image_type>& img
+ )
+ {
+ typedef op_const_image_view_to_mat<image_type> op;
+ return matrix_op<op>(op(img));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct op_array_to_mat : does_not_alias
+ {
+ op_array_to_mat( const T& vect_) : vect(vect_){}
+
+ const T& vect;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 1;
+ typedef typename T::type type;
+ typedef const typename T::type& const_ret_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long ) const { return vect[r]; }
+
+ long nr () const { return vect.size(); }
+ long nc () const { return 1; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename MM
+ >
+ const matrix_op<op_array_to_mat<array<T,MM> > > mat (
+ const array<T,MM>& m
+ )
+ {
+ typedef op_array_to_mat<array<T,MM> > op;
+ return matrix_op<op>(op(m));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename U>
+ struct not_bool { typedef U type; };
+ template <>
+ struct not_bool<const bool&> { typedef bool type; };
+ }
+
+ template <typename T>
+ struct op_std_vect_to_mat : does_not_alias
+ {
+ op_std_vect_to_mat( const T& vect_) : vect(vect_){}
+
+ const T& vect;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 1;
+ typedef typename T::value_type type;
+ // Since std::vector returns a proxy for bool types we need to make sure we don't
+ // return an element by reference if it is a bool type.
+ typedef typename impl::not_bool<const typename T::value_type&>::type const_ret_type;
+
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long ) const { return vect[r]; }
+
+ long nr () const { return vect.size(); }
+ long nc () const { return 1; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename value_type,
+ typename alloc
+ >
+ const matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > > mat (
+ const std::vector<value_type,alloc>& vector
+ )
+ {
+ typedef op_std_vect_to_mat<std::vector<value_type,alloc> > op;
+ return matrix_op<op>(op(vector));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename value_type,
+ typename alloc
+ >
+ const matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > > mat (
+ const std_vector_c<value_type,alloc>& vector
+ )
+ {
+ typedef op_std_vect_to_mat<std_vector_c<value_type,alloc> > op;
+ return matrix_op<op>(op(vector));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct op_pointer_to_mat;
+
+ template <typename T>
+ struct op_pointer_to_col_vect
+ {
+ op_pointer_to_col_vect(
+ const T* ptr_,
+ const long size_
+ ) : ptr(ptr_), size(size_){}
+
+ const T* ptr;
+ const long size;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 1;
+ typedef T type;
+ typedef const T& const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long ) const { return ptr[r]; }
+
+ long nr () const { return size; }
+ long nc () const { return 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& ) const { return false; }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& ) const { return false; }
+
+ template <long num_rows, long num_cols, typename mem_manager, typename layout>
+ bool aliases (
+ const matrix_exp<matrix<T,num_rows,num_cols, mem_manager,layout> >& item
+ ) const
+ {
+ if (item.size() == 0)
+ return false;
+ else
+ return (ptr == &item(0,0));
+ }
+
+ inline bool aliases (
+ const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
+ ) const;
+
+ bool aliases (
+ const matrix_exp<matrix_op<op_pointer_to_col_vect<T> > >& item
+ ) const
+ {
+ return item.ref().op.ptr == ptr;
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_pointer_to_col_vect<T> > mat (
+ const T* ptr,
+ long nr
+ )
+ {
+ DLIB_ASSERT(nr >= 0 ,
+ "\tconst matrix_exp mat(ptr, nr)"
+ << "\n\t nr must be >= 0"
+ << "\n\t nr: " << nr
+ );
+ typedef op_pointer_to_col_vect<T> op;
+ return matrix_op<op>(op(ptr, nr));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct op_pointer_to_mat
+ {
+ op_pointer_to_mat(
+ const T* ptr_,
+ const long nr_,
+ const long nc_
+ ) : ptr(ptr_), rows(nr_), cols(nc_), stride(nc_){}
+
+ op_pointer_to_mat(
+ const T* ptr_,
+ const long nr_,
+ const long nc_,
+ const long stride_
+ ) : ptr(ptr_), rows(nr_), cols(nc_), stride(stride_){}
+
+ const T* ptr;
+ const long rows;
+ const long cols;
+ const long stride;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef T type;
+ typedef const T& const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c) const { return ptr[r*stride + c]; }
+
+ long nr () const { return rows; }
+ long nc () const { return cols; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& ) const { return false; }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& ) const { return false; }
+
+ template <long num_rows, long num_cols, typename mem_manager, typename layout>
+ bool aliases (
+ const matrix_exp<matrix<T,num_rows,num_cols, mem_manager,layout> >& item
+ ) const
+ {
+ if (item.size() == 0)
+ return false;
+ else
+ return (ptr == &item(0,0));
+ }
+
+ bool aliases (
+ const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
+ ) const
+ {
+ return item.ref().op.ptr == ptr;
+ }
+
+ bool aliases (
+ const matrix_exp<matrix_op<op_pointer_to_col_vect<T> > >& item
+ ) const
+ {
+ return item.ref().op.ptr == ptr;
+ }
+ };
+
+ template <typename T>
+ bool op_pointer_to_col_vect<T>::
+ aliases (
+ const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
+ ) const
+ {
+ return item.ref().op.ptr == ptr;
+ }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ bool matrix<T,NR,NC,MM,L>::aliases (
+ const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
+ ) const
+ {
+ if (size() != 0)
+ return item.ref().op.ptr == &data(0,0);
+ else
+ return false;
+ }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ bool matrix<T,NR,NC,MM,L>::aliases (
+ const matrix_exp<matrix_op<op_pointer_to_col_vect<T> > >& item
+ ) const
+ {
+ if (size() != 0)
+ return item.ref().op.ptr == &data(0,0);
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_pointer_to_mat<T> > mat (
+ const T* ptr,
+ long nr,
+ long nc
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0 ,
+ "\tconst matrix_exp mat(ptr, nr, nc)"
+ << "\n\t nr and nc must be >= 0"
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ );
+ typedef op_pointer_to_mat<T> op;
+ return matrix_op<op>(op(ptr,nr,nc));
+ }
+
+ template <
+ typename T
+ >
+ const matrix_op<op_pointer_to_mat<T> > mat (
+ const T* ptr,
+ long nr,
+ long nc,
+ long stride
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0 && stride > 0 ,
+ "\tconst matrix_exp mat(ptr, nr, nc, stride)"
+ << "\n\t nr and nc must be >= 0 while stride > 0"
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ << "\n\t stride: " << stride
+ );
+ typedef op_pointer_to_mat<T> op;
+ return matrix_op<op>(op(ptr,nr,nc,stride));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+namespace arma
+{
+ template <typename T> class Mat;
+}
+namespace dlib
+{
+ template <typename T>
+ struct op_arma_Mat_to_mat : does_not_alias
+ {
+ op_arma_Mat_to_mat( const T& array_) : array(array_){}
+
+ const T& array;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef typename T::elem_type type;
+ typedef typename T::elem_type const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const { return array(r,c); }
+
+ long nr () const { return array.n_rows; }
+ long nc () const { return array.n_cols; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_arma_Mat_to_mat< ::arma::Mat<T> > > mat (
+ const ::arma::Mat<T>& array
+ )
+ {
+ typedef op_arma_Mat_to_mat< ::arma::Mat<T> > op;
+ return matrix_op<op>(op(array));
+ }
+}
+
+namespace Eigen
+{
+ template<typename _Scalar, int _Rows, int _Cols, int _Options, int _MaxRows, int _MaxCols>
+ class Matrix;
+}
+
+namespace dlib
+{
+ template <typename T, int _Rows, int _Cols>
+ struct op_eigen_Matrix_to_mat : does_not_alias
+ {
+ op_eigen_Matrix_to_mat( const T& array_) : m(array_){}
+
+ const T& m;
+
+ const static long cost = 1;
+ const static long NR = (_Rows > 0) ? _Rows : 0;
+ const static long NC = (_Cols > 0) ? _Cols : 0;
+ typedef typename T::Scalar type;
+ typedef typename T::Scalar const_ret_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const { return m(r,c); }
+
+ long nr () const { return m.rows(); }
+ long nc () const { return m.cols(); }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename _Scalar, int _Rows, int _Cols, int _Options, int _MaxRows, int _MaxCols
+ >
+ const matrix_op<op_eigen_Matrix_to_mat< ::Eigen::Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols>,_Rows,_Cols > > mat (
+ const ::Eigen::Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols>& m
+ )
+ {
+ typedef op_eigen_Matrix_to_mat< ::Eigen::Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols>,_Rows,_Cols > op;
+ return matrix_op<op>(op(m));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// DEPRECATED FUNCTIONS
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+// vector_to_matrix(), array_to_matrix(), pointer_to_matrix(), and
+// pointer_to_column_vector() have been deprecated in favor of the more uniform mat()
+// function. But they are here for backwards compatibility.
+
+ template <
+ typename vector_type
+ >
+ const typename disable_if<is_matrix<vector_type>, matrix_op<op_array_to_mat<vector_type> > >::type
+ vector_to_matrix (
+ const vector_type& vector
+ )
+ {
+ typedef op_array_to_mat<vector_type> op;
+ return matrix_op<op>(op(vector));
+ }
+
+ template <
+ typename vector_type
+ >
+ const typename enable_if<is_matrix<vector_type>,vector_type>::type& vector_to_matrix (
+ const vector_type& vector
+ )
+ /*!
+ This overload catches the case where the argument to this function is
+ already a matrix.
+ !*/
+ {
+ return vector;
+ }
+
+ template <
+ typename value_type,
+ typename alloc
+ >
+ const matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > > vector_to_matrix (
+ const std::vector<value_type,alloc>& vector
+ )
+ {
+ typedef op_std_vect_to_mat<std::vector<value_type,alloc> > op;
+ return matrix_op<op>(op(vector));
+ }
+
+ template <
+ typename value_type,
+ typename alloc
+ >
+ const matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > > vector_to_matrix (
+ const std_vector_c<value_type,alloc>& vector
+ )
+ {
+ typedef op_std_vect_to_mat<std_vector_c<value_type,alloc> > op;
+ return matrix_op<op>(op(vector));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type
+ >
+ const typename enable_if<is_matrix<array_type>,array_type>::type&
+ array_to_matrix (
+ const array_type& array
+ )
+ {
+ return array;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct op_array2d_to_mat : does_not_alias
+ {
+ op_array2d_to_mat( const T& array_) : array(array_){}
+
+ const T& array;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef typename T::type type;
+ typedef const typename T::type& const_ret_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const { return array[r][c]; }
+
+ long nr () const { return array.nr(); }
+ long nc () const { return array.nc(); }
+ };
+
+ // Note that we have this version of mat() because it's slightly faster executing
+ // than the general one that handles any generic image. This is because it avoids
+ // calling image_data() which for array2d involves a single if statement but this
+ // version here has no if statement in its construction.
+ template < typename T, typename MM >
+ const matrix_op<op_array2d_to_mat<array2d<T,MM> > > mat (
+ const array2d<T,MM>& array
+ )
+ {
+ typedef op_array2d_to_mat<array2d<T,MM> > op;
+ return matrix_op<op>(op(array));
+ }
+
+ template <
+ typename array_type
+ >
+ const typename disable_if<is_matrix<array_type>,matrix_op<op_array2d_to_mat<array_type> > >::type
+ array_to_matrix (
+ const array_type& array
+ )
+ {
+ typedef op_array2d_to_mat<array_type> op;
+ return matrix_op<op>(op(array));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_pointer_to_mat<T> > pointer_to_matrix (
+ const T* ptr,
+ long nr,
+ long nc
+ )
+ {
+ DLIB_ASSERT(nr > 0 && nc > 0 ,
+ "\tconst matrix_exp pointer_to_matrix(ptr, nr, nc)"
+ << "\n\t nr and nc must be bigger than 0"
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ );
+ typedef op_pointer_to_mat<T> op;
+ return matrix_op<op>(op(ptr,nr,nc));
+ }
+
+ template <
+ typename T
+ >
+ const matrix_op<op_pointer_to_col_vect<T> > pointer_to_column_vector (
+ const T* ptr,
+ long nr
+ )
+ {
+ DLIB_ASSERT(nr > 0 ,
+ "\tconst matrix_exp pointer_to_column_vector(ptr, nr)"
+ << "\n\t nr must be bigger than 0"
+ << "\n\t nr: " << nr
+ );
+ typedef op_pointer_to_col_vect<T> op;
+ return matrix_op<op>(op(ptr, nr));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ inline matrix<double,1,1> mat (
+ double value
+ )
+ {
+ matrix<double,1,1> temp;
+ temp(0) = value;
+ return temp;
+ }
+
+ inline matrix<float,1,1> mat (
+ float value
+ )
+ {
+ matrix<float,1,1> temp;
+ temp(0) = value;
+ return temp;
+ }
+
+ inline matrix<long double,1,1> mat (
+ long double value
+ )
+ {
+ matrix<long double,1,1> temp;
+ temp(0) = value;
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_MAT_Hh_
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_mat_abstract.h b/ml/dlib/dlib/matrix/matrix_mat_abstract.h
new file mode 100644
index 000000000..7026f60a1
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_mat_abstract.h
@@ -0,0 +1,243 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MATRIx_MAT_ABSTRACT_Hh_
+#ifdef DLIB_MATRIx_MAT_ABSTRACT_Hh_
+
+#include "matrix_abstract.h"
+#inclue <vector>
+#include "../array/array_kernel_abstract.h"
+#include "../array2d/array2d_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const matrix_exp<EXP>& mat (
+ const matrix_exp<EXP>& m
+ );
+ /*!
+ ensures
+ - returns m
+ (i.e. this function just returns the input matrix without any modifications)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ const matrix_exp mat (
+ const image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or image_type is a image_view or
+ const_image_view object.
+ ensures
+ - This function converts any kind of generic image object into a dlib::matrix
+ expression. Therefore, it is capable of converting objects like dlib::array2d
+ of dlib::cv_image.
+ - returns a matrix R such that:
+ - R.nr() == array.nr()
+ - R.nc() == array.nc()
+ - for all valid r and c:
+ R(r, c) == array[r][c]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename MM
+ >
+ const matrix_exp mat (
+ const array<T,MM>& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - is_col_vector(R) == true
+ - R.size() == m.size()
+ - for all valid r:
+ R(r) == m[r]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename value_type,
+ typename alloc
+ >
+ const matrix_exp mat (
+ const std::vector<value_type,alloc>& vector
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - is_col_vector(R) == true
+ - R.size() == vector.size()
+ - for all valid r:
+ R(r) == vector[r]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename value_type,
+ typename alloc
+ >
+ const matrix_exp mat (
+ const std_vector_c<value_type,alloc>& vector
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - is_col_vector(R) == true
+ - R.size() == vector.size()
+ - for all valid r:
+ R(r) == vector[r]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp mat (
+ const T* ptr,
+ long nr
+ );
+ /*!
+ requires
+ - nr >= 0
+ - ptr == a pointer to at least nr T objects (or the NULL pointer if nr==0)
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == nr
+ - m.nc() == 1
+ - for all valid i:
+ M(i) == ptr[i]
+ - Note that the returned matrix doesn't take "ownership" of
+ the pointer and thus will not delete or free it.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp mat (
+ const T* ptr,
+ long nr,
+ long nc
+ );
+ /*!
+ requires
+ - nr >= 0
+ - nc >= 0
+ - ptr == a pointer to at least nr*nc T objects (or the NULL pointer if nr*nc==0)
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == nr
+ - m.nc() == nc
+ - for all valid r and c:
+ M(r,c) == ptr[r*nc + c]
+ (i.e. the pointer is interpreted as a matrix laid out in memory
+ in row major order)
+ - Note that the returned matrix doesn't take "ownership" of
+ the pointer and thus will not delete or free it.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp mat (
+ const T* ptr,
+ long nr,
+ long nc,
+ long stride
+ );
+ /*!
+ requires
+ - nr >= 0
+ - nc >= 0
+ - stride > 0
+ - ptr == a pointer to at least (nr-1)*stride+nc T objects (or the NULL pointer if nr*nc==0)
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == nr
+ - m.nc() == nc
+ - for all valid r and c:
+ M(r,c) == ptr[r*stride + c]
+ (i.e. the pointer is interpreted as a matrix laid out in memory
+ in row major order, with a row stride of the given stride amount.)
+ - Note that the returned matrix doesn't take "ownership" of
+ the pointer and thus will not delete or free it.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp mat (
+ const ::arma::Mat<T>& m
+ );
+ /*!
+ ensures
+ - Converts a matrix from the Armadillo library into a dlib matrix.
+ - returns a matrix R such that:
+ - R.nr() == m.n_rows
+ - R.nc() == m.n_cols
+ - for all valid r:
+ R(r,c) == m(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename _Scalar,
+ int _Rows,
+ int _Cols,
+ int _Options,
+ int _MaxRows,
+ int _MaxCols
+ >
+ const matrix_exp mat (
+ const ::Eigen::Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols>& m
+ );
+ /*!
+ ensures
+ - Converts a matrix from the Eigen library into a dlib matrix.
+ - returns a matrix R such that:
+ - R.nr() == m.rows()
+ - R.nc() == m.cols()
+ - for all valid r:
+ R(r,c) == m(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ matrix<double,1,1> mat (double value);
+ matrix<float,1,1> mat (float value);
+ matrix<long double,1,1> mat (long double value);
+ /*!
+ ensures
+ - Converts a scalar into a matrix containing just that scalar and returns the
+ results.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_MAT_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_math_functions.h b/ml/dlib/dlib/matrix/matrix_math_functions.h
new file mode 100644
index 000000000..d1db3ed14
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_math_functions.h
@@ -0,0 +1,448 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_MATH_FUNCTIONS
+#define DLIB_MATRIx_MATH_FUNCTIONS
+
+#include "matrix_math_functions_abstract.h"
+#include "matrix_op.h"
+#include "matrix_utilities.h"
+#include "matrix.h"
+#include "../algs.h"
+#include <cmath>
+#include <complex>
+#include <limits>
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ DLIB_DEFINE_FUNCTION_M(op_sqrt, sqrt, std::sqrt ,7);
+ DLIB_DEFINE_FUNCTION_M(op_log, log, std::log ,7);
+ DLIB_DEFINE_FUNCTION_M(op_log10, log10, std::log10 ,7);
+ DLIB_DEFINE_FUNCTION_M(op_exp, exp, std::exp ,7);
+
+ DLIB_DEFINE_FUNCTION_M(op_conj, conj, std::conj ,2);
+
+ DLIB_DEFINE_FUNCTION_M(op_ceil, ceil, std::ceil ,7);
+ DLIB_DEFINE_FUNCTION_M(op_floor, floor, std::floor ,7);
+
+ DLIB_DEFINE_FUNCTION_M(op_sin, sin, std::sin ,7);
+ DLIB_DEFINE_FUNCTION_M(op_cos, cos, std::cos ,7);
+ DLIB_DEFINE_FUNCTION_M(op_tan, tan, std::tan ,7);
+ DLIB_DEFINE_FUNCTION_M(op_sinh, sinh, std::sinh ,7);
+ DLIB_DEFINE_FUNCTION_M(op_cosh, cosh, std::cosh ,7);
+ DLIB_DEFINE_FUNCTION_M(op_tanh, tanh, std::tanh ,7);
+ DLIB_DEFINE_FUNCTION_M(op_asin, asin, std::asin ,7);
+ DLIB_DEFINE_FUNCTION_M(op_acos, acos, std::acos ,7);
+ DLIB_DEFINE_FUNCTION_M(op_atan, atan, std::atan ,7);
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename type>
+ inline type sigmoid (const type& val)
+ {
+ return static_cast<type>(1/(1 + std::exp(-val)));
+ }
+
+ template <typename type, typename S>
+ inline type round_zeros_eps (const type& val, const S& eps)
+ {
+ // you can only round matrices that contain built in scalar types like double, long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<type>::value);
+
+ if (val >= eps || val <= -eps)
+ return val;
+ else
+ return 0;
+ }
+
+ template <typename type>
+ inline type round_zeros (const type& val)
+ {
+ // you can only round matrices that contain built in scalar types like double, long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<type>::value);
+
+ const type eps = 10*std::numeric_limits<type>::epsilon();
+ if (val >= eps || val <= -eps)
+ return val;
+ else
+ return 0;
+ }
+
+ template <typename type>
+ inline type squared (const type& val)
+ {
+ return val*val;
+ }
+
+ template <typename type>
+ inline type sign (const type& val)
+ {
+ if (val >= 0)
+ return +1;
+ else
+ return -1;
+ }
+
+ template <typename type>
+ type cubed (const type& val)
+ {
+ return val*val*val;
+ }
+
+ template <typename type, typename S>
+ inline type pow1 (const type& val, const S& s)
+ {
+ // you can only call pow() on matrices that contain floats, doubles or long doubles.
+ COMPILE_TIME_ASSERT((
+ is_same_type<type,float>::value == true ||
+ is_same_type<type,double>::value == true ||
+ is_same_type<type,long double>::value == true
+ ));
+
+ return std::pow(val,static_cast<type>(s));
+ }
+
+ template <typename type, typename S>
+ inline type pow2 (const S& s, const type& val)
+ {
+ // you can only call pow() on matrices that contain floats, doubles or long doubles.
+ COMPILE_TIME_ASSERT((
+ is_same_type<type,float>::value == true ||
+ is_same_type<type,double>::value == true ||
+ is_same_type<type,long double>::value == true
+ ));
+
+ return std::pow(static_cast<type>(s),val);
+ }
+
+ template <typename type>
+ inline type reciprocal (const type& val)
+ {
+ // you can only compute reciprocal matrices that contain floats, doubles or long doubles.
+ COMPILE_TIME_ASSERT((
+ is_same_type<type,float>::value == true ||
+ is_same_type<type,double>::value == true ||
+ is_same_type<type,long double>::value == true ||
+ is_same_type<type,std::complex<float> >::value == true ||
+ is_same_type<type,std::complex<double> >::value == true ||
+ is_same_type<type,std::complex<long double> >::value == true
+ ));
+
+ if (val != static_cast<type>(0))
+ return static_cast<type>((type)1.0/val);
+ else
+ return 0;
+ }
+
+ template <typename type>
+ inline type reciprocal_max (const type& val)
+ {
+ // you can only compute reciprocal_max matrices that contain floats, doubles or long doubles.
+ COMPILE_TIME_ASSERT((
+ is_same_type<type,float>::value == true ||
+ is_same_type<type,double>::value == true ||
+ is_same_type<type,long double>::value == true
+ ));
+
+ if (val != static_cast<type>(0))
+ return static_cast<type>((type)1.0/val);
+ else
+ return std::numeric_limits<type>::max();
+ }
+
+ }
+
+ DLIB_DEFINE_FUNCTION_M(op_sigmoid, sigmoid, impl::sigmoid, 7);
+ DLIB_DEFINE_FUNCTION_MS(op_round_zeros, round_zeros, impl::round_zeros_eps, 7);
+ DLIB_DEFINE_FUNCTION_M(op_round_zeros2, round_zeros, impl::round_zeros, 7);
+ DLIB_DEFINE_FUNCTION_M(op_cubed, cubed, impl::cubed, 7);
+ DLIB_DEFINE_FUNCTION_M(op_squared, squared, impl::squared, 6);
+ DLIB_DEFINE_FUNCTION_M(op_sign, sign, impl::sign, 6);
+ DLIB_DEFINE_FUNCTION_MS(op_pow1, pow, impl::pow1, 7);
+ DLIB_DEFINE_FUNCTION_SM(op_pow2, pow, impl::pow2, 7);
+ DLIB_DEFINE_FUNCTION_M(op_reciprocal, reciprocal, impl::reciprocal, 6);
+ DLIB_DEFINE_FUNCTION_M(op_reciprocal_max, reciprocal_max, impl::reciprocal_max, 6);
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, typename enabled = void>
+ struct op_round : basic_op_m<M>
+ {
+ op_round( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost+7;
+ typedef typename M::type type;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply (long r, long c) const
+ {
+ return static_cast<type>(std::floor(this->m(r,c)+0.5));
+ }
+ };
+
+ template <typename M>
+ struct op_round<M,typename enable_if_c<std::numeric_limits<typename M::type>::is_integer>::type >
+ : basic_op_m<M>
+ {
+ op_round( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ const_ret_type apply (long r, long c) const
+ {
+ return this->m(r,c);
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_round<EXP> > round (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only round matrices that contain built in scalar types like double, long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_round<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_normalize : basic_op_m<M>
+ {
+ typedef typename M::type type;
+
+ op_normalize( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
+
+ const type s;
+
+ const static long cost = M::cost+5;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply (long r, long c) const
+ {
+ return this->m(r,c)*s;
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_normalize<EXP> > normalize (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only compute normalized matrices that contain floats, doubles or long doubles.
+ COMPILE_TIME_ASSERT((
+ is_same_type<typename EXP::type,float>::value == true ||
+ is_same_type<typename EXP::type,double>::value == true ||
+ is_same_type<typename EXP::type,long double>::value == true
+ ));
+
+
+ typedef op_normalize<EXP> op;
+ typename EXP::type temp = std::sqrt(sum(squared(m)));
+ if (temp != 0.0)
+ temp = 1.0/temp;
+
+ return matrix_op<op>(op(m.ref(),temp));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, typename return_type = typename M::type>
+ struct op_abs : basic_op_m<M>
+ {
+ op_abs( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost+7;
+ typedef typename M::type type;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ return static_cast<type>(std::abs(this->m(r,c)));
+ }
+ };
+
+ template <typename M, typename T>
+ struct op_abs<M, std::complex<T> > : basic_op_m<M>
+ {
+ op_abs( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost;
+ typedef T type;
+ typedef const T const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ return static_cast<type>(std::abs(this->m(r,c)));
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_abs<EXP> > abs (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_abs<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_complex_matrix : basic_op_m<M>
+ {
+ op_complex_matrix( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost+1;
+ typedef std::complex<typename M::type> type;
+ typedef const std::complex<typename M::type> const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ return type(this->m(r,c));
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_complex_matrix<EXP> > complex_matrix (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_complex_matrix<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_complex_matrix2 : basic_op_mm<M1,M2>
+ {
+ op_complex_matrix2( const M1& m1_, const M2& m2_) : basic_op_mm<M1,M2>(m1_,m2_){}
+
+ const static long cost = M1::cost+M2::cost+1;
+ typedef std::complex<typename M1::type> type;
+ typedef const std::complex<typename M1::type> const_ret_type;
+
+ const_ret_type apply ( long r, long c) const
+ { return type(this->m1(r,c), this->m2(r,c)); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ const matrix_op<op_complex_matrix2<EXP1,EXP2> > complex_matrix (
+ const matrix_exp<EXP1>& real_part,
+ const matrix_exp<EXP2>& imag_part
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0);
+
+ DLIB_ASSERT(real_part.nr() == imag_part.nr() &&
+ real_part.nc() == imag_part.nc(),
+ "\tconst matrix_exp::type complex_matrix(real_part, imag_part)"
+ << "\n\tYou can only make a complex matrix from two equally sized matrices"
+ << "\n\treal_part.nr(): " << real_part.nr()
+ << "\n\treal_part.nc(): " << real_part.nc()
+ << "\n\timag_part.nr(): " << imag_part.nr()
+ << "\n\timag_part.nc(): " << imag_part.nc()
+ );
+
+ typedef op_complex_matrix2<EXP1,EXP2> op;
+ return matrix_op<op>(op(real_part.ref(),imag_part.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_norm : basic_op_m<M>
+ {
+ op_norm( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost+6;
+ typedef typename M::type::value_type type;
+ typedef const typename M::type::value_type const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ { return std::norm(this->m(r,c)); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_norm<EXP> > norm (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_norm<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_real : basic_op_m<M>
+ {
+ op_real( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost;
+ typedef typename M::type::value_type type;
+ typedef const typename M::type::value_type const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ { return std::real(this->m(r,c)); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_real<EXP> > real (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_real<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_imag : basic_op_m<M>
+ {
+ op_imag( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost;
+ typedef typename M::type::value_type type;
+ typedef const typename M::type::value_type const_ret_type;
+ const_ret_type apply (long r, long c) const
+ { return std::imag(this->m(r,c)); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_imag<EXP> > imag (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_imag<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_MATH_FUNCTIONS
+
diff --git a/ml/dlib/dlib/matrix/matrix_math_functions_abstract.h b/ml/dlib/dlib/matrix/matrix_math_functions_abstract.h
new file mode 100644
index 000000000..09210270d
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_math_functions_abstract.h
@@ -0,0 +1,595 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MATRIx_MATH_FUNCTIONS_ABSTRACT_
+#ifdef DLIB_MATRIx_MATH_FUNCTIONS_ABSTRACT_
+
+#include "matrix_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Exponential Functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp exp (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::exp(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp log10 (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::log10(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp log (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::log(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp sqrt (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == sqrt(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ const matrix_exp pow (
+ const matrix_exp& m,
+ const T& e
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == pow(m(r,c),e)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ const matrix_exp pow (
+ const T& b,
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == pow(b, m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp squared (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == m(r,c)*m(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp cubed (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == m(r,c)*m(r,c)*m(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Miscellaneous
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp sign (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix that tells the sign of each element in m. In particular:
+ returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ - if (m(r,c) >= 0) then
+ - R(r,c) == +1
+ - else
+ - R(r,c) == -1
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp sigmoid (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == 1/(1 + exp(-m(r,c)))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp abs (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - if (m contains std::complex<T> objects) then
+ - R::type == T
+ - else
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::abs(m(r,c))
+ (note that if m is complex then std::abs(val) performs sqrt(std::norm(val))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp reciprocal (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, long double, std::complex<float>,
+ std::complex<double>, or std::complex<long double>
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ - if (m(r,c) != 0) then
+ - R(r,c) == 1.0/m(r,c)
+ - else
+ - R(r,c) == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp reciprocal_max (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ - if (m(r,c) != 0) then
+ - R(r,c) == 1.0/m(r,c)
+ - else
+ - R(r,c) == std::numeric_limits<R::type>::max()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp normalize (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - if (sqrt(sum(squared(m))) != 0) then
+ - returns m/sqrt(sum(squared(m)))
+ - else
+ - returns m
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Rounding numbers one way or another
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp round (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ (i.e. m must contain a type like int, float, double, long, etc.)
+ ensures
+ - if (m contains integers) then
+ - returns m unmodified
+ - else
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == m(r,c) rounded to the nearest integral value
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp ceil (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::ceil(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp floor (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::floor(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp round_zeros (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ (i.e. m must contain a type like int, float, double, long, etc.)
+ ensures
+ - if (m contains integers) then
+ - returns m unmodified
+ - else
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - let eps == 10*std::numeric_limits<matrix_exp::type>::epsilon()
+ - for all valid r and c:
+ - if (abs(m(r,c)) >= eps) then
+ - R(r,c) == m(r,c)
+ - else
+ - R(r,c) == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp round_zeros (
+ const matrix_exp& m,
+ matrix_exp::type eps
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ (i.e. m must contain a type like int, float, double, long, etc.)
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ - if (abs(m(r,c)) >= eps) then
+ - R(r,c) == m(r,c)
+ - else
+ - R(r,c) == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Complex number utility functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp conj (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == std::complex<T>
+ ensures
+ - returns a matrix R such that:
+ - R::type == std::complex<T>
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::conj(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp norm (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == std::complex<T>
+ ensures
+ - returns a matrix R such that:
+ - R::type == T
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::norm(m(r,c))
+ (note that std::norm(val) == val.real()*val.real() + val.imag()*val.imag())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp imag (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == std::complex<T>
+ ensures
+ - returns a matrix R such that:
+ - R::type == T
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::imag(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp real (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == std::complex<T>
+ ensures
+ - returns a matrix R such that:
+ - R::type == T
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::real(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp complex_matrix (
+ const matrix_exp& real_part
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == std::complex<T> where T is whatever type real_part used.
+ - R has the same dimensions as real_part.
+ - for all valid r and c:
+ R(r,c) == std::complex(real_part(r,c), 0)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp complex_matrix (
+ const matrix_exp& real_part,
+ const matrix_exp& imag_part
+ );
+ /*!
+ requires
+ - real_part.nr() == imag_part.nr()
+ - real_part.nc() == imag_part.nc()
+ - real_part and imag_part both contain the same type of element
+ ensures
+ - returns a matrix R such that:
+ - R::type == std::complex<T> where T is whatever type real_part and imag_part used.
+ - R has the same dimensions as real_part and imag_part
+ - for all valid r and c:
+ R(r,c) == std::complex(real_part(r,c),imag_part(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Trigonometric Functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp sin (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::sin(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp cos (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::cos(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp tan (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::tan(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp asin (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::asin(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp acos (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::acos(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp atan (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::atan(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp sinh (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::sinh(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp cosh (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::cosh(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp tanh (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == float, double, or long double
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R(r,c) == std::tanh(m(r,c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_MATH_FUNCTIONS_ABSTRACT_
+
diff --git a/ml/dlib/dlib/matrix/matrix_op.h b/ml/dlib/dlib/matrix/matrix_op.h
new file mode 100644
index 000000000..524a775eb
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_op.h
@@ -0,0 +1,479 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_OP_H_
+#define DLIB_MATRIx_OP_H_
+
+#include "matrix_exp.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename OP >
+ class matrix_op;
+
+ template < typename OP >
+ struct matrix_traits<matrix_op<OP> >
+ {
+ typedef typename OP::type type;
+ typedef typename OP::const_ret_type const_ret_type;
+ typedef typename OP::mem_manager_type mem_manager_type;
+ typedef typename OP::layout_type layout_type;
+ const static long NR = OP::NR;
+ const static long NC = OP::NC;
+ const static long cost = OP::cost;
+ };
+
+ template <
+ typename OP
+ >
+ class matrix_op : public matrix_exp<matrix_op<OP> >
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ The matrix_op is simply a tool for reducing the amount of boilerplate
+ you need to write when creating matrix expressions.
+ !*/
+
+ public:
+ typedef typename matrix_traits<matrix_op>::type type;
+ typedef typename matrix_traits<matrix_op>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_op>::mem_manager_type mem_manager_type;
+ typedef typename matrix_traits<matrix_op>::layout_type layout_type;
+ const static long NR = matrix_traits<matrix_op>::NR;
+ const static long NC = matrix_traits<matrix_op>::NC;
+ const static long cost = matrix_traits<matrix_op>::cost;
+
+ private:
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of object.
+ template <typename T1>
+ matrix_op (T1);
+ public:
+
+ matrix_op (
+ const OP& op_
+ ) :
+ op(op_)
+ {}
+
+ const_ret_type operator() (
+ long r,
+ long c
+ ) const { return op.apply(r,c); }
+
+ const_ret_type operator() ( long i ) const
+ { return matrix_exp<matrix_op>::operator()(i); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const { return op.aliases(item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const { return op.destructively_aliases(item); }
+
+ long nr (
+ ) const { return op.nr(); }
+
+ long nc (
+ ) const { return op.nc(); }
+
+
+ const OP op;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename OP >
+ class matrix_diag_op;
+
+ template < typename OP >
+ struct matrix_traits<matrix_diag_op<OP> >
+ {
+ typedef typename OP::type type;
+ typedef typename OP::const_ret_type const_ret_type;
+ typedef typename OP::mem_manager_type mem_manager_type;
+ typedef typename OP::layout_type layout_type;
+ const static long NR = OP::NR;
+ const static long NC = OP::NC;
+ const static long cost = OP::cost;
+ };
+
+ template <
+ typename OP
+ >
+ class matrix_diag_op : public matrix_diag_exp<matrix_diag_op<OP> >
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ The matrix_diag_op is simply a tool for reducing the amount of boilerplate
+ you need to write when creating matrix expressions.
+ !*/
+
+ public:
+ typedef typename matrix_traits<matrix_diag_op>::type type;
+ typedef typename matrix_traits<matrix_diag_op>::const_ret_type const_ret_type;
+ typedef typename matrix_traits<matrix_diag_op>::mem_manager_type mem_manager_type;
+ typedef typename matrix_traits<matrix_diag_op>::layout_type layout_type;
+ const static long NR = matrix_traits<matrix_diag_op>::NR;
+ const static long NC = matrix_traits<matrix_diag_op>::NC;
+ const static long cost = matrix_traits<matrix_diag_op>::cost;
+
+ private:
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of object.
+ template <typename T1>
+ matrix_diag_op (T1);
+ public:
+
+ matrix_diag_op (
+ const OP& op_
+ ) :
+ op(op_)
+ {}
+
+ const_ret_type operator() (
+ long r,
+ long c
+ ) const { return op.apply(r,c); }
+
+ const_ret_type operator() ( long i ) const
+ { return matrix_exp<matrix_diag_op>::operator()(i); }
+
+ template <typename U>
+ bool aliases (
+ const matrix_exp<U>& item
+ ) const { return op.aliases(item); }
+
+ template <typename U>
+ bool destructively_aliases (
+ const matrix_exp<U>& item
+ ) const { return op.destructively_aliases(item); }
+
+ long nr (
+ ) const { return op.nr(); }
+
+ long nc (
+ ) const { return op.nc(); }
+
+
+ const OP op;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct does_not_alias
+ {
+ /*!
+ This is a partial implementation of a matrix operator that never aliases
+ another expression.
+ !*/
+
+ template <typename U> bool aliases ( const U& ) const { return false; }
+ template <typename U> bool destructively_aliases ( const U& ) const { return false; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct basic_op_m
+ {
+ /*!
+ This is a partial implementation of a matrix operator that preserves
+ the dimensions of its argument and doesn't have destructive aliasing.
+ !*/
+
+ private:
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of object.
+ template <typename T1>
+ basic_op_m (T1);
+ public:
+
+ basic_op_m(
+ const M& m_
+ ) : m(m_){}
+
+ const M& m;
+
+ const static long NR = M::NR;
+ const static long NC = M::NC;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+ long nr () const { return m.nr(); }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m.destructively_aliases(item); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct basic_op_mm
+ {
+ /*!
+ This is a partial implementation of a matrix operator that preserves
+ the dimensions of its arguments and doesn't have destructive aliasing.
+ !*/
+
+ private:
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of object.
+ template <typename T1, typename T2>
+ basic_op_mm (T1, T2);
+ public:
+
+ basic_op_mm(
+ const M1& m1_,
+ const M2& m2_
+ ) : m1(m1_), m2(m2_){}
+
+ const M1& m1;
+ const M2& m2;
+
+ const static long NR = M1::NR;
+ const static long NC = M1::NC;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+
+ long nr () const { return m1.nr(); }
+ long nc () const { return m1.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.destructively_aliases(item) || m2.destructively_aliases(item); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, typename M3>
+ struct basic_op_mmm
+ {
+ /*!
+ This is a partial implementation of a matrix operator that preserves
+ the dimensions of its arguments and doesn't have destructive aliasing.
+ !*/
+
+ private:
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of object.
+ template <typename T1, typename T2, typename T3>
+ basic_op_mmm (T1, T2, T3);
+ public:
+
+ basic_op_mmm(
+ const M1& m1_,
+ const M2& m2_,
+ const M3& m3_
+ ) : m1(m1_), m2(m2_), m3(m3_){}
+
+ const M1& m1;
+ const M2& m2;
+ const M3& m3;
+
+ const static long NR = M1::NR;
+ const static long NC = M1::NC;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+
+ long nr () const { return m1.nr(); }
+ long nc () const { return m1.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item) || m3.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.destructively_aliases(item) || m2.destructively_aliases(item) ||
+ m3.destructively_aliases(item);}
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, typename M3, typename M4>
+ struct basic_op_mmmm
+ {
+ /*!
+ This is a partial implementation of a matrix operator that preserves
+ the dimensions of its arguments and doesn't have destructive aliasing.
+ !*/
+
+ private:
+ // This constructor exists simply for the purpose of causing a compile time error if
+ // someone tries to create an instance of this object with the wrong kind of object.
+ template <typename T1, typename T2, typename T3, typename T4>
+ basic_op_mmmm (T1, T2, T3, T4);
+ public:
+
+ basic_op_mmmm(
+ const M1& m1_,
+ const M2& m2_,
+ const M3& m3_,
+ const M4& m4_
+ ) : m1(m1_), m2(m2_), m3(m3_), m4(m4_){}
+
+ const M1& m1;
+ const M2& m2;
+ const M3& m3;
+ const M4& m4;
+
+ const static long NR = M1::NR;
+ const static long NC = M1::NC;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+
+ long nr () const { return m1.nr(); }
+ long nc () const { return m1.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item) || m3.aliases(item) || m4.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.destructively_aliases(item) || m2.destructively_aliases(item) ||
+ m3.destructively_aliases(item) || m4.destructively_aliases(item);}
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+#define DLIB_DEFINE_OP_M(op_name, function, extra_cost) \
+ template <typename M> \
+ struct op_name \
+ { \
+ op_name( \
+ const M& m_ \
+ ) : m(m_){} \
+ \
+ const M& m; \
+ \
+ const static long cost = M::cost+(extra_cost); \
+ const static long NR = M::NR; \
+ const static long NC = M::NC; \
+ typedef typename M::type type; \
+ typedef const typename M::type const_ret_type; \
+ typedef typename M::mem_manager_type mem_manager_type; \
+ typedef typename M::layout_type layout_type; \
+ \
+ const_ret_type apply (long r, long c) const { return function(m(r,c)); } \
+ \
+ long nr () const { return m.nr(); } \
+ long nc () const { return m.nc(); } \
+ \
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const \
+ { return m.aliases(item); } \
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const \
+ { return m.destructively_aliases(item); } \
+ \
+ }
+
+#define DLIB_DEFINE_FUNCTION_M(op_name, name, function, extra_cost) \
+ DLIB_DEFINE_OP_M(op_name, function, extra_cost); \
+ template < typename M > \
+ const matrix_op<op_name<M> > name ( const matrix_exp<M>& m) \
+ { \
+ typedef op_name<M> op; \
+ return matrix_op<op>(op(m.ref())); \
+ }
+
+// ----------------------------------------------------------------------------------------
+
+#define DLIB_DEFINE_OP_MS(op_name, function, extra_cost) \
+ template <typename M, typename S> \
+ struct op_name \
+ { \
+ op_name( \
+ const M& m_, \
+ const S& s_ \
+ ) : m(m_), s(s_){} \
+ \
+ const M& m; \
+ const S s; \
+ \
+ const static long cost = M::cost+(extra_cost); \
+ const static long NR = M::NR; \
+ const static long NC = M::NC; \
+ typedef typename M::type type; \
+ typedef const typename M::type const_ret_type; \
+ typedef typename M::mem_manager_type mem_manager_type; \
+ typedef typename M::layout_type layout_type; \
+ \
+ const_ret_type apply (long r, long c) const { return function(m(r,c), s); } \
+ \
+ long nr () const { return m.nr(); } \
+ long nc () const { return m.nc(); } \
+ \
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const \
+ { return m.aliases(item); } \
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const \
+ { return m.destructively_aliases(item); } \
+ \
+ }
+
+#define DLIB_DEFINE_FUNCTION_MS(op_name, name, function, extra_cost) \
+ DLIB_DEFINE_OP_MS(op_name, function, extra_cost); \
+ template < typename M, typename S > \
+ const matrix_op<op_name<M, S> > name ( const matrix_exp<M>& m, const S& s) \
+ { \
+ typedef op_name<M, S> op; \
+ return matrix_op<op>(op(m.ref(), s)); \
+ }
+
+// ----------------------------------------------------------------------------------------
+
+#define DLIB_DEFINE_OP_SM(op_name, function, extra_cost) \
+ template <typename S, typename M> \
+ struct op_name \
+ { \
+ op_name( \
+ const S& s_, \
+ const M& m_ \
+ ) : m(m_), s(s_){} \
+ \
+ const M& m; \
+ const S s; \
+ \
+ const static long cost = M::cost+(extra_cost); \
+ const static long NR = M::NR; \
+ const static long NC = M::NC; \
+ typedef typename M::type type; \
+ typedef const typename M::type const_ret_type; \
+ typedef typename M::mem_manager_type mem_manager_type; \
+ typedef typename M::layout_type layout_type; \
+ \
+ const_ret_type apply (long r, long c) const { return function(s, m(r,c)); } \
+ \
+ long nr () const { return m.nr(); } \
+ long nc () const { return m.nc(); } \
+ \
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const \
+ { return m.aliases(item); } \
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const \
+ { return m.destructively_aliases(item); } \
+ \
+ }
+
+#define DLIB_DEFINE_FUNCTION_SM(op_name, name, function, extra_cost) \
+ DLIB_DEFINE_OP_SM(op_name, function, extra_cost); \
+ template < typename S, typename M > \
+ const matrix_op<op_name<S, M> > name (const S& s, const matrix_exp<M>& m) \
+ { \
+ typedef op_name<S, M> op; \
+ return matrix_op<op>(op(s, m.ref())); \
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_OP_H_
+
diff --git a/ml/dlib/dlib/matrix/matrix_qr.h b/ml/dlib/dlib/matrix/matrix_qr.h
new file mode 100644
index 000000000..086d481f1
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_qr.h
@@ -0,0 +1,466 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+// This code was adapted from code from the JAMA part of NIST's TNT library.
+// See: http://math.nist.gov/tnt/
+#ifndef DLIB_MATRIX_QR_DECOMPOSITION_H
+#define DLIB_MATRIX_QR_DECOMPOSITION_H
+
+#include "matrix.h"
+#include "matrix_utilities.h"
+#include "matrix_subexp.h"
+
+#ifdef DLIB_USE_LAPACK
+#include "lapack/geqrf.h"
+#include "lapack/ormqr.h"
+#endif
+
+#include "matrix_trsm.h"
+
+namespace dlib
+{
+
+ template <
+ typename matrix_exp_type
+ >
+ class qr_decomposition
+ {
+
+ public:
+
+ const static long NR = matrix_exp_type::NR;
+ const static long NC = matrix_exp_type::NC;
+ typedef typename matrix_exp_type::type type;
+ typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_exp_type::layout_type layout_type;
+
+ typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
+
+ // You have supplied an invalid type of matrix_exp_type. You have
+ // to use this object with matrices that contain float or double type data.
+ COMPILE_TIME_ASSERT((is_same_type<float, type>::value ||
+ is_same_type<double, type>::value ));
+
+
+
+ template <typename EXP>
+ qr_decomposition(
+ const matrix_exp<EXP>& A
+ );
+
+ bool is_full_rank(
+ ) const;
+
+ long nr(
+ ) const;
+
+ long nc(
+ ) const;
+
+ const matrix_type get_r (
+ ) const;
+
+ const matrix_type get_q (
+ ) const;
+
+ template <typename T, long R, long C, typename MM, typename L>
+ void get_q (
+ matrix<T,R,C,MM,L>& Q
+ ) const;
+
+ template <typename EXP>
+ const matrix_type solve (
+ const matrix_exp<EXP>& B
+ ) const;
+
+ private:
+
+#ifndef DLIB_USE_LAPACK
+ template <typename EXP>
+ const matrix_type solve_mat (
+ const matrix_exp<EXP>& B
+ ) const;
+
+ template <typename EXP>
+ const matrix_type solve_vect (
+ const matrix_exp<EXP>& B
+ ) const;
+#endif
+
+
+ /** Array for internal storage of decomposition.
+ @serial internal array storage.
+ */
+ matrix<type,0,0,mem_manager_type,column_major_layout> QR_;
+
+ /** Row and column dimensions.
+ @serial column dimension.
+ @serial row dimension.
+ */
+ long m, n;
+
+ /** Array for internal storage of diagonal of R.
+ @serial diagonal of R.
+ */
+ typedef matrix<type,0,1,mem_manager_type,column_major_layout> column_vector_type;
+ column_vector_type tau;
+ column_vector_type Rdiag;
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ qr_decomposition<matrix_exp_type>::
+ qr_decomposition(
+ const matrix_exp<EXP>& A
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(A.nr() >= A.nc() && A.size() > 0,
+ "\tqr_decomposition::qr_decomposition(A)"
+ << "\n\tInvalid inputs were given to this function"
+ << "\n\tA.nr(): " << A.nr()
+ << "\n\tA.nc(): " << A.nc()
+ << "\n\tA.size(): " << A.size()
+ << "\n\tthis: " << this
+ );
+
+
+ QR_ = A;
+ m = A.nr();
+ n = A.nc();
+
+#ifdef DLIB_USE_LAPACK
+
+ lapack::geqrf(QR_, tau);
+ Rdiag = diag(QR_);
+
+#else
+ Rdiag.set_size(n);
+ long i=0, j=0, k=0;
+
+ // Main loop.
+ for (k = 0; k < n; k++)
+ {
+ // Compute 2-norm of k-th column without under/overflow.
+ type nrm = 0;
+ for (i = k; i < m; i++)
+ {
+ nrm = hypot(nrm,QR_(i,k));
+ }
+
+ if (nrm != 0.0)
+ {
+ // Form k-th Householder vector.
+ if (QR_(k,k) < 0)
+ {
+ nrm = -nrm;
+ }
+ for (i = k; i < m; i++)
+ {
+ QR_(i,k) /= nrm;
+ }
+ QR_(k,k) += 1.0;
+
+ // Apply transformation to remaining columns.
+ for (j = k+1; j < n; j++)
+ {
+ type s = 0.0;
+ for (i = k; i < m; i++)
+ {
+ s += QR_(i,k)*QR_(i,j);
+ }
+ s = -s/QR_(k,k);
+ for (i = k; i < m; i++)
+ {
+ QR_(i,j) += s*QR_(i,k);
+ }
+ }
+ }
+ Rdiag(k) = -nrm;
+ }
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ long qr_decomposition<matrix_exp_type>::
+ nr (
+ ) const
+ {
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ long qr_decomposition<matrix_exp_type>::
+ nc (
+ ) const
+ {
+ return n;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ bool qr_decomposition<matrix_exp_type>::
+ is_full_rank(
+ ) const
+ {
+ type eps = max(abs(Rdiag));
+ if (eps != 0)
+ eps *= std::sqrt(std::numeric_limits<type>::epsilon())/100;
+ else
+ eps = 1; // there is no max so just use 1
+
+ // check if any of the elements of Rdiag are effectively 0
+ return min(abs(Rdiag)) > eps;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
+ get_r(
+ ) const
+ {
+ matrix_type R(n,n);
+ for (long i = 0; i < n; i++)
+ {
+ for (long j = 0; j < n; j++)
+ {
+ if (i < j)
+ {
+ R(i,j) = QR_(i,j);
+ }
+ else if (i == j)
+ {
+ R(i,j) = Rdiag(i);
+ }
+ else
+ {
+ R(i,j) = 0.0;
+ }
+ }
+ }
+ return R;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
+ get_q(
+ ) const
+ {
+ matrix_type Q;
+ get_q(Q);
+ return Q;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename T, long R, long C, typename MM, typename L>
+ void qr_decomposition<matrix_exp_type>::
+ get_q(
+ matrix<T,R,C,MM,L>& X
+ ) const
+ {
+#ifdef DLIB_USE_LAPACK
+ // Take only the first n columns of an identity matrix. This way
+ // X ends up being an m by n matrix.
+ X = colm(identity_matrix<type>(m), range(0,n-1));
+
+ // Compute Y = Q*X
+ lapack::ormqr('L','N', QR_, tau, X);
+
+#else
+ long i=0, j=0, k=0;
+
+ X.set_size(m,n);
+ for (k = n-1; k >= 0; k--)
+ {
+ for (i = 0; i < m; i++)
+ {
+ X(i,k) = 0.0;
+ }
+ X(k,k) = 1.0;
+ for (j = k; j < n; j++)
+ {
+ if (QR_(k,k) != 0)
+ {
+ type s = 0.0;
+ for (i = k; i < m; i++)
+ {
+ s += QR_(i,k)*X(i,j);
+ }
+ s = -s/QR_(k,k);
+ for (i = k; i < m; i++)
+ {
+ X(i,j) += s*QR_(i,k);
+ }
+ }
+ }
+ }
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
+ solve(
+ const matrix_exp<EXP>& B
+ ) const
+ {
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(B.nr() == nr(),
+ "\tconst matrix_type qr_decomposition::solve(B)"
+ << "\n\tInvalid inputs were given to this function"
+ << "\n\tB.nr(): " << B.nr()
+ << "\n\tnr(): " << nr()
+ << "\n\tthis: " << this
+ );
+
+#ifdef DLIB_USE_LAPACK
+
+ using namespace blas_bindings;
+ matrix<type,0,0,mem_manager_type,column_major_layout> X(B);
+ // Compute Y = transpose(Q)*B
+ lapack::ormqr('L','T',QR_, tau, X);
+ // Solve R*X = Y;
+ triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, QR_, X, n);
+
+ /* return n x nx portion of X */
+ return subm(X,0,0,n,B.nc());
+
+#else
+ // just call the right version of the solve function
+ if (B.nc() == 1)
+ return solve_vect(B);
+ else
+ return solve_mat(B);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Private member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+#ifndef DLIB_USE_LAPACK
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
+ solve_vect(
+ const matrix_exp<EXP>& B
+ ) const
+ {
+
+ column_vector_type x(B);
+
+ // Compute Y = transpose(Q)*B
+ for (long k = 0; k < n; k++)
+ {
+ type s = 0.0;
+ for (long i = k; i < m; i++)
+ {
+ s += QR_(i,k)*x(i);
+ }
+ s = -s/QR_(k,k);
+ for (long i = k; i < m; i++)
+ {
+ x(i) += s*QR_(i,k);
+ }
+ }
+ // Solve R*X = Y;
+ for (long k = n-1; k >= 0; k--)
+ {
+ x(k) /= Rdiag(k);
+ for (long i = 0; i < k; i++)
+ {
+ x(i) -= x(k)*QR_(i,k);
+ }
+ }
+
+
+ /* return n x 1 portion of x */
+ return colm(x,0,n);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
+ solve_mat(
+ const matrix_exp<EXP>& B
+ ) const
+ {
+ const long nx = B.nc();
+ matrix_type X(B);
+ long i=0, j=0, k=0;
+
+ // Compute Y = transpose(Q)*B
+ for (k = 0; k < n; k++)
+ {
+ for (j = 0; j < nx; j++)
+ {
+ type s = 0.0;
+ for (i = k; i < m; i++)
+ {
+ s += QR_(i,k)*X(i,j);
+ }
+ s = -s/QR_(k,k);
+ for (i = k; i < m; i++)
+ {
+ X(i,j) += s*QR_(i,k);
+ }
+ }
+ }
+ // Solve R*X = Y;
+ for (k = n-1; k >= 0; k--)
+ {
+ for (j = 0; j < nx; j++)
+ {
+ X(k,j) /= Rdiag(k);
+ }
+ for (i = 0; i < k; i++)
+ {
+ for (j = 0; j < nx; j++)
+ {
+ X(i,j) -= X(k,j)*QR_(i,k);
+ }
+ }
+ }
+
+ /* return n x nx portion of X */
+ return subm(X,0,0,n,nx);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_USE_LAPACK not defined
+
+}
+
+#endif // DLIB_MATRIX_QR_DECOMPOSITION_H
+
+
+
diff --git a/ml/dlib/dlib/matrix/matrix_read_from_istream.h b/ml/dlib/dlib/matrix/matrix_read_from_istream.h
new file mode 100644
index 000000000..3aced3584
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_read_from_istream.h
@@ -0,0 +1,108 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_READ_FROM_ISTREAM_H_h_
+#define DLIB_MATRIx_READ_FROM_ISTREAM_H_h_
+
+#include "matrix.h"
+#include <vector>
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline bool next_is_whitespace (
+ std::istream& in
+ )
+ {
+ return in.peek() == '\n' ||
+ in.peek() == ' ' ||
+ in.peek() == ',' ||
+ in.peek() == '\t' ||
+ in.peek() == '\r';
+ }
+ }
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ std::istream& operator>> (
+ std::istream& in,
+ matrix<T,NR,NC,MM,L>& m
+ )
+ {
+ using namespace dlib::impl;
+ long num_rows = 0;
+ std::vector<T> buf;
+ buf.reserve(100);
+
+ // eat any leading whitespace
+ while (next_is_whitespace(in))
+ in.get();
+
+ bool at_start_of_line = true;
+ bool stop = false;
+ while(!stop && in.peek() != EOF)
+ {
+ T temp;
+ in >> temp;
+ if (!in)
+ return in;
+
+ buf.push_back(temp);
+ if (at_start_of_line)
+ {
+ at_start_of_line = false;
+ ++num_rows;
+ }
+
+ // Eat next block of whitespace but also note if we hit the start of the next
+ // line.
+ while (next_is_whitespace(in))
+ {
+ if (at_start_of_line && in.peek() == '\n')
+ {
+ stop = true;
+ break;
+ }
+
+ if (in.get() == '\n')
+ at_start_of_line = true;
+ }
+ }
+
+ // It's an error for there to not be any matrix data in the input stream
+ if (num_rows == 0)
+ {
+ in.clear(in.rdstate() | std::ios::failbit);
+ return in;
+ }
+
+ const long num_cols = buf.size()/num_rows;
+ // It's also an error if the sizes don't make sense.
+ if (num_rows*num_cols != (long)buf.size() ||
+ (NR != 0 && NR != num_rows) ||
+ (NC != 0 && NC != num_cols))
+ {
+ in.clear(in.rdstate() | std::ios::failbit);
+ return in;
+ }
+
+
+ m = reshape(mat(buf),num_rows, buf.size()/num_rows);
+
+ if (in.eof())
+ {
+ // Clear the eof and fail bits since this is caused by peeking at the EOF.
+ // But in the current case, we have successfully read the matrix.
+ in.clear(in.rdstate() & (~(std::ios::eofbit | std::ios::failbit)));
+ }
+ return in;
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_MATRIx_READ_FROM_ISTREAM_H_h_
+
diff --git a/ml/dlib/dlib/matrix/matrix_subexp.h b/ml/dlib/dlib/matrix/matrix_subexp.h
new file mode 100644
index 000000000..668e57496
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_subexp.h
@@ -0,0 +1,1566 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_SUBEXP_
+#define DLIB_MATRIx_SUBEXP_
+
+#include "matrix_subexp_abstract.h"
+#include "matrix_op.h"
+#include "matrix.h"
+#include "../geometry/rectangle.h"
+#include "matrix_expressions.h"
+#include "matrix_mat.h"
+
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <long start, long inc, long end>
+ const matrix_range_static_exp<start,inc,end> range (
+ )
+ {
+ COMPILE_TIME_ASSERT(inc > 0);
+ return matrix_range_static_exp<start,inc,end>();
+ }
+
+ template <long start, long end>
+ const matrix_range_static_exp<start,1,end> range (
+ )
+ {
+ return matrix_range_static_exp<start,1,end>();
+ }
+
+ inline const matrix_range_exp<long> range (
+ long start,
+ long end
+ )
+ {
+ return matrix_range_exp<long>(start,end);
+ }
+
+ inline const matrix_range_exp<long> range (
+ long start,
+ long inc,
+ long end
+ )
+ {
+ DLIB_ASSERT(inc > 0,
+ "\tconst matrix_exp range(start, inc, end)"
+ << "\n\tInvalid inputs to this function"
+ << "\n\tstart: " << start
+ << "\n\tinc: " << inc
+ << "\n\tend: " << end
+ );
+
+ return matrix_range_exp<long>(start,inc,end);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_subm
+ {
+ op_subm (
+ const M& m_x,
+ const long& r_x,
+ const long& c_x,
+ const long& nr_x,
+ const long& nc_x
+ ) : m(m_x), r_(r_x), c_(c_x), nr_(nr_x), nc_(nc_x) { }
+
+ const M& m;
+ const long r_;
+ const long c_;
+ const long nr_;
+ const long nc_;
+
+ const static long cost = M::cost+1;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const static long NR = 0;
+ const static long NC = 0;
+
+ const_ret_type apply ( long r, long c) const { return m(r+r_,c+c_); }
+
+ long nr () const { return nr_; }
+ long nc () const { return nc_; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_subm<EXP> > subm (
+ const matrix_exp<EXP>& m,
+ long r,
+ long c,
+ long nr,
+ long nc
+ )
+ {
+ DLIB_ASSERT(r >= 0 && c >= 0 && nr >= 0 && nc >= 0 && r+nr <= m.nr() && c+nc <= m.nc(),
+ "\tconst matrix_exp subm(const matrix_exp& m, r, c, nr, nc)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tr: " << r
+ << "\n\tc: " << c
+ << "\n\tnr: " << nr
+ << "\n\tnc: " << nc
+ );
+
+ typedef op_subm<EXP> op;
+ return matrix_op<op>(op(m.ref(),r,c,nr,nc));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_subm<EXP> > subm_clipped (
+ const matrix_exp<EXP>& m,
+ long r,
+ long c,
+ long nr,
+ long nc
+ )
+ {
+ rectangle box(c,r,c+nc-1,r+nr-1);
+ box = box.intersect(get_rect(m));
+ typedef op_subm<EXP> op;
+ return matrix_op<op>(op(m.ref(),box.top(),box.left(),box.height(),box.width()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_subm<EXP> > subm (
+ const matrix_exp<EXP>& m,
+ const rectangle& rect
+ )
+ {
+ DLIB_ASSERT(get_rect(m).contains(rect) == true,
+ "\tconst matrix_exp subm(const matrix_exp& m, const rectangle& rect)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\trect.left(): " << rect.left()
+ << "\n\trect.top(): " << rect.top()
+ << "\n\trect.right(): " << rect.right()
+ << "\n\trect.bottom(): " << rect.bottom()
+ );
+
+ typedef op_subm<EXP> op;
+ return matrix_op<op>(op(m.ref(),rect.top(),rect.left(),rect.height(),rect.width()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_subm<EXP> > subm_clipped (
+ const matrix_exp<EXP>& m,
+ rectangle rect
+ )
+ {
+ rect = rect.intersect(get_rect(m));
+
+ typedef op_subm<EXP> op;
+ return matrix_op<op>(op(m.ref(),rect.top(),rect.left(),rect.height(),rect.width()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, typename M3>
+ struct op_subm_range
+ {
+ op_subm_range( const M1& m1_, const M2& rows_, const M3& cols_) :
+ m1(m1_), rows(rows_), cols(cols_) {}
+ const M1& m1;
+ const M2& rows;
+ const M3& cols;
+
+ const static long cost = M1::cost+M2::cost+M3::cost;
+ typedef typename M1::type type;
+ typedef typename M1::const_ret_type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+ const static long NR = M2::NC*M2::NR;
+ const static long NC = M3::NC*M3::NR;
+
+ const_ret_type apply ( long r, long c) const { return m1(rows(r),cols(c)); }
+
+ long nr () const { return rows.size(); }
+ long nc () const { return cols.size(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || rows.aliases(item) || cols.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || rows.aliases(item) || cols.aliases(item); }
+ };
+
+ template <
+ typename EXP,
+ typename EXPr,
+ typename EXPc
+ >
+ const matrix_op<op_subm_range<EXP,EXPr,EXPc> > subm (
+ const matrix_exp<EXP>& m,
+ const matrix_exp<EXPr>& rows,
+ const matrix_exp<EXPc>& cols
+ )
+ {
+ // the rows and cols matrices must contain integer elements
+ COMPILE_TIME_ASSERT(std::numeric_limits<typename EXPr::type>::is_integer);
+ COMPILE_TIME_ASSERT(std::numeric_limits<typename EXPc::type>::is_integer);
+
+ DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && 0 <= min(cols) && max(cols) < m.nc() &&
+ (rows.nr() == 1 || rows.nc() == 1) && (cols.nr() == 1 || cols.nc() == 1),
+ "\tconst matrix_exp subm(const matrix_exp& m, const matrix_exp& rows, const matrix_exp& cols)"
+ << "\n\tYou have given invalid arguments to this function"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tmin(rows): " << min(rows)
+ << "\n\tmax(rows): " << max(rows)
+ << "\n\tmin(cols): " << min(cols)
+ << "\n\tmax(cols): " << max(cols)
+ << "\n\trows.nr(): " << rows.nr()
+ << "\n\trows.nc(): " << rows.nc()
+ << "\n\tcols.nr(): " << cols.nr()
+ << "\n\tcols.nc(): " << cols.nc()
+ );
+
+ typedef op_subm_range<EXP,EXPr,EXPc> op;
+ return matrix_op<op>(op(m.ref(),rows.ref(),cols.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_rowm
+ {
+ op_rowm(const M& m_, const long& row_) : m(m_), row(row_) {}
+ const M& m;
+ const long row;
+
+ const static long cost = M::cost;
+ const static long NR = 1;
+ const static long NC = M::NC;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long, long c) const { return m(row,c); }
+
+ long nr () const { return 1; }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_rowm<EXP> > rowm (
+ const matrix_exp<EXP>& m,
+ long row
+ )
+ {
+ DLIB_ASSERT(row >= 0 && row < m.nr(),
+ "\tconst matrix_exp rowm(const matrix_exp& m, row)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\trow: " << row
+ );
+
+ typedef op_rowm<EXP> op;
+ return matrix_op<op>(op(m.ref(),row));
+ }
+
+ template <typename EXP>
+ struct rowm_exp
+ {
+ typedef matrix_op<op_rowm<EXP> > type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_rowm2
+ {
+ op_rowm2(const M& m_, const long& row_, const long& len) : m(m_), row(row_), length(len) {}
+ const M& m;
+ const long row;
+ const long length;
+
+ const static long cost = M::cost;
+ const static long NR = 1;
+ const static long NC = 0;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long , long c) const { return m(row,c); }
+
+ long nr () const { return 1; }
+ long nc () const { return length; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_rowm2<EXP> > rowm (
+ const matrix_exp<EXP>& m,
+ long row,
+ long length
+ )
+ {
+ DLIB_ASSERT(row >= 0 && row < m.nr() &&
+ length >= 0 && length <= m.nc(),
+ "\tconst matrix_exp rowm(const matrix_exp& m, row, length)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\trow: " << row
+ << "\n\tlength: " << length
+ );
+
+ typedef op_rowm2<EXP> op;
+ return matrix_op<op>(op(m.ref(), row, length));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_rowm_range
+ {
+ op_rowm_range( const M1& m1_, const M2& rows_) : m1(m1_), rows(rows_) {}
+ const M1& m1;
+ const M2& rows;
+
+ const static long cost = M1::cost+M2::cost;
+ typedef typename M1::type type;
+ typedef typename M1::const_ret_type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+ const static long NR = M2::NC*M2::NR;
+ const static long NC = M1::NC;
+
+ const_ret_type apply ( long r, long c) const { return m1(rows(r),c); }
+
+ long nr () const { return rows.size(); }
+ long nc () const { return m1.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || rows.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || rows.aliases(item); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ const matrix_op<op_rowm_range<EXP1,EXP2> > rowm (
+ const matrix_exp<EXP1>& m,
+ const matrix_exp<EXP2>& rows
+ )
+ {
+ // the rows matrix must contain integer elements
+ COMPILE_TIME_ASSERT(std::numeric_limits<typename EXP2::type>::is_integer);
+
+ DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && (rows.nr() == 1 || rows.nc() == 1),
+ "\tconst matrix_exp rowm(const matrix_exp& m, const matrix_exp& rows)"
+ << "\n\tYou have given invalid arguments to this function"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tmin(rows): " << min(rows)
+ << "\n\tmax(rows): " << max(rows)
+ << "\n\trows.nr(): " << rows.nr()
+ << "\n\trows.nc(): " << rows.nc()
+ );
+
+ typedef op_rowm_range<EXP1,EXP2> op;
+ return matrix_op<op>(op(m.ref(),rows.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_colm
+ {
+ op_colm(const M& m_, const long& col_) : m(m_), col(col_) {}
+ const M& m;
+ const long col;
+
+ const static long cost = M::cost;
+ const static long NR = M::NR;
+ const static long NC = 1;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long) const { return m(r,col); }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_colm<EXP> > colm (
+ const matrix_exp<EXP>& m,
+ long col
+ )
+ {
+ DLIB_ASSERT(col >= 0 && col < m.nc(),
+ "\tconst matrix_exp colm(const matrix_exp& m, row)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tcol: " << col
+ );
+
+ typedef op_colm<EXP> op;
+ return matrix_op<op>(op(m.ref(),col));
+ }
+
+ template <typename EXP>
+ struct colm_exp
+ {
+ typedef matrix_op<op_colm<EXP> > type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_colm2
+ {
+ op_colm2(const M& m_, const long& col_, const long& len) : m(m_), col(col_), length(len) {}
+ const M& m;
+ const long col;
+ const long length;
+
+ const static long cost = M::cost;
+ const static long NR = 0;
+ const static long NC = 1;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long ) const { return m(r,col); }
+
+ long nr () const { return length; }
+ long nc () const { return 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_colm2<EXP> > colm (
+ const matrix_exp<EXP>& m,
+ long col,
+ long length
+ )
+ {
+ DLIB_ASSERT(col >= 0 && col < m.nc() &&
+ length >= 0 && length <= m.nr(),
+ "\tconst matrix_exp colm(const matrix_exp& m, col, length)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tcol: " << col
+ << "\n\tlength: " << length
+ );
+
+ typedef op_colm2<EXP> op;
+ return matrix_op<op>(op(m.ref(),col, length));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_colm_range
+ {
+ op_colm_range( const M1& m1_, const M2& cols_) : m1(m1_), cols(cols_) {}
+ const M1& m1;
+ const M2& cols;
+
+ typedef typename M1::type type;
+ typedef typename M1::const_ret_type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+ const static long NR = M1::NR;
+ const static long NC = M2::NC*M2::NR;
+ const static long cost = M1::cost+M2::cost;
+
+ const_ret_type apply (long r, long c) const { return m1(r,cols(c)); }
+
+ long nr () const { return m1.nr(); }
+ long nc () const { return cols.size(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || cols.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || cols.aliases(item); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ const matrix_op<op_colm_range<EXP1,EXP2> > colm (
+ const matrix_exp<EXP1>& m,
+ const matrix_exp<EXP2>& cols
+ )
+ {
+ // the rows matrix must contain integer elements
+ COMPILE_TIME_ASSERT(std::numeric_limits<typename EXP2::type>::is_integer);
+
+ DLIB_ASSERT(0 <= min(cols) && max(cols) < m.nc() && (cols.nr() == 1 || cols.nc() == 1),
+ "\tconst matrix_exp colm(const matrix_exp& m, const matrix_exp& cols)"
+ << "\n\tYou have given invalid arguments to this function"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tmin(cols): " << min(cols)
+ << "\n\tmax(cols): " << max(cols)
+ << "\n\tcols.nr(): " << cols.nr()
+ << "\n\tcols.nc(): " << cols.nc()
+ );
+
+ typedef op_colm_range<EXP1,EXP2> op;
+ return matrix_op<op>(op(m.ref(),cols.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class assignable_ptr_matrix
+ {
+ public:
+ typedef T type;
+ typedef row_major_layout layout_type;
+ typedef matrix<T,0,0,default_memory_manager,layout_type> matrix_type;
+
+ assignable_ptr_matrix(
+ T* ptr_,
+ long nr_,
+ long nc_
+ ) : ptr(ptr_), height(nr_), width(nc_){}
+
+ T& operator() (
+ long r,
+ long c
+ )
+ {
+ return ptr[r*width + c];
+ }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const
+ {
+ return ptr[r*width + c];
+ }
+
+ long nr() const { return height; }
+ long nc() const { return width; }
+
+ template <typename EXP>
+ assignable_ptr_matrix& operator= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ // You can only assign to a set_ptrm() expression with a source matrix that
+ // contains the same type of elements as the target (i.e. you can't mix double
+ // and float types).
+ COMPILE_TIME_ASSERT((is_same_type<T, typename EXP::type>::value == true));
+
+ DLIB_ASSERT( exp.nr() == height && exp.nc() == width,
+ "\tassignable_matrix_expression set_ptrm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\twidth (target matrix): " << width
+ << "\n\theight (target matrix): " << height
+ );
+
+ if (exp.destructively_aliases(mat(ptr,height,width)) == false)
+ {
+ matrix_assign(*this, exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to ptr to
+ // avoid aliasing issues during the copy
+ this->operator=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_ptr_matrix& operator+= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ // You can only assign to a set_ptrm() expression with a source matrix that
+ // contains the same type of elements as the target (i.e. you can't mix double
+ // and float types).
+ COMPILE_TIME_ASSERT((is_same_type<T, typename EXP::type>::value == true));
+
+ DLIB_ASSERT( exp.nr() == height && exp.nc() == width,
+ "\tassignable_matrix_expression set_ptrm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\twidth (target matrix): " << width
+ << "\n\theight (target matrix): " << height
+ );
+
+ if (exp.destructively_aliases(mat(ptr,height,width)) == false)
+ {
+ matrix_assign(*this, mat(ptr,height,width)+exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to ptr to
+ // avoid aliasing issues during the copy
+ this->operator+=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_ptr_matrix& operator-= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ // You can only assign to a set_ptrm() expression with a source matrix that
+ // contains the same type of elements as the target (i.e. you can't mix double
+ // and float types).
+ COMPILE_TIME_ASSERT((is_same_type<T, typename EXP::type>::value == true));
+
+ DLIB_ASSERT( exp.nr() == height && exp.nc() == width,
+ "\tassignable_matrix_expression set_ptrm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\twidth (target matrix): " << width
+ << "\n\theight (target matrix): " << height
+ );
+
+ if (exp.destructively_aliases(mat(ptr,height,width)) == false)
+ {
+ matrix_assign(*this, mat(ptr,height,width)-exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to ptr to
+ // avoid aliasing issues during the copy
+ this->operator-=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ assignable_ptr_matrix& operator= (
+ const T& value
+ )
+ {
+ const long size = width*height;
+ for (long i = 0; i < size; ++i)
+ ptr[i] = value;
+
+ return *this;
+ }
+
+ assignable_ptr_matrix& operator+= (
+ const T& value
+ )
+ {
+ const long size = width*height;
+ for (long i = 0; i < size; ++i)
+ ptr[i] += value;
+
+ return *this;
+ }
+
+ assignable_ptr_matrix& operator-= (
+ const T& value
+ )
+ {
+ const long size = width*height;
+ for (long i = 0; i < size; ++i)
+ ptr[i] -= value;
+
+ return *this;
+ }
+
+
+ T* ptr;
+ const long height;
+ const long width;
+ };
+
+
+ template <typename T>
+ assignable_ptr_matrix<T> set_ptrm (
+ T* ptr,
+ long nr,
+ long nc = 1
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0,
+ "\t assignable_matrix_expression set_ptrm(T* ptr, long nr, long nc)"
+ << "\n\t The dimensions can't be negative."
+ << "\n\t nr: " << nr
+ << "\n\t nc: " << nc
+ );
+
+
+ return assignable_ptr_matrix<T>(ptr,nr,nc);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename mm, typename l>
+ class assignable_sub_matrix
+ {
+ public:
+ typedef T type;
+ typedef l layout_type;
+ typedef matrix<T,NR,NC,mm,l> matrix_type;
+
+ assignable_sub_matrix(
+ matrix<T,NR,NC,mm,l>& m_,
+ long top_,
+ long left_,
+ long height_,
+ long width_
+ ) : m(m_), left(left_), top(top_), width(width_), height(height_) {}
+
+ T& operator() (
+ long r,
+ long c
+ )
+ {
+ return m(r+top,c+left);
+ }
+
+ const T& operator() (
+ long r,
+ long c
+ ) const
+ {
+ return m(r+top,c+left);
+ }
+
+ long nr() const { return height; }
+ long nc() const { return width; }
+
+ template <typename EXP>
+ assignable_sub_matrix& operator= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nr() == height && exp.nc() == width,
+ "\tassignable_matrix_expression set_subm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\twidth (target matrix): " << width
+ << "\n\theight (target matrix): " << height
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_sub_matrix& operator+= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nr() == height && exp.nc() == width,
+ "\tassignable_matrix_expression set_subm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\twidth (target matrix): " << width
+ << "\n\theight (target matrix): " << height
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, subm(m,top,left,height,width)+exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator+=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_sub_matrix& operator-= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nr() == height && exp.nc() == width,
+ "\tassignable_matrix_expression set_subm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\twidth (target matrix): " << width
+ << "\n\theight (target matrix): " << height
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, subm(m,top,left,height,width)-exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator-=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ assignable_sub_matrix& operator= (
+ const T& value
+ )
+ {
+ const long bottom = top+height-1;
+ const long right = left+width-1;
+ for (long r = top; r <= bottom; ++r)
+ {
+ for (long c = left; c <= right; ++c)
+ {
+ m(r,c) = value;
+ }
+ }
+
+ return *this;
+ }
+
+ assignable_sub_matrix& operator+= (
+ const T& value
+ )
+ {
+ const long bottom = top+height-1;
+ const long right = left+width-1;
+ for (long r = top; r <= bottom; ++r)
+ {
+ for (long c = left; c <= right; ++c)
+ {
+ m(r,c) += value;
+ }
+ }
+
+ return *this;
+ }
+
+ assignable_sub_matrix& operator-= (
+ const T& value
+ )
+ {
+ const long bottom = top+height-1;
+ const long right = left+width-1;
+ for (long r = top; r <= bottom; ++r)
+ {
+ for (long c = left; c <= right; ++c)
+ {
+ m(r,c) -= value;
+ }
+ }
+
+ return *this;
+ }
+
+
+ matrix<T,NR,NC,mm,l>& m;
+ const long left, top, width, height;
+ };
+
+
+ template <typename T, long NR, long NC, typename mm, typename l>
+ assignable_sub_matrix<T,NR,NC,mm,l> set_subm (
+ matrix<T,NR,NC,mm,l>& m,
+ const rectangle& rect
+ )
+ {
+ DLIB_ASSERT(get_rect(m).contains(rect) == true,
+ "\tassignable_matrix_expression set_subm(matrix& m, const rectangle& rect)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\trect.left(): " << rect.left()
+ << "\n\trect.top(): " << rect.top()
+ << "\n\trect.right(): " << rect.right()
+ << "\n\trect.bottom(): " << rect.bottom()
+ );
+
+
+ return assignable_sub_matrix<T,NR,NC,mm,l>(m,rect.top(), rect.left(), rect.height(), rect.width());
+ }
+
+
+ template <typename T, long NR, long NC, typename mm, typename l>
+ assignable_sub_matrix<T,NR,NC,mm,l> set_subm (
+ matrix<T,NR,NC,mm,l>& m,
+ long r,
+ long c,
+ long nr,
+ long nc
+ )
+ {
+ DLIB_ASSERT(r >= 0 && c >= 0 && nr >= 0 && nc >= 0 && r+nr <= m.nr() && c+nc <= m.nc(),
+ "\tassignable_matrix_expression set_subm(matrix& m, r, c, nr, nc)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tr: " << r
+ << "\n\tc: " << c
+ << "\n\tnr: " << nr
+ << "\n\tnc: " << nc
+ );
+
+ return assignable_sub_matrix<T,NR,NC,mm,l>(m,r,c, nr, nc);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename mm, typename l, typename EXPr, typename EXPc>
+ class assignable_sub_range_matrix
+ {
+ public:
+ typedef T type;
+ typedef l layout_type;
+ typedef matrix<T,NR,NC,mm,l> matrix_type;
+
+ assignable_sub_range_matrix(
+ matrix<T,NR,NC,mm,l>& m_,
+ const EXPr& rows_,
+ const EXPc& cols_
+ ) : m(m_), rows(rows_), cols(cols_) {}
+
+ T& operator() (
+ long r,
+ long c
+ )
+ {
+ return m(rows(r),cols(c));
+ }
+
+ long nr() const { return rows.size(); }
+ long nc() const { return cols.size(); }
+
+
+ template <typename EXP>
+ assignable_sub_range_matrix& operator= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nr() == rows.size() && exp.nc() == cols.size(),
+ "\tassignable_matrix_expression set_subm(matrix& m, const matrix_exp rows, const matrix_exp cols)"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\trows.size() (target matrix): " << rows.size()
+ << "\n\tcols.size() (target matrix): " << cols.size()
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_sub_range_matrix& operator+= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nr() == rows.size() && exp.nc() == cols.size(),
+ "\tassignable_matrix_expression set_subm(matrix& m, const matrix_exp rows, const matrix_exp cols)"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\trows.size() (target matrix): " << rows.size()
+ << "\n\tcols.size() (target matrix): " << cols.size()
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, subm(m,rows,cols)+exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator+=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_sub_range_matrix& operator-= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nr() == rows.size() && exp.nc() == cols.size(),
+ "\tassignable_matrix_expression set_subm(matrix& m, const matrix_exp rows, const matrix_exp cols)"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\trows.size() (target matrix): " << rows.size()
+ << "\n\tcols.size() (target matrix): " << cols.size()
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, subm(m,rows,cols)-exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator-=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ assignable_sub_range_matrix& operator= (
+ const T& value
+ )
+ {
+ for (long r = 0; r < rows.size(); ++r)
+ {
+ for (long c = 0; c < cols.size(); ++c)
+ {
+ m(rows(r),cols(c)) = value;
+ }
+ }
+
+ return *this;
+ }
+
+ assignable_sub_range_matrix& operator+= (
+ const T& value
+ )
+ {
+ for (long r = 0; r < rows.size(); ++r)
+ {
+ for (long c = 0; c < cols.size(); ++c)
+ {
+ m(rows(r),cols(c)) += value;
+ }
+ }
+
+ return *this;
+ }
+
+ assignable_sub_range_matrix& operator-= (
+ const T& value
+ )
+ {
+ for (long r = 0; r < rows.size(); ++r)
+ {
+ for (long c = 0; c < cols.size(); ++c)
+ {
+ m(rows(r),cols(c)) -= value;
+ }
+ }
+
+ return *this;
+ }
+
+ private:
+
+ matrix<T,NR,NC,mm,l>& m;
+ const EXPr rows;
+ const EXPc cols;
+ };
+
+ template <typename T, long NR, long NC, typename mm, typename l, typename EXPr, typename EXPc>
+ assignable_sub_range_matrix<T,NR,NC,mm,l,EXPr,EXPc > set_subm (
+ matrix<T,NR,NC,mm,l>& m,
+ const matrix_exp<EXPr>& rows,
+ const matrix_exp<EXPc>& cols
+ )
+ {
+ DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && 0 <= min(cols) && max(cols) < m.nc() &&
+ (rows.nr() == 1 || rows.nc() == 1) && (cols.nr() == 1 || cols.nc() == 1),
+ "\tassignable_matrix_expression set_subm(matrix& m, const matrix_exp& rows, const matrix_exp& cols)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tmin(rows): " << min(rows)
+ << "\n\tmax(rows): " << max(rows)
+ << "\n\tmin(cols): " << min(cols)
+ << "\n\tmax(cols): " << max(cols)
+ << "\n\trows.nr(): " << rows.nr()
+ << "\n\trows.nc(): " << rows.nc()
+ << "\n\tcols.nr(): " << cols.nr()
+ << "\n\tcols.nc(): " << cols.nc()
+ );
+
+ return assignable_sub_range_matrix<T,NR,NC,mm,l,EXPr,EXPc >(m,rows.ref(),cols.ref());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename mm, typename l, typename EXPr>
+ assignable_sub_range_matrix<T,NR,NC,mm,l,EXPr,matrix_range_exp<long> > set_rowm (
+ matrix<T,NR,NC,mm,l>& m,
+ const matrix_exp<EXPr>& rows
+ )
+ {
+ DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && (rows.nr() == 1 || rows.nc() == 1),
+ "\tassignable_matrix_expression set_rowm(matrix& m, const matrix_exp& rows)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tmin(rows): " << min(rows)
+ << "\n\tmax(rows): " << max(rows)
+ << "\n\trows.nr(): " << rows.nr()
+ << "\n\trows.nc(): " << rows.nc()
+ );
+
+ return assignable_sub_range_matrix<T,NR,NC,mm,l,EXPr,matrix_range_exp<long> >(m,rows.ref(),range(0,m.nc()-1));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename mm, typename l, typename EXPc>
+ assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_range_exp<long>,EXPc > set_colm (
+ matrix<T,NR,NC,mm,l>& m,
+ const matrix_exp<EXPc>& cols
+ )
+ {
+ DLIB_ASSERT(0 <= min(cols) && max(cols) < m.nc() && (cols.nr() == 1 || cols.nc() == 1),
+ "\tassignable_matrix_expression set_colm(matrix& m, const matrix_exp& cols)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tmin(cols): " << min(cols)
+ << "\n\tmax(cols): " << max(cols)
+ << "\n\tcols.nr(): " << cols.nr()
+ << "\n\tcols.nc(): " << cols.nc()
+ );
+
+ return assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_range_exp<long>,EXPc >(m,range(0,m.nr()-1),cols.ref());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename mm, typename l>
+ class assignable_col_matrix
+ {
+ public:
+ typedef T type;
+ typedef l layout_type;
+ typedef matrix<T,NR,NC,mm,l> matrix_type;
+
+ assignable_col_matrix(
+ matrix<T,NR,NC,mm,l>& m_,
+ const long col_
+ ) : m(m_), col(col_) {}
+
+ T& operator() (
+ long r,
+ long
+ )
+ {
+ return m(r,col);
+ }
+
+ const T& operator() (
+ long r,
+ long
+ ) const
+ {
+ return m(r,col);
+ }
+
+ long nr() const { return m.nr(); }
+ long nc() const { return 1; }
+
+ template <typename EXP>
+ assignable_col_matrix& operator= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nc() == 1 && exp.nr() == m.nr(),
+ "\tassignable_matrix_expression set_colm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\tm.nr() (target matrix): " << m.nr()
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_col_matrix& operator+= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nc() == 1 && exp.nr() == m.nr(),
+ "\tassignable_matrix_expression set_colm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\tm.nr() (target matrix): " << m.nr()
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, colm(m,col)+exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator+=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_col_matrix& operator-= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nc() == 1 && exp.nr() == m.nr(),
+ "\tassignable_matrix_expression set_colm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\tm.nr() (target matrix): " << m.nr()
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, colm(m,col)-exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator-=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ assignable_col_matrix& operator= (
+ const T& value
+ )
+ {
+ for (long i = 0; i < m.nr(); ++i)
+ {
+ m(i,col) = value;
+ }
+
+ return *this;
+ }
+
+ assignable_col_matrix& operator+= (
+ const T& value
+ )
+ {
+ for (long i = 0; i < m.nr(); ++i)
+ {
+ m(i,col) += value;
+ }
+
+ return *this;
+ }
+
+ assignable_col_matrix& operator-= (
+ const T& value
+ )
+ {
+ for (long i = 0; i < m.nr(); ++i)
+ {
+ m(i,col) -= value;
+ }
+
+ return *this;
+ }
+
+
+ matrix<T,NR,NC,mm,l>& m;
+ const long col;
+ };
+
+
+ template <typename T, long NR, long NC, typename mm, typename l>
+ assignable_col_matrix<T,NR,NC,mm,l> set_colm (
+ matrix<T,NR,NC,mm,l>& m,
+ const long col
+ )
+ {
+ DLIB_ASSERT(col >= 0 && col < m.nc(),
+ "\tassignable_matrix_expression set_colm(matrix& m, col)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tcol: " << col
+ );
+
+
+ return assignable_col_matrix<T,NR,NC,mm,l>(m,col);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <typename T, long NR, long NC, typename mm, typename l>
+ class assignable_row_matrix
+ {
+ public:
+ typedef T type;
+ typedef l layout_type;
+ typedef matrix<T,NR,NC,mm,l> matrix_type;
+
+ assignable_row_matrix(
+ matrix<T,NR,NC,mm,l>& m_,
+ const long row_
+ ) : m(m_), row(row_) {}
+
+
+ T& operator() (
+ long ,
+ long c
+ )
+ {
+ return m(row,c);
+ }
+
+ const T& operator() (
+ long ,
+ long c
+ ) const
+ {
+ return m(row,c);
+ }
+
+ long nr() const { return 1; }
+ long nc() const { return m.nc(); }
+
+
+ template <typename EXP>
+ assignable_row_matrix& operator= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nr() == 1 && exp.nc() == m.nc(),
+ "\tassignable_matrix_expression set_rowm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\tm.nc() (target matrix): " << m.nc()
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_row_matrix& operator+= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nr() == 1 && exp.nc() == m.nc(),
+ "\tassignable_matrix_expression set_rowm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\tm.nc() (target matrix): " << m.nc()
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, rowm(m,row)+exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator+=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ template <typename EXP>
+ assignable_row_matrix& operator-= (
+ const matrix_exp<EXP>& exp
+ )
+ {
+ DLIB_ASSERT( exp.nr() == 1 && exp.nc() == m.nc(),
+ "\tassignable_matrix_expression set_rowm()"
+ << "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
+ << "\n\texp.nr() (source matrix): " << exp.nr()
+ << "\n\texp.nc() (source matrix): " << exp.nc()
+ << "\n\tm.nc() (target matrix): " << m.nc()
+ );
+
+ if (exp.destructively_aliases(m) == false)
+ {
+ matrix_assign(*this, rowm(m,row)-exp);
+ }
+ else
+ {
+ // make a temporary copy of the matrix we are going to assign to m to
+ // avoid aliasing issues during the copy
+ this->operator-=(tmp(exp));
+ }
+
+ return *this;
+ }
+
+ assignable_row_matrix& operator= (
+ const T& value
+ )
+ {
+ for (long i = 0; i < m.nc(); ++i)
+ {
+ m(row,i) = value;
+ }
+
+ return *this;
+ }
+
+ assignable_row_matrix& operator+= (
+ const T& value
+ )
+ {
+ for (long i = 0; i < m.nc(); ++i)
+ {
+ m(row,i) += value;
+ }
+
+ return *this;
+ }
+
+ assignable_row_matrix& operator-= (
+ const T& value
+ )
+ {
+ for (long i = 0; i < m.nc(); ++i)
+ {
+ m(row,i) -= value;
+ }
+
+ return *this;
+ }
+
+
+ matrix<T,NR,NC,mm,l>& m;
+ const long row;
+ };
+
+
+ template <typename T, long NR, long NC, typename mm, typename l>
+ assignable_row_matrix<T,NR,NC,mm,l> set_rowm (
+ matrix<T,NR,NC,mm,l>& m,
+ const long row
+ )
+ {
+ DLIB_ASSERT(row >= 0 && row < m.nr(),
+ "\tassignable_matrix_expression set_rowm(matrix& m, row)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\trow: " << row
+ );
+
+
+ return assignable_row_matrix<T,NR,NC,mm,l>(m,row);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_SUBEXP_
+
diff --git a/ml/dlib/dlib/matrix/matrix_subexp_abstract.h b/ml/dlib/dlib/matrix/matrix_subexp_abstract.h
new file mode 100644
index 000000000..2665d1b99
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_subexp_abstract.h
@@ -0,0 +1,570 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MATRIx_SUBEXP_ABSTRACT_
+#ifdef DLIB_MATRIx_SUBEXP_ABSTRACT_
+
+#include "matrix_abstract.h"
+#include "../geometry/rectangle.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <long start, long inc, long end>
+ const matrix_exp range (
+ );
+ /*!
+ requires
+ - inc > 0
+ ensures
+ - returns a matrix R such that:
+ - R::type == long
+ - R.nr() == 1
+ - R.nc() == abs(end - start)/inc + 1
+ - if (start <= end) then
+ - R(i) == start + i*inc
+ - else
+ - R(i) == start - i*inc
+ !*/
+
+ template <long start, long end>
+ const matrix_exp range (
+ ) { return range<start,1,end>(); }
+
+ const matrix_exp range (
+ long start,
+ long inc,
+ long end
+ );
+ /*!
+ requires
+ - inc > 0
+ ensures
+ - returns a matrix R such that:
+ - R::type == long
+ - R.nr() == 1
+ - R.nc() == abs(end - start)/inc + 1
+ - if (start <= end) then
+ - R(i) == start + i*inc
+ - else
+ - R(i) == start - i*inc
+ !*/
+
+ const matrix_exp range (
+ long start,
+ long end
+ ) { return range(start,1,end); }
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp subm (
+ const matrix_exp& m,
+ const matrix_exp& rows,
+ const matrix_exp& cols,
+ );
+ /*!
+ requires
+ - rows and cols contain integral elements (e.g. int, long)
+ - 0 <= min(rows) && max(rows) < m.nr()
+ - 0 <= min(cols) && max(cols) < m.nc()
+ - rows.nr() == 1 || rows.nc() == 1
+ - cols.nr() == 1 || cols.nc() == 1
+ (i.e. rows and cols must be vectors)
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R.nr() == rows.size()
+ - R.nc() == cols.size()
+ - for all valid r and c:
+ R(r,c) == m(rows(r),cols(c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp subm (
+ const matrix_exp& m,
+ long row,
+ long col,
+ long nr,
+ long nc
+ );
+ /*!
+ requires
+ - row >= 0
+ - col >= 0
+ - nr >= 0
+ - nc >= 0
+ - row + nr <= m.nr()
+ - col + nc <= m.nc()
+ ensures
+ - returns a matrix R such that:
+ - R.nr() == nr
+ - R.nc() == nc
+ - for all valid r and c:
+ R(r, c) == m(r+row,c+col)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp subm (
+ const matrix_exp& m,
+ const rectangle& rect
+ );
+ /*!
+ requires
+ - get_rect(m).contains(rect) == true
+ (i.e. rect is a region inside the matrix m)
+ ensures
+ - returns a matrix R such that:
+ - R.nr() == rect.height()
+ - R.nc() == rect.width()
+ - for all valid r and c:
+ R(r, c) == m(r+rect.top(), c+rect.left())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp subm_clipped (
+ const matrix_exp& m,
+ long row,
+ long col,
+ long nr,
+ long nc
+ );
+ /*!
+ ensures
+ - This function is just like subm() except that it will automatically clip the
+ indicated sub matrix window so that it does not extend outside m.
+ In particular:
+ - Let box = rectangle(col,row,col+nc-1,row+nr-1)
+ (i.e. the box that contains the indicated sub matrix)
+ - Let box_clipped = box.intersect(get_rect(m))
+ - Then this function returns a matrix R such that:
+ - R.nr() == box_clipped.height()
+ - R.nc() == box_clipped.width()
+ - for all valid r and c:
+ R(r, c) == m(r+box_clipped.top(),c+box_clipped.left())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp subm_clipped (
+ const matrix_exp& m,
+ const rectangle& rect
+ );
+ /*!
+ ensures
+ - Let box_clipped == rect.intersect(get_rect(m))
+ - returns a matrix R such that:
+ - R.nr() == box_clipped.height()
+ - R.nc() == box_clipped.width()
+ - for all valid r and c:
+ R(r, c) == m(r+box_clipped.top(), c+box_clipped.left())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp rowm (
+ const matrix_exp& m,
+ long row
+ );
+ /*!
+ requires
+ - 0 <= row < m.nr()
+ ensures
+ - returns a matrix R such that:
+ - R.nr() == 1
+ - R.nc() == m.nc()
+ - for all valid i:
+ R(i) == m(row,i)
+ !*/
+
+ template <typename EXP>
+ struct rowm_exp
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This struct allows you to determine the type of matrix expression
+ object returned from the rowm(m,row) function. An example makes its
+ use clear:
+
+ template <typename EXP>
+ void do_something( const matrix_exp<EXP>& mat)
+ {
+ // r is a matrix expression that aliases mat.
+ typename rowm_exp<EXP>::type r = rowm(mat,0);
+
+ // Print the first row of mat. So we see that by using
+ // rowm_exp we can save the object returned by rowm() in
+ // a local variable.
+ cout << r << endl;
+
+ // Note that you can only save the return value of rowm() to
+ // a local variable if the argument to rowm() has a lifetime
+ // beyond the rowm() expression. The example shown above is
+ // OK but the following would result in undefined behavior:
+ typename rowm_exp<EXP>::type bad = rowm(mat + mat,0);
+ }
+ !*/
+ typedef type_of_expression_returned_by_rowm type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp rowm (
+ const matrix_exp& m,
+ long row,
+ long length
+ );
+ /*!
+ requires
+ - 0 <= row < m.nr()
+ - 0 <= length <= m.nc()
+ ensures
+ - returns a matrix R such that:
+ - R.nr() == 1
+ - R.nc() == length
+ - for all valid i:
+ R(i) == m(row,i)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp rowm (
+ const matrix_exp& m,
+ const matrix_exp& rows
+ );
+ /*!
+ requires
+ - rows contains integral elements (e.g. int, long)
+ - 0 <= min(rows) && max(rows) < m.nr()
+ - rows.nr() == 1 || rows.nc() == 1
+ (i.e. rows must be a vector)
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R.nr() == rows.size()
+ - R.nc() == m.nc()
+ - for all valid r and c:
+ R(r,c) == m(rows(r),c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp colm (
+ const matrix_exp& m,
+ long col
+ );
+ /*!
+ requires
+ - 0 <= col < m.nc()
+ ensures
+ - returns a matrix R such that:
+ - R.nr() == m.nr()
+ - R.nc() == 1
+ - for all valid i:
+ R(i) == m(i,col)
+ !*/
+
+ template <typename EXP>
+ struct colm_exp
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This struct allows you to determine the type of matrix expression
+ object returned from the colm(m,col) function. An example makes its
+ use clear:
+
+ template <typename EXP>
+ void do_something( const matrix_exp<EXP>& mat)
+ {
+ // c is a matrix expression that aliases mat.
+ typename colm_exp<EXP>::type c = colm(mat,0);
+
+ // Print the first column of mat. So we see that by using
+ // colm_exp we can save the object returned by colm() in
+ // a local variable.
+ cout << c << endl;
+
+ // Note that you can only save the return value of colm() to
+ // a local variable if the argument to colm() has a lifetime
+ // beyond the colm() expression. The example shown above is
+ // OK but the following would result in undefined behavior:
+ typename colm_exp<EXP>::type bad = colm(mat + mat,0);
+ }
+ !*/
+ typedef type_of_expression_returned_by_colm type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp colm (
+ const matrix_exp& m,
+ long col,
+ long length
+ );
+ /*!
+ requires
+ - 0 <= col < m.nc()
+ - 0 <= length <= m.nr()
+ ensures
+ - returns a matrix R such that:
+ - R.nr() == length
+ - R.nc() == 1
+ - for all valid i:
+ R(i) == m(i,col)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp colm (
+ const matrix_exp& m,
+ const matrix_exp& cols
+ );
+ /*!
+ requires
+ - cols contains integral elements (e.g. int, long)
+ - 0 <= min(cols) && max(cols) < m.nc()
+ - cols.nr() == 1 || cols.nc() == 1
+ (i.e. cols must be a vector)
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R.nr() == m.nr()
+ - R.nc() == cols.size()
+ - for all valid r and c:
+ R(r,c) == m(r,cols(c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ assignable_matrix_expression set_ptrm (
+ T* ptr,
+ long nr,
+ long nc = 1
+ );
+ /*!
+ requires
+ - ptr == a pointer to nr*nc elements of type T
+ - nr >= 0
+ - nc >= 0
+ ensures
+ - statements of the following form:
+ - set_ptrm(ptr,nr,nc) = some_matrix;
+ result in it being the case that:
+ - mat(ptr,nr,nc) == some_matrix.
+
+ - statements of the following form:
+ - set_ptrm(ptr,nr,nc) = scalar_value;
+ result in it being the case that:
+ - mat(ptr,nr,nc) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
+
+ - In addition to the normal assignment statements using the = symbol, you may
+ also use the usual += and -= versions of the assignment operator. In these
+ cases, they have their usual effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ assignable_matrix_expression set_subm (
+ matrix& m,
+ long row,
+ long col,
+ long nr,
+ long nc
+ );
+ /*!
+ requires
+ - row >= 0
+ - col >= 0
+ - nr >= 0
+ - nc >= 0
+ - row + nr <= m.nr()
+ - col + nc <= m.nc()
+ ensures
+ - statements of the following form:
+ - set_subm(m,row,col,nr,nc) = some_matrix;
+ result in it being the case that:
+ - subm(m,row,col,nr,nc) == some_matrix.
+
+ - statements of the following form:
+ - set_subm(m,row,col,nr,nc) = scalar_value;
+ result in it being the case that:
+ - subm(m,row,col,nr,nc) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
+
+ - In addition to the normal assignment statements using the = symbol, you may
+ also use the usual += and -= versions of the assignment operator. In these
+ cases, they have their usual effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ assignable_matrix_expression set_subm (
+ matrix& m,
+ const rectangle& rect
+ );
+ /*!
+ requires
+ - get_rect(m).contains(rect) == true
+ (i.e. rect is a region inside the matrix m)
+ ensures
+ - statements of the following form:
+ - set_subm(m,rect) = some_matrix;
+ result in it being the case that:
+ - subm(m,rect) == some_matrix.
+
+ - statements of the following form:
+ - set_subm(m,rect) = scalar_value;
+ result in it being the case that:
+ - subm(m,rect) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
+
+ - In addition to the normal assignment statements using the = symbol, you may
+ also use the usual += and -= versions of the assignment operator. In these
+ cases, they have their usual effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ assignable_matrix_expression set_subm (
+ matrix& m,
+ const matrix_exp& rows,
+ const matrix_exp& cols
+ );
+ /*!
+ requires
+ - rows and cols contain integral elements (e.g. int, long)
+ - 0 <= min(rows) && max(rows) < m.nr()
+ - 0 <= min(cols) && max(cols) < m.nc()
+ - rows.nr() == 1 || rows.nc() == 1
+ - cols.nr() == 1 || cols.nc() == 1
+ (i.e. rows and cols must be vectors)
+ ensures
+ - statements of the following form:
+ - set_subm(m,rows,cols) = some_matrix;
+ result in it being the case that:
+ - subm(m,rows,cols) == some_matrix.
+
+ - statements of the following form:
+ - set_subm(m,rows,cols) = scalar_value;
+ result in it being the case that:
+ - subm(m,rows,cols) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
+
+ - In addition to the normal assignment statements using the = symbol, you may
+ also use the usual += and -= versions of the assignment operator. In these
+ cases, they have their usual effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ assignable_matrix_expression set_rowm (
+ matrix& m,
+ long row
+ );
+ /*!
+ requires
+ - 0 <= row < m.nr()
+ ensures
+ - statements of the following form:
+ - set_rowm(m,row) = some_matrix;
+ result in it being the case that:
+ - rowm(m,row) == some_matrix.
+
+ - statements of the following form:
+ - set_rowm(m,row) = scalar_value;
+ result in it being the case that:
+ - rowm(m,row) == uniform_matrix<matrix::type>(1,nc,scalar_value).
+
+ - In addition to the normal assignment statements using the = symbol, you may
+ also use the usual += and -= versions of the assignment operator. In these
+ cases, they have their usual effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ assignable_matrix_expression set_rowm (
+ matrix& m,
+ const matrix_exp& rows
+ );
+ /*!
+ requires
+ - rows contains integral elements (e.g. int, long)
+ - 0 <= min(rows) && max(rows) < m.nr()
+ - rows.nr() == 1 || rows.nc() == 1
+ (i.e. rows must be a vector)
+ ensures
+ - statements of the following form:
+ - set_rowm(m,rows) = some_matrix;
+ result in it being the case that:
+ - rowm(m,rows) == some_matrix.
+
+ - statements of the following form:
+ - set_rowm(m,rows) = scalar_value;
+ result in it being the case that:
+ - rowm(m,rows) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
+
+ - In addition to the normal assignment statements using the = symbol, you may
+ also use the usual += and -= versions of the assignment operator. In these
+ cases, they have their usual effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ assignable_matrix_expression set_colm (
+ matrix& m,
+ long col
+ );
+ /*!
+ requires
+ - 0 <= col < m.nr()
+ ensures
+ - statements of the following form:
+ - set_colm(m,col) = some_matrix;
+ result in it being the case that:
+ - colm(m,col) == some_matrix.
+
+ - statements of the following form:
+ - set_colm(m,col) = scalar_value;
+ result in it being the case that:
+ - colm(m,col) == uniform_matrix<matrix::type>(nr,1,scalar_value).
+
+ - In addition to the normal assignment statements using the = symbol, you may
+ also use the usual += and -= versions of the assignment operator. In these
+ cases, they have their usual effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ assignable_matrix_expression set_colm (
+ matrix& m,
+ const matrix_exp& cols
+ );
+ /*!
+ requires
+ - cols contains integral elements (e.g. int, long)
+ - 0 <= min(cols) && max(cols) < m.nc()
+ - cols.nr() == 1 || cols.nc() == 1
+ (i.e. cols must be a vector)
+ ensures
+ - statements of the following form:
+ - set_colm(m,cols) = some_matrix;
+ result in it being the case that:
+ - colm(m,cols) == some_matrix.
+
+ - statements of the following form:
+ - set_colm(m,cols) = scalar_value;
+ result in it being the case that:
+ - colm(m,cols) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
+
+ - In addition to the normal assignment statements using the = symbol, you may
+ also use the usual += and -= versions of the assignment operator. In these
+ cases, they have their usual effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_SUBEXP_ABSTRACT_
+
diff --git a/ml/dlib/dlib/matrix/matrix_trsm.h b/ml/dlib/dlib/matrix/matrix_trsm.h
new file mode 100644
index 000000000..ef5ec5ed9
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_trsm.h
@@ -0,0 +1,654 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRiX_TRSM_Hh_
+#define DLIB_MATRiX_TRSM_Hh_
+#include "lapack/fortran_id.h"
+#include "cblas_constants.h"
+
+namespace dlib
+{
+ namespace blas_bindings
+ {
+#ifdef DLIB_USE_BLAS
+#ifndef CBLAS_H
+ extern "C"
+ {
+ void cblas_strsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side,
+ const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ float *B, const int ldb);
+
+ void cblas_dtrsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side,
+ const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ double *B, const int ldb);
+ }
+#endif // if not CBLAS_H
+#endif // if DLIB_USE_BLAS
+
+ // ------------------------------------------------------------------------------------
+
+/* Purpose */
+/* ======= */
+
+/* DTRSM solves one of the matrix equations */
+
+/* op( A )*X = alpha*B, or X*op( A ) = alpha*B, */
+
+/* where alpha is a scalar, X and B are m by n matrices, A is a unit, or */
+/* non-unit, upper or lower triangular matrix and op( A ) is one of */
+
+/* op( A ) = A or op( A ) = A'. */
+
+/* The matrix X is overwritten on B. */
+
+/* Arguments */
+/* ========== */
+
+/* SIDE - CHARACTER*1. */
+/* On entry, SIDE specifies whether op( A ) appears on the left */
+/* or right of X as follows: */
+
+/* SIDE = 'L' or 'l' op( A )*X = alpha*B. */
+
+/* SIDE = 'R' or 'r' X*op( A ) = alpha*B. */
+
+/* Unchanged on exit. */
+
+/* UPLO - CHARACTER*1. */
+/* On entry, UPLO specifies whether the matrix A is an upper or */
+/* lower triangular matrix as follows: */
+
+/* UPLO = 'U' or 'u' A is an upper triangular matrix. */
+
+/* UPLO = 'L' or 'l' A is a lower triangular matrix. */
+
+/* Unchanged on exit. */
+
+/* TRANSA - CHARACTER*1. */
+/* On entry, TRANSA specifies the form of op( A ) to be used in */
+/* the matrix multiplication as follows: */
+
+/* TRANSA = 'N' or 'n' op( A ) = A. */
+
+/* TRANSA = 'T' or 't' op( A ) = A'. */
+
+/* TRANSA = 'C' or 'c' op( A ) = A'. */
+
+/* Unchanged on exit. */
+
+/* DIAG - CHARACTER*1. */
+/* On entry, DIAG specifies whether or not A is unit triangular */
+/* as follows: */
+
+/* DIAG = 'U' or 'u' A is assumed to be unit triangular. */
+
+/* DIAG = 'N' or 'n' A is not assumed to be unit */
+/* triangular. */
+
+/* Unchanged on exit. */
+
+/* M - INTEGER. */
+/* On entry, M specifies the number of rows of B. M must be at */
+/* least zero. */
+/* Unchanged on exit. */
+
+/* N - INTEGER. */
+/* On entry, N specifies the number of columns of B. N must be */
+/* at least zero. */
+/* Unchanged on exit. */
+
+/* ALPHA - DOUBLE PRECISION. */
+/* On entry, ALPHA specifies the scalar alpha. When alpha is */
+/* zero then A is not referenced and B need not be set before */
+/* entry. */
+/* Unchanged on exit. */
+
+/* A - DOUBLE PRECISION array of DIMENSION ( LDA, k ), where k is m */
+/* when SIDE = 'L' or 'l' and is n when SIDE = 'R' or 'r'. */
+/* Before entry with UPLO = 'U' or 'u', the leading k by k */
+/* upper triangular part of the array A must contain the upper */
+/* triangular matrix and the strictly lower triangular part of */
+/* A is not referenced. */
+/* Before entry with UPLO = 'L' or 'l', the leading k by k */
+/* lower triangular part of the array A must contain the lower */
+/* triangular matrix and the strictly upper triangular part of */
+/* A is not referenced. */
+/* Note that when DIAG = 'U' or 'u', the diagonal elements of */
+/* A are not referenced either, but are assumed to be unity. */
+/* Unchanged on exit. */
+
+/* LDA - INTEGER. */
+/* On entry, LDA specifies the first dimension of A as declared */
+/* in the calling (sub) program. When SIDE = 'L' or 'l' then */
+/* LDA must be at least max( 1, m ), when SIDE = 'R' or 'r' */
+/* then LDA must be at least max( 1, n ). */
+/* Unchanged on exit. */
+
+/* B - DOUBLE PRECISION array of DIMENSION ( LDB, n ). */
+/* Before entry, the leading m by n part of the array B must */
+/* contain the right-hand side matrix B, and on exit is */
+/* overwritten by the solution matrix X. */
+
+/* LDB - INTEGER. */
+/* On entry, LDB specifies the first dimension of B as declared */
+/* in the calling (sub) program. LDB must be at least */
+/* max( 1, m ). */
+/* Unchanged on exit. */
+
+
+/* Level 3 Blas routine. */
+
+
+/* -- Written on 8-February-1989. */
+/* Jack Dongarra, Argonne National Laboratory. */
+/* Iain Duff, AERE Harwell. */
+/* Jeremy Du Croz, Numerical Algorithms Group Ltd. */
+/* Sven Hammarling, Numerical Algorithms Group Ltd. */
+
+ template <typename T>
+ void local_trsm(
+ const CBLAS_ORDER Order,
+ CBLAS_SIDE Side,
+ CBLAS_UPLO Uplo,
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag,
+ long m,
+ long n,
+ T alpha,
+ const T *a,
+ long lda,
+ T *b,
+ long ldb
+ )
+ /*!
+ This is a copy of the dtrsm routine from the netlib.org BLAS which was run though
+ f2c and converted into this form for use when a BLAS library is not available.
+ !*/
+ {
+ if (Order == CblasRowMajor)
+ {
+ // since row major ordering looks like transposition to FORTRAN we need to flip a
+ // few things.
+ if (Side == CblasLeft)
+ Side = CblasRight;
+ else
+ Side = CblasLeft;
+
+ if (Uplo == CblasUpper)
+ Uplo = CblasLower;
+ else
+ Uplo = CblasUpper;
+
+ std::swap(m,n);
+ }
+
+ /* System generated locals */
+ long a_dim1, a_offset, b_dim1, b_offset, i__1, i__2, i__3;
+
+ /* Local variables */
+ long i__, j, k, info;
+ T temp;
+ bool lside;
+ long nrowa;
+ bool upper;
+ bool nounit;
+
+ /* Parameter adjustments */
+ a_dim1 = lda;
+ a_offset = 1 + a_dim1;
+ a -= a_offset;
+ b_dim1 = ldb;
+ b_offset = 1 + b_dim1;
+ b -= b_offset;
+
+ /* Function Body */
+ lside = (Side == CblasLeft);
+ if (lside)
+ {
+ nrowa = m;
+ } else
+ {
+ nrowa = n;
+ }
+ nounit = (Diag == CblasNonUnit);
+ upper = (Uplo == CblasUpper);
+
+ info = 0;
+ if (! lside && ! (Side == CblasRight)) {
+ info = 1;
+ } else if (! upper && !(Uplo == CblasLower) ) {
+ info = 2;
+ } else if (!(TransA == CblasNoTrans) &&
+ !(TransA == CblasTrans) &&
+ !(TransA == CblasConjTrans)) {
+ info = 3;
+ } else if (!(Diag == CblasUnit) &&
+ !(Diag == CblasNonUnit) ) {
+ info = 4;
+ } else if (m < 0) {
+ info = 5;
+ } else if (n < 0) {
+ info = 6;
+ } else if (lda < std::max<long>(1,nrowa)) {
+ info = 9;
+ } else if (ldb < std::max<long>(1,m)) {
+ info = 11;
+ }
+ DLIB_CASSERT( info == 0, "Invalid inputs given to local_trsm");
+
+ /* Quick return if possible. */
+
+ if (m == 0 || n == 0) {
+ return;
+ }
+
+ /* And when alpha.eq.zero. */
+
+ if (alpha == 0.) {
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + j * b_dim1] = 0.;
+ /* L10: */
+ }
+ /* L20: */
+ }
+ return;
+ }
+
+ /* Start the operations. */
+
+ if (lside) {
+ if (TransA == CblasNoTrans) {
+
+ /* Form B := alpha*inv( A )*B. */
+
+ if (upper) {
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ if (alpha != 1.) {
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1]
+ ;
+ /* L30: */
+ }
+ }
+ for (k = m; k >= 1; --k) {
+ if (b[k + j * b_dim1] != 0.) {
+ if (nounit) {
+ b[k + j * b_dim1] /= a[k + k * a_dim1];
+ }
+ i__2 = k - 1;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + j * b_dim1] -= b[k + j * b_dim1] * a[
+ i__ + k * a_dim1];
+ /* L40: */
+ }
+ }
+ /* L50: */
+ }
+ /* L60: */
+ }
+ } else {
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ if (alpha != 1.) {
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1]
+ ;
+ /* L70: */
+ }
+ }
+ i__2 = m;
+ for (k = 1; k <= i__2; ++k) {
+ if (b[k + j * b_dim1] != 0.) {
+ if (nounit) {
+ b[k + j * b_dim1] /= a[k + k * a_dim1];
+ }
+ i__3 = m;
+ for (i__ = k + 1; i__ <= i__3; ++i__) {
+ b[i__ + j * b_dim1] -= b[k + j * b_dim1] * a[
+ i__ + k * a_dim1];
+ /* L80: */
+ }
+ }
+ /* L90: */
+ }
+ /* L100: */
+ }
+ }
+ } else {
+
+ /* Form B := alpha*inv( A' )*B. */
+
+ if (upper) {
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ temp = alpha * b[i__ + j * b_dim1];
+ i__3 = i__ - 1;
+ for (k = 1; k <= i__3; ++k) {
+ temp -= a[k + i__ * a_dim1] * b[k + j * b_dim1];
+ /* L110: */
+ }
+ if (nounit) {
+ temp /= a[i__ + i__ * a_dim1];
+ }
+ b[i__ + j * b_dim1] = temp;
+ /* L120: */
+ }
+ /* L130: */
+ }
+ } else {
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ for (i__ = m; i__ >= 1; --i__) {
+ temp = alpha * b[i__ + j * b_dim1];
+ i__2 = m;
+ for (k = i__ + 1; k <= i__2; ++k) {
+ temp -= a[k + i__ * a_dim1] * b[k + j * b_dim1];
+ /* L140: */
+ }
+ if (nounit) {
+ temp /= a[i__ + i__ * a_dim1];
+ }
+ b[i__ + j * b_dim1] = temp;
+ /* L150: */
+ }
+ /* L160: */
+ }
+ }
+ }
+ } else {
+ if (TransA == CblasNoTrans) {
+
+ /* Form B := alpha*B*inv( A ). */
+
+ if (upper) {
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ if (alpha != 1.) {
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1]
+ ;
+ /* L170: */
+ }
+ }
+ i__2 = j - 1;
+ for (k = 1; k <= i__2; ++k) {
+ if (a[k + j * a_dim1] != 0.) {
+ i__3 = m;
+ for (i__ = 1; i__ <= i__3; ++i__) {
+ b[i__ + j * b_dim1] -= a[k + j * a_dim1] * b[
+ i__ + k * b_dim1];
+ /* L180: */
+ }
+ }
+ /* L190: */
+ }
+ if (nounit) {
+ temp = 1. / a[j + j * a_dim1];
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + j * b_dim1] = temp * b[i__ + j * b_dim1];
+ /* L200: */
+ }
+ }
+ /* L210: */
+ }
+ } else {
+ for (j = n; j >= 1; --j) {
+ if (alpha != 1.) {
+ i__1 = m;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1]
+ ;
+ /* L220: */
+ }
+ }
+ i__1 = n;
+ for (k = j + 1; k <= i__1; ++k) {
+ if (a[k + j * a_dim1] != 0.) {
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + j * b_dim1] -= a[k + j * a_dim1] * b[
+ i__ + k * b_dim1];
+ /* L230: */
+ }
+ }
+ /* L240: */
+ }
+ if (nounit) {
+ temp = 1. / a[j + j * a_dim1];
+ i__1 = m;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ b[i__ + j * b_dim1] = temp * b[i__ + j * b_dim1];
+ /* L250: */
+ }
+ }
+ /* L260: */
+ }
+ }
+ } else {
+
+ /* Form B := alpha*B*inv( A' ). */
+
+ if (upper) {
+ for (k = n; k >= 1; --k) {
+ if (nounit) {
+ temp = 1. / a[k + k * a_dim1];
+ i__1 = m;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ b[i__ + k * b_dim1] = temp * b[i__ + k * b_dim1];
+ /* L270: */
+ }
+ }
+ i__1 = k - 1;
+ for (j = 1; j <= i__1; ++j) {
+ if (a[j + k * a_dim1] != 0.) {
+ temp = a[j + k * a_dim1];
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + j * b_dim1] -= temp * b[i__ + k *
+ b_dim1];
+ /* L280: */
+ }
+ }
+ /* L290: */
+ }
+ if (alpha != 1.) {
+ i__1 = m;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ b[i__ + k * b_dim1] = alpha * b[i__ + k * b_dim1]
+ ;
+ /* L300: */
+ }
+ }
+ /* L310: */
+ }
+ } else {
+ i__1 = n;
+ for (k = 1; k <= i__1; ++k) {
+ if (nounit) {
+ temp = 1. / a[k + k * a_dim1];
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + k * b_dim1] = temp * b[i__ + k * b_dim1];
+ /* L320: */
+ }
+ }
+ i__2 = n;
+ for (j = k + 1; j <= i__2; ++j) {
+ if (a[j + k * a_dim1] != 0.) {
+ temp = a[j + k * a_dim1];
+ i__3 = m;
+ for (i__ = 1; i__ <= i__3; ++i__) {
+ b[i__ + j * b_dim1] -= temp * b[i__ + k *
+ b_dim1];
+ /* L330: */
+ }
+ }
+ /* L340: */
+ }
+ if (alpha != 1.) {
+ i__2 = m;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ b[i__ + k * b_dim1] = alpha * b[i__ + k * b_dim1]
+ ;
+ /* L350: */
+ }
+ }
+ /* L360: */
+ }
+ }
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline void cblas_trsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side,
+ const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag, const int M, const int N,
+ const float alpha, const float *A, const int lda,
+ float *B, const int ldb)
+ {
+#ifdef DLIB_USE_BLAS
+ if (M > 4)
+ {
+ cblas_strsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
+ return;
+ }
+#endif
+ local_trsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
+ }
+
+ inline void cblas_trsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side,
+ const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag, const int M, const int N,
+ const double alpha, const double *A, const int lda,
+ double *B, const int ldb)
+ {
+#ifdef DLIB_USE_BLAS
+ if (M > 4)
+ {
+ cblas_dtrsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
+ return;
+ }
+#endif
+ local_trsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
+ }
+
+ inline void cblas_trsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side,
+ const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag, const int M, const int N,
+ const long double alpha, const long double *A, const int lda,
+ long double *B, const int ldb)
+ {
+ local_trsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2,
+ long NC1, long NC2,
+ typename MM
+ >
+ inline void triangular_solver (
+ const CBLAS_SIDE Side,
+ const CBLAS_UPLO Uplo,
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag,
+ const matrix<T,NR1,NC1,MM,row_major_layout>& A,
+ const T alpha,
+ matrix<T,NR2,NC2,MM,row_major_layout>& B
+ )
+ {
+ cblas_trsm(CblasRowMajor, Side, Uplo, TransA, Diag, B.nr(), B.nc(),
+ alpha, &A(0,0), A.nc(), &B(0,0), B.nc());
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2,
+ long NC1, long NC2,
+ typename MM
+ >
+ inline void triangular_solver (
+ const CBLAS_SIDE Side,
+ const CBLAS_UPLO Uplo,
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag,
+ const matrix<T,NR1,NC1,MM,column_major_layout>& A,
+ const T alpha,
+ matrix<T,NR2,NC2,MM,column_major_layout>& B
+ )
+ {
+ cblas_trsm(CblasColMajor, Side, Uplo, TransA, Diag, B.nr(), B.nc(),
+ alpha, &A(0,0), A.nr(), &B(0,0), B.nr());
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2,
+ long NC1, long NC2,
+ typename MM
+ >
+ inline void triangular_solver (
+ const CBLAS_SIDE Side,
+ const CBLAS_UPLO Uplo,
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag,
+ const matrix<T,NR1,NC1,MM,column_major_layout>& A,
+ matrix<T,NR2,NC2,MM,column_major_layout>& B,
+ long rows_of_B
+ )
+ {
+ const T alpha = 1;
+ cblas_trsm(CblasColMajor, Side, Uplo, TransA, Diag, rows_of_B, B.nc(),
+ alpha, &A(0,0), A.nr(), &B(0,0), B.nr());
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR1, long NR2,
+ long NC1, long NC2,
+ typename MM,
+ typename layout
+ >
+ inline void triangular_solver (
+ const CBLAS_SIDE Side,
+ const CBLAS_UPLO Uplo,
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_DIAG Diag,
+ const matrix<T,NR1,NC1,MM,layout>& A,
+ matrix<T,NR2,NC2,MM,layout>& B
+ )
+ {
+ const T alpha = 1;
+ triangular_solver(Side, Uplo, TransA, Diag, A, alpha, B);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_MATRiX_TRSM_Hh_
+
diff --git a/ml/dlib/dlib/matrix/matrix_utilities.h b/ml/dlib/dlib/matrix/matrix_utilities.h
new file mode 100644
index 000000000..0c5091a4b
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_utilities.h
@@ -0,0 +1,4544 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MATRIx_UTILITIES_
+#define DLIB_MATRIx_UTILITIES_
+
+#include "matrix_utilities_abstract.h"
+#include "matrix.h"
+#include <cmath>
+#include <complex>
+#include <limits>
+#include "../pixel.h"
+#include "../stl_checked.h"
+#include <vector>
+#include <algorithm>
+#include "../std_allocator.h"
+#include "matrix_expressions.h"
+#include "matrix_math_functions.h"
+#include "matrix_op.h"
+#include "../general_hash/random_hashing.h"
+#include "matrix_mat.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A is_complex
+ This is a template that can be used to determine if a type is a specialization
+ of the std::complex template class.
+
+ For example:
+ is_complex<float>::value == false
+ is_complex<std::complex<float> >::value == true
+ !*/
+
+ template <typename T>
+ struct is_complex { static const bool value = false; };
+
+ template <typename T>
+ struct is_complex<std::complex<T> > { static const bool value = true; };
+ template <typename T>
+ struct is_complex<std::complex<T>& > { static const bool value = true; };
+ template <typename T>
+ struct is_complex<const std::complex<T>& > { static const bool value = true; };
+ template <typename T>
+ struct is_complex<const std::complex<T> > { static const bool value = true; };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ inline bool is_row_vector (
+ const matrix_exp<EXP>& m
+ ) { return m.nr() == 1; }
+
+ template <typename EXP>
+ inline bool is_col_vector (
+ const matrix_exp<EXP>& m
+ ) { return m.nc() == 1; }
+
+ template <typename EXP>
+ inline bool is_vector (
+ const matrix_exp<EXP>& m
+ ) { return is_row_vector(m) || is_col_vector(m); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ inline bool is_finite (
+ const matrix_exp<EXP>& m
+ )
+ {
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ if (!is_finite(m(r,c)))
+ return false;
+ }
+ }
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T>
+ const T& magnitude (const T& item) { return item; }
+ template <typename T>
+ T magnitude (const std::complex<T>& item) { return std::norm(item); }
+ }
+
+ template <
+ typename EXP
+ >
+ void find_min_and_max (
+ const matrix_exp<EXP>& m,
+ typename EXP::type& min_val,
+ typename EXP::type& max_val
+ )
+ {
+ DLIB_ASSERT(m.size() > 0,
+ "\ttype find_min_and_max(const matrix_exp& m, min_val, max_val)"
+ << "\n\tYou can't ask for the min and max of an empty matrix"
+ << "\n\tm.size(): " << m.size()
+ );
+ typedef typename matrix_exp<EXP>::type type;
+
+ min_val = m(0,0);
+ max_val = min_val;
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ type temp = m(r,c);
+ if (dlib::impl::magnitude(temp) > dlib::impl::magnitude(max_val))
+ max_val = temp;
+ if (dlib::impl::magnitude(temp) < dlib::impl::magnitude(min_val))
+ min_val = temp;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ point max_point (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.size() > 0,
+ "\tpoint max_point(const matrix_exp& m)"
+ << "\n\tm can't be empty"
+ << "\n\tm.size(): " << m.size()
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ typedef typename matrix_exp<EXP>::type type;
+
+ point best_point(0,0);
+ type val = m(0,0);
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ type temp = m(r,c);
+ if (dlib::impl::magnitude(temp) > dlib::impl::magnitude(val))
+ {
+ val = temp;
+ best_point = point(c,r);
+ }
+ }
+ }
+ return best_point;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ point min_point (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.size() > 0,
+ "\tpoint min_point(const matrix_exp& m)"
+ << "\n\tm can't be empty"
+ << "\n\tm.size(): " << m.size()
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ typedef typename matrix_exp<EXP>::type type;
+
+ point best_point(0,0);
+ type val = m(0,0);
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ type temp = m(r,c);
+ if (dlib::impl::magnitude(temp) < dlib::impl::magnitude(val))
+ {
+ val = temp;
+ best_point = point(c,r);
+ }
+ }
+ }
+ return best_point;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ long index_of_max (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.size() > 0 && is_vector(m) == true,
+ "\tlong index_of_max(const matrix_exp& m)"
+ << "\n\tm must be a row or column matrix"
+ << "\n\tm.size(): " << m.size()
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ typedef typename matrix_exp<EXP>::type type;
+
+ type val = m(0);
+ long best_idx = 0;
+ for (long i = 1; i < m.size(); ++i)
+ {
+ type temp = m(i);
+ if (dlib::impl::magnitude(temp) > dlib::impl::magnitude(val))
+ {
+ val = temp;
+ best_idx = i;
+ }
+ }
+ return best_idx;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ long index_of_min (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.size() > 0 && is_vector(m),
+ "\tlong index_of_min(const matrix_exp& m)"
+ << "\n\tm must be a row or column matrix"
+ << "\n\tm.size(): " << m.size()
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ typedef typename matrix_exp<EXP>::type type;
+
+ type val = m(0);
+ long best_idx = 0;
+ for (long i = 1; i < m.size(); ++i)
+ {
+ type temp = m(i);
+ if (dlib::impl::magnitude(temp) < dlib::impl::magnitude(val))
+ {
+ val = temp;
+ best_idx = i;
+ }
+ }
+ return best_idx;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const typename matrix_exp<EXP>::type max (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.size() > 0,
+ "\ttype max(const matrix_exp& m)"
+ << "\n\tYou can't ask for the max() of an empty matrix"
+ << "\n\tm.size(): " << m.size()
+ );
+ typedef typename matrix_exp<EXP>::type type;
+
+ type val = m(0,0);
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ type temp = m(r,c);
+ if (dlib::impl::magnitude(temp) > dlib::impl::magnitude(val))
+ val = temp;
+ }
+ }
+ return val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const typename matrix_exp<EXP>::type min (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.size() > 0,
+ "\ttype min(const matrix_exp& m)"
+ << "\n\tYou can't ask for the min() of an empty matrix"
+ << "\n\tm.size(): " << m.size()
+ );
+ typedef typename matrix_exp<EXP>::type type;
+
+ type val = m(0,0);
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ type temp = m(r,c);
+ if (dlib::impl::magnitude(temp) < dlib::impl::magnitude(val))
+ val = temp;
+ }
+ }
+ return val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_binary_min : basic_op_mm<M1,M2>
+ {
+ op_binary_min( const M1& m1_, const M2& m2_) : basic_op_mm<M1,M2>(m1_,m2_){}
+
+ typedef typename M1::type type;
+ typedef const type const_ret_type;
+ const static long cost = M1::cost + M2::cost + 1;
+
+ const_ret_type apply ( long r, long c) const
+ { return std::min(this->m1(r,c),this->m2(r,c)); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline const matrix_op<op_binary_min<EXP1,EXP2> > min_pointwise (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0);
+ DLIB_ASSERT(a.nr() == b.nr() &&
+ a.nc() == b.nc(),
+ "\t const matrix_exp min_pointwise(const matrix_exp& a, const matrix_exp& b)"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\tb.nc(): " << b.nc()
+ );
+ typedef op_binary_min<EXP1,EXP2> op;
+ return matrix_op<op>(op(a.ref(),b.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, typename M3>
+ struct op_min_pointwise3 : basic_op_mmm<M1,M2,M3>
+ {
+ op_min_pointwise3( const M1& m1_, const M2& m2_, const M3& m3_) :
+ basic_op_mmm<M1,M2,M3>(m1_,m2_,m3_){}
+
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ const static long cost = M1::cost + M2::cost + M3::cost + 2;
+
+ const_ret_type apply (long r, long c) const
+ { return std::min(this->m1(r,c),std::min(this->m2(r,c),this->m3(r,c))); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3
+ >
+ inline const matrix_op<op_min_pointwise3<EXP1,EXP2,EXP3> >
+ min_pointwise (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b,
+ const matrix_exp<EXP3>& c
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,typename EXP3::type>::value == true));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NR == 0 || EXP2::NC == 0);
+ COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0);
+ COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0);
+ DLIB_ASSERT(a.nr() == b.nr() &&
+ a.nc() == b.nc() &&
+ b.nr() == c.nr() &&
+ b.nc() == c.nc(),
+ "\tconst matrix_exp min_pointwise(a,b,c)"
+ << "\n\tYou can only make a do a pointwise min between equally sized matrices"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\tb.nc(): " << b.nc()
+ << "\n\tc.nr(): " << c.nr()
+ << "\n\tc.nc(): " << c.nc()
+ );
+
+ typedef op_min_pointwise3<EXP1,EXP2,EXP3> op;
+ return matrix_op<op>(op(a.ref(),b.ref(),c.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_binary_max : basic_op_mm<M1,M2>
+ {
+ op_binary_max( const M1& m1_, const M2& m2_) : basic_op_mm<M1,M2>(m1_,m2_){}
+
+ typedef typename M1::type type;
+ typedef const type const_ret_type;
+ const static long cost = M1::cost + M2::cost + 1;
+
+ const_ret_type apply ( long r, long c) const
+ { return std::max(this->m1(r,c),this->m2(r,c)); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline const matrix_op<op_binary_max<EXP1,EXP2> > max_pointwise (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0);
+ DLIB_ASSERT(a.nr() == b.nr() &&
+ a.nc() == b.nc(),
+ "\t const matrix_exp max_pointwise(const matrix_exp& a, const matrix_exp& b)"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\tb.nc(): " << b.nc()
+ );
+ typedef op_binary_max<EXP1,EXP2> op;
+ return matrix_op<op>(op(a.ref(),b.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, typename M3>
+ struct op_max_pointwise3 : basic_op_mmm<M1,M2,M3>
+ {
+ op_max_pointwise3( const M1& m1_, const M2& m2_, const M3& m3_) :
+ basic_op_mmm<M1,M2,M3>(m1_,m2_,m3_){}
+
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ const static long cost = M1::cost + M2::cost + M3::cost + 2;
+
+ const_ret_type apply (long r, long c) const
+ { return std::max(this->m1(r,c),std::max(this->m2(r,c),this->m3(r,c))); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3
+ >
+ inline const matrix_op<op_max_pointwise3<EXP1,EXP2,EXP3> >
+ max_pointwise (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b,
+ const matrix_exp<EXP3>& c
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,typename EXP3::type>::value == true));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NR == 0 || EXP2::NC == 0);
+ COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0);
+ COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0);
+ DLIB_ASSERT(a.nr() == b.nr() &&
+ a.nc() == b.nc() &&
+ b.nr() == c.nr() &&
+ b.nc() == c.nc(),
+ "\tconst matrix_exp max_pointwise(a,b,c)"
+ << "\n\tYou can only make a do a pointwise max between equally sized matrices"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\tb.nc(): " << b.nc()
+ << "\n\tc.nr(): " << c.nr()
+ << "\n\tc.nc(): " << c.nc()
+ );
+
+ typedef op_max_pointwise3<EXP1,EXP2,EXP3> op;
+ return matrix_op<op>(op(a.ref(),b.ref(),c.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ typename enable_if_c<std::numeric_limits<typename EXP::type>::is_integer, double>::type length (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(is_vector(m) == true,
+ "\ttype length(const matrix_exp& m)"
+ << "\n\tm must be a row or column vector"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+
+ return std::sqrt(static_cast<double>(sum(squared(m))));
+ }
+
+ template <
+ typename EXP
+ >
+ typename disable_if_c<std::numeric_limits<typename EXP::type>::is_integer, const typename EXP::type>::type length (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(is_vector(m) == true,
+ "\ttype length(const matrix_exp& m)"
+ << "\n\tm must be a row or column vector"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ return std::sqrt(sum(squared(m)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const typename matrix_exp<EXP>::type length_squared (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(is_vector(m) == true,
+ "\ttype length_squared(const matrix_exp& m)"
+ << "\n\tm must be a row or column vector"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ return sum(squared(m));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_trans
+ {
+ op_trans( const M& m_) : m(m_){}
+
+ const M& m;
+
+ const static long cost = M::cost;
+ const static long NR = M::NC;
+ const static long NC = M::NR;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+ const_ret_type apply (long r, long c) const { return m(c,r); }
+
+ long nr () const { return m.nc(); }
+ long nc () const { return m.nr(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+
+ };
+
+ template <
+ typename M
+ >
+ const matrix_op<op_trans<M> > trans (
+ const matrix_exp<M>& m
+ )
+ {
+ typedef op_trans<M> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// don't to anything at all for diagonal matrices
+ template <
+ typename M
+ >
+ const matrix_diag_exp<M>& trans (
+ const matrix_diag_exp<M>& m
+ )
+ {
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// I introduced this struct because it avoids an inane compiler warning from gcc
+ template <typename EXP>
+ struct is_not_ct_vector{ static const bool value = (EXP::NR != 1 && EXP::NC != 1); };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ typename enable_if_c<(is_not_ct_vector<EXP1>::value) || (is_not_ct_vector<EXP2>::value),
+ typename EXP1::type>::type
+ dot (
+ const matrix_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ )
+ {
+ // You are getting an error on this line because you are trying to
+ // compute the dot product between two matrices that aren't both vectors (i.e.
+ // they aren't column or row matrices).
+ COMPILE_TIME_ASSERT(EXP1::NR*EXP1::NC == 0 ||
+ EXP2::NR*EXP2::NC == 0);
+
+ DLIB_ASSERT(is_vector(m1) && is_vector(m2) && m1.size() == m2.size() &&
+ m1.size() > 0,
+ "\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
+ << "\n\t You can only compute the dot product between non-empty vectors of equal length."
+ << "\n\t is_vector(m1): " << is_vector(m1)
+ << "\n\t is_vector(m2): " << is_vector(m2)
+ << "\n\t m1.size(): " << m1.size()
+ << "\n\t m2.size(): " << m2.size()
+ );
+
+ if (is_col_vector(m1) && is_col_vector(m2)) return (trans(m1)*m2)(0);
+ if (is_col_vector(m1) && is_row_vector(m2)) return (m2*m1)(0);
+ if (is_row_vector(m1) && is_col_vector(m2)) return (m1*m2)(0);
+
+ //if (is_row_vector(m1) && is_row_vector(m2))
+ return (m1*trans(m2))(0);
+ }
+
+ template < typename EXP1, typename EXP2 >
+ typename enable_if_c<EXP1::NR == 1 && EXP2::NR == 1 && EXP1::NC != 1 && EXP2::NC != 1, typename EXP1::type>::type
+ dot ( const matrix_exp<EXP1>& m1, const matrix_exp<EXP2>& m2)
+ {
+ DLIB_ASSERT(m1.size() == m2.size(),
+ "\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
+ << "\n\t You can only compute the dot product between vectors of equal length"
+ << "\n\t m1.size(): " << m1.size()
+ << "\n\t m2.size(): " << m2.size()
+ );
+
+ return m1*trans(m2);
+ }
+
+ template < typename EXP1, typename EXP2 >
+ typename enable_if_c<EXP1::NR == 1 && EXP2::NC == 1 && EXP1::NC != 1 && EXP2::NR != 1, typename EXP1::type>::type
+ dot ( const matrix_exp<EXP1>& m1, const matrix_exp<EXP2>& m2)
+ {
+ DLIB_ASSERT(m1.size() == m2.size(),
+ "\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
+ << "\n\t You can only compute the dot product between vectors of equal length"
+ << "\n\t m1.size(): " << m1.size()
+ << "\n\t m2.size(): " << m2.size()
+ );
+
+ return m1*m2;
+ }
+
+ template < typename EXP1, typename EXP2 >
+ typename enable_if_c<EXP1::NC == 1 && EXP2::NR == 1 && EXP1::NR != 1 && EXP2::NC != 1, typename EXP1::type>::type
+ dot ( const matrix_exp<EXP1>& m1, const matrix_exp<EXP2>& m2)
+ {
+ DLIB_ASSERT(m1.size() == m2.size(),
+ "\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
+ << "\n\t You can only compute the dot product between vectors of equal length"
+ << "\n\t m1.size(): " << m1.size()
+ << "\n\t m2.size(): " << m2.size()
+ );
+
+ return m2*m1;
+ }
+
+ template < typename EXP1, typename EXP2 >
+ typename enable_if_c<EXP1::NC == 1 && EXP2::NC == 1 && EXP1::NR != 1 && EXP2::NR != 1, typename EXP1::type>::type
+ dot ( const matrix_exp<EXP1>& m1, const matrix_exp<EXP2>& m2)
+ {
+ DLIB_ASSERT(m1.size() == m2.size(),
+ "\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
+ << "\n\t You can only compute the dot product between vectors of equal length"
+ << "\n\t m1.size(): " << m1.size()
+ << "\n\t m2.size(): " << m2.size()
+ );
+
+ return trans(m1)*m2;
+ }
+
+ template < typename EXP1, typename EXP2 >
+ typename enable_if_c<(EXP1::NC*EXP1::NR == 1) || (EXP2::NC*EXP2::NR == 1), typename EXP1::type>::type
+ dot ( const matrix_exp<EXP1>& m1, const matrix_exp<EXP2>& m2)
+ {
+ DLIB_ASSERT(m1.size() == m2.size(),
+ "\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
+ << "\n\t You can only compute the dot product between vectors of equal length"
+ << "\n\t m1.size(): " << m1.size()
+ << "\n\t m2.size(): " << m2.size()
+ );
+
+ return m1(0)*m2(0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, long R, long C>
+ struct op_removerc
+ {
+ op_removerc( const M& m_) : m(m_){}
+
+ const M& m;
+
+ const static long cost = M::cost+2;
+ const static long NR = (M::NR==0) ? 0 : (M::NR - 1);
+ const static long NC = (M::NC==0) ? 0 : (M::NC - 1);
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply (long r, long c) const
+ {
+ if (r < R)
+ {
+ if (c < C)
+ return m(r,c);
+ else
+ return m(r,c+1);
+ }
+ else
+ {
+ if (c < C)
+ return m(r+1,c);
+ else
+ return m(r+1,c+1);
+ }
+ }
+
+ long nr () const { return m.nr() - 1; }
+ long nc () const { return m.nc() - 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <typename M>
+ struct op_removerc2
+ {
+ op_removerc2( const M& m_, const long R_, const long C_) : m(m_), R(R_), C(C_){}
+ const M& m;
+ const long R;
+ const long C;
+
+ const static long cost = M::cost+2;
+ const static long NR = (M::NR==0) ? 0 : (M::NR - 1);
+ const static long NC = (M::NC==0) ? 0 : (M::NC - 1);
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply (long r, long c) const
+ {
+ if (r < R)
+ {
+ if (c < C)
+ return m(r,c);
+ else
+ return m(r,c+1);
+ }
+ else
+ {
+ if (c < C)
+ return m(r+1,c);
+ else
+ return m(r+1,c+1);
+ }
+ }
+
+ long nr () const { return m.nr() - 1; }
+ long nc () const { return m.nc() - 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ long R,
+ long C,
+ typename EXP
+ >
+ const matrix_op<op_removerc<EXP,R,C> > removerc (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can't remove a row from a matrix with only one row
+ COMPILE_TIME_ASSERT((EXP::NR > R && R >= 0) || EXP::NR == 0);
+ // you can't remove a column from a matrix with only one column
+ COMPILE_TIME_ASSERT((EXP::NC > C && C >= 0) || EXP::NR == 0);
+ DLIB_ASSERT(m.nr() > R && R >= 0 && m.nc() > C && C >= 0,
+ "\tconst matrix_exp removerc<R,C>(const matrix_exp& m)"
+ << "\n\tYou can't remove a row/column from a matrix if it doesn't have that row/column"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tR: " << R
+ << "\n\tC: " << C
+ );
+ typedef op_removerc<EXP,R,C> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_removerc2<EXP> > removerc (
+ const matrix_exp<EXP>& m,
+ long R,
+ long C
+ )
+ {
+ DLIB_ASSERT(m.nr() > R && R >= 0 && m.nc() > C && C >= 0,
+ "\tconst matrix_exp removerc(const matrix_exp& m,R,C)"
+ << "\n\tYou can't remove a row/column from a matrix if it doesn't have that row/column"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tR: " << R
+ << "\n\tC: " << C
+ );
+ typedef op_removerc2<EXP> op;
+ return matrix_op<op>(op(m.ref(),R,C));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, long C>
+ struct op_remove_col
+ {
+ op_remove_col( const M& m_) : m(m_){}
+ const M& m;
+
+ const static long cost = M::cost+2;
+ const static long NR = M::NR;
+ const static long NC = (M::NC==0) ? 0 : (M::NC - 1);
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (c < C)
+ {
+ return m(r,c);
+ }
+ else
+ {
+ return m(r,c+1);
+ }
+ }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return m.nc() - 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <typename M>
+ struct op_remove_col2
+ {
+ op_remove_col2( const M& m_, const long C_) : m(m_), C(C_){}
+ const M& m;
+ const long C;
+
+ const static long cost = M::cost+2;
+ const static long NR = M::NR;
+ const static long NC = (M::NC==0) ? 0 : (M::NC - 1);
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (c < C)
+ {
+ return m(r,c);
+ }
+ else
+ {
+ return m(r,c+1);
+ }
+ }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return m.nc() - 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ long C,
+ typename EXP
+ >
+ const matrix_op<op_remove_col<EXP, C> > remove_col (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // You can't remove the given column from the matrix because the matrix doesn't
+ // have a column with that index.
+ COMPILE_TIME_ASSERT((EXP::NC > C && C >= 0) || EXP::NC == 0);
+ DLIB_ASSERT(m.nc() > C && C >= 0 ,
+ "\tconst matrix_exp remove_col<C>(const matrix_exp& m)"
+ << "\n\tYou can't remove a col from a matrix if it doesn't have it"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tC: " << C
+ );
+ typedef op_remove_col<EXP,C> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_remove_col2<EXP> > remove_col (
+ const matrix_exp<EXP>& m,
+ long C
+ )
+ {
+ DLIB_ASSERT(m.nc() > C && C >= 0 ,
+ "\tconst matrix_exp remove_col(const matrix_exp& m,C)"
+ << "\n\tYou can't remove a col from a matrix if it doesn't have it"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tC: " << C
+ );
+ typedef op_remove_col2<EXP> op;
+ return matrix_op<op>(op(m.ref(),C));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, long R>
+ struct op_remove_row
+ {
+ op_remove_row( const M& m_) : m(m_){}
+ const M& m;
+
+ const static long cost = M::cost+2;
+ const static long NR = (M::NR==0) ? 0 : (M::NR - 1);
+ const static long NC = M::NC;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r < R)
+ {
+ return m(r,c);
+ }
+ else
+ {
+ return m(r+1,c);
+ }
+ }
+
+ long nr () const { return m.nr() - 1; }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <typename M>
+ struct op_remove_row2
+ {
+ op_remove_row2( const M& m_, const long R_) : m(m_), R(R_){}
+ const M& m;
+ const long R;
+
+ const static long cost = M::cost+2;
+ const static long NR = (M::NR==0) ? 0 : (M::NR - 1);
+ const static long NC = M::NC;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r < R)
+ {
+ return m(r,c);
+ }
+ else
+ {
+ return m(r+1,c);
+ }
+ }
+
+ long nr () const { return m.nr() - 1; }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ long R,
+ typename EXP
+ >
+ const matrix_op<op_remove_row<EXP,R> > remove_row (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // You can't remove the given row from the matrix because the matrix doesn't
+ // have a row with that index.
+ COMPILE_TIME_ASSERT((EXP::NR > R && R >= 0) || EXP::NR == 0);
+ DLIB_ASSERT(m.nr() > R && R >= 0,
+ "\tconst matrix_exp remove_row<R>(const matrix_exp& m)"
+ << "\n\tYou can't remove a row from a matrix if it doesn't have it"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tR: " << R
+ );
+ typedef op_remove_row<EXP,R> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_remove_row2<EXP> > remove_row (
+ const matrix_exp<EXP>& m,
+ long R
+ )
+ {
+ DLIB_ASSERT(m.nr() > R && R >= 0,
+ "\tconst matrix_exp remove_row(const matrix_exp& m, long R)"
+ << "\n\tYou can't remove a row from a matrix if it doesn't have it"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tR: " << R
+ );
+ typedef op_remove_row2<EXP> op;
+ return matrix_op<op>(op(m.ref(),R));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_diagm
+ {
+ op_diagm( const M& m_) : m(m_){}
+ const M& m;
+
+ const static long cost = M::cost+2;
+ const static long N = M::NC*M::NR;
+ const static long NR = N;
+ const static long NC = N;
+ typedef typename M::type type;
+ typedef const typename M::type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r==c)
+ return m(r);
+ else
+ return 0;
+ }
+
+ long nr () const { return (m.nr()>m.nc())? m.nr():m.nc(); }
+ long nc () const { return (m.nr()>m.nc())? m.nr():m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_diag_op<op_diagm<EXP> > diagm (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // You can only make a diagonal matrix out of a row or column vector
+ COMPILE_TIME_ASSERT(EXP::NR == 0 || EXP::NR == 1 || EXP::NC == 1 || EXP::NC == 0);
+ DLIB_ASSERT(is_vector(m),
+ "\tconst matrix_exp diagm(const matrix_exp& m)"
+ << "\n\tYou can only apply diagm() to a row or column matrix"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ typedef op_diagm<EXP> op;
+ return matrix_diag_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_diagm_mult : basic_op_mm<M1,M2>
+ {
+ op_diagm_mult( const M1& m1_, const M2& m2_) : basic_op_mm<M1,M2>(m1_,m2_){}
+
+ typedef typename M1::type type;
+ typedef const type const_ret_type;
+ const static long cost = M1::cost + M2::cost + 1;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r == c)
+ return this->m1(r,c)*this->m2(r,c);
+ else
+ return 0;
+ }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline const matrix_diag_op<op_diagm_mult<EXP1,EXP2> > operator* (
+ const matrix_diag_exp<EXP1>& a,
+ const matrix_diag_exp<EXP2>& b
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type, typename EXP2::type>::value));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0);
+ DLIB_ASSERT(a.nr() == b.nr() &&
+ a.nc() == b.nc(),
+ "\tconst matrix_exp operator(const matrix_diag_exp& a, const matrix_diag_exp& b)"
+ << "\n\tYou can only multiply diagonal matrices together if they are the same size"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\tb.nc(): " << b.nc()
+ );
+ typedef op_diagm_mult<EXP1,EXP2> op;
+ return matrix_diag_op<op>(op(a.ref(),b.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_diag
+ {
+ op_diag( const M& m_) : m(m_){}
+ const M& m;
+
+ const static long cost = M::cost;
+ const static long NR = tmin<M::NR,M::NC>::value;
+ const static long NC = 1;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long ) const { return m(r,r); }
+
+ long nr () const { return std::min(m.nc(),m.nr()); }
+ long nc () const { return 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_diag<EXP> > diag (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_diag<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+ template <typename EXP>
+ struct diag_exp
+ {
+ typedef matrix_op<op_diag<EXP> > type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, typename target_type>
+ struct op_cast
+ {
+ op_cast( const M& m_) : m(m_){}
+ const M& m;
+
+ const static long cost = M::cost+2;
+ const static long NR = M::NR;
+ const static long NC = M::NC;
+ typedef target_type type;
+ typedef const target_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long c) const { return static_cast<target_type>(m(r,c)); }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.destructively_aliases(item); }
+ };
+
+ template <
+ typename target_type,
+ typename EXP
+ >
+ const matrix_op<op_cast<EXP, target_type> > matrix_cast (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_cast<EXP, target_type> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename type, typename S>
+ inline type lessthan(const type& val, const S& s)
+ {
+ if (val < s)
+ return 1;
+ else
+ return 0;
+ }
+
+ }
+ DLIB_DEFINE_OP_MS(op_lessthan, impl::lessthan, 1);
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_lessthan<EXP,S> > >::type operator< (
+ const matrix_exp<EXP>& m,
+ const S& s
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_lessthan<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_lessthan<EXP,S> > >::type operator> (
+ const S& s,
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_lessthan<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename type, typename S>
+ inline type lessthan_eq(const type& val, const S& s)
+ {
+ if (val <= s)
+ return 1;
+ else
+ return 0;
+ }
+
+ }
+ DLIB_DEFINE_OP_MS(op_lessthan_eq, impl::lessthan_eq, 1);
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_lessthan_eq<EXP,S> > >::type operator<= (
+ const matrix_exp<EXP>& m,
+ const S& s
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_lessthan_eq<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_lessthan_eq<EXP,S> > >::type operator>= (
+ const S& s,
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_lessthan_eq<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename type, typename S>
+ inline type greaterthan(const type& val, const S& s)
+ {
+ if (val > s)
+ return 1;
+ else
+ return 0;
+ }
+
+ }
+ DLIB_DEFINE_OP_MS(op_greaterthan, impl::greaterthan, 1);
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_greaterthan<EXP,S> > >::type operator> (
+ const matrix_exp<EXP>& m,
+ const S& s
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_greaterthan<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_greaterthan<EXP,S> > >::type operator< (
+ const S& s,
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_greaterthan<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename type, typename S>
+ inline type greaterthan_eq(const type& val, const S& s)
+ {
+ if (val >= s)
+ return 1;
+ else
+ return 0;
+ }
+
+ }
+ DLIB_DEFINE_OP_MS(op_greaterthan_eq, impl::greaterthan_eq, 1);
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_greaterthan_eq<EXP,S> > >::type operator>= (
+ const matrix_exp<EXP>& m,
+ const S& s
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_greaterthan_eq<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_greaterthan_eq<EXP,S> > >::type operator<= (
+ const S& s,
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_greaterthan_eq<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename type, typename S>
+ inline type equal_to(const type& val, const S& s)
+ {
+ if (val == s)
+ return 1;
+ else
+ return 0;
+ }
+
+ }
+ DLIB_DEFINE_OP_MS(op_equal_to, impl::equal_to, 1);
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_equal_to<EXP,S> > >::type operator== (
+ const matrix_exp<EXP>& m,
+ const S& s
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT( is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_equal_to<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_equal_to<EXP,S> > >::type operator== (
+ const S& s,
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT( is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_equal_to<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename type, typename S>
+ inline type not_equal_to(const type& val, const S& s)
+ {
+ if (val != s)
+ return 1;
+ else
+ return 0;
+ }
+
+ }
+ DLIB_DEFINE_OP_MS(op_not_equal_to, impl::not_equal_to, 1);
+
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_not_equal_to<EXP,S> > >::type operator!= (
+ const matrix_exp<EXP>& m,
+ const S& s
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_not_equal_to<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+ template <
+ typename EXP,
+ typename S
+ >
+ const typename enable_if<is_built_in_scalar_type<S>, matrix_op<op_not_equal_to<EXP,S> > >::type operator!= (
+ const S& s,
+ const matrix_exp<EXP>& m
+ )
+ {
+ // you can only use this relational operator with the built in scalar types like
+ // long, float, etc.
+ COMPILE_TIME_ASSERT(is_built_in_scalar_type<typename EXP::type>::value);
+
+ typedef op_not_equal_to<EXP,S> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename U,
+ typename L
+ >
+ typename disable_if<is_matrix<U>,void>::type set_all_elements (
+ matrix<T,NR,NC,MM,L>& m,
+ const U& value
+ )
+ {
+ // The value you are trying to assign to each element of the m matrix
+ // doesn't have the appropriate type.
+ COMPILE_TIME_ASSERT(is_matrix<T>::value == is_matrix<U>::value);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = static_cast<T>(value);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename U,
+ typename L
+ >
+ typename enable_if<is_matrix<U>,void>::type set_all_elements (
+ matrix<T,NR,NC,MM,L>& m,
+ const U& value
+ )
+ {
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = value;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ inline const typename matrix_exp<EXP>::matrix_type tmp (
+ const matrix_exp<EXP>& m
+ )
+ {
+ return typename matrix_exp<EXP>::matrix_type (m);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ constexpr bool is_row_major (
+ const matrix_exp<EXP>&
+ )
+ {
+ return is_same_type<typename EXP::layout_type,row_major_layout>::value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const typename lazy_disable_if<is_matrix<typename EXP::type>, EXP>::type sum (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef typename matrix_exp<EXP>::type type;
+
+ type val = 0;
+ if (is_row_major(m))
+ {
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ val += m(r,c);
+ }
+ }
+ }
+ else
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ val += m(r,c);
+ }
+ }
+ }
+ return val;
+ }
+
+ template <
+ typename EXP
+ >
+ const typename lazy_enable_if<is_matrix<typename EXP::type>, EXP>::type sum (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef typename matrix_exp<EXP>::type type;
+
+ type val;
+ if (m.size() > 0)
+ val.set_size(m(0,0).nr(),m(0,0).nc());
+ set_all_elements(val,0);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ val += m(r,c);
+ }
+ }
+ return val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_sumr
+ {
+ op_sumr(const M& m_) : m(m_) {}
+ const M& m;
+
+ const static long cost = M::cost+10;
+ const static long NR = 1;
+ const static long NC = M::NC;
+ typedef typename M::type type;
+ typedef const typename M::type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long , long c) const
+ {
+ type temp = m(0,c);
+ for (long r = 1; r < m.nr(); ++r)
+ temp += m(r,c);
+ return temp;
+ }
+
+ long nr () const { return 1; }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_sumr<EXP> > sum_rows (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.size() > 0 ,
+ "\tconst matrix_exp sum_rows(m)"
+ << "\n\t The matrix can't be empty"
+ << "\n\t m.size(): " << m.size()
+ );
+ typedef op_sumr<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_sumc
+ {
+ op_sumc(const M& m_) : m(m_) {}
+ const M& m;
+
+ const static long cost = M::cost + 10;
+ const static long NR = M::NR;
+ const static long NC = 1;
+ typedef typename M::type type;
+ typedef const typename M::type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long ) const
+ {
+ type temp = m(r,0);
+ for (long c = 1; c < m.nc(); ++c)
+ temp += m(r,c);
+ return temp;
+ }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_sumc<EXP> > sum_cols (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.size() > 0 ,
+ "\tconst matrix_exp sum_cols(m)"
+ << "\n\t The matrix can't be empty"
+ << "\n\t m.size(): " << m.size()
+ );
+ typedef op_sumc<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ inline const typename disable_if<is_complex<typename EXP::type>, typename matrix_exp<EXP>::type>::type mean (
+ const matrix_exp<EXP>& m
+ )
+ {
+ return sum(m)/(m.nr()*m.nc());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ inline const typename enable_if<is_complex<typename EXP::type>, typename matrix_exp<EXP>::type>::type mean (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef typename EXP::type::value_type type;
+ return sum(m)/(type)(m.nr()*m.nc());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const typename matrix_exp<EXP>::type variance (
+ const matrix_exp<EXP>& m
+ )
+ {
+ using std::pow;
+ using dlib::pow;
+ const typename matrix_exp<EXP>::type avg = mean(m);
+
+ typedef typename matrix_exp<EXP>::type type;
+
+ type val;
+ val = 0;
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ val += pow(m(r,c) - avg,2);
+ }
+ }
+
+ if (m.nr() * m.nc() <= 1)
+ {
+ return val;
+ }
+ else
+ {
+ // Note, for some reason, in gcc 4.1 performing this division using a
+ // double instead of a long value avoids a segmentation fault. That is,
+ // using 1.0 instead of 1 does the trick.
+ return val/(m.nr()*m.nc() - 1.0);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const typename matrix_exp<EXP>::type stddev (
+ const matrix_exp<EXP>& m
+ )
+ {
+ using std::sqrt;
+ using dlib::sqrt;
+ return sqrt(variance(m));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// this is a workaround for a bug in visual studio 7.1
+ template <typename EXP>
+ struct visual_studio_sucks_cov_helper
+ {
+ typedef typename EXP::type inner_type;
+ typedef matrix<typename inner_type::type, inner_type::NR, inner_type::NR, typename EXP::mem_manager_type> type;
+ };
+
+ template <
+ typename EXP
+ >
+ const typename visual_studio_sucks_cov_helper<EXP>::type covariance (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // perform static checks to make sure m is a column vector
+ COMPILE_TIME_ASSERT(EXP::NR == 0 || EXP::NR > 1);
+ COMPILE_TIME_ASSERT(EXP::NC == 1 || EXP::NC == 0);
+
+ // perform static checks to make sure the matrices contained in m are column vectors
+ COMPILE_TIME_ASSERT(EXP::type::NC == 1 || EXP::type::NC == 0 );
+
+ DLIB_ASSERT(m.size() > 1 && is_col_vector(m),
+ "\tconst matrix covariance(const matrix_exp& m)"
+ << "\n\tYou can only apply covariance() to a column matrix"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+#ifdef ENABLE_ASSERTS
+ for (long i = 0; i < m.nr(); ++i)
+ {
+ DLIB_ASSERT(m(0).size() == m(i).size() && m(i).size() > 0 && is_col_vector(m(i)),
+ "\tconst matrix covariance(const matrix_exp& m)"
+ << "\n\tYou can only apply covariance() to a column matrix of column matrices"
+ << "\n\tm(0).size(): " << m(0).size()
+ << "\n\tm(i).size(): " << m(i).size()
+ << "\n\tis_col_vector(m(i)): " << (is_col_vector(m(i)) ? "true" : "false")
+ << "\n\ti: " << i
+ );
+ }
+#endif
+
+ // now perform the actual calculation of the covariance matrix.
+ typename visual_studio_sucks_cov_helper<EXP>::type cov(m(0).nr(),m(0).nr());
+ set_all_elements(cov,0);
+
+ const typename EXP::type avg = mean(m);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ cov += (m(r) - avg)*trans(m(r) - avg);
+ }
+
+ cov *= 1.0 / (m.nr() - 1.0);
+ return cov;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const typename matrix_exp<EXP>::type prod (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef typename matrix_exp<EXP>::type type;
+
+ type val = 1;
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ val *= m(r,c);
+ }
+ }
+ return val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct op_uniform_matrix_3 : does_not_alias
+ {
+ op_uniform_matrix_3(const long& rows_, const long& cols_, const T& val_ ) :
+ rows(rows_), cols(cols_), val(val_) {}
+
+ const long rows;
+ const long cols;
+ const T val;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+ typedef T type;
+ typedef const T& const_ret_type;
+ const_ret_type apply (long, long ) const { return val; }
+
+ long nr() const { return rows; }
+ long nc() const { return cols; }
+ };
+
+ template <
+ typename T
+ >
+ const matrix_op<op_uniform_matrix_3<T> > uniform_matrix (
+ long nr,
+ long nc,
+ const T& val
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0,
+ "\tconst matrix_exp uniform_matrix<T>(nr, nc, val)"
+ << "\n\tnr and nc have to be bigger than 0"
+ << "\n\tnr: " << nr
+ << "\n\tnc: " << nc
+ );
+ typedef op_uniform_matrix_3<T> op;
+ return matrix_op<op>(op(nr, nc, val));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_uniform_matrix_3<T> > zeros_matrix (
+ long nr,
+ long nc
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0,
+ "\tconst matrix_exp zeros_matrix<T>(nr, nc)"
+ << "\n\tnr and nc have to be >= 0"
+ << "\n\tnr: " << nr
+ << "\n\tnc: " << nc
+ );
+ typedef op_uniform_matrix_3<T> op;
+ return matrix_op<op>(op(nr, nc, 0));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_uniform_matrix_3<typename EXP::type> > zeros_matrix (
+ const matrix_exp<EXP>& mat
+ )
+ {
+ DLIB_ASSERT(mat.nr() >= 0 && mat.nc() >= 0,
+ "\tconst matrix_exp zeros_matrix(mat)"
+ << "\n\t nr and nc have to be >= 0"
+ << "\n\t mat.nr(): " << mat.nr()
+ << "\n\t mat.nc(): " << mat.nc()
+ );
+ typedef typename EXP::type T;
+ typedef op_uniform_matrix_3<T> op;
+ return matrix_op<op>(op(mat.nr(), mat.nc(), 0));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_uniform_matrix_3<T> > ones_matrix (
+ long nr,
+ long nc
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0,
+ "\tconst matrix_exp ones_matrix<T>(nr, nc)"
+ << "\n\tnr and nc have to be >= 0"
+ << "\n\tnr: " << nr
+ << "\n\tnc: " << nc
+ );
+ typedef op_uniform_matrix_3<T> op;
+ return matrix_op<op>(op(nr, nc, 1));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_uniform_matrix_3<typename EXP::type> > ones_matrix (
+ const matrix_exp<EXP>& mat
+ )
+ {
+ DLIB_ASSERT(mat.nr() >= 0 && mat.nc() >= 0,
+ "\tconst matrix_exp ones_matrix(mat)"
+ << "\n\t nr and nc have to be >= 0"
+ << "\n\t mat.nr(): " << mat.nr()
+ << "\n\t mat.nc(): " << mat.nc()
+ );
+ typedef typename EXP::type T;
+ typedef op_uniform_matrix_3<T> op;
+ return matrix_op<op>(op(mat.nr(), mat.nc(), 1));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR_,
+ long NC_
+ >
+ struct op_uniform_matrix_2 : does_not_alias
+ {
+ op_uniform_matrix_2( const T& val_ ) : val(val_) {}
+ const T val;
+
+ const static long cost = 1;
+ const static long NR = NR_;
+ const static long NC = NC_;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+ typedef T type;
+ typedef const T& const_ret_type;
+
+ const_ret_type apply (long , long ) const { return val; }
+
+ long nr() const { return NR; }
+ long nc() const { return NC; }
+ };
+
+ template <
+ typename T,
+ long NR,
+ long NC
+ >
+ const matrix_op<op_uniform_matrix_2<T,NR,NC> > uniform_matrix (
+ const T& val
+ )
+ {
+ COMPILE_TIME_ASSERT(NR > 0 && NC > 0);
+
+ typedef op_uniform_matrix_2<T,NR,NC> op;
+ return matrix_op<op>(op(val));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR_,
+ long NC_,
+ T val
+ >
+ struct op_uniform_matrix : does_not_alias
+ {
+ const static long cost = 1;
+ const static long NR = NR_;
+ const static long NC = NC_;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+ typedef T type;
+ typedef const T const_ret_type;
+ const_ret_type apply ( long , long ) const { return val; }
+
+ long nr() const { return NR; }
+ long nc() const { return NC; }
+ };
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ T val
+ >
+ const matrix_op<op_uniform_matrix<T,NR,NC,val> > uniform_matrix (
+ )
+ {
+ COMPILE_TIME_ASSERT(NR > 0 && NC > 0);
+ typedef op_uniform_matrix<T,NR,NC,val> op;
+ return matrix_op<op>(op());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ struct op_gaussian_randm : does_not_alias
+ {
+ op_gaussian_randm (
+ long nr_,
+ long nc_,
+ unsigned long seed_
+ ) :_nr(nr_), _nc(nc_), seed(seed_){}
+
+ const long _nr;
+ const long _nc;
+ const unsigned long seed;
+
+ const static long cost = 100;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+ typedef double type;
+ typedef double const_ret_type;
+ const_ret_type apply ( long r, long c) const { return gaussian_random_hash(r,c,seed); }
+
+ long nr() const { return _nr; }
+ long nc() const { return _nc; }
+ };
+
+ inline const matrix_op<op_gaussian_randm> gaussian_randm (
+ long nr,
+ long nc,
+ unsigned long seed = 0
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0,
+ "\tmatrix_exp gaussian_randm(nr, nc, seed)"
+ << "\n\tInvalid inputs to this function"
+ << "\n\tnr: " << nr
+ << "\n\tnc: " << nc
+ );
+
+ typedef op_gaussian_randm op;
+ return matrix_op<op>(op(nr,nc,seed));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_add_diag
+ {
+ op_add_diag( const M& m_, const typename M::type& value_) : m(m_), value(value_){}
+ const M& m;
+ const typename M::type value;
+
+ const static long cost = M::cost+1;
+ const static long NR = M::NR;
+ const static long NC = M::NC;
+ typedef typename M::type type;
+ typedef const typename M::type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r==c)
+ return m(r,c)+value;
+ else
+ return m(r,c);
+ }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.destructively_aliases(item); }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct op_identity_matrix_2 : does_not_alias
+ {
+ op_identity_matrix_2(const long& size_) : size(size_) {}
+
+ const long size;
+
+ const static long cost = 1;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+ typedef T type;
+ typedef const T const_ret_type;
+ const_ret_type apply (long r, long c) const { return static_cast<type>(r == c); }
+
+ long nr() const { return size; }
+ long nc() const { return size; }
+ };
+
+ template <
+ typename T,
+ typename U
+ >
+ const matrix_diag_op<op_identity_matrix_2<T> > identity_matrix (
+ const U& size
+ )
+ {
+ // the size argument must be some scalar value, not a matrix!
+ COMPILE_TIME_ASSERT(is_matrix<U>::value == false);
+
+ DLIB_ASSERT(size > 0,
+ "\tconst matrix_exp identity_matrix<T>(size)"
+ << "\n\tsize must be bigger than 0"
+ << "\n\tsize: " << size
+ );
+ typedef op_identity_matrix_2<T> op;
+ return matrix_diag_op<op>(op(size));
+ }
+
+ template <
+ typename EXP
+ >
+ const matrix_diag_op<op_identity_matrix_2<typename EXP::type> > identity_matrix (
+ const matrix_exp<EXP>& mat
+ )
+ {
+ DLIB_ASSERT(mat.nr() == mat.nc(),
+ "\tconst matrix_exp identity_matrix(mat)"
+ << "\n\t mat must be a square matrix."
+ << "\n\t mat.nr(): " << mat.nr()
+ << "\n\t mat.nc(): " << mat.nc()
+ );
+ typedef typename EXP::type T;
+ typedef op_identity_matrix_2<T> op;
+ return matrix_diag_op<op>(op(mat.nr()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename T
+ >
+ const matrix_op<op_add_diag<EXP> > operator+ (
+ const matrix_exp<EXP>& lhs,
+ const matrix_exp<matrix_diag_op<op_identity_matrix_2<T> > >& DLIB_IF_ASSERT(rhs)
+ )
+ {
+ // both matrices must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<T,typename EXP::type>::value == true));
+
+ // You can only add matrices together if they both have the same number of rows and columns.
+ DLIB_ASSERT(lhs.nc() == rhs.nc() &&
+ lhs.nr() == rhs.nr(),
+ "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)"
+ << "\n\tYou are trying to add two incompatible matrices together"
+ << "\n\tlhs.nr(): " << lhs.nr()
+ << "\n\tlhs.nc(): " << lhs.nc()
+ << "\n\trhs.nr(): " << rhs.nr()
+ << "\n\trhs.nc(): " << rhs.nc()
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+
+ typedef op_add_diag<EXP> op;
+ return matrix_op<op>(op(lhs.ref(),1));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename T
+ >
+ const matrix_op<op_add_diag<EXP> > operator+ (
+ const matrix_exp<matrix_diag_op<op_identity_matrix_2<T> > >& DLIB_IF_ASSERT(lhs),
+ const matrix_exp<EXP>& rhs
+ )
+ {
+ // both matrices must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<T,typename EXP::type>::value == true));
+
+ // You can only add matrices together if they both have the same number of rows and columns.
+ DLIB_ASSERT(lhs.nc() == rhs.nc() &&
+ lhs.nr() == rhs.nr(),
+ "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)"
+ << "\n\tYou are trying to add two incompatible matrices together"
+ << "\n\tlhs.nr(): " << lhs.nr()
+ << "\n\tlhs.nc(): " << lhs.nc()
+ << "\n\trhs.nr(): " << rhs.nr()
+ << "\n\trhs.nc(): " << rhs.nc()
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+
+ typedef op_add_diag<EXP> op;
+ return matrix_op<op>(op(rhs.ref(),1));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long N
+ >
+ struct op_const_diag_matrix : does_not_alias
+ {
+ op_const_diag_matrix(const long& size_, const T& value_) : size(size_),value(value_) {}
+
+ const long size;
+ const T value;
+
+ const static long cost = 1;
+ const static long NR = N;
+ const static long NC = N;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+ typedef T type;
+ typedef const T const_ret_type;
+ const_ret_type apply (long r, long c) const
+ {
+ if (r == c)
+ return value;
+ else
+ return 0;
+ }
+
+ long nr() const { return size; }
+ long nc() const { return size; }
+ };
+
+ template <
+ typename T,
+ typename U
+ >
+ const typename disable_if<is_matrix<U>, matrix_diag_op<op_const_diag_matrix<T,0> > >::type operator* (
+ const matrix_exp<matrix_diag_op<op_identity_matrix_2<T> > >& m,
+ const U& value
+ )
+ {
+ typedef op_const_diag_matrix<T,0> op;
+ return matrix_diag_op<op>(op(m.nr(), value));
+ }
+
+ template <
+ typename T,
+ typename U
+ >
+ const typename disable_if<is_matrix<U>, matrix_diag_op<op_const_diag_matrix<T,0> > >::type operator* (
+ const U& value,
+ const matrix_exp<matrix_diag_op<op_identity_matrix_2<T> > >& m
+ )
+ {
+ typedef op_const_diag_matrix<T,0> op;
+ return matrix_diag_op<op>(op(m.nr(), value));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename T,
+ long N
+ >
+ const matrix_op<op_add_diag<EXP> > operator+ (
+ const matrix_exp<EXP>& lhs,
+ const matrix_exp<matrix_diag_op<op_const_diag_matrix<T,N> > >& rhs
+ )
+ {
+ // both matrices must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<T,typename EXP::type>::value == true));
+
+ // You can only add matrices together if they both have the same number of rows and columns.
+ DLIB_ASSERT(lhs.nc() == rhs.nc() &&
+ lhs.nr() == rhs.nr(),
+ "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)"
+ << "\n\tYou are trying to add two incompatible matrices together"
+ << "\n\tlhs.nr(): " << lhs.nr()
+ << "\n\tlhs.nc(): " << lhs.nc()
+ << "\n\trhs.nr(): " << rhs.nr()
+ << "\n\trhs.nc(): " << rhs.nc()
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+
+ typedef op_add_diag<EXP> op;
+ return matrix_op<op>(op(lhs.ref(),rhs.ref().op.value));
+ }
+
+ template <
+ typename EXP,
+ typename T,
+ long N
+ >
+ const matrix_op<op_add_diag<EXP> > operator+ (
+ const matrix_exp<matrix_diag_op<op_const_diag_matrix<T,N> > >& lhs,
+ const matrix_exp<EXP>& rhs
+ )
+ {
+ // both matrices must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<T,typename EXP::type>::value == true));
+
+ // You can only add matrices together if they both have the same number of rows and columns.
+ DLIB_ASSERT(lhs.nc() == rhs.nc() &&
+ lhs.nr() == rhs.nr(),
+ "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)"
+ << "\n\tYou are trying to add two incompatible matrices together"
+ << "\n\tlhs.nr(): " << lhs.nr()
+ << "\n\tlhs.nc(): " << lhs.nc()
+ << "\n\trhs.nr(): " << rhs.nr()
+ << "\n\trhs.nc(): " << rhs.nc()
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+
+ typedef op_add_diag<EXP> op;
+ return matrix_op<op>(op(rhs.ref(),lhs.ref().op.value));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long N
+ >
+ struct op_identity_matrix : does_not_alias
+ {
+ const static long cost = 1;
+ const static long NR = N;
+ const static long NC = N;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+ typedef T type;
+ typedef const T const_ret_type;
+ const_ret_type apply ( long r, long c) const { return static_cast<type>(r == c); }
+
+ long nr () const { return NR; }
+ long nc () const { return NC; }
+ };
+
+ template <
+ typename T,
+ long N
+ >
+ const matrix_diag_op<op_identity_matrix<T,N> > identity_matrix (
+ )
+ {
+ COMPILE_TIME_ASSERT(N > 0);
+
+ typedef op_identity_matrix<T,N> op;
+ return matrix_diag_op<op>(op());
+ }
+
+ template <
+ typename T,
+ typename U,
+ long N
+ >
+ const typename disable_if<is_matrix<U>, matrix_diag_op<op_const_diag_matrix<T,N> > >::type operator* (
+ const matrix_exp<matrix_diag_op<op_identity_matrix<T,N> > >& m,
+ const U& value
+ )
+ {
+ typedef op_const_diag_matrix<T,N> op;
+ return matrix_diag_op<op>(op(m.nr(), value));
+ }
+
+ template <
+ typename T,
+ typename U,
+ long N
+ >
+ const typename disable_if<is_matrix<U>, matrix_diag_op<op_const_diag_matrix<T,N> > >::type operator* (
+ const U& value,
+ const matrix_exp<matrix_diag_op<op_identity_matrix<T,N> > >& m
+ )
+ {
+ typedef op_const_diag_matrix<T,N> op;
+ return matrix_diag_op<op>(op(m.nr(), value));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename T,
+ long N
+ >
+ const matrix_op<op_add_diag<EXP> > operator+ (
+ const matrix_exp<matrix_diag_op<op_identity_matrix<T,N> > >& DLIB_IF_ASSERT(lhs),
+ const matrix_exp<EXP>& rhs
+ )
+ {
+ // both matrices must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<T,typename EXP::type>::value == true));
+
+ // You can only add matrices together if they both have the same number of rows and columns.
+ DLIB_ASSERT(lhs.nc() == rhs.nc() &&
+ lhs.nr() == rhs.nr(),
+ "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)"
+ << "\n\tYou are trying to add two incompatible matrices together"
+ << "\n\tlhs.nr(): " << lhs.nr()
+ << "\n\tlhs.nc(): " << lhs.nc()
+ << "\n\trhs.nr(): " << rhs.nr()
+ << "\n\trhs.nc(): " << rhs.nc()
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+
+ typedef op_add_diag<EXP> op;
+ return matrix_op<op>(op(rhs.ref(),1));
+ }
+
+ template <
+ typename EXP,
+ typename T,
+ long N
+ >
+ const matrix_op<op_add_diag<EXP> > operator+ (
+ const matrix_exp<EXP>& lhs,
+ const matrix_exp<matrix_diag_op<op_identity_matrix<T,N> > >& DLIB_IF_ASSERT(rhs)
+ )
+ {
+ // both matrices must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<T,typename EXP::type>::value == true));
+
+ // You can only add matrices together if they both have the same number of rows and columns.
+ DLIB_ASSERT(lhs.nc() == rhs.nc() &&
+ lhs.nr() == rhs.nr(),
+ "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)"
+ << "\n\tYou are trying to add two incompatible matrices together"
+ << "\n\tlhs.nr(): " << lhs.nr()
+ << "\n\tlhs.nc(): " << lhs.nc()
+ << "\n\trhs.nr(): " << rhs.nr()
+ << "\n\trhs.nc(): " << rhs.nc()
+ << "\n\t&lhs: " << &lhs
+ << "\n\t&rhs: " << &rhs
+ );
+
+
+ typedef op_add_diag<EXP> op;
+ return matrix_op<op>(op(lhs.ref(),1));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, long R, long C>
+ struct op_rotate
+ {
+ op_rotate(const M& m_) : m(m_) {}
+ const M& m;
+
+ const static long cost = M::cost + 2;
+ const static long NR = M::NR;
+ const static long NC = M::NC;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ const_ret_type apply ( long r, long c) const { return m((r+R)%m.nr(),(c+C)%m.nc()); }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ long R,
+ long C,
+ typename EXP
+ >
+ const matrix_op<op_rotate<EXP,R,C> > rotate (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_rotate<EXP,R,C> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ // A template to tell me if two types can be multiplied together in a sensible way. Here
+ // I'm saying it is ok if they are both the same type or one is the complex version of the other.
+ template <typename T, typename U> struct compatible { static const bool value = false; typedef T type; };
+ template <typename T> struct compatible<T,T> { static const bool value = true; typedef T type; };
+ template <typename T> struct compatible<std::complex<T>,T> { static const bool value = true; typedef std::complex<T> type; };
+ template <typename T> struct compatible<T,std::complex<T> > { static const bool value = true; typedef std::complex<T> type; };
+ }
+
+
+ template <typename M1, typename M2>
+ struct op_pointwise_multiply : basic_op_mm<M1,M2>
+ {
+ op_pointwise_multiply( const M1& m1_, const M2& m2_) : basic_op_mm<M1,M2>(m1_,m2_){}
+
+ typedef typename impl::compatible<typename M1::type, typename M2::type>::type type;
+ typedef const type const_ret_type;
+ const static long cost = M1::cost + M2::cost + 1;
+
+ const_ret_type apply ( long r, long c) const
+ { return this->m1(r,c)*this->m2(r,c); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline const matrix_op<op_pointwise_multiply<EXP1,EXP2> > pointwise_multiply (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b
+ )
+ {
+ COMPILE_TIME_ASSERT((impl::compatible<typename EXP1::type,typename EXP2::type>::value == true));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0);
+ DLIB_ASSERT(a.nr() == b.nr() &&
+ a.nc() == b.nc(),
+ "\tconst matrix_exp pointwise_multiply(const matrix_exp& a, const matrix_exp& b)"
+ << "\n\tYou can only make a do a pointwise multiply with two equally sized matrices"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\tb.nc(): " << b.nc()
+ );
+ typedef op_pointwise_multiply<EXP1,EXP2> op;
+ return matrix_op<op>(op(a.ref(),b.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, typename M3>
+ struct op_pointwise_multiply3 : basic_op_mmm<M1,M2,M3>
+ {
+ op_pointwise_multiply3( const M1& m1_, const M2& m2_, const M3& m3_) :
+ basic_op_mmm<M1,M2,M3>(m1_,m2_,m3_){}
+
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ const static long cost = M1::cost + M2::cost + M3::cost + 2;
+
+ const_ret_type apply (long r, long c) const
+ { return this->m1(r,c)*this->m2(r,c)*this->m3(r,c); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3
+ >
+ inline const matrix_op<op_pointwise_multiply3<EXP1,EXP2,EXP3> >
+ pointwise_multiply (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b,
+ const matrix_exp<EXP3>& c
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,typename EXP3::type>::value == true));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NR == 0 || EXP2::NC == 0);
+ COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0);
+ COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0);
+ DLIB_ASSERT(a.nr() == b.nr() &&
+ a.nc() == b.nc() &&
+ b.nr() == c.nr() &&
+ b.nc() == c.nc(),
+ "\tconst matrix_exp pointwise_multiply(a,b,c)"
+ << "\n\tYou can only make a do a pointwise multiply between equally sized matrices"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\tb.nc(): " << b.nc()
+ << "\n\tc.nr(): " << c.nr()
+ << "\n\tc.nc(): " << c.nc()
+ );
+
+ typedef op_pointwise_multiply3<EXP1,EXP2,EXP3> op;
+ return matrix_op<op>(op(a.ref(),b.ref(),c.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, typename M3, typename M4>
+ struct op_pointwise_multiply4 : basic_op_mmmm<M1,M2,M3,M4>
+ {
+ op_pointwise_multiply4( const M1& m1_, const M2& m2_, const M3& m3_, const M4& m4_) :
+ basic_op_mmmm<M1,M2,M3,M4>(m1_,m2_,m3_,m4_){}
+
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ const static long cost = M1::cost + M2::cost + M3::cost + M4::cost + 3;
+
+ const_ret_type apply (long r, long c) const
+ { return this->m1(r,c)*this->m2(r,c)*this->m3(r,c)*this->m4(r,c); }
+ };
+
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3,
+ typename EXP4
+ >
+ inline const matrix_op<op_pointwise_multiply4<EXP1,EXP2,EXP3,EXP4> > pointwise_multiply (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b,
+ const matrix_exp<EXP3>& c,
+ const matrix_exp<EXP4>& d
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,typename EXP3::type>::value == true));
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP3::type,typename EXP4::type>::value == true));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0 );
+ COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0);
+ COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0);
+ COMPILE_TIME_ASSERT(EXP3::NR == EXP4::NR || EXP3::NR == 0 || EXP4::NR == 0);
+ COMPILE_TIME_ASSERT(EXP3::NC == EXP4::NC || EXP3::NC == 0 || EXP4::NC == 0);
+ DLIB_ASSERT(a.nr() == b.nr() &&
+ a.nc() == b.nc() &&
+ b.nr() == c.nr() &&
+ b.nc() == c.nc() &&
+ c.nr() == d.nr() &&
+ c.nc() == d.nc(),
+ "\tconst matrix_exp pointwise_multiply(a,b,c,d)"
+ << "\n\tYou can only make a do a pointwise multiply between equally sized matrices"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\tb.nc(): " << b.nc()
+ << "\n\tc.nr(): " << c.nr()
+ << "\n\tc.nc(): " << c.nc()
+ << "\n\td.nr(): " << d.nr()
+ << "\n\td.nc(): " << d.nc()
+ );
+
+ typedef op_pointwise_multiply4<EXP1,EXP2,EXP3,EXP4> op;
+ return matrix_op<op>(op(a.ref(),b.ref(),c.ref(),d.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename P,
+ int type = static_switch<
+ pixel_traits<P>::grayscale,
+ pixel_traits<P>::rgb,
+ pixel_traits<P>::hsi,
+ pixel_traits<P>::rgb_alpha,
+ pixel_traits<P>::lab
+ >::value
+ >
+ struct pixel_to_vector_helper;
+
+ template <typename P>
+ struct pixel_to_vector_helper<P,1>
+ {
+ template <typename M>
+ static void assign (
+ M& m,
+ const P& pixel
+ )
+ {
+ m(0) = static_cast<typename M::type>(pixel);
+ }
+ };
+
+ template <typename P>
+ struct pixel_to_vector_helper<P,2>
+ {
+ template <typename M>
+ static void assign (
+ M& m,
+ const P& pixel
+ )
+ {
+ m(0) = static_cast<typename M::type>(pixel.red);
+ m(1) = static_cast<typename M::type>(pixel.green);
+ m(2) = static_cast<typename M::type>(pixel.blue);
+ }
+ };
+
+ template <typename P>
+ struct pixel_to_vector_helper<P,3>
+ {
+ template <typename M>
+ static void assign (
+ M& m,
+ const P& pixel
+ )
+ {
+ m(0) = static_cast<typename M::type>(pixel.h);
+ m(1) = static_cast<typename M::type>(pixel.s);
+ m(2) = static_cast<typename M::type>(pixel.i);
+ }
+ };
+
+ template <typename P>
+ struct pixel_to_vector_helper<P,4>
+ {
+ template <typename M>
+ static void assign (
+ M& m,
+ const P& pixel
+ )
+ {
+ m(0) = static_cast<typename M::type>(pixel.red);
+ m(1) = static_cast<typename M::type>(pixel.green);
+ m(2) = static_cast<typename M::type>(pixel.blue);
+ m(3) = static_cast<typename M::type>(pixel.alpha);
+ }
+ };
+
+ template <typename P>
+ struct pixel_to_vector_helper<P,5>
+ {
+ template <typename M>
+ static void assign (
+ M& m,
+ const P& pixel
+ )
+ {
+ m(0) = static_cast<typename M::type>(pixel.l);
+ m(1) = static_cast<typename M::type>(pixel.a);
+ m(2) = static_cast<typename M::type>(pixel.b);
+ }
+ };
+
+
+ template <
+ typename T,
+ typename P
+ >
+ inline const matrix<T,pixel_traits<P>::num,1> pixel_to_vector (
+ const P& pixel
+ )
+ {
+ COMPILE_TIME_ASSERT(pixel_traits<P>::num > 0);
+ matrix<T,pixel_traits<P>::num,1> m;
+ pixel_to_vector_helper<P>::assign(m,pixel);
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename P,
+ int type = static_switch<
+ pixel_traits<P>::grayscale,
+ pixel_traits<P>::rgb,
+ pixel_traits<P>::hsi,
+ pixel_traits<P>::rgb_alpha,
+ pixel_traits<P>::lab
+ >::value
+ >
+ struct vector_to_pixel_helper;
+
+ template <typename P>
+ struct vector_to_pixel_helper<P,1>
+ {
+ template <typename M>
+ static void assign (
+ P& pixel,
+ const M& m
+ )
+ {
+ pixel = static_cast<unsigned char>(m(0));
+ }
+ };
+
+ template <typename P>
+ struct vector_to_pixel_helper<P,2>
+ {
+ template <typename M>
+ static void assign (
+ P& pixel,
+ const M& m
+ )
+ {
+ pixel.red = static_cast<unsigned char>(m(0));
+ pixel.green = static_cast<unsigned char>(m(1));
+ pixel.blue = static_cast<unsigned char>(m(2));
+ }
+ };
+
+ template <typename P>
+ struct vector_to_pixel_helper<P,3>
+ {
+ template <typename M>
+ static void assign (
+ P& pixel,
+ const M& m
+ )
+ {
+ pixel.h = static_cast<unsigned char>(m(0));
+ pixel.s = static_cast<unsigned char>(m(1));
+ pixel.i = static_cast<unsigned char>(m(2));
+ }
+ };
+
+ template <typename P>
+ struct vector_to_pixel_helper<P,4>
+ {
+ template <typename M>
+ static void assign (
+ P& pixel,
+ const M& m
+ )
+ {
+ pixel.red = static_cast<unsigned char>(m(0));
+ pixel.green = static_cast<unsigned char>(m(1));
+ pixel.blue = static_cast<unsigned char>(m(2));
+ pixel.alpha = static_cast<unsigned char>(m(3));
+ }
+ };
+
+ template <typename P>
+ struct vector_to_pixel_helper<P,5>
+ {
+ template <typename M>
+ static void assign (
+ P& pixel,
+ const M& m
+ )
+ {
+ pixel.l = static_cast<unsigned char>(m(0));
+ pixel.a = static_cast<unsigned char>(m(1));
+ pixel.b = static_cast<unsigned char>(m(2));
+ }
+ };
+
+ template <
+ typename P,
+ typename EXP
+ >
+ inline void vector_to_pixel (
+ P& pixel,
+ const matrix_exp<EXP>& vector
+ )
+ {
+ COMPILE_TIME_ASSERT(pixel_traits<P>::num == matrix_exp<EXP>::NR);
+ COMPILE_TIME_ASSERT(matrix_exp<EXP>::NC == 1);
+ vector_to_pixel_helper<P>::assign(pixel,vector);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, long lower, long upper>
+ struct op_clamp : basic_op_m<M>
+ {
+ op_clamp( const M& m_) : basic_op_m<M>(m_){}
+
+ typedef typename M::type type;
+ typedef const typename M::type const_ret_type;
+ const static long cost = M::cost + 2;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ const type temp = this->m(r,c);
+ if (temp > static_cast<type>(upper))
+ return static_cast<type>(upper);
+ else if (temp < static_cast<type>(lower))
+ return static_cast<type>(lower);
+ else
+ return temp;
+ }
+ };
+
+ template <
+ long l,
+ long u,
+ typename EXP
+ >
+ const matrix_op<op_clamp<EXP,l,u> > clamp (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_clamp<EXP,l,u> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_clamp2 : basic_op_m<M>
+ {
+ typedef typename M::type type;
+
+ op_clamp2( const M& m_, const type& l, const type& u) :
+ basic_op_m<M>(m_), lower(l), upper(u){}
+
+ const type& lower;
+ const type& upper;
+
+ typedef const typename M::type const_ret_type;
+ const static long cost = M::cost + 2;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ const type temp = this->m(r,c);
+ if (temp > upper)
+ return upper;
+ else if (temp < lower)
+ return lower;
+ else
+ return temp;
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_clamp2<EXP> > clamp (
+ const matrix_exp<EXP>& m,
+ const typename EXP::type& lower,
+ const typename EXP::type& upper
+ )
+ {
+ typedef op_clamp2<EXP> op;
+ return matrix_op<op>(op(m.ref(),lower, upper));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, typename M3>
+ struct op_clamp_m : basic_op_mmm<M1,M2,M3>
+ {
+ op_clamp_m( const M1& m1_, const M2& m2_, const M3& m3_) :
+ basic_op_mmm<M1,M2,M3>(m1_,m2_,m3_){}
+
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ const static long cost = M1::cost + M2::cost + M3::cost + 2;
+
+ const_ret_type apply (long r, long c) const
+ {
+ const type val = this->m1(r,c);
+ const type lower = this->m2(r,c);
+ const type upper = this->m3(r,c);
+ if (val <= upper)
+ {
+ if (lower <= val)
+ return val;
+ else
+ return lower;
+ }
+ else
+ {
+ return upper;
+ }
+ }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3
+ >
+ const matrix_op<op_clamp_m<EXP1,EXP2,EXP3> >
+ clamp (
+ const matrix_exp<EXP1>& m,
+ const matrix_exp<EXP2>& lower,
+ const matrix_exp<EXP3>& upper
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,typename EXP3::type>::value == true));
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NR == 0 || EXP2::NC == 0);
+ COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0);
+ COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0);
+ DLIB_ASSERT(m.nr() == lower.nr() &&
+ m.nc() == lower.nc() &&
+ m.nr() == upper.nr() &&
+ m.nc() == upper.nc(),
+ "\tconst matrix_exp clamp(m,lower,upper)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t m.nr(): " << m.nr()
+ << "\n\t m.nc(): " << m.nc()
+ << "\n\t lower.nr(): " << lower.nr()
+ << "\n\t lower.nc(): " << lower.nc()
+ << "\n\t upper.nr(): " << upper.nr()
+ << "\n\t upper.nc(): " << upper.nc()
+ );
+
+ typedef op_clamp_m<EXP1,EXP2,EXP3> op;
+ return matrix_op<op>(op(m.ref(),lower.ref(),upper.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_lowerbound : basic_op_m<M>
+ {
+ typedef typename M::type type;
+
+ op_lowerbound( const M& m_, const type& thresh_) :
+ basic_op_m<M>(m_), thresh(thresh_){}
+
+ const type& thresh;
+
+ typedef const typename M::type const_ret_type;
+ const static long cost = M::cost + 2;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ const type temp = this->m(r,c);
+ if (temp >= thresh)
+ return temp;
+ else
+ return thresh;
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_lowerbound<EXP> > lowerbound (
+ const matrix_exp<EXP>& m,
+ const typename EXP::type& thresh
+ )
+ {
+ typedef op_lowerbound<EXP> op;
+ return matrix_op<op>(op(m.ref(), thresh));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_upperbound : basic_op_m<M>
+ {
+ typedef typename M::type type;
+
+ op_upperbound( const M& m_, const type& thresh_) :
+ basic_op_m<M>(m_), thresh(thresh_){}
+
+ const type& thresh;
+
+ typedef const typename M::type const_ret_type;
+ const static long cost = M::cost + 2;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ const type temp = this->m(r,c);
+ if (temp <= thresh)
+ return temp;
+ else
+ return thresh;
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_upperbound<EXP> > upperbound (
+ const matrix_exp<EXP>& m,
+ const typename EXP::type& thresh
+ )
+ {
+ typedef op_upperbound<EXP> op;
+ return matrix_op<op>(op(m.ref(), thresh));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_reshape
+ {
+ op_reshape(const M& m_, const long& rows_, const long& cols_) : m(m_),rows(rows_),cols(cols_) {}
+ const M& m;
+ const long rows;
+ const long cols;
+
+ const static long cost = M::cost+2;
+ const static long NR = 0;
+ const static long NC = 0;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ const long idx = r*cols + c;
+ return m(idx/m.nc(), idx%m.nc());
+ }
+
+ long nr () const { return rows; }
+ long nc () const { return cols; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_reshape<EXP> > reshape (
+ const matrix_exp<EXP>& m,
+ const long& rows,
+ const long& cols
+ )
+ {
+ DLIB_ASSERT(m.size() == rows*cols && rows > 0 && cols > 0,
+ "\tconst matrix_exp reshape(m, rows, cols)"
+ << "\n\t The size of m must match the dimensions you want to reshape it into."
+ << "\n\t m.size(): " << m.size()
+ << "\n\t rows*cols: " << rows*cols
+ << "\n\t rows: " << rows
+ << "\n\t cols: " << cols
+ );
+
+ typedef op_reshape<EXP> op;
+ return matrix_op<op>(op(m.ref(), rows, cols));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ typename disable_if<is_complex<typename EXP1::type>,bool>::type equal (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b,
+ const typename EXP1::type eps = 100*std::numeric_limits<typename EXP1::type>::epsilon()
+ )
+ {
+ // check if the dimensions don't match
+ if (a.nr() != b.nr() || a.nc() != b.nc())
+ return false;
+
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ if (std::abs(a(r,c)-b(r,c)) > eps)
+ return false;
+ }
+ }
+
+ // no non-equal points found so we return true
+ return true;
+ }
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ typename enable_if<is_complex<typename EXP1::type>,bool>::type equal (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b,
+ const typename EXP1::type::value_type eps = 100*std::numeric_limits<typename EXP1::type::value_type>::epsilon()
+ )
+ {
+ // check if the dimensions don't match
+ if (a.nr() != b.nr() || a.nc() != b.nc())
+ return false;
+
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ if (std::abs(real(a(r,c)-b(r,c))) > eps ||
+ std::abs(imag(a(r,c)-b(r,c))) > eps)
+ return false;
+ }
+ }
+
+ // no non-equal points found so we return true
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_scale_columns
+ {
+ op_scale_columns(const M1& m1_, const M2& m2_) : m1(m1_), m2(m2_) {}
+ const M1& m1;
+ const M2& m2;
+
+ const static long cost = M1::cost + M2::cost + 1;
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+ const static long NR = M1::NR;
+ const static long NC = M1::NC;
+
+ const_ret_type apply ( long r, long c) const { return m1(r,c)*m2(c); }
+
+ long nr () const { return m1.nr(); }
+ long nc () const { return m1.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item) ; }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.destructively_aliases(item) || m2.aliases(item); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ const matrix_op<op_scale_columns<EXP1,EXP2> > scale_columns (
+ const matrix_exp<EXP1>& m,
+ const matrix_exp<EXP2>& v
+ )
+ {
+ // Both arguments to this function must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ // The v argument must be a row or column vector.
+ COMPILE_TIME_ASSERT((EXP2::NC == 1 || EXP2::NC == 0) || (EXP2::NR == 1 || EXP2::NR == 0));
+
+ // figure out the compile time known length of v
+ const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax<EXP2::NR,EXP2::NC>::value);
+
+ // the length of v must match the number of columns in m
+ COMPILE_TIME_ASSERT(EXP1::NC == v_len || EXP1::NC == 0 || v_len == 0);
+
+ DLIB_ASSERT(is_vector(v) == true && v.size() == m.nc(),
+ "\tconst matrix_exp scale_columns(m, v)"
+ << "\n\tv must be a row or column vector and its length must match the number of columns in m"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tv.nr(): " << v.nr()
+ << "\n\tv.nc(): " << v.nc()
+ );
+ typedef op_scale_columns<EXP1,EXP2> op;
+ return matrix_op<op>(op(m.ref(),v.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_scale_columns_diag
+ {
+ op_scale_columns_diag(const M1& m1_, const M2& m2_) : m1(m1_), m2(m2_) {}
+ const M1& m1;
+ const M2& m2;
+
+ const static long cost = M1::cost + M2::cost + 1;
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+ const static long NR = M1::NR;
+ const static long NC = M1::NC;
+
+ const_ret_type apply ( long r, long c) const { return m1(r,c)*m2(c,c); }
+
+ long nr () const { return m1.nr(); }
+ long nc () const { return m1.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item) ; }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.destructively_aliases(item) || m2.aliases(item); }
+ };
+
+// turn expressions of the form mat*diagonal_matrix into scale_columns(mat, d)
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ const matrix_op<op_scale_columns_diag<EXP1,EXP2> > operator* (
+ const matrix_exp<EXP1>& m,
+ const matrix_diag_exp<EXP2>& d
+ )
+ {
+ // Both arguments to this function must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+
+ // figure out the compile time known length of d
+ const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax<EXP2::NR,EXP2::NC>::value);
+
+ // the length of d must match the number of columns in m
+ COMPILE_TIME_ASSERT(EXP1::NC == v_len || EXP1::NC == 0 || v_len == 0);
+
+ DLIB_ASSERT(m.nc() == d.nr(),
+ "\tconst matrix_exp operator*(m, d)"
+ << "\n\tmatrix dimensions don't match"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\td.nr(): " << d.nr()
+ << "\n\td.nc(): " << d.nc()
+ );
+ typedef op_scale_columns_diag<EXP1,EXP2> op;
+ return matrix_op<op>(op(m.ref(),d.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_scale_rows
+ {
+ op_scale_rows(const M1& m1_, const M2& m2_) : m1(m1_), m2(m2_) {}
+ const M1& m1;
+ const M2& m2;
+
+ const static long cost = M1::cost + M2::cost + 1;
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+ const static long NR = M1::NR;
+ const static long NC = M1::NC;
+
+ const_ret_type apply ( long r, long c) const { return m1(r,c)*m2(r); }
+
+ long nr () const { return m1.nr(); }
+ long nc () const { return m1.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item) ; }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.destructively_aliases(item) || m2.aliases(item); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ const matrix_op<op_scale_rows<EXP1,EXP2> > scale_rows (
+ const matrix_exp<EXP1>& m,
+ const matrix_exp<EXP2>& v
+ )
+ {
+ // Both arguments to this function must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ // The v argument must be a row or column vector.
+ COMPILE_TIME_ASSERT((EXP2::NC == 1 || EXP2::NC == 0) || (EXP2::NR == 1 || EXP2::NR == 0));
+
+ // figure out the compile time known length of v
+ const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax<EXP2::NR,EXP2::NC>::value);
+
+ // the length of v must match the number of rows in m
+ COMPILE_TIME_ASSERT(EXP1::NR == v_len || EXP1::NR == 0 || v_len == 0);
+
+ DLIB_ASSERT(is_vector(v) == true && v.size() == m.nr(),
+ "\tconst matrix_exp scale_rows(m, v)"
+ << "\n\tv must be a row or column vector and its length must match the number of rows in m"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tv.nr(): " << v.nr()
+ << "\n\tv.nc(): " << v.nc()
+ );
+ typedef op_scale_rows<EXP1,EXP2> op;
+ return matrix_op<op>(op(m.ref(),v.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_scale_rows_diag
+ {
+ op_scale_rows_diag(const M1& m1_, const M2& m2_) : m1(m1_), m2(m2_) {}
+ const M1& m1;
+ const M2& m2;
+
+ const static long cost = M1::cost + M2::cost + 1;
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+ const static long NR = M1::NR;
+ const static long NC = M1::NC;
+
+ const_ret_type apply ( long r, long c) const { return m1(r,c)*m2(r,r); }
+
+ long nr () const { return m1.nr(); }
+ long nc () const { return m1.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item) ; }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.destructively_aliases(item) || m2.aliases(item); }
+ };
+
+// turn expressions of the form diagonal_matrix*mat into scale_rows(mat, d)
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ const matrix_op<op_scale_rows_diag<EXP1,EXP2> > operator* (
+ const matrix_diag_exp<EXP2>& d,
+ const matrix_exp<EXP1>& m
+ )
+ {
+ // Both arguments to this function must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+
+ // figure out the compile time known length of d
+ const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax<EXP2::NR,EXP2::NC>::value);
+
+ // the length of d must match the number of rows in m
+ COMPILE_TIME_ASSERT(EXP1::NR == v_len || EXP1::NR == 0 || v_len == 0);
+
+ DLIB_ASSERT(d.nc() == m.nr(),
+ "\tconst matrix_exp operator*(d, m)"
+ << "\n\tThe dimensions of the d and m matrices don't match."
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\td.nr(): " << d.nr()
+ << "\n\td.nc(): " << d.nc()
+ );
+ typedef op_scale_rows_diag<EXP1,EXP2> op;
+ return matrix_op<op>(op(m.ref(),d.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ /*
+ The idea here is to catch expressions of the form d*M*d where d is diagonal and M
+ is some square matrix and turn them into something equivalent to
+ pointwise_multiply(diag(d)*trans(diag(d)), M).
+
+ The reason for this is that doing it this way is more numerically stable. In particular,
+ doing 2 matrix multiplies as suggested by d*M*d could result in an asymmetric matrix even
+ if M is symmetric to begin with.
+ */
+
+ template <typename M1, typename M2, typename M3>
+ struct op_diag_m_diag
+ {
+ // This operator represents M1*M2*M3 where M1 and M3 are diagonal
+
+ op_diag_m_diag(const M1& m1_, const M2& m2_, const M3& m3_) : m1(m1_), m2(m2_), m3(m3_) {}
+ const M1& m1;
+ const M2& m2;
+ const M3& m3;
+
+ const static long cost = M1::cost + M2::cost + M3::cost + 1;
+ typedef typename M2::type type;
+ typedef const typename M2::type const_ret_type;
+ typedef typename M2::mem_manager_type mem_manager_type;
+ typedef typename M2::layout_type layout_type;
+ const static long NR = M2::NR;
+ const static long NC = M2::NC;
+
+ const_ret_type apply ( long r, long c) const { return (m1(r,r)*m3(c,c))*m2(r,c); }
+
+ long nr () const { return m2.nr(); }
+ long nc () const { return m2.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item) || m3.aliases(item) ; }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m2.destructively_aliases(item) || m1.aliases(item) || m3.aliases(item) ; }
+ };
+
+ // catch d*(M*d) = EXP1*EXP2*EXP3
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3
+ >
+ const matrix_op<op_diag_m_diag<EXP1,EXP2,EXP3> > operator* (
+ const matrix_diag_exp<EXP1>& d,
+ const matrix_exp<matrix_op<op_scale_columns_diag<EXP2,EXP3> > >& m
+ )
+ {
+ // Both arguments to this function must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+
+ // figure out the compile time known length of d
+ const long v_len = ((EXP1::NR)*(EXP1::NC) == 0)? 0 : (tmax<EXP1::NR,EXP1::NC>::value);
+
+ // the length of d must match the number of rows in m
+ COMPILE_TIME_ASSERT(EXP2::NR == v_len || EXP2::NR == 0 || v_len == 0);
+
+ DLIB_ASSERT(d.nc() == m.nr(),
+ "\tconst matrix_exp operator*(d, m)"
+ << "\n\tmatrix dimensions don't match"
+ << "\n\td.nr(): " << d.nr()
+ << "\n\td.nc(): " << d.nc()
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ );
+ typedef op_diag_m_diag<EXP1,EXP2,EXP3> op;
+ return matrix_op<op>(op(d.ref(), m.ref().op.m1, m.ref().op.m2));
+ }
+
+ // catch (d*M)*d = EXP1*EXP2*EXP3
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3
+ >
+ const matrix_op<op_diag_m_diag<EXP1,EXP2,EXP3> > operator* (
+ const matrix_exp<matrix_op<op_scale_rows_diag<EXP2,EXP1> > >& m,
+ const matrix_diag_exp<EXP3>& d
+ )
+ {
+ // Both arguments to this function must contain the same type of element
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP3::type,typename EXP2::type>::value == true));
+
+ // figure out the compile time known length of d
+ const long v_len = ((EXP3::NR)*(EXP3::NC) == 0)? 0 : (tmax<EXP3::NR,EXP3::NC>::value);
+
+ // the length of d must match the number of columns in m
+ COMPILE_TIME_ASSERT(EXP2::NC == v_len || EXP2::NC == 0 || v_len == 0);
+
+ DLIB_ASSERT(m.nc() == d.nr(),
+ "\tconst matrix_exp operator*(m, d)"
+ << "\n\tmatrix dimensions don't match"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\td.nr(): " << d.nr()
+ << "\n\td.nc(): " << d.nc()
+ );
+ typedef op_diag_m_diag<EXP1,EXP2,EXP3> op;
+ return matrix_op<op>(op(m.ref().op.m2, m.ref().op.m1, d.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ struct sort_columns_sort_helper
+ {
+ template <typename T>
+ bool operator() (
+ const T& item1,
+ const T& item2
+ ) const
+ {
+ return item1.first < item2.first;
+ }
+ };
+
+ template <
+ typename T, long NR, long NC, typename mm, typename l1,
+ long NR2, long NC2, typename mm2, typename l2
+ >
+ void sort_columns (
+ matrix<T,NR,NC,mm,l1>& m,
+ matrix<T,NR2,NC2,mm2,l2>& v
+ )
+ {
+ COMPILE_TIME_ASSERT(NC2 == 1 || NC2 == 0);
+ COMPILE_TIME_ASSERT(NC == NR2 || NC == 0 || NR2 == 0);
+
+ DLIB_ASSERT(is_col_vector(v) == true && v.size() == m.nc(),
+ "\tconst matrix_exp sort_columns(m, v)"
+ << "\n\tv must be a column vector and its length must match the number of columns in m"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tv.nr(): " << v.nr()
+ << "\n\tv.nc(): " << v.nc()
+ );
+
+
+
+ // Now we have to sort the given vectors in the m matrix according
+ // to how big their corresponding v(column index) values are.
+ typedef std::pair<T, matrix<T,0,1,mm> > col_pair;
+ typedef std_allocator<col_pair, mm> alloc;
+ std::vector<col_pair,alloc> colvalues;
+ col_pair p;
+ for (long r = 0; r < v.nr(); ++r)
+ {
+ p.first = v(r);
+ p.second = colm(m,r);
+ colvalues.push_back(p);
+ }
+ std::sort(colvalues.begin(), colvalues.end(), sort_columns_sort_helper());
+
+ for (long i = 0; i < v.nr(); ++i)
+ {
+ v(i) = colvalues[i].first;
+ set_colm(m,i) = colvalues[i].second;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename mm, typename l1,
+ long NR2, long NC2, typename mm2, typename l2
+ >
+ void rsort_columns (
+ matrix<T,NR,NC,mm,l1>& m,
+ matrix<T,NR2,NC2,mm2,l2>& v
+ )
+ {
+ COMPILE_TIME_ASSERT(NC2 == 1 || NC2 == 0);
+ COMPILE_TIME_ASSERT(NC == NR2 || NC == 0 || NR2 == 0);
+
+ DLIB_ASSERT(is_col_vector(v) == true && v.size() == m.nc(),
+ "\tconst matrix_exp rsort_columns(m, v)"
+ << "\n\tv must be a column vector and its length must match the number of columns in m"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tv.nr(): " << v.nr()
+ << "\n\tv.nc(): " << v.nc()
+ );
+
+
+
+ // Now we have to sort the given vectors in the m matrix according
+ // to how big their corresponding v(column index) values are.
+ typedef std::pair<T, matrix<T,0,1,mm> > col_pair;
+ typedef std_allocator<col_pair, mm> alloc;
+ std::vector<col_pair,alloc> colvalues;
+ col_pair p;
+ for (long r = 0; r < v.nr(); ++r)
+ {
+ p.first = v(r);
+ p.second = colm(m,r);
+ colvalues.push_back(p);
+ }
+ std::sort(colvalues.rbegin(), colvalues.rend(), sort_columns_sort_helper());
+
+ for (long i = 0; i < v.nr(); ++i)
+ {
+ v(i) = colvalues[i].first;
+ set_colm(m,i) = colvalues[i].second;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_tensor_product
+ {
+ op_tensor_product(const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_) {}
+ const M1& m1;
+ const M2& m2;
+
+ const static long cost = M1::cost + M2::cost + 1;
+ const static long NR = M1::NR*M2::NR;
+ const static long NC = M1::NC*M2::NC;
+ typedef typename M1::type type;
+ typedef const typename M1::type const_ret_type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ return m1(r/m2.nr(),c/m2.nc())*m2(r%m2.nr(),c%m2.nc());
+ }
+
+ long nr () const { return m1.nr()*m2.nr(); }
+ long nc () const { return m1.nc()*m2.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline const matrix_op<op_tensor_product<EXP1,EXP2> > tensor_product (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ typedef op_tensor_product<EXP1,EXP2> op;
+ return matrix_op<op>(op(a.ref(),b.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_make_symmetric : basic_op_m<M>
+ {
+ op_make_symmetric ( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost+1;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r >= c)
+ return this->m(r,c);
+ else
+ return this->m(c,r);
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_make_symmetric<EXP> > make_symmetric (
+ const matrix_exp<EXP>& m
+ )
+ {
+ DLIB_ASSERT(m.nr() == m.nc(),
+ "\tconst matrix make_symmetric(m)"
+ << "\n\t m must be a square matrix"
+ << "\n\t m.nr(): " << m.nr()
+ << "\n\t m.nc(): " << m.nc()
+ );
+
+ typedef op_make_symmetric<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_lowerm : basic_op_m<M>
+ {
+ op_lowerm( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost+2;
+ typedef typename M::type type;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r >= c)
+ return this->m(r,c);
+ else
+ return 0;
+ }
+ };
+
+ template <typename M>
+ struct op_lowerm_s : basic_op_m<M>
+ {
+ typedef typename M::type type;
+ op_lowerm_s( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
+
+ const type s;
+
+ const static long cost = M::cost+2;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r > c)
+ return this->m(r,c);
+ else if (r==c)
+ return s;
+ else
+ return 0;
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_lowerm<EXP> > lowerm (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_lowerm<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_lowerm_s<EXP> > lowerm (
+ const matrix_exp<EXP>& m,
+ typename EXP::type s
+ )
+ {
+ typedef op_lowerm_s<EXP> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_upperm : basic_op_m<M>
+ {
+ op_upperm( const M& m_) : basic_op_m<M>(m_){}
+
+ const static long cost = M::cost+2;
+ typedef typename M::type type;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r <= c)
+ return this->m(r,c);
+ else
+ return 0;
+ }
+ };
+
+ template <typename M>
+ struct op_upperm_s : basic_op_m<M>
+ {
+ typedef typename M::type type;
+ op_upperm_s( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
+
+ const type s;
+
+ const static long cost = M::cost+2;
+ typedef const typename M::type const_ret_type;
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r < c)
+ return this->m(r,c);
+ else if (r==c)
+ return s;
+ else
+ return 0;
+ }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_upperm<EXP> > upperm (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_upperm<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_upperm_s<EXP> > upperm (
+ const matrix_exp<EXP>& m,
+ typename EXP::type s
+ )
+ {
+ typedef op_upperm_s<EXP> op;
+ return matrix_op<op>(op(m.ref(),s));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename rand_gen>
+ inline const matrix<double> randm(
+ long nr,
+ long nc,
+ rand_gen& rnd
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0,
+ "\tconst matrix randm(nr, nc, rnd)"
+ << "\n\tInvalid inputs to this function"
+ << "\n\tnr: " << nr
+ << "\n\tnc: " << nc
+ );
+
+ matrix<double> m(nr,nc);
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = rnd.get_random_double();
+ }
+ }
+
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const matrix<double> randm(
+ long nr,
+ long nc
+ )
+ {
+ DLIB_ASSERT(nr >= 0 && nc >= 0,
+ "\tconst matrix randm(nr, nc)"
+ << "\n\tInvalid inputs to this function"
+ << "\n\tnr: " << nr
+ << "\n\tnc: " << nc
+ );
+
+ matrix<double> m(nr,nc);
+ // make a double that contains RAND_MAX + the smallest number that still
+ // makes the resulting double slightly bigger than static_cast<double>(RAND_MAX)
+ double max_val = RAND_MAX;
+ max_val += std::numeric_limits<double>::epsilon()*RAND_MAX;
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = std::rand()/max_val;
+ }
+ }
+
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const matrix_range_exp<double> linspace (
+ double start,
+ double end,
+ long num
+ )
+ {
+ DLIB_ASSERT(num >= 0,
+ "\tconst matrix_exp linspace(start, end, num)"
+ << "\n\tInvalid inputs to this function"
+ << "\n\tstart: " << start
+ << "\n\tend: " << end
+ << "\n\tnum: " << num
+ );
+
+ return matrix_range_exp<double>(start,end,num,false);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_linpiece
+ {
+ op_linpiece(const double val_, const M& joints_) : joints(joints_), val(val_){}
+
+ const M& joints;
+ const double val;
+
+ const static long cost = 10;
+
+ const static long NR = (M::NR*M::NC==0) ? (0) : (M::NR*M::NC-1);
+ const static long NC = 1;
+ typedef typename M::type type;
+ typedef default_memory_manager mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ typedef type const_ret_type;
+ const_ret_type apply (long i, long ) const
+ {
+ if (joints(i) < val)
+ return std::min<type>(val,joints(i+1)) - joints(i);
+ else
+ return 0;
+ }
+
+ long nr () const { return joints.size()-1; }
+ long nc () const { return 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return joints.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return joints.aliases(item); }
+ };
+
+ template < typename EXP >
+ const matrix_op<op_linpiece<EXP> > linpiece (
+ const double val,
+ const matrix_exp<EXP>& joints
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(joints) && joints.size() >= 2,
+ "\t matrix_exp linpiece()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t is_vector(joints): " << is_vector(joints)
+ << "\n\t joints.size(): " << joints.size()
+ );
+#ifdef ENABLE_ASSERTS
+ for (long i = 1; i < joints.size(); ++i)
+ {
+ DLIB_ASSERT(joints(i-1) < joints(i),
+ "\t matrix_exp linpiece()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t joints("<<i-1<<"): " << joints(i-1)
+ << "\n\t joints("<<i<<"): " << joints(i)
+ );
+ }
+#endif
+
+ typedef op_linpiece<EXP> op;
+ return matrix_op<op>(op(val,joints.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const matrix_log_range_exp<double> logspace (
+ double start,
+ double end,
+ long num
+ )
+ {
+ DLIB_ASSERT(num >= 0,
+ "\tconst matrix_exp logspace(start, end, num)"
+ << "\n\tInvalid inputs to this function"
+ << "\n\tstart: " << start
+ << "\n\tend: " << end
+ << "\n\tnum: " << num
+ );
+
+ return matrix_log_range_exp<double>(start,end,num);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_cart_prod
+ {
+ op_cart_prod(const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_) {}
+ const M1& m1;
+ const M2& m2;
+
+ const static long cost = M1::cost+M2::cost+1;
+ typedef typename M1::type type;
+ typedef const typename M1::const_ret_type const_ret_type;
+
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+ const static long NR = M1::NR+M2::NR;
+ const static long NC = M1::NC*M2::NC;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r < m1.nr())
+ return m1(r, c/m2.nc());
+ else
+ return m2(r-m1.nr(), c%m2.nc());
+ }
+
+ long nr () const { return m1.nr() + m2.nr(); }
+ long nc () const { return m1.nc() * m2.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ const matrix_op<op_cart_prod<EXP1,EXP2> > cartesian_product (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+
+ typedef op_cart_prod<EXP1,EXP2> op;
+ return matrix_op<op>(op(a.ref(),b.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_mat_to_vect
+ {
+ op_mat_to_vect(const M& m_) : m(m_) {}
+ const M& m;
+
+ const static long cost = M::cost+2;
+ const static long NR = M::NC*M::NR;
+ const static long NC = 1;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+ const_ret_type apply ( long r, long ) const { return m(r/m.nc(), r%m.nc()); }
+
+ long nr () const { return m.size(); }
+ long nc () const { return 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_mat_to_vect<EXP> > reshape_to_column_vector (
+ const matrix_exp<EXP>& m
+ )
+ {
+ typedef op_mat_to_vect<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR_,
+ long NC_,
+ typename MM
+ >
+ struct op_mat_to_vect2
+ {
+ typedef matrix<T,NR_,NC_,MM,row_major_layout> M;
+ op_mat_to_vect2(const M& m_) : m(m_) {}
+ const M& m;
+
+ const static long cost = M::cost+2;
+ const static long NR = M::NC*M::NR;
+ const static long NC = 1;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+ const_ret_type apply ( long r, long ) const { return (&m(0,0))[r]; }
+
+ long nr () const { return m.size(); }
+ long nc () const { return 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM
+ >
+ const matrix_op<op_mat_to_vect2<T,NR,NC,MM> > reshape_to_column_vector (
+ const matrix<T,NR,NC,MM,row_major_layout>& m
+ )
+ {
+ typedef op_mat_to_vect2<T,NR,NC,MM> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_join_rows
+ {
+ op_join_rows(const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_),_nr(std::max(m1.nr(),m2.nr())) {}
+ const M1& m1;
+ const M2& m2;
+ const long _nr;
+
+ template <typename T, typename U, bool selection>
+ struct type_selector;
+ template <typename T, typename U>
+ struct type_selector<T,U,true> { typedef T type; };
+ template <typename T, typename U>
+ struct type_selector<T,U,false> { typedef U type; };
+
+ // If both const_ret_types are references then we should use them as the const_ret_type type
+ // but otherwise we should use the normal type.
+ typedef typename M1::const_ret_type T1;
+ typedef typename M1::type T2;
+ typedef typename M2::const_ret_type T3;
+ typedef typename type_selector<T1, T2, is_reference_type<T1>::value && is_reference_type<T3>::value>::type const_ret_type;
+
+ const static long cost = M1::cost + M2::cost + 1;
+ const static long NR = tmax<M1::NR, M2::NR>::value;
+ const static long NC = (M1::NC*M2::NC != 0)? (M1::NC+M2::NC) : (0);
+ typedef typename M1::type type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+
+ const_ret_type apply (long r, long c) const
+ {
+ if (c < m1.nc())
+ return m1(r,c);
+ else
+ return m2(r,c-m1.nc());
+ }
+
+ long nr () const { return _nr; }
+ long nc () const { return m1.nc()+m2.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline const matrix_op<op_join_rows<EXP1,EXP2> > join_rows (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ // You are getting an error on this line because you are trying to join two matrices that
+ // don't have the same number of rows
+ COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || (EXP1::NR*EXP2::NR == 0));
+
+ DLIB_ASSERT(a.nr() == b.nr() || a.size() == 0 || b.size() == 0,
+ "\tconst matrix_exp join_rows(const matrix_exp& a, const matrix_exp& b)"
+ << "\n\tYou can only use join_rows() if both matrices have the same number of rows"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nc(): " << b.nc()
+ );
+
+ typedef op_join_rows<EXP1,EXP2> op;
+ return matrix_op<op>(op(a.ref(),b.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ struct op_join_cols
+ {
+ op_join_cols(const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_),_nc(std::max(m1.nc(),m2.nc())) {}
+ const M1& m1;
+ const M2& m2;
+ const long _nc;
+
+ template <typename T, typename U, bool selection>
+ struct type_selector;
+ template <typename T, typename U>
+ struct type_selector<T,U,true> { typedef T type; };
+ template <typename T, typename U>
+ struct type_selector<T,U,false> { typedef U type; };
+
+ // If both const_ret_types are references then we should use them as the const_ret_type type
+ // but otherwise we should use the normal type.
+ typedef typename M1::const_ret_type T1;
+ typedef typename M1::type T2;
+ typedef typename M2::const_ret_type T3;
+ typedef typename type_selector<T1, T2, is_reference_type<T1>::value && is_reference_type<T3>::value>::type const_ret_type;
+
+
+
+ const static long cost = M1::cost + M2::cost + 1;
+ const static long NC = tmax<M1::NC, M2::NC>::value;
+ const static long NR = (M1::NR*M2::NR != 0)? (M1::NR+M2::NR) : (0);
+ typedef typename M1::type type;
+ typedef typename M1::mem_manager_type mem_manager_type;
+ typedef typename M1::layout_type layout_type;
+
+ const_ret_type apply ( long r, long c) const
+ {
+ if (r < m1.nr())
+ return m1(r,c);
+ else
+ return m2(r-m1.nr(),c);
+ }
+
+ long nr () const { return m1.nr()+m2.nr(); }
+ long nc () const { return _nc; }
+
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m1.aliases(item) || m2.aliases(item); }
+ };
+
+ template <
+ typename EXP1,
+ typename EXP2
+ >
+ inline const matrix_op<op_join_cols<EXP1,EXP2> > join_cols (
+ const matrix_exp<EXP1>& a,
+ const matrix_exp<EXP2>& b
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
+ // You are getting an error on this line because you are trying to join two matrices that
+ // don't have the same number of columns
+ COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || (EXP1::NC*EXP2::NC == 0));
+
+ DLIB_ASSERT(a.nc() == b.nc() || a.size() == 0 || b.size() == 0,
+ "\tconst matrix_exp join_cols(const matrix_exp& a, const matrix_exp& b)"
+ << "\n\tYou can only use join_cols() if both matrices have the same number of columns"
+ << "\n\ta.nr(): " << a.nr()
+ << "\n\tb.nr(): " << b.nr()
+ << "\n\ta.nc(): " << a.nc()
+ << "\n\tb.nc(): " << b.nc()
+ );
+
+ typedef op_join_cols<EXP1,EXP2> op;
+ return matrix_op<op>(op(a.ref(),b.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_fliplr
+ {
+ op_fliplr( const M& m_) : m(m_){}
+
+ const M& m;
+
+ const static long cost = M::cost;
+ const static long NR = M::NR;
+ const static long NC = M::NC;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+ const_ret_type apply (long r, long c) const { return m(r,m.nc()-c-1); }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+
+ };
+
+ template <
+ typename M
+ >
+ const matrix_op<op_fliplr<M> > fliplr (
+ const matrix_exp<M>& m
+ )
+ {
+ typedef op_fliplr<M> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_flipud
+ {
+ op_flipud( const M& m_) : m(m_){}
+
+ const M& m;
+
+ const static long cost = M::cost;
+ const static long NR = M::NR;
+ const static long NC = M::NC;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+ const_ret_type apply (long r, long c) const { return m(m.nr()-r-1,c); }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+
+ };
+
+ template <
+ typename M
+ >
+ const matrix_op<op_flipud<M> > flipud (
+ const matrix_exp<M>& m
+ )
+ {
+ typedef op_flipud<M> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_flip
+ {
+ op_flip( const M& m_) : m(m_){}
+
+ const M& m;
+
+ const static long cost = M::cost;
+ const static long NR = M::NR;
+ const static long NC = M::NC;
+ typedef typename M::type type;
+ typedef typename M::const_ret_type const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+ const_ret_type apply (long r, long c) const { return m(m.nr()-r-1, m.nc()-c-1); }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+
+ };
+
+ template <
+ typename M
+ >
+ const matrix_op<op_flip<M> > flip (
+ const matrix_exp<M>& m
+ )
+ {
+ typedef op_flip<M> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L>
+ uint32 hash (
+ const matrix<T,NR,NC,MM,L>& item,
+ uint32 seed = 0
+ )
+ {
+ DLIB_ASSERT_HAS_STANDARD_LAYOUT(T);
+
+ if (item.size() == 0)
+ return 0;
+ else
+ return murmur_hash3(&item(0,0), sizeof(T)*item.size(), seed);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_UTILITIES_
+
diff --git a/ml/dlib/dlib/matrix/matrix_utilities_abstract.h b/ml/dlib/dlib/matrix/matrix_utilities_abstract.h
new file mode 100644
index 000000000..ad4c91167
--- /dev/null
+++ b/ml/dlib/dlib/matrix/matrix_utilities_abstract.h
@@ -0,0 +1,1874 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MATRIx_UTILITIES_ABSTRACT_
+#ifdef DLIB_MATRIx_UTILITIES_ABSTRACT_
+
+#include "matrix_abstract.h"
+#include <complex>
+#include "../pixel.h"
+#include "../geometry/rectangle.h"
+#inclue <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Simple matrix utilities
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ constexpr bool is_row_major (
+ const matrix_exp<EXP>&
+ );
+ /*!
+ ensures
+ - returns true if and only if the given matrix expression uses the row_major_layout.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp diag (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a column vector R that contains the elements from the diagonal
+ of m in the order R(0)==m(0,0), R(1)==m(1,1), R(2)==m(2,2) and so on.
+ !*/
+
+ template <typename EXP>
+ struct diag_exp
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This struct allows you to determine the type of matrix expression
+ object returned from the diag() function. An example makes its
+ use clear:
+
+ template <typename EXP>
+ void do_something( const matrix_exp<EXP>& mat)
+ {
+ // d is a matrix expression that aliases mat.
+ typename diag_exp<EXP>::type d = diag(mat);
+
+ // Print the diagonal of mat. So we see that by using
+ // diag_exp we can save the object returned by diag() in
+ // a local variable.
+ cout << d << endl;
+
+ // Note that you can only save the return value of diag() to
+ // a local variable if the argument to diag() has a lifetime
+ // beyond the diag() expression. The example shown above is
+ // OK but the following would result in undefined behavior:
+ typename diag_exp<EXP>::type bad = diag(mat + mat);
+ }
+ !*/
+ typedef type_of_expression_returned_by_diag type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp diagm (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_vector(m) == true
+ (i.e. m is a row or column matrix)
+ ensures
+ - returns a square matrix M such that:
+ - diag(M) == m
+ - non diagonal elements of M are 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp trans (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns the transpose of the matrix m
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_type::type dot (
+ const matrix_exp& m1,
+ const matrix_exp& m2
+ );
+ /*!
+ requires
+ - is_vector(m1) == true
+ - is_vector(m2) == true
+ - m1.size() == m2.size()
+ - m1.size() > 0
+ ensures
+ - returns the dot product between m1 and m2. That is, this function
+ computes and returns the sum, for all i, of m1(i)*m2(i).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp lowerm (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix M such that:
+ - M::type == the same type that was in m
+ - M has the same dimensions as m
+ - M is the lower triangular part of m. That is:
+ - if (r >= c) then
+ - M(r,c) == m(r,c)
+ - else
+ - M(r,c) == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp lowerm (
+ const matrix_exp& m,
+ const matrix_exp::type scalar_value
+ );
+ /*!
+ ensures
+ - returns a matrix M such that:
+ - M::type == the same type that was in m
+ - M has the same dimensions as m
+ - M is the lower triangular part of m except that the diagonal has
+ been set to scalar_value. That is:
+ - if (r > c) then
+ - M(r,c) == m(r,c)
+ - else if (r == c) then
+ - M(r,c) == scalar_value
+ - else
+ - M(r,c) == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp upperm (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix M such that:
+ - M::type == the same type that was in m
+ - M has the same dimensions as m
+ - M is the upper triangular part of m. That is:
+ - if (r <= c) then
+ - M(r,c) == m(r,c)
+ - else
+ - M(r,c) == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp upperm (
+ const matrix_exp& m,
+ const matrix_exp::type scalar_value
+ );
+ /*!
+ ensures
+ - returns a matrix M such that:
+ - M::type == the same type that was in m
+ - M has the same dimensions as m
+ - M is the upper triangular part of m except that the diagonal has
+ been set to scalar_value. That is:
+ - if (r < c) then
+ - M(r,c) == m(r,c)
+ - else if (r == c) then
+ - M(r,c) == scalar_value
+ - else
+ - M(r,c) == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp make_symmetric (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.nr() == m.nc()
+ (i.e. m must be a square matrix)
+ ensures
+ - returns a matrix M such that:
+ - M::type == the same type that was in m
+ - M has the same dimensions as m
+ - M is a symmetric matrix, that is, M == trans(M) and
+ it is constructed from the lower triangular part of m. Specifically,
+ we have:
+ - lowerm(M) == lowerm(m)
+ - upperm(M) == trans(lowerm(m))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ T val
+ >
+ const matrix_exp uniform_matrix (
+ );
+ /*!
+ requires
+ - NR > 0 && NC > 0
+ ensures
+ - returns an NR by NC matrix with elements of type T and all set to val.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC
+ >
+ const matrix_exp uniform_matrix (
+ const T& val
+ );
+ /*!
+ requires
+ - NR > 0 && NC > 0
+ ensures
+ - returns an NR by NC matrix with elements of type T and all set to val.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp uniform_matrix (
+ long nr,
+ long nc,
+ const T& val
+ );
+ /*!
+ requires
+ - nr >= 0 && nc >= 0
+ ensures
+ - returns an nr by nc matrix with elements of type T and all set to val.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp ones_matrix (
+ const matrix_exp& mat
+ );
+ /*!
+ requires
+ - mat.nr() >= 0 && mat.nc() >= 0
+ ensures
+ - Let T denote the type of element in mat. Then this function
+ returns uniform_matrix<T>(mat.nr(), mat.nc(), 1)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp ones_matrix (
+ long nr,
+ long nc
+ );
+ /*!
+ requires
+ - nr >= 0 && nc >= 0
+ ensures
+ - returns uniform_matrix<T>(nr, nc, 1)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp zeros_matrix (
+ const matrix_exp& mat
+ );
+ /*!
+ requires
+ - mat.nr() >= 0 && mat.nc() >= 0
+ ensures
+ - Let T denote the type of element in mat. Then this function
+ returns uniform_matrix<T>(mat.nr(), mat.nc(), 0)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp zeros_matrix (
+ long nr,
+ long nc
+ );
+ /*!
+ requires
+ - nr >= 0 && nc >= 0
+ ensures
+ - returns uniform_matrix<T>(nr, nc, 0)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp identity_matrix (
+ const matrix_exp& mat
+ );
+ /*!
+ requires
+ - mat.nr() == mat.nc()
+ ensures
+ - returns an identity matrix with the same dimensions as mat and
+ containing the same type of elements as mat.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp identity_matrix (
+ long N
+ );
+ /*!
+ requires
+ - N > 0
+ ensures
+ - returns an N by N identity matrix with elements of type T.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long N
+ >
+ const matrix_exp identity_matrix (
+ );
+ /*!
+ requires
+ - N > 0
+ ensures
+ - returns an N by N identity matrix with elements of type T.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp linspace (
+ double start,
+ double end,
+ long num
+ );
+ /*!
+ requires
+ - num >= 0
+ ensures
+ - returns a matrix M such that:
+ - M::type == double
+ - is_row_vector(M) == true
+ - M.size() == num
+ - M == a row vector with num linearly spaced values beginning with start
+ and stopping with end.
+ - M(num-1) == end
+ - if (num > 1) then
+ - M(0) == start
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp logspace (
+ double start,
+ double end,
+ long num
+ );
+ /*!
+ requires
+ - num >= 0
+ ensures
+ - returns a matrix M such that:
+ - M::type == double
+ - is_row_vector(M) == true
+ - M.size() == num
+ - M == a row vector with num logarithmically spaced values beginning with
+ 10^start and stopping with 10^end.
+ (i.e. M == pow(10, linspace(start, end, num)))
+ - M(num-1) == 10^end
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp linpiece (
+ const double val,
+ const matrix_exp& joints
+ );
+ /*!
+ requires
+ - is_vector(joints) == true
+ - joints.size() >= 2
+ - for all valid i < j:
+ - joints(i) < joints(j)
+ ensures
+ - linpiece() is useful for creating piecewise linear functions of val. For
+ example, if w is a parameter vector then you can represent a piecewise linear
+ function of val as: f(val) = dot(w, linpiece(val, linspace(0,100,5))). In
+ this case, f(val) is piecewise linear on the intervals [0,25], [25,50],
+ [50,75], [75,100]. Moreover, w(i) defines the derivative of f(val) in the
+ i-th interval. Finally, outside the interval [0,100] f(val) has a derivative
+ of zero and f(0) == 0.
+ - To be precise, this function returns a column vector L such that:
+ - L.size() == joints.size()-1
+ - is_col_vector(L) == true
+ - L contains the same type of elements as joints.
+ - for all valid i:
+ - if (joints(i) < val)
+ - L(i) == min(val,joints(i+1)) - joints(i)
+ - else
+ - L(i) == 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long R,
+ long C
+ >
+ const matrix_exp rotate (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ R( (r+R)%m.nr() , (c+C)%m.nc() ) == m(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp fliplr (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - flips the matrix m from left to right and returns the result.
+ I.e. reverses the order of the columns.
+ - returns a matrix M such that:
+ - M::type == the same type that was in m
+ - M has the same dimensions as m
+ - for all valid r and c:
+ M(r,c) == m(r, m.nc()-c-1)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp flipud (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - flips the matrix m from up to down and returns the result.
+ I.e. reverses the order of the rows.
+ - returns a matrix M such that:
+ - M::type == the same type that was in m
+ - M has the same dimensions as m
+ - for all valid r and c:
+ M(r,c) == m(m.nr()-r-1, c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp flip (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - flips the matrix m from up to down and left to right and returns the
+ result. I.e. returns flipud(fliplr(m)).
+ - returns a matrix M such that:
+ - M::type == the same type that was in m
+ - M has the same dimensions as m
+ - for all valid r and c:
+ M(r,c) == m(m.nr()-r-1, m.nc()-c-1)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp reshape (
+ const matrix_exp& m,
+ long rows,
+ long cols
+ );
+ /*!
+ requires
+ - m.size() == rows*cols
+ - rows > 0
+ - cols > 0
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == rows
+ - M.nc() == cols
+ - M.size() == m.size()
+ - for all valid r and c:
+ - let IDX = r*cols + c
+ - M(r,c) == m(IDX/m.nc(), IDX%m.nc())
+
+ - i.e. The matrix m is reshaped into a new matrix of rows by cols
+ dimension. Additionally, the elements of m are laid into M in row major
+ order.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp reshape_to_column_vector (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix M such that:
+ - is_col_vector(M) == true
+ - M.size() == m.size()
+ - for all valid r and c:
+ - m(r,c) == M(r*m.nc() + c)
+
+ - i.e. The matrix m is reshaped into a column vector. Note that
+ the elements are pulled out in row major order.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long R,
+ long C
+ >
+ const matrix_exp removerc (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.nr() > R >= 0
+ - m.nc() > C >= 0
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == m.nr() - 1
+ - M.nc() == m.nc() - 1
+ - M == m with its R row and C column removed
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp removerc (
+ const matrix_exp& m,
+ long R,
+ long C
+ );
+ /*!
+ requires
+ - m.nr() > R >= 0
+ - m.nc() > C >= 0
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == m.nr() - 1
+ - M.nc() == m.nc() - 1
+ - M == m with its R row and C column removed
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long R
+ >
+ const matrix_exp remove_row (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.nr() > R >= 0
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == m.nr() - 1
+ - M.nc() == m.nc()
+ - M == m with its R row removed
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp remove_row (
+ const matrix_exp& m,
+ long R
+ );
+ /*!
+ requires
+ - m.nr() > R >= 0
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == m.nr() - 1
+ - M.nc() == m.nc()
+ - M == m with its R row removed
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long C
+ >
+ const matrix_exp remove_col (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.nc() > C >= 0
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == m.nr()
+ - M.nc() == m.nc() - 1
+ - M == m with its C column removed
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp remove_col (
+ const matrix_exp& m,
+ long C
+ );
+ /*!
+ requires
+ - m.nc() > C >= 0
+ ensures
+ - returns a matrix M such that:
+ - M.nr() == m.nr()
+ - M.nc() == m.nc() - 1
+ - M == m with its C column removed
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename target_type
+ >
+ const matrix_exp matrix_cast (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R where for all valid r and c:
+ R(r,c) == static_cast<target_type>(m(r,c))
+ also, R has the same dimensions as m.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename U,
+ typename L
+ >
+ void set_all_elements (
+ matrix<T,NR,NC,MM,L>& m,
+ U value
+ );
+ /*!
+ ensures
+ - for all valid r and c:
+ m(r,c) == value
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::matrix_type tmp (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a temporary matrix object that is a copy of m.
+ (This allows you to easily force a matrix_exp to fully evaluate)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ uint32 hash (
+ const matrix<T,NR,NC,MM,L>& item,
+ uint32 seed = 0
+ );
+ /*!
+ requires
+ - T is a standard layout type (e.g. a POD type like int, float,
+ or a simple struct).
+ ensures
+ - returns a 32bit hash of the data stored in item.
+ - Each value of seed results in a different hash function being used.
+ (e.g. hash(item,0) should generally not be equal to hash(item,1))
+ - uses the murmur_hash3() routine to compute the actual hash.
+ - Note that if the memory layout of the elements in item change between
+ hardware platforms then hash() will give different outputs. If you want
+ hash() to always give the same output for the same input then you must
+ ensure that elements of item always have the same layout in memory.
+ Typically this means using fixed width types and performing byte swapping
+ to account for endianness before passing item to hash().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ // if matrix_exp contains non-complex types (e.g. float, double)
+ bool equal (
+ const matrix_exp& a,
+ const matrix_exp& b,
+ const matrix_exp::type epsilon = 100*std::numeric_limits<matrix_exp::type>::epsilon()
+ );
+ /*!
+ ensures
+ - if (a and b don't have the same dimensions) then
+ - returns false
+ - else if (there exists an r and c such that abs(a(r,c)-b(r,c)) > epsilon) then
+ - returns false
+ - else
+ - returns true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ // if matrix_exp contains std::complex types
+ bool equal (
+ const matrix_exp& a,
+ const matrix_exp& b,
+ const matrix_exp::type::value_type epsilon = 100*std::numeric_limits<matrix_exp::type::value_type>::epsilon()
+ );
+ /*!
+ ensures
+ - if (a and b don't have the same dimensions) then
+ - returns false
+ - else if (there exists an r and c such that abs(real(a(r,c)-b(r,c))) > epsilon
+ or abs(imag(a(r,c)-b(r,c))) > epsilon) then
+ - returns false
+ - else
+ - returns true
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp pointwise_multiply (
+ const matrix_exp& a,
+ const matrix_exp& b
+ );
+ /*!
+ requires
+ - a.nr() == b.nr()
+ - a.nc() == b.nc()
+ - a and b both contain the same type of element (one or both
+ can also be of type std::complex so long as the underlying type
+ in them is the same)
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in a and b.
+ - R has the same dimensions as a and b.
+ - for all valid r and c:
+ R(r,c) == a(r,c) * b(r,c)
+ !*/
+
+ const matrix_exp pointwise_multiply (
+ const matrix_exp& a,
+ const matrix_exp& b,
+ const matrix_exp& c
+ );
+ /*!
+ performs pointwise_multiply(a,pointwise_multiply(b,c));
+ !*/
+
+ const matrix_exp pointwise_multiply (
+ const matrix_exp& a,
+ const matrix_exp& b,
+ const matrix_exp& c,
+ const matrix_exp& d
+ );
+ /*!
+ performs pointwise_multiply(pointwise_multiply(a,b),pointwise_multiply(c,d));
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp join_rows (
+ const matrix_exp& a,
+ const matrix_exp& b
+ );
+ /*!
+ requires
+ - a.nr() == b.nr() || a.size() == 0 || b.size() == 0
+ - a and b both contain the same type of element
+ ensures
+ - This function joins two matrices together by concatenating their rows.
+ - returns a matrix R such that:
+ - R::type == the same type that was in a and b.
+ - R.nr() == a.nr() == b.nr()
+ - R.nc() == a.nc() + b.nc()
+ - for all valid r and c:
+ - if (c < a.nc()) then
+ - R(r,c) == a(r,c)
+ - else
+ - R(r,c) == b(r, c-a.nc())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp join_cols (
+ const matrix_exp& a,
+ const matrix_exp& b
+ );
+ /*!
+ requires
+ - a.nc() == b.nc() || a.size() == 0 || b.size() == 0
+ - a and b both contain the same type of element
+ ensures
+ - This function joins two matrices together by concatenating their columns.
+ - returns a matrix R such that:
+ - R::type == the same type that was in a and b.
+ - R.nr() == a.nr() + b.nr()
+ - R.nc() == a.nc() == b.nc()
+ - for all valid r and c:
+ - if (r < a.nr()) then
+ - R(r,c) == a(r,c)
+ - else
+ - R(r,c) == b(r-a.nr(), c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp tensor_product (
+ const matrix_exp& a,
+ const matrix_exp& b
+ );
+ /*!
+ requires
+ - a and b both contain the same type of element
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in a and b.
+ - R.nr() == a.nr() * b.nr()
+ - R.nc() == a.nc() * b.nc()
+ - for all valid r and c:
+ R(r,c) == a(r/b.nr(), c/b.nc()) * b(r%b.nr(), c%b.nc())
+ - I.e. R is the tensor product of matrix a with matrix b
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp cartesian_product (
+ const matrix_exp& A,
+ const matrix_exp& B
+ );
+ /*!
+ requires
+ - A and B both contain the same type of element
+ ensures
+ - Think of A and B as sets of column vectors. Then this function
+ returns a matrix that contains a set of column vectors that is
+ the Cartesian product of the sets A and B. That is, the resulting
+ matrix contains every possible combination of vectors from both A and
+ B.
+ - returns a matrix R such that:
+ - R::type == the same type that was in A and B.
+ - R.nr() == A.nr() + B.nr()
+ - R.nc() == A.nc() * B.nc()
+ - Each column of R is the concatenation of a column vector
+ from A with a column vector from B.
+ - for all valid r and c:
+ - if (r < A.nr()) then
+ - R(r,c) == A(r, c/B.nc())
+ - else
+ - R(r,c) == B(r-A.nr(), c%B.nc())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp scale_columns (
+ const matrix_exp& m,
+ const matrix_exp& v
+ );
+ /*!
+ requires
+ - is_vector(v) == true
+ - v.size() == m.nc()
+ - m and v both contain the same type of element
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m and v.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ R(r,c) == m(r,c) * v(c)
+ - i.e. R is the result of multiplying each of m's columns by
+ the corresponding scalar in v.
+
+ - Note that this function is identical to the expression m*diagm(v).
+ That is, the * operator is overloaded for this case and will invoke
+ scale_columns() automatically as appropriate.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp scale_rows (
+ const matrix_exp& m,
+ const matrix_exp& v
+ );
+ /*!
+ requires
+ - is_vector(v) == true
+ - v.size() == m.nr()
+ - m and v both contain the same type of element
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m and v.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ R(r,c) == m(r,c) * v(r)
+ - i.e. R is the result of multiplying each of m's rows by
+ the corresponding scalar in v.
+
+ - Note that this function is identical to the expression diagm(v)*m.
+ That is, the * operator is overloaded for this case and will invoke
+ scale_rows() automatically as appropriate.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void sort_columns (
+ matrix<T>& m,
+ matrix<T>& v
+ );
+ /*!
+ requires
+ - is_col_vector(v) == true
+ - v.size() == m.nc()
+ - m and v both contain the same type of element
+ ensures
+ - the dimensions for m and v are not changed
+ - sorts the columns of m according to the values in v.
+ i.e.
+ - #v == the contents of v but in sorted order according to
+ operator<. So smaller elements come first.
+ - Let #v(new(i)) == v(i) (i.e. new(i) is the index element i moved to)
+ - colm(#m,new(i)) == colm(m,i)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void rsort_columns (
+ matrix<T>& m,
+ matrix<T>& v
+ );
+ /*!
+ requires
+ - is_col_vector(v) == true
+ - v.size() == m.nc()
+ - m and v both contain the same type of element
+ ensures
+ - the dimensions for m and v are not changed
+ - sorts the columns of m according to the values in v.
+ i.e.
+ - #v == the contents of v but in sorted order according to
+ operator>. So larger elements come first.
+ - Let #v(new(i)) == v(i) (i.e. new(i) is the index element i moved to)
+ - colm(#m,new(i)) == colm(m,i)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type length_squared (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_vector(m) == true
+ ensures
+ - returns sum(squared(m))
+ (i.e. returns the square of the length of the vector m)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type length (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_vector(m) == true
+ ensures
+ - returns sqrt(sum(squared(m)))
+ (i.e. returns the length of the vector m)
+ - if (m contains integer valued elements) then
+ - The return type is a double that represents the length. Therefore, the
+ return value of length() is always represented using a floating point
+ type.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_row_vector (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - if (m.nr() == 1) then
+ - return true
+ - else
+ - returns false
+ !*/
+
+ bool is_col_vector (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - if (m.nc() == 1) then
+ - return true
+ - else
+ - returns false
+ !*/
+
+ bool is_vector (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - if (is_row_vector(m) || is_col_vector(m)) then
+ - return true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_finite (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns true if all the values in m are finite values and also not any kind
+ of NaN value.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Thresholding relational operators
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator< (
+ const matrix_exp& m,
+ const S& s
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (m(r,c) < s) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator< (
+ const S& s,
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (s < m(r,c)) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator<= (
+ const matrix_exp& m,
+ const S& s
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (m(r,c) <= s) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator<= (
+ const S& s,
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (s <= m(r,c)) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator> (
+ const matrix_exp& m,
+ const S& s
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (m(r,c) > s) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator> (
+ const S& s,
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (s > m(r,c)) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator>= (
+ const matrix_exp& m,
+ const S& s
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (m(r,c) >= s) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator>= (
+ const S& s,
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (s >= m(r,c)) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator== (
+ const matrix_exp& m,
+ const S& s
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (m(r,c) == s) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator== (
+ const S& s,
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (s == m(r,c)) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator!= (
+ const matrix_exp& m,
+ const S& s
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (m(r,c) != s) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename S>
+ const matrix_exp operator!= (
+ const S& s,
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_built_in_scalar_type<S>::value == true
+ - is_built_in_scalar_type<matrix_exp::type>::value == true
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m.
+ - R has the same dimensions as m.
+ - for all valid r and c:
+ - if (s != m(r,c)) then
+ - R(r,c) == 1
+ - else
+ - R(r,c) == 0
+ - i.e. R is a binary matrix of all 1s or 0s.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Statistics
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type min (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.size() > 0
+ ensures
+ - returns the value of the smallest element of m. If m contains complex
+ elements then the element returned is the one with the smallest norm
+ according to std::norm().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp min_pointwise (
+ const matrix_exp& a,
+ const matrix_exp& b
+ );
+ /*!
+ requires
+ - a.nr() == b.nr()
+ - a.nc() == b.nc()
+ - a and b both contain the same type of element
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in a and b.
+ - R has the same dimensions as a and b.
+ - for all valid r and c:
+ R(r,c) == std::min(a(r,c), b(r,c))
+ !*/
+
+ const matrix_exp min_pointwise (
+ const matrix_exp& a,
+ const matrix_exp& b,
+ const matrix_exp& c
+ );
+ /*!
+ performs min_pointwise(a,min_pointwise(b,c));
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type max (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.size() > 0
+ ensures
+ - returns the value of the biggest element of m. If m contains complex
+ elements then the element returned is the one with the largest norm
+ according to std::norm().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp max_pointwise (
+ const matrix_exp& a,
+ const matrix_exp& b
+ );
+ /*!
+ requires
+ - a.nr() == b.nr()
+ - a.nc() == b.nc()
+ - a and b both contain the same type of element
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in a and b.
+ - R has the same dimensions as a and b.
+ - for all valid r and c:
+ R(r,c) == std::max(a(r,c), b(r,c))
+ !*/
+
+ const matrix_exp max_pointwise (
+ const matrix_exp& a,
+ const matrix_exp& b,
+ const matrix_exp& c
+ );
+ /*!
+ performs max_pointwise(a,max_pointwise(b,c));
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void find_min_and_max (
+ const matrix_exp& m,
+ matrix_exp::type& min_val,
+ matrix_exp::type& max_val
+ );
+ /*!
+ requires
+ - m.size() > 0
+ ensures
+ - #min_val == min(m)
+ - #max_val == max(m)
+ - This function computes both the min and max in just one pass
+ over the elements of the matrix m.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ long index_of_max (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_vector(m) == true
+ - m.size() > 0
+ ensures
+ - returns the index of the largest element in m.
+ (i.e. m(index_of_max(m)) == max(m))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ long index_of_min (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - is_vector(m) == true
+ - m.size() > 0
+ ensures
+ - returns the index of the smallest element in m.
+ (i.e. m(index_of_min(m)) == min(m))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point max_point (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.size() > 0
+ ensures
+ - returns the location of the maximum element of the array, that is, if the
+ returned point is P then it will be the case that: m(P.y(),P.x()) == max(m).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ dlib::vector<double,2> max_point_interpolated (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.size() > 0
+ ensures
+ - Like max_point(), this function finds the location in m with the largest
+ value. However, we additionally use some quadratic interpolation to find the
+ location of the maximum point with sub-pixel accuracy. Therefore, the
+ returned point is equal to max_point(m) + some small sub-pixel delta.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ point min_point (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.size() > 0
+ ensures
+ - returns the location of the minimum element of the array, that is, if the
+ returned point is P then it will be the case that: m(P.y(),P.x()) == min(m).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type sum (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns the sum of all elements in m
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp sum_rows (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.size() > 0
+ ensures
+ - returns a row matrix that contains the sum of all the rows in m.
+ - returns a matrix M such that
+ - M::type == the same type that was in m
+ - M.nr() == 1
+ - M.nc() == m.nc()
+ - for all valid i:
+ - M(i) == sum(colm(m,i))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp sum_cols (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - m.size() > 0
+ ensures
+ - returns a column matrix that contains the sum of all the columns in m.
+ - returns a matrix M such that
+ - M::type == the same type that was in m
+ - M.nr() == m.nr()
+ - M.nc() == 1
+ - for all valid i:
+ - M(i) == sum(rowm(m,i))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type prod (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns the results of multiplying all elements of m together.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type mean (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns the mean of all elements in m.
+ (i.e. returns sum(m)/(m.nr()*m.nc()))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type variance (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns the unbiased sample variance of all elements in m
+ (i.e. 1.0/(m.nr()*m.nc() - 1)*(sum of all pow(m(i,j) - mean(m),2)))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp::type stddev (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns sqrt(variance(m))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix covariance (
+ const matrix_exp& m
+ );
+ /*!
+ requires
+ - matrix_exp::type == a dlib::matrix object
+ - is_col_vector(m) == true
+ - m.size() > 1
+ - for all valid i, j:
+ - is_col_vector(m(i)) == true
+ - m(i).size() > 0
+ - m(i).size() == m(j).size()
+ - i.e. m contains only column vectors and all the column vectors
+ have the same non-zero length
+ ensures
+ - returns the unbiased sample covariance matrix for the set of samples
+ in m.
+ (i.e. 1.0/(m.nr()-1)*(sum of all (m(i) - mean(m))*trans(m(i) - mean(m))))
+ - the returned matrix will contain elements of type matrix_exp::type::type.
+ - the returned matrix will have m(0).nr() rows and columns.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename rand_gen>
+ const matrix<double> randm(
+ long nr,
+ long nc,
+ rand_gen& rnd
+ );
+ /*!
+ requires
+ - nr >= 0
+ - nc >= 0
+ - rand_gen == an object that implements the rand/rand_float_abstract.h interface
+ ensures
+ - generates a random matrix using the given rnd random number generator
+ - returns a matrix M such that
+ - M::type == double
+ - M.nr() == nr
+ - M.nc() == nc
+ - for all valid i, j:
+ - M(i,j) == a random number such that 0 <= M(i,j) < 1
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline const matrix<double> randm(
+ long nr,
+ long nc
+ );
+ /*!
+ requires
+ - nr >= 0
+ - nc >= 0
+ ensures
+ - generates a random matrix using std::rand()
+ - returns a matrix M such that
+ - M::type == double
+ - M.nr() == nr
+ - M.nc() == nc
+ - for all valid i, j:
+ - M(i,j) == a random number such that 0 <= M(i,j) < 1
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline const matrix_exp gaussian_randm (
+ long nr,
+ long nc,
+ unsigned long seed = 0
+ );
+ /*!
+ requires
+ - nr >= 0
+ - nc >= 0
+ ensures
+ - returns a matrix with its values filled with 0 mean unit variance Gaussian
+ random numbers.
+ - Each setting of the seed results in a different random matrix.
+ - The returned matrix is lazily evaluated using the expression templates
+ technique. This means that the returned matrix doesn't take up any memory
+ and is only an expression template. The values themselves are computed on
+ demand using the gaussian_random_hash() routine.
+ - returns a matrix M such that
+ - M::type == double
+ - M.nr() == nr
+ - M.nc() == nc
+ - for all valid i, j:
+ - M(i,j) == gaussian_random_hash(i,j,seed)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Pixel and Image Utilities
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename P
+ >
+ const matrix<T,pixel_traits<P>::num,1> pixel_to_vector (
+ const P& pixel
+ );
+ /*!
+ requires
+ - pixel_traits<P> must be defined
+ ensures
+ - returns a matrix M such that:
+ - M::type == T
+ - M::NC == 1
+ - M::NR == pixel_traits<P>::num
+ - if (pixel_traits<P>::grayscale) then
+ - M(0) == pixel
+ - if (pixel_traits<P>::rgb) then
+ - M(0) == pixel.red
+ - M(1) == pixel.green
+ - M(2) == pixel.blue
+ - if (pixel_traits<P>::hsi) then
+ - M(0) == pixel.h
+ - M(1) == pixel.s
+ - M(2) == pixel.i
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename P
+ >
+ void vector_to_pixel (
+ P& pixel,
+ const matrix_exp& vector
+ );
+ /*!
+ requires
+ - vector::NR == pixel_traits<P>::num
+ - vector::NC == 1
+ (i.e. you have to use a statically dimensioned vector)
+ ensures
+ - if (pixel_traits<P>::grayscale) then
+ - pixel == M(0)
+ - if (pixel_traits<P>::rgb) then
+ - pixel.red == M(0)
+ - pixel.green == M(1)
+ - pixel.blue == M(2)
+ - if (pixel_traits<P>::hsi) then
+ - pixel.h == M(0)
+ - pixel.s == M(1)
+ - pixel.i == M(2)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ long lower,
+ long upper
+ >
+ const matrix_exp clamp (
+ const matrix_exp& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ - if (m(r,c) > upper) then
+ - R(r,c) == upper
+ - else if (m(r,c) < lower) then
+ - R(r,c) == lower
+ - else
+ - R(r,c) == m(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp clamp (
+ const matrix_exp& m,
+ const matrix_exp::type& lower,
+ const matrix_exp::type& upper
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ - if (m(r,c) > upper) then
+ - R(r,c) == upper
+ - else if (m(r,c) < lower) then
+ - R(r,c) == lower
+ - else
+ - R(r,c) == m(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp clamp (
+ const matrix_exp& m,
+ const matrix_exp& lower,
+ const matrix_exp& upper
+ );
+ /*!
+ requires
+ - m.nr() == lower.nr()
+ - m.nc() == lower.nc()
+ - m.nr() == upper.nr()
+ - m.nc() == upper.nc()
+ - m, lower, and upper all contain the same type of elements.
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ - if (m(r,c) > upper(r,c)) then
+ - R(r,c) == upper(r,c)
+ - else if (m(r,c) < lower(r,c)) then
+ - R(r,c) == lower(r,c)
+ - else
+ - R(r,c) == m(r,c)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp lowerbound (
+ const matrix_exp& m,
+ const matrix_exp::type& thresh
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ - if (m(r,c) >= thresh) then
+ - R(r,c) == m(r,c)
+ - else
+ - R(r,c) == thresh
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const matrix_exp upperbound (
+ const matrix_exp& m,
+ const matrix_exp::type& thresh
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R::type == the same type that was in m
+ - R has the same dimensions as m
+ - for all valid r and c:
+ - if (m(r,c) <= thresh) then
+ - R(r,c) == m(r,c)
+ - else
+ - R(r,c) == thresh
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIx_UTILITIES_ABSTRACT_
+
diff --git a/ml/dlib/dlib/matrix/symmetric_matrix_cache.h b/ml/dlib/dlib/matrix/symmetric_matrix_cache.h
new file mode 100644
index 000000000..bff268aef
--- /dev/null
+++ b/ml/dlib/dlib/matrix/symmetric_matrix_cache.h
@@ -0,0 +1,464 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SYMMETRIC_MATRIX_CAcHE_Hh_
+#define DLIB_SYMMETRIC_MATRIX_CAcHE_Hh_
+
+#include "symmetric_matrix_cache_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include "../algs.h"
+#include "../array.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, typename cache_element_type>
+ struct op_symm_cache : basic_op_m<M>
+ {
+ inline op_symm_cache(
+ const M& m_,
+ long max_size_megabytes_
+ ) :
+ basic_op_m<M>(m_),
+ max_size_megabytes(max_size_megabytes_),
+ is_initialized(false)
+ {
+ lookup.assign(this->m.nr(), -1);
+
+ diag_cache = matrix_cast<cache_element_type>(dlib::diag(m_));
+ }
+
+ op_symm_cache (
+ const op_symm_cache& item
+ ) :
+ basic_op_m<M>(item.m),
+ diag_cache(item.diag_cache),
+ max_size_megabytes(item.max_size_megabytes),
+ is_initialized(false)
+ {
+ lookup.assign(this->m.nr(), -1);
+ }
+
+ typedef cache_element_type type;
+ typedef const cache_element_type& const_ret_type;
+ const static long cost = M::cost + 3;
+
+ inline const_ret_type apply ( long r, long c) const
+ {
+ if (lookup[c] != -1)
+ {
+ return cache[lookup[c]](r);
+ }
+ else if (r == c)
+ {
+ return diag_cache(r);
+ }
+ else if (lookup[r] != -1)
+ {
+ // the matrix is symmetric so this is legit
+ return cache[lookup[r]](c);
+ }
+ else
+ {
+ add_col_to_cache(c);
+ return cache[lookup[c]](r);
+ }
+ }
+
+ inline std::pair<const type*,long*> col(long i) const
+ /*!
+ requires
+ - 0 <= i < nc()
+ ensures
+ - returns a pair P such that:
+ - P.first == a pointer to the first element of the ith column
+ - P.second == a pointer to the integer used to count the number of
+ outstanding references to the ith column.
+ !*/
+ {
+ if (is_cached(i) == false)
+ add_col_to_cache(i);
+
+ // find where this column is in the cache
+ long idx = lookup[i];
+ if (idx == next)
+ {
+ // if this column was the next to be replaced
+ // then make sure that doesn't happen
+ next = (next + 1)%cache.size();
+ }
+
+ return std::make_pair(&cache[idx](0), &references[idx]);
+ }
+
+ const type* diag() const { init(); return &diag_cache(0); }
+
+ long* diag_ref_count() const
+ {
+ return &diag_reference_count;
+ }
+
+ private:
+ inline bool is_cached (
+ long r
+ ) const
+ {
+ return (lookup[r] != -1);
+ }
+
+ inline void init() const
+ {
+ if (is_initialized == false)
+ {
+ // figure out how many columns of the matrix we can have
+ // with the given amount of memory.
+ long max_size = (max_size_megabytes*1024*1024)/(this->m.nr()*sizeof(type));
+ // don't let it be 0 or 1
+ if (max_size <= 1)
+ max_size = 2;
+
+ const long size = std::min(max_size,this->m.nr());
+
+ diag_reference_count = 0;
+
+ references.set_max_size(this->m.nr());
+ references.set_size(size);
+ for (unsigned long i = 0; i < references.size(); ++i)
+ references[i] = 0;
+
+ cache.set_max_size(this->m.nr());
+ cache.set_size(size);
+
+ rlookup.assign(size,-1);
+ next = 0;
+
+ is_initialized = true;
+ }
+ }
+
+ void make_sure_next_is_unreferenced (
+ ) const
+ {
+ if (references[next] != 0)
+ {
+ // find an unreferenced element of the cache
+ unsigned long i;
+ for (i = 1; i < references.size(); ++i)
+ {
+ const unsigned long idx = (next+i)%references.size();
+ if (references[idx] == 0)
+ {
+ next = idx;
+ break;
+ }
+ }
+
+ // if all elements of the cache are referenced then make the cache bigger
+ // and use the new element.
+ if (references[next] != 0)
+ {
+ cache.resize(cache.size()+1);
+
+ next = references.size();
+ references.resize(references.size()+1);
+ references[next] = 0;
+
+ rlookup.push_back(-1);
+ }
+ }
+ }
+
+ inline void add_col_to_cache(
+ long c
+ ) const
+ {
+ init();
+ make_sure_next_is_unreferenced();
+
+ // if the lookup table is pointing to cache[next] then clear lookup[next]
+ if (rlookup[next] != -1)
+ lookup[rlookup[next]] = -1;
+
+ // make the lookup table so that it says c is now cached at the spot indicated by next
+ lookup[c] = next;
+ rlookup[next] = c;
+
+ // compute this column in the matrix and store it in the cache
+ cache[next] = matrix_cast<cache_element_type>(colm(this->m,c));
+
+ next = (next + 1)%cache.size();
+ }
+
+ /*!
+ INITIAL VALUE
+ - for all valid x:
+ - lookup(x) == -1
+
+ - diag_cache == the diagonal of the original matrix
+ - is_initialized == false
+ - max_size_megabytes == the max_size_megabytes from symmetric_matrix_cache()
+
+ CONVENTION
+ - diag_cache == the diagonal of the original matrix
+ - lookup.size() == diag_cache.size()
+
+ - if (is_initialized) then
+ - if (lookup[c] != -1) then
+ - cache[lookup[c]] == the cached column c of the matrix
+ - rlookup[lookup[c]] == c
+
+ - if (rlookup[x] != -1) then
+ - lookup[rlookup[x]] == x
+ - cache[x] == the cached column rlookup[x] of the matrix
+
+ - next == the next element in the cache table to use to cache something
+ - references[i] == the number of outstanding references to cache element cache[i]
+
+ - diag_reference_count == the number of outstanding references to diag_cache.
+ (this isn't really needed. It's just here so that we can reuse the matrix
+ expression from colm() to implement diag())
+ !*/
+
+
+ mutable array<matrix<type,0,1,typename M::mem_manager_type> > cache;
+ mutable array<long> references;
+ matrix<type,0,1,typename M::mem_manager_type> diag_cache;
+ mutable std::vector<long> lookup;
+ mutable std::vector<long> rlookup;
+ mutable long next;
+
+ const long max_size_megabytes;
+ mutable bool is_initialized;
+ mutable long diag_reference_count;
+
+ };
+
+ template <
+ typename cache_element_type,
+ typename EXP
+ >
+ const matrix_op<op_symm_cache<EXP,cache_element_type> > symmetric_matrix_cache (
+ const matrix_exp<EXP>& m,
+ long max_size_megabytes
+ )
+ {
+ // Don't check that m is symmetric since doing so would be extremely onerous for the
+ // kinds of matrices intended for use with the symmetric_matrix_cache. Check everything
+ // else though.
+ DLIB_ASSERT(m.size() > 0 && m.nr() == m.nc() && max_size_megabytes >= 0,
+ "\tconst matrix_exp symmetric_matrix_cache(const matrix_exp& m, max_size_megabytes)"
+ << "\n\t You have given invalid arguments to this function"
+ << "\n\t m.nr(): " << m.nr()
+ << "\n\t m.nc(): " << m.nc()
+ << "\n\t m.size(): " << m.size()
+ << "\n\t max_size_megabytes: " << max_size_megabytes
+ );
+
+ typedef op_symm_cache<EXP,cache_element_type> op;
+ return matrix_op<op>(op(m.ref(), max_size_megabytes));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, typename cache_element_type>
+ struct op_colm_symm_cache
+ {
+ typedef cache_element_type type;
+
+ op_colm_symm_cache(
+ const M& m_,
+ const type* data_,
+ long* ref_count_
+ ) :
+ m(m_),
+ data(data_),
+ ref_count(ref_count_)
+ {
+ *ref_count += 1;
+ }
+
+ op_colm_symm_cache (
+ const op_colm_symm_cache& item
+ ) :
+ m(item.m),
+ data(item.data),
+ ref_count(item.ref_count)
+ {
+ *ref_count += 1;
+ }
+
+ ~op_colm_symm_cache(
+ )
+ {
+ *ref_count -= 1;
+ }
+
+ const M& m;
+
+ const type* const data;
+ long* const ref_count;
+
+ const static long cost = M::cost;
+ const static long NR = M::NR;
+ const static long NC = 1;
+ typedef const type& const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ inline const_ret_type apply ( long r, long) const { return data[r]; }
+
+ long nr () const { return m.nr(); }
+ long nc () const { return 1; }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP,
+ typename cache_element_type
+ >
+ inline const matrix_op<op_colm_symm_cache<EXP,cache_element_type> > colm (
+ const matrix_exp<matrix_op<op_symm_cache<EXP,cache_element_type> > >& m,
+ long col
+ )
+ {
+ DLIB_ASSERT(col >= 0 && col < m.nc(),
+ "\tconst matrix_exp colm(const matrix_exp& m, row)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\tcol: " << col
+ );
+
+ std::pair<const cache_element_type*,long*> p = m.ref().op.col(col);
+
+ typedef op_colm_symm_cache<EXP,cache_element_type> op;
+ return matrix_op<op>(op(m.ref().op.m,
+ p.first,
+ p.second));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename cache_element_type
+ >
+ inline const matrix_op<op_colm_symm_cache<EXP,cache_element_type> > diag (
+ const matrix_exp<matrix_op<op_symm_cache<EXP,cache_element_type> > >& m
+ )
+ {
+ typedef op_colm_symm_cache<EXP,cache_element_type> op;
+ return matrix_op<op>(op(m.ref().op.m,
+ m.ref().op.diag(),
+ m.ref().op.diag_ref_count()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M, typename cache_element_type>
+ struct op_rowm_symm_cache
+ {
+ typedef cache_element_type type;
+
+ op_rowm_symm_cache(
+ const M& m_,
+ const type* data_,
+ long* ref_count_
+ ) :
+ m(m_),
+ data(data_),
+ ref_count(ref_count_)
+ {
+ *ref_count += 1;
+ }
+
+ op_rowm_symm_cache (
+ const op_rowm_symm_cache& item
+ ) :
+ m(item.m),
+ data(item.data),
+ ref_count(item.ref_count)
+ {
+ *ref_count += 1;
+ }
+
+ ~op_rowm_symm_cache(
+ )
+ {
+ *ref_count -= 1;
+ }
+
+ const M& m;
+
+ const type* const data;
+ long* const ref_count;
+
+ const static long cost = M::cost;
+ const static long NR = 1;
+ const static long NC = M::NC;
+ typedef const type& const_ret_type;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+ inline const_ret_type apply ( long , long c) const { return data[c]; }
+
+ long nr () const { return 1; }
+ long nc () const { return m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP,
+ typename cache_element_type
+ >
+ inline const matrix_op<op_rowm_symm_cache<EXP,cache_element_type> > rowm (
+ const matrix_exp<matrix_op<op_symm_cache<EXP,cache_element_type> > >& m,
+ long row
+ )
+ {
+ DLIB_ASSERT(row >= 0 && row < m.nr(),
+ "\tconst matrix_exp rowm(const matrix_exp& m, row)"
+ << "\n\tYou have specified invalid sub matrix dimensions"
+ << "\n\tm.nr(): " << m.nr()
+ << "\n\tm.nc(): " << m.nc()
+ << "\n\trow: " << row
+ );
+
+ std::pair<const cache_element_type*,long*> p = m.ref().op.col(row);
+
+ typedef op_rowm_symm_cache<EXP,cache_element_type> op;
+ return matrix_op<op>(op(m.ref().op.m,
+ p.first,
+ p.second));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP, typename cache_element_type>
+ struct colm_exp<matrix_op<op_symm_cache<EXP, cache_element_type> > >
+ {
+ typedef matrix_op<op_colm_symm_cache<EXP, cache_element_type> > type;
+ };
+
+ template <typename EXP, typename cache_element_type>
+ struct rowm_exp<matrix_op<op_symm_cache<EXP, cache_element_type> > >
+ {
+ typedef matrix_op<op_rowm_symm_cache<EXP, cache_element_type> > type;
+ };
+
+ template <typename EXP, typename cache_element_type>
+ struct diag_exp<matrix_op<op_symm_cache<EXP, cache_element_type> > >
+ {
+ typedef matrix_op<op_colm_symm_cache<EXP, cache_element_type> > type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SYMMETRIC_MATRIX_CAcHE_Hh_
+
diff --git a/ml/dlib/dlib/matrix/symmetric_matrix_cache_abstract.h b/ml/dlib/dlib/matrix/symmetric_matrix_cache_abstract.h
new file mode 100644
index 000000000..6a41ad282
--- /dev/null
+++ b/ml/dlib/dlib/matrix/symmetric_matrix_cache_abstract.h
@@ -0,0 +1,63 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#define DLIB_SYMMETRIC_MATRIX_CAcHE_ABSTRACT_Hh_
+#ifndef DLIB_SYMMETRIC_MATRIX_CAcHE_ABSTRACT_Hh_
+
+#include "matrix_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename cache_element_type
+ >
+ const matrix_exp symmetric_matrix_cache (
+ const matrix_exp& m,
+ long max_size_megabytes
+ );
+ /*!
+ requires
+ - m.size() > 0
+ - m.nr() == m.nc()
+ - max_size_megabytes >= 0
+ ensures
+ - This function assumes that m is symmetric. If m is not symmetric then it won't
+ crash but you will get incorrect results.
+ - This method creates a matrix expression which internally caches the elements
+ of m so that they can be accessed quickly. It is useful if m is some kind of
+ complex matrix expression which is both very large and expensive to evaluate.
+ An example would be a kernel_matrix() expression with an expensive kernel and
+ a large number of samples. Such an expression would result in a huge matrix,
+ potentially too big to store in memory. The symmetric_matrix_cache() then makes
+ it easy to store just the parts of a matrix expression which are accessed most
+ often in memory. The specific details are defined below.
+ - returns a matrix M such that
+ - M == m
+ (i.e. M represents the same matrix as m)
+ - M will cache elements of m and hold them internally so they can be quickly
+ accessed. In particular, M will attempt to allocate no more than
+ max_size_megabytes megabytes of memory for the purposes of caching
+ elements of m. When an element of the matrix is accessed it is either
+ retrieved from the cache, or if this is not possible, then an entire
+ column of m is loaded into a part of the cache which hasn't been used
+ recently and the needed element returned.
+ - diag(m) is always loaded into the cache and is stored separately from
+ the cached columns. That means accesses to the diagonal elements of m
+ are always fast.
+ - M will store the cached elements of m as cache_element_type objects.
+ Typically, cache_element_type will be float or double.
+ - To avoid repeated cache lookups, the following operations are optimized for
+ use with the symmetric_matrix_cache():
+ - diag(M), rowm(M,row_idx), colm(M,col_idx)
+ These methods will perform only one cache lookup operation for an
+ entire row/column/diagonal worth of data.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SYMMETRIC_MATRIX_CAcHE_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/md5.h b/ml/dlib/dlib/md5.h
new file mode 100644
index 000000000..e62930366
--- /dev/null
+++ b/ml/dlib/dlib/md5.h
@@ -0,0 +1,3 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include "md5/md5_kernel_1.h"
diff --git a/ml/dlib/dlib/md5/md5_kernel_1.cpp b/ml/dlib/dlib/md5/md5_kernel_1.cpp
new file mode 100644
index 000000000..f073f9256
--- /dev/null
+++ b/ml/dlib/dlib/md5/md5_kernel_1.cpp
@@ -0,0 +1,617 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MD5_KERNEL_1_CPp_
+#define DLIB_MD5_KERNEL_1_CPp_
+#include "md5_kernel_1.h"
+#include "../uintn.h"
+
+#include <sstream>
+#include <cstring>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace md5_stuff
+ {
+
+ inline uint32 F (
+ uint32 x,
+ uint32 y,
+ uint32 z
+ )
+ {
+ return ( (x&y) | ((~x)&z) );
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline uint32 G (
+ uint32 x,
+ uint32 y,
+ uint32 z
+ )
+ {
+ return ( (x&z) | (y&(~z)) );
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline uint32 H (
+ uint32 x,
+ uint32 y,
+ uint32 z
+ )
+ {
+ return ( x^y^z );
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline uint32 I (
+ uint32 x,
+ uint32 y,
+ uint32 z
+ )
+ {
+ return ( y ^ (x|(~z)) );
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline uint32 rotate_left (
+ uint32 x,
+ uint32 n
+ )
+ {
+ return ( (x<<n) | (x>>(32-n)) );
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline void FF (
+ uint32& a,
+ uint32 b,
+ uint32 c,
+ uint32 d,
+ uint32 x,
+ uint32 s,
+ uint32 ac
+ )
+ {
+ a += F(b, c, d) + x + ac;
+ a = rotate_left(a, s);
+ a += b;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline void GG (
+ uint32& a,
+ uint32 b,
+ uint32 c,
+ uint32 d,
+ uint32 x,
+ uint32 s,
+ uint32 ac
+ )
+ {
+ a += G(b, c, d) + x + ac;
+ a = rotate_left(a, s);
+ a += b;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline void HH (
+ uint32& a,
+ uint32 b,
+ uint32 c,
+ uint32 d,
+ uint32 x,
+ uint32 s,
+ uint32 ac
+ )
+ {
+ a += H(b, c, d) + x + ac;
+ a = rotate_left(a, s);
+ a += b;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline void II (
+ uint32& a,
+ uint32 b,
+ uint32 c,
+ uint32 d,
+ uint32 x,
+ uint32 s,
+ uint32 ac
+ )
+ {
+ a += I(b, c, d) + x + ac;
+ a = rotate_left(a, s);
+ a += b;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void scramble_block (
+ uint32& a,
+ uint32& b,
+ uint32& c,
+ uint32& d,
+ uint32* x
+ )
+ {
+ const uint32 S11 = 7;
+ const uint32 S12 = 12;
+ const uint32 S13 = 17;
+ const uint32 S14 = 22;
+ const uint32 S21 = 5;
+ const uint32 S22 = 9;
+ const uint32 S23 = 14;
+ const uint32 S24 = 20;
+ const uint32 S31 = 4;
+ const uint32 S32 = 11;
+ const uint32 S33 = 16;
+ const uint32 S34 = 23;
+ const uint32 S41 = 6;
+ const uint32 S42 = 10;
+ const uint32 S43 = 15;
+ const uint32 S44 = 21;
+
+
+ // round 1
+ FF (a, b, c, d, x[ 0], S11, 0xd76aa478); // 1
+ FF (d, a, b, c, x[ 1], S12, 0xe8c7b756); // 2
+ FF (c, d, a, b, x[ 2], S13, 0x242070db); // 3
+ FF (b, c, d, a, x[ 3], S14, 0xc1bdceee); // 4
+ FF (a, b, c, d, x[ 4], S11, 0xf57c0faf); // 5
+ FF (d, a, b, c, x[ 5], S12, 0x4787c62a); // 6
+ FF (c, d, a, b, x[ 6], S13, 0xa8304613); // 7
+ FF (b, c, d, a, x[ 7], S14, 0xfd469501); // 8
+ FF (a, b, c, d, x[ 8], S11, 0x698098d8); // 9
+ FF (d, a, b, c, x[ 9], S12, 0x8b44f7af); // 10
+ FF (c, d, a, b, x[10], S13, 0xffff5bb1); // 11
+ FF (b, c, d, a, x[11], S14, 0x895cd7be); // 12
+ FF (a, b, c, d, x[12], S11, 0x6b901122); // 13
+ FF (d, a, b, c, x[13], S12, 0xfd987193); // 14
+ FF (c, d, a, b, x[14], S13, 0xa679438e); // 15
+ FF (b, c, d, a, x[15], S14, 0x49b40821); // 16
+
+ // Round 2
+ GG (a, b, c, d, x[ 1], S21, 0xf61e2562); // 17
+ GG (d, a, b, c, x[ 6], S22, 0xc040b340); // 18
+ GG (c, d, a, b, x[11], S23, 0x265e5a51); // 19
+ GG (b, c, d, a, x[ 0], S24, 0xe9b6c7aa); // 20
+ GG (a, b, c, d, x[ 5], S21, 0xd62f105d); // 21
+ GG (d, a, b, c, x[10], S22, 0x2441453); // 22
+ GG (c, d, a, b, x[15], S23, 0xd8a1e681); // 23
+ GG (b, c, d, a, x[ 4], S24, 0xe7d3fbc8); // 24
+ GG (a, b, c, d, x[ 9], S21, 0x21e1cde6); // 25
+ GG (d, a, b, c, x[14], S22, 0xc33707d6); // 26
+ GG (c, d, a, b, x[ 3], S23, 0xf4d50d87); // 27
+ GG (b, c, d, a, x[ 8], S24, 0x455a14ed); // 28
+ GG (a, b, c, d, x[13], S21, 0xa9e3e905); // 29
+ GG (d, a, b, c, x[ 2], S22, 0xfcefa3f8); // 30
+ GG (c, d, a, b, x[ 7], S23, 0x676f02d9); // 31
+ GG (b, c, d, a, x[12], S24, 0x8d2a4c8a); // 32
+
+ // Round 3
+ HH (a, b, c, d, x[ 5], S31, 0xfffa3942); // 33
+ HH (d, a, b, c, x[ 8], S32, 0x8771f681); // 34
+ HH (c, d, a, b, x[11], S33, 0x6d9d6122); // 35
+ HH (b, c, d, a, x[14], S34, 0xfde5380c); // 36
+ HH (a, b, c, d, x[ 1], S31, 0xa4beea44); // 37
+ HH (d, a, b, c, x[ 4], S32, 0x4bdecfa9); // 38
+ HH (c, d, a, b, x[ 7], S33, 0xf6bb4b60); // 39
+ HH (b, c, d, a, x[10], S34, 0xbebfbc70); // 40
+ HH (a, b, c, d, x[13], S31, 0x289b7ec6); // 41
+ HH (d, a, b, c, x[ 0], S32, 0xeaa127fa); // 42
+ HH (c, d, a, b, x[ 3], S33, 0xd4ef3085); // 43
+ HH (b, c, d, a, x[ 6], S34, 0x4881d05); // 44
+ HH (a, b, c, d, x[ 9], S31, 0xd9d4d039); // 45
+ HH (d, a, b, c, x[12], S32, 0xe6db99e5); // 46
+ HH (c, d, a, b, x[15], S33, 0x1fa27cf8); // 47
+ HH (b, c, d, a, x[ 2], S34, 0xc4ac5665); // 48
+
+ // Round 4
+ II (a, b, c, d, x[ 0], S41, 0xf4292244); // 49
+ II (d, a, b, c, x[ 7], S42, 0x432aff97); // 50
+ II (c, d, a, b, x[14], S43, 0xab9423a7); // 51
+ II (b, c, d, a, x[ 5], S44, 0xfc93a039); // 52
+ II (a, b, c, d, x[12], S41, 0x655b59c3); // 53
+ II (d, a, b, c, x[ 3], S42, 0x8f0ccc92); // 54
+ II (c, d, a, b, x[10], S43, 0xffeff47d); // 55
+ II (b, c, d, a, x[ 1], S44, 0x85845dd1); // 56
+ II (a, b, c, d, x[ 8], S41, 0x6fa87e4f); // 57
+ II (d, a, b, c, x[15], S42, 0xfe2ce6e0); // 58
+ II (c, d, a, b, x[ 6], S43, 0xa3014314); // 59
+ II (b, c, d, a, x[13], S44, 0x4e0811a1); // 60
+ II (a, b, c, d, x[ 4], S41, 0xf7537e82); // 61
+ II (d, a, b, c, x[11], S42, 0xbd3af235); // 62
+ II (c, d, a, b, x[ 2], S43, 0x2ad7d2bb); // 63
+ II (b, c, d, a, x[ 9], S44, 0xeb86d391); // 64
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+
+ const std::string md5 (
+ const std::string& input
+ )
+ {
+ unsigned char output[16];
+ md5 (
+ reinterpret_cast<const unsigned char*>(input.data()),
+ static_cast<unsigned long>(input.size()),
+ output
+ );
+
+
+ std::stringstream temp;
+ for (int i = 0; i < 16; ++i)
+ {
+ temp.fill('0');
+ temp.width(2);
+ temp << std::hex << static_cast<unsigned int>(output[i]);
+ }
+
+ return temp.str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void md5 (
+ const unsigned char* input,
+ unsigned long len,
+ unsigned char* output
+ )
+ {
+ using namespace md5_stuff;
+
+
+
+
+ // make a temp version of input with enough space for padding and len appended
+ unsigned long extra_len = 64-len%64;
+ if (extra_len <= 8)
+ extra_len += 64;
+ unsigned char* temp = new unsigned char[extra_len + len];
+
+ // number of 16 word blocks
+ const unsigned long N = (extra_len + len)/64;
+
+ const unsigned char* input2 = input;
+ unsigned char* temp2 = temp;
+ unsigned char* end = temp+len;
+
+ // copy input into temp
+ while (temp2 != end)
+ {
+ *temp2 = *input2;
+ ++temp2;
+ ++input2;
+ }
+
+ // pad temp
+ end += extra_len-8;
+ *temp2 = static_cast<unsigned char>(0x80);
+ ++temp2;
+ while (temp2 != end)
+ {
+ *temp2 = 0;
+ ++temp2;
+ }
+
+ // make len the number of bits in the original message
+ // but first multiply len by 8 and since len is only 32 bits the number might
+ // overflow so we will carry out the multiplication manually and end up with
+ // the result in the base 65536 number with three digits
+ // result = low + high*65536 + upper*65536*65536
+ unsigned long low = len & 0xFFFF;
+ unsigned long high = len >> 16;
+ unsigned long upper;
+ unsigned long tmp;
+ tmp = low * 8;
+ low = tmp & 0xFFFF;
+ tmp = high * 8 + (tmp>>16);
+ high = tmp & 0xFFFF;
+ upper = tmp >> 16;
+
+
+ // append the length
+ *temp2 = static_cast<unsigned char>(low&0xFF);
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((low>>8)&0xFF);
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((high)&0xFF);
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((high>>8)&0xFF);
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((upper)&0xFF);;
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((upper>>8)&0xFF);;
+ ++temp2;
+ *temp2 = 0;
+ ++temp2;
+ *temp2 = 0;
+
+
+ uint32 a = 0x67452301;
+ uint32 b = 0xefcdab89;
+ uint32 c = 0x98badcfe;
+ uint32 d = 0x10325476;
+
+
+ // an array of 16 words
+ uint32 x[16];
+
+ for (unsigned long i = 0; i < N; ++i)
+ {
+
+ // copy a block of 16 words from m into x
+ for (unsigned long j = 0; j < 16; ++j)
+ {
+ x[j] = (
+ (static_cast<uint32>(temp[4*(j + 16*i) + 3]) << 24) |
+ (static_cast<uint32>(temp[4*(j + 16*i) + 2]) << 16) |
+ (static_cast<uint32>(temp[4*(j + 16*i) + 1]) << 8 ) |
+ (static_cast<uint32>(temp[4*(j + 16*i) ]) )
+ );
+ }
+
+ uint32 aa = a;
+ uint32 bb = b;
+ uint32 cc = c;
+ uint32 dd = d;
+
+
+ scramble_block(a,b,c,d,x);
+
+
+ a = a + aa;
+ b = b + bb;
+ c = c + cc;
+ d = d + dd;
+
+ }
+
+
+ // put a, b, c, and d into output
+ output[0] = static_cast<unsigned char>((a) &0xFF);
+ output[1] = static_cast<unsigned char>((a>>8) &0xFF);
+ output[2] = static_cast<unsigned char>((a>>16)&0xFF);
+ output[3] = static_cast<unsigned char>((a>>24)&0xFF);
+
+ output[4] = static_cast<unsigned char>((b) &0xFF);
+ output[5] = static_cast<unsigned char>((b>>8) &0xFF);
+ output[6] = static_cast<unsigned char>((b>>16)&0xFF);
+ output[7] = static_cast<unsigned char>((b>>24)&0xFF);
+
+ output[8] = static_cast<unsigned char>((c) &0xFF);
+ output[9] = static_cast<unsigned char>((c>>8) &0xFF);
+ output[10] = static_cast<unsigned char>((c>>16)&0xFF);
+ output[11] = static_cast<unsigned char>((c>>24)&0xFF);
+
+ output[12] = static_cast<unsigned char>((d) &0xFF);
+ output[13] = static_cast<unsigned char>((d>>8) &0xFF);
+ output[14] = static_cast<unsigned char>((d>>16)&0xFF);
+ output[15] = static_cast<unsigned char>((d>>24)&0xFF);
+
+ delete [] temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string md5 (
+ std::istream& input
+ )
+ {
+ unsigned char output[16];
+ md5 (
+ input,
+ output
+ );
+
+
+ std::stringstream temp;
+ for (int i = 0; i < 16; ++i)
+ {
+ temp.fill('0');
+ temp.width(2);
+ temp << std::hex << static_cast<unsigned int>(output[i]);
+ }
+
+ return temp.str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void md5 (
+ std::istream& input,
+ unsigned char* output
+ )
+ {
+ using namespace md5_stuff;
+
+
+
+
+ uint32 a = 0x67452301;
+ uint32 b = 0xefcdab89;
+ uint32 c = 0x98badcfe;
+ uint32 d = 0x10325476;
+
+
+
+ unsigned long len = 0;
+
+ // an array of 16 words
+ uint32 x[16];
+ unsigned char temp[64];
+
+
+
+ bool write_length = false;
+ bool at_end = false;
+ std::streambuf& inputbuf = *input.rdbuf();
+ while(!at_end)
+ {
+ int num = inputbuf.sgetn(reinterpret_cast<char*>(temp),64);
+ len += num;
+
+ // if we hit the end of the stream then pad and add length
+ if (num < 64)
+ {
+ at_end = true;
+ unsigned char* temp2 = temp;
+ unsigned char* end;
+ if (num < 56)
+ end = temp+56;
+ else
+ end = temp+64;
+
+ temp2 += num;
+
+ // apply padding
+ *temp2 = 0x80;
+ ++temp2;
+ while (temp2 != end)
+ {
+ *temp2 = 0;
+ ++temp2;
+ }
+
+
+ if (num < 56)
+ {
+ write_length = true;
+ // make len the number of bits in the original message
+ // but first multiply len by 8 and since len is only 32 bits the number might
+ // overflow so we will carry out the multiplication manually and end up with
+ // the result in the base 65536 number with three digits
+ // result = low + high*65536 + upper*65536*65536
+ unsigned long low = len & 0xFFFF;
+ unsigned long high = len >> 16;
+ unsigned long upper;
+ unsigned long tmp;
+ tmp = low * 8;
+ low = tmp & 0xFFFF;
+ tmp = high * 8 + (tmp>>16);
+ high = tmp & 0xFFFF;
+ upper = tmp >> 16;
+
+
+ // append the length
+ *temp2 = static_cast<unsigned char>(low&0xFF);
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((low>>8)&0xFF);
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((high)&0xFF);
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((high>>8)&0xFF);
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((upper)&0xFF);;
+ ++temp2;
+ *temp2 = static_cast<unsigned char>((upper>>8)&0xFF);;
+ ++temp2;
+ *temp2 = 0;
+ ++temp2;
+ *temp2 = 0;
+ }
+
+
+ }
+
+
+ // copy a block of 16 words from m into x
+ for (unsigned long i = 0; i < 16; ++i)
+ {
+ x[i] = (
+ (static_cast<uint32>(temp[4*i + 3]) << 24) |
+ (static_cast<uint32>(temp[4*i + 2]) << 16) |
+ (static_cast<uint32>(temp[4*i + 1]) << 8 ) |
+ (static_cast<uint32>(temp[4*i ]) )
+ );
+ }
+
+
+ uint32 aa = a;
+ uint32 bb = b;
+ uint32 cc = c;
+ uint32 dd = d;
+
+
+ scramble_block(a,b,c,d,x);
+
+
+ a = a + aa;
+ b = b + bb;
+ c = c + cc;
+ d = d + dd;
+
+ }
+
+ if (!write_length)
+ {
+ uint64 temp = len*8;
+
+ uint32 aa = a;
+ uint32 bb = b;
+ uint32 cc = c;
+ uint32 dd = d;
+
+ std::memset(x, 0, sizeof(x));
+ x[15] = (temp>>32);
+ x[14] = (temp&0xFFFFFFFF);
+
+ scramble_block(a,b,c,d,x);
+
+
+ a = a + aa;
+ b = b + bb;
+ c = c + cc;
+ d = d + dd;
+
+ }
+
+
+ // put a, b, c, and d into output
+ output[0] = static_cast<unsigned char>((a) &0xFF);
+ output[1] = static_cast<unsigned char>((a>>8) &0xFF);
+ output[2] = static_cast<unsigned char>((a>>16)&0xFF);
+ output[3] = static_cast<unsigned char>((a>>24)&0xFF);
+
+ output[4] = static_cast<unsigned char>((b) &0xFF);
+ output[5] = static_cast<unsigned char>((b>>8) &0xFF);
+ output[6] = static_cast<unsigned char>((b>>16)&0xFF);
+ output[7] = static_cast<unsigned char>((b>>24)&0xFF);
+
+ output[8] = static_cast<unsigned char>((c) &0xFF);
+ output[9] = static_cast<unsigned char>((c>>8) &0xFF);
+ output[10] = static_cast<unsigned char>((c>>16)&0xFF);
+ output[11] = static_cast<unsigned char>((c>>24)&0xFF);
+
+ output[12] = static_cast<unsigned char>((d) &0xFF);
+ output[13] = static_cast<unsigned char>((d>>8) &0xFF);
+ output[14] = static_cast<unsigned char>((d>>16)&0xFF);
+ output[15] = static_cast<unsigned char>((d>>24)&0xFF);
+
+ input.clear(std::ios::eofbit);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_MD5_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/md5/md5_kernel_1.h b/ml/dlib/dlib/md5/md5_kernel_1.h
new file mode 100644
index 000000000..7031d21ef
--- /dev/null
+++ b/ml/dlib/dlib/md5/md5_kernel_1.h
@@ -0,0 +1,50 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MD5_KERNEl_1_
+#define DLIB_MD5_KERNEl_1_
+
+#include "md5_kernel_abstract.h"
+#include <string>
+#include <iosfwd>
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string md5 (
+ const std::string& input
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ void md5 (
+ const unsigned char* input,
+ unsigned long len,
+ unsigned char* output
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string md5 (
+ std::istream& input
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ void md5 (
+ std::istream& input,
+ unsigned char* output
+ );
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "md5_kernel_1.cpp"
+#endif
+
+#endif // DLIB_MD5_KERNEl_1_
+
diff --git a/ml/dlib/dlib/md5/md5_kernel_abstract.h b/ml/dlib/dlib/md5/md5_kernel_abstract.h
new file mode 100644
index 000000000..b7d265d72
--- /dev/null
+++ b/ml/dlib/dlib/md5/md5_kernel_abstract.h
@@ -0,0 +1,83 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MD5_KERNEl_ABSTRACT_
+#ifdef DLIB_MD5_KERNEl_ABSTRACT_
+
+#include <string>
+#include <iosfwd>
+
+namespace dlib
+{
+
+ /*!
+ NOTE:
+ This is the RSA Data Security, Inc. MD5 Message-Digest Algorithm
+ as described in rfc1321
+
+ For the functions which return a unsigned char*. The array contains
+ the 16 bytes of the digest and are in the correct order.
+ i.e. output[0], output[1], output[2], ...
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string md5 (
+ const std::string& input
+ );
+ /*!
+ ensures
+ - returns the md5 digest of input as a hexadecimal string
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void md5 (
+ const unsigned char* input,
+ unsigned long len,
+ unsigned char* output
+ );
+ /*!
+ requires
+ - input == pointer to len bytes
+ - output == pointer to 16 bytes
+ - input != output
+ ensures
+ - #output == the md5 digest of input.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string md5 (
+ std::istream& input
+ );
+ /*!
+ requires
+ - input.fail() == false
+ ensures
+ - returns the md5 digest of input as a hexadecimal string
+ - #input.eof() == true
+ - #input.fail() == false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void md5 (
+ std::istream& input
+ unsigned char* output
+ );
+ /*!
+ requires
+ - input.fail() == false
+ - output == pointer to 16 bytes
+ ensures
+ - #output == the md5 digest of input
+ - #input.eof() == true
+ - #input.fail() == false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MD5_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/member_function_pointer.h b/ml/dlib/dlib/member_function_pointer.h
new file mode 100644
index 000000000..3dd72f596
--- /dev/null
+++ b/ml/dlib/dlib/member_function_pointer.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMBER_FUNCTION_POINTEr_
+#define DLIB_MEMBER_FUNCTION_POINTEr_
+
+#include "member_function_pointer/member_function_pointer_kernel_1.h"
+#include "member_function_pointer/make_mfp.h"
+
+#endif // DLIB_MEMBER_FUNCTION_POINTEr_
+
diff --git a/ml/dlib/dlib/member_function_pointer/make_mfp.h b/ml/dlib/dlib/member_function_pointer/make_mfp.h
new file mode 100644
index 000000000..fff9b27ea
--- /dev/null
+++ b/ml/dlib/dlib/member_function_pointer/make_mfp.h
@@ -0,0 +1,179 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MAKE_MFp_H_
+#define DLIB_MAKE_MFp_H_
+
+#include "member_function_pointer_kernel_1.h"
+#include "make_mfp_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ member_function_pointer<> make_mfp (
+ T& object,
+ void (T::*cb)()
+ )
+ {
+ member_function_pointer<> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+ template <
+ typename T
+ >
+ member_function_pointer<> make_mfp (
+ const T& object,
+ void (T::*cb)()const
+ )
+ {
+ member_function_pointer<> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename A1
+ >
+ member_function_pointer<A1> make_mfp (
+ T& object,
+ void (T::*cb)(A1)
+ )
+ {
+ member_function_pointer<A1> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+ template <
+ typename T,
+ typename A1
+ >
+ member_function_pointer<A1> make_mfp (
+ const T& object,
+ void (T::*cb)(A1)const
+ )
+ {
+ member_function_pointer<A1> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename A1,
+ typename A2
+ >
+ member_function_pointer<A1,A2> make_mfp (
+ T& object,
+ void (T::*cb)(A1,A2)
+ )
+ {
+ member_function_pointer<A1,A2> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+ template <
+ typename T,
+ typename A1,
+ typename A2
+ >
+ member_function_pointer<A1,A2> make_mfp (
+ const T& object,
+ void (T::*cb)(A1,A2)const
+ )
+ {
+ member_function_pointer<A1,A2> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename A1,
+ typename A2,
+ typename A3
+ >
+ member_function_pointer<A1,A2,A3> make_mfp (
+ T& object,
+ void (T::*cb)(A1,A2,A3)
+ )
+ {
+ member_function_pointer<A1,A2,A3> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+ template <
+ typename T,
+ typename A1,
+ typename A2,
+ typename A3
+ >
+ member_function_pointer<A1,A2,A3> make_mfp (
+ const T& object,
+ void (T::*cb)(A1,A2,A3)const
+ )
+ {
+ member_function_pointer<A1,A2,A3> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename A1,
+ typename A2,
+ typename A3,
+ typename A4
+ >
+ member_function_pointer<A1,A2,A3,A4> make_mfp (
+ T& object,
+ void (T::*cb)(A1,A2,A3,A4)
+ )
+ {
+ member_function_pointer<A1,A2,A3,A4> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+ template <
+ typename T,
+ typename A1,
+ typename A2,
+ typename A3,
+ typename A4
+ >
+ member_function_pointer<A1,A2,A3,A4> make_mfp (
+ const T& object,
+ void (T::*cb)(A1,A2,A3,A4)const
+ )
+ {
+ member_function_pointer<A1,A2,A3,A4> temp;
+ temp.set(object, cb);
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MAKE_MFp_H_
+
+
+
diff --git a/ml/dlib/dlib/member_function_pointer/make_mfp_abstract.h b/ml/dlib/dlib/member_function_pointer/make_mfp_abstract.h
new file mode 100644
index 000000000..5074ca9a7
--- /dev/null
+++ b/ml/dlib/dlib/member_function_pointer/make_mfp_abstract.h
@@ -0,0 +1,207 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MAKE_MFp_ABSTRACT_
+#ifdef DLIB_MAKE_MFp_ABSTRACT_
+
+#include "member_function_pointer_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ member_function_pointer<> make_mfp (
+ T& object,
+ void (T::*cb)()
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP() will call (object.*cb)()
+ !*/
+
+ template <
+ typename T
+ >
+ member_function_pointer<> make_mfp (
+ const T& object,
+ void (T::*cb)()const
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP() will call (object.*cb)()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename A1
+ >
+ member_function_pointer<A1> make_mfp (
+ T& object,
+ void (T::*cb)(A1 a1)
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP(a1) will call (object.*cb)(a1)
+ !*/
+
+ template <
+ typename T,
+ typename A1
+ >
+ member_function_pointer<A1> make_mfp (
+ const T& object,
+ void (T::*cb)(A1 a1)const
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP(a1) will call (object.*cb)(a1)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename A1,
+ typename A2
+ >
+ member_function_pointer<A1,A2> make_mfp (
+ T& object,
+ void (T::*cb)(A1 a1, A2 a2)
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP(a1,a2) will call (object.*cb)(a1,a2)
+ !*/
+
+ template <
+ typename T,
+ typename A1,
+ typename A2
+ >
+ member_function_pointer<A1,A2> make_mfp (
+ const T& object,
+ void (T::*cb)(A1 a1, A2 a2)const
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP(a1,a2) will call (object.*cb)(a1,a2)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename A1,
+ typename A2,
+ typename A3
+ >
+ member_function_pointer<A1,A2,A3> make_mfp (
+ T& object,
+ void (T::*cb)(A1 a1, A2 a2, A3 a3)
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP(a1,a2,a3) will call (object.*cb)(a1,a2,a3)
+ !*/
+
+ template <
+ typename T,
+ typename A1,
+ typename A2,
+ typename A3
+ >
+ member_function_pointer<A1,A2,A3> make_mfp (
+ const T& object,
+ void (T::*cb)(A1 a1, A2 a2, A3 a3)const
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP(a1,a2,a3) will call (object.*cb)(a1,a2,a3)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename A1,
+ typename A2,
+ typename A3,
+ typename A4
+ >
+ member_function_pointer<A1,A2,A3,A4> make_mfp (
+ T& object,
+ void (T::*cb)(A1 a1, A2 a2, A3 a3, A4 a4)
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP(a1,a2,a3,a4) will call (object.*cb)(a1,a2,a3,a4)
+ !*/
+
+ template <
+ typename T,
+ typename A1,
+ typename A2,
+ typename A3,
+ typename A4
+ >
+ member_function_pointer<A1,A2,A3,A4> make_mfp (
+ const T& object,
+ void (T::*cb)(A1 a1, A2 a2, A3 a3, A4 a4)const
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - returns a member function pointer object MFP such that:
+ - MFP.is_set() == true
+ - calls to MFP(a1,a2,a3,a4) will call (object.*cb)(a1,a2,a3,a4)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MAKE_MFp_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_1.h b/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_1.h
new file mode 100644
index 000000000..6cf5630b4
--- /dev/null
+++ b/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_1.h
@@ -0,0 +1,498 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMBER_FUNCTION_POINTER_KERNEl_1_
+#define DLIB_MEMBER_FUNCTION_POINTER_KERNEl_1_
+
+#include "../algs.h"
+#include "member_function_pointer_kernel_abstract.h"
+#include "../enable_if.h"
+#include <new>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1 = void,
+ typename PARAM2 = void,
+ typename PARAM3 = void,
+ typename PARAM4 = void
+ >
+ class member_function_pointer;
+
+// ----------------------------------------------------------------------------------------
+
+#define DLIB_MFP_SC DLIB_ASSERT(cb != 0, \
+ "\tvoid member_function_pointer::set" \
+ << "\n\tthe member function pointer can't be null" \
+ << "\n\tthis: " << this );
+
+
+#define DLIB_MFP_OC DLIB_ASSERT(this->is_set() == true , \
+ "\tvoid member_function_pointer::operator()" \
+ << "\n\tYou must call set() before you can use this function" \
+ << "\n\tthis: " << this);
+
+// ----------------------------------------------------------------------------------------
+
+ template <unsigned long num_args>
+ class mfp_kernel_1_base_class
+ {
+ /*
+ All member function pointer classes inherit from this class. This
+ is where most of the things in a member function pointer are defined.
+
+ The reason for the num_args template argument to this class is to prevent
+ any sort of implicit casting between derived member function pointer classes
+ that take different numbers of arguments.
+ */
+ protected:
+ enum mfp_type { mfp_nonconst, mfp_const, mfp_null};
+
+ class mp_base_base
+ {
+ public:
+ mp_base_base(void* ptr, mfp_type type_) : o(ptr),type(type_) {}
+ virtual ~mp_base_base(){}
+ virtual void clone(void* ptr) const = 0;
+ virtual bool is_same (const mp_base_base* item) const = 0;
+ bool is_set () const { return o!=0; }
+
+ void* const o;
+ const mfp_type type;
+ };
+
+ template <typename T>
+ class mp_null : public mp_base_base
+ {
+ public:
+ typedef void (T::*mfp_pointer_type)() ;
+
+ mp_null (void* , mfp_pointer_type ) : mp_base_base(0,mfp_null), callback(0) {}
+ mp_null () : mp_base_base(0,mfp_null), callback(0) {}
+
+ const mfp_pointer_type callback;
+ };
+
+ template <typename mp_impl>
+ class mp_impl_T : public mp_impl
+ {
+ /*
+ This class supplies the implementations clone() and is_same() for any
+ classes that inherit from mp_base_base. It does this in a very
+ roundabout way...
+ */
+
+ public:
+ typedef typename mp_impl::mfp_pointer_type mfp_pointer_type;
+
+ mp_impl_T() : mp_impl(0,0) {}
+ mp_impl_T(void* ptr, mfp_pointer_type cb) : mp_impl(ptr,cb) {}
+
+ template <unsigned long mem_size>
+ void safe_clone(stack_based_memory_block<mem_size>& buf)
+ {
+ // This is here just to validate the assumption that our block of memory we have made
+ // in mp_memory is the right size to store the data for this object. If you
+ // get a compiler error on this line then email me :)
+ COMPILE_TIME_ASSERT(sizeof(mp_impl_T) <= mem_size);
+ clone(buf.get());
+ }
+
+ void clone (void* ptr) const { new(ptr) mp_impl_T(this->o,this->callback); }
+ bool is_same (const mp_base_base* item) const
+ {
+ if (item->o == 0 && this->o == 0)
+ {
+ return true;
+ }
+ else if (item->o == this->o && this->type == item->type)
+ {
+ const mp_impl* i = reinterpret_cast<const mp_impl*>(item);
+ return (i->callback == this->callback);
+ }
+ return false;
+ }
+ };
+
+ struct dummy_base { virtual void nonnull() {}; virtual ~dummy_base(){}; int a; };
+ struct dummy : virtual public dummy_base{ void nonnull() {}; };
+
+ typedef mp_impl_T<mp_null<dummy> > mp_null_impl;
+ public:
+
+ mfp_kernel_1_base_class (
+ const mfp_kernel_1_base_class& item
+ ) { item.mp()->clone(mp_memory.get()); }
+
+ mfp_kernel_1_base_class (
+ ) { mp_null_impl().safe_clone(mp_memory); }
+
+ bool operator == (
+ const mfp_kernel_1_base_class& item
+ ) const { return mp()->is_same(item.mp()); }
+
+ bool operator != (
+ const mfp_kernel_1_base_class& item
+ ) const { return !(*this == item); }
+
+ mfp_kernel_1_base_class& operator= (
+ const mfp_kernel_1_base_class& item
+ ) { mfp_kernel_1_base_class(item).swap(*this); return *this; }
+
+ ~mfp_kernel_1_base_class (
+ ) { destroy_mp_memory(); }
+
+ void clear(
+ ) { mfp_kernel_1_base_class().swap(*this); }
+
+ bool is_set (
+ ) const { return mp()->is_set(); }
+
+ private:
+ typedef void (dummy::*safe_bool)();
+
+ public:
+ operator safe_bool () const { return is_set() ? &dummy::nonnull : 0; }
+ bool operator!() const { return !is_set(); }
+
+ void swap (
+ mfp_kernel_1_base_class& item
+ )
+ {
+ // make a temp copy of item
+ mfp_kernel_1_base_class temp(item);
+
+ // destory the stuff in item
+ item.destroy_mp_memory();
+ // copy *this into item
+ mp()->clone(item.mp_memory.get());
+
+ // destory the stuff in this
+ destroy_mp_memory();
+ // copy temp into *this
+ temp.mp()->clone(mp_memory.get());
+ }
+
+ protected:
+
+ // The reason for adding 1 here is because visual studio 2003 will sometimes
+ // try to compile this code with sizeof(mp_null_impl) == 0 (which is a bug in visual studio).
+ // Fortunately, no actual real instances of this template seem to end up with that screwed up
+ // value so everything works fine if we just add 1 so that this degenerate case doesn't cause
+ // trouble. Note that we know it all works fine because safe_clone() checks the size of this
+ // memory block whenever the member function pointer is used.
+ stack_based_memory_block<sizeof(mp_null_impl)+1> mp_memory;
+
+ void destroy_mp_memory (
+ )
+ {
+ // Honestly this probably doesn't even do anything but I'm putting
+ // it here just for good measure.
+ mp()->~mp_base_base();
+ }
+
+ mp_base_base* mp () { return static_cast<mp_base_base*>(mp_memory.get()); }
+ const mp_base_base* mp () const { return static_cast<const mp_base_base*>(mp_memory.get()); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <>
+ class member_function_pointer<void,void,void,void> : public mfp_kernel_1_base_class<0>
+ {
+ class mp_base : public mp_base_base {
+ public:
+ mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {}
+ virtual void call() const = 0;
+ };
+
+ template <typename T>
+ class mp_impl : public mp_base {
+ public:
+ typedef void (T::*mfp_pointer_type)() ;
+ void call () const { (static_cast<T*>(this->o)->*callback)(); }
+
+ mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ template <typename T>
+ class mp_impl_const : public mp_base {
+ public:
+ typedef void ((T::*mfp_pointer_type)()const);
+ void call () const { (static_cast<const T*>(this->o)->*callback)(); }
+
+ mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ public:
+ typedef void param1_type;
+ typedef void param2_type;
+ typedef void param3_type;
+ typedef void param4_type;
+
+ // These two typedefs are here for backwards compatibility with previous versions
+ // of dlib.
+ typedef member_function_pointer kernel_1a;
+ typedef member_function_pointer kernel_1a_c;
+
+
+ void operator() () const { DLIB_MFP_OC; static_cast<const mp_base*>(mp_memory.get())->call(); }
+
+ // the reason for putting disable_if on this function is that it avoids an overload
+ // resolution bug in visual studio.
+ template <typename T> typename disable_if<is_const_type<T>,void>::type
+ set(T& object, typename mp_impl<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl<T> >(&object,cb).safe_clone(mp_memory); }
+
+ template <typename T> void set(const T& object, typename mp_impl_const<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl_const<T> >((void*)&object,cb).safe_clone(mp_memory); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1
+ >
+ class member_function_pointer<PARAM1,void,void,void> : public mfp_kernel_1_base_class<1>
+ {
+ class mp_base : public mp_base_base {
+ public:
+ mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {}
+ virtual void call(PARAM1) const = 0;
+ };
+
+ template <typename T>
+ class mp_impl : public mp_base {
+ public:
+ typedef void (T::*mfp_pointer_type)(PARAM1) ;
+ void call (PARAM1 p1) const { (static_cast<T*>(this->o)->*callback)(p1); }
+
+ mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ template <typename T>
+ class mp_impl_const : public mp_base {
+ public:
+ typedef void ((T::*mfp_pointer_type)(PARAM1)const);
+ void call (PARAM1 p1) const { (static_cast<const T*>(this->o)->*callback)(p1); }
+
+ mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ public:
+ typedef PARAM1 param1_type;
+ typedef void param2_type;
+ typedef void param3_type;
+ typedef void param4_type;
+
+ // These two typedefs are here for backwards compatibility with previous versions
+ // of dlib.
+ typedef member_function_pointer kernel_1a;
+ typedef member_function_pointer kernel_1a_c;
+
+
+ void operator() (PARAM1 p1) const { DLIB_MFP_OC; static_cast<const mp_base*>(mp_memory.get())->call(p1); }
+
+ // the reason for putting disable_if on this function is that it avoids an overload
+ // resolution bug in visual studio.
+ template <typename T> typename disable_if<is_const_type<T>,void>::type
+ set(T& object, typename mp_impl<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl<T> >(&object,cb).safe_clone(mp_memory); }
+
+ template <typename T> void set(const T& object, typename mp_impl_const<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl_const<T> >((void*)&object,cb).safe_clone(mp_memory); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1,
+ typename PARAM2
+ >
+ class member_function_pointer<PARAM1,PARAM2,void,void> : public mfp_kernel_1_base_class<2>
+ {
+ class mp_base : public mp_base_base {
+ public:
+ mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {}
+ virtual void call(PARAM1,PARAM2) const = 0;
+ };
+
+ template <typename T>
+ class mp_impl : public mp_base {
+ public:
+ typedef void (T::*mfp_pointer_type)(PARAM1,PARAM2) ;
+ void call (PARAM1 p1, PARAM2 p2) const { (static_cast<T*>(this->o)->*callback)(p1,p2); }
+
+ mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ template <typename T>
+ class mp_impl_const : public mp_base {
+ public:
+ typedef void ((T::*mfp_pointer_type)(PARAM1,PARAM2)const);
+ void call (PARAM1 p1, PARAM2 p2) const { (static_cast<const T*>(this->o)->*callback)(p1,p2); }
+
+ mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ public:
+ typedef PARAM1 param1_type;
+ typedef PARAM2 param2_type;
+ typedef void param3_type;
+ typedef void param4_type;
+
+ // These two typedefs are here for backwards compatibility with previous versions
+ // of dlib.
+ typedef member_function_pointer kernel_1a;
+ typedef member_function_pointer kernel_1a_c;
+
+ void operator() (PARAM1 p1, PARAM2 p2) const { DLIB_MFP_OC; static_cast<const mp_base*>(mp_memory.get())->call(p1,p2); }
+
+ // the reason for putting disable_if on this function is that it avoids an overload
+ // resolution bug in visual studio.
+ template <typename T> typename disable_if<is_const_type<T>,void>::type
+ set(T& object, typename mp_impl<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl<T> >(&object,cb).safe_clone(mp_memory); }
+
+ template <typename T> void set(const T& object, typename mp_impl_const<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl_const<T> >((void*)&object,cb).safe_clone(mp_memory); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1,
+ typename PARAM2,
+ typename PARAM3
+ >
+ class member_function_pointer<PARAM1,PARAM2,PARAM3,void> : public mfp_kernel_1_base_class<3>
+ {
+ class mp_base : public mp_base_base {
+ public:
+ mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {}
+ virtual void call(PARAM1,PARAM2,PARAM3) const = 0;
+ };
+
+ template <typename T>
+ class mp_impl : public mp_base {
+ public:
+ typedef void (T::*mfp_pointer_type)(PARAM1,PARAM2,PARAM3) ;
+ void call (PARAM1 p1, PARAM2 p2, PARAM3 p3) const { (static_cast<T*>(this->o)->*callback)(p1,p2,p3); }
+
+ mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ template <typename T>
+ class mp_impl_const : public mp_base {
+ public:
+ typedef void ((T::*mfp_pointer_type)(PARAM1,PARAM2,PARAM3)const);
+ void call (PARAM1 p1, PARAM2 p2, PARAM3 p3) const { (static_cast<const T*>(this->o)->*callback)(p1,p2,p3); }
+
+ mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ public:
+ typedef PARAM1 param1_type;
+ typedef PARAM2 param2_type;
+ typedef PARAM3 param3_type;
+ typedef void param4_type;
+
+ // These two typedefs are here for backwards compatibility with previous versions
+ // of dlib.
+ typedef member_function_pointer kernel_1a;
+ typedef member_function_pointer kernel_1a_c;
+
+ void operator() (PARAM1 p1, PARAM2 p2, PARAM3 p3) const { DLIB_MFP_OC; static_cast<const mp_base*>(mp_memory.get())->call(p1,p2,p3); }
+
+ // the reason for putting disable_if on this function is that it avoids an overload
+ // resolution bug in visual studio.
+ template <typename T> typename disable_if<is_const_type<T>,void>::type
+ set(T& object, typename mp_impl<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl<T> >(&object,cb).safe_clone(mp_memory); }
+
+ template <typename T> void set(const T& object, typename mp_impl_const<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl_const<T> >((void*)&object,cb).safe_clone(mp_memory); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1,
+ typename PARAM2,
+ typename PARAM3,
+ typename PARAM4
+ >
+ class member_function_pointer : public mfp_kernel_1_base_class<4>
+ {
+ class mp_base : public mp_base_base {
+ public:
+ mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {}
+ virtual void call(PARAM1,PARAM2,PARAM3,PARAM4) const = 0;
+ };
+
+ template <typename T>
+ class mp_impl : public mp_base {
+ public:
+ typedef void (T::*mfp_pointer_type)(PARAM1,PARAM2,PARAM3, PARAM4) ;
+ void call (PARAM1 p1, PARAM2 p2, PARAM3 p3, PARAM4 p4) const { (static_cast<T*>(this->o)->*callback)(p1,p2,p3,p4); }
+
+ mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ template <typename T>
+ class mp_impl_const : public mp_base {
+ public:
+ typedef void ((T::*mfp_pointer_type)(PARAM1,PARAM2,PARAM3,PARAM4)const);
+ void call (PARAM1 p1, PARAM2 p2, PARAM3 p3, PARAM4 p4) const { (static_cast<const T*>(this->o)->*callback)(p1,p2,p3,p4); }
+
+ mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {}
+ const mfp_pointer_type callback;
+ };
+
+ public:
+ typedef PARAM1 param1_type;
+ typedef PARAM2 param2_type;
+ typedef PARAM3 param3_type;
+ typedef PARAM4 param4_type;
+
+ // These two typedefs are here for backwards compatibility with previous versions
+ // of dlib.
+ typedef member_function_pointer kernel_1a;
+ typedef member_function_pointer kernel_1a_c;
+
+ void operator() (PARAM1 p1, PARAM2 p2, PARAM3 p3, PARAM4 p4) const
+ { DLIB_MFP_OC; static_cast<const mp_base*>(mp_memory.get())->call(p1,p2,p3,p4); }
+
+ // the reason for putting disable_if on this function is that it avoids an overload
+ // resolution bug in visual studio.
+ template <typename T> typename disable_if<is_const_type<T>,void>::type
+ set(T& object, typename mp_impl<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl<T> >(&object,cb).safe_clone(mp_memory); }
+
+ template <typename T> void set(const T& object, typename mp_impl_const<T>::mfp_pointer_type cb)
+ { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T<mp_impl_const<T> >((void*)&object,cb).safe_clone(mp_memory); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MEMBER_FUNCTION_POINTER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_abstract.h b/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_abstract.h
new file mode 100644
index 000000000..e152972ee
--- /dev/null
+++ b/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_abstract.h
@@ -0,0 +1,483 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MEMBER_FUNCTION_POINTER_KERNEl_ABSTRACT_
+#ifdef DLIB_MEMBER_FUNCTION_POINTER_KERNEl_ABSTRACT_
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1 = void,
+ typename PARAM2 = void,
+ typename PARAM3 = void,
+ typename PARAM4 = void
+ >
+ class member_function_pointer;
+
+// ----------------------------------------------------------------------------------------
+
+ template <>
+ class member_function_pointer<void,void,void,void>
+ {
+ /*!
+ INITIAL VALUE
+ is_set() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a member function pointer. It is useful because
+ instances of this object can be created without needing to know the type
+ of object whose member function we will be calling.
+
+ There are five template specializations of this object. The first
+ represents a pointer to a member function taking no parameters, the
+ second represents a pointer to a member function taking one parameter,
+ the third to one taking two parameters, and so on.
+
+ You specify the parameters to your member function pointer by filling in
+ the PARAM template parameters. For example:
+
+ To use a pointer to a function with no parameters you would say:
+ member_function_pointer<> my_pointer;
+ To use a pointer to a function that takes a single int you would say:
+ member_function_pointer<int> my_pointer;
+ To use a pointer to a function that takes an int and then a reference
+ to a string you would say:
+ member_function_pointer<int,string&> my_pointer;
+
+ Also note that the formal comments are only present for the first
+ template specialization. They are all exactly the same except for the
+ number of parameters each takes in its member function pointer.
+ !*/
+
+ public:
+ typedef void param1_type;
+ typedef void param2_type;
+ typedef void param3_type;
+ typedef void param4_type;
+
+ member_function_pointer (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ !*/
+
+ member_function_pointer(
+ const member_function_pointer& item
+ );
+ /*!
+ ensures
+ - *this == item
+ !*/
+
+ ~member_function_pointer (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ member_function_pointer& operator=(
+ const member_function_pointer& item
+ );
+ /*!
+ ensures
+ - *this == item
+ !*/
+
+ bool operator == (
+ const member_function_pointer& item
+ ) const;
+ /*!
+ ensures
+ - if (is_set() == false && item.is_set() == false) then
+ - returns true
+ - else if (both *this and item point to the same member function
+ in the same object instance) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool operator != (
+ const member_function_pointer& item
+ ) const;
+ /*!
+ ensures
+ - returns !(*this == item)
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ !*/
+
+ bool is_set (
+ ) const;
+ /*!
+ ensures
+ - if (this->set() has been called) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ template <
+ typename T
+ >
+ void set (
+ T& object,
+ void (T::*cb)()
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*cb)()
+ !*/
+
+ template <
+ typename T
+ >
+ void set (
+ const T& object,
+ void (T::*cb)()const
+ );
+ /*!
+ requires
+ - cb == a valid member function pointer for class T
+ ensures
+ - #is_set() == true
+ - calls to this->operator() will call (object.*cb)()
+ !*/
+
+ operator some_undefined_pointer_type (
+ ) const;
+ /*!
+ ensures
+ - if (is_set()) then
+ - returns a non 0 value
+ - else
+ - returns a 0 value
+ !*/
+
+ bool operator! (
+ ) const;
+ /*!
+ ensures
+ - returns !is_set()
+ !*/
+
+ void operator () (
+ ) const;
+ /*!
+ requires
+ - is_set() == true
+ ensures
+ - calls the member function on the object specified by the last
+ call to this->set()
+ throws
+ - any exception thrown by the member function specified by
+ the previous call to this->set().
+ If any of these exceptions are thrown then the call to this
+ function will have no effect on *this.
+ !*/
+
+ void swap (
+ member_function_pointer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1
+ >
+ class member_function_pointer<PARAM1,void,void,void>
+ {
+ public:
+ typedef PARAM1 param1_type;
+ typedef void param2_type;
+ typedef void param3_type;
+ typedef void param4_type;
+
+ member_function_pointer ();
+
+ member_function_pointer(
+ const member_function_pointer& item
+ );
+
+ ~member_function_pointer (
+ );
+
+ member_function_pointer& operator=(
+ const member_function_pointer& item
+ );
+
+ bool operator == (
+ const member_function_pointer& item
+ ) const;
+
+ bool operator != (
+ const member_function_pointer& item
+ ) const;
+
+ void clear();
+
+ bool is_set () const;
+
+ template <typename T>
+ void set (
+ T& object,
+ void (T::*cb)(PARAM1)
+ );
+
+ template <typename T>
+ void set (
+ const T& object,
+ void (T::*cb)(PARAM1)const
+ );
+
+ operator some_undefined_pointer_type (
+ ) const;
+
+ bool operator! (
+ ) const;
+
+ void operator () (
+ PARAM1 param1
+ ) const;
+
+ void swap (
+ member_function_pointer& item
+ );
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1,
+ typename PARAM2
+ >
+ class member_function_pointer<PARAM1,PARAM2,void,void>
+ {
+ public:
+ typedef PARAM1 param1_type;
+ typedef PARAM2 param2_type;
+ typedef void param3_type;
+ typedef void param4_type;
+
+ member_function_pointer ();
+
+ member_function_pointer(
+ const member_function_pointer& item
+ );
+
+ ~member_function_pointer (
+ );
+
+ member_function_pointer& operator=(
+ const member_function_pointer& item
+ );
+
+ bool operator == (
+ const member_function_pointer& item
+ ) const;
+
+ bool operator != (
+ const member_function_pointer& item
+ ) const;
+
+ void clear();
+
+ bool is_set () const;
+
+ template <typename T>
+ void set (
+ T& object,
+ void (T::*cb)(PARAM1,PARAM2)
+ );
+
+ template <typename T>
+ void set (
+ const T& object,
+ void (T::*cb)(PARAM1,PARAM2)const
+ );
+
+ operator some_undefined_pointer_type (
+ ) const;
+
+ bool operator! (
+ ) const;
+
+ void operator () (
+ PARAM1 param1,
+ PARAM2 param2
+ ) const;
+
+ void swap (
+ member_function_pointer& item
+ );
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1,
+ typename PARAM2,
+ typename PARAM3
+ >
+ class member_function_pointer<PARAM1,PARAM2,PARAM3,void>
+ {
+ public:
+ typedef PARAM1 param1_type;
+ typedef PARAM2 param2_type;
+ typedef PARAM3 param3_type;
+ typedef void param4_type;
+
+ member_function_pointer ();
+
+ member_function_pointer(
+ const member_function_pointer& item
+ );
+
+ ~member_function_pointer (
+ );
+
+ member_function_pointer& operator=(
+ const member_function_pointer& item
+ );
+
+ bool operator == (
+ const member_function_pointer& item
+ ) const;
+
+ bool operator != (
+ const member_function_pointer& item
+ ) const;
+
+ void clear();
+
+ bool is_set () const;
+
+ template <typename T>
+ void set (
+ T& object,
+ void (T::*cb)(PARAM1,PARAM2,PARAM3)
+ );
+
+ template <typename T>
+ void set (
+ const T& object,
+ void (T::*cb)(PARAM1,PARAM2,PARAM3)const
+ );
+
+ operator some_undefined_pointer_type (
+ ) const;
+
+ bool operator! (
+ ) const;
+
+ void operator () (
+ PARAM1 param1,
+ PARAM2 param2,
+ PARAM2 param3
+ ) const;
+
+ void swap (
+ member_function_pointer& item
+ );
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename PARAM1,
+ typename PARAM2,
+ typename PARAM3,
+ typename PARAM4
+ >
+ class member_function_pointer
+ {
+ public:
+ typedef PARAM1 param1_type;
+ typedef PARAM2 param2_type;
+ typedef PARAM3 param3_type;
+ typedef PARAM4 param4_type;
+
+ member_function_pointer ();
+
+ member_function_pointer(
+ const member_function_pointer& item
+ );
+
+ ~member_function_pointer (
+ );
+
+ member_function_pointer& operator=(
+ const member_function_pointer& item
+ );
+
+ bool operator == (
+ const member_function_pointer& item
+ ) const;
+
+ bool operator != (
+ const member_function_pointer& item
+ ) const;
+
+ void clear();
+
+ bool is_set () const;
+
+ template <typename T>
+ void set (
+ T& object,
+ void (T::*cb)(PARAM1,PARAM2,PARAM3,PARAM4)
+ );
+
+ template <typename T>
+ void set (
+ const T& object,
+ void (T::*cb)(PARAM1,PARAM2,PARAM3,PARAM4)const
+ );
+
+ operator some_undefined_pointer_type (
+ ) const;
+
+ bool operator! (
+ ) const;
+
+ void operator () (
+ PARAM1 param1,
+ PARAM2 param2,
+ PARAM2 param3,
+ PARAM2 param4
+ ) const;
+
+ void swap (
+ member_function_pointer& item
+ );
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+#endif // DLIB_MEMBER_FUNCTION_POINTER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/memory_manager.h b/ml/dlib/dlib/memory_manager.h
new file mode 100644
index 000000000..5b5283255
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager.h
@@ -0,0 +1,73 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMORY_MANAGEr_
+#define DLIB_MEMORY_MANAGEr_
+
+#include "memory_manager/memory_manager_kernel_1.h"
+#include "memory_manager/memory_manager_kernel_2.h"
+#include "memory_manager/memory_manager_kernel_3.h"
+
+
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class memory_manager
+ {
+ memory_manager() {}
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1
+ typedef memory_manager_kernel_1<T,0>
+ kernel_1a;
+ typedef memory_manager_kernel_1<T,10>
+ kernel_1b;
+ typedef memory_manager_kernel_1<T,100>
+ kernel_1c;
+ typedef memory_manager_kernel_1<T,1000>
+ kernel_1d;
+ typedef memory_manager_kernel_1<T,10000>
+ kernel_1e;
+ typedef memory_manager_kernel_1<T,100000>
+ kernel_1f;
+
+ // kernel_2
+ typedef memory_manager_kernel_2<T,10>
+ kernel_2a;
+ typedef memory_manager_kernel_2<T,100>
+ kernel_2b;
+ typedef memory_manager_kernel_2<T,1000>
+ kernel_2c;
+ typedef memory_manager_kernel_2<T,10000>
+ kernel_2d;
+ typedef memory_manager_kernel_2<T,100000>
+ kernel_2e;
+
+
+ // kernel_3
+ typedef memory_manager_kernel_3<T,10>
+ kernel_3a;
+ typedef memory_manager_kernel_3<T,100>
+ kernel_3b;
+ typedef memory_manager_kernel_3<T,1000>
+ kernel_3c;
+ typedef memory_manager_kernel_3<T,10000>
+ kernel_3d;
+ typedef memory_manager_kernel_3<T,100000>
+ kernel_3e;
+
+
+
+
+ };
+}
+
+#endif // DLIB_MEMORY_MANAGEr_
+
diff --git a/ml/dlib/dlib/memory_manager/memory_manager_kernel_1.h b/ml/dlib/dlib/memory_manager/memory_manager_kernel_1.h
new file mode 100644
index 000000000..557a19fca
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager/memory_manager_kernel_1.h
@@ -0,0 +1,305 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMORY_MANAGER_KERNEl_1_
+#define DLIB_MEMORY_MANAGER_KERNEl_1_
+
+#include "../algs.h"
+#include "memory_manager_kernel_abstract.h"
+#include "../assert.h"
+#include <new>
+
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ unsigned long max_pool_size
+ >
+ class memory_manager_kernel_1
+ {
+ /*!
+ INITIAL VALUE
+ allocations == 0
+ next == 0
+ pool_size == 0
+
+ REQUIREMENTS ON max_pool_size
+ max_pool_size is the maximum number of nodes we will keep in our linked list at once.
+ So you can put any value in for this argument.
+
+ CONVENTION
+ This memory manager implementation allocates T objects one at a time when there are
+ allocation requests. Then when there is a deallocate request the returning T object
+ is place into a list of free blocks if that list has less than max_pool_size
+ blocks in it. subsequent allocation requests will be serviced by drawing from the
+ free list whenever it isn't empty.
+
+
+ allocations == get_number_of_allocations()
+
+ - if (next != 0) then
+ - next == the next pointer to return from allocate()
+ and next == pointer to the first node in a linked list. each node
+ is one item in the memory pool.
+ - the last node in the linked list has next set to 0
+ - pool_size == the number of nodes in the linked list
+ - pool_size <= max_pool_size
+ - else
+ - we need to call new to get the next pointer to return from allocate()
+
+ !*/
+
+ union node
+ {
+ node* next;
+ char item[sizeof(T)];
+ };
+
+ public:
+
+ typedef T type;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager_kernel_1<U,max_pool_size> other;
+ };
+
+
+ memory_manager_kernel_1(
+ ) :
+ allocations(0),
+ next(0),
+ pool_size(0)
+ {
+ }
+
+ virtual ~memory_manager_kernel_1(
+ )
+ {
+
+ while (next != 0)
+ {
+ node* temp = next;
+ next = next->next;
+ ::operator delete ( static_cast<void*>(temp));
+ }
+ }
+
+ unsigned long get_number_of_allocations (
+ ) const { return allocations; }
+
+ T* allocate_array (
+ unsigned long size
+ )
+ {
+ T* temp = new T[size];
+ ++allocations;
+ return temp;
+ }
+
+ void deallocate_array (
+ T* item
+ )
+ {
+ --allocations;
+ delete [] item;
+ }
+
+ T* allocate (
+ )
+ {
+ T* temp;
+ if (next != 0)
+ {
+ temp = reinterpret_cast<T*>(next);
+
+ node* n = next->next;
+
+ try
+ {
+ // construct this new T object with placement new.
+ new (static_cast<void*>(temp))T();
+ }
+ catch (...)
+ {
+ next->next = n;
+ throw;
+ }
+
+ next = n;
+
+ --pool_size;
+ }
+ else
+ {
+ temp = static_cast<T*>(::operator new(sizeof(node)));
+ try
+ {
+ // construct this new T object with placement new.
+ new (static_cast<void*>(temp))T();
+ }
+ catch (...)
+ {
+ // construction of the new object threw so delete the block of memory
+ ::operator delete ( static_cast<void*>(temp));
+ throw;
+ }
+ }
+
+ ++allocations;
+ return temp;
+ }
+
+ void deallocate (
+ T* item
+ )
+ {
+ --allocations;
+ item->~T();
+
+ if (pool_size >= max_pool_size)
+ {
+ ::operator delete ( static_cast<void*>(item));
+ return;
+ }
+
+ // add this memory chunk into our linked list.
+ node* temp = reinterpret_cast<node*>(item);
+ temp->next = next;
+ next = temp;
+ ++pool_size;
+ }
+
+ void swap (
+ memory_manager_kernel_1& item
+ )
+ {
+ exchange(allocations,item.allocations);
+ exchange(next,item.next);
+ exchange(pool_size,item.pool_size);
+ }
+
+ private:
+
+ // data members
+ unsigned long allocations;
+ node* next;
+ unsigned long pool_size;
+
+ // restricted functions
+ memory_manager_kernel_1(memory_manager_kernel_1&); // copy constructor
+ memory_manager_kernel_1& operator=(memory_manager_kernel_1&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class memory_manager_kernel_1<T,0>
+ {
+ /*!
+ INITIAL VALUE
+ allocations == 0
+
+ CONVENTION
+ This memory manager just calls new and delete directly so it doesn't
+ really do anything.
+
+ allocations == get_number_of_allocations()
+ !*/
+
+ public:
+
+ typedef T type;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager_kernel_1<U,0> other;
+ };
+
+
+ memory_manager_kernel_1(
+ ) :
+ allocations(0)
+ {
+ }
+
+ virtual ~memory_manager_kernel_1(
+ )
+ {
+ }
+
+ unsigned long get_number_of_allocations (
+ ) const { return allocations; }
+
+ T* allocate_array (
+ unsigned long size
+ )
+ {
+ T* temp = new T[size];
+ ++allocations;
+ return temp;
+ }
+
+ void deallocate_array (
+ T* item
+ )
+ {
+ --allocations;
+ delete [] item;
+ }
+
+ T* allocate (
+ )
+ {
+ T* temp = new T;
+ ++allocations;
+ return temp;
+ }
+
+ void deallocate (
+ T* item
+ )
+ {
+ delete item;
+ --allocations;
+ }
+
+ void swap (
+ memory_manager_kernel_1& item
+ )
+ {
+ exchange(allocations,item.allocations);
+ }
+
+ private:
+
+ // data members
+ unsigned long allocations;
+
+ // restricted functions
+ memory_manager_kernel_1(memory_manager_kernel_1&); // copy constructor
+ memory_manager_kernel_1& operator=(memory_manager_kernel_1&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long max_pool_size
+ >
+ inline void swap (
+ memory_manager_kernel_1<T,max_pool_size>& a,
+ memory_manager_kernel_1<T,max_pool_size>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MEMORY_MANAGER_KERNEl_1_
+
+
+
diff --git a/ml/dlib/dlib/memory_manager/memory_manager_kernel_2.h b/ml/dlib/dlib/memory_manager/memory_manager_kernel_2.h
new file mode 100644
index 000000000..8f026bc49
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager/memory_manager_kernel_2.h
@@ -0,0 +1,253 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMORY_MANAGER_KERNEl_2_
+#define DLIB_MEMORY_MANAGER_KERNEl_2_
+
+#include "../algs.h"
+#include "memory_manager_kernel_abstract.h"
+#include "../assert.h"
+#include <new>
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ unsigned long chunk_size
+ >
+ class memory_manager_kernel_2
+ {
+ /*!
+ INITIAL VALUE
+ allocations == 0
+ next == 0
+ first_chunk == 0
+
+ REQUIREMENTS ON chunk_size
+ chunk_size is the number of items of type T we will allocate at a time. so
+ it must be > 0.
+
+ CONVENTION
+ This memory manager implementation allocates memory in blocks of chunk_size*sizeof(T)
+ bytes. All the sizeof(T) subblocks are kept in a linked list of free memory blocks
+ and are given out whenever an allocation request occurs. Also, memory is not freed
+ until this object is destructed.
+
+ Note that array allocations are not memory managed.
+
+
+
+ allocations == get_number_of_allocations()
+
+ - if (next != 0) then
+ - next == the next pointer to return from allocate()
+ and next == pointer to the first node in a linked list. each node
+ is one item in the memory pool.
+ - the last node in the linked list has next set to 0
+ - else
+ - we need to call new to get the next pointer to return from allocate()
+
+
+ - if (first_chunk != 0) then
+ - first_chunk == the first node in a linked list that contains pointers
+ to all the chunks we have ever allocated. The last link in the list
+ has its next pointer set to 0.
+ !*/
+
+ union node
+ {
+ node* next;
+ char item[sizeof(T)];
+ };
+
+ struct chunk_node
+ {
+ node* chunk;
+ chunk_node* next;
+ };
+
+ public:
+
+ typedef T type;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager_kernel_2<U,chunk_size> other;
+ };
+
+
+ memory_manager_kernel_2(
+ ) :
+ allocations(0),
+ next(0),
+ first_chunk(0)
+ {
+ // You FOOL! You can't have a zero chunk_size.
+ COMPILE_TIME_ASSERT(chunk_size > 0);
+ }
+
+ virtual ~memory_manager_kernel_2(
+ )
+ {
+ if (allocations == 0)
+ {
+ while (first_chunk != 0)
+ {
+ chunk_node* temp = first_chunk;
+ first_chunk = first_chunk->next;
+ // delete the memory chunk
+ ::operator delete ( static_cast<void*>(temp->chunk));
+ // delete the chunk_node
+ delete temp;
+ }
+ }
+ }
+
+ unsigned long get_number_of_allocations (
+ ) const { return allocations; }
+
+ T* allocate_array (
+ unsigned long size
+ )
+ {
+ T* temp = new T[size];
+ ++allocations;
+ return temp;
+ }
+
+ void deallocate_array (
+ T* item
+ )
+ {
+ --allocations;
+ delete [] item;
+ }
+
+ T* allocate (
+ )
+ {
+ T* temp = 0;
+ if (next != 0)
+ {
+ temp = reinterpret_cast<T*>(next);
+ node* n = next->next;
+
+ try
+ {
+ // construct this new T object with placement new.
+ new (static_cast<void*>(temp))T();
+ }
+ catch (...)
+ {
+ next->next = n;
+ throw;
+ }
+
+ next = n;
+ }
+ else
+ {
+ // the linked list is empty so we need to allocate some more memory
+ node* block = 0;
+ block = static_cast<node*>(::operator new (sizeof(node)*chunk_size));
+
+ // the first part of this block can be our new object
+ temp = reinterpret_cast<T*>(block);
+
+ try
+ {
+ // construct this new T object with placement new.
+ new (static_cast<void*>(temp))T();
+ }
+ catch (...)
+ {
+ // construction of the new object threw so delete the block of memory
+ ::operator delete ( static_cast<void*>(block));
+ throw;
+ }
+
+ // allocate a new chunk_node
+ chunk_node* chunk;
+ try {chunk = new chunk_node; }
+ catch (...)
+ {
+ temp->~T();
+ ::operator delete ( static_cast<void*>(block));
+ throw;
+ }
+
+ // add this block into the chunk list
+ chunk->chunk = block;
+ chunk->next = first_chunk;
+ first_chunk = chunk;
+
+
+ ++block;
+ // now add the rest of the block into the linked list of free nodes.
+ for (unsigned long i = 0; i < chunk_size-1; ++i)
+ {
+ block->next = next;
+ next = block;
+ ++block;
+ }
+
+ }
+
+
+ ++allocations;
+ return temp;
+ }
+
+ void deallocate (
+ T* item
+ )
+ {
+ --allocations;
+ item->~T();
+
+ // add this memory into our linked list.
+ node* temp = reinterpret_cast<node*>(item);
+ temp->next = next;
+ next = temp;
+ }
+
+ void swap (
+ memory_manager_kernel_2& item
+ )
+ {
+ exchange(allocations,item.allocations);
+ exchange(next,item.next);
+ exchange(first_chunk,item.first_chunk);
+ }
+
+ private:
+
+ // data members
+ unsigned long allocations;
+ node* next;
+
+ chunk_node* first_chunk;
+
+
+
+
+ // restricted functions
+ memory_manager_kernel_2(memory_manager_kernel_2&); // copy constructor
+ memory_manager_kernel_2& operator=(memory_manager_kernel_2&); // assignment operator
+ };
+
+ template <
+ typename T,
+ unsigned long chunk_size
+ >
+ inline void swap (
+ memory_manager_kernel_2<T,chunk_size>& a,
+ memory_manager_kernel_2<T,chunk_size>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MEMORY_MANAGER_KERNEl_2_
+
diff --git a/ml/dlib/dlib/memory_manager/memory_manager_kernel_3.h b/ml/dlib/dlib/memory_manager/memory_manager_kernel_3.h
new file mode 100644
index 000000000..1f9229772
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager/memory_manager_kernel_3.h
@@ -0,0 +1,385 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMORY_MANAGER_KERNEl_3_
+#define DLIB_MEMORY_MANAGER_KERNEl_3_
+
+#include "../algs.h"
+#include "memory_manager_kernel_abstract.h"
+#include "../assert.h"
+#include <new>
+#include "memory_manager_kernel_2.h"
+#include "../binary_search_tree/binary_search_tree_kernel_2.h"
+
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ unsigned long chunk_size
+ >
+ class memory_manager_kernel_3
+ {
+ /*!
+ INITIAL VALUE
+ allocations == 0
+ next == 0
+ first_chunk == 0
+ bst_of_arrays == 0
+
+ REQUIREMENTS ON chunk_size
+ chunk_size is the number of items of type T we will allocate at a time. so
+ it must be > 0.
+
+ CONVENTION
+ This memory manager implementation allocates memory in blocks of chunk_size*sizeof(T)
+ bytes. All the sizeof(T) subblocks are kept in a linked list of free memory blocks
+ and are given out whenever an allocation request occurs. Also, memory is not freed
+ until this object is destructed.
+
+
+
+ allocations == get_number_of_allocations()
+
+ - if (next != 0) then
+ - next == the next pointer to return from allocate()
+ and next == pointer to the first node in a linked list. each node
+ is one item in the memory pool.
+ - the last node in the linked list has next set to 0
+ - else
+ - we need to call new to get the next pointer to return from allocate()
+
+ - if (arrays != 0) then
+ - someone has called allocate_array()
+ - (*arrays)[size] == an array of size bytes of memory
+
+ - if (first_chunk != 0) then
+ - first_chunk == the first node in a linked list that contains pointers
+ to all the chunks we have ever allocated. The last link in the list
+ has its next pointer set to 0.
+ !*/
+
+ union node
+ {
+ node* next;
+ char item[sizeof(T)];
+ };
+
+ struct chunk_node
+ {
+ node* chunk;
+ chunk_node* next;
+ };
+
+
+ typedef binary_search_tree_kernel_2<
+ size_t,
+ char*,
+ memory_manager_kernel_2<char,5>
+ > bst_of_arrays;
+
+ public:
+
+ typedef T type;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager_kernel_3<U,chunk_size> other;
+ };
+
+
+ memory_manager_kernel_3(
+ ) :
+ allocations(0),
+ next(0),
+ first_chunk(0),
+ arrays(0)
+ {
+ // You FOOL! You can't have a zero chunk_size.
+ COMPILE_TIME_ASSERT(chunk_size > 0);
+ }
+
+ virtual ~memory_manager_kernel_3(
+ )
+ {
+ if (allocations == 0)
+ {
+ while (first_chunk != 0)
+ {
+ chunk_node* temp = first_chunk;
+ first_chunk = first_chunk->next;
+ // delete the memory chunk
+ ::operator delete ( static_cast<void*>(temp->chunk));
+ // delete the chunk_node
+ delete temp;
+ }
+ }
+
+ if (arrays)
+ {
+ arrays->reset();
+ while (arrays->move_next())
+ {
+ ::operator delete (arrays->element().value());
+ }
+ delete arrays;
+ }
+ }
+
+ unsigned long get_number_of_allocations (
+ ) const { return allocations; }
+
+ T* allocate_array (
+ unsigned long size
+ )
+ {
+ size_t block_size = sizeof(T)*size + sizeof(size_t)*2;
+
+ // make sure we have initialized the arrays object.
+ if (arrays == 0)
+ {
+ arrays = new bst_of_arrays;
+ }
+
+ char* temp;
+
+ // see if we have a suitable block of memory already.
+ arrays->position_enumerator(block_size);
+ if (arrays->current_element_valid())
+ {
+ // we have a suitable block of memory already so use that one.
+ arrays->remove_current_element(block_size,temp);
+ }
+ else
+ {
+ temp = static_cast<char*>(::operator new(block_size));
+ }
+
+ reinterpret_cast<size_t*>(temp)[0] = block_size;
+ reinterpret_cast<size_t*>(temp)[1] = size;
+ temp += sizeof(size_t)*2;
+
+ try
+ {
+ initialize_array(reinterpret_cast<T*>(temp),size);
+ }
+ catch (...)
+ {
+ // something was thrown while we were initializing the array so
+ // stick our memory block into arrays and rethrow the exception
+ temp -= sizeof(size_t)*2;
+ arrays->add(block_size,temp);
+ throw;
+ }
+
+ ++allocations;
+ return reinterpret_cast<T*>(temp);
+ }
+
+ void deallocate_array (
+ T* item
+ )
+ {
+ char* temp = reinterpret_cast<char*>(item);
+ temp -= sizeof(size_t)*2;
+ size_t block_size = reinterpret_cast<size_t*>(temp)[0];
+ size_t size = reinterpret_cast<size_t*>(temp)[1];
+
+ deinitialize_array(item,size);
+
+ arrays->add(block_size,temp);
+
+ --allocations;
+ }
+
+ T* allocate (
+ )
+ {
+ T* temp;
+ if (next != 0)
+ {
+ temp = reinterpret_cast<T*>(next);
+ node* n = next->next;
+
+ try
+ {
+ // construct this new T object with placement new.
+ new (static_cast<void*>(temp))T();
+ }
+ catch (...)
+ {
+ next->next = n;
+ throw;
+ }
+
+ next = n;
+ }
+ else
+ {
+ // the linked list is empty so we need to allocate some more memory
+ node* block = static_cast<node*>(::operator new (sizeof(node)*chunk_size));
+
+ // the first part of this block can be our new object
+ temp = reinterpret_cast<T*>(block);
+
+ try
+ {
+ // construct this new T object with placement new.
+ new (static_cast<void*>(temp))T();
+ }
+ catch (...)
+ {
+ // construction of the new object threw so delete the block of memory
+ ::operator delete ( static_cast<void*>(block));
+ throw;
+ }
+
+ // allocate a new chunk_node
+ chunk_node* chunk;
+ try {chunk = new chunk_node; }
+ catch (...)
+ {
+ temp->~T();
+ ::operator delete ( static_cast<void*>(block));
+ throw;
+ }
+
+ // add this block into the chunk list
+ chunk->chunk = block;
+ chunk->next = first_chunk;
+ first_chunk = chunk;
+
+
+ ++block;
+ // now add the rest of the block into the linked list of free nodes.
+ for (unsigned long i = 0; i < chunk_size-1; ++i)
+ {
+ block->next = next;
+ next = block;
+ ++block;
+ }
+
+ }
+
+
+ ++allocations;
+ return temp;
+ }
+
+ void deallocate (
+ T* item
+ )
+ {
+ --allocations;
+ item->~T();
+
+ // add this memory into our linked list.
+ node* temp = reinterpret_cast<node*>(item);
+ temp->next = next;
+ next = temp;
+ }
+
+ void swap (
+ memory_manager_kernel_3& item
+ )
+ {
+ exchange(allocations,item.allocations);
+ exchange(next,item.next);
+ exchange(first_chunk,item.first_chunk);
+ exchange(arrays,item.arrays);
+ }
+
+ private:
+
+ // data members
+ unsigned long allocations;
+ node* next;
+
+ chunk_node* first_chunk;
+ bst_of_arrays* arrays;
+
+
+ void initialize_array (
+ T* array,
+ size_t size
+ ) const
+ {
+ size_t i;
+ try
+ {
+ for (i = 0; i < size; ++i)
+ {
+ // construct this new T object with placement new.
+ new (static_cast<void*>(array+i))T();
+ }
+ }
+ catch (...)
+ {
+ // Catch any exceptions thrown during the construction process
+ // and then destruct any T objects that actually were successfully
+ // constructed.
+ for (size_t j = 0; j < i; ++j)
+ {
+ array[i].~T();
+ }
+ throw;
+ }
+ }
+
+ void deinitialize_array (
+ T* array,
+ size_t size
+ ) const
+ {
+ for (size_t i = 0; i < size; ++i)
+ {
+ array[i].~T();
+ }
+ }
+
+ // don't do any initialization for the built in types
+ void initialize_array(unsigned char*, size_t) {}
+ void deinitialize_array(unsigned char*, size_t) {}
+ void initialize_array(signed char*, size_t) {}
+ void deinitialize_array(signed char*, size_t) {}
+ void initialize_array(char*, size_t) {}
+ void deinitialize_array(char*, size_t) {}
+ void initialize_array(int*, size_t) {}
+ void deinitialize_array(int*, size_t) {}
+ void initialize_array(unsigned int*, size_t) {}
+ void deinitialize_array(unsigned int*, size_t) {}
+ void initialize_array(unsigned long*, size_t) {}
+ void deinitialize_array(unsigned long*, size_t) {}
+ void initialize_array(long*, size_t) {}
+ void deinitialize_array(long*, size_t) {}
+ void initialize_array(float*, size_t) {}
+ void deinitialize_array(float*, size_t) {}
+ void initialize_array(double*, size_t) {}
+ void deinitialize_array(double*, size_t) {}
+ void initialize_array(short*, size_t) {}
+ void deinitialize_array(short*, size_t) {}
+ void initialize_array(unsigned short*, size_t) {}
+ void deinitialize_array(unsigned short*, size_t) {}
+
+
+
+ // restricted functions
+ memory_manager_kernel_3(memory_manager_kernel_3&); // copy constructor
+ memory_manager_kernel_3& operator=(memory_manager_kernel_3&); // assignment operator
+ };
+
+ template <
+ typename T,
+ unsigned long chunk_size
+ >
+ inline void swap (
+ memory_manager_kernel_3<T,chunk_size>& a,
+ memory_manager_kernel_3<T,chunk_size>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MEMORY_MANAGER_KERNEl_3_
+
diff --git a/ml/dlib/dlib/memory_manager/memory_manager_kernel_abstract.h b/ml/dlib/dlib/memory_manager/memory_manager_kernel_abstract.h
new file mode 100644
index 000000000..8439caf85
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager/memory_manager_kernel_abstract.h
@@ -0,0 +1,146 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MEMORY_MANAGER_KERNEl_ABSTRACT_
+#ifdef DLIB_MEMORY_MANAGER_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+ template <
+ typename T
+ >
+ class memory_manager
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must have a default constructor.
+
+ INITIAL VALUE
+ get_number_of_allocations() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents some kind of memory manager or memory pool.
+ !*/
+
+ public:
+
+ typedef T type;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager<U> other;
+ };
+
+ memory_manager(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~memory_manager(
+ );
+ /*!
+ ensures
+ - if (get_number_of_allocations() == 0) then
+ - all resources associated with *this have been released.
+ - else
+ - The memory still allocated will not be deleted and this
+ causes a memory leak.
+ !*/
+
+ unsigned long get_number_of_allocations (
+ ) const;
+ /*!
+ ensures
+ - returns the current number of outstanding allocations
+ !*/
+
+ T* allocate (
+ );
+ /*!
+ ensures
+ - allocates a new object of type T and returns a pointer to it.
+ - #get_number_of_allocations() == get_number_of_allocations() + 1
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then the call to allocate()
+ has no effect on #*this.
+ !*/
+
+ void deallocate (
+ T* item
+ );
+ /*!
+ requires
+ - item == is a pointer to memory that was obtained from a call to
+ this->allocate(). (i.e. you can't deallocate a pointer you
+ got from a different memory_manager instance.)
+ - the memory pointed to by item hasn't already been deallocated.
+ ensures
+ - deallocates the object pointed to by item
+ - #get_number_of_allocations() == get_number_of_allocations() - 1
+ !*/
+
+ T* allocate_array (
+ unsigned long size
+ );
+ /*!
+ ensures
+ - allocates a new array of size objects of type T and returns a
+ pointer to it.
+ - #get_number_of_allocations() == get_number_of_allocations() + 1
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then the call to allocate()
+ has no effect on #*this.
+ !*/
+
+ void deallocate_array (
+ T* item
+ );
+ /*!
+ requires
+ - item == is a pointer to memory that was obtained from a call to
+ this->allocate_array(). (i.e. you can't deallocate a pointer you
+ got from a different memory_manager instance and it must be an
+ array.)
+ - the memory pointed to by item hasn't already been deallocated.
+ ensures
+ - deallocates the array pointed to by item
+ - #get_number_of_allocations() == get_number_of_allocations() - 1
+ !*/
+
+ void swap (
+ memory_manager& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ memory_manager(memory_manager&); // copy constructor
+ memory_manager& operator=(memory_manager&); // assignment operator
+ };
+
+ template <
+ typename T
+ >
+ inline void swap (
+ memory_manager<T>& a,
+ memory_manager<T>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_MEMORY_MANAGER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/memory_manager_global.h b/ml/dlib/dlib/memory_manager_global.h
new file mode 100644
index 000000000..05f439fdc
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager_global.h
@@ -0,0 +1,38 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMORY_MANAGER_GLOBAl_
+#define DLIB_MEMORY_MANAGER_GLOBAl_
+
+#include "memory_manager_global/memory_manager_global_kernel_1.h"
+#include "memory_manager.h"
+
+
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename factory
+ >
+ class memory_manager_global
+ {
+ memory_manager_global() {}
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1
+ typedef memory_manager_global_kernel_1<T,factory>
+ kernel_1a;
+
+
+
+
+ };
+}
+
+#endif // DLIB_MEMORY_MANAGER_GLOBAl_
+
diff --git a/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_1.h b/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_1.h
new file mode 100644
index 000000000..4cf418d24
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_1.h
@@ -0,0 +1,113 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMORY_MANAGER_GLOBAl_1_
+#define DLIB_MEMORY_MANAGER_GLOBAl_1_
+
+#include "../algs.h"
+#include "../memory_manager/memory_manager_kernel_abstract.h"
+#include "memory_manager_global_kernel_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename T,
+ typename factory
+ >
+ class memory_manager_global_kernel_1
+ {
+ /*!
+ INITIAL VALUE
+ - *global_mm == get_global_memory_manager()
+
+ CONVENTION
+ - global_mm->get_number_of_allocations() == get_number_of_allocations()
+ - *global_mm == get_global_memory_manager()
+ !*/
+
+ public:
+
+ typedef typename factory::template return_type<T>::type mm_global_type;
+
+ typedef T type;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager_global_kernel_1<U,factory> other;
+ };
+
+ memory_manager_global_kernel_1(
+ ) :
+ global_mm(factory::template get_instance<T>())
+ {}
+
+ virtual ~memory_manager_global_kernel_1(
+ ) {}
+
+ unsigned long get_number_of_allocations (
+ ) const { return global_mm->get_number_of_allocations(); }
+
+ mm_global_type& get_global_memory_manager (
+ ) { return *global_mm; }
+
+ T* allocate (
+ )
+ {
+ return global_mm->allocate();
+ }
+
+ void deallocate (
+ T* item
+ )
+ {
+ global_mm->deallocate(item);
+ }
+
+ T* allocate_array (
+ unsigned long size
+ )
+ {
+ return global_mm->allocate_array(size);
+ }
+
+ void deallocate_array (
+ T* item
+ )
+ {
+ global_mm->deallocate_array(item);
+ }
+
+ void swap (
+ memory_manager_global_kernel_1& item
+ )
+ {
+ exchange(item.global_mm, global_mm);
+ }
+
+ private:
+
+ mm_global_type* global_mm;
+
+
+ // restricted functions
+ memory_manager_global_kernel_1(memory_manager_global_kernel_1&); // copy constructor
+ memory_manager_global_kernel_1& operator=(memory_manager_global_kernel_1&); // assignment operator
+ };
+
+ template <
+ typename T,
+ typename factory
+ >
+ inline void swap (
+ memory_manager_global_kernel_1<T,factory>& a,
+ memory_manager_global_kernel_1<T,factory>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_MEMORY_MANAGER_GLOBAl_1_
+
+
+
diff --git a/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_abstract.h b/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_abstract.h
new file mode 100644
index 000000000..a35e38b51
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_abstract.h
@@ -0,0 +1,181 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MEMORY_MANAGER_GLOBAl_ABSTRACT_
+#ifdef DLIB_MEMORY_MANAGER_GLOBAl_ABSTRACT_
+
+#include "../algs.h"
+#include "../memory_manager/memory_manager_kernel_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename T,
+ typename factory
+ >
+ class memory_manager_global
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must have a default constructor.
+
+ REQUIREMENTS ON factory
+ factory must be defined as follows:
+ struct factory
+ {
+ template <typename U>
+ struct return_type {
+ typedef typename memory_manager_type<U> type;
+ };
+
+ template <typename U>
+ static typename return_type<U>::type* get_instance (
+ );
+ / *!
+ ensures
+ - returns a pointer to an instance of a memory_manager object
+ where memory_manager_type implements the interface defined
+ by dlib/memory_manager/memory_manager_kernel_abstract.h
+ !* /
+ };
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents some kind of global memory manager or memory pool.
+ It is identical to the memory_manager object except that it gets all of
+ its allocations from a global instance of a memory_manager object which
+ is provided by the factory object's static member get_instance().
+
+ THREAD SAFETY
+ This object is, by itself, threadsafe. However, if you want to use this
+ object in multiple threads then you must ensure that your factory is
+ threadsafe. This means its factory::get_instance() method should be
+ threadsafe and the memory_manager object it returns must also be threadsafe.
+ !*/
+
+ public:
+
+ typedef typename factory::template return_type<T>::type mm_global_type;
+
+ typedef T type;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager_global<U,factory> other;
+ };
+
+ memory_manager_global(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #get_global_memory_manager() == the memory manager that was
+ returned by a call to factory::get_instance<T>()
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~memory_manager_global(
+ );
+ /*!
+ ensures
+ - This destructor has no effect on the global memory_manager
+ get_global_memory_manager().
+ !*/
+
+ unsigned long get_number_of_allocations (
+ ) const;
+ /*!
+ ensures
+ - returns get_global_memory_manager().get_number_of_allocations()
+ !*/
+
+ mm_global_type& get_global_memory_manager (
+ );
+ /*!
+ ensures
+ - returns a reference to the global memory manager instance being
+ used by *this.
+ !*/
+
+ T* allocate (
+ );
+ /*!
+ ensures
+ - #get_number_of_allocations() == get_number_of_allocations() + 1
+ - returns get_global_memory_manager().allocate()
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then the call to allocate()
+ has no effect on #*this.
+ !*/
+
+ void deallocate (
+ T* item
+ );
+ /*!
+ requires
+ - item == is a pointer to memory that was obtained from a call to
+ the get_global_memory_manager() object's allocate() method.
+ - the memory pointed to by item hasn't already been deallocated.
+ ensures
+ - calls get_global_memory_manager().deallocate(item)
+ - #get_number_of_allocations() == get_number_of_allocations() - 1
+ !*/
+
+ T* allocate_array (
+ unsigned long size
+ );
+ /*!
+ ensures
+ - #get_number_of_allocations() == get_number_of_allocations() + 1
+ - returns get_global_memory_manager().allocate_array()
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then the call to allocate_array()
+ has no effect on #*this.
+ !*/
+
+ void deallocate_array (
+ T* item
+ );
+ /*!
+ requires
+ - item == is a pointer to memory that was obtained from a call to
+ the get_global_memory_manager() object's allocate_array() method.
+ - the memory pointed to by item hasn't already been deallocated.
+ ensures
+ - calls get_global_memory_manager().deallocate_array(item)
+ - #get_number_of_allocations() == get_number_of_allocations() - 1
+ !*/
+
+ void swap (
+ memory_manager_global& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ memory_manager_global(memory_manager_global&); // copy constructor
+ memory_manager_global& operator=(memory_manager_global&); // assignment operator
+ };
+
+ template <
+ typename T,
+ typename factory
+ >
+ inline void swap (
+ memory_manager_global<T,factory>& a,
+ memory_manager_global<T,factory>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_MEMORY_MANAGER_GLOBAl_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/memory_manager_stateless.h b/ml/dlib/dlib/memory_manager_stateless.h
new file mode 100644
index 000000000..32bc7e9c0
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager_stateless.h
@@ -0,0 +1,72 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMORY_MANAGER_STATELESs_
+#define DLIB_MEMORY_MANAGER_STATELESs_
+
+#include "memory_manager_stateless/memory_manager_stateless_kernel_1.h"
+#include "memory_manager_stateless/memory_manager_stateless_kernel_2.h"
+#include "memory_manager.h"
+
+
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class memory_manager_stateless
+ {
+ memory_manager_stateless() {}
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1
+ typedef memory_manager_stateless_kernel_1<T>
+ kernel_1a;
+
+ // kernel_2
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_1a>
+ kernel_2_1a;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_1b>
+ kernel_2_1b;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_1c>
+ kernel_2_1c;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_1d>
+ kernel_2_1d;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_1e>
+ kernel_2_1e;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_1f>
+ kernel_2_1f;
+
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_2a>
+ kernel_2_2a;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_2b>
+ kernel_2_2b;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_2c>
+ kernel_2_2c;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_2d>
+ kernel_2_2d;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_2e>
+ kernel_2_2e;
+
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_3a>
+ kernel_2_3a;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_3b>
+ kernel_2_3b;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_3c>
+ kernel_2_3c;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_3d>
+ kernel_2_3d;
+ typedef memory_manager_stateless_kernel_2<T,memory_manager<char>::kernel_3e>
+ kernel_2_3e;
+
+
+ };
+}
+
+#endif // DLIB_MEMORY_MANAGER_STATELESs_
+
diff --git a/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_1.h b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_1.h
new file mode 100644
index 000000000..0d5794d54
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_1.h
@@ -0,0 +1,86 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMORY_MANAGER_STATELESs_1_
+#define DLIB_MEMORY_MANAGER_STATELESs_1_
+
+#include "memory_manager_stateless_kernel_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename T
+ >
+ class memory_manager_stateless_kernel_1
+ {
+ /*!
+ this implementation just calls new and delete directly
+ !*/
+
+ public:
+
+ typedef T type;
+ const static bool is_stateless = true;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager_stateless_kernel_1<U> other;
+ };
+
+ memory_manager_stateless_kernel_1(
+ )
+ {}
+
+ virtual ~memory_manager_stateless_kernel_1(
+ ) {}
+
+ T* allocate (
+ )
+ {
+ return new T;
+ }
+
+ void deallocate (
+ T* item
+ )
+ {
+ delete item;
+ }
+
+ T* allocate_array (
+ unsigned long size
+ )
+ {
+ return new T[size];
+ }
+
+ void deallocate_array (
+ T* item
+ )
+ {
+ delete [] item;
+ }
+
+ void swap (memory_manager_stateless_kernel_1&)
+ {}
+
+ private:
+
+ // restricted functions
+ memory_manager_stateless_kernel_1(memory_manager_stateless_kernel_1&); // copy constructor
+ memory_manager_stateless_kernel_1& operator=(memory_manager_stateless_kernel_1&); // assignment operator
+ };
+
+ template <
+ typename T
+ >
+ inline void swap (
+ memory_manager_stateless_kernel_1<T>& a,
+ memory_manager_stateless_kernel_1<T>& b
+ ) { a.swap(b); }
+
+}
+
+#endif // DLIB_MEMORY_MANAGER_STATELESs_1_
+
+
+
diff --git a/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_2.h b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_2.h
new file mode 100644
index 000000000..7c4bf76f5
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_2.h
@@ -0,0 +1,119 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MEMORY_MANAGER_STATELESs_2_
+#define DLIB_MEMORY_MANAGER_STATELESs_2_
+
+#include "../algs.h"
+#include "memory_manager_stateless_kernel_abstract.h"
+#include "../threads.h"
+
+namespace dlib
+{
+ template <
+ typename T,
+ typename mem_manager
+ >
+ class memory_manager_stateless_kernel_2
+ {
+ /*!
+ REQUIREMENTS ON mem_manager
+ mem_manager must be an implementation of memory_manager/memory_manager_kernel_abstract.h
+
+ CONVENTION
+ this object has a single global instance of mem_manager
+ !*/
+
+ public:
+
+ typedef T type;
+ const static bool is_stateless = true;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager_stateless_kernel_2<U,mem_manager> other;
+ };
+
+ memory_manager_stateless_kernel_2(
+ )
+ {
+ // call this just to make sure the mutex is is initialized before
+ // multiple threads start calling the member functions.
+ global_mutex();
+ }
+
+ virtual ~memory_manager_stateless_kernel_2(
+ ) {}
+
+ T* allocate (
+ )
+ {
+ auto_mutex M(global_mutex());
+ return global_mm().allocate();
+ }
+
+ void deallocate (
+ T* item
+ )
+ {
+ auto_mutex M(global_mutex());
+ return global_mm().deallocate(item);
+ }
+
+ T* allocate_array (
+ unsigned long size
+ )
+ {
+ auto_mutex M(global_mutex());
+ return global_mm().allocate_array(size);
+ }
+
+ void deallocate_array (
+ T* item
+ )
+ {
+ auto_mutex M(global_mutex());
+ return global_mm().deallocate_array(item);
+ }
+
+ void swap (memory_manager_stateless_kernel_2&)
+ {}
+
+ private:
+
+ static mutex& global_mutex (
+ )
+ {
+ static mutex lock;
+ return lock;
+ }
+
+ typedef typename mem_manager::template rebind<T>::other rebound_mm_type;
+
+ static rebound_mm_type& global_mm (
+ )
+ {
+ static rebound_mm_type mm;
+ return mm;
+ }
+
+ // restricted functions
+ memory_manager_stateless_kernel_2(memory_manager_stateless_kernel_2&); // copy constructor
+ memory_manager_stateless_kernel_2& operator=(memory_manager_stateless_kernel_2&); // assignment operator
+ };
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ memory_manager_stateless_kernel_2<T,mem_manager>& a,
+ memory_manager_stateless_kernel_2<T,mem_manager>& b
+ ) { a.swap(b); }
+
+}
+
+#endif // DLIB_MEMORY_MANAGER_STATELESs_2_
+
+
+
+
diff --git a/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
new file mode 100644
index 000000000..2c5b1e73c
--- /dev/null
+++ b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
@@ -0,0 +1,142 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MEMORY_MANAGER_STATELESs_ABSTRACT_
+#ifdef DLIB_MEMORY_MANAGER_STATELESs_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+ template <
+ typename T
+ >
+ class memory_manager_stateless
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must have a default constructor.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents some kind of stateless memory manager or memory pool.
+ Stateless means that all instances (instances of the same kernel implementation that is)
+ of this object are identical and can be used interchangeably. Note that
+ implementations are allowed to have some shared global state such as a
+ global memory pool.
+
+ THREAD SAFETY
+ This object is thread safe. You may access it from any thread at any time
+ without synchronizing access.
+ !*/
+
+ public:
+
+ typedef T type;
+ const static bool is_stateless = true;
+
+ template <typename U>
+ struct rebind {
+ typedef memory_manager_stateless<U> other;
+ };
+
+ memory_manager_stateless(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~memory_manager_stateless(
+ );
+ /*!
+ ensures
+ - frees any resources used by *this but has no effect on any shared global
+ resources used by the implementation.
+ !*/
+
+ T* allocate (
+ );
+ /*!
+ ensures
+ - allocates a new object of type T and returns a pointer to it.
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then the call to allocate()
+ has no effect on #*this.
+ !*/
+
+ void deallocate (
+ T* item
+ );
+ /*!
+ requires
+ - item == is a pointer to memory that was obtained from a call to
+ allocate(). (i.e. The pointer you are deallocating must have
+ come from the same implementation of memory_manager_stateless
+ that is trying to deallocate it.)
+ - the memory pointed to by item hasn't already been deallocated.
+ ensures
+ - deallocates the object pointed to by item
+ !*/
+
+ T* allocate_array (
+ unsigned long size
+ );
+ /*!
+ ensures
+ - allocates a new array of size objects of type T and returns a
+ pointer to it.
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then the call to allocate()
+ has no effect on #*this.
+ !*/
+
+ void deallocate_array (
+ T* item
+ );
+ /*!
+ requires
+ - item == is a pointer to memory that was obtained from a call to
+ allocate_array(). (i.e. The pointer you are deallocating must have
+ come from the same implementation of memory_manager_stateless
+ that is trying to deallocate it.)
+ - the memory pointed to by item hasn't already been deallocated.
+ ensures
+ - deallocates the array pointed to by item
+ !*/
+
+ void swap (
+ memory_manager_stateless& item
+ );
+ /*!
+ ensures
+ - this function has no effect on *this or item. It is just provided
+ to make this object's interface more compatable with the other
+ memory managers.
+ !*/
+
+ private:
+
+ // restricted functions
+ memory_manager_stateless(memory_manager_stateless&); // copy constructor
+ memory_manager_stateless& operator=(memory_manager_stateless&); // assignment operator
+ };
+
+ template <
+ typename T
+ >
+ inline void swap (
+ memory_manager_stateless<T>& a,
+ memory_manager_stateless<T>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_MEMORY_MANAGER_STATELESs_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/metaprogramming.h b/ml/dlib/dlib/metaprogramming.h
new file mode 100644
index 000000000..bc63041ff
--- /dev/null
+++ b/ml/dlib/dlib/metaprogramming.h
@@ -0,0 +1,71 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_METApROGRAMMING_Hh_
+#define DLIB_METApROGRAMMING_Hh_
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <size_t... n>
+ struct compile_time_integer_list
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ The point of this type is to, as the name suggests, hold a compile time list of integers.
+ As an example, here is something simple you could do with it:
+
+ template <size_t... ints>
+ void print_compile_time_ints (
+ compile_time_integer_list<ints...>
+ )
+ {
+ print(ints...);
+ }
+
+ int main()
+ {
+ print_compile_time_ints(compile_time_integer_list<0,4,9>());
+ }
+
+ Which just calls: print(0,4,9);
+
+ This is a simple example, but this kind of thing is useful in larger and
+ more complex template metaprogramming constructs.
+ !*/
+
+ template <size_t m>
+ struct push_back
+ {
+ typedef compile_time_integer_list<n..., m> type;
+ };
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <size_t max>
+ struct make_compile_time_integer_range
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object makes a compile_time_integer_list containing the integers in the range [1,max] inclusive.
+ For example:
+ make_compile_time_integer_range<4>::type
+ evaluates to:
+ compile_time_integer_list<1,2,3,4>
+ !*/
+
+ typedef typename make_compile_time_integer_range<max-1>::type::template push_back<max>::type type;
+ };
+ // base case
+ template <> struct make_compile_time_integer_range<0> { typedef compile_time_integer_list<> type; };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_METApROGRAMMING_Hh_
+
+
diff --git a/ml/dlib/dlib/misc_api.h b/ml/dlib/dlib/misc_api.h
new file mode 100644
index 000000000..acf87be7c
--- /dev/null
+++ b/ml/dlib/dlib/misc_api.h
@@ -0,0 +1,20 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifndef DLIB_MISC_APi_
+#define DLIB_MISC_APi_
+
+#include "platform.h"
+
+#ifdef WIN32
+#include "misc_api/windows.h"
+#endif
+
+#ifndef WIN32
+#include "misc_api/posix.h"
+#endif
+
+#include "misc_api/misc_api_shared.h"
+
+#endif // DLIB_MISC_APi_
+
diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_1.cpp b/ml/dlib/dlib/misc_api/misc_api_kernel_1.cpp
new file mode 100644
index 000000000..f17d850e1
--- /dev/null
+++ b/ml/dlib/dlib/misc_api/misc_api_kernel_1.cpp
@@ -0,0 +1,149 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MISC_API_KERNEL_1_CPp_
+#define DLIB_MISC_API_KERNEL_1_CPp_
+
+#include "../platform.h"
+#include "../threads.h"
+
+#ifdef WIN32
+
+#include "misc_api_kernel_1.h"
+
+#include "../windows_magic.h"
+#include <mmsystem.h>
+#include <windows.h>
+
+// tell visual studio to link to the library needed to call timeGetTime()
+#ifdef _MSC_VER
+#pragma comment (lib, "winmm.lib")
+#endif
+
+#ifdef __BORLANDC__
+// Apparently the borland compiler doesn't define this.
+#define INVALID_FILE_ATTRIBUTES ((DWORD)-1)
+#endif
+
+namespace dlib
+{
+// ----------------------------------------------------------------------------------------
+
+ void sleep (
+ unsigned long milliseconds
+ )
+ {
+ ::Sleep(milliseconds);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace
+ {
+ mutex& cwd_mutex()
+ {
+ static mutex m;
+ return m;
+ }
+ // Make sure the above mutex gets constructed before main()
+ // starts. This way we can be pretty sure it will be constructed
+ // before any threads could possibly call set_current_dir() or
+ // get_current_dir() simultaneously.
+ struct construct_cwd_mutex
+ {
+ construct_cwd_mutex()
+ {
+ cwd_mutex();
+ }
+ } oaimvweoinvwe;
+ }
+
+ std::string get_current_dir (
+ )
+ {
+ // need to lock a mutex here because getting and setting the
+ // current working directory is not thread safe on windows.
+ auto_mutex lock(cwd_mutex());
+ char buf[1024];
+ if (GetCurrentDirectoryA(sizeof(buf),buf) == 0)
+ {
+ return std::string();
+ }
+ else
+ {
+ return std::string(buf);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void set_current_dir (
+ const std::string& new_dir
+ )
+ {
+ // need to lock a mutex here because getting and setting the
+ // current working directory is not thread safe on windows.
+ auto_mutex lock(cwd_mutex());
+ if (SetCurrentDirectoryA(new_dir.c_str()) == 0)
+ {
+ throw set_current_dir_error("Error changing current dir to '" + new_dir + "'");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint64 timestamper::
+ get_timestamp (
+ ) const
+ {
+ unsigned long temp = timeGetTime();
+ if (temp >= last_time)
+ {
+ last_time = temp;
+ return (offset + temp)*1000;
+ }
+ else
+ {
+ last_time = temp;
+
+ // there was overflow since the last call so we need to make the offset
+ // bigger to account for that
+ offset += dword_max;
+ return (offset + temp)*1000;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void create_directory (
+ const std::string& dir
+ )
+ {
+ if (CreateDirectoryA(dir.c_str(),0) == 0)
+ {
+ // an error has occurred
+ if (GetLastError() == ERROR_ALREADY_EXISTS)
+ {
+ // make sure this is actually a directory
+ DWORD attribs = GetFileAttributesA(dir.c_str());
+ if (attribs == INVALID_FILE_ATTRIBUTES ||
+ (attribs&FILE_ATTRIBUTE_DIRECTORY) == 0)
+ {
+ // it isn't a directory
+ throw dir_create_error(dir);
+ }
+ }
+ else
+ {
+ throw dir_create_error(dir);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // WIN32
+
+#endif // DLIB_MISC_API_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_1.h b/ml/dlib/dlib/misc_api/misc_api_kernel_1.h
new file mode 100644
index 000000000..a500e992a
--- /dev/null
+++ b/ml/dlib/dlib/misc_api/misc_api_kernel_1.h
@@ -0,0 +1,110 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MISC_API_KERNEl_1_
+#define DLIB_MISC_API_KERNEl_1_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+
+#include "misc_api_kernel_abstract.h"
+#include "../algs.h"
+#include <string>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void sleep (
+ unsigned long milliseconds
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ std::string get_current_dir (
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ class set_current_dir_error : public error
+ {
+ public:
+ set_current_dir_error(
+ const std::string& a
+ ): error(a) {}
+ };
+
+ void set_current_dir (
+ const std::string& new_dir
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ class timestamper
+ {
+ /*!
+ INITIAL VALUE
+ - last_time == 0
+ - offset == 0
+ - dword_max == 2^32
+
+ CONVENTION
+ - last_time == the time returned by GetTickCount() the last time we
+ called it.
+ - offset == the number of microseconds we should add to the result of
+ GetTickCount() so that it is correct.
+ - dword_max == 2^32.
+ This is the number of values representable by a DWORD.
+ !*/
+
+ mutable unsigned long last_time;
+ mutable uint64 offset;
+ mutable uint64 dword_max;
+
+ public:
+ timestamper(
+ ) :
+ last_time(0),
+ offset(0)
+ {
+ dword_max = 0xFFFFFFFF;
+ ++dword_max;
+ }
+
+ uint64 get_timestamp (
+ ) const;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class dir_create_error : public error
+ {
+ public:
+ dir_create_error(
+ const std::string& dir_name
+ ) :
+ error(EDIR_CREATE,"Error creating directory '" + dir_name + "'."),
+ name(dir_name)
+ {}
+
+ ~dir_create_error() throw() {}
+ const std::string name;
+ };
+
+ void create_directory (
+ const std::string& dir
+ );
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "misc_api_kernel_1.cpp"
+#endif
+
+#endif // DLIB_MISC_API_KERNEl_1_
+
diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_2.cpp b/ml/dlib/dlib/misc_api/misc_api_kernel_2.cpp
new file mode 100644
index 000000000..e6dc772da
--- /dev/null
+++ b/ml/dlib/dlib/misc_api/misc_api_kernel_2.cpp
@@ -0,0 +1,123 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MISC_API_KERNEL_2_CPp_
+#define DLIB_MISC_API_KERNEL_2_CPp_
+#include "../platform.h"
+
+#ifdef POSIX
+
+#include <unistd.h>
+#include "misc_api_kernel_2.h"
+#include <sys/time.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <errno.h>
+
+namespace dlib
+{
+// ----------------------------------------------------------------------------------------
+
+ void sleep (
+ unsigned long milliseconds
+ )
+ {
+ // in HP-UX you can only usleep for less than a second
+#ifdef HPUX
+ if (milliseconds >= 1000)
+ {
+ ::sleep(milliseconds/1000);
+ unsigned long remaining = milliseconds%1000;
+ if (remaining > 0)
+ ::usleep(remaining*1000);
+ }
+ else
+ {
+ ::usleep(milliseconds*1000);
+ }
+#else
+ ::usleep(milliseconds*1000);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::string get_current_dir (
+ )
+ {
+ char buf[1024];
+ if (getcwd(buf,sizeof(buf)) == 0)
+ {
+ return std::string();
+ }
+ else
+ {
+ return std::string(buf);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void set_current_dir (
+ const std::string& new_dir
+ )
+ {
+ if (chdir(new_dir.c_str()))
+ {
+ throw set_current_dir_error("Error changing current dir to '" + new_dir + "'");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint64 timestamper::
+ get_timestamp (
+ ) const
+ {
+ uint64 ts;
+ timeval curtime;
+ gettimeofday(&curtime,0);
+
+ ts = curtime.tv_sec;
+ ts *= 1000000;
+ ts += curtime.tv_usec;
+ return ts;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void create_directory (
+ const std::string& dir
+ )
+ {
+ if (mkdir(dir.c_str(),0777))
+ {
+ // an error has occurred
+ if (errno == EEXIST)
+ {
+ struct stat buffer;
+ // now check that this is actually a valid directory
+ if (::stat(dir.c_str(),&buffer))
+ {
+ // the directory was not found
+ throw dir_create_error(dir);
+ }
+ else if (S_ISDIR(buffer.st_mode) == 0)
+ {
+ // It is not a directory
+ throw dir_create_error(dir);
+ }
+ }
+ else
+ {
+ throw dir_create_error(dir);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+}
+
+#endif // POSIX
+
+#endif // DLIB_MISC_API_KERNEL_2_CPp_
+
diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_2.h b/ml/dlib/dlib/misc_api/misc_api_kernel_2.h
new file mode 100644
index 000000000..86e8a7f5b
--- /dev/null
+++ b/ml/dlib/dlib/misc_api/misc_api_kernel_2.h
@@ -0,0 +1,81 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MISC_API_KERNEl_2_
+#define DLIB_MISC_API_KERNEl_2_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+
+#include "misc_api_kernel_abstract.h"
+#include "../algs.h"
+#include <string>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ void sleep (
+ unsigned long milliseconds
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ std::string get_current_dir (
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ class set_current_dir_error : public error
+ {
+ public:
+ set_current_dir_error(
+ const std::string& a
+ ): error(a) {}
+ };
+
+ void set_current_dir (
+ const std::string& new_dir
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ class timestamper
+ {
+ public:
+ uint64 get_timestamp (
+ ) const;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class dir_create_error : public error
+ {
+ public:
+ dir_create_error(
+ const std::string& dir_name
+ ) :
+ error(EDIR_CREATE,"Error creating directory '" + dir_name + "'."),
+ name(dir_name)
+ {}
+ const std::string& name;
+ };
+
+
+ void create_directory (
+ const std::string& dir
+ );
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "misc_api_kernel_2.cpp"
+#endif
+
+#endif // DLIB_MISC_API_KERNEl_2_
+
diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_abstract.h b/ml/dlib/dlib/misc_api/misc_api_kernel_abstract.h
new file mode 100644
index 000000000..47749b91b
--- /dev/null
+++ b/ml/dlib/dlib/misc_api/misc_api_kernel_abstract.h
@@ -0,0 +1,159 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MISC_API_KERNEl_ABSTRACT_
+#ifdef DLIB_MISC_API_KERNEl_ABSTRACT_
+
+#include <string>
+#include "../uintn.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ GENERAL COMMENTS
+ This file just contains miscellaneous api functions
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void sleep (
+ unsigned long milliseconds
+ );
+ /*!
+ ensures
+ - causes the calling thread to sleep for the given number of
+ milliseconds.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::string get_current_dir (
+ );
+ /*!
+ ensures
+ - if (no errors occur) then
+ - returns the path to the current working directory
+ - else
+ - returns ""
+ throws
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class set_current_dir_error : public error;
+
+ void set_current_dir (
+ const std::string& new_dir
+ );
+ /*!
+ ensures
+ - sets the current working directory to new_dir
+ throws
+ - std::bad_alloc
+ - set_current_dir_error
+ This exception is thrown if there is an error when attempting
+ to change the current working directory.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class locally_change_current_dir : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a RAII tool for safely switching the current directory
+ to a new directory and then automatically switching back to the original
+ directory upon this object's destruction.
+ !*/
+ public:
+ explicit locally_change_current_dir (
+ const std::string& new_dir
+ );
+ /*!
+ ensures
+ - calls set_current_dir(new_dir)
+ - #old_dir() == The value of get_current_dir() prior to switching to new_dir.
+ !*/
+
+ const std::string& old_dir (
+ ) const;
+ /*!
+ ensures
+ - returns the directory we switch back to once this object is destructed.
+ !*/
+
+ ~locally_change_current_dir(
+ );
+ /*!
+ ensures
+ - if (revert() hasn't already been called) then
+ - calls set_current_dir(old_dir())
+ !*/
+
+ void revert (
+ );
+ /*!
+ ensures
+ - if (revert() hasn't already been called) then
+ - calls set_current_dir(old_dir())
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class dir_create_error : public error {
+ public:
+ const std::string name
+ };
+
+ void create_directory (
+ const std::string& dir
+ );
+ /*!
+ ensures
+ - if (dir does not already exist) then
+ - creates the given directory.
+ - else
+ - the call to create_directory() has no effect.
+ throws
+ - dir_create_error
+ This exception is thrown if we were unable to create the requested
+ directory and it didn't already exist. The type member of the exception
+ will bet set to EDIR_CREATE and the name member will be set to dir.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class timestamper
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a timer that is capable of returning
+ timestamps.
+
+ Note that the time is measured in microseconds but you are not
+ guaranteed to have that level of resolution. The actual resolution
+ is implementation dependent.
+ !*/
+
+ public:
+ uint64 get_timestamp (
+ ) const;
+ /*!
+ ensures
+ - returns a timestamp that measures the time in microseconds since an
+ arbitrary point in the past. Note that this arbitrary point remains
+ the same between all calls to get_timestamp().
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MISC_API_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/misc_api/misc_api_shared.h b/ml/dlib/dlib/misc_api/misc_api_shared.h
new file mode 100644
index 000000000..6b84dd64c
--- /dev/null
+++ b/ml/dlib/dlib/misc_api/misc_api_shared.h
@@ -0,0 +1,57 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MISC_API_ShARED_Hh_
+#define DLIB_MISC_API_ShARED_Hh_
+
+#include <string>
+#include "../noncopyable.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class locally_change_current_dir : noncopyable
+ {
+ public:
+ explicit locally_change_current_dir (
+ const std::string& new_dir
+ )
+ {
+ reverted = false;
+ _old_dir = get_current_dir();
+ set_current_dir(new_dir);
+ }
+
+ ~locally_change_current_dir()
+ {
+ revert();
+ }
+
+ const std::string& old_dir (
+ ) const
+ {
+ return _old_dir;
+ }
+
+ void revert (
+ )
+ {
+ if (!reverted)
+ {
+ set_current_dir(_old_dir);
+ reverted = true;
+ }
+ }
+
+ private:
+ bool reverted;
+ std::string _old_dir;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MISC_API_ShARED_Hh_
+
diff --git a/ml/dlib/dlib/misc_api/posix.h b/ml/dlib/dlib/misc_api/posix.h
new file mode 100644
index 000000000..1dbb38031
--- /dev/null
+++ b/ml/dlib/dlib/misc_api/posix.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MISC_API_KERNEl_1_
+#include "misc_api_kernel_2.h"
+#endif
+
diff --git a/ml/dlib/dlib/misc_api/windows.h b/ml/dlib/dlib/misc_api/windows.h
new file mode 100644
index 000000000..0817c2aaf
--- /dev/null
+++ b/ml/dlib/dlib/misc_api/windows.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MISC_API_KERNEl_2_
+#include "misc_api_kernel_1.h"
+#endif
+
diff --git a/ml/dlib/dlib/mlp.h b/ml/dlib/dlib/mlp.h
new file mode 100644
index 000000000..d287847c8
--- /dev/null
+++ b/ml/dlib/dlib/mlp.h
@@ -0,0 +1,30 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MLp_
+#define DLIB_MLp_
+
+#include "mlp/mlp_kernel_1.h"
+#include "mlp/mlp_kernel_c.h"
+
+namespace dlib
+{
+
+ class mlp
+ {
+ mlp() {}
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef mlp_kernel_1
+ kernel_1a;
+ typedef mlp_kernel_c<kernel_1a >
+ kernel_1a_c;
+
+ };
+}
+
+#endif // DLIB_MLp_
+
diff --git a/ml/dlib/dlib/mlp/mlp_kernel_1.h b/ml/dlib/dlib/mlp/mlp_kernel_1.h
new file mode 100644
index 000000000..d420eea9c
--- /dev/null
+++ b/ml/dlib/dlib/mlp/mlp_kernel_1.h
@@ -0,0 +1,394 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MLp_KERNEL_1_
+#define DLIB_MLp_KERNEL_1_
+
+#include "../algs.h"
+#include "../serialize.h"
+#include "../matrix.h"
+#include "../rand.h"
+#include "mlp_kernel_abstract.h"
+#include <ctime>
+#include <sstream>
+
+namespace dlib
+{
+
+ class mlp_kernel_1 : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ The network is initially initialized with random weights
+
+ CONVENTION
+ - input_layer_nodes() == input_nodes
+ - first_hidden_layer_nodes() == first_hidden_nodes
+ - second_hidden_layer_nodes() == second_hidden_nodes
+ - output_layer_nodes() == output_nodes
+ - get_alpha == alpha
+ - get_momentum() == momentum
+
+
+ - if (second_hidden_nodes == 0) then
+ - for all i and j:
+ - w1(i,j) == the weight on the link from node i in the first hidden layer
+ to input node j
+ - w3(i,j) == the weight on the link from node i in the output layer
+ to first hidden layer node j
+ - for all i and j:
+ - w1m == the momentum terms for w1 from the previous update
+ - w3m == the momentum terms for w3 from the previous update
+ - else
+ - for all i and j:
+ - w1(i,j) == the weight on the link from node i in the first hidden layer
+ to input node j
+ - w2(i,j) == the weight on the link from node i in the second hidden layer
+ to first hidden layer node j
+ - w3(i,j) == the weight on the link from node i in the output layer
+ to second hidden layer node j
+ - for all i and j:
+ - w1m == the momentum terms for w1 from the previous update
+ - w2m == the momentum terms for w2 from the previous update
+ - w3m == the momentum terms for w3 from the previous update
+ !*/
+
+ public:
+
+ mlp_kernel_1 (
+ long nodes_in_input_layer,
+ long nodes_in_first_hidden_layer,
+ long nodes_in_second_hidden_layer = 0,
+ long nodes_in_output_layer = 1,
+ double alpha_ = 0.1,
+ double momentum_ = 0.8
+ ) :
+ input_nodes(nodes_in_input_layer),
+ first_hidden_nodes(nodes_in_first_hidden_layer),
+ second_hidden_nodes(nodes_in_second_hidden_layer),
+ output_nodes(nodes_in_output_layer),
+ alpha(alpha_),
+ momentum(momentum_)
+ {
+
+ // seed the random number generator
+ std::ostringstream sout;
+ sout << time(0);
+ rand_nums.set_seed(sout.str());
+
+ w1.set_size(first_hidden_nodes+1, input_nodes+1);
+ w1m.set_size(first_hidden_nodes+1, input_nodes+1);
+ z.set_size(input_nodes+1,1);
+
+ if (second_hidden_nodes != 0)
+ {
+ w2.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
+ w3.set_size(output_nodes, second_hidden_nodes+1);
+
+ w2m.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
+ w3m.set_size(output_nodes, second_hidden_nodes+1);
+ }
+ else
+ {
+ w3.set_size(output_nodes, first_hidden_nodes+1);
+
+ w3m.set_size(output_nodes, first_hidden_nodes+1);
+ }
+
+ reset();
+ }
+
+ virtual ~mlp_kernel_1 (
+ ) {}
+
+ void reset (
+ )
+ {
+ // randomize the weights for the first layer
+ for (long r = 0; r < w1.nr(); ++r)
+ for (long c = 0; c < w1.nc(); ++c)
+ w1(r,c) = rand_nums.get_random_double();
+
+ // randomize the weights for the second layer
+ for (long r = 0; r < w2.nr(); ++r)
+ for (long c = 0; c < w2.nc(); ++c)
+ w2(r,c) = rand_nums.get_random_double();
+
+ // randomize the weights for the third layer
+ for (long r = 0; r < w3.nr(); ++r)
+ for (long c = 0; c < w3.nc(); ++c)
+ w3(r,c) = rand_nums.get_random_double();
+
+ // zero all the momentum terms
+ set_all_elements(w1m,0);
+ set_all_elements(w2m,0);
+ set_all_elements(w3m,0);
+ }
+
+ long input_layer_nodes (
+ ) const { return input_nodes; }
+
+ long first_hidden_layer_nodes (
+ ) const { return first_hidden_nodes; }
+
+ long second_hidden_layer_nodes (
+ ) const { return second_hidden_nodes; }
+
+ long output_layer_nodes (
+ ) const { return output_nodes; }
+
+ double get_alpha (
+ ) const { return alpha; }
+
+ double get_momentum (
+ ) const { return momentum; }
+
+ template <typename EXP>
+ const matrix<double> operator() (
+ const matrix_exp<EXP>& in
+ ) const
+ {
+ for (long i = 0; i < in.nr(); ++i)
+ z(i) = in(i);
+ // insert the bias
+ z(z.nr()-1) = -1;
+
+ tmp1 = sigmoid(w1*z);
+ // insert the bias
+ tmp1(tmp1.nr()-1) = -1;
+
+ if (second_hidden_nodes == 0)
+ {
+ return sigmoid(w3*tmp1);
+ }
+ else
+ {
+ tmp2 = sigmoid(w2*tmp1);
+ // insert the bias
+ tmp2(tmp2.nr()-1) = -1;
+
+ return sigmoid(w3*tmp2);
+ }
+ }
+
+ template <typename EXP1, typename EXP2>
+ void train (
+ const matrix_exp<EXP1>& example_in,
+ const matrix_exp<EXP2>& example_out
+ )
+ {
+ for (long i = 0; i < example_in.nr(); ++i)
+ z(i) = example_in(i);
+ // insert the bias
+ z(z.nr()-1) = -1;
+
+ tmp1 = sigmoid(w1*z);
+ // insert the bias
+ tmp1(tmp1.nr()-1) = -1;
+
+
+ if (second_hidden_nodes == 0)
+ {
+ o = sigmoid(w3*tmp1);
+
+ // now compute the errors and propagate them backwards though the network
+ e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
+ e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w3)*e3 );
+
+ // compute the new weight updates
+ w3m = alpha * e3*trans(tmp1) + w3m*momentum;
+ w1m = alpha * e1*trans(z) + w1m*momentum;
+
+ // now update the weights
+ w1 += w1m;
+ w3 += w3m;
+ }
+ else
+ {
+ tmp2 = sigmoid(w2*tmp1);
+ // insert the bias
+ tmp2(tmp2.nr()-1) = -1;
+
+ o = sigmoid(w3*tmp2);
+
+
+ // now compute the errors and propagate them backwards though the network
+ e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
+ e2 = pointwise_multiply(tmp2, uniform_matrix<double>(second_hidden_nodes+1,1,1.0) - tmp2, trans(w3)*e3 );
+ e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w2)*e2 );
+
+ // compute the new weight updates
+ w3m = alpha * e3*trans(tmp2) + w3m*momentum;
+ w2m = alpha * e2*trans(tmp1) + w2m*momentum;
+ w1m = alpha * e1*trans(z) + w1m*momentum;
+
+ // now update the weights
+ w1 += w1m;
+ w2 += w2m;
+ w3 += w3m;
+ }
+ }
+
+ template <typename EXP>
+ void train (
+ const matrix_exp<EXP>& example_in,
+ double example_out
+ )
+ {
+ matrix<double,1,1> e_out;
+ e_out(0) = example_out;
+ train(example_in,e_out);
+ }
+
+ double get_average_change (
+ ) const
+ {
+ // sum up all the weight changes
+ double delta = sum(abs(w1m)) + sum(abs(w2m)) + sum(abs(w3m));
+
+ // divide by the number of weights
+ delta /= w1m.nr()*w1m.nc() +
+ w2m.nr()*w2m.nc() +
+ w3m.nr()*w3m.nc();
+
+ return delta;
+ }
+
+ void swap (
+ mlp_kernel_1& item
+ )
+ {
+ exchange(input_nodes, item.input_nodes);
+ exchange(first_hidden_nodes, item.first_hidden_nodes);
+ exchange(second_hidden_nodes, item.second_hidden_nodes);
+ exchange(output_nodes, item.output_nodes);
+ exchange(alpha, item.alpha);
+ exchange(momentum, item.momentum);
+
+ w1.swap(item.w1);
+ w2.swap(item.w2);
+ w3.swap(item.w3);
+
+ w1m.swap(item.w1m);
+ w2m.swap(item.w2m);
+ w3m.swap(item.w3m);
+
+ // even swap the temporary matrices because this may ultimately result in
+ // fewer calls to new and delete.
+ e1.swap(item.e1);
+ e2.swap(item.e2);
+ e3.swap(item.e3);
+ z.swap(item.z);
+ tmp1.swap(item.tmp1);
+ tmp2.swap(item.tmp2);
+ o.swap(item.o);
+ }
+
+
+ friend void serialize (
+ const mlp_kernel_1& item,
+ std::ostream& out
+ );
+
+ friend void deserialize (
+ mlp_kernel_1& item,
+ std::istream& in
+ );
+
+ private:
+
+ long input_nodes;
+ long first_hidden_nodes;
+ long second_hidden_nodes;
+ long output_nodes;
+ double alpha;
+ double momentum;
+
+ matrix<double> w1;
+ matrix<double> w2;
+ matrix<double> w3;
+
+ matrix<double> w1m;
+ matrix<double> w2m;
+ matrix<double> w3m;
+
+
+ rand rand_nums;
+
+ // temporary storage
+ mutable matrix<double> e1, e2, e3;
+ mutable matrix<double> z, tmp1, tmp2, o;
+ };
+
+ inline void swap (
+ mlp_kernel_1& a,
+ mlp_kernel_1& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const mlp_kernel_1& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.input_nodes, out);
+ serialize(item.first_hidden_nodes, out);
+ serialize(item.second_hidden_nodes, out);
+ serialize(item.output_nodes, out);
+ serialize(item.alpha, out);
+ serialize(item.momentum, out);
+
+ serialize(item.w1, out);
+ serialize(item.w2, out);
+ serialize(item.w3, out);
+
+ serialize(item.w1m, out);
+ serialize(item.w2m, out);
+ serialize(item.w3m, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type mlp_kernel_1");
+ }
+ }
+
+ inline void deserialize (
+ mlp_kernel_1& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.input_nodes, in);
+ deserialize(item.first_hidden_nodes, in);
+ deserialize(item.second_hidden_nodes, in);
+ deserialize(item.output_nodes, in);
+ deserialize(item.alpha, in);
+ deserialize(item.momentum, in);
+
+ deserialize(item.w1, in);
+ deserialize(item.w2, in);
+ deserialize(item.w3, in);
+
+ deserialize(item.w1m, in);
+ deserialize(item.w2m, in);
+ deserialize(item.w3m, in);
+
+ item.z.set_size(item.input_nodes+1,1);
+ }
+ catch (serialization_error& e)
+ {
+ // give item a reasonable value since the deserialization failed
+ mlp_kernel_1(1,1).swap(item);
+ throw serialization_error(e.info + "\n while deserializing object of type mlp_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MLp_KERNEL_1_
+
diff --git a/ml/dlib/dlib/mlp/mlp_kernel_abstract.h b/ml/dlib/dlib/mlp/mlp_kernel_abstract.h
new file mode 100644
index 000000000..cbb473a87
--- /dev/null
+++ b/ml/dlib/dlib/mlp/mlp_kernel_abstract.h
@@ -0,0 +1,225 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MLp_ABSTRACT_
+#ifdef DLIB_MLp_ABSTRACT_
+
+#include "../algs.h"
+#include "../serialize.h"
+#include "../matrix/matrix_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class mlp : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ The network is initially initialized with random weights
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a multilayer layer perceptron network that is
+ trained using the back propagation algorithm. The training algorithm also
+ incorporates the momentum method. That is, each round of back propagation
+ training also adds a fraction of the previous update. This fraction
+ is controlled by the momentum term set in the constructor.
+
+ The activation function used at each node is the sigmoid function. I.e.
+ sigmoid(x) = 1/(1 + pow(e,-x)). Thus the output of the network is
+ always in the range [0,1]
+ !*/
+
+ public:
+
+ mlp (
+ long nodes_in_input_layer,
+ long nodes_in_first_hidden_layer,
+ long nodes_in_second_hidden_layer = 0,
+ long nodes_in_output_layer = 1,
+ double alpha = 0.1,
+ double momentum = 0.8
+ );
+ /*!
+ requires
+ - nodes_in_input_layer > 0
+ - nodes_in_first_hidden_layer > 0
+ - nodes_in_second_hidden_layer >= 0
+ - nodes_in_output_layer > 0
+ ensures
+ - #*this is properly initialized
+ - #input_layer_nodes() == nodes_in_input_layer
+ - #first_hidden_layer_nodes() == nodes_in_first_hidden_layer
+ - #second_hidden_layer_nodes() == nodes_in_second_hidden_layer
+ - #output_layer_nodes() == nodes_in_output_layer
+ - #get_alpha() == alpha
+ - #get_momentum() == momentum
+ throws
+ - std::bad_alloc
+ if this is thrown the mlp will be unusable but
+ will not leak memory
+ !*/
+
+ virtual ~mlp (
+ );
+ /*!
+ ensures
+ - all resources associated with #*this have been released
+ !*/
+
+ void reset (
+ ) const;
+ /*!
+ ensures
+ - reinitialize the network with random weights
+ !*/
+
+ long input_layer_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in the input layer
+ !*/
+
+ long first_hidden_layer_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in the first hidden layer. This is
+ the hidden layer that is directly connected to the input layer.
+ !*/
+
+ long second_hidden_layer_nodes (
+ ) const;
+ /*!
+ ensures
+ - if (this network has a second hidden layer) then
+ - returns the number of nodes in the second hidden layer. This is
+ the hidden layer that is directly connected to the output layer.
+ - else
+ - returns 0
+ !*/
+
+ long output_layer_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in the output layer
+ !*/
+
+ double get_alpha (
+ ) const;
+ /*!
+ ensures
+ - returns the back propagation learning rate used by this object.
+ !*/
+
+ double get_momentum (
+ ) const;
+ /*!
+ ensures
+ - returns the momentum term used by this object during back propagation
+ training. The momentum is is the fraction of a previous update to
+ carry forward to the next call to train()
+ !*/
+
+ template <typename EXP>
+ const matrix<double> operator() (
+ const matrix_exp<EXP>& in
+ ) const;
+ /*!
+ requires
+ - in.nr() == input_layer_nodes()
+ - in.nc() == 1
+ - EXP::type == double
+ ensures
+ - returns the output of the network when it is given the
+ input in. The output's elements are always in the range
+ of 0.0 to 1.0
+ !*/
+
+ template <typename EXP1, typename EXP2>
+ void train (
+ const matrix_exp<EXP1>& example_in,
+ const matrix_exp<EXP2>& example_out
+ );
+ /*!
+ requires
+ - example_in.nr() == input_layer_nodes()
+ - example_in.nc() == 1
+ - example_out.nr() == output_layer_nodes()
+ - example_out.nc() == 1
+ - max(example_out) <= 1.0 && min(example_out) >= 0.0
+ - EXP1::type == double
+ - EXP2::type == double
+ ensures
+ - trains the network that the correct output when given example_in
+ should be example_out.
+ !*/
+
+ template <typename EXP>
+ void train (
+ const matrix_exp<EXP>& example_in,
+ double example_out
+ );
+ /*!
+ requires
+ - example_in.nr() == input_layer_nodes()
+ - example_in.nc() == 1
+ - output_layer_nodes() == 1
+ - example_out <= 1.0 && example_out >= 0.0
+ - EXP::type == double
+ ensures
+ - trains the network that the correct output when given example_in
+ should be example_out.
+ !*/
+
+ double get_average_change (
+ ) const;
+ /*!
+ ensures
+ - returns the average change in the node weights in the
+ neural network during the last call to train()
+ !*/
+
+ void swap (
+ mlp& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ inline void swap (
+ mlp& a,
+ mlp& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ void serialize (
+ const mlp& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ mlp& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MLp_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/mlp/mlp_kernel_c.h b/ml/dlib/dlib/mlp/mlp_kernel_c.h
new file mode 100644
index 000000000..bd3438fd9
--- /dev/null
+++ b/ml/dlib/dlib/mlp/mlp_kernel_c.h
@@ -0,0 +1,151 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MLP_KERNEl_C_
+#define DLIB_MLP_KERNEl_C_
+
+#include "mlp_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+
+ template <
+ typename mlp_base // is an implementation of mlp_kernel_abstract.h
+ >
+ class mlp_kernel_c : public mlp_base
+ {
+ long verify_constructor_args (
+ long nodes_in_input_layer,
+ long nodes_in_first_hidden_layer,
+ long nodes_in_second_hidden_layer,
+ long nodes_in_output_layer
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(nodes_in_input_layer > 0 &&
+ nodes_in_first_hidden_layer > 0 &&
+ nodes_in_second_hidden_layer >= 0 &&
+ nodes_in_output_layer > 0,
+ "\tconst mlp::constructor()"
+ << "\n\tinvalid constructor arguments"
+ << "\n\tnodes_in_input_layer: " << nodes_in_input_layer
+ << "\n\tnodes_in_first_hidden_layer: " << nodes_in_first_hidden_layer
+ << "\n\tnodes_in_second_hidden_layer: " << nodes_in_second_hidden_layer
+ << "\n\tnodes_in_output_layer: " << nodes_in_output_layer
+ );
+
+ return nodes_in_input_layer;
+ }
+
+ public:
+
+ mlp_kernel_c (
+ long nodes_in_input_layer,
+ long nodes_in_first_hidden_layer,
+ long nodes_in_second_hidden_layer = 0,
+ long nodes_in_output_layer = 1,
+ double alpha = 0.1,
+ double momentum = 0.8
+ ) : mlp_base( verify_constructor_args(
+ nodes_in_input_layer,
+ nodes_in_input_layer,
+ nodes_in_second_hidden_layer,
+ nodes_in_output_layer),
+ nodes_in_first_hidden_layer,
+ nodes_in_second_hidden_layer,
+ nodes_in_output_layer,
+ alpha,
+ momentum)
+ {
+ }
+
+ template <typename EXP>
+ const matrix<double> operator() (
+ const matrix_exp<EXP>& in
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(in.nr() == this->input_layer_nodes() &&
+ in.nc() == 1,
+ "\tconst matrix<double> mlp::operator()(matrix_exp)"
+ << "\n\tthe input matrix dimensions are not correct"
+ << "\n\tin.nr(): " << in.nr()
+ << "\n\tin.nc(): " << in.nc()
+ << "\n\tinput_layer_nodes(): " << this->input_layer_nodes()
+ << "\n\tthis: " << this
+ );
+
+ return mlp_base::operator()(in);
+ }
+
+ template <typename EXP1, typename EXP2>
+ void train (
+ const matrix_exp<EXP1>& example_in,
+ const matrix_exp<EXP2>& example_out
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(example_in.nr() == this->input_layer_nodes() &&
+ example_in.nc() == 1 &&
+ example_out.nr() == this->output_layer_nodes() &&
+ example_out.nc() == 1 &&
+ max(example_out) <= 1.0 && min(example_out) >= 0.0,
+ "\tvoid mlp::train(matrix_exp, matrix_exp)"
+ << "\n\tthe training example dimensions are not correct"
+ << "\n\texample_in.nr(): " << example_in.nr()
+ << "\n\texample_in.nc(): " << example_in.nc()
+ << "\n\texample_out.nr(): " << example_out.nr()
+ << "\n\texample_out.nc(): " << example_out.nc()
+ << "\n\tmax(example_out): " << max(example_out)
+ << "\n\tmin(example_out): " << min(example_out)
+ << "\n\tinput_layer_nodes(): " << this->input_layer_nodes()
+ << "\n\toutput_layer_nodes(): " << this->output_layer_nodes()
+ << "\n\tthis: " << this
+ );
+
+ mlp_base::train(example_in,example_out);
+ }
+
+ template <typename EXP>
+ void train (
+ const matrix_exp<EXP>& example_in,
+ double example_out
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(example_in.nr() == this->input_layer_nodes() &&
+ example_in.nc() == 1 &&
+ this->output_layer_nodes() == 1 &&
+ example_out <= 1.0 && example_out >= 0.0,
+ "\tvoid mlp::train(matrix_exp, double)"
+ << "\n\tthe training example dimensions are not correct"
+ << "\n\texample_in.nr(): " << example_in.nr()
+ << "\n\texample_in.nc(): " << example_in.nc()
+ << "\n\texample_out: " << example_out
+ << "\n\tinput_layer_nodes(): " << this->input_layer_nodes()
+ << "\n\toutput_layer_nodes(): " << this->output_layer_nodes()
+ << "\n\tthis: " << this
+ );
+
+ mlp_base::train(example_in,example_out);
+ }
+
+ };
+
+ template <
+ typename mlp_base
+ >
+ inline void swap (
+ mlp_kernel_c<mlp_base>& a,
+ mlp_kernel_c<mlp_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MLP_KERNEl_C_
+
+
diff --git a/ml/dlib/dlib/noncopyable.h b/ml/dlib/dlib/noncopyable.h
new file mode 100644
index 000000000..20b9866e7
--- /dev/null
+++ b/ml/dlib/dlib/noncopyable.h
@@ -0,0 +1,32 @@
+// (C) Copyright Beman Dawes 1999-2003. Distributed under the Boost
+// Software License, Version 1.0. (See accompanying file
+// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
+// Contributed by Dave Abrahams
+// See http://www.boost.org/libs/utility for documentation.
+
+#ifndef DLIB_BOOST_NONCOPYABLE_HPP_INCLUDED
+#define DLIB_BOOST_NONCOPYABLE_HPP_INCLUDED
+
+
+namespace dlib
+{
+ class noncopyable
+ {
+ /*!
+ This class makes it easier to declare a class as non-copyable.
+ If you want to make an object that can't be copied just inherit
+ from this object.
+ !*/
+
+ protected:
+ noncopyable() = default;
+ ~noncopyable() = default;
+ private: // emphasize the following members are private
+ noncopyable(const noncopyable&);
+ const noncopyable& operator=(const noncopyable&);
+
+ };
+}
+
+#endif // DLIB_BOOST_NONCOPYABLE_HPP_INCLUDED
+
diff --git a/ml/dlib/dlib/numeric_constants.h b/ml/dlib/dlib/numeric_constants.h
new file mode 100644
index 000000000..05f26319f
--- /dev/null
+++ b/ml/dlib/dlib/numeric_constants.h
@@ -0,0 +1,53 @@
+//Copyright (C) 2013 Steve Taylor (steve98654@gmail.com), Davis E. King
+//License: Boost Software License. See LICENSE.txt for full license.
+#ifndef DLIB_NUMERIC_CONSTANTs_H_
+#define DLIB_NUMERIC_CONSTANTs_H_
+
+namespace dlib
+{
+
+ // pi -- Pi
+ const double pi = 3.1415926535897932385;
+
+ // e -- Euler's Constant
+ const double e = 2.7182818284590452354;
+
+ // sqrt_2 -- The square root of 2
+ const double sqrt_2 = 1.4142135623730950488;
+
+ // sqrt_3 -- The square root of 3
+ const double sqrt_3 = 1.7320508075688772935;
+
+ // log10_2 -- The logarithm base 10 of two
+ const double log10_2 = 0.30102999566398119521;
+
+ // light_spd -- The speed of light in vacuum in meters per second
+ const double light_spd = 2.99792458e8;
+
+ // newton_G -- Newton's gravitational constant (in metric units of m^3/(kg*s^2))
+ const double newton_G = 6.67384e-11;
+
+ // planck_cst -- Planck's constant (in units of Joules * seconds)
+ const double planck_cst = 6.62606957e-34;
+
+ // golden_ratio -- The Golden Ratio
+ const double golden_ratio = 1.6180339887498948482;
+
+ // euler_gamma -- The Euler Mascheroni Constant
+ const double euler_gamma = 0.5772156649015328606065;
+
+ // catalan -- Catalan's Constant
+ const double catalan = 0.91596559417721901505;
+
+ // glaisher -- Glaisher Kinkelin constant
+ const double glaisher = 1.2824271291006226369;
+
+ // khinchin -- Khinchin's constant
+ const double khinchin = 2.6854520010653064453;
+
+ // apery -- Apery's constant
+ const double apery = 1.2020569031595942854;
+}
+
+#endif //DLIB_NUMERIC_CONSTANTs_H_
+
diff --git a/ml/dlib/dlib/numerical_integration.h b/ml/dlib/dlib/numerical_integration.h
new file mode 100644
index 000000000..6676c9892
--- /dev/null
+++ b/ml/dlib/dlib/numerical_integration.h
@@ -0,0 +1,8 @@
+// Copyright (C) 2013 Steve Taylor (steve98654@gmail.com)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_HEADER
+#define DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_HEADER
+
+#include "numerical_integration/integrate_function_adapt_simpson.h"
+
+#endif // DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_HEADER
diff --git a/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson.h b/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson.h
new file mode 100644
index 000000000..c30e21c59
--- /dev/null
+++ b/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson.h
@@ -0,0 +1,93 @@
+// Copyright (C) 2013 Steve Taylor (steve98654@gmail.com)
+// License: Boost Software License See LICENSE.txt for full license
+#ifndef DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSONh_
+#define DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSONh_
+
+#include "integrate_function_adapt_simpson_abstract.h"
+#include "../assert.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ template <typename T, typename funct>
+ T impl_adapt_simp_stop(const funct& f, T a, T b, T fa, T fm, T fb, T is, int cnt)
+ {
+ const int maxint = 500;
+
+ T m = (a + b)/2.0;
+ T h = (b - a)/4.0;
+ T fml = f(a + h);
+ T fmr = f(b - h);
+ T i1 = h/1.5*(fa+4.0*fm+fb);
+ T i2 = h/3.0*(fa+4.0*(fml+fmr)+2.0*fm+fb);
+ i1 = (16.0*i2 - i1)/15.0;
+ T Q = 0;
+
+ if ((std::abs(i1-i2) <= std::abs(is)) || (m <= a) || (b <= m))
+ {
+ Q = i1;
+ }
+ else
+ {
+ if(cnt < maxint)
+ {
+ cnt = cnt + 1;
+
+ Q = impl_adapt_simp_stop(f,a,m,fa,fml,fm,is,cnt)
+ + impl_adapt_simp_stop(f,m,b,fm,fmr,fb,is,cnt);
+ }
+ }
+
+ return Q;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename funct>
+ T integrate_function_adapt_simp(
+ const funct& f,
+ T a,
+ T b,
+ T tol = 1e-10
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(b > a && tol > 0,
+ "\t T integrate_function_adapt_simp()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t a: " << a
+ << "\n\t b: " << b
+ << "\n\t tol: " << tol
+ );
+
+ T eps = std::numeric_limits<T>::epsilon();
+ if(tol < eps)
+ {
+ tol = eps;
+ }
+
+ const T ba = b-a;
+ const T fa = f(a);
+ const T fb = f(b);
+ const T fm = f((a+b)/2);
+
+ T is = ba/8*(fa+fb+fm+ f(a + 0.9501*ba) + f(a + 0.2311*ba) + f(a + 0.6068*ba)
+ + f(a + 0.4860*ba) + f(a + 0.8913*ba));
+
+ if(is == 0)
+ {
+ is = b-a;
+ }
+
+ is = is*tol;
+
+ int cnt = 0;
+
+ return impl_adapt_simp_stop(f, a, b, fa, fm, fb, is, cnt);
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSONh_
diff --git a/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson_abstract.h b/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson_abstract.h
new file mode 100644
index 000000000..90badcf14
--- /dev/null
+++ b/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson_abstract.h
@@ -0,0 +1,34 @@
+// Copyright (C) 2013 Steve Taylor (steve98654@gmail.com)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_ABSTRACTh_
+#ifdef DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_ABSTRACTh_
+
+namespace dlib
+{
+
+ template <typename T, typename funct>
+ T integrate_function_adapt_simp(
+ const funct& f,
+ T a,
+ T b,
+ T tol = 1e-10
+ );
+ /*!
+ requires
+ - b > a
+ - tol > 0
+ - T should be either float, double, or long double
+ - The expression f(a) should be a valid expression that evaluates to a T.
+ I.e. f() should be a real valued function of a single variable.
+ ensures
+ - returns an approximation of the integral of f over the domain [a,b] using the
+ adaptive Simpson method outlined in Gander, W. and W. Gautshi, "Adaptive
+ Quadrature -- Revisited" BIT, Vol. 40, (2000), pp.84-101
+ - tol is a tolerance parameter that determines the overall accuracy of
+ approximated integral. We suggest a default value of 1e-10 for tol.
+ !*/
+
+}
+
+#endif // DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_ABSTRACTh_
+
diff --git a/ml/dlib/dlib/opencv.h b/ml/dlib/dlib/opencv.h
new file mode 100644
index 000000000..c48a6ec5a
--- /dev/null
+++ b/ml/dlib/dlib/opencv.h
@@ -0,0 +1,17 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_OPEnCV_HEADER
+#define DLIB_OPEnCV_HEADER
+
+#include "opencv/cv_image.h"
+#include "opencv/to_open_cv.h"
+
+#endif // DLIB_OPEnCV_HEADER
+
+
+
+
diff --git a/ml/dlib/dlib/opencv/cv_image.h b/ml/dlib/dlib/opencv/cv_image.h
new file mode 100644
index 000000000..5f224d003
--- /dev/null
+++ b/ml/dlib/dlib/opencv/cv_image.h
@@ -0,0 +1,225 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CvIMAGE_H_
+#define DLIB_CvIMAGE_H_
+
+#include <opencv2/core/core.hpp>
+#include <opencv2/core/types_c.h>
+#include "cv_image_abstract.h"
+#include "../algs.h"
+#include "../pixel.h"
+#include "../matrix/matrix_mat.h"
+#include "../image_processing/generic_image.h"
+
+namespace dlib
+{
+
+ template <
+ typename pixel_type
+ >
+ class cv_image
+ {
+ public:
+ typedef pixel_type type;
+ typedef default_memory_manager mem_manager_type;
+
+ cv_image (const cv::Mat img)
+ {
+ DLIB_CASSERT(img.depth() == cv::DataType<typename pixel_traits<pixel_type>::basic_pixel_type>::depth &&
+ img.channels() == pixel_traits<pixel_type>::num,
+ "The pixel type you gave doesn't match pixel used by the open cv Mat object."
+ << "\n\t img.depth(): " << img.depth()
+ << "\n\t img.cv::DataType<typename pixel_traits<pixel_type>::basic_pixel_type>::depth: "
+ << cv::DataType<typename pixel_traits<pixel_type>::basic_pixel_type>::depth
+ << "\n\t img.channels(): " << img.channels()
+ << "\n\t img.pixel_traits<pixel_type>::num: " << pixel_traits<pixel_type>::num
+ );
+ IplImage temp = img;
+ init(&temp);
+ }
+
+ cv_image (const IplImage img)
+ {
+ init(&img);
+ }
+
+ cv_image (const IplImage* img)
+ {
+ init(img);
+ }
+
+ cv_image() : _data(0), _widthStep(0), _nr(0), _nc(0) {}
+
+ size_t size () const { return static_cast<size_t>(_nr*_nc); }
+
+ inline pixel_type* operator[](const long row )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= row && row < nr(),
+ "\tpixel_type* cv_image::operator[](row)"
+ << "\n\t you have asked for an out of bounds row "
+ << "\n\t row: " << row
+ << "\n\t nr(): " << nr()
+ << "\n\t this: " << this
+ );
+
+ return reinterpret_cast<pixel_type*>( _data + _widthStep*row);
+ }
+
+ inline const pixel_type* operator[](const long row ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= row && row < nr(),
+ "\tconst pixel_type* cv_image::operator[](row)"
+ << "\n\t you have asked for an out of bounds row "
+ << "\n\t row: " << row
+ << "\n\t nr(): " << nr()
+ << "\n\t this: " << this
+ );
+
+ return reinterpret_cast<const pixel_type*>( _data + _widthStep*row);
+ }
+
+ inline const pixel_type& operator()(const long row, const long column) const
+ {
+ DLIB_ASSERT(0<= column && column < nc(),
+ "\tcont pixel_type& cv_image::operator()(const long rown const long column)"
+ << "\n\t you have asked for an out of bounds column "
+ << "\n\t column: " << column
+ << "\n\t nc(): " << nc()
+ << "\n\t this: " << this
+ );
+
+ return (*this)[row][column];
+ }
+
+ inline pixel_type& operator()(const long row, const long column)
+ {
+ DLIB_ASSERT(0<= column && column < nc(),
+ "\tcont pixel_type& cv_image::operator()(const long rown const long column)"
+ << "\n\t you have asked for an out of bounds column "
+ << "\n\t column: " << column
+ << "\n\t nc(): " << nc()
+ << "\n\t this: " << this
+ );
+
+ return (*this)[row][column];
+ }
+
+ long nr() const { return _nr; }
+ long nc() const { return _nc; }
+ long width_step() const { return _widthStep; }
+
+ cv_image& operator=( const cv_image& item)
+ {
+ _data = item._data;
+ _widthStep = item._widthStep;
+ _nr = item._nr;
+ _nc = item._nc;
+ return *this;
+ }
+
+ cv_image& operator=( const IplImage* img)
+ {
+ init(img);
+ return *this;
+ }
+
+ cv_image& operator=( const IplImage img)
+ {
+ init(&img);
+ return *this;
+ }
+
+ cv_image& operator=( const cv::Mat img)
+ {
+ IplImage temp = img;
+ init(&temp);
+ return *this;
+ }
+
+ private:
+
+ void init (const IplImage* img)
+ {
+ DLIB_CASSERT( img->dataOrder == 0, "Only interleaved color channels are supported with cv_image");
+ DLIB_CASSERT((img->depth&0xFF)/8*img->nChannels == sizeof(pixel_type),
+ "The pixel type you gave doesn't match the size of pixel used by the open cv image struct");
+
+ _data = img->imageData;
+ _widthStep = img->widthStep;
+ _nr = img->height;
+ _nc = img->width;
+
+ }
+
+ char* _data;
+ long _widthStep;
+ long _nr;
+ long _nc;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_array2d_to_mat<cv_image<T> > > mat (
+ const cv_image<T>& m
+ )
+ {
+ typedef op_array2d_to_mat<cv_image<T> > op;
+ return matrix_op<op>(op(m));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// Define the global functions that make cv_image a proper "generic image" according to
+// ../image_processing/generic_image.h
+ template <typename T>
+ struct image_traits<cv_image<T> >
+ {
+ typedef T pixel_type;
+ };
+
+ template <typename T>
+ inline long num_rows( const cv_image<T>& img) { return img.nr(); }
+ template <typename T>
+ inline long num_columns( const cv_image<T>& img) { return img.nc(); }
+
+ template <typename T>
+ inline void* image_data(
+ cv_image<T>& img
+ )
+ {
+ if (img.size() != 0)
+ return &img[0][0];
+ else
+ return 0;
+ }
+
+ template <typename T>
+ inline const void* image_data(
+ const cv_image<T>& img
+ )
+ {
+ if (img.size() != 0)
+ return &img[0][0];
+ else
+ return 0;
+ }
+
+ template <typename T>
+ inline long width_step(
+ const cv_image<T>& img
+ )
+ {
+ return img.width_step();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CvIMAGE_H_
+
diff --git a/ml/dlib/dlib/opencv/cv_image_abstract.h b/ml/dlib/dlib/opencv/cv_image_abstract.h
new file mode 100644
index 000000000..6fbc56b43
--- /dev/null
+++ b/ml/dlib/dlib/opencv/cv_image_abstract.h
@@ -0,0 +1,280 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPENCV_IMAGE_AbSTRACT_H_
+#ifdef DLIB_OPENCV_IMAGE_AbSTRACT_H_
+
+#include <opencv2/core/core.hpp>
+#include <opencv2/core/types_c.h>
+#include "../algs.h"
+#include "../pixel.h"
+
+namespace dlib
+{
+
+ template <
+ typename pixel_type
+ >
+ class cv_image
+ {
+ /*!
+ REQUIREMENTS ON pixel_type
+ pixel_type just needs to be something that matches the pixel memory
+ layout of whatever OpenCV image you are going to use with this object.
+ For example, you might use unsigned char or bgr_pixel depending
+ on what you needed.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is meant to be used as a simple wrapper around the OpenCV
+ IplImage struct or Mat object. Using this class template you can turn
+ an OpenCV image into something that looks like a normal dlib style
+ image object.
+
+ So you should be able to use cv_image objects with many of the image
+ processing functions in dlib as well as the GUI tools for displaying
+ images on the screen.
+
+ Note that this object does NOT take ownership of the image data you
+ give to it. This means it is up to you to make sure the OpenCV image
+ is properly freed at some point. This also means that an instance of
+ this object can only be used as long as the OpenCV image it references
+ remains valid, since a cv_image just points to the OpenCV image's
+ memory directly.
+ !*/
+
+ public:
+ typedef pixel_type type;
+ typedef default_memory_manager mem_manager_type;
+
+ cv_image (
+ const IplImage* img
+ );
+ /*!
+ requires
+ - img->dataOrder == 0
+ (i.e. Only interleaved color channels are supported with cv_image)
+ - (img->depth&0xFF)/8*img->nChannels == sizeof(pixel_type)
+ (i.e. The size of the pixel_type needs to match the size of the pixels
+ inside the OpenCV image)
+ ensures
+ - #nr() == img->height
+ #nc() == img->width
+ - using the operator[] on this object you will be able to access the pixels
+ inside this OpenCV image.
+ !*/
+
+ cv_image (
+ const IplImage img
+ );
+ /*!
+ requires
+ - img.dataOrder == 0
+ (i.e. Only interleaved color channels are supported with cv_image)
+ - (img.depth&0xFF)/8*img.nChannels == sizeof(pixel_type)
+ (i.e. The size of the pixel_type needs to match the size of the pixels
+ inside the OpenCV image)
+ ensures
+ - #nr() == img.height
+ #nc() == img.width
+ - using the operator[] on this object you will be able to access the pixels
+ inside this OpenCV image.
+ !*/
+
+ cv_image (
+ const cv::Mat img
+ );
+ /*!
+ requires
+ - img.depth() == cv::DataType<pixel_traits<pixel_type>::basic_pixel_type>::depth
+ (i.e. The pixel_type template argument needs to match the type of pixel
+ used inside the OpenCV image)
+ - img.channels() == pixel_traits<pixel_type>::num
+ (i.e. the number of channels in the pixel_type needs to match the number of
+ channels in the OpenCV image)
+ ensures
+ - #nr() == img.rows
+ - #nc() == img.cols
+ - using the operator[] on this object you will be able to access the pixels
+ inside this OpenCV image.
+ !*/
+
+ cv_image(
+ );
+ /*!
+ ensures
+ - #nr() == 0
+ - #nc() == 0
+ !*/
+
+ ~cv_image (
+ );
+ /*!
+ ensures
+ - This function does nothing. e.g. It doesn't delete the OpenCV
+ image used by this cv_image object
+ !*/
+
+ long nr(
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in this image
+ !*/
+
+ long nc(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns in this image
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns nr()*nc()
+ (i.e. returns the number of pixels in this image)
+ !*/
+
+ inline pixel_type* operator[] (
+ const long row
+ );
+ /*!
+ requires
+ - 0 <= row < nr()
+ ensures
+ - returns a pointer to the first pixel in the given row
+ of this image
+ !*/
+
+ inline const pixel_type* operator[] (
+ const long row
+ ) const;
+ /*!
+ requires
+ - 0 <= row < nr()
+ ensures
+ - returns a pointer to the first pixel in the given row
+ of this image
+ !*/
+
+ inline const pixel_type& operator()(
+ const long row, const long column
+ ) const
+ /*!
+ requires
+ - 0 <= row < nr()
+ - 0 <= column < nc()
+ ensures
+ - returns a const reference to the pixel at coordinates (row, column)
+ of this image
+ !*/
+
+ inline pixel_type& operator()(
+ const long row, const long column
+ )
+ /*!
+ requires
+ - 0 <= row < nr()
+ - 0 <= column < nc()
+ ensures
+ - returns a reference to the pixel at coordinates (row, column)
+ of this image
+ !*/
+
+ cv_image& operator= (
+ const cv_image& item
+ );
+ /*!
+ ensures
+ - #*this is an identical copy of item
+ - returns #*this
+ !*/
+
+ cv_image& operator=(
+ const IplImage* img
+ );
+ /*!
+ requires
+ - img->dataOrder == 0
+ (i.e. Only interleaved color channels are supported with cv_image)
+ - (img->depth&0xFF)/8*img->nChannels == sizeof(pixel_type)
+ (i.e. The size of the pixel_type needs to match the size of the pixels
+ inside the OpenCV image)
+ ensures
+ - #nr() == img->height
+ #nc() == img->width
+ - using the operator[] on this object you will be able to access the pixels
+ inside this OpenCV image.
+ - returns #*this
+ !*/
+
+ cv_image& operator=(
+ const IplImage img
+ );
+ /*!
+ requires
+ - img->dataOrder == 0
+ (i.e. Only interleaved color channels are supported with cv_image)
+ - (img->depth&0xFF)/8*img->nChannels == sizeof(pixel_type)
+ (i.e. The size of the pixel_type needs to match the size of the pixels
+ inside the OpenCV image)
+ ensures
+ - #nr() == img->height
+ #nc() == img->width
+ - using the operator[] on this object you will be able to access the pixels
+ inside this OpenCV image.
+ - returns #*this
+ !*/
+
+ cv_image& operator=(
+ const cv::Mat img
+ );
+ /*!
+ requires
+ - img.depth() == cv::DataType<pixel_traits<pixel_type>::basic_pixel_type>::depth
+ (i.e. The pixel_type template argument needs to match the type of pixel
+ used inside the OpenCV image)
+ - img.channels() == pixel_traits<pixel_type>::num
+ (i.e. the number of channels in the pixel_type needs to match the number of
+ channels in the OpenCV image)
+ ensures
+ - #nr() == img.rows
+ - #nc() == img.cols
+ - using the operator[] on this object you will be able to access the pixels
+ inside this OpenCV image.
+ - returns #*this
+ !*/
+
+ long width_step (
+ ) const;
+ /*!
+ ensures
+ - returns the size of one row of the image, in bytes.
+ More precisely, return a number N such that:
+ (char*)&item[0][0] + N == (char*)&item[1][0].
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp mat (
+ const cv_image<T>& img
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - R.nr() == img.nr()
+ - R.nc() == img.nc()
+ - for all valid r and c:
+ R(r, c) == img[r][c]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPENCV_IMAGE_AbSTRACT_H_
+
diff --git a/ml/dlib/dlib/opencv/to_open_cv.h b/ml/dlib/dlib/opencv/to_open_cv.h
new file mode 100644
index 000000000..02b7bf6fc
--- /dev/null
+++ b/ml/dlib/dlib/opencv/to_open_cv.h
@@ -0,0 +1,46 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TO_OPEN_Cv_Hh_
+#define DLIB_TO_OPEN_Cv_Hh_
+
+#include <opencv2/core/core.hpp>
+#include "to_open_cv_abstract.h"
+#include "../pixel.h"
+#include "../matrix/matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type
+ >
+ cv::Mat toMat (
+ image_type& img
+ )
+ {
+ if (image_size(img) == 0)
+ return cv::Mat();
+
+ typedef typename image_traits<image_type>::pixel_type type;
+ typedef typename pixel_traits<type>::basic_pixel_type basic_pixel_type;
+ if (pixel_traits<type>::num == 1)
+ {
+ return cv::Mat(num_rows(img), num_columns(img), cv::DataType<basic_pixel_type>::type, image_data(img), width_step(img));
+ }
+ else
+ {
+ int depth = sizeof(typename pixel_traits<type>::basic_pixel_type)*8;
+ int channels = pixel_traits<type>::num;
+ int thetype = CV_MAKETYPE(depth, channels);
+ return cv::Mat(num_rows(img), num_columns(img), thetype, image_data(img), width_step(img));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TO_OPEN_Cv_Hh_
+
diff --git a/ml/dlib/dlib/opencv/to_open_cv_abstract.h b/ml/dlib/dlib/opencv/to_open_cv_abstract.h
new file mode 100644
index 000000000..43307e02f
--- /dev/null
+++ b/ml/dlib/dlib/opencv/to_open_cv_abstract.h
@@ -0,0 +1,34 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_TO_OPEN_Cv_ABSTRACTh_
+#ifdef DLIB_TO_OPEN_Cv_ABSTRACTh_
+
+#include <opencv2/core/core.hpp>
+#include "../pixel.h"
+
+namespace dlib
+{
+ template <
+ typename image_type
+ >
+ cv::Mat toMat (
+ image_type& img
+ );
+ /*!
+ requires
+ - image_type == an image object that implements the interface defined in
+ dlib/image_processing/generic_image.h or a dlib::matrix object which uses a
+ row_major_layout.
+ - pixel_traits is defined for the contents of img.
+ ensures
+ - returns an OpenCV Mat object which represents the same image as img. This
+ is done by setting up the Mat object to point to the same memory as img.
+ Therefore, the returned Mat object is valid only as long as pointers
+ to the pixels in img remain valid.
+ !*/
+}
+
+#endif // DLIB_TO_OPEN_Cv_ABSTRACTh_
+
+
+
diff --git a/ml/dlib/dlib/optimization.h b/ml/dlib/dlib/optimization.h
new file mode 100644
index 000000000..260eacc1c
--- /dev/null
+++ b/ml/dlib/dlib/optimization.h
@@ -0,0 +1,24 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATIOn_HEADER
+#define DLIB_OPTIMIZATIOn_HEADER
+
+#include "optimization/optimization.h"
+#include "optimization/optimization_bobyqa.h"
+#include "optimization/optimization_solve_qp_using_smo.h"
+#include "optimization/optimization_solve_qp2_using_smo.h"
+#include "optimization/optimization_solve_qp3_using_smo.h"
+#include "optimization/optimization_oca.h"
+#include "optimization/optimization_trust_region.h"
+#include "optimization/optimization_least_squares.h"
+#include "optimization/max_cost_assignment.h"
+#include "optimization/max_sum_submatrix.h"
+#include "optimization/find_max_factor_graph_nmplp.h"
+#include "optimization/find_max_factor_graph_viterbi.h"
+#include "optimization/find_max_parse_cky.h"
+#include "optimization/isotonic_regression.h"
+
+#endif // DLIB_OPTIMIZATIOn_HEADER
+
+
+
diff --git a/ml/dlib/dlib/optimization/elastic_net.h b/ml/dlib/dlib/optimization/elastic_net.h
new file mode 100644
index 000000000..6c4b6d0b4
--- /dev/null
+++ b/ml/dlib/dlib/optimization/elastic_net.h
@@ -0,0 +1,389 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ElASTIC_NET_Hh_
+#define DLIB_ElASTIC_NET_Hh_
+
+#include "../matrix.h"
+#include "elastic_net_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class elastic_net
+ {
+ public:
+
+ template <typename EXP>
+ explicit elastic_net(
+ const matrix_exp<EXP>& XX
+ ) : eps(1e-5), max_iterations(50000), verbose(false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(XX.size() > 0 &&
+ XX.nr() == XX.nc(),
+ "\t elastic_net::elastic_net(XX)"
+ << " \n\t XX must be a non-empty square matrix."
+ << " \n\t XX.nr(): " << XX.nr()
+ << " \n\t XX.nc(): " << XX.nc()
+ << " \n\t this: " << this
+ );
+
+
+ // If the number of columns in X is big and in particular bigger than the number of
+ // rows then we can get rid of them by doing some SVD magic. Doing this doesn't
+ // make the final results of anything change but makes all the matrices have
+ // dimensions that are X.nr() in size, which can be much smaller.
+ matrix<double,0,1> s;
+ svd3(XX,u,eig_vals,eig_vects);
+ s = sqrt(eig_vals);
+ X = eig_vects*diagm(s);
+ u = eig_vects*inv(diagm(s));
+
+
+
+ samples.resize(X.nr()*2);
+
+ for (size_t i = 0; i < samples.size(); ++i)
+ index.push_back(i);
+ active_size = index.size();
+
+
+ // setup the training samples used in the SVM optimizer below
+ for (size_t i = 0; i < samples.size(); ++i)
+ {
+ auto& x = samples[i];
+ const long idx = i/2;
+ if (i%2 == 0)
+ x.label = +1;
+ else
+ x.label = -1;
+
+ x.r = idx%X.nr();
+ }
+ }
+
+ template <typename EXP1, typename EXP2>
+ elastic_net(
+ const matrix_exp<EXP1>& XX,
+ const matrix_exp<EXP2>& XY
+ ) : elastic_net(XX)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(XX.size() > 0 &&
+ XX.nr() == XX.nc() &&
+ is_col_vector(XY) &&
+ XX.nc() == XY.size() ,
+ "\t elastic_net::elastic_net(XX,XY)"
+ << " \n\t Invalid inputs were given to this function."
+ << " \n\t XX.size(): " << XX.size()
+ << " \n\t is_col_vector(XY): " << is_col_vector(XY)
+ << " \n\t XX.nr(): " << XX.nr()
+ << " \n\t XX.nc(): " << XX.nc()
+ << " \n\t XY.size(): " << XY.size()
+ << " \n\t this: " << this
+ );
+
+ set_xy(XY);
+ }
+
+ long size (
+ ) const { return u.nr(); }
+
+ template <typename EXP>
+ void set_xy(
+ const matrix_exp<EXP>& XY
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(XY) &&
+ XY.size() == size(),
+ "\t void elastic_net::set_y(Y)"
+ << " \n\t Invalid inputs were given to this function."
+ << " \n\t is_col_vector(XY): " << is_col_vector(XY)
+ << " \n\t size(): " << size()
+ << " \n\t XY.size(): " << XY.size()
+ << " \n\t this: " << this
+ );
+
+ Y = trans(u)*XY;
+ // We can use the ynorm after it has been projected because the only place Y
+ // appears in the algorithm is in terms of dot products with w and x vectors.
+ // But those vectors are always in the span of X and therefore we only see the
+ // part of the norm of Y that is in the span of X (and hence u since u and X
+ // have the same span by construction)
+ ynorm = length_squared(Y);
+ xdoty = X*Y;
+ eig_vects_xdoty = trans(eig_vects)*xdoty;
+
+ w.set_size(Y.size());
+ // zero out any memory of previous solutions
+ alpha.assign(X.nr()*2, 0);
+ }
+
+ bool have_target_values (
+ ) const { return Y.size() != 0; }
+
+ void set_epsilon(
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void elastic_net::set_epsilon()"
+ << " \n\t eps_ must be greater than 0"
+ << " \n\t eps_: " << eps_
+ << " \n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ double get_epsilon (
+ ) const { return eps; }
+
+ matrix<double,0,1> operator() (
+ double ridge_lambda,
+ double lasso_budget = std::numeric_limits<double>::infinity()
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(have_target_values() &&
+ ridge_lambda > 0 &&
+ lasso_budget > 0 ,
+ "\t matrix<double,0,1> elastic_net::operator()()"
+ << " \n\t Invalid inputs were given to this function."
+ << " \n\t have_target_values(): " << have_target_values()
+ << " \n\t ridge_lambda: " << ridge_lambda
+ << " \n\t lasso_budget: " << lasso_budget
+ << " \n\t this: " << this
+ );
+
+
+ // First check if lasso_budget is so big that it isn't even active. We do this
+ // by doing just ridge regression and checking the result.
+ matrix<double,0,1> betas = eig_vects*tmp(inv(diagm(eig_vals + ridge_lambda))*eig_vects_xdoty);
+ if (sum(abs(betas)) <= lasso_budget)
+ return betas;
+
+
+ // Set w back to 0. We will compute the w corresponding to what is currently
+ // in alpha layer on. This way w and alpha are always in sync.
+ w = 0;
+ wy_mult = 0;
+ wdoty = 0;
+
+
+ // return dot(w,x)
+ auto dot = [&](const matrix<double,0,1>& w, const en_sample2& x)
+ {
+ const double xmul = -x.label*(1/lasso_budget);
+ // Do the base dot product but don't forget to add in the -(1/t)*y part from the svm reduction paper
+ double val = rowm(X,x.r)*w + xmul*wdoty + wy_mult*xdoty(x.r) + xmul*wy_mult*ynorm;
+
+ return val;
+ };
+
+
+ // perform w += scale*x;
+ auto add_to = [&](matrix<double,0,1>& w, double scale, const en_sample2& x)
+ {
+ const double xmul = -x.label*(1/lasso_budget);
+ wy_mult += scale*xmul;
+ wdoty += scale*xdoty(x.r);
+ w += scale*trans(rowm(X,x.r));
+
+ };
+
+ const double Dii = ridge_lambda;
+
+ // setup the training samples used in the SVM optimizer below
+ for (size_t i = 0; i < samples.size(); ++i)
+ {
+ auto& x = samples[i];
+
+ const double xmul = -x.label*(1/lasso_budget);
+ x.xdotx = xmul*xmul*ynorm;
+ for (long c = 0; c < X.nc(); ++c)
+ x.xdotx += std::pow(X(x.r,c)+xmul*Y(c), 2.0) - std::pow(xmul*Y(c),2.0);
+
+ // compute the correct w given whatever might be in alpha.
+ if (alpha[i] != 0)
+ add_to(w, x.label*alpha[i], samples[i]);
+ }
+
+
+ // Now run the optimizer
+ double PG_max_prev = std::numeric_limits<double>::infinity();
+ double PG_min_prev = -std::numeric_limits<double>::infinity();
+
+
+ unsigned int iter;
+ for (iter = 0; iter < max_iterations; ++iter)
+ {
+ // randomly shuffle the indices
+ for (unsigned long i = 0; i < active_size; ++i)
+ {
+ // pick a random index >= i
+ const long j = i + rnd.get_random_32bit_number()%(active_size-i);
+ std::swap(index[i], index[j]);
+ }
+
+ double PG_max = -std::numeric_limits<double>::infinity();
+ double PG_min = std::numeric_limits<double>::infinity();
+ for (size_t ii = 0; ii < active_size; ++ii)
+ {
+ const auto i = index[ii];
+ const auto& x = samples[i];
+ double G = x.label*dot(w, x) - 1 + Dii*alpha[i];
+
+ double PG = 0;
+ if (alpha[i] == 0)
+ {
+ if (G > PG_max_prev)
+ {
+ // shrink the active set of training examples
+ --active_size;
+ std::swap(index[ii], index[active_size]);
+ --ii;
+ continue;
+ }
+
+ if (G < 0)
+ PG = G;
+ }
+ else
+ {
+ PG = G;
+ }
+
+ if (PG > PG_max)
+ PG_max = PG;
+ if (PG < PG_min)
+ PG_min = PG;
+
+ // if PG != 0
+ if (std::abs(PG) > 1e-12)
+ {
+ const double alpha_old = alpha[i];
+ alpha[i] = std::max(alpha[i] - G/(x.xdotx+Dii), (double)0.0);
+ const double delta = (alpha[i]-alpha_old)*x.label;
+ add_to(w, delta, x);
+ }
+ }
+
+ if (verbose)
+ {
+ using namespace std;
+ cout << "gap: " << PG_max - PG_min << endl;
+ cout << "active_size: " << active_size << endl;
+ cout << "iter: " << iter << endl;
+ cout << endl;
+ }
+
+ if (PG_max - PG_min <= eps)
+ {
+ // stop if we are within eps tolerance and the last iteration
+ // was over all the samples
+ if (active_size == index.size())
+ break;
+
+ // Turn off shrinking on the next iteration. We will stop if the
+ // tolerance is still <= eps when shrinking is off.
+ active_size = index.size();
+ PG_max_prev = std::numeric_limits<double>::infinity();
+ PG_min_prev = -std::numeric_limits<double>::infinity();
+ }
+ else
+ {
+ PG_max_prev = PG_max;
+ PG_min_prev = PG_min;
+ if (PG_max_prev <= 0)
+ PG_max_prev = std::numeric_limits<double>::infinity();
+ if (PG_min_prev >= 0)
+ PG_min_prev = -std::numeric_limits<double>::infinity();
+ }
+
+
+ // recalculate wdoty every so often to avoid drift.
+ if (iter%100 == 0)
+ wdoty = dlib::dot(Y, w);
+ }
+
+
+ betas.set_size(alpha.size()/2);
+ for (long i = 0; i < betas.size(); ++i)
+ betas(i) = lasso_budget*(alpha[2*i] - alpha[2*i+1]);
+ betas /= sum(mat(alpha));
+ return betas;
+ }
+
+
+ private:
+
+ struct en_sample2
+ {
+ // X location
+ long r;
+
+
+ double label;
+
+ double xdotx;
+ };
+
+ std::vector<en_sample2> samples;
+ std::vector<double> alpha;
+ double ynorm;
+ matrix<double> X;
+ matrix<double,0,1> Y;
+ matrix<double,0,1> xdoty;
+ double wdoty;
+ double wy_mult; // logically, the real w is what is in the w vector + wy_mult*Y
+ matrix<double,0,1> w;
+ std::vector<long> index;
+ unsigned long active_size;
+
+ matrix<double,0,1> eig_vects_xdoty;
+ matrix<double,0,1> eig_vals;
+ matrix<double> eig_vects;
+ matrix<double> u;
+
+ dlib::rand rnd;
+
+
+ double eps;
+ unsigned long max_iterations;
+ bool verbose;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ElASTIC_NET_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/elastic_net_abstract.h b/ml/dlib/dlib/optimization/elastic_net_abstract.h
new file mode 100644
index 000000000..8ae69e37e
--- /dev/null
+++ b/ml/dlib/dlib/optimization/elastic_net_abstract.h
@@ -0,0 +1,190 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ElASTIC_NET_ABSTRACT_Hh_
+#ifdef DLIB_ElASTIC_NET_ABSTRACT_Hh_
+
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class elastic_net
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for solving the following optimization problem:
+
+ min_w: length_squared(X*w - Y) + ridge_lambda*length_squared(w)
+ such that: sum(abs(w)) <= lasso_budget
+
+ That is, it solves the elastic net optimization problem. This object also
+ has the special property that you can quickly obtain different solutions
+ for different settings of ridge_lambda, lasso_budget, and target Y values.
+
+ This is because a large amount of work is precomputed in the constructor.
+ The solver will also remember the previous solution and will use that to
+ warm start subsequent invocations. Therefore, you can efficiently get
+ solutions for a wide range of regularization parameters.
+
+
+ The particular algorithm used to solve it is described in the paper:
+ Zhou, Quan, et al. "A reduction of the elastic net to support vector
+ machines with an application to gpu computing." arXiv preprint
+ arXiv:1409.1976 (2014). APA
+
+ And for the SVM solver sub-component we use the algorithm from:
+ Hsieh, Cho-Jui, et al. "A dual coordinate descent method for large-scale
+ linear SVM." Proceedings of the 25th international conference on Machine
+ learning. ACM, 2008.
+ !*/
+
+ public:
+
+ template <typename EXP>
+ explicit elastic_net(
+ const matrix_exp<EXP>& XX
+ );
+ /*!
+ requires
+ - XX.size() != 0
+ - XX.nr() == XX.nc()
+ ensures
+ - #get_epsilon() == 1e-5
+ - #get_max_iterations() == 50000
+ - This object will not be verbose unless be_verbose() is called.
+ - #size() == XX.nc()
+ - #have_target_values() == false
+ - We interpret XX as trans(X)*X where X is as defined in the objective
+ function discussed above in WHAT THIS OBJECT REPRESENTS.
+ !*/
+
+ template <typename EXP1, typename EXP2>
+ elastic_net(
+ const matrix_exp<EXP1>& XX,
+ const matrix_exp<EXP2>& XY
+ );
+ /*!
+ requires
+ - XX.size() != 0
+ - XX.nr() == XX.nc()
+ - is_col_vector(XY)
+ - XX.nc() == Y.size()
+ ensures
+ - constructs this object by calling the elastic_net(XX) constructor and
+ then calling this->set_xy(XY).
+ - #have_target_values() == true
+ - We interpret XX as trans(X)*X where X is as defined in the objective
+ function discussed above in WHAT THIS OBJECT REPRESENTS. Similarly, XY
+ should be trans(X)*Y.
+ !*/
+
+ long size (
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the data loaded into this object. That is,
+ how many elements are in the optimal w vector? This function returns
+ that number.
+ !*/
+
+ bool have_target_values (
+ ) const;
+ /*!
+ ensures
+ - returns true if set_xy() has been called and false otherwise.
+ !*/
+
+ template <typename EXP>
+ void set_xy(
+ const matrix_exp<EXP>& XY
+ );
+ /*!
+ requires
+ - is_col_vector(Y)
+ - Y.size() == size()
+ ensures
+ - #have_target_values() == true
+ - Sets the target values of the regression. Note that we expect the given
+ matrix, XY, to be equal to trans(X)*Y, where X and Y have the definitions
+ discussed above in WHAT THIS OBJECT REPRESENTS.
+ !*/
+
+ void set_epsilon(
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when the solver should stop.
+ Smaller values may result in a more accurate solution but take longer to
+ execute.
+ !*/
+
+ unsigned long get_max_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of iterations the optimizer is allowed to run
+ before it is required to stop and return a result.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out.
+ !*/
+
+
+ matrix<double,0,1> operator() (
+ double ridge_lambda,
+ double lasso_budget = std::numeric_limits<double>::infinity()
+ );
+ /*!
+ requires
+ - have_target_values() == true
+ - ridge_lambda > 0
+ - lasso_budget > 0
+ ensures
+ - Solves the optimization problem described in the WHAT THIS OBJECT
+ REPRESENTS section above and returns the optimal w.
+ - The returned vector has size() elements.
+ - if (lasso_budget == infinity) then
+ - The lasso constraint is ignored
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ElASTIC_NET_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp.h b/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp.h
new file mode 100644
index 000000000..3dd7fd56f
--- /dev/null
+++ b/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp.h
@@ -0,0 +1,337 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_Hh_
+#define DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_Hh_
+
+#include "find_max_factor_graph_nmplp_abstract.h"
+#include <vector>
+#include <map>
+#include "../matrix.h"
+#include "../hash.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ class simple_hash_map
+ {
+ public:
+
+ simple_hash_map(
+ ) :
+ scan_dist(6)
+ {
+ data.resize(5000);
+ }
+
+ void insert (
+ const unsigned long a,
+ const unsigned long b,
+ const unsigned long value
+ )
+ /*!
+ requires
+ - a != std::numeric_limits<unsigned long>::max()
+ ensures
+ - #(*this)(a,b) == value
+ !*/
+ {
+ const uint32 h = murmur_hash3_2(a,b)%(data.size()-scan_dist);
+
+ const unsigned long empty_bucket = std::numeric_limits<unsigned long>::max();
+
+ for (uint32 i = 0; i < scan_dist; ++i)
+ {
+ if (data[i+h].key1 == empty_bucket)
+ {
+ data[i+h].key1 = a;
+ data[i+h].key2 = b;
+ data[i+h].value = value;
+ return;
+ }
+ }
+
+ // if we get this far it means the hash table is filling up. So double its size.
+ std::vector<bucket> new_data;
+ new_data.resize(data.size()*2);
+ new_data.swap(data);
+ for (uint32 i = 0; i < new_data.size(); ++i)
+ {
+ if (new_data[i].key1 != empty_bucket)
+ {
+ insert(new_data[i].key1, new_data[i].key2, new_data[i].value);
+ }
+ }
+
+ insert(a,b,value);
+ }
+
+ unsigned long operator() (
+ const unsigned long a,
+ const unsigned long b
+ ) const
+ /*!
+ requires
+ - this->insert(a,b,some_value) has been called
+ ensures
+ - returns the value stored at key (a,b)
+ !*/
+ {
+ DLIB_ASSERT(a != b, "An invalid map_problem was given to find_max_factor_graph_nmplp()."
+ << "\nNode " << a << " is listed as being a neighbor with itself, which is illegal.");
+
+ uint32 h = murmur_hash3_2(a,b)%(data.size()-scan_dist);
+
+
+ for (unsigned long i = 0; i < scan_dist; ++i)
+ {
+ if (data[h].key1 == a && data[h].key2 == b)
+ {
+ return data[h].value;
+ }
+ ++h;
+ }
+
+
+ // this should never happen (since this function requires (a,b) to be in the hash table
+ DLIB_ASSERT(false, "An invalid map_problem was given to find_max_factor_graph_nmplp()."
+ << "\nThe nodes in the map_problem are inconsistent because node "<<a<<" is in the neighbor list"
+ << "\nof node "<<b<< " but node "<<b<<" isn't in the neighbor list of node "<<a<<". The neighbor relationship"
+ << "\nis supposed to be symmetric."
+ );
+ return 0;
+ }
+
+ private:
+
+ struct bucket
+ {
+ // having max() in key1 indicates that the bucket isn't used.
+ bucket() : key1(std::numeric_limits<unsigned long>::max()) {}
+ unsigned long key1;
+ unsigned long key2;
+ unsigned long value;
+ };
+
+ std::vector<bucket> data;
+ const unsigned int scan_dist;
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_problem
+ >
+ void find_max_factor_graph_nmplp (
+ const map_problem& prob,
+ std::vector<unsigned long>& map_assignment,
+ unsigned long max_iter,
+ double eps
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( eps > 0,
+ "\t void find_max_factor_graph_nmplp()"
+ << "\n\t eps must be greater than zero"
+ << "\n\t eps: " << eps
+ );
+
+ /*
+ This function is an implementation of the NMPLP algorithm introduced in the
+ following papers:
+ Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008)
+ by Amir Globerson and Tommi Jaakkola
+
+ Introduction to dual decomposition for inference (2011)
+ by David Sontag, Amir Globerson, and Tommi Jaakkola
+
+ In particular, this function implements the star MPLP update equations shown as
+ equation 1.20 from the paper Introduction to dual decomposition for inference
+ (the method was called NMPLP in the first paper). It should also be noted that
+ the original description of the NMPLP in the first paper had an error in the
+ equations and the second paper contains corrected equations, which is what this
+ function uses.
+ */
+
+ typedef typename map_problem::node_iterator node_iterator;
+ typedef typename map_problem::neighbor_iterator neighbor_iterator;
+
+ map_assignment.resize(prob.number_of_nodes());
+
+
+ if (prob.number_of_nodes() == 0)
+ return;
+
+
+ std::vector<double> delta_elements;
+ delta_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3);
+
+ impl::simple_hash_map delta_idx;
+
+
+
+ // Initialize delta to zero and fill up the hash table with the appropriate values
+ // so we can index into delta later on.
+ for (node_iterator i = prob.begin(); i != prob.end(); ++i)
+ {
+ const unsigned long id_i = prob.node_id(i);
+
+ for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
+ {
+ const unsigned long id_j = prob.node_id(j);
+ delta_idx.insert(id_i, id_j, delta_elements.size());
+
+ const unsigned long num_states_xj = prob.num_states(j);
+ for (unsigned long xj = 0; xj < num_states_xj; ++xj)
+ delta_elements.push_back(0);
+ }
+ }
+
+
+ std::vector<double> gamma_i;
+ std::vector<std::vector<double> > gamma_ji;
+ std::vector<std::vector<double> > delta_to_j_no_i;
+ // These arrays will end up with a length equal to the maximum number of neighbors
+ // of any node in the graph. So reserve a bigish number of slots so that we are
+ // very unlikely to need to preform an expensive reallocation during the
+ // optimization.
+ gamma_ji.reserve(10000);
+ delta_to_j_no_i.reserve(10000);
+
+
+ double max_change = eps + 1;
+ // Now do the main body of the optimization.
+ unsigned long iter;
+ for (iter = 0; iter < max_iter && max_change > eps; ++iter)
+ {
+ max_change = -std::numeric_limits<double>::infinity();
+
+ for (node_iterator i = prob.begin(); i != prob.end(); ++i)
+ {
+ const unsigned long id_i = prob.node_id(i);
+ const unsigned long num_states_xi = prob.num_states(i);
+ gamma_i.assign(num_states_xi, 0);
+
+ double num_neighbors = 0;
+
+ unsigned int jcnt = 0;
+ // first we fill in the gamma vectors
+ for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
+ {
+ // Make sure these arrays are big enough to hold all the neighbor
+ // information.
+ if (jcnt >= gamma_ji.size())
+ {
+ gamma_ji.resize(gamma_ji.size()+1);
+ delta_to_j_no_i.resize(delta_to_j_no_i.size()+1);
+ }
+
+ ++num_neighbors;
+ const unsigned long id_j = prob.node_id(j);
+ const unsigned long num_states_xj = prob.num_states(j);
+
+ gamma_ji[jcnt].assign(num_states_xi, -std::numeric_limits<double>::infinity());
+ delta_to_j_no_i[jcnt].assign(num_states_xj, 0);
+
+ // compute delta_j^{-i} and store it in delta_to_j_no_i[jcnt]
+ for (neighbor_iterator k = prob.begin(j); k != prob.end(j); ++k)
+ {
+ const unsigned long id_k = prob.node_id(k);
+ if (id_k==id_i)
+ continue;
+ const double* const delta_kj = &delta_elements[delta_idx(id_k,id_j)];
+ for (unsigned long xj = 0; xj < num_states_xj; ++xj)
+ {
+ delta_to_j_no_i[jcnt][xj] += delta_kj[xj];
+ }
+ }
+
+ // now compute gamma values
+ for (unsigned long xi = 0; xi < num_states_xi; ++xi)
+ {
+ for (unsigned long xj = 0; xj < num_states_xj; ++xj)
+ {
+ gamma_ji[jcnt][xi] = std::max(gamma_ji[jcnt][xi], prob.factor_value(i,j,xi,xj) + delta_to_j_no_i[jcnt][xj]);
+ }
+ gamma_i[xi] += gamma_ji[jcnt][xi];
+ }
+ ++jcnt;
+ }
+
+ // now update the delta values
+ jcnt = 0;
+ for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
+ {
+ const unsigned long id_j = prob.node_id(j);
+ const unsigned long num_states_xj = prob.num_states(j);
+
+ // messages from j to i
+ double* const delta_ji = &delta_elements[delta_idx(id_j,id_i)];
+
+ // messages from i to j
+ double* const delta_ij = &delta_elements[delta_idx(id_i,id_j)];
+
+ for (unsigned long xj = 0; xj < num_states_xj; ++xj)
+ {
+ double best_val = -std::numeric_limits<double>::infinity();
+
+ for (unsigned long xi = 0; xi < num_states_xi; ++xi)
+ {
+ double val = prob.factor_value(i,j,xi,xj) + 2/(num_neighbors+1)*gamma_i[xi] -gamma_ji[jcnt][xi];
+ if (val > best_val)
+ best_val = val;
+ }
+ best_val = -0.5*delta_to_j_no_i[jcnt][xj] + 0.5*best_val;
+
+ if (std::abs(delta_ij[xj] - best_val) > max_change)
+ max_change = std::abs(delta_ij[xj] - best_val);
+
+ delta_ij[xj] = best_val;
+ }
+
+ for (unsigned long xi = 0; xi < num_states_xi; ++xi)
+ {
+ double new_val = -1/(num_neighbors+1)*gamma_i[xi] + gamma_ji[jcnt][xi];
+ if (std::abs(delta_ji[xi] - new_val) > max_change)
+ max_change = std::abs(delta_ji[xi] - new_val);
+ delta_ji[xi] = new_val;
+ }
+ ++jcnt;
+ }
+ }
+ }
+
+
+ // now decode the "beliefs"
+ std::vector<double> b;
+ for (node_iterator i = prob.begin(); i != prob.end(); ++i)
+ {
+ const unsigned long id_i = prob.node_id(i);
+ b.assign(prob.num_states(i), 0);
+
+ for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k)
+ {
+ const unsigned long id_k = prob.node_id(k);
+
+ for (unsigned long xi = 0; xi < b.size(); ++xi)
+ {
+ const double* const delta_ki = &delta_elements[delta_idx(id_k,id_i)];
+ b[xi] += delta_ki[xi];
+ }
+ }
+
+ map_assignment[id_i] = index_of_max(mat(b));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_Hh_
+
diff --git a/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp_abstract.h b/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp_abstract.h
new file mode 100644
index 000000000..3dd9aead0
--- /dev/null
+++ b/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp_abstract.h
@@ -0,0 +1,365 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_ABSTRACT_Hh_
+#ifdef DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_ABSTRACT_Hh_
+
+#include <vector>
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class map_problem
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a factor graph or graphical model. In
+ particular, this object defines the interface a MAP problem on
+ a factor graph must implement if it is to be solved using the
+ find_max_factor_graph_nmplp() routine defined at the bottom of this file.
+
+ Note that there is no dlib::map_problem object. What you are
+ looking at here is simply the interface definition for a map problem.
+ You must implement your own version of this object for the problem
+ you wish to solve and then pass it to the find_max_factor_graph_nmplp() routine.
+
+
+ Note also that a factor graph should not have any nodes which are
+ neighbors with themselves. Additionally, the graph is undirected. This
+ mean that if A is a neighbor of B then B must be a neighbor of A for
+ the map problem to be valid.
+
+
+ Finally, note that the "neighbor" relationship between nodes means the
+ following: Two nodes are neighbors if and only if there is a potential
+ function (implemented by the factor_value() method) which operates on
+ the nodes.
+ !*/
+
+ public:
+
+ class node_iterator
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple forward iterator for iterating over
+ the nodes/variables in this factor graph.
+
+ Note that you can't dereference the iterator and
+ obtain a value. That is, the iterator is opaque to
+ the user. It is used only as an argument to the other
+ methods defined in this interface.
+ !*/
+
+ public:
+ node_iterator(
+ );
+ /*!
+ ensures
+ - constructs an iterator in an undefined state
+ !*/
+
+ node_iterator(
+ const node_iterator& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ !*/
+
+ node_iterator& operator= (
+ const node_iterator& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ - returns #*this
+ !*/
+
+ bool operator== (
+ const node_iterator& item
+ ) const;
+ /*!
+ ensures
+ - returns true if *this and item both reference
+ the same node in the factor graph and false
+ otherwise.
+ !*/
+
+ bool operator!= (
+ const node_iterator& item
+ ) const;
+ /*!
+ ensures
+ - returns false if *this and item both reference
+ the same node in the factor graph and true
+ otherwise.
+ !*/
+
+ node_iterator& operator++(
+ );
+ /*!
+ ensures
+ - advances *this to the next node in the factor graph.
+ - returns a reference to the updated *this
+ (i.e. this is the ++object form of the increment operator)
+ !*/
+ };
+
+ class neighbor_iterator
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple forward iterator for iterating over
+ the nodes/variables in this factor graph. This version
+ of the iterator is used for iterating over the neighbors
+ of another node in the graph.
+ !*/
+
+ public:
+ neighbor_iterator(
+ );
+ /*!
+ ensures
+ - constructs an iterator in an undefined state
+ !*/
+
+ neighbor_iterator(
+ const neighbor_iterator& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ !*/
+
+ neighbor_iterator& operator= (
+ const neighbor_iterator& item
+ );
+ /*!
+ ensures
+ - #*this is a copy of item
+ - returns #*this
+ !*/
+
+ bool operator== (
+ const neighbor_iterator& item
+ ) const;
+ /*!
+ ensures
+ - returns true if *this and item both reference
+ the same node in the factor graph and false
+ otherwise.
+ !*/
+
+ bool operator!= (
+ const neighbor_iterator& item
+ ) const;
+ /*!
+ ensures
+ - returns false if *this and item both reference
+ the same node in the factor graph and true
+ otherwise.
+ !*/
+
+ neighbor_iterator& operator++(
+ );
+ /*!
+ ensures
+ - advances *this to the next node in the factor graph.
+ - returns a reference to the updated *this
+ (i.e. this is the ++object form of the increment operator)
+ !*/
+ };
+
+ unsigned long number_of_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in the factor graph. Or in other words,
+ returns the number of variables in the MAP problem.
+ !*/
+
+ node_iterator begin(
+ ) const;
+ /*!
+ ensures
+ - returns an iterator to the first node in the graph. If no such
+ node exists then returns end().
+ !*/
+
+ node_iterator end(
+ ) const;
+ /*!
+ ensures
+ - returns an iterator to one past the last node in the graph.
+ !*/
+
+ neighbor_iterator begin(
+ const node_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator (i.e. it must be in the range [begin(), end()))
+ ensures
+ - returns an iterator to the first neighboring node of the node
+ referenced by it. If no such node exists then returns end(it).
+ !*/
+
+ neighbor_iterator begin(
+ const neighbor_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator. (i.e. it must be in the range
+ [begin(i), end(i)) where i is some valid iterator. )
+ ensures
+ - returns an iterator to the first neighboring node of the node
+ referenced by it. If no such node exists then returns end(it).
+ !*/
+
+ neighbor_iterator end(
+ const node_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator (i.e. it must be in the range [begin(), end()))
+ ensures
+ - returns an iterator to one past the last neighboring node of the node
+ referenced by it.
+ !*/
+
+ neighbor_iterator end(
+ const neighbor_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator. (i.e. it must be in the range
+ [begin(i), end(i)) where i is some valid iterator. )
+ ensures
+ - returns an iterator to one past the last neighboring node of the node
+ referenced by it.
+ !*/
+
+ unsigned long node_id (
+ const node_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator (i.e. it must be in the range [begin(), end()))
+ ensures
+ - returns a number ID such that:
+ - 0 <= ID < number_of_nodes()
+ - ID == a number which uniquely identifies the node pointed to by it.
+ !*/
+
+ unsigned long node_id (
+ const neighbor_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator. (i.e. it must be in the range
+ [begin(i), end(i)) where i is some valid iterator. )
+ ensures
+ - returns a number ID such that:
+ - 0 <= ID < number_of_nodes()
+ - ID == a number which uniquely identifies the node pointed to by it.
+ !*/
+
+ unsigned long num_states (
+ const node_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator (i.e. it must be in the range [begin(), end()))
+ ensures
+ - returns the number of states attainable by the node/variable referenced by it.
+ !*/
+
+ unsigned long num_states (
+ const neighbor_iterator& it
+ ) const;
+ /*!
+ requires
+ - it == a valid iterator. (i.e. it must be in the range
+ [begin(i), end(i)) where i is some valid iterator. )
+ ensures
+ - returns the number of states attainable by the node/variable referenced by it.
+ !*/
+
+ // The next four functions all have the same contract.
+ double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const;
+ double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const;
+ double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const;
+ double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const;
+ /*!
+ requires
+ - it1 == a valid iterator
+ - it2 == a valid iterator
+ - 0 <= s1 < num_states(it1)
+ - 0 <= s2 < num_states(it2)
+ - it1 and it2 reference nodes which are neighbors in the factor graph
+ ensures
+ - returns the value of the factor/potential function for the given pair of
+ nodes, defined by it1 and it2, for the case where they take on the values
+ s1 and s2 respectively.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_problem
+ >
+ void find_max_factor_graph_nmplp (
+ const map_problem& prob,
+ std::vector<unsigned long>& map_assignment,
+ unsigned long max_iter,
+ double eps
+ );
+ /*!
+ requires
+ - for all valid i: prob.num_states(i) >= 2
+ - map_problem == an object with an interface compatible with the map_problem
+ object defined at the top of this file.
+ - eps > 0
+ ensures
+ - This function is a tool for approximately solving the given MAP problem in a graphical
+ model or factor graph with pairwise potential functions. That is, it attempts
+ to solve a certain kind of optimization problem which can be defined as follows:
+ maximize: f(X)
+ where X is a set of integer valued variables and f(X) can be written as the
+ sum of functions which each involve only two variables from X. In reference
+ to the prob object, the nodes in prob represent the variables in X and the
+ functions which are summed are represented by prob.factor_value().
+ - #map_assignment == the result of the optimization.
+ - #map_assignment.size() == prob.number_of_nodes()
+ - for all valid i:
+ - #map_assignment[prob.node_id(i)] < prob.num_states(i)
+ - #map_assignment[prob.node_id(i)] == The approximate MAP assignment for node/variable i.
+ - eps controls the stopping condition, smaller values of eps lead to more accurate
+ solutions of the relaxed linear program but may take more iterations. Note that
+ the algorithm will never execute more than max_iter iterations regardless of
+ the setting of eps.
+ - If the graph is tree-structured then this routine always gives the exact solution
+ to the MAP problem. However, for graphs with cycles, the solution may be approximate.
+
+
+ - This function is an implementation of the NMPLP algorithm introduced in the
+ following papers:
+ Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008)
+ by Amir Globerson and Tommi Jaakkola
+
+ Introduction to dual decomposition for inference (2011)
+ by David Sontag, Amir Globerson, and Tommi Jaakkola
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi.h b/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi.h
new file mode 100644
index 000000000..f7cbcb8d9
--- /dev/null
+++ b/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi.h
@@ -0,0 +1,232 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_Hh_
+#define DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_Hh_
+
+#include "find_max_factor_graph_viterbi_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include "../array2d.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ struct viterbi_data
+ {
+ viterbi_data() :val(-std::numeric_limits<double>::infinity()), back_index(0) {}
+ double val;
+ unsigned long back_index;
+ };
+
+ template <long NC>
+ inline bool advance_state(
+ matrix<unsigned long,1,NC>& node_states,
+ unsigned long num_states
+ )
+ /*!
+ ensures
+ - advances node_states to the next state by adding 1
+ to node_states(node_states.size()-1) and carrying any
+ rollover (modulo num_states). Stores the result into #node_states.
+ - if (#node_states is all zeros) then
+ - returns false
+ - else
+ - returns true
+ !*/
+ {
+ for (long i = node_states.size()-1; i >= 0; --i)
+ {
+ node_states(i) += 1;
+ if (node_states(i) < num_states)
+ return true;
+
+ node_states(i) = 0;
+ }
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_problem
+ >
+ void find_max_factor_graph_viterbi (
+ const map_problem& prob,
+ std::vector<unsigned long>& map_assignment
+ )
+ {
+ using namespace dlib::impl;
+ const unsigned long order = prob.order();
+ const unsigned long num_states = prob.num_states();
+
+
+ DLIB_ASSERT(prob.num_states() > 0,
+ "\t void find_max_factor_graph_viterbi()"
+ << "\n\t The nodes in a factor graph have to be able to take on more than 0 states."
+ );
+ DLIB_ASSERT(std::pow(num_states,(double)order) < std::numeric_limits<unsigned long>::max(),
+ "\t void find_max_factor_graph_viterbi()"
+ << "\n\t The order is way too large for this algorithm to handle."
+ << "\n\t order: " << order
+ << "\n\t num_states: " << num_states
+ << "\n\t std::pow(num_states,order): " << std::pow(num_states,(double)order)
+ << "\n\t std::numeric_limits<unsigned long>::max(): " << std::numeric_limits<unsigned long>::max()
+ );
+
+ if (prob.number_of_nodes() == 0)
+ {
+ map_assignment.clear();
+ return;
+ }
+
+ if (order == 0)
+ {
+ map_assignment.resize(prob.number_of_nodes());
+ for (unsigned long i = 0; i < map_assignment.size(); ++i)
+ {
+ matrix<unsigned long,1,1> node_state;
+ unsigned long best_state = 0;
+ double best_val = -std::numeric_limits<double>::infinity();
+ for (unsigned long s = 0; s < num_states; ++s)
+ {
+ node_state(0) = s;
+ const double temp = prob.factor_value(i,node_state);
+ if (temp > best_val)
+ {
+ best_val = temp;
+ best_state = s;
+ }
+ }
+ map_assignment[i] = best_state;
+ }
+ return;
+ }
+
+
+ const unsigned long trellis_size = static_cast<unsigned long>(std::pow(num_states,(double)order));
+ unsigned long init_ring_size = 1;
+
+ array2d<impl::viterbi_data> trellis;
+ trellis.set_size(prob.number_of_nodes(), trellis_size);
+
+
+ for (unsigned long node = 0; node < prob.number_of_nodes(); ++node)
+ {
+
+ if (node < order)
+ {
+ matrix<unsigned long,1,0> node_states;
+ node_states.set_size(std::min<int>(node, order) + 1);
+ node_states = 0;
+
+ unsigned long idx = 0;
+ if (node == 0)
+ {
+ do
+ {
+ trellis[node][idx].val = prob.factor_value(node,node_states);
+ ++idx;
+ } while(advance_state(node_states,num_states));
+ }
+ else
+ {
+ init_ring_size *= num_states;
+ do
+ {
+ const unsigned long back_index = idx%init_ring_size;
+ trellis[node][idx].val = prob.factor_value(node,node_states) + trellis[node-1][back_index].val;
+ trellis[node][idx].back_index = back_index;
+ ++idx;
+ } while(advance_state(node_states,num_states));
+
+ }
+ }
+ else if (order == 1)
+ {
+ /*
+ WHAT'S THE DEAL WITH THIS PREPROCESSOR MACRO?
+ Well, if we can declare the dimensions of node_states as a compile
+ time constant then this function runs significantly faster. So this macro
+ is here to let us do that. It just lets us avoid replicating this code
+ block in the following if statements for different order sizes.
+ */
+#define DLIB_FMFGV_WORK \
+ node_states = 0; \
+ unsigned long count = 0; \
+ for (unsigned long i = 0; i < trellis_size; ++i) \
+ { \
+ unsigned long back_index = 0; \
+ double best_score = -std::numeric_limits<double>::infinity(); \
+ for (unsigned long s = 0; s < num_states; ++s) \
+ { \
+ const double temp = prob.factor_value(node,node_states) + trellis[node-1][count%trellis_size].val; \
+ if (temp > best_score) \
+ { \
+ best_score = temp; \
+ back_index = count%trellis_size; \
+ } \
+ advance_state(node_states,num_states); \
+ ++count; \
+ } \
+ trellis[node][i].val = best_score; \
+ trellis[node][i].back_index = back_index; \
+ }
+
+ matrix<unsigned long,1,2> node_states;
+ DLIB_FMFGV_WORK
+ }
+ else if (order == 2)
+ {
+ matrix<unsigned long,1,3> node_states;
+ DLIB_FMFGV_WORK
+ }
+ else if (order == 3)
+ {
+ matrix<unsigned long,1,4> node_states;
+ DLIB_FMFGV_WORK
+ }
+ else
+ {
+ // The general case, here we don't define the size of node_states at compile time.
+ matrix<unsigned long,1,0> node_states(order+1);
+ DLIB_FMFGV_WORK
+ }
+ }
+
+
+ map_assignment.resize(prob.number_of_nodes());
+ // Figure out which state of the last node has the biggest value.
+ unsigned long back_index = 0;
+ double best_val = -std::numeric_limits<double>::infinity();
+ for (long i = 0; i < trellis.nc(); ++i)
+ {
+ if (trellis[trellis.nr()-1][i].val > best_val)
+ {
+ best_val = trellis[trellis.nr()-1][i].val;
+ back_index = i;
+ }
+ }
+ // Follow the back links to find the decoding.
+ for (long node = map_assignment.size()-1; node >= 0; --node)
+ {
+ map_assignment[node] = back_index/init_ring_size;
+ back_index = trellis[node][back_index].back_index;
+ if (node < (long)order)
+ init_ring_size /= num_states;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi_abstract.h b/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi_abstract.h
new file mode 100644
index 000000000..c19e4c7eb
--- /dev/null
+++ b/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi_abstract.h
@@ -0,0 +1,131 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_ABSTRACT_Hh_
+#ifdef DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class map_problem
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a chain-structured factor graph or graphical
+ model. In particular, this object defines the interface a MAP problem
+ on a factor graph must implement if it is to be solved using the
+ find_max_factor_graph_viterbi() routine defined at the bottom of this file.
+
+ Note that there is no dlib::map_problem object. What you are looking
+ at here is simply the interface definition for a map problem. You must
+ implement your own version of this object for the problem you wish to
+ solve and then pass it to the find_max_factor_graph_viterbi() routine.
+ !*/
+
+ public:
+
+ unsigned long order (
+ ) const;
+ /*!
+ ensures
+ - returns the order of this model. The order has the following interpretation:
+ This model can represent a high order Markov chain. If order()==1 then map_problem
+ represents a basic chain-structured graph where nodes only depend on their immediate
+ neighbors. However, high order Markov models can also be used by setting order() > 1.
+ !*/
+
+ unsigned long num_states (
+ ) const;
+ /*!
+ ensures
+ - returns the number of states attainable by each variable/node in the graph.
+ !*/
+
+ unsigned long number_of_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nodes in the factor graph. Or in other words,
+ returns the number of variables in the MAP problem.
+ !*/
+
+ template <
+ typename EXP
+ >
+ double factor_value (
+ unsigned long node_id,
+ const matrix_exp<EXP>& node_states
+ ) const;
+ /*!
+ requires
+ - EXP::type == unsigned long
+ (i.e. node_states contains unsigned longs)
+ - node_id < number_of_nodes()
+ - node_states.size() == min(node_id, order()) + 1
+ - is_vector(node_states) == true
+ - max(node_states) < num_states()
+ ensures
+ - In a chain-structured graph, each node has a potential function associated with
+ it. The potential function operates on the variable given by the node as well
+ as the order() previous variables. Therefore, factor_value() returns the value
+ of the factor/potential function associated with node node_id where the following
+ nodes take on the values defined below:
+ - node_states(0) == the value of the node with ID node_id
+ - node_states(i) == the value of the node with ID node_id-i
+ - It is ok for this function to return a value of -std::numeric_limits<double>::infinity().
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_problem
+ >
+ void find_max_factor_graph_viterbi (
+ const map_problem& prob,
+ std::vector<unsigned long>& map_assignment
+ );
+ /*!
+ requires
+ - prob.num_states() > 0
+ - std::pow(prob.num_states(), prob.order()) < std::numeric_limits<unsigned long>::max()
+ (i.e. The Viterbi algorithm is exponential in the order of the map problem. So don't
+ make order too large.)
+ - map_problem == an object with an interface compatible with the map_problem
+ object defined at the top of this file.
+ ensures
+ - This function is a tool for exactly solving the MAP problem in a chain-structured
+ graphical model or factor graph. That is, it attempts to solve a certain kind of
+ optimization problem which can be defined as follows:
+ - Let X denote a set of prob.number_of_nodes() integer valued variables, each taking
+ a value in the range [0, prob.num_states()).
+ - Let X(i) = the ith variable in X.
+ - Let F(i) = factor_value_i(X(i), X(i-1), ..., X(i-prob.order()))
+ (This is the value returned by prob.factor_value(i, node_states). Note that
+ each factor's value function operates on at most prob.order()+1 variables.
+ Moreover, the variables are adjacent and hence the graph is "chain-structured".)
+
+ Then this function finds the assignments to the X variables which
+ maximizes: sum over all valid i: F(i)
+
+ - #map_assignment == the result of the optimization.
+ - #map_assignment.size() == prob.number_of_nodes()
+ - for all valid i:
+ - #map_assignment[i] < prob.num_states()
+ - #map_assignment[i] == The MAP assignment for node/variable i.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/optimization/find_max_parse_cky.h b/ml/dlib/dlib/optimization/find_max_parse_cky.h
new file mode 100644
index 000000000..79614792a
--- /dev/null
+++ b/ml/dlib/dlib/optimization/find_max_parse_cky.h
@@ -0,0 +1,414 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FIND_MAX_PaRSE_CKY_Hh_
+#define DLIB_FIND_MAX_PaRSE_CKY_Hh_
+
+#include "find_max_parse_cky_abstract.h"
+#include <vector>
+#include <string>
+#include <sstream>
+#include "../serialize.h"
+#include "../array2d.h"
+
+namespace dlib
+{
+
+// -----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct constituent
+ {
+ unsigned long begin, end, k;
+ T left_tag;
+ T right_tag;
+ };
+
+ template <typename T>
+ void serialize(
+ const constituent<T>& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.begin, out);
+ serialize(item.end, out);
+ serialize(item.k, out);
+ serialize(item.left_tag, out);
+ serialize(item.right_tag, out);
+ }
+
+ template <typename T>
+ void deserialize(
+ constituent<T>& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.begin, in);
+ deserialize(item.end, in);
+ deserialize(item.k, in);
+ deserialize(item.left_tag, in);
+ deserialize(item.right_tag, in);
+ }
+
+// -----------------------------------------------------------------------------------------
+
+ const unsigned long END_OF_TREE = 0xFFFFFFFF;
+
+// -----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct parse_tree_element
+ {
+ constituent<T> c;
+ T tag; // id for the constituent corresponding to this level of the tree
+
+ unsigned long left;
+ unsigned long right;
+ double score;
+ };
+
+ template <typename T>
+ void serialize (
+ const parse_tree_element<T>& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.c, out);
+ serialize(item.tag, out);
+ serialize(item.left, out);
+ serialize(item.right, out);
+ serialize(item.score, out);
+ }
+
+ template <typename T>
+ void deserialize (
+ parse_tree_element<T>& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.c, in);
+ deserialize(item.tag, in);
+ deserialize(item.left, in);
+ deserialize(item.right, in);
+ deserialize(item.score, in);
+ }
+
+// -----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T>
+ unsigned long fill_parse_tree(
+ std::vector<parse_tree_element<T> >& parse_tree,
+ const T& tag,
+ const array2d<std::map<T, parse_tree_element<T> > >& back,
+ long r, long c
+ )
+ /*!
+ requires
+ - back[r][c].size() == 0 || back[r][c].count(tag) != 0
+ !*/
+ {
+ // base case of the recursion
+ if (back[r][c].size() == 0)
+ {
+ return END_OF_TREE;
+ }
+
+ const unsigned long idx = parse_tree.size();
+ const parse_tree_element<T>& item = back[r][c].find(tag)->second;
+ parse_tree.push_back(item);
+
+ const long k = item.c.k;
+ const unsigned long idx_left = fill_parse_tree(parse_tree, item.c.left_tag, back, r, k-1);
+ const unsigned long idx_right = fill_parse_tree(parse_tree, item.c.right_tag, back, k, c);
+ parse_tree[idx].left = idx_left;
+ parse_tree[idx].right = idx_right;
+ return idx;
+ }
+ }
+
+ template <typename T, typename production_rule_function>
+ void find_max_parse_cky (
+ const std::vector<T>& sequence,
+ const production_rule_function& production_rules,
+ std::vector<parse_tree_element<T> >& parse_tree
+ )
+ {
+ parse_tree.clear();
+ if (sequence.size() == 0)
+ return;
+
+ array2d<std::map<T,double> > table(sequence.size(), sequence.size());
+ array2d<std::map<T,parse_tree_element<T> > > back(sequence.size(), sequence.size());
+ typedef typename std::map<T,double>::iterator itr;
+ typedef typename std::map<T,parse_tree_element<T> >::iterator itr_b;
+
+ for (long r = 0; r < table.nr(); ++r)
+ table[r][r][sequence[r]] = 0;
+
+ std::vector<std::pair<T,double> > possible_tags;
+
+ for (long r = table.nr()-2; r >= 0; --r)
+ {
+ for (long c = r+1; c < table.nc(); ++c)
+ {
+ for (long k = r; k < c; ++k)
+ {
+ for (itr i = table[k+1][c].begin(); i != table[k+1][c].end(); ++i)
+ {
+ for (itr j = table[r][k].begin(); j != table[r][k].end(); ++j)
+ {
+ constituent<T> con;
+ con.begin = r;
+ con.end = c+1;
+ con.k = k+1;
+ con.left_tag = j->first;
+ con.right_tag = i->first;
+ possible_tags.clear();
+ production_rules(sequence, con, possible_tags);
+ for (unsigned long m = 0; m < possible_tags.size(); ++m)
+ {
+ const double score = possible_tags[m].second + i->second + j->second;
+ itr match = table[r][c].find(possible_tags[m].first);
+ if (match == table[r][c].end() || score > match->second)
+ {
+ table[r][c][possible_tags[m].first] = score;
+ parse_tree_element<T> item;
+ item.c = con;
+ item.score = score;
+ item.tag = possible_tags[m].first;
+ item.left = END_OF_TREE;
+ item.right = END_OF_TREE;
+ back[r][c][possible_tags[m].first] = item;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+
+ // now use back pointers to build the parse trees
+ const long r = 0;
+ const long c = back.nc()-1;
+ if (back[r][c].size() != 0)
+ {
+
+ // find the max scoring element in back[r][c]
+ itr_b max_i = back[r][c].begin();
+ itr_b i = max_i;
+ ++i;
+ for (; i != back[r][c].end(); ++i)
+ {
+ if (i->second.score > max_i->second.score)
+ max_i = i;
+ }
+
+ parse_tree.reserve(c);
+ impl::fill_parse_tree(parse_tree, max_i->second.tag, back, r, c);
+ }
+ }
+
+// -----------------------------------------------------------------------------------------
+
+ class parse_tree_to_string_error : public error
+ {
+ public:
+ parse_tree_to_string_error(const std::string& str): error(str) {}
+ };
+
+ namespace impl
+ {
+ template <bool enabled, typename T>
+ typename enable_if_c<enabled>::type conditional_print(
+ const T& item,
+ std::ostream& out
+ ) { out << item << " "; }
+
+ template <bool enabled, typename T>
+ typename disable_if_c<enabled>::type conditional_print(
+ const T& ,
+ std::ostream&
+ ) { }
+
+ template <bool print_tag, bool skip_tag, typename T, typename U >
+ void print_parse_tree_helper (
+ const std::vector<parse_tree_element<T> >& tree,
+ const std::vector<U>& words,
+ unsigned long i,
+ const T& tag_to_skip,
+ std::ostream& out
+ )
+ {
+ if (!skip_tag || tree[i].tag != tag_to_skip)
+ out << "[";
+
+ bool left_recurse = false;
+
+ // Only print if we are supposed to. Doing it this funny way avoids compiler
+ // errors in parse_tree_to_string() for the case where tag isn't
+ // printable.
+ if (!skip_tag || tree[i].tag != tag_to_skip)
+ conditional_print<print_tag>(tree[i].tag, out);
+
+ if (tree[i].left < tree.size())
+ {
+ left_recurse = true;
+ print_parse_tree_helper<print_tag,skip_tag>(tree, words, tree[i].left, tag_to_skip, out);
+ }
+ else
+ {
+ if ((tree[i].c.begin) < words.size())
+ {
+ out << words[tree[i].c.begin] << " ";
+ }
+ else
+ {
+ std::ostringstream sout;
+ sout << "Parse tree refers to element " << tree[i].c.begin
+ << " of sequence which is only of size " << words.size() << ".";
+ throw parse_tree_to_string_error(sout.str());
+ }
+ }
+
+ if (left_recurse == true)
+ out << " ";
+
+ if (tree[i].right < tree.size())
+ {
+ print_parse_tree_helper<print_tag,skip_tag>(tree, words, tree[i].right, tag_to_skip, out);
+ }
+ else
+ {
+ if (tree[i].c.k < words.size())
+ {
+ out << words[tree[i].c.k];
+ }
+ else
+ {
+ std::ostringstream sout;
+ sout << "Parse tree refers to element " << tree[i].c.k
+ << " of sequence which is only of size " << words.size() << ".";
+ throw parse_tree_to_string_error(sout.str());
+ }
+ }
+
+
+ if (!skip_tag || tree[i].tag != tag_to_skip)
+ out << "]";
+ }
+ }
+
+// -----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ std::string parse_tree_to_string (
+ const std::vector<parse_tree_element<T> >& tree,
+ const std::vector<U>& words,
+ const unsigned long root_idx = 0
+ )
+ {
+ if (root_idx >= tree.size())
+ return "";
+
+ std::ostringstream sout;
+ impl::print_parse_tree_helper<false,false>(tree, words, root_idx, tree[root_idx].tag, sout);
+ return sout.str();
+ }
+
+// -----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ std::string parse_tree_to_string_tagged (
+ const std::vector<parse_tree_element<T> >& tree,
+ const std::vector<U>& words,
+ const unsigned long root_idx = 0
+ )
+ {
+ if (root_idx >= tree.size())
+ return "";
+
+ std::ostringstream sout;
+ impl::print_parse_tree_helper<true,false>(tree, words, root_idx, tree[root_idx].tag, sout);
+ return sout.str();
+ }
+
+// -----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ std::string parse_trees_to_string (
+ const std::vector<parse_tree_element<T> >& tree,
+ const std::vector<U>& words,
+ const T& tag_to_skip
+ )
+ {
+ if (tree.size() == 0)
+ return "";
+
+ std::ostringstream sout;
+ impl::print_parse_tree_helper<false,true>(tree, words, 0, tag_to_skip, sout);
+ return sout.str();
+ }
+
+// -----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ std::string parse_trees_to_string_tagged (
+ const std::vector<parse_tree_element<T> >& tree,
+ const std::vector<U>& words,
+ const T& tag_to_skip
+ )
+ {
+ if (tree.size() == 0)
+ return "";
+
+ std::ostringstream sout;
+ impl::print_parse_tree_helper<true,true>(tree, words, 0, tag_to_skip, sout);
+ return sout.str();
+ }
+
+// -----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T>
+ void helper_find_trees_without_tag (
+ const std::vector<parse_tree_element<T> >& tree,
+ const T& tag,
+ std::vector<unsigned long>& tree_roots,
+ unsigned long idx
+ )
+ {
+ if (idx < tree.size())
+ {
+ if (tree[idx].tag != tag)
+ {
+ tree_roots.push_back(idx);
+ }
+ else
+ {
+ helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].left);
+ helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].right);
+ }
+ }
+ }
+ }
+
+ template <typename T>
+ void find_trees_not_rooted_with_tag (
+ const std::vector<parse_tree_element<T> >& tree,
+ const T& tag,
+ std::vector<unsigned long>& tree_roots
+ )
+ {
+ tree_roots.clear();
+ impl::helper_find_trees_without_tag(tree, tag, tree_roots, 0);
+ }
+
+// -----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_MAX_PaRSE_CKY_Hh_
+
diff --git a/ml/dlib/dlib/optimization/find_max_parse_cky_abstract.h b/ml/dlib/dlib/optimization/find_max_parse_cky_abstract.h
new file mode 100644
index 000000000..52ffc787b
--- /dev/null
+++ b/ml/dlib/dlib/optimization/find_max_parse_cky_abstract.h
@@ -0,0 +1,388 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_FIND_MAX_PARsE_CKY_ABSTRACT_Hh_
+#ifdef DLIB_FIND_MAX_PARsE_CKY_ABSTRACT_Hh_
+
+#include <vector>
+#include <string>
+#include "../algs.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// -----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct constituent
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the linguistic idea of a constituent, that is, a
+ group of words that functions as a single unit. In particular, it
+ represents a combination of two constituents into a new constituent.
+
+ Additionally, a constituent object represents a range of words relative to
+ some std::vector of words. The range is from [begin, end) (i.e. including
+ begin but not including end, so using the normal C++ iterator notation).
+ Moreover, a constituent is always composed of two parts, each having a tag.
+ Therefore, the left part is composed of the words in the range [begin,k)
+ and has tag left_tag while the right part of the constituent contains the
+ words in the range [k,end) and has the tag right_tag.
+
+ The tags are user defined objects of type T. In general, they are used to
+ represent syntactic categories such as noun phrase, verb phrase, etc.
+ !*/
+
+ unsigned long begin, end, k;
+ T left_tag;
+ T right_tag;
+ };
+
+ template <
+ typename T
+ >
+ void serialize(
+ const constituent<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize(
+ constituent<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// -----------------------------------------------------------------------------------------
+
+ /*!A END_OF_TREE is used to indicate that parse_tree_element::left or
+ parse_tree_element::right doesn't point to another subtree.
+ !*/
+ const unsigned long END_OF_TREE = 0xFFFFFFFF;
+
+// -----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct parse_tree_element
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is used to represent a node in a binary parse tree. An entire
+ parse tree is represented by a std::vector of parse_tree_element objects.
+ We follow the convention that the first element of this vector is always
+ the root of the entire tree.
+
+ The fields of this object have the following interpretations:
+ - c == the constituent spanned by this node in the parse tree.
+ Therefore, the node spans the words in the range [c.begin, c.end).
+ - tag == the syntactic category of this node in the parse tree.
+ - score == the score or log likelihood for this parse tree. In
+ general, this is the sum of scores of all the production rules used
+ to build the tree rooted at the current node.
+ - let PT denote the vector of parse_tree_elements that defines an
+ entire parse tree. Then we have:
+ - if (left != END_OF_TREE) then
+ - PT[left] == the left sub-tree of the current node.
+ - PT[left] spans the words [c.begin, c.k)
+ - PT[left].tag == c.left_tag
+ - else
+ - there is no left sub-tree
+
+ - if (right != END_OF_TREE) then
+ - PT[right] == the right sub-tree of the current node.
+ - PT[right] spans the words [c.k, c.end)
+ - PT[right].tag == c.right_tag
+ - else
+ - there is no right sub-tree
+ !*/
+
+ constituent<T> c;
+ T tag;
+ double score;
+
+ unsigned long left;
+ unsigned long right;
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const parse_tree_element<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ parse_tree_element<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// -----------------------------------------------------------------------------------------
+// -----------------------------------------------------------------------------------------
+
+ void example_production_rule_function (
+ const std::vector<T>& words,
+ const constituent<T>& c,
+ std::vector<std::pair<T,double> >& possible_tags
+ )
+ /*!
+ requires
+ - 0 <= c.begin < c.k < c.end <= words.size()
+ - possible_tags.size() == 0
+ ensures
+ - Finds all the syntactic categories that can be used to label c and puts those
+ categories, along with their scores, into possible_tags. Or in other words,
+ this function determines which production rules can be used to turn the left
+ and right sub-constituents in c into a single constituent. The contents of c
+ have the following interpretations:
+ - The left sub-constituent has syntactic category c.left_tag
+ - for all i such that c.begin <= i < c.k:
+ - words[i] is part of the left sub-constituent.
+ - The right sub-constituent has syntactic category c.right_tag
+ - for all i such that c.k <= i < c.end:
+ - words[i] is part of the right sub-constituent.
+
+ - Note that example_production_rule_function() is not a real function. It is
+ here just to show you how to define production rule producing functions for
+ use with the find_max_parse_cky() routine defined below.
+ !*/
+
+ template <
+ typename T,
+ typename production_rule_function
+ >
+ void find_max_parse_cky (
+ const std::vector<T>& words,
+ const production_rule_function& production_rules,
+ std::vector<parse_tree_element<T> >& parse_tree
+ );
+ /*!
+ requires
+ - production_rule_function == a function or function object with the same
+ interface as example_production_rule_function defined above.
+ - It must be possible to store T objects in a std::map.
+ ensures
+ - Uses the CKY algorithm to find the most probable/highest scoring binary parse
+ tree of the given vector of words.
+ - if (#parse_tree.size() == 0) then
+ - There is no parse tree, using the given production_rules, that can cover
+ the given word sequence.
+ - else
+ - #parse_tree == the highest scoring parse tree that covers all the
+ elements of words.
+ - #parse_tree[0] == the root node of the parse tree.
+ - #parse_tree[0].score == the score of the parse tree. This is the sum of
+ the scores of all production rules used to construct the tree.
+ - #parse_tree[0].begin == 0
+ - #parse_tree[0].end == words.size()
+ - This function uses production_rules() to find out what the allowed production
+ rules are. That is, production_rules() defines all properties of the grammar
+ used by find_max_parse_cky().
+ !*/
+
+// -----------------------------------------------------------------------------------------
+// -----------------------------------------------------------------------------------------
+
+ class parse_tree_to_string_error : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown by parse_tree_to_string() and
+ parse_tree_to_string_tagged() if the inputs are discovered to be invalid.
+ !*/
+ };
+
+// -----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ std::string parse_tree_to_string (
+ const std::vector<parse_tree_element<T> >& tree,
+ const std::vector<U>& words,
+ const unsigned long root_idx = 0
+ );
+ /*!
+ requires
+ - It must be possible to print U objects to an ostream using operator<<
+ (typically, U would be something like std::string)
+ ensures
+ - Interprets tree as a parse tree defined over the given sequence of words.
+ - returns a bracketed string that represents the parse tree over the words.
+ For example, suppose the following parse tree is input:
+
+ /\
+ / \
+ /\ \
+ / \ \
+ the dog ran
+
+ Then the output would be the string "[[the dog] ran]"
+ - Only the sub-tree rooted at tree[root_idx] will be output. If root_idx >=
+ tree.size() then the empty string is returned.
+ throws
+ - parse_tree_to_string_error
+ This exception is thrown if an invalid tree is detected. This might happen
+ if the tree refers to elements of words that don't exist because words is
+ shorted than it is supposed to be.
+ !*/
+
+// -----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ std::string parse_tree_to_string_tagged (
+ const std::vector<parse_tree_element<T> >& tree,
+ const std::vector<U>& words,
+ const unsigned long root_idx = 0
+ );
+ /*!
+ requires
+ - It must be possible to print T objects to an ostream using operator<<
+ - It must be possible to print U objects to an ostream using operator<<
+ (typically, U would be something like std::string)
+ ensures
+ - This function does the same thing as parse_tree_to_string() except that it
+ also includes the parse_tree_element::tag object in the output. Therefore,
+ the tag of each bracket will be included as the first token inside the
+ bracket. For example, suppose the following parse tree is input (where tags
+ are shown at the vertices):
+
+ S
+ /\
+ NP \
+ /\ \
+ / \ \
+ the dog ran
+
+ Then the output would be the string "[S [NP the dog] ran]"
+ - Only the sub-tree rooted at tree[root_idx] will be output. If root_idx >=
+ tree.size() then the empty string is returned.
+ throws
+ - parse_tree_to_string_error
+ This exception is thrown if an invalid tree is detected. This might happen
+ if the tree refers to elements of words that don't exist because words is
+ shorted than it is supposed to be.
+ !*/
+
+// -----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ std::string parse_trees_to_string (
+ const std::vector<parse_tree_element<T> >& tree,
+ const std::vector<U>& words,
+ const T& tag_to_skip
+ );
+ /*!
+ requires
+ - It must be possible to print U objects to an ostream using operator<<
+ (typically, U would be something like std::string)
+ ensures
+ - This function behaves just like parse_tree_to_string() except that it will
+ not print the brackets (i.e. []) for the top most parts of the tree which
+ have tags equal to tag_to_skip. It will however print all the words.
+ Therefore, this function only includes brackets on the subtrees which begin
+ with a tag other than tag_to_skip.
+ throws
+ - parse_tree_to_string_error
+ This exception is thrown if an invalid tree is detected. This might happen
+ if the tree refers to elements of words that don't exist because words is
+ shorted than it is supposed to be.
+ !*/
+
+// -----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ std::string parse_trees_to_string_tagged (
+ const std::vector<parse_tree_element<T> >& tree,
+ const std::vector<U>& words,
+ const T& tag_to_skip
+ );
+ /*!
+ requires
+ - It must be possible to print T objects to an ostream using operator<<
+ - It must be possible to print U objects to an ostream using operator<<
+ (typically, U would be something like std::string)
+ ensures
+ - This function behaves just like parse_tree_to_string_tagged() except that it
+ will not print the brackets (i.e. []) for the top most parts of the tree
+ which have tags equal to tag_to_skip. It will however print all the words.
+ Therefore, this function only includes brackets on the subtrees which begin
+ with a tag other than tag_to_skip.
+ throws
+ - parse_tree_to_string_error
+ This exception is thrown if an invalid tree is detected. This might happen
+ if the tree refers to elements of words that don't exist because words is
+ shorted than it is supposed to be.
+ !*/
+
+// -----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void find_trees_not_rooted_with_tag (
+ const std::vector<parse_tree_element<T> >& tree,
+ const T& tag,
+ std::vector<unsigned long>& tree_roots
+ );
+ /*!
+ requires
+ - objects of type T must be comparable using operator==
+ ensures
+ - Finds all the largest non-overlapping trees in tree that are not rooted with
+ the given tag.
+ - find_trees_not_rooted_with_tag() is useful when you want to cut a parse tree
+ into a bunch of sub-trees and you know that the top level of the tree is all
+ composed of the same kind of tag. So if you want to just "slice off" the top
+ of the tree where this tag lives then this function is useful for doing that.
+ - #tree_roots.size() == the number of sub-trees found.
+ - for all valid i:
+ - tree[#tree_roots[i]].tag != tag
+ - To make the operation of this function clearer, here are a few examples of
+ what it will do:
+ - if (tree[0].tag != tag) then
+ - #tree_roots.size() == 0
+ - #tree_roots[0] == 0
+ - else if (tree[0].tag == tag but its immediate children's tags are not equal to tag) then
+ - #tree_roots.size() == 2
+ - #tree_roots[0] == tree[0].left
+ - #tree_roots[1] == tree[0].right
+ !*/
+
+// -----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_MAX_PARsE_CKY_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/optimization/find_optimal_parameters.h b/ml/dlib/dlib/optimization/find_optimal_parameters.h
new file mode 100644
index 000000000..0884778c9
--- /dev/null
+++ b/ml/dlib/dlib/optimization/find_optimal_parameters.h
@@ -0,0 +1,117 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_fIND_OPTIMAL_PARAMETERS_Hh_
+#define DLIB_fIND_OPTIMAL_PARAMETERS_Hh_
+
+#include "../matrix.h"
+#include "find_optimal_parameters_abstract.h"
+#include "optimization_bobyqa.h"
+#include "optimization_line_search.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ double find_optimal_parameters (
+ double initial_search_radius,
+ double eps,
+ const unsigned int max_f_evals,
+ matrix<double,0,1>& x,
+ const matrix<double,0,1>& x_lower,
+ const matrix<double,0,1>& x_upper,
+ const funct& f
+ )
+ {
+ DLIB_CASSERT(x.size() == x_lower.size() && x_lower.size() == x_upper.size() && x.size() > 0,
+ "\t double find_optimal_parameters()"
+ << "\n\t x.size(): " << x.size()
+ << "\n\t x_lower.size(): " << x_lower.size()
+ << "\n\t x_upper.size(): " << x_upper.size()
+ );
+
+ // check the requirements. Also split the assert up so that the error message isn't huge.
+ DLIB_CASSERT(max_f_evals > 1 && eps > 0 && initial_search_radius > eps,
+ "\t double find_optimal_parameters()"
+ << "\n\t Invalid arguments have been given to this function"
+ << "\n\t initial_search_radius: " << initial_search_radius
+ << "\n\t eps: " << eps
+ << "\n\t max_f_evals: " << max_f_evals
+ );
+
+ DLIB_CASSERT( min(x_upper - x_lower) > 0 &&
+ min(x - x_lower) >= 0 && min(x_upper - x) >= 0,
+ "\t double find_optimal_parameters()"
+ << "\n\t The bounds constraints have to make sense and also contain the starting point."
+ << "\n\t min(x_upper - x_lower): " << min(x_upper - x_lower)
+ << "\n\t min(x - x_lower) >= 0 && min(x_upper - x) >= 0: " << (min(x - x_lower) >= 0 && min(x_upper - x) >= 0)
+ );
+
+ // if the search radius is too big then shrink it so it fits inside the bounds.
+ if (initial_search_radius*2 >= min(x_upper-x_lower))
+ initial_search_radius = 0.5*min(x_upper-x_lower)*0.99;
+
+
+ double objective_val = std::numeric_limits<double>::infinity();
+ size_t num_iter_used = 0;
+ if (x.size() == 1)
+ {
+ // BOBYQA requires x to have at least 2 variables in it. So we can't call it in
+ // this case. Instead we call find_min_single_variable().
+ matrix<double,0,1> temp(1);
+ auto ff = [&](const double& xx)
+ {
+ temp = xx;
+ double obj = f(temp);
+ ++num_iter_used;
+ // keep track of the best x.
+ if (obj < objective_val)
+ {
+ objective_val = obj;
+ x = temp;
+ }
+ return obj;
+ };
+ try
+ {
+ double dx = x(0);
+ find_min_single_variable(ff, dx, x_lower(0), x_upper(0), eps, max_f_evals, initial_search_radius);
+ } catch (optimize_single_variable_failure& )
+ {
+ }
+ }
+ else
+ {
+ auto ff = [&](const matrix<double,0,1>& xx)
+ {
+ double obj = f(xx);
+ ++num_iter_used;
+ // keep track of the best x.
+ if (obj < objective_val)
+ {
+ objective_val = obj;
+ x = xx;
+ }
+ return obj;
+ };
+ try
+ {
+ matrix<double,0,1> start_x = x;
+ find_min_bobyqa(ff, start_x, 2*x.size()+1, x_lower, x_upper, initial_search_radius, eps, max_f_evals);
+ } catch (bobyqa_failure& )
+ {
+ }
+ }
+
+ return objective_val;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_fIND_OPTIMAL_PARAMETERS_Hh_
+
diff --git a/ml/dlib/dlib/optimization/find_optimal_parameters_abstract.h b/ml/dlib/dlib/optimization/find_optimal_parameters_abstract.h
new file mode 100644
index 000000000..96dcee89b
--- /dev/null
+++ b/ml/dlib/dlib/optimization/find_optimal_parameters_abstract.h
@@ -0,0 +1,58 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_fIND_OPTIMAL_PARAMETERS_ABSTRACT_Hh_
+#ifdef DLIB_fIND_OPTIMAL_PARAMETERS_ABSTRACT_Hh_
+
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ double find_optimal_parameters (
+ double initial_search_radius,
+ double eps,
+ const unsigned int max_f_evals,
+ matrix<double,0,1>& x,
+ const matrix<double,0,1>& x_lower,
+ const matrix<double,0,1>& x_upper,
+ const funct& f
+ );
+ /*!
+ requires
+ - f(x) must be a valid expression that evaluates to a double
+ - x.size() == x_lower.size() == x_upper.size()
+ - x.size() > 0
+ - 0 < eps < initial_search_radius
+ - max_f_evals > 1
+ - min(x_upper - x_lower) > 0
+ - min(x - x_lower) >= 0 && min(x_upper - x) >= 0
+ (i.e. the given x should be within the bounds defined by x_lower and x_upper)
+ ensures
+ - Performs a constrained minimization of the function f() starting from
+ the initial point x.
+ - This function does not require derivatives of f(). Instead, it uses
+ derivative free methods to find the best setting of x. In particular, it
+ will begin by searching within a sphere of radius initial_search_radius
+ around x and will continue searching until either f() has been called
+ max_f_evals times or the search area has been shrunk to less than eps radius.
+ - #x == the value of x (within the bounds defined by x_lower and x_upper) that
+ was found to minimize f(). More precisely, it will always be true that:
+ - min(#x - x_lower) >= 0 && min(x_upper - #x) >= 0
+ - returns f(#x).
+ throws
+ - No exception is thrown for executing max_f_evals iterations. This function
+ will simply output the best x it has seen if it runs out of iterations.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_fIND_OPTIMAL_PARAMETERS_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/isotonic_regression.h b/ml/dlib/dlib/optimization/isotonic_regression.h
new file mode 100644
index 000000000..e89e70b48
--- /dev/null
+++ b/ml/dlib/dlib/optimization/isotonic_regression.h
@@ -0,0 +1,169 @@
+// Copyright (C) 2018 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ISOTONIC_ReGRESSION_H_
+#define DLIB_ISOTONIC_ReGRESSION_H_
+
+#include "isotonic_regression_abstract.h"
+#include <vector>
+#include <utility>
+
+namespace dlib
+{
+ class isotonic_regression
+ {
+ public:
+
+ template <
+ typename const_iterator,
+ typename iterator
+ >
+ void operator() (
+ const_iterator begin,
+ const_iterator end,
+ iterator obegin
+ )
+ {
+ do_isotonic_regression(begin, end);
+
+ // unpack blocks to output
+ for (auto& block : blocks)
+ {
+ for (size_t k = 0; k < block.num; ++k)
+ set_val(*obegin++, block.avg);
+ }
+
+ blocks.clear();
+ }
+
+ void operator() (
+ std::vector<double>& vect
+ ) { (*this)(vect.begin(), vect.end(), vect.begin()); }
+
+ template <typename T, typename U>
+ void operator() (
+ std::vector<std::pair<T,U>>& vect
+ ) { (*this)(vect.begin(), vect.end(), vect.begin()); }
+
+
+ template <
+ typename const_iterator,
+ typename iterator
+ >
+ void fit_with_linear_output_interpolation (
+ const_iterator begin,
+ const_iterator end,
+ iterator obegin
+ )
+ {
+ do_isotonic_regression(begin, end);
+
+ // Unpack blocks to output, but here instead of producing the step function
+ // output we linearly interpolate. Note that this actually fits the data less
+ // than the step-function, but in many applications might be closer to what you
+ // really when when using isotonic_regression than the step function.
+ for (size_t i = 0; i < blocks.size(); ++i)
+ {
+ auto& block = blocks[i];
+
+ double prev = (blocks.front().avg + block.avg)/2;
+ if (i > 0)
+ prev = (blocks[i-1].avg+block.avg)/2;
+
+ double next = (blocks.back().avg + block.avg)/2;
+ if (i+1 < blocks.size())
+ next = (blocks[i+1].avg+block.avg)/2;
+
+ for (size_t k = 0; k < block.num; ++k)
+ {
+ const auto mid = block.num/2.0;
+ if (k < mid)
+ {
+ const double alpha = k/mid;
+ set_val(*obegin++, (1-alpha)*prev + alpha*block.avg);
+ }
+ else
+ {
+ const double alpha = k/mid-1;
+ set_val(*obegin++, alpha*next + (1-alpha)*block.avg);
+ }
+ }
+ }
+
+ blocks.clear();
+ }
+
+ void fit_with_linear_output_interpolation (
+ std::vector<double>& vect
+ ) { fit_with_linear_output_interpolation(vect.begin(), vect.end(), vect.begin()); }
+
+ template <typename T, typename U>
+ void fit_with_linear_output_interpolation (
+ std::vector<std::pair<T,U>>& vect
+ ) { fit_with_linear_output_interpolation(vect.begin(), vect.end(), vect.begin()); }
+
+ private:
+
+ template <
+ typename const_iterator
+ >
+ void do_isotonic_regression (
+ const_iterator begin,
+ const_iterator end
+ )
+ {
+ blocks.clear();
+
+ // Do the actual isotonic regression. The output is a step-function and is
+ // stored in the vector of blocks.
+ for (auto i = begin; i != end; ++i)
+ {
+ blocks.emplace_back(get_val(*i));
+ while (blocks.size() > 1 && prev_block().avg > current_block().avg)
+ {
+ // merge the last two blocks.
+ prev_block() = prev_block() + current_block();
+ blocks.pop_back();
+ }
+ }
+ }
+
+
+ template <typename T>
+ static double get_val(const T& v) { return v;}
+
+ template <typename T, typename U>
+ static double get_val(const std::pair<T,U>& v) { return v.second;}
+
+ template <typename T>
+ static void set_val(T& v, double val) { v = val;}
+
+ template <typename T, typename U>
+ static void set_val(std::pair<T,U>& v, double val) { v.second = val;}
+
+
+
+ struct block_t
+ {
+ block_t(double val) : num(1), avg(val) {}
+ block_t(size_t n, double val) : num(n), avg(val) {}
+
+ size_t num;
+ double avg;
+
+ inline block_t operator+(const block_t& rhs) const
+ {
+ return block_t(num+rhs.num,
+ (num*avg + rhs.num*rhs.avg)/(num+rhs.num));
+ }
+ };
+
+ inline block_t& prev_block() { return blocks[blocks.size()-2]; }
+ inline block_t& current_block() { return blocks.back(); }
+
+ std::vector<block_t> blocks;
+ };
+}
+
+#endif // DLIB_ISOTONIC_ReGRESSION_H_
+
+
diff --git a/ml/dlib/dlib/optimization/isotonic_regression_abstract.h b/ml/dlib/dlib/optimization/isotonic_regression_abstract.h
new file mode 100644
index 000000000..b00334bef
--- /dev/null
+++ b/ml/dlib/dlib/optimization/isotonic_regression_abstract.h
@@ -0,0 +1,128 @@
+// Copyright (C) 2018 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ISOTONIC_ReGRESSION_ABSTRACT_H_
+#ifdef DLIB_ISOTONIC_ReGRESSION_ABSTRACT_H_
+
+#include <vector>
+#include <utility>
+
+namespace dlib
+{
+ class isotonic_regression
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for performing 1-D isotonic regression. That is, it
+ finds the least squares fit of a non-parametric curve to some user supplied
+ data, subject to the constraint that the fitted curve is non-decreasing.
+
+ This is done using the fast O(n) pool adjacent violators algorithm.
+ !*/
+
+ public:
+
+ template <
+ typename const_iterator,
+ typename iterator
+ >
+ void operator() (
+ const_iterator begin,
+ const_iterator end,
+ iterator obegin
+ );
+ /*!
+ requires
+ - [begin,end) is an iterator range of float or doubles or a range of
+ std::pair<T,double> or std::pair<T,float> where T an be anything.
+ - obegin points to an iterator range at least std::distance(begin,end).
+ - obegin points to an iterator range of objects of type float, double, std::pair<T,float>, or std::pair<T,double>.
+ ensures
+ - Given the range of real values stored in [begin,end), this method performs isotonic regression
+ on this data and writes the results to obegin. To be specific:
+ - let IN refer to the input values stored in the iterator range [begin,end).
+ - let OUT refer to the output values stored in the iterator range [obegin, obegin+std::distance(begin,end)).
+ - This function populates OUT with values such that the sum_i of
+ (IN[i]-OUT[i])^2 is minimized, subject to the constraint that
+ OUT[i] <= OUT[i+1], i.e. that OUT is monotonic.
+ - It is OK for [begin,end) to overlap with the range pointed to by obegin.
+ That is, this function can run in-place.
+ - Note that when the inputs or outputs are std::pairs this algorithm only
+ looks at the .second field of the pair. It therefore still treats these
+ iterator ranges as ranges of reals since it only looks at the .second
+ field, which is a real number. The .first field is entirely ignored.
+ !*/
+
+ void operator() (
+ std::vector<double>& vect
+ ) { (*this)(vect.begin(), vect.end(), vect.begin()); }
+ /*!
+ ensures
+ - performs in-place isotonic regression. Therefore, #vect will contain the
+ isotonic regression of vect.
+ - #vect.size() == vect.size()
+ !*/
+
+ template <typename T, typename U>
+ void operator() (
+ std::vector<std::pair<T,U>>& vect
+ ) { (*this)(vect.begin(), vect.end(), vect.begin()); }
+ /*!
+ ensures
+ - performs in-place isotonic regression. Therefore, #vect will contain the
+ isotonic regression of vect.
+ - #vect.size() == vect.size()
+ !*/
+
+
+ template <
+ typename const_iterator,
+ typename iterator
+ >
+ void fit_with_linear_output_interpolation (
+ const_iterator begin,
+ const_iterator end,
+ iterator obegin
+ );
+ /*!
+ requires
+ - [begin,end) is an iterator range of float or doubles or a range of
+ std::pair<T,double> or std::pair<T,float> where T an be anything.
+ - obegin points to an iterator range at least std::distance(begin,end).
+ - obegin points to an iterator range of objects of type float, double, std::pair<T,float>, or std::pair<T,double>.
+ ensures
+ - This function behaves just like (*this)(begin,end,obegin) except that the
+ output is interpolated. To explain, not that the optimal output of
+ isotonic regression is a step function. However, in many applications
+ that isn't really what you want. You want something smoother. So
+ fit_with_linear_output_interpolation() does isotonic regression and then
+ linearly interpolates the step function into a piecewise linear function.
+ !*/
+
+ void fit_with_linear_output_interpolation (
+ std::vector<double>& vect
+ ) { fit_with_linear_output_interpolation(vect.begin(), vect.end(), vect.begin()); }
+ /*!
+ ensures
+ - performs in-place isotonic regression. Therefore, #vect will contain the
+ isotonic regression of vect.
+ - #vect.size() == vect.size()
+ !*/
+
+ template <typename T, typename U>
+ void fit_with_linear_output_interpolation (
+ std::vector<std::pair<T,U>>& vect
+ ) { fit_with_linear_output_interpolation(vect.begin(), vect.end(), vect.begin()); }
+ /*!
+ ensures
+ - performs in-place isotonic regression. Therefore, #vect will contain the
+ isotonic regression of vect.
+ - #vect.size() == vect.size()
+ !*/
+
+ };
+}
+
+#endif // DLIB_ISOTONIC_ReGRESSION_ABSTRACT_H_
+
+
+
diff --git a/ml/dlib/dlib/optimization/max_cost_assignment.h b/ml/dlib/dlib/optimization/max_cost_assignment.h
new file mode 100644
index 000000000..db6c6f0d7
--- /dev/null
+++ b/ml/dlib/dlib/optimization/max_cost_assignment.h
@@ -0,0 +1,288 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MAX_COST_ASSIgNMENT_Hh_
+#define DLIB_MAX_COST_ASSIgNMENT_Hh_
+
+#include "max_cost_assignment_abstract.h"
+#include "../matrix.h"
+#include <vector>
+#include <deque>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ typename EXP::type assignment_cost (
+ const matrix_exp<EXP>& cost,
+ const std::vector<long>& assignment
+ )
+ {
+ DLIB_ASSERT(cost.nr() == cost.nc(),
+ "\t type assignment_cost(cost,assignment)"
+ << "\n\t cost.nr(): " << cost.nr()
+ << "\n\t cost.nc(): " << cost.nc()
+ );
+#ifdef ENABLE_ASSERTS
+ // can't call max on an empty vector. So put an if here to guard against it.
+ if (assignment.size() > 0)
+ {
+ DLIB_ASSERT(0 <= min(mat(assignment)) && max(mat(assignment)) < cost.nr(),
+ "\t type assignment_cost(cost,assignment)"
+ << "\n\t cost.nr(): " << cost.nr()
+ << "\n\t cost.nc(): " << cost.nc()
+ << "\n\t min(assignment): " << min(mat(assignment))
+ << "\n\t max(assignment): " << max(mat(assignment))
+ );
+ }
+#endif
+
+ typename EXP::type temp = 0;
+ for (unsigned long i = 0; i < assignment.size(); ++i)
+ {
+ temp += cost(i, assignment[i]);
+ }
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename EXP>
+ inline void compute_slack(
+ const long x,
+ std::vector<typename EXP::type>& slack,
+ std::vector<long>& slackx,
+ const matrix_exp<EXP>& cost,
+ const std::vector<typename EXP::type>& lx,
+ const std::vector<typename EXP::type>& ly
+ )
+ {
+ for (long y = 0; y < cost.nc(); ++y)
+ {
+ if (lx[x] + ly[y] - cost(x,y) < slack[y])
+ {
+ slack[y] = lx[x] + ly[y] - cost(x,y);
+ slackx[y] = x;
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ std::vector<long> max_cost_assignment (
+ const matrix_exp<EXP>& cost_
+ )
+ {
+ const_temp_matrix<EXP> cost(cost_);
+ typedef typename EXP::type type;
+ // This algorithm only works if the elements of the cost matrix can be reliably
+ // compared using operator==. However, comparing for equality with floating point
+ // numbers is not a stable operation. So you need to use an integer cost matrix.
+ COMPILE_TIME_ASSERT(std::numeric_limits<type>::is_integer);
+ DLIB_ASSERT(cost.nr() == cost.nc(),
+ "\t std::vector<long> max_cost_assignment(cost)"
+ << "\n\t cost.nr(): " << cost.nr()
+ << "\n\t cost.nc(): " << cost.nc()
+ );
+
+ using namespace dlib::impl;
+ /*
+ I based the implementation of this algorithm on the description of the
+ Hungarian algorithm on the following websites:
+ http://www.math.uwo.ca/~mdawes/courses/344/kuhn-munkres.pdf
+ http://www.topcoder.com/tc?module=Static&d1=tutorials&d2=hungarianAlgorithm
+
+ Note that this is the fast O(n^3) version of the algorithm.
+ */
+
+ if (cost.size() == 0)
+ return std::vector<long>();
+
+ std::vector<type> lx, ly;
+ std::vector<long> xy;
+ std::vector<long> yx;
+ std::vector<char> S, T;
+ std::vector<type> slack;
+ std::vector<long> slackx;
+ std::vector<long> aug_path;
+
+
+
+
+ // Initially, nothing is matched.
+ xy.assign(cost.nc(), -1);
+ yx.assign(cost.nc(), -1);
+ /*
+ We maintain the following invariant:
+ Vertex x is matched to vertex xy[x] and
+ vertex y is matched to vertex yx[y].
+
+ A value of -1 means a vertex isn't matched to anything. Moreover,
+ x corresponds to rows of the cost matrix and y corresponds to the
+ columns of the cost matrix. So we are matching X to Y.
+ */
+
+
+ // Create an initial feasible labeling. Moreover, in the following
+ // code we will always have:
+ // for all valid x and y: lx[x] + ly[y] >= cost(x,y)
+ lx.resize(cost.nc());
+ ly.assign(cost.nc(),0);
+ for (long x = 0; x < cost.nr(); ++x)
+ lx[x] = max(rowm(cost,x));
+
+ // Now grow the match set by picking edges from the equality subgraph until
+ // we have a complete matching.
+ for (long match_size = 0; match_size < cost.nc(); ++match_size)
+ {
+ std::deque<long> q;
+
+ // Empty out the S and T sets
+ S.assign(cost.nc(), false);
+ T.assign(cost.nc(), false);
+
+ // clear out old slack values
+ slack.assign(cost.nc(), std::numeric_limits<type>::max());
+ slackx.resize(cost.nc());
+ /*
+ slack and slackx are maintained such that we always
+ have the following (once they get initialized by compute_slack() below):
+ - for all y:
+ - let x == slackx[y]
+ - slack[y] == lx[x] + ly[y] - cost(x,y)
+ */
+
+ aug_path.assign(cost.nc(), -1);
+
+ for (long x = 0; x < cost.nc(); ++x)
+ {
+ // If x is not matched to anything
+ if (xy[x] == -1)
+ {
+ q.push_back(x);
+ S[x] = true;
+
+ compute_slack(x, slack, slackx, cost, lx, ly);
+ break;
+ }
+ }
+
+
+ long x_start = 0;
+ long y_start = 0;
+
+ // Find an augmenting path.
+ bool found_augmenting_path = false;
+ while (!found_augmenting_path)
+ {
+ while (q.size() > 0 && !found_augmenting_path)
+ {
+ const long x = q.front();
+ q.pop_front();
+ for (long y = 0; y < cost.nc(); ++y)
+ {
+ if (cost(x,y) == lx[x] + ly[y] && !T[y])
+ {
+ // if vertex y isn't matched with anything
+ if (yx[y] == -1)
+ {
+ y_start = y;
+ x_start = x;
+ found_augmenting_path = true;
+ break;
+ }
+
+ T[y] = true;
+ q.push_back(yx[y]);
+
+ aug_path[yx[y]] = x;
+ S[yx[y]] = true;
+ compute_slack(yx[y], slack, slackx, cost, lx, ly);
+ }
+ }
+ }
+
+ if (found_augmenting_path)
+ break;
+
+
+ // Since we didn't find an augmenting path we need to improve the
+ // feasible labeling stored in lx and ly. We also need to keep the
+ // slack updated accordingly.
+ type delta = std::numeric_limits<type>::max();
+ for (unsigned long i = 0; i < T.size(); ++i)
+ {
+ if (!T[i])
+ delta = std::min(delta, slack[i]);
+ }
+ for (unsigned long i = 0; i < T.size(); ++i)
+ {
+ if (S[i])
+ lx[i] -= delta;
+
+ if (T[i])
+ ly[i] += delta;
+ else
+ slack[i] -= delta;
+ }
+
+
+
+ q.clear();
+ for (long y = 0; y < cost.nc(); ++y)
+ {
+ if (!T[y] && slack[y] == 0)
+ {
+ // if vertex y isn't matched with anything
+ if (yx[y] == -1)
+ {
+ x_start = slackx[y];
+ y_start = y;
+ found_augmenting_path = true;
+ break;
+ }
+ else
+ {
+ T[y] = true;
+ if (!S[yx[y]])
+ {
+ q.push_back(yx[y]);
+
+ aug_path[yx[y]] = slackx[y];
+ S[yx[y]] = true;
+ compute_slack(yx[y], slack, slackx, cost, lx, ly);
+ }
+ }
+ }
+ }
+ } // end while (!found_augmenting_path)
+
+ // Flip the edges along the augmenting path. This means we will add one more
+ // item to our matching.
+ for (long cx = x_start, cy = y_start, ty;
+ cx != -1;
+ cx = aug_path[cx], cy = ty)
+ {
+ ty = xy[cx];
+ yx[cy] = cx;
+ xy[cx] = cy;
+ }
+
+ }
+
+
+ return xy;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MAX_COST_ASSIgNMENT_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/max_cost_assignment_abstract.h b/ml/dlib/dlib/optimization/max_cost_assignment_abstract.h
new file mode 100644
index 000000000..bbdb0abfb
--- /dev/null
+++ b/ml/dlib/dlib/optimization/max_cost_assignment_abstract.h
@@ -0,0 +1,63 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MAX_COST_ASSIgNMENT_ABSTRACT_Hh_
+#ifdef DLIB_MAX_COST_ASSIgNMENT_ABSTRACT_Hh_
+
+#include "../matrix.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ typename EXP::type assignment_cost (
+ const matrix_exp<EXP>& cost,
+ const std::vector<long>& assignment
+ );
+ /*!
+ requires
+ - cost.nr() == cost.nc()
+ - for all valid i:
+ - 0 <= assignment[i] < cost.nr()
+ ensures
+ - Interprets cost as a cost assignment matrix. That is, cost(i,j)
+ represents the cost of assigning i to j.
+ - Interprets assignment as a particular set of assignments. That is,
+ i is assigned to assignment[i].
+ - returns the cost of the given assignment. That is, returns
+ a number which is:
+ sum over i: cost(i,assignment[i])
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ std::vector<long> max_cost_assignment (
+ const matrix_exp<EXP>& cost
+ );
+ /*!
+ requires
+ - EXP::type == some integer type (e.g. int)
+ (i.e. cost must contain integers rather than floats or doubles)
+ - cost.nr() == cost.nc()
+ ensures
+ - Finds and returns the solution to the following optimization problem:
+
+ Maximize: f(A) == assignment_cost(cost, A)
+ Subject to the following constraints:
+ - The elements of A are unique. That is, there aren't any
+ elements of A which are equal.
+ - A.size() == cost.nr()
+
+ - This function implements the O(N^3) version of the Hungarian algorithm
+ where N is the number of rows in the cost matrix.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MAX_COST_ASSIgNMENT_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/optimization/max_sum_submatrix.h b/ml/dlib/dlib/optimization/max_sum_submatrix.h
new file mode 100644
index 000000000..1986cc26b
--- /dev/null
+++ b/ml/dlib/dlib/optimization/max_sum_submatrix.h
@@ -0,0 +1,285 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MAX_SUM_SUBMaTRIX_Hh_
+#define DLIB_MAX_SUM_SUBMaTRIX_Hh_
+
+#include "max_sum_submatrix_abstract.h"
+#include "../matrix.h"
+#include <vector>
+#include <queue>
+#include "../geometry.h"
+
+namespace dlib
+{
+ namespace impl
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct range_set
+ {
+ int top_min;
+ int top_max;
+ int bottom_min;
+ int bottom_max;
+ T weight;
+
+ bool operator<(const range_set& item) const { return weight < item.weight; }
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ bool is_terminal_set (
+ const range_set<T>& item
+ )
+ {
+ return (item.top_min >= item.top_max &&
+ item.bottom_min >= item.bottom_max);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ void split (
+ const range_set<T>& rset,
+ range_set<T>& a,
+ range_set<T>& b
+ )
+ {
+ if (rset.top_max - rset.top_min > rset.bottom_max - rset.bottom_min)
+ {
+ // split top
+ const int middle = (rset.top_max + rset.top_min)/2;
+ a.top_min = rset.top_min;
+ a.top_max = middle;
+ b.top_min = middle+1;
+ b.top_max = rset.top_max;
+
+ a.bottom_min = rset.bottom_min;
+ a.bottom_max = rset.bottom_max;
+ b.bottom_min = rset.bottom_min;
+ b.bottom_max = rset.bottom_max;
+ }
+ else
+ {
+ // split bottom
+ const int middle = (rset.bottom_max + rset.bottom_min)/2;
+ a.bottom_min = rset.bottom_min;
+ a.bottom_max = middle;
+ b.bottom_min = middle+1;
+ b.bottom_max = rset.bottom_max;
+
+ a.top_min = rset.top_min;
+ a.top_max = rset.top_max;
+ b.top_min = rset.top_min;
+ b.top_max = rset.top_max;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename EXP, typename T>
+ void find_best_column_range (
+ const matrix_exp<EXP>& sum_pos,
+ const matrix_exp<EXP>& sum_neg,
+ const range_set<T>& row_range,
+ T& weight,
+ int& left,
+ int& right
+ )
+ {
+ left = 0;
+ right = -1;
+ weight = 0;
+ T cur_sum = 0;
+ int cur_pos = 0;
+ for (long c = 0; c < sum_pos.nc(); ++c)
+ {
+ // compute the value for the current column
+ T temp = sum_pos(row_range.bottom_max+1,c) - sum_pos(row_range.top_min,c);
+ if (row_range.top_max <= row_range.bottom_min)
+ temp += sum_neg(row_range.bottom_min+1,c) - sum_neg(row_range.top_max,c);
+
+
+ cur_sum += temp;
+ if (cur_sum > weight)
+ {
+ left = cur_pos;
+ right = c;
+ weight = cur_sum;
+ }
+
+ if (cur_sum <= 0)
+ {
+ cur_sum = 0;
+ cur_pos = c+1;
+ }
+
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ std::vector<rectangle> max_sum_submatrix(
+ const matrix_exp<EXP>& mat,
+ unsigned long max_rects,
+ double thresh_ = 0
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(thresh_ >= 0 && mat.size() > 0,
+ "\t std::vector<rectangle> max_sum_submatrix()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t mat.size(): " << mat.size()
+ << "\n\t thresh_: " << thresh_
+ );
+
+ /*
+ This function is basically an implementation of the efficient subwindow search (I-ESS)
+ algorithm presented in the following paper:
+ Efficient Algorithms for Subwindow Search in Object Detection and Localization
+ by Senjian An, Patrick Peursum, Wanquan Liu and Svetha Venkatesh
+ In CVPR 2009
+
+ */
+
+
+ if (max_rects == 0)
+ return std::vector<rectangle>();
+
+ using namespace dlib::impl;
+ typedef typename EXP::type element_type;
+ typedef typename promote<element_type>::type scalar_type;
+
+ const scalar_type thresh = static_cast<scalar_type>(thresh_);
+
+
+ matrix<scalar_type> sum_pos;
+ matrix<scalar_type> sum_neg;
+ sum_pos.set_size(mat.nr()+1, mat.nc());
+ sum_neg.set_size(mat.nr()+1, mat.nc());
+ // integrate over the rows.
+ for (long c = 0; c < mat.nc(); ++c)
+ {
+ sum_pos(0,c) = 0;
+ sum_neg(0,c) = 0;
+ }
+ for (long r = 0; r < mat.nr(); ++r)
+ {
+ for (long c = 0; c < mat.nc(); ++c)
+ {
+ if (mat(r,c) > 0)
+ {
+ sum_pos(r+1,c) = mat(r,c) + sum_pos(r,c);
+ sum_neg(r+1,c) = sum_neg(r,c);
+ }
+ else
+ {
+ sum_pos(r+1,c) = sum_pos(r,c);
+ sum_neg(r+1,c) = mat(r,c) + sum_neg(r,c);
+ }
+ }
+ }
+
+ std::priority_queue<range_set<scalar_type> > q;
+
+ // the range_sets will represent ranges of columns
+ range_set<scalar_type> universe_set;
+ universe_set.bottom_min = 0;
+ universe_set.top_min = 0;
+ universe_set.bottom_max = mat.nr()-1;
+ universe_set.top_max = mat.nr()-1;
+ universe_set.weight = sum(rowm(dlib::mat(sum_pos),mat.nr()));
+
+ q.push(universe_set);
+
+ std::vector<rectangle> results;
+ std::vector<scalar_type> temp_pos(mat.nc());
+ std::vector<scalar_type> temp_neg(mat.nc());
+
+ while (q.size() > 0)
+ {
+ if (is_terminal_set(q.top()))
+ {
+ int left, right;
+ scalar_type weight;
+ find_best_column_range(sum_pos, sum_neg, q.top(), weight, left, right);
+
+ rectangle rect(left, q.top().top_min,
+ right, q.top().bottom_min);
+
+ if (weight <= thresh)
+ break;
+
+ results.push_back(rect);
+
+ if (results.size() >= max_rects)
+ break;
+
+ q = std::priority_queue<range_set<scalar_type> >();
+ // We are going to blank out the weights we just used. So adjust the sum images appropriately.
+ for (long c = rect.left(); c <= rect.right(); ++c)
+ {
+ temp_pos[c] = sum_pos(rect.bottom()+1,c) - sum_pos(rect.top(),c);
+ temp_neg[c] = sum_neg(rect.bottom()+1,c) - sum_neg(rect.top(),c);
+ }
+ // blank out the area inside the rectangle
+ for (long r = rect.top(); r <= rect.bottom(); ++r)
+ {
+ for (long c = rect.left(); c <= rect.right(); ++c)
+ {
+ sum_pos(r+1,c) = sum_pos(r,c);
+ sum_neg(r+1,c) = sum_neg(r,c);
+ }
+ }
+ // account for the area below the rectangle
+ for (long r = rect.bottom()+2; r < sum_pos.nr(); ++r)
+ {
+ for (long c = rect.left(); c <= rect.right(); ++c)
+ {
+ sum_pos(r,c) -= temp_pos[c];
+ sum_neg(r,c) -= temp_neg[c];
+ }
+ }
+
+
+ universe_set.weight = sum(rowm(dlib::mat(sum_pos),mat.nr()));
+ if (universe_set.weight <= thresh)
+ break;
+
+ q.push(universe_set);
+ continue;
+ }
+
+ range_set<scalar_type> a, b;
+ split(q.top(), a,b);
+ q.pop();
+
+ // these variables are not used at this point in the algorithm.
+ int a_left, a_right;
+ int b_left, b_right;
+
+ find_best_column_range(sum_pos, sum_neg, a, a.weight, a_left, a_right);
+ find_best_column_range(sum_pos, sum_neg, b, b.weight, b_left, b_right);
+
+ if (a.weight > thresh)
+ q.push(a);
+ if (b.weight > thresh)
+ q.push(b);
+
+ }
+
+
+ return results;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MAX_SUM_SUBMaTRIX_Hh_
+
diff --git a/ml/dlib/dlib/optimization/max_sum_submatrix_abstract.h b/ml/dlib/dlib/optimization/max_sum_submatrix_abstract.h
new file mode 100644
index 000000000..6714dd7fe
--- /dev/null
+++ b/ml/dlib/dlib/optimization/max_sum_submatrix_abstract.h
@@ -0,0 +1,49 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MAX_SUM_SUBMaTRIX_ABSTRACT_Hh_
+#ifdef DLIB_MAX_SUM_SUBMaTRIX_ABSTRACT_Hh_
+
+#include "../matrix.h"
+#include <vector>
+#include "../geometry.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ std::vector<rectangle> max_sum_submatrix(
+ const matrix_exp<EXP>& mat,
+ unsigned long max_rects,
+ double thresh = 0
+ );
+ /*!
+ requires
+ - thresh >= 0
+ - mat.size() != 0
+ ensures
+ - This function finds the submatrix within mat which has the largest sum. It then
+ zeros out that submatrix and repeats the process until no more maximal submatrices can
+ be found. The std::vector returned will be ordered so that the rectangles with the
+ largest sum come first.
+ - Each submatrix must have a sum greater than thresh. If no such submatrix exists then
+ the algorithm terminates and returns an empty std::vector.
+ - At most max_rects rectangles are returned.
+
+ - This function is basically an implementation of the efficient subwindow search (I-ESS)
+ algorithm presented in the following paper:
+ Efficient Algorithms for Subwindow Search in Object Detection and Localization
+ by Senjian An, Patrick Peursum, Wanquan Liu and Svetha Venkatesh
+ In CVPR 2009
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MAX_SUM_SUBMaTRIX_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/optimization.h b/ml/dlib/dlib/optimization/optimization.h
new file mode 100644
index 000000000..561d64376
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization.h
@@ -0,0 +1,714 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATIOn_H_
+#define DLIB_OPTIMIZATIOn_H_
+
+#include <cmath>
+#include <limits>
+#include "optimization_abstract.h"
+#include "optimization_search_strategies.h"
+#include "optimization_stop_strategies.h"
+#include "optimization_line_search.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Functions that transform other functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct>
+ class central_differences
+ {
+ public:
+ central_differences(const funct& f_, double eps_ = 1e-7) : f(f_), eps(eps_){}
+
+ template <typename T>
+ typename T::matrix_type operator()(const T& x) const
+ {
+ // T must be some sort of dlib matrix
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+
+ typename T::matrix_type der(x.size());
+ typename T::matrix_type e(x);
+ for (long i = 0; i < x.size(); ++i)
+ {
+ const double old_val = e(i);
+
+ e(i) += eps;
+ const double delta_plus = f(e);
+ e(i) = old_val - eps;
+ const double delta_minus = f(e);
+
+ der(i) = (delta_plus - delta_minus)/((old_val+eps)-(old_val-eps));
+
+ // and finally restore the old value of this element
+ e(i) = old_val;
+ }
+
+ return der;
+ }
+
+ template <typename T, typename U>
+ typename U::matrix_type operator()(const T& item, const U& x) const
+ {
+ // U must be some sort of dlib matrix
+ COMPILE_TIME_ASSERT(is_matrix<U>::value);
+
+ typename U::matrix_type der(x.size());
+ typename U::matrix_type e(x);
+ for (long i = 0; i < x.size(); ++i)
+ {
+ const double old_val = e(i);
+
+ e(i) += eps;
+ const double delta_plus = f(item,e);
+ e(i) = old_val - eps;
+ const double delta_minus = f(item,e);
+
+ der(i) = (delta_plus - delta_minus)/((old_val+eps)-(old_val-eps));
+
+ // and finally restore the old value of this element
+ e(i) = old_val;
+ }
+
+ return der;
+ }
+
+
+ double operator()(const double& x) const
+ {
+ return (f(x+eps)-f(x-eps))/((x+eps)-(x-eps));
+ }
+
+ private:
+ const funct& f;
+ const double eps;
+ };
+
+ template <typename funct>
+ const central_differences<funct> derivative(const funct& f) { return central_differences<funct>(f); }
+ template <typename funct>
+ const central_differences<funct> derivative(const funct& f, double eps)
+ {
+ DLIB_ASSERT (
+ eps > 0,
+ "\tcentral_differences derivative(f,eps)"
+ << "\n\tYou must give an epsilon > 0"
+ << "\n\teps: " << eps
+ );
+ return central_differences<funct>(f,eps);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct, typename EXP1, typename EXP2>
+ struct clamped_function_object
+ {
+ clamped_function_object(
+ const funct& f_,
+ const matrix_exp<EXP1>& x_lower_,
+ const matrix_exp<EXP2>& x_upper_
+ ) : f(f_), x_lower(x_lower_), x_upper(x_upper_)
+ {
+ }
+
+ template <typename T>
+ double operator() (
+ const T& x
+ ) const
+ {
+ return f(clamp(x,x_lower,x_upper));
+ }
+
+ const funct& f;
+ const matrix_exp<EXP1>& x_lower;
+ const matrix_exp<EXP2>& x_upper;
+ };
+
+ template <typename funct, typename EXP1, typename EXP2>
+ clamped_function_object<funct,EXP1,EXP2> clamp_function(
+ const funct& f,
+ const matrix_exp<EXP1>& x_lower,
+ const matrix_exp<EXP2>& x_upper
+ ) { return clamped_function_object<funct,EXP1,EXP2>(f,x_lower,x_upper); }
+
+// ----------------------------------------------------------------------------------------
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Functions that perform unconstrained optimization
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T
+ >
+ double find_min (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ double min_f
+ )
+ {
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ DLIB_CASSERT (
+ is_col_vector(x),
+ "\tdouble find_min()"
+ << "\n\tYou have to supply column vectors to this function"
+ << "\n\tx.nc(): " << x.nc()
+ );
+
+
+ T g, s;
+
+ double f_value = f(x);
+ g = der(x);
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+
+ while(stop_strategy.should_continue_search(x, f_value, g) && f_value > min_f)
+ {
+ s = search_strategy.get_next_direction(x, f_value, g);
+
+ double alpha = line_search(
+ make_line_search_function(f,x,s, f_value),
+ f_value,
+ make_line_search_function(der,x,s, g),
+ dot(g,s), // compute initial gradient for the line search
+ search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f,
+ search_strategy.get_max_line_search_iterations());
+
+ // Take the search step indicated by the above line search
+ x += alpha*s;
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+ }
+
+ return f_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T
+ >
+ double find_max (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ double max_f
+ )
+ {
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ DLIB_CASSERT (
+ is_col_vector(x),
+ "\tdouble find_max()"
+ << "\n\tYou have to supply column vectors to this function"
+ << "\n\tx.nc(): " << x.nc()
+ );
+
+ T g, s;
+
+ // This function is basically just a copy of find_min() but with - put in the right places
+ // to flip things around so that it ends up looking for the max rather than the min.
+
+ double f_value = -f(x);
+ g = -der(x);
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+
+ while(stop_strategy.should_continue_search(x, f_value, g) && f_value > -max_f)
+ {
+ s = search_strategy.get_next_direction(x, f_value, g);
+
+ double alpha = line_search(
+ negate_function(make_line_search_function(f,x,s, f_value)),
+ f_value,
+ negate_function(make_line_search_function(der,x,s, g)),
+ dot(g,s), // compute initial gradient for the line search
+ search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), -max_f,
+ search_strategy.get_max_line_search_iterations()
+ );
+
+ // Take the search step indicated by the above line search
+ x += alpha*s;
+
+ // Don't forget to negate these outputs from the line search since they are
+ // from the unnegated versions of f() and der()
+ g *= -1;
+ f_value *= -1;
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+ }
+
+ return -f_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename T
+ >
+ double find_min_using_approximate_derivatives (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ T& x,
+ double min_f,
+ double derivative_eps = 1e-7
+ )
+ {
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ DLIB_CASSERT (
+ is_col_vector(x) && derivative_eps > 0,
+ "\tdouble find_min_using_approximate_derivatives()"
+ << "\n\tYou have to supply column vectors to this function"
+ << "\n\tx.nc(): " << x.nc()
+ << "\n\tderivative_eps: " << derivative_eps
+ );
+
+ T g, s;
+
+ double f_value = f(x);
+ g = derivative(f,derivative_eps)(x);
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+
+ while(stop_strategy.should_continue_search(x, f_value, g) && f_value > min_f)
+ {
+ s = search_strategy.get_next_direction(x, f_value, g);
+
+ double alpha = line_search(
+ make_line_search_function(f,x,s,f_value),
+ f_value,
+ derivative(make_line_search_function(f,x,s),derivative_eps),
+ dot(g,s), // Sometimes the following line is a better way of determining the initial gradient.
+ //derivative(make_line_search_function(f,x,s),derivative_eps)(0),
+ search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f,
+ search_strategy.get_max_line_search_iterations()
+ );
+
+ // Take the search step indicated by the above line search
+ x += alpha*s;
+
+ g = derivative(f,derivative_eps)(x);
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+ }
+
+ return f_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename T
+ >
+ double find_max_using_approximate_derivatives (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ T& x,
+ double max_f,
+ double derivative_eps = 1e-7
+ )
+ {
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ DLIB_CASSERT (
+ is_col_vector(x) && derivative_eps > 0,
+ "\tdouble find_max_using_approximate_derivatives()"
+ << "\n\tYou have to supply column vectors to this function"
+ << "\n\tx.nc(): " << x.nc()
+ << "\n\tderivative_eps: " << derivative_eps
+ );
+
+ // Just negate the necessary things and call the find_min version of this function.
+ return -find_min_using_approximate_derivatives(
+ search_strategy,
+ stop_strategy,
+ negate_function(f),
+ x,
+ -max_f,
+ derivative_eps
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Functions for box constrained optimization
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename V>
+ T zero_bounded_variables (
+ const double eps,
+ T vect,
+ const T& x,
+ const T& gradient,
+ const U& x_lower,
+ const V& x_upper
+ )
+ {
+ for (long i = 0; i < gradient.size(); ++i)
+ {
+ const double tol = eps*std::abs(x(i));
+ // if x(i) is an active bound constraint
+ if (x_lower(i)+tol >= x(i) && gradient(i) > 0)
+ vect(i) = 0;
+ else if (x_upper(i)-tol <= x(i) && gradient(i) < 0)
+ vect(i) = 0;
+ }
+ return vect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename V>
+ T gap_step_assign_bounded_variables (
+ const double eps,
+ T vect,
+ const T& x,
+ const T& gradient,
+ const U& x_lower,
+ const V& x_upper
+ )
+ {
+ for (long i = 0; i < gradient.size(); ++i)
+ {
+ const double tol = eps*std::abs(x(i));
+ // If x(i) is an active bound constraint then we should set its search
+ // direction such that a single step along the direction either does nothing or
+ // closes the gap of size tol before hitting the bound exactly.
+ if (x_lower(i)+tol >= x(i) && gradient(i) > 0)
+ vect(i) = x_lower(i)-x(i);
+ else if (x_upper(i)-tol <= x(i) && gradient(i) < 0)
+ vect(i) = x_upper(i)-x(i);
+ }
+ return vect;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T,
+ typename EXP1,
+ typename EXP2
+ >
+ double find_min_box_constrained (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ const matrix_exp<EXP1>& x_lower,
+ const matrix_exp<EXP2>& x_upper
+ )
+ {
+ /*
+ The implementation of this function is more or less based on the discussion in
+ the paper Projected Newton-type Methods in Machine Learning by Mark Schmidt, et al.
+ */
+
+ // make sure the requires clause is not violated
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ DLIB_CASSERT (
+ is_col_vector(x) && is_col_vector(x_lower) && is_col_vector(x_upper) &&
+ x.size() == x_lower.size() && x.size() == x_upper.size(),
+ "\tdouble find_min_box_constrained()"
+ << "\n\t The inputs to this function must be equal length column vectors."
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper)
+ << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t x_lower.size(): " << x_lower.size()
+ << "\n\t x_upper.size(): " << x_upper.size()
+ );
+ DLIB_ASSERT (
+ min(x_upper-x_lower) >= 0,
+ "\tdouble find_min_box_constrained()"
+ << "\n\t You have to supply proper box constraints to this function."
+ << "\n\r min(x_upper-x_lower): " << min(x_upper-x_lower)
+ );
+
+
+ T g, s;
+ double f_value = f(x);
+ g = der(x);
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+
+ // gap_eps determines how close we have to get to a bound constraint before we
+ // start basically dropping it from the optimization and consider it to be an
+ // active constraint.
+ const double gap_eps = 1e-8;
+
+ double last_alpha = 1;
+ while(stop_strategy.should_continue_search(x, f_value, g))
+ {
+ s = search_strategy.get_next_direction(x, f_value, zero_bounded_variables(gap_eps, g, x, g, x_lower, x_upper));
+ s = gap_step_assign_bounded_variables(gap_eps, s, x, g, x_lower, x_upper);
+
+ double alpha = backtracking_line_search(
+ make_line_search_function(clamp_function(f,x_lower,x_upper), x, s, f_value),
+ f_value,
+ dot(g,s), // compute gradient for the line search
+ last_alpha,
+ search_strategy.get_wolfe_rho(),
+ search_strategy.get_max_line_search_iterations());
+
+ // Do a trust region style thing for alpha. The idea is that if we take a
+ // small step then we are likely to take another small step. So we reuse the
+ // alpha from the last iteration unless the line search didn't shrink alpha at
+ // all, in that case, we start with a bigger alpha next time.
+ if (alpha == last_alpha)
+ last_alpha = std::min(last_alpha*10,1.0);
+ else
+ last_alpha = alpha;
+
+ // Take the search step indicated by the above line search
+ x = dlib::clamp(x + alpha*s, x_lower, x_upper);
+ g = der(x);
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+ }
+
+ return f_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T
+ >
+ double find_min_box_constrained (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ double x_lower,
+ double x_upper
+ )
+ {
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ typedef typename T::type scalar_type;
+ return find_min_box_constrained(search_strategy,
+ stop_strategy,
+ f,
+ der,
+ x,
+ uniform_matrix<scalar_type>(x.size(),1,x_lower),
+ uniform_matrix<scalar_type>(x.size(),1,x_upper) );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T,
+ typename EXP1,
+ typename EXP2
+ >
+ double find_max_box_constrained (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ const matrix_exp<EXP1>& x_lower,
+ const matrix_exp<EXP2>& x_upper
+ )
+ {
+ // make sure the requires clause is not violated
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ DLIB_CASSERT (
+ is_col_vector(x) && is_col_vector(x_lower) && is_col_vector(x_upper) &&
+ x.size() == x_lower.size() && x.size() == x_upper.size(),
+ "\tdouble find_max_box_constrained()"
+ << "\n\t The inputs to this function must be equal length column vectors."
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper)
+ << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t x_lower.size(): " << x_lower.size()
+ << "\n\t x_upper.size(): " << x_upper.size()
+ );
+ DLIB_ASSERT (
+ min(x_upper-x_lower) >= 0,
+ "\tdouble find_max_box_constrained()"
+ << "\n\t You have to supply proper box constraints to this function."
+ << "\n\r min(x_upper-x_lower): " << min(x_upper-x_lower)
+ );
+
+ // This function is basically just a copy of find_min_box_constrained() but with - put
+ // in the right places to flip things around so that it ends up looking for the max
+ // rather than the min.
+
+ T g, s;
+ double f_value = -f(x);
+ g = -der(x);
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+
+ // gap_eps determines how close we have to get to a bound constraint before we
+ // start basically dropping it from the optimization and consider it to be an
+ // active constraint.
+ const double gap_eps = 1e-8;
+
+ double last_alpha = 1;
+ while(stop_strategy.should_continue_search(x, f_value, g))
+ {
+ s = search_strategy.get_next_direction(x, f_value, zero_bounded_variables(gap_eps, g, x, g, x_lower, x_upper));
+ s = gap_step_assign_bounded_variables(gap_eps, s, x, g, x_lower, x_upper);
+
+ double alpha = backtracking_line_search(
+ negate_function(make_line_search_function(clamp_function(f,x_lower,x_upper), x, s, f_value)),
+ f_value,
+ dot(g,s), // compute gradient for the line search
+ last_alpha,
+ search_strategy.get_wolfe_rho(),
+ search_strategy.get_max_line_search_iterations());
+
+ // Do a trust region style thing for alpha. The idea is that if we take a
+ // small step then we are likely to take another small step. So we reuse the
+ // alpha from the last iteration unless the line search didn't shrink alpha at
+ // all, in that case, we start with a bigger alpha next time.
+ if (alpha == last_alpha)
+ last_alpha = std::min(last_alpha*10,1.0);
+ else
+ last_alpha = alpha;
+
+ // Take the search step indicated by the above line search
+ x = dlib::clamp(x + alpha*s, x_lower, x_upper);
+ g = -der(x);
+
+ // Don't forget to negate the output from the line search since it is from the
+ // unnegated version of f()
+ f_value *= -1;
+
+ if (!is_finite(f_value))
+ throw error("The objective function generated non-finite outputs");
+ if (!is_finite(g))
+ throw error("The objective function generated non-finite outputs");
+ }
+
+ return -f_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T
+ >
+ double find_max_box_constrained (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ double x_lower,
+ double x_upper
+ )
+ {
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ typedef typename T::type scalar_type;
+ return find_max_box_constrained(search_strategy,
+ stop_strategy,
+ f,
+ der,
+ x,
+ uniform_matrix<scalar_type>(x.size(),1,x_lower),
+ uniform_matrix<scalar_type>(x.size(),1,x_upper) );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_H_
+
diff --git a/ml/dlib/dlib/optimization/optimization_abstract.h b/ml/dlib/dlib/optimization/optimization_abstract.h
new file mode 100644
index 000000000..f3c42c2b4
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_abstract.h
@@ -0,0 +1,468 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATIOn_ABSTRACT_
+#ifdef DLIB_OPTIMIZATIOn_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "optimization_search_strategies_abstract.h"
+#include "optimization_stop_strategies_abstract.h"
+#include "optimization_line_search_abstract.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Functions that transform other functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ class central_differences;
+ /*!
+ This is a function object that represents the derivative of some other
+ function.
+
+ Note that if funct is a function of a double then the derivative of
+ funct is just a double but if funct is a function of a dlib::matrix (i.e. a
+ function of many variables) then its derivative is a gradient vector (a column
+ vector in particular).
+ !*/
+
+ template <
+ typename funct
+ >
+ const central_differences<funct> derivative(
+ const funct& f,
+ double eps
+ );
+ /*!
+ requires
+ - f == a function that returns a scalar
+ - f must have one of the following forms:
+ - double f(double)
+ - double f(dlib::matrix) (where the matrix is a column vector)
+ - double f(T, dlib::matrix) (where the matrix is a column vector. In
+ this case the derivative of f is taken with respect to the second argument.)
+ - eps > 0
+ ensures
+ - returns a function that represents the derivative of the function f. It
+ is approximated numerically by:
+ (f(x+eps)-f(x-eps))/(2*eps)
+ !*/
+
+ template <
+ typename funct
+ >
+ const central_differences<funct> derivative(
+ const funct& f
+ );
+ /*!
+ ensures
+ - returns derivative(f, 1e-7)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct,
+ typename EXP1,
+ typename EXP2
+ >
+ clamped_function_object<funct,EXP1,EXP2> clamp_function (
+ const funct& f,
+ const matrix_exp<EXP1>& x_lower,
+ const matrix_exp<EXP2>& x_upper
+ );
+ /*!
+ requires
+ - f == a function that takes a matrix and returns a scalar value. Moreover, f
+ must be capable of taking in matrices with the same dimensions as x_lower and
+ x_upper. So f(x_lower) must be a valid expression that evaluates to a scalar
+ value.
+ - x_lower.nr() == x_upper.nr() && x_lower.nc() == x_upper.nc()
+ (i.e. x_lower and x_upper must have the same dimensions)
+ - x_lower and x_upper must contain the same type of elements.
+ ensures
+ - returns a function object that represents the function g(x) where
+ g(x) == f(clamp(x,x_lower,x_upper))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Functions that perform unconstrained optimization
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T
+ >
+ double find_min (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ double min_f
+ );
+ /*!
+ requires
+ - search_strategy == an object that defines a search strategy such as one
+ of the objects from dlib/optimization/optimization_search_strategies_abstract.h
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - f(x) must be a valid expression that evaluates to a double
+ - der(x) must be a valid expression that evaluates to the derivative of f() at x.
+ - is_col_vector(x) == true
+ ensures
+ - Performs an unconstrained minimization of the function f() using the given
+ search_strategy and starting from the initial point x.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found or f(#x) < min_f.
+ - #x == the value of x that was found to minimize f()
+ - returns f(#x).
+ - When this function makes calls to f() and der() it always does so by
+ first calling f() and then calling der(). That is, these two functions
+ are always called in pairs with f() being called first and then der()
+ being called second.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T
+ >
+ double find_max (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ double max_f
+ );
+ /*!
+ requires
+ - search_strategy == an object that defines a search strategy such as one
+ of the objects from dlib/optimization/optimization_search_strategies_abstract.h
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - f(x) must be a valid expression that evaluates to a double
+ - der(x) must be a valid expression that evaluates to the derivative of f() at x.
+ - is_col_vector(x) == true
+ ensures
+ - Performs an unconstrained maximization of the function f() using the given
+ search_strategy and starting from the initial point x.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found or f(#x) > max_f.
+ - #x == the value of x that was found to maximize f()
+ - returns f(#x).
+ - When this function makes calls to f() and der() it always does so by
+ first calling f() and then calling der(). That is, these two functions
+ are always called in pairs with f() being called first and then der()
+ being called second.
+ - Note that this function solves the maximization problem by converting it
+ into a minimization problem. Therefore, the values of f and its derivative
+ reported to the stopping strategy will be negated. That is, stop_strategy
+ will see -f() and -der(). All this really means is that the status messages
+ from a stopping strategy in verbose mode will display a negated objective
+ value.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename T
+ >
+ double find_min_using_approximate_derivatives (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ T& x,
+ double min_f,
+ double derivative_eps = 1e-7
+ );
+ /*!
+ requires
+ - search_strategy == an object that defines a search strategy such as one
+ of the objects from dlib/optimization/optimization_search_strategies_abstract.h
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - f(x) must be a valid expression that evaluates to a double
+ - is_col_vector(x) == true
+ - derivative_eps > 0
+ ensures
+ - Performs an unconstrained minimization of the function f() using the given
+ search_strategy and starting from the initial point x.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found or f(#x) < min_f.
+ - #x == the value of x that was found to minimize f()
+ - returns f(#x).
+ - Uses the dlib::derivative(f,derivative_eps) function to compute gradient
+ information.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename T
+ >
+ double find_max_using_approximate_derivatives (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ T& x,
+ double max_f,
+ double derivative_eps = 1e-7
+ );
+ /*!
+ requires
+ - search_strategy == an object that defines a search strategy such as one
+ of the objects from dlib/optimization/optimization_search_strategies_abstract.h
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - f(x) must be a valid expression that evaluates to a double
+ - is_col_vector(x) == true
+ - derivative_eps > 0
+ ensures
+ - Performs an unconstrained maximization of the function f() using the given
+ search_strategy and starting from the initial point x.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found or f(#x) > max_f.
+ - #x == the value of x that was found to maximize f()
+ - returns f(#x).
+ - Uses the dlib::derivative(f,derivative_eps) function to compute gradient
+ information.
+ - Note that this function solves the maximization problem by converting it
+ into a minimization problem. Therefore, the values of f and its derivative
+ reported to the stopping strategy will be negated. That is, stop_strategy
+ will see -f() and -der(). All this really means is that the status messages
+ from a stopping strategy in verbose mode will display a negated objective
+ value.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Functions that perform box constrained optimization
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T,
+ typename EXP1,
+ typename EXP2
+ >
+ double find_min_box_constrained (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ const matrix_exp<EXP1>& x_lower,
+ const matrix_exp<EXP2>& x_upper
+ );
+ /*!
+ requires
+ - search_strategy == an object that defines a search strategy such as one
+ of the objects from dlib/optimization/optimization_search_strategies_abstract.h
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - f(x) must be a valid expression that evaluates to a double
+ - der(x) must be a valid expression that evaluates to the derivative of f() at x.
+ - is_col_vector(x) == true
+ - is_col_vector(x_lower) == true
+ - is_col_vector(x_upper) == true
+ - x.size() == x_lower.size() == x_upper.size()
+ (i.e. x, x_lower, and x_upper need to all be column vectors of the same dimensionality)
+ - min(x_upper-x_lower) >= 0
+ (i.e. x_upper must contain upper bounds relative to x_lower)
+ ensures
+ - Performs a box constrained minimization of the function f() using the given
+ search_strategy and starting from the initial point x. That is, we try to
+ find the x value that minimizes f(x) but is also within the box constraints
+ specified by x_lower and x_upper. That is, we ensure that #x satisfies:
+ - min(#x - x_lower) >= 0 && min(x_upper - #x) >= 0
+ - This function uses a backtracking line search along with a gradient projection
+ step to handle the box constraints.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found.
+ - #x == the value of x that was found to minimize f() within the given box
+ constraints.
+ - returns f(#x).
+ - The last call to f() will be made with f(#x).
+ - When calling f() and der(), the input passed to them will always be inside
+ the box constraints defined by x_lower and x_upper.
+ - When calling der(x), it will always be the case that the last call to f() was
+ made with the same x value. This means that you can reuse any intermediate
+ results from the previous call to f(x) inside der(x) rather than recomputing
+ them inside der(x).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T
+ >
+ double find_min_box_constrained (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ const double x_lower,
+ const double x_upper
+ );
+ /*!
+ requires
+ - search_strategy == an object that defines a search strategy such as one
+ of the objects from dlib/optimization/optimization_search_strategies_abstract.h
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - f(x) must be a valid expression that evaluates to a double
+ - der(x) must be a valid expression that evaluates to the derivative of f() at x.
+ - is_col_vector(x) == true
+ - x_lower < x_upper
+ ensures
+ - This function is identical to find_min_box_constrained() as defined above
+ except that it takes x_lower and x_upper as doubles rather than column
+ vectors. In this case, all variables have the same lower bound of x_lower
+ and similarly have the same upper bound of x_upper. Therefore, this is just
+ a convenience function for calling find_max_box_constrained() when all
+ variables have the same bound constraints.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T,
+ typename EXP1,
+ typename EXP2
+ >
+ double find_max_box_constrained (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ const matrix_exp<EXP1>& x_lower,
+ const matrix_exp<EXP2>& x_upper
+ );
+ /*!
+ requires
+ - search_strategy == an object that defines a search strategy such as one
+ of the objects from dlib/optimization/optimization_search_strategies_abstract.h
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - f(x) must be a valid expression that evaluates to a double
+ - der(x) must be a valid expression that evaluates to the derivative of f() at x.
+ - is_col_vector(x) == true
+ - is_col_vector(x_lower) == true
+ - is_col_vector(x_upper) == true
+ - x.size() == x_lower.size() == x_upper.size()
+ (i.e. x, x_lower, and x_upper need to all be column vectors of the same dimensionality)
+ - min(x_upper-x_lower) >= 0
+ (i.e. x_upper must contain upper bounds relative to x_lower)
+ ensures
+ - Performs a box constrained maximization of the function f() using the given
+ search_strategy and starting from the initial point x. That is, we try to
+ find the x value that maximizes f(x) but is also within the box constraints
+ specified by x_lower and x_upper. That is, we ensure that #x satisfies:
+ - min(#x - x_lower) >= 0 && min(x_upper - #x) >= 0
+ - This function uses a backtracking line search along with a gradient projection
+ step to handle the box constraints.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found.
+ - #x == the value of x that was found to maximize f() within the given box
+ constraints.
+ - returns f(#x).
+ - The last call to f() will be made with f(#x).
+ - When calling f() and der(), the input passed to them will always be inside
+ the box constraints defined by x_lower and x_upper.
+ - When calling der(x), it will always be the case that the last call to f() was
+ made with the same x value. This means that you can reuse any intermediate
+ results from the previous call to f(x) inside der(x) rather than recomputing
+ them inside der(x).
+ - Note that this function solves the maximization problem by converting it
+ into a minimization problem. Therefore, the values of f and its derivative
+ reported to the stopping strategy will be negated. That is, stop_strategy
+ will see -f() and -der(). All this really means is that the status messages
+ from a stopping strategy in verbose mode will display a negated objective
+ value.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename search_strategy_type,
+ typename stop_strategy_type,
+ typename funct,
+ typename funct_der,
+ typename T
+ >
+ double find_max_box_constrained (
+ search_strategy_type search_strategy,
+ stop_strategy_type stop_strategy,
+ const funct& f,
+ const funct_der& der,
+ T& x,
+ const double x_lower,
+ const double x_upper
+ );
+ /*!
+ requires
+ - search_strategy == an object that defines a search strategy such as one
+ of the objects from dlib/optimization/optimization_search_strategies_abstract.h
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - f(x) must be a valid expression that evaluates to a double
+ - der(x) must be a valid expression that evaluates to the derivative of f() at x.
+ - is_col_vector(x) == true
+ - x_lower < x_upper
+ ensures
+ - This function is identical to find_max_box_constrained() as defined above
+ except that it takes x_lower and x_upper as doubles rather than column
+ vectors. In this case, all variables have the same lower bound of x_lower
+ and similarly have the same upper bound of x_upper. Therefore, this is just
+ a convenience function for calling find_max_box_constrained() when all
+ variables have the same bound constraints.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/optimization/optimization_bobyqa.h b/ml/dlib/dlib/optimization/optimization_bobyqa.h
new file mode 100644
index 000000000..6fbc40c06
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_bobyqa.h
@@ -0,0 +1,3423 @@
+// Copyright (C) 2009 M.J.D. Powell, Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATIOn_BOBYQA_Hh_
+#define DLIB_OPTIMIZATIOn_BOBYQA_Hh_
+
+/*
+ The code in this file is derived from Powell's BOBYQA Fortran code.
+ It was created by running f2c on the original Fortran code and then
+ massaging the resulting C code into what you can see below.
+
+
+ The following paper, published in 2009 by Powell, describes the
+ detailed workings of the BOBYQA algorithm.
+
+ The BOBYQA algorithm for bound constrained optimization
+ without derivatives by M.J.D. Powell
+*/
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+
+#include "../matrix.h"
+#include "optimization_bobyqa_abstract.h"
+#include "optimization.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ class bobyqa_failure : public error {
+ public: bobyqa_failure(const std::string& s):error(s){}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class bobyqa_implementation
+ {
+ typedef long integer;
+ typedef double doublereal;
+
+ public:
+
+ template <
+ typename funct,
+ typename T,
+ typename U
+ >
+ double find_min (
+ const funct& f,
+ T& x,
+ long npt,
+ const U& xl_,
+ const U& xu_,
+ const double rhobeg,
+ const double rhoend,
+ const long max_f_evals
+ ) const
+ {
+ const unsigned long n = x.size();
+ const unsigned long w_size = (npt+5)*(npt+n)+3*n*(n+5)/2;
+ std::unique_ptr<doublereal[]> w(new doublereal[w_size]);
+
+ // make these temporary matrices becuse U might be some
+ // kind of matrix_exp that doesn't support taking the address
+ // of the first element.
+ matrix<double,0,1> xl(xl_);
+ matrix<double,0,1> xu(xu_);
+
+
+ return bobyqa_ (f,
+ x.size(),
+ npt,
+ &x(0),
+ &xl(0),
+ &xu(0),
+ rhobeg,
+ rhoend,
+ max_f_evals,
+ w.get() );
+ }
+
+ private:
+
+
+ template <typename funct>
+ doublereal bobyqa_(
+ const funct& calfun,
+ const integer n,
+ const integer npt,
+ doublereal *x,
+ const doublereal *xl,
+ const doublereal *xu,
+ const doublereal rhobeg,
+ const doublereal rhoend,
+ const integer maxfun,
+ doublereal *w
+ ) const
+ {
+
+ /* System generated locals */
+ integer i__1;
+ doublereal d__1, d__2;
+
+ /* Local variables */
+ integer j, id_, np, iw, igo, ihq, ixb, ixa, ifv, isl, jsl, ipq, ivl, ixn, ixo, ixp, isu, jsu, ndim;
+ doublereal temp, zero;
+ integer ibmat, izmat;
+
+
+ /* This subroutine seeks the least value of a function of many variables, */
+ /* by applying a trust region method that forms quadratic models by */
+ /* interpolation. There is usually some freedom in the interpolation */
+ /* conditions, which is taken up by minimizing the Frobenius norm of */
+ /* the change to the second derivative of the model, beginning with the */
+ /* zero matrix. The values of the variables are constrained by upper and */
+ /* lower bounds. The arguments of the subroutine are as follows. */
+
+ /* N must be set to the number of variables and must be at least two. */
+ /* NPT is the number of interpolation conditions. Its value must be in */
+ /* the interval [N+2,(N+1)(N+2)/2]. Choices that exceed 2*N+1 are not */
+ /* recommended. */
+ /* Initial values of the variables must be set in X(1),X(2),...,X(N). They */
+ /* will be changed to the values that give the least calculated F. */
+ /* For I=1,2,...,N, XL(I) and XU(I) must provide the lower and upper */
+ /* bounds, respectively, on X(I). The construction of quadratic models */
+ /* requires XL(I) to be strictly less than XU(I) for each I. Further, */
+ /* the contribution to a model from changes to the I-th variable is */
+ /* damaged severely by rounding errors if XU(I)-XL(I) is too small. */
+ /* RHOBEG and RHOEND must be set to the initial and final values of a trust */
+ /* region radius, so both must be positive with RHOEND no greater than */
+ /* RHOBEG. Typically, RHOBEG should be about one tenth of the greatest */
+ /* expected change to a variable, while RHOEND should indicate the */
+ /* accuracy that is required in the final values of the variables. An */
+ /* error return occurs if any of the differences XU(I)-XL(I), I=1,...,N, */
+ /* is less than 2*RHOBEG. */
+ /* MAXFUN must be set to an upper bound on the number of calls of CALFUN. */
+ /* The array W will be used for working space. Its length must be at least */
+ /* (NPT+5)*(NPT+N)+3*N*(N+5)/2. */
+
+ /* Parameter adjustments */
+ --w;
+ --xu;
+ --xl;
+ --x;
+
+ /* Function Body */
+ np = n + 1;
+
+ /* Return if the value of NPT is unacceptable. */
+ if (npt < n + 2 || npt > (n + 2) * np / 2) {
+ throw bobyqa_failure("Return from BOBYQA because NPT is not in the required interval");
+ //goto L40;
+ }
+
+ /* Partition the working space array, so that different parts of it can */
+ /* be treated separately during the calculation of BOBYQB. The partition */
+ /* requires the first (NPT+2)*(NPT+N)+3*N*(N+5)/2 elements of W plus the */
+ /* space that is taken by the last array in the argument list of BOBYQB. */
+
+ ndim = npt + n;
+ ixb = 1;
+ ixp = ixb + n;
+ ifv = ixp + n * npt;
+ ixo = ifv + npt;
+ igo = ixo + n;
+ ihq = igo + n;
+ ipq = ihq + n * np / 2;
+ ibmat = ipq + npt;
+ izmat = ibmat + ndim * n;
+ isl = izmat + npt * (npt - np);
+ isu = isl + n;
+ ixn = isu + n;
+ ixa = ixn + n;
+ id_ = ixa + n;
+ ivl = id_ + n;
+ iw = ivl + ndim;
+
+ /* Return if there is insufficient space between the bounds. Modify the */
+ /* initial X if necessary in order to avoid conflicts between the bounds */
+ /* and the construction of the first quadratic model. The lower and upper */
+ /* bounds on moves from the updated X are set now, in the ISL and ISU */
+ /* partitions of W, in order to provide useful and exact information about */
+ /* components of X that become within distance RHOBEG from their bounds. */
+
+ zero = 0.;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ temp = xu[j] - xl[j];
+ if (temp < rhobeg + rhobeg) {
+ throw bobyqa_failure("Return from BOBYQA because one of the differences in x_lower and x_upper is less than 2*rho_begin");
+ //goto L40;
+ }
+ jsl = isl + j - 1;
+ jsu = jsl + n;
+ w[jsl] = xl[j] - x[j];
+ w[jsu] = xu[j] - x[j];
+ if (w[jsl] >= -(rhobeg)) {
+ if (w[jsl] >= zero) {
+ x[j] = xl[j];
+ w[jsl] = zero;
+ w[jsu] = temp;
+ } else {
+ x[j] = xl[j] + rhobeg;
+ w[jsl] = -(rhobeg);
+ /* Computing MAX */
+ d__1 = xu[j] - x[j];
+ w[jsu] = std::max(d__1,rhobeg);
+ }
+ } else if (w[jsu] <= rhobeg) {
+ if (w[jsu] <= zero) {
+ x[j] = xu[j];
+ w[jsl] = -temp;
+ w[jsu] = zero;
+ } else {
+ x[j] = xu[j] - rhobeg;
+ /* Computing MIN */
+ d__1 = xl[j] - x[j], d__2 = -(rhobeg);
+ w[jsl] = std::min(d__1,d__2);
+ w[jsu] = rhobeg;
+ }
+ }
+ /* L30: */
+ }
+
+ /* Make the call of BOBYQB. */
+
+ return bobyqb_(calfun, n, npt, &x[1], &xl[1], &xu[1], rhobeg, rhoend, maxfun, &w[
+ ixb], &w[ixp], &w[ifv], &w[ixo], &w[igo], &w[ihq], &w[ipq], &w[
+ ibmat], &w[izmat], ndim, &w[isl], &w[isu], &w[ixn], &w[ixa], &w[
+ id_], &w[ivl], &w[iw]);
+ //L40:
+ ;
+ } /* bobyqa_ */
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename funct>
+ doublereal bobyqb_(
+ const funct& calfun,
+ const integer n,
+ const integer npt,
+ doublereal *x,
+ const doublereal *xl,
+ const doublereal *xu,
+ const doublereal rhobeg,
+ const doublereal rhoend,
+ const integer maxfun,
+ doublereal *xbase,
+ doublereal *xpt,
+ doublereal *fval,
+ doublereal *xopt,
+ doublereal *gopt,
+ doublereal *hq,
+ doublereal *pq,
+ doublereal *bmat,
+ doublereal *zmat,
+ const integer ndim,
+ doublereal *sl,
+ doublereal *su,
+ doublereal *xnew,
+ doublereal *xalt,
+ doublereal *d__,
+ doublereal *vlag,
+ doublereal *w
+ ) const
+ {
+ /* System generated locals */
+ integer xpt_dim1, xpt_offset, bmat_dim1, bmat_offset, zmat_dim1,
+ zmat_offset, i__1, i__2, i__3;
+ doublereal d__1, d__2, d__3, d__4;
+
+ /* Local variables */
+ doublereal f = 0;
+ integer i__, j, k, ih, nf, jj, nh, ip, jp;
+ doublereal dx;
+ integer np;
+ doublereal den = 0, one = 0, ten = 0, dsq = 0, rho = 0, sum = 0, two = 0, diff = 0, half = 0, beta = 0, gisq = 0;
+ integer knew = 0;
+ doublereal temp, suma, sumb, bsum, fopt;
+ integer kopt = 0, nptm;
+ doublereal zero, curv;
+ integer ksav;
+ doublereal gqsq = 0, dist = 0, sumw = 0, sumz = 0, diffa = 0, diffb = 0, diffc = 0, hdiag = 0;
+ integer kbase;
+ doublereal alpha = 0, delta = 0, adelt = 0, denom = 0, fsave = 0, bdtol = 0, delsq = 0;
+ integer nresc, nfsav;
+ doublereal ratio = 0, dnorm = 0, vquad = 0, pqold = 0, tenth = 0;
+ integer itest;
+ doublereal sumpq, scaden;
+ doublereal errbig, cauchy, fracsq, biglsq, densav;
+ doublereal bdtest;
+ doublereal crvmin, frhosq;
+ doublereal distsq;
+ integer ntrits;
+ doublereal xoptsq;
+
+
+
+ /* The arguments N, NPT, X, XL, XU, RHOBEG, RHOEND, IPRINT and MAXFUN */
+ /* are identical to the corresponding arguments in SUBROUTINE BOBYQA. */
+ /* XBASE holds a shift of origin that should reduce the contributions */
+ /* from rounding errors to values of the model and Lagrange functions. */
+ /* XPT is a two-dimensional array that holds the coordinates of the */
+ /* interpolation points relative to XBASE. */
+ /* FVAL holds the values of F at the interpolation points. */
+ /* XOPT is set to the displacement from XBASE of the trust region centre. */
+ /* GOPT holds the gradient of the quadratic model at XBASE+XOPT. */
+ /* HQ holds the explicit second derivatives of the quadratic model. */
+ /* PQ contains the parameters of the implicit second derivatives of the */
+ /* quadratic model. */
+ /* BMAT holds the last N columns of H. */
+ /* ZMAT holds the factorization of the leading NPT by NPT submatrix of H, */
+ /* this factorization being ZMAT times ZMAT^T, which provides both the */
+ /* correct rank and positive semi-definiteness. */
+ /* NDIM is the first dimension of BMAT and has the value NPT+N. */
+ /* SL and SU hold the differences XL-XBASE and XU-XBASE, respectively. */
+ /* All the components of every XOPT are going to satisfy the bounds */
+ /* SL(I) .LEQ. XOPT(I) .LEQ. SU(I), with appropriate equalities when */
+ /* XOPT is on a constraint boundary. */
+ /* XNEW is chosen by SUBROUTINE TRSBOX or ALTMOV. Usually XBASE+XNEW is the */
+ /* vector of variables for the next call of CALFUN. XNEW also satisfies */
+ /* the SL and SU constraints in the way that has just been mentioned. */
+ /* XALT is an alternative to XNEW, chosen by ALTMOV, that may replace XNEW */
+ /* in order to increase the denominator in the updating of UPDATE. */
+ /* D is reserved for a trial step from XOPT, which is usually XNEW-XOPT. */
+ /* VLAG contains the values of the Lagrange functions at a new point X. */
+ /* They are part of a product that requires VLAG to be of length NDIM. */
+ /* W is a one-dimensional array that is used for working space. Its length */
+ /* must be at least 3*NDIM = 3*(NPT+N). */
+
+ /* Set some constants. */
+
+ /* Parameter adjustments */
+ zmat_dim1 = npt;
+ zmat_offset = 1 + zmat_dim1;
+ zmat -= zmat_offset;
+ xpt_dim1 = npt;
+ xpt_offset = 1 + xpt_dim1;
+ xpt -= xpt_offset;
+ --x;
+ --xl;
+ --xu;
+ --xbase;
+ --fval;
+ --xopt;
+ --gopt;
+ --hq;
+ --pq;
+ bmat_dim1 = ndim;
+ bmat_offset = 1 + bmat_dim1;
+ bmat -= bmat_offset;
+ --sl;
+ --su;
+ --xnew;
+ --xalt;
+ --d__;
+ --vlag;
+ --w;
+
+ /* Function Body */
+ half = .5;
+ one = 1.;
+ ten = 10.;
+ tenth = .1;
+ two = 2.;
+ zero = 0.;
+ np = n + 1;
+ nptm = npt - np;
+ nh = n * np / 2;
+
+ /* The call of PRELIM sets the elements of XBASE, XPT, FVAL, GOPT, HQ, PQ, */
+ /* BMAT and ZMAT for the first iteration, with the corresponding values of */
+ /* of NF and KOPT, which are the number of calls of CALFUN so far and the */
+ /* index of the interpolation point at the trust region centre. Then the */
+ /* initial XOPT is set too. The branch to label 720 occurs if MAXFUN is */
+ /* less than NPT. GOPT will be updated if KOPT is different from KBASE. */
+
+ prelim_(calfun, n, npt, &x[1], &xl[1], &xu[1], rhobeg, maxfun, &xbase[1],
+ &xpt[xpt_offset], &fval[1], &gopt[1], &hq[1], &pq[1], &bmat[bmat_offset],
+ &zmat[zmat_offset], ndim, &sl[1], &su[1], nf, kopt);
+ xoptsq = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ xopt[i__] = xpt[kopt + i__ * xpt_dim1];
+ /* L10: */
+ /* Computing 2nd power */
+ d__1 = xopt[i__];
+ xoptsq += d__1 * d__1;
+ }
+ fsave = fval[1];
+ if (nf < npt) {
+ throw bobyqa_failure("Return from BOBYQA because the objective function has been called max_f_evals times.");
+ //goto L720;
+ }
+ kbase = 1;
+
+ /* Complete the settings that are required for the iterative procedure. */
+
+ rho = rhobeg;
+ delta = rho;
+ nresc = nf;
+ ntrits = 0;
+ diffa = zero;
+ diffb = zero;
+ itest = 0;
+ nfsav = nf;
+
+ /* Update GOPT if necessary before the first iteration and after each */
+ /* call of RESCUE that makes a call of CALFUN. */
+
+L20:
+ if (kopt != kbase) {
+ ih = 0;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ i__2 = j;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ ++ih;
+ if (i__ < j) {
+ gopt[j] += hq[ih] * xopt[i__];
+ }
+ /* L30: */
+ gopt[i__] += hq[ih] * xopt[j];
+ }
+ }
+ if (nf > npt) {
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ temp = zero;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ /* L40: */
+ temp += xpt[k + j * xpt_dim1] * xopt[j];
+ }
+ temp = pq[k] * temp;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ /* L50: */
+ gopt[i__] += temp * xpt[k + i__ * xpt_dim1];
+ }
+ }
+ }
+ }
+
+ /* Generate the next point in the trust region that provides a small value */
+ /* of the quadratic model subject to the constraints on the variables. */
+ /* The integer NTRITS is set to the number "trust region" iterations that */
+ /* have occurred since the last "alternative" iteration. If the length */
+ /* of XNEW-XOPT is less than HALF*RHO, however, then there is a branch to */
+ /* label 650 or 680 with NTRITS=-1, instead of calculating F at XNEW. */
+
+L60:
+ trsbox_(n, npt, &xpt[xpt_offset], &xopt[1], &gopt[1], &hq[1], &pq[1], &sl[1],
+ &su[1], delta, &xnew[1], &d__[1], &w[1], &w[np], &w[np + n],
+ &w[np + (n << 1)], &w[np + n * 3], &dsq, &crvmin);
+ /* Computing MIN */
+ d__1 = delta, d__2 = std::sqrt(dsq);
+ dnorm = std::min(d__1,d__2);
+ if (dnorm < half * rho) {
+ ntrits = -1;
+ /* Computing 2nd power */
+ d__1 = ten * rho;
+ distsq = d__1 * d__1;
+ if (nf <= nfsav + 2) {
+ goto L650;
+ }
+
+ /* The following choice between labels 650 and 680 depends on whether or */
+ /* not our work with the current RHO seems to be complete. Either RHO is */
+ /* decreased or termination occurs if the errors in the quadratic model at */
+ /* the last three interpolation points compare favourably with predictions */
+ /* of likely improvements to the model within distance HALF*RHO of XOPT. */
+
+ /* Computing MAX */
+ d__1 = std::max(diffa,diffb);
+ errbig = std::max(d__1,diffc);
+ frhosq = rho * .125 * rho;
+ if (crvmin > zero && errbig > frhosq * crvmin) {
+ goto L650;
+ }
+ bdtol = errbig / rho;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ bdtest = bdtol;
+ if (xnew[j] == sl[j]) {
+ bdtest = w[j];
+ }
+ if (xnew[j] == su[j]) {
+ bdtest = -w[j];
+ }
+ if (bdtest < bdtol) {
+ curv = hq[(j + j * j) / 2];
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ /* L70: */
+ /* Computing 2nd power */
+ d__1 = xpt[k + j * xpt_dim1];
+ curv += pq[k] * (d__1 * d__1);
+ }
+ bdtest += half * curv * rho;
+ if (bdtest < bdtol) {
+ goto L650;
+ }
+ }
+ /* L80: */
+ }
+ goto L680;
+ }
+ ++ntrits;
+
+ /* Severe cancellation is likely to occur if XOPT is too far from XBASE. */
+ /* If the following test holds, then XBASE is shifted so that XOPT becomes */
+ /* zero. The appropriate changes are made to BMAT and to the second */
+ /* derivatives of the current model, beginning with the changes to BMAT */
+ /* that do not depend on ZMAT. VLAG is used temporarily for working space. */
+
+L90:
+ if (dsq <= xoptsq * .001) {
+ fracsq = xoptsq * .25;
+ sumpq = zero;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ sumpq += pq[k];
+ sum = -half * xoptsq;
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* L100: */
+ sum += xpt[k + i__ * xpt_dim1] * xopt[i__];
+ }
+ w[npt + k] = sum;
+ temp = fracsq - half * sum;
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ w[i__] = bmat[k + i__ * bmat_dim1];
+ vlag[i__] = sum * xpt[k + i__ * xpt_dim1] + temp * xopt[i__];
+ ip = npt + i__;
+ i__3 = i__;
+ for (j = 1; j <= i__3; ++j) {
+ /* L110: */
+ bmat[ip + j * bmat_dim1] = bmat[ip + j * bmat_dim1] + w[
+ i__] * vlag[j] + vlag[i__] * w[j];
+ }
+ }
+ }
+
+ /* Then the revisions of BMAT that depend on ZMAT are calculated. */
+
+ i__3 = nptm;
+ for (jj = 1; jj <= i__3; ++jj) {
+ sumz = zero;
+ sumw = zero;
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ sumz += zmat[k + jj * zmat_dim1];
+ vlag[k] = w[npt + k] * zmat[k + jj * zmat_dim1];
+ /* L120: */
+ sumw += vlag[k];
+ }
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ sum = (fracsq * sumz - half * sumw) * xopt[j];
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L130: */
+ sum += vlag[k] * xpt[k + j * xpt_dim1];
+ }
+ w[j] = sum;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L140: */
+ bmat[k + j * bmat_dim1] += sum * zmat[k + jj * zmat_dim1];
+ }
+ }
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ ip = i__ + npt;
+ temp = w[i__];
+ i__2 = i__;
+ for (j = 1; j <= i__2; ++j) {
+ /* L150: */
+ bmat[ip + j * bmat_dim1] += temp * w[j];
+ }
+ }
+ }
+
+ /* The following instructions complete the shift, including the changes */
+ /* to the second derivative parameters of the quadratic model. */
+
+ ih = 0;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ w[j] = -half * sumpq * xopt[j];
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ w[j] += pq[k] * xpt[k + j * xpt_dim1];
+ /* L160: */
+ xpt[k + j * xpt_dim1] -= xopt[j];
+ }
+ i__1 = j;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ ++ih;
+ hq[ih] = hq[ih] + w[i__] * xopt[j] + xopt[i__] * w[j];
+ /* L170: */
+ bmat[npt + i__ + j * bmat_dim1] = bmat[npt + j + i__ *
+ bmat_dim1];
+ }
+ }
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ xbase[i__] += xopt[i__];
+ xnew[i__] -= xopt[i__];
+ sl[i__] -= xopt[i__];
+ su[i__] -= xopt[i__];
+ /* L180: */
+ xopt[i__] = zero;
+ }
+ xoptsq = zero;
+ }
+ if (ntrits == 0) {
+ goto L210;
+ }
+ goto L230;
+
+ /* XBASE is also moved to XOPT by a call of RESCUE. This calculation is */
+ /* more expensive than the previous shift, because new matrices BMAT and */
+ /* ZMAT are generated from scratch, which may include the replacement of */
+ /* interpolation points whose positions seem to be causing near linear */
+ /* dependence in the interpolation conditions. Therefore RESCUE is called */
+ /* only if rounding errors have reduced by at least a factor of two the */
+ /* denominator of the formula for updating the H matrix. It provides a */
+ /* useful safeguard, but is not invoked in most applications of BOBYQA. */
+
+L190:
+ nfsav = nf;
+ kbase = kopt;
+ rescue_(calfun, n, npt, &xl[1], &xu[1], maxfun, &xbase[1], &xpt[
+ xpt_offset], &fval[1], &xopt[1], &gopt[1], &hq[1], &pq[1], &bmat[
+ bmat_offset], &zmat[zmat_offset], ndim, &sl[1], &su[1], nf, delta,
+ kopt, &vlag[1], &w[1], &w[n + np], &w[ndim + np]);
+
+ /* XOPT is updated now in case the branch below to label 720 is taken. */
+ /* Any updating of GOPT occurs after the branch below to label 20, which */
+ /* leads to a trust region iteration as does the branch to label 60. */
+
+ xoptsq = zero;
+ if (kopt != kbase) {
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ xopt[i__] = xpt[kopt + i__ * xpt_dim1];
+ /* L200: */
+ /* Computing 2nd power */
+ d__1 = xopt[i__];
+ xoptsq += d__1 * d__1;
+ }
+ }
+ if (nf < 0) {
+ nf = maxfun;
+ throw bobyqa_failure("Return from BOBYQA because the objective function has been called max_f_evals times.");
+ //goto L720;
+ }
+ nresc = nf;
+ if (nfsav < nf) {
+ nfsav = nf;
+ goto L20;
+ }
+ if (ntrits > 0) {
+ goto L60;
+ }
+
+ /* Pick two alternative vectors of variables, relative to XBASE, that */
+ /* are suitable as new positions of the KNEW-th interpolation point. */
+ /* Firstly, XNEW is set to the point on a line through XOPT and another */
+ /* interpolation point that minimizes the predicted value of the next */
+ /* denominator, subject to ||XNEW - XOPT|| .LEQ. ADELT and to the SL */
+ /* and SU bounds. Secondly, XALT is set to the best feasible point on */
+ /* a constrained version of the Cauchy step of the KNEW-th Lagrange */
+ /* function, the corresponding value of the square of this function */
+ /* being returned in CAUCHY. The choice between these alternatives is */
+ /* going to be made when the denominator is calculated. */
+
+L210:
+ altmov_(n, npt, &xpt[xpt_offset], &xopt[1], &bmat[bmat_offset], &zmat[zmat_offset],
+ ndim, &sl[1], &su[1], kopt, knew, adelt, &xnew[1],
+ &xalt[1], alpha, cauchy, &w[1], &w[np], &w[ndim + 1]);
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ /* L220: */
+ d__[i__] = xnew[i__] - xopt[i__];
+ }
+
+ /* Calculate VLAG and BETA for the current choice of D. The scalar */
+ /* product of D with XPT(K,.) is going to be held in W(NPT+K) for */
+ /* use when VQUAD is calculated. */
+
+L230:
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ suma = zero;
+ sumb = zero;
+ sum = zero;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ suma += xpt[k + j * xpt_dim1] * d__[j];
+ sumb += xpt[k + j * xpt_dim1] * xopt[j];
+ /* L240: */
+ sum += bmat[k + j * bmat_dim1] * d__[j];
+ }
+ w[k] = suma * (half * suma + sumb);
+ vlag[k] = sum;
+ /* L250: */
+ w[npt + k] = suma;
+ }
+ beta = zero;
+ i__1 = nptm;
+ for (jj = 1; jj <= i__1; ++jj) {
+ sum = zero;
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ /* L260: */
+ sum += zmat[k + jj * zmat_dim1] * w[k];
+ }
+ beta -= sum * sum;
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ /* L270: */
+ vlag[k] += sum * zmat[k + jj * zmat_dim1];
+ }
+ }
+ dsq = zero;
+ bsum = zero;
+ dx = zero;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ /* Computing 2nd power */
+ d__1 = d__[j];
+ dsq += d__1 * d__1;
+ sum = zero;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L280: */
+ sum += w[k] * bmat[k + j * bmat_dim1];
+ }
+ bsum += sum * d__[j];
+ jp = npt + j;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ /* L290: */
+ sum += bmat[jp + i__ * bmat_dim1] * d__[i__];
+ }
+ vlag[jp] = sum;
+ bsum += sum * d__[j];
+ /* L300: */
+ dx += d__[j] * xopt[j];
+ }
+ beta = dx * dx + dsq * (xoptsq + dx + dx + half * dsq) + beta - bsum;
+ vlag[kopt] += one;
+
+ /* If NTRITS is zero, the denominator may be increased by replacing */
+ /* the step D of ALTMOV by a Cauchy step. Then RESCUE may be called if */
+ /* rounding errors have damaged the chosen denominator. */
+
+ if (ntrits == 0) {
+ /* Computing 2nd power */
+ d__1 = vlag[knew];
+ denom = d__1 * d__1 + alpha * beta;
+ if (denom < cauchy && cauchy > zero) {
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ xnew[i__] = xalt[i__];
+ /* L310: */
+ d__[i__] = xnew[i__] - xopt[i__];
+ }
+ cauchy = zero;
+ goto L230;
+ }
+ /* Computing 2nd power */
+ d__1 = vlag[knew];
+ if (denom <= half * (d__1 * d__1)) {
+ if (nf > nresc) {
+ goto L190;
+ }
+ throw bobyqa_failure("Return from BOBYQA because of much cancellation in a denominator.");
+ //goto L720;
+ }
+
+ /* Alternatively, if NTRITS is positive, then set KNEW to the index of */
+ /* the next interpolation point to be deleted to make room for a trust */
+ /* region step. Again RESCUE may be called if rounding errors have damaged */
+ /* the chosen denominator, which is the reason for attempting to select */
+ /* KNEW before calculating the next value of the objective function. */
+
+ } else {
+ delsq = delta * delta;
+ scaden = zero;
+ biglsq = zero;
+ knew = 0;
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ if (k == kopt) {
+ goto L350;
+ }
+ hdiag = zero;
+ i__1 = nptm;
+ for (jj = 1; jj <= i__1; ++jj) {
+ /* L330: */
+ /* Computing 2nd power */
+ d__1 = zmat[k + jj * zmat_dim1];
+ hdiag += d__1 * d__1;
+ }
+ /* Computing 2nd power */
+ d__1 = vlag[k];
+ den = beta * hdiag + d__1 * d__1;
+ distsq = zero;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ /* L340: */
+ /* Computing 2nd power */
+ d__1 = xpt[k + j * xpt_dim1] - xopt[j];
+ distsq += d__1 * d__1;
+ }
+ /* Computing MAX */
+ /* Computing 2nd power */
+ d__3 = distsq / delsq;
+ d__1 = one, d__2 = d__3 * d__3;
+ temp = std::max(d__1,d__2);
+ if (temp * den > scaden) {
+ scaden = temp * den;
+ knew = k;
+ denom = den;
+ }
+ /* Computing MAX */
+ /* Computing 2nd power */
+ d__3 = vlag[k];
+ d__1 = biglsq, d__2 = temp * (d__3 * d__3);
+ biglsq = std::max(d__1,d__2);
+L350:
+ ;
+ }
+ if (scaden <= half * biglsq) {
+ if (nf > nresc) {
+ goto L190;
+ }
+ throw bobyqa_failure("Return from BOBYQA because of much cancellation in a denominator.");
+ //goto L720;
+ }
+ }
+
+ /* Put the variables for the next calculation of the objective function */
+ /* in XNEW, with any adjustments for the bounds. */
+
+
+ /* Calculate the value of the objective function at XBASE+XNEW, unless */
+ /* the limit on the number of calculations of F has been reached. */
+
+L360:
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* Computing MIN */
+ /* Computing MAX */
+ d__3 = xl[i__], d__4 = xbase[i__] + xnew[i__];
+ d__1 = std::max(d__3,d__4), d__2 = xu[i__];
+ x[i__] = std::min(d__1,d__2);
+ if (xnew[i__] == sl[i__]) {
+ x[i__] = xl[i__];
+ }
+ if (xnew[i__] == su[i__]) {
+ x[i__] = xu[i__];
+ }
+ /* L380: */
+ }
+ if (nf >= maxfun) {
+ throw bobyqa_failure("Return from BOBYQA because the objective function has been called max_f_evals times.");
+ //goto L720;
+ }
+ ++nf;
+ f = calfun(mat(&x[1], n));
+ if (ntrits == -1) {
+ fsave = f;
+ goto L720;
+ }
+
+ /* Use the quadratic model to predict the change in F due to the step D, */
+ /* and set DIFF to the error of this prediction. */
+
+ fopt = fval[kopt];
+ vquad = zero;
+ ih = 0;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ vquad += d__[j] * gopt[j];
+ i__1 = j;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ ++ih;
+ temp = d__[i__] * d__[j];
+ if (i__ == j) {
+ temp = half * temp;
+ }
+ /* L410: */
+ vquad += hq[ih] * temp;
+ }
+ }
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L420: */
+ /* Computing 2nd power */
+ d__1 = w[npt + k];
+ vquad += half * pq[k] * (d__1 * d__1);
+ }
+ diff = f - fopt - vquad;
+ diffc = diffb;
+ diffb = diffa;
+ diffa = std::abs(diff);
+ if (dnorm > rho) {
+ nfsav = nf;
+ }
+
+ /* Pick the next value of DELTA after a trust region step. */
+
+ if (ntrits > 0) {
+ if (vquad >= zero) {
+ throw bobyqa_failure("Return from BOBYQA because a trust region step has failed to reduce Q.");
+ //goto L720;
+ }
+ ratio = (f - fopt) / vquad;
+ if (ratio <= tenth) {
+ /* Computing MIN */
+ d__1 = half * delta;
+ delta = std::min(d__1,dnorm);
+ } else if (ratio <= .7) {
+ /* Computing MAX */
+ d__1 = half * delta;
+ delta = std::max(d__1,dnorm);
+ } else {
+ /* Computing MAX */
+ d__1 = half * delta, d__2 = dnorm + dnorm;
+ delta = std::max(d__1,d__2);
+ }
+ if (delta <= rho * 1.5) {
+ delta = rho;
+ }
+
+ /* Recalculate KNEW and DENOM if the new F is less than FOPT. */
+
+ if (f < fopt) {
+ ksav = knew;
+ densav = denom;
+ delsq = delta * delta;
+ scaden = zero;
+ biglsq = zero;
+ knew = 0;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ hdiag = zero;
+ i__2 = nptm;
+ for (jj = 1; jj <= i__2; ++jj) {
+ /* L440: */
+ /* Computing 2nd power */
+ d__1 = zmat[k + jj * zmat_dim1];
+ hdiag += d__1 * d__1;
+ }
+ /* Computing 2nd power */
+ d__1 = vlag[k];
+ den = beta * hdiag + d__1 * d__1;
+ distsq = zero;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ /* L450: */
+ /* Computing 2nd power */
+ d__1 = xpt[k + j * xpt_dim1] - xnew[j];
+ distsq += d__1 * d__1;
+ }
+ /* Computing MAX */
+ /* Computing 2nd power */
+ d__3 = distsq / delsq;
+ d__1 = one, d__2 = d__3 * d__3;
+ temp = std::max(d__1,d__2);
+ if (temp * den > scaden) {
+ scaden = temp * den;
+ knew = k;
+ denom = den;
+ }
+ /* L460: */
+ /* Computing MAX */
+ /* Computing 2nd power */
+ d__3 = vlag[k];
+ d__1 = biglsq, d__2 = temp * (d__3 * d__3);
+ biglsq = std::max(d__1,d__2);
+ }
+ if (scaden <= half * biglsq) {
+ knew = ksav;
+ denom = densav;
+ }
+ }
+ }
+
+ /* Update BMAT and ZMAT, so that the KNEW-th interpolation point can be */
+ /* moved. Also update the second derivative terms of the model. */
+
+ update_(n, npt, &bmat[bmat_offset], &zmat[zmat_offset], ndim, &vlag[1],
+ beta, denom, knew, &w[1]);
+ ih = 0;
+ pqold = pq[knew];
+ pq[knew] = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ temp = pqold * xpt[knew + i__ * xpt_dim1];
+ i__2 = i__;
+ for (j = 1; j <= i__2; ++j) {
+ ++ih;
+ /* L470: */
+ hq[ih] += temp * xpt[knew + j * xpt_dim1];
+ }
+ }
+ i__2 = nptm;
+ for (jj = 1; jj <= i__2; ++jj) {
+ temp = diff * zmat[knew + jj * zmat_dim1];
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L480: */
+ pq[k] += temp * zmat[k + jj * zmat_dim1];
+ }
+ }
+
+ /* Include the new interpolation point, and make the changes to GOPT at */
+ /* the old XOPT that are caused by the updating of the quadratic model. */
+
+ fval[knew] = f;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ xpt[knew + i__ * xpt_dim1] = xnew[i__];
+ /* L490: */
+ w[i__] = bmat[knew + i__ * bmat_dim1];
+ }
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ suma = zero;
+ i__2 = nptm;
+ for (jj = 1; jj <= i__2; ++jj) {
+ /* L500: */
+ suma += zmat[knew + jj * zmat_dim1] * zmat[k + jj * zmat_dim1];
+ }
+ sumb = zero;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ /* L510: */
+ sumb += xpt[k + j * xpt_dim1] * xopt[j];
+ }
+ temp = suma * sumb;
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* L520: */
+ w[i__] += temp * xpt[k + i__ * xpt_dim1];
+ }
+ }
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* L530: */
+ gopt[i__] += diff * w[i__];
+ }
+
+ /* Update XOPT, GOPT and KOPT if the new calculated F is less than FOPT. */
+
+ if (f < fopt) {
+ kopt = knew;
+ xoptsq = zero;
+ ih = 0;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ xopt[j] = xnew[j];
+ /* Computing 2nd power */
+ d__1 = xopt[j];
+ xoptsq += d__1 * d__1;
+ i__1 = j;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ ++ih;
+ if (i__ < j) {
+ gopt[j] += hq[ih] * d__[i__];
+ }
+ /* L540: */
+ gopt[i__] += hq[ih] * d__[j];
+ }
+ }
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ temp = zero;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ /* L550: */
+ temp += xpt[k + j * xpt_dim1] * d__[j];
+ }
+ temp = pq[k] * temp;
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* L560: */
+ gopt[i__] += temp * xpt[k + i__ * xpt_dim1];
+ }
+ }
+ }
+
+ /* Calculate the parameters of the least Frobenius norm interpolant to */
+ /* the current data, the gradient of this interpolant at XOPT being put */
+ /* into VLAG(NPT+I), I=1,2,...,N. */
+
+ if (ntrits > 0) {
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ vlag[k] = fval[k] - fval[kopt];
+ /* L570: */
+ w[k] = zero;
+ }
+ i__2 = nptm;
+ for (j = 1; j <= i__2; ++j) {
+ sum = zero;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L580: */
+ sum += zmat[k + j * zmat_dim1] * vlag[k];
+ }
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L590: */
+ w[k] += sum * zmat[k + j * zmat_dim1];
+ }
+ }
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ sum = zero;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ /* L600: */
+ sum += xpt[k + j * xpt_dim1] * xopt[j];
+ }
+ w[k + npt] = w[k];
+ /* L610: */
+ w[k] = sum * w[k];
+ }
+ gqsq = zero;
+ gisq = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ sum = zero;
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ /* L620: */
+ sum = sum + bmat[k + i__ * bmat_dim1] * vlag[k] + xpt[k + i__
+ * xpt_dim1] * w[k];
+ }
+ if (xopt[i__] == sl[i__]) {
+ /* Computing MIN */
+ d__2 = zero, d__3 = gopt[i__];
+ /* Computing 2nd power */
+ d__1 = std::min(d__2,d__3);
+ gqsq += d__1 * d__1;
+ /* Computing 2nd power */
+ d__1 = std::min(zero,sum);
+ gisq += d__1 * d__1;
+ } else if (xopt[i__] == su[i__]) {
+ /* Computing MAX */
+ d__2 = zero, d__3 = gopt[i__];
+ /* Computing 2nd power */
+ d__1 = std::max(d__2,d__3);
+ gqsq += d__1 * d__1;
+ /* Computing 2nd power */
+ d__1 = std::max(zero,sum);
+ gisq += d__1 * d__1;
+ } else {
+ /* Computing 2nd power */
+ d__1 = gopt[i__];
+ gqsq += d__1 * d__1;
+ gisq += sum * sum;
+ }
+ /* L630: */
+ vlag[npt + i__] = sum;
+ }
+
+ /* Test whether to replace the new quadratic model by the least Frobenius */
+ /* norm interpolant, making the replacement if the test is satisfied. */
+
+ ++itest;
+ if (gqsq < ten * gisq) {
+ itest = 0;
+ }
+ if (itest >= 3) {
+ i__1 = std::max(npt,nh);
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (i__ <= n) {
+ gopt[i__] = vlag[npt + i__];
+ }
+ if (i__ <= npt) {
+ pq[i__] = w[npt + i__];
+ }
+ if (i__ <= nh) {
+ hq[i__] = zero;
+ }
+ itest = 0;
+ /* L640: */
+ }
+ }
+ }
+
+ /* If a trust region step has provided a sufficient decrease in F, then */
+ /* branch for another trust region calculation. The case NTRITS=0 occurs */
+ /* when the new interpolation point was reached by an alternative step. */
+
+ if (ntrits == 0) {
+ goto L60;
+ }
+ if (f <= fopt + tenth * vquad) {
+ goto L60;
+ }
+
+ /* Alternatively, find out if the interpolation points are close enough */
+ /* to the best point so far. */
+
+ /* Computing MAX */
+ /* Computing 2nd power */
+ d__3 = two * delta;
+ /* Computing 2nd power */
+ d__4 = ten * rho;
+ d__1 = d__3 * d__3, d__2 = d__4 * d__4;
+ distsq = std::max(d__1,d__2);
+L650:
+ knew = 0;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ sum = zero;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ /* L660: */
+ /* Computing 2nd power */
+ d__1 = xpt[k + j * xpt_dim1] - xopt[j];
+ sum += d__1 * d__1;
+ }
+ if (sum > distsq) {
+ knew = k;
+ distsq = sum;
+ }
+ /* L670: */
+ }
+
+ /* If KNEW is positive, then ALTMOV finds alternative new positions for */
+ /* the KNEW-th interpolation point within distance ADELT of XOPT. It is */
+ /* reached via label 90. Otherwise, there is a branch to label 60 for */
+ /* another trust region iteration, unless the calculations with the */
+ /* current RHO are complete. */
+
+ if (knew > 0) {
+ dist = std::sqrt(distsq);
+ if (ntrits == -1) {
+ /* Computing MIN */
+ d__1 = tenth * delta, d__2 = half * dist;
+ delta = std::min(d__1,d__2);
+ if (delta <= rho * 1.5) {
+ delta = rho;
+ }
+ }
+ ntrits = 0;
+ /* Computing MAX */
+ /* Computing MIN */
+ d__2 = tenth * dist;
+ d__1 = std::min(d__2,delta);
+ adelt = std::max(d__1,rho);
+ dsq = adelt * adelt;
+ goto L90;
+ }
+ if (ntrits == -1) {
+ goto L680;
+ }
+ if (ratio > zero) {
+ goto L60;
+ }
+ if (std::max(delta,dnorm) > rho) {
+ goto L60;
+ }
+
+ /* The calculations with the current value of RHO are complete. Pick the */
+ /* next values of RHO and DELTA. */
+
+L680:
+ if (rho > rhoend) {
+ delta = half * rho;
+ ratio = rho / rhoend;
+ if (ratio <= 16.) {
+ rho = rhoend;
+ } else if (ratio <= 250.) {
+ rho = std::sqrt(ratio) * rhoend;
+ } else {
+ rho = tenth * rho;
+ }
+ delta = std::max(delta,rho);
+ ntrits = 0;
+ nfsav = nf;
+ goto L60;
+ }
+
+ /* Return from the calculation, after another Newton-Raphson step, if */
+ /* it is too short to have been tried before. */
+
+ if (ntrits == -1) {
+ goto L360;
+ }
+L720:
+ if (fval[kopt] <= fsave) {
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ /* Computing MIN */
+ /* Computing MAX */
+ d__3 = xl[i__], d__4 = xbase[i__] + xopt[i__];
+ d__1 = std::max(d__3,d__4), d__2 = xu[i__];
+ x[i__] = std::min(d__1,d__2);
+ if (xopt[i__] == sl[i__]) {
+ x[i__] = xl[i__];
+ }
+ if (xopt[i__] == su[i__]) {
+ x[i__] = xu[i__];
+ }
+ /* L730: */
+ }
+ f = fval[kopt];
+ }
+
+ return f;
+ } /* bobyqb_ */
+
+ // ----------------------------------------------------------------------------------------
+
+ void altmov_(
+ const integer n,
+ const integer npt,
+ const doublereal *xpt,
+ const doublereal *xopt,
+ const doublereal *bmat,
+ const doublereal *zmat,
+ const integer ndim,
+ const doublereal *sl,
+ const doublereal *su,
+ const integer kopt,
+ const integer knew,
+ const doublereal adelt,
+ doublereal *xnew,
+ doublereal *xalt,
+ doublereal& alpha,
+ doublereal& cauchy,
+ doublereal *glag,
+ doublereal *hcol,
+ doublereal *w
+ ) const
+ {
+ /* System generated locals */
+ integer xpt_dim1, xpt_offset, bmat_dim1, bmat_offset, zmat_dim1,
+ zmat_offset, i__1, i__2;
+ doublereal d__1, d__2, d__3, d__4;
+
+
+ /* Local variables */
+ integer i__, j, k;
+ doublereal ha, gw, one, diff, half;
+ integer ilbd, isbd;
+ doublereal slbd;
+ integer iubd;
+ doublereal vlag, subd, temp;
+ integer ksav = 0;
+ doublereal step = 0, zero = 0, curv = 0;
+ integer iflag;
+ doublereal scale = 0, csave = 0, tempa = 0, tempb = 0, tempd = 0, const__ = 0, sumin = 0,
+ ggfree = 0;
+ integer ibdsav = 0;
+ doublereal dderiv = 0, bigstp = 0, predsq = 0, presav = 0, distsq = 0, stpsav = 0, wfixsq = 0, wsqsav = 0;
+
+
+ /* The arguments N, NPT, XPT, XOPT, BMAT, ZMAT, NDIM, SL and SU all have */
+ /* the same meanings as the corresponding arguments of BOBYQB. */
+ /* KOPT is the index of the optimal interpolation point. */
+ /* KNEW is the index of the interpolation point that is going to be moved. */
+ /* ADELT is the current trust region bound. */
+ /* XNEW will be set to a suitable new position for the interpolation point */
+ /* XPT(KNEW,.). Specifically, it satisfies the SL, SU and trust region */
+ /* bounds and it should provide a large denominator in the next call of */
+ /* UPDATE. The step XNEW-XOPT from XOPT is restricted to moves along the */
+ /* straight lines through XOPT and another interpolation point. */
+ /* XALT also provides a large value of the modulus of the KNEW-th Lagrange */
+ /* function subject to the constraints that have been mentioned, its main */
+ /* difference from XNEW being that XALT-XOPT is a constrained version of */
+ /* the Cauchy step within the trust region. An exception is that XALT is */
+ /* not calculated if all components of GLAG (see below) are zero. */
+ /* ALPHA will be set to the KNEW-th diagonal element of the H matrix. */
+ /* CAUCHY will be set to the square of the KNEW-th Lagrange function at */
+ /* the step XALT-XOPT from XOPT for the vector XALT that is returned, */
+ /* except that CAUCHY is set to zero if XALT is not calculated. */
+ /* GLAG is a working space vector of length N for the gradient of the */
+ /* KNEW-th Lagrange function at XOPT. */
+ /* HCOL is a working space vector of length NPT for the second derivative */
+ /* coefficients of the KNEW-th Lagrange function. */
+ /* W is a working space vector of length 2N that is going to hold the */
+ /* constrained Cauchy step from XOPT of the Lagrange function, followed */
+ /* by the downhill version of XALT when the uphill step is calculated. */
+
+ /* Set the first NPT components of W to the leading elements of the */
+ /* KNEW-th column of the H matrix. */
+
+ /* Parameter adjustments */
+ zmat_dim1 = npt;
+ zmat_offset = 1 + zmat_dim1;
+ zmat -= zmat_offset;
+ xpt_dim1 = npt;
+ xpt_offset = 1 + xpt_dim1;
+ xpt -= xpt_offset;
+ --xopt;
+ bmat_dim1 = ndim;
+ bmat_offset = 1 + bmat_dim1;
+ bmat -= bmat_offset;
+ --sl;
+ --su;
+ --xnew;
+ --xalt;
+ --glag;
+ --hcol;
+ --w;
+
+ /* Function Body */
+ half = .5;
+ one = 1.;
+ zero = 0.;
+ const__ = one + std::sqrt(2.);
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L10: */
+ hcol[k] = zero;
+ }
+ i__1 = npt - n - 1;
+ for (j = 1; j <= i__1; ++j) {
+ temp = zmat[knew + j * zmat_dim1];
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ /* L20: */
+ hcol[k] += temp * zmat[k + j * zmat_dim1];
+ }
+ }
+ alpha = hcol[knew];
+ ha = half * alpha;
+
+ /* Calculate the gradient of the KNEW-th Lagrange function at XOPT. */
+
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* L30: */
+ glag[i__] = bmat[knew + i__ * bmat_dim1];
+ }
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ temp = zero;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ /* L40: */
+ temp += xpt[k + j * xpt_dim1] * xopt[j];
+ }
+ temp = hcol[k] * temp;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ /* L50: */
+ glag[i__] += temp * xpt[k + i__ * xpt_dim1];
+ }
+ }
+
+ /* Search for a large denominator along the straight lines through XOPT */
+ /* and another interpolation point. SLBD and SUBD will be lower and upper */
+ /* bounds on the step along each of these lines in turn. PREDSQ will be */
+ /* set to the square of the predicted denominator for each line. PRESAV */
+ /* will be set to the largest admissible value of PREDSQ that occurs. */
+
+ presav = zero;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ if (k == kopt) {
+ goto L80;
+ }
+ dderiv = zero;
+ distsq = zero;
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ temp = xpt[k + i__ * xpt_dim1] - xopt[i__];
+ dderiv += glag[i__] * temp;
+ /* L60: */
+ distsq += temp * temp;
+ }
+ subd = adelt / std::sqrt(distsq);
+ slbd = -subd;
+ ilbd = 0;
+ iubd = 0;
+ sumin = std::min(one,subd);
+
+ /* Revise SLBD and SUBD if necessary because of the bounds in SL and SU. */
+
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ temp = xpt[k + i__ * xpt_dim1] - xopt[i__];
+ if (temp > zero) {
+ if (slbd * temp < sl[i__] - xopt[i__]) {
+ slbd = (sl[i__] - xopt[i__]) / temp;
+ ilbd = -i__;
+ }
+ if (subd * temp > su[i__] - xopt[i__]) {
+ /* Computing MAX */
+ d__1 = sumin, d__2 = (su[i__] - xopt[i__]) / temp;
+ subd = std::max(d__1,d__2);
+ iubd = i__;
+ }
+ } else if (temp < zero) {
+ if (slbd * temp > su[i__] - xopt[i__]) {
+ slbd = (su[i__] - xopt[i__]) / temp;
+ ilbd = i__;
+ }
+ if (subd * temp < sl[i__] - xopt[i__]) {
+ /* Computing MAX */
+ d__1 = sumin, d__2 = (sl[i__] - xopt[i__]) / temp;
+ subd = std::max(d__1,d__2);
+ iubd = -i__;
+ }
+ }
+ /* L70: */
+ }
+
+ /* Seek a large modulus of the KNEW-th Lagrange function when the index */
+ /* of the other interpolation point on the line through XOPT is KNEW. */
+
+ if (k == knew) {
+ diff = dderiv - one;
+ step = slbd;
+ vlag = slbd * (dderiv - slbd * diff);
+ isbd = ilbd;
+ temp = subd * (dderiv - subd * diff);
+ if (std::abs(temp) > std::abs(vlag)) {
+ step = subd;
+ vlag = temp;
+ isbd = iubd;
+ }
+ tempd = half * dderiv;
+ tempa = tempd - diff * slbd;
+ tempb = tempd - diff * subd;
+ if (tempa * tempb < zero) {
+ temp = tempd * tempd / diff;
+ if (std::abs(temp) > std::abs(vlag)) {
+ step = tempd / diff;
+ vlag = temp;
+ isbd = 0;
+ }
+ }
+
+ /* Search along each of the other lines through XOPT and another point. */
+
+ } else {
+ step = slbd;
+ vlag = slbd * (one - slbd);
+ isbd = ilbd;
+ temp = subd * (one - subd);
+ if (std::abs(temp) > std::abs(vlag)) {
+ step = subd;
+ vlag = temp;
+ isbd = iubd;
+ }
+ if (subd > half) {
+ if (std::abs(vlag) < .25) {
+ step = half;
+ vlag = .25;
+ isbd = 0;
+ }
+ }
+ vlag *= dderiv;
+ }
+
+ /* Calculate PREDSQ for the current line search and maintain PRESAV. */
+
+ temp = step * (one - step) * distsq;
+ predsq = vlag * vlag * (vlag * vlag + ha * temp * temp);
+ if (predsq > presav) {
+ presav = predsq;
+ ksav = k;
+ stpsav = step;
+ ibdsav = isbd;
+ }
+L80:
+ ;
+ }
+
+ /* Construct XNEW in a way that satisfies the bound constraints exactly. */
+
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ temp = xopt[i__] + stpsav * (xpt[ksav + i__ * xpt_dim1] - xopt[i__]);
+ /* L90: */
+ /* Computing MAX */
+ /* Computing MIN */
+ d__3 = su[i__];
+ d__1 = sl[i__], d__2 = std::min(d__3,temp);
+ xnew[i__] = std::max(d__1,d__2);
+ }
+ if (ibdsav < 0) {
+ xnew[-ibdsav] = sl[-ibdsav];
+ }
+ if (ibdsav > 0) {
+ xnew[ibdsav] = su[ibdsav];
+ }
+
+ /* Prepare for the iterative method that assembles the constrained Cauchy */
+ /* step in W. The sum of squares of the fixed components of W is formed in */
+ /* WFIXSQ, and the free components of W are set to BIGSTP. */
+
+ bigstp = adelt + adelt;
+ iflag = 0;
+L100:
+ wfixsq = zero;
+ ggfree = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ w[i__] = zero;
+ /* Computing MIN */
+ d__1 = xopt[i__] - sl[i__], d__2 = glag[i__];
+ tempa = std::min(d__1,d__2);
+ /* Computing MAX */
+ d__1 = xopt[i__] - su[i__], d__2 = glag[i__];
+ tempb = std::max(d__1,d__2);
+ if (tempa > zero || tempb < zero) {
+ w[i__] = bigstp;
+ /* Computing 2nd power */
+ d__1 = glag[i__];
+ ggfree += d__1 * d__1;
+ }
+ /* L110: */
+ }
+ if (ggfree == zero) {
+ cauchy = zero;
+ goto L200;
+ }
+
+ /* Investigate whether more components of W can be fixed. */
+
+L120:
+ temp = adelt * adelt - wfixsq;
+ if (temp > zero) {
+ wsqsav = wfixsq;
+ step = std::sqrt(temp / ggfree);
+ ggfree = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (w[i__] == bigstp) {
+ temp = xopt[i__] - step * glag[i__];
+ if (temp <= sl[i__]) {
+ w[i__] = sl[i__] - xopt[i__];
+ /* Computing 2nd power */
+ d__1 = w[i__];
+ wfixsq += d__1 * d__1;
+ } else if (temp >= su[i__]) {
+ w[i__] = su[i__] - xopt[i__];
+ /* Computing 2nd power */
+ d__1 = w[i__];
+ wfixsq += d__1 * d__1;
+ } else {
+ /* Computing 2nd power */
+ d__1 = glag[i__];
+ ggfree += d__1 * d__1;
+ }
+ }
+ /* L130: */
+ }
+ if (wfixsq > wsqsav && ggfree > zero) {
+ goto L120;
+ }
+ }
+
+ /* Set the remaining free components of W and all components of XALT, */
+ /* except that W may be scaled later. */
+
+ gw = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (w[i__] == bigstp) {
+ w[i__] = -step * glag[i__];
+ /* Computing MAX */
+ /* Computing MIN */
+ d__3 = su[i__], d__4 = xopt[i__] + w[i__];
+ d__1 = sl[i__], d__2 = std::min(d__3,d__4);
+ xalt[i__] = std::max(d__1,d__2);
+ } else if (w[i__] == zero) {
+ xalt[i__] = xopt[i__];
+ } else if (glag[i__] > zero) {
+ xalt[i__] = sl[i__];
+ } else {
+ xalt[i__] = su[i__];
+ }
+ /* L140: */
+ gw += glag[i__] * w[i__];
+ }
+
+ /* Set CURV to the curvature of the KNEW-th Lagrange function along W. */
+ /* Scale W by a factor less than one if that can reduce the modulus of */
+ /* the Lagrange function at XOPT+W. Set CAUCHY to the final value of */
+ /* the square of this function. */
+
+ curv = zero;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ temp = zero;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ /* L150: */
+ temp += xpt[k + j * xpt_dim1] * w[j];
+ }
+ /* L160: */
+ curv += hcol[k] * temp * temp;
+ }
+ if (iflag == 1) {
+ curv = -curv;
+ }
+ if (curv > -gw && curv < -const__ * gw) {
+ scale = -gw / curv;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ temp = xopt[i__] + scale * w[i__];
+ /* L170: */
+ /* Computing MAX */
+ /* Computing MIN */
+ d__3 = su[i__];
+ d__1 = sl[i__], d__2 = std::min(d__3,temp);
+ xalt[i__] = std::max(d__1,d__2);
+ }
+ /* Computing 2nd power */
+ d__1 = half * gw * scale;
+ cauchy = d__1 * d__1;
+ } else {
+ /* Computing 2nd power */
+ d__1 = gw + half * curv;
+ cauchy = d__1 * d__1;
+ }
+
+ /* If IFLAG is zero, then XALT is calculated as before after reversing */
+ /* the sign of GLAG. Thus two XALT vectors become available. The one that */
+ /* is chosen is the one that gives the larger value of CAUCHY. */
+
+ if (iflag == 0) {
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ glag[i__] = -glag[i__];
+ /* L180: */
+ w[n + i__] = xalt[i__];
+ }
+ csave = cauchy;
+ iflag = 1;
+ goto L100;
+ }
+ if (csave > cauchy) {
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ /* L190: */
+ xalt[i__] = w[n + i__];
+ }
+ cauchy = csave;
+ }
+L200:
+ ;
+ } /* altmov_ */
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename funct>
+ void prelim_(
+ const funct& calfun,
+ const integer n,
+ const integer npt,
+ doublereal *x,
+ const doublereal *xl,
+ const doublereal *xu,
+ const doublereal rhobeg,
+ const integer maxfun,
+ doublereal *xbase,
+ doublereal *xpt,
+ doublereal *fval,
+ doublereal *gopt,
+ doublereal *hq,
+ doublereal *pq,
+ doublereal *bmat,
+ doublereal *zmat,
+ const integer ndim,
+ const doublereal *sl,
+ const doublereal *su,
+ integer& nf,
+ integer& kopt
+ ) const
+ {
+ /* System generated locals */
+ integer xpt_dim1, xpt_offset, bmat_dim1, bmat_offset, zmat_dim1,
+ zmat_offset, i__1, i__2;
+ doublereal d__1, d__2, d__3, d__4;
+
+
+ /* Local variables */
+ doublereal f;
+ integer i__, j, k, ih, np, nfm;
+ doublereal one;
+ integer nfx = 0, ipt = 0, jpt = 0;
+ doublereal two = 0, fbeg = 0, diff = 0, half = 0, temp = 0, zero = 0, recip = 0, stepa = 0, stepb = 0;
+ integer itemp;
+ doublereal rhosq;
+
+
+
+ /* The arguments N, NPT, X, XL, XU, RHOBEG, IPRINT and MAXFUN are the */
+ /* same as the corresponding arguments in SUBROUTINE BOBYQA. */
+ /* The arguments XBASE, XPT, FVAL, HQ, PQ, BMAT, ZMAT, NDIM, SL and SU */
+ /* are the same as the corresponding arguments in BOBYQB, the elements */
+ /* of SL and SU being set in BOBYQA. */
+ /* GOPT is usually the gradient of the quadratic model at XOPT+XBASE, but */
+ /* it is set by PRELIM to the gradient of the quadratic model at XBASE. */
+ /* If XOPT is nonzero, BOBYQB will change it to its usual value later. */
+ /* NF is maintaned as the number of calls of CALFUN so far. */
+ /* KOPT will be such that the least calculated value of F so far is at */
+ /* the point XPT(KOPT,.)+XBASE in the space of the variables. */
+
+ /* SUBROUTINE PRELIM sets the elements of XBASE, XPT, FVAL, GOPT, HQ, PQ, */
+ /* BMAT and ZMAT for the first iteration, and it maintains the values of */
+ /* NF and KOPT. The vector X is also changed by PRELIM. */
+
+ /* Set some constants. */
+
+ /* Parameter adjustments */
+ zmat_dim1 = npt;
+ zmat_offset = 1 + zmat_dim1;
+ zmat -= zmat_offset;
+ xpt_dim1 = npt;
+ xpt_offset = 1 + xpt_dim1;
+ xpt -= xpt_offset;
+ --x;
+ --xl;
+ --xu;
+ --xbase;
+ --fval;
+ --gopt;
+ --hq;
+ --pq;
+ bmat_dim1 = ndim;
+ bmat_offset = 1 + bmat_dim1;
+ bmat -= bmat_offset;
+ --sl;
+ --su;
+
+ /* Function Body */
+ half = .5;
+ one = 1.;
+ two = 2.;
+ zero = 0.;
+ rhosq = rhobeg * rhobeg;
+ recip = one / rhosq;
+ np = n + 1;
+
+ /* Set XBASE to the initial vector of variables, and set the initial */
+ /* elements of XPT, BMAT, HQ, PQ and ZMAT to zero. */
+
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ xbase[j] = x[j];
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ /* L10: */
+ xpt[k + j * xpt_dim1] = zero;
+ }
+ i__2 = ndim;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* L20: */
+ bmat[i__ + j * bmat_dim1] = zero;
+ }
+ }
+ i__2 = n * np / 2;
+ for (ih = 1; ih <= i__2; ++ih) {
+ /* L30: */
+ hq[ih] = zero;
+ }
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ pq[k] = zero;
+ i__1 = npt - np;
+ for (j = 1; j <= i__1; ++j) {
+ /* L40: */
+ zmat[k + j * zmat_dim1] = zero;
+ }
+ }
+
+ /* Begin the initialization procedure. NF becomes one more than the number */
+ /* of function values so far. The coordinates of the displacement of the */
+ /* next initial interpolation point from XBASE are set in XPT(NF+1,.). */
+
+ nf = 0;
+L50:
+ nfm = nf;
+ nfx = nf - n;
+ ++(nf);
+ if (nfm <= n << 1) {
+ if (nfm >= 1 && nfm <= n) {
+ stepa = rhobeg;
+ if (su[nfm] == zero) {
+ stepa = -stepa;
+ }
+ xpt[nf + nfm * xpt_dim1] = stepa;
+ } else if (nfm > n) {
+ stepa = xpt[nf - n + nfx * xpt_dim1];
+ stepb = -(rhobeg);
+ if (sl[nfx] == zero) {
+ /* Computing MIN */
+ d__1 = two * rhobeg, d__2 = su[nfx];
+ stepb = std::min(d__1,d__2);
+ }
+ if (su[nfx] == zero) {
+ /* Computing MAX */
+ d__1 = -two * rhobeg, d__2 = sl[nfx];
+ stepb = std::max(d__1,d__2);
+ }
+ xpt[nf + nfx * xpt_dim1] = stepb;
+ }
+ } else {
+ itemp = (nfm - np) / n;
+ jpt = nfm - itemp * n - n;
+ ipt = jpt + itemp;
+ if (ipt > n) {
+ itemp = jpt;
+ jpt = ipt - n;
+ ipt = itemp;
+ }
+ xpt[nf + ipt * xpt_dim1] = xpt[ipt + 1 + ipt * xpt_dim1];
+ xpt[nf + jpt * xpt_dim1] = xpt[jpt + 1 + jpt * xpt_dim1];
+ }
+
+ /* Calculate the next value of F. The least function value so far and */
+ /* its index are required. */
+
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ /* Computing MIN */
+ /* Computing MAX */
+ d__3 = xl[j], d__4 = xbase[j] + xpt[nf + j * xpt_dim1];
+ d__1 = std::max(d__3,d__4), d__2 = xu[j];
+ x[j] = std::min(d__1,d__2);
+ if (xpt[nf + j * xpt_dim1] == sl[j]) {
+ x[j] = xl[j];
+ }
+ if (xpt[nf + j * xpt_dim1] == su[j]) {
+ x[j] = xu[j];
+ }
+ /* L60: */
+ }
+ f = calfun(mat(&x[1],n));
+ fval[nf] = f;
+ if (nf == 1) {
+ fbeg = f;
+ kopt = 1;
+ } else if (f < fval[kopt]) {
+ kopt = nf;
+ }
+
+ /* Set the nonzero initial elements of BMAT and the quadratic model in the */
+ /* cases when NF is at most 2*N+1. If NF exceeds N+1, then the positions */
+ /* of the NF-th and (NF-N)-th interpolation points may be switched, in */
+ /* order that the function value at the first of them contributes to the */
+ /* off-diagonal second derivative terms of the initial quadratic model. */
+
+ if (nf <= (n << 1) + 1) {
+ if (nf >= 2 && nf <= n + 1) {
+ gopt[nfm] = (f - fbeg) / stepa;
+ if (npt < nf + n) {
+ bmat[nfm * bmat_dim1 + 1] = -one / stepa;
+ bmat[nf + nfm * bmat_dim1] = one / stepa;
+ bmat[npt + nfm + nfm * bmat_dim1] = -half * rhosq;
+ }
+ } else if (nf >= n + 2) {
+ ih = nfx * (nfx + 1) / 2;
+ temp = (f - fbeg) / stepb;
+ diff = stepb - stepa;
+ hq[ih] = two * (temp - gopt[nfx]) / diff;
+ gopt[nfx] = (gopt[nfx] * stepb - temp * stepa) / diff;
+ if (stepa * stepb < zero) {
+ if (f < fval[nf - n]) {
+ fval[nf] = fval[nf - n];
+ fval[nf - n] = f;
+ if (kopt == nf) {
+ kopt = nf - n;
+ }
+ xpt[nf - n + nfx * xpt_dim1] = stepb;
+ xpt[nf + nfx * xpt_dim1] = stepa;
+ }
+ }
+ bmat[nfx * bmat_dim1 + 1] = -(stepa + stepb) / (stepa * stepb);
+ bmat[nf + nfx * bmat_dim1] = -half / xpt[nf - n + nfx *
+ xpt_dim1];
+ bmat[nf - n + nfx * bmat_dim1] = -bmat[nfx * bmat_dim1 + 1] -
+ bmat[nf + nfx * bmat_dim1];
+ zmat[nfx * zmat_dim1 + 1] = std::sqrt(two) / (stepa * stepb);
+ zmat[nf + nfx * zmat_dim1] = std::sqrt(half) / rhosq;
+ zmat[nf - n + nfx * zmat_dim1] = -zmat[nfx * zmat_dim1 + 1] -
+ zmat[nf + nfx * zmat_dim1];
+ }
+
+ /* Set the off-diagonal second derivatives of the Lagrange functions and */
+ /* the initial quadratic model. */
+
+ } else {
+ ih = ipt * (ipt - 1) / 2 + jpt;
+ zmat[nfx * zmat_dim1 + 1] = recip;
+ zmat[nf + nfx * zmat_dim1] = recip;
+ zmat[ipt + 1 + nfx * zmat_dim1] = -recip;
+ zmat[jpt + 1 + nfx * zmat_dim1] = -recip;
+ temp = xpt[nf + ipt * xpt_dim1] * xpt[nf + jpt * xpt_dim1];
+ hq[ih] = (fbeg - fval[ipt + 1] - fval[jpt + 1] + f) / temp;
+ }
+ if (nf < npt && nf < maxfun) {
+ goto L50;
+ }
+
+ } /* prelim_ */
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename funct>
+ void rescue_ (
+ const funct& calfun,
+ const integer n,
+ const integer npt,
+ const doublereal *xl,
+ const doublereal *xu,
+ const integer maxfun,
+ doublereal *xbase,
+ doublereal *xpt,
+ doublereal *fval,
+ doublereal *xopt,
+ doublereal *gopt,
+ doublereal *hq,
+ doublereal *pq,
+ doublereal *bmat,
+ doublereal *zmat,
+ const integer ndim,
+ doublereal *sl,
+ doublereal *su,
+ integer& nf,
+ const doublereal delta,
+ integer& kopt,
+ doublereal *vlag,
+ doublereal * ptsaux,
+ doublereal *ptsid,
+ doublereal *w
+ ) const
+ {
+ /* System generated locals */
+ integer xpt_dim1, xpt_offset, bmat_dim1, bmat_offset, zmat_dim1,
+ zmat_offset, i__1, i__2, i__3;
+ doublereal d__1, d__2, d__3, d__4;
+
+
+ /* Local variables */
+ doublereal f;
+ integer i__, j, k, ih, jp, ip, iq, np, iw;
+ doublereal xp = 0, xq = 0, den = 0;
+ integer ihp = 0;
+ doublereal one;
+ integer ihq, jpn, kpt;
+ doublereal sum = 0, diff = 0, half = 0, beta = 0;
+ integer kold;
+ doublereal winc;
+ integer nrem, knew;
+ doublereal temp, bsum;
+ integer nptm;
+ doublereal zero = 0, hdiag = 0, fbase = 0, sfrac = 0, denom = 0, vquad = 0, sumpq = 0;
+ doublereal dsqmin, distsq, vlmxsq;
+
+
+
+ /* The arguments N, NPT, XL, XU, IPRINT, MAXFUN, XBASE, XPT, FVAL, XOPT, */
+ /* GOPT, HQ, PQ, BMAT, ZMAT, NDIM, SL and SU have the same meanings as */
+ /* the corresponding arguments of BOBYQB on the entry to RESCUE. */
+ /* NF is maintained as the number of calls of CALFUN so far, except that */
+ /* NF is set to -1 if the value of MAXFUN prevents further progress. */
+ /* KOPT is maintained so that FVAL(KOPT) is the least calculated function */
+ /* value. Its correct value must be given on entry. It is updated if a */
+ /* new least function value is found, but the corresponding changes to */
+ /* XOPT and GOPT have to be made later by the calling program. */
+ /* DELTA is the current trust region radius. */
+ /* VLAG is a working space vector that will be used for the values of the */
+ /* provisional Lagrange functions at each of the interpolation points. */
+ /* They are part of a product that requires VLAG to be of length NDIM. */
+ /* PTSAUX is also a working space array. For J=1,2,...,N, PTSAUX(1,J) and */
+ /* PTSAUX(2,J) specify the two positions of provisional interpolation */
+ /* points when a nonzero step is taken along e_J (the J-th coordinate */
+ /* direction) through XBASE+XOPT, as specified below. Usually these */
+ /* steps have length DELTA, but other lengths are chosen if necessary */
+ /* in order to satisfy the given bounds on the variables. */
+ /* PTSID is also a working space array. It has NPT components that denote */
+ /* provisional new positions of the original interpolation points, in */
+ /* case changes are needed to restore the linear independence of the */
+ /* interpolation conditions. The K-th point is a candidate for change */
+ /* if and only if PTSID(K) is nonzero. In this case let p and q be the */
+ /* integer parts of PTSID(K) and (PTSID(K)-p) multiplied by N+1. If p */
+ /* and q are both positive, the step from XBASE+XOPT to the new K-th */
+ /* interpolation point is PTSAUX(1,p)*e_p + PTSAUX(1,q)*e_q. Otherwise */
+ /* the step is PTSAUX(1,p)*e_p or PTSAUX(2,q)*e_q in the cases q=0 or */
+ /* p=0, respectively. */
+ /* The first NDIM+NPT elements of the array W are used for working space. */
+ /* The final elements of BMAT and ZMAT are set in a well-conditioned way */
+ /* to the values that are appropriate for the new interpolation points. */
+ /* The elements of GOPT, HQ and PQ are also revised to the values that are */
+ /* appropriate to the final quadratic model. */
+
+ /* Set some constants. */
+
+ /* Parameter adjustments */
+ zmat_dim1 = npt;
+ zmat_offset = 1 + zmat_dim1;
+ zmat -= zmat_offset;
+ xpt_dim1 = npt;
+ xpt_offset = 1 + xpt_dim1;
+ xpt -= xpt_offset;
+ --xl;
+ --xu;
+ --xbase;
+ --fval;
+ --xopt;
+ --gopt;
+ --hq;
+ --pq;
+ bmat_dim1 = ndim;
+ bmat_offset = 1 + bmat_dim1;
+ bmat -= bmat_offset;
+ --sl;
+ --su;
+ --vlag;
+ ptsaux -= 3;
+ --ptsid;
+ --w;
+
+ /* Function Body */
+ half = .5;
+ one = 1.;
+ zero = 0.;
+ np = n + 1;
+ sfrac = half / (doublereal) np;
+ nptm = npt - np;
+
+ /* Shift the interpolation points so that XOPT becomes the origin, and set */
+ /* the elements of ZMAT to zero. The value of SUMPQ is required in the */
+ /* updating of HQ below. The squares of the distances from XOPT to the */
+ /* other interpolation points are set at the end of W. Increments of WINC */
+ /* may be added later to these squares to balance the consideration of */
+ /* the choice of point that is going to become current. */
+
+ sumpq = zero;
+ winc = zero;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ distsq = zero;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ xpt[k + j * xpt_dim1] -= xopt[j];
+ /* L10: */
+ /* Computing 2nd power */
+ d__1 = xpt[k + j * xpt_dim1];
+ distsq += d__1 * d__1;
+ }
+ sumpq += pq[k];
+ w[ndim + k] = distsq;
+ winc = std::max(winc,distsq);
+ i__2 = nptm;
+ for (j = 1; j <= i__2; ++j) {
+ /* L20: */
+ zmat[k + j * zmat_dim1] = zero;
+ }
+ }
+
+ /* Update HQ so that HQ and PQ define the second derivatives of the model */
+ /* after XBASE has been shifted to the trust region centre. */
+
+ ih = 0;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ w[j] = half * sumpq * xopt[j];
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L30: */
+ w[j] += pq[k] * xpt[k + j * xpt_dim1];
+ }
+ i__1 = j;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ ++ih;
+ /* L40: */
+ hq[ih] = hq[ih] + w[i__] * xopt[j] + w[j] * xopt[i__];
+ }
+ }
+
+ /* Shift XBASE, SL, SU and XOPT. Set the elements of BMAT to zero, and */
+ /* also set the elements of PTSAUX. */
+
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ xbase[j] += xopt[j];
+ sl[j] -= xopt[j];
+ su[j] -= xopt[j];
+ xopt[j] = zero;
+ /* Computing MIN */
+ d__1 = delta, d__2 = su[j];
+ ptsaux[(j << 1) + 1] = std::min(d__1,d__2);
+ /* Computing MAX */
+ d__1 = -(delta), d__2 = sl[j];
+ ptsaux[(j << 1) + 2] = std::max(d__1,d__2);
+ if (ptsaux[(j << 1) + 1] + ptsaux[(j << 1) + 2] < zero) {
+ temp = ptsaux[(j << 1) + 1];
+ ptsaux[(j << 1) + 1] = ptsaux[(j << 1) + 2];
+ ptsaux[(j << 1) + 2] = temp;
+ }
+ if ((d__2 = ptsaux[(j << 1) + 2], std::abs(d__2)) < half * (d__1 = ptsaux[(
+ j << 1) + 1], std::abs(d__1))) {
+ ptsaux[(j << 1) + 2] = half * ptsaux[(j << 1) + 1];
+ }
+ i__2 = ndim;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* L50: */
+ bmat[i__ + j * bmat_dim1] = zero;
+ }
+ }
+ fbase = fval[kopt];
+
+ /* Set the identifiers of the artificial interpolation points that are */
+ /* along a coordinate direction from XOPT, and set the corresponding */
+ /* nonzero elements of BMAT and ZMAT. */
+
+ ptsid[1] = sfrac;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ jp = j + 1;
+ jpn = jp + n;
+ ptsid[jp] = (doublereal) j + sfrac;
+ if (jpn <= npt) {
+ ptsid[jpn] = (doublereal) j / (doublereal) np + sfrac;
+ temp = one / (ptsaux[(j << 1) + 1] - ptsaux[(j << 1) + 2]);
+ bmat[jp + j * bmat_dim1] = -temp + one / ptsaux[(j << 1) + 1];
+ bmat[jpn + j * bmat_dim1] = temp + one / ptsaux[(j << 1) + 2];
+ bmat[j * bmat_dim1 + 1] = -bmat[jp + j * bmat_dim1] - bmat[jpn +
+ j * bmat_dim1];
+ zmat[j * zmat_dim1 + 1] = std::sqrt(2.) / (d__1 = ptsaux[(j << 1) + 1]
+ * ptsaux[(j << 1) + 2], std::abs(d__1));
+ zmat[jp + j * zmat_dim1] = zmat[j * zmat_dim1 + 1] * ptsaux[(j <<
+ 1) + 2] * temp;
+ zmat[jpn + j * zmat_dim1] = -zmat[j * zmat_dim1 + 1] * ptsaux[(j
+ << 1) + 1] * temp;
+ } else {
+ bmat[j * bmat_dim1 + 1] = -one / ptsaux[(j << 1) + 1];
+ bmat[jp + j * bmat_dim1] = one / ptsaux[(j << 1) + 1];
+ /* Computing 2nd power */
+ d__1 = ptsaux[(j << 1) + 1];
+ bmat[j + npt + j * bmat_dim1] = -half * (d__1 * d__1);
+ }
+ /* L60: */
+ }
+
+ /* Set any remaining identifiers with their nonzero elements of ZMAT. */
+
+ if (npt >= n + np) {
+ i__2 = npt;
+ for (k = np << 1; k <= i__2; ++k) {
+ iw = (integer) (((doublereal) (k - np) - half) / (doublereal) (n)
+ );
+ ip = k - np - iw * n;
+ iq = ip + iw;
+ if (iq > n) {
+ iq -= n;
+ }
+ ptsid[k] = (doublereal) ip + (doublereal) iq / (doublereal) np +
+ sfrac;
+ temp = one / (ptsaux[(ip << 1) + 1] * ptsaux[(iq << 1) + 1]);
+ zmat[(k - np) * zmat_dim1 + 1] = temp;
+ zmat[ip + 1 + (k - np) * zmat_dim1] = -temp;
+ zmat[iq + 1 + (k - np) * zmat_dim1] = -temp;
+ /* L70: */
+ zmat[k + (k - np) * zmat_dim1] = temp;
+ }
+ }
+ nrem = npt;
+ kold = 1;
+ knew = kopt;
+
+ /* Reorder the provisional points in the way that exchanges PTSID(KOLD) */
+ /* with PTSID(KNEW). */
+
+L80:
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ temp = bmat[kold + j * bmat_dim1];
+ bmat[kold + j * bmat_dim1] = bmat[knew + j * bmat_dim1];
+ /* L90: */
+ bmat[knew + j * bmat_dim1] = temp;
+ }
+ i__2 = nptm;
+ for (j = 1; j <= i__2; ++j) {
+ temp = zmat[kold + j * zmat_dim1];
+ zmat[kold + j * zmat_dim1] = zmat[knew + j * zmat_dim1];
+ /* L100: */
+ zmat[knew + j * zmat_dim1] = temp;
+ }
+ ptsid[kold] = ptsid[knew];
+ ptsid[knew] = zero;
+ w[ndim + knew] = zero;
+ --nrem;
+ if (knew != kopt) {
+ temp = vlag[kold];
+ vlag[kold] = vlag[knew];
+ vlag[knew] = temp;
+
+ /* Update the BMAT and ZMAT matrices so that the status of the KNEW-th */
+ /* interpolation point can be changed from provisional to original. The */
+ /* branch to label 350 occurs if all the original points are reinstated. */
+ /* The nonnegative values of W(NDIM+K) are required in the search below. */
+
+ update_(n, npt, &bmat[bmat_offset], &zmat[zmat_offset], ndim, &vlag[1],
+ beta, denom, knew, &w[1]);
+ if (nrem == 0) {
+ goto L350;
+ }
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ /* L110: */
+ w[ndim + k] = (d__1 = w[ndim + k], std::abs(d__1));
+ }
+ }
+
+ /* Pick the index KNEW of an original interpolation point that has not */
+ /* yet replaced one of the provisional interpolation points, giving */
+ /* attention to the closeness to XOPT and to previous tries with KNEW. */
+
+L120:
+ dsqmin = zero;
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ if (w[ndim + k] > zero) {
+ if (dsqmin == zero || w[ndim + k] < dsqmin) {
+ knew = k;
+ dsqmin = w[ndim + k];
+ }
+ }
+ /* L130: */
+ }
+ if (dsqmin == zero) {
+ goto L260;
+ }
+
+ /* Form the W-vector of the chosen original interpolation point. */
+
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ /* L140: */
+ w[npt + j] = xpt[knew + j * xpt_dim1];
+ }
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ sum = zero;
+ if (k == kopt) {
+ } else if (ptsid[k] == zero) {
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ /* L150: */
+ sum += w[npt + j] * xpt[k + j * xpt_dim1];
+ }
+ } else {
+ ip = (integer) ptsid[k];
+ if (ip > 0) {
+ sum = w[npt + ip] * ptsaux[(ip << 1) + 1];
+ }
+ iq = (integer) ((doublereal) np * ptsid[k] - (doublereal) (ip *
+ np));
+ if (iq > 0) {
+ iw = 1;
+ if (ip == 0) {
+ iw = 2;
+ }
+ sum += w[npt + iq] * ptsaux[iw + (iq << 1)];
+ }
+ }
+ /* L160: */
+ w[k] = half * sum * sum;
+ }
+
+ /* Calculate VLAG and BETA for the required updating of the H matrix if */
+ /* XPT(KNEW,.) is reinstated in the set of interpolation points. */
+
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ sum = zero;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ /* L170: */
+ sum += bmat[k + j * bmat_dim1] * w[npt + j];
+ }
+ /* L180: */
+ vlag[k] = sum;
+ }
+ beta = zero;
+ i__2 = nptm;
+ for (j = 1; j <= i__2; ++j) {
+ sum = zero;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L190: */
+ sum += zmat[k + j * zmat_dim1] * w[k];
+ }
+ beta -= sum * sum;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ /* L200: */
+ vlag[k] += sum * zmat[k + j * zmat_dim1];
+ }
+ }
+ bsum = zero;
+ distsq = zero;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ sum = zero;
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ /* L210: */
+ sum += bmat[k + j * bmat_dim1] * w[k];
+ }
+ jp = j + npt;
+ bsum += sum * w[jp];
+ i__2 = ndim;
+ for (ip = npt + 1; ip <= i__2; ++ip) {
+ /* L220: */
+ sum += bmat[ip + j * bmat_dim1] * w[ip];
+ }
+ bsum += sum * w[jp];
+ vlag[jp] = sum;
+ /* L230: */
+ /* Computing 2nd power */
+ d__1 = xpt[knew + j * xpt_dim1];
+ distsq += d__1 * d__1;
+ }
+ beta = half * distsq * distsq + beta - bsum;
+ vlag[kopt] += one;
+
+ /* KOLD is set to the index of the provisional interpolation point that is */
+ /* going to be deleted to make way for the KNEW-th original interpolation */
+ /* point. The choice of KOLD is governed by the avoidance of a small value */
+ /* of the denominator in the updating calculation of UPDATE. */
+
+ denom = zero;
+ vlmxsq = zero;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ if (ptsid[k] != zero) {
+ hdiag = zero;
+ i__2 = nptm;
+ for (j = 1; j <= i__2; ++j) {
+ /* L240: */
+ /* Computing 2nd power */
+ d__1 = zmat[k + j * zmat_dim1];
+ hdiag += d__1 * d__1;
+ }
+ /* Computing 2nd power */
+ d__1 = vlag[k];
+ den = beta * hdiag + d__1 * d__1;
+ if (den > denom) {
+ kold = k;
+ denom = den;
+ }
+ }
+ /* L250: */
+ /* Computing MAX */
+ /* Computing 2nd power */
+ d__3 = vlag[k];
+ d__1 = vlmxsq, d__2 = d__3 * d__3;
+ vlmxsq = std::max(d__1,d__2);
+ }
+ if (denom <= vlmxsq * .01) {
+ w[ndim + knew] = -w[ndim + knew] - winc;
+ goto L120;
+ }
+ goto L80;
+
+ /* When label 260 is reached, all the final positions of the interpolation */
+ /* points have been chosen although any changes have not been included yet */
+ /* in XPT. Also the final BMAT and ZMAT matrices are complete, but, apart */
+ /* from the shift of XBASE, the updating of the quadratic model remains to */
+ /* be done. The following cycle through the new interpolation points begins */
+ /* by putting the new point in XPT(KPT,.) and by setting PQ(KPT) to zero, */
+ /* except that a RETURN occurs if MAXFUN prohibits another value of F. */
+
+L260:
+ i__1 = npt;
+ for (kpt = 1; kpt <= i__1; ++kpt) {
+ if (ptsid[kpt] == zero) {
+ goto L340;
+ }
+ if (nf >= maxfun) {
+ nf = -1;
+ goto L350;
+ }
+ ih = 0;
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ w[j] = xpt[kpt + j * xpt_dim1];
+ xpt[kpt + j * xpt_dim1] = zero;
+ temp = pq[kpt] * w[j];
+ i__3 = j;
+ for (i__ = 1; i__ <= i__3; ++i__) {
+ ++ih;
+ /* L270: */
+ hq[ih] += temp * w[i__];
+ }
+ }
+ pq[kpt] = zero;
+ ip = (integer) ptsid[kpt];
+ iq = (integer) ((doublereal) np * ptsid[kpt] - (doublereal) (ip * np))
+ ;
+ if (ip > 0) {
+ xp = ptsaux[(ip << 1) + 1];
+ xpt[kpt + ip * xpt_dim1] = xp;
+ }
+ if (iq > 0) {
+ xq = ptsaux[(iq << 1) + 1];
+ if (ip == 0) {
+ xq = ptsaux[(iq << 1) + 2];
+ }
+ xpt[kpt + iq * xpt_dim1] = xq;
+ }
+
+ /* Set VQUAD to the value of the current model at the new point. */
+
+ vquad = fbase;
+ if (ip > 0) {
+ ihp = (ip + ip * ip) / 2;
+ vquad += xp * (gopt[ip] + half * xp * hq[ihp]);
+ }
+ if (iq > 0) {
+ ihq = (iq + iq * iq) / 2;
+ vquad += xq * (gopt[iq] + half * xq * hq[ihq]);
+ if (ip > 0) {
+ iw = std::max(ihp,ihq) - (i__3 = ip - iq, std::abs(i__3));
+ vquad += xp * xq * hq[iw];
+ }
+ }
+ i__3 = npt;
+ for (k = 1; k <= i__3; ++k) {
+ temp = zero;
+ if (ip > 0) {
+ temp += xp * xpt[k + ip * xpt_dim1];
+ }
+ if (iq > 0) {
+ temp += xq * xpt[k + iq * xpt_dim1];
+ }
+ /* L280: */
+ vquad += half * pq[k] * temp * temp;
+ }
+
+ /* Calculate F at the new interpolation point, and set DIFF to the factor */
+ /* that is going to multiply the KPT-th Lagrange function when the model */
+ /* is updated to provide interpolation to the new function value. */
+
+ i__3 = n;
+ for (i__ = 1; i__ <= i__3; ++i__) {
+ /* Computing MIN */
+ /* Computing MAX */
+ d__3 = xl[i__], d__4 = xbase[i__] + xpt[kpt + i__ * xpt_dim1];
+ d__1 = std::max(d__3,d__4), d__2 = xu[i__];
+ w[i__] = std::min(d__1,d__2);
+ if (xpt[kpt + i__ * xpt_dim1] == sl[i__]) {
+ w[i__] = xl[i__];
+ }
+ if (xpt[kpt + i__ * xpt_dim1] == su[i__]) {
+ w[i__] = xu[i__];
+ }
+ /* L290: */
+ }
+ ++(nf);
+ f = calfun(mat(&w[1],n));
+ fval[kpt] = f;
+ if (f < fval[kopt]) {
+ kopt = kpt;
+ }
+ diff = f - vquad;
+
+ /* Update the quadratic model. The RETURN from the subroutine occurs when */
+ /* all the new interpolation points are included in the model. */
+
+ i__3 = n;
+ for (i__ = 1; i__ <= i__3; ++i__) {
+ /* L310: */
+ gopt[i__] += diff * bmat[kpt + i__ * bmat_dim1];
+ }
+ i__3 = npt;
+ for (k = 1; k <= i__3; ++k) {
+ sum = zero;
+ i__2 = nptm;
+ for (j = 1; j <= i__2; ++j) {
+ /* L320: */
+ sum += zmat[k + j * zmat_dim1] * zmat[kpt + j * zmat_dim1];
+ }
+ temp = diff * sum;
+ if (ptsid[k] == zero) {
+ pq[k] += temp;
+ } else {
+ ip = (integer) ptsid[k];
+ iq = (integer) ((doublereal) np * ptsid[k] - (doublereal) (ip
+ * np));
+ ihq = (iq * iq + iq) / 2;
+ if (ip == 0) {
+ /* Computing 2nd power */
+ d__1 = ptsaux[(iq << 1) + 2];
+ hq[ihq] += temp * (d__1 * d__1);
+ } else {
+ ihp = (ip * ip + ip) / 2;
+ /* Computing 2nd power */
+ d__1 = ptsaux[(ip << 1) + 1];
+ hq[ihp] += temp * (d__1 * d__1);
+ if (iq > 0) {
+ /* Computing 2nd power */
+ d__1 = ptsaux[(iq << 1) + 1];
+ hq[ihq] += temp * (d__1 * d__1);
+ iw = std::max(ihp,ihq) - (i__2 = iq - ip, std::abs(i__2));
+ hq[iw] += temp * ptsaux[(ip << 1) + 1] * ptsaux[(iq <<
+ 1) + 1];
+ }
+ }
+ }
+ /* L330: */
+ }
+ ptsid[kpt] = zero;
+L340:
+ ;
+ }
+L350:
+ ;
+ } /* rescue_ */
+
+ // ----------------------------------------------------------------------------------------
+
+ void trsbox_(
+ const integer n,
+ const integer npt,
+ const doublereal *xpt,
+ const doublereal *xopt,
+ const doublereal *gopt,
+ const doublereal *hq,
+ const doublereal *pq,
+ const doublereal *sl,
+ const doublereal *su,
+ const doublereal delta,
+ doublereal *xnew,
+ doublereal *d__,
+ doublereal *gnew,
+ doublereal *xbdi,
+ doublereal *s,
+ doublereal *hs,
+ doublereal *hred,
+ doublereal *dsq,
+ doublereal *crvmin
+ ) const
+ {
+ /* System generated locals */
+ integer xpt_dim1, xpt_offset, i__1, i__2;
+ doublereal d__1, d__2, d__3, d__4;
+
+ /* Local variables */
+ integer i__, j, k, ih;
+ doublereal ds;
+ integer iu;
+ doublereal dhd, dhs, cth, one, shs, sth, ssq, half, beta, sdec, blen;
+ integer iact = 0, nact = 0;
+ doublereal angt, qred;
+ integer isav;
+ doublereal temp = 0, zero = 0, xsav = 0, xsum = 0, angbd = 0, dredg = 0, sredg = 0;
+ integer iterc;
+ doublereal resid = 0, delsq = 0, ggsav = 0, tempa = 0, tempb = 0,
+ redmax = 0, dredsq = 0, redsav = 0, onemin = 0, gredsq = 0, rednew = 0;
+ integer itcsav = 0;
+ doublereal rdprev = 0, rdnext = 0, stplen = 0, stepsq = 0;
+ integer itermax = 0;
+
+
+ /* The arguments N, NPT, XPT, XOPT, GOPT, HQ, PQ, SL and SU have the same */
+ /* meanings as the corresponding arguments of BOBYQB. */
+ /* DELTA is the trust region radius for the present calculation, which */
+ /* seeks a small value of the quadratic model within distance DELTA of */
+ /* XOPT subject to the bounds on the variables. */
+ /* XNEW will be set to a new vector of variables that is approximately */
+ /* the one that minimizes the quadratic model within the trust region */
+ /* subject to the SL and SU constraints on the variables. It satisfies */
+ /* as equations the bounds that become active during the calculation. */
+ /* D is the calculated trial step from XOPT, generated iteratively from an */
+ /* initial value of zero. Thus XNEW is XOPT+D after the final iteration. */
+ /* GNEW holds the gradient of the quadratic model at XOPT+D. It is updated */
+ /* when D is updated. */
+ /* XBDI is a working space vector. For I=1,2,...,N, the element XBDI(I) is */
+ /* set to -1.0, 0.0, or 1.0, the value being nonzero if and only if the */
+ /* I-th variable has become fixed at a bound, the bound being SL(I) or */
+ /* SU(I) in the case XBDI(I)=-1.0 or XBDI(I)=1.0, respectively. This */
+ /* information is accumulated during the construction of XNEW. */
+ /* The arrays S, HS and HRED are also used for working space. They hold the */
+ /* current search direction, and the changes in the gradient of Q along S */
+ /* and the reduced D, respectively, where the reduced D is the same as D, */
+ /* except that the components of the fixed variables are zero. */
+ /* DSQ will be set to the square of the length of XNEW-XOPT. */
+ /* CRVMIN is set to zero if D reaches the trust region boundary. Otherwise */
+ /* it is set to the least curvature of H that occurs in the conjugate */
+ /* gradient searches that are not restricted by any constraints. The */
+ /* value CRVMIN=-1.0D0 is set, however, if all of these searches are */
+ /* constrained. */
+
+ /* A version of the truncated conjugate gradient is applied. If a line */
+ /* search is restricted by a constraint, then the procedure is restarted, */
+ /* the values of the variables that are at their bounds being fixed. If */
+ /* the trust region boundary is reached, then further changes may be made */
+ /* to D, each one being in the two dimensional space that is spanned */
+ /* by the current D and the gradient of Q at XOPT+D, staying on the trust */
+ /* region boundary. Termination occurs when the reduction in Q seems to */
+ /* be close to the greatest reduction that can be achieved. */
+
+ /* Set some constants. */
+
+ /* Parameter adjustments */
+ xpt_dim1 = npt;
+ xpt_offset = 1 + xpt_dim1;
+ xpt -= xpt_offset;
+ --xopt;
+ --gopt;
+ --hq;
+ --pq;
+ --sl;
+ --su;
+ --xnew;
+ --d__;
+ --gnew;
+ --xbdi;
+ --s;
+ --hs;
+ --hred;
+
+ /* Function Body */
+ half = .5;
+ one = 1.;
+ onemin = -1.;
+ zero = 0.;
+
+ /* The sign of GOPT(I) gives the sign of the change to the I-th variable */
+ /* that will reduce Q from its value at XOPT. Thus XBDI(I) shows whether */
+ /* or not to fix the I-th variable at one of its bounds initially, with */
+ /* NACT being set to the number of fixed variables. D and GNEW are also */
+ /* set for the first iteration. DELSQ is the upper bound on the sum of */
+ /* squares of the free variables. QRED is the reduction in Q so far. */
+
+ iterc = 0;
+ nact = 0;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ xbdi[i__] = zero;
+ if (xopt[i__] <= sl[i__]) {
+ if (gopt[i__] >= zero) {
+ xbdi[i__] = onemin;
+ }
+ } else if (xopt[i__] >= su[i__]) {
+ if (gopt[i__] <= zero) {
+ xbdi[i__] = one;
+ }
+ }
+ if (xbdi[i__] != zero) {
+ ++nact;
+ }
+ d__[i__] = zero;
+ /* L10: */
+ gnew[i__] = gopt[i__];
+ }
+ delsq = delta * delta;
+ qred = zero;
+ *crvmin = onemin;
+
+ /* Set the next search direction of the conjugate gradient method. It is */
+ /* the steepest descent direction initially and when the iterations are */
+ /* restarted because a variable has just been fixed by a bound, and of */
+ /* course the components of the fixed variables are zero. ITERMAX is an */
+ /* upper bound on the indices of the conjugate gradient iterations. */
+
+L20:
+ beta = zero;
+L30:
+ stepsq = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (xbdi[i__] != zero) {
+ s[i__] = zero;
+ } else if (beta == zero) {
+ s[i__] = -gnew[i__];
+ } else {
+ s[i__] = beta * s[i__] - gnew[i__];
+ }
+ /* L40: */
+ /* Computing 2nd power */
+ d__1 = s[i__];
+ stepsq += d__1 * d__1;
+ }
+ if (stepsq == zero) {
+ goto L190;
+ }
+ if (beta == zero) {
+ gredsq = stepsq;
+ itermax = iterc + n - nact;
+ }
+ if (gredsq * delsq <= qred * 1e-4 * qred) {
+ goto L190;
+ }
+
+ /* Multiply the search direction by the second derivative matrix of Q and */
+ /* calculate some scalars for the choice of steplength. Then set BLEN to */
+ /* the length of the the step to the trust region boundary and STPLEN to */
+ /* the steplength, ignoring the simple bounds. */
+
+ goto L210;
+L50:
+ resid = delsq;
+ ds = zero;
+ shs = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (xbdi[i__] == zero) {
+ /* Computing 2nd power */
+ d__1 = d__[i__];
+ resid -= d__1 * d__1;
+ ds += s[i__] * d__[i__];
+ shs += s[i__] * hs[i__];
+ }
+ /* L60: */
+ }
+ if (resid <= zero) {
+ goto L90;
+ }
+ temp = std::sqrt(stepsq * resid + ds * ds);
+ if (ds < zero) {
+ blen = (temp - ds) / stepsq;
+ } else {
+ blen = resid / (temp + ds);
+ }
+ stplen = blen;
+ if (shs > zero) {
+ /* Computing MIN */
+ d__1 = blen, d__2 = gredsq / shs;
+ stplen = std::min(d__1,d__2);
+ }
+
+ /* Reduce STPLEN if necessary in order to preserve the simple bounds, */
+ /* letting IACT be the index of the new constrained variable. */
+
+ iact = 0;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (s[i__] != zero) {
+ xsum = xopt[i__] + d__[i__];
+ if (s[i__] > zero) {
+ temp = (su[i__] - xsum) / s[i__];
+ } else {
+ temp = (sl[i__] - xsum) / s[i__];
+ }
+ if (temp < stplen) {
+ stplen = temp;
+ iact = i__;
+ }
+ }
+ /* L70: */
+ }
+
+ /* Update CRVMIN, GNEW and D. Set SDEC to the decrease that occurs in Q. */
+
+ sdec = zero;
+ if (stplen > zero) {
+ ++iterc;
+ temp = shs / stepsq;
+ if (iact == 0 && temp > zero) {
+ *crvmin = std::min(*crvmin,temp);
+ if (*crvmin == onemin) {
+ *crvmin = temp;
+ }
+ }
+ ggsav = gredsq;
+ gredsq = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ gnew[i__] += stplen * hs[i__];
+ if (xbdi[i__] == zero) {
+ /* Computing 2nd power */
+ d__1 = gnew[i__];
+ gredsq += d__1 * d__1;
+ }
+ /* L80: */
+ d__[i__] += stplen * s[i__];
+ }
+ /* Computing MAX */
+ d__1 = stplen * (ggsav - half * stplen * shs);
+ sdec = std::max(d__1,zero);
+ qred += sdec;
+ }
+
+ /* Restart the conjugate gradient method if it has hit a new bound. */
+
+ if (iact > 0) {
+ ++nact;
+ xbdi[iact] = one;
+ if (s[iact] < zero) {
+ xbdi[iact] = onemin;
+ }
+ /* Computing 2nd power */
+ d__1 = d__[iact];
+ delsq -= d__1 * d__1;
+ if (delsq <= zero) {
+ goto L90;
+ }
+ goto L20;
+ }
+
+ /* If STPLEN is less than BLEN, then either apply another conjugate */
+ /* gradient iteration or RETURN. */
+
+ if (stplen < blen) {
+ if (iterc == itermax) {
+ goto L190;
+ }
+ if (sdec <= qred * .01) {
+ goto L190;
+ }
+ beta = gredsq / ggsav;
+ goto L30;
+ }
+L90:
+ *crvmin = zero;
+
+ /* Prepare for the alternative iteration by calculating some scalars */
+ /* and by multiplying the reduced D by the second derivative matrix of */
+ /* Q, where S holds the reduced D in the call of GGMULT. */
+
+L100:
+ if (nact >= n - 1) {
+ goto L190;
+ }
+ dredsq = zero;
+ dredg = zero;
+ gredsq = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (xbdi[i__] == zero) {
+ /* Computing 2nd power */
+ d__1 = d__[i__];
+ dredsq += d__1 * d__1;
+ dredg += d__[i__] * gnew[i__];
+ /* Computing 2nd power */
+ d__1 = gnew[i__];
+ gredsq += d__1 * d__1;
+ s[i__] = d__[i__];
+ } else {
+ s[i__] = zero;
+ }
+ /* L110: */
+ }
+ itcsav = iterc;
+ goto L210;
+
+ /* Let the search direction S be a linear combination of the reduced D */
+ /* and the reduced G that is orthogonal to the reduced D. */
+
+L120:
+ ++iterc;
+ temp = gredsq * dredsq - dredg * dredg;
+ if (temp <= qred * 1e-4 * qred) {
+ goto L190;
+ }
+ temp = std::sqrt(temp);
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (xbdi[i__] == zero) {
+ s[i__] = (dredg * d__[i__] - dredsq * gnew[i__]) / temp;
+ } else {
+ s[i__] = zero;
+ }
+ /* L130: */
+ }
+ sredg = -temp;
+
+ /* By considering the simple bounds on the variables, calculate an upper */
+ /* bound on the tangent of half the angle of the alternative iteration, */
+ /* namely ANGBD, except that, if already a free variable has reached a */
+ /* bound, there is a branch back to label 100 after fixing that variable. */
+
+ angbd = one;
+ iact = 0;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (xbdi[i__] == zero) {
+ tempa = xopt[i__] + d__[i__] - sl[i__];
+ tempb = su[i__] - xopt[i__] - d__[i__];
+ if (tempa <= zero) {
+ ++nact;
+ xbdi[i__] = onemin;
+ goto L100;
+ } else if (tempb <= zero) {
+ ++nact;
+ xbdi[i__] = one;
+ goto L100;
+ }
+ /* Computing 2nd power */
+ d__1 = d__[i__];
+ /* Computing 2nd power */
+ d__2 = s[i__];
+ ssq = d__1 * d__1 + d__2 * d__2;
+ /* Computing 2nd power */
+ d__1 = xopt[i__] - sl[i__];
+ temp = ssq - d__1 * d__1;
+ if (temp > zero) {
+ temp = std::sqrt(temp) - s[i__];
+ if (angbd * temp > tempa) {
+ angbd = tempa / temp;
+ iact = i__;
+ xsav = onemin;
+ }
+ }
+ /* Computing 2nd power */
+ d__1 = su[i__] - xopt[i__];
+ temp = ssq - d__1 * d__1;
+ if (temp > zero) {
+ temp = std::sqrt(temp) + s[i__];
+ if (angbd * temp > tempb) {
+ angbd = tempb / temp;
+ iact = i__;
+ xsav = one;
+ }
+ }
+ }
+ /* L140: */
+ }
+
+ /* Calculate HHD and some curvatures for the alternative iteration. */
+
+ goto L210;
+L150:
+ shs = zero;
+ dhs = zero;
+ dhd = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ if (xbdi[i__] == zero) {
+ shs += s[i__] * hs[i__];
+ dhs += d__[i__] * hs[i__];
+ dhd += d__[i__] * hred[i__];
+ }
+ /* L160: */
+ }
+
+ /* Seek the greatest reduction in Q for a range of equally spaced values */
+ /* of ANGT in [0,ANGBD], where ANGT is the tangent of half the angle of */
+ /* the alternative iteration. */
+
+ redmax = zero;
+ isav = 0;
+ redsav = zero;
+ iu = (integer) (angbd * 17. + 3.1);
+ i__1 = iu;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ angt = angbd * (doublereal) i__ / (doublereal) iu;
+ sth = (angt + angt) / (one + angt * angt);
+ temp = shs + angt * (angt * dhd - dhs - dhs);
+ rednew = sth * (angt * dredg - sredg - half * sth * temp);
+ if (rednew > redmax) {
+ redmax = rednew;
+ isav = i__;
+ rdprev = redsav;
+ } else if (i__ == isav + 1) {
+ rdnext = rednew;
+ }
+ /* L170: */
+ redsav = rednew;
+ }
+
+ /* Return if the reduction is zero. Otherwise, set the sine and cosine */
+ /* of the angle of the alternative iteration, and calculate SDEC. */
+
+ if (isav == 0) {
+ goto L190;
+ }
+ if (isav < iu) {
+ temp = (rdnext - rdprev) / (redmax + redmax - rdprev - rdnext);
+ angt = angbd * ((doublereal) isav + half * temp) / (doublereal) iu;
+ }
+ cth = (one - angt * angt) / (one + angt * angt);
+ sth = (angt + angt) / (one + angt * angt);
+ temp = shs + angt * (angt * dhd - dhs - dhs);
+ sdec = sth * (angt * dredg - sredg - half * sth * temp);
+ if (sdec <= zero) {
+ goto L190;
+ }
+
+ /* Update GNEW, D and HRED. If the angle of the alternative iteration */
+ /* is restricted by a bound on a free variable, that variable is fixed */
+ /* at the bound. */
+
+ dredg = zero;
+ gredsq = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ gnew[i__] = gnew[i__] + (cth - one) * hred[i__] + sth * hs[i__];
+ if (xbdi[i__] == zero) {
+ d__[i__] = cth * d__[i__] + sth * s[i__];
+ dredg += d__[i__] * gnew[i__];
+ /* Computing 2nd power */
+ d__1 = gnew[i__];
+ gredsq += d__1 * d__1;
+ }
+ /* L180: */
+ hred[i__] = cth * hred[i__] + sth * hs[i__];
+ }
+ qred += sdec;
+ if (iact > 0 && isav == iu) {
+ ++nact;
+ xbdi[iact] = xsav;
+ goto L100;
+ }
+
+ /* If SDEC is sufficiently small, then RETURN after setting XNEW to */
+ /* XOPT+D, giving careful attention to the bounds. */
+
+ if (sdec > qred * .01) {
+ goto L120;
+ }
+L190:
+ *dsq = zero;
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ /* Computing MAX */
+ /* Computing MIN */
+ d__3 = xopt[i__] + d__[i__], d__4 = su[i__];
+ d__1 = std::min(d__3,d__4), d__2 = sl[i__];
+ xnew[i__] = std::max(d__1,d__2);
+ if (xbdi[i__] == onemin) {
+ xnew[i__] = sl[i__];
+ }
+ if (xbdi[i__] == one) {
+ xnew[i__] = su[i__];
+ }
+ d__[i__] = xnew[i__] - xopt[i__];
+ /* L200: */
+ /* Computing 2nd power */
+ d__1 = d__[i__];
+ *dsq += d__1 * d__1;
+ }
+ return;
+ /* The following instructions multiply the current S-vector by the second */
+ /* derivative matrix of the quadratic model, putting the product in HS. */
+ /* They are reached from three different parts of the software above and */
+ /* they can be regarded as an external subroutine. */
+
+L210:
+ ih = 0;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ hs[j] = zero;
+ i__2 = j;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ ++ih;
+ if (i__ < j) {
+ hs[j] += hq[ih] * s[i__];
+ }
+ /* L220: */
+ hs[i__] += hq[ih] * s[j];
+ }
+ }
+ i__2 = npt;
+ for (k = 1; k <= i__2; ++k) {
+ if (pq[k] != zero) {
+ temp = zero;
+ i__1 = n;
+ for (j = 1; j <= i__1; ++j) {
+ /* L230: */
+ temp += xpt[k + j * xpt_dim1] * s[j];
+ }
+ temp *= pq[k];
+ i__1 = n;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ /* L240: */
+ hs[i__] += temp * xpt[k + i__ * xpt_dim1];
+ }
+ }
+ /* L250: */
+ }
+ if (*crvmin != zero) {
+ goto L50;
+ }
+ if (iterc > itcsav) {
+ goto L150;
+ }
+ i__2 = n;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* L260: */
+ hred[i__] = hs[i__];
+ }
+ goto L120;
+ } /* trsbox_ */
+
+ // ----------------------------------------------------------------------------------------
+
+ void update_(
+ const integer n,
+ const integer npt,
+ doublereal *bmat,
+ doublereal *zmat,
+ const integer ndim,
+ doublereal *vlag,
+ const doublereal beta,
+ const doublereal denom,
+ const integer knew,
+ doublereal *w
+ ) const
+ {
+ /* System generated locals */
+ integer bmat_dim1, bmat_offset, zmat_dim1, zmat_offset, i__1, i__2;
+ doublereal d__1, d__2, d__3;
+
+ /* Local variables */
+ integer i__, j, k, jp;
+ doublereal one, tau, temp;
+ integer nptm;
+ doublereal zero, alpha, tempa, tempb, ztest;
+
+
+ /* The arrays BMAT and ZMAT are updated, as required by the new position */
+ /* of the interpolation point that has the index KNEW. The vector VLAG has */
+ /* N+NPT components, set on entry to the first NPT and last N components */
+ /* of the product Hw in equation (4.11) of the Powell (2006) paper on */
+ /* NEWUOA. Further, BETA is set on entry to the value of the parameter */
+ /* with that name, and DENOM is set to the denominator of the updating */
+ /* formula. Elements of ZMAT may be treated as zero if their moduli are */
+ /* at most ZTEST. The first NDIM elements of W are used for working space. */
+
+ /* Set some constants. */
+
+ /* Parameter adjustments */
+ zmat_dim1 = npt;
+ zmat_offset = 1 + zmat_dim1;
+ zmat -= zmat_offset;
+ bmat_dim1 = ndim;
+ bmat_offset = 1 + bmat_dim1;
+ bmat -= bmat_offset;
+ --vlag;
+ --w;
+
+ /* Function Body */
+ one = 1.;
+ zero = 0.;
+ nptm = npt - n - 1;
+ ztest = zero;
+ i__1 = npt;
+ for (k = 1; k <= i__1; ++k) {
+ i__2 = nptm;
+ for (j = 1; j <= i__2; ++j) {
+ /* L10: */
+ /* Computing MAX */
+ d__2 = ztest, d__3 = (d__1 = zmat[k + j * zmat_dim1], std::abs(d__1));
+ ztest = std::max(d__2,d__3);
+ }
+ }
+ ztest *= 1e-20;
+
+ /* Apply the rotations that put zeros in the KNEW-th row of ZMAT. */
+
+ i__2 = nptm;
+ for (j = 2; j <= i__2; ++j) {
+ if ((d__1 = zmat[knew + j * zmat_dim1], std::abs(d__1)) > ztest) {
+ /* Computing 2nd power */
+ d__1 = zmat[knew + zmat_dim1];
+ /* Computing 2nd power */
+ d__2 = zmat[knew + j * zmat_dim1];
+ temp = std::sqrt(d__1 * d__1 + d__2 * d__2);
+ tempa = zmat[knew + zmat_dim1] / temp;
+ tempb = zmat[knew + j * zmat_dim1] / temp;
+ i__1 = npt;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ temp = tempa * zmat[i__ + zmat_dim1] + tempb * zmat[i__ + j *
+ zmat_dim1];
+ zmat[i__ + j * zmat_dim1] = tempa * zmat[i__ + j * zmat_dim1]
+ - tempb * zmat[i__ + zmat_dim1];
+ /* L20: */
+ zmat[i__ + zmat_dim1] = temp;
+ }
+ }
+ zmat[knew + j * zmat_dim1] = zero;
+ /* L30: */
+ }
+
+ /* Put the first NPT components of the KNEW-th column of HLAG into W, */
+ /* and calculate the parameters of the updating formula. */
+
+ i__2 = npt;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ w[i__] = zmat[knew + zmat_dim1] * zmat[i__ + zmat_dim1];
+ /* L40: */
+ }
+ alpha = w[knew];
+ tau = vlag[knew];
+ vlag[knew] -= one;
+
+ /* Complete the updating of ZMAT. */
+
+ temp = std::sqrt(denom);
+ tempb = zmat[knew + zmat_dim1] / temp;
+ tempa = tau / temp;
+ i__2 = npt;
+ for (i__ = 1; i__ <= i__2; ++i__) {
+ /* L50: */
+ zmat[i__ + zmat_dim1] = tempa * zmat[i__ + zmat_dim1] - tempb * vlag[
+ i__];
+ }
+
+ /* Finally, update the matrix BMAT. */
+
+ i__2 = n;
+ for (j = 1; j <= i__2; ++j) {
+ jp = npt + j;
+ w[jp] = bmat[knew + j * bmat_dim1];
+ tempa = (alpha * vlag[jp] - tau * w[jp]) / denom;
+ tempb = (-(beta) * w[jp] - tau * vlag[jp]) / denom;
+ i__1 = jp;
+ for (i__ = 1; i__ <= i__1; ++i__) {
+ bmat[i__ + j * bmat_dim1] = bmat[i__ + j * bmat_dim1] + tempa *
+ vlag[i__] + tempb * w[i__];
+ if (i__ > npt) {
+ bmat[jp + (i__ - npt) * bmat_dim1] = bmat[i__ + j *
+ bmat_dim1];
+ }
+ /* L60: */
+ }
+ }
+ } /* update_ */
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct,
+ typename T,
+ typename U
+ >
+ double find_min_bobyqa (
+ const funct& f,
+ T& x,
+ long npt,
+ const U& x_lower,
+ const U& x_upper,
+ const double rho_begin,
+ const double rho_end,
+ const long max_f_evals
+ )
+ {
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ // check the requirements. Also split the assert up so that the error message isn't huge.
+ DLIB_CASSERT(is_col_vector(x) && is_col_vector(x_lower) && is_col_vector(x_upper) &&
+ x.size() == x_lower.size() && x_lower.size() == x_upper.size() &&
+ x.size() > 1 && max_f_evals > 1,
+ "\tdouble find_min_bobyqa()"
+ << "\n\t Invalid arguments have been given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t is_col_vector(x_lower): " << is_col_vector(x_lower)
+ << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t x_lower.size(): " << x_lower.size()
+ << "\n\t x_upper.size(): " << x_upper.size()
+ << "\n\t max_f_evals: " << max_f_evals
+ );
+
+ DLIB_CASSERT(x.size() + 2 <= npt && npt <= (x.size()+1)*(x.size()+2)/2 &&
+ 0 < rho_end && rho_end < rho_begin &&
+ min(x_upper - x_lower) > 2*rho_begin &&
+ min(x - x_lower) >= 0 && min(x_upper - x) >= 0,
+ "\tdouble find_min_bobyqa()"
+ << "\n\t Invalid arguments have been given to this function"
+ << "\n\t ntp in valid range: " << (x.size() + 2 <= npt && npt <= (x.size()+1)*(x.size()+2)/2)
+ << "\n\t npt: " << npt
+ << "\n\t rho_begin: " << rho_begin
+ << "\n\t rho_end: " << rho_end
+ << "\n\t min(x_upper - x_lower) > 2*rho_begin: " << (min(x_upper - x_lower) > 2*rho_begin)
+ << "\n\t min(x - x_lower) >= 0 && min(x_upper - x) >= 0: " << (min(x - x_lower) >= 0 && min(x_upper - x) >= 0)
+ );
+
+
+ bobyqa_implementation impl;
+ return impl.find_min(f, x, npt, x_lower, x_upper, rho_begin, rho_end, max_f_evals);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct,
+ typename T,
+ typename U
+ >
+ double find_max_bobyqa (
+ const funct& f,
+ T& x,
+ long npt,
+ const U& x_lower,
+ const U& x_upper,
+ const double rho_begin,
+ const double rho_end,
+ const long max_f_evals
+ )
+ {
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ return -find_min_bobyqa(negate_function(f), x, npt, x_lower, x_upper, rho_begin, rho_end, max_f_evals);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_BOBYQA_Hh_
+
diff --git a/ml/dlib/dlib/optimization/optimization_bobyqa_abstract.h b/ml/dlib/dlib/optimization/optimization_bobyqa_abstract.h
new file mode 100644
index 000000000..46f9436af
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_bobyqa_abstract.h
@@ -0,0 +1,120 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATIOn_BOBYQA_ABSTRACT_Hh_
+#ifdef DLIB_OPTIMIZATIOn_BOBYQA_ABSTRACT_Hh_
+
+#include "../matrix.h"
+
+// ----------------------------------------------------------------------------------------
+
+/*
+ This file defines the dlib interface to the BOBYQA software developed by M.J.D Powell.
+ BOBYQA is a method for optimizing a function in the absence of derivative information.
+ Powell described it as a method that seeks the least value of a function of many
+ variables, by applying a trust region method that forms quadratic models by
+ interpolation. There is usually some freedom in the interpolation conditions,
+ which is taken up by minimizing the Frobenius norm of the change to the second
+ derivative of the model, beginning with the zero matrix. The values of the variables
+ are constrained by upper and lower bounds.
+
+
+ The following paper, published in 2009 by Powell, describes the
+ detailed working of the BOBYQA algorithm.
+
+ The BOBYQA algorithm for bound constrained optimization
+ without derivatives by M.J.D. Powell
+*/
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ class bobyqa_failure : public error;
+ /*!
+ This is the exception class used by the functions defined in this file.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct,
+ typename T,
+ typename U
+ >
+ double find_min_bobyqa (
+ const funct& f,
+ T& x,
+ long npt,
+ const U& x_lower,
+ const U& x_upper,
+ const double rho_begin,
+ const double rho_end,
+ const long max_f_evals
+ );
+ /*!
+ requires
+ - f(x) must be a valid expression that evaluates to a double
+ - is_col_vector(x) == true
+ - is_col_vector(x_lower) == true
+ - is_col_vector(x_upper) == true
+ - x.size() == x_lower.size() == x_upper.size()
+ - x.size() > 1
+ - x.size() + 2 <= npt <= (x.size()+1)*(x.size()+2)/2
+ - 0 < rho_end < rho_begin
+ - min(x_upper - x_lower) > 2*rho_begin
+ (i.e. the lower and upper bounds on each x element must be larger than 2*rho_begin)
+ - min(x - x_lower) >= 0 && min(x_upper - x) >= 0
+ (i.e. the given x should be within the bounds defined by x_lower and x_upper)
+ - max_f_evals > 1
+ ensures
+ - Performs a constrained minimization of the function f() starting from
+ the initial point x.
+ - The BOBYQA algorithm uses a number of interpolating points to perform its
+ work. The npt argument controls how many points get used. Typically,
+ a good value to use is 2*x.size()+1.
+ - #x == the value of x (within the bounds defined by x_lower and x_upper) that
+ was found to minimize f(). More precisely:
+ - min(#x - x_lower) >= 0 && min(x_upper - #x) >= 0
+ - returns f(#x).
+ - rho_begin and rho_end are used as the initial and final values of a trust
+ region radius. Typically, rho_begin should be about one tenth of the greatest
+ expected change to a variable, while rho_end should indicate the accuracy that
+ is required in the final values of the variables.
+ throws
+ - bobyqa_failure
+ This exception is thrown if the algorithm is unable to make progress towards
+ solving the problem. This may occur because the algorithm detects excessive
+ numerical errors or because max_f_evals of f() have occurred without reaching
+ convergence.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct,
+ typename T,
+ typename U
+ >
+ double find_max_bobyqa (
+ const funct& f,
+ T& x,
+ long npt,
+ const U& x_lower,
+ const U& x_upper,
+ const double rho_begin,
+ const double rho_end,
+ const long max_f_evals
+ );
+ /*!
+ This function is identical to the find_min_bobyqa() routine defined above
+ except that it negates the f() function before performing optimization.
+ Thus this function will attempt to find the maximizer of f() rather than
+ the minimizer.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_BOBYQA_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/optimization/optimization_least_squares.h b/ml/dlib/dlib/optimization/optimization_least_squares.h
new file mode 100644
index 000000000..6d12a919d
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_least_squares.h
@@ -0,0 +1,345 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATION_LEAST_SQuARES_H_h_
+#define DLIB_OPTIMIZATION_LEAST_SQuARES_H_h_
+
+#include "../matrix.h"
+#include "optimization_trust_region.h"
+#include "optimization_least_squares_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename column_vector_type,
+ typename funct_type,
+ typename funct_der_type,
+ typename vector_type
+ >
+ class least_squares_function_model
+ {
+ public:
+ least_squares_function_model (
+ const funct_type& f_,
+ const funct_der_type& der_,
+ const vector_type& list_
+ ) : f(f_), der(der_), list(list_)
+ {
+ S = 0;
+ last_f = 0;
+ last_f2 = 0;
+
+ r.set_size(list.size(),1);
+ }
+
+ const funct_type& f;
+ const funct_der_type& der;
+ const vector_type& list;
+
+ typedef typename column_vector_type::type type;
+ typedef typename column_vector_type::mem_manager_type mem_manager_type;
+ typedef typename column_vector_type::layout_type layout_type;
+ const static long NR = column_vector_type::NR;
+
+ typedef column_vector_type column_vector;
+ typedef matrix<type,NR,NR,mem_manager_type,layout_type> general_matrix;
+
+
+ type operator() (
+ const column_vector& x
+ ) const
+ {
+ type result = 0;
+ for (long i = 0; i < list.size(); ++i)
+ {
+ const type temp = f(list(i), x);
+ // save the residual for later
+ r(i) = temp;
+ result += temp*temp;
+ }
+
+ last_f = 0.5*result;
+ return 0.5*result;
+ }
+
+ void get_derivative_and_hessian (
+ const column_vector& x,
+ column_vector& d,
+ general_matrix& h
+ ) const
+ {
+ J.set_size(list.size(), x.size());
+
+ // compute the Jacobian
+ for (long i = 0; i < list.size(); ++i)
+ {
+ set_rowm(J,i) = trans(der(list(i), x));
+ }
+
+ // Compute the Levenberg-Marquardt gradient and hessian
+ d = trans(J)*r;
+ h = trans(J)*J;
+
+ if (S.size() == 0)
+ {
+ S.set_size(x.size(), x.size());
+ S = 0;
+ }
+
+ // If this isn't the first iteration then check if using
+ // a quasi-newton update helps things.
+ if (last_r.size() != 0)
+ {
+
+ s = x - last_x;
+ y = d - last_d;
+ yy = d - trans(last_J)*r;
+
+ const type ys = trans(y)*s;
+ vtemp = yy - S*s;
+ const type temp2 = std::abs(trans(s)*S*s);
+ type scale = (temp2 != 0) ? std::min<type>(1, std::abs(dot(s,yy))/temp2) : 1;
+
+ if (ys != 0)
+ S = scale*S + (vtemp*trans(y) + y*trans(vtemp))/(ys) - dot(vtemp,s)/ys/ys*y*trans(y);
+ else
+ S *= scale;
+
+ // check how well both the models fit the last change we saw in f()
+ const type measured_delta = last_f2 - last_f;
+ s = -s;
+ const type h_predicted_delta = 0.5*trans(s)*h*s + trans(d)*s;
+ const type s_predicted_delta = 0.5*trans(s)*(h+S)*s + trans(d)*s;
+
+ const type h_error = std::abs((h_predicted_delta/measured_delta) - 1);
+ const type s_error = std::abs((s_predicted_delta/measured_delta) - 1);
+
+ if (s_error < h_error && h_error > 0.01)
+ {
+ h += make_symmetric(S);
+ }
+ else if (s_error > 10)
+ {
+ S = 0;
+ }
+
+ // put r into last_r
+ r.swap(last_r);
+ }
+ else
+ {
+ // put r into last_r
+ last_r = r;
+ }
+
+ J.swap(last_J);
+ last_x = x;
+ last_d = d;
+
+ last_f2 = last_f;
+ }
+
+ mutable type last_f; // value of function we saw in last operator()
+ mutable type last_f2; // value of last_f we saw in get_derivative_and_hessian()
+ mutable matrix<type,0,1,mem_manager_type,layout_type> r;
+ mutable column_vector vtemp;
+ mutable column_vector s, y, yy;
+
+ mutable general_matrix S;
+ mutable column_vector last_x;
+ mutable column_vector last_d;
+ mutable matrix<type,0,1,mem_manager_type,layout_type> last_r;
+ mutable matrix<type,0,NR,mem_manager_type,layout_type> last_J;
+ mutable matrix<type,0,NR,mem_manager_type,layout_type> J;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename column_vector_type,
+ typename funct_type,
+ typename funct_der_type,
+ typename vector_type
+ >
+ least_squares_function_model<column_vector_type,funct_type,funct_der_type,vector_type> least_squares_model (
+ const funct_type& f,
+ const funct_der_type& der,
+ const vector_type& list
+ )
+ {
+ return least_squares_function_model<column_vector_type,funct_type,funct_der_type,vector_type>(f,der,list);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stop_strategy_type,
+ typename funct_type,
+ typename funct_der_type,
+ typename vector_type,
+ typename T
+ >
+ double solve_least_squares (
+ stop_strategy_type stop_strategy,
+ const funct_type& f,
+ const funct_der_type& der,
+ const vector_type& list,
+ T& x,
+ double radius = 1
+ )
+ {
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(mat(list)) && list.size() > 0 &&
+ is_col_vector(x) && radius > 0,
+ "\t double solve_least_squares()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t is_vector(list): " << is_vector(mat(list))
+ << "\n\t list.size(): " << list.size()
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t radius: " << radius
+ );
+
+ return find_min_trust_region(stop_strategy,
+ least_squares_model<T>(f, der, mat(list)),
+ x,
+ radius);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename column_vector_type,
+ typename funct_type,
+ typename funct_der_type,
+ typename vector_type
+ >
+ class least_squares_lm_function_model
+ {
+ public:
+ least_squares_lm_function_model (
+ const funct_type& f_,
+ const funct_der_type& der_,
+ const vector_type& list_
+ ) : f(f_), der(der_), list(list_)
+ {
+ r.set_size(list.size(),1);
+ }
+
+ const funct_type& f;
+ const funct_der_type& der;
+ const vector_type& list;
+
+ typedef typename column_vector_type::type type;
+ typedef typename column_vector_type::mem_manager_type mem_manager_type;
+ typedef typename column_vector_type::layout_type layout_type;
+ const static long NR = column_vector_type::NR;
+
+ typedef column_vector_type column_vector;
+ typedef matrix<type,NR,NR,mem_manager_type,layout_type> general_matrix;
+
+ mutable matrix<type,0,1,mem_manager_type,layout_type> r;
+ mutable column_vector vtemp;
+
+ type operator() (
+ const column_vector& x
+ ) const
+ {
+ type result = 0;
+ for (long i = 0; i < list.size(); ++i)
+ {
+ const type temp = f(list(i), x);
+ // save the residual for later
+ r(i) = temp;
+ result += temp*temp;
+ }
+
+ return 0.5*result;
+ }
+
+ void get_derivative_and_hessian (
+ const column_vector& x,
+ column_vector& d,
+ general_matrix& h
+ ) const
+ {
+ d = 0;
+ h = 0;
+ for (long i = 0; i < list.size(); ++i)
+ {
+ vtemp = der(list(i), x);
+ d += r(i)*vtemp;
+ h += vtemp*trans(vtemp);
+ }
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename column_vector_type,
+ typename funct_type,
+ typename funct_der_type,
+ typename vector_type
+ >
+ least_squares_lm_function_model<column_vector_type,funct_type,funct_der_type,vector_type> least_squares_lm_model (
+ const funct_type& f,
+ const funct_der_type& der,
+ const vector_type& list
+ )
+ {
+ return least_squares_lm_function_model<column_vector_type,funct_type,funct_der_type,vector_type>(f,der,list);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stop_strategy_type,
+ typename funct_type,
+ typename funct_der_type,
+ typename vector_type,
+ typename T
+ >
+ double solve_least_squares_lm (
+ stop_strategy_type stop_strategy,
+ const funct_type& f,
+ const funct_der_type& der,
+ const vector_type& list,
+ T& x,
+ double radius = 1
+ )
+ {
+ // The starting point (i.e. x) must be a column vector.
+ COMPILE_TIME_ASSERT(T::NC <= 1);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(mat(list)) && list.size() > 0 &&
+ is_col_vector(x) && radius > 0,
+ "\t double solve_least_squares_lm()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t is_vector(list): " << is_vector(mat(list))
+ << "\n\t list.size(): " << list.size()
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t radius: " << radius
+ );
+
+ return find_min_trust_region(stop_strategy,
+ least_squares_lm_model<T>(f, der, mat(list)),
+ x,
+ radius);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATION_LEAST_SQuARES_H_h_
+
+
diff --git a/ml/dlib/dlib/optimization/optimization_least_squares_abstract.h b/ml/dlib/dlib/optimization/optimization_least_squares_abstract.h
new file mode 100644
index 000000000..aaffb2221
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_least_squares_abstract.h
@@ -0,0 +1,112 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATIOn_LEAST_SQUARES_ABSTRACT_
+#ifdef DLIB_OPTIMIZATIOn_LEAST_SQUARES_ABSTRACT_
+
+#include "../matrix/matrix_abstract.h"
+#include "optimization_trust_region_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stop_strategy_type,
+ typename funct_type,
+ typename funct_der_type,
+ typename vector_type,
+ typename T
+ >
+ double solve_least_squares (
+ stop_strategy_type stop_strategy,
+ const funct_type& f,
+ const funct_der_type& der,
+ const vector_type& list,
+ T& x,
+ double radius = 1
+ );
+ /*!
+ requires
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - list == a matrix or something convertible to a matrix via mat()
+ such as a std::vector.
+ - is_vector(list) == true
+ - list.size() > 0
+ - is_col_vector(x) == true
+ - radius > 0
+ - for all valid i:
+ - f(list(i),x) must be a valid expression that evaluates to a floating point value.
+ - der(list(i),x) must be a valid expression that evaluates to the derivative of f(list(i),x)
+ with respect to x. This derivative must take the form of a column vector.
+ ensures
+ - This function performs an unconstrained minimization of the least squares
+ function g(x) defined by:
+ - g(x) = sum over all i: 0.5*pow( f(list(i),x), 2 )
+ - This method combines the Levenberg-Marquardt method with a quasi-newton method
+ for approximating the second order terms of the hessian and is appropriate for
+ large residual problems (i.e. problems where the f() function isn't driven to 0).
+ In particular, it uses the method of Dennis, Gay, and Welsch as described in
+ Numerical Optimization by Nocedal and Wright (second edition).
+ - Since this is a trust region algorithm, the radius parameter defines the initial
+ size of the trust region.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found or the trust region subproblem fails to make progress.
+ - #x == the value of x that was found to minimize g()
+ - returns g(#x).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stop_strategy_type,
+ typename funct_type,
+ typename funct_der_type,
+ typename vector_type,
+ typename T
+ >
+ double solve_least_squares_lm (
+ stop_strategy_type stop_strategy,
+ const funct_type& f,
+ const funct_der_type& der,
+ const vector_type& list,
+ T& x,
+ double radius = 1
+ );
+ /*!
+ requires
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - list == a matrix or something convertible to a matrix via mat()
+ such as a std::vector.
+ - is_vector(list) == true
+ - list.size() > 0
+ - is_col_vector(x) == true
+ - radius > 0
+ - for all valid i:
+ - f(list(i),x) must be a valid expression that evaluates to a floating point value.
+ - der(list(i),x) must be a valid expression that evaluates to the derivative of f(list(i),x)
+ with respect to x. This derivative must take the form of a column vector.
+ ensures
+ - This function performs an unconstrained minimization of the least squares
+ function g(x) defined by:
+ - g(x) = sum over all i: 0.5*pow( f(list(i),x), 2 )
+ - This method implements a plain Levenberg-Marquardt approach for approximating
+ the hessian of g(). Therefore, it is most appropriate for small residual problems
+ (i.e. problems where f() goes to 0 at the solution).
+ - Since this is a trust region algorithm, the radius parameter defines the initial
+ size of the trust region.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found or the trust region subproblem fails to make progress.
+ - #x == the value of x that was found to minimize g()
+ - returns g(#x).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_LEAST_SQUARES_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/optimization/optimization_line_search.h b/ml/dlib/dlib/optimization/optimization_line_search.h
new file mode 100644
index 000000000..a91e3df84
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_line_search.h
@@ -0,0 +1,888 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATIOn_LINE_SEARCH_H_
+#define DLIB_OPTIMIZATIOn_LINE_SEARCH_H_
+
+#include <cmath>
+#include <limits>
+#include "../matrix.h"
+#include "../algs.h"
+#include "optimization_line_search_abstract.h"
+#include <utility>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct, typename T>
+ class line_search_funct
+ {
+ public:
+ line_search_funct(const funct& f_, const T& start_, const T& direction_)
+ : f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(0)
+ {}
+
+ line_search_funct(const funct& f_, const T& start_, const T& direction_, T& r)
+ : f(f_),start(start_), direction(direction_), matrix_r(&r), scalar_r(0)
+ {}
+
+ line_search_funct(const funct& f_, const T& start_, const T& direction_, double& r)
+ : f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(&r)
+ {}
+
+ double operator()(const double& x) const
+ {
+ return get_value(f(start + x*direction));
+ }
+
+ private:
+
+ double get_value (const double& r) const
+ {
+ // save a copy of this value for later
+ if (scalar_r)
+ *scalar_r = r;
+
+ return r;
+ }
+
+ template <typename U>
+ double get_value (const U& r) const
+ {
+ // U should be a matrix type
+ COMPILE_TIME_ASSERT(is_matrix<U>::value);
+
+ // save a copy of this value for later
+ if (matrix_r)
+ *matrix_r = r;
+
+ return dot(r,direction);
+ }
+
+ const funct& f;
+ const T& start;
+ const T& direction;
+ T* matrix_r;
+ double* scalar_r;
+ };
+
+ template <typename funct, typename T>
+ const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction)
+ {
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ DLIB_ASSERT (
+ is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
+ "\tline_search_funct make_line_search_function(f,start,direction)"
+ << "\n\tYou have to supply column vectors to this function"
+ << "\n\tstart.nc(): " << start.nc()
+ << "\n\tdirection.nc(): " << direction.nc()
+ << "\n\tstart.nr(): " << start.nr()
+ << "\n\tdirection.nr(): " << direction.nr()
+ );
+ return line_search_funct<funct,T>(f,start,direction);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct, typename T>
+ const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, double& f_out)
+ {
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ DLIB_ASSERT (
+ is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
+ "\tline_search_funct make_line_search_function(f,start,direction)"
+ << "\n\tYou have to supply column vectors to this function"
+ << "\n\tstart.nc(): " << start.nc()
+ << "\n\tdirection.nc(): " << direction.nc()
+ << "\n\tstart.nr(): " << start.nr()
+ << "\n\tdirection.nr(): " << direction.nr()
+ );
+ return line_search_funct<funct,T>(f,start,direction, f_out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct, typename T>
+ const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, T& grad_out)
+ {
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ DLIB_ASSERT (
+ is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
+ "\tline_search_funct make_line_search_function(f,start,direction)"
+ << "\n\tYou have to supply column vectors to this function"
+ << "\n\tstart.nc(): " << start.nc()
+ << "\n\tdirection.nc(): " << direction.nc()
+ << "\n\tstart.nr(): " << start.nr()
+ << "\n\tdirection.nr(): " << direction.nr()
+ );
+ return line_search_funct<funct,T>(f,start,direction,grad_out);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ inline double poly_min_extrap (
+ double f0,
+ double d0,
+ double f1,
+ double d1,
+ double limit = 1
+ )
+ {
+ const double n = 3*(f1 - f0) - 2*d0 - d1;
+ const double e = d0 + d1 - 2*(f1 - f0);
+
+
+ // find the minimum of the derivative of the polynomial
+
+ double temp = std::max(n*n - 3*e*d0,0.0);
+
+ if (temp < 0)
+ return 0.5;
+
+ temp = std::sqrt(temp);
+
+ if (std::abs(e) <= std::numeric_limits<double>::epsilon())
+ return 0.5;
+
+ // figure out the two possible min values
+ double x1 = (temp - n)/(3*e);
+ double x2 = -(temp + n)/(3*e);
+
+ // compute the value of the interpolating polynomial at these two points
+ double y1 = f0 + d0*x1 + n*x1*x1 + e*x1*x1*x1;
+ double y2 = f0 + d0*x2 + n*x2*x2 + e*x2*x2*x2;
+
+ // pick the best point
+ double x;
+ if (y1 < y2)
+ x = x1;
+ else
+ x = x2;
+
+ // now make sure the minimum is within the allowed range of [0,limit]
+ return put_in_range(0,limit,x);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double poly_min_extrap (
+ double f0,
+ double d0,
+ double f1
+ )
+ {
+ const double temp = 2*(f1 - f0 - d0);
+ if (std::abs(temp) <= d0*std::numeric_limits<double>::epsilon())
+ return 0.5;
+
+ const double alpha = -d0/temp;
+
+ // now make sure the minimum is within the allowed range of (0,1)
+ return put_in_range(0,1,alpha);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double poly_min_extrap (
+ double f0,
+ double d0,
+ double x1,
+ double f_x1,
+ double x2,
+ double f_x2
+ )
+ {
+ DLIB_ASSERT(0 < x1 && x1 < x2,"Invalid inputs were given to this function");
+ // The contents of this function follow the equations described on page 58 of the
+ // book Numerical Optimization by Nocedal and Wright, second edition.
+ matrix<double,2,2> m;
+ matrix<double,2,1> v;
+
+ const double aa2 = x2*x2;
+ const double aa1 = x1*x1;
+ m = aa2, -aa1,
+ -aa2*x2, aa1*x1;
+ v = f_x1 - f0 - d0*x1,
+ f_x2 - f0 - d0*x2;
+
+
+ double temp = aa2*aa1*(x1-x2);
+
+ // just take a guess if this happens
+ if (temp == 0 || std::fpclassify(temp) == FP_SUBNORMAL)
+ {
+ return x1/2.0;
+ }
+
+ matrix<double,2,1> temp2;
+ temp2 = m*v/temp;
+ const double a = temp2(0);
+ const double b = temp2(1);
+
+ temp = b*b - 3*a*d0;
+ if (temp < 0 || a == 0)
+ {
+ // This is probably a line so just pick the lowest point
+ if (f0 < f_x2)
+ return 0;
+ else
+ return x2;
+ }
+ temp = (-b + std::sqrt(temp))/(3*a);
+ return put_in_range(0, x2, temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double lagrange_poly_min_extrap (
+ double p1,
+ double p2,
+ double p3,
+ double f1,
+ double f2,
+ double f3
+ )
+ {
+ DLIB_ASSERT(p1 < p2 && p2 < p3 && f1 >= f2 && f2 <= f3,
+ " p1: " << p1
+ << " p2: " << p2
+ << " p3: " << p3
+ << " f1: " << f1
+ << " f2: " << f2
+ << " f3: " << f3);
+
+ // This formula is out of the book Nonlinear Optimization by Andrzej Ruszczynski. See section 5.2.
+ double temp1 = f1*(p3*p3 - p2*p2) + f2*(p1*p1 - p3*p3) + f3*(p2*p2 - p1*p1);
+ double temp2 = 2*(f1*(p3 - p2) + f2*(p1 - p3) + f3*(p2 - p1) );
+
+ if (temp2 == 0)
+ {
+ return p2;
+ }
+
+ const double result = temp1/temp2;
+
+ // do a final sanity check to make sure the result is in the right range
+ if (p1 <= result && result <= p3)
+ {
+ return result;
+ }
+ else
+ {
+ return std::min(std::max(p1,result),p3);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct,
+ typename funct_der
+ >
+ double line_search (
+ const funct& f,
+ const double f0,
+ const funct_der& der,
+ const double d0,
+ double rho,
+ double sigma,
+ double min_f,
+ unsigned long max_iter
+ )
+ {
+ DLIB_ASSERT (
+ 0 < rho && rho < sigma && sigma < 1 && max_iter > 0,
+ "\tdouble line_search()"
+ << "\n\tYou have given invalid arguments to this function"
+ << "\n\t sigma: " << sigma
+ << "\n\t rho: " << rho
+ << "\n\t max_iter: " << max_iter
+ );
+
+ // The bracketing phase of this function is implemented according to block 2.6.2 from
+ // the book Practical Methods of Optimization by R. Fletcher. The sectioning
+ // phase is an implementation of 2.6.4 from the same book.
+
+ // 1 <= tau1a < tau1b. Controls the alpha jump size during the bracketing phase of
+ // the search.
+ const double tau1a = 1.4;
+ const double tau1b = 9;
+
+ // it must be the case that 0 < tau2 < tau3 <= 1/2 for the algorithm to function
+ // correctly but the specific values of tau2 and tau3 aren't super important.
+ const double tau2 = 1.0/10.0;
+ const double tau3 = 1.0/2.0;
+
+
+ // Stop right away and return a step size of 0 if the gradient is 0 at the starting point
+ if (std::abs(d0) <= std::abs(f0)*std::numeric_limits<double>::epsilon())
+ return 0;
+
+ // Stop right away if the current value is good enough according to min_f
+ if (f0 <= min_f)
+ return 0;
+
+ // Figure out a reasonable upper bound on how large alpha can get.
+ const double mu = (min_f-f0)/(rho*d0);
+
+
+ double alpha = 1;
+ if (mu < 0)
+ alpha = -alpha;
+ alpha = put_in_range(0, 0.65*mu, alpha);
+
+
+ double last_alpha = 0;
+ double last_val = f0;
+ double last_val_der = d0;
+
+ // The bracketing stage will find a range of points [a,b]
+ // that contains a reasonable solution to the line search
+ double a, b;
+
+ // These variables will hold the values and derivatives of f(a) and f(b)
+ double a_val, b_val, a_val_der, b_val_der;
+
+ // This thresh value represents the Wolfe curvature condition
+ const double thresh = std::abs(sigma*d0);
+
+ unsigned long itr = 0;
+ // do the bracketing stage to find the bracket range [a,b]
+ while (true)
+ {
+ ++itr;
+ const double val = f(alpha);
+ const double val_der = der(alpha);
+
+ // we are done with the line search since we found a value smaller
+ // than the minimum f value
+ if (val <= min_f)
+ return alpha;
+
+ if (val > f0 + rho*alpha*d0 || val >= last_val)
+ {
+ a_val = last_val;
+ a_val_der = last_val_der;
+ b_val = val;
+ b_val_der = val_der;
+
+ a = last_alpha;
+ b = alpha;
+ break;
+ }
+
+ if (std::abs(val_der) <= thresh)
+ return alpha;
+
+ // if we are stuck not making progress then quit with the current alpha
+ if (last_alpha == alpha || itr >= max_iter)
+ return alpha;
+
+ if (val_der >= 0)
+ {
+ a_val = val;
+ a_val_der = val_der;
+ b_val = last_val;
+ b_val_der = last_val_der;
+
+ a = alpha;
+ b = last_alpha;
+ break;
+ }
+
+
+
+ const double temp = alpha;
+ // Pick a larger range [first, last]. We will pick the next alpha in that
+ // range.
+ double first, last;
+ if (mu > 0)
+ {
+ first = std::min(mu, alpha + tau1a*(alpha - last_alpha));
+ last = std::min(mu, alpha + tau1b*(alpha - last_alpha));
+ }
+ else
+ {
+ first = std::max(mu, alpha + tau1a*(alpha - last_alpha));
+ last = std::max(mu, alpha + tau1b*(alpha - last_alpha));
+ }
+
+
+
+ // pick a point between first and last by doing some kind of interpolation
+ if (last_alpha < alpha)
+ alpha = last_alpha + (alpha-last_alpha)*poly_min_extrap(last_val, last_val_der, val, val_der, 1e10);
+ else
+ alpha = alpha + (last_alpha-alpha)*poly_min_extrap(val, val_der, last_val, last_val_der, 1e10);
+
+ alpha = put_in_range(first,last,alpha);
+
+ last_alpha = temp;
+
+ last_val = val;
+ last_val_der = val_der;
+
+ }
+
+
+ // Now do the sectioning phase from 2.6.4
+ while (true)
+ {
+ ++itr;
+ double first = a + tau2*(b-a);
+ double last = b - tau3*(b-a);
+
+ // use interpolation to pick alpha between first and last
+ alpha = a + (b-a)*poly_min_extrap(a_val, a_val_der, b_val, b_val_der);
+ alpha = put_in_range(first,last,alpha);
+
+ const double val = f(alpha);
+ const double val_der = der(alpha);
+
+ // we are done with the line search since we found a value smaller
+ // than the minimum f value or we ran out of iterations.
+ if (val <= min_f || itr >= max_iter)
+ return alpha;
+
+ // stop if the interval gets so small that it isn't shrinking any more due to rounding error
+ if (a == first || b == last)
+ {
+ return b;
+ }
+
+ // If alpha has basically become zero then just stop. Think of it like this,
+ // if we take the largest possible alpha step will the objective function
+ // change at all? If not then there isn't any point looking for a better
+ // alpha.
+ const double max_possible_alpha = std::max(std::abs(a),std::abs(b));
+ if (std::abs(max_possible_alpha*d0) <= std::abs(f0)*std::numeric_limits<double>::epsilon())
+ return alpha;
+
+
+ if (val > f0 + rho*alpha*d0 || val >= a_val)
+ {
+ b = alpha;
+ b_val = val;
+ b_val_der = val_der;
+ }
+ else
+ {
+ if (std::abs(val_der) <= thresh)
+ return alpha;
+
+ if ( (b-a)*val_der >= 0)
+ {
+ b = a;
+ b_val = a_val;
+ b_val_der = a_val_der;
+ }
+
+ a = alpha;
+ a_val = val;
+ a_val_der = val_der;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ double backtracking_line_search (
+ const funct& f,
+ double f0,
+ double d0,
+ double alpha,
+ double rho,
+ unsigned long max_iter
+ )
+ {
+ DLIB_ASSERT (
+ 0 < rho && rho < 1 && max_iter > 0,
+ "\tdouble backtracking_line_search()"
+ << "\n\tYou have given invalid arguments to this function"
+ << "\n\t rho: " << rho
+ << "\n\t max_iter: " << max_iter
+ );
+
+ // make sure alpha is going in the right direction. That is, it should be opposite
+ // the direction of the gradient.
+ if ((d0 > 0 && alpha > 0) ||
+ (d0 < 0 && alpha < 0))
+ {
+ alpha *= -1;
+ }
+
+ bool have_prev_alpha = false;
+ double prev_alpha = 0;
+ double prev_val = 0;
+ unsigned long iter = 0;
+ while (true)
+ {
+ ++iter;
+ const double val = f(alpha);
+ if (val <= f0 + alpha*rho*d0 || iter >= max_iter)
+ {
+ return alpha;
+ }
+ else
+ {
+ // Interpolate a new alpha. We also make sure the step by which we
+ // reduce alpha is not super small.
+ double step;
+ if (!have_prev_alpha)
+ {
+ if (d0 < 0)
+ step = alpha*put_in_range(0.1,0.9, poly_min_extrap(f0, d0, val));
+ else
+ step = alpha*put_in_range(0.1,0.9, poly_min_extrap(f0, -d0, val));
+ have_prev_alpha = true;
+ }
+ else
+ {
+ if (d0 < 0)
+ step = put_in_range(0.1*alpha,0.9*alpha, poly_min_extrap(f0, d0, alpha, val, prev_alpha, prev_val));
+ else
+ step = put_in_range(0.1*alpha,0.9*alpha, -poly_min_extrap(f0, -d0, -alpha, val, -prev_alpha, prev_val));
+ }
+
+ prev_alpha = alpha;
+ prev_val = val;
+
+ alpha = step;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class optimize_single_variable_failure : public error {
+ public: optimize_single_variable_failure(const std::string& s):error(s){}
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct>
+ double find_min_single_variable (
+ const funct& f,
+ double& starting_point,
+ const double begin = -1e200,
+ const double end = 1e200,
+ const double eps = 1e-3,
+ const long max_iter = 100,
+ const double initial_search_radius = 1
+ )
+ {
+ DLIB_CASSERT( eps > 0 &&
+ max_iter > 1 &&
+ begin <= starting_point && starting_point <= end &&
+ initial_search_radius > 0,
+ "eps: " << eps
+ << "\n max_iter: "<< max_iter
+ << "\n begin: "<< begin
+ << "\n end: "<< end
+ << "\n starting_point: "<< starting_point
+ << "\n initial_search_radius: "<< initial_search_radius
+ );
+
+ double search_radius = initial_search_radius;
+
+ double p1=0, p2=0, p3=0, f1=0, f2=0, f3=0;
+ long f_evals = 1;
+
+ if (begin == end)
+ {
+ return f(starting_point);
+ }
+
+ using std::abs;
+ using std::min;
+ using std::max;
+
+ // find three bracketing points such that f1 > f2 < f3. Do this by generating a sequence
+ // of points expanding away from 0. Also note that, in the following code, it is always the
+ // case that p1 < p2 < p3.
+
+
+
+ // The first thing we do is get a starting set of 3 points that are inside the [begin,end] bounds
+ p1 = max(starting_point-search_radius, begin);
+ p3 = min(starting_point+search_radius, end);
+ f1 = f(p1);
+ f3 = f(p3);
+
+ if (starting_point == p1 || starting_point == p3)
+ {
+ p2 = (p1+p3)/2;
+ f2 = f(p2);
+ }
+ else
+ {
+ p2 = starting_point;
+ f2 = f(starting_point);
+ }
+
+ f_evals += 2;
+
+ // Now we have 3 points on the function. Start looking for a bracketing set such that
+ // f1 > f2 < f3 is the case.
+ while ( !(f1 > f2 && f2 < f3))
+ {
+ // check for hitting max_iter or if the interval is now too small
+ if (f_evals >= max_iter)
+ {
+ throw optimize_single_variable_failure(
+ "The max number of iterations of single variable optimization have been reached\n"
+ "without converging.");
+ }
+ if (p3-p1 < eps)
+ {
+ if (f1 < min(f2,f3))
+ {
+ starting_point = p1;
+ return f1;
+ }
+
+ if (f2 < min(f1,f3))
+ {
+ starting_point = p2;
+ return f2;
+ }
+
+ starting_point = p3;
+ return f3;
+ }
+
+ // If the left most points are identical in function value then expand out the
+ // left a bit, unless it's already at bound or we would drop that left most
+ // point anyway because it's bad.
+ if (f1==f2 && f1<f3 && p1!=begin)
+ {
+ p1 = max(p1 - search_radius, begin);
+ f1 = f(p1);
+ ++f_evals;
+ search_radius *= 2;
+ continue;
+ }
+ if (f2==f3 && f3<f1 && p3!=end)
+ {
+ p3 = min(p3 + search_radius, end);
+ f3 = f(p3);
+ ++f_evals;
+ search_radius *= 2;
+ continue;
+ }
+
+
+ // if f1 is small then take a step to the left
+ if (f1 <= f3)
+ {
+ // check if the minimum is butting up against the bounds and if so then pick
+ // a point between p1 and p2 in the hopes that shrinking the interval will
+ // be a good thing to do. Or if p1 and p2 aren't differentiated then try and
+ // get them to obtain different values.
+ if (p1 == begin || (f1 == f2 && (end-begin) < search_radius ))
+ {
+ p3 = p2;
+ f3 = f2;
+
+ p2 = (p1+p2)/2.0;
+ f2 = f(p2);
+ }
+ else
+ {
+ // pick a new point to the left of our current bracket
+ p3 = p2;
+ f3 = f2;
+
+ p2 = p1;
+ f2 = f1;
+
+ p1 = max(p1 - search_radius, begin);
+ f1 = f(p1);
+
+ search_radius *= 2;
+ }
+
+ }
+ // otherwise f3 is small and we should take a step to the right
+ else
+ {
+ // check if the minimum is butting up against the bounds and if so then pick
+ // a point between p2 and p3 in the hopes that shrinking the interval will
+ // be a good thing to do. Or if p2 and p3 aren't differentiated then try and
+ // get them to obtain different values.
+ if (p3 == end || (f2 == f3 && (end-begin) < search_radius))
+ {
+ p1 = p2;
+ f1 = f2;
+
+ p2 = (p3+p2)/2.0;
+ f2 = f(p2);
+ }
+ else
+ {
+ // pick a new point to the right of our current bracket
+ p1 = p2;
+ f1 = f2;
+
+ p2 = p3;
+ f2 = f3;
+
+ p3 = min(p3 + search_radius, end);
+ f3 = f(p3);
+
+ search_radius *= 2;
+ }
+ }
+
+ ++f_evals;
+ }
+
+
+ // Loop until we have done the max allowable number of iterations or
+ // the bracketing window is smaller than eps.
+ // Within this loop we maintain the invariant that: f1 > f2 < f3 and p1 < p2 < p3
+ const double tau = 0.1;
+ while( f_evals < max_iter && p3-p1 > eps)
+ {
+ double p_min = lagrange_poly_min_extrap(p1,p2,p3, f1,f2,f3);
+
+
+ // make sure p_min isn't too close to the three points we already have
+ if (p_min < p2)
+ {
+ const double min_dist = (p2-p1)*tau;
+ if (abs(p1-p_min) < min_dist)
+ {
+ p_min = p1 + min_dist;
+ }
+ else if (abs(p2-p_min) < min_dist)
+ {
+ p_min = p2 - min_dist;
+ }
+ }
+ else
+ {
+ const double min_dist = (p3-p2)*tau;
+ if (abs(p2-p_min) < min_dist)
+ {
+ p_min = p2 + min_dist;
+ }
+ else if (abs(p3-p_min) < min_dist)
+ {
+ p_min = p3 - min_dist;
+ }
+ }
+
+ // make sure one side of the bracket isn't super huge compared to the other
+ // side. If it is then contract it.
+ const double bracket_ratio = abs(p1-p2)/abs(p2-p3);
+ if ( !( bracket_ratio < 10 && bracket_ratio > 0.1) )
+ {
+ // Force p_min to be on a reasonable side. But only if lagrange_poly_min_extrap()
+ // didn't put it on a good side already.
+ if (bracket_ratio > 1 && p_min > p2)
+ p_min = (p1+p2)/2;
+ else if (p_min < p2)
+ p_min = (p2+p3)/2;
+ }
+
+
+ const double f_min = f(p_min);
+
+
+ // Remove one of the endpoints of our bracket depending on where the new point falls.
+ if (p_min < p2)
+ {
+ if (f1 > f_min && f_min < f2)
+ {
+ p3 = p2;
+ f3 = f2;
+ p2 = p_min;
+ f2 = f_min;
+ }
+ else
+ {
+ p1 = p_min;
+ f1 = f_min;
+ }
+ }
+ else
+ {
+ if (f2 > f_min && f_min < f3)
+ {
+ p1 = p2;
+ f1 = f2;
+ p2 = p_min;
+ f2 = f_min;
+ }
+ else
+ {
+ p3 = p_min;
+ f3 = f_min;
+ }
+ }
+
+
+ ++f_evals;
+ }
+
+ if (f_evals >= max_iter)
+ {
+ throw optimize_single_variable_failure(
+ "The max number of iterations of single variable optimization have been reached\n"
+ "without converging.");
+ }
+
+ starting_point = p2;
+ return f2;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct>
+ class negate_function_object
+ {
+ public:
+ negate_function_object(const funct& f_) : f(f_){}
+
+ template <typename T>
+ double operator()(const T& x) const
+ {
+ return -f(x);
+ }
+
+ private:
+ const funct& f;
+ };
+
+ template <typename funct>
+ const negate_function_object<funct> negate_function(const funct& f) { return negate_function_object<funct>(f); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct>
+ double find_max_single_variable (
+ const funct& f,
+ double& starting_point,
+ const double begin = -1e200,
+ const double end = 1e200,
+ const double eps = 1e-3,
+ const long max_iter = 100,
+ const double initial_search_radius = 1
+ )
+ {
+ return -find_min_single_variable(negate_function(f), starting_point, begin, end, eps, max_iter, initial_search_radius);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_LINE_SEARCH_H_
+
diff --git a/ml/dlib/dlib/optimization/optimization_line_search_abstract.h b/ml/dlib/dlib/optimization/optimization_line_search_abstract.h
new file mode 100644
index 000000000..2aa221da4
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_line_search_abstract.h
@@ -0,0 +1,361 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATIOn_ABSTRACT_
+#ifdef DLIB_OPTIMIZATIOn_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct,
+ typename T
+ >
+ class line_search_funct;
+ /*!
+ This object is a function object that represents a line search function.
+
+ Moreover, it represents a function with the signature:
+ double l(double x)
+ !*/
+
+ template <
+ typename funct,
+ typename T
+ >
+ const line_search_funct<funct,T> make_line_search_function (
+ const funct& f,
+ const T& start,
+ const T& direction
+ );
+ /*!
+ requires
+ - is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size()
+ (i.e. start and direction should be column vectors of the same size)
+ - f must return either a double or a column vector the same length as start
+ - f(start + 1.5*direction) should be a valid expression
+ ensures
+ - if (f returns a double) then
+ - returns a line search function that computes l(x) == f(start + x*direction)
+ - else
+ - returns a line search function that computes l(x) == dot(f(start + x*direction),direction).
+ That is, we assume f is the derivative of some other function and that what
+ f returns is a gradient vector.
+ So the following two expressions both create the derivative of l(x):
+ - derivative(make_line_search_function(funct,start,direction))
+ - make_line_search_function(derivative(funct),start,direction)
+ !*/
+
+ template <
+ typename funct,
+ typename T
+ >
+ const line_search_funct<funct,T> make_line_search_function (
+ const funct& f,
+ const T& start,
+ const T& direction,
+ double& f_out
+ );
+ /*!
+ This function is identical to the above three argument version of make_line_search_function()
+ except that, if f() outputs a double, every time f() is evaluated its output is also stored
+ into f_out.
+ !*/
+
+ template <
+ typename funct,
+ typename T
+ >
+ const line_search_funct<funct,T> make_line_search_function (
+ const funct& f,
+ const T& start,
+ const T& direction,
+ T& gradient_out
+ );
+ /*!
+ This function is identical to the above three argument version of make_line_search_function()
+ except that, if f() outputs a column vector, every time f() is evaluated its output is also
+ stored into gradient_out.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline double poly_min_extrap (
+ double f0,
+ double d0,
+ double f1,
+ double d1,
+ double limit = 1
+ );
+ /*!
+ ensures
+ - let c(x) be a 3rd degree polynomial such that:
+ - c(0) == f0
+ - c(1) == f1
+ - derivative of c(x) at x==0 is d0
+ - derivative of c(x) at x==1 is d1
+ - returns the point in the range [0,limit] that minimizes the polynomial c(x)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline double poly_min_extrap (
+ double f0,
+ double d0,
+ double f1
+ );
+ /*!
+ ensures
+ - let c(x) be a 2nd degree polynomial such that:
+ - c(0) == f0
+ - c(1) == f1
+ - derivative of c(x) at x==0 is d0
+ - returns the point in the range [0,1] that minimizes the polynomial c(x)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline double poly_min_extrap (
+ double f0,
+ double d0,
+ double x1,
+ double f_x1,
+ double x2,
+ double f_x2
+ )
+ /*!
+ requires
+ - 0 < x1 < x2
+ ensures
+ - let f(x) be a 3rd degree polynomial such that:
+ - f(0) == f0
+ - derivative of f(x) at x==0 is d0
+ - f(x1) == f_x1
+ - f(x2) == f_x2
+ - returns the point in the range [0,x2] that minimizes the polynomial f(x)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline double lagrange_poly_min_extrap (
+ double p1,
+ double p2,
+ double p3,
+ double f1,
+ double f2,
+ double f3
+ );
+ /*!
+ requires
+ - f1 >= f2 <= f3
+ - p1 < p2 < p3
+ ensures
+ - let c(x) be the second order Lagrange polynomial that interpolates the
+ points p1, p2, and p3 where c(p1)==f1, c(p2)==f2, and c(p3)==f3
+ - this function returns the point in the range [p1,p3] that minimizes
+ the polynomial c(x)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct,
+ typename funct_der
+ >
+ double line_search (
+ const funct& f,
+ const double f0,
+ const funct_der& der,
+ const double d0,
+ double rho,
+ double sigma,
+ double min_f,
+ unsigned long max_iter
+ )
+ /*!
+ requires
+ - 0 < rho < sigma < 1
+ - f and der are scalar functions of scalars
+ (e.g. line_search_funct objects)
+ - der is the derivative of f
+ - f0 == f(0)
+ - d0 == der(0)
+ - max_iter > 0
+ ensures
+ - Performs a line search and uses the strong Wolfe conditions to decide when
+ the search can stop.
+ - rho == the parameter of the Wolfe sufficient decrease condition
+ - sigma == the parameter of the Wolfe curvature condition
+ - max_iter == the maximum number of iterations allowable. After this
+ many evaluations of f() line_search() is guaranteed to terminate.
+ - returns a value alpha such that f(alpha) is significantly closer to
+ the minimum of f than f(0).
+ - It is assumed that the minimum possible value of f(x) is min_f. So if
+ an alpha is found such that f(alpha) <= min_f then the search stops
+ immediately.
+ - This function is also optimized for the case where der(0) is negative. I.e.
+ positive values of the argument to f() should be in a descent direction.
+ - When this function makes calls to f() and der() it always does so by
+ first calling f() and then calling der(). That is, these two functions
+ are always called in pairs with f() being called first and then der()
+ being called second.
+ !*/
+
+ /*
+ A good discussion of the Wolfe conditions and line search algorithms in
+ general can be found in the book Practical Methods of Optimization by R. Fletcher
+ and also in the more recent book Numerical Optimization by Nocedal and Wright.
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ double backtracking_line_search (
+ const funct& f,
+ double f0,
+ double d0,
+ double alpha,
+ double rho,
+ unsigned long max_iter
+ );
+ /*!
+ requires
+ - 0 < rho < 1
+ - f is a scalar function of scalars
+ (e.g. a line_search_funct object)
+ - f0 == f(0)
+ - d0 == the derivative of f() at f(0).
+ - max_iter > 0
+ ensures
+ - Performs a backtracking line search and uses the Armijo sufficient decrease
+ rule to decide when the search can stop.
+ - rho == the parameter of the sufficient decrease condition.
+ - max_iter == the maximum number of iterations allowable. After this many
+ evaluations of f() backtracking_line_search() is guaranteed to terminate.
+ - The line search starts with the input alpha value and then backtracks until
+ it finds a good enough alpha value. Once found, it returns the alpha value
+ such that f(alpha) is significantly closer to the minimum of f than f(0).
+ - The returned value of alpha will always be the last value of alpha which was
+ passed to f(). That is, it will always be the case that the last call to f()
+ made by backtracking_line_search() was f(alpha) where alpha is the return
+ value from this function.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ const negate_function_object<funct> negate_function(
+ const funct& f
+ );
+ /*!
+ requires
+ - f == a function that returns a scalar
+ ensures
+ - returns a function that represents the negation of f. That is,
+ the returned function object represents g(x) == -f(x)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class optimize_single_variable_failure : public error;
+ /*!
+ This is the exception class used by the functions defined below.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ double find_min_single_variable (
+ const funct& f,
+ double& starting_point,
+ const double begin = -1e200,
+ const double end = 1e200,
+ const double eps = 1e-3,
+ const long max_iter = 100,
+ const double initial_search_radius = 1
+ )
+ /*!
+ requires
+ - eps > 0
+ - max_iter > 1
+ - begin <= starting_point <= end
+ - f must be a function of a double that returns a double
+ (e.g. f(starting_point) should be a valid expression that evaluates to a double)
+ - initial_search_radius > 0
+ ensures
+ - Finds a point P such that:
+ - P is a local minimum of the function f().
+ - begin <= P <= end
+ - Evaluates f() no more than max_iter times
+ - Stops searching when the window around the minimum point is smaller than eps.
+ The search will begin with the given starting_point and expand out to the
+ left and right by initial_search_radius sized steps. So if you think the
+ minimum is likely to be found within X distance from the starting_point then
+ set initial_search_radius to X.
+ - #starting_point == P
+ - returns f(P)
+ throws
+ - optimize_single_variable_failure
+ This exception is thrown if max_iter iterations are performed without
+ determining the min point to the requested accuracy of eps.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename funct
+ >
+ double find_max_single_variable (
+ const funct& f,
+ double& starting_point,
+ const double begin = -1e200,
+ const double end = 1e200,
+ const double eps = 1e-3,
+ const long max_iter = 100,
+ const double initial_search_radius = 1
+ )
+ /*!
+ requires
+ - eps > 0
+ - max_iter > 1
+ - begin <= starting_point <= end
+ - f must be a function of a double that returns a double
+ (e.g. f(starting_point) should be a valid expression that evaluates to a double)
+ - initial_search_radius > 0
+ ensures
+ - Finds a point P such that:
+ - P is a local maximum of the function f().
+ - begin <= P <= end
+ - Evaluates f() no more than max_iter times
+ - Stops searching when the window around the maximum point is smaller than eps.
+ The search will begin with the given starting_point and expand out to the
+ left and right by initial_search_radius sized steps. So if you think the
+ maximum is likely to be found within X distance from the starting_point then
+ set initial_search_radius to X.
+ - #starting_point == P
+ - returns f(P)
+ throws
+ - optimize_single_variable_failure
+ This exception is thrown if max_iter iterations are performed without
+ determining the max point to the requested accuracy of eps.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_ABSTRACT_
+
diff --git a/ml/dlib/dlib/optimization/optimization_oca.h b/ml/dlib/dlib/optimization/optimization_oca.h
new file mode 100644
index 000000000..4ca9cd7ae
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_oca.h
@@ -0,0 +1,407 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATIoN_OCA_Hh_
+#define DLIB_OPTIMIZATIoN_OCA_Hh_
+
+#include "optimization_oca_abstract.h"
+
+#include "../matrix.h"
+#include "optimization_solve_qp_using_smo.h"
+#include <vector>
+#include "../sequence.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ template <typename matrix_type>
+ class oca_problem
+ {
+ public:
+ typedef typename matrix_type::type scalar_type;
+
+ virtual ~oca_problem() {}
+
+ virtual bool risk_has_lower_bound (
+ scalar_type&
+ ) const { return false; }
+
+ virtual bool optimization_status (
+ scalar_type ,
+ scalar_type ,
+ scalar_type ,
+ scalar_type ,
+ unsigned long,
+ unsigned long
+ ) const = 0;
+
+ virtual scalar_type get_c (
+ ) const = 0;
+
+ virtual long get_num_dimensions (
+ ) const = 0;
+
+ virtual void get_risk (
+ matrix_type& current_solution,
+ scalar_type& risk_value,
+ matrix_type& risk_subgradient
+ ) const = 0;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class oca
+ {
+ public:
+
+ oca ()
+ {
+ sub_eps = 1e-2;
+ sub_max_iter = 50000;
+
+ inactive_thresh = 20;
+ }
+
+ void set_subproblem_epsilon (
+ double eps_
+ ) { sub_eps = eps_; }
+
+ double get_subproblem_epsilon (
+ ) const { return sub_eps; }
+
+ void set_subproblem_max_iterations (
+ unsigned long sub_max_iter_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(sub_max_iter_ > 0,
+ "\t void oca::set_subproblem_max_iterations"
+ << "\n\t max iterations must be greater than 0"
+ << "\n\t sub_max_iter_: " << sub_max_iter_
+ << "\n\t this: " << this
+ );
+
+ sub_max_iter = sub_max_iter_;
+ }
+
+ unsigned long get_subproblem_max_iterations (
+ ) const { return sub_max_iter; }
+
+ void set_inactive_plane_threshold (
+ unsigned long inactive_thresh_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(inactive_thresh_ > 0,
+ "\t void oca::set_inactive_plane_threshold"
+ << "\n\t inactive threshold must be greater than 0"
+ << "\n\t inactive_thresh_: " << inactive_thresh_
+ << "\n\t this: " << this
+ );
+
+ inactive_thresh = inactive_thresh_;
+ }
+
+ unsigned long get_inactive_plane_threshold (
+ ) const { return inactive_thresh; }
+
+ template <
+ typename matrix_type
+ >
+ typename matrix_type::type operator() (
+ const oca_problem<matrix_type>& problem,
+ matrix_type& w,
+ unsigned long num_nonnegative = 0,
+ unsigned long force_weight_to_1 = std::numeric_limits<unsigned long>::max()
+ ) const
+ {
+ matrix_type empty_prior;
+ return oca_impl(problem, w, empty_prior, false, num_nonnegative, force_weight_to_1, 0);
+ }
+
+ template <
+ typename matrix_type
+ >
+ typename matrix_type::type solve_with_elastic_net (
+ const oca_problem<matrix_type>& problem,
+ matrix_type& w,
+ double lasso_lambda,
+ unsigned long force_weight_to_1 = std::numeric_limits<unsigned long>::max()
+ ) const
+ {
+ matrix_type empty_prior;
+ return oca_impl(problem, w, empty_prior, false, 0, force_weight_to_1, lasso_lambda);
+ }
+
+ template <
+ typename matrix_type
+ >
+ typename matrix_type::type operator() (
+ const oca_problem<matrix_type>& problem,
+ matrix_type& w,
+ const matrix_type& prior
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(prior) && prior.size() == problem.get_num_dimensions(),
+ "\t scalar_type oca::operator()"
+ << "\n\t The prior vector does not have the correct dimensions."
+ << "\n\t is_col_vector(prior): " << is_col_vector(prior)
+ << "\n\t prior.size(): " << prior.size()
+ << "\n\t problem.get_num_dimensions(): " << problem.get_num_dimensions()
+ << "\n\t this: " << this
+ );
+ // disable the force weight to 1 option for this mode. We also disable the
+ // non-negative constraints.
+ unsigned long force_weight_to_1 = std::numeric_limits<unsigned long>::max();
+ return oca_impl(problem, w, prior, true, 0, force_weight_to_1, 0);
+ }
+
+ private:
+
+ template <
+ typename matrix_type
+ >
+ typename matrix_type::type oca_impl (
+ const oca_problem<matrix_type>& problem,
+ matrix_type& w,
+ const matrix_type& prior,
+ bool have_prior,
+ unsigned long num_nonnegative,
+ unsigned long force_weight_to_1,
+ const double lasso_lambda
+ ) const
+ {
+ const unsigned long num_dims = problem.get_num_dimensions();
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(problem.get_c() > 0 &&
+ problem.get_num_dimensions() > 0 &&
+ 0 <= lasso_lambda && lasso_lambda < 1,
+ "\t scalar_type oca::operator()"
+ << "\n\t The oca_problem is invalid"
+ << "\n\t problem.get_c(): " << problem.get_c()
+ << "\n\t problem.get_num_dimensions(): " << num_dims
+ << "\n\t lasso_lambda: " << lasso_lambda
+ << "\n\t this: " << this
+ );
+ if (have_prior)
+ {
+ DLIB_ASSERT(lasso_lambda == 0, "Solver doesn't support using a prior with lasso.");
+ DLIB_ASSERT(num_nonnegative == 0, "Solver doesn't support using a prior with non-negative constraints.");
+ }
+ else if (lasso_lambda != 0)
+ {
+ DLIB_ASSERT(num_nonnegative == 0, "Solver doesn't support using lasso with non-negative constraints.");
+ }
+
+ const double ridge_lambda = 1-lasso_lambda;
+
+ if (num_nonnegative > num_dims)
+ num_nonnegative = num_dims;
+
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef matrix_type vect_type;
+
+ const scalar_type C = problem.get_c();
+
+ typename sequence<vect_type>::kernel_2a planes;
+ std::vector<scalar_type> bs, miss_count;
+
+ vect_type new_plane, alpha, btemp;
+
+ w.set_size(num_dims, 1);
+ w = 0;
+
+ // The current objective value. Note also that w always contains
+ // the current solution.
+ scalar_type cur_obj = std::numeric_limits<scalar_type>::max();
+
+ // This will hold the cutting plane objective value. This value is
+ // a lower bound on the true optimal objective value.
+ scalar_type cp_obj = 0;
+
+ matrix<scalar_type,0,0,mem_manager_type, layout_type> K, Ktmp;
+ matrix<scalar_type,0,1,mem_manager_type, layout_type> lambda, d;
+ if (lasso_lambda != 0)
+ d.set_size(num_dims);
+ else
+ d.set_size(num_nonnegative);
+ d = lasso_lambda*ones_matrix(d);
+
+ scalar_type R_lower_bound;
+ if (problem.risk_has_lower_bound(R_lower_bound))
+ {
+ // The flat lower bounding plane is always good to have if we know
+ // what it is.
+ bs.push_back(R_lower_bound);
+ new_plane = zeros_matrix(w);
+ planes.add(0, new_plane);
+ alpha = uniform_matrix<scalar_type>(1,1, C);
+ miss_count.push_back(0);
+
+ K.set_size(1,1);
+ K(0,0) = 0;
+ }
+
+ const double prior_norm = have_prior ? 0.5*dot(prior,prior) : 0;
+
+ unsigned long counter = 0;
+ while (true)
+ {
+
+ // add the next cutting plane
+ scalar_type cur_risk;
+ if (force_weight_to_1 < (unsigned long)w.size())
+ w(force_weight_to_1) = 1;
+
+ problem.get_risk(w, cur_risk, new_plane);
+
+ if (force_weight_to_1 < (unsigned long)w.size())
+ {
+ // We basically arrange for the w(force_weight_to_1) element and all
+ // subsequent elements of w to not be involved in the optimization at
+ // all. An easy way to do this is to just make sure the elements of w
+ // corresponding elements in the subgradient are always set to zero
+ // while we run the cutting plane algorithm. The only time
+ // w(force_weight_to_1) is 1 is when we pass it to the oca_problem.
+ set_rowm(w, range(force_weight_to_1, w.size()-1)) = 0;
+ set_rowm(new_plane, range(force_weight_to_1, new_plane.size()-1)) = 0;
+ }
+
+ if (have_prior)
+ bs.push_back(cur_risk - dot(w,new_plane) + dot(prior,new_plane));
+ else
+ bs.push_back(cur_risk - dot(w,new_plane));
+ planes.add(planes.size(), new_plane);
+ miss_count.push_back(0);
+
+ // If alpha is empty then initialize it (we must always have sum(alpha) == C).
+ // But otherwise, just append a zero.
+ if (alpha.size() == 0)
+ alpha = uniform_matrix<scalar_type>(1,1, C);
+ else
+ alpha = join_cols(alpha,zeros_matrix<scalar_type>(1,1));
+
+ const scalar_type wnorm = 0.5*ridge_lambda*trans(w)*w + lasso_lambda*sum(abs(w));
+ const double prior_part = have_prior? dot(w,prior) : 0;
+ cur_obj = wnorm + C*cur_risk + prior_norm-prior_part;
+
+ // report current status
+ const scalar_type risk_gap = cur_risk - (cp_obj-wnorm+prior_part-prior_norm)/C;
+ if (counter > 0 && problem.optimization_status(cur_obj, cur_obj - cp_obj,
+ cur_risk, risk_gap, planes.size(), counter))
+ {
+ break;
+ }
+
+ // compute kernel matrix for all the planes
+ K.swap(Ktmp);
+ K.set_size(planes.size(), planes.size());
+ // copy over the old K matrix
+ set_subm(K, 0,0, Ktmp.nr(), Ktmp.nc()) = Ktmp;
+
+ // now add the new row and column to K
+ for (unsigned long c = 0; c < planes.size(); ++c)
+ {
+ K(c, Ktmp.nc()) = dot(planes[c], planes[planes.size()-1]);
+ K(Ktmp.nc(), c) = K(c,Ktmp.nc());
+ }
+
+
+ // solve the cutting plane subproblem for the next w. We solve it to an
+ // accuracy that is related to how big the error gap is. Also, we multiply
+ // by ridge_lambda because the objective function for the QP we solve was
+ // implicitly scaled by ridge_lambda. That is, we want to ask the QP
+ // solver to solve the problem until the duality gap is 0.1 times smaller
+ // than what it is now. So the factor of ridge_lambda is necessary to make
+ // this happen.
+ scalar_type eps = std::min<scalar_type>(sub_eps, 0.1*ridge_lambda*(cur_obj-cp_obj));
+ // just a sanity check
+ if (eps < 1e-16)
+ eps = 1e-16;
+ // Note that we warm start this optimization by using the alpha from the last
+ // iteration as the starting point.
+ if (lasso_lambda != 0)
+ {
+ // copy planes into a matrix so we can call solve_qp4_using_smo()
+ matrix<scalar_type,0,0,mem_manager_type, layout_type> planes_mat(num_dims,planes.size());
+ for (unsigned long i = 0; i < planes.size(); ++i)
+ set_colm(planes_mat,i) = planes[i];
+
+ btemp = ridge_lambda*mat(bs) - trans(planes_mat)*d;
+ solve_qp4_using_smo(planes_mat, K, btemp, d, alpha, lambda, eps, sub_max_iter, (scalar_type)(2*lasso_lambda));
+ }
+ else if (num_nonnegative != 0)
+ {
+ // copy planes into a matrix so we can call solve_qp4_using_smo()
+ matrix<scalar_type,0,0,mem_manager_type, layout_type> planes_mat(num_nonnegative,planes.size());
+ for (unsigned long i = 0; i < planes.size(); ++i)
+ set_colm(planes_mat,i) = colm(planes[i],0,num_nonnegative);
+
+ solve_qp4_using_smo(planes_mat, K, mat(bs), d, alpha, lambda, eps, sub_max_iter);
+ }
+ else
+ {
+ solve_qp_using_smo(K, mat(bs), alpha, eps, sub_max_iter);
+ }
+
+ // construct the w that minimized the subproblem.
+ w = -alpha(0)*planes[0];
+ for (unsigned long i = 1; i < planes.size(); ++i)
+ w -= alpha(i)*planes[i];
+ if (lasso_lambda != 0)
+ w = (lambda-d+w)/ridge_lambda;
+ else if (num_nonnegative != 0) // threshold the first num_nonnegative w elements if necessary.
+ set_rowm(w,range(0,num_nonnegative-1)) = lowerbound(rowm(w,range(0,num_nonnegative-1)),0);
+
+ for (long i = 0; i < alpha.size(); ++i)
+ {
+ if (alpha(i) != 0)
+ miss_count[i] = 0;
+ else
+ miss_count[i] += 1;
+ }
+
+ // Compute the lower bound on the true objective given to us by the cutting
+ // plane subproblem.
+ cp_obj = -0.5*ridge_lambda*trans(w)*w + trans(alpha)*mat(bs);
+ if (have_prior)
+ w += prior;
+
+ // If it has been a while since a cutting plane was an active constraint then
+ // we should throw it away.
+ while (max(mat(miss_count)) >= inactive_thresh)
+ {
+ const long idx = index_of_max(mat(miss_count));
+ bs.erase(bs.begin()+idx);
+ miss_count.erase(miss_count.begin()+idx);
+ K = removerc(K, idx, idx);
+ alpha = remove_row(alpha,idx);
+ planes.remove(idx, new_plane);
+ }
+
+ ++counter;
+ }
+
+ if (force_weight_to_1 < (unsigned long)w.size())
+ w(force_weight_to_1) = 1;
+
+ return cur_obj;
+ }
+
+ double sub_eps;
+
+ unsigned long sub_max_iter;
+
+ unsigned long inactive_thresh;
+ };
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_OPTIMIZATIoN_OCA_Hh_
+
diff --git a/ml/dlib/dlib/optimization/optimization_oca_abstract.h b/ml/dlib/dlib/optimization/optimization_oca_abstract.h
new file mode 100644
index 000000000..859dbdcfc
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_oca_abstract.h
@@ -0,0 +1,334 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATION_OCA_ABsTRACT_Hh_
+#ifdef DLIB_OPTIMIZATION_OCA_ABsTRACT_Hh_
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ template <typename matrix_type>
+ class oca_problem
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ - matrix_type == a dlib::matrix capable of storing column vectors
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is the interface used to define the optimization
+ problems solved by the oca optimizer defined later in this file.
+
+ OCA solves optimization problems with the following form:
+ Minimize: f(w) == 0.5*length_squared(w) + C*R(w)
+
+ Where R(w) is a user-supplied convex function and C > 0. Optionally,
+ there can also be non-negativity constraints on some or all of the
+ elements of w.
+
+ Or it can alternatively solve:
+ Minimize: f(w) == 0.5*length_squared(w-prior) + C*R(w)
+
+ Where prior is a user supplied vector and R(w) has the same
+ interpretation as above.
+
+ Or it can use the elastic net regularizer:
+ Minimize: f(w) == 0.5*(1-lasso_lambda)*length_squared(w) + lasso_lambda*sum(abs(w)) + C*R(w)
+
+ Where lasso_lambda is a number in the range [0, 1) and controls
+ trade-off between doing L1 and L2 regularization. R(w) has the same
+ interpretation as above.
+
+
+ Note that the stopping condition must be provided by the user
+ in the form of the optimization_status() function.
+ !*/
+
+ public:
+
+ typedef typename matrix_type::type scalar_type;
+
+ virtual ~oca_problem() {}
+
+ virtual bool risk_has_lower_bound (
+ scalar_type& lower_bound
+ ) const { return false; }
+ /*!
+ ensures
+ - if (R(w) >= a constant for all values of w) then
+ - returns true
+ - #lower_bound == the constant that lower bounds R(w)
+ - else
+ - returns false
+ !*/
+
+ virtual bool optimization_status (
+ scalar_type current_objective_value,
+ scalar_type current_error_gap,
+ scalar_type current_risk_value,
+ scalar_type current_risk_gap,
+ unsigned long num_cutting_planes,
+ unsigned long num_iterations
+ ) const = 0;
+ /*!
+ requires
+ - This function is called by the OCA optimizer each iteration.
+ - current_objective_value == the current value of the objective function f(w)
+ - current_error_gap == The bound on how much lower the objective function
+ can drop before we reach the optimal point. At the optimal solution the
+ error gap is equal to 0.
+ - current_risk_value == the current value of the R(w) term of the objective function.
+ - current_risk_gap == the bound on how much lower the risk term can go. At the optimal
+ solution the risk gap is zero.
+ - num_cutting_planes == the number of cutting planes the algorithm is currently using.
+ - num_iterations == A count of the total number of iterations that have executed
+ since we started running the optimization.
+ ensures
+ - If it is appropriate to terminate the optimization then this function returns true
+ and false otherwise.
+ !*/
+
+ virtual scalar_type get_c (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the C parameter
+ !*/
+
+ virtual long get_num_dimensions (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the number of free variables in this optimization problem
+ !*/
+
+ virtual void get_risk (
+ matrix_type& current_solution,
+ scalar_type& risk_value,
+ matrix_type& risk_subgradient
+ ) const = 0;
+ /*!
+ requires
+ - is_col_vector(current_solution) == true
+ - current_solution.size() == get_num_dimensions()
+ ensures
+ - #current_solution will be set to one of the following:
+ - current_solution (i.e. it won't be modified at all)
+ - The result of a line search passing through current_solution.
+ - #risk_value == R(#current_solution)
+ - #risk_subgradient == an element of the subgradient of R() at the
+ point #current_solution
+ - Note that #risk_value and #risk_subgradient are NOT multiplied by get_c()
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class oca
+ {
+ /*!
+ INITIAL VALUE
+ - get_subproblem_epsilon() == 1e-2
+ - get_subproblem_max_iterations() == 50000
+ - get_inactive_plane_threshold() == 20
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for solving the optimization problem defined above
+ by the oca_problem abstract class.
+
+ For reference, OCA solves optimization problems with the following form:
+ Minimize: f(w) == 0.5*length_squared(w) + C*R(w)
+
+ Where R(w) is a user-supplied convex function and C > 0. Optionally,
+ this object can also add non-negativity constraints to some or all
+ of the elements of w.
+
+ Or it can alternatively solve:
+ Minimize: f(w) == 0.5*length_squared(w-prior) + C*R(w)
+
+ Where prior is a user supplied vector and R(w) has the same
+ interpretation as above.
+
+ Or it can use the elastic net regularizer:
+ Minimize: f(w) == 0.5*(1-lasso_lambda)*length_squared(w) + lasso_lambda*sum(abs(w)) + C*R(w)
+
+ Where lasso_lambda is a number in the range [0, 1) and controls
+ trade-off between doing L1 and L2 regularization. R(w) has the same
+ interpretation as above.
+
+
+ For a detailed discussion you should consult the following papers
+ from the Journal of Machine Learning Research:
+ Optimized Cutting Plane Algorithm for Large-Scale Risk Minimization
+ Vojtech Franc, Soren Sonnenburg; 10(Oct):2157--2192, 2009.
+
+ Bundle Methods for Regularized Risk Minimization
+ Choon Hui Teo, S.V.N. Vishwanthan, Alex J. Smola, Quoc V. Le; 11(Jan):311-365, 2010.
+ !*/
+ public:
+
+ oca (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ template <
+ typename matrix_type
+ >
+ typename matrix_type::type operator() (
+ const oca_problem<matrix_type>& problem,
+ matrix_type& w,
+ unsigned long num_nonnegative = 0,
+ unsigned long force_weight_to_1 = std::numeric_limits<unsigned long>::max()
+ ) const;
+ /*!
+ requires
+ - problem.get_c() > 0
+ - problem.get_num_dimensions() > 0
+ ensures
+ - solves the given oca problem and stores the solution in #w. In particular,
+ this function solves:
+ Minimize: f(w) == 0.5*length_squared(w) + C*R(w)
+ - The optimization algorithm runs until problem.optimization_status()
+ indicates it is time to stop.
+ - returns the objective value at the solution #w
+ - if (num_nonnegative != 0) then
+ - Adds the constraint that #w(i) >= 0 for all i < num_nonnegative.
+ That is, the first num_nonnegative elements of #w will always be
+ non-negative. This includes the copies of w passed to get_risk()
+ in the form of the current_solution vector as well as the final
+ output of this function.
+ - if (force_weight_to_1 < problem.get_num_dimensions()) then
+ - The optimizer enforces the following constraints:
+ - #w(force_weight_to_1) == 1
+ - for all i > force_weight_to_1:
+ - #w(i) == 0
+ - That is, the element in the weight vector at the index indicated
+ by force_weight_to_1 will have a value of 1 upon completion of
+ this function, while all subsequent elements of w will have
+ values of 0.
+ !*/
+
+ template <
+ typename matrix_type
+ >
+ typename matrix_type::type operator() (
+ const oca_problem<matrix_type>& problem,
+ matrix_type& w,
+ const matrix_type& prior
+ ) const;
+ /*!
+ requires
+ - problem.get_c() > 0
+ - problem.get_num_dimensions() > 0
+ - is_col_vector(prior) == true
+ - prior.size() == problem.get_num_dimensions()
+ ensures
+ - solves the given oca problem and stores the solution in #w.
+ - In this mode, we solve a version of the problem with a different
+ regularizer. In particular, this function solves:
+ Minimize: f(w) == 0.5*length_squared(w-prior) + C*R(w)
+ - The optimization algorithm runs until problem.optimization_status()
+ indicates it is time to stop.
+ - returns the objective value at the solution #w
+ !*/
+
+ template <
+ typename matrix_type
+ >
+ typename matrix_type::type solve_with_elastic_net (
+ const oca_problem<matrix_type>& problem,
+ matrix_type& w,
+ scalar_type lasso_lambda,
+ unsigned long force_weight_to_1 = std::numeric_limits<unsigned long>::max()
+ ) const;
+ /*!
+ requires
+ - problem.get_c() > 0
+ - problem.get_num_dimensions() > 0
+ - 0 <= lasso_lambda < 1
+ ensures
+ - Solves the given oca problem and stores the solution in #w, but uses an
+ elastic net regularizer instead of the normal L2 regularizer. In
+ particular, this function solves:
+ Minimize: f(w) == 0.5*(1-lasso_lambda)*length_squared(w) + lasso_lambda*sum(abs(w)) + C*R(w)
+ - The optimization algorithm runs until problem.optimization_status()
+ indicates it is time to stop.
+ - returns the objective value at the solution #w
+ - if (force_weight_to_1 < problem.get_num_dimensions()) then
+ - The optimizer enforces the following constraints:
+ - #w(force_weight_to_1) == 1
+ - for all i > force_weight_to_1:
+ - #w(i) == 0
+ - That is, the element in the weight vector at the index indicated
+ by force_weight_to_1 will have a value of 1 upon completion of
+ this function, while all subsequent elements of w will have
+ values of 0.
+ !*/
+
+ void set_subproblem_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_subproblem_epsilon() == eps
+ !*/
+
+ double get_subproblem_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the accuracy used in solving the quadratic programming
+ subproblem that is part of the overall OCA algorithm.
+ !*/
+
+ void set_subproblem_max_iterations (
+ unsigned long sub_max_iter
+ );
+ /*!
+ requires
+ - sub_max_iter > 0
+ ensures
+ - #get_subproblem_max_iterations() == sub_max_iter
+ !*/
+
+ unsigned long get_subproblem_max_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of iterations this object will perform
+ while attempting to solve each quadratic programming subproblem.
+ !*/
+
+ void set_inactive_plane_threshold (
+ unsigned long inactive_thresh
+ );
+ /*!
+ requires
+ - inactive_thresh > 0
+ ensures
+ - #get_inactive_plane_threshold() == inactive_thresh
+ !*/
+
+ unsigned long get_inactive_plane_threshold (
+ ) const;
+ /*!
+ ensures
+ - As OCA runs it builds up a set of cutting planes. Typically
+ cutting planes become inactive after a certain point and can then
+ be removed. This function returns the number of iterations of
+ inactivity required before a cutting plane is removed.
+ !*/
+
+ };
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_OPTIMIZATION_OCA_ABsTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/optimization_search_strategies.h b/ml/dlib/dlib/optimization/optimization_search_strategies.h
new file mode 100644
index 000000000..eda637fa1
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_search_strategies.h
@@ -0,0 +1,324 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_H_
+#define DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_H_
+
+#include <cmath>
+#include <limits>
+#include "../matrix.h"
+#include "../algs.h"
+#include "optimization_search_strategies_abstract.h"
+#include "../sequence.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class cg_search_strategy
+ {
+ public:
+ cg_search_strategy() : been_used(false) {}
+
+ double get_wolfe_rho (
+ ) const { return 0.001; }
+
+ double get_wolfe_sigma (
+ ) const { return 0.01; }
+
+ unsigned long get_max_line_search_iterations (
+ ) const { return 100; }
+
+ template <typename T>
+ const matrix<double,0,1>& get_next_direction (
+ const T& ,
+ const double ,
+ const T& funct_derivative
+ )
+ {
+ if (been_used == false)
+ {
+ been_used = true;
+ prev_direction = -funct_derivative;
+ }
+ else
+ {
+ // Use the Polak-Ribiere (4.1.12) conjugate gradient described by Fletcher on page 83
+ const double temp = trans(prev_derivative)*prev_derivative;
+ // If this value hits zero then just use the direction of steepest descent.
+ if (std::abs(temp) < std::numeric_limits<double>::epsilon())
+ {
+ prev_derivative = funct_derivative;
+ prev_direction = -funct_derivative;
+ return prev_direction;
+ }
+
+ double b = trans(funct_derivative-prev_derivative)*funct_derivative/(temp);
+ prev_direction = -funct_derivative + b*prev_direction;
+
+ }
+
+ prev_derivative = funct_derivative;
+ return prev_direction;
+ }
+
+ private:
+ bool been_used;
+ matrix<double,0,1> prev_derivative;
+ matrix<double,0,1> prev_direction;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class bfgs_search_strategy
+ {
+ public:
+ bfgs_search_strategy() : been_used(false), been_used_twice(false) {}
+
+ double get_wolfe_rho (
+ ) const { return 0.01; }
+
+ double get_wolfe_sigma (
+ ) const { return 0.9; }
+
+ unsigned long get_max_line_search_iterations (
+ ) const { return 100; }
+
+ template <typename T>
+ const matrix<double,0,1>& get_next_direction (
+ const T& x,
+ const double ,
+ const T& funct_derivative
+ )
+ {
+ if (been_used == false)
+ {
+ been_used = true;
+ H = identity_matrix<double>(x.size());
+ }
+ else
+ {
+ // update H with the BFGS formula from (3.2.12) on page 55 of Fletcher
+ delta = (x-prev_x);
+ gamma = funct_derivative-prev_derivative;
+
+ double dg = dot(delta,gamma);
+
+ // Try to set the initial value of the H matrix to something reasonable if we are still
+ // in the early stages of figuring out what it is. This formula below is what is suggested
+ // in the book Numerical Optimization by Nocedal and Wright in the chapter on Quasi-Newton methods.
+ if (been_used_twice == false)
+ {
+ double gg = trans(gamma)*gamma;
+ if (std::abs(gg) > std::numeric_limits<double>::epsilon())
+ {
+ const double temp = put_in_range(0.01, 100, dg/gg);
+ H = diagm(uniform_matrix<double>(x.size(),1, temp));
+ been_used_twice = true;
+ }
+ }
+
+ Hg = H*gamma;
+ gH = trans(trans(gamma)*H);
+ double gHg = trans(gamma)*H*gamma;
+ if (gHg < std::numeric_limits<double>::infinity() && dg < std::numeric_limits<double>::infinity() &&
+ dg != 0)
+ {
+ H += (1 + gHg/dg)*delta*trans(delta)/(dg) - (delta*trans(gH) + Hg*trans(delta))/(dg);
+ }
+ else
+ {
+ H = identity_matrix<double>(H.nr());
+ been_used_twice = false;
+ }
+ }
+
+ prev_x = x;
+ prev_direction = -H*funct_derivative;
+ prev_derivative = funct_derivative;
+ return prev_direction;
+ }
+
+ private:
+ bool been_used;
+ bool been_used_twice;
+ matrix<double,0,1> prev_x;
+ matrix<double,0,1> prev_derivative;
+ matrix<double,0,1> prev_direction;
+ matrix<double> H;
+ matrix<double,0,1> delta, gamma, Hg, gH;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class lbfgs_search_strategy
+ {
+ public:
+ explicit lbfgs_search_strategy(unsigned long max_size_) : max_size(max_size_), been_used(false)
+ {
+ DLIB_ASSERT (
+ max_size > 0,
+ "\t lbfgs_search_strategy(max_size)"
+ << "\n\t max_size can't be zero"
+ );
+ }
+
+ lbfgs_search_strategy(const lbfgs_search_strategy& item)
+ {
+ max_size = item.max_size;
+ been_used = item.been_used;
+ prev_x = item.prev_x;
+ prev_derivative = item.prev_derivative;
+ prev_direction = item.prev_direction;
+ alpha = item.alpha;
+ dh_temp = item.dh_temp;
+ }
+
+ double get_wolfe_rho (
+ ) const { return 0.01; }
+
+ double get_wolfe_sigma (
+ ) const { return 0.9; }
+
+ unsigned long get_max_line_search_iterations (
+ ) const { return 100; }
+
+ template <typename T>
+ const matrix<double,0,1>& get_next_direction (
+ const T& x,
+ const double ,
+ const T& funct_derivative
+ )
+ {
+ prev_direction = -funct_derivative;
+
+ if (been_used == false)
+ {
+ been_used = true;
+ }
+ else
+ {
+ // add an element into the stored data sequence
+ dh_temp.s = x - prev_x;
+ dh_temp.y = funct_derivative - prev_derivative;
+ double temp = dot(dh_temp.s, dh_temp.y);
+ // only accept this bit of data if temp isn't zero
+ if (std::abs(temp) > std::numeric_limits<double>::epsilon())
+ {
+ dh_temp.rho = 1/temp;
+ data.add(data.size(), dh_temp);
+ }
+ else
+ {
+ data.clear();
+ }
+
+ if (data.size() > 0)
+ {
+ // This block of code is from algorithm 7.4 in the Nocedal book.
+
+ alpha.resize(data.size());
+ for (unsigned long i = data.size()-1; i < data.size(); --i)
+ {
+ alpha[i] = data[i].rho*dot(data[i].s, prev_direction);
+ prev_direction -= alpha[i]*data[i].y;
+ }
+
+ // Take a guess at what the first H matrix should be. This formula below is what is suggested
+ // in the book Numerical Optimization by Nocedal and Wright in the chapter on Large Scale
+ // Unconstrained Optimization (in the L-BFGS section).
+ double H_0 = 1.0/data[data.size()-1].rho/dot(data[data.size()-1].y, data[data.size()-1].y);
+ H_0 = put_in_range(0.001, 1000.0, H_0);
+ prev_direction *= H_0;
+
+ for (unsigned long i = 0; i < data.size(); ++i)
+ {
+ double beta = data[i].rho*dot(data[i].y, prev_direction);
+ prev_direction += data[i].s * (alpha[i] - beta);
+ }
+ }
+
+ }
+
+ if (data.size() > max_size)
+ {
+ // remove the oldest element in the data sequence
+ data.remove(0, dh_temp);
+ }
+
+ prev_x = x;
+ prev_derivative = funct_derivative;
+ return prev_direction;
+ }
+
+ private:
+
+ struct data_helper
+ {
+ matrix<double,0,1> s;
+ matrix<double,0,1> y;
+ double rho;
+
+ friend void swap(data_helper& a, data_helper& b)
+ {
+ a.s.swap(b.s);
+ a.y.swap(b.y);
+ std::swap(a.rho, b.rho);
+ }
+ };
+ sequence<data_helper>::kernel_2a data;
+
+ unsigned long max_size;
+ bool been_used;
+ matrix<double,0,1> prev_x;
+ matrix<double,0,1> prev_derivative;
+ matrix<double,0,1> prev_direction;
+ std::vector<double> alpha;
+
+ data_helper dh_temp;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename hessian_funct>
+ class newton_search_strategy_obj
+ {
+ public:
+ explicit newton_search_strategy_obj(
+ const hessian_funct& hess
+ ) : hessian(hess) {}
+
+ double get_wolfe_rho (
+ ) const { return 0.01; }
+
+ double get_wolfe_sigma (
+ ) const { return 0.9; }
+
+ unsigned long get_max_line_search_iterations (
+ ) const { return 100; }
+
+ template <typename T>
+ const matrix<double,0,1> get_next_direction (
+ const T& x,
+ const double ,
+ const T& funct_derivative
+ )
+ {
+ return -inv(hessian(x))*funct_derivative;
+ }
+
+ private:
+ hessian_funct hessian;
+ };
+
+ template <typename hessian_funct>
+ newton_search_strategy_obj<hessian_funct> newton_search_strategy (
+ hessian_funct hessian
+ ) { return newton_search_strategy_obj<hessian_funct>(hessian); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_H_
+
diff --git a/ml/dlib/dlib/optimization/optimization_search_strategies_abstract.h b/ml/dlib/dlib/optimization/optimization_search_strategies_abstract.h
new file mode 100644
index 000000000..44b894bfc
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_search_strategies_abstract.h
@@ -0,0 +1,330 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_ABSTRACT_
+#ifdef DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+
+
+namespace dlib
+{
+ /*
+ A good discussion of the search strategies in this file can be found in the
+ following book: Numerical Optimization by Nocedal and Wright.
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ class cg_search_strategy
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a strategy for determining which direction
+ a line search should be carried out along. This particular object
+ is an implementation of the Polak-Ribiere conjugate gradient method
+ for determining this direction.
+
+ This method uses an amount of memory that is linear in the number
+ of variables to be optimized. So it is capable of handling problems
+ with a very large number of variables. However, it is generally
+ not as good as the L-BFGS algorithm (which is defined below in
+ the lbfgs_search_strategy class).
+ !*/
+
+ public:
+ cg_search_strategy(
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to generate
+ search directions.
+ !*/
+
+ double get_wolfe_rho (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the Wolfe rho parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ double get_wolfe_sigma (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the Wolfe sigma parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ unsigned long get_max_line_search_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the max iterations parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ template <typename T>
+ const matrix<double,0,1>& get_next_direction (
+ const T& x,
+ const double funct_value,
+ const T& funct_derivative
+ );
+ /*!
+ requires
+ - this function is only called once per search iteration
+ - for some objective function f():
+ - x == the search point for the current iteration
+ - funct_value == f(x)
+ - funct_derivative == derivative(f)(x)
+ ensures
+ - Assuming that a line search is going to be conducted starting from the point x,
+ this function returns the direction in which the search should proceed.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class bfgs_search_strategy
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a strategy for determining which direction
+ a line search should be carried out along. This particular object
+ is an implementation of the BFGS quasi-newton method for determining
+ this direction.
+
+ This method uses an amount of memory that is quadratic in the number
+ of variables to be optimized. It is generally very effective but
+ if your problem has a very large number of variables then it isn't
+ appropriate. Instead You should try the lbfgs_search_strategy.
+ !*/
+
+ public:
+ bfgs_search_strategy(
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to generate
+ search directions.
+ !*/
+
+ double get_wolfe_rho (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the Wolfe rho parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ double get_wolfe_sigma (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the Wolfe sigma parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ unsigned long get_max_line_search_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the max iterations parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ template <typename T>
+ const matrix<double,0,1>& get_next_direction (
+ const T& x,
+ const double funct_value,
+ const T& funct_derivative
+ );
+ /*!
+ requires
+ - this function is only called once per search iteration
+ - for some objective function f():
+ - x == the search point for the current iteration
+ - funct_value == f(x)
+ - funct_derivative == derivative(f)(x)
+ ensures
+ - Assuming that a line search is going to be conducted starting from the point x,
+ this function returns the direction in which the search should proceed.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class lbfgs_search_strategy
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a strategy for determining which direction
+ a line search should be carried out along. This particular object
+ is an implementation of the L-BFGS quasi-newton method for determining
+ this direction.
+
+ This method uses an amount of memory that is linear in the number
+ of variables to be optimized. This makes it an excellent method
+ to use when an optimization problem has a large number of variables.
+ !*/
+ public:
+ explicit lbfgs_search_strategy(
+ unsigned long max_size
+ );
+ /*!
+ requires
+ - max_size > 0
+ ensures
+ - This object is properly initialized and ready to generate
+ search directions.
+ - L-BFGS works by remembering a certain number of position and gradient
+ pairs. It uses this remembered information to compute search directions.
+ The max_size argument determines how many of these pairs will be remembered.
+ Typically, using between 3 and 30 pairs performs well for many problems.
+ !*/
+
+ double get_wolfe_rho (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the Wolfe rho parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ double get_wolfe_sigma (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the Wolfe sigma parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ unsigned long get_max_line_search_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the max iterations parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ template <typename T>
+ const matrix<double,0,1>& get_next_direction (
+ const T& x,
+ const double funct_value,
+ const T& funct_derivative
+ );
+ /*!
+ requires
+ - this function is only called once per search iteration
+ - for some objective function f():
+ - x == the search point for the current iteration
+ - funct_value == f(x)
+ - funct_derivative == derivative(f)(x)
+ ensures
+ - Assuming that a line search is going to be conducted starting from the point x,
+ this function returns the direction in which the search should proceed.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename hessian_funct
+ >
+ class newton_search_strategy_obj
+ {
+ /*!
+ REQUIREMENTS ON hessian_funct
+ Objects of hessian_funct type must be function objects which
+ take a single argument and return a dlib::matrix of doubles. The
+ single argument must be a dlib::matrix capable of representing
+ column vectors of doubles. hessian_funct must also be copy
+ constructable.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a strategy for determining which direction
+ a line search should be carried out along. This particular object
+ is an implementation of the newton method for determining this
+ direction. That is, it uses the following formula to determine
+ the direction:
+ search_direction = -inv(hessian(x))*derivative
+ !*/
+ public:
+ explicit newton_search_strategy_obj(
+ const hessian_funct& hess
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to generate
+ search directions.
+ - hess will be used by this object to generate the needed hessian
+ matrices every time get_next_direction() is called.
+ !*/
+
+ double get_wolfe_rho (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the Wolfe rho parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ double get_wolfe_sigma (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the Wolfe sigma parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ unsigned long get_max_line_search_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the value of the max iterations parameter that should be used when
+ this search strategy is used with the line_search() function.
+ !*/
+
+ template <typename T>
+ const matrix<double,0,1> get_next_direction (
+ const T& x,
+ const double funct_value,
+ const T& funct_derivative
+ );
+ /*!
+ requires
+ - for some objective function f():
+ - x == the search point for the current iteration
+ - funct_value == f(x)
+ - funct_derivative == derivative(f)(x)
+ ensures
+ - Assuming that a line search is going to be conducted starting from the
+ point x, this function returns the direction in which the search should
+ proceed.
+ - In particular, the search direction will be given by:
+ - search_direction = -inv(hessian(x))*funct_derivative
+ !*/
+
+ };
+
+ template <typename hessian_funct>
+ newton_search_strategy_obj<hessian_funct> newton_search_strategy (
+ hessian_funct hessian
+ ) { return newton_search_strategy_obj<hessian_funct>(hessian); }
+ /*!
+ ensures
+ - constructs and returns a newton_search_strategy_obj.
+ This function is just a helper to make the syntax for creating
+ these objects a little simpler.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_ABSTRACT_
+
diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo.h b/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo.h
new file mode 100644
index 000000000..88cad0cf3
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo.h
@@ -0,0 +1,468 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOLVE_QP2_USING_SMo_Hh_
+#define DLIB_SOLVE_QP2_USING_SMo_Hh_
+
+#include "optimization_solve_qp2_using_smo_abstract.h"
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix.h"
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class invalid_nu_error : public dlib::error
+ {
+ public:
+ invalid_nu_error(const std::string& msg, double nu_) : dlib::error(msg), nu(nu_) {};
+ const double nu;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename T::type maximum_nu_impl (
+ const T& y
+ )
+ {
+ typedef typename T::type scalar_type;
+ // make sure requires clause is not broken
+ DLIB_ASSERT(y.size() > 1 && is_col_vector(y),
+ "\ttypedef T::type maximum_nu(y)"
+ << "\n\ty should be a column vector with more than one entry"
+ << "\n\ty.nr(): " << y.nr()
+ << "\n\ty.nc(): " << y.nc()
+ );
+
+ long pos_count = 0;
+ long neg_count = 0;
+ for (long r = 0; r < y.nr(); ++r)
+ {
+ if (y(r) == 1.0)
+ {
+ ++pos_count;
+ }
+ else if (y(r) == -1.0)
+ {
+ ++neg_count;
+ }
+ else
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(y(r) == -1.0 || y(r) == 1.0,
+ "\ttypedef T::type maximum_nu(y)"
+ << "\n\ty should contain only 1 and 0 entries"
+ << "\n\tr: " << r
+ << "\n\ty(r): " << y(r)
+ );
+ }
+ }
+ return static_cast<scalar_type>(2.0*(scalar_type)std::min(pos_count,neg_count)/(scalar_type)y.nr());
+ }
+
+ template <
+ typename T
+ >
+ typename T::type maximum_nu (
+ const T& y
+ )
+ {
+ return maximum_nu_impl(mat(y));
+ }
+
+ template <
+ typename T
+ >
+ typename T::value_type maximum_nu (
+ const T& y
+ )
+ {
+ return maximum_nu_impl(mat(y));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class solve_qp2_using_smo
+ {
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ long NR
+ >
+ unsigned long operator() (
+ const matrix_exp<EXP1>& Q,
+ const matrix_exp<EXP2>& y,
+ const scalar_type nu,
+ matrix<scalar_type,NR,1,mem_manager_type, layout_type>& alpha,
+ scalar_type eps
+ )
+ {
+ DLIB_ASSERT(Q.nr() == Q.nc() && y.size() == Q.nr() && y.size() > 1 && is_col_vector(y) &&
+ sum((y == +1) + (y == -1)) == y.size() &&
+ 0 < nu && nu <= 1 &&
+ eps > 0,
+ "\t void solve_qp2_using_smo::operator()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t Q.nc(): " << Q.nc()
+ << "\n\t is_col_vector(y): " << is_col_vector(y)
+ << "\n\t y.size(): " << y.size()
+ << "\n\t sum((y == +1) + (y == -1)): " << sum((y == +1) + (y == -1))
+ << "\n\t nu: " << nu
+ << "\n\t eps: " << eps
+ );
+
+ alpha.set_size(Q.nr(),1);
+ df.set_size(Q.nr());
+
+ // now initialize alpha
+ set_initial_alpha(y, nu, alpha);
+
+ const scalar_type tau = 1e-12;
+
+ typedef typename colm_exp<EXP1>::type col_type;
+
+ set_all_elements(df, 0);
+ // initialize df. Compute df = Q*alpha
+ for (long r = 0; r < df.nr(); ++r)
+ {
+ if (alpha(r) != 0)
+ {
+ df += alpha(r)*matrix_cast<scalar_type>(colm(Q,r));
+ }
+ }
+
+ unsigned long count = 0;
+
+ // now perform the actual optimization of alpha
+ long i=0, j=0;
+ while (find_working_group(y,alpha,Q,df,tau,eps,i,j))
+ {
+ ++count;
+ const scalar_type old_alpha_i = alpha(i);
+ const scalar_type old_alpha_j = alpha(j);
+
+ optimize_working_pair(alpha,Q,df,tau,i,j);
+
+ // update the df vector now that we have modified alpha(i) and alpha(j)
+ scalar_type delta_alpha_i = alpha(i) - old_alpha_i;
+ scalar_type delta_alpha_j = alpha(j) - old_alpha_j;
+
+ col_type Q_i = colm(Q,i);
+ col_type Q_j = colm(Q,j);
+ for(long k = 0; k < df.nr(); ++k)
+ df(k) += Q_i(k)*delta_alpha_i + Q_j(k)*delta_alpha_j;
+ }
+
+ return count;
+ }
+
+ const column_matrix& get_gradient (
+ ) const { return df; }
+
+ private:
+
+ // -------------------------------------------------------------------------------------
+
+ template <
+ typename scalar_type,
+ typename scalar_vector_type,
+ typename scalar_vector_type2
+ >
+ inline void set_initial_alpha (
+ const scalar_vector_type& y,
+ const scalar_type nu,
+ scalar_vector_type2& alpha
+ ) const
+ {
+ set_all_elements(alpha,0);
+ const scalar_type l = y.nr();
+ scalar_type temp = nu*l/2;
+ long num = (long)std::floor(temp);
+ long num_total = (long)std::ceil(temp);
+
+ bool has_slack = false;
+ int count = 0;
+ for (int i = 0; i < alpha.nr(); ++i)
+ {
+ if (y(i) == 1)
+ {
+ if (count < num)
+ {
+ ++count;
+ alpha(i) = 1;
+ }
+ else
+ {
+ has_slack = true;
+ if (num_total > num)
+ {
+ ++count;
+ alpha(i) = temp - std::floor(temp);
+ }
+ break;
+ }
+ }
+ }
+
+ if (count != num_total || has_slack == false)
+ {
+ std::ostringstream sout;
+ sout << "Invalid nu of " << nu << ". It is required that: 0 < nu < " << 2*(scalar_type)count/y.nr();
+ throw invalid_nu_error(sout.str(),nu);
+ }
+
+ has_slack = false;
+ count = 0;
+ for (int i = 0; i < alpha.nr(); ++i)
+ {
+ if (y(i) == -1)
+ {
+ if (count < num)
+ {
+ ++count;
+ alpha(i) = 1;
+ }
+ else
+ {
+ has_slack = true;
+ if (num_total > num)
+ {
+ ++count;
+ alpha(i) = temp - std::floor(temp);
+ }
+ break;
+ }
+ }
+ }
+
+ if (count != num_total || has_slack == false)
+ {
+ std::ostringstream sout;
+ sout << "Invalid nu of " << nu << ". It is required that: 0 < nu < " << 2*(scalar_type)count/y.nr();
+ throw invalid_nu_error(sout.str(),nu);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename scalar_vector_type,
+ typename scalar_type,
+ typename EXP,
+ typename U, typename V
+ >
+ inline bool find_working_group (
+ const V& y,
+ const U& alpha,
+ const matrix_exp<EXP>& Q,
+ const scalar_vector_type& df,
+ const scalar_type tau,
+ const scalar_type eps,
+ long& i_out,
+ long& j_out
+ ) const
+ {
+ using namespace std;
+ long ip = 0;
+ long jp = 0;
+ long in = 0;
+ long jn = 0;
+
+
+ typedef typename colm_exp<EXP>::type col_type;
+ typedef typename diag_exp<EXP>::type diag_type;
+
+ scalar_type ip_val = -numeric_limits<scalar_type>::infinity();
+ scalar_type jp_val = numeric_limits<scalar_type>::infinity();
+ scalar_type in_val = -numeric_limits<scalar_type>::infinity();
+ scalar_type jn_val = numeric_limits<scalar_type>::infinity();
+
+ // loop over the alphas and find the maximum ip and in indices.
+ for (long i = 0; i < alpha.nr(); ++i)
+ {
+ if (y(i) == 1)
+ {
+ if (alpha(i) < 1.0)
+ {
+ if (-df(i) > ip_val)
+ {
+ ip_val = -df(i);
+ ip = i;
+ }
+ }
+ }
+ else
+ {
+ if (alpha(i) > 0.0)
+ {
+ if (df(i) > in_val)
+ {
+ in_val = df(i);
+ in = i;
+ }
+ }
+ }
+ }
+
+ scalar_type Mp = numeric_limits<scalar_type>::infinity();
+ scalar_type Mn = numeric_limits<scalar_type>::infinity();
+
+ // Pick out the columns and diagonal of Q we need below. Doing
+ // it this way is faster if Q is actually a symmetric_matrix_cache
+ // object.
+ col_type Q_ip = colm(Q,ip);
+ col_type Q_in = colm(Q,in);
+ diag_type Q_diag = diag(Q);
+
+
+
+ // now we need to find the minimum jp and jn indices
+ for (long j = 0; j < alpha.nr(); ++j)
+ {
+ if (y(j) == 1)
+ {
+ if (alpha(j) > 0.0)
+ {
+ scalar_type b = ip_val + df(j);
+ if (-df(j) < Mp)
+ Mp = -df(j);
+
+ if (b > 0)
+ {
+ scalar_type a = Q_ip(ip) + Q_diag(j) - 2*Q_ip(j);
+ if (a <= 0)
+ a = tau;
+ scalar_type temp = -b*b/a;
+ if (temp < jp_val)
+ {
+ jp_val = temp;
+ jp = j;
+ }
+ }
+ }
+ }
+ else
+ {
+ if (alpha(j) < 1.0)
+ {
+ scalar_type b = in_val - df(j);
+ if (df(j) < Mn)
+ Mn = df(j);
+
+ if (b > 0)
+ {
+ scalar_type a = Q_in(in) + Q_diag(j) - 2*Q_in(j);
+ if (a <= 0)
+ a = tau;
+ scalar_type temp = -b*b/a;
+ if (temp < jn_val)
+ {
+ jn_val = temp;
+ jn = j;
+ }
+ }
+ }
+ }
+ }
+
+ // if we are at the optimal point then return false so the caller knows
+ // to stop optimizing
+ if (std::max(ip_val - Mp, in_val - Mn) < eps)
+ return false;
+
+ if (jp_val < jn_val)
+ {
+ i_out = ip;
+ j_out = jp;
+ }
+ else
+ {
+ i_out = in;
+ j_out = jn;
+ }
+
+ return true;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename T, typename U
+ >
+ inline void optimize_working_pair (
+ T& alpha,
+ const matrix_exp<EXP>& Q,
+ const U& df,
+ const scalar_type tau,
+ const long i,
+ const long j
+ ) const
+ {
+ scalar_type quad_coef = Q(i,i)+Q(j,j)-2*Q(j,i);
+ if (quad_coef <= 0)
+ quad_coef = tau;
+ scalar_type delta = (df(i)-df(j))/quad_coef;
+ scalar_type sum = alpha(i) + alpha(j);
+ alpha(i) -= delta;
+ alpha(j) += delta;
+
+ if(sum > 1)
+ {
+ if(alpha(i) > 1)
+ {
+ alpha(i) = 1;
+ alpha(j) = sum - 1;
+ }
+ else if(alpha(j) > 1)
+ {
+ alpha(j) = 1;
+ alpha(i) = sum - 1;
+ }
+ }
+ else
+ {
+ if(alpha(j) < 0)
+ {
+ alpha(j) = 0;
+ alpha(i) = sum;
+ }
+ else if(alpha(i) < 0)
+ {
+ alpha(i) = 0;
+ alpha(j) = sum;
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ column_matrix df; // gradient of f(alpha)
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SOLVE_QP2_USING_SMo_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo_abstract.h b/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo_abstract.h
new file mode 100644
index 000000000..962541c25
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo_abstract.h
@@ -0,0 +1,150 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATION_SOLVE_QP2_USING_SMO_ABSTRACT_H_
+#ifdef DLIB_OPTIMIZATION_SOLVE_QP2_USING_SMO_ABSTRACT_H_
+
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class invalid_nu_error : public dlib::error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is an exception class used to indicate that a
+ value of nu given to the solve_qp2_using_smo object is incompatible
+ with the constraints of the quadratic program.
+
+ this->nu will be set to the invalid value of nu used.
+ !*/
+
+ public:
+ invalid_nu_error(const std::string& msg, double nu_) : dlib::error(msg), nu(nu_) {};
+ const double nu;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename T::type maximum_nu (
+ const T& y
+ );
+ /*!
+ requires
+ - T == a matrix object or an object convertible to a matrix via mat()
+ - is_col_vector(y) == true
+ - y.size() > 1
+ - sum((y == +1) + (y == -1)) == y.size()
+ (i.e. all elements of y must be equal to +1 or -1)
+ ensures
+ - returns the maximum valid nu that can be used with solve_qp2_using_smo and
+ the given y vector.
+ (i.e. 2.0*min(sum(y == +1), sum(y == -1))/y.size())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class solve_qp2_using_smo
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ Must be some type of dlib::matrix.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for solving the following quadratic programming
+ problem using the sequential minimal optimization algorithm:
+
+ Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha
+ subject to the following constraints:
+ - sum(alpha) == nu*y.size()
+ - 0 <= min(alpha) && max(alpha) <= 1
+ - trans(y)*alpha == 0
+
+ Where all elements of y must be equal to +1 or -1 and f is convex.
+ This means that Q should be symmetric and positive-semidefinite.
+
+
+ This object implements the strategy used by the LIBSVM tool. The following papers
+ can be consulted for additional details:
+ - Chang and Lin, Training {nu}-Support Vector Classifiers: Theory and Algorithms
+ - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector
+ machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm
+ !*/
+
+ public:
+
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ long NR
+ >
+ unsigned long operator() (
+ const matrix_exp<EXP1>& Q,
+ const matrix_exp<EXP2>& y,
+ const scalar_type nu,
+ matrix<scalar_type,NR,1,mem_manager_type, layout_type>& alpha,
+ scalar_type eps
+ );
+ /*!
+ requires
+ - Q.nr() == Q.nc()
+ - is_col_vector(y) == true
+ - y.size() == Q.nr()
+ - y.size() > 1
+ - sum((y == +1) + (y == -1)) == y.size()
+ (i.e. all elements of y must be equal to +1 or -1)
+ - alpha must be capable of representing a vector of size y.size() elements
+ - 0 < nu <= 1
+ - eps > 0
+ ensures
+ - This function solves the quadratic program defined in this class's main comment.
+ - The solution to the quadratic program will be stored in #alpha.
+ - #alpha.size() == y.size()
+ - This function uses an implementation of the sequential minimal optimization
+ algorithm. It runs until the KKT violation is less than eps. So eps controls
+ how accurate the solution is and smaller values result in better solutions.
+ (a reasonable eps is usually about 1e-3)
+ - #get_gradient() == Q*(#alpha)
+ (i.e. stores the gradient of f() at #alpha in get_gradient())
+ - returns the number of iterations performed.
+ throws
+ - invalid_nu_error
+ This exception is thrown if nu >= maximum_nu(y).
+ (some values of nu cause the constraints to become impossible to satisfy.
+ If this is detected then an exception is thrown).
+ !*/
+
+ const column_matrix& get_gradient (
+ ) const;
+ /*!
+ ensures
+ - returns the gradient vector at the solution of the last problem solved
+ by this object. If no problem has been solved then returns an empty
+ vector.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATION_SOLVE_QP2_USING_SMO_ABSTRACT_H_
+
+
+
diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo.h b/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo.h
new file mode 100644
index 000000000..617ecc408
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo.h
@@ -0,0 +1,455 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOLVE_QP3_USING_SMo_Hh_
+#define DLIB_SOLVE_QP3_USING_SMo_Hh_
+
+#include "optimization_solve_qp3_using_smo_abstract.h"
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix.h"
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class invalid_qp3_error : public dlib::error
+ {
+
+ public:
+ invalid_qp3_error(
+ const std::string& msg,
+ double B_,
+ double Cp_,
+ double Cn_
+ ) :
+ dlib::error(msg),
+ B(B_),
+ Cp(Cp_),
+ Cn(Cn_)
+ {};
+
+ const double B;
+ const double Cp;
+ const double Cn;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class solve_qp3_using_smo
+ {
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3,
+ long NR
+ >
+ unsigned long operator() (
+ const matrix_exp<EXP1>& Q,
+ const matrix_exp<EXP2>& p,
+ const matrix_exp<EXP3>& y,
+ const scalar_type B,
+ const scalar_type Cp,
+ const scalar_type Cn,
+ matrix<scalar_type,NR,1,mem_manager_type, layout_type>& alpha,
+ scalar_type eps
+ )
+ {
+ DLIB_ASSERT(Q.nr() == Q.nc() && y.size() == Q.nr() && p.size() == y.size() &&
+ y.size() > 0 && is_col_vector(y) && is_col_vector(p) &&
+ sum((y == +1) + (y == -1)) == y.size() &&
+ Cp > 0 && Cn > 0 &&
+ eps > 0,
+ "\t void solve_qp3_using_smo::operator()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t Q.nc(): " << Q.nc()
+ << "\n\t is_col_vector(p): " << is_col_vector(p)
+ << "\n\t p.size(): " << p.size()
+ << "\n\t is_col_vector(y): " << is_col_vector(y)
+ << "\n\t y.size(): " << y.size()
+ << "\n\t sum((y == +1) + (y == -1)): " << sum((y == +1) + (y == -1))
+ << "\n\t Cp: " << Cp
+ << "\n\t Cn: " << Cn
+ << "\n\t eps: " << eps
+ );
+
+
+
+ set_initial_alpha(y, B, Cp, Cn, alpha);
+
+
+ const scalar_type tau = 1e-12;
+
+ typedef typename colm_exp<EXP1>::type col_type;
+
+ // initialize df. Compute df = Q*alpha + p
+ df = p;
+ for (long r = 0; r < df.nr(); ++r)
+ {
+ if (alpha(r) != 0)
+ {
+ df += alpha(r)*matrix_cast<scalar_type>(colm(Q,r));
+ }
+ }
+
+ unsigned long count = 0;
+ // now perform the actual optimization of alpha
+ long i=0, j=0;
+ while (find_working_group(y,alpha,Q,df,Cp,Cn,tau,eps,i,j))
+ {
+ ++count;
+ const scalar_type old_alpha_i = alpha(i);
+ const scalar_type old_alpha_j = alpha(j);
+
+ optimize_working_pair(alpha,Q,y,df,tau,i,j, Cp, Cn );
+
+ // update the df vector now that we have modified alpha(i) and alpha(j)
+ scalar_type delta_alpha_i = alpha(i) - old_alpha_i;
+ scalar_type delta_alpha_j = alpha(j) - old_alpha_j;
+
+ col_type Q_i = colm(Q,i);
+ col_type Q_j = colm(Q,j);
+ for(long k = 0; k < df.nr(); ++k)
+ df(k) += Q_i(k)*delta_alpha_i + Q_j(k)*delta_alpha_j;
+ }
+
+ return count;
+ }
+
+ const column_matrix& get_gradient (
+ ) const { return df; }
+
+ private:
+
+ // -------------------------------------------------------------------------------------
+
+ template <
+ typename scalar_type,
+ typename scalar_vector_type,
+ typename scalar_vector_type2
+ >
+ inline void set_initial_alpha (
+ const scalar_vector_type& y,
+ const scalar_type B,
+ const scalar_type Cp,
+ const scalar_type Cn,
+ scalar_vector_type2& alpha
+ ) const
+ {
+ alpha.set_size(y.size());
+
+ set_all_elements(alpha,0);
+
+ // It's easy in the B == 0 case
+ if (B == 0)
+ return;
+
+ const scalar_type C = (B > 0)? Cp : Cn;
+
+ scalar_type temp = std::abs(B)/C;
+ long num = (long)std::floor(temp);
+ long num_total = (long)std::ceil(temp);
+
+ const scalar_type B_sign = (B > 0)? 1 : -1;
+
+ long count = 0;
+ for (long i = 0; i < alpha.nr(); ++i)
+ {
+ if (y(i) == B_sign)
+ {
+ if (count < num)
+ {
+ ++count;
+ alpha(i) = C;
+ }
+ else
+ {
+ if (count < num_total)
+ {
+ ++count;
+ alpha(i) = C*(temp - std::floor(temp));
+ }
+ break;
+ }
+ }
+ }
+
+ if (count != num_total)
+ {
+ std::ostringstream sout;
+ sout << "Invalid QP3 constraint parameters of B: " << B << ", Cp: " << Cp << ", Cn: "<< Cn;
+ throw invalid_qp3_error(sout.str(),B,Cp,Cn);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename scalar_vector_type,
+ typename scalar_type,
+ typename EXP,
+ typename U, typename V
+ >
+ inline bool find_working_group (
+ const V& y,
+ const U& alpha,
+ const matrix_exp<EXP>& Q,
+ const scalar_vector_type& df,
+ const scalar_type Cp,
+ const scalar_type Cn,
+ const scalar_type tau,
+ const scalar_type eps,
+ long& i_out,
+ long& j_out
+ ) const
+ {
+ using namespace std;
+
+ long ip = 0;
+ long jp = 0;
+
+
+ typedef typename colm_exp<EXP>::type col_type;
+ typedef typename diag_exp<EXP>::type diag_type;
+
+ scalar_type ip_val = -numeric_limits<scalar_type>::infinity();
+ scalar_type jp_val = numeric_limits<scalar_type>::infinity();
+
+ // loop over the alphas and find the maximum ip and in indices.
+ for (long i = 0; i < alpha.nr(); ++i)
+ {
+ if (y(i) == 1)
+ {
+ if (alpha(i) < Cp)
+ {
+ if (-df(i) > ip_val)
+ {
+ ip_val = -df(i);
+ ip = i;
+ }
+ }
+ }
+ else
+ {
+ if (alpha(i) > 0.0)
+ {
+ if (df(i) > ip_val)
+ {
+ ip_val = df(i);
+ ip = i;
+ }
+ }
+ }
+ }
+
+ scalar_type Mp = -numeric_limits<scalar_type>::infinity();
+
+ // Pick out the column and diagonal of Q we need below. Doing
+ // it this way is faster if Q is actually a symmetric_matrix_cache
+ // object.
+ col_type Q_ip = colm(Q,ip);
+ diag_type Q_diag = diag(Q);
+
+
+
+ // now we need to find the minimum jp indices
+ for (long j = 0; j < alpha.nr(); ++j)
+ {
+ if (y(j) == 1)
+ {
+ if (alpha(j) > 0.0)
+ {
+ scalar_type b = ip_val + df(j);
+ if (df(j) > Mp)
+ Mp = df(j);
+
+ if (b > 0)
+ {
+ scalar_type a = Q_ip(ip) + Q_diag(j) - 2*y(ip)*Q_ip(j);
+ if (a <= 0)
+ a = tau;
+ scalar_type temp = -b*b/a;
+ if (temp < jp_val)
+ {
+ jp_val = temp;
+ jp = j;
+ }
+ }
+ }
+ }
+ else
+ {
+ if (alpha(j) < Cn)
+ {
+ scalar_type b = ip_val - df(j);
+ if (-df(j) > Mp)
+ Mp = -df(j);
+
+ if (b > 0)
+ {
+ scalar_type a = Q_ip(ip) + Q_diag(j) + 2*y(ip)*Q_ip(j);
+ if (a <= 0)
+ a = tau;
+ scalar_type temp = -b*b/a;
+ if (temp < jp_val)
+ {
+ jp_val = temp;
+ jp = j;
+ }
+ }
+ }
+ }
+ }
+
+ // if we are at the optimal point then return false so the caller knows
+ // to stop optimizing
+ if (Mp + ip_val < eps)
+ return false;
+
+
+ i_out = ip;
+ j_out = jp;
+
+ return true;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename EXP2,
+ typename T, typename U
+ >
+ inline void optimize_working_pair (
+ T& alpha,
+ const matrix_exp<EXP>& Q,
+ const matrix_exp<EXP2>& y,
+ const U& df,
+ const scalar_type& tau,
+ const long i,
+ const long j,
+ const scalar_type& Cp,
+ const scalar_type& Cn
+ ) const
+ {
+ const scalar_type Ci = (y(i) > 0 )? Cp : Cn;
+ const scalar_type Cj = (y(j) > 0 )? Cp : Cn;
+
+ if (y(i) != y(j))
+ {
+ scalar_type quad_coef = Q(i,i)+Q(j,j)+2*Q(j,i);
+ if (quad_coef <= 0)
+ quad_coef = tau;
+ scalar_type delta = (-df(i)-df(j))/quad_coef;
+ scalar_type diff = alpha(i) - alpha(j);
+ alpha(i) += delta;
+ alpha(j) += delta;
+
+ if (diff > 0)
+ {
+ if (alpha(j) < 0)
+ {
+ alpha(j) = 0;
+ alpha(i) = diff;
+ }
+ }
+ else
+ {
+ if (alpha(i) < 0)
+ {
+ alpha(i) = 0;
+ alpha(j) = -diff;
+ }
+ }
+
+ if (diff > Ci - Cj)
+ {
+ if (alpha(i) > Ci)
+ {
+ alpha(i) = Ci;
+ alpha(j) = Ci - diff;
+ }
+ }
+ else
+ {
+ if (alpha(j) > Cj)
+ {
+ alpha(j) = Cj;
+ alpha(i) = Cj + diff;
+ }
+ }
+ }
+ else
+ {
+ scalar_type quad_coef = Q(i,i)+Q(j,j)-2*Q(j,i);
+ if (quad_coef <= 0)
+ quad_coef = tau;
+ scalar_type delta = (df(i)-df(j))/quad_coef;
+ scalar_type sum = alpha(i) + alpha(j);
+ alpha(i) -= delta;
+ alpha(j) += delta;
+
+ if(sum > Ci)
+ {
+ if(alpha(i) > Ci)
+ {
+ alpha(i) = Ci;
+ alpha(j) = sum - Ci;
+ }
+ }
+ else
+ {
+ if(alpha(j) < 0)
+ {
+ alpha(j) = 0;
+ alpha(i) = sum;
+ }
+ }
+
+ if(sum > Cj)
+ {
+ if(alpha(j) > Cj)
+ {
+ alpha(j) = Cj;
+ alpha(i) = sum - Cj;
+ }
+ }
+ else
+ {
+ if(alpha(i) < 0)
+ {
+ alpha(i) = 0;
+ alpha(j) = sum;
+ }
+ }
+
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ column_matrix df; // gradient of f(alpha)
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SOLVE_QP3_USING_SMo_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo_abstract.h b/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo_abstract.h
new file mode 100644
index 000000000..8efd7215b
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo_abstract.h
@@ -0,0 +1,139 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATION_SOLVE_QP3_USING_SMO_ABSTRACT_H_
+#ifdef DLIB_OPTIMIZATION_SOLVE_QP3_USING_SMO_ABSTRACT_H_
+
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class invalid_qp3_error : public dlib::error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is an exception class used to indicate that the values
+ of B, Cp, and Cn given to the solve_qp3_using_smo object are incompatible
+ with the constraints of the quadratic program.
+
+ this->B, this->Cp, and this->Cn will be set to the invalid values used.
+ !*/
+
+ public:
+ invalid_qp3_error( const std::string& msg, double B_, double Cp_, double Cn_) :
+ dlib::error(msg), B(B_), Cp(Cp_), Cn(Cn_) {};
+
+ const double B;
+ const double Cp;
+ const double Cn;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class solve_qp3_using_smo
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ Must be some type of dlib::matrix.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for solving the following quadratic programming
+ problem using the sequential minimal optimization algorithm:
+
+ Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + trans(p)*alpha
+ subject to the following constraints:
+ - for all i such that y(i) == +1: 0 <= alpha(i) <= Cp
+ - for all i such that y(i) == -1: 0 <= alpha(i) <= Cn
+ - trans(y)*alpha == B
+
+ Where all elements of y must be equal to +1 or -1 and f is convex.
+ This means that Q should be symmetric and positive-semidefinite.
+
+
+ This object implements the strategy used by the LIBSVM tool. The following papers
+ can be consulted for additional details:
+ - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector
+ machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm
+ - Working Set Selection Using Second Order Information for Training Support Vector Machines by
+ Fan, Chen, and Lin. In the Journal of Machine Learning Research 2005.
+ !*/
+
+ public:
+
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3,
+ long NR
+ >
+ unsigned long operator() (
+ const matrix_exp<EXP1>& Q,
+ const matrix_exp<EXP2>& p,
+ const matrix_exp<EXP3>& y,
+ const scalar_type B,
+ const scalar_type Cp,
+ const scalar_type Cn,
+ matrix<scalar_type,NR,1,mem_manager_type, layout_type>& alpha,
+ scalar_type eps
+ );
+ /*!
+ requires
+ - Q.nr() == Q.nc()
+ - is_col_vector(y) == true
+ - is_col_vector(p) == true
+ - p.size() == y.size() == Q.nr()
+ - y.size() > 0
+ - sum((y == +1) + (y == -1)) == y.size()
+ (i.e. all elements of y must be equal to +1 or -1)
+ - alpha must be capable of representing a vector of size y.size() elements
+ - Cp > 0
+ - Cn > 0
+ - eps > 0
+ ensures
+ - This function solves the quadratic program defined in this class's main comment.
+ - The solution to the quadratic program will be stored in #alpha.
+ - #alpha.size() == y.size()
+ - This function uses an implementation of the sequential minimal optimization
+ algorithm. It runs until the KKT violation is less than eps. So eps controls
+ how accurate the solution is and smaller values result in better solutions.
+ (a reasonable eps is usually about 1e-3)
+ - #get_gradient() == Q*(#alpha)
+ (i.e. stores the gradient of f() at #alpha in get_gradient())
+ - returns the number of iterations performed.
+ throws
+ - invalid_qp3_error
+ This exception is thrown if the given parameters cause the constraints
+ of the quadratic programming problem to be impossible to satisfy.
+ !*/
+
+ const column_matrix& get_gradient (
+ ) const;
+ /*!
+ ensures
+ - returns the gradient vector at the solution of the last problem solved
+ by this object. If no problem has been solved then returns an empty
+ vector.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATION_SOLVE_QP3_USING_SMO_ABSTRACT_H_
+
+
+
diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo.h b/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo.h
new file mode 100644
index 000000000..b9cc74df3
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo.h
@@ -0,0 +1,937 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_Hh_
+#define DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_Hh_
+
+#include "optimization_solve_qp_using_smo_abstract.h"
+#include "../matrix.h"
+#include <map>
+#include "../unordered_pair.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*
+ The algorithm defined in the solve_qp_using_smo() function below can be
+ derived by using an important theorem from the theory of constrained optimization.
+ This theorem tells us that any optimal point of a constrained function must
+ satisfy what are called the KKT conditions (also sometimes called just the KT
+ conditions, especially in older literature). A very good book to consult
+ regarding this topic is Practical Methods of Optimization (second edition) by
+ R. Fletcher. Below I will try to explain the general idea of how this is
+ applied.
+
+ Let e == ones_matrix(alpha.size(),1)
+
+ First, note that the function below solves the following quadratic program.
+ Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha - trans(alpha)*b
+ subject to the following constraints:
+ - trans(e)*alpha == C (i.e. the sum of alpha values doesn't change)
+ - min(alpha) >= 0 (i.e. all alpha values are nonnegative)
+ Where f is convex. This means that Q should be positive-semidefinite.
+
+
+ To get from this problem formulation to the algorithm below we have to
+ consider the KKT conditions. They tell us that any solution to the above
+ problem must satisfy the following 5 conditions:
+ 1. trans(e)*alpha == C
+ 2. min(alpha) >= 0
+
+ 3. Let L(alpha, x, y) == f(alpha) - trans(x)*alpha - y*(trans(e)*alpha - C)
+ Where x is a vector of length alpha.size() and y is a single scalar.
+ Then the derivative of L with respect to alpha must == 0
+ So we get the following as our 3rd condition:
+ f'(alpha) - x - y*e == 0
+
+ 4. min(x) >= 0 (i.e. all x values are nonnegative)
+ 5. pointwise_multiply(x, alpha) == 0
+ (i.e. only one member of each x(i) and alpha(i) pair can be non-zero)
+
+
+ From 3 we can easily obtain this rule:
+ for all i: f'(alpha)(i) - x(i) == y
+
+ If we then consider 4 and 5 we see that we can infer that the following
+ must also be the case:
+ - if (alpha(i) > 0) then
+ - x(i) == 0
+ - f'(alpha)(i) == y
+ - else
+ - x(i) == some nonnegative number
+ - f'(alpha)(i) >= y
+
+
+ The important thing to take away is the final rule. It tells us that at the
+ optimal solution all elements of the gradient of f have the same value if
+ their corresponding alpha is non-zero. It also tells us that all the other
+ gradient values are bigger than y. We can use this information to help us
+ pick which alpha variables to optimize at each iteration.
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ unsigned long solve_qp_using_smo (
+ const matrix_exp<EXP1>& Q,
+ const matrix_exp<EXP2>& b,
+ matrix<T,NR,NC,MM,L>& alpha,
+ T eps,
+ unsigned long max_iter
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(Q.nr() == Q.nc() &&
+ is_col_vector(b) &&
+ is_col_vector(alpha) &&
+ b.size() == alpha.size() &&
+ b.size() == Q.nr() &&
+ alpha.size() > 0 &&
+ min(alpha) >= 0 &&
+ eps > 0 &&
+ max_iter > 0,
+ "\t unsigned long solve_qp_using_smo()"
+ << "\n\t Invalid arguments were given to this function"
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t Q.nc(): " << Q.nc()
+ << "\n\t is_col_vector(b): " << is_col_vector(b)
+ << "\n\t is_col_vector(alpha): " << is_col_vector(alpha)
+ << "\n\t b.size(): " << b.size()
+ << "\n\t alpha.size(): " << alpha.size()
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t min(alpha): " << min(alpha)
+ << "\n\t eps: " << eps
+ << "\n\t max_iter: " << max_iter
+ );
+
+ const T C = sum(alpha);
+
+ // Compute f'(alpha) (i.e. the gradient of f(alpha)) for the current alpha.
+ matrix<T,NR,NC,MM,L> df = Q*alpha - b;
+
+ const T tau = 1000*std::numeric_limits<T>::epsilon();
+
+ T big, little;
+ unsigned long iter = 0;
+ for (; iter < max_iter; ++iter)
+ {
+ // Find the two elements of df that satisfy the following:
+ // - little_idx == index_of_min(df)
+ // - big_idx == the index of the largest element in df such that alpha(big_idx) > 0
+ // These two indices will tell us which two alpha values are most in violation of the KKT
+ // optimality conditions.
+ big = -std::numeric_limits<T>::max();
+ long big_idx = 0;
+ little = std::numeric_limits<T>::max();
+ long little_idx = 0;
+ for (long i = 0; i < df.nr(); ++i)
+ {
+ if (df(i) > big && alpha(i) > 0)
+ {
+ big = df(i);
+ big_idx = i;
+ }
+ if (df(i) < little)
+ {
+ little = df(i);
+ little_idx = i;
+ }
+ }
+
+ // Check if the KKT conditions are still violated and stop if so.
+ //if (alpha(little_idx) > 0 && (big - little) < eps)
+ // break;
+
+ // Check how big the duality gap is and stop when it goes below eps.
+ // The duality gap is the gap between the objective value of the function
+ // we are optimizing and the value of its primal form. This value is always
+ // greater than or equal to the distance to the optimum solution so it is a
+ // good way to decide if we should stop. See the book referenced above for
+ // more information. In particular, see the part about the Wolfe Dual.
+ if (trans(alpha)*df - C*little < eps)
+ break;
+
+
+ // Save these values, we will need them later.
+ const T old_alpha_big = alpha(big_idx);
+ const T old_alpha_little = alpha(little_idx);
+
+
+ // Now optimize the two variables we just picked.
+ T quad_coef = Q(big_idx,big_idx) + Q(little_idx,little_idx) - 2*Q(big_idx, little_idx);
+ if (quad_coef <= tau)
+ quad_coef = tau;
+ const T delta = (big - little)/quad_coef;
+ alpha(big_idx) -= delta;
+ alpha(little_idx) += delta;
+
+ // Make sure alpha stays feasible. That is, make sure the updated alpha doesn't
+ // violate the non-negativity constraint.
+ if (alpha(big_idx) < 0)
+ {
+ // Since an alpha can't be negative we will just set it to 0 and shift all the
+ // weight to the other alpha.
+ alpha(big_idx) = 0;
+ alpha(little_idx) = old_alpha_big + old_alpha_little;
+ }
+
+ // Every 300 iterations
+ if ((iter%300) == 299)
+ {
+ // Perform this form of the update every so often because doing so can help
+ // avoid the buildup of numerical errors you get with the alternate update
+ // below.
+ df = Q*alpha - b;
+ }
+ else
+ {
+ // Now update the gradient. We will perform the equivalent of: df = Q*alpha - b;
+ const T delta_alpha_big = alpha(big_idx) - old_alpha_big;
+ const T delta_alpha_little = alpha(little_idx) - old_alpha_little;
+
+ for(long k = 0; k < df.nr(); ++k)
+ df(k) += Q(big_idx,k)*delta_alpha_big + Q(little_idx,k)*delta_alpha_little;;
+ }
+ }
+
+ /*
+ using namespace std;
+ cout << "SMO: " << endl;
+ cout << " duality gap: "<< trans(alpha)*df - C*min(df) << endl;
+ cout << " KKT gap: "<< big-little << endl;
+ cout << " iter: "<< iter+1 << endl;
+ cout << " eps: "<< eps << endl;
+ */
+
+ return iter+1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3,
+ typename EXP4,
+ typename T, long NR, long NC, typename MM, typename L,
+ long NR2, long NC2
+ >
+ unsigned long solve_qp4_using_smo (
+ const matrix_exp<EXP1>& A,
+ const matrix_exp<EXP2>& Q,
+ const matrix_exp<EXP3>& b,
+ const matrix_exp<EXP4>& d,
+ matrix<T,NR,NC,MM,L>& alpha,
+ matrix<T,NR2,NC2,MM,L>& lambda,
+ T eps,
+ unsigned long max_iter,
+ T max_lambda = std::numeric_limits<T>::infinity()
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(A.nc() == alpha.size() &&
+ Q.nr() == Q.nc() &&
+ is_col_vector(b) &&
+ is_col_vector(alpha) &&
+ b.size() == alpha.size() &&
+ b.size() == Q.nr() &&
+ alpha.size() > 0 &&
+ min(alpha) >= 0 &&
+ eps > 0 &&
+ max_iter > 0,
+ "\t void solve_qp4_using_smo()"
+ << "\n\t Invalid arguments were given to this function"
+ << "\n\t A.nc(): " << A.nc()
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t Q.nc(): " << Q.nc()
+ << "\n\t is_col_vector(b): " << is_col_vector(b)
+ << "\n\t is_col_vector(alpha): " << is_col_vector(alpha)
+ << "\n\t b.size(): " << b.size()
+ << "\n\t alpha.size(): " << alpha.size()
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t min(alpha): " << min(alpha)
+ << "\n\t eps: " << eps
+ << "\n\t max_iter: " << max_iter
+ );
+ DLIB_ASSERT(is_col_vector(d) == true &&
+ max_lambda >= 0 &&
+ d.size() == A.nr(),
+ "\t void solve_qp4_using_smo()"
+ << "\n\t Invalid arguments were given to this function"
+ << "\n\t A.nr(): " << A.nr()
+ << "\n\t d.size(): " << d.size()
+ << "\n\t max_lambda: " << max_lambda
+ );
+
+ const T C = sum(alpha);
+
+ /*
+ For this optimization problem, it is the case that the optimal
+ value of lambda is given by a simple closed form expression if we
+ know the optimal alpha. So what we will do is to just optimize
+ alpha and every now and then we will update lambda with its optimal
+ value. Therefore, we use essentially the same method as the
+ solve_qp_using_smo() routine.
+ */
+
+ const bool d_is_zero = d==zeros_matrix(d);
+
+ // compute optimal lambda for current alpha
+ if (d_is_zero)
+ lambda = A*alpha;
+ else
+ lambda = A*alpha + d;
+ lambda = dlib::clamp(lambda, 0, max_lambda);
+
+ // Compute f'(alpha) (i.e. the gradient of f(alpha) with respect to alpha) for the current alpha.
+ matrix<T,NR,NC,MM,L> df = Q*alpha - b - trans(A)*lambda;
+
+ const T tau = 1000*std::numeric_limits<T>::epsilon();
+
+ T big, little;
+ unsigned long iter = 0;
+ for (; iter < max_iter; ++iter)
+ {
+ // Find the two elements of df that satisfy the following:
+ // - little_idx == index_of_min(df)
+ // - big_idx == the index of the largest element in df such that alpha(big_idx) > 0
+ // These two indices will tell us which two alpha values are most in violation of the KKT
+ // optimality conditions.
+ big = -std::numeric_limits<T>::max();
+ long big_idx = 0;
+ little = std::numeric_limits<T>::max();
+ long little_idx = 0;
+ for (long i = 0; i < df.nr(); ++i)
+ {
+ if (df(i) > big && alpha(i) > 0)
+ {
+ big = df(i);
+ big_idx = i;
+ }
+ if (df(i) < little)
+ {
+ little = df(i);
+ little_idx = i;
+ }
+ }
+
+ // Check how big the duality gap is and stop when it goes below eps.
+ // The duality gap is the gap between the objective value of the function
+ // we are optimizing and the value of its primal form. This value is always
+ // greater than or equal to the distance to the optimum solution so it is a
+ // good way to decide if we should stop.
+ if (trans(alpha)*df - C*little < eps)
+ {
+ // compute optimal lambda and recheck the duality gap to make
+ // sure we have really converged.
+ if (d_is_zero)
+ lambda = A*alpha;
+ else
+ lambda = A*alpha + d;
+ lambda = dlib::clamp(lambda, 0, max_lambda);
+ df = Q*alpha - b - trans(A)*lambda;
+
+ if (trans(alpha)*df - C*min(df) < eps)
+ break;
+ else
+ continue;
+ }
+
+
+ // Save these values, we will need them later.
+ const T old_alpha_big = alpha(big_idx);
+ const T old_alpha_little = alpha(little_idx);
+
+
+ // Now optimize the two variables we just picked.
+ T quad_coef = Q(big_idx,big_idx) + Q(little_idx,little_idx) - 2*Q(big_idx, little_idx);
+ if (quad_coef <= tau)
+ quad_coef = tau;
+ const T delta = (big - little)/quad_coef;
+ alpha(big_idx) -= delta;
+ alpha(little_idx) += delta;
+
+ // Make sure alpha stays feasible. That is, make sure the updated alpha doesn't
+ // violate the non-negativity constraint.
+ if (alpha(big_idx) < 0)
+ {
+ // Since an alpha can't be negative we will just set it to 0 and shift all the
+ // weight to the other alpha.
+ alpha(big_idx) = 0;
+ alpha(little_idx) = old_alpha_big + old_alpha_little;
+ }
+
+
+ // Every 300 iterations
+ if ((iter%300) == 299)
+ {
+ // compute the optimal lambda for the current alpha
+ if (d_is_zero)
+ lambda = A*alpha;
+ else
+ lambda = A*alpha + d;
+ lambda = dlib::clamp(lambda, 0, max_lambda);
+
+ // Perform this form of the update every so often because doing so can help
+ // avoid the buildup of numerical errors you get with the alternate update
+ // below.
+ df = Q*alpha - b - trans(A)*lambda;
+ }
+ else
+ {
+ // Now update the gradient. We will perform the equivalent of: df = Q*alpha - b;
+ const T delta_alpha_big = alpha(big_idx) - old_alpha_big;
+ const T delta_alpha_little = alpha(little_idx) - old_alpha_little;
+
+ for(long k = 0; k < df.nr(); ++k)
+ df(k) += Q(big_idx,k)*delta_alpha_big + Q(little_idx,k)*delta_alpha_little;;
+ }
+ }
+
+ /*
+ using namespace std;
+ cout << "SMO: " << endl;
+ cout << " duality gap: "<< trans(alpha)*df - C*min(df) << endl;
+ cout << " KKT gap: "<< big-little << endl;
+ cout << " iter: "<< iter+1 << endl;
+ cout << " eps: "<< eps << endl;
+ */
+
+
+ return iter+1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ unsigned long solve_qp_box_constrained (
+ const matrix_exp<EXP1>& Q,
+ const matrix_exp<EXP2>& b,
+ matrix<T,NR,NC,MM,L>& alpha,
+ const matrix<T,NR,NC,MM,L>& lower,
+ const matrix<T,NR,NC,MM,L>& upper,
+ T eps,
+ unsigned long max_iter
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(Q.nr() == Q.nc() &&
+ alpha.size() == lower.size() &&
+ alpha.size() == upper.size() &&
+ is_col_vector(b) &&
+ is_col_vector(alpha) &&
+ is_col_vector(lower) &&
+ is_col_vector(upper) &&
+ b.size() == alpha.size() &&
+ b.size() == Q.nr() &&
+ alpha.size() > 0 &&
+ 0 <= min(alpha-lower) &&
+ 0 <= max(upper-alpha) &&
+ eps > 0 &&
+ max_iter > 0,
+ "\t unsigned long solve_qp_box_constrained()"
+ << "\n\t Invalid arguments were given to this function"
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t Q.nc(): " << Q.nc()
+ << "\n\t is_col_vector(b): " << is_col_vector(b)
+ << "\n\t is_col_vector(alpha): " << is_col_vector(alpha)
+ << "\n\t is_col_vector(lower): " << is_col_vector(lower)
+ << "\n\t is_col_vector(upper): " << is_col_vector(upper)
+ << "\n\t b.size(): " << b.size()
+ << "\n\t alpha.size(): " << alpha.size()
+ << "\n\t lower.size(): " << lower.size()
+ << "\n\t upper.size(): " << upper.size()
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t min(alpha-lower): " << min(alpha-lower)
+ << "\n\t max(upper-alpha): " << max(upper-alpha)
+ << "\n\t eps: " << eps
+ << "\n\t max_iter: " << max_iter
+ );
+
+
+ // Compute f'(alpha) (i.e. the gradient of f(alpha)) for the current alpha.
+ matrix<T,NR,NC,MM,L> df = Q*alpha + b;
+ matrix<T,NR,NC,MM,L> QQ = reciprocal_max(diag(Q));
+
+ // First we use a coordinate descent method to initialize alpha.
+ double max_df = 0;
+ for (long iter = 0; iter < alpha.size()*2; ++iter)
+ {
+ max_df = 0;
+ long best_r =0;
+ // find the best alpha to optimize.
+ for (long r = 0; r < Q.nr(); ++r)
+ {
+ if (alpha(r) <= lower(r) && df(r) > 0)
+ ;//alpha(r) = lower(r);
+ else if (alpha(r) >= upper(r) && df(r) < 0)
+ ;//alpha(r) = upper(r);
+ else if (std::abs(df(r)) > max_df)
+ {
+ best_r = r;
+ max_df = std::abs(df(r));
+ }
+ }
+
+ // now optimize alpha(best_r)
+ const long r = best_r;
+ const T old_alpha = alpha(r);
+ alpha(r) = -(df(r)-Q(r,r)*alpha(r))*QQ(r);
+ if (alpha(r) < lower(r))
+ alpha(r) = lower(r);
+ else if (alpha(r) > upper(r))
+ alpha(r) = upper(r);
+
+ const T delta = old_alpha-alpha(r);
+
+ // Now update the gradient. We will perform the equivalent of: df = Q*alpha + b;
+ for(long k = 0; k < df.nr(); ++k)
+ df(k) -= Q(r,k)*delta;
+ }
+ //cout << "max_df: " << max_df << endl;
+ //cout << "objective value: " << 0.5*trans(alpha)*Q*alpha + trans(b)*alpha << endl;
+
+
+
+ // Now do the main iteration block of this solver. The coordinate descent method
+ // we used above can improve the objective rapidly in the beginning. However,
+ // Nesterov's method has more rapid convergence once it gets going so this is what
+ // we use for the main iteration.
+ matrix<T,NR,NC,MM,L> v, v_old;
+ v = alpha;
+ // We need to get an upper bound on the Lipschitz constant for this QP. Since that
+ // is just the max eigenvalue of Q we can do it using Gershgorin disks.
+ const T lipschitz_bound = max(diag(Q) + (sum_cols(abs(Q)) - abs(diag(Q))));
+ double lambda = 0;
+ unsigned long iter;
+ for (iter = 0; iter < max_iter; ++iter)
+ {
+ const double next_lambda = (1 + std::sqrt(1+4*lambda*lambda))/2;
+ const double gamma = (1-lambda)/next_lambda;
+ lambda = next_lambda;
+
+ v_old = v;
+
+ df = Q*alpha + b;
+ // now take a projected gradient step using Nesterov's method.
+ v = clamp(alpha - 1.0/lipschitz_bound * df, lower, upper);
+ alpha = dlib::clamp((1-gamma)*v + gamma*v_old, lower, upper);
+
+
+ // check for convergence every 10 iterations
+ if (iter%10 == 0)
+ {
+ max_df = 0;
+ for (long r = 0; r < Q.nr(); ++r)
+ {
+ if (alpha(r) <= lower(r) && df(r) > 0)
+ ;//alpha(r) = lower(r);
+ else if (alpha(r) >= upper(r) && df(r) < 0)
+ ;//alpha(r) = upper(r);
+ else if (std::abs(df(r)) > max_df)
+ max_df = std::abs(df(r));
+ }
+ if (max_df < eps)
+ break;
+ }
+ }
+
+ //cout << "max_df: " << max_df << endl;
+ //cout << "objective value: " << 0.5*trans(alpha)*Q*alpha + trans(b)*alpha << endl;
+ return iter+1;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ // Check if each vector in Q_offdiag is actually a constant times the 1s vector.
+ template <
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ bool has_uniform_offdiag_vectors(
+ const std::map<unordered_pair<size_t>, matrix<T,NR,NC,MM,L>>& Q_offdiag
+ )
+ {
+ for (auto& x : Q_offdiag)
+ {
+ auto ref = x.second(0);
+ for (auto& y : x.second)
+ if (ref != y)
+ return false;
+ }
+ return true;
+ }
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ matrix<T,0,0,MM,L> compact_offdiag(
+ const size_t& num_blocks,
+ const std::map<unordered_pair<size_t>, matrix<T,NR,NC,MM,L>>& Q_offdiag
+ )
+ {
+ matrix<T,0,0,MM,L> temp;
+ // we can only compact the offdiag information if they are uniform vectors
+ if (!has_uniform_offdiag_vectors(Q_offdiag))
+ return temp;
+
+ temp.set_size(num_blocks, num_blocks);
+ temp = 0;
+
+ for (auto& x : Q_offdiag)
+ {
+ long r = x.first.first;
+ long c = x.first.second;
+ temp(r,c) = x.second(0);
+ temp(c,r) = x.second(0);
+ }
+
+ return temp;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ unsigned long solve_qp_box_constrained_blockdiag (
+ const std::vector<matrix<T,NR,NR,MM,L>>& Q_blocks,
+ const std::vector<matrix<T,NR,NC,MM,L>>& bs,
+ const std::map<unordered_pair<size_t>, matrix<T,NR,NC,MM,L>>& Q_offdiag,
+ std::vector<matrix<T,NR,NC,MM,L>>& alphas,
+ const std::vector<matrix<T,NR,NC,MM,L>>& lowers,
+ const std::vector<matrix<T,NR,NC,MM,L>>& uppers,
+ T eps,
+ unsigned long max_iter
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(Q_blocks.size() > 0);
+ DLIB_CASSERT(Q_blocks.size() == bs.size() &&
+ Q_blocks.size() == alphas.size() &&
+ Q_blocks.size() == lowers.size() &&
+ Q_blocks.size() == uppers.size(),
+ "Q_blocks.size(): "<< Q_blocks.size() << "\n" <<
+ "bs.size(): "<< bs.size() << "\n" <<
+ "alphas.size(): "<< alphas.size() << "\n" <<
+ "lowers.size(): "<< lowers.size() << "\n" <<
+ "uppers.size(): "<< uppers.size() << "\n"
+ );
+ for (auto& Q : Q_blocks)
+ {
+ DLIB_CASSERT(Q.nr() == Q.nc(), "All the matrices in Q_blocks have the same dimensions.");
+ DLIB_CASSERT(Q.size() > 0, "All the matrices in Q_blocks must be non-empty and have the same dimensions.");
+ DLIB_CASSERT(Q.nr() == Q_blocks[0].nr() && Q.nc() == Q_blocks[0].nc(), "All the matrices in Q_blocks have the same dimensions.");
+ }
+#ifdef ENABLE_ASSERTS
+ for (size_t i = 0; i < alphas.size(); ++i)
+ {
+ DLIB_CASSERT(is_col_vector(bs[i]) && bs[i].size() == Q_blocks[0].nr(),
+ "is_col_vector(bs["<<i<<"]): " << is_col_vector(bs[i]) << "\n" <<
+ "bs["<<i<<"].size(): " << bs[i].size() << "\n" <<
+ "Q_blocks[0].nr(): " << Q_blocks[0].nr());
+
+ for (auto& Qoffdiag : Q_offdiag)
+ {
+ auto& Q_offdiag_element = Qoffdiag.second;
+ long r = Qoffdiag.first.first;
+ long c = Qoffdiag.first.second;
+ DLIB_CASSERT(is_col_vector(Q_offdiag_element) && Q_offdiag_element.size() == Q_blocks[0].nr(),
+ "is_col_vector(Q_offdiag["<<r<<","<<c<<"]): " << is_col_vector(Q_offdiag_element) << "\n" <<
+ "Q_offdiag["<<r<<","<<c<<"].size(): " << Q_offdiag_element.size() << "\n" <<
+ "Q_blocks[0].nr(): " << Q_blocks[0].nr());
+ }
+
+ DLIB_CASSERT(is_col_vector(alphas[i]) && alphas[i].size() == Q_blocks[0].nr(),
+ "is_col_vector(alphas["<<i<<"]): " << is_col_vector(alphas[i]) << "\n" <<
+ "alphas["<<i<<"].size(): " << alphas[i].size() << "\n" <<
+ "Q_blocks[0].nr(): " << Q_blocks[0].nr());
+
+ DLIB_CASSERT(is_col_vector(lowers[i]) && lowers[i].size() == Q_blocks[0].nr(),
+ "is_col_vector(lowers["<<i<<"]): " << is_col_vector(lowers[i]) << "\n" <<
+ "lowers["<<i<<"].size(): " << lowers[i].size() << "\n" <<
+ "Q_blocks[0].nr(): " << Q_blocks[0].nr());
+
+ DLIB_CASSERT(is_col_vector(uppers[i]) && uppers[i].size() == Q_blocks[0].nr(),
+ "is_col_vector(uppers["<<i<<"]): " << is_col_vector(uppers[i]) << "\n" <<
+ "uppers["<<i<<"].size(): " << uppers[i].size() << "\n" <<
+ "Q_blocks[0].nr(): " << Q_blocks[0].nr());
+
+ DLIB_CASSERT(0 <= min(alphas[i]-lowers[i]), "min(alphas["<<i<<"]-lowers["<<i<<"]): " << min(alphas[i]-lowers[i]));
+ DLIB_CASSERT(0 <= max(uppers[i]-alphas[i]), "max(uppers["<<i<<"]-alphas["<<i<<"]): " << max(uppers[i]-alphas[i]));
+ }
+ DLIB_CASSERT(eps > 0 && max_iter > 0, "eps: " << eps << "\nmax_iter: "<< max_iter);
+#endif // ENABLE_ASSERTS
+
+
+ const auto offdiag_compact = impl::compact_offdiag(Q_blocks.size(), Q_offdiag);
+ matrix<T,0,0,MM,L> temp, alphas_compact;
+
+ // Compute f'(alpha) (i.e. the gradient of f(alpha)) for the current alpha.
+ std::vector<matrix<T,NR,NC,MM,L>> df;// = Q*alpha + b;
+ auto compute_df = [&]()
+ {
+ df.resize(Q_blocks.size());
+ for (size_t i = 0; i < df.size(); ++i)
+ df[i] = Q_blocks[i]*alphas[i] + bs[i];
+
+
+ // Don't forget to include the Q_offdiag terms in the computation. Note that
+ // we have two options for how we can compute this part. If Q_offdiag is
+ // uniform and can be compacted into a simple matrix and there are a lot of off
+ // diagonal entries then it's faster to do it as a matrix multiply. Otherwise
+ // we do the more general computation.
+ if (offdiag_compact.size() != 0 && Q_offdiag.size() > Q_blocks.size()*5)
+ {
+ // Do it as a matrix multiply (with a bit of data shuffling)
+ alphas_compact.set_size(alphas[0].size(), offdiag_compact.nr());
+ for (long c = 0; c < alphas_compact.nc(); ++c)
+ set_colm(alphas_compact,c) = alphas[c];
+ temp = alphas_compact*offdiag_compact;
+ for (size_t i = 0; i < df.size(); ++i)
+ df[i] += colm(temp,i);
+ }
+ else
+ {
+ // Do the fully general computation that allows for non-uniform values in
+ // the off diagonal vectors.
+ for (auto& p : Q_offdiag)
+ {
+ long r = p.first.first;
+ long c = p.first.second;
+ df[r] += pointwise_multiply(p.second, alphas[c]);
+ if (r != c)
+ df[c] += pointwise_multiply(p.second, alphas[r]);
+ }
+ }
+ };
+ compute_df();
+
+
+
+ std::vector<matrix<T,NR,NC,MM,L>> Q_diag, Q_ggd;
+ std::vector<matrix<T,NR,NC,MM,L>> QQ;// = reciprocal_max(diag(Q));
+ QQ.resize(Q_blocks.size());
+ Q_diag.resize(Q_blocks.size());
+ Q_ggd.resize(Q_blocks.size());
+
+ // We need to get an upper bound on the Lipschitz constant for this QP. Since that
+ // is just the max eigenvalue of Q we can do it using Gershgorin disks.
+ //const T lipschitz_bound = max(diag(Q) + (sum_cols(abs(Q)) - abs(diag(Q))));
+ for (size_t i = 0; i < QQ.size(); ++i)
+ {
+ auto f = Q_offdiag.find(make_unordered_pair(i,i));
+ if (f != Q_offdiag.end())
+ Q_diag[i] = diag(Q_blocks[i]) + f->second;
+ else
+ Q_diag[i] = diag(Q_blocks[i]);
+ QQ[i] = reciprocal_max(Q_diag[i]);
+
+ Q_ggd[i] = Q_diag[i] + (sum_cols(abs(Q_blocks[i]))-abs(diag(Q_blocks[i])));
+ }
+ for (auto& p : Q_offdiag)
+ {
+ long r = p.first.first;
+ long c = p.first.second;
+ if (r != c)
+ {
+ Q_ggd[r] += abs(p.second);
+ Q_ggd[c] += abs(p.second);
+ }
+ }
+ T lipschitz_bound = -std::numeric_limits<T>::infinity();
+ for (auto& x : Q_ggd)
+ lipschitz_bound = std::max(lipschitz_bound, max(x));
+
+
+ const long num_variables = alphas.size()*alphas[0].size();
+
+ // First we use a coordinate descent method to initialize alpha.
+ double max_df = 0;
+ for (long iter = 0; iter < num_variables*2; ++iter)
+ {
+ max_df = 0;
+ long best_r =0;
+ size_t best_r2 =0;
+ // find the best alpha to optimize.
+ for (size_t r2 = 0; r2 < alphas.size(); ++r2)
+ {
+ auto& alpha = alphas[r2];
+ auto& df_ = df[r2];
+ auto& lower = lowers[r2];
+ auto& upper = uppers[r2];
+ for (long r = 0; r < alpha.nr(); ++r)
+ {
+ if (alpha(r) <= lower(r) && df_(r) > 0)
+ ;//alpha(r) = lower(r);
+ else if (alpha(r) >= upper(r) && df_(r) < 0)
+ ;//alpha(r) = upper(r);
+ else if (std::abs(df_(r)) > max_df)
+ {
+ best_r = r;
+ best_r2 = r2;
+ max_df = std::abs(df_(r));
+ }
+ }
+ }
+
+ // now optimize alphas[best_r2](best_r)
+ const long r = best_r;
+ auto& alpha = alphas[best_r2];
+ auto& lower = lowers[best_r2];
+ auto& upper = uppers[best_r2];
+ auto& df_ = df[best_r2];
+ const T old_alpha = alpha(r);
+ alpha(r) = -(df_(r)-Q_diag[best_r2](r)*alpha(r))*QQ[best_r2](r);
+ if (alpha(r) < lower(r))
+ alpha(r) = lower(r);
+ else if (alpha(r) > upper(r))
+ alpha(r) = upper(r);
+
+ const T delta = old_alpha-alpha(r);
+
+ // Now update the gradient. We will perform the equivalent of: df = Q*alpha +
+ // b; except we only need to compute one column of the matrix multiply because
+ // only one element of alpha changed.
+ auto& Q = Q_blocks[best_r2];
+ for(long k = 0; k < df_.nr(); ++k)
+ df_(k) -= Q(r,k)*delta;
+ for(size_t j = 0; j < Q_blocks.size(); ++j)
+ {
+ auto f = Q_offdiag.find(make_unordered_pair(best_r2, j));
+ if (f != Q_offdiag.end())
+ df[j](r) -= f->second(r)*delta;
+ }
+ }
+
+
+
+
+ std::vector<matrix<T,NR,NC,MM,L>> v(alphas), v_old(alphas.size());
+ double lambda = 0;
+ unsigned long iter;
+ // Now do the main iteration block of this solver. The coordinate descent method
+ // we used above can improve the objective rapidly in the beginning. However,
+ // Nesterov's method has more rapid convergence once it gets going so this is what
+ // we use for the main iteration.
+ for (iter = 0; iter < max_iter; ++iter)
+ {
+ const double next_lambda = (1 + std::sqrt(1+4*lambda*lambda))/2;
+ const double gamma = (1-lambda)/next_lambda;
+ lambda = next_lambda;
+
+ v_old.swap(v);
+
+ //df = Q*alpha + b;
+ compute_df();
+
+ // now take a projected gradient step using Nesterov's method.
+ for (size_t j = 0; j < alphas.size(); ++j)
+ {
+ v[j] = clamp(alphas[j] - 1.0/lipschitz_bound * df[j], lowers[j], uppers[j]);
+ alphas[j] = clamp((1-gamma)*v[j] + gamma*v_old[j], lowers[j], uppers[j]);
+ }
+
+
+ // check for convergence every 10 iterations
+ if (iter%10 == 0)
+ {
+ max_df = 0;
+ for (size_t r2 = 0; r2 < alphas.size(); ++r2)
+ {
+ auto& alpha = alphas[r2];
+ auto& df_ = df[r2];
+ auto& lower = lowers[r2];
+ auto& upper = uppers[r2];
+ for (long r = 0; r < alpha.nr(); ++r)
+ {
+ if (alpha(r) <= lower(r) && df_(r) > 0)
+ ;//alpha(r) = lower(r);
+ else if (alpha(r) >= upper(r) && df_(r) < 0)
+ ;//alpha(r) = upper(r);
+ else if (std::abs(df_(r)) > max_df)
+ max_df = std::abs(df_(r));
+ }
+ }
+ if (max_df < eps)
+ break;
+ }
+ }
+
+ return iter+1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NRa, long NRb
+ >
+ unsigned long find_gap_between_convex_hulls (
+ const matrix_exp<EXP1>& A,
+ const matrix_exp<EXP2>& B,
+ matrix<T,NRa,1>& cA,
+ matrix<T,NRb,1>& cB,
+ const double eps,
+ const unsigned long max_iter = 1000
+ )
+ {
+ DLIB_CASSERT(A.size() != 0);
+ DLIB_CASSERT(B.size() != 0);
+ DLIB_CASSERT(A.nr() == B.nr(), "The dimensionality of the points in both convex hull sets must match");
+ DLIB_CASSERT(eps > 0);
+ DLIB_CASSERT(max_iter > 0);
+
+ cA.set_size(A.nc());
+ cB.set_size(B.nc());
+
+ // initialize to the centroids of A and B respectively.
+ cA = 1.0/cA.size();
+ cB = 1.0/cB.size();
+
+
+ matrix<T> AA, BB, AB, ABb, ABa;
+
+ AA = trans(A)*A;
+ BB = trans(B)*B;
+ AB = trans(A)*B;
+
+ unsigned long iter = 0;
+ for (iter = 0; iter < max_iter; ++iter)
+ {
+ // find the convex combination of A that is nearest to B*cB
+ ABb = AB*cB;
+ const auto smo_iter1 = solve_qp_using_smo(AA, ABb, cA, eps, cA.size());
+
+ // now find the convex combination of B that is nearest to A*cA
+ ABa = trans(AB)*cA;
+ const auto smo_iter2 = solve_qp_using_smo(BB, ABa, cB, eps, cB.size());
+
+ // stop if the QP solvers failed to improve
+ if (smo_iter1 == 1 && smo_iter2 == 1)
+ break;
+ }
+
+
+ return iter+1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_Hh_
+
diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo_abstract.h b/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo_abstract.h
new file mode 100644
index 000000000..5e7d5ec3f
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo_abstract.h
@@ -0,0 +1,282 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_ABSTRACT_Hh_
+#ifdef DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_ABSTRACT_Hh_
+
+#include "../matrix.h"
+#include <map>
+#include "../unordered_pair.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ unsigned long solve_qp_using_smo (
+ const matrix_exp<EXP1>& Q,
+ const matrix_exp<EXP2>& b,
+ matrix<T,NR,NC,MM,L>& alpha,
+ T eps,
+ unsigned long max_iter
+ );
+ /*!
+ requires
+ - Q.nr() == Q.nc()
+ - is_col_vector(b) == true
+ - is_col_vector(alpha) == true
+ - b.size() == alpha.size() == Q.nr()
+ - alpha.size() > 0
+ - min(alpha) >= 0
+ - eps > 0
+ - max_iter > 0
+ ensures
+ - Let C == sum(alpha) (i.e. C is the sum of the alpha values you
+ supply to this function)
+ - This function solves the following quadratic program:
+ Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha - trans(alpha)*b
+ subject to the following constraints:
+ - sum(alpha) == C (i.e. the sum of alpha values doesn't change)
+ - min(alpha) >= 0 (i.e. all alpha values are nonnegative)
+ Where f is convex. This means that Q should be positive-semidefinite.
+ - The solution to the above QP will be stored in #alpha.
+ - This function uses a simple implementation of the sequential minimal
+ optimization algorithm. It starts the algorithm with the given alpha
+ and it works on the problem until the duality gap (i.e. how far away
+ we are from the optimum solution) is less than eps. So eps controls
+ how accurate the solution is and smaller values result in better solutions.
+ - At most max_iter iterations of optimization will be performed.
+ - returns the number of iterations performed. If this method fails to
+ converge to eps accuracy then the number returned will be max_iter+1.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3,
+ typename T, long NR, long NC, typename MM, typename L,
+ long NR2, long NC2
+ >
+ unsigned long solve_qp4_using_smo (
+ const matrix_exp<EXP1>& A,
+ const matrix_exp<EXP2>& Q,
+ const matrix_exp<EXP3>& b,
+ const matrix_exp<EXP4>& d,
+ matrix<T,NR,NC,MM,L>& alpha,
+ matrix<T,NR2,NC2,MM,L>& lambda,
+ T eps,
+ unsigned long max_iter,
+ T max_lambda = std::numeric_limits<T>::infinity()
+ );
+ /*!
+ requires
+ - A.nc() == alpha.size()
+ - Q.nr() == Q.nc()
+ - is_col_vector(b) == true
+ - is_col_vector(d) == true
+ - is_col_vector(alpha) == true
+ - b.size() == alpha.size() == Q.nr()
+ - d.size() == A.nr()
+ - alpha.size() > 0
+ - min(alpha) >= 0
+ - eps > 0
+ - max_iter > 0
+ - max_lambda >= 0
+ ensures
+ - Let C == sum(alpha) (i.e. C is the sum of the alpha values you
+ supply to this function)
+ - This function solves the following quadratic program:
+ Minimize: f(alpha,lambda) == 0.5*trans(alpha)*Q*alpha - trans(alpha)*b +
+ 0.5*trans(lambda)*lambda - trans(lambda)*A*alpha - trans(lambda)*d
+ subject to the following constraints:
+ - sum(alpha) == C (i.e. the sum of alpha values doesn't change)
+ - min(alpha) >= 0 (i.e. all alpha values are nonnegative)
+ - min(lambda) >= 0 (i.e. all lambda values are nonnegative)
+ - max(lambda) <= max_lambda (i.e. all lambda values are less than max_lambda)
+ Where f is convex. This means that Q should be positive-semidefinite.
+ - If you don't want an upper limit on lambda then max_lambda can be set to
+ infinity.
+ - The solution to the above QP will be stored in #alpha and #lambda.
+ - This function uses a simple implementation of the sequential minimal
+ optimization algorithm. It starts the algorithm with the given alpha
+ and it works on the problem until the duality gap (i.e. how far away
+ we are from the optimum solution) is less than eps. So eps controls
+ how accurate the solution is and smaller values result in better solutions.
+ The initial value of lambda is ignored since the optimal lambda can be
+ obtained via a simple closed form expression given alpha.
+ - At most max_iter iterations of optimization will be performed.
+ - returns the number of iterations performed. If this method fails to
+ converge to eps accuracy then the number returned will be max_iter+1.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ unsigned long solve_qp_box_constrained (
+ const matrix_exp<EXP1>& Q,
+ const matrix_exp<EXP2>& b,
+ matrix<T,NR,NC,MM,L>& alpha,
+ const matrix<T,NR,NC,MM,L>& lower,
+ const matrix<T,NR,NC,MM,L>& upper,
+ T eps,
+ unsigned long max_iter
+ );
+ /*!
+ requires
+ - Q.nr() == Q.nc()
+ - alpha.size() == lower.size() == upper.size()
+ - is_col_vector(b) == true
+ - is_col_vector(alpha) == true
+ - is_col_vector(lower) == true
+ - is_col_vector(upper) == true
+ - b.size() == alpha.size() == Q.nr()
+ - alpha.size() > 0
+ - 0 <= min(alpha-lower)
+ - 0 <= max(upper-alpha)
+ - eps > 0
+ - max_iter > 0
+ ensures
+ - This function solves the following quadratic program:
+ Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + trans(b)*alpha
+ subject to the following box constraints on alpha:
+ - 0 <= min(alpha-lower)
+ - 0 <= max(upper-alpha)
+ Where f is convex. This means that Q should be positive-semidefinite.
+ - The solution to the above QP will be stored in #alpha.
+ - This function uses a combination of a SMO algorithm along with Nesterov's
+ method as the main iteration of the solver. It starts the algorithm with the
+ given alpha and it works on the problem until the derivative of f(alpha) is
+ smaller than eps for each element of alpha or the alpha value is at a box
+ constraint. So eps controls how accurate the solution is and smaller values
+ result in better solutions.
+ - At most max_iter iterations of optimization will be performed.
+ - returns the number of iterations performed. If this method fails to
+ converge to eps accuracy then the number returned will be max_iter+1.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ unsigned long solve_qp_box_constrained_blockdiag (
+ const std::vector<matrix<T,NR,NR,MM,L>>& Q_blocks,
+ const std::vector<matrix<T,NR,NC,MM,L>>& bs,
+ const std::map<unordered_pair<size_t>, matrix<T,NR,NC,MM,L>>& Q_offdiag,
+ std::vector<matrix<T,NR,NC,MM,L>>& alphas,
+ const std::vector<matrix<T,NR,NC,MM,L>>& lowers,
+ const std::vector<matrix<T,NR,NC,MM,L>>& uppers,
+ T eps,
+ unsigned long max_iter
+ );
+ /*!
+ requires
+ - Q_blocks.size() > 0
+ - Q_blocks.size() == bs.size() == alphas.size() == lowers.size() == uppers.size()
+ - All the matrices in Q_blocks have the same dimensions. Moreover, they are
+ non-empty square matrices.
+ - All the matrices in bs, Q_offdiag, alphas, lowers, and uppers have the same
+ dimensions. Moreover, they are all column vectors.
+ - Q_blocks[0].nr() == alphas[0].size()
+ (i.e. the dimensionality of the column vectors in alphas must match the
+ dimensionality of the square matrices in Q_blocks.)
+ - for all valid i:
+ - 0 <= min(alphas[i]-lowers[i])
+ - 0 <= max(uppers[i]-alphas[i])
+ - eps > 0
+ - max_iter > 0
+ ensures
+ - This function solves the same QP as solve_qp_box_constrained(), except it is
+ optimized for the case where the Q matrix has a certain sparsity structure.
+ To be precise:
+ - Let Q1 be a block diagonal matrix with the elements of Q_blocks placed
+ along its diagonal, and in the order contained in Q_blocks.
+ - Let Q2 be a matrix with the same size as Q1, except instead of being block diagonal, it
+ is block structured into Q_blocks.nr() by Q_blocks.nc() blocks. If we let (r,c) be the
+ coordinate of each block then each block contains the matrix
+ diagm(Q_offdiag[make_unordered_pair(r,c)]) or the zero matrix if Q_offdiag has no entry
+ for the coordinate (r,c).
+ - Let Q == Q1+Q2
+ - Let b == the concatenation of all the vectors in bs into one big vector.
+ - Let alpha == the concatenation of all the vectors in alphas into one big vector.
+ - Let lower == the concatenation of all the vectors in lowers into one big vector.
+ - Let upper == the concatenation of all the vectors in uppers into one big vector.
+ - Then this function solves the following quadratic program:
+ Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + trans(b)*alpha
+ subject to the following box constraints on alpha:
+ - 0 <= min(alpha-lower)
+ - 0 <= max(upper-alpha)
+ Where f is convex. This means that Q should be positive-semidefinite.
+ - More specifically, this function is identical to
+ solve_qp_box_constrained(Q, b, alpha, lower, upper, eps, max_iter),
+ except that it runs faster since it avoids unnecessary computation by
+ taking advantage of the sparsity structure in the QP.
+ - The solution to the above QP will be stored in #alphas.
+ - This function uses a combination of a SMO algorithm along with Nesterov's
+ method as the main iteration of the solver. It starts the algorithm with the
+ given alpha and it works on the problem until the derivative of f(alpha) is
+ smaller than eps for each element of alpha or the alpha value is at a box
+ constraint. So eps controls how accurate the solution is and smaller values
+ result in better solutions.
+ - At most max_iter iterations of optimization will be performed.
+ - returns the number of iterations performed. If this method fails to
+ converge to eps accuracy then the number returned will be max_iter+1.
+ !*/
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NRa, long NRb
+ >
+ unsigned long find_gap_between_convex_hulls (
+ const matrix_exp<EXP1>& A,
+ const matrix_exp<EXP2>& B,
+ matrix<T,NRa,1>& cA,
+ matrix<T,NRb,1>& cB,
+ const double eps,
+ const unsigned long max_iter = 1000
+ );
+ /*!
+ requires
+ - A.nr() == B.nr()
+ - A.size() != 0
+ - B.size() != 0
+ - eps > 0
+ - max_iter > 0
+ ensures
+ - If you think of A and B as sets of column vectors, then we can identify the
+ convex sets hullA and hullB, which are the convex hulls of A and B
+ respectively. This function finds the pair of points in hullA and hullB that
+ are nearest to each other. To be precise, this function solves the following
+ quadratic program:
+ Minimize: f(cA,cB) == length_squared(A*cA - B*cB)
+ subject to the following constraints on cA and cB:
+ - is_col_vector(cA) == true && cA.size() == A.nc()
+ - is_col_vector(cB) == true && cB.size() == B.nc()
+ - sum(cA) == 1 && min(cA) >= 0
+ - sum(cB) == 1 && min(cB) >= 0
+ - This function uses an iterative block coordinate descent algorithm to solve
+ the QP. It runs until either max_iter iterations have been performed or the
+ QP is solved to at least eps accuracy.
+ - returns the number of iterations performed. If this method fails to
+ converge to eps accuracy then the number returned will be max_iter+1.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/optimization/optimization_stop_strategies.h b/ml/dlib/dlib/optimization/optimization_stop_strategies.h
new file mode 100644
index 000000000..a0243cacf
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_stop_strategies.h
@@ -0,0 +1,173 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATIOn_STOP_STRATEGIES_H_
+#define DLIB_OPTIMIZATIOn_STOP_STRATEGIES_H_
+
+#include <cmath>
+#include <limits>
+#include "../matrix.h"
+#include "../algs.h"
+#include "optimization_stop_strategies_abstract.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class objective_delta_stop_strategy
+ {
+ public:
+ explicit objective_delta_stop_strategy (
+ double min_delta = 1e-7
+ ) : _verbose(false), _been_used(false), _min_delta(min_delta), _max_iter(0), _cur_iter(0), _prev_funct_value(0)
+ {
+ DLIB_ASSERT (
+ min_delta >= 0,
+ "\t objective_delta_stop_strategy(min_delta)"
+ << "\n\t min_delta can't be negative"
+ << "\n\t min_delta: " << min_delta
+ );
+ }
+
+ objective_delta_stop_strategy (
+ double min_delta,
+ unsigned long max_iter
+ ) : _verbose(false), _been_used(false), _min_delta(min_delta), _max_iter(max_iter), _cur_iter(0), _prev_funct_value(0)
+ {
+ DLIB_ASSERT (
+ min_delta >= 0 && max_iter > 0,
+ "\t objective_delta_stop_strategy(min_delta, max_iter)"
+ << "\n\t min_delta can't be negative and max_iter can't be 0"
+ << "\n\t min_delta: " << min_delta
+ << "\n\t max_iter: " << max_iter
+ );
+ }
+
+ objective_delta_stop_strategy& be_verbose(
+ )
+ {
+ _verbose = true;
+ return *this;
+ }
+
+ template <typename T>
+ bool should_continue_search (
+ const T& ,
+ const double funct_value,
+ const T&
+ )
+ {
+ if (_verbose)
+ {
+ using namespace std;
+ cout << "iteration: " << _cur_iter << " objective: " << funct_value << endl;
+ }
+
+ ++_cur_iter;
+ if (_been_used)
+ {
+ // Check if we have hit the max allowable number of iterations. (but only
+ // check if _max_iter is enabled (i.e. not 0)).
+ if (_max_iter != 0 && _cur_iter > _max_iter)
+ return false;
+
+ // check if the function change was too small
+ if (std::abs(funct_value - _prev_funct_value) < _min_delta)
+ return false;
+ }
+
+ _been_used = true;
+ _prev_funct_value = funct_value;
+ return true;
+ }
+
+ private:
+ bool _verbose;
+
+ bool _been_used;
+ double _min_delta;
+ unsigned long _max_iter;
+ unsigned long _cur_iter;
+ double _prev_funct_value;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class gradient_norm_stop_strategy
+ {
+ public:
+ explicit gradient_norm_stop_strategy (
+ double min_norm = 1e-7
+ ) : _verbose(false), _min_norm(min_norm), _max_iter(0), _cur_iter(0)
+ {
+ DLIB_ASSERT (
+ min_norm >= 0,
+ "\t gradient_norm_stop_strategy(min_norm)"
+ << "\n\t min_norm can't be negative"
+ << "\n\t min_norm: " << min_norm
+ );
+ }
+
+ gradient_norm_stop_strategy (
+ double min_norm,
+ unsigned long max_iter
+ ) : _verbose(false), _min_norm(min_norm), _max_iter(max_iter), _cur_iter(0)
+ {
+ DLIB_ASSERT (
+ min_norm >= 0 && max_iter > 0,
+ "\t gradient_norm_stop_strategy(min_norm, max_iter)"
+ << "\n\t min_norm can't be negative and max_iter can't be 0"
+ << "\n\t min_norm: " << min_norm
+ << "\n\t max_iter: " << max_iter
+ );
+ }
+
+ gradient_norm_stop_strategy& be_verbose(
+ )
+ {
+ _verbose = true;
+ return *this;
+ }
+
+ template <typename T>
+ bool should_continue_search (
+ const T& ,
+ const double funct_value,
+ const T& funct_derivative
+ )
+ {
+ if (_verbose)
+ {
+ using namespace std;
+ cout << "iteration: " << _cur_iter << " objective: " << funct_value << " gradient norm: " << length(funct_derivative) << endl;
+ }
+
+ ++_cur_iter;
+
+ // Check if we have hit the max allowable number of iterations. (but only
+ // check if _max_iter is enabled (i.e. not 0)).
+ if (_max_iter != 0 && _cur_iter > _max_iter)
+ return false;
+
+ // check if the gradient norm is too small
+ if (length(funct_derivative) < _min_norm)
+ return false;
+
+ return true;
+ }
+
+ private:
+ bool _verbose;
+
+ double _min_norm;
+ unsigned long _max_iter;
+ unsigned long _cur_iter;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_STOP_STRATEGIES_H_
+
diff --git a/ml/dlib/dlib/optimization/optimization_stop_strategies_abstract.h b/ml/dlib/dlib/optimization/optimization_stop_strategies_abstract.h
new file mode 100644
index 000000000..6a999f8d9
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_stop_strategies_abstract.h
@@ -0,0 +1,157 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATIOn_STOP_STRATEGIES_ABSTRACT_
+#ifdef DLIB_OPTIMIZATIOn_STOP_STRATEGIES_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class objective_delta_stop_strategy
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a strategy for deciding if an optimization
+ algorithm should terminate. This particular object looks at the
+ change in the objective function from one iteration to the next and
+ bases its decision on how large this change is. If the change
+ is below a user given threshold then the search stops.
+ !*/
+
+ public:
+ explicit objective_delta_stop_strategy (
+ double min_delta = 1e-7
+ );
+ /*!
+ requires
+ - min_delta >= 0
+ ensures
+ - This stop strategy object will only consider a search to be complete
+ if a change in an objective function from one iteration to the next
+ is less than min_delta.
+ !*/
+
+ objective_delta_stop_strategy (
+ double min_delta,
+ unsigned long max_iter
+ );
+ /*!
+ requires
+ - min_delta >= 0
+ - max_iter > 0
+ ensures
+ - This stop strategy object will only consider a search to be complete
+ if a change in an objective function from one iteration to the next
+ is less than min_delta or more than max_iter iterations has been
+ executed.
+ !*/
+
+ objective_delta_stop_strategy& be_verbose(
+ );
+ /*!
+ ensures
+ - causes this object to print status messages to standard out
+ every time should_continue_search() is called.
+ - returns *this
+ !*/
+
+ template <typename T>
+ bool should_continue_search (
+ const T& x,
+ const double funct_value,
+ const T& funct_derivative
+ );
+ /*!
+ requires
+ - this function is only called once per search iteration
+ - for some objective function f():
+ - x == the search point for the current iteration
+ - funct_value == f(x)
+ - funct_derivative == derivative(f)(x)
+ ensures
+ - returns true if the point x doest not satisfy the stopping condition and
+ false otherwise.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class gradient_norm_stop_strategy
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a strategy for deciding if an optimization
+ algorithm should terminate. This particular object looks at the
+ norm (i.e. the length) of the current gradient vector and stops
+ if it is smaller than a user given threshold.
+ !*/
+
+ public:
+ explicit gradient_norm_stop_strategy (
+ double min_norm = 1e-7
+ );
+ /*!
+ requires
+ - min_norm >= 0
+ ensures
+ - This stop strategy object will only consider a search to be complete
+ if the current gradient norm is less than min_norm
+ !*/
+
+ gradient_norm_stop_strategy (
+ double min_norm,
+ unsigned long max_iter
+ );
+ /*!
+ requires
+ - min_norm >= 0
+ - max_iter > 0
+ ensures
+ - This stop strategy object will only consider a search to be complete
+ if the current gradient norm is less than min_norm or more than
+ max_iter iterations has been executed.
+ !*/
+
+ gradient_norm_stop_strategy& be_verbose(
+ );
+ /*!
+ ensures
+ - causes this object to print status messages to standard out
+ every time should_continue_search() is called.
+ - returns *this
+ !*/
+
+ template <typename T>
+ bool should_continue_search (
+ const T& x,
+ const double funct_value,
+ const T& funct_derivative
+ );
+ /*!
+ requires
+ - this function is only called once per search iteration
+ - for some objective function f():
+ - x == the search point for the current iteration
+ - funct_value == f(x)
+ - funct_derivative == derivative(f)(x)
+ ensures
+ - returns true if the point x doest not satisfy the stopping condition and
+ false otherwise.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATIOn_STOP_STRATEGIES_ABSTRACT_
+
diff --git a/ml/dlib/dlib/optimization/optimization_trust_region.h b/ml/dlib/dlib/optimization/optimization_trust_region.h
new file mode 100644
index 000000000..5f0ad897f
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_trust_region.h
@@ -0,0 +1,564 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATION_TRUST_REGIoN_Hh_
+#define DLIB_OPTIMIZATION_TRUST_REGIoN_Hh_
+
+#include "../matrix.h"
+#include "optimization_trust_region_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ unsigned long solve_trust_region_subproblem (
+ const matrix_exp<EXP1>& B,
+ const matrix_exp<EXP2>& g,
+ const typename EXP1::type radius,
+ matrix<T,NR,NC,MM,L>& p,
+ double eps,
+ unsigned long max_iter
+ )
+ {
+ /*
+ This is an implementation of algorithm 4.3(Trust Region Subproblem)
+ from the book Numerical Optimization by Nocedal and Wright. Some of
+ the details are also from Practical Methods of Optimization by Fletcher.
+ */
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(B.nr() == B.nc() && is_col_vector(g) && g.size() == B.nr(),
+ "\t unsigned long solve_trust_region_subproblem()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t B.nr(): " << B.nr()
+ << "\n\t B.nc(): " << B.nc()
+ << "\n\t is_col_vector(g): " << is_col_vector(g)
+ << "\n\t g.size(): " << g.size()
+ );
+ DLIB_ASSERT(radius > 0 && eps > 0 && max_iter > 0,
+ "\t unsigned long solve_trust_region_subproblem()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t radius: " << radius
+ << "\n\t eps: " << eps
+ << "\n\t max_iter: " << max_iter
+ );
+
+
+ const_temp_matrix<EXP1> BB(B);
+ const_temp_matrix<EXP2> gg(g);
+
+ p.set_size(g.nr(),g.nc());
+ p = 0;
+
+
+ const T numeric_eps = max(diag(abs(BB)))*std::numeric_limits<T>::epsilon();
+
+ matrix<T,EXP1::NR,EXP2::NR,MM,L> R;
+
+ T lambda = 0;
+
+ // We need to put a bracket around lambda. It can't go below 0. We
+ // can get an upper bound using Gershgorin disks.
+ // This number is a lower bound on the eigenvalues in BB
+ const T BB_min_eigenvalue = min(diag(BB) - (sum_cols(abs(BB)) - abs(diag(BB))));
+
+ const T g_norm = length(gg);
+
+ T lambda_min = 0;
+ T lambda_max = put_in_range(0,
+ std::numeric_limits<T>::max(),
+ g_norm/radius - BB_min_eigenvalue);
+
+
+ // If we can tell that the minimum is at 0 then don't do anything. Just return the answer.
+ if (g_norm < numeric_eps && BB_min_eigenvalue > numeric_eps)
+ {
+ return 0;
+ }
+
+
+ // how much lambda has changed recently
+ T lambda_delta = 0;
+
+ for (unsigned long i = 0; i < max_iter; ++i)
+ {
+ R = chol(BB + lambda*identity_matrix<T>(BB.nr()));
+
+ // if the cholesky decomposition doesn't exist.
+ if (R(R.nr()-1, R.nc()-1) <= 0)
+ {
+ // If B is indefinite and g is equal to 0 then we should
+ // quit this loop and go right to the eigenvalue decomposition method.
+ if (g_norm <= numeric_eps)
+ break;
+
+ // narrow the bracket on lambda. Obviously the current lambda is
+ // too small.
+ lambda_min = lambda;
+
+ // jump towards the max value. Eventually there will
+ // be a lambda that results in a cholesky decomposition.
+ const T alpha = 0.10;
+ lambda = (1-alpha)*lambda + alpha*lambda_max;
+ continue;
+ }
+
+ using namespace blas_bindings;
+
+ p = -gg;
+ // Solve RR'*p = -g for p.
+ // Solve R*q = -g for q where q = R'*p.
+ if (R.nr() == 2)
+ {
+ p(0) = p(0)/R(0,0);
+ p(1) = (p(1)-R(1,0)*p(0))/R(1,1);
+ }
+ else
+ {
+ triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, R, p);
+ }
+ const T q_norm = length(p);
+
+ // Solve R'*p = q for p.
+ if (R.nr() == 2)
+ {
+ p(1) = p(1)/R(1,1);
+ p(0) = (p(0)-R(1,0)*p(1))/R(0,0);
+ }
+ else
+ {
+ triangular_solver(CblasLeft, CblasLower, CblasTrans, CblasNonUnit, R, p);
+ }
+ const T p_norm = length(p);
+
+ // check if we are done.
+ if (lambda == 0)
+ {
+ if (p_norm < radius)
+ {
+ // i will always be 0 in this case. So we return 1.
+ return i+1;
+ }
+ }
+ else
+ {
+ // if we are close enough to the solution then terminate
+ if (std::abs(p_norm - radius)/radius < eps)
+ return i+1;
+ }
+
+ // shrink our bracket on lambda
+ if (p_norm < radius)
+ lambda_max = lambda;
+ else
+ lambda_min = lambda;
+
+
+ if (p_norm <= radius*std::numeric_limits<T>::epsilon())
+ {
+ const T alpha = 0.01;
+ lambda = (1-alpha)*lambda_min + alpha*lambda_max;
+ continue;
+ }
+
+ const T old_lambda = lambda;
+
+ // figure out which lambda to try next
+ lambda = lambda + std::pow(q_norm/p_norm,2)*(p_norm - radius)/radius;
+
+ // make sure the chosen lambda is within our bracket (but not exactly at either end).
+ const T gap = (lambda_max-lambda_min)*0.01;
+ lambda = put_in_range(lambda_min+gap, lambda_max-gap, lambda);
+
+ // Keep track of how much lambda is thrashing around inside the search bracket. If it
+ // keeps moving around a whole lot then cut the search bracket in half.
+ lambda_delta += std::abs(lambda - old_lambda);
+ if (lambda_delta > 3*(lambda_max-lambda_min))
+ {
+ lambda = (lambda_min+lambda_max)/2;
+ lambda_delta = 0;
+ }
+ } // end for loop
+
+
+ // We are probably in the "hard case". Use an eigenvalue decomposition to sort things out.
+ // Either that or the eps was just set too tight and really we are already done.
+ eigenvalue_decomposition<EXP1> ed(make_symmetric(BB));
+
+ matrix<T,NR,NC,MM,L> ev = ed.get_real_eigenvalues();
+ const long min_eig_idx = index_of_min(ev);
+
+
+ ev -= min(ev);
+ // zero out any values which are basically zero
+ ev = pointwise_multiply(ev, ev > max(abs(ev))*std::numeric_limits<T>::epsilon());
+ ev = reciprocal(ev);
+
+
+ // figure out part of what p should be assuming we are in the hard case.
+ matrix<T,NR,NC,MM,L> p_hard;
+ p_hard = trans(ed.get_pseudo_v())*gg;
+ p_hard = diagm(ev)*p_hard;
+ p_hard = ed.get_pseudo_v()*p_hard;
+
+
+ // If we really are in the hard case then this if will be true. Otherwise, the p
+ // we found in the "easy case" loop up top is the best answer.
+ if (length(p_hard) < radius && length(p_hard) >= length(p))
+ {
+ // adjust the length of p_hard by adding a component along the eigenvector associated with
+ // the smallest eigenvalue. We want to make it the case that length(p) == radius.
+ const T tau = std::sqrt(radius*radius - length_squared(p_hard));
+ p = p_hard + tau*colm(ed.get_pseudo_v(),min_eig_idx);
+
+
+ // if we have to do an eigenvalue decomposition then say we did all the iterations
+ return max_iter;
+ }
+
+ // if we get this far it means we didn't converge to eps accuracy.
+ return max_iter+1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename EXP3
+ >
+ bool bounds_violated (
+ const matrix_exp<EXP1>& v,
+ const matrix_exp<EXP2>& l,
+ const matrix_exp<EXP3>& u
+ )
+ {
+ DLIB_ASSERT(v.nr() == l.nr() && v.nr() == u.nr());
+ DLIB_ASSERT(v.nc() == l.nc() && v.nc() == u.nc());
+ for (long r = 0; r < v.nr(); ++r)
+ {
+ for (long c = 0; c < v.nc(); c++)
+ {
+ if (!(l(r,c) <= v(r,c) && v(r,c) <= u(r,c)))
+ return true;
+ }
+ }
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NR, long NC, typename MM, typename L,
+ typename EXP3
+ >
+ void solve_trust_region_subproblem_bounded (
+ const matrix_exp<EXP1>& B_,
+ const matrix_exp<EXP2>& g_,
+ const typename EXP1::type radius_,
+ matrix<T,NR,NC,MM,L>& p_,
+ double eps,
+ unsigned long max_iter,
+ const matrix_exp<EXP3>& lower_,
+ const matrix_exp<EXP3>& upper_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(B_.nr() == B_.nc() && is_col_vector(g_) && g_.size() == B_.nr(),
+ "\t unsigned long solve_trust_region_subproblem_bounded()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t B_.nr(): " << B_.nr()
+ << "\n\t B_.nc(): " << B_.nc()
+ << "\n\t is_col_vector(g_): " << is_col_vector(g_)
+ << "\n\t g_.size(): " << g_.size()
+ );
+ DLIB_ASSERT(radius_ > 0 && eps > 0 && max_iter > 0,
+ "\t unsigned long solve_trust_region_subproblem_bounded()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t radius_: " << radius_
+ << "\n\t eps: " << eps
+ << "\n\t max_iter: " << max_iter
+ );
+ DLIB_ASSERT(is_col_vector(lower_) && lower_.size() == g_.size());
+ DLIB_ASSERT(is_col_vector(upper_) && upper_.size() == g_.size());
+ DLIB_ASSERT(max(upper_-lower_) >= 0);
+ // make sure the problem is feasible. That is, there should be a point inside the
+ // lower and upper bounds that has a norm <= radius_
+ DLIB_ASSERT(length(clamp(zeros_matrix(lower_),lower_,upper_)) <= radius_,
+ "The lower and upper bounds are incompatible with the radius since there is no point within the bounds with a norm less than the radius.");
+
+ // We are going to solve this by greedily finding the most violated bound constraint,
+ // locking that variable to its constrained value, removing it from the problem,
+ // and then resolving. We do that until no more constraint violations are present.
+
+ solve_trust_region_subproblem(B_,g_,radius_,p_,eps,max_iter);
+
+
+ // just stop here if all the bounds are satisfied.
+ if (!impl::bounds_violated(p_, lower_, upper_))
+ return;
+
+ matrix<double> B = matrix_cast<double>(B_);
+ matrix<double,0,1> g = matrix_cast<double>(g_);
+ double radius = radius_;
+ matrix<double,0,1> p = matrix_cast<double>(p_);
+ matrix<double,0,1> lower = matrix_cast<double>(lower_);
+ matrix<double,0,1> upper = matrix_cast<double>(upper_);
+
+ // keep a table that tells us how to map any reduced QP back to the original QP
+ std::vector<long> idxs(g.size());
+ for (size_t i = 0; i < idxs.size(); ++i)
+ idxs[i] = i;
+
+
+ // while we haven't found a p that satisfies the bounds constraints
+ while(impl::bounds_violated(p, lower, upper) )
+ {
+ // Find the most violated variable and fix its value to a constant (the bound
+ // value).
+ long most_violated_idx = 0;
+ double max_violation = 0;
+ double bounded_value = 0;
+ for (long i = 0; i < lower.size(); ++i)
+ {
+ if (!(lower(i) <= p(i) && p(i) <= upper(i)))
+ {
+ if (lower(i)-p(i) > max_violation)
+ {
+ max_violation = lower(i)-p(i);
+ most_violated_idx = i;
+ bounded_value = lower(i);
+ }
+ else if (p(i)-upper(i) > max_violation)
+ {
+ max_violation = p(i)-upper(i);
+ most_violated_idx = i;
+ bounded_value = upper(i);
+ }
+ }
+ }
+
+ // assign this variable to its final value.
+ p_(idxs[most_violated_idx]) = bounded_value;
+
+
+ // now reduce the QP by removing the variable p_(idxs[most_violated_idx]).
+ idxs.erase(idxs.begin()+most_violated_idx);
+ // we are out of variables to remove since everything is at bounds.
+ if (idxs.size() == 0)
+ break;
+
+ lower = remove_row(lower,most_violated_idx);
+ upper = remove_row(upper,most_violated_idx);
+ g += colm(B,most_violated_idx)*bounded_value;
+ g = remove_row(g,most_violated_idx);
+ p = remove_row(p,most_violated_idx);
+ B = removerc(B,most_violated_idx, most_violated_idx);
+
+ // Removing a variable changes the radius, so we have to subtract the bounded
+ // value from the radius so as to not change the effective radius for the whole
+ // problem.
+ double squared_radius = radius*radius - bounded_value*bounded_value;
+ if (squared_radius <= 0)
+ {
+ p = 0;
+ break;
+ }
+ radius = std::sqrt(squared_radius);
+
+
+ solve_trust_region_subproblem(B,g,radius,p,eps,max_iter);
+ }
+
+ // assign the non-bound-constrained variables to their final values
+ for (size_t i = 0; i < idxs.size(); ++i)
+ p_(idxs[i]) = p(i);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stop_strategy_type,
+ typename funct_model
+ >
+ double find_min_trust_region (
+ stop_strategy_type stop_strategy,
+ const funct_model& model,
+ typename funct_model::column_vector& x,
+ double radius = 1
+ )
+ {
+ /*
+ This is an implementation of algorithm 4.1(Trust Region)
+ from the book Numerical Optimization by Nocedal and Wright.
+ */
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(x) && radius > 0,
+ "\t double find_min_trust_region()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t radius: " << radius
+ );
+
+ const double initial_radius = radius;
+
+ typedef typename funct_model::column_vector T;
+ typedef typename T::type type;
+
+ typename funct_model::general_matrix h;
+ typename funct_model::column_vector g, p, d;
+ type f_value = model(x);
+
+ model.get_derivative_and_hessian(x,g,h);
+
+ DLIB_ASSERT(is_finite(x), "The objective function generated non-finite outputs");
+ DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs");
+ DLIB_ASSERT(is_finite(h), "The objective function generated non-finite outputs");
+
+ // Sometimes the loop below won't modify x because the trust region step failed.
+ // This bool tells us when we are in that case.
+ bool stale_x = false;
+
+ while(stale_x || stop_strategy.should_continue_search(x, f_value, g))
+ {
+ const unsigned long iter = solve_trust_region_subproblem(h,
+ g,
+ radius,
+ p,
+ 0.1,
+ 20);
+
+
+ const type new_f_value = model(x+p);
+ const type predicted_improvement = -0.5*trans(p)*h*p - trans(g)*p;
+ const type measured_improvement = (f_value - new_f_value);
+
+ // If the sub-problem can't find a way to improve then stop. This only happens when p is essentially 0.
+ if (std::abs(predicted_improvement) <= std::abs(measured_improvement)*std::numeric_limits<type>::epsilon())
+ break;
+
+ // predicted_improvement shouldn't be negative but it might be if something went
+ // wrong in the trust region solver. So put abs() here to guard against that. This
+ // way the sign of rho is determined only by the sign of measured_improvement.
+ const type rho = measured_improvement/std::abs(predicted_improvement);
+
+
+ if (!is_finite(rho))
+ break;
+
+ if (rho < 0.25)
+ {
+ radius *= 0.25;
+
+ // something has gone horribly wrong if the radius has shrunk to zero. So just
+ // give up if that happens.
+ if (radius <= initial_radius*std::numeric_limits<double>::epsilon())
+ break;
+ }
+ else
+ {
+ // if rho > 0.75 and we are being checked by the radius
+ if (rho > 0.75 && iter > 1)
+ {
+ radius = std::min<type>(1000, 2*radius);
+ }
+ }
+
+ if (rho > 0)
+ {
+ x = x + p;
+ f_value = new_f_value;
+ model.get_derivative_and_hessian(x,g,h);
+ stale_x = false;
+ }
+ else
+ {
+ stale_x = true;
+ }
+
+ DLIB_ASSERT(is_finite(x), "The objective function generated non-finite outputs");
+ DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs");
+ DLIB_ASSERT(is_finite(h), "The objective function generated non-finite outputs");
+ }
+
+ return f_value;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct_model>
+ struct negate_tr_model
+ {
+ negate_tr_model( const funct_model& m) : model(m) {}
+
+ const funct_model& model;
+
+ typedef typename funct_model::column_vector column_vector;
+ typedef typename funct_model::general_matrix general_matrix;
+
+ template <typename T>
+ typename T::type operator() (const T& x) const
+ {
+ return -model(x);
+ }
+
+ template <typename T, typename U>
+ void get_derivative_and_hessian (
+ const T& x,
+ T& d,
+ U& h
+ ) const
+ {
+ model.get_derivative_and_hessian(x,d,h);
+ d = -d;
+ h = -h;
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stop_strategy_type,
+ typename funct_model
+ >
+ double find_max_trust_region (
+ stop_strategy_type stop_strategy,
+ const funct_model& model,
+ typename funct_model::column_vector& x,
+ double radius = 1
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(x) && radius > 0,
+ "\t double find_max_trust_region()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t radius: " << radius
+ );
+
+ return -find_min_trust_region(stop_strategy,
+ negate_tr_model<funct_model>(model),
+ x,
+ radius);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATION_TRUST_REGIoN_Hh_
+
diff --git a/ml/dlib/dlib/optimization/optimization_trust_region_abstract.h b/ml/dlib/dlib/optimization/optimization_trust_region_abstract.h
new file mode 100644
index 000000000..303ee746d
--- /dev/null
+++ b/ml/dlib/dlib/optimization/optimization_trust_region_abstract.h
@@ -0,0 +1,233 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_OPTIMIZATION_TRUST_REGIoN_H_ABSTRACTh_
+#ifdef DLIB_OPTIMIZATION_TRUST_REGIoN_H_ABSTRACTh_
+
+#include "../matrix/matrix_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ unsigned long solve_trust_region_subproblem (
+ const matrix_exp<EXP1>& B,
+ const matrix_exp<EXP2>& g,
+ const typename EXP1::type radius,
+ matrix<T,NR,NC,MM,L>& p,
+ double eps,
+ unsigned long max_iter
+ );
+ /*!
+ requires
+ - B == trans(B)
+ (i.e. B should be a symmetric matrix)
+ - B.nr() == B.nc()
+ - is_col_vector(g) == true
+ - g.size() == B.nr()
+ - p is capable of containing a column vector the size of g
+ (i.e. p = g; should be a legal expression)
+ - radius > 0
+ - eps > 0
+ - max_iter > 0
+ ensures
+ - This function solves the following optimization problem:
+ Minimize: f(p) == 0.5*trans(p)*B*p + trans(g)*p
+ subject to the following constraint:
+ - length(p) <= radius
+ - returns the number of iterations performed. If this method fails to converge
+ to eps accuracy then the number returned will be max_iter+1.
+ - if (this function didn't terminate due to hitting the max_iter iteration limit) then
+ - if this function returns 0 or 1 then we are not hitting the radius bound Otherwise,
+ the radius constraint is active and std::abs(length(#p)-radius)/radius <= eps.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NR, long NC, typename MM, typename L,
+ typename EXP3
+ >
+ void solve_trust_region_subproblem_bounded (
+ const matrix_exp<EXP1>& B,
+ const matrix_exp<EXP2>& g,
+ const typename EXP1::type radius,
+ matrix<T,NR,NC,MM,L>& p,
+ double eps,
+ unsigned long max_iter,
+ const matrix_exp<EXP3>& lower,
+ const matrix_exp<EXP3>& upper
+ );
+ /*!
+ requires
+ - B == trans(B)
+ (i.e. B should be a symmetric matrix)
+ - B.nr() == B.nc()
+ - is_col_vector(g) == true
+ - is_col_vector(lower) == true
+ - is_col_vector(upper) == true
+ - g.size() == B.nr()
+ - lower.size() == B.nr()
+ - upper.size() == B.nr()
+ - p is capable of containing a column vector the size of g
+ (i.e. p = g; should be a legal expression)
+ - radius > 0
+ - eps > 0
+ - max_iter > 0
+ - min(upper-lower) >= 0
+ - length(clamp(zeros_matrix(lower),lower,upper)) <= radius
+ (i.e. the lower and upper bounds can't exclude all points with the desired radius.)
+ ensures
+ - This function solves the following optimization problem:
+ Minimize: f(p) == 0.5*trans(p)*B*p + trans(g)*p
+ subject to the following constraints:
+ - length(p) <= radius
+ - lower(i) <= p(i) <= upper(i), for all i
+ - Solves the problem to eps accuracy. We do this by greedily finding the most
+ violated bound constraint, locking that variable to its constrained value, removing
+ it from the problem, and then resolving. We do that until no more constraint
+ violations are present. Each time we just call solve_trust_region_subproblem()
+ to get the solution and pass eps and max_iter directly to these calls to
+ solve_trust_region_subproblem().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class function_model
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines the interface for a function model
+ used by the trust-region optimizers defined below.
+
+ In particular, this object represents a function f() and
+ its associated derivative and hessian.
+
+ !*/
+
+ public:
+
+ // Define the type used to represent column vectors
+ typedef matrix<double,0,1> column_vector;
+ // Define the type used to represent the hessian matrix
+ typedef matrix<double> general_matrix;
+
+ double operator() (
+ const column_vector& x
+ ) const;
+ /*!
+ ensures
+ - returns f(x)
+ (i.e. evaluates this model at the given point and returns the value)
+ !*/
+
+ void get_derivative_and_hessian (
+ const column_vector& x,
+ column_vector& d,
+ general_matrix& h
+ ) const;
+ /*!
+ ensures
+ - #d == the derivative of f() at x
+ - #h == the hessian matrix of f() at x
+ - is_col_vector(#d) == true
+ - #d.size() == x.size()
+ - #h.nr() == #h.nc() == x.size()
+ - #h == trans(#h)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stop_strategy_type,
+ typename funct_model
+ >
+ double find_min_trust_region (
+ stop_strategy_type stop_strategy,
+ const funct_model& model,
+ typename funct_model::column_vector& x,
+ double radius = 1
+ );
+ /*!
+ requires
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - is_col_vector(x) == true
+ - radius > 0
+ - model must be an object with an interface as defined by the function_model
+ example object shown above.
+ ensures
+ - Performs an unconstrained minimization of the function defined by model
+ starting from the initial point x. This function uses a trust region
+ algorithm to perform the minimization. The radius parameter defines
+ the initial size of the trust region.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found or the trust region subproblem fails to make progress.
+ - #x == the value of x that was found to minimize model()
+ - returns model(#x).
+ - When this function makes calls to model.get_derivative_and_hessian() it always
+ does so by first calling model() and then calling model.get_derivative_and_hessian().
+ That is, any call to model.get_derivative_and_hessian(val) will always be
+ preceded by a call to model(val) with the same value. This way you can reuse
+ any redundant computations performed by model() and model.get_derivative_and_hessian()
+ as appropriate.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stop_strategy_type,
+ typename funct_model
+ >
+ double find_max_trust_region (
+ stop_strategy_type stop_strategy,
+ const funct_model& model,
+ typename funct_model::column_vector& x,
+ double radius = 1
+ );
+ /*!
+ requires
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - is_col_vector(x) == true
+ - radius > 0
+ - model must be an object with an interface as defined by the function_model
+ example object shown above.
+ ensures
+ - Performs an unconstrained maximization of the function defined by model
+ starting from the initial point x. This function uses a trust region
+ algorithm to perform the maximization. The radius parameter defines
+ the initial size of the trust region.
+ - The function is optimized until stop_strategy decides that an acceptable
+ point has been found or the trust region subproblem fails to make progress.
+ - #x == the value of x that was found to maximize model()
+ - returns model(#x).
+ - When this function makes calls to model.get_derivative_and_hessian() it always
+ does so by first calling model() and then calling model.get_derivative_and_hessian().
+ That is, any call to model.get_derivative_and_hessian(val) will always be
+ preceded by a call to model(val) with the same value. This way you can reuse
+ any redundant computations performed by model() and model.get_derivative_and_hessian()
+ as appropriate.
+ - Note that this function solves the maximization problem by converting it
+ into a minimization problem. Therefore, the values of model() and its derivative
+ reported to the stopping strategy will be negated. That is, stop_strategy
+ will see -model() and -derivative. All this really means is that the status
+ messages from a stopping strategy in verbose mode will display a negated objective
+ value.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_OPTIMIZATION_TRUST_REGIoN_H_ABSTRACTh_
+
+
diff --git a/ml/dlib/dlib/ostream b/ml/dlib/dlib/ostream
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/ostream
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/pipe.h b/ml/dlib/dlib/pipe.h
new file mode 100644
index 000000000..023b985bf
--- /dev/null
+++ b/ml/dlib/dlib/pipe.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PIPe_
+#define DLIB_PIPe_
+
+#include "pipe/pipe_kernel_1.h"
+
+
+#endif // DLIB_PIPe_
+
diff --git a/ml/dlib/dlib/pipe/pipe_kernel_1.h b/ml/dlib/dlib/pipe/pipe_kernel_1.h
new file mode 100644
index 000000000..543754121
--- /dev/null
+++ b/ml/dlib/dlib/pipe/pipe_kernel_1.h
@@ -0,0 +1,756 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PIPE_KERNEl_1_
+#define DLIB_PIPE_KERNEl_1_
+
+#include "../algs.h"
+#include "../threads.h"
+#include "pipe_kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class pipe
+ {
+ /*!
+ INITIAL VALUE
+ - pipe_size == 0
+ - pipe_max_size == defined by constructor
+ - enabled == true
+ - data == a pointer to an array of ((pipe_max_size>0)?pipe_max_size:1) T objects.
+ - dequeue_waiters == 0
+ - enqueue_waiters == 0
+ - first == 1
+ - last == 1
+ - unblock_sig_waiters == 0
+
+ CONVENTION
+ - size() == pipe_size
+ - max_size() == pipe_max_size
+ - is_enabled() == enabled
+
+ - m == the mutex used to lock access to all the members of this class
+
+ - dequeue_waiters == the number of threads blocked on calls to dequeue()
+ - enqueue_waiters == the number of threads blocked on calls to enqueue() and
+ wait_until_empty()
+ - unblock_sig_waiters == the number of threads blocked on calls to
+ wait_for_num_blocked_dequeues() and the destructor. (i.e. the number of
+ blocking calls to unblock_sig.wait())
+
+ - dequeue_sig == the signaler that threads blocked on calls to dequeue() wait on
+ - enqueue_sig == the signaler that threads blocked on calls to enqueue()
+ or wait_until_empty() wait on.
+ - unblock_sig == the signaler that is signaled when a thread stops blocking on a call
+ to enqueue() or dequeue(). It is also signaled when a dequeue that will probably
+ block is called. The destructor and wait_for_num_blocked_dequeues are the only
+ things that will wait on this signaler.
+
+ - if (pipe_size > 0) then
+ - data[first] == the next item to dequeue
+ - data[last] == the item most recently added via enqueue, so the last to dequeue.
+ - else if (pipe_max_size == 0)
+ - if (first == 0 && last == 0) then
+ - data[0] == the next item to dequeue
+ - else if (first == 0 && last == 1) then
+ - data[0] has been taken out already by a dequeue
+ !*/
+
+ public:
+ // this is here for backwards compatibility with older versions of dlib.
+ typedef pipe kernel_1a;
+
+ typedef T type;
+
+ explicit pipe (
+ size_t maximum_size
+ );
+
+ virtual ~pipe (
+ );
+
+ void empty (
+ );
+
+ void wait_until_empty (
+ ) const;
+
+ void wait_for_num_blocked_dequeues (
+ unsigned long num
+ )const;
+
+ void enable (
+ );
+
+ void disable (
+ );
+
+ bool is_enqueue_enabled (
+ ) const;
+
+ void disable_enqueue (
+ );
+
+ void enable_enqueue (
+ );
+
+ bool is_dequeue_enabled (
+ ) const;
+
+ void disable_dequeue (
+ );
+
+ void enable_dequeue (
+ );
+
+ bool is_enabled (
+ ) const;
+
+ size_t max_size (
+ ) const;
+
+ size_t size (
+ ) const;
+
+ bool enqueue (
+ T& item
+ );
+
+ bool enqueue (
+ T&& item
+ ) { return enqueue(item); }
+
+ bool dequeue (
+ T& item
+ );
+
+ bool enqueue_or_timeout (
+ T& item,
+ unsigned long timeout
+ );
+
+ bool enqueue_or_timeout (
+ T&& item,
+ unsigned long timeout
+ ) { return enqueue_or_timeout(item,timeout); }
+
+ bool dequeue_or_timeout (
+ T& item,
+ unsigned long timeout
+ );
+
+ private:
+
+ size_t pipe_size;
+ const size_t pipe_max_size;
+ bool enabled;
+
+ T* const data;
+
+ size_t first;
+ size_t last;
+
+ mutex m;
+ signaler dequeue_sig;
+ signaler enqueue_sig;
+ signaler unblock_sig;
+
+ unsigned long dequeue_waiters;
+ mutable unsigned long enqueue_waiters;
+ mutable unsigned long unblock_sig_waiters;
+ bool enqueue_enabled;
+ bool dequeue_enabled;
+
+ // restricted functions
+ pipe(const pipe&); // copy constructor
+ pipe& operator=(const pipe&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ pipe<T>::
+ pipe (
+ size_t maximum_size
+ ) :
+ pipe_size(0),
+ pipe_max_size(maximum_size),
+ enabled(true),
+ data(new T[(maximum_size>0) ? maximum_size : 1]),
+ first(1),
+ last(1),
+ dequeue_sig(m),
+ enqueue_sig(m),
+ unblock_sig(m),
+ dequeue_waiters(0),
+ enqueue_waiters(0),
+ unblock_sig_waiters(0),
+ enqueue_enabled(true),
+ dequeue_enabled(true)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ pipe<T>::
+ ~pipe (
+ )
+ {
+ auto_mutex M(m);
+ ++unblock_sig_waiters;
+
+ // first make sure no one is blocked on any calls to enqueue() or dequeue()
+ enabled = false;
+ dequeue_sig.broadcast();
+ enqueue_sig.broadcast();
+ unblock_sig.broadcast();
+
+ // wait for all threads to unblock
+ while (dequeue_waiters > 0 || enqueue_waiters > 0 || unblock_sig_waiters > 1)
+ unblock_sig.wait();
+
+ delete [] data;
+ --unblock_sig_waiters;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void pipe<T>::
+ empty (
+ )
+ {
+ auto_mutex M(m);
+ pipe_size = 0;
+
+ // let any calls to enqueue() know that the pipe is now empty
+ if (enqueue_waiters > 0)
+ enqueue_sig.broadcast();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void pipe<T>::
+ wait_until_empty (
+ ) const
+ {
+ auto_mutex M(m);
+ // this function is sort of like a call to enqueue so treat it like that
+ ++enqueue_waiters;
+
+ while (pipe_size > 0 && enabled && dequeue_enabled )
+ enqueue_sig.wait();
+
+ // let the destructor know we are ending if it is blocked waiting
+ if (enabled == false)
+ unblock_sig.broadcast();
+
+ --enqueue_waiters;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void pipe<T>::
+ enable (
+ )
+ {
+ auto_mutex M(m);
+ enabled = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void pipe<T>::
+ disable (
+ )
+ {
+ auto_mutex M(m);
+ enabled = false;
+ dequeue_sig.broadcast();
+ enqueue_sig.broadcast();
+ unblock_sig.broadcast();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool pipe<T>::
+ is_enabled (
+ ) const
+ {
+ auto_mutex M(m);
+ return enabled;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ size_t pipe<T>::
+ max_size (
+ ) const
+ {
+ auto_mutex M(m);
+ return pipe_max_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ size_t pipe<T>::
+ size (
+ ) const
+ {
+ auto_mutex M(m);
+ return pipe_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool pipe<T>::
+ enqueue (
+ T& item
+ )
+ {
+ auto_mutex M(m);
+ ++enqueue_waiters;
+
+ // wait until there is room or we are disabled
+ while (pipe_size == pipe_max_size && enabled && enqueue_enabled &&
+ !(pipe_max_size == 0 && first == 1) )
+ enqueue_sig.wait();
+
+ if (enabled == false || enqueue_enabled == false)
+ {
+ --enqueue_waiters;
+ // let the destructor know we are unblocking
+ unblock_sig.broadcast();
+ return false;
+ }
+
+ // set the appropriate values for first and last
+ if (pipe_size == 0)
+ {
+ first = 0;
+ last = 0;
+ }
+ else
+ {
+ last = (last+1)%pipe_max_size;
+ }
+
+
+ exchange(item,data[last]);
+
+ // wake up a call to dequeue() if there are any currently blocked
+ if (dequeue_waiters > 0)
+ dequeue_sig.signal();
+
+ if (pipe_max_size > 0)
+ {
+ ++pipe_size;
+ }
+ else
+ {
+ // wait for a dequeue to take the item out
+ while (last == 0 && enabled && enqueue_enabled)
+ enqueue_sig.wait();
+
+ if (last == 0 && (enabled == false || enqueue_enabled == false))
+ {
+ last = 1;
+ first = 1;
+
+ // no one dequeued this object to put it back into item
+ exchange(item,data[0]);
+
+ --enqueue_waiters;
+ // let the destructor know we are unblocking
+ if (unblock_sig_waiters > 0)
+ unblock_sig.broadcast();
+ return false;
+ }
+
+ last = 1;
+ first = 1;
+
+ // tell any waiting calls to enqueue() that one of them can proceed
+ if (enqueue_waiters > 1)
+ enqueue_sig.broadcast();
+
+ // let the destructor know we are unblocking
+ if (enabled == false && unblock_sig_waiters > 0)
+ unblock_sig.broadcast();
+ }
+
+ --enqueue_waiters;
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool pipe<T>::
+ dequeue (
+ T& item
+ )
+ {
+ auto_mutex M(m);
+ ++dequeue_waiters;
+
+ if (pipe_size == 0)
+ {
+ // notify wait_for_num_blocked_dequeues()
+ if (unblock_sig_waiters > 0)
+ unblock_sig.broadcast();
+
+ // notify any blocked enqueue_or_timeout() calls
+ if (enqueue_waiters > 0)
+ enqueue_sig.broadcast();
+ }
+
+ // wait until there is something in the pipe or we are disabled
+ while (pipe_size == 0 && enabled && dequeue_enabled &&
+ !(pipe_max_size == 0 && first == 0 && last == 0) )
+ dequeue_sig.wait();
+
+ if (enabled == false || dequeue_enabled == false)
+ {
+ --dequeue_waiters;
+ // let the destructor know we are unblocking
+ unblock_sig.broadcast();
+ return false;
+ }
+
+ exchange(item,data[first]);
+
+ if (pipe_max_size > 0)
+ {
+ // set the appropriate values for first
+ first = (first+1)%pipe_max_size;
+
+ --pipe_size;
+ }
+ else
+ {
+ // let the enqueue waiting on us know that we took the
+ // item out already.
+ last = 1;
+ }
+
+ // wake up a call to enqueue() if there are any currently blocked
+ if (enqueue_waiters > 0)
+ enqueue_sig.broadcast();
+
+ --dequeue_waiters;
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool pipe<T>::
+ enqueue_or_timeout (
+ T& item,
+ unsigned long timeout
+ )
+ {
+ auto_mutex M(m);
+ ++enqueue_waiters;
+
+ // wait until there is room or we are disabled or
+ // we run out of time.
+ bool timed_out = false;
+ while (pipe_size == pipe_max_size && enabled && enqueue_enabled &&
+ !(pipe_max_size == 0 && dequeue_waiters > 0 && first == 1) )
+ {
+ if (timeout == 0 || enqueue_sig.wait_or_timeout(timeout) == false)
+ {
+ timed_out = true;
+ break;
+ }
+ }
+
+ if (enabled == false || timed_out || enqueue_enabled == false)
+ {
+ --enqueue_waiters;
+ // let the destructor know we are unblocking
+ unblock_sig.broadcast();
+ return false;
+ }
+
+ // set the appropriate values for first and last
+ if (pipe_size == 0)
+ {
+ first = 0;
+ last = 0;
+ }
+ else
+ {
+ last = (last+1)%pipe_max_size;
+ }
+
+
+ exchange(item,data[last]);
+
+ // wake up a call to dequeue() if there are any currently blocked
+ if (dequeue_waiters > 0)
+ dequeue_sig.signal();
+
+ if (pipe_max_size > 0)
+ {
+ ++pipe_size;
+ }
+ else
+ {
+ // wait for a dequeue to take the item out
+ while (last == 0 && enabled && enqueue_enabled)
+ enqueue_sig.wait();
+
+ if (last == 0 && (enabled == false || enqueue_enabled == false))
+ {
+ last = 1;
+ first = 1;
+
+ // no one dequeued this object to put it back into item
+ exchange(item,data[0]);
+
+ --enqueue_waiters;
+ // let the destructor know we are unblocking
+ if (unblock_sig_waiters > 0)
+ unblock_sig.broadcast();
+ return false;
+ }
+
+ last = 1;
+ first = 1;
+
+ // tell any waiting calls to enqueue() that one of them can proceed
+ if (enqueue_waiters > 1)
+ enqueue_sig.broadcast();
+
+ // let the destructor know we are unblocking
+ if (enabled == false && unblock_sig_waiters > 0)
+ unblock_sig.broadcast();
+ }
+
+ --enqueue_waiters;
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool pipe<T>::
+ dequeue_or_timeout (
+ T& item,
+ unsigned long timeout
+ )
+ {
+ auto_mutex M(m);
+ ++dequeue_waiters;
+
+ if (pipe_size == 0)
+ {
+ // notify wait_for_num_blocked_dequeues()
+ if (unblock_sig_waiters > 0)
+ unblock_sig.broadcast();
+
+ // notify any blocked enqueue_or_timeout() calls
+ if (enqueue_waiters > 0)
+ enqueue_sig.broadcast();
+ }
+
+ bool timed_out = false;
+ // wait until there is something in the pipe or we are disabled or we timeout.
+ while (pipe_size == 0 && enabled && dequeue_enabled &&
+ !(pipe_max_size == 0 && first == 0 && last == 0) )
+ {
+ if (timeout == 0 || dequeue_sig.wait_or_timeout(timeout) == false)
+ {
+ timed_out = true;
+ break;
+ }
+ }
+
+ if (enabled == false || timed_out || dequeue_enabled == false)
+ {
+ --dequeue_waiters;
+ // let the destructor know we are unblocking
+ unblock_sig.broadcast();
+ return false;
+ }
+
+ exchange(item,data[first]);
+
+ if (pipe_max_size > 0)
+ {
+ // set the appropriate values for first
+ first = (first+1)%pipe_max_size;
+
+ --pipe_size;
+ }
+ else
+ {
+ // let the enqueue waiting on us know that we took the
+ // item out already.
+ last = 1;
+ }
+
+ // wake up a call to enqueue() if there are any currently blocked
+ if (enqueue_waiters > 0)
+ enqueue_sig.broadcast();
+
+ --dequeue_waiters;
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void pipe<T>::
+ wait_for_num_blocked_dequeues (
+ unsigned long num
+ )const
+ {
+ auto_mutex M(m);
+ ++unblock_sig_waiters;
+
+ while ( (dequeue_waiters < num || pipe_size != 0) && enabled && dequeue_enabled)
+ unblock_sig.wait();
+
+ // let the destructor know we are ending if it is blocked waiting
+ if (enabled == false)
+ unblock_sig.broadcast();
+
+ --unblock_sig_waiters;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool pipe<T>::
+ is_enqueue_enabled (
+ ) const
+ {
+ auto_mutex M(m);
+ return enqueue_enabled;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void pipe<T>::
+ disable_enqueue (
+ )
+ {
+ auto_mutex M(m);
+ enqueue_enabled = false;
+ enqueue_sig.broadcast();
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void pipe<T>::
+ enable_enqueue (
+ )
+ {
+ auto_mutex M(m);
+ enqueue_enabled = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool pipe<T>::
+ is_dequeue_enabled (
+ ) const
+ {
+ auto_mutex M(m);
+ return dequeue_enabled;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void pipe<T>::
+ disable_dequeue (
+ )
+ {
+ auto_mutex M(m);
+ dequeue_enabled = false;
+ dequeue_sig.broadcast();
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void pipe<T>::
+ enable_dequeue (
+ )
+ {
+ auto_mutex M(m);
+ dequeue_enabled = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_PIPE_KERNEl_1_
+
diff --git a/ml/dlib/dlib/pipe/pipe_kernel_abstract.h b/ml/dlib/dlib/pipe/pipe_kernel_abstract.h
new file mode 100644
index 000000000..91b2205e7
--- /dev/null
+++ b/ml/dlib/dlib/pipe/pipe_kernel_abstract.h
@@ -0,0 +1,323 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_PIPE_KERNEl_ABSTRACT_
+#ifdef DLIB_PIPE_KERNEl_ABSTRACT_
+
+#include "../threads.h"
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class pipe
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be swappable by a global swap()
+ T must have a default constructor
+
+ INITIAL VALUE
+ size() == 0
+ is_enabled() == true
+ is_enqueue_enabled() == true
+ is_dequeue_enabled() == true
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a first in first out queue with a fixed maximum size containing
+ items of type T. It is suitable for passing objects between threads.
+
+ THREAD SAFETY
+ All methods of this class are thread safe. You may call them from any
+ thread and any number of threads my call them at once.
+ !*/
+
+ public:
+
+ typedef T type;
+
+ explicit pipe (
+ size_t maximum_size
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #max_size() == maximum_size
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~pipe (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ - disables (i.e. sets is_enabled() == false) this object so that
+ all calls currently blocking on it will return immediately.
+ !*/
+
+ void enable (
+ );
+ /*!
+ ensures
+ - #is_enabled() == true
+ !*/
+
+ void disable (
+ );
+ /*!
+ ensures
+ - #is_enabled() == false
+ - causes all current and future calls to enqueue(), dequeue(),
+ enqueue_or_timeout() and dequeue_or_timeout() to not block but
+ to return false immediately until enable() is called.
+ - causes all current and future calls to wait_until_empty() and
+ wait_for_num_blocked_dequeues() to not block but return
+ immediately until enable() is called.
+ !*/
+
+ bool is_enabled (
+ ) const;
+ /*!
+ ensures
+ - returns true if this pipe is currently enabled, false otherwise.
+ !*/
+
+ void empty (
+ );
+ /*!
+ ensures
+ - #size() == 0
+ !*/
+
+ void wait_until_empty (
+ ) const;
+ /*!
+ ensures
+ - blocks until one of the following is the case:
+ - size() == 0
+ - is_enabled() == false
+ - is_dequeue_enabled() == false
+ !*/
+
+ void wait_for_num_blocked_dequeues (
+ unsigned long num
+ ) const;
+ /*!
+ ensures
+ - blocks until one of the following is the case:
+ - size() == 0 and the number of threads blocked on calls
+ to dequeue() and dequeue_or_timeout() is greater than
+ or equal to num.
+ - is_enabled() == false
+ - is_dequeue_enabled() == false
+ !*/
+
+ bool is_enqueue_enabled (
+ ) const;
+ /*!
+ ensures
+ - returns true if the enqueue() and enqueue_or_timeout() functions are
+ currently enabled, returns false otherwise. (note that the higher
+ level is_enabled() function can overrule this one. So if
+ is_enabled() == false then enqueue functions are still disabled even
+ if is_enqueue_enabled() returns true. But if is_enqueue_enabled() == false
+ then enqueue functions are always disabled no matter the state of
+ is_enabled())
+ !*/
+
+ void disable_enqueue (
+ );
+ /*!
+ ensures
+ - #is_enqueue_enabled() == false
+ - causes all current and future calls to enqueue() and
+ enqueue_or_timeout() to not block but to return false
+ immediately until enable_enqueue() is called.
+ !*/
+
+ void enable_enqueue (
+ );
+ /*!
+ ensures
+ - #is_enqueue_enabled() == true
+ !*/
+
+ bool is_dequeue_enabled (
+ ) const;
+ /*!
+ ensures
+ - returns true if the dequeue() and dequeue_or_timeout() functions are
+ currently enabled, returns false otherwise. (note that the higher
+ level is_enabled() function can overrule this one. So if
+ is_enabled() == false then dequeue functions are still disabled even
+ if is_dequeue_enabled() returns true. But if is_dequeue_enabled() == false
+ then dequeue functions are always disabled no matter the state of
+ is_enabled())
+ !*/
+
+ void disable_dequeue (
+ );
+ /*!
+ ensures
+ - #is_dequeue_enabled() == false
+ - causes all current and future calls to dequeue() and
+ dequeue_or_timeout() to not block but to return false
+ immediately until enable_dequeue() is called.
+ !*/
+
+ void enable_dequeue (
+ );
+ /*!
+ ensures
+ - #is_dequeue_enabled() == true
+ !*/
+
+ size_t max_size (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of objects of type T that this
+ pipe can contain.
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of objects of type T that this
+ object currently contains.
+ !*/
+
+ bool enqueue (
+ T& item
+ );
+ /*!
+ ensures
+ - if (size() == max_size()) then
+ - this call to enqueue() blocks until one of the following is the case:
+ - there is room in the pipe for another item
+ - max_size() == 0 and another thread is trying to dequeue from this
+ pipe and we can pass our item object directly to that thread.
+ - someone calls disable()
+ - someone calls disable_enqueue()
+ - else
+ - this call does not block.
+ - if (this call to enqueue() returns true) then
+ - #is_enabled() == true
+ - #is_enqueue_enabled() == true
+ - if (max_size() == 0) then
+ - using global swap, item was passed directly to a
+ thread attempting to dequeue from this pipe
+ - else
+ - using global swap, item was added into this pipe.
+ - #item is in an undefined but valid state for its type
+ - else
+ - item was NOT added into the pipe
+ - #item == item (i.e. the value of item is unchanged)
+ !*/
+
+ bool enqueue (T&& item) { return enqueue(item); }
+ /*!
+ enable enqueueing from rvalues
+ !*/
+
+ bool enqueue_or_timeout (
+ T& item,
+ unsigned long timeout
+ );
+ /*!
+ ensures
+ - if (size() == max_size() && timeout > 0) then
+ - this call to enqueue_or_timeout() blocks until one of the following is the case:
+ - there is room in the pipe to add another item
+ - max_size() == 0 and another thread is trying to dequeue from this pipe
+ and we can pass our item object directly to that thread.
+ - someone calls disable()
+ - someone calls disable_enqueue()
+ - timeout milliseconds passes
+ - else
+ - this call does not block.
+ - if (this call to enqueue() returns true) then
+ - #is_enabled() == true
+ - #is_enqueue_enabled() == true
+ - if (max_size() == 0) then
+ - using global swap, item was passed directly to a
+ thread attempting to dequeue from this pipe
+ - else
+ - using global swap, item was added into this pipe.
+ - #item is in an undefined but valid state for its type
+ - else
+ - item was NOT added into the pipe
+ - #item == item (i.e. the value of item is unchanged)
+ !*/
+
+ bool enqueue_or_timeout (T&& item, unsigned long timeout) { return enqueue_or_timeout(item,timeout); }
+ /*!
+ enable enqueueing from rvalues
+ !*/
+
+ bool dequeue (
+ T& item
+ );
+ /*!
+ ensures
+ - if (size() == 0) then
+ - this call to dequeue() blocks until one of the following is the case:
+ - there is something in the pipe we can dequeue
+ - max_size() == 0 and another thread is trying to enqueue an item
+ onto this pipe and we can receive our item directly from that thread.
+ - someone calls disable()
+ - someone calls disable_dequeue()
+ - else
+ - this call does not block.
+ - if (this call to dequeue() returns true) then
+ - #is_enabled() == true
+ - #is_dequeue_enabled() == true
+ - the oldest item that was enqueued into this pipe has been
+ swapped into #item.
+ - else
+ - nothing was dequeued from this pipe.
+ - #item == item (i.e. the value of item is unchanged)
+ !*/
+
+ bool dequeue_or_timeout (
+ T& item,
+ unsigned long timeout
+ );
+ /*!
+ ensures
+ - if (size() == 0 && timeout > 0) then
+ - this call to dequeue_or_timeout() blocks until one of the following is the case:
+ - there is something in the pipe we can dequeue
+ - max_size() == 0 and another thread is trying to enqueue an item onto this
+ pipe and we can receive our item directly from that thread.
+ - someone calls disable()
+ - someone calls disable_dequeue()
+ - timeout milliseconds passes
+ - else
+ - this call does not block.
+ - if (this call to dequeue_or_timeout() returns true) then
+ - #is_enabled() == true
+ - #is_dequeue_enabled() == true
+ - the oldest item that was enqueued into this pipe has been
+ swapped into #item.
+ - else
+ - nothing was dequeued from this pipe.
+ - #item == item (i.e. the value of item is unchanged)
+ !*/
+
+ private:
+
+ // restricted functions
+ pipe(const pipe&); // copy constructor
+ pipe& operator=(const pipe&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_PIPE_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/pixel.h b/ml/dlib/dlib/pixel.h
new file mode 100644
index 000000000..50ead2c34
--- /dev/null
+++ b/ml/dlib/dlib/pixel.h
@@ -0,0 +1,1649 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PIXEl_
+#define DLIB_PIXEl_
+
+#include <iostream>
+#include "serialize.h"
+#include <cmath>
+#include "algs.h"
+#include "uintn.h"
+#include <limits>
+#include <complex>
+#include "enable_if.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ This file contains definitions of pixel objects and related classes and
+ functionality.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct pixel_traits;
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ As the name implies, this is a traits class for pixel types.
+ It defines the properties of a pixel.
+
+ This traits class will define the following public static members:
+ - bool grayscale
+ - bool rgb
+ - bool rgb_alpha
+ - bool hsi
+ - bool lab
+
+ - bool has_alpha
+
+ - long num
+
+ - basic_pixel_type
+ - basic_pixel_type min()
+ - basic_pixel_type max()
+ - is_unsigned
+
+ The above public constants are subject to the following constraints:
+ - only one of grayscale, rgb, rgb_alpha, hsi or lab is true
+ - if (rgb == true) then
+ - The type T will be a struct with 3 public members of type
+ unsigned char named "red" "green" and "blue".
+ - This type of pixel represents the RGB color space.
+ - num == 3
+ - has_alpha == false
+ - basic_pixel_type == unsigned char
+ - min() == 0
+ - max() == 255
+ - is_unsigned == true
+ - if (rgb_alpha == true) then
+ - The type T will be a struct with 4 public members of type
+ unsigned char named "red" "green" "blue" and "alpha".
+ - This type of pixel represents the RGB color space with
+ an alpha channel where an alpha of 0 represents a pixel
+ that is totally transparent and 255 represents a pixel
+ with maximum opacity.
+ - num == 4
+ - has_alpha == true
+ - basic_pixel_type == unsigned char
+ - min() == 0
+ - max() == 255
+ - is_unsigned == true
+ - else if (hsi == true) then
+ - The type T will be a struct with 3 public members of type
+ unsigned char named "h" "s" and "i".
+ - This type of pixel represents the HSI color space.
+ - num == 3
+ - has_alpha == false
+ - basic_pixel_type == unsigned char
+ - min() == 0
+ - max() == 255
+ - is_unsigned == true
+ - else if (lab == true) then
+ - The type T will be a struct with 3 public members of type
+ unsigned char named "l" "a" and "b".
+ - This type of pixel represents the Lab color space.
+ - num == 3
+ - has_alpha == false
+ - basic_pixel_type == unsigned char
+ - min() == 0
+ - max() == 255
+ - is_unsigned == true
+ - else
+ - grayscale == true
+ - This type of pixel represents a grayscale color space. T
+ will be some kind of basic scalar type such as unsigned int.
+ - num == 1
+ - has_alpha == false
+ - basic_pixel_type == T
+ - min() == the minimum obtainable value of objects of type T
+ - max() == the maximum obtainable value of objects of type T
+ - is_unsigned is true if min() == 0 and false otherwise
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ struct rgb_pixel
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple struct that represents an RGB colored graphical pixel.
+ !*/
+
+ rgb_pixel (
+ ) {}
+
+ rgb_pixel (
+ unsigned char red_,
+ unsigned char green_,
+ unsigned char blue_
+ ) : red(red_), green(green_), blue(blue_) {}
+
+ unsigned char red;
+ unsigned char green;
+ unsigned char blue;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct bgr_pixel
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple struct that represents an BGR colored graphical pixel.
+ (the reason it exists in addition to the rgb_pixel is so you can lay
+ it down on top of a memory region that organizes its color data in the
+ BGR format and still be able to read it)
+ !*/
+
+ bgr_pixel (
+ ) {}
+
+ bgr_pixel (
+ unsigned char blue_,
+ unsigned char green_,
+ unsigned char red_
+ ) : blue(blue_), green(green_), red(red_) {}
+
+ unsigned char blue;
+ unsigned char green;
+ unsigned char red;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct rgb_alpha_pixel
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple struct that represents an RGB colored graphical pixel
+ with an alpha channel.
+ !*/
+
+ rgb_alpha_pixel (
+ ) {}
+
+ rgb_alpha_pixel (
+ unsigned char red_,
+ unsigned char green_,
+ unsigned char blue_,
+ unsigned char alpha_
+ ) : red(red_), green(green_), blue(blue_), alpha(alpha_) {}
+
+ unsigned char red;
+ unsigned char green;
+ unsigned char blue;
+ unsigned char alpha;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct hsi_pixel
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple struct that represents an HSI colored graphical pixel.
+ !*/
+
+ hsi_pixel (
+ ) {}
+
+ hsi_pixel (
+ unsigned char h_,
+ unsigned char s_,
+ unsigned char i_
+ ) : h(h_), s(s_), i(i_) {}
+
+ unsigned char h;
+ unsigned char s;
+ unsigned char i;
+ };
+ // ----------------------------------------------------------------------------------------
+
+ struct lab_pixel
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple struct that represents an Lab colored graphical pixel.
+ !*/
+
+ lab_pixel (
+ ) {}
+
+ lab_pixel (
+ unsigned char l_,
+ unsigned char a_,
+ unsigned char b_
+ ) : l(l_), a(a_), b(b_) {}
+
+ unsigned char l;
+ unsigned char a;
+ unsigned char b;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename P1,
+ typename P2
+ >
+ inline void assign_pixel (
+ P1& dest,
+ const P2& src
+ );
+ /*!
+ requires
+ - pixel_traits<P1> must be defined
+ - pixel_traits<P2> must be defined
+ ensures
+ - if (P1 and P2 are the same type of pixel) then
+ - simply copies the value of src into dest. In other words,
+ dest will be identical to src after this function returns.
+ - else if (P1 and P2 are not the same type of pixel) then
+ - assigns pixel src to pixel dest and does any necessary color space
+ conversions.
+ - When converting from a grayscale color space with more than 255 values the
+ pixel intensity is saturated at pixel_traits<P1>::max() or pixel_traits<P1>::min()
+ as appropriate.
+ - if (the dest pixel has an alpha channel and the src pixel doesn't) then
+ - #dest.alpha == 255
+ - else if (the src pixel has an alpha channel but the dest pixel doesn't) then
+ - #dest == the original dest value blended with the src value according
+ to the alpha channel in src.
+ (i.e. #dest == src*(alpha/255) + dest*(1-alpha/255))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename P
+ >
+ inline typename pixel_traits<P>::basic_pixel_type get_pixel_intensity (
+ const P& src
+ );
+ /*!
+ requires
+ - pixel_traits<P> must be defined
+ ensures
+ - if (pixel_traits<P>::grayscale == true) then
+ - returns src
+ - else
+ - converts src to grayscale and returns the resulting value.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename P,
+ typename T
+ >
+ inline void assign_pixel_intensity (
+ P& dest,
+ const T& new_intensity
+ );
+ /*!
+ requires
+ - pixel_traits<P> must be defined
+ - pixel_traits<T> must be defined
+ ensures
+ - This function changes the intensity of the dest pixel. So if the pixel in
+ question is a grayscale pixel then it simply assigns that pixel with the
+ value of get_pixel_intensity(new_intensity). However, if the pixel is not
+ a grayscale pixel then it converts the pixel to the HSI color space and sets
+ the I channel to the given intensity and then converts this HSI value back to
+ the original pixel's color space.
+ - Note that we don't necessarily have #get_pixel_intensity(dest) == get_pixel_intensity(new_intensity)
+ due to vagaries of how converting to and from HSI works out.
+ - if (the dest pixel has an alpha channel) then
+ - #dest.alpha == dest.alpha
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const rgb_pixel& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for the rgb_pixel struct
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ rgb_pixel& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for the rgb_pixel struct
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const bgr_pixel& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for the bgr_pixel struct
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ bgr_pixel& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for the bgr_pixel struct
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const rgb_alpha_pixel& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for the rgb_alpha_pixel struct
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ rgb_alpha_pixel& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for the rgb_alpha_pixel struct
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const hsi_pixel& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for the hsi_pixel struct
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const lab_pixel& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for the lab_pixel struct
+ !*/
+
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ hsi_pixel& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for the hsi_pixel struct
+ !*/
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ lab_pixel& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for the lab_pixel struct
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <>
+ struct pixel_traits<rgb_pixel>
+ {
+ constexpr static bool rgb = true;
+ constexpr static bool rgb_alpha = false;
+ constexpr static bool grayscale = false;
+ constexpr static bool hsi = false;
+ constexpr static bool lab = false;
+ enum { num = 3};
+ typedef unsigned char basic_pixel_type;
+ static basic_pixel_type min() { return 0;}
+ static basic_pixel_type max() { return 255;}
+ constexpr static bool is_unsigned = true;
+ constexpr static bool has_alpha = false;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <>
+ struct pixel_traits<bgr_pixel>
+ {
+ constexpr static bool rgb = true;
+ constexpr static bool rgb_alpha = false;
+ constexpr static bool grayscale = false;
+ constexpr static bool hsi = false;
+ constexpr static bool lab = false;
+ constexpr static long num = 3;
+ typedef unsigned char basic_pixel_type;
+ static basic_pixel_type min() { return 0;}
+ static basic_pixel_type max() { return 255;}
+ constexpr static bool is_unsigned = true;
+ constexpr static bool has_alpha = false;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <>
+ struct pixel_traits<rgb_alpha_pixel>
+ {
+ constexpr static bool rgb = false;
+ constexpr static bool rgb_alpha = true;
+ constexpr static bool grayscale = false;
+ constexpr static bool hsi = false;
+ constexpr static bool lab = false;
+ constexpr static long num = 4;
+ typedef unsigned char basic_pixel_type;
+ static basic_pixel_type min() { return 0;}
+ static basic_pixel_type max() { return 255;}
+ constexpr static bool is_unsigned = true;
+ constexpr static bool has_alpha = true;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <>
+ struct pixel_traits<hsi_pixel>
+ {
+ constexpr static bool rgb = false;
+ constexpr static bool rgb_alpha = false;
+ constexpr static bool grayscale = false;
+ constexpr static bool hsi = true;
+ constexpr static bool lab = false;
+ constexpr static long num = 3;
+ typedef unsigned char basic_pixel_type;
+ static basic_pixel_type min() { return 0;}
+ static basic_pixel_type max() { return 255;}
+ constexpr static bool is_unsigned = true;
+ constexpr static bool has_alpha = false;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <>
+ struct pixel_traits<lab_pixel>
+ {
+ constexpr static bool rgb = false;
+ constexpr static bool rgb_alpha = false;
+ constexpr static bool grayscale = false;
+ constexpr static bool hsi = false;
+ constexpr static bool lab = true;
+ constexpr static long num = 3;
+ typedef unsigned char basic_pixel_type;
+ static basic_pixel_type min() { return 0;}
+ static basic_pixel_type max() { return 255;}
+ constexpr static bool is_unsigned = true;
+ constexpr static bool has_alpha = false;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct grayscale_pixel_traits
+ {
+ constexpr static bool rgb = false;
+ constexpr static bool rgb_alpha = false;
+ constexpr static bool grayscale = true;
+ constexpr static bool hsi = false;
+ constexpr static bool lab = false;
+ constexpr static long num = 1;
+ constexpr static bool has_alpha = false;
+ typedef T basic_pixel_type;
+ static basic_pixel_type min() { return std::numeric_limits<T>::min();}
+ static basic_pixel_type max() { return std::numeric_limits<T>::max();}
+ constexpr static bool is_unsigned = is_unsigned_type<T>::value;
+ };
+
+ template <> struct pixel_traits<unsigned char> : public grayscale_pixel_traits<unsigned char> {};
+ template <> struct pixel_traits<unsigned short> : public grayscale_pixel_traits<unsigned short> {};
+ template <> struct pixel_traits<unsigned int> : public grayscale_pixel_traits<unsigned int> {};
+ template <> struct pixel_traits<unsigned long> : public grayscale_pixel_traits<unsigned long> {};
+
+ template <> struct pixel_traits<char> : public grayscale_pixel_traits<char> {};
+ template <> struct pixel_traits<signed char> : public grayscale_pixel_traits<signed char> {};
+ template <> struct pixel_traits<short> : public grayscale_pixel_traits<short> {};
+ template <> struct pixel_traits<int> : public grayscale_pixel_traits<int> {};
+ template <> struct pixel_traits<long> : public grayscale_pixel_traits<long> {};
+
+ template <> struct pixel_traits<int64> : public grayscale_pixel_traits<int64> {};
+ template <> struct pixel_traits<uint64> : public grayscale_pixel_traits<uint64> {};
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct float_grayscale_pixel_traits
+ {
+ constexpr static bool rgb = false;
+ constexpr static bool rgb_alpha = false;
+ constexpr static bool grayscale = true;
+ constexpr static bool hsi = false;
+ constexpr static bool lab = false;
+ constexpr static long num = 1;
+ constexpr static bool has_alpha = false;
+ typedef T basic_pixel_type;
+ static basic_pixel_type min() { return -std::numeric_limits<T>::max();}
+ static basic_pixel_type max() { return std::numeric_limits<T>::max();}
+ constexpr static bool is_unsigned = false;
+ };
+
+ template <> struct pixel_traits<float> : public float_grayscale_pixel_traits<float> {};
+ template <> struct pixel_traits<double> : public float_grayscale_pixel_traits<double> {};
+ template <> struct pixel_traits<long double> : public float_grayscale_pixel_traits<long double> {};
+
+ // These are here mainly so you can easily copy images into complex arrays. This is
+ // useful when you want to do a FFT on an image or some similar operation.
+ template <> struct pixel_traits<std::complex<float> > : public float_grayscale_pixel_traits<float> {};
+ template <> struct pixel_traits<std::complex<double> > : public float_grayscale_pixel_traits<double> {};
+ template <> struct pixel_traits<std::complex<long double> > : public float_grayscale_pixel_traits<long double> {};
+
+// ----------------------------------------------------------------------------------------
+
+ // The following is a bunch of conversion stuff for the assign_pixel function.
+
+ namespace assign_pixel_helpers
+ {
+
+ // -----------------------------
+ // all the same kind
+
+ template < typename P >
+ typename enable_if_c<pixel_traits<P>::grayscale>::type
+ assign(P& dest, const P& src)
+ {
+ dest = src;
+ }
+
+ // -----------------------------
+
+ template <typename T>
+ typename unsigned_type<T>::type make_unsigned (
+ const T& val
+ ) { return static_cast<typename unsigned_type<T>::type>(val); }
+
+ inline float make_unsigned(const float& val) { return val; }
+ inline double make_unsigned(const double& val) { return val; }
+ inline long double make_unsigned(const long double& val) { return val; }
+
+
+ template <typename T, typename P>
+ typename enable_if_c<pixel_traits<T>::is_unsigned == pixel_traits<P>::is_unsigned, bool>::type less_or_equal_to_max (
+ const P& p
+ )
+ /*!
+ ensures
+ - returns true if p is <= max value of T
+ !*/
+ {
+ return p <= pixel_traits<T>::max();
+ }
+
+ template <typename T, typename P>
+ typename enable_if_c<pixel_traits<T>::is_unsigned && !pixel_traits<P>::is_unsigned, bool>::type less_or_equal_to_max (
+ const P& p
+ )
+ {
+ if (p <= 0)
+ return true;
+ else if (make_unsigned(p) <= pixel_traits<T>::max())
+ return true;
+ else
+ return false;
+ }
+
+ template <typename T, typename P>
+ typename enable_if_c<!pixel_traits<T>::is_unsigned && pixel_traits<P>::is_unsigned, bool>::type less_or_equal_to_max (
+ const P& p
+ )
+ {
+ return p <= make_unsigned(pixel_traits<T>::max());
+ }
+
+ // -----------------------------
+
+ template <typename T, typename P>
+ typename enable_if_c<pixel_traits<P>::is_unsigned, bool >::type greater_or_equal_to_min (
+ const P&
+ ) { return true; }
+ /*!
+ ensures
+ - returns true if p is >= min value of T
+ !*/
+
+ template <typename T, typename P>
+ typename enable_if_c<!pixel_traits<P>::is_unsigned && pixel_traits<T>::is_unsigned, bool >::type greater_or_equal_to_min (
+ const P& p
+ )
+ {
+ return p >= 0;
+ }
+
+ template <typename T, typename P>
+ typename enable_if_c<!pixel_traits<P>::is_unsigned && !pixel_traits<T>::is_unsigned, bool >::type greater_or_equal_to_min (
+ const P& p
+ )
+ {
+ return p >= pixel_traits<T>::min();
+ }
+ // -----------------------------
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::grayscale && pixel_traits<P2>::grayscale>::type
+ assign(P1& dest, const P2& src)
+ {
+ /*
+ The reason for these weird comparison functions is to avoid getting compiler
+ warnings about comparing signed types to unsigned and stuff like that.
+ */
+
+ if (less_or_equal_to_max<P1>(src))
+ if (greater_or_equal_to_min<P1>(src))
+ dest = static_cast<P1>(src);
+ else
+ dest = pixel_traits<P1>::min();
+ else
+ dest = pixel_traits<P1>::max();
+ }
+
+ // -----------------------------
+ // -----------------------------
+ // -----------------------------
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb && pixel_traits<P2>::rgb>::type
+ assign(P1& dest, const P2& src)
+ {
+ dest.red = src.red;
+ dest.green = src.green;
+ dest.blue = src.blue;
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb_alpha && pixel_traits<P2>::rgb_alpha>::type
+ assign(P1& dest, const P2& src)
+ {
+ dest.red = src.red;
+ dest.green = src.green;
+ dest.blue = src.blue;
+ dest.alpha = src.alpha;
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::hsi && pixel_traits<P2>::hsi>::type
+ assign(P1& dest, const P2& src)
+ {
+ dest.h = src.h;
+ dest.s = src.s;
+ dest.i = src.i;
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::lab && pixel_traits<P2>::lab>::type
+ assign(P1& dest, const P2& src)
+ {
+ dest.l = src.l;
+ dest.a = src.a;
+ dest.b = src.b;
+ }
+
+ // -----------------------------
+ // dest is a grayscale
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::grayscale && pixel_traits<P2>::rgb>::type
+ assign(P1& dest, const P2& src)
+ {
+ const unsigned int temp = ((static_cast<unsigned int>(src.red) +
+ static_cast<unsigned int>(src.green) +
+ static_cast<unsigned int>(src.blue))/3);
+ assign_pixel(dest, temp);
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::grayscale && pixel_traits<P2>::rgb_alpha>::type
+ assign(P1& dest, const P2& src)
+ {
+
+ const unsigned char avg = static_cast<unsigned char>((static_cast<unsigned int>(src.red) +
+ static_cast<unsigned int>(src.green) +
+ static_cast<unsigned int>(src.blue))/3);
+
+ if (src.alpha == 255)
+ {
+ assign_pixel(dest, avg);
+ }
+ else
+ {
+ // perform this assignment using fixed point arithmetic:
+ // dest = src*(alpha/255) + dest*(1 - alpha/255);
+ // dest = src*(alpha/255) + dest*1 - dest*(alpha/255);
+ // dest = dest*1 + src*(alpha/255) - dest*(alpha/255);
+ // dest = dest*1 + (src - dest)*(alpha/255);
+ // dest += (src - dest)*(alpha/255);
+
+ int temp = avg;
+ // copy dest into dest_copy using assign_pixel to avoid potential
+ // warnings about implicit float to int warnings.
+ int dest_copy;
+ assign_pixel(dest_copy, dest);
+
+ temp -= dest_copy;
+
+ temp *= src.alpha;
+
+ temp /= 255;
+
+ assign_pixel(dest, temp+dest_copy);
+ }
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::grayscale && pixel_traits<P2>::hsi>::type
+ assign(P1& dest, const P2& src)
+ {
+ assign_pixel(dest, src.i);
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::grayscale && pixel_traits<P2>::lab>::type
+ assign(P1& dest, const P2& src)
+ {
+ assign_pixel(dest, src.l);
+ }
+
+
+ // -----------------------------
+
+ struct HSL
+ {
+ double h;
+ double s;
+ double l;
+ };
+
+ struct COLOUR
+ {
+ double r;
+ double g;
+ double b;
+ };
+
+ /*
+ I found this excellent bit of code for dealing with HSL spaces at
+ http://local.wasp.uwa.edu.au/~pbourke/colour/hsl/
+ */
+ /*
+ Calculate HSL from RGB
+ Hue is in degrees
+ Lightness is between 0 and 1
+ Saturation is between 0 and 1
+ */
+ inline HSL RGB2HSL(COLOUR c1)
+ {
+ double themin,themax,delta;
+ HSL c2;
+ using namespace std;
+
+ themin = std::min(c1.r,std::min(c1.g,c1.b));
+ themax = std::max(c1.r,std::max(c1.g,c1.b));
+ delta = themax - themin;
+ c2.l = (themin + themax) / 2;
+ c2.s = 0;
+ if (c2.l > 0 && c2.l < 1)
+ c2.s = delta / (c2.l < 0.5 ? (2*c2.l) : (2-2*c2.l));
+ c2.h = 0;
+ if (delta > 0) {
+ if (themax == c1.r && themax != c1.g)
+ c2.h += (c1.g - c1.b) / delta;
+ if (themax == c1.g && themax != c1.b)
+ c2.h += (2 + (c1.b - c1.r) / delta);
+ if (themax == c1.b && themax != c1.r)
+ c2.h += (4 + (c1.r - c1.g) / delta);
+ c2.h *= 60;
+ }
+ return(c2);
+ }
+
+ /*
+ Calculate RGB from HSL, reverse of RGB2HSL()
+ Hue is in degrees
+ Lightness is between 0 and 1
+ Saturation is between 0 and 1
+ */
+ inline COLOUR HSL2RGB(HSL c1)
+ {
+ COLOUR c2,sat,ctmp;
+ using namespace std;
+
+ if (c1.h < 120) {
+ sat.r = (120 - c1.h) / 60.0;
+ sat.g = c1.h / 60.0;
+ sat.b = 0;
+ } else if (c1.h < 240) {
+ sat.r = 0;
+ sat.g = (240 - c1.h) / 60.0;
+ sat.b = (c1.h - 120) / 60.0;
+ } else {
+ sat.r = (c1.h - 240) / 60.0;
+ sat.g = 0;
+ sat.b = (360 - c1.h) / 60.0;
+ }
+ sat.r = std::min(sat.r,1.0);
+ sat.g = std::min(sat.g,1.0);
+ sat.b = std::min(sat.b,1.0);
+
+ ctmp.r = 2 * c1.s * sat.r + (1 - c1.s);
+ ctmp.g = 2 * c1.s * sat.g + (1 - c1.s);
+ ctmp.b = 2 * c1.s * sat.b + (1 - c1.s);
+
+ if (c1.l < 0.5) {
+ c2.r = c1.l * ctmp.r;
+ c2.g = c1.l * ctmp.g;
+ c2.b = c1.l * ctmp.b;
+ } else {
+ c2.r = (1 - c1.l) * ctmp.r + 2 * c1.l - 1;
+ c2.g = (1 - c1.l) * ctmp.g + 2 * c1.l - 1;
+ c2.b = (1 - c1.l) * ctmp.b + 2 * c1.l - 1;
+ }
+
+ return(c2);
+ }
+
+ // -----------------------------
+
+ struct Lab
+ {
+ double l;
+ double a;
+ double b;
+ };
+ /*
+ Calculate Lab from RGB
+ L is between 0 and 100
+ a is between -128 and 127
+ b is between -128 and 127
+ RGB is between 0.0 and 1.0
+ */
+ inline Lab RGB2Lab(COLOUR c1)
+ {
+ Lab c2;
+ using namespace std;
+
+ double var_R = c1.r;
+ double var_G = c1.g;
+ double var_B = c1.b;
+
+ if (var_R > 0.04045) {
+ var_R = pow(((var_R + 0.055) / 1.055), 2.4);
+ } else {
+ var_R = var_R / 12.92;
+ }
+
+ if (var_G > 0.04045) {
+ var_G = pow(((var_G + 0.055) / 1.055), 2.4);
+ } else {
+ var_G = var_G / 12.92;
+ }
+
+ if (var_B > 0.04045) {
+ var_B = pow(((var_B + 0.055) / 1.055), 2.4);
+ } else {
+ var_B = var_B / 12.92;
+ }
+
+ var_R = var_R * 100;
+ var_G = var_G * 100;
+ var_B = var_B * 100;
+
+//Observer. = 2°, Illuminant = D65
+ double X = var_R * 0.4124 + var_G * 0.3576 + var_B * 0.1805;
+ double Y = var_R * 0.2126 + var_G * 0.7152 + var_B * 0.0722;
+ double Z = var_R * 0.0193 + var_G * 0.1192 + var_B * 0.9505;
+
+ double var_X = X / 95.047;
+ double var_Y = Y / 100.000;
+ double var_Z = Z / 108.883;
+
+ if (var_X > 0.008856) {
+ var_X = pow(var_X, (1.0 / 3));
+ }
+ else {
+ var_X = (7.787 * var_X) + (16.0 / 116);
+ }
+
+ if (var_Y > 0.008856) {
+ var_Y = pow(var_Y, (1.0 / 3));
+ }
+ else {
+ var_Y = (7.787 * var_Y) + (16.0 / 116);
+ }
+
+ if (var_Z > 0.008856) {
+ var_Z = pow(var_Z, (1.0 / 3));
+ }
+ else {
+ var_Z = (7.787 * var_Z) + (16.0 / 116);
+ }
+
+ //clamping
+ c2.l = max(0.0, (116.0 * var_Y) - 16);
+ c2.a = max(-128.0, min(127.0, 500.0 * (var_X - var_Y)));
+ c2.b = max(-128.0, min(127.0, 200.0 * (var_Y - var_Z)));
+
+ return c2;
+ }
+
+ /*
+ Calculate RGB from Lab, reverse of RGB2LAb()
+ L is between 0 and 100
+ a is between -128 and 127
+ b is between -128 and 127
+ RGB is between 0.0 and 1.0
+ */
+ inline COLOUR Lab2RGB(Lab c1) {
+ COLOUR c2;
+ using namespace std;
+
+ double var_Y = (c1.l + 16) / 116.0;
+ double var_X = (c1.a / 500.0) + var_Y;
+ double var_Z = var_Y - (c1.b / 200);
+
+ if (pow(var_Y, 3) > 0.008856) {
+ var_Y = pow(var_Y, 3);
+ } else {
+ var_Y = (var_Y - 16.0 / 116) / 7.787;
+ }
+
+ if (pow(var_X, 3) > 0.008856) {
+ var_X = pow(var_X, 3);
+ } else {
+ var_X = (var_X - 16.0 / 116) / 7.787;
+ }
+
+ if (pow(var_Z, 3) > 0.008856) {
+ var_Z = pow(var_Z, 3);
+ } else {
+ var_Z = (var_Z - 16.0 / 116) / 7.787;
+ }
+
+ double X = var_X * 95.047;
+ double Y = var_Y * 100.000;
+ double Z = var_Z * 108.883;
+
+ var_X = X / 100.0;
+ var_Y = Y / 100.0;
+ var_Z = Z / 100.0;
+
+ double var_R = var_X * 3.2406 + var_Y * -1.5372 + var_Z * -0.4986;
+ double var_G = var_X * -0.9689 + var_Y * 1.8758 + var_Z * 0.0415;
+ double var_B = var_X * 0.0557 + var_Y * -0.2040 + var_Z * 1.0570;
+
+ if (var_R > 0.0031308) {
+ var_R = 1.055 * pow(var_R, (1 / 2.4)) - 0.055;
+ } else {
+ var_R = 12.92 * var_R;
+ }
+
+ if (var_G > 0.0031308) {
+ var_G = 1.055 * pow(var_G, (1 / 2.4)) - 0.055;
+ } else {
+ var_G = 12.92 * var_G;
+ }
+
+ if (var_B > 0.0031308) {
+ var_B = 1.055 * pow(var_B, (1 / 2.4)) - 0.055;
+ } else {
+ var_B = 12.92 * var_B;
+ }
+
+ // clamping
+ c2.r = max(0.0, min(1.0, var_R));
+ c2.g = max(0.0, min(1.0, var_G));
+ c2.b = max(0.0, min(1.0, var_B));
+
+ return (c2);
+ }
+
+
+ // -----------------------------
+ // dest is a color rgb_pixel
+
+ template < typename P1 >
+ typename enable_if_c<pixel_traits<P1>::rgb>::type
+ assign(P1& dest, const unsigned char& src)
+ {
+ dest.red = src;
+ dest.green = src;
+ dest.blue = src;
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb && pixel_traits<P2>::grayscale>::type
+ assign(P1& dest, const P2& src)
+ {
+ unsigned char p;
+ assign_pixel(p, src);
+ dest.red = p;
+ dest.green = p;
+ dest.blue = p;
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb && pixel_traits<P2>::rgb_alpha>::type
+ assign(P1& dest, const P2& src)
+ {
+ if (src.alpha == 255)
+ {
+ dest.red = src.red;
+ dest.green = src.green;
+ dest.blue = src.blue;
+ }
+ else
+ {
+ // perform this assignment using fixed point arithmetic:
+ // dest = src*(alpha/255) + dest*(1 - alpha/255);
+ // dest = src*(alpha/255) + dest*1 - dest*(alpha/255);
+ // dest = dest*1 + src*(alpha/255) - dest*(alpha/255);
+ // dest = dest*1 + (src - dest)*(alpha/255);
+ // dest += (src - dest)*(alpha/255);
+
+ unsigned int temp_r = src.red;
+ unsigned int temp_g = src.green;
+ unsigned int temp_b = src.blue;
+
+ temp_r -= dest.red;
+ temp_g -= dest.green;
+ temp_b -= dest.blue;
+
+ temp_r *= src.alpha;
+ temp_g *= src.alpha;
+ temp_b *= src.alpha;
+
+ temp_r >>= 8;
+ temp_g >>= 8;
+ temp_b >>= 8;
+
+ dest.red += static_cast<unsigned char>(temp_r&0xFF);
+ dest.green += static_cast<unsigned char>(temp_g&0xFF);
+ dest.blue += static_cast<unsigned char>(temp_b&0xFF);
+ }
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb && pixel_traits<P2>::hsi>::type
+ assign(P1& dest, const P2& src)
+ {
+ COLOUR c;
+ HSL h;
+ h.h = src.h;
+ h.h = h.h/255.0*360;
+ h.s = src.s/255.0;
+ h.l = src.i/255.0;
+ c = HSL2RGB(h);
+
+ dest.red = static_cast<unsigned char>(c.r*255.0 + 0.5);
+ dest.green = static_cast<unsigned char>(c.g*255.0 + 0.5);
+ dest.blue = static_cast<unsigned char>(c.b*255.0 + 0.5);
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb && pixel_traits<P2>::lab>::type
+ assign(P1& dest, const P2& src)
+ {
+ COLOUR c;
+ Lab l;
+ l.l = (src.l/255.0)*100;
+ l.a = (src.a-128.0);
+ l.b = (src.b-128.0);
+ c = Lab2RGB(l);
+
+ dest.red = static_cast<unsigned char>(c.r*255.0 + 0.5);
+ dest.green = static_cast<unsigned char>(c.g*255.0 + 0.5);
+ dest.blue = static_cast<unsigned char>(c.b*255.0 + 0.5);
+ }
+
+
+ // -----------------------------
+ // dest is a color rgb_alpha_pixel
+
+ template < typename P1 >
+ typename enable_if_c<pixel_traits<P1>::rgb_alpha>::type
+ assign(P1& dest, const unsigned char& src)
+ {
+ dest.red = src;
+ dest.green = src;
+ dest.blue = src;
+ dest.alpha = 255;
+ }
+
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb_alpha && pixel_traits<P2>::grayscale>::type
+ assign(P1& dest, const P2& src)
+ {
+ unsigned char p;
+ assign_pixel(p, src);
+
+ dest.red = p;
+ dest.green = p;
+ dest.blue = p;
+ dest.alpha = 255;
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb_alpha && pixel_traits<P2>::rgb>::type
+ assign(P1& dest, const P2& src)
+ {
+ dest.red = src.red;
+ dest.green = src.green;
+ dest.blue = src.blue;
+ dest.alpha = 255;
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb_alpha && pixel_traits<P2>::hsi>::type
+ assign(P1& dest, const P2& src)
+ {
+ COLOUR c;
+ HSL h;
+ h.h = src.h;
+ h.h = h.h/255.0*360;
+ h.s = src.s/255.0;
+ h.l = src.i/255.0;
+ c = HSL2RGB(h);
+
+ dest.red = static_cast<unsigned char>(c.r*255.0 + 0.5);
+ dest.green = static_cast<unsigned char>(c.g*255.0 + 0.5);
+ dest.blue = static_cast<unsigned char>(c.b*255.0 + 0.5);
+ dest.alpha = 255;
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::rgb_alpha && pixel_traits<P2>::lab>::type
+ assign(P1& dest, const P2& src)
+ {
+ COLOUR c;
+ Lab l;
+ l.l = (src.l/255.0)*100;
+ l.a = (src.a-128.0);
+ l.b = (src.b-128.0);
+ c = Lab2RGB(l);
+
+ dest.red = static_cast<unsigned char>(c.r * 255 + 0.5);
+ dest.green = static_cast<unsigned char>(c.g * 255 + 0.5);
+ dest.blue = static_cast<unsigned char>(c.b * 255 + 0.5);
+ dest.alpha = 255;
+ }
+ // -----------------------------
+ // dest is an hsi pixel
+
+ template < typename P1>
+ typename enable_if_c<pixel_traits<P1>::hsi>::type
+ assign(P1& dest, const unsigned char& src)
+ {
+ dest.h = 0;
+ dest.s = 0;
+ dest.i = src;
+ }
+
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::hsi && pixel_traits<P2>::grayscale>::type
+ assign(P1& dest, const P2& src)
+ {
+ dest.h = 0;
+ dest.s = 0;
+ assign_pixel(dest.i, src);
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::hsi && pixel_traits<P2>::rgb>::type
+ assign(P1& dest, const P2& src)
+ {
+ COLOUR c1;
+ HSL c2;
+ c1.r = src.red/255.0;
+ c1.g = src.green/255.0;
+ c1.b = src.blue/255.0;
+ c2 = RGB2HSL(c1);
+
+ dest.h = static_cast<unsigned char>(c2.h/360.0*255.0 + 0.5);
+ dest.s = static_cast<unsigned char>(c2.s*255.0 + 0.5);
+ dest.i = static_cast<unsigned char>(c2.l*255.0 + 0.5);
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::hsi && pixel_traits<P2>::rgb_alpha>::type
+ assign(P1& dest, const P2& src)
+ {
+ rgb_pixel temp;
+ // convert target hsi pixel to rgb
+ assign_pixel_helpers::assign(temp,dest);
+
+ // now assign the rgb_alpha value to our temp rgb pixel
+ assign_pixel_helpers::assign(temp,src);
+
+ // now we can just go assign the new rgb value to the
+ // hsi pixel
+ assign_pixel_helpers::assign(dest,temp);
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::hsi && pixel_traits<P2>::lab>::type
+ assign(P1& dest, const P2& src)
+ {
+ rgb_pixel temp;
+ // convert lab value to our temp rgb pixel
+ assign_pixel_helpers::assign(temp,src);
+ // now we can just go assign the new rgb value to the
+ // hsi pixel
+ assign_pixel_helpers::assign(dest,temp);
+ }
+
+ // -----------------------------
+ // dest is an lab pixel
+ template < typename P1>
+ typename enable_if_c<pixel_traits<P1>::lab>::type
+ assign(P1& dest, const unsigned char& src)
+ {
+ dest.a = 128;
+ dest.b = 128;
+ dest.l = src;
+ }
+
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::lab && pixel_traits<P2>::grayscale>::type
+ assign(P1& dest, const P2& src)
+ {
+ dest.a = 128;
+ dest.b = 128;
+ assign_pixel(dest.l, src);
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::lab && pixel_traits<P2>::rgb>::type
+ assign(P1& dest, const P2& src)
+ {
+ COLOUR c1;
+ Lab c2;
+ c1.r = src.red / 255.0;
+ c1.g = src.green / 255.0;
+ c1.b = src.blue / 255.0;
+ c2 = RGB2Lab(c1);
+
+ dest.l = static_cast<unsigned char>((c2.l / 100) * 255 + 0.5);
+ dest.a = static_cast<unsigned char>(c2.a + 128 + 0.5);
+ dest.b = static_cast<unsigned char>(c2.b + 128 + 0.5);
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::lab && pixel_traits<P2>::rgb_alpha>::type
+ assign(P1& dest, const P2& src)
+ {
+ rgb_pixel temp;
+ // convert target lab pixel to rgb
+ assign_pixel_helpers::assign(temp,dest);
+
+ // now assign the rgb_alpha value to our temp rgb pixel
+ assign_pixel_helpers::assign(temp,src);
+
+ // now we can just go assign the new rgb value to the
+ // lab pixel
+ assign_pixel_helpers::assign(dest,temp);
+ }
+
+ template < typename P1, typename P2 >
+ typename enable_if_c<pixel_traits<P1>::lab && pixel_traits<P2>::hsi>::type
+ assign(P1& dest, const P2& src)
+ {
+ rgb_pixel temp;
+
+ // convert hsi value to our temp rgb pixel
+ assign_pixel_helpers::assign(temp,src);
+
+ // now we can just go assign the new rgb value to the
+ // lab pixel
+ assign_pixel_helpers::assign(dest,temp);
+ }
+ }
+
+ // -----------------------------
+
+ template < typename P1, typename P2 >
+ inline void assign_pixel (
+ P1& dest,
+ const P2& src
+ ) { assign_pixel_helpers::assign(dest,src); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename P,
+ typename T
+ >
+ inline typename enable_if_c<pixel_traits<P>::grayscale>::type assign_pixel_intensity_helper (
+ P& dest,
+ const T& new_intensity
+ )
+ {
+ assign_pixel(dest, new_intensity);
+ }
+
+ template <
+ typename P,
+ typename T
+ >
+ inline typename enable_if_c<pixel_traits<P>::grayscale == false &&
+ pixel_traits<P>::has_alpha>::type assign_pixel_intensity_helper (
+ P& dest,
+ const T& new_intensity
+ )
+ {
+ hsi_pixel p;
+ const unsigned long old_alpha = dest.alpha;
+ dest.alpha = 255;
+ rgb_pixel temp;
+ assign_pixel(temp, dest); // put dest into an rgb_pixel to avoid the somewhat complicated assign_pixel(hsi,rgb_alpha).
+ assign_pixel(p,temp);
+ assign_pixel(p.i, new_intensity);
+ assign_pixel(dest,p);
+ dest.alpha = old_alpha;
+ }
+
+ template <
+ typename P,
+ typename T
+ >
+ inline typename enable_if_c<pixel_traits<P>::grayscale == false &&
+ pixel_traits<P>::has_alpha == false>::type assign_pixel_intensity_helper (
+ P& dest,
+ const T& new_intensity
+ )
+ {
+ hsi_pixel p;
+ assign_pixel(p,dest);
+ assign_pixel(p.i, new_intensity);
+ assign_pixel(dest,p);
+ }
+
+ template <
+ typename P,
+ typename T
+ >
+ inline void assign_pixel_intensity (
+ P& dest,
+ const T& new_intensity
+ )
+ {
+ assign_pixel_intensity_helper(dest, new_intensity);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename P
+ >
+ inline typename enable_if_c<pixel_traits<P>::grayscale, P>::type get_pixel_intensity_helper (
+ const P& src
+ )
+ {
+ return src;
+ }
+
+ template <
+ typename P
+ >
+ inline typename enable_if_c<pixel_traits<P>::grayscale == false&&
+ pixel_traits<P>::has_alpha,
+ typename pixel_traits<P>::basic_pixel_type>::type get_pixel_intensity_helper (
+ const P& src
+ )
+ {
+ P temp = src;
+ temp.alpha = 255;
+ typename pixel_traits<P>::basic_pixel_type p;
+ assign_pixel(p,temp);
+ return p;
+ }
+
+ template <
+ typename P
+ >
+ inline typename enable_if_c<pixel_traits<P>::grayscale == false&&
+ pixel_traits<P>::has_alpha == false,
+ typename pixel_traits<P>::basic_pixel_type>::type get_pixel_intensity_helper (
+ const P& src
+ )
+ {
+ typename pixel_traits<P>::basic_pixel_type p;
+ assign_pixel(p,src);
+ return p;
+ }
+
+ template <
+ typename P
+ >
+ inline typename pixel_traits<P>::basic_pixel_type get_pixel_intensity (
+ const P& src
+ )
+ {
+ return get_pixel_intensity_helper(src);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const rgb_alpha_pixel& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.red,out);
+ serialize(item.green,out);
+ serialize(item.blue,out);
+ serialize(item.alpha,out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type rgb_alpha_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ rgb_alpha_pixel& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.red,in);
+ deserialize(item.green,in);
+ deserialize(item.blue,in);
+ deserialize(item.alpha,in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type rgb_alpha_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const rgb_pixel& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.red,out);
+ serialize(item.green,out);
+ serialize(item.blue,out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type rgb_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ rgb_pixel& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.red,in);
+ deserialize(item.green,in);
+ deserialize(item.blue,in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type rgb_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const bgr_pixel& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.blue,out);
+ serialize(item.green,out);
+ serialize(item.red,out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type bgr_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ bgr_pixel& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.blue,in);
+ deserialize(item.green,in);
+ deserialize(item.red,in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type bgr_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const hsi_pixel& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.h,out);
+ serialize(item.s,out);
+ serialize(item.i,out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type hsi_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ hsi_pixel& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.h,in);
+ deserialize(item.s,in);
+ deserialize(item.i,in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type hsi_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const lab_pixel& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.l,out);
+ serialize(item.a,out);
+ serialize(item.b,out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type lab_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void deserialize (
+ lab_pixel& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.l,in);
+ deserialize(item.a,in);
+ deserialize(item.b,in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type lab_pixel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_PIXEl_
+
diff --git a/ml/dlib/dlib/platform.h b/ml/dlib/dlib/platform.h
new file mode 100644
index 000000000..f3000a6cd
--- /dev/null
+++ b/ml/dlib/dlib/platform.h
@@ -0,0 +1,65 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_PLATFORm_
+#define DLIB_PLATFORm_
+
+
+/*!
+ This file ensures that:
+ - if (we are compiling under a posix platform) then
+ - POSIX will be defined
+ - if (this is also Mac OS X) then
+ - MACOSX will be defined
+ - if (this is also HP-UX) then
+ - HPUX will be defined
+ - if (we are compiling under an MS Windows platform) then
+ - WIN32 will be defined
+!*/
+
+
+/*
+ A good reference for this sort of information is
+ http://predef.sourceforge.net/
+*/
+
+// Define WIN32 if this is MS Windows
+#ifndef WIN32
+ #if defined( _MSC_VER) || defined(__BORLANDC__) || defined(_WIN32) || defined(__WIN32__) || defined(__TOS_WIN__)
+ #define WIN32
+ #endif
+#endif
+
+#ifndef WIN32
+ // since this is the only other platform the library currently supports
+ // just assume it is POSIX if it isn't WIN32
+ #ifndef POSIX
+ #define POSIX
+ #endif
+
+ #ifndef HPUX
+ #if defined(__hpux ) || defined(hpux) || defined (_hpux)
+ #define HPUX
+ #endif
+ #endif
+
+ #ifndef MACOSX
+ #ifdef __MACOSX__
+ #define MACOSX
+ #endif
+ #ifdef __APPLE__
+ #define MACOSX
+ #endif
+ #endif
+
+#endif
+
+
+
+
+#endif // DLIB_PLATFORm_
+
diff --git a/ml/dlib/dlib/python.h b/ml/dlib/dlib/python.h
new file mode 100644
index 000000000..07b9c0707
--- /dev/null
+++ b/ml/dlib/dlib/python.h
@@ -0,0 +1,14 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PYTHoN_TOP_
+#define DLIB_PYTHoN_TOP_
+
+#include "python/pybind_utils.h"
+#include "python/pyassert.h"
+#include "python/serialize_pickle.h"
+#include "python/numpy.h"
+#include "python/numpy_image.h"
+
+#endif // DLIB_PYTHoN_TOP_
+
+
diff --git a/ml/dlib/dlib/python/numpy.h b/ml/dlib/dlib/python/numpy.h
new file mode 100644
index 000000000..9b2c1a01c
--- /dev/null
+++ b/ml/dlib/dlib/python/numpy.h
@@ -0,0 +1,214 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PYTHON_NuMPY_Hh_
+#define DLIB_PYTHON_NuMPY_Hh_
+
+#include <pybind11/pybind11.h>
+#include <dlib/error.h>
+#include <dlib/algs.h>
+#include <dlib/string.h>
+#include <dlib/array.h>
+#include <dlib/pixel.h>
+
+namespace py = pybind11;
+
+// ----------------------------------------------------------------------------------------
+
+template <typename TT>
+void validate_numpy_array_type (
+ const py::object& obj
+)
+{
+ const char ch = obj.attr("dtype").attr("char").cast<char>();
+
+ using T = typename dlib::pixel_traits<TT>::basic_pixel_type;
+
+ if (dlib::is_same_type<T,double>::value)
+ {
+ if (ch != 'd')
+ throw dlib::error("Expected numpy.ndarray of float64");
+ }
+ else if (dlib::is_same_type<T,float>::value)
+ {
+ if (ch != 'f')
+ throw dlib::error("Expected numpy.ndarray of float32");
+ }
+ else if (dlib::is_same_type<T,dlib::int16>::value)
+ {
+ if (ch != 'h')
+ throw dlib::error("Expected numpy.ndarray of int16");
+ }
+ else if (dlib::is_same_type<T,dlib::uint16>::value)
+ {
+ if (ch != 'H')
+ throw dlib::error("Expected numpy.ndarray of uint16");
+ }
+ else if (dlib::is_same_type<T,dlib::int32>::value)
+ {
+ if (ch != 'i')
+ throw dlib::error("Expected numpy.ndarray of int32");
+ }
+ else if (dlib::is_same_type<T,dlib::uint32>::value)
+ {
+ if (ch != 'I')
+ throw dlib::error("Expected numpy.ndarray of uint32");
+ }
+ else if (dlib::is_same_type<T,unsigned char>::value)
+ {
+ if (ch != 'B')
+ throw dlib::error("Expected numpy.ndarray of uint8");
+ }
+ else if (dlib::is_same_type<T,signed char>::value)
+ {
+ if (ch != 'b')
+ throw dlib::error("Expected numpy.ndarray of int8");
+ }
+ else
+ {
+ throw dlib::error("validate_numpy_array_type() called with unsupported type.");
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+template <int dims>
+void get_numpy_ndarray_shape (
+ const py::object& obj,
+ long (&shape)[dims]
+)
+/*!
+ ensures
+ - stores the shape of the array into #shape.
+ - the dimension of the given numpy array is not greater than #dims.
+!*/
+{
+ Py_buffer pybuf;
+ if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES ))
+ throw dlib::error("Expected numpy.ndarray with shape set.");
+
+ try
+ {
+
+ if (pybuf.ndim > dims)
+ throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions.");
+
+ for (int i = 0; i < dims; ++i)
+ {
+ if (i < pybuf.ndim)
+ shape[i] = pybuf.shape[i];
+ else
+ shape[i] = 1;
+ }
+ }
+ catch(...)
+ {
+ PyBuffer_Release(&pybuf);
+ throw;
+ }
+ PyBuffer_Release(&pybuf);
+}
+
+// ----------------------------------------------------------------------------------------
+
+template <typename T, int dims>
+void get_numpy_ndarray_parts (
+ py::object& obj,
+ T*& data,
+ dlib::array<T>& contig_buf,
+ long (&shape)[dims]
+)
+/*!
+ ensures
+ - extracts the pointer to the data from the given numpy ndarray. Stores the shape
+ of the array into #shape.
+ - the dimension of the given numpy array is not greater than #dims.
+ - #shape[#dims-1] == pixel_traits<T>::num when #dims is greater than 2
+!*/
+{
+ Py_buffer pybuf;
+ if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES | PyBUF_WRITABLE ))
+ throw dlib::error("Expected writable numpy.ndarray with shape set.");
+
+ try
+ {
+ validate_numpy_array_type<T>(obj);
+
+ if (pybuf.ndim > dims)
+ throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions.");
+ get_numpy_ndarray_shape(obj, shape);
+
+ if (dlib::pixel_traits<T>::num > 1 && dlib::pixel_traits<T>::num != shape[dims-1])
+ throw dlib::error("Expected numpy.ndarray with " + dlib::cast_to_string(dlib::pixel_traits<T>::num) + " channels.");
+
+ if (PyBuffer_IsContiguous(&pybuf, 'C'))
+ data = (T*)pybuf.buf;
+ else
+ {
+ contig_buf.resize(pybuf.len);
+ if (PyBuffer_ToContiguous(&contig_buf[0], &pybuf, pybuf.len, 'C'))
+ throw dlib::error("Can't copy numpy.ndarray to a contiguous buffer.");
+ data = &contig_buf[0];
+ }
+ }
+ catch(...)
+ {
+ PyBuffer_Release(&pybuf);
+ throw;
+ }
+ PyBuffer_Release(&pybuf);
+}
+
+// ----------------------------------------------------------------------------------------
+
+template <typename T, int dims>
+void get_numpy_ndarray_parts (
+ const py::object& obj,
+ const T*& data,
+ dlib::array<T>& contig_buf,
+ long (&shape)[dims]
+)
+/*!
+ ensures
+ - extracts the pointer to the data from the given numpy ndarray. Stores the shape
+ of the array into #shape.
+ - the dimension of the given numpy array is not greater than #dims.
+ - #shape[#dims-1] == pixel_traits<T>::num when #dims is greater than 2
+!*/
+{
+ Py_buffer pybuf;
+ if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES ))
+ throw dlib::error("Expected numpy.ndarray with shape set.");
+
+ try
+ {
+ validate_numpy_array_type<T>(obj);
+
+ if (pybuf.ndim > dims)
+ throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions.");
+ get_numpy_ndarray_shape(obj, shape);
+
+ if (dlib::pixel_traits<T>::num > 1 && dlib::pixel_traits<T>::num != shape[dims-1])
+ throw dlib::error("Expected numpy.ndarray with " + dlib::cast_to_string(dlib::pixel_traits<T>::num) + " channels.");
+
+ if (PyBuffer_IsContiguous(&pybuf, 'C'))
+ data = (const T*)pybuf.buf;
+ else
+ {
+ contig_buf.resize(pybuf.len);
+ if (PyBuffer_ToContiguous(&contig_buf[0], &pybuf, pybuf.len, 'C'))
+ throw dlib::error("Can't copy numpy.ndarray to a contiguous buffer.");
+ data = &contig_buf[0];
+ }
+ }
+ catch(...)
+ {
+ PyBuffer_Release(&pybuf);
+ throw;
+ }
+ PyBuffer_Release(&pybuf);
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_PYTHON_NuMPY_Hh_
+
diff --git a/ml/dlib/dlib/python/numpy_image.h b/ml/dlib/dlib/python/numpy_image.h
new file mode 100644
index 000000000..49ea80317
--- /dev/null
+++ b/ml/dlib/dlib/python/numpy_image.h
@@ -0,0 +1,129 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PYTHON_NuMPY_IMAGE_Hh_
+#define DLIB_PYTHON_NuMPY_IMAGE_Hh_
+
+#include "numpy.h"
+#include <dlib/pixel.h>
+#include <dlib/matrix.h>
+#include <dlib/array.h>
+
+
+// ----------------------------------------------------------------------------------------
+
+class numpy_gray_image
+{
+public:
+
+ numpy_gray_image() : _data(0), _nr(0), _nc(0) {}
+ numpy_gray_image (py::object& img)
+ {
+ long shape[2];
+ get_numpy_ndarray_parts(img, _data, _contig_buf, shape);
+ _nr = shape[0];
+ _nc = shape[1];
+ }
+
+ friend inline long num_rows(const numpy_gray_image& img) { return img._nr; }
+ friend inline long num_columns(const numpy_gray_image& img) { return img._nc; }
+ friend inline void* image_data(numpy_gray_image& img) { return img._data; }
+ friend inline const void* image_data(const numpy_gray_image& img) { return img._data; }
+ friend inline long width_step(const numpy_gray_image& img) { return img._nc*sizeof(unsigned char); }
+
+private:
+
+ unsigned char* _data;
+ dlib::array<unsigned char> _contig_buf;
+ long _nr;
+ long _nc;
+};
+
+namespace dlib
+{
+ template <>
+ struct image_traits<numpy_gray_image >
+ {
+ typedef unsigned char pixel_type;
+ };
+}
+
+// ----------------------------------------------------------------------------------------
+
+inline bool is_gray_python_image (py::object& img)
+{
+ try
+ {
+ long shape[2];
+ get_numpy_ndarray_shape(img, shape);
+ return true;
+ }
+ catch (dlib::error&)
+ {
+ return false;
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+class numpy_rgb_image
+{
+public:
+
+ numpy_rgb_image() : _data(0), _nr(0), _nc(0) {}
+ numpy_rgb_image (py::object& img)
+ {
+ long shape[3];
+ get_numpy_ndarray_parts(img, _data, _contig_buf, shape);
+ _nr = shape[0];
+ _nc = shape[1];
+ if (shape[2] != 3)
+ throw dlib::error("Error, python object is not a three band image and therefore can't be a RGB image.");
+ }
+
+ friend inline long num_rows(const numpy_rgb_image& img) { return img._nr; }
+ friend inline long num_columns(const numpy_rgb_image& img) { return img._nc; }
+ friend inline void* image_data(numpy_rgb_image& img) { return img._data; }
+ friend inline const void* image_data(const numpy_rgb_image& img) { return img._data; }
+ friend inline long width_step(const numpy_rgb_image& img) { return img._nc*sizeof(dlib::rgb_pixel); }
+
+
+private:
+
+ dlib::rgb_pixel* _data;
+ dlib::array<dlib::rgb_pixel> _contig_buf;
+ long _nr;
+ long _nc;
+};
+
+namespace dlib
+{
+ template <>
+ struct image_traits<numpy_rgb_image >
+ {
+ typedef rgb_pixel pixel_type;
+ };
+}
+
+// ----------------------------------------------------------------------------------------
+
+
+inline bool is_rgb_python_image (py::object& img)
+{
+ try
+ {
+ long shape[3];
+ get_numpy_ndarray_shape(img, shape);
+ if (shape[2] == 3)
+ return true;
+ return false;
+ }
+ catch (dlib::error&)
+ {
+ return false;
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_PYTHON_NuMPY_IMAGE_Hh_
+
diff --git a/ml/dlib/dlib/python/pyassert.h b/ml/dlib/dlib/python/pyassert.h
new file mode 100644
index 000000000..80939f501
--- /dev/null
+++ b/ml/dlib/dlib/python/pyassert.h
@@ -0,0 +1,17 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PYaSSERT_Hh_
+#define DLIB_PYaSSERT_Hh_
+
+#include <pybind11/pybind11.h>
+
+#define pyassert(_exp,_message) \
+ {if ( !(_exp) ) \
+ { \
+ namespace py = pybind11; \
+ PyErr_SetString( PyExc_ValueError, _message ); \
+ throw py::error_already_set(); \
+ }}
+
+#endif // DLIB_PYaSSERT_Hh_
+
diff --git a/ml/dlib/dlib/python/pybind_utils.h b/ml/dlib/dlib/python/pybind_utils.h
new file mode 100644
index 000000000..7f94cf32d
--- /dev/null
+++ b/ml/dlib/dlib/python/pybind_utils.h
@@ -0,0 +1,82 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PYBIND_UtILS_Hh_
+#define DLIB_PYBIND_UtILS_Hh_
+
+#include <pybind11/pybind11.h>
+#include <vector>
+#include <string>
+#include <dlib/serialize.h>
+
+namespace py = pybind11;
+
+template <typename T>
+std::vector<T> python_list_to_vector (
+ const py::list& obj
+)
+/*!
+ ensures
+ - converts a python object into a std::vector<T> and returns it.
+!*/
+{
+ std::vector<T> vect(len(obj));
+ for (unsigned long i = 0; i < vect.size(); ++i)
+ {
+ vect[i] = obj[i].cast<T>();
+ }
+ return vect;
+}
+
+template <typename T>
+py::list vector_to_python_list (
+ const std::vector<T>& vect
+)
+/*!
+ ensures
+ - converts a std::vector<T> into a python list object.
+!*/
+{
+ py::list obj;
+ for (unsigned long i = 0; i < vect.size(); ++i)
+ obj.append(vect[i]);
+ return obj;
+}
+
+template <typename T>
+void extend_vector_with_python_list (
+ std::vector<T> &v,
+ const py::list &l
+)
+/*!
+ ensures
+ - appends items from a python list to the end of std::vector<T>.
+!*/
+{
+ for (const auto &item : l)
+ v.push_back(item.cast<T>());
+}
+
+// ----------------------------------------------------------------------------------------
+
+template <typename T>
+std::shared_ptr<T> load_object_from_file (
+ const std::string& filename
+)
+/*!
+ ensures
+ - deserializes an object of type T from the given file and returns it.
+!*/
+{
+ std::ifstream fin(filename.c_str(), std::ios::binary);
+ if (!fin)
+ throw dlib::error("Unable to open " + filename);
+ auto obj = std::make_shared<T>();
+ deserialize(*obj, fin);
+ return obj;
+}
+
+// ----------------------------------------------------------------------------------------
+
+
+#endif // DLIB_PYBIND_UtILS_Hh_
+
diff --git a/ml/dlib/dlib/python/serialize_pickle.h b/ml/dlib/dlib/python/serialize_pickle.h
new file mode 100644
index 000000000..2dc44c322
--- /dev/null
+++ b/ml/dlib/dlib/python/serialize_pickle.h
@@ -0,0 +1,66 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERIALIZE_PiCKLE_Hh_
+#define DLIB_SERIALIZE_PiCKLE_Hh_
+
+#include <dlib/serialize.h>
+#include <pybind11/pybind11.h>
+#include <sstream>
+#include <dlib/vectorstream.h>
+
+template<typename T>
+py::tuple getstate(const T& item)
+{
+ using namespace dlib;
+ std::vector<char> buf;
+ buf.reserve(5000);
+ vectorstream sout(buf);
+ serialize(item, sout);
+ return py::make_tuple(py::handle(
+ PyBytes_FromStringAndSize(buf.size()?&buf[0]:0, buf.size())));
+}
+
+template<typename T>
+T setstate(py::tuple state)
+{
+ using namespace dlib;
+ if (len(state) != 1)
+ {
+ PyErr_SetObject(PyExc_ValueError,
+ py::str("expected 1-item tuple in call to __setstate__; got {}").format(state).ptr()
+ );
+ throw py::error_already_set();
+ }
+
+ // We used to serialize by converting to a str but the boost.python routines for
+ // doing this don't work in Python 3. You end up getting an error about invalid
+ // UTF-8 encodings. So instead we access the python C interface directly and use
+ // bytes objects. However, we keep the deserialization code that worked with str
+ // for backwards compatibility with previously pickled files.
+ T item;
+ py::object obj = state[0];
+ if (py::isinstance<py::str>(obj))
+ {
+ py::str data = state[0].cast<py::str>();
+ std::string temp = data;
+ std::istringstream sin(temp);
+ deserialize(item, sin);
+ }
+ else if(PyBytes_Check(py::object(state[0]).ptr()))
+ {
+ py::object obj = state[0];
+ char* data = PyBytes_AsString(obj.ptr());
+ unsigned long num = PyBytes_Size(obj.ptr());
+ std::istringstream sin(std::string(data, num));
+ deserialize(item, sin);
+ }
+ else
+ {
+ throw error("Unable to unpickle, error in input file.");
+ }
+
+ return item;
+}
+
+#endif // DLIB_SERIALIZE_PiCKLE_Hh_
+
diff --git a/ml/dlib/dlib/quantum_computing.h b/ml/dlib/dlib/quantum_computing.h
new file mode 100644
index 000000000..11af76197
--- /dev/null
+++ b/ml/dlib/dlib/quantum_computing.h
@@ -0,0 +1,12 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_QUANTUM_COMPUTINg_H_
+#define DLIB_QUANTUM_COMPUTINg_H_
+
+#include "quantum_computing/quantum_computing.h"
+
+#endif // DLIB_QUANTUM_COMPUTINg_H_
+
+
+
+
diff --git a/ml/dlib/dlib/quantum_computing/quantum_computing.h b/ml/dlib/dlib/quantum_computing/quantum_computing.h
new file mode 100644
index 000000000..afa2e40e7
--- /dev/null
+++ b/ml/dlib/dlib/quantum_computing/quantum_computing.h
@@ -0,0 +1,863 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_QUANTUM_COMPUTINg_1_
+#define DLIB_QUANTUM_COMPUTINg_1_
+
+#include <complex>
+#include <cmath>
+#include "../matrix.h"
+#include "../rand.h"
+#include "../enable_if.h"
+#include "../algs.h"
+#include "quantum_computing_abstract.h"
+
+namespace dlib
+{
+
+ template <typename T>
+ struct gate_traits {};
+
+ namespace qc_helpers
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ // This is a template to compute the value of 2^n at compile time
+ template <long n>
+ struct exp_2_n
+ {
+ COMPILE_TIME_ASSERT(0 <= n && n <= 30);
+ static const long value = exp_2_n<n-1>::value*2;
+ };
+
+ template <>
+ struct exp_2_n<0>
+ {
+ static const long value = 1;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+ typedef std::complex<double> qc_scalar_type;
+
+// ----------------------------------------------------------------------------------------
+
+ class quantum_register
+ {
+ public:
+
+ quantum_register()
+ {
+ set_num_bits(1);
+ }
+
+ int num_bits (
+ ) const
+ {
+ return num_bits_in_register;
+ }
+
+ void set_num_bits (
+ int num_bits
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(1 <= num_bits && num_bits <= 30,
+ "\tvoid quantum_register::set_num_bits()"
+ << "\n\tinvalid arguments to this function"
+ << "\n\tnum_bits: " << num_bits
+ << "\n\tthis: " << this
+ );
+
+ num_bits_in_register = num_bits;
+
+ unsigned long size = 1;
+ for (int i = 0; i < num_bits; ++i)
+ size *= 2;
+
+ state.set_size(size);
+
+ zero_all_bits();
+ }
+
+ void zero_all_bits()
+ {
+ set_all_elements(state,0);
+ state(0) = 1;
+ }
+
+ void append (
+ const quantum_register& reg
+ )
+ {
+ num_bits_in_register += reg.num_bits_in_register;
+ state = tensor_product(state, reg.state);
+ }
+
+ template <typename rand_type>
+ bool measure_bit (
+ int bit,
+ rand_type& rnd
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(0 <= bit && bit < num_bits(),
+ "\tbool quantum_register::measure_bit()"
+ << "\n\tinvalid arguments to this function"
+ << "\n\tbit: " << bit
+ << "\n\tnum_bits(): " << num_bits()
+ << "\n\tthis: " << this
+ );
+
+ const bool value = (rnd.get_random_double() < probability_of_bit(bit));
+
+ // Next we set all the states where this bit doesn't have the given value to 0
+
+ // But first make a mask that selects our bit
+ unsigned long mask = 1;
+ for (int i = 0; i < bit; ++i)
+ mask <<= 1;
+
+ // loop over all the elements in the state vector and zero out those that
+ // conflict with the measurement we just made.
+ for (long r = 0; r < state.nr(); ++r)
+ {
+ const unsigned long field = r;
+ // if this state indicates that the bit should be set and it isn't
+ if ((field & mask) && !value)
+ {
+ state(r) = 0;
+ }
+ // else if this state indicates that the bit should not be set and it is
+ else if (!(field & mask) && value)
+ {
+ state(r) = 0;
+ }
+ }
+
+ // normalize the state
+ state = state/(std::sqrt(sum(norm(state))));
+
+ return value;
+ }
+
+ template <typename rand_type>
+ bool measure_and_remove_bit (
+ int bit,
+ rand_type& rnd
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(0 <= bit && bit < num_bits() && num_bits() > 0,
+ "\tbool quantum_register::measure_and_remove_bit()"
+ << "\n\tinvalid arguments to this function"
+ << "\n\tbit: " << bit
+ << "\n\tnum_bits(): " << num_bits()
+ << "\n\tthis: " << this
+ );
+
+
+ const bool value = (rnd.get_random_double() < probability_of_bit(bit));
+ quantum_register temp;
+ temp.set_num_bits(num_bits()-1);
+
+
+ // Next we set all the states where this bit doesn't have the given value to 0
+
+ // But first make a mask that selects our bit
+ unsigned long mask = 1;
+ for (int i = 0; i < bit; ++i)
+ mask <<= 1;
+
+ long count = 0;
+ for (long r = 0; r < state.nr(); ++r)
+ {
+ const unsigned long field = r;
+ // if this basis vector is one that matches the measured state then keep it
+ if (((field & mask) != 0) == value)
+ {
+ temp.state(count) = state(r);
+ ++count;
+ }
+ }
+
+ // normalize the state
+ temp.state = temp.state/std::sqrt(sum(norm(temp.state)));
+
+ temp.swap(*this);
+
+ return value;
+ }
+
+ double probability_of_bit (
+ int bit
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(0 <= bit && bit < num_bits(),
+ "\tdouble quantum_register::probability_of_bit()"
+ << "\n\tinvalid arguments to this function"
+ << "\n\tbit: " << bit
+ << "\n\tnum_bits(): " << num_bits()
+ << "\n\tthis: " << this
+ );
+
+
+ // make a mask that selects our bit
+ unsigned long mask = 1;
+ for (int i = 0; i < bit; ++i)
+ mask <<= 1;
+
+ // now find the total probability of all the states that have the given bit set
+ double prob = 0;
+ for (long r = 0; r < state.nr(); ++r)
+ {
+ const unsigned long field = r;
+ if (field & mask)
+ {
+ prob += std::norm(state(r));
+ }
+ }
+
+
+ return prob;
+ }
+
+ const matrix<qc_scalar_type,0,1>& state_vector() const { return state; }
+ matrix<qc_scalar_type,0,1>& state_vector() { return state; }
+
+ void swap (
+ quantum_register& item
+ )
+ {
+ exchange(num_bits_in_register, item.num_bits_in_register);
+ state.swap(item.state);
+ }
+
+ private:
+
+ int num_bits_in_register;
+ matrix<qc_scalar_type,0,1> state;
+ };
+
+ inline void swap (
+ quantum_register& a,
+ quantum_register& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class gate_exp
+ {
+ public:
+ static const long num_bits = gate_traits<T>::num_bits;
+ static const long dims = gate_traits<T>::dims;
+
+ gate_exp(T& exp_) : exp(exp_) {}
+
+ const qc_scalar_type operator() (long r, long c) const { return exp(r,c); }
+
+ const matrix<qc_scalar_type> mat (
+ ) const
+ {
+ matrix<qc_scalar_type,dims,dims> m;
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = exp(r,c);
+ }
+ }
+ return m;
+ }
+
+ void apply_gate_to (quantum_register& reg) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(reg.num_bits() == num_bits,
+ "\tvoid gate_exp::apply_gate_to()"
+ << "\n\tinvalid arguments to this function"
+ << "\n\treg.num_bits(): " << reg.num_bits()
+ << "\n\tnum_bits: " << num_bits
+ << "\n\tthis: " << this
+ );
+
+
+ quantum_register temp(reg);
+
+
+ // check if any of the elements of the register are 1 and if so then
+ // we don't have to do the full matrix multiply. Or check if only a small number are non-zero.
+ long non_zero_elements = 0;
+ for (long r = 0; r < dims; ++r)
+ {
+ if (reg.state_vector()(r) != qc_scalar_type(0))
+ ++non_zero_elements;
+
+ reg.state_vector()(r) = 0;
+ }
+
+
+ if (non_zero_elements > 3)
+ {
+ // do a full matrix multiply to compute the output state
+ for (long r = 0; r < dims; ++r)
+ {
+ reg.state_vector()(r) = compute_state_element(temp.state_vector(),r);
+ }
+ }
+ else
+ {
+ // do a matrix multiply but only use the columns in the gate matrix
+ // that correspond to the non-zero register elements
+ for (long r = 0; r < dims; ++r)
+ {
+ if (temp.state_vector()(r) != qc_scalar_type(0))
+ {
+ for (long i = 0; i < dims; ++i)
+ {
+ reg.state_vector()(i) += temp.state_vector()(r)*exp(i,r);
+ }
+ }
+ }
+ }
+ }
+
+ template <typename exp>
+ qc_scalar_type compute_state_element (
+ const matrix_exp<exp>& reg,
+ long row_idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 &&
+ 0 <= row_idx && row_idx < dims,
+ "\tqc_scalar_type gate_exp::compute_state_element(reg,row_idx)"
+ << "\n\tinvalid arguments to this function"
+ << "\n\treg.nr(): " << reg.nr()
+ << "\n\treg.nc(): " << reg.nc()
+ << "\n\tdims: " << dims
+ << "\n\trow_idx: " << row_idx
+ << "\n\tthis: " << this
+ );
+
+
+ return this->exp.compute_state_element(reg,row_idx);
+ }
+
+ const T& ref() const { return exp; }
+
+ private:
+ T& exp;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <typename T, typename U>
+ class composite_gate;
+
+ template <typename T, typename U>
+ struct gate_traits<composite_gate<T,U> >
+ {
+ static const long num_bits = T::num_bits + U::num_bits;
+ static const long dims = qc_helpers::exp_2_n<num_bits>::value;
+ };
+
+ template <typename T, typename U>
+ class composite_gate : public gate_exp<composite_gate<T,U> >
+ {
+ public:
+
+ typedef T lhs_type;
+ typedef U rhs_type;
+
+ composite_gate(const composite_gate& g) : gate_exp<composite_gate>(*this), lhs(g.lhs), rhs(g.rhs) {}
+
+ composite_gate(
+ const gate_exp<T>& lhs_,
+ const gate_exp<U>& rhs_
+ ) : gate_exp<composite_gate>(*this), lhs(lhs_.ref()), rhs(rhs_.ref()) {}
+
+
+
+ static const long num_bits = gate_traits<composite_gate>::num_bits;
+ static const long dims = gate_traits<composite_gate>::dims;
+
+ const qc_scalar_type operator() (long r, long c) const { return lhs(r/U::dims,c/U::dims)*rhs(r%U::dims, c%U::dims); }
+
+ template <typename exp>
+ qc_scalar_type compute_state_element (
+ const matrix_exp<exp>& reg,
+ long row_idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 &&
+ 0 <= row_idx && row_idx < dims,
+ "\tqc_scalar_type composite_gate::compute_state_element(reg,row_idx)"
+ << "\n\tinvalid arguments to this function"
+ << "\n\treg.nr(): " << reg.nr()
+ << "\n\treg.nc(): " << reg.nc()
+ << "\n\tdims: " << dims
+ << "\n\trow_idx: " << row_idx
+ << "\n\tthis: " << this
+ );
+
+
+ qc_scalar_type result = 0;
+ for (long c = 0; c < T::dims; ++c)
+ {
+ if (lhs(row_idx/U::dims,c) != qc_scalar_type(0))
+ {
+ result += lhs(row_idx/U::dims,c) * rhs.compute_state_element(subm(reg,c*U::dims,0,U::dims,1), row_idx%U::dims);
+ }
+ }
+
+ return result;
+ }
+
+
+ const T lhs;
+ const U rhs;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <long bits>
+ class gate;
+ template <long bits>
+ struct gate_traits<gate<bits> >
+ {
+ static const long num_bits = bits;
+ static const long dims = qc_helpers::exp_2_n<num_bits>::value;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <long bits>
+ class gate : public gate_exp<gate<bits> >
+ {
+ public:
+ gate() : gate_exp<gate>(*this) { set_all_elements(data,0); }
+ gate(const gate& g) :gate_exp<gate>(*this), data(g.data) {}
+
+ template <typename T>
+ explicit gate(const gate_exp<T>& g) : gate_exp<gate>(*this)
+ {
+ COMPILE_TIME_ASSERT(T::num_bits == num_bits);
+ for (long r = 0; r < dims; ++r)
+ {
+ for (long c = 0; c < dims; ++c)
+ {
+ data(r,c) = g(r,c);
+ }
+ }
+ }
+
+ static const long num_bits = gate_traits<gate>::num_bits;
+ static const long dims = gate_traits<gate>::dims;
+
+ const qc_scalar_type& operator() (long r, long c) const { return data(r,c); }
+ qc_scalar_type& operator() (long r, long c) { return data(r,c); }
+
+ template <typename exp>
+ qc_scalar_type compute_state_element (
+ const matrix_exp<exp>& reg,
+ long row_idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 &&
+ 0 <= row_idx && row_idx < dims,
+ "\tqc_scalar_type gate::compute_state_element(reg,row_idx)"
+ << "\n\tinvalid arguments to this function"
+ << "\n\treg.nr(): " << reg.nr()
+ << "\n\treg.nc(): " << reg.nc()
+ << "\n\tdims: " << dims
+ << "\n\trow_idx: " << row_idx
+ << "\n\tthis: " << this
+ );
+
+
+ return (data*reg)(row_idx);
+ }
+
+ private:
+
+ matrix<qc_scalar_type,dims,dims> data;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ namespace qc_helpers
+ {
+ // This is the maximum number of bits used for cached sub-matrices in composite_gate expressions
+ const int qc_block_chunking_size = 8;
+
+ template <typename T>
+ struct is_composite_gate { const static bool value = false; };
+ template <typename T, typename U>
+ struct is_composite_gate<composite_gate<T,U> > { const static bool value = true; };
+
+
+ // These overloads all deal with intelligently composing chains of composite_gate expressions
+ // such that the resulting expression has the form:
+ // (gate_exp,(gate_exp,(gate_exp,(gate_exp()))))
+ // and each gate_exp contains a cached gate matrix for a gate of at most qc_block_chunking_size bits.
+ // This facilitates the optimizations in the compute_state_element() function.
+ template <typename T, typename U, typename V, typename enabled = void>
+ struct combine_gates;
+
+ // This is a base case of this recursive template. It takes care of converting small composite_gates into
+ // cached gate objects.
+ template <typename T, typename U, typename V>
+ struct combine_gates<T,U,V,typename enable_if_c<(T::num_bits + U::num_bits <= qc_block_chunking_size)>::type >
+ {
+ typedef composite_gate<gate<T::num_bits + U::num_bits>,V> result_type;
+
+ static const result_type eval (
+ const composite_gate<T,U>& lhs,
+ const gate_exp<V>& rhs
+ )
+ {
+ typedef gate<T::num_bits + U::num_bits> gate_type;
+ return composite_gate<gate_type,V>(gate_type(lhs), rhs);
+ }
+ };
+
+ // this is the recursive step of this template
+ template <typename T, typename U, typename V>
+ struct combine_gates<T,U,V,typename enable_if_c<(is_composite_gate<U>::value == true)>::type >
+ {
+ typedef typename combine_gates<typename U::lhs_type, typename U::rhs_type, V>::result_type inner_type;
+ typedef composite_gate<T,inner_type> result_type;
+
+ static const result_type eval (
+ const composite_gate<T,U>& lhs,
+ const gate_exp<V>& rhs
+ )
+ {
+ return composite_gate<T,inner_type>(lhs.lhs, combine_gates<typename U::lhs_type, typename U::rhs_type, V>::eval(lhs.rhs,rhs));
+ }
+
+ };
+
+ // This is a base case of this recursive template. It takes care of adding new gates when the left
+ // hand side is too big to just turn it into a cached gate object.
+ template <typename T, typename U, typename V>
+ struct combine_gates<T,U,V,typename enable_if_c<(T::num_bits + U::num_bits > qc_block_chunking_size &&
+ is_composite_gate<U>::value == false)>::type >
+ {
+ typedef composite_gate<T,composite_gate<U, V> > result_type;
+
+ static const result_type eval (
+ const composite_gate<T,U>& lhs,
+ const gate_exp<V>& rhs
+ )
+ {
+ return result_type(lhs.lhs, composite_gate<U,V>(lhs.rhs, rhs));
+ }
+
+ };
+
+ }
+
+ template <typename T, typename U>
+ const composite_gate<T,U> operator, (
+ const gate_exp<T>& lhs,
+ const gate_exp<U>& rhs
+ )
+ {
+ return composite_gate<T,U>(lhs,rhs);
+ }
+
+ template <typename T, typename U, typename V>
+ const typename qc_helpers::combine_gates<T,U,V>::result_type operator, (
+ const composite_gate<T,U>& lhs,
+ const gate_exp<V>& rhs
+ )
+ {
+ return qc_helpers::combine_gates<T,U,V>::eval(lhs,rhs);
+ }
+
+ // If you are getting an error here then it means that you are trying to combine a gate expression
+ // with an integer somewhere (and that is an error).
+ template <typename T> void operator, ( const gate_exp<T>&, int) { COMPILE_TIME_ASSERT(sizeof(T) > 100000000); }
+ template <typename T> void operator, ( int, const gate_exp<T>&) { COMPILE_TIME_ASSERT(sizeof(T) > 100000000); }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace quantum_gates
+ {
+ template <int control_bit, int target_bit>
+ class cnot;
+
+ template <int control_bit1, int control_bit2, int target_bit>
+ class toffoli;
+ }
+
+ template <int control_bit, int target_bit>
+ struct gate_traits<quantum_gates::cnot<control_bit, target_bit> >
+ {
+ static const long num_bits = tabs<control_bit-target_bit>::value+1;
+ static const long dims = qc_helpers::exp_2_n<num_bits>::value;
+ };
+
+ template <int control_bit1, int control_bit2, int target_bit>
+ struct gate_traits<quantum_gates::toffoli<control_bit1, control_bit2, target_bit> >
+ {
+ static const long num_bits = tmax<tabs<control_bit1-target_bit>::value,
+ tabs<control_bit2-target_bit>::value>::value+1;
+ static const long dims = qc_helpers::exp_2_n<num_bits>::value;
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ namespace quantum_gates
+ {
+
+ inline const gate<1> hadamard(
+ )
+ {
+ gate<1> h;
+ h(0,0) = std::sqrt(1/2.0);
+ h(0,1) = std::sqrt(1/2.0);
+ h(1,0) = std::sqrt(1/2.0);
+ h(1,1) = -std::sqrt(1/2.0);
+ return h;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline const gate<1> x(
+ )
+ {
+ gate<1> x;
+ x(0,1) = 1;
+ x(1,0) = 1;
+ return x;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline const gate<1> y(
+ )
+ {
+ gate<1> x;
+ qc_scalar_type i(0,1);
+ x(0,1) = -i;
+ x(1,0) = i;
+ return x;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline const gate<1> z(
+ )
+ {
+ gate<1> z;
+ z(0,0) = 1;
+ z(1,1) = -1;
+ return z;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ inline const gate<1> noop(
+ )
+ {
+ gate<1> i;
+ i(0,0) = 1;
+ i(1,1) = 1;
+ return i;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <int control_bit, int target_bit>
+ class cnot : public gate_exp<cnot<control_bit, target_bit> >
+ {
+ public:
+ COMPILE_TIME_ASSERT(control_bit != target_bit);
+
+ cnot() : gate_exp<cnot>(*this)
+ {
+ const int min_bit = std::min(control_bit, target_bit);
+
+ control_mask = 1;
+ target_mask = 1;
+
+ // make the masks so that their only on bit corresponds to the given control_bit and target_bit bits
+ for (int i = 0; i < control_bit-min_bit; ++i)
+ control_mask <<= 1;
+ for (int i = 0; i < target_bit-min_bit; ++i)
+ target_mask <<= 1;
+ }
+
+ static const long num_bits = gate_traits<cnot>::num_bits;
+ static const long dims = gate_traits<cnot>::dims;
+
+ const qc_scalar_type operator() (long r, long c) const
+ {
+ unsigned long output;
+ // if the input control bit is set
+ if (control_mask&c)
+ {
+ output = c^target_mask;
+ }
+ else
+ {
+ output = c;
+ }
+
+ if ((unsigned long)r == output)
+ return 1;
+ else
+ return 0;
+ }
+
+ template <typename exp>
+ qc_scalar_type compute_state_element (
+ const matrix_exp<exp>& reg,
+ long row_idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 &&
+ 0 <= row_idx && row_idx < dims,
+ "\tqc_scalar_type cnot::compute_state_element(reg,row_idx)"
+ << "\n\tinvalid arguments to this function"
+ << "\n\treg.nr(): " << reg.nr()
+ << "\n\treg.nc(): " << reg.nc()
+ << "\n\tdims: " << dims
+ << "\n\trow_idx: " << row_idx
+ << "\n\tthis: " << this
+ );
+
+
+ unsigned long output = row_idx;
+ // if the input control bit is set
+ if (control_mask&output)
+ {
+ output = output^target_mask;
+ }
+
+ return reg(output);
+ }
+
+ private:
+
+ unsigned long control_mask;
+ unsigned long target_mask;
+
+
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <int control_bit1, int control_bit2, int target_bit>
+ class toffoli : public gate_exp<toffoli<control_bit1, control_bit2, target_bit> >
+ {
+ public:
+ COMPILE_TIME_ASSERT(control_bit1 != target_bit && control_bit2 != target_bit && control_bit1 != control_bit2);
+ COMPILE_TIME_ASSERT((control_bit1 < target_bit && control_bit2 < target_bit) ||(control_bit1 > target_bit && control_bit2 > target_bit) );
+
+ toffoli() : gate_exp<toffoli>(*this)
+ {
+ const int min_bit = std::min(std::min(control_bit1, control_bit2), target_bit);
+
+ control1_mask = 1;
+ control2_mask = 1;
+ target_mask = 1;
+
+ // make the masks so that their only on bit corresponds to the given control_bit1 and target_bit bits
+ for (int i = 0; i < control_bit1-min_bit; ++i)
+ control1_mask <<= 1;
+ for (int i = 0; i < control_bit2-min_bit; ++i)
+ control2_mask <<= 1;
+ for (int i = 0; i < target_bit-min_bit; ++i)
+ target_mask <<= 1;
+ }
+
+ static const long num_bits = gate_traits<toffoli>::num_bits;
+ static const long dims = gate_traits<toffoli>::dims;
+
+ const qc_scalar_type operator() (long r, long c) const
+ {
+ unsigned long output;
+ // if the input control bits are set
+ if ((control1_mask&c) && (control2_mask&c))
+ {
+ output = c^target_mask;
+ }
+ else
+ {
+ output = c;
+ }
+
+ if ((unsigned long)r == output)
+ return 1;
+ else
+ return 0;
+ }
+
+ template <typename exp>
+ qc_scalar_type compute_state_element (
+ const matrix_exp<exp>& reg,
+ long row_idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 &&
+ 0 <= row_idx && row_idx < dims,
+ "\tqc_scalar_type toffoli::compute_state_element(reg,row_idx)"
+ << "\n\tinvalid arguments to this function"
+ << "\n\treg.nr(): " << reg.nr()
+ << "\n\treg.nc(): " << reg.nc()
+ << "\n\tdims: " << dims
+ << "\n\trow_idx: " << row_idx
+ << "\n\tthis: " << this
+ );
+
+
+ unsigned long output;
+ // if the input control bits are set
+ if ((control1_mask&row_idx) && (control2_mask&row_idx))
+ {
+ output = row_idx^target_mask;
+ }
+ else
+ {
+ output = row_idx;
+ }
+
+ return reg(output);
+
+ }
+
+ private:
+
+ unsigned long control1_mask;
+ unsigned long control2_mask;
+ unsigned long target_mask;
+
+
+ };
+
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_QUANTUM_COMPUTINg_1_
+
+
diff --git a/ml/dlib/dlib/quantum_computing/quantum_computing_abstract.h b/ml/dlib/dlib/quantum_computing/quantum_computing_abstract.h
new file mode 100644
index 000000000..bcc65af23
--- /dev/null
+++ b/ml/dlib/dlib/quantum_computing/quantum_computing_abstract.h
@@ -0,0 +1,590 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_QUANTUM_COMPUTINg_ABSTRACT_
+#ifdef DLIB_QUANTUM_COMPUTINg_ABSTRACT_
+
+#include <complex>
+#include "../matrix.h"
+#include "../rand.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ typedef std::complex<double> qc_scalar_type;
+
+// ----------------------------------------------------------------------------------------
+
+ class quantum_register
+ {
+ /*!
+ INITIAL VALUE
+ - num_bits() == 1
+ - state_vector().nr() == 2
+ - state_vector().nc() == 1
+ - state_vector()(0) == 1
+ - state_vector()(1) == 0
+ - probability_of_bit(0) == 0
+
+ - i.e. This register represents a single quantum bit and it is
+ completely in the 0 state.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a set of quantum bits.
+ !*/
+
+ public:
+
+ quantum_register(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ int num_bits (
+ ) const;
+ /*!
+ ensures
+ - returns the number of quantum bits in this register
+ !*/
+
+ void set_num_bits (
+ int new_num_bits
+ );
+ /*!
+ requires
+ - 1 <= new_num_bits <= 30
+ ensures
+ - #num_bits() == new_num_bits
+ - #state_vector().nr() == 2^new_num_bits
+ (i.e. the size of the state_vector is exponential in the number of bits in a register)
+ - for all valid i:
+ - probability_of_bit(i) == 0
+ !*/
+
+ void zero_all_bits(
+ );
+ /*!
+ ensures
+ - for all valid i:
+ - probability_of_bit(i) == 0
+ !*/
+
+ void append (
+ const quantum_register& reg
+ );
+ /*!
+ ensures
+ - #num_bits() == num_bits() + reg.num_bits()
+ - #this->state_vector() == tensor_product(this->state_vector(), reg.state_vector())
+ - The original bits in *this become the high order bits of the resulting
+ register and all the bits in reg end up as the low order bits in the
+ resulting register.
+ !*/
+
+ double probability_of_bit (
+ int bit
+ ) const;
+ /*!
+ requires
+ - 0 <= bit < num_bits()
+ ensures
+ - returns the probability of measuring the given bit and it being in the 1 state.
+ - The returned value is also equal to the sum of norm(state_vector()(i)) for all
+ i where the bit'th bit in i is set to 1. (note that the lowest order bit is bit 0)
+ !*/
+
+ template <typename rand_type>
+ bool measure_bit (
+ int bit,
+ rand_type& rnd
+ );
+ /*!
+ requires
+ - 0 <= bit < num_bits()
+ - rand_type == an implementation of dlib/rand/rand_float_abstract.h
+ ensures
+ - measures the given bit in this register. Let R denote the boolean
+ result of the measurement, where true means the bit was measured to
+ have value 1 and false means it had a value of 0.
+ - if (R == true) then
+ - returns true
+ - #probability_of_bit(bit) == 1
+ - else
+ - returns false
+ - #probability_of_bit(bit) == 0
+ !*/
+
+ template <typename rand_type>
+ bool measure_and_remove_bit (
+ int bit,
+ rand_type& rnd
+ );
+ /*!
+ requires
+ - num_bits() > 1
+ - 0 <= bit < num_bits()
+ - rand_type == an implementation of dlib/rand/rand_float_abstract.h
+ ensures
+ - measures the given bit in this register. Let R denote the boolean
+ result of the measurement, where true means the bit was measured to
+ have value 1 and false means it had a value of 0.
+ - #num_bits() == num_bits() - 1
+ - removes the bit that was measured from this register.
+ - if (R == true) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ const matrix<qc_scalar_type,0,1>& state_vector(
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the state vector that describes the state of
+ the quantum bits in this register.
+ !*/
+
+ matrix<qc_scalar_type,0,1>& state_vector(
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the state vector that describes the state of
+ the quantum bits in this register.
+ !*/
+
+ void swap (
+ quantum_register& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ inline void swap (
+ quantum_register& a,
+ quantum_register& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class gate_exp
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be some object that inherits from gate_exp and implements its own
+ version of operator() and compute_state_element().
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an expression that evaluates to a quantum gate
+ that operates on T::num_bits qubits.
+
+ This object makes it easy to create new types of gate objects. All
+ you need to do is inherit from gate_exp in the proper way and
+ then you can use your new gate objects in conjunction with all the
+ others.
+ !*/
+
+ public:
+
+ static const long num_bits = T::num_bits;
+ static const long dims = T::dims;
+
+ gate_exp(
+ T& exp
+ );
+ /*!
+ ensures
+ - #&ref() == &exp
+ !*/
+
+ const qc_scalar_type operator() (
+ long r,
+ long c
+ ) const;
+ /*!
+ requires
+ - 0 <= r < dims
+ - 0 <= c < dims
+ ensures
+ - returns ref()(r,c)
+ !*/
+
+ void apply_gate_to (
+ quantum_register& reg
+ ) const;
+ /*!
+ requires
+ - reg.num_bits() == num_bits
+ ensures
+ - applies this quantum gate to the given quantum register
+ - Let M represent the matrix for this quantum gate, then
+ #reg().state_vector() = M*reg().state_vector()
+ !*/
+
+ template <typename exp>
+ qc_scalar_type compute_state_element (
+ const matrix_exp<exp>& reg,
+ long row_idx
+ ) const;
+ /*!
+ requires
+ - reg.nr() == dims
+ - reg.nc() == 1
+ - 0 <= row_idx < dims
+ ensures
+ - Let M represent the matrix for this gate, then
+ this function returns rowm(M*reg, row_idx)
+ (i.e. returns the row_idx row of what you get when you apply this
+ gate to the given column vector in reg)
+ - This function works by calling ref().compute_state_element(reg,row_idx)
+ !*/
+
+ const T& ref(
+ );
+ /*!
+ ensures
+ - returns a reference to the subexpression contained in this object
+ !*/
+
+ const matrix<qc_scalar_type> mat (
+ ) const;
+ /*!
+ ensures
+ - returns a dense matrix object that contains the matrix for this gate
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ class composite_gate : public gate_exp<composite_gate<T,U> >
+ {
+ /*!
+ REQUIREMENTS ON T AND U
+ Both must be gate expressions that inherit from gate_exp
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a quantum gate that is the tensor product of
+ two other quantum gates.
+
+
+ As an example, suppose you have 3 registers, reg_high, reg_low, and reg_all. Also
+ suppose that reg_all is what you get when you append reg_high and reg_low,
+ so reg_all.state_vector() == tensor_product(reg_high.state_vector(),reg_low.state_vector()).
+
+ Then applying a composite gate to reg_all would give you the same thing as
+ applying the lhs gate to reg_high and the rhs gate to reg_low and then appending
+ the two resulting registers. So the lhs gate of a composite_gate applies to
+ the high order bits of a regitser and the rhs gate applies to the lower order bits.
+ !*/
+ public:
+
+ composite_gate (
+ const composite_gate& g
+ );
+ /*!
+ ensures
+ - *this is a copy of g
+ !*/
+
+ composite_gate(
+ const gate_exp<T>& lhs_,
+ const gate_exp<U>& rhs_
+ ):
+ /*!
+ ensures
+ - #lhs == lhs_.ref()
+ - #rhs == rhs_.ref()
+ - #num_bits == T::num_bits + U::num_bits
+ - #dims == 2^num_bits
+ - #&ref() == this
+ !*/
+
+ const qc_scalar_type operator() (
+ long r,
+ long c
+ ) const;
+ /*!
+ requires
+ - 0 <= r < dims
+ - 0 <= c < dims
+ ensures
+ - Let M denote the tensor product of lhs with rhs, then this function
+ returns M(r,c)
+ (i.e. returns lhs(r/U::dims,c/U::dims)*rhs(r%U::dims, c%U::dims))
+ !*/
+
+ template <typename exp>
+ qc_scalar_type compute_state_element (
+ const matrix_exp<exp>& reg,
+ long row_idx
+ ) const;
+ /*!
+ requires
+ - reg.nr() == dims
+ - reg.nc() == 1
+ - 0 <= row_idx < dims
+ ensures
+ - Let M represent the matrix for this gate, then this function
+ returns rowm(M*reg, row_idx)
+ (i.e. returns the row_idx row of what you get when you apply this
+ gate to the given column vector in reg)
+ - This function works by calling rhs.compute_state_element() and using elements
+ of the matrix in lhs.
+ !*/
+
+ static const long num_bits;
+ static const long dims;
+
+ const T lhs;
+ const U rhs;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <long bits>
+ class gate : public gate_exp<gate<bits> >
+ {
+ /*!
+ REQUIREMENTS ON bits
+ 0 < bits <= 30
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a quantum gate that operates on bits qubits.
+ It stores its gate matrix explicitly in a dense in-memory matrix.
+ !*/
+
+ public:
+ gate(
+ );
+ /*!
+ ensures
+ - num_bits == bits
+ - dims == 2^bits
+ - #&ref() == this
+ - for all valid r and c:
+ #(*this)(r,c) == 0
+ !*/
+
+ gate (
+ const gate& g
+ );
+ /*!
+ ensures
+ - *this is a copy of g
+ !*/
+
+ template <typename T>
+ explicit gate(
+ const gate_exp<T>& g
+ );
+ /*!
+ requires
+ - T::num_bits == num_bits
+ ensures
+ - num_bits == bits
+ - dims == 2^bits
+ - #&ref() == this
+ - for all valid r and c:
+ #(*this)(r,c) == g(r,c)
+ !*/
+
+ const qc_scalar_type& operator() (
+ long r,
+ long c
+ ) const;
+ /*!
+ requires
+ - 0 <= r < dims
+ - 0 <= c < dims
+ ensures
+ - Let M denote the matrix for this gate, then this function
+ returns a const reference to M(r,c)
+ !*/
+
+ qc_scalar_type& operator() (
+ long r,
+ long c
+ );
+ /*!
+ requires
+ - 0 <= r < dims
+ - 0 <= c < dims
+ ensures
+ - Let M denote the matrix for this gate, then this function
+ returns a non-const reference to M(r,c)
+ !*/
+
+ template <typename exp>
+ qc_scalar_type compute_state_element (
+ const matrix_exp<exp>& reg,
+ long row_idx
+ ) const;
+ /*!
+ requires
+ - reg.nr() == dims
+ - reg.nc() == 1
+ - 0 <= row_idx < dims
+ ensures
+ - Let M represent the matrix for this gate, then this function
+ returns rowm(M*reg, row_idx)
+ (i.e. returns the row_idx row of what you get when you apply this
+ gate to the given column vector in reg)
+ !*/
+
+ static const long num_bits;
+ static const long dims;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ const composite_gate<T,U> operator, (
+ const gate_exp<T>& lhs,
+ const gate_exp<U>& rhs
+ ) { return composite_gate<T,U>(lhs,rhs); }
+ /*!
+ ensures
+ - returns a composite_gate that represents the tensor product of the lhs
+ gate with the rhs gate.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ namespace quantum_gates
+ {
+
+ inline const gate<1> hadamard(
+ );
+ /*!
+ ensures
+ - returns the Hadamard gate.
+ (i.e. A gate with a matrix of
+ |1, 1|
+ 1/sqrt(2) * |1,-1| )
+ !*/
+
+ inline const gate<1> x(
+ );
+ /*!
+ ensures
+ - returns the not gate.
+ (i.e. A gate with a matrix of
+ |0, 1|
+ |1, 0| )
+ !*/
+
+ inline const gate<1> y(
+ );
+ /*!
+ ensures
+ - returns the y gate.
+ (i.e. A gate with a matrix of
+ |0,-i|
+ |i, 0| )
+ !*/
+
+ inline const gate<1> z(
+ );
+ /*!
+ ensures
+ - returns the z gate.
+ (i.e. A gate with a matrix of
+ |1, 0|
+ |0,-1| )
+ !*/
+
+ inline const gate<1> noop(
+ );
+ /*!
+ ensures
+ - returns the no-op or identity gate.
+ (i.e. A gate with a matrix of
+ |1, 0|
+ |0, 1| )
+ !*/
+
+ template <
+ int control_bit,
+ int target_bit
+ >
+ class cnot : public gate_exp<cnot<control_bit, target_bit> >
+ {
+ /*!
+ REQUIREMENTS ON control_bit AND target_bit
+ - control_bit != target_bit
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the controlled-not quantum gate. It is a gate that
+ operates on abs(control_bit-target_bit)+1 qubits.
+
+ In terms of the computational basis vectors, this gate maps input
+ vectors to output vectors in the following way:
+ - if (the input vector corresponds to a state where the control_bit
+ qubit is 1) then
+ - this gate outputs the computational basis vector that
+ corresponds to the state where the target_bit has been flipped
+ with respect to the input vector
+ - else
+ - this gate outputs the input vector unmodified
+
+ !*/
+ };
+
+ template <
+ int control_bit1,
+ int control_bit2,
+ int target_bit
+ >
+ class toffoli : public gate_exp<toffoli<control_bit1, control_bit2, target_bit> >
+ {
+ /*!
+ REQUIREMENTS ON control_bit1, control_bit2, AND target_bit
+ - all the arguments denote different bits, i.e.:
+ - control_bit1 != target_bit
+ - control_bit2 != target_bit
+ - control_bit1 != control_bit2
+ - The target bit can't be in-between the control bits, i.e.:
+ - (control_bit1 < target_bit && control_bit2 < target_bit) ||
+ (control_bit1 > target_bit && control_bit2 > target_bit)
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents the toffoli variant of a controlled-not quantum gate.
+ It is a gate that operates on max(abs(control_bit2-target_bit),abs(control_bit1-target_bit))+1
+ qubits.
+
+ In terms of the computational basis vectors, this gate maps input
+ vectors to output vectors in the following way:
+ - if (the input vector corresponds to a state where the control_bit1 and
+ control_bit2 qubits are 1) then
+ - this gate outputs the computational basis vector that
+ corresponds to the state where the target_bit has been flipped
+ with respect to the input vector
+ - else
+ - this gate outputs the input vector unmodified
+
+ !*/
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_QUANTUM_COMPUTINg_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/queue.h b/ml/dlib/dlib/queue.h
new file mode 100644
index 000000000..b0f331dc9
--- /dev/null
+++ b/ml/dlib/dlib/queue.h
@@ -0,0 +1,84 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_QUEUe_
+#define DLIB_QUEUe_
+
+#include "queue/queue_kernel_1.h"
+#include "queue/queue_kernel_2.h"
+#include "queue/queue_kernel_c.h"
+
+#include "queue/queue_sort_1.h"
+
+
+#include "algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class queue
+ {
+ queue() {}
+ public:
+
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef queue_kernel_1<T,mem_manager>
+ kernel_1a;
+ typedef queue_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ // kernel_2a
+ typedef queue_kernel_2<T,20,mem_manager>
+ kernel_2a;
+ typedef queue_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+
+ // kernel_2b
+ typedef queue_kernel_2<T,100,mem_manager>
+ kernel_2b;
+ typedef queue_kernel_c<kernel_2b>
+ kernel_2b_c;
+
+
+
+
+ //---------- extensions ------------
+
+ // sort_1 extend kernel_1a
+ typedef queue_sort_1<kernel_1a>
+ sort_1a;
+ typedef queue_sort_1<kernel_1a_c>
+ sort_1a_c;
+
+
+ // sort_1 extend kernel_2a
+ typedef queue_sort_1<kernel_2a>
+ sort_1b;
+ typedef queue_sort_1<kernel_2a_c>
+ sort_1b_c;
+
+
+
+ // sort_1 extend kernel_2b
+ typedef queue_sort_1<kernel_2b>
+ sort_1c;
+ typedef queue_sort_1<kernel_2b_c>
+ sort_1c_c;
+
+
+
+
+
+ };
+}
+
+#endif // DLIB_QUEUe_
+
diff --git a/ml/dlib/dlib/queue/queue_kernel_1.h b/ml/dlib/dlib/queue/queue_kernel_1.h
new file mode 100644
index 000000000..e59bf2659
--- /dev/null
+++ b/ml/dlib/dlib/queue/queue_kernel_1.h
@@ -0,0 +1,554 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_QUEUE_KERNEl_1_
+#define DLIB_QUEUE_KERNEl_1_
+
+#include "queue_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class queue_kernel_1 : public enumerable<T>,
+ public remover<T>
+ {
+
+ /*!
+ INITIAL VALUE
+ queue_size == 0
+ current_element == 0
+ at_start_ == true
+
+ CONVENTION
+ queue_size == the number of elements in the queue
+ at_start() == at_start_
+ current_element_valid() == (current_element != 0)
+ element() == current_element->item
+
+ if (queue_size > 0)
+ {
+ in points to the last element to be inserted into the queue
+ out points to the next element to be dequeued
+
+ each node points to the node inserted after it except for the most
+ recently inserted node
+
+ current_element == 0
+ }
+
+ !*/
+
+
+ struct node
+ {
+ node* last;
+
+ T item;
+ };
+
+
+ public:
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ queue_kernel_1 (
+ ) :
+ in(0),
+ out(0),
+ queue_size(0),
+ current_element(0),
+ at_start_(true)
+ {
+ }
+
+ virtual ~queue_kernel_1 (
+ );
+
+ inline void clear(
+ );
+
+ void enqueue (
+ T& item
+ );
+
+ void dequeue (
+ T& item
+ );
+
+ void cat (
+ queue_kernel_1& item
+ );
+
+ T& current (
+ );
+
+ const T& current (
+ ) const;
+
+ void swap (
+ queue_kernel_1& item
+ );
+
+ // functions from the remover interface
+ inline void remove_any (
+ T& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ inline const T& element (
+ ) const;
+
+ inline T& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ private:
+
+ void delete_nodes (
+ node* start,
+ unsigned long length
+ );
+ /*!
+ requires
+ - start points to a node in a singly linked list
+ - start->last points to the next node in the list
+ - there are at least length nodes in the list begining with start
+ ensures
+ - length nodes have been deleted starting with the node pointed
+ to by start
+ !*/
+
+ // data members
+
+ node* in;
+ node* out;
+ unsigned long queue_size;
+ mutable node* current_element;
+ mutable bool at_start_;
+
+ // restricted functions
+ queue_kernel_1(queue_kernel_1&); // copy constructor
+ queue_kernel_1& operator=(queue_kernel_1&); // assignment operator
+
+ };
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ queue_kernel_1<T,mem_manager>& a,
+ queue_kernel_1<T,mem_manager>& b
+ ) { a.swap(b); }
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ queue_kernel_1<T,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ T temp;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(temp,in);
+ item.enqueue(temp);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type queue_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ queue_kernel_1<T,mem_manager>::
+ ~queue_kernel_1 (
+ )
+ {
+ delete_nodes(out,queue_size);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void queue_kernel_1<T,mem_manager>::
+ clear (
+ )
+ {
+ delete_nodes(out,queue_size);
+ queue_size = 0;
+
+ // put the enumerator at the start
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void queue_kernel_1<T,mem_manager>::
+ enqueue (
+ T& item
+ )
+ {
+ // make new node
+ node* temp = new node;
+
+ // swap item into new node
+ exchange(item,temp->item);
+
+ if (queue_size == 0)
+ out = temp;
+ else
+ in->last = temp;
+
+ // make in point to the new node
+ in = temp;
+
+ ++queue_size;
+
+ // put the enumerator at the start
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void queue_kernel_1<T,mem_manager>::
+ dequeue (
+ T& item
+ )
+ {
+ // swap out into item
+ exchange(item,out->item);
+
+ --queue_size;
+
+ if (queue_size == 0)
+ {
+ delete out;
+ }
+ else
+ {
+ node* temp = out;
+
+ // move out pointer to the next element in the queue
+ out = out->last;
+
+ // delete old node
+ delete temp;
+ }
+
+ // put the enumerator at the start
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void queue_kernel_1<T,mem_manager>::
+ cat (
+ queue_kernel_1<T,mem_manager>& item
+ )
+ {
+ if (item.queue_size > 0)
+ {
+ if (queue_size > 0)
+ {
+ in->last = item.out;
+ }
+ else
+ {
+ out = item.out;
+ }
+
+
+ in = item.in;
+ queue_size += item.queue_size;
+ item.queue_size = 0;
+ }
+
+ // put the enumerator at the start
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& queue_kernel_1<T,mem_manager>::
+ current (
+ )
+ {
+ return out->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& queue_kernel_1<T,mem_manager>::
+ current (
+ ) const
+ {
+ return out->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void queue_kernel_1<T,mem_manager>::
+ swap (
+ queue_kernel_1<T,mem_manager>& item
+ )
+ {
+ node* in_temp = in;
+ node* out_temp = out;
+ unsigned long queue_size_temp = queue_size;
+ node* current_element_temp = current_element;
+ bool at_start_temp = at_start_;
+
+ in = item.in;
+ out = item.out;
+ queue_size = item.queue_size;
+ current_element = item.current_element;
+ at_start_ = item.at_start_;
+
+ item.in = in_temp;
+ item.out = out_temp;
+ item.queue_size = queue_size_temp;
+ item.current_element = current_element_temp;
+ item.at_start_ = at_start_temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool queue_kernel_1<T,mem_manager>::
+ at_start (
+ ) const
+ {
+ return at_start_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ size_t queue_kernel_1<T,mem_manager>::
+ size (
+ ) const
+ {
+ return queue_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void queue_kernel_1<T,mem_manager>::
+ reset (
+ ) const
+ {
+ at_start_ = true;
+ current_element = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool queue_kernel_1<T,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return (current_element != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& queue_kernel_1<T,mem_manager>::
+ element (
+ ) const
+ {
+ return current_element->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& queue_kernel_1<T,mem_manager>::
+ element (
+ )
+ {
+ return current_element->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool queue_kernel_1<T,mem_manager>::
+ move_next (
+ ) const
+ {
+ if (at_start_)
+ {
+ at_start_ = false;
+ // if the queue is empty then there is nothing to do
+ if (queue_size == 0)
+ {
+ return false;
+ }
+ else
+ {
+ current_element = out;
+ return true;
+ }
+ }
+ else
+ {
+ // if we are at the last element then the enumeration has finished
+ if (current_element == in || current_element == 0)
+ {
+ current_element = 0;
+ return false;
+ }
+ else
+ {
+ current_element = current_element->last;
+ return true;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // remover function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void queue_kernel_1<T,mem_manager>::
+ remove_any (
+ T& item
+ )
+ {
+ dequeue(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void queue_kernel_1<T,mem_manager>::
+ delete_nodes (
+ node* start,
+ unsigned long length
+ )
+ {
+ node* temp;
+ while (length)
+ {
+ temp = start->last;
+ delete start;
+ start = temp;
+ --length;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_QUEUE_KERNEl_1_
+
diff --git a/ml/dlib/dlib/queue/queue_kernel_2.h b/ml/dlib/dlib/queue/queue_kernel_2.h
new file mode 100644
index 000000000..8e4536ae9
--- /dev/null
+++ b/ml/dlib/dlib/queue/queue_kernel_2.h
@@ -0,0 +1,600 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_QUEUE_KERNEl_2_
+#define DLIB_QUEUE_KERNEl_2_
+
+#include "queue_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager = default_memory_manager
+ >
+ class queue_kernel_2 : public enumerable<T>,
+ public remover<T>
+ {
+
+ /*!
+ REQUIREMENTS ON block_size
+ 0 < block_size < 2000000000
+
+ INITIAL VALUE
+ queue_size == 0
+ current_element == 0
+ at_start_ == true
+
+ CONVENTION
+ queue_size == the number of elements in the queue
+ at_start() == at_start_
+ current_element_valid() == (current_element != 0)
+ if (current_element_valid()) then
+ element() == current_element->item[current_element_pos]
+
+ if (queue_size > 0)
+ {
+ in->item[in_pos] == the spot where we will put the next item added
+ into the queue
+ out->item[out_pos] == current()
+
+ when enqueuing elements inside each node item[0] is filled first, then
+ item[1], then item[2], etc.
+
+
+ each node points to the node inserted after it except for the most
+ recently inserted node.
+ }
+
+ !*/
+
+
+ struct node
+ {
+ node* next;
+
+ T item[block_size];
+ };
+
+
+ public:
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ queue_kernel_2 (
+ ) :
+ in(0),
+ out(0),
+ queue_size(0),
+ current_element(0),
+ at_start_(true)
+ {
+ }
+
+ virtual ~queue_kernel_2 (
+ );
+
+ inline void clear(
+ );
+
+ void enqueue (
+ T& item
+ );
+
+ void dequeue (
+ T& item
+ );
+
+ void cat (
+ queue_kernel_2& item
+ );
+
+ T& current (
+ );
+
+ const T& current (
+ ) const;
+
+ void swap (
+ queue_kernel_2& item
+ );
+
+ // functions from the remover interface
+ inline void remove_any (
+ T& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ inline const T& element (
+ ) const;
+
+ inline T& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ private:
+
+ void delete_nodes (
+ node* start,
+ node* end
+ );
+ /*!
+ requires
+ - start points to a node in a singly linked list
+ - start->next points to the next node in the list
+ - by following the next pointers you eventually hit the node pointed
+ to by end
+ ensures
+ - calls delete on the start node, the end node, and all nodes in between
+ !*/
+
+ // data members
+
+ typename mem_manager::template rebind<node>::other pool;
+
+ node* in;
+ node* out;
+ size_t queue_size;
+ size_t in_pos;
+ size_t out_pos;
+
+
+ mutable node* current_element;
+ mutable size_t current_element_pos;
+ mutable bool at_start_;
+
+ // restricted functions
+ queue_kernel_2(queue_kernel_2&); // copy constructor
+ queue_kernel_2& operator=(queue_kernel_2&); // assignment operator
+
+ };
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ inline void swap (
+ queue_kernel_2<T,block_size,mem_manager>& a,
+ queue_kernel_2<T,block_size,mem_manager>& b
+ ) { a.swap(b); }
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ void deserialize (
+ queue_kernel_2<T,block_size,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ T temp;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(temp,in);
+ item.enqueue(temp);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type queue_kernel_2");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ queue_kernel_2<T,block_size,mem_manager>::
+ ~queue_kernel_2 (
+ )
+ {
+ COMPILE_TIME_ASSERT(0 < block_size && block_size < (unsigned long)(2000000000));
+
+ if (queue_size > 0)
+ delete_nodes(out,in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ void queue_kernel_2<T,block_size,mem_manager>::
+ clear (
+ )
+ {
+ if (queue_size > 0)
+ {
+ delete_nodes(out,in);
+ queue_size = 0;
+ }
+
+ // put the enumerator at the start
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ void queue_kernel_2<T,block_size,mem_manager>::
+ enqueue (
+ T& item
+ )
+ {
+ if (queue_size == 0)
+ {
+ out = in = pool.allocate();
+ in_pos = 0;
+ out_pos = 0;
+ }
+ else if (in_pos >= block_size)
+ {
+ in->next = pool.allocate();
+ in_pos = 0;
+ in = in->next;
+ }
+
+ exchange(item,in->item[in_pos]);
+ ++in_pos;
+
+ ++queue_size;
+
+ // put the enumerator at the start
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ void queue_kernel_2<T,block_size,mem_manager>::
+ dequeue (
+ T& item
+ )
+ {
+ // swap out into item
+ exchange(item,out->item[out_pos]);
+
+ ++out_pos;
+ --queue_size;
+
+ // if this was the last element in this node then remove this node
+ if (out_pos == block_size)
+ {
+ out_pos = 0;
+ node* temp = out;
+ out = out->next;
+ pool.deallocate(temp);
+ }
+ else if (queue_size == 0)
+ {
+ pool.deallocate(out);
+ }
+
+ // put the enumerator at the start
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ void queue_kernel_2<T,block_size,mem_manager>::
+ cat (
+ queue_kernel_2<T,block_size,mem_manager>& item
+ )
+ {
+ if (queue_size > 0)
+ {
+ T temp;
+ assign_zero_if_built_in_scalar_type(temp);
+ while (item.size() > 0)
+ {
+ item.dequeue(temp);
+ enqueue(temp);
+ }
+ }
+ else
+ {
+ in = item.in;
+ out = item.out;
+ out_pos = item.out_pos;
+ in_pos = item.in_pos;
+
+ queue_size = item.queue_size;
+ item.queue_size = 0;
+
+ // put the enumerator at the start
+ reset();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ T& queue_kernel_2<T,block_size,mem_manager>::
+ current (
+ )
+ {
+ return out->item[out_pos];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ const T& queue_kernel_2<T,block_size,mem_manager>::
+ current (
+ ) const
+ {
+ return out->item[out_pos];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ void queue_kernel_2<T,block_size,mem_manager>::
+ swap (
+ queue_kernel_2<T,block_size,mem_manager>& item
+ )
+ {
+ exchange(in,item.in);
+ exchange(out,item.out);
+ exchange(queue_size,item.queue_size);
+ exchange(in_pos,item.in_pos);
+ exchange(out_pos,item.out_pos);
+ exchange(current_element,item.current_element);
+ exchange(current_element_pos,item.current_element_pos);
+ exchange(at_start_,item.at_start_);
+ pool.swap(item.pool);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ size_t queue_kernel_2<T,block_size,mem_manager>::
+ size (
+ ) const
+ {
+ return queue_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ bool queue_kernel_2<T,block_size,mem_manager>::
+ at_start (
+ ) const
+ {
+ return at_start_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ void queue_kernel_2<T,block_size,mem_manager>::
+ reset (
+ ) const
+ {
+ at_start_ = true;
+ current_element = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ bool queue_kernel_2<T,block_size,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return (current_element != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ const T& queue_kernel_2<T,block_size,mem_manager>::
+ element (
+ ) const
+ {
+ return current_element->item[current_element_pos];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ T& queue_kernel_2<T,block_size,mem_manager>::
+ element (
+ )
+ {
+ return current_element->item[current_element_pos];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ bool queue_kernel_2<T,block_size,mem_manager>::
+ move_next (
+ ) const
+ {
+ if (at_start_)
+ {
+ at_start_ = false;
+ // if the queue is empty then there is nothing to do
+ if (queue_size == 0)
+ {
+ return false;
+ }
+ else
+ {
+ current_element = out;
+ current_element_pos = out_pos;
+ return true;
+ }
+ }
+ else if (current_element == 0)
+ {
+ return false;
+ }
+ else
+ {
+ ++current_element_pos;
+ // if we are at the last element then the enumeration has finished
+ if (current_element == in && current_element_pos == in_pos )
+ {
+ current_element = 0;
+ return false;
+ }
+ else if (current_element_pos == block_size)
+ {
+ current_element_pos = 0;
+ current_element = current_element->next;
+ }
+
+ return true;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // remover function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ void queue_kernel_2<T,block_size,mem_manager>::
+ remove_any (
+ T& item
+ )
+ {
+ dequeue(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ unsigned long block_size,
+ typename mem_manager
+ >
+ void queue_kernel_2<T,block_size,mem_manager>::
+ delete_nodes (
+ node* start,
+ node* end
+ )
+ {
+ node* temp;
+ while (start != end)
+ {
+ temp = start;
+ start = start->next;
+ pool.deallocate(temp);
+ }
+ pool.deallocate(start);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_QUEUE_KERNEl_2_
+
diff --git a/ml/dlib/dlib/queue/queue_kernel_abstract.h b/ml/dlib/dlib/queue/queue_kernel_abstract.h
new file mode 100644
index 000000000..4fd4c7dd1
--- /dev/null
+++ b/ml/dlib/dlib/queue/queue_kernel_abstract.h
@@ -0,0 +1,196 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_QUEUE_KERNEl_ABSTRACT_
+#ifdef DLIB_QUEUE_KERNEl_ABSTRACT_
+
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class queue : public enumerable<T>,
+ public remover<T>
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must be swappable by a global swap()
+ T must have a default constructor
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap() and current() functions do not invalidate pointers or
+ references to internal data.
+ All other functions have no such guarantee.
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the elements in the queue in the
+ same order they would be removed by repeated calls to dequeue().
+ (e.g. current() would be the first element enumerated)
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a first in first out queue containing items of type T
+
+ e.g. if the queue is {b,c,d,e} and then 'a' is enqueued
+ the queue becomes {a,b,c,d,e} and then calling dequeue takes e out
+ making the queue {a,b,c,d}
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ queue (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ !*/
+
+ virtual ~queue (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void enqueue (
+ T& item
+ );
+ /*!
+ ensures
+ - item is now at the left end of #*this
+ - #item has an initial value for its type
+ - #size() == size() + 1
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if enqueue() throws then it has no effect
+ !*/
+
+ void dequeue (
+ T& item
+ );
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - #size() == size() - 1
+ - the far right element of *this has been removed and swapped
+ into #item
+ - #at_start() == true
+ !*/
+
+ void cat (
+ queue& item
+ );
+ /*!
+ ensures
+ - item has been concatenated onto the left end of *this.
+ i.e. item.current() is attached onto the left end of *this and
+ the left most element in item will also be the left most item
+ in #*this
+ - #size() == size() + item.size()
+ - #item has its initial value
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if cat() throws then the state of #item and *this is undefined
+ until clear() is successfully called on them.
+ !*/
+
+ T& current (
+ );
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - returns a const reference to the next element to be dequeued.
+ i.e. the right most element.
+ !*/
+
+ const T& current (
+ ) const;
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - returns a non-const reference to the next element to be dequeued.
+ i.e. the right most element.
+ !*/
+
+ void swap (
+ queue& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ queue(queue&); // copy constructor
+ queue& operator=(queue&); // assignment operator
+
+ };
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ queue<T,mem_manager>& a,
+ queue<T,mem_manager>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ queue<T,mem_manager>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_QUEUE_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/queue/queue_kernel_c.h b/ml/dlib/dlib/queue/queue_kernel_c.h
new file mode 100644
index 000000000..554c9aecd
--- /dev/null
+++ b/ml/dlib/dlib/queue/queue_kernel_c.h
@@ -0,0 +1,187 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_QUEUE_KERNEl_C_
+#define DLIB_QUEUE_KERNEl_C_
+
+#include "queue_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+
+ template <
+ typename queue_base // is an implementation of queue_kernel_abstract.h
+ >
+ class queue_kernel_c : public queue_base
+ {
+ typedef typename queue_base::type T;
+
+ public:
+
+ void dequeue (
+ T& item
+ );
+
+ T& current (
+ );
+
+ const T& current (
+ ) const;
+
+ const T& element (
+ ) const;
+
+ T& element (
+ );
+
+ void remove_any (
+ T& item
+ );
+
+ };
+
+ template <
+ typename queue_base
+ >
+ inline void swap (
+ queue_kernel_c<queue_base>& a,
+ queue_kernel_c<queue_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_base
+ >
+ void queue_kernel_c<queue_base>::
+ dequeue (
+ T& item
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->size() != 0,
+ "\tvoid queue::dequeue"
+ << "\n\tsize of queue should not be zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ queue_base::dequeue(item);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_base
+ >
+ const typename queue_base::type& queue_kernel_c<queue_base>::
+ current (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->size() != 0,
+ "\tconst T& queue::current"
+ << "\n\tsize of queue should not be zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return queue_base::current();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_base
+ >
+ typename queue_base::type& queue_kernel_c<queue_base>::
+ current (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->size() != 0,
+ "\tT& queue::current"
+ << "\n\tsize of queue should not be zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return queue_base::current();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_base
+ >
+ const typename queue_base::type& queue_kernel_c<queue_base>::
+ element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst T& queue::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return queue_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_base
+ >
+ typename queue_base::type& queue_kernel_c<queue_base>::
+ element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tT& queue::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return queue_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_base
+ >
+ void queue_kernel_c<queue_base>::
+ remove_any (
+ T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (this->size() > 0),
+ "\tvoid queue::remove_any"
+ << "\n\tsize() must be greater than zero if something is going to be removed"
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ queue_base::remove_any(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_QUEUE_KERNEl_C_
+
diff --git a/ml/dlib/dlib/queue/queue_sort_1.h b/ml/dlib/dlib/queue/queue_sort_1.h
new file mode 100644
index 000000000..bc20dcb82
--- /dev/null
+++ b/ml/dlib/dlib/queue/queue_sort_1.h
@@ -0,0 +1,165 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_QUEUE_SORt_1_
+#define DLIB_QUEUE_SORt_1_
+
+#include "queue_sort_abstract.h"
+#include "../algs.h"
+#include <vector>
+#include "../sort.h"
+
+namespace dlib
+{
+
+ template <
+ typename queue_base
+ >
+ class queue_sort_1 : public queue_base
+ {
+ typedef typename queue_base::type T;
+
+ public:
+
+ /*!
+ This implementation uses the QuickSort algorithm and
+ when the quicksort depth goes too high it uses the dlib::qsort_array()
+ function on the data.
+ !*/
+
+ void sort (
+ );
+
+ template <typename compare_type>
+ void sort (
+ const compare_type& compare
+ )
+ {
+ if (this->size() > 1)
+ {
+ sort_this_queue(*this,0,compare);
+ }
+ }
+
+ private:
+
+ template <typename compare_type>
+ void sort_this_queue (
+ queue_base& queue,
+ long depth,
+ const compare_type& compare
+ )
+ /*!
+ ensures
+ each element in the queue is < the element behind it according
+ to compare
+ !*/
+ {
+ if (queue.size() <= 1)
+ {
+ // already sorted
+ }
+ else if (queue.size() <= 29)
+ {
+ T vect[29];
+ const unsigned long size = queue.size();
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ queue.dequeue(vect[i]);
+ }
+ isort_array(vect,0,size-1,compare);
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ queue.enqueue(vect[i]);
+ }
+ }
+ else if (depth > 50)
+ {
+ std::vector<T> vect(queue.size());
+ for (unsigned long i = 0; i < vect.size(); ++i)
+ {
+ queue.dequeue(vect[i]);
+ }
+ hsort_array(vect,0,vect.size()-1,compare);
+ for (unsigned long i = 0; i < vect.size(); ++i)
+ {
+ queue.enqueue(vect[i]);
+ }
+ }
+ else
+ {
+ queue_base left, right;
+ T partition_element;
+ T temp;
+ // do this just to avoid a compiler warning
+ assign_zero_if_built_in_scalar_type(temp);
+ assign_zero_if_built_in_scalar_type(partition_element);
+
+ queue.dequeue(partition_element);
+
+ // partition queue into left and right
+ while (queue.size() > 0)
+ {
+ queue.dequeue(temp);
+ if (compare(temp , partition_element))
+ {
+ left.enqueue(temp);
+ }
+ else
+ {
+ right.enqueue(temp);
+ }
+ }
+
+
+ long ratio;
+ if (left.size() > right.size())
+ ratio = left.size()/(right.size()+1); // add 1 so we can't divide by zero
+ else
+ ratio = right.size()/(left.size()+1);
+
+ sort_this_queue(left,ratio+depth,compare);
+ sort_this_queue(right,ratio+depth,compare);
+
+ // combine the two queues
+ left.swap(queue);
+ queue.enqueue(partition_element);
+ queue.cat(right);
+ }
+ }
+
+
+ };
+
+ template <
+ typename queue_base
+ >
+ inline void swap (
+ queue_sort_1<queue_base>& a,
+ queue_sort_1<queue_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename queue_base
+ >
+ void queue_sort_1<queue_base>::
+ sort (
+ )
+ {
+ if (this->size() > 1)
+ {
+ sort_this_queue(*this,0,std::less<typename queue_base::type>());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_QUEUE_SORt_1_
+
diff --git a/ml/dlib/dlib/queue/queue_sort_abstract.h b/ml/dlib/dlib/queue/queue_sort_abstract.h
new file mode 100644
index 000000000..54c06f430
--- /dev/null
+++ b/ml/dlib/dlib/queue/queue_sort_abstract.h
@@ -0,0 +1,74 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_QUEUE_SORt_ABSTRACT_
+#ifdef DLIB_QUEUE_SORt_ABSTRACT_
+
+
+#include "queue_kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename queue_base
+ >
+ class queue_sort : public queue_base
+ {
+
+ /*!
+ REQUIREMENTS ON QUEUE_BASE
+ - is an implementation of queue/queue_kernel_abstract.h
+ - queue_base::type must be a type with that is comparable via operator<
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ sort() may invalidate pointers and references to internal data.
+
+ WHAT THIS EXTENSION DOES FOR QUEUE
+ This gives a queue the ability to sort its contents by calling sort().
+ !*/
+
+
+ public:
+
+ void sort (
+ );
+ /*!
+ ensures
+ - for all elements in #*this the ith element is <= the i+1 element
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ data may be lost if sort() throws
+ !*/
+
+ template <typename compare_type>
+ void sort (
+ const compare_type& compare
+ );
+ /*!
+ ensures
+ - for all elements in #*this the ith element is <= the i+1 element
+ - uses compare(a,b) as the < operator. So if compare(a,b) == true
+ then a comes before b in the resulting ordering.
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ data may be lost if sort() throws
+ !*/
+ };
+
+ template <
+ template queue_base
+ >
+ inline void swap (
+ queue_sort<queue_base>& a,
+ queue_sort<queue_base>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_QUEUE_SORt_ABSTRACT_
+
diff --git a/ml/dlib/dlib/rand.h b/ml/dlib/dlib/rand.h
new file mode 100644
index 000000000..5e0146f0d
--- /dev/null
+++ b/ml/dlib/dlib/rand.h
@@ -0,0 +1,9 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RANd_
+#define DLIB_RANd_
+
+#include "rand/rand_kernel_1.h"
+
+#endif // DLIB_RANd_
+
diff --git a/ml/dlib/dlib/rand/mersenne_twister.h b/ml/dlib/dlib/rand/mersenne_twister.h
new file mode 100644
index 000000000..66ee87f54
--- /dev/null
+++ b/ml/dlib/dlib/rand/mersenne_twister.h
@@ -0,0 +1,210 @@
+/* boost random/mersenne_twister.hpp header file
+ *
+ * Copyright Jens Maurer 2000-2001
+ * Distributed under the Boost Software License, Version 1.0. (See
+ * accompanying file LICENSE_1_0.txt or copy at
+ * http://www.boost.org/LICENSE_1_0.txt)
+ *
+ * See http://www.boost.org for most recent version including documentation.
+ *
+ * $Id: mersenne_twister.hpp,v 1.20 2005/07/21 22:04:31 jmaurer Exp $
+ *
+ * Revision history
+ * 2001-02-18 moved to individual header files
+*/
+
+#ifndef DLIB_BOOST_RANDOM_MERSENNE_TWISTER_HPP
+#define DLIB_BOOST_RANDOM_MERSENNE_TWISTER_HPP
+
+#include <iostream>
+#include <algorithm> // std::copy
+#include <stdexcept>
+#include "../uintn.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+ namespace random_helpers
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ // http://www.math.keio.ac.jp/matumoto/emt.html
+ template<
+ class UIntType,
+ int w,
+ int n,
+ int m,
+ int r,
+ UIntType a,
+ int u,
+ int s,
+ UIntType b,
+ int t,
+ UIntType c,
+ int l,
+ UIntType val
+ >
+ class mersenne_twister
+ {
+ public:
+ typedef UIntType result_type;
+ const static int word_size = w;
+ const static int state_size = n;
+ const static int shift_size = m;
+ const static int mask_bits = r;
+ const static UIntType parameter_a = a;
+ const static int output_u = u;
+ const static int output_s = s;
+ const static UIntType output_b = b;
+ const static int output_t = t;
+ const static UIntType output_c = c;
+ const static int output_l = l;
+
+ const static bool has_fixed_range = false;
+
+ mersenne_twister() { seed(); }
+
+ explicit mersenne_twister(UIntType value) { seed(value); }
+
+ void seed () { seed(UIntType(5489)); }
+
+ // compiler-generated copy ctor and assignment operator are fine
+
+ void seed(UIntType value)
+ {
+ // New seeding algorithm from
+ // http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/emt19937ar.html
+ // In the previous versions, MSBs of the seed affected only MSBs of the
+ // state x[].
+ const UIntType mask = ~0u;
+ x[0] = value & mask;
+ for (i = 1; i < n; i++) {
+ // See Knuth "The Art of Computer Programming" Vol. 2, 3rd ed., page 106
+ x[i] = (1812433253UL * (x[i-1] ^ (x[i-1] >> (w-2))) + i) & mask;
+ }
+ }
+
+ result_type min() const { return 0; }
+ result_type max() const
+ {
+ // avoid "left shift count >= with of type" warning
+ result_type res = 0;
+ for(int i = 0; i < w; ++i)
+ res |= (1u << i);
+ return res;
+ }
+
+ result_type operator()();
+
+ friend void serialize(
+ const mersenne_twister& item,
+ std::ostream& out
+ )
+ {
+ dlib::serialize(item.x, out);
+ dlib::serialize(item.i, out);
+ }
+
+ friend void deserialize(
+ mersenne_twister& item,
+ std::istream& in
+ )
+ {
+ dlib::deserialize(item.x, in);
+ dlib::deserialize(item.i, in);
+ }
+
+ private:
+
+ void twist(int block);
+
+ // state representation: next output is o(x(i))
+ // x[0] ... x[k] x[k+1] ... x[n-1] x[n] ... x[2*n-1] represents
+ // x(i-k) ... x(i) x(i+1) ... x(i-k+n-1) x(i-k-n) ... x[i(i-k-1)]
+ // The goal is to always have x(i-n) ... x(i-1) available for
+ // operator== and save/restore.
+
+ UIntType x[2*n];
+ int i;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template<
+ class UIntType, int w, int n, int m, int r, UIntType a, int u,
+ int s, UIntType b, int t, UIntType c, int l, UIntType val
+ >
+ void mersenne_twister<UIntType,w,n,m,r,a,u,s,b,t,c,l,val>::twist(
+ int block
+ )
+ {
+ const UIntType upper_mask = (~0u) << r;
+ const UIntType lower_mask = ~upper_mask;
+
+ if(block == 0) {
+ for(int j = n; j < 2*n; j++) {
+ UIntType y = (x[j-n] & upper_mask) | (x[j-(n-1)] & lower_mask);
+ x[j] = x[j-(n-m)] ^ (y >> 1) ^ (y&1 ? a : 0);
+ }
+ } else if (block == 1) {
+ // split loop to avoid costly modulo operations
+ { // extra scope for MSVC brokenness w.r.t. for scope
+ for(int j = 0; j < n-m; j++) {
+ UIntType y = (x[j+n] & upper_mask) | (x[j+n+1] & lower_mask);
+ x[j] = x[j+n+m] ^ (y >> 1) ^ (y&1 ? a : 0);
+ }
+ }
+
+ for(int j = n-m; j < n-1; j++) {
+ UIntType y = (x[j+n] & upper_mask) | (x[j+n+1] & lower_mask);
+ x[j] = x[j-(n-m)] ^ (y >> 1) ^ (y&1 ? a : 0);
+ }
+ // last iteration
+ UIntType y = (x[2*n-1] & upper_mask) | (x[0] & lower_mask);
+ x[n-1] = x[m-1] ^ (y >> 1) ^ (y&1 ? a : 0);
+ i = 0;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template<
+ class UIntType, int w, int n, int m, int r, UIntType a, int u,
+ int s, UIntType b, int t, UIntType c, int l, UIntType val
+ >
+ inline typename mersenne_twister<UIntType,w,n,m,r,a,u,s,b,t,c,l,val>::result_type
+ mersenne_twister<UIntType,w,n,m,r,a,u,s,b,t,c,l,val>::operator()(
+ )
+ {
+ if(i == n)
+ twist(0);
+ else if(i >= 2*n)
+ twist(1);
+ // Step 4
+ UIntType z = x[i];
+ ++i;
+ z ^= (z >> u);
+ z ^= ((z << s) & b);
+ z ^= ((z << t) & c);
+ z ^= (z >> l);
+ return z;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ } // namespace random
+
+
+ typedef random_helpers::mersenne_twister<uint32,32,351,175,19,0xccab8ee7,11,
+ 7,0x31b6ab00,15,0xffe50000,17, 0xa37d3c92> mt11213b;
+
+ // validation by experiment from mt19937.c
+ typedef random_helpers::mersenne_twister<uint32,32,624,397,31,0x9908b0df,11,
+ 7,0x9d2c5680,15,0xefc60000,18, 3346425566U> mt19937;
+
+} // namespace dlib
+
+
+#endif // DLIB_BOOST_RANDOM_MERSENNE_TWISTER_HPP
+
diff --git a/ml/dlib/dlib/rand/rand_kernel_1.h b/ml/dlib/dlib/rand/rand_kernel_1.h
new file mode 100644
index 000000000..a1847be24
--- /dev/null
+++ b/ml/dlib/dlib/rand/rand_kernel_1.h
@@ -0,0 +1,354 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RAND_KERNEl_1_
+#define DLIB_RAND_KERNEl_1_
+
+#include <string>
+#include "../algs.h"
+#include "rand_kernel_abstract.h"
+#include "mersenne_twister.h"
+#include "../is_kind.h"
+#include <iostream>
+#include "../serialize.h"
+#include "../string.h"
+
+namespace dlib
+{
+
+
+ class rand
+ {
+
+ /*!
+ INITIAL VALUE
+ - seed == ""
+
+ CONVENTION
+ - the random numbers come from the boost mersenne_twister code
+ - get_seed() == seed
+ !*/
+
+ public:
+
+ // These typedefs are here for backwards compatibility with older versions of dlib.
+ typedef rand kernel_1a;
+ typedef rand float_1a;
+
+ rand(
+ )
+ {
+ init();
+ }
+
+ rand (
+ time_t seed_value
+ )
+ {
+ init();
+ set_seed(cast_to_string(seed_value));
+ }
+
+ rand (
+ const std::string& seed_value
+ )
+ {
+ init();
+ set_seed(seed_value);
+ }
+
+ virtual ~rand(
+ )
+ {}
+
+ void clear(
+ )
+ {
+ mt.seed();
+ seed.clear();
+
+ has_gaussian = false;
+ next_gaussian = 0;
+
+ // prime the generator a bit
+ for (int i = 0; i < 10000; ++i)
+ mt();
+ }
+
+ const std::string& get_seed (
+ )
+ {
+ return seed;
+ }
+
+ void set_seed (
+ const std::string& value
+ )
+ {
+ seed = value;
+
+ // make sure we do the seeding so that using a seed of "" gives the same
+ // state as calling this->clear()
+ if (value.size() != 0)
+ {
+ uint32 s = 0;
+ for (std::string::size_type i = 0; i < seed.size(); ++i)
+ {
+ s = (s*37) + static_cast<uint32>(seed[i]);
+ }
+ mt.seed(s);
+ }
+ else
+ {
+ mt.seed();
+ }
+
+ // prime the generator a bit
+ for (int i = 0; i < 10000; ++i)
+ mt();
+
+
+ has_gaussian = false;
+ next_gaussian = 0;
+ }
+
+ unsigned char get_random_8bit_number (
+ )
+ {
+ return static_cast<unsigned char>(mt());
+ }
+
+ uint16 get_random_16bit_number (
+ )
+ {
+ return static_cast<uint16>(mt());
+ }
+
+ inline uint32 get_random_32bit_number (
+ )
+ {
+ return mt();
+ }
+
+ inline uint64 get_random_64bit_number (
+ )
+ {
+ const uint64 a = get_random_32bit_number();
+ const uint64 b = get_random_32bit_number();
+ return (a<<32)|b;
+ }
+
+ double get_double_in_range (
+ double begin,
+ double end
+ )
+ {
+ DLIB_ASSERT(begin <= end);
+ return begin + get_random_double()*(end-begin);
+ }
+
+ long long get_integer_in_range(
+ long long begin,
+ long long end
+ )
+ {
+ DLIB_ASSERT(begin <= end);
+ if (begin == end)
+ return begin;
+
+ auto r = get_random_64bit_number();
+ const auto limit = std::numeric_limits<decltype(r)>::max();
+ const auto range = end-begin;
+ // Use rejection sampling to remove the biased sampling you would get with
+ // the naive get_random_64bit_number()%range sampling.
+ while(r >= (limit/range)*range)
+ r = get_random_64bit_number();
+
+ return begin + static_cast<long long>(r%range);
+ }
+
+ long long get_integer(
+ long long end
+ )
+ {
+ DLIB_ASSERT(end >= 0);
+
+ return get_integer_in_range(0,end);
+ }
+
+ double get_random_double (
+ )
+ {
+ uint32 temp;
+
+ temp = rand::get_random_32bit_number();
+ temp &= 0xFFFFFF;
+
+ double val = static_cast<double>(temp);
+
+ val *= 0x1000000;
+
+ temp = rand::get_random_32bit_number();
+ temp &= 0xFFFFFF;
+
+ val += temp;
+
+ val /= max_val;
+
+ if (val < 1.0)
+ {
+ return val;
+ }
+ else
+ {
+ // return a value slightly less than 1.0
+ return 1.0 - std::numeric_limits<double>::epsilon();
+ }
+ }
+
+ float get_random_float (
+ )
+ {
+ uint32 temp;
+
+ temp = rand::get_random_32bit_number();
+ temp &= 0xFFFFFF;
+
+ const float scale = 1.0/0x1000000;
+
+ const float val = static_cast<float>(temp)*scale;
+ if (val < 1.0f)
+ {
+ return val;
+ }
+ else
+ {
+ // return a value slightly less than 1.0
+ return 1.0f - std::numeric_limits<float>::epsilon();
+ }
+ }
+
+ double get_random_gaussian (
+ )
+ {
+ if (has_gaussian)
+ {
+ has_gaussian = false;
+ return next_gaussian;
+ }
+
+ double x1, x2, w;
+
+ const double rndmax = std::numeric_limits<dlib::uint32>::max();
+
+ // Generate a pair of Gaussian random numbers using the Box-Muller transformation.
+ do
+ {
+ const double rnd1 = get_random_32bit_number()/rndmax;
+ const double rnd2 = get_random_32bit_number()/rndmax;
+
+ x1 = 2.0 * rnd1 - 1.0;
+ x2 = 2.0 * rnd2 - 1.0;
+ w = x1 * x1 + x2 * x2;
+ } while ( w >= 1.0 );
+
+ w = std::sqrt( (-2.0 * std::log( w ) ) / w );
+ next_gaussian = x2 * w;
+ has_gaussian = true;
+ return x1 * w;
+ }
+
+ void swap (
+ rand& item
+ )
+ {
+ exchange(mt,item.mt);
+ exchange(seed, item.seed);
+ exchange(has_gaussian, item.has_gaussian);
+ exchange(next_gaussian, item.next_gaussian);
+ }
+
+ friend void serialize(
+ const rand& item,
+ std::ostream& out
+ );
+
+ friend void deserialize(
+ rand& item,
+ std::istream& in
+ );
+
+ private:
+
+ void init()
+ {
+ // prime the generator a bit
+ for (int i = 0; i < 10000; ++i)
+ mt();
+
+ max_val = 0xFFFFFF;
+ max_val *= 0x1000000;
+ max_val += 0xFFFFFF;
+ max_val += 0.05;
+
+
+ has_gaussian = false;
+ next_gaussian = 0;
+ }
+
+ mt19937 mt;
+
+ std::string seed;
+
+
+ double max_val;
+ bool has_gaussian;
+ double next_gaussian;
+ };
+
+
+ inline void swap (
+ rand& a,
+ rand& b
+ ) { a.swap(b); }
+
+
+ template <>
+ struct is_rand<rand>
+ {
+ static const bool value = true;
+ };
+
+ inline void serialize(
+ const rand& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+
+ serialize(item.mt, out);
+ serialize(item.seed, out);
+ serialize(item.has_gaussian, out);
+ serialize(item.next_gaussian, out);
+ }
+
+ inline void deserialize(
+ rand& item,
+ std::istream& in
+ )
+ {
+ int version;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Error deserializing object of type rand: unexpected version.");
+
+ deserialize(item.mt, in);
+ deserialize(item.seed, in);
+ deserialize(item.has_gaussian, in);
+ deserialize(item.next_gaussian, in);
+ }
+}
+
+#endif // DLIB_RAND_KERNEl_1_
+
+
diff --git a/ml/dlib/dlib/rand/rand_kernel_abstract.h b/ml/dlib/dlib/rand/rand_kernel_abstract.h
new file mode 100644
index 000000000..67af27a9e
--- /dev/null
+++ b/ml/dlib/dlib/rand/rand_kernel_abstract.h
@@ -0,0 +1,218 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RAND_KERNEl_ABSTRACT_
+#ifdef DLIB_RAND_KERNEl_ABSTRACT_
+
+#include <string>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+
+ class rand
+ {
+
+ /*!
+ INITIAL VALUE
+ get_seed() == ""
+
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a pseudorandom number generator.
+ !*/
+
+ public:
+
+
+ rand(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ rand (
+ time_t seed_value
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #get_seed() == cast_to_string(seed_value)
+ - This version of the constructor is equivalent to using
+ the default constructor and then calling set_seed(cast_to_string(seed_value))
+ throws
+ - std::bad_alloc
+ !*/
+
+ rand (
+ const std::string& seed_value
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #get_seed() == seed_value
+ - This version of the constructor is equivalent to using
+ the default constructor and then calling set_seed(seed_value)
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~rand(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ const std::string& get_seed (
+ );
+ /*!
+ ensures
+ - returns the string currently being used as the random seed.
+ !*/
+
+ void set_seed (
+ const std::string& value
+ );
+ /*!
+ ensures
+ - #get_seed() == value
+ !*/
+
+ unsigned char get_random_8bit_number (
+ );
+ /*!
+ ensures
+ - returns a pseudorandom number in the range 0 to 255
+ !*/
+
+ uint16 get_random_16bit_number (
+ );
+ /*!
+ ensures
+ - returns a pseudorandom number in the range 0 to 2^16-1
+ !*/
+
+ uint32 get_random_32bit_number (
+ );
+ /*!
+ ensures
+ - returns a pseudorandom number in the range 0 to 2^32-1
+ !*/
+
+ uint64 get_random_64bit_number (
+ );
+ /*!
+ ensures
+ - returns a pseudorandom number in the range 0 to 2^64-1
+ !*/
+
+ float get_random_float (
+ );
+ /*!
+ ensures
+ - returns a random float number N where: 0.0 <= N < 1.0.
+ !*/
+
+ double get_random_double (
+ );
+ /*!
+ ensures
+ - returns a random double number N where: 0.0 <= N < 1.0.
+ !*/
+
+ double get_double_in_range (
+ double begin,
+ double end
+ );
+ /*!
+ requires
+ - begin <= end
+ ensures
+ - if (begin < end) then
+ - returns a random double number N where: begin <= N < end.
+ - else
+ - returns begin
+ !*/
+
+ long long get_integer_in_range(
+ long long begin,
+ long long end
+ );
+ /*!
+ requires
+ - begin <= end
+ ensures
+ - returns a random integer selected from the range: begin <= N < end
+ The integer is selected uniformly at random.
+ !*/
+
+ long long get_integer(
+ long long end
+ );
+ /*!
+ requires
+ - 0 <= end
+ ensures
+ - returns get_integer_in_range(0,end)
+ !*/
+
+ double get_random_gaussian (
+ );
+ /*!
+ ensures
+ - returns a random number sampled from a Gaussian distribution
+ with mean 0 and standard deviation 1.
+ !*/
+
+ void swap (
+ rand& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ inline void swap (
+ rand& a,
+ rand& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ void serialize (
+ const rand& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ rand& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_RAND_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/random_forest.h b/ml/dlib/dlib/random_forest.h
new file mode 100644
index 000000000..082f36703
--- /dev/null
+++ b/ml/dlib/dlib/random_forest.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2018 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RANDOM_FOReST_H_
+#define DLIB_RANDOM_FOReST_H_
+
+#include "random_forest/random_forest_regression.h"
+
+#endif // DLIB_RANDOM_FOReST_H_
+
+
diff --git a/ml/dlib/dlib/random_forest/random_forest_regression.h b/ml/dlib/dlib/random_forest/random_forest_regression.h
new file mode 100644
index 000000000..a61f7a1a2
--- /dev/null
+++ b/ml/dlib/dlib/random_forest/random_forest_regression.h
@@ -0,0 +1,738 @@
+// Copyright (C) 2018 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RANdOM_FOREST_REGRESSION_H_
+#define DLIB_RANdOM_FOREST_REGRESSION_H_
+
+#include "random_forest_regression_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include <algorithm>
+#include "../threads.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class dense_feature_extractor
+ {
+
+ public:
+ typedef uint32_t feature;
+ typedef matrix<double,0,1> sample_type;
+
+ dense_feature_extractor(
+ ) = default;
+
+ void setup (
+ const std::vector<sample_type>& x,
+ const std::vector<double>& y
+ )
+ {
+ DLIB_CASSERT(x.size() > 0);
+ DLIB_CASSERT(x.size() == y.size());
+ for (auto& el : x)
+ DLIB_CASSERT(el.size() == x[0].size(), "All the vectors in a training set have to have the same dimensionality.");
+
+ DLIB_CASSERT(x[0].size() != 0, "The vectors can't be empty.");
+
+ num_feats = x[0].size();
+ }
+
+
+ void get_random_features (
+ dlib::rand& rnd,
+ size_t num,
+ std::vector<feature>& feats
+ ) const
+ {
+ DLIB_ASSERT(max_num_feats() != 0);
+ num = std::min(num, num_feats);
+
+ feats.clear();
+ for (size_t i = 0; i < num_feats; ++i)
+ feats.push_back(i);
+
+ // now pick num features at random
+ for (size_t i = 0; i < num; ++i)
+ {
+ auto idx = rnd.get_integer_in_range(i,num_feats);
+ std::swap(feats[i], feats[idx]);
+ }
+ feats.resize(num);
+ }
+
+ double extract_feature_value (
+ const sample_type& item,
+ const feature& f
+ ) const
+ {
+ DLIB_ASSERT(max_num_feats() != 0);
+ return item(f);
+ }
+
+ size_t max_num_feats (
+ ) const
+ {
+ return num_feats;
+ }
+
+ friend void serialize(const dense_feature_extractor& item, std::ostream& out)
+ {
+ serialize("dense_feature_extractor", out);
+ serialize(item.num_feats, out);
+ }
+
+ friend void deserialize(dense_feature_extractor& item, std::istream& in)
+ {
+ check_serialized_version("dense_feature_extractor", in);
+ deserialize(item.num_feats, in);
+ }
+
+ private:
+ size_t num_feats = 0;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ struct internal_tree_node
+ {
+ uint32_t left;
+ uint32_t right;
+ float split_threshold;
+ typename feature_extractor::feature split_feature;
+ };
+
+ template <typename feature_extractor>
+ void serialize(const internal_tree_node<feature_extractor>& item, std::ostream& out)
+ {
+ serialize(item.left, out);
+ serialize(item.right, out);
+ serialize(item.split_threshold, out);
+ serialize(item.split_feature, out);
+ }
+
+ template <typename feature_extractor>
+ void deserialize(internal_tree_node<feature_extractor>& item, std::istream& in)
+ {
+ deserialize(item.left, in);
+ deserialize(item.right, in);
+ deserialize(item.split_threshold, in);
+ deserialize(item.split_feature, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor = dense_feature_extractor
+ >
+ class random_forest_regression_function
+ {
+
+ public:
+
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::sample_type sample_type;
+
+ random_forest_regression_function(
+ ) = default;
+
+ random_forest_regression_function (
+ feature_extractor_type&& fe_,
+ std::vector<std::vector<internal_tree_node<feature_extractor>>>&& trees_,
+ std::vector<std::vector<float>>&& leaves_
+ ) :
+ fe(std::move(fe_)),
+ trees(std::move(trees_)),
+ leaves(std::move(leaves_))
+ {
+ DLIB_ASSERT(trees.size() > 0);
+ DLIB_ASSERT(trees.size() == leaves.size(), "Every set of tree nodes has to have leaves");
+#ifdef ENABLE_ASSERTS
+ for (size_t i = 0; i < trees.size(); ++i)
+ {
+ DLIB_ASSERT(trees[i].size() > 0, "A tree can't have 0 leaves.");
+ for (auto& node : trees[i])
+ {
+ DLIB_ASSERT(trees[i].size()+leaves[i].size() > node.left, "left node index in tree is too big. There is no associated tree node or leaf.");
+ DLIB_ASSERT(trees[i].size()+leaves[i].size() > node.right, "right node index in tree is too big. There is no associated tree node or leaf.");
+ }
+ }
+#endif
+ }
+
+ size_t get_num_trees(
+ ) const
+ {
+ return trees.size();
+ }
+
+ const std::vector<std::vector<internal_tree_node<feature_extractor>>>& get_internal_tree_nodes (
+ ) const { return trees; }
+
+ const std::vector<std::vector<float>>& get_tree_leaves (
+ ) const { return leaves; }
+
+ const feature_extractor_type& get_feature_extractor (
+ ) const { return fe; }
+
+ double operator() (
+ const sample_type& x
+ ) const
+ {
+ DLIB_ASSERT(get_num_trees() > 0);
+
+ double accum = 0;
+
+ for (size_t i = 0; i < trees.size(); ++i)
+ {
+ auto& tree = trees[i];
+ // walk the tree to the leaf
+ uint32_t idx = 0;
+ while(idx < tree.size())
+ {
+ auto feature_value = fe.extract_feature_value(x, tree[idx].split_feature);
+ if (feature_value < tree[idx].split_threshold)
+ idx = tree[idx].left;
+ else
+ idx = tree[idx].right;
+ }
+ // compute leaf index
+ accum += leaves[i][idx-tree.size()];
+ }
+
+ return accum/trees.size();
+ }
+
+ friend void serialize(const random_forest_regression_function& item, std::ostream& out)
+ {
+ serialize("random_forest_regression_function", out);
+ serialize(item.fe, out);
+ serialize(item.trees, out);
+ serialize(item.leaves, out);
+ }
+
+ friend void deserialize(random_forest_regression_function& item, std::istream& in)
+ {
+ check_serialized_version("random_forest_regression_function", in);
+ deserialize(item.fe, in);
+ deserialize(item.trees, in);
+ deserialize(item.leaves, in);
+ }
+
+ private:
+
+ /*!
+ CONVENTION
+ - trees.size() == leaves.size()
+ - Any .left or .right index in trees that is larger than the number of
+ nodes in the tree references a leaf. Moreover, the index of the leaf is
+ computed by subtracting the number of nodes in the tree.
+ !*/
+
+ feature_extractor_type fe;
+
+ // internal nodes of trees
+ std::vector<std::vector<internal_tree_node<feature_extractor>>> trees;
+ // leaves of trees
+ std::vector<std::vector<float>> leaves;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor = dense_feature_extractor
+ >
+ class random_forest_regression_trainer
+ {
+ public:
+ typedef feature_extractor feature_extractor_type;
+ typedef random_forest_regression_function<feature_extractor> trained_function_type;
+ typedef typename feature_extractor::sample_type sample_type;
+
+
+ random_forest_regression_trainer (
+ ) = default;
+
+ const feature_extractor_type& get_feature_extractor (
+ ) const
+ {
+ return fe_;
+ }
+
+ void set_feature_extractor (
+ const feature_extractor_type& feat_extractor
+ )
+ {
+ fe_ = feat_extractor;
+ }
+
+ void set_seed (
+ const std::string& seed
+ )
+ {
+ random_seed = seed;
+ }
+
+ const std::string& get_random_seed (
+ ) const
+ {
+ return random_seed;
+ }
+
+ size_t get_num_trees (
+ ) const
+ {
+ return num_trees;
+ }
+
+ void set_num_trees (
+ size_t num
+ )
+ {
+ DLIB_CASSERT(num > 0);
+ num_trees = num;
+ }
+
+ void set_feature_subsampling_fraction (
+ double frac
+ )
+ {
+ DLIB_CASSERT(0 < frac && frac <= 1);
+ feature_subsampling_frac = frac;
+ }
+
+ double get_feature_subsampling_frac(
+ ) const
+ {
+ return feature_subsampling_frac;
+ }
+
+ void set_min_samples_per_leaf (
+ size_t num
+ )
+ {
+ DLIB_ASSERT(num > 0);
+ min_samples_per_leaf = num;
+ }
+
+ size_t get_min_samples_per_leaf(
+ ) const
+ {
+ return min_samples_per_leaf;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& x,
+ const std::vector<double>& y
+ ) const
+ {
+ std::vector<double> junk;
+ return do_train(x,y,junk,false);
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& x,
+ const std::vector<double>& y,
+ std::vector<double>& oob_values
+ ) const
+ {
+ return do_train(x,y,oob_values,true);
+ }
+
+ private:
+
+ trained_function_type do_train (
+ const std::vector<sample_type>& x,
+ const std::vector<double>& y,
+ std::vector<double>& oob_values,
+ bool compute_oob_values
+ ) const
+ {
+ DLIB_CASSERT(x.size() == y.size());
+ DLIB_CASSERT(x.size() > 0);
+
+ feature_extractor_type fe = fe_;
+ fe.setup(x,y);
+
+ DLIB_CASSERT(fe.max_num_feats() != 0);
+
+ std::vector<std::vector<internal_tree_node<feature_extractor>>> all_trees(num_trees);
+ std::vector<std::vector<float>> all_leaves(num_trees);
+
+ const double sumy = sum(mat(y));
+
+ const size_t feats_per_node = std::max(1.0,std::round(fe.max_num_feats()*feature_subsampling_frac));
+
+ // Each tree couldn't have more than this many interior nodes. It might
+ // end up having less though. We need to know this value because the way
+ // we mark a left or right pointer in a tree as pointing to a leaf is by
+ // making its index larger than the number of interior nodes in the tree.
+ // But we don't know the tree's size before we finish building it. So we
+ // will use max_num_nodes as a proxy during tree construction and then go
+ // back and fix it once a tree's size is known.
+ const uint32_t max_num_nodes = y.size();
+
+ std::vector<uint32_t> oob_hits;
+
+ if (compute_oob_values)
+ {
+ oob_values.resize(y.size());
+ oob_hits.resize(y.size());
+ }
+
+
+ std::mutex m;
+
+ // Calling build_tree(i) creates the ith tree and stores the results in
+ // all_trees and all_leaves.
+ auto build_tree = [&](long i)
+ {
+ dlib::rand rnd(random_seed + std::to_string(i));
+ auto& tree = all_trees[i];
+ auto& leaves = all_leaves[i];
+
+ // Check if there are fewer than min_samples_per_leaf and if so then
+ // don't make any tree. Just average the things and be done.
+ if (y.size() <= min_samples_per_leaf)
+ {
+ leaves.push_back(sumy/y.size());
+ return;
+ }
+
+
+ // pick a random bootstrap of the data.
+ std::vector<std::pair<float,uint32_t>> idxs(y.size());
+ for (auto& idx : idxs)
+ idx = std::make_pair(0,rnd.get_integer(y.size()));
+
+ // We are going to use ranges_to_process as a stack that tracks which
+ // range of samples we are going to split next.
+ std::vector<range_t> ranges_to_process;
+ // start with the root of the tree, i.e. the entire range of training
+ // samples.
+ ranges_to_process.emplace_back(sumy,0,y.size());
+ // push an unpopulated root node into the tree. We will populate it
+ // when we process its corresponding range.
+ tree.emplace_back();
+
+ std::vector<typename feature_extractor::feature> feats;
+
+ while(ranges_to_process.size() > 0)
+ {
+ // Grab the next range/node to process.
+ const auto range = ranges_to_process.back();
+ ranges_to_process.pop_back();
+
+
+ // Get the split features we will consider at this node.
+ fe.get_random_features(rnd, feats_per_node, feats);
+ // Then find the best split
+ auto best_split = find_best_split_among_feats(fe, range, feats, x, y, idxs);
+
+ range_t left_split(best_split.left_sum, range.begin, best_split.split_idx);
+ range_t right_split(best_split.right_sum, best_split.split_idx, range.end);
+
+ DLIB_ASSERT(left_split.begin < left_split.end);
+ DLIB_ASSERT(right_split.begin < right_split.end);
+
+ // Now that we know the split we can populate the parent node we popped
+ // from ranges_to_process.
+ tree[range.tree_idx].split_threshold = best_split.split_threshold;
+ tree[range.tree_idx].split_feature = best_split.split_feature;
+
+ // If the left split is big enough to make a new interior leaf
+ // node. We also stop splitting if all the samples went into this node.
+ // This could happen if the features are all uniform so there just
+ // isn't any way to split them anymore.
+ if (left_split.size() > min_samples_per_leaf && right_split.size() != 0)
+ {
+ // allocate an interior leaf node for it.
+ left_split.tree_idx = tree.size();
+ tree.emplace_back();
+ // set the pointer in the parent node to the newly allocated
+ // node.
+ tree[range.tree_idx].left = left_split.tree_idx;
+
+ ranges_to_process.emplace_back(left_split);
+ }
+ else
+ {
+ // Add to leaves. Don't forget to set the pointer in the
+ // parent node to the newly allocated leaf node.
+ tree[range.tree_idx].left = leaves.size() + max_num_nodes;
+ leaves.emplace_back(left_split.avg());
+ }
+
+
+ // If the right split is big enough to make a new interior leaf
+ // node. We also stop splitting if all the samples went into this node.
+ // This could happen if the features are all uniform so there just
+ // isn't any way to split them anymore.
+ if (right_split.size() > min_samples_per_leaf && left_split.size() != 0)
+ {
+ // allocate an interior leaf node for it.
+ right_split.tree_idx = tree.size();
+ tree.emplace_back();
+ // set the pointer in the parent node to the newly allocated
+ // node.
+ tree[range.tree_idx].right = right_split.tree_idx;
+
+ ranges_to_process.emplace_back(right_split);
+ }
+ else
+ {
+ // Add to leaves. Don't forget to set the pointer in the
+ // parent node to the newly allocated leaf node.
+ tree[range.tree_idx].right = leaves.size() + max_num_nodes;
+ leaves.emplace_back(right_split.avg());
+ }
+ } // end while (still building tree)
+
+ // Fix the leaf pointers in the tree now that we know the correct
+ // tree.size() value.
+ DLIB_CASSERT(max_num_nodes >= tree.size());
+ const auto offset = max_num_nodes - tree.size();
+ for (auto& n : tree)
+ {
+ if (n.left >= max_num_nodes)
+ n.left -= offset;
+ if (n.right >= max_num_nodes)
+ n.right -= offset;
+ }
+
+
+ if (compute_oob_values)
+ {
+ std::sort(idxs.begin(), idxs.end(),
+ [](const std::pair<float,uint32_t>& a, const std::pair<float,uint32_t>& b) {return a.second<b.second; });
+
+ std::lock_guard<std::mutex> lock(m);
+
+ size_t j = 0;
+ for (size_t i = 0; i < oob_values.size(); ++i)
+ {
+ // check if i is in idxs
+ while(j < idxs.size() && i > idxs[j].second)
+ ++j;
+
+ // i isn't in idxs so it's an oob sample and we should process it.
+ if (j == idxs.size() || idxs[j].second != i)
+ {
+ oob_hits[i]++;
+
+ // walk the tree to find the leaf value for this oob sample
+ uint32_t idx = 0;
+ while(idx < tree.size())
+ {
+ auto feature_value = fe.extract_feature_value(x[i], tree[idx].split_feature);
+ if (feature_value < tree[idx].split_threshold)
+ idx = tree[idx].left;
+ else
+ idx = tree[idx].right;
+ }
+ oob_values[i] += leaves[idx-tree.size()];
+ }
+ }
+ }
+ };
+
+ if (verbose)
+ parallel_for_verbose(0, num_trees, build_tree);
+ else
+ parallel_for(0, num_trees, build_tree);
+
+
+ if (compute_oob_values)
+ {
+ double meanval = 0;
+ double cnt = 0;
+ for (size_t i = 0; i < oob_values.size(); ++i)
+ {
+ if (oob_hits[i] != 0)
+ {
+ oob_values[i] /= oob_hits[i];
+ meanval += oob_values[i];
+ ++cnt;
+ }
+ }
+
+ // If there are some elements that didn't get hits, we set their oob values
+ // to the mean oob value.
+ if (cnt != 0)
+ {
+ const double typical_value = meanval/cnt;
+ for (size_t i = 0; i < oob_values.size(); ++i)
+ {
+ if (oob_hits[i] == 0)
+ oob_values[i] = typical_value;
+ }
+ }
+ }
+
+ return trained_function_type(std::move(fe), std::move(all_trees), std::move(all_leaves));
+ }
+
+ struct range_t
+ {
+ range_t(
+ double sumy,
+ uint32_t begin,
+ uint32_t end
+ ) : sumy(sumy), begin(begin), end(end), tree_idx(0) {}
+
+ double sumy;
+ uint32_t begin;
+ uint32_t end;
+
+ // Every range object corresponds to an entry in a tree. This tells you the
+ // tree node that owns the range.
+ uint32_t tree_idx;
+
+ uint32_t size() const { return end-begin; }
+ double avg() const { return sumy/size(); }
+ };
+
+ struct best_split_details
+ {
+ double score = -std::numeric_limits<double>::infinity();
+ double left_sum;
+ double right_sum;
+ uint32_t split_idx;
+ double split_threshold;
+ typename feature_extractor::feature split_feature;
+
+ bool operator < (const best_split_details& rhs) const
+ {
+ return score < rhs.score;
+ }
+ };
+
+ static best_split_details find_best_split (
+ const range_t& range,
+ const std::vector<double>& y,
+ const std::vector<std::pair<float,uint32_t>>& idxs
+ )
+ /*!
+ requires
+ - max(mat(idxs)) < y.size()
+ - range.sumy == sum of y[idxs[j].second] for all valid j in range [range.begin, range.end).
+ ensures
+ - finds a threshold T such that there exists an i satisfying the following:
+ - y[idxs[j].second] < T for all j <= i
+ - y[idxs[j].second] > T for all j > i
+ Therefore, the threshold T partitions the contents of y into two groups,
+ relative to the ordering established by idxs. Moreover the partitioning
+ of y values into two groups has the additional requirement that it is
+ optimal in the sense that the sum of the squared deviations from each
+ partition's mean is minimized.
+ !*/
+ {
+
+ size_t best_i = range.begin;
+ double best_score = -1;
+ double left_sum = 0;
+ double best_left_sum = y[idxs[range.begin].second];
+ const auto size = range.size();
+ size_t left_size = 0;
+ for (size_t i = range.begin; i+1 < range.end; ++i)
+ {
+ ++left_size;
+ left_sum += y[idxs[i].second];
+
+ // Don't split here because the next element has the same feature value so
+ // we can't *really* split here.
+ if (idxs[i].first==idxs[i+1].first)
+ continue;
+
+ const double right_sum = range.sumy-left_sum;
+
+ const double score = left_sum*left_sum/left_size + right_sum*right_sum/(size-left_size);
+
+ if (score > best_score)
+ {
+ best_score = score;
+ best_i = i;
+ best_left_sum = left_sum;
+ }
+ }
+
+ best_split_details result;
+ result.score = best_score;
+ result.left_sum = best_left_sum;
+ result.right_sum = range.sumy-best_left_sum;
+ result.split_idx = best_i+1; // one past the end of the left range
+ result.split_threshold = (idxs[best_i].first+idxs[best_i+1].first)/2;
+
+ return result;
+ }
+
+
+ static best_split_details find_best_split_among_feats(
+ const feature_extractor& fe,
+ const range_t& range,
+ const std::vector<typename feature_extractor::feature>& feats,
+ const std::vector<sample_type>& x,
+ const std::vector<double>& y,
+ std::vector<std::pair<float,uint32_t>>& idxs
+ )
+ {
+ auto compare_first = [](const std::pair<float,uint32_t>& a, const std::pair<float,uint32_t>& b) { return a.first<b.first; };
+ best_split_details best;
+ for (auto& feat : feats)
+ {
+ // Extract feature values for this feature and sort the indexes based on
+ // that feature so we can then find the best split.
+ for (auto i = range.begin; i < range.end; ++i)
+ idxs[i].first = fe.extract_feature_value(x[idxs[i].second], feat);
+
+ std::sort(idxs.begin()+range.begin, idxs.begin()+range.end, compare_first);
+
+ auto split = find_best_split(range, y, idxs);
+
+ if (best < split)
+ {
+ best = split;
+ best.split_feature = feat;
+ }
+ }
+
+ // resort idxs based on winning feat
+ for (auto i = range.begin; i < range.end; ++i)
+ idxs[i].first = fe.extract_feature_value(x[idxs[i].second], best.split_feature);
+ std::sort(idxs.begin()+range.begin, idxs.begin()+range.end, compare_first);
+
+ return best;
+ }
+
+ std::string random_seed;
+ size_t num_trees = 1000;
+ double feature_subsampling_frac = 1.0/3.0;
+ size_t min_samples_per_leaf = 5;
+ feature_extractor_type fe_;
+ bool verbose = false;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANdOM_FOREST_REGRESSION_H_
+
+
diff --git a/ml/dlib/dlib/random_forest/random_forest_regression_abstract.h b/ml/dlib/dlib/random_forest/random_forest_regression_abstract.h
new file mode 100644
index 000000000..8ece1f04b
--- /dev/null
+++ b/ml/dlib/dlib/random_forest/random_forest_regression_abstract.h
@@ -0,0 +1,460 @@
+// Copyright (C) 2018 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RANdOM_FOREST_REGRESION_ABSTRACT_H_
+#ifdef DLIB_RANdOM_FOREST_REGRESION_ABSTRACT_H_
+
+#include <vector>
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class dense_feature_extractor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for extracting features from objects. In particular,
+ it is designed to be used with the random forest regression tools discussed
+ below.
+
+ This particular feature extract does almost nothing since it works on
+ vectors in R^n and simply selects elements from each vector. However, the
+ tools below are templated and allow you to design your own feature extractors
+ that operate on whatever object types you create. So for example, maybe
+ you want to perform regression on images rather than vectors. Moreover,
+ your feature extraction could be more complex. Maybe you are selecting
+ differences between pairs of pixels in an image or doing something
+ involving geometric transforms during feature extraction. Any of these
+ kinds of more complex feature extraction patterns can be realized with the
+ random forest tools by implementing your own feature extractor object and
+ using it with the random forest objects.
+
+ Therefore, you should consider this dense_feature_extractor as an example
+ that documents the interface as well as the simple default extractor for
+ use with dense vectors.
+
+
+ THREAD SAFETY
+ It is safe to call const members of this object from multiple threads. ANY
+ USER DEFINED FEATURE EXTRACTORS MUST ALSO MEET THIS GUARONTEE AS WELL SINCE
+ IT IS ASSUMED BY THE RANDOM FOREST TRAINING ROUTINES.
+ !*/
+
+ public:
+ typedef uint32_t feature;
+ typedef matrix<double,0,1> sample_type;
+
+ dense_feature_extractor(
+ );
+ /*!
+ ensures
+ - #max_num_feats() == 0
+ !*/
+
+ void setup (
+ const std::vector<sample_type>& x,
+ const std::vector<double>& y
+ );
+ /*!
+ requires
+ - x.size() == y.size()
+ - x.size() > 0
+ - x[0].size() > 0
+ - all the vectors in x have the same dimensionality.
+ ensures
+ - Configures this feature extractor to work on the given training data.
+ For dense feature extractors all we do is record the dimensionality of
+ the training vectors.
+ - #max_num_feats() == x[0].size()
+ (In general, setup() sets max_num_feats() to some non-zero value so that
+ the other methods of this object can then be called. The point of setup()
+ is to allow a feature extractor to gather whatever statistics it needs from
+ training data. That is, more complex feature extraction strategies my
+ themselves be trained from data.)
+ !*/
+
+ void get_random_features (
+ dlib::rand& rnd,
+ size_t num,
+ std::vector<feature>& feats
+ ) const;
+ /*!
+ requires
+ - max_num_feats() != 0
+ ensures
+ - #feats.size() == min(num, max_num_feats())
+ - This function randomly identifies num features and stores them into feats.
+ These feature objects can then be used with extract_feature_value() to
+ obtain a value from any particular sample_type object. This value is the
+ "feature value" used by a decision tree algorithm to deice how to split
+ and traverse trees.
+ - The above two conditions define the behavior of get_random_features() in
+ general. For this specific implementation of the feature extraction interface
+ this function selects num integer values from the range [0, max_num_feats()),
+ without replacement. These values are stored into feats.
+ !*/
+
+ double extract_feature_value (
+ const sample_type& item,
+ const feature& f
+ ) const;
+ /*!
+ requires
+ - #max_num_feats() != 0
+ - f was produced from a call to get_random_features().
+ ensures
+ - Extracts the feature value corresponding to f. For this simple dense
+ feature extractor this simply means returning item(f). But in general
+ you can design feature extractors that do something more complex.
+ !*/
+
+ size_t max_num_feats (
+ ) const;
+ /*!
+ ensures
+ - returns the number of distinct features this object might extract. That is,
+ a feature extractor essentially defines a mapping from sample_type objects to
+ vectors in R^max_num_feats().
+ !*/
+ };
+
+ void serialize(const dense_feature_extractor& item, std::ostream& out);
+ void deserialize(dense_feature_extractor& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ struct internal_tree_node
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is an internal node in a regression tree. See the code of
+ random_forest_regression_function to see how it is used to create a tree.
+ !*/
+
+ uint32_t left;
+ uint32_t right;
+ float split_threshold;
+ typename feature_extractor::feature split_feature;
+ };
+
+ template <typename feature_extractor>
+ void serialize(const internal_tree_node<feature_extractor>& item, std::ostream& out);
+ template <typename feature_extractor>
+ void deserialize(internal_tree_node<feature_extractor>& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor = dense_feature_extractor
+ >
+ class random_forest_regression_function
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ feature_extractor must be dense_feature_extractor or a type with a
+ compatible interface.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a regression forest. This is a collection of
+ decision trees that take an object as input and each vote on a real value
+ to associate with the object. The final real value output is the average
+ of all the votes from each of the trees.
+ !*/
+
+ public:
+
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::sample_type sample_type;
+
+ random_forest_regression_function(
+ );
+ /*!
+ ensures
+ - #num_trees() == 0
+ !*/
+
+ random_forest_regression_function (
+ feature_extractor_type&& fe_,
+ std::vector<std::vector<internal_tree_node<feature_extractor>>>&& trees_,
+ std::vector<std::vector<float>>&& leaves_
+ );
+ /*!
+ requires
+ - trees.size() > 0
+ - trees.size() = leaves.size()
+ - for all valid i:
+ - leaves[i].size() > 0
+ - trees[i].size()+leaves[i].size() > the maximal left or right index values in trees[i].
+ (i.e. each left or right value must index to some existing internal tree node or leaf node).
+ ensures
+ - #get_internal_tree_nodes() == trees_
+ - #get_tree_leaves() == leaves_
+ - #get_feature_extractor() == fe_
+ !*/
+
+ size_t get_num_trees(
+ ) const;
+ /*!
+ ensures
+ - returns the number of trees in this regression forest.
+ !*/
+
+ const std::vector<std::vector<internal_tree_node<feature_extractor>>>& get_internal_tree_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the internal tree nodes that define the regression trees.
+ - get_internal_tree_nodes().size() == get_num_trees()
+ !*/
+
+ const std::vector<std::vector<float>>& get_tree_leaves (
+ ) const;
+ /*!
+ ensures
+ - returns the tree leaves that define the regression trees.
+ - get_tree_leaves().size() == get_num_trees()
+ !*/
+
+ const feature_extractor_type& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used by the trees.
+ !*/
+
+ double operator() (
+ const sample_type& x
+ ) const;
+ /*!
+ requires
+ - get_num_trees() > 0
+ ensures
+ - Maps x to a real value and returns the value. To do this, we find the
+ get_num_trees() leaf values associated with x and then return the average
+ of these leaf values.
+ !*/
+ };
+
+ void serialize(const random_forest_regression_function& item, std::ostream& out);
+ void deserialize(random_forest_regression_function& item, std::istream& in);
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor = dense_feature_extractor
+ >
+ class random_forest_regression_trainer
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ feature_extractor must be dense_feature_extractor or a type with a
+ compatible interface.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements Breiman's classic random forest regression
+ algorithm. The algorithm learns to map objects, nominally vectors in R^n,
+ into the reals. It essentially optimizes the mean squared error by fitting
+ a bunch of decision trees, each of which vote on the output value of the
+ regressor. The final prediction is obtained by averaging all the
+ predictions.
+
+ For more information on the algorithm see:
+ Breiman, Leo. "Random forests." Machine learning 45.1 (2001): 5-32.
+ !*/
+
+ public:
+ typedef feature_extractor feature_extractor_type;
+ typedef random_forest_regression_function<feature_extractor> trained_function_type;
+ typedef typename feature_extractor::sample_type sample_type;
+
+
+ random_forest_regression_trainer (
+ );
+ /*!
+ ensures
+ - #get_min_samples_per_leaf() == 5
+ - #get_num_trees() == 1000
+ - #get_feature_subsampling_frac() == 1.0/3.0
+ - #get_feature_extractor() == a default initialized feature extractor.
+ - #get_random_seed() == ""
+ - this object is not verbose.
+ !*/
+
+ const feature_extractor_type& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used when train() is invoked.
+ !*/
+
+ void set_feature_extractor (
+ const feature_extractor_type& feat_extractor
+ );
+ /*!
+ ensures
+ - #get_feature_extractor() == feat_extractor
+ !*/
+
+ void set_seed (
+ const std::string& seed
+ );
+ /*!
+ ensures
+ - #get_random_seed() == seed
+ !*/
+
+ const std::string& get_random_seed (
+ ) const;
+ /*!
+ ensures
+ - A central part of this algorithm is random selection of both training
+ samples and features. This function returns the seed used to initialized
+ the random number generator used for these random selections.
+ !*/
+
+ size_t get_num_trees (
+ ) const;
+ /*!
+ ensures
+ - Random forests built by this object will contain get_num_trees() trees.
+ !*/
+
+ void set_num_trees (
+ size_t num
+ );
+ /*!
+ requires
+ - num > 0
+ ensures
+ - #get_num_trees() == num
+ !*/
+
+ void set_feature_subsampling_fraction (
+ double frac
+ );
+ /*!
+ requires
+ - 0 < frac <= 1
+ ensures
+ - #get_feature_subsampling_frac() == frac
+ !*/
+
+ double get_feature_subsampling_frac(
+ ) const;
+ /*!
+ ensures
+ - When we build trees, at each node we don't look at all the available
+ features. We consider only get_feature_subsampling_frac() fraction of
+ them, selected at random.
+ !*/
+
+ void set_min_samples_per_leaf (
+ size_t num
+ );
+ /*!
+ requires
+ - num > 0
+ ensures
+ - #get_min_samples_per_leaf() == num
+ !*/
+
+ size_t get_min_samples_per_leaf(
+ ) const;
+ /*!
+ ensures
+ - When building trees, each leaf node in a tree will contain at least
+ get_min_samples_per_leaf() samples. This means that the output votes of
+ each tree are averages of at least get_min_samples_per_leaf() y values.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that the
+ progress of training can be tracked..
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ random_forest_regression_function<feature_extractor> train (
+ const std::vector<sample_type>& x,
+ const std::vector<double>& y,
+ std::vector<double>& oob_values
+ ) const;
+ /*!
+ requires
+ - x.size() == y.size()
+ - x.size() > 0
+ - Running following code:
+ auto fe = get_feature_extractor()
+ fe.setup(x,y);
+ Must be valid and result in fe.max_num_feats() != 0
+ ensures
+ - This function fits a regression forest to the given training data. The
+ goal being to regress x to y in the mean squared sense. It therefore
+ fits regression trees and returns the resulting random_forest_regression_function
+ RF, which will have the following properties:
+ - RF.get_num_trees() == get_num_trees()
+ - for all valid i:
+ - RF(x[i]) should output a value close to y[i]
+ - RF.get_feature_extractor() will be a copy of this->get_feature_extractor()
+ that has been configured by a call the feature extractor's setup() routine.
+ To run the algorithm we need to use a feature extractor. We obtain a
+ valid feature extractor by making a copy of get_feature_extractor(), then
+ invoking setup(x,y) on it. This feature extractor is what is used to fit
+ the trees and is also the feature extractor stored in the returned random
+ forest.
+ - #oob_values.size() == y.size()
+ - for all valid i:
+ - #oob_values[i] == the "out of bag" prediction for y[i]. It is
+ calculated by computing the average output from trees not trained on
+ y[i]. This is similar to a leave-one-out cross-validation prediction
+ of y[i] and can be used to estimate the generalization error of the
+ regression forest.
+ - Training uses all the available CPU cores.
+ !*/
+
+ random_forest_regression_function<feature_extractor> train (
+ const std::vector<sample_type>& x,
+ const std::vector<double>& y
+ ) const;
+ /*!
+ requires
+ - x.size() == y.size()
+ - x.size() > 0
+ - Running following code:
+ auto fe = get_feature_extractor()
+ fe.setup(x,y);
+ Must be valid and result in fe.max_num_feats() != 0
+ ensures
+ - This function is identical to train(x,y,oob_values) except that the
+ oob_values are not calculated.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANdOM_FOREST_REGRESION_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/ref.h b/ml/dlib/dlib/ref.h
new file mode 100644
index 000000000..53a37fbf8
--- /dev/null
+++ b/ml/dlib/dlib/ref.h
@@ -0,0 +1,84 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_REFERENCE_WRAPpER_H_
+#define DLIB_REFERENCE_WRAPpER_H_
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template<
+ typename T
+ >
+ class reference_wrapper
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple object that just holds a reference to another object.
+ It is useful because it can serve as a kind of "copyable reference".
+ !*/
+
+ public:
+ typedef T type;
+
+ explicit reference_wrapper(T& o) : obj(&o) {}
+
+ operator T&() const { return *obj; }
+ T& get() const { return *obj; }
+
+ private:
+ T* obj;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ reference_wrapper<T> ref(
+ T& obj
+ ) { return reference_wrapper<T>(obj); }
+ /*!
+ ensures
+ - returns a reference_wrapper that contains a reference to obj.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ reference_wrapper<T> ref(
+ reference_wrapper<T> obj
+ ) { return obj; }
+ /*!
+ ensures
+ - returns the given reference_wrapper object without modification
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ reference_wrapper<const T> cref(
+ const T& obj
+ ) { return reference_wrapper<const T>(obj); }
+ /*!
+ ensures
+ - returns a reference_wrapper that contains a constant reference to obj.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ reference_wrapper<const T> cref(
+ reference_wrapper<T> obj
+ ) { return cref(obj.get()); }
+ /*!
+ ensures
+ - converts the given reference_wrapper into a reference_wrapper that contains a
+ constant reference.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_REFERENCE_WRAPpER_H_
+
diff --git a/ml/dlib/dlib/reference_counter.h b/ml/dlib/dlib/reference_counter.h
new file mode 100644
index 000000000..3e4b1ee20
--- /dev/null
+++ b/ml/dlib/dlib/reference_counter.h
@@ -0,0 +1,31 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_REFERENCE_COUNTEr_
+#define DLIB_REFERENCE_COUNTEr_
+
+#include "reference_counter/reference_counter_kernel_1.h"
+#include "algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename copy = copy_functor<T>
+ >
+ class reference_counter
+ {
+ reference_counter() {}
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef reference_counter_kernel_1<T,copy>
+ kernel_1a;
+
+ };
+}
+
+#endif // DLIB_REFERENCE_COUNTEr_
+
diff --git a/ml/dlib/dlib/reference_counter/reference_counter_kernel_1.h b/ml/dlib/dlib/reference_counter/reference_counter_kernel_1.h
new file mode 100644
index 000000000..64e465550
--- /dev/null
+++ b/ml/dlib/dlib/reference_counter/reference_counter_kernel_1.h
@@ -0,0 +1,298 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_REFERENCE_COUNTER_KERNEl_1_
+#define DLIB_REFERENCE_COUNTER_KERNEl_1_
+
+#include "reference_counter_kernel_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename copy = copy_functor<T>
+ >
+ class reference_counter_kernel_1
+ {
+
+ /*!
+ INITIAL VALUE
+ *data = item of type T with its initial value
+ *count = 1
+
+ CONVENTION
+ *data = pointer to item of type T
+ *count = number of references to *data
+
+ if clear() threw an exception then count = 0 and data is not a
+ valid pointer
+ !*/
+
+ public:
+
+ typedef T type;
+
+
+ reference_counter_kernel_1 (
+ );
+
+ inline reference_counter_kernel_1 (
+ const reference_counter_kernel_1& item
+ );
+
+ virtual ~reference_counter_kernel_1 (
+ );
+
+ void clear (
+ );
+
+ T& modify (
+ );
+
+ inline const T& access (
+ ) const;
+
+ inline reference_counter_kernel_1& operator= (
+ const reference_counter_kernel_1& rhs
+ );
+
+ inline void swap (
+ reference_counter_kernel_1& item
+ );
+
+
+ private:
+
+ T* data;
+ unsigned long* count;
+ mutable copy copy_item;
+ };
+
+ template <
+ typename T,
+ typename copy
+ >
+ inline void swap (
+ reference_counter_kernel_1<T,copy>& a,
+ reference_counter_kernel_1<T,copy>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename copy
+ >
+ reference_counter_kernel_1<T,copy>::
+ reference_counter_kernel_1 (
+ )
+ {
+ data = new T;
+ try { count = new unsigned long; }
+ catch (...) { delete data; throw; }
+
+ *count = 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename copy
+ >
+ reference_counter_kernel_1<T,copy>::
+ reference_counter_kernel_1 (
+ const reference_counter_kernel_1<T,copy>& item
+ ) :
+ data(item.data),
+ count(item.count)
+ {
+ ++(*count);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename copy
+ >
+ reference_counter_kernel_1<T,copy>::
+ ~reference_counter_kernel_1 (
+ )
+ {
+ if (*count > 1)
+ {
+ // if there are other references to this data
+ --(*count);
+ }
+ else
+ {
+ // if there are no other references to this data
+ delete count;
+ delete data;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename copy
+ >
+ void reference_counter_kernel_1<T,copy>::
+ clear (
+ )
+ {
+ // if an exception was thrown last time clear() was called then do this
+ if (count == 0)
+ {
+ data = new T;
+ try { count = new unsigned long; }
+ catch (...) { delete data; throw; }
+
+ *count = 1;
+ }
+ // if there are other references to the data then do this
+ else if (*count > 1)
+ {
+ --(*count);
+
+ try { data = new T; }
+ catch (...) { count = 0; throw; }
+
+ try { count = new unsigned long; }
+ catch (...) { delete data; count = 0; throw; }
+
+ *count = 1;
+
+ }
+ else
+ {
+ // if there are no other references to this data
+ *count = 1;
+ delete data;
+ try { data = new T; } catch (...) { delete count; count = 0; throw; }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename copy
+ >
+ T& reference_counter_kernel_1<T,copy>::
+ modify (
+ )
+ {
+ // if this is not the only reference then make a new copy
+ if ( *count > 1 )
+ {
+ T& old_data = *data;
+ unsigned long& old_count = *count;
+
+
+ // get memory for the new copy
+ try { data = new T; }
+ catch (...) { data = &old_data; throw; }
+
+ try { count = new unsigned long; }
+ catch (...) {delete data; data = &old_data; count = &old_count; throw;}
+
+ // decrement the number of references to old_data
+ --(old_count);
+
+ *count = 1;
+
+ // make a copy of the old data
+ try { copy_item(old_data,*data); }
+ catch (...)
+ { delete data; delete count; data = &old_data; count = &old_count; }
+
+ }
+
+ return *data;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename copy
+ >
+ const T& reference_counter_kernel_1<T,copy>::
+ access (
+ ) const
+ {
+ return *data;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename copy
+ >
+ reference_counter_kernel_1<T,copy>& reference_counter_kernel_1<T,copy>::
+ operator= (
+ const reference_counter_kernel_1<T,copy>& rhs
+ )
+ {
+ if (this == &rhs)
+ return *this;
+
+ // delete the current data if this is the last reference to it
+ if (*count > 1)
+ {
+ // if there are other references to this data
+ --(*count);
+ }
+ else
+ {
+ // if there are no other references to this data
+ delete count;
+ delete data;
+ }
+
+ // copy the pointers
+ count = (rhs.count);
+ data = (rhs.data);
+ ++(*count);
+
+ return *this;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename copy
+ >
+ void reference_counter_kernel_1<T,copy>::
+ swap (
+ reference_counter_kernel_1<T,copy>& item
+ )
+ {
+ T* data_temp = data;
+ unsigned long* count_temp = count;
+
+ data = item.data;
+ count = item.count;
+
+ item.data = data_temp;
+ item.count = count_temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_REFERENCE_COUNTER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/reference_counter/reference_counter_kernel_abstract.h b/ml/dlib/dlib/reference_counter/reference_counter_kernel_abstract.h
new file mode 100644
index 000000000..e46127387
--- /dev/null
+++ b/ml/dlib/dlib/reference_counter/reference_counter_kernel_abstract.h
@@ -0,0 +1,141 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_REFERENCE_COUNTER_KERNEl_ABSTRACT_
+#ifdef DLIB_REFERENCE_COUNTER_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename copy = copy_functor<T>
+ >
+ class reference_counter
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must have a default constructor
+
+ REQUIREMENTS ON copy
+ it should be a function object that copies an object of type T. and
+ it must have a default constructor and
+ operator() should be overloaded as
+ void operator()(const T& source, T& destination);
+ copy may throw any exception
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap() and access() functions do not invalidate pointers or
+ references to internal data.
+ All other functions have no such guarantee
+
+
+ INITIAL VALUE
+ reference_counter contains one object of type T and
+ this object of type T has its initial value
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a container for an object of type T and
+ provides reference counting capabilities for the object it contains
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+ !*/
+
+ public:
+
+ typedef T type;
+
+ reference_counter (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ !*/
+
+ reference_counter (
+ const reference_counter& item
+ );
+ /*!
+ ensures
+ - #access() == item.access()
+ !*/
+
+ virtual ~reference_counter (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ T& modify (
+ );
+ /*!
+ ensures
+ - returns a non-const reference to the item contained in *this
+ - the item is ok to modify. i.e. there are no other references to it
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ modify() may throw this exception if there are other references
+ to the item and there is not enough memory to copy it. If modify()
+ throws then it has no effect.
+ !*/
+
+ const T& access (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the item contained in *this
+ - there may be other references to to the item
+ !*/
+
+ reference_counter& operator= (
+ const reference_counter& rhs
+ );
+ /*!
+ ensures
+ - #access() == rhs.access()
+ !*/
+
+ void swap (
+ reference_counter& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ template <
+ typename T,
+ typename copy
+ >
+ inline void swap (
+ reference_counter<T,copy>& a,
+ reference_counter<T,copy>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_REFERENCE_COUNTER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/revision.h.in b/ml/dlib/dlib/revision.h.in
new file mode 100644
index 000000000..e6922c17c
--- /dev/null
+++ b/ml/dlib/dlib/revision.h.in
@@ -0,0 +1,6 @@
+#ifndef DLIB_REVISION_H
+#define DLIB_MAJOR_VERSION @CPACK_PACKAGE_VERSION_MAJOR@
+#define DLIB_MINOR_VERSION @CPACK_PACKAGE_VERSION_MINOR@
+#define DLIB_PATCH_VERSION @CPACK_PACKAGE_VERSION_PATCH@
+#endif
+
diff --git a/ml/dlib/dlib/sequence.h b/ml/dlib/dlib/sequence.h
new file mode 100644
index 000000000..1223c30bb
--- /dev/null
+++ b/ml/dlib/dlib/sequence.h
@@ -0,0 +1,83 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEQUENCe_
+#define DLIB_SEQUENCe_
+
+#include "sequence/sequence_kernel_1.h"
+#include "sequence/sequence_kernel_2.h"
+#include "sequence/sequence_kernel_c.h"
+
+#include "sequence/sequence_compare_1.h"
+#include "sequence/sequence_sort_1.h"
+#include "sequence/sequence_sort_2.h"
+#include "algs.h"
+
+
+
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class sequence
+ {
+
+ sequence() {}
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef sequence_kernel_1<T,mem_manager>
+ kernel_1a;
+ typedef sequence_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+ // kernel_2a
+ typedef sequence_kernel_2<T,mem_manager>
+ kernel_2a;
+ typedef sequence_kernel_c<kernel_2a>
+ kernel_2a_c;
+
+
+ //---------- extensions ------------
+
+ // compare_1 extend kernel_1a
+ typedef sequence_compare_1<kernel_1a >
+ compare_1a;
+ typedef sequence_compare_1<kernel_1a_c>
+ compare_1a_c;
+
+ // compare_1 extend kernel_2a
+ typedef sequence_compare_1<kernel_2a >
+ compare_1b;
+ typedef sequence_compare_1<kernel_2a_c>
+ compare_1b_c;
+
+
+
+ // sort_1 extend kernel_2a
+ typedef sequence_sort_1<kernel_2a>
+ sort_1a;
+ typedef sequence_sort_1<kernel_2a_c>
+ sort_1a_c;
+
+ // sort_2 extend kernel_1a
+ typedef sequence_sort_2<kernel_1a>
+ sort_2a;
+ typedef sequence_sort_2<kernel_1a_c>
+ sort_2a_c;
+
+
+
+
+
+
+ };
+}
+
+#endif // DLIB_SEQUENCe_
+
diff --git a/ml/dlib/dlib/sequence/sequence_compare_1.h b/ml/dlib/dlib/sequence/sequence_compare_1.h
new file mode 100644
index 000000000..9bc3c773d
--- /dev/null
+++ b/ml/dlib/dlib/sequence/sequence_compare_1.h
@@ -0,0 +1,102 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEQUENCE_COMPARe_1_
+#define DLIB_SEQUENCE_COMPARe_1_
+
+#include "sequence_compare_abstract.h"
+
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+ template <
+ typename seq_base
+ >
+ class sequence_compare_1 : public seq_base
+ {
+ typedef typename seq_base::type T;
+
+ public:
+
+ bool operator< (
+ const sequence_compare_1& rhs
+ ) const;
+
+ bool operator== (
+ const sequence_compare_1& rhs
+ ) const;
+
+ };
+
+
+ template <
+ typename seq_base
+ >
+ inline void swap (
+ sequence_compare_1<seq_base>& a,
+ sequence_compare_1<seq_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ bool sequence_compare_1<seq_base>::
+ operator< (
+ const sequence_compare_1<seq_base>& rhs
+ ) const
+ {
+ unsigned int length;
+ if (this->size() < rhs.size())
+ length = this->size();
+ else
+ length = rhs.size();
+
+ for (unsigned long i = 0; i < length; ++i)
+ {
+ if ((*this)[i] < rhs[i])
+ return true;
+ else if ( !((*this)[i] == rhs[i]) )
+ return false;
+ }
+ // they are equal so far
+ if (this->size() < rhs.size())
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ bool sequence_compare_1<seq_base>::
+ operator== (
+ const sequence_compare_1<seq_base>& rhs
+ ) const
+ {
+ if (this->size() != rhs.size())
+ return false;
+
+ for (unsigned long i = 0; i < this->size(); ++i)
+ {
+ if (!((*this)[i] == rhs[i]))
+ return false;
+ }
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEQUENCE_COMPARe_1_
+
diff --git a/ml/dlib/dlib/sequence/sequence_compare_abstract.h b/ml/dlib/dlib/sequence/sequence_compare_abstract.h
new file mode 100644
index 000000000..261703f45
--- /dev/null
+++ b/ml/dlib/dlib/sequence/sequence_compare_abstract.h
@@ -0,0 +1,75 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SEQUENCE_COMPARe_ABSTRACT_
+#ifdef DLIB_SEQUENCE_COMPARe_ABSTRACT_
+
+#include "sequence_kernel_abstract.h"
+
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+ template <
+ typename seq_base
+ >
+ class sequence_compare : public seq_base
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must implement operator< for its type and
+ T must implement operator== for its type
+
+ REQUIREMENTS ON SEQUENCE_BASE
+ must be an implementation of sequence/sequence_kernel_abstract.h
+
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ operator== and operator< do not invalidate pointers or references to
+ data members
+
+ WHAT THIS EXTENSION DOES FOR sequence
+ This gives a sequence the ability to compare itself to other
+ sequences using the < and == operators.
+ !*/
+
+ public:
+
+ bool operator< (
+ const sequence_compare& rhs
+ ) const;
+ /*!
+ ensures
+ - returns true if there exists an integer j such that 0 <= j < size()
+ and for all integers i such that 0 <= i < j where it is true that
+ (*this)[i] <= rhs[i] and (*this)[j] < rhs[j]
+ - returns false if there is no j that will satisfy the above conditions
+ !*/
+
+ bool operator== (
+ const sequence_compare& rhs
+ ) const;
+ /*!
+ ensures
+ - returns true if for all i: (*this)[i] == rhs[i] else returns false
+ !*/
+
+ };
+
+ template <
+ typename seq_base
+ >
+ inline void swap (
+ sequence_compare<seq_base>& a,
+ sequence_compare<seq_base>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_SEQUENCE_COMPARe_ABSTRACT_
+
diff --git a/ml/dlib/dlib/sequence/sequence_kernel_1.h b/ml/dlib/dlib/sequence/sequence_kernel_1.h
new file mode 100644
index 000000000..9e1e26f1b
--- /dev/null
+++ b/ml/dlib/dlib/sequence/sequence_kernel_1.h
@@ -0,0 +1,1340 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEQUENCE_KERNEl_1_
+#define DLIB_SEQUENCE_KERNEl_1_
+
+#include "sequence_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class sequence_kernel_1 : public enumerable<T>,
+ public remover<T>
+ {
+
+ /*!
+ INITIAL VALUE
+ - tree_root == 0
+ - tree_size == 0
+ - at_start_ == true
+ - current_element == 0
+ - stack == array of 50 node pointers
+ - stack_pos == 0
+
+ CONVENTION
+
+ - if (tree_size > 0)
+ - tree_root == pointer to the root node of the binary search tree
+ - else
+ - tree_root == 0
+
+
+
+ - stack[stack_pos-1] == pop()
+
+ - current_element_valid() == (current_element != 0)
+
+ - at_start_ == at_start()
+ - if (current_element != 0 && current_element != tree_root) then
+ - stack[stack_pos-1] == the parent of the node pointed to by current_element
+
+ - if (current_element_valid()) then
+ - element() == current_element->item
+
+
+
+ - tree_size == size()
+ - (*this)[i] == return_reference(i)
+
+
+ - for all nodes:
+ - left_size == the number of elements in the left subtree.
+ - left points to the left subtree or 0 if there is no left subtree.
+ - right points to the right subtree or 0 if there is no right subtree.
+
+ - all elements in a left subtree have a position in the sequence < that
+ of the root of the current tree.
+
+ - all elements in a right subtree have a position in the
+ sequence > that of the root of the current tree.
+
+ - item is the sequence element for that node.
+ - balance:
+ - balance == 0 if both subtrees have the same height
+ - balance == -1 if the left subtree has a height that is
+ greater than the height of the right subtree by 1
+ - balance == 1 if the right subtree has a height that is
+ greater than the height of the left subtree by 1
+ - for all subtrees:
+ - the height of the left and right subtrees differ by at most one
+
+ !*/
+
+
+ class node
+ {
+ public:
+ node* left;
+ node* right;
+ unsigned long left_size;
+ T item;
+ signed char balance;
+ };
+
+
+
+ public:
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ sequence_kernel_1 (
+ ) :
+ tree_root(0),
+ tree_size(0),
+ stack(ppool.allocate_array(50)),
+ current_element(0),
+ at_start_(true),
+ stack_pos(0)
+ {}
+
+ virtual ~sequence_kernel_1 (
+ );
+
+ inline void clear (
+ );
+
+ void add (
+ unsigned long pos,
+ T& item
+ );
+
+ void remove (
+ unsigned long pos,
+ T& item
+ );
+
+ void cat (
+ sequence_kernel_1& item
+ );
+
+ const T& operator[] (
+ unsigned long pos
+ ) const;
+
+ T& operator[] (
+ unsigned long pos
+ );
+
+ inline void swap (
+ sequence_kernel_1& item
+ );
+
+ // functions from the remover interface
+ inline void remove_any (
+ T& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ const T& element (
+ ) const;
+
+ T& element (
+ );
+
+ bool move_next (
+ ) const;
+
+
+ private:
+
+ void delete_nodes (
+ node* t
+ );
+ /*!
+ requires
+ - t == a pointer to a valid node
+ ensures
+ - deletes t and all its sub nodes.
+ !*/
+
+ inline void rotate_left (
+ node*& t
+ );
+ /*!
+ requires
+ - t->balance == 2
+ - t->right->balance == 0 or 1
+ ensures
+ - #t is still a binary search tree
+ - #t->balance is between 1 and -1
+ - #t now has a height smaller by 1 if #t->balance == 0
+ !*/
+
+ inline void rotate_right (
+ node*& t
+ );
+ /*!
+ requires
+ - t->balance == -2
+ - t->left->balance == 0 or -1
+ ensures
+ - #t is still a binary search tree
+ - #t->balance is between 1 and -1
+ - #t now has a height smaller by 1 if #t->balance == 0
+
+ !*/
+
+ inline void double_rotate_right (
+ node*& t
+ );
+ /*!
+ requires
+ - #t->balance == -2
+ - #t->left->balance == 1
+ ensures
+ - #t is still a binary search tree
+ - #t now has a balance of 0
+ - #t now has a height smaller by 1
+ !*/
+
+ inline void double_rotate_left (
+ node*& t
+ );
+ /*!
+ requires
+ - #t->balance == 2
+ - #t->right->balance == -1
+ ensures
+ - #t is still a binary search tree
+ - #t now has a balance of 0 and
+ - #t now has a height smaller by 1
+ !*/
+
+ bool remove_least_element_in_tree (
+ node*& t,
+ T& item
+ );
+ /*!
+ requires
+ - t != 0 (i.e. there must be something in the tree to remove)
+ ensures
+ - the least node in t has been removed
+ - the least node element in t has been put into #item
+ - #t is still a binary search tree
+ - returns false if the height of the tree has not changed
+ - returns true if the height of the tree has shrunk by one
+ !*/
+
+ bool add_to_tree (
+ node*& t,
+ unsigned long pos,
+ T& item
+ );
+ /*!
+ requires
+ - pos <= the number of items in the tree
+ ensures
+ - item has been added to #t
+ - #return_reference(pos) == item
+ - the convention is still satisfied
+ - #item has an initial value for its type
+ - returns false if the height of the tree has not changed
+ - returns true if the height of the tree has grown by one
+ !*/
+
+ bool remove_from_tree (
+ node*& t,
+ unsigned long pos,
+ T& item
+ );
+ /*!
+ requires
+ - there is an item in the tree associated with pos
+ ensures
+ - the element in the tree associated with pos has been removed
+ and put into #item
+ - the convention is still satisfied
+ - returns false if the height of the tree has not changed
+ - returns true if the height of the tree has shrunk by one
+ !*/
+
+ const T& return_reference (
+ const node* t,
+ unsigned long pos
+ ) const;
+ /*!
+ requires
+ - there is an item in the tree associated with pos
+ ensures
+ - returns a const reference to the item in the tree associated with pos
+ !*/
+
+ T& return_reference (
+ node* t,
+ unsigned long pos
+ );
+ /*!
+ requires
+ - there is an item in the tree associated with pos
+ ensures
+ - returns a non-const reference to the item in the tree associated
+ with pos
+ !*/
+
+ inline bool keep_node_balanced (
+ node*& t
+ );
+ /*!
+ requires
+ - t != 0
+ ensures
+ - if (t->balance is < 1 or > 1) then
+ - keep_node_balanced() will ensure that t->balance == 0, -1, or 1
+ - returns true if it made the tree one height shorter
+ - returns false if it didn't change the height
+ !*/
+
+ void push (
+ node* n
+ ) const { stack[stack_pos] = n; ++stack_pos; }
+ /*!
+ ensures
+ - pushes n onto the stack
+ !*/
+
+
+ node* pop (
+ ) const { --stack_pos; return stack[stack_pos]; }
+ /*!
+ ensures
+ - pops the top of the stack and returns it
+ !*/
+
+ // data members
+ typename mem_manager::template rebind<node>::other pool;
+ typename mem_manager::template rebind<node*>::other ppool;
+
+ node* tree_root;
+ unsigned long tree_size;
+
+ mutable node** stack;
+ mutable node* current_element;
+ mutable bool at_start_;
+ mutable unsigned char stack_pos;
+
+ // restricted functions
+ sequence_kernel_1(sequence_kernel_1&); // copy constructor
+ sequence_kernel_1& operator=(sequence_kernel_1&); // assignment operator
+
+ };
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ sequence_kernel_1<T,mem_manager>& a,
+ sequence_kernel_1<T,mem_manager>& b
+ ) { a.swap(b); }
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ sequence_kernel_1<T,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ T temp;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(temp,in);
+ item.add(i,temp);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type sequence_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ sequence_kernel_1<T,mem_manager>::
+ ~sequence_kernel_1 (
+ )
+ {
+ ppool.deallocate_array(stack);
+ if (tree_size > 0)
+ {
+ delete_nodes(tree_root);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ swap (
+ sequence_kernel_1<T,mem_manager>& item
+ )
+ {
+ exchange(stack,item.stack);
+ exchange(stack_pos,item.stack_pos);
+
+ pool.swap(item.pool);
+ ppool.swap(item.ppool);
+
+ node* tree_root_temp = item.tree_root;
+ unsigned long tree_size_temp = item.tree_size;
+ node* current_element_temp = item.current_element;
+ bool at_start_temp = item.at_start_;
+
+ item.tree_root = tree_root;
+ item.tree_size = tree_size;
+ item.current_element = current_element;
+ item.at_start_ = at_start_;
+
+ tree_root = tree_root_temp;
+ tree_size = tree_size_temp;
+ current_element = current_element_temp;
+ at_start_ = at_start_temp;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ size_t sequence_kernel_1<T,mem_manager>::
+ size (
+ ) const
+ {
+ return tree_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& sequence_kernel_1<T,mem_manager>::
+ operator[] (
+ unsigned long pos
+ ) const
+ {
+ return return_reference(tree_root,pos);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& sequence_kernel_1<T,mem_manager>::
+ operator[] (
+ unsigned long pos
+ )
+ {
+ return return_reference(tree_root,pos);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ add (
+ unsigned long pos,
+ T& item
+ )
+ {
+ add_to_tree(tree_root,pos,item);
+ ++tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ remove (
+ unsigned long pos,
+ T& item
+ )
+ {
+ remove_from_tree(tree_root,pos,item);
+ --tree_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ cat (
+ sequence_kernel_1<T,mem_manager>& item
+ )
+ {
+ for (unsigned long i = 0; i < item.tree_size; ++i)
+ {
+ add_to_tree(
+ tree_root,
+ tree_size,
+ return_reference(item.tree_root,i)
+ );
+
+ ++tree_size;
+ }
+
+ item.clear();
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ clear (
+ )
+ {
+ if (tree_size > 0)
+ {
+ delete_nodes(tree_root);
+ tree_root = 0;
+ tree_size = 0;
+ }
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_1<T,mem_manager>::
+ at_start (
+ ) const
+ {
+ return at_start_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ reset (
+ ) const
+ {
+ at_start_ = true;
+ current_element = 0;
+ stack_pos = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_1<T,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return (current_element != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& sequence_kernel_1<T,mem_manager>::
+ element (
+ ) const
+ {
+ return current_element->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& sequence_kernel_1<T,mem_manager>::
+ element (
+ )
+ {
+ return current_element->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_1<T,mem_manager>::
+ move_next (
+ ) const
+ {
+ // if we haven't started iterating yet
+ if (at_start_)
+ {
+ at_start_ = false;
+ if (tree_size == 0)
+ {
+ return false;
+ }
+ else
+ {
+ // find the first element in the tree
+ current_element = tree_root;
+ node* temp = current_element->left;
+ while (temp != 0)
+ {
+ push(current_element);
+ current_element = temp;
+ temp = current_element->left;
+ }
+ return true;
+ }
+ }
+ else
+ {
+ if (current_element == 0)
+ {
+ return false;
+ }
+ else
+ {
+ node* temp;
+ bool went_up; // true if we went up the tree from a child node to parent
+ bool from_left = false; // true if we went up and were coming from a left child node
+ // find the next element in the tree
+ if (current_element->right != 0)
+ {
+ // go right and down
+ temp = current_element;
+ push(current_element);
+ current_element = temp->right;
+ went_up = false;
+ }
+ else
+ {
+ // go up to the parent if we can
+ if (current_element == tree_root)
+ {
+ // in this case we have iterated over all the element of the tree
+ current_element = 0;
+ return false;
+ }
+ went_up = true;
+ node* parent = pop();
+
+
+ from_left = (parent->left == current_element);
+ // go up to parent
+ current_element = parent;
+ }
+
+
+ while (true)
+ {
+ if (went_up)
+ {
+ if (from_left)
+ {
+ // in this case we have found the next node
+ break;
+ }
+ else
+ {
+ if (current_element == tree_root)
+ {
+ // in this case we have iterated over all the elements
+ // in the tree
+ current_element = 0;
+ return false;
+ }
+ // we should go up
+ node* parent = pop();
+ from_left = (parent->left == current_element);
+ current_element = parent;
+ }
+ }
+ else
+ {
+ // we just went down to a child node
+ if (current_element->left != 0)
+ {
+ // go left
+ went_up = false;
+ temp = current_element;
+ push(current_element);
+ current_element = temp->left;
+ }
+ else
+ {
+ // if there is no left child then we have found the next node
+ break;
+ }
+ }
+ }
+
+ return true;
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // remover function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ remove_any (
+ T& item
+ )
+ {
+ remove(0,item);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ rotate_left (
+ node*& t
+ )
+ {
+
+ // set the new balance numbers
+ if (t->right->balance == 1)
+ {
+ t->balance = 0;
+ t->right->balance = 0;
+ }
+ else
+ {
+ t->balance = 1;
+ t->right->balance = -1;
+ }
+
+ // perform the rotation
+ node* temp = t->right;
+ t->right = temp->left;
+ temp->left = t;
+ t = temp;
+
+
+ // set left_size to its correct value
+ t->left_size += t->left->left_size + 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ rotate_right (
+ node*& t
+ )
+ {
+ // set the new balance numbers
+ if (t->left->balance == -1)
+ {
+ t->balance = 0;
+ t->left->balance = 0;
+ }
+ else
+ {
+ t->balance = -1;
+ t->left->balance = 1;
+ }
+
+ // preform the rotation
+ node* temp = t->left;
+ t->left = temp->right;
+ temp->right = t;
+ t = temp;
+
+
+ // set left_size to its correct value
+ t->right->left_size -= t->left_size + 1;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ double_rotate_right (
+ node*& t
+ )
+ {
+
+ node* temp = t;
+ t = t->left->right;
+
+ temp->left->right = t->left;
+ t->left = temp->left;
+
+ temp->left = t->right;
+ t->right = temp;
+
+ if (t->balance < 0)
+ {
+ t->left->balance = 0;
+ t->right->balance = 1;
+ }
+ else if (t->balance > 0)
+ {
+ t->left->balance = -1;
+ t->right->balance = 0;
+ }
+ else
+ {
+ t->left->balance = 0;
+ t->right->balance = 0;
+ }
+ t->balance = 0;
+
+
+ // set left_size to its correct value
+ t->left_size += t->left->left_size + 1;
+ t->right->left_size -= t->left_size + 1;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ double_rotate_left (
+ node*& t
+ )
+ {
+ node* temp = t;
+ t = t->right->left;
+
+ temp->right->left = t->right;
+ t->right = temp->right;
+
+ temp->right = t->left;
+ t->left = temp;
+
+ if (t->balance < 0)
+ {
+ t->left->balance = 0;
+ t->right->balance = 1;
+ }
+ else if (t->balance > 0)
+ {
+ t->left->balance = -1;
+ t->right->balance = 0;
+ }
+ else
+ {
+ t->left->balance = 0;
+ t->right->balance = 0;
+ }
+
+ t->balance = 0;
+
+ // set left_size to its correct value
+ t->right->left_size -= t->left_size + 1;
+ t->left_size += t->left->left_size + 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_1<T,mem_manager>::
+ remove_least_element_in_tree (
+ node*& t,
+ T& item
+ )
+ {
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+
+ // if the left tree is an empty tree
+ if ( tree.left == 0)
+ {
+ // swap nodes element into item
+ exchange(tree.item,item);
+
+ // plug hole left by removing this node
+ t = tree.right;
+
+ // delete the node that was just removed
+ tree.right = 0;
+ delete_nodes(&tree);
+
+ // return that the height of this part of the tree has decreased
+ return true;
+ }
+ else
+ {
+ // subtract one from the left size
+ --tree.left_size;
+
+ // keep going left
+
+ // if remove made the tree one height shorter
+ if ( remove_least_element_in_tree(tree.left,item) )
+ {
+ // if this caused the current tree to strink then report that
+ if ( tree.balance == -1)
+ {
+ ++tree.balance;
+ return true;
+ }
+ else
+ {
+ ++tree.balance;
+ return keep_node_balanced(t);
+ }
+ }
+
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_1<T,mem_manager>::
+ add_to_tree (
+ node*& t,
+ unsigned long pos,
+ T& item
+ )
+ {
+ // if found place to add
+ if (t == 0)
+ {
+ // create a node to add new item into
+ t = pool.allocate();
+
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+
+
+ // set left and right pointers to 0 to indicate that there are no
+ // left or right subtrees
+ tree.left = 0;
+ tree.right = 0;
+ tree.balance = 0;
+ tree.left_size = 0;
+
+ // put item into t
+ exchange(item,tree.item);
+
+ // indicate that the height of this tree has increased
+ return true;
+ }
+ else // keep looking for a place to add the new item
+ {
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+ signed char old_balance = tree.balance;
+
+ // add the new item to whatever subtree it should go into
+ if ( pos < tree.left_size + 1 )
+ {
+ tree.balance -= add_to_tree(tree.left,pos,item);
+ ++tree.left_size;
+ }
+ else
+ tree.balance += add_to_tree(tree.right,pos - tree.left_size - 1,item);
+
+
+ // if the tree was balanced to start with
+ if (old_balance == 0)
+ {
+ // if its not balanced anymore then it grew in height
+ if (tree.balance != 0)
+ return true;
+ else
+ return false;
+ }
+ else
+ {
+ // if the tree is now balanced then it didn't grow
+ if (tree.balance == 0)
+ {
+ return false;
+ }
+ else
+ {
+ // if the tree needs to be balanced
+ if (tree.balance != old_balance)
+ {
+ return !keep_node_balanced(t);
+ }
+ // if there has been no change in the heights
+ else
+ {
+ return false;
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_1<T,mem_manager>::
+ remove_from_tree (
+ node*& t,
+ unsigned long pos,
+ T& item
+ )
+ {
+
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+
+ // if item is on the left
+ if (pos < tree.left_size)
+ {
+ // adjust the left size
+ --tree.left_size;
+
+ // if the left side of the tree has the greatest height
+ if (tree.balance == -1)
+ {
+ tree.balance += remove_from_tree(tree.left,pos,item);
+ return !tree.balance;
+ }
+ else
+ {
+ tree.balance += remove_from_tree(tree.left,pos,item);
+ return keep_node_balanced(t);
+ }
+
+ }
+ // if item is found
+ else if (pos == tree.left_size)
+ {
+ // if there is no left node
+ if (tree.left == 0)
+ {
+ // swap nodes element into item
+ exchange(tree.item,item);
+
+ // plug hole left by removing this node and free memory
+ t = tree.right; // plug hole with right subtree
+
+ // delete old node
+ tree.right = 0;
+ delete_nodes(&tree);
+
+ // indicate that the height has changed
+ return true;
+ }
+ // if there is no right node
+ else if (tree.right == 0)
+ {
+ // swap nodes element into item
+ exchange(tree.item,item);
+
+ // plug hole left by removing this node and free memory
+ t = tree.left; // plug hole with left subtree
+
+ // delete old node
+ tree.left = 0;
+ delete_nodes(&tree);
+
+ // indicate that the height of this tree has changed
+ return true;
+ }
+ // if there are both a left and right sub node
+ else
+ {
+ // get an element that can replace the one being removed and do this
+ // if it made the right subtree shrink by one
+ if (remove_least_element_in_tree(tree.right,item))
+ {
+ // adjust the tree height
+ --tree.balance;
+
+ // put the element into item copy and also plug the
+ // hole with the smallest element from the right.
+ exchange(item,tree.item);
+
+ // if the height of the current tree has dropped by one
+ if (tree.balance == 0)
+ {
+ return true;
+ }
+ else
+ {
+ return keep_node_balanced(t);
+ }
+ }
+ // else this remove did not effect the height of this tree
+ else
+ {
+ // put the element into item copy and also plug the
+ // hole with the smallest element from the right.
+ exchange(item,tree.item);
+
+ return false;
+ }
+
+ }
+ }
+ // if item is on the right
+ else
+ {
+
+ // if the right side of the tree has the greatest height
+ if (tree.balance == 1)
+ {
+ tree.balance -= remove_from_tree(tree.right,pos - tree.left_size - 1,item);
+ return !tree.balance;
+ }
+ else
+ {
+ tree.balance -= remove_from_tree(tree.right,pos - tree.left_size - 1,item);
+ return keep_node_balanced(t);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& sequence_kernel_1<T,mem_manager>::
+ return_reference (
+ node* t,
+ unsigned long pos
+ )
+ {
+ while (true)
+ {
+ // if we have found the node
+ if (pos == t->left_size)
+ return t->item;
+
+ if (pos < t->left_size)
+ {
+ // go left
+ t = t->left;
+ }
+ else
+ {
+ // go right
+ pos -= t->left_size+1;
+ t = t->right;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& sequence_kernel_1<T,mem_manager>::
+ return_reference (
+ const node* t,
+ unsigned long pos
+ ) const
+ {
+ while (true)
+ {
+ // if we have found the node
+ if (pos == t->left_size)
+ return t->item;
+
+ if (pos < t->left_size)
+ {
+ // go left
+ t = t->left;
+ }
+ else
+ {
+ // go right
+ pos -= t->left_size+1;
+ t = t->right;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_1<T,mem_manager>::
+ keep_node_balanced (
+ node*& t
+ )
+ {
+ // make a reference to the current node so we don't have to dereference
+ // a pointer a bunch of times
+ node& tree = *t;
+
+ // if tree does not need to be balanced then return false
+ if (tree.balance == 0)
+ return false;
+
+
+ // if tree needs to be rotated left
+ if (tree.balance == 2)
+ {
+ if (tree.right->balance >= 0)
+ rotate_left(t);
+ else
+ double_rotate_left(t);
+ }
+ // else if the tree needs to be rotated right
+ else if (tree.balance == -2)
+ {
+ if (tree.left->balance <= 0)
+ rotate_right(t);
+ else
+ double_rotate_right(t);
+ }
+
+
+ if (t->balance == 0)
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_1<T,mem_manager>::
+ delete_nodes (
+ node* t
+ )
+ {
+ if (t->left)
+ delete_nodes(t->left);
+ if (t->right)
+ delete_nodes(t->right);
+ pool.deallocate(t);
+ }
+
+// ----------------------------------------------------------------------------------------
+}
+
+#endif // DLIB_SEQUENCE_KERNEl_1_
+
diff --git a/ml/dlib/dlib/sequence/sequence_kernel_2.h b/ml/dlib/dlib/sequence/sequence_kernel_2.h
new file mode 100644
index 000000000..f15c50e93
--- /dev/null
+++ b/ml/dlib/dlib/sequence/sequence_kernel_2.h
@@ -0,0 +1,682 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEQUENCE_KERNEl_2_
+#define DLIB_SEQUENCE_KERNEl_2_
+
+#include "sequence_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class sequence_kernel_2 : public enumerable<T>,
+ public remover<T>
+ {
+ /*!
+ INITIAL VALUE
+ sequence_size == 0
+ at_start_ == true
+ current_enumeration_node == 0
+
+ CONVENTION
+ sequence_size == the number of elements in the sequence
+
+ at_start_ == at_start()
+ (current_enumeration_node!=0) == current_element_valid()
+ if (current_enumeration_node!=0) then
+ current_enumeration_node->item == element()
+ current_enumeration_pos == the position of the node pointed to by
+ current_enumeration_node
+
+ if ( sequence_size > 0 )
+ {
+ current_node == pointer to a node in the linked list and
+ current_node->right->right->... eventually == current_node and
+ current_node->left->left->... eventually == current_node and
+ current_pos == the position in the sequence of
+ current_node->item
+ }
+
+ !*/
+
+ struct node {
+ T item;
+ node* right;
+ node* left;
+ };
+
+ public:
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ sequence_kernel_2 (
+ ) :
+ sequence_size(0),
+ at_start_(true),
+ current_enumeration_node(0)
+ {}
+
+ virtual ~sequence_kernel_2 (
+ );
+
+ inline void clear (
+ );
+
+ void add (
+ unsigned long pos,
+ T& item
+ );
+
+ void remove (
+ unsigned long pos,
+ T& item
+ );
+
+ void cat (
+ sequence_kernel_2& item
+ );
+
+ const T& operator[] (
+ unsigned long pos
+ ) const;
+
+ T& operator[] (
+ unsigned long pos
+ );
+
+ void swap (
+ sequence_kernel_2& item
+ );
+
+ // functions from the remover interface
+ inline void remove_any (
+ T& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ const T& element (
+ ) const;
+
+ T& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ private:
+
+ void delete_nodes (
+ node* current_node,
+ unsigned long sequence_size
+ );
+ /*!
+ requires
+ CONVENTION IS CORRECT
+ ensures
+ all memory associated with the ring of nodes has been freed
+ !*/
+
+ void move_to_pos (
+ node*& current_node,
+ unsigned long& current_pos,
+ unsigned long pos,
+ unsigned long size
+ ) const;
+ /*!
+ requires
+ everything in the CONVENTION is correct and
+ there is a node corresponding to pos in the CONVENTION and
+ 0 <= pos < size
+ ensures
+ current_pos == pos and
+ current_node->item is the item in the sequence associated with
+ position pos
+ !*/
+
+ // data members
+ unsigned long sequence_size;
+ mutable node* current_node;
+ mutable unsigned long current_pos;
+ mutable bool at_start_;
+ mutable node* current_enumeration_node;
+ mutable unsigned long current_enumeration_pos;
+
+ // restricted functions
+ sequence_kernel_2(sequence_kernel_2&); // copy constructor
+ sequence_kernel_2& operator=(sequence_kernel_2&); // assignment operator
+
+ };
+
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ sequence_kernel_2<T,mem_manager>& a,
+ sequence_kernel_2<T,mem_manager>& b
+ ) { a.swap(b); }
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ sequence_kernel_2<T,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ T temp;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(temp,in);
+ item.add(i,temp);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type sequence_kernel_2");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ sequence_kernel_2<T,mem_manager>::
+ ~sequence_kernel_2 (
+ )
+ {
+ delete_nodes(current_node,sequence_size);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_2<T,mem_manager>::
+ clear (
+ )
+ {
+ if (sequence_size != 0)
+ {
+ delete_nodes(current_node,sequence_size);
+ sequence_size = 0;
+ }
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_2<T,mem_manager>::
+ add (
+ unsigned long pos,
+ T& item
+ )
+ {
+ // make new node and swap item into it
+ node* new_node = new node;
+ exchange(item,new_node->item);
+
+ if (sequence_size > 0)
+ {
+ if (pos == sequence_size)
+ {
+ move_to_pos(current_node,current_pos,pos-1,sequence_size);
+
+ node& n_node = *new_node;
+ node& c_node = *current_node;
+
+ // make new node point to the nodes to its left and right
+ n_node.right = c_node.right;
+ n_node.left = current_node;
+
+ // make the left node point back to new_node
+ c_node.right->left = new_node;
+
+ // make the right node point back to new_node
+ c_node.right = new_node;
+ current_pos = pos;
+
+ }
+ else
+ {
+ move_to_pos(current_node,current_pos,pos,sequence_size);
+
+ node& n_node = *new_node;
+ node& c_node = *current_node;
+
+ // make new node point to the nodes to its left and right
+ n_node.right = current_node;
+ n_node.left = c_node.left;
+
+ // make the left node point back to new_node
+ c_node.left->right = new_node;
+
+ // make the right node point back to new_node
+ c_node.left = new_node;
+ }
+
+ }
+ else
+ {
+ current_pos = 0;
+ new_node->left = new_node;
+ new_node->right = new_node;
+ }
+
+ // make the new node the current node
+ current_node = new_node;
+
+ ++sequence_size;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_2<T,mem_manager>::
+ remove (
+ unsigned long pos,
+ T& item
+ )
+ {
+ move_to_pos(current_node,current_pos,pos,sequence_size);
+ node& c_node = *current_node;
+ exchange(c_node.item,item);
+
+ node* temp = current_node;
+
+ // close up gap left by remove
+ c_node.left->right = c_node.right;
+ c_node.right->left = c_node.left;
+
+ current_node = c_node.right;
+
+ --sequence_size;
+
+ delete temp;
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& sequence_kernel_2<T,mem_manager>::
+ operator[] (
+ unsigned long pos
+ ) const
+ {
+ move_to_pos(current_node,current_pos,pos,sequence_size);
+ return current_node->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_2<T,mem_manager>::
+ cat (
+ sequence_kernel_2<T,mem_manager>& item
+ )
+ {
+ if (item.sequence_size > 0)
+ {
+ if (sequence_size > 0)
+ {
+ // move both sequences to a convenient location
+ move_to_pos(current_node,current_pos,0,sequence_size);
+ item.move_to_pos (
+ item.current_node,
+ item.current_pos,
+ item.sequence_size-1,
+ item.sequence_size
+ );
+
+ // make copies of poitners
+ node& item_right = *item.current_node->right;
+ node& left = *current_node->left;
+
+
+ item.current_node->right = current_node;
+ current_node->left = item.current_node;
+
+ left.right = &item_right;
+ item_right.left = &left;
+
+ // set sizes
+ sequence_size += item.sequence_size;
+ item.sequence_size = 0;
+ }
+ else
+ {
+ // *this is empty so just swap
+ item.swap(*this);
+ }
+ }
+ item.clear();
+ // reset the enumerator
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& sequence_kernel_2<T,mem_manager>::
+ operator[] (
+ unsigned long pos
+ )
+ {
+ move_to_pos(current_node,current_pos,pos,sequence_size);
+ return current_node->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ size_t sequence_kernel_2<T,mem_manager>::
+ size (
+ ) const
+ {
+ return sequence_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_2<T,mem_manager>::
+ swap (
+ sequence_kernel_2<T,mem_manager>& item
+ )
+ {
+ unsigned long sequence_size_temp = item.sequence_size;
+ node* current_node_temp = item.current_node;
+ unsigned long current_pos_temp = item.current_pos;
+ bool at_start_temp = item.at_start_;
+ node* current_enumeration_node_temp = item.current_enumeration_node;
+ unsigned long current_enumeration_pos_temp = item.current_enumeration_pos;
+
+ item.sequence_size = sequence_size;
+ item.current_node = current_node;
+ item.current_pos = current_pos;
+ item.at_start_ = at_start_;
+ item.current_enumeration_node = current_enumeration_node;
+ item.current_enumeration_pos = current_enumeration_pos;
+
+ sequence_size = sequence_size_temp;
+ current_node = current_node_temp;
+ current_pos = current_pos_temp;
+ at_start_ = at_start_temp;
+ current_enumeration_node = current_enumeration_node_temp;
+ current_enumeration_pos = current_enumeration_pos_temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_2<T,mem_manager>::
+ at_start (
+ ) const
+ {
+ return at_start_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_2<T,mem_manager>::
+ reset (
+ ) const
+ {
+ at_start_ = true;
+ current_enumeration_node = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_2<T,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return (current_enumeration_node!=0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& sequence_kernel_2<T,mem_manager>::
+ element (
+ ) const
+ {
+ return current_enumeration_node->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& sequence_kernel_2<T,mem_manager>::
+ element (
+ )
+ {
+ return current_enumeration_node->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool sequence_kernel_2<T,mem_manager>::
+ move_next (
+ ) const
+ {
+ if (at_start_ && sequence_size>0)
+ {
+ move_to_pos(current_node,current_pos,0,sequence_size);
+ current_enumeration_node = current_node;
+ current_enumeration_pos = 0;
+ }
+ else if (current_enumeration_node!=0)
+ {
+ ++current_enumeration_pos;
+ if (current_enumeration_pos<sequence_size)
+ {
+ current_enumeration_node = current_enumeration_node->right;
+ }
+ else
+ {
+ // we have reached the end of the sequence
+ current_enumeration_node = 0;
+ }
+ }
+
+ at_start_ = false;
+ return (current_enumeration_node!=0);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // remover function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_2<T,mem_manager>::
+ remove_any (
+ T& item
+ )
+ {
+ remove(0,item);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_2<T,mem_manager>::
+ delete_nodes (
+ node* current_node,
+ unsigned long sequence_size
+ )
+ {
+ node* temp;
+ while (sequence_size)
+ {
+ temp = current_node->right;
+ delete current_node;
+ current_node = temp;
+ --sequence_size;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void sequence_kernel_2<T,mem_manager>::
+ move_to_pos (
+ node*& current_node,
+ unsigned long& current_pos,
+ unsigned long pos,
+ unsigned long size
+ ) const
+ {
+ if ( current_pos > pos)
+ {
+ // number of hops in each direction needed to reach pos
+ unsigned long right = size + pos - current_pos;
+ unsigned long left = current_pos - pos;
+ current_pos = pos;
+
+ if (left < right)
+ {
+ // move left to position pos
+ for (; left > 0; --left)
+ current_node = current_node->left;
+ }
+ else
+ {
+ // move left to position pos
+ for (; right > 0; --right)
+ current_node = current_node->right;
+ }
+ }
+ else if (current_pos != pos)
+ {
+ // number of hops in each direction needed to reach pos
+ unsigned long right = pos - current_pos;
+ unsigned long left = size - pos + current_pos;
+ current_pos = pos;
+
+ if (left < right)
+ {
+ // move left to position pos
+ for (; left > 0; --left)
+ current_node = current_node->left;
+ }
+ else
+ {
+ // move left to position pos
+ for (; right > 0; --right)
+ current_node = current_node->right;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEQUENCE_KERNEl_2_
+
diff --git a/ml/dlib/dlib/sequence/sequence_kernel_abstract.h b/ml/dlib/dlib/sequence/sequence_kernel_abstract.h
new file mode 100644
index 000000000..8a0bdc5b3
--- /dev/null
+++ b/ml/dlib/dlib/sequence/sequence_kernel_abstract.h
@@ -0,0 +1,199 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SEQUENCE_KERNEl_ABSTRACT_
+#ifdef DLIB_SEQUENCE_KERNEl_ABSTRACT_
+
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include "../algs.h"
+
+namespace dlib
+{
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class sequence : public enumerable<T>,
+ public remover<T>
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must be swappable by a global swap() and
+ T must have a default constructor
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap() and operator[] functions do not invalidate pointers or
+ references to internal data.
+ All other functions have no such guarantees.
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the elements in the sequence from
+ the 0th element to the (size()-1)th element.
+
+ INITIAL VALUE
+ size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ sequence contains items of type T
+
+ This object represents an ordered sequence of items, each item is
+ associated with an integer value.
+ The items are numbered from 0 to size()-1
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ sequence (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ !*/
+
+ virtual ~sequence (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void add (
+ unsigned long pos,
+ T& item
+ );
+ /*!
+ requires
+ - pos <= size()
+ ensures
+ - #size() == size() + 1
+ - #item has an initial value for its type
+ - #operator[](pos) == item
+ i.e. item has been inserted into *this between the elements which
+ were previously at position pos-1 and pos
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if add() throws then it has no effect
+ !*/
+
+ void remove (
+ unsigned long pos,
+ T& item
+ );
+ /*!
+ requires
+ - pos < size()
+ ensures
+ - #size() == size() - 1
+ - the element at the position pos in *this has been removed and
+ swapped into #item
+ - #at_start() == true
+ !*/
+
+ void cat (
+ sequence& item
+ );
+ /*!
+ requires
+ - &item != this (i.e. you can't concatenate a sequence onto itself)
+ ensures
+ - item has been concatenated onto the end of *this
+ i.e. item[0] becomes (#*this)[size()], item[1]
+ becomes (#*this)[size()+1], etc.
+ - #size() == size() + item.size()
+ - #item has its initial value
+ - #at_start() == true
+ !*/
+
+ const T& operator[] (
+ unsigned long pos
+ ) const;
+ /*!
+ requires
+ - pos < size()
+ ensures
+ - returns a const reference to the element at position pos
+ !*/
+
+ T& operator[] (
+ unsigned long pos
+ );
+ /*!
+ requires
+ - pos < size()
+ ensures
+ - returns a non-const reference to the element at position pos
+ !*/
+
+ void swap (
+ sequence& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+
+ private:
+
+ // restricted functions
+ sequence(sequence&); // copy constructor
+ sequence& operator=(sequence&); // assignment operator
+
+ };
+
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ sequence<T,mem_manager>& a,
+ sequence<T,mem_manager>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ sequence<T,mem_manager>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_SEQUENCE_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/sequence/sequence_kernel_c.h b/ml/dlib/dlib/sequence/sequence_kernel_c.h
new file mode 100644
index 000000000..c565b54f0
--- /dev/null
+++ b/ml/dlib/dlib/sequence/sequence_kernel_c.h
@@ -0,0 +1,253 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEQUENCE_KERNEl_C_
+#define DLIB_SEQUENCE_KERNEl_C_
+
+#include "sequence_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ typename seq_base
+ >
+ class sequence_kernel_c : public seq_base
+ {
+ typedef typename seq_base::type T;
+ public:
+
+
+ void add (
+ unsigned long pos,
+ T& item
+ );
+
+ void remove (
+ unsigned long pos,
+ T& item
+ );
+
+ const T& operator[] (
+ unsigned long pos
+ ) const;
+
+ T& operator[] (
+ unsigned long pos
+ );
+
+ void cat (
+ sequence_kernel_c& item
+ );
+
+ const T& element (
+ ) const;
+
+ T& element (
+ );
+
+ void remove_any (
+ T& item
+ );
+
+ };
+
+
+ template <
+ typename seq_base
+ >
+ inline void swap (
+ sequence_kernel_c<seq_base>& a,
+ sequence_kernel_c<seq_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ void sequence_kernel_c<seq_base>::
+ add(
+ unsigned long pos,
+ T& item
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( pos <= this->size() ),
+ "\tvoid sequence::add"
+ << "\n\tpos must be >= 0 and <= size()"
+ << "\n\tpos: " << pos
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ seq_base::add(pos,item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ void sequence_kernel_c<seq_base>::
+ cat (
+ sequence_kernel_c<seq_base>& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(&item != this,
+ "\tvoid sequence::cat"
+ << "\n\tyou can't concatenate a sequence onto itself"
+ << "\n\t&item: " << &item
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ seq_base::cat(item);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ void sequence_kernel_c<seq_base>::
+ remove (
+ unsigned long pos,
+ T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( pos < this->size() ),
+ "\tvoid sequence::remove"
+ << "\n\tpos must be >= 0 and < size()"
+ << "\n\tpos: " << pos
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ seq_base::remove(pos,item);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ const typename seq_base::type& sequence_kernel_c<seq_base>::
+ operator[] (
+ unsigned long pos
+ ) const
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( pos < this->size() ),
+ "\tconst T& sequence::operator[]"
+ << "\n\tpos must be >= 0 and < size()"
+ << "\n\tpos: " << pos
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return seq_base::operator[](pos);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ typename seq_base::type& sequence_kernel_c<seq_base>::
+ operator[] (
+ unsigned long pos
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(( pos < this->size() ),
+ "\tT& sequence::operator[]"
+ << "\n\tpos must be >= 0 and < size()"
+ << "\n\tpos: " << pos
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return seq_base::operator[](pos);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ const typename seq_base::type& sequence_kernel_c<seq_base>::
+ element (
+ ) const
+ {
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst T& sequence::element() const"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return seq_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ typename seq_base::type& sequence_kernel_c<seq_base>::
+ element (
+ )
+ {
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tT& sequence::element()"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ return seq_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ void sequence_kernel_c<seq_base>::
+ remove_any (
+ T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (this->size() > 0),
+ "\tvoid sequence::remove_any"
+ << "\n\tsize() must be greater than zero if something is going to be removed"
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ seq_base::remove_any(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEQUENCE_KERNEl_C_
+
diff --git a/ml/dlib/dlib/sequence/sequence_sort_1.h b/ml/dlib/dlib/sequence/sequence_sort_1.h
new file mode 100644
index 000000000..2dd258e50
--- /dev/null
+++ b/ml/dlib/dlib/sequence/sequence_sort_1.h
@@ -0,0 +1,182 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEQUENCE_SORt_1_
+#define DLIB_SEQUENCE_SORt_1_
+
+#include "sequence_sort_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename seq_base
+ >
+ class sequence_sort_1 : public seq_base
+ {
+ typedef typename seq_base::type T;
+
+ public:
+
+ /*!
+ this is a median of three version of the QuickSort algorithm and
+ it sorts sequences of less than 30 elements with a selection sort
+ !*/
+
+ void sort (
+ );
+
+ private:
+
+ void sort_this_sequence (
+ seq_base& sequence
+ );
+ /*!
+ ensures
+ - each element in the sequence is < the element behind it
+ !*/
+
+ void selection_sort (
+ seq_base& sequence
+ );
+ /*!
+ ensures
+ - sequence is sorted with a selection_sort
+ !*/
+
+
+ };
+
+
+ template <
+ typename seq_base
+ >
+ inline void swap (
+ sequence_sort_1<seq_base>& a,
+ sequence_sort_1<seq_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ void sequence_sort_1<seq_base>::
+ sort (
+ )
+ {
+ if (this->size() > 1)
+ {
+ sort_this_sequence(*this);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ void sequence_sort_1<seq_base>::
+ sort_this_sequence (
+ seq_base& sequence
+ )
+ {
+ if (sequence.size() < 30)
+ {
+ selection_sort(sequence);
+ }
+ else
+ {
+ seq_base left, right;
+ T partition_element;
+
+ sequence.remove(0,partition_element);
+
+ dlib::median (
+ partition_element,
+ sequence[sequence.size()-1],
+ sequence[(sequence.size()-1)/2]
+ );
+
+ // partition sequence into left and right
+ T temp;
+ while (sequence.size() > 0)
+ {
+ sequence.remove(0,temp);
+ if (temp < partition_element)
+ {
+ left.add(0,temp);
+ }
+ else
+ {
+ right.add(0,temp);
+ }
+ }
+
+ sort_this_sequence(left);
+ sort_this_sequence(right);
+
+ // combine left and right into sequence
+ left.swap(sequence);
+ sequence.add(sequence.size(),partition_element);
+ sequence.cat(right);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ void sequence_sort_1<seq_base>::
+ selection_sort (
+ seq_base& sequence
+ )
+ {
+ if (sequence.size() > 2)
+ {
+ T temp[29];
+ unsigned long ssize = sequence.size();
+
+ for (unsigned long i = 0; i < ssize; ++i)
+ sequence.remove(0,temp[i]);
+
+ unsigned long smallest;
+ for (unsigned long i = 0; i < ssize - 1; ++i)
+ {
+ // find smallest element and swap into i
+ smallest = i;
+ for (unsigned long j = i+1; j < ssize; ++j)
+ {
+ if (temp[j] < temp[smallest])
+ smallest = j;
+ }
+ exchange(temp[smallest],temp[i]);
+ }
+
+ for (unsigned long i = 0; i < ssize; ++i)
+ sequence.add(i,temp[i]);
+ }
+ else if (sequence.size() == 2)
+ {
+ if (sequence[1] < sequence[0])
+ {
+ exchange(sequence[0],sequence[1]);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEQUENCE_SORt_1_
+
diff --git a/ml/dlib/dlib/sequence/sequence_sort_2.h b/ml/dlib/dlib/sequence/sequence_sort_2.h
new file mode 100644
index 000000000..558d4e0d5
--- /dev/null
+++ b/ml/dlib/dlib/sequence/sequence_sort_2.h
@@ -0,0 +1,65 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEQUENCE_SORt_2_
+#define DLIB_SEQUENCE_SORt_2_
+
+#include "sequence_sort_abstract.h"
+#include "../algs.h"
+#include "../sort.h"
+
+namespace dlib
+{
+
+ template <
+ typename seq_base
+ >
+ class sequence_sort_2 : public seq_base
+ {
+ typedef typename seq_base::type T;
+
+ public:
+
+ /*!
+ this is a version of the QuickSort algorithm
+ this uses the dlib::qsort_array function
+ !*/
+
+ void sort (
+ );
+
+
+ };
+
+ template <
+ typename seq_base
+ >
+ inline void swap (
+ sequence_sort_2<seq_base>& a,
+ sequence_sort_2<seq_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename seq_base
+ >
+ void sequence_sort_2<seq_base>::
+ sort (
+ )
+ {
+ if (this->size() > 1)
+ {
+ dlib::qsort_array(*this,0,this->size()-1);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEQUENCE_SORt_2_
+
diff --git a/ml/dlib/dlib/sequence/sequence_sort_abstract.h b/ml/dlib/dlib/sequence/sequence_sort_abstract.h
new file mode 100644
index 000000000..fe0a91220
--- /dev/null
+++ b/ml/dlib/dlib/sequence/sequence_sort_abstract.h
@@ -0,0 +1,65 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SEQUENCE_SORt_ABSTRACT_
+#ifdef DLIB_SEQUENCE_SORt_ABSTRACT_
+
+#include "sequence_kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename seq_base
+ >
+ class sequence_sort : public seq_base
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must implement operator< for its type
+
+ REQUIREMENTS ON seq_base
+ must be an implementation of sequence/sequence_kernel_abstract.h
+
+
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ sort() may invalidate pointers and references to data members.
+
+ WHAT THIS EXTENSION DOES FOR sequence
+ this gives a sequence the ability to sort its contents by calling sort()
+ !*/
+
+
+ public:
+
+ void sort (
+ );
+ /*!
+ ensures
+ - for all elements in #*this the ith element is <= the i+1 element
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ data may be lost if sort() throws
+ !*/
+
+
+ };
+
+
+ template <
+ typename seq_base
+ >
+ inline void swap (
+ sequence_sort<seq_base>& a,
+ sequence_sort<seq_base>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_SEQUENCE_SORt_ABSTRACT_
+
diff --git a/ml/dlib/dlib/serialize.h b/ml/dlib/dlib/serialize.h
new file mode 100644
index 000000000..f21bdaaff
--- /dev/null
+++ b/ml/dlib/dlib/serialize.h
@@ -0,0 +1,1779 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERIALIZe_
+#define DLIB_SERIALIZe_
+
+/*!
+ There are two global functions in the dlib namespace that provide serialization and
+ deserialization support. Their signatures and specifications are as follows:
+
+ void serialize (
+ const serializable_type& item,
+ std::ostream& out
+ );
+ /!*
+ ensures
+ - writes the state of item to the output stream out
+ - if (serializable_type implements the enumerable interface) then
+ - item.at_start() == true
+ throws
+ - serialization_error
+ This exception is thrown if there is some problem which prevents
+ us from successfully writing item to the output stream.
+ - any other exception
+ *!/
+
+ void deserialize (
+ serializable_type& item,
+ std::istream& in
+ );
+ /!*
+ ensures
+ - #item == a deserialized copy of the serializable_type that was
+ in the input stream in.
+ - Reads all the bytes associated with the serialized serializable_type
+ contained inside the input stream and no more. This means you
+ can serialize multiple objects to an output stream and then read
+ them all back in, one after another, using deserialize().
+ - if (serializable_type implements the enumerable interface) then
+ - item.at_start() == true
+ throws
+ - serialization_error
+ This exception is thrown if there is some problem which prevents
+ us from successfully deserializing item from the input stream.
+ If this exception is thrown then item will have an initial value
+ for its type.
+ - any other exception
+ *!/
+
+ For convenience, you can also serialize to a file using this syntax:
+ serialize("your_file.dat") << some_object << another_object;
+
+ That overwrites the contents of your_file.dat with the serialized data from some_object
+ and another_object. Then to recall the objects from the file you can do:
+ deserialize("your_file.dat") >> some_object >> another_object;
+
+ Finally, you can chain as many objects together using the << and >> operators as you
+ like.
+
+
+ This file provides serialization support to the following object types:
+ - The C++ base types (NOT including pointer types)
+ - std::string
+ - std::wstring
+ - std::vector
+ - std::array
+ - std::deque
+ - std::map
+ - std::set
+ - std::pair
+ - std::complex
+ - dlib::uint64
+ - dlib::int64
+ - float_details
+ - enumerable<T> where T is a serializable type
+ - map_pair<D,R> where D and R are both serializable types.
+ - C style arrays of serializable types
+ - Google protocol buffer objects.
+
+ This file provides deserialization support to the following object types:
+ - The C++ base types (NOT including pointer types)
+ - std::string
+ - std::wstring
+ - std::vector
+ - std::array
+ - std::deque
+ - std::map
+ - std::set
+ - std::pair
+ - std::complex
+ - dlib::uint64
+ - dlib::int64
+ - float_details
+ - C style arrays of serializable types
+ - Google protocol buffer objects.
+
+ Support for deserialization of objects which implement the enumerable or
+ map_pair interfaces is the responsibility of those objects.
+
+ Note that you can deserialize an integer value to any integral type (except for a
+ char type) if its value will fit into the target integer type. I.e. the types
+ short, int, long, unsigned short, unsigned int, unsigned long, and dlib::uint64
+ can all receive serialized data from each other so long as the actual serialized
+ value fits within the receiving integral type's range.
+
+ Also note that for any container to be serializable the type of object it contains
+ must be serializable.
+
+ FILE STREAMS
+ If you are serializing to and from file streams it is important to
+ remember to set the file streams to binary mode using the std::ios::binary
+ flag.
+
+
+ INTEGRAL SERIALIZATION FORMAT
+ All C++ integral types (except the char types) are serialized to the following
+ format:
+ The first byte is a control byte. It tells you if the serialized number is
+ positive or negative and also tells you how many of the following bytes are
+ part of the number. The absolute value of the number is stored in little
+ endian byte order and follows the control byte.
+
+ The control byte:
+ The high order bit of the control byte is a flag that tells you if the
+ encoded number is negative or not. It is set to 1 when the number is
+ negative and 0 otherwise.
+ The 4 low order bits of the control byte represent an unsigned number
+ and tells you how many of the following bytes are part of the encoded
+ number.
+
+ bool SERIALIZATION FORMAT
+ A bool value is serialized as the single byte character '1' or '0' in ASCII.
+ Where '1' indicates true and '0' indicates false.
+
+ FLOATING POINT SERIALIZATION FORMAT
+ To serialize a floating point value we convert it into a float_details object and
+ then serialize the exponent and mantissa values using dlib's integral serialization
+ format. Therefore, the output is first the exponent and then the mantissa. Note that
+ the mantissa is a signed integer (i.e. there is not a separate sign bit).
+!*/
+
+
+#include "algs.h"
+#include "assert.h"
+#include <iomanip>
+#include <cstddef>
+#include <iostream>
+#include <fstream>
+#include <string>
+#include <vector>
+#include <array>
+#include <deque>
+#include <complex>
+#include <map>
+#include <memory>
+#include <set>
+#include <limits>
+#include <type_traits>
+#include <utility>
+#include "uintn.h"
+#include "interfaces/enumerable.h"
+#include "interfaces/map_pair.h"
+#include "enable_if.h"
+#include "unicode.h"
+#include "byte_orderer.h"
+#include "float_details.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class serialization_error : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception object. It is thrown if serialization or
+ deserialization fails.
+ !*/
+
+ public:
+ serialization_error(const std::string& e):error(e) {}
+ };
+
+
+ void check_serialized_version(
+ const std::string& expected_version,
+ std::istream& in
+ );
+ /*!
+ ensures
+ - Deserializes a string from in and if it doesn't match expected_version we
+ throw serialization_error.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A ramdump information !*/
+ template <typename T>
+ struct ramdump_t
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a type decoration used to indicate that serialization should be
+ done by simply dumping the memory of some object to disk as fast as
+ possible without any sort of conversions. This means that the data written
+ will be "non-portable" in the sense that the format output by a RAM dump
+ may depend on things like the endianness of your CPU or settings of certain
+ compiler switches.
+
+ You use this object like this:
+ serialize("yourfile.dat") << ramdump(yourobject);
+ deserialize("yourfile.dat") >> ramdump(yourobject);
+ or
+ serialize(ramdump(yourobject), out);
+ deserialize(ramdump(yourobject), in);
+
+ Also, not all objects have a ramdump mode. If you try to use ramdump on an
+ object that does not define a serialization dump for ramdump you will get a
+ compiler error.
+ !*/
+ ramdump_t(T& item_) : item(item_) {}
+ T& item;
+ };
+
+ // This function just makes a ramdump that wraps an object.
+ template <typename T>
+ ramdump_t<typename std::remove_reference<T>::type> ramdump(T&& item)
+ {
+ return ramdump_t<typename std::remove_reference<T>::type>(item);
+ }
+
+
+ template <
+ typename T
+ >
+ void serialize (
+ const ramdump_t<const T>& item_,
+ std::ostream& out
+ )
+ {
+ // Move the const from inside the ramdump_t template to outside so we can bind
+ // against a serialize() call that takes just a const ramdump_t<T>. Doing this
+ // saves you from needing to write multiple overloads of serialize() to handle
+ // these different const placement cases.
+ const auto temp = ramdump(const_cast<T&>(item_.item));
+ serialize(temp, out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace ser_helper
+ {
+
+ template <
+ typename T
+ >
+ typename enable_if_c<std::numeric_limits<T>::is_signed,bool>::type pack_int (
+ T item,
+ std::ostream& out
+ )
+ /*!
+ requires
+ - T is a signed integral type
+ ensures
+ - if (no problems occur serializing item) then
+ - writes item to out
+ - returns false
+ - else
+ - returns true
+ !*/
+ {
+ COMPILE_TIME_ASSERT(sizeof(T) <= 8);
+ unsigned char buf[9];
+ unsigned char size = sizeof(T);
+ unsigned char neg;
+ if (item < 0)
+ {
+ neg = 0x80;
+ item *= -1;
+ }
+ else
+ {
+ neg = 0;
+ }
+
+ for (unsigned char i = 1; i <= sizeof(T); ++i)
+ {
+ buf[i] = static_cast<unsigned char>(item&0xFF);
+ item >>= 8;
+ if (item == 0) { size = i; break; }
+ }
+
+ std::streambuf* sbuf = out.rdbuf();
+ buf[0] = size|neg;
+ if (sbuf->sputn(reinterpret_cast<char*>(buf),size+1) != size+1)
+ {
+ out.setstate(std::ios::eofbit | std::ios::badbit);
+ return true;
+ }
+
+ return false;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename enable_if_c<std::numeric_limits<T>::is_signed,bool>::type unpack_int (
+ T& item,
+ std::istream& in
+ )
+ /*!
+ requires
+ - T is a signed integral type
+ ensures
+ - if (there are no problems deserializing item) then
+ - returns false
+ - #item == the value stored in in
+ - else
+ - returns true
+
+ !*/
+ {
+ COMPILE_TIME_ASSERT(sizeof(T) <= 8);
+
+
+ unsigned char buf[8];
+ unsigned char size;
+ bool is_negative;
+
+ std::streambuf* sbuf = in.rdbuf();
+
+ item = 0;
+ int ch = sbuf->sbumpc();
+ if (ch != EOF)
+ {
+ size = static_cast<unsigned char>(ch);
+ }
+ else
+ {
+ in.setstate(std::ios::badbit);
+ return true;
+ }
+
+ if (size&0x80)
+ is_negative = true;
+ else
+ is_negative = false;
+ size &= 0x0F;
+
+ // check if the serialized object is too big
+ if (size > (unsigned long)tmin<sizeof(T),8>::value || size == 0)
+ {
+ return true;
+ }
+
+ if (sbuf->sgetn(reinterpret_cast<char*>(&buf),size) != size)
+ {
+ in.setstate(std::ios::badbit);
+ return true;
+ }
+
+
+ for (unsigned char i = size-1; true; --i)
+ {
+ item <<= 8;
+ item |= buf[i];
+ if (i == 0)
+ break;
+ }
+
+ if (is_negative)
+ item *= -1;
+
+
+ return false;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename disable_if_c<std::numeric_limits<T>::is_signed,bool>::type pack_int (
+ T item,
+ std::ostream& out
+ )
+ /*!
+ requires
+ - T is an unsigned integral type
+ ensures
+ - if (no problems occur serializing item) then
+ - writes item to out
+ - returns false
+ - else
+ - returns true
+ !*/
+ {
+ COMPILE_TIME_ASSERT(sizeof(T) <= 8);
+ unsigned char buf[9];
+ unsigned char size = sizeof(T);
+
+ for (unsigned char i = 1; i <= sizeof(T); ++i)
+ {
+ buf[i] = static_cast<unsigned char>(item&0xFF);
+ item >>= 8;
+ if (item == 0) { size = i; break; }
+ }
+
+ std::streambuf* sbuf = out.rdbuf();
+ buf[0] = size;
+ if (sbuf->sputn(reinterpret_cast<char*>(buf),size+1) != size+1)
+ {
+ out.setstate(std::ios::eofbit | std::ios::badbit);
+ return true;
+ }
+
+ return false;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename disable_if_c<std::numeric_limits<T>::is_signed,bool>::type unpack_int (
+ T& item,
+ std::istream& in
+ )
+ /*!
+ requires
+ - T is an unsigned integral type
+ ensures
+ - if (there are no problems deserializing item) then
+ - returns false
+ - #item == the value stored in in
+ - else
+ - returns true
+
+ !*/
+ {
+ COMPILE_TIME_ASSERT(sizeof(T) <= 8);
+
+ unsigned char buf[8];
+ unsigned char size;
+
+ item = 0;
+
+ std::streambuf* sbuf = in.rdbuf();
+ int ch = sbuf->sbumpc();
+ if (ch != EOF)
+ {
+ size = static_cast<unsigned char>(ch);
+ }
+ else
+ {
+ in.setstate(std::ios::badbit);
+ return true;
+ }
+
+
+ // mask out the 3 reserved bits
+ size &= 0x8F;
+
+ // check if an error occurred
+ if (size > (unsigned long)tmin<sizeof(T),8>::value || size == 0)
+ return true;
+
+
+ if (sbuf->sgetn(reinterpret_cast<char*>(&buf),size) != size)
+ {
+ in.setstate(std::ios::badbit);
+ return true;
+ }
+
+ for (unsigned char i = size-1; true; --i)
+ {
+ item <<= 8;
+ item |= buf[i];
+ if (i == 0)
+ break;
+ }
+
+ return false;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ #define USE_DEFAULT_INT_SERIALIZATION_FOR(T) \
+ inline void serialize (const T& item, std::ostream& out) \
+ { if (ser_helper::pack_int(item,out)) throw serialization_error("Error serializing object of type " + std::string(#T)); } \
+ inline void deserialize (T& item, std::istream& in) \
+ { if (ser_helper::unpack_int(item,in)) throw serialization_error("Error deserializing object of type " + std::string(#T)); }
+
+ template <typename T>
+ inline bool pack_byte (
+ const T& ch,
+ std::ostream& out
+ )
+ {
+ std::streambuf* sbuf = out.rdbuf();
+ return (sbuf->sputc((char)ch) == EOF);
+ }
+
+ template <typename T>
+ inline bool unpack_byte (
+ T& ch,
+ std::istream& in
+ )
+ {
+ std::streambuf* sbuf = in.rdbuf();
+ int temp = sbuf->sbumpc();
+ if (temp != EOF)
+ {
+ ch = static_cast<T>(temp);
+ return false;
+ }
+ else
+ {
+ return true;
+ }
+ }
+
+ #define USE_DEFAULT_BYTE_SERIALIZATION_FOR(T) \
+ inline void serialize (const T& item,std::ostream& out) \
+ { if (pack_byte(item,out)) throw serialization_error("Error serializing object of type " + std::string(#T)); } \
+ inline void deserialize (T& item, std::istream& in) \
+ { if (unpack_byte(item,in)) throw serialization_error("Error deserializing object of type " + std::string(#T)); }
+
+// ----------------------------------------------------------------------------------------
+
+ USE_DEFAULT_INT_SERIALIZATION_FOR(short)
+ USE_DEFAULT_INT_SERIALIZATION_FOR(int)
+ USE_DEFAULT_INT_SERIALIZATION_FOR(long)
+ USE_DEFAULT_INT_SERIALIZATION_FOR(unsigned short)
+ USE_DEFAULT_INT_SERIALIZATION_FOR(unsigned int)
+ USE_DEFAULT_INT_SERIALIZATION_FOR(unsigned long)
+ USE_DEFAULT_INT_SERIALIZATION_FOR(uint64)
+ USE_DEFAULT_INT_SERIALIZATION_FOR(int64)
+
+ USE_DEFAULT_BYTE_SERIALIZATION_FOR(char)
+ USE_DEFAULT_BYTE_SERIALIZATION_FOR(signed char)
+ USE_DEFAULT_BYTE_SERIALIZATION_FOR(unsigned char)
+
+ // Don't define serialization for wchar_t when using visual studio and
+ // _NATIVE_WCHAR_T_DEFINED isn't defined since if it isn't they improperly set
+ // wchar_t to be a typedef rather than its own type as required by the C++
+ // standard.
+#if !defined(_MSC_VER) || _NATIVE_WCHAR_T_DEFINED
+ USE_DEFAULT_INT_SERIALIZATION_FOR(wchar_t)
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize(
+ const float_details& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.mantissa, out);
+ serialize(item.exponent, out);
+ }
+
+ inline void deserialize(
+ float_details& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.mantissa, in);
+ deserialize(item.exponent, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ inline void serialize_floating_point (
+ const T& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ float_details temp = item;
+ serialize(temp, out);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing a floating point number."); }
+ }
+
+ template <typename T>
+ inline bool old_deserialize_floating_point (
+ T& item,
+ std::istream& in
+ )
+ {
+ std::ios::fmtflags oldflags = in.flags();
+ in.flags(static_cast<std::ios_base::fmtflags>(0));
+ std::streamsize ss = in.precision(35);
+ if (in.peek() == 'i')
+ {
+ item = std::numeric_limits<T>::infinity();
+ in.get();
+ in.get();
+ in.get();
+ }
+ else if (in.peek() == 'n')
+ {
+ item = -std::numeric_limits<T>::infinity();
+ in.get();
+ in.get();
+ in.get();
+ in.get();
+ }
+ else if (in.peek() == 'N')
+ {
+ item = std::numeric_limits<T>::quiet_NaN();
+ in.get();
+ in.get();
+ in.get();
+ }
+ else
+ {
+ in >> item;
+ }
+ in.flags(oldflags);
+ in.precision(ss);
+ return (in.get() != ' ');
+ }
+
+ template <typename T>
+ inline void deserialize_floating_point (
+ T& item,
+ std::istream& in
+ )
+ {
+ // check if the serialized data uses the older ASCII based format. We can check
+ // this easily because the new format starts with the integer control byte which
+ // always has 0 bits in the positions corresponding to the bitmask 0x70. Moreover,
+ // since the previous format used ASCII numbers we know that no valid bytes can
+ // have bit values of one in the positions indicated 0x70. So this test looks at
+ // the first byte and checks if the serialized data uses the old format or the new
+ // format.
+ if ((in.rdbuf()->sgetc()&0x70) == 0)
+ {
+ try
+ {
+ // Use the fast and compact binary serialization format.
+ float_details temp;
+ deserialize(temp, in);
+ item = temp;
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing a floating point number."); }
+ }
+ else
+ {
+ if (old_deserialize_floating_point(item, in))
+ throw serialization_error("Error deserializing a floating point number.");
+ }
+ }
+
+ inline void serialize ( const float& item, std::ostream& out)
+ {
+ serialize_floating_point(item,out);
+ }
+
+ inline void deserialize (float& item, std::istream& in)
+ {
+ deserialize_floating_point(item,in);
+ }
+
+ inline void serialize ( const double& item, std::ostream& out)
+ {
+ serialize_floating_point(item,out);
+ }
+
+ inline void deserialize (double& item, std::istream& in)
+ {
+ deserialize_floating_point(item,in);
+ }
+
+ inline void serialize ( const long double& item, std::ostream& out)
+ {
+ serialize_floating_point(item,out);
+ }
+
+ inline void deserialize ( long double& item, std::istream& in)
+ {
+ deserialize_floating_point(item,in);
+ }
+
+// ----------------------------------------------------------------------------------------
+// prototypes
+
+ template <typename domain, typename range, typename compare, typename alloc>
+ void serialize (
+ const std::map<domain,range, compare, alloc>& item,
+ std::ostream& out
+ );
+
+ template <typename domain, typename range, typename compare, typename alloc>
+ void deserialize (
+ std::map<domain, range, compare, alloc>& item,
+ std::istream& in
+ );
+
+ template <typename domain, typename compare, typename alloc>
+ void serialize (
+ const std::set<domain, compare, alloc>& item,
+ std::ostream& out
+ );
+
+ template <typename domain, typename compare, typename alloc>
+ void deserialize (
+ std::set<domain, compare, alloc>& item,
+ std::istream& in
+ );
+
+ template <typename T, typename alloc>
+ void serialize (
+ const std::vector<T,alloc>& item,
+ std::ostream& out
+ );
+
+ template <typename T, typename alloc>
+ void deserialize (
+ std::vector<T,alloc>& item,
+ std::istream& in
+ );
+
+ template <typename T, typename alloc>
+ void serialize (
+ const std::deque<T,alloc>& item,
+ std::ostream& out
+ );
+
+ template <typename T, typename alloc>
+ void deserialize (
+ std::deque<T,alloc>& item,
+ std::istream& in
+ );
+
+ inline void serialize (
+ const std::string& item,
+ std::ostream& out
+ );
+
+ inline void deserialize (
+ std::string& item,
+ std::istream& in
+ );
+
+ inline void serialize (
+ const std::wstring& item,
+ std::ostream& out
+ );
+
+ inline void deserialize (
+ std::wstring& item,
+ std::istream& in
+ );
+
+ inline void serialize (
+ const ustring& item,
+ std::ostream& out
+ );
+
+ inline void deserialize (
+ ustring& item,
+ std::istream& in
+ );
+
+ template <
+ typename T
+ >
+ inline void serialize (
+ const enumerable<T>& item,
+ std::ostream& out
+ );
+
+ template <
+ typename domain,
+ typename range
+ >
+ inline void serialize (
+ const map_pair<domain,range>& item,
+ std::ostream& out
+ );
+
+ template <
+ typename T,
+ size_t length
+ >
+ inline void serialize (
+ const T (&array)[length],
+ std::ostream& out
+ );
+
+ template <
+ typename T,
+ size_t length
+ >
+ inline void deserialize (
+ T (&array)[length],
+ std::istream& in
+ );
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ bool item,
+ std::ostream& out
+ )
+ {
+ if (item)
+ out << '1';
+ else
+ out << '0';
+
+ if (!out)
+ throw serialization_error("Error serializing object of type bool");
+ }
+
+ inline void deserialize (
+ bool& item,
+ std::istream& in
+ )
+ {
+ int ch = in.get();
+ if (ch != EOF)
+ {
+ if (ch == '1')
+ item = true;
+ else if (ch == '0')
+ item = false;
+ else
+ throw serialization_error("Error deserializing object of type bool");
+ }
+ else
+ {
+ throw serialization_error("Error deserializing object of type bool");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename first_type, typename second_type>
+ void serialize (
+ const std::pair<first_type, second_type>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.first,out);
+ serialize(item.second,out);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std::pair"); }
+ }
+
+ template <typename first_type, typename second_type>
+ void deserialize (
+ std::pair<first_type, second_type>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.first,in);
+ deserialize(item.second,in);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std::pair"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename domain, typename range, typename compare, typename alloc>
+ void serialize (
+ const std::map<domain,range, compare, alloc>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+
+ serialize(size,out);
+ typename std::map<domain,range,compare,alloc>::const_iterator i;
+ for (i = item.begin(); i != item.end(); ++i)
+ {
+ serialize(i->first,out);
+ serialize(i->second,out);
+ }
+
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std::map"); }
+ }
+
+ template <typename domain, typename range, typename compare, typename alloc>
+ void deserialize (
+ std::map<domain, range, compare, alloc>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+
+ unsigned long size;
+ deserialize(size,in);
+ domain d;
+ range r;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(d,in);
+ deserialize(r,in);
+ item[d] = r;
+ }
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std::map"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename domain, typename compare, typename alloc>
+ void serialize (
+ const std::set<domain, compare, alloc>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+
+ serialize(size,out);
+ typename std::set<domain,compare,alloc>::const_iterator i;
+ for (i = item.begin(); i != item.end(); ++i)
+ {
+ serialize(*i,out);
+ }
+
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std::set"); }
+ }
+
+ template <typename domain, typename compare, typename alloc>
+ void deserialize (
+ std::set<domain, compare, alloc>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+
+ unsigned long size;
+ deserialize(size,in);
+ domain d;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(d,in);
+ item.insert(d);
+ }
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std::set"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename alloc>
+ void serialize (
+ const std::vector<bool,alloc>& item,
+ std::ostream& out
+ )
+ {
+ std::vector<unsigned char> temp(item.size());
+ for (unsigned long i = 0; i < item.size(); ++i)
+ {
+ if (item[i])
+ temp[i] = '1';
+ else
+ temp[i] = '0';
+ }
+ serialize(temp, out);
+ }
+
+ template <typename alloc>
+ void deserialize (
+ std::vector<bool,alloc>& item,
+ std::istream& in
+ )
+ {
+ std::vector<unsigned char> temp;
+ deserialize(temp, in);
+ item.resize(temp.size());
+ for (unsigned long i = 0; i < temp.size(); ++i)
+ {
+ if (temp[i] == '1')
+ item[i] = true;
+ else
+ item[i] = false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ void serialize (
+ const std::vector<T,alloc>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+
+ serialize(size,out);
+ for (unsigned long i = 0; i < item.size(); ++i)
+ serialize(item[i],out);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std::vector"); }
+ }
+
+ template <typename T, typename alloc>
+ void deserialize (
+ std::vector<T, alloc>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long size;
+ deserialize(size,in);
+ item.resize(size);
+ for (unsigned long i = 0; i < size; ++i)
+ deserialize(item[i],in);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std::vector"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename alloc>
+ void serialize (
+ const std::vector<char,alloc>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+ serialize(size,out);
+ if (item.size() != 0)
+ out.write(&item[0], item.size());
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std::vector"); }
+ }
+
+ template <typename alloc>
+ void deserialize (
+ std::vector<char, alloc>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long size;
+ deserialize(size,in);
+ item.resize(size);
+ if (item.size() != 0)
+ in.read(&item[0], item.size());
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std::vector"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename alloc>
+ void serialize (
+ const std::vector<unsigned char,alloc>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+ serialize(size,out);
+ if (item.size() != 0)
+ out.write((char*)&item[0], item.size());
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std::vector"); }
+ }
+
+ template <typename alloc>
+ void deserialize (
+ std::vector<unsigned char, alloc>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long size;
+ deserialize(size,in);
+ item.resize(size);
+ if (item.size() != 0)
+ in.read((char*)&item[0], item.size());
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std::vector"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ void serialize (
+ const std::deque<T,alloc>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+
+ serialize(size,out);
+ for (unsigned long i = 0; i < item.size(); ++i)
+ serialize(item[i],out);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std::deque"); }
+ }
+
+ template <typename T, typename alloc>
+ void deserialize (
+ std::deque<T, alloc>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long size;
+ deserialize(size,in);
+ item.resize(size);
+ for (unsigned long i = 0; i < size; ++i)
+ deserialize(item[i],in);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std::deque"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const std::string& item,
+ std::ostream& out
+ )
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+ try{ serialize(size,out); }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std::string"); }
+
+ out.write(item.c_str(),size);
+ if (!out) throw serialization_error("Error serializing object of type std::string");
+ }
+
+ inline void deserialize (
+ std::string& item,
+ std::istream& in
+ )
+ {
+ unsigned long size;
+ try { deserialize(size,in); }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std::string"); }
+
+ item.resize(size);
+ if (size != 0)
+ {
+ in.read(&item[0],size);
+ if (!in) throw serialization_error("Error deserializing object of type std::string");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const std::wstring& item,
+ std::ostream& out
+ )
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+ try{ serialize(size,out); }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std::wstring"); }
+
+ for (unsigned long i = 0; i < item.size(); ++i)
+ serialize(item[i], out);
+ if (!out) throw serialization_error("Error serializing object of type std::wstring");
+ }
+
+ inline void deserialize (
+ std::wstring& item,
+ std::istream& in
+ )
+ {
+ unsigned long size;
+ try { deserialize(size,in); }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std::wstring"); }
+
+ item.resize(size);
+ for (unsigned long i = 0; i < item.size(); ++i)
+ deserialize(item[i],in);
+
+ if (!in) throw serialization_error("Error deserializing object of type std::wstring");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const ustring& item,
+ std::ostream& out
+ )
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+ try{ serialize(size,out); }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type ustring"); }
+
+ for (unsigned long i = 0; i < item.size(); ++i)
+ serialize(item[i], out);
+ if (!out) throw serialization_error("Error serializing object of type ustring");
+ }
+
+ inline void deserialize (
+ ustring& item,
+ std::istream& in
+ )
+ {
+ unsigned long size;
+ try { deserialize(size,in); }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type ustring"); }
+
+ item.resize(size);
+ for (unsigned long i = 0; i < item.size(); ++i)
+ deserialize(item[i],in);
+
+ if (!in) throw serialization_error("Error deserializing object of type ustring");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ inline void serialize (
+ const enumerable<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ item.reset();
+ serialize(item.size(),out);
+ while (item.move_next())
+ serialize(item.element(),out);
+ item.reset();
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type enumerable");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range
+ >
+ inline void serialize (
+ const map_pair<domain,range>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.key(),out);
+ serialize(item.value(),out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type map_pair");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ size_t length
+ >
+ inline void serialize (
+ const T (&array)[length],
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(length,out);
+ for (size_t i = 0; i < length; ++i)
+ serialize(array[i],out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing a C style array");
+ }
+ }
+
+ template <
+ size_t length
+ >
+ inline void serialize (
+ const char (&array)[length],
+ std::ostream& out
+ )
+ {
+ if (length != 0 && array[length-1] == '\0')
+ {
+ // If this is a null terminated string then don't serialize the trailing null.
+ // We do this so that the serialization format for C-strings is the same as
+ // std::string.
+ serialize(length-1, out);
+ out.write(array, length-1);
+ if (!out)
+ throw serialization_error("Error serializing a C-style string");
+ }
+ else
+ {
+ try
+ {
+ serialize(length,out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing a C style array");
+ }
+ if (length != 0)
+ out.write(array, length);
+ if (!out)
+ throw serialization_error("Error serializing a C-style string");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ size_t length
+ >
+ inline void deserialize (
+ T (&array)[length],
+ std::istream& in
+ )
+ {
+ size_t size;
+ try
+ {
+ deserialize(size,in);
+ if (size == length)
+ {
+ for (size_t i = 0; i < length; ++i)
+ deserialize(array[i],in);
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing a C style array");
+ }
+
+ if (size != length)
+ throw serialization_error("Error deserializing a C style array, lengths do not match");
+ }
+
+ template <
+ size_t length
+ >
+ inline void deserialize (
+ char (&array)[length],
+ std::istream& in
+ )
+ {
+ size_t size;
+ try
+ {
+ deserialize(size,in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing a C style array");
+ }
+
+ if (size == length)
+ {
+ in.read(array, size);
+ if (!in)
+ throw serialization_error("Error deserializing a C-style array");
+ }
+ else if (size+1 == length)
+ {
+ // In this case we are deserializing a C-style array so we need to add the null
+ // terminator.
+ in.read(array, size);
+ array[size] = '\0';
+ if (!in)
+ throw serialization_error("Error deserializing a C-style string");
+ }
+ else
+ {
+ throw serialization_error("Error deserializing a C style array, lengths do not match");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ size_t N
+ >
+ inline void serialize (
+ const std::array<T,N>& array,
+ std::ostream& out
+ )
+ {
+ typedef T c_array_type[N];
+ serialize(*(const c_array_type*)array.data(), out);
+ }
+
+ template <
+ typename T,
+ size_t N
+ >
+ inline void deserialize (
+ std::array<T,N>& array,
+ std::istream& in
+ )
+ {
+ typedef T c_array_type[N];
+ deserialize(*(c_array_type*)array.data(), in);
+ }
+
+ template <
+ typename T
+ >
+ inline void serialize (
+ const std::array<T,0>& /*array*/,
+ std::ostream& out
+ )
+ {
+ size_t N = 0;
+ serialize(N, out);
+ }
+
+ template <
+ typename T
+ >
+ inline void deserialize (
+ std::array<T,0>& /*array*/,
+ std::istream& in
+ )
+ {
+ size_t N;
+ deserialize(N, in);
+ if (N != 0)
+ {
+ std::ostringstream sout;
+ sout << "Expected std::array of size 0 but found a size of " << N;
+ throw serialization_error(sout.str());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ inline void serialize (
+ const std::complex<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.real(),out);
+ serialize(item.imag(),out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing an object of type std::complex");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ inline void deserialize (
+ std::complex<T>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ T real, imag;
+ deserialize(real,in);
+ deserialize(imag,in);
+ item = std::complex<T>(real,imag);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing an object of type std::complex");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class proxy_serialize
+ {
+ public:
+ explicit proxy_serialize (
+ const std::string& filename
+ )
+ {
+ fout.reset(new std::ofstream(filename.c_str(), std::ios::binary));
+ if (!(*fout))
+ throw serialization_error("Unable to open " + filename + " for writing.");
+ }
+
+ template <typename T>
+ inline proxy_serialize& operator<<(const T& item)
+ {
+ serialize(item, *fout);
+ return *this;
+ }
+
+ private:
+ std::shared_ptr<std::ofstream> fout;
+ };
+
+ class proxy_deserialize
+ {
+ public:
+ explicit proxy_deserialize (
+ const std::string& filename
+ ) : filename(filename)
+ {
+ fin.reset(new std::ifstream(filename.c_str(), std::ios::binary));
+ if (!(*fin))
+ throw serialization_error("Unable to open " + filename + " for reading.");
+
+ // read the file header into a buffer and then seek back to the start of the
+ // file.
+ fin->read(file_header,4);
+ fin->clear();
+ fin->seekg(0);
+ }
+
+ template <typename T>
+ inline proxy_deserialize& operator>>(T& item)
+ {
+ return doit(item);
+ }
+
+ template <typename T>
+ inline proxy_deserialize& operator>>(ramdump_t<T>&& item)
+ {
+ return doit(std::move(item));
+ }
+
+ private:
+ template <typename T>
+ inline proxy_deserialize& doit(T&& item)
+ {
+ try
+ {
+ if (fin->peek() == EOF)
+ throw serialization_error("No more objects were in the file!");
+ deserialize(std::forward<T>(item), *fin);
+ }
+ catch (serialization_error& e)
+ {
+ std::string suffix;
+ if (looks_like_a_compressed_file())
+ suffix = "\n *** THIS LOOKS LIKE A COMPRESSED FILE. DID YOU FORGET TO DECOMPRESS IT? *** \n";
+
+ if (objects_read == 0)
+ {
+ throw serialization_error("An error occurred while trying to read the first"
+ " object from the file " + filename + ".\nERROR: " + e.info + "\n" + suffix);
+ }
+ else if (objects_read == 1)
+ {
+ throw serialization_error("An error occurred while trying to read the second"
+ " object from the file " + filename +
+ ".\nERROR: " + e.info + "\n" + suffix);
+ }
+ else if (objects_read == 2)
+ {
+ throw serialization_error("An error occurred while trying to read the third"
+ " object from the file " + filename +
+ ".\nERROR: " + e.info + "\n" + suffix);
+ }
+ else
+ {
+ throw serialization_error("An error occurred while trying to read the " +
+ std::to_string(objects_read+1) + "th object from the file " + filename +
+ ".\nERROR: " + e.info + "\n" + suffix);
+ }
+ }
+ ++objects_read;
+ return *this;
+ }
+
+ int objects_read = 0;
+ std::string filename;
+ std::shared_ptr<std::ifstream> fin;
+
+ // We don't need to look at the file header. However, it's here because people
+ // keep posting questions to the dlib forums asking why they get file load errors.
+ // Then it turns out that the problem is they have a compressed file that NEEDS TO
+ // BE DECOMPRESSED by bzip2 or whatever and the reason they are getting
+ // deserialization errors is because they didn't decompress the file. So we are
+ // going to check if this file looks like a compressed file and if so then emit an
+ // error message telling them to unzip the file. :(
+ char file_header[4] = {0,0,0,0};
+
+ bool looks_like_a_compressed_file(
+ ) const
+ {
+ if (file_header[0] == 'B' && file_header[1] == 'Z' && file_header[2] == 'h' &&
+ ('0' <= file_header[3] && file_header[3] <= '9') )
+ {
+ return true;
+ }
+
+ return false;
+ }
+ };
+
+ inline proxy_serialize serialize(const std::string& filename)
+ { return proxy_serialize(filename); }
+ inline proxy_deserialize deserialize(const std::string& filename)
+ { return proxy_deserialize(filename); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+// forward declare the MessageLite object so we can reference it below.
+namespace google
+{
+ namespace protobuf
+ {
+ class MessageLite;
+ }
+}
+
+namespace dlib
+{
+
+ /*!A is_protocol_buffer
+ This is a template that tells you if a type is a Google protocol buffer object.
+ !*/
+
+ template <typename T, typename U = void >
+ struct is_protocol_buffer
+ {
+ static const bool value = false;
+ };
+
+ template <typename T>
+ struct is_protocol_buffer <T,typename enable_if<is_convertible<T*,::google::protobuf::MessageLite*> >::type >
+ {
+ static const bool value = true;
+ };
+
+ template <typename T>
+ typename enable_if<is_protocol_buffer<T> >::type serialize(const T& item, std::ostream& out)
+ {
+ // Note that Google protocol buffer messages are not self delimiting
+ // (see https://developers.google.com/protocol-buffers/docs/techniques)
+ // This means they don't record their length or where they end, so we have
+ // to record this information ourselves. So we save the size as a little endian 32bit
+ // integer prefixed onto the front of the message.
+
+ byte_orderer bo;
+
+ // serialize into temp string
+ std::string temp;
+ if (!item.SerializeToString(&temp))
+ throw dlib::serialization_error("Error while serializing a Google Protocol Buffer object.");
+ if (temp.size() > std::numeric_limits<uint32>::max())
+ throw dlib::serialization_error("Error while serializing a Google Protocol Buffer object, message too large.");
+
+ // write temp to the output stream
+ uint32 size = temp.size();
+ bo.host_to_little(size);
+ out.write((char*)&size, sizeof(size));
+ out.write(temp.c_str(), temp.size());
+ }
+
+ template <typename T>
+ typename enable_if<is_protocol_buffer<T> >::type deserialize(T& item, std::istream& in)
+ {
+ // Note that Google protocol buffer messages are not self delimiting
+ // (see https://developers.google.com/protocol-buffers/docs/techniques)
+ // This means they don't record their length or where they end, so we have
+ // to record this information ourselves. So we save the size as a little endian 32bit
+ // integer prefixed onto the front of the message.
+
+ byte_orderer bo;
+
+ uint32 size = 0;
+ // read the size
+ in.read((char*)&size, sizeof(size));
+ bo.little_to_host(size);
+ if (!in || size == 0)
+ throw dlib::serialization_error("Error while deserializing a Google Protocol Buffer object.");
+
+ // read the bytes into temp
+ std::string temp;
+ temp.resize(size);
+ in.read(&temp[0], size);
+
+ // parse temp into item
+ if (!in || !item.ParseFromString(temp))
+ {
+ throw dlib::serialization_error("Error while deserializing a Google Protocol Buffer object.");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void check_serialized_version(const std::string& expected_version, std::istream& in)
+ {
+ std::string version;
+ deserialize(version, in);
+ if (version != expected_version)
+ {
+ throw serialization_error("Unexpected version '"+version+
+ "' found while deserializing object. Expected version to be '"+expected_version+"'.");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SERIALIZe_
+
diff --git a/ml/dlib/dlib/server.h b/ml/dlib/dlib/server.h
new file mode 100644
index 000000000..d346941ab
--- /dev/null
+++ b/ml/dlib/dlib/server.h
@@ -0,0 +1,12 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERVEr_
+#define DLIB_SERVEr_
+
+#include "server/server_kernel.h"
+#include "server/server_iostream.h"
+#include "server/server_http.h"
+
+
+#endif // DLIB_SERVEr_
+
diff --git a/ml/dlib/dlib/server/server_http.cpp b/ml/dlib/dlib/server/server_http.cpp
new file mode 100644
index 000000000..9e3051a43
--- /dev/null
+++ b/ml/dlib/dlib/server/server_http.cpp
@@ -0,0 +1,409 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERVER_HTTP_CPp_
+#define DLIB_SERVER_HTTP_CPp_
+
+#include "server_http.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace http_impl
+ {
+ inline unsigned char to_hex( unsigned char x )
+ {
+ return x + (x > 9 ? ('A'-10) : '0');
+ }
+
+ const std::string urlencode( const std::string& s )
+ {
+ std::ostringstream os;
+
+ for ( std::string::const_iterator ci = s.begin(); ci != s.end(); ++ci )
+ {
+ if ( (*ci >= 'a' && *ci <= 'z') ||
+ (*ci >= 'A' && *ci <= 'Z') ||
+ (*ci >= '0' && *ci <= '9') )
+ { // allowed
+ os << *ci;
+ }
+ else if ( *ci == ' ')
+ {
+ os << '+';
+ }
+ else
+ {
+ os << '%' << to_hex(*ci >> 4) << to_hex(*ci % 16);
+ }
+ }
+
+ return os.str();
+ }
+
+ inline unsigned char from_hex (
+ unsigned char ch
+ )
+ {
+ if (ch <= '9' && ch >= '0')
+ ch -= '0';
+ else if (ch <= 'f' && ch >= 'a')
+ ch -= 'a' - 10;
+ else if (ch <= 'F' && ch >= 'A')
+ ch -= 'A' - 10;
+ else
+ ch = 0;
+ return ch;
+ }
+
+ const std::string urldecode (
+ const std::string& str
+ )
+ {
+ using namespace std;
+ string result;
+ string::size_type i;
+ for (i = 0; i < str.size(); ++i)
+ {
+ if (str[i] == '+')
+ {
+ result += ' ';
+ }
+ else if (str[i] == '%' && str.size() > i+2)
+ {
+ const unsigned char ch1 = from_hex(str[i+1]);
+ const unsigned char ch2 = from_hex(str[i+2]);
+ const unsigned char ch = (ch1 << 4) | ch2;
+ result += ch;
+ i += 2;
+ }
+ else
+ {
+ result += str[i];
+ }
+ }
+ return result;
+ }
+
+ void parse_url(
+ std::string word,
+ key_value_map& queries
+ )
+ /*!
+ Parses the query string of a URL. word should be the stuff that comes
+ after the ? in the query URL.
+ !*/
+ {
+ std::string::size_type pos;
+
+ for (pos = 0; pos < word.size(); ++pos)
+ {
+ if (word[pos] == '&')
+ word[pos] = ' ';
+ }
+
+ std::istringstream sin(word);
+ sin >> word;
+ while (sin)
+ {
+ pos = word.find_first_of("=");
+ if (pos != std::string::npos)
+ {
+ std::string key = urldecode(word.substr(0,pos));
+ std::string value = urldecode(word.substr(pos+1));
+
+ queries[key] = value;
+ }
+ sin >> word;
+ }
+ }
+
+ void read_with_limit(
+ std::istream& in,
+ std::string& buffer,
+ int delim = '\n'
+ )
+ {
+ using namespace std;
+ const size_t max = 64*1024;
+ buffer.clear();
+ buffer.reserve(300);
+
+ while (in.peek() != delim && in.peek() != '\n' && in.peek() != EOF && buffer.size() < max)
+ {
+ buffer += (char)in.get();
+ }
+
+ // if we quit the loop because the data is longer than expected or we hit EOF
+ if (in.peek() == EOF)
+ throw http_parse_error("HTTP field from client terminated incorrectly", 414);
+ if (buffer.size() == max)
+ throw http_parse_error("HTTP field from client is too long", 414);
+
+ in.get();
+ // eat any remaining whitespace
+ if (delim == ' ')
+ {
+ while (in.peek() == ' ')
+ in.get();
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long parse_http_request (
+ std::istream& in,
+ incoming_things& incoming,
+ unsigned long max_content_length
+ )
+ {
+ using namespace std;
+ using namespace http_impl;
+ read_with_limit(in, incoming.request_type, ' ');
+
+ // get the path
+ read_with_limit(in, incoming.path, ' ');
+
+ // Get the HTTP/1.1 - Ignore for now...
+ read_with_limit(in, incoming.protocol);
+
+ key_value_map_ci& incoming_headers = incoming.headers;
+ key_value_map& cookies = incoming.cookies;
+ std::string& path = incoming.path;
+ std::string& content_type = incoming.content_type;
+ unsigned long content_length = 0;
+
+ string line;
+ read_with_limit(in, line);
+ string first_part_of_header;
+ string::size_type position_of_double_point;
+ // now loop over all the incoming_headers
+ while (line != "\r")
+ {
+ position_of_double_point = line.find_first_of(':');
+ if ( position_of_double_point != string::npos )
+ {
+ first_part_of_header = dlib::trim(line.substr(0, position_of_double_point));
+
+ if ( !incoming_headers[first_part_of_header].empty() )
+ incoming_headers[ first_part_of_header ] += " ";
+ incoming_headers[first_part_of_header] += dlib::trim(line.substr(position_of_double_point+1));
+
+ // look for Content-Type:
+ if (line.size() > 14 && strings_equal_ignore_case(line, "Content-Type:", 13))
+ {
+ content_type = line.substr(14);
+ if (content_type[content_type.size()-1] == '\r')
+ content_type.erase(content_type.size()-1);
+ }
+ // look for Content-Length:
+ else if (line.size() > 16 && strings_equal_ignore_case(line, "Content-Length:", 15))
+ {
+ istringstream sin(line.substr(16));
+ sin >> content_length;
+ if (!sin)
+ {
+ throw http_parse_error("Invalid Content-Length of '" + line.substr(16) + "'", 411);
+ }
+
+ if (content_length > max_content_length)
+ {
+ std::ostringstream sout;
+ sout << "Content-Length of post back is too large. It must be less than " << max_content_length;
+ throw http_parse_error(sout.str(), 413);
+ }
+ }
+ // look for any cookies
+ else if (line.size() > 6 && strings_equal_ignore_case(line, "Cookie:", 7))
+ {
+ string::size_type pos = 6;
+ string key, value;
+ bool seen_key_start = false;
+ bool seen_equal_sign = false;
+ while (pos + 1 < line.size())
+ {
+ ++pos;
+ // ignore whitespace between cookies
+ if (!seen_key_start && line[pos] == ' ')
+ continue;
+
+ seen_key_start = true;
+ if (!seen_equal_sign)
+ {
+ if (line[pos] == '=')
+ {
+ seen_equal_sign = true;
+ }
+ else
+ {
+ key += line[pos];
+ }
+ }
+ else
+ {
+ if (line[pos] == ';')
+ {
+ cookies[urldecode(key)] = urldecode(value);
+ seen_equal_sign = false;
+ seen_key_start = false;
+ key.clear();
+ value.clear();
+ }
+ else
+ {
+ value += line[pos];
+ }
+ }
+ }
+ if (key.size() > 0)
+ {
+ cookies[urldecode(key)] = urldecode(value);
+ key.clear();
+ value.clear();
+ }
+ }
+ } // no ':' in it!
+ read_with_limit(in, line);
+ } // while (line != "\r")
+
+
+ // If there is data being posted back to us as a query string then
+ // pick out the queries using parse_url.
+ if ((strings_equal_ignore_case(incoming.request_type, "POST") ||
+ strings_equal_ignore_case(incoming.request_type, "PUT")) &&
+ strings_equal_ignore_case(left_substr(content_type,";"), "application/x-www-form-urlencoded"))
+ {
+ if (content_length > 0)
+ {
+ incoming.body.resize(content_length);
+ in.read(&incoming.body[0],content_length);
+ }
+ parse_url(incoming.body, incoming.queries);
+ }
+
+ string::size_type pos = path.find_first_of("?");
+ if (pos != string::npos)
+ {
+ parse_url(path.substr(pos+1), incoming.queries);
+ }
+
+
+ if (!in)
+ throw http_parse_error("Error parsing HTTP request", 500);
+
+ return content_length;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void read_body (
+ std::istream& in,
+ incoming_things& incoming
+ )
+ {
+ // if the body hasn't already been loaded and there is data to load
+ if (incoming.body.size() == 0 &&
+ incoming.headers.count("Content-Length") != 0)
+ {
+ const unsigned long content_length = string_cast<unsigned long>(incoming.headers["Content-Length"]);
+
+ incoming.body.resize(content_length);
+ if (content_length > 0)
+ {
+ in.read(&incoming.body[0],content_length);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void write_http_response (
+ std::ostream& out,
+ outgoing_things outgoing,
+ const std::string& result
+ )
+ {
+ using namespace http_impl;
+ key_value_map& new_cookies = outgoing.cookies;
+ key_value_map_ci& response_headers = outgoing.headers;
+
+ // only send this header if the user hasn't told us to send another kind
+ bool has_content_type = false, has_location = false;
+ for(key_value_map_ci::const_iterator ci = response_headers.begin(); ci != response_headers.end(); ++ci )
+ {
+ if ( !has_content_type && strings_equal_ignore_case(ci->first , "content-type") )
+ {
+ has_content_type = true;
+ }
+ else if ( !has_location && strings_equal_ignore_case(ci->first , "location") )
+ {
+ has_location = true;
+ }
+ }
+
+ if ( has_location )
+ {
+ outgoing.http_return = 302;
+ }
+
+ if ( !has_content_type )
+ {
+ response_headers["Content-Type"] = "text/html";
+ }
+
+ response_headers["Content-Length"] = cast_to_string(result.size());
+
+ out << "HTTP/1.0 " << outgoing.http_return << " " << outgoing.http_return_status << "\r\n";
+
+ // Set any new headers
+ for(key_value_map_ci::const_iterator ci = response_headers.begin(); ci != response_headers.end(); ++ci )
+ {
+ out << ci->first << ": " << ci->second << "\r\n";
+ }
+
+ // set any cookies
+ for(key_value_map::const_iterator ci = new_cookies.begin(); ci != new_cookies.end(); ++ci )
+ {
+ out << "Set-Cookie: " << urlencode(ci->first) << '=' << urlencode(ci->second) << "\r\n";
+ }
+ out << "\r\n" << result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void write_http_response (
+ std::ostream& out,
+ const http_parse_error& e
+ )
+ {
+ outgoing_things outgoing;
+ outgoing.http_return = e.http_error_code;
+ outgoing.http_return_status = e.what();
+ write_http_response(out, outgoing, std::string("Error processing request: ") + e.what());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void write_http_response (
+ std::ostream& out,
+ const std::exception& e
+ )
+ {
+ outgoing_things outgoing;
+ outgoing.http_return = 500;
+ outgoing.http_return_status = e.what();
+ write_http_response(out, outgoing, std::string("Error processing request: ") + e.what());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const logger server_http::dlog("dlib.server_http");
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SERVER_HTTP_CPp_
+
diff --git a/ml/dlib/dlib/server/server_http.h b/ml/dlib/dlib/server/server_http.h
new file mode 100644
index 000000000..4e95f679f
--- /dev/null
+++ b/ml/dlib/dlib/server/server_http.h
@@ -0,0 +1,242 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net), Steven Van Ingelgem
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERVER_HTTp_1_
+#define DLIB_SERVER_HTTp_1_
+
+
+#include "server_http_abstract.h"
+#include <iostream>
+#include <sstream>
+#include <string>
+#include <cctype>
+#include <map>
+#include "../logger.h"
+#include "../string.h"
+#include "server_iostream.h"
+
+#ifdef __INTEL_COMPILER
+// ignore the bogus warning about hiding on_connect()
+#pragma warning (disable: 1125)
+#endif
+
+#if _MSC_VER
+# pragma warning( disable: 4503 )
+#endif // _MSC_VER
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class http_parse_error : public error
+ {
+ public:
+ http_parse_error(const std::string& str, int http_error_code_):
+ error(str),http_error_code(http_error_code_) {}
+
+ const int http_error_code;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename Key, typename Value, typename Comparer = std::less<Key> >
+ class constmap : public std::map<Key, Value, Comparer>
+ {
+ public:
+ const Value& operator[](const Key& k) const
+ {
+ static const Value dummy = Value();
+
+ typename std::map<Key, Value, Comparer>::const_iterator ci = std::map<Key, Value, Comparer>::find(k);
+
+ if ( ci == this->end() )
+ return dummy;
+ else
+ return ci->second;
+ }
+
+ Value& operator[](const Key& k)
+ {
+ return std::map<Key, Value, Comparer>::operator [](k);
+ }
+ };
+
+
+ class less_case_insensitive
+ {
+ public:
+ bool operator()(const std::string& a, const std::string& b) const
+ {
+ unsigned long i = 0;
+ while (i < a.size() && i < b.size())
+ {
+ const int cha = std::tolower(a[i]);
+ const int chb = std::tolower(b[i]);
+ if (cha < chb)
+ return true;
+ else if (cha > chb)
+ return false;
+ ++i;
+ }
+ if (a.size() < b.size())
+ return true;
+ else
+ return false;
+ }
+ };
+ typedef constmap< std::string, std::string, less_case_insensitive > key_value_map_ci;
+ typedef constmap< std::string, std::string > key_value_map;
+
+ struct incoming_things
+ {
+ incoming_things (
+ const std::string& foreign_ip_,
+ const std::string& local_ip_,
+ unsigned short foreign_port_,
+ unsigned short local_port_
+ ):
+ foreign_ip(foreign_ip_),
+ foreign_port(foreign_port_),
+ local_ip(local_ip_),
+ local_port(local_port_)
+ {}
+
+
+ std::string path;
+ std::string request_type;
+ std::string content_type;
+ std::string protocol;
+ std::string body;
+
+ key_value_map queries;
+ key_value_map cookies;
+ key_value_map_ci headers;
+
+ std::string foreign_ip;
+ unsigned short foreign_port;
+ std::string local_ip;
+ unsigned short local_port;
+ };
+
+ struct outgoing_things
+ {
+ outgoing_things() : http_return(200), http_return_status("OK") { }
+
+ key_value_map cookies;
+ key_value_map_ci headers;
+ unsigned short http_return;
+ std::string http_return_status;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long parse_http_request (
+ std::istream& in,
+ incoming_things& incoming,
+ unsigned long max_content_length
+ );
+
+ void read_body (
+ std::istream& in,
+ incoming_things& incoming
+ );
+
+ void write_http_response (
+ std::ostream& out,
+ outgoing_things outgoing,
+ const std::string& result
+ );
+
+ void write_http_response (
+ std::ostream& out,
+ const http_parse_error& e
+ );
+
+ void write_http_response (
+ std::ostream& out,
+ const std::exception& e
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ class server_http : public server_iostream
+ {
+
+ public:
+
+ server_http()
+ {
+ max_content_length = 10*1024*1024; // 10MB
+ }
+
+ unsigned long get_max_content_length (
+ ) const
+ {
+ auto_mutex lock(http_class_mutex);
+ return max_content_length;
+ }
+
+ void set_max_content_length (
+ unsigned long max_length
+ )
+ {
+ auto_mutex lock(http_class_mutex);
+ max_content_length = max_length;
+ }
+
+
+ private:
+ virtual const std::string on_request (
+ const incoming_things& incoming,
+ outgoing_things& outgoing
+ ) = 0;
+
+
+ virtual void on_connect (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& foreign_ip,
+ const std::string& local_ip,
+ unsigned short foreign_port,
+ unsigned short local_port,
+ uint64
+ )
+ {
+ try
+ {
+ incoming_things incoming(foreign_ip, local_ip, foreign_port, local_port);
+ outgoing_things outgoing;
+
+ parse_http_request(in, incoming, get_max_content_length());
+ read_body(in, incoming);
+ const std::string& result = on_request(incoming, outgoing);
+ write_http_response(out, outgoing, result);
+ }
+ catch (http_parse_error& e)
+ {
+ dlog << LERROR << "Error processing request from: " << foreign_ip << " - " << e.what();
+ write_http_response(out, e);
+ }
+ catch (std::exception& e)
+ {
+ dlog << LERROR << "Error processing request from: " << foreign_ip << " - " << e.what();
+ write_http_response(out, e);
+ }
+ }
+
+ mutex http_class_mutex;
+ unsigned long max_content_length;
+ const static logger dlog;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "server_http.cpp"
+#endif
+
+#endif // DLIB_SERVER_HTTp_1_
+
diff --git a/ml/dlib/dlib/server/server_http_abstract.h b/ml/dlib/dlib/server/server_http_abstract.h
new file mode 100644
index 000000000..0bebfb61d
--- /dev/null
+++ b/ml/dlib/dlib/server/server_http_abstract.h
@@ -0,0 +1,390 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net), Steven Van Ingelgem
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SERVER_HTTp_ABSTRACT_
+#ifdef DLIB_SERVER_HTTp_ABSTRACT_
+
+#include "server_iostream_abstract.h"
+#include <iostream>
+#include <string>
+#include <map>
+
+namespace dlib
+{
+
+// -----------------------------------------------------------------------------------------
+
+ template <
+ typename Key,
+ typename Value,
+ typename Comparer = std::less<Key>
+ >
+ class constmap : public std::map<Key, Value, Comparer>
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is simply an extension to the std::map that allows you
+ to use the operator[] accessor with a constant map.
+ !*/
+ public:
+
+ const Value& operator[](
+ const Key& k
+ ) const;
+ /*!
+ ensures
+ - if (this->find(k) != this->end()) then
+ - This map contains the given key
+ - return the value associated with the given key
+ - else
+ - return a default initialized Value object
+ !*/
+
+ Value& operator[](
+ const Key& k
+ ) { return std::map<Key, Value>::operator [](k); }
+ /*!
+ ensures
+ - This function does the same thing as the normal std::map operator[]
+ function.
+ - if (this->find(k) != this->end()) then
+ - This map contains the given key
+ - return the value associated with the given key
+ - else
+ - Adds a new entry into the map that is associated with the
+ given key. The new entry will be default initialized and
+ this function returns a reference to it.
+ !*/
+ };
+
+ typedef constmap<std::string, std::string> key_value_map;
+ // This version of key_value_map treats the key string as being case-insensitive.
+ // For example, a key string of "Content-Type" would access the same element as a key
+ // of "content-type".
+ typedef constmap<std::string, std::string, less_case_insensitive> key_value_map_ci;
+
+// -----------------------------------------------------------------------------------------
+
+ struct incoming_things
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object contains all the various bits of information that describe a
+ HTTP request that comes into a web server.
+
+ For a detailed discussion of the fields of this object, see the
+ server_http::on_request() method defined later in this file.
+ !*/
+
+ incoming_things (
+ const std::string& foreign_ip_,
+ const std::string& local_ip_,
+ unsigned short foreign_port_,
+ unsigned short local_port_
+ );
+ /*!
+ ensures
+ - #foreign_ip = foreign_ip_
+ - #foreign_port = foreign_port_
+ - #local_ip = local_ip_
+ - #local_port = local_port_
+ !*/
+
+ std::string path;
+ std::string request_type;
+ std::string content_type;
+ std::string protocol;
+ std::string body;
+
+ key_value_map queries;
+ key_value_map cookies;
+ key_value_map_ci headers;
+
+ std::string foreign_ip;
+ unsigned short foreign_port;
+ std::string local_ip;
+ unsigned short local_port;
+ };
+
+ struct outgoing_things
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object contains all the various bits of information that describe a
+ HTTP response from a web server.
+
+ For a detailed discussion of the fields of this object, see the
+ server_http::on_request() method defined later in this file.
+ !*/
+
+ outgoing_things(
+ );
+ /*!
+ ensures
+ - #http_return == 200
+ - #http_return_status == "OK"
+ !*/
+
+ key_value_map cookies;
+ key_value_map_ci headers;
+ unsigned short http_return;
+ std::string http_return_status;
+ };
+
+// -----------------------------------------------------------------------------------------
+
+ class http_parse_error : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an exception thrown by the parse_http_request() routine if
+ there is a problem.
+ !*/
+ };
+
+// -----------------------------------------------------------------------------------------
+
+ unsigned long parse_http_request (
+ std::istream& in,
+ incoming_things& incoming,
+ unsigned long max_content_length
+ );
+ /*!
+ ensures
+ - Attempts to read a HTTP GET, POST, or PUT request from the given input
+ stream.
+ - Reads all headers of the request and puts them into #incoming. In particular,
+ this function populates the following fields:
+ - #incoming.path
+ - #incoming.request_type
+ - #incoming.content_type
+ - #incoming.protocol
+ - #incoming.queries
+ - #incoming.cookies
+ - #incoming.headers
+ - This function also populates the #incoming.body field if and only if the
+ Content-Type field is equal to "application/x-www-form-urlencoded".
+ Otherwise, the content is not read from the stream.
+ throws
+ - http_parse_error
+ This exception is thrown if the Content-Length coming from the web
+ browser is greater than max_content_length or if any other problem
+ is detected with the request.
+ !*/
+
+ void read_body (
+ std::istream& in,
+ incoming_things& incoming
+ );
+ /*!
+ requires
+ - parse_http_request(in,incoming,max_content_length) has already been called
+ and therefore populated the fields of incoming.
+ ensures
+ - if (incoming.body has already been populated with the content of an HTTP
+ request) then
+ - this function does nothing
+ - else
+ - reads the body of the HTTP request into #incoming.body.
+ !*/
+
+ void write_http_response (
+ std::ostream& out,
+ outgoing_things outgoing,
+ const std::string& result
+ );
+ /*!
+ ensures
+ - Writes an HTTP response, defined by the data in outgoing, to the given output
+ stream.
+ - The result variable is written out as the content of the response.
+ !*/
+
+ void write_http_response (
+ std::ostream& out,
+ const http_parse_error& e
+ );
+ /*!
+ ensures
+ - Writes an HTTP error response based on the information in the exception
+ object e.
+ !*/
+
+ void write_http_response (
+ std::ostream& out,
+ const std::exception& e
+ );
+ /*!
+ ensures
+ - Writes an HTTP error response based on the information in the exception
+ object e.
+ !*/
+
+// -----------------------------------------------------------------------------------------
+// -----------------------------------------------------------------------------------------
+
+ class server_http : public server_iostream
+ {
+
+ /*!
+ WHAT THIS EXTENSION DOES FOR server_iostream
+ This extension turns the server object into a simple HTTP server. It only
+ handles HTTP GET, PUT and POST requests and each incoming request triggers
+ the on_request() callback.
+
+ COOKIE STRINGS
+ The strings returned in the cookies key_value_map should be of the following form:
+ key: cookie_name
+ value: cookie contents; expires=Fri, 31-Dec-2010 23:59:59 GMT; path=/; domain=.example.net
+
+ You don't have to supply all the extra cookie arguments. So if you just want to
+ set a cookie that will expire when the client's browser is closed you can
+ just say something like incoming.cookies["cookie_name"] = "cookie contents";
+
+ HTTP HEADERS
+ The HTTP headers in the incoming.headers and outgoing.headers are the name/value pairs
+ of HTTP headers. For example, the HTTP header "Content-Type: text/html" would be
+ encoded such that outgoing.headers["Content-Type"] == "text/html".
+
+ Also note that if you wish to change the content type of your response to the
+ client you may do so by setting the "Content-Type" header to whatever you like.
+ However, setting this field manually is not necessary as it will default to
+ "text/html" if you don't explicitly set it to something.
+ !*/
+
+ public:
+
+ server_http (
+ );
+ /*!
+ ensures
+ - #get_max_content_length() == 10*1024*1024
+ !*/
+
+ unsigned long get_max_content_length (
+ ) const;
+ /*!
+ ensures
+ - returns the max allowable content length, in bytes, of the post back to
+ the web server. If a client attempts to send more data than this then an
+ error number 413 is returned back to the client and the request is not
+ processed by the web server.
+ !*/
+
+ void set_max_content_length (
+ unsigned long max_length
+ );
+ /*!
+ ensures
+ - #get_max_content_length() == max_length
+ !*/
+
+ private:
+
+ virtual const std::string on_request (
+ const incoming_things& incoming,
+ outgoing_things& outgoing
+ ) = 0;
+ /*!
+ requires
+ - on_request() is called when there is an HTTP GET or POST request to be serviced
+ - on_request() is run in its own thread
+ - is_running() == true
+ - the number of current on_request() functions running < get_max_connection()
+ - in incoming:
+ - incoming.path == the path being requested by this request
+ - incoming.request_type == the type of request, GET or POST
+ - incoming.content_type == the content type associated with this request
+ - incoming.protocol == The protocol being used by the web browser (e.g. HTTP/1.1)
+ - incoming.body == a string that contains the data that was posted back to the
+ web server by the client (e.g. The string has the length specified by the
+ Content-Length header).
+ - incoming.body.size() < get_max_content_length()
+ - incoming.queries == a map that contains all the key/value pairs in the query
+ string of this request. The key and value strings of the query string will
+ have been decoded back into their original form before being sent to this
+ function (i.e. '+' decoded back to ' ' and "%hh" into its corresponding
+ ascii value. So the URL-encoding is decoded automatically)
+ - incoming.cookies == The set of cookies that came from the client along with
+ this request. The cookies will have been decoded back to normal form
+ from the URL-encoding.
+ - incoming.headers == a map that contains all the incoming HTTP headers
+ from the client web browser.
+ - incoming.foreign_ip == the foreign ip address for this request
+ - incoming.foreign_port == the foreign port number for this request
+ - incoming.local_ip == the IP of the local interface this request is coming in on
+ - incoming.local_port == the local port number this request is coming in on
+ - in outgoing:
+ - outgoing.cookies.size() == 0
+ - outgoing.headers.size() == 0
+ - outgoing.http_return == 200
+ - outgoing.http_return_status == "OK"
+ ensures
+ - This function returns the HTML page to be displayed as the response to this request.
+ - this function will not call clear()
+ - #outgoing.cookies == a set of new cookies to pass back to the client along
+ with the result of this request. (Note that URL-encoding is automatically applied
+ so you don't have to do it yourself)
+ - #outgoing.headers == a set of additional headers you wish to appear in the
+ HTTP response to this request. (This may be empty, the minimum needed headers
+ will be added automatically if you don't set them)
+ - outgoing.http_return and outgoing.http_return_status may be set to override the
+ default HTTP return code of 200 OK
+ throws
+ - throws only exceptions derived from std::exception. If an exception is thrown
+ then the error string from the exception is returned to the web browser.
+ !*/
+
+
+ // -----------------------------------------------------------------------
+ // Implementation Notes
+ // -----------------------------------------------------------------------
+
+ virtual void on_connect (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& foreign_ip,
+ const std::string& local_ip,
+ unsigned short foreign_port,
+ unsigned short local_port,
+ uint64
+ )
+ /*!
+ on_connect() is the function defined by server_iostream which is overloaded by
+ server_http. In particular, the server_http's implementation is shown below.
+ In it you can see how the server_http parses the incoming http request, gets a
+ response by calling on_request(), and sends it back using the helper routines
+ defined at the top of this file.
+
+ Therefore, if you want to modify the behavior of the HTTP server, for example,
+ to do some more complex data streaming requiring direct access to the
+ iostreams, then you can do so by defining your own on_connect() routine. In
+ particular, the default implementation shown below is a good starting point.
+ !*/
+ {
+ try
+ {
+ incoming_things incoming(foreign_ip, local_ip, foreign_port, local_port);
+ outgoing_things outgoing;
+
+ parse_http_request(in, incoming, get_max_content_length());
+ read_body(in, incoming);
+ const std::string& result = on_request(incoming, outgoing);
+ write_http_response(out, outgoing, result);
+ }
+ catch (http_parse_error& e)
+ {
+ write_http_response(out, e);
+ }
+ catch (std::exception& e)
+ {
+ write_http_response(out, e);
+ }
+ }
+ };
+
+}
+
+#endif // DLIB_SERVER_HTTp_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/server/server_iostream.cpp b/ml/dlib/dlib/server/server_iostream.cpp
new file mode 100644
index 000000000..0fd49b67c
--- /dev/null
+++ b/ml/dlib/dlib/server/server_iostream.cpp
@@ -0,0 +1,14 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERVER_IOSTREAM_CPp_
+#define DLIB_SERVER_IOSTREAM_CPp_
+
+#include "server_iostream.h"
+
+namespace dlib
+{
+ const logger server_iostream::_dLog("dlib.server_iostream");
+}
+
+#endif // DLIB_SERVER_IOSTREAM_CPp_
+
diff --git a/ml/dlib/dlib/server/server_iostream.h b/ml/dlib/dlib/server/server_iostream.h
new file mode 100644
index 000000000..eed349015
--- /dev/null
+++ b/ml/dlib/dlib/server/server_iostream.h
@@ -0,0 +1,155 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERVER_IOSTREAm_1_
+#define DLIB_SERVER_IOSTREAm_1_
+
+#include <iostream>
+#include "server_iostream_abstract.h"
+#include "../logger.h"
+#include "../uintn.h"
+#include "server_kernel.h"
+#include "../sockstreambuf.h"
+#include "../map.h"
+
+
+namespace dlib
+{
+
+ class server_iostream : public server
+ {
+
+ /*!
+ INITIAL VALUE
+ - next_id == 0
+ - con_map.size() == 0
+
+ CONVENTION
+ - next_id == the id of the next connection
+ - for all current connections
+ - con_map[id] == the connection object with the given id
+ - m == the mutex that protects the members of this object
+ !*/
+
+ typedef map<uint64,connection*,memory_manager<char>::kernel_2a>::kernel_1b id_map;
+
+ public:
+ server_iostream(
+ ) :
+ next_id(0)
+ {}
+
+ ~server_iostream(
+ )
+ {
+ server::clear();
+ }
+
+ protected:
+
+ void shutdown_connection (
+ uint64 id
+ )
+ {
+ auto_mutex M(m);
+ if (con_map.is_in_domain(id))
+ {
+ con_map[id]->shutdown();
+ }
+ }
+
+ private:
+
+ virtual void on_connect (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& foreign_ip,
+ const std::string& local_ip,
+ unsigned short foreign_port,
+ unsigned short local_port,
+ uint64 connection_id
+ )=0;
+
+ void on_connect (
+ connection& con
+ )
+ {
+ bool my_fault = true;
+ uint64 this_con_id=0;
+ try
+ {
+ sockstreambuf buf(&con);
+ std::istream in(&buf);
+ std::ostream out(&buf);
+ in.tie(&out);
+
+ // add this connection to the con_map
+ {
+ auto_mutex M(m);
+ this_con_id = next_id;
+ connection* this_con = &con;
+ con_map.add(this_con_id,this_con);
+ this_con_id = next_id;
+ ++next_id;
+ }
+
+ my_fault = false;
+ on_connect(
+ in,
+ out,
+ con.get_foreign_ip(),
+ con.get_local_ip(),
+ con.get_foreign_port(),
+ con.get_local_port(),
+ this_con_id
+ );
+
+ // remove this connection from the con_map
+ {
+ auto_mutex M(m);
+ connection* this_con;
+ uint64 junk;
+ con_map.remove(this_con_id,junk,this_con);
+ }
+
+ }
+ catch (std::bad_alloc&)
+ {
+ // make sure we remove this connection from the con_map
+ {
+ auto_mutex M(m);
+ if (con_map.is_in_domain(this_con_id))
+ {
+ connection* this_con;
+ uint64 junk;
+ con_map.remove(this_con_id,junk,this_con);
+ }
+ }
+
+ _dLog << LERROR << "We ran out of memory in server_iostream::on_connect()";
+ // if this is an escaped exception from on_connect then let it fly!
+ // Seriously though, this way it is obvious to the user that something bad happened
+ // since they probably won't have the dlib logger enabled.
+ if (!my_fault)
+ throw;
+ }
+ }
+
+ uint64 next_id;
+ id_map con_map;
+ const static logger _dLog;
+ mutex m;
+
+
+ };
+
+
+}
+
+#ifdef NO_MAKEFILE
+#include "server_iostream.cpp"
+#endif
+
+#endif // DLIB_SERVER_IOSTREAm_1_
+
+
+
diff --git a/ml/dlib/dlib/server/server_iostream_abstract.h b/ml/dlib/dlib/server/server_iostream_abstract.h
new file mode 100644
index 000000000..86cd460bf
--- /dev/null
+++ b/ml/dlib/dlib/server/server_iostream_abstract.h
@@ -0,0 +1,84 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SERVER_IOSTREAm_ABSTRACT_
+#ifdef DLIB_SERVER_IOSTREAm_ABSTRACT_
+
+
+#include "server_kernel_abstract.h"
+#include <iostream>
+#include <string>
+#include "../uintn.h"
+
+namespace dlib
+{
+
+ class server_iostream : public server
+ {
+
+ /*!
+ WHAT THIS EXTENSION DOES FOR server
+ This extension redefines the on_connect() function so that
+ instead of giving you a connection object you get an istream
+ and ostream object.
+
+ THREAD SAFETY
+ Note that in on_connect() the input stream in is tied to the output stream
+ out. This means that when you read from in it will modify out and thus
+ it is not safe to touch in and out concurrently from different threads
+ unless you untie them (which you do by saying in.tie(0);)
+ !*/
+
+ protected:
+
+ void shutdown_connection (
+ uint64 id
+ );
+ /*!
+ ensures
+ - if (there is a connection currently being serviced with the given id) then
+ - the specified connection is shutdown. (i.e. connection::shutdown() is
+ called on it so the iostreams operating on it will return EOF)
+ !*/
+
+ private:
+
+ virtual void on_connect (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& foreign_ip,
+ const std::string& local_ip,
+ unsigned short foreign_port,
+ unsigned short local_port,
+ uint64 connection_id
+ )=0;
+ /*!
+ requires
+ - on_connect() is called when there is a new TCP connection that needs
+ to be serviced.
+ - in == the input stream that reads data from the new connection
+ - out == the output stream that writes data to the new connection
+ - in.tie() == &out (i.e. when you read from in it automatically calls out.flush())
+ - foreign_ip == the foreign ip address for this connection
+ - foreign_port == the foreign port number for this connection
+ - local_ip == the IP of the local interface this connection is using
+ - local_port == the local port number for this connection
+ - on_connect() is run in its own thread
+ - is_running() == true
+ - the number of current connections < get_max_connection()
+ - connection_id == an integer that uniquely identifies this connection.
+ It can be used by shutdown_connection() to terminate this connection.
+ ensures
+ - when the iostreams hit EOF on_connect() will terminate.
+ (because this is how clear() signals you the server is shutting down)
+ - this function will not call clear()
+ throws
+ - does not throw any exceptions
+ !*/
+
+ };
+
+}
+
+#endif // DLIB_SERVER_IOSTREAm_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/server/server_kernel.cpp b/ml/dlib/dlib/server/server_kernel.cpp
new file mode 100644
index 000000000..9e8130e78
--- /dev/null
+++ b/ml/dlib/dlib/server/server_kernel.cpp
@@ -0,0 +1,595 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERVER_KERNEL_CPp_
+#define DLIB_SERVER_KERNEL_CPp_
+
+#include "server_kernel.h"
+#include "../string.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ server::
+ server (
+ ) :
+ listening_port(0),
+ running(false),
+ shutting_down(false),
+ running_signaler(running_mutex),
+ thread_count(0),
+ thread_count_signaler(thread_count_mutex),
+ max_connections(1000),
+ thread_count_zero(thread_count_mutex),
+ graceful_close_timeout(500)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ server::
+ ~server (
+ )
+ {
+ clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long server::
+ get_graceful_close_timeout (
+ ) const
+ {
+ auto_mutex lock(max_connections_mutex);
+ return graceful_close_timeout;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ set_graceful_close_timeout (
+ unsigned long timeout
+ )
+ {
+ auto_mutex lock(max_connections_mutex);
+ graceful_close_timeout = timeout;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ int server::
+ get_max_connections (
+ ) const
+ {
+ max_connections_mutex.lock();
+ int temp = max_connections;
+ max_connections_mutex.unlock();
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ set_max_connections (
+ int max
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(
+ max >= 0 ,
+ "\tvoid server::set_max_connections"
+ << "\n\tmax == " << max
+ << "\n\tthis: " << this
+ );
+
+ max_connections_mutex.lock();
+ max_connections = max;
+ max_connections_mutex.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ clear (
+ )
+ {
+ // signal that we are shutting down
+ shutting_down_mutex.lock();
+ shutting_down = true;
+ shutting_down_mutex.unlock();
+
+
+
+ max_connections_mutex.lock();
+ listening_port_mutex.lock();
+ listening_ip_mutex.lock();
+ listening_ip = "";
+ listening_port = 0;
+ max_connections = 1000;
+ graceful_close_timeout = 500;
+ listening_port_mutex.unlock();
+ listening_ip_mutex.unlock();
+ max_connections_mutex.unlock();
+
+
+ // tell all the connections to shut down
+ cons_mutex.lock();
+ connection* temp;
+ while (cons.size() > 0)
+ {
+ cons.remove_any(temp);
+ temp->shutdown();
+ }
+ cons_mutex.unlock();
+
+
+ // wait for all the connections to shut down
+ thread_count_mutex.lock();
+ while (thread_count > 0)
+ {
+ thread_count_zero.wait();
+ }
+ thread_count_mutex.unlock();
+
+
+
+
+ // wait for the listener to close
+ running_mutex.lock();
+ while (running == true)
+ {
+ running_signaler.wait();
+ }
+ running_mutex.unlock();
+
+
+
+ // signal that the shutdown is complete
+ shutting_down_mutex.lock();
+ shutting_down = false;
+ shutting_down_mutex.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ start_async_helper (
+ )
+ {
+ try
+ {
+ start_accepting_connections();
+ }
+ catch (std::exception& e)
+ {
+ sdlog << LERROR << e.what();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ start_async (
+ )
+ {
+ auto_mutex lock(running_mutex);
+ if (running)
+ return;
+
+ // Any exceptions likely to be thrown by the server are going to be
+ // thrown when trying to bind the port. So calling this here rather
+ // than in the thread we are about to make will cause start_async()
+ // to report errors back to the user in a very straight forward way.
+ open_listening_socket();
+
+ async_start_thread.reset(new thread_function(make_mfp(*this,&server::start_async_helper)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ open_listening_socket (
+ )
+ {
+ if (!sock)
+ {
+ int status = create_listener(sock,listening_port,listening_ip);
+ const int port_used = listening_port;
+
+ // if there was an error then clear this object
+ if (status < 0)
+ {
+ max_connections_mutex.lock();
+ listening_port_mutex.lock();
+ listening_ip_mutex.lock();
+ listening_ip = "";
+ listening_port = 0;
+ max_connections = 1000;
+ graceful_close_timeout = 500;
+ listening_port_mutex.unlock();
+ listening_ip_mutex.unlock();
+ max_connections_mutex.unlock();
+ }
+
+
+
+ // throw an exception for the error
+ if (status == PORTINUSE)
+ {
+ throw dlib::socket_error(
+ EPORT_IN_USE,
+ "error occurred in server::start()\nport " + cast_to_string(port_used) + " already in use"
+ );
+ }
+ else if (status == OTHER_ERROR)
+ {
+ throw dlib::socket_error(
+ "error occurred in server::start()\nunable to create listener"
+ );
+ }
+ }
+
+ running_mutex.lock();
+ running = true;
+ running_mutex.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ start (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(
+ this->is_running() == false,
+ "\tvoid server::start"
+ << "\n\tis_running() == " << this->is_running()
+ << "\n\tthis: " << this
+ );
+
+ start_accepting_connections();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ start_accepting_connections (
+ )
+ {
+ open_listening_socket();
+
+ // determine the listening port
+ bool port_assigned = false;
+ listening_port_mutex.lock();
+ if (listening_port == 0)
+ {
+ port_assigned = true;
+ listening_port = sock->get_listening_port();
+ }
+ listening_port_mutex.unlock();
+ if (port_assigned)
+ on_listening_port_assigned();
+
+
+
+ int status = 0;
+
+ connection* client;
+ bool exit = false;
+ while ( true )
+ {
+
+
+ // accept the next connection
+ status = sock->accept(client,1000);
+
+
+ // if there was an error then quit the loop
+ if (status == OTHER_ERROR)
+ {
+ break;
+ }
+
+ shutting_down_mutex.lock();
+ // if we are shutting down then signal that we should quit the loop
+ exit = shutting_down;
+ shutting_down_mutex.unlock();
+
+
+ // if we should be shutting down
+ if (exit)
+ {
+ // if a connection was opened then close it
+ if (status == 0)
+ delete client;
+ break;
+ }
+
+
+
+ // if the accept timed out
+ if (status == TIMEOUT)
+ {
+ continue;
+ }
+
+
+
+
+
+ // add this new connection to cons
+ cons_mutex.lock();
+ connection* client_temp = client;
+ try{cons.add(client_temp);}
+ catch(...)
+ {
+ sock.reset();
+ delete client;
+ cons_mutex.unlock();
+
+ // signal that we are not running start() anymore
+ running_mutex.lock();
+ running = false;
+ running_signaler.broadcast();
+ running_mutex.unlock();
+
+
+ clear();
+ throw;
+ }
+ cons_mutex.unlock();
+
+
+ // make a param structure
+ param* temp = 0;
+ try{
+ temp = new param (
+ *this,
+ *client,
+ get_graceful_close_timeout()
+ );
+ } catch (...)
+ {
+ sock.reset();
+ delete client;
+ running_mutex.lock();
+ running = false;
+ running_signaler.broadcast();
+ running_mutex.unlock();
+ clear();
+ throw;
+ }
+
+
+ // if create_new_thread failed
+ if (!create_new_thread(service_connection,temp))
+ {
+ delete temp;
+ // close the listening socket
+ sock.reset();
+
+ // close the new connection and remove it from cons
+ cons_mutex.lock();
+ connection* ctemp;
+ if (cons.is_member(client))
+ {
+ cons.remove(client,ctemp);
+ }
+ delete client;
+ cons_mutex.unlock();
+
+
+ // signal that the listener has closed
+ running_mutex.lock();
+ running = false;
+ running_signaler.broadcast();
+ running_mutex.unlock();
+
+ // make sure the object is cleared
+ clear();
+
+ // throw the exception
+ throw dlib::thread_error(
+ ECREATE_THREAD,
+ "error occurred in server::start()\nunable to start thread"
+ );
+ }
+ // if we made the new thread then update thread_count
+ else
+ {
+ // increment the thread count
+ thread_count_mutex.lock();
+ ++thread_count;
+ if (thread_count == 0)
+ thread_count_zero.broadcast();
+ thread_count_mutex.unlock();
+ }
+
+
+
+
+ // check if we have hit the maximum allowed number of connections
+ max_connections_mutex.lock();
+ // if max_connections is zero or the loop is ending then skip this
+ if (max_connections != 0)
+ {
+ // wait for thread_count to be less than max_connections
+ thread_count_mutex.lock();
+ while (thread_count >= max_connections)
+ {
+ max_connections_mutex.unlock();
+ thread_count_signaler.wait();
+ max_connections_mutex.lock();
+
+ // if we are shutting down the quit the loop
+ shutting_down_mutex.lock();
+ exit = shutting_down;
+ shutting_down_mutex.unlock();
+ if (exit)
+ break;
+ }
+ thread_count_mutex.unlock();
+ }
+ max_connections_mutex.unlock();
+
+ if (exit)
+ {
+ break;
+ }
+ } //while ( true )
+
+
+ // close the socket
+ sock.reset();
+
+ // signal that the listener has closed
+ running_mutex.lock();
+ running = false;
+ running_signaler.broadcast();
+ running_mutex.unlock();
+
+ // if there was an error with accept then throw an exception
+ if (status == OTHER_ERROR)
+ {
+ // make sure the object is cleared
+ clear();
+
+ // throw the exception
+ throw dlib::socket_error(
+ "error occurred in server::start()\nlistening socket returned error"
+ );
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool server::
+ is_running (
+ ) const
+ {
+ running_mutex.lock();
+ bool temp = running;
+ running_mutex.unlock();
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string server::
+ get_listening_ip (
+ ) const
+ {
+ listening_ip_mutex.lock();
+ std::string ip(listening_ip);
+ listening_ip_mutex.unlock();
+ return ip;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int server::
+ get_listening_port (
+ ) const
+ {
+ listening_port_mutex.lock();
+ int port = listening_port;
+ listening_port_mutex.unlock();
+ return port;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ set_listening_port (
+ int port
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(
+ ( port >= 0 &&
+ this->is_running() == false ),
+ "\tvoid server::set_listening_port"
+ << "\n\tport == " << port
+ << "\n\tis_running() == " << this->is_running()
+ << "\n\tthis: " << this
+ );
+
+ listening_port_mutex.lock();
+ listening_port = port;
+ listening_port_mutex.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void server::
+ set_listening_ip (
+ const std::string& ip
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(
+ ( ( is_ip_address(ip) || ip == "" ) &&
+ this->is_running() == false ),
+ "\tvoid server::set_listening_ip"
+ << "\n\tip == " << ip
+ << "\n\tis_running() == " << this->is_running()
+ << "\n\tthis: " << this
+ );
+
+ listening_ip_mutex.lock();
+ listening_ip = ip;
+ listening_ip_mutex.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // static member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ const logger server::sdlog("dlib.server");
+
+ void server::
+ service_connection(
+ void* item
+ )
+ {
+ param& p = *static_cast<param*>(item);
+
+
+ p.the_server.on_connect(p.new_connection);
+
+
+ // remove this connection from cons and close it
+ p.the_server.cons_mutex.lock();
+ connection* temp;
+ if (p.the_server.cons.is_member(&p.new_connection))
+ p.the_server.cons.remove(&p.new_connection,temp);
+ p.the_server.cons_mutex.unlock();
+
+ try{ close_gracefully(&p.new_connection, p.graceful_close_timeout); }
+ catch (...) { sdlog << LERROR << "close_gracefully() threw"; }
+
+ // decrement the thread count and signal if it is now zero
+ p.the_server.thread_count_mutex.lock();
+ --p.the_server.thread_count;
+ p.the_server.thread_count_signaler.broadcast();
+ if (p.the_server.thread_count == 0)
+ p.the_server.thread_count_zero.broadcast();
+ p.the_server.thread_count_mutex.unlock();
+
+ delete &p;
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SERVER_KERNEL_CPp_
+
diff --git a/ml/dlib/dlib/server/server_kernel.h b/ml/dlib/dlib/server/server_kernel.h
new file mode 100644
index 000000000..4232ff343
--- /dev/null
+++ b/ml/dlib/dlib/server/server_kernel.h
@@ -0,0 +1,234 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SERVER_KERNEL_1_
+#define DLIB_SERVER_KERNEL_1_
+
+#include "server_kernel_abstract.h"
+
+#include <memory>
+#include <string>
+
+#include "../threads.h"
+#include "../sockets.h"
+#include "../algs.h"
+#include "../set.h"
+#include "../logger.h"
+
+
+namespace dlib
+{
+
+ // These forward declarations are here so we can use them in the typedefs in the server
+ // class. The reason for this is for backwards compatibility with previous versions of
+ // dlib.
+ class server_http;
+ class server_iostream;
+
+ class server
+ {
+
+
+ /*!
+ INITIAL VALUE
+ listening_port == 0
+ listening_ip == ""
+ running == false
+ shutting_down == false
+ cons.size() == 0
+ listening_port_mutex == a mutex
+ listening_ip_mutex == a mutex
+ running_mutex == a mutex
+ running_signaler == a signaler associated with running_mutex
+ shutting_down_mutex == a mutex
+ cons_mutex == a mutex
+ thread_count == 0
+ thread_count_mutex == a mutex
+ thread_count_signaler == a signaler associated with thread_count_mutex
+ thread_count_zero == a signaler associated with thread_count_mutex
+ max_connections == 1000
+ max_connections_mutex == a mutex for max_connections and graceful_close_timeout
+ graceful_close_timeout == 500
+
+ CONVENTION
+ listening_port == get_listening_port()
+ listening_ip == get_listening_ip()
+ running == is_running()
+ shutting_down == true while clear() is running. this
+ bool is used to tell the thread blocked on
+ accept that it should terminate
+ cons == a set containing all open connections
+ listening_port_mutex == a mutex for listening_port
+ listening_ip_mutex == a mutex for listening_ip
+ running_mutex == a mutex for running
+ running_signaler == a signaler for running and
+ is associated with running_mutex. it is
+ used to signal when running is false
+ shutting_down_mutex == a mutex for shutting_down
+ cons_mutex == a mutex for cons
+ thread_count == the number of threads currently running
+ thread_count_mutex == a mutex for thread_count
+ thread_count_signaler == a signaler for thread_count and
+ is associated with thread_count_mutex. it
+ is used to signal when thread_count is
+ decremented
+ thread_count_zero == a signaler for thread_count and
+ is associated with thread_count_mutex. it
+ is used to signal when thread_count becomes
+ zero
+ max_connections == get_max_connections()
+ max_connections_mutex == a mutex for max_connections
+ !*/
+
+
+ typedef set<connection*>::kernel_1a set_of_connections;
+
+ // this structure is used to pass parameters to new threads
+ struct param
+ {
+ param (
+ server& server_,
+ connection& new_connection_,
+ unsigned long graceful_close_timeout_
+ ) :
+ the_server(server_),
+ new_connection(new_connection_),
+ graceful_close_timeout(graceful_close_timeout_)
+ {}
+
+ server& the_server;
+ connection& new_connection;
+ unsigned long graceful_close_timeout;
+ };
+
+
+
+ public:
+
+ // These typedefs are here for backward compatibility with previous versions of dlib
+ typedef server kernel_1a;
+ typedef server kernel_1a_c;
+ typedef server_iostream iostream_1a;
+ typedef server_iostream iostream_1a_c;
+ typedef server_http http_1a;
+ typedef server_http http_1a_c;
+
+ server(
+ );
+
+ virtual ~server(
+ );
+
+ void clear(
+ );
+
+ void start (
+ );
+
+ bool is_running (
+ ) const;
+
+ const std::string get_listening_ip (
+ ) const;
+
+ int get_listening_port (
+ ) const;
+
+ void set_listening_port (
+ int port
+ );
+
+ void set_listening_ip (
+ const std::string& ip
+ );
+
+ void set_max_connections (
+ int max
+ );
+
+ int get_max_connections (
+ ) const;
+
+ void start_async (
+ );
+
+ void set_graceful_close_timeout (
+ unsigned long timeout
+ );
+
+ unsigned long get_graceful_close_timeout (
+ ) const;
+
+ private:
+
+ void start_async_helper (
+ );
+
+ void start_accepting_connections (
+ );
+
+ void open_listening_socket (
+ );
+
+ virtual void on_connect (
+ connection& new_connection
+ )=0;
+
+ virtual void on_listening_port_assigned (
+ ) {}
+
+ const static logger sdlog;
+
+ static void service_connection(
+ void* item
+ );
+ /*!
+ requires
+ item is a pointer to a param struct
+ ensures
+ services the new connection
+ will take care of closing the connection and
+ adding the connection to cons when it first starts and
+ remove the connection from cons and signal that it has
+ done so when it ends
+ !*/
+
+ // data members
+ int listening_port;
+ std::string listening_ip;
+ bool running;
+ bool shutting_down;
+ set_of_connections cons;
+ mutex listening_port_mutex;
+ mutex listening_ip_mutex;
+ rmutex running_mutex;
+ rsignaler running_signaler;
+ mutex shutting_down_mutex;
+ mutex cons_mutex;
+ int thread_count;
+ mutex thread_count_mutex;
+ signaler thread_count_signaler;
+ int max_connections;
+ mutex max_connections_mutex;
+ signaler thread_count_zero;
+ std::unique_ptr<thread_function> async_start_thread;
+ std::unique_ptr<listener> sock;
+ unsigned long graceful_close_timeout;
+
+
+ // restricted functions
+ server(server&);
+ server& operator= (
+ server&
+ );
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "server_kernel.cpp"
+#endif
+
+#endif // DLIB_SERVER_KERNEL_1_
+
diff --git a/ml/dlib/dlib/server/server_kernel_abstract.h b/ml/dlib/dlib/server/server_kernel_abstract.h
new file mode 100644
index 000000000..f7860d26c
--- /dev/null
+++ b/ml/dlib/dlib/server/server_kernel_abstract.h
@@ -0,0 +1,310 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SERVER_KERNEL_ABSTRACT_
+#ifdef DLIB_SERVER_KERNEL_ABSTRACT_
+
+#include "../threads/threads_kernel_abstract.h"
+#include "../sockets/sockets_kernel_abstract.h"
+#include <string>
+
+
+namespace dlib
+{
+ class server
+ {
+
+ /*!
+ INITIAL VALUE
+ get_listening_ip() == ""
+ get_listening_port() == 0
+ is_running() == false
+ get_max_connections() == 1000
+ get_graceful_close_timeout() == 500
+
+
+ CALLBACK FUNCTIONS
+ on_connect():
+ To use this object inherit from it and define the pure virtual function
+ on_connect. Inside this function is where you will handle each new
+ connection. Note that the connection object passed to on_connect() should
+ NOT be closed, just let the function end and it will be gracefully closed
+ for you. Also note that each call to on_connect() is run in its own
+ thread. Also note that on_connect() should NOT throw any exceptions,
+ all exceptions must be dealt with inside on_connect() and cannot be
+ allowed to leave.
+
+ on_listening_port_assigned():
+ This function is called to let the client know that the operating
+ system has assigned a port number to the listening port. This
+ happens if a port number of zero was given. Note that this
+ function does not need to be defined. If you don't care then
+ don't define it and it will do nothing. Note also that this function
+ is NOT called in its own thread. Thus, making it block might hang the
+ server.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a server that listens on a port and spawns new
+ threads to handle each new connection.
+
+ Note that the clear() function does not return until all calls to
+ on_connect() have finished and the start() function has been shutdown.
+ Also note that when clear() is called all open connection objects
+ will be shutdown().
+
+ A note about get_max_connections(): when the maximum number of connections
+ has been reached accept() will simply not be called until the number of
+ open connections drops below get_max_connections(). This means connections
+ will just wait to be serviced, rather than being outright refused.
+
+ THREAD SAFETY
+ All member functions are thread-safe.
+ !*/
+
+ public:
+
+ server(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~server(
+ );
+ /*!
+ requires
+ - is not called from any of server's callbacks
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ requires
+ - is not called from any of server's callbacks
+ ensures
+ - #*this has its initial value
+ - all open connection objects passed to on_connect() are shutdown()
+ - blocks until all calls to on_connect() have finished
+ - blocks until the start() function has released all its resources
+ throws
+ - std::bad_alloc
+ if this exception is thrown then the server object is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void start (
+ );
+ /*!
+ requires
+ - is_running() == false
+ ensures
+ - starts listening on the port and ip specified by get_listening_ip()
+ and #get_listening_port() for new connections.
+ - if (get_listening_port() == 0) then
+ - a port to listen on will be automatically selected
+ - #get_listening_port() == the selected port being used
+ - if (get_listening_ip() == "" ) then
+ - all local IPs will be listened on
+ - blocks until clear() is called or an error occurs
+ throws
+ - dlib::socket_error
+ start() will throw this exception if there is some problem binding
+ ports and/or starting the server or if there is a problem
+ accepting new connections while it's running.
+ If this happens then
+ - All open connection objects passed to on_connect() are shutdown()
+ and the exception will not be thrown until all on_connect() calls
+ have terminated.
+ - The server will be cleared and returned to its initial value.
+ - dlib::thread_error
+ start() will throw this exception if there is a problem
+ creating new threads. Or it may throw this exception if there
+ is a problem creating threading objects.
+ If this happens then
+ - All open connection objects passed to on_connect() are shutdown()
+ and the exception will not be thrown until all on_connect() calls
+ have terminated.
+ - The server will be cleared and returned to its initial value.
+ - std::bad_alloc
+ start() may throw this exception and if it does then the object
+ will be unusable until clear() is called and succeeds
+ !*/
+
+ void start_async (
+ );
+ /*!
+ ensures
+ - starts listening on the port and ip specified by get_listening_ip()
+ and #get_listening_port() for new connections.
+ - if (get_listening_port() == 0) then
+ - a port to listen on will be automatically selected
+ - #get_listening_port() == the selected port being used
+ - if (get_listening_ip() == "" ) then
+ - all local IPs will be listened on
+ - does NOT block. That is, this function will return right away and
+ the server will run on a background thread until clear() or this
+ object's destructor is called (or until some kind of fatal error
+ occurs).
+ - if an error occurs in the background thread while the server is
+ running then it will shut itself down, set is_running() to false, and
+ log the error to a dlib::logger object.
+ - calling start_async() on a running server has no effect.
+ throws
+ - dlib::socket_error
+ start_async() will throw this exception if there is some problem binding
+ ports and/or starting the server.
+ If this happens then
+ - The server will be cleared and returned to its initial value.
+ !*/
+
+ bool is_running (
+ ) const;
+ /*!
+ ensures
+ - returns true if start() is running
+ - returns false if start() is not running or has released all
+ its resources and is about to terminate
+ throws
+ - std::bad_alloc
+ !*/
+
+ int get_max_connections (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of connections the server will accept
+ at a time.
+ - returns 0 if the server will accept any number of connections
+ throws
+ - std::bad_alloc
+ !*/
+
+
+ const std::string get_listening_ip (
+ ) const;
+ /*!
+ ensures
+ - returns the local ip to listen for new connections on
+ - returns "" if ALL local ips are to be listened on
+ throws
+ - std::bad_alloc
+ !*/
+
+ int get_listening_port (
+ ) const;
+ /*!
+ ensures
+ - returns the local port number to listen for new connections on
+ - returns 0 if the local port number has not yet been set
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_listening_port (
+ int port
+ );
+ /*!
+ requires
+ - port >= 0
+ - is_running() == false
+ ensures
+ - #get_listening_port() == port
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_listening_ip (
+ const std::string& ip
+ );
+ /*!
+ requires
+ - is_ip_address(ip) == true or ip == ""
+ - is_running() == false
+ ensures
+ - #get_listening_ip() == ip
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_max_connections (
+ int max
+ );
+ /*!
+ requires
+ - max >= 0
+ ensures
+ - #get_max_connections() == max
+ throws
+ - std::bad_alloc
+ !*/
+
+ void set_graceful_close_timeout (
+ unsigned long timeout
+ );
+ /*!
+ ensures
+ - #get_graceful_close_timeout() == timeout
+ !*/
+
+ unsigned long get_graceful_close_timeout (
+ ) const;
+ /*!
+ ensures
+ - When on_connect() terminates, it will close the connection using
+ close_gracefully(). This is done so that any data still in the
+ operating system's output buffers gets a chance to be properly
+ transmitted to the remote host. Part of this involves waiting for
+ the remote host to close their end of the connection. Therefore,
+ get_graceful_close_timeout() returns the timeout, in milliseconds,
+ that we wait for the remote host to close their end of the
+ connection. This is the timeout value given to close_gracefully().
+ !*/
+
+ private:
+
+ virtual void on_connect (
+ connection& new_connection
+ )=0;
+ /*!
+ requires
+ - on_connect() is run in its own thread
+ - is_running() == true
+ - the number of current connections < get_max_connection()
+ - new_connection == the new connection to the server which is
+ to be serviced by this call to on_connect()
+ ensures
+ - when new_connection is shutdown() on_connect() will terminate
+ - this function will not call clear()
+ throws
+ - does not throw any exceptions
+ !*/
+
+ // do nothing by default
+ virtual void on_listening_port_assigned (
+ ) {}
+ /*!
+ requires
+ - is called if a listening port of zero was specified and
+ an actual port number has just been assigned to the server
+ ensures
+ - this function will not block
+ - this function will not call clear()
+ throws
+ - does not throw any exceptions
+ !*/
+
+
+ // restricted functions
+ server(server&); // copy constructor
+ server& operator=(server&); // assignment operator
+ };
+
+}
+
+#endif // DLIB_SERVER_KERNEL_ABSTRACT_
+
diff --git a/ml/dlib/dlib/set.h b/ml/dlib/dlib/set.h
new file mode 100644
index 000000000..bec6932a2
--- /dev/null
+++ b/ml/dlib/dlib/set.h
@@ -0,0 +1,74 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEt_
+#define DLIB_SEt_
+
+#include "set/set_kernel_1.h"
+#include "set/set_kernel_c.h"
+
+
+
+#include "binary_search_tree.h"
+
+#include "set/set_compare_1.h"
+
+#include "algs.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<T>
+ >
+ class set
+ {
+ set() {}
+
+
+
+
+
+ typedef typename binary_search_tree<T,char,mem_manager,compare>::kernel_1a
+ binary_search_tree_1;
+
+ typedef typename binary_search_tree<T,char,mem_manager,compare>::kernel_2a
+ binary_search_tree_2;
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef set_kernel_1<T,binary_search_tree_1,mem_manager>
+ kernel_1a;
+ typedef set_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+ // kernel_1b
+ typedef set_kernel_1<T,binary_search_tree_2,mem_manager>
+ kernel_1b;
+ typedef set_kernel_c<kernel_1b>
+ kernel_1b_c;
+
+
+ //---------- extensions ------------
+
+ // compare extensions
+ typedef set_compare_1<kernel_1a>
+ compare_1a;
+ typedef set_compare_1<kernel_1a_c>
+ compare_1a_c;
+
+ typedef set_compare_1<kernel_1b>
+ compare_1b;
+ typedef set_compare_1<kernel_1b_c>
+ compare_1b_c;
+
+ };
+}
+
+#endif // DLIB_SEt_
+
diff --git a/ml/dlib/dlib/set/set_compare_1.h b/ml/dlib/dlib/set/set_compare_1.h
new file mode 100644
index 000000000..d4b5e76ca
--- /dev/null
+++ b/ml/dlib/dlib/set/set_compare_1.h
@@ -0,0 +1,122 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SET_COMPARe_1_
+#define DLIB_SET_COMPARe_1_
+
+#include "set_compare_abstract.h"
+
+#include "../algs.h"
+
+
+
+namespace dlib
+{
+
+ template <
+ typename set_base
+ >
+ class set_compare_1 : public set_base
+ {
+
+ public:
+
+ bool operator< (
+ const set_compare_1& rhs
+ ) const;
+
+ bool operator== (
+ const set_compare_1& rhs
+ ) const;
+
+ };
+
+
+ template <
+ typename set_base
+ >
+ inline void swap (
+ set_compare_1<set_base>& a,
+ set_compare_1<set_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ bool set_compare_1<set_base>::
+ operator< (
+ const set_compare_1<set_base>& rhs
+ ) const
+ {
+ bool result = false;
+ if (set_base::size() < rhs.size())
+ result = true;
+
+ if (set_base::size() == rhs.size())
+ {
+ rhs.reset();
+ set_base::reset();
+ while (rhs.move_next())
+ {
+ set_base::move_next();
+ if (set_base::element() < rhs.element())
+ {
+ result = true;
+ break;
+ }
+ else if (rhs.element() < set_base::element())
+ {
+ break;
+ }
+ }
+ }
+
+ set_base::reset();
+ rhs.reset();
+
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ bool set_compare_1<set_base>::
+ operator== (
+ const set_compare_1<set_base>& rhs
+ ) const
+ {
+ bool result = true;
+ if (set_base::size() != rhs.size())
+ result = false;
+
+
+ rhs.reset();
+ set_base::reset();
+ while (rhs.move_next() && set_base::move_next())
+ {
+ if (!(rhs.element() == set_base::element()))
+ {
+ result = false;
+ break;
+ }
+ }
+
+ set_base::reset();
+ rhs.reset();
+
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SET_COMPARe_1_
+
diff --git a/ml/dlib/dlib/set/set_compare_abstract.h b/ml/dlib/dlib/set/set_compare_abstract.h
new file mode 100644
index 000000000..ff06218f8
--- /dev/null
+++ b/ml/dlib/dlib/set/set_compare_abstract.h
@@ -0,0 +1,96 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SET_COMPARe_ABSTRACT_
+#ifdef DLIB_SET_COMPARe_ABSTRACT_
+
+#include "set_kernel_abstract.h"
+
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+ template <
+ typename set_base
+ >
+ class set_compare : public set_base
+ {
+
+ /*!
+ REQUIREMENTS ON set_base
+ must be an implementation of set/set_kernel_abstract.h
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ operator== and operator< invalidate pointers or references to
+ data members.
+
+ WHAT THIS EXTENSION DOES FOR set
+ This gives a set the ability to compare itself to other
+ sets using the < and == operators.
+
+ The < operator is conceptually weird for sets. It is useful
+ though because it allows you to make sets of sets since
+ sets require that their containing type implement operator<.
+
+ Also note that it is the case that for any two sets a and b
+ if (a<b) == false and (b<a) == false then a == b.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+
+ NOTATION
+ For the purposes of defining what these operators do I will
+ use the operator[] to reference the elements of the sets.
+ operator[] is defined to access the elements of the set in
+ the same order they would be enumerated by the enumerable
+ interface.
+ !*/
+
+ public:
+
+ bool operator< (
+ const set_compare& rhs
+ ) const;
+ /*!
+ ensures
+ - #at_start() == true
+ - if (size() < rhs.size()) then
+ - returns true
+ - else if (size() > rhs.size()) then
+ - returns false
+ - else
+ - returns true if there exists an integer j such that 0 <= j < size()
+ and for all integers i such that 0 <= i < j where it is true that
+ (*this)[i] == rhs[i] and (*this)[j] < rhs[j]
+ - returns false if there is no j that will satisfy the above conditions.
+ !*/
+
+ bool operator== (
+ const set_compare& rhs
+ ) const;
+ /*!
+ ensures
+ - #at_start() == true
+ - returns true if *this and rhs contain the same elements.
+ returns false otherwise.
+ !*/
+ };
+
+
+ template <
+ typename set_base
+ >
+ inline void swap (
+ set_compare<set_base>& a,
+ set_compare<set_base>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_SET_COMPARe_ABSTRACT_
+
diff --git a/ml/dlib/dlib/set/set_kernel_1.h b/ml/dlib/dlib/set/set_kernel_1.h
new file mode 100644
index 000000000..9df96e671
--- /dev/null
+++ b/ml/dlib/dlib/set/set_kernel_1.h
@@ -0,0 +1,372 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SET_KERNEl_1_
+#define DLIB_SET_KERNEl_1_
+
+#include "set_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager = default_memory_manager
+ >
+ class set_kernel_1 : public enumerable<const T>,
+ public asc_remover<T,typename bst_base::compare_type>
+ {
+
+ /*!
+ REQUIREMENTS ON bst_base
+ bst_base is instantiated with <domain=T,range=char> and
+ implements binray_search_tree/binary_search_tree_kernel_abstract.h
+
+ INITIAL VALUE
+ bst has its initial value
+
+ CONVENTION
+ bst.size() == the number of elements in the set and
+ the elements in the set are stored in bst
+ !*/
+
+
+ public:
+
+ typedef T type;
+ typedef typename bst_base::compare_type compare_type;
+ typedef mem_manager mem_manager_type;
+
+ set_kernel_1(
+ )
+ {
+ }
+
+ virtual ~set_kernel_1(
+ )
+ {}
+
+ inline void clear(
+ );
+
+ inline void add (
+ T& item
+ );
+
+ inline bool is_member (
+ const T& item
+ ) const;
+
+ inline void remove (
+ const T& item,
+ T& item_copy
+ );
+
+ inline void destroy (
+ const T& item
+ );
+
+ inline void swap (
+ set_kernel_1& item
+ );
+
+ // functions from the remover interface
+ inline void remove_any (
+ T& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ inline bool current_element_valid (
+ ) const;
+
+ inline const T& element (
+ ) const;
+
+
+ inline const T& element (
+ );
+
+ inline bool move_next (
+ ) const;
+
+
+ private:
+
+ bst_base bst;
+ char junk;
+
+ // restricted functions
+ set_kernel_1(set_kernel_1&);
+ set_kernel_1& operator=(set_kernel_1&);
+
+ };
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ inline void swap (
+ set_kernel_1<T,bst_base,mem_manager>& a,
+ set_kernel_1<T,bst_base,mem_manager>& b
+ ) { a.swap(b); }
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ void deserialize (
+ set_kernel_1<T,bst_base,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ T temp;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(temp,in);
+ item.add(temp);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type set_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ void set_kernel_1<T,bst_base,mem_manager>::
+ clear (
+ )
+ {
+ bst.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ void set_kernel_1<T,bst_base,mem_manager>::
+ add (
+ T& item
+ )
+ {
+ bst.add(item,junk);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ bool set_kernel_1<T,bst_base,mem_manager>::
+ is_member(
+ const T& item
+ ) const
+ {
+ return (bst[item] != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ void set_kernel_1<T,bst_base,mem_manager>::
+ remove_any (
+ T& item
+ )
+ {
+ bst.remove_any(item,junk);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ void set_kernel_1<T,bst_base,mem_manager>::
+ remove(
+ const T& item,
+ T& item_copy
+ )
+ {
+ bst.remove(item,item_copy,junk);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ void set_kernel_1<T,bst_base,mem_manager>::
+ destroy(
+ const T& item
+ )
+ {
+ bst.destroy(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ size_t set_kernel_1<T,bst_base,mem_manager>::
+ size (
+ ) const
+ {
+ return bst.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ void set_kernel_1<T,bst_base,mem_manager>::
+ swap (
+ set_kernel_1<T,bst_base,mem_manager>& item
+ )
+ {
+ bst.swap(item.bst);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ bool set_kernel_1<T,bst_base,mem_manager>::
+ at_start (
+ ) const
+ {
+ return bst.at_start();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ void set_kernel_1<T,bst_base,mem_manager>::
+ reset (
+ ) const
+ {
+ bst.reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ bool set_kernel_1<T,bst_base,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return bst.current_element_valid();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ const T& set_kernel_1<T,bst_base,mem_manager>::
+ element (
+ ) const
+ {
+ return bst.element().key();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ const T& set_kernel_1<T,bst_base,mem_manager>::
+ element (
+ )
+ {
+ return bst.element().key();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename bst_base,
+ typename mem_manager
+ >
+ bool set_kernel_1<T,bst_base,mem_manager>::
+ move_next (
+ ) const
+ {
+ return bst.move_next();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SET_KERNEl_1_
+
diff --git a/ml/dlib/dlib/set/set_kernel_abstract.h b/ml/dlib/dlib/set/set_kernel_abstract.h
new file mode 100644
index 000000000..72a0a98b3
--- /dev/null
+++ b/ml/dlib/dlib/set/set_kernel_abstract.h
@@ -0,0 +1,192 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SET_KERNEl_ABSTRACT_
+#ifdef DLIB_SET_KERNEl_ABSTRACT_
+
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include "../algs.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager,
+ typename compare = std::less<T>
+ >
+ class set : public enumerable<const T>,
+ public asc_remover<T,compare>
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must be comparable by compare where compare is a functor compatible with std::less and
+ T must be swappable by a global swap() and
+ T must have a default constructor
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap() and is_member() functions do not invalidate pointers
+ or references to internal data.
+ All other functions have no such guarantee.
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the elements in the set in
+ ascending order according to the compare functor.
+ (i.e. the elements are enumerated in sorted order)
+
+ WHAT THIS OBJECT REPRESENTS
+ set contains items of type T
+
+ This object represents an unaddressed collection of items.
+ Every element in a set is unique.
+
+ definition of equivalent:
+ a is equivalent to b if
+ a < b == false and
+ b < a == false
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef compare compare_type;
+ typedef mem_manager mem_manager_type;
+
+ set(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ !*/
+
+ virtual ~set(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void add (
+ T& item
+ );
+ /*!
+ requires
+ - is_member(item) == false
+ ensures
+ - #is_member(item) == true
+ - #item has an initial value for its type
+ - #size() == size() + 1
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if add() throws then it has no effect
+ !*/
+
+ bool is_member (
+ const T& item
+ ) const;
+ /*!
+ ensures
+ - returns whether or not there is an element in *this equivalent to
+ item
+ !*/
+
+ void remove (
+ const T& item,
+ T& item_copy
+ );
+ /*!
+ requires
+ - is_member(item) == true
+ - &item != &item_copy (i.e. item and item_copy cannot be the same
+ variable)
+ ensures
+ - #is_member(item) == false
+ - the element in *this equivalent to item has been removed and
+ swapped into #item_copy
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void destroy (
+ const T& item
+ );
+ /*!
+ requires
+ - is_member(item) == true
+ ensures
+ - #is_member(item) == false
+ - #size() == size() - 1
+ - #at_start() == true
+ !*/
+
+ void swap (
+ set& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ set(set&); // copy constructor
+ set& operator=(set&); // assignment operator
+
+ };
+
+ template <
+ typename T,
+ typename mem_manager,
+ typename compare
+ >
+ inline void swap (
+ set<T,mem_manager,compare>& a,
+ set<T,mem_manager,compare>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ typename mem_manager,
+ typename compare
+ >
+ void deserialize (
+ set<T,mem_manager,compare>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_SET_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/set/set_kernel_c.h b/ml/dlib/dlib/set/set_kernel_c.h
new file mode 100644
index 000000000..679f9fae4
--- /dev/null
+++ b/ml/dlib/dlib/set/set_kernel_c.h
@@ -0,0 +1,194 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SET_KERNEl_C_
+#define DLIB_SET_KERNEl_C_
+
+#include "set_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ typename set_base
+ >
+ class set_kernel_c : public set_base
+ {
+ typedef typename set_base::type T;
+ public:
+
+ void add (
+ T& item
+ );
+
+ void remove_any (
+ T& item
+ );
+
+ void remove (
+ const T& item,
+ T& item_copy
+ );
+
+ void destroy (
+ const T& item
+ );
+
+ const T& element (
+ );
+
+ const T& element (
+ ) const;
+ };
+
+
+ template <
+ typename set_base
+ >
+ inline void swap (
+ set_kernel_c<set_base>& a,
+ set_kernel_c<set_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ void set_kernel_c<set_base>::
+ add(
+ T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( !this->is_member(item),
+ "\tvoid set::add"
+ << "\n\titem being added must not already be in the set"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ set_base::add(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ void set_kernel_c<set_base>::
+ remove (
+ const T& item,
+ T& item_copy
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->is_member(item) &&
+ (static_cast<const void*>(&item) != static_cast<void*>(&item_copy)),
+ "\tvoid set::remove"
+ << "\n\titem should be in the set if it's going to be removed"
+ << "\n\tthis: " << this
+ << "\n\t&item: " << &item
+ << "\n\t&item_copy: " << &item_copy
+ << "\n\tis_member(item): " << (this->is_member(item)?"true":"false")
+ );
+
+ // call the real function
+ set_base::remove(item,item_copy);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ void set_kernel_c<set_base>::
+ destroy (
+ const T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->is_member(item),
+ "\tvoid set::destroy"
+ << "\n\titem should be in the set if it's going to be removed"
+ << "\n\tthis: " << this
+ << "\n\t&item: " << &item
+ );
+
+ // call the real function
+ set_base::destroy(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ void set_kernel_c<set_base>::
+ remove_any (
+ T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->size() != 0,
+ "\tvoid set::remove_any"
+ << "\n\tsize must be greater than zero if an item is to be removed"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ set_base::remove_any(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ const typename set_base::type& set_kernel_c<set_base>::
+ element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst T& set::element() const"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return set_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ const typename set_base::type& set_kernel_c<set_base>::
+ element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst T& set::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return set_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+#endif // DLIB_SET_KERNEl_C_
+
diff --git a/ml/dlib/dlib/set_utils.h b/ml/dlib/dlib/set_utils.h
new file mode 100644
index 000000000..a4ee5d3da
--- /dev/null
+++ b/ml/dlib/dlib/set_utils.h
@@ -0,0 +1,11 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SET_UTILs_H_
+#define DLIB_SET_UTILs_H_
+
+#include "set_utils/set_utils.h"
+
+#endif // DLIB_SET_UTILs_H_
+
+
+
diff --git a/ml/dlib/dlib/set_utils/set_utils.h b/ml/dlib/dlib/set_utils/set_utils.h
new file mode 100644
index 000000000..c47b5ccd7
--- /dev/null
+++ b/ml/dlib/dlib/set_utils/set_utils.h
@@ -0,0 +1,246 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SET_UTILs_
+#define DLIB_SET_UTILs_
+
+#include "../algs.h"
+#include "set_utils_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ unsigned long set_intersection_size (
+ const T& a,
+ const U& b
+ )
+ {
+ if (is_same_object(a,b))
+ return a.size();
+
+ unsigned long num = 0;
+
+ if (a.size() < b.size())
+ {
+ a.reset();
+ while (a.move_next())
+ {
+ if (b.is_member(a.element()))
+ ++num;
+ }
+ }
+ else
+ {
+ b.reset();
+ while (b.move_next())
+ {
+ if (a.is_member(b.element()))
+ ++num;
+ }
+ }
+
+ return num;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V
+ >
+ void set_union (
+ const T& a,
+ const U& b,
+ V& u
+ )
+ {
+ typedef typename T::type type;
+ if (is_same_object(a,u) || is_same_object(b,u))
+ {
+ V local_u;
+ type temp;
+ a.reset();
+ while (a.move_next())
+ {
+ temp = a.element();
+ local_u.add(temp);
+ }
+
+ b.reset();
+ while (b.move_next())
+ {
+ if (a.is_member(b.element()) == false)
+ {
+ temp = b.element();
+ local_u.add(temp);
+ }
+ }
+
+ local_u.swap(u);
+ }
+ else
+ {
+ u.clear();
+
+ type temp;
+ a.reset();
+ while (a.move_next())
+ {
+ temp = a.element();
+ u.add(temp);
+ }
+
+ b.reset();
+ while (b.move_next())
+ {
+ if (a.is_member(b.element()) == false)
+ {
+ temp = b.element();
+ u.add(temp);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V
+ >
+ void set_intersection (
+ const T& a,
+ const U& b,
+ V& i
+ )
+ {
+ typedef typename T::type type;
+ if (is_same_object(a,i) || is_same_object(b,i))
+ {
+ V local_i;
+
+ type temp;
+
+ if (a.size() < b.size())
+ {
+ a.reset();
+ while (a.move_next())
+ {
+ if (b.is_member(a.element()))
+ {
+ temp = a.element();
+ local_i.add(temp);
+ }
+ }
+ }
+ else
+ {
+ b.reset();
+ while (b.move_next())
+ {
+ if (a.is_member(b.element()))
+ {
+ temp = b.element();
+ local_i.add(temp);
+ }
+ }
+ }
+
+ local_i.swap(i);
+ }
+ else
+ {
+ i.clear();
+ type temp;
+
+ if (a.size() < b.size())
+ {
+ a.reset();
+ while (a.move_next())
+ {
+ if (b.is_member(a.element()))
+ {
+ temp = a.element();
+ i.add(temp);
+ }
+ }
+ }
+ else
+ {
+ b.reset();
+ while (b.move_next())
+ {
+ if (a.is_member(b.element()))
+ {
+ temp = b.element();
+ i.add(temp);
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V
+ >
+ void set_difference (
+ const T& a,
+ const U& b,
+ V& d
+ )
+ {
+ typedef typename T::type type;
+ if (is_same_object(a,d) || is_same_object(b,d))
+ {
+ V local_d;
+
+ type temp;
+
+ a.reset();
+ while (a.move_next())
+ {
+ if (b.is_member(a.element()) == false)
+ {
+ temp = a.element();
+ local_d.add(temp);
+ }
+ }
+
+ local_d.swap(d);
+ }
+ else
+ {
+ d.clear();
+ type temp;
+
+ a.reset();
+ while (a.move_next())
+ {
+ if (b.is_member(a.element()) == false)
+ {
+ temp = a.element();
+ d.add(temp);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SET_UTILs_
+
+
+
diff --git a/ml/dlib/dlib/set_utils/set_utils_abstract.h b/ml/dlib/dlib/set_utils/set_utils_abstract.h
new file mode 100644
index 000000000..e1eba6afc
--- /dev/null
+++ b/ml/dlib/dlib/set_utils/set_utils_abstract.h
@@ -0,0 +1,98 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SET_UTILs_ABSTRACT_
+#ifdef DLIB_SET_UTILs_ABSTRACT_
+
+#include "../set.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ unsigned long set_intersection_size (
+ const T& a,
+ const U& b
+ );
+ /*!
+ requires
+ - T and U must both be implementations of set/set_kernel_abstract.h
+ ensures
+ - returns the number of elements that are in both set a and b
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V
+ >
+ void set_union (
+ const T& a,
+ const U& b,
+ V& u
+ );
+ /*!
+ requires
+ - T, U, and V must all be implementations of set/set_kernel_abstract.h
+ - the types of objects contained in these sets must be copyable
+ ensures
+ - #u == the union of a and b. That is, u contains all elements
+ of a and all the elements of b.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V
+ >
+ void set_intersection (
+ const T& a,
+ const U& b,
+ V& i
+ );
+ /*!
+ requires
+ - T, U, and V must all be implementations of set/set_kernel_abstract.h
+ - the types of objects contained in these sets must be copyable
+ ensures
+ - #i == the intersection of a and b. That is, i contains all elements
+ of a that are also in b.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V
+ >
+ void set_difference (
+ const T& a,
+ const U& b,
+ V& d
+ );
+ /*!
+ requires
+ - T, U, and V must all be implementations of set/set_kernel_abstract.h
+ - the types of objects contained in these sets must be copyable
+ ensures
+ - #d == the difference of a and b. That is, d contains all elements
+ of a that are NOT in b.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SET_UTILs_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/simd.h b/ml/dlib/dlib/simd.h
new file mode 100644
index 000000000..93df6702d
--- /dev/null
+++ b/ml/dlib/dlib/simd.h
@@ -0,0 +1,12 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SIMd_Hh_
+#define DLIB_SIMd_Hh_
+
+#include "simd/simd4f.h"
+#include "simd/simd4i.h"
+#include "simd/simd8f.h"
+#include "simd/simd8i.h"
+
+#endif // DLIB_SIMd_Hh_
+
diff --git a/ml/dlib/dlib/simd/simd4f.h b/ml/dlib/dlib/simd/simd4f.h
new file mode 100644
index 000000000..2bfadd23f
--- /dev/null
+++ b/ml/dlib/dlib/simd/simd4f.h
@@ -0,0 +1,685 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_sIMD4F_Hh_
+#define DLIB_sIMD4F_Hh_
+
+#include "simd_check.h"
+#include "simd4i.h"
+#include <cmath>
+#include <iostream>
+
+namespace dlib
+{
+
+#ifdef DLIB_HAVE_SSE2
+ class simd4f
+ {
+ public:
+ typedef float type;
+
+ inline simd4f() {}
+ inline simd4f(float f) { x = _mm_set1_ps(f); }
+ inline simd4f(float r0, float r1, float r2, float r3) { x = _mm_setr_ps(r0,r1,r2,r3); }
+ inline simd4f(const __m128& val):x(val) {}
+ inline simd4f(const simd4i& val):x(_mm_cvtepi32_ps(val)) {}
+
+ inline simd4f& operator=(const simd4i& val)
+ {
+ x = simd4f(val);
+ return *this;
+ }
+
+ inline simd4f& operator=(const float& val)
+ {
+ x = simd4f(val);
+ return *this;
+ }
+
+ inline simd4f& operator=(const __m128& val)
+ {
+ x = val;
+ return *this;
+ }
+
+ inline operator __m128() const { return x; }
+
+ // truncate to 32bit integers
+ inline operator __m128i() const { return _mm_cvttps_epi32(x); }
+
+ inline void load_aligned(const type* ptr) { x = _mm_load_ps(ptr); }
+ inline void store_aligned(type* ptr) const { _mm_store_ps(ptr, x); }
+ inline void load(const type* ptr) { x = _mm_loadu_ps(ptr); }
+ inline void store(type* ptr) const { _mm_storeu_ps(ptr, x); }
+
+ inline unsigned int size() const { return 4; }
+ inline float operator[](unsigned int idx) const
+ {
+ float temp[4];
+ store(temp);
+ return temp[idx];
+ }
+
+ private:
+ __m128 x;
+ };
+
+ class simd4f_bool
+ {
+ public:
+ typedef float type;
+
+ inline simd4f_bool() {}
+ inline simd4f_bool(const __m128& val):x(val) {}
+
+ inline simd4f_bool& operator=(const __m128& val)
+ {
+ x = val;
+ return *this;
+ }
+
+ inline operator __m128() const { return x; }
+
+
+ private:
+ __m128 x;
+ };
+
+#elif defined(DLIB_HAVE_VSX)
+
+ class simd4f
+ {
+ typedef union {
+ vector float v;
+ float x[4];
+ } v4f;
+
+ v4f x;
+
+ public:
+ inline simd4f() : x{0,0,0,0} {}
+ inline simd4f(const simd4f& v) : x(v.x) { }
+ inline simd4f(const vector float& v) : x{v} { }
+
+ inline simd4f(const simd4i& v) {
+ x.x[0]=v[0]; x.x[1]=v[1]; x.x[2]=v[2]; x.x[3]=v[3];
+ }
+
+
+ inline simd4f(float f) : x{f,f,f,f} { }
+ inline simd4f(float r0, float r1, float r2, float r3)
+ : x{r0,r1,r2,r3} { }
+
+ inline simd4f& operator=(const simd4f& v) { x = v.x; return *this; }
+ inline simd4f& operator=(const float& v) { *this = simd4f(v); return *this; }
+
+ inline vector float operator() () const { return x.v; }
+ inline float operator[](unsigned int idx) const { return x.x[idx]; }
+
+ inline void load_aligned(const float* ptr) { x.v = vec_ld(0, ptr); }
+ inline void store_aligned(float* ptr) const { vec_st(x.v, 0, ptr); }
+ inline void load(const float* ptr) { x.v = vec_vsx_ld(0, ptr); }
+ inline void store(float* ptr) const { vec_vsx_st(x.v, 0, ptr); }
+
+
+ // truncate to 32bit integers
+ inline operator simd4i::rawarray() const
+ {
+ simd4i::rawarray temp;
+ temp.v.x[0] = x.x[0];
+ temp.v.x[1] = x.x[1];
+ temp.v.x[2] = x.x[2];
+ temp.v.x[3] = x.x[3];
+ return temp;
+ }
+ };
+
+ typedef simd4i simd4f_bool;
+
+#elif defined(DLIB_HAVE_NEON)
+
+ class simd4f
+ {
+ public:
+ typedef float type;
+
+ inline simd4f() {}
+ inline simd4f(float f) { x = vdupq_n_f32(f); }
+ inline simd4f(float r0, float r1, float r2, float r3)
+ {
+ float __attribute__ ((aligned (16))) data[4] = { r0, r1, r2, r3 };
+ x = vld1q_f32(data);
+ }
+ inline simd4f(const float32x4_t& val):x(val) {}
+ inline simd4f(const simd4i& val):x(vcvtq_f32_s32(val)) {}
+
+ inline simd4f& operator=(const simd4i& val)
+ {
+ x = simd4f(val);
+ return *this;
+ }
+
+ inline simd4f& operator=(const float& val)
+ {
+ x = simd4f(val);
+ return *this;
+ }
+
+ inline simd4f& operator=(const float32x4_t& val)
+ {
+ x = val;
+ return *this;
+ }
+
+ inline operator float32x4_t() const { return x; }
+
+ // truncate to 32bit integers
+ inline operator int32x4_t() const { return vcvtq_s32_f32(x); }
+
+ inline void load_aligned(const type* ptr) { x = vld1q_f32(ptr); }
+ inline void store_aligned(type* ptr) const { vst1q_f32(ptr, x); }
+ inline void load(const type* ptr) { x = vld1q_f32(ptr); }
+ inline void store(type* ptr) const { vst1q_f32(ptr, x); }
+
+ inline unsigned int size() const { return 4; }
+ inline float operator[](unsigned int idx) const
+ {
+ float temp[4];
+ store(temp);
+ return temp[idx];
+ }
+
+ private:
+ float32x4_t x;
+ };
+
+
+ typedef simd4i simd4f_bool;
+#else
+ class simd4f
+ {
+ public:
+ typedef float type;
+
+ inline simd4f() {}
+ inline simd4f(float f) { x[0]=f; x[1]=f; x[2]=f; x[3]=f; }
+ inline simd4f(float r0, float r1, float r2, float r3) { x[0]=r0; x[1]=r1; x[2]=r2; x[3]=r3;}
+ inline simd4f(const simd4i& val) { x[0]=val[0]; x[1]=val[1]; x[2]=val[2]; x[3]=val[3];}
+
+ // truncate to 32bit integers
+ inline operator simd4i::rawarray() const
+ {
+ simd4i::rawarray temp;
+ temp.a[0] = (int32)x[0];
+ temp.a[1] = (int32)x[1];
+ temp.a[2] = (int32)x[2];
+ temp.a[3] = (int32)x[3];
+ return temp;
+ }
+
+ inline simd4f& operator=(const float& val)
+ {
+ *this = simd4f(val);
+ return *this;
+ }
+
+ inline simd4f& operator=(const simd4i& val)
+ {
+ x[0] = val[0];
+ x[1] = val[1];
+ x[2] = val[2];
+ x[3] = val[3];
+ return *this;
+ }
+
+
+ inline void load_aligned(const type* ptr)
+ {
+ x[0] = ptr[0];
+ x[1] = ptr[1];
+ x[2] = ptr[2];
+ x[3] = ptr[3];
+ }
+
+ inline void store_aligned(type* ptr) const
+ {
+ ptr[0] = x[0];
+ ptr[1] = x[1];
+ ptr[2] = x[2];
+ ptr[3] = x[3];
+ }
+
+ inline void load(const type* ptr)
+ {
+ x[0] = ptr[0];
+ x[1] = ptr[1];
+ x[2] = ptr[2];
+ x[3] = ptr[3];
+ }
+
+ inline void store(type* ptr) const
+ {
+ ptr[0] = x[0];
+ ptr[1] = x[1];
+ ptr[2] = x[2];
+ ptr[3] = x[3];
+ }
+
+ inline unsigned int size() const { return 4; }
+ inline float operator[](unsigned int idx) const { return x[idx]; }
+
+ private:
+ float x[4];
+ };
+
+
+ class simd4f_bool
+ {
+ public:
+ typedef float type;
+
+ inline simd4f_bool() {}
+ inline simd4f_bool(bool r0, bool r1, bool r2, bool r3) { x[0]=r0; x[1]=r1; x[2]=r2; x[3]=r3;}
+
+ inline bool operator[](unsigned int idx) const { return x[idx]; }
+ private:
+ bool x[4];
+ };
+
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::ostream& operator<<(std::ostream& out, const simd4f& item)
+ {
+ float temp[4];
+ item.store(temp);
+ out << "(" << temp[0] << ", " << temp[1] << ", " << temp[2] << ", " << temp[3] << ")";
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f operator+ (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_add_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_add(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vaddq_f32(lhs, rhs);
+#else
+ return simd4f(lhs[0]+rhs[0],
+ lhs[1]+rhs[1],
+ lhs[2]+rhs[2],
+ lhs[3]+rhs[3]);
+#endif
+ }
+ inline simd4f& operator+= (simd4f& lhs, const simd4f& rhs)
+ { lhs = lhs + rhs; return lhs; }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f operator- (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_sub_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_sub(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vsubq_f32(lhs, rhs);
+#else
+ return simd4f(lhs[0]-rhs[0],
+ lhs[1]-rhs[1],
+ lhs[2]-rhs[2],
+ lhs[3]-rhs[3]);
+#endif
+ }
+ inline simd4f& operator-= (simd4f& lhs, const simd4f& rhs)
+ { lhs = lhs - rhs; return lhs; }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f operator* (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_mul_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_mul(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vmulq_f32(lhs, rhs);
+#else
+ return simd4f(lhs[0]*rhs[0],
+ lhs[1]*rhs[1],
+ lhs[2]*rhs[2],
+ lhs[3]*rhs[3]);
+#endif
+ }
+ inline simd4f& operator*= (simd4f& lhs, const simd4f& rhs)
+ { lhs = lhs * rhs; return lhs; }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f operator/ (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_div_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_div(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ float32x4_t reciprocal = vrecpeq_f32(rhs);
+ reciprocal = vmulq_f32(vrecpsq_f32(rhs, reciprocal), reciprocal);
+ reciprocal = vmulq_f32(vrecpsq_f32(rhs, reciprocal), reciprocal);
+ float32x4_t result = vmulq_f32(lhs,reciprocal);
+ return result;
+#else
+ return simd4f(lhs[0]/rhs[0],
+ lhs[1]/rhs[1],
+ lhs[2]/rhs[2],
+ lhs[3]/rhs[3]);
+#endif
+ }
+ inline simd4f& operator/= (simd4f& lhs, const simd4f& rhs)
+ { lhs = lhs / rhs; return lhs; }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f_bool operator== (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_cmpeq_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_cmpeq(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return (int32x4_t)vceqq_f32(lhs, rhs);
+#else
+ return simd4f_bool(lhs[0]==rhs[0],
+ lhs[1]==rhs[1],
+ lhs[2]==rhs[2],
+ lhs[3]==rhs[3]);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f_bool operator!= (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_cmpneq_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX) || defined(DLIB_HAVE_NEON)
+ return ~(lhs==rhs); // simd4f_bool is simd4i typedef, can use ~
+#else
+ return simd4f_bool(lhs[0]!=rhs[0],
+ lhs[1]!=rhs[1],
+ lhs[2]!=rhs[2],
+ lhs[3]!=rhs[3]);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f_bool operator< (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_cmplt_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_cmplt(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return (int32x4_t)vcltq_f32(lhs, rhs);
+#else
+ return simd4f_bool(lhs[0]<rhs[0],
+ lhs[1]<rhs[1],
+ lhs[2]<rhs[2],
+ lhs[3]<rhs[3]);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f_bool operator> (const simd4f& lhs, const simd4f& rhs)
+ {
+ return rhs < lhs;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f_bool operator<= (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_cmple_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_cmple(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return (int32x4_t)vcleq_f32(lhs, rhs);
+#else
+ return simd4f_bool(lhs[0]<=rhs[0],
+ lhs[1]<=rhs[1],
+ lhs[2]<=rhs[2],
+ lhs[3]<=rhs[3]);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f_bool operator>= (const simd4f& lhs, const simd4f& rhs)
+ {
+ return rhs <= lhs;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f min (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_min_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_min(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vminq_f32(lhs, rhs);
+#else
+ return simd4f(std::min(lhs[0],rhs[0]),
+ std::min(lhs[1],rhs[1]),
+ std::min(lhs[2],rhs[2]),
+ std::min(lhs[3],rhs[3]));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f max (const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_max_ps(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_max(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vmaxq_f32(lhs, rhs);
+#else
+ return simd4f(std::max(lhs[0],rhs[0]),
+ std::max(lhs[1],rhs[1]),
+ std::max(lhs[2],rhs[2]),
+ std::max(lhs[3],rhs[3]));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f reciprocal (const simd4f& item)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_rcp_ps(item);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_re(item());
+#elif defined(DLIB_HAVE_NEON)
+ float32x4_t estimate = vrecpeq_f32(item);
+ estimate = vmulq_f32(vrecpsq_f32(estimate , item), estimate );
+ estimate = vmulq_f32(vrecpsq_f32(estimate , item), estimate );
+ return estimate ;
+#else
+ return simd4f(1.0f/item[0],
+ 1.0f/item[1],
+ 1.0f/item[2],
+ 1.0f/item[3]);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f reciprocal_sqrt (const simd4f& item)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_rsqrt_ps(item);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_rsqrt(item());
+#elif defined(DLIB_HAVE_NEON)
+ float32x4_t estimate = vrsqrteq_f32(item);
+ simd4f estimate2 = vmulq_f32(estimate, item);
+ estimate = vmulq_f32(estimate, vrsqrtsq_f32(estimate2, estimate));
+ return estimate;
+#else
+ return simd4f(1.0f/std::sqrt(item[0]),
+ 1.0f/std::sqrt(item[1]),
+ 1.0f/std::sqrt(item[2]),
+ 1.0f/std::sqrt(item[3]));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline float dot(const simd4f& lhs, const simd4f& rhs);
+ inline float sum(const simd4f& item)
+ {
+#ifdef DLIB_HAVE_SSE41
+ return dot(simd4f(1), item);
+#elif defined(DLIB_HAVE_SSE3)
+ simd4f temp = _mm_hadd_ps(item,item);
+ return _mm_cvtss_f32(_mm_hadd_ps(temp,temp));
+#elif defined(DLIB_HAVE_SSE2) && (!defined(_MSC_VER) || _MSC_VER!=1400)
+ simd4f temp = _mm_add_ps(item,_mm_movehl_ps(item,item));
+ simd4f temp2 = _mm_shuffle_ps(temp,temp,1);
+ return _mm_cvtss_f32(_mm_add_ss(temp,temp2));
+#elif defined(DLIB_HAVE_NEON)
+ float32x2_t r = vadd_f32(vget_high_f32(item), vget_low_f32(item));
+ return vget_lane_f32(vpadd_f32(r, r), 0);
+#else
+ return item[0]+item[1]+item[2]+item[3];
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline float dot(const simd4f& lhs, const simd4f& rhs)
+ {
+#ifdef DLIB_HAVE_SSE41
+ return _mm_cvtss_f32(_mm_dp_ps(lhs, rhs, 0xff));
+#else
+ return sum(lhs*rhs);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f sqrt(const simd4f& item)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_sqrt_ps(item);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_sqrt(item());
+#elif defined(DLIB_HAVE_NEON)
+ float32x4_t q_step_0 = vrsqrteq_f32(item);
+ float32x4_t q_step_parm0 = vmulq_f32(item, q_step_0);
+ float32x4_t q_step_result0 = vrsqrtsq_f32(q_step_parm0, q_step_0);
+ float32x4_t q_step_1 = vmulq_f32(q_step_0, q_step_result0);
+ float32x4_t q_step_parm1 = vmulq_f32(item, q_step_1);
+ float32x4_t q_step_result1 = vrsqrtsq_f32(q_step_parm1, q_step_1);
+ float32x4_t q_step_2 = vmulq_f32(q_step_1, q_step_result1);
+ float32x4_t res3 = vmulq_f32(item, q_step_2);
+
+ // normalize sqrt(0)=0
+ uint32x4_t zcomp = vceqq_f32(vdupq_n_f32(0), item);
+ float32x4_t rcorr = vbslq_f32(zcomp, item, res3);
+ return rcorr;
+#else
+ return simd4f(std::sqrt(item[0]),
+ std::sqrt(item[1]),
+ std::sqrt(item[2]),
+ std::sqrt(item[3]));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f ceil(const simd4f& item)
+ {
+#ifdef DLIB_HAVE_SSE41
+ return _mm_ceil_ps(item);
+#elif defined(DLIB_HAVE_SSE2) || defined(DLIB_HAVE_NEON)
+ float temp[4];
+ item.store(temp);
+ temp[0] = std::ceil(temp[0]);
+ temp[1] = std::ceil(temp[1]);
+ temp[2] = std::ceil(temp[2]);
+ temp[3] = std::ceil(temp[3]);
+ simd4f temp2;
+ temp2.load(temp);
+ return temp2;
+#elif defined(DLIB_HAVE_VSX)
+ return vec_ceil(item());
+#else
+ return simd4f(std::ceil(item[0]),
+ std::ceil(item[1]),
+ std::ceil(item[2]),
+ std::ceil(item[3]));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4f floor(const simd4f& item)
+ {
+#ifdef DLIB_HAVE_SSE41
+ return _mm_floor_ps(item);
+#elif defined(DLIB_HAVE_SSE2) || defined(DLIB_HAVE_NEON)
+ float temp[4];
+ item.store(temp);
+ temp[0] = std::floor(temp[0]);
+ temp[1] = std::floor(temp[1]);
+ temp[2] = std::floor(temp[2]);
+ temp[3] = std::floor(temp[3]);
+ simd4f temp2;
+ temp2.load(temp);
+ return temp2;
+#elif defined(DLIB_HAVE_VSX)
+ return vec_floor(item());
+#else
+ return simd4f(std::floor(item[0]),
+ std::floor(item[1]),
+ std::floor(item[2]),
+ std::floor(item[3]));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // perform cmp ? a : b
+ inline simd4f select(const simd4f_bool& cmp, const simd4f& a, const simd4f& b)
+ {
+#ifdef DLIB_HAVE_SSE41
+ return _mm_blendv_ps(b,a,cmp);
+#elif defined(DLIB_HAVE_SSE2)
+ return _mm_or_ps(_mm_and_ps(cmp,a) , _mm_andnot_ps(cmp,b));
+#elif defined(DLIB_HAVE_NEON)
+ return vbslq_f32(cmp, a, b);
+#else
+ return simd4f(cmp[0]?a[0]:b[0],
+ cmp[1]?a[1]:b[1],
+ cmp[2]?a[2]:b[2],
+ cmp[3]?a[3]:b[3]);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_sIMD4F_Hh_
+
diff --git a/ml/dlib/dlib/simd/simd4i.h b/ml/dlib/dlib/simd/simd4i.h
new file mode 100644
index 000000000..ea33f14a8
--- /dev/null
+++ b/ml/dlib/dlib/simd/simd4i.h
@@ -0,0 +1,566 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_sIMD4I_Hh_
+#define DLIB_sIMD4I_Hh_
+
+#include "simd_check.h"
+#include "../uintn.h"
+
+namespace dlib
+{
+
+#ifdef DLIB_HAVE_SSE2
+ class simd4i
+ {
+ public:
+ typedef int32 type;
+
+ inline simd4i() {}
+ inline simd4i(int32 f) { x = _mm_set1_epi32(f); }
+ inline simd4i(int32 r0, int32 r1, int32 r2, int32 r3) { x = _mm_setr_epi32(r0,r1,r2,r3); }
+ inline simd4i(const __m128i& val):x(val) {}
+
+ inline simd4i& operator=(const __m128i& val)
+ {
+ x = val;
+ return *this;
+ }
+
+ inline operator __m128i() const { return x; }
+
+ inline void load_aligned(const type* ptr) { x = _mm_load_si128((const __m128i*)ptr); }
+ inline void store_aligned(type* ptr) const { _mm_store_si128((__m128i*)ptr, x); }
+ inline void load(const type* ptr) { x = _mm_loadu_si128((const __m128i*)ptr); }
+ inline void store(type* ptr) const { _mm_storeu_si128((__m128i*)ptr, x); }
+
+ inline unsigned int size() const { return 4; }
+ inline int32 operator[](unsigned int idx) const
+ {
+ int32 temp[4];
+ store(temp);
+ return temp[idx];
+ }
+
+ private:
+ __m128i x;
+ };
+
+#elif defined(DLIB_HAVE_VSX)
+
+ class simd4i
+ {
+ typedef union {
+ vector signed int v;
+ vector bool int b;
+ signed int x[4];
+ } v4i;
+
+ v4i x;
+
+ public:
+ inline simd4i() : x{0,0,0,0} { }
+ inline simd4i(const simd4i& v) : x(v.x) { }
+ inline simd4i(const vector int& v) : x{v} { }
+ inline simd4i(const vector bool int& b) { x.b=b; }
+
+ inline simd4i(int32 f) : x{f,f,f,f} { }
+ inline simd4i(int32 r0, int32 r1, int32 r2, int32 r3)
+ : x{r0,r1,r2,r3} { }
+
+ inline simd4i& operator=(const simd4i& v) { x = v.x; return *this; }
+ inline simd4i& operator=(const int32& v) { *this = simd4i(v); return *this; }
+
+ inline vector signed int operator() () const { return x.v; }
+ inline int32 operator[](unsigned int idx) const { return x.x[idx]; }
+
+ inline vector bool int to_bool() const { return x.b; }
+
+ // intrinsics now seem to use xxpermdi automatically now
+ inline void load_aligned(const int32* ptr) { x.v = vec_ld(0, ptr); }
+ inline void store_aligned(int32* ptr) const { vec_st(x.v, 0, ptr); }
+ inline void load(const int32* ptr) { x.v = vec_vsx_ld(0, ptr); }
+ inline void store(int32* ptr) const { vec_vsx_st(x.v, 0, ptr); }
+
+
+ struct rawarray
+ {
+ v4i v;
+ };
+ inline simd4i(const rawarray& a) : x{a.v} { }
+
+ };
+
+#elif defined(DLIB_HAVE_NEON)
+
+ class simd4i
+ {
+ public:
+ typedef int32 type;
+
+ inline simd4i() {}
+ inline simd4i(int32 f) { x = vdupq_n_s32(f); }
+ inline simd4i(int32 r0, int32 r1, int32 r2, int32 r3)
+ {
+ int32 __attribute__((aligned(16))) data[4] = { r0, r1, r2, r3 };
+ x = vld1q_s32(data);
+ }
+ inline simd4i(const int32x4_t& val):x(val) {}
+
+ inline simd4i& operator=(const int32x4_t& val)
+ {
+ x = val;
+ return *this;
+ }
+
+ inline operator int32x4_t() const { return x; }
+ inline operator uint32x4_t() const { return (uint32x4_t)x; }
+
+ inline void load_aligned(const type* ptr) { x = vld1q_s32(ptr); }
+ inline void store_aligned(type* ptr) const { vst1q_s32(ptr, x); }
+ inline void load(const type* ptr) { x = vld1q_s32(ptr); }
+ inline void store(type* ptr) const { vst1q_s32(ptr, x); }
+
+ inline unsigned int size() const { return 4; }
+ inline int32 operator[](unsigned int idx) const
+ {
+ int32 temp[4];
+ store(temp);
+ return temp[idx];
+ }
+
+ private:
+ int32x4_t x;
+ };
+
+#else
+
+ class simd4i
+ {
+ public:
+ typedef int32 type;
+
+ inline simd4i() {}
+ inline simd4i(int32 f) { x[0]=f; x[1]=f; x[2]=f; x[3]=f; }
+ inline simd4i(int32 r0, int32 r1, int32 r2, int32 r3) { x[0]=r0; x[1]=r1; x[2]=r2; x[3]=r3;}
+
+ struct rawarray
+ {
+ int32 a[4];
+ };
+ inline simd4i(const rawarray& a) { x[0]=a.a[0]; x[1]=a.a[1]; x[2]=a.a[2]; x[3]=a.a[3]; }
+
+ inline void load_aligned(const type* ptr)
+ {
+ x[0] = ptr[0];
+ x[1] = ptr[1];
+ x[2] = ptr[2];
+ x[3] = ptr[3];
+ }
+
+ inline void store_aligned(type* ptr) const
+ {
+ ptr[0] = x[0];
+ ptr[1] = x[1];
+ ptr[2] = x[2];
+ ptr[3] = x[3];
+ }
+
+ inline void load(const type* ptr)
+ {
+ x[0] = ptr[0];
+ x[1] = ptr[1];
+ x[2] = ptr[2];
+ x[3] = ptr[3];
+ }
+
+ inline void store(type* ptr) const
+ {
+ ptr[0] = x[0];
+ ptr[1] = x[1];
+ ptr[2] = x[2];
+ ptr[3] = x[3];
+ }
+
+ inline unsigned int size() const { return 4; }
+ inline int32 operator[](unsigned int idx) const { return x[idx]; }
+
+ private:
+ int32 x[4];
+ };
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::ostream& operator<<(std::ostream& out, const simd4i& item)
+ {
+ int32 temp[4];
+ item.store(temp);
+ out << "(" << temp[0] << ", " << temp[1] << ", " << temp[2] << ", " << temp[3] << ")";
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator+ (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_add_epi32(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_add(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vaddq_s32(lhs, rhs);
+#else
+ return simd4i(lhs[0]+rhs[0],
+ lhs[1]+rhs[1],
+ lhs[2]+rhs[2],
+ lhs[3]+rhs[3]);
+#endif
+ }
+ inline simd4i& operator+= (simd4i& lhs, const simd4i& rhs)
+ { return lhs = lhs + rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator- (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_sub_epi32(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_sub(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vsubq_s32(lhs, rhs);
+#else
+ return simd4i(lhs[0]-rhs[0],
+ lhs[1]-rhs[1],
+ lhs[2]-rhs[2],
+ lhs[3]-rhs[3]);
+#endif
+ }
+ inline simd4i& operator-= (simd4i& lhs, const simd4i& rhs)
+ { return lhs = lhs - rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator* (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE41
+ return _mm_mullo_epi32(lhs, rhs);
+#elif defined(DLIB_HAVE_SSE2)
+ int32 _lhs[4]; lhs.store(_lhs);
+ int32 _rhs[4]; rhs.store(_rhs);
+ return simd4i(_lhs[0]*_rhs[0],
+ _lhs[1]*_rhs[1],
+ _lhs[2]*_rhs[2],
+ _lhs[3]*_rhs[3]);
+#elif defined(DLIB_HAVE_VSX)
+ vector int a = lhs(), b = rhs();
+ asm("vmuluwm %0, %0, %1\n\t" : "+&v" (a) : "v" (b) );
+ return simd4i(a);
+#elif defined(DLIB_HAVE_NEON)
+ return vmulq_s32(lhs, rhs);
+#else
+ return simd4i(lhs[0]*rhs[0],
+ lhs[1]*rhs[1],
+ lhs[2]*rhs[2],
+ lhs[3]*rhs[3]);
+#endif
+ }
+ inline simd4i& operator*= (simd4i& lhs, const simd4i& rhs)
+ { return lhs = lhs * rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator& (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_and_si128(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_and(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vandq_s32(lhs, rhs);
+#else
+ return simd4i(lhs[0]&rhs[0],
+ lhs[1]&rhs[1],
+ lhs[2]&rhs[2],
+ lhs[3]&rhs[3]);
+#endif
+ }
+ inline simd4i& operator&= (simd4i& lhs, const simd4i& rhs)
+ { return lhs = lhs & rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator| (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_or_si128(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_or(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vorrq_s32(lhs, rhs);
+#else
+ return simd4i(lhs[0]|rhs[0],
+ lhs[1]|rhs[1],
+ lhs[2]|rhs[2],
+ lhs[3]|rhs[3]);
+#endif
+ }
+ inline simd4i& operator|= (simd4i& lhs, const simd4i& rhs)
+ { return lhs = lhs | rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator^ (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_xor_si128(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_xor(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return veorq_s32(lhs, rhs);
+#else
+ return simd4i(lhs[0]^rhs[0],
+ lhs[1]^rhs[1],
+ lhs[2]^rhs[2],
+ lhs[3]^rhs[3]);
+#endif
+ }
+ inline simd4i& operator^= (simd4i& lhs, const simd4i& rhs)
+ { return lhs = lhs ^ rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator~ (const simd4i& lhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_xor_si128(lhs, _mm_set1_epi32(0xFFFFFFFF));
+#elif defined(DLIB_HAVE_VSX)
+ return vec_xor(lhs(), vec_splats(~0));
+#elif defined(DLIB_HAVE_NEON)
+ return vmvnq_s32(lhs);
+#else
+ return simd4i(~lhs[0],
+ ~lhs[1],
+ ~lhs[2],
+ ~lhs[3]);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator<< (const simd4i& lhs, const int& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_sll_epi32(lhs,_mm_cvtsi32_si128(rhs));
+#elif defined(DLIB_HAVE_VSX)
+ return vec_sl(lhs(), vec_splats((uint32_t)rhs));
+#elif defined(DLIB_HAVE_NEON)
+ return vshlq_s32(lhs, simd4i(rhs));
+#else
+ return simd4i(lhs[0]<<rhs,
+ lhs[1]<<rhs,
+ lhs[2]<<rhs,
+ lhs[3]<<rhs);
+#endif
+ }
+ inline simd4i& operator<<= (simd4i& lhs, const int& rhs)
+ { return lhs = lhs << rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator>> (const simd4i& lhs, const int& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_sra_epi32(lhs,_mm_cvtsi32_si128(rhs));
+#elif defined(DLIB_HAVE_VSX)
+ return vec_sr(lhs(), vec_splats((uint32_t)rhs));
+#elif defined(DLIB_HAVE_NEON)
+ int32 _lhs[4]; lhs.store(_lhs);
+ return simd4i(_lhs[0]>>rhs,
+ _lhs[1]>>rhs,
+ _lhs[2]>>rhs,
+ _lhs[3]>>rhs);
+#else
+ return simd4i(lhs[0]>>rhs,
+ lhs[1]>>rhs,
+ lhs[2]>>rhs,
+ lhs[3]>>rhs);
+#endif
+ }
+ inline simd4i& operator>>= (simd4i& lhs, const int& rhs)
+ { return lhs = lhs >> rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator== (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_cmpeq_epi32(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_cmpeq(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return (int32x4_t)vceqq_s32(lhs,rhs);
+#else
+ return simd4i(lhs[0]==rhs[0] ? 0xFFFFFFFF : 0,
+ lhs[1]==rhs[1] ? 0xFFFFFFFF : 0,
+ lhs[2]==rhs[2] ? 0xFFFFFFFF : 0,
+ lhs[3]==rhs[3] ? 0xFFFFFFFF : 0);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator!= (const simd4i& lhs, const simd4i& rhs)
+ {
+#if defined(DLIB_HAVE_SSE2) || defined(DLIB_HAVE_VSX) || defined(DLIB_HAVE_NEON)
+ return ~(lhs==rhs);
+#else
+ return simd4i(lhs[0]!=rhs[0] ? 0xFFFFFFFF : 0,
+ lhs[1]!=rhs[1] ? 0xFFFFFFFF : 0,
+ lhs[2]!=rhs[2] ? 0xFFFFFFFF : 0,
+ lhs[3]!=rhs[3] ? 0xFFFFFFFF : 0);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator< (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return _mm_cmplt_epi32(lhs, rhs);
+#elif defined(DLIB_HAVE_VSX)
+ return vec_cmplt(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return (int32x4_t)vcltq_s32(lhs, rhs);
+#else
+ return simd4i(lhs[0]<rhs[0] ? 0xFFFFFFFF : 0,
+ lhs[1]<rhs[1] ? 0xFFFFFFFF : 0,
+ lhs[2]<rhs[2] ? 0xFFFFFFFF : 0,
+ lhs[3]<rhs[3] ? 0xFFFFFFFF : 0);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator> (const simd4i& lhs, const simd4i& rhs)
+ {
+ return rhs < lhs;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator<= (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE2
+ return ~(lhs > rhs);
+#elif defined(DLIB_HAVE_NEON)
+ return (int32x4_t)vcleq_s32(lhs, rhs);
+#else
+ return simd4i(lhs[0]<=rhs[0] ? 0xFFFFFFFF : 0,
+ lhs[1]<=rhs[1] ? 0xFFFFFFFF : 0,
+ lhs[2]<=rhs[2] ? 0xFFFFFFFF : 0,
+ lhs[3]<=rhs[3] ? 0xFFFFFFFF : 0);
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i operator>= (const simd4i& lhs, const simd4i& rhs)
+ {
+ return rhs <= lhs;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i min (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE41
+ return _mm_min_epi32(lhs, rhs);
+#elif defined(DLIB_HAVE_SSE2)
+ int32 _lhs[4]; lhs.store(_lhs);
+ int32 _rhs[4]; rhs.store(_rhs);
+ return simd4i(std::min(_lhs[0],_rhs[0]),
+ std::min(_lhs[1],_rhs[1]),
+ std::min(_lhs[2],_rhs[2]),
+ std::min(_lhs[3],_rhs[3]));
+#elif defined(DLIB_HAVE_VSX)
+ return vec_min(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return (int32x4_t)vminq_s32(lhs, rhs);
+#else
+ return simd4i(std::min(lhs[0],rhs[0]),
+ std::min(lhs[1],rhs[1]),
+ std::min(lhs[2],rhs[2]),
+ std::min(lhs[3],rhs[3]));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd4i max (const simd4i& lhs, const simd4i& rhs)
+ {
+#ifdef DLIB_HAVE_SSE41
+ return _mm_max_epi32(lhs, rhs);
+#elif defined(DLIB_HAVE_SSE2)
+ int32 _lhs[4]; lhs.store(_lhs);
+ int32 _rhs[4]; rhs.store(_rhs);
+ return simd4i(std::max(_lhs[0],_rhs[0]),
+ std::max(_lhs[1],_rhs[1]),
+ std::max(_lhs[2],_rhs[2]),
+ std::max(_lhs[3],_rhs[3]));
+#elif defined(DLIB_HAVE_VSX)
+ return vec_max(lhs(), rhs());
+#elif defined(DLIB_HAVE_NEON)
+ return vmaxq_s32(lhs, rhs);
+#else
+ return simd4i(std::max(lhs[0],rhs[0]),
+ std::max(lhs[1],rhs[1]),
+ std::max(lhs[2],rhs[2]),
+ std::max(lhs[3],rhs[3]));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline int32 sum(const simd4i& item)
+ {
+#ifdef DLIB_HAVE_SSE3
+ simd4i temp = _mm_hadd_epi32(item,item);
+ temp = _mm_hadd_epi32(temp,temp);
+ return _mm_cvtsi128_si32(temp);
+#elif defined(DLIB_HAVE_SSE2)
+ int32 temp[4];
+ item.store(temp);
+ return temp[0]+temp[1]+temp[2]+temp[3];
+#elif defined(DLIB_HAVE_NEON)
+ int32x2_t r = vadd_s32(vget_high_s32(item), vget_low_s32(item));
+ return vget_lane_s32(vpadd_s32(r, r), 0);
+#else
+ return item[0]+item[1]+item[2]+item[3];
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // perform cmp ? a : b
+ inline simd4i select(const simd4i& cmp, const simd4i& a, const simd4i& b)
+ {
+#ifdef DLIB_HAVE_SSE41
+ return _mm_blendv_epi8(b,a,cmp);
+#elif defined(DLIB_HAVE_SSE2)
+ return ((cmp&a) | _mm_andnot_si128(cmp,b));
+#elif defined(DLIB_HAVE_VSX)
+ return vec_sel(b(), a(), cmp.to_bool());
+#elif defined(DLIB_HAVE_NEON)
+ return vbslq_s32(cmp, a, b);
+#else
+ return ((cmp&a) | (~cmp&b));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_sIMD4I_Hh_
+
diff --git a/ml/dlib/dlib/simd/simd8f.h b/ml/dlib/dlib/simd/simd8f.h
new file mode 100644
index 000000000..628ba74ee
--- /dev/null
+++ b/ml/dlib/dlib/simd/simd8f.h
@@ -0,0 +1,402 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_sIMD8F_Hh_
+#define DLIB_sIMD8F_Hh_
+
+#include "simd_check.h"
+#include "simd4f.h"
+#include "simd8i.h"
+
+namespace dlib
+{
+#ifdef DLIB_HAVE_AVX
+ class simd8f
+ {
+ public:
+ typedef float type;
+
+ inline simd8f() {}
+ inline simd8f(const simd4f& low, const simd4f& high)
+ {
+ x = _mm256_insertf128_ps(_mm256_castps128_ps256(low),high,1);
+ }
+ inline simd8f(float f) { x = _mm256_set1_ps(f); }
+ inline simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7)
+ { x = _mm256_setr_ps(r0,r1,r2,r3,r4,r5,r6,r7); }
+
+ inline simd8f(const simd8i& val):x(_mm256_cvtepi32_ps(val)) {}
+ inline simd8f(const __m256& val):x(val) {}
+ inline simd8f& operator=(const __m256& val)
+ {
+ x = val;
+ return *this;
+ }
+ inline operator __m256() const { return x; }
+
+ // truncate to 32bit integers
+ inline operator __m256i() const { return _mm256_cvttps_epi32(x); }
+
+ inline void load_aligned(const type* ptr) { x = _mm256_load_ps(ptr); }
+ inline void store_aligned(type* ptr) const { _mm256_store_ps(ptr, x); }
+ inline void load(const type* ptr) { x = _mm256_loadu_ps(ptr); }
+ inline void store(type* ptr) const { _mm256_storeu_ps(ptr, x); }
+
+ inline simd8f& operator=(const simd8i& rhs) { *this = simd8f(rhs); return *this; }
+ inline simd8f& operator=(const float& val)
+ {
+ x = simd8f(val);
+ return *this;
+ }
+
+ inline unsigned int size() const { return 8; }
+ inline float operator[](unsigned int idx) const
+ {
+ float temp[8];
+ store(temp);
+ return temp[idx];
+ }
+
+ inline simd4f low() const { return _mm256_castps256_ps128(x); }
+ inline simd4f high() const { return _mm256_extractf128_ps(x,1); }
+
+ private:
+ __m256 x;
+ };
+
+
+ class simd8f_bool
+ {
+ public:
+ typedef float type;
+
+ inline simd8f_bool() {}
+ inline simd8f_bool(const __m256& val):x(val) {}
+ inline simd8f_bool(const simd4f_bool& low, const simd4f_bool& high)
+ {
+ x = _mm256_insertf128_ps(_mm256_castps128_ps256(low),high,1);
+ }
+
+ inline simd8f_bool& operator=(const __m256& val)
+ {
+ x = val;
+ return *this;
+ }
+
+ inline operator __m256() const { return x; }
+
+
+ private:
+ __m256 x;
+ };
+
+#else
+ class simd8f
+ {
+ public:
+ typedef float type;
+
+ inline simd8f() {}
+ inline simd8f(const simd4f& low_, const simd4f& high_): _low(low_),_high(high_){}
+ inline simd8f(float f) :_low(f),_high(f) {}
+ inline simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) :
+ _low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {}
+ inline simd8f(const simd8i& val) : _low(val.low()), _high(val.high()) { }
+
+ // truncate to 32bit integers
+ inline operator simd8i::rawarray() const
+ {
+ simd8i::rawarray temp;
+ temp.low = simd4i(_low);
+ temp.high = simd4i(_high);
+ return temp;
+ }
+
+ inline void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); }
+ inline void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); }
+ inline void load(const type* ptr) { _low.load(ptr); _high.load(ptr+4); }
+ inline void store(type* ptr) const { _low.store(ptr); _high.store(ptr+4); }
+
+ inline unsigned int size() const { return 8; }
+ inline float operator[](unsigned int idx) const
+ {
+ if (idx < 4)
+ return _low[idx];
+ else
+ return _high[idx-4];
+ }
+
+ inline const simd4f& low() const { return _low; }
+ inline const simd4f& high() const { return _high; }
+
+ private:
+ simd4f _low, _high;
+ };
+
+ class simd8f_bool
+ {
+ public:
+ typedef float type;
+
+ inline simd8f_bool() {}
+ inline simd8f_bool(const simd4f_bool& low_, const simd4f_bool& high_): _low(low_),_high(high_){}
+
+
+ inline const simd4f_bool& low() const { return _low; }
+ inline const simd4f_bool& high() const { return _high; }
+ private:
+ simd4f_bool _low,_high;
+ };
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::ostream& operator<<(std::ostream& out, const simd8f& item)
+ {
+ float temp[8];
+ item.store(temp);
+ out << "(" << temp[0] << ", " << temp[1] << ", " << temp[2] << ", " << temp[3] << ", "
+ << temp[4] << ", " << temp[5] << ", " << temp[6] << ", " << temp[7] << ")";
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f operator+ (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_add_ps(lhs, rhs);
+#else
+ return simd8f(lhs.low()+rhs.low(),
+ lhs.high()+rhs.high());
+#endif
+ }
+ inline simd8f& operator+= (simd8f& lhs, const simd8f& rhs)
+ { lhs = lhs + rhs; return lhs; }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f operator- (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_sub_ps(lhs, rhs);
+#else
+ return simd8f(lhs.low()-rhs.low(),
+ lhs.high()-rhs.high());
+#endif
+ }
+ inline simd8f& operator-= (simd8f& lhs, const simd8f& rhs)
+ { lhs = lhs - rhs; return lhs; }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f operator* (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_mul_ps(lhs, rhs);
+#else
+ return simd8f(lhs.low()*rhs.low(),
+ lhs.high()*rhs.high());
+#endif
+ }
+ inline simd8f& operator*= (simd8f& lhs, const simd8f& rhs)
+ { lhs = lhs * rhs; return lhs; }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f operator/ (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_div_ps(lhs, rhs);
+#else
+ return simd8f(lhs.low()/rhs.low(),
+ lhs.high()/rhs.high());
+#endif
+ }
+ inline simd8f& operator/= (simd8f& lhs, const simd8f& rhs)
+ { lhs = lhs / rhs; return lhs; }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f_bool operator== (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_cmp_ps(lhs, rhs, 0);
+#else
+ return simd8f_bool(lhs.low() ==rhs.low(),
+ lhs.high()==rhs.high());
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f_bool operator!= (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_cmp_ps(lhs, rhs, 4);
+#else
+ return simd8f_bool(lhs.low() !=rhs.low(),
+ lhs.high()!=rhs.high());
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f_bool operator< (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_cmp_ps(lhs, rhs, 1);
+#else
+ return simd8f_bool(lhs.low() <rhs.low(),
+ lhs.high()<rhs.high());
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f_bool operator> (const simd8f& lhs, const simd8f& rhs)
+ {
+ return rhs < lhs;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f_bool operator<= (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_cmp_ps(lhs, rhs, 2);
+#else
+ return simd8f_bool(lhs.low() <=rhs.low(),
+ lhs.high()<=rhs.high());
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f_bool operator>= (const simd8f& lhs, const simd8f& rhs)
+ {
+ return rhs <= lhs;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f min (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_min_ps(lhs, rhs);
+#else
+ return simd8f(min(lhs.low(), rhs.low()),
+ min(lhs.high(),rhs.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f max (const simd8f& lhs, const simd8f& rhs)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_max_ps(lhs, rhs);
+#else
+ return simd8f(max(lhs.low(), rhs.low()),
+ max(lhs.high(),rhs.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f reciprocal (const simd8f& item)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_rcp_ps(item);
+#else
+ return simd8f(reciprocal(item.low()),
+ reciprocal(item.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f reciprocal_sqrt (const simd8f& item)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_rsqrt_ps(item);
+#else
+ return simd8f(reciprocal_sqrt(item.low()),
+ reciprocal_sqrt(item.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline float sum(const simd8f& item)
+ {
+#ifdef DLIB_HAVE_AVX
+ simd8f temp = _mm256_hadd_ps(item,item);
+ simd8f temp2 = _mm256_hadd_ps(temp,temp);
+ return _mm_cvtss_f32(_mm_add_ss(_mm256_castps256_ps128(temp2),_mm256_extractf128_ps(temp2,1)));
+#else
+ return sum(item.low()+item.high());
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline float dot(const simd8f& lhs, const simd8f& rhs)
+ {
+ return sum(lhs*rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f sqrt(const simd8f& item)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_sqrt_ps(item);
+#else
+ return simd8f(sqrt(item.low()),
+ sqrt(item.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f ceil(const simd8f& item)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_ceil_ps(item);
+#else
+ return simd8f(ceil(item.low()),
+ ceil(item.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8f floor(const simd8f& item)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_floor_ps(item);
+#else
+ return simd8f(floor(item.low()),
+ floor(item.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // perform cmp ? a : b
+ inline simd8f select(const simd8f_bool& cmp, const simd8f& a, const simd8f& b)
+ {
+#ifdef DLIB_HAVE_AVX
+ return _mm256_blendv_ps(b,a,cmp);
+#else
+ return simd8f(select(cmp.low(), a.low(), b.low()),
+ select(cmp.high(), a.high(), b.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_sIMD8F_Hh_
+
diff --git a/ml/dlib/dlib/simd/simd8i.h b/ml/dlib/dlib/simd/simd8i.h
new file mode 100644
index 000000000..18c06ec7e
--- /dev/null
+++ b/ml/dlib/dlib/simd/simd8i.h
@@ -0,0 +1,339 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_sIMD8I_Hh_
+#define DLIB_sIMD8I_Hh_
+
+#include "simd_check.h"
+#include "../uintn.h"
+
+namespace dlib
+{
+
+#ifdef DLIB_HAVE_AVX
+ class simd8i
+ {
+ public:
+ typedef int32 type;
+
+ inline simd8i() {}
+ inline simd8i(int32 f) { x = _mm256_set1_epi32(f); }
+ inline simd8i(int32 r0, int32 r1, int32 r2, int32 r3,
+ int32 r4, int32 r5, int32 r6, int32 r7 )
+ { x = _mm256_setr_epi32(r0,r1,r2,r3,r4,r5,r6,r7); }
+
+ inline simd8i(const __m256i& val):x(val) {}
+
+ inline simd8i(const simd4i& low, const simd4i& high)
+ {
+ x = _mm256_insertf128_si256(_mm256_castsi128_si256(low),high,1);
+ }
+
+ inline simd8i& operator=(const __m256i& val)
+ {
+ x = val;
+ return *this;
+ }
+
+ inline operator __m256i() const { return x; }
+
+ inline void load_aligned(const type* ptr) { x = _mm256_load_si256((const __m256i*)ptr); }
+ inline void store_aligned(type* ptr) const { _mm256_store_si256((__m256i*)ptr, x); }
+ inline void load(const type* ptr) { x = _mm256_loadu_si256((const __m256i*)ptr); }
+ inline void store(type* ptr) const { _mm256_storeu_si256((__m256i*)ptr, x); }
+
+ inline simd4i low() const { return _mm256_castsi256_si128(x); }
+ inline simd4i high() const { return _mm256_extractf128_si256(x,1); }
+
+ inline unsigned int size() const { return 8; }
+ inline int32 operator[](unsigned int idx) const
+ {
+ int32 temp[8];
+ store(temp);
+ return temp[idx];
+ }
+
+ private:
+ __m256i x;
+ };
+#else
+ class simd8i
+ {
+ public:
+ typedef int32 type;
+
+ inline simd8i() {}
+ inline simd8i(const simd4i& low_, const simd4i& high_): _low(low_),_high(high_){}
+ inline simd8i(int32 f) :_low(f),_high(f) {}
+ inline simd8i(int32 r0, int32 r1, int32 r2, int32 r3, int32 r4, int32 r5, int32 r6, int32 r7) :
+ _low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {}
+
+ struct rawarray
+ {
+ simd4i low, high;
+ };
+ inline simd8i(const rawarray& a)
+ {
+ _low = a.low;
+ _high = a.high;
+ }
+
+ inline void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); }
+ inline void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); }
+ inline void load(const type* ptr) { _low.load(ptr); _high.load(ptr+4); }
+ inline void store(type* ptr) const { _low.store(ptr); _high.store(ptr+4); }
+
+ inline unsigned int size() const { return 8; }
+ inline int32 operator[](unsigned int idx) const
+ {
+ if (idx < 4)
+ return _low[idx];
+ else
+ return _high[idx-4];
+ }
+
+ inline const simd4i& low() const { return _low; }
+ inline const simd4i& high() const { return _high; }
+
+ private:
+ simd4i _low, _high;
+ };
+
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::ostream& operator<<(std::ostream& out, const simd8i& item)
+ {
+ int32 temp[8];
+ item.store(temp);
+ out << "(" << temp[0] << ", " << temp[1] << ", " << temp[2] << ", " << temp[3] << ", "
+ << temp[4] << ", " << temp[5] << ", " << temp[6] << ", " << temp[7] << ")";
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator+ (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_add_epi32(lhs, rhs);
+#else
+ return simd8i(lhs.low()+rhs.low(),
+ lhs.high()+rhs.high());
+#endif
+ }
+ inline simd8i& operator+= (simd8i& lhs, const simd8i& rhs)
+ { return lhs = lhs + rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator- (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_sub_epi32(lhs, rhs);
+#else
+ return simd8i(lhs.low()-rhs.low(),
+ lhs.high()-rhs.high());
+#endif
+ }
+ inline simd8i& operator-= (simd8i& lhs, const simd8i& rhs)
+ { return lhs = lhs - rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator* (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_mullo_epi32(lhs, rhs);
+#else
+ return simd8i(lhs.low()*rhs.low(),
+ lhs.high()*rhs.high());
+#endif
+ }
+ inline simd8i& operator*= (simd8i& lhs, const simd8i& rhs)
+ { return lhs = lhs * rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator& (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_and_si256(lhs, rhs);
+#else
+ return simd8i(lhs.low()&rhs.low(),
+ lhs.high()&rhs.high());
+#endif
+ }
+ inline simd8i& operator&= (simd8i& lhs, const simd8i& rhs)
+ { return lhs = lhs & rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator| (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_or_si256(lhs, rhs);
+#else
+ return simd8i(lhs.low()|rhs.low(),
+ lhs.high()|rhs.high());
+#endif
+ }
+ inline simd8i& operator|= (simd8i& lhs, const simd8i& rhs)
+ { return lhs = lhs | rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator^ (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_xor_si256(lhs, rhs);
+#else
+ return simd8i(lhs.low()^rhs.low(),
+ lhs.high()^rhs.high());
+#endif
+ }
+ inline simd8i& operator^= (simd8i& lhs, const simd8i& rhs)
+ { return lhs = lhs ^ rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator~ (const simd8i& lhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_xor_si256(lhs, _mm256_set1_epi32(0xFFFFFFFF));
+#else
+ return simd8i(~lhs.low(), ~lhs.high());
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator<< (const simd8i& lhs, const int& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_sll_epi32(lhs,_mm_cvtsi32_si128(rhs));
+#else
+ return simd8i(lhs.low()<<rhs,
+ lhs.high()<<rhs);
+#endif
+ }
+ inline simd8i& operator<<= (simd8i& lhs, const int& rhs)
+ { return lhs = lhs << rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator>> (const simd8i& lhs, const int& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_sra_epi32(lhs,_mm_cvtsi32_si128(rhs));
+#else
+ return simd8i(lhs.low()>>rhs,
+ lhs.high()>>rhs);
+#endif
+ }
+ inline simd8i& operator>>= (simd8i& lhs, const int& rhs)
+ { return lhs = lhs >> rhs; return lhs;}
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator== (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_cmpeq_epi32(lhs, rhs);
+#else
+ return simd8i(lhs.low()==rhs.low(),
+ lhs.high()==rhs.high());
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator!= (const simd8i& lhs, const simd8i& rhs)
+ {
+ return ~(lhs==rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator> (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_cmpgt_epi32(lhs, rhs);
+#else
+ return simd8i(lhs.low()>rhs.low(),
+ lhs.high()>rhs.high());
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator< (const simd8i& lhs, const simd8i& rhs)
+ {
+ return rhs > lhs;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator<= (const simd8i& lhs, const simd8i& rhs)
+ {
+ return ~(lhs > rhs);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i operator>= (const simd8i& lhs, const simd8i& rhs)
+ {
+ return rhs <= lhs;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i min (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_min_epi32(lhs, rhs);
+#else
+ return simd8i(min(lhs.low(),rhs.low()),
+ min(lhs.high(),rhs.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline simd8i max (const simd8i& lhs, const simd8i& rhs)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_max_epi32(lhs, rhs);
+#else
+ return simd8i(max(lhs.low(),rhs.low()),
+ max(lhs.high(),rhs.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline int32 sum(const simd8i& item)
+ {
+ return sum(item.low()+item.high());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // perform cmp ? a : b
+ inline simd8i select(const simd8i& cmp, const simd8i& a, const simd8i& b)
+ {
+#ifdef DLIB_HAVE_AVX2
+ return _mm256_blendv_epi8(b,a,cmp);
+#else
+ return simd8i(select(cmp.low(), a.low(), b.low()),
+ select(cmp.high(), a.high(), b.high()));
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_sIMD8I_Hh_
+
+
diff --git a/ml/dlib/dlib/simd/simd_check.h b/ml/dlib/dlib/simd/simd_check.h
new file mode 100644
index 000000000..c4ca0c3b8
--- /dev/null
+++ b/ml/dlib/dlib/simd/simd_check.h
@@ -0,0 +1,177 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SIMd_CHECK_Hh_
+#define DLIB_SIMd_CHECK_Hh_
+
+#include <array>
+#include <iostream>
+
+//#define DLIB_DO_NOT_USE_SIMD
+
+// figure out which SIMD instructions we can use.
+#ifndef DLIB_DO_NOT_USE_SIMD
+ #if defined(_MSC_VER)
+ #ifdef __AVX__
+ #ifndef DLIB_HAVE_SSE2
+ #define DLIB_HAVE_SSE2
+ #endif
+ #ifndef DLIB_HAVE_SSE3
+ #define DLIB_HAVE_SSE3
+ #endif
+ #ifndef DLIB_HAVE_SSE41
+ #define DLIB_HAVE_SSE41
+ #endif
+ #ifndef DLIB_HAVE_AVX
+ #define DLIB_HAVE_AVX
+ #endif
+ #endif
+ #if (defined( _M_X64) || defined(_M_IX86_FP) && _M_IX86_FP >= 2) && !defined(DLIB_HAVE_SSE2)
+ #define DLIB_HAVE_SSE2
+ #endif
+ #else
+ #ifdef __SSE2__
+ #ifndef DLIB_HAVE_SSE2
+ #define DLIB_HAVE_SSE2
+ #endif
+ #endif
+ #ifdef __SSSE3__
+ #ifndef DLIB_HAVE_SSE3
+ #define DLIB_HAVE_SSE3
+ #endif
+ #endif
+ #ifdef __SSE4_1__
+ #ifndef DLIB_HAVE_SSE41
+ #define DLIB_HAVE_SSE41
+ #endif
+ #endif
+ #ifdef __AVX__
+ #ifndef DLIB_HAVE_AVX
+ #define DLIB_HAVE_AVX
+ #endif
+ #endif
+ #ifdef __AVX2__
+ #ifndef DLIB_HAVE_AVX2
+ #define DLIB_HAVE_AVX2
+ #endif
+ #endif
+ #ifdef __ALTIVEC__
+ #ifndef DLIB_HAVE_ALTIVEC
+ #define DLIB_HAVE_ALTIVEC
+ #endif
+ #endif
+ #ifdef __VSX__
+ #ifndef DLIB_HAVE_VSX
+ #define DLIB_HAVE_VSX
+ #endif
+ #endif
+ #ifdef __VEC__ // __VEC__ = 10206
+ #ifndef DLIB_HAVE_POWER_VEC // vector and vec_ intrinsics
+ #define DLIB_HAVE_POWER_VEC
+ #endif
+ #endif
+ #ifdef __ARM_NEON
+ #ifndef DLIB_HAVE_NEON
+ #define DLIB_HAVE_NEON
+ #endif
+ #endif
+ #endif
+#endif
+
+
+// ----------------------------------------------------------------------------------------
+
+
+#ifdef DLIB_HAVE_ALTIVEC
+#include <altivec.h>
+#endif
+
+#ifdef DLIB_HAVE_SSE2
+ #include <xmmintrin.h>
+ #include <emmintrin.h>
+ #include <mmintrin.h>
+#endif
+#ifdef DLIB_HAVE_SSE3
+ #include <pmmintrin.h> // SSE3
+ #include <tmmintrin.h>
+#endif
+#ifdef DLIB_HAVE_SSE41
+ #include <smmintrin.h> // SSE4
+#endif
+#ifdef DLIB_HAVE_AVX
+ #include <immintrin.h> // AVX
+#endif
+#ifdef DLIB_HAVE_AVX2
+ #include <immintrin.h> // AVX
+// #include <avx2intrin.h>
+#endif
+#ifdef DLIB_HAVE_NEON
+ #include <arm_neon.h> // ARM NEON
+#endif
+
+// ----------------------------------------------------------------------------------------
+// Define functions to check, at runtime, what instructions are available
+
+#if defined(_MSC_VER) && (defined(_M_I86) || defined(_M_IX86) || defined(_M_X64) || defined(_M_AMD64) )
+ #include <intrin.h>
+
+ inline std::array<unsigned int,4> cpuid(int function_id)
+ {
+ std::array<unsigned int,4> info;
+ // Load EAX, EBX, ECX, EDX into info
+ __cpuid((int*)info.data(), function_id);
+ return info;
+ }
+
+#elif (defined(__GNUC__) || defined(__clang__)) && (defined(__i386__) || defined(__i686__) || defined(__amd64__) || defined(__x86_64__))
+ #include <cpuid.h>
+
+ inline std::array<unsigned int,4> cpuid(int function_id)
+ {
+ std::array<unsigned int,4> info;
+ // Load EAX, EBX, ECX, EDX into info
+ __cpuid(function_id, info[0], info[1], info[2], info[3]);
+ return info;
+ }
+
+#else
+
+ inline std::array<unsigned int,4> cpuid(int)
+ {
+ return std::array<unsigned int,4>{};
+ }
+
+#endif
+
+ inline bool cpu_has_sse2_instructions() { return 0!=(cpuid(1)[3]&(1<<26)); }
+ inline bool cpu_has_sse3_instructions() { return 0!=(cpuid(1)[2]&(1<<0)); }
+ inline bool cpu_has_sse41_instructions() { return 0!=(cpuid(1)[2]&(1<<19)); }
+ inline bool cpu_has_sse42_instructions() { return 0!=(cpuid(1)[2]&(1<<20)); }
+ inline bool cpu_has_avx_instructions() { return 0!=(cpuid(1)[2]&(1<<28)); }
+ inline bool cpu_has_avx2_instructions() { return 0!=(cpuid(7)[1]&(1<<5)); }
+ inline bool cpu_has_avx512_instructions() { return 0!=(cpuid(7)[1]&(1<<16)); }
+
+ inline void warn_about_unavailable_but_used_cpu_instructions()
+ {
+#if defined(DLIB_HAVE_AVX2)
+ if (!cpu_has_avx2_instructions())
+ std::cerr << "Dlib was compiled to use AVX2 instructions, but these aren't available on your machine." << std::endl;
+#elif defined(DLIB_HAVE_AVX)
+ if (!cpu_has_avx_instructions())
+ std::cerr << "Dlib was compiled to use AVX instructions, but these aren't available on your machine." << std::endl;
+#elif defined(DLIB_HAVE_SSE41)
+ if (!cpu_has_sse41_instructions())
+ std::cerr << "Dlib was compiled to use SSE41 instructions, but these aren't available on your machine." << std::endl;
+#elif defined(DLIB_HAVE_SSE3)
+ if (!cpu_has_sse3_instructions())
+ std::cerr << "Dlib was compiled to use SSE3 instructions, but these aren't available on your machine." << std::endl;
+#elif defined(DLIB_HAVE_SSE2)
+ if (!cpu_has_sse2_instructions())
+ std::cerr << "Dlib was compiled to use SSE2 instructions, but these aren't available on your machine." << std::endl;
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_SIMd_CHECK_Hh_
+
+
diff --git a/ml/dlib/dlib/sliding_buffer.h b/ml/dlib/dlib/sliding_buffer.h
new file mode 100644
index 000000000..fb89e1b00
--- /dev/null
+++ b/ml/dlib/dlib/sliding_buffer.h
@@ -0,0 +1,38 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SLIDING_BUFFEr_
+#define DLIB_SLIDING_BUFFEr_
+
+
+#include "sliding_buffer/sliding_buffer_kernel_1.h"
+#include "sliding_buffer/sliding_buffer_kernel_c.h"
+#include "sliding_buffer/circular_buffer.h"
+
+
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class sliding_buffer
+ {
+
+ sliding_buffer() {}
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef sliding_buffer_kernel_1<T>
+ kernel_1a;
+ typedef sliding_buffer_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ };
+}
+
+#endif // DLIB_SLIDING_BUFFEr_
+
diff --git a/ml/dlib/dlib/sliding_buffer/circular_buffer.h b/ml/dlib/dlib/sliding_buffer/circular_buffer.h
new file mode 100644
index 000000000..4fcc922d6
--- /dev/null
+++ b/ml/dlib/dlib/sliding_buffer/circular_buffer.h
@@ -0,0 +1,235 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CIRCULAR_BuFFER_Hh_
+#define DLIB_CIRCULAR_BuFFER_Hh_
+
+#include "circular_buffer_abstract.h"
+#include <vector>
+#include "../algs.h"
+#include "../serialize.h"
+#include "../matrix/matrix_mat.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class circular_buffer
+ {
+ public:
+ typedef default_memory_manager mem_manager_type;
+ typedef T value_type;
+ typedef T type;
+
+ circular_buffer()
+ {
+ }
+
+ explicit circular_buffer(unsigned long s)
+ {
+ resize(s);
+ }
+
+ void clear (
+ )
+ {
+ offset = 0;
+ data.clear();
+ }
+
+ T& operator[] ( unsigned long i)
+ {
+ DLIB_ASSERT(i < size(),
+ "\t T& circular_buffer::operator[](i)"
+ << "\n\t You have supplied an invalid index"
+ << "\n\t this: " << this
+ << "\n\t i: " << i
+ << "\n\t size(): " << size()
+ );
+ return data[(i+offset)%data.size()];
+ }
+
+ const T& operator[] ( unsigned long i) const
+ {
+ DLIB_ASSERT(i < size(),
+ "\t const T& circular_buffer::operator[](i)"
+ << "\n\t You have supplied an invalid index"
+ << "\n\t this: " << this
+ << "\n\t i: " << i
+ << "\n\t size(): " << size()
+ );
+ return data[(i+offset)%data.size()];
+ }
+
+ void resize(unsigned long size)
+ {
+ offset = 0;
+ data.resize(size);
+ }
+
+ void assign(
+ unsigned long size,
+ const T& value
+ )
+ {
+ offset = 0;
+ data.assign(size,value);
+ }
+
+ unsigned long size() const { return data.size(); }
+
+ void push_front(const T& value)
+ {
+ if (data.size() != 0)
+ {
+ offset = (offset - 1 + data.size())%data.size();
+ data[offset] = value;
+ }
+ }
+
+ void push_back(const T& value)
+ {
+ if (data.size() != 0)
+ {
+ data[offset] = value;
+ offset = (offset + 1 + data.size())%data.size();
+ }
+ }
+
+ T& front(
+ )
+ {
+ DLIB_CASSERT(size() > 0,
+ "\t T& circular_buffer::front()"
+ << "\n\t You can't call front() on an empty circular_buffer"
+ << "\n\t this: " << this
+ );
+ return (*this)[0];
+ }
+
+ const T& front(
+ ) const
+ {
+ DLIB_CASSERT(size() > 0,
+ "\t const T& circular_buffer::front()"
+ << "\n\t You can't call front() on an empty circular_buffer"
+ << "\n\t this: " << this
+ );
+ return (*this)[0];
+ }
+
+ T& back(
+ )
+ {
+ DLIB_CASSERT(size() > 0,
+ "\t T& circular_buffer::back()"
+ << "\n\t You can't call back() on an empty circular_buffer"
+ << "\n\t this: " << this
+ );
+ return (*this)[size()-1];
+ }
+
+ const T& back(
+ ) const
+ {
+ DLIB_CASSERT(size() > 0,
+ "\t const T& circular_buffer::back()"
+ << "\n\t You can't call back() on an empty circular_buffer"
+ << "\n\t this: " << this
+ );
+ return (*this)[size()-1];
+ }
+
+ void swap( circular_buffer& item)
+ {
+ std::swap(item.offset, offset);
+ data.swap(item.data);
+ }
+
+
+ private:
+ std::vector<T> data;
+
+ unsigned long offset = 0;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void swap (
+ circular_buffer<T>& a,
+ circular_buffer<T>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void serialize (
+ const circular_buffer<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.size(),out);
+ for (unsigned long i = 0; i < item.size(); ++i)
+ serialize(item[i],out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type circular_buffer");
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void deserialize (
+ circular_buffer<T>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long size;
+ deserialize(size,in);
+ item.resize(size);
+ for (unsigned long i = 0; i < size; ++i)
+ deserialize(item[i],in);
+ }
+ catch (serialization_error& e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type circular_buffer");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_array_to_mat<circular_buffer<T> > > mat (
+ const circular_buffer<T>& m
+ )
+ {
+ typedef op_array_to_mat<circular_buffer<T> > op;
+ return matrix_op<op>(op(m));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CIRCULAR_BuFFER_Hh_
+
diff --git a/ml/dlib/dlib/sliding_buffer/circular_buffer_abstract.h b/ml/dlib/dlib/sliding_buffer/circular_buffer_abstract.h
new file mode 100644
index 000000000..dc9f35c7a
--- /dev/null
+++ b/ml/dlib/dlib/sliding_buffer/circular_buffer_abstract.h
@@ -0,0 +1,257 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CIRCULAR_BuFFER_ABSTRACT_Hh_
+#ifdef DLIB_CIRCULAR_BuFFER_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class circular_buffer
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must have a default constructor and be copyable.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap(), size(), front(), back(), and operator[] functions do
+ not invalidate pointers or references to internal data.
+ All other functions have no such guarantee.
+
+ INITIAL VALUE
+ - size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a circular buffer of objects of type T. This means
+ that when objects are pushed onto one of its ends it does not grow
+ in size. Instead, it shifts all elements over one to make room for
+ the new element and the element at the opposing end falls off the
+ buffer and is lost.
+ !*/
+
+ public:
+ typedef default_memory_manager mem_manager_type;
+ typedef T value_type;
+ typedef T type;
+
+ circular_buffer(
+ );
+ /*!
+ ensures
+ - #size() == 0
+ - this object is properly initialized
+ !*/
+
+ explicit circular_buffer(
+ unsigned long s
+ );
+ /*!
+ ensures
+ - #size() == s
+ - this object is properly initialized
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - this object has its initial value
+ - #size() == 0
+ !*/
+
+ T& operator[] (
+ unsigned long i
+ ) const;
+ /*!
+ requires
+ - i < size()
+ ensures
+ - returns a non-const reference to the i-th element of this circular buffer
+ !*/
+
+ const T& operator[] (
+ unsigned long i
+ ) const;
+ /*!
+ requires
+ - i < size()
+ ensures
+ - returns a const reference to the i-th element of this circular buffer
+ !*/
+
+ void resize(
+ unsigned long new_size
+ );
+ /*!
+ ensures
+ - #size() == new_size
+ !*/
+
+ void assign(
+ unsigned long new_size,
+ const T& value
+ );
+ /*!
+ ensures
+ - #size() == new_size
+ - for all valid i:
+ - (*this)[i] == value
+ !*/
+
+ unsigned long size(
+ ) const;
+ /*!
+ ensures
+ - returns the number of elements in this circular buffer
+ !*/
+
+ T& front(
+ );
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - returns a reference to (*this)[0]
+ !*/
+
+ const T& front(
+ ) const;
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - returns a const reference to (*this)[0]
+ !*/
+
+ T& back(
+ );
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - returns a reference to (*this)[size()-1]
+ !*/
+
+ const T& back(
+ ) const;
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - returns a const reference to (*this)[size()-1]
+ !*/
+
+ void push_front(
+ const T& value
+ );
+ /*!
+ ensures
+ - #size() == size()
+ (i.e. the size of this object does not change)
+ - if (size() != 0) then
+ - #front() == value
+ - all items are shifted over such that,
+ - #(*this)[1] == (*this)[0]
+ - #(*this)[2] == (*this)[1]
+ - #(*this)[3] == (*this)[2]
+ - etc.
+ - back() is shifted out of the circular buffer
+ - else
+ - This function has no effect on this object
+ !*/
+
+ void push_back(
+ const T& value
+ );
+ /*!
+ ensures
+ - #size() == size()
+ (i.e. the size of this object does not change)
+ - if (size() != 0) then
+ - #back() == value
+ - all items are shifted over such that,
+ - front() is shifted out of the circular buffer
+ - #(*this)[0] == (*this)[1]
+ - #(*this)[1] == (*this)[2]
+ - #(*this)[2] == (*this)[3]
+ - etc.
+ - else
+ - This function has no effect on this object
+ !*/
+
+ void swap (
+ circular_buffer& item
+ );
+ /*!
+ ensures
+ - swaps *this with item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void swap (
+ circular_buffer<T>& a,
+ circular_buffer<T>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T
+ >
+ void serialize (
+ const circular_buffer<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ circular_buffer<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp mat (
+ const circular_buffer<T>& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - is_col_vector(R) == true
+ - R.size() == m.size()
+ - for all valid r:
+ R(r) == m[r]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CIRCULAR_BuFFER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_1.h b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_1.h
new file mode 100644
index 000000000..d3e6cc4b4
--- /dev/null
+++ b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_1.h
@@ -0,0 +1,227 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SLIDING_BUFFER_KERNEl_1_
+#define DLIB_SLIDING_BUFFER_KERNEl_1_
+
+#include "sliding_buffer_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class sliding_buffer_kernel_1 : public enumerable<T>
+ {
+ /*!
+ INITIAL VALUE
+ - buffer_size == 0
+ - buffer == 0
+ - buffer_start == 0
+ - current == 0
+ - at_start_ == true
+
+ CONVENTION
+ - buffer_size == size()
+
+ - element() == (*this)[current]
+ - current_element_valid() == (current < buffer_size) && at_start_ == false
+ - at_start() == at_start_
+
+ - if (buffer_size != 0) then
+ - buffer[(buffer_start+i)&(mask)] == operator[](i)
+ - mask == buffer_size-1
+ - else
+ - buffer == 0
+ - buffer_size == 0
+ !*/
+
+ public:
+
+ typedef T type;
+
+ sliding_buffer_kernel_1 (
+ ) :
+ buffer_start(0),
+ buffer_size(0),
+ buffer(0),
+ current(0),
+ at_start_(true)
+ {}
+
+ virtual ~sliding_buffer_kernel_1 (
+ ) { if (buffer) delete [] buffer; }
+
+ void clear(
+ )
+ {
+ buffer_size = 0;
+ if (buffer) delete [] buffer;
+ buffer = 0;
+ at_start_ = true;
+ current = 0;
+ }
+
+ void set_size (
+ unsigned long exp_size
+ )
+ {
+ at_start_ = true;
+ if (buffer) delete [] buffer;
+ buffer_size = 1;
+ while (exp_size != 0)
+ {
+ --exp_size;
+ buffer_size <<= 1;
+ }
+ mask = buffer_size-1;
+ try { buffer = new T[buffer_size]; }
+ catch (...) { buffer = 0; buffer_size = 0; throw; }
+ }
+
+ size_t size (
+ ) const { return buffer_size; }
+
+ void rotate_left (
+ unsigned long amount
+ ) { buffer_start = ((buffer_start-amount)&mask); at_start_ = true; }
+
+ void rotate_right (
+ unsigned long amount
+ ) { buffer_start = ((buffer_start+amount)&mask); at_start_ = true;}
+
+ const T& operator[] (
+ unsigned long index
+ ) const { return buffer[(buffer_start+index)&mask]; }
+
+ T& operator[] (
+ unsigned long index
+ ) { return buffer[(buffer_start+index)&mask]; }
+
+ unsigned long get_element_id(
+ unsigned long index
+ ) const { return ((buffer_start+index)&mask); }
+
+ unsigned long get_element_index (
+ unsigned long element_id
+ ) const { return ((element_id-buffer_start)&mask);}
+
+ void swap (
+ sliding_buffer_kernel_1<T>& item
+ )
+ {
+ exchange(buffer_start,item.buffer_start);
+ exchange(buffer_size,item.buffer_size);
+ exchange(buffer,item.buffer);
+ exchange(mask,item.mask);
+ exchange(current,item.current);
+ exchange(at_start_,item.at_start_);
+ }
+
+
+ bool at_start (
+ ) const { return at_start_; }
+
+ void reset (
+ ) const { at_start_ = true; }
+
+ bool current_element_valid (
+ ) const { return (current < buffer_size) && (at_start_ == false); }
+
+ const T& element (
+ ) const { return (*this)[current]; }
+
+ T& element (
+ ) { return (*this)[current]; }
+
+ bool move_next (
+ ) const
+ {
+ if (at_start_ == false)
+ {
+ if (current+1 < buffer_size)
+ {
+ ++current;
+ return true;
+ }
+ else
+ {
+ current = buffer_size;
+ return false;
+ }
+ }
+ else
+ {
+ at_start_ = false;
+ current = 0;
+ return (buffer_size != 0);
+ }
+ }
+
+
+ private:
+
+ // data members
+ unsigned long buffer_start;
+ unsigned long buffer_size;
+ T* buffer;
+ unsigned long mask;
+
+
+ mutable unsigned long current;
+ mutable bool at_start_;
+
+ // restricted functions
+ sliding_buffer_kernel_1(sliding_buffer_kernel_1<T>&); // copy constructor
+ sliding_buffer_kernel_1<T>& operator=(sliding_buffer_kernel_1<T>&); // assignment operator
+
+ };
+
+ template <
+ typename T
+ >
+ inline void swap (
+ sliding_buffer_kernel_1<T>& a,
+ sliding_buffer_kernel_1<T>& b
+ ) { a.swap(b); }
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sliding_buffer_kernel_1<T>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ if (size > 0)
+ {
+ int count = 0;
+ while (size != 1)
+ {
+ size /= 2;
+ ++count;
+ }
+ item.set_size(count);
+
+ for (unsigned long i = 0; i < item.size(); ++i)
+ deserialize(item[i],in);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type sliding_buffer_kernel_1");
+ }
+ }
+}
+
+#endif // DLIB_SLIDING_BUFFER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_abstract.h b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_abstract.h
new file mode 100644
index 000000000..687a9e878
--- /dev/null
+++ b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_abstract.h
@@ -0,0 +1,205 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SLIDING_BUFFER_KERNEl_ABSTRACT_
+#ifdef DLIB_SLIDING_BUFFER_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class sliding_buffer : public enumerable<T>
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must have a default constructor
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the elements of the sliding_buffer in the
+ order (*this)[0], (*this)[1], (*this)[2], ...
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an array of T objects. The main
+ feature of this object is its ability to rotate its contents
+ left or right. An example will make it clear.
+
+ suppose we have the following buffer (assuming T is a char):
+ "some data!" <-- the data in the buffer
+ 9876543210 <-- the index numbers associated with each character
+
+ applying rotate_left(2) to this buffer would give us
+ "me data!so"
+ 9876543210
+
+ if instead of calling rotate_left we call rotate_right(3) instead we would have
+ "ta!some da"
+ 9876543210
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+ !*/
+
+ public:
+
+ typedef T type;
+
+ sliding_buffer (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ !*/
+
+ virtual ~sliding_buffer (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ if this exception is thrown then #*this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void set_size (
+ unsigned long exp_size
+ );
+ /*!
+ requires
+ - 0 < exp_size < 32
+ ensures
+ - #size() == 2^exp_size
+ - the value of all elements in the buffer are undefined
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ if this exception is thrown then #size() == 0
+ !*/
+
+ void rotate_left (
+ unsigned long amount
+ );
+ /*!
+ ensures
+ - for all i where 0 <= i < size():
+ (#*this)[i] == (*this)[(i-amount)&(size()-1)]
+ i.e. rotates the contents of *this left by amount spaces
+ - #at_start() == true
+ !*/
+
+ void rotate_right (
+ unsigned long amount
+ );
+ /*!
+ ensures
+ - for all i where 0 <= i < size():
+ (#*this)[i] == (*this)[(i+amount)&(size()-1)]
+ i.e. rotates the contents of *this right by amount spaces
+ - #at_start() == true
+ !*/
+
+ unsigned long get_element_id (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < size()
+ ensures
+ - returns an element id number that uniquely references the element at
+ the given index. (you can use this id to locate the new position of
+ an element after the buffer has been rotated)
+ - returned value is < size()
+ !*/
+
+ unsigned long get_element_index (
+ unsigned long element_id
+ ) const;
+ /*!
+ require
+ - element_id < size()
+ ensures
+ - returns the index of the element with the given element_id.
+ ( (*this)[get_element_index(element_id)] will always refer to the same element
+ no matter where it has been rotated to)
+ - returned value is < size()
+ !*/
+
+ const T& operator[] (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < size()
+ ensures
+ - returns a const reference to the element in *this at position index
+ !*/
+
+ T& operator[] (
+ unsigned long index
+ );
+ /*!
+ requires
+ - index < size()
+ ensures
+ - returns a reference to the element in *this at position index
+ !*/
+
+ void swap (
+ sliding_buffer<T>& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ sliding_buffer(sliding_buffer<T>&); // copy constructor
+ sliding_buffer<T>& operator=(sliding_buffer<T>&); // assignment operator
+
+ };
+
+ template <
+ typename T
+ >
+ void swap (
+ sliding_buffer<T>& a,
+ sliding_buffer<T>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sliding_buffer<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+}
+
+#endif // DLIB_SLIDING_BUFFER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_c.h b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_c.h
new file mode 100644
index 000000000..a7330e4b5
--- /dev/null
+++ b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_c.h
@@ -0,0 +1,222 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SLIDING_BUFFER_KERNEl_C_
+#define DLIB_SLIDING_BUFFER_KERNEl_C_
+
+#include "sliding_buffer_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename sb_base
+ >
+ class sliding_buffer_kernel_c : public sb_base
+ {
+ typedef typename sb_base::type T;
+
+ public:
+ void set_size (
+ unsigned long exp_size
+ );
+
+ const T& operator[] (
+ unsigned long index
+ ) const;
+
+ T& operator[] (
+ unsigned long index
+ );
+
+ unsigned long get_element_id (
+ unsigned long index
+ ) const;
+
+ unsigned long get_element_index (
+ unsigned long element_id
+ ) const;
+
+ const T& element (
+ ) const;
+
+ T& element (
+ );
+
+
+ };
+
+ template <
+ typename sb_base
+ >
+ inline void swap (
+ sliding_buffer_kernel_c<sb_base>& a,
+ sliding_buffer_kernel_c<sb_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sb_base
+ >
+ void sliding_buffer_kernel_c<sb_base>::
+ set_size (
+ unsigned long exp_size
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( 0 < exp_size && exp_size < 32,
+ "\tvoid sliding_buffer::set_size(unsigned long)"
+ << "\n\texp_size must be some number between 1 and 31"
+ << "\n\tthis: " << this
+ << "\n\texp_size: " << exp_size
+ );
+
+ // call the real function
+ sb_base::set_size(exp_size);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sb_base
+ >
+ unsigned long sliding_buffer_kernel_c<sb_base>::
+ get_element_id (
+ unsigned long index
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( index < this->size(),
+ "\tunsigned long sliding_buffer::get_element_id(unsigned long) const"
+ << "\n\tindex must be in the range 0 to size()-1"
+ << "\n\tthis: " << this
+ << "\n\tsize(): " << this->size()
+ << "\n\tindex: " << index
+ );
+
+ // call the real function
+ return sb_base::get_element_id(index);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sb_base
+ >
+ unsigned long sliding_buffer_kernel_c<sb_base>::
+ get_element_index (
+ unsigned long element_id
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( element_id < this->size(),
+ "\tunsigned long sliding_buffer::get_element_index(unsigned long) const"
+ << "\n\tid must be in the range 0 to size()-1"
+ << "\n\tthis: " << this
+ << "\n\tsize(): " << this->size()
+ << "\n\tid: " << element_id
+ );
+
+ // call the real function
+ return sb_base::get_element_index(element_id);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sb_base
+ >
+ const typename sb_base::type& sliding_buffer_kernel_c<sb_base>::
+ operator[] (
+ unsigned long index
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( index < this->size(),
+ "\tconst T& sliding_buffer::operator[](unsigned long) const"
+ << "\n\tindex must be in the range 0 to size()-1"
+ << "\n\tthis: " << this
+ << "\n\tsize(): " << this->size()
+ << "\n\tindex: " << index
+ );
+
+ // call the real function
+ return sb_base::operator[](index);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sb_base
+ >
+ typename sb_base::type& sliding_buffer_kernel_c<sb_base>::
+ operator[] (
+ unsigned long index
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( index < this->size(),
+ "\tT& sliding_buffer::operator[](unsigned long)"
+ << "\n\tindex must be in the range 0 to size()-1"
+ << "\n\tthis: " << this
+ << "\n\tsize(): " << this->size()
+ << "\n\tindex: " << index
+ );
+
+ // call the real function
+ return sb_base::operator[](index);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sb_base
+ >
+ const typename sb_base::type& sliding_buffer_kernel_c<sb_base>::
+ element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst T& sliding_buffer::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return sb_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sb_base
+ >
+ typename sb_base::type& sliding_buffer_kernel_c<sb_base>::
+ element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tT& sliding_buffer::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return sb_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SLIDING_BUFFER_KERNEl_C_
+
diff --git a/ml/dlib/dlib/smart_pointers.h b/ml/dlib/dlib/smart_pointers.h
new file mode 100644
index 000000000..905e88b15
--- /dev/null
+++ b/ml/dlib/dlib/smart_pointers.h
@@ -0,0 +1,22 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SMART_POINTERs_H_
+#define DLIB_SMART_POINTERs_H_
+
+// This is legacy smart pointer code that will likely stop working under default compiler
+// flags when C++17 becomes the default standard in compilers. Please consider migrating
+// your code to new smart pointers from C++ standard library.
+#if (defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))) || \
+ (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4)))
+#pragma GCC warning "smart_pointers.h is included. This code will fail to compile under C++17"
+#endif
+
+#include <memory>
+
+#include "smart_pointers/shared_ptr.h"
+#include "smart_pointers/weak_ptr.h"
+#include "smart_pointers/scoped_ptr.h"
+
+#endif // DLIB_SMART_POINTERs_H_
+
+
diff --git a/ml/dlib/dlib/smart_pointers/scoped_ptr.h b/ml/dlib/dlib/smart_pointers/scoped_ptr.h
new file mode 100644
index 000000000..dd890f330
--- /dev/null
+++ b/ml/dlib/dlib/smart_pointers/scoped_ptr.h
@@ -0,0 +1,16 @@
+#ifndef DLIB_SCOPED_PTr_H_
+#define DLIB_SCOPED_PTr_H_
+
+#include <memory>
+
+namespace dlib {
+ // Template alias for compatibility with clients using old dlib::scoped_ptr
+ // Old scoped_ptr implementation is removed completely
+ // This alias may fail in some reference deduction cases
+
+ template <class T, class Deleter = std::default_delete<T> >
+ using scoped_ptr = std::unique_ptr<T, Deleter>;
+
+}
+
+#endif
diff --git a/ml/dlib/dlib/smart_pointers/shared_ptr.h b/ml/dlib/dlib/smart_pointers/shared_ptr.h
new file mode 100644
index 000000000..15f7a4919
--- /dev/null
+++ b/ml/dlib/dlib/smart_pointers/shared_ptr.h
@@ -0,0 +1,492 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SHARED_PTr_
+#define DLIB_SHARED_PTr_
+
+#include <algorithm>
+#include <memory>
+#include <typeinfo>
+#include <string> // for the exceptions
+#include "../algs.h"
+#include "shared_ptr_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class bad_weak_ptr: public std::exception {};
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T> class weak_ptr;
+
+// ----------------------------------------------------------------------------------------
+
+ struct shared_ptr_deleter
+ {
+ virtual void del(const void* p) = 0;
+ virtual ~shared_ptr_deleter() {}
+
+ virtual void* get_deleter_void(const std::type_info& t) const = 0;
+ /*!
+ ensures
+ - if (the deleter in this object has typeid() == t) then
+ - returns a pointer to the deleter
+ - else
+ - return 0
+ !*/
+ };
+
+ struct shared_ptr_node;
+ struct weak_ptr_node
+ {
+ weak_ptr_node (
+ shared_ptr_node* sn
+ ) :
+ ref_count(1),
+ shared_node(sn)
+ {
+ DLIB_ASSERT(sn != 0,"");
+ }
+
+ long ref_count;
+ shared_ptr_node* shared_node;
+ };
+
+ struct shared_ptr_node
+ {
+ shared_ptr_node(
+ ) :
+ ref_count(1),
+ del(0),
+ weak_node(0)
+ {}
+
+ long ref_count;
+ shared_ptr_deleter* del;
+ weak_ptr_node* weak_node;
+ };
+
+ struct shared_ptr_static_cast {};
+ struct shared_ptr_const_cast {};
+ struct shared_ptr_dynamic_cast {};
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T>
+ class shared_ptr
+ {
+ /*!
+ CONVENTION
+ - get() == data
+ - unique() == (shared_node != 0) && (shared_node->ref_count == 1)
+ - if (shared_node != 0) then
+ - use_count() == shared_node->ref_count
+ - get() == a valid pointer
+ - if (we are supposed to use the deleter) then
+ - shared_node->del == the deleter to use
+ - else
+ - shared_node->del == 0
+ - else
+ - use_count() == 0
+ - get() == 0
+
+
+ - if (there are any weak_ptrs that reference this->data) then
+ - shared_node->weak_node->ref_count == the number of referencing weak_ptrs
+ - else
+ - shared_node->weak_node == 0
+ !*/
+
+ template <typename D>
+ struct deleter_template : public shared_ptr_deleter
+ {
+ deleter_template(const D& d_) : d(d_) {}
+ void del(const void* p) { d((T*)p); }
+ D d;
+
+ void* get_deleter_void(const std::type_info& t) const
+ {
+ if (typeid(D) == t)
+ return (void*)&d;
+ else
+ return 0;
+ }
+ };
+
+ struct default_deleter : public shared_ptr_deleter
+ {
+ void del(const void* p) { delete ((T*)p); }
+
+ void* get_deleter_void(const std::type_info&) const
+ {
+ return 0;
+ }
+ };
+
+ public:
+
+ typedef T element_type;
+
+ shared_ptr(
+ ) : data(0), shared_node(0) {}
+
+ template<typename Y>
+ explicit shared_ptr(
+ Y* p
+ ) : data(p)
+ {
+ DLIB_ASSERT(p != 0,
+ "\tshared_ptr::shared_ptr(p)"
+ << "\n\tp can't be null"
+ << "\n\tthis: " << this
+ );
+ try
+ {
+ shared_node = new shared_ptr_node;
+ shared_node->del = new default_deleter;
+ }
+ catch (...)
+ {
+ delete p;
+ throw;
+ }
+ }
+
+ template<typename Y, typename D>
+ shared_ptr(
+ Y* p,
+ const D& d
+ ) :
+ data(p)
+ {
+ DLIB_ASSERT(p != 0,
+ "\tshared_ptr::shared_ptr(p,d)"
+ << "\n\tp can't be null"
+ << "\n\tthis: " << this
+ );
+ try
+ {
+ shared_node = 0;
+ shared_node = new shared_ptr_node;
+ shared_node->del = new deleter_template<D>(d);
+ }
+ catch (...)
+ {
+ if (shared_node) delete shared_node;
+ d(p);
+ throw;
+ }
+ }
+
+ ~shared_ptr()
+ {
+ if ( shared_node != 0)
+ {
+ if (shared_node->ref_count == 1)
+ {
+ // delete the data in the appropriate way
+ shared_node->del->del(data);
+ delete shared_node->del;
+
+ // notify any weak_ptrs that the data has now expired
+ if (shared_node->weak_node)
+ shared_node->weak_node->shared_node = 0;
+
+ // finally delete the shared_node
+ delete shared_node;
+ }
+ else
+ {
+ shared_node->ref_count -= 1;
+ }
+ }
+ }
+
+ shared_ptr(
+ const shared_ptr& r
+ )
+ {
+ data = r.data;
+ shared_node = r.shared_node;
+ if (shared_node)
+ shared_node->ref_count += 1;
+ }
+
+ template<typename Y>
+ shared_ptr(
+ const shared_ptr<Y>& r,
+ const shared_ptr_static_cast&
+ )
+ {
+ data = static_cast<T*>(r.data);
+ if (data != 0)
+ {
+ shared_node = r.shared_node;
+ shared_node->ref_count += 1;
+ }
+ else
+ {
+ shared_node = 0;
+ }
+ }
+
+ template<typename Y>
+ shared_ptr(
+ const shared_ptr<Y>& r,
+ const shared_ptr_const_cast&
+ )
+ {
+ data = const_cast<T*>(r.data);
+ if (data != 0)
+ {
+ shared_node = r.shared_node;
+ shared_node->ref_count += 1;
+ }
+ else
+ {
+ shared_node = 0;
+ }
+ }
+
+ template<typename Y>
+ shared_ptr(
+ const shared_ptr<Y>& r,
+ const shared_ptr_dynamic_cast&
+ )
+ {
+ data = dynamic_cast<T*>(r.data);
+ if (data != 0)
+ {
+ shared_node = r.shared_node;
+ shared_node->ref_count += 1;
+ }
+ else
+ {
+ shared_node = 0;
+ }
+ }
+
+ template<typename Y>
+ shared_ptr(
+ const shared_ptr<Y>& r
+ )
+ {
+ data = r.data;
+ shared_node = r.shared_node;
+ if (shared_node)
+ shared_node->ref_count += 1;
+ }
+
+
+ template<typename Y>
+ explicit shared_ptr(
+ const weak_ptr<Y>& r
+ )
+ {
+ if (r.expired())
+ throw bad_weak_ptr();
+
+ data = r.data;
+ shared_node = r.weak_node->shared_node;
+ shared_node->ref_count += 1;
+ }
+
+ shared_ptr& operator= (
+ const shared_ptr& r
+ )
+ {
+ shared_ptr(r).swap(*this);
+ return *this;
+ }
+
+ template<typename Y>
+ shared_ptr& operator= (
+ const shared_ptr<Y>& r
+ )
+ {
+ shared_ptr(r).swap(*this);
+ return *this;
+ }
+
+ void reset()
+ {
+ shared_ptr().swap(*this);
+ }
+
+ template<typename Y>
+ void reset(Y* p)
+ {
+ DLIB_ASSERT(p != 0,
+ "\tshared_ptr::reset(p)"
+ << "\n\tp can't be null"
+ << "\n\tthis: " << this
+ );
+
+ shared_ptr(p).swap(*this);
+ }
+
+ template<typename Y, typename D>
+ void reset(
+ Y* p,
+ const D& d
+ )
+ {
+ DLIB_ASSERT(p != 0,
+ "\tshared_ptr::reset(p,d)"
+ << "\n\tp can't be null"
+ << "\n\tthis: " << this
+ );
+
+ shared_ptr(p,d).swap(*this);
+ }
+
+ T& operator*(
+ ) const
+ {
+ DLIB_ASSERT(get() != 0,
+ "\tshared_ptr::operator*()"
+ << "\n\tget() can't be null if you are going to dereference it"
+ << "\n\tthis: " << this
+ );
+
+ return *data;
+ }
+
+ T* operator->(
+ ) const
+ {
+ DLIB_ASSERT(get() != 0,
+ "\tshared_ptr::operator->()"
+ << "\n\tget() can't be null"
+ << "\n\tthis: " << this
+ );
+
+ return data;
+ }
+
+ T* get() const { return data; }
+
+ bool unique() const
+ {
+ return use_count() == 1;
+ }
+
+ long use_count() const
+ {
+ if (shared_node != 0)
+ return shared_node->ref_count;
+ else
+ return 0;
+ }
+
+ operator bool(
+ ) const { return get() != 0; }
+
+ void swap(shared_ptr& b)
+ {
+ std::swap(data, b.data);
+ std::swap(shared_node, b.shared_node);
+ }
+
+ template <typename D>
+ D* _get_deleter(
+ ) const
+ {
+ if (shared_node && shared_node->del)
+ return static_cast<D*>(shared_node->del->get_deleter_void(typeid(D)));
+ else
+ return 0;
+ }
+
+ template <typename Y>
+ bool _private_less (
+ const shared_ptr<Y>& rhs
+ ) const
+ {
+ return shared_node < rhs.shared_node;
+ }
+
+ private:
+
+ template <typename Y> friend class shared_ptr;
+ template <typename Y> friend class weak_ptr;
+
+ T* data;
+ shared_ptr_node* shared_node;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T, typename U>
+ bool operator== (
+ const shared_ptr<T>& a,
+ const shared_ptr<U>& b
+ ) { return a.get() == b.get(); }
+
+ template<typename T, typename U>
+ bool operator!= (
+ const shared_ptr<T>& a,
+ const shared_ptr<U>& b
+ ) { return a.get() != b.get(); }
+
+ template<typename T, typename U>
+ bool operator< (
+ const shared_ptr<T>& a,
+ const shared_ptr<U>& b
+ )
+ {
+ return a._private_less(b);
+ }
+
+ template<typename T>
+ void swap(
+ shared_ptr<T>& a,
+ shared_ptr<T>& b
+ ) { a.swap(b); }
+
+ template<typename T, typename U>
+ shared_ptr<T> static_pointer_cast(
+ const shared_ptr<U>& r
+ )
+ {
+ return shared_ptr<T>(r, shared_ptr_static_cast());
+ }
+
+ template<typename T, typename U>
+ shared_ptr<T> const_pointer_cast(
+ shared_ptr<U> const & r
+ )
+ {
+ return shared_ptr<T>(r, shared_ptr_const_cast());
+ }
+
+ template<typename T, typename U>
+ shared_ptr<T> dynamic_pointer_cast(
+ const shared_ptr<U>& r
+ )
+ {
+ return shared_ptr<T>(r, shared_ptr_dynamic_cast());
+ }
+
+ template<typename E, typename T, typename Y>
+ std::basic_ostream<E, T> & operator<< (std::basic_ostream<E, T> & os, shared_ptr<Y> const & p)
+ {
+ os << p.get();
+ return os;
+ }
+
+ template<typename D, typename T>
+ D* get_deleter(const shared_ptr<T>& p)
+ {
+ return p.template _get_deleter<D>();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_SHARED_PTr_
+
diff --git a/ml/dlib/dlib/smart_pointers/shared_ptr_abstract.h b/ml/dlib/dlib/smart_pointers/shared_ptr_abstract.h
new file mode 100644
index 000000000..9fc12c8e6
--- /dev/null
+++ b/ml/dlib/dlib/smart_pointers/shared_ptr_abstract.h
@@ -0,0 +1,374 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SHARED_PTr_ABSTRACT_
+#ifdef DLIB_SHARED_PTr_ABSTRACT_
+
+#include "weak_ptr_abstract.h"
+#include <exception>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class bad_weak_ptr: public std::exception {}
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class shared_ptr
+ {
+ /*!
+ INITIAL VALUE
+ defined by constructors
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a reference counted smart pointer. Each shared_ptr
+ contains a pointer to some object and when the last shared_ptr that points
+ to the object is destructed or reset() then the object is guaranteed to be
+ deleted.
+
+ This is an implementation of the std::tr1::shared_ptr template from the
+ document ISO/IEC PDTR 19768, Proposed Draft Technical Report on C++
+ Library Extensions. The only deviation from that document is that this
+ shared_ptr is declared inside the dlib namespace rather than std::tr1.
+
+ THREAD SAFETY
+ This object is not thread safe. Especially so since it is
+ reference counted. So you should take care to not have two shared_ptr
+ objects in different threads that point to the same object.
+
+ If you want a thread safe version of this object you should use the
+ dlib::shared_ptr_thread_safe object instead.
+ !*/
+
+ public:
+
+ typedef T element_type;
+
+ shared_ptr(
+ );
+ /*!
+ ensures
+ - #get() == 0
+ - #use_count() == 0
+ !*/
+
+ template<typename Y>
+ explicit shared_ptr(
+ Y* p
+ );
+ /*!
+ requires
+ - p is convertible to a T* type pointer
+ - p can be deleted by calling "delete p;" and doing so will not throw exceptions
+ - p != 0
+ ensures
+ - #get() == p
+ - #use_count() == 1
+ - #*this object owns the pointer p
+ throws
+ - std::bad_alloc
+ if this exception is thrown then "delete p;" is called
+ !*/
+
+ template<typename Y, typename D>
+ shared_ptr(
+ Y* p,
+ const D& d
+ );
+ /*!
+ requires
+ - p is convertible to a T* type pointer
+ - D is copy constructable (and the copy constructor of D doesn't throw)
+ - p can be deleted by calling "d(p);" and doing so will not throw exceptions
+ - p != 0
+ ensures
+ - #get() == p
+ - #use_count() == 1
+ - #*this object owns the pointer p
+ throws
+ - std::bad_alloc
+ if this exception is thrown then "d(p);" is called
+ !*/
+
+ shared_ptr(
+ const shared_ptr& r
+ );
+ /*!
+ ensures
+ - #get() == #r.get()
+ - #use_count() == #r.use_count()
+ - If r is empty, constructs an empty shared_ptr object; otherwise, constructs
+ a shared_ptr object that shares ownership with r.
+ !*/
+
+ template<typename Y>
+ shared_ptr(
+ const shared_ptr<Y>& r
+ );
+ /*!
+ requires
+ - Y* is convertible to T*
+ ensures
+ - #get() == #r.get()
+ - #use_count() == #r.use_count()
+ - If r is empty, constructs an empty shared_ptr object; otherwise, constructs
+ a shared_ptr object that shares ownership with r.
+ !*/
+
+ template<typename Y>
+ explicit shared_ptr(
+ const weak_ptr<Y>& r
+ );
+ /*!
+ requires
+ - Y* is convertible to T*
+ ensures
+ - #get() == #r.get()
+ - #use_count() == #r.use_count()
+ - If r is empty, constructs an empty shared_ptr object; otherwise, constructs
+ a shared_ptr object that shares ownership with r.
+ throws
+ - bad_weak_ptr
+ this exception is thrown if r.expired() == true
+ !*/
+
+ ~shared_ptr(
+ );
+ /*!
+ ensures
+ - if (use_count() > 1)
+ - this object destroys itself but otherwise has no effect (i.e.
+ the pointer get() is still valid and shared between the remaining
+ shared_ptr objects)
+ - else if (use_count() == 1)
+ - deletes the pointer get() by calling delete (or using the deleter passed
+ to the constructor if one was passed in)
+ - else
+ - in this case get() == 0 so there is nothing to do so nothing occurs
+ !*/
+
+ shared_ptr& operator= (
+ const shared_ptr& r
+ );
+ /*!
+ ensures
+ - equivalent to shared_ptr(r).swap(*this).
+ - returns #*this
+ !*/
+
+ template<typename Y>
+ shared_ptr& operator= (
+ const shared_ptr<Y>& r
+ );
+ /*!
+ requires
+ - Y* is convertible to T*
+ ensures
+ - equivalent to shared_ptr(r).swap(*this).
+ - returns #*this
+ !*/
+
+ void reset(
+ );
+ /*!
+ ensures
+ - equivalent to shared_ptr().swap(*this)
+ !*/
+
+ template<typename Y>
+ void reset(
+ Y* p
+ );
+ /*!
+ requires
+ - p is convertible to a T* type pointer
+ - p can be deleted by calling "delete p;" and doing so will not throw exceptions
+ - p != 0
+ ensures
+ - equivalent to shared_ptr(p).swap(*this)
+ !*/
+
+ template<typename Y, typename D>
+ void reset(
+ Y* p,
+ const D& d
+ );
+ /*!
+ requires
+ - p is convertible to a T* type pointer
+ - D is copy constructable (and the copy constructor of D doesn't throw)
+ - p can be deleted by calling "d(p);" and doing so will not throw exceptions
+ - p != 0
+ ensures
+ - equivalent to shared_ptr(p,d).swap(*this)
+ !*/
+
+ T* get(
+ ) const;
+ /*!
+ ensures
+ - returns the stored pointer
+ !*/
+
+ T& operator*(
+ ) const;
+ /*!
+ requires
+ - get() != 0
+ ensures
+ - returns a reference to *get()
+ !*/
+
+ T* operator->(
+ ) const;
+ /*!
+ requires
+ - get() != 0
+ ensures
+ - returns get()
+ !*/
+
+ bool unique(
+ ) const;
+ /*!
+ ensures
+ - returns (use_count() == 1)
+ !*/
+
+ long use_count(
+ ) const;
+ /*!
+ ensures
+ - The number of shared_ptr objects, *this included, that share ownership with *this, or 0 when *this
+ is empty.
+ !*/
+
+ operator bool(
+ ) const;
+ /*!
+ ensures
+ - returns (get() != 0)
+ !*/
+
+ void swap(
+ shared_ptr& b
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T, typename U>
+ bool operator== (
+ const shared_ptr<T>& a,
+ const shared_ptr<U>& b
+ );
+ /*!
+ ensures
+ - returns a.get() == b.get()
+ !*/
+
+ template<typename T, typename U>
+ bool operator!= (
+ const shared_ptr<T>& a,
+ const shared_ptr<U>& b
+ ) { return a.get() != b.get(); }
+ /*!
+ ensures
+ - returns a.get() != b.get()
+ !*/
+
+ template<typename T, typename U>
+ bool operator< (
+ const shared_ptr<T>& a,
+ const shared_ptr<U>& b
+ );
+ /*!
+ ensures
+ - Defines an operator< on shared_ptr types appropriate for use in the associative
+ containers.
+ !*/
+
+ template<typename T>
+ void swap(
+ shared_ptr<T>& a,
+ shared_ptr<T>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template<typename T, typename U>
+ shared_ptr<T> static_pointer_cast(
+ const shared_ptr<U>& r
+ );
+ /*!
+ - if (r.get() == 0) then
+ - returns shared_ptr<T>()
+ - else
+ - returns a shared_ptr<T> object that stores static_cast<T*>(r.get()) and shares
+ ownership with r.
+ !*/
+
+ template<typename T, typename U>
+ shared_ptr<T> const_pointer_cast(
+ const shared_ptr<U>& r
+ );
+ /*!
+ - if (r.get() == 0) then
+ - returns shared_ptr<T>()
+ - else
+ - returns a shared_ptr<T> object that stores const_cast<T*>(r.get()) and shares
+ ownership with r.
+ !*/
+
+ template<typename T, typename U>
+ shared_ptr<T> dynamic_pointer_cast(
+ const shared_ptr<U>& r
+ );
+ /*!
+ ensures
+ - if (dynamic_cast<T*>(r.get()) returns a nonzero value) then
+ - returns a shared_ptr<T> object that stores a copy of
+ dynamic_cast<T*>(r.get()) and shares ownership with r
+ - else
+ - returns an empty shared_ptr<T> object.
+ !*/
+
+ template<typename E, typename T, typename Y>
+ std::basic_ostream<E, T> & operator<< (
+ std::basic_ostream<E, T> & os,
+ const shared_ptr<Y>& p
+ );
+ /*!
+ ensures
+ - performs os << p.get()
+ - returns os
+ !*/
+
+ template<typename D, typename T>
+ D* get_deleter(
+ const shared_ptr<T>& p
+ );
+ /*!
+ ensures
+ - if (*this owns a deleter d of type cv-unqualified D) then
+ - returns &d
+ - else
+ - returns 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SHARED_PTr_ABSTRACT_
+
diff --git a/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe.h b/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe.h
new file mode 100644
index 000000000..31bda5651
--- /dev/null
+++ b/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe.h
@@ -0,0 +1,462 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SHARED_THREAD_SAFE_PTr_
+#define DLIB_SHARED_THREAD_SAFE_PTr_
+
+#include <algorithm>
+#include <memory>
+#include <typeinfo>
+#include <string> // for the exceptions
+#include "../algs.h"
+#include "shared_ptr_thread_safe_abstract.h"
+#include "../threads/threads_kernel.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct shared_ptr_thread_safe_deleter
+ {
+ virtual void del(const void* p) = 0;
+ virtual ~shared_ptr_thread_safe_deleter() {}
+
+ virtual void* get_deleter_void(const std::type_info& t) const = 0;
+ /*!
+ ensures
+ - if (the deleter in this object has typeid() == t) then
+ - returns a pointer to the deleter
+ - else
+ - return 0
+ !*/
+ };
+
+ struct shared_ptr_thread_safe_node
+ {
+ shared_ptr_thread_safe_node(
+ ) :
+ ref_count(1),
+ del(0)
+ {}
+
+ dlib::mutex m;
+ long ref_count;
+ shared_ptr_thread_safe_deleter* del;
+ };
+
+ struct shared_ptr_ts_static_cast {};
+ struct shared_ptr_ts_const_cast {};
+ struct shared_ptr_ts_dynamic_cast {};
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T>
+ class shared_ptr_thread_safe
+ {
+ /*!
+ CONVENTION
+ - get() == data
+ - unique() == (shared_node != 0) && (shared_node->ref_count == 1)
+ - if (shared_node != 0) then
+ - use_count() == shared_node->ref_count
+ - get() == a valid pointer
+ - if (we are supposed to use the deleter) then
+ - shared_node->del == the deleter to use
+ - else
+ - shared_node->del == 0
+ - else
+ - use_count() == 0
+ - get() == 0
+
+ !*/
+
+ template <typename D>
+ struct deleter_template : public shared_ptr_thread_safe_deleter
+ {
+ deleter_template(const D& d_) : d(d_) {}
+ void del(const void* p) { d((T*)p); }
+ D d;
+
+ void* get_deleter_void(const std::type_info& t) const
+ {
+ if (typeid(D) == t)
+ return (void*)&d;
+ else
+ return 0;
+ }
+ };
+
+ public:
+
+ typedef T element_type;
+
+ shared_ptr_thread_safe(
+ ) : data(0), shared_node(0) {}
+
+ template<typename Y>
+ explicit shared_ptr_thread_safe(
+ Y* p
+ ) : data(p)
+ {
+ DLIB_ASSERT(p != 0,
+ "\tshared_ptr::shared_ptr_thread_safe(p)"
+ << "\n\tp can't be null"
+ << "\n\tthis: " << this
+ );
+ try
+ {
+ shared_node = new shared_ptr_thread_safe_node;
+ }
+ catch (...)
+ {
+ delete p;
+ throw;
+ }
+ }
+
+ template<typename Y, typename D>
+ shared_ptr_thread_safe(
+ Y* p,
+ const D& d
+ ) :
+ data(p)
+ {
+ DLIB_ASSERT(p != 0,
+ "\tshared_ptr::shared_ptr_thread_safe(p,d)"
+ << "\n\tp can't be null"
+ << "\n\tthis: " << this
+ );
+ try
+ {
+ shared_node = 0;
+ shared_node = new shared_ptr_thread_safe_node;
+ shared_node->del = new deleter_template<D>(d);
+ }
+ catch (...)
+ {
+ if (shared_node) delete shared_node;
+ d(p);
+ throw;
+ }
+ }
+
+ ~shared_ptr_thread_safe()
+ {
+ if ( shared_node != 0)
+ {
+ shared_node->m.lock();
+ if (shared_node->ref_count == 1)
+ {
+ // delete the data in the appropriate way
+ if (shared_node->del)
+ {
+ shared_node->del->del(data);
+
+ shared_node->m.unlock();
+ delete shared_node->del;
+ }
+ else
+ {
+ shared_node->m.unlock();
+ delete data;
+ }
+
+ // finally delete the shared_node
+ delete shared_node;
+ }
+ else
+ {
+ shared_node->ref_count -= 1;
+ shared_node->m.unlock();
+ }
+ }
+ }
+
+ shared_ptr_thread_safe(
+ const shared_ptr_thread_safe& r
+ )
+ {
+ data = r.data;
+ shared_node = r.shared_node;
+ if (shared_node)
+ {
+ auto_mutex M(shared_node->m);
+ shared_node->ref_count += 1;
+ }
+ }
+
+ template<typename Y>
+ shared_ptr_thread_safe(
+ const shared_ptr_thread_safe<Y>& r,
+ const shared_ptr_ts_static_cast&
+ )
+ {
+ data = static_cast<T*>(r.data);
+ if (data != 0)
+ {
+ shared_node = r.shared_node;
+ auto_mutex M(shared_node->m);
+ shared_node->ref_count += 1;
+ }
+ else
+ {
+ shared_node = 0;
+ }
+ }
+
+ template<typename Y>
+ shared_ptr_thread_safe(
+ const shared_ptr_thread_safe<Y>& r,
+ const shared_ptr_ts_const_cast&
+ )
+ {
+ data = const_cast<T*>(r.data);
+ if (data != 0)
+ {
+ shared_node = r.shared_node;
+ auto_mutex M(shared_node->m);
+ shared_node->ref_count += 1;
+ }
+ else
+ {
+ shared_node = 0;
+ }
+ }
+
+ template<typename Y>
+ shared_ptr_thread_safe(
+ const shared_ptr_thread_safe<Y>& r,
+ const shared_ptr_ts_dynamic_cast&
+ )
+ {
+ data = dynamic_cast<T*>(r.data);
+ if (data != 0)
+ {
+ shared_node = r.shared_node;
+ auto_mutex M(shared_node->m);
+ shared_node->ref_count += 1;
+ }
+ else
+ {
+ shared_node = 0;
+ }
+ }
+
+ template<typename Y>
+ shared_ptr_thread_safe(
+ const shared_ptr_thread_safe<Y>& r
+ )
+ {
+ data = r.data;
+ shared_node = r.shared_node;
+ if (shared_node)
+ {
+ auto_mutex M(shared_node->m);
+ shared_node->ref_count += 1;
+ }
+ }
+
+ shared_ptr_thread_safe& operator= (
+ const shared_ptr_thread_safe& r
+ )
+ {
+ shared_ptr_thread_safe(r).swap(*this);
+ return *this;
+ }
+
+ template<typename Y>
+ shared_ptr_thread_safe& operator= (
+ const shared_ptr_thread_safe<Y>& r
+ )
+ {
+ shared_ptr_thread_safe(r).swap(*this);
+ return *this;
+ }
+
+ void reset()
+ {
+ shared_ptr_thread_safe().swap(*this);
+ }
+
+ template<typename Y>
+ void reset(Y* p)
+ {
+ DLIB_ASSERT(p != 0,
+ "\tshared_ptr::reset(p)"
+ << "\n\tp can't be null"
+ << "\n\tthis: " << this
+ );
+
+ shared_ptr_thread_safe(p).swap(*this);
+ }
+
+ template<typename Y, typename D>
+ void reset(
+ Y* p,
+ const D& d
+ )
+ {
+ DLIB_ASSERT(p != 0,
+ "\tshared_ptr::reset(p,d)"
+ << "\n\tp can't be null"
+ << "\n\tthis: " << this
+ );
+
+ shared_ptr_thread_safe(p,d).swap(*this);
+ }
+
+ T& operator*(
+ ) const
+ {
+ DLIB_ASSERT(get() != 0,
+ "\tshared_ptr::operator*()"
+ << "\n\tget() can't be null if you are going to dereference it"
+ << "\n\tthis: " << this
+ );
+
+ return *data;
+ }
+
+ T* operator->(
+ ) const
+ {
+ DLIB_ASSERT(get() != 0,
+ "\tshared_ptr::operator->()"
+ << "\n\tget() can't be null"
+ << "\n\tthis: " << this
+ );
+
+ return data;
+ }
+
+ T* get() const { return data; }
+
+ bool unique() const
+ {
+ return use_count() == 1;
+ }
+
+ long use_count() const
+ {
+ if (shared_node != 0)
+ {
+ auto_mutex M(shared_node->m);
+ return shared_node->ref_count;
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ operator bool(
+ ) const { return get() != 0; }
+
+ void swap(shared_ptr_thread_safe& b)
+ {
+ std::swap(data, b.data);
+ std::swap(shared_node, b.shared_node);
+ }
+
+ template <typename D>
+ D* _get_deleter(
+ ) const
+ {
+ if (shared_node)
+ {
+ auto_mutex M(shared_node->m);
+ if (shared_node->del)
+ return static_cast<D*>(shared_node->del->get_deleter_void(typeid(D)));
+ }
+ return 0;
+ }
+
+ template <typename Y>
+ bool _private_less (
+ const shared_ptr_thread_safe<Y>& rhs
+ ) const
+ {
+ return shared_node < rhs.shared_node;
+ }
+
+ private:
+
+ template <typename Y> friend class shared_ptr_thread_safe;
+
+ T* data;
+ shared_ptr_thread_safe_node* shared_node;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T, typename U>
+ bool operator== (
+ const shared_ptr_thread_safe<T>& a,
+ const shared_ptr_thread_safe<U>& b
+ ) { return a.get() == b.get(); }
+
+ template<typename T, typename U>
+ bool operator!= (
+ const shared_ptr_thread_safe<T>& a,
+ const shared_ptr_thread_safe<U>& b
+ ) { return a.get() != b.get(); }
+
+ template<typename T, typename U>
+ bool operator< (
+ const shared_ptr_thread_safe<T>& a,
+ const shared_ptr_thread_safe<U>& b
+ )
+ {
+ return a._private_less(b);
+ }
+
+ template<typename T>
+ void swap(
+ shared_ptr_thread_safe<T>& a,
+ shared_ptr_thread_safe<T>& b
+ ) { a.swap(b); }
+
+ template<typename T, typename U>
+ shared_ptr_thread_safe<T> static_pointer_cast(
+ const shared_ptr_thread_safe<U>& r
+ )
+ {
+ return shared_ptr_thread_safe<T>(r, shared_ptr_ts_static_cast());
+ }
+
+ template<typename T, typename U>
+ shared_ptr_thread_safe<T> const_pointer_cast(
+ shared_ptr_thread_safe<U> const & r
+ )
+ {
+ return shared_ptr_thread_safe<T>(r, shared_ptr_ts_const_cast());
+ }
+
+ template<typename T, typename U>
+ shared_ptr_thread_safe<T> dynamic_pointer_cast(
+ const shared_ptr_thread_safe<U>& r
+ )
+ {
+ return shared_ptr_thread_safe<T>(r, shared_ptr_ts_dynamic_cast());
+ }
+
+ template<typename E, typename T, typename Y>
+ std::basic_ostream<E, T> & operator<< (std::basic_ostream<E, T> & os, shared_ptr_thread_safe<Y> const & p)
+ {
+ os << p.get();
+ return os;
+ }
+
+ template<typename D, typename T>
+ D* get_deleter(const shared_ptr_thread_safe<T>& p)
+ {
+ return p.template _get_deleter<D>();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SHARED_THREAD_SAFE_PTr_
+
diff --git a/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe_abstract.h b/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe_abstract.h
new file mode 100644
index 000000000..472a00464
--- /dev/null
+++ b/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe_abstract.h
@@ -0,0 +1,352 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SHARED_PTr_THREAD_SAFE_ABSTRACT_
+#ifdef DLIB_SHARED_PTr_THREAD_SAFE_ABSTRACT_
+
+#include <exception>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class shared_ptr_thread_safe
+ {
+ /*!
+ INITIAL VALUE
+ defined by constructors
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a reference counted smart pointer. Each shared_ptr_thread_safe
+ contains a pointer to some object and when the last shared_ptr_thread_safe that points
+ to the object is destructed or reset() then the object is guaranteed to be
+ deleted.
+
+ This is an implementation of the std::tr1::shared_ptr template from the
+ document ISO/IEC PDTR 19768, Proposed Draft Technical Report on C++
+ Library Extensions. The only deviation from that document is that this
+ shared_ptr_thread_safe is declared inside the dlib namespace rather than std::tr1,
+ this one is explicitly thread safe, and there isn't a corresponding weak_ptr.
+
+ THREAD SAFETY
+ This is a version of the shared_ptr object that can be used to share pointers
+ across more than one thread. Note however, that individual instances of this object
+ must still have access to them serialized by a mutex lock if they are to be modified
+ by more than one thread. But if you have two different shared_ptr_thread_safe objects
+ that both point to the same thing from different threads then you are safe.
+ !*/
+
+ public:
+
+ typedef T element_type;
+
+ shared_ptr_thread_safe(
+ );
+ /*!
+ ensures
+ - #get() == 0
+ - #use_count() == 0
+ !*/
+
+ template<typename Y>
+ explicit shared_ptr_thread_safe(
+ Y* p
+ );
+ /*!
+ requires
+ - p is convertible to a T* type pointer
+ - p can be deleted by calling "delete p;" and doing so will not throw exceptions
+ - p != 0
+ ensures
+ - #get() == p
+ - #use_count() == 1
+ - #*this object owns the pointer p
+ throws
+ - std::bad_alloc
+ if this exception is thrown then "delete p;" is called
+ !*/
+
+ template<typename Y, typename D>
+ shared_ptr_thread_safe(
+ Y* p,
+ const D& d
+ );
+ /*!
+ requires
+ - p is convertible to a T* type pointer
+ - D is copy constructable (and the copy constructor of D doesn't throw)
+ - p can be deleted by calling "d(p);" and doing so will not throw exceptions
+ - p != 0
+ ensures
+ - #get() == p
+ - #use_count() == 1
+ - #*this object owns the pointer p
+ throws
+ - std::bad_alloc
+ if this exception is thrown then "d(p);" is called
+ !*/
+
+ shared_ptr_thread_safe(
+ const shared_ptr_thread_safe& r
+ );
+ /*!
+ ensures
+ - #get() == #r.get()
+ - #use_count() == #r.use_count()
+ - If r is empty, constructs an empty shared_ptr_thread_safe object; otherwise, constructs
+ a shared_ptr_thread_safe object that shares ownership with r.
+ !*/
+
+ template<typename Y>
+ shared_ptr_thread_safe(
+ const shared_ptr_thread_safe<Y>& r
+ );
+ /*!
+ requires
+ - Y* is convertible to T*
+ ensures
+ - #get() == #r.get()
+ - #use_count() == #r.use_count()
+ - If r is empty, constructs an empty shared_ptr_thread_safe object; otherwise, constructs
+ a shared_ptr_thread_safe object that shares ownership with r.
+ !*/
+
+ ~shared_ptr_thread_safe(
+ );
+ /*!
+ ensures
+ - if (use_count() > 1)
+ - this object destroys itself but otherwise has no effect (i.e.
+ the pointer get() is still valid and shared between the remaining
+ shared_ptr_thread_safe objects)
+ - else if (use_count() == 1)
+ - deletes the pointer get() by calling delete (or using the deleter passed
+ to the constructor if one was passed in)
+ - else
+ - in this case get() == 0 so there is nothing to do so nothing occurs
+ !*/
+
+ shared_ptr_thread_safe& operator= (
+ const shared_ptr_thread_safe& r
+ );
+ /*!
+ ensures
+ - equivalent to shared_ptr_thread_safe(r).swap(*this).
+ - returns #*this
+ !*/
+
+ template<typename Y>
+ shared_ptr_thread_safe& operator= (
+ const shared_ptr_thread_safe<Y>& r
+ );
+ /*!
+ requires
+ - Y* is convertible to T*
+ ensures
+ - equivalent to shared_ptr_thread_safe(r).swap(*this).
+ - returns #*this
+ !*/
+
+ void reset(
+ );
+ /*!
+ ensures
+ - equivalent to shared_ptr_thread_safe().swap(*this)
+ !*/
+
+ template<typename Y>
+ void reset(
+ Y* p
+ );
+ /*!
+ requires
+ - p is convertible to a T* type pointer
+ - p can be deleted by calling "delete p;" and doing so will not throw exceptions
+ - p != 0
+ ensures
+ - equivalent to shared_ptr_thread_safe(p).swap(*this)
+ !*/
+
+ template<typename Y, typename D>
+ void reset(
+ Y* p,
+ const D& d
+ );
+ /*!
+ requires
+ - p is convertible to a T* type pointer
+ - D is copy constructable (and the copy constructor of D doesn't throw)
+ - p can be deleted by calling "d(p);" and doing so will not throw exceptions
+ - p != 0
+ ensures
+ - equivalent to shared_ptr_thread_safe(p,d).swap(*this)
+ !*/
+
+ T* get(
+ ) const;
+ /*!
+ ensures
+ - returns the stored pointer
+ !*/
+
+ T& operator*(
+ ) const;
+ /*!
+ requires
+ - get() != 0
+ ensures
+ - returns a reference to *get()
+ !*/
+
+ T* operator->(
+ ) const;
+ /*!
+ requires
+ - get() != 0
+ ensures
+ - returns get()
+ !*/
+
+ bool unique(
+ ) const;
+ /*!
+ ensures
+ - returns (use_count() == 1)
+ !*/
+
+ long use_count(
+ ) const;
+ /*!
+ ensures
+ - The number of shared_ptr_thread_safe objects, *this included, that share ownership with *this, or 0 when *this
+ is empty.
+ !*/
+
+ operator bool(
+ ) const;
+ /*!
+ ensures
+ - returns (get() != 0)
+ !*/
+
+ void swap(
+ shared_ptr_thread_safe& b
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template<typename T, typename U>
+ bool operator== (
+ const shared_ptr_thread_safe<T>& a,
+ const shared_ptr_thread_safe<U>& b
+ );
+ /*!
+ ensures
+ - returns a.get() == b.get()
+ !*/
+
+ template<typename T, typename U>
+ bool operator!= (
+ const shared_ptr_thread_safe<T>& a,
+ const shared_ptr_thread_safe<U>& b
+ ) { return a.get() != b.get(); }
+ /*!
+ ensures
+ - returns a.get() != b.get()
+ !*/
+
+ template<typename T, typename U>
+ bool operator< (
+ const shared_ptr_thread_safe<T>& a,
+ const shared_ptr_thread_safe<U>& b
+ );
+ /*!
+ ensures
+ - Defines an operator< on shared_ptr_thread_safe types appropriate for use in the associative
+ containers.
+ !*/
+
+ template<typename T>
+ void swap(
+ shared_ptr_thread_safe<T>& a,
+ shared_ptr_thread_safe<T>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template<typename T, typename U>
+ shared_ptr_thread_safe<T> static_pointer_cast(
+ const shared_ptr_thread_safe<U>& r
+ );
+ /*!
+ - if (r.get() == 0) then
+ - returns shared_ptr_thread_safe<T>()
+ - else
+ - returns a shared_ptr_thread_safe<T> object that stores static_cast<T*>(r.get()) and shares
+ ownership with r.
+ !*/
+
+ template<typename T, typename U>
+ shared_ptr_thread_safe<T> const_pointer_cast(
+ const shared_ptr_thread_safe<U>& r
+ );
+ /*!
+ - if (r.get() == 0) then
+ - returns shared_ptr_thread_safe<T>()
+ - else
+ - returns a shared_ptr_thread_safe<T> object that stores const_cast<T*>(r.get()) and shares
+ ownership with r.
+ !*/
+
+ template<typename T, typename U>
+ shared_ptr_thread_safe<T> dynamic_pointer_cast(
+ const shared_ptr_thread_safe<U>& r
+ );
+ /*!
+ ensures
+ - if (dynamic_cast<T*>(r.get()) returns a nonzero value) then
+ - returns a shared_ptr_thread_safe<T> object that stores a copy of
+ dynamic_cast<T*>(r.get()) and shares ownership with r
+ - else
+ - returns an empty shared_ptr_thread_safe<T> object.
+ !*/
+
+ template<typename E, typename T, typename Y>
+ std::basic_ostream<E, T> & operator<< (
+ std::basic_ostream<E, T> & os,
+ const shared_ptr_thread_safe<Y>& p
+ );
+ /*!
+ ensures
+ - performs os << p.get()
+ - returns os
+ !*/
+
+ template<typename D, typename T>
+ D* get_deleter(
+ const shared_ptr_thread_safe<T>& p
+ );
+ /*!
+ ensures
+ - if (*this owns a deleter d of type cv-unqualified D) then
+ - returns &d
+ - else
+ - returns 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SHARED_PTr_THREAD_SAFE_ABSTRACT_
+
diff --git a/ml/dlib/dlib/smart_pointers/weak_ptr.h b/ml/dlib/dlib/smart_pointers/weak_ptr.h
new file mode 100644
index 000000000..7e3405678
--- /dev/null
+++ b/ml/dlib/dlib/smart_pointers/weak_ptr.h
@@ -0,0 +1,225 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_WEAK_PTr_
+#define DLIB_WEAK_PTr_
+
+#include <algorithm>
+#include <memory>
+#include "shared_ptr.h"
+#include "../algs.h"
+#include "weak_ptr_abstract.h"
+
+namespace dlib {
+
+ template <
+ typename T
+ >
+ class weak_ptr
+ {
+
+ /*!
+ CONVENTION
+ - if (weak_node != 0) then
+ - data == valid pointer to shared data
+ - weak_node->ref_count == the number of weak_ptrs that reference this->data
+ - else
+ - data == 0
+
+ - expired() == ((weak_node == 0) || (weak_node->shared_node == 0))
+ - if (expired() == false) then
+ - use_count() == weak_node->shared_node->ref_count
+ - else
+ - use_count() == 0
+ !*/
+
+ public:
+ typedef T element_type;
+
+ weak_ptr(
+ ) : data(0), weak_node(0)
+ {
+ }
+
+ template<typename Y>
+ weak_ptr(
+ const shared_ptr<Y>& r
+ )
+ {
+ data = r.data;
+ if (r.shared_node)
+ {
+ if (r.shared_node->weak_node)
+ {
+ weak_node = r.shared_node->weak_node;
+ weak_node->ref_count += 1;
+ }
+ else
+ {
+ weak_node = new weak_ptr_node(r.shared_node);
+ r.shared_node->weak_node = weak_node;
+ }
+ }
+ else
+ {
+ weak_node = 0;
+ }
+ }
+
+ weak_ptr(
+ const weak_ptr& r
+ )
+ {
+ data = r.data;
+ weak_node = r.weak_node;
+ if (weak_node)
+ weak_node->ref_count += 1;
+ }
+
+ template<typename Y>
+ weak_ptr(
+ const weak_ptr<Y>& r
+ )
+ {
+ data = r.data;
+ weak_node = r.weak_node;
+ if (weak_node)
+ weak_node->ref_count += 1;
+ }
+
+ ~weak_ptr(
+ )
+ {
+ if (weak_node)
+ {
+ // make note that this weak_ptr is being destroyed
+ weak_node->ref_count -= 1;
+
+ // if this is the last weak_ptr then we should clean up our stuff
+ if (weak_node->ref_count == 0)
+ {
+ if (expired() == false)
+ weak_node->shared_node->weak_node = 0;
+ delete weak_node;
+ }
+ }
+ }
+
+ weak_ptr& operator= (
+ const weak_ptr& r
+ )
+ {
+ weak_ptr(r).swap(*this);
+ return *this;
+ }
+
+ template<typename Y>
+ weak_ptr& operator= (
+ const weak_ptr<Y>& r
+ )
+ {
+ weak_ptr(r).swap(*this);
+ return *this;
+ }
+
+ template<typename Y>
+ weak_ptr& operator=(
+ const shared_ptr<Y>& r
+ )
+ {
+ weak_ptr(r).swap(*this);
+ return *this;
+ }
+
+ long use_count(
+ ) const
+ {
+ if (expired())
+ return 0;
+ else
+ return weak_node->shared_node->ref_count;
+ }
+
+ bool expired() const { return weak_node == 0 || weak_node->shared_node == 0; }
+
+ shared_ptr<T> lock(
+ ) const
+ {
+ if (expired())
+ return shared_ptr<T>();
+ else
+ return shared_ptr<T>(*this);
+ }
+
+ void reset(
+ )
+ {
+ weak_ptr().swap(*this);
+ }
+
+ void swap(
+ weak_ptr<T>& b
+ )
+ {
+ std::swap(data, b.data);
+ std::swap(weak_node, b.weak_node);
+ }
+
+ template <typename Y>
+ bool _private_less (
+ const weak_ptr<Y>& rhs
+ ) const
+ {
+ if (expired())
+ {
+ if (rhs.expired())
+ {
+ return false;
+ }
+ else
+ {
+ return true;
+ }
+ }
+ else
+ {
+ if (rhs.expired())
+ {
+ return false;
+ }
+ else
+ {
+ // in this case they have both not expired so lets
+ // compare the shared_node pointers
+ return (weak_node->shared_node) < (rhs.weak_node->shared_node);
+ }
+ }
+ }
+
+ private:
+
+ template <typename Y> friend class shared_ptr;
+ template <typename Y> friend class weak_ptr;
+
+ T* data;
+ weak_ptr_node* weak_node;
+ };
+
+ template<typename T, typename U>
+ bool operator< (
+ const weak_ptr<T>& a,
+ const weak_ptr<U>& b
+ )
+ {
+ return a._private_less(b);
+ }
+
+ template<typename T>
+ void swap(
+ weak_ptr<T>& a,
+ weak_ptr<T> & b
+ ) { a.swap(b); }
+}
+
+#endif // DLIB_WEAK_PTr_
+
+
diff --git a/ml/dlib/dlib/smart_pointers/weak_ptr_abstract.h b/ml/dlib/dlib/smart_pointers/weak_ptr_abstract.h
new file mode 100644
index 000000000..549684e3a
--- /dev/null
+++ b/ml/dlib/dlib/smart_pointers/weak_ptr_abstract.h
@@ -0,0 +1,193 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_WEAK_PTr_ABSTRACT_
+#ifdef DLIB_WEAK_PTr_ABSTRACT_
+
+#include "shared_ptr_abstract.h"
+
+namespace dlib {
+
+ template <
+ typename T
+ >
+ class weak_ptr
+ {
+
+ /*!
+ INITIAL VALUE
+ defined by constructor
+
+ WHAT THIS OBJECT REPRESENTS
+ The weak_ptr class template stores a weak reference to an object that is
+ already managed by a shared_ptr. To access the object, a weak_ptr can
+ be converted to a shared_ptr using the member function lock().
+
+ This is an implementation of the std::tr1::weak_ptr template from the
+ document ISO/IEC PDTR 19768, Proposed Draft Technical Report on C++
+ Library Extensions. The only deviation from that document is that this
+ shared_ptr is declared inside the dlib namespace rather than std::tr1.
+ !*/
+
+ public:
+ typedef T element_type;
+
+ weak_ptr(
+ );
+ /*!
+ ensures
+ - #use_count() == 0
+ - creates an empty weak_ptr
+ !*/
+
+ template<typename Y>
+ weak_ptr(
+ const shared_ptr<Y>& r
+ );
+ /*!
+ requires
+ - Y* must be convertible to T*
+ ensures
+ - if (r is empty) then
+ - constructs an empty weak_ptr object
+ - else
+ - constructs a weak_ptr object that shares ownership with r and
+ stores a copy of the pointer stored in r.
+ - #use_count() == #r.use_count()
+ !*/
+
+ weak_ptr(
+ const weak_ptr& r
+ );
+ /*!
+ ensures
+ - if (r is empty) then
+ - constructs an empty weak_ptr object
+ - else
+ - constructs a weak_ptr object that shares ownership with r and
+ stores a copy of the pointer stored in r.
+ - #use_count() == #r.use_count()
+ !*/
+
+ template<typename Y>
+ weak_ptr(
+ const weak_ptr<Y>& r
+ );
+ /*!
+ requires
+ - Y* must be convertible to T*
+ ensures
+ - if (r is empty) then
+ - constructs an empty weak_ptr object
+ - else
+ - constructs a weak_ptr object that shares ownership with r and
+ stores a copy of the pointer stored in r.
+ - #use_count() == #r.use_count()
+ !*/
+
+ ~weak_ptr(
+ );
+ /*!
+ ensures
+ - destroys this weak_ptr object but has no effect on the object its
+ stored pointer points to.
+ !*/
+
+ weak_ptr& operator= (
+ const weak_ptr& r
+ );
+ /*!
+ ensures
+ - equivalent to weak_ptr(r).swap(*this)
+ !*/
+
+ template<typename Y>
+ weak_ptr& operator= (
+ const weak_ptr<Y>& r
+ );
+ /*!
+ requires
+ - Y* must be convertible to T*
+ ensures
+ - equivalent to weak_ptr(r).swap(*this)
+ !*/
+
+ template<typename Y>
+ weak_ptr& operator=(
+ const shared_ptr<Y>& r
+ );
+ /*!
+ requires
+ - Y* must be convertible to T*
+ ensures
+ - equivalent to weak_ptr(r).swap(*this)
+ !*/
+
+ long use_count(
+ ) const;
+ /*!
+ ensures
+ - if (*this is empty) then
+ - returns 0
+ - else
+ - returns the number of shared_ptr instances that share ownership
+ with *this
+ !*/
+
+ bool expired(
+ ) const;
+ /*!
+ ensures
+ - returns (use_count() == 0)
+ !*/
+
+ shared_ptr<T> lock(
+ ) const;
+ /*!
+ ensures
+ - if (expired()) then
+ - returns shared_ptr<T>()
+ - else
+ - returns shared_ptr<T>(*this)
+ !*/
+
+ void reset(
+ );
+ /*!
+ ensures
+ - equivalent to weak_ptr().swap(*this)
+ !*/
+
+ void swap(
+ weak_ptr<T>& b
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ template<typename T, typename U>
+ bool operator< (
+ const weak_ptr<T>& a,
+ const weak_ptr<U>& b
+ );
+ /*!
+ ensures
+ - Defines an operator< on shared_ptr types appropriate for use in the associative
+ containers.
+ !*/
+
+ template<typename T>
+ void swap(
+ weak_ptr<T>& a,
+ weak_ptr<T> & b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+}
+
+#endif // DLIB_WEAK_PTr_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/smart_pointers_thread_safe.h b/ml/dlib/dlib/smart_pointers_thread_safe.h
new file mode 100644
index 000000000..e00141f08
--- /dev/null
+++ b/ml/dlib/dlib/smart_pointers_thread_safe.h
@@ -0,0 +1,21 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SMART_POINTERs_THREAD_SAFE_H_
+#define DLIB_SMART_POINTERs_THREAD_SAFE_H_
+
+// This is legacy smart pointer code that will likely to stop working under default
+// compiler flags when C++17 becomes the default standard in the compilers.
+// Please consider migrating your code to contemporary smart pointers from C++
+// standard library. The warning below will help to detect if the deprecated code
+// was included from library's clients.
+#if (defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))) || \
+ (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4)))
+#pragma GCC warning "smart_pointers_thread_safe.h is included which will fail to compile under C++17"
+#endif
+
+#include "smart_pointers/shared_ptr_thread_safe.h"
+
+#endif // DLIB_SMART_POINTERs_THREAD_SAFE_H_
+
+
+
diff --git a/ml/dlib/dlib/sockets.h b/ml/dlib/dlib/sockets.h
new file mode 100644
index 000000000..e253587ed
--- /dev/null
+++ b/ml/dlib/dlib/sockets.h
@@ -0,0 +1,20 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETs_
+#define DLIB_SOCKETs_
+
+#include "platform.h"
+
+
+#ifdef WIN32
+#include "sockets/windows.h"
+#endif
+
+#ifndef WIN32
+#include "sockets/posix.h"
+#endif
+
+#include "sockets/sockets_extensions.h"
+
+#endif // DLIB_SOCKETs_
+
diff --git a/ml/dlib/dlib/sockets/posix.h b/ml/dlib/dlib/sockets/posix.h
new file mode 100644
index 000000000..d736a20d4
--- /dev/null
+++ b/ml/dlib/dlib/sockets/posix.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETS_KERNEl_1_
+#include "sockets_kernel_2.h"
+#endif
+
diff --git a/ml/dlib/dlib/sockets/sockets_extensions.cpp b/ml/dlib/dlib/sockets/sockets_extensions.cpp
new file mode 100644
index 000000000..be08c1998
--- /dev/null
+++ b/ml/dlib/dlib/sockets/sockets_extensions.cpp
@@ -0,0 +1,341 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETS_EXTENSIONs_CPP
+#define DLIB_SOCKETS_EXTENSIONs_CPP
+
+#include <string>
+#include <sstream>
+#include "../sockets.h"
+#include "../error.h"
+#include "sockets_extensions.h"
+#include "../timer.h"
+#include "../algs.h"
+#include "../timeout.h"
+#include "../misc_api.h"
+#include "../serialize.h"
+#include "../string.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ network_address::
+ network_address(
+ const std::string& full_address
+ )
+ {
+ std::istringstream sin(full_address);
+ sin >> *this;
+ if (!sin || sin.peek() != EOF)
+ throw invalid_network_address("invalid network address: " + full_address);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize(
+ const network_address& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.host_address, out);
+ serialize(item.port, out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void deserialize(
+ network_address& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.host_address, in);
+ deserialize(item.port, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::ostream& operator<< (
+ std::ostream& out,
+ const network_address& item
+ )
+ {
+ out << item.host_address << ":" << item.port;
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::istream& operator>> (
+ std::istream& in,
+ network_address& item
+ )
+ {
+ std::string temp;
+ in >> temp;
+
+ std::string::size_type pos = temp.find_last_of(":");
+ if (pos == std::string::npos)
+ {
+ in.setstate(std::ios::badbit);
+ return in;
+ }
+
+ item.host_address = temp.substr(0, pos);
+ try
+ {
+ item.port = sa = temp.substr(pos+1);
+ } catch (std::exception& )
+ {
+ in.setstate(std::ios::badbit);
+ return in;
+ }
+
+
+ return in;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const std::string& host_or_ip,
+ unsigned short port
+ )
+ {
+ std::string ip;
+ connection* con;
+ if (is_ip_address(host_or_ip))
+ {
+ ip = host_or_ip;
+ }
+ else
+ {
+ if( hostname_to_ip(host_or_ip,ip))
+ throw socket_error(ERESOLVE,"unable to resolve '" + host_or_ip + "' in connect()");
+ }
+
+ if(create_connection(con,port,ip))
+ {
+ std::ostringstream sout;
+ sout << "unable to connect to '" << host_or_ip << ":" << port << "'";
+ throw socket_error(sout.str());
+ }
+
+ return con;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const network_address& addr
+ )
+ {
+ return connect(addr.host_address, addr.port);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace connect_timeout_helpers
+ {
+ mutex connect_mutex;
+ signaler connect_signaler(connect_mutex);
+ timestamper ts;
+ long outstanding_connects = 0;
+
+ struct thread_data
+ {
+ std::string host_or_ip;
+ unsigned short port;
+ connection* con;
+ bool connect_ended;
+ bool error_occurred;
+ };
+
+ void thread(void* param)
+ {
+ thread_data p = *static_cast<thread_data*>(param);
+ try
+ {
+ p.con = connect(p.host_or_ip, p.port);
+ }
+ catch (...)
+ {
+ p.error_occurred = true;
+ }
+
+ auto_mutex M(connect_mutex);
+ // report the results back to the connect() call that spawned this
+ // thread.
+ static_cast<thread_data*>(param)->con = p.con;
+ static_cast<thread_data*>(param)->error_occurred = p.error_occurred;
+ connect_signaler.broadcast();
+
+ // wait for the call to connect() that spawned this thread to terminate
+ // before we delete the thread_data struct.
+ while (static_cast<thread_data*>(param)->connect_ended == false)
+ connect_signaler.wait();
+
+ connect_signaler.broadcast();
+ --outstanding_connects;
+ delete static_cast<thread_data*>(param);
+ }
+ }
+
+ connection* connect (
+ const std::string& host_or_ip,
+ unsigned short port,
+ unsigned long timeout
+ )
+ {
+ using namespace connect_timeout_helpers;
+
+ auto_mutex M(connect_mutex);
+
+ const uint64 end_time = ts.get_timestamp() + timeout*1000;
+
+
+ // wait until there are less than 100 outstanding connections
+ while (outstanding_connects > 100)
+ {
+ uint64 cur_time = ts.get_timestamp();
+ if (end_time > cur_time)
+ {
+ timeout = static_cast<unsigned long>((end_time - cur_time)/1000);
+ }
+ else
+ {
+ throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out");
+ }
+
+ connect_signaler.wait_or_timeout(timeout);
+ }
+
+
+ thread_data* data = new thread_data;
+ data->host_or_ip = host_or_ip.c_str();
+ data->port = port;
+ data->con = 0;
+ data->connect_ended = false;
+ data->error_occurred = false;
+
+
+ if (create_new_thread(thread, data) == false)
+ {
+ delete data;
+ throw socket_error("unable to connect to '" + host_or_ip);
+ }
+
+ ++outstanding_connects;
+
+ // wait until we have a connection object
+ while (data->con == 0)
+ {
+ uint64 cur_time = ts.get_timestamp();
+ if (end_time > cur_time && data->error_occurred == false)
+ {
+ timeout = static_cast<unsigned long>((end_time - cur_time)/1000);
+ }
+ else
+ {
+ // let the thread know that it should terminate
+ data->connect_ended = true;
+ connect_signaler.broadcast();
+ if (data->error_occurred)
+ throw socket_error("unable to connect to '" + host_or_ip);
+ else
+ throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out");
+ }
+
+ connect_signaler.wait_or_timeout(timeout);
+ }
+
+ // let the thread know that it should terminate
+ data->connect_ended = true;
+ connect_signaler.broadcast();
+ return data->con;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_ip_address (
+ std::string ip
+ )
+ {
+ for (std::string::size_type i = 0; i < ip.size(); ++i)
+ {
+ if (ip[i] == '.')
+ ip[i] = ' ';
+ }
+ std::istringstream sin(ip);
+
+ bool bad = false;
+ int num;
+ for (int i = 0; i < 4; ++i)
+ {
+ sin >> num;
+ if (!sin || num < 0 || num > 255)
+ {
+ bad = true;
+ break;
+ }
+ }
+
+ if (sin.get() != EOF)
+ bad = true;
+
+ return !bad;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void close_gracefully (
+ connection* con,
+ unsigned long timeout
+ )
+ {
+ std::unique_ptr<connection> ptr(con);
+ close_gracefully(ptr,timeout);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void close_gracefully (
+ std::unique_ptr<connection>& con,
+ unsigned long timeout
+ )
+ {
+ if (!con)
+ return;
+
+ if(con->shutdown_outgoing())
+ {
+ // there was an error so just close it now and return
+ con.reset();
+ return;
+ }
+
+ try
+ {
+ dlib::timeout t(*con,&connection::shutdown,timeout);
+
+ char junk[100];
+ // wait for the other end to close their side
+ while (con->read(junk,sizeof(junk)) > 0) ;
+ }
+ catch (...)
+ {
+ con.reset();
+ throw;
+ }
+
+ con.reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SOCKETS_EXTENSIONs_CPP
+
+
diff --git a/ml/dlib/dlib/sockets/sockets_extensions.h b/ml/dlib/dlib/sockets/sockets_extensions.h
new file mode 100644
index 000000000..9faa34e01
--- /dev/null
+++ b/ml/dlib/dlib/sockets/sockets_extensions.h
@@ -0,0 +1,151 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETS_EXTENSIONs_
+#define DLIB_SOCKETS_EXTENSIONs_
+
+#include <iosfwd>
+#include <memory>
+#include <string>
+
+#include "../sockets.h"
+#include "../smart_pointers/scoped_ptr.h"
+#include "sockets_extensions_abstract.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class invalid_network_address : public dlib::error
+ {
+ public:
+ invalid_network_address(const std::string& msg) : dlib::error(msg) {};
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct network_address
+ {
+ network_address() : port(0){}
+
+ network_address(
+ const std::string& full_address
+ );
+
+ network_address (
+ const char* full_address
+ )
+ {
+ *this = network_address(std::string(full_address));
+ }
+
+ network_address(
+ const std::string& host_address_,
+ const unsigned short port_
+ ) : host_address(host_address_), port(port_) {}
+
+ std::string host_address;
+ unsigned short port;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool operator < (
+ const network_address& a,
+ const network_address& b
+ )
+ {
+ if (a.host_address < b.host_address)
+ return true;
+ else if (a.host_address > b.host_address)
+ return false;
+ else if (a.port < b.port)
+ return true;
+ else
+ return false;
+ }
+
+ inline bool operator== (
+ const network_address& a,
+ const network_address& b
+ ) { return a.host_address == b.host_address && a.port == b.port; }
+
+ inline bool operator != (
+ const network_address& a,
+ const network_address& b
+ ) { return !(a == b); }
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize(
+ const network_address& item,
+ std::ostream& out
+ );
+
+ void deserialize(
+ network_address& item,
+ std::istream& in
+ );
+
+ std::ostream& operator<< (
+ std::ostream& out,
+ const network_address& item
+ );
+
+ std::istream& operator>> (
+ std::istream& in,
+ network_address& item
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const std::string& host_or_ip,
+ unsigned short port
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const network_address& addr
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const std::string& host_or_ip,
+ unsigned short port,
+ unsigned long timeout
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_ip_address (
+ std::string ip
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ void close_gracefully (
+ connection* con,
+ unsigned long timeout = 500
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ void close_gracefully (
+ std::unique_ptr<connection>& con,
+ unsigned long timeout = 500
+ );
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "sockets_extensions.cpp"
+#endif
+
+#endif // DLIB_SOCKETS_EXTENSIONs_
+
diff --git a/ml/dlib/dlib/sockets/sockets_extensions_abstract.h b/ml/dlib/dlib/sockets/sockets_extensions_abstract.h
new file mode 100644
index 000000000..194c22ab2
--- /dev/null
+++ b/ml/dlib/dlib/sockets/sockets_extensions_abstract.h
@@ -0,0 +1,300 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SOCKETS_EXTENSIONs_ABSTRACT_
+#ifdef DLIB_SOCKETS_EXTENSIONs_ABSTRACT_
+
+#include <memory>
+#include <string>
+
+#include "sockets_kernel_abstract.h"
+#include "../error.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class invalid_network_address : public dlib::error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception thrown by network_address's constructor if the
+ input is invalid.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct network_address
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is simply a container for two things:
+ - A host machine address which is either an IP address or DNS name
+ for a machine.
+ - A port number.
+
+ Together, these things define a machine and port on that machine.
+ !*/
+
+ network_address(
+ );
+ /*!
+ ensures
+ - host_address == ""
+ - #port == 0
+ !*/
+
+ network_address(
+ const std::string& full_address
+ );
+ /*!
+ ensures
+ - interprets full_address as a network address of the form:
+ host_address:port
+ and assigns each part into #host_address and #port. For example,
+ network_address("localhost:80") would result in a network_address
+ object where host_address was "localhost" and port was 80.
+ throws
+ - invalid_network_address
+ This exception is thrown if the full_address string can't be
+ interpreted as a valid network address.
+ !*/
+
+ network_address (
+ const char* full_address
+ );
+ /*!
+ requires
+ - full_address == a valid pointer to a null terminated string
+ ensures
+ - Invoking this constructor is equivalent to performing
+ network_address(std::string(full_address))
+ !*/
+
+ network_address(
+ const std::string& host_address_,
+ const unsigned short port_
+ );
+ /*!
+ ensures
+ - #host_address == host_address_
+ - #port == port_
+ !*/
+
+
+ std::string host_address;
+ unsigned short port;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool operator < (
+ const network_address& a,
+ const network_address& b
+ );
+ /*!
+ ensures
+ - provides a total ordering over network_address objects so you can use them in
+ the standard associative containers. The ordering is defined such that if
+ you sorted network addresses they would sort first on the host_address string
+ and then, for network_address objects with equal host_address, they would
+ sort on the port number
+ !*/
+
+ inline bool operator== (
+ const network_address& a,
+ const network_address& b
+ );
+ /*!
+ ensures
+ - returns true if a and b contain exactly the same address and false otherwise.
+ That is, the following must be true for this function to return true:
+ - a.host_address == b.host_address
+ - a.port == b.port
+ Note that this means that two addresses which are logically equivalent but
+ written differently will not compare equal. For example, suppose example.com
+ has the IP address 10.1.1.1. Then network_address("10.1.1.1:80") and
+ network_address("example.com:80") really refer to the same network resource
+ but will nevertheless not compare equal since.
+ !*/
+
+ inline bool operator != (
+ const network_address& a,
+ const network_address& b
+ );
+ /*!
+ ensures
+ - returns !(a == b)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize(
+ const network_address& item,
+ std::ostream& out
+ );
+ /*!
+ ensures
+ - provides serialization support
+ !*/
+
+ void deserialize(
+ network_address& item,
+ std::istream& in
+ );
+ /*!
+ ensures
+ - provides deserialization support
+ !*/
+
+ std::ostream& operator<< (
+ std::ostream& out,
+ const network_address& item
+ );
+ /*!
+ ensures
+ - writes the given network_address to the output stream. The format is the
+ host_address, then a colon, then the port number. So for example:
+ cout << network_address("localhost", 80);
+ would print:
+ localhost:80
+ - returns #out
+ !*/
+
+ std::istream& operator>> (
+ std::istream& in,
+ network_address& item
+ );
+ /*!
+ ensures
+ - reads a network_address from the given input stream. The expected format is
+ the same as the one used to print them by the above operator<<() routine.
+ - returns #in
+ - if (there is an error reading the network_address) then
+ - #in.good() == false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const std::string& host_or_ip,
+ unsigned short port
+ );
+ /*!
+ ensures
+ - returns a connection object that is connected to the given host at the
+ given port
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is some problem that prevents us from
+ creating the connection
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const network_address& addr
+ );
+ /*!
+ ensures
+ - returns connect(addr.host_address, addr_port);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const std::string& host_or_ip,
+ unsigned short port,
+ unsigned long timeout
+ );
+ /*!
+ ensures
+ - returns a connection object that is connected to the given host at the
+ given port.
+ - blocks for at most timeout milliseconds
+ throws
+ - dlib::socket_error
+ This exception is thrown if there is some problem that prevents us from
+ creating the connection or if timeout milliseconds elapses before the
+ connect is successful.
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+
+ bool is_ip_address (
+ std::string ip
+ );
+ /*!
+ ensures
+ - if (ip is a valid ip address) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void close_gracefully (
+ connection* con,
+ unsigned long timeout = 500
+ );
+ /*!
+ requires
+ - con == a valid pointer to a connection object or 0
+ ensures
+ - This function does nothing if con == 0, otherwise it performs the following:
+ - performs a graceful close of the given connection and if it takes longer
+ than timeout milliseconds to complete then forces the connection closed.
+ - Specifically, a graceful close means that the outgoing part of con is
+ closed (a FIN is sent) and then we wait for the other end to to close
+ their end of the connection. This way any data still on its way to
+ the other end of the connection will be received properly.
+ - This function will block until the graceful close is completed or we
+ timeout.
+ - calls "delete con;". Thus con is no longer a valid pointer after this
+ function has finished.
+ throws
+ - std::bad_alloc or dlib::thread_error
+ If either of these exceptions are thrown con will still be closed via
+ "delete con;"
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void close_gracefully (
+ std::unique_ptr<connection>& con,
+ unsigned long timeout = 500
+ );
+ /*!
+ requires
+ - con == a valid pointer to a connection object or con.get() == 0
+ ensures
+ - This function does nothing if con.get() == 0, otherwise it performs the
+ following:
+ - performs a graceful close of the given connection and if it takes longer
+ than timeout milliseconds to complete then forces the connection closed.
+ - Specifically, a graceful close means that the outgoing part of con is
+ closed (a FIN is sent) and then we wait for the other end to to close
+ their end of the connection. This way any data still on its way to
+ the other end of the connection will be received properly.
+ - This function will block until the graceful close is completed or we
+ timeout.
+ - #con.get() == 0. Thus con is no longer a valid pointer after this
+ function has finished.
+ throws
+ - std::bad_alloc or dlib::thread_error
+ If either of these exceptions are thrown con will still be closed and
+ deleted (i.e. #con.get() == 0).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SOCKETS_EXTENSIONs_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/sockets/sockets_kernel_1.cpp b/ml/dlib/dlib/sockets/sockets_kernel_1.cpp
new file mode 100644
index 000000000..55e39569f
--- /dev/null
+++ b/ml/dlib/dlib/sockets/sockets_kernel_1.cpp
@@ -0,0 +1,979 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net), Miguel Grinberg
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETS_KERNEL_1_CPp_
+#define DLIB_SOCKETS_KERNEL_1_CPp_
+#include "../platform.h"
+
+#ifdef WIN32
+
+#include <winsock2.h>
+
+#ifndef _WINSOCKAPI_
+#define _WINSOCKAPI_ /* Prevent inclusion of winsock.h in windows.h */
+#endif
+
+#include "../windows_magic.h"
+
+#include "sockets_kernel_1.h"
+
+#include <windows.h>
+
+#ifndef NI_MAXHOST
+#define NI_MAXHOST 1025
+#endif
+
+
+// tell visual studio to link to the libraries we need if we are
+// in fact using visual studio
+#ifdef _MSC_VER
+#pragma comment (lib, "ws2_32.lib")
+#endif
+
+#include "../assert.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class SOCKET_container
+ {
+ /*!
+ This object is just a wrapper around the SOCKET type. It exists
+ so that we can #include the windows.h and Winsock2.h header files
+ in this cpp file and not at all in the header file.
+ !*/
+ public:
+ SOCKET_container (
+ SOCKET s = INVALID_SOCKET
+ ) : val(s) {}
+
+ SOCKET val;
+ operator SOCKET&() { return val; }
+
+ SOCKET_container& operator= (
+ const SOCKET& s
+ ) { val = s; return *this; }
+
+ bool operator== (
+ const SOCKET& s
+ ) const { return s == val; }
+ };
+
+// ----------------------------------------------------------------------------------------
+// stuff to ensure that WSAStartup() is always called before any sockets stuff is needed
+
+ namespace sockets_kernel_1_mutex
+ {
+ mutex startup_lock;
+ }
+
+ class sockets_startupdown
+ {
+ public:
+ sockets_startupdown();
+ ~sockets_startupdown() { WSACleanup( ); }
+
+ };
+ sockets_startupdown::sockets_startupdown (
+ )
+ {
+ WSADATA wsaData;
+ WSAStartup (MAKEWORD(2,0), &wsaData);
+ }
+
+ void sockets_startup()
+ {
+ // mutex crap to make this function thread-safe
+ sockets_kernel_1_mutex::startup_lock.lock();
+ static sockets_startupdown a;
+ sockets_kernel_1_mutex::startup_lock.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // lookup functions
+
+ int
+ get_local_hostname (
+ std::string& hostname
+ )
+ {
+ // ensure that WSAStartup has been called and WSACleanup will eventually
+ // be called when program ends
+ sockets_startup();
+
+ try
+ {
+
+ char temp[NI_MAXHOST];
+ if (gethostname(temp,NI_MAXHOST) == SOCKET_ERROR )
+ {
+ return OTHER_ERROR;
+ }
+
+ hostname = temp;
+ }
+ catch (...)
+ {
+ return OTHER_ERROR;
+ }
+
+ return 0;
+ }
+
+// -----------------
+
+ int
+ hostname_to_ip (
+ const std::string& hostname,
+ std::string& ip,
+ int n
+ )
+ {
+ // ensure that WSAStartup has been called and WSACleanup will eventually
+ // be called when program ends
+ sockets_startup();
+
+ try
+ {
+ // lock this mutex since gethostbyname isn't really thread safe
+ auto_mutex M(sockets_kernel_1_mutex::startup_lock);
+
+ // if no hostname was given then return error
+ if ( hostname.empty())
+ return OTHER_ERROR;
+
+ hostent* address;
+ address = gethostbyname(hostname.c_str());
+
+ if (address == 0)
+ {
+ return OTHER_ERROR;
+ }
+
+ // find the nth address
+ in_addr* addr = reinterpret_cast<in_addr*>(address->h_addr_list[0]);
+ for (int i = 1; i <= n; ++i)
+ {
+ addr = reinterpret_cast<in_addr*>(address->h_addr_list[i]);
+
+ // if there is no nth address then return error
+ if (addr == 0)
+ return OTHER_ERROR;
+ }
+
+ char* resolved_ip = inet_ntoa(*addr);
+
+ // check if inet_ntoa returned an error
+ if (resolved_ip == NULL)
+ {
+ return OTHER_ERROR;
+ }
+
+ ip.assign(resolved_ip);
+
+ }
+ catch(...)
+ {
+ return OTHER_ERROR;
+ }
+
+ return 0;
+ }
+
+// -----------------
+
+ int
+ ip_to_hostname (
+ const std::string& ip,
+ std::string& hostname
+ )
+ {
+ // ensure that WSAStartup has been called and WSACleanup will eventually
+ // be called when program ends
+ sockets_startup();
+
+ try
+ {
+ // lock this mutex since gethostbyaddr isn't really thread safe
+ auto_mutex M(sockets_kernel_1_mutex::startup_lock);
+
+ // if no ip was given then return error
+ if (ip.empty())
+ return OTHER_ERROR;
+
+ hostent* address;
+ unsigned long ipnum = inet_addr(ip.c_str());
+
+ // if inet_addr couldn't convert ip then return an error
+ if (ipnum == INADDR_NONE)
+ {
+ return OTHER_ERROR;
+ }
+ address = gethostbyaddr(reinterpret_cast<char*>(&ipnum),4,AF_INET);
+
+ // check if gethostbyaddr returned an error
+ if (address == 0)
+ {
+ return OTHER_ERROR;
+ }
+ hostname.assign(address->h_name);
+
+ }
+ catch (...)
+ {
+ return OTHER_ERROR;
+ }
+ return 0;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // connection object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ connection::
+ connection(
+ SOCKET_container sock,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port,
+ const std::string& local_ip
+ ) :
+ user_data(0),
+ connection_socket(*(new SOCKET_container())),
+ connection_foreign_port(foreign_port),
+ connection_foreign_ip(foreign_ip),
+ connection_local_port(local_port),
+ connection_local_ip(local_ip),
+ sd(false),
+ sdo(false),
+ sdr(0)
+ {
+ connection_socket = sock;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ connection::
+ ~connection (
+ )
+ {
+ if (connection_socket != INVALID_SOCKET)
+ closesocket(connection_socket);
+ delete &connection_socket;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int connection::
+ disable_nagle()
+ {
+ int flag = 1;
+ int status = setsockopt( connection_socket, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(flag) );
+
+ if (status == SOCKET_ERROR)
+ return OTHER_ERROR;
+ else
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long connection::
+ write (
+ const char* buf,
+ long num
+ )
+ {
+ const long old_num = num;
+ long status;
+ const long max_send_length = 1024*1024*100;
+ while (num > 0)
+ {
+ // Make sure to cap the max value num can take on so that if it is
+ // really large (it might be big on 64bit platforms) so that the OS
+ // can't possibly get upset about it being large.
+ const long length = std::min(max_send_length, num);
+ if ( (status = send(connection_socket,buf,length,0)) == SOCKET_ERROR)
+ {
+ if (sdo_called())
+ return SHUTDOWN;
+ else
+ return OTHER_ERROR;
+ }
+ num -= status;
+ buf += status;
+ }
+ return old_num;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long connection::
+ read (
+ char* buf,
+ long num
+ )
+ {
+ const long max_recv_length = 1024*1024*100;
+ // Make sure to cap the max value num can take on so that if it is
+ // really large (it might be big on 64bit platforms) so that the OS
+ // can't possibly get upset about it being large.
+ const long length = std::min(max_recv_length, num);
+ long status = recv(connection_socket,buf,length,0);
+ if (status == SOCKET_ERROR)
+ {
+ // if this error is the result of a shutdown call then return SHUTDOWN
+ if (sd_called())
+ return SHUTDOWN;
+ else
+ return OTHER_ERROR;
+ }
+ else if (status == 0 && sd_called())
+ {
+ return SHUTDOWN;
+ }
+ return status;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long connection::
+ read (
+ char* buf,
+ long num,
+ unsigned long timeout
+ )
+ {
+ if (readable(timeout) == false)
+ return TIMEOUT;
+
+ const long max_recv_length = 1024*1024*100;
+ // Make sure to cap the max value num can take on so that if it is
+ // really large (it might be big on 64bit platforms) so that the OS
+ // can't possibly get upset about it being large.
+ const long length = std::min(max_recv_length, num);
+ long status = recv(connection_socket,buf,length,0);
+ if (status == SOCKET_ERROR)
+ {
+ // if this error is the result of a shutdown call then return SHUTDOWN
+ if (sd_called())
+ return SHUTDOWN;
+ else
+ return OTHER_ERROR;
+ }
+ else if (status == 0 && sd_called())
+ {
+ return SHUTDOWN;
+ }
+ return status;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool connection::
+ readable (
+ unsigned long timeout
+ ) const
+ {
+ fd_set read_set;
+ // initialize read_set
+ FD_ZERO(&read_set);
+
+ // add the listening socket to read_set
+ FD_SET(connection_socket, &read_set);
+
+ // setup a timeval structure
+ timeval time_to_wait;
+ time_to_wait.tv_sec = static_cast<long>(timeout/1000);
+ time_to_wait.tv_usec = static_cast<long>((timeout%1000)*1000);
+
+ // wait on select
+ int status = select(0,&read_set,0,0,&time_to_wait);
+
+ // if select timed out or there was an error
+ if (status <= 0)
+ return false;
+
+ // data is ready to be read
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int connection::
+ shutdown_outgoing (
+ )
+ {
+ sd_mutex.lock();
+ if (sdo || sd)
+ {
+ sd_mutex.unlock();
+ return sdr;
+ }
+ sdo = true;
+ sdr = ::shutdown(connection_socket,SD_SEND);
+
+ // convert -1 error code into the OTHER_ERROR error code
+ if (sdr == -1)
+ sdr = OTHER_ERROR;
+
+ int temp = sdr;
+
+ sd_mutex.unlock();
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int connection::
+ shutdown (
+ )
+ {
+ sd_mutex.lock();
+ if (sd)
+ {
+ sd_mutex.unlock();
+ return sdr;
+ }
+ sd = true;
+ SOCKET stemp = connection_socket;
+ connection_socket = INVALID_SOCKET;
+ sdr = closesocket(stemp);
+
+ // convert SOCKET_ERROR error code into the OTHER_ERROR error code
+ if (sdr == SOCKET_ERROR)
+ sdr = OTHER_ERROR;
+
+ int temp = sdr;
+
+ sd_mutex.unlock();
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ connection::socket_descriptor_type connection::
+ get_socket_descriptor (
+ ) const
+ {
+ return connection_socket.val;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // listener object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ listener::
+ listener(
+ SOCKET_container sock,
+ unsigned short port,
+ const std::string& ip
+ ) :
+ listening_socket(*(new SOCKET_container)),
+ listening_port(port),
+ listening_ip(ip),
+ inaddr_any(listening_ip.empty())
+ {
+ listening_socket = sock;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ listener::
+ ~listener (
+ )
+ {
+ closesocket(listening_socket);
+ delete &listening_socket;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int listener::
+ accept (
+ std::unique_ptr<connection>& new_connection,
+ unsigned long timeout
+ )
+ {
+ new_connection.reset(0);
+ connection* con;
+ int status = this->accept(con, timeout);
+
+ if (status == 0)
+ new_connection.reset(con);
+
+ return status;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int listener::
+ accept (
+ connection*& new_connection,
+ unsigned long timeout
+ )
+ {
+ SOCKET incoming;
+ sockaddr_in incomingAddr;
+ int length = sizeof(sockaddr_in);
+
+ // implement timeout with select if timeout is > 0
+ if (timeout > 0)
+ {
+ fd_set read_set;
+ // initialize read_set
+ FD_ZERO(&read_set);
+
+ // add the listening socket to read_set
+ FD_SET(listening_socket, &read_set);
+
+ // setup a timeval structure
+ timeval time_to_wait;
+ time_to_wait.tv_sec = static_cast<long>(timeout/1000);
+ time_to_wait.tv_usec = static_cast<long>((timeout%1000)*1000);
+
+
+ // wait on select
+ int status = select(0,&read_set,0,0,&time_to_wait);
+
+ // if select timed out
+ if (status == 0)
+ return TIMEOUT;
+
+ // if select returned an error
+ if (status == SOCKET_ERROR)
+ return OTHER_ERROR;
+
+ }
+
+
+ // call accept to get a new connection
+ incoming=::accept(listening_socket,reinterpret_cast<sockaddr*>(&incomingAddr),&length);
+
+ // if there was an error return OTHER_ERROR
+ if ( incoming == INVALID_SOCKET )
+ return OTHER_ERROR;
+
+
+ // get the port of the foreign host into foreign_port
+ int foreign_port = ntohs(incomingAddr.sin_port);
+
+ // get the IP of the foreign host into foreign_ip
+ std::string foreign_ip;
+ {
+ char* foreign_ip_temp = inet_ntoa(incomingAddr.sin_addr);
+
+ // check if inet_ntoa() returned an error
+ if (foreign_ip_temp == NULL)
+ {
+ closesocket(incoming);
+ return OTHER_ERROR;
+ }
+
+ foreign_ip.assign(foreign_ip_temp);
+ }
+
+
+ // get the local ip
+ std::string local_ip;
+ if (inaddr_any == true)
+ {
+ sockaddr_in local_info;
+ length = sizeof(sockaddr_in);
+ // get the local sockaddr_in structure associated with this new connection
+ if ( getsockname (
+ incoming,
+ reinterpret_cast<sockaddr*>(&local_info),
+ &length
+ ) == SOCKET_ERROR
+ )
+ { // an error occurred
+ closesocket(incoming);
+ return OTHER_ERROR;
+ }
+ char* temp = inet_ntoa(local_info.sin_addr);
+
+ // check if inet_ntoa() returned an error
+ if (temp == NULL)
+ {
+ closesocket(incoming);
+ return OTHER_ERROR;
+ }
+ local_ip.assign(temp);
+ }
+ else
+ {
+ local_ip = listening_ip;
+ }
+
+
+ // set the SO_OOBINLINE option
+ int flag_value = 1;
+ if (setsockopt(incoming,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast<const char*>(&flag_value),sizeof(int)) == SOCKET_ERROR )
+ {
+ closesocket(incoming);
+ return OTHER_ERROR;
+ }
+
+
+ // make a new connection object for this new connection
+ try
+ {
+ new_connection = new connection (
+ incoming,
+ foreign_port,
+ foreign_ip,
+ listening_port,
+ local_ip
+ );
+ }
+ catch (...) { closesocket(incoming); return OTHER_ERROR; }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // socket creation functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ int create_listener (
+ std::unique_ptr<listener>& new_listener,
+ unsigned short port,
+ const std::string& ip
+ )
+ {
+ new_listener.reset();
+ listener* temp;
+ int status = create_listener(temp,port,ip);
+
+ if (status == 0)
+ new_listener.reset(temp);
+
+ return status;
+ }
+
+ int create_listener (
+ listener*& new_listener,
+ unsigned short port,
+ const std::string& ip
+ )
+ {
+ // ensure that WSAStartup has been called and WSACleanup will eventually
+ // be called when program ends
+ sockets_startup();
+
+ sockaddr_in sa; // local socket structure
+ ZeroMemory(&sa,sizeof(sockaddr_in)); // initialize sa
+
+ SOCKET sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket
+
+ // if socket() returned an error then return OTHER_ERROR
+ if (sock == INVALID_SOCKET )
+ {
+ return OTHER_ERROR;
+ }
+
+ // set the local socket structure
+ sa.sin_family = AF_INET;
+ sa.sin_port = htons(port);
+ if (ip.empty())
+ {
+ // if the listener should listen on any IP
+ sa.sin_addr.S_un.S_addr = htons(INADDR_ANY);
+ }
+ else
+ {
+ // if there is a specific ip to listen on
+ sa.sin_addr.S_un.S_addr = inet_addr(ip.c_str());
+ // if inet_addr couldn't convert the ip then return an error
+ if ( sa.sin_addr.S_un.S_addr == INADDR_NONE )
+ {
+ closesocket(sock);
+ return OTHER_ERROR;
+ }
+ }
+
+ // set the SO_REUSEADDR option
+ int flag_value = 1;
+ setsockopt(sock,SOL_SOCKET,SO_REUSEADDR,reinterpret_cast<const char*>(&flag_value),sizeof(int));
+
+ // bind the new socket to the requested port and ip
+ if (bind(sock,reinterpret_cast<sockaddr*>(&sa),sizeof(sockaddr_in))==SOCKET_ERROR)
+ {
+ const int err = WSAGetLastError();
+ // if there was an error
+ closesocket(sock);
+
+ // if the port is already bound then return PORTINUSE
+ if (err == WSAEADDRINUSE)
+ return PORTINUSE;
+ else
+ return OTHER_ERROR;
+ }
+
+
+ // tell the new socket to listen
+ if ( listen(sock,SOMAXCONN) == SOCKET_ERROR)
+ {
+ const int err = WSAGetLastError();
+ // if there was an error return OTHER_ERROR
+ closesocket(sock);
+
+ // if the port is already bound then return PORTINUSE
+ if (err == WSAEADDRINUSE)
+ return PORTINUSE;
+ else
+ return OTHER_ERROR;
+ }
+
+ // determine the port used if necessary
+ if (port == 0)
+ {
+ sockaddr_in local_info;
+ int length = sizeof(sockaddr_in);
+ if ( getsockname (
+ sock,
+ reinterpret_cast<sockaddr*>(&local_info),
+ &length
+ ) == SOCKET_ERROR
+ )
+ {
+ closesocket(sock);
+ return OTHER_ERROR;
+ }
+ port = ntohs(local_info.sin_port);
+ }
+
+
+ // initialize a listener object on the heap with the new socket
+ try { new_listener = new listener(sock,port,ip); }
+ catch(...) { closesocket(sock); return OTHER_ERROR; }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int create_connection (
+ std::unique_ptr<connection>& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port,
+ const std::string& local_ip
+ )
+ {
+ new_connection.reset();
+ connection* temp;
+ int status = create_connection(temp,foreign_port, foreign_ip, local_port, local_ip);
+
+ if (status == 0)
+ new_connection.reset(temp);
+
+ return status;
+ }
+
+ int create_connection (
+ connection*& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port,
+ const std::string& local_ip
+ )
+ {
+ // ensure that WSAStartup has been called and WSACleanup
+ // will eventually be called when program ends
+ sockets_startup();
+
+
+ sockaddr_in local_sa; // local socket structure
+ sockaddr_in foreign_sa; // foreign socket structure
+ ZeroMemory(&local_sa,sizeof(sockaddr_in)); // initialize local_sa
+ ZeroMemory(&foreign_sa,sizeof(sockaddr_in)); // initialize foreign_sa
+
+ int length;
+
+ SOCKET sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket
+
+ // if socket() returned an error then return OTHER_ERROR
+ if (sock == INVALID_SOCKET )
+ {
+ return OTHER_ERROR;
+ }
+
+ // set the foreign socket structure
+ foreign_sa.sin_family = AF_INET;
+ foreign_sa.sin_port = htons(foreign_port);
+ foreign_sa.sin_addr.S_un.S_addr = inet_addr(foreign_ip.c_str());
+
+ // if inet_addr couldn't convert the ip then return an error
+ if ( foreign_sa.sin_addr.S_un.S_addr == INADDR_NONE )
+ {
+ closesocket(sock);
+ return OTHER_ERROR;
+ }
+
+
+ // set up the local socket structure
+ local_sa.sin_family = AF_INET;
+
+ // set the local ip
+ if (local_ip.empty())
+ {
+ // if the listener should listen on any IP
+ local_sa.sin_addr.S_un.S_addr = htons(INADDR_ANY);
+ }
+ else
+ {
+ // if there is a specific ip to listen on
+ local_sa.sin_addr.S_un.S_addr = inet_addr(local_ip.c_str());
+
+ // if inet_addr couldn't convert the ip then return an error
+ if (local_sa.sin_addr.S_un.S_addr == INADDR_NONE)
+ {
+ closesocket(sock);
+ return OTHER_ERROR;
+ }
+ }
+
+ // set the local port
+ local_sa.sin_port = htons(local_port);
+
+
+
+ // bind the new socket to the requested local port and local ip
+ if ( bind (
+ sock,
+ reinterpret_cast<sockaddr*>(&local_sa),
+ sizeof(sockaddr_in)
+ ) == SOCKET_ERROR
+ )
+ {
+ const int err = WSAGetLastError();
+ // if there was an error
+ closesocket(sock);
+
+ // if the port is already bound then return PORTINUSE
+ if (err == WSAEADDRINUSE)
+ return PORTINUSE;
+ else
+ return OTHER_ERROR;
+ }
+
+ // connect the socket
+ if (connect (
+ sock,
+ reinterpret_cast<sockaddr*>(&foreign_sa),
+ sizeof(sockaddr_in)
+ ) == SOCKET_ERROR
+ )
+ {
+ const int err = WSAGetLastError();
+ closesocket(sock);
+ // if the port is already bound then return PORTINUSE
+ if (err == WSAEADDRINUSE)
+ return PORTINUSE;
+ else
+ return OTHER_ERROR;
+ }
+
+
+
+ // determine the local port and IP and store them in used_local_ip
+ // and used_local_port
+ int used_local_port;
+ std::string used_local_ip;
+ sockaddr_in local_info;
+ if (local_port == 0)
+ {
+ length = sizeof(sockaddr_in);
+ if (getsockname (
+ sock,
+ reinterpret_cast<sockaddr*>(&local_info),
+ &length
+ ) == SOCKET_ERROR
+ )
+ {
+ closesocket(sock);
+ return OTHER_ERROR;
+ }
+ used_local_port = ntohs(local_info.sin_port);
+ }
+ else
+ {
+ used_local_port = local_port;
+ }
+
+ // determine real local ip
+ if (local_ip.empty())
+ {
+ // if local_port is not 0 then we must fill the local_info structure
+ if (local_port != 0)
+ {
+ length = sizeof(sockaddr_in);
+ if ( getsockname (
+ sock,
+ reinterpret_cast<sockaddr*>(&local_info),
+ &length
+ ) == SOCKET_ERROR
+ )
+ {
+ closesocket(sock);
+ return OTHER_ERROR;
+ }
+ }
+ char* temp = inet_ntoa(local_info.sin_addr);
+
+ // check if inet_ntoa returned an error
+ if (temp == NULL)
+ {
+ closesocket(sock);
+ return OTHER_ERROR;
+ }
+ used_local_ip.assign(temp);
+ }
+ else
+ {
+ used_local_ip = local_ip;
+ }
+
+ // set the SO_OOBINLINE option
+ int flag_value = 1;
+ if (setsockopt(sock,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast<const char*>(&flag_value),sizeof(int)) == SOCKET_ERROR )
+ {
+ closesocket(sock);
+ return OTHER_ERROR;
+ }
+
+ // initialize a connection object on the heap with the new socket
+ try
+ {
+ new_connection = new connection (
+ sock,
+ foreign_port,
+ foreign_ip,
+ used_local_port,
+ used_local_ip
+ );
+ }
+ catch(...) {closesocket(sock); return OTHER_ERROR; }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // WIN32
+
+#endif // DLIB_SOCKETS_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/sockets/sockets_kernel_1.h b/ml/dlib/dlib/sockets/sockets_kernel_1.h
new file mode 100644
index 000000000..5fb73ecd6
--- /dev/null
+++ b/ml/dlib/dlib/sockets/sockets_kernel_1.h
@@ -0,0 +1,351 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net), Miguel Grinberg
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETS_KERNEl_1_
+#define DLIB_SOCKETS_KERNEl_1_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+#include "sockets_kernel_abstract.h"
+
+#include <memory>
+#include <string>
+
+#include "../algs.h"
+#include "../threads.h"
+#include "../uintn.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ // forward declarations
+ class socket_factory;
+ class listener;
+ class SOCKET_container;
+
+// ----------------------------------------------------------------------------------------
+
+ // lookup functions
+
+ int
+ get_local_hostname (
+ std::string& hostname
+ );
+
+// -----------------
+
+ int
+ hostname_to_ip (
+ const std::string& hostname,
+ std::string& ip,
+ int n = 0
+ );
+
+// -----------------
+
+ int
+ ip_to_hostname (
+ const std::string& ip,
+ std::string& hostname
+ );
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // connection object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class connection
+ {
+ /*!
+ INITIAL_VALUE
+ - sd == false
+ - sdo == false
+ - sdr == 0
+
+ CONVENTION
+ - connection_socket == the socket handle for this connection.
+ - connection_foreign_port == the port that foreign host is using for
+ this connection.
+ - connection_foreign_ip == a string containing the IP address of the
+ foreign host.
+ - connection_local_port == the port that the local host is using for
+ this connection.
+ - connection_local_ip == a string containing the IP address of the
+ local interface being used by this connection.
+
+ - sd == if shutdown() has been called then true else false.
+ - sdo == if shutdown_outgoing() has been called then true else false.
+ - sdr == the return value of shutdown() if it has been called. if it
+ hasn't been called then 0.
+
+ !*/
+
+ friend class listener; // make listener a friend of connection
+ // make create_connection a friend of connection
+ friend int create_connection (
+ connection*& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port,
+ const std::string& local_ip
+ );
+
+ public:
+
+ ~connection (
+ );
+
+ void* user_data;
+
+ long write (
+ const char* buf,
+ long num
+ );
+
+ long read (
+ char* buf,
+ long num
+ );
+
+ long read (
+ char* buf,
+ long num,
+ unsigned long timeout
+ );
+
+ unsigned short get_local_port (
+ ) const { return connection_local_port; }
+
+ unsigned short get_foreign_port (
+ ) const { return connection_foreign_port; }
+
+ const std::string& get_local_ip (
+ ) const { return connection_local_ip; }
+
+ const std::string& get_foreign_ip (
+ ) const { return connection_foreign_ip; }
+
+ int shutdown_outgoing (
+ );
+
+ int shutdown (
+ );
+
+ // I would use SOCKET here but I don't want to include the windows
+ // header files since they bring in a bunch of unpleasantness. So
+ // I'm doing this instead which should ultimately be the same type
+ // as the SOCKET win the windows API.
+ typedef unsigned_type<void*>::type socket_descriptor_type;
+
+ int disable_nagle(
+ );
+
+ socket_descriptor_type get_socket_descriptor (
+ ) const;
+
+ private:
+
+ bool readable (
+ unsigned long timeout
+ ) const;
+ /*!
+ requires
+ - timeout < 2000000
+ ensures
+ - returns true if a read call on this connection will not block.
+ - returns false if a read call on this connection will block or if
+ there was an error.
+ !*/
+
+ bool sd_called (
+ )const
+ /*!
+ ensures
+ - returns true if shutdown() has been called else
+ returns false
+ !*/
+ {
+ sd_mutex.lock();
+ bool temp = sd;
+ sd_mutex.unlock();
+ return temp;
+ }
+
+ bool sdo_called (
+ )const
+ /*!
+ ensures
+ - returns true if shutdown_outgoing() or shutdown() has been called
+ else returns false
+ !*/
+ {
+ sd_mutex.lock();
+ bool temp = false;
+ if (sdo || sd)
+ temp = true;
+ sd_mutex.unlock();
+ return temp;
+ }
+
+
+ // data members
+ SOCKET_container& connection_socket;
+ const unsigned short connection_foreign_port;
+ const std::string connection_foreign_ip;
+ const unsigned short connection_local_port;
+ const std::string connection_local_ip;
+
+ bool sd; // called shutdown
+ bool sdo; // called shutdown_outgoing
+ int sdr; // return value for shutdown
+ mutex sd_mutex; // a lock for the three above vars
+
+
+ connection(
+ SOCKET_container sock,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port,
+ const std::string& local_ip
+ );
+ /*!
+ requires
+ sock is a socket handle and
+ sock is the handle for the connection between foreign_ip:foreign_port
+ and local_ip:local_port
+ ensures
+ *this is initialized correctly with the above parameters
+ !*/
+
+
+ // restricted functions
+ connection(connection&); // copy constructor
+ connection& operator=(connection&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // listener object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class listener
+ {
+ /*!
+ CONVENTION
+ if (inaddr_any == false)
+ {
+ listening_ip == a string containing the address the listener is
+ listening on
+ }
+ else
+ {
+ the listener is listening on all interfaces
+ }
+
+ listening_port == the port the listener is listening on
+ listening_socket == the listening socket handle for this object
+ !*/
+
+ // make the create_listener a friend of listener
+ friend int create_listener (
+ listener*& new_listener,
+ unsigned short port,
+ const std::string& ip
+ );
+
+ public:
+
+ ~listener (
+ );
+
+ int accept (
+ connection*& new_connection,
+ unsigned long timeout = 0
+ );
+
+ int accept (
+ std::unique_ptr<connection>& new_connection,
+ unsigned long timeout = 0
+ );
+
+ unsigned short get_listening_port (
+ ) { return listening_port; }
+
+ const std::string& get_listening_ip (
+ ) { return listening_ip; }
+
+ private:
+
+ // data members
+ SOCKET_container& listening_socket;
+ const unsigned short listening_port;
+ const std::string listening_ip;
+ const bool inaddr_any;
+
+ listener(
+ SOCKET_container sock,
+ unsigned short port,
+ const std::string& ip
+ );
+ /*!
+ requires
+ sock is a socket handle and
+ sock is listening on the port and ip(may be "") indicated in the
+ above parameters
+ ensures
+ *this is initialized correctly with the above parameters
+ !*/
+
+
+ // restricted functions
+ listener(listener&); // copy constructor
+ listener& operator=(listener&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ int create_listener (
+ listener*& new_listener,
+ unsigned short port,
+ const std::string& ip = ""
+ );
+
+ int create_connection (
+ connection*& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port = 0,
+ const std::string& local_ip = ""
+ );
+
+ int create_listener (
+ std::unique_ptr<listener>& new_listener,
+ unsigned short port,
+ const std::string& ip = ""
+ );
+
+ int create_connection (
+ std::unique_ptr<connection>& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port = 0,
+ const std::string& local_ip = ""
+ );
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+#ifdef NO_MAKEFILE
+#include "sockets_kernel_1.cpp"
+#endif
+
+#endif // DLIB_SOCKETS_KERNEl_1_
+
diff --git a/ml/dlib/dlib/sockets/sockets_kernel_2.cpp b/ml/dlib/dlib/sockets/sockets_kernel_2.cpp
new file mode 100644
index 000000000..ac7408ff3
--- /dev/null
+++ b/ml/dlib/dlib/sockets/sockets_kernel_2.cpp
@@ -0,0 +1,1109 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net), Miguel Grinberg
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETS_KERNEL_2_CPp_
+#define DLIB_SOCKETS_KERNEL_2_CPp_
+
+#include "../platform.h"
+
+#ifdef POSIX
+
+
+#include "sockets_kernel_2.h"
+#include <fcntl.h>
+#include "../set.h"
+#include <netinet/tcp.h>
+
+
+
+namespace dlib
+{
+// ----------------------------------------------------------------------------------------
+
+#ifdef HPUX
+ typedef int dsocklen_t;
+#else
+ typedef socklen_t dsocklen_t;
+#endif
+
+// ----------------------------------------------------------------------------------------
+// stuff to ensure that the signal SIGPIPE is ignored before any connections are made
+// so that when a connection object is shutdown the program won't end on a broken pipe
+
+ namespace sockets_kernel_2_mutex
+ {
+ mutex startup_lock;
+ }
+
+
+ void sockets_startup()
+ {
+ // mutex crap to make this function thread safe
+ sockets_kernel_2_mutex::startup_lock.lock();
+ static bool init = false;
+ if (init == false)
+ {
+ init = true;
+ signal( SIGPIPE, SIG_IGN);
+ }
+ sockets_kernel_2_mutex::startup_lock.unlock();
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ // lookup functions
+
+ int
+ get_local_hostname (
+ std::string& hostname
+ )
+ {
+ try
+ {
+ char temp[MAXHOSTNAMELEN];
+
+ if (gethostname(temp,MAXHOSTNAMELEN) == -1)
+ {
+ return OTHER_ERROR;
+ }
+ // ensure that NUL is at the end of the string
+ temp[MAXHOSTNAMELEN-1] = '\0';
+
+ hostname = temp;
+ }
+ catch (...)
+ {
+ return OTHER_ERROR;
+ }
+
+ return 0;
+ }
+
+// -----------------
+
+// cygwin currently doesn't support the getaddrinfo stuff
+#ifndef __CYGWIN__
+
+ int
+ hostname_to_ip (
+ const std::string& hostname,
+ std::string& ip,
+ int n
+ )
+ {
+ try
+ {
+ set<std::string>::kernel_1a sos;
+
+ if (hostname.empty())
+ return OTHER_ERROR;
+
+ addrinfo* result = 0;
+ if (getaddrinfo(hostname.c_str(),0,0,&result))
+ {
+ return OTHER_ERROR;
+ }
+ addrinfo* result_orig = result;
+
+ // loop over all the addrinfo structures and add them to the set. the reason for doing
+ // this dumb crap is because different platforms return all kinds of weird garbage. many
+ // return the same ip multiple times, etc.
+ while (result != 0)
+ {
+ char temp[16];
+ inet_ntop (
+ AF_INET,
+ &((reinterpret_cast<sockaddr_in*>(result->ai_addr))->sin_addr),
+ temp,16
+ );
+
+ result = result->ai_next;
+
+ ip.assign(temp);
+ if (sos.is_member(ip) == false && ip != "0.0.0.0")
+ sos.add(ip);
+ }
+
+ freeaddrinfo(result_orig);
+
+ // now return the nth unique ip address
+ int i = 0;
+ while (sos.move_next())
+ {
+ if (i == n)
+ {
+ ip = sos.element();
+ return 0;
+ }
+ ++i;
+ }
+
+ return OTHER_ERROR;
+ }
+ catch (...)
+ {
+ return OTHER_ERROR;
+ }
+ return 0;
+ }
+
+
+// -----------------
+
+ int
+ ip_to_hostname (
+ const std::string& ip,
+ std::string& hostname
+ )
+ {
+
+ try
+ {
+
+ if (ip.empty())
+ return OTHER_ERROR;
+
+ sockaddr_in sa;
+ sa.sin_family = AF_INET;
+ inet_pton(AF_INET,ip.c_str(),&sa.sin_addr);
+
+ char temp[NI_MAXHOST];
+ if ( getnameinfo (
+ reinterpret_cast<sockaddr*>(&sa),sizeof(sockaddr_in),
+ temp,
+ NI_MAXHOST,
+ 0,
+ 0,
+ NI_NAMEREQD
+ )
+ )
+ {
+ return OTHER_ERROR;
+ }
+
+ hostname.assign(temp);
+
+ }
+ catch (...)
+ {
+ return OTHER_ERROR;
+ }
+ return 0;
+ }
+#else
+ int
+ hostname_to_ip (
+ const std::string& hostname,
+ std::string& ip,
+ int n
+ )
+ {
+ try
+ {
+ // lock this mutex since gethostbyname isn't really thread safe
+ auto_mutex M(sockets_kernel_2_mutex::startup_lock);
+
+ // if no hostname was given then return error
+ if ( hostname.empty())
+ return OTHER_ERROR;
+
+ hostent* address;
+ address = gethostbyname(hostname.c_str());
+
+ if (address == 0)
+ {
+ return OTHER_ERROR;
+ }
+
+ // find the nth address
+ in_addr* addr = reinterpret_cast<in_addr*>(address->h_addr_list[0]);
+ for (int i = 1; i <= n; ++i)
+ {
+ addr = reinterpret_cast<in_addr*>(address->h_addr_list[i]);
+
+ // if there is no nth address then return error
+ if (addr == 0)
+ return OTHER_ERROR;
+ }
+
+ char* resolved_ip = inet_ntoa(*addr);
+
+ // check if inet_ntoa returned an error
+ if (resolved_ip == NULL)
+ {
+ return OTHER_ERROR;
+ }
+
+ ip.assign(resolved_ip);
+
+ }
+ catch(...)
+ {
+ return OTHER_ERROR;
+ }
+
+ return 0;
+ }
+
+// -----------------
+
+ int
+ ip_to_hostname (
+ const std::string& ip,
+ std::string& hostname
+ )
+ {
+ try
+ {
+ // lock this mutex since gethostbyaddr isn't really thread safe
+ auto_mutex M(sockets_kernel_2_mutex::startup_lock);
+
+ // if no ip was given then return error
+ if (ip.empty())
+ return OTHER_ERROR;
+
+ hostent* address;
+ unsigned long ipnum = inet_addr(ip.c_str());
+
+ // if inet_addr couldn't convert ip then return an error
+ if (ipnum == INADDR_NONE)
+ {
+ return OTHER_ERROR;
+ }
+ address = gethostbyaddr(reinterpret_cast<char*>(&ipnum),4,AF_INET);
+
+ // check if gethostbyaddr returned an error
+ if (address == 0)
+ {
+ return OTHER_ERROR;
+ }
+ hostname.assign(address->h_name);
+
+ }
+ catch (...)
+ {
+ return OTHER_ERROR;
+ }
+ return 0;
+
+ }
+
+#endif // __CYGWIN__
+
+// ----------------------------------------------------------------------------------------
+
+ connection::
+ connection(
+ int sock,
+ int foreign_port,
+ const std::string& foreign_ip,
+ int local_port,
+ const std::string& local_ip
+ ) :
+ connection_socket(sock),
+ connection_foreign_port(foreign_port),
+ connection_foreign_ip(foreign_ip),
+ connection_local_port(local_port),
+ connection_local_ip(local_ip),
+ sd(false),
+ sdo(false),
+ sdr(0)
+ {}
+
+// ----------------------------------------------------------------------------------------
+
+ int connection::
+ disable_nagle()
+ {
+ int flag = 1;
+ if(setsockopt( connection_socket, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(flag) ))
+ {
+ return OTHER_ERROR;
+ }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long connection::
+ write (
+ const char* buf,
+ long num
+ )
+ {
+ const long old_num = num;
+ long status;
+ const long max_send_length = 1024*1024*100;
+ while (num > 0)
+ {
+ // Make sure to cap the max value num can take on so that if it is
+ // really large (it might be big on 64bit platforms) so that the OS
+ // can't possibly get upset about it being large.
+ const long length = std::min(max_send_length, num);
+ if ( (status = ::send(connection_socket,buf,length,0)) <=0)
+ {
+ // if send was interupted by a signal then restart it
+ if (errno == EINTR)
+ {
+ continue;
+ }
+ else
+ {
+ // check if shutdown or shutdown_outgoing have been called
+ if (sdo_called())
+ return SHUTDOWN;
+ else
+ return OTHER_ERROR;
+ }
+ }
+ num -= status;
+ buf += status;
+ }
+ return old_num;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long connection::
+ read (
+ char* buf,
+ long num
+ )
+ {
+ long status;
+ const long max_recv_length = 1024*1024*100;
+ while (true)
+ {
+ // Make sure to cap the max value num can take on so that if it is
+ // really large (it might be big on 64bit platforms) so that the OS
+ // can't possibly get upset about it being large.
+ const long length = std::min(max_recv_length, num);
+ status = recv(connection_socket,buf,length,0);
+ if (status == -1)
+ {
+ // if recv was interupted then try again
+ if (errno == EINTR)
+ continue;
+ else
+ {
+ if (sd_called())
+ return SHUTDOWN;
+ else
+ return OTHER_ERROR;
+ }
+ }
+ else if (status == 0 && sd_called())
+ {
+ return SHUTDOWN;
+ }
+
+ return status;
+ } // while (true)
+ }
+// ----------------------------------------------------------------------------------------
+
+ long connection::
+ read (
+ char* buf,
+ long num,
+ unsigned long timeout
+ )
+ {
+ long status;
+ const long max_recv_length = 1024*1024*100;
+
+ if (readable(timeout) == false)
+ return TIMEOUT;
+
+ // Make sure to cap the max value num can take on so that if it is
+ // really large (it might be big on 64bit platforms) so that the OS
+ // can't possibly get upset about it being large.
+ const long length = std::min(max_recv_length, num);
+ status = recv(connection_socket,buf,length,0);
+ if (status == -1)
+ {
+ // if recv was interupted then call this a timeout
+ if (errno == EINTR)
+ {
+ return TIMEOUT;
+ }
+ else
+ {
+ if (sd_called())
+ return SHUTDOWN;
+ else
+ return OTHER_ERROR;
+ }
+ }
+ else if (status == 0 && sd_called())
+ {
+ return SHUTDOWN;
+ }
+
+ return status;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool connection::
+ readable (
+ unsigned long timeout
+ ) const
+ {
+ fd_set read_set;
+ // initialize read_set
+ FD_ZERO(&read_set);
+
+ // add the listening socket to read_set
+ FD_SET(connection_socket, &read_set);
+
+ // setup a timeval structure
+ timeval time_to_wait;
+ time_to_wait.tv_sec = static_cast<long>(timeout/1000);
+ time_to_wait.tv_usec = static_cast<long>((timeout%1000)*1000);
+
+ // wait on select
+ int status = select(connection_socket+1,&read_set,0,0,&time_to_wait);
+
+ // if select timed out or there was an error
+ if (status <= 0)
+ return false;
+
+ // socket is ready to be read
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ connection::
+ ~connection (
+ )
+ {
+ while (true)
+ {
+ int status = ::close(connection_socket);
+ if (status == -1 && errno == EINTR)
+ continue;
+ break;
+ }
+ }
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // listener object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ listener::
+ listener(
+ int sock,
+ int port,
+ const std::string& ip
+ ) :
+ listening_socket(sock),
+ listening_port(port),
+ listening_ip(ip),
+ inaddr_any(listening_ip.empty())
+ {}
+
+// ----------------------------------------------------------------------------------------
+
+ listener::
+ ~listener (
+ )
+ {
+ while (true)
+ {
+ int status = ::close(listening_socket);
+ if (status == -1 && errno == EINTR)
+ continue;
+ break;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int listener::
+ accept (
+ std::unique_ptr<connection>& new_connection,
+ unsigned long timeout
+ )
+ {
+ new_connection.reset(0);
+ connection* con;
+ int status = this->accept(con, timeout);
+
+ if (status == 0)
+ new_connection.reset(con);
+
+ return status;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int listener::
+ accept (
+ connection*& new_connection,
+ unsigned long timeout
+ )
+ {
+ int incoming;
+ sockaddr_in incomingAddr;
+ dsocklen_t length = sizeof(sockaddr_in);
+
+ // implement timeout with select if timeout is > 0
+ if (timeout > 0)
+ {
+
+ fd_set read_set;
+ // initialize read_set
+ FD_ZERO(&read_set);
+
+ // add the listening socket to read_set
+ FD_SET(listening_socket, &read_set);
+
+ timeval time_to_wait;
+
+
+ // loop on select so if its interupted then we can start it again
+ while (true)
+ {
+
+ // setup a timeval structure
+ time_to_wait.tv_sec = static_cast<long>(timeout/1000);
+ time_to_wait.tv_usec = static_cast<long>((timeout%1000)*1000);
+
+ // wait on select
+ int status = select(listening_socket+1,&read_set,0,0,&time_to_wait);
+
+ // if select timed out
+ if (status == 0)
+ return TIMEOUT;
+
+ // if select returned an error
+ if (status == -1)
+ {
+ // if select was interupted or the connection was aborted
+ // then go back to select
+ if (errno == EINTR ||
+ errno == ECONNABORTED ||
+#ifdef EPROTO
+ errno == EPROTO ||
+#endif
+ errno == ECONNRESET
+ )
+ {
+ continue;
+ }
+ else
+ {
+ return OTHER_ERROR;
+ }
+ }
+
+ // accept the new connection
+ incoming=::accept (
+ listening_socket,
+ reinterpret_cast<sockaddr*>(&incomingAddr),
+ &length
+ );
+
+ // if there was an error return OTHER_ERROR
+ if ( incoming == -1 )
+ {
+ // if accept was interupted then go back to accept
+ if (errno == EINTR ||
+ errno == ECONNABORTED ||
+#ifdef EPROTO
+ errno == EPROTO ||
+#endif
+ errno == ECONNRESET
+ )
+ {
+ continue;
+ }
+ else
+ {
+ return OTHER_ERROR;
+ }
+ }
+
+ // if there were no errors then quit loop
+ break;
+
+ }
+
+ }
+ // else if there is no time out then just go into accept
+ else
+ {
+ while (true)
+ {
+ // call accept to get a new connection
+ incoming=::accept (
+ listening_socket,
+ reinterpret_cast<sockaddr*>(&incomingAddr),
+ &length
+ );
+
+ // if there was an error return OTHER_ERROR
+ if ( incoming == -1 )
+ {
+ // if accept was interupted then go back to accept
+ if (errno == EINTR ||
+ errno == ECONNABORTED ||
+#ifdef EPROTO
+ errno == EPROTO ||
+#endif
+ errno == ECONNRESET
+ )
+ {
+ continue;
+ }
+ else
+ {
+ return OTHER_ERROR;
+ }
+ }
+ break;
+ }
+
+ }
+
+
+ // get the port of the foreign host into foreign_port
+ int foreign_port = ntohs(incomingAddr.sin_port);
+
+ // get the IP of the foreign host into foreign_ip
+ char foreign_ip[16];
+ inet_ntop(AF_INET,&incomingAddr.sin_addr,foreign_ip,16);
+
+
+
+ // get the local ip for this connection into local_ip
+ char temp_local_ip[16];
+ std::string local_ip;
+ if (inaddr_any == true)
+ {
+ sockaddr_in local_info;
+ length = sizeof(sockaddr_in);
+ // get the local sockaddr_in structure associated with this new connection
+ if ( getsockname (
+ incoming,
+ reinterpret_cast<sockaddr*>(&local_info),
+ &length
+ ) == -1
+ )
+ { // an error occurred
+ while (true)
+ {
+ int status = ::close(incoming);
+ if (status == -1 && errno == EINTR)
+ continue;
+ break;
+ }
+ return OTHER_ERROR;
+ }
+ local_ip = const_cast<char*> (
+ inet_ntop(AF_INET,&local_info.sin_addr,temp_local_ip,16)
+ );
+ }
+ else
+ {
+ local_ip = listening_ip;
+ }
+
+
+
+ // set the SO_OOBINLINE option
+ int flag_value = 1;
+ if (setsockopt(incoming,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast<const void*>(&flag_value),sizeof(int)))
+ {
+ while (true)
+ {
+ int status = ::close(incoming);
+ if (status == -1 && errno == EINTR)
+ continue;
+ break;
+ }
+ return OTHER_ERROR;
+ }
+
+
+
+ // make a new connection object for this new connection
+ try
+ {
+ new_connection = new connection (
+ incoming,
+ foreign_port,
+ foreign_ip,
+ listening_port,
+ local_ip
+ );
+ }
+ catch (...)
+ {
+ while (true)
+ {
+ int status = ::close(incoming);
+ if (status == -1 && errno == EINTR)
+ continue;
+ break;
+ }
+ return OTHER_ERROR;
+ }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // socket creation functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ static void
+ close_socket (
+ int sock
+ )
+ /*!
+ requires
+ - sock == a socket
+ ensures
+ - sock has been closed
+ !*/
+ {
+ while (true)
+ {
+ int status = ::close(sock);
+ if (status == -1 && errno == EINTR)
+ continue;
+ break;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int create_listener (
+ std::unique_ptr<listener>& new_listener,
+ unsigned short port,
+ const std::string& ip
+ )
+ {
+ new_listener.reset();
+ listener* temp;
+ int status = create_listener(temp,port,ip);
+
+ if (status == 0)
+ new_listener.reset(temp);
+
+ return status;
+ }
+
+ int create_listener (
+ listener*& new_listener,
+ unsigned short port,
+ const std::string& ip
+ )
+ {
+ sockets_startup();
+
+
+ sockaddr_in sa; // local socket structure
+ memset(&sa,'\0',sizeof(sockaddr_in)); // initialize sa
+
+
+ int sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket
+
+ // if socket() returned an error then return OTHER_ERROR
+ if (sock == -1)
+ {
+ return OTHER_ERROR;
+ }
+
+ // set the local socket structure
+ sa.sin_family = AF_INET;
+ sa.sin_port = htons(port);
+ if (ip.empty())
+ {
+ // if the listener should listen on any IP
+ sa.sin_addr.s_addr = htons(INADDR_ANY);
+ }
+ else
+ {
+ // if there is a specific ip to listen on
+ sa.sin_addr.s_addr = inet_addr(ip.c_str());
+
+ // if inet_addr couldn't convert the ip then return an error
+ if ( sa.sin_addr.s_addr == ( in_addr_t)(-1))
+ {
+ close_socket(sock);
+ return OTHER_ERROR;
+ }
+ }
+
+ // set the SO_REUSEADDR option
+ int flag_value = 1;
+ if (setsockopt(sock,SOL_SOCKET,SO_REUSEADDR,reinterpret_cast<const void*>(&flag_value),sizeof(int)))
+ {
+ close_socket(sock);
+ return OTHER_ERROR;
+ }
+
+
+ // bind the new socket to the requested port and ip
+ if (bind(sock,reinterpret_cast<sockaddr*>(&sa),sizeof(sockaddr_in)) == -1)
+ { // if there was an error
+ close_socket(sock);
+
+ // if the port is already bound then return PORTINUSE
+ if (errno == EADDRINUSE)
+ return PORTINUSE;
+ else
+ return OTHER_ERROR;
+ }
+
+
+ // tell the new socket to listen
+ if ( listen(sock,SOMAXCONN) == -1)
+ {
+ // if there was an error return OTHER_ERROR
+ close_socket(sock);
+
+ // if the port is already bound then return PORTINUSE
+ if (errno == EADDRINUSE)
+ return PORTINUSE;
+ else
+ return OTHER_ERROR;
+ }
+
+ // determine the used local port if necessary
+ if (port == 0)
+ {
+ sockaddr_in local_info;
+ dsocklen_t length = sizeof(sockaddr_in);
+ if ( getsockname(
+ sock,
+ reinterpret_cast<sockaddr*>(&local_info),
+ &length
+ ) == -1)
+ {
+ close_socket(sock);
+ return OTHER_ERROR;
+ }
+ port = ntohs(local_info.sin_port);
+ }
+
+ // initialize a listener object on the heap with the new socket
+ try { new_listener = new listener(sock,port,ip); }
+ catch(...) { close_socket(sock); return OTHER_ERROR; }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int create_connection (
+ std::unique_ptr<connection>& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port,
+ const std::string& local_ip
+ )
+ {
+ new_connection.reset();
+ connection* temp;
+ int status = create_connection(temp,foreign_port, foreign_ip, local_port, local_ip);
+
+ if (status == 0)
+ new_connection.reset(temp);
+
+ return status;
+ }
+
+ int
+ create_connection (
+ connection*& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port,
+ const std::string& local_ip
+ )
+ {
+ sockets_startup();
+
+ sockaddr_in local_sa; // local socket structure
+ sockaddr_in foreign_sa; // foreign socket structure
+ memset(&local_sa,'\0',sizeof(sockaddr_in)); // initialize local_sa
+ memset(&foreign_sa,'\0',sizeof(sockaddr_in)); // initialize foreign_sa
+
+ dsocklen_t length;
+
+ int sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket
+
+ // if socket() returned an error then return OTHER_ERROR
+ if (sock == -1 )
+ {
+ return OTHER_ERROR;
+ }
+
+ // set the foreign socket structure
+ foreign_sa.sin_family = AF_INET;
+ foreign_sa.sin_port = htons(foreign_port);
+ foreign_sa.sin_addr.s_addr = inet_addr(foreign_ip.c_str());
+
+ // if inet_addr couldn't convert the ip then return an error
+ if ( foreign_sa.sin_addr.s_addr == ( in_addr_t)(-1))
+ {
+ close_socket(sock);
+ return OTHER_ERROR;
+ }
+
+
+ // set up the local socket structure
+ local_sa.sin_family = AF_INET;
+
+ // set the local port
+ local_sa.sin_port = htons(local_port);
+
+ // set the local ip
+ if (local_ip.empty())
+ {
+ // if the listener should listen on any IP
+ local_sa.sin_addr.s_addr = htons(INADDR_ANY);
+ }
+ else
+ {
+ // if there is a specific ip to listen on
+ local_sa.sin_addr.s_addr = inet_addr(local_ip.c_str());
+
+ // if inet_addr couldn't convert the ip then return an error
+ if ( local_sa.sin_addr.s_addr == ( in_addr_t)(-1))
+ {
+ close_socket(sock);
+ return OTHER_ERROR;
+ }
+ }
+
+
+
+
+
+ // bind the new socket to the requested local port and local ip
+ if ( bind(sock,reinterpret_cast<sockaddr*>(&local_sa),sizeof(sockaddr_in)) == -1)
+ { // if there was an error
+ close_socket(sock);
+
+ // if the port is already bound then return PORTINUSE
+ if (errno == EADDRINUSE)
+ return PORTINUSE;
+ else
+ return OTHER_ERROR;
+ }
+
+ // connect the socket
+ if ( connect (
+ sock,
+ reinterpret_cast<sockaddr*>(&foreign_sa),
+ sizeof(sockaddr_in)
+ ) == -1
+ )
+ {
+ close_socket(sock);
+ // if the port is already bound then return PORTINUSE
+ if (errno == EADDRINUSE)
+ return PORTINUSE;
+ else
+ return OTHER_ERROR;
+ }
+
+
+ // determine the local port and IP and store them in used_local_ip
+ // and used_local_port
+ int used_local_port;
+ char temp_used_local_ip[16];
+ std::string used_local_ip;
+ sockaddr_in local_info;
+
+ // determine the port
+ if (local_port == 0)
+ {
+ length = sizeof(sockaddr_in);
+ if ( getsockname(
+ sock,
+ reinterpret_cast<sockaddr*>(&local_info),
+ &length
+ ) == -1)
+ {
+ close_socket(sock);
+ return OTHER_ERROR;
+ }
+ used_local_port = ntohs(local_info.sin_port);
+ }
+ else
+ {
+ used_local_port = local_port;
+ }
+
+ // determine the ip
+ if (local_ip.empty())
+ {
+ // if local_port is not 0 then we must fill the local_info structure
+ if (local_port != 0)
+ {
+ length = sizeof(sockaddr_in);
+ if ( getsockname (
+ sock,
+ reinterpret_cast<sockaddr*>(&local_info),
+ &length
+ ) == -1
+ )
+ {
+ close_socket(sock);
+ return OTHER_ERROR;
+ }
+ }
+ used_local_ip = inet_ntop(AF_INET,&local_info.sin_addr,temp_used_local_ip,16);
+ }
+ else
+ {
+ used_local_ip = local_ip;
+ }
+
+
+ // set the SO_OOBINLINE option
+ int flag_value = 1;
+ if (setsockopt(sock,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast<const void*>(&flag_value),sizeof(int)))
+ {
+ close_socket(sock);
+ return OTHER_ERROR;
+ }
+
+
+ // initialize a connection object on the heap with the new socket
+ try
+ {
+ new_connection = new connection (
+ sock,
+ foreign_port,
+ foreign_ip,
+ used_local_port,
+ used_local_ip
+ );
+ }
+ catch(...) {close_socket(sock); return OTHER_ERROR; }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // POSIX
+
+#endif // DLIB_SOCKETS_KERNEL_2_CPp_
+
diff --git a/ml/dlib/dlib/sockets/sockets_kernel_2.h b/ml/dlib/dlib/sockets/sockets_kernel_2.h
new file mode 100644
index 000000000..f3bc94ec0
--- /dev/null
+++ b/ml/dlib/dlib/sockets/sockets_kernel_2.h
@@ -0,0 +1,396 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net), Miguel Grinberg
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETS_KERNEl_2_
+#define DLIB_SOCKETS_KERNEl_2_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+#include "../platform.h"
+
+#include "sockets_kernel_abstract.h"
+
+#define _BSD_SOCKLEN_T_
+
+#include <ctime>
+#include <memory>
+#include <string>
+
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <errno.h>
+
+#ifndef HPUX
+#include <sys/select.h>
+#endif
+#include <arpa/inet.h>
+#include <signal.h>
+#include <inttypes.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <sys/param.h>
+
+#include <netinet/in.h>
+
+#include "../threads.h"
+#include "../algs.h"
+
+
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ // forward declarations
+ class socket_factory;
+ class listener;
+
+// ----------------------------------------------------------------------------------------
+
+ // lookup functions
+
+ int
+ get_local_hostname (
+ std::string& hostname
+ );
+
+// -----------------
+
+ int
+ hostname_to_ip (
+ const std::string& hostname,
+ std::string& ip,
+ int n = 0
+ );
+
+// -----------------
+
+ int
+ ip_to_hostname (
+ const std::string& ip,
+ std::string& hostname
+ );
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // connection object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class connection
+ {
+ /*!
+ INITIAL_VALUE
+ sd == false
+ sdo == false
+ sdr == 0
+
+
+ CONVENTION
+ connection_socket == the socket handle for this connection.
+ connection_foreign_port == the port that foreign host is using for
+ this connection
+ connection_foreign_ip == a string containing the IP address of the
+ foreign host
+ connection_local_port == the port that the local host is using for
+ this connection
+ connection_local_ip == a string containing the IP address of the
+ local interface being used by this connection
+
+ sd == if shutdown() has been called then true
+ else false
+ sdo == if shutdown_outgoing() has been called then true
+ else false
+ sdr == the return value of shutdown() if it has been
+ called. if it hasn't been called then 0
+
+
+ !*/
+
+ friend class listener; // make listener a friend of connection
+ // make create_connection a friend of connection
+ friend int create_connection (
+ connection*& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port,
+ const std::string& local_ip
+ );
+
+ public:
+
+ ~connection();
+
+ void* user_data;
+
+ long write (
+ const char* buf,
+ long num
+ );
+
+ long read (
+ char* buf,
+ long num
+ );
+
+ long read (
+ char* buf,
+ long num,
+ unsigned long timeout
+ );
+
+ int get_local_port (
+ ) const { return connection_local_port; }
+
+ int get_foreign_port (
+ ) const { return connection_foreign_port; }
+
+ const std::string& get_local_ip (
+ ) const { return connection_local_ip; }
+
+ const std::string& get_foreign_ip (
+ ) const { return connection_foreign_ip; }
+
+ int shutdown_outgoing (
+ )
+ {
+ sd_mutex.lock();
+ if (sdo || sd)
+ {
+ sd_mutex.unlock();
+ return sdr;
+ }
+ sdo = true;
+ sdr = ::shutdown(connection_socket,SHUT_WR);
+ int temp = sdr;
+ sd_mutex.unlock();
+ return temp;
+ }
+
+ int shutdown (
+ )
+ {
+ sd_mutex.lock();
+ if (sd)
+ {
+ sd_mutex.unlock();
+ return sdr;
+ }
+ sd = true;
+ sdr = ::shutdown(connection_socket,SHUT_RDWR);
+ int temp = sdr;
+ sd_mutex.unlock();
+ return temp;
+ }
+
+ int disable_nagle(
+ );
+
+ typedef int socket_descriptor_type;
+
+ socket_descriptor_type get_socket_descriptor (
+ ) const { return connection_socket; }
+
+ private:
+
+ bool readable (
+ unsigned long timeout
+ ) const;
+ /*!
+ requires
+ - timeout < 2000000
+ ensures
+ - returns true if a read call on this connection will not block.
+ - returns false if a read call on this connection will block or if
+ there was an error.
+ !*/
+
+ bool sd_called (
+ )const
+ /*!
+ ensures
+ - returns true if shutdown() has been called else
+ - returns false
+ !*/
+ {
+ sd_mutex.lock();
+ bool temp = sd;
+ sd_mutex.unlock();
+ return temp;
+ }
+
+ bool sdo_called (
+ )const
+ /*!
+ ensures
+ - returns true if shutdown_outgoing() or shutdown() has been called
+ else returns false
+ !*/
+ {
+ sd_mutex.lock();
+ bool temp = false;
+ if (sdo || sd)
+ temp = true;
+ sd_mutex.unlock();
+ return temp;
+ }
+
+
+ // data members
+ int connection_socket;
+ const int connection_foreign_port;
+ const std::string connection_foreign_ip;
+ const int connection_local_port;
+ const std::string connection_local_ip;
+
+ bool sd; // called shutdown
+ bool sdo; // called shutdown_outgoing
+ int sdr; // return value for shutdown
+ mutex sd_mutex; // a lock for the three above vars
+
+ connection(
+ int sock,
+ int foreign_port,
+ const std::string& foreign_ip,
+ int local_port,
+ const std::string& local_ip
+ );
+ /*!
+ requires
+ - sock is a socket handle
+ - sock is the handle for the connection between foreign_ip:foreign_port
+ and local_ip:local_port
+ ensures
+ - *this is initialized correctly with the above parameters
+ !*/
+
+
+ // restricted functions
+ connection();
+ connection(connection&); // copy constructor
+ connection& operator=(connection&); // assignement opertor
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // listener object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class listener
+ {
+ /*!
+ CONVENTION
+ if (inaddr_any == false)
+ {
+ listening_ip == a string containing the address the listener is
+ listening on
+ }
+ else
+ {
+ the listener is listening on all interfaces
+ }
+
+ listening_port == the port the listener is listening on
+ listening_socket == the listening socket handle for this object
+ !*/
+
+ // make the create_listener a friend of listener
+ friend int create_listener (
+ listener*& new_listener,
+ unsigned short port,
+ const std::string& ip
+ );
+
+ public:
+
+ ~listener();
+
+ int accept (
+ connection*& new_connection,
+ unsigned long timeout = 0
+ );
+
+ int accept (
+ std::unique_ptr<connection>& new_connection,
+ unsigned long timeout = 0
+ );
+
+ int get_listening_port (
+ ) const { return listening_port; }
+
+ const std::string& get_listening_ip (
+ ) const { return listening_ip; }
+
+ private:
+
+ // data members
+ int listening_socket;
+ const int listening_port;
+ const std::string listening_ip;
+ const bool inaddr_any;
+
+ listener(
+ int sock,
+ int port,
+ const std::string& ip
+ );
+ /*!
+ requires
+ - sock is a socket handle
+ - sock is listening on the port and ip(may be "") indicated in the above
+ parameters
+ ensures
+ - *this is initialized correctly with the above parameters
+ !*/
+
+
+ // restricted functions
+ listener();
+ listener(listener&); // copy constructor
+ listener& operator=(listener&); // assignement opertor
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ int create_listener (
+ listener*& new_listener,
+ unsigned short port,
+ const std::string& ip = ""
+ );
+
+ int create_connection (
+ connection*& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port = 0,
+ const std::string& local_ip = ""
+ );
+
+ int create_listener (
+ std::unique_ptr<listener>& new_listener,
+ unsigned short port,
+ const std::string& ip = ""
+ );
+
+ int create_connection (
+ std::unique_ptr<connection>& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port = 0,
+ const std::string& local_ip = ""
+ );
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "sockets_kernel_2.cpp"
+#endif
+
+#endif // DLIB_SOCKETS_KERNEl_2_
+
diff --git a/ml/dlib/dlib/sockets/sockets_kernel_abstract.h b/ml/dlib/dlib/sockets/sockets_kernel_abstract.h
new file mode 100644
index 000000000..d4571acad
--- /dev/null
+++ b/ml/dlib/dlib/sockets/sockets_kernel_abstract.h
@@ -0,0 +1,495 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SOCKETS_KERNEl_ABSTRACT_
+#ifdef DLIB_SOCKETS_KERNEl_ABSTRACT_
+
+#include <string>
+#include "../threads.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ GENERAL COMMENTS:
+ Nothing in here will throw exceptions.
+
+ All ip address strings in this file refer to IPv4 addresses. For
+ example "192.168.1.1"
+
+ Timeouts:
+ All timeout values are measured in milliseconds but you are not
+ guaranteed to have that level of resolution. The actual resolution
+ is implementation defined.
+
+ GENERAL WARNING
+ Don't call any of these functions or make any of these objects
+ before main() has been entered.
+
+ EXCEPTIONS
+ Unless specified otherwise, nothing in this file throws exceptions.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ // LOOKUP FUNCTIONS
+
+ // all lookup functions are thread-safe
+
+ int get_local_hostname (
+ std::string& hostname
+ );
+ /*!
+ ensures
+ - if (#get_local_hostname() == 0) then
+ - #hostname == a string containing the hostname of the local computer
+
+ - returns 0 upon success
+ - returns OTHER_ERROR upon failure and in this case #hostname's value
+ is undefined
+ !*/
+
+// -----------------
+
+ int hostname_to_ip (
+ const std::string& hostname,
+ std::string& ip,
+ int n = 0
+ );
+ /*!
+ requires
+ - n >= 0
+ ensures
+ - if (#hostname_to_ip() == 0) then
+ - #ip == string containing the nth ip address associated with the hostname
+
+ - returns 0 upon success
+ - returns OTHER_ERROR upon failure
+ !*/
+
+// -----------------
+
+ int ip_to_hostname (
+ const std::string& ip,
+ std::string& hostname
+ );
+ /*!
+ ensures
+ - if (#ip_to_hostname() == 0) then
+ - #hostname == string containing the hostname associated with ip
+
+ - returns 0 upon success
+ - returns OTHER_ERROR upon failure
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ //
+ // socket creation functions
+ //
+ // The following functions are guaranteed to be thread-safe
+ //
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ int create_listener (
+ listener*& new_listener,
+ unsigned short port,
+ const std::string& ip = ""
+ );
+ /*!
+ requires
+ - 0 <= port <= 65535
+ ensures
+ - if (#create_listener() == 0) then
+ - #new_listener == a pointer to a listener object that is listening on
+ the specified port and ip for an incoming connection
+ - if (ip == "") then
+ - the new listener will be listening on all interfaces
+ - if (port == 0) then
+ - the operating system will assign a free port to listen on
+
+
+ - returns 0 if create_listener was successful
+ - returns PORTINUSE if the specified local port was already in use
+ - returns OTHER_ERROR if some other error occurred
+ !*/
+
+ int create_listener (
+ std::unique_ptr<listener>& new_listener,
+ unsigned short port,
+ const std::string& ip = ""
+ );
+ /*!
+ This function is just an overload of the above function but it gives you a
+ std::unique_ptr smart pointer instead of a C pointer.
+ !*/
+
+ int create_connection (
+ connection*& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port = 0,
+ const std::string& local_ip = ""
+ );
+ /*!
+ requires
+ - 0 < foreign_port <= 65535
+ - 0 <= local_port <= 65535
+ ensures
+ - if (#create_connection() == 0) then
+ - #new_connection == a pointer to a connection object that is connected
+ to foreign_ip on port foreign_port and is using the local interface
+ local_ip and local port local_port
+ - #new_connection->user_data == 0
+ - if (local_ip == "") then
+ - the operating system will chose this for you
+ - if (local_port == 0) then
+ - the operating system will chose this for you
+
+ - returns 0 if create_connection was successful
+ - returns PORTINUSE if the specified local port was already in use
+ - returns OTHER_ERROR if some other error occurred
+ !*/
+
+ int create_connection (
+ std::unique_ptr<connection>& new_connection,
+ unsigned short foreign_port,
+ const std::string& foreign_ip,
+ unsigned short local_port = 0,
+ const std::string& local_ip = ""
+ );
+ /*!
+ This function is just an overload of the above function but it gives you a
+ std::unique_ptr smart pointer instead of a C pointer.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // connection object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class connection
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a TCP connection.
+
+ Instances of this class can only be created by using the
+ create_connection function or listener class defined below.
+
+ NOTE:
+ A connection object must ALWAYS be closed (delete the pointer to the
+ connection) or it will cause a resource leak.
+
+ Note also that all errors indicated by a return code of OTHER_ERROR
+ are fatal so if one occurs the connection should just be closed.
+
+ CLOSING A CONNECTION
+ Note that if ~connection() or shutdown() is called before the remote client
+ has received all sent data it is possible that the data will be lost. To
+ avoid this you should call the close_gracefully() function to close your
+ connections (unless you actually do want to immediately dispose of a
+ connection and don't care about the data).
+ (example: close_gracefully(con); // close con gracefully but force it closed
+ // if it takes more than 500 milliseconds.)
+
+ THREAD SAFETY
+ - It is always safe to call shutdown() or shutdown_outgoing().
+ - you may NOT call any function more than once at a time (except the
+ shutdown functions).
+ - do not call read() more than once at a time
+ - do not call write() more than once at a time
+ - You can safely call shutdown or shutdown_outgoing in conjunction with
+ the read/write functions.
+ This is helpful if you want to unblock another thread that is
+ blocking on a read/write operation. Shutting down the connection
+ will cause the read/write functions to return a value of SHUTDOWN.
+
+ OUT-OF-BAND DATA:
+ All out-of-band data will be put inline into the normal data stream.
+ This means that you can read any out-of-band data via calls to read().
+ (i.e. the SO_OOBINLINE socket option will be set)
+ !*/
+
+ public:
+
+ ~connection (
+ );
+ /*!
+ requires
+ - no other threads are using this connection object
+ ensures
+ - closes the connection (this is an abrupt non-graceful close)
+ - frees the resources used by this object
+ !*/
+
+ void* user_data;
+ /*!
+ This pointer is provided so that the client programmer may easily associate
+ some data with a connection object. You can really do whatever you want
+ with it. Initially user_data is 0.
+ !*/
+
+ long write (
+ const char* buf,
+ long num
+ );
+ /*!
+ requires
+ - num > 0
+ - buf points to an array of at least num bytes
+ ensures
+ - will block until ONE of the following occurs:
+ - num bytes from buf have been written to the connection
+ - an error has occurred
+ - the outgoing channel of the connection has been shutdown locally
+
+ - returns num if write succeeded
+ - returns OTHER_ERROR if there was an error (this could be due to a
+ connection close)
+ - returns SHUTDOWN if the outgoing channel of the connection has been
+ shutdown locally
+ !*/
+
+ long read (
+ char* buf,
+ long num
+ );
+ /*!
+ requires
+ - num > 0
+ - buf points to an array of at least num bytes
+ ensures
+ - read() will not read more than num bytes of data into #buf
+ - read blocks until ONE of the following happens:
+ - there is some data available and it has been written into #buf
+ - the remote end of the connection is closed
+ - an error has occurred
+ - the connection has been shutdown locally
+
+ - returns the number of bytes read into #buf if there was any data.
+ - returns 0 if the connection has ended/terminated and there is no more data.
+ - returns OTHER_ERROR if there was an error.
+ - returns SHUTDOWN if the connection has been shutdown locally
+ !*/
+
+ long read (
+ char* buf,
+ long num,
+ unsigned long timeout
+ );
+ /*!
+ requires
+ - num > 0
+ - buf points to an array of at least num bytes
+ - timeout < 2000000
+ ensures
+ - read() will not read more than num bytes of data into #buf
+ - if (timeout > 0) then read() blocks until ONE of the following happens:
+ - there is some data available and it has been written into #buf
+ - the remote end of the connection is closed
+ - an error has occurred
+ - the connection has been shutdown locally
+ - timeout milliseconds has elapsed
+ - else
+ - read() does not block
+
+ - returns the number of bytes read into #buf if there was any data.
+ - returns 0 if the connection has ended/terminated and there is no more data.
+ - returns TIMEOUT if timeout milliseconds elapsed before we got any data.
+ - returns OTHER_ERROR if there was an error.
+ - returns SHUTDOWN if the connection has been shutdown locally
+ !*/
+
+ unsigned short get_local_port (
+ ) const;
+ /*!
+ ensures
+ - returns the local port number for this connection
+ !*/
+
+ unsigned short get_foreign_port (
+ ) const;
+ /*!
+ ensures
+ - returns the foreign port number for this connection
+ !*/
+
+ const std::string& get_local_ip (
+ ) const;
+ /*!
+ ensures
+ - returns the IP of the local interface this connection is using
+ !*/
+
+ const std::string& get_foreign_ip (
+ ) const;
+ /*!
+ ensures
+ - returns the IP of the foreign host for this connection
+ !*/
+
+ int shutdown (
+ );
+ /*!
+ ensures
+ - if (#shutdown() == 0 && connection was still open) then
+ - terminates the connection but does not free the resources for the
+ connection object
+
+ - any read() or write() calls on this connection will return immediately
+ with the code SHUTDOWN.
+
+ - returns 0 upon success
+ - returns OTHER_ERROR if there was an error
+ !*/
+
+ int shutdown_outgoing (
+ );
+ /*!
+ ensures
+ - if (#shutdown_outgoing() == 0 && outgoing channel was still open) then
+ - sends a FIN to indicate that no more data will be sent on this
+ connection but leaves the receive half of the connection open to
+ receive more data from the other host
+
+ - any calls to write() will return immediately with the code SHUTDOWN.
+
+ - returns 0 upon success
+ - returns OTHER_ERROR if there was an error
+ !*/
+
+ int disable_nagle(
+ );
+ /*!
+ ensures
+ - Sets the TCP_NODELAY socket option to disable Nagle's algorithm.
+ This can sometimes reduce transmission latency, however, in almost
+ all normal cases you don't want to mess with this as the default
+ setting is usually appropriate.
+
+ - returns 0 upon success
+ - returns OTHER_ERROR if there was an error
+ !*/
+
+ typedef platform_specific_type socket_descriptor_type;
+ socket_descriptor_type get_socket_descriptor (
+ ) const;
+ /*!
+ ensures
+ - returns the underlying socket descriptor for this connection
+ object. The reason you might want access to this is to
+ pass it to some other library that requires a socket file
+ descriptor. However, if you do this then you probably shouldn't
+ use the dlib::connection read() and write() anymore since
+ whatever you are doing with the socket descriptor is probably
+ doing I/O with the socket.
+ !*/
+
+ private:
+ // restricted functions
+ connection();
+ connection(connection&); // copy constructor
+ connection& operator=(connection&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // listener object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class listener
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a TCP socket waiting for incoming connections.
+ Calling accept returns a pointer to any new incoming connections on its
+ port.
+
+ Instances of this class can only be created by using the
+ create_listener function defined below.
+
+ NOTE:
+ A listener object must ALWAYS be closed (delete the pointer to it) or
+ it will cause a resource leak.
+
+ Note also that all errors indicated by a return code of OTHER_ERROR
+ are fatal so if one occurs the listener should be closed.
+
+ THREAD SAFETY
+ None of the functions in this object are guaranteed to be thread-safe.
+ This means that you must serialize all access to this object.
+ !*/
+
+ public:
+
+ ~listener (
+ );
+ /*!
+ requires
+ - no other threads are using this listener object
+ ensures
+ - closes the listener
+ - frees the resources used by this object
+ !*/
+
+ int accept (
+ connection*& new_connection,
+ unsigned long timeout = 0
+ );
+ /*!
+ requires
+ - timeout < 2000000
+ ensures
+ - blocks until a new connection is ready or timeout milliseconds have
+ elapsed.
+ - #new_connection == a pointer to the new connection object
+ - #new_connection->user_data == 0
+ - if (timeout == 0) then
+ - the timeout argument is ignored
+
+ - returns 0 if accept() was successful
+ - returns TIMEOUT if timeout milliseconds have elapsed
+ - returns OTHER_ERROR if an error has occurred
+ !*/
+
+ int accept (
+ std::unique_ptr<connection>& new_connection,
+ unsigned long timeout = 0
+ );
+ /*!
+ This function is just an overload of the above function but it gives you a
+ std::unique_ptr smart pointer instead of a C pointer.
+ !*/
+
+ unsigned short get_listening_port (
+ ) const;
+ /*!
+ ensures
+ - returns the port number that this object is listening on
+ !*/
+
+ const std::string& get_listening_ip (
+ ) const;
+ /*!
+ ensures
+ - returns a string containing the IP (e.g. "127.0.0.1") of the
+ interface this object is listening on
+ - returns "" if it is accepting connections on all interfaces
+ !*/
+
+ private:
+ // restricted functions
+ listener();
+ listener(listener&); // copy constructor
+ listener& operator=(listener&); // assignment operator
+ };
+}
+
+#endif // DLIB_SOCKETS_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/sockets/windows.h b/ml/dlib/dlib/sockets/windows.h
new file mode 100644
index 000000000..85b7fd8d8
--- /dev/null
+++ b/ml/dlib/dlib/sockets/windows.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETS_KERNEl_2_
+#include "sockets_kernel_1.h"
+#endif
+
diff --git a/ml/dlib/dlib/sockstreambuf.h b/ml/dlib/dlib/sockstreambuf.h
new file mode 100644
index 000000000..41b0e7a9e
--- /dev/null
+++ b/ml/dlib/dlib/sockstreambuf.h
@@ -0,0 +1,11 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKSTREAMBUf_H_h_
+#define DLIB_SOCKSTREAMBUf_H_h_
+
+#include "sockstreambuf/sockstreambuf.h"
+#include "sockstreambuf/sockstreambuf_unbuffered.h"
+
+
+#endif // DLIB_SOCKSTREAMBUf_H_h_
+
diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf.cpp b/ml/dlib/dlib/sockstreambuf/sockstreambuf.cpp
new file mode 100644
index 000000000..e328e4259
--- /dev/null
+++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf.cpp
@@ -0,0 +1,177 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKStREAMBUF_CPp_
+#define DLIB_SOCKStREAMBUF_CPp_
+
+#include "sockstreambuf.h"
+#include "../assert.h"
+
+#include <cstring>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+ // output functions
+// ----------------------------------------------------------------------------------------
+
+ sockstreambuf::int_type sockstreambuf::
+ overflow (
+ int_type c
+ )
+ {
+ if (c != EOF)
+ {
+ *pptr() = c;
+ pbump(1);
+ }
+ if (flush_out_buffer() == EOF)
+ {
+ // an error occurred
+ return EOF;
+ }
+ return c;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::streamsize sockstreambuf::
+ xsputn (
+ const char* s,
+ std::streamsize num
+ )
+ {
+ // Add a sanity check here
+ DLIB_ASSERT(num >= 0,
+ "\tstd::streamsize sockstreambuf::xsputn"
+ << "\n\tThe number of bytes to write can't be negative"
+ << "\n\tnum: " << num
+ << "\n\tthis: " << this
+ );
+
+ std::streamsize space_left = static_cast<std::streamsize>(epptr()-pptr());
+ if (num <= space_left)
+ {
+ std::memcpy(pptr(),s,static_cast<size_t>(num));
+ pbump(static_cast<int>(num));
+ return num;
+ }
+ else
+ {
+ std::memcpy(pptr(),s,static_cast<size_t>(space_left));
+ s += space_left;
+ pbump(space_left);
+ std::streamsize num_left = num - space_left;
+
+ if (flush_out_buffer() == EOF)
+ {
+ // the write was not successful so return that 0 bytes were written
+ return 0;
+ }
+
+ if (num_left < out_buffer_size)
+ {
+ std::memcpy(pptr(),s,static_cast<size_t>(num_left));
+ pbump(num_left);
+ return num;
+ }
+ else
+ {
+ if (con.write(s,num_left) != num_left)
+ {
+ // the write was not successful so return that 0 bytes were written
+ return 0;
+ }
+ return num;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+ // input functions
+// ----------------------------------------------------------------------------------------
+
+ sockstreambuf::int_type sockstreambuf::
+ underflow(
+ )
+ {
+ if (gptr() < egptr())
+ {
+ return static_cast<unsigned char>(*gptr());
+ }
+
+ int num_put_back = static_cast<int>(gptr() - eback());
+ if (num_put_back > max_putback)
+ {
+ num_put_back = max_putback;
+ }
+
+ // copy the putback characters into the putback end of the in_buffer
+ std::memmove(in_buffer+(max_putback-num_put_back), gptr()-num_put_back, num_put_back);
+
+ if (flushes_output_on_read())
+ {
+ if (flush_out_buffer() == EOF)
+ {
+ // an error occurred
+ return EOF;
+ }
+ }
+
+ int num = con.read(in_buffer+max_putback, in_buffer_size-max_putback);
+ if (num <= 0)
+ {
+ // an error occurred or the connection is over which is EOF
+ return EOF;
+ }
+
+ // reset in_buffer pointers
+ setg (in_buffer+(max_putback-num_put_back),
+ in_buffer+max_putback,
+ in_buffer+max_putback+num);
+
+ return static_cast<unsigned char>(*gptr());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::streamsize sockstreambuf::
+ xsgetn (
+ char_type* s,
+ std::streamsize n
+ )
+ {
+ std::streamsize temp = n;
+ while (n > 0)
+ {
+ int num = static_cast<int>(egptr() - gptr());
+ if (num >= n)
+ {
+ // copy data from our buffer
+ std::memcpy(s, gptr(), static_cast<size_t>(n));
+ gbump(static_cast<int>(n));
+ return temp;
+ }
+
+ // read more data into our buffer
+ if (num == 0)
+ {
+ if (underflow() == EOF)
+ break;
+ continue;
+ }
+
+ // copy all the data from our buffer
+ std::memcpy(s, gptr(), num);
+ n -= num;
+ gbump(num);
+ s += num;
+ }
+ return temp-n;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_SOCKStREAMBUF_CPp_
+
diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf.h b/ml/dlib/dlib/sockstreambuf/sockstreambuf.h
new file mode 100644
index 000000000..f5b450e78
--- /dev/null
+++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf.h
@@ -0,0 +1,172 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKStREAMBUF_Hh_
+#define DLIB_SOCKStREAMBUF_Hh_
+
+#include <iosfwd>
+#include <streambuf>
+#include "../sockets.h"
+#include "sockstreambuf_abstract.h"
+#include "sockstreambuf_unbuffered.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class sockstreambuf : public std::streambuf
+ {
+ /*!
+ INITIAL VALUE
+ - con == a connection
+ - in_buffer == an array of in_buffer_size bytes
+ - out_buffer == an array of out_buffer_size bytes
+
+ CONVENTION
+ - in_buffer == the input buffer used by this streambuf
+ - out_buffer == the output buffer used by this streambuf
+ - max_putback == the maximum number of chars to have in the put back buffer.
+ !*/
+
+ public:
+
+ // These typedefs are here for backwards compatibility with previous versions of
+ // dlib.
+ typedef sockstreambuf_unbuffered kernel_1a;
+ typedef sockstreambuf kernel_2a;
+
+ sockstreambuf (
+ connection* con_
+ ) :
+ con(*con_),
+ out_buffer(0),
+ in_buffer(0),
+ autoflush(false)
+ {
+ init();
+ }
+
+ sockstreambuf (
+ const std::unique_ptr<connection>& con_
+ ) :
+ con(*con_),
+ out_buffer(0),
+ in_buffer(0),
+ autoflush(false)
+ {
+ init();
+ }
+
+ virtual ~sockstreambuf (
+ )
+ {
+ sync();
+ delete [] out_buffer;
+ delete [] in_buffer;
+ }
+
+ connection* get_connection (
+ ) { return &con; }
+
+ void flush_output_on_read()
+ {
+ autoflush = true;
+ }
+
+ bool flushes_output_on_read() const
+ {
+ return autoflush;
+ }
+
+ void do_not_flush_output_on_read()
+ {
+ autoflush = false;
+ }
+
+ protected:
+
+ void init (
+ )
+ {
+ try
+ {
+ out_buffer = new char[out_buffer_size];
+ in_buffer = new char[in_buffer_size];
+ }
+ catch (...)
+ {
+ if (out_buffer) delete [] out_buffer;
+ throw;
+ }
+ setp(out_buffer, out_buffer + (out_buffer_size-1));
+ setg(in_buffer+max_putback,
+ in_buffer+max_putback,
+ in_buffer+max_putback);
+ }
+
+ int flush_out_buffer (
+ )
+ {
+ int num = static_cast<int>(pptr()-pbase());
+ if (con.write(out_buffer,num) != num)
+ {
+ // the write was not successful so return EOF
+ return EOF;
+ }
+ pbump(-num);
+ return num;
+ }
+
+ // output functions
+ int_type overflow (
+ int_type c
+ );
+
+ int sync (
+ )
+ {
+ if (flush_out_buffer() == EOF)
+ {
+ // an error occurred
+ return -1;
+ }
+ return 0;
+ }
+
+ std::streamsize xsputn (
+ const char* s,
+ std::streamsize num
+ );
+
+ // input functions
+ int_type underflow(
+ );
+
+ std::streamsize xsgetn (
+ char_type* s,
+ std::streamsize n
+ );
+
+ private:
+
+ // member data
+ connection& con;
+ static const std::streamsize max_putback = 4;
+ static const std::streamsize out_buffer_size = 10000;
+ static const std::streamsize in_buffer_size = 10000;
+ char* out_buffer;
+ char* in_buffer;
+ bool autoflush;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "sockstreambuf.cpp"
+#endif
+
+#endif // DLIB_SOCKStREAMBUF_Hh_
+
diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf_abstract.h b/ml/dlib/dlib/sockstreambuf/sockstreambuf_abstract.h
new file mode 100644
index 000000000..12be84193
--- /dev/null
+++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf_abstract.h
@@ -0,0 +1,127 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SOCKSTREAMBUF_ABSTRACT_
+#ifdef DLIB_SOCKSTREAMBUF_ABSTRACT_
+
+#include <iosfwd>
+#include <memory>
+#include <streambuf>
+
+#include "../sockets/sockets_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class sockstreambuf : public std::streambuf
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a stream buffer capable of writing to and
+ reading from TCP connections.
+
+ NOTE:
+ For a sockstreambuf EOF is when the connection has closed or otherwise
+ returned some kind of error.
+
+ Also note that any data written to the streambuf may be buffered
+ internally. So if you need to ensure that data is actually sent then you
+ should flush the stream.
+
+ A read operation is guaranteed to block until the number of bytes
+ requested has arrived on the connection. It will never keep blocking
+ once enough data has arrived.
+
+ THREADING
+ Generally speaking, this object has the same kind of threading
+ restrictions as a connection object. those being:
+ - Do not try to write to a sockstreambuf from more than one thread.
+ - Do not try to read from a sockstreambuf from more than one thread.
+ - You may call shutdown() on the connection object and this will
+ cause any reading or writing calls to end. To the sockstreambuf it
+ will appear the same as hitting EOF. (note that EOF for a sockstreambuf
+ means that the connection has closed)
+ - It is safe to read from and write to the sockstreambuf at the same time
+ from different threads so long as flushes_output_on_read()==false.
+ - It is not safe to try to putback a char and read from the stream from
+ different threads
+ !*/
+ public:
+ sockstreambuf (
+ connection* con
+ );
+ /*!
+ requires
+ - con == a valid connection object
+ ensures
+ - *this will read from and write to con
+ - #flushes_output_on_read() == false
+ throws
+ - std::bad_alloc
+ !*/
+
+ sockstreambuf (
+ const std::unique_ptr<connection>& con
+ );
+ /*!
+ requires
+ - con == a valid connection object
+ ensures
+ - *this will read from and write to con
+ - #flushes_output_on_read() == false
+ throws
+ - std::bad_alloc
+ !*/
+
+ ~sockstreambuf (
+ );
+ /*!
+ requires
+ - get_connection() object has not been deleted
+ ensures
+ - sockstream buffer is destructed but the connection object will
+ NOT be closed.
+ - Any buffered data is flushed to the connection.
+ !*/
+
+ connection* get_connection (
+ );
+ /*!
+ ensures
+ - returns a pointer to the connection object which this buffer
+ reads from and writes to
+ !*/
+
+ void flush_output_on_read (
+ );
+ /*!
+ ensures
+ - #flushes_output_on_read() == true
+ !*/
+
+ bool flushes_output_on_read (
+ ) const;
+ /*!
+ ensures
+ - This function returns true if this object will flush its output buffer
+ to the network socket before performing any network read.
+ - if (flushes_output_on_read() == true)
+ - It is not safe to make concurrent read and write calls to this object.
+ !*/
+
+ void do_not_flush_output_on_read (
+ );
+ /*!
+ ensures
+ - #flushes_output_on_read() == false
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SOCKSTREAMBUF_ABSTRACT_
+
diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.cpp b/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.cpp
new file mode 100644
index 000000000..4dcefc17b
--- /dev/null
+++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.cpp
@@ -0,0 +1,168 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKSTrEAMBUF_UNBUFFERED_CPp_
+#define DLIB_SOCKSTrEAMBUF_UNBUFFERED_CPp_
+
+#include "sockstreambuf_unbuffered.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+ // output functions
+// ----------------------------------------------------------------------------------------
+
+ sockstreambuf_unbuffered::int_type sockstreambuf_unbuffered::
+ overflow (
+ int_type c
+ )
+ {
+ if (c != EOF)
+ {
+ char temp = static_cast<char>(c);
+ if (con.write(&temp,1) != 1)
+ {
+ // if the write was not successful
+ return EOF;
+ }
+ }
+ return c;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::streamsize sockstreambuf_unbuffered::
+ xsputn (
+ const char* s,
+ std::streamsize num
+ )
+ {
+ if (con.write(s,static_cast<int>(num)) != num)
+ {
+ // the write was not successful so return that 0 bytes were written
+ return 0;
+ }
+ return num;
+ }
+
+// ----------------------------------------------------------------------------------------
+ // input functions
+// ----------------------------------------------------------------------------------------
+
+ sockstreambuf_unbuffered::int_type sockstreambuf_unbuffered::
+ underflow(
+ )
+ {
+ if (lastread_next)
+ {
+ return lastread;
+ }
+ else if (peek != EOF)
+ {
+ return peek;
+ }
+ else
+ {
+ char temp;
+ if (con.read(&temp,1) != 1)
+ {
+ // some error occurred
+ return EOF;
+ }
+ peek = static_cast<unsigned char>(temp);
+ return peek;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ sockstreambuf_unbuffered::int_type sockstreambuf_unbuffered::
+ uflow(
+ )
+ {
+ if (lastread_next)
+ {
+ lastread_next = false;
+ return lastread;
+ }
+ else if (peek != EOF)
+ {
+ lastread = peek;
+ peek = EOF;
+ return lastread;
+ }
+ else
+ {
+ char temp;
+ if (con.read(&temp,1) != 1)
+ {
+ // some error occurred
+ return EOF;
+ }
+ lastread = static_cast<unsigned char>(temp);
+ return lastread;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ sockstreambuf_unbuffered::int_type sockstreambuf_unbuffered::
+ pbackfail(
+ int_type c
+ )
+ {
+ // if they are trying to push back a character that they didn't read last
+ // that is an error
+ if (c != EOF && c != lastread)
+ return EOF;
+
+ // if they are trying to push back a second character then thats an error
+ if (lastread_next)
+ return EOF;
+
+ lastread_next = true;
+ return 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::streamsize sockstreambuf_unbuffered::
+ xsgetn (
+ char_type* s,
+ std::streamsize n
+ )
+ {
+ std::streamsize temp = n;
+ if (lastread_next && n > 0)
+ {
+ *s = lastread;
+ lastread_next = false;
+ ++s;
+ --n;
+ }
+ if (peek != EOF && n > 0)
+ {
+ *s = peek;
+ peek = EOF;
+ ++s;
+ --n;
+ }
+
+ while (n>0)
+ {
+ int status = con.read(s,static_cast<int>(n));
+ if (status < 1)
+ break;
+ n -= status;
+ s += status;
+ }
+
+ return temp-n;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_SOCKSTrEAMBUF_UNBUFFERED_CPp_
+
diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.h b/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.h
new file mode 100644
index 000000000..8aa5992db
--- /dev/null
+++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.h
@@ -0,0 +1,118 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKSTrEAMBUF_UNBUFFERED_Hh_
+#define DLIB_SOCKSTrEAMBUF_UNBUFFERED_Hh_
+
+#include <iosfwd>
+#include <streambuf>
+#include "../sockets.h"
+#include "sockstreambuf_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class sockstreambuf_unbuffered : public std::streambuf
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the interface defined in
+ sockstreambuf_abstract.h except that it doesn't do any kind of buffering at
+ all. It just writes data directly to a connection. However, note that we
+ don't implement the flushes_output_on_read() routine as this object always
+ flushes immediately (since it isn't buffers. Moreover, it should be
+ pointed out that this object is deprecated and only present for backwards
+ compatibility with previous versions of dlib. So you really should use the
+ sockstreambuf object instead.
+
+ INITIAL VALUE
+ con == a connection
+ lastread_next == false
+ peek == EOF
+
+ CONVENTION
+ if (peek != EOF) then
+ peek == the last character read from the connection object and
+ is used to store the char in the event the user peeks by
+ calling sgetc()
+ if (lastread != EOF) then
+ lastread == the last character read and consumed by the user
+
+ if (lastread_next) then
+ the next character to be returned to the user is lastread because
+ the user put it back into the buffer
+
+ !*/
+
+ public:
+
+
+ sockstreambuf_unbuffered (
+ connection* con_
+ ) :
+ con(*con_),
+ peek(EOF),
+ lastread_next(false)
+ {}
+
+ sockstreambuf_unbuffered (
+ const std::unique_ptr<connection>& con_
+ ) :
+ con(*con_),
+ peek(EOF),
+ lastread_next(false)
+ {}
+
+ connection* get_connection (
+ ) { return &con; }
+
+
+ protected:
+
+ // output functions
+ int_type overflow (
+ int_type c
+ );
+
+ std::streamsize xsputn (
+ const char* s,
+ std::streamsize num
+ );
+
+ // input functions
+ int_type underflow(
+ );
+
+ int_type uflow(
+ );
+
+ int_type pbackfail(
+ int_type c
+ );
+
+ std::streamsize xsgetn (
+ char_type* s,
+ std::streamsize n
+ );
+
+ private:
+
+ // member data
+ connection& con;
+ int_type peek;
+ int_type lastread;
+ bool lastread_next;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "sockstreambuf_unbuffered.cpp"
+#endif
+
+#endif // DLIB_SOCKSTrEAMBUF_UNBUFFERED_Hh_
+
diff --git a/ml/dlib/dlib/sort.h b/ml/dlib/dlib/sort.h
new file mode 100644
index 000000000..c798372f3
--- /dev/null
+++ b/ml/dlib/dlib/sort.h
@@ -0,0 +1,490 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SORt_
+#define DLIB_SORt_
+
+#include "algs.h"
+#include <functional>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ inline void qsort_array (
+ T& array,
+ unsigned long left,
+ unsigned long right,
+ const compare& comp
+ );
+ /*!
+ requires
+ - T implements operator[]
+ - the items in array must be comparable by comp where comp is a function
+ object with the same syntax as std::less<>
+ - the items in array must be swappable by a global swap()
+ - left and right are within the bounds of array
+ i.e. array[left] and array[right] are valid elements
+ - left <= right
+ ensures
+ - for all elements in #array between and including left and right the
+ ith element is < the i+1 element
+ - sorts using a quick sort algorithm
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void hsort_array (
+ T& array,
+ unsigned long left,
+ unsigned long right,
+ const compare& comp
+ );
+ /*!
+ requires
+ - T implements operator[]
+ - the items in array must be comparable by comp where comp is a function
+ object with the same syntax as std::less<>
+ - the items in array must be swappable by a global swap()
+ - left and right are within the bounds of array
+ i.e. array[left] and array[right] are valid elements
+ - left <= right
+ ensures
+ - for all elements in #array between and including left and right the
+ ith element is < the i+1 element
+ - sorts using a heapsort algorithm
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void isort_array (
+ T& array,
+ unsigned long left,
+ unsigned long right,
+ const compare& comp
+ );
+ /*!
+ requires
+ - T implements operator[]
+ - the items in array must be comparable by comp where comp is a function
+ object with the same syntax as std::less<>
+ - the items in array must be swappable by a global swap()
+ - left and right are within the bounds of array
+ i.e. array[left] and array[right] are valid elements
+ - left <= right
+ ensures
+ - for all elements in #array between and including left and right the
+ ith element is < the i+1 element
+ - sorts using an insertion sort algorithm
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ inline void qsort_array (
+ T& array,
+ unsigned long left,
+ unsigned long right
+ );
+ /*!
+ requires
+ - T implements operator[]
+ - the items in array must be comparable by std::less
+ - the items in array must be swappable by a global swap()
+ - left and right are within the bounds of array
+ i.e. array[left] and array[right] are valid elements
+ - left <= right
+ ensures
+ - for all elements in #array between and including left and right the
+ ith element is < the i+1 element
+ - sorts using a quick sort algorithm
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void hsort_array (
+ T& array,
+ unsigned long left,
+ unsigned long right
+ );
+ /*!
+ requires
+ - T implements operator[]
+ - the items in array must be comparable by std::less
+ - the items in array must be swappable by a global swap()
+ - left and right are within the bounds of array
+ i.e. array[left] and array[right] are valid elements
+ - left <= right
+ ensures
+ - for all elements in #array between and including left and right the
+ ith element is < the i+1 element
+ - sorts using a heapsort algorithm
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void isort_array (
+ T& array,
+ unsigned long left,
+ unsigned long right
+ );
+ /*!
+ requires
+ - T implements operator[]
+ - the items in array must be comparable by std::less
+ - the items in array must be swappable by a global swap()
+ - left and right are within the bounds of array
+ i.e. array[left] and array[right] are valid elements
+ - left <= right
+ ensures
+ - for all elements in #array between and including left and right the
+ ith element is < the i+1 element
+ - sorts using an insertion sort algorithm
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// IMPLEMENTATION DETAILS
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace sort_helpers
+ {
+ template <typename T>
+ inline const std::less<T> comp (const T&)
+ {
+ return std::less<T>();
+ }
+
+ template <
+ typename T,
+ typename Y,
+ typename compare
+ >
+ inline unsigned long qsort_partition (
+ T& array,
+ Y& pivot,
+ const unsigned long left,
+ const unsigned long right,
+ const compare& comp
+ )
+ /*!
+ requires
+ - &pivot == &array[right]
+ - T implements operator[]
+ - the items in array must be comparable by comp
+ - left and right are within the bounts of the array
+ - left < right
+ ensures
+ - returns a number called partition_element such that:
+ - left <= partition_element <= right
+ - all elements in #array < #array[partition_element] have
+ indices >= left and < partition_element
+ - all elements in #array > #array[partition_element] have
+ indices > partition_element and <= right
+ !*/
+ {
+ DLIB_ASSERT (&pivot == &array[right] && left < right,
+ "\tunsigned long qsort_partition()"
+ << "\n\t&pivot: " << &pivot
+ << "\n\t&array[right]: " << &array[right]
+ << "\n\tleft: " << left
+ << "\n\tright: " << right );
+
+ exchange(array[(right-left)/2 +left],pivot);
+
+ unsigned long i = left;
+ for (unsigned long j = left; j < right; ++j)
+ {
+ if (comp(array[j] , pivot))
+ {
+ swap(array[i],array[j]);
+ ++i;
+ }
+ }
+ exchange(array[i],pivot);
+
+ return i;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void qsort_array_main (
+ T& array,
+ const unsigned long left,
+ const unsigned long right,
+ unsigned long depth_check,
+ const compare& comp
+ )
+ /*!
+ requires
+ - T implements operator[]
+ - the items in array must be comparable by comp
+ - the items in array must be swappable by a global swap()
+ - left and right are within the bounds of array
+ i.e. array[left] and array[right] are valid elements
+ ensures
+ - for all elements in #array between and including left and right the
+ ith element is < the i+1 element
+ - will only recurse about as deep as log(depth_check) calls
+ - sorts using a quick sort algorithm
+ !*/
+ {
+ if ( left < right)
+ {
+ if (right-left < 30 || depth_check == 0)
+ {
+ hsort_array(array,left,right,comp);
+ }
+ else
+ {
+ // The idea here is to only let quick sort go about log(N)
+ // calls deep before it kicks into something else.
+ depth_check >>= 1;
+ depth_check += (depth_check>>4);
+
+ unsigned long partition_element =
+ qsort_partition(array,array[right],left,right,comp);
+
+ if (partition_element > 0)
+ qsort_array_main(array,left,partition_element-1,depth_check,comp);
+ qsort_array_main(array,partition_element+1,right,depth_check,comp);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void heapify (
+ T& array,
+ const unsigned long start,
+ const unsigned long end,
+ unsigned long i,
+ const compare& comp
+ )
+ /*!
+ requires
+ - T implements operator[]
+ - the items in array must be comparable by comp
+ - the items in array must be swappable by a global swap()
+ - start, end, and i are within the bounds of array
+ i.e. array[start], array[end], and array[i] are valid elements
+ - start <= i <= end
+ - array[i/2] is a max heap
+ - array[i/2+1] is a max heap
+ - start and end specify the range of the array we are working with.
+ ensures
+ - array[i] is now a max heap
+ !*/
+ {
+ DLIB_ASSERT (start <= i && i <= end,
+ "\tvoid heapify()"
+ << "\n\tstart: " << start
+ << "\n\tend: " << end
+ << "\n\ti: " << i );
+
+ bool keep_going = true;
+ unsigned long left;
+ unsigned long right;
+ unsigned long largest;
+ while (keep_going)
+ {
+ keep_going = false;
+ left = (i<<1)+1-start;
+ right = left+1;
+
+ if (left <= end && comp(array[i] , array[left]))
+ largest = left;
+ else
+ largest = i;
+
+ if (right <= end && comp(array[largest] , array[right]))
+ largest = right;
+
+ if (largest != i)
+ {
+ exchange(array[i],array[largest]);
+ i = largest;
+ keep_going = true;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+ }
+// ----------------------------------------------------------------------------------------
+
+
+ template <
+ typename T
+ >
+ inline void qsort_array (
+ T& array,
+ unsigned long left,
+ unsigned long right
+ )
+ {
+ using namespace sort_helpers;
+ qsort_array(array,left,right,comp(array[left]));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void hsort_array (
+ T& array,
+ unsigned long left,
+ unsigned long right
+ )
+ {
+ using namespace sort_helpers;
+ hsort_array(array,left,right,comp(array[left]));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void isort_array (
+ T& array,
+ unsigned long left,
+ unsigned long right
+ )
+ {
+ using namespace sort_helpers;
+ isort_array(array,left,right,comp(array[left]));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void isort_array (
+ T& array,
+ const unsigned long left,
+ const unsigned long right,
+ const compare& comp
+ )
+ {
+ DLIB_ASSERT (left <= right,
+ "\tvoid isort_array()"
+ << "\n\tleft: " << left
+ << "\n\tright: " << right );
+ using namespace sort_helpers;
+
+ unsigned long pos;
+ for (unsigned long i = left+1; i <= right; ++i)
+ {
+ // everything from left to i-1 is sorted.
+ pos = i;
+ for (unsigned long j = i-1; comp(array[pos] , array[j]); --j)
+ {
+ exchange(array[pos],array[j]);
+ pos = j;
+
+ if (j == left)
+ break;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void qsort_array (
+ T& array,
+ const unsigned long left,
+ const unsigned long right,
+ const compare& comp
+ )
+ {
+ DLIB_ASSERT (left <= right,
+ "\tvoid qsort_array()"
+ << "\n\tleft: " << left
+ << "\n\tright: " << right );
+
+ sort_helpers::qsort_array_main(array,left,right,right-left,comp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void hsort_array (
+ T& array,
+ const unsigned long left,
+ const unsigned long right,
+ const compare& comp
+ )
+ {
+ DLIB_ASSERT (left <= right,
+ "\tvoid hsort_array()"
+ << "\n\tleft: " << left
+ << "\n\tright: " << right );
+
+ if (right-left < 30)
+ {
+ isort_array(array,left,right,comp);
+ return;
+ }
+
+ // turn array into a max heap
+ for (unsigned long i = left+((right-left)>>1);; --i)
+ {
+ sort_helpers::heapify(array,left,right,i,comp);
+ if (i == left)
+ break;
+ }
+
+ // now sort the array
+ for (unsigned long i = right; i > left;)
+ {
+ exchange(array[i],array[left]);
+ sort_helpers::heapify(array,left,--i,left,comp);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SORt_
+
diff --git a/ml/dlib/dlib/sparse_vector.h b/ml/dlib/dlib/sparse_vector.h
new file mode 100644
index 000000000..4be6a3adc
--- /dev/null
+++ b/ml/dlib/dlib/sparse_vector.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SPaRSE_VECTOR_Hh_
+#define DLIB_SPaRSE_VECTOR_Hh_
+
+#include "svm/sparse_vector.h"
+
+#endif // DLIB_SPaRSE_VECTOR_Hh_
+
+
diff --git a/ml/dlib/dlib/sqlite.h b/ml/dlib/dlib/sqlite.h
new file mode 100644
index 000000000..b5aadbed2
--- /dev/null
+++ b/ml/dlib/dlib/sqlite.h
@@ -0,0 +1,11 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SQLiTE_HEADER
+#define DLIB_SQLiTE_HEADER
+
+#include "sqlite/sqlite_tools.h"
+
+#endif // DLIB_SVm_HEADER
+
+
+
diff --git a/ml/dlib/dlib/sqlite/sqlite.h b/ml/dlib/dlib/sqlite/sqlite.h
new file mode 100644
index 000000000..7eefbb21a
--- /dev/null
+++ b/ml/dlib/dlib/sqlite/sqlite.h
@@ -0,0 +1,625 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SQLiTE_H_
+#define DLIB_SQLiTE_H_
+
+#include "sqlite_abstract.h"
+
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <vector>
+
+#include "../algs.h"
+#include <sqlite3.h>
+#include "../serialize.h"
+
+
+// --------------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+// --------------------------------------------------------------------------------------------
+
+ struct sqlite_error : public error
+ {
+ sqlite_error(const std::string& message): error(message) {}
+ };
+
+// --------------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ struct db_deleter
+ {
+ void operator()(
+ sqlite3* db
+ )const
+ {
+ sqlite3_close(db);
+ }
+ };
+ }
+
+// --------------------------------------------------------------------------------------------
+
+ class database : noncopyable
+ {
+ public:
+ database(
+ )
+ {
+ }
+
+ database (
+ const std::string& file
+ )
+ {
+ open(file);
+ }
+
+ bool is_open (
+ ) const
+ {
+ return db.get() != 0;
+ }
+
+ void open (
+ const std::string& file
+ )
+ {
+ filename = file;
+ sqlite3* ptr = 0;
+ int status = sqlite3_open(file.c_str(), &ptr);
+ db.reset(ptr, impl::db_deleter());
+ if (status != SQLITE_OK)
+ {
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ }
+
+ const std::string& get_database_filename (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_open() == true,
+ "\t std::string database::get_database_filename()"
+ << "\n\t The database must be opened before calling this routine."
+ << "\n\t this: " << this
+ );
+
+ return filename;
+ }
+
+ inline void exec (
+ const std::string& sql_statement
+ );
+
+ int64 last_insert_rowid (
+ ) const
+ {
+ return sqlite3_last_insert_rowid(db.get());
+ }
+
+ private:
+
+ friend class statement;
+
+ std::string filename;
+ std::shared_ptr<sqlite3> db;
+ };
+
+// --------------------------------------------------------------------------------------------
+
+ class statement : noncopyable
+ {
+ public:
+ statement (
+ database& db_,
+ const std::string sql_statement
+ ) :
+ needs_reset(false),
+ step_status(SQLITE_DONE),
+ at_first_step(true),
+ db(db_.db),
+ stmt(0),
+ sql_string(sql_statement)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(db_.is_open() == true,
+ "\t statement::statement()"
+ << "\n\t The database must be opened before calling this routine."
+ << "\n\t this: " << this
+ );
+
+ int status = sqlite3_prepare_v2(db.get(),
+ sql_string.c_str(),
+ sql_string.size()+1,
+ &stmt,
+ NULL);
+
+ if (status != SQLITE_OK)
+ {
+ sqlite3_finalize(stmt);
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ if (stmt == 0)
+ {
+ throw sqlite_error("Invalid SQL statement");
+ }
+ }
+
+ ~statement(
+ )
+ {
+ sqlite3_finalize(stmt);
+ }
+
+ void exec(
+ )
+ {
+ reset();
+
+ step_status = sqlite3_step(stmt);
+ needs_reset = true;
+ if (step_status != SQLITE_DONE && step_status != SQLITE_ROW)
+ {
+ if (step_status == SQLITE_ERROR)
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ else if (step_status == SQLITE_BUSY)
+ throw sqlite_error("statement::exec() failed. SQLITE_BUSY returned");
+ else
+ throw sqlite_error("statement::exec() failed.");
+ }
+ }
+
+ bool move_next (
+ )
+ {
+ if (step_status == SQLITE_ROW)
+ {
+ if (at_first_step)
+ {
+ at_first_step = false;
+ return true;
+ }
+ else
+ {
+ step_status = sqlite3_step(stmt);
+ if (step_status == SQLITE_DONE)
+ {
+ return false;
+ }
+ else if (step_status == SQLITE_ROW)
+ {
+ return true;
+ }
+ else
+ {
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ }
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ unsigned long get_num_columns(
+ ) const
+ {
+ if( (at_first_step==false) && (step_status==SQLITE_ROW))
+ {
+ return sqlite3_column_count(stmt);
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ const std::string& get_sql_string (
+ ) const
+ {
+ return sql_string;
+ }
+
+ template <typename T>
+ typename enable_if_c<std::numeric_limits<T>::is_integer>::type get_column (
+ unsigned long idx,
+ T& item
+ ) const
+ {
+ // unsigned ints won't fit into int all the time so put those into 64bit ints.
+ if (sizeof(T) < sizeof(int) || (sizeof(T)==sizeof(int) && is_signed_type<T>::value))
+ item = get_column_as_int(idx);
+ else
+ item = get_column_as_int64(idx);
+ }
+
+ void get_column(unsigned long idx, std::string& item) const { item = get_column_as_text(idx); }
+ void get_column(unsigned long idx, float& item ) const { item = get_column_as_double(idx); }
+ void get_column(unsigned long idx, double& item ) const { item = get_column_as_double(idx); }
+ void get_column(unsigned long idx, long double& item) const { item = get_column_as_double(idx); }
+
+ template <typename T>
+ typename disable_if_c<std::numeric_limits<T>::is_integer>::type get_column (
+ unsigned long idx,
+ T& item
+ ) const
+ {
+ get_column_as_object(idx, item);
+ }
+
+ const std::vector<char> get_column_as_blob (
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < get_num_columns(),
+ "\t std::vector<char> statement::get_column_as_blob()"
+ << "\n\t Invalid column index."
+ << "\n\t idx: " << idx
+ << "\n\t this: " << this
+ );
+
+ const char* data = static_cast<const char*>(sqlite3_column_blob(stmt, idx));
+ const int size = sqlite3_column_bytes(stmt, idx);
+
+ return std::vector<char>(data, data+size);
+ }
+
+ template <typename T>
+ void get_column_as_object (
+ unsigned long idx,
+ T& item
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < get_num_columns(),
+ "\t void statement::get_column_as_object()"
+ << "\n\t Invalid column index."
+ << "\n\t idx: " << idx
+ << "\n\t this: " << this
+ );
+
+ const char* data = static_cast<const char*>(sqlite3_column_blob(stmt, idx));
+ const int size = sqlite3_column_bytes(stmt, idx);
+ std::istringstream sin(std::string(data,size));
+ deserialize(item, sin);
+ }
+
+ const std::string get_column_as_text (
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < get_num_columns(),
+ "\t std::string statement::get_column_as_text()"
+ << "\n\t Invalid column index."
+ << "\n\t idx: " << idx
+ << "\n\t this: " << this
+ );
+
+ const char* data = reinterpret_cast<const char*>(sqlite3_column_text(stmt, idx));
+ if (data != 0)
+ return std::string(data);
+ else
+ return std::string();
+ }
+
+ double get_column_as_double (
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < get_num_columns(),
+ "\t double statement::get_column_as_double()"
+ << "\n\t Invalid column index."
+ << "\n\t idx: " << idx
+ << "\n\t this: " << this
+ );
+
+ return sqlite3_column_double(stmt, idx);
+ }
+
+ int get_column_as_int (
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < get_num_columns(),
+ "\t int statement::get_column_as_int()"
+ << "\n\t Invalid column index."
+ << "\n\t idx: " << idx
+ << "\n\t this: " << this
+ );
+
+ return sqlite3_column_int(stmt, idx);
+ }
+
+ int64 get_column_as_int64 (
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < get_num_columns(),
+ "\t int64 statement::get_column_as_int64()"
+ << "\n\t Invalid column index."
+ << "\n\t idx: " << idx
+ << "\n\t this: " << this
+ );
+
+ return sqlite3_column_int64(stmt, idx);
+ }
+
+ const std::string get_column_name (
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < get_num_columns(),
+ "\t std::string statement::get_column_name()"
+ << "\n\t Invalid column index."
+ << "\n\t idx: " << idx
+ << "\n\t this: " << this
+ );
+
+ return std::string(sqlite3_column_name(stmt,idx));
+ }
+
+ unsigned long get_max_parameter_id (
+ ) const
+ {
+ return sqlite3_limit(db.get(), SQLITE_LIMIT_VARIABLE_NUMBER, -1);
+ }
+
+ unsigned long get_parameter_id (
+ const std::string& name
+ ) const
+ {
+ return sqlite3_bind_parameter_index(stmt, name.c_str());
+ }
+
+ template <typename T>
+ typename enable_if_c<std::numeric_limits<T>::is_integer>::type bind (
+ unsigned long idx,
+ const T& item
+ )
+ {
+ // unsigned ints won't fit into int all the time so put those into 64bit ints.
+ if (sizeof(T) < sizeof(int) || (sizeof(T)==sizeof(int) && is_signed_type<T>::value))
+ bind_int(idx, item);
+ else
+ bind_int64(idx, item);
+ }
+
+ void bind(unsigned long idx, const std::string& item) { bind_text(idx, item); }
+ void bind(unsigned long idx, const float& item ) { bind_double(idx, item); }
+ void bind(unsigned long idx, const double& item ) { bind_double(idx, item); }
+ void bind(unsigned long idx, const long double& item) { bind_double(idx, item); }
+
+ template <typename T>
+ typename disable_if_c<std::numeric_limits<T>::is_integer>::type bind (
+ unsigned long idx,
+ const T& item
+ )
+ {
+ bind_object(idx, item);
+ }
+
+ void bind_blob (
+ unsigned long parameter_id,
+ const std::vector<char>& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
+ "\t void statement::bind_blob()"
+ << "\n\t Invalid parameter id."
+ << "\n\t parameter_id: " << parameter_id
+ << "\n\t get_max_parameter_id(): " << get_max_parameter_id()
+ << "\n\t this: " << this
+ );
+
+ reset();
+ int status = sqlite3_bind_blob(stmt, parameter_id, &item[0], item.size(), SQLITE_TRANSIENT);
+
+ if (status != SQLITE_OK)
+ {
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ }
+
+ template <typename T>
+ void bind_object (
+ unsigned long parameter_id,
+ const T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
+ "\t void statement::bind_object()"
+ << "\n\t Invalid parameter id."
+ << "\n\t parameter_id: " << parameter_id
+ << "\n\t get_max_parameter_id(): " << get_max_parameter_id()
+ << "\n\t this: " << this
+ );
+
+ reset();
+ std::ostringstream sout;
+ serialize(item, sout);
+ const std::string& str = sout.str();
+ int status = sqlite3_bind_blob(stmt, parameter_id, str.data(), str.size(), SQLITE_TRANSIENT);
+
+ if (status != SQLITE_OK)
+ {
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ }
+
+ void bind_double (
+ unsigned long parameter_id,
+ const double& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
+ "\t void statement::bind_double()"
+ << "\n\t Invalid parameter id."
+ << "\n\t parameter_id: " << parameter_id
+ << "\n\t get_max_parameter_id(): " << get_max_parameter_id()
+ << "\n\t this: " << this
+ );
+
+ reset();
+ int status = sqlite3_bind_double(stmt, parameter_id, item);
+
+ if (status != SQLITE_OK)
+ {
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ }
+
+ void bind_int (
+ unsigned long parameter_id,
+ const int& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
+ "\t void statement::bind_int()"
+ << "\n\t Invalid parameter id."
+ << "\n\t parameter_id: " << parameter_id
+ << "\n\t get_max_parameter_id(): " << get_max_parameter_id()
+ << "\n\t this: " << this
+ );
+
+ reset();
+ int status = sqlite3_bind_int(stmt, parameter_id, item);
+
+ if (status != SQLITE_OK)
+ {
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ }
+
+ void bind_int64 (
+ unsigned long parameter_id,
+ const int64& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
+ "\t void statement::bind_int64()"
+ << "\n\t Invalid parameter id."
+ << "\n\t parameter_id: " << parameter_id
+ << "\n\t get_max_parameter_id(): " << get_max_parameter_id()
+ << "\n\t this: " << this
+ );
+
+ reset();
+ int status = sqlite3_bind_int64(stmt, parameter_id, item);
+
+ if (status != SQLITE_OK)
+ {
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ }
+
+ void bind_null (
+ unsigned long parameter_id
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
+ "\t void statement::bind_null()"
+ << "\n\t Invalid parameter id."
+ << "\n\t parameter_id: " << parameter_id
+ << "\n\t get_max_parameter_id(): " << get_max_parameter_id()
+ << "\n\t this: " << this
+ );
+
+ reset();
+ int status = sqlite3_bind_null(stmt, parameter_id);
+
+ if (status != SQLITE_OK)
+ {
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ }
+
+ void bind_text (
+ unsigned long parameter_id,
+ const std::string& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
+ "\t void statement::bind_text()"
+ << "\n\t Invalid parameter id."
+ << "\n\t parameter_id: " << parameter_id
+ << "\n\t get_max_parameter_id(): " << get_max_parameter_id()
+ << "\n\t this: " << this
+ );
+
+ reset();
+ int status = sqlite3_bind_text(stmt, parameter_id, item.c_str(), -1, SQLITE_TRANSIENT);
+
+ if (status != SQLITE_OK)
+ {
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ }
+
+ private:
+
+ void reset()
+ {
+ if (needs_reset)
+ {
+ if (sqlite3_reset(stmt) != SQLITE_OK)
+ {
+ step_status = SQLITE_DONE;
+ throw sqlite_error(sqlite3_errmsg(db.get()));
+ }
+ needs_reset = false;
+ step_status = SQLITE_DONE;
+ at_first_step = true;
+ }
+ }
+
+ bool needs_reset; // true if sqlite3_step() has been called more recently than sqlite3_reset()
+ int step_status;
+ bool at_first_step;
+
+ std::shared_ptr<sqlite3> db;
+ sqlite3_stmt* stmt;
+ std::string sql_string;
+ };
+
+// --------------------------------------------------------------------------------------------
+
+ void database::
+ exec (
+ const std::string& sql_statement
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_open() == true,
+ "\t void database::exec()"
+ << "\n\t The database must be opened before calling this routine."
+ << "\n\t this: " << this
+ );
+
+ statement(*this, sql_statement).exec();
+ }
+
+// --------------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SQLiTE_H_
+
diff --git a/ml/dlib/dlib/sqlite/sqlite_abstract.h b/ml/dlib/dlib/sqlite/sqlite_abstract.h
new file mode 100644
index 000000000..7372162d8
--- /dev/null
+++ b/ml/dlib/dlib/sqlite/sqlite_abstract.h
@@ -0,0 +1,506 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SQLiTE_ABSTRACT_H_
+#ifdef DLIB_SQLiTE_ABSTRACT_H_
+
+
+#include <iostream>
+#include <vector>
+#include "../algs.h"
+#include <sqlite3.h>
+#include "../smart_pointers.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct sqlite_error : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception object used by the SQLite tools to indicate
+ that an error has occurred. Any of the functions defined in this
+ file might throw this exception.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class database : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a C++ wrapper around a SQLite database connection
+ handle and therefore represents a SQLite database file.
+
+ Note that this wrapper is targeted at SQLite Version 3.
+
+ Note also that whenever SQLite indicates an error has occurred
+ this object will throw the sqlite_error exception.
+ !*/
+
+ public:
+ database(
+ );
+ /*!
+ ensures
+ - #is_open() == false
+ !*/
+
+ database (
+ const std::string& file
+ );
+ /*!
+ ensures
+ - opens the indicated database file or creates a new
+ database with the given name if one doesn't already exist.
+ - #get_database_filename() == file
+ - #is_open() == true
+ !*/
+
+ ~database (
+ );
+ /*!
+ ensures
+ - safely disposes of any SQLite database connection. If
+ any statement objects still exist which reference this database
+ then the SQLite database connection won't be fully closed
+ until those statement objects are also destroyed. This allows
+ for any destruction order between database and statement objects.
+ !*/
+
+ void open (
+ const std::string& file
+ );
+ /*!
+ ensures
+ - opens the indicated database file or creates a new
+ database with the given name if one doesn't already exist.
+ - #get_database_filename() == file
+ - #is_open() == true
+ - safely disposes of any previous SQLite database connection. If
+ any statement objects still exist which reference this database
+ then the SQLite database connection won't be fully closed
+ until those statement objects are also destroyed.
+ !*/
+
+ bool is_open (
+ ) const;
+ /*!
+ ensures
+ - if (this object has an open connection to a SQLite database) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ const std::string& get_database_filename (
+ ) const;
+ /*!
+ requires
+ - is_open() == true
+ ensures
+ - returns the name of the SQLite database file this object
+ currently has open.
+ !*/
+
+ void exec (
+ const std::string& sql_statement
+ );
+ /*!
+ requires
+ - is_open() == true
+ ensures
+ - executes the supplied SQL statement against this database
+ !*/
+
+ int64 last_insert_rowid (
+ ) const;
+ /*!
+ requires
+ - is_open() == true
+ ensures
+ - Each element in a database table has a rowid which uniquely identifies
+ it. Therefore, this routine returns the rowid of the most recent
+ successful INSERT into the database via this database instance.
+ - If an INSERT has not been performed on the current database instance then
+ the return value is 0. This is true even if the database is not empty.
+ - See the sqlite documentation for the full details on how this function
+ behaves: http://www.sqlite.org/c3ref/last_insert_rowid.html
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class statement : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a SQL statement which can be executed
+ against a database object. In particular, this object is a
+ C++ wrapper around a SQLite prepared statement.
+
+
+ Note that whenever SQLite indicates an error has occurred this
+ object will throw the sqlite_error exception.
+
+ BINDABLE SQL PARAMETERS
+ Sometimes you want to execute a bunch of very similar SQL statements.
+ For example, you might need to execute many insert statements where each
+ statement changes only the value of a field. Since it is somewhat
+ costly to construct a statement object for each SQL operation, SQLite
+ supports defining bindable parameters for a statement object. This allows
+ you to reuse the same statement object.
+
+ Therefore, in SQL statements used with SQLite, wherever it is valid to
+ include a string literal, one can use a parameter in one of the following
+ forms:
+
+ ?
+ ?NNN
+ :AAA
+ $AAA
+ @AAA
+
+ In the examples above, NNN is an integer value and AAA is an identifier. A
+ parameter initially has a value of NULL. You can use the bind_*() routines
+ to attach values to the parameters. Each call to a bind_*() routine overrides
+ prior bindings on the same parameter.
+
+ Each SQL parameter has a numeric ID which is used to reference it when invoking
+ a bind_*() routine. The leftmost SQL parameter in a statement has an index of 1,
+ the next parameter has an index of 2, and so on, except when the following rules
+ apply. When the same named SQL parameter is used more than once, second and
+ subsequent occurrences have the same index as the first occurrence. The index
+ for named parameters can be looked up using the get_parameter_id() method if desired.
+ The index for "?NNN" parameters is the value of NNN. The NNN value must be between
+ 1 and get_max_parameter_id().
+ !*/
+
+ public:
+ statement (
+ database& db,
+ const std::string sql_statement
+ );
+ /*!
+ requires
+ - db.is_open() == true
+ ensures
+ - The given SQL statement can be executed against the given
+ database by calling exec().
+ - #get_sql_string() == sql_statement
+ !*/
+
+ ~statement(
+ );
+ /*!
+ ensures
+ - any resources associated with this object have been freed.
+ !*/
+
+ const std::string& get_sql_string (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the SQL statement used to create this statement object.
+ !*/
+
+ void exec(
+ );
+ /*!
+ ensures
+ - #get_num_columns() == 0
+ - executes the SQL statement get_sql_string() against the database
+ given to this object's constructor.
+ - If this was a select statement then you can obtain the resulting
+ rows by calling move_next() and using the get_column_as_*() member
+ functions.
+ !*/
+
+ // ----------------------------
+
+ bool move_next (
+ );
+ /*!
+ ensures
+ - if (there is a result row for this query) then
+ - #get_num_columns() == the number of columns in the result row.
+ - The get_column_as_*() routines can be used to access the elements
+ of the row data.
+ - returns true
+ - else
+ - returns false
+ - #get_num_columns() == 0
+ !*/
+
+ unsigned long get_num_columns(
+ ) const;
+ /*!
+ ensures
+ - returns the number of columns of data available via the get_column_as_*()
+ routines.
+ !*/
+
+ template <
+ typename T
+ >
+ void get_column (
+ unsigned long idx,
+ T& item
+ ) const;
+ /*!
+ requires
+ - idx < get_num_columns()
+ ensures
+ - This function automatically selects how to extract the column data based
+ on the type of item given. In particular:
+ - if (T is a 32bit or smaller built in integer type) then
+ - #item == get_column_as_int(idx)
+ - else if (T is a 64bit built in integer type) then
+ - #item == get_column_as_int64(idx)
+ - else if (T is float, double, or long double) then
+ - #item == get_column_as_double(idx)
+ - else if (T is std::string) then
+ - #item == get_column_as_text(idx)
+ - else
+ - invokes: get_column_as_object(idx, item)
+ !*/
+
+ const std::vector<char> get_column_as_blob (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < get_num_columns()
+ ensures
+ - returns the contents of the idx-th column as a binary BLOB.
+ !*/
+
+ template <
+ typename T
+ >
+ void get_column_as_object (
+ unsigned long idx,
+ T& item
+ ) const;
+ /*!
+ requires
+ - idx < get_num_columns()
+ - item is deserializable
+ (i.e. Calling deserialize(item, some_input_stream) reads an item
+ of type T from the some_input_stream stream)
+ ensures
+ - gets the contents of the idx-th column as a binary BLOB and then
+ deserializes it into item.
+ !*/
+
+ const std::string get_column_as_text (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < get_num_columns()
+ ensures
+ - returns the contents of the idx-th column as a text string.
+ !*/
+
+ double get_column_as_double (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < get_num_columns()
+ ensures
+ - returns the contents of the idx-th column as a double.
+ !*/
+
+ int get_column_as_int (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < get_num_columns()
+ ensures
+ - returns the contents of the idx-th column as an int.
+ !*/
+
+ int64 get_column_as_int64 (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < get_num_columns()
+ ensures
+ - returns the contents of the idx-th column as a 64bit int.
+ !*/
+
+ const std::string get_column_name (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < get_num_columns()
+ ensures
+ - returns the name of the idx-th column. In particular:
+ The name of a result column is the value of the "AS" clause for
+ that column, if there is an AS clause. If there is no AS clause
+ then the name of the column is unspecified and may change from
+ one release of SQLite to the next.
+ !*/
+
+ // ----------------------------
+
+ unsigned long get_max_parameter_id (
+ ) const;
+ /*!
+ ensures
+ - returns the max parameter ID value which can be used with the
+ bind_() member functions defined below.
+ - In SQLite, the default value of this limit is usually 999.
+ !*/
+
+ unsigned long get_parameter_id (
+ const std::string& name
+ ) const;
+ /*!
+ ensures
+ - if (This SQL statement contains a SQL parameter with the given name) then
+ - returns the parameter_id number which can be used in the bind_*()
+ member functions defined below.
+ - else
+ - returns 0
+ !*/
+
+ template <
+ typename T
+ >
+ void bind (
+ unsigned long parameter_id,
+ const T& item
+ ) const;
+ /*!
+ requires
+ - 1 <= parameter_id <= get_max_parameter_id()
+ ensures
+ - This function automatically selects how to bind item to a statement based
+ on the type of item given. In particular:
+ - if (T is a 32bit or smaller built in integer type) then
+ - invokes: bind_int(parameter_id, item)
+ - else if (T is a 64bit built in integer type) then
+ - invokes: bind_int64(parameter_id, item)
+ - else if (T is float, double, or long double) then
+ - invokes: bind_double(parameter_id, item)
+ - else if (T is std::string) then
+ - invokes: bind_text(parameter_id, item)
+ - else
+ - invokes: bind_object(parameter_id, item)
+ !*/
+
+ void bind_blob (
+ unsigned long parameter_id,
+ const std::vector<char>& item
+ );
+ /*!
+ requires
+ - 1 <= parameter_id <= get_max_parameter_id()
+ ensures
+ - #get_num_columns() == 0
+ - binds the value of item into the SQL parameter indicated by
+ parameter_id.
+ !*/
+
+ template <
+ typename T
+ >
+ void bind_object (
+ unsigned long parameter_id,
+ const T& item
+ );
+ /*!
+ requires
+ - 1 <= parameter_id <= get_max_parameter_id()
+ - item is serializable
+ (i.e. Calling serialize(item, some_output_stream) writes an item
+ of type T to the some_output_stream stream)
+ ensures
+ - #get_num_columns() == 0
+ - binds the value of item into the SQL parameter indicated by
+ parameter_id. This is performed by serializing item and then
+ binding it as a binary BLOB.
+ !*/
+
+ void bind_double (
+ unsigned long parameter_id,
+ const double& item
+ );
+ /*!
+ requires
+ - 1 <= parameter_id <= get_max_parameter_id()
+ ensures
+ - #get_num_columns() == 0
+ - binds the value of item into the SQL parameter indicated by
+ parameter_id.
+ !*/
+
+ void bind_int (
+ unsigned long parameter_id,
+ const int& item
+ );
+ /*!
+ requires
+ - 1 <= parameter_id <= get_max_parameter_id()
+ ensures
+ - #get_num_columns() == 0
+ - binds the value of item into the SQL parameter indicated by
+ parameter_id.
+ !*/
+
+ void bind_int64 (
+ unsigned long parameter_id,
+ const int64& item
+ );
+ /*!
+ requires
+ - 1 <= parameter_id <= get_max_parameter_id()
+ ensures
+ - #get_num_columns() == 0
+ - binds the value of item into the SQL parameter indicated by
+ parameter_id.
+ !*/
+
+ void bind_null (
+ unsigned long parameter_id
+ );
+ /*!
+ requires
+ - 1 <= parameter_id <= get_max_parameter_id()
+ ensures
+ - #get_num_columns() == 0
+ - binds a NULL to the SQL parameter indicated by parameter_id.
+ !*/
+
+ void bind_text (
+ unsigned long parameter_id,
+ const std::string& item
+ );
+ /*!
+ requires
+ - 1 <= parameter_id <= get_max_parameter_id()
+ ensures
+ - #get_num_columns() == 0
+ - binds the value of item into the SQL parameter indicated by
+ parameter_id.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SQLiTE_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/sqlite/sqlite_tools.h b/ml/dlib/dlib/sqlite/sqlite_tools.h
new file mode 100644
index 000000000..062a6b2c8
--- /dev/null
+++ b/ml/dlib/dlib/sqlite/sqlite_tools.h
@@ -0,0 +1,189 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SQLiTE_TOOLS_H_
+#define DLIB_SQLiTE_TOOLS_H_
+
+
+#include "sqlite_tools_abstract.h"
+#include "sqlite.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ class transaction : noncopyable
+ {
+ public:
+ transaction (
+ database& db_
+ ) :
+ db(db_),
+ committed(false)
+ {
+ db.exec("begin transaction");
+ }
+
+ void commit ()
+ {
+ if (!committed)
+ {
+ committed = true;
+ db.exec("commit");
+ }
+ }
+
+ ~transaction() throw (std::exception)
+ {
+ if (!committed)
+ db.exec("rollback");
+ }
+
+ private:
+ database& db;
+ bool committed;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <
+ typename T
+ >
+ void query_object (
+ database& db,
+ const std::string& query,
+ T& item
+ )
+ {
+ statement st(db, query);
+ st.exec();
+ if (st.move_next() && st.get_num_columns() == 1)
+ {
+ st.get_column_as_object(0,item);
+ if (st.move_next())
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ }
+ else
+ {
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::string query_text (
+ database& db,
+ const std::string& query
+ )
+ {
+ statement st(db, query);
+ st.exec();
+ if (st.move_next() && st.get_num_columns() == 1)
+ {
+ const std::string& temp = st.get_column_as_text(0);
+ if (st.move_next())
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ return temp;
+ }
+ else
+ {
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double query_double (
+ database& db,
+ const std::string& query
+ )
+ {
+ statement st(db, query);
+ st.exec();
+ if (st.move_next() && st.get_num_columns() == 1)
+ {
+ double temp = st.get_column_as_double(0);
+ if (st.move_next())
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ return temp;
+ }
+ else
+ {
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline int query_int (
+ database& db,
+ const std::string& query
+ )
+ {
+ statement st(db, query);
+ st.exec();
+ if (st.move_next() && st.get_num_columns() == 1)
+ {
+ int temp = st.get_column_as_int(0);
+ if (st.move_next())
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ return temp;
+ }
+ else
+ {
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline int64 query_int64 (
+ database& db,
+ const std::string& query
+ )
+ {
+ statement st(db, query);
+ st.exec();
+ if (st.move_next() && st.get_num_columns() == 1)
+ {
+ int64 temp = st.get_column_as_int64(0);
+ if (st.move_next())
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ return temp;
+ }
+ else
+ {
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline const std::vector<char> query_blob (
+ database& db,
+ const std::string& query
+ )
+ {
+ statement st(db, query);
+ st.exec();
+ if (st.move_next() && st.get_num_columns() == 1)
+ {
+ const std::vector<char>& temp = st.get_column_as_blob(0);
+ if (st.move_next())
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ return temp;
+ }
+ else
+ {
+ throw sqlite_error("query doesn't result in exactly 1 element");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SQLiTE_TOOLS_H_
+
diff --git a/ml/dlib/dlib/sqlite/sqlite_tools_abstract.h b/ml/dlib/dlib/sqlite/sqlite_tools_abstract.h
new file mode 100644
index 000000000..c13a09265
--- /dev/null
+++ b/ml/dlib/dlib/sqlite/sqlite_tools_abstract.h
@@ -0,0 +1,164 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SQLiTE_TOOLS_ABSTRACT_H_
+#ifdef DLIB_SQLiTE_TOOLS_ABSTRACT_H_
+
+
+#include "sqlite_abstract.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ class transaction : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for creating exception safe
+ database transactions.
+ !*/
+
+ public:
+ transaction (
+ database& db
+ );
+ /*!
+ ensures
+ - Begins a database transaction which will be rolled back
+ if commit() isn't called eventually.
+ - In particular, performs: db.exec("begin transaction");
+ !*/
+
+ void commit (
+ );
+ /*!
+ ensures
+ - if (commit() hasn't already been called) then
+ - Commits all changes made during this database transaction.
+ - In particular, performs: db.exec("commit");
+ - else
+ - does nothing
+ !*/
+
+ ~transaction(
+ );
+ /*!
+ ensures
+ - if (commit() was never called) then
+ - rolls back any changes made to the database during this transaction.
+ - In particular, performs: db.exec("rollback");
+ - else
+ - does nothing
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void query_object (
+ database& db,
+ const std::string& query,
+ T& item
+ );
+ /*!
+ ensures
+ - executes the given SQL query against db. If the query results in a
+ single row and column being returned then the data in the column is
+ interpreted as a binary BLOB and deserialized into item.
+ throws
+ - sqlite_error or serialization_error if an error occurs which prevents
+ this operation from succeeding.
+
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::string query_text (
+ database& db,
+ const std::string& query
+ );
+ /*!
+ ensures
+ - executes the given SQL query against db. If the query results in a
+ single row and column being returned then the data in the column is
+ converted to text and returned.
+ throws
+ - sqlite_error if an error occurs which prevents this operation from
+ succeeding.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ double query_double (
+ database& db,
+ const std::string& query
+ );
+ /*!
+ ensures
+ - executes the given SQL query against db. If the query results in a
+ single row and column being returned then the data in the column is
+ converted to a double and returned.
+ throws
+ - sqlite_error if an error occurs which prevents this operation from
+ succeeding.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ int query_int (
+ database& db,
+ const std::string& query
+ );
+ /*!
+ ensures
+ - executes the given SQL query against db. If the query results in a
+ single row and column being returned then the data in the column is
+ converted to an int and returned.
+ throws
+ - sqlite_error if an error occurs which prevents this operation from
+ succeeding.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ int64 query_int64 (
+ database& db,
+ const std::string& query
+ );
+ /*!
+ ensures
+ - executes the given SQL query against db. If the query results in a
+ single row and column being returned then the data in the column is
+ converted to an int64 and returned.
+ throws
+ - sqlite_error if an error occurs which prevents this operation from
+ succeeding.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const std::vector<char> query_blob (
+ database& db,
+ const std::string& query
+ );
+ /*!
+ ensures
+ - executes the given SQL query against db. If the query results in a
+ single row and column being returned then the data in the column is
+ returned as a binary BLOB.
+ throws
+ - sqlite_error if an error occurs which prevents this operation from
+ succeeding.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SQLiTE_TOOLS_H_
+
+
diff --git a/ml/dlib/dlib/sstream b/ml/dlib/dlib/sstream
new file mode 100644
index 000000000..eb0e59e41
--- /dev/null
+++ b/ml/dlib/dlib/sstream
@@ -0,0 +1 @@
+#include "dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/stack.h b/ml/dlib/dlib/stack.h
new file mode 100644
index 000000000..58f2c5a6c
--- /dev/null
+++ b/ml/dlib/dlib/stack.h
@@ -0,0 +1,34 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STACk_
+#define DLIB_STACk_
+
+#include "stack/stack_kernel_1.h"
+#include "stack/stack_kernel_c.h"
+#include "algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class stack
+ {
+ stack() {}
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef stack_kernel_1<T,mem_manager>
+ kernel_1a;
+ typedef stack_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+ };
+}
+
+#endif // DLIB_STACk_
+
diff --git a/ml/dlib/dlib/stack/stack_kernel_1.h b/ml/dlib/dlib/stack/stack_kernel_1.h
new file mode 100644
index 000000000..427d65183
--- /dev/null
+++ b/ml/dlib/dlib/stack/stack_kernel_1.h
@@ -0,0 +1,504 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STACK_KERNEl_1_
+#define DLIB_STACK_KERNEl_1_
+
+#include "stack_kernel_abstract.h"
+#include "../algs.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class stack_kernel_1 : public enumerable<T>,
+ public remover<T>
+ {
+
+ /*!
+ INITIAL VALUE
+ stack_size == 0
+ top == 0
+ current_element == 0
+ _at_start == true
+
+
+ CONVENTION
+ at_start() == _at_start
+ current_element_valid() == (current_element != 0)
+ if (current_element != 0) then
+ element() == current_element->item
+
+ stack_size == the number of elements in the stack.
+ Each node points to the next node to be poped off the stack.
+ The last node in the list has its next pointer is set to 0.
+
+ if (size == 0)
+ {
+ top == 0
+ }
+ else
+ {
+ top == pointer to the last element added to the stack
+ }
+ !*/
+
+ struct node
+ {
+ node* next;
+ T item;
+ };
+
+ public:
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ stack_kernel_1(
+ ):
+ top(0),
+ stack_size(0),
+ current_element(0),
+ _at_start(true)
+ {}
+
+ virtual ~stack_kernel_1(
+ );
+
+ inline void clear(
+ );
+
+ inline void push(
+ T& item
+ );
+
+ void pop(
+ T& item
+ );
+
+ T& current(
+ );
+
+ const T& current(
+ ) const;
+
+ inline void swap (
+ stack_kernel_1& item
+ );
+
+ // functions from the remover interface
+ inline void remove_any (
+ T& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ bool current_element_valid (
+ ) const;
+
+ inline const T& element (
+ ) const;
+
+ inline T& element (
+ );
+
+ bool move_next (
+ ) const;
+
+ private:
+
+ void delete_elements_in_stack(
+ node*& top
+ );
+ /*!
+ requires
+ - top points to the top of the stack
+ ensures
+ - all memory has been freed
+ - #top = 0
+ !*/
+
+
+ // data members
+ typename mem_manager::template rebind<node>::other pool;
+ node* top;
+ unsigned long stack_size;
+ mutable node* current_element;
+ mutable bool _at_start;
+
+
+ // restricted functions
+ stack_kernel_1(stack_kernel_1&); // copy constructor
+ stack_kernel_1& operator=(stack_kernel_1&); // assignment operator
+
+ };
+
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ stack_kernel_1<T,mem_manager>& a,
+ stack_kernel_1<T,mem_manager>& b
+ ) { a.swap(b); }
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ stack_kernel_1<T,mem_manager>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ T temp = T();
+ stack_kernel_1<T> temp_stack;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(temp,in);
+ temp_stack.push(temp);
+ }
+ while (temp_stack.size() > 0)
+ {
+ temp_stack.pop(temp);
+ item.push(temp);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.clear();
+ throw serialization_error(e.info + "\n while deserializing object of type stack_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ stack_kernel_1<T,mem_manager>::
+ ~stack_kernel_1(
+ )
+ {
+ delete_elements_in_stack(top);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void stack_kernel_1<T,mem_manager>::
+ clear(
+ )
+ {
+ if (stack_size != 0)
+ {
+ delete_elements_in_stack(top);
+ stack_size = 0;
+ }
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& stack_kernel_1<T,mem_manager>::
+ current(
+ )
+ {
+ return top->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& stack_kernel_1<T,mem_manager>::
+ current(
+ ) const
+ {
+ return top->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void stack_kernel_1<T,mem_manager>::
+ swap(
+ stack_kernel_1<T,mem_manager>& item
+ )
+ {
+ pool.swap(item.pool);
+
+ // declare temp variables
+ node* top_temp;
+ unsigned long stack_size_temp;
+
+ // swap stack_size variables
+ stack_size_temp = item.stack_size;
+ item.stack_size = stack_size;
+ stack_size = stack_size_temp;
+
+ // swap top pointers
+ top_temp = item.top;
+ item.top = top;
+ top = top_temp;
+
+ exchange(current_element,item.current_element);
+ exchange(_at_start,item._at_start);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void stack_kernel_1<T,mem_manager>::
+ push(
+ T& item
+ )
+ {
+ // allocate memory for new node
+ node* new_node = pool.allocate();
+
+ // swap item into new_node
+ exchange(new_node->item,item);
+
+ // put new_node into stack
+ new_node->next = top;
+ top = new_node;
+ ++stack_size;
+
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void stack_kernel_1<T,mem_manager>::
+ pop(
+ T& item
+ )
+ {
+ node* old_node = top;
+ top = top->next;
+
+ // swap the item from the stack into item
+ exchange(old_node->item,item);
+
+ // free the memory
+ pool.deallocate(old_node);
+ --stack_size;
+
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void stack_kernel_1<T,mem_manager>::
+ delete_elements_in_stack(
+ node*& top
+ )
+ {
+ node* temp;
+ while (top != 0)
+ {
+ temp = top->next;
+ pool.deallocate(top);
+ top = temp;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ size_t stack_kernel_1<T,mem_manager>::
+ size (
+ ) const
+ {
+ return stack_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool stack_kernel_1<T,mem_manager>::
+ at_start (
+ ) const
+ {
+ return _at_start;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void stack_kernel_1<T,mem_manager>::
+ reset (
+ ) const
+ {
+ _at_start = true;
+ current_element = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool stack_kernel_1<T,mem_manager>::
+ current_element_valid (
+ ) const
+ {
+ return current_element != 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ const T& stack_kernel_1<T,mem_manager>::
+ element (
+ ) const
+ {
+ return current_element->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ T& stack_kernel_1<T,mem_manager>::
+ element (
+ )
+ {
+ return current_element->item;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ bool stack_kernel_1<T,mem_manager>::
+ move_next (
+ ) const
+ {
+ if (!_at_start)
+ {
+ if (current_element)
+ {
+ current_element = current_element->next;
+ if (current_element)
+ return true;
+ else
+ return false;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ else
+ {
+ _at_start = false;
+ if (stack_size)
+ {
+ current_element = top;
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // remover function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void stack_kernel_1<T,mem_manager>::
+ remove_any (
+ T& item
+ )
+ {
+ pop(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STACK_KERNEl_1_
+
diff --git a/ml/dlib/dlib/stack/stack_kernel_abstract.h b/ml/dlib/dlib/stack/stack_kernel_abstract.h
new file mode 100644
index 000000000..d86cc0629
--- /dev/null
+++ b/ml/dlib/dlib/stack/stack_kernel_abstract.h
@@ -0,0 +1,180 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STACK_KERNEl_ABSTRACT_
+#ifdef DLIB_STACK_KERNEl_ABSTRACT_
+
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename mem_manager = default_memory_manager
+ >
+ class stack : public enumerable<T>,
+ public remover<T>
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must be swappable by a global swap() and
+ T must have a default constructor
+
+ REQUIREMENTS ON mem_manager
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ mem_manager::type can be set to anything.
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ swap() and current() functions do not invalidate pointers
+ or references to internal data.
+ All other functions have no such guarantee.
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the elements in the stack in the
+ same order they would be removed in by repeated calls to pop().
+ (e.g. current() would be the first element enumerated)
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a last in first out stack containing items of type T.
+
+ e.g. if the stack is {b,c,d,e} then a is put in
+ the stack becomes {a,b,c,d,e} and then pop takes a back out
+ returning the stack to {b,c,d,e}
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef mem_manager mem_manager_type;
+
+ stack (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ !*/
+
+ virtual ~stack (
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void push (
+ T& item
+ );
+ /*!
+ ensures
+ - item has been swapped onto the top of the stack
+ - #current() == item
+ - #item has an initial value for its type
+ - #size() == size() + 1
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ if push() throws then it has no effect
+ !*/
+
+ void pop (
+ T& item
+ );
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - #size() == size() - 1
+ - #item == current()
+ i.e. the top element of *this has been removed and swapped
+ into #item
+ - #at_start() == true
+ !*/
+
+ T& current (
+ );
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - returns a const reference to the element at the top of *this
+ !*/
+
+ const T& current (
+ ) const;
+ /*!
+ requires
+ - size() != 0
+ ensures
+ - returns a non-const reference to the element at the top of *this
+ !*/
+
+ void swap (
+ stack& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+
+ private:
+
+ // restricted functions
+ stack(stack&); // copy constructor
+ stack& operator=(stack&); // assignment operator
+
+ };
+
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ inline void swap (
+ stack<T,mem_manager>& a,
+ stack<T,mem_manager>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ typename mem_manager
+ >
+ void deserialize (
+ stack<T,mem_manager>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_STACK_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/stack/stack_kernel_c.h b/ml/dlib/dlib/stack/stack_kernel_c.h
new file mode 100644
index 000000000..ec8642a40
--- /dev/null
+++ b/ml/dlib/dlib/stack/stack_kernel_c.h
@@ -0,0 +1,189 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STACK_KERNEl_C_
+#define DLIB_STACK_KERNEl_C_
+
+#include "stack_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ typename stack_base
+ >
+ class stack_kernel_c : public stack_base
+ {
+ typedef typename stack_base::type T;
+ public:
+ void pop(
+ T& item
+ );
+
+ T& current(
+ );
+
+ const T& current(
+ ) const;
+
+ const T& element(
+ ) const;
+
+ T& element(
+ );
+
+ void remove_any (
+ T& item
+ );
+
+ };
+
+
+ template <
+ typename stack_base
+ >
+ inline void swap (
+ stack_kernel_c<stack_base>& a,
+ stack_kernel_c<stack_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack_base
+ >
+ void stack_kernel_c<stack_base>::
+ pop(
+ T& item
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->size() != 0,
+ "\tvoid stack::pop"
+ << "\n\tsize of stack should not be zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ stack_base::pop(item);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack_base
+ >
+ const typename stack_base::type& stack_kernel_c<stack_base>::
+ current(
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->size() != 0,
+ "\tconst T& stack::current"
+ << "\n\tsize of stack should not be zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return stack_base::current();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack_base
+ >
+ typename stack_base::type& stack_kernel_c<stack_base>::
+ current(
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->size() != 0,
+ "\tT& stack::current"
+ << "\n\tsize of stack should not be zero"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return stack_base::current();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack_base
+ >
+ typename stack_base::type& stack_kernel_c<stack_base>::
+ element(
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid(),
+ "\tT& stack::element"
+ << "\n\tThe current element must be valid if you are to access it."
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return stack_base::element();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack_base
+ >
+ const typename stack_base::type& stack_kernel_c<stack_base>::
+ element(
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid(),
+ "\tconst T& stack::element"
+ << "\n\tThe current element must be valid if you are to access it."
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return stack_base::element();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename stack_base
+ >
+ void stack_kernel_c<stack_base>::
+ remove_any (
+ T& item
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( (this->size() > 0),
+ "\tvoid stack::remove_any"
+ << "\n\tsize() must be greater than zero if something is going to be removed"
+ << "\n\tsize(): " << this->size()
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ stack_base::remove_any(item);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STACK_KERNEl_C_
+
diff --git a/ml/dlib/dlib/stack_trace.cpp b/ml/dlib/dlib/stack_trace.cpp
new file mode 100644
index 000000000..0a6ff8ee6
--- /dev/null
+++ b/ml/dlib/dlib/stack_trace.cpp
@@ -0,0 +1,91 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STACK_TRACE_CPp_
+#define DLIB_STACK_TRACE_CPp_
+
+#if defined(DLIB_ENABLE_STACK_TRACE) && !defined(NO_MAKEFILE)
+
+#include <sstream>
+#include <cstring>
+#include "stack_trace.h"
+#include "stack.h"
+#include "memory_manager.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace
+ {
+ struct stack_tracer_data
+ {
+ stack_tracer_data(
+ ) : funct_name(0),
+ file_name(0),
+ line_number(0)
+ {}
+ const char* funct_name;
+ const char* file_name;
+ int line_number;
+ };
+
+ using stack_tracer_stack_type = stack<stack_tracer_data,memory_manager<char>::kernel_2a>::kernel_1a;
+
+ stack_tracer_stack_type& get_dlib_stack_trace_stack()
+ {
+ thread_local stack_tracer_stack_type a;
+ return a;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ stack_tracer::
+ stack_tracer (
+ const char* funct_name,
+ const char* file_name,
+ const int line_number
+ )
+ {
+ stack_tracer_data data;
+ data.funct_name = funct_name;
+ data.file_name = file_name;
+ data.line_number = line_number;
+
+ // pop the info onto the function stack trace
+ get_dlib_stack_trace_stack().push(data);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ stack_tracer::
+ ~stack_tracer()
+ {
+ stack_tracer_data temp;
+ get_dlib_stack_trace_stack().pop(temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string get_stack_trace()
+ {
+ std::ostringstream sout;
+ auto& stack = get_dlib_stack_trace_stack();
+ stack.reset();
+ while (stack.move_next())
+ {
+ stack_tracer_data data = stack.element();
+ sout << data.file_name << ":" << data.line_number << "\n " << data.funct_name << "\n";
+ }
+ return sout.str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif
+
+#endif // DLIB_STACK_TRACE_CPp_
+
+
diff --git a/ml/dlib/dlib/stack_trace.h b/ml/dlib/dlib/stack_trace.h
new file mode 100644
index 000000000..aacbeb782
--- /dev/null
+++ b/ml/dlib/dlib/stack_trace.h
@@ -0,0 +1,118 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STACK_TRACe_
+#define DLIB_STACK_TRACe_
+
+/*!
+ This file defines 3 things. Two of them are preprocessor macros that
+ enable you to tag functions with the dlib stack trace watcher. The
+ third thing is a function named get_stack_trace() which returns the
+ current stack trace in std::string form.
+
+ To enable the stack trace you must #define DLIB_ENABLE_STACK_TRACE.
+ When this #define isn't set then the 3 things described above
+ still exist but they don't do anything.
+
+ Also note that when the stack trace is enabled it changes the DLIB_ASSERT
+ and DLIB_CASSERT macros so that they print stack traces when
+ an assert fails.
+
+ See the following example program for details:
+
+ #include <iostream>
+ #include <dlib/stack_trace.h>
+
+ void funct2()
+ {
+ // put this macro at the top of each function you would
+ // like to appear in stack traces
+ DLIB_STACK_TRACE;
+
+ // you may print the current stack trace as follows.
+ std::cout << dlib::get_stack_trace() << endl;
+ }
+
+ void funct()
+ {
+ // This alternate form of DLIB_STACK_TRACE allows you to specify
+ // the string used to name the current function. The other form
+ // will usually output an appropriate function name automatically
+ // so this may not be needed.
+ DLIB_STACK_TRACE_NAMED("funct");
+ funct2();
+ }
+
+ int main()
+ {
+ funct();
+ }
+!*/
+
+
+#include <string>
+#include "assert.h"
+
+// only setup the stack trace stuff if the asserts are enabled (which happens in debug mode
+// basically). Also, this stuff doesn't work if you use NO_MAKEFILE
+#if defined(DLIB_ENABLE_STACK_TRACE)
+#ifdef NO_MAKEFILE
+#error "You can't use the dlib stack trace stuff and NO_MAKEFILE at the same time"
+#endif
+
+namespace dlib
+{
+ const std::string get_stack_trace();
+}
+
+// redefine the DLIB_CASSERT macro to include the stack trace
+#undef DLIBM_CASSERT
+#define DLIBM_CASSERT(_exp,_message) \
+ {if ( !(_exp) ) \
+ { \
+ std::ostringstream dlib_o_out; \
+ dlib_o_out << "\n\nError occurred at line " << __LINE__ << ".\n"; \
+ dlib_o_out << "Error occurred in file " << __FILE__ << ".\n"; \
+ dlib_o_out << "Error occurred in function " << DLIB_FUNCTION_NAME << ".\n\n"; \
+ dlib_o_out << "Failing expression was " << #_exp << ".\n"; \
+ dlib_o_out << _message << "\n\n"; \
+ dlib_o_out << "Stack Trace: \n" << dlib::get_stack_trace() << "\n"; \
+ dlib_assert_breakpoint(); \
+ throw dlib::fatal_error(dlib::EBROKEN_ASSERT,dlib_o_out.str()); \
+ }}
+
+
+
+namespace dlib
+{
+
+ class stack_tracer
+ {
+ public:
+ stack_tracer (
+ const char* funct_name,
+ const char* file_name,
+ const int line_number
+ );
+
+ ~stack_tracer();
+
+ };
+}
+
+#define DLIB_STACK_TRACE_NAMED(x) dlib::stack_tracer dlib_stack_tracer_object(x,__FILE__,__LINE__)
+#define DLIB_STACK_TRACE dlib::stack_tracer dlib_stack_tracer_object(DLIB_FUNCTION_NAME,__FILE__,__LINE__)
+
+#else // don't do anything if ENABLE_ASSERTS isn't defined
+#define DLIB_STACK_TRACE_NAMED(x)
+#define DLIB_STACK_TRACE
+
+namespace dlib
+{
+ inline const std::string get_stack_trace() { return std::string("stack trace not enabled");}
+}
+
+#endif
+
+
+#endif // DLIB_STACK_TRACe_
+
diff --git a/ml/dlib/dlib/static_map.h b/ml/dlib/dlib/static_map.h
new file mode 100644
index 000000000..f1fcadab9
--- /dev/null
+++ b/ml/dlib/dlib/static_map.h
@@ -0,0 +1,43 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STATIC_MAp_
+#define DLIB_STATIC_MAp_
+
+#include "static_map/static_map_kernel_1.h"
+#include "static_map/static_map_kernel_c.h"
+
+#include <functional>
+
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename compare = std::less<domain>
+ >
+ class static_map
+ {
+ static_map() {}
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef static_map_kernel_1<domain,range,compare>
+ kernel_1a;
+ typedef static_map_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+
+
+
+ };
+}
+
+#endif // DLIB_STATIC_MAp_
+
diff --git a/ml/dlib/dlib/static_map/static_map_kernel_1.h b/ml/dlib/dlib/static_map/static_map_kernel_1.h
new file mode 100644
index 000000000..a7b627ae6
--- /dev/null
+++ b/ml/dlib/dlib/static_map/static_map_kernel_1.h
@@ -0,0 +1,756 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STATIC_MAP_KERNEl_1_
+#define DLIB_STATIC_MAP_KERNEl_1_
+
+#include "static_map_kernel_abstract.h"
+#include "../interfaces/map_pair.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename compare = std::less<domain>
+ >
+ class static_map_kernel_1 : public enumerable<map_pair<domain,range> >
+ {
+
+ /*!
+ INITIAL VALUE
+ - map_size == 0
+ - d == 0
+ - r == 0
+ - mp.d = 0;
+ - at_start_ == true
+
+
+ CONVENTION
+ - size() == map_size
+ - if (size() > 0) then
+ - d == pointer to an array containing all the domain elements
+ - r == pointer to an array containing all the range elements
+ - for every i: operator[](d[i]) == r[i]
+ - d is sorted according to operator<
+ - else
+ - d == 0
+ - r == 0
+
+ - current_element_valid() == (mp.d != 0)
+ - at_start() == (at_start_)
+ - if (current_element_valid()) then
+ - element() == mp
+ !*/
+
+ class mpair : public map_pair<domain,range>
+ {
+ public:
+ const domain* d;
+ range* r;
+
+ const domain& key(
+ ) const { return *d; }
+
+ const range& value(
+ ) const { return *r; }
+
+ range& value(
+ ) { return *r; }
+ };
+
+
+ // I would define this outside the class but Borland 5.5 has some problems
+ // with non-inline templated friend functions.
+ friend void deserialize (
+ static_map_kernel_1& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ unsigned long size;
+ deserialize(size,in);
+ item.map_size = size;
+ item.d = new domain[size];
+ item.r = new range[size];
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(item.d[i],in);
+ deserialize(item.r[i],in);
+ }
+ }
+ catch (serialization_error& e)
+ {
+ item.map_size = 0;
+ if (item.d)
+ {
+ delete [] item.d;
+ item.d = 0;
+ }
+ if (item.r)
+ {
+ delete [] item.r;
+ item.r = 0;
+ }
+
+ throw serialization_error(e.info + "\n while deserializing object of type static_map_kernel_1");
+ }
+ catch (...)
+ {
+ item.map_size = 0;
+ if (item.d)
+ {
+ delete [] item.d;
+ item.d = 0;
+ }
+ if (item.r)
+ {
+ delete [] item.r;
+ item.r = 0;
+ }
+
+ throw;
+ }
+ }
+
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+
+ static_map_kernel_1(
+ );
+
+ virtual ~static_map_kernel_1(
+ );
+
+ void clear (
+ );
+
+ void load (
+ pair_remover<domain,range>& source
+ );
+
+ void load (
+ asc_pair_remover<domain,range,compare>& source
+ );
+
+ inline const range* operator[] (
+ const domain& d
+ ) const;
+
+ inline range* operator[] (
+ const domain& d
+ );
+
+ inline void swap (
+ static_map_kernel_1& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ inline bool current_element_valid (
+ ) const;
+
+ inline const map_pair<domain,range>& element (
+ ) const;
+
+ inline map_pair<domain,range>& element (
+ );
+
+ inline bool move_next (
+ ) const;
+
+
+ private:
+
+ bool binary_search (
+ const domain& item,
+ unsigned long& pos
+ ) const;
+ /*!
+ ensures
+ - if (there is an item in d equivalent to item) then
+ - returns true
+ - d[#pos] is equivalent item
+ - else
+ - returns false
+ !*/
+
+ void sort_arrays (
+ unsigned long left,
+ unsigned long right
+ );
+ /*!
+ requires
+ - left and right are within the bounts of the array
+ ensures
+ - everything in the convention is still true and d[left] though
+ d[right] is sorted according to operator<
+ !*/
+
+ void qsort_partition (
+ unsigned long& partition_element,
+ const unsigned long left,
+ const unsigned long right
+ );
+ /*!
+ requires
+ - left < right
+ - left and right are within the bounts of the array
+ ensures
+ - the convention is still true
+ - left <= #partition_element <= right
+ - all elements in #d < #d[#partition_element] have
+ indices >= left and < #partition_element
+ - all elements in #d >= #d[#partition_element] have
+ indices >= #partition_element and <= right
+ !*/
+
+ unsigned long median (
+ unsigned long one,
+ unsigned long two,
+ unsigned long three
+ );
+ /*!
+ requires
+ - one, two, and three are valid indexes into d
+ ensures
+ - returns the median of d[one], d[two], and d[three]
+ !*/
+
+
+
+
+ // data members
+ unsigned long map_size;
+ domain* d;
+ range* r;
+ mutable mpair mp;
+ mutable bool at_start_;
+ compare comp;
+
+ // restricted functions
+ static_map_kernel_1(static_map_kernel_1&); // copy constructor
+ static_map_kernel_1& operator=(static_map_kernel_1&); // assignment operator
+ };
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ inline void swap (
+ static_map_kernel_1<domain,range,compare>& a,
+ static_map_kernel_1<domain,range,compare>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ static_map_kernel_1<domain,range,compare>::
+ static_map_kernel_1(
+ ) :
+ map_size(0),
+ d(0),
+ r(0),
+ at_start_(true)
+ {
+ mp.d = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ static_map_kernel_1<domain,range,compare>::
+ ~static_map_kernel_1(
+ )
+ {
+ if (map_size > 0)
+ {
+ delete [] d;
+ delete [] r;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ void static_map_kernel_1<domain,range,compare>::
+ clear (
+ )
+ {
+ if (map_size > 0)
+ {
+ map_size = 0;
+ delete [] d;
+ delete [] r;
+ d = 0;
+ r = 0;
+ }
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ void static_map_kernel_1<domain,range,compare>::
+ load (
+ pair_remover<domain,range>& source
+ )
+ {
+ if (source.size() > 0)
+ {
+ domain* old_d = d;
+ d = new domain[source.size()];
+ try { r = new range[source.size()]; }
+ catch (...) { delete [] d; d = old_d; throw; }
+
+ map_size = source.size();
+
+ for (unsigned long i = 0; source.size() > 0; ++i)
+ source.remove_any(d[i],r[i]);
+
+ sort_arrays(0,map_size-1);
+ }
+ else
+ {
+ clear();
+ }
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ void static_map_kernel_1<domain,range,compare>::
+ load (
+ asc_pair_remover<domain,range,compare>& source
+ )
+ {
+ if (source.size() > 0)
+ {
+ domain* old_d = d;
+ d = new domain[source.size()];
+ try { r = new range[source.size()]; }
+ catch (...) { delete [] d; d = old_d; throw; }
+
+ map_size = source.size();
+
+ for (unsigned long i = 0; source.size() > 0; ++i)
+ source.remove_any(d[i],r[i]);
+ }
+ else
+ {
+ clear();
+ }
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ const range* static_map_kernel_1<domain,range,compare>::
+ operator[] (
+ const domain& d_item
+ ) const
+ {
+ unsigned long pos;
+ if (binary_search(d_item,pos))
+ return r+pos;
+ else
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ range* static_map_kernel_1<domain,range,compare>::
+ operator[] (
+ const domain& d_item
+ )
+ {
+ unsigned long pos;
+ if (binary_search(d_item,pos))
+ return r+pos;
+ else
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ size_t static_map_kernel_1<domain,range,compare>::
+ size (
+ ) const
+ {
+ return map_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ void static_map_kernel_1<domain,range,compare>::
+ swap (
+ static_map_kernel_1<domain,range,compare>& item
+ )
+ {
+ exchange(map_size,item.map_size);
+ exchange(d,item.d);
+ exchange(r,item.r);
+ exchange(mp,item.mp);
+ exchange(at_start_,item.at_start_);
+ exchange(comp,item.comp);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ bool static_map_kernel_1<domain,range,compare>::
+ at_start (
+ ) const
+ {
+ return (at_start_);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ void static_map_kernel_1<domain,range,compare>::
+ reset (
+ ) const
+ {
+ mp.d = 0;
+ at_start_ = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ bool static_map_kernel_1<domain,range,compare>::
+ current_element_valid (
+ ) const
+ {
+ return (mp.d != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ const map_pair<domain,range>& static_map_kernel_1<domain,range,compare>::
+ element (
+ ) const
+ {
+ return mp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ map_pair<domain,range>& static_map_kernel_1<domain,range,compare>::
+ element (
+ )
+ {
+ return mp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ bool static_map_kernel_1<domain,range,compare>::
+ move_next (
+ ) const
+ {
+ // if at_start() && size() > 0
+ if (at_start_ && map_size > 0)
+ {
+ at_start_ = false;
+ mp.r = r;
+ mp.d = d;
+ return true;
+ }
+ // else if current_element_valid()
+ else if (mp.d != 0)
+ {
+ ++mp.d;
+ ++mp.r;
+ if (static_cast<unsigned long>(mp.d - d) < map_size)
+ {
+ return true;
+ }
+ else
+ {
+ mp.d = 0;
+ return false;
+ }
+ }
+ else
+ {
+ at_start_ = false;
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ bool static_map_kernel_1<domain,range,compare>::
+ binary_search (
+ const domain& item,
+ unsigned long& pos
+ ) const
+ {
+ unsigned long high = map_size;
+ unsigned long low = 0;
+ unsigned long p = map_size;
+ unsigned long idx;
+ while (p > 0)
+ {
+ p = (high-low)>>1;
+ idx = p+low;
+ if (comp(item , d[idx]))
+ {
+ high = idx;
+ }
+ else if (comp(d[idx] , item))
+ {
+ low = idx;
+ }
+ else
+ {
+ pos = idx;
+ return true;
+ }
+ }
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ void static_map_kernel_1<domain,range,compare>::
+ sort_arrays (
+ unsigned long left,
+ unsigned long right
+ )
+ {
+ if ( left < right)
+ {
+ unsigned long partition_element;
+ qsort_partition(partition_element,left,right);
+
+ if (partition_element > 0)
+ sort_arrays(left,partition_element-1);
+ sort_arrays(partition_element+1,right);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ void static_map_kernel_1<domain,range,compare>::
+ qsort_partition (
+ unsigned long& partition_element,
+ const unsigned long left,
+ const unsigned long right
+ )
+ {
+ partition_element = right;
+
+ unsigned long med = median(partition_element,left,((right-left)>>1) +left);
+ exchange(d[partition_element],d[med]);
+ exchange(r[partition_element],r[med]);
+
+ unsigned long right_scan = right-1;
+ unsigned long left_scan = left;
+
+ while (true)
+ {
+ // find an element to the left of partition_element that needs to be moved
+ while ( comp( d[left_scan] , d[partition_element]) )
+ {
+ ++left_scan;
+ }
+
+ // find an element to the right of partition_element that needs to be moved
+ while (
+ !(comp (d[right_scan] , d[partition_element])) &&
+ (right_scan > left_scan)
+ )
+ {
+ --right_scan;
+ }
+ if (left_scan >= right_scan)
+ break;
+
+ exchange(d[left_scan],d[right_scan]);
+ exchange(r[left_scan],r[right_scan]);
+
+ }
+ exchange(d[left_scan],d[partition_element]);
+ exchange(r[left_scan],r[partition_element]);
+ partition_element = left_scan;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ unsigned long static_map_kernel_1<domain,range,compare>::
+ median (
+ unsigned long one,
+ unsigned long two,
+ unsigned long three
+ )
+ {
+ if ( comp( d[one] , d[two]) )
+ {
+ // one < two
+ if ( comp( d[two] , d[three]) )
+ {
+ // one < two < three : two
+ return two;
+ }
+ else
+ {
+ // one < two >= three
+ if (comp( d[one] , d[three]))
+ {
+ // three
+ return three;
+ }
+ }
+
+ }
+ else
+ {
+ // one >= two
+ if ( comp(d[three] , d[one] ))
+ {
+ // three <= one >= two
+ if ( comp(d[three] , d[two]) )
+ {
+ // two
+ return two;
+ }
+ else
+ {
+ // three
+ return three;
+ }
+ }
+ }
+ return one;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STATIC_MAP_KERNEl_1_
+
diff --git a/ml/dlib/dlib/static_map/static_map_kernel_abstract.h b/ml/dlib/dlib/static_map/static_map_kernel_abstract.h
new file mode 100644
index 000000000..0821367b5
--- /dev/null
+++ b/ml/dlib/dlib/static_map/static_map_kernel_abstract.h
@@ -0,0 +1,181 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STATIC_MAP_KERNEl_ABSTRACT_
+#ifdef DLIB_STATIC_MAP_KERNEl_ABSTRACT_
+
+#include "../interfaces/map_pair.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename domain,
+ typename range,
+ typename compare = std::less<domain>
+ >
+ class static_map : public enumerable<map_pair<domain,range> >
+ {
+
+ /*!
+ REQUIREMENTS ON domain
+ domain must be comparable by compare where compare is a functor compatible with std::less and
+ domain is swappable by a global swap() and
+ domain must have a default constructor
+
+ REQUIREMENTS ON range
+ range is swappable by a global swap() and
+ range must have a default constructor
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ Only the destructor and load_from() will invalidate pointers or
+ references to internal data.
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the domain (and each associated
+ range element) elements in ascending order according to the compare functor.
+ (i.e. the elements are enumerated in sorted order)
+
+ WHAT THIS OBJECT REPRESENTS
+ static_map contains items of type domain and range
+
+ This object is similar an array. It maps items of type domain on to
+ items of type range.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+ NOTE
+ definition of equivalent:
+ a is equivalent to b if
+ a < b == false and
+ b < a == false
+ !*/
+
+ public:
+
+ typedef domain domain_type;
+ typedef range range_type;
+ typedef compare compare_type;
+
+ static_map (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ !*/
+
+ virtual ~static_map(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ If this exception is thrown then #*this is unusable
+ until clear() is called and succeeds.
+ !*/
+
+ void load (
+ pair_remover<domain,range>& source
+ );
+ /*!
+ ensures
+ - #size() == source.size()
+ - #source.size() == 0
+ - all the pairs in source are removed and placed into #*this
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by domain's or range's
+ constructor.
+ If this exception is thrown then the call to load() will have
+ no effect on #*this.
+ !*/
+
+ const range* operator[] (
+ const domain& d
+ ) const;
+ /*!
+ ensures
+ - if (there is an element in the domain equivalent to d) then
+ - returns a pointer to an element in the range of *this that
+ is associated with an element in the domain of *this
+ equivalent to d.
+ - else
+ - returns 0
+ !*/
+
+ range* operator[] (
+ const domain& d
+ );
+ /*!
+ ensures
+ - if (there is an element in the domain equivalent to d) then
+ - returns a pointer to an element in the range of *this that
+ is associated with an element in the domain of *this
+ equivalent to d.
+ - else
+ - returns 0
+ !*/
+
+ void swap (
+ static_map& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ static_map(static_map&); // copy constructor
+ static_map& operator=(static_map&); // assignment operator
+ };
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ inline void swap (
+ static_map<domain,range,compare>& a,
+ static_map<domain,range,compare>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename domain,
+ typename range,
+ typename compare
+ >
+ void deserialize (
+ static_map<domain,range,compare>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+}
+
+#endif // DLIB_STATIC_MAP_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/static_map/static_map_kernel_c.h b/ml/dlib/dlib/static_map/static_map_kernel_c.h
new file mode 100644
index 000000000..576d79374
--- /dev/null
+++ b/ml/dlib/dlib/static_map/static_map_kernel_c.h
@@ -0,0 +1,89 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STATIC_MAP_KERNEl_C_
+#define DLIB_STATIC_MAP_KERNEl_C_
+
+#include "static_map_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include "../interfaces/map_pair.h"
+#include "../interfaces/remover.h"
+
+namespace dlib
+{
+
+ template <
+ typename map_base
+ >
+ class static_map_kernel_c : public map_base
+ {
+ typedef typename map_base::domain_type domain;
+ typedef typename map_base::range_type range;
+
+ public:
+ const map_pair<domain,range>& element (
+ ) const;
+
+ map_pair<domain,range>& element (
+ );
+
+ };
+
+ template <
+ typename map_base
+ >
+ inline void swap (
+ static_map_kernel_c<map_base>& a,
+ static_map_kernel_c<map_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_base
+ >
+ const map_pair<typename map_base::domain_type,typename map_base::range_type>& static_map_kernel_c<map_base>::
+ element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst map_pair<domain,range>& static_map::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return map_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_base
+ >
+ map_pair<typename map_base::domain_type,typename map_base::range_type>& static_map_kernel_c<map_base>::
+ element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tmap_pair<domain,range>& static_map::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return map_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STATIC_MAP_KERNEl_C_
+
diff --git a/ml/dlib/dlib/static_set.h b/ml/dlib/dlib/static_set.h
new file mode 100644
index 000000000..47ecbafe4
--- /dev/null
+++ b/ml/dlib/dlib/static_set.h
@@ -0,0 +1,49 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STATIC_SEt_
+#define DLIB_STATIC_SEt_
+
+#include "static_set/static_set_kernel_1.h"
+#include "static_set/static_set_kernel_c.h"
+#include "static_set/static_set_compare_1.h"
+
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename compare = std::less<T>
+ >
+ class static_set
+ {
+ static_set() {}
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef static_set_kernel_1<T,compare>
+ kernel_1a;
+ typedef static_set_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ //----------- extensions -------------
+
+ typedef static_set_compare_1<kernel_1a>
+ compare_1a;
+ typedef static_set_compare_1<kernel_1a_c>
+ compare_1a_c;
+
+
+
+
+ };
+}
+
+#endif // DLIB_STATIC_SEt_
+
diff --git a/ml/dlib/dlib/static_set/static_set_compare_1.h b/ml/dlib/dlib/static_set/static_set_compare_1.h
new file mode 100644
index 000000000..b5271e1d4
--- /dev/null
+++ b/ml/dlib/dlib/static_set/static_set_compare_1.h
@@ -0,0 +1,122 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STATIC_SET_COMPARe_1_
+#define DLIB_STATIC_SET_COMPARe_1_
+
+#include "static_set_compare_abstract.h"
+
+#include "../algs.h"
+
+
+
+namespace dlib
+{
+
+ template <
+ typename static_set_base
+ >
+ class static_set_compare_1 : public static_set_base
+ {
+
+ public:
+
+ bool operator< (
+ const static_set_compare_1& rhs
+ ) const;
+
+ bool operator== (
+ const static_set_compare_1& rhs
+ ) const;
+
+ };
+
+
+ template <
+ typename static_set_base
+ >
+ inline void swap (
+ static_set_compare_1<static_set_base>& a,
+ static_set_compare_1<static_set_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename static_set_base
+ >
+ bool static_set_compare_1<static_set_base>::
+ operator< (
+ const static_set_compare_1<static_set_base>& rhs
+ ) const
+ {
+ bool result = false;
+ if (static_set_base::size() < rhs.size())
+ result = true;
+
+ if (static_set_base::size() == rhs.size())
+ {
+ rhs.reset();
+ static_set_base::reset();
+ while (rhs.move_next())
+ {
+ static_set_base::move_next();
+ if (static_set_base::element() < rhs.element())
+ {
+ result = true;
+ break;
+ }
+ else if (rhs.element() < static_set_base::element())
+ {
+ break;
+ }
+ }
+ }
+
+ static_set_base::reset();
+ rhs.reset();
+
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename static_set_base
+ >
+ bool static_set_compare_1<static_set_base>::
+ operator== (
+ const static_set_compare_1<static_set_base>& rhs
+ ) const
+ {
+ bool result = true;
+ if (static_set_base::size() != rhs.size())
+ result = false;
+
+
+ rhs.reset();
+ static_set_base::reset();
+ while (rhs.move_next() && static_set_base::move_next())
+ {
+ if (!(rhs.element() == static_set_base::element()))
+ {
+ result = false;
+ break;
+ }
+ }
+
+ static_set_base::reset();
+ rhs.reset();
+
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STATIC_SET_COMPARe_1_
+
diff --git a/ml/dlib/dlib/static_set/static_set_compare_abstract.h b/ml/dlib/dlib/static_set/static_set_compare_abstract.h
new file mode 100644
index 000000000..356354e69
--- /dev/null
+++ b/ml/dlib/dlib/static_set/static_set_compare_abstract.h
@@ -0,0 +1,93 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STATIC_SET_COMPARe_ABSTRACT_
+#ifdef DLIB_STATIC_SET_COMPARe_ABSTRACT_
+
+#include "static_set_kernel_abstract.h"
+
+#include "../algs.h"
+
+
+namespace dlib
+{
+
+ template <
+ typename static_set_base
+ >
+ class static_set_compare : public static_set_base
+ {
+
+ /*!
+ REQUIREMENTS ON static_set_base
+ must an implementation of static_set/static_set_kernel_abstract.h
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ operator== and operator< invalidate pointers or references to
+ data members.
+
+ WHAT THIS EXTENSION DOES FOR static_set
+ This gives a static_set the ability to compare itself to other
+ static_sets using the < and == operators.
+
+ The < operator is conceptually weird for sets. It is useful
+ though because it allows you to make sets of sets since
+ sets require that their containing type implement operator<.
+
+ Also note that it is the case that for any two sets a and b
+ if (a<b) == false and (b<a) == false then a == b.
+
+
+ NOTATION
+ For the purposes of defining what these operators do I will
+ use the operator[] to reference the elements of the static_sets.
+ operator[] is defined to access the elements of the static_set in
+ the same order they would be enumerated by the enumerable
+ interface.
+ !*/
+
+ public:
+
+ bool operator< (
+ const static_set_compare& rhs
+ ) const;
+ /*!
+ ensures
+ - #at_start() == true
+ - if (size() < rhs.size()) then
+ - returns true
+ - else if (size() > rhs.size()) then
+ - returns false
+ - else
+ - returns true if there exists an integer j such that 0 <= j < size()
+ and for all integers i such that 0 <= i < j where it is true that
+ (*this)[i] == rhs[i] and (*this)[j] < rhs[j]
+ - returns false if there is no j that will satisfy the above conditions.
+ !*/
+
+ bool operator== (
+ const static_set_compare& rhs
+ ) const;
+ /*!
+ ensures
+ - #at_start() == true
+ - returns true if *this and rhs contain the same elements.
+ returns false otherwise.
+ !*/
+ };
+
+
+ template <
+ typename static_set_base
+ >
+ inline void swap (
+ static_set_compare<static_set_base>& a,
+ static_set_compare<static_set_base>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_STATIC_SET_COMPARe_ABSTRACT_
+
diff --git a/ml/dlib/dlib/static_set/static_set_kernel_1.h b/ml/dlib/dlib/static_set/static_set_kernel_1.h
new file mode 100644
index 000000000..7a1f166fc
--- /dev/null
+++ b/ml/dlib/dlib/static_set/static_set_kernel_1.h
@@ -0,0 +1,446 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STATIC_SET_KERNEl_1_
+#define DLIB_STATIC_SET_KERNEl_1_
+
+#include "static_set_kernel_abstract.h"
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../algs.h"
+#include "../sort.h"
+#include "../serialize.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename compare = std::less<T>
+ >
+ class static_set_kernel_1 : public enumerable<const T>
+ {
+
+ /*!
+ INITIAL VALUE
+ - set_size == 0
+ - d == 0
+ - at_start_ == true
+ - cur == 0
+
+ CONVENTION
+ - size() == set_size
+ - if (set_size > 0) then
+ - d == pointer to an array containing all the elements of the set
+ - d is sorted according to operator<
+ - else
+ - d == 0
+
+ - current_element_valid() == (cur != 0)
+ - at_start() == (at_start_)
+ - if (current_element_valid()) then
+ - element() == *cur
+ !*/
+
+ // I would define this outside the class but Borland 5.5 has some problems
+ // with non-inline templated friend functions.
+ friend void deserialize (
+ static_set_kernel_1& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.clear();
+ size_t size;
+ deserialize(size,in);
+ item.set_size = size;
+ item.d = new T[size];
+ for (size_t i = 0; i < size; ++i)
+ {
+ deserialize(item.d[i],in);
+ }
+ }
+ catch (serialization_error e)
+ {
+ item.set_size = 0;
+ if (item.d)
+ {
+ delete [] item.d;
+ item.d = 0;
+ }
+
+ throw serialization_error(e.info + "\n while deserializing object of type static_set_kernel_1");
+ }
+ catch (...)
+ {
+ item.set_size = 0;
+ if (item.d)
+ {
+ delete [] item.d;
+ item.d = 0;
+ }
+
+ throw;
+ }
+ }
+
+ public:
+
+ typedef T type;
+ typedef compare compare_type;
+
+ static_set_kernel_1(
+ );
+
+ virtual ~static_set_kernel_1(
+ );
+
+ void clear (
+ );
+
+ void load (
+ remover<T>& source
+ );
+
+ void load (
+ asc_remover<T,compare>& source
+ );
+
+ bool is_member (
+ const T& item
+ ) const;
+
+ inline void swap (
+ static_set_kernel_1& item
+ );
+
+ // functions from the enumerable interface
+ inline size_t size (
+ ) const;
+
+ inline bool at_start (
+ ) const;
+
+ inline void reset (
+ ) const;
+
+ inline bool current_element_valid (
+ ) const;
+
+ inline const T& element (
+ ) const;
+
+ inline const T& element (
+ );
+
+ inline bool move_next (
+ ) const;
+
+
+ private:
+
+
+ // data members
+ size_t set_size;
+ T* d;
+ mutable T* cur;
+ mutable bool at_start_;
+
+ // restricted functions
+ static_set_kernel_1(static_set_kernel_1&); // copy constructor
+ static_set_kernel_1& operator=(static_set_kernel_1&); // assignment operator
+ };
+
+ template <
+ typename T,
+ typename compare
+ >
+ inline void swap (
+ static_set_kernel_1<T,compare>& a,
+ static_set_kernel_1<T,compare>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ static_set_kernel_1<T,compare>::
+ static_set_kernel_1(
+ ) :
+ set_size(0),
+ d(0),
+ cur(0),
+ at_start_(true)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ static_set_kernel_1<T,compare>::
+ ~static_set_kernel_1(
+ )
+ {
+ if (set_size > 0)
+ delete [] d;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void static_set_kernel_1<T,compare>::
+ clear(
+ )
+ {
+ if (set_size > 0)
+ {
+ set_size = 0;
+ delete [] d;
+ d = 0;
+ }
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void static_set_kernel_1<T,compare>::
+ load (
+ remover<T>& source
+ )
+ {
+ if (source.size() > 0)
+ {
+ d = new T[source.size()];
+
+ set_size = source.size();
+
+ for (size_t i = 0; source.size() > 0; ++i)
+ source.remove_any(d[i]);
+
+ compare comp;
+ qsort_array(d,0,set_size-1,comp);
+ }
+ else
+ {
+ clear();
+ }
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void static_set_kernel_1<T,compare>::
+ load (
+ asc_remover<T,compare>& source
+ )
+ {
+ if (source.size() > 0)
+ {
+ d = new T[source.size()];
+
+ set_size = source.size();
+
+ for (size_t i = 0; source.size() > 0; ++i)
+ source.remove_any(d[i]);
+ }
+ else
+ {
+ clear();
+ }
+ reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ bool static_set_kernel_1<T,compare>::
+ is_member (
+ const T& item
+ ) const
+ {
+ size_t high = set_size;
+ size_t low = 0;
+ size_t p = set_size;
+ size_t idx;
+ while (p > 0)
+ {
+ p = (high-low)>>1;
+ idx = p+low;
+ if (item < d[idx])
+ high = idx;
+ else if (d[idx] < item)
+ low = idx;
+ else
+ return true;
+ }
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ size_t static_set_kernel_1<T,compare>::
+ size (
+ ) const
+ {
+ return set_size;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void static_set_kernel_1<T,compare>::
+ swap (
+ static_set_kernel_1<T,compare>& item
+ )
+ {
+ exchange(set_size,item.set_size);
+ exchange(d,item.d);
+ exchange(cur,item.cur);
+ exchange(at_start_,item.at_start_);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // enumerable function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ bool static_set_kernel_1<T,compare>::
+ at_start (
+ ) const
+ {
+ return at_start_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ void static_set_kernel_1<T,compare>::
+ reset (
+ ) const
+ {
+ at_start_ = true;
+ cur = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ bool static_set_kernel_1<T,compare>::
+ current_element_valid (
+ ) const
+ {
+ return (cur != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ const T& static_set_kernel_1<T,compare>::
+ element (
+ ) const
+ {
+ return *cur;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ const T& static_set_kernel_1<T,compare>::
+ element (
+ )
+ {
+ return *cur;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename compare
+ >
+ bool static_set_kernel_1<T,compare>::
+ move_next (
+ ) const
+ {
+ // if at_start() && size() > 0
+ if (at_start_ && set_size > 0)
+ {
+ at_start_ = false;
+ cur = d;
+ return true;
+ }
+ // else if current_element_valid()
+ else if (cur != 0)
+ {
+ ++cur;
+ if (static_cast<size_t>(cur - d) < set_size)
+ {
+ return true;
+ }
+ else
+ {
+ cur = 0;
+ return false;
+ }
+ }
+ else
+ {
+ at_start_ = false;
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STATIC_SET_KERNEl_1_
+
diff --git a/ml/dlib/dlib/static_set/static_set_kernel_abstract.h b/ml/dlib/dlib/static_set/static_set_kernel_abstract.h
new file mode 100644
index 000000000..a2efd1aaa
--- /dev/null
+++ b/ml/dlib/dlib/static_set/static_set_kernel_abstract.h
@@ -0,0 +1,154 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STATIC_SET_KERNEl_ABSTRACT_
+#ifdef DLIB_STATIC_SET_KERNEl_ABSTRACT_
+
+#include "../interfaces/enumerable.h"
+#include "../interfaces/remover.h"
+#include "../serialize.h"
+#include <functional>
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename compare = std::less<T>
+ >
+ class static_set : public enumerable<const T>
+ {
+
+ /*!
+ REQUIREMENTS ON T
+ T must be comparable by compare where compare is a functor compatible with std::less and
+ T is swappable by a global swap() and
+ T must have a default constructor
+
+ POINTERS AND REFERENCES TO INTERNAL DATA
+ Only the destructor will invalidate pointers or references
+ to internal data.
+
+ INITIAL VALUE
+ size() == 0
+
+ ENUMERATION ORDER
+ The enumerator will iterate over the elements in the set in
+ ascending order according to the compare functor.
+ (i.e. the elements are enumerated in sorted order)
+
+ WHAT THIS OBJECT REPRESENTS
+ static_set contains items of type T
+
+ This object represents an unaddressed collection of items.
+
+ Also note that unless specified otherwise, no member functions
+ of this object throw exceptions.
+
+ NOTE
+ definition of equivalent:
+ a is equivalent to b if
+ a < b == false and
+ b < a == false
+ !*/
+
+ public:
+
+ typedef T type;
+ typedef compare compare_type;
+
+ static_set (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ !*/
+
+ virtual ~static_set(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then #*this is unusable
+ until clear() is called and succeeds.
+ !*/
+
+ void load (
+ remover<T>& source
+ );
+ /*!
+ ensures
+ - #size() == source.size()
+ - #source.size() == 0
+ - all the elements in source are removed and placed into #*this
+ - #at_start() == true
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor.
+ If this exception is thrown then the call to load() will have
+ no effect on #*this.
+ !*/
+
+ bool is_member (
+ const T& item
+ ) const;
+ /*!
+ ensures
+ - if (there is an item in *this equivalent to item) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void swap (
+ static_set& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ static_set(static_set&); // copy constructor
+ static_set& operator=(static_set&); // assignment operator
+ };
+
+ template <
+ typename T,
+ typename compare
+ >
+ inline void swap (
+ static_set<T,compare>& a,
+ static_set<T,compare>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename T,
+ typename compare
+ >
+ void deserialize (
+ static_set<T,compare>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+}
+
+#endif // DLIB_STATIC_SET_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/static_set/static_set_kernel_c.h b/ml/dlib/dlib/static_set/static_set_kernel_c.h
new file mode 100644
index 000000000..1280c9c89
--- /dev/null
+++ b/ml/dlib/dlib/static_set/static_set_kernel_c.h
@@ -0,0 +1,88 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STATIC_SET_KERNEl_C_
+#define DLIB_STATIC_SET_KERNEl_C_
+
+#include "static_set_kernel_abstract.h"
+#include "../algs.h"
+#include "../assert.h"
+#include "../interfaces/remover.h"
+
+namespace dlib
+{
+
+ template <
+ typename set_base
+ >
+ class static_set_kernel_c : public set_base
+ {
+ typedef typename set_base::type T;
+ public:
+
+ const T& element (
+ );
+
+ const T& element (
+ ) const;
+ };
+
+
+ template <
+ typename set_base
+ >
+ inline void swap (
+ static_set_kernel_c<set_base>& a,
+ static_set_kernel_c<set_base>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ const typename set_base::type& static_set_kernel_c<set_base>::
+ element (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst T& static_set::element() const"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return set_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename set_base
+ >
+ const typename set_base::type& static_set_kernel_c<set_base>::
+ element (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(this->current_element_valid() == true,
+ "\tconst T& static_set::element"
+ << "\n\tyou can't access the current element if it doesn't exist"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return set_base::element();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+#endif // DLIB_STATIC_SET_KERNEl_C_
+
diff --git a/ml/dlib/dlib/statistics.h b/ml/dlib/dlib/statistics.h
new file mode 100644
index 000000000..45785c635
--- /dev/null
+++ b/ml/dlib/dlib/statistics.h
@@ -0,0 +1,19 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STATISTICs_H_
+#define DLIB_STATISTICs_H_
+
+#include "statistics/statistics.h"
+#include "statistics/dpca.h"
+#include "statistics/random_subset_selector.h"
+#include "statistics/image_feature_sampling.h"
+#include "statistics/sammon.h"
+#include "statistics/cca.h"
+#include "statistics/average_precision.h"
+#include "statistics/vector_normalizer_frobmetric.h"
+#include "statistics/lda.h"
+
+#endif // DLIB_STATISTICs_H_
+
+
+
diff --git a/ml/dlib/dlib/statistics/average_precision.h b/ml/dlib/dlib/statistics/average_precision.h
new file mode 100644
index 000000000..6c2e7e0b1
--- /dev/null
+++ b/ml/dlib/dlib/statistics/average_precision.h
@@ -0,0 +1,66 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AVERAGE_PREcISION_Hh_
+#define DLIB_AVERAGE_PREcISION_Hh_
+
+#include "average_precision_abstract.h"
+#include <vector>
+
+
+namespace dlib
+{
+ namespace impl
+ {
+ inline bool get_bool_part (
+ const bool& b
+ ) { return b; }
+
+ template <typename T>
+ bool get_bool_part(const std::pair<T,bool>& item) { return item.second; }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ double average_precision (
+ const std::vector<T,alloc>& items,
+ unsigned long missing_relevant_items = 0
+ )
+ {
+ using namespace dlib::impl;
+ double relevant_count = 0;
+ // find the precision values
+ std::vector<double> precision;
+ for (unsigned long i = 0; i < items.size(); ++i)
+ {
+ if (get_bool_part(items[i]))
+ {
+ ++relevant_count;
+ precision.push_back(relevant_count / (i+1));
+ }
+ }
+
+ double precision_sum = 0;
+ double max_val = 0;
+ // now sum over the interpolated precision values
+ for (std::vector<double>::reverse_iterator i = precision.rbegin(); i != precision.rend(); ++i)
+ {
+ max_val = std::max(max_val, *i);
+ precision_sum += max_val;
+ }
+
+
+ relevant_count += missing_relevant_items;
+
+ if (relevant_count != 0)
+ return precision_sum/relevant_count;
+ else
+ return 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AVERAGE_PREcISION_Hh_
+
diff --git a/ml/dlib/dlib/statistics/average_precision_abstract.h b/ml/dlib/dlib/statistics/average_precision_abstract.h
new file mode 100644
index 000000000..76c2c702a
--- /dev/null
+++ b/ml/dlib/dlib/statistics/average_precision_abstract.h
@@ -0,0 +1,67 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_AVERAGE_PREcISION_ABSTRACT_Hh_
+#ifdef DLIB_AVERAGE_PREcISION_ABSTRACT_Hh_
+
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename alloc
+ >
+ double average_precision (
+ const std::vector<bool,alloc>& items,
+ unsigned long missing_relevant_items = 0
+ );
+ /*!
+ ensures
+ - Interprets items as a list of relevant and non-relevant items in a response
+ from an information retrieval system. In particular, items with a true value
+ are relevant and false items are non-relevant. This function then returns
+ the average precision of the ranking of the given items. For, example, the
+ ranking [true, true, true, true, false] would have an average precision of 1.
+ On the other hand, the ranking of [true false false true] would have an
+ average precision of 0.75 (because the first true has a precision of 1 and
+ the second true has a precision of 0.5, giving an average of 0.75).
+ - As a special case, if item contains no true elements then the average
+ precision is considered to be 1.
+ - Note that we use the interpolated precision. That is, the interpolated
+ precision at a recall value r is set to the maximum precision obtained at any
+ higher recall value. Or in other words, we interpolate the precision/recall
+ curve so that precision is monotonically decreasing. Therefore, the average
+ precision value returned by this function is the area under this interpolated
+ precision/recall curve.
+ - This function will add in missing_relevant_items number of items with a
+ precision of zero into the average value returned. For example, the average
+ precision of the ranking [true, true] if there are 2 missing relevant items
+ is 0.5.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double average_precision (
+ const std::vector<std::pair<T,bool>,alloc>& items,
+ unsigned long missing_relevant_items = 0
+ );
+ /*!
+ ensures
+ - this function is equivalent to copying the bool values from items into a
+ std::vector<bool> and then calling the above average_precision() routine on
+ it and returning the result.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AVERAGE_PREcISION_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/statistics/cca.h b/ml/dlib/dlib/statistics/cca.h
new file mode 100644
index 000000000..21300ea8f
--- /dev/null
+++ b/ml/dlib/dlib/statistics/cca.h
@@ -0,0 +1,186 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CCA_hh_
+#define DLIB_CCA_hh_
+
+#include "cca_abstract.h"
+#include "../algs.h"
+#include "../matrix.h"
+#include "../sparse_vector.h"
+#include "random_subset_selector.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ matrix<typename T::type,0,1> compute_correlations (
+ const matrix_exp<T>& L,
+ const matrix_exp<T>& R
+ )
+ {
+ DLIB_ASSERT( L.size() > 0 && R.size() > 0 && L.nr() == R.nr(),
+ "\t matrix compute_correlations()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t L.size(): " << L.size()
+ << "\n\t R.size(): " << R.size()
+ << "\n\t L.nr(): " << L.nr()
+ << "\n\t R.nr(): " << R.nr()
+ );
+
+ typedef typename T::type type;
+ matrix<type> A, B, C;
+ A = diag(trans(R)*L);
+ B = sqrt(diag(trans(L)*L));
+ C = sqrt(diag(trans(R)*R));
+ A = pointwise_multiply(A , reciprocal(pointwise_multiply(B,C)));
+ return A;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type,
+ typename T
+ >
+ matrix<T,0,1> impl_cca (
+ const matrix_type& L,
+ const matrix_type& R,
+ matrix<T>& Ltrans,
+ matrix<T>& Rtrans,
+ unsigned long num_correlations,
+ unsigned long extra_rank,
+ unsigned long q,
+ unsigned long num_output_correlations,
+ double regularization
+ )
+ {
+ matrix<T> Ul, Vl;
+ matrix<T> Ur, Vr;
+ matrix<T> U, V;
+ matrix<T,0,1> Dr, Dl, D;
+
+
+ // Note that we add a few more singular vectors in because it helps improve the
+ // final results if we run this part with a little higher rank than the final SVD.
+ svd_fast(L, Ul, Dl, Vl, num_correlations+extra_rank, q);
+ svd_fast(R, Ur, Dr, Vr, num_correlations+extra_rank, q);
+
+ // Zero out singular values that are essentially zero so they don't cause numerical
+ // difficulties in the code below.
+ const double eps = std::numeric_limits<T>::epsilon()*std::max(max(Dr),max(Dl))*100;
+ Dl = round_zeros(Dl+regularization,eps);
+ Dr = round_zeros(Dr+regularization,eps);
+
+ // This matrix is really small so we can do a normal full SVD on it. Note that we
+ // also throw away the columns of Ul and Ur corresponding to zero singular values.
+ svd3(diagm(Dl>0)*tmp(trans(Ul)*Ur)*diagm(Dr>0), U, D, V);
+
+ // now throw away extra columns of the transformations. We do this in a way
+ // that keeps the directions that have the highest correlations.
+ matrix<T,0,1> temp = D;
+ rsort_columns(U, temp);
+ rsort_columns(V, D);
+ U = colm(U, range(0, num_output_correlations-1));
+ V = colm(V, range(0, num_output_correlations-1));
+ D = rowm(D, range(0, num_output_correlations-1));
+
+ Ltrans = Vl*inv(diagm(Dl))*U;
+ Rtrans = Vr*inv(diagm(Dr))*V;
+
+ // Note that the D matrix contains the correlation values for the transformed
+ // vectors. However, when the L and R matrices have rank higher than
+ // num_correlations+extra_rank then the values in D become only approximate.
+ return D;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ matrix<T,0,1> cca (
+ const matrix<T>& L,
+ const matrix<T>& R,
+ matrix<T>& Ltrans,
+ matrix<T>& Rtrans,
+ unsigned long num_correlations,
+ unsigned long extra_rank = 5,
+ unsigned long q = 2,
+ double regularization = 0
+ )
+ {
+ DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr() &&
+ regularization >= 0,
+ "\t matrix cca()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t num_correlations: " << num_correlations
+ << "\n\t regularization: " << regularization
+ << "\n\t L.size(): " << L.size()
+ << "\n\t R.size(): " << R.size()
+ << "\n\t L.nr(): " << L.nr()
+ << "\n\t R.nr(): " << R.nr()
+ );
+
+ using std::min;
+ const unsigned long n = min(num_correlations, (unsigned long)min(R.nr(),min(L.nc(), R.nc())));
+ return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n, regularization);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename sparse_vector_type, typename T>
+ matrix<T,0,1> cca (
+ const std::vector<sparse_vector_type>& L,
+ const std::vector<sparse_vector_type>& R,
+ matrix<T>& Ltrans,
+ matrix<T>& Rtrans,
+ unsigned long num_correlations,
+ unsigned long extra_rank = 5,
+ unsigned long q = 2,
+ double regularization = 0
+ )
+ {
+ DLIB_ASSERT( num_correlations > 0 && L.size() == R.size() &&
+ max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0 &&
+ regularization >= 0,
+ "\t matrix cca()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t num_correlations: " << num_correlations
+ << "\n\t regularization: " << regularization
+ << "\n\t L.size(): " << L.size()
+ << "\n\t R.size(): " << R.size()
+ << "\n\t max_index_plus_one(L): " << max_index_plus_one(L)
+ << "\n\t max_index_plus_one(R): " << max_index_plus_one(R)
+ );
+
+ using std::min;
+ const unsigned long n = min(max_index_plus_one(L), max_index_plus_one(R));
+ const unsigned long num_output_correlations = min(num_correlations, std::min<unsigned long>(R.size(),n));
+ return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations, regularization);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename sparse_vector_type, typename Rand_type, typename T>
+ matrix<T,0,1> cca (
+ const random_subset_selector<sparse_vector_type,Rand_type>& L,
+ const random_subset_selector<sparse_vector_type,Rand_type>& R,
+ matrix<T>& Ltrans,
+ matrix<T>& Rtrans,
+ unsigned long num_correlations,
+ unsigned long extra_rank = 5,
+ unsigned long q = 2
+ )
+ {
+ return cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CCA_hh_
+
+
diff --git a/ml/dlib/dlib/statistics/cca_abstract.h b/ml/dlib/dlib/statistics/cca_abstract.h
new file mode 100644
index 000000000..8e0b4e742
--- /dev/null
+++ b/ml/dlib/dlib/statistics/cca_abstract.h
@@ -0,0 +1,191 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CCA_AbSTRACT_Hh_
+#ifdef DLIB_CCA_AbSTRACT_Hh_
+
+#include "../matrix/matrix_la_abstract.h"
+#include "random_subset_selector_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ matrix<typename T::type,0,1> compute_correlations (
+ const matrix_exp<T>& L,
+ const matrix_exp<T>& R
+ );
+ /*!
+ requires
+ - L.size() > 0
+ - R.size() > 0
+ - L.nr() == R.nr()
+ ensures
+ - This function treats L and R as sequences of paired row vectors. It
+ then computes the correlation values between the elements of these
+ row vectors. In particular, we return a vector COR such that:
+ - COR.size() == L.nc()
+ - for all valid i:
+ - COR(i) == the correlation coefficient between the following sequence
+ of paired numbers: (L(k,i), R(k,i)) for k: 0 <= k < L.nr().
+ Therefore, COR(i) is a value between -1 and 1 inclusive where 1
+ indicates perfect correlation and -1 perfect anti-correlation. Note
+ that this function assumes the input data vectors have been centered
+ (i.e. made to have zero mean). If this is not the case then it will
+ report inaccurate results.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ matrix<T,0,1> cca (
+ const matrix<T>& L,
+ const matrix<T>& R,
+ matrix<T>& Ltrans,
+ matrix<T>& Rtrans,
+ unsigned long num_correlations,
+ unsigned long extra_rank = 5,
+ unsigned long q = 2,
+ double regularization = 0
+ );
+ /*!
+ requires
+ - num_correlations > 0
+ - L.size() > 0
+ - R.size() > 0
+ - L.nr() == R.nr()
+ - regularization >= 0
+ ensures
+ - This function performs a canonical correlation analysis between the row
+ vectors in L and R. That is, it finds two transformation matrices, Ltrans
+ and Rtrans, such that row vectors in the transformed matrices L*Ltrans and
+ R*Rtrans are as correlated as possible. That is, we try to find two transforms
+ such that the correlation values returned by compute_correlations(L*Ltrans, R*Rtrans)
+ would be maximized.
+ - Let N == min(num_correlations, min(R.nr(),min(L.nc(),R.nc())))
+ (This is the actual number of elements in the transformed vectors.
+ Therefore, note that you can't get more outputs than there are rows or
+ columns in the input matrices.)
+ - #Ltrans.nr() == L.nc()
+ - #Ltrans.nc() == N
+ - #Rtrans.nr() == R.nc()
+ - #Rtrans.nc() == N
+ - This function assumes the data vectors in L and R have already been centered
+ (i.e. we assume the vectors have zero means). However, in many cases it is
+ fine to use uncentered data with cca(). But if it is important for your
+ problem then you should center your data before passing it to cca().
+ - This function works with reduced rank approximations of the L and R matrices.
+ This makes it fast when working with large matrices. In particular, we use
+ the svd_fast() routine to find reduced rank representations of the input
+ matrices by calling it as follows: svd_fast(L, U,D,V, num_correlations+extra_rank, q)
+ and similarly for R. This means that you can use the extra_rank and q
+ arguments to cca() to influence the accuracy of the reduced rank
+ approximation. However, the default values should work fine for most
+ problems.
+ - returns an estimate of compute_correlations(L*#Ltrans, R*#Rtrans). The
+ returned vector should exactly match the output of compute_correlations()
+ when the reduced rank approximation to L and R is accurate and regularization
+ is set to 0. However, if this is not the case then the return value of this
+ function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans). This
+ deviation can be used to check if the reduced rank approximation is working
+ or you need to increase extra_rank.
+ - The dimensions of the output vectors produced by L*#Ltrans or R*#Rtrans are
+ ordered such that the dimensions with the highest correlations come first.
+ That is, after applying the transforms produced by cca() to a set of vectors
+ you will find that dimension 0 has the highest correlation, then dimension 1
+ has the next highest, and so on. This also means that the list of numbers
+ returned from cca() will always be listed in decreasing order.
+ - This function performs the ridge regression version of Canonical Correlation
+ Analysis when regularization is set to a value > 0. In particular, larger
+ values indicate the solution should be more heavily regularized. This can be
+ useful when the dimensionality of the data is larger than the number of
+ samples.
+ - A good discussion of CCA can be found in the paper "Canonical Correlation
+ Analysis" by David Weenink. In particular, this function is implemented
+ using equations 29 and 30 from his paper. We also use the idea of doing CCA
+ on a reduced rank approximation of L and R as suggested by Paramveer S.
+ Dhillon in his paper "Two Step CCA: A new spectral method for estimating
+ vector models of words".
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sparse_vector_type,
+ typename T
+ >
+ matrix<T,0,1> cca (
+ const std::vector<sparse_vector_type>& L,
+ const std::vector<sparse_vector_type>& R,
+ matrix<T>& Ltrans,
+ matrix<T>& Rtrans,
+ unsigned long num_correlations,
+ unsigned long extra_rank = 5,
+ unsigned long q = 2,
+ double regularization = 0
+ );
+ /*!
+ requires
+ - num_correlations > 0
+ - L.size() == R.size()
+ - max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0
+ (i.e. L and R can't represent empty matrices)
+ - L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
+ for a definition of sparse vector)
+ - regularization >= 0
+ ensures
+ - This is just an overload of the cca() function defined above. Except in this
+ case we take a sparse representation of the input L and R matrices rather than
+ dense matrices. Therefore, in this case, we interpret L and R as matrices
+ with L.size() rows, where each row is defined by a sparse vector. So this
+ function does exactly the same thing as the above cca().
+ - Note that you can apply the output transforms to a sparse vector with the
+ following code:
+ sparse_matrix_vector_multiply(trans(Ltrans), your_sparse_vector)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sparse_vector_type,
+ typename Rand_type,
+ typename T
+ >
+ matrix<T,0,1> cca (
+ const random_subset_selector<sparse_vector_type,Rand_type>& L,
+ const random_subset_selector<sparse_vector_type,Rand_type>& R,
+ matrix<T>& Ltrans,
+ matrix<T>& Rtrans,
+ unsigned long num_correlations,
+ unsigned long extra_rank = 5,
+ unsigned long q = 2,
+ double regularization = 0
+ );
+ /*!
+ requires
+ - num_correlations > 0
+ - L.size() == R.size()
+ - max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0
+ (i.e. L and R can't represent empty matrices)
+ - L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
+ for a definition of sparse vector)
+ - regularization >= 0
+ ensures
+ - returns cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q)
+ (i.e. this is just a convenience function for calling the cca() routine when
+ your sparse vectors are contained inside a random_subset_selector rather than
+ a std::vector)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CCA_AbSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/statistics/dpca.h b/ml/dlib/dlib/statistics/dpca.h
new file mode 100644
index 000000000..cae784682
--- /dev/null
+++ b/ml/dlib/dlib/statistics/dpca.h
@@ -0,0 +1,541 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DPCA_h_
+#define DLIB_DPCA_h_
+
+#include "dpca_abstract.h"
+#include <limits>
+#include <cmath>
+#include "../algs.h"
+#include "../matrix.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class discriminant_pca
+ {
+ /*!
+ INITIAL VALUE
+ - vect_size == 0
+ - total_count == 0
+ - between_count == 0
+ - within_count == 0
+ - between_weight == 1
+ - within_weight == 1
+
+ CONVENTION
+ - vect_size == in_vector_size()
+ - total_count == the number of times add_to_total_variance() has been called.
+ - within_count == the number of times add_to_within_class_variance() has been called.
+ - between_count == the number of times add_to_between_class_variance() has been called.
+ - between_weight == between_class_weight()
+ - within_weight == within_class_weight()
+
+ - if (total_count != 0)
+ - total_sum == the sum of all vectors given to add_to_total_variance()
+ - the covariance of all the elements given to add_to_total_variance() is given
+ by:
+ - let avg == total_sum/total_count
+ - covariance == total_cov/total_count - avg*trans(avg)
+ - if (within_count != 0)
+ - within_cov/within_count == the normalized within class scatter matrix
+ - if (between_count != 0)
+ - between_cov/between_count == the normalized between class scatter matrix
+ !*/
+
+ public:
+
+ struct discriminant_pca_error : public error
+ {
+ discriminant_pca_error(const std::string& message): error(message) {}
+ };
+
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+ discriminant_pca (
+ )
+ {
+ clear();
+ }
+
+ void clear(
+ )
+ {
+ total_count = 0;
+ between_count = 0;
+ within_count = 0;
+
+ vect_size = 0;
+
+
+ between_weight = 1;
+ within_weight = 1;
+
+
+ total_sum.set_size(0);
+ between_cov.set_size(0,0);
+ total_cov.set_size(0,0);
+ within_cov.set_size(0,0);
+ }
+
+ long in_vector_size (
+ ) const
+ {
+ return vect_size;
+ }
+
+ void set_within_class_weight (
+ scalar_type weight
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(weight >= 0,
+ "\t void discriminant_pca::set_within_class_weight()"
+ << "\n\t You can't use negative weight values"
+ << "\n\t weight: " << weight
+ << "\n\t this: " << this
+ );
+
+ within_weight = weight;
+ }
+
+ scalar_type within_class_weight (
+ ) const
+ {
+ return within_weight;
+ }
+
+ void set_between_class_weight (
+ scalar_type weight
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(weight >= 0,
+ "\t void discriminant_pca::set_between_class_weight()"
+ << "\n\t You can't use negative weight values"
+ << "\n\t weight: " << weight
+ << "\n\t this: " << this
+ );
+
+ between_weight = weight;
+ }
+
+ scalar_type between_class_weight (
+ ) const
+ {
+ return between_weight;
+ }
+
+ template <typename EXP1, typename EXP2>
+ void add_to_within_class_variance(
+ const matrix_exp<EXP1>& x,
+ const matrix_exp<EXP2>& y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(x) && is_col_vector(y) &&
+ x.size() == y.size() &&
+ (in_vector_size() == 0 || x.size() == in_vector_size()),
+ "\t void discriminant_pca::add_to_within_class_variance()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t is_col_vector(y): " << is_col_vector(y)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t y.size(): " << y.size()
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t this: " << this
+ );
+
+ vect_size = x.size();
+ if (within_count == 0)
+ {
+ within_cov = (x-y)*trans(x-y);
+ }
+ else
+ {
+ within_cov += (x-y)*trans(x-y);
+ }
+ ++within_count;
+ }
+
+ template <typename EXP1, typename EXP2>
+ void add_to_between_class_variance(
+ const matrix_exp<EXP1>& x,
+ const matrix_exp<EXP2>& y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(x) && is_col_vector(y) &&
+ x.size() == y.size() &&
+ (in_vector_size() == 0 || x.size() == in_vector_size()),
+ "\t void discriminant_pca::add_to_between_class_variance()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t is_col_vector(y): " << is_col_vector(y)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t y.size(): " << y.size()
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t this: " << this
+ );
+
+ vect_size = x.size();
+ if (between_count == 0)
+ {
+ between_cov = (x-y)*trans(x-y);
+ }
+ else
+ {
+ between_cov += (x-y)*trans(x-y);
+ }
+ ++between_count;
+ }
+
+ template <typename EXP>
+ void add_to_total_variance(
+ const matrix_exp<EXP>& x
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(x) && (in_vector_size() == 0 || x.size() == in_vector_size()),
+ "\t void discriminant_pca::add_to_total_variance()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t x.size(): " << x.size()
+ << "\n\t this: " << this
+ );
+
+ vect_size = x.size();
+ if (total_count == 0)
+ {
+ total_cov = x*trans(x);
+ total_sum = x;
+ }
+ else
+ {
+ total_cov += x*trans(x);
+ total_sum += x;
+ }
+ ++total_count;
+ }
+
+ const general_matrix dpca_matrix (
+ const double eps = 0.99
+ ) const
+ {
+ general_matrix dpca_mat;
+ general_matrix eigenvalues;
+ dpca_matrix(dpca_mat, eigenvalues, eps);
+ return dpca_mat;
+ }
+
+ const general_matrix dpca_matrix_of_size (
+ const long num_rows
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < num_rows && num_rows <= in_vector_size(),
+ "\t general_matrix discriminant_pca::dpca_matrix_of_size()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t num_rows: " << num_rows
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t this: " << this
+ );
+
+ general_matrix dpca_mat;
+ general_matrix eigenvalues;
+ dpca_matrix_of_size(dpca_mat, eigenvalues, num_rows);
+ return dpca_mat;
+ }
+
+ void dpca_matrix (
+ general_matrix& dpca_mat,
+ general_matrix& eigenvalues,
+ const double eps = 0.99
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < eps && eps <= 1 && in_vector_size() != 0,
+ "\t void discriminant_pca::dpca_matrix()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t eps: " << eps
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t this: " << this
+ );
+
+ compute_dpca_matrix(dpca_mat, eigenvalues, eps, 0);
+ }
+
+ void dpca_matrix_of_size (
+ general_matrix& dpca_mat,
+ general_matrix& eigenvalues,
+ const long num_rows
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < num_rows && num_rows <= in_vector_size(),
+ "\t general_matrix discriminant_pca::dpca_matrix_of_size()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t num_rows: " << num_rows
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t this: " << this
+ );
+
+ compute_dpca_matrix(dpca_mat, eigenvalues, 1, num_rows);
+ }
+
+ void swap (
+ discriminant_pca& item
+ )
+ {
+ using std::swap;
+ swap(total_cov, item.total_cov);
+ swap(total_sum, item.total_sum);
+ swap(total_count, item.total_count);
+ swap(vect_size, item.vect_size);
+ swap(between_cov, item.between_cov);
+
+ swap(between_count, item.between_count);
+ swap(between_weight, item.between_weight);
+ swap(within_cov, item.within_cov);
+ swap(within_count, item.within_count);
+ swap(within_weight, item.within_weight);
+ }
+
+ friend void deserialize (
+ discriminant_pca& item,
+ std::istream& in
+ )
+ {
+ deserialize( item.total_cov, in);
+ deserialize( item.total_sum, in);
+ deserialize( item.total_count, in);
+ deserialize( item.vect_size, in);
+ deserialize( item.between_cov, in);
+ deserialize( item.between_count, in);
+ deserialize( item.between_weight, in);
+ deserialize( item.within_cov, in);
+ deserialize( item.within_count, in);
+ deserialize( item.within_weight, in);
+ }
+
+ friend void serialize (
+ const discriminant_pca& item,
+ std::ostream& out
+ )
+ {
+ serialize( item.total_cov, out);
+ serialize( item.total_sum, out);
+ serialize( item.total_count, out);
+ serialize( item.vect_size, out);
+ serialize( item.between_cov, out);
+ serialize( item.between_count, out);
+ serialize( item.between_weight, out);
+ serialize( item.within_cov, out);
+ serialize( item.within_count, out);
+ serialize( item.within_weight, out);
+ }
+
+ discriminant_pca operator+ (
+ const discriminant_pca& item
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT((in_vector_size() == 0 || item.in_vector_size() == 0 || in_vector_size() == item.in_vector_size()) &&
+ between_class_weight() == item.between_class_weight() &&
+ within_class_weight() == item.within_class_weight(),
+ "\t discriminant_pca discriminant_pca::operator+()"
+ << "\n\t The two discriminant_pca objects being added must have compatible parameters"
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t item.in_vector_size(): " << item.in_vector_size()
+ << "\n\t between_class_weight(): " << between_class_weight()
+ << "\n\t item.between_class_weight(): " << item.between_class_weight()
+ << "\n\t within_class_weight(): " << within_class_weight()
+ << "\n\t item.within_class_weight(): " << item.within_class_weight()
+ << "\n\t this: " << this
+ );
+
+ discriminant_pca temp(item);
+
+ // We need to make sure to ignore empty matrices. That's what these if statements
+ // are for.
+
+ if (total_count != 0 && temp.total_count != 0)
+ {
+ temp.total_cov += total_cov;
+ temp.total_sum += total_sum;
+ temp.total_count += total_count;
+ }
+ else if (total_count != 0)
+ {
+ temp.total_cov = total_cov;
+ temp.total_sum = total_sum;
+ temp.total_count = total_count;
+ }
+
+ if (between_count != 0 && temp.between_count != 0)
+ {
+ temp.between_cov += between_cov;
+ temp.between_count += between_count;
+ }
+ else if (between_count != 0)
+ {
+ temp.between_cov = between_cov;
+ temp.between_count = between_count;
+ }
+
+ if (within_count != 0 && temp.within_count != 0)
+ {
+ temp.within_cov += within_cov;
+ temp.within_count += within_count;
+ }
+ else if (within_count != 0)
+ {
+ temp.within_cov = within_cov;
+ temp.within_count = within_count;
+ }
+
+ return temp;
+ }
+
+ discriminant_pca& operator+= (
+ const discriminant_pca& rhs
+ )
+ {
+ (*this + rhs).swap(*this);
+ return *this;
+ }
+
+ private:
+
+ void compute_dpca_matrix (
+ general_matrix& dpca_mat,
+ general_matrix& eigenvalues,
+ const double eps,
+ long num_rows
+ ) const
+ {
+ general_matrix cov;
+
+ // now combine the three measures of variance into a single matrix by using the
+ // within_weight and between_weight weights.
+ cov = get_total_covariance_matrix();
+ if (within_count != 0)
+ cov -= within_weight*within_cov/within_count;
+ if (between_count != 0)
+ cov += between_weight*between_cov/between_count;
+
+
+ eigenvalue_decomposition<general_matrix> eig(make_symmetric(cov));
+
+ eigenvalues = eig.get_real_eigenvalues();
+ dpca_mat = eig.get_pseudo_v();
+
+ // sort the eigenvalues and eigenvectors so that the biggest eigenvalues come first
+ rsort_columns(dpca_mat, eigenvalues);
+
+ long num_vectors = 0;
+ if (num_rows == 0)
+ {
+ // Some of the eigenvalues might be negative. So first lets zero those out
+ // so they won't get considered.
+ eigenvalues = pointwise_multiply(eigenvalues > 0, eigenvalues);
+ // figure out how many eigenvectors we want in our dpca matrix
+ const double thresh = sum(eigenvalues)*eps;
+ double total = 0;
+ for (long r = 0; r < eigenvalues.size() && total < thresh; ++r)
+ {
+ // Don't even think about looking at eigenvalues that are 0. If we go this
+ // far then we have all we need.
+ if (eigenvalues(r) == 0)
+ break;
+
+ ++num_vectors;
+ total += eigenvalues(r);
+ }
+
+ if (num_vectors == 0)
+ throw discriminant_pca_error("While performing discriminant_pca, all eigenvalues were negative or 0");
+ }
+ else
+ {
+ num_vectors = num_rows;
+ }
+
+
+ // So now we know we want to use num_vectors of the first eigenvectors. So
+ // pull those out and discard the rest.
+ dpca_mat = trans(colm(dpca_mat,range(0,num_vectors-1)));
+
+ // also clip off the eigenvalues we aren't using
+ eigenvalues = rowm(eigenvalues, range(0,num_vectors-1));
+
+ }
+
+ general_matrix get_total_covariance_matrix (
+ ) const
+ /*!
+ ensures
+ - returns the covariance matrix of all the data given to the add_to_total_variance()
+ !*/
+ {
+ // if we don't even know the dimensionality of the vectors we are dealing
+ // with then just return an empty matrix
+ if (vect_size == 0)
+ return general_matrix();
+
+ // we know the vector size but we have zero total covariance.
+ if (total_count == 0)
+ {
+ general_matrix temp(vect_size,vect_size);
+ temp = 0;
+ return temp;
+ }
+
+ // In this case we actually have something to make a total covariance matrix out of.
+ // So do that.
+ column_matrix avg = total_sum/total_count;
+
+ return total_cov/total_count - avg*trans(avg);
+ }
+
+ general_matrix total_cov;
+ column_matrix total_sum;
+ scalar_type total_count;
+
+ long vect_size;
+
+ general_matrix between_cov;
+ scalar_type between_count;
+ scalar_type between_weight;
+
+ general_matrix within_cov;
+ scalar_type within_count;
+ scalar_type within_weight;
+ };
+
+ template <
+ typename matrix_type
+ >
+ inline void swap (
+ discriminant_pca<matrix_type>& a,
+ discriminant_pca<matrix_type>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DPCA_h_
+
+
diff --git a/ml/dlib/dlib/statistics/dpca_abstract.h b/ml/dlib/dlib/statistics/dpca_abstract.h
new file mode 100644
index 000000000..d9eef635b
--- /dev/null
+++ b/ml/dlib/dlib/statistics/dpca_abstract.h
@@ -0,0 +1,365 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_DPCA_ABSTRaCT_
+#ifdef DLIB_DPCA_ABSTRaCT_
+
+#include <limits>
+#include <cmath>
+#include "../matrix/matrix_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class discriminant_pca
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ Must be some type of dlib::matrix.
+
+ INITIAL VALUE
+ - in_vector_size() == 0
+ - between_class_weight() == 1
+ - within_class_weight() == 1
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements the Discriminant PCA technique described in the paper:
+ A New Discriminant Principal Component Analysis Method with Partial Supervision (2009)
+ by Dan Sun and Daoqiang Zhang
+
+ This algorithm is basically a straightforward generalization of the classical PCA
+ technique to handle partially labeled data. It is useful if you want to learn a linear
+ dimensionality reduction rule using a bunch of data that is partially labeled.
+
+ It functions by estimating three different scatter matrices. The first is the total scatter
+ matrix St (i.e. the total data covariance matrix), the second is the between class scatter
+ matrix Sb (basically a measure of the variance between data of different classes) and the
+ third is the within class scatter matrix Sw (a measure of the variance of data within the
+ same classes).
+
+ Once these three matrices are estimated they are combined according to the following equation:
+ S = St + a*Sb - b*Sw
+ Where a and b are user supplied weights. Then the largest eigenvalues of the S matrix are
+ computed and their associated eigenvectors are returned as the output of this algorithm.
+ That is, the desired linear dimensionality reduction is given by the matrix with these
+ eigenvectors stored in its rows.
+
+ Note that if a and b are set to 0 (or no labeled data is provided) then the output transformation
+ matrix is the same as the one produced by the classical PCA algorithm.
+ !*/
+
+ public:
+
+ struct discriminant_pca_error : public error;
+ /*!
+ This exception is thrown if there is some error that prevents us from creating
+ a DPCA matrix.
+ !*/
+
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+ discriminant_pca (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ !*/
+
+ long in_vector_size (
+ ) const;
+ /*!
+ ensures
+ - if (this object has been presented with any input vectors) then
+ - returns the dimension of the column vectors used with this object
+ - else
+ - returns 0
+ !*/
+
+ void set_within_class_weight (
+ scalar_type weight
+ );
+ /*!
+ requires
+ - weight >= 0
+ ensures
+ - #within_class_weight() == weight
+ !*/
+
+ scalar_type within_class_weight (
+ ) const;
+ /*!
+ ensures
+ - returns the weight used when combining the within class scatter matrix with
+ the other scatter matrices.
+ !*/
+
+ void set_between_class_weight (
+ scalar_type weight
+ );
+ /*!
+ requires
+ - weight >= 0
+ ensures
+ - #between_class_weight() == weight
+ !*/
+
+ scalar_type between_class_weight (
+ ) const;
+ /*!
+ ensures
+ - returns the weight used when combining the between class scatter matrix with
+ the other scatter matrices.
+ !*/
+
+ void add_to_within_class_variance(
+ const matrix_exp& x,
+ const matrix_exp& y
+ );
+ /*!
+ requires
+ - is_col_vector(x) == true
+ - is_col_vector(y) == true
+ - x.size() == y.size()
+ - if (in_vector_size() != 0) then
+ - x.size() == y.size() == in_vector_size()
+ ensures
+ - #in_vector_size() == x.size()
+ - Adds (x-y)*trans(x-y) to the within class scatter matrix.
+ (i.e. the direction given by (x-y) is recorded as being a direction associated
+ with within class variance and is therefore unimportant and will be weighted
+ less in the final dimensionality reduction)
+ !*/
+
+ void add_to_between_class_variance(
+ const matrix_exp& x,
+ const matrix_exp& y
+ );
+ /*!
+ requires
+ - is_col_vector(x) == true
+ - is_col_vector(y) == true
+ - x.size() == y.size()
+ - if (in_vector_size() != 0) then
+ - x.size() == y.size() == in_vector_size()
+ ensures
+ - #in_vector_size() == x.size()
+ - Adds (x-y)*trans(x-y) to the between class scatter matrix.
+ (i.e. the direction given by (x-y) is recorded as being a direction associated
+ with between class variance and is therefore important and will be weighted
+ higher in the final dimensionality reduction)
+ !*/
+
+ void add_to_total_variance(
+ const matrix_exp& x
+ );
+ /*!
+ requires
+ - is_col_vector(x) == true
+ - if (in_vector_size() != 0) then
+ - x.size() == in_vector_size()
+ ensures
+ - #in_vector_size() == x.size()
+ - let M denote the centroid (or mean) of all the data. Then this function
+ Adds (x-M)*trans(x-M) to the total scatter matrix.
+ (i.e. the direction given by (x-M) is recorded as being a direction associated
+ with unlabeled variance and is therefore of default importance and will be weighted
+ as described in the discriminant_pca class description.)
+ !*/
+
+ const general_matrix dpca_matrix (
+ const double eps = 0.99
+ ) const;
+ /*!
+ requires
+ - 0 < eps <= 1
+ - in_vector_size() != 0
+ (i.e. you have to have given this object some data)
+ ensures
+ - computes and returns the matrix MAT given by dpca_matrix(MAT,eigen,eps).
+ That is, this function returns the dpca_matrix computed by the function
+ defined below.
+ - Note that MAT is the desired linear transformation matrix. That is,
+ multiplying a vector by MAT performs the desired linear dimensionality reduction.
+ throws
+ - discriminant_pca_error
+ This exception is thrown if we are unable to create the dpca_matrix for some
+ reason. For example, if only within class examples have been given or
+ within_class_weight() is very large then all eigenvalues will be negative and
+ that prevents this algorithm from working properly.
+ !*/
+
+ void dpca_matrix (
+ general_matrix& dpca_mat,
+ general_matrix& eigenvalues,
+ const double eps = 0.99
+ ) const;
+ /*!
+ requires
+ - 0 < eps <= 1
+ - in_vector_size() != 0
+ (i.e. you have to have given this object some data)
+ ensures
+ - is_col_vector(#eigenvalues) == true
+ - #dpca_mat.nr() == eigenvalues.size()
+ - #dpca_mat.nc() == in_vector_size()
+ - rowm(#dpca_mat,i) represents the ith eigenvector of the S matrix described
+ in the class description and its eigenvalue is given by eigenvalues(i).
+ - all values in #eigenvalues are > 0. Moreover, the eigenvalues are in
+ sorted order with the largest eigenvalue stored at eigenvalues(0).
+ - (#dpca_mat)*trans(#dpca_mat) == identity_matrix.
+ (i.e. the rows of the dpca_matrix are all unit length vectors and are mutually
+ orthogonal)
+ - Note that #dpca_mat is the desired linear transformation matrix. That is,
+ multiplying a vector by #dpca_mat performs the desired linear dimensionality
+ reduction.
+ - sum(#eigenvalues) will be equal to about eps times the total sum of all
+ positive eigenvalues in the S matrix described in this class's description.
+ This means that eps is a number that controls how "lossy" the dimensionality
+ reduction will be. Large values of eps result in more output dimensions
+ while smaller values result in fewer.
+ throws
+ - discriminant_pca_error
+ This exception is thrown if we are unable to create the dpca_matrix for some
+ reason. For example, if only within class examples have been given or
+ within_class_weight() is very large then all eigenvalues will be negative and
+ that prevents this algorithm from working properly.
+ !*/
+
+ const general_matrix dpca_matrix_of_size (
+ const long num_rows
+ );
+ /*!
+ requires
+ - 0 < num_rows <= in_vector_size()
+ ensures
+ - computes and returns the matrix MAT given by dpca_matrix_of_size(MAT,eigen,num_rows).
+ That is, this function returns the dpca_matrix computed by the function
+ defined below.
+ - Note that MAT is the desired linear transformation matrix. That is,
+ multiplying a vector by MAT performs the desired linear dimensionality
+ reduction to num_rows dimensions.
+ !*/
+
+ void dpca_matrix_of_size (
+ general_matrix& dpca_mat,
+ general_matrix& eigenvalues,
+ const long num_rows
+ );
+ /*!
+ requires
+ - 0 < num_rows <= in_vector_size()
+ ensures
+ - is_col_vector(#eigenvalues) == true
+ - #dpca_mat.nr() == eigenvalues.size()
+ - #dpca_mat.nr() == num_rows
+ - #dpca_mat.nc() == in_vector_size()
+ - rowm(#dpca_mat,i) represents the ith eigenvector of the S matrix described
+ in the class description and its eigenvalue is given by eigenvalues(i).
+ - The values in #eigenvalues might be positive or negative. Additionally, the
+ eigenvalues are in sorted order with the largest eigenvalue stored at
+ eigenvalues(0).
+ - (#dpca_mat)*trans(#dpca_mat) == identity_matrix.
+ (i.e. the rows of the dpca_matrix are all unit length vectors and are mutually
+ orthogonal)
+ - Note that #dpca_mat is the desired linear transformation matrix. That is,
+ multiplying a vector by #dpca_mat performs the desired linear dimensionality
+ reduction to num_rows dimensions.
+ !*/
+
+ discriminant_pca operator+ (
+ const discriminant_pca& item
+ ) const;
+ /*!
+ requires
+ - in_vector_size() == 0 || item.in_vector_size() == 0 || in_vector_size() == item.in_vector_size()
+ (i.e. the in_vector_size() of *this and item must match or one must be zero)
+ - between_class_weight() == item.between_class_weight()
+ - within_class_weight() == item.within_class_weight()
+ ensures
+ - returns a new discriminant_pca object that represents the combination of all
+ the measurements given to *this and item. That is, this function returns a
+ discriminant_pca object, R, that is equivalent to what you would obtain if all
+ modifying calls (e.g. the add_to_*() functions) to *this and item had instead
+ been done to R.
+ !*/
+
+ discriminant_pca& operator+= (
+ const discriminant_pca& rhs
+ );
+ /*!
+ requires
+ - in_vector_size() == 0 || rhs.in_vector_size() == 0 || in_vector_size() == rhs.in_vector_size()
+ (i.e. the in_vector_size() of *this and rhs must match or one must be zero)
+ - between_class_weight() == rhs.between_class_weight()
+ - within_class_weight() == rhs.within_class_weight()
+ ensures
+ - #*this == *item + rhs
+ - returns #*this
+ !*/
+
+ void swap (
+ discriminant_pca& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ inline void swap (
+ discriminant_pca<matrix_type>& a,
+ discriminant_pca<matrix_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename matrix_type,
+ >
+ void deserialize (
+ discriminant_pca<matrix_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+ template <
+ typename matrix_type,
+ >
+ void serialize (
+ const discriminant_pca<matrix_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_DPCA_ABSTRaCT_
+
diff --git a/ml/dlib/dlib/statistics/image_feature_sampling.h b/ml/dlib/dlib/statistics/image_feature_sampling.h
new file mode 100644
index 000000000..f04f9926e
--- /dev/null
+++ b/ml/dlib/dlib/statistics/image_feature_sampling.h
@@ -0,0 +1,82 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_IMAGE_FEATURE_SaMPLING_Hh_
+#define DLIB_IMAGE_FEATURE_SaMPLING_Hh_
+
+#include "image_feature_sampling_abstract.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename feature_extractor_type,
+ typename pyramid_type
+ >
+ random_subset_selector<typename feature_extractor_type::descriptor_type> randomly_sample_image_features (
+ const image_array_type& images,
+ const pyramid_type& pyr,
+ const feature_extractor_type& fe_,
+ unsigned long num
+ )
+ {
+ feature_extractor_type fe;
+ fe.copy_configuration(fe_);
+ random_subset_selector<typename feature_extractor_type::descriptor_type> basis;
+ basis.set_max_size(num);
+
+ typedef typename image_array_type::type image_type;
+ image_type temp_img, temp_img2;
+
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ bool at_pyramid_top = true;
+ while (true)
+ {
+ if (at_pyramid_top)
+ fe.load(images[i]);
+ else
+ fe.load(temp_img);
+
+ if (fe.size() == 0)
+ break;
+
+ for (long r = 0; r < fe.nr(); ++r)
+ {
+ for (long c = 0; c < fe.nc(); ++c)
+ {
+ if (basis.next_add_accepts())
+ {
+ basis.add(fe(r,c));
+ }
+ else
+ {
+ basis.add();
+ }
+ }
+ }
+
+ if (at_pyramid_top)
+ {
+ at_pyramid_top = false;
+ pyr(images[i], temp_img);
+ }
+ else
+ {
+ pyr(temp_img, temp_img2);
+ swap(temp_img2,temp_img);
+ }
+ }
+ }
+ return basis;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IMAGE_FEATURE_SaMPLING_Hh_
+
diff --git a/ml/dlib/dlib/statistics/image_feature_sampling_abstract.h b/ml/dlib/dlib/statistics/image_feature_sampling_abstract.h
new file mode 100644
index 000000000..b51ef5423
--- /dev/null
+++ b/ml/dlib/dlib/statistics/image_feature_sampling_abstract.h
@@ -0,0 +1,45 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_IMAGE_FEATURE_SaMPLING_ABSTRACT_Hh_
+#ifdef DLIB_IMAGE_FEATURE_SaMPLING_ABSTRACT_Hh_
+
+#include "random_subset_selector_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename feature_extractor_type,
+ typename pyramid_type
+ >
+ random_subset_selector<typename feature_extractor_type::descriptor_type> randomly_sample_image_features (
+ const image_array_type& images,
+ const pyramid_type& pyr,
+ const feature_extractor_type& fe,
+ unsigned long num
+ );
+ /*!
+ requires
+ - pyramid_type == a type compatible with the image pyramid objects defined
+ in dlib/image_transforms/image_pyramid_abstract.h
+ - feature_extractor_type == a local image feature extractor type such as the
+ dlib::hog_image
+ - image_array_type == an implementation of dlib/array/array_kernel_abstract.h
+ and it must contain image objects which can be passed to pyr() and fe.load()
+ and are swappable by global swap().
+ ensures
+ - creates an image pyramid for each image in images and performs feature
+ extraction on each pyramid level. Then selects a random subsample of at
+ most num local feature vectors and returns it.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_IMAGE_FEATURE_SaMPLING_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/statistics/lda.h b/ml/dlib/dlib/statistics/lda.h
new file mode 100644
index 000000000..38de3fd1e
--- /dev/null
+++ b/ml/dlib/dlib/statistics/lda.h
@@ -0,0 +1,237 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LDA_Hh_
+#define DLIB_LDA_Hh_
+
+#include "lda_abstract.h"
+#include "../algs.h"
+#include <map>
+#include "../matrix.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ inline std::map<unsigned long,unsigned long> make_class_labels(
+ const std::vector<unsigned long>& row_labels
+ )
+ {
+ std::map<unsigned long,unsigned long> class_labels;
+ for (unsigned long i = 0; i < row_labels.size(); ++i)
+ {
+ const unsigned long next = class_labels.size();
+ if (class_labels.count(row_labels[i]) == 0)
+ class_labels[row_labels[i]] = next;
+ }
+ return class_labels;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ matrix<T,0,1> center_matrix (
+ matrix<T>& X
+ )
+ {
+ matrix<T,1> mean;
+ for (long r = 0; r < X.nr(); ++r)
+ mean += rowm(X,r);
+ mean /= X.nr();
+
+ for (long r = 0; r < X.nr(); ++r)
+ set_rowm(X,r) -= mean;
+
+ return trans(mean);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void compute_lda_transform (
+ matrix<T>& X,
+ matrix<T,0,1>& mean,
+ const std::vector<unsigned long>& row_labels,
+ unsigned long lda_dims = 500,
+ unsigned long extra_pca_dims = 200
+ )
+ {
+ std::map<unsigned long,unsigned long> class_labels = impl::make_class_labels(row_labels);
+ // LDA can only give out at most class_labels.size()-1 dimensions so don't try to
+ // compute more than that.
+ lda_dims = std::min<unsigned long>(lda_dims, class_labels.size()-1);
+
+ // make sure requires clause is not broken
+ DLIB_CASSERT(class_labels.size() > 1,
+ "\t void compute_lda_transform()"
+ << "\n\t You can't call this function if the number of distinct class labels is less than 2."
+ );
+ DLIB_CASSERT(X.size() != 0 && (long)row_labels.size() == X.nr() && lda_dims != 0,
+ "\t void compute_lda_transform()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t X.size(): " << X.size()
+ << "\n\t row_labels.size(): " << row_labels.size()
+ << "\n\t lda_dims: " << lda_dims
+ );
+
+
+ mean = impl::center_matrix(X);
+ // Do PCA to reduce dims
+ matrix<T> pu,pw,pv;
+ svd_fast(X, pu, pw, pv, lda_dims+extra_pca_dims, 4);
+ pu.set_size(0,0); // free RAM, we don't need pu.
+ X = X*pv;
+
+
+ matrix<T> class_means(class_labels.size(), X.nc());
+ class_means = 0;
+ matrix<T,0,1> class_counts(class_labels.size());
+ class_counts = 0;
+
+ // First compute the means of each class
+ for (unsigned long i = 0; i < row_labels.size(); ++i)
+ {
+ const unsigned long class_idx = class_labels[row_labels[i]];
+ set_rowm(class_means,class_idx) += rowm(X,i);
+ class_counts(class_idx)++;
+ }
+ class_means = inv(diagm(class_counts))*class_means;
+ // subtract means from the data
+ for (unsigned long i = 0; i < row_labels.size(); ++i)
+ {
+ const unsigned long class_idx = class_labels[row_labels[i]];
+ set_rowm(X,i) -= rowm(class_means,class_idx);
+ }
+
+ // Note that we are using the formulas from the paper Using Discriminant
+ // Eigenfeatures for Image Retrieval by Swets and Weng.
+ matrix<T> Sw = trans(X)*X;
+ matrix<T> Sb = trans(class_means)*class_means;
+ matrix<T> A, H;
+ matrix<T,0,1> W;
+ svd3(Sw, A, W, H);
+ W = sqrt(W);
+ W = reciprocal(lowerbound(W,max(W)*1e-5));
+ A = trans(H*diagm(W))*Sb*H*diagm(W);
+ matrix<T> v,s,u;
+ svd3(A, v, s, u);
+ matrix<T> tform = H*diagm(W)*u;
+ // pick out only the number of dimensions we are supposed to for the output, unless
+ // we should just keep them all, then don't do anything.
+ if ((long)lda_dims <= tform.nc())
+ {
+ rsort_columns(tform, s);
+ tform = colm(tform, range(0, lda_dims-1));
+ }
+
+ X = trans(pv*tform);
+ mean = X*mean;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::pair<double,double> equal_error_rate (
+ const std::vector<double>& low_vals,
+ const std::vector<double>& high_vals
+ )
+ {
+ std::vector<std::pair<double,int> > temp;
+ temp.reserve(low_vals.size()+high_vals.size());
+ for (unsigned long i = 0; i < low_vals.size(); ++i)
+ temp.push_back(std::make_pair(low_vals[i], -1));
+ for (unsigned long i = 0; i < high_vals.size(); ++i)
+ temp.push_back(std::make_pair(high_vals[i], +1));
+
+ std::sort(temp.begin(), temp.end());
+
+ if (temp.size() == 0)
+ return std::make_pair(0,0);
+
+ double thresh = temp[0].first;
+
+ unsigned long num_low_wrong = low_vals.size();
+ unsigned long num_high_wrong = 0;
+ double low_error = num_low_wrong/(double)low_vals.size();
+ double high_error = num_high_wrong/(double)high_vals.size();
+ for (unsigned long i = 0; i < temp.size() && high_error < low_error; ++i)
+ {
+ thresh = temp[i].first;
+ if (temp[i].second > 0)
+ {
+ num_high_wrong++;
+ high_error = num_high_wrong/(double)high_vals.size();
+ }
+ else
+ {
+ num_low_wrong--;
+ low_error = num_low_wrong/(double)low_vals.size();
+ }
+ }
+
+ return std::make_pair((low_error+high_error)/2, thresh);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ struct roc_point
+ {
+ double true_positive_rate;
+ double false_positive_rate;
+ double detection_threshold;
+ };
+
+ inline std::vector<roc_point> compute_roc_curve (
+ const std::vector<double>& true_detections,
+ const std::vector<double>& false_detections
+ )
+ {
+ DLIB_CASSERT(true_detections.size() != 0);
+ DLIB_CASSERT(false_detections.size() != 0);
+
+ std::vector<std::pair<double,int> > temp;
+ temp.reserve(true_detections.size()+false_detections.size());
+ for (unsigned long i = 0; i < true_detections.size(); ++i)
+ temp.push_back(std::make_pair(true_detections[i], +1));
+ for (unsigned long i = 0; i < false_detections.size(); ++i)
+ temp.push_back(std::make_pair(false_detections[i], -1));
+
+ std::sort(temp.rbegin(), temp.rend());
+
+
+ std::vector<roc_point> roc_curve;
+ roc_curve.reserve(temp.size());
+
+ double num_false_included = 0;
+ double num_true_included = 0;
+ for (unsigned long i = 0; i < temp.size(); ++i)
+ {
+ if (temp[i].second > 0)
+ num_true_included++;
+ else
+ num_false_included++;
+
+ roc_point p;
+ p.true_positive_rate = num_true_included/true_detections.size();
+ p.false_positive_rate = num_false_included/false_detections.size();
+ p.detection_threshold = temp[i].first;
+ roc_curve.push_back(p);
+ }
+
+ return roc_curve;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LDA_Hh_
+
diff --git a/ml/dlib/dlib/statistics/lda_abstract.h b/ml/dlib/dlib/statistics/lda_abstract.h
new file mode 100644
index 000000000..ab9fd7a32
--- /dev/null
+++ b/ml/dlib/dlib/statistics/lda_abstract.h
@@ -0,0 +1,118 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LDA_ABSTRACT_Hh_
+#ifdef DLIB_LDA_ABSTRACT_Hh_
+
+#include <map>
+#include "../matrix.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void compute_lda_transform (
+ matrix<T>& X,
+ matrix<T,0,1>& M,
+ const std::vector<unsigned long>& row_labels,
+ unsigned long lda_dims = 500,
+ unsigned long extra_pca_dims = 200
+ );
+ /*!
+ requires
+ - X.size() != 0
+ - row_labels.size() == X.nr()
+ - The number of distinct values in row_labels > 1
+ - lda_dims != 0
+ ensures
+ - We interpret X as a collection X.nr() of input vectors, where each row of X
+ is one of the vectors.
+ - We interpret row_labels[i] as the label of the vector rowm(X,i).
+ - This function performs the dimensionality reducing version of linear
+ discriminant analysis. That is, you give it a set of labeled vectors and it
+ returns a linear transform that maps the input vectors into a new space that
+ is good for distinguishing between the different classes. In particular,
+ this function finds matrices Z and M such that:
+ - Given an input vector x, Z*x-M, is the transformed version of x. That is,
+ Z*x-M maps x into a space where x vectors that share the same class label
+ are near each other.
+ - Z*x-M results in the transformed vectors having zero expected mean.
+ - Z.nr() <= lda_dims
+ (it might be less than lda_dims if there are not enough distinct class
+ labels to support lda_dims dimensions).
+ - Z.nc() == X.nc()
+ - We overwrite the input matrix X and store Z in it. Therefore, the
+ outputs of this function are in X and M.
+ - In order to deal with very high dimensional inputs, we perform PCA internally
+ to map the input vectors into a space of at most lda_dims+extra_pca_dims
+ prior to performing LDA.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::pair<double,double> equal_error_rate (
+ const std::vector<double>& low_vals,
+ const std::vector<double>& high_vals
+ );
+ /*!
+ ensures
+ - This function finds a threshold T that best separates the elements of
+ low_vals from high_vals by selecting the threshold with equal error rate. In
+ particular, we try to pick a threshold T such that:
+ - for all valid i:
+ - high_vals[i] >= T
+ - for all valid i:
+ - low_vals[i] < T
+ Where the best T is determined such that the fraction of low_vals >= T is the
+ same as the fraction of high_vals < T.
+ - Let ERR == the equal error rate. I.e. the fraction of times low_vals >= T
+ and high_vals < T. Note that 0 <= ERR <= 1.
+ - returns make_pair(ERR,T)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ struct roc_point
+ {
+ double true_positive_rate;
+ double false_positive_rate;
+ double detection_threshold;
+ };
+
+ std::vector<roc_point> compute_roc_curve (
+ const std::vector<double>& true_detections,
+ const std::vector<double>& false_detections
+ );
+ /*!
+ requires
+ - true_detections.size() != 0
+ - false_detections.size() != 0
+ ensures
+ - This function computes the ROC curve (receiver operating characteristic)
+ curve of the given data. Therefore, we interpret true_detections as
+ containing detection scores for a bunch of true detections and
+ false_detections as detection scores from a bunch of false detections. A
+ perfect detector would always give higher scores to true detections than to
+ false detections, resulting in a true positive rate of 1 and a false positive
+ rate of 0, for some appropriate detection threshold.
+ - Returns an array, ROC, such that:
+ - ROC.size() == true_detections.size()+false_detections.size()
+ - for all valid i:
+ - If you were to accept all detections with a score >= ROC[i].detection_threshold
+ then you would obtain a true positive rate of ROC[i].true_positive_rate and a
+ false positive rate of ROC[i].false_positive_rate.
+ - ROC is ordered such that low detection rates come first. That is, the
+ curve is swept from a high detection threshold to a low threshold.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LDA_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/statistics/random_subset_selector.h b/ml/dlib/dlib/statistics/random_subset_selector.h
new file mode 100644
index 000000000..17492363d
--- /dev/null
+++ b/ml/dlib/dlib/statistics/random_subset_selector.h
@@ -0,0 +1,372 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RANDOM_SUBSeT_SELECTOR_H_
+#define DLIB_RANDOM_SUBSeT_SELECTOR_H_
+
+#include "random_subset_selector_abstract.h"
+#include "../rand.h"
+#include <vector>
+#include "../algs.h"
+#include "../string.h"
+#include "../serialize.h"
+#include "../matrix/matrix_mat.h"
+#include <iostream>
+
+namespace dlib
+{
+ template <
+ typename T,
+ typename Rand_type = dlib::rand
+ >
+ class random_subset_selector
+ {
+ /*!
+ INITIAL VALUE
+ - _max_size == 0
+ - items.size() == 0
+ - count == 0
+ - _next_add_accepts == false
+
+ CONVENTION
+ - count == the number of times add() has been called since the last
+ time this object was empty.
+ - items.size() == size()
+ - max_size() == _max_size
+ - next_add_accepts() == _next_add_accepts
+ !*/
+ public:
+ typedef T type;
+ typedef T value_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef Rand_type rand_type;
+
+ typedef typename std::vector<T>::iterator iterator;
+ typedef typename std::vector<T>::const_iterator const_iterator;
+
+
+ random_subset_selector (
+ )
+ {
+ _max_size = 0;
+ make_empty();
+ }
+
+ void set_seed(const std::string& value)
+ {
+ rnd.set_seed(value);
+ }
+
+ void make_empty (
+ )
+ {
+ items.resize(0);
+ count = 0;
+ update_next_add_accepts();
+ }
+
+ const std::vector<T>& to_std_vector(
+ ) const { return items; }
+
+ size_t size (
+ ) const
+ {
+ return items.size();
+ }
+
+ void set_max_size (
+ unsigned long new_max_size
+ )
+ {
+ items.reserve(new_max_size);
+ make_empty();
+ _max_size = new_max_size;
+ update_next_add_accepts();
+ }
+
+ unsigned long max_size (
+ ) const
+ {
+ return _max_size;
+ }
+
+ T& operator[] (
+ unsigned long idx
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < size(),
+ "\tvoid random_subset_selector::operator[]()"
+ << "\n\t idx is out of range"
+ << "\n\t idx: " << idx
+ << "\n\t size(): " << size()
+ << "\n\t this: " << this
+ );
+
+ return items[idx];
+ }
+
+ const T& operator[] (
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(idx < size(),
+ "\tvoid random_subset_selector::operator[]()"
+ << "\n\t idx is out of range"
+ << "\n\t idx: " << idx
+ << "\n\t size(): " << size()
+ << "\n\t this: " << this
+ );
+
+ return items[idx];
+ }
+
+ iterator begin() { return items.begin(); }
+ const_iterator begin() const { return items.begin(); }
+ iterator end() { return items.end(); }
+ const_iterator end() const { return items.end(); }
+
+ bool next_add_accepts (
+ ) const
+ {
+ return _next_add_accepts;
+ }
+
+ void add (
+ const T& new_item
+ )
+ {
+ if (items.size() < _max_size)
+ {
+ items.push_back(new_item);
+ // swap into a random place
+ exchange(items[rnd.get_random_32bit_number()%items.size()], items.back());
+ }
+ else if (_next_add_accepts)
+ {
+ // pick a random element of items and replace it.
+ items[rnd.get_random_32bit_number()%items.size()] = new_item;
+ }
+
+ update_next_add_accepts();
+ ++count;
+ }
+
+ void add (
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(next_add_accepts() == false,
+ "\tvoid random_subset_selector::add()"
+ << "\n\t You should be calling the version of add() that takes an argument"
+ << "\n\t this: " << this
+ );
+
+ update_next_add_accepts();
+ ++count;
+ }
+
+ void swap (
+ random_subset_selector& a
+ )
+ {
+ items.swap(a.items);
+ std::swap(_max_size, a._max_size);
+ std::swap(count, a.count);
+ rnd.swap(a.rnd);
+ std::swap(_next_add_accepts, a._next_add_accepts);
+ }
+
+ template <typename T1, typename T2>
+ friend void serialize (
+ const random_subset_selector<T1,T2>& item,
+ std::ostream& out
+ );
+
+ template <typename T1, typename T2>
+ friend void deserialize (
+ random_subset_selector<T1,T2>& item,
+ std::istream& in
+ );
+
+ private:
+
+ void update_next_add_accepts (
+ )
+ {
+ if (items.size() < _max_size)
+ {
+ _next_add_accepts = true;
+ }
+ else if (_max_size == 0)
+ {
+ _next_add_accepts = false;
+ }
+ else
+ {
+ // At this point each element of items has had an equal chance of being in this object.
+ // In particular, the probability that each arrived here is currently items.size()/count.
+ // We need to be able to say that, after this function ends, the probability of any
+ // particular object ending up in items is items.size()/(count+1). So this means that
+ // we should decide to add a new item into items with this probability. Also, if we do
+ // so then we pick one of the current items and replace it at random with the new item.
+
+ // Make me a random 64 bit number. This might seem excessive but I want this object
+ // to be able to handle an effectively infinite number of calls to add(). So count
+ // might get very large and we need to deal with that properly.
+ const unsigned long num1 = rnd.get_random_32bit_number();
+ const unsigned long num2 = rnd.get_random_32bit_number();
+ uint64 num = num1;
+ num <<= 32;
+ num |= num2;
+
+ num %= (count+1);
+
+ _next_add_accepts = (num < items.size());
+ }
+
+ }
+
+ std::vector<T> items;
+ unsigned long _max_size;
+ uint64 count;
+
+ rand_type rnd;
+
+ bool _next_add_accepts;
+
+ };
+
+ template <
+ typename T,
+ typename rand_type
+ >
+ void swap (
+ random_subset_selector<T,rand_type>& a,
+ random_subset_selector<T,rand_type>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T1, typename T2>
+ void serialize (
+ const random_subset_selector<T1,T2>& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.items, out);
+ serialize(item._max_size, out);
+ serialize(item.count, out);
+ serialize(item.rnd, out);
+ serialize(item._next_add_accepts, out);
+ }
+
+ template <typename T1, typename T2>
+ void deserialize (
+ random_subset_selector<T1,T2>& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.items, in);
+ deserialize(item._max_size, in);
+ deserialize(item.count, in);
+ deserialize(item.rnd, in);
+ deserialize(item._next_add_accepts, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ random_subset_selector<T> randomly_subsample (
+ const std::vector<T,alloc>& samples,
+ unsigned long num
+ )
+ {
+ random_subset_selector<T> subset;
+ subset.set_max_size(num);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ subset.add(samples[i]);
+ return subset;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc,
+ typename U
+ >
+ random_subset_selector<T> randomly_subsample (
+ const std::vector<T,alloc>& samples,
+ unsigned long num,
+ const U& random_seed
+ )
+ {
+ random_subset_selector<T> subset;
+ subset.set_seed(cast_to_string(random_seed));
+ subset.set_max_size(num);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ subset.add(samples[i]);
+ return subset;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ random_subset_selector<T> randomly_subsample (
+ const random_subset_selector<T>& samples,
+ unsigned long num
+ )
+ {
+ random_subset_selector<T> subset;
+ subset.set_max_size(num);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ subset.add(samples[i]);
+ return subset;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ random_subset_selector<T> randomly_subsample (
+ const random_subset_selector<T>& samples,
+ unsigned long num,
+ const U& random_seed
+ )
+ {
+ random_subset_selector<T> subset;
+ subset.set_seed(cast_to_string(random_seed));
+ subset.set_max_size(num);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ subset.add(samples[i]);
+ return subset;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_array_to_mat<random_subset_selector<T> > > mat (
+ const random_subset_selector<T>& m
+ )
+ {
+ typedef op_array_to_mat<random_subset_selector<T> > op;
+ return matrix_op<op>(op(m));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANDOM_SUBSeT_SELECTOR_H_
+
+
diff --git a/ml/dlib/dlib/statistics/random_subset_selector_abstract.h b/ml/dlib/dlib/statistics/random_subset_selector_abstract.h
new file mode 100644
index 000000000..96f8b545d
--- /dev/null
+++ b/ml/dlib/dlib/statistics/random_subset_selector_abstract.h
@@ -0,0 +1,388 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RANDOM_SUBSeT_SELECTOR_ABSTRACT_H_
+#ifdef DLIB_RANDOM_SUBSeT_SELECTOR_ABSTRACT_H_
+
+#include <vector>
+#include "../rand/rand_kernel_abstract.h"
+#include "../algs.h"
+#include "../string.h"
+
+namespace dlib
+{
+ template <
+ typename T,
+ typename Rand_type = dlib::rand
+ >
+ class random_subset_selector
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be a copyable type
+
+ REQUIREMENTS ON Rand_type
+ must be an implementation of dlib/rand/rand_kernel_abstract.h
+
+ INITIAL VALUE
+ - size() == 0
+ - max_size() == 0
+ - next_add_accepts() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool to help you select a random subset of a large body of data.
+ In particular, it is useful when the body of data is too large to fit into memory.
+
+ So for example, suppose you have 1000000 data samples and you want to select a
+ random subset of size 1000. Then you could do that as follows:
+
+ random_subset_selector<sample_type> rand_subset;
+ rand_subset.set_max_size(1000)
+ for (int i = 0; i < 1000000; ++i)
+ rand_subset.add( get_next_data_sample());
+
+
+ At the end of the for loop you will have your random subset of 1000 samples. And by
+ random I mean that each of the 1000000 data samples has an equal chance of ending
+ up in the rand_subset object.
+
+
+ Note that the above example calls get_next_data_sample() for each data sample. This
+ may be inefficient since most of the data samples are just ignored. An alternative
+ method that doesn't require you to load each sample can also be used. Consider the
+ following:
+
+ random_subset_selector<sample_type> rand_subset;
+ rand_subset.set_max_size(1000)
+ for (int i = 0; i < 1000000; ++i)
+ if (rand_subset.next_add_accepts())
+ rand_subset.add(get_data_sample(i));
+ else
+ rand_subset.add()
+
+ In the above example we only actually fetch the data sample into memory if we
+ know that the rand_subset would include it into the random subset. Otherwise,
+ we can just call the empty add().
+
+
+ Finally, note that the random_subset_selector uses a deterministic pseudo-random
+ number generator under the hood. Moreover, the default constructor always seeds
+ the random number generator in the same way. So unless you call set_seed()
+ each instance of the random_subset_selector will function identically.
+ !*/
+ public:
+ typedef T type;
+ typedef T value_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef Rand_type rand_type;
+
+ typedef typename std::vector<T>::iterator iterator;
+ typedef typename std::vector<T>::const_iterator const_iterator;
+
+ random_subset_selector (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void set_seed(
+ const std::string& value
+ );
+ /*!
+ ensures
+ - sets the seed of the random number generator that is embedded in
+ this object to the given value.
+ !*/
+
+ void make_empty (
+ );
+ /*!
+ ensures
+ - #size() == 0
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of items of type T currently contained in this object
+ !*/
+
+ void set_max_size (
+ unsigned long new_max_size
+ );
+ /*!
+ ensures
+ - #max_size() == new_max_size
+ - #size() == 0
+ !*/
+
+ unsigned long max_size (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum allowable size for this object
+ !*/
+
+ T& operator[] (
+ unsigned long idx
+ );
+ /*!
+ requires
+ - idx < size()
+ ensures
+ - returns a non-const reference to the idx'th element of this object
+ !*/
+
+ const T& operator[] (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < size()
+ ensures
+ - returns a const reference to the idx'th element of this object
+ !*/
+
+ bool next_add_accepts (
+ ) const;
+ /*!
+ ensures
+ - if (the next call to add(item) will result in item being included
+ into *this) then
+ - returns true
+ - Note that the next item will always be accepted if size() < max_size().
+ - else
+ - returns false
+ - Note that the next item will never be accepted if max_size() == 0.
+ !*/
+
+ void add (
+ const T& new_item
+ );
+ /*!
+ ensures
+ - if (next_add_accepts()) then
+ - places new_item into *this object at a random location
+ - if (size() < max_size()) then
+ - #size() == size() + 1
+ - #next_add_accepts() == The updated information about the acceptance
+ of the next call to add()
+ !*/
+
+ void add (
+ );
+ /*!
+ requires
+ - next_add_accepts() == false
+ ensures
+ - This function does nothing but update the value of #next_add_accepts()
+ !*/
+
+ iterator begin(
+ );
+ /*!
+ ensures
+ - if (size() > 0) then
+ - returns an iterator referring to the first element in
+ this container.
+ - else
+ - returns end()
+ !*/
+
+ const_iterator begin(
+ ) const;
+ /*!
+ ensures
+ - if (size() > 0) then
+ - returns a const_iterator referring to the first element in
+ this container.
+ - else
+ - returns end()
+ !*/
+
+ iterator end(
+ );
+ /*!
+ ensures
+ - returns an iterator that represents one past the end of
+ this container
+ !*/
+
+ const_iterator end(
+ ) const;
+ /*!
+ ensures
+ - returns an iterator that represents one past the end of
+ this container
+ !*/
+
+ const std::vector<T>& to_std_vector(
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the underlying std::vector<T> that contains
+ all elements in this object. That is, this function returns a vector, V,
+ which has the following properties:
+ - V.size() == this->size()
+ - V.begin() == this->begin()
+ - V.end() == this->end()
+ !*/
+
+ void swap (
+ random_subset_selector& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+ template <
+ typename T,
+ typename rand_type
+ >
+ void swap (
+ random_subset_selector<T,rand_type>& a,
+ random_subset_selector<T,rand_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides global swap support
+ !*/
+
+ template <
+ typename T,
+ typename rand_type
+ >
+ void serialize (
+ const random_subset_selector<T,rand_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ typename T,
+ typename rand_type
+ >
+ void deserialize (
+ random_subset_selector<T,rand_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ random_subset_selector<T> randomly_subsample (
+ const std::vector<T,alloc>& samples,
+ unsigned long num
+ );
+ /*!
+ ensures
+ - returns a random subset R such that:
+ - R contains a random subset of the given samples
+ - R.size() == min(num, samples.size())
+ - R.max_size() == num
+ - The random number generator used by this function will always be
+ initialized in the same way. I.e. this function will always pick
+ the same random subsample if called multiple times.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc,
+ typename U
+ >
+ random_subset_selector<T> randomly_subsample (
+ const std::vector<T,alloc>& samples,
+ unsigned long num,
+ const U& random_seed
+ );
+ /*!
+ requires
+ - random_seed must be convertible to a string by dlib::cast_to_string()
+ ensures
+ - returns a random subset R such that:
+ - R contains a random subset of the given samples
+ - R.size() == min(num, samples.size())
+ - R.max_size() == num
+ - The given random_seed will be used to initialize the random number
+ generator used by this function.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ random_subset_selector<T> randomly_subsample (
+ const random_subset_selector<T>& samples,
+ unsigned long num
+ );
+ /*!
+ ensures
+ - returns a random subset R such that:
+ - R contains a random subset of the given samples
+ - R.size() == min(num, samples.size())
+ - R.max_size() == num
+ - The random number generator used by this function will always be
+ initialized in the same way. I.e. this function will always pick
+ the same random subsample if called multiple times.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ random_subset_selector<T> randomly_subsample (
+ const random_subset_selector<T>& samples,
+ unsigned long num,
+ const U& random_seed
+ );
+ /*!
+ requires
+ - random_seed must be convertible to a string by dlib::cast_to_string()
+ ensures
+ - returns a random subset R such that:
+ - R contains a random subset of the given samples
+ - R.size() == min(num, samples.size())
+ - R.max_size() == num
+ - The given random_seed will be used to initialize the random number
+ generator used by this function.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_exp mat (
+ const random_subset_selector<T>& m
+ );
+ /*!
+ ensures
+ - returns a matrix R such that:
+ - is_col_vector(R) == true
+ - R.size() == m.size()
+ - for all valid r:
+ R(r) == m[r]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANDOM_SUBSeT_SELECTOR_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/statistics/running_gradient.h b/ml/dlib/dlib/statistics/running_gradient.h
new file mode 100644
index 000000000..d3f1b3ddf
--- /dev/null
+++ b/ml/dlib/dlib/statistics/running_gradient.h
@@ -0,0 +1,370 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RuNNING_GRADIENT_Hh_
+#define DLIB_RuNNING_GRADIENT_Hh_
+
+#include "running_gradient_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include <cmath>
+#include "../matrix.h"
+#include <algorithm>
+
+
+namespace dlib
+{
+ class running_gradient
+ {
+ public:
+
+ running_gradient (
+ )
+ {
+ clear();
+ }
+
+ void clear(
+ )
+ {
+ n = 0;
+ R = identity_matrix<double>(2)*1e6;
+ w = 0;
+ residual_squared = 0;
+ }
+
+ double current_n (
+ ) const
+ {
+ return n;
+ }
+
+ void add(
+ double y
+ )
+ {
+ matrix<double,2,1> x;
+ x = n, 1;
+
+ // Do recursive least squares computations
+ const double temp = 1 + trans(x)*R*x;
+ matrix<double,2,1> tmp = R*x;
+ R = R - (tmp*trans(tmp))/temp;
+ // R should always be symmetric. This line improves numeric stability of this algorithm.
+ R = 0.5*(R + trans(R));
+ w = w + R*x*(y - trans(x)*w);
+
+ // Also, recursively keep track of the residual error between the given value
+ // and what our linear predictor outputs.
+ residual_squared = residual_squared + std::pow((y - trans(x)*w),2.0)*temp;
+
+ ++n;
+ }
+
+ double gradient (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\t double running_gradient::gradient()"
+ << "\n\t You must add more values into this object before calling this function."
+ << "\n\t this: " << this
+ );
+
+ return w(0);
+ }
+
+ double intercept (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 0,
+ "\t double running_gradient::intercept()"
+ << "\n\t You must add more values into this object before calling this function."
+ << "\n\t this: " << this
+ );
+
+ return w(1);
+ }
+ double standard_error (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 2,
+ "\t double running_gradient::standard_error()"
+ << "\n\t You must add more values into this object before calling this function."
+ << "\n\t this: " << this
+ );
+
+
+ const double s = residual_squared/(n-2);
+ const double adjust = 12.0/(std::pow(current_n(),3.0) - current_n());
+ return std::sqrt(s*adjust);
+ }
+
+ double probability_gradient_less_than (
+ double thresh
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 2,
+ "\t double running_gradient::probability_gradient_less_than()"
+ << "\n\t You must add more values into this object before calling this function."
+ << "\n\t this: " << this
+ );
+
+ return normal_cdf(thresh, gradient(), standard_error());
+ }
+
+ double probability_gradient_greater_than (
+ double thresh
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 2,
+ "\t double running_gradient::probability_gradient_greater_than()"
+ << "\n\t You must add more values into this object before calling this function."
+ << "\n\t this: " << this
+ );
+
+ return 1-probability_gradient_less_than(thresh);
+ }
+
+ friend void serialize (const running_gradient& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.n, out);
+ serialize(item.R, out);
+ serialize(item.w, out);
+ serialize(item.residual_squared, out);
+ }
+
+ friend void deserialize (running_gradient& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::running_gradient.");
+ deserialize(item.n, in);
+ deserialize(item.R, in);
+ deserialize(item.w, in);
+ deserialize(item.residual_squared, in);
+ }
+
+ private:
+
+ static double normal_cdf(double value, double mean, double stddev)
+ {
+ if (stddev == 0)
+ {
+ if (value < mean)
+ return 0;
+ else if (value > mean)
+ return 1;
+ else
+ return 0.5;
+ }
+ value = (value-mean)/stddev;
+ return 0.5 * std::erfc(-value / std::sqrt(2.0));
+ }
+
+ double n;
+ matrix<double,2,2> R;
+ matrix<double,2,1> w;
+ double residual_squared;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ double probability_gradient_less_than (
+ const T& container,
+ double thresh
+ )
+ {
+ running_gradient g;
+ for(auto&& v : container)
+ g.add(v);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(g.current_n() > 2,
+ "\t double probability_gradient_less_than()"
+ << "\n\t You need more than 2 elements in the given container to call this function."
+ );
+ return g.probability_gradient_less_than(thresh);
+ }
+
+ template <
+ typename T
+ >
+ double probability_gradient_greater_than (
+ const T& container,
+ double thresh
+ )
+ {
+ running_gradient g;
+ for(auto&& v : container)
+ g.add(v);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(g.current_n() > 2,
+ "\t double probability_gradient_greater_than()"
+ << "\n\t You need more than 2 elements in the given container to call this function."
+ );
+ return g.probability_gradient_greater_than(thresh);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ double find_upper_quantile (
+ const T& container_,
+ double quantile
+ )
+ {
+ DLIB_CASSERT(0 <= quantile && quantile <= 1.0);
+
+ // copy container into a std::vector
+ std::vector<double> container(container_.begin(), container_.end());
+
+ DLIB_CASSERT(container.size() > 0);
+
+ size_t idx_upper = std::round((container.size()-1)*(1-quantile));
+
+ std::nth_element(container.begin(), container.begin()+idx_upper, container.end());
+ auto upper_q = *(container.begin()+idx_upper);
+ return upper_q;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ size_t count_steps_without_decrease (
+ const T& container,
+ double probability_of_decrease = 0.51
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0.5 < probability_of_decrease && probability_of_decrease < 1,
+ "\t size_t count_steps_without_decrease()"
+ << "\n\t probability_of_decrease: "<< probability_of_decrease
+ );
+
+ running_gradient g;
+ size_t count = 0;
+ size_t j = 0;
+ for (auto i = container.rbegin(); i != container.rend(); ++i)
+ {
+ ++j;
+ g.add(*i);
+ if (g.current_n() > 2)
+ {
+ // Note that this only looks backwards because we are looping over the
+ // container backwards. So here we are really checking if the gradient isn't
+ // decreasing.
+ double prob_decreasing = g.probability_gradient_greater_than(0);
+ // If we aren't confident things are decreasing.
+ if (prob_decreasing < probability_of_decrease)
+ count = j;
+ }
+ }
+ return count;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ size_t count_steps_without_decrease_robust (
+ const T& container,
+ double probability_of_decrease = 0.51,
+ double quantile_discard = 0.10
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= quantile_discard && quantile_discard <= 1);
+ DLIB_ASSERT(0.5 < probability_of_decrease && probability_of_decrease < 1,
+ "\t size_t count_steps_without_decrease_robust()"
+ << "\n\t probability_of_decrease: "<< probability_of_decrease
+ );
+
+ if (container.size() == 0)
+ return 0;
+
+ const auto quantile_thresh = find_upper_quantile(container, quantile_discard);
+
+ running_gradient g;
+ size_t count = 0;
+ size_t j = 0;
+ for (auto i = container.rbegin(); i != container.rend(); ++i)
+ {
+ ++j;
+ // ignore values that are too large
+ if (*i <= quantile_thresh)
+ g.add(*i);
+
+ if (g.current_n() > 2)
+ {
+ // Note that this only looks backwards because we are looping over the
+ // container backwards. So here we are really checking if the gradient isn't
+ // decreasing.
+ double prob_decreasing = g.probability_gradient_greater_than(0);
+ // If we aren't confident things are decreasing.
+ if (prob_decreasing < probability_of_decrease)
+ count = j;
+ }
+ }
+ return count;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ size_t count_steps_without_increase (
+ const T& container,
+ double probability_of_increase = 0.51
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0.5 < probability_of_increase && probability_of_increase < 1,
+ "\t size_t count_steps_without_increase()"
+ << "\n\t probability_of_increase: "<< probability_of_increase
+ );
+
+ running_gradient g;
+ size_t count = 0;
+ size_t j = 0;
+ for (auto i = container.rbegin(); i != container.rend(); ++i)
+ {
+ ++j;
+ g.add(*i);
+ if (g.current_n() > 2)
+ {
+ // Note that this only looks backwards because we are looping over the
+ // container backwards. So here we are really checking if the gradient isn't
+ // increasing.
+ double prob_increasing = g.probability_gradient_less_than(0);
+ // If we aren't confident things are increasing.
+ if (prob_increasing < probability_of_increase)
+ count = j;
+ }
+ }
+ return count;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RuNNING_GRADIENT_Hh_
+
+
diff --git a/ml/dlib/dlib/statistics/running_gradient_abstract.h b/ml/dlib/dlib/statistics/running_gradient_abstract.h
new file mode 100644
index 000000000..a42e1c152
--- /dev/null
+++ b/ml/dlib/dlib/statistics/running_gradient_abstract.h
@@ -0,0 +1,276 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RuNNING_GRADIENT_ABSTRACT_Hh_
+#ifdef DLIB_RuNNING_GRADIENT_ABSTRACT_Hh_
+
+
+namespace dlib
+{
+ class running_gradient
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for estimating if a noisy sequence of numbers is
+ trending up or down and by how much. It does this by finding the least
+ squares fit of a line to the data and then allows you to perform a
+ statistical test on the slope of that line.
+ !*/
+
+ public:
+
+ running_gradient (
+ );
+ /*!
+ ensures
+ - #current_n() == 0
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #current_n() == 0
+ - this object has its initial value
+ - clears all memory of any previous data points
+ !*/
+
+ double current_n (
+ ) const;
+ /*!
+ ensures
+ - returns the number of values given to this object by add().
+ !*/
+
+ void add(
+ double y
+ );
+ /*!
+ ensures
+ - Updates the gradient() and standard_error() estimates in this object
+ based on the new y value.
+ - #current_n() == current_n() + 1
+ !*/
+
+ double gradient (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - If we consider the values given to add() as time series data, we can
+ estimate the rate-of-change of those values. That is, how much,
+ typically, do those values change from sample to sample? The gradient()
+ function returns the current estimate. It does this by finding the least
+ squares fit of a line to the data given to add() and returning the slope
+ of this line.
+ !*/
+
+ double intercept (
+ ) const;
+ /*!
+ requires
+ - current_n() > 0
+ ensures
+ - This class fits a line to the time series data given to add(). This
+ function returns the intercept of that line while gradient() returns the
+ slope of that line. This means that, for example, the next point that
+ add() will see, as predicted by this best fit line, is the value
+ intercept() + current_n()*gradient().
+ !*/
+
+ double standard_error (
+ ) const;
+ /*!
+ requires
+ - current_n() > 2
+ ensures
+ - returns the standard deviation of the estimate of gradient().
+ !*/
+
+ double probability_gradient_less_than (
+ double thresh
+ ) const;
+ /*!
+ requires
+ - current_n() > 2
+ ensures
+ - If we can assume the values given to add() are linearly related to each
+ other and corrupted by Gaussian additive noise then our estimate of
+ gradient() is a random variable with a mean value of gradient() and a
+ standard deviation of standard_error(). This lets us compute the
+ probability that the true gradient of the data is less than thresh, which
+ is what this function returns.
+ !*/
+
+ double probability_gradient_greater_than (
+ double thresh
+ ) const;
+ /*!
+ requires
+ - current_n() > 2
+ ensures
+ - returns 1-probability_gradient_less_than(thresh)
+ !*/
+
+ };
+
+ void serialize (
+ const running_gradient& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ running_gradient& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ double probability_gradient_less_than (
+ const T& container,
+ double thresh
+ );
+ /*!
+ requires
+ - container must be a container of double values that can be enumerated with a
+ range based for loop.
+ - The container must contain more than 2 elements.
+ ensures
+ - Puts all the elements of container into a running_gradient object, R, and
+ then returns R.probability_gradient_less_than(thresh).
+ !*/
+
+ template <
+ typename T
+ >
+ double probability_gradient_greater_than (
+ const T& container,
+ double thresh
+ );
+ /*!
+ requires
+ - container must be a container of double values that can be enumerated with a
+ range based for loop.
+ - The container must contain more than 2 elements.
+ ensures
+ - Puts all the elements of container into a running_gradient object, R, and
+ then returns R.probability_gradient_greater_than(thresh).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ size_t count_steps_without_decrease (
+ const T& container,
+ double probability_of_decrease = 0.51
+ );
+ /*!
+ requires
+ - container must be a container of double values that can be enumerated with
+ .rbegin() and .rend().
+ - 0.5 < probability_of_decrease < 1
+ ensures
+ - If you think of the contents of container as a potentially noisy time series,
+ then this function returns a count of how long the time series has gone
+ without noticeably decreasing in value. It does this by adding the
+ elements into a running_gradient object and counting how many elements,
+ starting with container.back(), that you need to examine before you are
+ confident that the series has been decreasing in value. Here, "confident of
+ decrease" means that the probability of decrease is >= probability_of_decrease.
+ - Setting probability_of_decrease to 0.51 means we count until we see even a
+ small hint of decrease, whereas a larger value of 0.99 would return a larger
+ count since it keeps going until it is nearly certain the time series is
+ decreasing.
+ - The max possible output from this function is container.size().
+ !*/
+
+ template <
+ typename T
+ >
+ size_t count_steps_without_decrease_robust (
+ const T& container,
+ double probability_of_decrease = 0.51,
+ double quantile_discard = 0.10
+ );
+ /*!
+ requires
+ - container must be a container of double values that can be enumerated with
+ .begin() and .end() as well as .rbegin() and .rend().
+ - 0.5 < probability_of_decrease < 1
+ - 0 <= quantile_discard <= 1
+ ensures
+ - This function behaves just like
+ count_steps_without_decrease(container,probability_of_decrease) except that
+ it ignores values in container that are in the upper quantile_discard
+ quantile. So for example, if the quantile discard is 0.1 then the 10%
+ largest values in container are ignored.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ size_t count_steps_without_increase (
+ const T& container,
+ double probability_of_increase = 0.51
+ );
+ /*!
+ requires
+ - container must be a container of double values that can be enumerated with
+ .rbegin() and .rend().
+ - 0.5 < probability_of_increase < 1
+ ensures
+ - If you think of the contents of container as a potentially noisy time series,
+ then this function returns a count of how long the time series has gone
+ without noticeably increasing in value. It does this by adding the
+ elements into a running_gradient object and counting how many elements,
+ starting with container.back(), that you need to examine before you are
+ confident that the series has been increasing in value. Here, "confident of
+ increase" means that the probability of increase is >= probability_of_increase.
+ - Setting probability_of_increase to 0.51 means we count until we see even a
+ small hint of increase, whereas a larger value of 0.99 would return a larger
+ count since it keeps going until it is nearly certain the time series is
+ increasing.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ double find_upper_quantile (
+ const T& container,
+ double quantile
+ );
+ /*!
+ requires
+ - container must be a container of double values that can be enumerated with
+ .begin() and .end().
+ - 0 <= quantile <= 1
+ - container.size() > 0
+ ensures
+ - Finds and returns the value such that quantile percent of the values in
+ container are greater than it. For example, 0.5 would find the median value
+ in container while 0.1 would find the value that lower bounded the 10%
+ largest values in container.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RuNNING_GRADIENT_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/statistics/sammon.h b/ml/dlib/dlib/statistics/sammon.h
new file mode 100644
index 000000000..1a3eb72a1
--- /dev/null
+++ b/ml/dlib/dlib/statistics/sammon.h
@@ -0,0 +1,269 @@
+// Copyright (C) 2012 Emanuele Cesena (emanuele.cesena@gmail.com), Davis E. King
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SAMMoN_Hh_
+#define DLIB_SAMMoN_Hh_
+
+#include "sammon_abstract.h"
+#include "../matrix.h"
+#include "../algs.h"
+#include "dpca.h"
+#include <vector>
+
+namespace dlib
+{
+
+ class sammon_projection
+ {
+
+ public:
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename matrix_type>
+ std::vector<matrix<double,0,1> > operator() (
+ const std::vector<matrix_type>& data,
+ const long num_dims
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_dims > 0,
+ "\t std::vector<matrix<double,0,1> > sammon_projection::operator()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t num_dims: " << num_dims
+ );
+ std::vector<matrix<double,0,1> > result; // projections
+ if (data.size() == 0)
+ {
+ return result;
+ }
+
+#ifdef ENABLE_ASSERTS
+ DLIB_ASSERT(0 < num_dims && num_dims <= data[0].size(),
+ "\t std::vector<matrix<double,0,1> > sammon_projection::operator()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t data.size(): " << data.size()
+ << "\n\t num_dims: " << num_dims
+ << "\n\t data[0].size(): " << data[0].size()
+ );
+ for (unsigned long i = 0; i < data.size(); ++i)
+ {
+ DLIB_ASSERT(is_col_vector(data[i]) && data[i].size() == data[0].size(),
+ "\t std::vector<matrix<double,0,1> > sammon_projection::operator()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t data["<<i<<"].size(): " << data[i].size()
+ << "\n\t data[0].size(): " << data[0].size()
+ << "\n\t is_col_vector(data["<<i<<"]): " << is_col_vector(data[i])
+ );
+ }
+#endif
+
+ double err; // error (discarded)
+ do_sammon_projection(data, num_dims, result, err);
+ return result;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename matrix_type>
+ void operator() (
+ const std::vector<matrix_type>& data,
+ const long num_dims,
+ std::vector<matrix<double,0,1> >& result,
+ double &err,
+ const unsigned long num_iters = 1000,
+ const double err_delta = 1.0e-9
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_dims > 0 && num_iters > 0 && err_delta > 0.0,
+ "\t std::vector<matrix<double,0,1> > sammon_projection::operator()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t data.size(): " << data.size()
+ << "\n\t num_dims: " << num_dims
+ << "\n\t num_iters: " << num_iters
+ << "\n\t err_delta: " << err_delta
+ );
+ if (data.size() == 0)
+ {
+ result.clear();
+ err = 0;
+ return;
+ }
+
+#ifdef ENABLE_ASSERTS
+ DLIB_ASSERT(0 < num_dims && num_dims <= data[0].size(),
+ "\t std::vector<matrix<double,0,1> > sammon_projection::operator()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t data.size(): " << data.size()
+ << "\n\t num_dims: " << num_dims
+ << "\n\t data[0].size(): " << data[0].size()
+ );
+ for (unsigned long i = 0; i < data.size(); ++i)
+ {
+ DLIB_ASSERT(is_col_vector(data[i]) && data[i].size() == data[0].size(),
+ "\t std::vector<matrix<double,0,1> > sammon_projection::operator()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t data["<<i<<"].size(): " << data[i].size()
+ << "\n\t data[0].size(): " << data[0].size()
+ << "\n\t is_col_vector(data["<<i<<"]): " << is_col_vector(data[i])
+ );
+ }
+#endif
+
+ do_sammon_projection(data, num_dims, result, err, num_iters, err_delta);
+ }
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ private:
+
+ void compute_relative_distances(
+ matrix<double,0,1>& dist, // relative distances (output)
+ matrix<double,0,0>& data, // input data (matrix whose columns are the input vectors)
+ double eps_ratio = 1.0e-7 // to compute the minimum distance eps
+ )
+ /*!
+ requires
+ - dist.nc() == comb( data.nc(), 2 ), preallocated
+ - eps_ratio > 0
+ ensures
+ - dist[k] == lenght(data[i] - data[j]) for k = j(j-1)/2 + i
+ !*/
+ {
+ const long N = data.nc(); // num of points
+ double eps; // minimum distance, forced to avoid vectors collision
+ // computed at runtime as eps_ration * mean(vectors distances)
+ for (int k = 0, i = 1; i < N; ++i)
+ for (int j = 0; j < i; ++j)
+ dist(k++) = length(colm(data, i) - colm(data, j));
+
+ eps = eps_ratio * mean(dist);
+ dist = lowerbound(dist, eps);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename matrix_type>
+ void do_sammon_projection(
+ const std::vector<matrix_type>& data, // input data
+ unsigned long num_dims, // dimension of the reduced space
+ std::vector<matrix<double,0,1> >& result, // projections (output)
+ double &err, // error (output)
+ unsigned long num_iters = 1000, // max num of iterations: stop condition
+ const double err_delta = 1.0e-9 // delta error: stop condition
+ )
+ /*!
+ requires
+ - matrix_type should be a kind of dlib::matrix<double,N,1>
+ - num_dims > 0
+ - num_iters > 0
+ - err_delta > 0
+ ensures
+ - result == a set of matrix<double,num_dims,1> objects that represent
+ the Sammon's projections of data vectors.
+ - err == the estimated error done in the projection, with the extra
+ property that err(at previous iteration) - err < err_delta
+ !*/
+ {
+ // other params
+ const double mf = 0.3; // magic factor
+
+ matrix<double> mdata; // input data as matrix
+ matrix<double> projs; // projected vectors, i.e. output data as matrix
+
+ // std::vector<matrix> -> matrix
+ mdata.set_size(data[0].size(), data.size());
+ for (unsigned int i = 0; i < data.size(); i++)
+ set_colm(mdata, i) = data[i];
+
+ const long N = mdata.nc(); // num of points
+ const long d = num_dims; // size of the reduced space
+ const long nd = N * (N - 1) / 2; // num of pairs of points = size of the distances vectors
+
+ matrix<double, 0, 1> dsij, inv_dsij; // d*_ij: pair-wise distances in the input space (and inverses)
+ dsij.set_size(nd, 1);
+ inv_dsij.set_size(nd, 1);
+ double ic; // 1.0 / sum of dsij
+
+ matrix<double, 0, 1> dij; // d_ij: pair-wise distances in the reduced space
+ dij.set_size(nd, 1);
+
+ matrix<double, 0, 0> dE, dE2, dtemp; // matrices representing error partial derivatives
+ dE.set_size(d, N);
+ dE2.set_size(d, N);
+ dtemp.set_size(d, N);
+
+ matrix<double, 0, 1> inv_dij, alpha; // utility vectors used to compute the partial derivatives
+ inv_dij.set_size(N, 1); // inv_dij is 1.0/dij, but we only need it column-wise
+ alpha.set_size(N, 1); // (slightly wasting a bit of computation)
+ // alpha = 1.0/dij - 1.0/dsij, again column-wise
+
+
+ // initialize projs with PCA
+ discriminant_pca<matrix<double> > dpca;
+ for (int i = 0; i < mdata.nc(); ++i)
+ {
+ dpca.add_to_total_variance(colm(mdata, i));
+ }
+ matrix<double> mat = dpca.dpca_matrix_of_size(num_dims);
+ projs = mat * mdata;
+
+ // compute dsij, inv_dsij and ic
+ compute_relative_distances(dsij, mdata);
+ inv_dsij = 1.0 / dsij;
+ ic = 1.0 / sum(dsij);
+
+ // compute dij and err
+ compute_relative_distances(dij, projs);
+ err = ic * sum(pointwise_multiply(squared(dij - dsij), inv_dsij));
+
+ // start iterating
+ while (num_iters--)
+ {
+ // compute dE, dE2 progressively column by column
+ for (int p = 0; p < N; ++p)
+ {
+ // compute
+ // - alpha_p, the column vector with 1/d_pj - 1/d*_pj
+ // - dtemp, the matrix with the p-th column repeated all along
+ //TODO: optimize constructions
+ for (int i = 0; i < N; ++i)
+ {
+ int pos = (i < p) ? p * (p - 1) / 2 + i : i * (i - 1) / 2 + p;
+ inv_dij(i) = (i == p) ? 0.0 : 1.0 / dij(pos);
+ alpha(i) = (i == p) ? 0.0 : inv_dij(i) - inv_dsij(pos);
+ set_colm(dtemp, i) = colm(projs, p);
+ }
+
+ dtemp -= projs;
+ set_colm(dE, p) = dtemp * alpha;
+
+ double sum_alpha = sum(alpha);
+ set_colm(dE2, p) = abs( sum_alpha + squared(dtemp) * cubed(inv_dij) );
+ }
+
+
+ // compute the update projections
+ projs += pointwise_multiply(dE, mf * reciprocal(dE2));
+
+ // compute new dij and error
+ compute_relative_distances(dij, projs);
+ double err_new = ic * sum( pointwise_multiply(squared(dij - dsij), inv_dsij) );
+ if (err - err_new < err_delta)
+ break;
+ err = err_new;
+ }
+
+ // matrix -> std::vector<matrix>
+ result.clear();
+ for (int i = 0; i < projs.nc(); ++i)
+ result.push_back(colm(projs, i));
+ }
+
+ };
+
+} // namespace dlib
+
+#endif // DLIB_SAMMoN_Hh_
+
diff --git a/ml/dlib/dlib/statistics/sammon_abstract.h b/ml/dlib/dlib/statistics/sammon_abstract.h
new file mode 100644
index 000000000..0e009729c
--- /dev/null
+++ b/ml/dlib/dlib/statistics/sammon_abstract.h
@@ -0,0 +1,117 @@
+// Copyright (C) 2012 Emanuele Cesena (emanuele.cesena@gmail.com), Davis E. King
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SAMMoN_ABSTRACT_Hh_
+#ifdef DLIB_SAMMoN_ABSTRACT_Hh_
+
+#include "../matrix/matrix_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+ class sammon_projection
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a function object that computes the Sammon projection of a set
+ of N points in a L-dimensional vector space onto a d-dimensional space
+ (d < L), according to the paper:
+ A Nonlinear Mapping for Data Structure Analysis (1969) by J.W. Sammon
+
+ The current implementation is a vectorized version of the original algorithm.
+ !*/
+
+ public:
+
+ sammon_projection(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ template <typename matrix_type>
+ std::vector<matrix<double,0,1> > operator() (
+ const std::vector<matrix_type>& data,
+ const long num_dims
+ );
+ /*!
+ requires
+ - num_dims > 0
+ - matrix_type should be a kind of dlib::matrix of doubles capable
+ of representing column vectors.
+ - for all valid i:
+ - is_col_vector(data[i]) == true
+ - data[0].size() == data[i].size()
+ (i.e. all the vectors in data must have the same dimensionality)
+ - if (data.size() != 0) then
+ - 0 < num_dims <= data[0].size()
+ (i.e. you can't project into a higher dimension than the input data,
+ only to a lower dimension.)
+ ensures
+ - This routine computes Sammon's dimensionality reduction method based on the
+ given input data. It will attempt to project the contents of data into a
+ num_dims dimensional space that preserves relative distances between the
+ input data points.
+ - This function returns a std::vector, OUT, such that:
+ - OUT == a set of column vectors that represent the Sammon projection of
+ the input data vectors.
+ - OUT.size() == data.size()
+ - for all valid i:
+ - OUT[i].size() == num_dims
+ - OUT[i] == the Sammon projection of the input vector data[i]
+ !*/
+
+ template <typename matrix_type>
+ void operator() (
+ const std::vector<matrix_type>& data,
+ const long num_dims,
+ std::vector<matrix<double,0,1> >& result,
+ double &err,
+ const unsigned long num_iters = 1000,
+ const double err_delta = 1.0e-9
+ );
+ /*!
+ requires
+ - num_iters > 0
+ - err_delta > 0
+ - num_dims > 0
+ - matrix_type should be a kind of dlib::matrix of doubles capable
+ of representing column vectors.
+ - for all valid i:
+ - is_col_vector(data[i]) == true
+ - data[0].size() == data[i].size()
+ (i.e. all the vectors in data must have the same dimensionality)
+ - if (data.size() != 0) then
+ - 0 < num_dims <= data[0].size()
+ (i.e. you can't project into a higher dimension than the input data,
+ only to a lower dimension.)
+ ensures
+ - This routine computes Sammon's dimensionality reduction method based on the
+ given input data. It will attempt to project the contents of data into a
+ num_dims dimensional space that preserves relative distances between the
+ input data points.
+ - #err == the final error value at the end of the algorithm. The goal of Sammon's
+ algorithm is to find a lower dimensional projection of the input data that
+ preserves the relative distances between points. The value in #err is a measure
+ of the total error at the end of the algorithm. So smaller values indicate
+ a better projection was found than if a large value is returned via #err.
+ - Sammon's algorithm will run until either num_iters iterations has executed
+ or the change in error from one iteration to the next is less than err_delta.
+ - Upon completion, the output of Sammon's projection is stored into #result, in
+ particular, we will have:
+ - #result == a set of column vectors that represent the Sammon projection of
+ the input data vectors.
+ - #result.size() == data.size()
+ - for all valid i:
+ - #result[i].size() == num_dims
+ - #result[i] == the Sammon projection of the input vector data[i]
+ !*/
+
+ };
+
+}
+
+#endif // DLIB_SAMMoN_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/statistics/statistics.h b/ml/dlib/dlib/statistics/statistics.h
new file mode 100644
index 000000000..9dee7006b
--- /dev/null
+++ b/ml/dlib/dlib/statistics/statistics.h
@@ -0,0 +1,1890 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net), Steve Taylor
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STATISTICs_
+#define DLIB_STATISTICs_
+
+#include "statistics_abstract.h"
+#include <limits>
+#include <cmath>
+#include "../algs.h"
+#include "../matrix.h"
+#include "../sparse_vector.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class running_stats
+ {
+ public:
+
+ running_stats()
+ {
+ clear();
+
+ COMPILE_TIME_ASSERT ((
+ is_same_type<float,T>::value ||
+ is_same_type<double,T>::value ||
+ is_same_type<long double,T>::value
+ ));
+ }
+
+ void clear()
+ {
+ sum = 0;
+ sum_sqr = 0;
+ sum_cub = 0;
+ sum_four = 0;
+
+ n = 0;
+ min_value = std::numeric_limits<T>::infinity();
+ max_value = -std::numeric_limits<T>::infinity();
+ }
+
+ void add (
+ const T& val
+ )
+ {
+ sum += val;
+ sum_sqr += val*val;
+ sum_cub += cubed(val);
+ sum_four += quaded(val);
+
+ if (val < min_value)
+ min_value = val;
+ if (val > max_value)
+ max_value = val;
+
+ ++n;
+ }
+
+ T current_n (
+ ) const
+ {
+ return n;
+ }
+
+ T mean (
+ ) const
+ {
+ if (n != 0)
+ return sum/n;
+ else
+ return 0;
+ }
+
+ T max (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 0,
+ "\tT running_stats::max"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return max_value;
+ }
+
+ T min (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 0,
+ "\tT running_stats::min"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return min_value;
+ }
+
+ T variance (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_stats::variance"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ T temp = 1/(n-1);
+ temp = temp*(sum_sqr - sum*sum/n);
+ // make sure the variance is never negative. This might
+ // happen due to numerical errors.
+ if (temp >= 0)
+ return temp;
+ else
+ return 0;
+ }
+
+ T stddev (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_stats::stddev"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return std::sqrt(variance());
+ }
+
+ T skewness (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 2,
+ "\tT running_stats::skewness"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ T temp = 1/n;
+ T temp1 = std::sqrt(n*(n-1))/(n-2);
+ temp = temp1*temp*(sum_cub - 3*sum_sqr*sum*temp + 2*cubed(sum)*temp*temp)/
+ (std::sqrt(std::pow(temp*(sum_sqr-sum*sum*temp),3)));
+
+ return temp;
+ }
+
+ T ex_kurtosis (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 3,
+ "\tT running_stats::kurtosis"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ T temp = 1/n;
+ T m4 = temp*(sum_four - 4*sum_cub*sum*temp+6*sum_sqr*sum*sum*temp*temp
+ -3*quaded(sum)*cubed(temp));
+ T m2 = temp*(sum_sqr-sum*sum*temp);
+ temp = (n-1)*((n+1)*m4/(m2*m2)-3*(n-1))/((n-2)*(n-3));
+
+ return temp;
+ }
+
+ T scale (
+ const T& val
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_stats::variance"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+ return (val-mean())/std::sqrt(variance());
+ }
+
+ running_stats operator+ (
+ const running_stats& rhs
+ ) const
+ {
+ running_stats temp(*this);
+
+ temp.sum += rhs.sum;
+ temp.sum_sqr += rhs.sum_sqr;
+ temp.sum_cub += rhs.sum_cub;
+ temp.sum_four += rhs.sum_four;
+ temp.n += rhs.n;
+ temp.min_value = std::min(rhs.min_value, min_value);
+ temp.max_value = std::max(rhs.max_value, max_value);
+ return temp;
+ }
+
+ template <typename U>
+ friend void serialize (
+ const running_stats<U>& item,
+ std::ostream& out
+ );
+
+ template <typename U>
+ friend void deserialize (
+ running_stats<U>& item,
+ std::istream& in
+ );
+
+ private:
+ T sum;
+ T sum_sqr;
+ T sum_cub;
+ T sum_four;
+ T n;
+ T min_value;
+ T max_value;
+
+ T cubed (const T& val) const {return val*val*val; }
+ T quaded (const T& val) const {return val*val*val*val; }
+ };
+
+ template <typename T>
+ void serialize (
+ const running_stats<T>& item,
+ std::ostream& out
+ )
+ {
+ int version = 2;
+ serialize(version, out);
+
+ serialize(item.sum, out);
+ serialize(item.sum_sqr, out);
+ serialize(item.sum_cub, out);
+ serialize(item.sum_four, out);
+ serialize(item.n, out);
+ serialize(item.min_value, out);
+ serialize(item.max_value, out);
+ }
+
+ template <typename T>
+ void deserialize (
+ running_stats<T>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 2)
+ throw dlib::serialization_error("Unexpected version number found while deserializing dlib::running_stats object.");
+
+ deserialize(item.sum, in);
+ deserialize(item.sum_sqr, in);
+ deserialize(item.sum_cub, in);
+ deserialize(item.sum_four, in);
+ deserialize(item.n, in);
+ deserialize(item.min_value, in);
+ deserialize(item.max_value, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class running_scalar_covariance
+ {
+ public:
+
+ running_scalar_covariance()
+ {
+ clear();
+
+ COMPILE_TIME_ASSERT ((
+ is_same_type<float,T>::value ||
+ is_same_type<double,T>::value ||
+ is_same_type<long double,T>::value
+ ));
+ }
+
+ void clear()
+ {
+ sum_xy = 0;
+ sum_x = 0;
+ sum_y = 0;
+ sum_xx = 0;
+ sum_yy = 0;
+ n = 0;
+ }
+
+ void add (
+ const T& x,
+ const T& y
+ )
+ {
+ sum_xy += x*y;
+
+ sum_xx += x*x;
+ sum_yy += y*y;
+
+ sum_x += x;
+ sum_y += y;
+
+ n += 1;
+ }
+
+ T current_n (
+ ) const
+ {
+ return n;
+ }
+
+ T mean_x (
+ ) const
+ {
+ if (n != 0)
+ return sum_x/n;
+ else
+ return 0;
+ }
+
+ T mean_y (
+ ) const
+ {
+ if (n != 0)
+ return sum_y/n;
+ else
+ return 0;
+ }
+
+ T covariance (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance::covariance()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return 1/(n-1) * (sum_xy - sum_y*sum_x/n);
+ }
+
+ T correlation (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance::correlation()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return covariance() / std::sqrt(variance_x()*variance_y());
+ }
+
+ T variance_x (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance::variance_x()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ T temp = 1/(n-1) * (sum_xx - sum_x*sum_x/n);
+ // make sure the variance is never negative. This might
+ // happen due to numerical errors.
+ if (temp >= 0)
+ return temp;
+ else
+ return 0;
+ }
+
+ T variance_y (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance::variance_y()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ T temp = 1/(n-1) * (sum_yy - sum_y*sum_y/n);
+ // make sure the variance is never negative. This might
+ // happen due to numerical errors.
+ if (temp >= 0)
+ return temp;
+ else
+ return 0;
+ }
+
+ T stddev_x (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance::stddev_x()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return std::sqrt(variance_x());
+ }
+
+ T stddev_y (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance::stddev_y()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return std::sqrt(variance_y());
+ }
+
+ running_scalar_covariance operator+ (
+ const running_scalar_covariance& rhs
+ ) const
+ {
+ running_scalar_covariance temp(rhs);
+
+ temp.sum_xy += sum_xy;
+ temp.sum_x += sum_x;
+ temp.sum_y += sum_y;
+ temp.sum_xx += sum_xx;
+ temp.sum_yy += sum_yy;
+ temp.n += n;
+ return temp;
+ }
+
+ private:
+
+ T sum_xy;
+ T sum_x;
+ T sum_y;
+ T sum_xx;
+ T sum_yy;
+ T n;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class running_scalar_covariance_decayed
+ {
+ public:
+
+ explicit running_scalar_covariance_decayed(
+ T decay_halflife = 1000
+ )
+ {
+ DLIB_ASSERT(decay_halflife > 0);
+
+ sum_xy = 0;
+ sum_x = 0;
+ sum_y = 0;
+ sum_xx = 0;
+ sum_yy = 0;
+ forget = std::pow(0.5, 1/decay_halflife);
+ n = 0;
+
+ COMPILE_TIME_ASSERT ((
+ is_same_type<float,T>::value ||
+ is_same_type<double,T>::value ||
+ is_same_type<long double,T>::value
+ ));
+ }
+
+ T forget_factor (
+ ) const
+ {
+ return forget;
+ }
+
+ void add (
+ const T& x,
+ const T& y
+ )
+ {
+ sum_xy = sum_xy*forget + x*y;
+
+ sum_xx = sum_xx*forget + x*x;
+ sum_yy = sum_yy*forget + y*y;
+
+ sum_x = sum_x*forget + x;
+ sum_y = sum_y*forget + y;
+
+ n = n*forget + 1;
+ }
+
+ T current_n (
+ ) const
+ {
+ return n;
+ }
+
+ T mean_x (
+ ) const
+ {
+ if (n != 0)
+ return sum_x/n;
+ else
+ return 0;
+ }
+
+ T mean_y (
+ ) const
+ {
+ if (n != 0)
+ return sum_y/n;
+ else
+ return 0;
+ }
+
+ T covariance (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance_decayed::covariance()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return 1/(n-1) * (sum_xy - sum_y*sum_x/n);
+ }
+
+ T correlation (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance_decayed::correlation()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ T temp = std::sqrt(variance_x()*variance_y());
+ if (temp != 0)
+ return covariance() / temp;
+ else
+ return 0; // just say it's zero if there isn't any variance in x or y.
+ }
+
+ T variance_x (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance_decayed::variance_x()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ T temp = 1/(n-1) * (sum_xx - sum_x*sum_x/n);
+ // make sure the variance is never negative. This might
+ // happen due to numerical errors.
+ if (temp >= 0)
+ return temp;
+ else
+ return 0;
+ }
+
+ T variance_y (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance_decayed::variance_y()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ T temp = 1/(n-1) * (sum_yy - sum_y*sum_y/n);
+ // make sure the variance is never negative. This might
+ // happen due to numerical errors.
+ if (temp >= 0)
+ return temp;
+ else
+ return 0;
+ }
+
+ T stddev_x (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance_decayed::stddev_x()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return std::sqrt(variance_x());
+ }
+
+ T stddev_y (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_scalar_covariance_decayed::stddev_y()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return std::sqrt(variance_y());
+ }
+
+ private:
+
+ T sum_xy;
+ T sum_x;
+ T sum_y;
+ T sum_xx;
+ T sum_yy;
+ T n;
+ T forget;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class running_stats_decayed
+ {
+ public:
+
+ explicit running_stats_decayed(
+ T decay_halflife = 1000
+ )
+ {
+ DLIB_ASSERT(decay_halflife > 0);
+
+ sum_x = 0;
+ sum_xx = 0;
+ forget = std::pow(0.5, 1/decay_halflife);
+ n = 0;
+
+ COMPILE_TIME_ASSERT ((
+ is_same_type<float,T>::value ||
+ is_same_type<double,T>::value ||
+ is_same_type<long double,T>::value
+ ));
+ }
+
+ T forget_factor (
+ ) const
+ {
+ return forget;
+ }
+
+ void add (
+ const T& x
+ )
+ {
+
+ sum_xx = sum_xx*forget + x*x;
+
+ sum_x = sum_x*forget + x;
+
+ n = n*forget + 1;
+ }
+
+ T current_n (
+ ) const
+ {
+ return n;
+ }
+
+ T mean (
+ ) const
+ {
+ if (n != 0)
+ return sum_x/n;
+ else
+ return 0;
+ }
+
+ T variance (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_stats_decayed::variance()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ T temp = 1/(n-1) * (sum_xx - sum_x*sum_x/n);
+ // make sure the variance is never negative. This might
+ // happen due to numerical errors.
+ if (temp >= 0)
+ return temp;
+ else
+ return 0;
+ }
+
+ T stddev (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(current_n() > 1,
+ "\tT running_stats_decayed::stddev()"
+ << "\n\tyou have to add some numbers to this object first"
+ << "\n\tthis: " << this
+ );
+
+ return std::sqrt(variance());
+ }
+
+ template <typename U>
+ friend void serialize (
+ const running_stats_decayed<U>& item,
+ std::ostream& out
+ );
+
+ template <typename U>
+ friend void deserialize (
+ running_stats_decayed<U>& item,
+ std::istream& in
+ );
+
+ private:
+
+ T sum_x;
+ T sum_xx;
+ T n;
+ T forget;
+ };
+
+ template <typename T>
+ void serialize (
+ const running_stats_decayed<T>& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+
+ serialize(item.sum_x, out);
+ serialize(item.sum_xx, out);
+ serialize(item.n, out);
+ serialize(item.forget, out);
+ }
+
+ template <typename T>
+ void deserialize (
+ running_stats_decayed<T>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw dlib::serialization_error("Unexpected version number found while deserializing dlib::running_stats_decayed object.");
+
+ deserialize(item.sum_x, in);
+ deserialize(item.sum_xx, in);
+ deserialize(item.n, in);
+ deserialize(item.forget, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double mean_sign_agreement (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(a.size() == b.size(),
+ "\t double mean_sign_agreement(a,b)"
+ << "\n\t a and b must be the same length."
+ << "\n\t a.size(): " << a.size()
+ << "\n\t b.size(): " << b.size()
+ );
+
+
+ double temp = 0;
+ for (unsigned long i = 0; i < a.size(); ++i)
+ {
+ if ((a[i] >= 0 && b[i] >= 0) ||
+ (a[i] < 0 && b[i] < 0))
+ {
+ temp += 1;
+ }
+ }
+
+ return temp/a.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double correlation (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(a.size() == b.size() && a.size() > 1,
+ "\t double correlation(a,b)"
+ << "\n\t a and b must be the same length and have more than one element."
+ << "\n\t a.size(): " << a.size()
+ << "\n\t b.size(): " << b.size()
+ );
+
+ running_scalar_covariance<double> rs;
+ for (unsigned long i = 0; i < a.size(); ++i)
+ {
+ rs.add(a[i], b[i]);
+ }
+ return rs.correlation();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double covariance (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(a.size() == b.size() && a.size() > 1,
+ "\t double covariance(a,b)"
+ << "\n\t a and b must be the same length and have more than one element."
+ << "\n\t a.size(): " << a.size()
+ << "\n\t b.size(): " << b.size()
+ );
+
+ running_scalar_covariance<double> rs;
+ for (unsigned long i = 0; i < a.size(); ++i)
+ {
+ rs.add(a[i], b[i]);
+ }
+ return rs.covariance();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double r_squared (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(a.size() == b.size() && a.size() > 1,
+ "\t double r_squared(a,b)"
+ << "\n\t a and b must be the same length and have more than one element."
+ << "\n\t a.size(): " << a.size()
+ << "\n\t b.size(): " << b.size()
+ );
+
+ return std::pow(correlation(a,b),2.0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double mean_squared_error (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(a.size() == b.size(),
+ "\t double mean_squared_error(a,b)"
+ << "\n\t a and b must be the same length."
+ << "\n\t a.size(): " << a.size()
+ << "\n\t b.size(): " << b.size()
+ );
+
+ return mean(squared(matrix_cast<double>(mat(a))-matrix_cast<double>(mat(b))));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class running_covariance
+ {
+ /*!
+ INITIAL VALUE
+ - vect_size == 0
+ - total_count == 0
+
+ CONVENTION
+ - vect_size == in_vector_size()
+ - total_count == current_n()
+
+ - if (total_count != 0)
+ - total_sum == the sum of all vectors given to add()
+ - the covariance of all the elements given to add() is given
+ by:
+ - let avg == total_sum/total_count
+ - covariance == total_cov/total_count - avg*trans(avg)
+ !*/
+ public:
+
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+ running_covariance(
+ )
+ {
+ clear();
+ }
+
+ void clear(
+ )
+ {
+ total_count = 0;
+
+ vect_size = 0;
+
+ total_sum.set_size(0);
+ total_cov.set_size(0,0);
+ }
+
+ long in_vector_size (
+ ) const
+ {
+ return vect_size;
+ }
+
+ long current_n (
+ ) const
+ {
+ return static_cast<long>(total_count);
+ }
+
+ void set_dimension (
+ long size
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( size > 0,
+ "\t void running_covariance::set_dimension()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t size: " << size
+ << "\n\t this: " << this
+ );
+
+ clear();
+ vect_size = size;
+ total_sum.set_size(size);
+ total_cov.set_size(size,size);
+ total_sum = 0;
+ total_cov = 0;
+ }
+
+ template <typename T>
+ typename disable_if<is_matrix<T> >::type add (
+ const T& val
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(((long)max_index_plus_one(val) <= in_vector_size() && in_vector_size() > 0),
+ "\t void running_covariance::add()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t max_index_plus_one(val): " << max_index_plus_one(val)
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t this: " << this
+ );
+
+ for (typename T::const_iterator i = val.begin(); i != val.end(); ++i)
+ {
+ total_sum(i->first) += i->second;
+ for (typename T::const_iterator j = val.begin(); j != val.end(); ++j)
+ {
+ total_cov(i->first, j->first) += i->second*j->second;
+ }
+ }
+
+ ++total_count;
+ }
+
+ template <typename T>
+ typename enable_if<is_matrix<T> >::type add (
+ const T& val
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(val) && (in_vector_size() == 0 || val.size() == in_vector_size()),
+ "\t void running_covariance::add()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t is_col_vector(val): " << is_col_vector(val)
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t val.size(): " << val.size()
+ << "\n\t this: " << this
+ );
+
+ vect_size = val.size();
+ if (total_count == 0)
+ {
+ total_cov = val*trans(val);
+ total_sum = val;
+ }
+ else
+ {
+ total_cov += val*trans(val);
+ total_sum += val;
+ }
+ ++total_count;
+ }
+
+ const column_matrix mean (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( in_vector_size() != 0,
+ "\t running_covariance::mean()"
+ << "\n\t This object can not execute this function in its current state."
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t current_n(): " << current_n()
+ << "\n\t this: " << this
+ );
+
+ return total_sum/total_count;
+ }
+
+ const general_matrix covariance (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( in_vector_size() != 0 && current_n() > 1,
+ "\t running_covariance::covariance()"
+ << "\n\t This object can not execute this function in its current state."
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t current_n(): " << current_n()
+ << "\n\t this: " << this
+ );
+
+ return (total_cov - total_sum*trans(total_sum)/total_count)/(total_count-1);
+ }
+
+ const running_covariance operator+ (
+ const running_covariance& item
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT((in_vector_size() == 0 || item.in_vector_size() == 0 || in_vector_size() == item.in_vector_size()),
+ "\t running_covariance running_covariance::operator+()"
+ << "\n\t The two running_covariance objects being added must have compatible parameters"
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t item.in_vector_size(): " << item.in_vector_size()
+ << "\n\t this: " << this
+ );
+
+ running_covariance temp(item);
+
+ // make sure we ignore empty matrices
+ if (total_count != 0 && temp.total_count != 0)
+ {
+ temp.total_cov += total_cov;
+ temp.total_sum += total_sum;
+ temp.total_count += total_count;
+ }
+ else if (total_count != 0)
+ {
+ temp.total_cov = total_cov;
+ temp.total_sum = total_sum;
+ temp.total_count = total_count;
+ }
+
+ return temp;
+ }
+
+
+ private:
+
+ general_matrix total_cov;
+ column_matrix total_sum;
+ scalar_type total_count;
+
+ long vect_size;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class running_cross_covariance
+ {
+ /*!
+ INITIAL VALUE
+ - x_vect_size == 0
+ - y_vect_size == 0
+ - total_count == 0
+
+ CONVENTION
+ - x_vect_size == x_vector_size()
+ - y_vect_size == y_vector_size()
+ - total_count == current_n()
+
+ - if (total_count != 0)
+ - sum_x == the sum of all x vectors given to add()
+ - sum_y == the sum of all y vectors given to add()
+ - total_cov == sum of all x*trans(y) given to add()
+ !*/
+
+ public:
+
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+ running_cross_covariance(
+ )
+ {
+ clear();
+ }
+
+ void clear(
+ )
+ {
+ total_count = 0;
+
+ x_vect_size = 0;
+ y_vect_size = 0;
+
+ sum_x.set_size(0);
+ sum_y.set_size(0);
+ total_cov.set_size(0,0);
+ }
+
+ long x_vector_size (
+ ) const
+ {
+ return x_vect_size;
+ }
+
+ long y_vector_size (
+ ) const
+ {
+ return y_vect_size;
+ }
+
+ void set_dimensions (
+ long x_size,
+ long y_size
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( x_size > 0 && y_size > 0,
+ "\t void running_cross_covariance::set_dimensions()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t x_size: " << x_size
+ << "\n\t y_size: " << y_size
+ << "\n\t this: " << this
+ );
+
+ clear();
+ x_vect_size = x_size;
+ y_vect_size = y_size;
+ sum_x.set_size(x_size);
+ sum_y.set_size(y_size);
+ total_cov.set_size(x_size,y_size);
+
+ sum_x = 0;
+ sum_y = 0;
+ total_cov = 0;
+ }
+
+ long current_n (
+ ) const
+ {
+ return static_cast<long>(total_count);
+ }
+
+ template <typename T, typename U>
+ typename enable_if_c<!is_matrix<T>::value && !is_matrix<U>::value>::type add (
+ const T& x,
+ const U& y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( ((long)max_index_plus_one(x) <= x_vector_size() && x_vector_size() > 0) &&
+ ((long)max_index_plus_one(y) <= y_vector_size() && y_vector_size() > 0) ,
+ "\t void running_cross_covariance::add()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t max_index_plus_one(x): " << max_index_plus_one(x)
+ << "\n\t max_index_plus_one(y): " << max_index_plus_one(y)
+ << "\n\t x_vector_size(): " << x_vector_size()
+ << "\n\t y_vector_size(): " << y_vector_size()
+ << "\n\t this: " << this
+ );
+
+ for (typename T::const_iterator i = x.begin(); i != x.end(); ++i)
+ {
+ sum_x(i->first) += i->second;
+ for (typename U::const_iterator j = y.begin(); j != y.end(); ++j)
+ {
+ total_cov(i->first, j->first) += i->second*j->second;
+ }
+ }
+
+ // do sum_y += y
+ for (typename U::const_iterator j = y.begin(); j != y.end(); ++j)
+ {
+ sum_y(j->first) += j->second;
+ }
+
+ ++total_count;
+ }
+
+ template <typename T, typename U>
+ typename enable_if_c<is_matrix<T>::value && !is_matrix<U>::value>::type add (
+ const T& x,
+ const U& y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( (is_col_vector(x) && x.size() == x_vector_size() && x_vector_size() > 0) &&
+ ((long)max_index_plus_one(y) <= y_vector_size() && y_vector_size() > 0) ,
+ "\t void running_cross_covariance::add()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t max_index_plus_one(y): " << max_index_plus_one(y)
+ << "\n\t x_vector_size(): " << x_vector_size()
+ << "\n\t y_vector_size(): " << y_vector_size()
+ << "\n\t this: " << this
+ );
+
+ sum_x += x;
+
+ for (long i = 0; i < x.size(); ++i)
+ {
+ for (typename U::const_iterator j = y.begin(); j != y.end(); ++j)
+ {
+ total_cov(i, j->first) += x(i)*j->second;
+ }
+ }
+
+ // do sum_y += y
+ for (typename U::const_iterator j = y.begin(); j != y.end(); ++j)
+ {
+ sum_y(j->first) += j->second;
+ }
+
+ ++total_count;
+ }
+
+ template <typename T, typename U>
+ typename enable_if_c<!is_matrix<T>::value && is_matrix<U>::value>::type add (
+ const T& x,
+ const U& y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( ((long)max_index_plus_one(x) <= x_vector_size() && x_vector_size() > 0) &&
+ (is_col_vector(y) && y.size() == (long)y_vector_size() && y_vector_size() > 0) ,
+ "\t void running_cross_covariance::add()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t max_index_plus_one(x): " << max_index_plus_one(x)
+ << "\n\t is_col_vector(y): " << is_col_vector(y)
+ << "\n\t y.size(): " << y.size()
+ << "\n\t x_vector_size(): " << x_vector_size()
+ << "\n\t y_vector_size(): " << y_vector_size()
+ << "\n\t this: " << this
+ );
+
+ for (typename T::const_iterator i = x.begin(); i != x.end(); ++i)
+ {
+ sum_x(i->first) += i->second;
+ for (long j = 0; j < y.size(); ++j)
+ {
+ total_cov(i->first, j) += i->second*y(j);
+ }
+ }
+
+ sum_y += y;
+
+ ++total_count;
+ }
+
+ template <typename T, typename U>
+ typename enable_if_c<is_matrix<T>::value && is_matrix<U>::value>::type add (
+ const T& x,
+ const U& y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(x) && (x_vector_size() == 0 || x.size() == x_vector_size()) &&
+ is_col_vector(y) && (y_vector_size() == 0 || y.size() == y_vector_size()) &&
+ x.size() != 0 &&
+ y.size() != 0,
+ "\t void running_cross_covariance::add()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t x_vector_size(): " << x_vector_size()
+ << "\n\t x.size(): " << x.size()
+ << "\n\t is_col_vector(y): " << is_col_vector(y)
+ << "\n\t y_vector_size(): " << y_vector_size()
+ << "\n\t y.size(): " << y.size()
+ << "\n\t this: " << this
+ );
+
+ x_vect_size = x.size();
+ y_vect_size = y.size();
+ if (total_count == 0)
+ {
+ total_cov = x*trans(y);
+ sum_x = x;
+ sum_y = y;
+ }
+ else
+ {
+ total_cov += x*trans(y);
+ sum_x += x;
+ sum_y += y;
+ }
+ ++total_count;
+ }
+
+ const column_matrix mean_x (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( current_n() != 0,
+ "\t running_cross_covariance::mean()"
+ << "\n\t This object can not execute this function in its current state."
+ << "\n\t current_n(): " << current_n()
+ << "\n\t this: " << this
+ );
+
+ return sum_x/total_count;
+ }
+
+ const column_matrix mean_y (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( current_n() != 0,
+ "\t running_cross_covariance::mean()"
+ << "\n\t This object can not execute this function in its current state."
+ << "\n\t current_n(): " << current_n()
+ << "\n\t this: " << this
+ );
+
+ return sum_y/total_count;
+ }
+
+ const general_matrix covariance_xy (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( current_n() > 1,
+ "\t running_cross_covariance::covariance()"
+ << "\n\t This object can not execute this function in its current state."
+ << "\n\t x_vector_size(): " << x_vector_size()
+ << "\n\t y_vector_size(): " << y_vector_size()
+ << "\n\t current_n(): " << current_n()
+ << "\n\t this: " << this
+ );
+
+ return (total_cov - sum_x*trans(sum_y)/total_count)/(total_count-1);
+ }
+
+ const running_cross_covariance operator+ (
+ const running_cross_covariance& item
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT((x_vector_size() == 0 || item.x_vector_size() == 0 || x_vector_size() == item.x_vector_size()) &&
+ (y_vector_size() == 0 || item.y_vector_size() == 0 || y_vector_size() == item.y_vector_size()),
+ "\t running_cross_covariance running_cross_covariance::operator+()"
+ << "\n\t The two running_cross_covariance objects being added must have compatible parameters"
+ << "\n\t x_vector_size(): " << x_vector_size()
+ << "\n\t item.x_vector_size(): " << item.x_vector_size()
+ << "\n\t y_vector_size(): " << y_vector_size()
+ << "\n\t item.y_vector_size(): " << item.y_vector_size()
+ << "\n\t this: " << this
+ );
+
+ running_cross_covariance temp(item);
+
+ // make sure we ignore empty matrices
+ if (total_count != 0 && temp.total_count != 0)
+ {
+ temp.total_cov += total_cov;
+ temp.sum_x += sum_x;
+ temp.sum_y += sum_y;
+ temp.total_count += total_count;
+ }
+ else if (total_count != 0)
+ {
+ temp.total_cov = total_cov;
+ temp.sum_x = sum_x;
+ temp.sum_y = sum_y;
+ temp.total_count = total_count;
+ }
+
+ return temp;
+ }
+
+
+ private:
+
+ general_matrix total_cov;
+ column_matrix sum_x;
+ column_matrix sum_y;
+ scalar_type total_count;
+
+ long x_vect_size;
+ long y_vect_size;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class vector_normalizer
+ {
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef matrix_type result_type;
+
+ template <typename vector_type>
+ void train (
+ const vector_type& samples
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(samples.size() > 0,
+ "\tvoid vector_normalizer::train()"
+ << "\n\tyou have to give a nonempty set of samples to this function"
+ << "\n\tthis: " << this
+ );
+
+ m = mean(mat(samples));
+ sd = reciprocal(sqrt(variance(mat(samples))));
+
+ DLIB_ASSERT(is_finite(m), "Some of the input vectors to vector_normalizer::train() have infinite or NaN values");
+ }
+
+ long in_vector_size (
+ ) const
+ {
+ return m.nr();
+ }
+
+ long out_vector_size (
+ ) const
+ {
+ return m.nr();
+ }
+
+ const matrix_type& means (
+ ) const
+ {
+ return m;
+ }
+
+ const matrix_type& std_devs (
+ ) const
+ {
+ return sd;
+ }
+
+ const result_type& operator() (
+ const matrix_type& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.nr() == in_vector_size() && x.nc() == 1,
+ "\tmatrix vector_normalizer::operator()"
+ << "\n\t you have given invalid arguments to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t this: " << this
+ );
+
+ temp_out = pointwise_multiply(x-m, sd);
+ return temp_out;
+ }
+
+ void swap (
+ vector_normalizer& item
+ )
+ {
+ m.swap(item.m);
+ sd.swap(item.sd);
+ temp_out.swap(item.temp_out);
+ }
+
+ template <typename mt>
+ friend void deserialize (
+ vector_normalizer<mt>& item,
+ std::istream& in
+ );
+
+ template <typename mt>
+ friend void serialize (
+ const vector_normalizer<mt>& item,
+ std::ostream& out
+ );
+
+ private:
+
+ // ------------------- private data members -------------------
+
+ matrix_type m, sd;
+
+ // This is just a temporary variable that doesn't contribute to the
+ // state of this object.
+ mutable matrix_type temp_out;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ inline void swap (
+ vector_normalizer<matrix_type>& a,
+ vector_normalizer<matrix_type>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ void deserialize (
+ vector_normalizer<matrix_type>& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.m, in);
+ deserialize(item.sd, in);
+ // Keep deserializing the pca matrix for backwards compatibility.
+ matrix<double> pca;
+ deserialize(pca, in);
+
+ if (pca.size() != 0)
+ throw serialization_error("Error deserializing object of type vector_normalizer\n"
+ "It looks like a serialized vector_normalizer_pca was accidentally deserialized into \n"
+ "a vector_normalizer object.");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ void serialize (
+ const vector_normalizer<matrix_type>& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.m, out);
+ serialize(item.sd, out);
+ // Keep serializing the pca matrix for backwards compatibility.
+ matrix<double> pca;
+ serialize(pca, out);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class vector_normalizer_pca
+ {
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef matrix<scalar_type,0,1,mem_manager_type> result_type;
+
+ template <typename vector_type>
+ void train (
+ const vector_type& samples,
+ const double eps = 0.99
+ )
+ {
+ // You are getting an error here because you are trying to apply PCA
+ // to a vector of fixed length. But PCA is going to try and do
+ // dimensionality reduction so you can't use a vector with a fixed dimension.
+ COMPILE_TIME_ASSERT(matrix_type::NR == 0);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(samples.size() > 0,
+ "\tvoid vector_normalizer_pca::train_pca()"
+ << "\n\tyou have to give a nonempty set of samples to this function"
+ << "\n\tthis: " << this
+ );
+ DLIB_ASSERT(0 < eps && eps <= 1,
+ "\tvoid vector_normalizer_pca::train_pca()"
+ << "\n\tyou have to give a nonempty set of samples to this function"
+ << "\n\tthis: " << this
+ );
+ train_pca_impl(mat(samples),eps);
+
+ DLIB_ASSERT(is_finite(m), "Some of the input vectors to vector_normalizer_pca::train() have infinite or NaN values");
+ }
+
+ long in_vector_size (
+ ) const
+ {
+ return m.nr();
+ }
+
+ long out_vector_size (
+ ) const
+ {
+ return pca.nr();
+ }
+
+ const matrix<scalar_type,0,1,mem_manager_type>& means (
+ ) const
+ {
+ return m;
+ }
+
+ const matrix<scalar_type,0,1,mem_manager_type>& std_devs (
+ ) const
+ {
+ return sd;
+ }
+
+ const matrix<scalar_type,0,0,mem_manager_type>& pca_matrix (
+ ) const
+ {
+ return pca;
+ }
+
+ const result_type& operator() (
+ const matrix_type& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.nr() == in_vector_size() && x.nc() == 1,
+ "\tmatrix vector_normalizer_pca::operator()"
+ << "\n\t you have given invalid arguments to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t this: " << this
+ );
+
+ // If we have a pca transform matrix on hand then
+ // also apply that.
+ temp_out = pca*pointwise_multiply(x-m, sd);
+
+ return temp_out;
+ }
+
+ void swap (
+ vector_normalizer_pca& item
+ )
+ {
+ m.swap(item.m);
+ sd.swap(item.sd);
+ pca.swap(item.pca);
+ temp_out.swap(item.temp_out);
+ }
+
+ template <typename T>
+ friend void deserialize (
+ vector_normalizer_pca<T>& item,
+ std::istream& in
+ );
+
+ template <typename T>
+ friend void serialize (
+ const vector_normalizer_pca<T>& item,
+ std::ostream& out
+ );
+
+ private:
+
+ template <typename mat_type>
+ void train_pca_impl (
+ const mat_type& samples,
+ const double eps
+ )
+ {
+ m = mean(samples);
+ sd = reciprocal(sqrt(variance(samples)));
+
+ // fill x with the normalized version of the input samples
+ matrix<typename mat_type::type,0,1,mem_manager_type> x(samples);
+ for (long r = 0; r < x.size(); ++r)
+ x(r) = pointwise_multiply(x(r)-m, sd);
+
+ matrix<scalar_type,0,0,mem_manager_type> temp, eigen;
+ matrix<scalar_type,0,1,mem_manager_type> eigenvalues;
+
+ // Compute the svd of the covariance matrix of the normalized inputs
+ svd(covariance(x), temp, eigen, pca);
+ eigenvalues = diag(eigen);
+
+ rsort_columns(pca, eigenvalues);
+
+ // figure out how many eigenvectors we want in our pca matrix
+ const double thresh = sum(eigenvalues)*eps;
+ long num_vectors = 0;
+ double total = 0;
+ for (long r = 0; r < eigenvalues.size() && total < thresh; ++r)
+ {
+ ++num_vectors;
+ total += eigenvalues(r);
+ }
+
+ // So now we know we want to use num_vectors of the first eigenvectors. So
+ // pull those out and discard the rest.
+ pca = trans(colm(pca,range(0,num_vectors-1)));
+
+ // Apply the pca transform to the data in x. Then we will normalize the
+ // pca matrix below.
+ for (long r = 0; r < x.nr(); ++r)
+ {
+ x(r) = pca*x(r);
+ }
+
+ // Here we just scale the output features from the pca transform so
+ // that the variance of each feature is 1. So this doesn't really change
+ // what the pca is doing, it just makes sure the output features are
+ // normalized.
+ pca = trans(scale_columns(trans(pca), reciprocal(sqrt(variance(x)))));
+ }
+
+
+ // ------------------- private data members -------------------
+
+ matrix<scalar_type,0,1,mem_manager_type> m, sd;
+ matrix<scalar_type,0,0,mem_manager_type> pca;
+
+ // This is just a temporary variable that doesn't contribute to the
+ // state of this object.
+ mutable result_type temp_out;
+ };
+
+ template <
+ typename matrix_type
+ >
+ inline void swap (
+ vector_normalizer_pca<matrix_type>& a,
+ vector_normalizer_pca<matrix_type>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ void deserialize (
+ vector_normalizer_pca<matrix_type>& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.m, in);
+ deserialize(item.sd, in);
+ deserialize(item.pca, in);
+ if (item.pca.nc() != item.m.nr())
+ throw serialization_error("Error deserializing object of type vector_normalizer_pca\n"
+ "It looks like a serialized vector_normalizer was accidentally deserialized into \n"
+ "a vector_normalizer_pca object.");
+ }
+
+ template <
+ typename matrix_type
+ >
+ void serialize (
+ const vector_normalizer_pca<matrix_type>& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.m, out);
+ serialize(item.sd, out);
+ serialize(item.pca, out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double binomial_random_vars_are_different (
+ uint64_t k1,
+ uint64_t n1,
+ uint64_t k2,
+ uint64_t n2
+ )
+ {
+ DLIB_ASSERT(k1 <= n1, "k1: "<< k1 << " n1: "<< n1);
+ DLIB_ASSERT(k2 <= n2, "k2: "<< k2 << " n2: "<< n2);
+
+ const double p1 = k1/(double)n1;
+ const double p2 = k2/(double)n2;
+ const double p = (k1+k2)/(double)(n1+n2);
+
+ auto ll = [](double p, uint64_t k, uint64_t n) {
+ if (p == 0 || p == 1)
+ return 0.0;
+ return k*std::log(p) + (n-k)*std::log(1-p);
+ };
+
+ auto logll = ll(p1,k1,n1) + ll(p2,k2,n2) - ll(p,k1,n1) - ll(p,k2,n2);
+
+ // The basic statistic only tells you if the random variables are different. But
+ // it's nice to know which way they are different, i.e., which one is bigger. So
+ // stuff that information into the sign bit of the return value.
+ if (p1>=p2)
+ return logll;
+ else
+ return -logll;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double event_correlation (
+ uint64_t A_count,
+ uint64_t B_count,
+ uint64_t AB_count,
+ uint64_t total_num_observations
+ )
+ {
+ DLIB_ASSERT(AB_count <= A_count && A_count <= total_num_observations,
+ "AB_count: " << AB_count << ", A_count: "<< A_count << ", total_num_observations: " << total_num_observations);
+ DLIB_ASSERT(AB_count <= B_count && B_count <= total_num_observations,
+ "AB_count: " << AB_count << ", B_count: "<< B_count << ", total_num_observations: " << total_num_observations);
+
+ if (total_num_observations == 0)
+ return 0;
+
+ DLIB_ASSERT(A_count + B_count - AB_count <= total_num_observations,
+ "AB_count: " << AB_count << " A_count: " << A_count << ", B_count: "<< B_count << ", total_num_observations: " << total_num_observations);
+
+
+ const auto AnotB_count = A_count - AB_count;
+ const auto notB_count = total_num_observations - B_count;
+ // How likely is it that the odds of A happening is different when conditioned on
+ // whether or not B happened?
+ return binomial_random_vars_are_different(
+ AB_count, B_count, // A conditional on the presence of B
+ AnotB_count, notB_count // A conditional on the absence of B
+ );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STATISTICs_
+
diff --git a/ml/dlib/dlib/statistics/statistics_abstract.h b/ml/dlib/dlib/statistics/statistics_abstract.h
new file mode 100644
index 000000000..ef8f13802
--- /dev/null
+++ b/ml/dlib/dlib/statistics/statistics_abstract.h
@@ -0,0 +1,1387 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net), Steve Taylor
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STATISTICs_ABSTRACT_
+#ifdef DLIB_STATISTICs_ABSTRACT_
+
+#include <limits>
+#include <cmath>
+#include "../matrix/matrix_abstract.h"
+#include "../svm/sparse_vector_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double mean_sign_agreement (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ );
+ /*!
+ requires
+ - a.size() == b.size()
+ ensures
+ - returns the number of times a[i] has the same sign as b[i] divided by
+ a.size(). So we return the probability that elements of a and b have
+ the same sign.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double correlation (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ );
+ /*!
+ requires
+ - a.size() == b.size()
+ - a.size() > 1
+ ensures
+ - returns the correlation coefficient between all the elements of a and b.
+ (i.e. how correlated is a(i) with b(i))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double covariance (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ );
+ /*!
+ requires
+ - a.size() == b.size()
+ - a.size() > 1
+ ensures
+ - returns the covariance between all the elements of a and b.
+ (i.e. how does a(i) vary with b(i))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double r_squared (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ );
+ /*!
+ requires
+ - a.size() == b.size()
+ - a.size() > 1
+ ensures
+ - returns the R^2 coefficient of determination between all the elements of a and b.
+ This value is just the square of correlation(a,b).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename alloc
+ >
+ double mean_squared_error (
+ const std::vector<T,alloc>& a,
+ const std::vector<T,alloc>& b
+ );
+ /*!
+ requires
+ - a.size() == b.size()
+ ensures
+ - returns the mean squared error between all the elements of a and b.
+ (i.e. mean(squared(mat(a)-mat(b))))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ double binomial_random_vars_are_different (
+ uint64_t k1,
+ uint64_t n1,
+ uint64_t k2,
+ uint64_t n2
+ );
+ /*!
+ requires
+ - k1 <= n1
+ - k2 <= n2
+ ensures
+ - Given two binomially distributed random variables, X1 and X2, we want to know
+ if these variables have the same parameter (i.e. the chance of "success").
+ So assume that:
+ - You observed X1 to give k1 successes out of n1 trials.
+ - You observed X2 to give k2 successes out of n2 trials.
+ - This function performs a simple likelihood ratio test to determine if X1 and
+ X2 have the same parameter. The return value of this function will be:
+ - Close to 0 if they are probably the same.
+ - Larger than 0 if X1 probably has a higher "success" rate than X2.
+ - Smaller than 0 if X2 probably has a higher "success" rate than X1.
+ Moreover, the larger the absolute magnitude of the return value the more
+ likely it is that X1 and X2 have different distributions.
+ - For a discussion of the technique and applications see:
+ Dunning, Ted. "Accurate methods for the statistics of surprise and
+ coincidence." Computational linguistics 19.1 (1993): 61-74.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ double event_correlation (
+ uint64_t A_count,
+ uint64_t B_count,
+ uint64_t AB_count,
+ uint64_t total_num_observations
+ );
+ /*!
+ requires
+ - AB_count <= A_count <= total_num_observations
+ - AB_count <= B_count <= total_num_observations
+ - A_count + B_count - AB_count <= total_num_observations
+ ensures
+ - This function does a statistical test to determine if two events co-occur in
+ a statistically significant way. In particular, we assume you performed
+ total_num_observations measurements and during those measurements you:
+ - Observed event A to happen A_count times.
+ - Observed event B to happen B_count times.
+ - Observed AB_count co-occurrences of the events. That is, AB_count is the
+ number of times the events happened together during the same measurement.
+ - This function returns a number, COR, which can take any real value. It has
+ the following interpretations:
+ - COR == 0: there is no evidence of correlation between the two events.
+ They appear to be unrelated.
+ - COR > 0: There is evidence that A and B co-occur together. That is,
+ they happen at the same times more often than you would expect if they
+ were independent events. The larger the magnitude of COR the more
+ evidence we have for the correlation.
+ - COR < 0: There is evidence that A and B are anti-correlated. That is,
+ when A happens B is unlikely to happen and vise versa. The larger the
+ magnitude of COR the more evidence we have for the anti-correlation.
+ - This function implements the simple likelihood ratio test discussed in the
+ following paper:
+ Dunning, Ted. "Accurate methods for the statistics of surprise and
+ coincidence." Computational linguistics 19.1 (1993): 61-74.
+ So for an extended discussion of the method see the above paper.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class running_stats
+ {
+ /*!
+ REQUIREMENTS ON T
+ - T must be a float, double, or long double type
+
+ INITIAL VALUE
+ - mean() == 0
+ - current_n() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can compute the running mean,
+ variance, skewness, and excess kurtosis of a stream of real numbers.
+ !*/
+ public:
+
+ running_stats(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - this object has its initial value
+ - clears all memory of any previous data points
+ !*/
+
+ T current_n (
+ ) const;
+ /*!
+ ensures
+ - returns the number of points given to this object so far.
+ !*/
+
+ void add (
+ const T& val
+ );
+ /*!
+ ensures
+ - updates the mean, variance, skewness, and kurtosis stored in this object
+ so that the new value is factored into them.
+ - #mean() == mean()*current_n()/(current_n()+1) + val/(current_n()+1).
+ (i.e. the updated mean value that takes the new value into account)
+ - #variance() == the updated variance that takes this new value into account.
+ - #skewness() == the updated skewness that takes this new value into account.
+ - #ex_kurtosis() == the updated kurtosis that takes this new value into account.
+ - #current_n() == current_n() + 1
+ !*/
+
+ T mean (
+ ) const;
+ /*!
+ ensures
+ - returns the mean of all the values presented to this object
+ so far.
+ !*/
+
+ T variance (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the unbiased sample variance of all the values presented to this
+ object so far.
+ !*/
+
+ T stddev (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the unbiased sampled standard deviation of all the values
+ presented to this object so far.
+ !*/
+
+ T skewness (
+ ) const;
+ /*!
+ requires
+ - current_n() > 2
+ ensures
+ - returns the unbiased sample skewness of all the values presented
+ to this object so far.
+ !*/
+
+ T ex_kurtosis(
+ ) const;
+ /*!
+ requires
+ - current_n() > 3
+ ensures
+ - returns the unbiased sample kurtosis of all the values presented
+ to this object so far.
+ !*/
+
+ T max (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the largest value presented to this object so far.
+ !*/
+
+ T min (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the smallest value presented to this object so far.
+ !*/
+
+ T scale (
+ const T& val
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - return (val-mean())/stddev();
+ !*/
+
+ running_stats operator+ (
+ const running_stats& rhs
+ ) const;
+ /*!
+ ensures
+ - returns a new running_stats object that represents the combination of all
+ the values given to *this and rhs. That is, this function returns a
+ running_stats object, R, that is equivalent to what you would obtain if
+ all calls to this->add() and rhs.add() had instead been done to R.
+ !*/
+ };
+
+ template <typename T>
+ void serialize (
+ const running_stats<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <typename T>
+ void deserialize (
+ running_stats<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class running_scalar_covariance
+ {
+ /*!
+ REQUIREMENTS ON T
+ - T must be a float, double, or long double type
+
+ INITIAL VALUE
+ - mean_x() == 0
+ - mean_y() == 0
+ - current_n() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can compute the running covariance
+ of a stream of real number pairs.
+ !*/
+
+ public:
+
+ running_scalar_covariance(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - this object has its initial value
+ - clears all memory of any previous data points
+ !*/
+
+ void add (
+ const T& x,
+ const T& y
+ );
+ /*!
+ ensures
+ - updates the statistics stored in this object so that
+ the new pair (x,y) is factored into them.
+ - #current_n() == current_n() + 1
+ !*/
+
+ T current_n (
+ ) const;
+ /*!
+ ensures
+ - returns the number of points given to this object so far.
+ !*/
+
+ T mean_x (
+ ) const;
+ /*!
+ ensures
+ - returns the mean value of all x samples presented to this object
+ via add().
+ !*/
+
+ T mean_y (
+ ) const;
+ /*!
+ ensures
+ - returns the mean value of all y samples presented to this object
+ via add().
+ !*/
+
+ T covariance (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the covariance between all the x and y samples presented
+ to this object via add()
+ !*/
+
+ T correlation (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the correlation coefficient between all the x and y samples
+ presented to this object via add()
+ !*/
+
+ T variance_x (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the unbiased sample variance value of all x samples presented
+ to this object via add().
+ !*/
+
+ T variance_y (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the unbiased sample variance value of all y samples presented
+ to this object via add().
+ !*/
+
+ T stddev_x (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the unbiased sample standard deviation of all x samples
+ presented to this object via add().
+ !*/
+
+ T stddev_y (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the unbiased sample standard deviation of all y samples
+ presented to this object via add().
+ !*/
+
+ running_scalar_covariance operator+ (
+ const running_covariance& rhs
+ ) const;
+ /*!
+ ensures
+ - returns a new running_scalar_covariance object that represents the
+ combination of all the values given to *this and rhs. That is, this
+ function returns a running_scalar_covariance object, R, that is
+ equivalent to what you would obtain if all calls to this->add() and
+ rhs.add() had instead been done to R.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class running_scalar_covariance_decayed
+ {
+ /*!
+ REQUIREMENTS ON T
+ - T must be a float, double, or long double type
+
+ INITIAL VALUE
+ - mean_x() == 0
+ - mean_y() == 0
+ - current_n() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can compute the running covariance of
+ a stream of real number pairs. It is essentially the same as
+ running_scalar_covariance except that it forgets about data it has seen
+ after a certain period of time. It does this by exponentially decaying old
+ statistics.
+ !*/
+
+ public:
+
+ running_scalar_covariance_decayed(
+ T decay_halflife = 1000
+ );
+ /*!
+ requires
+ - decay_halflife > 0
+ ensures
+ - #forget_factor() == std::pow(0.5, 1/decay_halflife);
+ (i.e. after decay_halflife calls to add() the data given to the first add
+ will be down weighted by 0.5 in the statistics stored in this object).
+ !*/
+
+ T forget_factor (
+ ) const;
+ /*!
+ ensures
+ - returns the exponential forget factor used to forget old statistics when
+ add() is called.
+ !*/
+
+ void add (
+ const T& x,
+ const T& y
+ );
+ /*!
+ ensures
+ - updates the statistics stored in this object so that
+ the new pair (x,y) is factored into them.
+ - #current_n() == current_n()*forget_factor() + forget_factor()
+ - Down weights old statistics by a factor of forget_factor().
+ !*/
+
+ T current_n (
+ ) const;
+ /*!
+ ensures
+ - returns the effective number of points given to this object. As add()
+ is called this value will converge to a constant, the value of which is
+ based on the decay_halflife supplied to the constructor.
+ !*/
+
+ T mean_x (
+ ) const;
+ /*!
+ ensures
+ - returns the mean value of all x samples presented to this object
+ via add().
+ !*/
+
+ T mean_y (
+ ) const;
+ /*!
+ ensures
+ - returns the mean value of all y samples presented to this object
+ via add().
+ !*/
+
+ T covariance (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the covariance between all the x and y samples presented
+ to this object via add()
+ !*/
+
+ T correlation (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the correlation coefficient between all the x and y samples
+ presented to this object via add()
+ !*/
+
+ T variance_x (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the sample variance value of all x samples presented
+ to this object via add().
+ !*/
+
+ T variance_y (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the sample variance value of all y samples presented
+ to this object via add().
+ !*/
+
+ T stddev_x (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the sample standard deviation of all x samples
+ presented to this object via add().
+ !*/
+
+ T stddev_y (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the sample standard deviation of all y samples
+ presented to this object via add().
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class running_stats_decayed
+ {
+ /*!
+ REQUIREMENTS ON T
+ - T must be a float, double, or long double type
+
+ INITIAL VALUE
+ - mean() == 0
+ - current_n() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can compute the running mean and
+ variance of a stream of real numbers. It is similar to running_stats
+ except that it forgets about data it has seen after a certain period of
+ time. It does this by exponentially decaying old statistics.
+ !*/
+
+ public:
+
+ running_stats_decayed(
+ T decay_halflife = 1000
+ );
+ /*!
+ requires
+ - decay_halflife > 0
+ ensures
+ - #forget_factor() == std::pow(0.5, 1/decay_halflife);
+ (i.e. after decay_halflife calls to add() the data given to the first add
+ will be down weighted by 0.5 in the statistics stored in this object).
+ !*/
+
+ T forget_factor (
+ ) const;
+ /*!
+ ensures
+ - returns the exponential forget factor used to forget old statistics when
+ add() is called.
+ !*/
+
+ void add (
+ const T& x
+ );
+ /*!
+ ensures
+ - updates the statistics stored in this object so that x is factored into
+ them.
+ - #current_n() == current_n()*forget_factor() + forget_factor()
+ - Down weights old statistics by a factor of forget_factor().
+ !*/
+
+ T current_n (
+ ) const;
+ /*!
+ ensures
+ - returns the effective number of points given to this object. As add()
+ is called this value will converge to a constant, the value of which is
+ based on the decay_halflife supplied to the constructor.
+ !*/
+
+ T mean (
+ ) const;
+ /*!
+ ensures
+ - returns the mean value of all x samples presented to this object
+ via add().
+ !*/
+
+ T variance (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the sample variance value of all x samples presented to this
+ object via add().
+ !*/
+
+ T stddev (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the sample standard deviation of all x samples presented to this
+ object via add().
+ !*/
+
+ };
+
+ template <typename T>
+ void serialize (
+ const running_stats_decayed<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <typename T>
+ void deserialize (
+ running_stats_decayed<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class running_covariance
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ Must be some type of dlib::matrix.
+
+ INITIAL VALUE
+ - in_vector_size() == 0
+ - current_n() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple tool for computing the mean and
+ covariance of a sequence of vectors.
+ !*/
+ public:
+
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+ running_covariance(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - this object has its initial value
+ - clears all memory of any previous data points
+ !*/
+
+ long current_n (
+ ) const;
+ /*!
+ ensures
+ - returns the number of samples that have been presented to this object
+ !*/
+
+ long in_vector_size (
+ ) const;
+ /*!
+ ensures
+ - if (this object has been presented with any input vectors or
+ set_dimension() has been called) then
+ - returns the dimension of the column vectors used with this object
+ - else
+ - returns 0
+ !*/
+
+ void set_dimension (
+ long size
+ );
+ /*!
+ requires
+ - size > 0
+ ensures
+ - #in_vector_size() == size
+ - #current_n() == 0
+ !*/
+
+ template <typename T>
+ void add (
+ const T& val
+ );
+ /*!
+ requires
+ - val must represent a column vector. It can either be a dlib::matrix
+ object or some kind of unsorted sparse vector type. See the top of
+ dlib/svm/sparse_vector_abstract.h for a definition of unsorted sparse vector.
+ - val must have a number of dimensions which is compatible with the current
+ setting of in_vector_size(). In particular, this means that the
+ following must hold:
+ - if (val is a dlib::matrix) then
+ - in_vector_size() == 0 || val.size() == val_vector_size()
+ - else
+ - max_index_plus_one(val) <= in_vector_size()
+ - in_vector_size() > 0
+ (i.e. you must call set_dimension() prior to calling add() if
+ you want to use sparse vectors.)
+ ensures
+ - updates the mean and covariance stored in this object so that
+ the new value is factored into them.
+ - if (val is a dlib::matrix) then
+ - #in_vector_size() == val.size()
+ !*/
+
+ const column_matrix mean (
+ ) const;
+ /*!
+ requires
+ - in_vector_size() != 0
+ ensures
+ - returns the mean of all the vectors presented to this object
+ so far.
+ !*/
+
+ const general_matrix covariance (
+ ) const;
+ /*!
+ requires
+ - in_vector_size() != 0
+ - current_n() > 1
+ ensures
+ - returns the unbiased sample covariance matrix for all the vectors
+ presented to this object so far.
+ !*/
+
+ const running_covariance operator+ (
+ const running_covariance& item
+ ) const;
+ /*!
+ requires
+ - in_vector_size() == 0 || item.in_vector_size() == 0 || in_vector_size() == item.in_vector_size()
+ (i.e. the in_vector_size() of *this and item must match or one must be zero)
+ ensures
+ - returns a new running_covariance object that represents the combination of all
+ the vectors given to *this and item. That is, this function returns a
+ running_covariance object, R, that is equivalent to what you would obtain if all
+ calls to this->add() and item.add() had instead been done to R.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class running_cross_covariance
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ Must be some type of dlib::matrix.
+
+ INITIAL VALUE
+ - x_vector_size() == 0
+ - y_vector_size() == 0
+ - current_n() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple tool for computing the mean and cross-covariance
+ matrices of a sequence of pairs of vectors.
+ !*/
+
+ public:
+
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef typename matrix_type::layout_type layout_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type,layout_type> general_matrix;
+ typedef matrix<scalar_type,0,1,mem_manager_type,layout_type> column_matrix;
+
+ running_cross_covariance(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - This object has its initial value.
+ - Clears all memory of any previous data points.
+ !*/
+
+ long x_vector_size (
+ ) const;
+ /*!
+ ensures
+ - if (this object has been presented with any input vectors or
+ set_dimensions() has been called) then
+ - returns the dimension of the x vectors given to this object via add().
+ - else
+ - returns 0
+ !*/
+
+ long y_vector_size (
+ ) const;
+ /*!
+ ensures
+ - if (this object has been presented with any input vectors or
+ set_dimensions() has been called) then
+ - returns the dimension of the y vectors given to this object via add().
+ - else
+ - returns 0
+ !*/
+
+ void set_dimensions (
+ long x_size,
+ long y_size
+ );
+ /*!
+ requires
+ - x_size > 0
+ - y_size > 0
+ ensures
+ - #x_vector_size() == x_size
+ - #y_vector_size() == y_size
+ - #current_n() == 0
+ !*/
+
+ long current_n (
+ ) const;
+ /*!
+ ensures
+ - returns the number of samples that have been presented to this object.
+ !*/
+
+ template <typename T, typename U>
+ void add (
+ const T& x,
+ const U& y
+ );
+ /*!
+ requires
+ - x and y must represent column vectors. They can either be dlib::matrix
+ objects or some kind of unsorted sparse vector type. See the top of
+ dlib/svm/sparse_vector_abstract.h for a definition of unsorted sparse vector.
+ - x and y must have a number of dimensions which is compatible with the
+ current setting of x_vector_size() and y_vector_size(). In particular,
+ this means that the following must hold:
+ - if (x or y is a sparse vector type) then
+ - x_vector_size() > 0 && y_vector_size() > 0
+ (i.e. you must call set_dimensions() prior to calling add() if
+ you want to use sparse vectors.)
+ - if (x is a dlib::matrix) then
+ - x_vector_size() == 0 || x.size() == x_vector_size()
+ - else
+ - max_index_plus_one(x) <= x_vector_size()
+ - if (y is a dlib::matrix) then
+ - y_vector_size() == 0 || y.size() == y_vector_size()
+ - else
+ - max_index_plus_one(y) <= y_vector_size()
+ ensures
+ - updates the mean and cross-covariance matrices stored in this object so
+ that the new (x,y) vector pair is factored into them.
+ - if (x is a dlib::matrix) then
+ - #x_vector_size() == x.size()
+ - if (y is a dlib::matrix) then
+ - #y_vector_size() == y.size()
+ !*/
+
+ const column_matrix mean_x (
+ ) const;
+ /*!
+ requires
+ - current_n() != 0
+ ensures
+ - returns the mean of all the x vectors presented to this object so far.
+ - The returned vector will have x_vector_size() dimensions.
+ !*/
+
+ const column_matrix mean_y (
+ ) const;
+ /*!
+ requires
+ - current_n() != 0
+ ensures
+ - returns the mean of all the y vectors presented to this object so far.
+ - The returned vector will have y_vector_size() dimensions.
+ !*/
+
+ const general_matrix covariance_xy (
+ ) const;
+ /*!
+ requires
+ - current_n() > 1
+ ensures
+ - returns the unbiased sample cross-covariance matrix for all the vector
+ pairs presented to this object so far. In particular, returns a matrix
+ M such that:
+ - M.nr() == x_vector_size()
+ - M.nc() == y_vector_size()
+ - M == the cross-covariance matrix of the data given to add().
+ !*/
+
+ const running_cross_covariance operator+ (
+ const running_cross_covariance& item
+ ) const;
+ /*!
+ requires
+ - x_vector_size() == 0 || item.x_vector_size() == 0 || x_vector_size() == item.x_vector_size()
+ (i.e. the x_vector_size() of *this and item must match or one must be zero)
+ - y_vector_size() == 0 || item.y_vector_size() == 0 || y_vector_size() == item.y_vector_size()
+ (i.e. the y_vector_size() of *this and item must match or one must be zero)
+ ensures
+ - returns a new running_cross_covariance object that represents the
+ combination of all the vectors given to *this and item. That is, this
+ function returns a running_cross_covariance object, R, that is equivalent
+ to what you would obtain if all calls to this->add() and item.add() had
+ instead been done to R.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class vector_normalizer
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ - must be a dlib::matrix object capable of representing column
+ vectors
+
+ INITIAL VALUE
+ - in_vector_size() == 0
+ - out_vector_size() == 0
+ - means().size() == 0
+ - std_devs().size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can learn to normalize a set
+ of column vectors. In particular, normalized column vectors should
+ have zero mean and a variance of one.
+
+ Also, if desired, this object can use principal component
+ analysis for the purposes of reducing the number of elements in a
+ vector.
+
+ THREAD SAFETY
+ Note that this object contains a cached matrix object it uses
+ to store intermediate results for normalization. This avoids
+ needing to reallocate it every time this object performs normalization
+ but also makes it non-thread safe. So make sure you don't share
+ instances of this object between threads.
+ !*/
+
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef matrix_type result_type;
+
+ template <typename vector_type>
+ void train (
+ const vector_type& samples
+ );
+ /*!
+ requires
+ - samples.size() > 0
+ - samples == a column matrix or something convertible to a column
+ matrix via mat(). Also, x should contain
+ matrix_type objects that represent nonempty column vectors.
+ - samples does not contain any infinite or NaN values
+ ensures
+ - #in_vector_size() == samples(0).nr()
+ - #out_vector_size() == samples(0).nr()
+ - This object has learned how to normalize vectors that look like
+ vectors in the given set of samples.
+ - #means() == mean(samples)
+ - #std_devs() == reciprocal(sqrt(variance(samples)));
+ !*/
+
+ long in_vector_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows that input vectors are
+ required to contain if they are to be normalized by
+ this object.
+ !*/
+
+ long out_vector_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the normalized vectors
+ that come out of this object.
+ !*/
+
+ const matrix_type& means (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix M such that:
+ - M.nc() == 1
+ - M.nr() == in_vector_size()
+ - M(i) == the mean of the ith input feature shown to train()
+ !*/
+
+ const matrix_type& std_devs (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix SD such that:
+ - SD.nc() == 1
+ - SD.nr() == in_vector_size()
+ - SD(i) == the reciprocal of the standard deviation of the ith
+ input feature shown to train()
+ !*/
+
+ const result_type& operator() (
+ const matrix_type& x
+ ) const;
+ /*!
+ requires
+ - x.nr() == in_vector_size()
+ - x.nc() == 1
+ ensures
+ - returns a normalized version of x, call it Z, that has the
+ following properties:
+ - Z.nr() == out_vector_size()
+ - Z.nc() == 1
+ - the mean of each element of Z is 0
+ - the variance of each element of Z is 1
+ - Z == pointwise_multiply(x-means(), std_devs());
+ !*/
+
+ void swap (
+ vector_normalizer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+ };
+
+ template <
+ typename matrix_type
+ >
+ inline void swap (
+ vector_normalizer<matrix_type>& a,
+ vector_normalizer<matrix_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename matrix_type,
+ >
+ void deserialize (
+ vector_normalizer<matrix_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+ template <
+ typename matrix_type,
+ >
+ void serialize (
+ const vector_normalizer<matrix_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class vector_normalizer_pca
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ - must be a dlib::matrix object capable of representing column
+ vectors
+
+ INITIAL VALUE
+ - in_vector_size() == 0
+ - out_vector_size() == 0
+ - means().size() == 0
+ - std_devs().size() == 0
+ - pca_matrix().size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents something that can learn to normalize a set
+ of column vectors. In particular, normalized column vectors should
+ have zero mean and a variance of one.
+
+ Also, this object uses principal component analysis for the purposes
+ of reducing the number of elements in a vector.
+
+ THREAD SAFETY
+ Note that this object contains a cached matrix object it uses
+ to store intermediate results for normalization. This avoids
+ needing to reallocate it every time this object performs normalization
+ but also makes it non-thread safe. So make sure you don't share
+ instances of this object between threads.
+ !*/
+
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef matrix<scalar_type,0,1,mem_manager_type> result_type;
+
+ template <typename vector_type>
+ void train (
+ const vector_type& samples,
+ const double eps = 0.99
+ );
+ /*!
+ requires
+ - 0 < eps <= 1
+ - samples.size() > 0
+ - samples == a column matrix or something convertible to a column
+ matrix via mat(). Also, x should contain
+ matrix_type objects that represent nonempty column vectors.
+ - samples does not contain any infinite or NaN values
+ ensures
+ - This object has learned how to normalize vectors that look like
+ vectors in the given set of samples.
+ - Principal component analysis is performed to find a transform
+ that might reduce the number of output features.
+ - #in_vector_size() == samples(0).nr()
+ - 0 < #out_vector_size() <= samples(0).nr()
+ - eps is a number that controls how "lossy" the pca transform will be.
+ Large values of eps result in #out_vector_size() being larger and
+ smaller values of eps result in #out_vector_size() being smaller.
+ - #means() == mean(samples)
+ - #std_devs() == reciprocal(sqrt(variance(samples)));
+ - #pca_matrix() == the PCA transform matrix that is out_vector_size()
+ rows by in_vector_size() columns.
+ !*/
+
+ long in_vector_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows that input vectors are
+ required to contain if they are to be normalized by
+ this object.
+ !*/
+
+ long out_vector_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the normalized vectors
+ that come out of this object.
+ !*/
+
+ const matrix<scalar_type,0,1,mem_manager_type>& means (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix M such that:
+ - M.nc() == 1
+ - M.nr() == in_vector_size()
+ - M(i) == the mean of the ith input feature shown to train()
+ !*/
+
+ const matrix<scalar_type,0,1,mem_manager_type>& std_devs (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix SD such that:
+ - SD.nc() == 1
+ - SD.nr() == in_vector_size()
+ - SD(i) == the reciprocal of the standard deviation of the ith
+ input feature shown to train()
+ !*/
+
+ const matrix<scalar_type,0,0,mem_manager_type>& pca_matrix (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix PCA such that:
+ - PCA.nr() == out_vector_size()
+ - PCA.nc() == in_vector_size()
+ - PCA == the principal component analysis transformation
+ matrix
+ !*/
+
+ const result_type& operator() (
+ const matrix_type& x
+ ) const;
+ /*!
+ requires
+ - x.nr() == in_vector_size()
+ - x.nc() == 1
+ ensures
+ - returns a normalized version of x, call it Z, that has the
+ following properties:
+ - Z.nr() == out_vector_size()
+ - Z.nc() == 1
+ - the mean of each element of Z is 0
+ - the variance of each element of Z is 1
+ - Z == pca_matrix()*pointwise_multiply(x-means(), std_devs());
+ !*/
+
+ void swap (
+ vector_normalizer_pca& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+ };
+
+ template <
+ typename matrix_type
+ >
+ inline void swap (
+ vector_normalizer_pca<matrix_type>& a,
+ vector_normalizer_pca<matrix_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename matrix_type,
+ >
+ void deserialize (
+ vector_normalizer_pca<matrix_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+ template <
+ typename matrix_type,
+ >
+ void serialize (
+ const vector_normalizer_pca<matrix_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STATISTICs_ABSTRACT_
+
diff --git a/ml/dlib/dlib/statistics/vector_normalizer_frobmetric.h b/ml/dlib/dlib/statistics/vector_normalizer_frobmetric.h
new file mode 100644
index 000000000..690370f80
--- /dev/null
+++ b/ml/dlib/dlib/statistics/vector_normalizer_frobmetric.h
@@ -0,0 +1,618 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_VECTOR_NORMALIZER_FRoBMETRIC_Hh_
+#define DLIB_VECTOR_NORMALIZER_FRoBMETRIC_Hh_
+
+#include "vector_normalizer_frobmetric_abstract.h"
+#include "../matrix.h"
+#include "../optimization.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ struct frobmetric_training_sample
+ {
+ matrix_type anchor_vect;
+ std::vector<matrix_type> near_vects;
+ std::vector<matrix_type> far_vects;
+
+ unsigned long num_triples (
+ ) const { return near_vects.size() * far_vects.size(); }
+
+ void clear()
+ {
+ near_vects.clear();
+ far_vects.clear();
+ }
+ };
+
+ template <
+ typename matrix_type
+ >
+ void serialize(const frobmetric_training_sample<matrix_type>& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.anchor_vect, out);
+ serialize(item.near_vects, out);
+ serialize(item.far_vects, out);
+ }
+
+ template <
+ typename matrix_type
+ >
+ void deserialize(frobmetric_training_sample<matrix_type>& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::frobmetric_training_sample.");
+ deserialize(item.anchor_vect, in);
+ deserialize(item.near_vects, in);
+ deserialize(item.far_vects, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class vector_normalizer_frobmetric
+ {
+
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef matrix_type result_type;
+
+ private:
+ struct compact_frobmetric_training_sample
+ {
+ std::vector<matrix_type> near_vects;
+ std::vector<matrix_type> far_vects;
+ };
+
+ struct objective
+ {
+ objective (
+ const std::vector<compact_frobmetric_training_sample>& samples_,
+ matrix<double,0,0,mem_manager_type>& Aminus_,
+ const matrix<double,0,1,mem_manager_type>& bias_
+ ) : samples(samples_), Aminus(Aminus_), bias(bias_) {}
+
+ double operator()(const matrix<double,0,1,mem_manager_type>& u) const
+ {
+ long idx = 0;
+ const long dims = samples[0].far_vects[0].size();
+ // Here we compute \hat A from the paper, which we refer to as just A in
+ // the code.
+ matrix<double,0,0,mem_manager_type> A(dims,dims);
+ A = 0;
+ std::vector<double> ufar, unear;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ ufar.assign(samples[i].far_vects.size(),0);
+ unear.assign(samples[i].near_vects.size(),0);
+ for (unsigned long j = 0; j < unear.size(); ++j)
+ {
+ for (unsigned long k = 0; k < ufar.size(); ++k)
+ {
+ const double val = u(idx++);
+ ufar[k] -= val;
+ unear[j] += val;
+ }
+ }
+ for (unsigned long j = 0; j < unear.size(); ++j)
+ A += unear[j]*samples[i].near_vects[j]*trans(samples[i].near_vects[j]);
+ for (unsigned long j = 0; j < ufar.size(); ++j)
+ A += ufar[j]*samples[i].far_vects[j]*trans(samples[i].far_vects[j]);
+ }
+
+ eigenvalue_decomposition<matrix<double,0,0,mem_manager_type> > ed(make_symmetric(A));
+ Aminus = ed.get_pseudo_v()*diagm(upperbound(ed.get_real_eigenvalues(),0))*trans(ed.get_pseudo_v());
+ // Do this to avoid numeric instability later on since the above
+ // computation can make Aminus slightly non-symmetric.
+ Aminus = make_symmetric(Aminus);
+
+ return dot(u,bias) - 0.5*sum(squared(Aminus));
+ }
+
+ private:
+ const std::vector<compact_frobmetric_training_sample>& samples;
+ matrix<double,0,0,mem_manager_type>& Aminus;
+ const matrix<double,0,1,mem_manager_type>& bias;
+ };
+
+ struct derivative
+ {
+ derivative (
+ unsigned long num_triples_,
+ const std::vector<compact_frobmetric_training_sample>& samples_,
+ matrix<double,0,0,mem_manager_type>& Aminus_,
+ const matrix<double,0,1,mem_manager_type>& bias_
+ ) : num_triples(num_triples_), samples(samples_), Aminus(Aminus_), bias(bias_) {}
+
+ matrix<double,0,1,mem_manager_type> operator()(const matrix<double,0,1,mem_manager_type>& ) const
+ {
+ // Note that Aminus is a function of u (the argument to this function), but
+ // since Aminus will have been computed already by the most recent call to
+ // the objective function we don't need to do anything with u. We can just
+ // use Aminus right away.
+ matrix<double,0,1,mem_manager_type> grad(num_triples);
+
+ long idx = 0;
+ std::vector<double> ufar, unear;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ ufar.resize(samples[i].far_vects.size());
+ unear.resize(samples[i].near_vects.size());
+
+ for (unsigned long j = 0; j < unear.size(); ++j)
+ unear[j] = sum(pointwise_multiply(Aminus, samples[i].near_vects[j]*trans(samples[i].near_vects[j])));
+ for (unsigned long j = 0; j < ufar.size(); ++j)
+ ufar[j] = sum(pointwise_multiply(Aminus, samples[i].far_vects[j]*trans(samples[i].far_vects[j])));
+
+ for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j)
+ {
+ for (unsigned long k = 0; k < samples[i].far_vects.size(); ++k)
+ {
+ grad(idx) = bias(idx) + ufar[k]-unear[j];
+ idx++;
+ }
+ }
+ }
+
+ return grad;
+ }
+
+ private:
+ const unsigned long num_triples;
+ const std::vector<compact_frobmetric_training_sample>& samples;
+ matrix<double,0,0,mem_manager_type>& Aminus;
+ const matrix<double,0,1,mem_manager_type>& bias;
+ };
+
+
+ class custom_stop_strategy
+ {
+ public:
+ custom_stop_strategy(
+ double C_,
+ double eps_,
+ bool be_verbose_,
+ unsigned long max_iter_
+ )
+ {
+ _c = C_;
+
+ _cur_iter = 0;
+ _gradient_thresh = eps_;
+ _max_iter = max_iter_;
+ _verbose = be_verbose_;
+ }
+
+ template <typename T>
+ bool should_continue_search (
+ const T& u,
+ const double ,
+ const T& grad
+ )
+ {
+ ++_cur_iter;
+
+ double max_gradient = 0;
+ for (long i = 0; i < grad.size(); ++i)
+ {
+ const bool at_lower_bound = (0 >= u(i) && grad(i) > 0);
+ const bool at_upper_bound = (_c/grad.size() <= u(i) && grad(i) < 0);
+ if (!at_lower_bound && !at_upper_bound)
+ max_gradient = std::max(std::abs(grad(i)), max_gradient);
+ }
+
+ if (_verbose)
+ {
+ std::cout << "iteration: " << _cur_iter << " max_gradient: "<< max_gradient << std::endl;
+ }
+
+ // Only stop when the largest non-bound-constrained element of the gradient
+ // is lower than the threshold.
+ if (max_gradient < _gradient_thresh)
+ return false;
+
+ // Check if we have hit the max allowable number of iterations.
+ if (_cur_iter > _max_iter)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ private:
+ bool _verbose;
+
+ unsigned long _max_iter;
+ unsigned long _cur_iter;
+ double _c;
+ double _gradient_thresh;
+ };
+
+ public:
+ vector_normalizer_frobmetric (
+ )
+ {
+ verbose = false;
+ eps = 0.1;
+ C = 1;
+ max_iter = 5000;
+ _use_identity_matrix_prior = false;
+ }
+
+ bool uses_identity_matrix_prior (
+ ) const
+ {
+ return _use_identity_matrix_prior;
+ }
+
+ void set_uses_identity_matrix_prior (
+ bool use_prior
+ )
+ {
+ _use_identity_matrix_prior = use_prior;
+ }
+
+ void be_verbose(
+ )
+ {
+ verbose = true;
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void vector_normalizer_frobmetric::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps: " << eps_
+ );
+ eps = eps_;
+ }
+
+ double get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ void set_c (
+ double C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void vector_normalizer_frobmetric::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ void set_max_iterations (
+ unsigned long max_iterations
+ )
+ {
+ max_iter = max_iterations;
+ }
+
+ unsigned long get_max_iterations (
+ ) const
+ {
+ return max_iter;
+ }
+
+ double get_c (
+ ) const
+ {
+ return C;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void train (
+ const std::vector<frobmetric_training_sample<matrix_type> >& samples
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(samples.size() > 0,
+ "\tvoid vector_normalizer_frobmetric::train()"
+ << "\n\t you have to give a nonempty set of samples to this function"
+ );
+#ifdef ENABLE_ASSERTS
+ {
+ const long dims = samples[0].anchor_vect.size();
+ DLIB_ASSERT(dims != 0,
+ "\tvoid vector_normalizer_frobmetric::train()"
+ << "\n\t The dimension of the input vectors can't be zero."
+ );
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ DLIB_ASSERT(is_col_vector(samples[i].anchor_vect),
+ "\tvoid vector_normalizer_frobmetric::train()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ );
+ DLIB_ASSERT(samples[i].anchor_vect.size() == dims,
+ "\tvoid vector_normalizer_frobmetric::train()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t dims: " << dims
+ << "\n\t samples[i].anchor_vect.size(): " << samples[i].anchor_vect.size()
+ );
+
+ DLIB_ASSERT(samples[i].num_triples() != 0,
+ "\tvoid vector_normalizer_frobmetric::train()"
+ << "\n\t It is illegal for a training sample to have no data in it"
+ << "\n\t i: " << i
+ );
+ for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j)
+ {
+ DLIB_ASSERT(is_col_vector(samples[i].near_vects[j]),
+ "\tvoid vector_normalizer_frobmetric::train()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ );
+ DLIB_ASSERT(samples[i].near_vects[j].size() == dims,
+ "\tvoid vector_normalizer_frobmetric::train()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t dims: " << dims
+ << "\n\t samples[i].near_vects[j].size(): " << samples[i].near_vects[j].size()
+ );
+ }
+ for (unsigned long j = 0; j < samples[i].far_vects.size(); ++j)
+ {
+ DLIB_ASSERT(is_col_vector(samples[i].far_vects[j]),
+ "\tvoid vector_normalizer_frobmetric::train()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ );
+ DLIB_ASSERT(samples[i].far_vects[j].size() == dims,
+ "\tvoid vector_normalizer_frobmetric::train()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t dims: " << dims
+ << "\n\t samples[i].far_vects[j].size(): " << samples[i].far_vects[j].size()
+ );
+ }
+ }
+ }
+#endif // end ENABLE_ASSERTS
+
+
+ // compute the mean sample
+ m = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ m += samples[i].anchor_vect;
+ m /= samples.size();
+
+ DLIB_ASSERT(is_finite(m), "Some of the input vectors to vector_normalizer_frobmetric::train() have infinite or NaN values");
+
+ // Now we need to find tform. So we setup the optimization problem and run it
+ // over the next few lines of code.
+ unsigned long num_triples = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ num_triples += samples[i].near_vects.size()*samples[i].far_vects.size();
+
+ matrix<double,0,1,mem_manager_type> u(num_triples);
+ matrix<double,0,1,mem_manager_type> bias(num_triples);
+ u = 0;
+ bias = 1;
+
+
+ // precompute all the anchor_vect to far_vects/near_vects pairs
+ std::vector<compact_frobmetric_training_sample> data(samples.size());
+ unsigned long cnt = 0;
+ std::vector<double> far_norm, near_norm;
+ for (unsigned long i = 0; i < data.size(); ++i)
+ {
+ far_norm.clear();
+ near_norm.clear();
+ data[i].far_vects.reserve(samples[i].far_vects.size());
+ data[i].near_vects.reserve(samples[i].near_vects.size());
+ for (unsigned long j = 0; j < samples[i].far_vects.size(); ++j)
+ {
+ data[i].far_vects.push_back(samples[i].anchor_vect - samples[i].far_vects[j]);
+ if (_use_identity_matrix_prior)
+ far_norm.push_back(length_squared(data[i].far_vects.back()));
+ }
+ for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j)
+ {
+ data[i].near_vects.push_back(samples[i].anchor_vect - samples[i].near_vects[j]);
+ if (_use_identity_matrix_prior)
+ near_norm.push_back(length_squared(data[i].near_vects.back()));
+ }
+
+ // Note that this loop only executes if _use_identity_matrix_prior == true.
+ for (unsigned long j = 0; j < near_norm.size(); ++j)
+ {
+ for (unsigned long k = 0; k < far_norm.size(); ++k)
+ {
+ bias(cnt++) = 1 - (far_norm[k] - near_norm[j]);
+ }
+ }
+ }
+
+ // Now run the main part of the algorithm
+ matrix<double,0,0,mem_manager_type> Aminus;
+ find_max_box_constrained(lbfgs_search_strategy(10),
+ custom_stop_strategy(C, eps, verbose, max_iter),
+ objective(data, Aminus, bias),
+ derivative(num_triples, data, Aminus, bias),
+ u, 0, C/num_triples);
+
+
+ // What we need is the optimal Aminus which is a function of u. So we already
+ // have what we need and just need to put it into tform.
+ eigenvalue_decomposition<matrix<double,0,0,mem_manager_type> > ed(make_symmetric(-Aminus));
+ matrix<double,0,1,mem_manager_type> eigs = ed.get_real_eigenvalues();
+ // But first, discard the components that are zero to within the machine epsilon.
+ const double tol = max(eigs)*std::numeric_limits<double>::epsilon();
+ for (long i = 0; i < eigs.size(); ++i)
+ {
+ if (eigs(i) < tol)
+ eigs(i) = 0;
+ }
+ if (_use_identity_matrix_prior)
+ tform = matrix_cast<scalar_type>(identity_matrix(Aminus) + diagm(sqrt(eigs))*trans(ed.get_pseudo_v()));
+ else
+ tform = matrix_cast<scalar_type>(diagm(sqrt(eigs))*trans(ed.get_pseudo_v()));
+
+ // Pre-apply the transform to m so we don't have to do it inside operator()
+ // every time it's called.
+ m = tform*m;
+ }
+
+ long in_vector_size (
+ ) const
+ {
+ return m.nr();
+ }
+
+ long out_vector_size (
+ ) const
+ {
+ return m.nr();
+ }
+
+ const matrix<scalar_type,0,1,mem_manager_type>& transformed_means (
+ ) const
+ {
+ return m;
+ }
+
+ const matrix<scalar_type,0,0,mem_manager_type>& transform (
+ ) const
+ {
+ return tform;
+ }
+
+ const result_type& operator() (
+ const matrix_type& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(in_vector_size() != 0 && in_vector_size() == x.size() &&
+ is_col_vector(x) == true,
+ "\tmatrix vector_normalizer_frobmetric::operator()"
+ << "\n\t you have given invalid arguments to this function"
+ << "\n\t in_vector_size(): " << in_vector_size()
+ << "\n\t x.size(): " << x.size()
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t this: " << this
+ );
+
+ temp_out = tform*x-m;
+ return temp_out;
+ }
+
+ template <typename mt>
+ friend void deserialize (
+ vector_normalizer_frobmetric<mt>& item,
+ std::istream& in
+ );
+
+ template <typename mt>
+ friend void serialize (
+ const vector_normalizer_frobmetric<mt>& item,
+ std::ostream& out
+ );
+
+ private:
+
+ // ------------------- private data members -------------------
+
+ matrix_type m;
+ matrix<scalar_type,0,0,mem_manager_type> tform;
+ bool verbose;
+ double eps;
+ double C;
+ unsigned long max_iter;
+ bool _use_identity_matrix_prior;
+
+ // This is just a temporary variable that doesn't contribute to the
+ // state of this object.
+ mutable matrix_type temp_out;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ void serialize (
+ const vector_normalizer_frobmetric<matrix_type>& item,
+ std::ostream& out
+ )
+ {
+ const int version = 2;
+ serialize(version, out);
+
+ serialize(item.m, out);
+ serialize(item.tform, out);
+ serialize(item.verbose, out);
+ serialize(item.eps, out);
+ serialize(item.C, out);
+ serialize(item.max_iter, out);
+ serialize(item._use_identity_matrix_prior, out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ void deserialize (
+ vector_normalizer_frobmetric<matrix_type>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1 && version != 2)
+ throw serialization_error("Unsupported version found while deserializing dlib::vector_normalizer_frobmetric.");
+
+ deserialize(item.m, in);
+ deserialize(item.tform, in);
+ deserialize(item.verbose, in);
+ deserialize(item.eps, in);
+ deserialize(item.C, in);
+ deserialize(item.max_iter, in);
+ if (version == 2)
+ deserialize(item._use_identity_matrix_prior, in);
+ else
+ item._use_identity_matrix_prior = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_VECTOR_NORMALIZER_FRoBMETRIC_Hh_
+
diff --git a/ml/dlib/dlib/statistics/vector_normalizer_frobmetric_abstract.h b/ml/dlib/dlib/statistics/vector_normalizer_frobmetric_abstract.h
new file mode 100644
index 000000000..302628330
--- /dev/null
+++ b/ml/dlib/dlib/statistics/vector_normalizer_frobmetric_abstract.h
@@ -0,0 +1,328 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_VECTOR_NORMALIZER_FRoBMETRIC_ABSTRACT_Hh_
+#ifdef DLIB_VECTOR_NORMALIZER_FRoBMETRIC_ABSTRACT_Hh_
+
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ struct frobmetric_training_sample
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a training data sample for the
+ vector_normalizer_frobmetric object. It defines a set of training triplets
+ relative to a single anchor_vect vector. That is, it specifies that the
+ learned distance metric should satisfy num_triples() constraints which are,
+ for all valid i and j:
+ length(T*anchor_vect-T*near_vects[i]) + 1 < length(T*anchor_vect - T*far_vects[j])
+ for some appropriate linear transformation T which will be learned by
+ vector_normalizer_frobmetric.
+ !*/
+
+ matrix_type anchor_vect;
+ std::vector<matrix_type> near_vects;
+ std::vector<matrix_type> far_vects;
+
+ unsigned long num_triples (
+ ) const { return near_vects.size() * far_vects.size(); }
+ /*!
+ ensures
+ - returns the number of training triplets defined by this object.
+ !*/
+
+ void clear()
+ /*!
+ ensures
+ - #near_vects.size() == 0
+ - #far_vects.size() == 0
+ !*/
+ };
+
+ template < typename matrix_type >
+ void serialize(const frobmetric_training_sample<matrix_type>& item, std::ostream& out)
+ template < typename matrix_type >
+ void deserialize(frobmetric_training_sample<matrix_type>& item, std::istream& in)
+ /*!
+ provides serialisation support.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ class vector_normalizer_frobmetric
+ {
+ /*!
+ REQUIREMENTS ON matrix_type
+ - must be a dlib::matrix object capable of representing column
+ vectors
+
+ INITIAL VALUE
+ - in_vector_size() == 0
+ - out_vector_size() == 0
+ - get_epsilon() == 0.1
+ - get_c() == 1
+ - get_max_iterations() == 5000
+ - This object is not verbose
+ - uses_identity_matrix_prior() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for performing the FrobMetric distance metric
+ learning algorithm described in the following paper:
+ A Scalable Dual Approach to Semidefinite Metric Learning
+ By Chunhua Shen, Junae Kim, Lei Wang, in CVPR 2011
+
+ Therefore, this object is a tool that takes as input training triplets
+ (anchor_vect, near, far) of vectors and attempts to learn a linear
+ transformation T such that:
+ length(T*anchor_vect-T*near) + 1 < length(T*anchor_vect - T*far)
+ That is, you give a bunch of anchor_vect vectors and for each anchor_vect
+ you specify some vectors which should be near to it and some that should be
+ far form it. This object then tries to find a transformation matrix that
+ makes the "near" vectors close to their anchors while the "far" vectors are
+ farther away.
+
+ THREAD SAFETY
+ Note that this object contains a cached matrix object it uses
+ to store intermediate results for normalization. This avoids
+ needing to reallocate it every time this object performs normalization
+ but also makes it non-thread safe. So make sure you don't share
+ instances of this object between threads.
+ !*/
+
+ public:
+ typedef typename matrix_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef matrix_type result_type;
+
+ vector_normalizer_frobmetric (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ bool uses_identity_matrix_prior (
+ ) const;
+ /*!
+ ensures
+ - Normally this object will try and find a matrix transform() that
+ minimizes sum(squared(transform())) but also fits the training data.
+ However, if #uses_identity_matrix_prior() == true then it will instead
+ try to find the transformation matrix that minimizes
+ sum(squared(identity_matrix()-transform())). That is, it will try to
+ find the matrix most similar to the identity matrix that best fits the
+ training data.
+ !*/
+
+ void set_uses_identity_matrix_prior (
+ bool use_prior
+ );
+ /*!
+ ensures
+ - #uses_identity_matrix_prior() == use_prior
+ !*/
+
+ void be_verbose(
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so the user can
+ observe the progress of the train() routine.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out.
+ !*/
+
+ void set_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer to
+ execute.
+ !*/
+
+ void set_c (
+ double C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #set_c() == C
+ !*/
+
+ double get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the regularization parameter. It is the parameter that
+ determines the trade-off between trying to fit the training data exactly
+ or allowing more errors but hopefully improving the generalization of the
+ resulting distance metric. Larger values encourage exact fitting while
+ smaller values of C may encourage better generalization.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iterations
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iterations
+ !*/
+
+ unsigned long get_max_iterations (
+ ) const;
+ /*!
+ ensures
+ - The train() routine uses an iterative numerical solver to find the best
+ distance metric. This function returns the maximum allowable number of
+ iterations it will use before terminating. Note that typically the
+ solver terminates prior to the max iteration count limit due to the error
+ dropping below get_epsilon().
+ !*/
+
+ void train (
+ const std::vector<frobmetric_training_sample<matrix_type> >& samples
+ );
+ /*!
+ requires
+ - samples.size() != 0
+ - All matrices inside samples (i.e. anchors and elements of near_vects and far_vects)
+ are column vectors with the same non-zero dimension.
+ - All the vectors in samples contain finite values.
+ - All elements of samples contain data, specifically, for all valid i:
+ - samples[i].num_triples() != 0
+ ensures
+ - learns a distance metric from the given training samples. After train
+ finishes you can use this object's operator() to transform vectors
+ according to the learned distance metric. In particular, we will have:
+ - #transform() == The linear transformation learned by the FrobMetric
+ learning procedure.
+ - #in_vector_size() == samples[0].anchor_vect.size()
+ - You can call (*this)(x) to transform a vector according to the learned
+ distance metric. That is, it should generally be the case that:
+ - length((*this)(anchor_vect) - (*this)(near)) + 1 < length((*this)(anchor_vect) - (*this)(far))
+ for the anchor_vect, near, and far vectors in the training data.
+ - #transformed_means() == the mean of the input anchor_vect vectors
+ after being transformed by #transform()
+ !*/
+
+ long in_vector_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows that input vectors are required to contain if
+ they are to be normalized by this object.
+ !*/
+
+ long out_vector_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of rows in the normalized vectors that come out of
+ this object.
+ - The value returned is always in_vector_size(). So out_vector_size() is
+ just provided to maintain interface consistency with other vector
+ normalizer objects. That is, the transformations applied by this object
+ do not change the dimensionality of the vectors.
+ !*/
+
+ const matrix<scalar_type,0,1,mem_manager_type>& transformed_means (
+ ) const;
+ /*!
+ ensures
+ - returns a column vector V such that:
+ - V.size() == in_vector_size()
+ - V is a vector such that subtracting it from transformed vectors
+ results in them having an expected value of 0. Therefore, it is
+ equal to transform() times the mean of the input anchor_vect vectors
+ given to train().
+ !*/
+
+ const matrix<scalar_type,0,0,mem_manager_type>& transform (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the transformation matrix we learned during the last
+ call to train().
+ - The returned matrix is square and has in_vector_size() by in_vector_size()
+ dimensions.
+ !*/
+
+ const result_type& operator() (
+ const matrix_type& x
+ ) const;
+ /*!
+ requires
+ - in_vector_size() != 0
+ - in_vector_size() == x.size()
+ - is_col_vector(x) == true
+ ensures
+ - returns a normalized version of x, call it Z, that has the following
+ properties:
+ - Z == The result of applying the linear transform we learned during
+ train() to the input vector x.
+ - Z == transform()*x-transformed_means()
+ - is_col_vector(Z) == true
+ - Z.size() == x.size()
+ - The expected value of each element of Z is 0.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ void serialize (
+ const vector_normalizer_frobmetric<matrix_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type
+ >
+ void deserialize (
+ vector_normalizer_frobmetric<matrix_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_VECTOR_NORMALIZER_FRoBMETRIC_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/std_allocator.h b/ml/dlib/dlib/std_allocator.h
new file mode 100644
index 000000000..b6e411c12
--- /dev/null
+++ b/ml/dlib/dlib/std_allocator.h
@@ -0,0 +1,199 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STD_ALLOc_H_
+#define DLIB_STD_ALLOc_H_
+
+#include <limits>
+#include <memory>
+#include "enable_if.h"
+#include "algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename M
+ >
+ class std_allocator
+ {
+ /*!
+ REQUIREMENTS ON M
+ must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
+ must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
+ must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ M::type can be set to anything.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is an implementation of an allocator that conforms to the C++ standard
+ requirements for allocator objects. The M template argument is one of the dlib
+ memory manager objects and this allocator implementation will do all of its memory allocations
+ using whatever dlib memory manager you supply.
+
+ Thus, using this allocator object you can use any of the dlib memory manager objects with
+ the containers in the STL or with any other object that requires a C++ allocator object.
+
+ It is important to note that many STL implementations make the assumption that the memory
+ allocated by one allocator can be freed by another. This effectively means that you should
+ only use a global or stateless memory manager with the std_allocator. Either that or you
+ have to verify that your version of the STL isn't going to try and allocate and deallocate
+ memory with different allocators.
+ !*/
+
+ public:
+ //type definitions
+ typedef std::size_t size_type;
+ typedef std::ptrdiff_t difference_type;
+ typedef T* pointer;
+ typedef const T* const_pointer;
+ typedef T& reference;
+ typedef const T& const_reference;
+ typedef T value_type;
+
+ //rebind std_allocator to type U
+ template <typename U>
+ struct rebind {
+ typedef std_allocator<U,M> other;
+ };
+
+ //return address of values
+ pointer address (reference value) const { return &value; }
+
+ const_pointer address (const_reference value) const { return &value; }
+
+ /*constructors and destructor
+ *-nothing to do because the std_allocator has no state
+ */
+ std_allocator() throw() { }
+
+ std_allocator(const std_allocator&) throw() { }
+
+ template <typename U>
+ std_allocator (const std_allocator<U,M>&) throw() { }
+
+ ~std_allocator() throw() { }
+
+ //return maximum number of elements that can be allocated
+ size_type max_size () const throw()
+ {
+ //for numeric_limits see Section 4.3, page 59
+ return std::numeric_limits<size_t>::max() / sizeof(T);
+ }
+
+ //allocate but don't initialize num elements of type T
+ pointer allocate (
+ size_type num,
+ typename std_allocator<void,M>::const_pointer = 0
+ )
+ {
+ return (pointer) pool.allocate_array(num*sizeof(T));
+ }
+
+ // This function is not required by the C++ standard but some versions of the STL
+ // distributed with gcc erroneously require it. See the bug report for further
+ // details: http://gcc.gnu.org/bugzilla/show_bug.cgi?id=51626
+ void construct(pointer p) { return construct(p, value_type()); }
+
+ //initialize elements of allocated storage p with value value
+ void construct (pointer p, const T& value)
+ {
+ //initialize memory with placement new
+ new((void*)p)T(value);
+ }
+
+
+ //destroy elements of initialized storage p
+ void destroy (pointer p)
+ {
+ // destroy objects by calling their destructor
+ p->~T();
+ }
+
+ //deallocate storage p of deleted elements
+ void deallocate (pointer p, size_type )
+ {
+ pool.deallocate_array((char*)p);
+ }
+
+ void swap (
+ std_allocator& item
+ )
+ {
+ pool.swap(item.pool);
+ }
+
+ std_allocator& operator= (const std_allocator&) { return *this;}
+
+ private:
+ typename M::template rebind<char>::other pool;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename M
+ >
+ class std_allocator<void,M>
+ {
+ public:
+ //type definitions
+ typedef std::size_t size_type;
+ typedef std::ptrdiff_t difference_type;
+ typedef void* pointer;
+ typedef const void* const_pointer;
+ typedef void value_type;
+
+ //rebind std_allocator to type U
+ template <typename U>
+ struct rebind {
+ typedef std_allocator<U,M> other;
+ };
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2, typename enabled = void>
+ struct std_alloc_compare
+ { const static bool are_interchangeable = false; };
+
+ template <typename M1, typename M2>
+ struct std_alloc_compare<M1,M2,typename enable_if<is_same_type<typename M1::mm_global_type, typename M2::mm_global_type> >::type>
+ { const static bool are_interchangeable = true; };
+
+ template <typename M>
+ struct std_alloc_compare<M,M,typename enable_if_c<M::is_stateless>::type>
+ { const static bool are_interchangeable = true; };
+
+ //return that all specializations of this std_allocator are interchangeable if they use memory_manager_global
+ // instances with the same mm_global_type
+ template <typename T1, typename M1, typename T2, typename M2>
+ bool operator== (
+ const std_allocator<T1,M1>&,
+ const std_allocator<T2,M2>&
+ ) throw()
+ { return std_alloc_compare<M1,M2>::are_interchangeable; }
+
+ template <typename T1, typename M1, typename T2, typename M2>
+ bool operator!= (
+ const std_allocator<T1,M1>&,
+ const std_allocator<T2,M2>&
+ ) throw()
+ { return !std_alloc_compare<M1,M2>::are_interchangeable; }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename M>
+ void swap (
+ std_allocator<T,M>& a,
+ std_allocator<T,M>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STD_ALLOc_H_
+
diff --git a/ml/dlib/dlib/stl_checked.h b/ml/dlib/dlib/stl_checked.h
new file mode 100644
index 000000000..ec15aa84e
--- /dev/null
+++ b/ml/dlib/dlib/stl_checked.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STL_CHECKEd_HEADER
+#define DLIB_STL_CHECKEd_HEADER
+
+#include "stl_checked/std_vector_c.h"
+
+#endif // DLIB_STL_CHECKEd_HEADER
+
+
diff --git a/ml/dlib/dlib/stl_checked/std_vector_c.h b/ml/dlib/dlib/stl_checked/std_vector_c.h
new file mode 100644
index 000000000..d46c9850c
--- /dev/null
+++ b/ml/dlib/dlib/stl_checked/std_vector_c.h
@@ -0,0 +1,333 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STD_VECTOr_C_H_
+#define DLIB_STD_VECTOr_C_H_
+
+#include <vector>
+#include <algorithm>
+#include "../assert.h"
+#include "std_vector_c_abstract.h"
+#include "../serialize.h"
+#include "../is_kind.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename Allocator = std::allocator<T>
+ >
+ class std_vector_c : public std::vector<T,Allocator>
+ {
+ typedef typename std::vector<T,Allocator> base_type;
+ public:
+ // types:
+ typedef typename Allocator::reference reference;
+ typedef typename Allocator::const_reference const_reference;
+ typedef typename base_type::iterator iterator; // See 23.1
+ typedef typename base_type::const_iterator const_iterator; // See 23.1
+ typedef typename base_type::size_type size_type; // See 23.1
+ typedef typename base_type::difference_type difference_type;// See 23.1
+ typedef T value_type;
+ typedef Allocator allocator_type;
+ typedef typename Allocator::pointer pointer;
+ typedef typename Allocator::const_pointer const_pointer;
+ typedef std::reverse_iterator<iterator> reverse_iterator;
+ typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
+
+
+ // 23.2.4.1 construct/copy/destroy:
+ explicit std_vector_c(const Allocator& alloc= Allocator()) : base_type(alloc) {}
+
+ explicit std_vector_c(size_type n, const T& value = T(),
+ const Allocator& alloc= Allocator()) : base_type(n, value, alloc) {}
+
+ template <typename InputIterator>
+ std_vector_c(InputIterator first, InputIterator last,
+ const Allocator& alloc= Allocator()) : base_type(first,last,alloc) {}
+
+ std_vector_c(const std::vector<T,Allocator>& x) : base_type(x) {}
+
+ std_vector_c<T,Allocator>& operator=(const std::vector<T,Allocator>& x)
+ {
+ static_cast<base_type&>(*this) = x;
+ return *this;
+ }
+
+ template <typename InputIterator>
+ void assign(InputIterator first, InputIterator last) { base_type::assign(first,last); }
+ void assign(size_type n, const T& u) { base_type::assign(n,u); }
+ allocator_type get_allocator() const { return base_type::get_allocator(); }
+ // iterators:
+ iterator begin() { return base_type::begin(); }
+ const_iterator begin() const { return base_type::begin(); }
+ iterator end() { return base_type::end(); }
+ const_iterator end() const { return base_type::end(); }
+ reverse_iterator rbegin() { return base_type::rbegin(); }
+ const_reverse_iterator rbegin() const { return base_type::rbegin(); }
+ reverse_iterator rend() { return base_type::rend(); }
+ const_reverse_iterator rend() const { return base_type::rend(); }
+ // 23.2.4.2 capacity:
+ size_type size() const { return base_type::size(); }
+ size_type max_size() const { return base_type::max_size(); }
+ void resize(size_type sz, T c = T()) { base_type::resize(sz,c); }
+ size_type capacity() const { return base_type::capacity(); }
+ bool empty() const { return base_type::empty(); }
+ void reserve(size_type n) { base_type::reserve(n); }
+
+ // element access:
+ const_reference at(size_type n) const { return base_type::at(n); }
+ reference at(size_type n) { return base_type::at(n); }
+
+
+ // 23.2.4.3 modifiers:
+ void push_back(const T& x) { base_type::push_back(x); }
+ void swap(std_vector_c<T,Allocator>& x) { base_type::swap(x); }
+ void clear() { base_type::clear(); }
+
+
+ // ------------------------------------------------------
+ // Things that have preconditions that should be checked.
+ // ------------------------------------------------------
+
+ reference operator[](
+ size_type n
+ )
+ {
+ DLIB_CASSERT(n < size(),
+ "\treference std_vector_c::operator[](n)"
+ << "\n\tYou have supplied an invalid index"
+ << "\n\tthis: " << this
+ << "\n\tn: " << n
+ << "\n\tsize(): " << size()
+ );
+ return static_cast<base_type&>(*this)[n];
+ }
+
+ // ------------------------------------------------------
+
+ const_reference operator[](
+ size_type n
+ ) const
+ {
+ DLIB_CASSERT(n < size(),
+ "\tconst_reference std_vector_c::operator[](n)"
+ << "\n\tYou have supplied an invalid index"
+ << "\n\tthis: " << this
+ << "\n\tn: " << n
+ << "\n\tsize(): " << size()
+ );
+ return static_cast<const base_type&>(*this)[n];
+ }
+
+ // ------------------------------------------------------
+
+ reference front(
+ )
+ {
+ DLIB_CASSERT(size() > 0,
+ "\treference std_vector_c::front()"
+ << "\n\tYou can't call front() on an empty vector"
+ << "\n\tthis: " << this
+ );
+ return base_type::front();
+ }
+
+ // ------------------------------------------------------
+
+ const_reference front(
+ ) const
+ {
+ DLIB_CASSERT(size() > 0,
+ "\tconst_reference std_vector_c::front()"
+ << "\n\tYou can't call front() on an empty vector"
+ << "\n\tthis: " << this
+ );
+ return base_type::front();
+ }
+
+ // ------------------------------------------------------
+
+ reference back(
+ )
+ {
+ DLIB_CASSERT(size() > 0,
+ "\treference std_vector_c::back()"
+ << "\n\tYou can't call back() on an empty vector"
+ << "\n\tthis: " << this
+ );
+ return base_type::back();
+ }
+
+ // ------------------------------------------------------
+
+ const_reference back(
+ ) const
+ {
+ DLIB_CASSERT(size() > 0,
+ "\tconst_reference std_vector_c::back()"
+ << "\n\tYou can't call back() on an empty vector"
+ << "\n\tthis: " << this
+ );
+ return base_type::back();
+ }
+
+ // ------------------------------------------------------
+
+ void pop_back(
+ )
+ {
+ DLIB_CASSERT(size() > 0,
+ "\tconst_reference std_vector_c::pop_back()"
+ << "\n\tYou can't call pop_back() on an empty vector"
+ << "\n\tthis: " << this
+ );
+ base_type::pop_back();
+ }
+
+ // ------------------------------------------------------
+
+ iterator insert(
+ iterator position,
+ const T& x
+ )
+ {
+ DLIB_CASSERT( begin() <= position && position <= end(),
+ "\titerator std_vector_c::insert(position,x)"
+ << "\n\tYou have called insert() with an invalid position"
+ << "\n\tthis: " << this
+ );
+ return base_type::insert(position, x);
+ }
+
+ // ------------------------------------------------------
+
+ void insert(
+ iterator position,
+ size_type n,
+ const T& x
+ )
+ {
+ DLIB_CASSERT( begin() <= position && position <= end(),
+ "\tvoid std_vector_c::insert(position,n,x)"
+ << "\n\tYou have called insert() with an invalid position"
+ << "\n\tthis: " << this
+ );
+ base_type::insert(position, n, x);
+ }
+
+ // ------------------------------------------------------
+
+ template <typename InputIterator>
+ void insert(
+ iterator position,
+ InputIterator first,
+ InputIterator last
+ )
+ {
+ DLIB_CASSERT( begin() <= position && position <= end(),
+ "\tvoid std_vector_c::insert(position,first,last)"
+ << "\n\tYou have called insert() with an invalid position"
+ << "\n\tthis: " << this
+ );
+ base_type::insert(position, first, last);
+ }
+
+ // ------------------------------------------------------
+
+ iterator erase(
+ iterator position
+ )
+ {
+ DLIB_CASSERT( begin() <= position && position < end(),
+ "\titerator std_vector_c::erase(position)"
+ << "\n\tYou have called erase() with an invalid position"
+ << "\n\tthis: " << this
+ );
+ return base_type::erase(position);
+ }
+
+ // ------------------------------------------------------
+
+ iterator erase(
+ iterator first,
+ iterator last
+ )
+ {
+ DLIB_CASSERT( begin() <= first && first <= last && last <= end(),
+ "\titerator std_vector_c::erase(first,last)"
+ << "\n\tYou have called erase() with an invalid range of iterators"
+ << "\n\tthis: " << this
+ );
+ return base_type::erase(first,last);
+ }
+
+ // ------------------------------------------------------
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+// Add these swaps just to make absolutely sure the specialized swap always gets called even
+// if the compiler is crappy and would otherwise mess it up.
+ template <typename T, typename Allocator>
+ void swap(std_vector_c<T,Allocator>& x, std_vector_c<T,Allocator>& y) { x.swap(y); }
+
+ template <typename T, typename Allocator>
+ void swap(std::vector<T,Allocator>& x, std_vector_c<T,Allocator>& y) { x.swap(y); }
+
+ template <typename T, typename Allocator>
+ void swap(std_vector_c<T,Allocator>& x, std::vector<T,Allocator>& y) { y.swap(x); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ void serialize (
+ const std_vector_c<T,alloc>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ const unsigned long size = static_cast<unsigned long>(item.size());
+
+ serialize(size,out);
+ for (unsigned long i = 0; i < item.size(); ++i)
+ serialize(item[i],out);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type std_vector_c"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ void deserialize (
+ std_vector_c<T, alloc>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ unsigned long size;
+ deserialize(size,in);
+ item.resize(size);
+ for (unsigned long i = 0; i < size; ++i)
+ deserialize(item[i],in);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type std_vector_c"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ struct is_std_vector<std_vector_c<T,alloc> > { const static bool value = true; };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STD_VECTOr_C_H_
+
diff --git a/ml/dlib/dlib/stl_checked/std_vector_c_abstract.h b/ml/dlib/dlib/stl_checked/std_vector_c_abstract.h
new file mode 100644
index 000000000..2a8045c72
--- /dev/null
+++ b/ml/dlib/dlib/stl_checked/std_vector_c_abstract.h
@@ -0,0 +1,470 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STD_VECTOr_C_ABSTRACT_H_
+#ifdef DLIB_STD_VECTOr_C_ABSTRACT_H_
+
+#include <vector>
+#include <algorithm>
+#include "../assert.h"
+
+namespace dlib
+{
+
+ template <
+ typename T,
+ typename Allocator = std::allocator<T>
+ >
+ class std_vector_c : public std::vector<T,Allocator>
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple wrapper around the std::vector object. It
+ provides an identical interface but also checks the preconditions of
+ each member function. That is, if you violate a requires
+ clause the dlib::fatal_error exception is thrown.
+ !*/
+
+ typedef typename std::vector<T,Allocator> base_type;
+ public:
+ typedef typename Allocator::reference reference;
+ typedef typename Allocator::const_reference const_reference;
+ typedef typename base_type::iterator iterator;
+ typedef typename base_type::const_iterator const_iterator;
+ typedef typename base_type::size_type size_type;
+ typedef typename base_type::difference_type difference_type;
+ typedef T value_type;
+ typedef Allocator allocator_type;
+ typedef typename Allocator::pointer pointer;
+ typedef typename Allocator::const_pointer const_pointer;
+ typedef std::reverse_iterator<iterator> reverse_iterator;
+ typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
+
+
+ explicit std_vector_c(
+ const Allocator& alloc = Allocator()
+ );
+ /*!
+ ensures
+ - #get_allocator() == alloc
+ - #size() == 0
+ !*/
+
+ explicit std_vector_c (
+ size_type n,
+ const T& value = T(),
+ const Allocator& alloc = Allocator()
+ );
+ /*!
+ ensures
+ - #size() == n
+ - #get_allocator() == alloc
+ - for all valid i:
+ - (*this)[i] == value
+ !*/
+
+ template <typename InputIterator>
+ std_vector_c (
+ InputIterator first,
+ InputIterator last,
+ const Allocator& alloc = Allocator()
+ );
+ /*!
+ ensures
+ - #size() == std::distance(first,last)
+ - #get_allocator() == alloc
+ - std::equal(first, last, begin()) == true
+ !*/
+
+ std_vector_c(
+ const std::vector<T,Allocator>& x
+ );
+ /*!
+ ensures
+ - #*this == x
+ !*/
+
+ std_vector_c<T,Allocator>& operator= (
+ const std::vector<T,Allocator>& x
+ );
+ /*!
+ ensures
+ - #*this == x
+ - returns #*this
+ !*/
+
+ template <typename InputIterator>
+ void assign(
+ InputIterator first,
+ InputIterator last
+ );
+ /*!
+ ensures
+ - #size() == std::distance(first,last)
+ - std::equal(first, last, begin()) == true
+ !*/
+
+ void assign(
+ size_type n,
+ const T& value
+ );
+ /*!
+ ensures
+ - #size() == n
+ - for all valid i:
+ - (*this)[i] == value
+ !*/
+
+ allocator_type get_allocator(
+ ) const;
+ /*!
+ ensures
+ - returns the allocator used by this vector
+ !*/
+
+ iterator begin(
+ );
+ /*!
+ ensures
+ - if (size() > 0) then
+ - returns an iterator referring to the first element in
+ this container.
+ - else
+ - returns end()
+ !*/
+
+ const_iterator begin(
+ ) const;
+ /*!
+ ensures
+ - if (size() > 0) then
+ - returns a const_iterator referring to the first element in
+ this container.
+ - else
+ - returns end()
+ !*/
+
+ iterator end(
+ );
+ /*!
+ ensures
+ - returns an iterator that represents one past the end of
+ this container
+ !*/
+
+ const_iterator end(
+ ) const;
+ /*!
+ ensures
+ - returns an iterator that represents one past the end of
+ this container
+ !*/
+
+ reverse_iterator rbegin(
+ );
+ /*!
+ ensures
+ - returns std::reverse_iterator(end())
+ !*/
+
+ const_reverse_iterator rbegin(
+ ) const;
+ /*!
+ ensures
+ - returns std::reverse_iterator(end())
+ !*/
+
+ reverse_iterator rend(
+ );
+ /*!
+ ensures
+ - returns std::reverse_iterator(begin())
+ !*/
+
+ const_reverse_iterator rend(
+ ) const;
+ /*!
+ ensures
+ - returns std::reverse_iterator(begin())
+ !*/
+
+ size_type size(
+ ) const;
+ /*!
+ ensures
+ - returns end()-begin()
+ (i.e. returns the number of elements in this container)
+ !*/
+
+ size_type max_size(
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of elements this vector can contain
+ !*/
+
+ void resize(
+ size_type sz,
+ T c = T()
+ );
+ /*!
+ ensures
+ - #size() == sz
+ - any element with index between 0 and sz - 1 which was in the
+ vector before the call to resize() retains its value and index.
+ All other elements have a value given by c.
+ !*/
+
+ size_type capacity(
+ ) const;
+ /*!
+ ensures
+ - returns the total number of elements that the vector can hold without
+ requiring reallocation.
+ !*/
+
+ bool empty(
+ ) const;
+ /*!
+ ensures
+ - if (size() == 0) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void reserve(
+ size_type n
+ );
+ /*!
+ ensures
+ - #capacity() >= n
+ !*/
+
+ const_reference at(
+ size_type n
+ ) const;
+ /*!
+ ensures
+ - if (n < size()) then
+ - returns a const reference to (*this)[n]
+ - else
+ - throws std::out_of_range
+ !*/
+
+ reference at(
+ size_type n
+ );
+ /*!
+ ensures
+ - if (n < size()) then
+ - returns a reference to (*this)[n]
+ - else
+ - throws std::out_of_range
+ !*/
+
+ void push_back(
+ const T& x
+ );
+ /*!
+ ensures
+ - #size() == size() + 1
+ - #back() == x
+ !*/
+
+ void swap(
+ std_vector_c<T,Allocator>& x
+ );
+ /*!
+ ensures
+ - swaps the state of *this and x
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #size() == 0
+ !*/
+
+ reference operator[](
+ size_type n
+ );
+ /*!
+ requires
+ - n < size()
+ ensures
+ - returns a reference to the nth element of this container
+ !*/
+
+ const_reference operator[](
+ size_type n
+ ) const;
+ /*!
+ requires
+ - n < size()
+ ensures
+ - returns a const reference to the nth element of this container
+ !*/
+
+ reference front(
+ );
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - returns a reference to (*this)[0]
+ !*/
+
+ const_reference front(
+ ) const;
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - returns a const reference to (*this)[0]
+ !*/
+
+ reference back(
+ );
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - returns a reference to (*this)[size()-1]
+ !*/
+
+ const_reference back(
+ ) const;
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - returns a const reference to (*this)[size()-1]
+ !*/
+
+ void pop_back(
+ );
+ /*!
+ requires
+ - size() > 0
+ ensures
+ - #size() == size() - 1
+ - removes the last element in the vector but leaves the others
+ unmodified.
+ !*/
+
+ iterator insert(
+ iterator position,
+ const T& x
+ );
+ /*!
+ requires
+ - begin() <= position && position <= end()
+ (i.e. position references an element in this vector object)
+ ensures
+ - #size() == size() + 1
+ - inserts a copy of x into *this before the given position
+ - returns an iterator that points to the copy of x inserted
+ into *this
+ !*/
+
+ void insert(
+ iterator position,
+ size_type n,
+ const T& x
+ );
+ /*!
+ requires
+ - begin() <= position && position <= end()
+ (i.e. position references an element in this vector object)
+ ensures
+ - #size() == size() + n
+ - inserts n copies of x into *this before the given position
+ !*/
+
+ template <typename InputIterator>
+ void insert(
+ iterator position,
+ InputIterator first,
+ InputIterator last
+ );
+ /*!
+ requires
+ - begin() <= position && position <= end()
+ (i.e. position references an element in this vector object)
+ - first and last are not iterators into *this
+ ensures
+ - #size() == size() + std::distance(last,first)
+ - inserts copies of the range of elements [first,last) into *this
+ before the given position
+ !*/
+
+ iterator erase(
+ iterator position
+ );
+ /*!
+ requires
+ - begin() <= position && position < end()
+ (i.e. position references an element in this vector object)
+ ensures
+ - #size() == size() - 1
+ - removes the element in this vector referenced by position but
+ leaves all other elements in this vector unmodified.
+ - if (position < end()-1) then
+ - returns an iterator referencing the element immediately
+ following *position prior to the erase.
+ - else
+ - returns end()
+ !*/
+
+ iterator erase(
+ iterator first,
+ iterator last
+ );
+ /*!
+ requires
+ - begin() <= first && first <= last && last <= end()
+ (i.e. the range [first,last) must be inside this container )
+ ensures
+ - #size() == size() - (last-first)
+ - removes the elements in this vector referenced by the
+ iterator range [first,last) but leaves all other elements
+ in this vector unmodified.
+ - if (last < end()-1) then
+ - returns an iterator referencing the element immediately
+ following *last prior to the erase.
+ - else
+ - returns end()
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ void serialize (
+ const std_vector_c<T,alloc>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ void deserialize (
+ std_vector_c<T, alloc>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STD_VECTOr_C_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/string.h b/ml/dlib/dlib/string.h
new file mode 100644
index 000000000..671fee404
--- /dev/null
+++ b/ml/dlib/dlib/string.h
@@ -0,0 +1,9 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRINg_TOP_
+#define DLIB_STRINg_TOP_
+
+#include "string/string.h"
+
+#endif // DLIB_STRINg_TOP_
+
diff --git a/ml/dlib/dlib/string/cassert b/ml/dlib/dlib/string/cassert
new file mode 100644
index 000000000..6139ba823
--- /dev/null
+++ b/ml/dlib/dlib/string/cassert
@@ -0,0 +1 @@
+#include "../dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/string/iomanip b/ml/dlib/dlib/string/iomanip
new file mode 100644
index 000000000..6139ba823
--- /dev/null
+++ b/ml/dlib/dlib/string/iomanip
@@ -0,0 +1 @@
+#include "../dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/string/iosfwd b/ml/dlib/dlib/string/iosfwd
new file mode 100644
index 000000000..6139ba823
--- /dev/null
+++ b/ml/dlib/dlib/string/iosfwd
@@ -0,0 +1 @@
+#include "../dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/string/iostream b/ml/dlib/dlib/string/iostream
new file mode 100644
index 000000000..6139ba823
--- /dev/null
+++ b/ml/dlib/dlib/string/iostream
@@ -0,0 +1 @@
+#include "../dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/string/locale b/ml/dlib/dlib/string/locale
new file mode 100644
index 000000000..6139ba823
--- /dev/null
+++ b/ml/dlib/dlib/string/locale
@@ -0,0 +1 @@
+#include "../dlib_include_path_tutorial.txt"
diff --git a/ml/dlib/dlib/string/string.h b/ml/dlib/dlib/string/string.h
new file mode 100644
index 000000000..2c2602198
--- /dev/null
+++ b/ml/dlib/dlib/string/string.h
@@ -0,0 +1,1004 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRINg_
+#define DLIB_STRINg_
+
+#include "string_abstract.h"
+#include <sstream>
+#include "../algs.h"
+#include <string>
+#include <iostream>
+#include <iomanip>
+#include "../error.h"
+#include "../assert.h"
+#include "../uintn.h"
+#include <cctype>
+#include <algorithm>
+#include <vector>
+#include "../enable_if.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ inline const typename disable_if<is_same_type<charT,char>,std::string>::type narrow (
+ const std::basic_string<charT,traits,alloc>& str
+ )
+ {
+ std::string temp;
+ temp.reserve(str.size());
+ std::string::size_type i;
+ for (i = 0; i < str.size(); ++i)
+ {
+ if (zero_extend_cast<unsigned long>(str[i]) > 255)
+ temp += ' ';
+ else
+ temp += zero_extend_cast<char>(str[i]);
+ }
+ return temp;
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ inline const typename enable_if<is_same_type<charT,char>,std::string>::type narrow (
+ const std::basic_string<charT,traits,alloc>& str
+ )
+ {
+ return str;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<char,traits,alloc> tolower (
+ const std::basic_string<char,traits,alloc>& str
+ )
+ {
+ std::basic_string<char,traits,alloc> temp;
+
+ temp.resize(str.size());
+
+ for (typename std::basic_string<char,traits,alloc>::size_type i = 0; i < str.size(); ++i)
+ temp[i] = (char)std::tolower(str[i]);
+
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<char,traits,alloc> toupper (
+ const std::basic_string<char,traits,alloc>& str
+ )
+ {
+ std::basic_string<char,traits,alloc> temp;
+
+ temp.resize(str.size());
+
+ for (typename std::basic_string<char,traits,alloc>::size_type i = 0; i < str.size(); ++i)
+ temp[i] = (char)std::toupper(str[i]);
+
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ bool strings_equal_ignore_case (
+ const std::basic_string<char,traits,alloc>& str1,
+ const std::basic_string<char,traits,alloc>& str2
+ )
+ {
+ if (str1.size() != str2.size())
+ return false;
+
+ for (typename std::basic_string<char,traits,alloc>::size_type i = 0; i < str1.size(); ++i)
+ {
+ if (std::tolower(str1[i]) != std::tolower(str2[i]))
+ return false;
+ }
+
+ return true;
+ }
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ bool strings_equal_ignore_case (
+ const std::basic_string<char,traits,alloc>& str1,
+ const char* str2
+ )
+ {
+ typename std::basic_string<char,traits,alloc>::size_type i;
+ for (i = 0; i < str1.size(); ++i)
+ {
+ // if we hit the end of str2 then the strings aren't the same length
+ if (str2[i] == '\0')
+ return false;
+
+ if (std::tolower(str1[i]) != std::tolower(str2[i]))
+ return false;
+ }
+
+ // This happens when str2 is longer than str1
+ if (str2[i] != '\0')
+ return false;
+
+ return true;
+ }
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ bool strings_equal_ignore_case (
+ const char* str1,
+ const std::basic_string<char,traits,alloc>& str2
+ )
+ {
+ return strings_equal_ignore_case(str2, str1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ bool strings_equal_ignore_case (
+ const std::basic_string<char,traits,alloc>& str1,
+ const std::basic_string<char,traits,alloc>& str2,
+ unsigned long num
+ )
+ {
+ if (str1.size() != str2.size() && (str1.size() < num || str2.size() < num))
+ return false;
+
+ for (typename std::basic_string<char,traits,alloc>::size_type i = 0; i < str1.size() && i < num; ++i)
+ {
+ if (std::tolower(str1[i]) != std::tolower(str2[i]))
+ return false;
+ }
+
+ return true;
+ }
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ bool strings_equal_ignore_case (
+ const std::basic_string<char,traits,alloc>& str1,
+ const char* str2,
+ unsigned long num
+ )
+ {
+ typename std::basic_string<char,traits,alloc>::size_type i;
+ for (i = 0; i < str1.size() && i < num; ++i)
+ {
+ // if we hit the end of str2 then the strings aren't the same length
+ if (str2[i] == '\0')
+ return false;
+
+ if (std::tolower(str1[i]) != std::tolower(str2[i]))
+ return false;
+ }
+
+ return true;
+ }
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ bool strings_equal_ignore_case (
+ const char* str1,
+ const std::basic_string<char,traits,alloc>& str2,
+ unsigned long num
+ )
+ {
+ return strings_equal_ignore_case(str2, str1, num);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class cast_to_string_error : public error
+ {
+ public:
+ cast_to_string_error():error(ECAST_TO_STRING) {}
+ };
+
+ template <
+ typename T
+ >
+ const std::string cast_to_string (
+ const T& item
+ )
+ {
+ std::ostringstream sout;
+ sout << item;
+ if (!sout)
+ throw cast_to_string_error();
+ return sout.str();
+ }
+
+ // don't declare this if we are using mingw because it apparently doesn't
+ // support iostreams with wchar_t?
+#if !(defined(__MINGW32__) && (__GNUC__ < 4))
+ template <
+ typename T
+ >
+ const std::wstring cast_to_wstring (
+ const T& item
+ )
+ {
+ std::basic_ostringstream<wchar_t> sout;
+ sout << item;
+ if (!sout)
+ throw cast_to_string_error();
+ return sout.str();
+ }
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ inline std::string pad_int_with_zeros (
+ int i,
+ unsigned long width = 6
+ )
+ {
+ std::ostringstream sout;
+ sout << std::setw(width) << std::setfill('0') << i;
+ return sout.str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class string_cast_error : public error
+ {
+ public:
+ string_cast_error(const std::string& str):
+ error(ESTRING_CAST,"string cast error: invalid string = '" + str + "'") {}
+ };
+
+ template <
+ typename T
+ >
+ struct string_cast_helper
+ {
+ template < typename charT, typename traits, typename alloc >
+ static const T cast (
+ const std::basic_string<charT,traits,alloc>& str
+ )
+ {
+ using namespace std;
+ basic_istringstream<charT,traits,alloc> sin(str);
+ T temp;
+ sin >> temp;
+ if (!sin) throw string_cast_error(narrow(str));
+ if (sin.get() != std::char_traits<charT>::eof()) throw string_cast_error(narrow(str));
+ return temp;
+ }
+ };
+
+ template <typename C, typename T, typename A>
+ struct string_cast_helper<std::basic_string<C,T,A> >
+ {
+ template < typename charT, typename traits, typename alloc >
+ static const std::basic_string<C,T,A> cast (
+ const std::basic_string<charT,traits,alloc>& str
+ )
+ {
+ std::basic_string<C,T,A> temp;
+ temp.resize(str.size());
+ for (unsigned long i = 0; i < str.size(); ++i)
+ temp[i] = zero_extend_cast<C>(str[i]);
+ return temp;
+ }
+ };
+
+ template <>
+ struct string_cast_helper<bool>
+ {
+ template < typename charT, typename traits, typename alloc >
+ static bool cast (
+ const std::basic_string<charT,traits,alloc>& str
+ )
+ {
+ using namespace std;
+ if (str.size() == 1 && str[0] == '1')
+ return true;
+ if (str.size() == 1 && str[0] == '0')
+ return false;
+ if (tolower(narrow(str)) == "true")
+ return true;
+ if (tolower(narrow(str)) == "false")
+ return false;
+
+ throw string_cast_error(narrow(str));
+ }
+ };
+
+#define DLIB_STRING_CAST_INTEGRAL(type) \
+ template <> \
+ struct string_cast_helper<type> \
+ { \
+ template < typename charT, typename traits, typename alloc> \
+ static type cast ( \
+ const std::basic_string<charT,traits,alloc>& str \
+ ) \
+ { \
+ using namespace std; \
+ basic_istringstream<charT,traits,alloc> sin(str); \
+ type temp; \
+ if (str.size() > 2 && str[0] == _dT(charT,'0') && str[1] == _dT(charT,'x')) \
+ sin >> hex >> temp; \
+ else \
+ sin >> temp; \
+ if (!sin) throw string_cast_error(narrow(str)); \
+ if (sin.get() != std::char_traits<charT>::eof()) throw string_cast_error(narrow(str)); \
+ return temp; \
+ } \
+ };
+
+ DLIB_STRING_CAST_INTEGRAL(unsigned short)
+ DLIB_STRING_CAST_INTEGRAL(short)
+ DLIB_STRING_CAST_INTEGRAL(unsigned int)
+ DLIB_STRING_CAST_INTEGRAL(int)
+ DLIB_STRING_CAST_INTEGRAL(unsigned long)
+ DLIB_STRING_CAST_INTEGRAL(long)
+ DLIB_STRING_CAST_INTEGRAL(uint64)
+
+ template <
+ typename T,
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ inline const T string_cast (
+ const std::basic_string<charT,traits,alloc>& str
+ )
+ {
+ COMPILE_TIME_ASSERT(is_pointer_type<T>::value == false);
+ return string_cast_helper<T>::cast(str);
+ }
+
+ template <typename T>
+ inline const T string_cast (const char* str){ return string_cast<T>(std::string(str)); }
+ template <typename T>
+ inline const T string_cast (const wchar_t* str){ return string_cast<T>(std::wstring(str)); }
+
+// ----------------------------------------------------------------------------------------
+
+ class string_assign
+ {
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ class string_assign_helper
+ {
+ public:
+ string_assign_helper (
+ const std::basic_string<charT,traits,alloc>& str_
+ ) : str(str_) {}
+
+ template <typename T>
+ operator T () const
+ {
+ return string_cast<T>(str);
+ }
+
+ private:
+
+ const std::basic_string<charT,traits,alloc>& str;
+ };
+
+ // -------------
+
+ class char_assign_helper
+ {
+ public:
+ char_assign_helper (
+ const char* str_
+ ) : str(str_) {}
+
+ template <typename T>
+ operator T () const
+ {
+ return string_cast<T>(str);
+ }
+
+ private:
+
+ const char* str;
+ };
+
+ // -------------
+
+ class wchar_t_assign_helper
+ {
+ public:
+ wchar_t_assign_helper (
+ const wchar_t* str_
+ ) : str(str_) {}
+
+ template <typename T>
+ operator T () const
+ {
+ return string_cast<T>(str);
+ }
+
+ private:
+
+ const wchar_t* str;
+ };
+
+ // -------------
+
+ public:
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ string_assign_helper<charT,traits,alloc> operator=(
+ const std::basic_string<charT,traits,alloc>& str
+ ) const
+ {
+ return string_assign_helper<charT,traits,alloc>(str);
+ }
+
+ char_assign_helper operator= (
+ const char* str
+ ) const
+ {
+ return char_assign_helper(str);
+ }
+
+ wchar_t_assign_helper operator= (
+ const wchar_t* str
+ ) const
+ {
+ return wchar_t_assign_helper(str);
+ }
+ };
+
+ const string_assign sa = string_assign();
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> wrap_string (
+ const std::basic_string<charT,traits,alloc>& str,
+ const unsigned long first_pad = 0,
+ const unsigned long rest_pad = 0,
+ const unsigned long max_per_line = 79
+ )
+ {
+ DLIB_ASSERT ( first_pad < max_per_line && rest_pad < max_per_line &&
+ rest_pad >= first_pad,
+ "\tconst std::basic_string<charT,traits,alloc> wrap_string()"
+ << "\n\tfirst_pad: " << first_pad
+ << "\n\trest_pad: " << rest_pad
+ << "\n\tmax_per_line: " << max_per_line );
+
+ using namespace std;
+
+ basic_ostringstream<charT,traits,alloc> sout;
+ basic_istringstream<charT,traits,alloc> sin(str);
+
+ for (unsigned long i = 0; i < rest_pad; ++i)
+ sout << _dT(charT," ");
+ const basic_string<charT,traits,alloc> pad(sout.str());
+ sout.str(_dT(charT,""));
+
+ for (unsigned long i = 0; i < first_pad; ++i)
+ sout << _dT(charT," ");
+
+
+ typename basic_string<charT,traits,alloc>::size_type remaining = max_per_line - rest_pad;
+
+ basic_string<charT,traits,alloc> temp;
+
+ sin >> temp;
+ while (sin)
+ {
+ if (temp.size() > remaining)
+ {
+ if (temp.size() + rest_pad >= max_per_line)
+ {
+ string::size_type i = 0;
+ for (; i < temp.size(); ++i)
+ {
+ sout << temp[i];
+ --remaining;
+ if (remaining == 0)
+ {
+ sout << _dT(charT,"\n") << pad;
+ remaining = max_per_line - rest_pad;
+ }
+ }
+ }
+ else
+ {
+ sout << _dT(charT,"\n") << pad << temp;
+ remaining = max_per_line - rest_pad - temp.size();
+ }
+ }
+ else if (temp.size() == remaining)
+ {
+ sout << temp;
+ remaining = 0;
+ }
+ else
+ {
+ sout << temp;
+ remaining -= temp.size();
+ }
+
+ sin >> temp;
+ if (remaining == 0 && sin)
+ {
+ sout << _dT(charT,"\n") << pad;
+ remaining = max_per_line - rest_pad;
+ }
+ else
+ {
+ sout << _dT(charT," ");
+ --remaining;
+ }
+ }
+
+ return sout.str();
+ }
+
+ template <
+ typename charT
+ >
+ const std::basic_string<charT> wrap_string (
+ const charT* str,
+ const unsigned long first_pad = 0,
+ const unsigned long rest_pad = 0,
+ const unsigned long max_per_line = 79
+ ) { return wrap_string(std::basic_string<charT>(str),first_pad,rest_pad,max_per_line); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> ltrim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& trim_chars
+ )
+ {
+ typedef std::basic_string<charT,traits,alloc> string;
+ typename string::size_type pos = str.find_first_not_of(trim_chars);
+ if (pos != string::npos)
+ return str.substr(pos);
+ else
+ return std::basic_string<charT,traits,alloc>();
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> ltrim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* trim_chars = _dT(charT," \t\r\n")
+ ) { return ltrim(str,std::basic_string<charT,traits,alloc>(trim_chars)); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> rtrim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& trim_chars
+ )
+ {
+ typedef std::basic_string<charT,traits,alloc> string;
+
+ typename string::size_type pos = str.find_last_not_of(trim_chars);
+ if (pos != string::npos)
+ return str.substr(0,pos+1);
+ else
+ return std::basic_string<charT,traits,alloc>();
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> rtrim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* trim_chars = _dT(charT," \t\r\n")
+ ) { return rtrim(str,std::basic_string<charT,traits,alloc>(trim_chars)); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> trim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& trim_chars
+ )
+ {
+ typedef std::basic_string<charT,traits,alloc> string;
+ typename string::size_type lpos = str.find_first_not_of(trim_chars);
+ if (lpos != string::npos)
+ {
+ typename string::size_type rpos = str.find_last_not_of(trim_chars);
+ return str.substr(lpos,rpos-lpos+1);
+ }
+ else
+ {
+ return std::basic_string<charT,traits,alloc>();
+ }
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> trim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* trim_chars = _dT(charT," \t\r\n")
+ ) { return trim(str,std::basic_string<charT,traits,alloc>(trim_chars)); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> rpad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const std::basic_string<charT,traits,alloc>& pad_string
+ )
+ {
+ typedef std::basic_string<charT,traits,alloc> string;
+ // if str is too big then just return str
+ if (pad_length <= static_cast<long>(str.size()))
+ return str;
+
+ // make the string we will padd onto the string
+ string P;
+ while (P.size() < pad_length - str.size())
+ P += pad_string;
+ P = P.substr(0,pad_length - str.size());
+
+ // return the padded string
+ return str + P;
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> rpad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const charT* pad_string = _dT(charT," ")
+ ) { return rpad(str,pad_length,std::basic_string<charT,traits,alloc>(pad_string)); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> lpad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const std::basic_string<charT,traits,alloc>& pad_string
+ )
+ {
+ typedef std::basic_string<charT,traits,alloc> string;
+ // if str is too big then just return str
+ if (pad_length <= static_cast<long>(str.size()))
+ return str;
+
+ // make the string we will padd onto the string
+ string P;
+ while (P.size() < pad_length - str.size())
+ P += pad_string;
+ P = P.substr(0,pad_length - str.size());
+
+ // return the padded string
+ return P + str;
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> lpad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const charT* pad_string = _dT(charT," ")
+ ) { return lpad(str,pad_length,std::basic_string<charT,traits,alloc>(pad_string)); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> pad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const std::basic_string<charT,traits,alloc>& pad_string
+ )
+ {
+ const long str_size = static_cast<long>(str.size());
+ return rpad(lpad(str,(pad_length-str_size)/2 + str_size,pad_string),
+ pad_length,
+ pad_string);
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> pad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const charT* pad_string = _dT(charT," ")
+ ) { return pad(str,pad_length,std::basic_string<charT,traits,alloc>(pad_string)); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> left_substr (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ )
+ {
+ return str.substr(0,str.find_first_of(delim));
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> left_substr (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT," \n\r\t")
+ )
+ {
+ return str.substr(0,str.find_first_of(delim));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> right_substr (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ )
+ {
+ typename std::basic_string<charT,traits,alloc>::size_type delim_pos = str.find_last_of(delim);
+ if (delim_pos != std::basic_string<charT,traits,alloc>::npos)
+ return str.substr(delim_pos+1);
+ else
+ return _dT(charT,"");
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> right_substr (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT," \n\r\t")
+ )
+ {
+ typename std::basic_string<charT,traits,alloc>::size_type delim_pos = str.find_last_of(delim);
+ if (delim_pos != std::basic_string<charT,traits,alloc>::npos)
+ return str.substr(delim_pos+1);
+ else
+ return _dT(charT,"");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ std::pair<std::basic_string<charT,traits,alloc>, std::basic_string<charT,traits,alloc> >
+ split_on_first (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT," \n\r\t")
+ )
+ {
+ typename std::basic_string<charT,traits,alloc>::size_type delim_pos = str.find_first_of(delim);
+ if (delim_pos != std::basic_string<charT,traits,alloc>::npos)
+ return std::make_pair(str.substr(0, delim_pos), str.substr(delim_pos+1));
+ else
+ return std::make_pair(str, _dT(charT,""));
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ inline std::pair<std::basic_string<charT,traits,alloc>, std::basic_string<charT,traits,alloc> >
+ split_on_first (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ )
+ {
+ return split_on_first(str, delim.c_str());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ std::pair<std::basic_string<charT,traits,alloc>, std::basic_string<charT,traits,alloc> >
+ split_on_last (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT," \n\r\t")
+ )
+ {
+ typename std::basic_string<charT,traits,alloc>::size_type delim_pos = str.find_last_of(delim);
+ if (delim_pos != std::basic_string<charT,traits,alloc>::npos)
+ return std::make_pair(str.substr(0, delim_pos), str.substr(delim_pos+1));
+ else
+ return std::make_pair(str, _dT(charT,""));
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ inline std::pair<std::basic_string<charT,traits,alloc>, std::basic_string<charT,traits,alloc> >
+ split_on_last (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ )
+ {
+ return split_on_last(str, delim.c_str());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::vector<std::basic_string<charT,traits,alloc> > split (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT," \n\r\t")
+ )
+ {
+ std::basic_string<charT,traits,alloc> temp;
+
+ std::vector<std::basic_string<charT,traits,alloc> > res;
+
+ for (unsigned long i = 0; i < str.size(); ++i)
+ {
+ // check if delim contains the character str[i]
+ bool hit = false;
+ const charT* d = delim;
+ while (*d != '\0')
+ {
+ if (str[i] == *d)
+ {
+ hit = true;
+ break;
+ }
+ ++d;
+ }
+
+ if (hit)
+ {
+ if (temp.size() != 0)
+ {
+ res.push_back(temp);
+ temp.clear();
+ }
+ }
+ else
+ {
+ temp.push_back(str[i]);
+ }
+ }
+
+ if (temp.size() != 0)
+ res.push_back(temp);
+
+ return res;
+ }
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::vector<std::basic_string<charT,traits,alloc> > split (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ )
+ {
+ return split(str,delim.c_str());
+ }
+
+ inline const std::vector<std::string> split (
+ const char* str,
+ const char* delim = " \n\r\t"
+ )
+ {
+ return split(std::string(str),delim);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRINg_
+
diff --git a/ml/dlib/dlib/string/string_abstract.h b/ml/dlib/dlib/string/string_abstract.h
new file mode 100644
index 000000000..0a1ef0c79
--- /dev/null
+++ b/ml/dlib/dlib/string/string_abstract.h
@@ -0,0 +1,652 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRINg_ABSTRACT_
+#ifdef DLIB_STRINg_ABSTRACT_
+
+#include <string>
+#include <iostream>
+#include <vector>
+#include "../error.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class string_cast_error : public error
+ {
+ public:
+ string_cast_error():error(ECAST_TO_STRING) {}
+ };
+
+ template <
+ typename T,
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const T string_cast (
+ const std::basic_string<charT,traits,alloc>& str
+ );
+ /*!
+ requires
+ - T is not a pointer type
+ ensures
+ - returns str converted to T
+ throws
+ - string_cast_error
+ This exception is thrown if string_cast() is unable to convert
+ str into a T. Also, string_cast_error::info == str
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class string_assign
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple tool which provides an alternative syntax for using
+ the string_cast() function. It can be understood by considering
+ the following example:
+
+ string_assign sa;
+ int val;
+ double dval;
+
+ val = sa = "1234"; // executes: val = string_cast<int>("1234");
+ dval = sa = "3.141"; // executes: val = string_cast<double>("3.141");
+
+ After executing, val will be equal to 1234 and dval will be 3.141.
+ Note that you can use string_assign to assign to any type which you could
+ use with string_cast(), except for std::basic_string, assigning to this
+ type is ambiguous for boring technical reasons. But there isn't much
+ point in using this tool to assign from one string to another so it doesn't
+ matter.
+
+ Additionally, note that there is a global instance of this object, dlib::sa.
+ So you never have to create a string_assign object yourself. Finally, this
+ object is totally stateless and threadsafe.
+ !*/
+ };
+
+ const string_assign sa = string_assign();
+
+// ----------------------------------------------------------------------------------------
+
+ class cast_to_string_error : public error
+ {
+ public:
+ cast_to_string_error():error(ECAST_TO_STRING) {}
+ };
+
+ template <
+ typename T
+ >
+ const std::string cast_to_string (
+ const T& item
+ );
+ /*!
+ requires
+ - T is not a pointer type
+ ensures
+ - returns item converted to std::string
+ throws
+ - cast_to_string_error
+ This exception is thrown if cast_to_string() is unable to convert
+ item into a std::string.
+ !*/
+
+ template <
+ typename T
+ >
+ const std::wstring cast_to_wstring (
+ const T& item
+ );
+ /*!
+ requires
+ - T is not a pointer type
+ ensures
+ - returns item converted to std::wstring
+ throws
+ - cast_to_string_error
+ This exception is thrown if cast_to_string() is unable to convert
+ item into a std::string.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ std::string pad_int_with_zeros (
+ int i,
+ unsigned long width = 6
+ );
+ /*!
+ ensures
+ - converts i into a string of at least width characters in length. If
+ necessary, the string will be padded with leading zeros to get
+ to width characters.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::string narrow (
+ const std::basic_string<charT,traits,alloc>& str
+ );
+ /*!
+ ensures
+ - returns str as a std::string by converting every character in it to a char.
+ Note that any characters that do not have a mapping to type char will be
+ converted to a space.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> wrap_string (
+ const std::basic_string<charT,traits,alloc>& str,
+ const unsigned long first_pad = 0,
+ const unsigned long rest_pad = 0,
+ const unsigned long max_per_line = 79
+ );
+ /*!
+ requires
+ - first_pad < max_per_line
+ - rest_pad < max_per_line
+ - rest_pad >= first_pad
+ ensures
+ - returns a copy of str S such that:
+ - S is broken up into lines separated by the \n character.
+ - The first line starts with first_pad space characters.
+ - The second and all subsequent lines start with rest_pad space characters.
+ - The first line is no longer than max_per_line - (rest_pad-first_pad) characters.
+ - The second and all subsequent lines are no longer than max_per_line characters.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename traits
+ typename alloc
+ >
+ const std::basic_string<char,traits,alloc> tolower (
+ const std::basic_string<char,traits,alloc>& str
+ );
+ /*!
+ ensures
+ - returns a copy of str S such that:
+ - #S.size() == str.size()
+ - #S[i] == std::tolower(str[i])
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<char,traits,alloc> toupper (
+ const std::basic_string<char,traits,alloc>& str
+ );
+ /*!
+ ensures
+ - returns a copy of str S such that:
+ - #S.size() == str.size()
+ - #S[i] == std::toupper(str[i])
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ bool strings_equal_ignore_case (
+ const std::basic_string<char,traits,alloc>& str1,
+ const std::basic_string<char,traits,alloc>& str2
+ );
+ /*!
+ ensures
+ - returns tolower(str1) == tolower(str2)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename traits,
+ typename alloc
+ >
+ bool strings_equal_ignore_case (
+ const std::basic_string<char,traits,alloc>& str1,
+ const std::basic_string<char,traits,alloc>& str2,
+ unsigned long num
+ );
+ /*!
+ ensures
+ - returns tolower(str1.substr(0,num)) == tolower(str2.substr(0,num))
+ (i.e. only compares the first num characters)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> ltrim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& trim_chars
+ );
+ /*!
+ ensures
+ - returns a copy of str with any leading trim_chars
+ from the left side of the string removed.
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> ltrim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* trim_chars = _dT(charT," \t\r\n")
+ );
+ /*!
+ requires
+ - trim_chars == a valid null-terminated C string
+ ensures
+ - returns ltrim(str, std::basic_string<charT,traits,alloc>(trim_chars))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> rtrim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& trim_chars
+ );
+ /*!
+ ensures
+ - returns a copy of str with any trailing trim_chars
+ from the right side of the string removed.
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> rtrim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* trim_chars = _dT(charT," \t\r\n")
+ );
+ /*!
+ requires
+ - trim_chars == a valid null-terminated C string
+ ensures
+ - returns rtrim(str, std::basic_string<charT,traits,alloc>(trim_chars))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> trim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& trim_chars
+ );
+ /*!
+ ensures
+ - returns a copy of str with any leading or trailing trim_chars
+ from the ends of the string removed.
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> trim (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* trim_chars = _dT(charT," \t\r\n")
+ );
+ /*!
+ requires
+ - trim_chars == a valid null-terminated C string
+ ensures
+ - returns trim(str, std::basic_string<charT,traits,alloc>(trim_chars))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> rpad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const std::basic_string<charT,traits,alloc>& pad_string
+ );
+ /*!
+ ensures
+ - if (pad_length <= str.size()) then
+ - returns str
+ - else
+ - let P be a string defined as follows:
+ - P.size() == pad_length - str.size()
+ - P == (pad_string + pad_string + ... + pad_string).substr(0,pad_length - str.size())
+ (i.e. P == a string with the above specified size that contains just
+ repitions of the pad_string)
+ - returns the string str + P
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> rpad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const charT* pad_string = _dT(charT," ")
+ );
+ /*!
+ requires
+ - pad_string == a valid null-terminated C string
+ ensures
+ - returns rpad(str, pad_length, std::basic_string<charT,traits,alloc>(pad_string))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> lpad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const std::basic_string<charT,traits,alloc>& pad_string
+ );
+ /*!
+ ensures
+ - if (pad_length <= str.size()) then
+ - returns str
+ - else
+ - let P be a string defined as follows:
+ - P.size() == pad_length - str.size()
+ - P == (pad_string + pad_string + ... + pad_string).substr(0,pad_length - str.size())
+ (i.e. P == a string with the above specified size that contains just
+ repitions of the pad_string)
+ - returns the string P + str
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> lpad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const charT* pad_string = _dT(charT," ")
+ );
+ /*!
+ requires
+ - pad_string == a valid null-terminated C string
+ ensures
+ - returns lpad(str, pad_length, std::basic_string<charT,traits,alloc>(pad_string))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> pad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const std::basic_string<charT,traits,alloc>& pad_string
+ );
+ /*!
+ ensures
+ - let str_size == static_cast<long>(str.size())
+ - returns rpad( lpad(str, (pad_length-str_size)/2 + str_size, pad_string),
+ pad_length,
+ pad_string);
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> pad (
+ const std::basic_string<charT,traits,alloc>& str,
+ long pad_length,
+ const charT* pad_string = _dT(charT," ")
+ );
+ /*!
+ requires
+ - pad_string == a valid null-terminated C string
+ ensures
+ - returns pad(str, pad_length, std::basic_string<charT,traits,alloc>(pad_string))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> left_substr (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ );
+ /*!
+ ensures
+ - let delim_pos = str.find_first_of(delim)
+ - returns str.substr(0,delim_pos)
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> left_substr (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT," \n\r\t")
+ );
+ /*!
+ requires
+ - delim == a valid null-terminated C string
+ ensures
+ - returns left_substr(str, std::basic_string<charT,traits,alloc>(delim))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> right_substr (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ );
+ /*!
+ ensures
+ - let delim_pos = str.find_last_of(delim)
+ - if (delim_pos == std::string::npos) then
+ - returns ""
+ - else
+ - returns str.substr(delim_pos+1)
+ !*/
+
+ template <
+ typename charT,
+ typename traits
+ typename alloc
+ >
+ const std::basic_string<charT,traits,alloc> right_substr (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT," \n\r\t")
+ );
+ /*!
+ requires
+ - delim == a valid null-terminated C string
+ ensures
+ - returns right_substr(str, std::basic_string<charT,traits,alloc>(delim))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ std::pair<std::basic_string<charT,traits,alloc>, std::basic_string<charT,traits,alloc> >
+ split_on_first (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT, " \n\r\t")
+ );
+ /*!
+ ensures
+ - This function splits string into two parts, the split is based on the first
+ occurrence of any character from delim.
+ - let delim_pos = str.find_first_of(delim)
+ - if (delim_pos == std::string::npos) then
+ - returns make_pair(str,"")
+ - else
+ - return make_pair(str.substr(0, delim_pos), str.substr(delim_pos+1));
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ std::pair<std::basic_string<charT,traits,alloc>, std::basic_string<charT,traits,alloc> >
+ split_on_first (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ );
+ /*!
+ requires
+ - delim == a valid null-terminated C string
+ ensures
+ - returns split_on_first(str, delim.c_str())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ std::pair<std::basic_string<charT,traits,alloc>, std::basic_string<charT,traits,alloc> >
+ split_on_last (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT, " \n\r\t")
+ );
+ /*!
+ ensures
+ - This function splits string into two parts, the split is based on the last
+ occurrence of any character from delim.
+ - let delim_pos = str.find_last_of(delim)
+ - if (delim_pos == std::string::npos) then
+ - returns make_pair(str,"")
+ - else
+ - return make_pair(str.substr(0, delim_pos), str.substr(delim_pos+1));
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ std::pair<std::basic_string<charT,traits,alloc>, std::basic_string<charT,traits,alloc> >
+ split_on_last (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ );
+ /*!
+ requires
+ - delim == a valid null-terminated C string
+ ensures
+ - returns split_on_last(str, delim.c_str())
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::vector<std::basic_string<charT,traits,alloc> > split (
+ const std::basic_string<charT,traits,alloc>& str,
+ const std::basic_string<charT,traits,alloc>& delim
+ );
+ /*!
+ ensures
+ - Breaks the given string str into a sequence of substrings delimited
+ by characters in delim and returns the results.
+ - returns a vector V such that:
+ - V.size() == the number of substrings found in str.
+ - for all i: V[i] == The ith substring. Note that it will not contain
+ any delimiter characters (i.e. characters in delim). It will also
+ never be an empty string.
+ - V contains the substrings in the order in which they appear in str.
+ That is, V[0] contains the first substring, V[1] the second, and
+ so on.
+ !*/
+
+ template <
+ typename charT,
+ typename traits,
+ typename alloc
+ >
+ const std::vector<std::basic_string<charT,traits,alloc> > split (
+ const std::basic_string<charT,traits,alloc>& str,
+ const charT* delim = _dT(charT," \n\r\t")
+ );
+ /*!
+ requires
+ - trim_chars == a valid null-terminated C string
+ ensures
+ - returns split(str, std::basic_string<charT,traits,alloc>(delim))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRINg_ABSTRACT_
+
diff --git a/ml/dlib/dlib/svm.h b/ml/dlib/dlib/svm.h
new file mode 100644
index 000000000..4dc7382c8
--- /dev/null
+++ b/ml/dlib/dlib/svm.h
@@ -0,0 +1,60 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_SVm_HEADER
+#define DLIB_SVm_HEADER
+
+#include "svm/svm_rank_trainer.h"
+#include "svm/svm.h"
+#include "svm/krls.h"
+#include "svm/rls.h"
+#include "svm/kcentroid.h"
+#include "svm/kcentroid_overloads.h"
+#include "svm/kkmeans.h"
+#include "svm/feature_ranking.h"
+#include "svm/rbf_network.h"
+#include "svm/linearly_independent_subset_finder.h"
+#include "svm/reduced.h"
+#include "svm/rvm.h"
+#include "svm/pegasos.h"
+#include "svm/sparse_kernel.h"
+#include "svm/null_trainer.h"
+#include "svm/roc_trainer.h"
+#include "svm/kernel_matrix.h"
+#include "svm/empirical_kernel_map.h"
+#include "svm/svm_c_linear_trainer.h"
+#include "svm/svm_c_linear_dcd_trainer.h"
+#include "svm/svm_c_ekm_trainer.h"
+#include "svm/simplify_linear_decision_function.h"
+#include "svm/krr_trainer.h"
+#include "svm/sort_basis_vectors.h"
+#include "svm/svm_c_trainer.h"
+#include "svm/svm_one_class_trainer.h"
+#include "svm/svr_trainer.h"
+
+#include "svm/one_vs_one_decision_function.h"
+#include "svm/multiclass_tools.h"
+#include "svm/cross_validate_multiclass_trainer.h"
+#include "svm/cross_validate_regression_trainer.h"
+#include "svm/cross_validate_object_detection_trainer.h"
+#include "svm/cross_validate_sequence_labeler.h"
+#include "svm/cross_validate_sequence_segmenter.h"
+#include "svm/cross_validate_assignment_trainer.h"
+
+#include "svm/one_vs_all_decision_function.h"
+
+#include "svm/structural_svm_problem.h"
+#include "svm/sequence_labeler.h"
+#include "svm/assignment_function.h"
+#include "svm/track_association_function.h"
+#include "svm/active_learning.h"
+#include "svm/svr_linear_trainer.h"
+#include "svm/sequence_segmenter.h"
+
+#endif // DLIB_SVm_HEADER
+
+
diff --git a/ml/dlib/dlib/svm/active_learning.h b/ml/dlib/dlib/svm/active_learning.h
new file mode 100644
index 000000000..581540e67
--- /dev/null
+++ b/ml/dlib/dlib/svm/active_learning.h
@@ -0,0 +1,162 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ACTIVE_LEARnING_Hh_
+#define DLIB_ACTIVE_LEARnING_Hh_
+
+#include "active_learning_abstract.h"
+
+#include "svm_c_linear_dcd_trainer.h"
+#include <vector>
+
+namespace dlib
+{
+
+ enum active_learning_mode
+ {
+ max_min_margin,
+ ratio_margin
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type,
+ typename in_sample_vector_type2
+ >
+ std::vector<unsigned long> impl_rank_unlabeled_training_samples (
+ const svm_c_linear_dcd_trainer<kernel_type>& trainer,
+ const in_sample_vector_type& samples,
+ const in_scalar_vector_type& labels,
+ const in_sample_vector_type2& unlabeled_samples,
+ const active_learning_mode mode
+ )
+ {
+ DLIB_ASSERT(is_vector(unlabeled_samples) &&
+ (samples.size() == 0 || is_learning_problem(samples, labels)) ,
+ "\t std::vector<unsigned long> rank_unlabeled_training_samples()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t is_vector(unlabeled_samples): " << is_vector(unlabeled_samples)
+ << "\n\t is_learning_problem(samples, labels): " << is_learning_problem(samples, labels)
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t labels.size(): " << labels.size()
+ );
+
+ // If there aren't any training samples then all unlabeled_samples are equally good.
+ // So just report an arbitrary ordering.
+ if (samples.size() == 0 || unlabeled_samples.size() == 0)
+ {
+ std::vector<unsigned long> ret(unlabeled_samples.size());
+ for (unsigned long i = 0; i < ret.size(); ++i)
+ ret[i] = i;
+
+ return ret;
+ }
+
+ // We are going to score each unlabeled sample and put the score and index into
+ // results. Then at the end of this function we just sort it and return the indices.
+ std::vector<std::pair<double, unsigned long> > results;
+ results.resize(unlabeled_samples.size());
+
+ // make sure we use this trainer's ability to warm start itself since that will make
+ // this whole function run a lot faster. But first, we need to find out what the state
+ // we will be warm starting from is.
+ typedef typename svm_c_linear_dcd_trainer<kernel_type>::optimizer_state optimizer_state;
+ optimizer_state state;
+ trainer.train(samples, labels, state); // call train() just to get state
+
+ decision_function<kernel_type> df;
+
+ std::vector<typename kernel_type::sample_type> temp_samples;
+ std::vector<typename kernel_type::scalar_type> temp_labels;
+ temp_samples.reserve(samples.size()+1);
+ temp_labels.reserve(labels.size()+1);
+ temp_samples.assign(samples.begin(), samples.end());
+ temp_labels.assign(labels.begin(), labels.end());
+ temp_samples.resize(temp_samples.size()+1);
+ temp_labels.resize(temp_labels.size()+1);
+
+
+ for (long i = 0; i < unlabeled_samples.size(); ++i)
+ {
+ temp_samples.back() = unlabeled_samples(i);
+ // figure out the margin for each possible labeling of this sample.
+
+ optimizer_state temp(state);
+ temp_labels.back() = +1;
+ df = trainer.train(temp_samples, temp_labels, temp);
+ const double margin_p = temp_labels.back()*df(temp_samples.back());
+
+ temp = state;
+ temp_labels.back() = -1;
+ df = trainer.train(temp_samples, temp_labels, temp);
+ const double margin_n = temp_labels.back()*df(temp_samples.back());
+
+ if (mode == max_min_margin)
+ {
+ // The score for this sample is its min possible margin over possible labels.
+ // Therefore, this score measures how much flexibility we have to label this
+ // sample however we want. The intuition being that the most useful points to
+ // label are the ones that are still free to obtain either label.
+ results[i] = std::make_pair(std::min(margin_p, margin_n), i);
+ }
+ else
+ {
+ // In this case, the score for the sample is a ratio that tells how close the
+ // two margin values are to each other. The closer they are the better. So in
+ // this case we are saying we are looking for samples that have the same
+ // preference for either class label.
+ if (std::abs(margin_p) >= std::abs(margin_n))
+ {
+ if (margin_p != 0)
+ results[i] = std::make_pair(margin_n/margin_p, i);
+ else // if both are == 0 then say 0/0 == 1
+ results[i] = std::make_pair(1, i);
+ }
+ else
+ {
+ results[i] = std::make_pair(margin_p/margin_n, i);
+ }
+ }
+ }
+
+ // sort the results so the highest scoring samples come first.
+ std::sort(results.rbegin(), results.rend());
+
+ // transfer results into a vector with just sample indices so we can return it.
+ std::vector<unsigned long> ret(results.size());
+ for (unsigned long i = 0; i < ret.size(); ++i)
+ ret[i] = results[i].second;
+ return ret;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type,
+ typename in_sample_vector_type2
+ >
+ std::vector<unsigned long> rank_unlabeled_training_samples (
+ const svm_c_linear_dcd_trainer<kernel_type>& trainer,
+ const in_sample_vector_type& samples,
+ const in_scalar_vector_type& labels,
+ const in_sample_vector_type2& unlabeled_samples,
+ const active_learning_mode mode = max_min_margin
+ )
+ {
+ return impl_rank_unlabeled_training_samples(trainer,
+ mat(samples),
+ mat(labels),
+ mat(unlabeled_samples),
+ mode);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ACTIVE_LEARnING_Hh_
+
diff --git a/ml/dlib/dlib/svm/active_learning_abstract.h b/ml/dlib/dlib/svm/active_learning_abstract.h
new file mode 100644
index 000000000..76a5120e3
--- /dev/null
+++ b/ml/dlib/dlib/svm/active_learning_abstract.h
@@ -0,0 +1,75 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ACTIVE_LEARnING_ABSTRACT_Hh_
+#ifdef DLIB_ACTIVE_LEARnING_ABSTRACT_Hh_
+
+#include "svm_c_linear_dcd_trainer_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ enum active_learning_mode
+ {
+ max_min_margin,
+ ratio_margin
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type,
+ typename in_sample_vector_type2
+ >
+ std::vector<unsigned long> rank_unlabeled_training_samples (
+ const svm_c_linear_dcd_trainer<kernel_type>& trainer,
+ const in_sample_vector_type& samples,
+ const in_scalar_vector_type& labels,
+ const in_sample_vector_type2& unlabeled_samples,
+ const active_learning_mode mode = max_min_margin
+ );
+ /*!
+ requires
+ - if (samples.size() != 0) then
+ - it must be legal to call trainer.train(samples, labels)
+ - is_learning_problem(samples, labels) == true
+ - unlabeled_samples must contain the same kind of vectors as samples.
+ - unlabeled_samples, samples, and labels must be matrices or types of
+ objects convertible to a matrix via mat().
+ - is_vector(unlabeled_samples) == true
+ ensures
+ - Suppose that we wish to learn a binary classifier by calling
+ trainer.train(samples, labels) but we are also interested in selecting one of
+ the elements of unlabeled_samples to add to our training data. Since doing
+ this requires us to find out the label of the sample, a potentially tedious
+ or expensive process, we would like to select the "best" element from
+ unlabeled_samples for labeling. The rank_unlabeled_training_samples()
+ attempts to find this "best" element. In particular, this function returns a
+ ranked list of all the elements in unlabeled_samples such that that the
+ "best" elements come first.
+ - The method used by this function is described in the paper:
+ Support Vector Machine Active Learning with Applications to Text Classification
+ by Simon Tong and Daphne Koller
+ In particular, this function implements the MaxMin Margin and Ratio Margin
+ selection strategies described in the paper. Moreover, the mode argument
+ to this function selects which of these strategies is used.
+ - returns a std::vector V such that:
+ - V contains a list of all the indices from unlabeled_samples. Moreover,
+ they are ordered so that the most useful samples come first.
+ - V.size() == unlabeled_samples.size()
+ - unlabeled_samples[V[0]] == The best sample to add into the training set.
+ - unlabeled_samples[V[1]] == The second best sample to add into the training set.
+ - unlabeled_samples[V[i]] == The i-th best sample to add into the training set.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ACTIVE_LEARnING_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/assignment_function.h b/ml/dlib/dlib/svm/assignment_function.h
new file mode 100644
index 000000000..fdacb2c17
--- /dev/null
+++ b/ml/dlib/dlib/svm/assignment_function.h
@@ -0,0 +1,255 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ASSIGNMENT_FuNCTION_Hh_
+#define DLIB_ASSIGNMENT_FuNCTION_Hh_
+
+#include "assignment_function_abstract.h"
+#include "../matrix.h"
+#include <vector>
+#include "../optimization/max_cost_assignment.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class assignment_function
+ {
+ public:
+
+ typedef typename feature_extractor::lhs_element lhs_element;
+ typedef typename feature_extractor::rhs_element rhs_element;
+
+
+ typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
+
+ typedef std::vector<long> label_type;
+ typedef label_type result_type;
+
+ assignment_function()
+ {
+ weights.set_size(fe.num_features());
+ weights = 0;
+ bias = 0;
+ force_assignment = false;
+ }
+
+ explicit assignment_function(
+ const matrix<double,0,1>& weights_,
+ double bias_
+ ) :
+ weights(weights_),
+ bias(bias_),
+ force_assignment(false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(fe.num_features() == static_cast<unsigned long>(weights_.size()),
+ "\t assignment_function::assignment_function(weights_)"
+ << "\n\t These sizes should match"
+ << "\n\t fe.num_features(): " << fe.num_features()
+ << "\n\t weights_.size(): " << weights_.size()
+ << "\n\t this: " << this
+ );
+
+ }
+
+ assignment_function(
+ const matrix<double,0,1>& weights_,
+ double bias_,
+ const feature_extractor& fe_
+ ) :
+ fe(fe_),
+ weights(weights_),
+ bias(bias_),
+ force_assignment(false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(fe_.num_features() == static_cast<unsigned long>(weights_.size()),
+ "\t assignment_function::assignment_function(weights_,fe_)"
+ << "\n\t These sizes should match"
+ << "\n\t fe_.num_features(): " << fe_.num_features()
+ << "\n\t weights_.size(): " << weights_.size()
+ << "\n\t this: " << this
+ );
+ }
+
+ assignment_function(
+ const matrix<double,0,1>& weights_,
+ double bias_,
+ const feature_extractor& fe_,
+ bool force_assignment_
+ ) :
+ fe(fe_),
+ weights(weights_),
+ bias(bias_),
+ force_assignment(force_assignment_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(fe_.num_features() == static_cast<unsigned long>(weights_.size()),
+ "\t assignment_function::assignment_function(weights_,fe_,force_assignment_)"
+ << "\n\t These sizes should match"
+ << "\n\t fe_.num_features(): " << fe_.num_features()
+ << "\n\t weights_.size(): " << weights_.size()
+ << "\n\t this: " << this
+ );
+ }
+
+ const feature_extractor& get_feature_extractor (
+ ) const { return fe; }
+
+ const matrix<double,0,1>& get_weights (
+ ) const { return weights; }
+
+ double get_bias (
+ ) const { return bias; }
+
+ bool forces_assignment (
+ ) const { return force_assignment; }
+
+ void predict_assignments (
+ const std::vector<lhs_element>& lhs,
+ const std::vector<rhs_element>& rhs,
+ result_type& assignment
+ ) const
+ {
+ assignment.clear();
+
+ matrix<double> cost;
+ unsigned long size;
+ if (force_assignment)
+ {
+ size = std::max(lhs.size(), rhs.size());
+ }
+ else
+ {
+ size = rhs.size() + lhs.size();
+ }
+ cost.set_size(size, size);
+
+ typedef typename feature_extractor::feature_vector_type feature_vector_type;
+ feature_vector_type feats;
+
+ // now fill out the cost assignment matrix
+ for (long r = 0; r < cost.nr(); ++r)
+ {
+ for (long c = 0; c < cost.nc(); ++c)
+ {
+ if (r < (long)lhs.size() && c < (long)rhs.size())
+ {
+ fe.get_features(lhs[r], rhs[c], feats);
+ cost(r,c) = dot(weights, feats) + bias;
+ }
+ else
+ {
+ cost(r,c) = 0;
+ }
+ }
+ }
+
+
+ if (cost.size() != 0)
+ {
+ // max_cost_assignment() only works with integer matrices, so convert from
+ // double to integer.
+ const double scale = (std::numeric_limits<dlib::int64>::max()/1000)/max(abs(cost));
+ matrix<dlib::int64> int_cost = matrix_cast<dlib::int64>(round(cost*scale));
+ assignment = max_cost_assignment(int_cost);
+ assignment.resize(lhs.size());
+ }
+
+ // adjust assignment so that non-assignments have a value of -1
+ for (unsigned long i = 0; i < assignment.size(); ++i)
+ {
+ if (assignment[i] >= (long)rhs.size())
+ assignment[i] = -1;
+ }
+ }
+
+ void predict_assignments (
+ const sample_type& item,
+ result_type& assignment
+ ) const
+ {
+ predict_assignments(item.first, item.second, assignment);
+ }
+
+ result_type operator()(
+ const std::vector<lhs_element>& lhs,
+ const std::vector<rhs_element>& rhs
+ ) const
+ {
+ result_type temp;
+ predict_assignments(lhs,rhs,temp);
+ return temp;
+ }
+
+ result_type operator() (
+ const sample_type& item
+ ) const
+ {
+ return (*this)(item.first, item.second);
+ }
+
+ private:
+
+
+ feature_extractor fe;
+ matrix<double,0,1> weights;
+ double bias;
+ bool force_assignment;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void serialize (
+ const assignment_function<feature_extractor>& item,
+ std::ostream& out
+ )
+ {
+ int version = 2;
+ serialize(version, out);
+ serialize(item.get_feature_extractor(), out);
+ serialize(item.get_weights(), out);
+ serialize(item.get_bias(), out);
+ serialize(item.forces_assignment(), out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void deserialize (
+ assignment_function<feature_extractor>& item,
+ std::istream& in
+ )
+ {
+ feature_extractor fe;
+ matrix<double,0,1> weights;
+ double bias;
+ bool force_assignment;
+ int version = 0;
+ deserialize(version, in);
+ if (version != 2)
+ throw serialization_error("Unexpected version found while deserializing dlib::assignment_function.");
+
+ deserialize(fe, in);
+ deserialize(weights, in);
+ deserialize(bias, in);
+ deserialize(force_assignment, in);
+
+ item = assignment_function<feature_extractor>(weights, bias, fe, force_assignment);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ASSIGNMENT_FuNCTION_Hh_
+
diff --git a/ml/dlib/dlib/svm/assignment_function_abstract.h b/ml/dlib/dlib/svm/assignment_function_abstract.h
new file mode 100644
index 000000000..927731856
--- /dev/null
+++ b/ml/dlib/dlib/svm/assignment_function_abstract.h
@@ -0,0 +1,342 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ASSIGNMENT_FuNCTION_ABSTRACT_Hh_
+#ifdef DLIB_ASSIGNMENT_FuNCTION_ABSTRACT_Hh_
+
+#include <vector>
+#include "../optimization/max_cost_assignment_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class example_feature_extractor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines the interface a feature extractor must implement
+ if it is to be used with the assignment_function defined at the bottom
+ of this file.
+
+ The model used by assignment_function objects is the following.
+ Given two sets of objects, the Left Hand Set (LHS) and Right Hand Set (RHS),
+ find a one-to-one mapping M from LHS into RHS such that:
+ M == argmax_m sum_{l in LHS} match_score(l,m(l))
+ Where match_score() returns a scalar value indicating how good it is
+ to say l maps to the RHS element m(l). Additionally, in this model,
+ m() is allowed to indicate that l doesn't map to anything, and in this
+ case it is excluded from the sum.
+
+ Finally, match_score() is defined as:
+ match_score(l,r) == dot(w, PSI(l,r)) + bias
+ where l is an element of LHS, r is an element of RHS, w is a parameter
+ vector and bias is a scalar valued parameter.
+
+ Therefore, a feature extractor defines how the PSI() feature vector
+ is calculated. In particular, PSI() is defined by the get_features()
+ method of this class.
+
+ THREAD SAFETY
+ Instances of this object are required to be threadsafe, that is, it should
+ be safe for multiple threads to make concurrent calls to the member
+ functions of this object.
+
+ !*/
+
+ public:
+
+ // This type should be a dlib::matrix capable of storing column vectors
+ // or an unsorted sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
+ typedef matrix_or_sparse_vector_type feature_vector_type;
+
+ // These two typedefs define the types used to represent an element in
+ // the left hand and right hand sets. You can use any copyable types here.
+ typedef user_defined_type_1 lhs_element;
+ typedef user_defined_type_2 rhs_element;
+
+ unsigned long num_features(
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the PSI() feature vector.
+ !*/
+
+ void get_features (
+ const lhs_element& left,
+ const rhs_element& right,
+ feature_vector_type& feats
+ ) const;
+ /*!
+ ensures
+ - #feats == PSI(left,right)
+ (i.e. This function computes a feature vector which, in some sense,
+ captures information useful for deciding if matching left to right
+ is "good").
+ !*/
+
+ unsigned long num_nonnegative_weights (
+ ) const;
+ /*!
+ ensures
+ - returns the number of elements of the w parameter vector which should be
+ non-negative. That is, this feature extractor is intended to be used
+ with w vectors where the first num_nonnegative_weights() elements of w
+ are >= 0. That is, it should be the case that w(i) >= 0 for all i <
+ num_nonnegative_weights().
+ - Note that num_nonnegative_weights() is just an optional method to allow
+ you to tell a tool like the structural_assignment_trainer that the
+ learned w should have a certain number of non-negative elements.
+ Therefore, if you do not provide a num_nonnegative_weights() method in
+ your feature extractor then it will default to a value of 0, indicating
+ that all elements of the w parameter vector may be any value.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize(
+ const example_feature_extractor& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize(
+ example_feature_extractor& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class assignment_function
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ It must be an object that implements an interface compatible with
+ the example_feature_extractor discussed above.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for solving the optimal assignment problem given a
+ user defined method for computing the quality of any particular assignment.
+
+ To define this precisely, suppose you have two sets of objects, a
+ Left Hand Set (LHS) and a Right Hand Set (RHS) and you want to
+ find a one-to-one mapping M from LHS into RHS such that:
+ M == argmax_m sum_{l in LHS} match_score(l,m(l))
+ Where match_score() returns a scalar value indicating how good it is
+ to say l maps to the RHS element m(l). Additionally, in this model,
+ m() is allowed to indicate that l doesn't map to anything, and in this
+ case it is excluded from the sum.
+
+ Finally, this object supports match_score() functions of the form:
+ match_score(l,r) == dot(w, PSI(l,r)) + bias
+ where l is an element of LHS, r is an element of RHS, w is a parameter
+ vector, bias is a scalar valued parameter, and PSI() is defined by the
+ feature_extractor template argument.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call the const members of this object from multiple
+ threads so long as the feature_extractor is also threadsafe. This is
+ because the const members are purely read-only operations. However,
+ any operation that modifies an assignment_function is not threadsafe.
+ !*/
+
+ public:
+
+ typedef typename feature_extractor::lhs_element lhs_element;
+ typedef typename feature_extractor::rhs_element rhs_element;
+ typedef std::vector<long> label_type;
+ typedef label_type result_type;
+ typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
+
+ assignment_function(
+ );
+ /*!
+ ensures
+ - #get_feature_extractor() == feature_extractor()
+ (i.e. it will have its default value)
+ - #get_weights().size() == #get_feature_extractor().num_features()
+ - #get_weights() == 0
+ - #get_bias() == 0
+ - #forces_assignment() == false
+ !*/
+
+ explicit assignment_function(
+ const matrix<double,0,1>& weights,
+ double bias
+ );
+ /*!
+ requires
+ - feature_extractor().num_features() == weights.size()
+ ensures
+ - #get_feature_extractor() == feature_extractor()
+ (i.e. it will have its default value)
+ - #get_weights() == weights
+ - #get_bias() == bias
+ - #forces_assignment() == false
+ !*/
+
+ assignment_function(
+ const matrix<double,0,1>& weights,
+ double bias,
+ const feature_extractor& fe
+ );
+ /*!
+ requires
+ - fe.num_features() == weights.size()
+ ensures
+ - #get_feature_extractor() == fe
+ - #get_weights() == weights
+ - #get_bias() == bias
+ - #forces_assignment() == false
+ !*/
+
+ assignment_function(
+ const matrix<double,0,1>& weights,
+ double bias,
+ const feature_extractor& fe,
+ bool force_assignment
+ );
+ /*!
+ requires
+ - fe.num_features() == weights.size()
+ ensures
+ - #get_feature_extractor() == fe
+ - #get_weights() == weights
+ - #get_bias() == bias
+ - #forces_assignment() == force_assignment
+ !*/
+
+ const feature_extractor& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used by this object
+ !*/
+
+ const matrix<double,0,1>& get_weights (
+ ) const;
+ /*!
+ ensures
+ - returns the parameter vector (w) associated with this assignment function.
+ The length of the vector is get_feature_extractor().num_features().
+ !*/
+
+ double get_bias (
+ ) const;
+ /*!
+ ensures
+ - returns the bias parameter associated with this assignment function.
+ !*/
+
+ bool forces_assignment (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object is in the "forced assignment mode" and false
+ otherwise.
+ - When deciding how to match LHS to RHS, this object can operate in one of
+ two modes. In the default mode, this object will indicate that there is
+ no match for an element of LHS if the best matching element of RHS would
+ result in a negative match_score(). However, in the "forced assignment mode",
+ this object will always make the assignment if there is an available
+ element in RHS, regardless of the match_score().
+
+ Another way to understand this distinction is to consider an example.
+ Suppose LHS and RHS both contain 10 elements. Then in the default mode,
+ it is possible for this object to indicate that there are anywhere between
+ 0 to 10 matches between LHS and RHS. However, in forced assignment mode
+ it will always indicate exactly 10 matches.
+ !*/
+
+ result_type operator()(
+ const std::vector<lhs_element>& lhs,
+ const std::vector<rhs_element>& rhs
+ ) const
+ /*!
+ ensures
+ - returns a vector ASSIGN such that:
+ - ASSIGN.size() == lhs.size()
+ - if (ASSIGN[i] != -1) then
+ - lhs[i] is predicted to associate to rhs[ASSIGN[i]].
+ - else
+ - lhs[i] doesn't associate with anything in rhs.
+ - All values in ASSIGN which are not equal to -1 are unique.
+ That is, ASSIGN will never indicate that more than one element
+ of lhs is assigned to a particular element of rhs.
+ !*/
+
+ result_type operator() (
+ const sample_type& item
+ ) const;
+ /*!
+ ensures
+ - returns (*this)(item.first, item.second);
+ !*/
+
+ void predict_assignments (
+ const sample_type& item,
+ result_type& assignment
+ ) const;
+ /*!
+ ensures
+ - #assignment == (*this)(item)
+ !*/
+
+ void predict_assignments (
+ const std::vector<lhs_element>& lhs,
+ const std::vector<rhs_element>& rhs
+ result_type& assignment
+ ) const;
+ /*!
+ ensures
+ - #assignment == (*this)(lhs,rhs)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void serialize (
+ const assignment_function<feature_extractor>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void deserialize (
+ assignment_function<feature_extractor>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ASSIGNMENT_FuNCTION_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h b/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h
new file mode 100644
index 000000000..8166e1c82
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h
@@ -0,0 +1,181 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
+#define DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
+
+#include "cross_validate_assignment_trainer_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include "svm.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename assignment_function
+ >
+ double test_assignment_function (
+ const assignment_function& assigner,
+ const std::vector<typename assignment_function::sample_type>& samples,
+ const std::vector<typename assignment_function::label_type>& labels
+ )
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ if (assigner.forces_assignment())
+ {
+ DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
+ "\t double test_assignment_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+ else
+ {
+ DLIB_ASSERT(is_assignment_problem(samples, labels),
+ "\t double test_assignment_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+#endif
+ double total_right = 0;
+ double total = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ const std::vector<long>& out = assigner(samples[i]);
+ for (unsigned long j = 0; j < out.size(); ++j)
+ {
+ if (out[j] == labels[i][j])
+ ++total_right;
+
+ ++total;
+ }
+ }
+
+ if (total != 0)
+ return total_right/total;
+ else
+ return 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ double cross_validate_assignment_trainer (
+ const trainer_type& trainer,
+ const std::vector<typename trainer_type::sample_type>& samples,
+ const std::vector<typename trainer_type::label_type>& labels,
+ const long folds
+ )
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ if (trainer.forces_assignment())
+ {
+ DLIB_ASSERT(is_forced_assignment_problem(samples, labels) &&
+ 1 < folds && folds <= static_cast<long>(samples.size()),
+ "\t double cross_validate_assignment_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+ else
+ {
+ DLIB_ASSERT(is_assignment_problem(samples, labels) &&
+ 1 < folds && folds <= static_cast<long>(samples.size()),
+ "\t double cross_validate_assignment_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+#endif
+
+
+
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::label_type label_type;
+
+ const long num_in_test = samples.size()/folds;
+ const long num_in_train = samples.size() - num_in_test;
+
+
+ std::vector<sample_type> samples_test, samples_train;
+ std::vector<label_type> labels_test, labels_train;
+
+
+ long next_test_idx = 0;
+ double total_right = 0;
+ double total = 0;
+
+
+ for (long i = 0; i < folds; ++i)
+ {
+ samples_test.clear();
+ labels_test.clear();
+ samples_train.clear();
+ labels_train.clear();
+
+ // load up the test samples
+ for (long cnt = 0; cnt < num_in_test; ++cnt)
+ {
+ samples_test.push_back(samples[next_test_idx]);
+ labels_test.push_back(labels[next_test_idx]);
+ next_test_idx = (next_test_idx + 1)%samples.size();
+ }
+
+ // load up the training samples
+ long next = next_test_idx;
+ for (long cnt = 0; cnt < num_in_train; ++cnt)
+ {
+ samples_train.push_back(samples[next]);
+ labels_train.push_back(labels[next]);
+ next = (next + 1)%samples.size();
+ }
+
+
+ const typename trainer_type::trained_function_type& df = trainer.train(samples_train,labels_train);
+
+ // check how good df is on the test data
+ for (unsigned long i = 0; i < samples_test.size(); ++i)
+ {
+ const std::vector<long>& out = df(samples_test[i]);
+ for (unsigned long j = 0; j < out.size(); ++j)
+ {
+ if (out[j] == labels_test[i][j])
+ ++total_right;
+
+ ++total;
+ }
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+ if (total != 0)
+ return total_right/total;
+ else
+ return 1;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/cross_validate_assignment_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_assignment_trainer_abstract.h
new file mode 100644
index 000000000..05dd4758e
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_assignment_trainer_abstract.h
@@ -0,0 +1,69 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_Hh_
+#ifdef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "svm.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename assignment_function
+ >
+ double test_assignment_function (
+ const assignment_function& assigner,
+ const std::vector<typename assignment_function::sample_type>& samples,
+ const std::vector<typename assignment_function::label_type>& labels
+ );
+ /*!
+ requires
+ - is_assignment_problem(samples, labels)
+ - if (assigner.forces_assignment()) then
+ - is_forced_assignment_problem(samples, labels)
+ - assignment_function == an instantiation of the dlib::assignment_function
+ template or an object with a compatible interface.
+ ensures
+ - Tests assigner against the given samples and labels and returns the fraction
+ of assignments predicted correctly.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ double cross_validate_assignment_trainer (
+ const trainer_type& trainer,
+ const std::vector<typename trainer_type::sample_type>& samples,
+ const std::vector<typename trainer_type::label_type>& labels,
+ const long folds
+ );
+ /*!
+ requires
+ - is_assignment_problem(samples, labels)
+ - if (trainer.forces_assignment()) then
+ - is_forced_assignment_problem(samples, labels)
+ - 1 < folds <= samples.size()
+ - trainer_type == dlib::structural_assignment_trainer or an object
+ with a compatible interface.
+ ensures
+ - performs k-fold cross validation by using the given trainer to solve the
+ given assignment learning problem for the given number of folds. Each fold
+ is tested using the output of the trainer and the fraction of assignments
+ predicted correctly is returned.
+ - The number of folds used is given by the folds argument.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer.h b/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer.h
new file mode 100644
index 000000000..83e4e4048
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer.h
@@ -0,0 +1,258 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_Hh_
+#define DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_Hh_
+
+#include "../array.h"
+#include "../graph_cuts/min_cut.h"
+#include "svm.h"
+#include "cross_validate_graph_labeling_trainer_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_labeler,
+ typename graph_type
+ >
+ matrix<double,1,2> test_graph_labeling_function (
+ const graph_labeler& labeler,
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels,
+ const std::vector<std::vector<double> >& losses
+ )
+ {
+#ifdef ENABLE_ASSERTS
+ std::string reason_for_failure;
+ DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) ,
+ "\t matrix test_graph_labeling_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t reason_for_failure: " << reason_for_failure
+ );
+ DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
+ all_values_are_nonnegative(losses) == true,
+ "\t matrix test_graph_labeling_function()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t labels.size(): " << labels.size()
+ << "\n\t losses.size(): " << losses.size()
+ << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
+ << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
+ );
+#endif
+
+ std::vector<bool> temp;
+ double num_pos_correct = 0;
+ double num_pos = 0;
+ double num_neg_correct = 0;
+ double num_neg = 0;
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ labeler(samples[i], temp);
+
+ for (unsigned long j = 0; j < labels[i].size(); ++j)
+ {
+ // What is the loss for this example? It's just 1 unless we have a
+ // per example loss vector.
+ const double loss = (losses.size() == 0) ? 1.0 : losses[i][j];
+
+ if (labels[i][j])
+ {
+ num_pos += loss;
+ if (temp[j])
+ num_pos_correct += loss;
+ }
+ else
+ {
+ num_neg += loss;
+ if (!temp[j])
+ num_neg_correct += loss;
+ }
+ }
+ }
+
+ matrix<double, 1, 2> res;
+ if (num_pos != 0)
+ res(0) = num_pos_correct/num_pos;
+ else
+ res(0) = 1;
+ if (num_neg != 0)
+ res(1) = num_neg_correct/num_neg;
+ else
+ res(1) = 1;
+ return res;
+ }
+
+ template <
+ typename graph_labeler,
+ typename graph_type
+ >
+ matrix<double,1,2> test_graph_labeling_function (
+ const graph_labeler& labeler,
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels
+ )
+ {
+ std::vector<std::vector<double> > losses;
+ return test_graph_labeling_function(labeler, samples, labels, losses);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename graph_type
+ >
+ matrix<double,1,2> cross_validate_graph_labeling_trainer (
+ const trainer_type& trainer,
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels,
+ const std::vector<std::vector<double> >& losses,
+ const long folds
+ )
+ {
+#ifdef ENABLE_ASSERTS
+ std::string reason_for_failure;
+ DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure),
+ "\t matrix cross_validate_graph_labeling_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t reason_for_failure: " << reason_for_failure
+ );
+ DLIB_ASSERT( 1 < folds && folds <= static_cast<long>(samples.size()),
+ "\t matrix cross_validate_graph_labeling_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t folds: " << folds
+ );
+ DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
+ all_values_are_nonnegative(losses) == true,
+ "\t matrix cross_validate_graph_labeling_trainer()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t labels.size(): " << labels.size()
+ << "\n\t losses.size(): " << losses.size()
+ << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
+ << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
+ );
+#endif
+
+ typedef std::vector<bool> label_type;
+
+ const long num_in_test = samples.size()/folds;
+ const long num_in_train = samples.size() - num_in_test;
+
+
+ dlib::array<graph_type> samples_test, samples_train;
+ std::vector<label_type> labels_test, labels_train;
+ std::vector<std::vector<double> > losses_test, losses_train;
+
+
+ long next_test_idx = 0;
+
+ std::vector<bool> temp;
+ double num_pos_correct = 0;
+ double num_pos = 0;
+ double num_neg_correct = 0;
+ double num_neg = 0;
+
+ graph_type gtemp;
+
+ for (long i = 0; i < folds; ++i)
+ {
+ samples_test.clear();
+ labels_test.clear();
+ losses_test.clear();
+ samples_train.clear();
+ labels_train.clear();
+ losses_train.clear();
+
+ // load up the test samples
+ for (long cnt = 0; cnt < num_in_test; ++cnt)
+ {
+ copy_graph(samples[next_test_idx], gtemp);
+ samples_test.push_back(gtemp);
+ labels_test.push_back(labels[next_test_idx]);
+ if (losses.size() != 0)
+ losses_test.push_back(losses[next_test_idx]);
+ next_test_idx = (next_test_idx + 1)%samples.size();
+ }
+
+ // load up the training samples
+ long next = next_test_idx;
+ for (long cnt = 0; cnt < num_in_train; ++cnt)
+ {
+ copy_graph(samples[next], gtemp);
+ samples_train.push_back(gtemp);
+ labels_train.push_back(labels[next]);
+ if (losses.size() != 0)
+ losses_train.push_back(losses[next]);
+ next = (next + 1)%samples.size();
+ }
+
+
+ const typename trainer_type::trained_function_type& labeler = trainer.train(samples_train,labels_train,losses_train);
+
+ // check how good labeler is on the test data
+ for (unsigned long i = 0; i < samples_test.size(); ++i)
+ {
+ labeler(samples_test[i], temp);
+ for (unsigned long j = 0; j < labels_test[i].size(); ++j)
+ {
+ // What is the loss for this example? It's just 1 unless we have a
+ // per example loss vector.
+ const double loss = (losses_test.size() == 0) ? 1.0 : losses_test[i][j];
+
+ if (labels_test[i][j])
+ {
+ num_pos += loss;
+ if (temp[j])
+ num_pos_correct += loss;
+ }
+ else
+ {
+ num_neg += loss;
+ if (!temp[j])
+ num_neg_correct += loss;
+ }
+ }
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+
+ matrix<double, 1, 2> res;
+ if (num_pos != 0)
+ res(0) = num_pos_correct/num_pos;
+ else
+ res(0) = 1;
+ if (num_neg != 0)
+ res(1) = num_neg_correct/num_neg;
+ else
+ res(1) = 1;
+ return res;
+ }
+
+ template <
+ typename trainer_type,
+ typename graph_type
+ >
+ matrix<double,1,2> cross_validate_graph_labeling_trainer (
+ const trainer_type& trainer,
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels,
+ const long folds
+ )
+ {
+ std::vector<std::vector<double> > losses;
+ return cross_validate_graph_labeling_trainer(trainer, samples, labels, losses, folds);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer_abstract.h
new file mode 100644
index 000000000..cda4af91e
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer_abstract.h
@@ -0,0 +1,147 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_ABSTRACT_Hh_
+#ifdef DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_ABSTRACT_Hh_
+
+#include "../array/array_kernel_abstract.h"
+#include <vector>
+#include "../matrix/matrix_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_labeler,
+ typename graph_type
+ >
+ matrix<double,1,2> test_graph_labeling_function (
+ const graph_labeler& labeler,
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels
+ );
+ /*!
+ requires
+ - is_graph_labeling_problem(samples,labels) == true
+ - graph_labeler == an object with an interface compatible with the
+ dlib::graph_labeler object.
+ - the following must be a valid expression: labeler(samples[0]);
+ ensures
+ - This function tests the accuracy of the given graph labeler against
+ the sample graphs and their associated labels. In particular, this
+ function returns a matrix R such that:
+ - R(0) == The fraction of nodes which are supposed to have a label of
+ true that are labeled as such by the labeler.
+ - R(1) == The fraction of nodes which are supposed to have a label of
+ false that are labeled as such by the labeler.
+ Therefore, if R is [1,1] then the labeler makes perfect predictions while
+ an R of [0,0] indicates that it gets everything wrong.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_labeler,
+ typename graph_type
+ >
+ matrix<double,1,2> test_graph_labeling_function (
+ const graph_labeler& labeler,
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels,
+ const std::vector<std::vector<double> >& losses
+ );
+ /*!
+ requires
+ - is_graph_labeling_problem(samples,labels) == true
+ - graph_labeler == an object with an interface compatible with the
+ dlib::graph_labeler object.
+ - the following must be a valid expression: labeler(samples[0]);
+ - if (losses.size() != 0) then
+ - sizes_match(labels, losses) == true
+ - all_values_are_nonnegative(losses) == true
+ ensures
+ - This overload of test_graph_labeling_function() does the same thing as the
+ one defined above, except that instead of counting 1 for each labeling
+ mistake, it weights each mistake according to the corresponding value in
+ losses. That is, instead of counting a value of 1 for making a mistake on
+ samples[i].node(j), this routine counts a value of losses[i][j]. Under this
+ interpretation, the loss values represent how useful it is to correctly label
+ each node. Therefore, the values returned represent fractions of overall
+ labeling utility rather than raw labeling accuracy.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename graph_type
+ >
+ matrix<double,1,2> cross_validate_graph_labeling_trainer (
+ const trainer_type& trainer,
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels,
+ const long folds
+ );
+ /*!
+ requires
+ - is_graph_labeling_problem(samples,labels) == true
+ - 1 < folds <= samples.size()
+ - trainer_type == an object which trains some kind of graph labeler object
+ (e.g. structural_graph_labeling_trainer)
+ ensures
+ - performs k-fold cross validation by using the given trainer to solve the
+ given graph labeling problem for the given number of folds. Each fold
+ is tested using the output of the trainer and the average classification
+ accuracy from all folds is returned. In particular, this function returns
+ a matrix R such that:
+ - R(0) == The fraction of nodes which are supposed to have a label of
+ true that are labeled as such by the learned labeler.
+ - R(1) == The fraction of nodes which are supposed to have a label of
+ false that are labeled as such by the learned labeler.
+ Therefore, if R is [1,1] then the labeler makes perfect predictions while
+ an R of [0,0] indicates that it gets everything wrong.
+ - The number of folds used is given by the folds argument.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename graph_type
+ >
+ matrix<double,1,2> cross_validate_graph_labeling_trainer (
+ const trainer_type& trainer,
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels,
+ const std::vector<std::vector<double> >& losses,
+ const long folds
+ );
+ /*!
+ requires
+ - is_graph_labeling_problem(samples,labels) == true
+ - 1 < folds <= samples.size()
+ - trainer_type == an object which trains some kind of graph labeler object
+ (e.g. structural_graph_labeling_trainer)
+ - if (losses.size() != 0) then
+ - sizes_match(labels, losses) == true
+ - all_values_are_nonnegative(losses) == true
+ ensures
+ - This overload of cross_validate_graph_labeling_trainer() does the same thing
+ as the one defined above, except that instead of counting 1 for each labeling
+ mistake, it weights each mistake according to the corresponding value in
+ losses. That is, instead of counting a value of 1 for making a mistake on
+ samples[i].node(j), this routine counts a value of losses[i][j]. Under this
+ interpretation, the loss values represent how useful it is to correctly label
+ each node. Therefore, the values returned represent fractions of overall
+ labeling utility rather than raw labeling accuracy.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h
new file mode 100644
index 000000000..be8fa3f3f
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h
@@ -0,0 +1,208 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
+#define DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "cross_validate_multiclass_trainer_abstract.h"
+#include <sstream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dec_funct_type,
+ typename sample_type,
+ typename label_type
+ >
+ const matrix<double> test_multiclass_decision_function (
+ const dec_funct_type& dec_funct,
+ const std::vector<sample_type>& x_test,
+ const std::vector<label_type>& y_test
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_learning_problem(x_test,y_test) == true,
+ "\tmatrix test_multiclass_decision_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_learning_problem(x_test,y_test): "
+ << is_learning_problem(x_test,y_test));
+
+
+ const std::vector<label_type> all_labels = dec_funct.get_labels();
+
+ // make a lookup table that maps from labels to their index in all_labels
+ std::map<label_type,unsigned long> label_to_int;
+ for (unsigned long i = 0; i < all_labels.size(); ++i)
+ label_to_int[all_labels[i]] = i;
+
+ matrix<double, 0, 0, typename dec_funct_type::mem_manager_type> res;
+ res.set_size(all_labels.size(), all_labels.size());
+
+ res = 0;
+
+ typename std::map<label_type,unsigned long>::const_iterator iter;
+
+ // now test this trained object
+ for (unsigned long i = 0; i < x_test.size(); ++i)
+ {
+ iter = label_to_int.find(y_test[i]);
+ // ignore samples with labels that the decision function doesn't know about.
+ if (iter == label_to_int.end())
+ continue;
+
+ const unsigned long truth = iter->second;
+ const unsigned long pred = label_to_int[dec_funct(x_test[i])];
+
+ res(truth,pred) += 1;
+ }
+
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class cross_validation_error : public dlib::error
+ {
+ public:
+ cross_validation_error(const std::string& msg) : dlib::error(msg){};
+ };
+
+ template <
+ typename trainer_type,
+ typename sample_type,
+ typename label_type
+ >
+ const matrix<double> cross_validate_multiclass_trainer (
+ const trainer_type& trainer,
+ const std::vector<sample_type>& x,
+ const std::vector<label_type>& y,
+ const long folds
+ )
+ {
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y) == true &&
+ 1 < folds && folds <= static_cast<long>(x.size()),
+ "\tmatrix cross_validate_multiclass_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.size(): " << x.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
+ );
+
+ const std::vector<label_type> all_labels = select_all_distinct_labels(y);
+
+ // count the number of times each label shows up
+ std::map<label_type,long> label_counts;
+ for (unsigned long i = 0; i < y.size(); ++i)
+ label_counts[y[i]] += 1;
+
+
+ // figure out how many samples from each class will be in the test and train splits
+ std::map<label_type,long> num_in_test, num_in_train;
+ for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i)
+ {
+ const long in_test = i->second/folds;
+ if (in_test == 0)
+ {
+ std::ostringstream sout;
+ sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl;
+ sout << "than the number of elements of one of the training classes." << std::endl;
+ sout << " folds: "<< folds << std::endl;
+ sout << " size of class " << i->first << ": "<< i->second << std::endl;
+ throw cross_validation_error(sout.str());
+ }
+ num_in_test[i->first] = in_test;
+ num_in_train[i->first] = i->second - in_test;
+ }
+
+
+
+ std::vector<sample_type> x_test, x_train;
+ std::vector<label_type> y_test, y_train;
+
+ matrix<double, 0, 0, mem_manager_type> res;
+
+ std::map<label_type,long> next_test_idx;
+ for (unsigned long i = 0; i < all_labels.size(); ++i)
+ next_test_idx[all_labels[i]] = 0;
+
+ label_type label;
+
+ for (long i = 0; i < folds; ++i)
+ {
+ x_test.clear();
+ y_test.clear();
+ x_train.clear();
+ y_train.clear();
+
+ // load up the test samples
+ for (unsigned long j = 0; j < all_labels.size(); ++j)
+ {
+ label = all_labels[j];
+ long next = next_test_idx[label];
+
+ long cur = 0;
+ const long num_needed = num_in_test[label];
+ while (cur < num_needed)
+ {
+ if (y[next] == label)
+ {
+ x_test.push_back(x[next]);
+ y_test.push_back(label);
+ ++cur;
+ }
+ next = (next + 1)%x.size();
+ }
+
+ next_test_idx[label] = next;
+ }
+
+ // load up the training samples
+ for (unsigned long j = 0; j < all_labels.size(); ++j)
+ {
+ label = all_labels[j];
+ long next = next_test_idx[label];
+
+ long cur = 0;
+ const long num_needed = num_in_train[label];
+ while (cur < num_needed)
+ {
+ if (y[next] == label)
+ {
+ x_train.push_back(x[next]);
+ y_train.push_back(label);
+ ++cur;
+ }
+ next = (next + 1)%x.size();
+ }
+ }
+
+
+ try
+ {
+ // do the training and testing
+ res += test_multiclass_decision_function(trainer.train(x_train,y_train),x_test,y_test);
+ }
+ catch (invalid_nu_error&)
+ {
+ // just ignore cases which result in an invalid nu
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+ return res;
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/cross_validate_multiclass_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer_abstract.h
new file mode 100644
index 000000000..f84503cdc
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer_abstract.h
@@ -0,0 +1,99 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_ABSTRACT_Hh_
+#ifdef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dec_funct_type,
+ typename sample_type,
+ typename label_type
+ >
+ const matrix<double> test_multiclass_decision_function (
+ const dec_funct_type& dec_funct,
+ const std::vector<sample_type>& x_test,
+ const std::vector<label_type>& y_test
+ );
+ /*!
+ requires
+ - is_learning_problem(x_test, y_test)
+ - dec_funct_type == some kind of multiclass decision function object
+ (e.g. one_vs_one_decision_function)
+ ensures
+ - Tests dec_funct against the given samples in x_test and labels in y_test
+ and returns a confusion matrix summarizing the results.
+ - let L = dec_funct.get_labels(). Then the confusion matrix C returned
+ by this function has the following properties.
+ - C.nr() == C.nc() == L.size()
+ - C(r,c) == the number of times a sample with label L(r) was predicted
+ to have a label of L(c)
+ - Any samples with a y_test value not in L are ignored. That is, samples
+ with labels the decision function hasn't ever seen before are ignored.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class cross_validation_error : public dlib::error
+ {
+ /*!
+ This is the exception class used by the cross_validate_multiclass_trainer()
+ routine.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sample_type,
+ typename label_type
+ >
+ const matrix<double> cross_validate_multiclass_trainer (
+ const trainer_type& trainer,
+ const std::vector<sample_type>& x,
+ const std::vector<label_type>& y,
+ const long folds
+ );
+ /*!
+ requires
+ - is_learning_problem(x,y)
+ - 1 < folds <= x.size()
+ - trainer_type == some kind of multiclass classification trainer object (e.g. one_vs_one_trainer)
+ ensures
+ - performs k-fold cross validation by using the given trainer to solve the
+ given multiclass classification problem for the given number of folds.
+ Each fold is tested using the output of the trainer and the confusion
+ matrix from all folds is summed and returned.
+ - The total confusion matrix is computed by running test_binary_decision_function()
+ on each fold and summing its output.
+ - The number of folds used is given by the folds argument.
+ - let L = select_all_distinct_labels(y). Then the confusion matrix C returned
+ by this function has the following properties.
+ - C.nr() == C.nc() == L.size()
+ - C(r,c) == the number of times a sample with label L(r) was predicted
+ to have a label of L(c)
+
+ Note that sum(C) might be slightly less than x.size(). This happens if the number of
+ samples in a class is not an even multiple of folds. This is because each fold has the
+ same number of test samples in it and so if the number of samples in a class isn't a
+ multiple of folds then a few are not tested.
+ throws
+ - cross_validation_error
+ This exception is thrown if one of the classes has fewer samples than
+ the number of requested folds.
+ !*/
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_object_detection_trainer.h b/ml/dlib/dlib/svm/cross_validate_object_detection_trainer.h
new file mode 100644
index 000000000..7cb38f0b7
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_object_detection_trainer.h
@@ -0,0 +1,430 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_Hh_
+#define DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_Hh_
+
+#include "cross_validate_object_detection_trainer_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include "svm.h"
+#include "../geometry.h"
+#include "../image_processing/full_object_detection.h"
+#include "../image_processing/box_overlap_testing.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline unsigned long number_of_truth_hits (
+ const std::vector<full_object_detection>& truth_boxes,
+ const std::vector<rectangle>& ignore,
+ const std::vector<std::pair<double,rectangle> >& boxes,
+ const test_box_overlap& overlap_tester,
+ std::vector<std::pair<double,bool> >& all_dets,
+ unsigned long& missing_detections,
+ const test_box_overlap& overlaps_ignore_tester
+ )
+ /*!
+ ensures
+ - returns the number of elements in truth_boxes which are overlapped by an
+ element of boxes. In this context, two boxes, A and B, overlap if and only if
+ overlap_tester(A,B) == true.
+ - No element of boxes is allowed to account for more than one element of truth_boxes.
+ - The returned number is in the range [0,truth_boxes.size()]
+ - Adds the score for each box from boxes into all_dets and labels each with
+ a bool indicating if it hit a truth box. Note that we skip boxes that
+ don't hit any truth boxes and match an ignore box.
+ - Adds the number of truth boxes which didn't have any hits into
+ missing_detections.
+ !*/
+ {
+ if (boxes.size() == 0)
+ {
+ missing_detections += truth_boxes.size();
+ return 0;
+ }
+
+ unsigned long count = 0;
+ std::vector<bool> used(boxes.size(),false);
+ for (unsigned long i = 0; i < truth_boxes.size(); ++i)
+ {
+ bool found_match = false;
+ // Find the first box that hits truth_boxes[i]
+ for (unsigned long j = 0; j < boxes.size(); ++j)
+ {
+ if (used[j])
+ continue;
+
+ if (overlap_tester(truth_boxes[i].get_rect(), boxes[j].second))
+ {
+ used[j] = true;
+ ++count;
+ found_match = true;
+ break;
+ }
+ }
+
+ if (!found_match)
+ ++missing_detections;
+ }
+
+ for (unsigned long i = 0; i < boxes.size(); ++i)
+ {
+ // only out put boxes if they match a truth box or are not ignored.
+ if (used[i] || !overlaps_any_box(overlaps_ignore_tester, ignore, boxes[i].second))
+ {
+ all_dets.push_back(std::make_pair(boxes[i].first, used[i]));
+ }
+ }
+
+ return count;
+ }
+
+ inline unsigned long number_of_truth_hits (
+ const std::vector<full_object_detection>& truth_boxes,
+ const std::vector<rectangle>& ignore,
+ const std::vector<std::pair<double,rectangle> >& boxes,
+ const test_box_overlap& overlap_tester,
+ std::vector<std::pair<double,bool> >& all_dets,
+ unsigned long& missing_detections
+ )
+ {
+ return number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlap_tester);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename object_detector_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ object_detector_type& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_dets,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( is_learning_problem(images,truth_dets) == true &&
+ ignore.size() == images.size(),
+ "\t matrix test_object_detection_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
+ << "\n\t ignore.size(): " << ignore.size()
+ << "\n\t images.size(): " << images.size()
+ );
+
+
+
+ double correct_hits = 0;
+ double total_true_targets = 0;
+
+ std::vector<std::pair<double,bool> > all_dets;
+ unsigned long missing_detections = 0;
+
+
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ std::vector<std::pair<double,rectangle> > hits;
+ detector(images[i], hits, adjust_threshold);
+
+ correct_hits += impl::number_of_truth_hits(truth_dets[i], ignore[i], hits, overlap_tester, all_dets, missing_detections);
+ total_true_targets += truth_dets[i].size();
+ }
+
+ std::sort(all_dets.rbegin(), all_dets.rend());
+
+ double precision, recall;
+
+ double total_hits = all_dets.size();
+
+ if (total_hits == 0)
+ precision = 1;
+ else
+ precision = correct_hits / total_hits;
+
+ if (total_true_targets == 0)
+ recall = 1;
+ else
+ recall = correct_hits / total_true_targets;
+
+ matrix<double, 1, 3> res;
+ res = precision, recall, average_precision(all_dets, missing_detections);
+ return res;
+ }
+
+ template <
+ typename object_detector_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ object_detector_type& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_dets,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ )
+ {
+ // convert into a list of regular rectangles.
+ std::vector<std::vector<full_object_detection> > rects(truth_dets.size());
+ for (unsigned long i = 0; i < truth_dets.size(); ++i)
+ {
+ for (unsigned long j = 0; j < truth_dets[i].size(); ++j)
+ {
+ rects[i].push_back(full_object_detection(truth_dets[i][j]));
+ }
+ }
+
+ return test_object_detection_function(detector, images, rects, ignore, overlap_tester, adjust_threshold);
+ }
+
+ template <
+ typename object_detector_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ object_detector_type& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_dets,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ )
+ {
+ std::vector<std::vector<rectangle> > ignore(images.size());
+ return test_object_detection_function(detector,images,truth_dets,ignore, overlap_tester, adjust_threshold);
+ }
+
+ template <
+ typename object_detector_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ object_detector_type& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_dets,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ )
+ {
+ std::vector<std::vector<rectangle> > ignore(images.size());
+ return test_object_detection_function(detector,images,truth_dets,ignore, overlap_tester, adjust_threshold);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename array_type
+ >
+ struct array_subset_helper
+ {
+ typedef typename array_type::mem_manager_type mem_manager_type;
+
+ array_subset_helper (
+ const array_type& array_,
+ const std::vector<unsigned long>& idx_set_
+ ) :
+ array(array_),
+ idx_set(idx_set_)
+ {
+ }
+
+ unsigned long size() const { return idx_set.size(); }
+
+ typedef typename array_type::type type;
+ const type& operator[] (
+ unsigned long idx
+ ) const { return array[idx_set[idx]]; }
+
+ private:
+ const array_type& array;
+ const std::vector<unsigned long>& idx_set;
+ };
+
+ template <
+ typename T
+ >
+ const matrix_op<op_array_to_mat<array_subset_helper<T> > > mat (
+ const array_subset_helper<T>& m
+ )
+ {
+ typedef op_array_to_mat<array_subset_helper<T> > op;
+ return matrix_op<op>(op(m));
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> cross_validate_object_detection_trainer (
+ const trainer_type& trainer,
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_dets,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const long folds,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( is_learning_problem(images,truth_dets) == true &&
+ ignore.size() == images.size() &&
+ 1 < folds && folds <= static_cast<long>(images.size()),
+ "\t matrix cross_validate_object_detection_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
+ << "\n\t folds: "<< folds
+ << "\n\t ignore.size(): " << ignore.size()
+ << "\n\t images.size(): " << images.size()
+ );
+
+ double correct_hits = 0;
+ double total_true_targets = 0;
+
+ const long test_size = images.size()/folds;
+
+ std::vector<std::pair<double,bool> > all_dets;
+ unsigned long missing_detections = 0;
+ unsigned long test_idx = 0;
+ for (long iter = 0; iter < folds; ++iter)
+ {
+ std::vector<unsigned long> train_idx_set;
+ std::vector<unsigned long> test_idx_set;
+
+ for (long i = 0; i < test_size; ++i)
+ test_idx_set.push_back(test_idx++);
+
+ unsigned long train_idx = test_idx%images.size();
+ std::vector<std::vector<full_object_detection> > training_rects;
+ std::vector<std::vector<rectangle> > training_ignores;
+ for (unsigned long i = 0; i < images.size()-test_size; ++i)
+ {
+ training_rects.push_back(truth_dets[train_idx]);
+ training_ignores.push_back(ignore[train_idx]);
+ train_idx_set.push_back(train_idx);
+ train_idx = (train_idx+1)%images.size();
+ }
+
+
+ impl::array_subset_helper<image_array_type> array_subset(images, train_idx_set);
+ typename trainer_type::trained_function_type detector = trainer.train(array_subset, training_rects, training_ignores, overlap_tester);
+ for (unsigned long i = 0; i < test_idx_set.size(); ++i)
+ {
+ std::vector<std::pair<double,rectangle> > hits;
+ detector(images[test_idx_set[i]], hits, adjust_threshold);
+
+ correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], ignore[i], hits, overlap_tester, all_dets, missing_detections);
+ total_true_targets += truth_dets[test_idx_set[i]].size();
+ }
+
+ }
+
+ std::sort(all_dets.rbegin(), all_dets.rend());
+
+
+ double precision, recall;
+
+ double total_hits = all_dets.size();
+
+ if (total_hits == 0)
+ precision = 1;
+ else
+ precision = correct_hits / total_hits;
+
+ if (total_true_targets == 0)
+ recall = 1;
+ else
+ recall = correct_hits / total_true_targets;
+
+ matrix<double, 1, 3> res;
+ res = precision, recall, average_precision(all_dets, missing_detections);
+ return res;
+ }
+
+ template <
+ typename trainer_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> cross_validate_object_detection_trainer (
+ const trainer_type& trainer,
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_dets,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const long folds,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ )
+ {
+ // convert into a list of regular rectangles.
+ std::vector<std::vector<full_object_detection> > dets(truth_dets.size());
+ for (unsigned long i = 0; i < truth_dets.size(); ++i)
+ {
+ for (unsigned long j = 0; j < truth_dets[i].size(); ++j)
+ {
+ dets[i].push_back(full_object_detection(truth_dets[i][j]));
+ }
+ }
+
+ return cross_validate_object_detection_trainer(trainer, images, dets, ignore, folds, overlap_tester, adjust_threshold);
+ }
+
+ template <
+ typename trainer_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> cross_validate_object_detection_trainer (
+ const trainer_type& trainer,
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_dets,
+ const long folds,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ )
+ {
+ const std::vector<std::vector<rectangle> > ignore(images.size());
+ return cross_validate_object_detection_trainer(trainer,images,truth_dets,ignore,folds,overlap_tester,adjust_threshold);
+ }
+
+ template <
+ typename trainer_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> cross_validate_object_detection_trainer (
+ const trainer_type& trainer,
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_dets,
+ const long folds,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ )
+ {
+ const std::vector<std::vector<rectangle> > ignore(images.size());
+ return cross_validate_object_detection_trainer(trainer,images,truth_dets,ignore,folds,overlap_tester,adjust_threshold);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/cross_validate_object_detection_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_object_detection_trainer_abstract.h
new file mode 100644
index 000000000..575ed77fb
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_object_detection_trainer_abstract.h
@@ -0,0 +1,297 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_Hh_
+#ifdef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "../geometry.h"
+#include "../image_processing/full_object_detection_abstract.h"
+#include "../dnn/layers_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename object_detector_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ object_detector_type& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_dets,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - is_learning_problem(images,truth_dets)
+ - images.size() == ignore.size()
+ - object_detector_type == some kind of object detector function object
+ (e.g. object_detector)
+ - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
+ and it must contain objects which can be accepted by detector().
+ ensures
+ - Tests the given detector against the supplied object detection problem and
+ returns the precision, recall, and average precision. Note that the task is
+ to predict, for each images[i], the set of object locations given by
+ truth_dets[i]. Additionally, any detections on image[i] that match a box in
+ ignore[i] are ignored. That is, detections matching a box in ignore[i] do
+ not count as a false alarm and similarly if any element of ignore[i] goes
+ undetected it does not count as a missed detection. So we say that ignore[i]
+ contains a set of boxes that we "don't care" if they are detected or not.
+ - In particular, returns a matrix M such that:
+ - M(0) == the precision of the detector object. This is a number
+ in the range [0,1] which measures the fraction of detector outputs
+ which correspond to a real target. A value of 1 means the detector
+ never produces any false alarms while a value of 0 means it only
+ produces false alarms.
+ - M(1) == the recall of the detector object. This is a number in the
+ range [0,1] which measures the fraction of targets found by the
+ detector. A value of 1 means the detector found all the targets
+ in truth_dets while a value of 0 means the detector didn't locate
+ any of the targets.
+ - M(2) == the average precision of the detector object. This is a number
+ in the range [0,1] which measures the overall quality of the detector.
+ We compute this by taking all the detections output by the detector and
+ ordering them in descending order of their detection scores. Then we use
+ the average_precision() routine to score the ranked listing and store the
+ output into M(2).
+ - This function considers a detector output D to match a rectangle T if and
+ only if overlap_tester(T,D) returns true.
+ - Note that you can use the adjust_threshold argument to raise or lower the
+ detection threshold. This value is passed into the identically named
+ argument to the detector object and therefore influences the number of
+ output detections. It can be useful, for example, to lower the detection
+ threshold because it results in more detections being output by the
+ detector, and therefore provides more information in the ranking,
+ possibly raising the average precision.
+ !*/
+
+ template <
+ typename object_detector_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ object_detector_type& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_dets,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - All the requirements of the above test_object_detection_function() routine.
+ ensures
+ - converts all the rectangles in truth_dets into full_object_detection objects
+ via full_object_detection's rectangle constructor. Then invokes
+ test_object_detection_function() on the full_object_detections and returns
+ the results.
+ !*/
+
+ template <
+ typename object_detector_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ object_detector_type& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_dets,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - All the requirements of the above test_object_detection_function() routine.
+ ensures
+ - This function simply invokes test_object_detection_function() with all the
+ given arguments and an empty set of ignore rectangles and returns the results.
+ !*/
+
+ template <
+ typename object_detector_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ object_detector_type& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_dets,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - All the requirements of the above test_object_detection_function() routine.
+ ensures
+ - This function simply invokes test_object_detection_function() with all the
+ given arguments and an empty set of ignore rectangles and returns the results.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename SUBNET,
+ typename image_array_type
+ >
+ const matrix<double,1,3> test_object_detection_function (
+ loss_mmod<SUBNET>& detector,
+ const image_array_type& images,
+ const std::vector<std::vector<mmod_rect>>& truth_dets,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0,
+ const test_box_overlap& overlaps_ignore_tester = test_box_overlap()
+ );
+ /*!
+ requires
+ - is_learning_problem(images,truth_dets)
+ - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
+ and it must contain objects which can be accepted by detector().
+ ensures
+ - This function is just like the test_object_detection_function() for
+ object_detector's except it runs on CNNs that use loss_mmod.
+ - Tests the given detector against the supplied object detection problem and
+ returns the precision, recall, and average precision. Note that the task is
+ to predict, for each images[i], the set of object locations, and their
+ corresponding labels, given by truth_dets[i]. Additionally, any detections
+ on image[i] that match a box in truth_dets[i] that are marked ignore are
+ ignored. That is, detections matching an ignore box, regardless of the
+ ignore box's label, do not count as a false alarm and similarly if any
+ ignored box in truth_dets goes undetected it does not count as a missed
+ detection. To test if a box overlaps an ignore box, we use overlaps_ignore_tester.
+ - In particular, returns a matrix M such that:
+ - M(0) == the precision of the detector object. This is a number
+ in the range [0,1] which measures the fraction of detector outputs
+ which correspond to a real target. A value of 1 means the detector
+ never produces any false alarms while a value of 0 means it only
+ produces false alarms.
+ - M(1) == the recall of the detector object. This is a number in the
+ range [0,1] which measures the fraction of targets found by the detector.
+ A value of 1 means the detector found all the non-ignore targets in
+ truth_dets while a value of 0 means the detector didn't locate any of the
+ targets.
+ - M(2) == the average precision of the detector object. This is a number
+ in the range [0,1] which measures the overall quality of the detector.
+ We compute this by taking all the detections output by the detector and
+ ordering them in descending order of their detection scores. Then we use
+ the average_precision() routine to score the ranked listing and store the
+ output into M(2).
+ - This function considers a detector output D to match a truth rectangle T if
+ and only if overlap_tester(T,D) returns true and the labels are identical strings.
+ - Note that you can use the adjust_threshold argument to raise or lower the
+ detection threshold. This value is passed into the identically named
+ argument to the detector object and therefore influences the number of
+ output detections. It can be useful, for example, to lower the detection
+ threshold because it results in more detections being output by the
+ detector, and therefore provides more information in the ranking,
+ possibly raising the average precision.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> cross_validate_object_detection_trainer (
+ const trainer_type& trainer,
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_dets,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const long folds,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - is_learning_problem(images,truth_dets)
+ - images.size() == ignore.size()
+ - 1 < folds <= images.size()
+ - trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer)
+ - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
+ and it must contain objects which can be accepted by detector().
+ - it is legal to call trainer.train(images, truth_dets)
+ ensures
+ - Performs k-fold cross-validation by using the given trainer to solve an
+ object detection problem for the given number of folds. Each fold is tested
+ using the output of the trainer and a matrix summarizing the results is
+ returned. The matrix contains the precision, recall, and average
+ precision of the trained detectors and is defined identically to the
+ test_object_detection_function() routine defined at the top of this file.
+ !*/
+
+ template <
+ typename trainer_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> cross_validate_object_detection_trainer (
+ const trainer_type& trainer,
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_dets,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const long folds,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - all the requirements of the above cross_validate_object_detection_trainer() routine.
+ ensures
+ - converts all the rectangles in truth_dets into full_object_detection objects
+ via full_object_detection's rectangle constructor. Then invokes
+ cross_validate_object_detection_trainer() on the full_object_detections and
+ returns the results.
+ !*/
+
+ template <
+ typename trainer_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> cross_validate_object_detection_trainer (
+ const trainer_type& trainer,
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_dets,
+ const long folds,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - All the requirements of the above cross_validate_object_detection_trainer() routine.
+ ensures
+ - This function simply invokes cross_validate_object_detection_trainer() with all
+ the given arguments and an empty set of ignore rectangles and returns the results.
+ !*/
+
+ template <
+ typename trainer_type,
+ typename image_array_type
+ >
+ const matrix<double,1,3> cross_validate_object_detection_trainer (
+ const trainer_type& trainer,
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_dets,
+ const long folds,
+ const test_box_overlap& overlap_tester = test_box_overlap(),
+ const double adjust_threshold = 0
+ );
+ /*!
+ requires
+ - All the requirements of the above cross_validate_object_detection_trainer() routine.
+ ensures
+ - This function simply invokes cross_validate_object_detection_trainer() with all
+ the given arguments and an empty set of ignore rectangles and returns the results.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_regression_trainer.h b/ml/dlib/dlib/svm/cross_validate_regression_trainer.h
new file mode 100644
index 000000000..a4c6077c9
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_regression_trainer.h
@@ -0,0 +1,155 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_
+#define DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "../statistics.h"
+#include "cross_validate_regression_trainer_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename reg_funct_type,
+ typename sample_type,
+ typename label_type
+ >
+ matrix<double,1,4>
+ test_regression_function (
+ reg_funct_type& reg_funct,
+ const std::vector<sample_type>& x_test,
+ const std::vector<label_type>& y_test
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_learning_problem(x_test,y_test) == true,
+ "\tmatrix test_regression_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_learning_problem(x_test,y_test): "
+ << is_learning_problem(x_test,y_test));
+
+ running_stats<double> rs, rs_mae;
+ running_scalar_covariance<double> rc;
+
+ for (unsigned long i = 0; i < x_test.size(); ++i)
+ {
+ // compute error
+ const double output = reg_funct(x_test[i]);
+ const double temp = output - y_test[i];
+
+ rs_mae.add(std::abs(temp));
+ rs.add(temp*temp);
+ rc.add(output, y_test[i]);
+ }
+
+ matrix<double,1,4> result;
+ result = rs.mean(), rc.correlation(), rs_mae.mean(), rs_mae.stddev();
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sample_type,
+ typename label_type
+ >
+ matrix<double,1,4>
+ cross_validate_regression_trainer (
+ const trainer_type& trainer,
+ const std::vector<sample_type>& x,
+ const std::vector<label_type>& y,
+ const long folds
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y) == true &&
+ 1 < folds && folds <= static_cast<long>(x.size()),
+ "\tmatrix cross_validate_regression_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.size(): " << x.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
+ );
+
+
+
+ const long num_in_test = x.size()/folds;
+ const long num_in_train = x.size() - num_in_test;
+
+ running_stats<double> rs, rs_mae;
+ running_scalar_covariance<double> rc;
+
+ std::vector<sample_type> x_test, x_train;
+ std::vector<label_type> y_test, y_train;
+
+
+ long next_test_idx = 0;
+
+
+ for (long i = 0; i < folds; ++i)
+ {
+ x_test.clear();
+ y_test.clear();
+ x_train.clear();
+ y_train.clear();
+
+ // load up the test samples
+ for (long cnt = 0; cnt < num_in_test; ++cnt)
+ {
+ x_test.push_back(x[next_test_idx]);
+ y_test.push_back(y[next_test_idx]);
+ next_test_idx = (next_test_idx + 1)%x.size();
+ }
+
+ // load up the training samples
+ long next = next_test_idx;
+ for (long cnt = 0; cnt < num_in_train; ++cnt)
+ {
+ x_train.push_back(x[next]);
+ y_train.push_back(y[next]);
+ next = (next + 1)%x.size();
+ }
+
+
+ try
+ {
+ const typename trainer_type::trained_function_type& df = trainer.train(x_train,y_train);
+
+ // do the training and testing
+ for (unsigned long j = 0; j < x_test.size(); ++j)
+ {
+ // compute error
+ const double output = df(x_test[j]);
+ const double temp = output - y_test[j];
+
+ rs_mae.add(std::abs(temp));
+ rs.add(temp*temp);
+ rc.add(output, y_test[j]);
+ }
+ }
+ catch (invalid_nu_error&)
+ {
+ // just ignore cases which result in an invalid nu
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+ matrix<double,1,4> result;
+ result = rs.mean(), rc.correlation(), rs_mae.mean(), rs_mae.stddev();
+ return result;
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_regression_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_regression_trainer_abstract.h
new file mode 100644
index 000000000..d6298aa74
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_regression_trainer_abstract.h
@@ -0,0 +1,82 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_
+#ifdef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename reg_funct_type,
+ typename sample_type,
+ typename label_type
+ >
+ matrix<double,1,4>
+ test_regression_function (
+ reg_funct_type& reg_funct,
+ const std::vector<sample_type>& x_test,
+ const std::vector<label_type>& y_test
+ );
+ /*!
+ requires
+ - is_learning_problem(x_test, y_test)
+ - reg_funct_type == some kind of regression function object
+ (e.g. a decision_function created by the svr_trainer )
+ ensures
+ - Tests reg_funct against the given samples in x_test and target values in
+ y_test and returns a matrix M summarizing the results. Specifically:
+ - M(0) == the mean squared error.
+ The MSE is given by: sum over i: pow(reg_funct(x_test[i]) - y_test[i], 2.0)
+ - M(1) == the correlation between reg_funct(x_test[i]) and y_test[i].
+ This is a number between -1 and 1.
+ - M(2) == the mean absolute error.
+ This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
+ - M(3) == the standard deviation of the absolute error.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sample_type,
+ typename label_type
+ >
+ matrix<double,1,4>
+ cross_validate_regression_trainer (
+ const trainer_type& trainer,
+ const std::vector<sample_type>& x,
+ const std::vector<label_type>& y,
+ const long folds
+ );
+ /*!
+ requires
+ - is_learning_problem(x,y)
+ - 1 < folds <= x.size()
+ - trainer_type == some kind of regression trainer object (e.g. svr_trainer)
+ ensures
+ - Performs k-fold cross validation by using the given trainer to solve a
+ regression problem for the given number of folds. Each fold is tested using
+ the output of the trainer. A matrix M summarizing the results is returned.
+ Specifically:
+ - M(0) == the mean squared error.
+ The MSE is given by: sum over i: pow(reg_funct(x[i]) - y[i], 2.0)
+ - M(1) == the correlation between a predicted y value and its true value.
+ This is a number between -1 and 1.
+ - M(2) == the mean absolute error.
+ This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
+ - M(3) == the standard deviation of the absolute error.
+ !*/
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_sequence_labeler.h b/ml/dlib/dlib/svm/cross_validate_sequence_labeler.h
new file mode 100644
index 000000000..75c4e363a
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_sequence_labeler.h
@@ -0,0 +1,152 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_
+#define DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_
+
+#include "cross_validate_sequence_labeler_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include "svm.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sequence_labeler_type,
+ typename sequence_type
+ >
+ const matrix<double> test_sequence_labeler (
+ const sequence_labeler_type& labeler,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<unsigned long> >& labels
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_sequence_labeling_problem(samples, labels) == true,
+ "\tmatrix test_sequence_labeler()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_sequence_labeling_problem(samples, labels): "
+ << is_sequence_labeling_problem(samples, labels));
+
+ matrix<double> res(labeler.num_labels(), labeler.num_labels());
+ res = 0;
+
+ std::vector<unsigned long> pred;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ labeler.label_sequence(samples[i], pred);
+
+ for (unsigned long j = 0; j < pred.size(); ++j)
+ {
+ const unsigned long truth = labels[i][j];
+ if (truth >= static_cast<unsigned long>(res.nr()))
+ {
+ // ignore labels the labeler doesn't know about.
+ continue;
+ }
+
+ res(truth, pred[j]) += 1;
+ }
+ }
+
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sequence_type
+ >
+ const matrix<double> cross_validate_sequence_labeler (
+ const trainer_type& trainer,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<unsigned long> >& labels,
+ const long folds
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true &&
+ 1 < folds && folds <= static_cast<long>(samples.size()),
+ "\tmatrix cross_validate_sequence_labeler()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels)
+ );
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ for (unsigned long j = 0; j < labels[i].size(); ++j)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(labels[i][j] < trainer.num_labels(),
+ "\t matrix cross_validate_sequence_labeler()"
+ << "\n\t The labels are invalid."
+ << "\n\t labels[i][j]: " << labels[i][j]
+ << "\n\t trainer.num_labels(): " << trainer.num_labels()
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ );
+ }
+ }
+#endif
+
+
+
+
+ const long num_in_test = samples.size()/folds;
+ const long num_in_train = samples.size() - num_in_test;
+
+ std::vector<sequence_type> x_test, x_train;
+ std::vector<std::vector<unsigned long> > y_test, y_train;
+
+
+ long next_test_idx = 0;
+
+ matrix<double> res;
+
+
+ for (long i = 0; i < folds; ++i)
+ {
+ x_test.clear();
+ y_test.clear();
+ x_train.clear();
+ y_train.clear();
+
+ // load up the test samples
+ for (long cnt = 0; cnt < num_in_test; ++cnt)
+ {
+ x_test.push_back(samples[next_test_idx]);
+ y_test.push_back(labels[next_test_idx]);
+ next_test_idx = (next_test_idx + 1)%samples.size();
+ }
+
+ // load up the training samples
+ long next = next_test_idx;
+ for (long cnt = 0; cnt < num_in_train; ++cnt)
+ {
+ x_train.push_back(samples[next]);
+ y_train.push_back(labels[next]);
+ next = (next + 1)%samples.size();
+ }
+
+
+ res += test_sequence_labeler(trainer.train(x_train,y_train), x_test, y_test);
+
+ } // for (long i = 0; i < folds; ++i)
+
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h b/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h
new file mode 100644
index 000000000..3d2409b28
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h
@@ -0,0 +1,83 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_ABSTRACT_Hh_
+#ifdef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "svm.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sequence_labeler_type,
+ typename sequence_type
+ >
+ const matrix<double> test_sequence_labeler (
+ const sequence_labeler_type& labeler,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<unsigned long> >& labels
+ );
+ /*!
+ requires
+ - is_sequence_labeling_problem(samples, labels)
+ - sequence_labeler_type == dlib::sequence_labeler or an object with a
+ compatible interface.
+ ensures
+ - Tests labeler against the given samples and labels and returns a confusion
+ matrix summarizing the results.
+ - The confusion matrix C returned by this function has the following properties.
+ - C.nc() == labeler.num_labels()
+ - C.nr() == labeler.num_labels()
+ - C(T,P) == the number of times a sequence element with label T was predicted
+ to have a label of P.
+ - Any samples with a label value >= labeler.num_labels() are ignored. That
+ is, samples with labels the labeler hasn't ever seen before are ignored.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sequence_type
+ >
+ const matrix<double> cross_validate_sequence_labeler (
+ const trainer_type& trainer,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<unsigned long> >& labels,
+ const long folds
+ );
+ /*!
+ requires
+ - is_sequence_labeling_problem(samples, labels)
+ - 1 < folds <= samples.size()
+ - for all valid i and j: labels[i][j] < trainer.num_labels()
+ - trainer_type == dlib::structural_sequence_labeling_trainer or an object
+ with a compatible interface.
+ ensures
+ - performs k-fold cross validation by using the given trainer to solve the
+ given sequence labeling problem for the given number of folds. Each fold
+ is tested using the output of the trainer and the confusion matrix from all
+ folds is summed and returned.
+ - The total confusion matrix is computed by running test_sequence_labeler()
+ on each fold and summing its output.
+ - The number of folds used is given by the folds argument.
+ - The confusion matrix C returned by this function has the following properties.
+ - C.nc() == trainer.num_labels()
+ - C.nr() == trainer.num_labels()
+ - C(T,P) == the number of times a sequence element with label T was predicted
+ to have a label of P.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_sequence_segmenter.h b/ml/dlib/dlib/svm/cross_validate_sequence_segmenter.h
new file mode 100644
index 000000000..8413f9165
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_sequence_segmenter.h
@@ -0,0 +1,187 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_
+#define DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_
+
+#include "cross_validate_sequence_segmenter_abstract.h"
+#include "sequence_segmenter.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename sequence_segmenter_type,
+ typename sequence_type
+ >
+ const matrix<double,1,3> raw_metrics_test_sequence_segmenter (
+ const sequence_segmenter_type& segmenter,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments
+ )
+ {
+ std::vector<std::pair<unsigned long,unsigned long> > truth;
+ std::vector<std::pair<unsigned long,unsigned long> > pred;
+
+ double true_hits = 0;
+ double total_detections = 0;
+ double total_true_segments = 0;
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ segmenter.segment_sequence(samples[i], pred);
+ truth = segments[i];
+ // sort the segments so they will be in the same orders
+ std::sort(truth.begin(), truth.end());
+ std::sort(pred.begin(), pred.end());
+
+ total_true_segments += truth.size();
+ total_detections += pred.size();
+
+ unsigned long j=0,k=0;
+ while (j < pred.size() && k < truth.size())
+ {
+ if (pred[j].first == truth[k].first &&
+ pred[j].second == truth[k].second)
+ {
+ ++true_hits;
+ ++j;
+ ++k;
+ }
+ else if (pred[j].first < truth[k].first)
+ {
+ ++j;
+ }
+ else
+ {
+ ++k;
+ }
+ }
+ }
+
+ matrix<double,1,3> res;
+ res = total_detections, total_true_segments, true_hits;
+ return res;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sequence_segmenter_type,
+ typename sequence_type
+ >
+ const matrix<double,1,3> test_sequence_segmenter (
+ const sequence_segmenter_type& segmenter,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_sequence_segmentation_problem(samples, segments) == true,
+ "\tmatrix test_sequence_segmenter()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_sequence_segmentation_problem(samples, segments): "
+ << is_sequence_segmentation_problem(samples, segments));
+
+ const matrix<double,1,3> metrics = impl::raw_metrics_test_sequence_segmenter(segmenter, samples, segments);
+
+ const double total_detections = metrics(0);
+ const double total_true_segments = metrics(1);
+ const double true_hits = metrics(2);
+
+ const double precision = (total_detections ==0) ? 1 : true_hits/total_detections;
+ const double recall = (total_true_segments==0) ? 1 : true_hits/total_true_segments;
+ const double f1 = (precision+recall ==0) ? 0 : 2*precision*recall/(precision+recall);
+
+ matrix<double,1,3> res;
+ res = precision, recall, f1;
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sequence_type
+ >
+ const matrix<double,1,3> cross_validate_sequence_segmenter (
+ const trainer_type& trainer,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments,
+ const long folds
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_sequence_segmentation_problem(samples, segments) == true &&
+ 1 < folds && folds <= static_cast<long>(samples.size()),
+ "\tmatrix cross_validate_sequence_segmenter()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t folds: " << folds
+ << "\n\t is_sequence_segmentation_problem(samples, segments): "
+ << is_sequence_segmentation_problem(samples, segments));
+
+
+ const long num_in_test = samples.size()/folds;
+ const long num_in_train = samples.size() - num_in_test;
+
+ std::vector<sequence_type> x_test, x_train;
+ std::vector<std::vector<std::pair<unsigned long,unsigned long> > > y_test, y_train;
+
+ long next_test_idx = 0;
+
+ matrix<double,1,3> metrics;
+ metrics = 0;
+
+ for (long i = 0; i < folds; ++i)
+ {
+ x_test.clear();
+ y_test.clear();
+ x_train.clear();
+ y_train.clear();
+
+ // load up the test samples
+ for (long cnt = 0; cnt < num_in_test; ++cnt)
+ {
+ x_test.push_back(samples[next_test_idx]);
+ y_test.push_back(segments[next_test_idx]);
+ next_test_idx = (next_test_idx + 1)%samples.size();
+ }
+
+ // load up the training samples
+ long next = next_test_idx;
+ for (long cnt = 0; cnt < num_in_train; ++cnt)
+ {
+ x_train.push_back(samples[next]);
+ y_train.push_back(segments[next]);
+ next = (next + 1)%samples.size();
+ }
+
+
+ metrics += impl::raw_metrics_test_sequence_segmenter(trainer.train(x_train,y_train), x_test, y_test);
+ } // for (long i = 0; i < folds; ++i)
+
+
+ const double total_detections = metrics(0);
+ const double total_true_segments = metrics(1);
+ const double true_hits = metrics(2);
+
+ const double precision = (total_detections ==0) ? 1 : true_hits/total_detections;
+ const double recall = (total_true_segments==0) ? 1 : true_hits/total_true_segments;
+ const double f1 = (precision+recall ==0) ? 0 : 2*precision*recall/(precision+recall);
+
+ matrix<double,1,3> res;
+ res = precision, recall, f1;
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_sequence_segmenter_abstract.h b/ml/dlib/dlib/svm/cross_validate_sequence_segmenter_abstract.h
new file mode 100644
index 000000000..87e21d592
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_sequence_segmenter_abstract.h
@@ -0,0 +1,80 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_ABSTRACT_Hh_
+#ifdef DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_ABSTRACT_Hh_
+
+#include "sequence_segmenter_abstract.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sequence_segmenter_type,
+ typename sequence_type
+ >
+ const matrix<double,1,3> test_sequence_segmenter (
+ const sequence_segmenter_type& segmenter,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments
+ );
+ /*!
+ requires
+ - is_sequence_segmentation_problem(samples, segments) == true
+ - sequence_segmenter_type == dlib::sequence_segmenter or an object with a
+ compatible interface.
+ ensures
+ - Tests segmenter against the given samples and truth segments and returns the
+ precision, recall, and F1-score obtained by the segmenter. That is, the goal
+ of the segmenter should be to predict segments[i] given samples[i] as input.
+ The test_sequence_segmenter() routine therefore measures how well the
+ segmenter is able to perform this task.
+ - Returns a row matrix M with the following properties:
+ - M(0) == The precision of the segmenter measured against the task of
+ detecting the segments of each sample. This is a number in the range 0
+ to 1 and represents the fraction of segments output by the segmenter
+ which correspond to true segments for each sample.
+ - M(1) == The recall of the segmenter measured against the task of
+ detecting the segments of each sample. This is a number in the range 0
+ to 1 and represents the fraction of the true segments found by the
+ segmenter.
+ - M(2) == The F1-score for the segmenter. This is the harmonic mean of
+ M(0) and M(1).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sequence_type
+ >
+ const matrix<double,1,3> cross_validate_sequence_segmenter (
+ const trainer_type& trainer,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments,
+ const long folds
+ );
+ /*!
+ requires
+ - is_sequence_segmentation_problem(samples, segments) == true
+ - 1 < folds <= samples.size()
+ - trainer_type == dlib::structural_sequence_segmentation_trainer or an object
+ with a compatible interface.
+ ensures
+ - Performs k-fold cross validation by using the given trainer to solve the
+ given sequence segmentation problem for the given number of folds. Each fold
+ is tested using the output of the trainer and the results from all folds are
+ summarized and returned.
+ - This function returns the precision, recall, and F1-score for the trainer.
+ In particular, the output is the same as the output from the
+ test_sequence_segmenter() routine defined above.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/cross_validate_track_association_trainer.h b/ml/dlib/dlib/svm/cross_validate_track_association_trainer.h
new file mode 100644
index 000000000..dac519b7a
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_track_association_trainer.h
@@ -0,0 +1,163 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_
+#define DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_
+
+#include "cross_validate_track_association_trainer_abstract.h"
+#include "structural_track_association_trainer.h"
+
+namespace dlib
+{
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename track_association_function,
+ typename detection_type,
+ typename label_type
+ >
+ void test_track_association_function (
+ const track_association_function& assoc,
+ const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& samples,
+ unsigned long& total_dets,
+ unsigned long& correctly_associated_dets
+ )
+ {
+ const typename track_association_function::association_function_type& f = assoc.get_assignment_function();
+
+ typedef typename detection_type::track_type track_type;
+ using namespace impl;
+
+ dlib::rand rnd;
+ std::vector<track_type> tracks;
+ std::map<label_type,long> track_idx; // tracks[track_idx[id]] == track with ID id.
+
+ for (unsigned long j = 0; j < samples.size(); ++j)
+ {
+ std::vector<labeled_detection<detection_type,label_type> > dets = samples[j];
+ // Shuffle the order of the detections so we can be sure that there isn't
+ // anything funny going on like the detections always coming in the same
+ // order relative to their labels and the association function just gets
+ // lucky by picking the same assignment ordering every time. So this way
+ // we know the assignment function really is doing something rather than
+ // just being lucky.
+ randomize_samples(dets, rnd);
+
+ total_dets += dets.size();
+ std::vector<long> assignments = f(get_unlabeled_dets(dets), tracks);
+ std::vector<bool> updated_track(tracks.size(), false);
+ // now update all the tracks with the detections that associated to them.
+ for (unsigned long k = 0; k < assignments.size(); ++k)
+ {
+ // If the detection is associated to tracks[assignments[k]]
+ if (assignments[k] != -1)
+ {
+ tracks[assignments[k]].update_track(dets[k].det);
+ updated_track[assignments[k]] = true;
+
+ // if this detection was supposed to go to this track
+ if (track_idx.count(dets[k].label) && track_idx[dets[k].label]==assignments[k])
+ ++correctly_associated_dets;
+
+ track_idx[dets[k].label] = assignments[k];
+ }
+ else
+ {
+ track_type new_track;
+ new_track.update_track(dets[k].det);
+ tracks.push_back(new_track);
+
+ // if this detection was supposed to go to a new track
+ if (track_idx.count(dets[k].label) == 0)
+ ++correctly_associated_dets;
+
+ track_idx[dets[k].label] = tracks.size()-1;
+ }
+ }
+
+ // Now propagate all the tracks that didn't get any detections.
+ for (unsigned long k = 0; k < updated_track.size(); ++k)
+ {
+ if (!updated_track[k])
+ tracks[k].propagate_track();
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename track_association_function,
+ typename detection_type,
+ typename label_type
+ >
+ double test_track_association_function (
+ const track_association_function& assoc,
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
+ )
+ {
+ unsigned long total_dets = 0;
+ unsigned long correctly_associated_dets = 0;
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ impl::test_track_association_function(assoc, samples[i], total_dets, correctly_associated_dets);
+ }
+
+ return (double)correctly_associated_dets/(double)total_dets;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename detection_type,
+ typename label_type
+ >
+ double cross_validate_track_association_trainer (
+ const trainer_type& trainer,
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples,
+ const long folds
+ )
+ {
+ const long num_in_test = samples.size()/folds;
+ const long num_in_train = samples.size() - num_in_test;
+
+ std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > > samples_train;
+
+ long next_test_idx = 0;
+ unsigned long total_dets = 0;
+ unsigned long correctly_associated_dets = 0;
+
+ for (long i = 0; i < folds; ++i)
+ {
+ samples_train.clear();
+
+ // load up the training samples
+ long next = (next_test_idx + num_in_test)%samples.size();
+ for (long cnt = 0; cnt < num_in_train; ++cnt)
+ {
+ samples_train.push_back(samples[next]);
+ next = (next + 1)%samples.size();
+ }
+
+ const track_association_function<detection_type>& df = trainer.train(samples_train);
+ for (long cnt = 0; cnt < num_in_test; ++cnt)
+ {
+ impl::test_track_association_function(df, samples[next_test_idx], total_dets, correctly_associated_dets);
+ next_test_idx = (next_test_idx + 1)%samples.size();
+ }
+ }
+
+ return (double)correctly_associated_dets/(double)total_dets;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/cross_validate_track_association_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_track_association_trainer_abstract.h
new file mode 100644
index 000000000..76b985600
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_track_association_trainer_abstract.h
@@ -0,0 +1,69 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_ABSTRACT_Hh_
+#ifdef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_ABSTRACT_Hh_
+
+#include "structural_track_association_trainer_abstract.h"
+#include "svm_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename track_association_function,
+ typename detection_type,
+ typename label_type
+ >
+ double test_track_association_function (
+ const track_association_function& assoc,
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
+ );
+ /*!
+ requires
+ - is_track_association_problem(samples)
+ - track_association_function == an instantiation of the dlib::track_association_function
+ template or an object with a compatible interface.
+ ensures
+ - Tests assoc against the given samples and returns the fraction of detections
+ which were correctly associated to their tracks. That is, if assoc produces
+ perfect tracks when used then this function returns a value of 1. Similarly,
+ if 5% of the detections were associated to the incorrect track then the
+ return value is 0.05.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename detection_type,
+ typename label_type
+ >
+ double cross_validate_track_association_trainer (
+ const trainer_type& trainer,
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples,
+ const long folds
+ );
+ /*!
+ requires
+ - is_track_association_problem(samples)
+ - 1 < folds <= samples.size()
+ - trainer_type == dlib::structural_track_association_trainer or an object with
+ a compatible interface.
+ ensures
+ - Performs k-fold cross validation by using the given trainer to solve the
+ given track association learning problem for the given number of folds. Each
+ fold is tested using the output of the trainer and the fraction of
+ mis-associated detections is returned (i.e. this function returns the same
+ measure of track association quality as test_track_association_function()).
+ - The number of folds used is given by the folds argument.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/empirical_kernel_map.h b/ml/dlib/dlib/svm/empirical_kernel_map.h
new file mode 100644
index 000000000..7a91e591a
--- /dev/null
+++ b/ml/dlib/dlib/svm/empirical_kernel_map.h
@@ -0,0 +1,429 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_EMPIRICAL_KERNEl_MAP_H_
+#define DLIB_EMPIRICAL_KERNEl_MAP_H_
+
+#include "../matrix.h"
+#include "empirical_kernel_map_abstract.h"
+#include "linearly_independent_subset_finder.h"
+#include <vector>
+#include "../algs.h"
+#include "kernel_matrix.h"
+#include "function.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type, typename EXP>
+ const decision_function<kernel_type> convert_to_decision_function (
+ const projection_function<kernel_type>& project_funct,
+ const matrix_exp<EXP>& vect
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(project_funct.out_vector_size() > 0 && is_vector(vect) &&
+ project_funct.out_vector_size() == vect.size() && project_funct.weights.nc() == project_funct.basis_vectors.size(),
+ "\t const decision_function convert_to_decision_function()"
+ << "\n\t Invalid inputs to this function."
+ << "\n\t project_funct.out_vector_size(): " << project_funct.out_vector_size()
+ << "\n\t project_funct.weights.nc(): " << project_funct.weights.nc()
+ << "\n\t project_funct.basis_vectors.size(): " << project_funct.basis_vectors.size()
+ << "\n\t is_vector(vect): " << is_vector(vect)
+ << "\n\t vect.size(): " << vect.size()
+ );
+
+ return decision_function<kernel_type>(trans(project_funct.weights)*vect,
+ 0,
+ project_funct.kernel_function,
+ project_funct.basis_vectors);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kern_type>
+ class empirical_kernel_map
+ {
+ public:
+
+ struct empirical_kernel_map_error : public error
+ {
+ empirical_kernel_map_error(const std::string& message): error(message) {}
+ };
+
+ typedef kern_type kernel_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ void clear (
+ )
+ {
+ empirical_kernel_map().swap(*this);
+ }
+
+ template <typename T>
+ void load(
+ const kernel_type& kernel_,
+ const T& basis_samples
+ )
+ {
+ load_impl(kernel_, mat(basis_samples));
+ }
+
+ void load(
+ const linearly_independent_subset_finder<kernel_type>& lisf
+ )
+ {
+ if (lisf.size() == 0)
+ {
+ std::ostringstream sout;
+ sout << "An empty linearly_independent_subset_finder was supplied to the\n"
+ << "empirical_kernel_map::load() function. One reason this might occur\n"
+ << "is if your dataset contains only zero vectors (or vectors \n"
+ << "approximately zero).\n";
+ clear();
+ throw empirical_kernel_map_error(sout.str());
+ }
+
+ kernel = lisf.get_kernel();
+ weights = trans(chol(lisf.get_inv_kernel_marix()));
+ basis.resize(lisf.size());
+ for (unsigned long i = 0; i < basis.size(); ++i)
+ basis[i] = lisf[i];
+
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(out_vector_size() > 0,
+ "\tconst kernel_type empirical_kernel_map::get_kernel()"
+ << "\n\t You have to load this object with a kernel before you can call this function"
+ << "\n\t this: " << this
+ );
+
+ return kernel;
+ }
+
+ long out_vector_size (
+ ) const
+ {
+ return weights.nr();
+ }
+
+ unsigned long basis_size (
+ ) const
+ {
+ return basis.size();
+ }
+
+ const sample_type& operator[] (
+ unsigned long idx
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( idx < basis_size(),
+ "\t const sample_type& empirical_kernel_map::operator[](idx)"
+ << "\n\t Invalid inputs to this function."
+ << "\n\t basis_size(): " << basis_size()
+ << "\n\t this: " << this
+ );
+
+ return basis[idx];
+ }
+
+ template <typename EXP>
+ const decision_function<kernel_type> convert_to_decision_function (
+ const matrix_exp<EXP>& vect
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(out_vector_size() != 0 && is_vector(vect) && out_vector_size() == vect.size(),
+ "\t const decision_function empirical_kernel_map::convert_to_decision_function()"
+ << "\n\t Invalid inputs to this function."
+ << "\n\t out_vector_size(): " << out_vector_size()
+ << "\n\t is_vector(vect): " << is_vector(vect)
+ << "\n\t vect.size(): " << vect.size()
+ << "\n\t this: " << this
+ );
+
+ return decision_function<kernel_type>(trans(weights)*vect, 0, kernel, mat(basis));
+ }
+
+ template <typename EXP>
+ const distance_function<kernel_type> convert_to_distance_function (
+ const matrix_exp<EXP>& vect
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(out_vector_size() != 0 && is_vector(vect) && out_vector_size() == vect.size(),
+ "\t const distance_function empirical_kernel_map::convert_to_distance_function()"
+ << "\n\t Invalid inputs to this function."
+ << "\n\t out_vector_size(): " << out_vector_size()
+ << "\n\t is_vector(vect): " << is_vector(vect)
+ << "\n\t vect.size(): " << vect.size()
+ << "\n\t this: " << this
+ );
+
+ return distance_function<kernel_type>(trans(weights)*vect, dot(vect,vect), kernel, mat(basis));
+ }
+
+ const projection_function<kernel_type> get_projection_function (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(out_vector_size() != 0,
+ "\tconst projection_function empirical_kernel_map::get_projection_function()"
+ << "\n\t You have to load this object with data before you can call this function"
+ << "\n\t this: " << this
+ );
+
+ return projection_function<kernel_type>(weights, kernel, mat(basis));
+ }
+
+ const matrix<scalar_type,0,0,mem_manager_type> get_transformation_to (
+ const empirical_kernel_map& target
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(out_vector_size() != 0 &&
+ target.out_vector_size() != 0 &&
+ get_kernel() == target.get_kernel(),
+ "\t const matrix empirical_kernel_map::get_transformation_to(target)"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t out_vector_size(): " << out_vector_size()
+ << "\n\t target.out_vector_size(): " << target.out_vector_size()
+ << "\n\t get_kernel()==target.get_kernel(): " << (get_kernel()==target.get_kernel())
+ << "\n\t this: " << this
+ );
+
+ return target.weights * kernel_matrix(target.get_kernel(),target.basis, basis)*trans(weights);
+ }
+
+ void get_transformation_to (
+ const empirical_kernel_map& target,
+ matrix<scalar_type, 0, 0, mem_manager_type>& tmat,
+ projection_function<kernel_type>& partial_projection
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(out_vector_size() != 0 &&
+ target.out_vector_size() != 0 &&
+ get_kernel() == target.get_kernel() &&
+ basis_size() < target.basis_size(),
+ "\t void empirical_kernel_map::get_transformation_to(target, tmat, partial_projection)"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t out_vector_size(): " << out_vector_size()
+ << "\n\t target.out_vector_size(): " << target.out_vector_size()
+ << "\n\t basis_size(): " << basis_size()
+ << "\n\t target.basis_size(): " << target.basis_size()
+ << "\n\t get_kernel()==target.get_kernel(): " << (get_kernel()==target.get_kernel())
+ << "\n\t this: " << this
+ );
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < basis_size(); ++i)
+ {
+ DLIB_ASSERT(dlib::equal((*this)[i], target[i]),
+ "\t const matrix empirical_kernel_map::get_transformation_to(target, tmat, partial_projection)"
+ << "\n\t target must contain a superset of the basis vectors in *this"
+ << "\n\t i: " << i
+ << "\n\t this: " << this
+ );
+ }
+#endif
+
+ const unsigned long num1 = basis.size();
+ const unsigned long num2 = target.basis.size();
+
+ tmat = colm(target.weights, range(0,num1-1))*kernel_matrix(kernel, basis)*trans(weights);
+
+ empirical_kernel_map temp_ekm;
+ temp_ekm.load(kernel, rowm(mat(target.basis), range(num1,num2-1)));
+
+ partial_projection = temp_ekm.get_projection_function();
+
+ partial_projection.weights = colm(target.weights,range(num1,num2-1))*
+ kernel_matrix(kernel, temp_ekm.basis)*
+ trans(temp_ekm.weights)*
+ partial_projection.weights;
+ }
+
+ const matrix<scalar_type,0,1,mem_manager_type>& project (
+ const sample_type& samp
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(out_vector_size() != 0,
+ "\tconst matrix empirical_kernel_map::project()"
+ << "\n\t You have to load this object with data before you can call this function"
+ << "\n\t this: " << this
+ );
+
+ temp1 = kernel_matrix(kernel, basis, samp);
+ temp2 = weights*temp1;
+ return temp2;
+ }
+
+ const matrix<scalar_type,0,1,mem_manager_type>& project (
+ const sample_type& samp,
+ scalar_type& projection_error
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(out_vector_size() != 0,
+ "\tconst matrix empirical_kernel_map::project()"
+ << "\n\t You have to load this object with data before you can call this function"
+ << "\n\t this: " << this
+ );
+
+ temp1 = kernel_matrix(kernel, basis, samp);
+ temp2 = weights*temp1;
+ // This value should never be negative (it measures squared distance) but I'm putting the abs()
+ // here just for good measure since rounding error might push it slightly negative.
+ projection_error = std::abs( kernel(samp,samp) - dot(temp2,temp2));
+
+ return temp2;
+ }
+
+ void swap (
+ empirical_kernel_map& item
+ )
+ {
+ basis.swap(item.basis);
+ weights.swap(item.weights);
+ std::swap(kernel, item.kernel);
+
+ temp1.swap(item.temp1);
+ temp2.swap(item.temp2);
+ }
+
+ friend void serialize (
+ const empirical_kernel_map& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.basis, out);
+ serialize(item.weights, out);
+ serialize(item.kernel, out);
+ }
+
+ friend void deserialize (
+ empirical_kernel_map& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.basis, in);
+ deserialize(item.weights, in);
+ deserialize(item.kernel, in);
+ }
+
+ private:
+
+ template <typename T>
+ void load_impl(
+ const kernel_type& kernel_,
+ const T& basis_samples
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(basis_samples.size() > 0 && is_vector(basis_samples),
+ "\tvoid empirical_kernel_map::load(kernel,basis_samples)"
+ << "\n\t You have to give a non-empty set of basis_samples and it must be a vector"
+ << "\n\t basis_samples.size(): " << basis_samples.size()
+ << "\n\t is_vector(basis_samples): " << is_vector(basis_samples)
+ << "\n\t this: " << this
+ );
+
+ // clear out the weights before we begin. This way if an exception throws
+ // this object will already be in the right state.
+ weights.set_size(0,0);
+ kernel = kernel_;
+ basis.clear();
+ basis.reserve(basis_samples.size());
+
+ // find out the value of the largest norm of the elements in basis_samples.
+ const scalar_type max_norm = max(diag(kernel_matrix(kernel, basis_samples)));
+ // we will consider anything less than or equal to this number to be 0
+ const scalar_type eps = max_norm*100*std::numeric_limits<scalar_type>::epsilon();
+
+ // Copy all the basis_samples into basis but make sure we don't copy any samples
+ // that have length 0
+ for (long i = 0; i < basis_samples.size(); ++i)
+ {
+ const scalar_type norm = kernel(basis_samples(i), basis_samples(i));
+ if (norm > eps)
+ {
+ basis.push_back(basis_samples(i));
+ }
+ }
+
+ if (basis.size() == 0)
+ {
+ clear();
+ throw empirical_kernel_map_error("All basis_samples given to empirical_kernel_map::load() were zero vectors");
+ }
+
+ matrix<scalar_type,0,0,mem_manager_type> K(kernel_matrix(kernel, basis)), U,W,V;
+
+ if (svd2(false,true,K,U,W,V))
+ {
+ clear();
+ throw empirical_kernel_map_error("While loading empirical_kernel_map with data, SVD failed to converge.");
+ }
+
+
+ // now count how many elements of W are non-zero
+ const long num_not_zero = static_cast<long>(sum(W>eps));
+
+ // Really, this should never happen. But I'm checking for good measure.
+ if (num_not_zero == 0)
+ {
+ clear();
+ throw empirical_kernel_map_error("While loading empirical_kernel_map with data, SVD failed");
+ }
+
+ weights.set_size(num_not_zero, basis.size());
+
+ // now fill the weights matrix with the output of the SVD
+ long counter = 0;
+ for (long i =0; i < W.size(); ++i)
+ {
+ double val = W(i);
+ if (val > eps)
+ {
+ val = std::sqrt(val);
+ set_rowm(weights,counter) = rowm(trans(V),i)/val;
+ ++counter;
+ }
+ }
+
+ }
+
+
+ std::vector<sample_type> basis;
+ matrix<scalar_type,0,0,mem_manager_type> weights;
+ kernel_type kernel;
+
+ // These members don't contribute to the logical state of this object. They are
+ // just here so that they don't have to be reallocated every time the project() function
+ // is called.
+ mutable matrix<scalar_type,0,1,mem_manager_type> temp1, temp2;
+
+ };
+
+ template <typename kernel_type>
+ void swap (
+ empirical_kernel_map<kernel_type>& a,
+ empirical_kernel_map<kernel_type>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_EMPIRICAL_KERNEl_MAP_H_
+
diff --git a/ml/dlib/dlib/svm/empirical_kernel_map_abstract.h b/ml/dlib/dlib/svm/empirical_kernel_map_abstract.h
new file mode 100644
index 000000000..8fc413447
--- /dev/null
+++ b/ml/dlib/dlib/svm/empirical_kernel_map_abstract.h
@@ -0,0 +1,430 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_EMPIRICAL_KERNEl_MAP_ABSTRACT_H_
+#ifdef DLIB_EMPIRICAL_KERNEl_MAP_ABSTRACT_H_
+
+#include <vector>
+#include "../matrix.h"
+#include "kernel_abstract.h"
+#include "function_abstract.h"
+#include "linearly_independent_subset_finder_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename EXP
+ >
+ const decision_function<kernel_type> convert_to_decision_function (
+ const projection_function<kernel_type>& project_funct,
+ const matrix_exp<EXP>& vect
+ );
+ /*!
+ requires
+ - is_vector(vect) == true
+ - vect.size() == project_funct.out_vector_size()
+ - project_funct.out_vector_size() > 0
+ - project_funct.weights.nc() == project_funct.basis_vectors.size()
+ ensures
+ - This function interprets the given vector as a point in the kernel feature space defined
+ by the given projection function. The return value of this function is a decision
+ function, DF, that represents the given vector in the following sense:
+ - for all possible sample_type objects, S, it is the case that DF(S) == dot(project_funct(S), vect)
+ (i.e. the returned decision function computes dot products, in kernel feature space,
+ between vect and any argument you give it. Note also that this equality is exact, even
+ for sample_type objects not in the span of the basis_vectors.)
+ - DF.kernel_function == project_funct.kernel_function
+ - DF.b == 0
+ - DF.basis_vectors == project_funct.basis_vectors.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kern_type
+ >
+ class empirical_kernel_map
+ {
+ /*!
+ REQUIREMENTS ON kern_type
+ - must be a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ INITIAL VALUE
+ - out_vector_size() == 0
+ - basis_size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a map from objects of sample_type (the kind of object
+ a kernel function operates on) to finite dimensional column vectors which
+ represent points in the kernel feature space defined by whatever kernel
+ is used with this object.
+
+ To use the empirical_kernel_map you supply it with a particular kernel and a set of
+ basis samples. After that you can present it with new samples and it will project
+ them into the part of kernel feature space spanned by your basis samples.
+
+ This means the empirical_kernel_map is a tool you can use to very easily kernelize
+ any algorithm that operates on column vectors. All you have to do is select a
+ set of basis samples and then use the empirical_kernel_map to project all your
+ data points into the part of kernel feature space spanned by those basis samples.
+ Then just run your normal algorithm on the output vectors and it will be effectively
+ kernelized.
+
+ Regarding methods to select a set of basis samples, if you are working with only a
+ few thousand samples then you can just use all of them as basis samples.
+ Alternatively, the linearly_independent_subset_finder often works well for
+ selecting a basis set. I also find that picking a random subset typically works
+ well.
+
+
+ The empirical kernel map is something that has been around in the kernel methods
+ literature for a long time but is seemingly not well known. Anyway, one of the
+ best books on the subject is the following:
+ Learning with Kernels: Support Vector Machines, Regularization, Optimization,
+ and Beyond by Bernhard Schlkopf, Alexander J. Smola
+ The authors discuss the empirical kernel map as well as many other interesting
+ topics.
+ !*/
+
+ public:
+
+ typedef kern_type kernel_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ struct empirical_kernel_map_error : public error;
+ /*!
+ This is an exception class used to indicate a failure to create a
+ kernel map from data given by the user.
+ !*/
+
+ empirical_kernel_map (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - this object has its initial value
+ !*/
+
+ template <typename T>
+ void load(
+ const kernel_type& kernel,
+ const T& basis_samples
+ );
+ /*!
+ requires
+ - T must be a dlib::matrix type or something convertible to a matrix via mat()
+ (e.g. a std::vector)
+ - is_vector(basis_samples) == true
+ - basis_samples.size() > 0
+ - kernel must be capable of operating on the elements of basis_samples. That is,
+ expressions such as kernel(basis_samples(0), basis_samples(0)) should make sense.
+ ensures
+ - 0 < #out_vector_size() <= basis_samples.size()
+ - #basis_size() == basis_samples.size()
+ - #get_kernel() == kernel
+ - This function constructs a map between normal sample_type objects and the
+ subspace of the kernel feature space defined by the given kernel and the
+ given set of basis samples. So after this function has been called you
+ will be able to project sample_type objects into kernel feature space
+ and obtain the resulting vector as a regular column matrix.
+ - The basis samples are loaded into this object in the order in which they
+ are stored in basis_samples. That is:
+ - for all valid i: (*this)[i] == basis_samples(i)
+ throws
+ - empirical_kernel_map_error
+ This exception is thrown if we are unable to create a kernel map.
+ If this happens then this object will revert back to its initial value.
+ !*/
+
+ void load(
+ const linearly_independent_subset_finder<kernel_type>& lisf
+ );
+ /*!
+ ensures
+ - #out_vector_size() == lisf.dictionary_size()
+ - #basis_size() == lisf.dictionary_size()
+ - #get_kernel() == lisf.get_kernel()
+ - Uses the dictionary vectors from lisf as a basis set. Thus, this function
+ constructs a map between normal sample_type objects and the subspace of
+ the kernel feature space defined by the given kernel and the given set
+ of basis samples. So after this function has been called you will be
+ able to project sample_type objects into kernel feature space and obtain
+ the resulting vector as a regular column matrix.
+ - The basis samples are loaded into this object in the order in which they
+ are stored in lisf. That is:
+ - for all valid i: (*this)[i] == lisf[i]
+ throws
+ - empirical_kernel_map_error
+ This exception is thrown if we are unable to create a kernel map.
+ E.g. if the lisf.size() == 0.
+ If this happens then this object will revert back to its initial value.
+ !*/
+
+ const kernel_type get_kernel (
+ ) const;
+ /*!
+ requires
+ - out_vector_size() != 0
+ ensures
+ - returns a copy of the kernel used by this object
+ !*/
+
+ long out_vector_size (
+ ) const;
+ /*!
+ ensures
+ - if (this object has been loaded with basis samples) then
+ - returns the dimensionality of the vectors output by the project() function.
+ - else
+ - returns 0
+ !*/
+
+ unsigned long basis_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of basis vectors in projection_functions created
+ by this object. This is also equal to the number of basis vectors
+ given to the load() function.
+ !*/
+
+ const sample_type& operator[] (
+ unsigned long idx
+ ) const;
+ /*!
+ requires
+ - idx < basis_size()
+ ensures
+ - returns a const reference to the idx'th basis vector contained inside
+ this object.
+ !*/
+
+ const matrix<scalar_type,0,1,mem_manager_type>& project (
+ const sample_type& sample
+ ) const;
+ /*!
+ requires
+ - out_vector_size() != 0
+ ensures
+ - takes the given sample and projects it into the kernel feature space
+ of out_vector_size() dimensions defined by this kernel map and
+ returns the resulting vector.
+ - in more precise terms, this function returns a vector such that:
+ - The returned vector will contain out_vector_size() elements.
+ - for any sample_type object S, the following equality is approximately true:
+ - get_kernel()(sample,S) == dot(project(sample), project(S)).
+ - The approximation error in the above equality will be zero (within rounding error)
+ if both sample_type objects involved are within the span of the set of basis
+ samples given to the load() function. If they are not then there will be some
+ approximation error. Note that all the basis samples are always within their
+ own span. So the equality is always exact for the samples given to the load()
+ function.
+ !*/
+
+ const matrix<scalar_type,0,1,mem_manager_type>& project (
+ const sample_type& samp,
+ scalar_type& projection_error
+ ) const;
+ /*!
+ requires
+ - out_vector_size() != 0
+ ensures
+ - This function returns project(samp)
+ (i.e. it returns the same thing as the above project() function)
+ - #projection_error == the square of the distance between the point samp
+ gets projected onto and samp's true image in kernel feature space.
+ That is, this value is equal to:
+ pow(convert_to_distance_function(project(samp))(samp),2)
+ !*/
+
+ template <typename EXP>
+ const decision_function<kernel_type> convert_to_decision_function (
+ const matrix_exp<EXP>& vect
+ ) const;
+ /*!
+ requires
+ - is_vector(vect) == true
+ - vect.size() == out_vector_size()
+ - out_vector_size() != 0
+ ensures
+ - This function interprets the given vector as a point in the kernel feature space defined
+ by this empirical_kernel_map. The return value of this function is a decision
+ function, DF, that represents the given vector in the following sense:
+ - for all possible sample_type objects, S, it is the case that DF(S) == dot(project(S), vect)
+ (i.e. the returned decision function computes dot products, in kernel feature space,
+ between vect and any argument you give it. Note also that this equality is exact, even
+ for sample_type objects not in the span of the basis samples.)
+ - DF.kernel_function == get_kernel()
+ - DF.b == 0
+ - DF.basis_vectors == these will be the basis samples given to the previous call to load(). Note
+ that it is possible for there to be fewer basis_vectors than basis samples given to load().
+ - DF.basis_vectors.size() == basis_size()
+ !*/
+
+ template <typename EXP>
+ const distance_function<kernel_type> convert_to_distance_function (
+ const matrix_exp<EXP>& vect
+ ) const
+ /*!
+ requires
+ - is_vector(vect) == true
+ - vect.size() == out_vector_size()
+ - out_vector_size() != 0
+ ensures
+ - This function interprets the given vector as a point in the kernel feature space defined
+ by this empirical_kernel_map. The return value of this function is a distance
+ function, DF, that represents the given vector in the following sense:
+ - for any sample_type object S, the following equality is approximately true:
+ - DF(S) == length(project(S) - vect)
+ (i.e. the returned distance function computes distances, in kernel feature space,
+ between vect and any argument you give it. )
+ - The approximation error in the above equality will be zero (within rounding error)
+ if S is within the span of the set of basis samples given to the load() function.
+ If it is not then there will be some approximation error. Note that all the basis
+ samples are always within their own span. So the equality is always exact for the
+ samples given to the load() function. Note further that the distance computed
+ by DF(S) is always the correct distance in kernel feature space between vect and
+ the true projection of S. That is, the above equality is approximate only because
+ of potential error in the project() function, not in DF(S).
+ - DF.kernel_function == get_kernel()
+ - DF.b == dot(vect,vect)
+ - DF.basis_vectors == these will be the basis samples given to the previous call to load(). Note
+ that it is possible for there to be fewer basis_vectors than basis samples given to load().
+ - DF.basis_vectors.size() == basis_size()
+ !*/
+
+ const projection_function<kernel_type> get_projection_function (
+ ) const;
+ /*!
+ requires
+ - out_vector_size() != 0
+ ensures
+ - returns a projection_function, PF, that computes the same projection as project().
+ That is, calling PF() on any sample will produce the same output vector as calling
+ this->project() on that sample.
+ - PF.basis_vectors.size() == basis_size()
+ !*/
+
+ const matrix<scalar_type,0,0,mem_manager_type> get_transformation_to (
+ const empirical_kernel_map& target
+ ) const;
+ /*!
+ requires
+ - get_kernel() == target.get_kernel()
+ - out_vector_size() != 0
+ - target.out_vector_size() != 0
+ ensures
+ - A point in the kernel feature space defined by the kernel get_kernel() typically
+ has different representations with respect to different empirical_kernel_maps.
+ This function lets you obtain a transformation matrix that will allow you
+ to project between these different representations. That is, this function returns
+ a matrix M with the following properties:
+ - M maps vectors represented according to *this into the representation used by target.
+ - M.nr() == target.out_vector_size()
+ - M.nc() == this->out_vector_size()
+ - Let V be a vector of this->out_vector_size() length. Then define two distance_functions
+ DF1 = this->convert_to_distance_function(V)
+ DF2 = target.convert_to_distance_function(M*V)
+
+ Then DF1(DF2) == 0 // i.e. the distance between these two points should be 0
+
+ That is, DF1 and DF2 both represent the same point in kernel feature space. Note
+ that the above equality is only approximate. If the vector V represents a point in
+ kernel space that isn't in the span of the basis samples used by target then the
+ equality is approximate. However, if it is in their span then the equality will
+ be exact. For example, if target's basis samples are a superset of the basis samples
+ used by *this then the equality will always be exact (within rounding error).
+ !*/
+
+ void get_transformation_to (
+ const empirical_kernel_map& target,
+ matrix<scalar_type, 0, 0, mem_manager_type>& tmat,
+ projection_function<kernel_type>& partial_projection
+ ) const;
+ /*!
+ requires
+ - get_kernel() == target.get_kernel()
+ - out_vector_size() != 0
+ - target.out_vector_size() != 0
+ - basis_size() < target.basis_size()
+ - for all i < basis_size(): (*this)[i] == target[i]
+ i.e. target must contain a superset of the basis vectors contained in *this. Moreover,
+ it must contain them in the same order.
+ ensures
+ - The single argument version of get_transformation_to() allows you to project
+ vectors from one empirical_kernel_map representation to another. This version
+ provides a somewhat different capability. Assuming target's basis vectors form a
+ superset of *this's basis vectors then this form of get_transformation_to() allows
+ you to reuse a vector from *this ekm to speed up the projection performed by target.
+ The defining relation is given below.
+ - for any sample S:
+ - target.project(S) == #tmat * this->project(S) + #partial_projection(S)
+ (this is always true to within rounding error for any S)
+ - #partial_projection.basis_vectors.size() == target.basis_vectors.size() - this->basis_vectors.size()
+ - #tmat.nr() == target.out_vector_size()
+ - #tmat.nc() == this->out_vector_size()
+ !*/
+
+ void swap (
+ empirical_kernel_map& item
+ );
+ /*!
+ ensures
+ - swaps the state of *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type
+ >
+ void swap (
+ empirical_kernel_map<kernel_type>& a,
+ empirical_kernel_map<kernel_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void serialize (
+ const empirical_kernel_map<kernel_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for empirical_kernel_map objects
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void deserialize (
+ empirical_kernel_map<kernel_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for empirical_kernel_map objects
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_EMPIRICAL_KERNEl_MAP_ABSTRACT_H_
+
diff --git a/ml/dlib/dlib/svm/feature_ranking.h b/ml/dlib/dlib/svm/feature_ranking.h
new file mode 100644
index 000000000..f6324fe3d
--- /dev/null
+++ b/ml/dlib/dlib/svm/feature_ranking.h
@@ -0,0 +1,477 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_KERNEL_FEATURE_RANKINg_H_
+#define DLIB_KERNEL_FEATURE_RANKINg_H_
+
+#include <vector>
+#include <limits>
+
+#include "feature_ranking_abstract.h"
+#include "kcentroid.h"
+#include "../optimization.h"
+#include "../statistics.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ matrix<typename kernel_type::scalar_type,0,2,typename kernel_type::mem_manager_type> rank_features_impl (
+ const kcentroid<kernel_type>& kc,
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels
+ )
+ {
+ /*
+ This function ranks features by doing recursive feature elimination
+
+ */
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::mem_manager_type mm;
+
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(samples, labels) == true,
+ "\tmatrix rank_features()"
+ << "\n\t you have given invalid arguments to this function"
+ );
+
+ matrix<scalar_type,0,2,mm> results(samples(0).nr(), 2);
+ matrix<scalar_type,sample_matrix_type::type::NR,1,mm> mask(samples(0).nr());
+ set_all_elements(mask,1);
+
+ // figure out what the separation is between the two centroids when all the features are
+ // present.
+ scalar_type first_separation;
+ {
+ kcentroid<kernel_type> c1(kc);
+ kcentroid<kernel_type> c2(kc);
+ // find the centers of each class
+ for (long s = 0; s < samples.size(); ++s)
+ {
+ if (labels(s) < 0)
+ {
+ c1.train(samples(s));
+ }
+ else
+ {
+ c2.train(samples(s));
+ }
+
+ }
+ first_separation = c1(c2);
+ }
+
+
+ using namespace std;
+
+ for (long i = results.nr()-1; i >= 0; --i)
+ {
+ long worst_feature_idx = 0;
+ scalar_type worst_feature_score = -std::numeric_limits<scalar_type>::infinity();
+
+ // figure out which feature to remove next
+ for (long j = 0; j < mask.size(); ++j)
+ {
+ // skip features we have already removed
+ if (mask(j) == 0)
+ continue;
+
+ kcentroid<kernel_type> c1(kc);
+ kcentroid<kernel_type> c2(kc);
+
+ // temporarily remove this feature from the working set of features
+ mask(j) = 0;
+
+ // find the centers of each class
+ for (long s = 0; s < samples.size(); ++s)
+ {
+ if (labels(s) < 0)
+ {
+ c1.train(pointwise_multiply(samples(s),mask));
+ }
+ else
+ {
+ c2.train(pointwise_multiply(samples(s),mask));
+ }
+
+ }
+
+ // find the distance between the two centroids and use that
+ // as the score
+ const double score = c1(c2);
+
+ if (score > worst_feature_score)
+ {
+ worst_feature_score = score;
+ worst_feature_idx = j;
+ }
+
+ // add this feature back to the working set of features
+ mask(j) = 1;
+
+ }
+
+ // now that we know what the next worst feature is record it
+ mask(worst_feature_idx) = 0;
+ results(i,0) = worst_feature_idx;
+ results(i,1) = worst_feature_score;
+ }
+
+ // now normalize the results
+ const scalar_type max_separation = std::max(max(colm(results,1)), first_separation);
+ set_colm(results,1) = colm(results,1)/max_separation;
+ for (long r = 0; r < results.nr()-1; ++r)
+ {
+ results(r,1) = results(r+1,1);
+ }
+ results(results.nr()-1,1) = first_separation/max_separation;
+
+ return results;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ matrix<typename kernel_type::scalar_type,0,2,typename kernel_type::mem_manager_type> rank_features (
+ const kcentroid<kernel_type>& kc,
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels
+ )
+ {
+ return rank_features_impl(kc, mat(samples), mat(labels));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ matrix<typename kernel_type::scalar_type,0,2,typename kernel_type::mem_manager_type> rank_features_impl (
+ const kcentroid<kernel_type>& kc,
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels,
+ const long num_features
+ )
+ {
+ /*
+ This function ranks features by doing recursive feature addition
+
+ */
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::mem_manager_type mm;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(samples, labels) == true,
+ "\tmatrix rank_features()"
+ << "\n\t you have given invalid arguments to this function"
+ );
+ DLIB_ASSERT(0 < num_features && num_features <= samples(0).nr(),
+ "\tmatrix rank_features()"
+ << "\n\t you have given invalid arguments to this function"
+ << "\n\t num_features: " << num_features
+ << "\n\t samples(0).nr(): " << samples(0).nr()
+ );
+
+ matrix<scalar_type,0,2,mm> results(num_features, 2);
+ matrix<scalar_type,sample_matrix_type::type::NR,1,mm> mask(samples(0).nr());
+ set_all_elements(mask,0);
+
+ using namespace std;
+
+ for (long i = 0; i < results.nr(); ++i)
+ {
+ long best_feature_idx = 0;
+ scalar_type best_feature_score = -std::numeric_limits<scalar_type>::infinity();
+
+ // figure out which feature to add next
+ for (long j = 0; j < mask.size(); ++j)
+ {
+ // skip features we have already added
+ if (mask(j) == 1)
+ continue;
+
+ kcentroid<kernel_type> c1(kc);
+ kcentroid<kernel_type> c2(kc);
+
+ // temporarily add this feature to the working set of features
+ mask(j) = 1;
+
+ // find the centers of each class
+ for (long s = 0; s < samples.size(); ++s)
+ {
+ if (labels(s) < 0)
+ {
+ c1.train(pointwise_multiply(samples(s),mask));
+ }
+ else
+ {
+ c2.train(pointwise_multiply(samples(s),mask));
+ }
+
+ }
+
+ // find the distance between the two centroids and use that
+ // as the score
+ const double score = c1(c2);
+
+ if (score > best_feature_score)
+ {
+ best_feature_score = score;
+ best_feature_idx = j;
+ }
+
+ // take this feature back out of the working set of features
+ mask(j) = 0;
+
+ }
+
+ // now that we know what the next best feature is record it
+ mask(best_feature_idx) = 1;
+ results(i,0) = best_feature_idx;
+ results(i,1) = best_feature_score;
+ }
+
+ // now normalize the results
+ set_colm(results,1) = colm(results,1)/max(colm(results,1));
+
+ return results;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ matrix<typename kernel_type::scalar_type,0,2,typename kernel_type::mem_manager_type> rank_features (
+ const kcentroid<kernel_type>& kc,
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels,
+ const long num_features
+ )
+ {
+ if (mat(samples).nr() > 0 && num_features == mat(samples)(0).nr())
+ {
+ // if we are going to rank them all then might as well do the recursive feature elimination version
+ return rank_features_impl(kc, mat(samples), mat(labels));
+ }
+ else
+ {
+ return rank_features_impl(kc, mat(samples), mat(labels), num_features);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace rank_features_helpers
+ {
+ template <
+ typename K,
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ typename K::scalar_type centroid_gap (
+ const kcentroid<K>& kc,
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels
+ )
+ {
+ kcentroid<K> kc1(kc);
+ kcentroid<K> kc2(kc);
+
+ // toss all the samples into our kcentroids
+ for (long i = 0; i < samples.size(); ++i)
+ {
+ if (labels(i) > 0)
+ kc1.train(samples(i));
+ else
+ kc2.train(samples(i));
+ }
+
+ // now return the separation between the mean of these two centroids
+ return kc1(kc2);
+ }
+
+ template <
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ class test
+ {
+ typedef typename sample_matrix_type::type sample_type;
+ typedef typename sample_type::type scalar_type;
+ typedef typename sample_type::mem_manager_type mem_manager_type;
+
+ public:
+ test (
+ const sample_matrix_type& samples_,
+ const label_matrix_type& labels_,
+ unsigned long num_sv_,
+ bool verbose_
+ ) : samples(samples_), labels(labels_), num_sv(num_sv_), verbose(verbose_)
+ {
+ }
+
+ double operator() (
+ double gamma
+ ) const
+ {
+ using namespace std;
+
+ // we are doing the optimization in log space so don't forget to convert back to normal space
+ gamma = std::exp(gamma);
+
+ typedef radial_basis_kernel<sample_type> kernel_type;
+ // Make a kcentroid and find out what the gap is at the current gamma. Try to pick a reasonable
+ // tolerance.
+ const double tolerance = std::min(gamma*0.01, 0.01);
+ const kernel_type kern(gamma);
+ kcentroid<kernel_type> kc(kern, tolerance, num_sv);
+ scalar_type temp = centroid_gap(kc, samples, labels);
+
+ if (verbose)
+ {
+ cout << "\rChecking goodness of gamma = " << gamma << ". Goodness = "
+ << temp << " " << flush;
+ }
+ return temp;
+ }
+
+ const sample_matrix_type& samples;
+ const label_matrix_type& labels;
+ unsigned long num_sv;
+ bool verbose;
+
+ };
+
+ template <
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ double find_gamma_with_big_centroid_gap_impl (
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels,
+ double initial_gamma,
+ unsigned long num_sv,
+ bool verbose
+ )
+ {
+ using namespace std;
+
+ if (verbose)
+ {
+ cout << endl;
+ }
+
+ test<sample_matrix_type, label_matrix_type> funct(samples, labels, num_sv, verbose);
+ double best_gamma = std::log(initial_gamma);
+ double goodness = find_max_single_variable(funct, best_gamma, -15, 15, 1e-3, 100);
+
+ if (verbose)
+ {
+ cout << "\rBest gamma = " << std::exp(best_gamma) << ". Goodness = "
+ << goodness << " " << endl;
+ }
+
+ return std::exp(best_gamma);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ double find_gamma_with_big_centroid_gap (
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels,
+ double initial_gamma = 0.1,
+ unsigned long num_sv = 40
+ )
+ {
+ DLIB_ASSERT(initial_gamma > 0 && num_sv > 0 && is_binary_classification_problem(samples, labels),
+ "\t double find_gamma_with_big_centroid_gap()"
+ << "\n\t initial_gamma: " << initial_gamma
+ << "\n\t num_sv: " << num_sv
+ << "\n\t is_binary_classification_problem(): " << is_binary_classification_problem(samples, labels)
+ );
+
+ return rank_features_helpers::find_gamma_with_big_centroid_gap_impl(mat(samples),
+ mat(labels),
+ initial_gamma,
+ num_sv,
+ false);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ double verbose_find_gamma_with_big_centroid_gap (
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels,
+ double initial_gamma = 0.1,
+ unsigned long num_sv = 40
+ )
+ {
+ DLIB_ASSERT(initial_gamma > 0 && num_sv > 0 && is_binary_classification_problem(samples, labels),
+ "\t double verbose_find_gamma_with_big_centroid_gap()"
+ << "\n\t initial_gamma: " << initial_gamma
+ << "\n\t num_sv: " << num_sv
+ << "\n\t is_binary_classification_problem(): " << is_binary_classification_problem(samples, labels)
+ );
+
+ return rank_features_helpers::find_gamma_with_big_centroid_gap_impl(mat(samples),
+ mat(labels),
+ initial_gamma,
+ num_sv,
+ true);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ double compute_mean_squared_distance (
+ const vector_type& samples
+ )
+ {
+ running_stats<double> rs;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ for (unsigned long j = i+1; j < samples.size(); ++j)
+ {
+ rs.add(length_squared(samples[i] - samples[j]));
+ }
+ }
+
+ return rs.mean();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KERNEL_FEATURE_RANKINg_H_
+
+
diff --git a/ml/dlib/dlib/svm/feature_ranking_abstract.h b/ml/dlib/dlib/svm/feature_ranking_abstract.h
new file mode 100644
index 000000000..5a6fd3bb9
--- /dev/null
+++ b/ml/dlib/dlib/svm/feature_ranking_abstract.h
@@ -0,0 +1,136 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_KERNEL_FEATURE_RANKINg_ABSTRACT_H_
+#ifdef DLIB_KERNEL_FEATURE_RANKINg_ABSTRACT_H_
+
+#include <vector>
+#include <limits>
+
+#include "svm_abstract.h"
+#include "kcentroid_abstract.h"
+#include "../is_kind.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ matrix<typename kernel_type::scalar_type> rank_features (
+ const kcentroid<kernel_type>& kc,
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels,
+ const long num_features = samples(0).nr()
+ );
+ /*!
+ requires
+ - sample_matrix_type == a matrix or something convertible to a matrix via mat()
+ - label_matrix_type == a matrix or something convertible to a matrix via mat()
+ - is_binary_classification_problem(samples, labels) == true
+ - kc.train(samples(0)) must be a valid expression. This means that
+ kc must use a kernel type that is capable of operating on the
+ contents of the samples matrix
+ - 0 < num_features <= samples(0).nr()
+ ensures
+ - Let Class1 denote the centroid of all the samples with labels that are < 0
+ - Let Class2 denote the centroid of all the samples with labels that are > 0
+ - finds a ranking of the features where the best features come first. This
+ function does this by computing the distance between the centroid of the Class1
+ samples and the Class2 samples in kernel defined feature space.
+ Good features are then ones that result in the biggest separation between
+ the two centroids of Class1 and Class2.
+ - Uses the kc object to compute the centroids of the two classes
+ - returns a ranking matrix R where:
+ - R.nr() == num_features
+ - r.nc() == 2
+ - R(i,0) == the index of the ith best feature according to our ranking.
+ (e.g. samples(n)(R(0,0)) is the best feature from sample(n) and
+ samples(n)(R(1,0)) is the second best, samples(n)(R(2,0)) the
+ third best and so on)
+ - R(i,1) == a number that indicates how much separation exists between
+ the two centroids when features 0 through i are used.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ double find_gamma_with_big_centroid_gap (
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels,
+ double initial_gamma = 0.1,
+ unsigned long num_sv = 40
+ );
+ /*!
+ requires
+ - initial_gamma > 0
+ - num_sv > 0
+ - is_binary_classification_problem(samples, labels) == true
+ ensures
+ - This is a function that tries to pick a reasonable default value for the gamma
+ parameter of the radial_basis_kernel. It picks the parameter that gives the
+ largest separation between the centroids, in kernel feature space, of two classes
+ of data. It does this using the kcentroid object and it sets the kcentroid up
+ to use num_sv dictionary vectors.
+ - This function does a search for the best gamma and the search starts with
+ the value given by initial_gamma. Better initial guesses will give
+ better results since the routine may get stuck in a local minima.
+ - returns the value of gamma that results in the largest separation.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_matrix_type,
+ typename label_matrix_type
+ >
+ double verbose_find_gamma_with_big_centroid_gap (
+ const sample_matrix_type& samples,
+ const label_matrix_type& labels,
+ double initial_gamma = 0.1,
+ unsigned long num_sv = 40
+ );
+ /*!
+ requires
+ - initial_gamma > 0
+ - num_sv > 0
+ - is_binary_classification_problem(samples, labels) == true
+ ensures
+ - This function does the same exact thing as the above find_gamma_with_big_centroid_gap()
+ except that it is also verbose in the sense that it will print status messages to
+ standard out during its processing.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ double compute_mean_squared_distance (
+ const vector_type& samples
+ );
+ /*!
+ requires
+ - vector_type is something with an interface compatible with std::vector.
+ Additionally, it must in turn contain dlib::matrix types which contain
+ scalars such as float or double values.
+ - for all valid i: is_vector(samples[i]) == true
+ ensures
+ - computes the average value of the squares of all the pairwise
+ distances between every element of samples.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KERNEL_FEATURE_RANKINg_ABSTRACT_H_
+
+
+
diff --git a/ml/dlib/dlib/svm/function.h b/ml/dlib/dlib/svm/function.h
new file mode 100644
index 000000000..f5a62a9f7
--- /dev/null
+++ b/ml/dlib/dlib/svm/function.h
@@ -0,0 +1,882 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_FUNCTION
+#define DLIB_SVm_FUNCTION
+
+#include "function_abstract.h"
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "../rand.h"
+#include "../statistics.h"
+#include "kernel_matrix.h"
+#include "kernel.h"
+#include "sparse_kernel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ struct decision_function
+ {
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::scalar_type result_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
+
+ scalar_vector_type alpha;
+ scalar_type b;
+ K kernel_function;
+ sample_vector_type basis_vectors;
+
+ decision_function (
+ ) : b(0), kernel_function(K()) {}
+
+ decision_function (
+ const decision_function& d
+ ) :
+ alpha(d.alpha),
+ b(d.b),
+ kernel_function(d.kernel_function),
+ basis_vectors(d.basis_vectors)
+ {}
+
+ decision_function (
+ const scalar_vector_type& alpha_,
+ const scalar_type& b_,
+ const K& kernel_function_,
+ const sample_vector_type& basis_vectors_
+ ) :
+ alpha(alpha_),
+ b(b_),
+ kernel_function(kernel_function_),
+ basis_vectors(basis_vectors_)
+ {}
+
+ result_type operator() (
+ const sample_type& x
+ ) const
+ {
+ result_type temp = 0;
+ for (long i = 0; i < alpha.nr(); ++i)
+ temp += alpha(i) * kernel_function(x,basis_vectors(i));
+
+ return temp - b;
+ }
+ };
+
+ template <
+ typename K
+ >
+ void serialize (
+ const decision_function<K>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.alpha, out);
+ serialize(item.b, out);
+ serialize(item.kernel_function, out);
+ serialize(item.basis_vectors, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type decision_function");
+ }
+ }
+
+ template <
+ typename K
+ >
+ void deserialize (
+ decision_function<K>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.alpha, in);
+ deserialize(item.b, in);
+ deserialize(item.kernel_function, in);
+ deserialize(item.basis_vectors, in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type decision_function");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename function_type
+ >
+ struct probabilistic_function
+ {
+ typedef typename function_type::scalar_type scalar_type;
+ typedef typename function_type::result_type result_type;
+ typedef typename function_type::sample_type sample_type;
+ typedef typename function_type::mem_manager_type mem_manager_type;
+
+ scalar_type alpha;
+ scalar_type beta;
+ function_type decision_funct;
+
+ probabilistic_function (
+ ) : alpha(0), beta(0), decision_funct(function_type()) {}
+
+ probabilistic_function (
+ const probabilistic_function& d
+ ) :
+ alpha(d.alpha),
+ beta(d.beta),
+ decision_funct(d.decision_funct)
+ {}
+
+ probabilistic_function (
+ const scalar_type a_,
+ const scalar_type b_,
+ const function_type& decision_funct_
+ ) :
+ alpha(a_),
+ beta(b_),
+ decision_funct(decision_funct_)
+ {}
+
+ result_type operator() (
+ const sample_type& x
+ ) const
+ {
+ result_type f = decision_funct(x);
+ return 1/(1 + std::exp(alpha*f + beta));
+ }
+ };
+
+ template <
+ typename function_type
+ >
+ void serialize (
+ const probabilistic_function<function_type>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.alpha, out);
+ serialize(item.beta, out);
+ serialize(item.decision_funct, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type probabilistic_function");
+ }
+ }
+
+ template <
+ typename function_type
+ >
+ void deserialize (
+ probabilistic_function<function_type>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.alpha, in);
+ deserialize(item.beta, in);
+ deserialize(item.decision_funct, in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type probabilistic_function");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ struct probabilistic_decision_function
+ {
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::scalar_type result_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ scalar_type alpha;
+ scalar_type beta;
+ decision_function<K> decision_funct;
+
+ probabilistic_decision_function (
+ ) : alpha(0), beta(0), decision_funct(decision_function<K>()) {}
+
+ probabilistic_decision_function (
+ const probabilistic_function<decision_function<K> >& d
+ ) :
+ alpha(d.alpha),
+ beta(d.beta),
+ decision_funct(d.decision_funct)
+ {}
+
+ probabilistic_decision_function (
+ const probabilistic_decision_function& d
+ ) :
+ alpha(d.alpha),
+ beta(d.beta),
+ decision_funct(d.decision_funct)
+ {}
+
+ probabilistic_decision_function (
+ const scalar_type a_,
+ const scalar_type b_,
+ const decision_function<K>& decision_funct_
+ ) :
+ alpha(a_),
+ beta(b_),
+ decision_funct(decision_funct_)
+ {}
+
+ result_type operator() (
+ const sample_type& x
+ ) const
+ {
+ result_type f = decision_funct(x);
+ return 1/(1 + std::exp(alpha*f + beta));
+ }
+ };
+
+ template <
+ typename K
+ >
+ void serialize (
+ const probabilistic_decision_function<K>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.alpha, out);
+ serialize(item.beta, out);
+ serialize(item.decision_funct, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type probabilistic_decision_function");
+ }
+ }
+
+ template <
+ typename K
+ >
+ void deserialize (
+ probabilistic_decision_function<K>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.alpha, in);
+ deserialize(item.beta, in);
+ deserialize(item.decision_funct, in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type probabilistic_decision_function");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class distance_function
+ {
+ public:
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::scalar_type result_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
+
+
+ distance_function (
+ ) : b(0), kernel_function(K()) {}
+
+ explicit distance_function (
+ const kernel_type& kern
+ ) : b(0), kernel_function(kern) {}
+
+ distance_function (
+ const kernel_type& kern,
+ const sample_type& samp
+ ) :
+ alpha(ones_matrix<scalar_type>(1,1)),
+ b(kern(samp,samp)),
+ kernel_function(kern)
+ {
+ basis_vectors.set_size(1,1);
+ basis_vectors(0) = samp;
+ }
+
+ distance_function (
+ const decision_function<K>& f
+ ) :
+ alpha(f.alpha),
+ b(trans(f.alpha)*kernel_matrix(f.kernel_function,f.basis_vectors)*f.alpha),
+ kernel_function(f.kernel_function),
+ basis_vectors(f.basis_vectors)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(f.alpha.size() == f.basis_vectors.size(),
+ "\t distance_function(f)"
+ << "\n\t The supplied decision_function is invalid."
+ << "\n\t f.alpha.size(): " << f.alpha.size()
+ << "\n\t f.basis_vectors.size(): " << f.basis_vectors.size()
+ );
+ }
+
+ distance_function (
+ const distance_function& d
+ ) :
+ alpha(d.alpha),
+ b(d.b),
+ kernel_function(d.kernel_function),
+ basis_vectors(d.basis_vectors)
+ {
+ }
+
+ distance_function (
+ const scalar_vector_type& alpha_,
+ const scalar_type& b_,
+ const K& kernel_function_,
+ const sample_vector_type& basis_vectors_
+ ) :
+ alpha(alpha_),
+ b(b_),
+ kernel_function(kernel_function_),
+ basis_vectors(basis_vectors_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
+ "\t distance_function()"
+ << "\n\t The supplied arguments are invalid."
+ << "\n\t alpha_.size(): " << alpha_.size()
+ << "\n\t basis_vectors_.size(): " << basis_vectors_.size()
+ );
+ }
+
+ distance_function (
+ const scalar_vector_type& alpha_,
+ const K& kernel_function_,
+ const sample_vector_type& basis_vectors_
+ ) :
+ alpha(alpha_),
+ b(trans(alpha)*kernel_matrix(kernel_function_,basis_vectors_)*alpha),
+ kernel_function(kernel_function_),
+ basis_vectors(basis_vectors_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
+ "\t distance_function()"
+ << "\n\t The supplied arguments are invalid."
+ << "\n\t alpha_.size(): " << alpha_.size()
+ << "\n\t basis_vectors_.size(): " << basis_vectors_.size()
+ );
+ }
+
+ const scalar_vector_type& get_alpha (
+ ) const { return alpha; }
+
+ const scalar_type& get_squared_norm (
+ ) const { return b; }
+
+ const K& get_kernel(
+ ) const { return kernel_function; }
+
+ const sample_vector_type& get_basis_vectors (
+ ) const { return basis_vectors; }
+
+ result_type operator() (
+ const sample_type& x
+ ) const
+ {
+ result_type temp = 0;
+ for (long i = 0; i < alpha.nr(); ++i)
+ temp += alpha(i) * kernel_function(x,basis_vectors(i));
+
+ temp = b + kernel_function(x,x) - 2*temp;
+ if (temp > 0)
+ return std::sqrt(temp);
+ else
+ return 0;
+ }
+
+ result_type operator() (
+ const distance_function& x
+ ) const
+ {
+ result_type temp = 0;
+ for (long i = 0; i < alpha.nr(); ++i)
+ for (long j = 0; j < x.alpha.nr(); ++j)
+ temp += alpha(i)*x.alpha(j) * kernel_function(basis_vectors(i), x.basis_vectors(j));
+
+ temp = b + x.b - 2*temp;
+ if (temp > 0)
+ return std::sqrt(temp);
+ else
+ return 0;
+ }
+
+ distance_function operator* (
+ const scalar_type& val
+ ) const
+ {
+ return distance_function(val*alpha,
+ val*val*b,
+ kernel_function,
+ basis_vectors);
+ }
+
+ distance_function operator/ (
+ const scalar_type& val
+ ) const
+ {
+ return distance_function(alpha/val,
+ b/val/val,
+ kernel_function,
+ basis_vectors);
+ }
+
+ distance_function operator+ (
+ const distance_function& rhs
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
+ "\t distance_function distance_function::operator+()"
+ << "\n\t You can only add two distance_functions together if they use the same kernel."
+ );
+
+ if (alpha.size() == 0)
+ return rhs;
+ else if (rhs.alpha.size() == 0)
+ return *this;
+ else
+ return distance_function(join_cols(alpha, rhs.alpha),
+ b + rhs.b + 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
+ kernel_function,
+ join_cols(basis_vectors, rhs.basis_vectors));
+ }
+
+ distance_function operator- (
+ const distance_function& rhs
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
+ "\t distance_function distance_function::operator-()"
+ << "\n\t You can only subtract two distance_functions if they use the same kernel."
+ );
+
+ if (alpha.size() == 0 && rhs.alpha.size() == 0)
+ return distance_function(kernel_function);
+ else if (alpha.size() != 0 && rhs.alpha.size() == 0)
+ return *this;
+ else if (alpha.size() == 0 && rhs.alpha.size() != 0)
+ return -1*rhs;
+ else
+ return distance_function(join_cols(alpha, -rhs.alpha),
+ b + rhs.b - 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
+ kernel_function,
+ join_cols(basis_vectors, rhs.basis_vectors));
+ }
+
+ private:
+
+ scalar_vector_type alpha;
+ scalar_type b;
+ K kernel_function;
+ sample_vector_type basis_vectors;
+
+ };
+
+ template <
+ typename K
+ >
+ distance_function<K> operator* (
+ const typename K::scalar_type& val,
+ const distance_function<K>& df
+ ) { return df*val; }
+
+ template <
+ typename K
+ >
+ void serialize (
+ const distance_function<K>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.alpha, out);
+ serialize(item.b, out);
+ serialize(item.kernel_function, out);
+ serialize(item.basis_vectors, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type distance_function");
+ }
+ }
+
+ template <
+ typename K
+ >
+ void deserialize (
+ distance_function<K>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.alpha, in);
+ deserialize(item.b, in);
+ deserialize(item.kernel_function, in);
+ deserialize(item.basis_vectors, in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type distance_function");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename function_type,
+ typename normalizer_type = vector_normalizer<typename function_type::sample_type>
+ >
+ struct normalized_function
+ {
+ typedef typename function_type::result_type result_type;
+ typedef typename function_type::sample_type sample_type;
+ typedef typename function_type::mem_manager_type mem_manager_type;
+
+ normalizer_type normalizer;
+ function_type function;
+
+ normalized_function (
+ ){}
+
+ normalized_function (
+ const normalized_function& f
+ ) :
+ normalizer(f.normalizer),
+ function(f.function)
+ {}
+
+ const std::vector<result_type> get_labels(
+ ) const { return function.get_labels(); }
+
+ unsigned long number_of_classes (
+ ) const { return function.number_of_classes(); }
+
+ normalized_function (
+ const vector_normalizer<sample_type>& normalizer_,
+ const function_type& funct
+ ) : normalizer(normalizer_), function(funct) {}
+
+ result_type operator() (
+ const sample_type& x
+ ) const { return function(normalizer(x)); }
+ };
+
+ template <
+ typename function_type,
+ typename normalizer_type
+ >
+ void serialize (
+ const normalized_function<function_type,normalizer_type>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.normalizer, out);
+ serialize(item.function, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type normalized_function");
+ }
+ }
+
+ template <
+ typename function_type,
+ typename normalizer_type
+ >
+ void deserialize (
+ normalized_function<function_type,normalizer_type>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.normalizer, in);
+ deserialize(item.function, in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type normalized_function");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ struct projection_function
+ {
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
+ typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
+ typedef scalar_vector_type result_type;
+
+ scalar_matrix_type weights;
+ K kernel_function;
+ sample_vector_type basis_vectors;
+
+ projection_function (
+ ) {}
+
+ projection_function (
+ const projection_function& f
+ ) : weights(f.weights), kernel_function(f.kernel_function), basis_vectors(f.basis_vectors) {}
+
+ projection_function (
+ const scalar_matrix_type& weights_,
+ const K& kernel_function_,
+ const sample_vector_type& basis_vectors_
+ ) : weights(weights_), kernel_function(kernel_function_), basis_vectors(basis_vectors_) {}
+
+ long out_vector_size (
+ ) const { return weights.nr(); }
+
+ const result_type& operator() (
+ const sample_type& x
+ ) const
+ {
+ // Run the x sample through all the basis functions we have and then
+ // multiply it by the weights matrix and return the result. Note that
+ // the temp vectors are here to avoid reallocating their memory every
+ // time this function is called.
+ temp1 = kernel_matrix(kernel_function, basis_vectors, x);
+ temp2 = weights*temp1;
+ return temp2;
+ }
+
+ private:
+ mutable result_type temp1, temp2;
+ };
+
+ template <
+ typename K
+ >
+ void serialize (
+ const projection_function<K>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.weights, out);
+ serialize(item.kernel_function, out);
+ serialize(item.basis_vectors, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type projection_function");
+ }
+ }
+
+ template <
+ typename K
+ >
+ void deserialize (
+ projection_function<K>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.weights, in);
+ deserialize(item.kernel_function, in);
+ deserialize(item.basis_vectors, in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type projection_function");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename result_type_ = typename K::scalar_type
+ >
+ struct multiclass_linear_decision_function
+ {
+ typedef result_type_ result_type;
+
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
+
+ // You are getting a compiler error on this line because you supplied a non-linear kernel
+ // to the multiclass_linear_decision_function object. You have to use one of the linear
+ // kernels with this object.
+ COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
+ is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
+
+
+ scalar_matrix_type weights;
+ scalar_vector_type b;
+ std::vector<result_type> labels;
+
+ const std::vector<result_type>& get_labels(
+ ) const { return labels; }
+
+ unsigned long number_of_classes (
+ ) const { return labels.size(); }
+
+ std::pair<result_type, scalar_type> predict (
+ const sample_type& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(weights.size() > 0 &&
+ weights.nr() == (long)number_of_classes() &&
+ weights.nr() == b.size(),
+ "\t pair<result_type,scalar_type> multiclass_linear_decision_function::predict(x)"
+ << "\n\t This object must be properly initialized before you can use it."
+ << "\n\t weights.size(): " << weights.size()
+ << "\n\t weights.nr(): " << weights.nr()
+ << "\n\t number_of_classes(): " << number_of_classes()
+ );
+
+ // Rather than doing something like, best_idx = index_of_max(weights*x-b)
+ // we do the following somewhat more complex thing because this supports
+ // both sparse and dense samples.
+ scalar_type best_val = dot(rowm(weights,0),x) - b(0);
+ unsigned long best_idx = 0;
+
+ for (unsigned long i = 1; i < labels.size(); ++i)
+ {
+ scalar_type temp = dot(rowm(weights,i),x) - b(i);
+ if (temp > best_val)
+ {
+ best_val = temp;
+ best_idx = i;
+ }
+ }
+
+ return std::make_pair(labels[best_idx], best_val);
+ }
+
+ result_type operator() (
+ const sample_type& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(weights.size() > 0 &&
+ weights.nr() == (long)number_of_classes() &&
+ weights.nr() == b.size(),
+ "\t result_type multiclass_linear_decision_function::operator()(x)"
+ << "\n\t This object must be properly initialized before you can use it."
+ << "\n\t weights.size(): " << weights.size()
+ << "\n\t weights.nr(): " << weights.nr()
+ << "\n\t number_of_classes(): " << number_of_classes()
+ );
+
+ return predict(x).first;
+ }
+ };
+
+ template <
+ typename K,
+ typename result_type_
+ >
+ void serialize (
+ const multiclass_linear_decision_function<K,result_type_>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.weights, out);
+ serialize(item.b, out);
+ serialize(item.labels, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type multiclass_linear_decision_function");
+ }
+ }
+
+ template <
+ typename K,
+ typename result_type_
+ >
+ void deserialize (
+ multiclass_linear_decision_function<K,result_type_>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.weights, in);
+ deserialize(item.b, in);
+ deserialize(item.labels, in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type multiclass_linear_decision_function");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_FUNCTION
+
+
diff --git a/ml/dlib/dlib/svm/function_abstract.h b/ml/dlib/dlib/svm/function_abstract.h
new file mode 100644
index 000000000..783a68c50
--- /dev/null
+++ b/ml/dlib/dlib/svm/function_abstract.h
@@ -0,0 +1,997 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_FUNCTION_ABSTRACT_
+#ifdef DLIB_SVm_FUNCTION_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "../statistics/statistics_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ struct decision_function
+ {
+ /*!
+ REQUIREMENTS ON K
+ K must be a kernel function object type as defined at the
+ top of dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a classification or regression function that was
+ learned by a kernel based learning algorithm. Therefore, it is a function
+ object that takes a sample object and returns a scalar value.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call operator() on this object from multiple threads so
+ long as the kernel, K, is also threadsafe. This is because operator()
+ is a read-only operation. However, any operation that modifies a
+ decision_function is not threadsafe.
+ !*/
+
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::scalar_type result_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
+
+ scalar_vector_type alpha;
+ scalar_type b;
+ K kernel_function;
+ sample_vector_type basis_vectors;
+
+ decision_function (
+ );
+ /*!
+ ensures
+ - #b == 0
+ - #alpha.nr() == 0
+ - #basis_vectors.nr() == 0
+ !*/
+
+ decision_function (
+ const decision_function& f
+ );
+ /*!
+ ensures
+ - #*this is a copy of f
+ !*/
+
+ decision_function (
+ const scalar_vector_type& alpha_,
+ const scalar_type& b_,
+ const K& kernel_function_,
+ const sample_vector_type& basis_vectors_
+ ) : alpha(alpha_), b(b_), kernel_function(kernel_function_), basis_vectors(basis_vectors_) {}
+ /*!
+ ensures
+ - populates the decision function with the given basis vectors, weights(i.e. alphas),
+ b term, and kernel function.
+ !*/
+
+ result_type operator() (
+ const sample_type& x
+ ) const
+ /*!
+ ensures
+ - evaluates this sample according to the decision
+ function contained in this object.
+ !*/
+ {
+ result_type temp = 0;
+ for (long i = 0; i < alpha.nr(); ++i)
+ temp += alpha(i) * kernel_function(x,basis_vectors(i));
+
+ return temp - b;
+ }
+ };
+
+ template <
+ typename K
+ >
+ void serialize (
+ const decision_function<K>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for decision_function
+ !*/
+
+ template <
+ typename K
+ >
+ void deserialize (
+ decision_function<K>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for decision_function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename function_type
+ >
+ struct probabilistic_function
+ {
+ /*!
+ REQUIREMENTS ON function_type
+ - function_type must be a function object with an overloaded
+ operator() similar to the other function objects defined in
+ this file. The operator() should return a scalar type such as
+ double or float.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a binary decision function that returns an
+ estimate of the probability that a given sample is in the +1 class.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call operator() on this object from multiple threads so
+ long as decision_funct is also threadsafe. This is because operator()
+ is a read-only operation. However, any operation that modifies a
+ probabilistic_function is not threadsafe.
+ !*/
+
+ typedef typename function_type::scalar_type scalar_type;
+ typedef typename function_type::result_type result_type;
+ typedef typename function_type::sample_type sample_type;
+ typedef typename function_type::mem_manager_type mem_manager_type;
+
+ scalar_type alpha;
+ scalar_type beta;
+ function_type decision_funct;
+
+ probabilistic_function (
+ );
+ /*!
+ ensures
+ - #alpha == 0
+ - #beta == 0
+ - #decision_funct has its initial value
+ !*/
+
+ probabilistic_function (
+ const probabilistic_function& f
+ );
+ /*!
+ ensures
+ - #*this is a copy of f
+ !*/
+
+ probabilistic_function (
+ const scalar_type a,
+ const scalar_type b,
+ const function_type& decision_funct_
+ ) : alpha(a), beta(b), decision_funct(decision_funct_) {}
+ /*!
+ ensures
+ - populates the probabilistic decision function with the given alpha, beta,
+ and decision function.
+ !*/
+
+ result_type operator() (
+ const sample_type& x
+ ) const
+ /*!
+ ensures
+ - returns a number P such that:
+ - 0 <= P <= 1
+ - P represents the probability that sample x is from
+ the class +1
+ !*/
+ {
+ // Evaluate the normal decision function
+ result_type f = decision_funct(x);
+ // Now basically normalize the output so that it is a properly
+ // conditioned probability of x being in the +1 class given
+ // the output of the decision function.
+ return 1/(1 + std::exp(alpha*f + beta));
+ }
+ };
+
+ template <
+ typename function_type
+ >
+ void serialize (
+ const probabilistic_function<function_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for probabilistic_function
+ !*/
+
+ template <
+ typename function_type
+ >
+ void deserialize (
+ probabilistic_function<function_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for probabilistic_function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ struct probabilistic_decision_function
+ {
+ /*!
+ REQUIREMENTS ON K
+ K must be a kernel function object type as defined at the
+ top of dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a binary decision function that returns an
+ estimate of the probability that a given sample is in the +1 class.
+
+ Note that this object is essentially just a copy of
+ probabilistic_function but with the template argument
+ changed from being a function type to a kernel type. Therefore, this
+ type is just a convenient version of probabilistic_function
+ for the case where the decision function is a dlib::decision_function<K>.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call operator() on this object from multiple threads so
+ long as the kernel, K, is also threadsafe. This is because operator()
+ is a read-only operation. However, any operation that modifies a
+ probabilistic_decision_function is not threadsafe.
+ !*/
+
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::scalar_type result_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ scalar_type alpha;
+ scalar_type beta;
+ decision_function<K> decision_funct;
+
+ probabilistic_decision_function (
+ );
+ /*!
+ ensures
+ - #alpha == 0
+ - #beta == 0
+ - #decision_funct has its initial value
+ !*/
+
+ probabilistic_decision_function (
+ const probabilistic_decision_function& f
+ );
+ /*!
+ ensures
+ - #*this is a copy of f
+ !*/
+
+ probabilistic_decision_function (
+ const probabilistic_function<decision_function<K> >& d
+ );
+ /*!
+ ensures
+ - #*this is a copy of f
+ !*/
+
+ probabilistic_decision_function (
+ const scalar_type a,
+ const scalar_type b,
+ const decision_function<K>& decision_funct_
+ ) : alpha(a), beta(b), decision_funct(decision_funct_) {}
+ /*!
+ ensures
+ - populates the probabilistic decision function with the given alpha, beta,
+ and decision_function.
+ !*/
+
+ result_type operator() (
+ const sample_type& x
+ ) const
+ /*!
+ ensures
+ - returns a number P such that:
+ - 0 <= P <= 1
+ - P represents the probability that sample x is from
+ the class +1
+ !*/
+ {
+ // Evaluate the normal decision function
+ result_type f = decision_funct(x);
+ // Now basically normalize the output so that it is a properly
+ // conditioned probability of x being in the +1 class given
+ // the output of the decision function.
+ return 1/(1 + std::exp(alpha*f + beta));
+ }
+ };
+
+ template <
+ typename K
+ >
+ void serialize (
+ const probabilistic_decision_function<K>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for probabilistic_decision_function
+ !*/
+
+ template <
+ typename K
+ >
+ void deserialize (
+ probabilistic_decision_function<K>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for probabilistic_decision_function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class distance_function
+ {
+ /*!
+ REQUIREMENTS ON K
+ K must be a kernel function object type as defined at the
+ top of dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a point in kernel induced feature space.
+ You may use this object to find the distance from the point it
+ represents to points in input space as well as other points
+ represented by distance_functions.
+
+ Specifically, if O() is the feature mapping associated with
+ the kernel used by this object. Then this object represents
+ the point:
+ sum alpha(i)*O(basis_vectors(i))
+
+ I.e. It represents a linear combination of the basis vectors where
+ the weights of the linear combination are stored in the alpha vector.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call the const members of this object from multiple
+ threads so long as the kernel, K, is also threadsafe. This is because
+ the const members are purely read-only operations. However, any
+ operation that modifies a distance_function is not threadsafe.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::scalar_type result_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
+
+ distance_function (
+ );
+ /*!
+ ensures
+ - #get_squared_norm() == 0
+ - #get_alpha().size() == 0
+ - #get_basis_vectors().size() == 0
+ - #get_kernel() == K() (i.e. the default value of the kernel)
+ !*/
+
+ explicit distance_function (
+ const kernel_type& kern
+ );
+ /*!
+ ensures
+ - #get_squared_norm() == 0
+ - #get_alpha().size() == 0
+ - #get_basis_vectors().size() == 0
+ - #get_kernel() == kern
+ !*/
+
+ distance_function (
+ const kernel_type& kern,
+ const sample_type& samp
+ );
+ /*!
+ ensures
+ - This object represents the point in kernel feature space which
+ corresponds directly to the given sample. In particular this means
+ that:
+ - #get_kernel() == kern
+ - #get_alpha() == a vector of length 1 which contains the value 1
+ - #get_basis_vectors() == a vector of length 1 which contains samp
+ !*/
+
+ distance_function (
+ const decision_function<K>& f
+ );
+ /*!
+ ensures
+ - Every decision_function represents a point in kernel feature space along
+ with a bias value. This constructor discards the bias value and creates
+ a distance_function which represents the point associated with the given
+ decision_function f. In particular, this means:
+ - #get_alpha() == f.alpha
+ - #get_kernel() == f.kernel_function
+ - #get_basis_vectors() == f.basis_vectors
+ !*/
+
+ distance_function (
+ const distance_function& f
+ );
+ /*!
+ requires
+ - f is a valid distance_function. In particular, this means that
+ f.alpha.size() == f.basis_vectors.size()
+ ensures
+ - #*this is a copy of f
+ !*/
+
+ distance_function (
+ const scalar_vector_type& alpha,
+ const scalar_type& squared_norm,
+ const K& kernel_function,
+ const sample_vector_type& basis_vectors
+ );
+ /*!
+ requires
+ - alpha.size() == basis_vectors.size()
+ - squared_norm == trans(alpha)*kernel_matrix(kernel_function,basis_vectors)*alpha
+ (Basically, squared_norm needs to be set properly for this object to make sense.
+ You should prefer to use the following constructor which computes squared_norm for
+ you. This version is provided just in case you already know squared_norm and
+ don't want to spend CPU cycles to recompute it.)
+ ensures
+ - populates the distance function with the given basis vectors, weights(i.e. alphas),
+ squared_norm value, and kernel function. I.e.
+ - #get_alpha() == alpha
+ - #get_squared_norm() == squared_norm
+ - #get_kernel() == kernel_function
+ - #get_basis_vectors() == basis_vectors
+ !*/
+
+ distance_function (
+ const scalar_vector_type& alpha,
+ const K& kernel_function,
+ const sample_vector_type& basis_vectors
+ );
+ /*!
+ requires
+ - alpha.size() == basis_vectors.size()
+ ensures
+ - populates the distance function with the given basis vectors, weights(i.e. alphas),
+ and kernel function. The correct b value is computed automatically. I.e.
+ - #get_alpha() == alpha
+ - #get_squared_norm() == trans(alpha)*kernel_matrix(kernel_function,basis_vectors)*alpha
+ (i.e. get_squared_norm() will be automatically set to the correct value)
+ - #get_kernel() == kernel_function
+ - #get_basis_vectors() == basis_vectors
+ !*/
+
+ const scalar_vector_type& get_alpha (
+ ) const;
+ /*!
+ ensures
+ - returns the set of weights on each basis vector in this object
+ !*/
+
+ const scalar_type& get_squared_norm (
+ ) const;
+ /*!
+ ensures
+ - returns the squared norm of the point represented by this object. This value is
+ equal to the following expression:
+ trans(get_alpha()) * kernel_matrix(get_kernel(),get_basis_vectors()) * get_alpha()
+ !*/
+
+ const K& get_kernel(
+ ) const;
+ /*!
+ ensures
+ - returns the kernel used by this object.
+ !*/
+
+ const sample_vector_type& get_basis_vectors (
+ ) const;
+ /*!
+ ensures
+ - returns the set of basis vectors contained in this object
+ !*/
+
+ result_type operator() (
+ const sample_type& x
+ ) const;
+ /*!
+ ensures
+ - Let O(x) represent the point x projected into kernel induced feature space.
+ - let c == sum_over_i get_alpha()(i)*O(get_basis_vectors()(i)) == the point in kernel space that
+ this object represents. That is, c is the weighted sum of basis vectors.
+ - Then this object returns the distance between the point O(x) and c in kernel
+ space.
+ !*/
+
+ result_type operator() (
+ const distance_function& x
+ ) const;
+ /*!
+ requires
+ - kernel_function == x.kernel_function
+ ensures
+ - returns the distance between the points in kernel space represented by *this and x.
+ !*/
+
+ distance_function operator* (
+ const scalar_type& val
+ ) const;
+ /*!
+ ensures
+ - multiplies the point represented by *this by val and returns the result. In
+ particular, this function returns a decision_function DF such that:
+ - DF.get_basis_vectors() == get_basis_vectors()
+ - DF.get_kernel() == get_kernel()
+ - DF.get_alpha() == get_alpha() * val
+ !*/
+
+ distance_function operator/ (
+ const scalar_type& val
+ ) const;
+ /*!
+ ensures
+ - divides the point represented by *this by val and returns the result. In
+ particular, this function returns a decision_function DF such that:
+ - DF.get_basis_vectors() == get_basis_vectors()
+ - DF.get_kernel() == get_kernel()
+ - DF.get_alpha() == get_alpha() / val
+ !*/
+
+ distance_function operator+ (
+ const distance_function& rhs
+ ) const;
+ /*!
+ requires
+ - get_kernel() == rhs.get_kernel()
+ ensures
+ - returns a distance function DF such that:
+ - DF represents the sum of the point represented by *this and rhs
+ - DF.get_basis_vectors().size() == get_basis_vectors().size() + rhs.get_basis_vectors().size()
+ - DF.get_basis_vectors() contains all the basis vectors in both *this and rhs.
+ - DF.get_kernel() == get_kernel()
+ - DF.alpha == join_cols(get_alpha(), rhs.get_alpha())
+ !*/
+
+ distance_function operator- (
+ const distance_function& rhs
+ ) const;
+ /*!
+ requires
+ - get_kernel() == rhs.get_kernel()
+ ensures
+ - returns a distance function DF such that:
+ - DF represents the difference of the point represented by *this and rhs (i.e. *this - rhs)
+ - DF.get_basis_vectors().size() == get_basis_vectors().size() + rhs.get_basis_vectors().size()
+ - DF.get_basis_vectors() contains all the basis vectors in both *this and rhs.
+ - DF.get_kernel() == get_kernel()
+ - DF.alpha == join_cols(get_alpha(), -1 * rhs.get_alpha())
+ !*/
+ };
+
+ template <
+ typename K
+ >
+ distance_function<K> operator* (
+ const typename K::scalar_type& val,
+ const distance_function<K>& df
+ ) { return df*val; }
+ /*!
+ ensures
+ - multiplies the point represented by *this by val and returns the result. This
+ function just allows multiplication syntax of the form val*df.
+ !*/
+
+ template <
+ typename K
+ >
+ void serialize (
+ const distance_function<K>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for distance_function
+ !*/
+
+ template <
+ typename K
+ >
+ void deserialize (
+ distance_function<K>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for distance_function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename function_type,
+ typename normalizer_type = vector_normalizer<typename function_type::sample_type>
+ >
+ struct normalized_function
+ {
+ /*!
+ REQUIREMENTS ON function_type
+ - function_type must be a function object with an overloaded
+ operator() similar to the other function objects defined in
+ this file.
+
+ REQUIREMENTS ON normalizer_type
+ - normalizer_type must be a function object with an overloaded
+ operator() that takes a sample_type and returns a sample_type.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a container for another function
+ object and an instance of a normalizer function.
+
+ It automatically normalizes all inputs before passing them
+ off to the contained function object.
+ !*/
+
+ typedef typename function_type::result_type result_type;
+ typedef typename function_type::sample_type sample_type;
+ typedef typename function_type::mem_manager_type mem_manager_type;
+
+ normalizer_type normalizer;
+ function_type function;
+
+ normalized_function (
+ );
+ /*!
+ ensures
+ - the members of this object have their default values
+ !*/
+
+ normalized_function (
+ const normalized_function& f
+ );
+ /*!
+ ensures
+ - #*this is a copy of f
+ !*/
+
+ normalized_function (
+ const vector_normalizer<sample_type>& normalizer_,
+ const function_type& funct
+ ) : normalizer(normalizer_), function(funct) {}
+ /*!
+ ensures
+ - populates this object with the vector_normalizer and function object
+ !*/
+
+ const std::vector<result_type> get_labels(
+ ) const;
+ /*!
+ ensures
+ - returns function.get_labels()
+ !*/
+
+ unsigned long number_of_classes (
+ ) const;
+ /*!
+ ensures
+ - returns function.number_of_classes()
+ !*/
+
+ result_type operator() (
+ const sample_type& x
+ ) const
+ /*!
+ ensures
+ - returns function(normalizer(x))
+ !*/
+ };
+
+ template <
+ typename function_type,
+ typename normalizer_type
+ >
+ void serialize (
+ const normalized_function<function_type, normalizer_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for normalized_function
+ !*/
+
+ template <
+ typename function_type,
+ typename normalizer_type
+ >
+ void deserialize (
+ normalized_function<function_type, normalizer_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for normalized_function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ struct projection_function
+ {
+ /*!
+ REQUIREMENTS ON K
+ K must be a kernel function object type as defined at the
+ top of dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a function that takes a data sample and projects
+ it into kernel feature space. The result is a real valued column vector that
+ represents a point in a kernel feature space.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ Instances of this object have a mutable cache which is used by const
+ member functions. Therefore, it is not safe to use one instance of
+ this object from multiple threads (unless protected by a mutex).
+ !*/
+
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
+ typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
+ typedef scalar_vector_type result_type;
+
+ scalar_matrix_type weights;
+ K kernel_function;
+ sample_vector_type basis_vectors;
+
+ projection_function (
+ );
+ /*!
+ ensures
+ - #weights.size() == 0
+ - #basis_vectors.size() == 0
+ !*/
+
+ projection_function (
+ const projection_function& f
+ );
+ /*!
+ ensures
+ - #*this is a copy of f
+ !*/
+
+ projection_function (
+ const scalar_matrix_type& weights_,
+ const K& kernel_function_,
+ const sample_vector_type& basis_vectors_
+ ) : weights(weights_), kernel_function(kernel_function_), basis_vectors(basis_vectors_) {}
+ /*!
+ ensures
+ - populates the projection function with the given basis vectors, weights,
+ and kernel function.
+ !*/
+
+ long out_vector_size (
+ ) const;
+ /*!
+ ensures
+ - returns weights.nr()
+ (i.e. returns the dimensionality of the vectors output by this projection_function.)
+ !*/
+
+ const result_type& operator() (
+ const sample_type& x
+ ) const
+ /*!
+ requires
+ - weights.nc() == basis_vectors.size()
+ - out_vector_size() > 0
+ ensures
+ - Takes the given x sample and projects it onto part of the kernel feature
+ space spanned by the basis_vectors. The exact projection arithmetic is
+ defined below.
+ !*/
+ {
+ // Run the x sample through all the basis functions we have and then
+ // multiply it by the weights matrix and return the result. Note that
+ // the temp vectors are here to avoid reallocating their memory every
+ // time this function is called.
+ temp1 = kernel_matrix(kernel_function, basis_vectors, x);
+ temp2 = weights*temp1;
+ return temp2;
+ }
+
+ private:
+ mutable result_type temp1, temp2;
+ };
+
+ template <
+ typename K
+ >
+ void serialize (
+ const projection_function<K>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for projection_function
+ !*/
+
+ template <
+ typename K
+ >
+ void deserialize (
+ projection_function<K>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for projection_function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename result_type_ = typename K::scalar_type
+ >
+ struct multiclass_linear_decision_function
+ {
+ /*!
+ REQUIREMENTS ON K
+ K must be either linear_kernel or sparse_linear_kernel.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a multiclass classifier built out of a set of
+ binary classifiers. Each binary classifier is used to vote for the
+ correct multiclass label using a one vs. all strategy. Therefore,
+ if you have N classes then there will be N binary classifiers inside
+ this object. Additionally, this object is linear in the sense that
+ each of these binary classifiers is a simple linear plane.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call the const member functions of this object from
+ multiple threads. This is because the const members are purely
+ read-only operations. However, any operation that modifies a
+ multiclass_linear_decision_function is not threadsafe.
+ !*/
+
+ typedef result_type_ result_type;
+
+ typedef K kernel_type;
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
+
+ scalar_matrix_type weights;
+ scalar_vector_type b;
+ std::vector<result_type> labels;
+
+ const std::vector<result_type>& get_labels(
+ ) const { return labels; }
+ /*!
+ ensures
+ - returns a vector containing all the labels which can be
+ predicted by this object.
+ !*/
+
+ unsigned long number_of_classes (
+ ) const;
+ /*!
+ ensures
+ - returns get_labels().size()
+ (i.e. returns the number of different labels/classes predicted by
+ this object)
+ !*/
+
+ std::pair<result_type, scalar_type> predict (
+ const sample_type& x
+ ) const;
+ /*!
+ requires
+ - weights.size() > 0
+ - weights.nr() == number_of_classes() == b.size()
+ - if (x is a dense vector, i.e. a dlib::matrix) then
+ - is_vector(x) == true
+ - x.size() == weights.nc()
+ (i.e. it must be legal to multiply weights with x)
+ ensures
+ - Returns the predicted label for the x sample and also it's score.
+ In particular, it returns the following:
+ std::make_pair(labels[index_of_max(weights*x-b)], max(weights*x-b))
+ !*/
+
+ result_type operator() (
+ const sample_type& x
+ ) const;
+ /*!
+ requires
+ - weights.size() > 0
+ - weights.nr() == number_of_classes() == b.size()
+ - if (x is a dense vector, i.e. a dlib::matrix) then
+ - is_vector(x) == true
+ - x.size() == weights.nc()
+ (i.e. it must be legal to multiply weights with x)
+ ensures
+ - Returns the predicted label for the x sample. In particular, it returns
+ the following:
+ labels[index_of_max(weights*x-b)]
+ Or in other words, this function returns predict(x).first
+ !*/
+ };
+
+ template <
+ typename K,
+ typename result_type_
+ >
+ void serialize (
+ const multiclass_linear_decision_function<K,result_type_>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for multiclass_linear_decision_function
+ !*/
+
+ template <
+ typename K,
+ typename result_type_
+ >
+ void deserialize (
+ multiclass_linear_decision_function<K,result_type_>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for multiclass_linear_decision_function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_FUNCTION_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/kcentroid.h b/ml/dlib/dlib/svm/kcentroid.h
new file mode 100644
index 000000000..5f380486a
--- /dev/null
+++ b/ml/dlib/dlib/svm/kcentroid.h
@@ -0,0 +1,614 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_KCENTROId_
+#define DLIB_KCENTROId_
+
+#include <vector>
+
+#include "kcentroid_abstract.h"
+#include "../matrix.h"
+#include "function.h"
+#include "../std_allocator.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ class kcentroid
+ {
+ /*!
+ This object represents a weighted sum of sample points in a kernel induced
+ feature space. It can be used to kernelize any algorithm that requires only
+ the ability to perform vector addition, subtraction, scalar multiplication,
+ and inner products. It uses the sparsification technique described in the
+ paper The Kernel Recursive Least Squares Algorithm by Yaakov Engel.
+
+ To understand the code it would also be useful to consult page 114 of the book
+ Kernel Methods for Pattern Analysis by Taylor and Cristianini as well as page 554
+ (particularly equation 18.31) of the book Learning with Kernels by Scholkopf and
+ Smola. Everything you really need to know is in the Engel paper. But the other
+ books help give more perspective on the issues involved.
+
+
+ INITIAL VALUE
+ - min_strength == 0
+ - min_vect_idx == 0
+ - K_inv.size() == 0
+ - K.size() == 0
+ - dictionary.size() == 0
+ - bias == 0
+ - bias_is_stale == false
+
+ CONVENTION
+ - max_dictionary_size() == my_max_dictionary_size
+ - get_kernel() == kernel
+
+ - K.nr() == dictionary.size()
+ - K.nc() == dictionary.size()
+ - for all valid r,c:
+ - K(r,c) == kernel(dictionary[r], dictionary[c])
+ - K_inv == inv(K)
+
+ - if (dictionary.size() == my_max_dictionary_size && my_remove_oldest_first == false) then
+ - for all valid 0 < i < dictionary.size():
+ - Let STRENGTHS[i] == the delta you would get for dictionary[i] (i.e. Approximately
+ Linearly Dependent value) if you removed dictionary[i] from this object and then
+ tried to add it back in.
+ - min_strength == the minimum value from STRENGTHS
+ - min_vect_idx == the index of the element in STRENGTHS with the smallest value
+
+ !*/
+
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ kcentroid (
+ ) :
+ my_remove_oldest_first(false),
+ my_tolerance(0.001),
+ my_max_dictionary_size(1000000),
+ bias(0),
+ bias_is_stale(false)
+ {
+ clear_dictionary();
+ }
+
+ explicit kcentroid (
+ const kernel_type& kernel_,
+ scalar_type tolerance_ = 0.001,
+ unsigned long max_dictionary_size_ = 1000000,
+ bool remove_oldest_first_ = false
+ ) :
+ my_remove_oldest_first(remove_oldest_first_),
+ kernel(kernel_),
+ my_tolerance(tolerance_),
+ my_max_dictionary_size(max_dictionary_size_),
+ bias(0),
+ bias_is_stale(false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(tolerance_ > 0 && max_dictionary_size_ > 1,
+ "\tkcentroid::kcentroid()"
+ << "\n\t You have to give a positive tolerance"
+ << "\n\t this: " << this
+ << "\n\t tolerance_: " << tolerance_
+ << "\n\t max_dictionary_size_: " << max_dictionary_size_
+ );
+
+ clear_dictionary();
+ }
+
+ scalar_type tolerance() const
+ {
+ return my_tolerance;
+ }
+
+ unsigned long max_dictionary_size() const
+ {
+ return my_max_dictionary_size;
+ }
+
+ bool remove_oldest_first (
+ ) const
+ {
+ return my_remove_oldest_first;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel;
+ }
+
+ void clear_dictionary ()
+ {
+ dictionary.clear();
+ alpha.clear();
+
+ min_strength = 0;
+ min_vect_idx = 0;
+ K_inv.set_size(0,0);
+ K.set_size(0,0);
+ samples_seen = 0;
+ bias = 0;
+ bias_is_stale = false;
+ }
+
+ scalar_type operator() (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::operator()(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ // make sure the bias terms are up to date
+ refresh_bias();
+ x.refresh_bias();
+
+ scalar_type temp = x.bias + bias - 2*inner_product(x);
+
+ if (temp > 0)
+ return std::sqrt(temp);
+ else
+ return 0;
+ }
+
+ scalar_type inner_product (
+ const sample_type& x
+ ) const
+ {
+ scalar_type temp = 0;
+ for (unsigned long i = 0; i < alpha.size(); ++i)
+ temp += alpha[i]*kernel(dictionary[i], x);
+ return temp;
+ }
+
+ scalar_type inner_product (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::inner_product(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ scalar_type temp = 0;
+ for (unsigned long i = 0; i < alpha.size(); ++i)
+ {
+ for (unsigned long j = 0; j < x.alpha.size(); ++j)
+ {
+ temp += alpha[i]*x.alpha[j]*kernel(dictionary[i], x.dictionary[j]);
+ }
+ }
+ return temp;
+ }
+
+ scalar_type squared_norm (
+ ) const
+ {
+ refresh_bias();
+ return bias;
+ }
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const
+ {
+ // make sure the bias terms are up to date
+ refresh_bias();
+
+ const scalar_type kxx = kernel(x,x);
+
+ scalar_type temp = kxx + bias - 2*inner_product(x);
+ if (temp > 0)
+ return std::sqrt(temp);
+ else
+ return 0;
+ }
+
+ scalar_type samples_trained (
+ ) const
+ {
+ return samples_seen;
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+ return train_and_maybe_test(x,cscale,xscale,true);
+ }
+
+ void train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+ train_and_maybe_test(x,cscale,xscale,false);
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+ return train_and_maybe_test(x,cscale,xscale,true);
+ }
+
+ void scale_by (
+ scalar_type cscale
+ )
+ {
+ for (unsigned long i = 0; i < alpha.size(); ++i)
+ {
+ alpha[i] = cscale*alpha[i];
+ }
+ }
+
+ void train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+ train_and_maybe_test(x,cscale,xscale,false);
+ }
+
+ void swap (
+ kcentroid& item
+ )
+ {
+ exchange(min_strength, item.min_strength);
+ exchange(min_vect_idx, item.min_vect_idx);
+ exchange(my_remove_oldest_first, item.my_remove_oldest_first);
+
+ exchange(kernel, item.kernel);
+ dictionary.swap(item.dictionary);
+ alpha.swap(item.alpha);
+ K_inv.swap(item.K_inv);
+ K.swap(item.K);
+ exchange(my_tolerance, item.my_tolerance);
+ exchange(samples_seen, item.samples_seen);
+ exchange(bias, item.bias);
+ a.swap(item.a);
+ k.swap(item.k);
+ exchange(bias_is_stale, item.bias_is_stale);
+ exchange(my_max_dictionary_size, item.my_max_dictionary_size);
+ }
+
+ unsigned long dictionary_size (
+ ) const { return dictionary.size(); }
+
+ friend void serialize(const kcentroid& item, std::ostream& out)
+ {
+ serialize(item.min_strength, out);
+ serialize(item.min_vect_idx, out);
+ serialize(item.my_remove_oldest_first, out);
+
+ serialize(item.kernel, out);
+ serialize(item.dictionary, out);
+ serialize(item.alpha, out);
+ serialize(item.K_inv, out);
+ serialize(item.K, out);
+ serialize(item.my_tolerance, out);
+ serialize(item.samples_seen, out);
+ serialize(item.bias, out);
+ serialize(item.bias_is_stale, out);
+ serialize(item.my_max_dictionary_size, out);
+ }
+
+ friend void deserialize(kcentroid& item, std::istream& in)
+ {
+ deserialize(item.min_strength, in);
+ deserialize(item.min_vect_idx, in);
+ deserialize(item.my_remove_oldest_first, in);
+
+ deserialize(item.kernel, in);
+ deserialize(item.dictionary, in);
+ deserialize(item.alpha, in);
+ deserialize(item.K_inv, in);
+ deserialize(item.K, in);
+ deserialize(item.my_tolerance, in);
+ deserialize(item.samples_seen, in);
+ deserialize(item.bias, in);
+ deserialize(item.bias_is_stale, in);
+ deserialize(item.my_max_dictionary_size, in);
+ }
+
+ distance_function<kernel_type> get_distance_function (
+ ) const
+ {
+ refresh_bias();
+ return distance_function<kernel_type>(mat(alpha),
+ bias,
+ kernel,
+ mat(dictionary));
+ }
+
+ private:
+
+ void refresh_bias (
+ ) const
+ {
+ if (bias_is_stale)
+ {
+ bias_is_stale = false;
+ // recompute the bias term
+ bias = sum(pointwise_multiply(K, mat(alpha)*trans(mat(alpha))));
+ }
+ }
+
+ scalar_type train_and_maybe_test (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale,
+ bool do_test
+ )
+ {
+ scalar_type test_result = 0;
+ const scalar_type kx = kernel(x,x);
+ if (alpha.size() == 0)
+ {
+ // just ignore this sample if it is the zero vector (or really close to being zero)
+ if (std::abs(kx) > std::numeric_limits<scalar_type>::epsilon())
+ {
+ // set initial state since this is the first training example we have seen
+
+ K_inv.set_size(1,1);
+ K_inv(0,0) = 1/kx;
+ K.set_size(1,1);
+ K(0,0) = kx;
+
+ alpha.push_back(xscale);
+ dictionary.push_back(x);
+ }
+ else
+ {
+ // the distance from an empty kcentroid and the zero vector is zero by definition.
+ return 0;
+ }
+ }
+ else
+ {
+ // fill in k
+ k.set_size(alpha.size());
+ for (long r = 0; r < k.nr(); ++r)
+ k(r) = kernel(x,dictionary[r]);
+
+ if (do_test)
+ {
+ refresh_bias();
+ test_result = std::sqrt(kx + bias - 2*trans(mat(alpha))*k);
+ }
+
+ // compute the error we would have if we approximated the new x sample
+ // with the dictionary. That is, do the ALD test from the KRLS paper.
+ a = K_inv*k;
+ scalar_type delta = kx - trans(k)*a;
+
+ // if this new vector isn't approximately linearly dependent on the vectors
+ // in our dictionary.
+ if (delta > min_strength && delta > my_tolerance)
+ {
+ bool need_to_update_min_strength = false;
+ if (dictionary.size() >= my_max_dictionary_size)
+ {
+ // We need to remove one of the old members of the dictionary before
+ // we proceed with adding a new one.
+ long idx_to_remove;
+ if (my_remove_oldest_first)
+ {
+ // remove the oldest one
+ idx_to_remove = 0;
+ }
+ else
+ {
+ // if we have never computed the min_strength then we should compute it
+ if (min_strength == 0)
+ recompute_min_strength();
+
+ // select the dictionary vector that is most linearly dependent for removal
+ idx_to_remove = min_vect_idx;
+ need_to_update_min_strength = true;
+ }
+
+ remove_dictionary_vector(idx_to_remove);
+
+ // recompute these guys since they were computed with the old
+ // kernel matrix
+ k = remove_row(k,idx_to_remove);
+ a = K_inv*k;
+ delta = kx - trans(k)*a;
+ }
+
+ // add x to the dictionary
+ dictionary.push_back(x);
+
+
+ // update K_inv by computing the new one in the temp matrix (equation 3.14)
+ matrix<scalar_type,0,0,mem_manager_type> temp(K_inv.nr()+1, K_inv.nc()+1);
+ // update the middle part of the matrix
+ set_subm(temp, get_rect(K_inv)) = K_inv + a*trans(a)/delta;
+ // update the right column of the matrix
+ set_subm(temp, 0, K_inv.nr(),K_inv.nr(),1) = -a/delta;
+ // update the bottom row of the matrix
+ set_subm(temp, K_inv.nr(), 0, 1, K_inv.nr()) = trans(-a/delta);
+ // update the bottom right corner of the matrix
+ temp(K_inv.nr(), K_inv.nc()) = 1/delta;
+ // put temp into K_inv
+ temp.swap(K_inv);
+
+
+
+ // update K (the kernel matrix)
+ temp.set_size(K.nr()+1, K.nc()+1);
+ set_subm(temp, get_rect(K)) = K;
+ // update the right column of the matrix
+ set_subm(temp, 0, K.nr(),K.nr(),1) = k;
+ // update the bottom row of the matrix
+ set_subm(temp, K.nr(), 0, 1, K.nr()) = trans(k);
+ temp(K.nr(), K.nc()) = kx;
+ // put temp into K
+ temp.swap(K);
+
+
+ // now update the alpha vector
+ for (unsigned long i = 0; i < alpha.size(); ++i)
+ {
+ alpha[i] *= cscale;
+ }
+ alpha.push_back(xscale);
+
+
+ if (need_to_update_min_strength)
+ {
+ // now we have to recompute the min_strength in this case
+ recompute_min_strength();
+ }
+ }
+ else
+ {
+ // update the alpha vector so that this new sample has been added into
+ // the mean vector we are accumulating
+ for (unsigned long i = 0; i < alpha.size(); ++i)
+ {
+ alpha[i] = cscale*alpha[i] + xscale*a(i);
+ }
+ }
+ }
+
+ bias_is_stale = true;
+
+ return test_result;
+ }
+
+ void remove_dictionary_vector (
+ long i
+ )
+ /*!
+ requires
+ - 0 <= i < dictionary.size()
+ ensures
+ - #dictionary.size() == dictionary.size() - 1
+ - #alpha.size() == alpha.size() - 1
+ - updates the K_inv matrix so that it is still a proper inverse of the
+ kernel matrix
+ - also removes the necessary row and column from the K matrix
+ - uses the this->a variable so after this function runs that variable
+ will contain a different value.
+ !*/
+ {
+ // remove the dictionary vector
+ dictionary.erase(dictionary.begin()+i);
+
+ // remove the i'th vector from the inverse kernel matrix. This formula is basically
+ // just the reverse of the way K_inv is updated by equation 3.14 during normal training.
+ K_inv = removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i);
+
+ // now compute the updated alpha values to take account that we just removed one of
+ // our dictionary vectors
+ a = (K_inv*remove_row(K,i)*mat(alpha));
+
+ // now copy over the new alpha values
+ alpha.resize(alpha.size()-1);
+ for (unsigned long k = 0; k < alpha.size(); ++k)
+ {
+ alpha[k] = a(k);
+ }
+
+ // update the K matrix as well
+ K = removerc(K,i,i);
+ }
+
+ void recompute_min_strength (
+ )
+ /*!
+ ensures
+ - recomputes the min_strength and min_vect_idx values
+ so that they are correct with respect to the CONVENTION
+ - uses the this->a variable so after this function runs that variable
+ will contain a different value.
+ !*/
+ {
+ min_strength = std::numeric_limits<scalar_type>::max();
+
+ // here we loop over each dictionary vector and compute what its delta would be if
+ // we were to remove it from the dictionary and then try to add it back in.
+ for (unsigned long i = 0; i < dictionary.size(); ++i)
+ {
+ // compute a = K_inv*k but where dictionary vector i has been removed
+ a = (removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i)) *
+ (remove_row(colm(K,i),i));
+ scalar_type delta = K(i,i) - trans(remove_row(colm(K,i),i))*a;
+
+ if (delta < min_strength)
+ {
+ min_strength = delta;
+ min_vect_idx = i;
+ }
+ }
+ }
+
+
+
+ typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type;
+ typedef std_allocator<scalar_type, mem_manager_type> alloc_scalar_type;
+ typedef std::vector<sample_type,alloc_sample_type> dictionary_vector_type;
+ typedef std::vector<scalar_type,alloc_scalar_type> alpha_vector_type;
+
+
+ scalar_type min_strength;
+ unsigned long min_vect_idx;
+ bool my_remove_oldest_first;
+
+ kernel_type kernel;
+ dictionary_vector_type dictionary;
+ alpha_vector_type alpha;
+
+ matrix<scalar_type,0,0,mem_manager_type> K_inv;
+ matrix<scalar_type,0,0,mem_manager_type> K;
+
+ scalar_type my_tolerance;
+ unsigned long my_max_dictionary_size;
+ scalar_type samples_seen;
+ mutable scalar_type bias;
+ mutable bool bias_is_stale;
+
+
+ // temp variables here just so we don't have to reconstruct them over and over. Thus,
+ // they aren't really part of the state of this object.
+ matrix<scalar_type,0,1,mem_manager_type> a;
+ matrix<scalar_type,0,1,mem_manager_type> k;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ void swap(kcentroid<kernel_type>& a, kcentroid<kernel_type>& b)
+ { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KCENTROId_
+
diff --git a/ml/dlib/dlib/svm/kcentroid_abstract.h b/ml/dlib/dlib/svm/kcentroid_abstract.h
new file mode 100644
index 000000000..44b94c813
--- /dev/null
+++ b/ml/dlib/dlib/svm/kcentroid_abstract.h
@@ -0,0 +1,339 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_KCENTROId_ABSTRACT_
+#ifdef DLIB_KCENTROId_ABSTRACT_
+
+#include "../algs.h"
+#include "../serialize.h"
+#include "kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename kernel_type
+ >
+ class kcentroid
+ {
+ /*!
+ REQUIREMENTS ON kernel_type
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ INITIAL VALUE
+ - dictionary_size() == 0
+ - samples_trained() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a weighted sum of sample points in a kernel induced
+ feature space. It can be used to kernelize any algorithm that requires only
+ the ability to perform vector addition, subtraction, scalar multiplication,
+ and inner products.
+
+ An example use of this object is as an online algorithm for recursively estimating
+ the centroid of a sequence of training points. This object then allows you to
+ compute the distance between the centroid and any test points. So you can use
+ this object to predict how similar a test point is to the data this object has
+ been trained on (larger distances from the centroid indicate dissimilarity/anomalous
+ points).
+
+ Also note that the algorithm internally keeps a set of "dictionary vectors"
+ that are used to represent the centroid. You can force the algorithm to use
+ no more than a set number of vectors by setting the 3rd constructor argument
+ to whatever you want.
+
+ This object uses the sparsification technique described in the paper The
+ Kernel Recursive Least Squares Algorithm by Yaakov Engel. This technique
+ allows us to keep the number of dictionary vectors down to a minimum. In fact,
+ the object has a user selectable tolerance parameter that controls the trade off
+ between accuracy and number of stored dictionary vectors.
+ !*/
+
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ kcentroid (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ - #tolerance() == 0.001
+ - #get_kernel() == kernel_type() (i.e. whatever the kernel's default value is)
+ - #max_dictionary_size() == 1000000
+ - #remove_oldest_first() == false
+ !*/
+
+ explicit kcentroid (
+ const kernel_type& kernel_,
+ scalar_type tolerance_ = 0.001,
+ unsigned long max_dictionary_size_ = 1000000,
+ bool remove_oldest_first_ = false
+ );
+ /*!
+ requires
+ - tolerance > 0
+ - max_dictionary_size_ > 1
+ ensures
+ - this object is properly initialized
+ - #tolerance() == tolerance_
+ - #get_kernel() == kernel_
+ - #max_dictionary_size() == max_dictionary_size_
+ - #remove_oldest_first() == remove_oldest_first_
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the kernel used by this object
+ !*/
+
+ unsigned long max_dictionary_size(
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of dictionary vectors this object will
+ use at a time. That is, dictionary_size() will never be greater
+ than max_dictionary_size().
+ !*/
+
+ bool remove_oldest_first (
+ ) const;
+ /*!
+ ensures
+ - When the maximum dictionary size is reached this object sometimes
+ needs to discard dictionary vectors when new samples are added via
+ one of the train functions. When this happens this object chooses
+ the dictionary vector to discard based on the setting of the
+ remove_oldest_first() parameter.
+ - if (remove_oldest_first() == true) then
+ - This object discards the oldest dictionary vectors when necessary.
+ This is an appropriate mode when using this object in an online
+ setting and the input training samples come from a slowly
+ varying distribution.
+ - else (remove_oldest_first() == false) then
+ - This object discards the most linearly dependent dictionary vectors
+ when necessary. This it the default behavior and should be used
+ in most cases.
+ !*/
+
+ unsigned long dictionary_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of basis vectors in the dictionary. These are
+ the basis vectors used by this object to represent a point in kernel
+ feature space.
+ !*/
+
+ scalar_type samples_trained (
+ ) const;
+ /*!
+ ensures
+ - returns the number of samples this object has been trained on so far
+ !*/
+
+ scalar_type tolerance(
+ ) const;
+ /*!
+ ensures
+ - returns the tolerance to use for the approximately linearly dependent
+ test used for sparsification (see the KRLS paper for details). This is
+ a number which governs how accurately this object will approximate the
+ centroid it is learning. Smaller values generally result in a more
+ accurate estimate while also resulting in a bigger set of vectors in
+ the dictionary. Bigger tolerances values result in a less accurate
+ estimate but also in less dictionary vectors. (Note that in any case,
+ the max_dictionary_size() limits the number of dictionary vectors no
+ matter the setting of the tolerance)
+ - The exact meaning of the tolerance parameter is the following:
+ Imagine that we have an empirical_kernel_map that contains all
+ the current dictionary vectors. Then the tolerance is the minimum
+ projection error (as given by empirical_kernel_map::project()) required
+ to cause us to include a new vector in the dictionary. So each time
+ you call train() the kcentroid basically just computes the projection
+ error for that new sample and if it is larger than the tolerance
+ then that new sample becomes part of the dictionary.
+ !*/
+
+ void clear_dictionary (
+ );
+ /*!
+ ensures
+ - clears out all learned data (e.g. #dictionary_size() == 0)
+ - #samples_seen() == 0
+ !*/
+
+ scalar_type operator() (
+ const kcentroid& x
+ ) const;
+ /*!
+ requires
+ - x.get_kernel() == get_kernel()
+ ensures
+ - returns the distance in kernel feature space between this centroid and the
+ centroid represented by x.
+ !*/
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const;
+ /*!
+ ensures
+ - returns the distance in kernel feature space between the sample x and the
+ current estimate of the centroid of the training samples given
+ to this object so far.
+ !*/
+
+ scalar_type inner_product (
+ const sample_type& x
+ ) const;
+ /*!
+ ensures
+ - returns the inner product of the given x point and the current
+ estimate of the centroid of the training samples given to this object
+ so far.
+ !*/
+
+ scalar_type inner_product (
+ const kcentroid& x
+ ) const;
+ /*!
+ requires
+ - x.get_kernel() == get_kernel()
+ ensures
+ - returns the inner product between x and this centroid object.
+ !*/
+
+ scalar_type squared_norm (
+ ) const;
+ /*!
+ ensures
+ - returns the squared norm of the centroid vector represented by this
+ object. I.e. returns this->inner_product(*this)
+ !*/
+
+ void train (
+ const sample_type& x
+ );
+ /*!
+ ensures
+ - adds the sample x into the current estimate of the centroid
+ - also note that calling this function is equivalent to calling
+ train(x, samples_trained()/(samples_trained()+1.0, 1.0/(samples_trained()+1.0).
+ That is, this function finds the normal unweighted centroid of all training points.
+ !*/
+
+ void train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ );
+ /*!
+ ensures
+ - adds the sample x into the current estimate of the centroid but
+ uses a user given scale. That is, this function performs:
+ - new_centroid = cscale*old_centroid + xscale*x
+ - This function allows you to weight different samples however
+ you want.
+ !*/
+
+ void scale_by (
+ scalar_type cscale
+ );
+ /*!
+ ensures
+ - multiplies the current centroid vector by the given scale value.
+ This function is equivalent to calling train(some_x_value, cscale, 0).
+ So it performs:
+ - new_centroid == cscale*old_centroid
+ !*/
+
+ scalar_type test_and_train (
+ const sample_type& x
+ );
+ /*!
+ ensures
+ - calls train(x)
+ - returns (*this)(x)
+ - The reason this function exists is because train() and operator()
+ both compute some of the same things. So this function is more efficient
+ than calling both individually.
+ !*/
+
+ scalar_type test_and_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ );
+ /*!
+ ensures
+ - calls train(x,cscale,xscale)
+ - returns (*this)(x)
+ - The reason this function exists is because train() and operator()
+ both compute some of the same things. So this function is more efficient
+ than calling both individually.
+ !*/
+
+ void swap (
+ kcentroid& item
+ );
+ /*!
+ ensures
+ - swaps *this with item
+ !*/
+
+ distance_function<kernel_type> get_distance_function (
+ ) const;
+ /*!
+ ensures
+ - returns a distance function F that represents the point learned
+ by this object so far. I.e. it is the case that:
+ - for all x: F(x) == (*this)(x)
+ !*/
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type
+ >
+ void swap(
+ kcentroid<kernel_type>& a,
+ kcentroid<kernel_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void serialize (
+ const kcentroid<kernel_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for kcentroid objects
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void deserialize (
+ kcentroid<kernel_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for kcentroid objects
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KCENTROId_ABSTRACT_
+
diff --git a/ml/dlib/dlib/svm/kcentroid_overloads.h b/ml/dlib/dlib/svm/kcentroid_overloads.h
new file mode 100644
index 000000000..9c39f3d78
--- /dev/null
+++ b/ml/dlib/dlib/svm/kcentroid_overloads.h
@@ -0,0 +1,1324 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_KCENTROId_OVERLOADS_
+#define DLIB_KCENTROId_OVERLOADS_
+
+#include "kcentroid_abstract.h"
+#include "sparse_kernel.h"
+#include "sparse_vector.h"
+#include <map>
+
+namespace dlib
+{
+ /*
+ This file contains optimized overloads of the kcentroid object for the following
+ linear cases:
+ kcentroid<linear_kernel<T>>
+ kcentroid<sparse_linear_kernel<T>>
+ kcentroid<offset_kernel<linear_kernel<T>>>
+ kcentroid<offset_kernel<sparse_linear_kernel<T>>>
+ */
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Overloads for when kernel_type == linear_kernel
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class kcentroid<linear_kernel<T> >
+ {
+
+
+ typedef linear_kernel<T> kernel_type;
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+
+ explicit kcentroid (
+ const kernel_type& kernel_,
+ scalar_type tolerance_ = 0.001,
+ unsigned long max_dictionary_size_ = 1000000,
+ bool remove_oldest_first_ = false
+ ) :
+ my_remove_oldest_first(remove_oldest_first_),
+ kernel(kernel_),
+ my_tolerance(tolerance_),
+ my_max_dictionary_size(max_dictionary_size_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(tolerance_ >= 0 && max_dictionary_size_ > 0,
+ "\tkcentroid::kcentroid()"
+ << "\n\t You have to give a positive tolerance"
+ << "\n\t this: " << this
+ << "\n\t tolerance_: " << tolerance_
+ << "\n\t max_dictionary_size_: " << max_dictionary_size_
+ );
+
+ clear_dictionary();
+ }
+
+ scalar_type tolerance() const { return my_tolerance; }
+ unsigned long max_dictionary_size() const { return my_max_dictionary_size; }
+ bool remove_oldest_first () const { return my_remove_oldest_first; }
+ const kernel_type& get_kernel () const { return kernel; }
+ scalar_type samples_trained () const { return samples_seen; }
+
+ void clear_dictionary ()
+ {
+ samples_seen = 0;
+ set_all_elements(w, 0);
+ alpha = 0;
+ }
+
+ scalar_type operator() (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::operator()(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ if (w.size() > 0)
+ {
+ if (x.w.size() > 0)
+ return length(alpha*w - x.alpha*x.w);
+ else
+ return alpha*length(w);
+ }
+ else
+ {
+ if (x.w.size() > 0)
+ return x.alpha*length(x.w);
+ else
+ return 0;
+ }
+ }
+
+ scalar_type inner_product (
+ const sample_type& x
+ ) const
+ {
+ if (w.size() > 0)
+ return alpha*trans(w)*x;
+ else
+ return 0;
+ }
+
+ scalar_type inner_product (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::inner_product(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ if (w.size() > 0 && x.w.size() > 0)
+ return alpha*x.alpha*trans(w)*x.w;
+ else
+ return 0;
+ }
+
+ scalar_type squared_norm (
+ ) const
+ {
+ if (w.size() > 0)
+ return alpha*alpha*trans(w)*w;
+ else
+ return 0;
+ }
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const
+ {
+ if (w.size() > 0)
+ return length(x-alpha*w);
+ else
+ return length(x);
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+
+ do_train(x, cscale, xscale);
+
+ return (*this)(x);
+ }
+
+ void train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+
+ do_train(x, cscale, xscale);
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+
+ do_train(x, cscale, xscale);
+
+ return (*this)(x);
+ }
+
+ void scale_by (
+ scalar_type cscale
+ )
+ {
+ alpha *= cscale;
+ }
+
+ void train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+ do_train(x, cscale, xscale);
+ }
+
+ void swap (
+ kcentroid& item
+ )
+ {
+ exchange(my_remove_oldest_first, item.my_remove_oldest_first);
+ exchange(kernel, item.kernel);
+ exchange(w, item.w);
+ exchange(alpha, item.alpha);
+ exchange(my_tolerance, item.my_tolerance);
+ exchange(my_max_dictionary_size, item.my_max_dictionary_size);
+ exchange(samples_seen, item.samples_seen);
+ }
+
+ unsigned long dictionary_size (
+ ) const
+ {
+ if (samples_seen > 0)
+ return 1;
+ else
+ return 0;
+ }
+
+ friend void serialize(const kcentroid& item, std::ostream& out)
+ {
+ serialize(item.my_remove_oldest_first, out);
+ serialize(item.kernel, out);
+ serialize(item.w, out);
+ serialize(item.alpha, out);
+ serialize(item.my_tolerance, out);
+ serialize(item.my_max_dictionary_size, out);
+ serialize(item.samples_seen, out);
+ }
+
+ friend void deserialize(kcentroid& item, std::istream& in)
+ {
+ deserialize(item.my_remove_oldest_first, in);
+ deserialize(item.kernel, in);
+ deserialize(item.w, in);
+ deserialize(item.alpha, in);
+ deserialize(item.my_tolerance, in);
+ deserialize(item.my_max_dictionary_size, in);
+ deserialize(item.samples_seen, in);
+ }
+
+ distance_function<kernel_type> get_distance_function (
+ ) const
+ {
+ if (samples_seen > 0)
+ {
+ typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
+ typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
+
+ temp_basis_vectors.set_size(1);
+ temp_basis_vectors(0) = w;
+ temp_alpha.set_size(1);
+ temp_alpha(0) = alpha;
+
+ return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
+ }
+ else
+ {
+ return distance_function<kernel_type>(kernel);
+ }
+ }
+
+ private:
+
+ void do_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ set_size_of_w(x);
+
+ const scalar_type temp = cscale*alpha;
+
+ if (temp != 0)
+ {
+ w = w + xscale*x/temp;
+ alpha = temp;
+ }
+ else
+ {
+ w = cscale*alpha*w + xscale*x;
+ alpha = 1;
+ }
+ }
+
+ void set_size_of_w (
+ const sample_type& x
+ )
+ {
+ if (x.size() != w.size())
+ {
+ w.set_size(x.nr(), x.nc());
+ set_all_elements(w, 0);
+ alpha = 0;
+ }
+ }
+
+ bool my_remove_oldest_first;
+
+ kernel_type kernel;
+
+ sample_type w;
+ scalar_type alpha;
+
+
+ scalar_type my_tolerance;
+ unsigned long my_max_dictionary_size;
+ scalar_type samples_seen;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Overloads for when kernel_type == offset_kernel<linear_kernel>
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class kcentroid<offset_kernel<linear_kernel<T> > >
+ {
+
+ /*!
+ INITIAL VALUE
+ - x_extra == sqrt(kernel.offset)
+
+ CONVENTION
+ - x_extra == sqrt(kernel.offset)
+ - w_extra == the value of the extra dimension tacked onto the
+ end of the w vector
+ !*/
+
+ typedef offset_kernel<linear_kernel<T> > kernel_type;
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+
+ explicit kcentroid (
+ const kernel_type& kernel_,
+ scalar_type tolerance_ = 0.001,
+ unsigned long max_dictionary_size_ = 1000000,
+ bool remove_oldest_first_ = false
+ ) :
+ my_remove_oldest_first(remove_oldest_first_),
+ kernel(kernel_),
+ my_tolerance(tolerance_),
+ my_max_dictionary_size(max_dictionary_size_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(tolerance_ >= 0 && max_dictionary_size_ > 0,
+ "\tkcentroid::kcentroid()"
+ << "\n\t You have to give a positive tolerance"
+ << "\n\t this: " << this
+ << "\n\t tolerance_: " << tolerance_
+ << "\n\t max_dictionary_size_: " << max_dictionary_size_
+ );
+
+ x_extra = std::sqrt(kernel.offset);
+
+ clear_dictionary();
+ }
+
+ scalar_type tolerance() const { return my_tolerance; }
+ unsigned long max_dictionary_size() const { return my_max_dictionary_size; }
+ bool remove_oldest_first () const { return my_remove_oldest_first; }
+ const kernel_type& get_kernel () const { return kernel; }
+ scalar_type samples_trained () const { return samples_seen; }
+
+ void clear_dictionary ()
+ {
+ samples_seen = 0;
+ set_all_elements(w, 0);
+ alpha = 0;
+ w_extra = x_extra;
+ }
+
+ scalar_type operator() (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::operator()(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ if (w.size() > 0)
+ {
+ if (x.w.size() > 0)
+ {
+ scalar_type temp1 = length_squared(alpha*w - x.alpha*x.w);
+ scalar_type temp2 = alpha*w_extra - x.alpha*x.w_extra;
+ return std::sqrt(temp1 + temp2*temp2);
+ }
+ else
+ {
+ return alpha*std::sqrt(length_squared(w) + w_extra*w_extra);
+ }
+ }
+ else
+ {
+ if (x.w.size() > 0)
+ return x.alpha*std::sqrt(length_squared(x.w) + x.w_extra*x.w_extra);
+ else
+ return 0;
+ }
+ }
+
+ scalar_type inner_product (
+ const sample_type& x
+ ) const
+ {
+ if (w.size() > 0)
+ return alpha*(trans(w)*x + w_extra*x_extra);
+ else
+ return 0;
+ }
+
+ scalar_type inner_product (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::inner_product(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ if (w.size() > 0 && x.w.size() > 0)
+ return alpha*x.alpha*(trans(w)*x.w + w_extra*x.w_extra);
+ else
+ return 0;
+ }
+
+ scalar_type squared_norm (
+ ) const
+ {
+ if (w.size() > 0)
+ return alpha*alpha*(trans(w)*w + w_extra*w_extra);
+ else
+ return 0;
+ }
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const
+ {
+ if (w.size() > 0)
+ {
+ scalar_type temp1 = length_squared(x-alpha*w);
+ scalar_type temp2 = x_extra - alpha*w_extra;
+ return std::sqrt(temp1 + temp2*temp2);
+ }
+ else
+ {
+ return std::sqrt(length_squared(x) + x_extra*x_extra);
+ }
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+
+ do_train(x, cscale, xscale);
+
+ return (*this)(x);
+ }
+
+ void train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+
+ do_train(x, cscale, xscale);
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+
+ do_train(x, cscale, xscale);
+
+ return (*this)(x);
+ }
+
+ void scale_by (
+ scalar_type cscale
+ )
+ {
+ alpha *= cscale;
+ w_extra *= cscale;
+ }
+
+ void train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+ do_train(x, cscale, xscale);
+ }
+
+ void swap (
+ kcentroid& item
+ )
+ {
+ exchange(my_remove_oldest_first, item.my_remove_oldest_first);
+ exchange(kernel, item.kernel);
+ exchange(w, item.w);
+ exchange(alpha, item.alpha);
+ exchange(w_extra, item.w_extra);
+ exchange(x_extra, item.x_extra);
+ exchange(my_tolerance, item.my_tolerance);
+ exchange(my_max_dictionary_size, item.my_max_dictionary_size);
+ exchange(samples_seen, item.samples_seen);
+ }
+
+ unsigned long dictionary_size (
+ ) const
+ {
+ if (samples_seen > 0)
+ {
+ if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon())
+ return 1;
+ else
+ return 2;
+ }
+ else
+ return 0;
+ }
+
+ friend void serialize(const kcentroid& item, std::ostream& out)
+ {
+ serialize(item.my_remove_oldest_first, out);
+ serialize(item.kernel, out);
+ serialize(item.w, out);
+ serialize(item.alpha, out);
+ serialize(item.w_extra, out);
+ serialize(item.x_extra, out);
+ serialize(item.my_tolerance, out);
+ serialize(item.my_max_dictionary_size, out);
+ serialize(item.samples_seen, out);
+ }
+
+ friend void deserialize(kcentroid& item, std::istream& in)
+ {
+ deserialize(item.my_remove_oldest_first, in);
+ deserialize(item.kernel, in);
+ deserialize(item.w, in);
+ deserialize(item.alpha, in);
+ deserialize(item.w_extra, in);
+ deserialize(item.x_extra, in);
+ deserialize(item.my_tolerance, in);
+ deserialize(item.my_max_dictionary_size, in);
+ deserialize(item.samples_seen, in);
+ }
+
+ distance_function<kernel_type> get_distance_function (
+ ) const
+ {
+
+ if (samples_seen > 0)
+ {
+ typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
+ typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
+
+ // What we are doing here needs a bit of explanation. The w vector
+ // has an implicit extra dimension tacked on to it with the value of w_extra.
+ // The kernel we are using takes normal vectors and implicitly tacks the value
+ // x_extra onto their end. So what we are doing here is scaling w so that
+ // the value it should have tacked onto it is x_scale. Note that we also
+ // adjust alpha so that the combination of alpha*w stays the same.
+ scalar_type scale;
+
+ // if w_extra is basically greater than 0
+ if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon())
+ {
+ scale = (x_extra/w_extra);
+ temp_basis_vectors.set_size(1);
+ temp_alpha.set_size(1);
+ temp_basis_vectors(0) = w*scale;
+ temp_alpha(0) = alpha/scale;
+ }
+ else
+ {
+ // In this case w_extra is zero. So the only way we can get the same
+ // thing in the output basis vector set is by using two vectors
+ temp_basis_vectors.set_size(2);
+ temp_alpha.set_size(2);
+ temp_basis_vectors(0) = 2*w;
+ temp_alpha(0) = alpha;
+ temp_basis_vectors(1) = w;
+ temp_alpha(1) = -alpha;
+ }
+
+
+ return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
+ }
+ else
+ {
+ return distance_function<kernel_type>(kernel);
+ }
+ }
+
+ private:
+
+ void do_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ set_size_of_w(x);
+
+ const scalar_type temp = cscale*alpha;
+
+ if (temp != 0)
+ {
+ w = w + xscale*x/temp;
+ w_extra = w_extra + xscale*x_extra/temp;
+ alpha = temp;
+ }
+ else
+ {
+ w = cscale*alpha*w + xscale*x;
+ w_extra = cscale*alpha*w_extra + xscale*x_extra;
+ alpha = 1;
+ }
+ }
+
+ void set_size_of_w (
+ const sample_type& x
+ )
+ {
+ if (x.size() != w.size())
+ {
+ w.set_size(x.nr(), x.nc());
+ set_all_elements(w, 0);
+ alpha = 0;
+ w_extra = x_extra;
+ }
+ }
+
+ bool my_remove_oldest_first;
+
+ kernel_type kernel;
+
+ sample_type w;
+ scalar_type alpha;
+
+ scalar_type w_extra;
+ scalar_type x_extra;
+
+
+ scalar_type my_tolerance;
+ unsigned long my_max_dictionary_size;
+ scalar_type samples_seen;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Overloads for when kernel_type == sparse_linear_kernel
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class kcentroid<sparse_linear_kernel<T> >
+ {
+
+
+ typedef sparse_linear_kernel<T> kernel_type;
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+
+ explicit kcentroid (
+ const kernel_type& kernel_,
+ scalar_type tolerance_ = 0.001,
+ unsigned long max_dictionary_size_ = 1000000,
+ bool remove_oldest_first_ = false
+ ) :
+ my_remove_oldest_first(remove_oldest_first_),
+ kernel(kernel_),
+ my_tolerance(tolerance_),
+ my_max_dictionary_size(max_dictionary_size_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(tolerance_ >= 0 && max_dictionary_size_ > 0,
+ "\tkcentroid::kcentroid()"
+ << "\n\t You have to give a positive tolerance"
+ << "\n\t this: " << this
+ << "\n\t tolerance_: " << tolerance_
+ << "\n\t max_dictionary_size_: " << max_dictionary_size_
+ );
+
+ clear_dictionary();
+ }
+
+ scalar_type tolerance() const { return my_tolerance; }
+ unsigned long max_dictionary_size() const { return my_max_dictionary_size; }
+ bool remove_oldest_first () const { return my_remove_oldest_first; }
+ const kernel_type& get_kernel () const { return kernel; }
+ scalar_type samples_trained () const { return samples_seen; }
+
+ void clear_dictionary ()
+ {
+ samples_seen = 0;
+ w.clear();
+ alpha = 0;
+ }
+
+ scalar_type operator() (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::operator()(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ return distance(alpha,w , x.alpha,x.w);
+ }
+
+ scalar_type inner_product (
+ const sample_type& x
+ ) const
+ {
+ return alpha*dot(w,x);
+ }
+
+ scalar_type inner_product (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::inner_product(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ return alpha*x.alpha*dot(w,x.w);
+ }
+
+ scalar_type squared_norm (
+ ) const
+ {
+ return alpha*alpha*length_squared(w);
+ }
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const
+ {
+ return distance(static_cast<scalar_type>(1), x, alpha, w);
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+
+ do_train(x, cscale, xscale);
+
+ return (*this)(x);
+ }
+
+ void train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+
+ do_train(x, cscale, xscale);
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+
+ do_train(x, cscale, xscale);
+
+ return (*this)(x);
+ }
+
+ void scale_by (
+ scalar_type cscale
+ )
+ {
+ alpha *= cscale;
+ }
+
+ void train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+ do_train(x, cscale, xscale);
+ }
+
+ void swap (
+ kcentroid& item
+ )
+ {
+ exchange(my_remove_oldest_first, item.my_remove_oldest_first);
+ exchange(kernel, item.kernel);
+ exchange(w, item.w);
+ exchange(alpha, item.alpha);
+ exchange(my_tolerance, item.my_tolerance);
+ exchange(my_max_dictionary_size, item.my_max_dictionary_size);
+ exchange(samples_seen, item.samples_seen);
+ }
+
+ unsigned long dictionary_size (
+ ) const
+ {
+ if (samples_seen > 0)
+ return 1;
+ else
+ return 0;
+ }
+
+ friend void serialize(const kcentroid& item, std::ostream& out)
+ {
+ serialize(item.my_remove_oldest_first, out);
+ serialize(item.kernel, out);
+ serialize(item.w, out);
+ serialize(item.alpha, out);
+ serialize(item.my_tolerance, out);
+ serialize(item.my_max_dictionary_size, out);
+ serialize(item.samples_seen, out);
+ }
+
+ friend void deserialize(kcentroid& item, std::istream& in)
+ {
+ deserialize(item.my_remove_oldest_first, in);
+ deserialize(item.kernel, in);
+ deserialize(item.w, in);
+ deserialize(item.alpha, in);
+ deserialize(item.my_tolerance, in);
+ deserialize(item.my_max_dictionary_size, in);
+ deserialize(item.samples_seen, in);
+ }
+
+ distance_function<kernel_type> get_distance_function (
+ ) const
+ {
+ if (samples_seen > 0)
+ {
+ typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
+ typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
+
+ temp_basis_vectors.set_size(1);
+ temp_basis_vectors(0) = sample_type(w.begin(), w.end());
+ temp_alpha.set_size(1);
+ temp_alpha(0) = alpha;
+
+ return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
+ }
+ else
+ {
+ return distance_function<kernel_type>(kernel);
+ }
+ }
+
+ private:
+
+ void do_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ const scalar_type temp = cscale*alpha;
+
+ if (temp != 0)
+ {
+ // compute w += xscale*x/temp
+ typename sample_type::const_iterator i;
+ for (i = x.begin(); i != x.end(); ++i)
+ {
+ w[i->first] += xscale*(i->second)/temp;
+ }
+
+ alpha = temp;
+ }
+ else
+ {
+ // first compute w = cscale*alpha*w
+ for (typename std::map<unsigned long,scalar_type>::iterator i = w.begin(); i != w.end(); ++i)
+ {
+ i->second *= cscale*alpha;
+ }
+
+ // now compute w += xscale*x
+ for (typename sample_type::const_iterator i = x.begin(); i != x.end(); ++i)
+ {
+ w[i->first] += xscale*(i->second);
+ }
+
+ alpha = 1;
+ }
+ }
+
+ bool my_remove_oldest_first;
+
+ kernel_type kernel;
+
+ std::map<unsigned long,scalar_type> w;
+ scalar_type alpha;
+
+
+ scalar_type my_tolerance;
+ unsigned long my_max_dictionary_size;
+ scalar_type samples_seen;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Overloads for when kernel_type == offset_kernel<sparse_linear_kernel>
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ class kcentroid<offset_kernel<sparse_linear_kernel<T> > >
+ {
+
+ /*!
+ INITIAL VALUE
+ - x_extra == sqrt(kernel.offset)
+
+ CONVENTION
+ - x_extra == sqrt(kernel.offset)
+ - w_extra == the value of the extra dimension tacked onto the
+ end of the w vector
+ !*/
+
+ typedef offset_kernel<sparse_linear_kernel<T> > kernel_type;
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+
+ explicit kcentroid (
+ const kernel_type& kernel_,
+ scalar_type tolerance_ = 0.001,
+ unsigned long max_dictionary_size_ = 1000000,
+ bool remove_oldest_first_ = false
+ ) :
+ my_remove_oldest_first(remove_oldest_first_),
+ kernel(kernel_),
+ my_tolerance(tolerance_),
+ my_max_dictionary_size(max_dictionary_size_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(tolerance_ >= 0 && max_dictionary_size_ > 0,
+ "\tkcentroid::kcentroid()"
+ << "\n\t You have to give a positive tolerance"
+ << "\n\t this: " << this
+ << "\n\t tolerance_: " << tolerance_
+ << "\n\t max_dictionary_size_: " << max_dictionary_size_
+ );
+
+ x_extra = std::sqrt(kernel.offset);
+
+ clear_dictionary();
+ }
+
+ scalar_type tolerance() const { return my_tolerance; }
+ unsigned long max_dictionary_size() const { return my_max_dictionary_size; }
+ bool remove_oldest_first () const { return my_remove_oldest_first; }
+ const kernel_type& get_kernel () const { return kernel; }
+ scalar_type samples_trained () const { return samples_seen; }
+
+ void clear_dictionary ()
+ {
+ samples_seen = 0;
+ w.clear();
+ alpha = 0;
+ w_extra = x_extra;
+ }
+
+ scalar_type operator() (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::operator()(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ if (samples_seen > 0)
+ {
+ scalar_type temp1 = distance_squared(alpha,w , x.alpha,x.w);
+ scalar_type temp2 = alpha*w_extra - x.alpha*x.w_extra;
+ return std::sqrt(temp1 + temp2*temp2);
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ scalar_type inner_product (
+ const sample_type& x
+ ) const
+ {
+ if (samples_seen > 0)
+ return alpha*(dot(w,x) + w_extra*x_extra);
+ else
+ return 0;
+ }
+
+ scalar_type inner_product (
+ const kcentroid& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(x.get_kernel() == get_kernel(),
+ "\tscalar_type kcentroid::inner_product(const kcentroid& x)"
+ << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
+ << "\n\tthis: " << this
+ );
+
+ if (samples_seen > 0 && x.samples_seen > 0)
+ return alpha*x.alpha*(dot(w,x.w) + w_extra*x.w_extra);
+ else
+ return 0;
+ }
+
+ scalar_type squared_norm (
+ ) const
+ {
+ if (samples_seen > 0)
+ return alpha*alpha*(length_squared(w) + w_extra*w_extra);
+ else
+ return 0;
+ }
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const
+ {
+ if (samples_seen > 0)
+ {
+ scalar_type temp1 = distance_squared(1,x,alpha,w);
+ scalar_type temp2 = x_extra - alpha*w_extra;
+ return std::sqrt(temp1 + temp2*temp2);
+ }
+ else
+ {
+ return std::sqrt(length_squared(x) + x_extra*x_extra);
+ }
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+
+ do_train(x, cscale, xscale);
+
+ return (*this)(x);
+ }
+
+ void train (
+ const sample_type& x
+ )
+ {
+ ++samples_seen;
+ const scalar_type xscale = 1/samples_seen;
+ const scalar_type cscale = 1-xscale;
+
+ do_train(x, cscale, xscale);
+ }
+
+ scalar_type test_and_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+
+ do_train(x, cscale, xscale);
+
+ return (*this)(x);
+ }
+
+ void scale_by (
+ scalar_type cscale
+ )
+ {
+ alpha *= cscale;
+ w_extra *= cscale;
+ }
+
+ void train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+ ++samples_seen;
+ do_train(x, cscale, xscale);
+ }
+
+ void swap (
+ kcentroid& item
+ )
+ {
+ exchange(my_remove_oldest_first, item.my_remove_oldest_first);
+ exchange(kernel, item.kernel);
+ exchange(w, item.w);
+ exchange(alpha, item.alpha);
+ exchange(w_extra, item.w_extra);
+ exchange(x_extra, item.x_extra);
+ exchange(my_tolerance, item.my_tolerance);
+ exchange(my_max_dictionary_size, item.my_max_dictionary_size);
+ exchange(samples_seen, item.samples_seen);
+ }
+
+ unsigned long dictionary_size (
+ ) const
+ {
+ if (samples_seen > 0)
+ {
+ if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon())
+ return 1;
+ else
+ return 2;
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ friend void serialize(const kcentroid& item, std::ostream& out)
+ {
+ serialize(item.my_remove_oldest_first, out);
+ serialize(item.kernel, out);
+ serialize(item.w, out);
+ serialize(item.alpha, out);
+ serialize(item.w_extra, out);
+ serialize(item.x_extra, out);
+ serialize(item.my_tolerance, out);
+ serialize(item.my_max_dictionary_size, out);
+ serialize(item.samples_seen, out);
+ }
+
+ friend void deserialize(kcentroid& item, std::istream& in)
+ {
+ deserialize(item.my_remove_oldest_first, in);
+ deserialize(item.kernel, in);
+ deserialize(item.w, in);
+ deserialize(item.alpha, in);
+ deserialize(item.w_extra, in);
+ deserialize(item.x_extra, in);
+ deserialize(item.my_tolerance, in);
+ deserialize(item.my_max_dictionary_size, in);
+ deserialize(item.samples_seen, in);
+ }
+
+ distance_function<kernel_type> get_distance_function (
+ ) const
+ {
+ if (samples_seen > 0)
+ {
+ typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
+ typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
+
+ // What we are doing here needs a bit of explanation. The w vector
+ // has an implicit extra dimension tacked on to it with the value of w_extra.
+ // The kernel we are using takes normal vectors and implicitly tacks the value
+ // x_extra onto their end. So what we are doing here is scaling w so that
+ // the value it should have tacked onto it is x_scale. Note that we also
+ // adjust alpha so that the combination of alpha*w stays the same.
+ scalar_type scale;
+
+ // if w_extra is basically greater than 0
+ if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon())
+ {
+ scale = (x_extra/w_extra);
+ temp_basis_vectors.set_size(1);
+ temp_alpha.set_size(1);
+ temp_basis_vectors(0) = sample_type(w.begin(), w.end());
+ dlib::scale_by(temp_basis_vectors(0), scale);
+ temp_alpha(0) = alpha/scale;
+ }
+ else
+ {
+ // In this case w_extra is zero. So the only way we can get the same
+ // thing in the output basis vector set is by using two vectors
+ temp_basis_vectors.set_size(2);
+ temp_alpha.set_size(2);
+ temp_basis_vectors(0) = sample_type(w.begin(), w.end());
+ dlib::scale_by(temp_basis_vectors(0), 2);
+ temp_alpha(0) = alpha;
+ temp_basis_vectors(1) = sample_type(w.begin(), w.end());
+ temp_alpha(1) = -alpha;
+ }
+
+ return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
+
+ }
+ else
+ {
+ return distance_function<kernel_type>(kernel);
+ }
+
+ }
+
+ private:
+
+ void do_train (
+ const sample_type& x,
+ scalar_type cscale,
+ scalar_type xscale
+ )
+ {
+
+ const scalar_type temp = cscale*alpha;
+
+ if (temp != 0)
+ {
+ // compute w += xscale*x/temp
+ typename sample_type::const_iterator i;
+ for (i = x.begin(); i != x.end(); ++i)
+ {
+ w[i->first] += xscale*(i->second)/temp;
+ }
+
+ w_extra = w_extra + xscale*x_extra/temp;
+ alpha = temp;
+ }
+ else
+ {
+ // first compute w = cscale*alpha*w
+ for (typename std::map<unsigned long,scalar_type>::iterator i = w.begin(); i != w.end(); ++i)
+ {
+ i->second *= cscale*alpha;
+ }
+
+ // now compute w += xscale*x
+ for (typename sample_type::const_iterator i = x.begin(); i != x.end(); ++i)
+ {
+ w[i->first] += xscale*(i->second);
+ }
+
+
+ w_extra = cscale*alpha*w_extra + xscale*x_extra;
+ alpha = 1;
+ }
+ }
+
+ bool my_remove_oldest_first;
+
+ kernel_type kernel;
+
+ std::map<unsigned long,scalar_type> w;
+ scalar_type alpha;
+
+ scalar_type w_extra;
+ scalar_type x_extra;
+
+
+ scalar_type my_tolerance;
+ unsigned long my_max_dictionary_size;
+ scalar_type samples_seen;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KCENTROId_OVERLOADS_
+
+
diff --git a/ml/dlib/dlib/svm/kernel.h b/ml/dlib/dlib/svm/kernel.h
new file mode 100644
index 000000000..907420986
--- /dev/null
+++ b/ml/dlib/dlib/svm/kernel.h
@@ -0,0 +1,569 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_KERNEL
+#define DLIB_SVm_KERNEL
+
+#include "kernel_abstract.h"
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix.h"
+#include "../algs.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template < typename kernel_type > struct kernel_derivative;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct radial_basis_kernel
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ // T must be capable of representing a column vector.
+ COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0);
+
+ radial_basis_kernel(const scalar_type g) : gamma(g) {}
+ radial_basis_kernel() : gamma(0.1) {}
+ radial_basis_kernel(
+ const radial_basis_kernel& k
+ ) : gamma(k.gamma) {}
+
+
+ const scalar_type gamma;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ const scalar_type d = trans(a-b)*(a-b);
+ return std::exp(-gamma*d);
+ }
+
+ radial_basis_kernel& operator= (
+ const radial_basis_kernel& k
+ )
+ {
+ const_cast<scalar_type&>(gamma) = k.gamma;
+ return *this;
+ }
+
+ bool operator== (
+ const radial_basis_kernel& k
+ ) const
+ {
+ return gamma == k.gamma;
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const radial_basis_kernel<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.gamma, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type radial_basis_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ void deserialize (
+ radial_basis_kernel<T>& item,
+ std::istream& in
+ )
+ {
+ typedef typename T::type scalar_type;
+ try
+ {
+ deserialize(const_cast<scalar_type&>(item.gamma), in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type radial_basis_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ struct kernel_derivative<radial_basis_kernel<T> >
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ kernel_derivative(const radial_basis_kernel<T>& k_) : k(k_){}
+
+ const sample_type& operator() (const sample_type& x, const sample_type& y) const
+ {
+ // return the derivative of the rbf kernel
+ temp = 2*k.gamma*(x-y)*k(x,y);
+ return temp;
+ }
+
+ const radial_basis_kernel<T>& k;
+ mutable sample_type temp;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct polynomial_kernel
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ // T must be capable of representing a column vector.
+ COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0);
+
+ polynomial_kernel(const scalar_type g, const scalar_type c, const scalar_type d) : gamma(g), coef(c), degree(d) {}
+ polynomial_kernel() : gamma(1), coef(0), degree(1) {}
+ polynomial_kernel(
+ const polynomial_kernel& k
+ ) : gamma(k.gamma), coef(k.coef), degree(k.degree) {}
+
+ typedef T type;
+ const scalar_type gamma;
+ const scalar_type coef;
+ const scalar_type degree;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return std::pow(gamma*(trans(a)*b) + coef, degree);
+ }
+
+ polynomial_kernel& operator= (
+ const polynomial_kernel& k
+ )
+ {
+ const_cast<scalar_type&>(gamma) = k.gamma;
+ const_cast<scalar_type&>(coef) = k.coef;
+ const_cast<scalar_type&>(degree) = k.degree;
+ return *this;
+ }
+
+ bool operator== (
+ const polynomial_kernel& k
+ ) const
+ {
+ return (gamma == k.gamma) && (coef == k.coef) && (degree == k.degree);
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const polynomial_kernel<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.gamma, out);
+ serialize(item.coef, out);
+ serialize(item.degree, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type polynomial_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ void deserialize (
+ polynomial_kernel<T>& item,
+ std::istream& in
+ )
+ {
+ typedef typename T::type scalar_type;
+ try
+ {
+ deserialize(const_cast<scalar_type&>(item.gamma), in);
+ deserialize(const_cast<scalar_type&>(item.coef), in);
+ deserialize(const_cast<scalar_type&>(item.degree), in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type polynomial_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ struct kernel_derivative<polynomial_kernel<T> >
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ kernel_derivative(const polynomial_kernel<T>& k_) : k(k_){}
+
+ const sample_type& operator() (const sample_type& x, const sample_type& y) const
+ {
+ // return the derivative of the rbf kernel
+ temp = k.degree*k.gamma*x*std::pow(k.gamma*(trans(x)*y) + k.coef, k.degree-1);
+ return temp;
+ }
+
+ const polynomial_kernel<T>& k;
+ mutable sample_type temp;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sigmoid_kernel
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ // T must be capable of representing a column vector.
+ COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0);
+
+ sigmoid_kernel(const scalar_type g, const scalar_type c) : gamma(g), coef(c) {}
+ sigmoid_kernel() : gamma(0.1), coef(-1.0) {}
+ sigmoid_kernel(
+ const sigmoid_kernel& k
+ ) : gamma(k.gamma), coef(k.coef) {}
+
+ typedef T type;
+ const scalar_type gamma;
+ const scalar_type coef;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return std::tanh(gamma*(trans(a)*b) + coef);
+ }
+
+ sigmoid_kernel& operator= (
+ const sigmoid_kernel& k
+ )
+ {
+ const_cast<scalar_type&>(gamma) = k.gamma;
+ const_cast<scalar_type&>(coef) = k.coef;
+ return *this;
+ }
+
+ bool operator== (
+ const sigmoid_kernel& k
+ ) const
+ {
+ return (gamma == k.gamma) && (coef == k.coef);
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sigmoid_kernel<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.gamma, out);
+ serialize(item.coef, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type sigmoid_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sigmoid_kernel<T>& item,
+ std::istream& in
+ )
+ {
+ typedef typename T::type scalar_type;
+ try
+ {
+ deserialize(const_cast<scalar_type&>(item.gamma), in);
+ deserialize(const_cast<scalar_type&>(item.coef), in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type sigmoid_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ struct kernel_derivative<sigmoid_kernel<T> >
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ kernel_derivative(const sigmoid_kernel<T>& k_) : k(k_){}
+
+ const sample_type& operator() (const sample_type& x, const sample_type& y) const
+ {
+ // return the derivative of the rbf kernel
+ temp = k.gamma*x*(1-std::pow(k(x,y),2));
+ return temp;
+ }
+
+ const sigmoid_kernel<T>& k;
+ mutable sample_type temp;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct linear_kernel
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ // T must be capable of representing a column vector.
+ COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0);
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return trans(a)*b;
+ }
+
+ bool operator== (
+ const linear_kernel&
+ ) const
+ {
+ return true;
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const linear_kernel<T>& ,
+ std::ostream&
+ ){}
+
+ template <
+ typename T
+ >
+ void deserialize (
+ linear_kernel<T>& ,
+ std::istream&
+ ){}
+
+ template <
+ typename T
+ >
+ struct kernel_derivative<linear_kernel<T> >
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ kernel_derivative(const linear_kernel<T>& k_) : k(k_){}
+
+ const sample_type& operator() (const sample_type& x, const sample_type& ) const
+ {
+ return x;
+ }
+
+ const linear_kernel<T>& k;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct histogram_intersection_kernel
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ scalar_type temp = 0;
+ for (long i = 0; i < a.size(); ++i)
+ {
+ temp += std::min(a(i), b(i));
+ }
+ return temp;
+ }
+
+ bool operator== (
+ const histogram_intersection_kernel&
+ ) const
+ {
+ return true;
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const histogram_intersection_kernel<T>& ,
+ std::ostream&
+ ){}
+
+ template <
+ typename T
+ >
+ void deserialize (
+ histogram_intersection_kernel<T>& ,
+ std::istream&
+ ){}
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct offset_kernel
+ {
+ typedef typename T::scalar_type scalar_type;
+ typedef typename T::sample_type sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ offset_kernel(const T& k, const scalar_type& offset_
+ ) : kernel(k), offset(offset_) {}
+ offset_kernel() : kernel(T()), offset(0.01) {}
+ offset_kernel(
+ const offset_kernel& k
+ ) : kernel(k.kernel), offset(k.offset) {}
+
+ const T kernel;
+ const scalar_type offset;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return kernel(a,b) + offset;
+ }
+
+ offset_kernel& operator= (
+ const offset_kernel& k
+ )
+ {
+ const_cast<T&>(kernel) = k.kernel;
+ const_cast<scalar_type&>(offset) = k.offset;
+ return *this;
+ }
+
+ bool operator== (
+ const offset_kernel& k
+ ) const
+ {
+ return k.kernel == kernel && offset == k.offset;
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const offset_kernel<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.offset, out);
+ serialize(item.kernel, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type offset_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ void deserialize (
+ offset_kernel<T>& item,
+ std::istream& in
+ )
+ {
+ typedef typename offset_kernel<T>::scalar_type scalar_type;
+ try
+ {
+ deserialize(const_cast<scalar_type&>(item.offset), in);
+ deserialize(const_cast<T&>(item.kernel), in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type offset_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ struct kernel_derivative<offset_kernel<T> >
+ {
+ typedef typename T::scalar_type scalar_type;
+ typedef typename T::sample_type sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ kernel_derivative(const offset_kernel<T>& k) : der(k.kernel){}
+
+ const sample_type operator() (const sample_type& x, const sample_type& y) const
+ {
+ return der(x,y);
+ }
+
+ kernel_derivative<T> der;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_KERNEL
+
+
diff --git a/ml/dlib/dlib/svm/kernel_abstract.h b/ml/dlib/dlib/svm/kernel_abstract.h
new file mode 100644
index 000000000..f72430eb8
--- /dev/null
+++ b/ml/dlib/dlib/svm/kernel_abstract.h
@@ -0,0 +1,681 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_KERNEL_ABSTRACT_
+#ifdef DLIB_SVm_KERNEL_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+/*!A Kernel_Function_Objects */
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ WHAT IS A KERNEL FUNCTION OBJECT?
+ In the context of the dlib library documentation a kernel function object
+ is an object with an interface with the following properties:
+ - a public typedef named sample_type
+ - a public typedef named scalar_type which should be a float, double, or
+ long double type.
+ - an overloaded operator() that operates on two items of sample_type
+ and returns a scalar_type.
+ (e.g. scalar_type val = kernel_function(sample1,sample2);
+ would be a valid expression)
+ - a public typedef named mem_manager_type that is an implementation of
+ dlib/memory_manager/memory_manager_kernel_abstract.h or
+ dlib/memory_manager_global/memory_manager_global_kernel_abstract.h or
+ dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
+ - an overloaded == operator that tells you if two kernels are
+ identical or not.
+
+ THREAD SAFETY
+ For a kernel function to be threadsafe it means that it must be safe to
+ evaluate an expression like val = kernel_function(sample1,sample2)
+ simultaneously from multiple threads, even when the threads operate on the same
+ object instances (i.e. kernel_function, sample1, and sample2). The most common
+ way to make this safe is to ensure that the kernel function does not mutate any
+ data, either in itself or in its arguments.
+
+ For examples of kernel functions see the following objects
+ (e.g. the radial_basis_kernel).
+ !*/
+
+ template <
+ typename T
+ >
+ struct radial_basis_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be a dlib::matrix object
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a radial basis function kernel
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ const scalar_type gamma;
+
+ radial_basis_kernel(
+ );
+ /*!
+ ensures
+ - #gamma == 0.1
+ !*/
+
+ radial_basis_kernel(
+ const radial_basis_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma == k.gamma
+ !*/
+
+ radial_basis_kernel(
+ const scalar_type g
+ );
+ /*!
+ ensures
+ - #gamma == g
+ !*/
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - a.nc() == 1
+ - b.nc() == 1
+ - a.nr() == b.nr()
+ ensures
+ - returns exp(-gamma * ||a-b||^2)
+ !*/
+
+ radial_basis_kernel& operator= (
+ const radial_basis_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma = k.gamma
+ - returns *this
+ !*/
+
+ bool operator== (
+ const radial_basis_kernel& k
+ ) const;
+ /*!
+ ensures
+ - if (k and *this are identical) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const radial_basis_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for radial_basis_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ radial_basis_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for radial_basis_kernel
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sigmoid_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be a dlib::matrix object
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a sigmoid kernel
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ const scalar_type gamma;
+ const scalar_type coef;
+
+ sigmoid_kernel(
+ );
+ /*!
+ ensures
+ - #gamma == 0.1
+ - #coef == -1.0
+ !*/
+
+ sigmoid_kernel(
+ const sigmoid_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma == k.gamma
+ - #coef == k.coef
+ !*/
+
+ sigmoid_kernel(
+ const scalar_type g,
+ const scalar_type c
+ );
+ /*!
+ ensures
+ - #gamma == g
+ - #coef == c
+ !*/
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - a.nc() == 1
+ - b.nc() == 1
+ - a.nr() == b.nr()
+ ensures
+ - returns tanh(gamma*trans(a)*b + coef)
+ !*/
+
+ sigmoid_kernel& operator= (
+ const sigmoid_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma = k.gamma
+ - #coef = k.coef
+ - returns *this
+ !*/
+
+ bool operator== (
+ const sigmoid_kernel& k
+ ) const;
+ /*!
+ ensures
+ - if (k and *this are identical) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sigmoid_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for sigmoid_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sigmoid_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for sigmoid_kernel
+ !*/
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct polynomial_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be a dlib::matrix object
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a polynomial kernel
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ const scalar_type gamma;
+ const scalar_type coef;
+ const scalar_type degree;
+
+ polynomial_kernel(
+ );
+ /*!
+ ensures
+ - #gamma == 1
+ - #coef == 0
+ - #degree == 1
+ !*/
+
+ polynomial_kernel(
+ const polynomial_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma == k.gamma
+ - #coef == k.coef
+ - #degree == k.degree
+ !*/
+
+ polynomial_kernel(
+ const scalar_type g,
+ const scalar_type c,
+ const scalar_type d
+ );
+ /*!
+ ensures
+ - #gamma == g
+ - #coef == c
+ - #degree == d
+ !*/
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - a.nc() == 1
+ - b.nc() == 1
+ - a.nr() == b.nr()
+ ensures
+ - returns pow(gamma*trans(a)*b + coef, degree)
+ !*/
+
+ polynomial_kernel& operator= (
+ const polynomial_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma = k.gamma
+ - #coef = k.coef
+ - #degree = k.degree
+ - returns *this
+ !*/
+
+ bool operator== (
+ const polynomial_kernel& k
+ ) const;
+ /*!
+ ensures
+ - if (k and *this are identical) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const polynomial_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for polynomial_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ polynomial_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for polynomial_kernel
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct linear_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be a dlib::matrix object
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a linear function kernel
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - a.nc() == 1
+ - b.nc() == 1
+ - a.nr() == b.nr()
+ ensures
+ - returns trans(a)*b
+ !*/
+
+ bool operator== (
+ const linear_kernel& k
+ ) const;
+ /*!
+ ensures
+ - returns true
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const linear_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for linear_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ linear_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for linear_kernel
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct histogram_intersection_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be a dlib::matrix object
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a histogram intersection kernel kernel
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - is_vector(a)
+ - is_vector(b)
+ - a.size() == b.size()
+ - min(a) >= 0
+ - min(b) >= 0
+ ensures
+ - returns sum over all i: std::min(a(i), b(i))
+ !*/
+
+ bool operator== (
+ const histogram_intersection_kernel& k
+ ) const;
+ /*!
+ ensures
+ - returns true
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const histogram_intersection_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for histogram_intersection_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ histogram_intersection_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for histogram_intersection_kernel
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct offset_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be a kernel object (e.g. radial_basis_kernel, polynomial_kernel, etc.)
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a kernel with a fixed value offset
+ added to it.
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::scalar_type scalar_type;
+ typedef typename T::sample_type sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ const T kernel;
+ const scalar_type offset;
+
+ offset_kernel(
+ );
+ /*!
+ ensures
+ - #offset == 0.01
+ !*/
+
+ offset_kernel(
+ const offset_kernel& k
+ );
+ /*!
+ ensures
+ - #offset == k.offset
+ - #kernel == k.kernel
+ !*/
+
+ offset_kernel(
+ const T& k,
+ const scalar_type& off
+ );
+ /*!
+ ensures
+ - #kernel == k
+ - #offset == off
+ !*/
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ ensures
+ - returns kernel(a,b) + offset
+ !*/
+
+ offset_kernel& operator= (
+ const offset_kernel& k
+ );
+ /*!
+ ensures
+ - #offset == k.offset
+ - #kernel == k.kernel
+ !*/
+
+ bool operator== (
+ const offset_kernel& k
+ ) const;
+ /*!
+ ensures
+ - if (k and *this are identical) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const offset_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for offset_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ offset_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for offset_kernel
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type
+ >
+ struct kernel_derivative
+ {
+ /*!
+ REQUIREMENTS ON kernel_type
+ kernel_type must be one of the following kernel types:
+ - radial_basis_kernel
+ - polynomial_kernel
+ - sigmoid_kernel
+ - linear_kernel
+ - offset_kernel
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a function object that computes the derivative of a kernel
+ function object.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ Instances of this object are allowed to have a mutable cache which is
+ used by const member functions. Therefore, it is not safe to use one
+ instance of this object from multiple threads (unless protected by a
+ mutex).
+ !*/
+
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ kernel_derivative(
+ const kernel_type& k_
+ );
+ /*!
+ ensures
+ - this object will return derivatives of the kernel object k_
+ - #k == k_
+ !*/
+
+ const sample_type operator() (
+ const sample_type& x,
+ const sample_type& y
+ ) const;
+ /*!
+ ensures
+ - returns the derivative of k with respect to y.
+ !*/
+
+ const kernel_type& k;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_KERNEL_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/kernel_matrix.h b/ml/dlib/dlib/svm/kernel_matrix.h
new file mode 100644
index 000000000..f6e1e0b90
--- /dev/null
+++ b/ml/dlib/dlib/svm/kernel_matrix.h
@@ -0,0 +1,268 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_KERNEL_MATRIX_
+#define DLIB_SVm_KERNEL_MATRIX_
+
+#include <vector>
+#include "kernel_matrix_abstract.h"
+#include "../matrix.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename kernel_type, typename T>
+ inline const typename T::type& access ( const matrix_exp<T>& m, long i)
+ {
+ return m(i);
+ }
+
+ // bind to anything that looks like an array and isn't a matrix
+ template <typename kernel_type, typename T>
+ inline const typename disable_if<is_matrix<T>,typename T::type>::type& access ( const T& m, long i)
+ {
+ return m[i];
+ }
+
+ // Only use this function if T isn't a std::pair because in that case the entire vector is
+ // probably itself a sparse sample.
+ template <typename kernel_type, typename T, typename alloc>
+ inline typename disable_if<is_pair<T>,const T&>::type access ( const std::vector<T,alloc>& m, long i)
+ {
+ return m[i];
+ }
+
+ // Only use this function if T isn't a std::pair because in that case the entire vector is
+ // probably a sparse sample.
+ template <typename kernel_type, typename T, typename alloc>
+ inline typename disable_if<is_pair<T>,const T&>::type access ( const std_vector_c<T,alloc>& m, long i)
+ {
+ return m[i];
+ }
+
+ template <typename kernel_type>
+ inline const typename kernel_type::sample_type& access (
+ const typename kernel_type::sample_type& samp,
+ long
+ )
+ {
+ return samp;
+ }
+
+ // --------------------------------------------
+
+ template <typename kernel_type, typename T>
+ inline typename disable_if<is_same_type<T,typename kernel_type::sample_type>,unsigned long>::type
+ size ( const T& m)
+ {
+ return m.size();
+ }
+
+ template <typename kernel_type>
+ inline size_t size (
+ const typename kernel_type::sample_type&
+ )
+ {
+ return 1;
+ }
+
+ // --------------------------------------------
+
+ template <typename T>
+ typename disable_if<is_matrix<T> >::type assert_is_vector(const T&)
+ {}
+
+ template <typename T>
+ // This funny #ifdef thing is here because gcc sometimes gives a warning
+ // about v being unused otherwise.
+#ifdef ENABLE_ASSERTS
+ void assert_is_vector(const matrix_exp<T>& v)
+#else
+ void assert_is_vector(const matrix_exp<T>& )
+#endif
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(v) == true,
+ "\tconst matrix_exp kernel_matrix()"
+ << "\n\t You have to supply this function with row or column vectors"
+ << "\n\t v.nr(): " << v.nr()
+ << "\n\t v.nc(): " << v.nc()
+ );
+ }
+
+ }
+
+ template <typename K, typename vect_type1, typename vect_type2>
+ struct op_kern_mat
+ {
+ op_kern_mat(
+ const K& kern_,
+ const vect_type1& vect1_,
+ const vect_type2& vect2_
+ ) :
+ kern(kern_),
+ vect1(vect1_),
+ vect2(vect2_)
+ {
+ // make sure the requires clauses get checked eventually
+ impl::assert_is_vector(vect1);
+ impl::assert_is_vector(vect2);
+ }
+
+ const K& kern;
+ const vect_type1& vect1;
+ const vect_type2& vect2;
+
+ typedef typename K::scalar_type type;
+
+ const static long cost = 100;
+ const static long NR = (is_same_type<vect_type1,typename K::sample_type>::value) ? 1 : 0;
+ const static long NC = (is_same_type<vect_type2,typename K::sample_type>::value) ? 1 : 0;
+
+ typedef const type const_ret_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const
+ {
+ return kern(impl::access<K>(vect1,r), impl::access<K>(vect2,c));
+ }
+
+ long nr () const { return impl::size<K>(vect1); }
+ long nc () const { return impl::size<K>(vect2); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item ) const { return alias_helper(item.ref()); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item ) const { return alias_helper(item.ref()); }
+
+ template <typename U> bool alias_helper ( const U& ) const { return false; }
+
+ typedef typename K::sample_type samp_type;
+
+ // Say we destructively alias if one of the vect* objects is actually item.
+ bool alias_helper (const samp_type& item ) const { return are_same(item, vect1) || are_same(item, vect2); }
+ template <typename U> bool are_same (const samp_type& , const U& ) const { return false; }
+ bool are_same (const samp_type& a, const samp_type& b) const { return (&a == &b); }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename V1,
+ typename V2
+ >
+ const matrix_op<op_kern_mat<K,V1,V2> > kernel_matrix (
+ const K& kern,
+ const V1& v1,
+ const V2& v2
+ )
+ {
+ typedef op_kern_mat<K,V1,V2> op;
+ return matrix_op<op>(op(kern,v1,v2));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ /*
+ It is possible to implement the kernel_matrix() operator with just one operator
+ class but treating the version that takes only a single vector separately
+ leads to more efficient output by gcc in certain instances.
+ */
+ template <typename K, typename vect_type1>
+ struct op_kern_mat_single
+ {
+ op_kern_mat_single(
+ const K& kern_,
+ const vect_type1& vect1_
+ ) :
+ kern(kern_),
+ vect1(vect1_)
+ {
+ // make sure the requires clauses get checked eventually
+ impl::assert_is_vector(vect1);
+ }
+
+ const K& kern;
+ const vect_type1& vect1;
+
+ typedef typename K::scalar_type type;
+
+ const static long cost = 100;
+ const static long NR = (is_same_type<vect_type1,typename K::sample_type>::value) ? 1 : 0;
+ const static long NC = (is_same_type<vect_type1,typename K::sample_type>::value) ? 1 : 0;
+
+ typedef const type const_ret_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+ typedef row_major_layout layout_type;
+
+ const_ret_type apply (long r, long c ) const
+ {
+ return kern(impl::access<K>(vect1,r), impl::access<K>(vect1,c));
+ }
+
+ long nr () const { return impl::size<K>(vect1); }
+ long nc () const { return impl::size<K>(vect1); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item ) const { return alias_helper(item.ref()); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item ) const { return alias_helper(item.ref()); }
+
+ template <typename U> bool alias_helper ( const U& ) const { return false; }
+
+ typedef typename K::sample_type samp_type;
+
+ // Say we destructively alias if vect1 is actually item.
+ bool alias_helper (const samp_type& item ) const { return are_same(item, vect1); }
+ template <typename U> bool are_same (const samp_type& , const U& ) const { return false; }
+ bool are_same (const samp_type& a, const samp_type& b) const { return (&a == &b); }
+ };
+
+ template <
+ typename K,
+ typename V
+ >
+ const matrix_op<op_kern_mat_single<K,V> > kernel_matrix (
+ const K& kern,
+ const V& v
+ )
+ {
+ typedef op_kern_mat_single<K,V> op;
+ return matrix_op<op>(op(kern,v));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_dest_type,
+ typename K,
+ typename V
+ >
+ inline void matrix_assign (
+ matrix_dest_type& dest,
+ const matrix_exp<matrix_op<op_kern_mat_single<K,V> > >& src
+ )
+ /*!
+ Overload matrix assignment so that when a kernel_matrix expression
+ gets assigned it only evaluates half the kernel matrix (since it is symmetric)
+ !*/
+ {
+ for (long r = 0; r < src.nr(); ++r)
+ {
+ for (long c = r; c < src.nc(); ++c)
+ {
+ dest(r,c) = dest(c,r) = src(r,c);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_KERNEL_MATRIX_
+
diff --git a/ml/dlib/dlib/svm/kernel_matrix_abstract.h b/ml/dlib/dlib/svm/kernel_matrix_abstract.h
new file mode 100644
index 000000000..4aa4b1ce2
--- /dev/null
+++ b/ml/dlib/dlib/svm/kernel_matrix_abstract.h
@@ -0,0 +1,115 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_KERNEL_MATRIX_ABSTRACT_
+#ifdef DLIB_SVm_KERNEL_MATRIX_ABSTRACT_
+
+#include <vector>
+#include "kernel_abstract.h"
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename V
+ >
+ const matrix_exp kernel_matrix (
+ const kernel_type& kernel,
+ const V& v
+ );
+ /*!
+ requires
+ - kernel == a kernel function object as defined by the file dlib/svm/kernel_abstract.h.
+ This kernel must also be capable of operating on the contents of v.
+ - V == dlib::matrix, std::vector, dlib::std_vector_c, dlib::random_subset_selector,
+ dlib::linearly_independent_subset_finder, or kernel_type::sample_type.
+ - if (V is a dlib::matrix) then
+ - is_vector(v) == true
+ ensures
+ - if (V is of type kernel_type::sample_type) then
+ - returns a matrix R such that:
+ - R::type == kernel_type::scalar_type
+ - R.size() == 1
+ - R(0,0) == kernel(v,v)
+ - else
+ - returns a matrix R such that:
+ - R::type == kernel_type::scalar_type
+ - R is a square matrix of v.size() rows by v.size() columns
+ - for all valid r and c:
+ - R(r,c) == kernel(v(r), v(c))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename V1,
+ typename V2
+ >
+ const matrix_exp kernel_matrix (
+ const kernel_type& kernel,
+ const V1& v1,
+ const V2& v2
+ );
+ /*!
+ requires
+ - kernel == a kernel function object as defined by the file dlib/svm/kernel_abstract.h
+ This kernel must also be capable of operating on the contents of v1 and v2.
+ - V1 == dlib::matrix, std::vector, dlib::std_vector_c, dlib::random_subset_selector,
+ dlib::linearly_independent_subset_finder, or kernel_type::sample_type.
+ - V2 == dlib::matrix, std::vector, dlib::std_vector_c, dlib::random_subset_selector,
+ dlib::linearly_independent_subset_finder, or kernel_type::sample_type.
+ - if (V1 is a dlib::matrix) then
+ - is_vector(v1) == true
+ - if (V2 is a dlib::matrix) then
+ - is_vector(v2) == true
+ ensures
+ - if (V1 and V2 are of type kernel_type::sample_type) then
+ - returns a matrix R such that:
+ - R::type == kernel_type::scalar_type
+ - R.size() == 1
+ - R(0,0) == kernel(v1,v2)
+ - else if (V1 is of type kernel_type::sample_type) then
+ - returns a matrix R such that:
+ - R::type == kernel_type::scalar_type
+ - R.nr() == 1
+ - R.nc() == v2.size()
+ - for all valid c:
+ - R(0,c) == kernel(v1, v2(c))
+ - else if (V2 is of type kernel_type::sample_type) then
+ - returns a matrix R such that:
+ - R::type == kernel_type::scalar_type
+ - R.nr() == v1.size()
+ - R.nc() == 1
+ - for all valid r:
+ - R(r,0) == kernel(v1(r), v2)
+ - else
+ - returns a matrix R such that:
+ - R::type == kernel_type::scalar_type
+ - R.nr() == v1.size()
+ - R.nc() == v2.size()
+ - for all valid r and c:
+ - R(r,c) == kernel(v1(r), v2(c))
+
+
+ A note about aliasing (see the examples/matrix_expressions_ex.cpp example program
+ for a discussion of what aliasing is in the context of the dlib::matrix):
+ kernel_matrix() expressions can detect aliasing of an argument if that
+ argument is of type kernel_type::sample_type. However, it can't detect
+ aliasing though std::vectors or other "list of sample type" container class
+ arguments. This means that it is safe to assign a kernel_matrix() expression
+ to a sample_type if V1 or V2 are of sample_type but not safe otherwise. However,
+ since the latter case results in a general n by m matrix rather than a column
+ or row vector you shouldn't ever be doing it anyway.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_KERNEL_MATRIX_ABSTRACT_
+
diff --git a/ml/dlib/dlib/svm/kkmeans.h b/ml/dlib/dlib/svm/kkmeans.h
new file mode 100644
index 000000000..4c72106d8
--- /dev/null
+++ b/ml/dlib/dlib/svm/kkmeans.h
@@ -0,0 +1,654 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_KKMEANs_
+#define DLIB_KKMEANs_
+
+#include <cmath>
+#include <vector>
+
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "kernel.h"
+#include "../array.h"
+#include "kcentroid.h"
+#include "kkmeans_abstract.h"
+#include "../noncopyable.h"
+
+namespace dlib
+{
+
+ template <
+ typename kernel_type
+ >
+ class kkmeans : public noncopyable
+ {
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ kkmeans (
+ const kcentroid<kernel_type>& kc_
+ ):
+ kc(kc_),
+ min_change(0.01)
+ {
+ set_number_of_centers(1);
+ }
+
+ ~kkmeans()
+ {
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kc.get_kernel();
+ }
+
+ void set_kcentroid (
+ const kcentroid<kernel_type>& kc_
+ )
+ {
+ kc = kc_;
+ set_number_of_centers(number_of_centers());
+ }
+
+ const kcentroid<kernel_type>& get_kcentroid (
+ unsigned long i
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(i < number_of_centers(),
+ "\tkcentroid kkmeans::get_kcentroid(i)"
+ << "\n\tYou have given an invalid value for i"
+ << "\n\ti: " << i
+ << "\n\tnumber_of_centers(): " << number_of_centers()
+ << "\n\tthis: " << this
+ );
+
+ return *centers[i];
+ }
+
+ void set_number_of_centers (
+ unsigned long num
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num > 0,
+ "\tvoid kkmeans::set_number_of_centers()"
+ << "\n\tYou can't set the number of centers to zero"
+ << "\n\tthis: " << this
+ );
+
+ centers.set_max_size(num);
+ centers.set_size(num);
+
+ for (unsigned long i = 0; i < centers.size(); ++i)
+ {
+ centers[i].reset(new kcentroid<kernel_type>(kc));
+ }
+ }
+
+ unsigned long number_of_centers (
+ ) const
+ {
+ return centers.size();
+ }
+
+ template <typename T, typename U>
+ void train (
+ const T& samples,
+ const U& initial_centers,
+ long max_iter = 1000
+ )
+ {
+ do_train(mat(samples),mat(initial_centers),max_iter);
+ }
+
+ unsigned long operator() (
+ const sample_type& sample
+ ) const
+ {
+ unsigned long label = 0;
+ scalar_type best_score = (*centers[0])(sample);
+
+ // figure out which center the given sample is closest too
+ for (unsigned long i = 1; i < centers.size(); ++i)
+ {
+ scalar_type temp = (*centers[i])(sample);
+ if (temp < best_score)
+ {
+ label = i;
+ best_score = temp;
+ }
+ }
+
+ return label;
+ }
+
+ void set_min_change (
+ scalar_type min_change_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT( 0 <= min_change_ < 1,
+ "\tvoid kkmeans::set_min_change()"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tthis: " << this
+ << "\n\tmin_change_: " << min_change_
+ );
+ min_change = min_change_;
+ }
+
+ const scalar_type get_min_change (
+ ) const
+ {
+ return min_change;
+ }
+
+ void swap (
+ kkmeans& item
+ )
+ {
+ centers.swap(item.centers);
+ kc.swap(item.kc);
+ assignments.swap(item.assignments);
+ exchange(min_change, item.min_change);
+ }
+
+ friend void serialize(const kkmeans& item, std::ostream& out)
+ {
+ serialize(item.centers.size(),out);
+ for (unsigned long i = 0; i < item.centers.size(); ++i)
+ {
+ serialize(*item.centers[i], out);
+ }
+ serialize(item.kc, out);
+ serialize(item.min_change, out);
+ }
+
+ friend void deserialize(kkmeans& item, std::istream& in)
+ {
+ unsigned long num;
+ deserialize(num, in);
+ item.centers.resize(num);
+ for (unsigned long i = 0; i < item.centers.size(); ++i)
+ {
+ std::unique_ptr<kcentroid<kernel_type> > temp(new kcentroid<kernel_type>(kernel_type()));
+ deserialize(*temp, in);
+ item.centers[i].swap(temp);
+ }
+
+ deserialize(item.kc, in);
+ deserialize(item.min_change, in);
+ }
+
+ private:
+
+ template <typename matrix_type, typename matrix_type2>
+ void do_train (
+ const matrix_type& samples,
+ const matrix_type2& initial_centers,
+ long max_iter = 1000
+ )
+ {
+ COMPILE_TIME_ASSERT((is_same_type<typename matrix_type::type, sample_type>::value));
+ COMPILE_TIME_ASSERT((is_same_type<typename matrix_type2::type, sample_type>::value));
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(samples.nc() == 1 && initial_centers.nc() == 1 &&
+ initial_centers.nr() == static_cast<long>(number_of_centers()),
+ "\tvoid kkmeans::train()"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tthis: " << this
+ << "\n\tsamples.nc(): " << samples.nc()
+ << "\n\tinitial_centers.nc(): " << initial_centers.nc()
+ << "\n\tinitial_centers.nr(): " << initial_centers.nr()
+ );
+
+ // clear out the old data and initialize the centers
+ for (unsigned long i = 0; i < centers.size(); ++i)
+ {
+ centers[i]->clear_dictionary();
+ centers[i]->train(initial_centers(i));
+ }
+
+ assignments.resize(samples.size());
+
+ bool assignment_changed = true;
+
+ // loop until the centers stabilize
+ long count = 0;
+ const unsigned long min_num_change = static_cast<unsigned long>(min_change*samples.size());
+ unsigned long num_changed = min_num_change;
+ while (assignment_changed && count < max_iter && num_changed >= min_num_change)
+ {
+ ++count;
+ assignment_changed = false;
+ num_changed = 0;
+
+ // loop over all the samples and assign them to their closest centers
+ for (long i = 0; i < samples.size(); ++i)
+ {
+ // find the best center
+ unsigned long best_center = 0;
+ scalar_type best_score = (*centers[0])(samples(i));
+ for (unsigned long c = 1; c < centers.size(); ++c)
+ {
+ scalar_type temp = (*centers[c])(samples(i));
+ if (temp < best_score)
+ {
+ best_score = temp;
+ best_center = c;
+ }
+ }
+
+ // if the current sample changed centers then make note of that
+ if (assignments[i] != best_center)
+ {
+ assignments[i] = best_center;
+ assignment_changed = true;
+ ++num_changed;
+ }
+ }
+
+ if (assignment_changed)
+ {
+ // now clear out the old data
+ for (unsigned long i = 0; i < centers.size(); ++i)
+ centers[i]->clear_dictionary();
+
+ // recalculate the cluster centers
+ for (unsigned long i = 0; i < assignments.size(); ++i)
+ centers[assignments[i]]->train(samples(i));
+ }
+
+ }
+
+
+ }
+
+ array<std::unique_ptr<kcentroid<kernel_type> > > centers;
+ kcentroid<kernel_type> kc;
+ scalar_type min_change;
+
+ // temp variables
+ array<unsigned long> assignments;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ void swap(kkmeans<kernel_type>& a, kkmeans<kernel_type>& b)
+ { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ struct dlib_pick_initial_centers_data
+ {
+ dlib_pick_initial_centers_data():idx(0), dist(std::numeric_limits<double>::infinity()){}
+ long idx;
+ double dist;
+ bool operator< (const dlib_pick_initial_centers_data& d) const { return dist < d.dist; }
+ };
+
+ template <
+ typename vector_type1,
+ typename vector_type2,
+ typename kernel_type
+ >
+ void pick_initial_centers(
+ long num_centers,
+ vector_type1& centers,
+ const vector_type2& samples,
+ const kernel_type& k,
+ double percentile = 0.01
+ )
+ {
+ /*
+ This function is basically just a non-randomized version of the kmeans++ algorithm
+ described in the paper:
+ kmeans++: The Advantages of Careful Seeding by Arthur and Vassilvitskii
+
+ */
+
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_centers > 1 && 0 <= percentile && percentile < 1 && samples.size() > 1,
+ "\tvoid pick_initial_centers()"
+ << "\n\tYou passed invalid arguments to this function"
+ << "\n\tnum_centers: " << num_centers
+ << "\n\tpercentile: " << percentile
+ << "\n\tsamples.size(): " << samples.size()
+ );
+
+ std::vector<dlib_pick_initial_centers_data> scores(samples.size());
+ std::vector<dlib_pick_initial_centers_data> scores_sorted(samples.size());
+ centers.clear();
+
+ // pick the first sample as one of the centers
+ centers.push_back(samples[0]);
+
+ const long best_idx = static_cast<long>(std::max(0.0,samples.size() - samples.size()*percentile - 1));
+
+ // pick the next center
+ for (long i = 0; i < num_centers-1; ++i)
+ {
+ // Loop over the samples and compare them to the most recent center. Store
+ // the distance from each sample to its closest center in scores.
+ const double k_cc = k(centers[i], centers[i]);
+ for (unsigned long s = 0; s < samples.size(); ++s)
+ {
+ // compute the distance between this sample and the current center
+ const double dist = k_cc + k(samples[s],samples[s]) - 2*k(samples[s], centers[i]);
+
+ if (dist < scores[s].dist)
+ {
+ scores[s].dist = dist;
+ scores[s].idx = s;
+ }
+ }
+
+ scores_sorted = scores;
+
+ // now find the winning center and add it to centers. It is the one that is
+ // far away from all the other centers.
+ sort(scores_sorted.begin(), scores_sorted.end());
+ centers.push_back(samples[scores_sorted[best_idx].idx]);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type1,
+ typename vector_type2
+ >
+ void pick_initial_centers(
+ long num_centers,
+ vector_type1& centers,
+ const vector_type2& samples,
+ double percentile = 0.01
+ )
+ {
+ typedef typename vector_type1::value_type sample_type;
+ linear_kernel<sample_type> kern;
+ pick_initial_centers(num_centers, centers, samples, kern, percentile);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type,
+ typename sample_type,
+ typename alloc
+ >
+ void find_clusters_using_kmeans (
+ const array_type& samples,
+ std::vector<sample_type, alloc>& centers,
+ unsigned long max_iter = 1000
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(samples.size() > 0 && centers.size() > 0,
+ "\tvoid find_clusters_using_kmeans()"
+ << "\n\tYou passed invalid arguments to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t centers.size(): " << centers.size()
+ );
+
+#ifdef ENABLE_ASSERTS
+ {
+ const long nr = samples[0].nr();
+ const long nc = samples[0].nc();
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ DLIB_ASSERT(is_vector(samples[i]) && samples[i].nr() == nr && samples[i].nc() == nc,
+ "\tvoid find_clusters_using_kmeans()"
+ << "\n\t You passed invalid arguments to this function"
+ << "\n\t is_vector(samples[i]): " << is_vector(samples[i])
+ << "\n\t samples[i].nr(): " << samples[i].nr()
+ << "\n\t nr: " << nr
+ << "\n\t samples[i].nc(): " << samples[i].nc()
+ << "\n\t nc: " << nc
+ << "\n\t i: " << i
+ );
+ }
+ }
+#endif
+
+ typedef typename sample_type::type scalar_type;
+
+ sample_type zero(centers[0]);
+ set_all_elements(zero, 0);
+
+ std::vector<unsigned long> center_element_count;
+
+ // tells which center a sample belongs to
+ std::vector<unsigned long> assignments(samples.size(), samples.size());
+
+
+ unsigned long iter = 0;
+ bool centers_changed = true;
+ while (centers_changed && iter < max_iter)
+ {
+ ++iter;
+ centers_changed = false;
+ center_element_count.assign(centers.size(), 0);
+
+ // loop over each sample and see which center it is closest to
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ // find the best center for sample[i]
+ scalar_type best_dist = std::numeric_limits<scalar_type>::max();
+ unsigned long best_center = 0;
+ for (unsigned long j = 0; j < centers.size(); ++j)
+ {
+ scalar_type dist = length(centers[j] - samples[i]);
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ best_center = j;
+ }
+ }
+
+ if (assignments[i] != best_center)
+ {
+ centers_changed = true;
+ assignments[i] = best_center;
+ }
+
+ center_element_count[best_center] += 1;
+ }
+
+ // now update all the centers
+ centers.assign(centers.size(), zero);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ centers[assignments[i]] += samples[i];
+ }
+ for (unsigned long i = 0; i < centers.size(); ++i)
+ {
+ if (center_element_count[i] != 0)
+ centers[i] /= center_element_count[i];
+ }
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type,
+ typename sample_type,
+ typename alloc
+ >
+ void find_clusters_using_angular_kmeans (
+ const array_type& samples,
+ std::vector<sample_type, alloc>& centers,
+ unsigned long max_iter = 1000
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(samples.size() > 0 && centers.size() > 0,
+ "\tvoid find_clusters_using_angular_kmeans()"
+ << "\n\tYou passed invalid arguments to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t centers.size(): " << centers.size()
+ );
+
+#ifdef ENABLE_ASSERTS
+ {
+ const long nr = samples[0].nr();
+ const long nc = samples[0].nc();
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ DLIB_ASSERT(is_vector(samples[i]) && samples[i].nr() == nr && samples[i].nc() == nc,
+ "\tvoid find_clusters_using_angular_kmeans()"
+ << "\n\t You passed invalid arguments to this function"
+ << "\n\t is_vector(samples[i]): " << is_vector(samples[i])
+ << "\n\t samples[i].nr(): " << samples[i].nr()
+ << "\n\t nr: " << nr
+ << "\n\t samples[i].nc(): " << samples[i].nc()
+ << "\n\t nc: " << nc
+ << "\n\t i: " << i
+ );
+ }
+ }
+#endif
+
+ typedef typename sample_type::type scalar_type;
+
+ sample_type zero(centers[0]);
+ set_all_elements(zero, 0);
+
+ unsigned long seed = 0;
+
+ // tells which center a sample belongs to
+ std::vector<unsigned long> assignments(samples.size(), samples.size());
+ std::vector<double> lengths;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ lengths.push_back(length(samples[i]));
+ // If there are zero vectors in samples then just say their length is 1 so we
+ // can avoid a division by zero check later on. Also, this doesn't matter
+ // since zero vectors can be assigned to any cluster randomly as there is no
+ // basis for picking one based on angle.
+ if (lengths.back() == 0)
+ lengths.back() = 1;
+ }
+
+ // We will keep the centers as unit vectors at all times throughout the processing.
+ for (unsigned long i = 0; i < centers.size(); ++i)
+ {
+ double len = length(centers[i]);
+ // Avoid having length 0 centers. If that is the case then pick another center
+ // at random.
+ while(len == 0)
+ {
+ centers[i] = matrix_cast<scalar_type>(gaussian_randm(centers[i].nr(), centers[i].nc(), seed++));
+ len = length(centers[i]);
+ }
+ centers[i] /= len;
+ }
+
+
+ unsigned long iter = 0;
+ bool centers_changed = true;
+ while (centers_changed && iter < max_iter)
+ {
+ ++iter;
+ centers_changed = false;
+
+ // loop over each sample and see which center it is closest to
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ // find the best center for sample[i]
+ scalar_type best_angle = std::numeric_limits<scalar_type>::max();
+ unsigned long best_center = 0;
+ for (unsigned long j = 0; j < centers.size(); ++j)
+ {
+ scalar_type angle = -dot(centers[j],samples[i])/lengths[i];
+
+ if (angle < best_angle)
+ {
+ best_angle = angle;
+ best_center = j;
+ }
+ }
+
+ if (assignments[i] != best_center)
+ {
+ centers_changed = true;
+ assignments[i] = best_center;
+ }
+ }
+
+ // now update all the centers
+ centers.assign(centers.size(), zero);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ centers[assignments[i]] += samples[i];
+ }
+ // Now length normalize all the centers.
+ for (unsigned long i = 0; i < centers.size(); ++i)
+ {
+ double len = length(centers[i]);
+ // Avoid having length 0 centers. If that is the case then pick another center
+ // at random.
+ while(len == 0)
+ {
+ centers[i] = matrix_cast<scalar_type>(gaussian_randm(centers[i].nr(), centers[i].nc(), seed++));
+ len = length(centers[i]);
+ centers_changed = true;
+ }
+ centers[i] /= len;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type,
+ typename EXP
+ >
+ unsigned long nearest_center (
+ const array_type& centers,
+ const matrix_exp<EXP>& sample
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(centers.size() > 0 && sample.size() > 0 && is_vector(sample),
+ "\t unsigned long nearest_center()"
+ << "\n\t You have given invalid inputs to this function."
+ << "\n\t centers.size(): " << centers.size()
+ << "\n\t sample.size(): " << sample.size()
+ << "\n\t is_vector(sample): " << is_vector(sample)
+ );
+
+ double best_dist = length_squared(centers[0] - sample);
+ unsigned long best_idx = 0;
+ for (unsigned long i = 1; i < centers.size(); ++i)
+ {
+ const double dist = length_squared(centers[i] - sample);
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ best_idx = i;
+ }
+ }
+ return best_idx;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KKMEANs_
+
+
diff --git a/ml/dlib/dlib/svm/kkmeans_abstract.h b/ml/dlib/dlib/svm/kkmeans_abstract.h
new file mode 100644
index 000000000..9f9d7ccce
--- /dev/null
+++ b/ml/dlib/dlib/svm/kkmeans_abstract.h
@@ -0,0 +1,365 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_KKMEANs_ABSTRACT_
+#ifdef DLIB_KKMEANs_ABSTRACT_
+
+#include <cmath>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "kernel_abstract.h"
+#include "kcentroid_abstract.h"
+#include "../noncopyable.h"
+
+namespace dlib
+{
+
+ template <
+ typename kernel_type
+ >
+ class kkmeans : public noncopyable
+ {
+ /*!
+ REQUIREMENTS ON kernel_type
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ INITIAL VALUE
+ - number_of_centers() == 1
+ - get_min_change() == 0.01
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of a kernelized k-means clustering algorithm.
+ It performs k-means clustering by using the kcentroid object.
+ !*/
+
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ kkmeans (
+ const kcentroid<kernel_type>& kc_
+ );
+ /*!
+ ensures
+ - #number_of_centers() == 1
+ - #get_min_change() == 0.01
+ - #get_kcentroid(0) == a copy of kc_
+ !*/
+
+ ~kkmeans(
+ );
+ /*!
+ ensures
+ - all resources associated with *this have been released
+ !*/
+
+ void set_kcentroid (
+ const kcentroid<kernel_type>& kc_
+ );
+ /*!
+ ensures
+ - for all idx:
+ - #get_kcentroid(idx) == a copy of kc_
+ !*/
+
+ const kcentroid<kernel_type>& get_kcentroid (
+ unsigned long i
+ ) const;
+ /*!
+ requires
+ - i < number_of_centers()
+ ensures
+ - returns a const reference to the ith kcentroid object contained in
+ this object. Each kcentroid represents one of the centers found
+ by the k-means clustering algorithm.
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the kernel used by this object
+ !*/
+
+ void set_number_of_centers (
+ unsigned long num
+ );
+ /*!
+ requires
+ - num > 0
+ ensures
+ - #number_of_centers() == num
+ !*/
+
+ unsigned long number_of_centers (
+ ) const;
+ /*!
+ ensures
+ - returns the number of centers used in this instance of the k-means clustering
+ algorithm.
+ !*/
+
+ template <
+ typename matrix_type,
+ typename matrix_type2
+ >
+ void train (
+ const matrix_type& samples,
+ const matrix_type2& initial_centers,
+ long max_iter = 1000
+ );
+ /*!
+ requires
+ - matrix_type and matrix_type2 must either be dlib::matrix objects or convertible to dlib::matrix
+ via mat()
+ - matrix_type::type == sample_type (i.e. matrix_type should contain sample_type objects)
+ - matrix_type2::type == sample_type (i.e. matrix_type2 should contain sample_type objects)
+ - initial_centers.nc() == 1 (i.e. must be a column vector)
+ - samples.nc() == 1 (i.e. must be a column vector)
+ - initial_centers.nr() == number_of_centers()
+ ensures
+ - performs k-means clustering of the given set of samples. The initial center points
+ are taken from the initial_centers argument.
+ - loops over the data and continues to refine the clustering until either less than
+ get_min_change() fraction of the data points change clusters or we have done max_iter
+ iterations over the data.
+ - After this function finishes you can call the operator() function below
+ to determine which centroid a given sample is closest to.
+ !*/
+
+ unsigned long operator() (
+ const sample_type& sample
+ ) const;
+ /*!
+ ensures
+ - returns a number idx such that:
+ - idx < number_of_centers()
+ - get_kcentroid(idx) == the centroid that is closest to the given
+ sample.
+ !*/
+
+ void set_min_change (
+ scalar_type min_change
+ );
+ /*!
+ requires
+ - 0 <= min_change < 1
+ ensures
+ - #get_min_change() == min_change
+ !*/
+
+ const scalar_type get_min_change (
+ ) const;
+ /*!
+ ensures
+ - returns the minimum fraction of data points that need to change
+ centers in an iteration of kmeans for the algorithm to keep going.
+ !*/
+
+ void swap (
+ kkmeans& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type
+ >
+ void swap(
+ kkmeans<kernel_type>& a,
+ kkmeans<kernel_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void serialize (
+ const kkmeans<kernel_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for kkmeans objects
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void deserialize (
+ kkmeans<kernel_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for kkmeans objects
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type1,
+ typename vector_type2,
+ typename kernel_type
+ >
+ void pick_initial_centers(
+ long num_centers,
+ vector_type1& centers,
+ const vector_type2& samples,
+ const kernel_type& k,
+ double percentile = 0.01
+ );
+ /*!
+ requires
+ - num_centers > 1
+ - 0 <= percentile < 1
+ - samples.size() > 1
+ - vector_type1 == something with an interface compatible with std::vector
+ - vector_type2 == something with an interface compatible with std::vector
+ - k(samples[0],samples[0]) must be a valid expression that returns a double
+ - both centers and samples must be able to contain kernel_type::sample_type
+ objects
+ ensures
+ - finds num_centers candidate cluster centers in the data in the samples
+ vector. Assumes that k is the kernel that will be used during clustering
+ to define the space in which clustering occurs.
+ - The centers are found by looking for points that are far away from other
+ candidate centers. However, if the data is noisy you probably want to
+ ignore the farthest way points since they will be outliers. To do this
+ set percentile to the fraction of outliers you expect the data to contain.
+ - #centers.size() == num_centers
+ - #centers == a vector containing the candidate centers found
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type1,
+ typename vector_type2
+ >
+ void pick_initial_centers(
+ long num_centers,
+ vector_type1& centers,
+ const vector_type2& samples,
+ double percentile = 0.01
+ );
+ /*!
+ requires
+ - num_centers > 1
+ - 0 <= percentile < 1
+ - samples.size() > 1
+ - vector_type1 == something with an interface compatible with std::vector
+ - vector_type2 == something with an interface compatible with std::vector
+ - Both centers and samples must be able to contain dlib::matrix based row or
+ column vectors.
+ ensures
+ - performs: pick_initial_centers(num_centers, centers, samples, linear_kernel<sample_type>(), percentile)
+ (i.e. this function is simply an overload that uses the linear kernel.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type,
+ typename sample_type,
+ typename alloc
+ >
+ void find_clusters_using_kmeans (
+ const array_type& samples,
+ std::vector<sample_type, alloc>& centers,
+ unsigned long max_iter = 1000
+ );
+ /*!
+ requires
+ - samples.size() > 0
+ - samples == a bunch of row or column vectors and they all must be of the
+ same length.
+ - centers.size() > 0
+ - array_type == something with an interface compatible with std::vector
+ and it must contain row or column vectors capable of being stored in
+ sample_type objects.
+ - sample_type == a dlib::matrix capable of representing vectors
+ ensures
+ - performs regular old linear kmeans clustering on the samples. The clustering
+ begins with the initial set of centers given as an argument to this function.
+ When it finishes #centers will contain the resulting centers.
+ - no more than max_iter iterations will be performed before this function
+ terminates.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type,
+ typename sample_type,
+ typename alloc
+ >
+ void find_clusters_using_angular_kmeans (
+ const array_type& samples,
+ std::vector<sample_type, alloc>& centers,
+ unsigned long max_iter = 1000
+ );
+ /*!
+ requires
+ - samples.size() > 0
+ - samples == a bunch of row or column vectors and they all must be of the
+ same length.
+ - centers.size() > 0
+ - array_type == something with an interface compatible with std::vector
+ and it must contain row or column vectors capable of being stored in
+ sample_type objects.
+ - sample_type == a dlib::matrix capable of representing vectors
+ ensures
+ - performs linear kmeans clustering on the samples, except instead of using
+ Euclidean distance to compare samples to the centers it uses the angle
+ between a sample and a center (with respect to the origin). So we try to
+ cluster samples together if they have small angles with respect to each
+ other. The clustering begins with the initial set of centers given as an
+ argument to this function. When it finishes #centers will contain the
+ resulting centers.
+ - for all valid i:
+ - length(#centers[i]) == 1
+ (i.e. the output centers are scaled to be unit vectors since their
+ magnitude is irrelevant. Moreover, this makes it so you can use
+ functions like nearest_center() with #centers to find the cluster
+ assignments.)
+ - No more than max_iter iterations will be performed before this function
+ terminates.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename array_type,
+ typename EXP
+ >
+ unsigned long nearest_center (
+ const array_type& centers,
+ const matrix_exp<EXP>& sample
+ );
+ /*!
+ requires
+ - centers.size() > 0
+ - sample.size() > 0
+ - is_vector(sample) == true
+ - centers must be an array of vectors such that the following expression is
+ valid: length_squared(sample - centers[0]). (e.g. centers could be a
+ std::vector of matrix objects holding column vectors).
+ ensures
+ - returns the index that identifies the element of centers that is nearest to
+ sample. That is, returns a number IDX such that centers[IDX] is the element
+ of centers that minimizes length(centers[IDX]-sample).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KKMEANs_ABSTRACT_
+
diff --git a/ml/dlib/dlib/svm/krls.h b/ml/dlib/dlib/svm/krls.h
new file mode 100644
index 000000000..6c72e45e8
--- /dev/null
+++ b/ml/dlib/dlib/svm/krls.h
@@ -0,0 +1,358 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_KRLs_
+#define DLIB_KRLs_
+
+#include <vector>
+
+#include "krls_abstract.h"
+#include "../matrix.h"
+#include "function.h"
+#include "../std_allocator.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ class krls
+ {
+ /*!
+ This is an implementation of the kernel recursive least squares algorithm described in the paper:
+ The Kernel Recursive Least Squares Algorithm by Yaakov Engel.
+ !*/
+
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+
+ explicit krls (
+ const kernel_type& kernel_,
+ scalar_type tolerance_ = 0.001,
+ unsigned long max_dictionary_size_ = 1000000
+ ) :
+ kernel(kernel_),
+ my_tolerance(tolerance_),
+ my_max_dictionary_size(max_dictionary_size_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(tolerance_ >= 0,
+ "\tkrls::krls()"
+ << "\n\t You have to give a positive tolerance"
+ << "\n\t this: " << this
+ << "\n\t tolerance: " << tolerance_
+ );
+
+ clear_dictionary();
+ }
+
+ scalar_type tolerance() const
+ {
+ return my_tolerance;
+ }
+
+ unsigned long max_dictionary_size() const
+ {
+ return my_max_dictionary_size;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel;
+ }
+
+ void clear_dictionary ()
+ {
+ dictionary.clear();
+ alpha.clear();
+
+ K_inv.set_size(0,0);
+ K.set_size(0,0);
+ P.set_size(0,0);
+ }
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const
+ {
+ scalar_type temp = 0;
+ for (unsigned long i = 0; i < alpha.size(); ++i)
+ temp += alpha[i]*kern(dictionary[i], x);
+
+ return temp;
+ }
+
+ void train (
+ const sample_type& x,
+ scalar_type y
+ )
+ {
+ const scalar_type kx = kern(x,x);
+ if (alpha.size() == 0)
+ {
+ // just ignore this sample if it is the zero vector (or really close to being zero)
+ if (std::abs(kx) > std::numeric_limits<scalar_type>::epsilon())
+ {
+ // set initial state since this is the first training example we have seen
+
+ K_inv.set_size(1,1);
+ K_inv(0,0) = 1/kx;
+ K.set_size(1,1);
+ K(0,0) = kx;
+
+ alpha.push_back(y/kx);
+ dictionary.push_back(x);
+ P.set_size(1,1);
+ P(0,0) = 1;
+ }
+ }
+ else
+ {
+ // fill in k
+ k.set_size(alpha.size());
+ for (long r = 0; r < k.nr(); ++r)
+ k(r) = kern(x,dictionary[r]);
+
+ // compute the error we would have if we approximated the new x sample
+ // with the dictionary. That is, do the ALD test from the KRLS paper.
+ a = K_inv*k;
+ scalar_type delta = kx - trans(k)*a;
+
+ // if this new vector isn't approximately linearly dependent on the vectors
+ // in our dictionary.
+ if (delta > my_tolerance)
+ {
+ if (dictionary.size() >= my_max_dictionary_size)
+ {
+ // We need to remove one of the old members of the dictionary before
+ // we proceed with adding a new one. So remove the oldest one.
+ remove_dictionary_vector(0);
+
+ // recompute these guys since they were computed with the old
+ // kernel matrix
+ k = remove_row(k,0);
+ a = K_inv*k;
+ delta = kx - trans(k)*a;
+ }
+
+ // add x to the dictionary
+ dictionary.push_back(x);
+
+ // update K_inv by computing the new one in the temp matrix (equation 3.14)
+ matrix<scalar_type,0,0,mem_manager_type> temp(K_inv.nr()+1, K_inv.nc()+1);
+ // update the middle part of the matrix
+ set_subm(temp, get_rect(K_inv)) = K_inv + a*trans(a)/delta;
+ // update the right column of the matrix
+ set_subm(temp, 0, K_inv.nr(),K_inv.nr(),1) = -a/delta;
+ // update the bottom row of the matrix
+ set_subm(temp, K_inv.nr(), 0, 1, K_inv.nr()) = trans(-a/delta);
+ // update the bottom right corner of the matrix
+ temp(K_inv.nr(), K_inv.nc()) = 1/delta;
+ // put temp into K_inv
+ temp.swap(K_inv);
+
+
+
+
+ // update K (the kernel matrix)
+ temp.set_size(K.nr()+1, K.nc()+1);
+ set_subm(temp, get_rect(K)) = K;
+ // update the right column of the matrix
+ set_subm(temp, 0, K.nr(),K.nr(),1) = k;
+ // update the bottom row of the matrix
+ set_subm(temp, K.nr(), 0, 1, K.nr()) = trans(k);
+ temp(K.nr(), K.nc()) = kx;
+ // put temp into K
+ temp.swap(K);
+
+
+
+
+ // Now update the P matrix (equation 3.15)
+ temp.set_size(P.nr()+1, P.nc()+1);
+ set_subm(temp, get_rect(P)) = P;
+ // initialize the new sides of P
+ set_rowm(temp,P.nr()) = 0;
+ set_colm(temp,P.nr()) = 0;
+ temp(P.nr(), P.nc()) = 1;
+ temp.swap(P);
+
+ // now update the alpha vector (equation 3.16)
+ const scalar_type k_a = (y-trans(k)*mat(alpha))/delta;
+ for (unsigned long i = 0; i < alpha.size(); ++i)
+ {
+ alpha[i] -= a(i)*k_a;
+ }
+ alpha.push_back(k_a);
+ }
+ else
+ {
+ q = P*a/(1+trans(a)*P*a);
+
+ // update P (equation 3.12)
+ temp_matrix = trans(a)*P;
+ P -= q*temp_matrix;
+
+ // update the alpha vector (equation 3.13)
+ const scalar_type k_a = y-trans(k)*mat(alpha);
+ for (unsigned long i = 0; i < alpha.size(); ++i)
+ {
+ alpha[i] += (K_inv*q*k_a)(i);
+ }
+ }
+ }
+ }
+
+ void swap (
+ krls& item
+ )
+ {
+ exchange(kernel, item.kernel);
+ dictionary.swap(item.dictionary);
+ alpha.swap(item.alpha);
+ K_inv.swap(item.K_inv);
+ K.swap(item.K);
+ P.swap(item.P);
+ exchange(my_tolerance, item.my_tolerance);
+ q.swap(item.q);
+ a.swap(item.a);
+ k.swap(item.k);
+ temp_matrix.swap(item.temp_matrix);
+ exchange(my_max_dictionary_size, item.my_max_dictionary_size);
+ }
+
+ unsigned long dictionary_size (
+ ) const { return dictionary.size(); }
+
+ decision_function<kernel_type> get_decision_function (
+ ) const
+ {
+ return decision_function<kernel_type>(
+ mat(alpha),
+ -sum(mat(alpha))*tau,
+ kernel,
+ mat(dictionary)
+ );
+ }
+
+ friend void serialize(const krls& item, std::ostream& out)
+ {
+ serialize(item.kernel, out);
+ serialize(item.dictionary, out);
+ serialize(item.alpha, out);
+ serialize(item.K_inv, out);
+ serialize(item.K, out);
+ serialize(item.P, out);
+ serialize(item.my_tolerance, out);
+ serialize(item.my_max_dictionary_size, out);
+ }
+
+ friend void deserialize(krls& item, std::istream& in)
+ {
+ deserialize(item.kernel, in);
+ deserialize(item.dictionary, in);
+ deserialize(item.alpha, in);
+ deserialize(item.K_inv, in);
+ deserialize(item.K, in);
+ deserialize(item.P, in);
+ deserialize(item.my_tolerance, in);
+ deserialize(item.my_max_dictionary_size, in);
+ }
+
+ private:
+
+ inline scalar_type kern (const sample_type& m1, const sample_type& m2) const
+ {
+ return kernel(m1,m2) + tau;
+ }
+
+ void remove_dictionary_vector (
+ long i
+ )
+ /*!
+ requires
+ - 0 <= i < dictionary.size()
+ ensures
+ - #dictionary.size() == dictionary.size() - 1
+ - #alpha.size() == alpha.size() - 1
+ - updates the K_inv matrix so that it is still a proper inverse of the
+ kernel matrix
+ - also removes the necessary row and column from the K matrix
+ - uses the this->a variable so after this function runs that variable
+ will contain a different value.
+ !*/
+ {
+ // remove the dictionary vector
+ dictionary.erase(dictionary.begin()+i);
+
+ // remove the i'th vector from the inverse kernel matrix. This formula is basically
+ // just the reverse of the way K_inv is updated by equation 3.14 during normal training.
+ K_inv = removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i);
+
+ // now compute the updated alpha values to take account that we just removed one of
+ // our dictionary vectors
+ a = (K_inv*remove_row(K,i)*mat(alpha));
+
+ // now copy over the new alpha values
+ alpha.resize(alpha.size()-1);
+ for (unsigned long k = 0; k < alpha.size(); ++k)
+ {
+ alpha[k] = a(k);
+ }
+
+ // update the P matrix as well
+ P = removerc(P,i,i);
+
+ // update the K matrix as well
+ K = removerc(K,i,i);
+ }
+
+
+ kernel_type kernel;
+
+ typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type;
+ typedef std_allocator<scalar_type, mem_manager_type> alloc_scalar_type;
+ typedef std::vector<sample_type,alloc_sample_type> dictionary_vector_type;
+ typedef std::vector<scalar_type,alloc_scalar_type> alpha_vector_type;
+
+ dictionary_vector_type dictionary;
+ alpha_vector_type alpha;
+
+ matrix<scalar_type,0,0,mem_manager_type> K_inv;
+ matrix<scalar_type,0,0,mem_manager_type> K;
+ matrix<scalar_type,0,0,mem_manager_type> P;
+
+ scalar_type my_tolerance;
+ unsigned long my_max_dictionary_size;
+
+
+ // temp variables here just so we don't have to reconstruct them over and over. Thus,
+ // they aren't really part of the state of this object.
+ matrix<scalar_type,0,1,mem_manager_type> q;
+ matrix<scalar_type,0,1,mem_manager_type> a;
+ matrix<scalar_type,0,1,mem_manager_type> k;
+ matrix<scalar_type,1,0,mem_manager_type> temp_matrix;
+
+ const static scalar_type tau;
+
+ };
+
+ template <typename kernel_type>
+ const typename kernel_type::scalar_type krls<kernel_type>::tau = static_cast<typename kernel_type::scalar_type>(0.01);
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ void swap(krls<kernel_type>& a, krls<kernel_type>& b)
+ { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KRLs_
+
diff --git a/ml/dlib/dlib/svm/krls_abstract.h b/ml/dlib/dlib/svm/krls_abstract.h
new file mode 100644
index 000000000..7ea2d9872
--- /dev/null
+++ b/ml/dlib/dlib/svm/krls_abstract.h
@@ -0,0 +1,202 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_KRLs_ABSTRACT_
+#ifdef DLIB_KRLs_ABSTRACT_
+
+#include <cmath>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename kernel_type
+ >
+ class krls
+ {
+ /*!
+ REQUIREMENTS ON kernel_type
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ INITIAL VALUE
+ - dictionary_size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the kernel recursive least squares algorithm
+ described in the paper:
+ The Kernel Recursive Least Squares Algorithm by Yaakov Engel.
+
+ The long and short of this algorithm is that it is an online kernel based
+ regression algorithm. You give it samples (x,y) and it learns the function
+ f(x) == y. For a detailed description of the algorithm read the above paper.
+
+ Also note that the algorithm internally keeps a set of "dictionary vectors"
+ that are used to represent the regression function. You can force the
+ algorithm to use no more than a set number of vectors by setting
+ the 3rd constructor argument to whatever you want. However, note that
+ doing this causes the algorithm to bias it's results towards more
+ recent training examples.
+ !*/
+
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+
+ explicit krls (
+ const kernel_type& kernel_,
+ scalar_type tolerance_ = 0.001,
+ unsigned long max_dictionary_size_ = 1000000
+ );
+ /*!
+ requires
+ - tolerance >= 0
+ ensures
+ - this object is properly initialized
+ - #tolerance() == tolerance_
+ - #get_decision_function().kernel_function == kernel_
+ (i.e. this object will use the given kernel function)
+ - #get_kernel() == kernel_
+ - #max_dictionary_size() == max_dictionary_size_
+ !*/
+
+ scalar_type tolerance(
+ ) const;
+ /*!
+ ensures
+ - returns the tolerance to use for the approximately linearly dependent
+ test in the KRLS algorithm. This is a number which governs how
+ accurately this object will approximate the decision function it is
+ learning. Smaller values generally result in a more accurate
+ estimate while also resulting in a bigger set of dictionary vectors in
+ the learned decision function. Bigger tolerances values result in a
+ less accurate decision function but also in less dictionary vectors.
+ - The exact meaning of the tolerance parameter is the following:
+ Imagine that we have an empirical_kernel_map that contains all
+ the current dictionary vectors. Then the tolerance is the minimum
+ projection error (as given by empirical_kernel_map::project()) required
+ to cause us to include a new vector in the dictionary. So each time
+ you call train() the krls object basically just computes the projection
+ error for that new sample and if it is larger than the tolerance
+ then that new sample becomes part of the dictionary.
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the kernel used by this object
+ !*/
+
+ unsigned long max_dictionary_size(
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of dictionary vectors this object
+ will use at a time. That is, dictionary_size() will never be
+ greater than max_dictionary_size().
+ !*/
+
+ void clear_dictionary (
+ );
+ /*!
+ ensures
+ - clears out all learned data
+ (e.g. #get_decision_function().basis_vectors.size() == 0)
+ !*/
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const;
+ /*!
+ ensures
+ - returns the current y estimate for the given x
+ !*/
+
+ void train (
+ const sample_type& x,
+ scalar_type y
+ );
+ /*!
+ ensures
+ - trains this object that the given x should be mapped to the given y
+ - if (dictionary_size() == max_dictionary_size() and training
+ would add another dictionary vector to this object) then
+ - discards the oldest dictionary vector so that we can still
+ add a new one and remain below the max number of dictionary
+ vectors.
+ !*/
+
+ void swap (
+ krls& item
+ );
+ /*!
+ ensures
+ - swaps *this with item
+ !*/
+
+ unsigned long dictionary_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of vectors in the dictionary. That is,
+ returns a number equal to get_decision_function().basis_vectors.size()
+ !*/
+
+ decision_function<kernel_type> get_decision_function (
+ ) const;
+ /*!
+ ensures
+ - returns a decision function F that represents the function learned
+ by this object so far. I.e. it is the case that:
+ - for all x: F(x) == (*this)(x)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type
+ >
+ void swap(
+ krls<kernel_type>& a,
+ krls<kernel_type>& b
+ )
+ { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void serialize (
+ const krls<kernel_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for krls objects
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void deserialize (
+ krls<kernel_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for krls objects
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_KRLs_ABSTRACT_
+
diff --git a/ml/dlib/dlib/svm/krr_trainer.h b/ml/dlib/dlib/svm/krr_trainer.h
new file mode 100644
index 000000000..a43431169
--- /dev/null
+++ b/ml/dlib/dlib/svm/krr_trainer.h
@@ -0,0 +1,368 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_KRR_TRAInER_Hh_
+#define DLIB_KRR_TRAInER_Hh_
+
+#include "../algs.h"
+#include "function.h"
+#include "kernel.h"
+#include "empirical_kernel_map.h"
+#include "linearly_independent_subset_finder.h"
+#include "../statistics.h"
+#include "rr_trainer.h"
+#include "krr_trainer_abstract.h"
+#include <vector>
+#include <iostream>
+
+namespace dlib
+{
+ template <
+ typename K
+ >
+ class krr_trainer
+ {
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ krr_trainer (
+ ) :
+ verbose(false),
+ max_basis_size(400),
+ ekm_stale(true)
+ {
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ trainer.be_verbose();
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ trainer.be_quiet();
+ }
+
+ void use_regression_loss_for_loo_cv (
+ )
+ {
+ trainer.use_regression_loss_for_loo_cv();
+ }
+
+ void use_classification_loss_for_loo_cv (
+ )
+ {
+ trainer.use_classification_loss_for_loo_cv();
+ }
+
+ bool will_use_regression_loss_for_loo_cv (
+ ) const
+ {
+ return trainer.will_use_regression_loss_for_loo_cv();
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ return kern;
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kern = k;
+ }
+
+ template <typename T>
+ void set_basis (
+ const T& basis_samples
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(basis_samples.size() > 0 && is_vector(mat(basis_samples)),
+ "\tvoid krr_trainer::set_basis(basis_samples)"
+ << "\n\t You have to give a non-empty set of basis_samples and it must be a vector"
+ << "\n\t basis_samples.size(): " << basis_samples.size()
+ << "\n\t is_vector(mat(basis_samples)): " << is_vector(mat(basis_samples))
+ << "\n\t this: " << this
+ );
+
+ basis = mat(basis_samples);
+ ekm_stale = true;
+ }
+
+ bool basis_loaded (
+ ) const
+ {
+ return (basis.size() != 0);
+ }
+
+ void clear_basis (
+ )
+ {
+ basis.set_size(0);
+ ekm.clear();
+ ekm_stale = true;
+ }
+
+ unsigned long get_max_basis_size (
+ ) const
+ {
+ return max_basis_size;
+ }
+
+ void set_max_basis_size (
+ unsigned long max_basis_size_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_basis_size_ > 0,
+ "\t void krr_trainer::set_max_basis_size()"
+ << "\n\t max_basis_size_ must be greater than 0"
+ << "\n\t max_basis_size_: " << max_basis_size_
+ << "\n\t this: " << this
+ );
+
+ max_basis_size = max_basis_size_;
+ }
+
+ void set_lambda (
+ scalar_type lambda_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(lambda_ >= 0,
+ "\t void krr_trainer::set_lambda()"
+ << "\n\t lambda must be greater than or equal to 0"
+ << "\n\t lambda_: " << lambda_
+ << "\n\t this: " << this
+ );
+
+ trainer.set_lambda(lambda_);
+ }
+
+ const scalar_type get_lambda (
+ ) const
+ {
+ return trainer.get_lambda();
+ }
+
+ template <typename EXP>
+ void set_search_lambdas (
+ const matrix_exp<EXP>& lambdas
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(lambdas) && lambdas.size() > 0 && min(lambdas) > 0,
+ "\t void krr_trainer::set_search_lambdas()"
+ << "\n\t lambdas must be a non-empty vector of values"
+ << "\n\t is_vector(lambdas): " << is_vector(lambdas)
+ << "\n\t lambdas.size(): " << lambdas.size()
+ << "\n\t min(lambdas): " << min(lambdas)
+ << "\n\t this: " << this
+ );
+
+ trainer.set_search_lambdas(lambdas);
+ }
+
+ const matrix<scalar_type,0,0,mem_manager_type>& get_search_lambdas (
+ ) const
+ {
+ return trainer.get_search_lambdas();
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ std::vector<scalar_type> temp;
+ scalar_type temp2;
+ return do_train(mat(x), mat(y), false, temp, temp2);
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values
+ ) const
+ {
+ scalar_type temp;
+ return do_train(mat(x), mat(y), true, loo_values, temp);
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values,
+ scalar_type& lambda_used
+ ) const
+ {
+ return do_train(mat(x), mat(y), true, loo_values, lambda_used);
+ }
+
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const bool output_loo_values,
+ std::vector<scalar_type>& loo_values,
+ scalar_type& the_lambda
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y),
+ "\t decision_function krr_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_vector(x): " << is_vector(x)
+ << "\n\t is_vector(y): " << is_vector(y)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t y.size(): " << y.size()
+ );
+
+#ifdef ENABLE_ASSERTS
+ if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y),
+ "\t decision_function krr_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ );
+ }
+#endif
+
+ // The first thing we do is make sure we have an appropriate ekm ready for use below.
+ if (basis_loaded())
+ {
+ if (ekm_stale)
+ {
+ ekm.load(kern, basis);
+ ekm_stale = false;
+ }
+ }
+ else
+ {
+ linearly_independent_subset_finder<kernel_type> lisf(kern, max_basis_size);
+ fill_lisf(lisf, x);
+ ekm.load(lisf);
+ }
+
+ if (verbose)
+ {
+ std::cout << "\nNumber of basis vectors used: " << ekm.out_vector_size() << std::endl;
+ }
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> column_matrix_type;
+
+ running_stats<scalar_type> rs;
+
+ // Now we project all the x samples into kernel space using our EKM
+ matrix<column_matrix_type,0,1,mem_manager_type > proj_x;
+ proj_x.set_size(x.size());
+ for (long i = 0; i < proj_x.size(); ++i)
+ {
+ scalar_type err;
+ // Note that we also append a 1 to the end of the vectors because this is
+ // a convenient way of dealing with the bias term later on.
+ if (verbose == false)
+ {
+ proj_x(i) = ekm.project(x(i));
+ }
+ else
+ {
+ proj_x(i) = ekm.project(x(i),err);
+ rs.add(err);
+ }
+ }
+
+ if (verbose)
+ {
+ std::cout << "Mean EKM projection error: " << rs.mean() << std::endl;
+ std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl;
+ }
+
+
+ decision_function<linear_kernel<matrix<scalar_type,0,0,mem_manager_type> > > lin_df;
+
+ if (output_loo_values)
+ lin_df = trainer.train(proj_x,y, loo_values, the_lambda);
+ else
+ lin_df = trainer.train(proj_x,y);
+
+ // convert the linear decision function into a kernelized one.
+ decision_function<kernel_type> df;
+ df = ekm.convert_to_decision_function(lin_df.basis_vectors(0));
+ df.b = lin_df.b;
+
+ // If we used an automatically derived basis then there isn't any point in
+ // keeping the ekm around. So free its memory.
+ if (basis_loaded() == false)
+ {
+ ekm.clear();
+ }
+
+ return df;
+ }
+
+
+ /*!
+ CONVENTION
+ - if (ekm_stale) then
+ - kern or basis have changed since the last time
+ they were loaded into the ekm
+
+ - get_lambda() == trainer.get_lambda()
+ - get_kernel() == kern
+ - get_max_basis_size() == max_basis_size
+ - will_use_regression_loss_for_loo_cv() == trainer.will_use_regression_loss_for_loo_cv()
+ - get_search_lambdas() == trainer.get_search_lambdas()
+
+ - basis_loaded() == (basis.size() != 0)
+ !*/
+
+ rr_trainer<linear_kernel<matrix<scalar_type,0,0,mem_manager_type> > > trainer;
+
+ bool verbose;
+
+
+ kernel_type kern;
+ unsigned long max_basis_size;
+
+ matrix<sample_type,0,1,mem_manager_type> basis;
+ mutable empirical_kernel_map<kernel_type> ekm;
+ mutable bool ekm_stale;
+
+ };
+
+}
+
+#endif // DLIB_KRR_TRAInER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/krr_trainer_abstract.h b/ml/dlib/dlib/svm/krr_trainer_abstract.h
new file mode 100644
index 000000000..399802f6b
--- /dev/null
+++ b/ml/dlib/dlib/svm/krr_trainer_abstract.h
@@ -0,0 +1,322 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_KRR_TRAInER_ABSTRACT_Hh_
+#ifdef DLIB_KRR_TRAInER_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "empirical_kernel_map_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename K
+ >
+ class krr_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ INITIAL VALUE
+ - get_lambda() == 0
+ - basis_loaded() == false
+ - get_max_basis_size() == 400
+ - will_use_regression_loss_for_loo_cv() == true
+ - get_search_lambdas() == logspace(-9, 2, 50)
+ - this object will not be verbose unless be_verbose() is called
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a tool for performing kernel ridge regression
+ (This basic algorithm is also known my many other names, e.g. regularized
+ least squares or least squares SVM).
+
+ The exact definition of what this algorithm does is this:
+ Find w and b that minimizes the following (x_i are input samples and y_i are target values):
+ lambda*dot(w,w) + sum_over_i( (f(x_i) - y_i)^2 )
+ where f(x) == dot(x,w) - b
+
+ Except the dot products are replaced by kernel functions. So this
+ algorithm is just regular old least squares regression but with the
+ addition of a regularization term which encourages small w and the
+ application of the kernel trick.
+
+
+ It is implemented using the empirical_kernel_map and thus allows you
+ to run the algorithm on large datasets and obtain sparse outputs. It is also
+ capable of estimating the lambda parameter using leave-one-out cross-validation.
+
+
+ The leave-one-out cross-validation implementation is based on the techniques
+ discussed in this paper:
+ Notes on Regularized Least Squares by Ryan M. Rifkin and Ross A. Lippert.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ krr_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ const kernel_type get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ template <typename T>
+ void set_basis (
+ const T& basis_samples
+ );
+ /*!
+ requires
+ - T must be a dlib::matrix type or something convertible to a matrix via mat()
+ (e.g. a std::vector)
+ - is_vector(basis_samples) == true
+ - basis_samples.size() > 0
+ - get_kernel() must be capable of operating on the elements of basis_samples. That is,
+ expressions such as get_kernel()(basis_samples(0), basis_samples(0)) should make sense.
+ ensures
+ - #basis_loaded() == true
+ - training will be carried out in the span of the given basis_samples
+ !*/
+
+ bool basis_loaded (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object has been loaded with user supplied basis vectors and false otherwise.
+ !*/
+
+ void clear_basis (
+ );
+ /*!
+ ensures
+ - #basis_loaded() == false
+ !*/
+
+ unsigned long get_max_basis_size (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of basis vectors this object is allowed
+ to use. This parameter only matters when the user has not supplied
+ a basis via set_basis().
+ !*/
+
+ void set_max_basis_size (
+ unsigned long max_basis_size
+ );
+ /*!
+ requires
+ - max_basis_size > 0
+ ensures
+ - #get_max_basis_size() == max_basis_size
+ !*/
+
+ void set_lambda (
+ scalar_type lambda
+ );
+ /*!
+ requires
+ - lambda >= 0
+ ensures
+ - #get_lambda() == lambda
+ !*/
+
+ const scalar_type get_lambda (
+ ) const;
+ /*!
+ ensures
+ - returns the regularization parameter. It is the parameter that
+ determines the trade off between trying to fit the training data
+ exactly or allowing more errors but hopefully improving the
+ generalization ability of the resulting function. Smaller values
+ encourage exact fitting while larger values of lambda may encourage
+ better generalization.
+
+ Note that a lambda of 0 has a special meaning. It indicates to this
+ object that it should automatically determine an appropriate lambda
+ value. This is done using leave-one-out cross-validation.
+ !*/
+
+ void use_regression_loss_for_loo_cv (
+ );
+ /*!
+ ensures
+ - #will_use_regression_loss_for_loo_cv() == true
+ !*/
+
+ void use_classification_loss_for_loo_cv (
+ );
+ /*!
+ ensures
+ - #will_use_regression_loss_for_loo_cv() == false
+ !*/
+
+ bool will_use_regression_loss_for_loo_cv (
+ ) const;
+ /*!
+ ensures
+ - returns true if the automatic lambda estimation will attempt to estimate a lambda
+ appropriate for a regression task. Otherwise it will try and find one which
+ minimizes the number of classification errors.
+ !*/
+
+ template <typename EXP>
+ void set_search_lambdas (
+ const matrix_exp<EXP>& lambdas
+ );
+ /*!
+ requires
+ - is_vector(lambdas) == true
+ - lambdas.size() > 0
+ - min(lambdas) > 0
+ - lambdas must contain floating point numbers
+ ensures
+ - #get_search_lambdas() == lambdas
+ !*/
+
+ const matrix<scalar_type,0,0,mem_manager_type>& get_search_lambdas (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix M such that:
+ - is_vector(M) == true
+ - M == a list of all the lambda values which will be tried when performing
+ LOO cross-validation for determining the best lambda.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ - is_learning_problem(x,y) == true
+ - if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false) then
+ - is_binary_classification_problem(x,y) == true
+ (i.e. if you want this algorithm to estimate a lambda appropriate for
+ classification functions then you had better give a valid classification
+ problem)
+ ensures
+ - performs kernel ridge regression given the training samples in x and target values in y.
+ - returns a decision_function F with the following properties:
+ - F(new_x) == predicted y value
+
+ - if (basis_loaded()) then
+ - training will be carried out in the span of the user supplied basis vectors
+ - else
+ - this object will attempt to automatically select an appropriate basis
+
+ - if (get_lambda() == 0) then
+ - This object will perform internal leave-one-out cross-validation to determine an
+ appropriate lambda automatically. It will compute the LOO error for each lambda
+ in get_search_lambdas() and select the best one.
+ - if (will_use_regression_loss_for_loo_cv()) then
+ - the lambda selected will be the one that minimizes the mean squared error.
+ - else
+ - the lambda selected will be the one that minimizes the number classification
+ mistakes. We say a point is classified correctly if the output of the
+ decision_function has the same sign as its label.
+ - #get_lambda() == 0
+ (i.e. we don't change the get_lambda() value. If you want to know what the
+ automatically selected lambda value was then call the version of train()
+ defined below)
+ - else
+ - The user supplied value of get_lambda() will be used to perform the kernel
+ ridge regression.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values
+ ) const;
+ /*!
+ requires
+ - all the requirements for train(x,y) must be satisfied
+ ensures
+ - returns train(x,y)
+ (i.e. executes train(x,y) and returns its result)
+ - #loo_values.size() == y.size()
+ - for all valid i:
+ - #loo_values[i] == leave-one-out prediction for the value of y(i) based
+ on all the training samples other than (x(i),y(i)).
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values,
+ scalar_type& lambda_used
+ ) const;
+ /*!
+ requires
+ - all the requirements for train(x,y) must be satisfied
+ ensures
+ - returns train(x,y)
+ (i.e. executes train(x,y) and returns its result)
+ - #loo_values.size() == y.size()
+ - for all valid i:
+ - #loo_values[i] == leave-one-out prediction for the value of y(i) based
+ on all the training samples other than (x(i),y(i)).
+ - #lambda_used == the value of lambda used to generate the
+ decision_function. Note that this lambda value is always
+ equal to get_lambda() if get_lambda() isn't 0.
+ !*/
+
+ };
+
+}
+
+#endif // DLIB_KRR_TRAInER_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/linearly_independent_subset_finder.h b/ml/dlib/dlib/svm/linearly_independent_subset_finder.h
new file mode 100644
index 000000000..3bac0df2c
--- /dev/null
+++ b/ml/dlib/dlib/svm/linearly_independent_subset_finder.h
@@ -0,0 +1,540 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_LISfh_
+#define DLIB_LISfh_
+
+#include <vector>
+
+#include "linearly_independent_subset_finder_abstract.h"
+#include "../matrix.h"
+#include "function.h"
+#include "../std_allocator.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "../is_kind.h"
+#include "../string.h"
+#include "../rand.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ class linearly_independent_subset_finder
+ {
+ /*!
+ INITIAL VALUE
+ - min_strength == 0
+ - min_vect_idx == 0
+ - K_inv.size() == 0
+ - K.size() == 0
+ - dictionary.size() == 0
+
+ CONVENTION
+ - max_dictionary_size() == my_max_dictionary_size
+ - get_kernel() == kernel
+ - minimum_tolerance() == min_tolerance
+ - size() == dictionary.size()
+ - get_dictionary() == mat(dictionary)
+ - K.nr() == dictionary.size()
+ - K.nc() == dictionary.size()
+ - for all valid r,c:
+ - K(r,c) == kernel(dictionary[r], dictionary[c])
+ - K_inv == inv(K)
+
+ - if (dictionary.size() == my_max_dictionary_size) then
+ - for all valid 0 < i < dictionary.size():
+ - Let STRENGTHS[i] == the delta you would get for dictionary[i] (i.e. Approximately
+ Linearly Dependent value) if you removed dictionary[i] from this object and then
+ tried to add it back in.
+ - min_strength == the minimum value from STRENGTHS
+ - min_vect_idx == the index of the element in STRENGTHS with the smallest value
+ !*/
+
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::sample_type type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ linearly_independent_subset_finder (
+ ) :
+ my_max_dictionary_size(100),
+ min_tolerance(0.001)
+ {
+ clear_dictionary();
+ }
+
+ linearly_independent_subset_finder (
+ const kernel_type& kernel_,
+ unsigned long max_dictionary_size_,
+ scalar_type min_tolerance_ = 0.001
+ ) :
+ kernel(kernel_),
+ my_max_dictionary_size(max_dictionary_size_),
+ min_tolerance(min_tolerance_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(min_tolerance_ > 0 && max_dictionary_size_ > 1,
+ "\tlinearly_independent_subset_finder()"
+ << "\n\tinvalid argument to constructor"
+ << "\n\tmin_tolerance_: " << min_tolerance_
+ << "\n\tmax_dictionary_size_: " << max_dictionary_size_
+ << "\n\tthis: " << this
+ );
+ clear_dictionary();
+ }
+
+ unsigned long max_dictionary_size() const
+ {
+ return my_max_dictionary_size;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel;
+ }
+
+ scalar_type minimum_tolerance(
+ ) const
+ {
+ return min_tolerance;
+ }
+
+ void set_minimum_tolerance (
+ scalar_type min_tol
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(min_tol > 0,
+ "\tlinearly_independent_subset_finder::set_minimum_tolerance()"
+ << "\n\tinvalid argument to this function"
+ << "\n\tmin_tol: " << min_tol
+ << "\n\tthis: " << this
+ );
+ min_tolerance = min_tol;
+ }
+
+ void clear_dictionary ()
+ {
+ dictionary.clear();
+ min_strength = 0;
+ min_vect_idx = 0;
+
+ K_inv.set_size(0,0);
+ K.set_size(0,0);
+ }
+
+ scalar_type projection_error (
+ const sample_type& x
+ ) const
+ {
+ const scalar_type kx = kernel(x,x);
+ if (dictionary.size() == 0)
+ {
+ return kx;
+ }
+ else
+ {
+ // fill in k
+ k.set_size(dictionary.size());
+ for (long r = 0; r < k.nr(); ++r)
+ k(r) = kernel(x,dictionary[r]);
+
+ // compute the error we would have if we approximated the new x sample
+ // with the dictionary. That is, do the ALD test from the KRLS paper.
+ a = K_inv*k;
+ scalar_type delta = kx - trans(k)*a;
+
+ return delta;
+ }
+ }
+
+ bool add (
+ const sample_type& x
+ )
+ {
+ const scalar_type kx = kernel(x,x);
+ if (dictionary.size() == 0)
+ {
+ // just ignore this sample if it is the zero vector (or really close to being zero)
+ if (std::abs(kx) > std::numeric_limits<scalar_type>::epsilon())
+ {
+ // set initial state since this is the first sample we have seen
+ K_inv.set_size(1,1);
+ K_inv(0,0) = 1/kx;
+
+ K.set_size(1,1);
+ K(0,0) = kx;
+
+ dictionary.push_back(x);
+ return true;
+ }
+ return false;
+ }
+ else
+ {
+ // fill in k
+ k.set_size(dictionary.size());
+ for (long r = 0; r < k.nr(); ++r)
+ k(r) = kernel(x,dictionary[r]);
+
+ // compute the error we would have if we approximated the new x sample
+ // with the dictionary. That is, do the ALD test from the KRLS paper.
+ a = K_inv*k;
+ scalar_type delta = kx - trans(k)*a;
+
+ // if this new vector is approximately linearly independent of the vectors
+ // in our dictionary.
+ if (delta > min_strength && delta > min_tolerance)
+ {
+ if (dictionary.size() == my_max_dictionary_size)
+ {
+ // if we have never computed the min_strength then we should compute it
+ if (min_strength == 0)
+ recompute_min_strength();
+
+ const long i = min_vect_idx;
+
+ // replace the min strength vector with x. Put the new vector onto the end of
+ // dictionary and remove the vector at position i.
+ dictionary.erase(dictionary.begin()+i);
+ dictionary.push_back(x);
+
+ // compute reduced K_inv.
+ // Remove the i'th vector from the inverse kernel matrix. This formula is basically
+ // just the reverse of the way K_inv is updated by equation 3.14 below.
+ temp = removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i);
+
+ // recompute these guys since they were computed with the old
+ // kernel matrix
+ k2 = remove_row(k,i);
+ a2 = temp*k2;
+ delta = kx - trans(k2)*a2;
+
+ // now update temp with the new dictionary vector
+ // update the middle part of the matrix
+ set_subm(K_inv, get_rect(temp)) = temp + a2*trans(a2)/delta;
+ // update the right column of the matrix
+ set_subm(K_inv, 0, temp.nr(),temp.nr(),1) = -a2/delta;
+ // update the bottom row of the matrix
+ set_subm(K_inv, temp.nr(), 0, 1, temp.nr()) = trans(-a2/delta);
+ // update the bottom right corner of the matrix
+ K_inv(temp.nr(), temp.nc()) = 1/delta;
+
+ // now update the kernel matrix K
+ set_subm(K,get_rect(temp)) = removerc(K, i,i);
+ set_subm(K, 0, K.nr()-1,K.nr()-1,1) = k2;
+ // update the bottom row of the matrix
+ set_subm(K, K.nr()-1, 0, 1, K.nr()-1) = trans(k2);
+ K(K.nr()-1, K.nc()-1) = kx;
+
+ // now we have to recompute the min_strength in this case
+ recompute_min_strength();
+ }
+ else
+ {
+ // update K_inv by computing the new one in the temp matrix (equation 3.14 from Engel)
+ temp.set_size(K_inv.nr()+1, K_inv.nc()+1);
+ // update the middle part of the matrix
+ set_subm(temp, get_rect(K_inv)) = K_inv + a*trans(a)/delta;
+ // update the right column of the matrix
+ set_subm(temp, 0, K_inv.nr(),K_inv.nr(),1) = -a/delta;
+ // update the bottom row of the matrix
+ set_subm(temp, K_inv.nr(), 0, 1, K_inv.nr()) = trans(-a/delta);
+ // update the bottom right corner of the matrix
+ temp(K_inv.nr(), K_inv.nc()) = 1/delta;
+ // put temp into K_inv
+ temp.swap(K_inv);
+
+
+ // update K (the kernel matrix)
+ temp.set_size(K.nr()+1, K.nc()+1);
+ set_subm(temp, get_rect(K)) = K;
+ // update the right column of the matrix
+ set_subm(temp, 0, K.nr(),K.nr(),1) = k;
+ // update the bottom row of the matrix
+ set_subm(temp, K.nr(), 0, 1, K.nr()) = trans(k);
+ temp(K.nr(), K.nc()) = kx;
+ // put temp into K
+ temp.swap(K);
+
+
+ // add x to the dictionary
+ dictionary.push_back(x);
+
+ }
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ }
+
+ void swap (
+ linearly_independent_subset_finder& item
+ )
+ {
+ exchange(kernel, item.kernel);
+ dictionary.swap(item.dictionary);
+ exchange(min_strength, item.min_strength);
+ exchange(min_vect_idx, item.min_vect_idx);
+ K_inv.swap(item.K_inv);
+ K.swap(item.K);
+ exchange(my_max_dictionary_size, item.my_max_dictionary_size);
+ exchange(min_tolerance, item.min_tolerance);
+
+ // non-state temp members
+ a.swap(item.a);
+ k.swap(item.k);
+ a2.swap(item.a2);
+ k2.swap(item.k2);
+ temp.swap(item.temp);
+ }
+
+ size_t size (
+ ) const { return dictionary.size(); }
+
+ const matrix<sample_type,0,1,mem_manager_type> get_dictionary (
+ ) const
+ {
+ return mat(dictionary);
+ }
+
+ friend void serialize(const linearly_independent_subset_finder& item, std::ostream& out)
+ {
+ serialize(item.kernel, out);
+ serialize(item.dictionary, out);
+ serialize(item.min_strength, out);
+ serialize(item.min_vect_idx, out);
+ serialize(item.K_inv, out);
+ serialize(item.K, out);
+ serialize(item.my_max_dictionary_size, out);
+ serialize(item.min_tolerance, out);
+ }
+
+ friend void deserialize(linearly_independent_subset_finder& item, std::istream& in)
+ {
+ deserialize(item.kernel, in);
+ deserialize(item.dictionary, in);
+ deserialize(item.min_strength, in);
+ deserialize(item.min_vect_idx, in);
+ deserialize(item.K_inv, in);
+ deserialize(item.K, in);
+ deserialize(item.my_max_dictionary_size, in);
+ deserialize(item.min_tolerance, in);
+ }
+
+ const sample_type& operator[] (
+ unsigned long index
+ ) const
+ {
+ return dictionary[index];
+ }
+
+ const matrix<scalar_type,0,0,mem_manager_type>& get_kernel_matrix (
+ ) const
+ {
+ return K;
+ }
+
+ const matrix<scalar_type,0,0,mem_manager_type>& get_inv_kernel_marix (
+ ) const
+ {
+ return K_inv;
+ }
+
+ private:
+
+ typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type;
+ typedef std_allocator<scalar_type, mem_manager_type> alloc_scalar_type;
+ typedef std::vector<sample_type,alloc_sample_type> dictionary_vector_type;
+ typedef std::vector<scalar_type,alloc_scalar_type> scalar_vector_type;
+
+ void recompute_min_strength (
+ )
+ /*!
+ ensures
+ - recomputes the min_strength and min_vect_idx values
+ so that they are correct with respect to the CONVENTION
+ !*/
+ {
+ min_strength = std::numeric_limits<scalar_type>::max();
+
+ // here we loop over each dictionary vector and compute what its delta would be if
+ // we were to remove it from the dictionary and then try to add it back in.
+ for (unsigned long i = 0; i < dictionary.size(); ++i)
+ {
+ // compute a2 = K_inv*k but where dictionary vector i has been removed
+ a2 = (removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i)) *
+ (remove_row(colm(K,i),i));
+ scalar_type delta = K(i,i) - trans(remove_row(colm(K,i),i))*a2;
+
+ if (delta < min_strength)
+ {
+ min_strength = delta;
+ min_vect_idx = i;
+ }
+ }
+ }
+
+
+ kernel_type kernel;
+ dictionary_vector_type dictionary;
+ scalar_type min_strength;
+ unsigned long min_vect_idx;
+
+ matrix<scalar_type,0,0,mem_manager_type> K_inv;
+ matrix<scalar_type,0,0,mem_manager_type> K;
+
+ unsigned long my_max_dictionary_size;
+ scalar_type min_tolerance;
+
+ // temp variables here just so we don't have to reconstruct them over and over. Thus,
+ // they aren't really part of the state of this object.
+ mutable matrix<scalar_type,0,1,mem_manager_type> a, a2;
+ mutable matrix<scalar_type,0,1,mem_manager_type> k, k2;
+ mutable matrix<scalar_type,0,0,mem_manager_type> temp;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ void swap(linearly_independent_subset_finder<kernel_type>& a, linearly_independent_subset_finder<kernel_type>& b)
+ { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const matrix_op<op_array_to_mat<linearly_independent_subset_finder<T> > > mat (
+ const linearly_independent_subset_finder<T>& m
+ )
+ {
+ typedef op_array_to_mat<linearly_independent_subset_finder<T> > op;
+ return matrix_op<op>(op(m));
+ }
+
+// ----------------------------------------------------------------------------------------
+ namespace impl
+ {
+ template <
+ typename kernel_type,
+ typename vector_type,
+ typename rand_type
+ >
+ void fill_lisf (
+ linearly_independent_subset_finder<kernel_type>& lisf,
+ const vector_type& samples,
+ rand_type& rnd,
+ int sampling_size
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(samples) && sampling_size > 0,
+ "\t void fill_lisf()"
+ << "\n\t invalid arguments to this function"
+ << "\n\t is_vector(samples): " << is_vector(samples)
+ << "\n\t sampling_size: " << sampling_size
+ );
+
+ // no need to do anything if there aren't any samples
+ if (samples.size() == 0)
+ return;
+
+ typedef typename kernel_type::scalar_type scalar_type;
+
+ // Start out by guessing what a reasonable projection error tolerance is. We will use
+ // the biggest projection error we see in a small sample.
+ scalar_type tol = 0;
+ for (int i = 0; i < sampling_size; ++i)
+ {
+ const unsigned long idx = rnd.get_random_32bit_number()%samples.size();
+ const scalar_type temp = lisf.projection_error(samples(idx));
+ if (temp > tol)
+ tol = temp;
+ }
+
+ const scalar_type min_tol = lisf.minimum_tolerance();
+
+ // run many rounds of random sampling. In each round we drop the tolerance lower.
+ while (tol >= min_tol && lisf.size() < lisf.max_dictionary_size())
+ {
+ tol *= 0.5;
+ lisf.set_minimum_tolerance(std::max(tol, min_tol));
+ int add_failures = 0;
+
+ // Keep picking random samples and adding them into the lisf. Stop when we either
+ // fill it up or can't find any more samples with projection error larger than the
+ // current tolerance.
+ while (lisf.size() < lisf.max_dictionary_size() && add_failures < sampling_size)
+ {
+ if (lisf.add(samples(rnd.get_random_32bit_number()%samples.size())) == false)
+ {
+ ++add_failures;
+ }
+ }
+ }
+
+ // set this back to its original value
+ lisf.set_minimum_tolerance(min_tol);
+ }
+ }
+
+ template <
+ typename kernel_type,
+ typename vector_type
+ >
+ void fill_lisf (
+ linearly_independent_subset_finder<kernel_type>& lisf,
+ const vector_type& samples
+ )
+ {
+ dlib::rand rnd;
+ impl::fill_lisf(lisf, mat(samples),rnd, 2000);
+ }
+
+ template <
+ typename kernel_type,
+ typename vector_type,
+ typename rand_type
+ >
+ typename enable_if<is_rand<rand_type> >::type fill_lisf (
+ linearly_independent_subset_finder<kernel_type>& lisf,
+ const vector_type& samples,
+ rand_type& rnd,
+ const int sampling_size = 2000
+ )
+ {
+ impl::fill_lisf(lisf, mat(samples),rnd, sampling_size);
+ }
+
+ template <
+ typename kernel_type,
+ typename vector_type,
+ typename rand_type
+ >
+ typename disable_if<is_rand<rand_type> >::type fill_lisf (
+ linearly_independent_subset_finder<kernel_type>& lisf,
+ const vector_type& samples,
+ rand_type random_seed,
+ const int sampling_size = 2000
+ )
+ {
+ dlib::rand rnd;
+ rnd.set_seed(cast_to_string(random_seed));
+ impl::fill_lisf(lisf, mat(samples), rnd, sampling_size);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LISfh_
+
diff --git a/ml/dlib/dlib/svm/linearly_independent_subset_finder_abstract.h b/ml/dlib/dlib/svm/linearly_independent_subset_finder_abstract.h
new file mode 100644
index 000000000..3224f9a0a
--- /dev/null
+++ b/ml/dlib/dlib/svm/linearly_independent_subset_finder_abstract.h
@@ -0,0 +1,327 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_LISf_ABSTRACT_
+#ifdef DLIB_LISf_ABSTRACT_
+
+#include "../algs.h"
+#include "../serialize.h"
+#include "kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename kernel_type
+ >
+ class linearly_independent_subset_finder
+ {
+ /*!
+ REQUIREMENTS ON kernel_type
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ INITIAL VALUE
+ - size() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of an online algorithm for recursively finding a
+ set (aka dictionary) of linearly independent vectors in a kernel induced
+ feature space. To use it you decide how large you would like the dictionary
+ to be and then you feed it sample points.
+
+ The implementation uses the Approximately Linearly Dependent metric described
+ in the paper The Kernel Recursive Least Squares Algorithm by Yaakov Engel to
+ decide which points are more linearly independent than others. The metric is
+ simply the squared distance between a test point and the subspace spanned by
+ the set of dictionary vectors.
+
+ Each time you present this object with a new sample point (via this->add())
+ it calculates the projection distance and if it is sufficiently large then this
+ new point is included into the dictionary. Note that this object can be configured
+ to have a maximum size. Once the max dictionary size is reached each new point
+ kicks out a previous point. This is done by removing the dictionary vector that
+ has the smallest projection distance onto the others. That is, the "least linearly
+ independent" vector is removed to make room for the new one.
+ !*/
+
+ public:
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::sample_type type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ linearly_independent_subset_finder (
+ );
+ /*!
+ ensures
+ - #minimum_tolerance() == 0.001
+ - this object is properly initialized
+ - #get_kernel() == kernel_type() (i.e. whatever the default is for the supplied kernel)
+ - #max_dictionary_size() == 100
+ !*/
+
+ linearly_independent_subset_finder (
+ const kernel_type& kernel_,
+ unsigned long max_dictionary_size_,
+ scalar_type min_tolerance = 0.001
+ );
+ /*!
+ requires
+ - min_tolerance > 0
+ - max_dictionary_size > 1
+ ensures
+ - #minimum_tolerance() == min_tolerance
+ - this object is properly initialized
+ - #get_kernel() == kernel_
+ - #max_dictionary_size() == max_dictionary_size_
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the kernel used by this object
+ !*/
+
+ unsigned long max_dictionary_size(
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of dictionary vectors this object
+ will accumulate. That is, size() will never be
+ greater than max_dictionary_size().
+ !*/
+
+ scalar_type minimum_tolerance(
+ ) const;
+ /*!
+ ensures
+ - returns the minimum projection error necessary to include a sample point
+ into the dictionary.
+ !*/
+
+ void set_minimum_tolerance (
+ scalar_type min_tolerance
+ );
+ /*!
+ requires
+ - min_tolerance > 0
+ ensures
+ - #minimum_tolerance() == min_tolerance
+ !*/
+
+ void clear_dictionary (
+ );
+ /*!
+ ensures
+ - clears out all the data (e.g. #size() == 0)
+ !*/
+
+ bool add (
+ const sample_type& x
+ );
+ /*!
+ ensures
+ - if (size() < max_dictionary_size() then
+ - if (projection_error(x) > minimum_tolerance()) then
+ - adds x into the dictionary
+ - (*this)[#size()-1] == x
+ - #size() == size() + 1
+ - returns true
+ - else
+ - the dictionary is not changed
+ - returns false
+ - else
+ - #size() == size()
+ (i.e. the number of vectors in this object doesn't change)
+ - since the dictionary is full adding a new element means we have to
+ remove one of the current ones. So let proj_error[i] be equal to the
+ projection error obtained when projecting dictionary vector (*this)[i]
+ onto the other elements of the dictionary. Then let min_proj_error
+ be equal to the minimum value in proj_error. The dictionary element
+ with the minimum projection error is the "least linearly independent"
+ vector in the dictionary and is the one which will be removed to make
+ room for a new element.
+ - if (projection_error(x) > minimum_tolerance() && projection_error(x) > min_proj_error)
+ - the least linearly independent vector in this object is removed
+ - adds x into the dictionary
+ - (*this)[#size()-1] == x
+ - returns true
+ - else
+ - the dictionary is not changed
+ - returns false
+ !*/
+
+ scalar_type projection_error (
+ const sample_type& x
+ ) const;
+ /*!
+ ensures
+ - returns the squared distance between x and the subspace spanned by
+ the set of dictionary vectors. (e.g. this is the same number that
+ gets returned by the empirical_kernel_map::project() function's
+ projection_error argument when the ekm is loaded with the dictionary
+ vectors.)
+ - Note that if the dictionary is empty then the return value is
+ equal to get_kernel()(x,x).
+ !*/
+
+ void swap (
+ linearly_independent_subset_finder& item
+ );
+ /*!
+ ensures
+ - swaps *this with item
+ !*/
+
+ size_t size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of vectors in the dictionary.
+ !*/
+
+ const sample_type& operator[] (
+ unsigned long index
+ ) const;
+ /*!
+ requires
+ - index < size()
+ ensures
+ - returns the index'th element in the set of linearly independent
+ vectors contained in this object.
+ !*/
+
+ const matrix<sample_type,0,1,mem_manager_type> get_dictionary (
+ ) const;
+ /*!
+ ensures
+ - returns a column vector that contains all the dictionary
+ vectors in this object.
+ !*/
+
+ const matrix<scalar_type,0,0,mem_manager_type>& get_kernel_matrix (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix K such that:
+ - K.nr() == K.nc() == size()
+ - K == kernel_matrix(get_kernel(), get_dictionary())
+ i.e. K == the kernel matrix for the dictionary vectors
+ !*/
+
+ const matrix<scalar_type,0,0,mem_manager_type>& get_inv_kernel_marix (
+ ) const;
+ /*!
+ ensures
+ - if (size() != 0)
+ - returns inv(get_kernel_matrix())
+ - else
+ - returns an empty matrix
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type
+ >
+ void swap(
+ linearly_independent_subset_finder<kernel_type>& a,
+ linearly_independent_subset_finder<kernel_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void serialize (
+ const linearly_independent_subset_finder<kernel_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for linearly_independent_subset_finder objects
+ !*/
+
+ template <
+ typename kernel_type
+ >
+ void deserialize (
+ linearly_independent_subset_finder<kernel_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for linearly_independent_subset_finder objects
+ !*/
+
+ template <
+ typename T
+ >
+ const matrix_exp mat (
+ const linearly_independent_subset_finder<T>& m
+ );
+ /*!
+ ensures
+ - converts m into a matrix
+ - returns a matrix R such that:
+ - is_col_vector(R) == true
+ - R.size() == m.size()
+ - for all valid r:
+ R(r) == m[r]
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename vector_type,
+ typename rand_type
+ >
+ void fill_lisf (
+ linearly_independent_subset_finder<kernel_type>& lisf,
+ const vector_type& samples,
+ rand_type& rnd,
+ int sampling_size = 2000
+ );
+ /*!
+ requires
+ - vector_type == a dlib::matrix or something convertible to one via
+ mat()
+ - is_vector(mat(samples)) == true
+ - rand_type == an implementation of rand/rand_kernel_abstract.h or a type
+ convertible to a string via cast_to_string()
+ - sampling_size > 0
+ ensures
+ - The purpose of this function is to fill lisf with points from samples. It does
+ this by randomly sampling elements of samples until no more can be added. The
+ precise stopping condition is when sampling_size additions to lisf have failed
+ or the max dictionary size has been reached.
+ - This function employs a random number generator. If rand_type is a random
+ number generator then it uses the instance given. Otherwise it uses cast_to_string(rnd)
+ to seed a new random number generator.
+ !*/
+
+ template <
+ typename kernel_type,
+ typename vector_type
+ >
+ void fill_lisf (
+ linearly_independent_subset_finder<kernel_type>& lisf,
+ const vector_type& samples
+ );
+ /*!
+ requires
+ - vector_type == a dlib::matrix or something convertible to one via
+ mat()
+ - is_vector(mat(samples)) == true
+ ensures
+ - performs fill_lisf(lisf, samples, default_rand_generator, 2000)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_LISf_ABSTRACT_
+
diff --git a/ml/dlib/dlib/svm/multiclass_tools.h b/ml/dlib/dlib/svm/multiclass_tools.h
new file mode 100644
index 000000000..d97e8aa04
--- /dev/null
+++ b/ml/dlib/dlib/svm/multiclass_tools.h
@@ -0,0 +1,68 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MULTICLASS_TOoLS_Hh_
+#define DLIB_MULTICLASS_TOoLS_Hh_
+
+#include "multiclass_tools_abstract.h"
+
+#include <vector>
+#include <set>
+#include "../unordered_pair.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename label_type>
+ std::vector<label_type> select_all_distinct_labels (
+ const std::vector<label_type>& labels
+ )
+ {
+ std::set<label_type> temp;
+ temp.insert(labels.begin(), labels.end());
+ return std::vector<label_type>(temp.begin(), temp.end());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename label_type, typename U>
+ std::vector<unordered_pair<label_type> > find_missing_pairs (
+ const std::map<unordered_pair<label_type>,U>& bdfs
+ )
+ {
+ typedef std::map<unordered_pair<label_type>,U> map_type;
+
+ // find all the labels
+ std::set<label_type> temp;
+ for (typename map_type::const_iterator i = bdfs.begin(); i != bdfs.end(); ++i)
+ {
+ temp.insert(i->first.first);
+ temp.insert(i->first.second);
+ }
+
+ std::vector<unordered_pair<label_type> > missing_pairs;
+
+ // now make sure all label pairs are present
+ typename std::set<label_type>::const_iterator i, j;
+ for (i = temp.begin(); i != temp.end(); ++i)
+ {
+ for (j = i, ++j; j != temp.end(); ++j)
+ {
+ const unordered_pair<label_type> p(*i, *j);
+
+ if (bdfs.count(p) == 0)
+ missing_pairs.push_back(p);
+ }
+ }
+
+ return missing_pairs;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MULTICLASS_TOoLS_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/multiclass_tools_abstract.h b/ml/dlib/dlib/svm/multiclass_tools_abstract.h
new file mode 100644
index 000000000..9e7774d3f
--- /dev/null
+++ b/ml/dlib/dlib/svm/multiclass_tools_abstract.h
@@ -0,0 +1,45 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MULTICLASS_TOoLS_ABSTRACT_Hh_
+#ifdef DLIB_MULTICLASS_TOoLS_ABSTRACT_Hh_
+
+#include <vector>
+#include <map>
+#include "../unordered_pair.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename label_type>
+ std::vector<label_type> select_all_distinct_labels (
+ const std::vector<label_type>& labels
+ );
+ /*!
+ ensures
+ - Determines all distinct values present in labels and stores them
+ into a sorted vector and returns it. They are sorted in ascending
+ order.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename label_type, typename U>
+ std::vector<unordered_pair<label_type> > find_missing_pairs (
+ const std::map<unordered_pair<label_type>,U>& binary_decision_functions
+ );
+ /*!
+ ensures
+ - Let L denote the set of all label_type values present in binary_decision_functions.
+ - This function finds all the label pairs with both elements distinct and in L but
+ not also in binary_decision_functions. All these missing pairs are stored
+ in a sorted vector and returned. They are sorted in ascending order.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MULTICLASS_TOoLS_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/null_df.h b/ml/dlib/dlib/svm/null_df.h
new file mode 100644
index 000000000..2cbbf04a7
--- /dev/null
+++ b/ml/dlib/dlib/svm/null_df.h
@@ -0,0 +1,33 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_NULL_DECISION_FUnCTION_Hh_
+#define DLIB_NULL_DECISION_FUnCTION_Hh_
+
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct null_df
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a type used to represent an unused field in the list of template
+ arguments of the one_vs_one_decision_function and one_vs_all_decision_function
+ templates. As such, null_df doesn't actually do anything.
+ !*/
+ template <typename T>
+ double operator() ( const T&) const { return 0; }
+ };
+
+ inline void serialize(const null_df&, std::ostream&) {}
+ inline void deserialize(null_df&, std::istream&) {}
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_NULL_DECISION_FUnCTION_Hh_
+
diff --git a/ml/dlib/dlib/svm/null_trainer.h b/ml/dlib/dlib/svm/null_trainer.h
new file mode 100644
index 000000000..015b00c15
--- /dev/null
+++ b/ml/dlib/dlib/svm/null_trainer.h
@@ -0,0 +1,61 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_NULL_TRAINERs_H_
+#define DLIB_NULL_TRAINERs_H_
+
+#include "null_trainer_abstract.h"
+#include "../algs.h"
+#include "function_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dec_funct_type
+ >
+ class null_trainer_type
+ {
+ public:
+ typedef typename dec_funct_type::kernel_type kernel_type;
+ typedef typename dec_funct_type::scalar_type scalar_type;
+ typedef typename dec_funct_type::sample_type sample_type;
+ typedef typename dec_funct_type::mem_manager_type mem_manager_type;
+ typedef dec_funct_type trained_function_type;
+
+ null_trainer_type (
+ ){}
+
+ null_trainer_type (
+ const dec_funct_type& dec_funct_
+ ) : dec_funct(dec_funct_) {}
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const dec_funct_type& train (
+ const in_sample_vector_type& ,
+ const in_scalar_vector_type&
+ ) const { return dec_funct; }
+
+ private:
+ dec_funct_type dec_funct;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dec_funct_type
+ >
+ const null_trainer_type<dec_funct_type> null_trainer (
+ const dec_funct_type& dec_funct
+ ) { return null_trainer_type<dec_funct_type>(dec_funct); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_NULL_TRAINERs_H_
+
diff --git a/ml/dlib/dlib/svm/null_trainer_abstract.h b/ml/dlib/dlib/svm/null_trainer_abstract.h
new file mode 100644
index 000000000..25f6a5443
--- /dev/null
+++ b/ml/dlib/dlib/svm/null_trainer_abstract.h
@@ -0,0 +1,101 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_NULL_TRAINERs_ABSTRACT_
+#ifdef DLIB_NULL_TRAINERs_ABSTRACT_
+
+#include "../algs.h"
+#include "function_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dec_funct_type
+ >
+ class null_trainer_type
+ {
+ /*!
+ REQUIREMENTS ON dec_funct_type
+ dec_funct_type can be any copyable type that provides the needed
+ typedefs used below (e.g. kernel_type, scalar_type, etc.).
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple tool for turning a decision function
+ into a trainer object that always returns the original decision
+ function when you try to train with it.
+
+ dlib contains a few "training post processing" algorithms (e.g.
+ reduced() and reduced2()). These tools take in a trainer object,
+ tell it to perform training, and then they take the output decision
+ function and do some kind of post processing to it. The null_trainer_type
+ object is useful because you can use it to run an already
+ learned decision function through the training post processing
+ algorithms by turning a decision function into a null_trainer_type
+ and then giving it to a post processor.
+ !*/
+
+ public:
+ typedef typename dec_funct_type::kernel_type kernel_type;
+ typedef typename dec_funct_type::scalar_type scalar_type;
+ typedef typename dec_funct_type::sample_type sample_type;
+ typedef typename dec_funct_type::mem_manager_type mem_manager_type;
+ typedef dec_funct_type trained_function_type;
+
+ null_trainer_type (
+ );
+ /*!
+ ensures
+ - any call to this->train(x,y) will return a default initialized
+ dec_funct_type object.
+ !*/
+
+ null_trainer_type (
+ const dec_funct_type& dec_funct
+ );
+ /*!
+ ensures
+ - any call to this->train(x,y) will always return a copy of
+ the given dec_funct object.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const dec_funct_type& train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the decision function object given to
+ this object's constructor.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dec_funct_type
+ >
+ const null_trainer_type<dec_funct_type> null_trainer (
+ const dec_funct_type& dec_funct
+ ) { return null_trainer_type<dec_funct_type>(dec_funct); }
+ /*!
+ ensures
+ - returns a null_trainer_type object that has been instantiated with
+ the given arguments. That is, this function returns a null_trainer_type
+ trainer that will return a copy of the given dec_funct object every time
+ someone calls its train() function.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_NULL_TRAINERs_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/svm/num_nonnegative_weights.h b/ml/dlib/dlib/svm/num_nonnegative_weights.h
new file mode 100644
index 000000000..4f21f9b69
--- /dev/null
+++ b/ml/dlib/dlib/svm/num_nonnegative_weights.h
@@ -0,0 +1,76 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_NUM_NONNEGATIVE_WEIGHtS_Hh_
+#define DLIB_NUM_NONNEGATIVE_WEIGHtS_Hh_
+
+#include "../enable_if.h"
+
+namespace dlib
+{
+
+ namespace impl2
+ {
+ template <
+ typename T,
+ unsigned long (T::*funct)()const
+ >
+ struct hnnf_helper
+ {
+ typedef char type;
+ };
+
+ template <typename T>
+ char has_num_nonnegative_weights_helper( typename hnnf_helper<T,&T::num_nonnegative_weights>::type = 0 ) { return 0;}
+
+ struct two_bytes
+ {
+ char a[2];
+ };
+
+ template <typename T>
+ two_bytes has_num_nonnegative_weights_helper(int) { return two_bytes();}
+
+ template <typename T>
+ struct work_around_visual_studio_bug
+ {
+ const static unsigned long U = sizeof(has_num_nonnegative_weights_helper<T>('a'));
+ };
+
+
+ // This is a template to tell you if a feature_extractor has a num_nonnegative_weights function or not.
+ template <typename T, unsigned long U = work_around_visual_studio_bug<T>::U >
+ struct has_num_nonnegative_weights
+ {
+ static const bool value = false;
+ };
+
+ template <typename T>
+ struct has_num_nonnegative_weights <T,1>
+ {
+ static const bool value = true;
+ };
+
+
+ }
+
+ // call fe.num_nonnegative_weights() if it exists, otherwise return 0.
+ template <typename feature_extractor>
+ typename enable_if<impl2::has_num_nonnegative_weights<feature_extractor>,unsigned long>::type num_nonnegative_weights (
+ const feature_extractor& fe
+ )
+ {
+ return fe.num_nonnegative_weights();
+ }
+
+ template <typename feature_extractor>
+ typename disable_if<impl2::has_num_nonnegative_weights<feature_extractor>,unsigned long>::type num_nonnegative_weights (
+ const feature_extractor& /*fe*/
+ )
+ {
+ return 0;
+ }
+
+}
+
+#endif // DLIB_NUM_NONNEGATIVE_WEIGHtS_Hh_
+
diff --git a/ml/dlib/dlib/svm/one_vs_all_decision_function.h b/ml/dlib/dlib/svm/one_vs_all_decision_function.h
new file mode 100644
index 000000000..8afa52344
--- /dev/null
+++ b/ml/dlib/dlib/svm/one_vs_all_decision_function.h
@@ -0,0 +1,265 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ONE_VS_ALL_DECISION_FUnCTION_Hh_
+#define DLIB_ONE_VS_ALL_DECISION_FUnCTION_Hh_
+
+#include "one_vs_all_decision_function_abstract.h"
+
+#include "../serialize.h"
+#include "../type_safe_union.h"
+#include <sstream>
+#include <map>
+#include "../any.h"
+#include "null_df.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename one_vs_all_trainer,
+ typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df,
+ typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df,
+ typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df,
+ typename DF10 = null_df
+ >
+ class one_vs_all_decision_function
+ {
+ public:
+
+ typedef typename one_vs_all_trainer::label_type result_type;
+ typedef typename one_vs_all_trainer::sample_type sample_type;
+ typedef typename one_vs_all_trainer::scalar_type scalar_type;
+ typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type;
+
+ typedef std::map<result_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
+
+ one_vs_all_decision_function() :num_classes(0) {}
+
+ explicit one_vs_all_decision_function(
+ const binary_function_table& dfs_
+ ) : dfs(dfs_)
+ {
+ num_classes = dfs.size();
+ }
+
+ const binary_function_table& get_binary_decision_functions (
+ ) const
+ {
+ return dfs;
+ }
+
+ const std::vector<result_type> get_labels (
+ ) const
+ {
+ std::vector<result_type> temp;
+ temp.reserve(dfs.size());
+ for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
+ {
+ temp.push_back(i->first);
+ }
+ return temp;
+ }
+
+
+ template <
+ typename df1, typename df2, typename df3, typename df4, typename df5,
+ typename df6, typename df7, typename df8, typename df9, typename df10
+ >
+ one_vs_all_decision_function (
+ const one_vs_all_decision_function<one_vs_all_trainer,
+ df1, df2, df3, df4, df5,
+ df6, df7, df8, df9, df10>& item
+ ) : dfs(item.get_binary_decision_functions()), num_classes(item.number_of_classes()) {}
+
+ unsigned long number_of_classes (
+ ) const
+ {
+ return num_classes;
+ }
+
+ std::pair<result_type, scalar_type> predict (
+ const sample_type& sample
+ ) const
+ {
+ DLIB_ASSERT(number_of_classes() != 0,
+ "\t pair<result_type,scalar_type> one_vs_all_decision_function::predict()"
+ << "\n\t You can't make predictions with an empty decision function."
+ << "\n\t this: " << this
+ );
+
+ result_type best_label = result_type();
+ scalar_type best_score = -std::numeric_limits<scalar_type>::infinity();
+
+ // run all the classifiers over the sample and find the best one
+ for(typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
+ {
+ const scalar_type score = i->second(sample);
+
+ if (score > best_score)
+ {
+ best_score = score;
+ best_label = i->first;
+ }
+ }
+
+ return std::make_pair(best_label, best_score);
+ }
+
+ result_type operator() (
+ const sample_type& sample
+ ) const
+ {
+ DLIB_ASSERT(number_of_classes() != 0,
+ "\t result_type one_vs_all_decision_function::operator()"
+ << "\n\t You can't make predictions with an empty decision function."
+ << "\n\t this: " << this
+ );
+
+ return predict(sample).first;
+ }
+
+
+
+ private:
+ binary_function_table dfs;
+ unsigned long num_classes;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename DF1, typename DF2, typename DF3,
+ typename DF4, typename DF5, typename DF6,
+ typename DF7, typename DF8, typename DF9,
+ typename DF10
+ >
+ void serialize(
+ const one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
+ typedef typename T::label_type result_type;
+ typedef typename T::sample_type sample_type;
+ typedef typename T::scalar_type scalar_type;
+ typedef std::map<result_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
+
+ const unsigned long version = 1;
+ serialize(version, out);
+
+ const unsigned long size = item.get_binary_decision_functions().size();
+ serialize(size, out);
+
+ for(typename binary_function_table::const_iterator i = item.get_binary_decision_functions().begin();
+ i != item.get_binary_decision_functions().end(); ++i)
+ {
+ serialize(i->first, out);
+
+ if (i->second.template contains<DF1>()) temp.template get<DF1>() = any_cast<DF1>(i->second);
+ else if (i->second.template contains<DF2>()) temp.template get<DF2>() = any_cast<DF2>(i->second);
+ else if (i->second.template contains<DF3>()) temp.template get<DF3>() = any_cast<DF3>(i->second);
+ else if (i->second.template contains<DF4>()) temp.template get<DF4>() = any_cast<DF4>(i->second);
+ else if (i->second.template contains<DF5>()) temp.template get<DF5>() = any_cast<DF5>(i->second);
+ else if (i->second.template contains<DF6>()) temp.template get<DF6>() = any_cast<DF6>(i->second);
+ else if (i->second.template contains<DF7>()) temp.template get<DF7>() = any_cast<DF7>(i->second);
+ else if (i->second.template contains<DF8>()) temp.template get<DF8>() = any_cast<DF8>(i->second);
+ else if (i->second.template contains<DF9>()) temp.template get<DF9>() = any_cast<DF9>(i->second);
+ else if (i->second.template contains<DF10>()) temp.template get<DF10>() = any_cast<DF10>(i->second);
+ else throw serialization_error("Can't serialize one_vs_all_decision_function. Not all decision functions defined.");
+
+ serialize(temp,out);
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing an object of type one_vs_all_decision_function");
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl_ova
+ {
+ template <typename sample_type, typename scalar_type>
+ struct copy_to_df_helper
+ {
+ copy_to_df_helper(any_decision_function<sample_type, scalar_type>& target_) : target(target_) {}
+
+ any_decision_function<sample_type, scalar_type>& target;
+
+ template <typename T>
+ void operator() (
+ const T& item
+ ) const
+ {
+ target = item;
+ }
+ };
+ }
+
+ template <
+ typename T,
+ typename DF1, typename DF2, typename DF3,
+ typename DF4, typename DF5, typename DF6,
+ typename DF7, typename DF8, typename DF9,
+ typename DF10
+ >
+ void deserialize(
+ one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
+ typedef typename T::label_type result_type;
+ typedef typename T::sample_type sample_type;
+ typedef typename T::scalar_type scalar_type;
+ typedef impl_ova::copy_to_df_helper<sample_type, scalar_type> copy_to;
+
+ unsigned long version;
+ deserialize(version, in);
+
+ if (version != 1)
+ throw serialization_error("Can't deserialize one_vs_all_decision_function. Wrong version.");
+
+ unsigned long size;
+ deserialize(size, in);
+
+ typedef std::map<result_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
+ binary_function_table dfs;
+
+ result_type l;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(l, in);
+ deserialize(temp, in);
+ if (temp.template contains<null_df>())
+ throw serialization_error("A sub decision function of unknown type was encountered.");
+
+ temp.apply_to_contents(copy_to(dfs[l]));
+ }
+
+ item = one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>(dfs);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing an object of type one_vs_all_decision_function");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ONE_VS_ALL_DECISION_FUnCTION_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/one_vs_all_decision_function_abstract.h b/ml/dlib/dlib/svm/one_vs_all_decision_function_abstract.h
new file mode 100644
index 000000000..8daacb8d6
--- /dev/null
+++ b/ml/dlib/dlib/svm/one_vs_all_decision_function_abstract.h
@@ -0,0 +1,214 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ONE_VS_ALL_DECISION_FUnCTION_ABSTRACT_Hh_
+#ifdef DLIB_ONE_VS_ALL_DECISION_FUnCTION_ABSTRACT_Hh_
+
+
+#include "../serialize.h"
+#include <map>
+#include "../any/any_decision_function_abstract.h"
+#include "one_vs_all_trainer_abstract.h"
+#include "null_df.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename one_vs_all_trainer,
+ typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df,
+ typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df,
+ typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df,
+ typename DF10 = null_df
+ >
+ class one_vs_all_decision_function
+ {
+ /*!
+ REQUIREMENTS ON one_vs_all_trainer
+ This should be an instantiation of the one_vs_all_trainer template.
+ It is used to infer which types are used for various things, such as
+ representing labels.
+
+ REQUIREMENTS ON DF*
+ These types can either be left at their default values or set
+ to any kind of decision function object capable of being
+ stored in an any_decision_function<sample_type,scalar_type>
+ object. These types should also be serializable.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a multiclass classifier built out of a set of
+ binary classifiers. Each binary classifier is used to vote for the
+ correct multiclass label using a one vs. all strategy. Therefore,
+ if you have N classes then there will be N binary classifiers inside
+ this object.
+
+ Note that the DF* template arguments are only used if you want
+ to serialize and deserialize one_vs_all_decision_function objects.
+ Specifically, all the types of binary decision function contained
+ within a one_vs_all_decision_function must be listed in the
+ template arguments if serialization and deserialization is to
+ be used.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call the const members of this object from multiple
+ threads so long as all the decision functions contained in this object
+ are also threadsafe. This is because the const members are purely
+ read-only operations. However, any operation that modifies a
+ one_vs_all_decision_function is not threadsafe.
+ !*/
+ public:
+
+ typedef typename one_vs_all_trainer::label_type result_type;
+ typedef typename one_vs_all_trainer::sample_type sample_type;
+ typedef typename one_vs_all_trainer::scalar_type scalar_type;
+ typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type;
+
+ typedef std::map<result_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
+
+ one_vs_all_decision_function(
+ );
+ /*!
+ ensures
+ - #number_of_classes() == 0
+ - #get_binary_decision_functions().size() == 0
+ - #get_labels().size() == 0
+ !*/
+
+ explicit one_vs_all_decision_function(
+ const binary_function_table& decision_functions
+ );
+ /*!
+ ensures
+ - #get_binary_decision_functions() == decision_functions
+ - #get_labels() == a list of all the labels which appear in the
+ given set of decision functions
+ - #number_of_classes() == #get_labels().size()
+ !*/
+
+ template <
+ typename df1, typename df2, typename df3, typename df4, typename df5,
+ typename df6, typename df7, typename df8, typename df9, typename df10
+ >
+ one_vs_all_decision_function (
+ const one_vs_all_decision_function<one_vs_all_trainer,
+ df1, df2, df3, df4, df5,
+ df6, df7, df8, df9, df10>& item
+ );
+ /*!
+ ensures
+ - #*this will be a copy of item
+ - #number_of_classes() == item.number_of_classes()
+ - #get_labels() == item.get_labels()
+ - #get_binary_decision_functions() == item.get_binary_decision_functions()
+ !*/
+
+ const binary_function_table& get_binary_decision_functions (
+ ) const;
+ /*!
+ ensures
+ - returns the table of binary decision functions used by this
+ object. The label given to a test sample is computed by
+ determining which binary decision function has the largest
+ (i.e. most positive) output and returning the label associated
+ with that decision function.
+ !*/
+
+ const std::vector<result_type> get_labels (
+ ) const;
+ /*!
+ ensures
+ - returns a vector containing all the labels which can be
+ predicted by this object.
+ !*/
+
+ unsigned long number_of_classes (
+ ) const;
+ /*!
+ ensures
+ - returns get_labels().size()
+ (i.e. returns the number of different labels/classes predicted by
+ this object)
+ !*/
+
+ std::pair<result_type, scalar_type> predict (
+ const sample_type& sample
+ ) const;
+ /*!
+ requires
+ - number_of_classes() != 0
+ ensures
+ - Evaluates all the decision functions in get_binary_decision_functions()
+ and returns the predicted label and score for the input sample. That is,
+ returns std::make_pair(label,score)
+ - The label is determined by whichever classifier outputs the largest
+ score.
+ !*/
+
+ result_type operator() (
+ const sample_type& sample
+ ) const
+ /*!
+ requires
+ - number_of_classes() != 0
+ ensures
+ - Evaluates all the decision functions in get_binary_decision_functions()
+ and returns the predicted label. That is, returns predict(sample).first.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename DF1, typename DF2, typename DF3,
+ typename DF4, typename DF5, typename DF6,
+ typename DF7, typename DF8, typename DF9,
+ typename DF10
+ >
+ void serialize(
+ const one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
+ std::ostream& out
+ );
+ /*!
+ ensures
+ - writes the given item to the output stream out.
+ throws
+ - serialization_error.
+ This is thrown if there is a problem writing to the ostream or if item
+ contains a type of decision function not listed among the DF* template
+ arguments.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename DF1, typename DF2, typename DF3,
+ typename DF4, typename DF5, typename DF6,
+ typename DF7, typename DF8, typename DF9,
+ typename DF10
+ >
+ void deserialize(
+ one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
+ std::istream& in
+ );
+ /*!
+ ensures
+ - deserializes a one_vs_all_decision_function from in and stores it in item.
+ throws
+ - serialization_error.
+ This is thrown if there is a problem reading from the istream or if the
+ serialized data contains decision functions not listed among the DF*
+ template arguments.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ONE_VS_ALL_DECISION_FUnCTION_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/one_vs_all_trainer.h b/ml/dlib/dlib/svm/one_vs_all_trainer.h
new file mode 100644
index 000000000..bcb006a41
--- /dev/null
+++ b/ml/dlib/dlib/svm/one_vs_all_trainer.h
@@ -0,0 +1,234 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ONE_VS_ALL_TRAiNER_Hh_
+#define DLIB_ONE_VS_ALL_TRAiNER_Hh_
+
+#include "one_vs_all_trainer_abstract.h"
+
+#include "one_vs_all_decision_function.h"
+#include <vector>
+
+#include "multiclass_tools.h"
+
+#include <sstream>
+#include <iostream>
+
+#include "../any.h"
+#include <map>
+#include <set>
+#include "../threads.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename any_trainer,
+ typename label_type_ = double
+ >
+ class one_vs_all_trainer
+ {
+ public:
+ typedef label_type_ label_type;
+
+ typedef typename any_trainer::sample_type sample_type;
+ typedef typename any_trainer::scalar_type scalar_type;
+ typedef typename any_trainer::mem_manager_type mem_manager_type;
+
+ typedef one_vs_all_decision_function<one_vs_all_trainer> trained_function_type;
+
+ one_vs_all_trainer (
+ ) :
+ verbose(false),
+ num_threads(4)
+ {}
+
+ void set_trainer (
+ const any_trainer& trainer
+ )
+ {
+ default_trainer = trainer;
+ trainers.clear();
+ }
+
+ void set_trainer (
+ const any_trainer& trainer,
+ const label_type& l
+ )
+ {
+ trainers[l] = trainer;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ num_threads = num;
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return num_threads;
+ }
+
+ struct invalid_label : public dlib::error
+ {
+ invalid_label(const std::string& msg, const label_type& l_
+ ) : dlib::error(msg), l(l_) {};
+
+ virtual ~invalid_label(
+ ) throw() {}
+
+ label_type l;
+ };
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(all_samples,all_labels),
+ "\t trained_function_type one_vs_all_trainer::train(all_samples,all_labels)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t all_samples.size(): " << all_samples.size()
+ << "\n\t all_labels.size(): " << all_labels.size()
+ );
+
+ const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels);
+
+ // make sure we have a trainer object for each of the label types.
+ for (unsigned long i = 0; i < distinct_labels.size(); ++i)
+ {
+ const label_type l = distinct_labels[i];
+ const typename binary_function_table::const_iterator itr = trainers.find(l);
+
+ if (itr == trainers.end() && default_trainer.is_empty())
+ {
+ std::ostringstream sout;
+ sout << "In one_vs_all_trainer, no trainer registered for the " << l << " label.";
+ throw invalid_label(sout.str(), l);
+ }
+ }
+
+
+ // now do the training
+ parallel_for_helper helper(all_samples,all_labels,default_trainer,trainers,verbose,distinct_labels);
+ parallel_for(num_threads, 0, distinct_labels.size(), helper, 500);
+
+ if (helper.error_message.size() != 0)
+ {
+ throw dlib::error("binary trainer threw while training one vs. all classifier. Error was: " + helper.error_message);
+ }
+ return trained_function_type(helper.dfs);
+ }
+
+ private:
+
+ typedef std::map<label_type, any_trainer> binary_function_table;
+ struct parallel_for_helper
+ {
+ parallel_for_helper(
+ const std::vector<sample_type>& all_samples_,
+ const std::vector<label_type>& all_labels_,
+ const any_trainer& default_trainer_,
+ const binary_function_table& trainers_,
+ const bool verbose_,
+ const std::vector<label_type>& distinct_labels_
+ ) :
+ all_samples(all_samples_),
+ all_labels(all_labels_),
+ default_trainer(default_trainer_),
+ trainers(trainers_),
+ verbose(verbose_),
+ distinct_labels(distinct_labels_)
+ {}
+
+ void operator()(long i) const
+ {
+ try
+ {
+ std::vector<scalar_type> labels;
+
+ const label_type l = distinct_labels[i];
+
+ // setup one of the one vs all training sets
+ for (unsigned long k = 0; k < all_samples.size(); ++k)
+ {
+ if (all_labels[k] == l)
+ labels.push_back(+1);
+ else
+ labels.push_back(-1);
+ }
+
+
+ if (verbose)
+ {
+ auto_mutex lock(class_mutex);
+ std::cout << "Training classifier for " << l << " vs. all" << std::endl;
+ }
+
+ any_trainer trainer;
+ // now train a binary classifier using the samples we selected
+ { auto_mutex lock(class_mutex);
+ const typename binary_function_table::const_iterator itr = trainers.find(l);
+ if (itr != trainers.end())
+ trainer = itr->second;
+ else
+ trainer = default_trainer;
+ }
+
+ any_decision_function<sample_type,scalar_type> binary_df = trainer.train(all_samples, labels);
+
+ auto_mutex lock(class_mutex);
+ dfs[l] = binary_df;
+ }
+ catch (std::exception& e)
+ {
+ auto_mutex lock(class_mutex);
+ error_message = e.what();
+ }
+ }
+
+ mutable typename trained_function_type::binary_function_table dfs;
+ mutex class_mutex;
+ mutable std::string error_message;
+
+ const std::vector<sample_type>& all_samples;
+ const std::vector<label_type>& all_labels;
+ const any_trainer& default_trainer;
+ const binary_function_table& trainers;
+ const bool verbose;
+ const std::vector<label_type>& distinct_labels;
+ };
+
+ any_trainer default_trainer;
+
+ binary_function_table trainers;
+
+ bool verbose;
+ unsigned long num_threads;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ONE_VS_ALL_TRAiNER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/one_vs_all_trainer_abstract.h b/ml/dlib/dlib/svm/one_vs_all_trainer_abstract.h
new file mode 100644
index 000000000..fb719a7e4
--- /dev/null
+++ b/ml/dlib/dlib/svm/one_vs_all_trainer_abstract.h
@@ -0,0 +1,163 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ONE_VS_ALL_TRAiNER_ABSTRACT_Hh_
+#ifdef DLIB_ONE_VS_ALL_TRAiNER_ABSTRACT_Hh_
+
+
+#include "one_vs_all_decision_function_abstract.h"
+#include <vector>
+
+#include "../any/any_trainer_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename any_trainer,
+ typename label_type_ = double
+ >
+ class one_vs_all_trainer
+ {
+ /*!
+ REQUIREMENTS ON any_trainer
+ must be an instantiation of the dlib::any_trainer template.
+
+ REQUIREMENTS ON label_type_
+ label_type_ must be default constructable, copyable, and comparable using
+ operator < and ==. It must also be possible to write it to an std::ostream
+ using operator<<.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for turning a bunch of binary classifiers into a
+ multiclass classifier. It does this by training the binary classifiers
+ in a one vs. all fashion. That is, if you have N possible classes then
+ it trains N binary classifiers which are then used to vote on the identity
+ of a test sample.
+
+ This object works with any kind of binary classification trainer object
+ capable of being assigned to an any_trainer object. (e.g. the svm_nu_trainer)
+ !*/
+
+ public:
+
+
+ typedef label_type_ label_type;
+
+ typedef typename any_trainer::sample_type sample_type;
+ typedef typename any_trainer::scalar_type scalar_type;
+ typedef typename any_trainer::mem_manager_type mem_manager_type;
+
+ typedef one_vs_all_decision_function<one_vs_all_trainer> trained_function_type;
+
+ one_vs_all_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized.
+ - This object will not be verbose unless be_verbose() is called.
+ - No binary trainers are associated with *this. I.e. you have to
+ call set_trainer() before calling train().
+ - #get_num_threads() == 4
+ !*/
+
+ void set_trainer (
+ const any_trainer& trainer
+ );
+ /*!
+ ensures
+ - sets the trainer used for all binary subproblems. Any previous
+ calls to set_trainer() are overridden by this function. Even the
+ more specific set_trainer(trainer, l) form.
+ !*/
+
+ void set_trainer (
+ const any_trainer& trainer,
+ const label_type& l
+ );
+ /*!
+ ensures
+ - Sets the trainer object used to create a binary classifier to
+ distinguish l labeled samples from all other samples.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads used during training. You should
+ usually set this equal to the number of processing cores on your
+ machine.
+ !*/
+
+ struct invalid_label : public dlib::error
+ {
+ /*!
+ This is the exception thrown by the train() function below.
+ !*/
+ label_type l;
+ };
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(all_samples, all_labels)
+ ensures
+ - trains a bunch of binary classifiers in a one vs all fashion to solve the given
+ multiclass classification problem.
+ - returns a one_vs_all_decision_function F with the following properties:
+ - F contains all the learned binary classifiers and can be used to predict
+ the labels of new samples.
+ - if (new_x is a sample predicted to have a label of L) then
+ - F(new_x) == L
+ - F.get_labels() == select_all_distinct_labels(all_labels)
+ - F.number_of_classes() == select_all_distinct_labels(all_labels).size()
+ throws
+ - invalid_label
+ This exception is thrown if there are labels in all_labels which don't have
+ any corresponding trainer object. This will never happen if set_trainer(trainer)
+ has been called. However, if only the set_trainer(trainer,l) form has been
+ used then this exception is thrown if not all labels have been given a trainer.
+
+ invalid_label::l will contain the label which is missing a trainer object.
+ Additionally, the exception will contain an informative error message available
+ via invalid_label::what().
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ONE_VS_ALL_TRAiNER_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/one_vs_one_decision_function.h b/ml/dlib/dlib/svm/one_vs_one_decision_function.h
new file mode 100644
index 000000000..02a5fa51e
--- /dev/null
+++ b/ml/dlib/dlib/svm/one_vs_one_decision_function.h
@@ -0,0 +1,291 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ONE_VS_ONE_DECISION_FUnCTION_Hh_
+#define DLIB_ONE_VS_ONE_DECISION_FUnCTION_Hh_
+
+#include "one_vs_one_decision_function_abstract.h"
+
+#include "../serialize.h"
+#include "../type_safe_union.h"
+#include <iostream>
+#include <sstream>
+#include <set>
+#include <map>
+#include "../any.h"
+#include "../unordered_pair.h"
+#include "null_df.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename one_vs_one_trainer,
+ typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df,
+ typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df,
+ typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df,
+ typename DF10 = null_df
+ >
+ class one_vs_one_decision_function
+ {
+ public:
+
+ typedef typename one_vs_one_trainer::label_type result_type;
+ typedef typename one_vs_one_trainer::sample_type sample_type;
+ typedef typename one_vs_one_trainer::scalar_type scalar_type;
+ typedef typename one_vs_one_trainer::mem_manager_type mem_manager_type;
+
+ typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;
+
+ one_vs_one_decision_function() :num_classes(0) {}
+
+ explicit one_vs_one_decision_function(
+ const binary_function_table& dfs_
+ ) : dfs(dfs_)
+ {
+#ifdef ENABLE_ASSERTS
+ {
+ const std::vector<unordered_pair<result_type> > missing_pairs = find_missing_pairs(dfs_);
+ if (missing_pairs.size() != 0)
+ {
+ std::ostringstream sout;
+ for (unsigned long i = 0; i < missing_pairs.size(); ++i)
+ {
+ sout << "\t (" << missing_pairs[i].first << ", " << missing_pairs[i].second << ")\n";
+ }
+ DLIB_ASSERT(missing_pairs.size() == 0,
+ "\t void one_vs_one_decision_function::one_vs_one_decision_function()"
+ << "\n\t The supplied set of binary decision functions is incomplete."
+ << "\n\t this: " << this
+ << "\n\t Classifiers are missing for the following label pairs: \n" << sout.str()
+ );
+ }
+ }
+#endif
+
+ // figure out how many labels are covered by this set of binary decision functions
+ std::set<result_type> labels;
+ for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
+ {
+ labels.insert(i->first.first);
+ labels.insert(i->first.second);
+ }
+ num_classes = labels.size();
+ }
+
+ const binary_function_table& get_binary_decision_functions (
+ ) const
+ {
+ return dfs;
+ }
+
+ const std::vector<result_type> get_labels (
+ ) const
+ {
+ std::set<result_type> labels;
+ for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
+ {
+ labels.insert(i->first.first);
+ labels.insert(i->first.second);
+ }
+ return std::vector<result_type>(labels.begin(), labels.end());
+ }
+
+
+ template <
+ typename df1, typename df2, typename df3, typename df4, typename df5,
+ typename df6, typename df7, typename df8, typename df9, typename df10
+ >
+ one_vs_one_decision_function (
+ const one_vs_one_decision_function<one_vs_one_trainer,
+ df1, df2, df3, df4, df5,
+ df6, df7, df8, df9, df10>& item
+ ) : dfs(item.get_binary_decision_functions()), num_classes(item.number_of_classes()) {}
+
+ unsigned long number_of_classes (
+ ) const
+ {
+ return num_classes;
+ }
+
+ result_type operator() (
+ const sample_type& sample
+ ) const
+ {
+ DLIB_ASSERT(number_of_classes() != 0,
+ "\t void one_vs_one_decision_function::operator()"
+ << "\n\t You can't make predictions with an empty decision function."
+ << "\n\t this: " << this
+ );
+
+ std::map<result_type,int> votes;
+
+ // run all the classifiers over the sample
+ for(typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
+ {
+ const scalar_type score = i->second(sample);
+
+ if (score > 0)
+ votes[i->first.first] += 1;
+ else
+ votes[i->first.second] += 1;
+ }
+
+ // now figure out who had the most votes
+ result_type best_label = result_type();
+ int best_votes = 0;
+ for (typename std::map<result_type,int>::iterator i = votes.begin(); i != votes.end(); ++i)
+ {
+ if (i->second > best_votes)
+ {
+ best_votes = i->second;
+ best_label = i->first;
+ }
+ }
+
+ return best_label;
+ }
+
+
+
+ private:
+ binary_function_table dfs;
+ unsigned long num_classes;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename DF1, typename DF2, typename DF3,
+ typename DF4, typename DF5, typename DF6,
+ typename DF7, typename DF8, typename DF9,
+ typename DF10
+ >
+ void serialize(
+ const one_vs_one_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
+ typedef typename T::label_type result_type;
+ typedef typename T::sample_type sample_type;
+ typedef typename T::scalar_type scalar_type;
+ typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;
+
+ const unsigned long version = 1;
+ serialize(version, out);
+
+ const unsigned long size = item.get_binary_decision_functions().size();
+ serialize(size, out);
+
+ for(typename binary_function_table::const_iterator i = item.get_binary_decision_functions().begin();
+ i != item.get_binary_decision_functions().end(); ++i)
+ {
+ serialize(i->first, out);
+
+ if (i->second.template contains<DF1>()) temp.template get<DF1>() = any_cast<DF1>(i->second);
+ else if (i->second.template contains<DF2>()) temp.template get<DF2>() = any_cast<DF2>(i->second);
+ else if (i->second.template contains<DF3>()) temp.template get<DF3>() = any_cast<DF3>(i->second);
+ else if (i->second.template contains<DF4>()) temp.template get<DF4>() = any_cast<DF4>(i->second);
+ else if (i->second.template contains<DF5>()) temp.template get<DF5>() = any_cast<DF5>(i->second);
+ else if (i->second.template contains<DF6>()) temp.template get<DF6>() = any_cast<DF6>(i->second);
+ else if (i->second.template contains<DF7>()) temp.template get<DF7>() = any_cast<DF7>(i->second);
+ else if (i->second.template contains<DF8>()) temp.template get<DF8>() = any_cast<DF8>(i->second);
+ else if (i->second.template contains<DF9>()) temp.template get<DF9>() = any_cast<DF9>(i->second);
+ else if (i->second.template contains<DF10>()) temp.template get<DF10>() = any_cast<DF10>(i->second);
+ else throw serialization_error("Can't serialize one_vs_one_decision_function. Not all decision functions defined.");
+
+ serialize(temp,out);
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing an object of type one_vs_one_decision_function");
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename sample_type, typename scalar_type>
+ struct copy_to_df_helper
+ {
+ copy_to_df_helper(any_decision_function<sample_type, scalar_type>& target_) : target(target_) {}
+
+ any_decision_function<sample_type, scalar_type>& target;
+
+ template <typename T>
+ void operator() (
+ const T& item
+ ) const
+ {
+ target = item;
+ }
+ };
+ }
+
+ template <
+ typename T,
+ typename DF1, typename DF2, typename DF3,
+ typename DF4, typename DF5, typename DF6,
+ typename DF7, typename DF8, typename DF9,
+ typename DF10
+ >
+ void deserialize(
+ one_vs_one_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
+ typedef typename T::label_type result_type;
+ typedef typename T::sample_type sample_type;
+ typedef typename T::scalar_type scalar_type;
+ typedef impl::copy_to_df_helper<sample_type, scalar_type> copy_to;
+
+ unsigned long version;
+ deserialize(version, in);
+
+ if (version != 1)
+ throw serialization_error("Can't deserialize one_vs_one_decision_function. Wrong version.");
+
+ unsigned long size;
+ deserialize(size, in);
+
+ typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;
+ binary_function_table dfs;
+
+ unordered_pair<result_type> p;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ deserialize(p, in);
+ deserialize(temp, in);
+ if (temp.template contains<null_df>())
+ throw serialization_error("A sub decision function of unknown type was encountered.");
+
+ temp.apply_to_contents(copy_to(dfs[p]));
+ }
+
+ item = one_vs_one_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>(dfs);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing an object of type one_vs_one_decision_function");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ONE_VS_ONE_DECISION_FUnCTION_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/one_vs_one_decision_function_abstract.h b/ml/dlib/dlib/svm/one_vs_one_decision_function_abstract.h
new file mode 100644
index 000000000..cf22e0ba7
--- /dev/null
+++ b/ml/dlib/dlib/svm/one_vs_one_decision_function_abstract.h
@@ -0,0 +1,213 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ONE_VS_ONE_DECISION_FUnCTION_ABSTRACT_Hh_
+#ifdef DLIB_ONE_VS_ONE_DECISION_FUnCTION_ABSTRACT_Hh_
+
+
+#include "../serialize.h"
+#include <map>
+#include "../any/any_decision_function_abstract.h"
+#include "../unordered_pair.h"
+#include "one_vs_one_trainer_abstract.h"
+#include "null_df.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename one_vs_one_trainer,
+ typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df,
+ typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df,
+ typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df,
+ typename DF10 = null_df
+ >
+ class one_vs_one_decision_function
+ {
+ /*!
+ REQUIREMENTS ON one_vs_one_trainer
+ This should be an instantiation of the one_vs_one_trainer template.
+ It is used to infer which types are used for various things, such as
+ representing labels.
+
+ REQUIREMENTS ON DF*
+ These types can either be left at their default values or set
+ to any kind of decision function object capable of being
+ stored in an any_decision_function<sample_type,scalar_type>
+ object. These types should also be serializable.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a multiclass classifier built out
+ of a set of binary classifiers. Each binary classifier
+ is used to vote for the correct multiclass label using a
+ one vs. one strategy. Therefore, if you have N classes then
+ there will be N*(N-1)/2 binary classifiers inside this object.
+
+ Note that the DF* template arguments are only used if you want
+ to serialize and deserialize one_vs_one_decision_function objects.
+ Specifically, all the types of binary decision function contained
+ within a one_vs_one_decision_function must be listed in the
+ template arguments if serialization and deserialization is to
+ be used.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call the const members of this object from multiple
+ threads so long as all the decision functions contained in this object
+ are also threadsafe. This is because the const members are purely
+ read-only operations. However, any operation that modifies a
+ one_vs_one_decision_function is not threadsafe.
+ !*/
+ public:
+
+ typedef typename one_vs_one_trainer::label_type result_type;
+ typedef typename one_vs_one_trainer::sample_type sample_type;
+ typedef typename one_vs_one_trainer::scalar_type scalar_type;
+ typedef typename one_vs_one_trainer::mem_manager_type mem_manager_type;
+
+ typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;
+
+ one_vs_one_decision_function(
+ );
+ /*!
+ ensures
+ - #number_of_classes() == 0
+ - #get_binary_decision_functions().size() == 0
+ - #get_labels().size() == 0
+ !*/
+
+ explicit one_vs_one_decision_function(
+ const binary_function_table& decision_functions
+ );
+ /*!
+ requires
+ - find_missing_pairs(decision_functions).size() == 0
+ (i.e. all pairs of labels have an associated decision function)
+ ensures
+ - #get_binary_decision_functions() == decision_functions
+ - #get_labels() == a list of all the labels which appear in the
+ given set of decision functions
+ - #number_of_classes() == #get_labels().size()
+ !*/
+
+ template <
+ typename df1, typename df2, typename df3, typename df4, typename df5,
+ typename df6, typename df7, typename df8, typename df9, typename df10
+ >
+ one_vs_one_decision_function (
+ const one_vs_one_decision_function<one_vs_one_trainer,
+ df1, df2, df3, df4, df5,
+ df6, df7, df8, df9, df10>& item
+ );
+ /*!
+ ensures
+ - #*this will be a copy of item
+ - #number_of_classes() == item.number_of_classes()
+ - #get_labels() == item.get_labels()
+ - #get_binary_decision_functions() == item.get_binary_decision_functions()
+ !*/
+
+ const binary_function_table& get_binary_decision_functions (
+ ) const;
+ /*!
+ ensures
+ - returns the table of binary decision functions used by this
+ object. The correspondence between binary decision functions
+ and multiclass labels is the following:
+ - for each element i of get_binary_decision_functions()
+ - i->first == the label pair associated with binary decision
+ function i->second.
+ - if (decision function i->second outputs a value > 0) then
+ - i->second is indicating that a test sample should
+ receive a label of i->first.first
+ - else
+ - i->second is indicating that a test sample should
+ receive a label of i->first.second
+ !*/
+
+ const std::vector<result_type> get_labels (
+ ) const;
+ /*!
+ ensures
+ - returns a vector containing all the labels which can be
+ predicted by this object.
+ !*/
+
+ unsigned long number_of_classes (
+ ) const;
+ /*!
+ ensures
+ - returns get_labels().size()
+ (i.e. returns the number of different labels/classes predicted by
+ this object)
+ !*/
+
+ result_type operator() (
+ const sample_type& sample
+ ) const
+ /*!
+ requires
+ - number_of_classes() != 0
+ ensures
+ - evaluates all the decision functions in get_binary_decision_functions()
+ and returns the label which received the most votes.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename DF1, typename DF2, typename DF3,
+ typename DF4, typename DF5, typename DF6,
+ typename DF7, typename DF8, typename DF9,
+ typename DF10
+ >
+ void serialize(
+ const one_vs_one_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
+ std::ostream& out
+ );
+ /*!
+ ensures
+ - writes the given item to the output stream out.
+ throws
+ - serialization_error.
+ This is thrown if there is a problem writing to the ostream or if item
+ contains a type of decision function not listed among the DF* template
+ arguments.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename DF1, typename DF2, typename DF3,
+ typename DF4, typename DF5, typename DF6,
+ typename DF7, typename DF8, typename DF9,
+ typename DF10
+ >
+ void deserialize(
+ one_vs_one_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
+ std::istream& in
+ );
+ /*!
+ ensures
+ - deserializes a one_vs_one_decision_function from in and stores it in item.
+ throws
+ - serialization_error.
+ This is thrown if there is a problem reading from the istream or if the
+ serialized data contains decision functions not listed among the DF*
+ template arguments.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ONE_VS_ONE_DECISION_FUnCTION_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/one_vs_one_trainer.h b/ml/dlib/dlib/svm/one_vs_one_trainer.h
new file mode 100644
index 000000000..2beec8f67
--- /dev/null
+++ b/ml/dlib/dlib/svm/one_vs_one_trainer.h
@@ -0,0 +1,249 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ONE_VS_ONE_TRAiNER_Hh_
+#define DLIB_ONE_VS_ONE_TRAiNER_Hh_
+
+#include "one_vs_one_trainer_abstract.h"
+
+#include "one_vs_one_decision_function.h"
+#include <vector>
+
+#include "../unordered_pair.h"
+#include "multiclass_tools.h"
+
+#include <sstream>
+#include <iostream>
+
+#include "../any.h"
+#include <map>
+#include <set>
+#include "../threads.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename any_trainer,
+ typename label_type_ = double
+ >
+ class one_vs_one_trainer
+ {
+ public:
+ typedef label_type_ label_type;
+
+ typedef typename any_trainer::sample_type sample_type;
+ typedef typename any_trainer::scalar_type scalar_type;
+ typedef typename any_trainer::mem_manager_type mem_manager_type;
+
+ typedef one_vs_one_decision_function<one_vs_one_trainer> trained_function_type;
+
+ one_vs_one_trainer (
+ ) :
+ verbose(false),
+ num_threads(4)
+ {}
+
+ void set_trainer (
+ const any_trainer& trainer
+ )
+ {
+ default_trainer = trainer;
+ trainers.clear();
+ }
+
+ void set_trainer (
+ const any_trainer& trainer,
+ const label_type& l1,
+ const label_type& l2
+ )
+ {
+ trainers[make_unordered_pair(l1,l2)] = trainer;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ num_threads = num;
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return num_threads;
+ }
+
+ struct invalid_label : public dlib::error
+ {
+ invalid_label(const std::string& msg, const label_type& l1_, const label_type& l2_
+ ) : dlib::error(msg), l1(l1_), l2(l2_) {};
+
+ virtual ~invalid_label(
+ ) throw() {}
+
+ label_type l1, l2;
+ };
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(all_samples,all_labels),
+ "\t trained_function_type one_vs_one_trainer::train(all_samples,all_labels)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t all_samples.size(): " << all_samples.size()
+ << "\n\t all_labels.size(): " << all_labels.size()
+ );
+
+ const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels);
+
+
+ // fill pairs with all the pairs of labels.
+ std::vector<unordered_pair<label_type> > pairs;
+ for (unsigned long i = 0; i < distinct_labels.size(); ++i)
+ {
+ for (unsigned long j = i+1; j < distinct_labels.size(); ++j)
+ {
+ pairs.push_back(unordered_pair<label_type>(distinct_labels[i], distinct_labels[j]));
+
+ // make sure we have a trainer for this pair
+ const typename binary_function_table::const_iterator itr = trainers.find(pairs.back());
+ if (itr == trainers.end() && default_trainer.is_empty())
+ {
+ std::ostringstream sout;
+ sout << "In one_vs_one_trainer, no trainer registered for the ("
+ << pairs.back().first << ", " << pairs.back().second << ") label pair.";
+ throw invalid_label(sout.str(), pairs.back().first, pairs.back().second);
+ }
+ }
+ }
+
+
+
+ // Now train on all the label pairs.
+ parallel_for_helper helper(all_samples,all_labels,default_trainer,trainers,verbose,pairs);
+ parallel_for(num_threads, 0, pairs.size(), helper, 500);
+
+ if (helper.error_message.size() != 0)
+ {
+ throw dlib::error("binary trainer threw while training one vs. one classifier. Error was: " + helper.error_message);
+ }
+ return trained_function_type(helper.dfs);
+ }
+
+ private:
+
+ typedef std::map<unordered_pair<label_type>, any_trainer> binary_function_table;
+
+ struct parallel_for_helper
+ {
+ parallel_for_helper(
+ const std::vector<sample_type>& all_samples_,
+ const std::vector<label_type>& all_labels_,
+ const any_trainer& default_trainer_,
+ const binary_function_table& trainers_,
+ const bool verbose_,
+ const std::vector<unordered_pair<label_type> >& pairs_
+ ) :
+ all_samples(all_samples_),
+ all_labels(all_labels_),
+ default_trainer(default_trainer_),
+ trainers(trainers_),
+ verbose(verbose_),
+ pairs(pairs_)
+ {}
+
+ void operator()(long i) const
+ {
+ try
+ {
+ std::vector<sample_type> samples;
+ std::vector<scalar_type> labels;
+
+ const unordered_pair<label_type> p = pairs[i];
+
+ // pick out the samples corresponding to these two classes
+ for (unsigned long k = 0; k < all_samples.size(); ++k)
+ {
+ if (all_labels[k] == p.first)
+ {
+ samples.push_back(all_samples[k]);
+ labels.push_back(+1);
+ }
+ else if (all_labels[k] == p.second)
+ {
+ samples.push_back(all_samples[k]);
+ labels.push_back(-1);
+ }
+ }
+
+ if (verbose)
+ {
+ auto_mutex lock(class_mutex);
+ std::cout << "Training classifier for " << p.first << " vs. " << p.second << std::endl;
+ }
+
+ any_trainer trainer;
+ // now train a binary classifier using the samples we selected
+ { auto_mutex lock(class_mutex);
+ const typename binary_function_table::const_iterator itr = trainers.find(p);
+ if (itr != trainers.end())
+ trainer = itr->second;
+ else
+ trainer = default_trainer;
+ }
+
+ any_decision_function<sample_type,scalar_type> binary_df = trainer.train(samples, labels);
+
+ auto_mutex lock(class_mutex);
+ dfs[p] = binary_df;
+ }
+ catch (std::exception& e)
+ {
+ auto_mutex lock(class_mutex);
+ error_message = e.what();
+ }
+ }
+
+ mutable typename trained_function_type::binary_function_table dfs;
+ mutex class_mutex;
+ mutable std::string error_message;
+
+ const std::vector<sample_type>& all_samples;
+ const std::vector<label_type>& all_labels;
+ const any_trainer& default_trainer;
+ const binary_function_table& trainers;
+ const bool verbose;
+ const std::vector<unordered_pair<label_type> >& pairs;
+ };
+
+
+ any_trainer default_trainer;
+ binary_function_table trainers;
+ bool verbose;
+ unsigned long num_threads;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ONE_VS_ONE_TRAiNER_Hh_
+
diff --git a/ml/dlib/dlib/svm/one_vs_one_trainer_abstract.h b/ml/dlib/dlib/svm/one_vs_one_trainer_abstract.h
new file mode 100644
index 000000000..42ba35815
--- /dev/null
+++ b/ml/dlib/dlib/svm/one_vs_one_trainer_abstract.h
@@ -0,0 +1,166 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ONE_VS_ONE_TRAiNER_ABSTRACT_Hh_
+#ifdef DLIB_ONE_VS_ONE_TRAiNER_ABSTRACT_Hh_
+
+
+#include "one_vs_one_decision_function_abstract.h"
+#include <vector>
+
+#include "../any/any_trainer_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename any_trainer,
+ typename label_type_ = double
+ >
+ class one_vs_one_trainer
+ {
+ /*!
+ REQUIREMENTS ON any_trainer
+ must be an instantiation of the dlib::any_trainer template.
+
+ REQUIREMENTS ON label_type_
+ label_type_ must be default constructable, copyable, and comparable using
+ operator < and ==. It must also be possible to write it to an std::ostream
+ using operator<<.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for turning a bunch of binary classifiers
+ into a multiclass classifier. It does this by training the binary
+ classifiers in a one vs. one fashion. That is, if you have N possible
+ classes then it trains N*(N-1)/2 binary classifiers which are then used
+ to vote on the identity of a test sample.
+
+ This object works with any kind of binary classification trainer object
+ capable of being assigned to an any_trainer object. (e.g. the svm_nu_trainer)
+ !*/
+
+ public:
+
+
+ typedef label_type_ label_type;
+
+ typedef typename any_trainer::sample_type sample_type;
+ typedef typename any_trainer::scalar_type scalar_type;
+ typedef typename any_trainer::mem_manager_type mem_manager_type;
+
+ typedef one_vs_one_decision_function<one_vs_one_trainer> trained_function_type;
+
+ one_vs_one_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized
+ - This object will not be verbose unless be_verbose() is called.
+ - No binary trainers are associated with *this. I.e. you have to
+ call set_trainer() before calling train().
+ - #get_num_threads() == 4
+ !*/
+
+ void set_trainer (
+ const any_trainer& trainer
+ );
+ /*!
+ ensures
+ - sets the trainer used for all pairs of training. Any previous
+ calls to set_trainer() are overridden by this function. Even the
+ more specific set_trainer(trainer, l1, l2) form.
+ !*/
+
+ void set_trainer (
+ const any_trainer& trainer,
+ const label_type& l1,
+ const label_type& l2
+ );
+ /*!
+ requires
+ - l1 != l2
+ ensures
+ - Sets the trainer object used to create a binary classifier to
+ distinguish l1 labeled samples from l2 labeled samples.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads used during training. You should
+ usually set this equal to the number of processing cores on your
+ machine.
+ !*/
+
+ struct invalid_label : public dlib::error
+ {
+ /*!
+ This is the exception thrown by the train() function below.
+ !*/
+ label_type l1, l2;
+ };
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(all_samples, all_labels)
+ ensures
+ - trains a bunch of binary classifiers in a one vs one fashion to solve the given
+ multiclass classification problem.
+ - returns a one_vs_one_decision_function F with the following properties:
+ - F contains all the learned binary classifiers and can be used to predict
+ the labels of new samples.
+ - if (new_x is a sample predicted to have a label of L) then
+ - F(new_x) == L
+ - F.get_labels() == select_all_distinct_labels(all_labels)
+ - F.number_of_classes() == select_all_distinct_labels(all_labels).size()
+ throws
+ - invalid_label
+ This exception is thrown if there are labels in all_labels which don't have
+ any corresponding trainer object. This will never happen if set_trainer(trainer)
+ has been called. However, if only the set_trainer(trainer,l1,l2) form has been
+ used then this exception is thrown if not all necessary label pairs have been
+ given a trainer.
+
+ invalid_label::l1 and invalid_label::l2 will contain the label pair which is
+ missing a trainer object. Additionally, the exception will contain an
+ informative error message available via invalid_label::what().
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ONE_VS_ONE_TRAiNER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/pegasos.h b/ml/dlib/dlib/svm/pegasos.h
new file mode 100644
index 000000000..c28093fe0
--- /dev/null
+++ b/ml/dlib/dlib/svm/pegasos.h
@@ -0,0 +1,710 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PEGASoS_
+#define DLIB_PEGASoS_
+
+#include "pegasos_abstract.h"
+#include <cmath>
+#include "../algs.h"
+#include "function.h"
+#include "kernel.h"
+#include "kcentroid.h"
+#include <iostream>
+#include <memory>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_pegasos
+ {
+ typedef kcentroid<offset_kernel<K> > kc_type;
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ template <typename K_>
+ struct rebind {
+ typedef svm_pegasos<K_> other;
+ };
+
+ svm_pegasos (
+ ) :
+ max_sv(40),
+ lambda_c1(0.0001),
+ lambda_c2(0.0001),
+ tau(0.01),
+ tolerance(0.01),
+ train_count(0),
+ w(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false)
+ {
+ }
+
+ svm_pegasos (
+ const kernel_type& kernel_,
+ const scalar_type& lambda_,
+ const scalar_type& tolerance_,
+ unsigned long max_num_sv
+ ) :
+ max_sv(max_num_sv),
+ kernel(kernel_),
+ lambda_c1(lambda_),
+ lambda_c2(lambda_),
+ tau(0.01),
+ tolerance(tolerance_),
+ train_count(0),
+ w(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(lambda_ > 0 && tolerance > 0 && max_num_sv > 0,
+ "\tsvm_pegasos::svm_pegasos(kernel,lambda,tolerance)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t lambda_: " << lambda_
+ << "\n\t max_num_sv: " << max_num_sv
+ );
+ }
+
+ void clear (
+ )
+ {
+ // reset the w vector back to its initial state
+ w = kc_type(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false);
+ train_count = 0;
+ }
+
+ void set_kernel (
+ kernel_type k
+ )
+ {
+ kernel = k;
+ clear();
+ }
+
+ void set_max_num_sv (
+ unsigned long max_num_sv
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_num_sv > 0,
+ "\tvoid svm_pegasos::set_max_num_sv(max_num_sv)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t max_num_sv: " << max_num_sv
+ );
+ max_sv = max_num_sv;
+ clear();
+ }
+
+ unsigned long get_max_num_sv (
+ ) const
+ {
+ return max_sv;
+ }
+
+ void set_tolerance (
+ double tol
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < tol,
+ "\tvoid svm_pegasos::set_tolerance(tol)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t tol: " << tol
+ );
+ tolerance = tol;
+ clear();
+ }
+
+ void set_lambda (
+ scalar_type lambda_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < lambda_,
+ "\tvoid svm_pegasos::set_lambda(lambda_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t lambda_: " << lambda_
+ );
+ lambda_c1 = lambda_;
+ lambda_c2 = lambda_;
+
+ max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
+ clear();
+ }
+
+ void set_lambda_class1 (
+ scalar_type lambda_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < lambda_,
+ "\tvoid svm_pegasos::set_lambda_class1(lambda_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t lambda_: " << lambda_
+ );
+ lambda_c1 = lambda_;
+ max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
+ clear();
+ }
+
+ void set_lambda_class2 (
+ scalar_type lambda_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < lambda_,
+ "\tvoid svm_pegasos::set_lambda_class2(lambda_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t lambda_: " << lambda_
+ );
+ lambda_c2 = lambda_;
+ max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
+ clear();
+ }
+
+ const scalar_type get_lambda_class1 (
+ ) const
+ {
+ return lambda_c1;
+ }
+
+ const scalar_type get_lambda_class2 (
+ ) const
+ {
+ return lambda_c2;
+ }
+
+ const scalar_type get_tolerance (
+ ) const
+ {
+ return tolerance;
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ return kernel;
+ }
+
+ unsigned long get_train_count (
+ ) const
+ {
+ return static_cast<unsigned long>(train_count);
+ }
+
+ scalar_type train (
+ const sample_type& x,
+ const scalar_type& y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(y == -1 || y == 1,
+ "\tscalar_type svm_pegasos::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t y: " << y
+ );
+
+ const double lambda = (y==+1)? lambda_c1 : lambda_c2;
+
+ ++train_count;
+ const scalar_type learning_rate = 1/(lambda*train_count);
+
+ // if this sample point is within the margin of the current hyperplane
+ if (y*w.inner_product(x) < 1)
+ {
+
+ // compute: w = (1-learning_rate*lambda)*w + y*learning_rate*x
+ w.train(x, 1 - learning_rate*lambda, y*learning_rate);
+
+ scalar_type wnorm = std::sqrt(w.squared_norm());
+ scalar_type temp = max_wnorm/wnorm;
+ if (temp < 1)
+ w.scale_by(temp);
+ }
+ else
+ {
+ w.scale_by(1 - learning_rate*lambda);
+ }
+
+ // return the current learning rate
+ return 1/(std::min(lambda_c1,lambda_c2)*train_count);
+ }
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const
+ {
+ return w.inner_product(x);
+ }
+
+ const decision_function<kernel_type> get_decision_function (
+ ) const
+ {
+ distance_function<offset_kernel<kernel_type> > df = w.get_distance_function();
+ return decision_function<kernel_type>(df.get_alpha(), -tau*sum(df.get_alpha()), kernel, df.get_basis_vectors());
+ }
+
+ void swap (
+ svm_pegasos& item
+ )
+ {
+ exchange(max_sv, item.max_sv);
+ exchange(kernel, item.kernel);
+ exchange(lambda_c1, item.lambda_c1);
+ exchange(lambda_c2, item.lambda_c2);
+ exchange(max_wnorm, item.max_wnorm);
+ exchange(tau, item.tau);
+ exchange(tolerance, item.tolerance);
+ exchange(train_count, item.train_count);
+ exchange(w, item.w);
+ }
+
+ friend void serialize(const svm_pegasos& item, std::ostream& out)
+ {
+ serialize(item.max_sv, out);
+ serialize(item.kernel, out);
+ serialize(item.lambda_c1, out);
+ serialize(item.lambda_c2, out);
+ serialize(item.max_wnorm, out);
+ serialize(item.tau, out);
+ serialize(item.tolerance, out);
+ serialize(item.train_count, out);
+ serialize(item.w, out);
+ }
+
+ friend void deserialize(svm_pegasos& item, std::istream& in)
+ {
+ deserialize(item.max_sv, in);
+ deserialize(item.kernel, in);
+ deserialize(item.lambda_c1, in);
+ deserialize(item.lambda_c2, in);
+ deserialize(item.max_wnorm, in);
+ deserialize(item.tau, in);
+ deserialize(item.tolerance, in);
+ deserialize(item.train_count, in);
+ deserialize(item.w, in);
+ }
+
+ private:
+
+ unsigned long max_sv;
+ kernel_type kernel;
+ scalar_type lambda_c1;
+ scalar_type lambda_c2;
+ scalar_type max_wnorm;
+ scalar_type tau;
+ scalar_type tolerance;
+ scalar_type train_count;
+ kc_type w;
+
+ }; // end of class svm_pegasos
+
+ template <
+ typename K
+ >
+ void swap (
+ svm_pegasos<K>& a,
+ svm_pegasos<K>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ void replicate_settings (
+ const svm_pegasos<T>& source,
+ svm_pegasos<U>& dest
+ )
+ {
+ dest.set_tolerance(source.get_tolerance());
+ dest.set_lambda_class1(source.get_lambda_class1());
+ dest.set_lambda_class2(source.get_lambda_class2());
+ dest.set_max_num_sv(source.get_max_num_sv());
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ class batch_trainer
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename sample_vector_type
+ >
+ class caching_kernel
+ {
+ public:
+ typedef typename K::scalar_type scalar_type;
+ typedef long sample_type;
+ //typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ caching_kernel () {}
+
+ caching_kernel (
+ const K& kern,
+ const sample_vector_type& samps,
+ long cache_size_
+ ) : real_kernel(kern), samples(&samps), counter(0)
+ {
+ cache_size = std::min<long>(cache_size_, samps.size());
+
+ cache.reset(new cache_type);
+ cache->frequency_of_use.resize(samps.size());
+ for (long i = 0; i < samps.size(); ++i)
+ cache->frequency_of_use[i] = std::make_pair(0, i);
+
+ // Set the cache build/rebuild threshold so that we have to have
+ // as many cache misses as there are entries in the cache before
+ // we build/rebuild.
+ counter_threshold = samps.size()*cache_size;
+ cache->sample_location.assign(samples->size(), -1);
+ }
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ // rebuild the cache every so often
+ if (counter > counter_threshold )
+ {
+ build_cache();
+ }
+
+ const long a_loc = cache->sample_location[a];
+ const long b_loc = cache->sample_location[b];
+
+ cache->frequency_of_use[a].first += 1;
+ cache->frequency_of_use[b].first += 1;
+
+ if (a_loc != -1)
+ {
+ return cache->kernel(a_loc, b);
+ }
+ else if (b_loc != -1)
+ {
+ return cache->kernel(b_loc, a);
+ }
+ else
+ {
+ ++counter;
+ return real_kernel((*samples)(a), (*samples)(b));
+ }
+ }
+
+ bool operator== (
+ const caching_kernel& item
+ ) const
+ {
+ return item.real_kernel == real_kernel &&
+ item.samples == samples;
+ }
+
+ private:
+ K real_kernel;
+
+ void build_cache (
+ ) const
+ {
+ std::sort(cache->frequency_of_use.rbegin(), cache->frequency_of_use.rend());
+ counter = 0;
+
+
+ cache->kernel.set_size(cache_size, samples->size());
+ cache->sample_location.assign(samples->size(), -1);
+
+ // loop over all the samples in the cache
+ for (long i = 0; i < cache_size; ++i)
+ {
+ const long cur = cache->frequency_of_use[i].second;
+ cache->sample_location[cur] = i;
+
+ // now populate all possible kernel products with the current sample
+ for (long j = 0; j < samples->size(); ++j)
+ {
+ cache->kernel(i, j) = real_kernel((*samples)(cur), (*samples)(j));
+ }
+
+ }
+
+ // reset the frequency of use metrics
+ for (long i = 0; i < samples->size(); ++i)
+ cache->frequency_of_use[i] = std::make_pair(0, i);
+ }
+
+
+ struct cache_type
+ {
+ matrix<scalar_type> kernel;
+
+ std::vector<long> sample_location; // where in the cache a sample is. -1 means not in cache
+ std::vector<std::pair<long,long> > frequency_of_use;
+ };
+
+ const sample_vector_type* samples = 0;
+
+ std::shared_ptr<cache_type> cache;
+ mutable unsigned long counter = 0;
+ unsigned long counter_threshold = 0;
+ long cache_size = 0;
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ public:
+ typedef typename trainer_type::kernel_type kernel_type;
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef typename trainer_type::trained_function_type trained_function_type;
+
+
+ batch_trainer (
+ ) :
+ min_learning_rate(0.1),
+ use_cache(false),
+ cache_size(100)
+ {
+ }
+
+ batch_trainer (
+ const trainer_type& trainer_,
+ const scalar_type min_learning_rate_,
+ bool verbose_,
+ bool use_cache_,
+ long cache_size_ = 100
+ ) :
+ trainer(trainer_),
+ min_learning_rate(min_learning_rate_),
+ verbose(verbose_),
+ use_cache(use_cache_),
+ cache_size(cache_size_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < min_learning_rate_ &&
+ cache_size_ > 0,
+ "\tbatch_trainer::batch_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t min_learning_rate_: " << min_learning_rate_
+ << "\n\t cache_size_: " << cache_size_
+ );
+
+ trainer.clear();
+ }
+
+ const scalar_type get_min_learning_rate (
+ ) const
+ {
+ return min_learning_rate;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ if (use_cache)
+ return do_train_cached(mat(x), mat(y));
+ else
+ return do_train(mat(x), mat(y));
+ }
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+
+ dlib::rand rnd;
+
+ trainer_type my_trainer(trainer);
+
+ scalar_type cur_learning_rate = min_learning_rate + 10;
+ unsigned long count = 0;
+
+ while (cur_learning_rate > min_learning_rate)
+ {
+ const long i = rnd.get_random_32bit_number()%x.size();
+ // keep feeding the trainer data until its learning rate goes below our threshold
+ cur_learning_rate = my_trainer.train(x(i), y(i));
+
+ if (verbose)
+ {
+ if ( (count&0x7FF) == 0)
+ {
+ std::cout << "\rbatch_trainer(): Percent complete: "
+ << 100*min_learning_rate/cur_learning_rate << " " << std::flush;
+ }
+ ++count;
+ }
+ }
+
+ if (verbose)
+ {
+ decision_function<kernel_type> df = my_trainer.get_decision_function();
+ std::cout << "\rbatch_trainer(): Percent complete: 100 " << std::endl;
+ std::cout << " Num sv: " << df.basis_vectors.size() << std::endl;
+ std::cout << " bias: " << df.b << std::endl;
+ return df;
+ }
+ else
+ {
+ return my_trainer.get_decision_function();
+ }
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train_cached (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+
+ dlib::rand rnd;
+
+ // make a caching kernel
+ typedef caching_kernel<kernel_type, in_sample_vector_type> ckernel_type;
+ ckernel_type ck(trainer.get_kernel(), x, cache_size);
+
+ // now rebind the trainer to use the caching kernel
+ typedef typename trainer_type::template rebind<ckernel_type>::other rebound_trainer_type;
+ rebound_trainer_type my_trainer;
+ my_trainer.set_kernel(ck);
+ replicate_settings(trainer, my_trainer);
+
+ scalar_type cur_learning_rate = min_learning_rate + 10;
+ unsigned long count = 0;
+
+ while (cur_learning_rate > min_learning_rate)
+ {
+ const long i = rnd.get_random_32bit_number()%x.size();
+ // keep feeding the trainer data until its learning rate goes below our threshold
+ cur_learning_rate = my_trainer.train(i, y(i));
+
+ if (verbose)
+ {
+ if ( (count&0x7FF) == 0)
+ {
+ std::cout << "\rbatch_trainer(): Percent complete: "
+ << 100*min_learning_rate/cur_learning_rate << " " << std::flush;
+ }
+ ++count;
+ }
+ }
+
+ if (verbose)
+ {
+ decision_function<ckernel_type> cached_df;
+ cached_df = my_trainer.get_decision_function();
+
+ std::cout << "\rbatch_trainer(): Percent complete: 100 " << std::endl;
+ std::cout << " Num sv: " << cached_df.basis_vectors.size() << std::endl;
+ std::cout << " bias: " << cached_df.b << std::endl;
+
+ return decision_function<kernel_type> (
+ cached_df.alpha,
+ cached_df.b,
+ trainer.get_kernel(),
+ rowm(x, cached_df.basis_vectors)
+ );
+ }
+ else
+ {
+ decision_function<ckernel_type> cached_df;
+ cached_df = my_trainer.get_decision_function();
+
+ return decision_function<kernel_type> (
+ cached_df.alpha,
+ cached_df.b,
+ trainer.get_kernel(),
+ rowm(x, cached_df.basis_vectors)
+ );
+ }
+ }
+
+ trainer_type trainer;
+ scalar_type min_learning_rate;
+ bool verbose;
+ bool use_cache;
+ long cache_size;
+
+ }; // end of class batch_trainer
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const batch_trainer<trainer_type> batch (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type min_learning_rate = 0.1
+ ) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, false); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const batch_trainer<trainer_type> verbose_batch (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type min_learning_rate = 0.1
+ ) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, false); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const batch_trainer<trainer_type> batch_cached (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type min_learning_rate = 0.1,
+ long cache_size = 100
+ ) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, true, cache_size); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const batch_trainer<trainer_type> verbose_batch_cached (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type min_learning_rate = 0.1,
+ long cache_size = 100
+ ) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, true, cache_size); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_PEGASoS_
+
diff --git a/ml/dlib/dlib/svm/pegasos_abstract.h b/ml/dlib/dlib/svm/pegasos_abstract.h
new file mode 100644
index 000000000..008b1cb94
--- /dev/null
+++ b/ml/dlib/dlib/svm/pegasos_abstract.h
@@ -0,0 +1,514 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_PEGASoS_ABSTRACT_
+#ifdef DLIB_PEGASoS_ABSTRACT_
+
+#include <cmath>
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kern_type
+ >
+ class svm_pegasos
+ {
+ /*!
+ REQUIREMENTS ON kern_type
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements an online algorithm for training a support
+ vector machine for solving binary classification problems.
+
+ The implementation of the Pegasos algorithm used by this object is based
+ on the following excellent paper:
+ Pegasos: Primal estimated sub-gradient solver for SVM (2007)
+ by Shai Shalev-Shwartz, Yoram Singer, Nathan Srebro
+ In ICML
+
+ This SVM training algorithm has two interesting properties. First, the
+ pegasos algorithm itself converges to the solution in an amount of time
+ unrelated to the size of the training set (in addition to being quite fast
+ to begin with). This makes it an appropriate algorithm for learning from
+ very large datasets. Second, this object uses the dlib::kcentroid object
+ to maintain a sparse approximation of the learned decision function.
+ This means that the number of support vectors in the resulting decision
+ function is also unrelated to the size of the dataset (in normal SVM
+ training algorithms, the number of support vectors grows approximately
+ linearly with the size of the training set).
+ !*/
+
+ public:
+ typedef kern_type kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ template <typename K_>
+ struct rebind {
+ typedef svm_pegasos<K_> other;
+ };
+
+ svm_pegasos (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ - #get_lambda_class1() == 0.0001
+ - #get_lambda_class2() == 0.0001
+ - #get_tolerance() == 0.01
+ - #get_train_count() == 0
+ - #get_max_num_sv() == 40
+ !*/
+
+ svm_pegasos (
+ const kernel_type& kernel_,
+ const scalar_type& lambda_,
+ const scalar_type& tolerance_,
+ unsigned long max_num_sv
+ );
+ /*!
+ requires
+ - lambda_ > 0
+ - tolerance_ > 0
+ - max_num_sv > 0
+ ensures
+ - this object is properly initialized
+ - #get_lambda_class1() == lambda_
+ - #get_lambda_class2() == lambda_
+ - #get_tolerance() == tolerance_
+ - #get_kernel() == kernel_
+ - #get_train_count() == 0
+ - #get_max_num_sv() == max_num_sv
+ !*/
+
+ void clear (
+ );
+ /*!
+ ensures
+ - #get_train_count() == 0
+ - clears out any memory of previous calls to train()
+ - doesn't change any of the algorithm parameters. I.e.
+ - #get_lambda_class1() == get_lambda_class1()
+ - #get_lambda_class2() == get_lambda_class2()
+ - #get_tolerance() == get_tolerance()
+ - #get_kernel() == get_kernel()
+ - #get_max_num_sv() == get_max_num_sv()
+ !*/
+
+ const scalar_type get_lambda_class1 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization term for the +1 class. It is the
+ parameter that determines the trade off between trying to fit the
+ +1 training data exactly or allowing more errors but hopefully
+ improving the generalization ability of the resulting classifier.
+ Smaller values encourage exact fitting while larger values may
+ encourage better generalization. It is also worth noting that the
+ number of iterations it takes for this algorithm to converge is
+ proportional to 1/lambda. So smaller values of this term cause
+ the running time of this algorithm to increase. For more
+ information you should consult the paper referenced above.
+ !*/
+
+ const scalar_type get_lambda_class2 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization term for the -1 class. It has
+ the same properties as the get_lambda_class1() parameter except that
+ it applies to the -1 class.
+ !*/
+
+ const scalar_type get_tolerance (
+ ) const;
+ /*!
+ ensures
+ - returns the tolerance used by the internal kcentroid object to
+ represent the learned decision function. Smaller values of this
+ tolerance will result in a more accurate representation of the
+ decision function but will use more support vectors (up to
+ a max of get_max_num_sv()).
+ !*/
+
+ unsigned long get_max_num_sv (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of support vectors this object is
+ allowed to use.
+ !*/
+
+ const kernel_type get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns the kernel used by this object
+ !*/
+
+ void set_kernel (
+ kernel_type k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ - #get_train_count() == 0
+ (i.e. clears any memory of previous training)
+ !*/
+
+ void set_tolerance (
+ double tol
+ );
+ /*!
+ requires
+ - tol > 0
+ ensures
+ - #get_tolerance() == tol
+ - #get_train_count() == 0
+ (i.e. clears any memory of previous training)
+ !*/
+
+ void set_max_num_sv (
+ unsigned long max_num_sv
+ );
+ /*!
+ requires
+ - max_num_sv > 0
+ ensures
+ - #get_max_num_sv() == max_num_sv
+ - #get_train_count() == 0
+ (i.e. clears any memory of previous training)
+ !*/
+
+ void set_lambda (
+ scalar_type lambda_
+ );
+ /*!
+ requires
+ - lambda_ > 0
+ ensures
+ - #get_lambda_class1() == lambda_
+ - #get_lambda_class2() == lambda_
+ - #get_train_count() == 0
+ (i.e. clears any memory of previous training)
+ !*/
+
+ void set_lambda_class1 (
+ scalar_type lambda_
+ );
+ /*!
+ requires
+ - lambda_ > 0
+ ensures
+ - #get_lambda_class1() == lambda_
+ #get_train_count() == 0
+ (i.e. clears any memory of previous training)
+ !*/
+
+ void set_lambda_class2 (
+ scalar_type lambda_
+ );
+ /*!
+ requires
+ - lambda_ > 0
+ ensures
+ - #get_lambda_class2() == lambda_
+ #get_train_count() == 0
+ (i.e. clears any memory of previous training)
+ !*/
+
+ unsigned long get_train_count (
+ ) const;
+ /*!
+ ensures
+ - returns how many times this->train() has been called
+ since this object was constructed or last cleared.
+ !*/
+
+ scalar_type train (
+ const sample_type& x,
+ const scalar_type& y
+ );
+ /*!
+ requires
+ - y == 1 || y == -1
+ ensures
+ - trains this svm using the given sample x and label y
+ - #get_train_count() == get_train_count() + 1
+ - returns the current learning rate
+ (i.e. 1/(get_train_count()*min(get_lambda_class1(),get_lambda_class2())) )
+ !*/
+
+ scalar_type operator() (
+ const sample_type& x
+ ) const;
+ /*!
+ ensures
+ - classifies the given x sample using the decision function
+ this object has learned so far.
+ - if (x is a sample predicted have +1 label) then
+ - returns a number >= 0
+ - else
+ - returns a number < 0
+ !*/
+
+ const decision_function<kernel_type> get_decision_function (
+ ) const;
+ /*!
+ ensures
+ - returns a decision function F that represents the function learned
+ by this object so far. I.e. it is the case that:
+ - for all x: F(x) == (*this)(x)
+ !*/
+
+ void swap (
+ svm_pegasos& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kern_type
+ >
+ void swap(
+ svm_pegasos<kern_type>& a,
+ svm_pegasos<kern_type>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ template <
+ typename kern_type
+ >
+ void serialize (
+ const svm_pegasos<kern_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for svm_pegasos objects
+ !*/
+
+ template <
+ typename kern_type
+ >
+ void deserialize (
+ svm_pegasos<kern_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides serialization support for svm_pegasos objects
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ void replicate_settings (
+ const svm_pegasos<T>& source,
+ svm_pegasos<U>& dest
+ );
+ /*!
+ ensures
+ - copies all the parameters from the source trainer to the dest trainer.
+ - #dest.get_tolerance() == source.get_tolerance()
+ - #dest.get_lambda_class1() == source.get_lambda_class1()
+ - #dest.get_lambda_class2() == source.get_lambda_class2()
+ - #dest.get_max_num_sv() == source.get_max_num_sv()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ class batch_trainer
+ {
+ /*!
+ REQUIREMENTS ON trainer_type
+ - trainer_type == some kind of online trainer object (e.g. svm_pegasos)
+ replicate_settings() must also be defined for the type.
+
+ WHAT THIS OBJECT REPRESENTS
+ This is a trainer object that is meant to wrap online trainer objects
+ that create decision_functions. It turns an online learning algorithm
+ such as svm_pegasos into a batch learning object. This allows you to
+ use objects like svm_pegasos with functions (e.g. cross_validate_trainer)
+ that expect batch mode training objects.
+ !*/
+
+ public:
+ typedef typename trainer_type::kernel_type kernel_type;
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef typename trainer_type::trained_function_type trained_function_type;
+
+
+ batch_trainer (
+ );
+ /*!
+ ensures
+ - This object is in an uninitialized state. You must
+ construct a real one with the other constructor and assign it
+ to this instance before you use this object.
+ !*/
+
+ batch_trainer (
+ const trainer_type& online_trainer,
+ const scalar_type min_learning_rate_,
+ bool verbose_,
+ bool use_cache_,
+ long cache_size_ = 100
+ );
+ /*!
+ requires
+ - min_learning_rate_ > 0
+ - cache_size_ > 0
+ ensures
+ - returns a batch trainer object that uses the given online_trainer object
+ to train a decision function.
+ - #get_min_learning_rate() == min_learning_rate_
+ - if (verbose_ == true) then
+ - this object will output status messages to standard out while
+ training is under way.
+ - if (use_cache_ == true) then
+ - this object will cache up to cache_size_ columns of the kernel
+ matrix during the training process.
+ !*/
+
+ const scalar_type get_min_learning_rate (
+ ) const;
+ /*!
+ ensures
+ - returns the min learning rate that the online trainer must reach
+ before this object considers training to be complete.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ ensures
+ - trains and returns a decision_function using the trainer that was
+ supplied to this object's constructor.
+ - training continues until the online training object indicates that
+ its learning rate has dropped below get_min_learning_rate().
+ throws
+ - std::bad_alloc
+ - any exceptions thrown by the trainer_type object
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const batch_trainer<trainer_type> batch (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type min_learning_rate = 0.1
+ ) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, false); }
+ /*!
+ requires
+ - min_learning_rate > 0
+ - trainer_type == some kind of online trainer object that creates decision_function
+ objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type.
+ ensures
+ - returns a batch_trainer object that has been instantiated with the
+ given arguments.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const batch_trainer<trainer_type> verbose_batch (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type min_learning_rate = 0.1
+ ) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, false); }
+ /*!
+ requires
+ - min_learning_rate > 0
+ - trainer_type == some kind of online trainer object that creates decision_function
+ objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type.
+ ensures
+ - returns a batch_trainer object that has been instantiated with the
+ given arguments (and is verbose).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const batch_trainer<trainer_type> batch_cached (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type min_learning_rate = 0.1,
+ long cache_size = 100
+ ) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, true, cache_size); }
+ /*!
+ requires
+ - min_learning_rate > 0
+ - cache_size > 0
+ - trainer_type == some kind of online trainer object that creates decision_function
+ objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type.
+ ensures
+ - returns a batch_trainer object that has been instantiated with the
+ given arguments (uses a kernel cache).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const batch_trainer<trainer_type> verbose_batch_cached (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type min_learning_rate = 0.1,
+ long cache_size = 100
+ ) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, true, cache_size); }
+ /*!
+ requires
+ - min_learning_rate > 0
+ - cache_size > 0
+ - trainer_type == some kind of online trainer object that creates decision_function
+ objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type.
+ ensures
+ - returns a batch_trainer object that has been instantiated with the
+ given arguments (is verbose and uses a kernel cache).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+#endif // DLIB_PEGASoS_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/svm/ranking_tools.h b/ml/dlib/dlib/svm/ranking_tools.h
new file mode 100644
index 000000000..3c77b41ae
--- /dev/null
+++ b/ml/dlib/dlib/svm/ranking_tools.h
@@ -0,0 +1,448 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RANKING_ToOLS_Hh_
+#define DLIB_RANKING_ToOLS_Hh_
+
+#include "ranking_tools_abstract.h"
+
+#include "../algs.h"
+#include "../matrix.h"
+#include <vector>
+#include <utility>
+#include <algorithm>
+#include "sparse_vector.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct ranking_pair
+ {
+ ranking_pair() {}
+
+ ranking_pair(
+ const std::vector<T>& r,
+ const std::vector<T>& nr
+ ) :
+ relevant(r), nonrelevant(nr)
+ {}
+
+ std::vector<T> relevant;
+ std::vector<T> nonrelevant;
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const ranking_pair<T>& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.relevant, out);
+ serialize(item.nonrelevant, out);
+ }
+
+
+ template <
+ typename T
+ >
+ void deserialize (
+ ranking_pair<T>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw dlib::serialization_error("Wrong version found while deserializing dlib::ranking_pair");
+
+ deserialize(item.relevant, in);
+ deserialize(item.nonrelevant, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename disable_if<is_matrix<T>,bool>::type is_ranking_problem (
+ const std::vector<ranking_pair<T> >& samples
+ )
+ {
+ if (samples.size() == 0)
+ return false;
+
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (samples[i].relevant.size() == 0)
+ return false;
+ if (samples[i].nonrelevant.size() == 0)
+ return false;
+ }
+
+ return true;
+ }
+
+ template <
+ typename T
+ >
+ typename enable_if<is_matrix<T>,bool>::type is_ranking_problem (
+ const std::vector<ranking_pair<T> >& samples
+ )
+ {
+ if (samples.size() == 0)
+ return false;
+
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (samples[i].relevant.size() == 0)
+ return false;
+ if (samples[i].nonrelevant.size() == 0)
+ return false;
+ }
+
+ // If these are dense vectors then they must all have the same dimensionality.
+ const long dims = max_index_plus_one(samples[0].relevant);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ for (unsigned long j = 0; j < samples[i].relevant.size(); ++j)
+ {
+ if (is_vector(samples[i].relevant[j]) == false)
+ return false;
+
+ if (samples[i].relevant[j].size() != dims)
+ return false;
+ }
+ for (unsigned long j = 0; j < samples[i].nonrelevant.size(); ++j)
+ {
+ if (is_vector(samples[i].nonrelevant[j]) == false)
+ return false;
+
+ if (samples[i].nonrelevant[j].size() != dims)
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ unsigned long max_index_plus_one (
+ const ranking_pair<T>& item
+ )
+ {
+ return std::max(max_index_plus_one(item.relevant), max_index_plus_one(item.nonrelevant));
+ }
+
+ template <
+ typename T
+ >
+ unsigned long max_index_plus_one (
+ const std::vector<ranking_pair<T> >& samples
+ )
+ {
+ unsigned long dims = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ dims = std::max(dims, max_index_plus_one(samples[i]));
+ }
+ return dims;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void count_ranking_inversions (
+ const std::vector<T>& x,
+ const std::vector<T>& y,
+ std::vector<unsigned long>& x_count,
+ std::vector<unsigned long>& y_count
+ )
+ {
+ x_count.assign(x.size(),0);
+ y_count.assign(y.size(),0);
+
+ if (x.size() == 0 || y.size() == 0)
+ return;
+
+ std::vector<std::pair<T,unsigned long> > xsort(x.size());
+ std::vector<std::pair<T,unsigned long> > ysort(y.size());
+ for (unsigned long i = 0; i < x.size(); ++i)
+ xsort[i] = std::make_pair(x[i], i);
+ for (unsigned long j = 0; j < y.size(); ++j)
+ ysort[j] = std::make_pair(y[j], j);
+
+ std::sort(xsort.begin(), xsort.end());
+ std::sort(ysort.begin(), ysort.end());
+
+
+ unsigned long i, j;
+
+ // Do the counting for the x values.
+ for (i = 0, j = 0; i < x_count.size(); ++i)
+ {
+ // Skip past y values that are in the correct order with respect to xsort[i].
+ while (j < ysort.size() && ysort[j].first < xsort[i].first)
+ ++j;
+
+ x_count[xsort[i].second] = ysort.size() - j;
+ }
+
+
+ // Now do the counting for the y values.
+ for (i = 0, j = 0; j < y_count.size(); ++j)
+ {
+ // Skip past x values that are in the incorrect order with respect to ysort[j].
+ while (i < xsort.size() && !(ysort[j].first < xsort[i].first))
+ ++i;
+
+ y_count[ysort[j].second] = i;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline bool compare_first_reverse_second (
+ const std::pair<double,bool>& a,
+ const std::pair<double,bool>& b
+ )
+ {
+ if (a.first < b.first)
+ return true;
+ else if (a.first > b.first)
+ return false;
+ else if (a.second && !b.second)
+ return true;
+ else
+ return false;
+ }
+ }
+
+ template <
+ typename ranking_function,
+ typename T
+ >
+ matrix<double,1,2> test_ranking_function (
+ const ranking_function& funct,
+ const std::vector<ranking_pair<T> >& samples
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_ranking_problem(samples),
+ "\t double test_ranking_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples)
+ );
+
+ unsigned long total_pairs = 0;
+ unsigned long total_wrong = 0;
+
+ std::vector<double> rel_scores;
+ std::vector<double> nonrel_scores;
+ std::vector<unsigned long> rel_counts;
+ std::vector<unsigned long> nonrel_counts;
+
+ running_stats<double> rs;
+ std::vector<std::pair<double,bool> > total_scores;
+ std::vector<bool> total_ranking;
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ rel_scores.resize(samples[i].relevant.size());
+ nonrel_scores.resize(samples[i].nonrelevant.size());
+ total_scores.clear();
+
+ for (unsigned long k = 0; k < rel_scores.size(); ++k)
+ {
+ rel_scores[k] = funct(samples[i].relevant[k]);
+ total_scores.push_back(std::make_pair(rel_scores[k], true));
+ }
+
+ for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
+ {
+ nonrel_scores[k] = funct(samples[i].nonrelevant[k]);
+ total_scores.push_back(std::make_pair(nonrel_scores[k], false));
+ }
+
+ // Now compute the average precision for this sample. We need to sort the
+ // results and the back them into total_ranking. Note that we sort them so
+ // that, if you get a block of ranking values that are all equal, the elements
+ // marked as true will come last. This prevents a ranking from outputting a
+ // constant value for everything and still getting a good MAP score.
+ std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second);
+ total_ranking.clear();
+ for (unsigned long i = 0; i < total_scores.size(); ++i)
+ total_ranking.push_back(total_scores[i].second);
+ rs.add(average_precision(total_ranking));
+
+
+ count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
+
+ total_pairs += rel_scores.size()*nonrel_scores.size();
+
+ // Note that we don't need to look at nonrel_counts since it is redundant with
+ // the information in rel_counts in this case.
+ total_wrong += sum(mat(rel_counts));
+ }
+
+ const double rank_swaps = static_cast<double>(total_pairs - total_wrong) / total_pairs;
+ const double mean_average_precision = rs.mean();
+ matrix<double,1,2> res;
+ res = rank_swaps, mean_average_precision;
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename ranking_function,
+ typename T
+ >
+ matrix<double,1,2> test_ranking_function (
+ const ranking_function& funct,
+ const ranking_pair<T>& sample
+ )
+ {
+ return test_ranking_function(funct, std::vector<ranking_pair<T> >(1,sample));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename T
+ >
+ matrix<double,1,2> cross_validate_ranking_trainer (
+ const trainer_type& trainer,
+ const std::vector<ranking_pair<T> >& samples,
+ const long folds
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_ranking_problem(samples) &&
+ 1 < folds && folds <= static_cast<long>(samples.size()),
+ "\t double cross_validate_ranking_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples)
+ );
+
+
+ const long num_in_test = samples.size()/folds;
+ const long num_in_train = samples.size() - num_in_test;
+
+
+ std::vector<ranking_pair<T> > samples_test, samples_train;
+
+
+ long next_test_idx = 0;
+
+ unsigned long total_pairs = 0;
+ unsigned long total_wrong = 0;
+
+ std::vector<double> rel_scores;
+ std::vector<double> nonrel_scores;
+ std::vector<unsigned long> rel_counts;
+ std::vector<unsigned long> nonrel_counts;
+
+ running_stats<double> rs;
+ std::vector<std::pair<double,bool> > total_scores;
+ std::vector<bool> total_ranking;
+
+ for (long i = 0; i < folds; ++i)
+ {
+ samples_test.clear();
+ samples_train.clear();
+
+ // load up the test samples
+ for (long cnt = 0; cnt < num_in_test; ++cnt)
+ {
+ samples_test.push_back(samples[next_test_idx]);
+ next_test_idx = (next_test_idx + 1)%samples.size();
+ }
+
+ // load up the training samples
+ long next = next_test_idx;
+ for (long cnt = 0; cnt < num_in_train; ++cnt)
+ {
+ samples_train.push_back(samples[next]);
+ next = (next + 1)%samples.size();
+ }
+
+
+ const typename trainer_type::trained_function_type& df = trainer.train(samples_train);
+
+ // check how good df is on the test data
+ for (unsigned long i = 0; i < samples_test.size(); ++i)
+ {
+ rel_scores.resize(samples_test[i].relevant.size());
+ nonrel_scores.resize(samples_test[i].nonrelevant.size());
+
+ total_scores.clear();
+
+ for (unsigned long k = 0; k < rel_scores.size(); ++k)
+ {
+ rel_scores[k] = df(samples_test[i].relevant[k]);
+ total_scores.push_back(std::make_pair(rel_scores[k], true));
+ }
+
+ for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
+ {
+ nonrel_scores[k] = df(samples_test[i].nonrelevant[k]);
+ total_scores.push_back(std::make_pair(nonrel_scores[k], false));
+ }
+
+ // Now compute the average precision for this sample. We need to sort the
+ // results and the back them into total_ranking. Note that we sort them so
+ // that, if you get a block of ranking values that are all equal, the elements
+ // marked as true will come last. This prevents a ranking from outputting a
+ // constant value for everything and still getting a good MAP score.
+ std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second);
+ total_ranking.clear();
+ for (unsigned long i = 0; i < total_scores.size(); ++i)
+ total_ranking.push_back(total_scores[i].second);
+ rs.add(average_precision(total_ranking));
+
+
+ count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
+
+ total_pairs += rel_scores.size()*nonrel_scores.size();
+
+ // Note that we don't need to look at nonrel_counts since it is redundant with
+ // the information in rel_counts in this case.
+ total_wrong += sum(mat(rel_counts));
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+ const double rank_swaps = static_cast<double>(total_pairs - total_wrong) / total_pairs;
+ const double mean_average_precision = rs.mean();
+ matrix<double,1,2> res;
+ res = rank_swaps, mean_average_precision;
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANKING_ToOLS_Hh_
+
diff --git a/ml/dlib/dlib/svm/ranking_tools_abstract.h b/ml/dlib/dlib/svm/ranking_tools_abstract.h
new file mode 100644
index 000000000..af6c7a2e3
--- /dev/null
+++ b/ml/dlib/dlib/svm/ranking_tools_abstract.h
@@ -0,0 +1,247 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RANKING_ToOLS_ABSTRACT_Hh_
+#ifdef DLIB_RANKING_ToOLS_ABSTRACT_Hh_
+
+
+#include "../algs.h"
+#include "../matrix.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct ranking_pair
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is used to contain a ranking example. In particular, we say
+ that a good ranking of T objects is one in which all the elements in
+ this->relevant are ranked higher than the elements of this->nonrelevant.
+ Therefore, ranking_pair objects are used to represent training examples for
+ learning-to-rank tasks.
+ !*/
+
+ ranking_pair() {}
+ /*!
+ ensures
+ - #relevant.size() == 0
+ - #nonrelevant.size() == 0
+ !*/
+
+ ranking_pair(
+ const std::vector<T>& r,
+ const std::vector<T>& nr
+ ) : relevant(r), nonrelevant(nr) {}
+ /*!
+ ensures
+ - #relevant == r
+ - #nonrelevant == nr
+ !*/
+
+ std::vector<T> relevant;
+ std::vector<T> nonrelevant;
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const ranking_pair<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ ranking_pair<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool is_ranking_problem (
+ const std::vector<ranking_pair<T> >& samples
+ );
+ /*!
+ ensures
+ - returns true if the data in samples represents a valid learning-to-rank
+ learning problem. That is, this function returns true if all of the
+ following are true and false otherwise:
+ - samples.size() > 0
+ - for all valid i:
+ - samples[i].relevant.size() > 0
+ - samples[i].nonrelevant.size() > 0
+ - if (is_matrix<T>::value == true) then
+ - All the elements of samples::nonrelevant and samples::relevant must
+ represent row or column vectors and they must be the same dimension.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ unsigned long max_index_plus_one (
+ const ranking_pair<T>& item
+ );
+ /*!
+ requires
+ - T must be a dlib::matrix capable of storing column vectors or T must be a
+ sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
+ ensures
+ - returns std::max(max_index_plus_one(item.relevant), max_index_plus_one(item.nonrelevant)).
+ Therefore, this function can be used to find the dimensionality of the
+ vectors stored in item.
+ !*/
+
+ template <
+ typename T
+ >
+ unsigned long max_index_plus_one (
+ const std::vector<ranking_pair<T> >& samples
+ );
+ /*!
+ requires
+ - T must be a dlib::matrix capable of storing column vectors or T must be a
+ sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
+ ensures
+ - returns the maximum of max_index_plus_one(samples[i]) over all valid values
+ of i. Therefore, this function can be used to find the dimensionality of the
+ vectors stored in samples
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void count_ranking_inversions (
+ const std::vector<T>& x,
+ const std::vector<T>& y,
+ std::vector<unsigned long>& x_count,
+ std::vector<unsigned long>& y_count
+ );
+ /*!
+ requires
+ - T objects must be copyable
+ - T objects must be comparable via operator<
+ ensures
+ - This function counts how many times we see a y value greater than or equal to
+ an x value. This is done efficiently in O(n*log(n)) time via the use of
+ quick sort.
+ - #x_count.size() == x.size()
+ - #y_count.size() == y.size()
+ - for all valid i:
+ - #x_count[i] == how many times a value in y was >= x[i].
+ - for all valid j:
+ - #y_count[j] == how many times a value in x was <= y[j].
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename ranking_function,
+ typename T
+ >
+ matrix<double,1,2> test_ranking_function (
+ const ranking_function& funct,
+ const std::vector<ranking_pair<T> >& samples
+ );
+ /*!
+ requires
+ - is_ranking_problem(samples) == true
+ - ranking_function == some kind of decision function object (e.g. decision_function)
+ ensures
+ - Tests the given ranking function on the supplied example ranking data and
+ returns the fraction of ranking pair orderings predicted correctly. This is
+ a number in the range [0,1] where 0 means everything was incorrectly
+ predicted while 1 means everything was correctly predicted. This function
+ also returns the mean average precision.
+ - In particular, this function returns a matrix M summarizing the results.
+ Specifically, it returns an M such that:
+ - M(0) == the fraction of times that the following is true:
+ - funct(samples[k].relevant[i]) > funct(samples[k].nonrelevant[j])
+ (for all valid i,j,k)
+ - M(1) == the mean average precision of the rankings induced by funct.
+ (Mean average precision is a number in the range 0 to 1. Moreover, a
+ mean average precision of 1 means everything was correctly predicted
+ while smaller values indicate worse rankings. See the documentation
+ for average_precision() for details of its computation.)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename ranking_function,
+ typename T
+ >
+ matrix<double,1,2> test_ranking_function (
+ const ranking_function& funct,
+ const ranking_pair<T>& sample
+ );
+ /*!
+ requires
+ - is_ranking_problem(std::vector<ranking_pair<T> >(1, sample)) == true
+ - ranking_function == some kind of decision function object (e.g. decision_function)
+ ensures
+ - This is just a convenience routine for calling the above
+ test_ranking_function() routine. That is, it just copies sample into a
+ std::vector object and invokes the above test_ranking_function() routine.
+ This means that calling this function is equivalent to invoking:
+ return test_ranking_function(funct, std::vector<ranking_pair<T> >(1, sample));
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename T
+ >
+ matrix<double,1,2> cross_validate_ranking_trainer (
+ const trainer_type& trainer,
+ const std::vector<ranking_pair<T> >& samples,
+ const long folds
+ );
+ /*!
+ requires
+ - is_ranking_problem(samples) == true
+ - 1 < folds <= samples.size()
+ - trainer_type == some kind of ranking trainer object (e.g. svm_rank_trainer)
+ ensures
+ - Performs k-fold cross validation by using the given trainer to solve the
+ given ranking problem for the given number of folds. Each fold is tested
+ using the output of the trainer and the average ranking accuracy as well as
+ the mean average precision over the number of folds is returned.
+ - The accuracy is computed the same way test_ranking_function() computes its
+ accuracy. Therefore, it is a number in the range [0,1] that represents the
+ fraction of times a ranking pair's ordering was predicted correctly. Similarly,
+ the mean average precision is computed identically to test_ranking_function().
+ In particular, this means that this function returns a matrix M such that:
+ - M(0) == the ranking accuracy
+ - M(1) == the mean average precision
+ - The number of folds used is given by the folds argument.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANKING_ToOLS_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/rbf_network.h b/ml/dlib/dlib/svm/rbf_network.h
new file mode 100644
index 000000000..23a2c7424
--- /dev/null
+++ b/ml/dlib/dlib/svm/rbf_network.h
@@ -0,0 +1,162 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RBf_NETWORK_
+#define DLIB_RBf_NETWORK_
+
+#include "../matrix.h"
+#include "rbf_network_abstract.h"
+#include "kernel.h"
+#include "linearly_independent_subset_finder.h"
+#include "function.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ------------------------------------------------------------------------------
+
+ template <
+ typename Kern
+ >
+ class rbf_network_trainer
+ {
+ /*!
+ This is an implementation of an RBF network trainer that follows
+ the directions right off Wikipedia basically. So nothing
+ particularly fancy. Although the way the centers are selected
+ is somewhat unique.
+ !*/
+
+ public:
+ typedef Kern kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ rbf_network_trainer (
+ ) :
+ num_centers(10)
+ {
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kernel = k;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel;
+ }
+
+ void set_num_centers (
+ const unsigned long num
+ )
+ {
+ num_centers = num;
+ }
+
+ unsigned long get_num_centers (
+ ) const
+ {
+ return num_centers;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ return do_train(mat(x), mat(y));
+ }
+
+ void swap (
+ rbf_network_trainer& item
+ )
+ {
+ exchange(kernel, item.kernel);
+ exchange(num_centers, item.num_centers);
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ typedef typename decision_function<kernel_type>::scalar_vector_type scalar_vector_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y),
+ "\tdecision_function rbf_network_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t y.nr(): " << y.nr()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t y.nc(): " << y.nc()
+ );
+
+ // use the linearly_independent_subset_finder object to select the centers. So here
+ // we show it all the data samples so it can find the best centers.
+ linearly_independent_subset_finder<kernel_type> lisf(kernel, num_centers);
+ fill_lisf(lisf, x);
+
+ const long num_centers = lisf.size();
+
+ // fill the K matrix with the output of the kernel for all the center and sample point pairs
+ matrix<scalar_type,0,0,mem_manager_type> K(x.nr(), num_centers+1);
+ for (long r = 0; r < x.nr(); ++r)
+ {
+ for (long c = 0; c < num_centers; ++c)
+ {
+ K(r,c) = kernel(x(r), lisf[c]);
+ }
+ // This last column of the K matrix takes care of the bias term
+ K(r,num_centers) = 1;
+ }
+
+ // compute the best weights by using the pseudo-inverse
+ scalar_vector_type weights(pinv(K)*y);
+
+ // now put everything into a decision_function object and return it
+ return decision_function<kernel_type> (remove_row(weights,num_centers),
+ -weights(num_centers),
+ kernel,
+ lisf.get_dictionary());
+
+ }
+
+ kernel_type kernel;
+ unsigned long num_centers;
+
+ }; // end of class rbf_network_trainer
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename sample_type>
+ void swap (
+ rbf_network_trainer<sample_type>& a,
+ rbf_network_trainer<sample_type>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RBf_NETWORK_
+
diff --git a/ml/dlib/dlib/svm/rbf_network_abstract.h b/ml/dlib/dlib/svm/rbf_network_abstract.h
new file mode 100644
index 000000000..782a4bdbd
--- /dev/null
+++ b/ml/dlib/dlib/svm/rbf_network_abstract.h
@@ -0,0 +1,132 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RBf_NETWORK_ABSTRACT_
+#ifdef DLIB_RBf_NETWORK_ABSTRACT_
+
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class rbf_network_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+ (since this is supposed to be a RBF network it is probably reasonable
+ to use some sort of radial basis kernel)
+
+ INITIAL VALUE
+ - get_num_centers() == 10
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a trainer for a radial basis function network.
+
+ The implementation of this algorithm follows the normal RBF training
+ process. For more details see the code or the Wikipedia article
+ about RBF networks.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ rbf_network_trainer (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ void set_num_centers (
+ const unsigned long num_centers
+ );
+ /*!
+ ensures
+ - #get_num_centers() == num_centers
+ !*/
+
+ const unsigned long get_num_centers (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of centers (a.k.a. basis_vectors in the
+ trained decision_function) you will get when you train this object on data.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ /*!
+ requires
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ - is_learning_problem(x,y) == true
+ ensures
+ - trains a RBF network given the training samples in x and
+ labels in y and returns the resulting decision_function
+ throws
+ - std::bad_alloc
+ !*/
+
+ void swap (
+ rbf_network_trainer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ rbf_network_trainer<K>& a,
+ rbf_network_trainer<K>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RBf_NETWORK_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/reduced.h b/ml/dlib/dlib/svm/reduced.h
new file mode 100644
index 000000000..b4c5b63ca
--- /dev/null
+++ b/ml/dlib/dlib/svm/reduced.h
@@ -0,0 +1,613 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_REDUCEd_TRAINERS_
+#define DLIB_REDUCEd_TRAINERS_
+
+#include "reduced_abstract.h"
+#include "../matrix.h"
+#include "../algs.h"
+#include "function.h"
+#include "kernel.h"
+#include "kcentroid.h"
+#include "linearly_independent_subset_finder.h"
+#include "../optimization.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ class reduced_decision_function_trainer
+ {
+ public:
+ typedef typename trainer_type::kernel_type kernel_type;
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef typename trainer_type::trained_function_type trained_function_type;
+
+ reduced_decision_function_trainer (
+ ) :num_bv(0) {}
+
+ reduced_decision_function_trainer (
+ const trainer_type& trainer_,
+ const unsigned long num_sb_
+ ) :
+ trainer(trainer_),
+ num_bv(num_sb_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_bv > 0,
+ "\t reduced_decision_function_trainer()"
+ << "\n\t you have given invalid arguments to this function"
+ << "\n\t num_bv: " << num_bv
+ );
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_bv > 0,
+ "\t reduced_decision_function_trainer::train(x,y)"
+ << "\n\t You have tried to use an uninitialized version of this object"
+ << "\n\t num_bv: " << num_bv );
+ return do_train(mat(x), mat(y));
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ // get the decision function object we are going to try and approximate
+ const decision_function<kernel_type>& dec_funct = trainer.train(x,y);
+
+ // now find a linearly independent subset of the training points of num_bv points.
+ linearly_independent_subset_finder<kernel_type> lisf(dec_funct.kernel_function, num_bv);
+ fill_lisf(lisf, x);
+
+ // The next few statements just find the best weights with which to approximate
+ // the dec_funct object with the smaller set of vectors in the lisf dictionary. This
+ // is really just a simple application of some linear algebra. For the details
+ // see page 554 of Learning with kernels by Scholkopf and Smola where they talk
+ // about "Optimal Expansion Coefficients."
+
+ const kernel_type kern(dec_funct.kernel_function);
+
+ matrix<scalar_type,0,1,mem_manager_type> alpha;
+
+ alpha = lisf.get_inv_kernel_marix()*(kernel_matrix(kern,lisf,dec_funct.basis_vectors)*dec_funct.alpha);
+
+ decision_function<kernel_type> new_df(alpha,
+ 0,
+ kern,
+ lisf.get_dictionary());
+
+ // now we have to figure out what the new bias should be. It might be a little
+ // different since we just messed with all the weights and vectors.
+ double bias = 0;
+ for (long i = 0; i < x.nr(); ++i)
+ {
+ bias += new_df(x(i)) - dec_funct(x(i));
+ }
+
+ new_df.b = bias/x.nr();
+
+ return new_df;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ trainer_type trainer;
+ unsigned long num_bv;
+
+
+ }; // end of class reduced_decision_function_trainer
+
+ template <typename trainer_type>
+ const reduced_decision_function_trainer<trainer_type> reduced (
+ const trainer_type& trainer,
+ const unsigned long num_bv
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_bv > 0,
+ "\tconst reduced_decision_function_trainer reduced()"
+ << "\n\t you have given invalid arguments to this function"
+ << "\n\t num_bv: " << num_bv
+ );
+
+ return reduced_decision_function_trainer<trainer_type>(trainer, num_bv);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace red_impl
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ class objective
+ {
+ /*
+ This object represents the objective function we will try to
+ minimize in approximate_distance_function().
+
+ The objective is the distance, in kernel induced feature space, between
+ the original distance function and the approximated version.
+
+ */
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ public:
+ objective(
+ const distance_function<kernel_type>& dist_funct_,
+ matrix<scalar_type,0,1,mem_manager_type>& b_,
+ matrix<sample_type,0,1,mem_manager_type>& out_vectors_
+ ) :
+ dist_funct(dist_funct_),
+ b(b_),
+ out_vectors(out_vectors_)
+ {
+ }
+
+ const matrix<scalar_type, 0, 1, mem_manager_type> state_to_vector (
+ ) const
+ /*!
+ ensures
+ - returns a vector that contains all the information necessary to
+ reproduce the current state of the approximated distance function
+ !*/
+ {
+ matrix<scalar_type, 0, 1, mem_manager_type> z(b.nr() + out_vectors.size()*out_vectors(0).nr());
+ long i = 0;
+ for (long j = 0; j < b.nr(); ++j)
+ {
+ z(i) = b(j);
+ ++i;
+ }
+
+ for (long j = 0; j < out_vectors.size(); ++j)
+ {
+ for (long k = 0; k < out_vectors(j).size(); ++k)
+ {
+ z(i) = out_vectors(j)(k);
+ ++i;
+ }
+ }
+ return z;
+ }
+
+
+ void vector_to_state (
+ const matrix<scalar_type, 0, 1, mem_manager_type>& z
+ ) const
+ /*!
+ requires
+ - z came from the state_to_vector() function or has a compatible format
+ ensures
+ - loads the vector z into the state variables of the approximate
+ distance function (i.e. b and out_vectors)
+ !*/
+ {
+ long i = 0;
+ for (long j = 0; j < b.nr(); ++j)
+ {
+ b(j) = z(i);
+ ++i;
+ }
+
+ for (long j = 0; j < out_vectors.size(); ++j)
+ {
+ for (long k = 0; k < out_vectors(j).size(); ++k)
+ {
+ out_vectors(j)(k) = z(i);
+ ++i;
+ }
+ }
+ }
+
+ double operator() (
+ const matrix<scalar_type, 0, 1, mem_manager_type>& z
+ ) const
+ /*!
+ ensures
+ - loads the current approximate distance function with z
+ - returns the distance between the original distance function
+ and the approximate one.
+ !*/
+ {
+ vector_to_state(z);
+ const kernel_type k(dist_funct.get_kernel());
+
+ double temp = 0;
+ for (long i = 0; i < out_vectors.size(); ++i)
+ {
+ for (long j = 0; j < dist_funct.get_basis_vectors().nr(); ++j)
+ {
+ temp -= b(i)*dist_funct.get_alpha()(j)*k(out_vectors(i), dist_funct.get_basis_vectors()(j));
+ }
+ }
+
+ temp *= 2;
+
+ for (long i = 0; i < out_vectors.size(); ++i)
+ {
+ for (long j = 0; j < out_vectors.size(); ++j)
+ {
+ temp += b(i)*b(j)*k(out_vectors(i), out_vectors(j));
+ }
+ }
+
+ return temp + dist_funct.get_squared_norm();
+ }
+
+ private:
+
+ const distance_function<kernel_type>& dist_funct;
+ matrix<scalar_type,0,1,mem_manager_type>& b;
+ matrix<sample_type,0,1,mem_manager_type>& out_vectors;
+
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ class objective_derivative
+ {
+ /*!
+ This object represents the derivative of the objective object
+ !*/
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ public:
+
+
+ objective_derivative(
+ const distance_function<kernel_type>& dist_funct_,
+ matrix<scalar_type,0,1,mem_manager_type>& b_,
+ matrix<sample_type,0,1,mem_manager_type>& out_vectors_
+ ) :
+ dist_funct(dist_funct_),
+ b(b_),
+ out_vectors(out_vectors_)
+ {
+ }
+
+ void vector_to_state (
+ const matrix<scalar_type, 0, 1, mem_manager_type>& z
+ ) const
+ /*!
+ requires
+ - z came from the state_to_vector() function or has a compatible format
+ ensures
+ - loads the vector z into the state variables of the approximate
+ distance function (i.e. b and out_vectors)
+ !*/
+ {
+ long i = 0;
+ for (long j = 0; j < b.nr(); ++j)
+ {
+ b(j) = z(i);
+ ++i;
+ }
+
+ for (long j = 0; j < out_vectors.size(); ++j)
+ {
+ for (long k = 0; k < out_vectors(j).size(); ++k)
+ {
+ out_vectors(j)(k) = z(i);
+ ++i;
+ }
+ }
+ }
+
+ const matrix<scalar_type,0,1,mem_manager_type>& operator() (
+ const matrix<scalar_type, 0, 1, mem_manager_type>& z
+ ) const
+ /*!
+ ensures
+ - loads the current approximate distance function with z
+ - returns the derivative of the distance between the original
+ distance function and the approximate one.
+ !*/
+ {
+ vector_to_state(z);
+ res.set_size(z.nr());
+ set_all_elements(res,0);
+ const kernel_type k(dist_funct.get_kernel());
+ const kernel_derivative<kernel_type> K_der(k);
+
+ // first compute the gradient for the beta weights
+ for (long i = 0; i < out_vectors.size(); ++i)
+ {
+ for (long j = 0; j < out_vectors.size(); ++j)
+ {
+ res(i) += b(j)*k(out_vectors(i), out_vectors(j));
+ }
+ }
+ for (long i = 0; i < out_vectors.size(); ++i)
+ {
+ for (long j = 0; j < dist_funct.get_basis_vectors().size(); ++j)
+ {
+ res(i) -= dist_funct.get_alpha()(j)*k(out_vectors(i), dist_funct.get_basis_vectors()(j));
+ }
+ }
+
+
+ // now compute the gradient of the actual vectors that go into the kernel functions
+ long pos = out_vectors.size();
+ const long num = out_vectors(0).nr();
+ temp.set_size(num,1);
+ for (long i = 0; i < out_vectors.size(); ++i)
+ {
+ set_all_elements(temp,0);
+ for (long j = 0; j < out_vectors.size(); ++j)
+ {
+ temp += b(j)*K_der(out_vectors(j), out_vectors(i));
+ }
+ for (long j = 0; j < dist_funct.get_basis_vectors().nr(); ++j)
+ {
+ temp -= dist_funct.get_alpha()(j)*K_der(dist_funct.get_basis_vectors()(j), out_vectors(i) );
+ }
+
+ // store the gradient for out_vectors(i) into result in the proper spot
+ set_subm(res,pos,0,num,1) = b(i)*temp;
+ pos += num;
+ }
+
+
+ res *= 2;
+ return res;
+ }
+
+ private:
+
+ mutable matrix<scalar_type, 0, 1, mem_manager_type> res;
+ mutable sample_type temp;
+
+ const distance_function<kernel_type>& dist_funct;
+ matrix<scalar_type,0,1,mem_manager_type>& b;
+ matrix<sample_type,0,1,mem_manager_type>& out_vectors;
+
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+ template <
+ typename K,
+ typename stop_strategy_type,
+ typename T
+ >
+ distance_function<K> approximate_distance_function (
+ stop_strategy_type stop_strategy,
+ const distance_function<K>& target,
+ const T& starting_basis
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(target.get_basis_vectors().size() > 0 &&
+ starting_basis.size() > 0,
+ "\t distance_function approximate_distance_function()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t target.get_basis_vectors().size(): " << target.get_basis_vectors().size()
+ << "\n\t starting_basis.size(): " << starting_basis.size()
+ );
+
+ using namespace red_impl;
+ // The next few statements just find the best weights with which to approximate
+ // the target object with the set of basis vectors in starting_basis. This
+ // is really just a simple application of some linear algebra. For the details
+ // see page 554 of Learning with kernels by Scholkopf and Smola where they talk
+ // about "Optimal Expansion Coefficients."
+
+ const K kern(target.get_kernel());
+ typedef typename K::scalar_type scalar_type;
+ typedef typename K::sample_type sample_type;
+ typedef typename K::mem_manager_type mem_manager_type;
+
+ matrix<scalar_type,0,1,mem_manager_type> beta;
+
+ // Now we compute the fist approximate distance function.
+ beta = pinv(kernel_matrix(kern,starting_basis)) *
+ (kernel_matrix(kern,starting_basis,target.get_basis_vectors())*target.get_alpha());
+ matrix<sample_type,0,1,mem_manager_type> out_vectors(mat(starting_basis));
+
+
+ // Now setup to do a global optimization of all the parameters in the approximate
+ // distance function.
+ const objective<K> obj(target, beta, out_vectors);
+ const objective_derivative<K> obj_der(target, beta, out_vectors);
+ matrix<scalar_type,0,1,mem_manager_type> opt_starting_point(obj.state_to_vector());
+
+
+ // perform a full optimization of all the parameters (i.e. both beta and the basis vectors together)
+ find_min(lbfgs_search_strategy(20),
+ stop_strategy,
+ obj, obj_der, opt_starting_point, 0);
+
+ // now make sure that the final optimized value is loaded into the beta and
+ // out_vectors matrices
+ obj.vector_to_state(opt_starting_point);
+
+ // Do a final reoptimization of beta just to make sure it is optimal given the new
+ // set of basis vectors.
+ beta = pinv(kernel_matrix(kern,out_vectors))*(kernel_matrix(kern,out_vectors,target.get_basis_vectors())*target.get_alpha());
+
+ // It is possible that some of the beta weights will be very close to zero. Lets remove
+ // the basis vectors with these essentially zero weights.
+ const scalar_type eps = max(abs(beta))*std::numeric_limits<scalar_type>::epsilon();
+ for (long i = 0; i < beta.size(); ++i)
+ {
+ // if beta(i) is zero (but leave at least one beta no matter what)
+ if (std::abs(beta(i)) < eps && beta.size() > 1)
+ {
+ beta = remove_row(beta, i);
+ out_vectors = remove_row(out_vectors, i);
+ --i;
+ }
+ }
+
+ return distance_function<K>(beta, kern, out_vectors);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ class reduced_decision_function_trainer2
+ {
+ public:
+ typedef typename trainer_type::kernel_type kernel_type;
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef typename trainer_type::trained_function_type trained_function_type;
+
+ reduced_decision_function_trainer2 () : num_bv(0) {}
+ reduced_decision_function_trainer2 (
+ const trainer_type& trainer_,
+ const long num_sb_,
+ const double eps_ = 1e-3
+ ) :
+ trainer(trainer_),
+ num_bv(num_sb_),
+ eps(eps_)
+ {
+ COMPILE_TIME_ASSERT(is_matrix<sample_type>::value);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_bv > 0 && eps > 0,
+ "\t reduced_decision_function_trainer2()"
+ << "\n\t you have given invalid arguments to this function"
+ << "\n\t num_bv: " << num_bv
+ << "\n\t eps: " << eps
+ );
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_bv > 0,
+ "\t reduced_decision_function_trainer2::train(x,y)"
+ << "\n\t You have tried to use an uninitialized version of this object"
+ << "\n\t num_bv: " << num_bv );
+ return do_train(mat(x), mat(y));
+ }
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ // get the decision function object we are going to try and approximate
+ const decision_function<kernel_type>& dec_funct = trainer.train(x,y);
+ const kernel_type kern(dec_funct.kernel_function);
+
+ // now find a linearly independent subset of the training points of num_bv points.
+ linearly_independent_subset_finder<kernel_type> lisf(kern, num_bv);
+ fill_lisf(lisf,x);
+
+ distance_function<kernel_type> approx, target;
+ target = dec_funct;
+ approx = approximate_distance_function(objective_delta_stop_strategy(eps), target, lisf);
+
+ decision_function<kernel_type> new_df(approx.get_alpha(),
+ 0,
+ kern,
+ approx.get_basis_vectors());
+
+ // now we have to figure out what the new bias should be. It might be a little
+ // different since we just messed with all the weights and vectors.
+ double bias = 0;
+ for (long i = 0; i < x.nr(); ++i)
+ {
+ bias += new_df(x(i)) - dec_funct(x(i));
+ }
+
+ new_df.b = bias/x.nr();
+
+ return new_df;
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ trainer_type trainer;
+ long num_bv;
+ double eps;
+
+
+ }; // end of class reduced_decision_function_trainer2
+
+ template <typename trainer_type>
+ const reduced_decision_function_trainer2<trainer_type> reduced2 (
+ const trainer_type& trainer,
+ const long num_bv,
+ double eps = 1e-3
+ )
+ {
+ COMPILE_TIME_ASSERT(is_matrix<typename trainer_type::sample_type>::value);
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_bv > 0 && eps > 0,
+ "\tconst reduced_decision_function_trainer2 reduced2()"
+ << "\n\t you have given invalid arguments to this function"
+ << "\n\t num_bv: " << num_bv
+ << "\n\t eps: " << eps
+ );
+
+ return reduced_decision_function_trainer2<trainer_type>(trainer, num_bv, eps);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_REDUCEd_TRAINERS_
+
diff --git a/ml/dlib/dlib/svm/reduced_abstract.h b/ml/dlib/dlib/svm/reduced_abstract.h
new file mode 100644
index 000000000..8b186c033
--- /dev/null
+++ b/ml/dlib/dlib/svm/reduced_abstract.h
@@ -0,0 +1,267 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_REDUCED_TRAINERs_ABSTRACT_
+#ifdef DLIB_REDUCED_TRAINERs_ABSTRACT_
+
+#include "../matrix.h"
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "../optimization.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ class reduced_decision_function_trainer
+ {
+ /*!
+ REQUIREMENTS ON trainer_type
+ - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer)
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an implementation of a reduced set algorithm.
+ This object acts as a post processor for anything that creates
+ decision_function objects. It wraps another trainer object and
+ performs this reduced set post processing with the goal of
+ representing the original decision function in a form that
+ involves fewer basis vectors.
+ !*/
+
+ public:
+ typedef typename trainer_type::kernel_type kernel_type;
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef typename trainer_type::trained_function_type trained_function_type;
+
+ reduced_decision_function_trainer (
+ );
+ /*!
+ ensures
+ - This object is in an uninitialized state. You must
+ construct a real one with the other constructor and assign it
+ to this instance before you use this object.
+ !*/
+
+ reduced_decision_function_trainer (
+ const trainer_type& trainer,
+ const unsigned long num_bv
+ );
+ /*!
+ requires
+ - num_bv > 0
+ ensures
+ - returns a trainer object that applies post processing to the decision_function
+ objects created by the given trainer object with the goal of creating
+ decision_function objects with fewer basis vectors.
+ - The reduced decision functions that are output will have at most
+ num_bv basis vectors.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ ensures
+ - trains a decision_function using the trainer that was supplied to
+ this object's constructor and then finds a reduced representation
+ for it and returns the reduced version.
+ throws
+ - std::bad_alloc
+ - any exceptions thrown by the trainer_type object
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const reduced_decision_function_trainer<trainer_type> reduced (
+ const trainer_type& trainer,
+ const unsigned long num_bv
+ ) { return reduced_decision_function_trainer<trainer_type>(trainer, num_bv); }
+ /*!
+ requires
+ - num_bv > 0
+ - trainer_type == some kind of batch trainer object that creates decision_function
+ objects (e.g. svm_nu_trainer)
+ ensures
+ - returns a reduced_decision_function_trainer object that has been
+ instantiated with the given arguments.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename stop_strategy_type,
+ typename T
+ >
+ distance_function<K> approximate_distance_function (
+ stop_strategy_type stop_strategy,
+ const distance_function<K>& target,
+ const T& starting_basis
+ );
+ /*!
+ requires
+ - stop_strategy == an object that defines a stop strategy such as one of
+ the objects from dlib/optimization/optimization_stop_strategies_abstract.h
+ - requirements on starting_basis
+ - T must be a dlib::matrix type or something convertible to a matrix via mat()
+ (e.g. a std::vector). Additionally, starting_basis must contain K::sample_type
+ objects which can be supplied to the kernel function used by target.
+ - is_vector(starting_basis) == true
+ - starting_basis.size() > 0
+ - target.get_basis_vectors().size() > 0
+ - kernel_derivative<K> is defined
+ (i.e. The analytic derivative for the given kernel must be defined)
+ - K::sample_type must be a dlib::matrix object and the basis_vectors inside target
+ and starting_basis must be column vectors.
+ ensures
+ - This routine attempts to find a distance_function object which is close
+ to the given target. That is, it searches for an X such that target(X) is
+ minimized. The optimization begins with an X in the span of the elements
+ of starting_basis and searches for an X which locally minimizes target(X).
+ Since this problem can have many local minima, the quality of the starting
+ basis can significantly influence the results.
+ - The optimization is over all variables in a distance_function, however,
+ the size of the basis set is constrained to no more than starting_basis.size().
+ That is, in the returned distance_function DF, we will have:
+ - DF.get_basis_vectors().size() <= starting_basis.size()
+ - The optimization is carried out until the stop_strategy indicates it
+ should stop.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ class reduced_decision_function_trainer2
+ {
+ /*!
+ REQUIREMENTS ON trainer_type
+ - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer)
+ - trainer_type::sample_type must be a dlib::matrix object
+ - kernel_derivative<trainer_type::kernel_type> must be defined
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an implementation of a reduced set algorithm.
+ This object acts as a post processor for anything that creates
+ decision_function objects. It wraps another trainer object and
+ performs this reduced set post processing with the goal of
+ representing the original decision function in a form that
+ involves fewer basis vectors.
+
+ This object's implementation is the same as that in the above
+ reduced_decision_function_trainer object except it also performs
+ a global gradient based optimization at the end to further
+ improve the approximation to the original decision function
+ object.
+ !*/
+
+ public:
+ typedef typename trainer_type::kernel_type kernel_type;
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef typename trainer_type::trained_function_type trained_function_type;
+
+ reduced_decision_function_trainer2 (
+ );
+ /*!
+ ensures
+ - This object is in an uninitialized state. You must
+ construct a real one with the other constructor and assign it
+ to this instance before you use this object.
+ !*/
+
+ reduced_decision_function_trainer2 (
+ const trainer_type& trainer,
+ const unsigned long num_bv,
+ double eps = 1e-3
+ );
+ /*!
+ requires
+ - num_bv > 0
+ - eps > 0
+ ensures
+ - returns a trainer object that applies post processing to the decision_function
+ objects created by the given trainer object with the goal of creating
+ decision_function objects with fewer basis vectors.
+ - The reduced decision functions that are output will have at most
+ num_bv basis vectors.
+ - the gradient based optimization will continue until the change in the
+ objective function is less than eps. So smaller values of eps will
+ give better results but take longer to compute.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - x must be a list of objects which are each some kind of dlib::matrix
+ which represents column or row vectors.
+ ensures
+ - trains a decision_function using the trainer that was supplied to
+ this object's constructor and then finds a reduced representation
+ for it and returns the reduced version.
+ throws
+ - std::bad_alloc
+ - any exceptions thrown by the trainer_type object
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const reduced_decision_function_trainer2<trainer_type> reduced2 (
+ const trainer_type& trainer,
+ const unsigned long num_bv,
+ double eps = 1e-3
+ ) { return reduced_decision_function_trainer2<trainer_type>(trainer, num_bv, eps); }
+ /*!
+ requires
+ - num_bv > 0
+ - trainer_type == some kind of batch trainer object that creates decision_function
+ objects (e.g. svm_nu_trainer)
+ - kernel_derivative<trainer_type::kernel_type> is defined
+ - eps > 0
+ ensures
+ - returns a reduced_decision_function_trainer2 object that has been
+ instantiated with the given arguments.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_REDUCED_TRAINERs_ABSTRACT_
+
diff --git a/ml/dlib/dlib/svm/rls.h b/ml/dlib/dlib/svm/rls.h
new file mode 100644
index 000000000..edee6b062
--- /dev/null
+++ b/ml/dlib/dlib/svm/rls.h
@@ -0,0 +1,232 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RLs_Hh_
+#define DLIB_RLs_Hh_
+
+#include "rls_abstract.h"
+#include "../matrix.h"
+#include "function.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class rls
+ {
+
+ public:
+
+
+ explicit rls(
+ double forget_factor_,
+ double C_ = 1000,
+ bool apply_forget_factor_to_C_ = false
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < forget_factor_ && forget_factor_ <= 1 &&
+ 0 < C_,
+ "\t rls::rls()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t forget_factor_: " << forget_factor_
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+
+ C = C_;
+ forget_factor = forget_factor_;
+ apply_forget_factor_to_C = apply_forget_factor_to_C_;
+ }
+
+ rls(
+ )
+ {
+ C = 1000;
+ forget_factor = 1;
+ apply_forget_factor_to_C = false;
+ }
+
+ double get_c(
+ ) const
+ {
+ return C;
+ }
+
+ double get_forget_factor(
+ ) const
+ {
+ return forget_factor;
+ }
+
+ bool should_apply_forget_factor_to_C (
+ ) const
+ {
+ return apply_forget_factor_to_C;
+ }
+
+ template <typename EXP>
+ void train (
+ const matrix_exp<EXP>& x,
+ double y
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(x) &&
+ (get_w().size() == 0 || get_w().size() == x.size()),
+ "\t void rls::train()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t get_w().size(): " << get_w().size()
+ << "\n\t this: " << this
+ );
+
+ if (R.size() == 0)
+ {
+ R = identity_matrix<double>(x.size())*C;
+ w.set_size(x.size());
+ w = 0;
+ }
+
+ // multiply by forget factor and incorporate x*trans(x) into R.
+ const double l = 1.0/forget_factor;
+ const double temp = 1 + l*trans(x)*R*x;
+ tmp = R*x;
+ R = l*R - l*l*(tmp*trans(tmp))/temp;
+
+ // Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
+ // identity matrix back in to keep the regularization alive.
+ if (forget_factor != 1 && !apply_forget_factor_to_C)
+ add_eye_to_inv(R, (1-forget_factor)/C);
+
+ // R should always be symmetric. This line improves numeric stability of this algorithm.
+ if (cnt%10 == 0)
+ R = 0.5*(R + trans(R));
+ ++cnt;
+
+ w = w + R*x*(y - trans(x)*w);
+
+ }
+
+
+
+ const matrix<double,0,1>& get_w(
+ ) const
+ {
+ return w;
+ }
+
+ template <typename EXP>
+ double operator() (
+ const matrix_exp<EXP>& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(x) && get_w().size() == x.size(),
+ "\t double rls::operator()()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t is_col_vector(x): " << is_col_vector(x)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t get_w().size(): " << get_w().size()
+ << "\n\t this: " << this
+ );
+
+ return dot(x,w);
+ }
+
+ decision_function<linear_kernel<matrix<double,0,1> > > get_decision_function (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_w().size() != 0,
+ "\t decision_function rls::get_decision_function()"
+ << "\n\t invalid arguments were given to this function"
+ << "\n\t get_w().size(): " << get_w().size()
+ << "\n\t this: " << this
+ );
+
+ decision_function<linear_kernel<matrix<double,0,1> > > df;
+ df.alpha.set_size(1);
+ df.basis_vectors.set_size(1);
+ df.b = 0;
+ df.alpha = 1;
+ df.basis_vectors(0) = w;
+
+ return df;
+ }
+
+ friend inline void serialize(const rls& item, std::ostream& out)
+ {
+ int version = 2;
+ serialize(version, out);
+ serialize(item.w, out);
+ serialize(item.R, out);
+ serialize(item.C, out);
+ serialize(item.forget_factor, out);
+ serialize(item.cnt, out);
+ serialize(item.apply_forget_factor_to_C, out);
+ }
+
+ friend inline void deserialize(rls& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (!(1 <= version && version <= 2))
+ throw dlib::serialization_error("Unknown version number found while deserializing rls object.");
+
+ if (version >= 1)
+ {
+ deserialize(item.w, in);
+ deserialize(item.R, in);
+ deserialize(item.C, in);
+ deserialize(item.forget_factor, in);
+ }
+ item.cnt = 0;
+ item.apply_forget_factor_to_C = false;
+ if (version >= 2)
+ {
+ deserialize(item.cnt, in);
+ deserialize(item.apply_forget_factor_to_C, in);
+ }
+ }
+
+ private:
+
+ void add_eye_to_inv(
+ matrix<double>& m,
+ double C
+ )
+ /*!
+ ensures
+ - Let m == inv(M)
+ - this function returns inv(M + C*identity_matrix<double>(m.nr()))
+ !*/
+ {
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ m = m - colm(m,r)*trans(colm(m,r))/(1/C + m(r,r));
+ }
+ }
+
+
+ matrix<double,0,1> w;
+ matrix<double> R;
+ double C;
+ double forget_factor;
+ int cnt = 0;
+ bool apply_forget_factor_to_C;
+
+
+ // This object is here only to avoid reallocation during training. It don't
+ // logically contribute to the state of this object.
+ matrix<double,0,1> tmp;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RLs_Hh_
+
diff --git a/ml/dlib/dlib/svm/rls_abstract.h b/ml/dlib/dlib/svm/rls_abstract.h
new file mode 100644
index 000000000..c593e4330
--- /dev/null
+++ b/ml/dlib/dlib/svm/rls_abstract.h
@@ -0,0 +1,175 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RLs_ABSTRACT_Hh_
+#ifdef DLIB_RLs_ABSTRACT_Hh_
+
+#include "../matrix/matrix_abstract.h"
+#include "function_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class rls
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the linear version of the recursive least
+ squares algorithm. It accepts training points incrementally and, at
+ each step, maintains the solution to the following optimization problem:
+ find w minimizing: 0.5*dot(w,w) + C*sum_i(y_i - trans(x_i)*w)^2
+ Where (x_i,y_i) are training pairs. x_i is some vector and y_i is a target
+ scalar value.
+
+ This object can also be configured to use exponential forgetting. This is
+ where each training example is weighted by pow(forget_factor, i), where i
+ indicates the sample's age. So older samples are weighted less in the
+ least squares solution and therefore become forgotten after some time.
+ Therefore, with forgetting, this object solves the following optimization
+ problem at each step:
+ find w minimizing: 0.5*dot(w,w) + C*sum_i pow(forget_factor, i)*(y_i - trans(x_i)*w)^2
+ Where i starts at 0 and i==0 corresponds to the most recent training point.
+ !*/
+
+ public:
+
+
+ explicit rls(
+ double forget_factor,
+ double C = 1000,
+ bool apply_forget_factor_to_C = false
+ );
+ /*!
+ requires
+ - 0 < forget_factor <= 1
+ - 0 < C
+ ensures
+ - #get_w().size() == 0
+ - #get_c() == C
+ - #get_forget_factor() == forget_factor
+ - #should_apply_forget_factor_to_C() == apply_forget_factor_to_C
+ !*/
+
+ rls(
+ );
+ /*!
+ ensures
+ - #get_w().size() == 0
+ - #get_c() == 1000
+ - #get_forget_factor() == 1
+ - #should_apply_forget_factor_to_C() == false
+ !*/
+
+ double get_c(
+ ) const;
+ /*!
+ ensures
+ - returns the regularization parameter. It is the parameter
+ that determines the trade-off between trying to fit the training
+ data or allowing more errors but hopefully improving the generalization
+ of the resulting regression. Larger values encourage exact fitting while
+ smaller values of C may encourage better generalization.
+ !*/
+
+ double get_forget_factor(
+ ) const;
+ /*!
+ ensures
+ - returns the exponential forgetting factor. A value of 1 disables forgetting
+ and results in normal least squares regression. On the other hand, a smaller
+ value causes the regression to forget about old training examples and prefer
+ instead to fit more recent examples. The closer the forget factor is to
+ zero the faster old examples are forgotten.
+ !*/
+
+ bool should_apply_forget_factor_to_C (
+ ) const;
+ /*!
+ ensures
+ - If this function returns false then it means we are optimizing the
+ objective function discussed in the WHAT THIS OBJECT REPRESENTS section
+ above. However, if it returns true then we will allow the forget factor
+ (get_forget_factor()) to be applied to the C value which causes the
+ algorithm to slowly increase C and convert into a textbook version of RLS
+ without regularization. The main reason you might want to do this is
+ because it can make the algorithm run significantly faster.
+ !*/
+
+ template <typename EXP>
+ void train (
+ const matrix_exp<EXP>& x,
+ double y
+ )
+ /*!
+ requires
+ - is_col_vector(x) == true
+ - if (get_w().size() != 0) then
+ - x.size() == get_w().size()
+ (i.e. all training examples must have the same
+ dimensionality)
+ ensures
+ - #get_w().size() == x.size()
+ - updates #get_w() such that it contains the solution to the least
+ squares problem of regressing the given x onto the given y as well
+ as all the previous training examples supplied to train().
+ !*/
+
+ const matrix<double,0,1>& get_w(
+ ) const;
+ /*!
+ ensures
+ - returns the regression weights. These are the values learned by the
+ least squares procedure. If train() has not been called then this
+ function returns an empty vector.
+ !*/
+
+ template <typename EXP>
+ double operator() (
+ const matrix_exp<EXP>& x
+ ) const;
+ /*!
+ requires
+ - is_col_vector(x) == true
+ - get_w().size() == x.size()
+ ensures
+ - returns dot(x, get_w())
+ !*/
+
+ decision_function<linear_kernel<matrix<double,0,1> > > get_decision_function (
+ ) const;
+ /*!
+ requires
+ - get_w().size() != 0
+ ensures
+ - returns a decision function DF such that:
+ - DF(x) == dot(x, get_w())
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const rls& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize (
+ rls& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RLs_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/roc_trainer.h b/ml/dlib/dlib/svm/roc_trainer.h
new file mode 100644
index 000000000..fa2c0ef9b
--- /dev/null
+++ b/ml/dlib/dlib/svm/roc_trainer.h
@@ -0,0 +1,149 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ROC_TRAINEr_H_
+#define DLIB_ROC_TRAINEr_H_
+
+#include "roc_trainer_abstract.h"
+#include "../algs.h"
+#include <limits>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ class roc_trainer_type
+ {
+ public:
+ typedef typename trainer_type::kernel_type kernel_type;
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef typename trainer_type::trained_function_type trained_function_type;
+
+ roc_trainer_type (
+ ) : desired_accuracy(0), class_selection(0){}
+
+ roc_trainer_type (
+ const trainer_type& trainer_,
+ const scalar_type& desired_accuracy_,
+ const scalar_type& class_selection_
+ ) : trainer(trainer_), desired_accuracy(desired_accuracy_), class_selection(class_selection_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= desired_accuracy && desired_accuracy <= 1 &&
+ (class_selection == -1 || class_selection == +1),
+ "\t roc_trainer_type::roc_trainer_type()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t desired_accuracy: " << desired_accuracy
+ << "\n\t class_selection: " << class_selection
+ );
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const trained_function_type train (
+ const in_sample_vector_type& samples,
+ const in_scalar_vector_type& labels
+ ) const
+ /*!
+ requires
+ - is_binary_classification_problem(samples, labels) == true
+ !*/
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(samples, labels),
+ "\t roc_trainer_type::train()"
+ << "\n\t invalid inputs were given to this function"
+ );
+
+
+ return do_train(mat(samples), mat(labels));
+ }
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const trained_function_type do_train (
+ const in_sample_vector_type& samples,
+ const in_scalar_vector_type& labels
+ ) const
+ {
+ trained_function_type df = trainer.train(samples, labels);
+
+ // clear out the old bias
+ df.b = 0;
+
+ // obtain all the scores from the df using all the class_selection labeled samples
+ std::vector<double> scores;
+ for (long i = 0; i < samples.size(); ++i)
+ {
+ if (labels(i) == class_selection)
+ scores.push_back(df(samples(i)));
+ }
+
+ if (class_selection == +1)
+ std::sort(scores.rbegin(), scores.rend());
+ else
+ std::sort(scores.begin(), scores.end());
+
+ // now pick out the index that gives us the desired accuracy with regards to selected class
+ unsigned long idx = static_cast<unsigned long>(desired_accuracy*scores.size() + 0.5);
+ if (idx >= scores.size())
+ idx = scores.size()-1;
+
+ df.b = scores[idx];
+
+ // In this case add a very small extra amount to the bias so that all the samples
+ // with the class_selection label are classified correctly.
+ if (desired_accuracy == 1)
+ {
+ if (class_selection == +1)
+ df.b -= std::numeric_limits<scalar_type>::epsilon()*df.b;
+ else
+ df.b += std::numeric_limits<scalar_type>::epsilon()*df.b;
+ }
+
+ return df;
+ }
+
+ trainer_type trainer;
+ scalar_type desired_accuracy;
+ scalar_type class_selection;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const roc_trainer_type<trainer_type> roc_c1_trainer (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type& desired_accuracy
+ ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, +1); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const roc_trainer_type<trainer_type> roc_c2_trainer (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type& desired_accuracy
+ ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, -1); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ROC_TRAINEr_H_
+
+
diff --git a/ml/dlib/dlib/svm/roc_trainer_abstract.h b/ml/dlib/dlib/svm/roc_trainer_abstract.h
new file mode 100644
index 000000000..74e6f9b65
--- /dev/null
+++ b/ml/dlib/dlib/svm/roc_trainer_abstract.h
@@ -0,0 +1,135 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_ROC_TRAINEr_ABSTRACT_
+#ifdef DLIB_ROC_TRAINEr_ABSTRACT_
+
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ class roc_trainer_type
+ {
+ /*!
+ REQUIREMENTS ON trainer_type
+ - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer)
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a simple trainer post processor that allows you to
+ easily adjust the bias term in a trained decision_function object.
+ That is, this object lets you pick a point on the ROC curve and
+ it will adjust the bias term appropriately.
+
+ So for example, suppose you wanted to set the bias term so that
+ the accuracy of your decision function on +1 labeled samples was 99%.
+ To do this you would use an instance of this object declared as follows:
+ roc_trainer_type<trainer_type>(your_trainer, 0.99, +1);
+ !*/
+
+ public:
+ typedef typename trainer_type::kernel_type kernel_type;
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef typename trainer_type::trained_function_type trained_function_type;
+
+ roc_trainer_type (
+ );
+ /*!
+ ensures
+ - This object is in an uninitialized state. You must
+ construct a real one with the other constructor and assign it
+ to this instance before you use this object.
+ !*/
+
+ roc_trainer_type (
+ const trainer_type& trainer_,
+ const scalar_type& desired_accuracy_,
+ const scalar_type& class_selection_
+ );
+ /*!
+ requires
+ - 0 <= desired_accuracy_ <= 1
+ - class_selection_ == +1 or -1
+ ensures
+ - when training is performed using this object it will automatically
+ adjust the bias term in the returned decision function so that it
+ achieves the desired accuracy on the selected class type.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const trained_function_type train (
+ const in_sample_vector_type& samples,
+ const in_scalar_vector_type& labels
+ ) const
+ /*!
+ requires
+ - is_binary_classification_problem(samples, labels) == true
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ ensures
+ - performs training using the trainer object given to this object's
+ constructor, then modifies the bias term in the returned decision function
+ as discussed above, and finally returns the decision function.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const roc_trainer_type<trainer_type> roc_c1_trainer (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type& desired_accuracy
+ ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, +1); }
+ /*!
+ requires
+ - 0 <= desired_accuracy <= 1
+ - trainer_type == some kind of batch trainer object that creates decision_function
+ objects (e.g. svm_nu_trainer)
+ ensures
+ - returns a roc_trainer_type object that has been instantiated with the given
+ arguments. The returned roc trainer will select the decision function
+ bias that gives the desired accuracy with respect to the +1 class.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const roc_trainer_type<trainer_type> roc_c2_trainer (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type& desired_accuracy
+ ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, -1); }
+ /*!
+ requires
+ - 0 <= desired_accuracy <= 1
+ - trainer_type == some kind of batch trainer object that creates decision_function
+ objects (e.g. svm_nu_trainer)
+ ensures
+ - returns a roc_trainer_type object that has been instantiated with the given
+ arguments. The returned roc trainer will select the decision function
+ bias that gives the desired accuracy with respect to the -1 class.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ROC_TRAINEr_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/rr_trainer.h b/ml/dlib/dlib/svm/rr_trainer.h
new file mode 100644
index 000000000..09177217e
--- /dev/null
+++ b/ml/dlib/dlib/svm/rr_trainer.h
@@ -0,0 +1,456 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RR_TRAInER_Hh_
+#define DLIB_RR_TRAInER_Hh_
+
+#include "../algs.h"
+#include "function.h"
+#include "kernel.h"
+#include "empirical_kernel_map.h"
+#include "linearly_independent_subset_finder.h"
+#include "../statistics.h"
+#include "rr_trainer_abstract.h"
+#include <vector>
+#include <iostream>
+
+namespace dlib
+{
+ template <
+ typename K
+ >
+ class rr_trainer
+ {
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ // You are getting a compiler error on this line because you supplied a non-linear or
+ // sparse kernel to the rr_trainer object. You have to use dlib::linear_kernel with this trainer.
+ COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value));
+
+ rr_trainer (
+ ) :
+ verbose(false),
+ use_regression_loss(true),
+ lambda(0)
+ {
+ // default lambda search list
+ lams = matrix_cast<scalar_type>(logspace(-9, 2, 50));
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void use_regression_loss_for_loo_cv (
+ )
+ {
+ use_regression_loss = true;
+ }
+
+ void use_classification_loss_for_loo_cv (
+ )
+ {
+ use_regression_loss = false;
+ }
+
+ bool will_use_regression_loss_for_loo_cv (
+ ) const
+ {
+ return use_regression_loss;
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ return kernel_type();
+ }
+
+ void set_lambda (
+ scalar_type lambda_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(lambda_ >= 0,
+ "\t void rr_trainer::set_lambda()"
+ << "\n\t lambda must be greater than or equal to 0"
+ << "\n\t lambda: " << lambda
+ << "\n\t this: " << this
+ );
+
+ lambda = lambda_;
+ }
+
+ const scalar_type get_lambda (
+ ) const
+ {
+ return lambda;
+ }
+
+ template <typename EXP>
+ void set_search_lambdas (
+ const matrix_exp<EXP>& lambdas
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(lambdas) && lambdas.size() > 0 && min(lambdas) > 0,
+ "\t void rr_trainer::set_search_lambdas()"
+ << "\n\t lambdas must be a non-empty vector of values"
+ << "\n\t is_vector(lambdas): " << is_vector(lambdas)
+ << "\n\t lambdas.size(): " << lambdas.size()
+ << "\n\t min(lambdas): " << min(lambdas)
+ << "\n\t this: " << this
+ );
+
+
+ lams = matrix_cast<scalar_type>(lambdas);
+ }
+
+ const matrix<scalar_type,0,0,mem_manager_type>& get_search_lambdas (
+ ) const
+ {
+ return lams;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ std::vector<scalar_type> temp;
+ scalar_type temp2;
+ return do_train(mat(x), mat(y), false, temp, temp2);
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values
+ ) const
+ {
+ scalar_type temp;
+ return do_train(mat(x), mat(y), true, loo_values, temp);
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values,
+ scalar_type& lambda_used
+ ) const
+ {
+ return do_train(mat(x), mat(y), true, loo_values, lambda_used);
+ }
+
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const bool output_loo_values,
+ std::vector<scalar_type>& loo_values,
+ scalar_type& the_lambda
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y),
+ "\t decision_function rr_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_vector(x): " << is_vector(x)
+ << "\n\t is_vector(y): " << is_vector(y)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t y.size(): " << y.size()
+ );
+
+#ifdef ENABLE_ASSERTS
+ if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y),
+ "\t decision_function rr_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ );
+ }
+#endif
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> column_matrix_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type> general_matrix_type;
+
+ const long dims = x(0).size();
+
+ /*
+ Notes on the solution of ridge regression
+
+ Let A = an x.size() by dims matrix which contains all the data samples.
+
+ Let I = an identity matrix
+
+ Let C = trans(A)*A
+ Let L = trans(A)*y
+
+ Then the optimal w is given by:
+ w = inv(C + lambda*I) * L
+
+
+ There is a trick to compute leave one out cross validation results for many different
+ lambda values quickly. The following paper has a detailed discussion of various
+ approaches:
+
+ Notes on Regularized Least Squares by Ryan M. Rifkin and Ross A. Lippert.
+
+ In the implementation of the rr_trainer I'm only using two simple equations
+ from the above paper.
+
+
+ First note that inv(C + lambda*I) can be computed for many different lambda
+ values in an efficient way by using an eigen decomposition of C. So we use
+ the fact that:
+ inv(C + lambda*I) == V*inv(D + lambda*I)*trans(V)
+ where V*D*trans(V) == C
+
+ Also, via some simple linear algebra the above paper works out that the leave one out
+ value for a sample x(i) is equal to the following:
+ Let G = inv(C + lambda*I)
+ let val = trans(x(i))*G*x(i);
+
+ leave one out value for sample x(i):
+ LOOV = (trans(w)*x(i) - y(i)*val) / (1 - val)
+
+ leave one out error for sample x(i):
+ LOOE = loss(y(i), LOOV)
+
+
+ Finally, note that we will pretend there was a 1 appended to the end of each
+ vector in x. We won't actually do that though because we don't want to
+ have to make a copy of all the samples. So throughout the following code
+ I have explicitly dealt with this.
+ */
+
+ general_matrix_type C, tempm, G;
+ column_matrix_type L, tempv, w;
+
+ // compute C and L
+ for (long i = 0; i < x.size(); ++i)
+ {
+ C += x(i)*trans(x(i));
+ L += y(i)*x(i);
+ tempv += x(i);
+ }
+
+ // Account for the extra 1 that we pretend is appended to x
+ // Make C = [C tempv
+ // tempv' x.size()]
+ C = join_cols(join_rows(C, tempv),
+ join_rows(trans(tempv), uniform_matrix<scalar_type>(1,1, x.size())));
+ L = join_cols(L, uniform_matrix<scalar_type>(1,1, sum(y)));
+
+ eigenvalue_decomposition<general_matrix_type> eig(make_symmetric(C));
+ const general_matrix_type V = eig.get_pseudo_v();
+ const column_matrix_type D = eig.get_real_eigenvalues();
+
+ // We can save some work by pre-multiplying the x vectors by trans(V)
+ // and saving the result so we don't have to recompute it over and over later.
+ matrix<column_matrix_type,0,1,mem_manager_type > Vx;
+ if (lambda == 0 || output_loo_values)
+ {
+ // Save the transpose of V into a temporary because the subsequent matrix
+ // vector multiplies will be faster (because of better cache locality).
+ const general_matrix_type transV( colm(trans(V),range(0,dims-1)) );
+ // Remember the pretend 1 at the end of x(*). We want to multiply trans(V)*x(*)
+ // so to do this we pull the last column off trans(V) and store it separately.
+ const column_matrix_type lastV = colm(trans(V), dims);
+ Vx.set_size(x.size());
+ for (long i = 0; i < x.size(); ++i)
+ {
+ Vx(i) = transV*x(i);
+ Vx(i) = squared(Vx(i) + lastV);
+ }
+ }
+
+ the_lambda = lambda;
+
+ // If we need to automatically select a lambda then do so using the LOOE trick described
+ // above.
+ bool did_loov = false;
+ scalar_type best_looe = std::numeric_limits<scalar_type>::max();
+ if (lambda == 0)
+ {
+ did_loov = true;
+
+ // Compute leave one out errors for a bunch of different lambdas and pick the best one.
+ for (long idx = 0; idx < lams.size(); ++idx)
+ {
+ // first compute G
+ tempv = 1.0/(D + lams(idx));
+ tempm = scale_columns(V,tempv);
+ G = tempm*trans(V);
+
+ // compute the solution w for the current lambda
+ w = G*L;
+
+ // make w have the same length as the x vectors.
+ const scalar_type b = w(dims);
+ w = colm(w,0,dims);
+
+ scalar_type looe = 0;
+ for (long i = 0; i < x.size(); ++i)
+ {
+ // perform equivalent of: val = trans(x(i))*G*x(i);
+ const scalar_type val = dot(tempv, Vx(i));
+ const scalar_type temp = (1 - val);
+ scalar_type loov;
+ if (temp != 0)
+ loov = (trans(w)*x(i) + b - y(i)*val) / temp;
+ else
+ loov = 0;
+
+ looe += loss(loov, y(i));
+ }
+
+ // Keep track of the lambda which gave the lowest looe. If two lambdas
+ // have the same looe then pick the biggest lambda.
+ if (looe < best_looe || (looe == best_looe && lams(idx) > the_lambda))
+ {
+ best_looe = looe;
+ the_lambda = lams(idx);
+ }
+ }
+
+ best_looe /= x.size();
+ }
+
+
+
+ // Now perform the main training. That is, find w.
+ // first, compute G = inv(C + the_lambda*I)
+ tempv = 1.0/(D + the_lambda);
+ tempm = scale_columns(V,tempv);
+ G = tempm*trans(V);
+ w = G*L;
+
+ // make w have the same length as the x vectors.
+ const scalar_type b = w(dims);
+ w = colm(w,0,dims);
+
+
+ // If we haven't done this already and we are supposed to then compute the LOO error rate for
+ // the current lambda and store the result in best_looe.
+ if (output_loo_values)
+ {
+ loo_values.resize(x.size());
+ did_loov = true;
+ best_looe = 0;
+ for (long i = 0; i < x.size(); ++i)
+ {
+ // perform equivalent of: val = trans(x(i))*G*x(i);
+ const scalar_type val = dot(tempv, Vx(i));
+ const scalar_type temp = (1 - val);
+ scalar_type loov;
+ if (temp != 0)
+ loov = (trans(w)*x(i) + b - y(i)*val) / temp;
+ else
+ loov = 0;
+
+ best_looe += loss(loov, y(i));
+ loo_values[i] = loov;
+ }
+
+ best_looe /= x.size();
+
+ }
+ else
+ {
+ loo_values.clear();
+ }
+
+ if (verbose && did_loov)
+ {
+ using namespace std;
+ cout << "Using lambda: " << the_lambda << endl;
+ if (use_regression_loss)
+ cout << "LOO Mean Squared Error: " << best_looe << endl;
+ else
+ cout << "LOO Classification Error: " << best_looe << endl;
+ }
+
+ // convert w into a proper decision function
+ decision_function<kernel_type> df;
+ df.alpha.set_size(1);
+ df.alpha = 1;
+ df.basis_vectors.set_size(1);
+ df.basis_vectors(0) = w;
+ df.b = -b; // don't forget about the bias we stuck onto all the vectors
+
+ return df;
+ }
+
+ inline scalar_type loss (
+ const scalar_type& a,
+ const scalar_type& b
+ ) const
+ {
+ if (use_regression_loss)
+ {
+ return (a-b)*(a-b);
+ }
+ else
+ {
+ // if a and b have the same sign then no loss
+ if (a*b >= 0)
+ return 0;
+ else
+ return 1;
+ }
+ }
+
+
+ /*!
+ CONVENTION
+ - get_lambda() == lambda
+ - get_kernel() == kernel_type()
+ - will_use_regression_loss_for_loo_cv() == use_regression_loss
+ - get_search_lambdas() == lams
+ !*/
+
+ bool verbose;
+ bool use_regression_loss;
+
+ scalar_type lambda;
+
+ matrix<scalar_type,0,0,mem_manager_type> lams;
+ };
+
+}
+
+#endif // DLIB_RR_TRAInER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/rr_trainer_abstract.h b/ml/dlib/dlib/svm/rr_trainer_abstract.h
new file mode 100644
index 000000000..f2fe21068
--- /dev/null
+++ b/ml/dlib/dlib/svm/rr_trainer_abstract.h
@@ -0,0 +1,255 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RR_TRAInER_ABSTRACT_Hh_
+#ifdef DLIB_RR_TRAInER_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "function_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename K
+ >
+ class rr_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ is the dlib::linear_kernel instantiated with some kind of column vector.
+
+ INITIAL VALUE
+ - get_lambda() == 0
+ - will_use_regression_loss_for_loo_cv() == true
+ - get_search_lambdas() == logspace(-9, 2, 50)
+ - this object will not be verbose unless be_verbose() is called
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a tool for performing linear ridge regression
+ (This basic algorithm is also known my many other names, e.g. regularized
+ least squares or least squares SVM).
+
+ The exact definition of what this algorithm does is this:
+ Find w and b that minimizes the following (x_i are input samples and y_i are target values):
+ lambda*dot(w,w) + sum_over_i( (f(x_i) - y_i)^2 )
+ where f(x) == dot(x,w) - b
+
+ So this algorithm is just regular old least squares regression but
+ with the addition of a regularization term which encourages small w.
+
+
+ It is capable of estimating the lambda parameter using leave-one-out cross-validation.
+
+
+ The leave-one-out cross-validation implementation is based on the techniques
+ discussed in this paper:
+ Notes on Regularized Least Squares by Ryan M. Rifkin and Ross A. Lippert.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ rr_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ const kernel_type get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object. Since
+ the linear kernels don't have any parameters this function just
+ returns kernel_type()
+ !*/
+
+ void set_lambda (
+ scalar_type lambda
+ );
+ /*!
+ requires
+ - lambda >= 0
+ ensures
+ - #get_lambda() == lambda
+ !*/
+
+ const scalar_type get_lambda (
+ ) const;
+ /*!
+ ensures
+ - returns the regularization parameter. It is the parameter that
+ determines the trade off between trying to fit the training data
+ exactly or allowing more errors but hopefully improving the
+ generalization ability of the resulting function. Smaller values
+ encourage exact fitting while larger values of lambda may encourage
+ better generalization.
+
+ Note that a lambda of 0 has a special meaning. It indicates to this
+ object that it should automatically determine an appropriate lambda
+ value. This is done using leave-one-out cross-validation.
+ !*/
+
+ void use_regression_loss_for_loo_cv (
+ );
+ /*!
+ ensures
+ - #will_use_regression_loss_for_loo_cv() == true
+ !*/
+
+ void use_classification_loss_for_loo_cv (
+ );
+ /*!
+ ensures
+ - #will_use_regression_loss_for_loo_cv() == false
+ !*/
+
+ bool will_use_regression_loss_for_loo_cv (
+ ) const;
+ /*!
+ ensures
+ - returns true if the automatic lambda estimation will attempt to estimate a lambda
+ appropriate for a regression task. Otherwise it will try and find one which
+ minimizes the number of classification errors.
+ !*/
+
+ template <typename EXP>
+ void set_search_lambdas (
+ const matrix_exp<EXP>& lambdas
+ );
+ /*!
+ requires
+ - is_vector(lambdas) == true
+ - lambdas.size() > 0
+ - min(lambdas) > 0
+ - lambdas must contain floating point numbers
+ ensures
+ - #get_search_lambdas() == lambdas
+ !*/
+
+ const matrix<scalar_type,0,0,mem_manager_type>& get_search_lambdas (
+ ) const;
+ /*!
+ ensures
+ - returns a matrix M such that:
+ - is_vector(M) == true
+ - M == a list of all the lambda values which will be tried when performing
+ LOO cross-validation for determining the best lambda.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ - is_learning_problem(x,y) == true
+ - if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false) then
+ - is_binary_classification_problem(x,y) == true
+ (i.e. if you want this algorithm to estimate a lambda appropriate for
+ classification functions then you had better give a valid classification
+ problem)
+ ensures
+ - performs linear ridge regression given the training samples in x and target values in y.
+ - returns a decision_function F with the following properties:
+ - F(new_x) == predicted y value
+ - F.alpha.size() == 1
+ - F.basis_vectors.size() == 1
+ - F.alpha(0) == 1
+
+ - if (get_lambda() == 0) then
+ - This object will perform internal leave-one-out cross-validation to determine an
+ appropriate lambda automatically. It will compute the LOO error for each lambda
+ in get_search_lambdas() and select the best one.
+ - if (will_use_regression_loss_for_loo_cv()) then
+ - the lambda selected will be the one that minimizes the mean squared error.
+ - else
+ - the lambda selected will be the one that minimizes the number classification
+ mistakes. We say a point is classified correctly if the output of the
+ decision_function has the same sign as its label.
+ - #get_lambda() == 0
+ (i.e. we don't change the get_lambda() value. If you want to know what the
+ automatically selected lambda value was then call the version of train()
+ defined below)
+ - else
+ - The user supplied value of get_lambda() will be used to perform the ridge regression.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values
+ ) const;
+ /*!
+ requires
+ - all the requirements for train(x,y) must be satisfied
+ ensures
+ - returns train(x,y)
+ (i.e. executes train(x,y) and returns its result)
+ - #loo_values.size() == y.size()
+ - for all valid i:
+ - #loo_values[i] == leave-one-out prediction for the value of y(i) based
+ on all the training samples other than (x(i),y(i)).
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values,
+ scalar_type& lambda_used
+ ) const;
+ /*!
+ requires
+ - all the requirements for train(x,y) must be satisfied
+ ensures
+ - returns train(x,y)
+ (i.e. executes train(x,y) and returns its result)
+ - #loo_values.size() == y.size()
+ - for all valid i:
+ - #loo_values[i] == leave-one-out prediction for the value of y(i) based
+ on all the training samples other than (x(i),y(i)).
+ - #lambda_used == the value of lambda used to generate the
+ decision_function. Note that this lambda value is always
+ equal to get_lambda() if get_lambda() isn't 0.
+ !*/
+
+ };
+
+}
+
+#endif // DLIB_RR_TRAInER_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/rvm.h b/ml/dlib/dlib/svm/rvm.h
new file mode 100644
index 000000000..e7ad495a2
--- /dev/null
+++ b/ml/dlib/dlib/svm/rvm.h
@@ -0,0 +1,1018 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RVm_
+#define DLIB_RVm_
+
+#include "rvm_abstract.h"
+#include <cmath>
+#include <limits>
+#include "../matrix.h"
+#include "../algs.h"
+#include "function.h"
+#include "kernel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace rvm_helpers
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename scalar_vector_type, typename mem_manager_type>
+ long find_next_best_alpha_to_update (
+ const scalar_vector_type& S,
+ const scalar_vector_type& Q,
+ const scalar_vector_type& alpha,
+ const matrix<long,0,1,mem_manager_type>& active_bases,
+ const bool search_all_alphas,
+ typename scalar_vector_type::type eps
+ )
+ /*!
+ ensures
+ - if (we can find another alpha to update) then
+ - returns the index of said alpha
+ - else
+ - returns -1
+ !*/
+ {
+ typedef typename scalar_vector_type::type scalar_type;
+ // now use S and Q to find next alpha to update. What
+ // we want to do here is select the alpha to update that gives us
+ // the greatest improvement in marginal likelihood.
+ long selected_idx = -1;
+ scalar_type greatest_improvement = -1;
+ for (long i = 0; i < S.nr(); ++i)
+ {
+ scalar_type value = -1;
+
+ // if i is currently in the active set
+ if (active_bases(i) >= 0)
+ {
+ const long idx = active_bases(i);
+ const scalar_type s = alpha(idx)*S(i)/(alpha(idx) - S(i));
+ const scalar_type q = alpha(idx)*Q(i)/(alpha(idx) - S(i));
+
+ if (q*q-s > 0)
+ {
+ // only update an existing alpha if this is a narrow search
+ if (search_all_alphas == false)
+ {
+ // choosing this sample would mean doing an update of an
+ // existing alpha value.
+ scalar_type new_alpha = s*s/(q*q-s);
+ scalar_type cur_alpha = alpha(idx);
+ new_alpha = 1/new_alpha;
+ cur_alpha = 1/cur_alpha;
+
+ // from equation 32 in the Tipping paper
+ value = Q(i)*Q(i)/(S(i) + 1/(new_alpha - cur_alpha) ) -
+ std::log(1 + S(i)*(new_alpha - cur_alpha));
+ }
+
+ }
+ // we only pick an alpha to remove if this is a wide search and it wasn't one of the recently added ones
+ else if (search_all_alphas && idx+2 < alpha.size() )
+ {
+ // choosing this sample would mean the alpha value is infinite
+ // so we would remove the selected sample from our model.
+
+ // from equation 37 in the Tipping paper
+ value = Q(i)*Q(i)/(S(i) - alpha(idx)) -
+ std::log(1-S(i)/alpha(idx));
+
+ }
+ }
+ else if (search_all_alphas)
+ {
+ const scalar_type s = S(i);
+ const scalar_type q = Q(i);
+
+ if (q*q-s > 0)
+ {
+ // choosing this sample would mean we would add the selected
+ // sample to our model.
+
+ // from equation 27 in the Tipping paper
+ value = (Q(i)*Q(i)-S(i))/S(i) + std::log(S(i)/(Q(i)*Q(i)));
+ }
+ }
+
+ if (value > greatest_improvement)
+ {
+ greatest_improvement = value;
+ selected_idx = i;
+ }
+ }
+
+ // If the greatest_improvement in marginal likelihood we would get is less
+ // than our epsilon then report that there isn't anything else to do. But
+ // if it is big enough then return the selected_idx.
+ if (greatest_improvement > eps)
+ return selected_idx;
+ else
+ return -1;
+ }
+
+ } // end namespace rvm_helpers
+
+ // ------------------------------------------------------------------------------------
+
+
+ template <
+ typename kern_type
+ >
+ class rvm_trainer
+ {
+ /*!
+ This is an implementation of the binary classifier version of the
+ relevance vector machine algorithm described in the paper:
+ Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation
+ for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings
+ of the Ninth International Workshop on Artificial Intelligence and Statistics,
+ Key West, FL, Jan 3-6.
+
+ This code mostly does what is described in the above paper with the exception
+ that here we use a different stopping condition as well as a modified alpha
+ selection rule. See the code for the exact details.
+ !*/
+
+ public:
+ typedef kern_type kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ rvm_trainer (
+ ) : eps(0.001), max_iterations(2000)
+ {
+ }
+
+ void set_max_iterations (
+ unsigned long max_iterations_
+ )
+ {
+ max_iterations = max_iterations_;
+ }
+
+ unsigned long get_max_iterations (
+ ) const
+ {
+ return max_iterations;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\tvoid rvm_trainer::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps: " << eps_
+ );
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kernel = k;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ return do_train(mat(x), mat(y));
+ }
+
+ void swap (
+ rvm_trainer& item
+ )
+ {
+ exchange(kernel, item.kernel);
+ exchange(eps, item.eps);
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
+ "\tdecision_function rvm_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t y.nr(): " << y.nr()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t y.nc(): " << y.nc()
+ << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
+ );
+
+ // make a target vector where +1 examples have value 1 and -1 examples
+ // have a value of 0.
+ scalar_vector_type t(y.size());
+ for (long i = 0; i < y.size(); ++i)
+ {
+ if (y(i) == 1)
+ t(i) = 1;
+ else
+ t(i) = 0;
+ }
+
+ /*! This is the convention for the active_bases variable in the function:
+ - if (active_bases(i) >= 0) then
+ - alpha(active_bases(i)) == the alpha value associated with sample x(i)
+ - weights(active_bases(i)) == the weight value associated with sample x(i)
+ - colm(phi, active_bases(i)) == the column of phi associated with sample x(i)
+ - colm(phi, active_bases(i)) == kernel column i (from get_kernel_colum())
+ - else
+ - the i'th sample isn't in the model and notionally has an alpha of infinity and
+ a weight of 0.
+ !*/
+ matrix<long,0,1,mem_manager_type> active_bases(x.nr());
+ scalar_matrix_type phi(x.nr(),1);
+ scalar_vector_type alpha(1), prev_alpha;
+ scalar_vector_type weights(1), prev_weights;
+
+ scalar_vector_type tempv, K_col;
+
+ // set the initial values of these guys
+ set_all_elements(active_bases, -1);
+ long first_basis = pick_initial_vector(x,t);
+ get_kernel_colum(first_basis, x, K_col);
+ active_bases(first_basis) = 0;
+ set_colm(phi,0) = K_col;
+ alpha(0) = compute_initial_alpha(phi, t);
+ weights(0) = 1;
+
+
+ // now declare a bunch of other variables we will be using below
+ scalar_vector_type mu, t_hat, Q, S;
+ scalar_matrix_type sigma;
+
+ matrix<scalar_type,1,0,mem_manager_type> tempv2, tempv3;
+ scalar_matrix_type tempm;
+
+ scalar_vector_type t_estimate;
+ scalar_vector_type beta;
+
+
+ Q.set_size(x.nr());
+ S.set_size(x.nr());
+
+ bool recompute_beta = true;
+
+ bool search_all_alphas = false;
+ unsigned long ticker = 0;
+ const unsigned long rounds_of_narrow_search = 100;
+ unsigned long iterations = 0;
+
+ while (iterations != max_iterations)
+ {
+ iterations++;
+ if (recompute_beta)
+ {
+ // calculate the current t_estimate. (this is the predicted t value for each sample according to the
+ // current state of the classifier)
+ t_estimate = phi*weights;
+
+ // calculate the current beta
+ beta = sigmoid(t_estimate);
+ beta = pointwise_multiply(beta,(uniform_matrix<scalar_type>(beta.nr(),beta.nc(),1)-beta));
+ recompute_beta = false;
+ }
+
+ // Compute optimal weights and sigma for current alpha using IRLS. This is the same
+ // technique documented in the paper by equations 12-14.
+ scalar_type weight_delta = std::numeric_limits<scalar_type>::max();
+ int count = 0;
+ while (weight_delta > 0.0001)
+ {
+ // This is a sanity check to make sure we never get stuck in this
+ // loop to do some degenerate numerical condition
+ ++count;
+ if (count > 100)
+ {
+ // jump us to where search_all_alphas will be set to true
+ ticker = rounds_of_narrow_search;
+ break;
+ }
+
+ // compute the updated sigma matrix
+ sigma = scale_columns(trans(phi),beta)*phi;
+ for (long r = 0; r < alpha.nr(); ++r)
+ sigma(r,r) += alpha(r);
+ sigma = inv(sigma);
+
+
+ // compute the updated weights vector (t_hat = phi*mu_mp + inv(B)*(t-y))
+ t_hat = t_estimate + trans(scale_columns(trans(t-sigmoid(t_estimate)),reciprocal(beta)));
+
+ // mu = sigma*trans(phi)*b*t_hat
+ mu = sigma*tmp(trans(phi)* trans(scale_columns(trans(t_hat), beta)));
+
+ // now compute how much the weights vector changed during this iteration
+ // through this loop.
+ weight_delta = max(abs(mu-weights));
+
+ // put mu into the weights vector
+ mu.swap(weights);
+
+ // calculate the current t_estimate
+ t_estimate = phi*weights;
+
+ // calculate the current beta
+ beta = sigmoid(t_estimate);
+ beta = pointwise_multiply(beta, uniform_matrix<scalar_type>(beta.nr(),beta.nc(),1)-beta);
+
+ }
+
+ // check if we should do a full search for the best alpha to optimize
+ if (ticker >= rounds_of_narrow_search)
+ {
+ // if the current alpha and weights are equal to what they were
+ // at the last time we were about to start a wide search then
+ // we are done.
+ if (equal(prev_alpha, alpha, eps) && equal(prev_weights, weights, eps))
+ break;
+
+
+ prev_alpha = alpha;
+ prev_weights = weights;
+ search_all_alphas = true;
+ ticker = 0;
+ }
+ else
+ {
+ search_all_alphas = false;
+ }
+ ++ticker;
+
+ // compute S and Q using equations 24 and 25 (tempv = phi*sigma*trans(phi)*B*t_hat)
+ tempv = phi*tmp(sigma*tmp(trans(phi)*trans(scale_columns(trans(t_hat),beta))));
+ for (long i = 0; i < S.size(); ++i)
+ {
+ // if we are currently limiting the search for the next alpha to update
+ // to the set in the active set then skip a non-active vector.
+ if (search_all_alphas == false && active_bases(i) == -1)
+ continue;
+
+ // get the column for this sample out of the kernel matrix. If it is
+ // something in the active set then just get it right out of phi, otherwise
+ // we have to calculate it.
+ if (active_bases(i) != -1)
+ K_col = colm(phi,active_bases(i));
+ else
+ get_kernel_colum(i, x, K_col);
+
+ // tempv2 = trans(phi_m)*B
+ tempv2 = scale_columns(trans(K_col), beta);
+ tempv3 = tempv2*phi;
+ S(i) = tempv2*K_col - tempv3*sigma*trans(tempv3);
+ Q(i) = tempv2*t_hat - tempv2*tempv;
+ }
+
+ const long selected_idx = rvm_helpers::find_next_best_alpha_to_update(S,Q,alpha,active_bases, search_all_alphas, eps);
+
+
+ // if find_next_best_alpha_to_update didn't find any good alpha to update
+ if (selected_idx == -1)
+ {
+ if (search_all_alphas == false)
+ {
+ // jump us to where search_all_alphas will be set to true and try again
+ ticker = rounds_of_narrow_search;
+ continue;
+ }
+ else
+ {
+ // we are really done so quit the main loop
+ break;
+ }
+ }
+
+
+ // next we update the selected alpha.
+
+ // if the selected alpha is in the active set
+ if (active_bases(selected_idx) >= 0)
+ {
+ const long idx = active_bases(selected_idx);
+ const scalar_type s = alpha(idx)*S(selected_idx)/(alpha(idx) - S(selected_idx));
+ const scalar_type q = alpha(idx)*Q(selected_idx)/(alpha(idx) - S(selected_idx));
+
+ if (q*q-s > 0)
+ {
+ // reestimate the value of alpha
+ alpha(idx) = s*s/(q*q-s);
+
+ }
+ else
+ {
+ // the new alpha value is infinite so remove the selected alpha from our model
+ active_bases(selected_idx) = -1;
+ phi = remove_col(phi, idx);
+ weights = remove_row(weights, idx);
+ alpha = remove_row(alpha, idx);
+
+ // fix the index values in active_bases
+ for (long i = 0; i < active_bases.size(); ++i)
+ {
+ if (active_bases(i) > idx)
+ {
+ active_bases(i) -= 1;
+ }
+ }
+
+ // we changed the number of weights so we need to remember to
+ // recompute the beta vector next time around the main loop.
+ recompute_beta = true;
+ }
+ }
+ else
+ {
+ const scalar_type s = S(selected_idx);
+ const scalar_type q = Q(selected_idx);
+
+ if (q*q-s > 0)
+ {
+ // add the selected alpha to our model
+
+ active_bases(selected_idx) = phi.nc();
+
+ // update alpha
+ tempv.set_size(alpha.size()+1);
+ set_subm(tempv, get_rect(alpha)) = alpha;
+ tempv(phi.nc()) = s*s/(q*q-s);
+ tempv.swap(alpha);
+
+ // update weights
+ tempv.set_size(weights.size()+1);
+ set_subm(tempv, get_rect(weights)) = weights;
+ tempv(phi.nc()) = 0;
+ tempv.swap(weights);
+
+ // update phi by adding the new sample's kernel matrix column in as one of phi's columns
+ tempm.set_size(phi.nr(), phi.nc()+1);
+ set_subm(tempm, get_rect(phi)) = phi;
+ get_kernel_colum(selected_idx, x, K_col);
+ set_colm(tempm, phi.nc()) = K_col;
+ tempm.swap(phi);
+
+
+ // we changed the number of weights so we need to remember to
+ // recompute the beta vector next time around the main loop.
+ recompute_beta = true;
+ }
+ }
+
+ } // end while(true). So we have converged on the final answer.
+
+
+ // now put everything into a decision_function object and return it
+ std_vector_c<sample_type> dictionary;
+ std_vector_c<scalar_type> final_weights;
+ for (long i = 0; i < active_bases.size(); ++i)
+ {
+ if (active_bases(i) >= 0)
+ {
+ dictionary.push_back(x(i));
+ final_weights.push_back(weights(active_bases(i)));
+ }
+ }
+
+ return decision_function<kernel_type> ( mat(final_weights),
+ -sum(mat(final_weights))*tau,
+ kernel,
+ mat(dictionary));
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ long pick_initial_vector (
+ const M1& x,
+ const M2& t
+ ) const
+ {
+ scalar_vector_type K_col;
+ double max_projection = -std::numeric_limits<scalar_type>::infinity();
+ long max_idx = 0;
+ // find the row in the kernel matrix that has the biggest normalized projection onto the t vector
+ for (long r = 0; r < x.nr(); ++r)
+ {
+ get_kernel_colum(r,x,K_col);
+ double temp = trans(K_col)*t;
+ temp = temp*temp/length_squared(K_col);
+
+ if (temp > max_projection)
+ {
+ max_projection = temp;
+ max_idx = r;
+ }
+ }
+
+ return max_idx;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ void get_kernel_colum (
+ long idx,
+ const T& x,
+ scalar_vector_type& col
+ ) const
+ {
+ col.set_size(x.nr());
+ for (long i = 0; i < col.size(); ++i)
+ {
+ col(i) = kernel(x(idx), x(i)) + tau;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ scalar_type compute_initial_alpha (
+ const M1& phi,
+ const M2& t
+ ) const
+ {
+ const double temp = length_squared(phi);
+ const double temp2 = trans(phi)*t;
+
+ return temp/( temp2*temp2/temp + variance(t)*0.1);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // private member variables
+ kernel_type kernel;
+ scalar_type eps;
+ unsigned long max_iterations;
+
+ const static scalar_type tau;
+
+ }; // end of class rvm_trainer
+
+ template <typename kernel_type>
+ const typename kernel_type::scalar_type rvm_trainer<kernel_type>::tau = static_cast<typename kernel_type::scalar_type>(0.001);
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ rvm_trainer<K>& a,
+ rvm_trainer<K>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kern_type
+ >
+ class rvm_regression_trainer
+ {
+ /*!
+ This is an implementation of the regression version of the
+ relevance vector machine algorithm described in the paper:
+ Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation
+ for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings
+ of the Ninth International Workshop on Artificial Intelligence and Statistics,
+ Key West, FL, Jan 3-6.
+
+ This code mostly does what is described in the above paper with the exception
+ that here we use a different stopping condition as well as a modified alpha
+ selection rule. See the code for the exact details.
+ !*/
+
+ public:
+ typedef kern_type kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ rvm_regression_trainer (
+ ) : eps(0.001)
+ {
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\tvoid rvm_regression_trainer::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps: " << eps_
+ );
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kernel = k;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& t
+ ) const
+ {
+ return do_train(mat(x), mat(t));
+ }
+
+ void swap (
+ rvm_regression_trainer& item
+ )
+ {
+ exchange(kernel, item.kernel);
+ exchange(eps, item.eps);
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+ typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& t
+ ) const
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,t) && x.size() > 0,
+ "\tdecision_function rvm_regression_trainer::train(x,t)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t t.nr(): " << t.nr()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t t.nc(): " << t.nc()
+ );
+
+
+ /*! This is the convention for the active_bases variable in the function:
+ - if (active_bases(i) >= 0) then
+ - alpha(active_bases(i)) == the alpha value associated with sample x(i)
+ - weights(active_bases(i)) == the weight value associated with sample x(i)
+ - colm(phi, active_bases(i)) == the column of phi associated with sample x(i)
+ - colm(phi, active_bases(i)) == kernel column i (from get_kernel_colum())
+ - else
+ - the i'th sample isn't in the model and notionally has an alpha of infinity and
+ a weight of 0.
+ !*/
+ matrix<long,0,1,mem_manager_type> active_bases(x.nr());
+ scalar_matrix_type phi(x.nr(),1);
+ scalar_vector_type alpha(1), prev_alpha;
+ scalar_vector_type weights(1), prev_weights;
+
+ scalar_vector_type tempv, K_col;
+ scalar_type var = variance(t)*0.1;
+
+ // set the initial values of these guys
+ set_all_elements(active_bases, -1);
+ long first_basis = pick_initial_vector(x,t);
+ get_kernel_colum(first_basis, x, K_col);
+ active_bases(first_basis) = 0;
+ set_colm(phi,0) = K_col;
+ alpha(0) = compute_initial_alpha(phi, t, var);
+ weights(0) = 1;
+
+
+ // now declare a bunch of other variables we will be using below
+ scalar_vector_type Q, S;
+ scalar_matrix_type sigma;
+
+ matrix<scalar_type,1,0,mem_manager_type> tempv2, tempv3;
+ scalar_matrix_type tempm;
+
+
+ Q.set_size(x.nr());
+ S.set_size(x.nr());
+
+
+ bool search_all_alphas = false;
+ unsigned long ticker = 0;
+ const unsigned long rounds_of_narrow_search = 100;
+
+ while (true)
+ {
+ // Compute optimal weights and sigma for current alpha using equation 6.
+ sigma = trans(phi)*phi/var;
+ for (long r = 0; r < alpha.nr(); ++r)
+ sigma(r,r) += alpha(r);
+ sigma = inv(sigma);
+ weights = sigma*trans(phi)*t/var;
+
+
+
+ // check if we should do a full search for the best alpha to optimize
+ if (ticker == rounds_of_narrow_search)
+ {
+ // if the current alpha and weights are equal to what they were
+ // at the last time we were about to start a wide search then
+ // we are done.
+ if (equal(prev_alpha, alpha, eps) && equal(prev_weights, weights, eps))
+ break;
+
+ prev_alpha = alpha;
+ prev_weights = weights;
+ search_all_alphas = true;
+ ticker = 0;
+ }
+ else
+ {
+ search_all_alphas = false;
+ }
+ ++ticker;
+
+ // compute S and Q using equations 24 and 25 (tempv = phi*sigma*trans(phi)*B*t)
+ tempv = phi*tmp(sigma*tmp(trans(phi)*t/var));
+ for (long i = 0; i < S.size(); ++i)
+ {
+ // if we are currently limiting the search for the next alpha to update
+ // to the set in the active set then skip a non-active vector.
+ if (search_all_alphas == false && active_bases(i) == -1)
+ continue;
+
+ // get the column for this sample out of the kernel matrix. If it is
+ // something in the active set then just get it right out of phi, otherwise
+ // we have to calculate it.
+ if (active_bases(i) != -1)
+ K_col = colm(phi,active_bases(i));
+ else
+ get_kernel_colum(i, x, K_col);
+
+ // tempv2 = trans(phi_m)*B
+ tempv2 = trans(K_col)/var;
+ tempv3 = tempv2*phi;
+ S(i) = tempv2*K_col - tempv3*sigma*trans(tempv3);
+ Q(i) = tempv2*t - tempv2*tempv;
+ }
+
+ const long selected_idx = rvm_helpers::find_next_best_alpha_to_update(S,Q,alpha,active_bases, search_all_alphas, eps);
+
+ // if find_next_best_alpha_to_update didn't find any good alpha to update
+ if (selected_idx == -1)
+ {
+ if (search_all_alphas == false)
+ {
+ // jump us to where search_all_alphas will be set to true and try again
+ ticker = rounds_of_narrow_search;
+ continue;
+ }
+ else
+ {
+ // we are really done so quit the main loop
+ break;
+ }
+ }
+
+ // recompute the variance
+ var = length_squared(t - phi*weights)/(x.nr() - weights.size() + trans(alpha)*diag(sigma));
+
+ // next we update the selected alpha.
+
+ // if the selected alpha is in the active set
+ if (active_bases(selected_idx) >= 0)
+ {
+ const long idx = active_bases(selected_idx);
+ const scalar_type s = alpha(idx)*S(selected_idx)/(alpha(idx) - S(selected_idx));
+ const scalar_type q = alpha(idx)*Q(selected_idx)/(alpha(idx) - S(selected_idx));
+
+ if (q*q-s > 0)
+ {
+ // reestimate the value of alpha
+ alpha(idx) = s*s/(q*q-s);
+
+ }
+ else
+ {
+ // the new alpha value is infinite so remove the selected alpha from our model
+ active_bases(selected_idx) = -1;
+ phi = remove_col(phi, idx);
+ weights = remove_row(weights, idx);
+ alpha = remove_row(alpha, idx);
+
+ // fix the index values in active_bases
+ for (long i = 0; i < active_bases.size(); ++i)
+ {
+ if (active_bases(i) > idx)
+ {
+ active_bases(i) -= 1;
+ }
+ }
+ }
+ }
+ else
+ {
+ const scalar_type s = S(selected_idx);
+ const scalar_type q = Q(selected_idx);
+
+ if (q*q-s > 0)
+ {
+ // add the selected alpha to our model
+
+ active_bases(selected_idx) = phi.nc();
+
+ // update alpha
+ tempv.set_size(alpha.size()+1);
+ set_subm(tempv, get_rect(alpha)) = alpha;
+ tempv(phi.nc()) = s*s/(q*q-s);
+ tempv.swap(alpha);
+
+ // update weights
+ tempv.set_size(weights.size()+1);
+ set_subm(tempv, get_rect(weights)) = weights;
+ tempv(phi.nc()) = 0;
+ tempv.swap(weights);
+
+ // update phi by adding the new sample's kernel matrix column in as one of phi's columns
+ tempm.set_size(phi.nr(), phi.nc()+1);
+ set_subm(tempm, get_rect(phi)) = phi;
+ get_kernel_colum(selected_idx, x, K_col);
+ set_colm(tempm, phi.nc()) = K_col;
+ tempm.swap(phi);
+
+ }
+ }
+
+
+
+ } // end while(true). So we have converged on the final answer.
+
+
+ // now put everything into a decision_function object and return it
+ std_vector_c<sample_type> dictionary;
+ std_vector_c<scalar_type> final_weights;
+ for (long i = 0; i < active_bases.size(); ++i)
+ {
+ if (active_bases(i) >= 0)
+ {
+ dictionary.push_back(x(i));
+ final_weights.push_back(weights(active_bases(i)));
+ }
+ }
+
+ return decision_function<kernel_type> ( mat(final_weights),
+ -sum(mat(final_weights))*tau,
+ kernel,
+ mat(dictionary));
+
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ void get_kernel_colum (
+ long idx,
+ const T& x,
+ scalar_vector_type& col
+ ) const
+ {
+ col.set_size(x.nr());
+ for (long i = 0; i < col.size(); ++i)
+ {
+ col(i) = kernel(x(idx), x(i)) + tau;
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ scalar_type compute_initial_alpha (
+ const M1& phi,
+ const M2& t,
+ const scalar_type& var
+ ) const
+ {
+ const double temp = length_squared(phi);
+ const double temp2 = trans(phi)*t;
+
+ return temp/( temp2*temp2/temp + var);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename M1, typename M2>
+ long pick_initial_vector (
+ const M1& x,
+ const M2& t
+ ) const
+ {
+ scalar_vector_type K_col;
+ double max_projection = -std::numeric_limits<scalar_type>::infinity();
+ long max_idx = 0;
+ // find the row in the kernel matrix that has the biggest normalized projection onto the t vector
+ for (long r = 0; r < x.nr(); ++r)
+ {
+ get_kernel_colum(r,x,K_col);
+ double temp = trans(K_col)*t;
+ temp = temp*temp/length_squared(K_col);
+
+ if (temp > max_projection)
+ {
+ max_projection = temp;
+ max_idx = r;
+ }
+ }
+
+ return max_idx;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // private member variables
+ kernel_type kernel;
+ scalar_type eps;
+
+ const static scalar_type tau;
+
+ }; // end of class rvm_regression_trainer
+
+ template <typename kernel_type>
+ const typename kernel_type::scalar_type rvm_regression_trainer<kernel_type>::tau = static_cast<typename kernel_type::scalar_type>(0.001);
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ rvm_regression_trainer<K>& a,
+ rvm_regression_trainer<K>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RVm_
+
+
diff --git a/ml/dlib/dlib/svm/rvm_abstract.h b/ml/dlib/dlib/svm/rvm_abstract.h
new file mode 100644
index 000000000..236d2ad3c
--- /dev/null
+++ b/ml/dlib/dlib/svm/rvm_abstract.h
@@ -0,0 +1,278 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RVm_ABSTRACT_
+#ifdef DLIB_RVm_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include "../matrix.h"
+#include "../algs.h"
+#include "function.h"
+#include "kernel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kern_type
+ >
+ class rvm_trainer
+ {
+ /*!
+ REQUIREMENTS ON kern_type
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a trainer for a relevance vector machine for
+ solving binary classification problems.
+
+ The implementation of the RVM training algorithm used by this object is based
+ on the following excellent paper:
+ Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation
+ for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings
+ of the Ninth International Workshop on Artificial Intelligence and Statistics,
+ Key West, FL, Jan 3-6.
+ !*/
+
+ public:
+ typedef kern_type kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ rvm_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a relevance vector machine.
+ - #get_epsilon() == 0.001
+ - #get_max_iterations() == 2000
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Generally a good value for this is 0.001. Smaller values may result
+ in a more accurate solution but take longer to execute.
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ unsigned long get_max_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of iterations the RVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - is_binary_classification_problem(x,y) == true
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ ensures
+ - trains a relevance vector classifier given the training samples in x and
+ labels in y.
+ - returns a decision function F with the following properties:
+ - if (new_x is a sample predicted have +1 label) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ throws
+ - std::bad_alloc
+ !*/
+
+ void swap (
+ rvm_trainer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ rvm_trainer<K>& a,
+ rvm_trainer<K>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kern_type
+ >
+ class rvm_regression_trainer
+ {
+ /*!
+ REQUIREMENTS ON kern_type
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a trainer for a relevance vector machine for
+ solving regression problems.
+
+ The implementation of the RVM training algorithm used by this object is based
+ on the following excellent paper:
+ Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation
+ for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings
+ of the Ninth International Workshop on Artificial Intelligence and Statistics,
+ Key West, FL, Jan 3-6.
+ !*/
+
+ public:
+ typedef kern_type kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ rvm_regression_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a relevance vector machine.
+ - #get_epsilon() == 0.001
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Generally a good value for this is 0.001. Smaller values may result
+ in a more accurate solution but take longer to execute.
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ - is_learning_problem(x,y) == true
+ - x.size() > 0
+ ensures
+ - trains a RVM given the training samples in x and
+ labels in y and returns the resulting decision_function.
+ throws
+ - std::bad_alloc
+ !*/
+
+ void swap (
+ rvm_regression_trainer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ rvm_regression_trainer<K>& a,
+ rvm_regression_trainer<K>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RVm_ABSTRACT_
+
diff --git a/ml/dlib/dlib/svm/sequence_labeler.h b/ml/dlib/dlib/svm/sequence_labeler.h
new file mode 100644
index 000000000..882cdb881
--- /dev/null
+++ b/ml/dlib/dlib/svm/sequence_labeler.h
@@ -0,0 +1,339 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEQUENCE_LAbELER_H_h_
+#define DLIB_SEQUENCE_LAbELER_H_h_
+
+#include "sequence_labeler_abstract.h"
+#include "../matrix.h"
+#include <vector>
+#include "../optimization/find_max_factor_graph_viterbi.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace fe_helpers
+ {
+ template <typename EXP>
+ struct dot_functor
+ {
+ dot_functor(const matrix_exp<EXP>& lambda_) : lambda(lambda_), value(0) {}
+
+ inline void operator() (
+ unsigned long feat_index
+ )
+ {
+ value += lambda(feat_index);
+ }
+
+ inline void operator() (
+ unsigned long feat_index,
+ double feat_value
+ )
+ {
+ value += feat_value*lambda(feat_index);
+ }
+
+ const matrix_exp<EXP>& lambda;
+ double value;
+ };
+
+ template <typename feature_extractor, typename EXP, typename sequence_type, typename EXP2>
+ double dot(
+ const matrix_exp<EXP>& lambda,
+ const feature_extractor& fe,
+ const sequence_type& sequence,
+ const matrix_exp<EXP2>& candidate_labeling,
+ unsigned long position
+ )
+ {
+ dot_functor<EXP> dot(lambda);
+ fe.get_features(dot, sequence, candidate_labeling, position);
+ return dot.value;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(
+ has_reject_labeling,
+ bool,
+ template reject_labeling<matrix<unsigned long> >,
+ (const typename T::sequence_type&, const matrix_exp<matrix<unsigned long> >&, unsigned long)const
+ );
+
+ template <typename feature_extractor, typename EXP, typename sequence_type>
+ typename enable_if<has_reject_labeling<feature_extractor>,bool>::type call_reject_labeling_if_exists (
+ const feature_extractor& fe,
+ const sequence_type& x,
+ const matrix_exp<EXP>& y,
+ unsigned long position
+ )
+ {
+ return fe.reject_labeling(x, y, position);
+ }
+
+ template <typename feature_extractor, typename EXP, typename sequence_type>
+ typename disable_if<has_reject_labeling<feature_extractor>,bool>::type call_reject_labeling_if_exists (
+ const feature_extractor& ,
+ const sequence_type& ,
+ const matrix_exp<EXP>& ,
+ unsigned long
+ )
+ {
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ typename enable_if<dlib::impl::has_reject_labeling<feature_extractor>,bool>::type contains_invalid_labeling (
+ const feature_extractor& fe,
+ const typename feature_extractor::sequence_type& x,
+ const std::vector<unsigned long>& y
+ )
+ {
+ if (x.size() != y.size())
+ return true;
+
+ matrix<unsigned long,0,1> node_states;
+
+ for (unsigned long i = 0; i < x.size(); ++i)
+ {
+ node_states.set_size(std::min(fe.order(),i) + 1);
+ for (unsigned long j = 0; j < (unsigned long)node_states.size(); ++j)
+ node_states(j) = y[i-j];
+
+ if (fe.reject_labeling(x, node_states, i))
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ typename disable_if<dlib::impl::has_reject_labeling<feature_extractor>,bool>::type contains_invalid_labeling (
+ const feature_extractor& ,
+ const typename feature_extractor::sequence_type& x,
+ const std::vector<unsigned long>& y
+ )
+ {
+ if (x.size() != y.size())
+ return true;
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ bool contains_invalid_labeling (
+ const feature_extractor& fe,
+ const std::vector<typename feature_extractor::sequence_type>& x,
+ const std::vector<std::vector<unsigned long> >& y
+ )
+ {
+ if (x.size() != y.size())
+ return true;
+
+ for (unsigned long i = 0; i < x.size(); ++i)
+ {
+ if (contains_invalid_labeling(fe,x[i],y[i]))
+ return true;
+ }
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class sequence_labeler
+ {
+ public:
+ typedef typename feature_extractor::sequence_type sample_sequence_type;
+ typedef std::vector<unsigned long> labeled_sequence_type;
+
+ private:
+ class map_prob
+ {
+ public:
+ unsigned long order() const { return fe.order(); }
+ unsigned long num_states() const { return fe.num_labels(); }
+
+ map_prob(
+ const sample_sequence_type& x_,
+ const feature_extractor& fe_,
+ const matrix<double,0,1>& weights_
+ ) :
+ sequence(x_),
+ fe(fe_),
+ weights(weights_)
+ {
+ }
+
+ unsigned long number_of_nodes(
+ ) const
+ {
+ return sequence.size();
+ }
+
+ template <
+ typename EXP
+ >
+ double factor_value (
+ unsigned long node_id,
+ const matrix_exp<EXP>& node_states
+ ) const
+ {
+ if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id))
+ return -std::numeric_limits<double>::infinity();
+
+ return fe_helpers::dot(weights, fe, sequence, node_states, node_id);
+ }
+
+ const sample_sequence_type& sequence;
+ const feature_extractor& fe;
+ const matrix<double,0,1>& weights;
+ };
+ public:
+
+ sequence_labeler()
+ {
+ weights.set_size(fe.num_features());
+ weights = 0;
+ }
+
+ explicit sequence_labeler(
+ const matrix<double,0,1>& weights_
+ ) :
+ weights(weights_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(fe.num_features() == static_cast<unsigned long>(weights_.size()),
+ "\t sequence_labeler::sequence_labeler(weights_)"
+ << "\n\t These sizes should match"
+ << "\n\t fe.num_features(): " << fe.num_features()
+ << "\n\t weights_.size(): " << weights_.size()
+ << "\n\t this: " << this
+ );
+ }
+
+ sequence_labeler(
+ const matrix<double,0,1>& weights_,
+ const feature_extractor& fe_
+ ) :
+ fe(fe_),
+ weights(weights_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(fe_.num_features() == static_cast<unsigned long>(weights_.size()),
+ "\t sequence_labeler::sequence_labeler(weights_,fe_)"
+ << "\n\t These sizes should match"
+ << "\n\t fe_.num_features(): " << fe_.num_features()
+ << "\n\t weights_.size(): " << weights_.size()
+ << "\n\t this: " << this
+ );
+ }
+
+ const feature_extractor& get_feature_extractor (
+ ) const { return fe; }
+
+ const matrix<double,0,1>& get_weights (
+ ) const { return weights; }
+
+ unsigned long num_labels (
+ ) const { return fe.num_labels(); }
+
+ labeled_sequence_type operator() (
+ const sample_sequence_type& x
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_labels() > 0,
+ "\t labeled_sequence_type sequence_labeler::operator()(x)"
+ << "\n\t You can't have no labels."
+ << "\n\t this: " << this
+ );
+
+ labeled_sequence_type y;
+ find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
+ return y;
+ }
+
+ void label_sequence (
+ const sample_sequence_type& x,
+ labeled_sequence_type& y
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(num_labels() > 0,
+ "\t void sequence_labeler::label_sequence(x,y)"
+ << "\n\t You can't have no labels."
+ << "\n\t this: " << this
+ );
+
+ find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
+ }
+
+ private:
+
+ feature_extractor fe;
+ matrix<double,0,1> weights;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void serialize (
+ const sequence_labeler<feature_extractor>& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.get_feature_extractor(), out);
+ serialize(item.get_weights(), out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void deserialize (
+ sequence_labeler<feature_extractor>& item,
+ std::istream& in
+ )
+ {
+ feature_extractor fe;
+ matrix<double,0,1> weights;
+
+ deserialize(fe, in);
+ deserialize(weights, in);
+
+ item = sequence_labeler<feature_extractor>(weights, fe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEQUENCE_LAbELER_H_h_
+
diff --git a/ml/dlib/dlib/svm/sequence_labeler_abstract.h b/ml/dlib/dlib/svm/sequence_labeler_abstract.h
new file mode 100644
index 000000000..3970b723a
--- /dev/null
+++ b/ml/dlib/dlib/svm/sequence_labeler_abstract.h
@@ -0,0 +1,396 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SEQUENCE_LAbELER_ABSTRACT_H_h_
+#ifdef DLIB_SEQUENCE_LAbELER_ABSTRACT_H_h_
+
+#include "../matrix.h"
+#include <vector>
+#include "../optimization/find_max_factor_graph_viterbi_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class example_feature_extractor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines the interface a feature extractor must implement
+ if it is to be used with the sequence_labeler defined at the bottom
+ of this file.
+
+ The model used by sequence_labeler objects is the following.
+ Given an input sequence x, predict an output label sequence y
+ such that:
+ y == argmax_Y dot(w, PSI(x,Y))
+ Where w is a parameter vector.
+
+ Therefore, a feature extractor defines how the PSI(x,y) feature vector
+ is calculated. It also defines how many output labels there are as
+ well as the order of the model.
+
+ Finally, note that PSI(x,y) is a sum of feature vectors, each derived
+ from the entire input sequence x but only part of the label sequence y.
+ Each of these constituent feature vectors is defined by the get_features()
+ method of this class.
+
+ THREAD SAFETY
+ Instances of this object are required to be threadsafe, that is, it should
+ be safe for multiple threads to make concurrent calls to the member
+ functions of this object.
+ !*/
+
+ public:
+ // This should be the type used to represent an input sequence. It can be
+ // anything so long as it has a .size() which returns the length of the sequence.
+ typedef the_type_used_to_represent_a_sequence sequence_type;
+
+ example_feature_extractor (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ unsigned long num_features (
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the PSI() feature vector.
+ !*/
+
+ unsigned long order(
+ ) const;
+ /*!
+ ensures
+ - This object represents a Markov model on the output labels.
+ This parameter defines the order of the model. That is, this
+ value controls how many previous label values get to be taken
+ into consideration when performing feature extraction for a
+ particular element of the input sequence. Note that the runtime
+ of the algorithm is exponential in the order. So don't make order
+ very large.
+ !*/
+
+ unsigned long num_labels(
+ ) const;
+ /*!
+ ensures
+ - returns the number of possible output labels.
+ !*/
+
+ template <typename EXP>
+ bool reject_labeling (
+ const sequence_type& x,
+ const matrix_exp<EXP>& y,
+ unsigned long position
+ ) const;
+ /*!
+ requires
+ - EXP::type == unsigned long
+ (i.e. y contains unsigned longs)
+ - position < x.size()
+ - y.size() == min(position, order()) + 1
+ - is_vector(y) == true
+ - max(y) < num_labels()
+ ensures
+ - for all valid i:
+ - interprets y(i) as the label corresponding to x[position-i]
+ - if (the labeling in y for x[position] is always the wrong labeling) then
+ - returns true
+ (note that reject_labeling() is just an optional tool to allow you
+ to overrule the normal labeling algorithm. You don't have to use
+ it. So if you don't include a reject_labeling() method in your
+ feature extractor it is the same as including one that always
+ returns false.)
+ - else
+ - returns false
+ !*/
+
+ template <typename feature_setter, typename EXP>
+ void get_features (
+ feature_setter& set_feature,
+ const sequence_type& x,
+ const matrix_exp<EXP>& y,
+ unsigned long position
+ ) const;
+ /*!
+ requires
+ - EXP::type == unsigned long
+ (i.e. y contains unsigned longs)
+ - reject_labeling(x,y,position) == false
+ - position < x.size()
+ - y.size() == min(position, order()) + 1
+ - is_vector(y) == true
+ - max(y) < num_labels()
+ - set_feature is a function object which allows expressions of the form:
+ - set_features((unsigned long)feature_index, (double)feature_value);
+ - set_features((unsigned long)feature_index);
+ ensures
+ - for all valid i:
+ - interprets y(i) as the label corresponding to x[position-i]
+ - This function computes the part of PSI() corresponding to the x[position]
+ element of the input sequence. Moreover, this part of PSI() is returned as
+ a sparse vector by invoking set_feature(). For example, to set the feature
+ with an index of 55 to the value of 1 this method would call:
+ set_feature(55);
+ Or equivalently:
+ set_feature(55,1);
+ Therefore, the first argument to set_feature is the index of the feature
+ to be set while the second argument is the value the feature should take.
+ Additionally, note that calling set_feature() multiple times with the same
+ feature index does NOT overwrite the old value, it adds to the previous
+ value. For example, if you call set_feature(55) 3 times then it will
+ result in feature 55 having a value of 3.
+ - This function only calls set_feature() with feature_index values < num_features()
+ !*/
+
+ unsigned long num_nonnegative_weights (
+ ) const;
+ /*!
+ ensures
+ - returns the number of elements of the w parameter vector which should be
+ non-negative. That is, this feature extractor is intended to be used
+ with w vectors where the first num_nonnegative_weights() elements of w
+ are >= 0. That is, it should be the case that w(i) >= 0 for all i <
+ num_nonnegative_weights().
+ - Note that num_nonnegative_weights() is just an optional method to allow
+ you to tell a tool like the structural_sequence_labeling_trainer that the
+ learned w should have a certain number of non-negative elements.
+ Therefore, if you do not provide a num_nonnegative_weights() method in
+ your feature extractor then it will default to a value of 0, indicating
+ that all elements of the w parameter vector may be any value.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize(
+ const example_feature_extractor& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize(
+ example_feature_extractor& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ bool contains_invalid_labeling (
+ const feature_extractor& fe,
+ const typename feature_extractor::sequence_type& x,
+ const std::vector<unsigned long>& y
+ );
+ /*!
+ requires
+ - feature_extractor must be an object that implements an interface compatible
+ with the example_feature_extractor discussed above.
+ ensures
+ - if (x.size() != y.size() ||
+ fe.reject_labeling() rejects any of the labels in y) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ bool contains_invalid_labeling (
+ const feature_extractor& fe,
+ const std::vector<typename feature_extractor::sequence_type>& x,
+ const std::vector<std::vector<unsigned long> >& y
+ );
+ /*!
+ requires
+ - feature_extractor must be an object that implements an interface compatible
+ with the example_feature_extractor discussed above.
+ ensures
+ - if (x.size() != y.size() ||
+ contains_invalid_labeling(fe,x[i],y[i]) == true for some i ) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class sequence_labeler
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ It must be an object that implements an interface compatible with
+ the example_feature_extractor discussed above.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for doing sequence labeling. In particular, it is
+ capable of representing sequence labeling models such as those produced by
+ Hidden Markov SVMs or Chain Structured Conditional Random fields. See the
+ following papers for an introduction to these techniques:
+ - Hidden Markov Support Vector Machines by
+ Y. Altun, I. Tsochantaridis, T. Hofmann
+ - Shallow Parsing with Conditional Random Fields by
+ Fei Sha and Fernando Pereira
+
+
+ The model used by this object is the following. Given
+ an input sequence x, predict an output label sequence y
+ such that:
+ y == argmax_Y dot(get_weights(), PSI(x,Y))
+ Where PSI() is defined by the feature_extractor template
+ argument.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call the const members of this object from multiple
+ threads so long as the feature_extractor is also threadsafe. This is
+ because the const members are purely read-only operations. However,
+ any operation that modifies a sequence_labeler is not threadsafe.
+ !*/
+
+ public:
+ typedef typename feature_extractor::sequence_type sample_sequence_type;
+ typedef std::vector<unsigned long> labeled_sequence_type;
+
+ sequence_labeler(
+ );
+ /*!
+ ensures
+ - #get_feature_extractor() == feature_extractor()
+ (i.e. it will have its default value)
+ - #get_weights().size() == #get_feature_extractor().num_features()
+ - #get_weights() == 0
+ !*/
+
+ explicit sequence_labeler(
+ const matrix<double,0,1>& weights
+ );
+ /*!
+ requires
+ - feature_extractor().num_features() == weights.size()
+ ensures
+ - #get_feature_extractor() == feature_extractor()
+ (i.e. it will have its default value)
+ - #get_weights() == weights
+ !*/
+
+ sequence_labeler(
+ const matrix<double,0,1>& weights,
+ const feature_extractor& fe
+ );
+ /*!
+ requires
+ - fe.num_features() == weights.size()
+ ensures
+ - #get_feature_extractor() == fe
+ - #get_weights() == weights
+ !*/
+
+ const feature_extractor& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used by this object
+ !*/
+
+ const matrix<double,0,1>& get_weights (
+ ) const;
+ /*!
+ ensures
+ - returns the parameter vector associated with this sequence labeler.
+ The length of the vector is get_feature_extractor().num_features().
+ !*/
+
+ unsigned long num_labels (
+ ) const;
+ /*!
+ ensures
+ - returns get_feature_extractor().num_labels()
+ (i.e. returns the number of possible output labels for each
+ element of a sequence)
+ !*/
+
+ labeled_sequence_type operator() (
+ const sample_sequence_type& x
+ ) const;
+ /*!
+ requires
+ - num_labels() > 0
+ ensures
+ - returns a vector Y of label values such that:
+ - Y.size() == x.size()
+ - for all valid i:
+ - Y[i] == the predicted label for x[i]
+ - 0 <= Y[i] < num_labels()
+ !*/
+
+ void label_sequence (
+ const sample_sequence_type& x,
+ labeled_sequence_type& y
+ ) const;
+ /*!
+ requires
+ - num_labels() > 0
+ ensures
+ - #y == (*this)(x)
+ (i.e. This is just another interface to the operator() routine
+ above. This one avoids returning the results by value and therefore
+ might be a little faster in some cases)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void serialize (
+ const sequence_labeler<feature_extractor>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void deserialize (
+ sequence_labeler<feature_extractor>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEQUENCE_LAbELER_ABSTRACT_H_h_
+
+
diff --git a/ml/dlib/dlib/svm/sequence_segmenter.h b/ml/dlib/dlib/svm/sequence_segmenter.h
new file mode 100644
index 000000000..237023efa
--- /dev/null
+++ b/ml/dlib/dlib/svm/sequence_segmenter.h
@@ -0,0 +1,468 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SEQUENCE_SeGMENTER_H_h_
+#define DLIB_SEQUENCE_SeGMENTER_H_h_
+
+#include "sequence_segmenter_abstract.h"
+#include "../matrix.h"
+#include "sequence_labeler.h"
+#include <vector>
+
+namespace dlib
+{
+ // This namespace contains implementation details for the sequence_segmenter.
+ namespace impl_ss
+ {
+
+ // ------------------------------------------------------------------------------------
+
+ // BIO/BILOU labels
+ const unsigned int BEGIN = 0;
+ const unsigned int INSIDE = 1;
+ const unsigned int OUTSIDE = 2;
+ const unsigned int LAST = 3;
+ const unsigned int UNIT = 4;
+
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename ss_feature_extractor>
+ class feature_extractor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a feature extractor for a sequence_labeler. It serves to map
+ the interface defined by a sequence_labeler into the kind of interface
+ defined for a sequence_segmenter.
+ !*/
+
+ public:
+ typedef typename ss_feature_extractor::sequence_type sequence_type;
+
+ ss_feature_extractor fe;
+
+ feature_extractor() {}
+ feature_extractor(const ss_feature_extractor& ss_fe_) : fe(ss_fe_) {}
+
+ unsigned long num_nonnegative_weights (
+ ) const
+ {
+ const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5;
+ if (ss_feature_extractor::allow_negative_weights)
+ {
+ return 0;
+ }
+ else
+ {
+ // We make everything non-negative except for the label transition
+ // and bias features.
+ return num_features() - NL*NL - NL;
+ }
+ }
+
+ friend void serialize(const feature_extractor& item, std::ostream& out)
+ {
+ serialize(item.fe, out);
+ }
+
+ friend void deserialize(feature_extractor& item, std::istream& in)
+ {
+ deserialize(item.fe, in);
+ }
+
+ unsigned long num_features() const
+ {
+ const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5;
+ if (ss_feature_extractor::use_high_order_features)
+ return NL + NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size();
+ else
+ return NL + NL*NL + NL*fe.num_features()*fe.window_size();
+ }
+
+ unsigned long order() const
+ {
+ return 1;
+ }
+
+ unsigned long num_labels() const
+ {
+ if (ss_feature_extractor::use_BIO_model)
+ return 3;
+ else
+ return 5;
+ }
+
+ private:
+
+ template <typename feature_setter>
+ struct dot_functor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This class wraps the feature_setter used by a sequence_labeler
+ and turns it into the kind needed by a sequence_segmenter.
+ !*/
+
+ dot_functor(feature_setter& set_feature_, unsigned long offset_) :
+ set_feature(set_feature_), offset(offset_) {}
+
+ feature_setter& set_feature;
+ unsigned long offset;
+
+ inline void operator() (
+ unsigned long feat_index
+ )
+ {
+ set_feature(offset+feat_index);
+ }
+
+ inline void operator() (
+ unsigned long feat_index,
+ double feat_value
+ )
+ {
+ set_feature(offset+feat_index, feat_value);
+ }
+ };
+
+ public:
+
+ template <typename EXP>
+ bool reject_labeling (
+ const sequence_type& x,
+ const matrix_exp<EXP>& y,
+ unsigned long pos
+ ) const
+ {
+ if (ss_feature_extractor::use_BIO_model)
+ {
+ // Don't allow BIO label patterns that don't correspond to a sensical
+ // segmentation.
+ if (y.size() > 1 && y(0) == INSIDE && y(1) == OUTSIDE)
+ return true;
+ if (y.size() == 1 && y(0) == INSIDE)
+ return true;
+ }
+ else
+ {
+ // Don't allow BILOU label patterns that don't correspond to a sensical
+ // segmentation.
+ if (y.size() > 1)
+ {
+ if (y(1) == BEGIN && y(0) == OUTSIDE)
+ return true;
+ if (y(1) == BEGIN && y(0) == UNIT)
+ return true;
+ if (y(1) == BEGIN && y(0) == BEGIN)
+ return true;
+
+ if (y(1) == INSIDE && y(0) == BEGIN)
+ return true;
+ if (y(1) == INSIDE && y(0) == OUTSIDE)
+ return true;
+ if (y(1) == INSIDE && y(0) == UNIT)
+ return true;
+
+ if (y(1) == OUTSIDE && y(0) == INSIDE)
+ return true;
+ if (y(1) == OUTSIDE && y(0) == LAST)
+ return true;
+
+ if (y(1) == LAST && y(0) == INSIDE)
+ return true;
+ if (y(1) == LAST && y(0) == LAST)
+ return true;
+
+ if (y(1) == UNIT && y(0) == INSIDE)
+ return true;
+ if (y(1) == UNIT && y(0) == LAST)
+ return true;
+
+ // if at the end of the sequence
+ if (pos == x.size()-1)
+ {
+ if (y(0) == BEGIN)
+ return true;
+ if (y(0) == INSIDE)
+ return true;
+ }
+ }
+ else
+ {
+ if (y(0) == INSIDE)
+ return true;
+ if (y(0) == LAST)
+ return true;
+
+ // if at the end of the sequence
+ if (pos == x.size()-1)
+ {
+ if (y(0) == BEGIN)
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ template <typename feature_setter, typename EXP>
+ void get_features (
+ feature_setter& set_feature,
+ const sequence_type& x,
+ const matrix_exp<EXP>& y,
+ unsigned long position
+ ) const
+ {
+ unsigned long offset = 0;
+
+ const int window_size = fe.window_size();
+
+ const int base_dims = fe.num_features();
+ for (int i = 0; i < window_size; ++i)
+ {
+ const long pos = i-window_size/2 + static_cast<long>(position);
+ if (0 <= pos && pos < (long)x.size())
+ {
+ const unsigned long off1 = y(0)*base_dims;
+ dot_functor<feature_setter> fs1(set_feature, offset+off1);
+ fe.get_features(fs1, x, pos);
+
+ if (ss_feature_extractor::use_high_order_features && y.size() > 1)
+ {
+ const unsigned long off2 = num_labels()*base_dims + (y(0)*num_labels()+y(1))*base_dims;
+ dot_functor<feature_setter> fs2(set_feature, offset+off2);
+ fe.get_features(fs2, x, pos);
+ }
+ }
+
+ if (ss_feature_extractor::use_high_order_features)
+ offset += num_labels()*base_dims + num_labels()*num_labels()*base_dims;
+ else
+ offset += num_labels()*base_dims;
+ }
+
+ // Pull out an indicator feature for the type of transition between the
+ // previous label and the current label.
+ if (y.size() > 1)
+ set_feature(offset + y(1)*num_labels() + y(0));
+
+ offset += num_labels()*num_labels();
+ // pull out an indicator feature for the current label. This is the per
+ // label bias.
+ set_feature(offset + y(0));
+ }
+ };
+
+ } // end namespace impl_ss
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ unsigned long total_feature_vector_size (
+ const feature_extractor& fe
+ )
+ {
+ const unsigned long NL = feature_extractor::use_BIO_model ? 3 : 5;
+ if (feature_extractor::use_high_order_features)
+ return NL + NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size();
+ else
+ return NL + NL*NL + NL*fe.num_features()*fe.window_size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class sequence_segmenter
+ {
+ public:
+ typedef typename feature_extractor::sequence_type sample_sequence_type;
+ typedef std::vector<std::pair<unsigned long, unsigned long> > segmented_sequence_type;
+
+
+ sequence_segmenter()
+ {
+#ifdef ENABLE_ASSERTS
+ const feature_extractor& fe = labeler.get_feature_extractor().fe;
+ DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1,
+ "\t sequence_segmenter::sequence_segmenter()"
+ << "\n\t An invalid feature extractor was supplied."
+ << "\n\t fe.window_size(): " << fe.window_size()
+ << "\n\t fe.num_features(): " << fe.num_features()
+ << "\n\t this: " << this
+ );
+#endif
+ }
+
+ explicit sequence_segmenter(
+ const matrix<double,0,1>& weights
+ ) :
+ labeler(weights)
+ {
+#ifdef ENABLE_ASSERTS
+ const feature_extractor& fe = labeler.get_feature_extractor().fe;
+ // make sure requires clause is not broken
+ DLIB_ASSERT(total_feature_vector_size(fe) == (unsigned long)weights.size(),
+ "\t sequence_segmenter::sequence_segmenter(weights)"
+ << "\n\t These sizes should match"
+ << "\n\t total_feature_vector_size(fe): " << total_feature_vector_size(fe)
+ << "\n\t weights.size(): " << weights.size()
+ << "\n\t this: " << this
+ );
+ DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1,
+ "\t sequence_segmenter::sequence_segmenter()"
+ << "\n\t An invalid feature extractor was supplied."
+ << "\n\t fe.window_size(): " << fe.window_size()
+ << "\n\t fe.num_features(): " << fe.num_features()
+ << "\n\t this: " << this
+ );
+#endif
+ }
+
+ sequence_segmenter(
+ const matrix<double,0,1>& weights,
+ const feature_extractor& fe
+ ) :
+ labeler(weights, impl_ss::feature_extractor<feature_extractor>(fe))
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(total_feature_vector_size(fe) == (unsigned long)weights.size(),
+ "\t sequence_segmenter::sequence_segmenter(weights,fe)"
+ << "\n\t These sizes should match"
+ << "\n\t total_feature_vector_size(fe): " << total_feature_vector_size(fe)
+ << "\n\t weights.size(): " << weights.size()
+ << "\n\t this: " << this
+ );
+ DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1,
+ "\t sequence_segmenter::sequence_segmenter()"
+ << "\n\t An invalid feature extractor was supplied."
+ << "\n\t fe.window_size(): " << fe.window_size()
+ << "\n\t fe.num_features(): " << fe.num_features()
+ << "\n\t this: " << this
+ );
+ }
+
+ const feature_extractor& get_feature_extractor (
+ ) const { return labeler.get_feature_extractor().fe; }
+
+ const matrix<double,0,1>& get_weights (
+ ) const { return labeler.get_weights(); }
+
+ segmented_sequence_type operator() (
+ const sample_sequence_type& x
+ ) const
+ {
+ segmented_sequence_type y;
+ segment_sequence(x,y);
+ return y;
+ }
+
+ void segment_sequence (
+ const sample_sequence_type& x,
+ segmented_sequence_type& y
+ ) const
+ {
+ y.clear();
+ std::vector<unsigned long> labels;
+ labeler.label_sequence(x, labels);
+
+ if (feature_extractor::use_BIO_model)
+ {
+ // Convert from BIO tagging to the explicit segments representation.
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ if (labels[i] == impl_ss::BEGIN)
+ {
+ const unsigned long begin = i;
+ ++i;
+ while (i < labels.size() && labels[i] == impl_ss::INSIDE)
+ ++i;
+
+ y.push_back(std::make_pair(begin, i));
+ --i;
+ }
+ }
+ }
+ else
+ {
+ // Convert from BILOU tagging to the explicit segments representation.
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ if (labels[i] == impl_ss::BEGIN)
+ {
+ const unsigned long begin = i;
+ ++i;
+ while (i < labels.size() && labels[i] == impl_ss::INSIDE)
+ ++i;
+
+ y.push_back(std::make_pair(begin, i+1));
+ }
+ else if (labels[i] == impl_ss::UNIT)
+ {
+ y.push_back(std::make_pair(i, i+1));
+ }
+ }
+ }
+ }
+
+ friend void serialize(const sequence_segmenter& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+
+ // Save these just so we can compare them when we deserialize and make
+ // sure the feature_extractor being used is compatible with the model being
+ // loaded.
+ serialize(feature_extractor::use_BIO_model, out);
+ serialize(feature_extractor::use_high_order_features, out);
+ serialize(total_feature_vector_size(item.get_feature_extractor()), out);
+
+ serialize(item.labeler, out);
+ }
+
+ friend void deserialize(sequence_segmenter& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::sequence_segmenter.");
+
+ // Try to check if the saved model is compatible with the current feature
+ // extractor.
+ bool use_BIO_model, use_high_order_features;
+ unsigned long dims;
+ deserialize(use_BIO_model, in);
+ deserialize(use_high_order_features, in);
+ deserialize(dims, in);
+ deserialize(item.labeler, in);
+ if (use_BIO_model != feature_extractor::use_BIO_model)
+ {
+ throw serialization_error("Incompatible feature extractor found while deserializing "
+ "dlib::sequence_segmenter. Wrong value of use_BIO_model.");
+ }
+ if (use_high_order_features != feature_extractor::use_high_order_features)
+ {
+ throw serialization_error("Incompatible feature extractor found while deserializing "
+ "dlib::sequence_segmenter. Wrong value of use_high_order_features.");
+ }
+ if (dims != total_feature_vector_size(item.get_feature_extractor()))
+ {
+ throw serialization_error("Incompatible feature extractor found while deserializing "
+ "dlib::sequence_segmenter. Wrong value of total_feature_vector_size().");
+ }
+ }
+
+ private:
+ sequence_labeler<impl_ss::feature_extractor<feature_extractor> > labeler;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEQUENCE_SeGMENTER_H_h_
+
+
diff --git a/ml/dlib/dlib/svm/sequence_segmenter_abstract.h b/ml/dlib/dlib/svm/sequence_segmenter_abstract.h
new file mode 100644
index 000000000..7229fee22
--- /dev/null
+++ b/ml/dlib/dlib/svm/sequence_segmenter_abstract.h
@@ -0,0 +1,452 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SEQUENCE_SeGMENTER_ABSTRACT_H_h_
+#ifdef DLIB_SEQUENCE_SeGMENTER_ABSTRACT_H_h_
+
+#include "../matrix.h"
+#include <vector>
+#include "sequence_labeler_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class example_feature_extractor
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines the interface a feature extractor must implement if it
+ is to be used with the sequence_segmenter defined at the bottom of this
+ file.
+
+ The model used by sequence_segmenter objects is the following. Given an
+ input sequence x, predict an output label sequence y such that:
+ y == argmax_Y dot(w, PSI(x,Y))
+ Where w is a parameter vector and the label sequence defines a segmentation
+ of x.
+
+ Recall that a sequence_segmenter uses the BIO or BILOU tagging model and is
+ also an instantiation of the dlib::sequence_labeler. Selecting to use the
+ BIO model means that each element of the label sequence y takes on one of
+ three possible values (B, I, or O) and together these labels define a
+ segmentation of the sequence. For example, to represent a segmentation of
+ the sequence of words "The dog ran to Bob Jones" where only "Bob Jones" was
+ segmented out we would use the label sequence OOOOBI. The BILOU model is
+ similar except that it uses five different labels and each segment is
+ labeled as U, BL, BIL, BIIL, BIIIL, and so on depending on its length.
+ Therefore, the BILOU model is able to more explicitly model the ends of the
+ segments than the BIO model, but has more parameters to estimate.
+
+ Keeping all this in mind, the purpose of a sequence_segmenter is to take
+ care of the bookkeeping associated with creating BIO/BILOU tagging models
+ for segmentation tasks. In particular, it presents the user with a
+ simplified version of the interface used by the dlib::sequence_labeler. It
+ does this by completely hiding the BIO/BILOU tags from the user and instead
+ exposes an explicit sub-segment based labeling representation. It also
+ simplifies the construction of the PSI() feature vector.
+
+ Like in the dlib::sequence_labeler, PSI() is a sum of feature vectors, each
+ derived from the entire input sequence x but only part of the label
+ sequence y. In the case of the sequence_segmenter, we use an order one
+ Markov model. This means that
+ PSI(x,y) == sum_i XI(x, y_{i-1}, y_{i}, i)
+ where the sum is taken over all the elements in the sequence. At each
+ element we extract a feature vector, XI(), that is expected to encode
+ important details describing what the i-th position of the sequence looks
+ like in the context of the current and previous labels. To do this, XI()
+ is allowed to look at any part of the input sequence x, the current and
+ previous labels, and of course it must also know the position in question, i.
+
+ The sequence_segmenter simplifies this further by decomposing XI() into
+ components which model the current window around each position as well as
+ the conjunction of the current window around each position and the previous
+ label. In particular, the sequence_segmenter only asks a user to provide a
+ single feature vector which characterizes a position of the sequence
+ independent of any labeling. We denote this feature vector by ZI(x,i), where
+ x is the sequence and i is the position in question.
+
+ For example, suppose we use a window size of 3 and BIO tags, then we can
+ put all this together and define XI() in terms of ZI(). To do this, we can
+ think of XI() as containing 12*3 slots which contain either a zero vector
+ or a ZI() vector. Each combination of window position and labeling has a
+ different slot. To explain further, consider the following examples where
+ we have annotated which parts of XI() correspond to each slot.
+
+ If the previous and current label are both B and we use a window size of 3
+ then XI() would be instantiated as:
+ XI(x, B, B, i) = [ZI(x,i-1) \
+ ZI(x,i) > If current label is B
+ ZI(x,i+1) /
+ 0 \
+ 0 > If current label is I
+ 0 /
+ 0 \
+ 0 > If current label is O
+ 0 /
+
+ ZI(x,i-1) \
+ ZI(x,i) > If previous label is B and current label is B
+ ZI(x,i+1) /
+ 0 \
+ 0 > If previous label is B and current label is I
+ 0 /
+ 0 \
+ 0 > If previous label is B and current label is O
+ 0 /
+
+ 0 \
+ 0 > If previous label is I and current label is B
+ 0 /
+ 0 \
+ 0 > If previous label is I and current label is I
+ 0 /
+ 0 \
+ 0 > If previous label is I and current label is O
+ 0 /
+
+ 0 \
+ 0 > If previous label is O and current label is B
+ 0 /
+ 0 \
+ 0 > If previous label is O and current label is I
+ 0 /
+ 0 \
+ 0 > If previous label is O and current label is O
+ 0] /
+
+
+ If the previous label is I and the current label is O and we use a window
+ size of 3 then XI() would be instantiated as:
+ XI(x, I, O, i) = [0 \
+ 0 > If current label is B
+ 0 /
+ 0 \
+ 0 > If current label is I
+ 0 /
+ ZI(x,i-1) \
+ ZI(x,i) > If current label is O
+ ZI(x,i+1) /
+
+ 0 \
+ 0 > If previous label is B and current label is B
+ 0 /
+ 0 \
+ 0 > If previous label is B and current label is I
+ 0 /
+ 0 \
+ 0 > If previous label is B and current label is O
+ 0 /
+
+ 0 \
+ 0 > If previous label is I and current label is B
+ 0 /
+ 0 \
+ 0 > If previous label is I and current label is I
+ 0 /
+ ZI(x,i-1) \
+ ZI(x,i) > If previous label is I and current label is O
+ ZI(x,i+1) /
+
+ 0 \
+ 0 > If previous label is O and current label is B
+ 0 /
+ 0 \
+ 0 > If previous label is O and current label is I
+ 0 /
+ 0 \
+ 0 > If previous label is O and current label is O
+ 0] /
+
+ If we had instead used the BILOU tagging model the XI() vector would
+ have been similarly defined except that there would be 30*3 slots for
+ the various label combination instead of 12*3.
+
+ Finally, while not shown here, we also include indicator features in
+ XI() to model label transitions and individual label biases. These are
+ 12 extra features in the case of the BIO tagging model and 30 extra in
+ the case of the BILOU tagging model.
+
+ THREAD SAFETY
+ Instances of this object are required to be threadsafe, that is, it should
+ be safe for multiple threads to make concurrent calls to the member
+ functions of this object.
+ !*/
+
+ public:
+ // This should be the type used to represent an input sequence. It can be
+ // anything so long as it has a .size() which returns the length of the sequence.
+ typedef the_type_used_to_represent_a_sequence sequence_type;
+
+ // If you want to use the BIO tagging model then set this bool to true. Set it to
+ // false to use the BILOU tagging model.
+ const static bool use_BIO_model = true;
+
+ // In the WHAT THIS OBJECT REPRESENTS section above we discussed how we model the
+ // conjunction of the previous label and the window around each position. Doing
+ // this greatly expands the size of the parameter vector w. You can optionally
+ // disable these higher order features by setting the use_high_order_features bool
+ // to false. This will cause XI() to include only slots which are independent of
+ // the previous label.
+ const static bool use_high_order_features = true;
+
+ // You use a tool like the structural_sequence_segmentation_trainer to learn the weight
+ // vector needed by a sequence_segmenter. You can tell the trainer to force all the
+ // elements of the weight vector corresponding to ZI() to be non-negative. This is all
+ // the elements of w except for the elements corresponding to the label transition and
+ // bias indicator features. To do this, just set allow_negative_weights to false.
+ const static bool allow_negative_weights = true;
+
+
+ example_feature_extractor (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ unsigned long num_features(
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the ZI() feature vector. This number is
+ always >= 1
+ !*/
+
+ unsigned long window_size(
+ ) const;
+ /*!
+ ensures
+ - returns the size of the window ZI() vectors are extracted from. This
+ number is always >= 1.
+ !*/
+
+ template <typename feature_setter>
+ void get_features (
+ feature_setter& set_feature,
+ const sequence_type& x,
+ unsigned long position
+ ) const;
+ /*!
+ requires
+ - position < x.size()
+ - set_feature is a function object which allows expressions of the form:
+ - set_features((unsigned long)feature_index, (double)feature_value);
+ - set_features((unsigned long)feature_index);
+ ensures
+ - This function computes the ZI(x,position) feature vector. This is a
+ feature vector which should capture the properties of x[position] that
+ are informative relative to the sequence segmentation task you are trying
+ to perform.
+ - ZI(x,position) is returned as a sparse vector by invoking set_feature().
+ For example, to set the feature with an index of 55 to the value of 1
+ this method would call:
+ set_feature(55);
+ Or equivalently:
+ set_feature(55,1);
+ Therefore, the first argument to set_feature is the index of the feature
+ to be set while the second argument is the value the feature should take.
+ Additionally, note that calling set_feature() multiple times with the
+ same feature index does NOT overwrite the old value, it adds to the
+ previous value. For example, if you call set_feature(55) 3 times then it
+ will result in feature 55 having a value of 3.
+ - This function only calls set_feature() with feature_index values < num_features()
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize(
+ const example_feature_extractor& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ void deserialize(
+ example_feature_extractor& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ unsigned long total_feature_vector_size (
+ const feature_extractor& fe
+ );
+ /*!
+ requires
+ - fe must be an object that implements an interface compatible with the
+ example_feature_extractor discussed above.
+ ensures
+ - returns the dimensionality of the PSI() vector defined by the given feature
+ extractor.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class sequence_segmenter
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ It must be an object that implements an interface compatible with
+ the example_feature_extractor discussed above.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for segmenting a sequence of objects into a set of
+ non-overlapping chunks. An example sequence segmentation task is to take
+ English sentences and identify all the named entities. In this example,
+ you would be using a sequence_segmenter to find all the chunks of
+ contiguous words which refer to proper names.
+
+ Internally, the sequence_segmenter uses the BIO (Begin, Inside, Outside) or
+ BILOU (Begin, Inside, Last, Outside, Unit) sequence tagging model.
+ Moreover, it is implemented using a dlib::sequence_labeler object and
+ therefore sequence_segmenter objects are examples of chain structured
+ conditional random field style sequence taggers.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call the const members of this object from multiple
+ threads so long as the feature_extractor is also threadsafe. This is
+ because the const members are purely read-only operations. However,
+ any operation that modifies a sequence_segmenter is not threadsafe.
+ !*/
+
+ public:
+ typedef typename feature_extractor::sequence_type sample_sequence_type;
+ typedef std::vector<std::pair<unsigned long, unsigned long> > segmented_sequence_type;
+
+ sequence_segmenter(
+ );
+ /*!
+ ensures
+ - #get_feature_extractor() == feature_extractor()
+ (i.e. it will have its default value)
+ - #get_weights().size() == total_feature_vector_size(#get_feature_extractor())
+ - #get_weights() == 0
+ !*/
+
+ explicit sequence_segmenter(
+ const matrix<double,0,1>& weights
+ );
+ /*!
+ requires
+ - total_feature_vector_size(feature_extractor()) == weights.size()
+ ensures
+ - #get_feature_extractor() == feature_extractor()
+ (i.e. it will have its default value)
+ - #get_weights() == weights
+ !*/
+
+ sequence_segmenter(
+ const matrix<double,0,1>& weights,
+ const feature_extractor& fe
+ );
+ /*!
+ requires
+ - total_feature_vector_size(fe) == weights.size()
+ ensures
+ - #get_feature_extractor() == fe
+ - #get_weights() == weights
+ !*/
+
+ const feature_extractor& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used by this object.
+ !*/
+
+ const matrix<double,0,1>& get_weights (
+ ) const;
+ /*!
+ ensures
+ - returns the parameter vector associated with this sequence segmenter.
+ The length of the vector is total_feature_vector_size(get_feature_extractor()).
+ !*/
+
+ segmented_sequence_type operator() (
+ const sample_sequence_type& x
+ ) const;
+ /*!
+ ensures
+ - Takes an input sequence and returns a list of detected segments within
+ that sequence.
+ - None of the returned segments will overlap.
+ - The returned segments are listed in the order they appeared in the input sequence.
+ - To be precise, this function returns a std::vector Y of segments such that:
+ - Y.size() == the number of segments detected in the input sequence x.
+ - for all valid i:
+ - Y[i].first == the start of the i-th segment.
+ - Y[i].second == one past the end of the i-th segment.
+ - Therefore, the i-th detected segment in x is composed of the elements
+ x[Y[i].first], x[Y[i].first+1], ..., x[Y[i].second-1]
+ - Y[i].first < x.size()
+ - Y[i].second <= x.size()
+ - Y[i].first < Y[i].second
+ (i.e. This function never outputs empty segments)
+ - Y[i].second <= Y[i+1].first
+ (i.e. the segments are listed in order of appearance and do not overlap)
+ !*/
+
+ void segment_sequence (
+ const sample_sequence_type& x,
+ segmented_sequence_type& y
+ ) const;
+ /*!
+ ensures
+ - #y == (*this)(x)
+ (i.e. This is just another interface to the operator() routine
+ above. This one avoids returning the results by value and therefore
+ might be a little faster in some cases)
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void serialize (
+ const sequence_segmenter<feature_extractor>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ void deserialize (
+ sequence_segmenter<feature_extractor>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SEQUENCE_SeGMENTER_ABSTRACT_H_h_
+
diff --git a/ml/dlib/dlib/svm/simplify_linear_decision_function.h b/ml/dlib/dlib/svm/simplify_linear_decision_function.h
new file mode 100644
index 000000000..4f5bef6f3
--- /dev/null
+++ b/ml/dlib/dlib/svm/simplify_linear_decision_function.h
@@ -0,0 +1,110 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_Hh_
+#define DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_Hh_
+
+#include "simplify_linear_decision_function_abstract.h"
+#include "../algs.h"
+#include "function.h"
+#include "sparse_kernel.h"
+#include "kernel.h"
+#include <map>
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ decision_function<sparse_linear_kernel<T> > simplify_linear_decision_function (
+ const decision_function<sparse_linear_kernel<T> >& df
+ )
+ {
+ // don't do anything if we don't have to
+ if (df.basis_vectors.size() <= 1)
+ return df;
+
+ decision_function<sparse_linear_kernel<T> > new_df;
+
+ new_df.b = df.b;
+ new_df.basis_vectors.set_size(1);
+ new_df.alpha.set_size(1);
+ new_df.alpha(0) = 1;
+
+ // now compute the weighted sum of all the sparse basis_vectors in df
+ typedef typename T::value_type pair_type;
+ typedef typename pair_type::first_type key_type;
+ typedef typename pair_type::second_type value_type;
+ std::map<key_type, value_type> accum;
+ for (long i = 0; i < df.basis_vectors.size(); ++i)
+ {
+ typename T::const_iterator j = df.basis_vectors(i).begin();
+ const typename T::const_iterator end = df.basis_vectors(i).end();
+ for (; j != end; ++j)
+ {
+ accum[j->first] += df.alpha(i) * (j->second);
+ }
+ }
+
+ new_df.basis_vectors(0) = T(accum.begin(), accum.end());
+
+ return new_df;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ decision_function<linear_kernel<T> > simplify_linear_decision_function (
+ const decision_function<linear_kernel<T> >& df
+ )
+ {
+ // don't do anything if we don't have to
+ if (df.basis_vectors.size() <= 1)
+ return df;
+
+ decision_function<linear_kernel<T> > new_df;
+
+ new_df.b = df.b;
+ new_df.basis_vectors.set_size(1);
+ new_df.alpha.set_size(1);
+ new_df.alpha(0) = 1;
+
+ // now compute the weighted sum of all the basis_vectors in df
+ new_df.basis_vectors(0) = 0;
+ for (long i = 0; i < df.basis_vectors.size(); ++i)
+ {
+ new_df.basis_vectors(0) += df.alpha(i) * df.basis_vectors(i);
+ }
+
+ return new_df;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ decision_function<linear_kernel<T> > simplify_linear_decision_function (
+ const normalized_function<decision_function<linear_kernel<T> >, vector_normalizer<T> >& df
+ )
+ {
+ decision_function<linear_kernel<T> > new_df = simplify_linear_decision_function(df.function);
+
+ // now incorporate the normalization stuff into new_df
+ new_df.basis_vectors(0) = pointwise_multiply(new_df.basis_vectors(0), df.normalizer.std_devs());
+ new_df.b += dot(new_df.basis_vectors(0), df.normalizer.means());
+
+ return new_df;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_Hh_
+
diff --git a/ml/dlib/dlib/svm/simplify_linear_decision_function_abstract.h b/ml/dlib/dlib/svm/simplify_linear_decision_function_abstract.h
new file mode 100644
index 000000000..cff8ae11f
--- /dev/null
+++ b/ml/dlib/dlib/svm/simplify_linear_decision_function_abstract.h
@@ -0,0 +1,74 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_ABSTRACT_Hh_
+#ifdef DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "function_abstract.h"
+#include "sparse_kernel_abstract.h"
+#include "kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ decision_function<sparse_linear_kernel<T> > simplify_linear_decision_function (
+ const decision_function<sparse_linear_kernel<T> >& df
+ );
+ /*!
+ requires
+ - T must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h
+ ensures
+ - returns a simplified version of df that only has one basis vector. That
+ is, returns a decision function D such that:
+ - D.basis_vectors.size() == 1 (or 0 if df is empty)
+ - for all possible x: D(x) == df(x)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ decision_function<linear_kernel<T> > simplify_linear_decision_function (
+ const decision_function<linear_kernel<T> >& df
+ );
+ /*!
+ requires
+ - T must be a dlib::matrix object
+ ensures
+ - returns a simplified version of df that only has one basis vector. That
+ is, returns a decision function D such that:
+ - D.basis_vectors.size() == 1 (or 0 if df is empty)
+ - for all possible x: D(x) == df(x)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ decision_function<linear_kernel<T> > simplify_linear_decision_function (
+ const normalized_function<decision_function<linear_kernel<T> >, vector_normalizer<T> >& df
+ );
+ /*!
+ requires
+ - T must be a dlib::matrix object
+ ensures
+ - returns a simplified version of df that only has one basis vector and
+ doesn't involve an explicit vector_normalizer. That is, returns a
+ decision function D such that:
+ - D.basis_vectors.size() == 1 (or 0 if df is empty)
+ - for all possible x: D(x) == df(x)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/sort_basis_vectors.h b/ml/dlib/dlib/svm/sort_basis_vectors.h
new file mode 100644
index 000000000..1d4605b41
--- /dev/null
+++ b/ml/dlib/dlib/svm/sort_basis_vectors.h
@@ -0,0 +1,224 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SORT_BASIS_VECTORs_Hh_
+#define DLIB_SORT_BASIS_VECTORs_Hh_
+
+#include <vector>
+
+#include "sort_basis_vectors_abstract.h"
+#include "../matrix.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace bs_impl
+ {
+ template <typename EXP>
+ typename EXP::matrix_type invert (
+ const matrix_exp<EXP>& m
+ )
+ {
+ eigenvalue_decomposition<EXP> eig(make_symmetric(m));
+
+ typedef typename EXP::type scalar_type;
+ typedef typename EXP::mem_manager_type mm_type;
+
+ matrix<scalar_type,0,1,mm_type> vals = eig.get_real_eigenvalues();
+
+ const scalar_type max_eig = max(abs(vals));
+ const scalar_type thresh = max_eig*std::sqrt(std::numeric_limits<scalar_type>::epsilon());
+
+ // Since m might be singular or almost singular we need to do something about
+ // any very small eigenvalues. So here we set the smallest eigenvalues to
+ // be equal to a large value to make the inversion stable. We can't just set
+ // them to zero like in a normal pseudo-inverse since we want the resulting
+ // inverse matrix to be full rank.
+ for (long i = 0; i < vals.size(); ++i)
+ {
+ if (std::abs(vals(i)) < thresh)
+ vals(i) = max_eig;
+ }
+
+ // Build the inverse matrix. This is basically a pseudo-inverse.
+ return make_symmetric(eig.get_pseudo_v()*diagm(reciprocal(vals))*trans(eig.get_pseudo_v()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename vect1_type,
+ typename vect2_type,
+ typename vect3_type
+ >
+ const std::vector<typename kernel_type::sample_type> sort_basis_vectors_impl (
+ const kernel_type& kern,
+ const vect1_type& samples,
+ const vect2_type& labels,
+ const vect3_type& basis,
+ double eps
+ )
+ {
+ DLIB_ASSERT(is_binary_classification_problem(samples, labels) &&
+ 0 < eps && eps <= 1 &&
+ basis.size() > 0,
+ "\t void sort_basis_vectors()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t is_binary_classification_problem(samples, labels): " << is_binary_classification_problem(samples, labels)
+ << "\n\t basis.size(): " << basis.size()
+ << "\n\t eps: " << eps
+ );
+
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::mem_manager_type mm_type;
+
+ typedef matrix<scalar_type,0,1,mm_type> col_matrix;
+ typedef matrix<scalar_type,0,0,mm_type> gen_matrix;
+
+ col_matrix c1_mean, c2_mean, temp, delta;
+
+
+ col_matrix weights;
+
+ running_covariance<gen_matrix> cov;
+
+ // compute the covariance matrix and the means of the two classes.
+ for (long i = 0; i < samples.size(); ++i)
+ {
+ temp = kernel_matrix(kern, basis, samples(i));
+ cov.add(temp);
+ if (labels(i) > 0)
+ c1_mean += temp;
+ else
+ c2_mean += temp;
+ }
+
+ c1_mean /= sum(labels > 0);
+ c2_mean /= sum(labels < 0);
+
+ delta = c1_mean - c2_mean;
+
+ gen_matrix cov_inv = bs_impl::invert(cov.covariance());
+
+
+ matrix<long,0,1,mm_type> total_perm = trans(range(0, delta.size()-1));
+ matrix<long,0,1,mm_type> perm = total_perm;
+
+ std::vector<std::pair<scalar_type,long> > sorted_feats(delta.size());
+
+ long best_size = delta.size();
+ long misses = 0;
+ matrix<long,0,1,mm_type> best_total_perm = perm;
+
+ // Now we basically find fisher's linear discriminant over and over. Each
+ // time sorting the features so that the most important ones pile up together.
+ weights = trans(chol(cov_inv))*delta;
+ while (true)
+ {
+
+ for (unsigned long i = 0; i < sorted_feats.size(); ++i)
+ sorted_feats[i] = make_pair(std::abs(weights(i)), i);
+
+ std::sort(sorted_feats.begin(), sorted_feats.end());
+
+ // make a permutation vector according to the sorting
+ for (long i = 0; i < perm.size(); ++i)
+ perm(i) = sorted_feats[i].second;
+
+
+ // Apply the permutation. Doing this gives the same result as permuting all the
+ // features and then recomputing the delta and cov_inv from scratch.
+ cov_inv = subm(cov_inv,perm,perm);
+ delta = rowm(delta,perm);
+
+ // Record all the permutations we have done so we will know how the final
+ // weights match up with the original basis vectors when we are done.
+ total_perm = rowm(total_perm, perm);
+
+ // compute new Fisher weights for sorted features.
+ weights = trans(chol(cov_inv))*delta;
+
+ // Measure how many features it takes to account for eps% of the weights vector.
+ const scalar_type total_weight = length_squared(weights);
+ scalar_type weight_accum = 0;
+ long size = 0;
+ // figure out how to get eps% of the weights
+ for (long i = weights.size()-1; i >= 0; --i)
+ {
+ ++size;
+ weight_accum += weights(i)*weights(i);
+ if (weight_accum/total_weight > eps)
+ break;
+ }
+
+ // loop until the best_size stops dropping
+ if (size < best_size)
+ {
+ misses = 0;
+ best_size = size;
+ best_total_perm = total_perm;
+ }
+ else
+ {
+ ++misses;
+
+ // Give up once we have had 10 rounds where we didn't find a weights vector with
+ // a smaller concentration of good features.
+ if (misses >= 10)
+ break;
+ }
+
+ }
+
+ // make sure best_size isn't zero
+ if (best_size == 0)
+ best_size = 1;
+
+ std::vector<typename kernel_type::sample_type> sorted_basis;
+
+ // permute the basis so that it matches up with the contents of the best weights
+ sorted_basis.resize(best_size);
+ for (unsigned long i = 0; i < sorted_basis.size(); ++i)
+ {
+ // Note that we load sorted_basis backwards so that the most important
+ // basis elements come first.
+ sorted_basis[i] = basis(best_total_perm(basis.size()-i-1));
+ }
+
+ return sorted_basis;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename vect1_type,
+ typename vect2_type,
+ typename vect3_type
+ >
+ const std::vector<typename kernel_type::sample_type> sort_basis_vectors (
+ const kernel_type& kern,
+ const vect1_type& samples,
+ const vect2_type& labels,
+ const vect3_type& basis,
+ double eps = 0.99
+ )
+ {
+ return bs_impl::sort_basis_vectors_impl(kern,
+ mat(samples),
+ mat(labels),
+ mat(basis),
+ eps);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SORT_BASIS_VECTORs_Hh_
+
diff --git a/ml/dlib/dlib/svm/sort_basis_vectors_abstract.h b/ml/dlib/dlib/svm/sort_basis_vectors_abstract.h
new file mode 100644
index 000000000..b43dca170
--- /dev/null
+++ b/ml/dlib/dlib/svm/sort_basis_vectors_abstract.h
@@ -0,0 +1,59 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SORT_BASIS_VECTORs_ABSTRACT_Hh_
+#ifdef DLIB_SORT_BASIS_VECTORs_ABSTRACT_Hh_
+
+#include <vector>
+
+#include "../matrix.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename vect1_type,
+ typename vect2_type,
+ typename vect3_type
+ >
+ const std::vector<typename kernel_type::sample_type> sort_basis_vectors (
+ const kernel_type& kern,
+ const vect1_type& samples,
+ const vect2_type& labels,
+ const vect3_type& basis,
+ double eps = 0.99
+ );
+ /*!
+ requires
+ - is_binary_classification_problem(samples, labels)
+ - 0 < eps <= 1
+ - basis.size() > 0
+ - kernel_type is a kernel function object as defined in dlib/svm/kernel_abstract.h
+ It must be capable of operating on the elements of samples and basis.
+ - vect1_type == a matrix or something convertible to a matrix via mat()
+ - vect2_type == a matrix or something convertible to a matrix via mat()
+ - vect3_type == a matrix or something convertible to a matrix via mat()
+ ensures
+ - A kernel based learning method ultimately needs to select a set of basis functions
+ represented by a particular choice of kernel and a set of basis vectors.
+ sort_basis_vectors() attempts to order the elements of basis so that elements which are
+ most useful in solving the binary classification problem defined by samples and
+ labels come first.
+ - In particular, this function returns a std::vector, SB, of sorted basis vectors such that:
+ - 0 < SB.size() <= basis.size()
+ - SB will contain elements from basis but they will have been sorted so that
+ the most useful elements come first (i.e. SB[0] is the most important).
+ - eps notionally controls how big SB will be. Bigger eps corresponds to a
+ bigger basis. You can think of it like asking for eps percent of the
+ discriminating power from the input basis.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SORT_BASIS_VECTORs_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/sparse_kernel.h b/ml/dlib/dlib/svm/sparse_kernel.h
new file mode 100644
index 000000000..f571135ec
--- /dev/null
+++ b/ml/dlib/dlib/svm/sparse_kernel.h
@@ -0,0 +1,384 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_SPARSE_KERNEL
+#define DLIB_SVm_SPARSE_KERNEL
+
+#include "sparse_kernel_abstract.h"
+#include <cmath>
+#include <limits>
+#include "../algs.h"
+#include "../serialize.h"
+#include "sparse_vector.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sparse_radial_basis_kernel
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ sparse_radial_basis_kernel(const scalar_type g) : gamma(g) {}
+ sparse_radial_basis_kernel() : gamma(0.1) {}
+ sparse_radial_basis_kernel(
+ const sparse_radial_basis_kernel& k
+ ) : gamma(k.gamma) {}
+
+
+ const scalar_type gamma;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ const scalar_type d = distance_squared(a,b);
+ return std::exp(-gamma*d);
+ }
+
+ sparse_radial_basis_kernel& operator= (
+ const sparse_radial_basis_kernel& k
+ )
+ {
+ const_cast<scalar_type&>(gamma) = k.gamma;
+ return *this;
+ }
+
+ bool operator== (
+ const sparse_radial_basis_kernel& k
+ ) const
+ {
+ return gamma == k.gamma;
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_radial_basis_kernel<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.gamma, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type sparse_radial_basis_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_radial_basis_kernel<T>& item,
+ std::istream& in
+ )
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ try
+ {
+ deserialize(const_cast<scalar_type&>(item.gamma), in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type sparse_radial_basis_kernel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sparse_polynomial_kernel
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ sparse_polynomial_kernel(const scalar_type g, const scalar_type c, const scalar_type d) : gamma(g), coef(c), degree(d) {}
+ sparse_polynomial_kernel() : gamma(1), coef(0), degree(1) {}
+ sparse_polynomial_kernel(
+ const sparse_polynomial_kernel& k
+ ) : gamma(k.gamma), coef(k.coef), degree(k.degree) {}
+
+ typedef T type;
+ const scalar_type gamma;
+ const scalar_type coef;
+ const scalar_type degree;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return std::pow(gamma*(dot(a,b)) + coef, degree);
+ }
+
+ sparse_polynomial_kernel& operator= (
+ const sparse_polynomial_kernel& k
+ )
+ {
+ const_cast<scalar_type&>(gamma) = k.gamma;
+ const_cast<scalar_type&>(coef) = k.coef;
+ const_cast<scalar_type&>(degree) = k.degree;
+ return *this;
+ }
+
+ bool operator== (
+ const sparse_polynomial_kernel& k
+ ) const
+ {
+ return (gamma == k.gamma) && (coef == k.coef) && (degree == k.degree);
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_polynomial_kernel<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.gamma, out);
+ serialize(item.coef, out);
+ serialize(item.degree, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type sparse_polynomial_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_polynomial_kernel<T>& item,
+ std::istream& in
+ )
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ try
+ {
+ deserialize(const_cast<scalar_type&>(item.gamma), in);
+ deserialize(const_cast<scalar_type&>(item.coef), in);
+ deserialize(const_cast<scalar_type&>(item.degree), in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type sparse_polynomial_kernel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sparse_sigmoid_kernel
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ sparse_sigmoid_kernel(const scalar_type g, const scalar_type c) : gamma(g), coef(c) {}
+ sparse_sigmoid_kernel() : gamma(0.1), coef(-1.0) {}
+ sparse_sigmoid_kernel(
+ const sparse_sigmoid_kernel& k
+ ) : gamma(k.gamma), coef(k.coef) {}
+
+ typedef T type;
+ const scalar_type gamma;
+ const scalar_type coef;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return std::tanh(gamma*(dot(a,b)) + coef);
+ }
+
+ sparse_sigmoid_kernel& operator= (
+ const sparse_sigmoid_kernel& k
+ )
+ {
+ const_cast<scalar_type&>(gamma) = k.gamma;
+ const_cast<scalar_type&>(coef) = k.coef;
+ return *this;
+ }
+
+ bool operator== (
+ const sparse_sigmoid_kernel& k
+ ) const
+ {
+ return (gamma == k.gamma) && (coef == k.coef);
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_sigmoid_kernel<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.gamma, out);
+ serialize(item.coef, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type sparse_sigmoid_kernel");
+ }
+ }
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_sigmoid_kernel<T>& item,
+ std::istream& in
+ )
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ try
+ {
+ deserialize(const_cast<scalar_type&>(item.gamma), in);
+ deserialize(const_cast<scalar_type&>(item.coef), in);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing object of type sparse_sigmoid_kernel");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct sparse_linear_kernel
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return dot(a,b);
+ }
+
+ bool operator== (
+ const sparse_linear_kernel&
+ ) const
+ {
+ return true;
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_linear_kernel<T>& ,
+ std::ostream&
+ ){}
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_linear_kernel<T>& ,
+ std::istream&
+ ){}
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct sparse_histogram_intersection_kernel
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ typename sample_type::const_iterator ai = a.begin();
+ typename sample_type::const_iterator bi = b.begin();
+
+ scalar_type sum = 0;
+ while (ai != a.end() && bi != b.end())
+ {
+ if (ai->first == bi->first)
+ {
+ sum += std::min(ai->second , bi->second);
+ ++ai;
+ ++bi;
+ }
+ else if (ai->first < bi->first)
+ {
+ ++ai;
+ }
+ else
+ {
+ ++bi;
+ }
+ }
+
+ return sum;
+ }
+
+ bool operator== (
+ const sparse_histogram_intersection_kernel&
+ ) const
+ {
+ return true;
+ }
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_histogram_intersection_kernel<T>& ,
+ std::ostream&
+ ){}
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_histogram_intersection_kernel<T>& ,
+ std::istream&
+ ){}
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_SPARSE_KERNEL
+
+
+
diff --git a/ml/dlib/dlib/svm/sparse_kernel_abstract.h b/ml/dlib/dlib/svm/sparse_kernel_abstract.h
new file mode 100644
index 000000000..55f9d7caa
--- /dev/null
+++ b/ml/dlib/dlib/svm/sparse_kernel_abstract.h
@@ -0,0 +1,486 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_SPARSE_KERNEL_ABSTRACT_
+#ifdef DLIB_SVm_SPARSE_KERNEL_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include "../algs.h"
+#include "../serialize.h"
+#include "kernel_abstract.h"
+#include "sparse_vector_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sparse_radial_basis_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a radial basis function kernel
+ that works with sparse vectors.
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ const scalar_type gamma;
+
+ sparse_radial_basis_kernel(
+ );
+ /*!
+ ensures
+ - #gamma == 0.1
+ !*/
+
+ sparse_radial_basis_kernel(
+ const sparse_radial_basis_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma == k.gamma
+ !*/
+
+ sparse_radial_basis_kernel(
+ const scalar_type g
+ );
+ /*!
+ ensures
+ - #gamma == g
+ !*/
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - a is a sparse vector
+ - b is a sparse vector
+ ensures
+ - returns exp(-gamma * distance_squared(a,b))
+ !*/
+
+ sparse_radial_basis_kernel& operator= (
+ const sparse_radial_basis_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma = k.gamma
+ - returns *this
+ !*/
+
+ bool operator== (
+ const sparse_radial_basis_kernel& k
+ ) const;
+ /*!
+ ensures
+ - if (k and *this are identical) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_radial_basis_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for sparse_radial_basis_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_radial_basis_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for sparse_radial_basis_kernel
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sparse_sigmoid_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a sigmoid kernel
+ that works with sparse vectors.
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ const scalar_type gamma;
+ const scalar_type coef;
+
+ sparse_sigmoid_kernel(
+ );
+ /*!
+ ensures
+ - #gamma == 0.1
+ - #coef == -1.0
+ !*/
+
+ sparse_sigmoid_kernel(
+ const sparse_sigmoid_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma == k.gamma
+ - #coef == k.coef
+ !*/
+
+ sparse_sigmoid_kernel(
+ const scalar_type g,
+ const scalar_type c
+ );
+ /*!
+ ensures
+ - #gamma == g
+ - #coef == c
+ !*/
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - a is a sparse vector
+ - b is a sparse vector
+ ensures
+ - returns tanh(gamma * dot(a,b) + coef)
+ !*/
+
+ sparse_sigmoid_kernel& operator= (
+ const sparse_sigmoid_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma = k.gamma
+ - #coef = k.coef
+ - returns *this
+ !*/
+
+ bool operator== (
+ const sparse_sigmoid_kernel& k
+ ) const;
+ /*!
+ ensures
+ - if (k and *this are identical) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_sigmoid_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for sparse_sigmoid_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_sigmoid_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for sparse_sigmoid_kernel
+ !*/
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sparse_polynomial_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a polynomial kernel
+ that works with sparse vectors.
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ const scalar_type gamma;
+ const scalar_type coef;
+ const scalar_type degree;
+
+ sparse_polynomial_kernel(
+ );
+ /*!
+ ensures
+ - #gamma == 1
+ - #coef == 0
+ - #degree == 1
+ !*/
+
+ sparse_polynomial_kernel(
+ const sparse_polynomial_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma == k.gamma
+ - #coef == k.coef
+ - #degree == k.degree
+ !*/
+
+ sparse_polynomial_kernel(
+ const scalar_type g,
+ const scalar_type c,
+ const scalar_type d
+ );
+ /*!
+ ensures
+ - #gamma == g
+ - #coef == c
+ - #degree == d
+ !*/
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - a is a sparse vector
+ - b is a sparse vector
+ ensures
+ - returns pow(gamma * dot(a,b) + coef, degree)
+ !*/
+
+ sparse_polynomial_kernel& operator= (
+ const sparse_polynomial_kernel& k
+ );
+ /*!
+ ensures
+ - #gamma = k.gamma
+ - #coef = k.coef
+ - #degree = k.degree
+ - returns *this
+ !*/
+
+ bool operator== (
+ const sparse_polynomial_kernel& k
+ ) const;
+ /*!
+ ensures
+ - if (k and *this are identical) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_polynomial_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for sparse_polynomial_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_polynomial_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for sparse_polynomial_kernel
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sparse_linear_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a linear function kernel
+ that works with sparse vectors.
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - a is a sparse vector
+ - b is a sparse vector
+ ensures
+ - returns dot(a,b)
+ !*/
+
+ bool operator== (
+ const sparse_linear_kernel& k
+ ) const;
+ /*!
+ ensures
+ - returns true
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_linear_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for sparse_linear_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_linear_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for sparse_linear_kernel
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct sparse_histogram_intersection_kernel
+ {
+ /*!
+ REQUIREMENTS ON T
+ Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a histogram intersection kernel
+ that works with sparse vectors.
+
+ THREAD SAFETY
+ This kernel is threadsafe.
+ !*/
+
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const;
+ /*!
+ requires
+ - a is a sparse vector
+ - b is a sparse vector
+ - all the values in a and b are >= 0
+ ensures
+ - Let A(i) denote the value of the ith dimension of the a vector.
+ - Let B(i) denote the value of the ith dimension of the b vector.
+ - returns sum over all i: std::min(A(i), B(i))
+ !*/
+
+ bool operator== (
+ const sparse_histogram_intersection_kernel& k
+ ) const;
+ /*!
+ ensures
+ - returns true
+ !*/
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const sparse_histogram_intersection_kernel<T>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support for sparse_histogram_intersection_kernel
+ !*/
+
+ template <
+ typename T
+ >
+ void deserialize (
+ sparse_histogram_intersection_kernel<T>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support for sparse_histogram_intersection_kernel
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_SPARSE_KERNEL_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/svm/sparse_vector.h b/ml/dlib/dlib/svm/sparse_vector.h
new file mode 100644
index 000000000..c42723f89
--- /dev/null
+++ b/ml/dlib/dlib/svm/sparse_vector.h
@@ -0,0 +1,1170 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_SPARSE_VECTOR
+#define DLIB_SVm_SPARSE_VECTOR
+
+#include "sparse_vector_abstract.h"
+#include <cmath>
+#include <limits>
+#include "../algs.h"
+#include <vector>
+#include <map>
+#include "../graph_utils/edge_list_graphs.h"
+#include "../matrix.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ typename T::value_type::second_type distance_squared (
+ const T& a,
+ const U& b
+ )
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ typedef typename U::value_type::second_type scalar_typeU;
+ // Both T and U must contain the same kinds of elements
+ COMPILE_TIME_ASSERT((is_same_type<scalar_type, scalar_typeU>::value));
+
+ typename T::const_iterator ai = a.begin();
+ typename U::const_iterator bi = b.begin();
+
+ scalar_type sum = 0, temp = 0;
+ while (ai != a.end() && bi != b.end())
+ {
+ if (ai->first == bi->first)
+ {
+ temp = ai->second - bi->second;
+ ++ai;
+ ++bi;
+ }
+ else if (ai->first < bi->first)
+ {
+ temp = ai->second;
+ ++ai;
+ }
+ else
+ {
+ temp = bi->second;
+ ++bi;
+ }
+
+ sum += temp*temp;
+ }
+
+ while (ai != a.end())
+ {
+ sum += ai->second*ai->second;
+ ++ai;
+ }
+ while (bi != b.end())
+ {
+ sum += bi->second*bi->second;
+ ++bi;
+ }
+
+ return sum;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename V, typename W>
+ typename T::value_type::second_type distance_squared (
+ const V& a_scale,
+ const T& a,
+ const W& b_scale,
+ const U& b
+ )
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ typedef typename U::value_type::second_type scalar_typeU;
+ // Both T and U must contain the same kinds of elements
+ COMPILE_TIME_ASSERT((is_same_type<scalar_type, scalar_typeU>::value));
+
+ typename T::const_iterator ai = a.begin();
+ typename U::const_iterator bi = b.begin();
+
+ scalar_type sum = 0, temp = 0;
+ while (ai != a.end() && bi != b.end())
+ {
+ if (ai->first == bi->first)
+ {
+ temp = a_scale*ai->second - b_scale*bi->second;
+ ++ai;
+ ++bi;
+ }
+ else if (ai->first < bi->first)
+ {
+ temp = a_scale*ai->second;
+ ++ai;
+ }
+ else
+ {
+ temp = b_scale*bi->second;
+ ++bi;
+ }
+
+ sum += temp*temp;
+ }
+
+ while (ai != a.end())
+ {
+ sum += a_scale*a_scale*ai->second*ai->second;
+ ++ai;
+ }
+ while (bi != b.end())
+ {
+ sum += b_scale*b_scale*bi->second*bi->second;
+ ++bi;
+ }
+
+ return sum;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ typename T::value_type::second_type distance (
+ const T& a,
+ const U& b
+ )
+ {
+ return std::sqrt(distance_squared(a,b));
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename V, typename W>
+ typename T::value_type::second_type distance (
+ const V& a_scale,
+ const T& a,
+ const W& b_scale,
+ const U& b
+ )
+ {
+ return std::sqrt(distance_squared(a_scale,a,b_scale,b));
+ }
+
+// ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ template <typename T, typename EXP>
+ typename enable_if<is_matrix<T> >::type assign (
+ T& dest,
+ const matrix_exp<EXP>& src
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(src),
+ "\t void assign(dest,src)"
+ << "\n\t the src matrix must be a row or column vector"
+ );
+
+ dest = src;
+ }
+
+ template <typename T, typename EXP>
+ typename disable_if<is_matrix<T> >::type assign (
+ T& dest,
+ const matrix_exp<EXP>& src
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(src),
+ "\t void assign(dest,src)"
+ << "\n\t the src matrix must be a row or column vector"
+ );
+
+ dest.clear();
+ typedef typename T::value_type item_type;
+ for (long i = 0; i < src.size(); ++i)
+ {
+ dest.insert(dest.end(),item_type(i, src(i)));
+ }
+ }
+
+ template <typename T, typename U>
+ typename disable_if_c<is_matrix<T>::value || is_matrix<U>::value>::type assign (
+ T& dest, // sparse
+ const U& src // sparse
+ )
+ {
+ dest.assign(src.begin(), src.end());
+ }
+
+ template <typename T, typename U, typename Comp, typename Alloc, typename S>
+ typename disable_if<is_matrix<S> >::type assign (
+ std::map<T,U,Comp,Alloc>& dest, // sparse
+ const S& src // sparse
+ )
+ {
+ dest.clear();
+ dest.insert(src.begin(), src.end());
+ }
+
+// ------------------------------------------------------------------------------------
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct has_unsigned_keys
+ {
+ static const bool value = is_unsigned_type<typename T::value_type::first_type>::value;
+ };
+
+// ------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T, typename U>
+ typename T::value_type::second_type general_dot (
+ const T& a,
+ const U& b
+ )
+ {
+ typedef typename T::value_type::second_type scalar_type;
+
+ typename T::const_iterator ai = a.begin();
+ typename U::const_iterator bi = b.begin();
+
+ scalar_type sum = 0;
+ while (ai != a.end() && bi != b.end())
+ {
+ if (ai->first == bi->first)
+ {
+ sum += ai->second * bi->second;
+ ++ai;
+ ++bi;
+ }
+ else if (ai->first < bi->first)
+ {
+ ++ai;
+ }
+ else
+ {
+ ++bi;
+ }
+ }
+
+ return sum;
+ }
+
+ template <typename T, typename U>
+ inline typename T::value_type::second_type dot (
+ const T& a,
+ const U& b
+ )
+ {
+ return general_dot(a,b);
+ }
+
+ template <typename T, typename U, typename alloc>
+ U dot (
+ const std::vector<std::pair<T,U>,alloc>& a,
+ const std::vector<std::pair<T,U>,alloc>& b
+ )
+ {
+ // You are getting this error because you are attempting to use sparse sample vectors
+ // but you aren't using an unsigned integer as your key type in the sparse vectors.
+ COMPILE_TIME_ASSERT(is_unsigned_type<T>::value);
+
+ if (a.size() == 0 || b.size() == 0)
+ return 0;
+
+ // if a is really a dense vector but just represented in a sparse container
+ if (a.back().first == a.size()-1)
+ {
+ double sum = 0;
+ for (unsigned long i = 0; i < b.size(); ++i)
+ {
+ if (b[i].first >= a.size())
+ break;
+ sum += a[b[i].first].second * b[i].second;
+ }
+ return sum;
+ }
+ // if b is really a dense vector but just represented in a sparse container
+ else if (b.back().first == b.size()-1)
+ {
+ double sum = 0;
+ for (unsigned long i = 0; i < a.size(); ++i)
+ {
+ if (a[i].first >= b.size())
+ break;
+ sum += b[a[i].first].second * a[i].second;
+ }
+ return sum;
+ }
+ else
+ {
+ return general_dot(a,b);
+ }
+ }
+ }
+
+ template <typename T>
+ inline typename T::value_type::second_type dot (
+ const T& a,
+ const T& b
+ )
+ {
+ return impl::dot(a,b);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5, typename T6>
+ inline T4 dot (
+ const std::vector<T1,T2>& a,
+ const std::map<T3,T4,T5,T6>& b
+ )
+ {
+ return impl::dot(a,b);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5, typename T6>
+ inline T4 dot (
+ const std::map<T3,T4,T5,T6>& a,
+ const std::vector<T1,T2>& b
+ )
+ {
+ return impl::dot(a,b);
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, typename EXP>
+ typename T::value_type::second_type dot (
+ const T& a,
+ const matrix_exp<EXP>& b
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(b),
+ "\t scalar_type dot(sparse_vector a, dense_vector b)"
+ << "\n\t 'b' must be a vector to be used in a dot product."
+ );
+
+ typedef typename T::value_type::second_type scalar_type;
+ typedef typename T::value_type::first_type first_type;
+
+ scalar_type sum = 0;
+ for (typename T::const_iterator ai = a.begin();
+ (ai != a.end()) && (ai->first < static_cast<first_type>(b.size()));
+ ++ai)
+ {
+ sum += ai->second * b(ai->first);
+ }
+
+ return sum;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, typename EXP>
+ typename T::value_type::second_type dot (
+ const matrix_exp<EXP>& b,
+ const T& a
+ )
+ {
+ return dot(a,b);
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename T::value_type::second_type length_squared (
+ const T& a
+ )
+ {
+ typedef typename T::value_type::second_type scalar_type;
+
+ typename T::const_iterator i;
+
+ scalar_type sum = 0;
+
+ for (i = a.begin(); i != a.end(); ++i)
+ {
+ sum += i->second * i->second;
+ }
+
+ return sum;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename T::value_type::second_type length (
+ const T& a
+ )
+ {
+ return std::sqrt(length_squared(a));
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ typename disable_if<is_matrix<T>,void>::type scale_by (
+ T& a,
+ const U& value
+ )
+ {
+ for (typename T::iterator i = a.begin(); i != a.end(); ++i)
+ {
+ i->second *= value;
+ }
+ }
+
+ template <typename T, typename U>
+ typename enable_if<is_matrix<T>,void>::type scale_by (
+ T& a,
+ const U& value
+ )
+ {
+ a *= value;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename disable_if<is_matrix<T>,T>::type add (
+ const T& a,
+ const T& b
+ )
+ {
+ T temp;
+
+ typename T::const_iterator i = a.begin();
+ typename T::const_iterator j = b.begin();
+ while (i != a.end() && j != b.end())
+ {
+ if (i->first == j->first)
+ {
+ temp.insert(temp.end(), std::make_pair(i->first, i->second + j->second));
+ ++i;
+ ++j;
+ }
+ else if (i->first < j->first)
+ {
+ temp.insert(temp.end(), *i);
+ ++i;
+ }
+ else
+ {
+ temp.insert(temp.end(), *j);
+ ++j;
+ }
+ }
+
+ while (i != a.end())
+ {
+ temp.insert(temp.end(), *i);
+ ++i;
+ }
+ while (j != b.end())
+ {
+ temp.insert(temp.end(), *j);
+ ++j;
+ }
+
+ return temp;
+ }
+
+ template <typename T, typename U>
+ typename enable_if_c<is_matrix<T>::value && is_matrix<U>::value, matrix_add_exp<T,U> >::type add (
+ const T& a,
+ const U& b
+ )
+ {
+ return matrix_add_exp<T,U>(a.ref(),b.ref());
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename disable_if<is_matrix<T>,T>::type subtract (
+ const T& a,
+ const T& b
+ )
+ {
+ T temp;
+
+ typename T::const_iterator i = a.begin();
+ typename T::const_iterator j = b.begin();
+ while (i != a.end() && j != b.end())
+ {
+ if (i->first == j->first)
+ {
+ temp.insert(temp.end(), std::make_pair(i->first, i->second - j->second));
+ ++i;
+ ++j;
+ }
+ else if (i->first < j->first)
+ {
+ temp.insert(temp.end(), *i);
+ ++i;
+ }
+ else
+ {
+ temp.insert(temp.end(), std::make_pair(j->first, -j->second));
+ ++j;
+ }
+ }
+
+ while (i != a.end())
+ {
+ temp.insert(temp.end(), *i);
+ ++i;
+ }
+ while (j != b.end())
+ {
+ temp.insert(temp.end(), std::make_pair(j->first, -j->second));
+ ++j;
+ }
+
+ return temp;
+ }
+
+ template <typename T, typename U>
+ typename enable_if_c<is_matrix<T>::value && is_matrix<U>::value, matrix_subtract_exp<T,U> >::type subtract (
+ const T& a,
+ const U& b
+ )
+ {
+ return matrix_subtract_exp<T,U>(a.ref(),b.ref());
+ }
+
+// ------------------------------------------------------------------------------------
+// ------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T>
+ typename enable_if<is_matrix<typename T::type>,unsigned long>::type max_index_plus_one (
+ const T& samples
+ )
+ {
+ if (samples.size() > 0)
+ return samples(0).size();
+ else
+ return 0;
+ }
+
+ template <typename T>
+ typename enable_if<is_built_in_scalar_type<typename T::type>,unsigned long>::type max_index_plus_one (
+ const T& sample
+ )
+ {
+ return sample.size();
+ }
+
+ // This !is_built_in_scalar_type<typename T::type>::value is here to avoid an inexplicable bug in Vistual Studio 2005
+ template <typename T>
+ typename enable_if_c<(!is_built_in_scalar_type<typename T::type>::value) && (is_pair<typename T::type::value_type>::value) ,unsigned long>::type
+ max_index_plus_one (
+ const T& samples
+ )
+ {
+ typedef typename T::type sample_type;
+ // You are getting this error because you are attempting to use sparse sample vectors
+ // but you aren't using an unsigned integer as your key type in the sparse vectors.
+ COMPILE_TIME_ASSERT(has_unsigned_keys<sample_type>::value);
+
+
+ // these should be sparse samples so look over all them to find the max index.
+ unsigned long max_dim = 0;
+ for (long i = 0; i < samples.size(); ++i)
+ {
+ if (samples(i).size() > 0)
+ max_dim = std::max<unsigned long>(max_dim, (--samples(i).end())->first + 1);
+ }
+
+ return max_dim;
+ }
+ }
+
+ template <typename T>
+ typename enable_if<is_pair<typename T::value_type>,unsigned long>::type max_index_plus_one (
+ const T& sample
+ )
+ {
+ if (sample.size() > 0)
+ return (--sample.end())->first + 1;
+ return 0;
+ }
+
+ template <typename T>
+ typename disable_if_c<is_pair<typename T::value_type>::value ||
+ is_same_type<typename T::value_type,sample_pair>::value ||
+ is_same_type<typename T::value_type,ordered_sample_pair>::value , unsigned long>::type
+ max_index_plus_one (
+ const T& samples
+ )
+ {
+ return impl::max_index_plus_one(mat(samples));
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename EXP>
+ inline void add_to (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_exp<EXP>& src
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast<unsigned long>(dest.size()),
+ "\t void add_to(dest,src)"
+ << "\n\t dest must be a vector large enough to hold the src vector."
+ << "\n\t is_vector(dest): " << is_vector(dest)
+ << "\n\t max_index_plus_one(src): " << max_index_plus_one(src)
+ << "\n\t dest.size(): " << dest.size()
+ );
+
+ for (long r = 0; r < src.size(); ++r)
+ dest(r) += src(r);
+ }
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename EXP>
+ inline typename disable_if<is_matrix<EXP> >::type add_to (
+ matrix<T,NR,NC,MM,L>& dest,
+ const EXP& src
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast<unsigned long>(dest.size()),
+ "\t void add_to(dest,src)"
+ << "\n\t dest must be a vector large enough to hold the src vector."
+ << "\n\t is_vector(dest): " << is_vector(dest)
+ << "\n\t max_index_plus_one(src): " << max_index_plus_one(src)
+ << "\n\t dest.size(): " << dest.size()
+ );
+
+ for (typename EXP::const_iterator i = src.begin(); i != src.end(); ++i)
+ dest(i->first) += i->second;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename EXP, typename U>
+ inline void add_to (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_exp<EXP>& src,
+ const U& C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast<unsigned long>(dest.size()),
+ "\t void add_to(dest,src)"
+ << "\n\t dest must be a vector large enough to hold the src vector."
+ << "\n\t is_vector(dest): " << is_vector(dest)
+ << "\n\t max_index_plus_one(src): " << max_index_plus_one(src)
+ << "\n\t dest.size(): " << dest.size()
+ );
+
+ for (long r = 0; r < src.size(); ++r)
+ dest(r) += C*src(r);
+ }
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename EXP, typename U>
+ inline typename disable_if<is_matrix<EXP> >::type add_to (
+ matrix<T,NR,NC,MM,L>& dest,
+ const EXP& src,
+ const U& C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast<unsigned long>(dest.size()),
+ "\t void add_to(dest,src)"
+ << "\n\t dest must be a vector large enough to hold the src vector."
+ << "\n\t is_vector(dest): " << is_vector(dest)
+ << "\n\t max_index_plus_one(src): " << max_index_plus_one(src)
+ << "\n\t dest.size(): " << dest.size()
+ );
+
+ for (typename EXP::const_iterator i = src.begin(); i != src.end(); ++i)
+ dest(i->first) += C*i->second;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename EXP>
+ inline void subtract_from (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_exp<EXP>& src
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast<unsigned long>(dest.size()),
+ "\t void subtract_from(dest,src)"
+ << "\n\t dest must be a vector large enough to hold the src vector."
+ << "\n\t is_vector(dest): " << is_vector(dest)
+ << "\n\t max_index_plus_one(src): " << max_index_plus_one(src)
+ << "\n\t dest.size(): " << dest.size()
+ );
+
+ for (long r = 0; r < src.size(); ++r)
+ dest(r) -= src(r);
+ }
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename EXP>
+ inline typename disable_if<is_matrix<EXP> >::type subtract_from (
+ matrix<T,NR,NC,MM,L>& dest,
+ const EXP& src
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast<unsigned long>(dest.size()),
+ "\t void subtract_from(dest,src)"
+ << "\n\t dest must be a vector large enough to hold the src vector."
+ << "\n\t is_vector(dest): " << is_vector(dest)
+ << "\n\t max_index_plus_one(src): " << max_index_plus_one(src)
+ << "\n\t dest.size(): " << dest.size()
+ );
+
+ for (typename EXP::const_iterator i = src.begin(); i != src.end(); ++i)
+ dest(i->first) -= i->second;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename EXP, typename U>
+ inline void subtract_from (
+ matrix<T,NR,NC,MM,L>& dest,
+ const matrix_exp<EXP>& src,
+ const U& C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast<unsigned long>(dest.size()),
+ "\t void subtract_from(dest,src)"
+ << "\n\t dest must be a vector large enough to hold the src vector."
+ << "\n\t is_vector(dest): " << is_vector(dest)
+ << "\n\t max_index_plus_one(src): " << max_index_plus_one(src)
+ << "\n\t dest.size(): " << dest.size()
+ );
+
+ for (long r = 0; r < src.size(); ++r)
+ dest(r) -= C*src(r);
+ }
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename EXP, typename U>
+ inline typename disable_if<is_matrix<EXP> >::type subtract_from (
+ matrix<T,NR,NC,MM,L>& dest,
+ const EXP& src,
+ const U& C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast<unsigned long>(dest.size()),
+ "\t void subtract_from(dest,src)"
+ << "\n\t dest must be a vector large enough to hold the src vector."
+ << "\n\t is_vector(dest): " << is_vector(dest)
+ << "\n\t max_index_plus_one(src): " << max_index_plus_one(src)
+ << "\n\t dest.size(): " << dest.size()
+ );
+
+ for (typename EXP::const_iterator i = src.begin(); i != src.end(); ++i)
+ dest(i->first) -= C*i->second;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename T::value_type::second_type min (
+ const T& a
+ )
+ {
+ typedef typename T::value_type::second_type type;
+
+ type temp = 0;
+ for (typename T::const_iterator i = a.begin(); i != a.end(); ++i)
+ {
+ if (temp > i->second)
+ temp = i->second;
+ }
+ return temp;
+ }
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename T::value_type::second_type max (
+ const T& a
+ )
+ {
+ typedef typename T::value_type::second_type type;
+
+ type temp = 0;
+ for (typename T::const_iterator i = a.begin(); i != a.end(); ++i)
+ {
+ if (temp < i->second)
+ temp = i->second;
+ }
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename sparse_vector_type>
+ inline matrix<typename sparse_vector_type::value_type::second_type,0,1> sparse_to_dense (
+ const sparse_vector_type& vect,
+ unsigned long num_dimensions
+ )
+ {
+ // You must use unsigned integral key types in your sparse vectors
+ typedef typename sparse_vector_type::value_type::first_type idx_type;
+ typedef typename sparse_vector_type::value_type::second_type value_type;
+ COMPILE_TIME_ASSERT(is_unsigned_type<idx_type>::value);
+
+ matrix<value_type,0,1> result;
+
+ if (vect.size() == 0)
+ return result;
+
+ result.set_size(num_dimensions);
+ result = 0;
+
+ for (typename sparse_vector_type::const_iterator j = vect.begin(); j != vect.end(); ++j)
+ {
+ if ((long)(j->first) < result.size())
+ {
+ result(j->first) += j->second;
+ }
+ }
+
+ return result;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename idx_type, typename value_type, typename alloc>
+ matrix<value_type,0,1> sparse_to_dense (
+ const std::vector<std::pair<idx_type,value_type>,alloc>& vect,
+ unsigned long num_dimensions
+ )
+ {
+ return impl::sparse_to_dense(vect,num_dimensions);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename idx_type, typename value_type, typename alloc>
+ matrix<value_type,0,1> sparse_to_dense (
+ const std::vector<std::pair<idx_type,value_type>,alloc>& vect
+ )
+ {
+ return impl::sparse_to_dense(vect, max_index_plus_one(vect));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T1, typename T2, typename T3, typename T4>
+ matrix<T2,0,1> sparse_to_dense (
+ const std::map<T1,T2,T3,T4>& vect,
+ unsigned long num_dimensions
+ )
+ {
+ return impl::sparse_to_dense(vect,num_dimensions);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T1, typename T2, typename T3, typename T4>
+ matrix<T2,0,1> sparse_to_dense (
+ const std::map<T1,T2,T3,T4>& vect
+ )
+ {
+ return impl::sparse_to_dense(vect, max_index_plus_one(vect));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename enable_if<is_matrix<T>,T&>::type sparse_to_dense(
+ T& item
+ ) { return item; }
+
+ template <typename EXP>
+ matrix<typename EXP::type,0,1> sparse_to_dense(
+ const matrix_exp<EXP>& item,
+ unsigned long num
+ )
+ {
+ typedef typename EXP::type type;
+ if (item.size() == (long)num)
+ return item;
+ else if (item.size() < (long)num)
+ return join_cols(item, zeros_matrix<type>((long)num-item.size(),1));
+ else
+ return colm(item,0,(long)num);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename sample_type, typename alloc>
+ std::vector<matrix<typename sample_type::value_type::second_type,0,1> > sparse_to_dense (
+ const std::vector<sample_type, alloc>& samples,
+ unsigned long num_dimensions
+ )
+ {
+ typedef typename sample_type::value_type pair_type;
+ typedef typename pair_type::second_type value_type;
+
+ std::vector< matrix<value_type,0,1> > result;
+
+ // now turn all the samples into dense samples
+ result.resize(samples.size());
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ result[i] = sparse_to_dense(samples[i],num_dimensions);
+ }
+
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename sample_type, typename alloc>
+ std::vector<matrix<typename sample_type::value_type::second_type,0,1> > sparse_to_dense (
+ const std::vector<sample_type, alloc>& samples
+ )
+ {
+ return sparse_to_dense(samples, max_index_plus_one(samples));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T make_sparse_vector (
+ const T& v
+ )
+ {
+ // You must use unsigned integral key types in your sparse vectors
+ typedef typename T::value_type::first_type idx_type;
+ typedef typename T::value_type::second_type value_type;
+ COMPILE_TIME_ASSERT(is_unsigned_type<idx_type>::value);
+ std::map<idx_type,value_type> temp;
+ for (typename T::const_iterator i = v.begin(); i != v.end(); ++i)
+ {
+ temp[i->first] += i->second;
+ }
+
+ return T(temp.begin(), temp.end());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void make_sparse_vector_inplace(
+ T& vect
+ )
+ {
+ vect = make_sparse_vector(vect);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename alloc
+ >
+ void make_sparse_vector_inplace (
+ std::vector<std::pair<T,U>,alloc>& vect
+ )
+ {
+ if (vect.size() > 0)
+ {
+ std::sort(vect.begin(), vect.end());
+
+ // merge duplicates
+ for (unsigned long i = 1; i < vect.size(); ++i)
+ {
+ // if we found a duplicate
+ if (vect[i-1].first == vect[i].first)
+ {
+ // now start collapsing and merging the vector
+ unsigned long j = i-1;
+ for (unsigned long k = i; k < vect.size(); ++k)
+ {
+ if (vect[j].first == vect[k].first)
+ {
+ vect[j].second += vect[k].second;
+ }
+ else
+ {
+ ++j;
+ vect[j] = vect[k];
+ }
+ }
+
+
+ // we removed elements when we merged so we need to adjust the size.
+ vect.resize(j+1);
+ return;
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP, typename T, long NR, long NC, typename MM, typename L>
+ void sparse_matrix_vector_multiply (
+ const std::vector<sample_pair>& edges,
+ const matrix_exp<EXP>& v,
+ matrix<T,NR,NC,MM,L>& result
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_index_plus_one(edges) <= (unsigned long)v.size() &&
+ is_col_vector(v),
+ "\t void sparse_matrix_vector_multiply()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t max_index_plus_one(edges): " << max_index_plus_one(edges)
+ << "\n\t v.size(): " << v.size()
+ << "\n\t is_col_vector(v): " << is_col_vector(v)
+ );
+
+ result.set_size(v.nr(),v.nc());
+ result = 0;
+
+ for (unsigned long k = 0; k < edges.size(); ++k)
+ {
+ const long i = edges[k].index1();
+ const long j = edges[k].index2();
+ const double d = edges[k].distance();
+
+ result(i) += v(j)*d;
+ if (i != j)
+ result(j) += v(i)*d;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ matrix<typename EXP::type,0,1> sparse_matrix_vector_multiply (
+ const std::vector<sample_pair>& edges,
+ const matrix_exp<EXP>& v
+ )
+ {
+ matrix<typename EXP::type,0,1> result;
+ sparse_matrix_vector_multiply(edges,v,result);
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP, typename T, long NR, long NC, typename MM, typename L>
+ void sparse_matrix_vector_multiply (
+ const std::vector<ordered_sample_pair>& edges,
+ const matrix_exp<EXP>& v,
+ matrix<T,NR,NC,MM,L>& result
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_index_plus_one(edges) <= (unsigned long)v.size() &&
+ is_col_vector(v),
+ "\t void sparse_matrix_vector_multiply()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t max_index_plus_one(edges): " << max_index_plus_one(edges)
+ << "\n\t v.size(): " << v.size()
+ << "\n\t is_col_vector(v): " << is_col_vector(v)
+ );
+
+
+ result.set_size(v.nr(),v.nc());
+ result = 0;
+
+ for (unsigned long k = 0; k < edges.size(); ++k)
+ {
+ const long i = edges[k].index1();
+ const long j = edges[k].index2();
+ const double d = edges[k].distance();
+
+ result(i) += v(j)*d;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ matrix<typename EXP::type,0,1> sparse_matrix_vector_multiply (
+ const std::vector<ordered_sample_pair>& edges,
+ const matrix_exp<EXP>& v
+ )
+ {
+ matrix<typename EXP::type,0,1> result;
+ sparse_matrix_vector_multiply(edges,v,result);
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename sparse_vector_type,
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ void sparse_matrix_vector_multiply (
+ const matrix_exp<EXP>& m,
+ const sparse_vector_type& v,
+ matrix<T,NR,NC,MM,L>& result
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_index_plus_one(v) <= (unsigned long)m.nc(),
+ "\t void sparse_matrix_vector_multiply()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t max_index_plus_one(v): " << max_index_plus_one(v)
+ << "\n\t m.size(): " << m.size()
+ );
+
+ result.set_size(m.nr(),1);
+ result = 0;
+
+ for (typename sparse_vector_type::const_iterator i = v.begin(); i != v.end(); ++i)
+ {
+ for (long r = 0; r < result.nr(); ++r)
+ {
+ result(r) += m(r, i->first)*i->second;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename sparse_vector_type
+ >
+ matrix<typename EXP::type,0,1> sparse_matrix_vector_multiply (
+ const matrix_exp<EXP>& m,
+ const sparse_vector_type& v
+ )
+ {
+ matrix<typename EXP::type,0,1> result;
+ sparse_matrix_vector_multiply(m,v,result);
+ return result;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_SPARSE_VECTOR
+
diff --git a/ml/dlib/dlib/svm/sparse_vector_abstract.h b/ml/dlib/dlib/svm/sparse_vector_abstract.h
new file mode 100644
index 000000000..e0c8d1f8c
--- /dev/null
+++ b/ml/dlib/dlib/svm/sparse_vector_abstract.h
@@ -0,0 +1,688 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_SPARSE_VECTOR_ABSTRACT_
+#ifdef DLIB_SVm_SPARSE_VECTOR_ABSTRACT_
+
+#include <cmath>
+#include "../algs.h"
+#include "../serialize.h"
+#include "../matrix.h"
+#include <map>
+#include <vector>
+#include "../graph_utils/sample_pair_abstract.h"
+#include "../graph_utils/ordered_sample_pair_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A sparse_vectors
+
+ In dlib, sparse vectors are represented using the container objects
+ in the C++ STL. In particular, a sparse vector is any container that
+ contains a range of std::pair<key, scalar_value> objects where:
+ - key is an unsigned integral type
+ - scalar_value is float, double, or long double
+ - the std::pair objects have unique key values
+ - the std::pair objects are sorted such that small keys come first
+
+ Therefore, if an object satisfies the above requirements we call it a
+ "sparse vector". Additionally, we define the concept of an "unsorted sparse vector"
+ to be a sparse vector that doesn't necessarily have sorted or unique key values.
+ Therefore, all sparse vectors are valid unsorted sparse vectors but not the other
+ way around.
+
+ An unsorted sparse vector with duplicate keys is always interpreted as
+ a vector where each dimension contains the sum of all corresponding elements
+ of the unsorted sparse vector. For example, an unsorted sparse vector
+ with the elements { (3,1), (0, 4), (3,5) } represents the 4D vector:
+ [4, 0, 0, 1+5]
+
+
+
+ Examples of valid sparse vectors are:
+ - std::map<unsigned long, double>
+ - std::vector<std::pair<unsigned long, float> > where the vector is sorted.
+ (you could make sure it was sorted by applying std::sort to it)
+
+
+ Finally, by "dense vector" we mean a dlib::matrix object which represents
+ either a row or column vector.
+
+ The rest of this file defines a number of helper functions for doing normal
+ vector arithmetic things with sparse vectors.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ /*!A has_unsigned_keys
+
+ This is a template where has_unsigned_keys<T>::value == true when T is a
+ sparse vector that contains unsigned integral keys and false otherwise.
+ !*/
+
+ template <typename T>
+ struct has_unsigned_keys
+ {
+ static const bool value = is_unsigned_type<typename T::value_type::first_type>::value;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ typename T::value_type::second_type distance_squared (
+ const T& a,
+ const U& b
+ );
+ /*!
+ requires
+ - a and b are sparse vectors
+ ensures
+ - returns the squared distance between the vectors
+ a and b
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename V, typename W>
+ typename T::value_type::second_type distance_squared (
+ const V& a_scale,
+ const T& a,
+ const W& b_scale,
+ const U& b
+ );
+ /*!
+ requires
+ - a and b are sparse vectors
+ ensures
+ - returns the squared distance between the vectors
+ a_scale*a and b_scale*b
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ typename T::value_type::second_type distance (
+ const T& a,
+ const U& b
+ );
+ /*!
+ requires
+ - a and b are sparse vectors
+ ensures
+ - returns the distance between the vectors
+ a and b. (i.e. std::sqrt(distance_squared(a,b)))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U, typename V, typename W>
+ typename T::value_type::second_type distance (
+ const V& a_scale,
+ const T& a,
+ const W& b_scale,
+ const U& b
+ );
+ /*!
+ requires
+ - a and b are sparse vectors
+ ensures
+ - returns the distance between the vectors
+ a_scale*a and b_scale*b. (i.e. std::sqrt(distance_squared(a_scale,a,b_scale,b)))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ void assign (
+ T& dest,
+ const U& src
+ );
+ /*!
+ requires
+ - dest == a sparse vector or a dense vector
+ - src == a sparse vector or a dense vector
+ - dest is not dense when src is sparse
+ (i.e. you can't assign a sparse vector to a dense vector. This is
+ because we don't know what the proper dimensionality should be for the
+ dense vector)
+ ensures
+ - #src represents the same vector as dest.
+ (conversion between sparse/dense formats is done automatically)
+ !*/
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename T::value_type::second_type dot (
+ const T& a,
+ const T& b
+ );
+ /*!
+ requires
+ - a and b are sparse vectors
+ ensures
+ - returns the dot product between the vectors a and b
+ !*/
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5, typename T6>
+ T4 dot (
+ const std::vector<T1,T2>& a,
+ const std::map<T3,T4,T5,T6>& b
+ );
+ /*!
+ requires
+ - a and b are sparse vectors
+ ensures
+ - returns the dot product between the vectors a and b
+ !*/
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5, typename T6>
+ T4 dot (
+ const std::map<T3,T4,T5,T6>& a,
+ const std::vector<T1,T2>& b
+ );
+ /*!
+ requires
+ - a and b are sparse vectors
+ ensures
+ - returns the dot product between the vectors a and b
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename EXP>
+ typename T::value_type::second_type dot (
+ const T& a,
+ const matrix_exp<EXP>& b
+ );
+ /*!
+ requires
+ - a is an unsorted sparse vector
+ - is_vector(b) == true
+ ensures
+ - returns the dot product between the vectors a and b.
+ - if (max_index_plus_one(a) >= b.size()) then
+ - a's dimensionality is greater than b's dimensionality. In this case we
+ pretend b is padded by as many zeros as is needed to make the dot product
+ work. So this means that any elements in a that go beyond the length of
+ b are simply ignored.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename EXP>
+ typename T::value_type::second_type dot (
+ const matrix_exp<EXP>& a,
+ const T& b
+ );
+ /*!
+ requires
+ - b is an unsorted sparse vector
+ - is_vector(a) == true
+ ensures
+ - returns the dot product between the vectors a and b
+ - if (max_index_plus_one(b) >= a.size()) then
+ - b's dimensionality is greater than a's dimensionality. In this case we
+ pretend a is padded by as many zeros as is needed to make the dot product
+ work. So this means that any elements in b that go beyond the length of
+ a are simply ignored.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename T::value_type::second_type length_squared (
+ const T& a
+ );
+ /*!
+ requires
+ - a is a sparse vector
+ ensures
+ - returns dot(a,a)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename T::value_type::second_type length (
+ const T& a
+ );
+ /*!
+ requires
+ - a is a sparse vector
+ ensures
+ - returns std::sqrt(length_squared(a,a))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename U>
+ void scale_by (
+ T& a,
+ const U& value
+ );
+ /*!
+ requires
+ - a is an unsorted sparse vector or a dlib::matrix
+ ensures
+ - #a == a*value
+ (i.e. multiplies every element of the vector a by value)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ T add (
+ const T& a,
+ const T& b
+ );
+ /*!
+ requires
+ - a is a sparse vector or dlib::matrix
+ - b is a sparse vector or dlib::matrix
+ ensures
+ - returns a vector or matrix which represents a+b. If the inputs are
+ sparse vectors then the result is a sparse vector.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ T subtract (
+ const T& a,
+ const T& b
+ );
+ /*!
+ requires
+ - a is a sparse vector or dlib::matrix
+ - b is a sparse vector or dlib::matrix
+ ensures
+ - returns a vector or matrix which represents a-b. If the inputs are
+ sparse vectors then the result is a sparse vector.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ unsigned long max_index_plus_one (
+ const T& samples
+ );
+ /*!
+ requires
+ - samples == a single vector (either sparse or dense), or a container
+ of vectors which is either a dlib::matrix of vectors or something
+ convertible to a dlib::matrix via mat() (e.g. a std::vector)
+ Valid types of samples include (but are not limited to):
+ - dlib::matrix<double,0,1> // A single dense vector
+ - std::map<unsigned int, double> // A single sparse vector
+ - std::vector<dlib::matrix<double,0,1> > // An array of dense vectors
+ - std::vector<std::map<unsigned int, double> > // An array of sparse vectors
+ ensures
+ - This function tells you the dimensionality of a set of vectors. The vectors
+ can be either sparse or dense.
+ - if (samples.size() == 0) then
+ - returns 0
+ - else if (samples contains dense vectors or is a dense vector) then
+ - returns the number of elements in the first sample vector. This means
+ we implicitly assume all dense vectors have the same length)
+ - else
+ - In this case samples contains sparse vectors or is a sparse vector.
+ - returns the largest element index in any sample + 1. Note that the element index values
+ are the values stored in std::pair::first. So this number tells you the dimensionality
+ of a set of sparse vectors.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename SRC, typename U>
+ inline void add_to (
+ matrix<T,NR,NC,MM,L>& dest,
+ const SRC& src,
+ const U& C = 1
+ );
+ /*!
+ requires
+ - SRC == a matrix expression or an unsorted sparse vector
+ - is_vector(dest) == true
+ - Let MAX denote the largest element index in src.
+ Then we require that:
+ - MAX < dest.size()
+ - (i.e. dest needs to be big enough to contain all the elements of src)
+ ensures
+ - #dest == dest + C*src
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, long NR, long NC, typename MM, typename L, typename SRC, typename U>
+ inline void subtract_from (
+ matrix<T,NR,NC,MM,L>& dest,
+ const SRC& src,
+ const U& C = 1
+ );
+ /*!
+ requires
+ - SRC == a matrix expression or an unsorted sparse vector
+ - is_vector(dest) == true
+ - Let MAX denote the largest element index in src.
+ Then we require that:
+ - MAX < dest.size()
+ - (i.e. dest needs to be big enough to contain all the elements of src)
+ ensures
+ - #dest == dest - C*src
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename T::value_type::second_type min (
+ const T& vect
+ );
+ /*!
+ requires
+ - T == an unsorted sparse vector
+ ensures
+ - returns the minimum value in the sparse vector vect. Note that
+ this value is always <= 0 since a sparse vector has an unlimited number
+ of 0 elements.
+ !*/
+
+// ------------------------------------------------------------------------------------
+
+ template <typename T>
+ typename T::value_type::second_type max (
+ const T& vect
+ );
+ /*!
+ requires
+ - T == an unsorted sparse vector
+ ensures
+ - returns the maximum value in the sparse vector vect. Note that
+ this value is always >= 0 since a sparse vector has an unlimited number
+ of 0 elements.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type
+ >
+ matrix<typename sample_type::value_type::second_type,0,1> sparse_to_dense (
+ const sample_type& vect
+ );
+ /*!
+ requires
+ - vect must be a sparse vector or a dense column vector.
+ ensures
+ - converts the single sparse or dense vector vect to a dense (column matrix form)
+ representation. That is, this function returns a vector V such that:
+ - V.size() == max_index_plus_one(vect)
+ - for all valid j:
+ - V(j) == The value of the j'th dimension of the vector vect. Note
+ that V(j) is zero if it is a sparse vector that doesn't contain an
+ entry for the j'th dimension.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type
+ >
+ matrix<typename sample_type::value_type::second_type,0,1> sparse_to_dense (
+ const sample_type& vect,
+ unsigned long num_dimensions
+ );
+ /*!
+ requires
+ - vect must be a sparse vector or a dense column vector.
+ ensures
+ - converts the single sparse or dense vector vect to a dense (column matrix form)
+ representation. That is, this function returns a vector V such that:
+ - V.size() == num_dimensions
+ - for all valid j:
+ - V(j) == The value of the j'th dimension of the vector vect. Note
+ that V(j) is zero if it is a sparse vector that doesn't contain an
+ entry for the j'th dimension.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type,
+ typename alloc
+ >
+ std::vector<matrix<typename sample_type::value_type::second_type,0,1> > sparse_to_dense (
+ const std::vector<sample_type, alloc>& samples
+ );
+ /*!
+ requires
+ - all elements of samples must be sparse vectors or dense column vectors.
+ ensures
+ - converts from sparse sample vectors to dense (column matrix form)
+ - That is, this function returns a std::vector R such that:
+ - R contains column matrices
+ - R.size() == samples.size()
+ - for all valid i:
+ - R[i] == sparse_to_dense(samples[i], max_index_plus_one(samples))
+ (i.e. the dense (i.e. dlib::matrix) version of the sparse sample
+ given by samples[i].)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sample_type,
+ typename alloc
+ >
+ std::vector<matrix<typename sample_type::value_type::second_type,0,1> > sparse_to_dense (
+ const std::vector<sample_type, alloc>& samples,
+ unsigned long num_dimensions
+ );
+ /*!
+ requires
+ - all elements of samples must be sparse vectors or dense column vectors.
+ ensures
+ - converts from sparse sample vectors to dense (column matrix form)
+ - That is, this function returns a std::vector R such that:
+ - R contains column matrices
+ - R.size() == samples.size()
+ - for all valid i:
+ - R[i] == sparse_to_dense(samples[i], num_dimensions)
+ (i.e. the dense (i.e. dlib::matrix) version of the sparse sample
+ given by samples[i].)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T make_sparse_vector (
+ const T& v
+ );
+ /*!
+ requires
+ - v is an unsorted sparse vector
+ ensures
+ - returns a copy of v which is a sparse vector.
+ (i.e. it will be properly sorted and not have any duplicate key values but
+ will still logically represent the same vector).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void make_sparse_vector_inplace(
+ T& vect
+ );
+ /*!
+ requires
+ - v is an unsorted sparse vector
+ ensures
+ - vect == make_sparse_vector(vect)
+ - This function is just an optimized version of make_sparse_vector(), in
+ particular, when T is a std::vector<std::pair<>> type it is much more
+ efficient.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ void sparse_matrix_vector_multiply (
+ const std::vector<sample_pair>& edges,
+ const matrix_exp<EXP>& v,
+ matrix<T,NR,NC,MM,L>& result
+ );
+ /*!
+ requires
+ - is_col_vector(v) == true
+ - max_index_plus_one(edges) <= v.size()
+ ensures
+ - Interprets edges as representing a symmetric sparse matrix M. The elements
+ of M are defined such that, for all valid i,j:
+ - M(i,j) == sum of edges[k].distance() for all k where edges[k]==sample_pair(i,j)
+ - This means that any element of M that doesn't have any edges associated
+ with it will have a value of 0.
+ - #result == M*v
+ (i.e. this function multiplies the vector v with the sparse matrix
+ represented by edges and stores the output into result)
+ - get_rect(#result) == get_rect(v)
+ (i.e. result will have the same dimensions as v)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ void sparse_matrix_vector_multiply (
+ const std::vector<ordered_sample_pair>& edges,
+ const matrix_exp<EXP>& v,
+ matrix<T,NR,NC,MM,L>& result
+ );
+ /*!
+ requires
+ - is_col_vector(v) == true
+ - max_index_plus_one(edges) <= v.size()
+ ensures
+ - Interprets edges as representing a square sparse matrix M. The elements of M
+ are defined such that, for all valid i,j:
+ - M(i,j) == sum of edges[k].distance() for all k where edges[k]==ordered_sample_pair(i,j)
+ - This means that any element of M that doesn't have any edges associated
+ with it will have a value of 0.
+ - #result == M*v
+ (i.e. this function multiplies the vector v with the sparse matrix
+ represented by edges and stores the output into result)
+ - get_rect(#result) == get_rect(v)
+ (i.e. result will have the same dimensions as v)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ matrix<typename EXP::type,0,1> sparse_matrix_vector_multiply (
+ const std::vector<sample_pair>& edges,
+ const matrix_exp<EXP>& v
+ );
+ /*!
+ requires
+ - is_col_vector(v) == true
+ - max_index_plus_one(edges) <= v.size()
+ ensures
+ - This is just a convenience routine for invoking the above
+ sparse_matrix_vector_multiply() routine. In particular, it just calls
+ sparse_matrix_vector_multiply() with a temporary result matrix and then
+ returns the result.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP
+ >
+ matrix<typename EXP::type,0,1> sparse_matrix_vector_multiply (
+ const std::vector<ordered_sample_pair>& edges,
+ const matrix_exp<EXP>& v
+ );
+ /*!
+ requires
+ - is_col_vector(v) == true
+ - max_index_plus_one(edges) <= v.size()
+ ensures
+ - This is just a convenience routine for invoking the above
+ sparse_matrix_vector_multiply() routine. In particular, it just calls
+ sparse_matrix_vector_multiply() with a temporary result matrix and then
+ returns the result.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename sparse_vector_type,
+ typename T,
+ long NR,
+ long NC,
+ typename MM,
+ typename L
+ >
+ void sparse_matrix_vector_multiply (
+ const matrix_exp<EXP>& m,
+ const sparse_vector_type& v,
+ matrix<T,NR,NC,MM,L>& result
+ );
+ /*!
+ requires
+ - max_index_plus_one(v) <= m.nc()
+ - v == an unsorted sparse vector
+ ensures
+ - #result == m*v
+ (i.e. multiply m by the vector v and store the output in result)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename EXP,
+ typename sparse_vector_type
+ >
+ matrix<typename EXP::type,0,1> sparse_matrix_vector_multiply (
+ const matrix_exp<EXP>& m,
+ const sparse_vector_type& v
+ );
+ /*!
+ requires
+ - max_index_plus_one(v) <= m.nc()
+ - v == an unsorted sparse vector
+ ensures
+ - returns m*v
+ (i.e. multiply m by the vector v and return the resulting vector)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_SPARSE_VECTOR_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/structural_assignment_trainer.h b/ml/dlib/dlib/svm/structural_assignment_trainer.h
new file mode 100644
index 000000000..d55b74ff0
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_assignment_trainer.h
@@ -0,0 +1,294 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_
+#define DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_
+
+#include "structural_assignment_trainer_abstract.h"
+#include "../algs.h"
+#include "../optimization.h"
+#include "structural_svm_assignment_problem.h"
+#include "num_nonnegative_weights.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class structural_assignment_trainer
+ {
+ public:
+ typedef typename feature_extractor::lhs_element lhs_element;
+ typedef typename feature_extractor::rhs_element rhs_element;
+ typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
+ typedef std::vector<long> label_type;
+ typedef assignment_function<feature_extractor> trained_function_type;
+
+ structural_assignment_trainer (
+ )
+ {
+ set_defaults();
+ }
+
+ explicit structural_assignment_trainer (
+ const feature_extractor& fe_
+ ) : fe(fe_)
+ {
+ set_defaults();
+ }
+
+ const feature_extractor& get_feature_extractor (
+ ) const { return fe; }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ num_threads = num;
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return num_threads;
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void structural_assignment_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ double get_epsilon (
+ ) const { return eps; }
+
+ void set_max_cache_size (
+ unsigned long max_size
+ )
+ {
+ max_cache_size = max_size;
+ }
+
+ unsigned long get_max_cache_size (
+ ) const
+ {
+ return max_cache_size;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ void set_c (
+ double C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void structural_assignment_trainer::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ double get_c (
+ ) const
+ {
+ return C;
+ }
+
+ bool forces_assignment(
+ ) const { return force_assignment; }
+
+ void set_forces_assignment (
+ bool new_value
+ )
+ {
+ force_assignment = new_value;
+ }
+
+ void set_loss_per_false_association (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_assignment_trainer::set_loss_per_false_association(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_false_association = loss;
+ }
+
+ double get_loss_per_false_association (
+ ) const
+ {
+ return loss_per_false_association;
+ }
+
+ void set_loss_per_missed_association (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_assignment_trainer::set_loss_per_missed_association(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_missed_association = loss;
+ }
+
+ double get_loss_per_missed_association (
+ ) const
+ {
+ return loss_per_missed_association;
+ }
+
+ bool forces_last_weight_to_1 (
+ ) const
+ {
+ return last_weight_1;
+ }
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ )
+ {
+ last_weight_1 = should_last_weight_be_1;
+ }
+
+ const assignment_function<feature_extractor> train (
+ const std::vector<sample_type>& samples,
+ const std::vector<label_type>& labels
+ ) const
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ if (force_assignment)
+ {
+ DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
+ "\t assignment_function structural_assignment_trainer::train()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+ else
+ {
+ DLIB_ASSERT(is_assignment_problem(samples, labels),
+ "\t assignment_function structural_assignment_trainer::train()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+#endif
+
+
+
+ structural_svm_assignment_problem<feature_extractor> prob(samples,labels, fe, force_assignment, num_threads,
+ loss_per_false_association, loss_per_missed_association);
+
+ if (verbose)
+ prob.be_verbose();
+
+ prob.set_c(C);
+ prob.set_epsilon(eps);
+ prob.set_max_cache_size(max_cache_size);
+
+ matrix<double,0,1> weights;
+
+ // Take the min here because we want to prevent the user from accidentally
+ // forcing the bias term to be non-negative.
+ const unsigned long num_nonneg = std::min(fe.num_features(),num_nonnegative_weights(fe));
+ if (last_weight_1)
+ solver(prob, weights, num_nonneg, fe.num_features()-1);
+ else
+ solver(prob, weights, num_nonneg);
+
+ const double bias = weights(weights.size()-1);
+ return assignment_function<feature_extractor>(colm(weights,0,weights.size()-1), bias,fe,force_assignment);
+
+ }
+
+
+ private:
+
+ bool force_assignment;
+ double C;
+ oca solver;
+ double eps;
+ bool verbose;
+ unsigned long num_threads;
+ unsigned long max_cache_size;
+ double loss_per_false_association;
+ double loss_per_missed_association;
+ bool last_weight_1;
+
+ void set_defaults ()
+ {
+ force_assignment = false;
+ C = 100;
+ verbose = false;
+ eps = 0.01;
+ num_threads = 2;
+ max_cache_size = 5;
+ loss_per_false_association = 1;
+ loss_per_missed_association = 1;
+ last_weight_1 = false;
+ }
+
+ feature_extractor fe;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_
+
+
+
+
diff --git a/ml/dlib/dlib/svm/structural_assignment_trainer_abstract.h b/ml/dlib/dlib/svm/structural_assignment_trainer_abstract.h
new file mode 100644
index 000000000..ebd402d42
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_assignment_trainer_abstract.h
@@ -0,0 +1,299 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "structural_svm_assignment_problem.h"
+#include "assignment_function_abstract.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class structural_assignment_trainer
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ It must be an object that implements an interface compatible with
+ the example_feature_extractor defined in dlib/svm/assignment_function_abstract.h.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning to solve an assignment problem based
+ on a training dataset of example assignments. The training procedure produces an
+ assignment_function object which can be used to predict the assignments of
+ new data.
+
+ Note that this is just a convenience wrapper around the
+ structural_svm_assignment_problem to make it look
+ similar to all the other trainers in dlib.
+ !*/
+
+ public:
+ typedef typename feature_extractor::lhs_element lhs_element;
+ typedef typename feature_extractor::rhs_element rhs_element;
+ typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
+ typedef std::vector<long> label_type;
+ typedef assignment_function<feature_extractor> trained_function_type;
+
+ structural_assignment_trainer (
+ );
+ /*!
+ ensures
+ - #get_c() == 100
+ - this object isn't verbose
+ - #get_epsilon() == 0.01
+ - #get_num_threads() == 2
+ - #get_max_cache_size() == 5
+ - #get_feature_extractor() == a default initialized feature_extractor
+ - #forces_assignment() == false
+ - #get_loss_per_false_association() == 1
+ - #get_loss_per_missed_association() == 1
+ - #forces_last_weight_to_1() == false
+ !*/
+
+ explicit structural_assignment_trainer (
+ const feature_extractor& fe
+ );
+ /*!
+ ensures
+ - #get_c() == 100
+ - this object isn't verbose
+ - #get_epsilon() == 0.01
+ - #get_num_threads() == 2
+ - #get_max_cache_size() == 40
+ - #get_feature_extractor() == fe
+ - #forces_assignment() == false
+ - #get_loss_per_false_association() == 1
+ - #get_loss_per_missed_association() == 1
+ - #forces_last_weight_to_1() == false
+ !*/
+
+ const feature_extractor& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used by this object
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads used during training. You should
+ usually set this equal to the number of processing cores on your
+ machine.
+ !*/
+
+ void set_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer
+ to train. You can think of this epsilon value as saying "solve the
+ optimization problem until the average number of assignment mistakes per
+ training sample is within epsilon of its optimal value".
+ !*/
+
+ void set_max_cache_size (
+ unsigned long max_size
+ );
+ /*!
+ ensures
+ - #get_max_cache_size() == max_size
+ !*/
+
+ unsigned long get_max_cache_size (
+ ) const;
+ /*!
+ ensures
+ - During training, this object basically runs the assignment_function on
+ each training sample, over and over. To speed this up, it is possible to
+ cache the results of these invocations. This function returns the number
+ of cache elements per training sample kept in the cache. Note that a value
+ of 0 means caching is not used at all.
+ !*/
+
+ void set_loss_per_false_association (
+ double loss
+ );
+ /*!
+ requires
+ - loss > 0
+ ensures
+ - #get_loss_per_false_association() == loss
+ !*/
+
+ double get_loss_per_false_association (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss experienced for associating two objects
+ together that shouldn't be associated. If you care more about avoiding
+ accidental associations than ensuring all possible associations are
+ identified then then you can increase this value.
+ !*/
+
+ void set_loss_per_missed_association (
+ double loss
+ );
+ /*!
+ requires
+ - loss > 0
+ ensures
+ - #get_loss_per_missed_association() == loss
+ !*/
+
+ double get_loss_per_missed_association (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss experienced for failing to associate two
+ objects that are supposed to be associated. If you care more about
+ getting all the associations than avoiding accidentally associating
+ objects that shouldn't be associated then you can increase this value.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the structural SVM problem.
+ !*/
+
+ void set_c (
+ double C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() = C
+ !*/
+
+ double get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter
+ that determines the trade-off between trying to fit the training
+ data (i.e. minimize the loss) or allowing more errors but hopefully
+ improving the generalization of the resulting assignment_function.
+ Larger values encourage exact fitting while smaller values of C may
+ encourage better generalization.
+ !*/
+
+ void set_forces_assignment (
+ bool new_value
+ );
+ /*!
+ ensures
+ - #forces_assignment() == new_value
+ !*/
+
+ bool forces_assignment(
+ ) const;
+ /*!
+ ensures
+ - returns the value of the forces_assignment() parameter for the
+ assignment_functions generated by this object.
+ !*/
+
+ bool forces_last_weight_to_1 (
+ ) const;
+ /*!
+ ensures
+ - returns true if this trainer has the constraint that the last weight in
+ the learned parameter vector must be 1. This is the weight corresponding
+ to the feature in the training vectors with the highest dimension.
+ - Forcing the last weight to 1 also disables the bias and therefore the
+ get_bias() field of the learned assignment_function will be 0 when
+ forces_last_weight_to_1() == true.
+ !*/
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ );
+ /*!
+ ensures
+ - #forces_last_weight_to_1() == should_last_weight_be_1
+ !*/
+
+ const assignment_function<feature_extractor> train (
+ const std::vector<sample_type>& samples,
+ const std::vector<label_type>& labels
+ ) const;
+ /*!
+ requires
+ - is_assignment_problem(samples,labels) == true
+ - if (forces_assignment()) then
+ - is_forced_assignment_problem(samples,labels) == true
+ ensures
+ - Uses the structural_svm_assignment_problem to train an
+ assignment_function on the given samples/labels training pairs.
+ The idea is to learn to predict a label given an input sample.
+ - returns a function F with the following properties:
+ - F(new_sample) == A set of assignments indicating how the elements of
+ new_sample.first match up with the elements of new_sample.second.
+ - F.forces_assignment() == forces_assignment()
+ - F.get_feature_extractor() == get_feature_extractor()
+ - if (forces_last_weight_to_1()) then
+ - F.get_bias() == 0
+ - F.get_weights()(F.get_weights().size()-1) == 1
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/structural_graph_labeling_trainer.h b/ml/dlib/dlib/svm/structural_graph_labeling_trainer.h
new file mode 100644
index 000000000..4d55c772b
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_graph_labeling_trainer.h
@@ -0,0 +1,282 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_Hh_
+#define DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_Hh_
+
+#include "structural_graph_labeling_trainer_abstract.h"
+#include "../algs.h"
+#include "../optimization.h"
+#include "structural_svm_graph_labeling_problem.h"
+#include "../graph_cuts/graph_labeler.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ class structural_graph_labeling_trainer
+ {
+ public:
+ typedef std::vector<bool> label_type;
+ typedef graph_labeler<vector_type> trained_function_type;
+
+ structural_graph_labeling_trainer (
+ )
+ {
+ C = 10;
+ verbose = false;
+ eps = 0.1;
+ num_threads = 2;
+ max_cache_size = 5;
+ loss_pos = 1.0;
+ loss_neg = 1.0;
+ }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ num_threads = num;
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return num_threads;
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void structural_graph_labeling_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ double get_epsilon (
+ ) const { return eps; }
+
+ void set_max_cache_size (
+ unsigned long max_size
+ )
+ {
+ max_cache_size = max_size;
+ }
+
+ unsigned long get_max_cache_size (
+ ) const
+ {
+ return max_cache_size;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ void set_c (
+ double C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void structural_graph_labeling_trainer::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ double get_c (
+ ) const
+ {
+ return C;
+ }
+
+
+ void set_loss_on_positive_class (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss >= 0,
+ "\t structural_graph_labeling_trainer::set_loss_on_positive_class()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this );
+
+ loss_pos = loss;
+ }
+
+ void set_loss_on_negative_class (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss >= 0,
+ "\t structural_graph_labeling_trainer::set_loss_on_negative_class()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this );
+
+ loss_neg = loss;
+ }
+
+ double get_loss_on_negative_class (
+ ) const { return loss_neg; }
+
+ double get_loss_on_positive_class (
+ ) const { return loss_pos; }
+
+
+ template <
+ typename graph_type
+ >
+ const graph_labeler<vector_type> train (
+ const dlib::array<graph_type>& samples,
+ const std::vector<label_type>& labels,
+ const std::vector<std::vector<double> >& losses
+ ) const
+ {
+#ifdef ENABLE_ASSERTS
+ std::string reason_for_failure;
+ DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true ,
+ "\t void structural_graph_labeling_trainer::train()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t reason_for_failure: " << reason_for_failure
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t labels.size(): " << labels.size()
+ << "\n\t this: " << this );
+ DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
+ all_values_are_nonnegative(losses) == true,
+ "\t void structural_graph_labeling_trainer::train()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t labels.size(): " << labels.size()
+ << "\n\t losses.size(): " << losses.size()
+ << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
+ << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
+ << "\n\t this: " << this );
+#endif
+
+
+ structural_svm_graph_labeling_problem<graph_type> prob(samples, labels, losses, num_threads);
+
+ if (verbose)
+ prob.be_verbose();
+
+ prob.set_c(C);
+ prob.set_epsilon(eps);
+ prob.set_max_cache_size(max_cache_size);
+ if (prob.get_losses().size() == 0)
+ {
+ prob.set_loss_on_positive_class(loss_pos);
+ prob.set_loss_on_negative_class(loss_neg);
+ }
+
+ matrix<double,0,1> w;
+ solver(prob, w, prob.get_num_edge_weights());
+
+ vector_type edge_weights;
+ vector_type node_weights;
+ populate_weights(w, edge_weights, node_weights, prob.get_num_edge_weights());
+ return graph_labeler<vector_type>(edge_weights, node_weights);
+ }
+
+ template <
+ typename graph_type
+ >
+ const graph_labeler<vector_type> train (
+ const dlib::array<graph_type>& samples,
+ const std::vector<label_type>& labels
+ ) const
+ {
+ std::vector<std::vector<double> > losses;
+ return train(samples, labels, losses);
+ }
+
+ private:
+
+ template <typename T>
+ typename enable_if<is_matrix<T> >::type populate_weights (
+ const matrix<double,0,1>& w,
+ T& edge_weights,
+ T& node_weights,
+ long split_idx
+ ) const
+ {
+ edge_weights = rowm(w,range(0, split_idx-1));
+ node_weights = rowm(w,range(split_idx,w.size()-1));
+ }
+
+ template <typename T>
+ typename disable_if<is_matrix<T> >::type populate_weights (
+ const matrix<double,0,1>& w,
+ T& edge_weights,
+ T& node_weights,
+ long split_idx
+ ) const
+ {
+ edge_weights.clear();
+ node_weights.clear();
+ for (long i = 0; i < split_idx; ++i)
+ {
+ if (w(i) != 0)
+ edge_weights.insert(edge_weights.end(), std::make_pair(i,w(i)));
+ }
+ for (long i = split_idx; i < w.size(); ++i)
+ {
+ if (w(i) != 0)
+ node_weights.insert(node_weights.end(), std::make_pair(i-split_idx,w(i)));
+ }
+ }
+
+
+ double C;
+ oca solver;
+ double eps;
+ bool verbose;
+ unsigned long num_threads;
+ unsigned long max_cache_size;
+ double loss_pos;
+ double loss_neg;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/structural_graph_labeling_trainer_abstract.h b/ml/dlib/dlib/svm/structural_graph_labeling_trainer_abstract.h
new file mode 100644
index 000000000..df88096a0
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_graph_labeling_trainer_abstract.h
@@ -0,0 +1,265 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "../optimization.h"
+#include "structural_svm_graph_labeling_problem_abstract.h"
+#include "../graph_cuts/graph_labeler_abstract.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ class structural_graph_labeling_trainer
+ {
+ /*!
+ REQUIREMENTS ON vector_type
+ - vector_type is a dlib::matrix capable of representing column
+ vectors or it is a sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning to solve a graph labeling problem based
+ on a training dataset of example labeled graphs. The training procedure
+ produces a graph_labeler object which can be used to predict the labelings
+ of new graphs.
+
+ Note that this is just a convenience wrapper around the
+ structural_svm_graph_labeling_problem to make it look
+ similar to all the other trainers in dlib.
+ !*/
+
+ public:
+ typedef std::vector<bool> label_type;
+ typedef graph_labeler<vector_type> trained_function_type;
+
+ structural_graph_labeling_trainer (
+ );
+ /*!
+ ensures
+ - #get_c() == 10
+ - this object isn't verbose
+ - #get_epsilon() == 0.1
+ - #get_num_threads() == 2
+ - #get_max_cache_size() == 5
+ - #get_loss_on_positive_class() == 1.0
+ - #get_loss_on_negative_class() == 1.0
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads used during training. You should
+ usually set this equal to the number of processing cores on your
+ machine.
+ !*/
+
+ void set_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer
+ to train. You can think of this epsilon value as saying "solve the
+ optimization problem until the average number of labeling mistakes per
+ example graph is within epsilon of its optimal value".
+ !*/
+
+ void set_max_cache_size (
+ unsigned long max_size
+ );
+ /*!
+ ensures
+ - #get_max_cache_size() == max_size
+ !*/
+
+ unsigned long get_max_cache_size (
+ ) const;
+ /*!
+ ensures
+ - During training, this object basically runs the graph_labeler on each
+ training sample, over and over. To speed this up, it is possible to
+ cache the results of these invocations. This function returns the number
+ of cache elements per training sample kept in the cache. Note that a value
+ of 0 means caching is not used at all.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the structural SVM problem.
+ !*/
+
+ void set_c (
+ double C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() = C
+ !*/
+
+ double get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter
+ that determines the trade-off between trying to fit the training
+ data (i.e. minimize the loss) or allowing more errors but hopefully
+ improving the generalization of the resulting graph_labeler. Larger
+ values encourage exact fitting while smaller values of C may encourage
+ better generalization.
+ !*/
+
+ void set_loss_on_positive_class (
+ double loss
+ );
+ /*!
+ requires
+ - loss >= 0
+ ensures
+ - #get_loss_on_positive_class() == loss
+ !*/
+
+ void set_loss_on_negative_class (
+ double loss
+ );
+ /*!
+ requires
+ - loss >= 0
+ ensures
+ - #get_loss_on_negative_class() == loss
+ !*/
+
+ double get_loss_on_positive_class (
+ ) const;
+ /*!
+ ensures
+ - returns the loss incurred when a graph node which is supposed to have
+ a label of true gets misclassified. This value controls how much we care
+ about correctly classifying nodes which should be labeled as true. Larger
+ loss values indicate that we care more strongly than smaller values.
+ !*/
+
+ double get_loss_on_negative_class (
+ ) const;
+ /*!
+ ensures
+ - returns the loss incurred when a graph node which is supposed to have
+ a label of false gets misclassified. This value controls how much we care
+ about correctly classifying nodes which should be labeled as false. Larger
+ loss values indicate that we care more strongly than smaller values.
+ !*/
+
+ template <
+ typename graph_type
+ >
+ const graph_labeler<vector_type> train (
+ const dlib::array<graph_type>& samples,
+ const std::vector<label_type>& labels
+ ) const;
+ /*!
+ requires
+ - is_graph_labeling_problem(samples,labels) == true
+ ensures
+ - Uses the structural_svm_graph_labeling_problem to train a graph_labeler
+ on the given samples/labels training pairs. The idea is to learn to
+ predict a label given an input sample.
+ - The values of get_loss_on_positive_class() and get_loss_on_negative_class()
+ are used to determine how to value mistakes on each node during training.
+ - returns a function F with the following properties:
+ - F(new_sample) == The predicted labels for the nodes in the graph
+ new_sample.
+ !*/
+
+ template <
+ typename graph_type
+ >
+ const graph_labeler<vector_type> train (
+ const dlib::array<graph_type>& samples,
+ const std::vector<label_type>& labels,
+ const std::vector<std::vector<double> >& losses
+ ) const;
+ /*!
+ requires
+ - is_graph_labeling_problem(samples,labels) == true
+ - if (losses.size() != 0) then
+ - sizes_match(labels, losses) == true
+ - all_values_are_nonnegative(losses) == true
+ ensures
+ - Uses the structural_svm_graph_labeling_problem to train a graph_labeler
+ on the given samples/labels training pairs. The idea is to learn to
+ predict a label given an input sample.
+ - returns a function F with the following properties:
+ - F(new_sample) == The predicted labels for the nodes in the graph
+ new_sample.
+ - if (losses.size() == 0) then
+ - The values of get_loss_on_positive_class() and get_loss_on_negative_class()
+ are used to determine how to value mistakes on each node during training.
+ - The losses argument is effectively ignored if its size is zero.
+ - else
+ - Each node in the training data has its own loss value defined by the
+ corresponding entry of losses. In particular, this means that the
+ node with label labels[i][j] incurs a loss of losses[i][j] if it is
+ incorrectly labeled.
+ - The get_loss_on_positive_class() and get_loss_on_negative_class()
+ parameters are ignored. Only losses is used in this case.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/structural_object_detection_trainer.h b/ml/dlib/dlib/svm/structural_object_detection_trainer.h
new file mode 100644
index 000000000..bdf8c5b5c
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_object_detection_trainer.h
@@ -0,0 +1,402 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_
+#define DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_
+
+#include "structural_object_detection_trainer_abstract.h"
+#include "../algs.h"
+#include "../optimization.h"
+#include "structural_svm_object_detection_problem.h"
+#include "../image_processing/object_detector.h"
+#include "../image_processing/box_overlap_testing.h"
+#include "../image_processing/full_object_detection.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type,
+ typename svm_struct_prob_type
+ >
+ void configure_nuclear_norm_regularizer (
+ const image_scanner_type&,
+ svm_struct_prob_type&
+ )
+ {
+ // does nothing by default. Specific scanner types overload this function to do
+ // whatever is appropriate.
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ class structural_object_detection_trainer : noncopyable
+ {
+
+ public:
+ typedef double scalar_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef object_detector<image_scanner_type> trained_function_type;
+
+
+ explicit structural_object_detection_trainer (
+ const image_scanner_type& scanner_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(scanner_.get_num_detection_templates() > 0,
+ "\t structural_object_detection_trainer::structural_object_detection_trainer(scanner_)"
+ << "\n\t You can't have zero detection templates"
+ << "\n\t this: " << this
+ );
+
+ C = 1;
+ verbose = false;
+ eps = 0.1;
+ num_threads = 2;
+ max_cache_size = 5;
+ match_eps = 0.5;
+ loss_per_missed_target = 1;
+ loss_per_false_alarm = 1;
+
+ scanner.copy_configuration(scanner_);
+
+ auto_overlap_tester = true;
+ }
+
+ const image_scanner_type& get_scanner (
+ ) const
+ {
+ return scanner;
+ }
+
+ bool auto_set_overlap_tester (
+ ) const
+ {
+ return auto_overlap_tester;
+ }
+
+ void set_overlap_tester (
+ const test_box_overlap& tester
+ )
+ {
+ overlap_tester = tester;
+ auto_overlap_tester = false;
+ }
+
+ test_box_overlap get_overlap_tester (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(auto_set_overlap_tester() == false,
+ "\t test_box_overlap structural_object_detection_trainer::get_overlap_tester()"
+ << "\n\t You can't call this function if the overlap tester is generated dynamically."
+ << "\n\t this: " << this
+ );
+
+ return overlap_tester;
+ }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ num_threads = num;
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return num_threads;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void structural_object_detection_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ scalar_type get_epsilon (
+ ) const { return eps; }
+
+ void set_max_cache_size (
+ unsigned long max_size
+ )
+ {
+ max_cache_size = max_size;
+ }
+
+ unsigned long get_max_cache_size (
+ ) const
+ {
+ return max_cache_size;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ void set_c (
+ scalar_type C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void structural_object_detection_trainer::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ scalar_type get_c (
+ ) const
+ {
+ return C;
+ }
+
+ void set_match_eps (
+ double eps
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < eps && eps < 1,
+ "\t void structural_object_detection_trainer::set_match_eps(eps)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t eps: " << eps
+ << "\n\t this: " << this
+ );
+
+ match_eps = eps;
+ }
+
+ double get_match_eps (
+ ) const
+ {
+ return match_eps;
+ }
+
+ double get_loss_per_missed_target (
+ ) const
+ {
+ return loss_per_missed_target;
+ }
+
+ void set_loss_per_missed_target (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_object_detection_trainer::set_loss_per_missed_target(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_missed_target = loss;
+ }
+
+ double get_loss_per_false_alarm (
+ ) const
+ {
+ return loss_per_false_alarm;
+ }
+
+ void set_loss_per_false_alarm (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_object_detection_trainer::set_loss_per_false_alarm(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_false_alarm = loss;
+ }
+
+ template <
+ typename image_array_type
+ >
+ const trained_function_type train (
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_object_detections
+ ) const
+ {
+ std::vector<std::vector<rectangle> > empty_ignore(images.size());
+ return train_impl(images, truth_object_detections, empty_ignore, test_box_overlap());
+ }
+
+ template <
+ typename image_array_type
+ >
+ const trained_function_type train (
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_object_detections,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& ignore_overlap_tester = test_box_overlap()
+ ) const
+ {
+ return train_impl(images, truth_object_detections, ignore, ignore_overlap_tester);
+ }
+
+ template <
+ typename image_array_type
+ >
+ const trained_function_type train (
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_object_detections
+ ) const
+ {
+ std::vector<std::vector<rectangle> > empty_ignore(images.size());
+ return train(images, truth_object_detections, empty_ignore, test_box_overlap());
+ }
+
+ template <
+ typename image_array_type
+ >
+ const trained_function_type train (
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_object_detections,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& ignore_overlap_tester = test_box_overlap()
+ ) const
+ {
+ std::vector<std::vector<full_object_detection> > truth_dets(truth_object_detections.size());
+ for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
+ {
+ for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
+ {
+ truth_dets[i].push_back(full_object_detection(truth_object_detections[i][j]));
+ }
+ }
+
+ return train_impl(images, truth_dets, ignore, ignore_overlap_tester);
+ }
+
+ private:
+
+ template <
+ typename image_array_type
+ >
+ const trained_function_type train_impl (
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_object_detections,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& ignore_overlap_tester
+ ) const
+ {
+#ifdef ENABLE_ASSERTS
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(images,truth_object_detections) == true && images.size() == ignore.size(),
+ "\t trained_function_type structural_object_detection_trainer::train()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t images.size(): " << images.size()
+ << "\n\t ignore.size(): " << ignore.size()
+ << "\n\t truth_object_detections.size(): " << truth_object_detections.size()
+ << "\n\t is_learning_problem(images,truth_object_detections): " << is_learning_problem(images,truth_object_detections)
+ );
+ for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
+ {
+ for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
+ {
+ DLIB_ASSERT(truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template() &&
+ all_parts_in_rect(truth_object_detections[i][j]) == true,
+ "\t trained_function_type structural_object_detection_trainer::train()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t truth_object_detections["<<i<<"]["<<j<<"].num_parts(): " <<
+ truth_object_detections[i][j].num_parts()
+ << "\n\t get_scanner().get_num_movable_components_per_detection_template(): " <<
+ get_scanner().get_num_movable_components_per_detection_template()
+ << "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j])
+ );
+ }
+ }
+#endif
+
+ structural_svm_object_detection_problem<image_scanner_type,image_array_type >
+ svm_prob(scanner, overlap_tester, auto_overlap_tester, images,
+ truth_object_detections, ignore, ignore_overlap_tester, num_threads);
+
+ if (verbose)
+ svm_prob.be_verbose();
+
+ svm_prob.set_c(C);
+ svm_prob.set_epsilon(eps);
+ svm_prob.set_max_cache_size(max_cache_size);
+ svm_prob.set_match_eps(match_eps);
+ svm_prob.set_loss_per_missed_target(loss_per_missed_target);
+ svm_prob.set_loss_per_false_alarm(loss_per_false_alarm);
+ configure_nuclear_norm_regularizer(scanner, svm_prob);
+ matrix<double,0,1> w;
+
+ // Run the optimizer to find the optimal w.
+ solver(svm_prob,w);
+
+ // report the results of the training.
+ return object_detector<image_scanner_type>(scanner, svm_prob.get_overlap_tester(), w);
+ }
+
+ image_scanner_type scanner;
+ test_box_overlap overlap_tester;
+
+ double C;
+ oca solver;
+ double eps;
+ double match_eps;
+ bool verbose;
+ unsigned long num_threads;
+ unsigned long max_cache_size;
+ double loss_per_missed_target;
+ double loss_per_false_alarm;
+ bool auto_overlap_tester;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/structural_object_detection_trainer_abstract.h b/ml/dlib/dlib/svm/structural_object_detection_trainer_abstract.h
new file mode 100644
index 000000000..2dd799874
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_object_detection_trainer_abstract.h
@@ -0,0 +1,390 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H_ABSTRACTh_
+#ifdef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H_ABSTRACTh_
+
+#include "structural_svm_object_detection_problem_abstract.h"
+#include "../image_processing/object_detector_abstract.h"
+#include "../image_processing/box_overlap_testing_abstract.h"
+#include "../image_processing/full_object_detection_abstract.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type
+ >
+ class structural_object_detection_trainer : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON image_scanner_type
+ image_scanner_type must be an implementation of
+ dlib/image_processing/scan_fhog_pyramid_abstract.h or
+ dlib/image_processing/scan_image_custom_abstract.h or
+ dlib/image_processing/scan_image_pyramid_abstract.h or
+ dlib/image_processing/scan_image_boxes_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning to detect objects in images based on a
+ set of labeled images. The training procedure produces an object_detector
+ which can be used to predict the locations of objects in new images.
+
+ Note that this is just a convenience wrapper around the structural_svm_object_detection_problem
+ to make it look similar to all the other trainers in dlib.
+ !*/
+
+ public:
+ typedef double scalar_type;
+ typedef default_memory_manager mem_manager_type;
+ typedef object_detector<image_scanner_type> trained_function_type;
+
+
+ explicit structural_object_detection_trainer (
+ const image_scanner_type& scanner
+ );
+ /*!
+ requires
+ - scanner.get_num_detection_templates() > 0
+ ensures
+ - #get_c() == 1
+ - this object isn't verbose
+ - #get_epsilon() == 0.1
+ - #get_num_threads() == 2
+ - #get_max_cache_size() == 5
+ - #get_match_eps() == 0.5
+ - #get_loss_per_missed_target() == 1
+ - #get_loss_per_false_alarm() == 1
+ - This object will attempt to learn a model for the given
+ scanner object when train() is called.
+ - #get_scanner() == scanner
+ (note that only the "configuration" of scanner is copied.
+ I.e. the copy is done using copy_configuration())
+ - #auto_set_overlap_tester() == true
+ !*/
+
+ const image_scanner_type& get_scanner (
+ ) const;
+ /*!
+ ensures
+ - returns the image scanner used by this object.
+ !*/
+
+ bool auto_set_overlap_tester (
+ ) const;
+ /*!
+ ensures
+ - if (this object will automatically determine an appropriate
+ state for the overlap tester used for non-max suppression.) then
+ - returns true
+ - In this case, it is determined using the find_tight_overlap_tester()
+ routine based on the truth_object_detections given to the
+ structural_object_detection_trainer::train() method.
+ - else
+ - returns false
+ !*/
+
+ void set_overlap_tester (
+ const test_box_overlap& tester
+ );
+ /*!
+ ensures
+ - #get_overlap_tester() == tester
+ - #auto_set_overlap_tester() == false
+ !*/
+
+ test_box_overlap get_overlap_tester (
+ ) const;
+ /*!
+ requires
+ - auto_set_overlap_tester() == false
+ ensures
+ - returns the overlap tester object which will be used to perform non-max suppression.
+ In particular, this function returns the overlap tester which will populate the
+ object_detector returned by train().
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads used during training. You should
+ usually set this equal to the number of processing cores on your
+ machine.
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer
+ to train. You can think of this epsilon value as saying "solve the
+ optimization problem until the average loss per sample is within epsilon
+ of its optimal value".
+ !*/
+
+ void set_max_cache_size (
+ unsigned long max_size
+ );
+ /*!
+ ensures
+ - #get_max_cache_size() == max_size
+ !*/
+
+ unsigned long get_max_cache_size (
+ ) const;
+ /*!
+ ensures
+ - During training, this object basically runs the object detector on
+ each image, over and over. To speed this up, it is possible to cache
+ the results of these detector invocations. This function returns the
+ number of cache elements per training sample kept in the cache. Note
+ that a value of 0 means caching is not used at all. Note also that
+ each cache element takes up about sizeof(double)*scanner.get_num_dimensions()
+ memory (where scanner is the scanner given to this object's constructor).
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the structural SVM problem.
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() = C
+ !*/
+
+ const scalar_type get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter
+ that determines the trade-off between trying to fit the training
+ data (i.e. minimize the loss) or allowing more errors but hopefully
+ improving the generalization of the resulting detector. Larger
+ values encourage exact fitting while smaller values of C may encourage
+ better generalization.
+ !*/
+
+ void set_match_eps (
+ double eps
+ );
+ /*!
+ requires
+ - 0 < eps < 1
+ ensures
+ - #get_match_eps() == eps
+ !*/
+
+ double get_match_eps (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of alignment necessary for a detection to be considered
+ as matching with a ground truth rectangle. If it doesn't match then
+ it is considered to be a false alarm. To define this precisely, let
+ A and B be two rectangles, then A and B match if and only if:
+ A.intersect(B).area()/(A+B).area() > get_match_eps()
+ !*/
+
+ double get_loss_per_missed_target (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss experienced for failing to detect one of the
+ targets. If you care more about finding targets than having a low false
+ alarm rate then you can increase this value.
+ !*/
+
+ void set_loss_per_missed_target (
+ double loss
+ );
+ /*!
+ requires
+ - loss > 0
+ ensures
+ - #get_loss_per_missed_target() == loss
+ !*/
+
+ double get_loss_per_false_alarm (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss experienced for emitting a false alarm detection.
+ Or in other words, the loss for generating a detection that doesn't correspond
+ to one of the truth rectangles. If you care more about having a low false
+ alarm rate than finding all the targets then you can increase this value.
+ !*/
+
+ void set_loss_per_false_alarm (
+ double loss
+ );
+ /*!
+ requires
+ - loss > 0
+ ensures
+ - #get_loss_per_false_alarm() == loss
+ !*/
+
+ template <
+ typename image_array_type
+ >
+ const trained_function_type train (
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_object_detections
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(images, truth_object_detections) == true
+ - it must be valid to pass images[0] into the image_scanner_type::load() method.
+ (also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h)
+ - for all valid i, j:
+ - truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template()
+ - all_parts_in_rect(truth_object_detections[i][j]) == true
+ ensures
+ - Uses the structural_svm_object_detection_problem to train an object_detector
+ on the given images and truth_object_detections.
+ - returns a function F with the following properties:
+ - F(new_image) == A prediction of what objects are present in new_image. This
+ is a set of rectangles indicating their positions.
+ !*/
+
+ template <
+ typename image_array_type
+ >
+ const trained_function_type train (
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_object_detections
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(images, truth_object_detections) == true
+ - it must be valid to pass images[0] into the image_scanner_type::load() method.
+ (also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h)
+ - get_scanner().get_num_movable_components_per_detection_template() == 0
+ ensures
+ - This function is identical to the above train(), except that it converts
+ each element of truth_object_detections into a full_object_detection by
+ passing it to full_object_detection's constructor taking only a rectangle.
+ Therefore, this version of train() is a convenience function for for the
+ case where you don't have any movable components of the detection templates.
+ !*/
+
+ template <
+ typename image_array_type
+ >
+ const trained_function_type train (
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_object_detections,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& ignore_overlap_tester = test_box_overlap()
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(images, truth_object_detections) == true
+ - it must be valid to pass images[0] into the image_scanner_type::load() method.
+ (also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h)
+ - ignore.size() == images.size()
+ - for all valid i, j:
+ - truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template()
+ - all_parts_in_rect(truth_object_detections[i][j]) == true
+ ensures
+ - Uses the structural_svm_object_detection_problem to train an object_detector
+ on the given images and truth_object_detections.
+ - for all valid i:
+ - Within images[i] any detections that match against a rectangle in
+ ignore[i], according to ignore_overlap_tester, are ignored. That is,
+ the optimizer doesn't care if the detector outputs a detection that
+ matches any of the ignore rectangles or if it fails to output a
+ detection for an ignore rectangle. Therefore, if there are objects
+ in your dataset that you are unsure if you want to detect or otherwise
+ don't care if the detector gets or doesn't then you can mark them
+ with ignore rectangles and the optimizer will simply ignore them.
+ - returns a function F with the following properties:
+ - F(new_image) == A prediction of what objects are present in new_image. This
+ is a set of rectangles indicating their positions.
+ !*/
+
+ template <
+ typename image_array_type
+ >
+ const trained_function_type train (
+ const image_array_type& images,
+ const std::vector<std::vector<rectangle> >& truth_object_detections,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& ignore_overlap_tester = test_box_overlap()
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(images, truth_object_detections) == true
+ - ignore.size() == images.size()
+ - it must be valid to pass images[0] into the image_scanner_type::load() method.
+ (also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h)
+ - get_scanner().get_num_movable_components_per_detection_template() == 0
+ ensures
+ - This function is identical to the above train(), except that it converts
+ each element of truth_object_detections into a full_object_detection by
+ passing it to full_object_detection's constructor taking only a rectangle.
+ Therefore, this version of train() is a convenience function for for the
+ case where you don't have any movable components of the detection templates.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H_ABSTRACTh_
+
+
diff --git a/ml/dlib/dlib/svm/structural_sequence_labeling_trainer.h b/ml/dlib/dlib/svm/structural_sequence_labeling_trainer.h
new file mode 100644
index 000000000..9b61fd6c2
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_sequence_labeling_trainer.h
@@ -0,0 +1,271 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_
+#define DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_
+
+#include "structural_sequence_labeling_trainer_abstract.h"
+#include "../algs.h"
+#include "../optimization.h"
+#include "structural_svm_sequence_labeling_problem.h"
+#include "num_nonnegative_weights.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class structural_sequence_labeling_trainer
+ {
+ public:
+ typedef typename feature_extractor::sequence_type sample_sequence_type;
+ typedef std::vector<unsigned long> labeled_sequence_type;
+
+ typedef sequence_labeler<feature_extractor> trained_function_type;
+
+ explicit structural_sequence_labeling_trainer (
+ const feature_extractor& fe_
+ ) : fe(fe_)
+ {
+ set_defaults();
+ }
+
+ structural_sequence_labeling_trainer (
+ )
+ {
+ set_defaults();
+ }
+
+ const feature_extractor& get_feature_extractor (
+ ) const { return fe; }
+
+ unsigned long num_labels (
+ ) const { return fe.num_labels(); }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ num_threads = num;
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return num_threads;
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void structural_sequence_labeling_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ double get_epsilon (
+ ) const { return eps; }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void set_max_cache_size (
+ unsigned long max_size
+ )
+ {
+ max_cache_size = max_size;
+ }
+
+ unsigned long get_max_cache_size (
+ ) const
+ {
+ return max_cache_size;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ void set_c (
+ double C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void structural_sequence_labeling_trainer::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ double get_c (
+ ) const
+ {
+ return C;
+ }
+
+ double get_loss (
+ unsigned long label
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(label < num_labels(),
+ "\t void structural_sequence_labeling_trainer::get_loss()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t label: " << label
+ << "\n\t num_labels(): " << num_labels()
+ << "\n\t this: " << this
+ );
+
+ return loss_values[label];
+ }
+
+ void set_loss (
+ unsigned long label,
+ double value
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(label < num_labels() && value >= 0,
+ "\t void structural_sequence_labeling_trainer::set_loss()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t label: " << label
+ << "\n\t num_labels(): " << num_labels()
+ << "\n\t value: " << value
+ << "\n\t this: " << this
+ );
+
+ loss_values[label] = value;
+ }
+
+
+ const sequence_labeler<feature_extractor> train(
+ const std::vector<sample_sequence_type>& x,
+ const std::vector<labeled_sequence_type>& y
+ ) const
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true &&
+ contains_invalid_labeling(get_feature_extractor(), x, y) == false,
+ "\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.size(): " << x.size()
+ << "\n\t is_sequence_labeling_problem(x,y): " << is_sequence_labeling_problem(x,y)
+ << "\n\t contains_invalid_labeling(get_feature_extractor(),x,y): " << contains_invalid_labeling(get_feature_extractor(),x,y)
+ << "\n\t this: " << this
+ );
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < y.size(); ++i)
+ {
+ for (unsigned long j = 0; j < y[i].size(); ++j)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(y[i][j] < num_labels(),
+ "\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
+ << "\n\t The given labels in y are invalid."
+ << "\n\t y[i][j]: " << y[i][j]
+ << "\n\t num_labels(): " << num_labels()
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t this: " << this
+ );
+ }
+ }
+#endif
+
+
+
+
+ structural_svm_sequence_labeling_problem<feature_extractor> prob(x, y, fe, num_threads);
+ matrix<double,0,1> weights;
+ if (verbose)
+ prob.be_verbose();
+
+ prob.set_epsilon(eps);
+ prob.set_max_iterations(max_iterations);
+ prob.set_c(C);
+ prob.set_max_cache_size(max_cache_size);
+ for (unsigned long i = 0; i < loss_values.size(); ++i)
+ prob.set_loss(i,loss_values[i]);
+
+ solver(prob, weights, num_nonnegative_weights(fe));
+
+ return sequence_labeler<feature_extractor>(weights,fe);
+ }
+
+ private:
+
+ double C;
+ oca solver;
+ double eps;
+ unsigned long max_iterations;
+ bool verbose;
+ unsigned long num_threads;
+ unsigned long max_cache_size;
+ std::vector<double> loss_values;
+
+ void set_defaults ()
+ {
+ C = 100;
+ verbose = false;
+ eps = 0.1;
+ max_iterations = 10000;
+ num_threads = 2;
+ max_cache_size = 5;
+ loss_values.assign(num_labels(), 1);
+ }
+
+ feature_extractor fe;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/structural_sequence_labeling_trainer_abstract.h b/ml/dlib/dlib/svm/structural_sequence_labeling_trainer_abstract.h
new file mode 100644
index 000000000..43e5f5131
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_sequence_labeling_trainer_abstract.h
@@ -0,0 +1,266 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "../optimization.h"
+#include "structural_svm_sequence_labeling_problem_abstract.h"
+#include "sequence_labeler_abstract.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class structural_sequence_labeling_trainer
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ It must be an object that implements an interface compatible with
+ the example_feature_extractor defined in dlib/svm/sequence_labeler_abstract.h.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning to do sequence labeling based
+ on a set of training data. The training procedure produces a
+ sequence_labeler object which can be used to predict the labels of
+ new data sequences.
+
+ Note that this is just a convenience wrapper around the
+ structural_svm_sequence_labeling_problem to make it look
+ similar to all the other trainers in dlib.
+ !*/
+
+ public:
+ typedef typename feature_extractor::sequence_type sample_sequence_type;
+ typedef std::vector<unsigned long> labeled_sequence_type;
+ typedef sequence_labeler<feature_extractor> trained_function_type;
+
+ structural_sequence_labeling_trainer (
+ );
+ /*!
+ ensures
+ - #get_c() == 100
+ - this object isn't verbose
+ - #get_epsilon() == 0.1
+ - #get_max_iterations() == 10000
+ - #get_num_threads() == 2
+ - #get_max_cache_size() == 5
+ - #get_feature_extractor() == a default initialized feature_extractor
+ !*/
+
+ explicit structural_sequence_labeling_trainer (
+ const feature_extractor& fe
+ );
+ /*!
+ ensures
+ - #get_c() == 100
+ - this object isn't verbose
+ - #get_epsilon() == 0.1
+ - #get_max_iterations() == 10000
+ - #get_num_threads() == 2
+ - #get_max_cache_size() == 5
+ - #get_feature_extractor() == fe
+ !*/
+
+ const feature_extractor& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used by this object
+ !*/
+
+ unsigned long num_labels (
+ ) const;
+ /*!
+ ensures
+ - returns get_feature_extractor().num_labels()
+ (i.e. returns the number of possible output labels for each
+ element of a sequence)
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads used during training. You should
+ usually set this equal to the number of processing cores on your
+ machine.
+ !*/
+
+ void set_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer
+ to train. You can think of this epsilon value as saying "solve the
+ optimization problem until the average number of labeling mistakes per
+ training sample is within epsilon of its optimal value".
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ unsigned long get_max_iterations (
+ );
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void set_max_cache_size (
+ unsigned long max_size
+ );
+ /*!
+ ensures
+ - #get_max_cache_size() == max_size
+ !*/
+
+ unsigned long get_max_cache_size (
+ ) const;
+ /*!
+ ensures
+ - During training, this object basically runs the sequence_labeler on
+ each training sample, over and over. To speed this up, it is possible to
+ cache the results of these labeler invocations. This function returns the
+ number of cache elements per training sample kept in the cache. Note
+ that a value of 0 means caching is not used at all.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the structural SVM problem.
+ !*/
+
+ void set_c (
+ double C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() = C
+ !*/
+
+ double get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter
+ that determines the trade-off between trying to fit the training
+ data (i.e. minimize the loss) or allowing more errors but hopefully
+ improving the generalization of the resulting sequence labeler. Larger
+ values encourage exact fitting while smaller values of C may encourage
+ better generalization.
+ !*/
+
+ double get_loss (
+ unsigned long label
+ ) const;
+ /*!
+ requires
+ - label < num_labels()
+ ensures
+ - returns the loss incurred when a sequence element with the given
+ label is misclassified. This value controls how much we care about
+ correctly classifying this type of label. Larger loss values indicate
+ that we care more strongly than smaller values.
+ !*/
+
+ void set_loss (
+ unsigned long label,
+ double value
+ );
+ /*!
+ requires
+ - label < num_labels()
+ - value >= 0
+ ensures
+ - #get_loss(label) == value
+ !*/
+
+ const sequence_labeler<feature_extractor> train(
+ const std::vector<sample_sequence_type>& x,
+ const std::vector<labeled_sequence_type>& y
+ ) const;
+ /*!
+ requires
+ - is_sequence_labeling_problem(x, y) == true
+ - contains_invalid_labeling(get_feature_extractor(), x, y) == false
+ - for all valid i and j: y[i][j] < num_labels()
+ ensures
+ - Uses the structural_svm_sequence_labeling_problem to train a
+ sequence_labeler on the given x/y training pairs. The idea is
+ to learn to predict a y given an input x.
+ - returns a function F with the following properties:
+ - F(new_x) == A sequence of predicted labels for the elements of new_x.
+ - F(new_x).size() == new_x.size()
+ - for all valid i:
+ - F(new_x)[i] == the predicted label of new_x[i]
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_ABSTRACT_Hh_
+
+
+
+
diff --git a/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer.h b/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer.h
new file mode 100644
index 000000000..2e0214008
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer.h
@@ -0,0 +1,281 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_
+#define DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_
+
+#include "structural_sequence_segmentation_trainer_abstract.h"
+#include "structural_sequence_labeling_trainer.h"
+#include "sequence_segmenter.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class structural_sequence_segmentation_trainer
+ {
+ public:
+ typedef typename feature_extractor::sequence_type sample_sequence_type;
+ typedef std::vector<std::pair<unsigned long, unsigned long> > segmented_sequence_type;
+
+ typedef sequence_segmenter<feature_extractor> trained_function_type;
+
+ explicit structural_sequence_segmentation_trainer (
+ const feature_extractor& fe_
+ ) : trainer(impl_ss::feature_extractor<feature_extractor>(fe_))
+ {
+ loss_per_missed_segment = 1;
+ loss_per_false_alarm = 1;
+ }
+
+ structural_sequence_segmentation_trainer (
+ )
+ {
+ loss_per_missed_segment = 1;
+ loss_per_false_alarm = 1;
+ }
+
+ const feature_extractor& get_feature_extractor (
+ ) const { return trainer.get_feature_extractor().fe; }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ trainer.set_num_threads(num);
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return trainer.get_num_threads();
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void structural_sequence_segmentation_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ trainer.set_epsilon(eps_);
+ }
+
+ double get_epsilon (
+ ) const { return trainer.get_epsilon(); }
+
+ unsigned long get_max_iterations (
+ ) const { return trainer.get_max_iterations(); }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ trainer.set_max_iterations(max_iter);
+ }
+
+ void set_max_cache_size (
+ unsigned long max_size
+ )
+ {
+ trainer.set_max_cache_size(max_size);
+ }
+
+ unsigned long get_max_cache_size (
+ ) const
+ {
+ return trainer.get_max_cache_size();
+ }
+
+ void be_verbose (
+ )
+ {
+ trainer.be_verbose();
+ }
+
+ void be_quiet (
+ )
+ {
+ trainer.be_quiet();
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ trainer.set_oca(item);
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return trainer.get_oca();
+ }
+
+ void set_c (
+ double C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void structural_sequence_segmentation_trainer::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ trainer.set_c(C_);
+ }
+
+ double get_c (
+ ) const
+ {
+ return trainer.get_c();
+ }
+
+ void set_loss_per_missed_segment (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss >= 0,
+ "\t void structural_sequence_segmentation_trainer::set_loss_per_missed_segment(loss)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_missed_segment = loss;
+
+ if (feature_extractor::use_BIO_model)
+ {
+ trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
+ trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
+ }
+ else
+ {
+ trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
+ trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
+ trainer.set_loss(impl_ss::LAST, loss_per_missed_segment);
+ trainer.set_loss(impl_ss::UNIT, loss_per_missed_segment);
+ }
+ }
+
+ double get_loss_per_missed_segment (
+ ) const
+ {
+ return loss_per_missed_segment;
+ }
+
+ void set_loss_per_false_alarm (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss >= 0,
+ "\t void structural_sequence_segmentation_trainer::set_loss_per_false_alarm(loss)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_false_alarm = loss;
+
+ trainer.set_loss(impl_ss::OUTSIDE, loss_per_false_alarm);
+ }
+
+ double get_loss_per_false_alarm (
+ ) const
+ {
+ return loss_per_false_alarm;
+ }
+
+ const sequence_segmenter<feature_extractor> train(
+ const std::vector<sample_sequence_type>& x,
+ const std::vector<segmented_sequence_type>& y
+ ) const
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_sequence_segmentation_problem(x,y) == true,
+ "\t sequence_segmenter structural_sequence_segmentation_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.size(): " << x.size()
+ << "\n\t is_sequence_segmentation_problem(x,y): " << is_sequence_segmentation_problem(x,y)
+ << "\n\t this: " << this
+ );
+
+ std::vector<std::vector<unsigned long> > labels(y.size());
+ if (feature_extractor::use_BIO_model)
+ {
+ // convert y into tagged BIO labels
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ labels[i].resize(x[i].size(), impl_ss::OUTSIDE);
+ for (unsigned long j = 0; j < y[i].size(); ++j)
+ {
+ const unsigned long begin = y[i][j].first;
+ const unsigned long end = y[i][j].second;
+ if (begin != end)
+ {
+ labels[i][begin] = impl_ss::BEGIN;
+ for (unsigned long k = begin+1; k < end; ++k)
+ labels[i][k] = impl_ss::INSIDE;
+ }
+ }
+ }
+ }
+ else
+ {
+ // convert y into tagged BILOU labels
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ labels[i].resize(x[i].size(), impl_ss::OUTSIDE);
+ for (unsigned long j = 0; j < y[i].size(); ++j)
+ {
+ const unsigned long begin = y[i][j].first;
+ const unsigned long end = y[i][j].second;
+ if (begin != end)
+ {
+ if (begin+1==end)
+ {
+ labels[i][begin] = impl_ss::UNIT;
+ }
+ else
+ {
+ labels[i][begin] = impl_ss::BEGIN;
+ for (unsigned long k = begin+1; k+1 < end; ++k)
+ labels[i][k] = impl_ss::INSIDE;
+ labels[i][end-1] = impl_ss::LAST;
+ }
+ }
+ }
+ }
+ }
+
+ sequence_labeler<impl_ss::feature_extractor<feature_extractor> > temp;
+ temp = trainer.train(x, labels);
+ return sequence_segmenter<feature_extractor>(temp.get_weights(), trainer.get_feature_extractor().fe);
+ }
+
+ private:
+
+ structural_sequence_labeling_trainer<impl_ss::feature_extractor<feature_extractor> > trainer;
+ double loss_per_missed_segment;
+ double loss_per_false_alarm;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer_abstract.h b/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer_abstract.h
new file mode 100644
index 000000000..bcd927ca6
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer_abstract.h
@@ -0,0 +1,264 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_ABSTRACT_Hh_
+
+#include "sequence_segmenter_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class structural_sequence_segmentation_trainer
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ It must be an object that implements an interface compatible with
+ the example_feature_extractor defined in dlib/svm/sequence_segmenter_abstract.h.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning to do sequence segmentation based on a
+ set of training data. The training procedure produces a sequence_segmenter
+ object which can be used to identify the sub-segments of new data
+ sequences.
+
+ This object internally uses the structural_sequence_labeling_trainer to
+ solve the learning problem.
+ !*/
+
+ public:
+
+ typedef typename feature_extractor::sequence_type sample_sequence_type;
+ typedef std::vector<std::pair<unsigned long, unsigned long> > segmented_sequence_type;
+
+ typedef sequence_segmenter<feature_extractor> trained_function_type;
+
+ structural_sequence_segmentation_trainer (
+ );
+ /*!
+ ensures
+ - #get_c() == 100
+ - this object isn't verbose
+ - #get_epsilon() == 0.1
+ - #get_max_iterations() == 10000
+ - #get_num_threads() == 2
+ - #get_max_cache_size() == 40
+ - #get_feature_extractor() == a default initialized feature_extractor
+ - #get_loss_per_missed_segment() == 1
+ - #get_loss_per_false_alarm() == 1
+ !*/
+
+ explicit structural_sequence_segmentation_trainer (
+ const feature_extractor& fe
+ );
+ /*!
+ ensures
+ - #get_c() == 100
+ - this object isn't verbose
+ - #get_epsilon() == 0.1
+ - #get_max_iterations() == 10000
+ - #get_num_threads() == 2
+ - #get_max_cache_size() == 40
+ - #get_feature_extractor() == fe
+ - #get_loss_per_missed_segment() == 1
+ - #get_loss_per_false_alarm() == 1
+ !*/
+
+ const feature_extractor& get_feature_extractor (
+ ) const;
+ /*!
+ ensures
+ - returns the feature extractor used by this object
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads used during training. You should
+ usually set this equal to the number of processing cores on your
+ machine.
+ !*/
+
+ void set_epsilon (
+ double eps_
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer
+ to train. You can think of this epsilon value as saying "solve the
+ optimization problem until the average number of segmentation mistakes
+ per training sample is within epsilon of its optimal value".
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ unsigned long get_max_iterations (
+ );
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void set_max_cache_size (
+ unsigned long max_size
+ );
+ /*!
+ ensures
+ - #get_max_cache_size() == max_size
+ !*/
+
+ unsigned long get_max_cache_size (
+ ) const;
+ /*!
+ ensures
+ - During training, this object basically runs the sequence_segmenter on
+ each training sample, over and over. To speed this up, it is possible to
+ cache the results of these segmenter invocations. This function returns
+ the number of cache elements per training sample kept in the cache. Note
+ that a value of 0 means caching is not used at all.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a user can
+ observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the structural SVM problem.
+ !*/
+
+ void set_c (
+ double C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() = C
+ !*/
+
+ double get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter that
+ determines the trade-off between trying to fit the training data (i.e.
+ minimize the loss) or allowing more errors but hopefully improving the
+ generalization of the resulting sequence labeler. Larger values
+ encourage exact fitting while smaller values of C may encourage better
+ generalization.
+ !*/
+
+ void set_loss_per_missed_segment (
+ double loss
+ );
+ /*!
+ requires
+ - loss >= 0
+ ensures
+ - #get_loss_per_missed_segment() == loss
+ !*/
+
+ double get_loss_per_missed_segment (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss incurred for failing to detect a segment. The
+ larger the loss the more important it is to detect all the segments.
+ !*/
+
+
+ void set_loss_per_false_alarm (
+ double loss
+ );
+ /*!
+ requires
+ - loss >= 0
+ ensures
+ - #get_loss_per_false_alarm() == loss
+ !*/
+
+ double get_loss_per_false_alarm (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss incurred for outputting a false detection. The
+ larger the loss the more important it is to avoid outputting false
+ detections.
+ !*/
+
+ const sequence_segmenter<feature_extractor> train(
+ const std::vector<sample_sequence_type>& x,
+ const std::vector<segmented_sequence_type>& y
+ ) const;
+ /*!
+ requires
+ - is_sequence_segmentation_problem(x, y) == true
+ ensures
+ - Uses the given training data to learn to do sequence segmentation. That
+ is, this function will try to find a sequence_segmenter capable of
+ predicting y[i] when given x[i] as input. Moreover, it should also be
+ capable of predicting the segmentation of new input sequences. Or in
+ other words, the learned sequence_segmenter should also generalize to new
+ data outside the training dataset.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/structural_svm_assignment_problem.h b/ml/dlib/dlib/svm/structural_svm_assignment_problem.h
new file mode 100644
index 000000000..963af1631
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_assignment_problem.h
@@ -0,0 +1,288 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_Hh_
+#define DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_Hh_
+
+
+#include "structural_svm_assignment_problem_abstract.h"
+#include "../matrix.h"
+#include <vector>
+#include <iterator>
+#include "structural_svm_problem_threaded.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ template <long n, typename T>
+ struct column_matrix_static_resize
+ {
+ typedef T type;
+ };
+
+ template <long n, typename T, long NR, long NC, typename MM, typename L>
+ struct column_matrix_static_resize<n, matrix<T,NR,NC,MM,L> >
+ {
+ typedef matrix<T,NR+n,NC,MM,L> type;
+ };
+
+ template <long n, typename T, long NC, typename MM, typename L>
+ struct column_matrix_static_resize<n, matrix<T,0,NC,MM,L> >
+ {
+ typedef matrix<T,0,NC,MM,L> type;
+ };
+
+ template <typename T>
+ struct add_one_to_static_feat_size
+ {
+ typedef typename column_matrix_static_resize<1,typename T::feature_vector_type>::type type;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class structural_svm_assignment_problem : noncopyable,
+ public structural_svm_problem_threaded<matrix<double,0,1>, typename add_one_to_static_feat_size<feature_extractor>::type >
+ {
+ public:
+ typedef matrix<double,0,1> matrix_type;
+ typedef typename add_one_to_static_feat_size<feature_extractor>::type feature_vector_type;
+
+ typedef typename feature_extractor::lhs_element lhs_element;
+ typedef typename feature_extractor::rhs_element rhs_element;
+
+
+ typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
+
+ typedef std::vector<long> label_type;
+
+ structural_svm_assignment_problem(
+ const std::vector<sample_type>& samples_,
+ const std::vector<label_type>& labels_,
+ const feature_extractor& fe_,
+ bool force_assignment_,
+ unsigned long num_threads,
+ const double loss_per_false_association_,
+ const double loss_per_missed_association_
+ ) :
+ structural_svm_problem_threaded<matrix_type,feature_vector_type>(num_threads),
+ samples(samples_),
+ labels(labels_),
+ fe(fe_),
+ force_assignment(force_assignment_),
+ loss_per_false_association(loss_per_false_association_),
+ loss_per_missed_association(loss_per_missed_association_)
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ DLIB_ASSERT(loss_per_false_association > 0 && loss_per_missed_association > 0,
+ "\t structural_svm_assignment_problem::structural_svm_assignment_problem()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t loss_per_false_association: " << loss_per_false_association
+ << "\n\t loss_per_missed_association: " << loss_per_missed_association
+ << "\n\t this: " << this
+ );
+ if (force_assignment)
+ {
+ DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
+ "\t structural_svm_assignment_problem::structural_svm_assignment_problem()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ << "\n\t this: " << this
+ );
+ }
+ else
+ {
+ DLIB_ASSERT(is_assignment_problem(samples, labels),
+ "\t structural_svm_assignment_problem::structural_svm_assignment_problem()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ << "\n\t this: " << this
+ );
+ }
+#endif
+
+ }
+
+ private:
+ virtual long get_num_dimensions (
+ ) const
+ {
+ return fe.num_features()+1; // +1 for the bias term
+ }
+
+ virtual long get_num_samples (
+ ) const
+ {
+ return samples.size();
+ }
+
+ template <typename psi_type>
+ typename enable_if<is_matrix<psi_type> >::type get_joint_feature_vector (
+ const sample_type& sample,
+ const label_type& label,
+ psi_type& psi
+ ) const
+ {
+ typename feature_extractor::feature_vector_type feats;
+ psi.set_size(get_num_dimensions());
+ psi = 0;
+ for (unsigned long i = 0; i < sample.first.size(); ++i)
+ {
+ if (label[i] != -1)
+ {
+ fe.get_features(sample.first[i], sample.second[label[i]], feats);
+ set_rowm(psi,range(0,feats.size()-1)) += feats;
+ psi(get_num_dimensions()-1) += 1;
+ }
+ }
+ }
+
+ template <typename T>
+ void append_to_sparse_vect (
+ T& psi,
+ const T& vect
+ ) const
+ {
+ std::copy(vect.begin(), vect.end(), std::back_inserter(psi));
+ }
+
+ template <typename psi_type>
+ typename disable_if<is_matrix<psi_type> >::type get_joint_feature_vector (
+ const sample_type& sample,
+ const label_type& label,
+ psi_type& psi
+ ) const
+ {
+ psi.clear();
+ feature_vector_type feats;
+ int num_assignments = 0;
+ for (unsigned long i = 0; i < sample.first.size(); ++i)
+ {
+ if (label[i] != -1)
+ {
+ fe.get_features(sample.first[i], sample.second[label[i]], feats);
+ append_to_sparse_vect(psi, feats);
+ ++num_assignments;
+ }
+ }
+ psi.push_back(std::make_pair(get_num_dimensions()-1,num_assignments));
+ }
+
+ virtual void get_truth_joint_feature_vector (
+ long idx,
+ feature_vector_type& psi
+ ) const
+ {
+ get_joint_feature_vector(samples[idx], labels[idx], psi);
+ }
+
+ virtual void separation_oracle (
+ const long idx,
+ const matrix_type& current_solution,
+ double& loss,
+ feature_vector_type& psi
+ ) const
+ {
+ matrix<double> cost;
+ unsigned long size;
+ if (force_assignment)
+ {
+ unsigned long lhs_size = samples[idx].first.size();
+ unsigned long rhs_size = samples[idx].second.size();
+ size = std::max(lhs_size, rhs_size);
+ }
+ else
+ {
+ unsigned long rhs_size = samples[idx].second.size() + samples[idx].first.size();
+ size = rhs_size;
+ }
+ cost.set_size(size, size);
+
+ typename feature_extractor::feature_vector_type feats;
+
+ // now fill out the cost assignment matrix
+ for (long r = 0; r < cost.nr(); ++r)
+ {
+ for (long c = 0; c < cost.nc(); ++c)
+ {
+ if (r < (long)samples[idx].first.size())
+ {
+ if (c < (long)samples[idx].second.size())
+ {
+ fe.get_features(samples[idx].first[r], samples[idx].second[c], feats);
+ const double bias = current_solution(current_solution.size()-1);
+ cost(r,c) = dot(colm(current_solution,0,current_solution.size()-1), feats) + bias;
+
+ // add in the loss since this corresponds to an incorrect prediction.
+ if (c != labels[idx][r])
+ {
+ cost(r,c) += loss_per_false_association;
+ }
+ }
+ else
+ {
+ if (labels[idx][r] == -1)
+ cost(r,c) = 0;
+ else
+ cost(r,c) = loss_per_missed_association;
+ }
+
+ }
+ else
+ {
+ cost(r,c) = 0;
+ }
+ }
+ }
+
+ std::vector<long> assignment;
+
+ if (cost.size() != 0)
+ {
+ // max_cost_assignment() only works with integer matrices, so convert from
+ // double to integer.
+ const double scale = (std::numeric_limits<dlib::int64>::max()/1000)/max(abs(cost));
+ matrix<dlib::int64> int_cost = matrix_cast<dlib::int64>(round(cost*scale));
+ assignment = max_cost_assignment(int_cost);
+ assignment.resize(samples[idx].first.size());
+ }
+
+ loss = 0;
+ // adjust assignment so that non-assignments have a value of -1. Also compute loss.
+ for (unsigned long i = 0; i < assignment.size(); ++i)
+ {
+ if (assignment[i] >= (long)samples[idx].second.size())
+ assignment[i] = -1;
+
+ if (assignment[i] != labels[idx][i])
+ {
+ if (assignment[i] == -1)
+ loss += loss_per_missed_association;
+ else
+ loss += loss_per_false_association;
+ }
+ }
+
+ get_joint_feature_vector(samples[idx], assignment, psi);
+ }
+
+ const std::vector<sample_type>& samples;
+ const std::vector<label_type>& labels;
+ const feature_extractor& fe;
+ bool force_assignment;
+ const double loss_per_false_association;
+ const double loss_per_missed_association;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_Hh_
+
diff --git a/ml/dlib/dlib/svm/structural_svm_assignment_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_assignment_problem_abstract.h
new file mode 100644
index 000000000..c06190726
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_assignment_problem_abstract.h
@@ -0,0 +1,87 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_ABSTRACT_Hh_
+
+
+#include "../matrix.h"
+#include <vector>
+#include "structural_svm_problem_threaded_abstract.h"
+#include "assignment_function_abstract.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ template <
+ typename feature_extractor
+ >
+ class structural_svm_assignment_problem : noncopyable,
+ public structural_svm_problem_threaded<matrix<double,0,1>,
+ typename feature_extractor::feature_vector_type >
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ It must be an object that implements an interface compatible with
+ the example_feature_extractor defined in dlib/svm/assignment_function_abstract.h.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning the parameters needed to use an
+ assignment_function object. It learns the parameters by formulating the
+ problem as a structural SVM problem.
+ !*/
+
+ public:
+ typedef matrix<double,0,1> matrix_type;
+ typedef typename feature_extractor::feature_vector_type feature_vector_type;
+ typedef typename feature_extractor::lhs_element lhs_element;
+ typedef typename feature_extractor::rhs_element rhs_element;
+ typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
+ typedef std::vector<long> label_type;
+
+ structural_svm_assignment_problem(
+ const std::vector<sample_type>& samples,
+ const std::vector<label_type>& labels,
+ const feature_extractor& fe,
+ bool force_assignment,
+ unsigned long num_threads,
+ const double loss_per_false_association,
+ const double loss_per_missed_association
+ );
+ /*!
+ requires
+ - loss_per_false_association > 0
+ - loss_per_missed_association > 0
+ - is_assignment_problem(samples,labels) == true
+ - if (force_assignment) then
+ - is_forced_assignment_problem(samples,labels) == true
+ ensures
+ - This object attempts to learn a mapping from the given samples to the
+ given labels. In particular, it attempts to learn to predict labels[i]
+ based on samples[i]. Or in other words, this object can be used to learn
+ a parameter vector and bias, w and b, such that an assignment_function declared as:
+ assignment_function<feature_extractor> assigner(w,b,fe,force_assignment)
+ results in an assigner object which attempts to compute the following mapping:
+ labels[i] == labeler(samples[i])
+ - This object will use num_threads threads during the optimization
+ procedure. You should set this parameter equal to the number of
+ available processing cores on your machine.
+ - When solving the structural SVM problem, we will use
+ loss_per_false_association as the loss for incorrectly associating
+ objects that shouldn't be associated.
+ - When solving the structural SVM problem, we will use
+ loss_per_missed_association as the loss for failing to associate to
+ objects that are supposed to be associated with each other.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/structural_svm_distributed.h b/ml/dlib/dlib/svm/structural_svm_distributed.h
new file mode 100644
index 000000000..a9542c70f
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_distributed.h
@@ -0,0 +1,700 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_
+#define DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_
+
+#include <memory>
+#include <iostream>
+#include <vector>
+
+#include "structural_svm_distributed_abstract.h"
+#include "structural_svm_problem.h"
+#include "../bridge.h"
+#include "../misc_api.h"
+#include "../statistics.h"
+#include "../threads.h"
+#include "../pipe.h"
+#include "../type_safe_union.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ template <typename matrix_type>
+ struct oracle_response
+ {
+ typedef typename matrix_type::type scalar_type;
+
+ matrix_type subgradient;
+ scalar_type loss;
+ long num;
+
+ friend void swap (oracle_response& a, oracle_response& b)
+ {
+ a.subgradient.swap(b.subgradient);
+ std::swap(a.loss, b.loss);
+ std::swap(a.num, b.num);
+ }
+
+ friend void serialize (const oracle_response& item, std::ostream& out)
+ {
+ serialize(item.subgradient, out);
+ dlib::serialize(item.loss, out);
+ dlib::serialize(item.num, out);
+ }
+
+ friend void deserialize (oracle_response& item, std::istream& in)
+ {
+ deserialize(item.subgradient, in);
+ dlib::deserialize(item.loss, in);
+ dlib::deserialize(item.num, in);
+ }
+ };
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename matrix_type>
+ struct oracle_request
+ {
+ typedef typename matrix_type::type scalar_type;
+
+ matrix_type current_solution;
+ scalar_type saved_current_risk_gap;
+ bool skip_cache;
+ bool converged;
+
+ friend void swap (oracle_request& a, oracle_request& b)
+ {
+ a.current_solution.swap(b.current_solution);
+ std::swap(a.saved_current_risk_gap, b.saved_current_risk_gap);
+ std::swap(a.skip_cache, b.skip_cache);
+ std::swap(a.converged, b.converged);
+ }
+
+ friend void serialize (const oracle_request& item, std::ostream& out)
+ {
+ serialize(item.current_solution, out);
+ dlib::serialize(item.saved_current_risk_gap, out);
+ dlib::serialize(item.skip_cache, out);
+ dlib::serialize(item.converged, out);
+ }
+
+ friend void deserialize (oracle_request& item, std::istream& in)
+ {
+ deserialize(item.current_solution, in);
+ dlib::deserialize(item.saved_current_risk_gap, in);
+ dlib::deserialize(item.skip_cache, in);
+ dlib::deserialize(item.converged, in);
+ }
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class svm_struct_processing_node : noncopyable
+ {
+ public:
+
+ template <
+ typename T,
+ typename U
+ >
+ svm_struct_processing_node (
+ const structural_svm_problem<T,U>& problem,
+ unsigned short port,
+ unsigned short num_threads
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(port != 0 && problem.get_num_samples() != 0 &&
+ problem.get_num_dimensions() != 0,
+ "\t svm_struct_processing_node()"
+ << "\n\t Invalid arguments were given to this function"
+ << "\n\t port: " << port
+ << "\n\t problem.get_num_samples(): " << problem.get_num_samples()
+ << "\n\t problem.get_num_dimensions(): " << problem.get_num_dimensions()
+ << "\n\t this: " << this
+ );
+
+ the_problem.reset(new node_type<T,U>(problem, port, num_threads));
+ }
+
+ private:
+
+ struct base
+ {
+ virtual ~base(){}
+ };
+
+ template <
+ typename matrix_type,
+ typename feature_vector_type
+ >
+ class node_type : public base, threaded_object
+ {
+ public:
+ typedef typename matrix_type::type scalar_type;
+
+ node_type(
+ const structural_svm_problem<matrix_type,feature_vector_type>& prob,
+ unsigned short port,
+ unsigned long num_threads
+ ) : in(3),out(3), problem(prob), tp(num_threads)
+ {
+ b.reconfigure(listen_on_port(port), receive(in), transmit(out));
+
+ start();
+ }
+
+ ~node_type()
+ {
+ in.disable();
+ out.disable();
+ wait();
+ }
+
+ private:
+
+ void thread()
+ {
+ using namespace impl;
+ tsu_in msg;
+ tsu_out temp;
+
+ timestamper ts;
+ running_stats<double> with_buffer_time;
+ running_stats<double> without_buffer_time;
+ unsigned long num_iterations_executed = 0;
+
+ while (in.dequeue(msg))
+ {
+ // initialize the cache and compute psi_true.
+ if (cache.size() == 0)
+ {
+ cache.resize(problem.get_num_samples());
+ for (unsigned long i = 0; i < cache.size(); ++i)
+ cache[i].init(&problem,i);
+
+ psi_true.set_size(problem.get_num_dimensions(),1);
+ psi_true = 0;
+
+ const unsigned long num = problem.get_num_samples();
+ feature_vector_type ftemp;
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ cache[i].get_truth_joint_feature_vector_cached(ftemp);
+
+ subtract_from(psi_true, ftemp);
+ }
+ }
+
+
+ if (msg.template contains<bridge_status>() &&
+ msg.template get<bridge_status>().is_connected)
+ {
+ temp = problem.get_num_dimensions();
+ out.enqueue(temp);
+
+ }
+ else if (msg.template contains<oracle_request<matrix_type> >())
+ {
+ ++num_iterations_executed;
+
+ const oracle_request<matrix_type>& req = msg.template get<oracle_request<matrix_type> >();
+
+ oracle_response<matrix_type>& data = temp.template get<oracle_response<matrix_type> >();
+
+ data.subgradient = psi_true;
+ data.loss = 0;
+
+ data.num = problem.get_num_samples();
+
+ const uint64 start_time = ts.get_timestamp();
+
+ // pick fastest buffering strategy
+ bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean();
+
+ // every 50 iterations we should try to flip the buffering scheme to see if
+ // doing it the other way might be better.
+ if ((num_iterations_executed%50) == 0)
+ {
+ buffer_subgradients_locally = !buffer_subgradients_locally;
+ }
+
+ binder b(*this, req, data, buffer_subgradients_locally);
+ parallel_for_blocked(tp, 0, data.num, b, &binder::call_oracle);
+
+ const uint64 stop_time = ts.get_timestamp();
+ if (buffer_subgradients_locally)
+ with_buffer_time.add(stop_time-start_time);
+ else
+ without_buffer_time.add(stop_time-start_time);
+
+ out.enqueue(temp);
+ }
+ }
+ }
+
+ struct binder
+ {
+ binder (
+ const node_type& self_,
+ const impl::oracle_request<matrix_type>& req_,
+ impl::oracle_response<matrix_type>& data_,
+ bool buffer_subgradients_locally_
+ ) : self(self_), req(req_), data(data_),
+ buffer_subgradients_locally(buffer_subgradients_locally_) {}
+
+ void call_oracle (
+ long begin,
+ long end
+ )
+ {
+ // If we are only going to call the separation oracle once then don't
+ // run the slightly more complex for loop version of this code. Or if
+ // we just don't want to run the complex buffering one. The code later
+ // on decides if we should do the buffering based on how long it takes
+ // to execute. We do this because, when the subgradient is really high
+ // dimensional it can take a lot of time to add them together. So we
+ // might want to avoid doing that.
+ if (end-begin <= 1 || !buffer_subgradients_locally)
+ {
+ scalar_type loss;
+ feature_vector_type ftemp;
+ for (long i = begin; i < end; ++i)
+ {
+ self.cache[i].separation_oracle_cached(req.converged,
+ req.skip_cache,
+ req.saved_current_risk_gap,
+ req.current_solution,
+ loss,
+ ftemp);
+
+ auto_mutex lock(self.accum_mutex);
+ data.loss += loss;
+ add_to(data.subgradient, ftemp);
+ }
+ }
+ else
+ {
+ scalar_type loss = 0;
+ matrix_type faccum(data.subgradient.size(),1);
+ faccum = 0;
+
+ feature_vector_type ftemp;
+
+ for (long i = begin; i < end; ++i)
+ {
+ scalar_type loss_temp;
+ self.cache[i].separation_oracle_cached(req.converged,
+ req.skip_cache,
+ req.saved_current_risk_gap,
+ req.current_solution,
+ loss_temp,
+ ftemp);
+ loss += loss_temp;
+ add_to(faccum, ftemp);
+ }
+
+ auto_mutex lock(self.accum_mutex);
+ data.loss += loss;
+ add_to(data.subgradient, faccum);
+ }
+ }
+
+ const node_type& self;
+ const impl::oracle_request<matrix_type>& req;
+ impl::oracle_response<matrix_type>& data;
+ bool buffer_subgradients_locally;
+ };
+
+
+
+ typedef type_safe_union<impl::oracle_request<matrix_type>, bridge_status> tsu_in;
+ typedef type_safe_union<impl::oracle_response<matrix_type> , long> tsu_out;
+
+ pipe<tsu_in> in;
+ pipe<tsu_out> out;
+ bridge b;
+
+ mutable matrix_type psi_true;
+ const structural_svm_problem<matrix_type,feature_vector_type>& problem;
+ mutable std::vector<cache_element_structural_svm<structural_svm_problem<matrix_type,feature_vector_type> > > cache;
+
+ mutable thread_pool tp;
+ mutex accum_mutex;
+ };
+
+
+ std::unique_ptr<base> the_problem;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class svm_struct_controller_node : noncopyable
+ {
+ public:
+
+ svm_struct_controller_node (
+ ) :
+ eps(0.001),
+ max_iterations(10000),
+ cache_based_eps(std::numeric_limits<double>::infinity()),
+ verbose(false),
+ C(1)
+ {}
+
+ double get_cache_based_epsilon (
+ ) const
+ {
+ return cache_based_eps;
+ }
+
+ void set_cache_based_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void svm_struct_controller_node::set_cache_based_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ cache_based_eps = eps_;
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void svm_struct_controller_node::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ double get_epsilon (
+ ) const { return eps; }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet(
+ )
+ {
+ verbose = false;
+ }
+
+ void add_nuclear_norm_regularizer (
+ long first_dimension,
+ long rows,
+ long cols,
+ double regularization_strength
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= first_dimension &&
+ 0 <= rows && 0 <= cols &&
+ 0 < regularization_strength,
+ "\t void svm_struct_controller_node::add_nuclear_norm_regularizer()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t first_dimension: " << first_dimension
+ << "\n\t rows: " << rows
+ << "\n\t cols: " << cols
+ << "\n\t regularization_strength: " << regularization_strength
+ << "\n\t this: " << this
+ );
+
+ impl::nuclear_norm_regularizer temp;
+ temp.first_dimension = first_dimension;
+ temp.nr = rows;
+ temp.nc = cols;
+ temp.regularization_strength = regularization_strength;
+ nuclear_norm_regularizers.push_back(temp);
+ }
+
+ unsigned long num_nuclear_norm_regularizers (
+ ) const { return nuclear_norm_regularizers.size(); }
+
+ void clear_nuclear_norm_regularizers (
+ ) { nuclear_norm_regularizers.clear(); }
+
+
+ double get_c (
+ ) const { return C; }
+
+ void set_c (
+ double C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void svm_struct_controller_node::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ void add_processing_node (
+ const network_address& addr
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(addr.port != 0,
+ "\t void svm_struct_controller_node::add_processing_node()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t addr.host_address: " << addr.host_address
+ << "\n\t addr.port: " << addr.port
+ << "\n\t this: " << this
+ );
+
+ // check if this address is already registered
+ for (unsigned long i = 0; i < nodes.size(); ++i)
+ {
+ if (nodes[i] == addr)
+ {
+ return;
+ }
+ }
+
+ nodes.push_back(addr);
+ }
+
+ void add_processing_node (
+ const std::string& ip_or_hostname,
+ unsigned short port
+ )
+ {
+ add_processing_node(network_address(ip_or_hostname,port));
+ }
+
+ unsigned long get_num_processing_nodes (
+ ) const
+ {
+ return nodes.size();
+ }
+
+ void remove_processing_nodes (
+ )
+ {
+ nodes.clear();
+ }
+
+ template <typename matrix_type>
+ double operator() (
+ const oca& solver,
+ matrix_type& w
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_num_processing_nodes() != 0,
+ "\t double svm_struct_controller_node::operator()"
+ << "\n\t You must add some processing nodes before calling this function."
+ << "\n\t this: " << this
+ );
+
+ problem_type<matrix_type> problem(nodes);
+ problem.set_cache_based_epsilon(cache_based_eps);
+ problem.set_epsilon(eps);
+ problem.set_max_iterations(max_iterations);
+ if (verbose)
+ problem.be_verbose();
+ problem.set_c(C);
+ for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i)
+ {
+ problem.add_nuclear_norm_regularizer(
+ nuclear_norm_regularizers[i].first_dimension,
+ nuclear_norm_regularizers[i].nr,
+ nuclear_norm_regularizers[i].nc,
+ nuclear_norm_regularizers[i].regularization_strength);
+ }
+
+ return solver(problem, w);
+ }
+
+ class invalid_problem : public error
+ {
+ public:
+ invalid_problem(
+ const std::string& a
+ ): error(a) {}
+ };
+
+
+ private:
+
+ template <typename matrix_type_>
+ class problem_type : public structural_svm_problem<matrix_type_>
+ {
+ public:
+ typedef typename matrix_type_::type scalar_type;
+ typedef matrix_type_ matrix_type;
+
+ problem_type (
+ const std::vector<network_address>& nodes_
+ ) :
+ nodes(nodes_),
+ in(3),
+ num_dims(0)
+ {
+
+ // initialize all the transmit pipes
+ out_pipes.resize(nodes.size());
+ for (unsigned long i = 0; i < out_pipes.size(); ++i)
+ {
+ out_pipes[i].reset(new pipe<tsu_out>(3));
+ }
+
+ // make bridges that connect to all our remote processing nodes
+ bridges.resize(nodes.size());
+ for (unsigned long i = 0; i< bridges.size(); ++i)
+ {
+ bridges[i].reset(new bridge(connect_to(nodes[i]),
+ receive(in), transmit(*out_pipes[i])));
+ }
+
+
+
+ // The remote processing nodes are supposed to all send the problem dimensionality
+ // upon connection. So get that and make sure everyone agrees on what it's supposed to be.
+ tsu_in temp;
+ unsigned long responses = 0;
+ bool seen_dim = false;
+ while (responses < nodes.size())
+ {
+ in.dequeue(temp);
+ if (temp.template contains<long>())
+ {
+ ++responses;
+ // if this new dimension doesn't match what we have seen previously
+ if (seen_dim && num_dims != temp.template get<long>())
+ {
+ throw invalid_problem("remote hosts disagree on the number of dimensions!");
+ }
+ seen_dim = true;
+ num_dims = temp.template get<long>();
+ }
+ }
+ }
+
+ // These functions are just here because the structural_svm_problem requires
+ // them, but since we are overloading get_risk() they are never called so they
+ // don't matter.
+ virtual long get_num_samples () const {return 0;}
+ virtual void get_truth_joint_feature_vector ( long , matrix_type& ) const {}
+ virtual void separation_oracle ( const long , const matrix_type& , scalar_type& , matrix_type& ) const {}
+
+ virtual long get_num_dimensions (
+ ) const
+ {
+ return num_dims;
+ }
+
+ virtual void get_risk (
+ matrix_type& w,
+ scalar_type& risk,
+ matrix_type& subgradient
+ ) const
+ {
+ using namespace impl;
+ subgradient.set_size(w.size(),1);
+ subgradient = 0;
+
+ // send out all the oracle requests
+ tsu_out temp_out;
+ for (unsigned long i = 0; i < out_pipes.size(); ++i)
+ {
+ temp_out.template get<oracle_request<matrix_type> >().current_solution = w;
+ temp_out.template get<oracle_request<matrix_type> >().saved_current_risk_gap = this->saved_current_risk_gap;
+ temp_out.template get<oracle_request<matrix_type> >().skip_cache = this->skip_cache;
+ temp_out.template get<oracle_request<matrix_type> >().converged = this->converged;
+ out_pipes[i]->enqueue(temp_out);
+ }
+
+ // collect all the oracle responses
+ long num = 0;
+ scalar_type total_loss = 0;
+ tsu_in temp_in;
+ unsigned long responses = 0;
+ while (responses < out_pipes.size())
+ {
+ in.dequeue(temp_in);
+ if (temp_in.template contains<oracle_response<matrix_type> >())
+ {
+ ++responses;
+ const oracle_response<matrix_type>& data = temp_in.template get<oracle_response<matrix_type> >();
+ subgradient += data.subgradient;
+ total_loss += data.loss;
+ num += data.num;
+ }
+ }
+
+ subgradient /= num;
+ total_loss /= num;
+ risk = total_loss + dot(subgradient,w);
+
+ if (this->nuclear_norm_regularizers.size() != 0)
+ {
+ matrix_type grad;
+ double obj;
+ this->compute_nuclear_norm_parts(w, grad, obj);
+ risk += obj;
+ subgradient += grad;
+ }
+ }
+
+ std::vector<network_address> nodes;
+
+ typedef type_safe_union<impl::oracle_request<matrix_type> > tsu_out;
+ typedef type_safe_union<impl::oracle_response<matrix_type>, long> tsu_in;
+
+ std::vector<std::shared_ptr<pipe<tsu_out> > > out_pipes;
+ mutable pipe<tsu_in> in;
+ std::vector<std::shared_ptr<bridge> > bridges;
+ long num_dims;
+ };
+
+ std::vector<network_address> nodes;
+ double eps;
+ unsigned long max_iterations;
+ double cache_based_eps;
+ bool verbose;
+ double C;
+ std::vector<impl::nuclear_norm_regularizer> nuclear_norm_regularizers;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_
+
diff --git a/ml/dlib/dlib/svm/structural_svm_distributed_abstract.h b/ml/dlib/dlib/svm/structural_svm_distributed_abstract.h
new file mode 100644
index 000000000..175a643c8
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_distributed_abstract.h
@@ -0,0 +1,357 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_ABSTRACT_Hh_
+
+
+#include "structural_svm_problem_abstract.h"
+#include "../optimization/optimization_oca_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class svm_struct_processing_node : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for distributing the work involved in solving
+ a dlib::structural_svm_problem across many computers. It is used in
+ conjunction with the svm_struct_controller_node defined below.
+ !*/
+
+ public:
+
+ template <
+ typename T,
+ typename U
+ >
+ svm_struct_processing_node (
+ const structural_svm_problem<T,U>& problem,
+ unsigned short port,
+ unsigned short num_threads
+ );
+ /*!
+ requires
+ - port != 0
+ - problem.get_num_samples() != 0
+ - problem.get_num_dimensions() != 0
+ ensures
+ - This object will listen on the given port for a TCP connection from a
+ svm_struct_controller_node. Once connected, the controller node will
+ be able to access the given problem.
+ - Will use num_threads threads at a time to make concurrent calls to the
+ problem.separation_oracle() routine. You should set this parameter equal
+ to the number of available processing cores.
+ - Note that the following parameters within the given problem are ignored:
+ - problem.get_c()
+ - problem.get_epsilon()
+ - problem.get_cache_based_epsilon()
+ - problem.num_nuclear_norm_regularizers()
+ - weather the problem is verbose or not
+ Instead, they are defined by the svm_struct_controller_node. Note, however,
+ that the problem.get_max_cache_size() parameter is meaningful and controls
+ the size of the separation oracle cache within a svm_struct_processing_node.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class svm_struct_controller_node : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ - get_num_processing_nodes() == 0
+ - get_epsilon() == 0.001
+ - get_max_iterations() == 10000
+ - get_c() == 1
+ - This object will not be verbose
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for distributing the work involved in solving a
+ dlib::structural_svm_problem across many computers. The best way to understand
+ its use is via example:
+
+ First, suppose you have defined a structural_svm_problem object by inheriting from
+ it and defining the appropriate virtual functions. You could solve it by passing
+ an instance to the oca optimizer. However, if your separation oracle takes a long
+ time to evaluate then the optimization will take a long time to solve. To speed
+ this up we can distribute the calls to the separation oracle across many computers.
+
+ To make this concrete, lets imagine you want to distribute the work across three
+ computers. You can accomplish this by creating four programs. One containing a
+ svm_struct_controller_node and three containing svm_struct_processing_nodes.
+
+ The programs might look like this:
+
+ Controller program:
+ int main()
+ {
+ svm_struct_controller_node cont;
+ cont.set_c(100);
+ // Tell cont where the processing nodes are on your network.
+ cont.add_processing_node("192.168.1.10:12345");
+ cont.add_processing_node("192.168.1.11:12345");
+ cont.add_processing_node("192.168.1.12:12345");
+ matrix<double> w;
+ oca solver;
+ cont(solver, w); // Run the optimization.
+ // After this finishes w will contain the solution vector.
+ }
+
+ Processing programs (they are all the same, except that each loads a different subset
+ of the training data):
+ int main()
+ {
+ // Put one third of your data into this problem object. How you do this depends on your problem.
+ your_structural_svm_problem problem;
+ svm_struct_processing_node node(problem, 12345, number_of_cores_on_this_computer);
+ cout << "hit enter to terminate this program" << endl;
+ cin.get();
+ }
+
+ !*/
+
+ public:
+
+ svm_struct_controller_node (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void set_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer
+ to execute. Specifically, the algorithm stops when the average sample
+ risk (i.e. R(w) as defined by the dlib::structural_svm_problem object) is
+ within epsilon of its optimal value.
+
+ Also note that sample risk is an upper bound on a sample's loss. So
+ you can think of this epsilon value as saying "solve the optimization
+ problem until the average loss per sample is within epsilon of it's
+ optimal value".
+ !*/
+
+ double get_cache_based_epsilon (
+ ) const;
+ /*!
+ ensures
+ - if (get_max_cache_size() != 0) then
+ - The solver will not stop when the average sample risk is within
+ get_epsilon() of its optimal value. Instead, it will keep running
+ but will run the optimizer completely on the cache until the average
+ sample risk is within #get_cache_based_epsilon() of its optimal
+ value. This means that it will perform this additional refinement in
+ the solution accuracy without making any additional calls to the
+ separation_oracle(). This is useful when using a nuclear norm
+ regularization term because it allows you to quickly solve the
+ optimization problem to a high precision, which in the case of a
+ nuclear norm regularized problem means that many of the learned
+ matrices will be low rank or very close to low rank due to the
+ nuclear norm regularizer. This may not happen without solving the
+ problem to a high accuracy or their ranks may be difficult to
+ determine, so the extra accuracy given by the cache based refinement
+ is very useful. Finally, note that we include the nuclear norm term
+ as part of the "risk" for the purposes of determining when to stop.
+ - else
+ - The value of #get_cache_based_epsilon() has no effect.
+ !*/
+
+ void set_cache_based_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_cache_based_epsilon() == eps
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ unsigned long get_max_iterations (
+ );
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void add_nuclear_norm_regularizer (
+ long first_dimension,
+ long rows,
+ long cols,
+ double regularization_strength
+ );
+ /*!
+ requires
+ - 0 <= first_dimension < number of dimensions in problem
+ - 0 <= rows
+ - 0 <= cols
+ - first_dimension+rows*cols <= number of dimensions in problem
+ - 0 < regularization_strength
+ ensures
+ - Adds a nuclear norm regularization term to the optimization problem
+ solved by this object. That is, instead of solving:
+ Minimize: h(w) == 0.5*dot(w,w) + C*R(w)
+ this object will solve:
+ Minimize: h(w) == 0.5*dot(w,w) + C*R(w) + regularization_strength*nuclear_norm_of(part of w)
+ where "part of w" is the part of w indicated by the arguments to this
+ function. In particular, the part of w included in the nuclear norm is
+ exactly the matrix reshape(rowm(w, range(first_dimension, first_dimension+rows*cols-1)), rows, cols).
+ Therefore, if you think of the w vector as being the concatenation of a
+ bunch of matrices then you can use multiple calls to add_nuclear_norm_regularizer()
+ to add nuclear norm regularization terms to any of the matrices packed into w.
+ - #num_nuclear_norm_regularizers() == num_nuclear_norm_regularizers() + 1
+ !*/
+
+ unsigned long num_nuclear_norm_regularizers (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nuclear norm regularizers that are currently a part
+ of this optimization problem. That is, returns the number of times
+ add_nuclear_norm_regularizer() has been called since the last call to
+ clear_nuclear_norm_regularizers() or object construction, whichever is
+ most recent.
+ !*/
+
+ void clear_nuclear_norm_regularizers (
+ );
+ /*!
+ ensures
+ - #num_nuclear_norm_regularizers() == 0
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet(
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ double get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter that
+ determines the trade off between trying to fit the training data
+ exactly or allowing more errors but hopefully improving the
+ generalization of the resulting classifier. Larger values encourage
+ exact fitting while smaller values of C may encourage better
+ generalization.
+ !*/
+
+ void set_c (
+ double C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() == C
+ !*/
+
+ void add_processing_node (
+ const network_address& addr
+ );
+ /*!
+ requires
+ - addr.port != 0
+ ensures
+ - if (this address hasn't already been added) then
+ - #get_num_processing_nodes() == get_num_processing_nodes() + 1
+ - When operator() is invoked to solve the structural svm problem this
+ object will connect to the svm_struct_processing_node located at the
+ given network address and will include it in the distributed
+ optimization.
+ !*/
+
+ void add_processing_node (
+ const std::string& ip_or_hostname,
+ unsigned short port
+ );
+ /*!
+ requires
+ - port != 0
+ ensures
+ - invokes: add_processing_node(network_address(ip_or_hostname, port))
+ !*/
+
+ unsigned long get_num_processing_nodes (
+ ) const;
+ /*!
+ ensures
+ - returns the number of remote processing nodes that have been
+ registered with this object.
+ !*/
+
+ void remove_processing_nodes (
+ );
+ /*!
+ ensures
+ - #get_num_processing_nodes() == 0
+ !*/
+
+ class invalid_problem : public error {};
+
+ template <typename matrix_type>
+ double operator() (
+ const oca& solver,
+ matrix_type& w
+ ) const;
+ /*!
+ requires
+ - get_num_processing_nodes() != 0
+ - matrix_type == a dlib::matrix capable of storing column vectors
+ ensures
+ - connects to the processing nodes and begins optimizing the structural
+ svm problem using the given oca solver.
+ - stores the solution in #w
+ - returns the objective value at the solution #w
+ throws
+ - invalid_problem
+ This exception is thrown if the svm_struct_processing_nodes disagree
+ on the dimensionality of the problem. That is, if they disagree on
+ the value of structural_svm_problem::get_num_dimensions().
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_DISTRIBUTeD_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem.h b/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem.h
new file mode 100644
index 000000000..c677861c9
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem.h
@@ -0,0 +1,542 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_
+#define DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_
+
+
+#include "structural_svm_graph_labeling_problem_abstract.h"
+#include "../graph_cuts.h"
+#include "../matrix.h"
+#include "../array.h"
+#include <vector>
+#include <iterator>
+#include "structural_svm_problem_threaded.h"
+#include "../graph.h"
+#include "sparse_vector.h"
+#include <sstream>
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ bool is_graph_labeling_problem (
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels,
+ std::string& reason_for_failure
+ )
+ {
+ typedef typename graph_type::type node_vector_type;
+ typedef typename graph_type::edge_type edge_vector_type;
+ // The graph must use all dense vectors or all sparse vectors. It can't mix the two types together.
+ COMPILE_TIME_ASSERT( (is_matrix<node_vector_type>::value && is_matrix<edge_vector_type>::value) ||
+ (!is_matrix<node_vector_type>::value && !is_matrix<edge_vector_type>::value));
+
+
+ std::ostringstream sout;
+ reason_for_failure.clear();
+
+ if (!is_learning_problem(samples, labels))
+ {
+ reason_for_failure = "is_learning_problem(samples, labels) returned false.";
+ return false;
+ }
+
+ const bool ismat = is_matrix<typename graph_type::type>::value;
+
+ // these are -1 until assigned with a value
+ long node_dims = -1;
+ long edge_dims = -1;
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (samples[i].number_of_nodes() != labels[i].size())
+ {
+ sout << "samples["<<i<<"].number_of_nodes() doesn't match labels["<<i<<"].size().";
+ reason_for_failure = sout.str();
+ return false;
+ }
+ if (graph_contains_length_one_cycle(samples[i]))
+ {
+ sout << "graph_contains_length_one_cycle(samples["<<i<<"]) returned true.";
+ reason_for_failure = sout.str();
+ return false;
+ }
+
+ for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j)
+ {
+ if (ismat && samples[i].node(j).data.size() == 0)
+ {
+ sout << "A graph contains an empty vector at node: samples["<<i<<"].node("<<j<<").data.";
+ reason_for_failure = sout.str();
+ return false;
+ }
+
+ if (ismat && node_dims == -1)
+ node_dims = samples[i].node(j).data.size();
+ // all nodes must have vectors of the same size.
+ if (ismat && (long)samples[i].node(j).data.size() != node_dims)
+ {
+ sout << "Not all node vectors in samples["<<i<<"] are the same dimension.";
+ reason_for_failure = sout.str();
+ return false;
+ }
+
+ for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n)
+ {
+ if (ismat && samples[i].node(j).edge(n).size() == 0)
+ {
+ sout << "A graph contains an empty vector at edge: samples["<<i<<"].node("<<j<<").edge("<<n<<").";
+ reason_for_failure = sout.str();
+ return false;
+ }
+ if (min(samples[i].node(j).edge(n)) < 0)
+ {
+ sout << "A graph contains negative values on an edge vector at: samples["<<i<<"].node("<<j<<").edge("<<n<<").";
+ reason_for_failure = sout.str();
+ return false;
+ }
+
+ if (ismat && edge_dims == -1)
+ edge_dims = samples[i].node(j).edge(n).size();
+ // all edges must have vectors of the same size.
+ if (ismat && (long)samples[i].node(j).edge(n).size() != edge_dims)
+ {
+ sout << "Not all edge vectors in samples["<<i<<"] are the same dimension.";
+ reason_for_failure = sout.str();
+ return false;
+ }
+ }
+ }
+ }
+
+ return true;
+ }
+
+ template <
+ typename graph_type
+ >
+ bool is_graph_labeling_problem (
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels
+ )
+ {
+ std::string reason_for_failure;
+ return is_graph_labeling_problem(samples, labels, reason_for_failure);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ bool sizes_match (
+ const std::vector<std::vector<T> >& lhs,
+ const std::vector<std::vector<U> >& rhs
+ )
+ {
+ if (lhs.size() != rhs.size())
+ return false;
+
+ for (unsigned long i = 0; i < lhs.size(); ++i)
+ {
+ if (lhs[i].size() != rhs[i].size())
+ return false;
+ }
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool all_values_are_nonnegative (
+ const std::vector<std::vector<double> >& x
+ )
+ {
+ for (unsigned long i = 0; i < x.size(); ++i)
+ {
+ for (unsigned long j = 0; j < x[i].size(); ++j)
+ {
+ if (x[i][j] < 0)
+ return false;
+ }
+ }
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename T,
+ typename enable = void
+ >
+ struct fvect
+ {
+ // In this case type should be some sparse vector type
+ typedef typename T::type type;
+ };
+
+ template < typename T >
+ struct fvect<T, typename enable_if<is_matrix<typename T::type> >::type>
+ {
+ // The point of this stuff is to create the proper matrix
+ // type to represent the concatenation of an edge vector
+ // with an node vector.
+ typedef typename T::type node_mat;
+ typedef typename T::edge_type edge_mat;
+ const static long NRd = node_mat::NR;
+ const static long NRe = edge_mat::NR;
+ const static long NR = ((NRd!=0) && (NRe!=0)) ? (NRd+NRe) : 0;
+ typedef typename node_mat::value_type value_type;
+
+ typedef matrix<value_type,NR,1, typename node_mat::mem_manager_type, typename node_mat::layout_type> type;
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ class structural_svm_graph_labeling_problem : noncopyable,
+ public structural_svm_problem_threaded<matrix<double,0,1>,
+ typename dlib::impl::fvect<graph_type>::type >
+ {
+ public:
+ typedef matrix<double,0,1> matrix_type;
+ typedef typename dlib::impl::fvect<graph_type>::type feature_vector_type;
+
+ typedef graph_type sample_type;
+
+ typedef std::vector<bool> label_type;
+
+ structural_svm_graph_labeling_problem(
+ const dlib::array<sample_type>& samples_,
+ const std::vector<label_type>& labels_,
+ const std::vector<std::vector<double> >& losses_,
+ unsigned long num_threads = 2
+ ) :
+ structural_svm_problem_threaded<matrix_type,feature_vector_type>(num_threads),
+ samples(samples_),
+ labels(labels_),
+ losses(losses_)
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ std::string reason_for_failure;
+ DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true ,
+ "\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t reason_for_failure: " << reason_for_failure
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t labels.size(): " << labels.size()
+ << "\n\t this: " << this );
+ DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
+ all_values_are_nonnegative(losses) == true,
+ "\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t labels.size(): " << labels.size()
+ << "\n\t losses.size(): " << losses.size()
+ << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
+ << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
+ << "\n\t this: " << this );
+#endif
+
+ loss_pos = 1.0;
+ loss_neg = 1.0;
+
+ // figure out how many dimensions are in the node and edge vectors.
+ node_dims = 0;
+ edge_dims = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j)
+ {
+ node_dims = std::max(node_dims,(long)max_index_plus_one(samples[i].node(j).data));
+ for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n)
+ {
+ edge_dims = std::max(edge_dims, (long)max_index_plus_one(samples[i].node(j).edge(n)));
+ }
+ }
+ }
+ }
+
+ const std::vector<std::vector<double> >& get_losses (
+ ) const { return losses; }
+
+ long get_num_edge_weights (
+ ) const
+ {
+ return edge_dims;
+ }
+
+ void set_loss_on_positive_class (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss >= 0 && get_losses().size() == 0,
+ "\t void structural_svm_graph_labeling_problem::set_loss_on_positive_class()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this );
+
+ loss_pos = loss;
+ }
+
+ void set_loss_on_negative_class (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss >= 0 && get_losses().size() == 0,
+ "\t void structural_svm_graph_labeling_problem::set_loss_on_negative_class()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this );
+
+ loss_neg = loss;
+ }
+
+ double get_loss_on_negative_class (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_losses().size() == 0,
+ "\t double structural_svm_graph_labeling_problem::get_loss_on_negative_class()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t this: " << this );
+
+ return loss_neg;
+ }
+
+ double get_loss_on_positive_class (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(get_losses().size() == 0,
+ "\t double structural_svm_graph_labeling_problem::get_loss_on_positive_class()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t this: " << this );
+
+ return loss_pos;
+ }
+
+
+ private:
+ virtual long get_num_dimensions (
+ ) const
+ {
+ // The psi/w vector will begin with all the edge dims and then follow with the node dims.
+ return edge_dims + node_dims;
+ }
+
+ virtual long get_num_samples (
+ ) const
+ {
+ return samples.size();
+ }
+
+ template <typename psi_type>
+ typename enable_if<is_matrix<psi_type> >::type get_joint_feature_vector (
+ const sample_type& sample,
+ const label_type& label,
+ psi_type& psi
+ ) const
+ {
+ psi.set_size(get_num_dimensions());
+ psi = 0;
+ for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
+ {
+ // accumulate the node vectors
+ if (label[i] == true)
+ set_rowm(psi, range(edge_dims, psi.size()-1)) += sample.node(i).data;
+
+ for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
+ {
+ const unsigned long j = sample.node(i).neighbor(n).index();
+
+ // Don't double count edges. Also only include the vector if
+ // the labels disagree.
+ if (i < j && label[i] != label[j])
+ {
+ set_rowm(psi, range(0, edge_dims-1)) -= sample.node(i).edge(n);
+ }
+ }
+ }
+ }
+
+ template <typename T>
+ void add_to_sparse_vect (
+ T& psi,
+ const T& vect,
+ unsigned long offset
+ ) const
+ {
+ for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i)
+ {
+ psi.insert(psi.end(), std::make_pair(i->first+offset, i->second));
+ }
+ }
+
+ template <typename T>
+ void subtract_from_sparse_vect (
+ T& psi,
+ const T& vect
+ ) const
+ {
+ for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i)
+ {
+ psi.insert(psi.end(), std::make_pair(i->first, -i->second));
+ }
+ }
+
+ template <typename psi_type>
+ typename disable_if<is_matrix<psi_type> >::type get_joint_feature_vector (
+ const sample_type& sample,
+ const label_type& label,
+ psi_type& psi
+ ) const
+ {
+ psi.clear();
+ for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
+ {
+ // accumulate the node vectors
+ if (label[i] == true)
+ add_to_sparse_vect(psi, sample.node(i).data, edge_dims);
+
+ for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
+ {
+ const unsigned long j = sample.node(i).neighbor(n).index();
+
+ // Don't double count edges. Also only include the vector if
+ // the labels disagree.
+ if (i < j && label[i] != label[j])
+ {
+ subtract_from_sparse_vect(psi, sample.node(i).edge(n));
+ }
+ }
+ }
+ }
+
+ virtual void get_truth_joint_feature_vector (
+ long idx,
+ feature_vector_type& psi
+ ) const
+ {
+ get_joint_feature_vector(samples[idx], labels[idx], psi);
+ }
+
+ virtual void separation_oracle (
+ const long idx,
+ const matrix_type& current_solution,
+ double& loss,
+ feature_vector_type& psi
+ ) const
+ {
+ const sample_type& samp = samples[idx];
+
+ // setup the potts graph based on samples[idx] and current_solution.
+ graph<double,double>::kernel_1a g;
+ copy_graph_structure(samp, g);
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ g.node(i).data = dot(rowm(current_solution,range(edge_dims,current_solution.size()-1)),
+ samp.node(i).data);
+
+ // Include a loss augmentation so that we will get the proper loss augmented
+ // max when we use find_max_factor_graph_potts() below.
+ if (labels[idx][i])
+ g.node(i).data -= get_loss_for_sample(idx,i,!labels[idx][i]);
+ else
+ g.node(i).data += get_loss_for_sample(idx,i,!labels[idx][i]);
+
+ for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
+ {
+ const unsigned long j = g.node(i).neighbor(n).index();
+ // Don't compute an edge weight more than once.
+ if (i < j)
+ {
+ g.node(i).edge(n) = dot(rowm(current_solution,range(0,edge_dims-1)),
+ samp.node(i).edge(n));
+ }
+ }
+
+ }
+
+ std::vector<node_label> labeling;
+ find_max_factor_graph_potts(g, labeling);
+
+
+ std::vector<bool> bool_labeling;
+ bool_labeling.reserve(labeling.size());
+ // figure out the loss
+ loss = 0;
+ for (unsigned long i = 0; i < labeling.size(); ++i)
+ {
+ const bool predicted_label = (labeling[i]!= 0);
+ bool_labeling.push_back(predicted_label);
+ loss += get_loss_for_sample(idx, i, predicted_label);
+ }
+
+ // compute psi
+ get_joint_feature_vector(samp, bool_labeling, psi);
+ }
+
+ double get_loss_for_sample (
+ long sample_idx,
+ long node_idx,
+ bool predicted_label
+ ) const
+ /*!
+ requires
+ - 0 <= sample_idx < labels.size()
+ - 0 <= node_idx < labels[sample_idx].size()
+ ensures
+ - returns the loss incurred for predicting that the node
+ samples[sample_idx].node(node_idx) has a label of predicted_label.
+ !*/
+ {
+ const bool true_label = labels[sample_idx][node_idx];
+ if (true_label != predicted_label)
+ {
+ if (losses.size() != 0)
+ return losses[sample_idx][node_idx];
+ else if (true_label == true)
+ return loss_pos;
+ else
+ return loss_neg;
+ }
+ else
+ {
+ // no loss for making the correct prediction.
+ return 0;
+ }
+ }
+
+ const dlib::array<sample_type>& samples;
+ const std::vector<label_type>& labels;
+ const std::vector<std::vector<double> >& losses;
+
+ long node_dims;
+ long edge_dims;
+ double loss_pos;
+ double loss_neg;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem_abstract.h
new file mode 100644
index 000000000..ab99ed8f4
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem_abstract.h
@@ -0,0 +1,249 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_ABSTRACT_Hh_
+
+#include "../array/array_kernel_abstract.h"
+#include "../graph/graph_kernel_abstract.h"
+#include "../matrix/matrix_abstract.h"
+#include "sparse_vector_abstract.h"
+#include "structural_svm_problem_threaded_abstract.h"
+#include <vector>
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ bool is_graph_labeling_problem (
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels
+ );
+ /*!
+ requires
+ - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h
+ - graph_type::type and graph_type::edge_type are either both dlib::matrix types
+ capable of containing column vectors or both some kind of sparse vector type
+ as defined in dlib/svm/sparse_vector_abstract.h.
+ ensures
+ - Note that a graph labeling problem is a task to learn a binary classifier which
+ predicts the correct label for each node in the provided graphs. Additionally,
+ we have information in the form of edges between nodes where edges are present
+ when we believe the linked nodes are likely to have the same label. Therefore,
+ part of a graph labeling problem is to learn to score each edge in terms of how
+ strongly the edge should enforce labeling consistency between its two nodes.
+ Thus, to be a valid graph labeling problem, samples should contain example graphs
+ of connected nodes while labels should indicate the desired label of each node.
+ The precise requirements for a valid graph labeling problem are listed below.
+ - This function returns true if all of the following are true and false otherwise:
+ - is_learning_problem(samples, labels) == true
+ - All the vectors stored on the edges of each graph in samples
+ contain only values which are >= 0.
+ - for all valid i:
+ - graph_contains_length_one_cycle(samples[i]) == false
+ - samples[i].number_of_nodes() == labels[i].size()
+ (i.e. Every graph node gets its own label)
+ - if (graph_type::edge_type is a dlib::matrix) then
+ - All the nodes must contain vectors with the same number of dimensions.
+ - All the edges must contain vectors with the same number of dimensions.
+ (However, edge vectors may differ in dimension from node vectors.)
+ - All vectors have non-zero size. That is, they have more than 0 dimensions.
+ !*/
+
+ template <
+ typename graph_type
+ >
+ bool is_graph_labeling_problem (
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels,
+ std::string& reason_for_failure
+ );
+ /*!
+ This function is identical to the above version of is_graph_labeling_problem()
+ except that if it returns false it will populate reason_for_failure with a message
+ describing why the graph is not a valid learning problem.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ bool sizes_match (
+ const std::vector<std::vector<T> >& lhs,
+ const std::vector<std::vector<U> >& rhs
+ );
+ /*!
+ ensures
+ - returns true if the sizes of lhs and rhs, as well as their constituent vectors
+ all match. In particular, we return true if all of the following conditions are
+ met and false otherwise:
+ - lhs.size() == rhs.size()
+ - for all valid i:
+ - lhs[i].size() == rhs[i].size()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool all_values_are_nonnegative (
+ const std::vector<std::vector<double> >& x
+ );
+ /*!
+ ensures
+ - returns true if all the double values contained in x are >= 0.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ class structural_svm_graph_labeling_problem : noncopyable,
+ public structural_svm_problem_threaded<matrix<double,0,1>,
+ typename graph_type::type >
+ {
+ /*!
+ REQUIREMENTS ON graph_type
+ - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h
+ - graph_type::type and graph_type::edge_type must be either matrix objects
+ capable of representing column vectors or some kind of sparse vector
+ type as defined in dlib/svm/sparse_vector_abstract.h.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning the weight vectors needed to use
+ a graph_labeler object. It learns the parameter vectors by formulating
+ the problem as a structural SVM problem.
+ !*/
+
+ public:
+ typedef matrix<double,0,1> matrix_type;
+ typedef typename graph_type::type feature_vector_type;
+ typedef graph_type sample_type;
+ typedef std::vector<bool> label_type;
+
+ structural_svm_graph_labeling_problem(
+ const dlib::array<sample_type>& samples,
+ const std::vector<label_type>& labels,
+ const std::vector<std::vector<double> >& losses,
+ unsigned long num_threads
+ );
+ /*!
+ requires
+ - is_graph_labeling_problem(samples,labels) == true
+ - if (losses.size() != 0) then
+ - sizes_match(labels, losses) == true
+ - all_values_are_nonnegative(losses) == true
+ ensures
+ - This object attempts to learn a mapping from the given samples to the
+ given labels. In particular, it attempts to learn to predict labels[i]
+ based on samples[i]. Or in other words, this object can be used to learn
+ parameter vectors, E and W, such that a graph_labeler declared as:
+ graph_labeler<feature_vector_type> labeler(E,W)
+ results in a labeler object which attempts to compute the following mapping:
+ labels[i] == labeler(samples[i])
+ - When you use this object with the oca optimizer you get back just one
+ big parameter vector as the solution. Therefore, note that this single
+ big vector is the concatenation of E and W. The first get_num_edge_weights()
+ elements of this vector correspond to E and the rest is W.
+ - This object will use num_threads threads during the optimization
+ procedure. You should set this parameter equal to the number of
+ available processing cores on your machine.
+ - if (losses.size() == 0) then
+ - #get_loss_on_positive_class() == 1.0
+ - #get_loss_on_negative_class() == 1.0
+ - #get_losses().size() == 0
+ - The losses argument is effectively ignored if its size is zero.
+ - else
+ - #get_losses() == losses
+ - Each node in the training data has its own loss value defined by
+ the corresponding entry of losses. In particular, this means that
+ the node with label labels[i][j] incurs a loss of losses[i][j] if
+ it is incorrectly labeled.
+ - The get_loss_on_positive_class() and get_loss_on_negative_class()
+ parameters are ignored. Only get_losses() is used in this case.
+ !*/
+
+ const std::vector<std::vector<double> >& get_losses (
+ ) const;
+ /*!
+ ensures
+ - returns the losses vector given to this object's constructor.
+ This vector defines the per sample loss values used. If the vector
+ is empty then the loss values defined by get_loss_on_positive_class() and
+ get_loss_on_positive_class() are used instead.
+ !*/
+
+ long get_num_edge_weights (
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the edge weight vector. It is also
+ important to know that when using the oca solver with this object,
+ you must set it to generate non-negative weights for the edge weight
+ part of the total weight vector. You can do this by passing get_num_edge_weights()
+ to the third argument to oca::operator().
+ !*/
+
+ void set_loss_on_positive_class (
+ double loss
+ );
+ /*!
+ requires
+ - loss >= 0
+ - get_losses().size() == 0
+ ensures
+ - #get_loss_on_positive_class() == loss
+ !*/
+
+ void set_loss_on_negative_class (
+ double loss
+ );
+ /*!
+ requires
+ - loss >= 0
+ - get_losses().size() == 0
+ ensures
+ - #get_loss_on_negative_class() == loss
+ !*/
+
+ double get_loss_on_positive_class (
+ ) const;
+ /*!
+ requires
+ - get_losses().size() == 0
+ ensures
+ - returns the loss incurred when a graph node which is supposed to have
+ a label of true gets misclassified. This value controls how much we care
+ about correctly classifying nodes which should be labeled as true. Larger
+ loss values indicate that we care more strongly than smaller values.
+ !*/
+
+ double get_loss_on_negative_class (
+ ) const;
+ /*!
+ requires
+ - get_losses().size() == 0
+ ensures
+ - returns the loss incurred when a graph node which is supposed to have
+ a label of false gets misclassified. This value controls how much we care
+ about correctly classifying nodes which should be labeled as false. Larger
+ loss values indicate that we care more strongly than smaller values.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_ABSTRACT_Hh_
+
+
+
+
diff --git a/ml/dlib/dlib/svm/structural_svm_object_detection_problem.h b/ml/dlib/dlib/svm/structural_svm_object_detection_problem.h
new file mode 100644
index 000000000..1c54a42b1
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_object_detection_problem.h
@@ -0,0 +1,531 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_Hh_
+#define DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_Hh_
+
+#include "structural_svm_object_detection_problem_abstract.h"
+#include "../matrix.h"
+#include "structural_svm_problem_threaded.h"
+#include <sstream>
+#include "../string.h"
+#include "../array.h"
+#include "../image_processing/full_object_detection.h"
+#include "../image_processing/box_overlap_testing.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type,
+ typename image_array_type
+ >
+ class structural_svm_object_detection_problem : public structural_svm_problem_threaded<matrix<double,0,1> >,
+ noncopyable
+ {
+ public:
+
+ structural_svm_object_detection_problem(
+ const image_scanner_type& scanner,
+ const test_box_overlap& overlap_tester,
+ const bool auto_overlap_tester,
+ const image_array_type& images_,
+ const std::vector<std::vector<full_object_detection> >& truth_object_detections_,
+ const std::vector<std::vector<rectangle> >& ignore_,
+ const test_box_overlap& ignore_overlap_tester_,
+ unsigned long num_threads = 2
+ ) :
+ structural_svm_problem_threaded<matrix<double,0,1> >(num_threads),
+ boxes_overlap(overlap_tester),
+ images(images_),
+ truth_object_detections(truth_object_detections_),
+ ignore(ignore_),
+ ignore_overlap_tester(ignore_overlap_tester_),
+ match_eps(0.5),
+ loss_per_false_alarm(1),
+ loss_per_missed_target(1)
+ {
+#ifdef ENABLE_ASSERTS
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(images_, truth_object_detections_) &&
+ ignore_.size() == images_.size() &&
+ scanner.get_num_detection_templates() > 0,
+ "\t structural_svm_object_detection_problem::structural_svm_object_detection_problem()"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t scanner.get_num_detection_templates(): " << scanner.get_num_detection_templates()
+ << "\n\t is_learning_problem(images_,truth_object_detections_): " << is_learning_problem(images_,truth_object_detections_)
+ << "\n\t ignore.size(): " << ignore.size()
+ << "\n\t images.size(): " << images.size()
+ << "\n\t this: " << this
+ );
+ for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
+ {
+ for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
+ {
+ DLIB_ASSERT(truth_object_detections[i][j].num_parts() == scanner.get_num_movable_components_per_detection_template(),
+ "\t trained_function_type structural_object_detection_trainer::train()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t truth_object_detections["<<i<<"]["<<j<<"].num_parts(): " <<
+ truth_object_detections[i][j].num_parts()
+ << "\n\t scanner.get_num_movable_components_per_detection_template(): " <<
+ scanner.get_num_movable_components_per_detection_template()
+ << "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j])
+ );
+ }
+ }
+#endif
+ // The purpose of the max_num_dets member variable is to give us a reasonable
+ // upper limit on the number of detections we can expect from a single image.
+ // This is used in the separation_oracle to put a hard limit on the number of
+ // detections we will consider. We do this purely for computational reasons
+ // since otherwise we can end up wasting large amounts of time on certain
+ // pathological cases during optimization which ultimately do not influence the
+ // result. Therefore, we force the separation oracle to only consider the
+ // max_num_dets strongest detections.
+ max_num_dets = 0;
+ for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
+ {
+ if (truth_object_detections[i].size() > max_num_dets)
+ max_num_dets = truth_object_detections[i].size();
+ }
+ max_num_dets = max_num_dets*3 + 10;
+
+ initialize_scanners(scanner, num_threads);
+
+ if (auto_overlap_tester)
+ {
+ auto_configure_overlap_tester();
+ }
+ }
+
+ test_box_overlap get_overlap_tester (
+ ) const
+ {
+ return boxes_overlap;
+ }
+
+ void set_match_eps (
+ double eps
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < eps && eps < 1,
+ "\t void structural_svm_object_detection_problem::set_match_eps(eps)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t eps: " << eps
+ << "\n\t this: " << this
+ );
+
+ match_eps = eps;
+ }
+
+ double get_match_eps (
+ ) const
+ {
+ return match_eps;
+ }
+
+ double get_loss_per_missed_target (
+ ) const
+ {
+ return loss_per_missed_target;
+ }
+
+ void set_loss_per_missed_target (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_svm_object_detection_problem::set_loss_per_missed_target(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_missed_target = loss;
+ }
+
+ double get_loss_per_false_alarm (
+ ) const
+ {
+ return loss_per_false_alarm;
+ }
+
+ void set_loss_per_false_alarm (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_svm_object_detection_problem::set_loss_per_false_alarm(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_false_alarm = loss;
+ }
+
+ private:
+
+ void auto_configure_overlap_tester(
+ )
+ {
+ std::vector<std::vector<rectangle> > mapped_rects(truth_object_detections.size());
+ for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
+ {
+ mapped_rects[i].resize(truth_object_detections[i].size());
+ for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
+ {
+ mapped_rects[i][j] = scanners[i].get_best_matching_rect(truth_object_detections[i][j].get_rect());
+ }
+ }
+
+ boxes_overlap = find_tight_overlap_tester(mapped_rects);
+ }
+
+
+ virtual long get_num_dimensions (
+ ) const
+ {
+ return scanners[0].get_num_dimensions() +
+ 1;// for threshold
+ }
+
+ virtual long get_num_samples (
+ ) const
+ {
+ return images.size();
+ }
+
+ virtual void get_truth_joint_feature_vector (
+ long idx,
+ feature_vector_type& psi
+ ) const
+ {
+ const image_scanner_type& scanner = scanners[idx];
+
+ psi.set_size(get_num_dimensions());
+ std::vector<rectangle> mapped_rects;
+
+ psi = 0;
+ for (unsigned long i = 0; i < truth_object_detections[idx].size(); ++i)
+ {
+ mapped_rects.push_back(scanner.get_best_matching_rect(truth_object_detections[idx][i].get_rect()));
+ scanner.get_feature_vector(truth_object_detections[idx][i], psi);
+ }
+ psi(scanner.get_num_dimensions()) = -1.0*truth_object_detections[idx].size();
+
+ // check if any of the boxes overlap. If they do then it is impossible for
+ // us to learn to correctly classify this sample
+ for (unsigned long i = 0; i < mapped_rects.size(); ++i)
+ {
+ for (unsigned long j = i+1; j < mapped_rects.size(); ++j)
+ {
+ if (boxes_overlap(mapped_rects[i], mapped_rects[j]))
+ {
+ const double area_overlap = mapped_rects[i].intersect(mapped_rects[j]).area();
+ const double match_amount = area_overlap/(double)( mapped_rects[i]+mapped_rects[j]).area();
+ const double overlap_amount = area_overlap/std::min(mapped_rects[i].area(),mapped_rects[j].area());
+
+ using namespace std;
+ ostringstream sout;
+ sout << "An impossible set of object labels was detected. This is happening because ";
+ sout << "the truth labels for an image contain rectangles which overlap according to the ";
+ sout << "test_box_overlap object supplied for non-max suppression. To resolve this, you ";
+ sout << "either need to relax the test_box_overlap object so it doesn't mark these rectangles as ";
+ sout << "overlapping or adjust the truth rectangles in your training dataset. ";
+
+ // make sure the above string fits nicely into a command prompt window.
+ string temp = sout.str();
+ sout.str(""); sout << wrap_string(temp,0,0) << endl << endl;
+
+
+ sout << "image index: "<< idx << endl;
+ sout << "The offending rectangles are:\n";
+ sout << "rect1: "<< mapped_rects[i] << endl;
+ sout << "rect2: "<< mapped_rects[j] << endl;
+ sout << "match amount: " << match_amount << endl;
+ sout << "overlap amount: " << overlap_amount << endl;
+ throw dlib::impossible_labeling_error(sout.str());
+ }
+ }
+ }
+
+ // make sure the mapped rectangles are within match_eps of the
+ // truth rectangles.
+ for (unsigned long i = 0; i < mapped_rects.size(); ++i)
+ {
+ const double area = (truth_object_detections[idx][i].get_rect().intersect(mapped_rects[i])).area();
+ const double total_area = (truth_object_detections[idx][i].get_rect() + mapped_rects[i]).area();
+ if (area/total_area <= match_eps)
+ {
+ using namespace std;
+ ostringstream sout;
+ sout << "An impossible set of object labels was detected. This is happening because ";
+ sout << "none of the object locations checked by the supplied image scanner is a close ";
+ sout << "enough match to one of the truth boxes in your training dataset. To resolve this ";
+ sout << "you need to either lower the match_eps, adjust the settings of the image scanner ";
+ sout << "so that it is capable of hitting this truth box, or adjust the offending truth rectangle so it ";
+ sout << "can be matched by the current image scanner. Also, if you ";
+ sout << "are using the scan_fhog_pyramid object then you could try using a finer image pyramid. ";
+ sout << "Additionally, the scan_fhog_pyramid scans a fixed aspect ratio box across the image when it ";
+ sout << "searches for objects. So if you are getting this error and you are using the scan_fhog_pyramid, ";
+ sout << "it's very likely the problem is that your training dataset contains truth rectangles of widely ";
+ sout << "varying aspect ratios. The solution is to make sure your training boxes all have about the same aspect ratio. ";
+
+
+ // make sure the above string fits nicely into a command prompt window.
+ string temp = sout.str();
+ sout.str(""); sout << wrap_string(temp,0,0) << endl << endl;
+
+ sout << "image index "<< idx << endl;
+ sout << "match_eps: "<< match_eps << endl;
+ sout << "best possible match: "<< area/total_area << endl;
+ sout << "truth rect: "<< truth_object_detections[idx][i].get_rect() << endl;
+ sout << "truth rect width/height: "<< truth_object_detections[idx][i].get_rect().width()/(double)truth_object_detections[idx][i].get_rect().height() << endl;
+ sout << "truth rect area: "<< truth_object_detections[idx][i].get_rect().area() << endl;
+ sout << "nearest detection template rect: "<< mapped_rects[i] << endl;
+ sout << "nearest detection template rect width/height: "<< mapped_rects[i].width()/(double)mapped_rects[i].height() << endl;
+ sout << "nearest detection template rect area: "<< mapped_rects[i].area() << endl;
+ throw dlib::impossible_labeling_error(sout.str());
+ }
+
+ }
+ }
+
+ virtual void separation_oracle (
+ const long idx,
+ const matrix_type& current_solution,
+ scalar_type& loss,
+ feature_vector_type& psi
+ ) const
+ {
+ const image_scanner_type& scanner = scanners[idx];
+
+ std::vector<std::pair<double, rectangle> > dets;
+ const double thresh = current_solution(scanner.get_num_dimensions());
+
+
+ scanner.detect(current_solution, dets, thresh-loss_per_false_alarm);
+
+
+ // The loss will measure the number of incorrect detections. A detection is
+ // incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection
+ // on a truth rectangle.
+ loss = truth_object_detections[idx].size()*loss_per_missed_target;
+
+ // Measure the loss augmented score for the detections which hit a truth rect.
+ std::vector<double> truth_score_hits(truth_object_detections[idx].size(), 0);
+
+ // keep track of which truth boxes we have hit so far.
+ std::vector<bool> hit_truth_table(truth_object_detections[idx].size(), false);
+
+ std::vector<rectangle> final_dets;
+ // The point of this loop is to fill out the truth_score_hits array.
+ for (unsigned long i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i)
+ {
+ if (overlaps_any_box(boxes_overlap, final_dets, dets[i].second))
+ continue;
+
+ const std::pair<double,unsigned int> truth = find_best_match(truth_object_detections[idx], dets[i].second);
+
+ final_dets.push_back(dets[i].second);
+
+ const double truth_match = truth.first;
+ // if hit truth rect
+ if (truth_match > match_eps)
+ {
+ // if this is the first time we have seen a detect which hit truth_object_detections[idx][truth.second]
+ const double score = dets[i].first - thresh;
+ if (hit_truth_table[truth.second] == false)
+ {
+ hit_truth_table[truth.second] = true;
+ truth_score_hits[truth.second] += score;
+ }
+ else
+ {
+ truth_score_hits[truth.second] += score + loss_per_false_alarm;
+ }
+ }
+ }
+
+ hit_truth_table.assign(hit_truth_table.size(), false);
+
+ final_dets.clear();
+#ifdef ENABLE_ASSERTS
+ double total_score = 0;
+#endif
+ // Now figure out which detections jointly maximize the loss and detection score sum. We
+ // need to take into account the fact that allowing a true detection in the output, while
+ // initially reducing the loss, may allow us to increase the loss later with many duplicate
+ // detections.
+ for (unsigned long i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i)
+ {
+ if (overlaps_any_box(boxes_overlap, final_dets, dets[i].second))
+ continue;
+
+ const std::pair<double,unsigned int> truth = find_best_match(truth_object_detections[idx], dets[i].second);
+
+ const double truth_match = truth.first;
+ if (truth_match > match_eps)
+ {
+ if (truth_score_hits[truth.second] > loss_per_missed_target)
+ {
+ if (!hit_truth_table[truth.second])
+ {
+ hit_truth_table[truth.second] = true;
+ final_dets.push_back(dets[i].second);
+#ifdef ENABLE_ASSERTS
+ total_score += dets[i].first;
+#endif
+ loss -= loss_per_missed_target;
+ }
+ else
+ {
+ final_dets.push_back(dets[i].second);
+#ifdef ENABLE_ASSERTS
+ total_score += dets[i].first;
+#endif
+ loss += loss_per_false_alarm;
+ }
+ }
+ }
+ else if (!overlaps_ignore_box(idx,dets[i].second))
+ {
+ // didn't hit anything
+ final_dets.push_back(dets[i].second);
+#ifdef ENABLE_ASSERTS
+ total_score += dets[i].first;
+#endif
+ loss += loss_per_false_alarm;
+ }
+ }
+
+ psi.set_size(get_num_dimensions());
+ psi = 0;
+ for (unsigned long i = 0; i < final_dets.size(); ++i)
+ scanner.get_feature_vector(scanner.get_full_object_detection(final_dets[i], current_solution), psi);
+
+#ifdef ENABLE_ASSERTS
+ const double psi_score = dot(psi, current_solution);
+ DLIB_CASSERT(std::abs(psi_score-total_score) <= 1e-4 * std::max(1.0,std::max(std::abs(psi_score),std::abs(total_score))),
+ "\t The get_feature_vector() and detect() methods of image_scanner_type are not in sync."
+ << "\n\t The relative error is too large to be attributed to rounding error."
+ << "\n\t error: " << std::abs(psi_score-total_score)
+ << "\n\t psi_score: " << psi_score
+ << "\n\t total_score: " << total_score
+ );
+#endif
+
+ psi(scanner.get_num_dimensions()) = -1.0*final_dets.size();
+ }
+
+
+ bool overlaps_ignore_box (
+ const long idx,
+ const dlib::rectangle& rect
+ ) const
+ {
+ for (unsigned long i = 0; i < ignore[idx].size(); ++i)
+ {
+ if (ignore_overlap_tester(ignore[idx][i], rect))
+ return true;
+ }
+ return false;
+ }
+
+ std::pair<double,unsigned int> find_best_match(
+ const std::vector<full_object_detection>& boxes,
+ const rectangle rect
+ ) const
+ /*!
+ ensures
+ - determines which rectangle in boxes matches rect the most and
+ returns the amount of this match. Specifically, the match is
+ a number O with the following properties:
+ - 0 <= O <= 1
+ - Let R be the maximum matching rectangle in boxes, then
+ O == (R.intersect(rect)).area() / (R + rect).area()
+ - O == 0 if there is no match with any rectangle.
+ !*/
+ {
+ double match = 0;
+ unsigned int best_idx = 0;
+ for (unsigned long i = 0; i < boxes.size(); ++i)
+ {
+
+ const unsigned long area = rect.intersect(boxes[i].get_rect()).area();
+ if (area != 0)
+ {
+ const double new_match = area / static_cast<double>((rect + boxes[i].get_rect()).area());
+ if (new_match > match)
+ {
+ match = new_match;
+ best_idx = i;
+ }
+ }
+ }
+
+ return std::make_pair(match,best_idx);
+ }
+
+ struct init_scanners_helper
+ {
+ init_scanners_helper (
+ array<image_scanner_type>& scanners_,
+ const image_array_type& images_
+ ) :
+ scanners(scanners_),
+ images(images_)
+ {}
+
+ array<image_scanner_type>& scanners;
+ const image_array_type& images;
+
+ void operator() (long i ) const
+ {
+ scanners[i].load(images[i]);
+ }
+ };
+
+ void initialize_scanners (
+ const image_scanner_type& scanner,
+ unsigned long num_threads
+ )
+ {
+ scanners.set_max_size(images.size());
+ scanners.set_size(images.size());
+
+ for (unsigned long i = 0; i < scanners.size(); ++i)
+ scanners[i].copy_configuration(scanner);
+
+ // now load the images into all the scanners
+ parallel_for(num_threads, 0, scanners.size(), init_scanners_helper(scanners, images));
+ }
+
+
+ test_box_overlap boxes_overlap;
+
+ mutable array<image_scanner_type> scanners;
+
+ const image_array_type& images;
+ const std::vector<std::vector<full_object_detection> >& truth_object_detections;
+ const std::vector<std::vector<rectangle> >& ignore;
+ const test_box_overlap ignore_overlap_tester;
+
+ unsigned long max_num_dets;
+ double match_eps;
+ double loss_per_false_alarm;
+ double loss_per_missed_target;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/structural_svm_object_detection_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_object_detection_problem_abstract.h
new file mode 100644
index 000000000..d73c5920d
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_object_detection_problem_abstract.h
@@ -0,0 +1,178 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_ABSTRACT_Hh_
+
+#include "../matrix.h"
+#include "structural_svm_problem_threaded_abstract.h"
+#include <sstream>
+#include "../image_processing/full_object_detection_abstract.h"
+#include "../image_processing/box_overlap_testing.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_scanner_type,
+ typename image_array_type
+ >
+ class structural_svm_object_detection_problem : public structural_svm_problem_threaded<matrix<double,0,1> >,
+ noncopyable
+ {
+ /*!
+ REQUIREMENTS ON image_scanner_type
+ image_scanner_type must be an implementation of
+ dlib/image_processing/scan_fhog_pyramid_abstract.h or
+ dlib/image_processing/scan_image_custom_abstract.h or
+ dlib/image_processing/scan_image_pyramid_abstract.h or
+ dlib/image_processing/scan_image_boxes_abstract.h
+
+ REQUIREMENTS ON image_array_type
+ image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
+ and it must contain objects which can be accepted by image_scanner_type::load().
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning the parameter vector needed to use a
+ scan_image_pyramid, scan_fhog_pyramid, scan_image_custom, or
+ scan_image_boxes object.
+
+ It learns the parameter vector by formulating the problem as a structural
+ SVM problem. The exact details of the method are described in the paper
+ Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046).
+
+
+ !*/
+
+ public:
+
+ structural_svm_object_detection_problem(
+ const image_scanner_type& scanner,
+ const test_box_overlap& overlap_tester,
+ const bool auto_overlap_tester,
+ const image_array_type& images,
+ const std::vector<std::vector<full_object_detection> >& truth_object_detections,
+ const std::vector<std::vector<rectangle> >& ignore,
+ const test_box_overlap& ignore_overlap_tester,
+ unsigned long num_threads = 2
+ );
+ /*!
+ requires
+ - is_learning_problem(images, truth_object_detections)
+ - ignore.size() == images.size()
+ - scanner.get_num_detection_templates() > 0
+ - scanner.load(images[0]) must be a valid expression.
+ - for all valid i, j:
+ - truth_object_detections[i][j].num_parts() == scanner.get_num_movable_components_per_detection_template()
+ - all_parts_in_rect(truth_object_detections[i][j]) == true
+ ensures
+ - This object attempts to learn a mapping from the given images to the
+ object locations given in truth_object_detections. In particular, it
+ attempts to learn to predict truth_object_detections[i] based on
+ images[i]. Or in other words, this object can be used to learn a
+ parameter vector, w, such that an object_detector declared as:
+ object_detector<image_scanner_type> detector(scanner,get_overlap_tester(),w)
+ results in a detector object which attempts to compute the locations of
+ all the objects in truth_object_detections. So if you called
+ detector(images[i]) you would hopefully get a list of rectangles back
+ that had truth_object_detections[i].size() elements and contained exactly
+ the rectangles indicated by truth_object_detections[i].
+ - if (auto_overlap_tester == true) then
+ - #get_overlap_tester() == a test_box_overlap object that is configured
+ using the find_tight_overlap_tester() routine and the contents of
+ truth_object_detections.
+ - else
+ - #get_overlap_tester() == overlap_tester
+ - #get_match_eps() == 0.5
+ - This object will use num_threads threads during the optimization
+ procedure. You should set this parameter equal to the number of
+ available processing cores on your machine.
+ - #get_loss_per_missed_target() == 1
+ - #get_loss_per_false_alarm() == 1
+ - for all valid i:
+ - Within images[i] any detections that match against a rectangle in
+ ignore[i], according to ignore_overlap_tester, are ignored. That is,
+ the optimizer doesn't care if the detector outputs a detection that
+ matches any of the ignore rectangles or if it fails to output a
+ detection for an ignore rectangle. Therefore, if there are objects
+ in your dataset that you are unsure you want to detect or otherwise
+ don't care if the detector gets or doesn't then you can mark them
+ with ignore rectangles and the optimizer will simply ignore them.
+ !*/
+
+ test_box_overlap get_overlap_tester (
+ ) const;
+ /*!
+ ensures
+ - returns the overlap tester used by this object.
+ !*/
+
+ void set_match_eps (
+ double eps
+ );
+ /*!
+ requires
+ - 0 < eps < 1
+ ensures
+ - #get_match_eps() == eps
+ !*/
+
+ double get_match_eps (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of alignment necessary for a detection to be considered
+ as matching with a ground truth rectangle. The precise formula for determining
+ if two rectangles match each other is the following, rectangles A and B match
+ if and only if:
+ A.intersect(B).area()/(A+B).area() > get_match_eps()
+ !*/
+
+ double get_loss_per_missed_target (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss experienced for failing to detect one of the
+ targets.
+ !*/
+
+ void set_loss_per_missed_target (
+ double loss
+ );
+ /*!
+ requires
+ - loss > 0
+ ensures
+ - #get_loss_per_missed_target() == loss
+ !*/
+
+ double get_loss_per_false_alarm (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss experienced for emitting a false alarm detection.
+ Or in other words, the loss for generating a detection that doesn't correspond
+ to one of the truth rectangles.
+ !*/
+
+ void set_loss_per_false_alarm (
+ double loss
+ );
+ /*!
+ requires
+ - loss > 0
+ ensures
+ - #get_loss_per_false_alarm() == loss
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/structural_svm_problem.h b/ml/dlib/dlib/svm/structural_svm_problem.h
new file mode 100644
index 000000000..3a73457b9
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_problem.h
@@ -0,0 +1,649 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_SVM_PRObLEM_Hh_
+#define DLIB_STRUCTURAL_SVM_PRObLEM_Hh_
+
+#include "structural_svm_problem_abstract.h"
+#include "../algs.h"
+#include <vector>
+#include "../optimization/optimization_oca.h"
+#include "../matrix.h"
+#include "sparse_vector.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ struct nuclear_norm_regularizer
+ {
+ long first_dimension;
+ long nr;
+ long nc;
+ double regularization_strength;
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename structural_svm_problem
+ >
+ class cache_element_structural_svm
+ {
+ public:
+
+ cache_element_structural_svm (
+ ) : prob(0), sample_idx(0), last_true_risk_computed(std::numeric_limits<double>::infinity()) {}
+
+ typedef typename structural_svm_problem::scalar_type scalar_type;
+ typedef typename structural_svm_problem::matrix_type matrix_type;
+ typedef typename structural_svm_problem::feature_vector_type feature_vector_type;
+
+ void init (
+ const structural_svm_problem* prob_,
+ const long idx
+ )
+ /*!
+ ensures
+ - This object will be a cache for the idx-th sample in the given
+ structural_svm_problem.
+ !*/
+ {
+ prob = prob_;
+ sample_idx = idx;
+
+ loss.clear();
+ psi.clear();
+ lru_count.clear();
+
+ if (prob->get_max_cache_size() != 0)
+ {
+ prob->get_truth_joint_feature_vector(idx, true_psi);
+ compact_sparse_vector(true_psi);
+ }
+ }
+
+ void get_truth_joint_feature_vector_cached (
+ feature_vector_type& psi
+ ) const
+ {
+ if (prob->get_max_cache_size() != 0)
+ psi = true_psi;
+ else
+ prob->get_truth_joint_feature_vector(sample_idx, psi);
+
+ if (is_matrix<feature_vector_type>::value)
+ {
+ DLIB_CASSERT((long)psi.size() == prob->get_num_dimensions(),
+ "The dimensionality of your PSI vector doesn't match get_num_dimensions()");
+ }
+ }
+
+ void separation_oracle_cached (
+ const bool use_only_cache,
+ const bool skip_cache,
+ const scalar_type& saved_current_risk_gap,
+ const matrix_type& current_solution,
+ scalar_type& out_loss,
+ feature_vector_type& out_psi
+ ) const
+ {
+ const bool cache_enabled = prob->get_max_cache_size() != 0;
+
+ // Don't waste time computing this if the cache isn't going to be used.
+ const scalar_type dot_true_psi = cache_enabled ? dot(true_psi, current_solution) : 0;
+
+ scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
+ unsigned long best_idx = 0;
+ long max_lru_count = 0;
+ if (cache_enabled)
+ {
+ // figure out which element in the cache is the best (i.e. has the biggest risk)
+ for (unsigned long i = 0; i < loss.size(); ++i)
+ {
+ const scalar_type risk = loss[i] + dot(psi[i], current_solution) - dot_true_psi;
+ if (risk > best_risk)
+ {
+ best_risk = risk;
+ out_loss = loss[i];
+ best_idx = i;
+ }
+ if (lru_count[i] > max_lru_count)
+ max_lru_count = lru_count[i];
+ }
+
+ if (!skip_cache)
+ {
+ // Check if the best psi vector in the cache is still good enough to use as
+ // a proxy for the true separation oracle. If the risk value has dropped
+ // by enough to get into the stopping condition then the best psi isn't
+ // good enough.
+ if ((best_risk + saved_current_risk_gap > last_true_risk_computed &&
+ best_risk >= 0) || use_only_cache)
+ {
+ out_psi = psi[best_idx];
+ lru_count[best_idx] = max_lru_count + 1;
+ return;
+ }
+ }
+ }
+
+
+ prob->separation_oracle(sample_idx, current_solution, out_loss, out_psi);
+ if (is_matrix<feature_vector_type>::value)
+ {
+ DLIB_CASSERT((long)out_psi.size() == prob->get_num_dimensions(),
+ "The dimensionality of your PSI vector doesn't match get_num_dimensions()");
+ }
+
+ if (!cache_enabled)
+ return;
+
+ compact_sparse_vector(out_psi);
+
+ last_true_risk_computed = out_loss + dot(out_psi, current_solution) - dot_true_psi;
+
+ // If the separation oracle is only solved approximately then the result might
+ // not be as good as just selecting true_psi as the output. So here we check
+ // if that is the case.
+ if (last_true_risk_computed < 0 && best_risk < 0)
+ {
+ out_psi = true_psi;
+ out_loss = 0;
+ }
+ // Alternatively, an approximate separation oracle might not do as well as just
+ // selecting from the cache. So if that is the case when just take the best
+ // element from the cache.
+ else if (last_true_risk_computed < best_risk)
+ {
+ out_psi = psi[best_idx];
+ out_loss = loss[best_idx];
+ lru_count[best_idx] = max_lru_count + 1;
+ }
+ // if the cache is full
+ else if (loss.size() >= prob->get_max_cache_size())
+ {
+ // find least recently used cache entry for idx-th sample
+ const long i = index_of_min(mat(lru_count));
+
+ // save our new data in the cache
+ loss[i] = out_loss;
+ psi[i] = out_psi;
+
+ const long max_use = max(mat(lru_count));
+ // Make sure this new cache entry has the best lru count since we have used
+ // it most recently.
+ lru_count[i] = max_use + 1;
+ }
+ else
+ {
+ // In this case we just append the new psi into the cache.
+
+ loss.push_back(out_loss);
+ psi.push_back(out_psi);
+ long max_use = 1;
+ if (lru_count.size() != 0)
+ max_use = max(mat(lru_count)) + 1;
+ lru_count.push_back(max_use);
+ }
+ }
+
+ private:
+ // Do nothing if T isn't actually a sparse vector
+ template <typename T> void compact_sparse_vector( T& ) const { }
+
+ template <
+ typename T,
+ typename U,
+ typename alloc
+ >
+ void compact_sparse_vector (
+ std::vector<std::pair<T,U>,alloc>& vect
+ ) const
+ {
+ // If the sparse vector has more entires than dimensions then it must have some
+ // duplicate elements. So compact them using make_sparse_vector_inplace().
+ if (vect.size() > (unsigned long)prob->get_num_dimensions())
+ {
+ make_sparse_vector_inplace(vect);
+ // make sure the vector doesn't use more RAM than is necessary
+ std::vector<std::pair<T,U>,alloc>(vect).swap(vect);
+ }
+ }
+
+ const structural_svm_problem* prob;
+
+ long sample_idx;
+
+ mutable feature_vector_type true_psi;
+ mutable std::vector<scalar_type> loss;
+ mutable std::vector<feature_vector_type> psi;
+ mutable std::vector<long> lru_count;
+ mutable double last_true_risk_computed;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type_,
+ typename feature_vector_type_ = matrix_type_
+ >
+ class structural_svm_problem : public oca_problem<matrix_type_>
+ {
+ public:
+ /*!
+ CONVENTION
+ - C == get_c()
+ - eps == get_epsilon()
+ - max_iterations == get_max_iterations()
+ - if (skip_cache) then
+ - we won't use the oracle cache when we need to evaluate the separation
+ oracle. Instead, we will directly call the user supplied separation_oracle().
+
+ - get_max_cache_size() == max_cache_size
+
+ - if (cache.size() != 0) then
+ - cache.size() == get_num_samples()
+ - for all i: cache[i] == the cached results of calls to separation_oracle()
+ for the i-th sample.
+ !*/
+
+ typedef matrix_type_ matrix_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef feature_vector_type_ feature_vector_type;
+
+ structural_svm_problem (
+ ) :
+ saved_current_risk_gap(0),
+ eps(0.001),
+ max_iterations(10000),
+ verbose(false),
+ skip_cache(true),
+ count_below_eps(0),
+ max_cache_size(5),
+ converged(false),
+ nuclear_norm_part(0),
+ cache_based_eps(std::numeric_limits<scalar_type>::infinity()),
+ C(1)
+ {}
+
+ scalar_type get_cache_based_epsilon (
+ ) const
+ {
+ return cache_based_eps;
+ }
+
+ void set_cache_based_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void structural_svm_problem::set_cache_based_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ cache_based_eps = eps_;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void structural_svm_problem::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const { return eps; }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void set_max_cache_size (
+ unsigned long max_size
+ )
+ {
+ max_cache_size = max_size;
+ }
+
+ unsigned long get_max_cache_size (
+ ) const { return max_cache_size; }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet(
+ )
+ {
+ verbose = false;
+ }
+
+ scalar_type get_c (
+ ) const { return C; }
+
+ void set_c (
+ scalar_type C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void structural_svm_problem::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ void add_nuclear_norm_regularizer (
+ long first_dimension,
+ long rows,
+ long cols,
+ double regularization_strength
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= first_dimension && first_dimension < get_num_dimensions() &&
+ 0 <= rows && 0 <= cols && rows*cols+first_dimension <= get_num_dimensions() &&
+ 0 < regularization_strength,
+ "\t void structural_svm_problem::add_nuclear_norm_regularizer()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t first_dimension: " << first_dimension
+ << "\n\t rows: " << rows
+ << "\n\t cols: " << cols
+ << "\n\t get_num_dimensions(): " << get_num_dimensions()
+ << "\n\t regularization_strength: " << regularization_strength
+ << "\n\t this: " << this
+ );
+
+ impl::nuclear_norm_regularizer temp;
+ temp.first_dimension = first_dimension;
+ temp.nr = rows;
+ temp.nc = cols;
+ temp.regularization_strength = regularization_strength;
+ nuclear_norm_regularizers.push_back(temp);
+ }
+
+ unsigned long num_nuclear_norm_regularizers (
+ ) const { return nuclear_norm_regularizers.size(); }
+
+ void clear_nuclear_norm_regularizers (
+ ) { nuclear_norm_regularizers.clear(); }
+
+ virtual long get_num_dimensions (
+ ) const = 0;
+
+ virtual long get_num_samples (
+ ) const = 0;
+
+ virtual void get_truth_joint_feature_vector (
+ long idx,
+ feature_vector_type& psi
+ ) const = 0;
+
+ virtual void separation_oracle (
+ const long idx,
+ const matrix_type& current_solution,
+ scalar_type& loss,
+ feature_vector_type& psi
+ ) const = 0;
+
+ private:
+
+ virtual bool risk_has_lower_bound (
+ scalar_type& lower_bound
+ ) const
+ {
+ lower_bound = 0;
+ return true;
+ }
+
+ virtual bool optimization_status (
+ scalar_type current_objective_value,
+ scalar_type current_error_gap,
+ scalar_type current_risk_value,
+ scalar_type current_risk_gap,
+ unsigned long num_cutting_planes,
+ unsigned long num_iterations
+ ) const
+ {
+ if (verbose)
+ {
+ using namespace std;
+ if (nuclear_norm_regularizers.size() != 0)
+ {
+ cout << "objective: " << current_objective_value << endl;
+ cout << "objective gap: " << current_error_gap << endl;
+ cout << "risk: " << current_risk_value-nuclear_norm_part << endl;
+ cout << "risk+nuclear norm: " << current_risk_value << endl;
+ cout << "risk+nuclear norm gap: " << current_risk_gap << endl;
+ cout << "num planes: " << num_cutting_planes << endl;
+ cout << "iter: " << num_iterations << endl;
+ }
+ else
+ {
+ cout << "objective: " << current_objective_value << endl;
+ cout << "objective gap: " << current_error_gap << endl;
+ cout << "risk: " << current_risk_value << endl;
+ cout << "risk gap: " << current_risk_gap << endl;
+ cout << "num planes: " << num_cutting_planes << endl;
+ cout << "iter: " << num_iterations << endl;
+ }
+ cout << endl;
+ }
+
+ if (num_iterations >= max_iterations)
+ return true;
+
+ saved_current_risk_gap = current_risk_gap;
+
+ if (converged)
+ {
+ return (current_risk_gap < std::max(cache_based_eps,cache_based_eps*current_risk_value)) ||
+ (current_risk_gap == 0);
+ }
+
+ if (current_risk_gap < eps)
+ {
+ // Only stop when we see that the risk gap is small enough on a non-cached
+ // iteration. But even then, if we are supposed to do the cache based
+ // refinement then we just mark that we have "converged" to avoid further
+ // calls to the separation oracle and run all subsequent iterations off the
+ // cache.
+ if (skip_cache || max_cache_size == 0)
+ {
+ converged = true;
+ skip_cache = false;
+ return (current_risk_gap < std::max(cache_based_eps,cache_based_eps*current_risk_value)) ||
+ (current_risk_gap == 0);
+ }
+
+ ++count_below_eps;
+
+ // Only disable the cache if we have seen a few consecutive iterations that
+ // look to have converged.
+ if (count_below_eps > 1)
+ {
+ // Instead of stopping we shouldn't use the cache on the next iteration. This way
+ // we can be sure to have the best solution rather than assuming the cache is up-to-date
+ // enough.
+ skip_cache = true;
+ count_below_eps = 0;
+ }
+ }
+ else
+ {
+ count_below_eps = 0;
+ skip_cache = false;
+ }
+
+ return false;
+ }
+
+ virtual void get_risk (
+ matrix_type& w,
+ scalar_type& risk,
+ matrix_type& subgradient
+ ) const
+ {
+ feature_vector_type ftemp;
+ const unsigned long num = get_num_samples();
+
+ // initialize the cache and compute psi_true.
+ if (cache.size() == 0)
+ {
+ cache.resize(get_num_samples());
+ for (unsigned long i = 0; i < cache.size(); ++i)
+ cache[i].init(this,i);
+
+ psi_true.set_size(w.size(),1);
+ psi_true = 0;
+
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ cache[i].get_truth_joint_feature_vector_cached(ftemp);
+
+ subtract_from(psi_true, ftemp);
+ }
+ }
+
+ subgradient = psi_true;
+ scalar_type total_loss = 0;
+ call_separation_oracle_on_all_samples(w,subgradient,total_loss);
+
+ subgradient /= num;
+ total_loss /= num;
+ risk = total_loss + dot(subgradient,w);
+
+ if (nuclear_norm_regularizers.size() != 0)
+ {
+ matrix_type grad;
+ scalar_type obj;
+ compute_nuclear_norm_parts(w, grad, obj);
+ risk += obj;
+ subgradient += grad;
+ }
+ }
+
+ virtual void call_separation_oracle_on_all_samples (
+ const matrix_type& w,
+ matrix_type& subgradient,
+ scalar_type& total_loss
+ ) const
+ {
+ feature_vector_type ftemp;
+ const unsigned long num = get_num_samples();
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ scalar_type loss;
+ separation_oracle_cached(i, w, loss, ftemp);
+ total_loss += loss;
+ add_to(subgradient, ftemp);
+ }
+ }
+
+ protected:
+
+ void compute_nuclear_norm_parts(
+ const matrix_type& m,
+ matrix_type& grad,
+ scalar_type& obj
+ ) const
+ {
+ obj = 0;
+ grad.set_size(m.size(), 1);
+ grad = 0;
+
+ matrix<double> u,v,w,f;
+ nuclear_norm_part = 0;
+ for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i)
+ {
+ const long nr = nuclear_norm_regularizers[i].nr;
+ const long nc = nuclear_norm_regularizers[i].nc;
+ const long size = nr*nc;
+ const long idx = nuclear_norm_regularizers[i].first_dimension;
+ const double strength = nuclear_norm_regularizers[i].regularization_strength;
+
+ f = matrix_cast<double>(reshape(rowm(m, range(idx, idx+size-1)), nr, nc));
+ svd3(f, u,w,v);
+
+
+ const double norm = sum(w);
+ obj += strength*norm;
+ nuclear_norm_part += strength*norm/C;
+
+ f = u*trans(v);
+
+ set_rowm(grad, range(idx, idx+size-1)) = matrix_cast<double>(strength*reshape_to_column_vector(f));
+ }
+
+ obj /= C;
+ grad /= C;
+ }
+
+ void separation_oracle_cached (
+ const long idx,
+ const matrix_type& current_solution,
+ scalar_type& loss,
+ feature_vector_type& psi
+ ) const
+ {
+ cache[idx].separation_oracle_cached(converged,
+ skip_cache,
+ saved_current_risk_gap,
+ current_solution,
+ loss,
+ psi);
+ }
+
+ std::vector<impl::nuclear_norm_regularizer> nuclear_norm_regularizers;
+
+ mutable scalar_type saved_current_risk_gap;
+ mutable matrix_type psi_true;
+ scalar_type eps;
+ unsigned long max_iterations;
+ mutable bool verbose;
+
+
+ mutable std::vector<cache_element_structural_svm<structural_svm_problem> > cache;
+ mutable bool skip_cache;
+ mutable int count_below_eps;
+ unsigned long max_cache_size;
+ mutable bool converged;
+ mutable double nuclear_norm_part;
+ scalar_type cache_based_eps;
+
+ scalar_type C;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_PRObLEM_Hh_
+
diff --git a/ml/dlib/dlib/svm/structural_svm_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_problem_abstract.h
new file mode 100644
index 000000000..20b3d73a7
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_problem_abstract.h
@@ -0,0 +1,348 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_SVM_PRObLEM_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_SVM_PRObLEM_ABSTRACT_Hh_
+
+#include "../optimization/optimization_oca_abstract.h"
+#include "sparse_vector_abstract.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type_,
+ typename feature_vector_type_ = matrix_type_
+ >
+ class structural_svm_problem : public oca_problem<matrix_type_>
+ {
+ public:
+ /*!
+ REQUIREMENTS ON matrix_type_
+ - matrix_type_ == a dlib::matrix capable of storing column vectors
+
+ REQUIREMENTS ON feature_vector_type_
+ - feature_vector_type_ == a dlib::matrix capable of storing column vectors
+ or an unsorted sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
+
+ INITIAL VALUE
+ - get_epsilon() == 0.001
+ - get_max_iterations() == 10000
+ - get_max_cache_size() == 5
+ - get_c() == 1
+ - get_cache_based_epsilon() == std::numeric_limits<scalar_type>::infinity()
+ (I.e. the cache based epsilon feature is disabled)
+ - num_nuclear_norm_regularizers() == 0
+ - This object will not be verbose
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for solving the optimization problem associated with
+ a structural support vector machine. A structural SVM is a supervised
+ machine learning method for learning to predict complex outputs. This is
+ contrasted with a binary classifier which makes only simple yes/no
+ predictions. A structural SVM, on the other hand, can learn to predict
+ complex outputs such as entire parse trees or DNA sequence alignments. To
+ do this, it learns a function F(x,y) which measures how well a particular
+ data sample x matches a label y. When used for prediction, the best label
+ for a new x is given by the y which maximizes F(x,y).
+
+ To use this object you inherit from it, provide implementations of its four
+ pure virtual functions, and then pass your object to the oca optimizer.
+ Also, you should only pass an instance of this object to the oca optimizer
+ once. That is, the act of using a structural_svm_problem instance with the
+ oca solver "uses" the structural_svm_problem instance. If you want to
+ solve the same problem multiple times then you must use a fresh instance of
+ your structural_svm_problem.
+
+
+ To define the optimization problem precisely, we first introduce some notation:
+ - let PSI(x,y) == the joint feature vector for input x and a label y.
+ - let F(x,y|w) == dot(w,PSI(x,y)).
+ - let LOSS(idx,y) == the loss incurred for predicting that the idx-th training
+ sample has a label of y. Note that LOSS() should always be >= 0 and should
+ become exactly 0 when y is the correct label for the idx-th sample.
+ - let x_i == the i-th training sample.
+ - let y_i == the correct label for the i-th training sample.
+ - The number of data samples is N.
+
+ Then the optimization problem solved using this object is the following:
+ Minimize: h(w) == 0.5*dot(w,w) + C*R(w)
+
+ Where R(w) == sum from i=1 to N: 1/N * sample_risk(i,w)
+ and sample_risk(i,w) == max over all Y: LOSS(i,Y) + F(x_i,Y|w) - F(x_i,y_i|w)
+ and C > 0
+
+
+
+ For an introduction to structured support vector machines you should consult
+ the following paper:
+ Predicting Structured Objects with Support Vector Machines by
+ Thorsten Joachims, Thomas Hofmann, Yisong Yue, and Chun-nam Yu
+
+ For a more detailed discussion of the particular algorithm implemented by this
+ object see the following paper:
+ T. Joachims, T. Finley, Chun-Nam Yu, Cutting-Plane Training of Structural SVMs,
+ Machine Learning, 77(1):27-59, 2009.
+
+ Note that this object is essentially a tool for solving the 1-Slack structural
+ SVM with margin-rescaling. Specifically, see Algorithm 3 in the above referenced
+ paper.
+ !*/
+
+ typedef matrix_type_ matrix_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef feature_vector_type_ feature_vector_type;
+
+ structural_svm_problem (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer
+ to execute. Specifically, the algorithm stops when the average sample
+ risk (i.e. R(w) as defined above) is within epsilon of its optimal value.
+
+ Also note that sample risk is an upper bound on a sample's loss. So
+ you can think of this epsilon value as saying "solve the optimization
+ problem until the average loss per sample is within epsilon of it's
+ optimal value".
+ !*/
+
+ scalar_type get_cache_based_epsilon (
+ ) const;
+ /*!
+ ensures
+ - if (get_max_cache_size() != 0) then
+ - The solver will not stop when the average sample risk is within
+ get_epsilon() of its optimal value. Instead, it will keep running
+ but will run the optimizer completely on the cache until the average
+ sample risk is within #get_cache_based_epsilon() of its optimal
+ value. This means that it will perform this additional refinement in
+ the solution accuracy without making any additional calls to the
+ separation_oracle(). This is useful when using a nuclear norm
+ regularization term because it allows you to quickly solve the
+ optimization problem to a high precision, which in the case of a
+ nuclear norm regularized problem means that many of the learned
+ matrices will be low rank or very close to low rank due to the
+ nuclear norm regularizer. This may not happen without solving the
+ problem to a high accuracy or their ranks may be difficult to
+ determine, so the extra accuracy given by the cache based refinement
+ is very useful. Finally, note that we include the nuclear norm term
+ as part of the "risk" for the purposes of determining when to stop.
+ - else
+ - The value of #get_cache_based_epsilon() has no effect.
+ !*/
+
+ void set_cache_based_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_cache_based_epsilon() == eps
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ unsigned long get_max_iterations (
+ );
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void set_max_cache_size (
+ unsigned long max_size
+ );
+ /*!
+ ensures
+ - #get_max_cache_size() == max_size
+ !*/
+
+ unsigned long get_max_cache_size (
+ ) const;
+ /*!
+ ensures
+ - Returns the number of joint feature vectors per training sample kept in
+ the separation oracle cache. This cache is used to avoid unnecessary
+ calls to the user supplied separation_oracle() function. Note that a
+ value of 0 means that caching is not used at all. This is appropriate
+ if the separation oracle is cheap to evaluate.
+ !*/
+
+ void add_nuclear_norm_regularizer (
+ long first_dimension,
+ long rows,
+ long cols,
+ double regularization_strength
+ );
+ /*!
+ requires
+ - 0 <= first_dimension < get_num_dimensions()
+ - 0 <= rows
+ - 0 <= cols
+ - first_dimension+rows*cols <= get_num_dimensions()
+ - 0 < regularization_strength
+ ensures
+ - Adds a nuclear norm regularization term to the optimization problem
+ solved by this object. That is, instead of solving:
+ Minimize: h(w) == 0.5*dot(w,w) + C*R(w)
+ this object will solve:
+ Minimize: h(w) == 0.5*dot(w,w) + C*R(w) + regularization_strength*nuclear_norm_of(part of w)
+ where "part of w" is the part of w indicated by the arguments to this
+ function. In particular, the part of w included in the nuclear norm is
+ exactly the matrix reshape(rowm(w, range(first_dimension, first_dimension+rows*cols-1)), rows, cols).
+ Therefore, if you think of the w vector as being the concatenation of a
+ bunch of matrices then you can use multiple calls to add_nuclear_norm_regularizer()
+ to add nuclear norm regularization terms to any of the matrices packed into w.
+ - #num_nuclear_norm_regularizers() == num_nuclear_norm_regularizers() + 1
+ !*/
+
+ unsigned long num_nuclear_norm_regularizers (
+ ) const;
+ /*!
+ ensures
+ - returns the number of nuclear norm regularizers that are currently a part
+ of this optimization problem. That is, returns the number of times
+ add_nuclear_norm_regularizer() has been called since the last call to
+ clear_nuclear_norm_regularizers() or object construction, whichever is
+ most recent.
+ !*/
+
+ void clear_nuclear_norm_regularizers (
+ );
+ /*!
+ ensures
+ - #num_nuclear_norm_regularizers() == 0
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet(
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ scalar_type get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter that
+ determines the trade off between trying to fit the training data
+ exactly or allowing more errors but hopefully improving the
+ generalization of the resulting classifier. Larger values encourage
+ exact fitting while smaller values of C may encourage better
+ generalization.
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() == C
+ !*/
+
+ // --------------------------------
+ // User supplied routines
+ // --------------------------------
+
+ virtual long get_num_dimensions (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the dimensionality of a joint feature vector
+ !*/
+
+ virtual long get_num_samples (
+ ) const = 0;
+ /*!
+ ensures
+ - returns the number of training samples in this problem.
+ !*/
+
+ virtual void get_truth_joint_feature_vector (
+ long idx,
+ feature_vector_type& psi
+ ) const = 0;
+ /*!
+ requires
+ - 0 <= idx < get_num_samples()
+ ensures
+ - #psi == PSI(x_idx, y_idx)
+ (i.e. the joint feature vector for the idx-th training sample its true label.)
+ !*/
+
+ virtual void separation_oracle (
+ const long idx,
+ const matrix_type& current_solution,
+ scalar_type& loss,
+ feature_vector_type& psi
+ ) const = 0;
+ /*!
+ requires
+ - 0 <= idx < get_num_samples()
+ - current_solution.size() == get_num_dimensions()
+ ensures
+ - runs the separation oracle on the idx-th sample. We define this as follows:
+ - let X == the idx-th training sample.
+ - let PSI(X,y) == the joint feature vector for input X and an arbitrary label y.
+ - let F(X,y) == dot(current_solution,PSI(X,y)).
+ - let LOSS(idx,y) == the loss incurred for predicting that the idx-th sample
+ has a label of y. Note that LOSS() should always be >= 0 and should
+ become exactly 0 when y is the correct label for the idx-th sample.
+
+ Then the separation oracle finds a Y such that:
+ Y = argmax over all y: LOSS(idx,y) + F(X,y)
+ (i.e. It finds the label which maximizes the above expression.)
+
+ Finally, we can define the outputs of this function as:
+ - #loss == LOSS(idx,Y)
+ - #psi == PSI(X,Y)
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_PRObLEM_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/structural_svm_problem_threaded.h b/ml/dlib/dlib/svm/structural_svm_problem_threaded.h
new file mode 100644
index 000000000..e981ba8d9
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_problem_threaded.h
@@ -0,0 +1,157 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_
+#define DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_
+
+#include "structural_svm_problem_threaded_abstract.h"
+#include "../algs.h"
+#include <vector>
+#include "structural_svm_problem.h"
+#include "../matrix.h"
+#include "sparse_vector.h"
+#include <iostream>
+#include "../threads.h"
+#include "../misc_api.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type_,
+ typename feature_vector_type_ = matrix_type_
+ >
+ class structural_svm_problem_threaded : public structural_svm_problem<matrix_type_,feature_vector_type_>
+ {
+ public:
+
+ typedef matrix_type_ matrix_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef feature_vector_type_ feature_vector_type;
+
+ explicit structural_svm_problem_threaded (
+ unsigned long num_threads
+ ) :
+ tp(num_threads),
+ num_iterations_executed(0)
+ {}
+
+ unsigned long get_num_threads (
+ ) const { return tp.num_threads_in_pool(); }
+
+ private:
+
+ struct binder
+ {
+ binder (
+ const structural_svm_problem_threaded& self_,
+ const matrix_type& w_,
+ matrix_type& subgradient_,
+ scalar_type& total_loss_,
+ bool buffer_subgradients_locally_
+ ) : self(self_), w(w_), subgradient(subgradient_), total_loss(total_loss_),
+ buffer_subgradients_locally(buffer_subgradients_locally_){}
+
+ void call_oracle (
+ long begin,
+ long end
+ )
+ {
+ // If we are only going to call the separation oracle once then don't run
+ // the slightly more complex for loop version of this code. Or if we just
+ // don't want to run the complex buffering one. The code later on decides
+ // if we should do the buffering based on how long it takes to execute. We
+ // do this because, when the subgradient is really high dimensional it can
+ // take a lot of time to add them together. So we might want to avoid
+ // doing that.
+ if (end-begin <= 1 || !buffer_subgradients_locally)
+ {
+ scalar_type loss;
+ feature_vector_type ftemp;
+ for (long i = begin; i < end; ++i)
+ {
+ self.separation_oracle_cached(i, w, loss, ftemp);
+
+ auto_mutex lock(self.accum_mutex);
+ total_loss += loss;
+ add_to(subgradient, ftemp);
+ }
+ }
+ else
+ {
+ scalar_type loss = 0;
+ matrix_type faccum(subgradient.size(),1);
+ faccum = 0;
+
+ feature_vector_type ftemp;
+
+ for (long i = begin; i < end; ++i)
+ {
+ scalar_type loss_temp;
+ self.separation_oracle_cached(i, w, loss_temp, ftemp);
+ loss += loss_temp;
+ add_to(faccum, ftemp);
+ }
+
+ auto_mutex lock(self.accum_mutex);
+ total_loss += loss;
+ add_to(subgradient, faccum);
+ }
+ }
+
+ const structural_svm_problem_threaded& self;
+ const matrix_type& w;
+ matrix_type& subgradient;
+ scalar_type& total_loss;
+ bool buffer_subgradients_locally;
+ };
+
+
+ virtual void call_separation_oracle_on_all_samples (
+ const matrix_type& w,
+ matrix_type& subgradient,
+ scalar_type& total_loss
+ ) const
+ {
+ ++num_iterations_executed;
+
+ const uint64 start_time = ts.get_timestamp();
+
+ bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean();
+
+ // every 50 iterations we should try to flip the buffering scheme to see if
+ // doing it the other way might be better.
+ if ((num_iterations_executed%50) == 0)
+ {
+ buffer_subgradients_locally = !buffer_subgradients_locally;
+ }
+
+ binder b(*this, w, subgradient, total_loss, buffer_subgradients_locally);
+ parallel_for_blocked(tp, 0, this->get_num_samples(), b, &binder::call_oracle);
+
+ const uint64 stop_time = ts.get_timestamp();
+
+ if (buffer_subgradients_locally)
+ with_buffer_time.add(stop_time-start_time);
+ else
+ without_buffer_time.add(stop_time-start_time);
+
+ }
+
+ mutable thread_pool tp;
+ mutable mutex accum_mutex;
+ mutable timestamper ts;
+ mutable running_stats<double> with_buffer_time;
+ mutable running_stats<double> without_buffer_time;
+ mutable unsigned long num_iterations_executed;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/structural_svm_problem_threaded_abstract.h b/ml/dlib/dlib/svm/structural_svm_problem_threaded_abstract.h
new file mode 100644
index 000000000..3cfc6a6eb
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_problem_threaded_abstract.h
@@ -0,0 +1,68 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_ABSTRACT_Hh_
+
+#include "structural_svm_problem_abstract.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type_,
+ typename feature_vector_type_ = matrix_type_
+ >
+ class structural_svm_problem_threaded : public structural_svm_problem<matrix_type_,feature_vector_type_>
+ {
+ public:
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is identical to the structural_svm_problem object defined in
+ dlib/svm/structural_svm_problem_abstract.h except that its constructor
+ takes a number which defines how many threads will be used to make concurrent
+ calls to the separation_oracle() routine.
+
+ So this object lets you take advantage of a multi-core system. You should
+ set the num_threads parameter equal to the number of available cores. Note
+ that the separation_oracle() function which you provide must be thread safe
+ if you are to use this version of the structural_svm_problem. In
+ particular, it must be safe to call separation_oracle() concurrently from
+ different threads. However, it is guaranteed that different threads will
+ never make concurrent calls to separation_oracle() using the same idx value
+ (i.e. the first argument).
+ !*/
+
+ typedef matrix_type_ matrix_type;
+ typedef typename matrix_type::type scalar_type;
+ typedef feature_vector_type_ feature_vector_type;
+
+ structural_svm_problem (
+ unsigned long num_threads
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ - #get_num_threads() == num_threads
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - Returns the number of threads which will be used to make concurrent
+ calls to the separation_oracle() function.
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem.h b/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem.h
new file mode 100644
index 000000000..68dff66f5
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem.h
@@ -0,0 +1,281 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_Hh_
+#define DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_Hh_
+
+
+#include "structural_svm_sequence_labeling_problem_abstract.h"
+#include "../matrix.h"
+#include "sequence_labeler.h"
+#include <vector>
+#include "structural_svm_problem_threaded.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ namespace fe_helpers
+ {
+
+ // ----------------------------------------------------------------------------------------
+
+ struct get_feats_functor
+ {
+ get_feats_functor(std::vector<std::pair<unsigned long, double> >& feats_) : feats(feats_) {}
+
+ inline void operator() (
+ unsigned long feat_index,
+ double feat_value
+ )
+ {
+ feats.push_back(std::make_pair(feat_index, feat_value));
+ }
+
+ inline void operator() (
+ unsigned long feat_index
+ )
+ {
+ feats.push_back(std::make_pair(feat_index, 1));
+ }
+
+ std::vector<std::pair<unsigned long, double> >& feats;
+ };
+
+ // ----------------------------------------------------------------------------------------
+
+ template <typename feature_extractor, typename sequence_type, typename EXP2>
+ void get_feature_vector(
+ std::vector<std::pair<unsigned long, double> >& feats,
+ const feature_extractor& fe,
+ const sequence_type& sequence,
+ const matrix_exp<EXP2>& candidate_labeling,
+ unsigned long position
+ )
+ {
+ get_feats_functor funct(feats);
+ fe.get_features(funct, sequence,candidate_labeling, position);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class structural_svm_sequence_labeling_problem : noncopyable,
+ public structural_svm_problem_threaded<matrix<double,0,1>, std::vector<std::pair<unsigned long,double> > >
+ {
+ public:
+ typedef matrix<double,0,1> matrix_type;
+ typedef std::vector<std::pair<unsigned long, double> > feature_vector_type;
+
+ typedef typename feature_extractor::sequence_type sequence_type;
+
+ structural_svm_sequence_labeling_problem(
+ const std::vector<sequence_type>& samples_,
+ const std::vector<std::vector<unsigned long> >& labels_,
+ const feature_extractor& fe_,
+ unsigned long num_threads = 2
+ ) :
+ structural_svm_problem_threaded<matrix_type,feature_vector_type>(num_threads),
+ samples(samples_),
+ labels(labels_),
+ fe(fe_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true &&
+ contains_invalid_labeling(fe, samples, labels) == false,
+ "\t structural_svm_sequence_labeling_problem::structural_svm_sequence_labeling_problem()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels)
+ << "\n\t contains_invalid_labeling(fe,samples,labels): " << contains_invalid_labeling(fe,samples,labels)
+ << "\n\t this: " << this
+ );
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ for (unsigned long j = 0; j < labels[i].size(); ++j)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(labels[i][j] < fe.num_labels(),
+ "\t structural_svm_sequence_labeling_problem::structural_svm_sequence_labeling_problem()"
+ << "\n\t The given labels in labels are invalid."
+ << "\n\t labels[i][j]: " << labels[i][j]
+ << "\n\t fe.num_labels(): " << fe.num_labels()
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t this: " << this
+ );
+ }
+ }
+#endif
+
+ loss_values.assign(num_labels(), 1);
+
+ }
+
+ unsigned long num_labels (
+ ) const { return fe.num_labels(); }
+
+ double get_loss (
+ unsigned long label
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(label < num_labels(),
+ "\t void structural_svm_sequence_labeling_problem::get_loss()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t label: " << label
+ << "\n\t num_labels(): " << num_labels()
+ << "\n\t this: " << this
+ );
+
+ return loss_values[label];
+ }
+
+ void set_loss (
+ unsigned long label,
+ double value
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(label < num_labels() && value >= 0,
+ "\t void structural_svm_sequence_labeling_problem::set_loss()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t label: " << label
+ << "\n\t num_labels(): " << num_labels()
+ << "\n\t value: " << value
+ << "\n\t this: " << this
+ );
+
+ loss_values[label] = value;
+ }
+
+ private:
+ virtual long get_num_dimensions (
+ ) const
+ {
+ return fe.num_features();
+ }
+
+ virtual long get_num_samples (
+ ) const
+ {
+ return samples.size();
+ }
+
+ void get_joint_feature_vector (
+ const sequence_type& sample,
+ const std::vector<unsigned long>& label,
+ feature_vector_type& psi
+ ) const
+ {
+ psi.clear();
+
+ const int order = fe.order();
+
+ matrix<unsigned long,0,1> candidate_labeling;
+ for (unsigned long i = 0; i < sample.size(); ++i)
+ {
+ candidate_labeling = rowm(mat(label), range(i, std::max((int)i-order,0)));
+
+ fe_helpers::get_feature_vector(psi,fe,sample,candidate_labeling, i);
+ }
+ }
+
+ virtual void get_truth_joint_feature_vector (
+ long idx,
+ feature_vector_type& psi
+ ) const
+ {
+ get_joint_feature_vector(samples[idx], labels[idx], psi);
+ }
+
+ class map_prob
+ {
+ public:
+ unsigned long order() const { return fe.order(); }
+ unsigned long num_states() const { return fe.num_labels(); }
+
+ map_prob(
+ const sequence_type& sequence_,
+ const std::vector<unsigned long>& label_,
+ const feature_extractor& fe_,
+ const matrix<double,0,1>& weights_,
+ const std::vector<double>& loss_values_
+ ) :
+ sequence(sequence_),
+ label(label_),
+ fe(fe_),
+ weights(weights_),
+ loss_values(loss_values_)
+ {
+ }
+
+ unsigned long number_of_nodes(
+ ) const
+ {
+ return sequence.size();
+ }
+
+ template <
+ typename EXP
+ >
+ double factor_value (
+ unsigned long node_id,
+ const matrix_exp<EXP>& node_states
+ ) const
+ {
+ if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id))
+ return -std::numeric_limits<double>::infinity();
+
+ double loss = 0;
+ if (node_states(0) != label[node_id])
+ loss = loss_values[label[node_id]];
+
+ return fe_helpers::dot(weights, fe, sequence, node_states, node_id) + loss;
+ }
+
+ const sequence_type& sequence;
+ const std::vector<unsigned long>& label;
+ const feature_extractor& fe;
+ const matrix<double,0,1>& weights;
+ const std::vector<double>& loss_values;
+ };
+
+ virtual void separation_oracle (
+ const long idx,
+ const matrix_type& current_solution,
+ scalar_type& loss,
+ feature_vector_type& psi
+ ) const
+ {
+ std::vector<unsigned long> y;
+ find_max_factor_graph_viterbi(map_prob(samples[idx],labels[idx],fe,current_solution,loss_values), y);
+
+ loss = 0;
+ for (unsigned long i = 0; i < y.size(); ++i)
+ {
+ if (y[i] != labels[idx][i])
+ loss += loss_values[labels[idx][i]];
+ }
+
+ get_joint_feature_vector(samples[idx], y, psi);
+ }
+
+ const std::vector<sequence_type>& samples;
+ const std::vector<std::vector<unsigned long> >& labels;
+ const feature_extractor& fe;
+ std::vector<double> loss_values;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_Hh_
+
diff --git a/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem_abstract.h
new file mode 100644
index 000000000..b46a55350
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem_abstract.h
@@ -0,0 +1,110 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_ABSTRACT_Hh_
+
+
+#include "../matrix.h"
+#include <vector>
+#include "structural_svm_problem_threaded_abstract.h"
+#include "sequence_labeler_abstract.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ template <
+ typename feature_extractor
+ >
+ class structural_svm_sequence_labeling_problem : noncopyable,
+ public structural_svm_problem_threaded<matrix<double,0,1>,
+ std::vector<std::pair<unsigned long,double> > >
+ {
+ /*!
+ REQUIREMENTS ON feature_extractor
+ It must be an object that implements an interface compatible with
+ the example_feature_extractor defined in dlib/svm/sequence_labeler_abstract.h.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning the weight vector needed to use
+ a sequence_labeler object.
+
+ It learns the parameter vector by formulating the problem as a structural
+ SVM problem. The general approach is discussed in the paper:
+ Hidden Markov Support Vector Machines by
+ Y. Altun, I. Tsochantaridis, T. Hofmann
+ While the particular optimization strategy used is the method from:
+ T. Joachims, T. Finley, Chun-Nam Yu, Cutting-Plane Training of
+ Structural SVMs, Machine Learning, 77(1):27-59, 2009.
+ !*/
+
+ public:
+ typedef typename feature_extractor::sequence_type sequence_type;
+
+ structural_svm_sequence_labeling_problem(
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<unsigned long> >& labels,
+ const feature_extractor& fe,
+ unsigned long num_threads = 2
+ );
+ /*!
+ requires
+ - is_sequence_labeling_problem(samples, labels) == true
+ - contains_invalid_labeling(fe, samples, labels) == false
+ - for all valid i and j: labels[i][j] < fe.num_labels()
+ ensures
+ - This object attempts to learn a mapping from the given samples to the
+ given labels. In particular, it attempts to learn to predict labels[i]
+ based on samples[i]. Or in other words, this object can be used to learn
+ a parameter vector, w, such that a sequence_labeler declared as:
+ sequence_labeler<feature_extractor> labeler(w,fe)
+ results in a labeler object which attempts to compute the following mapping:
+ labels[i] == labeler(samples[i])
+ - This object will use num_threads threads during the optimization
+ procedure. You should set this parameter equal to the number of
+ available processing cores on your machine.
+ - #num_labels() == fe.num_labels()
+ - for all valid i: #get_loss(i) == 1
+ !*/
+
+ unsigned long num_labels (
+ ) const;
+ /*!
+ ensures
+ - returns the number of possible labels in this learning problem
+ !*/
+
+ double get_loss (
+ unsigned long label
+ ) const;
+ /*!
+ requires
+ - label < num_labels()
+ ensures
+ - returns the loss incurred when a sequence element with the given
+ label is misclassified. This value controls how much we care about
+ correctly classifying this type of label. Larger loss values indicate
+ that we care more strongly than smaller values.
+ !*/
+
+ void set_loss (
+ unsigned long label,
+ double value
+ );
+ /*!
+ requires
+ - label < num_labels()
+ - value >= 0
+ ensures
+ - #get_loss(label) == value
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/structural_track_association_trainer.h b/ml/dlib/dlib/svm/structural_track_association_trainer.h
new file mode 100644
index 000000000..87fb829b2
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_track_association_trainer.h
@@ -0,0 +1,404 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
+#define DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
+
+#include "structural_track_association_trainer_abstract.h"
+#include "../algs.h"
+#include "svm.h"
+#include <utility>
+#include "track_association_function.h"
+#include "structural_assignment_trainer.h"
+#include <map>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ std::vector<detection_type> get_unlabeled_dets (
+ const std::vector<labeled_detection<detection_type,label_type> >& dets
+ )
+ {
+ std::vector<detection_type> temp;
+ temp.reserve(dets.size());
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ temp.push_back(dets[i].det);
+ return temp;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class structural_track_association_trainer
+ {
+ public:
+
+ structural_track_association_trainer (
+ )
+ {
+ set_defaults();
+ }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ num_threads = num;
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return num_threads;
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void structural_track_association_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ double get_epsilon (
+ ) const { return eps; }
+
+ void set_max_cache_size (
+ unsigned long max_size
+ )
+ {
+ max_cache_size = max_size;
+ }
+
+ unsigned long get_max_cache_size (
+ ) const
+ {
+ return max_cache_size;
+ }
+
+ void set_loss_per_false_association (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_track_association_trainer::set_loss_per_false_association(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_false_association = loss;
+ }
+
+ double get_loss_per_false_association (
+ ) const
+ {
+ return loss_per_false_association;
+ }
+
+ void set_loss_per_track_break (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_track_association_trainer::set_loss_per_track_break(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_track_break = loss;
+ }
+
+ double get_loss_per_track_break (
+ ) const
+ {
+ return loss_per_track_break;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ void set_c (
+ double C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void structural_track_association_trainer::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ double get_c (
+ ) const
+ {
+ return C;
+ }
+
+ bool learns_nonnegative_weights (
+ ) const { return learn_nonnegative_weights; }
+
+ void set_learns_nonnegative_weights (
+ bool value
+ )
+ {
+ learn_nonnegative_weights = value;
+ }
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ const track_association_function<detection_type> train (
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_track_association_problem(samples),
+ "\t track_association_function structural_track_association_trainer::train()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_track_association_problem(samples): " << is_track_association_problem(samples)
+ );
+
+ typedef typename detection_type::track_type track_type;
+
+ const unsigned long num_dims = find_num_dims(samples);
+
+ feature_extractor_track_association<detection_type> fe(num_dims, learn_nonnegative_weights?num_dims:0);
+ structural_assignment_trainer<feature_extractor_track_association<detection_type> > trainer(fe);
+
+
+ if (verbose)
+ trainer.be_verbose();
+
+ trainer.set_c(C);
+ trainer.set_epsilon(eps);
+ trainer.set_max_cache_size(max_cache_size);
+ trainer.set_num_threads(num_threads);
+ trainer.set_oca(solver);
+ trainer.set_loss_per_missed_association(loss_per_track_break);
+ trainer.set_loss_per_false_association(loss_per_false_association);
+
+ std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > > assignment_samples;
+ std::vector<std::vector<long> > labels;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ convert_dets_to_association_sets(samples[i], assignment_samples, labels);
+
+
+ return track_association_function<detection_type>(trainer.train(assignment_samples, labels));
+ }
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ const track_association_function<detection_type> train (
+ const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& sample
+ ) const
+ {
+ std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > > samples;
+ samples.push_back(sample);
+ return train(samples);
+ }
+
+ private:
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ static unsigned long find_num_dims (
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
+ )
+ {
+ typedef typename detection_type::track_type track_type;
+ // find a detection_type object so we can call get_similarity_features() and
+ // find out how big the feature vectors are.
+
+ // for all detection histories
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ // for all time instances in the detection history
+ for (unsigned j = 0; j < samples[i].size(); ++j)
+ {
+ if (samples[i][j].size() > 0)
+ {
+ track_type new_track;
+ new_track.update_track(samples[i][j][0].det);
+ typename track_type::feature_vector_type feats;
+ new_track.get_similarity_features(samples[i][j][0].det, feats);
+ return feats.size();
+ }
+ }
+ }
+
+ DLIB_CASSERT(false,
+ "No detection objects were given in the call to dlib::structural_track_association_trainer::train()");
+ }
+
+ template <
+ typename detections_at_single_time_step,
+ typename detection_type,
+ typename track_type
+ >
+ static void convert_dets_to_association_sets (
+ const std::vector<detections_at_single_time_step>& det_history,
+ std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > >& data,
+ std::vector<std::vector<long> >& labels
+ )
+ {
+ if (det_history.size() < 1)
+ return;
+
+ typedef typename detections_at_single_time_step::value_type::label_type label_type;
+ std::vector<track_type> tracks;
+ // track_labels maps from detection labels to the index in tracks. So track
+ // with detection label X is at tracks[track_labels[X]].
+ std::map<label_type,unsigned long> track_labels;
+ add_dets_to_tracks(tracks, track_labels, det_history[0]);
+
+ using namespace impl;
+ for (unsigned long i = 1; i < det_history.size(); ++i)
+ {
+ data.push_back(std::make_pair(get_unlabeled_dets(det_history[i]), tracks));
+ labels.push_back(get_association_labels(det_history[i], track_labels));
+ add_dets_to_tracks(tracks, track_labels, det_history[i]);
+ }
+ }
+
+ template <
+ typename labeled_detection,
+ typename label_type
+ >
+ static std::vector<long> get_association_labels(
+ const std::vector<labeled_detection>& dets,
+ const std::map<label_type,unsigned long>& track_labels
+ )
+ {
+ std::vector<long> assoc(dets.size(),-1);
+ // find out which detections associate to what tracks
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ typename std::map<label_type,unsigned long>::const_iterator j;
+ j = track_labels.find(dets[i].label);
+ // If this detection matches one of the tracks then record which track it
+ // matched with.
+ if (j != track_labels.end())
+ assoc[i] = j->second;
+ }
+ return assoc;
+ }
+
+ template <
+ typename track_type,
+ typename label_type,
+ typename labeled_detection
+ >
+ static void add_dets_to_tracks (
+ std::vector<track_type>& tracks,
+ std::map<label_type,unsigned long>& track_labels,
+ const std::vector<labeled_detection>& dets
+ )
+ {
+ std::vector<bool> updated_track(tracks.size(), false);
+
+ // first assign the dets to the tracks
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ const label_type& label = dets[i].label;
+ if (track_labels.count(label))
+ {
+ const unsigned long track_idx = track_labels[label];
+ tracks[track_idx].update_track(dets[i].det);
+ updated_track[track_idx] = true;
+ }
+ else
+ {
+ // this detection creates a new track
+ track_type new_track;
+ new_track.update_track(dets[i].det);
+ tracks.push_back(new_track);
+ track_labels[label] = tracks.size()-1;
+ }
+
+ }
+
+ // Now propagate all the tracks that didn't get any detections.
+ for (unsigned long i = 0; i < updated_track.size(); ++i)
+ {
+ if (!updated_track[i])
+ tracks[i].propagate_track();
+ }
+ }
+
+ double C;
+ oca solver;
+ double eps;
+ bool verbose;
+ unsigned long num_threads;
+ unsigned long max_cache_size;
+ bool learn_nonnegative_weights;
+ double loss_per_track_break;
+ double loss_per_false_association;
+
+ void set_defaults ()
+ {
+ C = 100;
+ verbose = false;
+ eps = 0.001;
+ num_threads = 2;
+ max_cache_size = 5;
+ learn_nonnegative_weights = false;
+ loss_per_track_break = 1;
+ loss_per_false_association = 1;
+ }
+ };
+
+}
+
+#endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
+
diff --git a/ml/dlib/dlib/svm/structural_track_association_trainer_abstract.h b/ml/dlib/dlib/svm/structural_track_association_trainer_abstract.h
new file mode 100644
index 000000000..e78fadef7
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_track_association_trainer_abstract.h
@@ -0,0 +1,268 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_ABSTRACT_Hh_
+#ifdef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_ABSTRACT_Hh_
+
+#include "track_association_function_abstract.h"
+#include "structural_assignment_trainer_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class structural_track_association_trainer
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for learning to solve a track association problem. That
+ is, it takes in a set of training data and outputs a track_association_function
+ you can use to do detection to track association. The training data takes the
+ form of a set or sets of "track histories". Each track history is a
+ std::vector where each element contains all the detections from a single time
+ step. Moreover, each detection has a label that uniquely identifies which
+ object (e.g. person or whatever) the detection really corresponds to. That is,
+ the labels indicate the correct detection to track associations. The goal of
+ this object is then to produce a track_association_function that can perform a
+ correct detection to track association at each time step.
+ !*/
+
+ public:
+
+ structural_track_association_trainer (
+ );
+ /*!
+ ensures
+ - #get_c() == 100
+ - this object isn't verbose
+ - #get_epsilon() == 0.001
+ - #get_num_threads() == 2
+ - #get_max_cache_size() == 5
+ - #learns_nonnegative_weights() == false
+ - #get_loss_per_track_break() == 1
+ - #get_loss_per_false_association() == 1
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads used during training. You should
+ usually set this equal to the number of processing cores on your
+ machine.
+ !*/
+
+ void set_epsilon (
+ double eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ double get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer to
+ train. You can think of this epsilon value as saying "solve the
+ optimization problem until the average number of association mistakes per
+ time step is within epsilon of its optimal value".
+ !*/
+
+ void set_max_cache_size (
+ unsigned long max_size
+ );
+ /*!
+ ensures
+ - #get_max_cache_size() == max_size
+ !*/
+
+ unsigned long get_max_cache_size (
+ ) const;
+ /*!
+ ensures
+ - During training, this object basically runs the track_association_function on
+ each training sample, over and over. To speed this up, it is possible to
+ cache the results of these invocations. This function returns the number
+ of cache elements per training sample kept in the cache. Note that a value
+ of 0 means caching is not used at all.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a user can
+ observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_loss_per_false_association (
+ double loss
+ );
+ /*!
+ requires
+ - loss > 0
+ ensures
+ - #get_loss_per_false_association() == loss
+ !*/
+
+ double get_loss_per_false_association (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss experienced for assigning a detection to the
+ wrong track. If you care more about avoiding false associations than
+ avoiding track breaks then you can increase this value.
+ !*/
+
+ void set_loss_per_track_break (
+ double loss
+ );
+ /*!
+ requires
+ - loss > 0
+ ensures
+ - #get_loss_per_track_break() == loss
+ !*/
+
+ double get_loss_per_track_break (
+ ) const;
+ /*!
+ ensures
+ - returns the amount of loss experienced for incorrectly assigning a
+ detection to a new track instead of assigning it to its existing track.
+ If you care more about avoiding track breaks than avoiding things like
+ track swaps then you can increase this value.
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - Internally this object treats track association learning as a structural
+ SVM problem. This routine returns a copy of the optimizer used to solve
+ the structural SVM problem.
+ !*/
+
+ void set_c (
+ double C
+ );
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter that
+ determines the trade-off between trying to fit the training data (i.e.
+ minimize the loss) or allowing more errors but hopefully improving the
+ generalization of the resulting track_association_function. Larger
+ values encourage exact fitting while smaller values of C may encourage
+ better generalization.
+ !*/
+
+ double get_c (
+ ) const;
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() = C
+ !*/
+
+ bool learns_nonnegative_weights (
+ ) const;
+ /*!
+ ensures
+ - Ultimately, the output of training is a parameter vector that defines the
+ behavior of the track_association_function. If
+ learns_nonnegative_weights() == true then the resulting learned parameter
+ vector will always have non-negative entries.
+ !*/
+
+ void set_learns_nonnegative_weights (
+ bool value
+ );
+ /*!
+ ensures
+ - #learns_nonnegative_weights() == value
+ !*/
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ const track_association_function<detection_type> train (
+ const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& sample
+ ) const;
+ /*!
+ requires
+ - is_track_association_problem(sample) == true
+ ensures
+ - This function attempts to learn to do track association from the given
+ training data. Note that we interpret sample as a single track history such
+ that sample[0] are all detections from the first time step, then sample[1]
+ are detections from the second time step, and so on.
+ - returns a function F such that:
+ - Executing F(tracks, detections) will try to correctly associate the
+ contents of detections to the contents of tracks and perform track
+ updating and creation.
+ - if (learns_nonnegative_weights() == true) then
+ - min(F.get_assignment_function().get_weights()) >= 0
+ !*/
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ const track_association_function<detection_type> train (
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& sample
+ ) const;
+ /*!
+ requires
+ - is_track_association_problem(samples) == true
+ ensures
+ - This function attempts to learn to do track association from the given
+ training data. In this case, we take a set of track histories as
+ training data instead of just one track history as with the above train()
+ method.
+ - returns a function F such that:
+ - Executing F(tracks, detections) will try to correctly associate the
+ contents of detections to the contents of tracks and perform track
+ updating and creation.
+ - if (learns_nonnegative_weights() == true) then
+ - min(F.get_assignment_function().get_weights()) >= 0
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/svm.h b/ml/dlib/dlib/svm/svm.h
new file mode 100644
index 000000000..e0587ef4a
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm.h
@@ -0,0 +1,1205 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_
+#define DLIB_SVm_
+
+#include "svm_abstract.h"
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "../rand.h"
+#include "../std_allocator.h"
+#include "function.h"
+#include "kernel.h"
+#include "../enable_if.h"
+#include "../optimization.h"
+#include "svm_nu_trainer.h"
+#include <vector>
+#include <set>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ inline bool is_learning_problem_impl (
+ const T& x,
+ const U& x_labels
+ )
+ {
+ return is_col_vector(x) &&
+ is_col_vector(x_labels) &&
+ x.size() == x_labels.size() &&
+ x.size() > 0;
+ }
+
+ template <
+ typename T,
+ typename U
+ >
+ inline bool is_learning_problem (
+ const T& x,
+ const U& x_labels
+ )
+ {
+ return is_learning_problem_impl(mat(x), mat(x_labels));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ bool is_binary_classification_problem_impl (
+ const T& x,
+ const U& x_labels
+ )
+ {
+ bool seen_neg_class = false;
+ bool seen_pos_class = false;
+
+ if (is_learning_problem_impl(x,x_labels) == false)
+ return false;
+
+ if (x.size() <= 1) return false;
+
+ for (long r = 0; r < x_labels.nr(); ++r)
+ {
+ if (x_labels(r) != -1 && x_labels(r) != 1)
+ return false;
+
+ if (x_labels(r) == 1)
+ seen_pos_class = true;
+ if (x_labels(r) == -1)
+ seen_neg_class = true;
+ }
+
+ return seen_pos_class && seen_neg_class;
+ }
+
+ template <
+ typename T,
+ typename U
+ >
+ bool is_binary_classification_problem (
+ const T& x,
+ const U& x_labels
+ )
+ {
+ return is_binary_classification_problem_impl(mat(x), mat(x_labels));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dec_funct_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double,1,2> test_binary_decision_function_impl (
+ const dec_funct_type& dec_funct,
+ const in_sample_vector_type& x_test,
+ const in_scalar_vector_type& y_test
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_binary_classification_problem(x_test,y_test) == true,
+ "\tmatrix test_binary_decision_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_binary_classification_problem(x_test,y_test): "
+ << ((is_binary_classification_problem(x_test,y_test))? "true":"false"));
+
+
+ // count the number of positive and negative examples
+ long num_pos = 0;
+ long num_neg = 0;
+
+
+ long num_pos_correct = 0;
+ long num_neg_correct = 0;
+
+
+ // now test this trained object
+ for (long i = 0; i < x_test.nr(); ++i)
+ {
+ // if this is a positive example
+ if (y_test(i) == +1.0)
+ {
+ ++num_pos;
+ if (dec_funct(x_test(i)) >= 0)
+ ++num_pos_correct;
+ }
+ else if (y_test(i) == -1.0)
+ {
+ ++num_neg;
+ if (dec_funct(x_test(i)) < 0)
+ ++num_neg_correct;
+ }
+ else
+ {
+ throw dlib::error("invalid input labels to the test_binary_decision_function() function");
+ }
+ }
+
+
+ matrix<double, 1, 2> res;
+ res(0) = (double)num_pos_correct/(double)(num_pos);
+ res(1) = (double)num_neg_correct/(double)(num_neg);
+ return res;
+ }
+
+ template <
+ typename dec_funct_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double,1,2> test_binary_decision_function (
+ const dec_funct_type& dec_funct,
+ const in_sample_vector_type& x_test,
+ const in_scalar_vector_type& y_test
+ )
+ {
+ return test_binary_decision_function_impl(dec_funct,
+ mat(x_test),
+ mat(y_test));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sequence_type
+ >
+ bool is_sequence_labeling_problem (
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<unsigned long> >& labels
+ )
+ {
+ if (is_learning_problem(samples, labels))
+ {
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (samples[i].size() != labels[i].size())
+ return false;
+ }
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sequence_type
+ >
+ bool is_sequence_segmentation_problem (
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments
+ )
+ {
+ if (is_learning_problem(samples, segments))
+ {
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ // Make sure the segments are inside samples[i] and don't overlap with each
+ // other.
+ std::vector<bool> hits(samples[i].size(), false);
+ for (unsigned long j = 0; j < segments[i].size(); ++j)
+ {
+ const unsigned long begin = segments[i][j].first;
+ const unsigned long end = segments[i][j].second;
+ // if the segment is outside the sequence
+ if (end > samples[i].size())
+ return false;
+
+ if (begin >= end)
+ return false;
+
+ // check for overlap
+ for (unsigned long k = begin; k < end; ++k)
+ {
+ if (hits[k])
+ return false;
+ hits[k] = true;
+ }
+ }
+ }
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lhs_type,
+ typename rhs_type
+ >
+ bool is_assignment_problem (
+ const std::vector<std::pair<std::vector<lhs_type>, std::vector<rhs_type> > >& samples,
+ const std::vector<std::vector<long> >& labels
+ )
+ {
+ std::vector<bool> seen_label;
+
+ if (is_learning_problem(samples, labels))
+ {
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (samples[i].first.size() != labels[i].size())
+ return false;
+
+ seen_label.assign(samples[i].second.size(), false);
+
+ for (unsigned long j = 0; j < labels[i].size(); ++j)
+ {
+ if (!(-1 <= labels[i][j] && labels[i][j] < (long)samples[i].second.size()))
+ return false;
+
+ if (labels[i][j] != -1)
+ {
+ // check label uniqueness
+ if (seen_label[labels[i][j]])
+ return false;
+
+ seen_label[labels[i][j]] = true;
+ }
+ }
+ }
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lhs_type,
+ typename rhs_type
+ >
+ bool is_forced_assignment_problem (
+ const std::vector<std::pair<std::vector<lhs_type>, std::vector<rhs_type> > >& samples,
+ const std::vector<std::vector<long> >& labels
+ )
+ {
+ if (is_assignment_problem(samples, labels))
+ {
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ const unsigned long N = sum(mat(labels[i]) != -1);
+ if (std::min(samples[i].first.size(), samples[i].second.size()) != N)
+ return false;
+ }
+ return true;
+ }
+
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename detection_type_,
+ typename label_type_ = long
+ >
+ struct labeled_detection
+ {
+ typedef detection_type_ detection_type;
+ typedef label_type_ label_type;
+ detection_type det;
+ label_type label;
+ };
+
+ template <
+ typename detection_type_,
+ typename label_type_
+ >
+ inline void serialize ( const labeled_detection<detection_type_,label_type_>& item, std::ostream& out)
+ {
+ serialize(item.det, out);
+ serialize(item.label, out);
+ }
+
+ template <
+ typename detection_type_,
+ typename label_type_
+ >
+ inline void deserialize (labeled_detection<detection_type_,label_type_>& item, std::istream& in)
+ {
+ deserialize(item.det, in);
+ deserialize(item.label, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ bool is_track_association_problem (
+ const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& samples
+ )
+ {
+ if (samples.size() == 0)
+ return false;
+
+ unsigned long num_nonzero_elements = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (samples.size() > 0)
+ ++num_nonzero_elements;
+ }
+ if (num_nonzero_elements < 2)
+ return false;
+
+ // now make sure the label_type values are unique within each time step.
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ std::set<label_type> vals;
+ for (unsigned long j = 0; j < samples[i].size(); ++j)
+ vals.insert(samples[i][j].label);
+ if (vals.size() != samples[i].size())
+ return false;
+ }
+
+ // passed all tests so it's good
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ bool is_track_association_problem (
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
+ )
+ {
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (!is_track_association_problem(samples[i]))
+ return false;
+ }
+
+ // passed all tests so it's good
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
+ cross_validate_trainer_impl (
+ const trainer_type& trainer,
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const long folds
+ )
+ {
+ typedef typename in_scalar_vector_type::value_type scalar_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y) == true &&
+ 1 < folds && folds <= std::min(sum(y>0),sum(y<0)),
+ "\tmatrix cross_validate_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t std::min(sum(y>0),sum(y<0)): " << std::min(sum(y>0),sum(y<0))
+ << "\n\t folds: " << folds
+ << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
+ );
+
+
+ // count the number of positive and negative examples
+ long num_pos = 0;
+ long num_neg = 0;
+ for (long r = 0; r < y.nr(); ++r)
+ {
+ if (y(r) == +1.0)
+ ++num_pos;
+ else
+ ++num_neg;
+ }
+
+ // figure out how many positive and negative examples we will have in each fold
+ const long num_pos_test_samples = num_pos/folds;
+ const long num_pos_train_samples = num_pos - num_pos_test_samples;
+ const long num_neg_test_samples = num_neg/folds;
+ const long num_neg_train_samples = num_neg - num_neg_test_samples;
+
+
+ matrix<long,0,1> x_test, x_train;
+ scalar_vector_type y_test, y_train;
+ x_test.set_size (num_pos_test_samples + num_neg_test_samples);
+ y_test.set_size (num_pos_test_samples + num_neg_test_samples);
+ x_train.set_size(num_pos_train_samples + num_neg_train_samples);
+ y_train.set_size(num_pos_train_samples + num_neg_train_samples);
+
+ long pos_idx = 0;
+ long neg_idx = 0;
+
+ matrix<double, 1, 2, mem_manager_type> res;
+ set_all_elements(res,0);
+
+ for (long i = 0; i < folds; ++i)
+ {
+ long cur = 0;
+
+ // load up our positive test samples
+ while (cur < num_pos_test_samples)
+ {
+ if (y(pos_idx) == +1.0)
+ {
+ x_test(cur) = pos_idx;
+ y_test(cur) = +1.0;
+ ++cur;
+ }
+ pos_idx = (pos_idx+1)%x.nr();
+ }
+
+ // load up our negative test samples
+ while (cur < x_test.nr())
+ {
+ if (y(neg_idx) == -1.0)
+ {
+ x_test(cur) = neg_idx;
+ y_test(cur) = -1.0;
+ ++cur;
+ }
+ neg_idx = (neg_idx+1)%x.nr();
+ }
+
+ // load the training data from the data following whatever we loaded
+ // as the testing data
+ long train_pos_idx = pos_idx;
+ long train_neg_idx = neg_idx;
+ cur = 0;
+
+ // load up our positive train samples
+ while (cur < num_pos_train_samples)
+ {
+ if (y(train_pos_idx) == +1.0)
+ {
+ x_train(cur) = train_pos_idx;
+ y_train(cur) = +1.0;
+ ++cur;
+ }
+ train_pos_idx = (train_pos_idx+1)%x.nr();
+ }
+
+ // load up our negative train samples
+ while (cur < x_train.nr())
+ {
+ if (y(train_neg_idx) == -1.0)
+ {
+ x_train(cur) = train_neg_idx;
+ y_train(cur) = -1.0;
+ ++cur;
+ }
+ train_neg_idx = (train_neg_idx+1)%x.nr();
+ }
+
+ try
+ {
+ // do the training and testing
+ res += test_binary_decision_function(trainer.train(rowm(x,x_train),y_train),rowm(x,x_test),y_test);
+ }
+ catch (invalid_nu_error&)
+ {
+ // Just ignore the error in this case since we are going to
+ // interpret an invalid nu value the same as generating a decision
+ // function that miss-classifies everything.
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+ return res/(double)folds;
+ }
+
+ template <
+ typename trainer_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
+ cross_validate_trainer (
+ const trainer_type& trainer,
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const long folds
+ )
+ {
+ return cross_validate_trainer_impl(trainer,
+ mat(x),
+ mat(y),
+ folds);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace prob_impl
+ {
+ template <typename vect_type>
+ struct objective
+ {
+ objective (
+ const vect_type& f_,
+ const vect_type& t_
+ ) : f(f_), t(t_) {}
+
+ double operator() (
+ const matrix<double,2,1>& x
+ ) const
+ {
+ const double A = x(0);
+ const double B = x(1);
+
+ double res = 0;
+ for (unsigned long i = 0; i < f.size(); ++i)
+ {
+ const double val = A*f[i]+B;
+ // See the paper "A Note on Platt's Probabilistic Outputs for Support Vector Machines"
+ // for an explanation of why this code looks the way it does (rather than being the
+ // obvious formula).
+ if (val < 0)
+ res += (t[i] - 1)*val + std::log(1 + std::exp(val));
+ else
+ res += t[i]*val + std::log(1 + std::exp(-val));
+ }
+
+ return res;
+ }
+
+ const vect_type& f;
+ const vect_type& t;
+ };
+
+ template <typename vect_type>
+ struct der
+ {
+ der (
+ const vect_type& f_,
+ const vect_type& t_
+ ) : f(f_), t(t_) {}
+
+ matrix<double,2,1> operator() (
+ const matrix<double,2,1>& x
+ ) const
+ {
+ const double A = x(0);
+ const double B = x(1);
+
+ double derA = 0;
+ double derB = 0;
+
+ for (unsigned long i = 0; i < f.size(); ++i)
+ {
+ const double val = A*f[i]+B;
+ double p;
+ // compute p = 1/(1+exp(val))
+ // but do so in a way that avoids numerical overflow.
+ if (val < 0)
+ p = 1.0/(1 + std::exp(val));
+ else
+ p = std::exp(-val)/(1 + std::exp(-val));
+
+ derA += f[i]*(t[i] - p);
+ derB += (t[i] - p);
+ }
+
+ matrix<double,2,1> res;
+ res = derA, derB;
+ return res;
+ }
+
+ const vect_type& f;
+ const vect_type& t;
+ };
+
+ template <typename vect_type>
+ struct hessian
+ {
+ hessian (
+ const vect_type& f_,
+ const vect_type& t_
+ ) : f(f_), t(t_) {}
+
+ matrix<double,2,2> operator() (
+ const matrix<double,2,1>& x
+ ) const
+ {
+ const double A = x(0);
+ const double B = x(1);
+
+ matrix<double,2,2> h;
+ h = 0;
+
+ for (unsigned long i = 0; i < f.size(); ++i)
+ {
+ const double val = A*f[i]+B;
+ // compute pp = 1/(1+exp(val)) and
+ // compute pn = 1 - pp
+ // but do so in a way that avoids numerical overflow and catastrophic cancellation.
+ double pp, pn;
+ if (val < 0)
+ {
+ const double temp = std::exp(val);
+ pp = 1.0/(1 + temp);
+ pn = temp*pp;
+ }
+ else
+ {
+ const double temp = std::exp(-val);
+ pn = 1.0/(1 + temp);
+ pp = temp*pn;
+ }
+
+ h(0,0) += f[i]*f[i]*pp*pn;
+ const double temp2 = f[i]*pp*pn;
+ h(0,1) += temp2;
+ h(1,0) += temp2;
+ h(1,1) += pp*pn;
+ }
+
+ return h;
+ }
+
+ const vect_type& f;
+ const vect_type& t;
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline double platt_scale (
+ const std::pair<double,double>& params,
+ const double score
+ )
+ {
+ return 1/(1 + std::exp(params.first*score + params.second));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ std::pair<double,double> learn_platt_scaling (
+ const std::vector<T,alloc>& scores,
+ const std::vector<T,alloc>& labels
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(scores,labels) == true,
+ "\t std::pair<T,T> learn_platt_scaling()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t scores.size(): " << scores.size()
+ << "\n\t labels.size(): " << labels.size()
+ << "\n\t is_binary_classification_problem(scores,labels): " << is_binary_classification_problem(scores,labels)
+ );
+
+ const T num_pos = sum(mat(labels)>0);
+ const T num_neg = sum(mat(labels)<0);
+ const T hi_target = (num_pos+1)/(num_pos+2);
+ const T lo_target = 1.0/(num_neg+2);
+
+ std::vector<T,alloc> target;
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ // if this was a positive example
+ if (labels[i] == +1.0)
+ {
+ target.push_back(hi_target);
+ }
+ else if (labels[i] == -1.0)
+ {
+ target.push_back(lo_target);
+ }
+ else
+ {
+ throw dlib::error("invalid input labels to the learn_platt_scaling() function.");
+ }
+ }
+
+ // Now find the maximum likelihood parameters of the sigmoid.
+
+ prob_impl::objective<std::vector<T,alloc> > obj(scores, target);
+ prob_impl::der<std::vector<T,alloc> > obj_der(scores, target);
+ prob_impl::hessian<std::vector<T,alloc> > obj_hessian(scores, target);
+
+ matrix<double,2,1> val;
+ val = 0;
+ find_min(newton_search_strategy(obj_hessian),
+ objective_delta_stop_strategy(),
+ obj,
+ obj_der,
+ val,
+ 0);
+
+ const double A = val(0);
+ const double B = val(1);
+
+ return std::make_pair(A,B);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sample_vector_type,
+ typename label_vector_type
+ >
+ const probabilistic_function<typename trainer_type::trained_function_type>
+ train_probabilistic_decision_function (
+ const trainer_type& trainer,
+ const sample_vector_type& x,
+ const label_vector_type& y,
+ const long folds
+ )
+ {
+ typedef typename sample_vector_type::value_type sample_type;
+ typedef typename label_vector_type::value_type scalar_type;
+
+ /*
+ This function fits a sigmoid function to the output of the
+ svm trained by svm_nu_trainer or a similar trainer. The
+ technique used is the one described in the papers:
+
+ Probabilistic Outputs for Support Vector Machines and
+ Comparisons to Regularized Likelihood Methods by
+ John C. Platt. March 26, 1999
+
+ A Note on Platt's Probabilistic Outputs for Support Vector Machines
+ by Hsuan-Tien Lin, Chih-Jen Lin, and Ruby C. Weng
+ */
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y) == true &&
+ 1 < folds && folds <= (long)x.size(),
+ "\tprobabilistic_decision_function train_probabilistic_decision_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.size(): " << x.size()
+ << "\n\t y.size(): " << y.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
+ );
+
+ // count the number of positive and negative examples
+ const long num_pos = (long)sum(mat(y) > 0);
+ const long num_neg = (long)sum(mat(y) < 0);
+
+ // figure out how many positive and negative examples we will have in each fold
+ const long num_pos_test_samples = num_pos/folds;
+ const long num_pos_train_samples = num_pos - num_pos_test_samples;
+ const long num_neg_test_samples = num_neg/folds;
+ const long num_neg_train_samples = num_neg - num_neg_test_samples;
+
+ typename trainer_type::trained_function_type d;
+ std::vector<sample_type> x_test, x_train;
+ std::vector<scalar_type> y_test, y_train;
+ x_test.resize (num_pos_test_samples + num_neg_test_samples);
+ y_test.resize (num_pos_test_samples + num_neg_test_samples);
+ x_train.resize(num_pos_train_samples + num_neg_train_samples);
+ y_train.resize(num_pos_train_samples + num_neg_train_samples);
+
+ std::vector<scalar_type> out, out_label;
+
+ long pos_idx = 0;
+ long neg_idx = 0;
+
+ for (long i = 0; i < folds; ++i)
+ {
+ long cur = 0;
+
+ // load up our positive test samples
+ while (cur < num_pos_test_samples)
+ {
+ if (y[pos_idx] == +1.0)
+ {
+ x_test[cur] = x[pos_idx];
+ y_test[cur] = +1.0;
+ ++cur;
+ }
+ pos_idx = (pos_idx+1)%x.size();
+ }
+
+ // load up our negative test samples
+ while (cur < (long)x_test.size())
+ {
+ if (y[neg_idx] == -1.0)
+ {
+ x_test[cur] = x[neg_idx];
+ y_test[cur] = -1.0;
+ ++cur;
+ }
+ neg_idx = (neg_idx+1)%x.size();
+ }
+
+ // load the training data from the data following whatever we loaded
+ // as the testing data
+ long train_pos_idx = pos_idx;
+ long train_neg_idx = neg_idx;
+ cur = 0;
+
+ // load up our positive train samples
+ while (cur < num_pos_train_samples)
+ {
+ if (y[train_pos_idx] == +1.0)
+ {
+ x_train[cur] = x[train_pos_idx];
+ y_train[cur] = +1.0;
+ ++cur;
+ }
+ train_pos_idx = (train_pos_idx+1)%x.size();
+ }
+
+ // load up our negative train samples
+ while (cur < (long)x_train.size())
+ {
+ if (y[train_neg_idx] == -1.0)
+ {
+ x_train[cur] = x[train_neg_idx];
+ y_train[cur] = -1.0;
+ ++cur;
+ }
+ train_neg_idx = (train_neg_idx+1)%x.size();
+ }
+
+ // do the training
+ d = trainer.train (x_train,y_train);
+
+ // now test this fold
+ for (unsigned long i = 0; i < x_test.size(); ++i)
+ {
+ out.push_back(d(x_test[i]));
+ out_label.push_back(y_test[i]);
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+ std::pair<double,double> params = learn_platt_scaling(out, out_label);
+
+ const double A = params.first;
+ const double B = params.second;
+
+ return probabilistic_function<typename trainer_type::trained_function_type>( A, B, trainer.train(x,y) );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename trainer_type>
+ struct trainer_adapter_probabilistic
+ {
+ typedef probabilistic_function<typename trainer_type::trained_function_type> trained_function_type;
+
+ const trainer_type trainer;
+ const long folds;
+
+ trainer_adapter_probabilistic (
+ const trainer_type& trainer_,
+ const long folds_
+ ) : trainer(trainer_),folds(folds_) {}
+
+ template <
+ typename T,
+ typename U
+ >
+ const trained_function_type train (
+ const T& samples,
+ const U& labels
+ ) const
+ {
+ return train_probabilistic_decision_function(trainer, samples, labels, folds);
+ }
+
+ };
+
+ template <
+ typename trainer_type
+ >
+ trainer_adapter_probabilistic<trainer_type> probabilistic (
+ const trainer_type& trainer,
+ const long folds
+ )
+ {
+ return trainer_adapter_probabilistic<trainer_type>(trainer,folds);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V,
+ typename rand_type
+ >
+ typename enable_if<is_matrix<T>,void>::type randomize_samples (
+ T& t,
+ U& u,
+ V& v,
+ rand_type& r
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(t) && is_vector(u) && is_vector(v) && u.size() == t.size() &&
+ u.size() == v.size(),
+ "\t randomize_samples(t,u,v)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t t.size(): " << t.size()
+ << "\n\t u.size(): " << u.size()
+ << "\n\t v.size(): " << v.size()
+ << "\n\t is_vector(t): " << is_vector(t)
+ << "\n\t is_vector(u): " << is_vector(u)
+ << "\n\t is_vector(v): " << is_vector(v)
+ );
+
+ long n = t.size()-1;
+ while (n > 0)
+ {
+ // pick a random index to swap into t[n]
+ const unsigned long idx = r.get_random_32bit_number()%(n+1);
+
+ // swap our randomly selected index into the n position
+ exchange(t(idx), t(n));
+ exchange(u(idx), u(n));
+ exchange(v(idx), v(n));
+
+ --n;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V,
+ typename rand_type
+ >
+ typename disable_if<is_matrix<T>,void>::type randomize_samples (
+ T& t,
+ U& u,
+ V& v,
+ rand_type& r
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(u.size() == t.size() && u.size() == v.size(),
+ "\t randomize_samples(t,u,v)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t t.size(): " << t.size()
+ << "\n\t u.size(): " << u.size()
+ << "\n\t v.size(): " << v.size()
+ );
+
+ long n = t.size()-1;
+ while (n > 0)
+ {
+ // pick a random index to swap into t[n]
+ const unsigned long idx = r.get_random_32bit_number()%(n+1);
+
+ // swap our randomly selected index into the n position
+ exchange(t[idx], t[n]);
+ exchange(u[idx], u[n]);
+ exchange(v[idx], v[n]);
+
+ --n;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V
+ >
+ typename disable_if<is_rand<V>,void>::type randomize_samples (
+ T& t,
+ U& u,
+ V& v
+ )
+ {
+ rand r;
+ randomize_samples(t,u,v,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename rand_type
+ >
+ typename enable_if_c<is_matrix<T>::value && is_rand<rand_type>::value,void>::type randomize_samples (
+ T& t,
+ U& u,
+ rand_type& r
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(t) && is_vector(u) && u.size() == t.size(),
+ "\t randomize_samples(t,u)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t t.size(): " << t.size()
+ << "\n\t u.size(): " << u.size()
+ << "\n\t is_vector(t): " << (is_vector(t)? "true" : "false")
+ << "\n\t is_vector(u): " << (is_vector(u)? "true" : "false")
+ );
+
+ long n = t.size()-1;
+ while (n > 0)
+ {
+ // pick a random index to swap into t[n]
+ const unsigned long idx = r.get_random_32bit_number()%(n+1);
+
+ // swap our randomly selected index into the n position
+ exchange(t(idx), t(n));
+ exchange(u(idx), u(n));
+
+ --n;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename rand_type
+ >
+ typename disable_if_c<is_matrix<T>::value || !is_rand<rand_type>::value,void>::type randomize_samples (
+ T& t,
+ U& u,
+ rand_type& r
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(u.size() == t.size(),
+ "\t randomize_samples(t,u)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t t.size(): " << t.size()
+ << "\n\t u.size(): " << u.size()
+ );
+
+ long n = t.size()-1;
+ while (n > 0)
+ {
+ // pick a random index to swap into t[n]
+ const unsigned long idx = r.get_random_32bit_number()%(n+1);
+
+ // swap our randomly selected index into the n position
+ exchange(t[idx], t[n]);
+ exchange(u[idx], u[n]);
+
+ --n;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ typename disable_if<is_rand<U>,void>::type randomize_samples (
+ T& t,
+ U& u
+ )
+ {
+ rand r;
+ randomize_samples(t,u,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename rand_type
+ >
+ typename enable_if_c<is_matrix<T>::value && is_rand<rand_type>::value,void>::type randomize_samples (
+ T& t,
+ rand_type& r
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(t),
+ "\t randomize_samples(t)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_vector(t): " << (is_vector(t)? "true" : "false")
+ );
+
+ long n = t.size()-1;
+ while (n > 0)
+ {
+ // pick a random index to swap into t[n]
+ const unsigned long idx = r.get_random_32bit_number()%(n+1);
+
+ // swap our randomly selected index into the n position
+ exchange(t(idx), t(n));
+
+ --n;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename rand_type
+ >
+ typename disable_if_c<(is_matrix<T>::value==true)||(is_rand<rand_type>::value==false),void>::type randomize_samples (
+ T& t,
+ rand_type& r
+ )
+ {
+ long n = t.size()-1;
+ while (n > 0)
+ {
+ // pick a random index to swap into t[n]
+ const unsigned long idx = r.get_random_32bit_number()%(n+1);
+
+ // swap our randomly selected index into the n position
+ exchange(t[idx], t[n]);
+
+ --n;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void randomize_samples (
+ T& t
+ )
+ {
+ rand r;
+ randomize_samples(t,r);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_
+
diff --git a/ml/dlib/dlib/svm/svm_abstract.h b/ml/dlib/dlib/svm/svm_abstract.h
new file mode 100644
index 000000000..ec92cf55b
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_abstract.h
@@ -0,0 +1,604 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_ABSTRACT_
+#ifdef DLIB_SVm_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "svm_nu_trainer_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ bool is_learning_problem (
+ const T& x,
+ const U& x_labels
+ );
+ /*!
+ requires
+ - T == a matrix or something convertible to a matrix via mat()
+ - U == a matrix or something convertible to a matrix via mat()
+ ensures
+ - returns true if all of the following are true and false otherwise:
+ - is_col_vector(x) == true
+ - is_col_vector(x_labels) == true
+ - x.size() == x_labels.size()
+ - x.size() > 0
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ bool is_binary_classification_problem (
+ const T& x,
+ const U& x_labels
+ );
+ /*!
+ requires
+ - T == a matrix or something convertible to a matrix via mat()
+ - U == a matrix or something convertible to a matrix via mat()
+ ensures
+ - returns true if all of the following are true and false otherwise:
+ - is_learning_problem(x, x_labels) == true
+ - x.size() > 1
+ - there exists at least one sample from both the +1 and -1 classes.
+ (i.e. all samples can't have the same label)
+ - for all valid i:
+ - x_labels(i) == -1 or +1
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sequence_type
+ >
+ bool is_sequence_labeling_problem (
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<unsigned long> >& labels
+ );
+ /*!
+ ensures
+ - returns true if all of the following are true and false otherwise:
+ - is_learning_problem(samples, labels) == true
+ - for all valid i:
+ - samples[i].size() == labels[i].size()
+ (i.e. The size of a label sequence need to match the size of
+ its corresponding sample sequence)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sequence_type
+ >
+ bool is_sequence_segmentation_problem (
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments
+ );
+ /*!
+ ensures
+ - Note that a sequence segmentation problem is a task where you are given a
+ sequence of objects (e.g. words in a sentence) and your task is to find
+ certain types of sub-sequences (e.g. proper names).
+ - returns true if all of the following are true and false otherwise:
+ - is_learning_problem(samples, segments) == true
+ - for all valid i and j:
+ - We interpret segments[i][j] as defining a half open range starting
+ with segments[i][j].first and ending just before segments[i][j].second.
+ - segments[i][j].first < segments[i][j].second
+ - segments[i][j].second <= samples[i].size()
+ (i.e. Each segment must be contained within its associated sequence)
+ - segments[i][j] does not overlap with any of the other ranges in
+ segments[i].
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lhs_type,
+ typename rhs_type
+ >
+ bool is_assignment_problem (
+ const std::vector<std::pair<std::vector<lhs_type>, std::vector<rhs_type> > >& samples,
+ const std::vector<std::vector<long> >& labels
+ );
+ /*!
+ ensures
+ - Note that an assignment problem is a task to associate each element of samples[i].first
+ to an element of samples[i].second, or to indicate that the element doesn't associate
+ with anything. Therefore, labels[i] should contain the association information for
+ samples[i].
+ - This function returns true if all of the following are true and false otherwise:
+ - is_learning_problem(samples, labels) == true
+ - for all valid i:
+ - samples[i].first.size() == labels[i].size()
+ - for all valid j:
+ -1 <= labels[i][j] < samples[i].second.size()
+ (A value of -1 indicates that samples[i].first[j] isn't associated with anything.
+ All other values indicate the associating element of samples[i].second)
+ - All elements of labels[i] which are not equal to -1 are unique. That is,
+ multiple elements of samples[i].first can't associate to the same element
+ in samples[i].second.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename lhs_type,
+ typename rhs_type
+ >
+ bool is_forced_assignment_problem (
+ const std::vector<std::pair<std::vector<lhs_type>, std::vector<rhs_type> > >& samples,
+ const std::vector<std::vector<long> >& labels
+ );
+ /*!
+ ensures
+ - A regular assignment problem is allowed to indicate that all elements of
+ samples[i].first don't associate to anything. However, a forced assignment
+ problem is required to always associate an element of samples[i].first to
+ something in samples[i].second if there is an element of samples[i].second
+ that hasn't already been associated to something.
+ - This function returns true if all of the following are true and false otherwise:
+ - is_assignment_problem(samples, labels) == true
+ - for all valid i:
+ - let N denote the number of elements in labels[i] that are not equal to -1.
+ - min(samples[i].first.size(), samples[i].second.size()) == N
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename detection_type_,
+ typename label_type_ = long
+ >
+ struct labeled_detection
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a simple object, like std::pair, it just holds two objects. It
+ serves the same purpose as std::pair except that it has informative names
+ describing its two members and is intended for use with track association
+ problems.
+ !*/
+
+ typedef detection_type_ detection_type;
+ typedef label_type_ label_type;
+
+ detection_type det;
+ label_type label;
+ };
+
+ template <
+ typename detection_type_,
+ typename label_type_
+ >
+ void serialize (const labeled_detection<detection_type_,label_type_>& item, std::ostream& out);
+ /*!
+ provides serialization support
+ !*/
+
+ template <
+ typename detection_type_,
+ typename label_type_
+ >
+ void deserialize (labeled_detection<detection_type_,label_type_>& item, std::istream& in);
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ bool is_track_association_problem (
+ const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& samples
+ );
+ /*!
+ ensures
+ - In this tracking model you get a set of detections at each time step and are
+ expected to associate each detection with a track or have it spawn a new
+ track. Therefore, a track association problem is a machine learning problem
+ where you are given a dataset of example input detections and are expected to
+ learn to perform the proper detection to track association.
+ - This function checks if samples can form a valid dataset for this machine
+ learning problem and returns true if this is the case. This means we should
+ interpret samples in the following way:
+ - samples is a track history and for each valid i:
+ - samples[i] is a set of labeled detections from the i-th time step.
+ Each detection has been labeled with its "true object identity".
+ That is, all the detection throughout the history with the same
+ label_type value are detections from the same object and therefore
+ should be associated to the same track.
+ Putting this all together, samples is a valid track association learning
+ problem if and only if the following are all true:
+ - samples.size() > 0
+ - There are at least two values, i and j such that:
+ - i != j
+ - samples[i].size() > 0
+ - samples[j].size() > 0
+ Or in other words, there needs to be some detections in samples somewhere
+ or it is impossible to learn anything.
+ - for all valid i:
+ - for all valid j and k where j!=k:
+ - samples[i][j].label != samples[i][k].label
+ (i.e. the label_type values must be unique within each time step.
+ Or in other words, you can't have two detections on the same
+ object in a single time step.)
+ !*/
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ bool is_track_association_problem (
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
+ );
+ /*!
+ ensures
+ - returns true if is_track_association_problem(samples[i]) == true for all
+ valid i and false otherwise.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ double platt_scale (
+ const std::pair<double,double>& params,
+ const double score
+ );
+ /*!
+ ensures
+ - returns 1/(1 + std::exp(params.first*score + params.second))
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T, typename alloc>
+ std::pair<double,double> learn_platt_scaling (
+ const std::vector<T,alloc>& scores,
+ const std::vector<T,alloc>& labels
+ );
+ /*!
+ requires
+ - T should be either float, double, or long double
+ - is_binary_classification_problem(scores,labels) == true
+ ensures
+ - This function learns to map scalar values into well calibrated probabilities
+ using Platt scaling. In particular, it returns a params object such that,
+ for all valid i:
+ - platt_scale(params,scores[i]) == the scaled version of the scalar value
+ scores[i]. That is, the output is a number between 0 and 1. In
+ particular, platt_scale(params,scores[i]) is meant to represent the
+ probability that labels[i] == +1.
+ - This function is an implementation of the algorithm described in the following
+ papers:
+ Probabilistic Outputs for Support Vector Machines and Comparisons to
+ Regularized Likelihood Methods by John C. Platt. March 26, 1999
+
+ A Note on Platt's Probabilistic Outputs for Support Vector Machines
+ by Hsuan-Tien Lin, Chih-Jen Lin, and Ruby C. Weng
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sample_vector_type,
+ typename label_vector_type
+ >
+ const probabilistic_function<typename trainer_type::trained_function_type>
+ train_probabilistic_decision_function (
+ const trainer_type& trainer,
+ const sample_vector_type& x,
+ const label_vector_type& y,
+ const long folds
+ );
+ /*!
+ requires
+ - 1 < folds <= x.size()
+ - is_binary_classification_problem(x,y) == true
+ - x and y must be std::vector objects or types with a compatible interface.
+ - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer)
+ ensures
+ - trains a classifier given the training samples in x and labels in y.
+ - returns a probabilistic_decision_function that represents the trained classifier.
+ - The parameters of the probability model are estimated by performing k-fold
+ cross validation.
+ - The number of folds used is given by the folds argument.
+ - This function is implemented using learn_platt_scaling()
+ throws
+ - any exceptions thrown by trainer.train()
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ trainer_adapter_probabilistic<trainer_type> probabilistic (
+ const trainer_type& trainer,
+ const long folds
+ );
+ /*!
+ requires
+ - 1 < folds <= x.size()
+ - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer)
+ ensures
+ - returns a trainer adapter TA such that calling TA.train(samples, labels)
+ returns the same object as calling train_probabilistic_decision_function(trainer,samples,labels,folds).
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Miscellaneous functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double,1,2> cross_validate_trainer (
+ const trainer_type& trainer,
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const long folds
+ );
+ /*!
+ requires
+ - is_binary_classification_problem(x,y) == true
+ - 1 < folds <= std::min(sum(y>0),sum(y<0))
+ (e.g. There must be at least as many examples of each class as there are folds)
+ - trainer_type == some kind of binary classification trainer object (e.g. svm_nu_trainer)
+ ensures
+ - performs k-fold cross validation by using the given trainer to solve the
+ given binary classification problem for the given number of folds.
+ Each fold is tested using the output of the trainer and the average
+ classification accuracy from all folds is returned.
+ - The average accuracy is computed by running test_binary_decision_function()
+ on each fold and its output is averaged and returned.
+ - The number of folds used is given by the folds argument.
+ throws
+ - any exceptions thrown by trainer.train()
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dec_funct_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double,1,2> test_binary_decision_function (
+ const dec_funct_type& dec_funct,
+ const in_sample_vector_type& x_test,
+ const in_scalar_vector_type& y_test
+ );
+ /*!
+ requires
+ - is_binary_classification_problem(x_test,y_test) == true
+ - dec_funct_type == some kind of decision function object (e.g. decision_function)
+ ensures
+ - Tests the given decision function by calling it on the x_test and y_test samples.
+ The output of dec_funct is interpreted as a prediction for the +1 class
+ if its output is >= 0 and as a prediction for the -1 class otherwise.
+ - The test accuracy is returned in a row vector, let us call it R. Both
+ quantities in R are numbers between 0 and 1 which represent the fraction
+ of examples correctly classified. R(0) is the fraction of +1 examples
+ correctly classified and R(1) is the fraction of -1 examples correctly
+ classified.
+ throws
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U
+ >
+ void randomize_samples (
+ T& samples,
+ U& labels
+ );
+ /*!
+ requires
+ - T == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - U == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - if samples or labels are matrix objects then is_vector(samples) == true and
+ is_vector(labels) == true
+ - samples.size() == labels.size()
+ ensures
+ - randomizes the order of the samples and labels but preserves
+ the pairing between each sample and its label
+ - A default initialized random number generator is used to perform the randomizing.
+ Note that this means that each call this this function does the same thing.
+ That is, the random number generator always uses the same seed.
+ - for all valid i:
+ - let r == the random index samples(i) was moved to. then:
+ - #labels(r) == labels(i)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename rand_type
+ >
+ void randomize_samples (
+ T& samples,
+ U& labels,
+ rand_type& rnd
+ );
+ /*!
+ requires
+ - T == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - U == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - if samples or labels are matrix objects then is_vector(samples) == true and
+ is_vector(labels) == true
+ - samples.size() == labels.size()
+ - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface
+ ensures
+ - randomizes the order of the samples and labels but preserves
+ the pairing between each sample and its label
+ - the given rnd random number generator object is used to do the randomizing
+ - for all valid i:
+ - let r == the random index samples(i) was moved to. then:
+ - #labels(r) == labels(i)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void randomize_samples (
+ T& samples
+ );
+ /*!
+ requires
+ - T == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - if (samples is a matrix) then
+ - is_vector(samples) == true
+ ensures
+ - randomizes the order of the elements inside samples
+ - A default initialized random number generator is used to perform the randomizing.
+ Note that this means that each call this this function does the same thing.
+ That is, the random number generator always uses the same seed.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename rand_type
+ >
+ void randomize_samples (
+ T& samples,
+ rand_type& rnd
+ );
+ /*!
+ requires
+ - T == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface
+ - if (samples is a matrix) then
+ - is_vector(samples) == true
+ ensures
+ - randomizes the order of the elements inside samples
+ - the given rnd random number generator object is used to do the randomizing
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V
+ >
+ void randomize_samples (
+ T& samples,
+ U& labels,
+ V& auxiliary
+ );
+ /*!
+ requires
+ - T == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - U == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - V == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - if (samples, labels, or auxiliary are matrix objects) then
+ - is_vector(samples) == true
+ - is_vector(labels) == true
+ - is_vector(auxiliary) == true
+ - samples.size() == labels.size() == auxiliary.size()
+ ensures
+ - randomizes the order of the samples, labels, and auxiliary but preserves the
+ pairing between each sample, its label, and its auxiliary value.
+ - A default initialized random number generator is used to perform the
+ randomizing. Note that this means that each call this this function does the
+ same thing. That is, the random number generator always uses the same seed.
+ - for all valid i:
+ - let r == the random index samples(i) was moved to. then:
+ - #labels(r) == labels(i)
+ - #auxiliary(r) == auxiliary(i)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ typename U,
+ typename V,
+ typename rand_type
+ >
+ void randomize_samples (
+ T& samples,
+ U& labels,
+ V& auxiliary,
+ rand_type& rnd
+ );
+ /*!
+ requires
+ - T == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - U == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - V == a matrix object or an object compatible with std::vector that contains
+ a swappable type.
+ - if (samples, labels, or auxiliary are matrix objects) then
+ - is_vector(samples) == true
+ - is_vector(labels) == true
+ - is_vector(auxiliary) == true
+ - samples.size() == labels.size() == auxiliary.size()
+ - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface
+ ensures
+ - randomizes the order of the samples, labels, and auxiliary but preserves the
+ pairing between each sample, its label, and its auxiliary value.
+ - the given rnd random number generator object is used to do the randomizing
+ - for all valid i:
+ - let r == the random index samples(i) was moved to. then:
+ - #labels(r) == labels(i)
+ - #auxiliary(r) == auxiliary(i)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/svm/svm_c_ekm_trainer.h b/ml/dlib/dlib/svm/svm_c_ekm_trainer.h
new file mode 100644
index 000000000..735e0f22e
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_c_ekm_trainer.h
@@ -0,0 +1,636 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVM_C_EKm_TRAINER_Hh_
+#define DLIB_SVM_C_EKm_TRAINER_Hh_
+
+#include "../algs.h"
+#include "function.h"
+#include "kernel.h"
+#include "empirical_kernel_map.h"
+#include "svm_c_linear_trainer.h"
+#include "svm_c_ekm_trainer_abstract.h"
+#include "../statistics.h"
+#include "../rand.h"
+#include <vector>
+
+namespace dlib
+{
+ template <
+ typename K
+ >
+ class svm_c_ekm_trainer
+ {
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_c_ekm_trainer (
+ )
+ {
+ verbose = false;
+ ekm_stale = true;
+
+ initial_basis_size = 10;
+ basis_size_increment = 50;
+ max_basis_size = 300;
+ }
+
+ explicit svm_c_ekm_trainer (
+ const scalar_type& C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t svm_c_ekm_trainer::svm_c_ekm_trainer()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+
+ ocas.set_c(C);
+ verbose = false;
+ ekm_stale = true;
+
+ initial_basis_size = 10;
+ basis_size_increment = 50;
+ max_basis_size = 300;
+ }
+
+ void set_epsilon (
+ scalar_type eps
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps > 0,
+ "\t void svm_c_ekm_trainer::set_epsilon()"
+ << "\n\t eps must be greater than 0"
+ << "\n\t eps: " << eps
+ << "\n\t this: " << this
+ );
+
+ ocas.set_epsilon(eps);
+ }
+
+ const scalar_type get_epsilon (
+ ) const
+ {
+ return ocas.get_epsilon();
+ }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ ocas.set_max_iterations(max_iter);
+ }
+
+ unsigned long get_max_iterations (
+ )
+ {
+ return ocas.get_max_iterations();
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ ocas.be_quiet();
+ }
+
+ void be_very_verbose (
+ )
+ {
+ verbose = true;
+ ocas.be_verbose();
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ ocas.be_quiet();
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ ocas.set_oca(item);
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return ocas.get_oca();
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ return kern;
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kern = k;
+ ekm_stale = true;
+ }
+
+ template <typename T>
+ void set_basis (
+ const T& basis_samples
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(basis_samples.size() > 0 && is_vector(mat(basis_samples)),
+ "\tvoid svm_c_ekm_trainer::set_basis(basis_samples)"
+ << "\n\t You have to give a non-empty set of basis_samples and it must be a vector"
+ << "\n\t basis_samples.size(): " << basis_samples.size()
+ << "\n\t is_vector(mat(basis_samples)): " << is_vector(mat(basis_samples))
+ << "\n\t this: " << this
+ );
+
+ basis = mat(basis_samples);
+ ekm_stale = true;
+ }
+
+ bool basis_loaded(
+ ) const
+ {
+ return (basis.size() != 0);
+ }
+
+ void clear_basis (
+ )
+ {
+ basis.set_size(0);
+ ekm.clear();
+ ekm_stale = true;
+ }
+
+ unsigned long get_max_basis_size (
+ ) const
+ {
+ return max_basis_size;
+ }
+
+ void set_max_basis_size (
+ unsigned long max_basis_size_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_basis_size_ > 0,
+ "\t void svm_c_ekm_trainer::set_max_basis_size()"
+ << "\n\t max_basis_size_ must be greater than 0"
+ << "\n\t max_basis_size_: " << max_basis_size_
+ << "\n\t this: " << this
+ );
+
+ max_basis_size = max_basis_size_;
+ if (initial_basis_size > max_basis_size)
+ initial_basis_size = max_basis_size;
+ }
+
+ unsigned long get_initial_basis_size (
+ ) const
+ {
+ return initial_basis_size;
+ }
+
+ void set_initial_basis_size (
+ unsigned long initial_basis_size_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(initial_basis_size_ > 0,
+ "\t void svm_c_ekm_trainer::set_initial_basis_size()"
+ << "\n\t initial_basis_size_ must be greater than 0"
+ << "\n\t initial_basis_size_: " << initial_basis_size_
+ << "\n\t this: " << this
+ );
+
+ initial_basis_size = initial_basis_size_;
+
+ if (initial_basis_size > max_basis_size)
+ max_basis_size = initial_basis_size;
+ }
+
+ unsigned long get_basis_size_increment (
+ ) const
+ {
+ return basis_size_increment;
+ }
+
+ void set_basis_size_increment (
+ unsigned long basis_size_increment_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(basis_size_increment_ > 0,
+ "\t void svm_c_ekm_trainer::set_basis_size_increment()"
+ << "\n\t basis_size_increment_ must be greater than 0"
+ << "\n\t basis_size_increment_: " << basis_size_increment_
+ << "\n\t this: " << this
+ );
+
+ basis_size_increment = basis_size_increment_;
+ }
+
+ void set_c (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_ekm_trainer::set_c()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ ocas.set_c(C);
+ }
+
+ const scalar_type get_c_class1 (
+ ) const
+ {
+ return ocas.get_c_class1();
+ }
+
+ const scalar_type get_c_class2 (
+ ) const
+ {
+ return ocas.get_c_class2();
+ }
+
+ void set_c_class1 (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_ekm_trainer::set_c_class1()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ ocas.set_c_class1(C);
+ }
+
+ void set_c_class2 (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_ekm_trainer::set_c_class2()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ ocas.set_c_class2(C);
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ scalar_type obj;
+ if (basis_loaded())
+ return do_train_user_basis(mat(x),mat(y),obj);
+ else
+ return do_train_auto_basis(mat(x),mat(y),obj);
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ scalar_type& svm_objective
+ ) const
+ {
+ if (basis_loaded())
+ return do_train_user_basis(mat(x),mat(y),svm_objective);
+ else
+ return do_train_auto_basis(mat(x),mat(y),svm_objective);
+ }
+
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train_user_basis (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ scalar_type& svm_objective
+ ) const
+ /*!
+ requires
+ - basis_loaded() == true
+ ensures
+ - trains an SVM with the user supplied basis
+ !*/
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
+ "\t decision_function svm_c_ekm_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t y.nr(): " << y.nr()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t y.nc(): " << y.nc()
+ << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
+ );
+
+ if (ekm_stale)
+ {
+ ekm.load(kern, basis);
+ ekm_stale = false;
+ }
+
+ // project all the samples with the ekm
+ running_stats<scalar_type> rs;
+ std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples;
+ proj_samples.reserve(x.size());
+ for (long i = 0; i < x.size(); ++i)
+ {
+ if (verbose)
+ {
+ scalar_type err;
+ proj_samples.push_back(ekm.project(x(i), err));
+ rs.add(err);
+ }
+ else
+ {
+ proj_samples.push_back(ekm.project(x(i)));
+ }
+ }
+
+ if (verbose)
+ {
+ std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl;
+ std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl;
+ }
+
+ // now do the training
+ decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df;
+ df = ocas.train(proj_samples, y, svm_objective);
+
+ if (verbose)
+ {
+ std::cout << "Final svm objective: " << svm_objective << std::endl;
+ }
+
+ decision_function<kernel_type> final_df;
+ final_df = ekm.convert_to_decision_function(df.basis_vectors(0));
+ final_df.b = df.b;
+ return final_df;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train_auto_basis (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ scalar_type& svm_objective
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
+ "\t decision_function svm_c_ekm_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t y.nr(): " << y.nr()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t y.nc(): " << y.nc()
+ << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
+ );
+
+
+ std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples(x.size());
+ decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df;
+
+ // we will use a linearly_independent_subset_finder to store our basis set.
+ linearly_independent_subset_finder<kernel_type> lisf(get_kernel(), max_basis_size);
+
+ dlib::rand rnd;
+
+ // first pick the initial basis set randomly
+ for (unsigned long i = 0; i < 10*initial_basis_size && lisf.size() < initial_basis_size; ++i)
+ {
+ lisf.add(x(rnd.get_random_32bit_number()%x.size()));
+ }
+
+ ekm.load(lisf);
+
+ // first project all samples into the span of the current basis
+ for (long i = 0; i < x.size(); ++i)
+ {
+ proj_samples[i] = ekm.project(x(i));
+ }
+
+
+ svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > trainer(ocas);
+
+ const scalar_type min_epsilon = trainer.get_epsilon();
+ // while we are determining what the basis set will be we are going to use a very
+ // lose stopping condition. We will tighten it back up before producing the
+ // final decision_function.
+ trainer.set_epsilon(0.2);
+
+ scalar_type prev_svm_objective = std::numeric_limits<scalar_type>::max();
+
+ empirical_kernel_map<kernel_type> prev_ekm;
+
+ // This loop is where we try to generate a basis for SVM training. We will
+ // do this by repeatedly training the SVM and adding a few points which violate the
+ // margin to the basis in each iteration.
+ while (true)
+ {
+ // if the basis is already as big as it's going to get then just do the most
+ // accurate training right now.
+ if (lisf.size() == max_basis_size)
+ trainer.set_epsilon(min_epsilon);
+
+ while (true)
+ {
+ // now do the training.
+ df = trainer.train(proj_samples, y, svm_objective);
+
+ if (svm_objective < prev_svm_objective)
+ break;
+
+ // If the training didn't reduce the objective more than last time then
+ // try lowering the epsilon and doing it again.
+ if (trainer.get_epsilon() > min_epsilon)
+ {
+ trainer.set_epsilon(std::max(trainer.get_epsilon()*0.5, min_epsilon));
+ if (verbose)
+ std::cout << " *** Reducing epsilon to " << trainer.get_epsilon() << std::endl;
+ }
+ else
+ break;
+ }
+
+ if (verbose)
+ {
+ std::cout << "svm objective: " << svm_objective << std::endl;
+ std::cout << "basis size: " << lisf.size() << std::endl;
+ }
+
+ // if we failed to make progress on this iteration then we are done
+ if (svm_objective >= prev_svm_objective)
+ break;
+
+ prev_svm_objective = svm_objective;
+
+ // now add more elements to the basis
+ unsigned long count = 0;
+ for (unsigned long j = 0;
+ (j < 100*basis_size_increment) && (count < basis_size_increment) && (lisf.size() < max_basis_size);
+ ++j)
+ {
+ // pick a random sample
+ const unsigned long idx = rnd.get_random_32bit_number()%x.size();
+ // If it is a margin violator then it is useful to add it into the basis set.
+ if (df(proj_samples[idx])*y(idx) < 1)
+ {
+ // Add the sample into the basis set if it is linearly independent of all the
+ // vectors already in the basis set.
+ if (lisf.add(x(idx)))
+ {
+ ++count;
+ }
+ }
+ }
+ // if we couldn't add any more basis vectors then stop
+ if (count == 0)
+ {
+ if (verbose)
+ std::cout << "Stopping, couldn't add more basis vectors." << std::endl;
+ break;
+ }
+
+
+ // Project all the samples into the span of our newly enlarged basis. We will do this
+ // using the special transformation in the EKM that lets us project from a smaller
+ // basis set to a larger without needing to reevaluate kernel functions we have already
+ // computed.
+ ekm.swap(prev_ekm);
+ ekm.load(lisf);
+ projection_function<kernel_type> proj_part;
+ matrix<double> prev_to_new;
+ prev_ekm.get_transformation_to(ekm, prev_to_new, proj_part);
+
+
+ matrix<scalar_type,0,1, mem_manager_type> temp;
+ for (long i = 0; i < x.size(); ++i)
+ {
+ // assign to temporary to avoid memory allocation that would result if we
+ // assigned this expression straight into proj_samples[i]
+ temp = prev_to_new*proj_samples[i] + proj_part(x(i));
+ proj_samples[i] = temp;
+
+ }
+ }
+
+ // Reproject all the data samples using the final basis. We could just use what we
+ // already have but the recursive thing done above to compute the proj_samples
+ // might have accumulated a little numerical error. So lets just be safe.
+ running_stats<scalar_type> rs, rs_margin;
+ for (long i = 0; i < x.size(); ++i)
+ {
+ if (verbose)
+ {
+ scalar_type err;
+ proj_samples[i] = ekm.project(x(i),err);
+ rs.add(err);
+ // if this point is within the margin
+ if (df(proj_samples[i])*y(i) < 1)
+ rs_margin.add(err);
+ }
+ else
+ {
+ proj_samples[i] = ekm.project(x(i));
+ }
+ }
+
+ // do the final training
+ trainer.set_epsilon(min_epsilon);
+ df = trainer.train(proj_samples, y, svm_objective);
+
+
+ if (verbose)
+ {
+ std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl;
+ std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl;
+ std::cout << "Mean EKM projection error for margin violators: " << rs_margin.mean() << std::endl;
+ std::cout << "Standard deviation of EKM projection error for margin violators: " << ((rs_margin.current_n()>1)?rs_margin.stddev():0) << std::endl;
+
+ std::cout << "Final svm objective: " << svm_objective << std::endl;
+ }
+
+
+ decision_function<kernel_type> final_df;
+ final_df = ekm.convert_to_decision_function(df.basis_vectors(0));
+ final_df.b = df.b;
+
+ // we don't need the ekm anymore so clear it out
+ ekm.clear();
+
+ return final_df;
+ }
+
+
+
+
+ /*!
+ CONVENTION
+ - if (ekm_stale) then
+ - kern or basis have changed since the last time
+ they were loaded into the ekm
+ !*/
+
+ svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > ocas;
+ bool verbose;
+
+ kernel_type kern;
+ unsigned long max_basis_size;
+ unsigned long basis_size_increment;
+ unsigned long initial_basis_size;
+
+
+ matrix<sample_type,0,1,mem_manager_type> basis;
+ mutable empirical_kernel_map<kernel_type> ekm;
+ mutable bool ekm_stale;
+
+ };
+
+}
+
+#endif // DLIB_SVM_C_EKm_TRAINER_Hh_
+
+
+
diff --git a/ml/dlib/dlib/svm/svm_c_ekm_trainer_abstract.h b/ml/dlib/dlib/svm/svm_c_ekm_trainer_abstract.h
new file mode 100644
index 000000000..d1ba2bf5f
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_c_ekm_trainer_abstract.h
@@ -0,0 +1,384 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVM_C_EKm_TRAINER_ABSTRACT_Hh_
+#ifdef DLIB_SVM_C_EKm_TRAINER_ABSTRACT_Hh_
+
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "empirical_kernel_map_abstract.h"
+#include "svm_c_linear_trainer_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename K
+ >
+ class svm_c_ekm_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a tool for training the C formulation of
+ a support vector machine. It is implemented using the empirical_kernel_map
+ to kernelize the svm_c_linear_trainer. This makes it a very fast algorithm
+ capable of learning from very large datasets.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_c_ekm_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_oca() == oca() (i.e. an instance of oca with default parameters)
+ - #get_c_class1() == 1
+ - #get_c_class2() == 1
+ - #get_epsilon() == 0.001
+ - #basis_loaded() == false
+ - #get_initial_basis_size() == 10
+ - #get_basis_size_increment() == 50
+ - #get_max_basis_size() == 300
+ - this object will not be verbose unless be_verbose() is called
+ - #get_max_iterations() == 10000
+ !*/
+
+ explicit svm_c_ekm_trainer (
+ const scalar_type& C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_oca() == oca() (i.e. an instance of oca with default parameters)
+ - #get_c_class1() == C
+ - #get_c_class2() == C
+ - #get_epsilon() == 0.001
+ - #basis_loaded() == false
+ - #get_initial_basis_size() == 10
+ - #get_basis_size_increment() == 50
+ - #get_max_basis_size() == 300
+ - this object will not be verbose unless be_verbose() is called
+ - #get_max_iterations() == 10000
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer
+ to execute.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ unsigned long get_max_iterations (
+ );
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_very_verbose (
+ );
+ /*!
+ ensures
+ - This object will print a lot of status messages to standard out so that a
+ user can observe the progress of the algorithm. In addition to the
+ few status messages normal verbosity produces this setting also causes
+ the underlying svm_c_linear_trainer to be verbose.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the SVM problem.
+ !*/
+
+ const kernel_type get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ template <typename T>
+ void set_basis (
+ const T& basis_samples
+ );
+ /*!
+ requires
+ - T must be a dlib::matrix type or something convertible to a matrix via mat()
+ (e.g. a std::vector)
+ - is_vector(basis_samples) == true
+ - basis_samples.size() > 0
+ - get_kernel() must be capable of operating on the elements of basis_samples. That is,
+ expressions such as get_kernel()(basis_samples(0), basis_samples(0)) should make sense.
+ ensures
+ - #basis_loaded() == true
+ - training will be carried out in the span of the given basis_samples
+ !*/
+
+ bool basis_loaded (
+ ) const;
+ /*!
+ ensures
+ - returns true if this object has been loaded with user supplied basis vectors and false otherwise.
+ !*/
+
+ void clear_basis (
+ );
+ /*!
+ ensures
+ - #basis_loaded() == false
+ !*/
+
+ unsigned long get_max_basis_size (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of basis vectors this object is allowed
+ to use. This parameter only matters when the user has not supplied
+ a basis via set_basis().
+ !*/
+
+ void set_max_basis_size (
+ unsigned long max_basis_size
+ );
+ /*!
+ requires
+ - max_basis_size > 0
+ ensures
+ - #get_max_basis_size() == max_basis_size
+ - if (get_initial_basis_size() > max_basis_size) then
+ - #get_initial_basis_size() == max_basis_size
+ !*/
+
+ unsigned long get_initial_basis_size (
+ ) const;
+ /*!
+ ensures
+ - If the user does not supply a basis via set_basis() then this object
+ will generate one automatically. It does this by starting with
+ a small basis of size N and repeatedly adds basis vectors to it
+ until a stopping condition is reached. This function returns that
+ initial size N.
+ !*/
+
+ void set_initial_basis_size (
+ unsigned long initial_basis_size
+ );
+ /*!
+ requires
+ - initial_basis_size > 0
+ ensures
+ - #get_initial_basis_size() == initial_basis_size
+ - if (initial_basis_size > get_max_basis_size()) then
+ - #get_max_basis_size() == initial_basis_size
+ !*/
+
+ unsigned long get_basis_size_increment (
+ ) const;
+ /*!
+ ensures
+ - If the user does not supply a basis via set_basis() then this object
+ will generate one automatically. It does this by starting with a small
+ basis and repeatedly adds sets of N basis vectors to it until a stopping
+ condition is reached. This function returns that increment size N.
+ !*/
+
+ void set_basis_size_increment (
+ unsigned long basis_size_increment
+ );
+ /*!
+ requires
+ - basis_size_increment > 0
+ ensures
+ - #get_basis_size_increment() == basis_size_increment
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class1() == C
+ - #get_c_class2() == C
+ !*/
+
+ const scalar_type get_c_class1 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter for the +1 class.
+ It is the parameter that determines the trade off between
+ trying to fit the +1 training data exactly or allowing more errors
+ but hopefully improving the generalization ability of the
+ resulting classifier. Larger values encourage exact fitting
+ while smaller values of C may encourage better generalization.
+ !*/
+
+ const scalar_type get_c_class2 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter for the -1 class.
+ It is the parameter that determines the trade off between
+ trying to fit the -1 training data exactly or allowing more errors
+ but hopefully improving the generalization ability of the
+ resulting classifier. Larger values encourage exact fitting
+ while smaller values of C may encourage better generalization.
+ !*/
+
+ void set_c_class1 (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class1() == C
+ !*/
+
+ void set_c_class2 (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class2() == C
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - is_binary_classification_problem(x,y) == true
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ ensures
+ - trains a C support vector classifier given the training samples in x and
+ labels in y.
+ - if (basis_loaded()) then
+ - training will be carried out in the span of the user supplied basis vectors
+ - else
+ - this object will attempt to automatically select an appropriate basis
+
+ - returns a decision function F with the following properties:
+ - if (new_x is a sample predicted have +1 label) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ scalar_type& svm_objective
+ ) const;
+ /*!
+ requires
+ - is_binary_classification_problem(x,y) == true
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ ensures
+ - trains a C support vector classifier given the training samples in x and
+ labels in y.
+ - if (basis_loaded()) then
+ - training will be carried out in the span of the user supplied basis vectors
+ - else
+ - this object will attempt to automatically select an appropriate basis
+
+ - #svm_objective == the final value of the SVM objective function
+ - returns a decision function F with the following properties:
+ - if (new_x is a sample predicted have +1 label) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ !*/
+
+ };
+
+}
+
+#endif // DLIB_SVM_C_EKm_TRAINER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer.h b/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer.h
new file mode 100644
index 000000000..039b70993
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer.h
@@ -0,0 +1,712 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_C_LINEAR_DCD_TRAINER_Hh_
+#define DLIB_SVm_C_LINEAR_DCD_TRAINER_Hh_
+
+#include "svm_c_linear_dcd_trainer_abstract.h"
+#include <cmath>
+#include <limits>
+#include "../matrix.h"
+#include "../algs.h"
+#include "../rand.h"
+#include "svm.h"
+
+#include "function.h"
+#include "kernel.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_c_linear_dcd_trainer
+ {
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+ typedef typename decision_function<K>::sample_vector_type sample_vector_type;
+ typedef typename decision_function<K>::scalar_vector_type scalar_vector_type;
+
+ // You are getting a compiler error on this line because you supplied a non-linear
+ // kernel to the svm_c_linear_dcd_trainer object. You have to use one of the
+ // linear kernels with this trainer.
+ COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
+ is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
+
+ svm_c_linear_dcd_trainer (
+ ) :
+ Cpos(1),
+ Cneg(1),
+ eps(0.1),
+ max_iterations(10000),
+ verbose(false),
+ have_bias(true),
+ last_weight_1(false),
+ do_shrinking(true),
+ do_svm_l2(false)
+ {
+ }
+
+ explicit svm_c_linear_dcd_trainer (
+ const scalar_type& C_
+ ) :
+ Cpos(C_),
+ Cneg(C_),
+ eps(0.1),
+ max_iterations(10000),
+ verbose(false),
+ have_bias(true),
+ last_weight_1(false),
+ do_shrinking(true),
+ do_svm_l2(false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < C_,
+ "\tsvm_c_trainer::svm_c_linear_dcd_trainer(kernel,C)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t C_: " << C_
+ );
+ }
+
+ bool includes_bias (
+ ) const
+ {
+ return have_bias;
+ }
+
+ void include_bias (
+ bool should_have_bias
+ )
+ {
+ have_bias = should_have_bias;
+ }
+
+ bool forces_last_weight_to_1 (
+ ) const
+ {
+ return last_weight_1;
+ }
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ )
+ {
+ last_weight_1 = should_last_weight_be_1;
+ }
+
+ bool shrinking_enabled (
+ ) const { return do_shrinking; }
+
+ void enable_shrinking (
+ bool enabled
+ ) { do_shrinking = enabled; }
+
+ bool solving_svm_l2_problem (
+ ) const { return do_svm_l2; }
+
+ void solve_svm_l2_problem (
+ bool enabled
+ ) { do_svm_l2 = enabled; }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\tvoid svm_c_linear_dcd_trainer::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps_: " << eps_
+ );
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel_type();
+ }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void set_c (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_linear_dcd_trainer::set_c()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cpos = C;
+ Cneg = C;
+ }
+
+ const scalar_type get_c_class1 (
+ ) const
+ {
+ return Cpos;
+ }
+
+ const scalar_type get_c_class2 (
+ ) const
+ {
+ return Cneg;
+ }
+
+ void set_c_class1 (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_linear_dcd_trainer::set_c_class1()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cpos = C;
+ }
+
+ void set_c_class2 (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_linear_dcd_trainer::set_c_class2()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cneg = C;
+ }
+
+ class optimizer_state
+ {
+ friend class svm_c_linear_dcd_trainer;
+
+ public:
+ optimizer_state() : did_init(false) {}
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ void init(
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ bool have_bias_,
+ bool last_weight_1_,
+ bool do_svm_l2_,
+ scalar_type Cpos,
+ scalar_type Cneg
+ )
+ {
+ const long new_dims = max_index_plus_one(x);
+ long new_idx = 0;
+
+ if (did_init)
+ {
+ DLIB_CASSERT(have_bias_ == have_bias &&
+ last_weight_1_ == last_weight_1,
+ "\t decision_function svm_c_linear_dcd_trainer::train(x,y,state)"
+ << "\n\t The given state object is invalid because the previous trainer was configured differently."
+ << "\n\t have_bias_: " << have_bias_
+ << "\n\t have_bias: " << have_bias
+ << "\n\t last_weight_1_: " << last_weight_1_
+ << "\n\t last_weight_1: " << last_weight_1
+ );
+
+ DLIB_CASSERT( new_dims >= dims,
+ "\t decision_function svm_c_linear_dcd_trainer::train(x,y,state)"
+ << "\n\t The given state object is invalid because the training data dimensions have shrunk."
+ << "\n\t new_dims: " << new_dims
+ << "\n\t dims: " << dims
+ );
+
+ DLIB_CASSERT( x.size() >= static_cast<long>(alpha.size()),
+ "\t decision_function svm_c_linear_dcd_trainer::train(x,y,state)"
+ << "\n\t The given state object is invalid because the training data has fewer samples than previously."
+ << "\n\t x.size(): " << x.size()
+ << "\n\t alpha.size(): " << alpha.size()
+ );
+
+ // make sure we amortize the cost of growing the alpha vector.
+ if (alpha.capacity() < static_cast<unsigned long>(x.size()))
+ alpha.reserve(x.size()*2);
+
+ new_idx = alpha.size();
+
+ // Make sure alpha has the same length as x. So pad with extra zeros if
+ // necessary to make this happen.
+ alpha.resize(x.size(),0);
+
+
+ if (new_dims != dims)
+ {
+ // The only valid way the dimensions can be different here is if
+ // you are using a sparse vector type. This is because we might
+ // have had training samples which just happened to not include all
+ // the features previously. Therefore, max_index_plus_one() would
+ // have given too low of a result. But for dense vectors it is
+ // definitely a user error if the dimensions don't match.
+
+ DLIB_CASSERT(is_matrix<sample_type>::value == false,
+ "\t decision_function svm_c_linear_dcd_trainer::train(x,y,state)"
+ << "\n\t The given state object is invalid because the training data dimensions have changed."
+ << "\n\t new_dims: " << new_dims
+ << "\n\t dims: " << dims
+ );
+
+ // extend w by the right number of elements
+ if (have_bias && !last_weight_1)
+ {
+ // Splice some zeros into the w vector so it will have the
+ // right length. Here we are being careful to move the bias
+ // weight to the end of the resulting vector.
+ w = join_cols(join_cols(
+ colm(w,0,dims),
+ zeros_matrix<scalar_type>(new_dims-dims,1)),
+ uniform_matrix<scalar_type>(1,1,w(dims))
+ );
+ }
+ else
+ {
+ // Just concatenate the right number of zeros.
+ w = join_cols(w, zeros_matrix<scalar_type>(new_dims-dims,1));
+ }
+ dims = new_dims;
+ }
+
+ }
+ else
+ {
+ did_init = true;
+ have_bias = have_bias_;
+ last_weight_1 = last_weight_1_;
+ dims = new_dims;
+
+ alpha.resize(x.size());
+
+ index.reserve(x.size());
+ Q.reserve(x.size());
+
+ if (have_bias && !last_weight_1)
+ w.set_size(dims+1);
+ else
+ w.set_size(dims);
+
+ w = 0;
+ }
+
+ for (long i = new_idx; i < x.size(); ++i)
+ {
+ Q.push_back(length_squared(x(i)));
+
+ if (have_bias && !last_weight_1)
+ {
+ index.push_back(i);
+ Q.back() += 1;
+ }
+ else if (Q.back() != 0)
+ {
+ index.push_back(i);
+ }
+
+ if (do_svm_l2_)
+ {
+ if (y(i) > 0)
+ Q.back() += 1/(2*Cpos);
+ else
+ Q.back() += 1/(2*Cneg);
+ }
+ }
+
+ if (last_weight_1)
+ w(dims-1) = 1;
+ }
+
+ template <typename T>
+ typename enable_if<is_matrix<T>,scalar_type>::type length_squared (const T& x) const
+ {
+ if (!last_weight_1)
+ {
+ return dlib::dot(x,x);
+ }
+ else
+ {
+ // skip the last dimension
+ return dlib::dot(colm(x,0,x.size()-1),
+ colm(x,0,x.size()-1));
+ }
+
+ }
+
+ template <typename T>
+ typename disable_if<is_matrix<T>,scalar_type>::type length_squared (const T& x) const
+ {
+ if (!last_weight_1)
+ {
+ return dlib::dot(x,x);
+ }
+ else
+ {
+ scalar_type temp = 0;
+ typename T::const_iterator i;
+ for (i = x.begin(); i != x.end(); ++i)
+ {
+ // skip the last dimension
+ if (static_cast<long>(i->first) < dims-1)
+ temp += i->second*i->second;
+ }
+ return temp;
+ }
+ }
+
+
+ bool did_init;
+ bool have_bias;
+ bool last_weight_1;
+ std::vector<scalar_type> alpha;
+ scalar_vector_type w;
+ std::vector<scalar_type> Q;
+ std::vector<long> index;
+ long dims;
+ dlib::rand rnd;
+
+ public:
+
+ const std::vector<scalar_type>& get_alpha () const { return alpha; }
+
+ friend void serialize(const optimizer_state& item, std::ostream& out)
+ {
+ const int version = 1;
+ dlib::serialize(version, out);
+ dlib::serialize(item.did_init, out);
+ dlib::serialize(item.have_bias, out);
+ dlib::serialize(item.last_weight_1, out);
+ dlib::serialize(item.alpha, out);
+ dlib::serialize(item.w, out);
+ dlib::serialize(item.Q, out);
+ dlib::serialize(item.index, out);
+ dlib::serialize(item.dims, out);
+ dlib::serialize(item.rnd, out);
+ }
+
+ friend void deserialize(optimizer_state& item, std::istream& in)
+ {
+ int version = 0;
+ dlib::deserialize(version, in);
+ if (version != 1)
+ {
+ throw dlib::serialization_error(
+ "Error while deserializing dlib::svm_c_linear_dcd_trainer::optimizer_state, unexpected version."
+ );
+ }
+
+ dlib::deserialize(item.did_init, in);
+ dlib::deserialize(item.have_bias, in);
+ dlib::deserialize(item.last_weight_1, in);
+ dlib::deserialize(item.alpha, in);
+ dlib::deserialize(item.w, in);
+ dlib::deserialize(item.Q, in);
+ dlib::deserialize(item.index, in);
+ dlib::deserialize(item.dims, in);
+ dlib::deserialize(item.rnd, in);
+ }
+
+ };
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ optimizer_state state;
+ return do_train(mat(x), mat(y), state);
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ optimizer_state& state
+ ) const
+ {
+ return do_train(mat(x), mat(y), state);
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ optimizer_state& state
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y) == true,
+ "\t decision_function svm_c_linear_dcd_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.size(): " << x.size()
+ << "\n\t y.size(): " << y.size()
+ << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
+ );
+#ifdef ENABLE_ASSERTS
+ for (long i = 0; i < x.size(); ++i)
+ {
+ DLIB_ASSERT(y(i) == +1 || y(i) == -1,
+ "\t decision_function svm_c_linear_dcd_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t y("<<i<<"): " << y(i)
+ );
+ }
+#endif
+
+ state.init(x,y,have_bias,last_weight_1,do_svm_l2,Cpos,Cneg);
+
+ std::vector<scalar_type>& alpha = state.alpha;
+ scalar_vector_type& w = state.w;
+ std::vector<long>& index = state.index;
+ const long dims = state.dims;
+
+
+ unsigned long active_size = index.size();
+
+ scalar_type PG_max_prev = std::numeric_limits<scalar_type>::infinity();
+ scalar_type PG_min_prev = -std::numeric_limits<scalar_type>::infinity();
+
+ const scalar_type Dii_pos = 1/(2*Cpos);
+ const scalar_type Dii_neg = 1/(2*Cneg);
+
+ // main loop
+ for (unsigned long iter = 0; iter < max_iterations; ++iter)
+ {
+ scalar_type PG_max = -std::numeric_limits<scalar_type>::infinity();
+ scalar_type PG_min = std::numeric_limits<scalar_type>::infinity();
+
+ // randomly shuffle the indices
+ for (unsigned long i = 0; i < active_size; ++i)
+ {
+ // pick a random index >= i
+ const long j = i + state.rnd.get_random_32bit_number()%(active_size-i);
+ std::swap(index[i], index[j]);
+ }
+
+ // for all the active training samples
+ for (unsigned long ii = 0; ii < active_size; ++ii)
+ {
+ const long i = index[ii];
+
+ scalar_type G = y(i)*dot(w, x(i)) - 1;
+ if (do_svm_l2)
+ {
+ if (y(i) > 0)
+ G += Dii_pos*alpha[i];
+ else
+ G += Dii_neg*alpha[i];
+ }
+ const scalar_type C = (y(i) > 0) ? Cpos : Cneg;
+ const scalar_type U = do_svm_l2 ? std::numeric_limits<scalar_type>::infinity() : C;
+
+ scalar_type PG = 0;
+ if (alpha[i] == 0)
+ {
+ if (G > PG_max_prev)
+ {
+ // shrink the active set of training examples
+ --active_size;
+ std::swap(index[ii], index[active_size]);
+ --ii;
+ continue;
+ }
+
+ if (G < 0)
+ PG = G;
+ }
+ else if (alpha[i] == U)
+ {
+ if (G < PG_min_prev)
+ {
+ // shrink the active set of training examples
+ --active_size;
+ std::swap(index[ii], index[active_size]);
+ --ii;
+ continue;
+ }
+
+ if (G > 0)
+ PG = G;
+ }
+ else
+ {
+ PG = G;
+ }
+
+ if (PG > PG_max)
+ PG_max = PG;
+ if (PG < PG_min)
+ PG_min = PG;
+
+ // if PG != 0
+ if (std::abs(PG) > 1e-12)
+ {
+ const scalar_type alpha_old = alpha[i];
+ alpha[i] = std::min(std::max(alpha[i] - G/state.Q[i], (scalar_type)0.0), U);
+ const scalar_type delta = (alpha[i]-alpha_old)*y(i);
+ add_to(w, x(i), delta);
+ if (have_bias && !last_weight_1)
+ w(w.size()-1) -= delta;
+
+ if (last_weight_1)
+ w(dims-1) = 1;
+ }
+
+ }
+
+ if (verbose)
+ {
+ using namespace std;
+ cout << "gap: " << PG_max - PG_min << endl;
+ cout << "active_size: " << active_size << endl;
+ cout << "iter: " << iter << endl;
+ cout << endl;
+ }
+
+ if (PG_max - PG_min <= eps)
+ {
+ // stop if we are within eps tolerance and the last iteration
+ // was over all the samples
+ if (active_size == index.size())
+ break;
+
+ // Turn off shrinking on the next iteration. We will stop if the
+ // tolerance is still <= eps when shrinking is off.
+ active_size = index.size();
+ PG_max_prev = std::numeric_limits<scalar_type>::infinity();
+ PG_min_prev = -std::numeric_limits<scalar_type>::infinity();
+ }
+ else if (do_shrinking)
+ {
+ PG_max_prev = PG_max;
+ PG_min_prev = PG_min;
+ if (PG_max_prev <= 0)
+ PG_max_prev = std::numeric_limits<scalar_type>::infinity();
+ if (PG_min_prev >= 0)
+ PG_min_prev = -std::numeric_limits<scalar_type>::infinity();
+ }
+
+ } // end of main optimization loop
+
+
+
+
+ // put the solution into a decision function and then return it
+ decision_function<kernel_type> df;
+ if (have_bias && !last_weight_1)
+ df.b = w(w.size()-1);
+ else
+ df.b = 0;
+
+ df.basis_vectors.set_size(1);
+ // Copy the plane normal into the output basis vector. The output vector might
+ // be a sparse vector container so we need to use this special kind of copy to
+ // handle that case.
+ assign(df.basis_vectors(0), colm(w, 0, dims));
+ df.alpha.set_size(1);
+ df.alpha(0) = 1;
+
+ return df;
+ }
+
+ scalar_type dot (
+ const scalar_vector_type& w,
+ const sample_type& sample
+ ) const
+ {
+ if (have_bias && !last_weight_1)
+ {
+ const long w_size_m1 = w.size()-1;
+ return dlib::dot(colm(w,0,w_size_m1), sample) - w(w_size_m1);
+ }
+ else
+ {
+ return dlib::dot(w, sample);
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ scalar_type Cpos;
+ scalar_type Cneg;
+ scalar_type eps;
+ unsigned long max_iterations;
+ bool verbose;
+ bool have_bias; // having a bias means we pretend all x vectors have an extra element which is always -1.
+ bool last_weight_1;
+ bool do_shrinking;
+ bool do_svm_l2;
+
+ }; // end of class svm_c_linear_dcd_trainer
+
+// ----------------------------------------------------------------------------------------
+
+
+}
+
+#endif // DLIB_SVm_C_LINEAR_DCD_TRAINER_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer_abstract.h b/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer_abstract.h
new file mode 100644
index 000000000..b57c54260
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer_abstract.h
@@ -0,0 +1,382 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_C_LINEAR_DCD_TRAINER_ABSTRACT_Hh_
+#ifdef DLIB_SVm_C_LINEAR_DCD_TRAINER_ABSTRACT_Hh_
+
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_c_linear_dcd_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ Is either linear_kernel or sparse_linear_kernel.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a tool for training the C formulation of a support
+ vector machine. It is optimized for the case where linear kernels are
+ used.
+
+
+ In particular, it is implemented using the algorithm described in the
+ following paper:
+ A Dual Coordinate Descent Method for Large-scale Linear SVM
+ by Cho-Jui Hsieh, Kai-Wei Chang, and Chih-Jen Lin
+
+ It solves the optimization problem of:
+ min_w: 0.5||w||^2 + C*sum_i (hinge loss for sample i)
+ where w is the learned SVM parameter vector.
+
+ Note that this object is very similar to the svm_c_linear_trainer, however,
+ it interprets the C parameter slightly differently. In particular, C for
+ the DCD trainer is not automatically divided by the number of samples like
+ it is with the svm_c_linear_trainer. For example, a C value of 10 when
+ given to the svm_c_linear_trainer is equivalent to a C value of 10/N for
+ the svm_c_linear_dcd_trainer, where N is the number of training samples.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+ typedef typename decision_function<K>::sample_vector_type sample_vector_type;
+ typedef typename decision_function<K>::scalar_vector_type scalar_vector_type;
+
+
+ svm_c_linear_dcd_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used to train a
+ support vector machine.
+ - #get_c_class1() == 1
+ - #get_c_class2() == 1
+ - #get_epsilon() == 0.1
+ - #get_max_iterations() == 10000
+ - This object will not be verbose unless be_verbose() is called
+ - #forces_last_weight_to_1() == false
+ - #includes_bias() == true
+ - #shrinking_enabled() == true
+ - #solving_svm_l2_problem() == false
+ !*/
+
+ explicit svm_c_linear_dcd_trainer (
+ const scalar_type& C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - This object is properly initialized and ready to be used to train a
+ support vector machine.
+ - #get_c_class1() == C
+ - #get_c_class2() == C
+ - #get_epsilon() == 0.1
+ - #get_max_iterations() == 10000
+ - This object will not be verbose unless be_verbose() is called
+ - #forces_last_weight_to_1() == false
+ - #includes_bias() == true
+ - #shrinking_enabled() == true
+ - #solving_svm_l2_problem() == false
+ !*/
+
+ bool includes_bias (
+ ) const;
+ /*!
+ ensures
+ - returns true if this trainer will produce decision_functions with
+ non-zero bias values.
+ !*/
+
+ void include_bias (
+ bool should_have_bias
+ );
+ /*!
+ ensures
+ - #includes_bias() == should_have_bias
+ !*/
+
+ bool forces_last_weight_to_1 (
+ ) const;
+ /*!
+ ensures
+ - returns true if this trainer has the constraint that the last weight in
+ the learned parameter vector must be 1. This is the weight corresponding
+ to the feature in the training vectors with the highest dimension.
+ - Forcing the last weight to 1 also disables the bias and therefore the b
+ field of the learned decision_function will be 0 when forces_last_weight_to_1() == true.
+ This is true regardless of the setting of #include_bias().
+ !*/
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ );
+ /*!
+ ensures
+ - #forces_last_weight_to_1() == should_last_weight_be_1
+ !*/
+
+ bool shrinking_enabled (
+ ) const;
+ /*!
+ ensures
+ - returns true if the shrinking heuristic is enabled. Typically this makes
+ the algorithm run a lot faster so it should be enabled.
+ !*/
+
+ void enable_shrinking (
+ bool enabled
+ );
+ /*!
+ ensures
+ - #shrinking_enabled() == enabled
+ !*/
+
+ bool solving_svm_l2_problem (
+ ) const;
+ /*!
+ ensures
+ - returns true if this solver will solve the L2 version of the SVM
+ objective function. That is, if solving_svm_l2_problem()==true then this
+ object, rather than using the hinge loss, uses the squared hinge loss.
+ !*/
+
+ void solve_svm_l2_problem (
+ bool enabled
+ );
+ /*!
+ ensures
+ - #solving_svm_l2_problem() == enabled
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_epsilon (
+ scalar_type eps_
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer to
+ train.
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object. Since the
+ linear kernels don't have any parameters this function just returns
+ kernel_type()
+ !*/
+
+ unsigned long get_max_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class1() == C
+ - #get_c_class2() == C
+ !*/
+
+ const scalar_type get_c_class1 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter for the +1 class. It is the
+ parameter that determines the trade off between trying to fit the +1
+ training data exactly or allowing more errors but hopefully improving the
+ generalization of the resulting classifier. Larger values encourage
+ exact fitting while smaller values of C may encourage better
+ generalization.
+ !*/
+
+ const scalar_type get_c_class2 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter for the -1 class. It is the
+ parameter that determines the trade off between trying to fit the -1
+ training data exactly or allowing more errors but hopefully improving the
+ generalization of the resulting classifier. Larger values encourage
+ exact fitting while smaller values of C may encourage better
+ generalization.
+ !*/
+
+ void set_c_class1 (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class1() == C
+ !*/
+
+ void set_c_class2 (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class2() == C
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(x,y) == true
+ (Note that it is ok for x.size() == 1)
+ - All elements of y must be equal to +1 or -1
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ ensures
+ - Trains a C support vector classifier given the training samples in x and
+ labels in y.
+ - returns a decision function F with the following properties:
+ - F.alpha.size() == 1
+ - F.basis_vectors.size() == 1
+ - F.alpha(0) == 1
+ - if (new_x is a sample predicted have +1 label) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ !*/
+
+ // optimizer_state is used to record the internal state of the SVM optimizer. It
+ // can be used with the following train() routine to warm-start the optimizer or
+ // access the optimal alpha values (see the Hsieh paper mentioned above). The
+ // optimizer_state objects are serializable and allow you to get the alphas, but
+ // are otherwise completely opaque to the user.
+ class optimizer_state
+ {
+ public:
+ const std::vector<scalar_type>& get_alpha (
+ ) const;
+ };
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ optimizer_state& state
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(x,y) == true
+ (Note that it is ok for x.size() == 1)
+ - All elements of y must be equal to +1 or -1
+ - state must be either a default initialized optimizer_state object or all the
+ following conditions must be satisfied:
+ - Let LAST denote the previous trainer used with the state object, then
+ we must have:
+ - LAST.includes_bias() == includes_bias()
+ - LAST.forces_last_weight_to_1() == forces_last_weight_to_1()
+ - Let X denote the previous training samples used with state, then the
+ following must be satisfied:
+ - x.size() >= X.size()
+ - for all valid i:
+ - x(i) == X(i)
+ (i.e. the samples x and X have in common must be identical.
+ That is, the only allowed difference between x and X is that
+ x might have new training samples appended onto its end)
+ - if (x contains dense vectors) then
+ - max_index_plus_one(x) == max_index_plus_one(X)
+ - else
+ - max_index_plus_one(x) >= max_index_plus_one(X)
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ ensures
+ - Trains a C support vector classifier given the training samples in x and
+ labels in y.
+ - The point of the state object is to allow you to warm start the SVM
+ optimizer from the solution to a previous call to train(). Doing this
+ might make the training run faster. This is useful when you are trying
+ different C values or have grown the training set and want to retrain.
+ - #state == the internal state of the optimizer at the solution to the SVM
+ problem. Therefore, passing #state to a new call to train() will start
+ the optimizer from the current solution.
+ - #state.get_alpha().size() == x.size()
+ - #state.get_alpha() == the optimal alpha/dual values learned by the optimizer.
+ - returns a decision function F with the following properties:
+ - F.alpha.size() == 1
+ - F.basis_vectors.size() == 1
+ - F.alpha(0) == 1
+ - if (new_x is a sample predicted have +1 label) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_C_LINEAR_DCD_TRAINER_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/svm_c_linear_trainer.h b/ml/dlib/dlib/svm/svm_c_linear_trainer.h
new file mode 100644
index 000000000..8d136d711
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_c_linear_trainer.h
@@ -0,0 +1,706 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVM_C_LiNEAR_TRAINER_Hh_
+#define DLIB_SVM_C_LiNEAR_TRAINER_Hh_
+
+#include "svm_c_linear_trainer_abstract.h"
+#include "../algs.h"
+#include "../optimization.h"
+#include "../matrix.h"
+#include "function.h"
+#include "kernel.h"
+#include <iostream>
+#include <vector>
+#include "sparse_vector.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ class oca_problem_c_svm : public oca_problem<matrix_type >
+ {
+ public:
+ /*
+ This class is used as part of the implementation of the svm_c_linear_trainer
+ defined towards the end of this file.
+
+
+ The bias parameter is dealt with by imagining that each sample vector has -1
+ as its last element.
+ */
+
+ typedef typename matrix_type::type scalar_type;
+
+ oca_problem_c_svm(
+ const scalar_type C_pos,
+ const scalar_type C_neg,
+ const in_sample_vector_type& samples_,
+ const in_scalar_vector_type& labels_,
+ const bool be_verbose_,
+ const scalar_type eps_,
+ const unsigned long max_iter,
+ const unsigned long dims_
+ ) :
+ samples(samples_),
+ labels(labels_),
+ C(std::min(C_pos,C_neg)),
+ Cpos(C_pos/C),
+ Cneg(C_neg/C),
+ be_verbose(be_verbose_),
+ eps(eps_),
+ max_iterations(max_iter),
+ dims(dims_)
+ {
+ dot_prods.resize(samples.size());
+ is_first_call = true;
+ }
+
+ virtual scalar_type get_c (
+ ) const
+ {
+ return C;
+ }
+
+ virtual long get_num_dimensions (
+ ) const
+ {
+ // plus 1 for the bias term
+ return dims + 1;
+ }
+
+ virtual bool optimization_status (
+ scalar_type current_objective_value,
+ scalar_type current_error_gap,
+ scalar_type current_risk_value,
+ scalar_type current_risk_gap,
+ unsigned long num_cutting_planes,
+ unsigned long num_iterations
+ ) const
+ {
+ if (be_verbose)
+ {
+ using namespace std;
+ cout << "objective: " << current_objective_value << endl;
+ cout << "objective gap: " << current_error_gap << endl;
+ cout << "risk: " << current_risk_value << endl;
+ cout << "risk gap: " << current_risk_gap << endl;
+ cout << "num planes: " << num_cutting_planes << endl;
+ cout << "iter: " << num_iterations << endl;
+ cout << endl;
+ }
+
+ if (num_iterations >= max_iterations)
+ return true;
+
+ if (current_risk_gap < eps)
+ return true;
+
+ return false;
+ }
+
+ virtual bool risk_has_lower_bound (
+ scalar_type& lower_bound
+ ) const
+ {
+ lower_bound = 0;
+ return true;
+ }
+
+ virtual void get_risk (
+ matrix_type& w,
+ scalar_type& risk,
+ matrix_type& subgradient
+ ) const
+ {
+ line_search(w);
+
+ subgradient.set_size(w.size(),1);
+ subgradient = 0;
+ risk = 0;
+
+
+ // loop over all the samples and compute the risk and its subgradient at the current solution point w
+ for (long i = 0; i < samples.size(); ++i)
+ {
+ // multiply current SVM output for the ith sample by its label
+ const scalar_type df_val = labels(i)*dot_prods[i];
+
+ if (labels(i) > 0)
+ risk += Cpos*std::max<scalar_type>(0.0,1 - df_val);
+ else
+ risk += Cneg*std::max<scalar_type>(0.0,1 - df_val);
+
+ if (df_val < 1)
+ {
+ if (labels(i) > 0)
+ {
+ subtract_from(subgradient, samples(i), Cpos);
+
+ subgradient(subgradient.size()-1) += Cpos;
+ }
+ else
+ {
+ add_to(subgradient, samples(i), Cneg);
+
+ subgradient(subgradient.size()-1) -= Cneg;
+ }
+ }
+ }
+
+ scalar_type scale = 1.0/samples.size();
+
+ risk *= scale;
+ subgradient = scale*subgradient;
+ }
+
+ private:
+
+ // -----------------------------------------------------
+ // -----------------------------------------------------
+
+ void line_search (
+ matrix_type& w
+ ) const
+ /*!
+ ensures
+ - does a line search to find a better w
+ - for all i: #dot_prods[i] == dot(colm(#w,0,w.size()-1), samples(i)) - #w(w.size()-1)
+ !*/
+ {
+ // The reason for using w_size_m1 and not just w.size()-1 is because
+ // doing it this way avoids an inane warning from gcc that can occur in some cases.
+ const long w_size_m1 = w.size()-1;
+ for (long i = 0; i < samples.size(); ++i)
+ dot_prods[i] = dot(colm(w,0,w_size_m1), samples(i)) - w(w_size_m1);
+
+ if (is_first_call)
+ {
+ is_first_call = false;
+ best_so_far = w;
+ dot_prods_best = dot_prods;
+ }
+ else
+ {
+ // do line search going from best_so_far to w. Store results in w.
+ // Here we use the line search algorithm presented in section 3.1.1 of Franc and Sonnenburg.
+
+ const scalar_type A0 = length_squared(best_so_far - w);
+ const scalar_type BB0 = dot(best_so_far, w - best_so_far);
+
+ const scalar_type scale_pos = (get_c()*Cpos)/samples.size();
+ const scalar_type scale_neg = (get_c()*Cneg)/samples.size();
+
+ ks.clear();
+ ks.reserve(samples.size());
+
+ scalar_type f0 = BB0;
+ for (long i = 0; i < samples.size(); ++i)
+ {
+ const scalar_type& scale = (labels(i)>0) ? scale_pos : scale_neg;
+
+ const scalar_type B = scale*labels(i) * ( dot_prods_best[i] - dot_prods[i]);
+ const scalar_type C = scale*(1 - labels(i)* dot_prods_best[i]);
+ // Note that if B is 0 then it doesn't matter what k is set to. So 0 is fine.
+ scalar_type k = 0;
+ if (B != 0)
+ k = -C/B;
+
+ if (k > 0)
+ ks.push_back(helper(k, std::abs(B)));
+
+ if ( (B < 0 && k > 0) || (B > 0 && k <= 0) )
+ f0 += B;
+ }
+
+ scalar_type opt_k = 1;
+ // ks.size() == 0 shouldn't happen but check anyway
+ if (f0 >= 0 || ks.size() == 0)
+ {
+ // Getting here means that we aren't searching in a descent direction.
+ // We could take a zero step but instead lets just assign w to the new best
+ // so far point just to make sure we don't get stuck coming back to this
+ // case over and over. This might happen if we never move the best point
+ // seen so far.
+
+ // So we let opt_k be 1
+ }
+ else
+ {
+ std::sort(ks.begin(), ks.end());
+
+ // figure out where f0 goes positive.
+ for (unsigned long i = 0; i < ks.size(); ++i)
+ {
+ f0 += ks[i].B;
+ if (f0 + A0*ks[i].k >= 0)
+ {
+ opt_k = ks[i].k;
+ break;
+ }
+ }
+
+ }
+
+ // Don't let the step size get too big. Otherwise we might pick huge steps
+ // over and over that don't improve the cutting plane approximation.
+ if (opt_k > 1.0)
+ {
+ opt_k = 1.0;
+ }
+
+ // take the step suggested by the line search
+ best_so_far = (1-opt_k)*best_so_far + opt_k*w;
+
+ // update best_so_far dot products
+ for (unsigned long i = 0; i < dot_prods_best.size(); ++i)
+ dot_prods_best[i] = (1-opt_k)*dot_prods_best[i] + opt_k*dot_prods[i];
+
+
+ const scalar_type mu = 0.1;
+ // Make sure we always take a little bit of a step towards w regardless of what the
+ // line search says to do. We do this since it is possible that some steps won't
+ // advance the best_so_far point. So this ensures we always make some progress each
+ // iteration.
+ w = (1-mu)*best_so_far + mu*w;
+
+ // update dot products
+ for (unsigned long i = 0; i < dot_prods.size(); ++i)
+ dot_prods[i] = (1-mu)*dot_prods_best[i] + mu*dot_prods[i];
+ }
+ }
+
+ struct helper
+ {
+ helper(scalar_type k_, scalar_type B_) : k(k_), B(B_) {}
+ scalar_type k;
+ scalar_type B;
+
+ bool operator< (const helper& item) const { return k < item.k; }
+ };
+
+ mutable std::vector<helper> ks;
+
+ mutable bool is_first_call;
+ mutable std::vector<scalar_type> dot_prods;
+
+ mutable matrix_type best_so_far; // best w seen so far
+ mutable std::vector<scalar_type> dot_prods_best; // dot products between best_so_far and samples
+
+
+ const in_sample_vector_type& samples;
+ const in_scalar_vector_type& labels;
+ const scalar_type C;
+ const scalar_type Cpos;
+ const scalar_type Cneg;
+
+ const bool be_verbose;
+ const scalar_type eps;
+ const unsigned long max_iterations;
+ const unsigned long dims;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type,
+ typename scalar_type
+ >
+ oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type> make_oca_problem_c_svm (
+ const scalar_type C_pos,
+ const scalar_type C_neg,
+ const in_sample_vector_type& samples,
+ const in_scalar_vector_type& labels,
+ const bool be_verbose,
+ const scalar_type eps,
+ const unsigned long max_iterations,
+ const unsigned long dims
+ )
+ {
+ return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>(
+ C_pos, C_neg, samples, labels, be_verbose, eps, max_iterations, dims);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_c_linear_trainer
+ {
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ // You are getting a compiler error on this line because you supplied a non-linear kernel
+ // to the svm_c_linear_trainer object. You have to use one of the linear kernels with this
+ // trainer.
+ COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
+ is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
+
+ svm_c_linear_trainer (
+ )
+ {
+ Cpos = 1;
+ Cneg = 1;
+ verbose = false;
+ eps = 0.001;
+ max_iterations = 10000;
+ learn_nonnegative_weights = false;
+ last_weight_1 = false;
+ }
+
+ explicit svm_c_linear_trainer (
+ const scalar_type& C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t svm_c_linear_trainer::svm_c_linear_trainer()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cpos = C;
+ Cneg = C;
+ verbose = false;
+ eps = 0.001;
+ max_iterations = 10000;
+ learn_nonnegative_weights = false;
+ last_weight_1 = false;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void svm_c_linear_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const { return eps; }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ return kernel_type();
+ }
+
+ bool learns_nonnegative_weights (
+ ) const { return learn_nonnegative_weights; }
+
+ void set_learns_nonnegative_weights (
+ bool value
+ )
+ {
+ learn_nonnegative_weights = value;
+ if (learn_nonnegative_weights)
+ prior.set_size(0);
+ }
+
+ bool forces_last_weight_to_1 (
+ ) const
+ {
+ return last_weight_1;
+ }
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ )
+ {
+ last_weight_1 = should_last_weight_be_1;
+ if (last_weight_1)
+ prior.set_size(0);
+ }
+
+ void set_prior (
+ const trained_function_type& prior_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(prior_.basis_vectors.size() == 1 &&
+ prior_.alpha(0) == 1,
+ "\t void svm_c_linear_trainer::set_prior()"
+ << "\n\t The supplied prior could not have been created by this object's train() method."
+ << "\n\t prior_.basis_vectors.size(): " << prior_.basis_vectors.size()
+ << "\n\t prior_.alpha(0): " << prior_.alpha(0)
+ << "\n\t this: " << this
+ );
+
+ prior = sparse_to_dense(prior_.basis_vectors(0));
+ prior_b = prior_.b;
+ learn_nonnegative_weights = false;
+ last_weight_1 = false;
+ }
+
+ bool has_prior (
+ ) const
+ {
+ return prior.size() != 0;
+ }
+
+ void set_c (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_linear_trainer::set_c()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cpos = C;
+ Cneg = C;
+ }
+
+ const scalar_type get_c_class1 (
+ ) const
+ {
+ return Cpos;
+ }
+
+ const scalar_type get_c_class2 (
+ ) const
+ {
+ return Cneg;
+ }
+
+ void set_c_class1 (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_linear_trainer::set_c_class1()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cpos = C;
+ }
+
+ void set_c_class2 (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_linear_trainer::set_c_class2()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cneg = C;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ scalar_type obj;
+ return do_train(mat(x),mat(y),obj);
+ }
+
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ scalar_type& svm_objective
+ ) const
+ {
+ return do_train(mat(x),mat(y),svm_objective);
+ }
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ scalar_type& svm_objective
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y) == true,
+ "\t decision_function svm_c_linear_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t y.nr(): " << y.nr()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t y.nc(): " << y.nc()
+ << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
+ );
+#ifdef ENABLE_ASSERTS
+ for (long i = 0; i < x.size(); ++i)
+ {
+ DLIB_ASSERT(y(i) == +1 || y(i) == -1,
+ "\t decision_function svm_c_linear_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t y("<<i<<"): " << y(i)
+ );
+ }
+#endif
+
+
+ typedef matrix<scalar_type,0,1> w_type;
+ w_type w;
+
+ const unsigned long num_dims = max_index_plus_one(x);
+
+ unsigned long num_nonnegative = 0;
+ if (learn_nonnegative_weights)
+ {
+ num_nonnegative = num_dims;
+ }
+
+ unsigned long force_weight_1_idx = std::numeric_limits<unsigned long>::max();
+ if (last_weight_1)
+ {
+ force_weight_1_idx = num_dims-1;
+ }
+
+
+ if (has_prior())
+ {
+ if (is_matrix<sample_type>::value)
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(num_dims == (unsigned long)prior.size(),
+ "\t decision_function svm_c_linear_trainer::train(x,y)"
+ << "\n\t The dimension of the training vectors must match the dimension of\n"
+ << "\n\t those used to create the prior."
+ << "\n\t num_dims: " << num_dims
+ << "\n\t prior.size(): " << prior.size()
+ );
+ }
+ const unsigned long dims = std::max(num_dims, (unsigned long)prior.size());
+ // In the case of sparse sample vectors, it is possible that the input
+ // vector dimensionality is larger than the prior vector dimensionality.
+ // We need to check for this case and pad prior with zeros if it is the
+ // case.
+ matrix<scalar_type,0,1> prior_temp = join_cols(join_cols(prior,
+ zeros_matrix<scalar_type>(dims-prior.size(),1)),
+ mat(prior_b));
+
+ svm_objective = solver(
+ make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations, dims),
+ w,
+ prior_temp);
+ }
+ else
+ {
+ svm_objective = solver(
+ make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations, num_dims),
+ w,
+ num_nonnegative,
+ force_weight_1_idx);
+ }
+
+ // put the solution into a decision function and then return it
+ decision_function<kernel_type> df;
+ df.b = static_cast<scalar_type>(w(w.size()-1));
+ df.basis_vectors.set_size(1);
+ // Copy the plane normal into the output basis vector. The output vector might be a
+ // sparse vector container so we need to use this special kind of copy to handle that case.
+ // As an aside, the reason for using max_index_plus_one() and not just w.size()-1 is because
+ // doing it this way avoids an inane warning from gcc that can occur in some cases.
+ const long out_size = max_index_plus_one(x);
+ assign(df.basis_vectors(0), matrix_cast<scalar_type>(colm(w, 0, out_size)));
+ df.alpha.set_size(1);
+ df.alpha(0) = 1;
+
+ return df;
+ }
+
+ scalar_type Cpos;
+ scalar_type Cneg;
+ oca solver;
+ scalar_type eps;
+ bool verbose;
+ unsigned long max_iterations;
+ bool learn_nonnegative_weights;
+ bool last_weight_1;
+ matrix<scalar_type,0,1> prior;
+ scalar_type prior_b = 0;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+
+#endif // DLIB_SVM_C_LiNEAR_TRAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/svm_c_linear_trainer_abstract.h b/ml/dlib/dlib/svm/svm_c_linear_trainer_abstract.h
new file mode 100644
index 000000000..1b7a128f0
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_c_linear_trainer_abstract.h
@@ -0,0 +1,359 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVM_C_LiNEAR_TRAINER_ABSTRACT_Hh_
+#ifdef DLIB_SVM_C_LiNEAR_TRAINER_ABSTRACT_Hh_
+
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "sparse_kernel_abstract.h"
+
+namespace dlib
+{
+ template <
+ typename K
+ >
+ class svm_c_linear_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ Is either linear_kernel or sparse_linear_kernel.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a tool for training the C formulation of
+ a support vector machine. It is optimized for the case where
+ linear kernels are used.
+
+
+ In particular, it is implemented using the OCAS algorithm
+ described in the following paper:
+ Optimized Cutting Plane Algorithm for Large-Scale Risk Minimization
+ Vojtech Franc, Soren Sonnenburg; Journal of Machine Learning
+ Research, 10(Oct):2157--2192, 2009.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_c_linear_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_oca() == oca() (i.e. an instance of oca with default parameters)
+ - #get_c_class1() == 1
+ - #get_c_class2() == 1
+ - #get_epsilon() == 0.001
+ - this object will not be verbose unless be_verbose() is called
+ - #get_max_iterations() == 10000
+ - #learns_nonnegative_weights() == false
+ - #force_last_weight_to_1() == false
+ - #has_prior() == false
+ !*/
+
+ explicit svm_c_linear_trainer (
+ const scalar_type& C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_oca() == oca() (i.e. an instance of oca with default parameters)
+ - #get_c_class1() == C
+ - #get_c_class2() == C
+ - #get_epsilon() == 0.001
+ - this object will not be verbose unless be_verbose() is called
+ - #get_max_iterations() == 10000
+ - #learns_nonnegative_weights() == false
+ - #force_last_weight_to_1() == false
+ - #has_prior() == false
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer to
+ train. You can think of this epsilon value as saying "solve the
+ optimization problem until the probability of misclassification is within
+ epsilon of its optimal value".
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ unsigned long get_max_iterations (
+ );
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the SVM problem.
+ !*/
+
+ const kernel_type get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object. Since
+ the linear kernels don't have any parameters this function just
+ returns kernel_type()
+ !*/
+
+ bool learns_nonnegative_weights (
+ ) const;
+ /*!
+ ensures
+ - The output of training is a weight vector and a bias value. These
+ two things define the resulting decision function. That is, the
+ decision function simply takes the dot product between the learned
+ weight vector and a test sample, then subtracts the bias value.
+ Therefore, if learns_nonnegative_weights() == true then the resulting
+ learned weight vector will always have non-negative entries. The
+ bias value may still be negative though.
+ !*/
+
+ void set_learns_nonnegative_weights (
+ bool value
+ );
+ /*!
+ ensures
+ - #learns_nonnegative_weights() == value
+ - if (value == true) then
+ - #has_prior() == false
+ !*/
+
+ void set_prior (
+ const trained_function_type& prior
+ );
+ /*!
+ requires
+ - prior == a function produced by a call to this class's train() function.
+ Therefore, it must be the case that:
+ - prior.basis_vectors.size() == 1
+ - prior.alpha(0) == 1
+ ensures
+ - Subsequent calls to train() will try to learn a function similar to the
+ given prior.
+ - #has_prior() == true
+ - #learns_nonnegative_weights() == false
+ - #forces_last_weight_to_1() == false
+ !*/
+
+ bool has_prior (
+ ) const
+ /*!
+ ensures
+ - returns true if a prior has been set and false otherwise. Having a prior
+ set means that you have called set_prior() and supplied a previously
+ trained function as a reference. In this case, any call to train() will
+ try to learn a function that matches the behavior of the prior as close
+ as possible but also fits the supplied training data. In more technical
+ detail, having a prior means we replace the ||w||^2 regularizer with one
+ of the form ||w-prior||^2 where w is the set of parameters for a learned
+ function.
+ !*/
+
+ bool forces_last_weight_to_1 (
+ ) const;
+ /*!
+ ensures
+ - returns true if this trainer has the constraint that the last weight in
+ the learned parameter vector must be 1. This is the weight corresponding
+ to the feature in the training vectors with the highest dimension.
+ - Forcing the last weight to 1 also disables the bias and therefore the b
+ field of the learned decision_function will be 0 when forces_last_weight_to_1() == true.
+ !*/
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ );
+ /*!
+ ensures
+ - #forces_last_weight_to_1() == should_last_weight_be_1
+ - if (should_last_weight_be_1 == true) then
+ - #has_prior() == false
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class1() == C
+ - #get_c_class2() == C
+ !*/
+
+ const scalar_type get_c_class1 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter for the +1 class.
+ It is the parameter that determines the trade off between
+ trying to fit the +1 training data exactly or allowing more errors
+ but hopefully improving the generalization of the resulting
+ classifier. Larger values encourage exact fitting while
+ smaller values of C may encourage better generalization.
+ !*/
+
+ const scalar_type get_c_class2 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter for the -1 class.
+ It is the parameter that determines the trade off between
+ trying to fit the -1 training data exactly or allowing more errors
+ but hopefully improving the generalization of the resulting
+ classifier. Larger values encourage exact fitting while
+ smaller values of C may encourage better generalization.
+ !*/
+
+ void set_c_class1 (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class1() == C
+ !*/
+
+ void set_c_class2 (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class2() == C
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(x,y) == true
+ (Note that it is ok for x.size() == 1)
+ - All elements of y must be equal to +1 or -1
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ - if (has_prior()) then
+ - The vectors in x must have the same dimensionality as the vectors
+ used to train the prior given to set_prior().
+ ensures
+ - trains a C support vector classifier given the training samples in x and
+ labels in y.
+ - returns a decision function F with the following properties:
+ - F.alpha.size() == 1
+ - F.basis_vectors.size() == 1
+ - F.alpha(0) == 1
+ - if (new_x is a sample predicted have +1 label) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ scalar_type& svm_objective
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(x,y) == true
+ (Note that it is ok for x.size() == 1)
+ - All elements of y must be equal to +1 or -1
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ - if (has_prior()) then
+ - The vectors in x must have the same dimensionality as the vectors
+ used to train the prior given to set_prior().
+ ensures
+ - trains a C support vector classifier given the training samples in x and
+ labels in y.
+ - #svm_objective == the final value of the SVM objective function
+ - returns a decision function F with the following properties:
+ - F.alpha.size() == 1
+ - F.basis_vectors.size() == 1
+ - F.alpha(0) == 1
+ - if (new_x is a sample predicted have +1 label) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ !*/
+
+ };
+
+}
+
+#endif // DLIB_SVM_C_LiNEAR_TRAINER_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/svm_c_trainer.h b/ml/dlib/dlib/svm/svm_c_trainer.h
new file mode 100644
index 000000000..14dcf3482
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_c_trainer.h
@@ -0,0 +1,359 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_C_TRAINER_Hh_
+#define DLIB_SVm_C_TRAINER_Hh_
+
+//#include "local/make_label_kernel_matrix.h"
+
+#include "svm_c_trainer_abstract.h"
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix.h"
+#include "../algs.h"
+
+#include "function.h"
+#include "kernel.h"
+#include "../optimization/optimization_solve_qp3_using_smo.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_c_trainer
+ {
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_c_trainer (
+ ) :
+ Cpos(1),
+ Cneg(1),
+ cache_size(200),
+ eps(0.001)
+ {
+ }
+
+ svm_c_trainer (
+ const kernel_type& kernel_,
+ const scalar_type& C_
+ ) :
+ kernel_function(kernel_),
+ Cpos(C_),
+ Cneg(C_),
+ cache_size(200),
+ eps(0.001)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < C_,
+ "\tsvm_c_trainer::svm_c_trainer(kernel,C)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t C_: " << C_
+ );
+ }
+
+ void set_cache_size (
+ long cache_size_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(cache_size_ > 0,
+ "\tvoid svm_c_trainer::set_cache_size(cache_size_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t cache_size: " << cache_size_
+ );
+ cache_size = cache_size_;
+ }
+
+ long get_cache_size (
+ ) const
+ {
+ return cache_size;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\tvoid svm_c_trainer::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps_: " << eps_
+ );
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kernel_function = k;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel_function;
+ }
+
+ void set_c (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_trainer::set_c()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cpos = C;
+ Cneg = C;
+ }
+
+ const scalar_type get_c_class1 (
+ ) const
+ {
+ return Cpos;
+ }
+
+ const scalar_type get_c_class2 (
+ ) const
+ {
+ return Cneg;
+ }
+
+ void set_c_class1 (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_trainer::set_c_class1()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cpos = C;
+ }
+
+ void set_c_class2 (
+ scalar_type C
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C > 0,
+ "\t void svm_c_trainer::set_c_class2()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C: " << C
+ << "\n\t this: " << this
+ );
+
+ Cneg = C;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ return do_train(mat(x), mat(y));
+ }
+
+ void swap (
+ svm_c_trainer& item
+ )
+ {
+ exchange(kernel_function, item.kernel_function);
+ exchange(Cpos, item.Cpos);
+ exchange(Cneg, item.Cneg);
+ exchange(cache_size, item.cache_size);
+ exchange(eps, item.eps);
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ typedef typename K::scalar_type scalar_type;
+ typedef typename decision_function<K>::sample_vector_type sample_vector_type;
+ typedef typename decision_function<K>::scalar_vector_type scalar_vector_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
+ "\tdecision_function svm_c_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t y.nr(): " << y.nr()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t y.nc(): " << y.nc()
+ << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
+ );
+
+
+ scalar_vector_type alpha;
+
+ solve_qp3_using_smo<scalar_vector_type> solver;
+
+ solver(symmetric_matrix_cache<float>((diagm(y)*kernel_matrix(kernel_function,x)*diagm(y)), cache_size),
+ //solver(symmetric_matrix_cache<float>(make_label_kernel_matrix(kernel_matrix(kernel_function,x),y), cache_size),
+ uniform_matrix<scalar_type>(y.size(),1,-1),
+ y,
+ 0,
+ Cpos,
+ Cneg,
+ alpha,
+ eps);
+
+ scalar_type b;
+ calculate_b(y,alpha,solver.get_gradient(),Cpos,Cneg,b);
+ alpha = pointwise_multiply(alpha,y);
+
+ // count the number of support vectors
+ const long sv_count = (long)sum(alpha != 0);
+
+ scalar_vector_type sv_alpha;
+ sample_vector_type support_vectors;
+
+ // size these column vectors so that they have an entry for each support vector
+ sv_alpha.set_size(sv_count);
+ support_vectors.set_size(sv_count);
+
+ // load the support vectors and their alpha values into these new column matrices
+ long idx = 0;
+ for (long i = 0; i < alpha.nr(); ++i)
+ {
+ if (alpha(i) != 0)
+ {
+ sv_alpha(idx) = alpha(i);
+ support_vectors(idx) = x(i);
+ ++idx;
+ }
+ }
+
+ // now return the decision function
+ return decision_function<K> (sv_alpha, b, kernel_function, support_vectors);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename scalar_vector_type,
+ typename scalar_vector_type2
+ >
+ void calculate_b(
+ const scalar_vector_type2& y,
+ const scalar_vector_type& alpha,
+ const scalar_vector_type& df,
+ const scalar_type& Cpos,
+ const scalar_type& Cneg,
+ scalar_type& b
+ ) const
+ {
+ using namespace std;
+ long num_free = 0;
+ scalar_type sum_free = 0;
+
+ scalar_type upper_bound = -numeric_limits<scalar_type>::infinity();
+ scalar_type lower_bound = numeric_limits<scalar_type>::infinity();
+
+ for(long i = 0; i < alpha.nr(); ++i)
+ {
+ if(y(i) == 1)
+ {
+ if(alpha(i) == Cpos)
+ {
+ if (df(i) > upper_bound)
+ upper_bound = df(i);
+ }
+ else if(alpha(i) == 0)
+ {
+ if (df(i) < lower_bound)
+ lower_bound = df(i);
+ }
+ else
+ {
+ ++num_free;
+ sum_free += df(i);
+ }
+ }
+ else
+ {
+ if(alpha(i) == Cneg)
+ {
+ if (-df(i) < lower_bound)
+ lower_bound = -df(i);
+ }
+ else if(alpha(i) == 0)
+ {
+ if (-df(i) > upper_bound)
+ upper_bound = -df(i);
+ }
+ else
+ {
+ ++num_free;
+ sum_free -= df(i);
+ }
+ }
+ }
+
+ if(num_free > 0)
+ b = sum_free/num_free;
+ else
+ b = (upper_bound+lower_bound)/2;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+
+ kernel_type kernel_function;
+ scalar_type Cpos;
+ scalar_type Cneg;
+ long cache_size;
+ scalar_type eps;
+ }; // end of class svm_c_trainer
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ svm_c_trainer<K>& a,
+ svm_c_trainer<K>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_C_TRAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/svm_c_trainer_abstract.h b/ml/dlib/dlib/svm/svm_c_trainer_abstract.h
new file mode 100644
index 000000000..696cccdb7
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_c_trainer_abstract.h
@@ -0,0 +1,237 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_C_TRAINER_ABSTRACT_
+#ifdef DLIB_SVm_C_TRAINER_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "../optimization/optimization_solve_qp3_using_smo_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_c_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a trainer for a C support vector machine for
+ solving binary classification problems. It is implemented using the SMO
+ algorithm.
+
+ The implementation of the C-SVM training algorithm used by this object is based
+ on the following paper:
+ - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector
+ machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm
+
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_c_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_c_class1() == 1
+ - #get_c_class2() == 1
+ - #get_cache_size() == 200
+ - #get_epsilon() == 0.001
+ !*/
+
+ svm_c_trainer (
+ const kernel_type& kernel,
+ const scalar_type& C
+ );
+ /*!
+ requires
+ - 0 < C
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_kernel() == kernel
+ - #get_c_class1() == C
+ - #get_c_class2() == C
+ - #get_cache_size() == 200
+ - #get_epsilon() == 0.001
+ !*/
+
+ void set_cache_size (
+ long cache_size
+ );
+ /*!
+ requires
+ - cache_size > 0
+ ensures
+ - #get_cache_size() == cache_size
+ !*/
+
+ const long get_cache_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of megabytes of cache this object will use
+ when it performs training via the this->train() function.
+ (bigger values of this may make training go faster but won't affect
+ the result. However, too big a value will cause you to run out of
+ memory, obviously.)
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Generally a good value for this is 0.001. Smaller values may result
+ in a more accurate solution but take longer to execute.
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class1() == C
+ - #get_c_class2() == C
+ !*/
+
+ const scalar_type get_c_class1 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter for the +1 class.
+ It is the parameter that determines the trade off between
+ trying to fit the +1 training data exactly or allowing more errors
+ but hopefully improving the generalization ability of the
+ resulting classifier. Larger values encourage exact fitting
+ while smaller values of C may encourage better generalization.
+ !*/
+
+ const scalar_type get_c_class2 (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter for the -1 class.
+ It is the parameter that determines the trade off between
+ trying to fit the -1 training data exactly or allowing more errors
+ but hopefully improving the generalization ability of the
+ resulting classifier. Larger values encourage exact fitting
+ while smaller values of C may encourage better generalization.
+ !*/
+
+ void set_c_class1 (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class1() == C
+ !*/
+
+ void set_c_class2 (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c_class2() == C
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - is_binary_classification_problem(x,y) == true
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ ensures
+ - trains a C support vector classifier given the training samples in x and
+ labels in y. Training is done when the error is less than get_epsilon().
+ - returns a decision function F with the following properties:
+ - if (new_x is a sample predicted have +1 label) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ !*/
+
+ void swap (
+ svm_c_trainer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+ };
+
+ template <typename K>
+ void swap (
+ svm_c_trainer<K>& a,
+ svm_c_trainer<K>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_C_TRAINER_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/svm_multiclass_linear_trainer.h b/ml/dlib/dlib/svm/svm_multiclass_linear_trainer.h
new file mode 100644
index 000000000..4727f7226
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_multiclass_linear_trainer.h
@@ -0,0 +1,432 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_MULTICLASS_LINEAR_TRAINER_Hh_
+#define DLIB_SVm_MULTICLASS_LINEAR_TRAINER_Hh_
+
+#include "svm_multiclass_linear_trainer_abstract.h"
+#include "structural_svm_problem_threaded.h"
+#include <vector>
+#include "../optimization/optimization_oca.h"
+#include "../matrix.h"
+#include "sparse_vector.h"
+#include "function.h"
+#include <algorithm>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type,
+ typename sample_type,
+ typename label_type
+ >
+ class multiclass_svm_problem : public structural_svm_problem_threaded<matrix_type,
+ std::vector<std::pair<unsigned long,typename matrix_type::type> > >
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines the optimization problem for the multiclass SVM trainer
+ object at the bottom of this file.
+
+ The joint feature vectors used by this object, the PSI(x,y) vectors, are
+ defined as follows:
+ PSI(x,0) = [x,0,0,0,0, ...,0]
+ PSI(x,1) = [0,x,0,0,0, ...,0]
+ PSI(x,2) = [0,0,x,0,0, ...,0]
+ That is, if there are N labels then the joint feature vector has a
+ dimension that is N times the dimension of a single x sample. Also,
+ note that we append a -1 value onto each x to account for the bias term.
+ !*/
+
+ public:
+ typedef typename matrix_type::type scalar_type;
+ typedef std::vector<std::pair<unsigned long,scalar_type> > feature_vector_type;
+
+ multiclass_svm_problem (
+ const std::vector<sample_type>& samples_,
+ const std::vector<label_type>& labels_,
+ const std::vector<label_type>& distinct_labels_,
+ const unsigned long dims_,
+ const unsigned long num_threads
+ ) :
+ structural_svm_problem_threaded<matrix_type, std::vector<std::pair<unsigned long,typename matrix_type::type> > >(num_threads),
+ samples(samples_),
+ labels(labels_),
+ distinct_labels(distinct_labels_),
+ dims(dims_+1) // +1 for the bias
+ {}
+
+ virtual long get_num_dimensions (
+ ) const
+ {
+ return dims*distinct_labels.size();
+ }
+
+ virtual long get_num_samples (
+ ) const
+ {
+ return static_cast<long>(samples.size());
+ }
+
+ virtual void get_truth_joint_feature_vector (
+ long idx,
+ feature_vector_type& psi
+ ) const
+ {
+ assign(psi, samples[idx]);
+ // Add a constant -1 to account for the bias term.
+ psi.push_back(std::make_pair(dims-1,static_cast<scalar_type>(-1)));
+
+ // Find which distinct label goes with this psi.
+ long label_idx = 0;
+ for (unsigned long i = 0; i < distinct_labels.size(); ++i)
+ {
+ if (distinct_labels[i] == labels[idx])
+ {
+ label_idx = i;
+ break;
+ }
+ }
+
+ offset_feature_vector(psi, dims*label_idx);
+ }
+
+ virtual void separation_oracle (
+ const long idx,
+ const matrix_type& current_solution,
+ scalar_type& loss,
+ feature_vector_type& psi
+ ) const
+ {
+ scalar_type best_val = -std::numeric_limits<scalar_type>::infinity();
+ unsigned long best_idx = 0;
+
+ // Figure out which label is the best. That is, what label maximizes
+ // LOSS(idx,y) + F(x,y). Note that y in this case is given by distinct_labels[i].
+ for (unsigned long i = 0; i < distinct_labels.size(); ++i)
+ {
+ // Compute the F(x,y) part:
+ // perform: temp == dot(relevant part of current solution, samples[idx]) - current_bias
+ scalar_type temp = dot(mat(&current_solution(i*dims),dims-1), samples[idx]) - current_solution((i+1)*dims-1);
+
+ // Add the LOSS(idx,y) part:
+ if (labels[idx] != distinct_labels[i])
+ temp += 1;
+
+ // Now temp == LOSS(idx,y) + F(x,y). Check if it is the biggest we have seen.
+ if (temp > best_val)
+ {
+ best_val = temp;
+ best_idx = i;
+ }
+ }
+
+ assign(psi, samples[idx]);
+ // add a constant -1 to account for the bias term
+ psi.push_back(std::make_pair(dims-1,static_cast<scalar_type>(-1)));
+
+ offset_feature_vector(psi, dims*best_idx);
+
+ if (distinct_labels[best_idx] == labels[idx])
+ loss = 0;
+ else
+ loss = 1;
+ }
+
+ private:
+
+ void offset_feature_vector (
+ feature_vector_type& sample,
+ const unsigned long val
+ ) const
+ {
+ if (val != 0)
+ {
+ for (typename feature_vector_type::iterator i = sample.begin(); i != sample.end(); ++i)
+ {
+ i->first += val;
+ }
+ }
+ }
+
+
+ const std::vector<sample_type>& samples;
+ const std::vector<label_type>& labels;
+ const std::vector<label_type>& distinct_labels;
+ const long dims;
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename label_type_ = typename K::scalar_type
+ >
+ class svm_multiclass_linear_trainer
+ {
+ public:
+ typedef label_type_ label_type;
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ typedef multiclass_linear_decision_function<kernel_type, label_type> trained_function_type;
+
+
+ // You are getting a compiler error on this line because you supplied a non-linear kernel
+ // to the svm_c_linear_trainer object. You have to use one of the linear kernels with this
+ // trainer.
+ COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
+ is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
+
+ svm_multiclass_linear_trainer (
+ ) :
+ num_threads(4),
+ C(1),
+ eps(0.001),
+ max_iterations(10000),
+ verbose(false),
+ learn_nonnegative_weights(false)
+ {
+ }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ num_threads = num;
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return num_threads;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void svm_multiclass_linear_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const { return eps; }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ return kernel_type();
+ }
+
+ bool learns_nonnegative_weights (
+ ) const { return learn_nonnegative_weights; }
+
+ void set_learns_nonnegative_weights (
+ bool value
+ )
+ {
+ learn_nonnegative_weights = value;
+ if (learn_nonnegative_weights)
+ prior = trained_function_type();
+ }
+
+ void set_c (
+ scalar_type C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void svm_multiclass_linear_trainer::set_c()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ const scalar_type get_c (
+ ) const
+ {
+ return C;
+ }
+
+ void set_prior (
+ const trained_function_type& prior_
+ )
+ {
+ prior = prior_;
+ learn_nonnegative_weights = false;
+ }
+
+ bool has_prior (
+ ) const
+ {
+ return prior.labels.size() != 0;
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const
+ {
+ scalar_type svm_objective = 0;
+ return train(all_samples, all_labels, svm_objective);
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels,
+ scalar_type& svm_objective
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(all_samples,all_labels),
+ "\t trained_function_type svm_multiclass_linear_trainer::train(all_samples,all_labels)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t all_samples.size(): " << all_samples.size()
+ << "\n\t all_labels.size(): " << all_labels.size()
+ );
+
+ trained_function_type df;
+ df.labels = select_all_distinct_labels(all_labels);
+ if (has_prior())
+ {
+ df.labels.insert(df.labels.end(), prior.labels.begin(), prior.labels.end());
+ df.labels = select_all_distinct_labels(df.labels);
+ }
+ const long input_sample_dimensionality = max_index_plus_one(all_samples);
+ // If the samples are sparse then the right thing to do is to take the max
+ // dimensionality between the prior and the new samples. But if the samples
+ // are dense vectors then they definitely all have to have exactly the same
+ // dimensionality.
+ const long dims = std::max(df.weights.nc(),input_sample_dimensionality);
+ if (is_matrix<sample_type>::value && has_prior())
+ {
+ DLIB_ASSERT(input_sample_dimensionality == prior.weights.nc(),
+ "\t trained_function_type svm_multiclass_linear_trainer::train(all_samples,all_labels)"
+ << "\n\t The training samples given to this function are not the same kind of training "
+ << "\n\t samples used to create the prior."
+ << "\n\t input_sample_dimensionality: " << input_sample_dimensionality
+ << "\n\t prior.weights.nc(): " << prior.weights.nc()
+ );
+ }
+
+ typedef matrix<scalar_type,0,1> w_type;
+ w_type weights;
+ multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels, df.labels, dims, num_threads);
+ if (verbose)
+ problem.be_verbose();
+
+ problem.set_max_cache_size(0);
+ problem.set_c(C);
+ problem.set_epsilon(eps);
+ problem.set_max_iterations(max_iterations);
+
+ unsigned long num_nonnegative = 0;
+ if (learn_nonnegative_weights)
+ {
+ num_nonnegative = problem.get_num_dimensions();
+ }
+
+ if (!has_prior())
+ {
+ svm_objective = solver(problem, weights, num_nonnegative);
+ }
+ else
+ {
+ matrix<scalar_type> temp(df.labels.size(),dims);
+ w_type b(df.labels.size());
+ temp = 0;
+ b = 0;
+
+ const long pad_size = dims-prior.weights.nc();
+ // Copy the prior into the temp and b matrices. We have to do this row
+ // by row copy because the new training data might have new labels we
+ // haven't seen before and therefore the sizes of these matrices could be
+ // different.
+ for (unsigned long i = 0; i < prior.labels.size(); ++i)
+ {
+ const long r = std::find(df.labels.begin(), df.labels.end(), prior.labels[i])-df.labels.begin();
+ set_rowm(temp,r) = join_rows(rowm(prior.weights,i), zeros_matrix<scalar_type>(1,pad_size));
+ b(r) = prior.b(i);
+ }
+
+ const w_type prior_vect = reshape_to_column_vector(join_rows(temp,b));
+ svm_objective = solver(problem, weights, prior_vect);
+ }
+
+
+ df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1));
+ df.b = colm(reshape(weights, df.labels.size(), dims+1), dims);
+ return df;
+ }
+
+ private:
+
+ unsigned long num_threads;
+ scalar_type C;
+ scalar_type eps;
+ unsigned long max_iterations;
+ bool verbose;
+ oca solver;
+ bool learn_nonnegative_weights;
+
+ trained_function_type prior;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_SVm_MULTICLASS_LINEAR_TRAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/svm_multiclass_linear_trainer_abstract.h b/ml/dlib/dlib/svm/svm_multiclass_linear_trainer_abstract.h
new file mode 100644
index 000000000..6561ce7b2
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_multiclass_linear_trainer_abstract.h
@@ -0,0 +1,275 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_MULTICLASS_LINEAR_TRAINER_ABSTRACT_Hh_
+#ifdef DLIB_SVm_MULTICLASS_LINEAR_TRAINER_ABSTRACT_Hh_
+
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "sparse_kernel_abstract.h"
+#include "../optimization/optimization_oca_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename label_type_ = typename K::scalar_type
+ >
+ class svm_multiclass_linear_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ Is either linear_kernel or sparse_linear_kernel.
+
+ REQUIREMENTS ON label_type_
+ label_type_ must be default constructable, copyable, and comparable using
+ operator < and ==. It must also be possible to write it to an std::ostream
+ using operator<<.
+
+ INITIAL VALUE
+ - get_num_threads() == 4
+ - learns_nonnegative_weights() == false
+ - get_epsilon() == 0.001
+ - get_max_iterations() == 10000
+ - get_c() == 1
+ - this object will not be verbose unless be_verbose() is called
+ - #get_oca() == oca() (i.e. an instance of oca with default parameters)
+ - has_prior() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a tool for training a multiclass support
+ vector machine. It is optimized for the case where linear kernels
+ are used.
+ !*/
+
+ public:
+ typedef label_type_ label_type;
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef multiclass_linear_decision_function<kernel_type, label_type> trained_function_type;
+
+ svm_multiclass_linear_trainer (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer
+ to execute.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ unsigned long get_max_iterations (
+ );
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a
+ user can observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the SVM problem.
+ !*/
+
+ void set_num_threads (
+ unsigned long num
+ );
+ /*!
+ ensures
+ - #get_num_threads() == num
+ !*/
+
+ unsigned long get_num_threads (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads used during training. You should
+ usually set this equal to the number of processing cores on your
+ machine.
+ !*/
+
+ const kernel_type get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object. Since
+ the linear kernels don't have any parameters this function just
+ returns kernel_type()
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() == C
+ !*/
+
+ const scalar_type get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter that
+ determines the trade off between trying to fit the training data
+ exactly or allowing more errors but hopefully improving the
+ generalization of the resulting classifier. Larger values encourage
+ exact fitting while smaller values of C may encourage better
+ generalization.
+ !*/
+
+ bool learns_nonnegative_weights (
+ ) const;
+ /*!
+ ensures
+ - The output of training is a set of weights and bias values that together
+ define the behavior of a multiclass_linear_decision_function object. If
+ learns_nonnegative_weights() == true then the resulting weights and bias
+ values will always have non-negative values. That is, if this function
+ returns true then all the numbers in the multiclass_linear_decision_function
+ objects output by train() will be non-negative.
+ !*/
+
+ void set_learns_nonnegative_weights (
+ bool value
+ );
+ /*!
+ ensures
+ - #learns_nonnegative_weights() == value
+ - if (value == true) then
+ - #has_prior() == false
+ !*/
+
+ void set_prior (
+ const trained_function_type& prior
+ );
+ /*!
+ ensures
+ - Subsequent calls to train() will try to learn a function similar to the
+ given prior.
+ - #has_prior() == true
+ - #learns_nonnegative_weights() == false
+ !*/
+
+ bool has_prior (
+ ) const
+ /*!
+ ensures
+ - returns true if a prior has been set and false otherwise. Having a prior
+ set means that you have called set_prior() and supplied a previously
+ trained function as a reference. In this case, any call to train() will
+ try to learn a function that matches the behavior of the prior as close
+ as possible but also fits the supplied training data. In more technical
+ detail, having a prior means we replace the ||w||^2 regularizer with one
+ of the form ||w-prior||^2 where w is the set of parameters for a learned
+ function.
+ !*/
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(all_samples, all_labels)
+ - All the vectors in all_samples must have the same dimensionality.
+ - if (has_prior()) then
+ - The vectors in all_samples must have the same dimensionality as the
+ vectors used to train the prior given to set_prior().
+ ensures
+ - trains a multiclass SVM to solve the given multiclass classification problem.
+ - returns a multiclass_linear_decision_function F with the following properties:
+ - if (new_x is a sample predicted to have a label of L) then
+ - F(new_x) == L
+ - F.get_labels() == select_all_distinct_labels(all_labels)
+ - F.number_of_classes() == select_all_distinct_labels(all_labels).size()
+ !*/
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels,
+ scalar_type& svm_objective
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(all_samples, all_labels)
+ - All the vectors in all_samples must have the same dimensionality.
+ - if (has_prior()) then
+ - The vectors in all_samples must have the same dimensionality as the
+ vectors used to train the prior given to set_prior().
+ ensures
+ - trains a multiclass SVM to solve the given multiclass classification problem.
+ - returns a multiclass_linear_decision_function F with the following properties:
+ - if (new_x is a sample predicted to have a label of L) then
+ - F(new_x) == L
+ - F.get_labels() == select_all_distinct_labels(all_labels)
+ - F.number_of_classes() == select_all_distinct_labels(all_labels).size()
+ - #svm_objective == the final value of the SVM objective function
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_SVm_MULTICLASS_LINEAR_TRAINER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/svm_nu_trainer.h b/ml/dlib/dlib/svm/svm_nu_trainer.h
new file mode 100644
index 000000000..1e89d6efa
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_nu_trainer.h
@@ -0,0 +1,326 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_NU_TRAINER_Hh_
+#define DLIB_SVm_NU_TRAINER_Hh_
+
+//#include "local/make_label_kernel_matrix.h"
+
+#include "svm_nu_trainer_abstract.h"
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix.h"
+#include "../algs.h"
+#include "../serialize.h"
+
+#include "function.h"
+#include "kernel.h"
+#include "../optimization/optimization_solve_qp2_using_smo.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_nu_trainer
+ {
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_nu_trainer (
+ ) :
+ nu(0.1),
+ cache_size(200),
+ eps(0.001)
+ {
+ }
+
+ svm_nu_trainer (
+ const kernel_type& kernel_,
+ const scalar_type& nu_
+ ) :
+ kernel_function(kernel_),
+ nu(nu_),
+ cache_size(200),
+ eps(0.001)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < nu && nu <= 1,
+ "\tsvm_nu_trainer::svm_nu_trainer(kernel,nu)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t nu: " << nu
+ );
+ }
+
+ void set_cache_size (
+ long cache_size_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(cache_size_ > 0,
+ "\tvoid svm_nu_trainer::set_cache_size(cache_size_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t cache_size: " << cache_size_
+ );
+ cache_size = cache_size_;
+ }
+
+ long get_cache_size (
+ ) const
+ {
+ return cache_size;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\tvoid svm_nu_trainer::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps: " << eps_
+ );
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kernel_function = k;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel_function;
+ }
+
+ void set_nu (
+ scalar_type nu_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < nu_ && nu_ <= 1,
+ "\tvoid svm_nu_trainer::set_nu(nu_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t nu: " << nu_
+ );
+ nu = nu_;
+ }
+
+ const scalar_type get_nu (
+ ) const
+ {
+ return nu;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ return do_train(mat(x), mat(y));
+ }
+
+ void swap (
+ svm_nu_trainer& item
+ )
+ {
+ exchange(kernel_function, item.kernel_function);
+ exchange(nu, item.nu);
+ exchange(cache_size, item.cache_size);
+ exchange(eps, item.eps);
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ typedef typename K::scalar_type scalar_type;
+ typedef typename decision_function<K>::sample_vector_type sample_vector_type;
+ typedef typename decision_function<K>::scalar_vector_type scalar_vector_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
+ "\tdecision_function svm_nu_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t y.nr(): " << y.nr()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t y.nc(): " << y.nc()
+ << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
+ );
+
+
+ scalar_vector_type alpha;
+
+ solve_qp2_using_smo<scalar_vector_type> solver;
+
+ solver(symmetric_matrix_cache<float>((diagm(y)*kernel_matrix(kernel_function,x)*diagm(y)), cache_size),
+ //solver(symmetric_matrix_cache<float>(make_label_kernel_matrix(kernel_matrix(kernel_function,x),y), cache_size),
+ y,
+ nu,
+ alpha,
+ eps);
+
+ scalar_type rho, b;
+ calculate_rho_and_b(y,alpha,solver.get_gradient(),rho,b);
+ alpha = pointwise_multiply(alpha,y)/rho;
+
+ // count the number of support vectors
+ const long sv_count = (long)sum(alpha != 0);
+
+ scalar_vector_type sv_alpha;
+ sample_vector_type support_vectors;
+
+ // size these column vectors so that they have an entry for each support vector
+ sv_alpha.set_size(sv_count);
+ support_vectors.set_size(sv_count);
+
+ // load the support vectors and their alpha values into these new column matrices
+ long idx = 0;
+ for (long i = 0; i < alpha.nr(); ++i)
+ {
+ if (alpha(i) != 0)
+ {
+ sv_alpha(idx) = alpha(i);
+ support_vectors(idx) = x(i);
+ ++idx;
+ }
+ }
+
+ // now return the decision function
+ return decision_function<K> (sv_alpha, b, kernel_function, support_vectors);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename scalar_vector_type,
+ typename scalar_vector_type2,
+ typename scalar_type
+ >
+ void calculate_rho_and_b(
+ const scalar_vector_type2& y,
+ const scalar_vector_type& alpha,
+ const scalar_vector_type& df,
+ scalar_type& rho,
+ scalar_type& b
+ ) const
+ {
+ using namespace std;
+ long num_p_free = 0;
+ long num_n_free = 0;
+ scalar_type sum_p_free = 0;
+ scalar_type sum_n_free = 0;
+
+ scalar_type upper_bound_p = -numeric_limits<scalar_type>::infinity();
+ scalar_type upper_bound_n = -numeric_limits<scalar_type>::infinity();
+ scalar_type lower_bound_p = numeric_limits<scalar_type>::infinity();
+ scalar_type lower_bound_n = numeric_limits<scalar_type>::infinity();
+
+ for(long i = 0; i < alpha.nr(); ++i)
+ {
+ if(y(i) == 1)
+ {
+ if(alpha(i) == 1)
+ {
+ if (df(i) > upper_bound_p)
+ upper_bound_p = df(i);
+ }
+ else if(alpha(i) == 0)
+ {
+ if (df(i) < lower_bound_p)
+ lower_bound_p = df(i);
+ }
+ else
+ {
+ ++num_p_free;
+ sum_p_free += df(i);
+ }
+ }
+ else
+ {
+ if(alpha(i) == 1)
+ {
+ if (df(i) > upper_bound_n)
+ upper_bound_n = df(i);
+ }
+ else if(alpha(i) == 0)
+ {
+ if (df(i) < lower_bound_n)
+ lower_bound_n = df(i);
+ }
+ else
+ {
+ ++num_n_free;
+ sum_n_free += df(i);
+ }
+ }
+ }
+
+ scalar_type r1,r2;
+ if(num_p_free > 0)
+ r1 = sum_p_free/num_p_free;
+ else
+ r1 = (upper_bound_p+lower_bound_p)/2;
+
+ if(num_n_free > 0)
+ r2 = sum_n_free/num_n_free;
+ else
+ r2 = (upper_bound_n+lower_bound_n)/2;
+
+ rho = (r1+r2)/2;
+ b = (r1-r2)/2/rho;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ kernel_type kernel_function;
+ scalar_type nu;
+ long cache_size;
+ scalar_type eps;
+ }; // end of class svm_nu_trainer
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ svm_nu_trainer<K>& a,
+ svm_nu_trainer<K>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_NU_TRAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/svm_nu_trainer_abstract.h b/ml/dlib/dlib/svm/svm_nu_trainer_abstract.h
new file mode 100644
index 000000000..5ae0fba4a
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_nu_trainer_abstract.h
@@ -0,0 +1,210 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_NU_TRAINER_ABSTRACT_
+#ifdef DLIB_SVm_NU_TRAINER_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "../optimization/optimization_solve_qp2_using_smo_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_nu_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a trainer for a nu support vector machine for
+ solving binary classification problems. It is implemented using the SMO
+ algorithm.
+
+ The implementation of the nu-svm training algorithm used by this object is based
+ on the following excellent papers:
+ - Chang and Lin, Training {nu}-Support Vector Classifiers: Theory and Algorithms
+ - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector
+ machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm
+
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_nu_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_nu() == 0.1
+ - #get_cache_size() == 200
+ - #get_epsilon() == 0.001
+ !*/
+
+ svm_nu_trainer (
+ const kernel_type& kernel,
+ const scalar_type& nu
+ );
+ /*!
+ requires
+ - 0 < nu <= 1
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_kernel() == kernel
+ - #get_nu() == nu
+ - #get_cache_size() == 200
+ - #get_epsilon() == 0.001
+ !*/
+
+ void set_cache_size (
+ long cache_size
+ );
+ /*!
+ requires
+ - cache_size > 0
+ ensures
+ - #get_cache_size() == cache_size
+ !*/
+
+ const long get_cache_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of megabytes of cache this object will use
+ when it performs training via the this->train() function.
+ (bigger values of this may make training go faster but won't affect
+ the result. However, too big a value will cause you to run out of
+ memory, obviously.)
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Generally a good value for this is 0.001. Smaller values may result
+ in a more accurate solution but take longer to execute.
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ void set_nu (
+ scalar_type nu
+ );
+ /*!
+ requires
+ - 0 < nu <= 1
+ ensures
+ - #get_nu() == nu
+ !*/
+
+ const scalar_type get_nu (
+ ) const;
+ /*!
+ ensures
+ - returns the nu svm parameter. This is a value between 0 and
+ 1. It is the parameter that determines the trade off between
+ trying to fit the training data exactly or allowing more errors
+ but hopefully improving the generalization ability of the
+ resulting classifier. Smaller values encourage exact fitting
+ while larger values of nu may encourage better generalization.
+ For more information you should consult the papers referenced
+ above.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - is_binary_classification_problem(x,y) == true
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ ensures
+ - trains a nu support vector classifier given the training samples in x and
+ labels in y. Training is done when the error is less than get_epsilon().
+ - returns a decision function F with the following properties:
+ - if (new_x is a sample predicted have +1 label) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ throws
+ - invalid_nu_error
+ This exception is thrown if get_nu() >= maximum_nu(y)
+ - std::bad_alloc
+ !*/
+
+ void swap (
+ svm_nu_trainer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+ };
+
+ template <typename K>
+ void swap (
+ svm_nu_trainer<K>& a,
+ svm_nu_trainer<K>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_NU_TRAINER_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/svm_one_class_trainer.h b/ml/dlib/dlib/svm/svm_one_class_trainer.h
new file mode 100644
index 000000000..be3cc8caf
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_one_class_trainer.h
@@ -0,0 +1,284 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_ONE_CLASS_TRAINER_Hh_
+#define DLIB_SVm_ONE_CLASS_TRAINER_Hh_
+
+#include "svm_one_class_trainer_abstract.h"
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include "../matrix.h"
+#include "../algs.h"
+
+#include "function.h"
+#include "kernel.h"
+#include "../optimization/optimization_solve_qp3_using_smo.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_one_class_trainer
+ {
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_one_class_trainer (
+ ) :
+ nu(0.1),
+ cache_size(200),
+ eps(0.001)
+ {
+ }
+
+ svm_one_class_trainer (
+ const kernel_type& kernel_,
+ const scalar_type& nu_
+ ) :
+ kernel_function(kernel_),
+ nu(nu_),
+ cache_size(200),
+ eps(0.001)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < nu && nu <= 1,
+ "\tsvm_one_class_trainer::svm_one_class_trainer(kernel,nu)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t nu: " << nu
+ );
+ }
+
+ void set_cache_size (
+ long cache_size_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(cache_size_ > 0,
+ "\tvoid svm_one_class_trainer::set_cache_size(cache_size_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t cache_size: " << cache_size_
+ );
+ cache_size = cache_size_;
+ }
+
+ long get_cache_size (
+ ) const
+ {
+ return cache_size;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\tvoid svm_one_class_trainer::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps: " << eps_
+ );
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kernel_function = k;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel_function;
+ }
+
+ void set_nu (
+ scalar_type nu_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 < nu_ && nu_ <= 1,
+ "\tvoid svm_one_class_trainer::set_nu(nu_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t nu: " << nu_
+ );
+ nu = nu_;
+ }
+
+ const scalar_type get_nu (
+ ) const
+ {
+ return nu;
+ }
+
+ template <
+ typename in_sample_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x
+ ) const
+ {
+ return do_train(mat(x));
+ }
+
+ void swap (
+ svm_one_class_trainer& item
+ )
+ {
+ exchange(kernel_function, item.kernel_function);
+ exchange(nu, item.nu);
+ exchange(cache_size, item.cache_size);
+ exchange(eps, item.eps);
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename in_sample_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x
+ ) const
+ {
+ typedef typename K::scalar_type scalar_type;
+ typedef typename decision_function<K>::sample_vector_type sample_vector_type;
+ typedef typename decision_function<K>::scalar_vector_type scalar_vector_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_col_vector(x) && x.size() > 0,
+ "\tdecision_function svm_one_class_trainer::train(x)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t x.nc(): " << x.nc()
+ );
+
+
+ scalar_vector_type alpha;
+
+ solve_qp3_using_smo<scalar_vector_type> solver;
+
+ solver(symmetric_matrix_cache<float>(kernel_matrix(kernel_function,x), cache_size),
+ zeros_matrix<scalar_type>(x.size(),1),
+ ones_matrix<scalar_type>(x.size(),1),
+ nu*x.size(),
+ 1,
+ 1,
+ alpha,
+ eps);
+
+ scalar_type rho;
+ calculate_rho(alpha,solver.get_gradient(),rho);
+
+
+ // count the number of support vectors
+ const long sv_count = (long)sum(alpha != 0);
+
+ scalar_vector_type sv_alpha;
+ sample_vector_type support_vectors;
+
+ // size these column vectors so that they have an entry for each support vector
+ sv_alpha.set_size(sv_count);
+ support_vectors.set_size(sv_count);
+
+ // load the support vectors and their alpha values into these new column matrices
+ long idx = 0;
+ for (long i = 0; i < alpha.nr(); ++i)
+ {
+ if (alpha(i) != 0)
+ {
+ sv_alpha(idx) = alpha(i);
+ support_vectors(idx) = x(i);
+ ++idx;
+ }
+ }
+
+ // now return the decision function
+ return decision_function<K> (sv_alpha, rho, kernel_function, support_vectors);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename scalar_vector_type
+ >
+ void calculate_rho(
+ const scalar_vector_type& alpha,
+ const scalar_vector_type& df,
+ scalar_type& rho
+ ) const
+ {
+ using namespace std;
+ long num_p_free = 0;
+ scalar_type sum_p_free = 0;
+
+
+ scalar_type upper_bound_p;
+ scalar_type lower_bound_p;
+
+ find_min_and_max(df, upper_bound_p, lower_bound_p);
+
+ for(long i = 0; i < alpha.nr(); ++i)
+ {
+ if(alpha(i) == 1)
+ {
+ if (df(i) > upper_bound_p)
+ upper_bound_p = df(i);
+ }
+ else if(alpha(i) == 0)
+ {
+ if (df(i) < lower_bound_p)
+ lower_bound_p = df(i);
+ }
+ else
+ {
+ ++num_p_free;
+ sum_p_free += df(i);
+ }
+ }
+
+ scalar_type r1;
+ if(num_p_free > 0)
+ r1 = sum_p_free/num_p_free;
+ else
+ r1 = (upper_bound_p+lower_bound_p)/2;
+
+ rho = r1;
+ }
+
+ kernel_type kernel_function;
+ scalar_type nu;
+ long cache_size;
+ scalar_type eps;
+ }; // end of class svm_one_class_trainer
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ svm_one_class_trainer<K>& a,
+ svm_one_class_trainer<K>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_ONE_CLASS_TRAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/svm_one_class_trainer_abstract.h b/ml/dlib/dlib/svm/svm_one_class_trainer_abstract.h
new file mode 100644
index 000000000..6b55919ad
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_one_class_trainer_abstract.h
@@ -0,0 +1,201 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_ONE_CLASS_TRAINER_ABSTRACT_
+#ifdef DLIB_SVm_ONE_CLASS_TRAINER_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "../optimization/optimization_solve_qp3_using_smo_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_one_class_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a trainer for a support vector machine for
+ solving one-class classification problems. It is implemented using the SMO
+ algorithm.
+
+ The implementation of the training algorithm used by this object is based
+ on the following excellent paper:
+ - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector
+ machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm
+
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_one_class_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_nu() == 0.1
+ - #get_cache_size() == 200
+ - #get_epsilon() == 0.001
+ !*/
+
+ svm_one_class_trainer (
+ const kernel_type& kernel,
+ const scalar_type& nu
+ );
+ /*!
+ requires
+ - 0 < nu <= 1
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_kernel() == kernel
+ - #get_nu() == nu
+ - #get_cache_size() == 200
+ - #get_epsilon() == 0.001
+ !*/
+
+ void set_cache_size (
+ long cache_size
+ );
+ /*!
+ requires
+ - cache_size > 0
+ ensures
+ - #get_cache_size() == cache_size
+ !*/
+
+ const long get_cache_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of megabytes of cache this object will use
+ when it performs training via the this->train() function.
+ (bigger values of this may make training go faster but won't affect
+ the result. However, too big a value will cause you to run out of
+ memory, obviously.)
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Generally a good value for this is 0.001. Smaller values may result
+ in a more accurate solution but take longer to execute.
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ void set_nu (
+ scalar_type nu
+ );
+ /*!
+ requires
+ - 0 < nu <= 1
+ ensures
+ - #get_nu() == nu
+ !*/
+
+ const scalar_type get_nu (
+ ) const;
+ /*!
+ ensures
+ - returns the nu svm parameter. This is a value between 0 and
+ 1. It is the parameter that determines the trade off between
+ trying to fit the training data exactly or allowing more errors
+ but hopefully improving the generalization ability of the
+ resulting classifier. Smaller values encourage exact fitting
+ while larger values of nu may encourage better generalization.
+ For more information you should consult the papers referenced
+ above.
+ !*/
+
+ template <
+ typename in_sample_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x
+ ) const;
+ /*!
+ requires
+ - x.size() > 0
+ - is_col_vector(x) == true
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ ensures
+ - trains a one-class support vector classifier given the training samples in x.
+ Training is done when the error is less than get_epsilon().
+ - returns a decision function F with the following properties:
+ - if (new_x is a sample predicted to arise from the distribution
+ which generated the training samples) then
+ - F(new_x) >= 0
+ - else
+ - F(new_x) < 0
+ !*/
+
+ void swap (
+ svm_one_class_trainer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+ };
+
+ template <typename K>
+ void swap (
+ svm_one_class_trainer<K>& a,
+ svm_one_class_trainer<K>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_ONE_CLASS_TRAINER_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/svm_rank_trainer.h b/ml/dlib/dlib/svm/svm_rank_trainer.h
new file mode 100644
index 000000000..0be737f48
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_rank_trainer.h
@@ -0,0 +1,495 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVM_RANK_TrAINER_Hh_
+#define DLIB_SVM_RANK_TrAINER_Hh_
+
+#include "svm_rank_trainer_abstract.h"
+
+#include "ranking_tools.h"
+#include "../algs.h"
+#include "../optimization.h"
+#include "function.h"
+#include "kernel.h"
+#include "sparse_vector.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type,
+ typename sample_type
+ >
+ class oca_problem_ranking_svm : public oca_problem<matrix_type >
+ {
+ public:
+ /*
+ This class is used as part of the implementation of the svm_rank_trainer
+ defined towards the end of this file.
+ */
+
+ typedef typename matrix_type::type scalar_type;
+
+ oca_problem_ranking_svm(
+ const scalar_type C_,
+ const std::vector<ranking_pair<sample_type> >& samples_,
+ const bool be_verbose_,
+ const scalar_type eps_,
+ const unsigned long max_iter,
+ const unsigned long dims_
+ ) :
+ samples(samples_),
+ C(C_),
+ be_verbose(be_verbose_),
+ eps(eps_),
+ max_iterations(max_iter),
+ dims(dims_)
+ {
+ }
+
+ virtual scalar_type get_c (
+ ) const
+ {
+ return C;
+ }
+
+ virtual long get_num_dimensions (
+ ) const
+ {
+ return dims;
+ }
+
+ virtual bool optimization_status (
+ scalar_type current_objective_value,
+ scalar_type current_error_gap,
+ scalar_type current_risk_value,
+ scalar_type current_risk_gap,
+ unsigned long num_cutting_planes,
+ unsigned long num_iterations
+ ) const
+ {
+ if (be_verbose)
+ {
+ using namespace std;
+ cout << "objective: " << current_objective_value << endl;
+ cout << "objective gap: " << current_error_gap << endl;
+ cout << "risk: " << current_risk_value << endl;
+ cout << "risk gap: " << current_risk_gap << endl;
+ cout << "num planes: " << num_cutting_planes << endl;
+ cout << "iter: " << num_iterations << endl;
+ cout << endl;
+ }
+
+ if (num_iterations >= max_iterations)
+ return true;
+
+ if (current_risk_gap < eps)
+ return true;
+
+ return false;
+ }
+
+ virtual bool risk_has_lower_bound (
+ scalar_type& lower_bound
+ ) const
+ {
+ lower_bound = 0;
+ return true;
+ }
+
+ virtual void get_risk (
+ matrix_type& w,
+ scalar_type& risk,
+ matrix_type& subgradient
+ ) const
+ {
+ subgradient.set_size(w.size(),1);
+ subgradient = 0;
+ risk = 0;
+
+ // Note that we want the risk value to be in terms of the fraction of overall
+ // rank flips. So a risk of 0.1 would mean that rank flips happen < 10% of the
+ // time.
+
+
+ std::vector<double> rel_scores;
+ std::vector<double> nonrel_scores;
+ std::vector<unsigned long> rel_counts;
+ std::vector<unsigned long> nonrel_counts;
+
+ unsigned long total_pairs = 0;
+
+ // loop over all the samples and compute the risk and its subgradient at the current solution point w
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ rel_scores.resize(samples[i].relevant.size());
+ nonrel_scores.resize(samples[i].nonrelevant.size());
+
+ for (unsigned long k = 0; k < rel_scores.size(); ++k)
+ rel_scores[k] = dot(samples[i].relevant[k], w);
+
+ for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
+ nonrel_scores[k] = dot(samples[i].nonrelevant[k], w) + 1;
+
+ count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
+
+ total_pairs += rel_scores.size()*nonrel_scores.size();
+
+ for (unsigned long k = 0; k < rel_counts.size(); ++k)
+ {
+ if (rel_counts[k] != 0)
+ {
+ risk -= rel_counts[k]*rel_scores[k];
+ subtract_from(subgradient, samples[i].relevant[k], rel_counts[k]);
+ }
+ }
+
+ for (unsigned long k = 0; k < nonrel_counts.size(); ++k)
+ {
+ if (nonrel_counts[k] != 0)
+ {
+ risk += nonrel_counts[k]*nonrel_scores[k];
+ add_to(subgradient, samples[i].nonrelevant[k], nonrel_counts[k]);
+ }
+ }
+
+ }
+
+ const scalar_type scale = 1.0/total_pairs;
+
+ risk *= scale;
+ subgradient = scale*subgradient;
+ }
+
+ private:
+
+ // -----------------------------------------------------
+ // -----------------------------------------------------
+
+
+ const std::vector<ranking_pair<sample_type> >& samples;
+ const scalar_type C;
+
+ const bool be_verbose;
+ const scalar_type eps;
+ const unsigned long max_iterations;
+ const unsigned long dims;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type,
+ typename sample_type,
+ typename scalar_type
+ >
+ oca_problem_ranking_svm<matrix_type, sample_type> make_oca_problem_ranking_svm (
+ const scalar_type C,
+ const std::vector<ranking_pair<sample_type> >& samples,
+ const bool be_verbose,
+ const scalar_type eps,
+ const unsigned long max_iterations,
+ const unsigned long dims
+ )
+ {
+ return oca_problem_ranking_svm<matrix_type, sample_type>(
+ C, samples, be_verbose, eps, max_iterations, dims);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_rank_trainer
+ {
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ // You are getting a compiler error on this line because you supplied a non-linear kernel
+ // to the svm_rank_trainer object. You have to use one of the linear kernels with this
+ // trainer.
+ COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
+ is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
+
+ svm_rank_trainer (
+ )
+ {
+ C = 1;
+ verbose = false;
+ eps = 0.001;
+ max_iterations = 10000;
+ learn_nonnegative_weights = false;
+ last_weight_1 = false;
+ }
+
+ explicit svm_rank_trainer (
+ const scalar_type& C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t svm_rank_trainer::svm_rank_trainer()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ verbose = false;
+ eps = 0.001;
+ max_iterations = 10000;
+ learn_nonnegative_weights = false;
+ last_weight_1 = false;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void svm_rank_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const { return eps; }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ bool forces_last_weight_to_1 (
+ ) const
+ {
+ return last_weight_1;
+ }
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ )
+ {
+ last_weight_1 = should_last_weight_be_1;
+ if (last_weight_1)
+ prior.set_size(0);
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ return kernel_type();
+ }
+
+ bool learns_nonnegative_weights (
+ ) const { return learn_nonnegative_weights; }
+
+ void set_learns_nonnegative_weights (
+ bool value
+ )
+ {
+ learn_nonnegative_weights = value;
+ if (learn_nonnegative_weights)
+ prior.set_size(0);
+ }
+
+ void set_prior (
+ const trained_function_type& prior_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(prior_.basis_vectors.size() == 1 &&
+ prior_.alpha(0) == 1,
+ "\t void svm_rank_trainer::set_prior()"
+ << "\n\t The supplied prior could not have been created by this object's train() method."
+ << "\n\t prior_.basis_vectors.size(): " << prior_.basis_vectors.size()
+ << "\n\t prior_.alpha(0): " << prior_.alpha(0)
+ << "\n\t this: " << this
+ );
+
+ prior = sparse_to_dense(prior_.basis_vectors(0));
+ learn_nonnegative_weights = false;
+ last_weight_1 = false;
+ }
+
+ bool has_prior (
+ ) const
+ {
+ return prior.size() != 0;
+ }
+
+ void set_c (
+ scalar_type C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void svm_rank_trainer::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ const scalar_type get_c (
+ ) const
+ {
+ return C;
+ }
+
+ const decision_function<kernel_type> train (
+ const std::vector<ranking_pair<sample_type> >& samples
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_ranking_problem(samples) == true,
+ "\t decision_function svm_rank_trainer::train(samples)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples)
+ );
+
+
+ typedef matrix<scalar_type,0,1> w_type;
+ w_type w;
+
+ const unsigned long num_dims = max_index_plus_one(samples);
+
+ unsigned long num_nonnegative = 0;
+ if (learn_nonnegative_weights)
+ {
+ num_nonnegative = num_dims;
+ }
+
+ unsigned long force_weight_1_idx = std::numeric_limits<unsigned long>::max();
+ if (last_weight_1)
+ {
+ force_weight_1_idx = num_dims-1;
+ }
+
+ if (has_prior())
+ {
+ if (is_matrix<sample_type>::value)
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(num_dims == (unsigned long)prior.size(),
+ "\t decision_function svm_rank_trainer::train(samples)"
+ << "\n\t The dimension of the training vectors must match the dimension of\n"
+ << "\n\t those used to create the prior."
+ << "\n\t num_dims: " << num_dims
+ << "\n\t prior.size(): " << prior.size()
+ );
+ }
+ const unsigned long dims = std::max(num_dims, (unsigned long)prior.size());
+ // In the case of sparse sample vectors, it is possible that the input
+ // vector dimensionality is larger than the prior vector dimensionality.
+ // We need to check for this case and pad prior with zeros if it is the
+ // case.
+ if ((unsigned long)prior.size() < dims)
+ {
+ matrix<scalar_type,0,1> prior_temp = join_cols(prior, zeros_matrix<scalar_type>(dims-prior.size(),1));
+ solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, dims),
+ w,
+ prior_temp);
+ }
+ else
+ {
+ solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, dims),
+ w,
+ prior);
+ }
+
+ }
+ else
+ {
+ solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, num_dims),
+ w,
+ num_nonnegative,
+ force_weight_1_idx);
+ }
+
+
+ // put the solution into a decision function and then return it
+ decision_function<kernel_type> df;
+ df.b = 0;
+ df.basis_vectors.set_size(1);
+ // Copy the results into the output basis vector. The output vector might be a
+ // sparse vector container so we need to use this special kind of copy to
+ // handle that case.
+ assign(df.basis_vectors(0), matrix_cast<scalar_type>(w));
+ df.alpha.set_size(1);
+ df.alpha(0) = 1;
+
+ return df;
+ }
+
+ const decision_function<kernel_type> train (
+ const ranking_pair<sample_type>& sample
+ ) const
+ {
+ return train(std::vector<ranking_pair<sample_type> >(1, sample));
+ }
+
+ private:
+
+ scalar_type C;
+ oca solver;
+ scalar_type eps;
+ bool verbose;
+ unsigned long max_iterations;
+ bool learn_nonnegative_weights;
+ bool last_weight_1;
+ matrix<scalar_type,0,1> prior;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVM_RANK_TrAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/svm_rank_trainer_abstract.h b/ml/dlib/dlib/svm/svm_rank_trainer_abstract.h
new file mode 100644
index 000000000..4658d950f
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_rank_trainer_abstract.h
@@ -0,0 +1,298 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVM_RANK_TrAINER_ABSTRACT_Hh_
+#ifdef DLIB_SVM_RANK_TrAINER_ABSTRACT_Hh_
+
+#include "ranking_tools_abstract.h"
+#include "sparse_vector_abstract.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svm_rank_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ Is either linear_kernel or sparse_linear_kernel.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a tool for training a ranking support vector machine
+ using linear kernels. In particular, this object is a tool for training
+ the Ranking SVM described in the paper:
+ Optimizing Search Engines using Clickthrough Data by Thorsten Joachims
+
+ Note that we normalize the C parameter by multiplying it by 1/(number of ranking pairs).
+ Therefore, to make an exact comparison between this object and Equation 12
+ in the paper you must multiply C by the appropriate normalizing quantity.
+
+ Finally, note that the implementation of this object is done using the oca
+ optimizer and count_ranking_inversions() method. This means that it runs
+ in O(n*log(n)) time, making it suitable for use with large datasets.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svm_rank_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used to train a
+ ranking support vector machine.
+ - #get_oca() == oca() (i.e. an instance of oca with default parameters)
+ - #get_c() == 1
+ - #get_epsilon() == 0.001
+ - this object will not be verbose unless be_verbose() is called
+ - #get_max_iterations() == 10000
+ - #learns_nonnegative_weights() == false
+ - #forces_last_weight_to_1() == false
+ - #has_prior() == false
+ !*/
+
+ explicit svm_rank_trainer (
+ const scalar_type& C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - This object is properly initialized and ready to be used to train a
+ ranking support vector machine.
+ - #get_oca() == oca() (i.e. an instance of oca with default parameters)
+ - #get_c() == C
+ - #get_epsilon() == 0.001
+ - this object will not be verbose unless be_verbose() is called
+ - #get_max_iterations() == 10000
+ - #learns_nonnegative_weights() == false
+ - #forces_last_weight_to_1() == false
+ - #has_prior() == false
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ );
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer to
+ train. You can think of this epsilon value as saying "solve the
+ optimization problem until the average ranking accuracy is within epsilon
+ of its optimal value". Here we mean "ranking accuracy" in the same sense
+ used by test_ranking_function() and cross_validate_ranking_trainer().
+ !*/
+
+ unsigned long get_max_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a user can
+ observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ bool forces_last_weight_to_1 (
+ ) const;
+ /*!
+ ensures
+ - returns true if this trainer has the constraint that the last weight in
+ the learned parameter vector must be 1. This is the weight corresponding
+ to the feature in the training vectors with the highest dimension.
+ !*/
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ );
+ /*!
+ ensures
+ - #forces_last_weight_to_1() == should_last_weight_be_1
+ - if (should_last_weight_be_1 == true) then
+ - #has_prior() == false
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the SVM problem.
+ !*/
+
+ const kernel_type get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object. Since the
+ linear kernels don't have any parameters this function just returns
+ kernel_type()
+ !*/
+
+ bool learns_nonnegative_weights (
+ ) const;
+ /*!
+ ensures
+ - The output of training is a weight vector that defines the behavior of
+ the resulting decision function. That is, the decision function simply
+ takes the dot product between the learned weight vector and a test sample
+ and returns the result. Therefore, if learns_nonnegative_weights() == true
+ then the resulting learned weight vector will always have non-negative
+ entries.
+ !*/
+
+ void set_learns_nonnegative_weights (
+ bool value
+ );
+ /*!
+ ensures
+ - #learns_nonnegative_weights() == value
+ - if (value == true) then
+ - #has_prior() == false
+ !*/
+
+ void set_prior (
+ const trained_function_type& prior
+ );
+ /*!
+ requires
+ - prior == a function produced by a call to this class's train() function.
+ Therefore, it must be the case that:
+ - prior.basis_vectors.size() == 1
+ - prior.alpha(0) == 1
+ ensures
+ - Subsequent calls to train() will try to learn a function similar to the
+ given prior.
+ - #has_prior() == true
+ - #learns_nonnegative_weights() == false
+ - #forces_last_weight_to_1() == false
+ !*/
+
+ bool has_prior (
+ ) const
+ /*!
+ ensures
+ - returns true if a prior has been set and false otherwise. Having a prior
+ set means that you have called set_prior() and supplied a previously
+ trained function as a reference. In this case, any call to train() will
+ try to learn a function that matches the behavior of the prior as close
+ as possible but also fits the supplied training data. In more technical
+ detail, having a prior means we replace the ||w||^2 regularizer with one
+ of the form ||w-prior||^2 where w is the set of parameters for a learned
+ function.
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() == C
+ !*/
+
+ const scalar_type get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter that
+ determines the trade off between trying to fit the training data exactly
+ or allowing more errors but hopefully improving the generalization of the
+ resulting classifier. Larger values encourage exact fitting while
+ smaller values of C may encourage better generalization.
+ !*/
+
+ const decision_function<kernel_type> train (
+ const std::vector<ranking_pair<sample_type> >& samples
+ ) const;
+ /*!
+ requires
+ - is_ranking_problem(samples) == true
+ - if (has_prior()) then
+ - The vectors in samples must have the same dimensionality as the
+ vectors used to train the prior given to set_prior().
+ ensures
+ - trains a ranking support vector classifier given the training samples.
+ - returns a decision function F with the following properties:
+ - F.alpha.size() == 1
+ - F.basis_vectors.size() == 1
+ - F.alpha(0) == 1
+ - Given two vectors, A and B, then A is predicted to come before B
+ in the learned ranking if and only if F(A) > F(B).
+ - Based on the contents of samples, F will attempt to give relevant
+ vectors higher scores than non-relevant vectors.
+ !*/
+
+ const decision_function<kernel_type> train (
+ const ranking_pair<sample_type>& sample
+ ) const;
+ /*!
+ requires
+ - is_ranking_problem(std::vector<ranking_pair<sample_type> >(1, sample)) == true
+ - if (has_prior()) then
+ - The vectors in samples must have the same dimensionality as the
+ vectors used to train the prior given to set_prior().
+ ensures
+ - This is just a convenience routine for calling the above train()
+ function. That is, it just copies sample into a std::vector object and
+ invokes the above train() method. This means that calling this function
+ is equivalent to invoking:
+ return train(std::vector<ranking_pair<sample_type> >(1, sample));
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVM_RANK_TrAINER_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/svm/svm_threaded.h b/ml/dlib/dlib/svm/svm_threaded.h
new file mode 100644
index 000000000..37927456b
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_threaded.h
@@ -0,0 +1,253 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_THREADED_
+#define DLIB_SVm_THREADED_
+
+#include <cmath>
+#include <iostream>
+#include <limits>
+#include <sstream>
+#include <vector>
+
+#include "svm_threaded_abstract.h"
+#include "svm.h"
+#include "../matrix.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "function.h"
+#include "kernel.h"
+#include "../threads.h"
+#include "../pipe.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace cvtti_helpers
+ {
+ template <typename trainer_type, typename in_sample_vector_type>
+ struct job
+ {
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
+ typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
+
+ job() : x(0) {}
+
+ trainer_type trainer;
+ matrix<long,0,1> x_test, x_train;
+ scalar_vector_type y_test, y_train;
+ const in_sample_vector_type* x;
+ };
+
+ struct task
+ {
+ template <
+ typename trainer_type,
+ typename mem_manager_type,
+ typename in_sample_vector_type
+ >
+ void operator()(
+ job<trainer_type,in_sample_vector_type>& j,
+ matrix<double,1,2,mem_manager_type>& result
+ )
+ {
+ try
+ {
+ result = test_binary_decision_function(j.trainer.train(rowm(*j.x,j.x_train), j.y_train), rowm(*j.x,j.x_test), j.y_test);
+
+ // Do this just to make j release it's memory since people might run threaded cross validation
+ // on very large datasets. Every bit of freed memory helps out.
+ j = job<trainer_type,in_sample_vector_type>();
+ }
+ catch (invalid_nu_error&)
+ {
+ // If this is a svm_nu_trainer then we might get this exception if the nu is
+ // invalid. In this case just return a cross validation score of 0.
+ result = 0;
+ }
+ catch (std::bad_alloc&)
+ {
+ std::cerr << "\nstd::bad_alloc thrown while running cross_validate_trainer_threaded(). Not enough memory.\n" << std::endl;
+ throw;
+ }
+ }
+ };
+ }
+
+ template <
+ typename trainer_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
+ cross_validate_trainer_threaded_impl (
+ const trainer_type& trainer,
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const long folds,
+ const long num_threads
+ )
+ {
+ using namespace dlib::cvtti_helpers;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y) == true &&
+ 1 < folds && folds <= std::min(sum(y>0),sum(y<0)) &&
+ num_threads > 0,
+ "\tmatrix cross_validate_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t std::min(sum(y>0),sum(y<0)): " << std::min(sum(y>0),sum(y<0))
+ << "\n\t folds: " << folds
+ << "\n\t num_threads: " << num_threads
+ << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
+ );
+
+
+ task mytask;
+ thread_pool tp(num_threads);
+
+
+ // count the number of positive and negative examples
+ long num_pos = 0;
+ long num_neg = 0;
+ for (long r = 0; r < y.nr(); ++r)
+ {
+ if (y(r) == +1.0)
+ ++num_pos;
+ else
+ ++num_neg;
+ }
+
+ // figure out how many positive and negative examples we will have in each fold
+ const long num_pos_test_samples = num_pos/folds;
+ const long num_pos_train_samples = num_pos - num_pos_test_samples;
+ const long num_neg_test_samples = num_neg/folds;
+ const long num_neg_train_samples = num_neg - num_neg_test_samples;
+
+
+ long pos_idx = 0;
+ long neg_idx = 0;
+
+
+
+ std::vector<future<job<trainer_type,in_sample_vector_type> > > jobs(folds);
+ std::vector<future<matrix<double, 1, 2, mem_manager_type> > > results(folds);
+
+
+ for (long i = 0; i < folds; ++i)
+ {
+ job<trainer_type,in_sample_vector_type>& j = jobs[i].get();
+
+ j.x = &x;
+ j.x_test.set_size (num_pos_test_samples + num_neg_test_samples);
+ j.y_test.set_size (num_pos_test_samples + num_neg_test_samples);
+ j.x_train.set_size(num_pos_train_samples + num_neg_train_samples);
+ j.y_train.set_size(num_pos_train_samples + num_neg_train_samples);
+ j.trainer = trainer;
+
+ long cur = 0;
+
+ // load up our positive test samples
+ while (cur < num_pos_test_samples)
+ {
+ if (y(pos_idx) == +1.0)
+ {
+ j.x_test(cur) = pos_idx;
+ j.y_test(cur) = +1.0;
+ ++cur;
+ }
+ pos_idx = (pos_idx+1)%x.nr();
+ }
+
+ // load up our negative test samples
+ while (cur < j.x_test.nr())
+ {
+ if (y(neg_idx) == -1.0)
+ {
+ j.x_test(cur) = neg_idx;
+ j.y_test(cur) = -1.0;
+ ++cur;
+ }
+ neg_idx = (neg_idx+1)%x.nr();
+ }
+
+ // load the training data from the data following whatever we loaded
+ // as the testing data
+ long train_pos_idx = pos_idx;
+ long train_neg_idx = neg_idx;
+ cur = 0;
+
+ // load up our positive train samples
+ while (cur < num_pos_train_samples)
+ {
+ if (y(train_pos_idx) == +1.0)
+ {
+ j.x_train(cur) = train_pos_idx;
+ j.y_train(cur) = +1.0;
+ ++cur;
+ }
+ train_pos_idx = (train_pos_idx+1)%x.nr();
+ }
+
+ // load up our negative train samples
+ while (cur < j.x_train.nr())
+ {
+ if (y(train_neg_idx) == -1.0)
+ {
+ j.x_train(cur) = train_neg_idx;
+ j.y_train(cur) = -1.0;
+ ++cur;
+ }
+ train_neg_idx = (train_neg_idx+1)%x.nr();
+ }
+
+ // finally spawn a task to process this job
+ tp.add_task(mytask, jobs[i], results[i]);
+
+ } // for (long i = 0; i < folds; ++i)
+
+ matrix<double, 1, 2, mem_manager_type> res;
+ set_all_elements(res,0);
+
+ // now compute the total results
+ for (long i = 0; i < folds; ++i)
+ {
+ res += results[i].get();
+ }
+
+ return res/(double)folds;
+ }
+
+ template <
+ typename trainer_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
+ cross_validate_trainer_threaded (
+ const trainer_type& trainer,
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const long folds,
+ const long num_threads
+ )
+ {
+ return cross_validate_trainer_threaded_impl(trainer,
+ mat(x),
+ mat(y),
+ folds,
+ num_threads);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_THREADED_
+
+
diff --git a/ml/dlib/dlib/svm/svm_threaded_abstract.h b/ml/dlib/dlib/svm/svm_threaded_abstract.h
new file mode 100644
index 000000000..f9973fb5c
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_threaded_abstract.h
@@ -0,0 +1,62 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_THREADED_ABSTRACT_
+#ifdef DLIB_SVm_THREADED_ABSTRACT_
+
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "../svm.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
+ cross_validate_trainer_threaded (
+ const trainer_type& trainer,
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const long folds,
+ const long num_threads
+ );
+ /*!
+ requires
+ - is_binary_classification_problem(x,y) == true
+ - 1 < folds <= std::min(sum(y>0),sum(y<0))
+ (e.g. There must be at least as many examples of each class as there are folds)
+ - trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
+ - num_threads > 0
+ - It must be safe for multiple trainer objects to access the elements of x from
+ multiple threads at the same time. Note that all trainers and kernels in
+ dlib are thread safe in this regard since they do not mutate the elements of x.
+ ensures
+ - performs k-fold cross validation by using the given trainer to solve the
+ given binary classification problem for the given number of folds.
+ Each fold is tested using the output of the trainer and the average
+ classification accuracy from all folds is returned.
+ - uses num_threads threads of execution in doing the cross validation.
+ - The accuracy is returned in a row vector, let us call it R. Both
+ quantities in R are numbers between 0 and 1 which represent the fraction
+ of examples correctly classified. R(0) is the fraction of +1 examples
+ correctly classified and R(1) is the fraction of -1 examples correctly
+ classified.
+ - The number of folds used is given by the folds argument.
+ throws
+ - any exceptions thrown by trainer.train()
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_THREADED_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/svr_linear_trainer.h b/ml/dlib/dlib/svm/svr_linear_trainer.h
new file mode 100644
index 000000000..27ce5b52a
--- /dev/null
+++ b/ml/dlib/dlib/svm/svr_linear_trainer.h
@@ -0,0 +1,424 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVR_LINEAR_TrAINER_Hh_
+#define DLIB_SVR_LINEAR_TrAINER_Hh_
+
+#include "svr_linear_trainer_abstract.h"
+
+#include "../algs.h"
+#include "../optimization.h"
+#include "function.h"
+#include "kernel.h"
+#include "sparse_vector.h"
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type,
+ typename sample_type
+ >
+ class oca_problem_linear_svr : public oca_problem<matrix_type >
+ {
+ public:
+ /*
+ This class is used as part of the implementation of the svr_linear_trainer
+ defined towards the end of this file.
+ */
+
+ typedef typename matrix_type::type scalar_type;
+
+ oca_problem_linear_svr(
+ const scalar_type C_,
+ const std::vector<sample_type>& samples_,
+ const std::vector<scalar_type>& targets_,
+ const bool be_verbose_,
+ const scalar_type eps_,
+ const scalar_type eps_insensitivity_,
+ const unsigned long max_iter
+ ) :
+ samples(samples_),
+ targets(targets_),
+ C(C_),
+ be_verbose(be_verbose_),
+ eps(eps_),
+ eps_insensitivity(eps_insensitivity_),
+ max_iterations(max_iter)
+ {
+ }
+
+ virtual scalar_type get_c (
+ ) const
+ {
+ return C;
+ }
+
+ virtual long get_num_dimensions (
+ ) const
+ {
+ // plus one for the bias term
+ return max_index_plus_one(samples) + 1;
+ }
+
+ virtual bool optimization_status (
+ scalar_type current_objective_value,
+ scalar_type current_error_gap,
+ scalar_type current_risk_value,
+ scalar_type current_risk_gap,
+ unsigned long num_cutting_planes,
+ unsigned long num_iterations
+ ) const
+ {
+ current_risk_value /= samples.size();
+ current_risk_gap /= samples.size();
+ if (be_verbose)
+ {
+ using namespace std;
+ cout << "objective: " << current_objective_value << endl;
+ cout << "objective gap: " << current_error_gap << endl;
+ cout << "risk: " << current_risk_value << endl;
+ cout << "risk gap: " << current_risk_gap << endl;
+ cout << "num planes: " << num_cutting_planes << endl;
+ cout << "iter: " << num_iterations << endl;
+ cout << endl;
+ }
+
+ if (num_iterations >= max_iterations)
+ return true;
+
+ if (current_risk_gap < eps*eps_insensitivity)
+ return true;
+
+ return false;
+ }
+
+ virtual bool risk_has_lower_bound (
+ scalar_type& lower_bound
+ ) const
+ {
+ lower_bound = 0;
+ return true;
+ }
+
+ virtual void get_risk (
+ matrix_type& w,
+ scalar_type& risk,
+ matrix_type& subgradient
+ ) const
+ {
+ subgradient.set_size(w.size(),1);
+ subgradient = 0;
+ risk = 0;
+
+ // loop over all the samples and compute the risk and its subgradient at the current solution point w
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ const long w_size_m1 = w.size()-1;
+ const scalar_type prediction = dot(colm(w,0,w_size_m1), samples[i]) - w(w_size_m1);
+
+ if (std::abs(prediction - targets[i]) > eps_insensitivity)
+ {
+ if (prediction < targets[i])
+ {
+ subtract_from(subgradient, samples[i]);
+ subgradient(w_size_m1) += 1;
+ }
+ else
+ {
+ add_to(subgradient, samples[i]);
+ subgradient(w_size_m1) -= 1;
+ }
+
+ risk += std::abs(prediction - targets[i]) - eps_insensitivity;
+ }
+ }
+ }
+
+ private:
+
+ // -----------------------------------------------------
+ // -----------------------------------------------------
+
+
+ const std::vector<sample_type>& samples;
+ const std::vector<scalar_type>& targets;
+ const scalar_type C;
+
+ const bool be_verbose;
+ const scalar_type eps;
+ const scalar_type eps_insensitivity;
+ const unsigned long max_iterations;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename matrix_type,
+ typename sample_type,
+ typename scalar_type
+ >
+ oca_problem_linear_svr<matrix_type, sample_type> make_oca_problem_linear_svr (
+ const scalar_type C,
+ const std::vector<sample_type>& samples,
+ const std::vector<scalar_type>& targets,
+ const bool be_verbose,
+ const scalar_type eps,
+ const scalar_type eps_insensitivity,
+ const unsigned long max_iterations
+ )
+ {
+ return oca_problem_linear_svr<matrix_type, sample_type>(
+ C, samples, targets, be_verbose, eps, eps_insensitivity, max_iterations);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svr_linear_trainer
+ {
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ // You are getting a compiler error on this line because you supplied a non-linear kernel
+ // to the svr_linear_trainer object. You have to use one of the linear kernels with this
+ // trainer.
+ COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
+ is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
+
+ svr_linear_trainer (
+ )
+ {
+ C = 1;
+ verbose = false;
+ eps = 0.01;
+ max_iterations = 10000;
+ learn_nonnegative_weights = false;
+ last_weight_1 = false;
+ eps_insensitivity = 0.1;
+ }
+
+ explicit svr_linear_trainer (
+ const scalar_type& C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t svr_linear_trainer::svr_linear_trainer()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ verbose = false;
+ eps = 0.01;
+ max_iterations = 10000;
+ learn_nonnegative_weights = false;
+ last_weight_1 = false;
+ eps_insensitivity = 0.1;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void svr_linear_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const { return eps; }
+
+ void set_epsilon_insensitivity (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\tvoid svr_linear_trainer::set_epsilon_insensitivity(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps_: " << eps_
+ );
+ eps_insensitivity = eps_;
+ }
+
+ const scalar_type get_epsilon_insensitivity (
+ ) const
+ {
+ return eps_insensitivity;
+ }
+
+ unsigned long get_max_iterations (
+ ) const { return max_iterations; }
+
+ void set_max_iterations (
+ unsigned long max_iter
+ )
+ {
+ max_iterations = max_iter;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ bool forces_last_weight_to_1 (
+ ) const
+ {
+ return last_weight_1;
+ }
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ )
+ {
+ last_weight_1 = should_last_weight_be_1;
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ return kernel_type();
+ }
+
+ bool learns_nonnegative_weights (
+ ) const { return learn_nonnegative_weights; }
+
+ void set_learns_nonnegative_weights (
+ bool value
+ )
+ {
+ learn_nonnegative_weights = value;
+ }
+
+ void set_c (
+ scalar_type C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void svr_linear_trainer::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ const scalar_type get_c (
+ ) const
+ {
+ return C;
+ }
+
+ const decision_function<kernel_type> train (
+ const std::vector<sample_type>& samples,
+ const std::vector<scalar_type>& targets
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT(is_learning_problem(samples, targets) == true,
+ "\t decision_function svr_linear_trainer::train(samples, targets)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t targets.size(): " << targets.size()
+ << "\n\t is_learning_problem(samples,targets): " << is_learning_problem(samples,targets)
+ );
+
+
+ typedef matrix<scalar_type,0,1> w_type;
+ w_type w;
+
+ const unsigned long num_dims = max_index_plus_one(samples);
+
+ unsigned long num_nonnegative = 0;
+ if (learn_nonnegative_weights)
+ {
+ num_nonnegative = num_dims;
+ }
+
+ unsigned long force_weight_1_idx = std::numeric_limits<unsigned long>::max();
+ if (last_weight_1)
+ {
+ force_weight_1_idx = num_dims-1;
+ }
+
+ solver( make_oca_problem_linear_svr<w_type>(C, samples, targets, verbose, eps, eps_insensitivity, max_iterations),
+ w,
+ num_nonnegative,
+ force_weight_1_idx);
+
+
+ // put the solution into a decision function and then return it
+ decision_function<kernel_type> df;
+ df.b = static_cast<scalar_type>(w(w.size()-1));
+ df.basis_vectors.set_size(1);
+ // Copy the plane normal into the output basis vector. The output vector might be a
+ // sparse vector container so we need to use this special kind of copy to handle that case.
+ // As an aside, the reason for using max_index_plus_one() and not just w.size()-1 is because
+ // doing it this way avoids an inane warning from gcc that can occur in some cases.
+ const long out_size = max_index_plus_one(samples);
+ assign(df.basis_vectors(0), matrix_cast<scalar_type>(colm(w, 0, out_size)));
+ df.alpha.set_size(1);
+ df.alpha(0) = 1;
+
+ return df;
+ }
+
+ private:
+
+ scalar_type C;
+ oca solver;
+ scalar_type eps;
+ bool verbose;
+ unsigned long max_iterations;
+ bool learn_nonnegative_weights;
+ bool last_weight_1;
+ scalar_type eps_insensitivity;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVR_LINEAR_TrAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/svr_linear_trainer_abstract.h b/ml/dlib/dlib/svm/svr_linear_trainer_abstract.h
new file mode 100644
index 000000000..c74310f06
--- /dev/null
+++ b/ml/dlib/dlib/svm/svr_linear_trainer_abstract.h
@@ -0,0 +1,269 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVR_LINEAR_TrAINER_ABSTRACT_Hh_
+#ifdef DLIB_SVR_LINEAR_TrAINER_ABSTRACT_Hh_
+
+#include "sparse_vector_abstract.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svr_linear_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ Is either linear_kernel or sparse_linear_kernel.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a trainer for performing epsilon-insensitive support
+ vector regression. It uses the oca optimizer so it is very efficient at
+ solving this problem when linear kernels are used, making it suitable for
+ use with large datasets.
+
+ For an introduction to support vector regression see the following paper:
+ A Tutorial on Support Vector Regression by Alex J. Smola and Bernhard Scholkopf.
+ Note that this object solves the version of support vector regression
+ defined by equation (3) in the paper, except that we incorporate the bias
+ term into the w vector by appending a 1 to the end of each sample.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svr_linear_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used to train a
+ ranking support vector machine.
+ - #get_oca() == oca() (i.e. an instance of oca with default parameters)
+ - #get_c() == 1
+ - #get_epsilon() == 0.01
+ - #get_epsilon_insensitivity() = 0.1
+ - This object will not be verbose unless be_verbose() is called
+ - #get_max_iterations() == 10000
+ - #learns_nonnegative_weights() == false
+ - #forces_last_weight_to_1() == false
+ !*/
+
+ explicit svr_linear_trainer (
+ const scalar_type& C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - This object is properly initialized and ready to be used to train a
+ ranking support vector machine.
+ - #get_oca() == oca() (i.e. an instance of oca with default parameters)
+ - #get_c() == C
+ - #get_epsilon() == 0.01
+ - #get_epsilon_insensitivity() = 0.1
+ - This object will not be verbose unless be_verbose() is called
+ - #get_max_iterations() == 10000
+ - #learns_nonnegative_weights() == false
+ - #forces_last_weight_to_1() == false
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Smaller values may result in a more accurate solution but take longer to
+ train. You can think of this epsilon value as saying "solve the
+ optimization problem until the average regression error is within epsilon
+ of its optimal value". See get_epsilon_insensitivity() below for a
+ definition of "regression error".
+ !*/
+
+ void set_epsilon_insensitivity (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon_insensitivity() == eps
+ !*/
+
+ const scalar_type get_epsilon_insensitivity (
+ ) const;
+ /*!
+ ensures
+ - This object tries to find a function which minimizes the regression error
+ on a training set. This error is measured in the following way:
+ - if (abs(predicted_value - true_labeled_value) < eps) then
+ - The error is 0. That is, any function which gets within eps of
+ the correct output is good enough.
+ - else
+ - The error grows linearly once it gets bigger than eps.
+
+ So epsilon-insensitive regression means we do regression but stop trying
+ to fit a data point once it is "close enough". This function returns
+ that eps value which controls what we mean by "close enough".
+ !*/
+
+ unsigned long get_max_iterations (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of iterations the SVM optimizer is allowed to
+ run before it is required to stop and return a result.
+ !*/
+
+ void set_max_iterations (
+ unsigned long max_iter
+ );
+ /*!
+ ensures
+ - #get_max_iterations() == max_iter
+ !*/
+
+ void be_verbose (
+ );
+ /*!
+ ensures
+ - This object will print status messages to standard out so that a user can
+ observe the progress of the algorithm.
+ !*/
+
+ void be_quiet (
+ );
+ /*!
+ ensures
+ - this object will not print anything to standard out
+ !*/
+
+ bool forces_last_weight_to_1 (
+ ) const;
+ /*!
+ ensures
+ - returns true if this trainer has the constraint that the last weight in
+ the learned parameter vector must be 1. This is the weight corresponding
+ to the feature in the training vectors with the highest dimension.
+ - Forcing the last weight to 1 also disables the bias and therefore the b
+ field of the learned decision_function will be 0 when forces_last_weight_to_1() == true.
+ !*/
+
+ void force_last_weight_to_1 (
+ bool should_last_weight_be_1
+ );
+ /*!
+ ensures
+ - #forces_last_weight_to_1() == should_last_weight_be_1
+ !*/
+
+ void set_oca (
+ const oca& item
+ );
+ /*!
+ ensures
+ - #get_oca() == item
+ !*/
+
+ const oca get_oca (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the optimizer used to solve the SVM problem.
+ !*/
+
+ const kernel_type get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object. Since the
+ linear kernels don't have any parameters this function just returns
+ kernel_type()
+ !*/
+
+ bool learns_nonnegative_weights (
+ ) const;
+ /*!
+ ensures
+ - The output of training is a weight vector and a bias value. These two
+ things define the resulting decision function. That is, the decision
+ function simply takes the dot product between the learned weight vector
+ and a test sample, then subtracts the bias value. Therefore, if
+ learns_nonnegative_weights() == true then the resulting learned weight
+ vector will always have non-negative entries. The bias value may still
+ be negative though.
+ !*/
+
+ void set_learns_nonnegative_weights (
+ bool value
+ );
+ /*!
+ ensures
+ - #learns_nonnegative_weights() == value
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() == C
+ !*/
+
+ const scalar_type get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVM regularization parameter. It is the parameter that
+ determines the trade off between trying to fit the training data exactly
+ or allowing more errors but hopefully improving the generalization of the
+ resulting classifier. Larger values encourage exact fitting while
+ smaller values of C may encourage better generalization.
+ !*/
+
+ const decision_function<kernel_type> train (
+ const std::vector<sample_type>& samples,
+ const std::vector<scalar_type>& targets
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(samples,targets) == true
+ ensures
+ - performs support vector regression given the training samples and targets.
+ - returns a decision_function F with the following properties:
+ - F(new_sample) == predicted target value for new_sample
+ - F.alpha.size() == 1
+ - F.basis_vectors.size() == 1
+ - F.alpha(0) == 1
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVR_LINEAR_TrAINER_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm/svr_trainer.h b/ml/dlib/dlib/svm/svr_trainer.h
new file mode 100644
index 000000000..bc6378a20
--- /dev/null
+++ b/ml/dlib/dlib/svm/svr_trainer.h
@@ -0,0 +1,393 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SVm_EPSILON_REGRESSION_TRAINER_Hh_
+#define DLIB_SVm_EPSILON_REGRESSION_TRAINER_Hh_
+
+
+#include "svr_trainer_abstract.h"
+#include <cmath>
+#include <limits>
+#include "../matrix.h"
+#include "../algs.h"
+
+#include "function.h"
+#include "kernel.h"
+#include "../optimization/optimization_solve_qp3_using_smo.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svr_trainer
+ {
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svr_trainer (
+ ) :
+ C(1),
+ eps_insensitivity(0.1),
+ cache_size(200),
+ eps(0.001)
+ {
+ }
+
+ void set_cache_size (
+ long cache_size_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(cache_size_ > 0,
+ "\tvoid svr_trainer::set_cache_size(cache_size_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t cache_size: " << cache_size_
+ );
+ cache_size = cache_size_;
+ }
+
+ long get_cache_size (
+ ) const
+ {
+ return cache_size;
+ }
+
+ void set_epsilon (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\tvoid svr_trainer::set_epsilon(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps_: " << eps_
+ );
+ eps = eps_;
+ }
+
+ const scalar_type get_epsilon (
+ ) const
+ {
+ return eps;
+ }
+
+ void set_epsilon_insensitivity (
+ scalar_type eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\tvoid svr_trainer::set_epsilon_insensitivity(eps_)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t eps_: " << eps_
+ );
+ eps_insensitivity = eps_;
+ }
+
+ const scalar_type get_epsilon_insensitivity (
+ ) const
+ {
+ return eps_insensitivity;
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kernel_function = k;
+ }
+
+ const kernel_type& get_kernel (
+ ) const
+ {
+ return kernel_function;
+ }
+
+ void set_c (
+ scalar_type C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void svr_trainer::set_c()"
+ << "\n\t C must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ const scalar_type get_c (
+ ) const
+ {
+ return C;
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ return do_train(mat(x), mat(y));
+ }
+
+ void swap (
+ svr_trainer& item
+ )
+ {
+ exchange(kernel_function, item.kernel_function);
+ exchange(C, item.C);
+ exchange(eps_insensitivity, item.eps_insensitivity);
+ exchange(cache_size, item.cache_size);
+ exchange(eps, item.eps);
+ }
+
+ private:
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename M>
+ struct op_quad
+ {
+ explicit op_quad(
+ const M& m_
+ ) : m(m_) {}
+
+ const M& m;
+
+ typedef typename M::type type;
+ typedef type const_ret_type;
+ const static long cost = M::cost + 2;
+
+ inline const_ret_type apply ( long r, long c) const
+ {
+ if (r < m.nr())
+ {
+ if (c < m.nc())
+ {
+ return m(r,c);
+ }
+ else
+ {
+ return -m(r,c-m.nc());
+ }
+ }
+ else
+ {
+ if (c < m.nc())
+ {
+ return -m(r-m.nr(),c);
+ }
+ else
+ {
+ return m(r-m.nr(),c-m.nc());
+ }
+ }
+ }
+
+ const static long NR = 2*M::NR;
+ const static long NC = 2*M::NC;
+ typedef typename M::mem_manager_type mem_manager_type;
+ typedef typename M::layout_type layout_type;
+
+ long nr () const { return 2*m.nr(); }
+ long nc () const { return 2*m.nc(); }
+
+ template <typename U> bool aliases ( const matrix_exp<U>& item) const
+ { return m.aliases(item); }
+ template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
+ { return m.aliases(item); }
+ };
+
+ template <
+ typename EXP
+ >
+ const matrix_op<op_quad<EXP> > make_quad (
+ const matrix_exp<EXP>& m
+ ) const
+ /*!
+ ensures
+ - returns the following matrix:
+ m -m
+ -m m
+ - I.e. returns a matrix that is twice the size of m and just
+ contains copies of m and -m
+ !*/
+ {
+ typedef op_quad<EXP> op;
+ return matrix_op<op>(op(m.ref()));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ typedef typename K::scalar_type scalar_type;
+ typedef typename decision_function<K>::sample_vector_type sample_vector_type;
+ typedef typename decision_function<K>::scalar_vector_type scalar_vector_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y) == true,
+ "\tdecision_function svr_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.nr(): " << x.nr()
+ << "\n\t y.nr(): " << y.nr()
+ << "\n\t x.nc(): " << x.nc()
+ << "\n\t y.nc(): " << y.nc()
+ );
+
+
+ scalar_vector_type alpha;
+
+ solve_qp3_using_smo<scalar_vector_type> solver;
+
+ solver(symmetric_matrix_cache<float>(make_quad(kernel_matrix(kernel_function,x)), cache_size),
+ uniform_matrix<scalar_type>(2*x.size(),1, eps_insensitivity) + join_cols(y,-y),
+ join_cols(uniform_matrix<scalar_type>(x.size(),1,1), uniform_matrix<scalar_type>(x.size(),1,-1)),
+ 0,
+ C,
+ C,
+ alpha,
+ eps);
+
+ scalar_type b;
+ calculate_b(alpha,solver.get_gradient(),C,b);
+
+ alpha = -rowm(alpha,range(0,x.size()-1)) + rowm(alpha,range(x.size(), alpha.size()-1));
+
+ // count the number of support vectors
+ const long sv_count = (long)sum(alpha != 0);
+
+ scalar_vector_type sv_alpha;
+ sample_vector_type support_vectors;
+
+ // size these column vectors so that they have an entry for each support vector
+ sv_alpha.set_size(sv_count);
+ support_vectors.set_size(sv_count);
+
+ // load the support vectors and their alpha values into these new column matrices
+ long idx = 0;
+ for (long i = 0; i < alpha.nr(); ++i)
+ {
+ if (alpha(i) != 0)
+ {
+ sv_alpha(idx) = alpha(i);
+ support_vectors(idx) = x(i);
+ ++idx;
+ }
+ }
+
+ // now return the decision function
+ return decision_function<K> (sv_alpha, -b, kernel_function, support_vectors);
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <
+ typename scalar_vector_type
+ >
+ void calculate_b(
+ const scalar_vector_type& alpha,
+ const scalar_vector_type& df,
+ const scalar_type& C,
+ scalar_type& b
+ ) const
+ {
+ using namespace std;
+ long num_free = 0;
+ scalar_type sum_free = 0;
+
+ scalar_type upper_bound = -numeric_limits<scalar_type>::infinity();
+ scalar_type lower_bound = numeric_limits<scalar_type>::infinity();
+
+ find_min_and_max(df, upper_bound, lower_bound);
+
+ for(long i = 0; i < alpha.nr(); ++i)
+ {
+ if(i < alpha.nr()/2)
+ {
+ if(alpha(i) == C)
+ {
+ if (df(i) > upper_bound)
+ upper_bound = df(i);
+ }
+ else if(alpha(i) == 0)
+ {
+ if (df(i) < lower_bound)
+ lower_bound = df(i);
+ }
+ else
+ {
+ ++num_free;
+ sum_free += df(i);
+ }
+ }
+ else
+ {
+ if(alpha(i) == C)
+ {
+ if (-df(i) < lower_bound)
+ lower_bound = -df(i);
+ }
+ else if(alpha(i) == 0)
+ {
+ if (-df(i) > upper_bound)
+ upper_bound = -df(i);
+ }
+ else
+ {
+ ++num_free;
+ sum_free -= df(i);
+ }
+ }
+ }
+
+ if(num_free > 0)
+ b = sum_free/num_free;
+ else
+ b = (upper_bound+lower_bound)/2;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+
+ kernel_type kernel_function;
+ scalar_type C;
+ scalar_type eps_insensitivity;
+ long cache_size;
+ scalar_type eps;
+ }; // end of class svr_trainer
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ svr_trainer<K>& a,
+ svr_trainer<K>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_EPSILON_REGRESSION_TRAINER_Hh_
+
diff --git a/ml/dlib/dlib/svm/svr_trainer_abstract.h b/ml/dlib/dlib/svm/svr_trainer_abstract.h
new file mode 100644
index 000000000..c1dd5f1f3
--- /dev/null
+++ b/ml/dlib/dlib/svm/svr_trainer_abstract.h
@@ -0,0 +1,209 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_EPSILON_REGRESSION_TRAINER_ABSTRACT_
+#ifdef DLIB_SVm_EPSILON_REGRESSION_TRAINER_ABSTRACT_
+
+#include <cmath>
+#include <limits>
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+#include "../optimization/optimization_solve_qp3_using_smo_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class svr_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a trainer for performing epsilon-insensitive support
+ vector regression. It is implemented using the SMO algorithm.
+
+ The implementation of the eps-SVR training algorithm used by this object is based
+ on the following paper:
+ - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector
+ machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ svr_trainer (
+ );
+ /*!
+ ensures
+ - This object is properly initialized and ready to be used
+ to train a support vector machine.
+ - #get_c() == 1
+ - #get_epsilon_insensitivity() == 0.1
+ - #get_cache_size() == 200
+ - #get_epsilon() == 0.001
+ !*/
+
+ void set_cache_size (
+ long cache_size
+ );
+ /*!
+ requires
+ - cache_size > 0
+ ensures
+ - #get_cache_size() == cache_size
+ !*/
+
+ const long get_cache_size (
+ ) const;
+ /*!
+ ensures
+ - returns the number of megabytes of cache this object will use
+ when it performs training via the this->train() function.
+ (bigger values of this may make training go faster but won't affect
+ the result. However, too big a value will cause you to run out of
+ memory, obviously.)
+ !*/
+
+ void set_epsilon (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon() == eps
+ !*/
+
+ const scalar_type get_epsilon (
+ ) const;
+ /*!
+ ensures
+ - returns the error epsilon that determines when training should stop.
+ Generally a good value for this is 0.001. Smaller values may result
+ in a more accurate solution but take longer to execute.
+ !*/
+
+ void set_epsilon_insensitivity (
+ scalar_type eps
+ );
+ /*!
+ requires
+ - eps > 0
+ ensures
+ - #get_epsilon_insensitivity() == eps
+ !*/
+
+ const scalar_type get_epsilon_insensitivity (
+ ) const;
+ /*!
+ ensures
+ - This object tries to find a function which minimizes the
+ regression error on a training set. This error is measured
+ in the following way:
+ - if (abs(predicted_value - true_labeled_value) < eps) then
+ - The error is 0. That is, any function which gets within
+ eps of the correct output is good enough.
+ - else
+ - The error grows linearly once it gets bigger than eps
+
+ So epsilon-insensitive regression means we do regression but
+ stop trying to fit a data point once it is "close enough".
+ This function returns that eps value which controls what we
+ mean by "close enough".
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ void set_c (
+ scalar_type C
+ );
+ /*!
+ requires
+ - C > 0
+ ensures
+ - #get_c() == C
+ !*/
+
+ const scalar_type get_c (
+ ) const;
+ /*!
+ ensures
+ - returns the SVR regularization parameter. It is the parameter that
+ determines the trade-off between trying to reduce the training error
+ or allowing more errors but hopefully improving the generalization
+ of the resulting decision_function. Larger values encourage exact
+ fitting while smaller values of C may encourage better generalization.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const;
+ /*!
+ requires
+ - is_learning_problem(x,y) == true
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ ensures
+ - performs support vector regression given the training samples in x and
+ target values in y.
+ - returns a decision_function F with the following properties:
+ - F(new_x) == predicted y value
+ !*/
+
+ void swap (
+ svr_trainer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+ };
+
+ template <typename K>
+ void swap (
+ svr_trainer<K>& a,
+ svr_trainer<K>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_EPSILON_REGRESSION_TRAINER_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/svm/track_association_function.h b/ml/dlib/dlib/svm/track_association_function.h
new file mode 100644
index 000000000..bf5ef36c7
--- /dev/null
+++ b/ml/dlib/dlib/svm/track_association_function.h
@@ -0,0 +1,154 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_
+#define DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_
+
+
+#include "track_association_function_abstract.h"
+#include <vector>
+#include <iostream>
+#include "../algs.h"
+#include "../serialize.h"
+#include "assignment_function.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename detection_type
+ >
+ class feature_extractor_track_association
+ {
+ public:
+ typedef typename detection_type::track_type track_type;
+ typedef typename track_type::feature_vector_type feature_vector_type;
+
+ typedef detection_type lhs_element;
+ typedef track_type rhs_element;
+
+ feature_extractor_track_association() : num_dims(0), num_nonnegative(0) {}
+
+ explicit feature_extractor_track_association (
+ unsigned long num_dims_,
+ unsigned long num_nonnegative_
+ ) : num_dims(num_dims_), num_nonnegative(num_nonnegative_) {}
+
+ unsigned long num_features(
+ ) const { return num_dims; }
+
+ unsigned long num_nonnegative_weights (
+ ) const { return num_nonnegative; }
+
+ void get_features (
+ const detection_type& det,
+ const track_type& track,
+ feature_vector_type& feats
+ ) const
+ {
+ track.get_similarity_features(det, feats);
+ }
+
+ friend void serialize (const feature_extractor_track_association& item, std::ostream& out)
+ {
+ serialize(item.num_dims, out);
+ serialize(item.num_nonnegative, out);
+ }
+
+ friend void deserialize (feature_extractor_track_association& item, std::istream& in)
+ {
+ deserialize(item.num_dims, in);
+ deserialize(item.num_nonnegative, in);
+ }
+
+ private:
+ unsigned long num_dims;
+ unsigned long num_nonnegative;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename detection_type_
+ >
+ class track_association_function
+ {
+ public:
+
+ typedef detection_type_ detection_type;
+ typedef typename detection_type::track_type track_type;
+ typedef assignment_function<feature_extractor_track_association<detection_type> > association_function_type;
+
+ track_association_function() {}
+
+ track_association_function (
+ const association_function_type& assoc_
+ ) : assoc(assoc_)
+ {
+ }
+
+ const association_function_type& get_assignment_function (
+ ) const
+ {
+ return assoc;
+ }
+
+ void operator() (
+ std::vector<track_type>& tracks,
+ const std::vector<detection_type>& dets
+ ) const
+ {
+ std::vector<long> assignments = assoc(dets, tracks);
+ std::vector<bool> updated_track(tracks.size(), false);
+ // now update all the tracks with the detections that associated to them.
+ for (unsigned long i = 0; i < assignments.size(); ++i)
+ {
+ if (assignments[i] != -1)
+ {
+ tracks[assignments[i]].update_track(dets[i]);
+ updated_track[assignments[i]] = true;
+ }
+ else
+ {
+ track_type new_track;
+ new_track.update_track(dets[i]);
+ tracks.push_back(new_track);
+ }
+ }
+
+ // Now propagate all the tracks that didn't get any detections.
+ for (unsigned long i = 0; i < updated_track.size(); ++i)
+ {
+ if (!updated_track[i])
+ tracks[i].propagate_track();
+ }
+ }
+
+ friend void serialize (const track_association_function& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.assoc, out);
+ }
+ friend void deserialize (track_association_function& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::track_association_function.");
+
+ deserialize(item.assoc, in);
+ }
+
+ private:
+
+ assignment_function<feature_extractor_track_association<detection_type> > assoc;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_
+
diff --git a/ml/dlib/dlib/svm/track_association_function_abstract.h b/ml/dlib/dlib/svm/track_association_function_abstract.h
new file mode 100644
index 000000000..8a6fe153c
--- /dev/null
+++ b/ml/dlib/dlib/svm/track_association_function_abstract.h
@@ -0,0 +1,271 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_TRACK_ASSOCiATION_FUNCTION_ABSTRACT_Hh_
+#ifdef DLIB_TRACK_ASSOCiATION_FUNCTION_ABSTRACT_Hh_
+
+#include <vector>
+#include "assignment_function_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class example_detection
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines the interface a detection must implement if it is to be
+ used with the track_association_function defined at the bottom of this
+ file. In this case, the interface is very simple. A detection object is
+ only required to define the track_type typedef and it must also be possible
+ to store detection objects in a std::vector.
+ !*/
+
+ public:
+ // Each detection object should be designed to work with a specific track object.
+ // This typedef lets us determine which track type is meant for use with this
+ // detection object.
+ typedef class example_track track_type;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class example_track
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object defines the interface a track must implement if it is to be
+ used with the track_association_function defined at the bottom of this
+ file.
+ !*/
+
+ public:
+ // This type should be a dlib::matrix capable of storing column vectors or an
+ // unsorted sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
+ typedef matrix_or_sparse_vector_type feature_vector_type;
+
+ example_track(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void get_similarity_features (
+ const example_detection& det,
+ feature_vector_type& feats
+ ) const;
+ /*!
+ requires
+ - update_track() has been called on this track at least once.
+ ensures
+ - #feats == A feature vector that contains information describing how
+ likely it is that det is a detection from the object corresponding to
+ this track. That is, the feature vector should contain information that
+ lets someone decide if det should be associated to this track.
+ - #feats.size() must be a constant. That is, every time we call
+ get_similarity_features() it must output a feature vector of the same
+ dimensionality.
+ !*/
+
+ void update_track (
+ const example_detection& det
+ );
+ /*!
+ ensures
+ - Updates this track with the given detection assuming that det is the most
+ current observation of the object under track.
+ !*/
+
+ void propagate_track (
+ );
+ /*!
+ ensures
+ - propagates this track forward in time one time step.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename detection_type
+ >
+ class feature_extractor_track_association
+ {
+ /*!
+ REQUIREMENTS ON detection_type
+ It must be an object that implements an interface compatible with the
+ example_detection discussed above. This also means that detection_type::track_type
+ must be an object that implements an interface compatible with example_track
+ defined above.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is an adapter that converts from the detection/track style
+ interface defined above to the feature extraction interface required by the
+ association rule learning tools in dlib. Specifically, it converts the
+ detection/track interface into a form usable by the assignment_function and
+ its trainer object structural_assignment_trainer.
+ !*/
+
+ public:
+ typedef typename detection_type::track_type track_type;
+ typedef typename track_type::feature_vector_type feature_vector_type;
+ typedef detection_type lhs_element;
+ typedef track_type rhs_element;
+
+ unsigned long num_features(
+ ) const;
+ /*!
+ ensures
+ - returns the dimensionality of the feature vectors produced by get_features().
+ !*/
+
+ void get_features (
+ const detection_type& det,
+ const track_type& track,
+ feature_vector_type& feats
+ ) const;
+ /*!
+ ensures
+ - performs: track.get_similarity_features(det, feats);
+ !*/
+ };
+
+ template <
+ typename detection_type
+ >
+ void serialize (
+ const feature_extractor_track_association<detection_type>& item,
+ std::ostream& out
+ );
+ /*!
+ Provides serialization support.
+ !*/
+
+ template <
+ typename detection_type
+ >
+ void deserialize (
+ feature_extractor_track_association<detection_type>& item,
+ std::istream& in
+ );
+ /*!
+ Provides deserialization support.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename detection_type_
+ >
+ class track_association_function
+ {
+ /*!
+ REQUIREMENTS ON detection_type
+ It must be an object that implements an interface compatible with the
+ example_detection discussed above. This also means that detection_type::track_type
+ must be an object that implements an interface compatible with example_track
+ defined above.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool that helps you implement an object tracker. So for
+ example, if you wanted to track people moving around in a video then this
+ object can help. In particular, imagine you have a tool for detecting the
+ positions of each person in an image. Then you can run this person
+ detector on the video and at each time step, i.e. at each frame, you get a
+ set of person detections. However, that by itself doesn't tell you how
+ many people there are in the video and where they are moving to and from.
+ To get that information you need to figure out which detections match each
+ other from frame to frame. This is where the track_association_function
+ comes in. It performs the detection to track association. It will also do
+ some of the track management tasks like creating a new track when a
+ detection doesn't match any of the existing tracks.
+
+ Internally, this object is implemented using the assignment_function object.
+ In fact, it's really just a thin wrapper around assignment_function and
+ exists just to provide a more convenient interface to users doing detection
+ to track association.
+ !*/
+ public:
+
+ typedef detection_type_ detection_type;
+ typedef typename detection_type::track_type track_type;
+ typedef assignment_function<feature_extractor_track_association<detection_type> > association_function_type;
+
+ track_association_function(
+ );
+ /*!
+ ensures
+ - #get_assignment_function() will be default initialized.
+ !*/
+
+ track_association_function (
+ const association_function_type& assoc
+ );
+ /*!
+ ensures
+ - #get_assignment_function() == assoc
+ !*/
+
+ const association_function_type& get_assignment_function (
+ ) const;
+ /*!
+ ensures
+ - returns the assignment_function used by this object to assign detections
+ to tracks.
+ !*/
+
+ void operator() (
+ std::vector<track_type>& tracks,
+ const std::vector<detection_type>& dets
+ ) const;
+ /*!
+ ensures
+ - This function uses get_assignment_function() to assign each detection
+ in dets to its appropriate track in tracks. Then each track which
+ associates to a detection is updated by calling update_track() with the
+ associated detection.
+ - Detections that don't associate with any of the elements of tracks will
+ spawn new tracks. For each unassociated detection, this is done by
+ creating a new track_type object, calling update_track() on it with the
+ new detection, and then adding the new track into tracks.
+ - Tracks that don't have a detection associate to them are propagated
+ forward in time by calling propagate_track() on them. That is, we call
+ propagate_track() only on tracks that do not get associated with a
+ detection.
+ !*/
+ };
+
+ template <
+ typename detection_type
+ >
+ void serialize (
+ const track_association_function<detection_type>& item,
+ std::ostream& out
+ );
+ /*!
+ Provides serialization support.
+ !*/
+
+ template <
+ typename detection_type
+ >
+ void deserialize (
+ track_association_function<detection_type>& item,
+ std::istream& in
+ );
+ /*!
+ Provides deserialization support.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TRACK_ASSOCiATION_FUNCTION_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/svm_threaded.h b/ml/dlib/dlib/svm_threaded.h
new file mode 100644
index 000000000..f77fab705
--- /dev/null
+++ b/ml/dlib/dlib/svm_threaded.h
@@ -0,0 +1,36 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_SVm_THREADED_HEADER
+#define DLIB_SVm_THREADED_HEADER
+
+#include "svm.h"
+#include "svm/svm_threaded.h"
+#include "svm/structural_svm_problem_threaded.h"
+#include "svm/structural_svm_distributed.h"
+#include "svm/structural_svm_object_detection_problem.h"
+#include "svm/structural_object_detection_trainer.h"
+#include "svm/structural_svm_sequence_labeling_problem.h"
+#include "svm/structural_sequence_labeling_trainer.h"
+
+#include "svm/structural_svm_assignment_problem.h"
+#include "svm/structural_assignment_trainer.h"
+#include "svm/cross_validate_track_association_trainer.h"
+#include "svm/structural_track_association_trainer.h"
+
+#include "svm/structural_svm_graph_labeling_problem.h"
+#include "svm/structural_graph_labeling_trainer.h"
+#include "svm/cross_validate_graph_labeling_trainer.h"
+#include "svm/svm_multiclass_linear_trainer.h"
+#include "svm/one_vs_one_trainer.h"
+#include "svm/one_vs_all_trainer.h"
+#include "svm/structural_sequence_segmentation_trainer.h"
+
+#endif // DLIB_SVm_THREADED_HEADER
+
+
+
diff --git a/ml/dlib/dlib/sync_extension.h b/ml/dlib/dlib/sync_extension.h
new file mode 100644
index 000000000..8a50bce01
--- /dev/null
+++ b/ml/dlib/dlib/sync_extension.h
@@ -0,0 +1,31 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SYNC_EXTENSIOn_
+#define DLIB_SYNC_EXTENSIOn_
+
+#include "sync_extension/sync_extension_kernel_1.h"
+
+
+
+namespace dlib
+{
+
+ template <
+ typename base
+ >
+ class sync_extension
+ {
+ sync_extension() {}
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef sync_extension_kernel_1<base>
+ kernel_1a;
+
+ };
+}
+
+#endif // DLIB_SYNC_EXTENSIOn_
+
diff --git a/ml/dlib/dlib/sync_extension/sync_extension_kernel_1.h b/ml/dlib/dlib/sync_extension/sync_extension_kernel_1.h
new file mode 100644
index 000000000..71fe7c391
--- /dev/null
+++ b/ml/dlib/dlib/sync_extension/sync_extension_kernel_1.h
@@ -0,0 +1,67 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SYNC_EXTENSION_KERNEl_1_
+#define DLIB_SYNC_EXTENSION_KERNEl_1_
+
+#include "../threads.h"
+#include "../algs.h"
+#include "sync_extension_kernel_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename base
+ >
+ class sync_extension_kernel_1 : public base
+ {
+
+ rmutex m;
+ rsignaler s;
+
+ public:
+
+ sync_extension_kernel_1 () : s(m) {}
+
+ template < typename T >
+ sync_extension_kernel_1 (const T& one) : base(one),s(m) {}
+ template < typename T, typename U >
+ sync_extension_kernel_1 (const T& one, const U& two) : base(one,two),s(m) {}
+
+
+ const rmutex& get_mutex(
+ ) const { return m; }
+
+ void lock (
+ ) const { m.lock(); }
+
+ void unlock (
+ ) const { m.unlock(); }
+
+ void wait (
+ ) const { s.wait(); }
+
+ bool wait_or_timeout (
+ unsigned long milliseconds
+ ) const { return s.wait_or_timeout(milliseconds); }
+
+ void broadcast (
+ ) const { s.broadcast(); }
+
+ void signal (
+ ) const { s.signal(); }
+
+ };
+
+ template <
+ typename base
+ >
+ inline void swap (
+ sync_extension_kernel_1<base>& a,
+ sync_extension_kernel_1<base>& b
+ ) { a.swap(b); }
+
+}
+
+#endif // DLIB_SYNC_EXTENSION_KERNEl_1_
+
diff --git a/ml/dlib/dlib/sync_extension/sync_extension_kernel_abstract.h b/ml/dlib/dlib/sync_extension/sync_extension_kernel_abstract.h
new file mode 100644
index 000000000..35665d430
--- /dev/null
+++ b/ml/dlib/dlib/sync_extension/sync_extension_kernel_abstract.h
@@ -0,0 +1,190 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SYNC_EXTENSION_KERNEl_ABSTRACT_
+#ifdef DLIB_SYNC_EXTENSION_KERNEl_ABSTRACT_
+
+#include "../threads/threads_kernel_abstract.h"
+#include "../threads/rmutex_extension_abstract.h"
+#include "../threads/rsignaler_extension_abstract.h"
+#include "../algs.h"
+
+namespace dlib
+{
+
+ template <
+ typename base
+ >
+ class sync_extension : public base
+ {
+
+ /*!
+ REQUIREMENTS ON base
+ base must have a default constructor
+ base must implement swap(base&)
+
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a general extension to any object (given the
+ restrictions on base). This object gives any object which it extends
+ an integrated rmutex and rsignaler object. The extended object will
+ then be able to be treated as if it was also a rmutex and rsignaler.
+
+ NOTE that just like the threading api, this object does not check
+ its requires clauses so be careful with it.
+
+ Also note that swap() does not swap the rmutex and rsignaler objects.
+ the rmutex and rsignaler are associated with the object instance itself,
+ not with whatever the object represents.
+ !*/
+
+
+ public:
+
+ sync_extension (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ this is thrown if there is a problem gathering memory
+ - dlib::thread_error
+ this is thrown if there is a problem creating threading objects
+ - any exception thrown by the constructor for the parent class base
+ !*/
+
+ template <
+ typename T
+ >
+ sync_extension (
+ const T& one
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - the argument one will be passed on to the constructor for the parent
+ class base.
+ throws
+ - std::bad_alloc
+ this is thrown if there is a problem gathering memory
+ - dlib::thread_error
+ this is thrown if there is a problem creating threading objects
+ - any exception thrown by the constructor for the parent class base
+ !*/
+
+ template <
+ typename T,
+ typename U
+ >
+ sync_extension (
+ const T& one,
+ const T& two
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - the argument one will be passed on to the constructor for the parent
+ class base as its first argument.
+ - the argument two will be passed on to the constructor for the parent
+ class base as its second argument.
+ throws
+ - std::bad_alloc
+ this is thrown if there is a problem gathering memory
+ - dlib::thread_error
+ this is thrown if there is a problem creating threading objects
+ - any exception thrown by the constructor for the parent class base
+ !*/
+
+
+ const rmutex& get_mutex (
+ ) const;
+ /*!
+ ensures
+ - returns the rmutex embedded in this object
+ !*/
+
+ void lock (
+ ) const;
+ /*!
+ requires
+ - the thread calling lock() does not already have a lock on *this
+ ensures
+ - if (*this is currently locked by another thread) then
+ - the thread that called lock() on *this is put to sleep until
+ it becomes available
+ - if (*this is currently unlocked) then
+ - #*this becomes locked and the current thread is NOT put to sleep
+ but now "owns" #*this
+ !*/
+
+ void unlock (
+ ) const;
+ /*!
+ ensures
+ - #*this is unlocked (i.e. other threads may now lock this object)
+ !*/
+
+
+ void wait (
+ ) const;
+ /*!
+ requires
+ - *this is locked and owned by the calling thread
+ ensures
+ - atomically unlocks *this and blocks the calling thread
+ - calling thread will wake if another thread calls signal() or broadcast()
+ on *this
+ - when wait returns the calling thread again has a lock on #*this
+ !*/
+
+
+ bool wait_or_timeout (
+ unsigned long milliseconds
+ ) const;
+ /*!
+ requires
+ - *this is locked and owned by the calling thread
+ ensures
+ - atomically unlocks *this and blocks the calling thread
+ - calling thread will wake if another thread calls signal() or broadcast()
+ on *this
+ - after the specified number of milliseconds has elapsed the calling thread
+ will wake once *this is free to be locked
+ - when wait returns the calling thread again has a lock on #*this
+
+ - returns false if the call to wait_or_timeout timed out
+ - returns true if the call did not time out
+ !*/
+
+ void signal (
+ ) const;
+ /*!
+ ensures
+ - if (at least one thread is waiting on *this) then
+ - at least one of the waiting threads will wake
+ !*/
+
+ void broadcast (
+ ) const;
+ /*!
+ ensures
+ - any and all threads waiting on *this will wake
+ !*/
+
+ };
+
+ template <
+ typename base
+ >
+ inline void swap (
+ sync_extension<base>& a,
+ sync_extension<base>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_SYNC_EXTENSION_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/test/CMakeLists.txt b/ml/dlib/dlib/test/CMakeLists.txt
new file mode 100644
index 000000000..d6147cb0e
--- /dev/null
+++ b/ml/dlib/dlib/test/CMakeLists.txt
@@ -0,0 +1,181 @@
+#
+# This is a CMake makefile. You can find the cmake utility and
+# information about it at http://www.cmake.org
+#
+
+cmake_minimum_required(VERSION 2.8.12)
+
+# create a variable called target_name and set it to the string "dtest"
+set (target_name dtest)
+PROJECT(${target_name})
+
+# compile the dlib/all/source.cpp file into its own object just to make sure it compiles
+set(DLIB_TEST_COMPILE_ALL_SOURCE_CPP ON)
+
+add_subdirectory(.. dlib_build)
+
+# This variable contains a list of all the tests we are building
+# into the regression test suite.
+set (tests
+ example.cpp
+ active_learning.cpp
+ any.cpp
+ any_function.cpp
+ array2d.cpp
+ array.cpp
+ assignment_learning.cpp
+ base64.cpp
+ bayes_nets.cpp
+ bigint.cpp
+ binary_search_tree_kernel_1a.cpp
+ binary_search_tree_kernel_2a.cpp
+ binary_search_tree_mm1.cpp
+ binary_search_tree_mm2.cpp
+ bridge.cpp
+ bsp.cpp
+ byte_orderer.cpp
+ cca.cpp
+ clustering.cpp
+ cmd_line_parser.cpp
+ cmd_line_parser_wchar_t.cpp
+ compress_stream.cpp
+ conditioning_class_c.cpp
+ conditioning_class.cpp
+ config_reader.cpp
+ correlation_tracker.cpp
+ crc32.cpp
+ create_iris_datafile.cpp
+ data_io.cpp
+ directed_graph.cpp
+ discriminant_pca.cpp
+ disjoint_subsets.cpp
+ disjoint_subsets_sized.cpp
+ ekm_and_lisf.cpp
+ empirical_kernel_map.cpp
+ entropy_coder.cpp
+ entropy_encoder_model.cpp
+ example_args.cpp
+ face.cpp
+ fft.cpp
+ fhog.cpp
+ filtering.cpp
+ find_max_factor_graph_nmplp.cpp
+ find_max_factor_graph_viterbi.cpp
+ geometry.cpp
+ graph.cpp
+ graph_cuts.cpp
+ graph_labeler.cpp
+ hash.cpp
+ hash_map.cpp
+ hash_set.cpp
+ hash_table.cpp
+ hog_image.cpp
+ image.cpp
+ iosockstream.cpp
+ is_same_object.cpp
+ isotonic_regression.cpp
+ kcentroid.cpp
+ kernel_matrix.cpp
+ kmeans.cpp
+ learning_to_track.cpp
+ least_squares.cpp
+ linear_manifold_regularizer.cpp
+ lspi.cpp
+ lz77_buffer.cpp
+ map.cpp
+ matrix2.cpp
+ matrix3.cpp
+ matrix4.cpp
+ matrix_chol.cpp
+ matrix.cpp
+ matrix_eig.cpp
+ matrix_lu.cpp
+ matrix_qr.cpp
+ max_cost_assignment.cpp
+ max_sum_submatrix.cpp
+ md5.cpp
+ member_function_pointer.cpp
+ metaprogramming.cpp
+ mpc.cpp
+ multithreaded_object.cpp
+ numerical_integration.cpp
+ object_detector.cpp
+ oca.cpp
+ one_vs_all_trainer.cpp
+ one_vs_one_trainer.cpp
+ optimization.cpp
+ optimization_test_functions.cpp
+ global_optimization.cpp
+ opt_qp_solver.cpp
+ parallel_for.cpp
+ parse.cpp
+ pipe.cpp
+ pixel.cpp
+ probabilistic.cpp
+ pyramid_down.cpp
+ queue.cpp
+ rand.cpp
+ ranking.cpp
+ read_write_mutex.cpp
+ reference_counter.cpp
+ rls.cpp
+ random_forest.cpp
+ sammon.cpp
+ scan_image.cpp
+ sequence.cpp
+ sequence_labeler.cpp
+ sequence_segmenter.cpp
+ serialize.cpp
+ set.cpp
+ sldf.cpp
+ sliding_buffer.cpp
+ sockets2.cpp
+ sockets.cpp
+ sockstreambuf.cpp
+ sparse_vector.cpp
+ stack.cpp
+ static_map.cpp
+ static_set.cpp
+ statistics.cpp
+ std_vector_c.cpp
+ string.cpp
+ svm_c_linear.cpp
+ svm_c_linear_dcd.cpp
+ svm.cpp
+ svm_multiclass_linear.cpp
+ svm_struct.cpp
+ svr_linear_trainer.cpp
+ symmetric_matrix_cache.cpp
+ thread_pool.cpp
+ threads.cpp
+ timer.cpp
+ tokenizer.cpp
+ trust_region.cpp
+ tuple.cpp
+ type_safe_union.cpp
+ vectorstream.cpp
+ dnn.cpp
+ cublas.cpp
+ find_optimal_parameters.cpp
+ elastic_net.cpp
+ )
+
+
+# add all the cpp files we want to compile to this list. This tells
+# cmake that they are part of our target (which is the executable named dtest)
+ADD_EXECUTABLE(${target_name} main.cpp tester.cpp ${tests})
+
+# Turn on all warnings when using gcc.
+if (CMAKE_COMPILER_IS_GNUCXX)
+ add_definitions("-W -Wall")
+endif()
+
+
+TARGET_LINK_LIBRARIES(${target_name} dlib::dlib )
+
+
+if (NOT DLIB_NO_GUI_SUPPORT)
+ add_subdirectory(gui)
+ add_subdirectory(examples)
+ add_subdirectory(tools)
+endif()
diff --git a/ml/dlib/dlib/test/WINDOWS_build_and_run_all_unit_tests.bat b/ml/dlib/dlib/test/WINDOWS_build_and_run_all_unit_tests.bat
new file mode 100644
index 000000000..0cacfa631
--- /dev/null
+++ b/ml/dlib/dlib/test/WINDOWS_build_and_run_all_unit_tests.bat
@@ -0,0 +1,42 @@
+date /T > test_log.txt
+time /T >> test_log.txt
+
+rem the pings are to wait between builds so visual studio doesn't get in a funk.
+
+
+
+echo testing python >> test_log.txt
+rmdir /S /Q build_python
+mkdir build_python
+cd build_python
+cmake -G "Visual Studio 14 2015 Win64" ../../../tools/python -DPYTHON3=ON
+cmake --build . --config Release --target install || exit /B
+ping 127.0.0.1 -n 5 -w 1000 > null
+cd ..
+
+
+
+echo testing vc2015 >> test_log.txt
+rmdir /S /Q build_vc2015_64
+mkdir build_vc2015_64
+cd build_vc2015_64
+cmake -G "Visual Studio 14 2015 Win64" ..
+cmake --build . --config Release || exit /B
+ping 127.0.0.1 -n 5 -w 1000 > null
+cmake --build . --config Debug || exit /B
+ping 127.0.0.1 -n 5 -w 1000 > null
+cd Release
+dtest --runall -d || exit /B
+cd ..
+cd ..
+
+
+
+
+
+del null
+type test_log.txt
+
+date /T
+time /T
+
diff --git a/ml/dlib/dlib/test/active_learning.cpp b/ml/dlib/dlib/test/active_learning.cpp
new file mode 100644
index 000000000..9dc0013a5
--- /dev/null
+++ b/ml/dlib/dlib/test/active_learning.cpp
@@ -0,0 +1,165 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/svm.h>
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.active_learning");
+
+// ----------------------------------------------------------------------------------------
+
+ typedef matrix<double, 0, 1> sample_type;
+ typedef radial_basis_kernel<sample_type> kernel_type;
+
+// ----------------------------------------------------------------------------------------
+
+ void make_dataset (
+ std::vector<sample_type>& samples,
+ std::vector<double>& labels
+ )
+ {
+ for (int r = -10; r <= 10; ++r)
+ {
+ for (int c = -10; c <= 10; ++c)
+ {
+ sample_type samp(2);
+ samp(0) = r;
+ samp(1) = c;
+ samples.push_back(samp);
+
+ // if this point is less than 10 from the origin
+ if (sqrt((double)r*r + c*c) <= 8)
+ labels.push_back(+1);
+ else
+ labels.push_back(-1);
+
+ }
+ }
+
+
+ vector_normalizer<sample_type> normalizer;
+ normalizer.train(samples);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ samples[i] = normalizer(samples[i]);
+
+ randomize_samples(samples, labels);
+
+ /*
+ cout << "samples.size(): " << samples.size() << endl;
+ cout << "num +1 samples: "<< sum(mat(labels) > 0) << endl;
+ cout << "num -1 samples: "<< sum(mat(labels) < 0) << endl;
+ */
+
+ empirical_kernel_map<kernel_type> ekm;
+ ekm.load(kernel_type(0.15), samples);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ samples[i] = ekm.project(samples[i]);
+
+ //cout << "dims: "<< ekm.out_vector_size() << endl;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double test_rank_unlabeled_training_samples (
+ const std::vector<sample_type>& samples,
+ const std::vector<double>& labels,
+ active_learning_mode mode,
+ int iterations,
+ bool pick_front
+ )
+ {
+ matrix<double,2,1> s;
+ s = sum(mat(labels) > 0), sum(mat(labels) < 0);
+ s /= labels.size();
+
+
+ svm_c_linear_dcd_trainer<linear_kernel<sample_type> > trainer;
+ trainer.set_c(25);
+
+ const unsigned long initial_size = 1;
+ std::vector<sample_type> tsamples(samples.begin(), samples.begin()+initial_size);
+ std::vector<double> tlabels(labels.begin(), labels.begin()+initial_size);
+
+ decision_function<linear_kernel<sample_type> > df;
+
+ double random_score = 0;
+ double active_learning_score = 0;
+ for (int i = 0; i < iterations; ++i)
+ {
+ print_spinner();
+ random_subset_selector<sample_type> sss = randomly_subsample(samples,50,i);
+ random_subset_selector<double> ssl = randomly_subsample(labels,50,i);
+ std::vector<unsigned long> results;
+
+ results = rank_unlabeled_training_samples(trainer, tsamples, tlabels, sss, mode);
+
+ const unsigned long idx = pick_front ? results.front() : results.back();
+ tsamples.push_back(sss[idx]);
+ tlabels.push_back(ssl[idx]);
+
+ df = trainer.train(tsamples, tlabels);
+ //cout << "tsamples.size(): " << tsamples.size() << endl;
+ const unsigned long num = tsamples.size();
+ const double active = test_binary_decision_function(df, samples, labels)*s;
+ //cout << "test: "<< active;
+ df = trainer.train(randomly_subsample(samples,num,i), randomly_subsample(labels,num,i));
+ const double random = test_binary_decision_function(df, samples, labels)*s;
+ //cout << "test: "<< random << endl;
+
+ active_learning_score += active;
+ random_score += random;
+
+ //cout << "\n\n***********\n\n" << flush;
+ }
+
+ dlog << LINFO << "pick_front: " << pick_front << " mode: "<< mode;
+ dlog << LINFO << "active_learning_score: "<< active_learning_score;
+ dlog << LINFO << "random_score: "<< random_score;
+ return active_learning_score / random_score;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_active_learning : public tester
+ {
+ public:
+ test_active_learning (
+ ) :
+ tester ("test_active_learning",
+ "Runs tests on the active learning components.")
+ {}
+
+ void perform_test (
+ )
+ {
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+ print_spinner();
+ make_dataset(samples, labels);
+ dlog << LINFO << "samples.size(): "<< samples.size();
+
+ // When we pick the best/front ranked element then the active learning method
+ // shouldn't do much worse than random selection (and often much better).
+ DLIB_TEST(test_rank_unlabeled_training_samples(samples, labels, max_min_margin, 35, true) >= 0.97);
+ DLIB_TEST(test_rank_unlabeled_training_samples(samples, labels, ratio_margin, 25, true) >= 0.96);
+ // However, picking the worst ranked element should do way worse than random
+ // selection.
+ DLIB_TEST(test_rank_unlabeled_training_samples(samples, labels, max_min_margin, 25, false) < 0.8);
+ DLIB_TEST(test_rank_unlabeled_training_samples(samples, labels, ratio_margin, 25, false) < 0.8);
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/any.cpp b/ml/dlib/dlib/test/any.cpp
new file mode 100644
index 000000000..355d00b31
--- /dev/null
+++ b/ml/dlib/dlib/test/any.cpp
@@ -0,0 +1,139 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/any.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../rand.h"
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.any");
+
+// ----------------------------------------------------------------------------------------
+
+ void test_contains_4(
+ const any a
+ )
+ {
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<int>() == true);
+ DLIB_TEST(a.contains<double>() == false);
+ DLIB_TEST(any_cast<int>(a) == 4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void run_test()
+ {
+ any a, b, c;
+
+ DLIB_TEST(a.is_empty());
+ DLIB_TEST(a.contains<int>() == false);
+ DLIB_TEST(a.contains<string>() == false);
+ DLIB_TEST(a.is_empty());
+
+ a = b;
+
+ swap(a,b);
+ a.swap(b);
+
+ a = 4;
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<int>() == true);
+ DLIB_TEST(a.contains<double>() == false);
+ DLIB_TEST(any_cast<int>(a) == 4);
+
+ test_contains_4(a);
+
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<int>() == true);
+ DLIB_TEST(a.contains<double>() == false);
+ DLIB_TEST(any_cast<int>(a) == 4);
+
+ bool error = false;
+ try
+ {
+ any_cast<double>(a);
+ }
+ catch (bad_any_cast&)
+ {
+ error = true;
+ }
+ DLIB_TEST(error);
+
+ swap(a,b);
+
+ test_contains_4(b);
+
+ DLIB_TEST(a.is_empty());
+
+ a = b;
+
+ test_contains_4(a);
+
+ c.get<string>() = "test string";
+ DLIB_TEST(c.get<string>() == "test string");
+
+ a = c;
+ DLIB_TEST(a.cast_to<string>() == "test string");
+
+
+ a.clear();
+ DLIB_TEST(a.is_empty());
+ error = false;
+ try
+ {
+ any_cast<string>(a);
+ }
+ catch (bad_any_cast&)
+ {
+ error = true;
+ }
+ DLIB_TEST(error);
+
+
+ a = 1;
+ b = 2;
+
+ int* a_ptr = &a.get<int>();
+ int* b_ptr = &b.get<int>();
+
+ swap(a,b);
+ DLIB_TEST(a_ptr == &b.get<int>());
+ DLIB_TEST(b_ptr == &a.get<int>());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class any_tester : public tester
+ {
+ public:
+ any_tester (
+ ) :
+ tester ("test_any",
+ "Runs tests on the any component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ run_test();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/any_function.cpp b/ml/dlib/dlib/test/any_function.cpp
new file mode 100644
index 000000000..8defb5988
--- /dev/null
+++ b/ml/dlib/dlib/test/any_function.cpp
@@ -0,0 +1,253 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/any.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../rand.h"
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.any_function");
+
+// ----------------------------------------------------------------------------------------
+
+ int add ( int a, int b) { return a + b; }
+ string cat ( string a, string b) { return a + b; }
+
+// ----------------------------------------------------------------------------------------
+
+ void set_vals1( int& a) { a = 1; }
+ void set_vals2( int& a, int& b) { a = 1; b = 2; }
+ void set_vals3( int& a, int& b, int& c) { a = 1; b = 2; c = 3; }
+ void set_vals4( int& a, int& b, int& c, int& d) { a = 1; b = 2; c = 3; d = 4; }
+ void set_vals5( int& a, int& b, int& c, int& d, int& e) { a = 1; b = 2; c = 3; d = 4; e = 5; }
+ void set_vals6( int& a, int& b, int& c, int& d, int& e, int& f) { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; }
+ void set_vals7( int& a, int& b, int& c, int& d, int& e, int& f, int& g) { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; g = 7; }
+
+ void set_vals8( int& a, int& b, int& c, int& d, int& e, int& f, int& g, int& h)
+ { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; g = 7; h = 8; }
+
+ void set_vals9( int& a, int& b, int& c, int& d, int& e, int& f, int& g, int& h, int& i)
+ { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; g = 7; h = 8; i = 9;}
+
+ void set_vals10( int& a, int& b, int& c, int& d, int& e, int& f, int& g, int& h, int& i, int& j)
+ { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; g = 7; h = 8; i = 9; j = 10;}
+
+ void zero_vals( int& a, int& b, int& c, int& d, int& e, int& f, int& g, int& h, int& i, int& j)
+ { a = 0; b = 0; c = 0; d = 0; e = 0; f = 0; g = 0; h = 0; i = 0; j = 0;}
+
+// ----------------------------------------------------------------------------------------
+
+ struct test
+ {
+ int operator()() const { return 4; }
+ };
+
+ struct test2
+ {
+ int v;
+
+ test2() : v(0) {}
+ test2(int val) : v(val) {}
+ int operator()() const { return v; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void test_contains_4(
+ const any_function<int()> a
+ )
+ {
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.is_set() == true);
+ DLIB_TEST(a.contains<test>() == true);
+ DLIB_TEST(a.contains<int(*)()>() == false);
+ DLIB_TEST(any_cast<test>(a)() == 4);
+ DLIB_TEST(a() == 4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void run_test()
+ {
+ any_function<int()> a, b, c;
+
+ DLIB_TEST(a.is_empty());
+ DLIB_TEST(a.is_set()==false);
+ DLIB_TEST(a.contains<int(*)()>() == false);
+ DLIB_TEST(a.contains<test>() == false);
+ DLIB_TEST(a.is_empty());
+
+ a = b;
+
+ swap(a,b);
+ a.swap(b);
+
+ a = test();
+ test_contains_4(a);
+
+
+ bool error = false;
+ try
+ {
+ any_cast<int(*)()>(a);
+ }
+ catch (bad_any_cast&)
+ {
+ error = true;
+ }
+ DLIB_TEST(error);
+
+ swap(a,b);
+
+ test_contains_4(b);
+
+ DLIB_TEST(a.is_empty());
+
+ a = b;
+
+ test_contains_4(a);
+
+ c.get<test2>() = test2(10);
+ DLIB_TEST(c.get<test2>().v == 10);
+
+ a = c;
+ DLIB_TEST(a.cast_to<test2>().v == 10);
+
+
+ a.clear();
+ DLIB_TEST(a.is_empty());
+ error = false;
+ try
+ {
+ any_cast<test>(a);
+ }
+ catch (bad_any_cast&)
+ {
+ error = true;
+ }
+ DLIB_TEST(error);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void run_test2()
+ {
+ any_function<int(int,int)> f = &add;
+
+ DLIB_TEST(f(1,3) == 4);
+
+ any_function<string(string,string)> g(&cat);
+ DLIB_TEST(g("one", "two") == "onetwo");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void run_test3()
+ {
+ any_function<void(int&)> f1;
+ any_function<void(int&,int&)> f2;
+ any_function<void(int&,int&,int&)> f3;
+ any_function<void(int&,int&,int&,int&)> f4;
+ any_function<void(int&,int&,int&,int&,int&)> f5;
+ any_function<void(int&,int&,int&,int&,int&,int&)> f6;
+ any_function<void(int&,int&,int&,int&,int&,int&,int&)> f7;
+ any_function<void(int&,int&,int&,int&,int&,int&,int&,int&)> f8;
+ any_function<void(int&,int&,int&,int&,int&,int&,int&,int&,int&)> f9;
+ any_function<void(int&,int&,int&,int&,int&,int&,int&,int&,int&,int&)> f10;
+
+ f1 = set_vals1;
+ f2 = set_vals2;
+ f3 = set_vals3;
+ f4 = set_vals4;
+ f5 = set_vals5;
+ f6 = set_vals6;
+ f7 = set_vals7;
+ f8 = set_vals8;
+ f9 = set_vals9;
+ f10 = set_vals10;
+
+ int a,b,c,d,e,f,g,h,i,j;
+
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f1(a);
+ DLIB_TEST(a==1);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f2(a,b);
+ DLIB_TEST(a==1 && b==2);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f3(a,b,c);
+ DLIB_TEST(a==1 && b==2 && c==3);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f4(a,b,c,d);
+ DLIB_TEST(a==1 && b==2 && c==3 && d==4);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f5(a,b,c,d,e);
+ DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f6(a,b,c,d,e,f);
+ DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f7(a,b,c,d,e,f,g);
+ DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6 && g==7);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f8(a,b,c,d,e,f,g,h);
+ DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6 && g==7 && h==8);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f9(a,b,c,d,e,f,g,h,i);
+ DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6 && g==7 && h==8 && i==9);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+
+ f10(a,b,c,d,e,f,g,h,i,j);
+ DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6 && g==7 && h==8 && i==9 && j==10);
+ zero_vals(a,b,c,d,e,f,g,h,i,j);
+ }
+// ----------------------------------------------------------------------------------------
+
+ class test_any_function : public tester
+ {
+ public:
+ test_any_function (
+ ) :
+ tester ("test_any_function",
+ "Runs tests on the any_function component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ run_test();
+ print_spinner();
+ run_test2();
+ print_spinner();
+ run_test3();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/array.cpp b/ml/dlib/dlib/test/array.cpp
new file mode 100644
index 000000000..55ba1a724
--- /dev/null
+++ b/ml/dlib/dlib/test/array.cpp
@@ -0,0 +1,669 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/interfaces/enumerable.h>
+#include <dlib/array.h>
+#include <dlib/rand.h>
+
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ using dlib::array;
+
+ logger dlog("test.array");
+
+ template <
+ typename array
+ >
+ void array_expand_test (
+ )
+ /*!
+ requires
+ - array is an implementation of array/array_sort_abstract.h
+ array is instantiated with unsigned long
+ ensures
+ - runs tests on array for compliance with the specs
+ !*/
+ {
+ dlib::rand rnd;
+
+ DLIB_TEST(dlib::is_array<array>::value == true);
+
+ array a1, a2;
+
+ {
+ array a4(4);
+ DLIB_TEST(a4.size() == 4);
+ }
+
+ {
+ array a1, a2;
+
+ for (int k = 1; k < 100000; k += 1000)
+ {
+ for (int i = 0; i < 10; ++i)
+ {
+ a1.clear();
+ a1.set_max_size(500+k);
+ a1.set_size(500+k);
+ for (unsigned long j = 0; j < a1.size(); ++j)
+ {
+ a1[j] = j;
+ DLIB_TEST(a1[j] == j);
+ }
+ }
+ }
+ }
+
+ DLIB_TEST(a1.max_size() == 0);
+ DLIB_TEST(a2.max_size() == 0);
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ swap(a1,a2);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.move_next() == false);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.size() == 0);
+
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ a1.reset();
+ a2.reset();
+
+ for (unsigned long k = 0; k < 4; ++k)
+ {
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ swap(a1,a2);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.move_next() == false);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.size() == 0);
+
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ a1.clear();
+ a2.clear();
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ swap(a1,a2);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.move_next() == false);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.size() == 0);
+
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ a1.clear();
+ a2.clear();
+
+
+
+
+ a1.set_max_size(100000);
+ a2.set_max_size(100000);
+ a1.set_size(10000);
+ DLIB_TEST(a1.size() == 10000);
+ a2.set_size(10000);
+ DLIB_TEST(a2.size() == 10000);
+ for (unsigned long i = 0; i < a1.size(); ++i)
+ {
+ unsigned long a = static_cast<unsigned long>(rnd.get_random_32bit_number());
+ a1[i] = a;
+ a2[i] = i;
+ DLIB_TEST(a1[i] == a);
+ DLIB_TEST(a2[i] == i);
+ }
+
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next());
+ DLIB_TEST(a1.current_element_valid());
+
+ DLIB_TEST(a1.at_start() == false);
+ a1.sort();
+ DLIB_TEST(a1.at_start());
+ a2.sort();
+ DLIB_TEST(a1.size() == 10000);
+ DLIB_TEST(a2.size() == 10000);
+
+
+ for (unsigned long i = 0; i < a1.size(); ++i)
+ {
+ if (i+1 < a1.size())
+ {
+ DLIB_TEST_MSG(a1[i] <= a1[i+1],
+ "a1[i]: " << a1[i] << " a1[i+1]: " << a1[i+1]
+ << " i: " << i);
+ }
+ DLIB_TEST_MSG(a2[i] == i,"i: " << i << " a2[i]: " << a2[i]);
+ }
+
+ unsigned long last = 0;
+ unsigned long count = 0;
+ while (a1.move_next())
+ {
+ DLIB_TEST(last <= a1.element());
+ last = a1.element();
+ ++count;
+ }
+ DLIB_TEST(count == a1.size());
+
+ last = 0;
+ count = 0;
+ while (a2.move_next())
+ {
+ DLIB_TEST(last <= a2.element());
+ last = a2.element();
+ ++count;
+ }
+ DLIB_TEST(count == a2.size());
+
+ a2.set_size(15000);
+
+ for (unsigned long i = 0; i < a1.size(); ++i)
+ {
+ if (i+1 < a1.size())
+ {
+ DLIB_TEST(a1[i] <= a1[i+1]);
+ }
+ DLIB_TEST(a2[i] == i);
+ }
+
+ for (unsigned long i = 10000; i < a2.size(); ++i)
+ {
+ a2[i] = i;
+ DLIB_TEST(a2[i] == i);
+ }
+
+ for (unsigned long i = 0; i < a2.size(); ++i)
+ {
+ DLIB_TEST(a2[i] == i);
+ }
+
+ a2.reset();
+ last = 0;
+ while (a2.move_next())
+ {
+ DLIB_TEST(last <= a2.element());
+ last = a2.element();
+ }
+
+ a1.reset();
+ last = 0;
+ while (a1.move_next())
+ {
+ DLIB_TEST(last <= a1.element());
+ last = a1.element();
+ }
+
+ a1.sort();
+ last = 0;
+ while (a1.move_next())
+ {
+ DLIB_TEST(last <= a1.element());
+ last = a1.element();
+ }
+
+ swap(a2,a1);
+
+ for (unsigned long i = 0; i < 15000; ++i)
+ {
+ DLIB_TEST(a1[i] == i);
+ }
+
+
+
+ a1.clear();
+ DLIB_TEST(a1.max_size() == 0);
+
+
+
+
+ a1.clear();
+ a2.clear();
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a2.size() == 0);
+ a1.set_max_size(100000);
+ a2.set_max_size(100000);
+
+ a1.set_size(10000);
+ DLIB_TEST(a1.size() == 10000);
+ a2.set_size(10000);
+ DLIB_TEST(a2.size() == 10000);
+ for (unsigned long i = 0; i < a1.size(); ++i)
+ {
+ unsigned long a = static_cast<unsigned long>(rnd.get_random_32bit_number());
+ a1[i] = a;
+ a2[i] = i;
+ DLIB_TEST(a1[i] == a);
+ DLIB_TEST(a2[i] == i);
+ }
+
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next());
+ DLIB_TEST(a1.current_element_valid());
+
+ DLIB_TEST(a1.at_start() == false);
+ a1.sort();
+ DLIB_TEST(a1.at_start());
+ a2.sort();
+ DLIB_TEST(a1.size() == 10000);
+ DLIB_TEST(a2.size() == 10000);
+
+
+ for (unsigned long i = 0; i < a1.size(); ++i)
+ {
+ if (i+1 < a1.size())
+ {
+ DLIB_TEST(a1[i] <= a1[i+1]);
+ }
+ DLIB_TEST(a2[i] == i);
+ }
+
+ last = 0;
+ while (a1.move_next())
+ {
+ DLIB_TEST(last <= a1.element());
+ last = a1.element();
+ }
+
+ last = 0;
+ while (a2.move_next())
+ {
+ DLIB_TEST(last <= a2.element());
+ last = a2.element();
+ }
+
+ a2.set_size(15000);
+
+ for (unsigned long i = 0; i < a1.size(); ++i)
+ {
+ if (i+1 < a1.size())
+ {
+ DLIB_TEST(a1[i] <= a1[i+1]);
+ }
+ DLIB_TEST(a2[i] == i);
+ }
+
+ for (unsigned long i = 10000; i < a2.size(); ++i)
+ {
+ a2[i] = i;
+ DLIB_TEST(a2[i] == i);
+ }
+
+ for (unsigned long i = 0; i < a2.size(); ++i)
+ {
+ DLIB_TEST(a2[i] == i);
+ }
+
+ a2.reset();
+ last = 0;
+ while (a2.move_next())
+ {
+ DLIB_TEST(last <= a2.element());
+ last = a2.element();
+ }
+
+ a1.reset();
+ last = 0;
+ while (a1.move_next())
+ {
+ DLIB_TEST(last <= a1.element());
+ last = a1.element();
+ }
+
+ a1.sort();
+ last = 0;
+ while (a1.move_next())
+ {
+ DLIB_TEST(last <= a1.element());
+ last = a1.element();
+ }
+
+ swap(a2,a1);
+
+ for (unsigned long i = 0; i < 15000; ++i)
+ {
+ DLIB_TEST(a1[i] == i);
+ }
+
+
+
+ a1.clear();
+ DLIB_TEST(a1.max_size() == 0);
+
+ a2.clear();
+ print_spinner();
+ }
+
+
+
+ a1.set_max_size(2000000);
+ DLIB_TEST(a1.max_size() == 2000000);
+ DLIB_TEST(a1.size() == 0);
+ a1.set_size(2000000);
+ DLIB_TEST(a1.max_size() == 2000000);
+ DLIB_TEST(a1.size() == 2000000);
+
+ for (unsigned long i = 0; i < a1.size(); ++i)
+ {
+ a1[i] = rnd.get_random_32bit_number();
+ }
+
+ print_spinner();
+ a1.sort();
+
+ print_spinner();
+ // serialize the state of a1, then clear a1, then
+ // load the state back into a1.
+ ostringstream sout;
+ serialize(a1,sout);
+ DLIB_TEST(a1.at_start() == true);
+ istringstream sin(sout.str());
+ a1.clear();
+ DLIB_TEST(a1.max_size() == 0);
+ deserialize(a1,sin);
+
+ DLIB_TEST(a1.size() == 2000000);
+
+ for (unsigned long i = 0; i < a1.size()-1; ++i)
+ {
+ DLIB_TEST(a1[i] <= a1[i+1]);
+ }
+
+ DLIB_TEST(a1.max_size() == 2000000);
+ DLIB_TEST(a1.size() == 2000000);
+
+
+ swap(a1,a2);
+
+ print_spinner();
+
+ DLIB_TEST(a2.size() == 2000000);
+
+ for (unsigned long i = 0; i < a2.size()-1; ++i)
+ {
+ DLIB_TEST(a2[i] <= a2[i+1]);
+ }
+
+ DLIB_TEST(a2.max_size() == 2000000);
+ DLIB_TEST(a2.size() == 2000000);
+
+ swap(a1,a2);
+
+
+ a1.clear();
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.max_size() == 0);
+
+ a1.resize(10);
+ DLIB_TEST(a1.size() == 10);
+ DLIB_TEST(a1.max_size() == 10);
+
+ for (unsigned long i = 0; i < a1.size(); ++i)
+ {
+ a1[i] = i;
+ }
+
+ print_spinner();
+ a1.resize(100);
+ DLIB_TEST(a1.size() == 100);
+ DLIB_TEST(a1.max_size() == 100);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_TEST(a1[i] == i);
+ }
+
+ a1.resize(50);
+ DLIB_TEST(a1.size() == 50);
+ DLIB_TEST(a1.max_size() == 100);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_TEST(a1[i] == i);
+ }
+
+ a1.resize(10);
+ DLIB_TEST(a1.size() == 10);
+ DLIB_TEST(a1.max_size() == 100);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_TEST(a1[i] == i);
+ }
+
+ a1.resize(20);
+ DLIB_TEST(a1.size() == 20);
+ DLIB_TEST(a1.max_size() == 100);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_TEST(a1[i] == i);
+ }
+
+
+ a1.resize(100);
+ DLIB_TEST(a1.size() == 100);
+ DLIB_TEST(a1.max_size() == 100);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_TEST(a1[i] == i);
+ }
+
+ {
+ a1.clear();
+ DLIB_TEST(a1.size() == 0);
+ for (unsigned long i = 0; i < 100; ++i)
+ {
+ unsigned long a = i;
+ a1.push_back(a);
+ DLIB_TEST(a1.size() == i+1);
+ DLIB_TEST(a1.back() == i);
+ }
+ for (unsigned long i = 0; i < 100; ++i)
+ {
+ DLIB_TEST(a1[i] == i);
+ }
+ for (unsigned long i = 0; i < 100; ++i)
+ {
+ unsigned long a = 0;
+ a1.pop_back(a);
+ DLIB_TEST(a == 99-i);
+ }
+ }
+
+ {
+ a1.clear();
+ DLIB_TEST(a1.size() == 0);
+ for (unsigned long i = 0; i < 100; ++i)
+ {
+ unsigned long a = i;
+ a1.push_back(a);
+ DLIB_TEST(a1.size() == i+1);
+ DLIB_TEST(a1.back() == i);
+ }
+ for (unsigned long i = 0; i < 100; ++i)
+ {
+ DLIB_TEST(a1[i] == i);
+ }
+ for (unsigned long i = 0; i < 100; ++i)
+ {
+ a1.pop_back();
+ }
+ DLIB_TEST(a1.size() == 0);
+ }
+
+ }
+
+ struct stuff
+ {
+ int whatever;
+ };
+ void another_array_test()
+ {
+ array<stuff> a;
+ a.resize(5);
+ a[0].whatever = 0;
+ stuff temp;
+ temp.whatever = 99;
+ a.push_back(temp);
+ DLIB_TEST(a.size() == 6);
+ DLIB_TEST(a[5].whatever == 99);
+
+ DLIB_TEST(dlib::is_array<array<stuff> >::value == true);
+ }
+
+ void test_array_split()
+ {
+ array<int> temp(5);
+
+ for (unsigned int i = 0; i < temp.size(); ++i)
+ temp[i] = i;
+
+ array<int> b;
+
+ split_array(temp, b, 0.5);
+ DLIB_TEST(temp.size() == 2);
+ DLIB_TEST(b.size() == 3);
+
+ DLIB_TEST(temp[0] == 0);
+ DLIB_TEST(temp[1] == 1);
+ DLIB_TEST(b[0] == 2);
+ DLIB_TEST(b[1] == 3);
+ DLIB_TEST(b[2] == 4);
+ }
+
+ class array_tester : public tester
+ {
+ public:
+ array_tester (
+ ) :
+ tester ("test_array",
+ "Runs tests on the array component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ another_array_test();
+
+ // test a checking version first for good measure
+ print_spinner();
+ array_expand_test<array<unsigned long> >();
+
+ DLIB_TEST(dlib::is_array<int>::value == false);
+ test_array_split();
+ }
+ } a;
+
+
+
+
+}
+
diff --git a/ml/dlib/dlib/test/array2d.cpp b/ml/dlib/dlib/test/array2d.cpp
new file mode 100644
index 000000000..12b8b586a
--- /dev/null
+++ b/ml/dlib/dlib/test/array2d.cpp
@@ -0,0 +1,580 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/interfaces/enumerable.h>
+#include <dlib/array2d.h>
+#include "tester.h"
+#include <dlib/pixel.h>
+#include <dlib/image_transforms.h>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.array2d");
+
+ template <
+ typename array2d
+ >
+ void array2d_kernel_test (
+ )
+ /*!
+ requires
+ - array2d is an implementation of array2d/array2d_kernel_abstract.h
+ is instantiated with unsigned long
+ ensures
+ - runs tests on array2d for compliance with the specs
+ !*/
+ {
+ srand(static_cast<unsigned int>(time(0)));
+
+ array2d test,test2;
+
+ long nc, nr;
+
+
+ DLIB_TEST(get_rect(test).is_empty());
+
+ enumerable<unsigned long>& e = test;
+ DLIB_TEST(e.at_start() == true);
+
+
+ DLIB_TEST(e.size() == 0);
+ DLIB_TEST(e.at_start() == true);
+ DLIB_TEST(e.current_element_valid() == false);
+
+ DLIB_TEST (e.move_next() == false);
+ DLIB_TEST (e.move_next() == false);
+ DLIB_TEST (e.move_next() == false);
+ DLIB_TEST (e.move_next() == false);
+ DLIB_TEST (e.move_next() == false);
+ DLIB_TEST (e.move_next() == false);
+
+
+ DLIB_TEST(e.size() == 0);
+ DLIB_TEST(e.at_start() == false);
+ DLIB_TEST(e.current_element_valid() == false);
+
+
+ e.reset();
+
+ DLIB_TEST(e.size() == 0);
+ DLIB_TEST(e.at_start() == true);
+ DLIB_TEST(e.current_element_valid() == false);
+
+
+ DLIB_TEST(get_rect(test).is_empty());
+
+
+
+ DLIB_TEST(test.at_start() == true);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ test.reset();
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ test.clear();
+
+
+ DLIB_TEST(test.at_start() == true);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+
+
+ test.set_size(0,0);
+
+ DLIB_TEST(get_rect(test).is_empty());
+
+ DLIB_TEST(test.at_start() == true);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+ DLIB_TEST (test.move_next() == false);
+
+ swap(test,test2);
+ DLIB_TEST (test2.at_start() == false);
+ DLIB_TEST (test2.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ swap(test,test2);
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.current_element_valid() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ test.reset();
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+
+
+ for (int j = 0; j < 30; ++j)
+ {
+ test2.clear();
+ switch (j)
+ {
+ case 0:
+ nc = 10;
+ nr = 11;
+ break;
+ case 1:
+ nc = 1;
+ nr = 1;
+ break;
+ case 2:
+ nc = 100;
+ nr = 1;
+ break;
+ case 3:
+ nc = 1;
+ nr = 100;
+ break;
+ default:
+ nc = ::rand()%100 + 1;
+ nr = ::rand()%100 + 1;
+ break;
+ }
+
+ test.set_size(nr,nc);
+
+ DLIB_TEST(get_rect(test).left() == 0);
+ DLIB_TEST(get_rect(test).top() == 0);
+ DLIB_TEST(get_rect(test).right() == nc-1);
+ DLIB_TEST(get_rect(test).bottom() == nr-1);
+
+ DLIB_TEST(test.size() == static_cast<unsigned long>(nc*nr));
+ DLIB_TEST(test.nr() == nr);
+ DLIB_TEST(test.nc() == nc);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ unsigned long i = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.current_element_valid() == true);
+ DLIB_TEST(test.at_start() == false);
+ test.element() = i;
+ DLIB_TEST(const_cast<const array2d&>(test).element() == i);
+ ++i;
+ }
+ DLIB_TEST(i == test.size());
+ DLIB_TEST(test.current_element_valid() == false);
+
+ DLIB_TEST(test.nr() == nr);
+ DLIB_TEST(test.nc() == nc);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.size() == static_cast<unsigned long>(nc*nr));
+
+ i = 0;
+ for (long row = 0; row < test.nr(); ++row)
+ {
+ for (long col = 0; col < test.nc(); ++col)
+ {
+ DLIB_TEST_MSG(test[row][col] == i,
+ "\n\trow: " << row <<
+ "\n\tcol: " << col <<
+ "\n\ti: " << i <<
+ "\n\ttest[row][col]: " << test[row][col]);
+ DLIB_TEST(test[row].nc() == test.nc());
+ DLIB_TEST(test.current_element_valid() == false);
+
+ DLIB_TEST(test.nr() == nr);
+ DLIB_TEST(test.nc() == nc);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.size() == static_cast<unsigned long>(nc*nr));
+ ++i;
+ }
+ }
+
+ test.reset();
+
+ i = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element() == i);
+ ++i;
+ DLIB_TEST(test.current_element_valid() == true);
+ DLIB_TEST(test.at_start() == false);
+ }
+ DLIB_TEST(i == test.size());
+
+ test.reset();
+
+
+
+
+ swap(test,test2);
+
+ DLIB_TEST(test2.size() == static_cast<unsigned long>(nc*nr));
+ DLIB_TEST(test2.nr() == nr);
+ DLIB_TEST(test2.nc() == nc);
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.current_element_valid() == false);
+
+ i = 0;
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2.current_element_valid() == true);
+ DLIB_TEST(test2.at_start() == false);
+ test2.element() = i;
+ ++i;
+ }
+ DLIB_TEST(i == test2.size());
+ DLIB_TEST(test2.current_element_valid() == false);
+
+ DLIB_TEST(test2.nr() == nr);
+ DLIB_TEST(test2.nr() == test2.nr());
+ DLIB_TEST(test2.nc() == nc);
+ DLIB_TEST(test2.nc() == test2.nc());
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.size() == static_cast<unsigned long>(nc*nr));
+
+ i = 0;
+ for (long row = 0; row < test2.nr(); ++row)
+ {
+ for (long col = 0; col < test2.nc(); ++col)
+ {
+ DLIB_TEST(test2[row][col] == i);
+ DLIB_TEST(const_cast<const array2d&>(test2)[row][col] == i);
+ DLIB_TEST(test2[row].nc() == test2.nc());
+ DLIB_TEST(test2.current_element_valid() == false);
+
+ DLIB_TEST(test2.nr() == nr);
+ DLIB_TEST(test2.nr() == test2.nr());
+ DLIB_TEST(test2.nc() == nc);
+ DLIB_TEST(test2.nc() == test2.nc());
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.size() == static_cast<unsigned long>(nc*nr));
+ ++i;
+ }
+ }
+
+ test2.reset();
+
+ i = 0;
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2.element() == i);
+ DLIB_TEST(const_cast<const array2d&>(test2).element() == i);
+ ++i;
+ DLIB_TEST(test2.current_element_valid() == true);
+ DLIB_TEST(test2.at_start() == false);
+ }
+ DLIB_TEST(i == test2.size());
+
+
+ test2.clear();
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.nr() == 0);
+ DLIB_TEST(test2.nc() == 0);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == true);
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.nc() == 0);
+ DLIB_TEST(test.nr() == 0);
+
+ test.set_size(nr,nc);
+ DLIB_TEST(test.size() == static_cast<unsigned long>(nc*nr));
+ DLIB_TEST(test.nc() == nc);
+ DLIB_TEST(test.nr() == nr);
+
+
+
+ }
+
+
+
+
+
+ // test the serialization
+ istringstream sin;
+ ostringstream sout;
+ test.clear();
+ test2.clear();
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.nc() == 0);
+ DLIB_TEST(test.nr() == 0);
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.nc() == 0);
+ DLIB_TEST(test2.nr() == 0);
+
+ test.set_size(10,10);
+
+ for (long row = 0; row < test.nr(); ++row)
+ {
+ for (long col = 0; col < test.nc(); ++col)
+ {
+ test[row][col] = row*col;
+ }
+ }
+
+ serialize(test,sout);
+ sin.str(sout.str());
+ deserialize(test2,sin);
+
+ DLIB_TEST(test2.size() == test.size());
+ DLIB_TEST(test2.nc() == test.nc());
+ DLIB_TEST(test2.nr() == test.nr());
+ DLIB_TEST(test2.size() == 100);
+ DLIB_TEST(test2.nc() == 10);
+ DLIB_TEST(test2.nr() == 10);
+
+
+ for (long row = 0; row < test.nr(); ++row)
+ {
+ for (long col = 0; col < test.nc(); ++col)
+ {
+ DLIB_TEST(test[row][col] == static_cast<unsigned long>(row*col));
+ DLIB_TEST(test2[row][col] == static_cast<unsigned long>(row*col));
+ }
+ }
+
+
+
+
+
+
+ test.set_size(10,11);
+ DLIB_TEST(test.nr() == 10);
+ DLIB_TEST(test.nc() == 11);
+ test.set_size(0,0);
+ DLIB_TEST(test.nr() == 0);
+ DLIB_TEST(test.nc() == 0);
+
+ }
+
+ void test_serialization()
+ {
+ // Do these tests because there are overloads of the serialize routines
+ // specifically for these types of pixel (except for unsigned short,
+ // we do that because you can never have too many tests).
+ {
+ array2d<rgb_alpha_pixel> img, img2;
+ img.set_size(3,2);
+ assign_all_pixels(img, 5);
+ img[1][1].red = 9;
+ img[1][1].green = 8;
+ img[1][1].blue = 7;
+ img[1][1].alpha = 3;
+ ostringstream sout;
+ serialize(img, sout);
+ istringstream sin(sout.str());
+ deserialize(img2, sin);
+
+ DLIB_TEST(img2.nr() == 3);
+ DLIB_TEST(img2.nc() == 2);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ DLIB_TEST(img[r][c].red == img2[r][c].red);
+ DLIB_TEST(img[r][c].green == img2[r][c].green);
+ DLIB_TEST(img[r][c].blue == img2[r][c].blue);
+ DLIB_TEST(img[r][c].alpha == img2[r][c].alpha);
+ }
+ }
+ }
+ {
+ array2d<hsi_pixel> img, img2;
+ img.set_size(3,2);
+ assign_all_pixels(img, 5);
+ img[1][1].h = 9;
+ img[1][1].s = 2;
+ img[1][1].i = 3;
+ ostringstream sout;
+ serialize(img, sout);
+ istringstream sin(sout.str());
+ deserialize(img2, sin);
+
+ DLIB_TEST(img2.nr() == 3);
+ DLIB_TEST(img2.nc() == 2);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ DLIB_TEST(img[r][c].h == img2[r][c].h);
+ DLIB_TEST(img[r][c].s == img2[r][c].s);
+ DLIB_TEST(img[r][c].i == img2[r][c].i);
+ }
+ }
+ }
+ {
+ array2d<bgr_pixel> img, img2;
+ img.set_size(3,2);
+ assign_all_pixels(img, 5);
+ img[1][1].red = 1;
+ img[1][1].green = 2;
+ img[1][1].blue = 3;
+ ostringstream sout;
+ serialize(img, sout);
+ istringstream sin(sout.str());
+ deserialize(img2, sin);
+
+ DLIB_TEST(img2.nr() == 3);
+ DLIB_TEST(img2.nc() == 2);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ DLIB_TEST(img[r][c].red == img2[r][c].red);
+ DLIB_TEST(img[r][c].green == img2[r][c].green);
+ DLIB_TEST(img[r][c].blue == img2[r][c].blue);
+ }
+ }
+ }
+ {
+ array2d<rgb_pixel> img, img2;
+ img.set_size(3,2);
+ assign_all_pixels(img, 5);
+ img[1][1].red = 1;
+ img[1][1].green = 2;
+ img[1][1].blue = 3;
+ ostringstream sout;
+ serialize(img, sout);
+ istringstream sin(sout.str());
+ deserialize(img2, sin);
+
+ DLIB_TEST(img2.nr() == 3);
+ DLIB_TEST(img2.nc() == 2);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ DLIB_TEST(img[r][c].red == img2[r][c].red);
+ DLIB_TEST(img[r][c].green == img2[r][c].green);
+ DLIB_TEST(img[r][c].blue == img2[r][c].blue);
+ }
+ }
+ }
+ {
+ array2d<unsigned short> img, img2;
+ img.set_size(3,2);
+ assign_all_pixels(img, 5);
+ img[1][1] = 9;
+ ostringstream sout;
+ serialize(img, sout);
+ istringstream sin(sout.str());
+ deserialize(img2, sin);
+
+ DLIB_TEST(img2.nr() == 3);
+ DLIB_TEST(img2.nc() == 2);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ DLIB_TEST(img[r][c] == img2[r][c]);
+ }
+ }
+ }
+ {
+ array2d<unsigned char> img, img2;
+ img.set_size(3,2);
+ assign_all_pixels(img, 5);
+ img[1][1] = 9;
+ ostringstream sout;
+ serialize(img, sout);
+ istringstream sin(sout.str());
+ deserialize(img2, sin);
+
+ DLIB_TEST(img2.nr() == 3);
+ DLIB_TEST(img2.nc() == 2);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ DLIB_TEST(img[r][c] == img2[r][c]);
+ }
+ }
+
+ DLIB_TEST((char*)&img[0][0] + img.width_step() == (char*)&img[1][0]);
+ }
+
+ COMPILE_TIME_ASSERT(is_array2d<array2d<unsigned char> >::value == true);
+ COMPILE_TIME_ASSERT(is_array2d<array2d<float> >::value == true);
+ COMPILE_TIME_ASSERT(is_array2d<float>::value == false);
+ }
+
+
+ class array2d_tester : public tester
+ {
+ public:
+ array2d_tester (
+ ) :
+ tester ("test_array2d",
+ "Runs tests on the array2d component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ array2d_kernel_test<array2d<unsigned long> >();
+ print_spinner();
+ test_serialization();
+ print_spinner();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/assignment_learning.cpp b/ml/dlib/dlib/test/assignment_learning.cpp
new file mode 100644
index 000000000..bba47db3d
--- /dev/null
+++ b/ml/dlib/dlib/test/assignment_learning.cpp
@@ -0,0 +1,379 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+#include <dlib/rand.h>
+
+
+typedef dlib::matrix<double,3,1> lhs_element;
+typedef dlib::matrix<double,3,1> rhs_element;
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.assignment_learning");
+
+// ----------------------------------------------------------------------------------------
+
+
+// ----------------------------------------------------------------------------------------
+
+ struct feature_extractor_dense
+ {
+ typedef matrix<double,3,1> feature_vector_type;
+
+ typedef ::lhs_element lhs_element;
+ typedef ::rhs_element rhs_element;
+
+ unsigned long num_features() const
+ {
+ return 3;
+ }
+
+ void get_features (
+ const lhs_element& left,
+ const rhs_element& right,
+ feature_vector_type& feats
+ ) const
+ {
+ feats = squared(left - right);
+ }
+
+ };
+
+ void serialize (const feature_extractor_dense& , std::ostream& ) {}
+ void deserialize (feature_extractor_dense& , std::istream& ) {}
+
+// ----------------------------------------------------------------------------------------
+
+ struct feature_extractor_sparse
+ {
+ typedef std::vector<std::pair<unsigned long,double> > feature_vector_type;
+
+ typedef ::lhs_element lhs_element;
+ typedef ::rhs_element rhs_element;
+
+ unsigned long num_features() const
+ {
+ return 3;
+ }
+
+ void get_features (
+ const lhs_element& left,
+ const rhs_element& right,
+ feature_vector_type& feats
+ ) const
+ {
+ feats.clear();
+ feats.push_back(make_pair(0,squared(left-right)(0)));
+ feats.push_back(make_pair(1,squared(left-right)(1)));
+ feats.push_back(make_pair(2,squared(left-right)(2)));
+ }
+
+ };
+
+ void serialize (const feature_extractor_sparse& , std::ostream& ) {}
+ void deserialize (feature_extractor_sparse& , std::istream& ) {}
+
+// ----------------------------------------------------------------------------------------
+
+ typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
+ typedef std::vector<long> label_type;
+
+// ----------------------------------------------------------------------------------------
+
+ void make_data (
+ std::vector<sample_type>& samples,
+ std::vector<label_type>& labels
+ )
+ {
+ lhs_element a, b, c, d;
+ a = 1,0,0;
+ b = 0,1,0;
+ c = 0,0,1;
+ d = 0,1,1;
+
+ std::vector<lhs_element> lhs;
+ std::vector<rhs_element> rhs;
+ label_type label;
+
+ lhs.push_back(a);
+ lhs.push_back(b);
+ lhs.push_back(c);
+
+ rhs.push_back(b);
+ rhs.push_back(a);
+ rhs.push_back(c);
+
+ label.push_back(1);
+ label.push_back(0);
+ label.push_back(2);
+
+ samples.push_back(make_pair(lhs,rhs));
+ labels.push_back(label);
+
+
+
+
+ lhs.clear();
+ rhs.clear();
+ label.clear();
+
+ lhs.push_back(a);
+ lhs.push_back(b);
+ lhs.push_back(c);
+
+ rhs.push_back(c);
+ rhs.push_back(b);
+ rhs.push_back(a);
+ rhs.push_back(d);
+
+ label.push_back(2);
+ label.push_back(1);
+ label.push_back(0);
+
+ samples.push_back(make_pair(lhs,rhs));
+ labels.push_back(label);
+
+
+ lhs.clear();
+ rhs.clear();
+ label.clear();
+
+ lhs.push_back(a);
+ lhs.push_back(b);
+ lhs.push_back(c);
+
+ rhs.push_back(c);
+ rhs.push_back(a);
+ rhs.push_back(d);
+
+ label.push_back(1);
+ label.push_back(-1);
+ label.push_back(0);
+
+ samples.push_back(make_pair(lhs,rhs));
+ labels.push_back(label);
+
+
+
+ lhs.clear();
+ rhs.clear();
+ label.clear();
+
+ lhs.push_back(d);
+ lhs.push_back(b);
+ lhs.push_back(c);
+
+ label.push_back(-1);
+ label.push_back(-1);
+ label.push_back(-1);
+
+ samples.push_back(make_pair(lhs,rhs));
+ labels.push_back(label);
+
+
+
+ lhs.clear();
+ rhs.clear();
+ label.clear();
+
+ samples.push_back(make_pair(lhs,rhs));
+ labels.push_back(label);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void make_data_force (
+ std::vector<sample_type>& samples,
+ std::vector<label_type>& labels
+ )
+ {
+ lhs_element a, b, c, d;
+ a = 1,0,0;
+ b = 0,1,0;
+ c = 0,0,1;
+ d = 0,1,1;
+
+ std::vector<lhs_element> lhs;
+ std::vector<rhs_element> rhs;
+ label_type label;
+
+ lhs.push_back(a);
+ lhs.push_back(b);
+ lhs.push_back(c);
+
+ rhs.push_back(b);
+ rhs.push_back(a);
+ rhs.push_back(c);
+
+ label.push_back(1);
+ label.push_back(0);
+ label.push_back(2);
+
+ samples.push_back(make_pair(lhs,rhs));
+ labels.push_back(label);
+
+
+
+
+ lhs.clear();
+ rhs.clear();
+ label.clear();
+
+ lhs.push_back(a);
+ lhs.push_back(b);
+ lhs.push_back(c);
+
+ rhs.push_back(c);
+ rhs.push_back(b);
+ rhs.push_back(a);
+ rhs.push_back(d);
+
+ label.push_back(2);
+ label.push_back(1);
+ label.push_back(0);
+
+ samples.push_back(make_pair(lhs,rhs));
+ labels.push_back(label);
+
+
+ lhs.clear();
+ rhs.clear();
+ label.clear();
+
+ lhs.push_back(a);
+ lhs.push_back(c);
+
+ rhs.push_back(c);
+ rhs.push_back(a);
+
+ label.push_back(1);
+ label.push_back(0);
+
+ samples.push_back(make_pair(lhs,rhs));
+ labels.push_back(label);
+
+
+
+
+
+ lhs.clear();
+ rhs.clear();
+ label.clear();
+
+ samples.push_back(make_pair(lhs,rhs));
+ labels.push_back(label);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename fe_type, typename F>
+ void test1(F make_data, bool force_assignment)
+ {
+ print_spinner();
+
+ std::vector<sample_type> samples;
+ std::vector<label_type> labels;
+
+ make_data(samples, labels);
+ make_data(samples, labels);
+ make_data(samples, labels);
+
+ randomize_samples(samples, labels);
+
+ structural_assignment_trainer<fe_type> trainer;
+
+ DLIB_TEST(trainer.forces_assignment() == false);
+ DLIB_TEST(trainer.get_c() == 100);
+ DLIB_TEST(trainer.get_num_threads() == 2);
+ DLIB_TEST(trainer.get_max_cache_size() == 5);
+
+
+ trainer.set_forces_assignment(force_assignment);
+ trainer.set_num_threads(3);
+ trainer.set_c(50);
+
+ DLIB_TEST(trainer.get_c() == 50);
+ DLIB_TEST(trainer.get_num_threads() == 3);
+ DLIB_TEST(trainer.forces_assignment() == force_assignment);
+
+ assignment_function<fe_type> ass = trainer.train(samples, labels);
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ std::vector<long> out = ass(samples[i]);
+ dlog << LINFO << "true labels: " << trans(mat(labels[i]));
+ dlog << LINFO << "pred labels: " << trans(mat(out));
+ DLIB_TEST(trans(mat(labels[i])) == trans(mat(out)));
+ }
+
+ double accuracy;
+
+ dlog << LINFO << "samples.size(): "<< samples.size();
+ accuracy = test_assignment_function(ass, samples, labels);
+ dlog << LINFO << "accuracy: "<< accuracy;
+ DLIB_TEST(accuracy == 1);
+
+ accuracy = cross_validate_assignment_trainer(trainer, samples, labels, 3);
+ dlog << LINFO << "cv accuracy: "<< accuracy;
+ DLIB_TEST(accuracy == 1);
+
+ ostringstream sout;
+ serialize(ass, sout);
+ istringstream sin(sout.str());
+ assignment_function<fe_type> ass2;
+ deserialize(ass2, sin);
+
+ DLIB_TEST(ass2.forces_assignment() == ass.forces_assignment());
+ DLIB_TEST(length(ass2.get_weights() - ass.get_weights()) < 1e-10);
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ std::vector<long> out = ass2(samples[i]);
+ dlog << LINFO << "true labels: " << trans(mat(labels[i]));
+ dlog << LINFO << "pred labels: " << trans(mat(out));
+ DLIB_TEST(trans(mat(labels[i])) == trans(mat(out)));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_assignment_learning : public tester
+ {
+ public:
+ test_assignment_learning (
+ ) :
+ tester ("test_assignment_learning",
+ "Runs tests on the assignment learning code.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test1<feature_extractor_dense>(make_data, false);
+ test1<feature_extractor_sparse>(make_data, false);
+
+ test1<feature_extractor_dense>(make_data_force, false);
+ test1<feature_extractor_sparse>(make_data_force, false);
+ test1<feature_extractor_dense>(make_data_force, true);
+ test1<feature_extractor_sparse>(make_data_force, true);
+ }
+ } a;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
diff --git a/ml/dlib/dlib/test/base64.cpp b/ml/dlib/dlib/test/base64.cpp
new file mode 100644
index 000000000..f4d478018
--- /dev/null
+++ b/ml/dlib/dlib/test/base64.cpp
@@ -0,0 +1,208 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/base64.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.base64");
+
+ template <
+ typename base64
+ >
+ void base64_kernel_test (
+ )
+ /*!
+ requires
+ - base64 is an implementation of base64/base64_kernel_abstract.h
+ ensures
+ - runs tests on base64 for compliance with the specs
+ !*/
+ {
+
+ const unsigned int seed = static_cast<unsigned int>(time(0));
+ try
+ {
+
+ srand(seed);
+
+ base64 test;
+
+ const string wiki_normal = "\
+Man is distinguished, not only by his reason, but by this singular passion from other \
+animals, which is a lust of the mind, that by a perseverance of delight in the continued \
+and indefatigable generation of knowledge, exceeds the short vehemence of any carnal pleasure.";
+
+ const string wiki_encoded = "\
+TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0\n\
+aGlzIHNpbmd1bGFyIHBhc3Npb24gZnJvbSBvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1\n\
+c3Qgb2YgdGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodCBpbiB0\n\
+aGUgY29udGludWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdl\n\
+LCBleGNlZWRzIHRoZSBzaG9ydCB2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=";
+
+
+
+ string str;
+
+ istringstream sin;
+ ostringstream sout;
+
+ sin.str(wiki_encoded);
+ test.decode(sin,sout);
+ DLIB_TEST_MSG(sout.str() == wiki_normal,
+ "sout.str(): " << sout.str() <<
+ "\nwiki_normal: " << wiki_normal);
+
+
+ sout.str("");
+ sin.str(wiki_normal);
+ sin.clear();
+ test.encode(sin,sout);
+
+ string a(sout.str()), b(wiki_encoded);
+ // we want to strip all the whitespace from a and b now
+ sin.str(a);
+ a.clear();
+ sin >> str;
+ while (sin)
+ {
+ a += str;
+ sin >> str;
+ }
+
+ sin.clear();
+ sin.str(b);
+ b.clear();
+ sin >> str;
+ while (sin)
+ {
+ b += str;
+ sin >> str;
+ }
+ sin.clear();
+
+ DLIB_TEST_MSG(a == b,
+ "a: \n" << a <<
+ "\n\nb: \n" << b);
+
+
+
+ sin.clear();
+ sin.str("");
+ sout.str("");
+ test.encode(sin,sout);
+ sin.str(sout.str());
+ sout.str("");
+ test.decode(sin,sout);
+ DLIB_TEST(sout.str() == "");
+
+ sin.clear();
+ sin.str("a");
+ sout.str("");
+ test.encode(sin,sout);
+ sin.str(sout.str());
+ sout.str("");
+ test.decode(sin,sout);
+ DLIB_TEST(sout.str() == "a");
+
+ sin.clear();
+ sin.str("da");
+ sout.str("");
+ test.encode(sin,sout);
+ sin.str(sout.str());
+ sout.str("");
+ test.decode(sin,sout);
+ DLIB_TEST(sout.str() == "da");
+
+ sin.clear();
+ sin.str("dav");
+ sout.str("");
+ test.encode(sin,sout);
+ sin.str(sout.str());
+ sout.str("");
+ test.decode(sin,sout);
+ DLIB_TEST(sout.str() == "dav");
+
+ sin.clear();
+ sin.str("davi");
+ sout.str("");
+ test.encode(sin,sout);
+ sin.str(sout.str());
+ sout.str("");
+ test.decode(sin,sout);
+ DLIB_TEST(sout.str() == "davi");
+
+
+ for (int i = 0; i < 1000; ++i)
+ {
+ str.clear();
+ sin.clear();
+ sout.str("");
+ sin.str("");
+
+ // fill str with random garbage
+ const int size = rand()%2000;
+ for (int j = 0; j < size; ++j)
+ {
+ unsigned char ch = rand()&0xFF;
+ str += ch;
+ }
+
+ sin.str(str);
+ test.encode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+ test.decode(sin,sout);
+
+ DLIB_TEST(str == sout.str());
+
+
+ }
+
+
+
+
+ }
+ catch (typename base64::decode_error& e)
+ {
+ DLIB_TEST_MSG(false,
+ "decode_error thrown when it shouldn't have been (" << seed << "):\n "
+ << e.info);
+ }
+ }
+
+
+ class base64_tester : public tester
+ {
+ public:
+ base64_tester (
+ ) :
+ tester ("test_base64",
+ "Runs tests on the base64 component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ base64_kernel_test<base64>();
+ }
+ } a;
+
+
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/bayes_nets.cpp b/ml/dlib/dlib/test/bayes_nets.cpp
new file mode 100644
index 000000000..1a3035762
--- /dev/null
+++ b/ml/dlib/dlib/test/bayes_nets.cpp
@@ -0,0 +1,411 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include "dlib/graph_utils.h"
+#include "dlib/graph.h"
+#include "dlib/directed_graph.h"
+#include "dlib/bayes_utils.h"
+#include "dlib/set.h"
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.bayes_nets");
+ enum nodes
+ {
+ A, T, S, L, O, B, D, X
+ };
+
+ template <typename gtype>
+ void setup_simple_network (
+ gtype& bn
+ )
+ {
+ /*
+ A
+ / \
+ T S
+ */
+
+ using namespace bayes_node_utils;
+
+ bn.set_number_of_nodes(3);
+ bn.add_edge(A, T);
+ bn.add_edge(A, S);
+
+
+ set_node_num_values(bn, A, 2);
+ set_node_num_values(bn, T, 2);
+ set_node_num_values(bn, S, 2);
+
+ assignment parents;
+
+ // set probabilities for node A
+ set_node_probability(bn, A, 1, parents, 0.1);
+ set_node_probability(bn, A, 0, parents, 1-0.1);
+
+ // set probabilities for node T
+ parents.add(A, 1);
+ set_node_probability(bn, T, 1, parents, 0.5);
+ set_node_probability(bn, T, 0, parents, 1-0.5);
+ parents[A] = 0;
+ set_node_probability(bn, T, 1, parents, 0.5);
+ set_node_probability(bn, T, 0, parents, 1-0.5);
+
+ // set probabilities for node S
+ parents[A] = 1;
+ set_node_probability(bn, S, 1, parents, 0.5);
+ set_node_probability(bn, S, 0, parents, 1-0.5);
+ parents[A] = 0;
+ set_node_probability(bn, S, 1, parents, 0.5);
+ set_node_probability(bn, S, 0, parents, 1-0.5);
+
+
+ // test the serialization code here by pushing this network though it
+ ostringstream sout;
+ serialize(bn, sout);
+ bn.clear();
+ DLIB_TEST(bn.number_of_nodes() == 0);
+ istringstream sin(sout.str());
+ deserialize(bn, sin);
+ DLIB_TEST(bn.number_of_nodes() == 3);
+ }
+
+
+ template <typename gtype>
+ void setup_dyspnea_network (
+ gtype& bn,
+ bool deterministic_o_node = true
+ )
+ {
+ /*
+ This is the example network used by David Zaret in his
+ reasoning under uncertainty class at Johns Hopkins
+ */
+
+ using namespace bayes_node_utils;
+
+ bn.set_number_of_nodes(8);
+ bn.add_edge(A, T);
+ bn.add_edge(T, O);
+
+ bn.add_edge(O, D);
+ bn.add_edge(O, X);
+
+ bn.add_edge(S, B);
+ bn.add_edge(S, L);
+
+ bn.add_edge(L, O);
+ bn.add_edge(B, D);
+
+
+ set_node_num_values(bn, A, 2);
+ set_node_num_values(bn, T, 2);
+ set_node_num_values(bn, O, 2);
+ set_node_num_values(bn, X, 2);
+ set_node_num_values(bn, L, 2);
+ set_node_num_values(bn, S, 2);
+ set_node_num_values(bn, B, 2);
+ set_node_num_values(bn, D, 2);
+
+ assignment parents;
+
+ // set probabilities for node A
+ set_node_probability(bn, A, 1, parents, 0.01);
+ set_node_probability(bn, A, 0, parents, 1-0.01);
+
+ // set probabilities for node S
+ set_node_probability(bn, S, 1, parents, 0.5);
+ set_node_probability(bn, S, 0, parents, 1-0.5);
+
+ // set probabilities for node T
+ parents.add(A, 1);
+ set_node_probability(bn, T, 1, parents, 0.05);
+ set_node_probability(bn, T, 0, parents, 1-0.05);
+ parents[A] = 0;
+ set_node_probability(bn, T, 1, parents, 0.01);
+ set_node_probability(bn, T, 0, parents, 1-0.01);
+
+ // set probabilities for node L
+ parents.clear();
+ parents.add(S,1);
+ set_node_probability(bn, L, 1, parents, 0.1);
+ set_node_probability(bn, L, 0, parents, 1-0.1);
+ parents[S] = 0;
+ set_node_probability(bn, L, 1, parents, 0.01);
+ set_node_probability(bn, L, 0, parents, 1-0.01);
+
+
+ // set probabilities for node B
+ parents[S] = 1;
+ set_node_probability(bn, B, 1, parents, 0.6);
+ set_node_probability(bn, B, 0, parents, 1-0.6);
+ parents[S] = 0;
+ set_node_probability(bn, B, 1, parents, 0.3);
+ set_node_probability(bn, B, 0, parents, 1-0.3);
+
+
+ // set probabilities for node O
+ double v;
+ if (deterministic_o_node)
+ v = 1;
+ else
+ v = 0.99;
+
+ parents.clear();
+ parents.add(T,1);
+ parents.add(L,1);
+ set_node_probability(bn, O, 1, parents, v);
+ set_node_probability(bn, O, 0, parents, 1-v);
+ parents[T] = 0; parents[L] = 1;
+ set_node_probability(bn, O, 1, parents, v);
+ set_node_probability(bn, O, 0, parents, 1-v);
+ parents[T] = 1; parents[L] = 0;
+ set_node_probability(bn, O, 1, parents, v);
+ set_node_probability(bn, O, 0, parents, 1-v);
+ parents[T] = 0; parents[L] = 0;
+ set_node_probability(bn, O, 1, parents, 1-v);
+ set_node_probability(bn, O, 0, parents, v);
+
+
+ // set probabilities for node D
+ parents.clear();
+ parents.add(O,1);
+ parents.add(B,1);
+ set_node_probability(bn, D, 1, parents, 0.9);
+ set_node_probability(bn, D, 0, parents, 1-0.9);
+ parents[O] = 1; parents[B] = 0;
+ set_node_probability(bn, D, 1, parents, 0.7);
+ set_node_probability(bn, D, 0, parents, 1-0.7);
+ parents[O] = 0; parents[B] = 1;
+ set_node_probability(bn, D, 1, parents, 0.8);
+ set_node_probability(bn, D, 0, parents, 1-0.8);
+ parents[O] = 0; parents[B] = 0;
+ set_node_probability(bn, D, 1, parents, 0.1);
+ set_node_probability(bn, D, 0, parents, 1-0.1);
+
+
+ // set probabilities for node X
+ parents.clear();
+ parents.add(O,1);
+ set_node_probability(bn, X, 1, parents, 0.98);
+ set_node_probability(bn, X, 0, parents, 1-0.98);
+ parents[O] = 0;
+ set_node_probability(bn, X, 1, parents, 0.05);
+ set_node_probability(bn, X, 0, parents, 1-0.05);
+
+
+ // test the serialization code here by pushing this network though it
+ ostringstream sout;
+ serialize(bn, sout);
+ bn.clear();
+ DLIB_TEST(bn.number_of_nodes() == 0);
+ istringstream sin(sout.str());
+ deserialize(bn, sin);
+ DLIB_TEST(bn.number_of_nodes() == 8);
+ }
+
+
+ void bayes_nets_test (
+ )
+ /*!
+ ensures
+ - runs tests on the bayesian network objects and functions for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ directed_graph<bayes_node>::kernel_1a_c bn;
+ setup_dyspnea_network(bn);
+
+ using namespace bayes_node_utils;
+
+
+ graph<dlib::set<unsigned long>::compare_1b_c, dlib::set<unsigned long>::compare_1b_c>::kernel_1a_c join_tree;
+
+ create_moral_graph(bn, join_tree);
+ create_join_tree(join_tree, join_tree);
+
+ bayesian_network_join_tree solution(bn, join_tree);
+
+ matrix<double,1,2> dist;
+
+ dist = solution.probability(A);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.01 ) < 1e-5);
+
+ dist = solution.probability(T);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.0104) < 1e-5);
+
+ dist = solution.probability(O);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.064828) < 1e-5);
+
+ dist = solution.probability(X);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.11029004) < 1e-5);
+
+ dist = solution.probability(L);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.055) < 1e-5);
+
+ dist = solution.probability(S);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.5) < 1e-5);
+
+ dist = solution.probability(B);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.4499999) < 1e-5);
+
+ dist = solution.probability(D);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.4359706 ) < 1e-5);
+
+ // now lets modify the probabilities of the bayesian network by making O
+ // not a deterministic node anymore but otherwise leave the network alone
+ setup_dyspnea_network(bn, false);
+
+ set_node_value(bn, A, 1);
+ set_node_value(bn, X, 1);
+ set_node_value(bn, S, 1);
+ // lets also make some of these nodes evidence nodes
+ set_node_as_evidence(bn, A);
+ set_node_as_evidence(bn, X);
+ set_node_as_evidence(bn, S);
+
+ // reload the solution now that we have changed the probabilities of node O
+ bayesian_network_join_tree(bn, join_tree).swap(solution);
+ DLIB_TEST(solution.number_of_nodes() == bn.number_of_nodes());
+
+ dist = solution.probability(A);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 1.0 ) < 1e-5);
+
+ dist = solution.probability(T);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.253508694039 ) < 1e-5);
+
+ dist = solution.probability(O);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.77856184024 ) < 1e-5);
+
+ dist = solution.probability(X);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 1.0 ) < 1e-5);
+
+ dist = solution.probability(L);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.5070173880 ) < 1e-5);
+
+ dist = solution.probability(S);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 1.0 ) < 1e-5);
+
+ dist = solution.probability(B);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.6 ) < 1e-5);
+
+ dist = solution.probability(D);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.7535685520 ) < 1e-5);
+
+
+ // now lets test the bayesian_network_gibbs_sampler
+ set_node_value(bn, A, 1);
+ set_node_value(bn, T, 1);
+ set_node_value(bn, O, 1);
+ set_node_value(bn, X, 1);
+ set_node_value(bn, S, 1);
+ set_node_value(bn, L, 1);
+ set_node_value(bn, B, 1);
+ set_node_value(bn, D, 1);
+
+ bayesian_network_gibbs_sampler sampler;
+ matrix<double,1,8> counts;
+ set_all_elements(counts, 0);
+ const unsigned long rounds = 500000;
+ for (unsigned long i = 0; i < rounds; ++i)
+ {
+ sampler.sample_graph(bn);
+
+ for (long c = 0; c < counts.nc(); ++c)
+ {
+ if (node_value(bn, c) == 1)
+ counts(c) += 1;
+ }
+
+ if ((i&0x3FF) == 0)
+ {
+ print_spinner();
+ }
+ }
+
+ counts /= rounds;
+
+ DLIB_TEST(abs(counts(A) - 1.0 ) < 1e-2);
+ DLIB_TEST(abs(counts(T) - 0.253508694039 ) < 1e-2);
+ DLIB_TEST_MSG(abs(counts(O) - 0.77856184024 ) < 1e-2,abs(counts(O) - 0.77856184024 ) );
+ DLIB_TEST(abs(counts(X) - 1.0 ) < 1e-2);
+ DLIB_TEST(abs(counts(L) - 0.5070173880 ) < 1e-2);
+ DLIB_TEST(abs(counts(S) - 1.0 ) < 1e-2);
+ DLIB_TEST(abs(counts(B) - 0.6 ) < 1e-2);
+ DLIB_TEST(abs(counts(D) - 0.7535685520 ) < 1e-2);
+
+
+ setup_simple_network(bn);
+ create_moral_graph(bn, join_tree);
+ create_join_tree(join_tree, join_tree);
+ bayesian_network_join_tree(bn, join_tree).swap(solution);
+ DLIB_TEST(solution.number_of_nodes() == bn.number_of_nodes());
+
+ dist = solution.probability(A);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.1 ) < 1e-5);
+
+ dist = solution.probability(T);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.5 ) < 1e-5);
+
+ dist = solution.probability(S);
+ DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5);
+ DLIB_TEST(abs(dist(1) - 0.5 ) < 1e-5);
+
+
+ }
+
+
+
+
+ class bayes_nets_tester : public tester
+ {
+ public:
+ bayes_nets_tester (
+ ) :
+ tester ("test_bayes_nets",
+ "Runs tests on the bayes_nets objects and functions.")
+ {}
+
+ void perform_test (
+ )
+ {
+ bayes_nets_test();
+ }
+ } a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/bigint.cpp b/ml/dlib/dlib/test/bigint.cpp
new file mode 100644
index 000000000..3ddc631b4
--- /dev/null
+++ b/ml/dlib/dlib/test/bigint.cpp
@@ -0,0 +1,522 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+
+#include <dlib/bigint.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.bigint");
+
+ namespace bigint_kernel_test_helpers
+ {
+ template <
+ typename bint
+ >
+ bint short_fact (unsigned short value)
+ /*!
+ ensures
+ - returns the factorial of value
+ !*/
+ {
+ using namespace relational_operators;
+
+ bint a = 1;
+ for (unsigned short i = 2; i <= value; ++i)
+ a *= i;
+
+ return a;
+ }
+
+ template <
+ typename bint
+ >
+ bint short_fact_squared (unsigned short value)
+ /*!
+ ensures
+ - returns the square of the factorial of value
+ !*/
+ {
+ using namespace relational_operators;
+
+ bint a = 1;
+ for (unsigned short i = 2; i <= value; ++i)
+ {
+ a *= i;
+ a *= i;
+ }
+
+ return a;
+ }
+
+
+ template <
+ typename bint
+ >
+ bint big_fact (unsigned short value)
+ /*!
+ ensures
+ - returns the factorial of value
+ !*/
+ {
+ using namespace relational_operators;
+
+ bint a = 1;
+ int k = 0;
+ for (bint i = 2; i <= value; ++i)
+ {
+ ++k;
+ if (k%10 == 0)
+ print_spinner();
+ a *= i;
+ }
+
+ return a;
+ }
+ }
+
+ template <
+ typename bint
+ >
+ void bigint_kernel_test (
+ )
+ /*!
+ requires
+ - bint is an implementation of bigint/bigint_kernel_abstract.h
+ ensures
+ - runs tests on bint for compliance with the specs
+ !*/
+ {
+ using namespace bigint_kernel_test_helpers;
+ using namespace relational_operators;
+ istringstream sin;
+ ostringstream sout;
+
+ bint i = 0;
+ bint a(5), b, c(0);
+
+ DLIB_TEST(5 - a == 0);
+ DLIB_TEST(a - 5 == 0);
+
+ DLIB_TEST(0 - c == 0);
+ DLIB_TEST(c - 0 == 0);
+
+ DLIB_TEST(0 + c == 0);
+ DLIB_TEST(c + 0 == 0);
+
+ DLIB_TEST(0 + a == 5);
+ DLIB_TEST(a + 0 == 5);
+
+ DLIB_TEST(0 - b == 0);
+ DLIB_TEST(b - 0 == 0);
+
+ DLIB_TEST(0 + b == 0);
+ DLIB_TEST(b + 0 == 0);
+
+ DLIB_TEST(i == 0);
+ DLIB_TEST(a == 5);
+ DLIB_TEST(b == 0);
+ DLIB_TEST(c == 0);
+
+
+
+ a -= 5;
+ DLIB_TEST(a == 0);
+
+
+
+ for (int k = 0; k < 100; ++k)
+ {
+ // compute the factorial of k using the O(n) multiplication algorithm
+ a = short_fact<bint>(k);
+ // compute the factorial of k using the full blown big int
+ // multiplication algorithm.
+ b = big_fact<bint>(k);
+ // compute the square of the factorial of k using the full blown
+ // big int multiplication algorithm.
+ c = a*b;
+ // make sure a and b ended up being the same number
+ DLIB_TEST_MSG(a == b,
+ "k: " << k << "\n"
+ "short_fact: " << a << "\n"
+ "big_fact: " << b
+ );
+ // make sure c really is the square of the factorial of k
+ DLIB_TEST_MSG(short_fact_squared<bint>(k) == c,"k: " << k);
+ print_spinner();
+ }
+
+ // do the same thing as the last loop but do it with way bigger numbers
+ for (int k = 1000; k < 10000; k += 2000)
+ {
+ bint a = short_fact<bint>(k);
+ bint b = big_fact<bint>(k);
+ bint c = a*b;
+ DLIB_TEST_MSG(a == b,
+ "k: " << k << "\n"
+ "short_fact: " << a << "\n"
+ "big_fact: " << b
+ );
+ DLIB_TEST_MSG(short_fact_squared<bint>(k) == c,"k: " << k);
+ print_spinner();
+ }
+
+
+
+ // test the << and >> operators a little
+ a = big_fact<bint>(20);
+ sout << a;
+ DLIB_TEST_MSG( sout.str() == "2432902008176640000","was: " << a);
+
+ sin.str("684626312793279327952039475203945");
+ sin >> a;
+ sout.str("");
+ sout << a;
+ DLIB_TEST(sout.str() == "684626312793279327952039475203945");
+
+ print_spinner();
+
+ DLIB_TEST(a > 0);
+
+
+ // make sure that when you try to read something that isn't a number
+ // into a bigint you get an error
+ DLIB_TEST(sin.fail() == false);
+ sin.str("the cat ate some cheese");
+ sin >> a;
+ DLIB_TEST(sin.fail() == true);
+ sin.clear();
+ sin.str("");
+
+
+
+ sin.str("3628913");
+ sin >> i;
+ DLIB_TEST(short_fact<bint>(10) + short_fact<bint>(5) - 7 == i);
+
+ sin.str("2432902008173011193");
+ sin >> i;
+ DLIB_TEST(short_fact<bint>(20) - short_fact<bint>(10) - 7 == i);
+
+ // test the serialization stuff
+ sout.str("");
+ serialize(i,sout);
+ i = 0;
+ sin.str(sout.str());
+ deserialize(i,sin);
+
+ DLIB_TEST(short_fact<bint>(20) - short_fact<bint>(10) - 7 == i);
+
+
+
+
+ print_spinner();
+
+
+
+
+ sin.str("100000");
+ sin >> b;
+ a = b;
+ ++b;
+ DLIB_TEST_MSG ( a + 1 == b,"a==" << a << endl << "b==" << b << endl);
+
+
+
+
+
+ // compute some stuff and see if you get the right value
+ a = 0;
+ b = 0;
+ sin.str("1000000");
+ sin >> b;
+ int mel = 0;
+ for (i = a; i <= b; ++i)
+ {
+ // switch it up on em
+ if (i%2 == 0)
+ a = a + i;
+ else
+ a += i;
+ ++mel;
+ if ((mel&0xFFF) == 0)
+ print_spinner();
+ }
+ DLIB_TEST_MSG(a == b*(b+1)/2, "a==" << a << endl << "b*(b+1)/2==" << b*(b+1)/2 << endl);
+
+
+
+
+
+
+ print_spinner();
+
+
+ // compute some stuff and see if you get the right value
+ // this time going the other way using operator--
+ a = 0;
+ b = 0;
+ sin.str("100000");
+ sin >> b;
+ i = b;
+ DLIB_TEST(i == b);
+ DLIB_TEST_MSG(i > 0,"i==" << i);
+ mel = 0;
+ for (i = b; i > 0; --i)
+ {
+ // switch it up on em
+ if (i%2 == 0)
+ a = a + i;
+ else
+ a += i;
+ ++mel;
+ if ((mel&0xFF) == 0)
+ print_spinner();
+ }
+ DLIB_TEST_MSG(a == b*(b+1)/2, "a==" << a << endl << "b*(b+1)/2==" << b*(b+1)/2 << endl);
+
+
+
+
+
+
+
+
+
+
+
+ DLIB_TEST(short_fact<bint>(10)/short_fact<bint>(5) == 30240);
+ DLIB_TEST(short_fact<bint>(10)/(short_fact<bint>(5)+1) == 29990);
+
+ sin.str("221172909834240000");
+ sin >> a;
+ DLIB_TEST(short_fact<bint>(20)/(short_fact<bint>(5)+1) == a/11);
+
+ sin.str("670442388044");
+ sin >> b;
+ DLIB_TEST(short_fact<bint>(20)/(short_fact<bint>(10)+1) == b);
+
+ print_spinner();
+
+ sin.str("1860479");
+ sin >> i;
+ DLIB_TEST_MSG(short_fact<bint>(20)/(short_fact<bint>(15)+1) == i,short_fact<bint>(20)/(short_fact<bint>(15)+1));
+
+ // test the serialization stuff
+ sout.str("");
+ serialize(i,sout);
+ i = 0;
+ sin.str(sout.str());
+ deserialize(i,sin);
+
+ DLIB_TEST_MSG(short_fact<bint>(20)/(short_fact<bint>(15)+1) == i,short_fact<bint>(20)/(short_fact<bint>(15)+1));
+
+
+ print_spinner();
+
+ // test the serialization stuff
+ sout.str("");
+ i = 0;
+ serialize(i,sout);
+ i = 1234;
+ sin.str(sout.str());
+ deserialize(i,sin);
+ DLIB_TEST(i == 0);
+
+
+ DLIB_TEST(short_fact<bint>(10000)/short_fact<bint>(9999) == 10000);
+
+
+ DLIB_TEST(bint(5)%bint(1) == 0);
+ DLIB_TEST(bint(5)%bint(6) == 5);
+ DLIB_TEST(bint(25)%bint(6) == 1);
+ print_spinner();
+ DLIB_TEST(bint(354)%bint(123) == 108);
+ DLIB_TEST(bint(20)%(bint(10)) == 0);
+ DLIB_TEST(bint(20)%(bint(10)+1) == 9);
+
+ DLIB_TEST(bint(20)%(bint(15)+1) == 4);
+
+
+ DLIB_TEST(short_fact<bint>(10)%(short_fact<bint>(5)+2) == 32);
+
+ sin.str("2908082");
+ sin >> i;
+ DLIB_TEST(short_fact<bint>(15)%(short_fact<bint>(10)+2) == i);
+
+
+
+
+
+
+ // same as some of the above stuff but using big_fact
+
+ DLIB_TEST(big_fact<bint>(10)%(big_fact<bint>(5)+2) == 32);
+
+ sin.str("2908082");
+ sin >> i;
+ DLIB_TEST(big_fact<bint>(15)%(big_fact<bint>(10)+2) == i);
+
+
+ print_spinner();
+
+
+ DLIB_TEST(big_fact<bint>(10)/big_fact<bint>(5) == 30240);
+ DLIB_TEST(big_fact<bint>(10)/(big_fact<bint>(5)+1) == 29990);
+
+ sin.str("221172909834240000");
+ sin >> a;
+ DLIB_TEST(big_fact<bint>(20)/(big_fact<bint>(5)+1) == a/11);
+
+
+ sin.str("670442388044");
+ sin >> b;
+ DLIB_TEST(big_fact<bint>(20)/(big_fact<bint>(10)+1) == b);
+
+
+ sin.str("1860479");
+ sin >> i;
+ DLIB_TEST_MSG(big_fact<bint>(20)/(big_fact<bint>(15)+1) == i,big_fact<bint>(20)/(big_fact<bint>(15)+1));
+
+ DLIB_TEST(big_fact<bint>(100)/big_fact<bint>(99) == 100);
+
+
+
+
+ sout.str("");
+ sout << "148571596448176149730952273362082573788556996128468876694221686370498539309";
+ sout << "4065876545992131370884059645617234469978112000000000000000000000";
+ sin.str(sout.str());
+ sin >> a;
+
+ sout.str("");
+ sout << "933262154439441526816992388562667004907159682643816214685929638952175999932";
+ sout << "299156089414639761565182862536979208272237582511852109168640000000000000000";
+ sout << "000000";
+ sin.str(sout.str());
+ sin >> b;
+
+
+ sout.str("");
+ sout << "138656248189732152054159609718432247180282092567575172939636909224427929240";
+ sout << "834642263988043338170905744175653189424779336521852536242160190545537133916";
+ sout << "649622615351174407746524657461692702500613722228638559932561661493048332720";
+ sout << "6050692647868232055316807680000000000000000000000000000000000000000000";
+ sin.str(sout.str());
+ sin >> c;
+
+ DLIB_TEST_MSG(a*b == c,
+ "a*b: " << a*b <<
+ "\nc: " << c);
+
+
+ print_spinner();
+
+ i = 0;
+ mel = 0;
+ unsigned long j;
+ for (j = 0; i < bint(100000); ++j)
+ {
+ DLIB_TEST(i++ == bint(j));
+ ++mel;
+ if((mel&0xFF) == 0)
+ print_spinner();
+ }
+ DLIB_TEST(j == 100000);
+
+ i = 1234;
+
+ DLIB_TEST(i == 1234);
+ DLIB_TEST(i < 2345 );
+ DLIB_TEST(i > 0 );
+ DLIB_TEST(i > 123 );
+
+ DLIB_TEST(i != 1334);
+ DLIB_TEST(i <= 2345);
+ DLIB_TEST(i >= 0 );
+ DLIB_TEST(i >= 123 );
+ DLIB_TEST(i >= 1234);
+ DLIB_TEST(i <= 1234);
+
+
+ DLIB_TEST(1234 == i);
+ DLIB_TEST(2345 > i);
+ DLIB_TEST(0 < i);
+ DLIB_TEST(123 < i);
+
+ DLIB_TEST(1334 != i);
+ DLIB_TEST(2345 >= i);
+ DLIB_TEST(0 <= i);
+ DLIB_TEST(123 <= i);
+ DLIB_TEST(1234 <= i);
+ DLIB_TEST(1234 >= i);
+
+
+ a = big_fact<bint>(200);
+ b = big_fact<bint>(100);
+
+ DLIB_TEST(a > b);
+ DLIB_TEST(a != b);
+ DLIB_TEST(b < a);
+ DLIB_TEST(b != a);
+ DLIB_TEST(b <= a);
+ DLIB_TEST(a >= b);
+
+
+
+ a = 10000;
+ a = a*a*a*a; a = a*a; a = a*a;
+ b = 2;
+ DLIB_TEST((a/b)*b == a);
+
+ a = 10000*5;
+ a = a*a*a*a; a = a*a; a = a*a;
+ b = 5;
+ DLIB_TEST((a/b)*b == a);
+ }
+
+
+
+
+ class bigint_tester : public tester
+ {
+ public:
+ bigint_tester (
+ ) :
+ tester ("test_bigint",
+ "Runs tests on the bigint component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ bigint_kernel_test<bigint::kernel_1a>();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_1a_c";
+ bigint_kernel_test<bigint::kernel_1a_c>();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_2a";
+ bigint_kernel_test<bigint::kernel_2a>();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_2a_c";
+ bigint_kernel_test<bigint::kernel_2a_c>();
+ print_spinner();
+
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/binary_search_tree.h b/ml/dlib/dlib/test/binary_search_tree.h
new file mode 100644
index 000000000..18bdff70d
--- /dev/null
+++ b/ml/dlib/dlib/test/binary_search_tree.h
@@ -0,0 +1,889 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_H_
+#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_H_
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/memory_manager_global.h>
+#include <dlib/memory_manager_stateless.h>
+#include <dlib/binary_search_tree.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.binary_search_tree");
+
+ template <
+ typename bst
+ >
+ void binary_search_tree_kernel_test (
+ )
+ /*!
+ requires
+ - bst is an implementation of
+ binary_search_tree/binary_search_tree_kernel_abstract.h is instantiated
+ to map int to int
+ ensures
+ - runs tests on bst for compliance with the specs
+ !*/
+ {
+
+ bst test, test2;
+
+ srand(static_cast<unsigned int>(time(0)));
+
+
+ DLIB_TEST(test.count(3) == 0);
+
+ enumerable<map_pair<int,int> >& e = test;
+ DLIB_TEST(e.at_start() == true);
+
+ DLIB_TEST(test.count(3) == 0);
+
+ for (int i = 0; i < 4; ++i)
+ {
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.count(3) == 0);
+ DLIB_TEST(test.height() == 0);
+ DLIB_TEST(test[5] == 0);
+ DLIB_TEST(test[0] == 0);
+ DLIB_TEST(test.at_start());
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.count(3) == 0);
+
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.count(3) == 0);
+
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ test.clear();
+ test.position_enumerator(5);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ test.position_enumerator(5);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ test.position_enumerator(9);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ test.clear();
+ test.position_enumerator(5);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ test.position_enumerator(5);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ test.position_enumerator(9);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ test.clear();
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ DLIB_TEST(test.count(3) == 0);
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.height() == 0);
+ DLIB_TEST(test[5] == 0);
+ DLIB_TEST(test[0] == 0);
+ DLIB_TEST(const_cast<const bst&>(test)[5] == 0);
+ DLIB_TEST(const_cast<const bst&>(test)[0] == 0);
+ DLIB_TEST(test.at_start());
+ DLIB_TEST(test.current_element_valid() == false);
+
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ DLIB_TEST(test.count(3) == 0);
+ test.reset();
+ DLIB_TEST(test.count(3) == 0);
+
+ DLIB_TEST(test.at_start());
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+
+
+
+
+ int a = 0, b = 0;
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand()%1000;
+ int temp = a;
+ unsigned long count = test.count(a);
+ test.add(a,b);
+ DLIB_TEST(test.count(temp) == count+1);
+ }
+
+
+ {
+ unsigned long count = test.count(3);
+
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ }
+
+
+ test.clear();
+
+
+
+
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand()&0x7FFF;
+ b = 0;
+ int temp = a;
+ unsigned long count = test.count(a);
+ test.add(a,b);
+ DLIB_TEST(test.count(temp) == count+1);
+ }
+
+ // serialize the state of test, then clear test, then
+ // load the state back into test.
+ ostringstream sout;
+ serialize(test,sout);
+ istringstream sin(sout.str());
+ test.clear();
+ deserialize(test,sin);
+
+ DLIB_TEST(test.size() == 10000);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ DLIB_TEST_MSG(test.height() > 13 && test.height() <= 26,"this is somewhat of an implementation dependent "
+ << "but really it should be in this range or the implementation is just crap");
+
+ a = 0;
+ unsigned long count = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST_MSG(a <= test.element().key(),"the numers are out of order but they should be in order");
+ a = test.element().key();
+ ++count;
+
+
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == true);
+ }
+
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+
+ DLIB_TEST(count == 10000);
+
+
+
+
+ DLIB_TEST_MSG(test.height() > 13 && test.height() <= 26,"this is somewhat of an implementation dependent "
+ << "but really it should be in this range or the implementation is just crap");
+
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.size() == 10000);
+
+
+ swap(test,test2);
+
+
+ test2.reset();
+ count = 0;
+ a = 0;
+ while (test2.move_next())
+ {
+ DLIB_TEST_MSG(a <= test2.element().key(),"the numers are out of order but they should be in order");
+ a = test2.element().key();
+ ++count;
+
+
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == true);
+
+ if (count == 5000)
+ {
+ break;
+ }
+ }
+
+ DLIB_TEST(test2.move_next() == true);
+ DLIB_TEST(test2.move_next() == true);
+ DLIB_TEST(test2.move_next() == true);
+
+
+ test2.reset();
+
+ count = 0;
+ a = 0;
+ while (test2.move_next())
+ {
+ DLIB_TEST_MSG(a <= test2.element().key(),"the numers are out of order but they should be in order");
+ a = test2.element().key();
+ ++count;
+
+
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == true);
+ }
+
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.move_next() == false);
+
+
+
+
+
+
+
+
+ int last = 0;
+ asc_pair_remover<int,int,typename bst::compare_type>& asdf = test2;
+ DLIB_TEST(asdf.size() > 0);
+ while (asdf.size() > 0)
+ {
+ asdf.remove_any(a,b);
+ DLIB_TEST(last <= a);
+ last = a;
+ --count;
+ DLIB_TEST(asdf.size() == count);
+ }
+
+
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.height() ==0);
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.move_next() == false);
+
+
+
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = i;
+ b = i;
+ test2.add(a,b);
+ DLIB_TEST(test2.size() == (unsigned int)(i +1));
+ DLIB_TEST(test2.count(i) == 1);
+ }
+
+ a = 0;
+ test2.position_enumerator(a);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.element().key() == a);
+ DLIB_TEST(test2.element().value() == a);
+ a = 0;
+ test2.position_enumerator(a);
+ DLIB_TEST(test2.element().key() == a);
+ DLIB_TEST(test2.element().value() == a);
+ a = 8;
+ test2.position_enumerator(a);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.element().key() == a);
+ DLIB_TEST(test2.element().value() == a);
+ a = 1;
+ test2.position_enumerator(a);
+ DLIB_TEST(test2.element().key() == a);
+ DLIB_TEST(test2.element().value() == a);
+ a = -29;
+ test2.position_enumerator(a);
+ DLIB_TEST(test2.element().key() == 0);
+ DLIB_TEST(test2.element().value() == 0);
+ a = 10000;
+ test2.position_enumerator(a);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ a = -29;
+ test2.position_enumerator(a);
+ DLIB_TEST(test2.element().key() == 0);
+ DLIB_TEST(test2.element().value() == 0);
+ a = 8;
+ test2.position_enumerator(a);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.element().key() == a);
+ DLIB_TEST(test2.element().value() == a);
+ test2.reset();
+
+
+ DLIB_TEST_MSG(test2.height() > 13 && test2.height() <= 26,"this is somewhat of an implementation dependent "
+ << "but really it should be in this range or the implementation is just crap");
+
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.size() == 10000);
+
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ DLIB_TEST(test2.move_next() == true);
+ DLIB_TEST(test2.element().key() == i);
+ }
+
+
+
+ DLIB_TEST_MSG(test2.height() > 13 && test2.height() <= 26,"this is somewhat of an implementation dependent "
+ << "but really it should be in this range or the implementation is just crap");
+
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == true);
+ DLIB_TEST(test2.size() == 10000);
+
+
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+
+ a = 3;
+ test2.add(a,b);
+ DLIB_TEST(test2.count(3) == 2);
+
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ test2.remove(i,a,b);
+ DLIB_TEST(i == a);
+ }
+ test2.remove(3,a,b);
+
+
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.height() == 0);
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+
+
+
+ test2.clear();
+
+
+ int m = 0;
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand()&0x7FFF;
+ m = max(a,m);
+ test2.add(a,b);
+ }
+
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.move_next() == true);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == true);
+ DLIB_TEST(test2.move_next() == true);
+ DLIB_TEST(test2.current_element_valid() == true);
+ DLIB_TEST(test2.move_next() == true);
+ DLIB_TEST(test2.current_element_valid() == true);
+ DLIB_TEST(test2.move_next() == true);
+ DLIB_TEST(test2.current_element_valid() == true);
+ DLIB_TEST(test2.at_start() == false);
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand()&0xFFFF;
+ test2.position_enumerator(a);
+ if (test2[a])
+ {
+ DLIB_TEST(test2.element().key() == a);
+ }
+ else if (a <= m)
+ {
+ DLIB_TEST(test2.element().key() > a);
+ }
+ }
+
+ test2.clear();
+
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+
+
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.height() == 0);
+
+
+ for (int i = 0; i < 20000; ++i)
+ {
+ a = ::rand()&0x7FFF;
+ b = a;
+ test2.add(a,b);
+ }
+
+
+ DLIB_TEST(test2.size() == 20000);
+
+
+
+ // remove a bunch of elements randomly
+ int c;
+ for (int i = 0; i < 50000; ++i)
+ {
+ a = ::rand()&0x7FFF;
+ if (test2[a] != 0)
+ {
+ test2.remove(a,b,c);
+ DLIB_TEST(a == b);
+ }
+ }
+
+
+ // now add a bunch more
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand()&0x7FFF;
+ b = a;
+ test2.add(a,b);
+ }
+
+
+ // now iterate over it all and then remove all elements
+ {
+ int* array = new int[test2.size()];
+ int* tmp = array;
+ DLIB_TEST(test2.at_start() == true);
+ while (test2.move_next())
+ {
+ *tmp = test2.element().key();
+ ++tmp;
+ }
+
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+
+ tmp = array;
+ for (int i = 0; i < 10000; ++i)
+ {
+ DLIB_TEST(*test2[*tmp] == *tmp);
+ DLIB_TEST(*test2[*tmp] == *tmp);
+ DLIB_TEST(*test2[*tmp] == *tmp);
+ DLIB_TEST(*const_cast<const bst&>(test2)[*tmp] == *tmp);
+ ++tmp;
+ }
+
+ tmp = array;
+ while (test2.size() > 0)
+ {
+ unsigned long count = test2.count(*tmp);
+ test2.destroy(*tmp);
+ DLIB_TEST(test2.count(*tmp)+1 == count);
+ ++tmp;
+ }
+
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.at_start() == false);
+ test.swap(test2);
+ test.reset();
+
+ delete [] array;
+ }
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.height() == 0);
+
+ for (unsigned long i = 1; i < 100; ++i)
+ {
+ a = 1234;
+ test.add(a,b);
+ DLIB_TEST(test.count(1234) == i);
+ }
+
+ test.clear();
+
+
+
+
+
+
+ for (int m = 0; m < 3; ++m)
+ {
+
+ test2.clear();
+
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+
+
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.height() == 0);
+
+
+ int counter = 0;
+ while (counter < 10000)
+ {
+ a = ::rand()&0x7FFF;
+ b = ::rand()&0x7FFF;
+ if (test2[a] == 0)
+ {
+ test2.add(a,b);
+ ++counter;
+ }
+
+ }
+
+
+
+ DLIB_TEST(test2.size() == 10000);
+
+
+
+ // remove a bunch of elements randomly
+ for (int i = 0; i < 20000; ++i)
+ {
+ a = ::rand()&0x7FFF;
+ if (test2[a] != 0)
+ {
+ test2.remove(a,b,c);
+ DLIB_TEST(a == b);
+ }
+ }
+
+
+ // now add a bunch more
+ for (int i = 0; i < 20000; ++i)
+ {
+ a = ::rand()&0x7FFF;
+ b = ::rand()&0x7FFF;
+ if (test2[a] == 0)
+ test2.add(a,b);
+ }
+
+
+ // now iterate over it all and then remove all elements
+ {
+ int* array = new int[test2.size()];
+ int* array_val = new int[test2.size()];
+ int* tmp = array;
+ int* tmp_val = array_val;
+ DLIB_TEST(test2.at_start() == true);
+ int count = 0;
+ while (test2.move_next())
+ {
+ *tmp = test2.element().key();
+ ++tmp;
+ *tmp_val = test2.element().value();
+ ++tmp_val;
+
+ DLIB_TEST(*test2[*(tmp-1)] == *(tmp_val-1));
+ ++count;
+ }
+
+ DLIB_TEST(count == (int)test2.size());
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+
+ tmp = array;
+ tmp_val = array_val;
+ for (unsigned long i = 0; i < test2.size(); ++i)
+ {
+ DLIB_TEST_MSG(*test2[*tmp] == *tmp_val,i);
+ DLIB_TEST(*test2[*tmp] == *tmp_val);
+ DLIB_TEST(*test2[*tmp] == *tmp_val);
+ DLIB_TEST(*const_cast<const bst&>(test2)[*tmp] == *tmp_val);
+ ++tmp;
+ ++tmp_val;
+ }
+
+ // out << "\nsize: " << test2.size() << endl;
+ // out << "height: " << test2.height() << endl;
+
+ tmp = array;
+ while (test2.size() > 0)
+ {
+ unsigned long count = test2.count(*tmp);
+ test2.destroy(*tmp);
+ DLIB_TEST(test2.count(*tmp)+1 == count);
+ ++tmp;
+ }
+
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.at_start() == false);
+ test.swap(test2);
+ test.reset();
+
+ delete [] array;
+ delete [] array_val;
+ }
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.height() == 0);
+
+ for (unsigned long i = 1; i < 100; ++i)
+ {
+ a = 1234;
+ test.add(a,b);
+ DLIB_TEST(test.count(1234) == i);
+ }
+
+ test.clear();
+
+ }
+
+
+
+ a = 1;
+ b = 2;
+
+ test.add(a,b);
+
+ test.position_enumerator(0);
+ a = 0;
+ b = 0;
+ DLIB_TEST(test.height() == 1);
+ test.remove_current_element(a,b);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.height() == 0);
+ DLIB_TEST(test.size() == 0);
+
+
+ a = 1;
+ b = 2;
+ test.add(a,b);
+ a = 1;
+ b = 2;
+ test.add(a,b);
+
+ test.position_enumerator(0);
+ a = 0;
+ b = 0;
+ DLIB_TEST(test.height() == 2);
+ test.remove_current_element(a,b);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == true);
+ DLIB_TEST(test.height() == 1);
+ DLIB_TEST(test.size() == 1);
+
+ test.remove_current_element(a,b);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.height() == 0);
+ DLIB_TEST(test.size() == 0);
+
+ for (int i = 0; i < 100; ++i)
+ {
+ a = i;
+ b = i;
+ test.add(a,b);
+ }
+
+ DLIB_TEST(test.size() == 100);
+ test.remove_last_in_order(a,b);
+ DLIB_TEST(a == 99);
+ DLIB_TEST(b == 99);
+ DLIB_TEST(test.size() == 99);
+ test.remove_last_in_order(a,b);
+ DLIB_TEST(a == 98);
+ DLIB_TEST(b == 98);
+ DLIB_TEST(test.size() == 98);
+
+ test.position_enumerator(-10);
+ for (int i = 0; i < 97; ++i)
+ {
+ DLIB_TEST(test.element().key() == i);
+ DLIB_TEST(test.element().value() == i);
+ DLIB_TEST(test.move_next());
+ }
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ test.position_enumerator(10);
+ for (int i = 10; i < 97; ++i)
+ {
+ DLIB_TEST(test.element().key() == i);
+ DLIB_TEST(test.element().value() == i);
+ DLIB_TEST(test.move_next());
+ }
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ test.reset();
+ DLIB_TEST(test.at_start());
+ DLIB_TEST(test.current_element_valid() == false);
+ for (int i = 0; i < 98; ++i)
+ {
+ DLIB_TEST(test.move_next());
+ DLIB_TEST(test.element().key() == i);
+ DLIB_TEST(test.element().value() == i);
+ }
+ DLIB_TEST_MSG(test.size() == 98, test.size());
+ DLIB_TEST(test.move_next() == false);
+
+ test.position_enumerator(98);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+
+
+ test.position_enumerator(50);
+ DLIB_TEST(test.element().key() == 50);
+ DLIB_TEST(test.element().value() == 50);
+ DLIB_TEST(test[50] != 0);
+ test.remove_current_element(a,b);
+ DLIB_TEST(test[50] == 0);
+ DLIB_TEST_MSG(test.size() == 97, test.size());
+ DLIB_TEST(a == 50);
+ DLIB_TEST(b == 50);
+ DLIB_TEST(test.element().key() == 51);
+ DLIB_TEST(test.element().value() == 51);
+ DLIB_TEST(test.current_element_valid());
+ test.remove_current_element(a,b);
+ DLIB_TEST_MSG(test.size() == 96, test.size());
+ DLIB_TEST(a == 51);
+ DLIB_TEST(b == 51);
+ DLIB_TEST_MSG(test.element().key() == 52,test.element().key());
+ DLIB_TEST_MSG(test.element().value() == 52,test.element().value());
+ DLIB_TEST(test.current_element_valid());
+ test.remove_current_element(a,b);
+ DLIB_TEST_MSG(test.size() == 95, test.size());
+ DLIB_TEST(a == 52);
+ DLIB_TEST(b == 52);
+ DLIB_TEST_MSG(test.element().key() == 53,test.element().key());
+ DLIB_TEST_MSG(test.element().value() == 53,test.element().value());
+ DLIB_TEST(test.current_element_valid());
+ test.position_enumerator(50);
+ DLIB_TEST_MSG(test.element().key() == 53,test.element().key());
+ DLIB_TEST_MSG(test.element().value() == 53,test.element().value());
+ DLIB_TEST(test.current_element_valid());
+ test.position_enumerator(51);
+ DLIB_TEST_MSG(test.element().key() == 53,test.element().key());
+ DLIB_TEST_MSG(test.element().value() == 53,test.element().value());
+ DLIB_TEST(test.current_element_valid());
+ test.position_enumerator(52);
+ DLIB_TEST_MSG(test.element().key() == 53,test.element().key());
+ DLIB_TEST_MSG(test.element().value() == 53,test.element().value());
+ DLIB_TEST(test.current_element_valid());
+ test.position_enumerator(53);
+ DLIB_TEST_MSG(test.element().key() == 53,test.element().key());
+ DLIB_TEST_MSG(test.element().value() == 53,test.element().value());
+ DLIB_TEST(test.current_element_valid());
+
+ test.reset();
+ test.move_next();
+ int lasta = -1, lastb = -1;
+ count = 0;
+ while (test.current_element_valid() )
+ {
+ ++count;
+ int c = test.element().key();
+ int d = test.element().value();
+ test.remove_current_element(a,b);
+ DLIB_TEST(c == a);
+ DLIB_TEST(d == a);
+ DLIB_TEST(lasta < a);
+ DLIB_TEST(lastb < b);
+ lasta = a;
+ lastb = b;
+ }
+ DLIB_TEST_MSG(count == 95, count);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.height() == 0);
+
+ test.clear();
+
+ for (int i = 0; i < 1000; ++i)
+ {
+ a = 1;
+ b = 1;
+ test.add(a,b);
+ }
+
+ for (int i = 0; i < 40; ++i)
+ {
+ int num = ::rand()%800 + 1;
+ test.reset();
+ for (int j = 0; j < num; ++j)
+ {
+ DLIB_TEST(test.move_next());
+ }
+ DLIB_TEST_MSG(test.current_element_valid(),"size: " << test.size() << " num: " << num);
+ test.remove_current_element(a,b);
+ DLIB_TEST_MSG(test.current_element_valid(),"size: " << test.size() << " num: " << num);
+ test.remove_current_element(a,b);
+ test.position_enumerator(1);
+ if (test.current_element_valid())
+ test.remove_current_element(a,b);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 1);
+ }
+
+ test.clear();
+
+ }
+
+
+ test.clear();
+ test2.clear();
+
+ }
+
+}
+
+#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_H_
+
diff --git a/ml/dlib/dlib/test/binary_search_tree_kernel_1a.cpp b/ml/dlib/dlib/test/binary_search_tree_kernel_1a.cpp
new file mode 100644
index 000000000..b7e0b3a1a
--- /dev/null
+++ b/ml/dlib/dlib/test/binary_search_tree_kernel_1a.cpp
@@ -0,0 +1,47 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_
+#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/memory_manager_global.h>
+#include <dlib/memory_manager_stateless.h>
+#include <dlib/binary_search_tree.h>
+#include "tester.h"
+#include "binary_search_tree.h"
+
+namespace
+{
+
+
+ class binary_search_tree_tester : public tester
+ {
+
+ public:
+ binary_search_tree_tester (
+ ) :
+ tester ("test_binary_search_tree_kernel_1a",
+ "Runs tests on the binary_search_tree_kernel_1a component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ binary_search_tree_kernel_test<binary_search_tree<int,int>::kernel_1a>();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_1a_c";
+ binary_search_tree_kernel_test<binary_search_tree<int,int>::kernel_1a_c>();
+ print_spinner();
+ }
+ } a;
+
+}
+
+#endif
diff --git a/ml/dlib/dlib/test/binary_search_tree_kernel_2a.cpp b/ml/dlib/dlib/test/binary_search_tree_kernel_2a.cpp
new file mode 100644
index 000000000..e2be4b143
--- /dev/null
+++ b/ml/dlib/dlib/test/binary_search_tree_kernel_2a.cpp
@@ -0,0 +1,45 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_
+#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/memory_manager_global.h>
+#include <dlib/memory_manager_stateless.h>
+#include <dlib/binary_search_tree.h>
+#include "tester.h"
+#include "binary_search_tree.h"
+
+namespace
+{
+
+ class binary_search_tree_tester : public tester
+ {
+ public:
+ binary_search_tree_tester (
+ ) :
+ tester ("test_binary_search_tree_kernel_2a",
+ "Runs tests on the binary_search_tree_kernel_2a component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_2a";
+ binary_search_tree_kernel_test<binary_search_tree<int,int>::kernel_2a>();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_2a_c";
+ binary_search_tree_kernel_test<binary_search_tree<int,int>::kernel_2a_c>();
+ print_spinner();
+ }
+ } a;
+
+}
+
+#endif
diff --git a/ml/dlib/dlib/test/binary_search_tree_mm1.cpp b/ml/dlib/dlib/test/binary_search_tree_mm1.cpp
new file mode 100644
index 000000000..a9693bd15
--- /dev/null
+++ b/ml/dlib/dlib/test/binary_search_tree_mm1.cpp
@@ -0,0 +1,66 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_
+#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/memory_manager_global.h>
+#include <dlib/memory_manager_stateless.h>
+#include <dlib/binary_search_tree.h>
+#include "tester.h"
+#include "binary_search_tree.h"
+
+namespace
+{
+
+ class binary_search_tree_tester : public tester
+ {
+ struct factory
+ {
+ template <typename U>
+ struct return_type {
+ typedef typename memory_manager<U>::kernel_1c type;
+ };
+
+ template <typename U>
+ static typename return_type<U>::type* get_instance (
+ )
+ {
+ static typename return_type<U>::type instance;
+ return &instance;
+ }
+
+ };
+
+
+ public:
+ binary_search_tree_tester (
+ ) :
+ tester ("test_binary_search_tree_mm1",
+ "Runs tests on the binary_search_tree component with memory managers.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a /w memory_manager_global";
+ binary_search_tree_kernel_test<binary_search_tree<int,int,
+ memory_manager_global<char,factory>::kernel_1a>::kernel_1a>();
+ print_spinner();
+
+
+ dlog << LINFO << "testing kernel_1a /w memory_manager_stateless";
+ binary_search_tree_kernel_test<binary_search_tree<int,int,
+ memory_manager_stateless<char>::kernel_1a>::kernel_1a>();
+ print_spinner();
+ }
+ } a;
+
+}
+
+#endif
diff --git a/ml/dlib/dlib/test/binary_search_tree_mm2.cpp b/ml/dlib/dlib/test/binary_search_tree_mm2.cpp
new file mode 100644
index 000000000..354b1f91c
--- /dev/null
+++ b/ml/dlib/dlib/test/binary_search_tree_mm2.cpp
@@ -0,0 +1,48 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_
+#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/memory_manager_global.h>
+#include <dlib/memory_manager_stateless.h>
+#include <dlib/binary_search_tree.h>
+#include "tester.h"
+#include "binary_search_tree.h"
+
+namespace
+{
+
+ class binary_search_tree_tester : public tester
+ {
+
+ public:
+ binary_search_tree_tester (
+ ) :
+ tester ("test_binary_search_tree_mm2",
+ "Runs tests on the binary_search_tree component with memory managers.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a /w memory_manager_stateless_2";
+ binary_search_tree_kernel_test<binary_search_tree<int,int,
+ memory_manager_stateless<char>::kernel_2_2c>::kernel_1a>();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_1a /w memory_manager_3";
+ binary_search_tree_kernel_test<binary_search_tree<int,int,
+ memory_manager<char>::kernel_3b>::kernel_1a>();
+ print_spinner();
+ }
+ } a;
+
+}
+
+#endif
diff --git a/ml/dlib/dlib/test/blas_bindings/CMakeLists.txt b/ml/dlib/dlib/test/blas_bindings/CMakeLists.txt
new file mode 100644
index 000000000..5deddee04
--- /dev/null
+++ b/ml/dlib/dlib/test/blas_bindings/CMakeLists.txt
@@ -0,0 +1,33 @@
+#
+# This is a CMake makefile. You can find the cmake utility and
+# information about it at http://www.cmake.org
+#
+
+cmake_minimum_required(VERSION 2.8.12)
+
+# This variable contains a list of all the tests we are building
+# into the regression test suite.
+set (tests
+ blas_bindings_gemm.cpp
+ blas_bindings_gemv.cpp
+ blas_bindings_ger.cpp
+ blas_bindings_dot.cpp
+ blas_bindings_scal_axpy.cpp
+ vector.cpp
+ )
+
+# create a variable called target_name and set it to the string "test"
+set (target_name dtest)
+
+PROJECT(${target_name})
+
+# add all the cpp files we want to compile to this list. This tells
+# cmake that they are part of our target (which is the executable named test)
+ADD_EXECUTABLE(${target_name} ../main.cpp ../tester.cpp ${tests})
+
+ADD_DEFINITIONS(-DDLIB_TEST_BLAS_BINDINGS)
+
+# Tell cmake to link our target executable to dlib
+include(../../cmake)
+TARGET_LINK_LIBRARIES(${target_name} dlib )
+
diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_dot.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_dot.cpp
new file mode 100644
index 000000000..0571b0685
--- /dev/null
+++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_dot.cpp
@@ -0,0 +1,314 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "../tester.h"
+#include <dlib/matrix.h>
+
+#ifndef DLIB_USE_BLAS
+#error "BLAS bindings must be used for this test to make any sense"
+#endif
+
+namespace dlib
+{
+ namespace blas_bindings
+ {
+ // This is a little screwy. This function is used inside the BLAS
+ // bindings to count how many times each of the BLAS functions get called.
+#ifdef DLIB_TEST_BLAS_BINDINGS
+ int& counter_dot() { static int counter = 0; return counter; }
+#endif
+
+ }
+}
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ // Declare the logger we will use in this test. The name of the logger
+ // should start with "test."
+ dlib::logger dlog("test.dot");
+
+
+ class blas_bindings_dot_tester : public tester
+ {
+ public:
+ blas_bindings_dot_tester (
+ ) :
+ tester (
+ "test_dot", // the command line argument name for this test
+ "Run tests for DOT routines.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {}
+
+ void test_mat_bindings()
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+ matrix<double,1,0> rv(10);
+ matrix<double,0,1> cv(10);
+ double val;
+
+ rv = 1; cv = 1;
+ counter_dot() = 0;
+ val = rv*cv;
+ DLIB_TEST(val == 10);
+ DLIB_TEST(counter_dot() == 1);
+
+ rv = 1; cv = 1;
+ counter_dot() = 0;
+ val = rv*mat(&cv(0),cv.size());
+ DLIB_TEST(val == 10);
+ DLIB_TEST(counter_dot() == 1);
+
+ rv = 1; cv = 1;
+ counter_dot() = 0;
+ val = trans(mat(&rv(0),rv.size()))*mat(&cv(0),cv.size());
+ DLIB_TEST(val == 10);
+ DLIB_TEST(counter_dot() == 1);
+
+ std::vector<double> sv(10,1);
+ rv = 1;
+ counter_dot() = 0;
+ val = trans(mat(&rv(0),rv.size()))*mat(sv);
+ DLIB_TEST(val == 10);
+ DLIB_TEST(counter_dot() == 1);
+
+
+ counter_dot() = 0;
+ val = trans(mat(sv))*mat(sv);
+ DLIB_TEST(val == 10);
+ DLIB_TEST(counter_dot() == 1);
+
+ std_vector_c<double> svc(10,1);
+ counter_dot() = 0;
+ val = trans(mat(svc))*mat(svc);
+ DLIB_TEST(val == 10);
+ DLIB_TEST(counter_dot() == 1);
+
+
+ dlib::array<double> arr(10);
+ for (unsigned int i = 0; i < arr.size(); ++i)
+ arr[i] = 1;
+ counter_dot() = 0;
+ val = trans(mat(arr))*mat(arr);
+ DLIB_TEST(val == 10);
+ DLIB_TEST(counter_dot() == 1);
+ }
+
+ template <typename matrix_type, typename cv_type, typename rv_type>
+ void test_dot_stuff(
+ matrix_type& m,
+ rv_type& rv,
+ cv_type& cv
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ rv_type rv2;
+ cv_type cv2;
+ matrix_type m2;
+ typedef typename matrix_type::type scalar_type;
+ scalar_type val;
+
+ counter_dot() = 0;
+ m2 = rv*cv;
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = rv*cv;
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = rv*3*cv;
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = rv*trans(rv)*3;
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = trans(rv*trans(rv)*3 + trans(cv)*cv);
+ DLIB_TEST(counter_dot() == 2);
+
+
+ counter_dot() = 0;
+ val = trans(cv)*cv;
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = trans(cv)*trans(rv);
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = dot(rv,cv);
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = dot(rv,colm(cv,0));
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = dot(cv,cv);
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = dot(colm(cv,0,cv.size()),colm(cv,0));
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = dot(rv,rv);
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = dot(rv,trans(rv));
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = dot(trans(cv),cv);
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = dot(trans(cv),trans(rv));
+ DLIB_TEST(counter_dot() == 1);
+
+
+ // This does one dot and one gemv
+ counter_dot() = 0;
+ val = trans(cv)*m*trans(rv);
+ DLIB_TEST_MSG(counter_dot() == 1, counter_dot());
+
+ // This does one dot and two gemv
+ counter_dot() = 0;
+ val = (trans(cv)*m)*(m*trans(rv));
+ DLIB_TEST_MSG(counter_dot() == 1, counter_dot());
+
+ // This does one dot and two gemv
+ counter_dot() = 0;
+ val = trans(cv)*m*trans(m)*trans(rv);
+ DLIB_TEST_MSG(counter_dot() == 1, counter_dot());
+ }
+
+
+ template <typename matrix_type, typename cv_type, typename rv_type>
+ void test_dot_stuff_conj(
+ matrix_type& ,
+ rv_type& rv,
+ cv_type& cv
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ rv_type rv2;
+ cv_type cv2;
+ matrix_type m2;
+ typedef typename matrix_type::type scalar_type;
+ scalar_type val;
+
+ counter_dot() = 0;
+ val = conj(rv)*cv;
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = trans(conj(cv))*cv;
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = trans(conj(cv))*trans(rv);
+ DLIB_TEST(counter_dot() == 1);
+
+ counter_dot() = 0;
+ val = trans(conj(cv))*3*trans(rv);
+ DLIB_TEST(counter_dot() == 1);
+ }
+
+ void perform_test (
+ )
+ {
+ using namespace dlib;
+ typedef dlib::memory_manager<char>::kernel_1a mm;
+
+ dlog << dlib::LINFO << "test double";
+ {
+ matrix<double> m = randm(4,4);
+ matrix<double,1,0> rv = randm(1,4);
+ matrix<double,0,1> cv = randm(4,1);
+ test_dot_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test float";
+ {
+ matrix<float> m = matrix_cast<float>(randm(4,4));
+ matrix<float,1,0> rv = matrix_cast<float>(randm(1,4));
+ matrix<float,0,1> cv = matrix_cast<float>(randm(4,1));
+ test_dot_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<double>";
+ {
+ matrix<complex<double> > m = complex_matrix(randm(4,4), randm(4,4));
+ matrix<complex<double>,1,0> rv = complex_matrix(randm(1,4), randm(1,4));
+ matrix<complex<double>,0,1> cv = complex_matrix(randm(4,1), randm(4,1));
+ test_dot_stuff(m,rv,cv);
+ test_dot_stuff_conj(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<float>";
+ {
+ matrix<complex<float> > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
+ matrix<complex<float>,1,0> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
+ matrix<complex<float>,0,1> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
+ test_dot_stuff(m,rv,cv);
+ test_dot_stuff_conj(m,rv,cv);
+ }
+
+
+ dlog << dlib::LINFO << "test double, column major";
+ {
+ matrix<double,0,0,mm,column_major_layout> m = randm(4,4);
+ matrix<double,1,0,mm,column_major_layout> rv = randm(1,4);
+ matrix<double,0,1,mm,column_major_layout> cv = randm(4,1);
+ test_dot_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test float, column major";
+ {
+ matrix<float,0,0,mm,column_major_layout> m = matrix_cast<float>(randm(4,4));
+ matrix<float,1,0,mm,column_major_layout> rv = matrix_cast<float>(randm(1,4));
+ matrix<float,0,1,mm,column_major_layout> cv = matrix_cast<float>(randm(4,1));
+ test_dot_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<double>, column major";
+ {
+ matrix<complex<double>,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4));
+ matrix<complex<double>,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4));
+ matrix<complex<double>,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1));
+ test_dot_stuff(m,rv,cv);
+ test_dot_stuff_conj(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<float>, column major";
+ {
+ matrix<complex<float>,0,0,mm,column_major_layout > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
+ matrix<complex<float>,1,0,mm,column_major_layout> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
+ matrix<complex<float>,0,1,mm,column_major_layout> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
+ test_dot_stuff(m,rv,cv);
+ test_dot_stuff_conj(m,rv,cv);
+ }
+
+
+ test_mat_bindings();
+
+ print_spinner();
+ }
+ };
+
+ blas_bindings_dot_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemm.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemm.cpp
new file mode 100644
index 000000000..83d41edd1
--- /dev/null
+++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemm.cpp
@@ -0,0 +1,311 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "../tester.h"
+#include <dlib/matrix.h>
+
+#ifndef DLIB_USE_BLAS
+#error "BLAS bindings must be used for this test to make any sense"
+#endif
+
+namespace dlib
+{
+ namespace blas_bindings
+ {
+ // This is a little screwy. This function is used inside the BLAS
+ // bindings to count how many times each of the BLAS functions get called.
+#ifdef DLIB_TEST_BLAS_BINDINGS
+ int& counter_gemm() { static int counter = 0; return counter; }
+#endif
+
+ }
+}
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ // Declare the logger we will use in this test. The name of the logger
+ // should start with "test."
+ dlib::logger dlog("test.gemm");
+
+
+ class blas_bindings_gemm_tester : public tester
+ {
+ public:
+ blas_bindings_gemm_tester (
+ ) :
+ tester (
+ "test_gemm", // the command line argument name for this test
+ "Run tests for GEMM routines.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {}
+
+ template <typename matrix_type>
+ void test_gemm_stuff(
+ const matrix_type& c
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ matrix_type b, a;
+ a = c;
+
+ counter_gemm() = 0;
+ b = a*a;
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = a/2*a;
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = a*trans(a) + a;
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = (a+a)*(a+a);
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = a*(a-a);
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = trans(a)*trans(a) + a;
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = trans(trans(trans(a)*a + a));
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = a*a*a*a;
+ DLIB_TEST(counter_gemm() == 3);
+ b = c;
+
+ counter_gemm() = 0;
+ a = a*a*a*a;
+ DLIB_TEST(counter_gemm() == 3);
+ a = c;
+
+ counter_gemm() = 0;
+ a = (b + a*trans(a)*a*3*a)*trans(b);
+ DLIB_TEST(counter_gemm() == 4);
+ a = c;
+
+ counter_gemm() = 0;
+ a = trans((trans(b) + trans(a)*trans(a)*a*3*a)*trans(b));
+ DLIB_TEST(counter_gemm() == 4);
+ a = c;
+
+ counter_gemm() = 0;
+ a = trans((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(b));
+ DLIB_TEST(counter_gemm() == 4);
+ a = c;
+
+ counter_gemm() = 0;
+ a = trans((trans(b) + trans(a)*(a + b)*trans(a)*3*a)*trans(b));
+ DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
+ a = c;
+
+ counter_gemm() = 0;
+ a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*a)*trans(b));
+ DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
+ a = c;
+ }
+
+ template <typename matrix_type>
+ void test_gemm_stuff_conj(
+ const matrix_type& c
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ matrix_type b, a;
+ a = c;
+
+ counter_gemm() = 0;
+ b = a*conj(a);
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = a*trans(conj(a)) + a;
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = conj(trans(a))*trans(a) + a;
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = trans(trans(trans(a)*conj(a) + conj(a)));
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ b = a*a*conj(a)*a;
+ DLIB_TEST(counter_gemm() == 3);
+ b = c;
+
+ counter_gemm() = 0;
+ a = a*trans(conj(a))*a*a;
+ DLIB_TEST(counter_gemm() == 3);
+ a = c;
+
+ counter_gemm() = 0;
+ a = (b + a*trans(conj(a))*a*3*a)*trans(b);
+ DLIB_TEST(counter_gemm() == 4);
+ a = c;
+
+ counter_gemm() = 0;
+ a = (trans((conj(trans(b)) + trans(a)*conj(trans(a))*a*3*a)*trans(b)));
+ DLIB_TEST(counter_gemm() == 4);
+ a = c;
+
+ counter_gemm() = 0;
+ a = ((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(conj(b)));
+ DLIB_TEST(counter_gemm() == 4);
+ a = c;
+
+ counter_gemm() = 0;
+ a = trans((trans(b) + trans(a)*conj(a + b)*trans(a)*3*a)*trans(b));
+ DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
+ a = c;
+
+ counter_gemm() = 0;
+ a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*conj(a))*trans(b));
+ DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
+ a = c;
+ }
+
+ void perform_test (
+ )
+ {
+ using namespace dlib;
+ typedef dlib::memory_manager<char>::kernel_1a mm;
+
+ print_spinner();
+
+ dlog << dlib::LINFO << "test double";
+ {
+ matrix<double> a = randm(4,4);
+ test_gemm_stuff(a);
+ }
+
+ print_spinner();
+ dlog << dlib::LINFO << "test float";
+ {
+ matrix<float> a = matrix_cast<float>(randm(4,4));
+ test_gemm_stuff(a);
+ }
+
+ print_spinner();
+ dlog << dlib::LINFO << "test complex<float>";
+ {
+ matrix<float> a = matrix_cast<float>(randm(4,4));
+ matrix<float> b = matrix_cast<float>(randm(4,4));
+ matrix<complex<float> > c = complex_matrix(a,b);
+ test_gemm_stuff(c);
+ test_gemm_stuff_conj(c);
+ }
+
+ print_spinner();
+ dlog << dlib::LINFO << "test complex<double>";
+ {
+ matrix<double> a = matrix_cast<double>(randm(4,4));
+ matrix<double> b = matrix_cast<double>(randm(4,4));
+ matrix<complex<double> > c = complex_matrix(a,b);
+ test_gemm_stuff(c);
+ test_gemm_stuff_conj(c);
+ }
+
+
+ print_spinner();
+
+ dlog << dlib::LINFO << "test double, column major";
+ {
+ matrix<double,100,100,mm,column_major_layout> a = randm(100,100);
+ test_gemm_stuff(a);
+ }
+
+ print_spinner();
+ dlog << dlib::LINFO << "test float, column major";
+ {
+ matrix<float,100,100,mm,column_major_layout> a = matrix_cast<float>(randm(100,100));
+ test_gemm_stuff(a);
+ }
+
+ print_spinner();
+ dlog << dlib::LINFO << "test complex<double>, column major";
+ {
+ matrix<double,100,100,mm,column_major_layout> a = matrix_cast<double>(randm(100,100));
+ matrix<double,100,100,mm,column_major_layout> b = matrix_cast<double>(randm(100,100));
+ matrix<complex<double>,100,100,mm,column_major_layout > c = complex_matrix(a,b);
+ test_gemm_stuff(c);
+ test_gemm_stuff_conj(c);
+ }
+
+ print_spinner();
+
+ dlog << dlib::LINFO << "test complex<float>, column major";
+ {
+ matrix<float,100,100,mm,column_major_layout> a = matrix_cast<float>(randm(100,100));
+ matrix<float,100,100,mm,column_major_layout> b = matrix_cast<float>(randm(100,100));
+ matrix<complex<float>,100,100,mm,column_major_layout > c = complex_matrix(a,b);
+ test_gemm_stuff(c);
+ test_gemm_stuff_conj(c);
+ }
+
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+ array2d<double> a(100,100);
+ array2d<double> b(100,100);
+ matrix<double> c;
+
+ counter_gemm() = 0;
+ c = mat(a)*mat(b);
+ DLIB_TEST(counter_gemm() == 1);
+
+ counter_gemm() = 0;
+ c = trans(2*mat(a)*mat(b));
+ DLIB_TEST(counter_gemm() == 1);
+ }
+
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+ array2d<double> a(100,100);
+ array2d<double> b(100,100);
+ matrix<double> aa(100,100);
+ matrix<double> bb(100,100);
+ matrix<double> c;
+
+ counter_gemm() = 0;
+ c = mat(&a[0][0],100,100)*mat(&b[0][0],100,100);
+ DLIB_TEST(counter_gemm() == 1);
+ set_ptrm(&c(0,0),100,100) = mat(&a[0][0],100,100)*mat(&b[0][0],100,100);
+ DLIB_TEST(counter_gemm() == 2);
+ set_ptrm(&c(0,0),100,100) = aa*bb;
+ DLIB_TEST(counter_gemm() == 3);
+
+ counter_gemm() = 0;
+ c = trans(2*mat(&a[0][0],100,100)*mat(&b[0][0],100,100));
+ DLIB_TEST(counter_gemm() == 1);
+ set_ptrm(&c(0,0),100,100) = trans(2*mat(&a[0][0],100,100)*mat(&b[0][0],100,100));
+ DLIB_TEST(counter_gemm() == 2);
+ set_ptrm(&c(0,0),100,100) = trans(2*mat(a)*mat(b));
+ DLIB_TEST(counter_gemm() == 3);
+ }
+
+ print_spinner();
+ }
+ };
+
+ blas_bindings_gemm_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemv.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemv.cpp
new file mode 100644
index 000000000..322438313
--- /dev/null
+++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemv.cpp
@@ -0,0 +1,226 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "../tester.h"
+#include <dlib/matrix.h>
+
+#ifndef DLIB_USE_BLAS
+#error "BLAS bindings must be used for this test to make any sense"
+#endif
+
+namespace dlib
+{
+ namespace blas_bindings
+ {
+ // This is a little screwy. This function is used inside the BLAS
+ // bindings to count how many times each of the BLAS functions get called.
+#ifdef DLIB_TEST_BLAS_BINDINGS
+ int& counter_gemv() { static int counter = 0; return counter; }
+#endif
+
+ }
+}
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ // Declare the logger we will use in this test. The name of the logger
+ // should start with "test."
+ dlib::logger dlog("test.gemv");
+
+
+ class blas_bindings_gemv_tester : public tester
+ {
+ public:
+ blas_bindings_gemv_tester (
+ ) :
+ tester (
+ "test_gemv", // the command line argument name for this test
+ "Run tests for GEMV routines.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {}
+
+ template <typename matrix_type, typename rv_type, typename cv_type>
+ void test_gemv_stuff(
+ matrix_type& m,
+ cv_type& cv,
+ rv_type& rv
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ cv_type cv2;
+ rv_type rv2;
+ typedef typename matrix_type::type scalar_type;
+ scalar_type val;
+
+ counter_gemv() = 0;
+ cv2 = m*cv;
+ DLIB_TEST(counter_gemv() == 1);
+
+ counter_gemv() = 0;
+ cv2 = m*2*cv;
+ DLIB_TEST(counter_gemv() == 1);
+
+ counter_gemv() = 0;
+ cv2 = m*2*trans(rv);
+ DLIB_TEST(counter_gemv() == 1);
+
+ counter_gemv() = 0;
+ rv2 = trans(m*2*cv);
+ DLIB_TEST(counter_gemv() == 1);
+
+ counter_gemv() = 0;
+ rv2 = rv*m;
+ DLIB_TEST(counter_gemv() == 1);
+
+ counter_gemv() = 0;
+ rv2 = (rv + rv)*m;
+ DLIB_TEST(counter_gemv() == 1);
+
+ counter_gemv() = 0;
+ rv2 = trans(cv)*m;
+ DLIB_TEST(counter_gemv() == 1);
+ dlog << dlib::LTRACE << 1;
+
+ counter_gemv() = 0;
+ rv2 = trans(cv)*trans(m) + rv*trans(m);
+ DLIB_TEST(counter_gemv() == 2);
+ dlog << dlib::LTRACE << 2;
+
+ counter_gemv() = 0;
+ cv2 = m*trans(trans(cv)*trans(m) + 3*rv*trans(m));
+ DLIB_TEST(counter_gemv() == 3);
+
+ // This does one dot and one gemv
+ counter_gemv() = 0;
+ val = trans(cv)*m*trans(rv);
+ DLIB_TEST_MSG(counter_gemv() == 1, counter_gemv());
+
+ // This does one dot and two gemv
+ counter_gemv() = 0;
+ val = (trans(cv)*m)*(m*trans(rv));
+ DLIB_TEST_MSG(counter_gemv() == 2, counter_gemv());
+
+ // This does one dot and two gemv
+ counter_gemv() = 0;
+ val = trans(cv)*m*trans(m)*trans(rv);
+ DLIB_TEST_MSG(counter_gemv() == 2, counter_gemv());
+ }
+
+
+ template <typename matrix_type, typename rv_type, typename cv_type>
+ void test_gemv_stuff_conj(
+ matrix_type& m,
+ cv_type& cv,
+ rv_type& rv
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ cv_type cv2;
+ rv_type rv2;
+
+ counter_gemv() = 0;
+ cv2 = trans(cv)*conj(m);
+ DLIB_TEST(counter_gemv() == 1);
+
+ counter_gemv() = 0;
+ cv2 = conj(trans(m))*rv;
+ DLIB_TEST(counter_gemv() == 1);
+
+ counter_gemv() = 0;
+ cv2 = conj(trans(m))*trans(cv);
+ DLIB_TEST(counter_gemv() == 1);
+
+ counter_gemv() = 0;
+ cv2 = trans(trans(cv)*conj(2*m) + conj(3*trans(m))*rv + conj(trans(m)*3)*trans(cv));
+ DLIB_TEST(counter_gemv() == 3);
+
+ }
+
+ void perform_test (
+ )
+ {
+ using namespace dlib;
+ typedef dlib::memory_manager<char>::kernel_1a mm;
+
+ dlog << dlib::LINFO << "test double";
+ {
+ matrix<double> m = randm(4,4);
+ matrix<double,0,1> cv = randm(4,1);
+ matrix<double,1,0> rv = randm(1,4);
+ test_gemv_stuff(m,cv,rv);
+ }
+
+ dlog << dlib::LINFO << "test float";
+ {
+ matrix<float> m = matrix_cast<float>(randm(4,4));
+ matrix<float,0,1> cv = matrix_cast<float>(randm(4,1));
+ matrix<float,1,0> rv = matrix_cast<float>(randm(1,4));
+ test_gemv_stuff(m,cv,rv);
+ }
+
+ dlog << dlib::LINFO << "test complex<double>";
+ {
+ matrix<complex<double> > m = complex_matrix(randm(4,4), randm(4,4));
+ matrix<complex<double>,0,1> cv = complex_matrix(randm(4,1), randm(4,1));
+ matrix<complex<double>,1,0> rv = complex_matrix(randm(1,4), randm(1,4));
+ test_gemv_stuff(m,cv,rv);
+ }
+
+ dlog << dlib::LINFO << "test complex<float>";
+ {
+ matrix<complex<float> > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
+ matrix<complex<float>,0,1> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
+ matrix<complex<float>,1,0> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
+ test_gemv_stuff(m,cv,rv);
+ }
+
+
+ dlog << dlib::LINFO << "test double";
+ {
+ matrix<double,0,0,mm,column_major_layout> m = randm(4,4);
+ matrix<double,0,1,mm,column_major_layout> cv = randm(4,1);
+ matrix<double,1,0,mm,column_major_layout> rv = randm(1,4);
+ test_gemv_stuff(m,cv,rv);
+ }
+
+ dlog << dlib::LINFO << "test float";
+ {
+ matrix<float,0,0,mm,column_major_layout> m = matrix_cast<float>(randm(4,4));
+ matrix<float,0,1,mm,column_major_layout> cv = matrix_cast<float>(randm(4,1));
+ matrix<float,1,0,mm,column_major_layout> rv = matrix_cast<float>(randm(1,4));
+ test_gemv_stuff(m,cv,rv);
+ }
+
+ dlog << dlib::LINFO << "test complex<double>";
+ {
+ matrix<complex<double>,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4));
+ matrix<complex<double>,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1));
+ matrix<complex<double>,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4));
+ test_gemv_stuff(m,cv,rv);
+ }
+
+ dlog << dlib::LINFO << "test complex<float>";
+ {
+ matrix<complex<float>,0,0,mm,column_major_layout > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
+ matrix<complex<float>,0,1,mm,column_major_layout> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
+ matrix<complex<float>,1,0,mm,column_major_layout> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
+ test_gemv_stuff(m,cv,rv);
+ }
+
+
+ print_spinner();
+ }
+ };
+
+ blas_bindings_gemv_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_ger.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_ger.cpp
new file mode 100644
index 000000000..2aac834d2
--- /dev/null
+++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_ger.cpp
@@ -0,0 +1,200 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "../tester.h"
+#include <dlib/matrix.h>
+
+#ifndef DLIB_USE_BLAS
+#error "BLAS bindings must be used for this test to make any sense"
+#endif
+
+namespace dlib
+{
+ namespace blas_bindings
+ {
+ // This is a little screwy. This function is used inside the BLAS
+ // bindings to count how many times each of the BLAS functions get called.
+#ifdef DLIB_TEST_BLAS_BINDINGS
+ int& counter_ger() { static int counter = 0; return counter; }
+#endif
+
+ }
+}
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ // Declare the logger we will use in this test. The name of the logger
+ // should start with "test."
+ dlib::logger dlog("test.ger");
+
+
+ class blas_bindings_ger_tester : public tester
+ {
+ public:
+ blas_bindings_ger_tester (
+ ) :
+ tester (
+ "test_ger", // the command line argument name for this test
+ "Run tests for GER routines.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {}
+
+ template <typename matrix_type, typename cv_type, typename rv_type>
+ void test_ger_stuff(
+ matrix_type& m,
+ rv_type& rv,
+ cv_type& cv
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ rv_type rv2;
+ cv_type cv2;
+ matrix_type m2;
+
+ counter_ger() = 0;
+ m2 = m + cv*rv;
+ DLIB_TEST_MSG(counter_ger() == 1, counter_ger());
+
+ counter_ger() = 0;
+ m += trans(rv)*rv;
+ DLIB_TEST(counter_ger() == 1);
+
+ counter_ger() = 0;
+ m += trans(rv)*trans(cv);
+ DLIB_TEST(counter_ger() == 1);
+
+ counter_ger() = 0;
+ m += cv*trans(cv);
+ DLIB_TEST(counter_ger() == 1);
+
+ counter_ger() = 0;
+ m += trans(rv)*rv + trans(cv*3*rv);
+ DLIB_TEST(counter_ger() == 2);
+ }
+
+
+ template <typename matrix_type, typename cv_type, typename rv_type>
+ void test_ger_stuff_conj(
+ matrix_type& m,
+ rv_type& rv,
+ cv_type& cv
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ rv_type rv2;
+ cv_type cv2;
+ matrix_type m2;
+
+ counter_ger() = 0;
+ m += cv*conj(rv);
+ DLIB_TEST_MSG(counter_ger() == 1, counter_ger());
+
+ counter_ger() = 0;
+ m += trans(rv)*conj(rv);
+ DLIB_TEST(counter_ger() == 1);
+
+ counter_ger() = 0;
+ m += trans(rv)*conj(trans(cv));
+ DLIB_TEST(counter_ger() == 1);
+
+ counter_ger() = 0;
+ m += cv*trans(conj(cv));
+ DLIB_TEST(counter_ger() == 1);
+
+ counter_ger() = 0;
+ m += trans(rv)*rv + trans(cv*3*conj(rv));
+ DLIB_TEST(counter_ger() == 2);
+ }
+
+ void perform_test (
+ )
+ {
+ using namespace dlib;
+ typedef dlib::memory_manager<char>::kernel_1a mm;
+
+ dlog << dlib::LINFO << "test double";
+ {
+ matrix<double> m = randm(4,4);
+ matrix<double,1,0> rv = randm(1,4);
+ matrix<double,0,1> cv = randm(4,1);
+ test_ger_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test float";
+ {
+ matrix<float> m = matrix_cast<float>(randm(4,4));
+ matrix<float,1,0> rv = matrix_cast<float>(randm(1,4));
+ matrix<float,0,1> cv = matrix_cast<float>(randm(4,1));
+ test_ger_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<double>";
+ {
+ matrix<complex<double> > m = complex_matrix(randm(4,4), randm(4,4));
+ matrix<complex<double>,1,0> rv = complex_matrix(randm(1,4), randm(1,4));
+ matrix<complex<double>,0,1> cv = complex_matrix(randm(4,1), randm(4,1));
+ test_ger_stuff(m,rv,cv);
+ test_ger_stuff_conj(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<float>";
+ {
+ matrix<complex<float> > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
+ matrix<complex<float>,1,0> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
+ matrix<complex<float>,0,1> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
+ test_ger_stuff(m,rv,cv);
+ test_ger_stuff_conj(m,rv,cv);
+ }
+
+
+ dlog << dlib::LINFO << "test double";
+ {
+ matrix<double,0,0,mm,column_major_layout> m = randm(4,4);
+ matrix<double,1,0,mm,column_major_layout> rv = randm(1,4);
+ matrix<double,0,1,mm,column_major_layout> cv = randm(4,1);
+ test_ger_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test float";
+ {
+ matrix<float,0,0,mm,column_major_layout> m = matrix_cast<float>(randm(4,4));
+ matrix<float,1,0,mm,column_major_layout> rv = matrix_cast<float>(randm(1,4));
+ matrix<float,0,1,mm,column_major_layout> cv = matrix_cast<float>(randm(4,1));
+ test_ger_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<double>";
+ {
+ matrix<complex<double>,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4));
+ matrix<complex<double>,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4));
+ matrix<complex<double>,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1));
+ test_ger_stuff(m,rv,cv);
+ test_ger_stuff_conj(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<float>";
+ {
+ matrix<complex<float>,0,0,mm,column_major_layout > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
+ matrix<complex<float>,1,0,mm,column_major_layout> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
+ matrix<complex<float>,0,1,mm,column_major_layout> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
+ test_ger_stuff(m,rv,cv);
+ test_ger_stuff_conj(m,rv,cv);
+ }
+
+
+ print_spinner();
+ }
+ };
+
+ blas_bindings_ger_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_scal_axpy.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_scal_axpy.cpp
new file mode 100644
index 000000000..d1a7b99e4
--- /dev/null
+++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_scal_axpy.cpp
@@ -0,0 +1,261 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "../tester.h"
+#include <dlib/matrix.h>
+
+#ifndef DLIB_USE_BLAS
+#error "BLAS bindings must be used for this test to make any sense"
+#endif
+
+namespace dlib
+{
+ namespace blas_bindings
+ {
+ // This is a little screwy. This function is used inside the BLAS
+ // bindings to count how many times each of the BLAS functions get called.
+#ifdef DLIB_TEST_BLAS_BINDINGS
+ int& counter_axpy() { static int counter = 0; return counter; }
+ int& counter_scal() { static int counter = 0; return counter; }
+#endif
+
+ }
+}
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ // Declare the logger we will use in this test. The name of the logger
+ // should start with "test."
+ dlib::logger dlog("test.scal_axpy");
+
+
+ class blas_bindings_scal_axpy_tester : public tester
+ {
+ public:
+ blas_bindings_scal_axpy_tester (
+ ) :
+ tester (
+ "test_scal_axpy", // the command line argument name for this test
+ "Run tests for DOT routines.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {}
+
+ template <typename matrix_type, typename cv_type, typename rv_type>
+ void test_scal_axpy_stuff(
+ matrix_type& m,
+ rv_type& rv,
+ cv_type& cv
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ rv_type rv2 = rv;
+ cv_type cv2 = cv;
+ matrix_type m2 = m;
+ typedef typename matrix_type::type scalar_type;
+ scalar_type val;
+
+ counter_scal() = 0;
+ m = 5*m;
+ DLIB_TEST(counter_scal() == 1);
+
+ counter_scal() = 0;
+ rv = 5*rv;
+ DLIB_TEST(counter_scal() == 1);
+
+ counter_scal() = 0;
+ rv = 5*rv;
+ DLIB_TEST(counter_scal() == 1);
+
+
+ counter_axpy() = 0;
+ m2 += 5*m;
+ DLIB_TEST(counter_axpy() == 1);
+
+ counter_axpy() = 0;
+ rv2 += 5*rv;
+ DLIB_TEST(counter_axpy() == 1);
+
+ counter_axpy() = 0;
+ rv2 += 5*rv;
+ DLIB_TEST(counter_axpy() == 1);
+
+
+
+ counter_scal() = 0;
+ m = m*5;
+ DLIB_TEST(counter_scal() == 1);
+
+ counter_scal() = 0;
+ rv = rv*5;
+ DLIB_TEST(counter_scal() == 1);
+
+ counter_scal() = 0;
+ cv = cv*5;
+ DLIB_TEST(counter_scal() == 1);
+
+
+ counter_axpy() = 0;
+ m2 += m*5;
+ DLIB_TEST(counter_axpy() == 1);
+
+ counter_axpy() = 0;
+ rv2 += rv*5;
+ DLIB_TEST(counter_axpy() == 1);
+
+ counter_axpy() = 0;
+ cv2 += cv*5;
+ DLIB_TEST(counter_axpy() == 1);
+
+
+
+
+ counter_axpy() = 0;
+ m2 = m2 + m*5;
+ DLIB_TEST(counter_axpy() == 1);
+
+ counter_axpy() = 0;
+ rv2 = rv2 + rv*5;
+ DLIB_TEST(counter_axpy() == 1);
+
+ counter_axpy() = 0;
+ cv2 = cv2 + cv*5;
+ DLIB_TEST(counter_axpy() == 1);
+
+
+ counter_axpy() = 0;
+ cv2 = 1;
+ cv = 1;
+ cv2 = 2*cv2 + cv*5;
+ DLIB_TEST(counter_axpy() == 1);
+ DLIB_TEST(max(abs(cv2 - 7)) == 0);
+
+
+ counter_axpy() = 0;
+ rv2 = 1;
+ rv = 1;
+ rv2 = 2*rv2 + rv*5;
+ DLIB_TEST(counter_axpy() == 1);
+ DLIB_TEST(max(abs(rv2 - 7)) == 0);
+
+ counter_axpy() = 0;
+ m2 = 1;
+ m = 1;
+ m2 = 2*m2 + m*5;
+ DLIB_TEST(counter_axpy() == 1);
+ DLIB_TEST(max(abs(m2 - 7)) == 0);
+
+
+ if (is_same_type<typename matrix_type::layout_type, row_major_layout>::value)
+ {
+ counter_axpy() = 0;
+ m2 = 1;
+ m = 1;
+ set_ptrm(&m2(0,0),m2.nr(),m2.nc()) = 2*m2 + m*5;
+ DLIB_TEST(max(abs(m2 - 7)) == 0);
+ DLIB_TEST(counter_axpy() == 1);
+
+ counter_axpy() = 0;
+ m2 = 1;
+ m = 1;
+ set_ptrm(&m2(0,0),m2.nr(),m2.nc()) = 2*mat(&m2(0,0),m2.nr(),m2.nc()) + mat(&m(0,0),m.nr(),m.nc())*5;
+ DLIB_TEST(max(abs(m2 - 7)) == 0);
+ DLIB_TEST(counter_axpy() == 1);
+
+ counter_axpy() = 0;
+ m2 = 1;
+ m = 1;
+ m2 = 2*mat(&m2(0,0),m2.nr(),m2.nc()) + mat(&m(0,0),m.nr(),m.nc())*5;
+ DLIB_TEST(max(abs(m2 - 7)) == 0);
+ DLIB_TEST(counter_axpy() == 1);
+ }
+
+ }
+
+
+
+ void perform_test (
+ )
+ {
+ using namespace dlib;
+ typedef dlib::memory_manager<char>::kernel_1a mm;
+
+ dlog << dlib::LINFO << "test double";
+ {
+ matrix<double> m = randm(4,4);
+ matrix<double,1,0> rv = randm(1,4);
+ matrix<double,0,1> cv = randm(4,1);
+ test_scal_axpy_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test float";
+ {
+ matrix<float> m = matrix_cast<float>(randm(4,4));
+ matrix<float,1,0> rv = matrix_cast<float>(randm(1,4));
+ matrix<float,0,1> cv = matrix_cast<float>(randm(4,1));
+ test_scal_axpy_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<double>";
+ {
+ matrix<complex<double> > m = complex_matrix(randm(4,4), randm(4,4));
+ matrix<complex<double>,1,0> rv = complex_matrix(randm(1,4), randm(1,4));
+ matrix<complex<double>,0,1> cv = complex_matrix(randm(4,1), randm(4,1));
+ test_scal_axpy_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<float>";
+ {
+ matrix<complex<float> > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
+ matrix<complex<float>,1,0> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
+ matrix<complex<float>,0,1> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
+ test_scal_axpy_stuff(m,rv,cv);
+ }
+
+
+ dlog << dlib::LINFO << "test double, column major";
+ {
+ matrix<double,0,0,mm,column_major_layout> m = randm(4,4);
+ matrix<double,1,0,mm,column_major_layout> rv = randm(1,4);
+ matrix<double,0,1,mm,column_major_layout> cv = randm(4,1);
+ test_scal_axpy_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test float, column major";
+ {
+ matrix<float,0,0,mm,column_major_layout> m = matrix_cast<float>(randm(4,4));
+ matrix<float,1,0,mm,column_major_layout> rv = matrix_cast<float>(randm(1,4));
+ matrix<float,0,1,mm,column_major_layout> cv = matrix_cast<float>(randm(4,1));
+ test_scal_axpy_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<double>, column major";
+ {
+ matrix<complex<double>,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4));
+ matrix<complex<double>,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4));
+ matrix<complex<double>,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1));
+ test_scal_axpy_stuff(m,rv,cv);
+ }
+
+ dlog << dlib::LINFO << "test complex<float>, column major";
+ {
+ matrix<complex<float>,0,0,mm,column_major_layout > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
+ matrix<complex<float>,1,0,mm,column_major_layout> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
+ matrix<complex<float>,0,1,mm,column_major_layout> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
+ test_scal_axpy_stuff(m,rv,cv);
+ }
+
+
+ print_spinner();
+ }
+ };
+
+ blas_bindings_scal_axpy_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/blas_bindings/vector.cpp b/ml/dlib/dlib/test/blas_bindings/vector.cpp
new file mode 100644
index 000000000..0a6f5f301
--- /dev/null
+++ b/ml/dlib/dlib/test/blas_bindings/vector.cpp
@@ -0,0 +1,115 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "../tester.h"
+#include <dlib/geometry.h>
+#include <dlib/matrix.h>
+
+#ifndef DLIB_USE_BLAS
+#error "BLAS bindings must be used for this test to make any sense"
+#endif
+
+namespace dlib
+{
+ namespace blas_bindings
+ {
+
+#ifdef DLIB_TEST_BLAS_BINDINGS
+ extern int& counter_gemm();
+ extern int& counter_gemv();
+ extern int& counter_ger();
+ extern int& counter_dot();
+#endif
+
+ }
+}
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ // Declare the logger we will use in this test. The name of the logger
+ // should start with "test."
+ dlib::logger dlog("test.vector");
+
+
+ class vector_tester : public tester
+ {
+ public:
+ vector_tester (
+ ) :
+ tester (
+ "test_vector", // the command line argument name for this test
+ "Run tests on dlib::vector.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {}
+
+ template <typename type>
+ void test_vector(
+ ) const
+ {
+ using namespace dlib;
+ using namespace dlib::blas_bindings;
+
+ dlib::vector<type,2> a2, b2, c2;
+ dlib::vector<type,3> a3, b3, c3;
+
+ matrix<type> mat2(2,2), mat3(3,3);
+ mat2 = 0;
+ mat3 = 0;
+
+ type var = 0;
+
+ // We want to make sure that the BLAS bindings are being called for the 2D and 3D vectors. That would
+ // be very slow.
+ counter_gemm() = 0;
+ counter_gemv() = 0;
+ counter_ger() = 0;
+ counter_dot() = 0;
+
+ var = trans(a2)*(a2);
+ var = dot(a2,a2);
+
+ a2 = mat2*b2;
+ var = trans(b2)*mat2*b2;
+
+ var = trans(a3)*(a3);
+ var = dot(a3,a3);
+
+ a3 = mat3*b3;
+ var = trans(b3)*mat3*b3;
+
+ mat3 = c3*trans(a3);
+ mat2 = c2*trans(a2);
+
+ DLIB_TEST(counter_gemm() == 0 && counter_gemv() == 0 && counter_ger() == 0 && counter_dot() == 0);
+
+ }
+
+ void perform_test (
+ )
+ {
+ using namespace dlib;
+
+ dlog << dlib::LINFO << "test double";
+ test_vector<double>();
+
+ dlog << dlib::LINFO << "test float";
+ test_vector<float>();
+
+ dlog << dlib::LINFO << "test int";
+ test_vector<int>();
+
+ dlog << dlib::LINFO << "test short";
+ test_vector<short>();
+
+ print_spinner();
+ }
+ };
+
+ vector_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/bridge.cpp b/ml/dlib/dlib/test/bridge.cpp
new file mode 100644
index 000000000..36d270c74
--- /dev/null
+++ b/ml/dlib/dlib/test/bridge.cpp
@@ -0,0 +1,259 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/bridge.h>
+#include <dlib/type_safe_union.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.bridge");
+
+ const unsigned short testing_port = 41238;
+
+ void do_test1()
+ {
+ dlib::pipe<int> in(0), out(0);
+
+ bridge b1(connect_to_ip_and_port("127.0.0.1",testing_port), receive(in));
+ bridge b2(listen_on_port(testing_port), transmit(out));
+
+ for (int i = 0; i < 100; ++i)
+ {
+ int val = i;
+ out.enqueue(val);
+ val = 0;
+ in.dequeue(val);
+ DLIB_TEST(val == i);
+ }
+ }
+
+ void do_test2()
+ {
+ dlib::pipe<int> in(0), out(0), echo_pipe(0);
+
+ bridge b2(listen_on_port(testing_port), transmit(out), receive(in));
+ bridge echo(connect_to_ip_and_port("127.0.0.1",testing_port), receive(echo_pipe), transmit(echo_pipe));
+
+ for (int i = 0; i < 100; ++i)
+ {
+ int val = i;
+ out.enqueue(val);
+ val = 0;
+ in.dequeue(val);
+ DLIB_TEST(val == i);
+ }
+ }
+
+ void do_test3()
+ {
+ dlib::pipe<int> in(10), out(10), echo_pipe(10);
+
+ bridge b2(listen_on_port(testing_port), transmit(out), receive(in));
+ bridge echo(connect_to_ip_and_port("127.0.0.1",testing_port), receive(echo_pipe), transmit(echo_pipe));
+
+ b2.reconfigure(listen_on_port(testing_port), transmit(out), receive(in));
+
+ for (int i = 0; i < 100; ++i)
+ {
+ int val = i;
+ out.enqueue(val);
+ val = 0;
+ in.dequeue(val);
+ DLIB_TEST(val == i);
+ }
+ }
+
+ void do_test4()
+ {
+ dlib::pipe<int> in(0), out(0), echo_pipe(0);
+
+ bridge b2, echo;
+ b2.reconfigure(listen_on_port(testing_port), receive(in), transmit(out));
+ echo.reconfigure(connect_to_ip_and_port("127.0.0.1",testing_port), transmit(echo_pipe), receive(echo_pipe));
+
+ for (int i = 0; i < 100; ++i)
+ {
+ int val = i;
+ out.enqueue(val);
+ val = 0;
+ in.dequeue(val);
+ DLIB_TEST(val == i);
+ }
+ }
+
+ void do_test5(int pipe_size)
+ {
+ typedef type_safe_union<int, bridge_status> tsu_type;
+
+ dlib::pipe<tsu_type> out(pipe_size);
+ dlib::pipe<tsu_type> in(pipe_size);
+ dlib::pipe<bridge_status> out_status(pipe_size);
+
+ bridge b1(connect_to_ip_and_port("127.0.0.1",testing_port), receive(in));
+ tsu_type msg;
+
+ msg = b1.get_bridge_status();
+ DLIB_TEST(msg.contains<bridge_status>() == true);
+ DLIB_TEST(msg.get<bridge_status>().is_connected == false);
+ DLIB_TEST(msg.get<bridge_status>().foreign_ip == "");
+ DLIB_TEST(msg.get<bridge_status>().foreign_port == 0);
+
+ {
+ bridge b2(listen_on_port(testing_port), transmit(out), receive(out_status));
+
+ in.dequeue(msg);
+ DLIB_TEST(msg.contains<bridge_status>() == true);
+ DLIB_TEST(msg.get<bridge_status>().is_connected == true);
+ DLIB_TEST(msg.get<bridge_status>().foreign_ip == "127.0.0.1");
+ DLIB_TEST(msg.get<bridge_status>().foreign_port == testing_port);
+ msg = b1.get_bridge_status();
+ DLIB_TEST(msg.contains<bridge_status>() == true);
+ DLIB_TEST(msg.get<bridge_status>().is_connected == true);
+ DLIB_TEST(msg.get<bridge_status>().foreign_ip == "127.0.0.1");
+ DLIB_TEST(msg.get<bridge_status>().foreign_port == testing_port);
+
+ bridge_status temp;
+ out_status.dequeue(temp);
+ DLIB_TEST(temp.is_connected == true);
+ DLIB_TEST(temp.foreign_ip == "127.0.0.1");
+
+ for (int i = 0; i < 100; ++i)
+ {
+ msg = i;
+ out.enqueue(msg);
+
+ msg.get<int>() = 0;
+
+ in.dequeue(msg);
+ DLIB_TEST(msg.contains<int>() == true);
+ DLIB_TEST(msg.get<int>() == i);
+ }
+
+ }
+
+ in.dequeue(msg);
+ DLIB_TEST(msg.contains<bridge_status>() == true);
+ DLIB_TEST(msg.get<bridge_status>().is_connected == false);
+ DLIB_TEST(msg.get<bridge_status>().foreign_ip == "127.0.0.1");
+ DLIB_TEST(msg.get<bridge_status>().foreign_port == testing_port);
+ }
+
+ void do_test5_5(int pipe_size)
+ {
+ typedef type_safe_union<int, bridge_status> tsu_type;
+
+ dlib::pipe<tsu_type> out(pipe_size);
+ dlib::pipe<tsu_type> in(pipe_size);
+ dlib::pipe<bridge_status> out_status(pipe_size);
+
+ bridge b1(connect_to_ip_and_port("127.0.0.1",testing_port), receive(in));
+ tsu_type msg;
+
+ bridge b2(listen_on_port(testing_port), transmit(out), receive(out_status));
+
+ in.dequeue(msg);
+ DLIB_TEST(msg.contains<bridge_status>() == true);
+ DLIB_TEST(msg.get<bridge_status>().is_connected == true);
+ DLIB_TEST(msg.get<bridge_status>().foreign_ip == "127.0.0.1");
+ DLIB_TEST(msg.get<bridge_status>().foreign_port == testing_port);
+
+ bridge_status temp;
+ out_status.dequeue(temp);
+ DLIB_TEST(temp.is_connected == true);
+ DLIB_TEST(temp.foreign_ip == "127.0.0.1");
+
+ for (int i = 0; i < 100; ++i)
+ {
+ msg = i;
+ out.enqueue(msg);
+
+ msg.get<int>() = 0;
+
+ in.dequeue(msg);
+ DLIB_TEST(msg.contains<int>() == true);
+ DLIB_TEST(msg.get<int>() == i);
+ }
+
+ b2.clear();
+ msg = b2.get_bridge_status();
+ DLIB_TEST(msg.contains<bridge_status>() == true);
+ DLIB_TEST(msg.get<bridge_status>().is_connected == false);
+ DLIB_TEST(msg.get<bridge_status>().foreign_ip == "");
+ DLIB_TEST(msg.get<bridge_status>().foreign_port == 0);
+
+ in.dequeue(msg);
+ DLIB_TEST(msg.contains<bridge_status>() == true);
+ DLIB_TEST(msg.get<bridge_status>().is_connected == false);
+ DLIB_TEST(msg.get<bridge_status>().foreign_ip == "127.0.0.1");
+ DLIB_TEST(msg.get<bridge_status>().foreign_port == testing_port);
+ }
+
+ void do_test6()
+ {
+ dlib::pipe<int> in(0), out(300);
+
+ bridge b1(connect_to_ip_and_port("127.0.0.1",testing_port), receive(in));
+ bridge b2(listen_on_port(testing_port), transmit(out));
+
+ for (int i = 0; i < 100; ++i)
+ {
+ int val = i;
+ out.enqueue(val);
+ }
+
+ int val = 10;
+ in.dequeue(val);
+ DLIB_TEST(val == 0);
+ dlib::sleep(100);
+ in.dequeue(val);
+ DLIB_TEST(val == 1);
+ dlib::sleep(100);
+ }
+
+ class test_bridge : public tester
+ {
+ public:
+ test_bridge (
+ ) :
+ tester ("test_bridge",
+ "Runs tests on the bridge component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing bridge, using local port number of " << testing_port;
+
+ print_spinner();
+ do_test1();
+ print_spinner();
+ do_test2();
+ print_spinner();
+ do_test3();
+ print_spinner();
+ do_test4();
+ print_spinner();
+ for (int i = 0; i < 5; ++i)
+ do_test5(i);
+ do_test5_5(1);
+ print_spinner();
+ do_test6();
+ }
+ } a;
+
+
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/bsp.cpp b/ml/dlib/dlib/test/bsp.cpp
new file mode 100644
index 000000000..04367c1d3
--- /dev/null
+++ b/ml/dlib/dlib/test/bsp.cpp
@@ -0,0 +1,566 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/bsp.h>
+#include <dlib/threads.h>
+#include <dlib/pipe.h>
+#include <dlib/matrix.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.bsp");
+
+
+ template <typename funct>
+ struct callfunct_helper
+ {
+ callfunct_helper (
+ funct f_,
+ int port_,
+ bool& error_occurred_
+ ) :f(f_), port(port_), error_occurred(error_occurred_) {}
+
+ funct f;
+ int port;
+ bool& error_occurred;
+
+ void operator() (
+ ) const
+ {
+ try
+ {
+ bsp_listen(port, f);
+ }
+ catch (exception& e)
+ {
+ dlog << LERROR << "error calling bsp_listen(): " << e.what();
+ error_occurred = true;
+ }
+ }
+ };
+
+ template <typename funct>
+ callfunct_helper<funct> callfunct(funct f, int port, bool& error_occurred)
+ {
+ return callfunct_helper<funct>(f,port,error_occurred);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename funct>
+ struct callfunct_helper_pn
+ {
+ callfunct_helper_pn (
+ funct f_,
+ int port_,
+ bool& error_occurred_,
+ dlib::pipe<unsigned short>& port_pipe_
+ ) :f(f_), port(port_), error_occurred(error_occurred_), port_pipe(port_pipe_) {}
+
+ funct f;
+ int port;
+ bool& error_occurred;
+ dlib::pipe<unsigned short>& port_pipe;
+
+ struct helper
+ {
+ helper (
+ dlib::pipe<unsigned short>& port_pipe_
+ ) : port_pipe(port_pipe_) {}
+
+ dlib::pipe<unsigned short>& port_pipe;
+
+ void operator() (unsigned short p) { port_pipe.enqueue(p); }
+ };
+
+ void operator() (
+ ) const
+ {
+ try
+ {
+ bsp_listen_dynamic_port(port, helper(port_pipe), f);
+ }
+ catch (exception& e)
+ {
+ dlog << LERROR << "error calling bsp_listen_dynamic_port(): " << e.what();
+ error_occurred = true;
+ }
+ }
+ };
+
+ template <typename funct>
+ callfunct_helper_pn<funct> callfunct(funct f, int port, bool& error_occurred, dlib::pipe<unsigned short>& port_pipe)
+ {
+ return callfunct_helper_pn<funct>(f,port,error_occurred,port_pipe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void sum_array_driver (
+ bsp_context& obj,
+ const std::vector<int>& v,
+ int& result
+ )
+ {
+ obj.broadcast(v);
+
+ result = 0;
+ int val;
+ while(obj.try_receive(val))
+ result += val;
+ }
+
+ void sum_array_other (
+ bsp_context& obj
+ )
+ {
+ std::vector<int> v;
+ obj.receive(v);
+
+ int sum = 0;
+ for (unsigned long i = 0; i < v.size(); ++i)
+ sum += v[i];
+
+ obj.send(sum, 0);
+
+
+ }
+
+
+ void dotest1()
+ {
+ dlog << LINFO << "start dotest1()";
+ print_spinner();
+ bool error_occurred = false;
+ {
+ thread_function t1(callfunct(sum_array_other, 12345, error_occurred));
+ thread_function t2(callfunct(sum_array_other, 12346, error_occurred));
+ thread_function t3(callfunct(sum_array_other, 12347, error_occurred));
+ std::vector<int> v;
+ int true_value = 0;
+ for (int i = 0; i < 10; ++i)
+ {
+ v.push_back(i);
+ true_value += i;
+ }
+
+ // wait a little bit for the threads to start up
+ dlib::sleep(200);
+
+ try
+ {
+ int result;
+ std::vector<network_address> hosts;
+ hosts.push_back("127.0.0.1:12345");
+ hosts.push_back("localhost:12346");
+ hosts.push_back("127.0.0.1:12347");
+ bsp_connect(hosts, sum_array_driver, dlib::ref(v), dlib::ref(result));
+
+ dlog << LINFO << "result: "<< result;
+ dlog << LINFO << "should be: "<< 3*true_value;
+ DLIB_TEST(result == 3*true_value);
+ }
+ catch (std::exception& e)
+ {
+ dlog << LERROR << "error during bsp_context: " << e.what();
+ DLIB_TEST(false);
+ }
+ }
+ DLIB_TEST(error_occurred == false);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <unsigned long id>
+ void test2_job(bsp_context& obj)
+ {
+ if (obj.node_id() == id)
+ dlib::sleep(100);
+ }
+
+ template <unsigned long id>
+ void dotest2()
+ {
+ dlog << LINFO << "start dotest2()";
+ print_spinner();
+ bool error_occurred = false;
+ {
+ thread_function t1(callfunct(test2_job<id>, 12345, error_occurred));
+ thread_function t2(callfunct(test2_job<id>, 12346, error_occurred));
+ thread_function t3(callfunct(test2_job<id>, 12347, error_occurred));
+
+ // wait a little bit for the threads to start up
+ dlib::sleep(200);
+
+ try
+ {
+ std::vector<network_address> hosts;
+ hosts.push_back("127.0.0.1:12345");
+ hosts.push_back("127.0.0.1:12346");
+ hosts.push_back("127.0.0.1:12347");
+ bsp_connect(hosts, test2_job<id>);
+ }
+ catch (std::exception& e)
+ {
+ dlog << LERROR << "error during bsp_context: " << e.what();
+ DLIB_TEST(false);
+ }
+
+ }
+ DLIB_TEST(error_occurred == false);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test3_job_driver(bsp_context& obj, int& result)
+ {
+
+ obj.broadcast(obj.node_id());
+
+ int accum = 0;
+ int temp = 0;
+ while(obj.try_receive(temp))
+ accum += temp;
+
+ // send to node 1 so it can sum everything
+ if (obj.node_id() != 1)
+ obj.send(accum, 1);
+
+ while(obj.try_receive(temp))
+ accum += temp;
+
+ // Now hop the accum values along the nodes until the value from node 1 gets to
+ // node 0.
+ obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
+ obj.receive(accum);
+ obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
+ obj.receive(accum);
+ obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
+ obj.receive(accum);
+
+ // this whole block is a noop since it doesn't end up doing anything.
+ for (int k = 0; k < 100; ++k)
+ {
+ dlog << LINFO << "k: " << k;
+ for (int i = 0; i < 4; ++i)
+ {
+ obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
+ obj.receive(accum);
+ }
+ }
+
+
+ dlog << LINFO << "TERMINATE";
+ if (obj.node_id() == 0)
+ result = accum;
+ }
+
+
+ void test3_job(bsp_context& obj)
+ {
+ int junk;
+ test3_job_driver(obj, junk);
+ }
+
+
+ void dotest3()
+ {
+ dlog << LINFO << "start dotest3()";
+ print_spinner();
+ bool error_occurred = false;
+ {
+ dlib::pipe<unsigned short> ports(5);
+ thread_function t1(callfunct(test3_job, 12345, error_occurred, ports));
+ thread_function t2(callfunct(test3_job, 0, error_occurred, ports));
+ thread_function t3(callfunct(test3_job, 12347, error_occurred, ports));
+
+
+ try
+ {
+ std::vector<network_address> hosts;
+ unsigned short port;
+ ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
+ ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
+ ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
+ int result = 0;
+ const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2;
+ bsp_connect(hosts, test3_job_driver, dlib::ref(result));
+
+ dlog << LINFO << "result: " << result;
+ dlog << LINFO << "should be: " << expected;
+ DLIB_TEST(result == expected);
+ }
+ catch (std::exception& e)
+ {
+ dlog << LERROR << "error during bsp_context: " << e.what();
+ DLIB_TEST(false);
+ }
+
+ }
+ DLIB_TEST(error_occurred == false);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test4_job_driver(bsp_context& obj, int& result)
+ {
+
+ obj.broadcast(obj.node_id());
+
+ int accum = 0;
+ int temp = 0;
+ while(obj.try_receive(temp))
+ accum += temp;
+
+ // send to node 1 so it can sum everything
+ if (obj.node_id() != 1)
+ obj.send(accum, 1);
+
+ while(obj.try_receive(temp))
+ accum += temp;
+
+ // Now hop the accum values along the nodes until the value from node 1 gets to
+ // node 0.
+ obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
+ obj.receive(accum);
+ obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
+ obj.receive(accum);
+ obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
+ obj.receive(accum);
+
+ // this whole block is a noop since it doesn't end up doing anything.
+ for (int k = 0; k < 40; ++k)
+ {
+ dlog << LINFO << "k: " << k;
+ for (int i = 0; i < 4; ++i)
+ {
+ obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
+ obj.receive(accum);
+
+ obj.receive();
+ }
+ }
+
+
+ dlog << LINFO << "TERMINATE";
+ if (obj.node_id() == 0)
+ result = accum;
+ }
+
+
+ void test4_job(bsp_context& obj)
+ {
+ int junk;
+ test4_job_driver(obj, junk);
+ }
+
+
+ void dotest4()
+ {
+ dlog << LINFO << "start dotest4()";
+ print_spinner();
+ bool error_occurred = false;
+ {
+ dlib::pipe<unsigned short> ports(5);
+ thread_function t1(callfunct(test4_job, 0, error_occurred, ports));
+ thread_function t2(callfunct(test4_job, 0, error_occurred, ports));
+ thread_function t3(callfunct(test4_job, 0, error_occurred, ports));
+
+
+ try
+ {
+ std::vector<network_address> hosts;
+ unsigned short port;
+ ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
+ ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
+ ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
+ int result = 0;
+ const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2;
+ bsp_connect(hosts, test4_job_driver, dlib::ref(result));
+
+ dlog << LINFO << "result: " << result;
+ dlog << LINFO << "should be: " << expected;
+ DLIB_TEST(result == expected);
+ }
+ catch (std::exception& e)
+ {
+ dlog << LERROR << "error during bsp_context: " << e.what();
+ DLIB_TEST(false);
+ }
+
+ }
+ DLIB_TEST(error_occurred == false);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test5_job(
+ bsp_context& ,
+ int& val
+ )
+ {
+ val = 25;
+ }
+
+ void dotest5()
+ {
+ dlog << LINFO << "start dotest5()";
+ print_spinner();
+ std::vector<network_address> hosts;
+ int val = 0;
+ bsp_connect(hosts, test5_job, dlib::ref(val));
+ DLIB_TEST(val == 25);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double f ( double x)
+ {
+ return std::pow(x-2.0, 2.0);
+ }
+
+
+ void bsp_job_node_0 (
+ bsp_context& context,
+ double& min_value,
+ double& optimal_x
+ )
+ {
+ double left = -100;
+ double right = 100;
+
+ min_value = std::numeric_limits<double>::infinity();
+ double interval_width = std::abs(right-left);
+
+ // This is doing a BSP based grid search for the minimum of f(). Here we
+ // do 100 iterations where we keep shrinking the grid size.
+ for (int i = 0; i < 100; ++i)
+ {
+ context.broadcast(left);
+ context.broadcast(right);
+
+ for (unsigned int k = 1; k < context.number_of_nodes(); ++k)
+ {
+ std::pair<double,double> val;
+ context.receive(val);
+ if (val.second < min_value)
+ {
+ min_value = val.second;
+ optimal_x = val.first;
+ }
+ }
+
+ interval_width *= 0.5;
+ left = optimal_x - interval_width/2;
+ right = optimal_x + interval_width/2;
+ }
+ }
+
+
+ void bsp_job_other_nodes (
+ bsp_context& context
+ )
+ {
+ double left, right;
+ while (context.try_receive(left))
+ {
+ context.receive(right);
+
+ const double l = (context.node_id()-1)/(context.number_of_nodes()-1.0);
+ const double r = context.node_id() /(context.number_of_nodes()-1.0);
+
+ const double width = right-left;
+ matrix<double> values_to_check = linspace(left +l*width, left + r*width, 100);
+
+ double best_x = 0;
+ double best_val = std::numeric_limits<double>::infinity();
+ for (long j = 0; j < values_to_check.size(); ++j)
+ {
+ double temp = f(values_to_check(j));
+ if (temp < best_val)
+ {
+ best_val = temp;
+ best_x = values_to_check(j);
+ }
+ }
+
+ context.send(make_pair(best_x, best_val), 0);
+ }
+ }
+
+ void dotest6()
+ {
+ dlog << LINFO << "start dotest6()";
+ print_spinner();
+ bool error_occurred = false;
+ {
+ dlib::pipe<unsigned short> ports(5);
+ thread_function t1(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
+ thread_function t2(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
+ thread_function t3(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
+
+
+ try
+ {
+ std::vector<network_address> hosts;
+ unsigned short port;
+ ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
+ ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
+ ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
+ double min_value = 10, optimal_x = 0;
+ bsp_connect(hosts, bsp_job_node_0, dlib::ref(min_value), dlib::ref(optimal_x));
+
+ dlog << LINFO << "min_value: " << min_value;
+ dlog << LINFO << "optimal_x: " << optimal_x;
+ DLIB_TEST(std::abs(min_value - 0) < 1e-14);
+ DLIB_TEST(std::abs(optimal_x - 2) < 1e-14);
+ }
+ catch (std::exception& e)
+ {
+ dlog << LERROR << "error during bsp_context: " << e.what();
+ DLIB_TEST(false);
+ }
+
+ }
+ DLIB_TEST(error_occurred == false);
+ }
+// ----------------------------------------------------------------------------------------
+
+ class bsp_tester : public tester
+ {
+
+ public:
+ bsp_tester (
+ ) :
+ tester ("test_bsp",
+ "Runs tests on the BSP components.")
+ {}
+
+ void perform_test (
+ )
+ {
+ for (int i = 0; i < 3; ++i)
+ {
+ dotest1();
+ dotest2<0>();
+ dotest2<1>();
+ dotest2<2>();
+ dotest3();
+ dotest4();
+ dotest5();
+ dotest6();
+ }
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/byte_orderer.cpp b/ml/dlib/dlib/test/byte_orderer.cpp
new file mode 100644
index 000000000..7200c1b4a
--- /dev/null
+++ b/ml/dlib/dlib/test/byte_orderer.cpp
@@ -0,0 +1,111 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <cmath>
+#include <dlib/byte_orderer.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.byte_orderer");
+
+
+ class byte_orderer_tester : public tester
+ {
+ public:
+ byte_orderer_tester (
+ ) :
+ tester ("test_byte_orderer",
+ "Runs tests on the byte_orderer component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ byte_orderer bo;
+
+ union data
+ {
+ unsigned char b[4];
+ dlib::uint32 val;
+ };
+
+ data a;
+
+ a.val = 1;
+
+ if (bo.host_is_little_endian())
+ {
+ DLIB_TEST(a.b[0] == 1);
+ DLIB_TEST(a.b[1] == 0);
+ DLIB_TEST(a.b[2] == 0);
+ DLIB_TEST(a.b[3] == 0);
+
+ bo.host_to_big(a.val);
+
+ DLIB_TEST(a.b[0] == 0);
+ DLIB_TEST(a.b[1] == 0);
+ DLIB_TEST(a.b[2] == 0);
+ DLIB_TEST(a.b[3] == 1);
+
+ bo.big_to_host(a.val);
+
+ DLIB_TEST(a.b[0] == 1);
+ DLIB_TEST(a.b[1] == 0);
+ DLIB_TEST(a.b[2] == 0);
+ DLIB_TEST(a.b[3] == 0);
+
+ DLIB_TEST(a.val == 1);
+ bo.host_to_network(a.val);
+ DLIB_TEST(a.val == 0x01000000);
+ bo.network_to_host(a.val);
+ DLIB_TEST(a.val == 1);
+ }
+ else
+ {
+ DLIB_TEST(a.b[0] == 0);
+ DLIB_TEST(a.b[1] == 0);
+ DLIB_TEST(a.b[2] == 0);
+ DLIB_TEST(a.b[3] == 1);
+
+ bo.host_to_little(a.val);
+
+ DLIB_TEST(a.b[0] == 1);
+ DLIB_TEST(a.b[1] == 0);
+ DLIB_TEST(a.b[2] == 0);
+ DLIB_TEST(a.b[3] == 0);
+
+ bo.little_to_host(a.val);
+
+ DLIB_TEST(a.b[0] == 0);
+ DLIB_TEST(a.b[1] == 0);
+ DLIB_TEST(a.b[2] == 0);
+ DLIB_TEST(a.b[3] == 1);
+
+
+ DLIB_TEST(a.val == 1);
+ bo.network_to_host(a.val);
+ DLIB_TEST(a.val == 1);
+ bo.host_to_network(a.val);
+ DLIB_TEST(a.val == 1);
+
+ }
+
+
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/cca.cpp b/ml/dlib/dlib/test/cca.cpp
new file mode 100644
index 000000000..a2014121d
--- /dev/null
+++ b/ml/dlib/dlib/test/cca.cpp
@@ -0,0 +1,460 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <dlib/statistics.h>
+#include <dlib/sparse_vector.h>
+#include <dlib/timing.h>
+#include <map>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.cca");
+
+ dlib::rand rnd;
+// ----------------------------------------------------------------------------------------
+
+ /*
+ std::vector<std::map<unsigned long, double> > make_really_big_test_matrix (
+ )
+ {
+ std::vector<std::map<unsigned long,double> > temp(30000);
+ for (unsigned long i = 0; i < temp.size(); ++i)
+ {
+ for (int k = 0; k < 30; ++k)
+ temp[i][rnd.get_random_32bit_number()%10000] = 1;
+ }
+ return temp;
+ }
+ */
+
+ template <typename T>
+ std::vector<std::map<unsigned long, T> > mat_to_sparse (
+ const matrix<T>& A
+ )
+ {
+ std::vector<std::map<unsigned long,T> > temp(A.nr());
+ for (long r = 0; r < A.nr(); ++r)
+ {
+ for (long c = 0; c < A.nc(); ++c)
+ {
+ temp[r][c] = A(r,c);
+ }
+ }
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename EXP>
+ matrix<typename EXP::type> rm_zeros (
+ const matrix_exp<EXP>& m
+ )
+ {
+ // Do this to avoid trying to correlate super small numbers that are really just
+ // zero. Doing this avoids some potential false alarms in the unit tests below.
+ return round_zeros(m, max(abs(m))*1e-14);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ /*
+ void check_correlation (
+ matrix<double> L,
+ matrix<double> R,
+ const matrix<double>& Ltrans,
+ const matrix<double>& Rtrans,
+ const matrix<double,0,1>& correlations
+ )
+ {
+ // apply the transforms
+ L = L*Ltrans;
+ R = R*Rtrans;
+
+ // compute the real correlation values. Store them in A.
+ matrix<double> A = compute_correlations(L, R);
+
+ for (long i = 0; i < correlations.size(); ++i)
+ {
+ // compare what the measured correlation values are (in A) to the
+ // predicted values.
+ cout << "error: "<< A(i) - correlations(i);
+ }
+ }
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ void test_cca3()
+ {
+ print_spinner();
+ const unsigned long rank = rnd.get_random_32bit_number()%10 + 1;
+ const unsigned long m = rank + rnd.get_random_32bit_number()%15;
+ const unsigned long n = rank + rnd.get_random_32bit_number()%15;
+ const unsigned long n2 = rank + rnd.get_random_32bit_number()%15;
+ const unsigned long rank2 = rank + rnd.get_random_32bit_number()%5;
+
+ dlog << LINFO << "m: " << m;
+ dlog << LINFO << "n: " << n;
+ dlog << LINFO << "n2: " << n2;
+ dlog << LINFO << "rank: " << rank;
+ dlog << LINFO << "rank2: " << rank2;
+
+
+ matrix<double> L = randm(m,rank, rnd)*randm(rank,n, rnd);
+ //matrix<double> R = randm(m,rank, rnd)*randm(rank,n2, rnd);
+ matrix<double> R = L*randm(n,n2, rnd);
+ //matrix<double> L = randm(m,n, rnd);
+ //matrix<double> R = randm(m,n2, rnd);
+
+ matrix<double> Ltrans, Rtrans;
+ matrix<double,0,1> correlations;
+
+ {
+ correlations = cca(L, R, Ltrans, Rtrans, min(m,n), max(n,n2));
+ DLIB_TEST(Ltrans.nc() == Rtrans.nc());
+ dlog << LINFO << "correlations: "<< trans(correlations);
+
+ const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
+ dlog << LINFO << "correlation error: "<< corr_error;
+ DLIB_TEST_MSG(corr_error < 1e-13, Ltrans << "\n\n" << Rtrans);
+
+ const double trans_error = max(abs(L*Ltrans - R*Rtrans));
+ dlog << LINFO << "trans_error: "<< trans_error;
+ DLIB_TEST_MSG(trans_error < 1e-9, trans_error);
+ }
+ {
+ correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, min(m,n), max(n,n2)+6, 4);
+ DLIB_TEST(Ltrans.nc() == Rtrans.nc());
+ dlog << LINFO << "correlations: "<< trans(correlations);
+ dlog << LINFO << "computed cors: " << trans(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)));
+
+ const double trans_error = max(abs(L*Ltrans - R*Rtrans));
+ dlog << LINFO << "trans_error: "<< trans_error;
+ const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
+ dlog << LINFO << "correlation error: "<< corr_error;
+ DLIB_TEST_MSG(corr_error < 1e-13, Ltrans << "\n\n" << Rtrans);
+
+ DLIB_TEST(trans_error < 2e-9);
+ }
+
+ dlog << LINFO << "*****************************************************";
+ }
+
+ void test_cca2()
+ {
+ print_spinner();
+ const unsigned long rank = rnd.get_random_32bit_number()%10 + 1;
+ const unsigned long m = rank + rnd.get_random_32bit_number()%15;
+ const unsigned long n = rank + rnd.get_random_32bit_number()%15;
+ const unsigned long n2 = rank + rnd.get_random_32bit_number()%15;
+
+ dlog << LINFO << "m: " << m;
+ dlog << LINFO << "n: " << n;
+ dlog << LINFO << "n2: " << n2;
+ dlog << LINFO << "rank: " << rank;
+
+
+ matrix<double> L = randm(m,n, rnd);
+ matrix<double> R = randm(m,n2, rnd);
+
+ matrix<double> Ltrans, Rtrans;
+ matrix<double,0,1> correlations;
+
+ {
+ correlations = cca(L, R, Ltrans, Rtrans, min(n,n2), max(n,n2)-min(n,n2));
+ DLIB_TEST(Ltrans.nc() == Rtrans.nc());
+ dlog << LINFO << "correlations: "<< trans(correlations);
+
+ if (Ltrans.nc() > 1)
+ {
+ // The CCA projection directions are supposed to be uncorrelated for
+ // non-matching pairs of projections.
+ const double corr_rot1_error = max(abs(compute_correlations(rm_zeros(L*rotate<0,1>(Ltrans)), rm_zeros(R*Rtrans))));
+ dlog << LINFO << "corr_rot1_error: "<< corr_rot1_error;
+ DLIB_TEST(std::abs(corr_rot1_error) < 1e-10);
+ }
+ // Matching projection directions should be correlated with the amount of
+ // correlation indicated by the return value of cca().
+ const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
+ dlog << LINFO << "correlation error: "<< corr_error;
+ DLIB_TEST(corr_error < 1e-13);
+ }
+ {
+ correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, min(n,n2), max(n,n2)-min(n,n2));
+ DLIB_TEST(Ltrans.nc() == Rtrans.nc());
+ dlog << LINFO << "correlations: "<< trans(correlations);
+
+ if (Ltrans.nc() > 1)
+ {
+ // The CCA projection directions are supposed to be uncorrelated for
+ // non-matching pairs of projections.
+ const double corr_rot1_error = max(abs(compute_correlations(rm_zeros(L*rotate<0,1>(Ltrans)), rm_zeros(R*Rtrans))));
+ dlog << LINFO << "corr_rot1_error: "<< corr_rot1_error;
+ DLIB_TEST(std::abs(corr_rot1_error) < 1e-10);
+ }
+ // Matching projection directions should be correlated with the amount of
+ // correlation indicated by the return value of cca().
+ const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
+ dlog << LINFO << "correlation error: "<< corr_error;
+ DLIB_TEST(corr_error < 1e-13);
+ }
+
+ dlog << LINFO << "*****************************************************";
+ }
+
+ void test_cca1()
+ {
+ print_spinner();
+ const unsigned long rank = rnd.get_random_32bit_number()%10 + 1;
+ const unsigned long m = rank + rnd.get_random_32bit_number()%15;
+ const unsigned long n = rank + rnd.get_random_32bit_number()%15;
+
+ dlog << LINFO << "m: " << m;
+ dlog << LINFO << "n: " << n;
+ dlog << LINFO << "rank: " << rank;
+
+ matrix<double> T = randm(n,n, rnd);
+
+ matrix<double> L = randm(m,rank, rnd)*randm(rank,n, rnd);
+ //matrix<double> L = randm(m,n, rnd);
+ matrix<double> R = L*T;
+
+ matrix<double> Ltrans, Rtrans;
+ matrix<double,0,1> correlations;
+
+ {
+ correlations = cca(L, R, Ltrans, Rtrans, rank);
+ DLIB_TEST(Ltrans.nc() == Rtrans.nc());
+ if (Ltrans.nc() > 1)
+ {
+ // The CCA projection directions are supposed to be uncorrelated for
+ // non-matching pairs of projections.
+ const double corr_rot1_error = max(abs(compute_correlations(rm_zeros(L*rotate<0,1>(Ltrans)), rm_zeros(R*Rtrans))));
+ dlog << LINFO << "corr_rot1_error: "<< corr_rot1_error;
+ DLIB_TEST(std::abs(corr_rot1_error) < 2e-9);
+ }
+ // Matching projection directions should be correlated with the amount of
+ // correlation indicated by the return value of cca().
+ const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
+ dlog << LINFO << "correlation error: "<< corr_error;
+ DLIB_TEST(corr_error < 1e-13);
+
+ const double trans_error = max(abs(L*Ltrans - R*Rtrans));
+ dlog << LINFO << "trans_error: "<< trans_error;
+ DLIB_TEST(trans_error < 2e-9);
+
+ dlog << LINFO << "correlations: "<< trans(correlations);
+ }
+ {
+ correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, rank);
+ DLIB_TEST(Ltrans.nc() == Rtrans.nc());
+ if (Ltrans.nc() > 1)
+ {
+ // The CCA projection directions are supposed to be uncorrelated for
+ // non-matching pairs of projections.
+ const double corr_rot1_error = max(abs(compute_correlations(rm_zeros(L*rotate<0,1>(Ltrans)), rm_zeros(R*Rtrans))));
+ dlog << LINFO << "corr_rot1_error: "<< corr_rot1_error;
+ DLIB_TEST(std::abs(corr_rot1_error) < 2e-9);
+ }
+ // Matching projection directions should be correlated with the amount of
+ // correlation indicated by the return value of cca().
+ const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
+ dlog << LINFO << "correlation error: "<< corr_error;
+ DLIB_TEST(corr_error < 1e-13);
+
+ const double trans_error = max(abs(L*Ltrans - R*Rtrans));
+ dlog << LINFO << "trans_error: "<< trans_error;
+ DLIB_TEST(trans_error < 2e-9);
+
+ dlog << LINFO << "correlations: "<< trans(correlations);
+ }
+
+ dlog << LINFO << "*****************************************************";
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_svd_fast(
+ long rank,
+ long m,
+ long n
+ )
+ {
+ print_spinner();
+ matrix<double> A = randm(m,rank,rnd)*randm(rank,n,rnd);
+ matrix<double> u,v;
+ matrix<double,0,1> w;
+
+ dlog << LINFO << "rank: "<< rank;
+ dlog << LINFO << "m: "<< m;
+ dlog << LINFO << "n: "<< n;
+
+ svd_fast(A, u, w, v, rank, 2);
+ DLIB_TEST(u.nr() == m);
+ DLIB_TEST(u.nc() == rank);
+ DLIB_TEST(w.nr() == rank);
+ DLIB_TEST(w.nc() == 1);
+ DLIB_TEST(v.nr() == n);
+ DLIB_TEST(v.nc() == rank);
+ DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
+
+ DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-13);
+ svd_fast(mat_to_sparse(A), u, w, v, rank, 2);
+ DLIB_TEST(u.nr() == m);
+ DLIB_TEST(u.nc() == rank);
+ DLIB_TEST(w.nr() == rank);
+ DLIB_TEST(w.nc() == 1);
+ DLIB_TEST(v.nr() == n);
+ DLIB_TEST(v.nc() == rank);
+ DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-13);
+
+ svd_fast(A, u, w, v, rank, 0);
+ DLIB_TEST(u.nr() == m);
+ DLIB_TEST(u.nc() == rank);
+ DLIB_TEST(w.nr() == rank);
+ DLIB_TEST(w.nc() == 1);
+ DLIB_TEST(v.nr() == n);
+ DLIB_TEST(v.nc() == rank);
+ DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST_MSG(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-9,max(abs(tmp(A - u*diagm(w)*trans(v)))));
+ svd_fast(mat_to_sparse(A), u, w, v, rank, 0);
+ DLIB_TEST(u.nr() == m);
+ DLIB_TEST(u.nc() == rank);
+ DLIB_TEST(w.nr() == rank);
+ DLIB_TEST(w.nc() == 1);
+ DLIB_TEST(v.nr() == n);
+ DLIB_TEST(v.nc() == rank);
+ DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-10);
+
+ svd_fast(A, u, w, v, rank+5, 0);
+ DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-11);
+ svd_fast(mat_to_sparse(A), u, w, v, rank+5, 0);
+ DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-11);
+ svd_fast(A, u, w, v, rank+5, 1);
+ DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-12);
+ svd_fast(mat_to_sparse(A), u, w, v, rank+5, 1);
+ DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
+ DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-12);
+ }
+
+ void test_svd_fast()
+ {
+ for (int iter = 0; iter < 1000; ++iter)
+ {
+ const unsigned long rank = rnd.get_random_32bit_number()%10 + 1;
+ const unsigned long m = rank + rnd.get_random_32bit_number()%10;
+ const unsigned long n = rank + rnd.get_random_32bit_number()%10;
+
+ test_svd_fast(rank, m, n);
+
+ }
+ test_svd_fast(1, 1, 1);
+ test_svd_fast(1, 2, 2);
+ test_svd_fast(1, 1, 2);
+ test_svd_fast(1, 2, 1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ /*
+ typedef std::vector<std::pair<unsigned int, float>> sv;
+ sv rand_sparse_vector()
+ {
+ static dlib::rand rnd;
+ sv v;
+ for (int i = 0; i < 50; ++i)
+ v.push_back(make_pair(rnd.get_integer(400000), rnd.get_random_gaussian()*100));
+
+ make_sparse_vector_inplace(v);
+ return v;
+ }
+
+ sv rand_basis_combo(const std::vector<sv>& basis)
+ {
+ static dlib::rand rnd;
+ sv result;
+
+ for (int i = 0; i < 5; ++i)
+ {
+ sv temp = basis[rnd.get_integer(basis.size())];
+ scale_by(temp, rnd.get_random_gaussian());
+ result = add(result,temp);
+ }
+ return result;
+ }
+
+ void big_sparse_speed_test()
+ {
+ cout << "making A" << endl;
+ std::vector<sv> basis;
+ for (int i = 0; i < 100; ++i)
+ basis.emplace_back(rand_sparse_vector());
+
+ std::vector<sv> A;
+ for (int i = 0; i < 500000; ++i)
+ A.emplace_back(rand_basis_combo(basis));
+
+ cout << "done making A" << endl;
+
+ matrix<float> u,v;
+ matrix<float,0,1> w;
+ {
+ timing::block aosijdf(0,"call it");
+ svd_fast(A, u,w,v, 100, 5);
+ }
+
+ timing::print();
+ }
+ */
+
+// ----------------------------------------------------------------------------------------
+
+ class test_cca : public tester
+ {
+ public:
+ test_cca (
+ ) :
+ tester ("test_cca",
+ "Runs tests on the cca() and svd_fast() routines.")
+ {}
+
+ void perform_test (
+ )
+ {
+ //big_sparse_speed_test();
+ for (int i = 0; i < 200; ++i)
+ {
+ test_cca1();
+ test_cca2();
+ test_cca3();
+ }
+ test_svd_fast();
+ }
+ } a;
+
+
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/checkerboard.h b/ml/dlib/dlib/test/checkerboard.h
new file mode 100644
index 000000000..90c8f0cb8
--- /dev/null
+++ b/ml/dlib/dlib/test/checkerboard.h
@@ -0,0 +1,55 @@
+#ifndef DLIB_CHECKERBOARD_TeST_H_
+#define DLIB_CHECKERBOARD_TeST_H_
+
+#include <dlib/matrix.h>
+#include <vector>
+#include <dlib/rand.h>
+
+namespace dlib
+{
+
+ template <typename scalar_type>
+ void get_checkerboard_problem (
+ std::vector<matrix<scalar_type,2,1> >& x,
+ std::vector<scalar_type>& y,
+ const long num_samples,
+ const long board_dimension = 8
+ )
+ /*!
+ requires
+ - num_samples > 0
+ - board_dimension > 0
+ ensures
+ - #x.size() == y.size() == num_samples
+ - is_binary_classification_problem(#x,#y) == true
+ - #x will contain points and #y labels that were
+ sampled randomly from a checkers board that has
+ board_dimension squares on each side.
+ !*/
+ {
+ static dlib::rand rnd;
+
+ x.clear();
+ y.clear();
+
+ matrix<scalar_type,2,1> sample;
+ for (long i = 0; i < num_samples; ++i)
+ {
+ sample(0) = rnd.get_random_double();
+ sample(1) = rnd.get_random_double();
+ sample *= board_dimension;
+
+ x.push_back(sample);
+ if (((int)sum(floor(sample)) %2) == 0)
+ y.push_back(+1);
+ else
+ y.push_back(-1);
+
+ }
+ }
+
+
+}
+
+#endif // DLIB_CHECKERBOARD_TeST_H_
+
diff --git a/ml/dlib/dlib/test/clustering.cpp b/ml/dlib/dlib/test/clustering.cpp
new file mode 100644
index 000000000..a784c57c0
--- /dev/null
+++ b/ml/dlib/dlib/test/clustering.cpp
@@ -0,0 +1,410 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <dlib/clustering.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.clustering");
+
+// ----------------------------------------------------------------------------------------
+
+ void make_test_graph(
+ dlib::rand& rnd,
+ std::vector<sample_pair>& edges,
+ std::vector<unsigned long>& labels,
+ const int groups,
+ const int group_size,
+ const int noise_level,
+ const double missed_edges
+ )
+ {
+ labels.resize(groups*group_size);
+
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ labels[i] = i/group_size;
+ }
+
+ edges.clear();
+ for (int i = 0; i < groups; ++i)
+ {
+ for (int j = 0; j < group_size; ++j)
+ {
+ for (int k = 0; k < group_size; ++k)
+ {
+ if (j == k)
+ continue;
+
+ if (rnd.get_random_double() < missed_edges)
+ continue;
+
+ edges.push_back(sample_pair(j+group_size*i, k+group_size*i, 1));
+ }
+ }
+ }
+
+ for (int k = 0; k < groups*noise_level; ++k)
+ {
+ const int i = rnd.get_random_32bit_number()%labels.size();
+ const int j = rnd.get_random_32bit_number()%labels.size();
+ edges.push_back(sample_pair(i,j,1));
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void make_modularity_matrices (
+ const std::vector<sample_pair>& edges,
+ matrix<double>& A,
+ matrix<double>& P,
+ double& m
+ )
+ {
+ const unsigned long num_nodes = max_index_plus_one(edges);
+ A.set_size(num_nodes, num_nodes);
+ P.set_size(num_nodes, num_nodes);
+ A = 0;
+ P = 0;
+ std::vector<double> k(num_nodes,0);
+
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ const unsigned long n1 = edges[i].index1();
+ const unsigned long n2 = edges[i].index2();
+ k[n1] += edges[i].distance();
+ if (n1 != n2)
+ {
+ k[n2] += edges[i].distance();
+ A(n2,n1) += edges[i].distance();
+ }
+
+ A(n1,n2) += edges[i].distance();
+ }
+
+ m = sum(A)/2;
+
+ for (long r = 0; r < P.nr(); ++r)
+ {
+ for (long c = 0; c < P.nc(); ++c)
+ {
+ P(r,c) = k[r]*k[c]/(2*m);
+ }
+ }
+
+ }
+
+ double compute_modularity_simple (
+ const std::vector<sample_pair>& edges,
+ std::vector<unsigned long> labels
+ )
+ {
+ double m;
+ matrix<double> A,P;
+ make_modularity_matrices(edges, A, P, m);
+ matrix<double> B = A - P;
+
+ double Q = 0;
+ for (long r = 0; r < B.nr(); ++r)
+ {
+ for (long c = 0; c < B.nc(); ++c)
+ {
+ if (labels[r] == labels[c])
+ {
+ Q += B(r,c);
+ }
+ }
+ }
+ return 1.0/(2*m) * Q;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_modularity(dlib::rand& rnd)
+ {
+ print_spinner();
+ std::vector<sample_pair> edges;
+ std::vector<ordered_sample_pair> oedges;
+ std::vector<unsigned long> labels;
+
+ make_test_graph(rnd, edges, labels, 10, 30, 3, 0.10);
+ if (rnd.get_random_double() < 0.5)
+ remove_duplicate_edges(edges);
+ convert_unordered_to_ordered(edges, oedges);
+
+
+ const double m1 = modularity(edges, labels);
+ const double m2 = compute_modularity_simple(edges, labels);
+ const double m3 = modularity(oedges, labels);
+
+ DLIB_TEST(std::abs(m1-m2) < 1e-12);
+ DLIB_TEST(std::abs(m2-m3) < 1e-12);
+ DLIB_TEST(std::abs(m3-m1) < 1e-12);
+ }
+
+ void test_newman_clustering(dlib::rand& rnd)
+ {
+ print_spinner();
+ std::vector<sample_pair> edges;
+ std::vector<unsigned long> labels;
+
+ make_test_graph(rnd, edges, labels, 5, 30, 3, 0.10);
+ if (rnd.get_random_double() < 0.5)
+ remove_duplicate_edges(edges);
+
+
+ std::vector<unsigned long> labels2;
+
+ unsigned long num_clusters = newman_cluster(edges, labels2);
+ DLIB_TEST(labels.size() == labels2.size());
+ DLIB_TEST(num_clusters == 5);
+
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ for (unsigned long j = 0; j < labels.size(); ++j)
+ {
+ if (labels[i] == labels[j])
+ {
+ DLIB_TEST(labels2[i] == labels2[j]);
+ }
+ else
+ {
+ DLIB_TEST(labels2[i] != labels2[j]);
+ }
+ }
+ }
+ }
+
+ void test_chinese_whispers(dlib::rand& rnd)
+ {
+ print_spinner();
+ std::vector<sample_pair> edges;
+ std::vector<unsigned long> labels;
+
+ make_test_graph(rnd, edges, labels, 5, 30, 3, 0.10);
+ if (rnd.get_random_double() < 0.5)
+ remove_duplicate_edges(edges);
+
+
+ std::vector<unsigned long> labels2;
+
+ unsigned long num_clusters;
+ if (rnd.get_random_double() < 0.5)
+ num_clusters = chinese_whispers(edges, labels2, 200, rnd);
+ else
+ num_clusters = chinese_whispers(edges, labels2);
+
+ DLIB_TEST(labels.size() == labels2.size());
+ DLIB_TEST(num_clusters == 5);
+
+ for (unsigned long i = 0; i < labels.size(); ++i)
+ {
+ for (unsigned long j = 0; j < labels.size(); ++j)
+ {
+ if (labels[i] == labels[j])
+ {
+ DLIB_TEST(labels2[i] == labels2[j]);
+ }
+ else
+ {
+ DLIB_TEST(labels2[i] != labels2[j]);
+ }
+ }
+ }
+ }
+
+ void test_bottom_up_clustering()
+ {
+ std::vector<dpoint> pts;
+ pts.push_back(dpoint(0.0,0.0));
+ pts.push_back(dpoint(0.5,0.0));
+ pts.push_back(dpoint(0.5,0.5));
+ pts.push_back(dpoint(0.0,0.5));
+
+ pts.push_back(dpoint(3.0,3.0));
+ pts.push_back(dpoint(3.5,3.0));
+ pts.push_back(dpoint(3.5,3.5));
+ pts.push_back(dpoint(3.0,3.5));
+
+ pts.push_back(dpoint(7.0,7.0));
+ pts.push_back(dpoint(7.5,7.0));
+ pts.push_back(dpoint(7.5,7.5));
+ pts.push_back(dpoint(7.0,7.5));
+
+ matrix<double> dists(pts.size(), pts.size());
+ for (long r = 0; r < dists.nr(); ++r)
+ for (long c = 0; c < dists.nc(); ++c)
+ dists(r,c) = length(pts[r]-pts[c]);
+
+
+ matrix<unsigned long,0,1> truth(12);
+ truth = 0, 0, 0, 0,
+ 1, 1, 1, 1,
+ 2, 2, 2, 2;
+
+ std::vector<unsigned long> labels;
+ DLIB_TEST(bottom_up_cluster(dists, labels, 3) == 3);
+ DLIB_TEST(mat(labels) == truth);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 1, 4.0) == 3);
+ DLIB_TEST(mat(labels) == truth);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 1, 4.95) == 2);
+ truth = 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 1, 1, 1, 1;
+ DLIB_TEST(mat(labels) == truth);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 1) == 1);
+ truth = 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0;
+ DLIB_TEST(mat(labels) == truth);
+
+ dists.set_size(0,0);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 3) == 0);
+ DLIB_TEST(labels.size() == 0);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 1) == 0);
+ DLIB_TEST(labels.size() == 0);
+
+ dists.set_size(1,1);
+ dists = 1;
+ DLIB_TEST(bottom_up_cluster(dists, labels, 3) == 1);
+ DLIB_TEST(labels.size() == 1);
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 1) == 1);
+ DLIB_TEST(labels.size() == 1);
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 1, 0) == 1);
+ DLIB_TEST(labels.size() == 1);
+ DLIB_TEST(labels[0] == 0);
+
+ dists.set_size(2,2);
+ dists = 1;
+ DLIB_TEST(bottom_up_cluster(dists, labels, 3) == 2);
+ DLIB_TEST(labels.size() == 2);
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(labels[1] == 1);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 1) == 1);
+ DLIB_TEST(labels.size() == 2);
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(labels[1] == 0);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 1, 1) == 1);
+ DLIB_TEST(labels.size() == 2);
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(labels[1] == 0);
+ DLIB_TEST(bottom_up_cluster(dists, labels, 1, 0.999) == 2);
+ DLIB_TEST(labels.size() == 2);
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(labels[1] == 1);
+ }
+
+ void test_segment_number_line()
+ {
+ dlib::rand rnd;
+
+
+ std::vector<double> x;
+ for (int i = 0; i < 5000; ++i)
+ {
+ x.push_back(rnd.get_double_in_range(-1.5, -1.01));
+ x.push_back(rnd.get_double_in_range(-0.99, -0.01));
+ x.push_back(rnd.get_double_in_range(0.01, 1));
+ }
+
+ auto r = segment_number_line(x,1);
+ std::sort(r.begin(), r.end());
+ DLIB_TEST(r.size() == 3);
+ DLIB_TEST(-1.5 <= r[0].lower && r[0].lower < r[0].upper && r[0].upper <= -1.01);
+ DLIB_TEST(-0.99 <= r[1].lower && r[1].lower < r[1].upper && r[1].upper <= -0.01);
+ DLIB_TEST(0.01 <= r[2].lower && r[2].lower < r[2].upper && r[2].upper <= 1);
+
+ x.clear();
+ for (int i = 0; i < 5000; ++i)
+ {
+ x.push_back(rnd.get_double_in_range(-2, 1));
+ x.push_back(rnd.get_double_in_range(-2, 1));
+ x.push_back(rnd.get_double_in_range(-2, 1));
+ }
+
+ r = segment_number_line(x,1);
+ DLIB_TEST(r.size() == 3);
+ r = segment_number_line(x,1.5);
+ DLIB_TEST(r.size() == 2);
+ r = segment_number_line(x,10.5);
+ DLIB_TEST(r.size() == 1);
+ DLIB_TEST(-2 <= r[0].lower && r[0].lower < r[0].upper && r[0].upper <= 1);
+ }
+
+ class test_clustering : public tester
+ {
+ public:
+ test_clustering (
+ ) :
+ tester ("test_clustering",
+ "Runs tests on the clustering routines.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_bottom_up_clustering();
+ test_segment_number_line();
+
+ dlib::rand rnd;
+
+ std::vector<sample_pair> edges;
+ std::vector<unsigned long> labels;
+ DLIB_TEST(newman_cluster(edges, labels) == 0);
+ DLIB_TEST(chinese_whispers(edges, labels) == 0);
+
+ edges.push_back(sample_pair(0,1,1));
+ DLIB_TEST(newman_cluster(edges, labels) == 1);
+ DLIB_TEST(labels.size() == 2);
+ DLIB_TEST(chinese_whispers(edges, labels) == 1);
+ DLIB_TEST(labels.size() == 2);
+
+ edges.clear();
+ edges.push_back(sample_pair(0,0,1));
+ DLIB_TEST(newman_cluster(edges, labels) == 1);
+ DLIB_TEST(labels.size() == 1);
+ DLIB_TEST(chinese_whispers(edges, labels) == 1);
+ DLIB_TEST(labels.size() == 1);
+
+ edges.clear();
+ edges.push_back(sample_pair(1,1,1));
+ DLIB_TEST(newman_cluster(edges, labels) == 1);
+ DLIB_TEST(labels.size() == 2);
+ DLIB_TEST(chinese_whispers(edges, labels) == 2);
+ DLIB_TEST(labels.size() == 2);
+
+ edges.push_back(sample_pair(0,0,1));
+ DLIB_TEST(newman_cluster(edges, labels) == 2);
+ DLIB_TEST(labels.size() == 2);
+ DLIB_TEST(chinese_whispers(edges, labels) == 2);
+ DLIB_TEST(labels.size() == 2);
+
+
+ for (int i = 0; i < 10; ++i)
+ test_modularity(rnd);
+
+ for (int i = 0; i < 10; ++i)
+ test_newman_clustering(rnd);
+
+ for (int i = 0; i < 10; ++i)
+ test_chinese_whispers(rnd);
+
+
+ }
+ } a;
+
+
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/cmd_line_parser.cpp b/ml/dlib/dlib/test/cmd_line_parser.cpp
new file mode 100644
index 000000000..9216a76cc
--- /dev/null
+++ b/ml/dlib/dlib/test/cmd_line_parser.cpp
@@ -0,0 +1,40 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <string>
+#include <dlib/string.h>
+
+#include <dlib/cmd_line_parser.h>
+
+#include "tester.h"
+
+#include "cmd_line_parser.h"
+namespace
+{
+
+ class cmd_line_parser_tester : public tester
+ {
+ public:
+ cmd_line_parser_tester (
+ ) :
+ tester ("test_cmd_line_parser_char",
+ "Runs tests on the cmd_line_parser component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a with char";
+ cmd_line_parser_kernel_test<cmd_line_parser<char>::kernel_1a>();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_1a_c with char";
+ cmd_line_parser_kernel_test<cmd_line_parser<char>::kernel_1a_c>();
+ print_spinner();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/cmd_line_parser.h b/ml/dlib/dlib/test/cmd_line_parser.h
new file mode 100644
index 000000000..6f8e411a4
--- /dev/null
+++ b/ml/dlib/dlib/test/cmd_line_parser.h
@@ -0,0 +1,901 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CMD_LINE_PARSER_KERNEl_TEST_H_
+#define DLIB_CMD_LINE_PARSER_KERNEl_TEST_H_
+
+
+#include <string>
+#include <dlib/string.h>
+
+#include <dlib/cmd_line_parser.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.cmd_line_parser");
+
+ template <
+ typename clp
+ >
+ void cmd_line_parser_kernel_test (
+ )
+ /*!
+ requires
+ - clp is an implementation of cmd_line_parser_kernel_abstract.h
+ ensures
+ - runs tests on clp for compliance with the specs
+ !*/
+ {
+ typedef typename clp::char_type ct;
+
+
+
+
+ int argc;
+ const ct* argv[100];
+ bool ok;
+
+ for (int j = 0; j < 3; ++j)
+ {
+ clp test, test2;
+
+
+
+
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start());
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+
+ DLIB_TEST(test.parsed_line() == false);
+ DLIB_TEST(test.option_is_defined(_dT(ct,"a")) == false);
+ DLIB_TEST(test.option_is_defined(_dT(ct,"a")) == false);
+ DLIB_TEST(test.option_is_defined(_dT(ct,"a")) == false);
+
+ DLIB_TEST(test.parsed_line() == false);
+ DLIB_TEST(test.option_is_defined(_dT(ct,"a")) == false);
+ DLIB_TEST(test.option_is_defined(_dT(ct,"b")) == false);
+ DLIB_TEST(test.option_is_defined(_dT(ct,"\0")) == false);
+
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+
+ // program arg1 --davis arg2 -cZzarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis");
+ argv[3] = _dT(ct,"arg2");
+ argv[4] = _dT(ct,"-cZzarg");
+ argv[5] = _dT(ct,"asdf");
+ argc = 6;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"));
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ try { test.parse(argc,argv); }
+ catch (error& e)
+ {
+ DLIB_TEST_MSG(false,e.info);
+ }
+
+ DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis"));
+ DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z"));
+ DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 0);
+ DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0);
+ DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2);
+ DLIB_TEST(test.number_of_arguments() == 2);
+ DLIB_TEST(test[0] == _dT(ct,"arg1"));
+ DLIB_TEST(test[1] == _dT(ct,"arg2"));
+ DLIB_TEST(test.option(_dT(ct,"d")).count()==0);
+ DLIB_TEST(test.option(_dT(ct,"davis")).count()==1);
+ DLIB_TEST_MSG(test.option(_dT(ct,"c")).count()==1,test.option(_dT(ct,"c")).count());
+ DLIB_TEST(test.option(_dT(ct,"Z")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf"));
+
+ }
+
+
+
+ swap(test,test2);
+
+
+
+
+
+ // program arg1 --davis arg2 -cZ zarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis");
+ argv[3] = _dT(ct,"arg2");
+ argv[4] = _dT(ct,"-cZ");
+ argv[5] = _dT(ct,"zarg");
+ argv[6] = _dT(ct,"asdf");
+ argc = 7;
+
+
+
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ try { test2.parse(argc,argv); }
+ catch (error& e)
+ {
+ DLIB_TEST_MSG(false,e.info);
+ }
+
+ DLIB_TEST(test2.option(_dT(ct,"davis")).name() == _dT(ct,"davis"));
+ DLIB_TEST(test2.option(_dT(ct,"c")).name() == _dT(ct,"c"));
+ DLIB_TEST(test2.option(_dT(ct,"Z")).name() == _dT(ct,"Z"));
+ DLIB_TEST(test2.option(_dT(ct,"davis")).number_of_arguments() == 0);
+ DLIB_TEST(test2.option(_dT(ct,"c")).number_of_arguments() == 0);
+ DLIB_TEST(test2.option(_dT(ct,"Z")).number_of_arguments() == 2);
+ DLIB_TEST(test2.number_of_arguments() == 2);
+ DLIB_TEST(test2[0] == _dT(ct,"arg1"));
+ DLIB_TEST(test2[1] == _dT(ct,"arg2"));
+ DLIB_TEST(test2.option(_dT(ct,"d")).count()==0);
+ DLIB_TEST(test2.option(_dT(ct,"davis")).count()==1);
+ DLIB_TEST(test2.option(_dT(ct,"c")).count()==1);
+ DLIB_TEST(test2.option(_dT(ct,"Z")).count()==1);
+ DLIB_TEST(test2.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf"));
+ DLIB_TEST_MSG(test2.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg"),
+ narrow(_dT(ct,"*") + test2.option(_dT(ct,"Z")).argument(0,0) + _dT(ct,"*")));
+
+
+ }
+
+
+
+
+
+ // program arg1 --davis= darg darg2 arg2 -cZzarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis=");
+ argv[3] = _dT(ct,"darg");
+ argv[4] = _dT(ct,"darg2");
+ argv[5] = _dT(ct,"arg2");
+ argv[6] = _dT(ct,"-cZzarg");
+ argv[7] = _dT(ct,"asdf");
+ argc = 8;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 2);
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ try { test.parse(argc,argv); }
+ catch (error& e)
+ {
+ DLIB_TEST_MSG(false,e.info);
+ }
+
+ DLIB_TEST(test.parsed_line());
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ if (test.element().name() == _dT(ct,"d"))
+ {
+ DLIB_TEST(test.element().count() == 0);
+ }
+ else
+ {
+ DLIB_TEST(test.element().count() == 1);
+ }
+
+ }
+ DLIB_TEST_MSG(count == 4,count);
+
+ DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis"));
+ DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z"));
+ DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 2);
+ DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0);
+ DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2);
+ DLIB_TEST(test.number_of_arguments() == 2);
+ DLIB_TEST(test[0] == _dT(ct,"arg1"));
+ DLIB_TEST(test[1] == _dT(ct,"arg2"));
+ DLIB_TEST(test.option(_dT(ct,"d")).count()==0);
+ DLIB_TEST(test.option(_dT(ct,"davis")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"c")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf"));
+ DLIB_TEST(test.option(_dT(ct,"davis")).argument(0,0) == _dT(ct,"darg"));
+ DLIB_TEST_MSG(test.option(_dT(ct,"davis")).argument(1,0) == _dT(ct,"darg2"),
+ narrow(test.option(_dT(ct,"davis")).argument(1,0)));
+ }
+
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+
+
+
+
+
+
+ // program arg1 --dav-is=darg darg2 arg2 -cZzarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--dav-is=darg");
+ argv[3] = _dT(ct,"darg2");
+ argv[4] = _dT(ct,"arg2");
+ argv[5] = _dT(ct,"-cZzarg");
+ argv[6] = _dT(ct,"asdf");
+ argc = 7;
+
+
+ test.add_option(_dT(ct,"dav-is"),_dT(ct,"davis option"), 2);
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ try { test.parse(argc,argv); }
+ catch (error& e)
+ {
+ DLIB_TEST_MSG(false,e.info);
+ }
+
+ DLIB_TEST(test.parsed_line());
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ if (test.element().name() == _dT(ct,"d"))
+ {
+ DLIB_TEST(test.element().count() == 0);
+ }
+ else
+ {
+ DLIB_TEST(test.element().count() == 1);
+ }
+
+ }
+ DLIB_TEST_MSG(count == 4,count);
+
+ DLIB_TEST(test.option(_dT(ct,"dav-is")).name() == _dT(ct,"dav-is"));
+ DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z"));
+ DLIB_TEST(test.option(_dT(ct,"dav-is")).number_of_arguments() == 2);
+ DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0);
+ DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2);
+ DLIB_TEST(test.number_of_arguments() == 2);
+ DLIB_TEST(test[0] == _dT(ct,"arg1"));
+ DLIB_TEST(test[1] == _dT(ct,"arg2"));
+ DLIB_TEST(test.option(_dT(ct,"d")).count()==0);
+ DLIB_TEST(test.option(_dT(ct,"dav-is")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"c")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf"));
+ DLIB_TEST(test.option(_dT(ct,"dav-is")).argument(0,0) == _dT(ct,"darg"));
+ DLIB_TEST_MSG(test.option(_dT(ct,"dav-is")).argument(1,0) == _dT(ct,"darg2"),
+ narrow(test.option(_dT(ct,"dav-is")).argument(1,0)));
+ }
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+
+
+
+
+
+
+ // program arg1 --davis=darg darg2 arg2 -cZzarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis=darg");
+ argv[3] = _dT(ct,"darg2");
+ argv[4] = _dT(ct,"arg2");
+ argv[5] = _dT(ct,"-cZzarg");
+ argv[6] = _dT(ct,"asdf");
+ argc = 7;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 2);
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ try { test.parse(argc,argv); }
+ catch (error& e)
+ {
+ DLIB_TEST_MSG(false,e.info);
+ }
+
+ DLIB_TEST(test.parsed_line());
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ if (test.element().name() == _dT(ct,"d"))
+ {
+ DLIB_TEST(test.element().count() == 0);
+ }
+ else
+ {
+ DLIB_TEST(test.element().count() == 1);
+ }
+
+ }
+ DLIB_TEST_MSG(count == 4,count);
+
+ DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis"));
+ DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z"));
+ DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 2);
+ DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0);
+ DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2);
+ DLIB_TEST(test.number_of_arguments() == 2);
+ DLIB_TEST(test[0] == _dT(ct,"arg1"));
+ DLIB_TEST(test[1] == _dT(ct,"arg2"));
+ DLIB_TEST(test.option(_dT(ct,"d")).count()==0);
+ DLIB_TEST(test.option(_dT(ct,"davis")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"c")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf"));
+ DLIB_TEST(test.option(_dT(ct,"davis")).argument(0,0) == _dT(ct,"darg"));
+ DLIB_TEST_MSG(test.option(_dT(ct,"davis")).argument(1,0) == _dT(ct,"darg2"),
+ narrow(test.option(_dT(ct,"davis")).argument(1,0)));
+ }
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+
+
+
+
+
+
+ // program arg1 --davis=darg arg2 -cZzarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis=darg");
+ argv[3] = _dT(ct,"arg2");
+ argv[4] = _dT(ct,"-cZzarg");
+ argv[5] = _dT(ct,"asdf");
+ argc = 6;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1);
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ try { test.parse(argc,argv); }
+ catch (error& e)
+ {
+ DLIB_TEST_MSG(false,e.info);
+ }
+
+ DLIB_TEST(test.parsed_line());
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ if (test.element().name() == _dT(ct,"d"))
+ {
+ DLIB_TEST(test.element().count() == 0);
+ }
+ else
+ {
+ DLIB_TEST(test.element().count() == 1);
+ }
+
+ }
+ DLIB_TEST_MSG(count == 4,count);
+
+ DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis"));
+ DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z"));
+ DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 1);
+ DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0);
+ DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2);
+ DLIB_TEST(test.number_of_arguments() == 2);
+ DLIB_TEST(test[0] == _dT(ct,"arg1"));
+ DLIB_TEST(test[1] == _dT(ct,"arg2"));
+ DLIB_TEST(test.option(_dT(ct,"d")).count()==0);
+ DLIB_TEST(test.option(_dT(ct,"davis")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"c")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf"));
+ DLIB_TEST(test.option(_dT(ct,"davis")).argument(0,0) == _dT(ct,"darg"));
+ }
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+
+
+
+
+
+ // program arg1 --davis darg arg2 -cZzarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis");
+ argv[3] = _dT(ct,"darg");
+ argv[4] = _dT(ct,"arg2");
+ argv[5] = _dT(ct,"-cZzarg");
+ argv[6] = _dT(ct,"asdf");
+ argc = 7;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1);
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ try { test.parse(argc,argv); }
+ catch (error& e)
+ {
+ DLIB_TEST_MSG(false,e.info);
+ }
+
+ DLIB_TEST(test.parsed_line());
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ if (test.element().name() == _dT(ct,"d"))
+ {
+ DLIB_TEST(test.element().count() == 0);
+ }
+ else
+ {
+ DLIB_TEST(test.element().count() == 1);
+ }
+
+ }
+ DLIB_TEST_MSG(count == 4,count);
+
+ DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis"));
+ DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z"));
+ DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 1);
+ DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0);
+ DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2);
+ DLIB_TEST(test.number_of_arguments() == 2);
+ DLIB_TEST(test[0] == _dT(ct,"arg1"));
+ DLIB_TEST(test[1] == _dT(ct,"arg2"));
+ DLIB_TEST(test.option(_dT(ct,"d")).count()==0);
+ DLIB_TEST(test.option(_dT(ct,"davis")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"c")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(1) == _dT(ct,"asdf"));
+ DLIB_TEST(test.option(_dT(ct,"davis")).argument(0,0) == _dT(ct,"darg"));
+ }
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+ // this string is incorrect because there is no avis option
+ // program arg1 --avis darg arg2 -cZzarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--avis");
+ argv[3] = _dT(ct,"darg");
+ argv[4] = _dT(ct,"arg2");
+ argv[5] = _dT(ct,"-cZzarg");
+ argv[6] = _dT(ct,"asdf");
+ argc = 7;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1);
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ ok = false;
+ try { test.parse(argc,argv); }
+ catch (typename clp::cmd_line_parse_error& e)
+ {
+ DLIB_TEST(e.type == EINVALID_OPTION);
+ DLIB_TEST(e.item == _dT(ct,"avis"));
+ ok = true;
+ }
+ DLIB_TEST(ok);
+
+
+ }
+
+
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+ // the c argument appears twice. make sure its count is correct
+ // program arg1 --davis darg arg2 -ccZzarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis");
+ argv[3] = _dT(ct,"darg");
+ argv[4] = _dT(ct,"arg2");
+ argv[5] = _dT(ct,"-ccZ");
+ argv[6] = _dT(ct,"zarg");
+ argv[7] = _dT(ct,"asdf");
+ argc = 8;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1);
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ ok = false;
+ test.parse(argc,argv);
+
+ DLIB_TEST(test.option(_dT(ct,"c")).count()==2);
+
+ }
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+ // this is a bad line because the davis argument requires 2 arguments but
+ // only gets one.
+ // program arg1 --davis darg darg2 --davis zarg
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis");
+ argv[3] = _dT(ct,"darg");
+ argv[4] = _dT(ct,"darg2");
+ argv[5] = _dT(ct,"--davis");
+ argv[6] = _dT(ct,"zarg");
+ argc = 7;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 2);
+ test.add_option(_dT(ct,"b"),_dT(ct,"b option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ DLIB_TEST(test.option(_dT(ct,"davis")).description() == _dT(ct,"davis option"));
+ DLIB_TEST(test.option(_dT(ct,"b")).description() == _dT(ct,"b option"));
+ DLIB_TEST(test.option(_dT(ct,"d")).description() == _dT(ct,"d option"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).description() == _dT(ct,"Z option"));
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ ok = false;
+ try { test.parse(argc,argv); }
+ catch (typename clp::cmd_line_parse_error& e)
+ {
+ DLIB_TEST(e.type == ETOO_FEW_ARGS);
+ DLIB_TEST(e.num == 2);
+ DLIB_TEST(e.item == _dT(ct,"davis"));
+ ok = true;
+ }
+ DLIB_TEST(ok);
+
+
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ DLIB_TEST(test.element().count() == 0);
+ DLIB_TEST(test.option_is_defined(test.element().name()));
+ }
+ DLIB_TEST_MSG(count == 4,count);
+
+
+ }
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+ // this is a bad line because the davis argument is not defined
+ // program arg1 --davis darg arg2 -davis zarg asdf
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis");
+ argv[3] = _dT(ct,"darg");
+ argv[4] = _dT(ct,"arg2");
+ argv[5] = _dT(ct,"--davis");
+ argv[6] = _dT(ct,"zarg");
+ argv[7] = _dT(ct,"asdf");
+ argc = 8;
+
+
+ DLIB_TEST(std::basic_string<ct>(argv[0]) == _dT(ct,"program"));
+
+ test.add_option(_dT(ct,"mavis"),_dT(ct,"mavis option"), 1);
+ test.add_option(_dT(ct,"b"),_dT(ct,"b option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ DLIB_TEST(test.option(_dT(ct,"mavis")).description() == _dT(ct,"mavis option"));
+ DLIB_TEST(test.option(_dT(ct,"b")).description() == _dT(ct,"b option"));
+ DLIB_TEST(test.option(_dT(ct,"d")).description() == _dT(ct,"d option"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).description() == _dT(ct,"Z option"));
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ ok = false;
+ try { test.parse(argc,argv); }
+ catch (typename clp::cmd_line_parse_error& e)
+ {
+ DLIB_TEST(e.type == EINVALID_OPTION);
+ DLIB_TEST(e.item == _dT(ct,"davis"));
+ ok = true;
+ }
+ DLIB_TEST(ok);
+
+
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ DLIB_TEST(test.element().count() == 0);
+ DLIB_TEST(test.option_is_defined(test.element().name()));
+ }
+ DLIB_TEST_MSG(count == 4,count);
+
+
+ }
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+
+ argv[0] = _dT(ct,"program");
+ argc = 1;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1);
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2);
+
+
+ DLIB_TEST(test.option(_dT(ct,"davis")).description() == _dT(ct,"davis option"));
+ DLIB_TEST(test.option(_dT(ct,"c")).description() == _dT(ct,"c option"));
+ DLIB_TEST(test.option(_dT(ct,"d")).description() == _dT(ct,"d option"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).description() == _dT(ct,"Z option"));
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ test.parse(argc,argv);
+
+
+ DLIB_TEST(test.number_of_arguments() == 0);
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ DLIB_TEST(test.element().count() == 0);
+ DLIB_TEST(test.option_is_defined(test.element().name()));
+ }
+ DLIB_TEST_MSG(count == 4,count);
+
+
+ }
+
+
+
+
+
+
+
+
+
+
+
+
+ test.clear();
+
+ // this is to make sure the -- command works right
+ // program arg1 --davis -darg -- arg2 -c asdf -Zeat -Zat -Zjoe's
+ argv[0] = _dT(ct,"program");
+ argv[1] = _dT(ct,"arg1");
+ argv[2] = _dT(ct,"--davis");
+ argv[3] = _dT(ct,"-darg");
+ argv[4] = _dT(ct,"-Zeat");
+ argv[5] = _dT(ct,"-Zat");
+ argv[6] = _dT(ct,"-Zjoe's");
+ argv[7] = _dT(ct,"--");
+ argv[8] = _dT(ct,"arg2");
+ argv[9] = _dT(ct,"-c");
+ argv[10] = _dT(ct,"asdf");
+
+ argc = 11;
+
+
+ test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1);
+ test.add_option(_dT(ct,"c"),_dT(ct,"c option"));
+ test.add_option(_dT(ct,"d"),_dT(ct,"d option"));
+ test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),1);
+
+
+ DLIB_TEST(test.option(_dT(ct,"davis")).description() == _dT(ct,"davis option"));
+ DLIB_TEST(test.option(_dT(ct,"c")).description() == _dT(ct,"c option"));
+ DLIB_TEST(test.option(_dT(ct,"d")).description() == _dT(ct,"d option"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).description() == _dT(ct,"Z option"));
+
+ for (int k = 0; k < 5; ++k)
+ {
+
+ test.parse(argc,argv);
+
+ DLIB_TEST_MSG(test.number_of_arguments() == 4,test.number_of_arguments());
+ DLIB_TEST(test[0] == _dT(ct,"arg1"));
+ DLIB_TEST(test[1] == _dT(ct,"arg2"));
+ DLIB_TEST(test[2] == _dT(ct,"-c"));
+ DLIB_TEST(test[3] == _dT(ct,"asdf"));
+
+ DLIB_TEST(test.option(_dT(ct,"davis")).count()==1);
+ DLIB_TEST(test.option(_dT(ct,"davis")).argument() == _dT(ct,"-darg"));
+ DLIB_TEST(test.option(_dT(ct,"c")).count()==0);
+ DLIB_TEST(test.option(_dT(ct,"d")).count()==0);
+ DLIB_TEST(test.option(_dT(ct,"Z")).count()==3);
+
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"eat"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,1) == _dT(ct,"at"));
+ DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,2) == _dT(ct,"joe's"));
+
+
+ }
+
+
+ }
+ }
+
+}
+
+
+#endif // DLIB_CMD_LINE_PARSER_KERNEl_TEST_H_
+
diff --git a/ml/dlib/dlib/test/cmd_line_parser_wchar_t.cpp b/ml/dlib/dlib/test/cmd_line_parser_wchar_t.cpp
new file mode 100644
index 000000000..f771eeedb
--- /dev/null
+++ b/ml/dlib/dlib/test/cmd_line_parser_wchar_t.cpp
@@ -0,0 +1,40 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <string>
+#include <dlib/string.h>
+
+#include <dlib/cmd_line_parser.h>
+
+#include "tester.h"
+
+#include "cmd_line_parser.h"
+namespace
+{
+
+ class cmd_line_parser_tester : public tester
+ {
+ public:
+ cmd_line_parser_tester (
+ ) :
+ tester ("test_cmd_line_parser_wchar_t",
+ "Runs tests on the cmd_line_parser<wchar_t> component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a with wchar_t";
+ cmd_line_parser_kernel_test<cmd_line_parser<wchar_t>::kernel_1a>();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_1a_c with wchar_t";
+ cmd_line_parser_kernel_test<cmd_line_parser<wchar_t>::kernel_1a_c>();
+ print_spinner();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/compress_stream.cpp b/ml/dlib/dlib/test/compress_stream.cpp
new file mode 100644
index 000000000..fbc57dd4c
--- /dev/null
+++ b/ml/dlib/dlib/test/compress_stream.cpp
@@ -0,0 +1,306 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <ctime>
+#include <cstdlib>
+
+#include <dlib/compress_stream.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.compress_stream");
+
+ template <
+ typename cs
+ >
+ void compress_stream_kernel_test (
+ unsigned long seed
+ )
+ /*!
+ requires
+ - cs is an implementation of compress_stream/compress_stream_kernel_abstract.h
+ the alphabet_size for cc is 256
+ ensures
+ - runs tests on cs for compliance with the specs
+ !*/
+ {
+
+
+ srand(seed);
+
+ cs test;
+
+
+ dlog << LTRACE << 1;
+
+ int count = 0;
+ while (count < 2)
+ {
+ print_spinner();
+ istringstream sin;
+ ostringstream sout;
+ string buffer;
+ buffer.reserve(10000);
+ // fill sin with a bunch of random data in the range 0 to 63
+ for (int i = 0; i < 10000; ++i)
+ {
+ char temp = static_cast<char>(::rand()&0x3f);
+ buffer.push_back(temp);
+ }
+
+ print_spinner();
+ sin.str(buffer);
+ string old_buffer = buffer;
+
+ test.compress(sin,sout);
+ buffer = sout.str();
+
+ print_spinner();
+ // corrput the data in buffer
+ buffer[buffer.size()/2]++;
+
+ sin.str(buffer);
+ sout.str("");
+
+ bool detected_error = false;
+ try {
+ test.decompress(sin,sout);
+ } catch ( typename cs::decompression_error e )
+ {
+ detected_error = true;
+ ++count;
+ }
+
+
+ DLIB_TEST_MSG(detected_error || sout.str() == old_buffer,(unsigned int)sout.str().size());
+
+
+
+ } /**/
+
+
+ dlog << LTRACE << 2;
+
+ for (int j = 0; j < 2; ++j)
+ {
+
+ print_spinner();
+ istringstream sin;
+ ostringstream sout;
+
+ string buffer;
+
+ buffer.reserve(10);
+
+ // make sure a single char can be compressed and decompressed
+ for (int i = 0; i < 256; ++i)
+ {
+ sin.str("");
+ sout.str("");
+ char ch = static_cast<char>(i);
+ buffer = ch;
+ sin.str(buffer);
+
+ test.compress(sin,sout);
+
+ sin.str(sout.str());
+ sout.str("");
+ test.decompress(sin,sout);
+ DLIB_TEST(sout.str() == buffer);
+ }
+
+ print_spinner();
+
+ // make sure you can compress a single char, then append a new
+ // compressed single char. and make sure you can decode the
+ // two streams. Just to make sure the decoder doesn't leave
+ // extra bytes behind or eat more than it should.
+ for (int i = 0; i < 500; ++i)
+ {
+ sin.str("");
+ sin.clear();
+ sout.str("");
+ sout.clear();
+ char ch = static_cast<char>(::rand()%256);
+ char ch2 = static_cast<char>(::rand()%256);
+
+ buffer = ch;
+ sin.str(buffer);
+
+
+
+ test.compress(sin,sout);
+
+
+
+
+ buffer = ch2;
+ sin.str(buffer);
+ test.compress(sin,sout);
+
+ sin.str(sout.str());
+
+ sout.str("");
+ test.decompress(sin,sout);
+ buffer = ch;
+ DLIB_TEST(sout.str() == buffer);
+
+
+
+
+ sout.str("");
+ test.decompress(sin,sout);
+ buffer = ch2;
+ DLIB_TEST(sout.str() == buffer);
+
+
+ }
+ print_spinner();
+
+
+ // make sure you can compress and decompress the empty string
+ sout.str("");
+ sin.str("");
+ test.compress(sin,sout);
+ sin.str(sout.str());
+ sout.str("");
+ test.decompress(sin,sout);
+ DLIB_TEST_MSG(sout.str() == "",sout.str());
+
+
+
+
+
+ print_spinner();
+
+ sin.str("");
+ sout.str("");
+ buffer = "";
+
+ buffer.reserve(20000);
+ // fill buffer with a bunch of random data in the range 0 to 63
+ for (int i = 0; i < 20000; ++i)
+ {
+ char temp = static_cast<char>(::rand()&0x3f);
+ buffer.push_back(temp);
+ }
+
+ sin.str(buffer);
+
+ print_spinner();
+ test.compress(sin,sout);
+
+ sin.str(sout.str());
+ sout.str("");
+
+ print_spinner();
+ test.decompress(sin,sout);
+
+ DLIB_TEST(sout.str() == buffer);
+
+ print_spinner();
+ }
+
+ dlog << LTRACE << 3;
+
+ // this block will try to compress a bunch of 'a' chars
+ {
+
+ istringstream sin;
+ ostringstream sout;
+
+ string buffer;
+
+
+ print_spinner();
+
+ sin.str("");
+ sout.str("");
+ buffer = "";
+
+ buffer.reserve(50000);
+ // fill buffer with a bunch of 'a' chars
+ for (int i = 0; i < 50000; ++i)
+ {
+ char temp = 'a';
+ buffer.push_back(temp);
+ }
+
+ sin.str(buffer);
+
+ print_spinner();
+ test.compress(sin,sout);
+
+ sin.str(sout.str());
+ sout.str("");
+
+ print_spinner();
+ test.decompress(sin,sout);
+
+ DLIB_TEST(sout.str() == buffer);
+
+ print_spinner();
+
+ }
+
+ dlog << LTRACE << 4;
+
+ }
+
+
+
+
+
+
+ class compress_stream_tester : public tester
+ {
+ public:
+ compress_stream_tester (
+ ) :
+ tester ("test_compress_stream",
+ "Runs tests on the compress_stream component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ const unsigned int seed = static_cast<unsigned int>(time(0));
+ dlog << LINFO << "using seed: " << seed;
+
+ dlog << LINFO << "testing kernel_1a";
+ compress_stream_kernel_test<compress_stream::kernel_1a>(seed);
+ dlog << LINFO << "testing kernel_1b";
+ compress_stream_kernel_test<compress_stream::kernel_1b>(seed);
+ dlog << LINFO << "testing kernel_1c";
+ compress_stream_kernel_test<compress_stream::kernel_1c>(seed);
+ dlog << LINFO << "testing kernel_1da";
+ compress_stream_kernel_test<compress_stream::kernel_1da>(seed);
+ dlog << LINFO << "testing kernel_1db";
+ compress_stream_kernel_test<compress_stream::kernel_1db>(seed);
+ dlog << LINFO << "testing kernel_1ea";
+ compress_stream_kernel_test<compress_stream::kernel_1ea>(seed);
+ dlog << LINFO << "testing kernel_1eb";
+ compress_stream_kernel_test<compress_stream::kernel_1eb>(seed);
+ dlog << LINFO << "testing kernel_1ec";
+ compress_stream_kernel_test<compress_stream::kernel_1ec>(seed);
+ dlog << LINFO << "testing kernel_2a";
+ compress_stream_kernel_test<compress_stream::kernel_2a>(seed);
+ dlog << LINFO << "testing kernel_3a";
+ compress_stream_kernel_test<compress_stream::kernel_3a>(seed);
+ dlog << LINFO << "testing kernel_3b";
+ compress_stream_kernel_test<compress_stream::kernel_3b>(seed);
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/conditioning_class.cpp b/ml/dlib/dlib/test/conditioning_class.cpp
new file mode 100644
index 000000000..b4415eafd
--- /dev/null
+++ b/ml/dlib/dlib/test/conditioning_class.cpp
@@ -0,0 +1,86 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <ctime>
+#include <cstdlib>
+
+#include <dlib/conditioning_class.h>
+
+#include "tester.h"
+#include "conditioning_class.h"
+
+namespace
+{
+
+
+ class conditioning_class_tester : public tester
+ {
+ public:
+ conditioning_class_tester (
+ ) :
+ tester ("test_conditioning_class",
+ "Runs tests on the conditioning_class component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_1a,
+ conditioning_class<2>::kernel_1a
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_2a";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_2a,
+ conditioning_class<2>::kernel_2a
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_3a";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_3a,
+ conditioning_class<2>::kernel_3a
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_4a";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_4a,
+ conditioning_class<2>::kernel_4a
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_4b";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_4b,
+ conditioning_class<2>::kernel_4b
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_4c";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_4c,
+ conditioning_class<2>::kernel_4c
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_4d";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_4d,
+ conditioning_class<2>::kernel_4d
+ >();
+ print_spinner();
+
+
+ }
+ } a;
+
+
+}
+
diff --git a/ml/dlib/dlib/test/conditioning_class.h b/ml/dlib/dlib/test/conditioning_class.h
new file mode 100644
index 000000000..3c6c88b8d
--- /dev/null
+++ b/ml/dlib/dlib/test/conditioning_class.h
@@ -0,0 +1,841 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TEST_CONDITIONING_CLASs_H_
+#define DLIB_TEST_CONDITIONING_CLASs_H_
+
+
+#include <sstream>
+#include <string>
+#include <ctime>
+#include <cstdlib>
+
+#include <dlib/conditioning_class.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.conditioning_class");
+
+ template <
+ typename cc,
+ typename cc2
+ >
+ void conditioning_class_kernel_test (
+ )
+ /*!
+ requires
+ - cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ the alphabet_size for cc is 256
+ - cc2 is an implementation of conditioning_class/conditioning_class_kernel_abstract.h
+ the alphabet_size for cc2 is 2
+ ensures
+ - runs tests on cc for compliance with the specs
+ !*/
+ {
+
+ srand(static_cast<unsigned int>(time(0)));
+
+
+
+ typename cc::global_state_type gs;
+ typename cc2::global_state_type gs2;
+
+
+
+
+ for (int g = 0; g < 2; ++g)
+ {
+ print_spinner();
+ unsigned long amount=g+1;
+ cc2 test(gs2);
+ cc2 test2(gs2);
+
+
+ DLIB_TEST(test.get_memory_usage() != 0);
+
+ const unsigned long alphabet_size = 2;
+
+
+ DLIB_TEST(test.get_total() == 1);
+
+ DLIB_TEST(test.get_count(alphabet_size-1)==1);
+ for (unsigned long i = 0; i < alphabet_size-1; ++i)
+ {
+ unsigned long low_count, high_count, total_count;
+ DLIB_TEST_MSG(test.get_range(i,low_count,high_count,total_count) == 0,i);
+ DLIB_TEST(test.get_count(i) == 0);
+ DLIB_TEST(test.get_total() == 1);
+ }
+
+
+
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ test.increment_count(i,static_cast<unsigned short>(amount));
+ unsigned long low_count = 0, high_count = 0, total_count = 0;
+
+ if (i ==alphabet_size-1)
+ {
+ DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 1+amount);
+
+ DLIB_TEST(high_count == low_count+1+amount);
+ DLIB_TEST(total_count == test.get_total());
+
+
+ DLIB_TEST(test.get_count(i) == 1+amount);
+ }
+ else
+ {
+ DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == amount);
+
+ DLIB_TEST(high_count == low_count+amount);
+ DLIB_TEST(total_count == test.get_total());
+
+
+ DLIB_TEST(test.get_count(i) == amount);
+ }
+ DLIB_TEST(test.get_total() == (i+1)*amount + 1);
+ }
+
+
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ unsigned long temp = static_cast<unsigned long>(::rand()%40);
+ for (unsigned long j = 0; j < temp; ++j)
+ {
+ test.increment_count(i,static_cast<unsigned short>(amount));
+ if (i == alphabet_size-1)
+ {
+ DLIB_TEST(test.get_count(i) == (j+1)*amount + 1 + amount);
+ }
+ else
+ {
+ DLIB_TEST(test.get_count(i) == (j+1)*amount + amount);
+ }
+ }
+
+ unsigned long target = test.get_total()/2;
+ unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0;
+
+ if (i == alphabet_size-1)
+ {
+ DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==temp*amount+1+amount);
+ DLIB_TEST(high_count-low_count == temp*amount+1+amount);
+ }
+ else
+ {
+ DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==temp*amount + amount);
+ DLIB_TEST(high_count-low_count == temp*amount + amount);
+ }
+ DLIB_TEST(total_count == test.get_total());
+
+ test.get_symbol(target,symbol,low_count,high_count);
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ DLIB_TEST(low_count <= target);
+ DLIB_TEST(target < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+
+ }
+
+ test.clear();
+
+
+ for (unsigned long i = 0; i < alphabet_size-1; ++i)
+ {
+ test.increment_count(i);
+ unsigned long low_count = 0, high_count = 0, total_count = 0;
+ DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 1);
+
+ DLIB_TEST(high_count == low_count+1);
+ DLIB_TEST(total_count == test.get_total());
+
+ DLIB_TEST(test.get_count(i) == 1);
+ DLIB_TEST(test.get_total() == i+2);
+ }
+
+
+
+
+ unsigned long counts[alphabet_size];
+
+
+ print_spinner();
+ for (int k = 0; k < 10; ++k)
+ {
+ unsigned long range = ::rand()%50000 + 2;
+
+ test.clear();
+
+ for (unsigned long i = 0; i < alphabet_size-1; ++i)
+ counts[i] = 0;
+ unsigned long total = 1;
+ counts[alphabet_size-1] = 1;
+
+
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ unsigned long temp = static_cast<unsigned long>(::rand()%range);
+ for (unsigned long j = 0; j < temp; ++j)
+ {
+ test.increment_count(i);
+
+
+ if (total >= 65535)
+ {
+ total = 0;
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ counts[i] >>= 1;
+ total += counts[i];
+ }
+ if (counts[alphabet_size-1]==0)
+ {
+ counts[alphabet_size-1] = 1;
+ ++total;
+ }
+ }
+ counts[i] = counts[i] + 1;
+ ++total;
+
+
+ }
+
+
+ unsigned long temp_total = 0;
+ for (unsigned long a = 0; a < alphabet_size; ++a)
+ {
+ temp_total += test.get_count(a);
+ }
+ DLIB_TEST_MSG(temp_total == test.get_total(),
+ "temp_total == " << temp_total << endl <<
+ "test.get_total() == " << test.get_total()
+ );
+
+ DLIB_TEST(test.get_count(alphabet_size-1) == counts[alphabet_size-1]);
+ DLIB_TEST_MSG(test.get_total() == total,
+ "test.get_total() == " << test.get_total() << endl <<
+ "total == " << total
+ );
+
+ unsigned long target = test.get_total()/2;
+ unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0;
+
+
+ DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==counts[symbol]);
+
+ if (counts[symbol] != 0)
+ {
+ DLIB_TEST(total_count == total);
+
+ DLIB_TEST(high_count <= total);
+ DLIB_TEST(low_count < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ }
+
+
+ if (target < total)
+ {
+ test.get_symbol(target,symbol,low_count,high_count);
+
+
+ DLIB_TEST(high_count <= total);
+ DLIB_TEST(low_count < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ DLIB_TEST(test.get_count(symbol) == counts[symbol]);
+ }
+
+
+
+
+ }
+
+ }
+
+ print_spinner();
+
+ for (unsigned long h = 0; h < 10; ++h)
+ {
+ test.clear();
+ DLIB_TEST(test.get_total() == 1);
+
+ // fill out test with some numbers
+ unsigned long temp = ::rand()%30000 + 50000;
+ for (unsigned long j = 0; j < temp; ++j)
+ {
+ unsigned long symbol = (unsigned long)::rand()%alphabet_size;
+ test.increment_count(symbol);
+ }
+
+ // make sure all symbols have a count of at least one
+ for (unsigned long j = 0; j < alphabet_size; ++j)
+ {
+ if (test.get_count(j) == 0)
+ test.increment_count(j);
+ }
+
+ unsigned long temp_total = 0;
+ for (unsigned long j = 0; j < alphabet_size; ++j)
+ {
+ temp_total += test.get_count(j);
+ }
+ DLIB_TEST(temp_total == test.get_total());
+
+
+ unsigned long low_counts[alphabet_size];
+ unsigned long high_counts[alphabet_size];
+ // iterate over all the symbols
+ for (unsigned long j = 0; j < alphabet_size; ++j)
+ {
+ unsigned long total;
+ unsigned long count = test.get_range(j,low_counts[j],high_counts[j],total);
+ DLIB_TEST(count == test.get_count(j));
+ DLIB_TEST(count == high_counts[j] - low_counts[j]);
+
+ }
+
+
+ // make sure get_symbol() matches what get_range() told us
+ for (unsigned long j = 0; j < alphabet_size; ++j)
+ {
+ for (unsigned long k = low_counts[j]; k < high_counts[j]; ++k)
+ {
+ unsigned long symbol, low_count, high_count;
+ test.get_symbol(k,symbol,low_count,high_count);
+ DLIB_TEST(high_count - low_count == test.get_count(symbol));
+ DLIB_TEST_MSG(j == symbol,
+ "j == " << j << endl <<
+ "k == " << k << endl <<
+ "symbol == " << symbol << endl <<
+ "low_counts[j] == " << low_counts[j] << endl <<
+ "high_counts[j] == " << high_counts[j] << endl <<
+ "low_counts[symbol] == " << low_counts[symbol] << endl <<
+ "high_counts[symbol] == " << high_counts[symbol] << endl <<
+ "low_count == " << low_count << endl <<
+ "high_count == " << high_count << endl <<
+ "temp.count(j) == " << test.get_count(j)
+ );
+ DLIB_TEST_MSG(low_count == low_counts[j],
+ "symbol: " << j << "\n" <<
+ "target: " << k << "\n" <<
+ "low_count: " << low_count << "\n" <<
+ "low_counts[j]: " << low_counts[j]);
+ DLIB_TEST(high_count == high_counts[j]);
+ }
+
+ }
+
+ }
+
+
+
+ print_spinner();
+
+ for (int h = 0; h < 10; ++h)
+ {
+
+
+ test.clear();
+
+ for (unsigned long k = 0; k < alphabet_size-1; ++k)
+ {
+ counts[k] = 0;
+ }
+ counts[alphabet_size-1] = 1;
+ unsigned long total = 1;
+ unsigned long i = ::rand()%alphabet_size;
+
+ unsigned long temp = 65536;
+ for (unsigned long j = 0; j < temp; ++j)
+ {
+ test.increment_count(i);
+
+
+ if (total >= 65535)
+ {
+ total = 0;
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ counts[i] >>= 1;
+ total += counts[i];
+ }
+ if (counts[alphabet_size-1] == 0)
+ {
+ ++total;
+ counts[alphabet_size-1] = 1;
+ }
+ }
+ counts[i] = counts[i] + 1;
+ ++total;
+
+ }
+
+
+ DLIB_TEST(test.get_total() == total);
+
+ unsigned long target = test.get_total()/2;
+ unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0;
+
+
+ DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==counts[symbol]);
+
+ if (counts[symbol] != 0)
+ {
+ DLIB_TEST(total_count == total);
+
+ DLIB_TEST(high_count <= total);
+ DLIB_TEST(low_count < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ }
+
+
+
+ test.get_symbol(target,symbol,low_count,high_count);
+
+
+ DLIB_TEST(high_count <= total);
+ DLIB_TEST(low_count < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ DLIB_TEST(test.get_count(symbol) == counts[symbol]);
+
+
+
+
+
+
+
+ }
+
+ } // for (int g = 0; g < 2; ++g)
+
+
+
+
+
+
+
+
+
+
+
+
+
+ for (int g = 0; g < 2; ++g)
+ {
+ print_spinner();
+ unsigned long amount=g+1;
+ cc test(gs);
+ cc test2(gs);
+
+ DLIB_TEST(test.get_memory_usage() != 0);
+
+ const unsigned long alphabet_size = 256;
+
+
+ DLIB_TEST(test.get_total() == 1);
+
+ DLIB_TEST(test.get_count(alphabet_size-1)==1);
+ for (unsigned long i = 0; i < alphabet_size-1; ++i)
+ {
+ unsigned long low_count, high_count, total_count;
+ DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 0);
+ DLIB_TEST(test.get_count(i) == 0);
+ DLIB_TEST(test.get_total() == 1);
+ }
+
+
+ bool oom = false;
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ bool status = test.increment_count(i,static_cast<unsigned short>(amount));
+ unsigned long low_count = 0, high_count = 0, total_count = 0;
+ if (!status)
+ oom = true;
+
+ if (status)
+ {
+ if (i ==alphabet_size-1)
+ {
+ DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 1+amount);
+
+ DLIB_TEST(high_count == low_count+1+amount);
+ DLIB_TEST(total_count == test.get_total());
+
+
+ DLIB_TEST(test.get_count(i) == 1+amount);
+ }
+ else
+ {
+ DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == amount);
+
+ DLIB_TEST(high_count == low_count+amount);
+ DLIB_TEST(total_count == test.get_total());
+
+
+ DLIB_TEST(test.get_count(i) == amount);
+ }
+ if (!oom)
+ DLIB_TEST(test.get_total() == (i+1)*amount + 1);
+ }
+ }
+
+
+ oom = false;
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ unsigned long temp = static_cast<unsigned long>(::rand()%40);
+ for (unsigned long j = 0; j < temp; ++j)
+ {
+ bool status = test.increment_count(i,static_cast<unsigned short>(amount));
+ if (!status)
+ oom = true;
+ if (status)
+ {
+ if (i == alphabet_size-1)
+ {
+ DLIB_TEST(test.get_count(i) == (j+1)*amount + 1 + amount);
+ }
+ else
+ {
+ DLIB_TEST(test.get_count(i) == (j+1)*amount + amount);
+ }
+ }
+ }
+
+ unsigned long target = test.get_total()/2;
+ unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0;
+
+ if (!oom)
+ {
+ if (i == alphabet_size-1)
+ {
+ DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==temp*amount+1+amount);
+ DLIB_TEST(high_count-low_count == temp*amount+1+amount);
+ }
+ else
+ {
+ DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==temp*amount + amount);
+ DLIB_TEST(high_count-low_count == temp*amount + amount);
+ }
+ DLIB_TEST(total_count == test.get_total());
+
+
+ test.get_symbol(target,symbol,low_count,high_count);
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ DLIB_TEST(low_count <= target);
+ DLIB_TEST(target < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+ }
+
+ }
+
+ test.clear();
+
+
+ oom = false;
+ for (unsigned long i = 0; i < alphabet_size-1; ++i)
+ {
+ if(!test.increment_count(i))
+ oom = true;
+ unsigned long low_count = 0, high_count = 0, total_count = 0;
+
+ if (!oom)
+ {
+ DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 1);
+
+ DLIB_TEST(high_count == low_count+1);
+ DLIB_TEST(total_count == test.get_total());
+
+ DLIB_TEST(test.get_count(i) == 1);
+ DLIB_TEST(test.get_total() == i+2);
+ }
+ }
+
+
+
+ unsigned long counts[alphabet_size];
+
+
+ for (int k = 0; k < 10; ++k)
+ {
+ unsigned long range = ::rand()%50000 + 2;
+
+ test.clear();
+
+ for (unsigned long i = 0; i < alphabet_size-1; ++i)
+ counts[i] = 0;
+ unsigned long total = 1;
+ counts[alphabet_size-1] = 1;
+
+
+ oom = false;
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ unsigned long temp = static_cast<unsigned long>(::rand()%range);
+ for (unsigned long j = 0; j < temp; ++j)
+ {
+ if (!test.increment_count(i))
+ oom = true;
+
+
+ if (total >= 65535)
+ {
+
+ total = 0;
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ counts[i] >>= 1;
+ total += counts[i];
+ }
+ if (counts[alphabet_size-1]==0)
+ {
+ counts[alphabet_size-1] = 1;
+ ++total;
+ }
+ }
+ counts[i] = counts[i] + 1;
+ ++total;
+
+
+ }
+
+
+ unsigned long temp_total = 0;
+ for (unsigned long a = 0; a < alphabet_size; ++a)
+ {
+ temp_total += test.get_count(a);
+ }
+
+ if (!oom)
+ {
+ DLIB_TEST_MSG(temp_total == test.get_total(),
+ "temp_total == " << temp_total << endl <<
+ "test.get_total() == " << test.get_total()
+ );
+
+ DLIB_TEST(test.get_count(alphabet_size-1) == counts[alphabet_size-1]);
+ DLIB_TEST_MSG(test.get_total() == total,
+ "test.get_total() == " << test.get_total() << endl <<
+ "total == " << total
+ );
+ }
+
+ unsigned long target = test.get_total()/2;
+ unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0;
+
+ if (!oom)
+ {
+
+ DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==counts[symbol]);
+
+ if (counts[symbol] != 0)
+ {
+ DLIB_TEST(total_count == total);
+
+ DLIB_TEST(high_count <= total);
+ DLIB_TEST(low_count < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ }
+
+
+ if (target < total)
+ {
+ test.get_symbol(target,symbol,low_count,high_count);
+
+
+ DLIB_TEST(high_count <= total);
+ DLIB_TEST(low_count < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ DLIB_TEST(test.get_count(symbol) == counts[symbol]);
+ }
+ }
+
+
+
+ }
+
+ }
+
+ oom = false;
+ for (unsigned long h = 0; h < 10; ++h)
+ {
+ test.clear();
+ DLIB_TEST(test.get_total() == 1);
+
+ // fill out test with some numbers
+ unsigned long temp = ::rand()%30000 + 50000;
+ for (unsigned long j = 0; j < temp; ++j)
+ {
+ unsigned long symbol = (unsigned long)::rand()%alphabet_size;
+ if (!test.increment_count(symbol))
+ oom = true;
+ }
+
+ // make sure all symbols have a count of at least one
+ for (unsigned long j = 0; j < alphabet_size; ++j)
+ {
+ if (test.get_count(j) == 0)
+ test.increment_count(j);
+ }
+
+ unsigned long temp_total = 0;
+ for (unsigned long j = 0; j < alphabet_size; ++j)
+ {
+ temp_total += test.get_count(j);
+ }
+ if (!oom)
+ DLIB_TEST(temp_total == test.get_total());
+
+
+ unsigned long low_counts[alphabet_size];
+ unsigned long high_counts[alphabet_size];
+
+ if (!oom)
+ {
+
+ // iterate over all the symbols
+ for (unsigned long j = 0; j < alphabet_size; ++j)
+ {
+ unsigned long total;
+ unsigned long count = test.get_range(j,low_counts[j],high_counts[j],total);
+ DLIB_TEST(count == test.get_count(j));
+ DLIB_TEST(count == high_counts[j] - low_counts[j]);
+
+ }
+
+
+
+
+ // make sure get_symbol() matches what get_range() told us
+ for (unsigned long j = 0; j < alphabet_size; ++j)
+ {
+ for (unsigned long k = low_counts[j]; k < high_counts[j]; ++k)
+ {
+ unsigned long symbol, low_count, high_count;
+ test.get_symbol(k,symbol,low_count,high_count);
+ DLIB_TEST(high_count - low_count == test.get_count(symbol));
+ DLIB_TEST_MSG(j == symbol,
+ "j == " << j << endl <<
+ "k == " << k << endl <<
+ "symbol == " << symbol << endl <<
+ "low_counts[j] == " << low_counts[j] << endl <<
+ "high_counts[j] == " << high_counts[j] << endl <<
+ "low_counts[symbol] == " << low_counts[symbol] << endl <<
+ "high_counts[symbol] == " << high_counts[symbol] << endl <<
+ "low_count == " << low_count << endl <<
+ "high_count == " << high_count << endl <<
+ "temp.count(j) == " << test.get_count(j)
+ );
+ DLIB_TEST_MSG(low_count == low_counts[j],
+ "symbol: " << j << "\n" <<
+ "target: " << k << "\n" <<
+ "low_count: " << low_count << "\n" <<
+ "low_counts[j]: " << low_counts[j]);
+ DLIB_TEST(high_count == high_counts[j]);
+ }
+
+ }
+ }
+
+ }
+
+
+
+
+ for (int h = 0; h < 10; ++h)
+ {
+
+
+ test.clear();
+
+ for (unsigned long k = 0; k < alphabet_size-1; ++k)
+ {
+ counts[k] = 0;
+ }
+ counts[alphabet_size-1] = 1;
+ unsigned long total = 1;
+ unsigned long i = ::rand()%alphabet_size;
+
+ unsigned long temp = 65536;
+ for (unsigned long j = 0; j < temp; ++j)
+ {
+ test.increment_count(i);
+
+
+ if (total >= 65535)
+ {
+ total = 0;
+ for (unsigned long i = 0; i < alphabet_size; ++i)
+ {
+ counts[i] >>= 1;
+ total += counts[i];
+ }
+ if (counts[alphabet_size-1] == 0)
+ {
+ ++total;
+ counts[alphabet_size-1] = 1;
+ }
+ }
+ counts[i] = counts[i] + 1;
+ ++total;
+
+ }
+
+
+ DLIB_TEST(test.get_total() == total);
+
+ unsigned long target = test.get_total()/2;
+ unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0;
+
+
+ DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==counts[symbol]);
+
+ if (counts[symbol] != 0)
+ {
+ DLIB_TEST(total_count == total);
+
+ DLIB_TEST(high_count <= total);
+ DLIB_TEST(low_count < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ }
+
+
+
+ test.get_symbol(target,symbol,low_count,high_count);
+
+
+ DLIB_TEST(high_count <= total);
+ DLIB_TEST(low_count < high_count);
+ DLIB_TEST(high_count <= test.get_total());
+ DLIB_TEST(test.get_count(symbol) == high_count-low_count);
+ DLIB_TEST(test.get_count(symbol) == counts[symbol]);
+
+
+
+
+
+
+
+ }
+
+ } // for (int g = 0; g < 2; ++g)
+
+
+ }
+
+}
+
+#endif // DLIB_TEST_CONDITIONING_CLASs_H_
+
diff --git a/ml/dlib/dlib/test/conditioning_class_c.cpp b/ml/dlib/dlib/test/conditioning_class_c.cpp
new file mode 100644
index 000000000..4bfd9f32a
--- /dev/null
+++ b/ml/dlib/dlib/test/conditioning_class_c.cpp
@@ -0,0 +1,87 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <ctime>
+#include <cstdlib>
+
+#include <dlib/conditioning_class.h>
+
+#include "tester.h"
+#include "conditioning_class.h"
+
+namespace
+{
+
+
+ class conditioning_class_tester : public tester
+ {
+ public:
+ conditioning_class_tester (
+ ) :
+ tester ("test_conditioning_class_c",
+ "Runs tests on the conditioning_class checked components.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a_c";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_1a_c,
+ conditioning_class<2>::kernel_1a_c
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_2a_c";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_2a_c,
+ conditioning_class<2>::kernel_2a_c
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_3a_c";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_3a_c,
+ conditioning_class<2>::kernel_3a_c
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_4a_c";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_4a_c,
+ conditioning_class<2>::kernel_4a_c
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_4b_c";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_4b_c,
+ conditioning_class<2>::kernel_4b_c
+ >();
+ print_spinner();
+
+
+ dlog << LINFO << "testing kernel_4c_c";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_4c_c,
+ conditioning_class<2>::kernel_4c_c
+ >();
+ print_spinner();
+
+ dlog << LINFO << "testing kernel_4d_c";
+ conditioning_class_kernel_test<
+ conditioning_class<256>::kernel_4d_c,
+ conditioning_class<2>::kernel_4d_c
+ >();
+ print_spinner();
+
+
+ }
+ } a;
+
+
+}
+
diff --git a/ml/dlib/dlib/test/config_reader.cpp b/ml/dlib/dlib/test/config_reader.cpp
new file mode 100644
index 000000000..20b5215f3
--- /dev/null
+++ b/ml/dlib/dlib/test/config_reader.cpp
@@ -0,0 +1,509 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/config_reader.h>
+#include <dlib/cmd_line_parser.h>
+
+#include "tester.h"
+
+// This is called an unnamed-namespace and it has the effect of making everything inside this file "private"
+// so that everything you declare will have static linkage. Thus we won't have any multiply
+// defined symbol errors coming out of the linker when we try to compile the test suite.
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ // Declare the logger we will use in this test. The name of the tester
+ // should start with "test."
+ logger dlog("test.config_reader");
+
+ template <
+ typename config_reader
+ >
+ void do_the_tests (
+ config_reader& cr
+ )
+ {
+ DLIB_TEST(cr.is_key_defined("global"));
+ DLIB_TEST(cr.is_block_defined("all"));
+ DLIB_TEST(cr.is_key_defined("globalasfd") == false);
+ DLIB_TEST(cr.is_block_defined("all!") == false);
+ DLIB_TEST(cr["global"] == "hmm");
+ DLIB_TEST(cr["global2"] == "hmm2");
+
+ std_vector_c<string> blocks;
+ cr.block("all").get_blocks(blocks);
+ DLIB_TEST(blocks.size() == 4);
+ cr.block("all").block("block1").get_blocks(blocks); DLIB_TEST(blocks.size() == 0);
+ cr.block("all").block("block2").get_blocks(blocks); DLIB_TEST(blocks.size() == 0);
+ cr.block("all").block("block3").get_blocks(blocks); DLIB_TEST(blocks.size() == 0);
+ cr.block("all").block("block4").get_blocks(blocks); DLIB_TEST(blocks.size() == 0);
+
+ DLIB_TEST(cr.block("all").block("block1").is_key_defined("name"));
+ DLIB_TEST(cr.block("all").block("block2").is_key_defined("name"));
+ DLIB_TEST(cr.block("all").block("block3").is_key_defined("name"));
+ DLIB_TEST(cr.block("all").block("block4").is_key_defined("name"));
+ DLIB_TEST(cr.block("all").block("block1").is_key_defined("age"));
+ DLIB_TEST(cr.block("all").block("block2").is_key_defined("age"));
+ DLIB_TEST(cr.block("all").block("block3").is_key_defined("age"));
+ DLIB_TEST(cr.block("all").block("block4").is_key_defined("age"));
+
+ DLIB_TEST(cr.block("all").block("block1")["name"] == "davis king");
+ DLIB_TEST(cr.block("all").block("block2")["name"] == "joel");
+ DLIB_TEST(cr.block("all").block("block3")["name"] == "john");
+ DLIB_TEST(cr.block("all").block("block4")["name"] == "dude");
+ DLIB_TEST(cr.block("all").block("block1")["age"] == "24");
+ DLIB_TEST(cr.block("all").block("block2")["age"] == "24");
+ DLIB_TEST(cr.block("all").block("block3")["age"] == "24");
+ DLIB_TEST(cr.block("all").block("block4")["age"] == "53");
+
+
+ int count2 = 0;
+ cr.get_blocks(blocks);
+ DLIB_TEST(blocks.size() == 1);
+ DLIB_TEST(blocks[0] == "all");
+
+
+ DLIB_TEST(cr.block("all").is_key_defined("global") == false);
+ DLIB_TEST(cr.block("all").is_key_defined("global2") == false);
+ DLIB_TEST(cr.block("all").is_key_defined("name") == false);
+ DLIB_TEST(cr.block("all").is_key_defined("age") == false);
+
+ cr.block("all").get_blocks(blocks);
+ DLIB_TEST(blocks.size() == 4);
+ std::vector<string> temp_blocks;
+ for (unsigned long i = 0; i < blocks.size(); ++i)
+ {
+ ++count2;
+ ostringstream sout;
+ sout << "block" << count2;
+ DLIB_TEST(blocks[i] == sout.str());
+
+ cr.block("all").block(blocks[i]).get_blocks(temp_blocks);
+ DLIB_TEST(temp_blocks.size() == 0);
+
+ DLIB_TEST(cr.block("all").block(blocks[i]).is_key_defined("name"));
+ DLIB_TEST(cr.block("all").block(blocks[i]).is_key_defined("age"));
+ }
+
+
+
+ bool found_error = false;
+ try
+ {
+ cr.block("bogus_block");
+ }
+ catch (typename config_reader::config_reader_access_error& e)
+ {
+ DLIB_TEST(e.block_name == "bogus_block");
+ DLIB_TEST(e.key_name == "");
+ found_error = true;
+ }
+ DLIB_TEST(found_error);
+
+ found_error = false;
+ try
+ {
+ cr["bogus_key"];
+ }
+ catch (typename config_reader::config_reader_access_error& e)
+ {
+ DLIB_TEST(e.block_name == "");
+ DLIB_TEST(e.key_name == "bogus_key");
+ found_error = true;
+ }
+ DLIB_TEST(found_error);
+
+
+ found_error = false;
+ try
+ {
+ cr.block("all").block("block10");
+ }
+ catch (typename config_reader::config_reader_access_error& e)
+ {
+ DLIB_TEST(e.block_name == "block10");
+ DLIB_TEST(e.key_name == "");
+ found_error = true;
+ }
+ DLIB_TEST(found_error);
+
+ found_error = false;
+ try
+ {
+ cr.block("all")["msdofg"];
+ }
+ catch (typename config_reader::config_reader_access_error& e)
+ {
+ DLIB_TEST(e.block_name == "");
+ DLIB_TEST(e.key_name == "msdofg");
+ found_error = true;
+ }
+ DLIB_TEST(found_error);
+
+ }
+
+
+
+ template <
+ typename config_reader
+ >
+ void config_reader_test (
+ )
+ /*!
+ requires
+ - config_reader is an implementation of config_reader/config_reader_kernel_abstract.h
+ is instantiated with int
+ ensures
+ - runs tests on config_reader for compliance with the specs
+ !*/
+ {
+
+
+
+ ostringstream sout;
+
+ sout << "all#comment { { } \n";
+ sout << "{ \n";
+ sout << " block1 \n";
+ sout << " { \n";
+ sout << " name = davis king \n";
+ sout << " age = 24 \n";
+ sout << " } \n";
+ sout << " \n";
+ sout << " block2 \n";
+ sout << " { \n";
+ sout << " name= joel \n";
+ sout << " age =24 \n";
+ sout << " } \n";
+ sout << " \n";
+ sout << " block3 \n";
+ sout << " { \n";
+ sout << " name = john \n";
+ sout << " age = 24 \n";
+ sout << " } \n";
+ sout << " #comment \n";
+ sout << "#comment \n";
+ sout << " block4{ # comment";
+ sout << " \n";
+ sout << " name = dude \n";
+ sout << " age = 53}\n";
+ sout << " \n";
+ sout << "} \n";
+ sout << " \n";
+ sout << " \n";
+ sout << "global=hmm#comment \n";
+ sout << "global2=hmm2 \n";
+ sout << " # comment \n";
+
+ string data = sout.str();
+
+ config_reader cr2;
+ for (int i = 0; i < 3; ++i)
+ {
+ istringstream sin;
+
+ sin.clear();
+ sin.str(data);
+
+ config_reader cr(sin);
+ sin.clear();
+ sin.str(data);
+
+ cr2.load_from(sin);
+
+ do_the_tests(cr);
+ do_the_tests(cr2);
+
+ cr.clear();
+ DLIB_TEST(cr.is_key_defined("global") == false);
+ }
+
+
+ sout.clear();
+ sout.str("");
+
+ {
+ sout << "all#comment { { } \n";
+ sout << "{ \n";
+ sout << " block1 \n";
+ sout << " { \n";
+ sout << " name = davis king \n";
+ sout << " age = 24 \n";
+ sout << " } \n";
+ sout << " \n";
+ sout << " block2 \n";
+ sout << " { \n";
+ sout << " name= joel \n";
+ sout << " age =24 \n";
+ sout << " } \n";
+ sout << " \n";
+ sout << " block3 \n";
+ sout << " {{ \n"; // error on this line
+ sout << " name = john \n";
+ sout << " age = 24 \n";
+ sout << " } \n";
+ sout << " #comment \n";
+ sout << "#comment \n";
+ sout << " block4{ # comment";
+ sout << " \n";
+ sout << " name = dude \n";
+ sout << " age = 53}\n";
+ sout << " \n";
+ sout << "} \n";
+ sout << " \n";
+ sout << " \n";
+ sout << "global=hmm#comment \n";
+ sout << "global2=hmm2 \n";
+ sout << " # comment \n";
+
+ istringstream sin(sout.str());
+
+ bool error_found = false;
+ try
+ {
+ cr2.load_from(sin);
+ }
+ catch (typename config_reader::config_reader_error& e)
+ {
+ error_found = true;
+ DLIB_TEST(e.line_number == 16);
+ DLIB_TEST(e.redefinition == false);
+ }
+ DLIB_TEST(error_found);
+ }
+
+ {
+ sout.str("");
+ sout.clear();
+ sout << "all#comment { { } \n";
+ sout << "{ \n";
+ sout << " block1 \n";
+ sout << " { \n";
+ sout << " name = davis king \n";
+ sout << " age = 24 \n";
+ sout << " } \n";
+ sout << " \n";
+ sout << " block2 \n";
+ sout << " { \n";
+ sout << " name= joel \n";
+ sout << " age =24 \n";
+ sout << " } \n";
+ sout << " \n";
+ sout << " block3 \n";
+ sout << " { \n";
+ sout << " name = john \n";
+ sout << " age = 24 \n";
+ sout << " } \n";
+ sout << " #comment \n";
+ sout << "#comment \n";
+ sout << " block4{ # comment";
+ sout << " \n";
+ sout << " name = dude \n";
+ sout << " age = 53}\n";
+ sout << " \n";
+ sout << "} \n";
+ sout << " \n";
+ sout << " \n";
+ sout << "global=hmm#comment \n";
+ sout << " \n";
+ sout << "global=hmm2 \n"; // error on this line
+ sout << " # comment \n";
+
+ istringstream sin(sout.str());
+
+ bool error_found = false;
+ try
+ {
+ cr2.load_from(sin);
+ }
+ catch (typename config_reader::config_reader_error& e)
+ {
+ error_found = true;
+ DLIB_TEST_MSG(e.line_number == 31,e.line_number);
+ DLIB_TEST(e.redefinition == true);
+ }
+ DLIB_TEST(error_found);
+ }
+
+
+ {
+ sout.str("");
+ sout.clear();
+ sout << "all#comment { { } \n";
+ sout << "{ \n";
+ sout << " block1 \n";
+ sout << " { \n";
+ sout << " name = davis king \n";
+ sout << " age = 24 \n";
+ sout << " } \n";
+ sout << " \n";
+ sout << " block2 \n";
+ sout << " { \n";
+ sout << " name= joel \n";
+ sout << " age =24 \n";
+ sout << " } block2{} \n"; // error on this line
+ sout << " \n";
+ sout << " block3 \n";
+ sout << " { \n";
+ sout << " name = john \n";
+ sout << " age = 24 \n";
+ sout << " } \n";
+ sout << " #comment \n";
+ sout << "#comment \n";
+ sout << " block4{ # comment";
+ sout << " \n";
+ sout << " name = dude \n";
+ sout << " age = 53}\n";
+ sout << " \n";
+ sout << "} \n";
+ sout << " \n";
+ sout << " \n";
+ sout << "global=hmm#comment \n";
+ sout << " \n";
+ sout << " # comment \n";
+
+ istringstream sin(sout.str());
+
+ bool error_found = false;
+ try
+ {
+ cr2.load_from(sin);
+ }
+ catch (typename config_reader::config_reader_error& e)
+ {
+ error_found = true;
+ DLIB_TEST_MSG(e.line_number == 13,e.line_number);
+ DLIB_TEST(e.redefinition == true);
+ }
+ DLIB_TEST(error_found);
+ }
+
+
+
+ }
+
+
+ void test_get_option()
+ {
+ const char* argv[100];
+ int argc;
+
+ // program --opt 4 -d dude
+ argv[0] = "program";
+ argv[1] = "--opt";
+ argv[2] = "4";
+ argv[3] = "-d";
+ argv[4] = "dude";
+ argc = 5;
+
+ std::ostringstream sout;
+ sout << "block#comment { { } \n";
+ sout << "{ \n";
+ sout << " opt = 5 \n";
+ sout << " a = 6 \n";
+ sout << " d = joel \n";
+ sout << " subblock {} \n";
+ sout << "} \n";
+ sout << " \n";
+ sout << " \n";
+ sout << "opt = 8 \n";
+ sout << "d = davis \n";
+ sout << "a = 50 \n";
+ sout << " # comment \n";
+
+ std::istringstream sin(sout.str());
+
+ config_reader cr(sin);
+
+ dlib::cmd_line_parser<char>::kernel_1a_c parser;
+
+ parser.add_option("opt","",1);
+ parser.add_option("d","",1);
+ parser.add_option("a","",1);
+ parser.add_option("b","",1);
+ parser.parse(argc, argv);
+
+ DLIB_TEST(get_option(cr, "d", "default") == "davis");
+ DLIB_TEST(get_option(cr, "opt", "default") == "8");
+ DLIB_TEST(get_option(cr, "opt", 1) == 8);
+ DLIB_TEST(get_option(cr, "optasdf", 1) == 1);
+ DLIB_TEST(get_option(cr, "optasdf", 1.1) == 1.1);
+ DLIB_TEST(get_option(cr.block("block"), "d", "default") == "joel");
+ DLIB_TEST(get_option(cr.block("block"), "opt", "default") == "5");
+ DLIB_TEST(get_option(cr.block("block"), "opt", 1) == 5);
+ DLIB_TEST(get_option(cr.block("block").block("subblock"), "d", "default") == "default");
+ DLIB_TEST(get_option(cr.block("block").block("subblock"), "opt", "default") == "default");
+ DLIB_TEST(get_option(cr.block("block").block("subblock"), "opt", 1) == 1);
+ DLIB_TEST(get_option(cr, "block.d", "default") == "joel");
+ DLIB_TEST(get_option(cr, "block.opt", "default") == "5");
+ DLIB_TEST(get_option(cr, "block.opt", 1) == 5);
+ DLIB_TEST(get_option(cr, "block.asdf.d", "default") == "default");
+ DLIB_TEST(get_option(cr, "block.asdf.opt", "default") == "default");
+ DLIB_TEST(get_option(cr, "block.asdf.opt", 2) == 2);
+ DLIB_TEST(get_option(cr, "block.subblock.d", "default") == "default");
+ DLIB_TEST(get_option(cr, "block.subblock.opt", "default") == "default");
+ DLIB_TEST(get_option(cr, "block.subblock.opt", 2) == 2);
+
+ DLIB_TEST(get_option(parser, "opt", 99) == 4);
+ DLIB_TEST(get_option(parser, "d", "stuff") == "dude");
+ DLIB_TEST(get_option(parser, "a", "stuff") == "stuff");
+ DLIB_TEST(get_option(parser, "a", 99) == 99);
+
+ DLIB_TEST(get_option(parser, cr, "d", "default") == "dude");
+ DLIB_TEST(get_option(cr, parser, "d", "default") == "dude");
+ DLIB_TEST(get_option(parser, cr, "a", 2) == 50);
+ DLIB_TEST(get_option(cr, parser, "a", 2) == 50);
+ DLIB_TEST(get_option(parser, cr, "opt", 2) == 4);
+ DLIB_TEST(get_option(cr, parser, "opt", 2) == 4);
+ DLIB_TEST(get_option(parser, cr, "b", 2) == 2);
+ DLIB_TEST(get_option(cr, parser, "b", 2) == 2);
+
+ DLIB_TEST(get_option(parser, cr.block("block"), "a", 2) == 6);
+ DLIB_TEST(get_option(cr.block("block"), parser, "a", 2) == 6);
+ }
+
+
+ class config_reader_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a test for the config_reader object. When it is constructed
+ it adds itself into the testing framework. The command line switch is
+ specified as test_config_reader by passing that string to the tester constructor.
+ !*/
+ public:
+ config_reader_tester (
+ ) :
+ tester ("test_config_reader",
+ "Runs tests on the config_reader component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing config_reader";
+ print_spinner();
+ config_reader_test<config_reader>();
+
+ dlog << LINFO << "testing config_reader_thread_safe";
+ print_spinner();
+ config_reader_test<config_reader_thread_safe>();
+
+ dlog << LINFO << "testing get_option()";
+ print_spinner();
+ test_get_option();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/correlation_tracker.cpp b/ml/dlib/dlib/test/correlation_tracker.cpp
new file mode 100644
index 000000000..5ccf61c57
--- /dev/null
+++ b/ml/dlib/dlib/test/correlation_tracker.cpp
@@ -0,0 +1,955 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/image_processing.h>
+#include <vector>
+#include <sstream>
+#include <dlib/compress_stream.h>
+#include <dlib/base64.h>
+#include <dlib/image_io.h>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.correlation_tracker");
+
+
+ class correlation_tracker_tester : public tester
+ {
+ public:
+ correlation_tracker_tester(
+ ) :
+ tester (
+ "test_correlation_tracker", // the command line argument name for this test
+ "Run tests on the correlation_tracker functions.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "perform_test()";
+
+ typedef const std::string(*frame_fn_type)();
+ // frames from examples folder
+ frame_fn_type frames[] = { &get_decoded_string_frame_000100,
+ &get_decoded_string_frame_000101,
+ &get_decoded_string_frame_000102,
+ &get_decoded_string_frame_000103
+ };
+ // correct tracking rectangles - recorded by successful runs
+ drectangle correct_rects[] = {drectangle(74, 67, 111, 152),
+ drectangle(76.025, 72.634, 112.799, 157.114),
+ drectangle(78.6849, 78.504, 115.413, 162.88),
+ drectangle(82.7572, 83.6035, 120.319, 169.895)
+ };
+ // correct update results - recorded by successful runs
+ double correct_update_results[] = { 0, 18.3077, 16.8406, 13.1716 };
+
+ correlation_tracker tracker;
+ std::istringstream sin(frames[0]());
+ array2d<unsigned char> img;
+ load_bmp(img, sin);
+ tracker.start_track(img, centered_rect(point(93, 110), 38, 86));
+ for (unsigned i = 1; i < sizeof(frames) / sizeof(frames[0]); ++i)
+ {
+ std::istringstream sin(frames[i]());
+ load_bmp(img, sin);
+
+ double res = tracker.update(img);
+ double correct_res = correct_update_results[i];
+ double res_diff = abs(correct_res - res);
+
+ drectangle pos = tracker.get_position();
+ drectangle correct_pos = correct_rects[i];
+ drectangle pos_intresect = pos.intersect(correct_pos);
+ double pos_area = pos.area();
+ double intersect_area = pos_intresect.area();
+ double rect_confidence = intersect_area / pos_area;
+
+ dlog << LINFO << "Frame #" << i << " res: " << res << " correct res: " << correct_res << " pos: " << pos
+ << " correct pos: " << correct_pos << " rect confidence: " << rect_confidence;
+
+ // small error possible due to rounding and different optimization options
+ DLIB_TEST(res_diff <= 1);
+ DLIB_TEST(rect_confidence >= 0.97);
+ print_spinner();
+ }
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'frame_000100.bmp'
+ static const std::string get_decoded_string_frame_000100()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file '..\..\examples\video_frames\frame_000100.bmp' we want to decode and return.
+ sout << "Qld+lmlZhXVX5NDRWIFG1T4+OGdJbkmxAXXdHZEtpDf6knVTlRWyAhgv85Tf11KJZbhKv1dcsKtQ";
+ sout << "fX3y/9RxoGwtXxs/KaBv3IrfvBK5WFRWzOM7H11xVin5AjxdzyzgU28fRgEasu0Jvk3SaMBNM1cI";
+ sout << "ZAK+MA+qEKwBn+rkuLbptS4sUNNC8PWaqLqNQ577TiMbsEgoa1FGkZX4feP8EfvV7w8H4FvbS+Cl";
+ sout << "yJCj0Q0Bx2fqFCDaAVtlm41VIJUFS7AePKmSnVnAlYYOQ35XxWS2sjAlWZHXiEfKmWmSOTt3x9/u";
+ sout << "5l78aPB0FvrUgF5RrvyVIOSgsC250JwINl8uulOOsykmhUUjRCw81dAITDr/d2EAYmPyiYNZWbvl";
+ sout << "0Iy4f9sOJr0HnaPdCSvojI9/xdQVT3MVjHhu82UtpJ2cuc1MjUooCZdQmTn1mksdAdZtsn78aRMw";
+ sout << "P97MKVgFmYMmG+yekaN6+nYFgTXnuEnatHDszyavB3+73iWTYiYs+CDRqEJ7XUkp+F97W/hUuvFr";
+ sout << "DwePerItRkWZN8Nmnjx0Jibv2wu37zkFVnFIAydvV9GzaFDd75hUwQE5rjGN8jp4dBMNwcK6l7kt";
+ sout << "xLm5hYF2eo9sESGJqR+8uAhGsQZESUpdBvLmf/+28MVfY7Li0lmeq+bz+JxvIa/jF5ptffuTRtZd";
+ sout << "WLbJK+EXlfeorbfh+6di5CPp5/50W5zrOepXHeDfrzCIpz6XTCw4k749ajiMk5y5XUAk/gObe1bG";
+ sout << "JZeGfMezy6bNcosolf5mmg5yDk07bK0JTLGFvYALT8aZFfUTHR8+47krp43r6XxaalqH3JFy3pYz";
+ sout << "3t1IQPC2wz9FiCGBpn+UdULDwpBt0oqbGJYGEHDWgzwv8dv31NEbqez+G+HSEfF7CWs80SH3yYzL";
+ sout << "9EjKo7ANmcYWtr8+3/SpUl0Yyzn/7zDL7191rWglr3N06+bB+bdQy7J1yGueGtTsO5XMqfUEQL5n";
+ sout << "xQ2FOU9c39cruW0Fp8TDkEKM6gowBRTelXMA1w84+L9Dynheg1iLE8MSrzvA0/TlXoFD7Wt1cC/6";
+ sout << "EoYPulG3BRy/3ugzUeiHm4ZWh95YOflbH2p8Ix+5/sBZ+W5xSr/37NRYi1nQ1jFwqPdAHx/OGvfW";
+ sout << "BKhz+PF3146JpvuVbq9fTqjqYeM+qcAu1mS+VqDyJTw3IlvRUsixyJu7/eN01m08yumXP2mm4SwC";
+ sout << "Q0j4aPKRzl4rRx5PJE4qnKxmrxSa5DavIO6ik+DKR+leGGk29zEqY+UcsqcJQLKWVy3k40wL8ouF";
+ sout << "ihEPUISSZDTE/iU/4wntqVXCP9vRNYnrnXVbh6EFFA6nr4dBrj+qnkkpDxrt3tGA0Dml1u68hNhI";
+ sout << "aI2bAca+a1TXb9mn18JTJqK6tkoJmE4JP/T5W03PSGYhRDa0ddNEt06eVPuCrSyDwQ2lQVjHiF7N";
+ sout << "iHVbLJOQ49//SmXdeId3+sJ4NwRM0jquqURHrBAYwuEImtqZDgJ9Ac7iC5FDMuigc3KJ6Y3MRAI9";
+ sout << "irrbcei2XbALqQH4hsHiOvplAaUwqE2f8iaHHQrzSKdNoRpl/0rWgEU4eWcKVl7JmmwLRWqRDg6x";
+ sout << "jYDZ9Bu3KAR3llOBPHN6G6VcVx8TfP6qHb7bYh0+lKqIv7qNYUdbVNZLnfBSrgYkrbf8LmgiUAyS";
+ sout << "d3JNuR7XsICjOPdEgC5OhKHNT8w/wcZBN1T2svgaokoYxI+dV9t84s5h+W1RgCLTnEu5Lz811RlG";
+ sout << "KEAZ+AgqPAk6PneeZs/Ujk1Q8ISZ/K6pL1SBnZN2RaN731UHMYJKK4/yCfdkqGAtFBnS4vvA82tV";
+ sout << "uv2SwPvhD2KHvZDCsAVPXFQWVVkzTWKRcSXe55vgnZjF33ziAFILy9xNklmZZ/tDpZa3I6VDkNUW";
+ sout << "o5+eLuUcJRi/dlrKnNaHSsjV6jQd0ud+JB+Gv2YfV9XcxFZDPavRGkNRdOTe045rJ7QKTI5OqZ5o";
+ sout << "DZ9Olo25TyXVteyt9CwVAlLB2owBoEmni8HFWUS1TiFWmU+DupCJw60bwE9SLhR1KpfUp32Myg2S";
+ sout << "6e7PJVndER84KdqRkPFWZyJsBTXJ6U1g273ty7MLFxZYmCp7SwvtXNrPMJ5yCS5VGa7sX8xavNkJ";
+ sout << "ade8iiSs0Upqbm71iDbMCp4oBRxa1w9zMWZ0VF+Qg4HVyKdyRmtJzBmRvUmujyYFCQwMtmzCgIu7";
+ sout << "mPCbA7EWkoteGnuFIgag9P7kAzEF6pI46gd6UyOOFF44eUytnP1AqCJYHvNexViNZ0FXzgoCzHMs";
+ sout << "m6RPhsqJnQjG+KaefJm52DCfLNOPd6+rj4mv5iQeVEyNUtdHuS3ll1svho7IKMF0KXyjajXV2X8i";
+ sout << "BaHxZ25zSlogQy//I4qdHx2SlUkiGJC8jpRZ9LVmaCWGCnbEnHn3uiiwTLRiPO0G2qgbMjgMvius";
+ sout << "XjOSu1Dlbz4A4XJ5nxRyF+7uG450OJbVF6LUYujHukrHn8LADkKQxY60sddDMcx16lofSVJEF2X8";
+ sout << "pwhXS+EcqDHQ5pT0Z7+fMFSumC6kTWJIfcwXrTJDcWslF+cnQw/qmlnSjK1jnTyU5YpjyByETfWc";
+ sout << "hpFJbt0INtrUIDjOwXha0dlbC/jLXrVbW26wTJrKmi7ftD6CJSE8cZBGLXCX4zakRRPg3DEHsv8D";
+ sout << "BFwPWlX7keygRS8L6o8teh9LHlYSbFqlno7FNqKh1V7D1ERDvsBO+6xoabo+mYXOdtyVWPWxJ0Kg";
+ sout << "1qsnFibHgmqi4RFHoSMkAcHVYj9cHhiFo3DSj8sm8l6FGBQSm+o+oehamFW6xIL9ICgm7gcBQH6N";
+ sout << "BH3f6q0CorAetVw7kyLqNiEt7G3V1DRjJ7bcn0/ziYsn6dZQ6kq9gslBWZe0dp227DpameI0TliG";
+ sout << "pZ+d5+sm6g1jE1H5H6Zlh6jJ9xDMywzNFOitOdcOL0YFuzsxKXVPA1Ye9/wYqtUyJbFvO2ZqDSbC";
+ sout << "k2z1UcbCv9KQhonO7aO6qgkaRIFmPmV82R7JsHqjA7CPQC/hURfB7qWP5mzvyB8oEqtd8LSuMgCC";
+ sout << "2WRsJTQ6GU40AoOcUFDPOo/Gv99uWzDQaGgJJzIxPlKhT3EtJWH7ZiapmojgRgdS2iEAXb4adzRQ";
+ sout << "/NsD80C6PvsuxB6NEravtrBV4nz7fXq2u1mPkjv+/jmp0hHf5SV1YDE+j+EfZqNMgEMaWJfMDbxs";
+ sout << "C70Ji0d8iv7zyevr8fIrVmTbttOxS73NShBRX7mHQnUcDZFDfLJy464v8PEvwX0gr8/ytzdMJTrN";
+ sout << "LEYHS/+ZxEN3CZyr5tzYL+DdqsTY4qFyuoQOQ0brplrGc2E1LubNhGVHXoKTsAE1+D9oq81HG4VH";
+ sout << "Y91M0A3+ckyDvQ+dTtPx+KT0on8OkZzSLy0VUJ6yHVmzlzL2R2B/C8Gp2KHKk2YJLxE8DLPRkOyz";
+ sout << "ZPrWyDUWKir87Cm6xHLScUj0MWujlg/ag1CLz38LKEBzoYxCzqQ/Zmyfp3NZRi+PEhSlSj8AyVPv";
+ sout << "FOFacvTBn+dQwWq6uLbOBXJiycohjgBGxrLc7ILTp2FHmNfpELCTCowepYKuO8VnEBGlIUJAH7+I";
+ sout << "zTtM2bvgr7k0b8dwh4SMdCBSOWPfYxuY8kVy/vcW0xKeAkrwKlVuL1+HJAtUTK0p9BOWl3bi7E4q";
+ sout << "Ur5UH6CSaVzbn7QvGh1wSKbr48XdeOx/MlqvH6yKeXxT+iogx7BdmQAMOsndXWH+T+HBbmqIF6Z8";
+ sout << "qA68Q6qcooNpHu13lqIojIbbf6VUT9eG9RqGPYXxKBAdWuRDbMmklXYoM7PI4VlmBcIgq+4U9Bps";
+ sout << "Yge/Oqf6DPFxPaqG+FnYPl7GpcXHNHAsBh5JwaDAM5uKCq4Y720KHLKqG1buziPjRwUr/5LhXyu6";
+ sout << "CF1ZzEKb6w26NDjD47myrFKRihB1zVZBxb6uIlHVcEqFZhJqC0uV/QQjRUfAQrG9G804Bcf7nuGM";
+ sout << "6u+nXr1Xl+oQIRiEYueMD1T0WbiIYIIO6lT0ICkSQQjoLQXq/8KP21Yrutv3H4PjDZXOPi9Z6Rcl";
+ sout << "40Vfdt9jyqeRgBhkhloRYFOBwXLYxlb6umBcmXoT3cv/Zlh4SfQrpG/LrfAHtYFyfwoDzXcV8Xuq";
+ sout << "qAtTkj97UcAVnfC0lNrBnWCT2SieQUl8nNLoIQl9uhzRBFHH7U9ey6uX3zop4esABGV/DL1ypLd2";
+ sout << "Yu2Q6XU+1j7wh3Rn5zttubqvYR8P5jMEqMUaX2Fith42jHy0U1NuNzX/MPCm6DFLvah/G/0sPKiP";
+ sout << "lboucbKekedV/knGV3kx+h6MmMgBRC4OkkPf3HG5ewpVpe4I3SZ/auTk1ZnePM5RL5wskdGCdbND";
+ sout << "njzEv5sKtTE1TxCAgVh6iFUtvr7c+3DNDkTilcm+ILRMwEr9vCF/OGcOT3YWr9cW/eP6fy/801bU";
+ sout << "nhotDzcKSPQbbN7jd3WIe6kaeXjsBk6LIHmPjtGPT3UdRFhEJZrnABhQrasLPeY0/+dODC5Y3GcC";
+ sout << "fV68kmensUcZ/5hFSEpLsScx/OtGvFI7hLNBAAJ6bhEuvgId7OXxQgWjQ6wKDhwOJmEB1Ra/Bztw";
+ sout << "4bGdRMl6c99nGcc9GmSWI1CmQZqbNfxL7AKO1oSkwpWbK9E9Cl+ZV5XyP28AQky4KQlYBvwUDH57";
+ sout << "LvtOvuIe/HzSagQjG8Sxf7wGF+MAz7MLRtFYGkHviL+mRvtp0NSHSYDEOqaIsnk1HbTWHUj0FUQt";
+ sout << "AD14VkX9ReVi1LnJGJNwmjeLYk7keYTkksi/kKEWCkh/8Av2Rlk7oZCwODxsAcnAyXStJV0tEAHN";
+ sout << "lq7MaAJz5zWZ4ZsETA5mn+3Qj4jIi5vbqJScxOh5KCvyOYfio9Eo+DPvyJtvyjGnqVkoFFk7v+8H";
+ sout << "e4d20Aq0vfYGKPzVNoXAtOPuA6/d/2yizrl9gEPBGVw2I3Q1/XlFMIHHZ3yT1TGJfGXcVkrYRhDY";
+ sout << "s8k6oEKbp0QQUiSkercwxbLpSMblFDaGeudPcWzKVwPE52rMEzh1/jr72s6ZUd/+Os2b3+CEg+lj";
+ sout << "AlnkU1tNZVnanmzOb5xBnk1H+3Q+FppUe9MgOK73344O7QlbBklWclKizZIoRVaOeP7WvjXSJK/A";
+ sout << "PPXXyKLMwP8eYDDPKOTA7MvBH3q4r7j7Au5av2sn1ZGwxBHLH4aYFH54kXJa5rwl0TgsTAHsQ2+5";
+ sout << "0/OWixbr8Ysl2JxunMjcewhJFI3mLHFmsYfVYDwC2+hIubW231HCxvqI5BgWZZRU+9DZ2hdQ+orw";
+ sout << "qocOI3yK5uHs6Qaxv9HPgWzNXLG4a1h2C+RbhIFMGNUcDAGrCgBCgepruLT0RLuaSYrsZY/gWRwj";
+ sout << "krk1M69/lnwvZlasFruau26RLyMAbELixpeeGpztv/tJx5fqStsjMGLHUsRj0SKgXgdawRdgUXUs";
+ sout << "Hepp9HsmJ/HR+5xAv8ORiEQekO02b/2wgqGTYXyHGT+RP9f76A90kVJ6Px4njrOLrICyTUcIFrIh";
+ sout << "0AvbAEysPnxbtYcp5tdNbG2RuarygtgscOxERIa1NaunqnBouzW88mOLeTgBV9Ofh2fvX3qalfcs";
+ sout << "xYY4//taBvyKWHiCanf5drYyOBPm1m8uzNP9ew6K6IYC6S4m8aklO+NrkoCqL2sq6Ged5y1nsWws";
+ sout << "BLeSaCakjVDU3ysh+J8YamF1Tzp+cOfh9dAquG4sk2jJmSfgl31jtG2lCBkBRQf+Cybgr7QfJYGd";
+ sout << "ZoX4RLzmSK3BXrufSNHntV7dXKCDT+2NvrW4n1EUegwonALERWRTMRDmA2NPZYgqr5cy8v0KQr0V";
+ sout << "hvEMjqP/9SPJdTtbjJY7WtdTNU5Er4KdzasYcJjphAPU3zC5PtCHbtTVUGXNz1UneXnzrN13zlK4";
+ sout << "WyJ9UcXeWpjtztobQZ/8NwXiYOndZ+/qF3BdYjHYDYhSjod4JyCnmLuy+cVG9yB3e21XVeUVC8sh";
+ sout << "7yg2oBx+9yn1ZK5ResSEAmX+m6Jq0etJWVnYvnl6TSN0XAFbZRk5n6r9w443Frw8RVRfYyISbuTm";
+ sout << "oLMaClZFf/XxxvCIQAd6+IDMzukUDvEKObp2+rbf7IWZtDUeR3VpCp3CJhTMr2UBRC68fwB3mx/n";
+ sout << "C2pAyPNX8WbZ8ZpAbtW3ax8U+yh2rH+hEK6zJSXFZk0Ea9Yc4MNhtDqGEXLgBvPX/OMH4E5wJxTP";
+ sout << "H070u7+OQWwA+Aeup//kJ+jm8EqF2RTAyGCiiYl8JXAMNBdIGaG+kzb9RzRD+YiyOrV/WyXn2pFv";
+ sout << "68NBt9sunIBnzNMLCPzT3lb4DcoOExmoOjcrziN6G2HjRsrkmyVnjvgoWO2wILYkm+yGy+ZEM1U4";
+ sout << "MpVYZgXikHOos/FR0GG9H65gUCPthPn448mtVNFvuYJhhenvuuSkUWI2IFW+rhEeOUXg+r4ZISpZ";
+ sout << "F26+EzKiDYKJWbJQNrVJv7CygE6SZ6E4VnKuXbzKHDct3O5EHoE3NDCO7J5kD7RyFE6HgUXusSgb";
+ sout << "kH2DOgJrK8KXhwcDIH4AmrPFukCmgS0/yTWOjiZKW8dmp9q9UHKQlRf1rqgmkph2fGXXBbi+AMLd";
+ sout << "qtk8gsm5Gb/a/fHIg8zOJ295ZShCqWm87z9jK5lx0/kFZKraB2JleJ0ryPFXCOp60hUbSvzzfxJ7";
+ sout << "JkH1dWoRvWr3wpXrwZucfHH11AeFRe3ShRaAKy42+CclzPkWDFLnJQt8NXSkZeuoPjcx9A7lI6UP";
+ sout << "WfwYZowJkoGDUq//TSqzSK6XAvQBUHuviclx/4/C3BvCwu40wXDadUtt+2xWJjNNvMlhnoPY3/i0";
+ sout << "yKS/BSg4TsdTceqfkt0nU5FeTdyfpLW/1TOAFERXkV2bTkEnoJ8g5LXHqj58hDZRnO45lxDXDYvX";
+ sout << "/3G7ja5OfBA+8flE+MDTtRnWAxgUBDxlFEKD2T2RBgQlLcy0Xa23LgAD3qIWYH+o5UVAmVoI5HjL";
+ sout << "mJ4FxQdvJOH7VwyIubukaioVoB63Ls1ySl1Ysk7EdgCpdiugM6I39QsFlYeHE3AlLeoHn4caWBFE";
+ sout << "lWIhgKer7kBvPFG7M6bN5flAtzAbEXFAH4M35L/6Cl+O2h+BvCJFiK1W/LZw2pkc+Ie0xC0YNWSn";
+ sout << "5nAPPxgGYtxCHLVl3R2XJvBcPpAHzbc05fAp2G4AKscWc75SZwfayKKYarX2U+zu3PVJIYbioSqD";
+ sout << "B1b3Wq4vO7MDdD7AMZjJO1ZrX3+8piEVo1MeVbqZBy3VTbjsnkkErQfyOulxlQfdcblMj2Zn9NRg";
+ sout << "EQwOnGEsWulIurqEe7bQEZldvTTz4pTOZ2W+uv8TqfnvtKgNESaqYvofF72ZFSJ4rMU+Z18Fqaep";
+ sout << "5MitEMIKwX3ttFQlgM4/CWeaygkhyNMbFA2C/T3jIopsy3bPo7AwwhHxaZIo9ghkvt0nTP80dgtU";
+ sout << "DhP2a2tg+fNUkvWFSfdP9/NKyH5Fahlm9UPNozDehlI0dzeHvqSpdYFYuTYtf1jggNVeC3r/Zfq8";
+ sout << "+B7U2qq+e4iGW1KVkOlimC3tYTRfPj4ZILtHAUIwHh4HY7pTxxpErDLpxyEZJNkBUxgN0hPLm8UN";
+ sout << "hTu3qFXzB149UpZCv/JmjmmEgLAWMT3PbBCebNVPEjfO6I59rTOvl+fVD3X2UJM1It4mh3NJPHzj";
+ sout << "4O7qLNm+S+A00bZx/g4ncCxrCvFkTpBmt21ZMAI/H4swFDlJ0gzKTq6+nhI/nbiWsuRY8z+GNqcE";
+ sout << "Snutk54hYXlQ3tVAxhRY4srAcEQMcHjX8pWrJUZjRiiYWnVfnmUf6U3rki8P25851GwTUGB6/lL+";
+ sout << "6MWKQ+s1Sa71hc5icdJtxTJSWADg41KO8n1xQ0Pu0KJzvdNMM9G4AswY87XnASsaF9EXKDTSW4q/";
+ sout << "s601Ghu8vEYKQVBNfJdiZ0wPsaLM2SvRwSS4Ji7agaJRLBvcv/cK0Hrxml20CNIGS2Q81xGS0Hdk";
+ sout << "ykMptYe+8pClFugfpJ+ETSqFgclYc0XoucOUxFh0vyg5CVE7WX6chpBmdhwWNnoyJz+WvNcoPIu9";
+ sout << "UZdYaseFLVhArb8cVaA9Mfk7tmFxaxNeXsIBFjiE4dcWKJbZXMnkhCaqG4k3SIHsswxjtj9hFoTB";
+ sout << "BzmbIwFhxPg7sZVrG7Af26CYC01P2PGqNqJnZRZQt1LwtPVQHzvMk2v0r1I3DDD6ugGlva+PrHIP";
+ sout << "D8FCY/mc0poDzTANlwAx28hkPTbtCpJIrWScZ/Vu/KYJo3F1Gk4C4CwcghgkwTYLhe2eMwlnA+Ww";
+ sout << "vOlw5SD9jca7GrTA2tkTvsUnlskDgGlkAvEwc6N5DkS6clO1XRbh1mwhr4UwRhkR9Z8sQLLUv0yt";
+ sout << "N/Wq4i6xuXuROdy78DSiSh9Gis5XQbqCvf1VKUOKkaA+/H0Y+XsrHrRCcqE5uaY0iIZtc62XgVHW";
+ sout << "QyrlDHsh4Lt9qGD93Dx2EZQqyl8KoyhQ/WbgnG/s77zdSNGTkoJDEQKIrKLRk9ReptdqQzjLPJvf";
+ sout << "GwN7wgO2N68B9XbX8hfXUmHX+G7kVnncugQzg0DS1qQ0Hbp4ibZHKEAUHtkoPEVgzGcsVowGYCqB";
+ sout << "KCGaq6FqyyIYp+UO5j00xyftvh6uta3atuu+JHkWAPGKY7uooV3MGmdcnnF8umE1NDEBfrcWrrNi";
+ sout << "IzQcnZe31Nxqo3WsknqYpRvCDwUZiU1f/EidxhykoQ1NCo1CY6ociQna6kpsM52E9ALvYlUyobN+";
+ sout << "7iJQcfagbJw6OpT+P54HFsedkIAlBgTIfZfag1u4lKZnibEZ5SEBfOGeI4pemf/ST+4a+2GalVHc";
+ sout << "4W5T3xhuJjQiIeE8X2/nEwqEPQRHdTH109EW4BCv1rlh+XWqRXQVHPOYY3lVV4ONn7/iWN0fslOG";
+ sout << "z2/MghTZb7Z1MhiWLxTcwpUXoyy/fexdBySIUf2ukM45ZPLr7T1aCrc+pW+XKfTewxqhWnKSjpVR";
+ sout << "o4doFMv+eVk8mQPFjAP98MdIdYEgTakSQPoEdHijyq31ID0lI3qDPElGdFEq/yKpd9Tg0i2OaOvk";
+ sout << "BiaBfldrXL+7jlCSTB14Vo0RCiGIGuxqmEgQhUrT2iHV1baKVWUuKERauwS3jVNz+xq6PsAU9lWn";
+ sout << "mC1WbmHOWyHpvoxE/X3X53njIVVIbRnGwEma/1THUaFIZ+qb4UVfVqBXHCiCtyAnf7+aH2/T52Pd";
+ sout << "gayp9h72ofbN/eBLbc0qXqECKcWLvwQlgiUkV5490CkjKnf49ubdFKRRH0PKCi0CR8T8lfhrHYnf";
+ sout << "j29CDEgwX3LkIPw9GXt93ua7XgOilyqxhPm5vDXFpfgR98Y4iMUyheHms8U81xbRmM7SqAPXDJsO";
+ sout << "vjoNtHmc20hgCb46BUQbMQMe/hhDjQNFMQ+HA83cvTyCqsXVdKcWIw3vPfD3RUqZ1R6LGgmLe+Xl";
+ sout << "vS2S/uDJjBohq5pJ7xCSI2+ylMf47qtS1pYd1bUXssnpJOu3kplTwqW5idkz8l01FlR+duXcn4we";
+ sout << "Wop+0ncMd5AsBpk2j868TAPVEJyC2lyQ9Qm6OouPEtMg4x0jyANl7JhyfEesTTvxIm3GNFe0fKTT";
+ sout << "cFwOojUje215zqgOqwE0b7JkipU4ssQ6j1lSHuW8lWEVKa17Wr2B77lj21JxubQJ99iifjoarWGs";
+ sout << "5NX5nxCQKJGDEPagDVmL7TZaQj+Cp34MR9mCJtRTyLmA4g2rwARMsbq1btaKXUG91484ipcu2jot";
+ sout << "/F3wLGdAhXQFWrxhIJS0ORxalPcSZ01zkECazBTIMokBrneVaX0RulGEMrs+CIGiSE6pFJKX18xs";
+ sout << "6dGqxF6TkPk3GalrQNn0al/8FgetsD8zUX+d5PdqOpb7qNqEfHAClOcNxhv341nNrGnR7YuO+r8R";
+ sout << "67gAjqnan2NY02JmVhT7xqN1Uz9jCWqmbLNAPDUWPvVLvFtHAqqp556xhpVhJchWi81mIZJ3S2L2";
+ sout << "TKd+jGHjojiYNyjQxz3DD1S6ZN4mPgzK6JMC4txgvcQ9qoWhoqBDRDrj9s3lT5kkjZnHRPO0RE+s";
+ sout << "W6m51Z6sKMT1/yoneuJEnrwLxlD3Efd/cq8zmmUKREZgqR0q7WkXC1Tfs5sPVwHDlc45p2ejbze2";
+ sout << "iDEv4ULiYrNnj/wdOSwVnNt/1DXvQ1FLtJ7K/Bt0qUszCNnNu6ZOwhOM2bqVugcOvONb6fLNOzY3";
+ sout << "WYA+2RGS0nuHC/pFuzX/AQiRveFsq3IqtYxyexv8uxmlg4sPDb7Q8i3pT5fQDYAh6CIFX7mCITvy";
+ sout << "2JnjSfiVqGdITtvsHyJtUVO0YlQKUETRBzsUQ7bvQ9gTETskkHNPe3h/VR/Dpa6ukhyspf1G9XI5";
+ sout << "bobOkgu1/3ExBOTiCjbcxgGrgXw7VdmFURpq+FAQwuNxLDSoNfwFw6ISgP80lgBDV8/5l24p517f";
+ sout << "fvNPTCus2I2A//vLxMIQqU6cz+kIeViDq0fcTxoBiD7SoISwJiRTqciiJ005uVYFAisBpU06k0NJ";
+ sout << "NQf1DfYz5JKAiP05ehMhQLCnJaHjbC3yILIfxXYu4wEV2lfLWWZ2u7/0oDBaZKNHv5JLzjQglqBr";
+ sout << "o+0GHf/hhu1EEqAyFRPWruxhJ1XzvbLt8sHc6wsxkQCYXPGfxlz5+WMrSdNP6jPoEleH7xcL/b1r";
+ sout << "mw5oP8fsoppeoxiK63Td0Ut6WtCI9grmCKOYJt9UfzTYI4TLDsI7mtofZPUvX8IeOOr8LySclV+m";
+ sout << "AcwU6BJepTxRmINnck9tCe7m5unJI8nBy+uy6a9NvSILGvuoJ6bidAqS0dojvAV9m1smC//ZJPLJ";
+ sout << "8UMVkhMUHEgx8n/Ss08rQXBFFqao1rCOvbUR1XC6g31GAju2dLm8Zyk9eomFmdOdtRTobmK4XbX4";
+ sout << "mNwCOlbKSZt1aBDN8JLYBR84cEBB4sjsPEXupIrJyKAe306c3RrnYj5lMxvnnknMqIfllkKr1BOm";
+ sout << "kBYJ91aDe10mUd7zHQwW8KAjHI4pmDhLwusleQcl0RL2Q6CgOB3x6ZaYnXLtuWsOf958+niS6dg1";
+ sout << "RA2Lv57ajSAOHzvuy7h/WV6uWhfsikjw6TEci8rQ6v5gY2DEMjrtSCbJeiJzcIqIC8bz4lvmEfiV";
+ sout << "QpZPhGDgfS73qeZV7ljrfBcjvSuN9MPbMFQfkr5v9lTNJ+/AosXAjqM6aJ4TTrMq3XAYMcbuEaDt";
+ sout << "89cI2TabaBe8B7cniD+JBOu73fndg6YglAy0Cl41GFjU0u/xq0xHIim7e9TLVtTJE47CjTWRrBEX";
+ sout << "ZAJLPDlnhevjsz+0vEuLeqGnsX5yjbmtZzpkDvh+R6eZAJiBVq9uOLHbKqplwbjU31y2OO+Gf2Sg";
+ sout << "C8nczxwSt6JT8ktG3CuGhuEIGi4l5LAjIjYd4LpQVcsUM/vP5cbzaC7/XyLmSY92KfBD8OuUL6FJ";
+ sout << "kMNyHHEtawNcKlUWsW6N7ybRpvZEwRXhQP+5Q3QXHDbiQ9YJhvnnLWHFJx1TmMlhAQlyJBnuGD9q";
+ sout << "Re5kqVi5ztXFHLY+yHAFHgCVbGq7WS3xxwQD+jeuLEnvbtNF5qUePj7u7psaPpYkQEj2DvRBoBGm";
+ sout << "eEheol5Gc8NcW8wRw0hLr3Syw/8b+1uMl7c/Sxx08yvnB7e8JNQ4kQMkxVWFCWE+OXKul5nI5mBN";
+ sout << "HRbSO5bBCuYU/9P7DvvK2U/WmcnOvNuFFax6mmBc7E8nqao9wKmpypDvddWJ+aStllc1xZt875G+";
+ sout << "xrYOX7Iv0MSfk41j+VwrawGyVYKLrssBQoft2lbeDqOsJBgxS1wFmxQXgFl5pW2ynnye7GuR9A+t";
+ sout << "pbgKYqFhYg8v2ZIYJ/SDJey5H6TImYlPNriXZv/ocn888rvlfOBjDcG9KiJ/CJDLs6RdMMI67EOE";
+ sout << "5VBRISu61OMlhv/zp2bOsyvONztlHXERM+N6zViIlmbewanZ7ujo969A8Edx1YiD7eBR9AkB2VxW";
+ sout << "nlnX2RzojsFbKIQjuE5NoFZdE49PsauvRhzYBid47XuPLu48IGlVWuX8dUENwWe8d3jUb3c8BPUv";
+ sout << "2E2TUEOcbEx/b+q5ex28LRk3CL2AolCUHMhw/aFKOj+P+AJMC2qAgzhEkVUWLfD5oHnN0gFzIr28";
+ sout << "EONuinHzkLIW/FcvxqVjNBn9O6svsz1RGQ7dOUStEjwpKDyYWboawNC12lNgh6+JO5uA9jGqmSyF";
+ sout << "uBacrq7GoETHISP5/7Mc6PkdsT3IQ7iy+suT5vTybsqlATrvJCR2/6S+M8kZCJdkIN3tVrbKqKHf";
+ sout << "O5sHBb6MgPVWotsXQyZ5m5t5x2bquGfonmlEdOcjBpYJSEM9Gd+8zvt36NgjRRr7Zgs8DU691FOE";
+ sout << "eoOOYDfc5yzjHClH/b2k5po2saebjyh6Zf2/EC5mQuqkHEeLtRbxSDSa0u5jbh80vEZ//J/3nhsr";
+ sout << "lL7Am8nXsj6aqyEFVYnomyDo9Uk+KLWD11PqLqJiCNTnxZbsCwiqZDo6qo4JngJbvdbGZBQMpgzU";
+ sout << "kj5cDYxSZLDARNLr+/M49l7esBKTtJUVtgJX2+nejrlgeNAc2BwGQPecI9Zx9v7qHOhxSsCwSrGG";
+ sout << "ZrHrTT3Pb08AKSzLQLmuijTWBRw9SmApDxTcUnSCI3RGTmqueOtpq8hGzD07OFhFU62SAdx8cW8w";
+ sout << "oB/FAeZGOpIHY3hi47JXjZGTw3oGMtUkbq6RN6wQ5NJzij/HHjLzXyXKpd6CveDuxM24zzyvhEwY";
+ sout << "Q3FNWhhgi8RcIFb5VutGDNZ4cd+jqf/veLh9Q1110CuDpR7MUUWm+MlG//rMcScEZDOZORL/qMuw";
+ sout << "+aKT4+x2LXKb8G5uWwNovb3eZOJDTnb7Mvp4RfFql4X38BUjUgjZEAO5QXRPv+OfUi3PeeZBDQO9";
+ sout << "yjqeCNsecLNlu4cf222h7Et1rdp8zYPwMbdfLrI3pH69VxM5eSJ81OtprdbsO36svGkPOOB8/rNq";
+ sout << "Lu/xFCLKvWnHCPoGT5ng1vmcUl6Lkz3r+yD3+Uy0X96++sXHnOVdiWxKyRhew+3/gSQuM6zoPumc";
+ sout << "xzAkEjOiKghShOmqPKQzrsEUi+sZ3NnCBwRyUwlaFw7/Fp1pi/N/9bCX30AEfaD7ucsxJGtC78W1";
+ sout << "jWxi5K4qPpgaM+qU9VORlnJRFpYFBgu/nqvd8LkFrLkLsFDxkEO7bQDjEKX6xTltM7nbF/KVSasE";
+ sout << "OXiMk4x9aOa8buelZrRiANp56cidcnJ8Ayv1GGAaA+hNYC+BNmTXZUal3NyMvxJOzFlKfaSUbgNB";
+ sout << "EsYQusXQuVqkD6lW2Odx1Lt3hHzF0MYgMY3cuw5oeMKzarQoA97JDf74Lnp63MeuU+pjBanY5pNF";
+ sout << "9D3/l1UIsrFh2ZyC0vS4i+5SGDI0Fza9ZsopX14fbUOG5FFSkrHPL1ArMvIXAG7B3OPSRF9ChX1u";
+ sout << "uPCaRK+0JYxLdACdJB4x+Mhnv3fS8w5dDetUO2zy6swM2OonEXqln7lETWsBk5yQpXBo9sR8RamA";
+ sout << "nt4HQhwaK0IAX4L6p2LMKJ27YUq2F3Xe2PWf9gMXdrUQPTvpq+MdEnDIJYf7sTkT+lyqLJJZap2y";
+ sout << "jmuv5U6uQyaeJ2ayTd2TcvryYtrYg4dUEfTvjdlE1QdJPxUPCzjbkM/y9/SrfPCXsWVlDZDtytup";
+ sout << "tSK5GBlCH8JOpiiwgyBR5nAHNZ8SKP/e0K2Ps+mDx6xcSz4WL6loJUN5+lPuMHYxXbT53LPOqwNX";
+ sout << "nS+OCDPFclYZO/0TxEerFNvxvjXLU9VxB87zLZRV24VAztABk3d5zrY48iSZ8WDUudPF0mf39geE";
+ sout << "4j9zhG2/3+rc1yYFhAQFN0ch0G+Gu/mEEkSkxSZy/PK/v+ITne9JIqao8JfEVgwmQY2VCSyy4Rj/";
+ sout << "KHb5SOY6CDIWIRMRj7GsqxI=";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'frame_000101.bmp'
+ static const std::string get_decoded_string_frame_000101()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file '..\..\examples\video_frames\frame_000101.bmp' we want to decode and return.
+ sout << "Qld+lmlZhXVX5NDRWIFG1T4+OGdJbkmxAXXdHZEtpDf6knVTlRWyAhgv85Tf11KJZbhKv1dcsKtQ";
+ sout << "fX3y/9RxoGs2WfSxuQ88PN3jZemvW7PfY1s1vwY1Govp+hD6LGyg6qbJF5iSr2edUHLWGq/00D7J";
+ sout << "dhKHu/QsLpEM3w0yjnufc6hORX4GNjJ2a/k5xFnHf29G4ztGVp6OAlsQa4tcLsh3cOK/eyCjevG6";
+ sout << "7GaMl/sfkK4TsjNfqdXM8sWxBCw/r6JTLrvlzVxDTmu6Sm2MpQ+IrRXFn4gK5vwPmc/mEp6etj4w";
+ sout << "YeFI6w2jGVTFIg3R3ixaPUtvv+gaIOtPsF+4StbI3AlmD1kDBt0t2xEyPOsmjOBuh+oy/si3xWt7";
+ sout << "RlCBakbDZnIihbzDMrLwTJvgoiQDgT3U4JRpH+5cL8E/3M5JAOuKWJU9nfdtNVIf1NuHCd1Ajsvt";
+ sout << "LjRLI3+rqQErHsKre5sId70rx+sYSqTWaiiGi0Nwg+JdCRuA1OwFmK4rahCAEjFr55py0iZxFuo/";
+ sout << "r9DLOVGqKKIWDX2HQQrQ8N1gATa5Wh9ypd1RWKz0s1wghRS8nzIpFqtvXxJxYz39gjai7qUOP25M";
+ sout << "zEL4RlnFsqP7GJaT7YnpV4PLkKAzYoq2SMTq4nlK2MGohKoPNj1JCfVcq3syGMJNYtNNOtN1m/E2";
+ sout << "3AKUnpPX0CAEgA43Y5yMLR5ca4+KGlvO/qWwex/bcU6uZSc6uEh2pcVaFvlnyUYdZlBg5huymL56";
+ sout << "HojkUQWiv48vxAoYtiEc2xsgsD/HeJxF7z+kWJ5nDeyxLhCLpDZfrc0g76TVEOVy68CtdGLuhDe0";
+ sout << "dhFhj+tlkmbNse+X7PlhhnI3Oq8kt8tFJly0YH9RXYd2eQDSK3Wz88jGhe9wcBEwcFir1foxdRa/";
+ sout << "rj6qz9DeAo6bVEZkIEXPSXL9EmiyTxwqjHgIySK1il+csxW2Gpwb1hwlKQWr2x7bRf826W/DGplc";
+ sout << "XNLAy0HLzQh+8E1iJZhpCwYthe751RQLrYaBxlFNZtlHDKIfuGg6ByiFdr8y3T9dx1e4xJHHAEhE";
+ sout << "Aem48B79Zp5xD/ZZOgAmJx/hrXJCnYXdH1/5FNhTyySNQKG2fChxHwALxDoOsB5BwFP77RkEuXkO";
+ sout << "QWLOWFUtu8XYAvf2re0ZafPDonS/GoPZkMBWo4S6KNoczfDskSXQ4PqKAW2stt1K3BoxO5LXfQYG";
+ sout << "qepdQBcZfL1Ir2WmA0L2HC3UmFr4bPwVkzaYCs4uQJqivpw1lD7CcNKi4BfBqWFTs0uWXQrwY1Yv";
+ sout << "rmhEavMquYjJRdo2QnufXkSjrW8VpVEthiQK895ABfggCWg4F3yiqBlCR+aGcCetNoJZ9pS3sb3Z";
+ sout << "32tmtN4rrQ4jMb02QNnXPQ+vAavQaWok/dELaGYCG7PefjjcFqX48WNVmHQx+wLW4m0xo0mqXSxz";
+ sout << "SQakpCGETfxuzyIlmYVRh41QcDdd/NH/CIwIMFVyga44Ewpcuew5mbz7bXP3k+FlxzXJ4e0ABVn9";
+ sout << "hO2sOG1V8teDymKMNPQOBuec5AenB+kchCtCDL8ZtglpBBqmijUCegyxiVpjnCzdGQvIq1XdfVmg";
+ sout << "e/hWiXT1Vi0+1wtdL3CayA1h+zvUiIJspWD5QOHj8O6buKKWky+V2UZDPK24AuIl125ajH481E6H";
+ sout << "sjS0/FqKjHWfw/SayHcCdbqIt6/8AouwlkzGOK4C14HxAg1jmvGnRn+46SCXNjlHBDH8/mJ0Z1uw";
+ sout << "C2Fv6G6t/IX1hcjJ8NBQOY/3gXAtDRmkNUc3KImzaKNDUkS9N8U5QLFVqrBQ9Dw2kjJfc28QJcXy";
+ sout << "OOItYP1RQmG942DKVwW+QQV/fvK8BFpRP5yuNgmzjGZI8gP5cqIOkW+D4eBFBb8PYsa15fOF9AWB";
+ sout << "RXDztTbO3uXWcTpa5w/dSbawwBqfeOY5wh3lnYd16Q+Yorn9D0O9VGFWr2JqKwl/NcRtH5FFsNC4";
+ sout << "LpqcX12SoO3gsql2isJzJlyTmeSDB5C7o0mBE23CsBx9uqDwBJn+YAucIFxB6I4sxgF6EOilkA0B";
+ sout << "5CpLkspGnHs66pvrLeByoDMCpkt56bXRlcvFhjRo5nljY+GjBh6Ux3ZJnMdUG4Qded7oVaY65JVX";
+ sout << "MpoHUUG/VSR8FWuHUP+SXh71QiEjsrAAVQ+Fm+fAH3xQ9Vjiv1ffRvxkb72G9O/LjF6NWWgowFcj";
+ sout << "QBhuD/mTI2q+6spFLWlZk9PmCWTBkQGsWouyZMuVQqr56sP0UEKANEGwpq+h3pq9t4+eN8WcGhyd";
+ sout << "kKgOgUvt8aY7Ib1HdckQt9cbtOrYyjxrQ7o7CXR8eg1SLPDJoWc05YevckkECeF4nQcZvoRxOlEV";
+ sout << "rQwpnngQ9oNoREpNnGlZGtBr97bvWTw3UWwiEkqDJWTTbgoDSzjEroYV3RjVyVKu7BZRTvWDFz+/";
+ sout << "xohrwbZ/NFA8wGZpgm0qtJRMCyDUIAJuQSEjxa925hWfdkVJxhExxXgWp7NQzhawF7/16Fa1tyMK";
+ sout << "62nnOYWV/OWcNOrRCt96RlQooSf0DbUqtLt0f6AzkjKobn3qpkaMxg4Fwzru+BaLmId2Cnsqdlvu";
+ sout << "jCjsgpY9cfv7NdcnPvhLSPyykMPmBRe2XZycAzmQdpY04DpVI/A7gpkVyH3IvYP0ZC4SXNwt3Iht";
+ sout << "CnmDh+VF9+iVjlaG/JJ2cSXG1mDSfjPgQqEO2DtotkjN3cHmqqSkLqbmA2PMnbzCn5FJbdhES/C7";
+ sout << "uMRFLhsEyH7BF/UGHCyIkzbGQhF1q3PE1PP7A6UOtWr9tmKvctlDnsiEpGvxBN/8wfDERqImT7Ex";
+ sout << "f/0fF7UbHD1+lJc9VzXKeV2ZsFhFRp7r89yVcejSwiAcgHenFuvMEtZ3lgRqy7mFxwN7KNAh3O6Y";
+ sout << "RqGFSz9w0gb2gmX/1f7kVvoKTgibMHYXyHUSvusrjX9betzQ4NysjILBta51/ZbeBIZPF8V7mhNE";
+ sout << "5c6iBo4j1i+jEzBGCbH4jGwNTRoD2S+uWFrDS2x3hYJL6uBoEWHN0oG5f7H3I82zkju7jkuUCM7V";
+ sout << "bd02Suu/hjwDlr4YaQUTIo2aThnXT7ZgCaNhkWODTVU3BqX6M8WP3zFZrMWGtbuXrE6WdweK5ENk";
+ sout << "+aV+0CDcPHertb4LmJCOY3yXtPe15OsDHwedxxBJCt4y1UdamFtFo57XjHeAr5c17B/MWQ82/ZlW";
+ sout << "1Rwrm+gNysxVGgfmzvmT0hpIiE//g/HqI737ovjOo6yRfkjil8AqIlOT8Gw8mBbnME2OyjX/TPOw";
+ sout << "wX4F52tAhs3oi9HmaKHW//GlC7ILPqFN1Yg1azmvTgqVeYq3l76JbgZwNmPUEM5z/EfOcWu1GKT1";
+ sout << "Pc0Hz6WcPYYyANIlmv32CkBiZQwMi5AIhrUJvpGaD16hPW8f5lXO+oG3FSJlhZhiIQVKFX/8UqZk";
+ sout << "AzFhmyPjUHNHOsJo+1c+5FTnPsjHDzYwW5fUQRSz3L3/rv6AY7g0BXKLmxCJYzIwxyQ5i7z810d0";
+ sout << "jTPy1DANcssoEqdA0P2u6udO3d0rtL1AY7dKoH5QYyKrOjx+HkLZcfxX2lpjp+kbq4uy89JYfo25";
+ sout << "M4tX0B45FKrUDDtIPKR9GRfb6VgsH2bi7aediwVXCoaTxPNgUyKaFWSf+y9gPBiD060TccAZfCqQ";
+ sout << "HqI2fQUwEfsiv1XeTLGH8mNIHETMKP78LdWLhATM3ejl8aGcNxsRAbMXm9JgJmuHxaNViUdWmfPU";
+ sout << "hJMksO1JBS66hcn0jBJhZEYF8qUwG855G65k8Vvo8tihDIbJhxGMDByP1beBLs6uoohknWTtDU0K";
+ sout << "Q70nziEj+VDbaexOU+dF/b/ywH7Q+KUAFvHH/xObyc0/maoRS/e0pfyhmB4OeGsbWAs9aGQNIaAn";
+ sout << "Toh1zF+xQBFMKGurc3P4aPvh3Vo/cw2hQTMYJpUJQvhYx8XAnA+hRtnXZtHD2Jas88G0JXCCX5EG";
+ sout << "9te/j6MEP9Xwb/TjVLDGOc4IeMYqVftVFcmL9hrX6nk3TYDyUkvLjgn8vXEmJ7qB8CthBF7UUyjU";
+ sout << "mcz3agjf1/6xrIotky+zbEAy45LvmZkcTfVVJt5nVyir5hvcGyXWUOpqp/9NQr0ClEVhzfD2d1cd";
+ sout << "xSSLaeTKsr0jTWFV9GDeIhprg4lJhxL/HbxFu3iOocrysDTTxsbLSTjO33ndrGiz08uW96K1QY2F";
+ sout << "lqmK8FOhR8C1eaUdUFfIa/cINoThxQRkoYq3tUr313+VzGTRO0I2SSJMIHgFBslLDHwX5AbkpJPs";
+ sout << "7jCQESWTp+LDoN6g2X+RmiLlPQiU7iJXxF7Y2SU09kmHOg7HFPKWZQ8bI+m6CG7kvBF9lWNuzfr0";
+ sout << "lWz08lEMExIEAAngM7G0LqqOuLJO+Dpk+lLWjO5OICoxQ/M7akvSJHtc2mVOhXNskHOYG7ZtEEEd";
+ sout << "ceDN4bzzJ6qEuYMljZcKkh3NduZbETGnW7D7Ec/UqU6WNX01iMGt+4lCWtu8NHySWGqIXcX+8ITJ";
+ sout << "euDFM4B86RSKk38VpVUXXLWd9GPUYa795WcoVxHlFRPULnHJK0G2/AUuKe/K3CxHNqnxsk1aGdxS";
+ sout << "NGVyQfhXJBWBA22f0Eclu/Z/UdzpKZjCZFFZWt/4ppIWGyMvxheOSqjA55keu3QoFj+xhE0TUNlv";
+ sout << "cYMT/b9kHJnHQ+j9X7ReicHMNdtWYmzybolMFV5P8fBwlReR6BS7nAk7hu37K4fpKoKujfMUfE76";
+ sout << "PYTfRoTxKKY0atJg87ZhID745jym89xIqFqex1vqb3Ysw0+dZDWHbwiidUZjHza33U07hl66qo7M";
+ sout << "iHWjOPzrMEhad9fsvLGQkA3WNjj4fjx5TLN0mfieIPgAjrQvZhCO9DspEWW+jwRP63BYUU3rB2sV";
+ sout << "oEJfsyRSoam9ujGE2LF6ePYZhOOs41OHGsbUweduJ47XGdt8Z8+wxnZ0ykwvxc4eVsNbNERVJ0pz";
+ sout << "5hULsfB5BCs0jc2cz6N+1MLS397qKGNwnim8OBuU90wU0vMy+QWF78OnVb9jlg3asOK5riTvCWMb";
+ sout << "qgj1qPo0GfWRojb6L2JaKiQlU7r5LiJSmmCQLdy7qnbtc2ul9kKn7iQZqxZVlHJklsrsOCdA3/Lb";
+ sout << "kzpWUbCp21oyQKgp8PRkMv4NhWt8JkushMklTjuvDhvqVTqrzFQE/UJoxWEBEjeZWRgTG9gYISJz";
+ sout << "9mnT+89BXF+oSahu0YO8TKgNMqIDQ0d5GVrc/lakm/jQpZ3nl2tVbN5vLMtCuwbkeYRiFuaQKt8n";
+ sout << "+0HpFoECCJZdk6vyVFpIHNPSP7bfnsrZJHRiprvhz41LowOajR20mGW/+mlqo5K5eVxrW4I2Uumc";
+ sout << "Xxp+DK5yhhIdYxY1SkRYi4CynSNUPLqoD3RaTFKfo+aGwd+N1abtMJpWmE/Xp5k3NNHXWi95ltoF";
+ sout << "tU4BxGQWWDjG3UB9t6eUDoV/WXwCy5pxs4rbXLb2O9CM2HaBC+lDaW9RjxpOjI1jvAXmcQrs/MeT";
+ sout << "ex8n5RxFcyRzjbaWhd/V4vZ5+qY6eLoT/cUpYVOp0w3lQEBaGz0W8bcetjroDYGSV1U+3nBgcKPb";
+ sout << "rW1wp9l+x/ihlfM4yHdApD1WUTYfdIgnO6YAw3tlm0UxARb3WxyVjIQ7JMr1Xgle4UbFSWrjI8pD";
+ sout << "U3arZPJzf7oN1eFDdS/V4fZ7I/B+j580QOmGI5LGvSBVR3Ic2iVSbZuk2mjWgK/YW4vg3kp7LjZe";
+ sout << "yIxulHxWGysi//Oe7bqcJsM8Jndw3sId9UYV9r9GaBGBf+E2YWibCRJCiDjzO4PizCVOnTGJlOMI";
+ sout << "Y7BvFi0d0eLHGAfIm+3d1bUqD1XcPblx0IYVa5AHY90JfMS7gBJHmJiY9Gv9+WtQmLZmNkU6Tcy6";
+ sout << "yNyIMaIfI3UO03ApmhM9babcNpX8xhxYotZXzuxf19nnSdyOebzl40qU4BPzwwQmQMISK35A80xS";
+ sout << "dvnVDMbcRQIC4fjwo/OKNWtk2oLY9ERJ6ixMMxG532tF12jN4ZbWn1X1HGFNXvSU2bFhuOLMSgDC";
+ sout << "1poVi4Cjjgd4n49BJmlkeBoL88z83sPikchLnW1kxHubG/0VBwyHYBmocPrK46PL1rx1+dSwc+Da";
+ sout << "3E+edMWBVe/ca2Z2lBDFIvo29fjwwumf5Ljg/M2YybU9jZC4vLR9VcdPJ9J/gi53iNOktmDcBcp8";
+ sout << "ThCOSXuaDyKBtQ6nvd3AUGdNdLADPuNXTRNybM+Lqc4k7cApgM6DZV3Elpz38473d3HDAW4pCOD3";
+ sout << "K/y2pYyog+/27OSOW9SGoVkOuqOBKwqrFAuD1jIbI/yDq/LYajcJIPqhkUv3srTxHkPiOBbFy4pI";
+ sout << "6NAftTHf4GGy8VTGIzeD+z7L1qJToCogS+FoBG8ixRGUYKvUYEQyc1EhXdoPOYY6pZiMsA4xSGyN";
+ sout << "aU9RrfweZH3ld9QU9Y2kjSqVOhQavzIPtD2wQBIDWxk3/bmH1m76qrcEfY9WKCb7Sl//1oILVf/N";
+ sout << "b0/TZEqVSAFOMoTzTXO1ClXymBTA9b1bJKlQL/8DRayUN0NUMllrwHT1PGOmpoJ+AyhUOjEmWREv";
+ sout << "mKkaxiRkpyLdgKyphJhtYGid41FAAKacNN4CMl9W+fZnydgR0SDvYvpOwveSXr66xfDZlQti8ZR5";
+ sout << "i8Il6sq3+2ybJj/oakohjWPxMAA2rEvsOYkuUbviTAYQOM3jMXAVPIkgAYgwwXhvTpHjeRurdxho";
+ sout << "5dz5zGTwSRsxCUn6QSTtt4DhroBs4xBCmYa8BFsUAzG9nG+SP+ejIgc6HsrLCxZ9Orlgn5nkX26K";
+ sout << "bI2dWHqUL8DIgM9OjrXeWath283tUxYQlnQ8XJ2/IsmRkbq7eNJfBATvOmvk6CTJJdGrK3iX/xaI";
+ sout << "WTJ8J4qY19ehkQoD8fLA0W9FitmfQyLGzxFSI4b3YOcl5KQUvSDjPpG053Ok4LhcQCILRhoHk5jl";
+ sout << "usMw7nWyw97HV8XHrQkjsGyvkT0G0LyeNV5iYo6o8yFtF/s44/mfI3sh+b5AzhdYOEyRcnPSlr7y";
+ sout << "axPlZnAfZRfqccQnEricrTtaPvYaJIh/dbtnlM5crcub0tJ5WNWAlLdS6IShjiIivEzqakfqg1GD";
+ sout << "w4TzFl4VeqaTR7JWeGT+yyRwTfqkIEFWp4/7DyRSXF1hbiUifVaBwy4mpar2tHuSL/eknEI2Hd9A";
+ sout << "oIFg0F1y0Jg4uAewz1MvmFsfoHpVAT8ejEF5OdFjkeUL8fjSBlX2SwMNaV07UlWkDDGckuHch6r0";
+ sout << "FIDtx71pldsL2ALaQgKk+85QC8YqJdzgtfK/lRT0MDYOh4NlCZoFqQAImSYqasRtsyeNYWVi7Hh5";
+ sout << "HUcaH4dpUrWFu8HFXN20rzWwkplYAGTeG1/S6T/sBXzAcv78VXwd8ec3+Soa/ZZNXY+yWuHziyrP";
+ sout << "K7FFd/fcYHML+2bB+E2Trm4MVpzv7n0a+Kuh9sy0SVe0IuVJMo68R2Ftl7tGgF+dTKVKmEHLHqcR";
+ sout << "fSjxUAto2wxtM3Lormd/1Yz3loH8CL83JrK1f699TQTZee6voQPJvlvKJlctr3NeWB0uwk3TIACf";
+ sout << "198pytEpHGhlLByCzpwV68aPjs9EV5yekrCJzEp0arZSIzfI/cgwYAZ35Ukff/bp6jxg6AB6yUL8";
+ sout << "Muq0rqGDqwpfTfTDaeBJAhwevEMyapDwrbzwHhnJLi7ZT5l67Q6Xjo/A8/U60t9m2Sdm5ULwzBdo";
+ sout << "KdSBd9S1UUueBxrx109Yh1gjvbk6k6d8L/Tuqjue77ZCUV5mOJ/rmuTi6OEDIKdzZuBgooisaIg+";
+ sout << "CargeAu0U/JS54e+b+Emr/56cogvgeJo2QeIKT5yhqHxtoEvHjhA+7q8MGuosAeQtp/9yYt5kPiM";
+ sout << "j0d5GjKTviHQqgISzuskcAWAhjIfXSyVrFhiL2tZ9hqS0u3juXdBoUy3nXcx9WF1tLkyxONcILs9";
+ sout << "+dOGRg+VpsTcjRqQyJoKm6OzVoN5J8iWKFkdZGLm0IM+p7F+jRasFIgJd4iPjHMsxYBlFX/aNcxT";
+ sout << "dt1W97L7uw42LMYVG+w3wnHdkO/ddrsp8DemUQzsT7yGDOhg5LMBqK+Lh1tcizQrb73NDpqjGmeh";
+ sout << "wmSb9GCzxDxw9uxfaZBBbO7M4KXJiANLFF86djhSpYgmApyhO1QcTSl2UR4XptqikYbPGFFQzKIC";
+ sout << "qjvmsXWnECWWSZK6pnNnHWZOMto8VUmPx64HIPe1hW7KeRM4ra7J5Imw6F5APG4+Jg8IUf/sH86g";
+ sout << "v3F35nnPnYeVvkQuYP3iLNapQqaKR+pQizKxXj/wgPagwSRrzlSDbKM7KqwbvdBGVMGfeW4PTpuP";
+ sout << "FxtmMKzWMWAGUtPSBHktXyt2yn/Gp7HVfirM1FYaIOsWvNPtBfsQ7HGA3dvKbP9f6ZvRSonW/+kU";
+ sout << "S+jp/UP+sIfihswVxrP8TO0YqQCwz6+X3nCT1pW7sAJKggb+DZyEisyl0jwK8fVttFraIdhqO2Ko";
+ sout << "o2Mfc8+V5jVTIC/VPtuiK0Wjv0+Eov2U1UgvFM0jsGIPfVZV5UDra0rWjp/vnzo5C9MBy/MxfCyZ";
+ sout << "5tC8AnEXpI6V5JSHSb8xRAHWZg4HUMmoZG9u4mX+fW9Lc1OET8xVMzfN68kndi25/bnfx6SpR6f1";
+ sout << "30Nhy8rd93qy7DpYUHhLLyhBgujuJiT4scogc2iDNwP/Fsam9RR/EeVjKv8gwDHrDSbbUbRbzfUD";
+ sout << "fM7oph68ce4/sQ0rCh1WbtPrSQWbNU225oOYZr1lgrfxlYnI6ROd1M8nWS623ofggs4Wfh5CMYXJ";
+ sout << "OYe0XlSNliVnU1CX4MXuX8dIXaEmU2HUjfKWif6YAQY9eSLikPWVgEYvpthn6dur/TP1jC+W3a/Q";
+ sout << "vigZwfynqIl/NC8Pe7tpm2qe/K6TcOgru5ojdaPeyLUsVx/wqvvVlT8ZOgVV2vKOXAUQ8bLXAH6N";
+ sout << "kccB9Aw6NLtXi63VjdwUrgfBwT3orFckTuYA4u60vxs6e2ocbkJa8YsPyN71Q13O7t5q7aYP3O4H";
+ sout << "JbDKncX6fZddc26LFxgRQem44nxcab0yoiP6H3cp8mHB86+vkBuSQHWGSWN5uNK03BP0rlUntMT8";
+ sout << "zis2iDvOY2gDKYHPj57HszSQ8/SZx7MFKiFRTIKMK5P39lNz0ylKhskAfZ4KDzLp6xzwjwxDsy70";
+ sout << "LQatqB5ojtaMcglatNYvdU48iHz7T/KPIU7fs/vySYdx4EHZ+Oei/dKFvpdK+y9lFyRJUoBXpb/L";
+ sout << "LCs202cmGv01lonaN1Qe0QEL8OTPqvPwGb/rqsf6CNobSZd7mmiMwlQ4dARKyk3PWgYvalT7nm7w";
+ sout << "aZvPlDUnws05FXtATiepdeNlffsCK+G/TEzH4vxOzbFsNwhwqL2tVf0t+1UJwD+NQaF4p0mSoA8B";
+ sout << "TZCslymygnhgF9Tu0jHPigJ99Nj36RhdYkndHiry2OlwpTACuBv2CFv9F8SDhfwuRw4sr3xQpomb";
+ sout << "vpTNyGrXXmSVrsTnGP506nCBg4IQOi5zXA6MgVNXiy62mh6u0lceyZXfbvCKkGJi6wUj05DdrCvY";
+ sout << "3TrTH7DxcwTZ01O1e2A+6A4v7u5/GyN2vEQb/p+lSZ9XAWYLywIId2pFMJ6nDuD4HKuoVljrA4+A";
+ sout << "wrwACO6QAn0H15I4nIBRzxgEhL+/U9Cmhmudtl5iKeZCifsTrIda80CJ9rTl/w6knmll4/9rLceK";
+ sout << "QmMKBXjd4sdpcDh8YNAUa4lWNsgL7ElkbTzmZlPe9ScYHAfcJGEtwrOuYDp5LfJOh76nM/LbGAkF";
+ sout << "qR596/Sqf23C+pfe4JH2lerREvbvkx+N9b6lRHEEU6tN5qx4AVPUlVDi2/1qo4wcIFJEKYS4bfKc";
+ sout << "jZdriM9msBQPzO1GDoVTfUxcY7vVtkVSjVONgpjlM4Bsr8YdqtRH7xAEuMFG5y97X5PK6KkIAqAc";
+ sout << "fXc8BSroGGBcXmIx5dAh/U6oQAu8gOe8mSvy84/dz2sFbioZSlVWR8asbFLDjrfHHpo4d4EX3t1D";
+ sout << "XsGHOG0QPuj5/lH2QZerKkfyMZiD7rSJTnpf8oUyY+gw/e0aZMHk510G4ybRKgY+9tCGnu/A3cJc";
+ sout << "5wi4HcD1PUKmZwKZI2PIfoIv/VUYQSZt7U26rf0wUZEMcuUldIxiU583HZ58fUGjmNhi5uWdDeIF";
+ sout << "pgIvY8tvmwUNT3W7H3iw+idsfn1w1CjWmWiCRo+yjAB+TiOxHD4JCCNAcM5PCWDa/NXLBA2tO1h9";
+ sout << "Tk8h96lBlAtu4IQVJaB6+o7iPk2GNq6FAiFQwYC2F/Oh5A3wKWrd2kqQ8JeZtNdk4L2+eHg3aDAU";
+ sout << "4aXaYzfxDSOYHVHLgDj4VY6ulSZx9DBst5UPsHbaLaANf98uIZuji2hXOcg/FEnchb7A5bb+UIyV";
+ sout << "jJH9d2D4R80e5GwPF/vVZK2GQr7h/pEgdGTh6hr8piMsR5fzJpEI9dS2jTJnet4FtnhDo90fImU7";
+ sout << "O07J4AT+uO6k0B+lKuc5QlmaXg5Hm3hCqq5CggvnHNKWHV4f1ux0GNn19RiXwX18EdkxSEDIX1W9";
+ sout << "CmdCn7gQtqznt3pNq3apLeITu3phYapWi3FlHhxoxOE8gKjF204FYOBeKztMqCqSlYa3ALQiac1W";
+ sout << "bk5HyFaFvWEG5Rb2gs7bHSgHuDFlJwErbiDfHM8415fDkbWAOMJFTDh3YY/tuNz4vU4bNC7TcJHn";
+ sout << "o5gak21470Iy5oZ9WHOgQm07tsFNxyJg/98rl1fQ0lcjNF+2nrhj8O3bXRpgt0eh9AQeu6QVTC9A";
+ sout << "d7DzZeQ7N3HCWJl+VbwFCW8Jyo5umAkO5qFn+c15SO7QZ0KEjEORTiUnBDXzAtQjDHURr4yEjtiS";
+ sout << "LkLTQsXwiwCmBOnoTdEPegCHIEeVSIye1t0x+dR2DeqK3WA/hB8mFawdIlhtmdEOZXov1mutXVDB";
+ sout << "LBN80uBR/EypdQyctL5IR7IZIS8p1vG9kY2yceXlszfRZ/t1H3jx3sRSJEdVoMOP6Es4IpKj5uw4";
+ sout << "agZ/POEGbn9qKdL1N2MM0R3y6D6YjIXhQF6QO3XsYlzU/OeOWxoemfC3/qn4yxwbaYyEWjICn/9d";
+ sout << "e+z3q8F3Kh2JC5VsuIVaKZ7MqXxvtn9SBuUQ2vsx/VqH/LyRi9glQZom7H7XYefzQ6neF1DobC/p";
+ sout << "HsFSf3Q8QV215kXq7lm0FUxQHcA2p4F037be3CN4V8aa4NqL/ChPqdR/EMtB4C72nOJClAGY0BsL";
+ sout << "4ApuEZj/SfshLKcnMdTcJTsqBDrm2ykdu8JEg5PQT3vpG72eUbAm3Mbm3CYImISFQiZ07XYDnN6J";
+ sout << "DUOjaiwfjJ4ydnkaAQbDnChIS6JXZV1LEMS19DCB0TLoRl6jnX3mNu4S+fugfFHqXYxgKIAen83P";
+ sout << "YdaCotX2Q4qj7YjNhknIksW1OPA+8VxgBhrO/xBocrJY+uEi4+iVFJE/8Ge8sXmYr4AcY8q6/CuF";
+ sout << "jYWYkmj8UZp69kE9QGJdiwmHBFOLIgAdnKMXuRS2NFZgFcoGMWMxeasH12WFGRmtuoS+eBi7RzUn";
+ sout << "CC7RpHf++XNJClYpMSPnEYafiDQOcMmjHcaBYkEBB6kRV6dPtwrb0ZQPGVFYnhhcOtcM8tkLD0k6";
+ sout << "Ek1+UYaG75go18OFlshMOzU3Rq7SuuNwtSNQ6drsWWcGYEOWtQG1b6DniQ4+e9hkP1sHI9GE3jMV";
+ sout << "tbuSi8Q7pRY+vzyHyNIOH5FmnWCFPYLMkHt+6aenSR4bh9b8Bx7khop+XymNuNOZg/hu863yOx7P";
+ sout << "D/OCXH8BflcACMxVMFKcPMxYZHrxC4cINAO2wY7ooge06LSlM/3gNNwfizSn5miDw+3d4UsnnD6G";
+ sout << "+oMA9AJTuPcy0tMqBsGqmBZMAIO3yHWroF9G0dS2TVJqMn8Yt5shS7TKoVn2ognWInejfbOSP8nl";
+ sout << "wliaBSd4XCMRoacNMgoPilfwQsqwntNn1jsNRVG3YwQwJnSWXo0YUvbCeijqzP0pX7IezmCY8qgI";
+ sout << "VG7spKv3J900W4CCvbAlNsfHnNAnjU/5MWjs7I2j+WhuFv+b++3vDsedHI5nUWjTUYYy7N8O48LX";
+ sout << "XH042vXLucXjT4inm+IWDpj+br+GmppCY/bZzqO/IwzMt8pWiReyU9NNtawN8ag6FYA9fxrqGwCf";
+ sout << "j98DdzKzgBgo28R8Y2Al0oC52pFovY0Ym/ormPTSwtehSaq3lTFgbCKBhYR9QPfx3vMbYZupWWCD";
+ sout << "FjnZYkKtJltasN+SKs7Wp/cCU+U+6zPAqHmv+ZSWiNhC4lPgYDGSdyhPhp/sHz9rW582bdo0iUT0";
+ sout << "8QRPz1UGheorTDK0K+V0sR6ltaIS4/5P7/QzmVvGnnilWTFmqdBViBZRFv+3wp5xMqMjtZthbZtI";
+ sout << "3YmEvxUwnPvJCu3+veeErSbvB9AKDj0PJGmqLx8Cw+SbQgN874XSSq+w4WVXCT+GOUrspzZR1LNi";
+ sout << "59SSCEs6AKkKmXmjk7sdl6FocyINLIGBUSPc+nwCgESD4xSaR1cx2SjSf0CynPF/YprCxvqRvQEZ";
+ sout << "dcKWvwgeU6oJP/0tz55yIZjiIi47ucY4ySXSel5iPoofoh4Gg21tguOQovIAbzqYrkN44bmBy+2S";
+ sout << "Bpxa7JRdQr+QXwMMEbsnjrXbdHL7J5fZa/80FkwOHY8Fvz1wCRoZ5y0EyAZCq1s5H4Vd4yHMaZmW";
+ sout << "ZNJzXvADrDSO7NxyTDGo4pVX4Dxh/v854SiHkO12eDvoizsraCz2ENtE6cbpHIdRq0Tt/K1HEH1E";
+ sout << "V9l7eNcD6XJrvSe+6yBEJKD7srWbyC3MPSzfl7cFEZPbVzfkjSP8H7AnZl9mFnEv8AbuLvHJTVzq";
+ sout << "LdwmDV35e4jFpDbf61+kzeO0F9ABrphKOY2YP01f0sKoIicMWOcXyBVIMCtB/9Bav/Pn6bH0pyn9";
+ sout << "fyVa22J4yLzvKdYKluBaKUYVVUS5GnvTT1Qxdtlv0AWceJApOg4sHx9zZmoUMDlQQqH54S9sikKk";
+ sout << "NnW4TD+CbdPYFH7kgXAkoFveAOIKcZZhTUD3FLbarLFqwcU37NSuqM841MlPW4Koa5Jrc5ySoL2U";
+ sout << "FU/Xu07j/PvrpxdVHbok0aRtPgNm44+vPDAWmuG8zuxqluSW2xxzPi/YVo7nsWPvJoDKjZA1aRMr";
+ sout << "LehSLFoQa66fCq39/XXbQp83pL9b3tFGRndFnP8FxvMcCaYkG/SDevUg+cG1ysNbBSBJ1plCJe4S";
+ sout << "loGs/jQUQCNIX2/mR9L9PHM7FS+KfQn0tHgjipj5DRGiFInscy8cXn/44uYSHLvZUTvdcg3+I9hi";
+ sout << "qdV0luQoUCZm9ooT/SBeEo4VzB6+L+fsuOlYCQloHNtmJ3KrHtRtbk3caJcrbTcoQETMbSGN5awK";
+ sout << "XueIC5Ar7zL30I4/DTjrhho+/6Kf4nRhplqPpt+0lyLOSk9KjeNxWktMwnnplhggnxLqb9laXNAM";
+ sout << "juhmMFkTHN21kH6F3TDmBE6rl8iMHHn1+kSKhhYjDeb+n35qu3kYYEg651ow0NVdq+aPC4YYCvt8";
+ sout << "44xqYd6LxPeXF66jNdNl+oOBdds5MLOyW8uIFdQbmh514Q41qZ3TEMieqT08Rp7qW4/mqiPjRzp8";
+ sout << "DZgOaUXJnQnUKkVbqqaiBv+ot+ArYe5JB+e7nvzgyf2zxCZLk9/Szpqn+JbGGEHFixV99XjJQRTU";
+ sout << "q5DcOkzNjd2/Da4VEXCfjQPWgYlYu1u4hSgasO4GVpGFnyuj76gnzPGbkSVq5GMJ4KzYtl7/sta6";
+ sout << "sijpcScgVlvvL6ff7fhEVnwQfa4mCZn9/umsDB5ZbG+VJufprMVSb5CMYcMyOU2KYDyvWHE6jSP3";
+ sout << "Za1W9Fjo1Fkzx99+pV5Hex/GeiEOxyYCPjNEvCNpn9xJWjU8fx6/6BbN2mGET9hO18lol0OFeBio";
+ sout << "88h5F6msca4plFbxu7eLv1Mx4kyQNmvupkPaufis95NgiPhNwOUSPsGffK8WGIaJJx7+g/SkgN1h";
+ sout << "RqaZGgxZdrnY9nOOB4TsqNd1dG4p8ybarjysGHg0JQEFRcxj22DlC2PoAMEEILAjYb3cIX0xxy18";
+ sout << "D6TPca+XaHi6LoYixk1zz+S8FdUGweRHS43IaWzBa9KDJ5vEIxAAXpyJ5Uv5PRcBhAfdjqjeT8DN";
+ sout << "bVN0R0D5QQA=";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin, sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin, sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'frame_000102.bmp'
+ static const std::string get_decoded_string_frame_000102()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file '..\..\examples\video_frames\frame_000102.bmp' we want to decode and return.
+ sout << "Qld+lmlZhXVX5NDRWIFG1T4+OGdJbkmxAXXdHZEtpDf6knVTlRWyAhgv85Tf11KJZbhKv1dcsKtQ";
+ sout << "fX3y/9RxoGwtXxwFXCxbBrLRaklamboDoK5N06vzzd+CS5m2gEc6/c0nITY0hfeRURvJI6cvb8OM";
+ sout << "W3UbsFYtU1ntq1umGkQDDu2ukl7sV9LfHj5eh+n3zjy7VIX/RAghwQA44rcpnh2d9icwDh3G2dCo";
+ sout << "KxHo3W4qqxVLis7IIGg0dzwguoB8ZGmlvOg7hdMz20beAel/LlU8TMghQk2Aj+4NTKQRVwcfXfdd";
+ sout << "AyOhcT9wYaLabnDZgL+7yq8Ji3VHvlGqgVD14WJj87s83dgMGikK03QgTzhyY4dwADT1aQ+xqeec";
+ sout << "K5d37u4xE+btNpcqFhHx1YCp1r6+O9kjWRE08rfb1qWZB4qyCwTkxc1C5buXaeNRrw6pHHqSLWd0";
+ sout << "9IFAbB0ukrEWmdHr8lVKAfBzL+fTFrAO8gjgT/vf15asPfqSIl+zJyHK+YJ1p0F/UwIzPkKl5nx2";
+ sout << "O3u06yBtIHKdflKGjRSXomDYGjrjeoLNjqF//7zBBXyNXXv9TsWwekzsX7A+qRdm/j/C0lCHU+sL";
+ sout << "g6wWh7L0kL4ARqClFQfanO+87/EPhnu9LuovQzTY8PhLB7a0R2fhAncpTEPCHBm/0bTEL5F7V+W3";
+ sout << "8yzUfcgi1NOCAUBN1tVwWBDsVv/NUIl1x+FTKd1BFAsN7a0LLZO3IGnCikPyU/Tn+TCslyvvnU7C";
+ sout << "SxgTwVZIeDC4sq0l7kk3Jc3UWrDI77EQHjxNFKG+lnJxMLXrNMy/lJ621WGz8uVPvpkcTajH/YdN";
+ sout << "eaRRgbdqnkqmoSyX7cJeaZG4NE0n+isv3j7WU3cM6N8WKhYCH0ReG7XjbdBgrPHaUdFD1dV20eC1";
+ sout << "2PSv5D2jEJkDUVNdMZSmwOtF/UDg4WRMY80YLGSKhxqLqoszhKWXltFdm1M2+WPtSFOdr4i3wZRs";
+ sout << "W1At1+bmMTFCM18KnBg0mQv6zyjKRjzBPHLiQZQVBmv3al6bQnyLsbpNDM79oVd4cdEv+l6Cze9Q";
+ sout << "gPwotlVgN8hi9gcF3lH+KTQ3Nmd7WK3+125inZjFasNLUVJZNLshz0DVnKWeaO4RGLx/Nb2ceMwG";
+ sout << "YknvrhsARhzLZe8ZFcdvQs8JnlXwuNpN9OKtBNQMpQHf6ytCWCV/ZCsrFCT+GnKMQxCvR4/LNIFU";
+ sout << "vWKvSIe1iIDoOhUCgJknwUH1M1Ug/VLDpOxT4ylapmV+YoKSYnJANOQ3+3FhqHRc2KG8vGMTNzab";
+ sout << "3rs5TEQVOdP+fEtU8oynfcb/0WK1BAKgmDRhKcz8XG28EX15aWyr6uJkAxqYE8UXHOr+kn27TuWA";
+ sout << "49KqCFNsNcjhvnjWQEBvwTwGRKofE4ASdnQi/fF6Ozk321H0GJYpXETW69XJiNzT0f6lDUYWR+Bd";
+ sout << "HC1sW256R9dn/MBAcSH2FRUf0Mv7z0yhkb4aUwHNvGhHPX5g0Bhy6x1TCOqHtjubkHkxcq6YpwW0";
+ sout << "f+9DXG2M/sK+IPqqgy4tY/fcJeYxq3Vd4+voDPYfVxcWpmy/4X5kzf+fFwtXEu2MBdYKQEKviV0Y";
+ sout << "WixgGguhWig6d2vAbUKgAlJd9yixtBum62nEK1vMZmwzfokP80fimNdw7Wony6EA+v1oB+zHQojG";
+ sout << "1fNaBTacRa+NsW58U848DFrsJ94S/9ss/RUx8GuaWwKP6WImvkhyHP+YdbtkykVPVzm3s798VWAX";
+ sout << "nvjz9R2NlupUdBVgJg03BAH4DoJdYWOxnLtUChjW++StA7OGgWoay7DDs/6AJYycFh+TFW8D/HhA";
+ sout << "sEQCUAYDnpze0aSX6CBZ7vHMr4z9Gr6V1zE4dbJ27Yv7ZFBetDLTN3MN1vk8SYDFeFHtzyMBfqor";
+ sout << "hHXu4rKtGW2j084a3BQvlDVZ5bQX8mh6QBENJ+C3kBUhE4QthRIUwGsuarKCnOUAAL2Zt64CgGKi";
+ sout << "gRF4wZ3V3Tt/FZ725pUJe2Td65CddbBd2mRqEzdN+qqn24au9ucznuEr7E2j0T3ey1RWEcU7Xsx9";
+ sout << "dEX6LoyFqowIEvyGOwdTlycOq4v53Z2cJTaitjGSvtaL1h1ASDW11jvsX0OHueMJJyLFwKOjj5JL";
+ sout << "JrtfkGvWRrycw3UdT6g2n8IPKk8j96hds+UlDvgwIzmlromWvPfwixA9mt8MRqJlBIo94lKjOvuw";
+ sout << "sS5FUcr31Dqmaz7uQcrwmxh35jAvW3EbMObKNdbaojaZl3u2Szt4LA0yFs7UJpkkJgEO9E7lGte3";
+ sout << "WAUC09lpDNq1CLvEfOEsWQqzXayT9AU86CCnFgi6iErKoA3uNToSaqQVOxOi8RWzQWuNfmkluBMF";
+ sout << "6NmgipWLj1nlLLBUOpQD6nVlwYqZEeiaQOhHRYUiSUFbJ1EVdBPoRi9CPOquM1WX5H9qbUOBchB1";
+ sout << "/oI2wsh86ZqOld75VfD5QsHXmrFsQ74mO4MyvAojodFWADFmRloe6UoEBFYhjXSlj2sUSTyAMfPh";
+ sout << "jgRrJwdMb0/53hm/Emqz9AhG9FpTI5P6doDAm19dGp09NvQyYoZnCCWHC6fPpl+BS9WANNfsSZnl";
+ sout << "dtI/JmBbtlx51QDTA+Dc/RqGZZFzIIDnBoIUnyFY6FBb/S03rQadSUdeuBInghBu3bw1UH0AImFZ";
+ sout << "uy0gKklA51AkE3xVxtyGaCniGugrWwcGYl+FeiisBNKtZtbo3spTO/Rr2Fjg/mbJ1I0dCDf2gJt5";
+ sout << "qkjF7zRl35+gl58FjkBxMdzM64289w8wqjKO6P1JBnJV2gsbU6QYGN/w3Xy9BC6Yq7WaHU2eUves";
+ sout << "tyuWAxa2+pl8SzrrvubhGUV0Rx4wVcRCLilUKHQ9VoCEJeou1y15h/hl8BkmuU65Do4a5txsH+Xt";
+ sout << "5zVkkmGnvrvq8xMd15UIz867U7lt+BwgV8UoDz5ZQnQXvwt4+NLK8dLaRfr3Tb7KpLPOdymoOW1H";
+ sout << "7ZnfmR75beprkaMn29qygGhlTmVTxNC7CppNNfS97Yc3ki+g14rkFZMmIpETfUnM1eqp9O1+y5Lt";
+ sout << "1p1DrkrZduSwZtg4DmJt/2g0eV5xO7LPishn8ezxLTcMLCQYZDM4gjhiWt26U0XEQClxxouxQKiV";
+ sout << "jpcwn99nBEj5r7tkQLfJ4clQW8gOuBidgJ3gWiAlR/Pwkys4L7cdQybo8WegBvXRJyNi0mDGsM4k";
+ sout << "cD8A38+zxN/mI7Z1Gv7y5zWUMrILQJvPVXCCWXo6gikdJpkLrprcp0RUXo0b1iUgGgOXVvY6PGA9";
+ sout << "HrM5/u0t40+Yh7+3XqzLWWU9kuOgDLvjwYeuRu7oniMmTWOuRtkWyqKmTXlrwjBz+yqIQJyiAa0h";
+ sout << "NF+n6CNX4Bk1X2h7DIbZf/hCbYAsEQ/zHAudSiEfzew7T8r/h4YG2huCTjV6NrBNOzxk2Bu2A2kd";
+ sout << "PbIOFbLWphpDX3pQtfSvx46IL0uX4O03VgTiz2NCvNKaH6oZZ8mTxkV7QmVnRdHG2my+oFsqtkv0";
+ sout << "uRF3h9xXKDij2r3lfNRhHrM3OAsySepTuXa0hN2bU0DIiX+BlxVjBKicOLIQUI2k9ws3akGMpicU";
+ sout << "vO/fKIwOs7D5L2oUOqDcycceYM3DCTjyjdeP9hq9+w3nxnuvtoNuFnE67nR0FA/XSYLjZOxq5EYD";
+ sout << "N4LdNhDWzYFjEHXgmyvc6IrCcltsisuRN1GNZmXDTd++/8CdA9ThbqH/AZvMeUt3zZeKMmb4Ad3o";
+ sout << "T9in73sligZv9ed32/6rc3w96aoiGtHgv5szIJ+h49hCNfuVoSR6GB89zKr3cHcsQ1PyI9+S7ON/";
+ sout << "/lTAQtwIUSGSQxM7xmawtKVSit1zv/gw307SfoWAZMmz3TzrMzKbe3mmNNwW91a6GLqe1iyaQ9Lf";
+ sout << "KdIVsCHzqq3DKgex8X0XPnP699SlqSR9gjHWZvuVOzZkA5WE8oatWQNS/W/bhBaaAiT5/781O1s2";
+ sout << "9y+RfF/lQh6tQbhPthbY/ALQDR07LM43qqvdaiinxT1gYh3pJJHs6oEQEAZwNvbjUZURmwbJ+hE7";
+ sout << "fJnWbYtftZn7pxlJXzXueWe7IPmYuzrn6jJEh4xbnANe2gpzkGPi2mHIlJ8ij9vp5KwkaZ+SkmEB";
+ sout << "El1hUdhsJoXXErzKm76RD0oLriVUGd2yvI/pM5Mm/3/pykt8cARjatq2CVc2YE8P2M0cC3ItKD0Z";
+ sout << "YqnMdxmJpxys5JoN7q9yemZFpiV2OMJ24kPc8BGJtgbVQBWpMUnNEOZOI0P17D3QRehHWf543/X5";
+ sout << "PqzfHj9J1tUcJMNCpp3GFxejtpOvKD8J7JB7+0QzqCCOE8SQIC3Bs6IB3IeZiYj70f3XU1UPF028";
+ sout << "Kbkis0YuvkL79pWPoqNAoA7YVVGhnb7ed4p2/lqFw1NHG0+H4YQO1gz8U6MeiUrjcxA/IkwNFAZX";
+ sout << "5WsC2QSadmv7OtUT6L9ifTCs0sogqcC5UETke+cLIfprQ/c4qIMP6JWMqRasR4+0qjxZbSxovEme";
+ sout << "0oIBoRGE2za5L0/Ulwcx4URU5/iUu9u7IPEcUVQzABaMIN4wMA4Lbd2evs1Xjo6xcFLK/DBtOUdM";
+ sout << "JTmiOUwloSGDd8MSx25d3BM1HdXfUc5uYihDFkk/0AUSLa3x0GbKimowJYGAkr6Px/5MsCnFW+Ix";
+ sout << "anv9LFErILEMTwCvETg2LuRDSC/2uyGNnndivYpOu8QGplJbW9n4ALWpvS/dc6PJKjZo6GhqqK/X";
+ sout << "U8Q39jn5hSYIGb/LXS0mMJZ7CPTO1AcNCAZb2r7Fc3orKeMoBDA+9R2x+ZQnlGWV8PdAzfpnfsIa";
+ sout << "O+JL5Nh2NimtPQsqEVykbF1RF3caiYPCcEj+SAUFSDZ2pe+YfXPOpe2EOZ/BnGTQWdn8u87mw68G";
+ sout << "Y2zcXQAjUHXuu8iF3IXR0bbRR4bDKp1MJ3YW3nlk2E37v0SEkyX7EpXg8hvNB0uReBLOk8pxw4Ha";
+ sout << "8R6ekjMAfMnOQynT0gi3tOoy6/CNXBQFAp7/u+Owz+DO+pRX9StkBjFUWe6+8ZeG9YJLnNJZUGkY";
+ sout << "Ki0EEoWpBbW4zvGa0pOPJ/OceaYIzhL2/1DohO4dO9jJkOU6UxpYaHMOqxr56K7MCpJF6k3ii/q5";
+ sout << "yLDM1Ee3peTb6MfWrrkeYc9PhEtD6TdlJwQpHkivrFvBvp5tz4Tg2zoW4O27SZZn9Hr0MPAKCT9E";
+ sout << "Vouc/LBstfyV8cwXm+nqNGg+f9ZLgns7No8isMXQTPso65SeRw1K48aIwSqeZ/LZ3fQZ+n8l0M0B";
+ sout << "BSRAF11nGpSNKY/G+7vKuJ+FVV1z1MXmc9wp4tjnYRMEBsmhTnL9fFpoHvDlOAgI7ZN+YTKFOP8R";
+ sout << "RzeXkz6ne3/ztgWIS6QZ1xFzl3rEPwvcADeMAT5C/j9COs8YXQUafQ4o2I42PM8PloMpz73nTNqR";
+ sout << "yot/tF65hOa7X6sYFISZxTXvAq1qaV4j/bOyv7Rr4IMQkvPhZRnKuOKd1lEtEleREDKdNA+3Un3u";
+ sout << "3czlYhhMp358CRLKGyMNsLc0oVmrc3Zf7AljUJg3j/WmyZdSWCfCFAUzh4y+OF0r0LSxy566EC0f";
+ sout << "EvoeoqDl6Dn20my8VVm7Ek7kqtjGX8OPXh+OYw7q7GDHCLn+gSQKRPZEMAVwFPfY5J0qUdn2EhBi";
+ sout << "bIdpDLJlX5A9Tuvsfk6/TgWJWrd1ljnk5LRU1JdLpy9Gr3i7RuUzHVmxutyDezdNhJQ2a/dGJoyL";
+ sout << "wqtRLh+sAldKv7SyftQMUr3mVgcuViK24HOW55optbEe8LiDnZVyjx0fWPnvUF3jrOu9iKpCzEQb";
+ sout << "UcV/ItitPx76PWI/HUsXW9nPkyHBkVhy5n4A5OFbK4O7r1eiQBRttalZs9UPIX/zGjhL+mKsmDPZ";
+ sout << "5W1ejsVxJT54FOwFmCD6yxgeWyGirMHJ8iIwL12Vf2WG3NvnLTIMokCOIrJIGl2nsgpYWFh66aLF";
+ sout << "BY52VplzlZ7uvieN8T1gFkoD1GKDTPvxJdbUIyCQG09lZhFsqWDT1h+sZMk2cE4UGFa9SUZY8cC3";
+ sout << "n8nDCxxmmQybJp70Ig8cQPy9US51O+PL8ZjyRJirlSwqhSXndjyvgm4xk9uOSk7AooavjRJYlO7r";
+ sout << "Mm+U11XEN5mZJL10W3nMa3Ls99DQIrifOcWVAaY0pEjmeBWHizxOwwqt3X9CY+Sa/X35awGgaQCe";
+ sout << "CaaIXBkxwkRLD8/Of5fYQ8Dwg+LWZnkMjs1IBdauSjdWvT7B9CXK7O+YJC5kjduIGOAu8qfJqG6o";
+ sout << "HP/uOEvCq4vT6Fg/gLbOjNenVhZSq6bAVSgG3C8/WonUeNQiWnEI0xLm+WOU1L6bcqBNHNBOOiOB";
+ sout << "yn1YojmgQszuBRj2yC4rCAkddymD/yzJ99g8QmhpzlakTRtpjf9LTErhIBDIQSd19U8mNdu1u+cT";
+ sout << "PeIl7vlVmhhnojOHfLGRvD5Zy5dC9XZiXRRcNWHj9vOTdrmtwy0mPIZZMPSBLJnur3djgCLfMniS";
+ sout << "Q8iBSo8Uzf00zSZBUBUmTekAAHFvf+we0/ERcXVkJ2ikfZ0v3sxGf4zIVGAiI0Nx6PZTWgW1xH9a";
+ sout << "3OYLo3mzzz7TVy/Yueh2awJyZ6CWq/ysu9x3tSJG8DbBBc+4segNq+IKeu52mYTcLmASIR9PT12z";
+ sout << "ifAJWWtKNPR4T/FK1pDTkaVkULRPzqHP8jxerpBRMnt6FStshZiZOs+DZVrXJxSP64ZxFthgqMd2";
+ sout << "UvzcqeI8yD8ndIyd69Os5t9OoeWseLPWq95+5OUX5v13wiS+EHbJBNnZLLt4K8udeohSNUXEBDND";
+ sout << "tj7TDkzLyXYpcm3KYkJgtRY88kwidchhjvMQRDNw15ozLnR01t7UsNLVLuBij9sUBDA81i4YXyrr";
+ sout << "5se81KvEZG5kLLkFCUnPKRU/3aRb40F3zvY9pcBU6KeJsjBQzKRX/EqFise10qyClgAyvADqyW8P";
+ sout << "jMDslvAq1mY3kJJU74AQNRuFP0iBJWTX5B3BinqVoE+VJzZPbziFVs0riZccQIoff2DiHS49f06c";
+ sout << "NvWgHb04xNXz1skoINdhFVRSsjik+qmxbFE90F+h9eshTMxwQVobiRLdLv1pGkVOpsRxl0eDsPMI";
+ sout << "NECgDvYSyRtSxXz+SUGk4dokSgkrTPa6NDx82FYiyITDu7wcgtQBOTj+SXRWrrG+KB1MR5cRHWP6";
+ sout << "nmELArZ2JquWF90mbylFHbzKZNB27/TLN3X4/0cVqLNxyzow44+Z8f26lOFFnh9qfzqS874839gB";
+ sout << "IRvAiKnuSo7KCh9BbsvAHk8a23Ei/KNxJFH745cvLab2oVcPuFdwsDvuYgPNT22tuLb+/QN7djB+";
+ sout << "6a84/73zWFCnfeMlPgtGbxUE9yns4nYTNx+jLDUpKAmYBWeiPEiGhsvJYwy/dKvXF+CVdUHFbzJO";
+ sout << "r6rnCYGBY4Urxy+S6KPldLloAnA2SgsqfiU3B59g18OuY2NyPQ2fj5/Ytpum62hQo3SlZTVGXFq7";
+ sout << "pInYYEueHtHIyJA7xpqX7TG8HrdRJRbz0Obs5X0P1vbpELhaaI0vppEIcfzZrmrqQfOT9fUILia4";
+ sout << "DOCR6e6TWVirmzFa8hilcMLjG89m/+h34mO+pwqNsURvWoJ7oixSe9arMPUl2pKyA/wICP8AjQak";
+ sout << "XXaacTp0DMIbcOsf1NY14cuKUgl7tcst714GMXCHfPCnYMUzoawg4o6LVOzHwHRgys0aWz4aNmQq";
+ sout << "5UJDqc1UWVweo8LD1++5BCSMO2pNJaKdbHqVhT9yJP5NxIkV1UxvH1TowNN4SI7q7BMH2fxRBB5e";
+ sout << "bu2XU2s/5Vl7snSkKk0PxmbiiwWGiN1Dx9YIUCEZGJcomUh7phIlPaoT44MEDMT96z7neW110We6";
+ sout << "/FUCAtq2aeF887UH/ARplwM/elwZyI2BIDGAWF5udKyRtCrVznJdXgtlNND4a8MM2OQicC3D85Z4";
+ sout << "suz+m7cwYIW1o0xBcuM2ASLhhAdATCHoyD7A8fIm3dD0vMKM7vA8QQi8ETjiN7GzRqoyHOLt1+4U";
+ sout << "wM/BQdskv+ufGD3KADK5CrKPyKzaosnPyGYyzclqSBbt//QUdYBQO4AP2A3hqh/QySkGID1r4fUC";
+ sout << "cM+fXd4DhwvNjAZHRPfGPD2xPguPOmmuCxkHU1CpcKdBSbxR2q4FTdaQ+1HBbU6KZ6358kqT+5by";
+ sout << "3LP0eS7VemMfP0d9rshtgbK/+nG30EGTtSIJe74inTJ63k3mkHb7jRu0rn+vmO2aO60HQM4SUiK5";
+ sout << "eKZUz9LGB6lD49IqbEwj55SpSUmLsITyO9HdZ2hhAAxRTQeMz2xatUny/fFOGSUiBSeYvyrmnWjs";
+ sout << "R93uXWixgiS7DC15ARLHAkvkhT76lGJkxfdyhdkneQlImBYUecqxHmGRAEOQEUFB7o7ARObjdaFr";
+ sout << "Cii//3Mj+6/YSOAaYlS1YInt2Wo2+gin8+eeqpBp88j4NIH7qyKiGw/fco3w0Yi/x/2tbc+OR79D";
+ sout << "HylA5tAU38lKI9ZuW3/8nYGw1mXE133YDK8ZVzGsU5JnTnCuIW1+aVLf5pzzldfg3z49X2S11QnW";
+ sout << "SNtd3gZ6cY2sWHnQ0gm/EIQCF8696kwZlp/J7fZonWd5Fr3P57QepBjktJs2ReoUXGQGAYABNJAK";
+ sout << "VdbPpCSu/kWlizZFudp+2MxUJnGMg72fGFn2AYhlW47Luuu3/DNQiHhmrxaqXbnpqk+jn52j2rQT";
+ sout << "BRiIujM4aCQk8bXqFCvgD2pcFpKMKtwgkKeLlLMhqX4H4kKZ2h2KpJPSHsTg4EW6p0p1boVMoD8H";
+ sout << "ZFZWTuDKtJOorF/hCTAC7paeeY6EqeEs1S+5ujoI2XjTNcai4TMAOe/pAdgpVIRIybPZP8kWAc44";
+ sout << "P56yfuv2RYxcqAvmpEjvPqhaP0lG+bOUzxAelJy0hu0xe14CdvwWZseSDq+mDaFrrM2fEqGp0eBw";
+ sout << "ysDRYDp0HJaY+QM7mS1chGIw3iTOg5IfI+wgz4BqRz6ILSX1Ci93QzzMsr+t8AqDssxzDX9kMNsW";
+ sout << "oSoRYt29qfxPr7wjBdWURj2KxpW5yCygTMfAh53XuqjyEmuiSu30o3fCOzdQFOppq/LUul6A/Xbj";
+ sout << "anZBHBWlmaIpPtvtdvjcWrflDrwdsJVZENu3qKyEmY+wpE2MMqXUtctJRBr1eauRU/Kcn+3JfSji";
+ sout << "PtgwJ2W6ztTzjqr+6+TRCqJ7xnzIZ2Cj9jxjcD5VCEcqAVg1/rlP50cHMlBrW+UIySQQRG1FtX7k";
+ sout << "8zqh30BJhM6Qff6NrvLkpzAl6qKkgcSe3tsbmovCf5EmB0irBvKg15LiCDxI0HyYT+KfKv5NBgjB";
+ sout << "NQDDNZzqulGPMIqzBGaZc29SfIHaMxYihQS5mONL345HTVNQg7MkcHIYalGwbpIbq3GKv5MumOVj";
+ sout << "SP1kZ2BGFt2U1A0ytDfQLhHan2a7v4+7T5JgUkRrvP8hKHRm0JTMFqFxZpzGOEQ9dwl6atFoNdQw";
+ sout << "f+5ZTWFfzYzWxfvZT0j18yQ4fo8VQzCqlOc5kLYatapswTogEQ1PGd3UH9gTs4SfMjmvrfre/kBL";
+ sout << "nTISbSjaS3Z7N81jEQkExFHqRfxiMIK+h5VRZ/FwviueAATnsIY6vChZYueN9Y8m92Ou81EHj6Je";
+ sout << "o4mKrrtkUETSWP7IkTd4wD/EC7hMWvA7c9cWpY7Q0nAcruSrfcUgks3mW//cNgLmMo0UE4+0W8LM";
+ sout << "H5Ib3drj/Ha1OAZK3NRhtUN2TRgJEaLWrvJ3fo/xDyIdmm/Ap0jf6jXGlXtdWl7KaMElGsmdnFyb";
+ sout << "xgIsGgh9k6hyZ2uTFcTzeRFcsAVOBkQWOBhyqtqh2asmPFjOro/sNvstEVD4+SrHSPG8rlqK3nr6";
+ sout << "SzNq0Is07At+PS14wWqp9ZseQSLr/pWwK7a/CyyAwxnbdJqhhsnxmODTz7nIWFgQittN2Z5l0RhP";
+ sout << "C06ZoojtRQVQPyjCr8C0BnFGyKapnGPGWVbuzlu6xMhqC9C4T+4zjsl/hVT5BaBKvGgjvxUJ9pWi";
+ sout << "fs120RVW4ra/T6X1oAXvYrjp+2i8QVcOPAC/jfE3zT1UpsTrt4IKYR1ILwbWsOi9ORJUqpQgQKpb";
+ sout << "ZrMseH1LAmp1RzFN7M3Xo+wWx1fu9S0W3ZYizaVrkbWrGJyv99zuiCeIY8Ns7oqUfQAfBbJM2TPb";
+ sout << "6jmkCyD1U5a1tir6vGc3YDkBryvw1oXRdeQEuKPSF3YVsp7bHf03ALIn5glbDJwh+8FvJVX0rkPa";
+ sout << "yK4n0dU6mBDfyCt8yZ7/uFqPHVr9Y3d7Ec6RJPEouC3MOlKhYzow/lkVnCVteEaGqhGh42TsKnHt";
+ sout << "I4MUSgA6n2QVTcFfD/hsQdtjzWoFupA0v3PsXobN27vz/aEGrtHjaPXKNDLfbRyW0OQN1Rfj/nRG";
+ sout << "+TbEr9Y17f87ar0qnd0aLd/hgEJ65lW41HWWj3bjRP/RL/X5HpKVzj5zOPBfnOH04vCT2XtjAK2o";
+ sout << "J9gFPdwz+6hWrKXGuQlrIJS+7RpVwZG4woY4cgBv95rFjfWwLrLuuX7PYeCJNLFZKpoAz3WpzRYQ";
+ sout << "PepcF0/AzD6U/dLKAaI3pr/g3kItwzrhGIOz8ZLN4IrVYrSQ+kw+R9NuRWq6Wxg84hoRBWN3hfxt";
+ sout << "A/58J+s/+DnmBilHfrYMhysYMVe0TaN0fM/Am5VIQW+lQJOVY8nVFqxxxt3BTNDAQtVl1BwaRQw1";
+ sout << "7PVWmpFsBKX/cIkJRfuUo47InS4SvzV6JEgNSJb4Jp7BHtNHQpwKDkCjfwcajIrde2nIEXCZ7CPX";
+ sout << "6TmNL3i/ys3IU0zW7v3uDZja60mepUlqvb3mJsiPNu0OkX26rIO+K/E9roqWnwp4HdzTCgv4Lmuz";
+ sout << "5jaVBx3ZVxdjB0Eb3Kkt96pZflyLdI6WSR5RVCfpl8ov/QOKAjZoihm3/YZJJvytsUsj+Wi9CS3G";
+ sout << "If3aps72fkcCfzlzB/2H7HS2Xmfor9P+hBdfEjBiz0oMlkXo0+JuamAk7ueQ5RO7YpnI61Q190x4";
+ sout << "vgeGRm588X3qNb5IXuhumYUZJ1ATmHc9mMkNSNWEGdFy13DJmhHl+fOrADwMAy2+J8g0L2yHloaO";
+ sout << "rSycaN7DgGe0YJD8IUbgS8QMOz8Kq1/Dy2/Uinix2smPobySnP+uQjRJ6ilbHnZCAa5t9KB1fNRs";
+ sout << "HExRV0YMN0tZ/njB/hUsI+N/sg3lvPtSOOpXb8wAQQIcj5v7rCaCBbiGdfjWcgUbHfLDgFLTxq2q";
+ sout << "LND3J3HmONGoE7kwPwNRBMD06KUY9PYwSqjQ7spkdBOYtuBVqCbnnaT6hnh/DsSTqG6DkymArX6G";
+ sout << "Y4SesPr5KhQtUvkrHHQAAmQDWALzA7w0fKLi2unHSsRmteT1ctL4lGwVdEb/YFmNzxiD7BM3Zqvq";
+ sout << "JtEDUScoB94K8R2BBOy/CQqEcln+DCoukLW8BdLAJ6b3mDaw8tLIevYutQwrQln0pc6HwXKmVTeb";
+ sout << "AKKBMmZaAAClk5amuEkyI5YPen9gSZDX45yRohZcef7NSVOmI1ma3mc0rFYJYkJkC7T9bcpcHqAb";
+ sout << "WONeENewbaLz1onBbrVYBdqPeBWjrGt4kOi+YSXtjm6WdZRtLf0cy2uRq1gvzg5oNIDGer12uAAA";
+ sout << "AIqr189l1IrG/3MtrYrKK/Km9wuukr7dJe5Qf/96FXuft+BUsbt8WYQKOxeMZJgWfoZVUzKwkzEN";
+ sout << "/+uWsF3fYne2EclBulQem6566qaX95eQMQv+mzrBVL7+PG6lweDQ6r6ZETzX5ogxLqyeqLSue8qJ";
+ sout << "xmVPMI9fr4rDFcRrt5XgwH2j/QB8ymhsIsqPVnZKbq9+jrPUoRJDs1X6VCycaMS9bG36PFJwxW9y";
+ sout << "Be+LAmVi2N4wlBCmNwXclqK1URJRvocKWAKrt0Kn4dILU1Y7hXNPjcUx/0oG5Ffoh0f12glRsyva";
+ sout << "FJjtV7ePY1hClgBK9v7owd1YZUz9e10XcWOzS6L8HOpX9pocEiBdfO/oTDYBPd6LNk/E5V032Usw";
+ sout << "XdqReOtJL/zOBPW1vJ5xc1UMRyGxDeqUQFxCKLkFAQj46cG89dM3bhj+Dh4U5J5MPhQltLLhIVbZ";
+ sout << "qHvKkLr8cCYqSS5h/SOF9HWJKjNfXOgpRbR2l3/m8t71h6sSi5XR5o0NCi5V3/VCgAs/CfXCYH/O";
+ sout << "WEDkfZKw1IgO15FUghj5y1vzejXv1logti/MHvh3WqMlQ5i6B6SgaGmIbOdqNHdYUbQnNg41m+8u";
+ sout << "k5VPCiPWtGkbzzuzaYSx/Na14VdvAbZXDztFPbN3l1fMrAolDKlexOXohzVZhYOsqyu0xdGka67Q";
+ sout << "i/Bz6DbYM/U4rF1lWoX1Xh9eslMlBlJjaAPdicvaJF7YArjzLKOHB9p+EZ5tfrotIKGfdGIbt/Yt";
+ sout << "MSsyTx9PBINYD1u/JeI4SqXTwvWcfF8+ZmnspgnEabm8NKufPXhaJODzrxL7filfb0/JbL6qmf/b";
+ sout << "5ksG2XT2sPxBM3TjQE3DAtnS/Y3psQUGsXUuY8oGNF/PgRWd8e+Ek9IoAnMLuvW7S4D4MP/wmWTX";
+ sout << "kZzlpZBnBBvlh7H5yDHGYCwGUofjCiOk+4hU1LUNU4LGgZYUsfjNW+zbj81/JSoK8AMlmB/ERH+S";
+ sout << "yuBnOIDJNY7BG1n0q/YoDIU/YuM09fyzW4tAHNqfbxaYif2QldKZWbxUEiiOWaItpW4QTsQe0MaC";
+ sout << "h6ppkigA6/7DUgUy3PKon+JB/jdrUdW/2Q1qgLqV22zBLHJoq76n4QYZFmpLN7FkLFtabiXZigdm";
+ sout << "7y4GEh33CQIlETdes59sn/0qIUj36ICMh1T/ujQOe8hlG1EIruBdZk9CtprIsVYdpuiftqJHRLmz";
+ sout << "271Lti8HdM/KSVqrupdYb1GO00XOGVU4/kujCa/nzhCukpCpESV65sKSuIaW+VQlvYFATzX74WWL";
+ sout << "bKfIvaa0TLR8pMz30JVDT7f+I568gRLEMsBdtMnLyJ8bgkPGdngpnz4G9FBOqkDeZdb/Ji0aKN8v";
+ sout << "gwRZxrwUrjBE4EjfcRQy886hvReIP6GvKax4bwVcGfmYGdNuliSsCsPC6pLXSH3fnmYBHD8DbGct";
+ sout << "Q7dlu1Z/aFTO/uY3yup3clLe89WTrPTK/N1lsUDULj/gPNwbWY5l3ay7YRoTFt49EjYPmx6wz3kc";
+ sout << "YXYzOroiVWGspN9WfmHL4XN6pXWhiCOrMRG2Z43vbdUnqNk1992xy1q/ZHTw2ikxIwKy77wygWfx";
+ sout << "28l1Iqj1k8Km2Ci8sBy/QCLnql6olazxv2fxn6vj+4QNrOcJb3sDuLbEm6BkUHIoiFzZEpc5nIpw";
+ sout << "5YBSARDuvqMt+01+ed7yBUTms08EFh6EwPIsJ5R/6YE/Qt1aP/hfB7nGWlpGb+yqycdZbUw5XyZm";
+ sout << "pOEHQHvW1j30bk27jCRHzktRYXRXQ3BpP+gVcAlXgxeoARjFfC3hYPFtKSQsmKpGKOFFMdvdO+qI";
+ sout << "lwrru0fgpd0ctBiPsWR4loLRLpt/GHedyVRF29GtPPfP+JekSMDzQj4/73RZfIdHlc2b7zX+JdLz";
+ sout << "Jk/4yNb2NmUnum55ezFqq7sv4HPFvMBuOth14QPJsrT1f0NyUBP3qYocsTAY7yo/fg2gYqLFMBHK";
+ sout << "+N9IvtHqxyn4actl6t8DIHeBTsM48AJVza+LkOrmuK6k1x61G5r/qyEVTbWVMqKagbySMYtIyJxd";
+ sout << "G87XVnI5bbsfSdTdTSQ2EOry9U2H3HIrKIFzxOcj2NxJ+dNqpiLtreoN8XsfLvfM/46NeNaynG17";
+ sout << "jIqcV8TL7MVqJ74/OYv6Mek/0vjLysUDQwPniEkka91ZCJboXzsuW41M5QkSfU/nxsD05TvQVbFj";
+ sout << "GBNApTN8Oe8miBahONkbfBB2ks536yBwG8RXWGSHR5Jdl6yhUIw4mWb0WupVEd90sZmXEYeP1q4K";
+ sout << "MvdHRuEQbNZfyXLskOtwKsFAqa9ebN6yVdM2tIOOp12Km3mabcUct0nOq///UyfEclOfrC5km0on";
+ sout << "WWXFk++6vwChLbG+5WCZl5SNeuxEWu1odbh2XZ5ng+XSz/T2P2l3Fa7ujEKilHAYYbS3foWHO+ev";
+ sout << "sZlG35o8e5wuKiqnf7gGs8tAIFFyONtfRZuGzH1qy5W7BysZaohtU3ihEzJdomSI1xKHE7+Dn94W";
+ sout << "ZDi7S+Qp6VKbepW6JcdWctmtXyncC5AW+UzquyYuJzjoDRdVe7uEAA44Fhc7nHFHx7snTuZpaYcE";
+ sout << "OfIBgzr61aCpVMKX7YJ6NOzRNgfTBIMl24JybjFG9Mcbk2qQVzKQ+w1StjfGTEexWIfYRDxhoUzB";
+ sout << "I+2cAA==";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin, sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin, sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'frame_000103.bmp'
+ static const std::string get_decoded_string_frame_000103()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file '..\..\examples\video_frames\frame_000103.bmp' we want to decode and return.
+ sout << "Qld+lmlZhXVX5NDRWIFG1T4+OGdJbkmxAXXdHZEtpDf6knVTlRWyAhgv85Tf11KJZbhKv1dcsKtQ";
+ sout << "fX3y/9RxoGwtXxwFXCxbBrLRaklZ0wCJRoLkFEbV5T2V0aJZQJTKWkAqP6JwJ7zPTyCiJkDRbGA9";
+ sout << "1cIxGeLInzTUAlhbGO8eexmWzXHz6KhgZ87K5c4n+YcRjVfwTo3Bl2AB4QPv/bs5da4g6jZd7iwv";
+ sout << "PYnF6pRCiwuiCobwvi8eG43pWnNEX2bIUtfMVGP2zA3e9Qg7Q7w6/KNILIYSXeI4KjLlnmC9YPB2";
+ sout << "2qPH0Q7DVlBLBBUHO8lZaFg1gR2Jb+4rgSvuUg6JaoUbqAUMaFFnQCjw9bwdN4bJ1GFXCu4G2afO";
+ sout << "/izIyq+QxXHIy/Ez1TZWTtkJfWRkKAjj59IBD7fRSqyvL2tSatSiYQffAcsyQJzcPMTHahKq3XeO";
+ sout << "pFu5XOr3OsyxVr7ME4/GAOcUjNpdW5pAMPNddRfL6/Jy6S9ICHxppdwhMAtbxsaHiKv2xgcc/t3I";
+ sout << "J0t4NPThn/QTNed+F4iFR30+lwrS1EyOUiqRT7q51LFDQE+/ZyqCrq8vEGp+BRAxMEoL0ekA2B9W";
+ sout << "vceyzfAlymEivr89KSSC9L+VZjE8qylGshgeXM30qladBEU4CKtcTlfJFWSHZX9Gm8qDXmp/5bX1";
+ sout << "lT9/GNd2E1hBWWQtmQ50DMlInp8oFsoURaMgzNu8zOelG1WAD663IxutR2wJY6ILYQfXNND70XrM";
+ sout << "O2PI0ZVRWenZjLzXtSefvwSXCG1G2u5sjq45P4Uk35wtetzV+N2WzVinM3yYtcABa51N3Se2nDWP";
+ sout << "huxbCEVYT0c5sfaNCy7XRQ0ijcrMdqMDbecOuTMSlE8FEu12Z0cNFQqao0R8q3601D39vmJoxIPM";
+ sout << "wRIiHH6qVcCMBV3oVPWIiD8T96pUZoN3V9BpyQx43MEC6ZdVFqUVQtiYu/PgWqK2pwk4Lh4vpou3";
+ sout << "Jad9ENAsLdoVKZrtHlD8k70MVFOXNXJELqJarxwDy35JAha1+dHBvcgU3q/8pOnhS6xZ6judsmXX";
+ sout << "LVBdUnKu4w/tbXgJA7lkLtwtODiXCL6edRCSkuXmTRLqXhrzUYzOVO3lZQSNWLLagHa73ce5B6Wu";
+ sout << "rVTjMEVaNC3Guz5/cOkZrveG7MTk/eKR56x7MhkDS3H9pbNiCoqO7Ub52x7j/aicE48SYYkXFTrM";
+ sout << "ETpxS8VFa1i5VpaiDxhfZKKV03S/fA0QoUcakVISiSN3VgSrhYsu9KvVjD3kjW9IMKGMlcVH040t";
+ sout << "1cs0rHtLRVh/FjpMZlQxWxhI9ZMH0Ccfzpr7SbTN0Y4WfKaO1O/JfqYkzCnFTThvkrUMAeA1+zHJ";
+ sout << "1HS44rFnT06RSAGFX6KaqrRFBOC3V3ljXEHyyrGJJe0OHWTUAwov+JzoXuB8onm10lcxUpV2Wd5R";
+ sout << "RZ3dx/u9pEYgBolzQIXVkcE0YBYmkXMJfgXK3Rp6V+1pIR9bTWUU76/0CgMHqVBKZOtpl1ykTQ90";
+ sout << "h/n2Mcjhmo9BtRIi+R7rBhr2a+oNj7GZHVPaaItTcdSvns42OxGjc8R6ZdzLHzNNPYx0NkbVeKnU";
+ sout << "FSV4qd+iYq+N5hiPeTRfwdLU1wRL84Gvm+JfR3Y6eK8dBmMGCp0skspHGgpf8xTXXpYmOuZ1FmXV";
+ sout << "YMKmrYMuTm0IigxtpmCmHPmRZ8U9xbe4Oc8UDxkkmCUvcuvMLS9SJKFOQiUJwxuXgf2IzzzePLcx";
+ sout << "ePs3jrlOoD5GzQzSvDrJRIeSxjwTrVFiFT+qwMIIAR7TyTjBOWV51+kM0+5mU9g89OwLmO4yxXes";
+ sout << "F++dusUFHv+DGSCxk2iGNwiNDQA8PktmdVKux3PpVBxWZ6P0W4Kyso6RNIXk9iAVZj59zCr9uPwK";
+ sout << "MzIZ1ASN4a7lrfvv4cD+Uxovvt1s8Kck2wmIiWg95RiLdnopqSg8qZY/3U5u6XmXCqnEFTYuKKVY";
+ sout << "dWsDem2GxTTNM3DtxFN9fYbHN6MdbuyiVRPovBz/T+9AYpx0PVaFst9WVLbf8Ph2Hs1S1vgrW5sS";
+ sout << "qM3lBgjEYl4CvDzL5HbmqBV1gTOZrzkrEdr00ChKzRURvTBn7NlvPZewMhL+Hc0SVXYtzCLSl+p4";
+ sout << "C3tIOZaI2l8JIt8k3znBFs3UrWx1WGCv3GY8HEsTa6uX984B5dVvdlUqmstnqK0FgMEkYxiPWoKV";
+ sout << "eYlx3k/sV8GChiZWY/0k15KIY5c1ZUxBQezRc6a8IwYeuCK4dcfkMtn6AUpWu3W/vY/sx4dwNNaC";
+ sout << "eWhKda7lBCMsfRdeOZvWgvFKoIknjnkh27uEJu0AskTF+xhdXdKwIfM2ThYFipLGAImNXGuHbtlN";
+ sout << "CJ4FBcpn1BFWEY6E4Q+UeOlU1rv//LgKYxgNC2GX64oLOAoPGgEOG6dTKmnad/7y42k+kmQnvLvT";
+ sout << "YNkffPIcqqBvhZCmvIg4eMR+cB5Kw5IN1tik/XNmrI2FP8y+7OhNlFah1F+slgX9oHVKoRLtJFPT";
+ sout << "e4rQDWqAnWMAoSqsITbHCq9Em1lP6Cq3anR7zsCNH79KJCsbZOZ+J9ccoLm7uLJQUl9BW4irlZH8";
+ sout << "fYZypsN59YE3kzHivG1xH0QsDLXrw4G9E2Z6CFojySfluCPiHXO9V9XxiMaU2X4uYHVb7ehs+/TX";
+ sout << "h8uPIngi2rIspfThIzUPC8nl9k8xIH1BoWhryXHkzPi4M+ZAI9afBNr6dB9qa8O5gPvcfpeWPRvB";
+ sout << "TGAw+OW3GX40CxJhqQ6n1G101LjjbFvWWdHTm94cR3cS4B2BYzfAqy9Fe2woNh8OT9520VCEMgWc";
+ sout << "4xYjf8rbQFq3kBgPDVEyu86VZCdVBaAs91vTnfd1i7tKCN8JECyr04O4qJ5IM1+5Mtv1Z6wlgT4r";
+ sout << "NmW5FaCZIKwfORsXG55kuEeKBmvbnBUd6iO+jVtIn3D3RjWCHS05cBsw5uCvixNIHXIF48lTiHpR";
+ sout << "hGITp+vcydTLUEeVRxnlKIn0RJX0Ht2PWFNNbU0RV2WFS4K8QGQxekj6i3XFsB0UWJn7ipi4eHJl";
+ sout << "CIP6hG5bQ0qG1tvhC3a+Ve4et41gI/DZu9LbSqg3oNWUPjEgFSFadjJ6ZByhJ3QTm8Pl0uVdpzra";
+ sout << "myKEafoz0tbaRgJgqAGpQwkIPgp1ocPZBplMSbI9zjBhrTlI6fXfH8aOw4FgB6F436yGuVl3/Rjn";
+ sout << "UuvOT27jH6TvUOGxWBbzXXcxU6250E9S45QSh17Ly1/pfal5GdAcemaHKLRBeCLMwMszIr5ZOM8M";
+ sout << "LxeGCXvWuMuoLIPmElzolFuKKEht5zAr+KGbhsILlLcVpQBnQRM3NxXfh6bO86V/lEOZxhy1b+Js";
+ sout << "1bQrn1Uj90QnhqDoI8t1P6aSFDWKvbsnzC30gPZ8E+FFTwsvvKlzijcoroFQ6I5rSdVehv8/YNvf";
+ sout << "1yYpXA+hIK5wbbyIAXxxmi+5OP2viwGEihe5rwkuHRTmtJJ09rHWFT2TbVQAF7775ZFNpSogY57b";
+ sout << "GJVCa9pBGN8qzsahk99Z3tjeP+fP+4F4Si8JyPJq+xmTO62ciKxHrHt9b6sE9H3N62huURtpP90g";
+ sout << "91QbEbFJU/5JSPfJjJZIp299UhafpBSUACpOnnjBJT4xhVU1E3k5x6TKdrZYzAC/dsQ7BEwZgg+W";
+ sout << "5wfLTnkJsuvMLQgS5e+Ot/EUNd5+16kkR/wFeH6vtzG0EyxkWK4V7GrKtsrLWV7jOC6yY2mGwfnB";
+ sout << "iWLdaI9Wj9pIeTk7w+PcyTZZfdDldyOF9TwYKvHAz6ea5s7sUe569Y+iDqOOXOwpjX7TMUVgyTCW";
+ sout << "4VC8ixQp+UNTdlKLoeS+2fw2QupH395hLQ1DHBcO5Tuil8Yqxxw6nV+j27THym0CaB3j0JgA/4Tq";
+ sout << "vbqsEonwHgSaP55u/j+0QRpycKoXfEkS4gh/qMhbLAPQyTNY7QLjE5cVUqzLG1twetMMcXZL2Dqf";
+ sout << "dP2uDCswsX+RJdIAN8BFoa+J+uXHgBMgbxO4DrGz8GxuSr8RfS0JeEFReLiLpE3n+SrnLwKAUr8y";
+ sout << "qWjNrqRD6XKVWpX4SY1I6wcE9/3jX0c27mGX+87XnUZdIOZbNgXmntyS8J1P4uFaEZIy2rC2ahgg";
+ sout << "1mdwALsK29JJ5QEaSg/qd9UEmm6jnWLfquCgBIkacJgjG1wc/029kL0/Br9t7tuI4jPiKKp+XRM8";
+ sout << "ZZJGwjcB0DbJPB25JQ7zp01lbnEnNiGpNkNXUQtf7a9Z/F530W0M6wG4fOl3lW1HP2zF1Ad0GJxK";
+ sout << "y9g9bj+tCsfJQfeymW/5xaC1C+9sEDOsg6hgVH6avh1a0ILiY/5MddtCHDQKQF5tgkuiKXG1+ihQ";
+ sout << "CUlFov0Q7HG9UruG+eydcYSdyfKkdTqYQ6F8PjLP/rpiypX1jiZAodzEuvZwQcZLAgp3gqV1swZw";
+ sout << "6rm2JpyS2Kqc4yCanywpHAU8LkVsjbPTanJiOgDguguVRtYGpAh4fTje2Kofw3rbPjznNu+CC6+4";
+ sout << "8QxPZtcc8USR2EShYVlE90d5uu13F+epClENFc5hww6ymXysLDCIop2vM0PuA6csUxJqf6YFZHBm";
+ sout << "c1h+7m2ULTr7DEVeRCbDkLvrXrL3zzqESvN2x7QDUW1lQRG+A+hLKS/nuibs6tI8uCZ0I3/VDQKX";
+ sout << "hAhNjD9x82FqKKyPGFikNASTaFg7I1DAHm2iHm1Rbff4I5S57UUIk5zr4o/pfQBh12XxfpLtvSfn";
+ sout << "SGXRAUcTO5jYoftuSmX9tiMfjjSXopedLWjMKSUXE6m2yggIzxfAzd60Zyr9lNMBVzeXTSkrReAU";
+ sout << "QjMxJ9zsAF/86q2lsxr4gJKNmg9QkunRbsVlh3OXE7MkxBHlb2MkULioU83Gf8LiosiEEmYSB5vY";
+ sout << "H8D6cgnXtB64BWoEJpQuG/R+zvnrWNHKmcBd3YkEi1j0dtL5G50SlMOJuai2xsD4RosbPEdV6Qx7";
+ sout << "ESWrAO1naNHp0fewaVl3d8qXh2rFZAKNZXP3O3H7h3BAoMOsxWMFNJwQ0xZshhRvASlOr3Yn3HB4";
+ sout << "fOqqLiPecX/o5r5FRa6qIfP75QaKqcu+iA23psV4/bszP8L4wREKioaSFncsUwsMU5NyZn3hPDUo";
+ sout << "kdW+ZGGQk1F6Md46VYMIvSxC0xer/NcISrDfWmae5WCoGlJGhv1rSnSjRmMrj213eEh2qpDNhz2C";
+ sout << "8dlByOZtSTOPejh42nYMZKCdBTK1lftdSeYdsHRGt51TQOauYYNjCwQkh87aW32Z5dsUi6H96k/c";
+ sout << "NTnEoTZbmyqFdmov0QyQBgIfPmIUV/z+3x9PnfuVHJr2PBdElqqgp7h0Z+Z/7ESnDV+G/0/dx6nf";
+ sout << "hfktS6/8axrd2xTKr7VrfBLvasV0dTC/GibZlQ1hvowffGgH8dgooP+Nc/S2UqhkYuMZz/Df1r5g";
+ sout << "jYMY3rLQt3RqpbZmwU3p9MGe74SVOSCLgYXIDQiGyoGGPuwRwj4uhI5f26fC0UzG/0RCe2g4dDe1";
+ sout << "vpDGatzd+3ixmyzEf6K+4UU4l7V7XAuUYs3lGEmtkz9li9mF3huDAUgOvvw+jErQrMnUGeyfzqDF";
+ sout << "xyGfB50sK3Y1q/NSxYKQ1oHsyNYvnTr/aD5OPgMzbjv1Ii1mnwtP3m/EWmdSmAsvsjEq6Dy/yBuv";
+ sout << "WZ9JO/XtxbusYXBNUxb6lU00DpJK5NGZVxv6UiQmrRqZvg531ratwkkaH0ZJSOwOzWlRAeVKiDPM";
+ sout << "mxtnqoUGNcyr1SqIHBHvhFpr2BJKEeZ6MRwNHoQYCA1dQ4l2YJtZVITuL/o0SHNq1EXC+5sWB00T";
+ sout << "UFfHRWFOYH6/baG2GJSvgzQrxBH9miEG4WzXQcdTQlmup0RZWjnwEoyHRMj4pc9cujSsTMaGktXD";
+ sout << "ZzkAX52IncX0jqq2ZE4epQnx1UHg5IXANgb0Ed0y/FQjS24SZ57uIvl396ubCO1LQ04vGH2Mvs/m";
+ sout << "bSM7jKGtcvjQ4uvVaCyk29Ek4EKJM+gUAuEj3IcKnoRNL/P7L2NEBlJuQBlIEgczCVG68kkIOFLZ";
+ sout << "CqGgbf9IaBULEjthq4Pr+B5bUYmoBGtkUDM0KewVRJ+YA6A62AAAeStbx+myMgTTqq2b4idEllZn";
+ sout << "0Hc4kztLuBgtHfwaizv9gYEUbgKbXQWfTQdUNAGSyKyeY5dQ425q+eyHezNiXZ7C+tDiM2UN+Lg9";
+ sout << "t49U+7urUylh189bDcZcGmbx10unLEwYGT0CgDPGA8DJRlkHUkXSlImHpz8+3hjRtgljVcz+kFXi";
+ sout << "jP8F2VcEoiUDRpaTwdJTi4pRsWSZF6pDEvHWntpnhtBI51zoDzEpbkHJbjRbvBA2zl0b2MqmwGXW";
+ sout << "2qDOXOfSFVUVzTr7++omJ+UXxq9awcr/LVXz2tsuMyIvREj0Fx3c9jVozDSbOPzn97QD9LNh8Pau";
+ sout << "bvPgfmMGz1Xj5f1UuKvUiCfvfP8ZZs5l08ChMadB64ipgmdWKK4adqE0ES0cTqcg+pJDDL8FgLYV";
+ sout << "susPETHqp58vc13TMcBaqMAa44/xA98rFy+KZpYNxFf+l2U1eq44OH/zdytXhOg6y8TWrKYIdX1A";
+ sout << "C91FZWXj0t5KfTahGZE5t0nP5iKl1IdXnE+dOIvIHuzrTnyjM+IiHmj1DxClG1VQXhZcwvfBzIY6";
+ sout << "MRwFxUJXmt7O55uesEbZ/anmQB6LHdy7hMuvvCgeLBhoRnnyiRJYp3OAwNH5MLy7LXK84ZKTkodw";
+ sout << "BuyRw5oJ9KOB13Buxj2yCRhLky572BT61cw0rXXkw2ZOGQgzHbO+cjikNk5LsVyMYfE49ClKM4w9";
+ sout << "6pVtXDNzSPvQK0KSO0ZMKPDAr1/MrV0PD5Y2v1JIVPOuKSjq1s2vP7ypzarrFHAJnsD1FrigYXME";
+ sout << "9amdaYhCmPIiOjj+9xSfmuBrrsHJ/THe82pg0fO3mpsF7TVY3zQCpiMDY/UqDyRrgc8NQu6A2WRc";
+ sout << "oGCRQ5QCWCTFhbIV1HQ/KXjiCeFrCd88KkZDLDwrWMqeeGDTM8eWQcTIqoMMtu/JiCrts9c8yiTj";
+ sout << "IFiHTXnpZMYwxDeT9hvNkifueXVjFYbEs6hHySwX0kcUSgtRP3fUKyQe+Iz+u/iOi1tFHQgPsV5W";
+ sout << "gltpMLE72aQmlQ9C1lYkJQoKrQJRFiJmiQn5S2hQKXurP2HOIsAYSpW4mpeeFJpxsF8lekzeniil";
+ sout << "96nrrpQlpc1vPdDhbmTY07tGoNlYRjoKTYaelkktnWA2a4n69x+kD8iBfVLANU+EFHNfCyQgmdJe";
+ sout << "fWMYQwqwgxMWk88peAftdx7bDJ1ZJQ4q0zzkot7w7NB4oe0Q66awUwj2n4DYvDAum3SRz2LwsaiR";
+ sout << "Ke8N8dMtRYN2mosZ/108MktIKlnJ08bkI6eDnO5vHlbeBfLOfEP++jGsgdS1KHrZHdIygtioQ9Zj";
+ sout << "RzWb1fKOpoIbxV4mgh5MwpliqwHmG0DIq1UPIEqj8W/Vbd1T49FhWRJs8cyYraimx9MPklhyxlGA";
+ sout << "HvHpQfuihMzLT7nNAVceWz6GzYfHy5uvje2MYBVHkoQ6RaMnJUCfFd64pbFTmsB+/+Teoo6q4I5G";
+ sout << "fbeIMh26cs37X5MwocXcq7jSunVLgPJYDlBcbpIdhDt+Uuf5VPo1GXRAk80ouu1E9mrYvlZ4sanS";
+ sout << "5EJ61PNevEOSDPPdYWwfMR0rKZuoefUCNfvy7XyXpEnZpNLHKsO7Cs6ZyL6zQykBk+KQCfdRnTHa";
+ sout << "SVSpKvmD34O1gGkrq7kSfyWO6eU0TU8HImJY1tpI8fcMTheZ24hxx2QsV6IbsM0wbxS3d2Qo15XC";
+ sout << "oP9xQxVcn17iiHUg70awt7j2VFEuCb1QjMbNnd5lUvjSJ0jH6gi+n/on65AWEJYm73Akn9i3fP+v";
+ sout << "oLTDxIOmjRZsDgzsAbgtxb0Ozu0rNutnPRH0Am9C+GGJX78X3u5qDVn0q2+qD7eQkPOmW6oQiIYq";
+ sout << "8rWM8ya2N940j+H5HsvyW0FKyrqjSSXQ5fF2LJE27Wc5ob0tVuvWAzOQ8COZWMMapu/S2C61f+2s";
+ sout << "dFiCjJQqvr3Ai3pePmVrODBGROVHow3PXekC8iQLRwOYfNHwiM1jGwPOiMBZ3v+LghoBYHhmUh0Y";
+ sout << "kZUb1j5Mncmgz3pctYS7KFhlHq1RHTDlF0pevtVLDXyy0QNXFoOqGydJaqLN2eILSl/eOGmsVlQz";
+ sout << "lgyoCzPW0YsuN7Z4m0z0CfdP7WTpL4s2jlB7QpdaF3JOaip2nVZzX9xb6qXPosQ1WF9yixcznoAZ";
+ sout << "bfWnipe5Lnxm+KHX8OUclbWUC6Twaxoq9Uo9jXuwkU6PZ5VzM5lao7D0JCPfVo7b55q1258j1AnN";
+ sout << "9JNYYcZ4ZjTEKvbXGdtb943Wmmr7CnI4IyKv7AFpeiMAYPVRgxIvn0RjWpJ6lzV7LutAjv0HXwZ7";
+ sout << "5XHtqvvuIe+IIIG7Iqj+JBW/pgQsbb7CKXiRmrWwl1MTtDIrQuh6UP6mstikBOF/d+yOSqCvU71B";
+ sout << "A/Bj8gEqiPEqB7uih1Kakz9wRraRAYo6VZphC8Q5LdkFQG45Bj7KnACHbBs7e9xe7GDrP4OVbQHh";
+ sout << "m5QCONkNpd+0umv74Adr64LDpsMr6CLBgHRrXukUXUGmFPqUCD7lC+quhu8psnZZrvSNPSa+goRp";
+ sout << "PnO4wDlSeBSkuSdyC7utuC40jQtIZrkJmvkq5UURj4VcltYQ+5KiypO7ixAK4fvKyBTgxV2WNVeY";
+ sout << "QvxNeJcNck+IPEhYLyplcLzyoekhnP4ZDD6ALG2jWZ8QoKpShA8+HdRSJoFguoERQRS7ePlMxywb";
+ sout << "kcLPkcyEOlBu7lSjCRvjvdROmLAidz2GrlWvkuLrRe6clvs+U8JfBTjOON6P5nq41ElWNA33r9KN";
+ sout << "kcCmEICyFl/Ier/YdjOUtXaHsKwHaMjmrXHQOTP7ELs7jouCbb4cTFVtgB20smKj3paJgoQysxmq";
+ sout << "12Ur3QnHaHjaoxzIQnlaD69TF4O7T6U68/C6BwQDuZiD6iB+mkCQ7Uz0zFsiVMbF+RgFJKpjN4oo";
+ sout << "SuR2gUL8sPlixeYmannuPwhCkJ1swN7NiVeUwTGFUDeQt2eJvJc2CllOpUZwF9PbyHSkxkQLs9kP";
+ sout << "9DkvDhKkqJ2aEHDlPGBjETFYjNGdoBfMIcr1yvWshl9fTGlEi8c7mxMW/5Ed4sSdV3N6qRj3KZB7";
+ sout << "zdQJ4Ql6tH7zkYWfo/ZtTEggQQAe6FzyQ12/xgFgdAulwF1IJZ/JlzY/SKePQPX0/89AooSUAl2q";
+ sout << "/BfEFEI5XJhXDp2VJevsfFAjlmf0pbC5DkyBlUwdStZAoZpvfZPXcQ1NshCuD/hT4FinYmLl7GaW";
+ sout << "LJSXkHegN09TBLWDdWe3DfKDZEWYqbIvtgt5Evgtw/5Qvy5DtefyHPq21BoIT8C1zxkgfOqEhH+l";
+ sout << "E/fLaDsROj+RUJ055ycF9Wa8mzCOZMnaqsD13eTXtEY7LQ7MEo/15S1ny5JxaGNHxy7YJ3JfrshB";
+ sout << "0pZ5vL3jvb9iYwocPnlmFpli+Md6qEdonEN6K86WwcRNqV57mMlF/EhBZXi8VnW0nqglR3mT2xTz";
+ sout << "CKiwPK1W/zOnR2K0SsQrHV5kaLCfRi4vM3Jv/bGjlgJSUYBkB1QC7jBKE+uA3H5EmspwnDKPAb+b";
+ sout << "kYXzpMWfQbvMvzpumEs+nY/xbCf2Hr1vhZAAFR1L4F49xrxDr8rjP3DryRgtBquQH6/5qMwdg9pR";
+ sout << "4ACieRQts0S8jC77fxLAFW2GqifC5DBhwIcdhOhxR1FGLo5RzNktZkPob/fXEcC6u/Uz4v4Gy0j/";
+ sout << "ZM783wye39lB2eymWQ5XPGC9FKR+ZLodyKJK+NXDBiFBXOjY4WNKEghllE7jAEU6VVmAN8NDkQdf";
+ sout << "u6MiVlBZLK1cepTPDj6G/TGH2FK0I4O/z+Gb2WQdZ2VUObGhf3GpYNR4XEoBsJe0I8F+TKpHYpdY";
+ sout << "guyaXCEMW9XpK0hTOBXfaooIQu4+Yb9e18fKnkrW74Aj45mdseL80RH2RJGm8LeZa/Nluqm71cqu";
+ sout << "9zJkwZjxlfICDgLhkOXtt/Lihzoqav/abehnuZjlyWwJ34eIhCmEv9D7C0P9VE/7B4P1S/demat+";
+ sout << "W+rmgFScowUvs1lfxNr1I5uCyLEJ9VphPUXnykyJ0XM6zm2Za8jD17Nh82k/eCX+JFMWv5znDNQe";
+ sout << "aDzkIRjTQhXWt4e6GNZ5zU6g9rr/TIePeMJyRE4D09xOF6orHWMBh/TFAL6PQnJaI6RUqCg8le0+";
+ sout << "ySd9qWCRpZsYBK+bZBEV9A1iJyvhLyKLgSGC8Q5Y0tQmE5zWSlYkRr4YtfASRn7e4l+Hz+hhcReD";
+ sout << "F6s8jWCAHjsq4uZi8qf17LwZu2xpbSLmgCE1sij59Mli4sWlPDkX86ehGqvH2D2x3hKlXco1v+q5";
+ sout << "dbNe/zvwcng5KoCwL0pmT3itlXfQs08+1lXbpoqUf2wqmQFEi3/TIWq+1lF/zV+6ICVx1d1eKarh";
+ sout << "ZgdmD0C3QLlPNlUAG1j6qWQ4K32ICyKu2XWRjcaElm8uRxPBBJ2DKWJFyDwP8caoB80A2ApnnFCH";
+ sout << "mrijgRO7LeQ6bQJJjOBSjeppPm4jTsVd5tlAefLV25uRwx4Reif0e4x+24fxtPj9udHRWggwGu/1";
+ sout << "Zs4+wsJ7KmX2ekZ2LR5aotwtBCo8X1J9GaHal4WCzh9G7EqAeSjjrD63DoSgJx8PAUoH32QXUMxw";
+ sout << "b6JR8Czkimh47Uv+aIqZsr6GDskW7xfxxgmzIPwSA0cJkulJKiOYEmjcfgzqaKHrcjMRb4LhJRTE";
+ sout << "F7Xgdk4nJGT2hRv3Abqz4725TLPUsfd1DFTM8BJuJSWjGNa2cyQ6f3CTQPEZYFiyYMCtd4Ycytnq";
+ sout << "N4djAycInnDJV/+XSjC94UB04UDvkq5lds+tgGMDf5YWClAz5EO0ztmL5PIQrA7qx1ykF7xiRHVB";
+ sout << "Jy1Ln0oTEV9yf8jlIYILg3V/j716p61dkyJIvRtO6XhFdyE+4ia0WdXaNWNwvx4wkzqRHD4IxMD0";
+ sout << "imBJe7IIMVztDzJ7T4K61P7nzwgGJgXHwfdh4xxYZC+4Qf7LVyGOeH30AgAiZBzEFu6kCHmPMxgZ";
+ sout << "VQIAvbW0/84nBgbqTvUTQTVytf9ufnhuVZSJfXbMlFGTxeCtHG9C6rgAAk2QNPHb8up8NL4ZnVcA";
+ sout << "AWk0OHyU1DZrLFb26nmHBxwVZySrSppYbldjEYxMdewD2OIYBYrCBCKIJzZpTMP8D7QOIX7B05Vq";
+ sout << "zPH2i9cAAAWjZjAkQxG25u2KhBSufKHwxVZ9xgi82dPIUZJRMhNOyUQ3uVhK4MilNqxi3Giw2KtP";
+ sout << "lO1IWReKc0m1pXxv70r03E0i7kf19iQ6AAAABRagQbu/yzAW7Stqw18BH3pVGZSoj8no2Sn/A9QF";
+ sout << "E49JB+qp6BejcC0vGwH+ogKo7pk9fecrkHEc/6aH1Dt7ciMaDXfFII5rE9AAAAAB5kmFhDYODCNn";
+ sout << "qZy5bb542mHiW4BxDLW3mQ4uKyXFV56IuntOj0rsjRZdrOATffLebKWyn1anGbGdQvJcOZcHvV6h";
+ sout << "1Aih5pmFfPxyhA2ba90hh+kPYjhT2QAAAIYcLHepOiMlf+gULTiFNjbeyaz5vgWcGTiwkv4q4ED3";
+ sout << "ZLE0lb+l/7qt620m8kgPvUAOLnK6ysWdLpoTtevH7MY5LafF3R/YxEPOo9xgr7zDLrMKLkZILeIb";
+ sout << "yegmBAnGeWonZpv3eidKXsbBFhHGw7DoSakVKAT7Jahlmyk6howH8x8VEVpK1QNbaAXt/pFXGrzr";
+ sout << "dAO6LjOcrwxvW8iYPjPChEHqwYloPizhrN7k0u0F/oq/urAZxyJ/MP8Tx9afuHY4aocW8FIGmN8f";
+ sout << "NSozaWDqqw209VkydVcnibQweQucq14zNKxGZgvnuw3SKeLPAc40Z3fSWsEfG/DImcOJ6QdmY8S8";
+ sout << "34MKBXUjR8W0w1P0gs407lSQFRZBEmiOpU3HW+A4BisSe4fcFlzFt+lhP9VNDr7lBHpbAqkg1Ob4";
+ sout << "8/uu5w5qut8RiDJp2dYQFY1OieuY1lEN1BFYl4hoUuBQqGBpq3MDn6V40ZzXHl2Rbl+u2tXhNICL";
+ sout << "FPOFo0Y+7b0ppLIQGPAd5fBGTfMnibWFP9Bw8KDo9wUbvGslyV0ny303hYXO4TzIliPUUu+nsJRc";
+ sout << "pBCeJdyRvQUjr6BjCXrNVSldiB6oCr7e0AfMBU1BErZAhzseZYzzlIavra3tEOdW4xweFxOO1JNp";
+ sout << "BqoG5a9gozEd6VQvhurdttGa5jsSm/tEmGDZl0t14nN18hjVW6KCb0u3+kCbcCRRcFz6tZ3+eomJ";
+ sout << "TQIBUJu40uNzrLmur+zJyT7+JRRIq/xXSv2R1oJhIiYGe9P99wNd13TuTi5Oyx1b3hHJdysPjIt1";
+ sout << "/I8UGCpIQ/FCIEEplSCkhDgedfL9OtDOnU/+bI2cipB8tWSwpwbczIslv49e3+xIqBzLf/4gBDd2";
+ sout << "4ZYwe2lHWGL7OXGYCAQbC9ELg2KuRDpLrG2ad4fhUuAnizCLb+q+3Nd7qgkPXweeEMAbSHhp0vuN";
+ sout << "XSbVB+cDaHkw7DxxJR0nn4IhdbkuQ54G975t+Wo5uPunpyZ5QvDXijG8NEhIAytpMmhXjkiyUbk/";
+ sout << "XmaEnaSTOxOEVh/fBjeGfxHmjn0TgkQlw7GLHa1rB6Lp3y7hSWoZW0RSn0TmVQbWBFHen0cfhsp+";
+ sout << "zc7gTzMeFxhMf3ZmBS9pmigcVHPAmU+rOLTtAzaklKumgxf0B+Su8YLZpcvc6Jy8t6rJgCYwIhR+";
+ sout << "BiKEN1l6CxFeQtJp+2Do6oEZ4Lfk0DoAnJdV8l5DhkiXiuwX4zPFiKZQtKSBsrp7bq6dqjDphbgL";
+ sout << "1A6JahmxkTYcJVOnVWG6LvfHYCIJ6hwDV4J23P+tNZtLCzfHPnDDdyJCmBExsOV1xXzNeGhqZhb3";
+ sout << "kSWXntkMAstuTSdgb8vPGYSyEMsTx+pTeXFGv3TN+YCdQTQIovV2f0KTZMxmjoTmlZKueVeFwzSP";
+ sout << "S+G65EMEcgQUGloWsFPdIaRzPLYHUktOcJc4bqEDvwIUDLqfjH4zr7kgF6i0KfRUMP4Jrje78S9G";
+ sout << "C7o9XJLGWbS32yrdKVlIG+9f+2YQRrZECzhex6WjV69/JSe0jbilwXwsJSKtY6jFik8eE6RCZqN9";
+ sout << "c1vOxMcd+5twMMkmnu4yvBEV6Y207UtegGSHaESpEW1MTvbDk/75X46WMSECbeEe8lzlmNaaWyWY";
+ sout << "C+siHamjyTgHAvJLEwiNlQ26LS3mJr4HmpOVizbTykoe2WGV/jmUrhhUg2AN0CrS2LqriZolW/JX";
+ sout << "jJeehNlJ9Jy4e4NZE2Xn/Q5UC3zVJm3bpbC3JKre7sALksQSQInLf7OoYArfwgYmAdMWhgxRakIv";
+ sout << "0M0acqOuNUxmW0k4YEzEncfjSAQ3cAUkaL1XS55I5vm3o1RPiKBjjypzdi8GI7migOK5dRrIctP4";
+ sout << "U0S4L6uojFipn9VGhIjGiv0Rg6OcaZO2FDq9SOGmLZ4b92sE7v4rtEt8oxlbbxPdgT6CLQkVCCJh";
+ sout << "uvjCIeAmV0wwhaKYzE2TpSkPFrtCPWhBjSulx5OwEGvdgzUq1NAnHBqlc6Oazt+6Hb/b+zpziei2";
+ sout << "huprTtScOqJcYzqJYGIwYebjxpVM0f/05a3LodBUWUNUIRNyzJpyEZluHCupDqfEgF7ObUDvtTDO";
+ sout << "Zl1aCRWbztt2W0JE7SSCAZbzY+3Dq+m0VUJj1L7NZkg5py0mUHHnz2lDveZjtHr9UBqETzN2xCuw";
+ sout << "8LQqbhDKR5I5ThHilGjZjxSN4VUKgBiBNUHXu8D2VN9RHr9xjQzmnSOJPHtcFuf9ioGO43OVtV+I";
+ sout << "DvE3YjEt53F37uZBmvTQD3zcBvgRoNF24j6jjTeHE3I/MROKotCIBkxk5AZTF9Awzha4XwP3xOAI";
+ sout << "FyhCle92a3102sdU/azu+1n3Jq0ZeifCSicBjrAHgQpVM7afn/yha/YKXqiwingX4pvrN6IRKADv";
+ sout << "edJG6C4OttSovas7ayKdaTWURKwQWQ2NQF/24DUEmfGkWgk3e1uFjfdBNFez7/omGkRqvSAoT5pe";
+ sout << "rXpbWzb7nigtdMRUdVMdzAHooSObROsbHc7ngUn6sJ6rD88EJZH94kk3adWN6plDDTyRe/eu5Ah0";
+ sout << "XnBPmYIgVSYSZ/X6BvS5AZDElImmABuvuLEjOQ3Ydqc1NQIGOgMzWhn/Dc94YMxSOdmHCo5RyGjo";
+ sout << "Zn0wntZDa+gKaR8PwzsIi+Q7RvReVKP4xKNPbI3W47v3sh1PeidJXt75OoHyKpoxJJcjwB8me4iZ";
+ sout << "ctKBz0QZ5ZBuywtn2Lq69I5nmjHQbpIB+zmPqIxEJOworB7Qg6thkda97KcHsRbMJAA=";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin, sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin, sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ };
+
+ correlation_tracker_tester a;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
diff --git a/ml/dlib/dlib/test/crc32.cpp b/ml/dlib/dlib/test/crc32.cpp
new file mode 100644
index 000000000..32f67ed3a
--- /dev/null
+++ b/ml/dlib/dlib/test/crc32.cpp
@@ -0,0 +1,74 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <cmath>
+#include <dlib/crc32.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.crc32");
+
+
+ class crc32_tester : public tester
+ {
+ public:
+ crc32_tester (
+ ) :
+ tester ("test_crc32",
+ "Runs tests on the crc32 component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ DLIB_TEST(crc32("davis").get_checksum() == 0x0445527C);
+
+ crc32 c, c2;
+ DLIB_TEST(c.get_checksum() == 0);
+ c.add("davis");
+ DLIB_TEST(c.get_checksum() == 0x0445527C);
+ DLIB_TEST(c2.get_checksum() == 0);
+ c2 = c;
+ DLIB_TEST(c2.get_checksum() == 0x0445527C);
+ crc32 c3(c);
+ DLIB_TEST(c3.get_checksum() == 0x0445527C);
+ c.add('a');
+ c2.add('a');
+ c3.add('a');
+ DLIB_TEST(c.get_checksum() == 0xB100C606);
+ DLIB_TEST(c2.get_checksum() == 0xB100C606);
+ DLIB_TEST(c3.get_checksum() == 0xB100C606);
+
+
+ crc32::kernel_1a cold;
+ DLIB_TEST(cold.get_checksum() == 0);
+ cold.add("davis");
+ DLIB_TEST(cold.get_checksum() == 0x0445527C);
+
+ c.clear();
+ DLIB_TEST(c.get_checksum() == 0);
+ c.add("davis");
+ DLIB_TEST(c.get_checksum() == 0x0445527C);
+
+ std::vector<char> buf;
+ for (int i = 0; i < 4000; ++i)
+ buf.push_back(i);
+ DLIB_TEST(crc32(buf) == 492662731);
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/create_iris_datafile.cpp b/ml/dlib/dlib/test/create_iris_datafile.cpp
new file mode 100644
index 000000000..1e19d2aac
--- /dev/null
+++ b/ml/dlib/dlib/test/create_iris_datafile.cpp
@@ -0,0 +1,65 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <sstream>
+#include <fstream>
+#include <dlib/compress_stream.h>
+#include <dlib/base64.h>
+
+namespace
+{
+ // This function returns the contents of the file 'iris.scale'
+ const std::string get_decoded_string()
+ {
+ dlib::base64::kernel_1a base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'iris.scale' we want to decode and return.
+ sout << "MU66cCmT9lCXWJXwhdfOGELwlyExClbHEF1s9XoqxNDV7o8AdVVHws/C9oIKO5EShH1lI/QFTWk3";
+ sout << "8EUdVpSw/NpZCUa7O9nq5uO6SE0gfRAyryH+pfIVL9jPiQi8rBdagDf4kUd4eggz9glwYnKEE+US";
+ sout << "K4GUBnW33YDf/jMF2GIBLNvz69yGJj8RC5rOUeJxR4mlHrDmnEfgRSdFIfXk4OQ4V/XbOsE1bnhG";
+ sout << "fmACcu7nYv6M/043Z6o8oaBeoJ2XK/9UqOWGFOwfVpQ46fz1a0oTlOzyDbbzMiniLr8z5P/VYwYd";
+ sout << "iAE70MwxHXs6Ga3zMmD/h1WxB/uRRph39B1lPN1UXC7U6SIatmtGWY+JYpwBk6raAnR3sblTFBNs";
+ sout << "UdPW+1a7AxinR0NZO6YEiCFy8lbpfPRZNAr5ENqPbD2DZtkHk3L8ARxSoFBgqPa8aO3fFow7rVxF";
+ sout << "xIJ2TxcHS84+BtH7KvtWfH7kUPOZLQ+Ohqghn9I57IeMl7E3aoTRTiVv3P2twAbP5Y+ZaAUoU7CK";
+ sout << "c9FptjKgMClUkuWxA7tGUEp069PqGT8NbI+yxorh/iVhkVhuGAzgjjXYS/D26OGj4bzF6mtRbnms";
+ sout << "Y2OYlF7QqhawZaHLtmZ6xLhR2F8p/0nrbpAz2brQLNKgQAMvU9rTZ0XYpuJNbRSsARkRDorPopDO";
+ sout << "kKNUORfkh2zfIytVToQ9tZ9W2LkfGZdWjJu/wEKjPDAU55q3bCfKOUk12tjq0sq/7qjUWJRcLSCu";
+ sout << "bqo8EzaKJj3cTXVgXXLHP6WEOPZ9vShuxQUu1JWkh8YEinjwFSyA6UnAKqPtN/HsBgv8YbnfnY/q";
+ sout << "e5JvUYWbs3Lk9enlhcI0vEVTV5f0GMjdkW87l3cWgmXJqiljJDREWEdKZJQ0rGBU/gW5kO3SAS1W";
+ sout << "OETVJG2kJD8Ib7hT15Mu2lOVNQYFri6O3yWtp5/NLHsYXoDKIYrxoJtM9+GkprVwRuhDcwxE+eQa";
+ sout << "pp5nC8qj38ameQHaJR2hJCuW2nvr4Wwm0ploF00ZP9cS9YznCO52cueUQX0+zil7bU++jghqSGP5";
+ sout << "+JyRzWUWWbDhnCyanej2Y3sqfZ3o2kuUjaAgZFz5pLqK64uACjztp4bQFsaMRdc+OCV2uItqoaRg";
+ sout << "a6u7/VrvS+ZigwcGWDjXSKev334f8ZqQQIR5hljdeseGuw7/5XySzUrgc8lCOvMa0pKNn9Nl8W/W";
+ sout << "vbKz1VwA";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+}
+
+namespace dlib
+{
+ void create_iris_datafile (
+ )
+ {
+ std::ofstream fout("iris.scale");
+ fout << get_decoded_string();
+ }
+}
+
diff --git a/ml/dlib/dlib/test/create_iris_datafile.h b/ml/dlib/dlib/test/create_iris_datafile.h
new file mode 100644
index 000000000..805291f07
--- /dev/null
+++ b/ml/dlib/dlib/test/create_iris_datafile.h
@@ -0,0 +1,19 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CREATE_IRIS_DAtAFILE_Hh_
+#define DLIB_CREATE_IRIS_DAtAFILE_Hh_
+
+namespace dlib
+{
+ void create_iris_datafile (
+ );
+ /*!
+ ensures
+ - Creates a local file called iris.scale that contains the
+ 150 samples from the 3-class Iris dataset from the UCI
+ repository. The file will be in LIBSVM format (it was
+ originally downloaded from the LIBSVM website).
+ !*/
+}
+
+#endif // DLIB_CREATE_IRIS_DAtAFILE_Hh_
diff --git a/ml/dlib/dlib/test/cublas.cpp b/ml/dlib/dlib/test/cublas.cpp
new file mode 100644
index 000000000..3f3cb09dc
--- /dev/null
+++ b/ml/dlib/dlib/test/cublas.cpp
@@ -0,0 +1,198 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../dnn/tensor_tools.h"
+
+#include "tester.h"
+
+// We only do these tests if CUDA is available to test in the first place.
+#ifdef DLIB_USE_CUDA
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.cublas");
+
+
+ void test_inv()
+ {
+ tt::tensor_rand rnd;
+ dlib::tt::inv tinv;
+ dlib::cuda::inv cinv;
+ resizable_tensor minv1, minv2;
+ for (int n = 1; n < 20; ++n)
+ {
+ print_spinner();
+ resizable_tensor m(n,n);
+ rnd.fill_uniform(m);
+
+ tinv(m, minv1);
+ cinv(m, minv2);
+ matrix<float> mref = inv(mat(m));
+ DLIB_TEST_MSG(mean(abs(mref-mat(minv1)))/mean(abs(mref)) < 1e-5, mean(abs(mref-mat(minv1)))/mean(abs(mref)) <<" n: " << n);
+ DLIB_TEST_MSG(mean(abs(mref-mat(minv2)))/mean(abs(mref)) < 1e-5, mean(abs(mref-mat(minv2)))/mean(abs(mref)) <<" n: " << n);
+ }
+ }
+
+
+ class cublas_tester : public tester
+ {
+ public:
+ cublas_tester (
+ ) :
+ tester ("test_cublas",
+ "Runs tests on the cuBLAS bindings.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_inv();
+ {
+ resizable_tensor a(4,3), b(3,4), c(3,3);
+
+ c = 1;
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+
+ matrix<float> truth = 2*mat(c)+trans(mat(a))*trans(mat(b));
+
+ a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device();
+ cuda::gemm(2, c, 1, a, true, b, true);
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ {
+ resizable_tensor a(4,3), b(4,3), c(3,3);
+
+ c = 1;
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+
+ matrix<float> truth = 2*mat(c)+trans(mat(a))*mat(b);
+
+ a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device();
+ cuda::gemm(2, c, 1, a, true, b, false);
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ {
+ resizable_tensor a(3,4), b(3,4), c(3,3);
+
+ c = 1;
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+
+ matrix<float> truth = 2*mat(c)+mat(a)*trans(mat(b));
+
+ a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device();
+ cuda::gemm(2, c, 1, a, false, b, true);
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ {
+ resizable_tensor a(3,4), b(3,4), c(3,3);
+
+ c = 1;
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+
+ matrix<float> truth = mat(c)+mat(a)*trans(mat(b));
+
+ a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device();
+ cuda::gemm(1, c, 1, a, false, b, true);
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ {
+ resizable_tensor a(3,4), b(4,3), c(3,3);
+
+ c = 1;
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+
+ matrix<float> truth = 2*mat(c)+mat(a)*mat(b);
+
+ a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device();
+ cuda::gemm(2, c, 1, a, false, b, false);
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ {
+ resizable_tensor a(3,4), b(4,3), c(3,3);
+
+ c = std::numeric_limits<float>::infinity();
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+ a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device();
+
+ matrix<float> truth = mat(a)*mat(b);
+
+ cuda::gemm(0, c, 1, a, false, b, false);
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ {
+ resizable_tensor a(3,4), b(4,4), c(3,4);
+
+ c = 1;
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+
+ matrix<float> truth = 2*mat(c)+mat(a)*mat(b);
+
+ cuda::gemm(2, c, 1, a, false, b, false);
+ DLIB_TEST(get_rect(truth) == get_rect(mat(c)));
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ {
+ resizable_tensor a(4,3), b(4,4), c(3,4);
+
+ c = 1;
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+
+ matrix<float> truth = 2*mat(c)+trans(mat(a))*mat(b);
+
+ cuda::gemm(2, c, 1, a, true, b, false);
+ DLIB_TEST(get_rect(truth) == get_rect(mat(c)));
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ {
+ resizable_tensor a(4,3), b(4,5), c(3,5);
+
+ c = 1;
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+
+ matrix<float> truth = 2*mat(c)+trans(mat(a))*mat(b);
+
+ cuda::gemm(2, c, 1, a, true, b, false);
+ DLIB_TEST(get_rect(truth) == get_rect(mat(c)));
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ {
+ resizable_tensor a(4,3), b(4,5), c(3,5);
+
+ c = std::numeric_limits<float>::infinity();
+ a = matrix_cast<float>(gaussian_randm(a.num_samples(),a.size()/a.num_samples()));
+ b = matrix_cast<float>(gaussian_randm(b.num_samples(),b.size()/b.num_samples()));
+
+ matrix<float> truth = trans(mat(a))*mat(b);
+
+ cuda::gemm(0, c, 1, a, true, b, false);
+ DLIB_TEST(get_rect(truth) == get_rect(mat(c)));
+ DLIB_TEST(max(abs(truth-mat(c))) < 1e-6);
+ }
+ }
+ } a;
+
+}
+
+#endif // DLIB_USE_CUDA
+
diff --git a/ml/dlib/dlib/test/data_io.cpp b/ml/dlib/dlib/test/data_io.cpp
new file mode 100644
index 000000000..8ced88008
--- /dev/null
+++ b/ml/dlib/dlib/test/data_io.cpp
@@ -0,0 +1,227 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+#include <dlib/data_io.h>
+#include <dlib/sparse_vector.h>
+#include "create_iris_datafile.h"
+#include <vector>
+#include <sstream>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.data_io");
+
+
+ class test_data_io : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ test_data_io (
+ ) :
+ tester (
+ "test_data_io", // the command line argument name for this test
+ "Run tests on the data_io stuff.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+
+ template <typename sample_type>
+ void run_test()
+ {
+ print_spinner();
+
+ typedef typename sample_type::value_type::second_type scalar_type;
+
+ std::vector<sample_type> samples;
+ std::vector<scalar_type> labels;
+
+ load_libsvm_formatted_data("iris.scale",samples, labels);
+ save_libsvm_formatted_data("iris.scale2", samples, labels);
+
+ DLIB_TEST(samples.size() == 150);
+ DLIB_TEST(labels.size() == 150);
+ DLIB_TEST(max_index_plus_one(samples) == 5);
+ fix_nonzero_indexing(samples);
+ DLIB_TEST(max_index_plus_one(samples) == 4);
+
+ load_libsvm_formatted_data("iris.scale2",samples, labels);
+
+ DLIB_TEST(samples.size() == 150);
+ DLIB_TEST(labels.size() == 150);
+
+ DLIB_TEST(max_index_plus_one(samples) == 5);
+ fix_nonzero_indexing(samples);
+ DLIB_TEST(max_index_plus_one(samples) == 4);
+
+ one_vs_one_trainer<any_trainer<sample_type,scalar_type>,scalar_type> trainer;
+
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+ trainer.set_trainer(krr_trainer<kernel_type>());
+
+ randomize_samples(samples, labels);
+ matrix<double> cv = cross_validate_multiclass_trainer(trainer, samples, labels, 4);
+
+ dlog << LINFO << "confusion matrix: \n" << cv;
+ const scalar_type cv_accuracy = sum(diag(cv))/sum(cv);
+ dlog << LINFO << "cv accuracy: " << cv_accuracy;
+ DLIB_TEST(cv_accuracy > 0.97);
+
+
+
+
+ {
+ print_spinner();
+ typedef matrix<scalar_type,0,1> dsample_type;
+ std::vector<dsample_type> dsamples = sparse_to_dense(samples);
+ DLIB_TEST(dsamples.size() == 150);
+ DLIB_TEST(dsamples[0].size() == 4);
+ DLIB_TEST(max_index_plus_one(dsamples) == 4);
+
+ one_vs_one_trainer<any_trainer<dsample_type,scalar_type>,scalar_type> trainer;
+
+ typedef linear_kernel<dsample_type> kernel_type;
+ trainer.set_trainer(rr_trainer<kernel_type>());
+
+ cv = cross_validate_multiclass_trainer(trainer, dsamples, labels, 4);
+
+ dlog << LINFO << "dense confusion matrix: \n" << cv;
+ const scalar_type cv_accuracy = sum(diag(cv))/sum(cv);
+ dlog << LINFO << "dense cv accuracy: " << cv_accuracy;
+ DLIB_TEST(cv_accuracy > 0.97);
+ }
+
+ }
+
+
+ void test_sparse_to_dense()
+ {
+ {
+ std::map<unsigned long, double> temp;
+
+ matrix<double,0,1> m, m2;
+
+ m = sparse_to_dense(m);
+ DLIB_TEST(m.size() == 0);
+ m.set_size(2,1);
+ m = 1, 2;
+ m2 = sparse_to_dense(m);
+ DLIB_TEST(m == m2);
+ m2 = sparse_to_dense(m,1);
+ DLIB_TEST(m2.size() == 1);
+ DLIB_TEST(m2(0,0) == 1);
+ m2 = sparse_to_dense(m,0);
+ DLIB_TEST(m2.size() == 0);
+
+ temp[3] = 2;
+ temp[5] = 4;
+ m2 = sparse_to_dense(temp);
+ m.set_size(6);
+ m = 0,0,0,2,0,4;
+ DLIB_TEST(m2 == m);
+
+ m2 = sparse_to_dense(temp, 5);
+ m.set_size(5);
+ m = 0,0,0,2,0;
+ DLIB_TEST(m2 == m);
+
+ m2 = sparse_to_dense(temp, 7);
+ m.set_size(7);
+ m = 0,0,0,2,0,4,0;
+ DLIB_TEST(m2 == m);
+
+ std::vector<std::vector<std::pair<unsigned long,double> > > vects;
+
+ std::vector<std::pair<unsigned long,double> > v;
+ v.push_back(make_pair(5,2));
+ v.push_back(make_pair(3,1));
+ v.push_back(make_pair(5,2));
+ v.push_back(make_pair(3,1));
+ v = make_sparse_vector(v);
+ vects.push_back(v);
+ vects.push_back(v);
+ vects.push_back(v);
+ vects.push_back(v);
+ DLIB_TEST(max_index_plus_one(v) == 6);
+ m2 = sparse_to_dense(v);
+ m.set_size(6);
+ m = 0,0,0,2,0,4;
+ DLIB_TEST_MSG(m2 == m, m2 << "\n\n" << m );
+
+ m2 = sparse_to_dense(v,7);
+ m.set_size(7);
+ m = 0,0,0,2,0,4,0;
+ DLIB_TEST(m2 == m);
+
+ m2 = sparse_to_dense(v,5);
+ m.set_size(5);
+ m = 0,0,0,2,0;
+ DLIB_TEST(m2 == m);
+
+ v.clear();
+ m2 = sparse_to_dense(v);
+ DLIB_TEST(m2.size() == 0);
+
+
+ std::vector<matrix<double,0,1> > mvects = sparse_to_dense(vects);
+ DLIB_TEST(mvects.size() == 4);
+ m.set_size(6);
+ m = 0,0,0,2,0,4;
+ DLIB_TEST(mvects[0] == m);
+ DLIB_TEST(mvects[1] == m);
+ DLIB_TEST(mvects[2] == m);
+ DLIB_TEST(mvects[3] == m);
+
+
+ mvects = sparse_to_dense(vects, 7);
+ DLIB_TEST(mvects.size() == 4);
+ m.set_size(7);
+ m = 0,0,0,2,0,4,0;
+ DLIB_TEST(mvects[0] == m);
+ DLIB_TEST(mvects[1] == m);
+ DLIB_TEST(mvects[2] == m);
+ DLIB_TEST(mvects[3] == m);
+
+ mvects = sparse_to_dense(vects, 5);
+ DLIB_TEST(mvects.size() == 4);
+ m.set_size(5);
+ m = 0,0,0,2,0;
+ DLIB_TEST(mvects[0] == m);
+ DLIB_TEST(mvects[1] == m);
+ DLIB_TEST(mvects[2] == m);
+ DLIB_TEST(mvects[3] == m);
+
+ }
+ }
+
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ create_iris_datafile();
+
+ test_sparse_to_dense();
+
+ run_test<std::map<unsigned int, double> >();
+ run_test<std::map<unsigned int, float> >();
+ run_test<std::vector<std::pair<unsigned int, float> > >();
+ run_test<std::vector<std::pair<unsigned long, double> > >();
+ }
+ };
+
+ test_data_io a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/directed_graph.cpp b/ml/dlib/dlib/test/directed_graph.cpp
new file mode 100644
index 000000000..b97976b96
--- /dev/null
+++ b/ml/dlib/dlib/test/directed_graph.cpp
@@ -0,0 +1,541 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/directed_graph.h>
+#include <dlib/graph.h>
+#include <dlib/graph_utils.h>
+#include <dlib/set.h>
+
+#include "tester.h"
+
+// This is called an unnamed-namespace and it has the effect of making everything inside this file "private"
+// so that everything you declare will have static linkage. Thus we won't have any multiply
+// defined symbol errors coming out of the linker when we try to compile the test suite.
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ // Declare the logger we will use in this test. The name of the tester
+ // should start with "test."
+ logger dlog("test.directed_graph");
+
+ template <
+ typename directed_graph
+ >
+ void directed_graph_test (
+ )
+ /*!
+ requires
+ - directed_graph is an implementation of directed_graph/directed_graph_kernel_abstract.h
+ is instantiated with int
+ ensures
+ - runs tests on directed_graph for compliance with the specs
+ !*/
+ {
+ print_spinner();
+
+ COMPILE_TIME_ASSERT(is_directed_graph<directed_graph>::value == true);
+ directed_graph a, b;
+ dlib::set<unsigned long>::compare_1b_c s;
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+
+ DLIB_TEST(a.number_of_nodes() == 0);
+
+ DLIB_TEST(graph_contains_length_one_cycle(a) == false);
+
+ a.set_number_of_nodes(5);
+ DLIB_TEST(graph_contains_length_one_cycle(a) == false);
+ DLIB_TEST(graph_is_connected(a) == false);
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+ DLIB_TEST(a.number_of_nodes() == 5);
+
+ for (int i = 0; i < 5; ++i)
+ {
+ a.node(i).data = i;
+ DLIB_TEST(a.node(i).index() == (unsigned int)i);
+ }
+
+ a.remove_node(1);
+
+ DLIB_TEST(a.number_of_nodes() == 4);
+
+
+ // make sure that only the number with data == 1 was remove
+ int count = 0;
+ for (int i = 0; i < 4; ++i)
+ {
+ count += a.node(i).data;
+ DLIB_TEST(a.node(i).number_of_children() == 0);
+ DLIB_TEST(a.node(i).number_of_parents() == 0);
+ DLIB_TEST(a.node(i).index() == (unsigned int)i);
+ }
+
+ DLIB_TEST(count == 9);
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+
+ a.add_edge(1,1);
+ DLIB_TEST(graph_contains_length_one_cycle(a) == true);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == true);
+
+ a.add_edge(1,2);
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == true);
+
+ DLIB_TEST(a.node(1).number_of_children() == 2);
+ DLIB_TEST(a.node(1).number_of_parents() == 1);
+ DLIB_TEST_MSG(a.node(1).parent(0).index() == 1,"");
+
+ DLIB_TEST_MSG(a.node(1).child(0).index() + a.node(1).child(1).index() == 3,"");
+ DLIB_TEST(a.node(2).number_of_children() == 0);
+ DLIB_TEST(a.node(2).number_of_parents() == 1);
+ DLIB_TEST(a.node(2).index() == 2);
+
+ int val = a.node(1).data;
+ a.remove_node(1);
+ DLIB_TEST(graph_contains_length_one_cycle(a) == false);
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+
+
+ DLIB_TEST(a.number_of_nodes() == 3);
+
+ count = 0;
+ for (int i = 0; i < 3; ++i)
+ {
+ count += a.node(i).data;
+ DLIB_TEST(a.node(i).number_of_children() == 0);
+ DLIB_TEST(a.node(i).number_of_parents() == 0);
+ DLIB_TEST(a.node(i).index() == (unsigned int)i);
+ }
+ DLIB_TEST(count == 9-val);
+
+
+ val = a.add_node();
+ DLIB_TEST(val == 3);
+ DLIB_TEST(a.number_of_nodes() == 4);
+
+ for (int i = 0; i < 4; ++i)
+ {
+ a.node(i).data = i;
+ DLIB_TEST(a.node(i).index() == (unsigned int)i);
+ }
+
+ for (int i = 0; i < 4; ++i)
+ {
+ DLIB_TEST(a.node(i).data == i);
+ DLIB_TEST(a.node(i).index() == (unsigned int)i);
+ }
+
+ a.add_edge(0, 1);
+ a.add_edge(0, 2);
+ DLIB_TEST(graph_is_connected(a) == false);
+ a.add_edge(1, 3);
+ DLIB_TEST(graph_is_connected(a) == true);
+ a.add_edge(2, 3);
+ DLIB_TEST(graph_is_connected(a) == true);
+ DLIB_TEST(graph_contains_length_one_cycle(a) == false);
+
+ DLIB_TEST(a.has_edge(0, 1));
+ DLIB_TEST(a.has_edge(0, 2));
+ DLIB_TEST(a.has_edge(1, 3));
+ DLIB_TEST(a.has_edge(2, 3));
+
+ DLIB_TEST(!a.has_edge(1, 0));
+ DLIB_TEST(!a.has_edge(2, 0));
+ DLIB_TEST(!a.has_edge(3, 1));
+ DLIB_TEST(!a.has_edge(3, 2));
+
+ DLIB_TEST(a.node(0).number_of_parents() == 0);
+ DLIB_TEST(a.node(0).number_of_children() == 2);
+
+ DLIB_TEST(a.node(1).number_of_parents() == 1);
+ DLIB_TEST(a.node(1).number_of_children() == 1);
+ DLIB_TEST(a.node(1).child(0).index() == 3);
+ DLIB_TEST(a.node(1).parent(0).index() == 0);
+
+ DLIB_TEST(a.node(2).number_of_parents() == 1);
+ DLIB_TEST(a.node(2).number_of_children() == 1);
+ DLIB_TEST(a.node(2).child(0).index() == 3);
+ DLIB_TEST(a.node(2).parent(0).index() == 0);
+
+ DLIB_TEST(a.node(3).number_of_parents() == 2);
+ DLIB_TEST(a.node(3).number_of_children() == 0);
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+
+ a.remove_edge(0,1);
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+
+ DLIB_TEST(!a.has_edge(0, 1));
+ DLIB_TEST(a.has_edge(0, 2));
+ DLIB_TEST(a.has_edge(1, 3));
+ DLIB_TEST(a.has_edge(2, 3));
+
+ DLIB_TEST(!a.has_edge(1, 0));
+ DLIB_TEST(!a.has_edge(2, 0));
+ DLIB_TEST(!a.has_edge(3, 1));
+ DLIB_TEST(!a.has_edge(3, 2));
+
+
+ DLIB_TEST(a.node(0).number_of_parents() == 0);
+ DLIB_TEST(a.node(0).number_of_children() == 1);
+
+ DLIB_TEST(a.node(1).number_of_parents() == 0);
+ DLIB_TEST(a.node(1).number_of_children() == 1);
+ DLIB_TEST(a.node(1).child(0).index() == 3);
+
+ DLIB_TEST(a.node(2).number_of_parents() == 1);
+ DLIB_TEST(a.node(2).number_of_children() == 1);
+ DLIB_TEST(a.node(2).child(0).index() == 3);
+ DLIB_TEST(a.node(2).parent(0).index() == 0);
+
+ DLIB_TEST(a.node(3).number_of_parents() == 2);
+ DLIB_TEST(a.node(3).number_of_children() == 0);
+
+ for (int i = 0; i < 4; ++i)
+ {
+ DLIB_TEST(a.node(i).data == i);
+ DLIB_TEST(a.node(i).index() == (unsigned int)i);
+ }
+
+
+
+ swap(a,b);
+
+ DLIB_TEST(a.number_of_nodes() == 0);
+ DLIB_TEST(b.number_of_nodes() == 4);
+ DLIB_TEST(b.node(0).number_of_parents() == 0);
+ DLIB_TEST(b.node(0).number_of_children() == 1);
+
+ DLIB_TEST(b.node(1).number_of_parents() == 0);
+ DLIB_TEST(b.node(1).number_of_children() == 1);
+ DLIB_TEST(b.node(1).child(0).index() == 3);
+
+ DLIB_TEST(b.node(2).number_of_parents() == 1);
+ DLIB_TEST(b.node(2).number_of_children() == 1);
+ DLIB_TEST(b.node(2).child(0).index() == 3);
+ DLIB_TEST(b.node(2).parent(0).index() == 0);
+
+ DLIB_TEST(b.node(3).number_of_parents() == 2);
+ DLIB_TEST(b.node(3).number_of_children() == 0);
+ b.node(0).child_edge(0) = static_cast<unsigned short>(b.node(0).child(0).index()+1);
+ b.node(1).child_edge(0) = static_cast<unsigned short>(b.node(1).child(0).index()+1);
+ b.node(2).child_edge(0) = static_cast<unsigned short>(b.node(2).child(0).index()+1);
+
+ DLIB_TEST_MSG(b.node(0).child_edge(0) == b.node(0).child(0).index()+1,
+ b.node(0).child_edge(0) << " " << b.node(0).child(0).index()+1);
+ DLIB_TEST_MSG(b.node(1).child_edge(0) == b.node(1).child(0).index()+1,
+ b.node(1).child_edge(0) << " " << b.node(1).child(0).index()+1);
+ DLIB_TEST_MSG(b.node(2).child_edge(0) == b.node(2).child(0).index()+1,
+ b.node(2).child_edge(0) << " " << b.node(2).child(0).index()+1);
+
+ DLIB_TEST_MSG(b.node(2).parent_edge(0) == 2+1,
+ b.node(2).parent_edge(0) << " " << 2+1);
+ DLIB_TEST_MSG(b.node(3).parent_edge(0) == 3+1,
+ b.node(3).parent_edge(0) << " " << 3+1);
+ DLIB_TEST_MSG(b.node(3).parent_edge(1) == 3+1,
+ b.node(3).parent_edge(1) << " " << 3+1);
+
+ ostringstream sout;
+
+ serialize(b, sout);
+
+ istringstream sin(sout.str());
+
+ a.set_number_of_nodes(20);
+ DLIB_TEST(a.number_of_nodes() == 20);
+ deserialize(a, sin);
+ DLIB_TEST(a.number_of_nodes() == 4);
+
+ DLIB_TEST(!a.has_edge(0, 1));
+ DLIB_TEST(a.has_edge(0, 2));
+ DLIB_TEST(a.has_edge(1, 3));
+ DLIB_TEST(a.has_edge(2, 3));
+
+ DLIB_TEST(!a.has_edge(1, 0));
+ DLIB_TEST(!a.has_edge(2, 0));
+ DLIB_TEST(!a.has_edge(3, 1));
+ DLIB_TEST(!a.has_edge(3, 2));
+
+ DLIB_TEST_MSG(a.node(0).child_edge(0) == a.node(0).child(0).index()+1,
+ a.node(0).child_edge(0) << " " << a.node(0).child(0).index()+1);
+ DLIB_TEST_MSG(a.node(1).child_edge(0) == a.node(1).child(0).index()+1,
+ a.node(1).child_edge(0) << " " << a.node(1).child(0).index()+1);
+ DLIB_TEST_MSG(a.node(2).child_edge(0) == a.node(2).child(0).index()+1,
+ a.node(2).child_edge(0) << " " << a.node(2).child(0).index()+1);
+ DLIB_TEST_MSG(a.node(2).parent_edge(0) == 2+1,
+ a.node(2).parent_edge(0) << " " << 2+1);
+ DLIB_TEST_MSG(a.node(3).parent_edge(0) == 3+1,
+ a.node(3).parent_edge(0) << " " << 3+1);
+ DLIB_TEST_MSG(a.node(3).parent_edge(1) == 3+1,
+ a.node(3).parent_edge(1) << " " << 3+1);
+
+
+
+ for (int i = 0; i < 4; ++i)
+ {
+ DLIB_TEST(a.node(i).data == i);
+ DLIB_TEST(a.node(i).index() == (unsigned int)i);
+ }
+
+
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+
+ DLIB_TEST(b.number_of_nodes() == 4);
+ DLIB_TEST(b.node(0).number_of_parents() == 0);
+ DLIB_TEST(b.node(0).number_of_children() == 1);
+
+ DLIB_TEST(b.node(1).number_of_parents() == 0);
+ DLIB_TEST(b.node(1).number_of_children() == 1);
+ DLIB_TEST(b.node(1).child(0).index() == 3);
+
+ DLIB_TEST(b.node(2).number_of_parents() == 1);
+ DLIB_TEST(b.node(2).number_of_children() == 1);
+ DLIB_TEST(b.node(2).child(0).index() == 3);
+ DLIB_TEST(b.node(2).parent(0).index() == 0);
+
+ DLIB_TEST(b.node(3).number_of_parents() == 2);
+ DLIB_TEST(b.node(3).number_of_children() == 0);
+
+
+ DLIB_TEST(a.number_of_nodes() == 4);
+ DLIB_TEST(a.node(0).number_of_parents() == 0);
+ DLIB_TEST(a.node(0).number_of_children() == 1);
+
+ DLIB_TEST(a.node(1).number_of_parents() == 0);
+ DLIB_TEST(a.node(1).number_of_children() == 1);
+ DLIB_TEST(a.node(1).child(0).index() == 3);
+
+ DLIB_TEST(a.node(2).number_of_parents() == 1);
+ DLIB_TEST(a.node(2).number_of_children() == 1);
+ DLIB_TEST(a.node(2).child(0).index() == 3);
+ DLIB_TEST(a.node(2).parent(0).index() == 0);
+
+ DLIB_TEST(a.node(3).number_of_parents() == 2);
+ DLIB_TEST(a.node(3).number_of_children() == 0);
+
+ DLIB_TEST(a.number_of_nodes() == 4);
+ a.clear();
+ DLIB_TEST(a.number_of_nodes() == 0);
+
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+
+ a.set_number_of_nodes(10);
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+
+ a.add_edge(0,1);
+ a.add_edge(1,2);
+ a.add_edge(1,3);
+ a.add_edge(2,4);
+ a.add_edge(3,4);
+ a.add_edge(4,5);
+ a.add_edge(5,1);
+
+ DLIB_TEST(graph_contains_directed_cycle(a) == true);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+
+ a.remove_edge(5,1);
+
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+ a.add_edge(7,8);
+ DLIB_TEST(graph_contains_directed_cycle(a) == false);
+ a.add_edge(8,7);
+ DLIB_TEST(graph_contains_directed_cycle(a) == true);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+
+
+ a.clear();
+ /*
+ Make a graph that looks like:
+ 0 1
+ \ /
+ 2
+ |
+ 3
+ */
+ a.set_number_of_nodes(4);
+ a.add_edge(0,2);
+ a.add_edge(1,2);
+ a.add_edge(2,3);
+ for (unsigned long i = 0; i < 4; ++i)
+ a.node(i).data = i;
+
+ graph<int>::kernel_1a_c g;
+ create_moral_graph(a,g);
+
+ graph<dlib::set<unsigned long>::compare_1b_c, dlib::set<unsigned long>::compare_1a_c>::kernel_1a_c join_tree;
+ dlib::set<dlib::set<unsigned long>::compare_1b_c>::kernel_1b_c sos;
+
+ create_join_tree(g, join_tree);
+ DLIB_TEST(is_join_tree(g, join_tree));
+ DLIB_TEST(join_tree.number_of_nodes() == 2);
+ DLIB_TEST(graph_contains_undirected_cycle(join_tree) == false);
+ DLIB_TEST(graph_is_connected(join_tree) == true);
+
+ unsigned long temp;
+ triangulate_graph_and_find_cliques(g,sos);
+
+ temp = 2; s.add(temp);
+ temp = 3; s.add(temp);
+ DLIB_TEST(sos.is_member(s));
+ s.clear();
+ temp = 0; s.add(temp);
+ temp = 1; s.add(temp);
+ temp = 2; s.add(temp);
+ DLIB_TEST(sos.is_member(s));
+ DLIB_TEST(sos.size() == 2);
+ DLIB_TEST(sos.is_member(join_tree.node(0).data));
+ DLIB_TEST(sos.is_member(join_tree.node(1).data));
+
+
+ s.clear();
+ temp = 0; s.add(temp);
+ DLIB_TEST(is_clique(g,s) == true);
+ DLIB_TEST(is_maximal_clique(g,s) == false);
+ temp = 3; s.add(temp);
+ DLIB_TEST(is_clique(g,s) == false);
+ s.destroy(3);
+ DLIB_TEST(is_clique(g,s) == true);
+ temp = 2; s.add(temp);
+ DLIB_TEST(is_clique(g,s) == true);
+ DLIB_TEST(is_maximal_clique(g,s) == false);
+ temp = 1; s.add(temp);
+ DLIB_TEST(is_clique(g,s) == true);
+ DLIB_TEST(is_maximal_clique(g,s) == true);
+ s.clear();
+ DLIB_TEST(is_clique(g,s) == true);
+ temp = 3; s.add(temp);
+ DLIB_TEST(is_clique(g,s) == true);
+ temp = 2; s.add(temp);
+ DLIB_TEST(is_clique(g,s) == true);
+ DLIB_TEST(is_maximal_clique(g,s) == true);
+
+
+ DLIB_TEST(a.number_of_nodes() == 4);
+ DLIB_TEST(g.number_of_nodes() == 4);
+ for (unsigned long i = 0; i < 4; ++i)
+ DLIB_TEST( a.node(i).data == (int)i);
+ DLIB_TEST(g.has_edge(0,1));
+ DLIB_TEST(g.has_edge(0,2));
+ DLIB_TEST(g.has_edge(1,2));
+ DLIB_TEST(g.has_edge(3,2));
+ DLIB_TEST(g.has_edge(0,3) == false);
+ DLIB_TEST(g.has_edge(1,3) == false);
+
+ }
+
+
+ void test_copy()
+ {
+ {
+ directed_graph<int,int>::kernel_1a_c a,b;
+
+ a.set_number_of_nodes(3);
+ a.node(0).data = 1;
+ a.node(1).data = 2;
+ a.node(2).data = 3;
+ a.add_edge(0,1);
+ a.add_edge(1,0);
+ a.add_edge(0,2);
+ edge(a,0,1) = 4;
+ edge(a,1,0) = 3;
+ edge(a,0,2) = 5;
+
+ a.add_edge(0,0);
+ edge(a,0,0) = 9;
+ copy_graph(a, b);
+
+ DLIB_TEST(b.number_of_nodes() == 3);
+ DLIB_TEST(b.node(0).data == 1);
+ DLIB_TEST(b.node(1).data == 2);
+ DLIB_TEST(b.node(2).data == 3);
+ DLIB_TEST(edge(b,0,1) == 4);
+ DLIB_TEST(edge(b,1,0) == 3);
+ DLIB_TEST(edge(b,0,2) == 5);
+ DLIB_TEST(edge(b,0,0) == 9);
+ }
+ {
+ directed_graph<int,int>::kernel_1a_c a,b;
+
+ a.set_number_of_nodes(4);
+ a.node(0).data = 1;
+ a.node(1).data = 2;
+ a.node(2).data = 3;
+ a.node(3).data = 8;
+ a.add_edge(0,1);
+ a.add_edge(0,2);
+ a.add_edge(2,3);
+ a.add_edge(3,2);
+ edge(a,0,1) = 4;
+ edge(a,0,2) = 5;
+ edge(a,2,3) = 6;
+ edge(a,3,2) = 3;
+
+ copy_graph(a, b);
+
+ DLIB_TEST(b.number_of_nodes() == 4);
+ DLIB_TEST(b.node(0).data == 1);
+ DLIB_TEST(b.node(1).data == 2);
+ DLIB_TEST(b.node(2).data == 3);
+ DLIB_TEST(b.node(3).data == 8);
+ DLIB_TEST(edge(b,0,1) == 4);
+ DLIB_TEST(edge(b,0,2) == 5);
+ DLIB_TEST(edge(b,2,3) == 6);
+ DLIB_TEST(edge(b,3,2) == 3);
+ }
+ }
+
+
+
+ class directed_graph_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a test for the directed_graph object. When it is constructed
+ it adds itself into the testing framework. The command line switch is
+ specified as test_directed_graph by passing that string to the tester constructor.
+ !*/
+ public:
+ directed_graph_tester (
+ ) :
+ tester ("test_directed_graph",
+ "Runs tests on the directed_graph component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_copy();
+
+ dlog << LINFO << "testing kernel_1a_c";
+ directed_graph_test<directed_graph<int,unsigned short>::kernel_1a_c>();
+
+ dlog << LINFO << "testing kernel_1a";
+ directed_graph_test<directed_graph<int,unsigned short>::kernel_1a>();
+ }
+ } a;
+
+
+}
+
+
diff --git a/ml/dlib/dlib/test/discriminant_pca.cpp b/ml/dlib/dlib/test/discriminant_pca.cpp
new file mode 100644
index 000000000..2a7aa61d1
--- /dev/null
+++ b/ml/dlib/dlib/test/discriminant_pca.cpp
@@ -0,0 +1,365 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm.h>
+#include <dlib/rand.h>
+#include <dlib/string.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.discriminant_pca");
+
+ using dlib::equal;
+
+ class discriminant_pca_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ discriminant_pca_tester (
+ ) :
+ tester (
+ "test_discriminant_pca", // the command line argument name for this test
+ "Run tests on the discriminant_pca object.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ thetime = 1407805946;// time(0);
+ }
+
+ time_t thetime;
+ dlib::rand rnd;
+
+ template <typename dpca_type>
+ void test1()
+ {
+
+ dpca_type dpca, dpca2, dpca3;
+
+ DLIB_TEST(dpca.in_vector_size() == 0);
+ DLIB_TEST(dpca.between_class_weight() == 1);
+ DLIB_TEST(dpca.within_class_weight() == 1);
+
+ // generate a bunch of 4 dimensional vectors and compute the normal PCA transformation matrix
+ // and just make sure it is a unitary matrix as it should be.
+ for (int i = 0; i < 5000; ++i)
+ {
+ dpca.add_to_total_variance(randm(4,1,rnd));
+ DLIB_TEST(dpca.in_vector_size() == 4);
+ }
+
+
+ matrix<double> mat = dpca.dpca_matrix(1);
+
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
+
+ mat = dpca.dpca_matrix(0.9);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(mat.nr())));
+
+ matrix<double> eig;
+ dpca.dpca_matrix(mat, eig, 1);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
+ // check that all eigen values are grater than 0
+ DLIB_TEST(min(eig > 0) == 1);
+ DLIB_TEST(eig.size() == mat.nr());
+ DLIB_TEST(is_col_vector(eig));
+ // check that the eigenvalues are sorted
+ double last = eig(0);
+ for (long i = 1; i < eig.size(); ++i)
+ {
+ DLIB_TEST(last >= eig(i));
+ }
+
+ {
+ matrix<double> mat = dpca.dpca_matrix_of_size(4);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
+ }
+ {
+ matrix<double> mat = dpca.dpca_matrix_of_size(3);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(3)));
+ }
+
+
+ dpca.set_within_class_weight(5);
+ dpca.set_between_class_weight(6);
+
+ DLIB_TEST(dpca.in_vector_size() == 4);
+ DLIB_TEST(dpca.within_class_weight() == 5);
+ DLIB_TEST(dpca.between_class_weight() == 6);
+
+
+ ostringstream sout;
+ serialize(dpca, sout);
+ istringstream sin(sout.str());
+ deserialize(dpca2, sin);
+
+ // now make sure the serialization worked
+ DLIB_TEST(dpca.in_vector_size() == 4);
+ DLIB_TEST(dpca.within_class_weight() == 5);
+ DLIB_TEST(dpca.between_class_weight() == 6);
+ DLIB_TEST(dpca2.in_vector_size() == 4);
+ DLIB_TEST(dpca2.within_class_weight() == 5);
+ DLIB_TEST(dpca2.between_class_weight() == 6);
+ DLIB_TEST(equal(dpca.dpca_matrix(), dpca2.dpca_matrix(), 1e-10));
+ DLIB_TEST(equal(mat, dpca2.dpca_matrix(1), 1e-10));
+ DLIB_TEST(equal(dpca.dpca_matrix(1), mat, 1e-10));
+
+ // now test swap
+ dpca2.swap(dpca3);
+ DLIB_TEST(dpca2.in_vector_size() == 0);
+ DLIB_TEST(dpca2.between_class_weight() == 1);
+ DLIB_TEST(dpca2.within_class_weight() == 1);
+
+ DLIB_TEST(dpca3.in_vector_size() == 4);
+ DLIB_TEST(dpca3.within_class_weight() == 5);
+ DLIB_TEST(dpca3.between_class_weight() == 6);
+ DLIB_TEST(equal(mat, dpca3.dpca_matrix(1), 1e-10));
+ DLIB_TEST((dpca3 + dpca3).in_vector_size() == 4);
+ DLIB_TEST((dpca3 + dpca3).within_class_weight() == 5);
+ DLIB_TEST((dpca3 + dpca3).between_class_weight() == 6);
+
+ dpca.clear();
+
+ DLIB_TEST(dpca.in_vector_size() == 0);
+ DLIB_TEST(dpca.between_class_weight() == 1);
+ DLIB_TEST(dpca.within_class_weight() == 1);
+ }
+
+ template <typename dpca_type>
+ void test2()
+ {
+ dpca_type dpca, dpca2, dpca3;
+
+ typename dpca_type::column_matrix samp1(4), samp2(4);
+
+ for (int i = 0; i < 5000; ++i)
+ {
+ dpca.add_to_total_variance(randm(4,1,rnd));
+ DLIB_TEST(dpca.in_vector_size() == 4);
+
+ // do this to subtract out the variance along the 3rd axis
+ samp1 = 0,0,0,0;
+ samp2 = 0,0,1,0;
+ dpca.add_to_within_class_variance(samp1, samp2);
+ }
+
+ matrix<double> mat;
+
+ dpca.set_within_class_weight(0);
+ mat = dpca.dpca_matrix(1);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
+ DLIB_TEST(dpca.dpca_matrix(1).nr() == 4);
+ dpca.set_within_class_weight(1000);
+ DLIB_TEST(dpca.dpca_matrix(1).nr() == 3);
+
+ // the 3rd column of the transformation matrix should be all zero since
+ // we killed all the variation long the 3rd axis
+ DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),2))) < 1e-5);
+
+ mat = dpca.dpca_matrix(1);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(3)));
+
+
+ }
+
+ template <typename dpca_type>
+ void test3()
+ {
+ dpca_type dpca, dpca2, dpca3;
+
+ typename dpca_type::column_matrix samp1(4), samp2(4);
+
+ for (int i = 0; i < 5000; ++i)
+ {
+ dpca.add_to_total_variance(randm(4,1,rnd));
+ DLIB_TEST(dpca.in_vector_size() == 4);
+
+ // do this to subtract out the variance along the 3rd axis
+ samp1 = 0,0,0,0;
+ samp2 = 0,0,1,0;
+ dpca.add_to_within_class_variance(samp1, samp2);
+
+ // do this to subtract out the variance along the 1st axis
+ samp1 = 0,0,0,0;
+ samp2 = 1,0,0,0;
+ dpca.add_to_within_class_variance(samp1, samp2);
+ }
+
+ matrix<double> mat;
+
+ dpca.set_within_class_weight(0);
+ mat = dpca.dpca_matrix(1);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
+ DLIB_TEST(dpca.dpca_matrix(1).nr() == 4);
+ dpca.set_within_class_weight(10000);
+ DLIB_TEST(dpca.dpca_matrix(1).nr() == 2);
+
+ // the 1st and 3rd columns of the transformation matrix should be all zero since
+ // we killed all the variation long the 1st and 3rd axes
+ DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),2))) < 1e-5);
+ DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),0))) < 1e-5);
+
+ mat = dpca.dpca_matrix(1);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(2)));
+
+
+ }
+
+ template <typename dpca_type>
+ void test4()
+ {
+ dpca_type dpca, dpca2, dpca3;
+
+ dpca_type add_dpca1, add_dpca2, add_dpca3, add_dpca4, sum_dpca;
+
+ typename dpca_type::column_matrix samp1(4), samp2(4), samp;
+
+ for (int i = 0; i < 5000; ++i)
+ {
+ samp = randm(4,1,rnd);
+ dpca.add_to_total_variance(samp);
+ add_dpca4.add_to_total_variance(samp);
+ DLIB_TEST(dpca.in_vector_size() == 4);
+
+ // do this to subtract out the variance along the 3rd axis
+ samp1 = 0,0,0,0;
+ samp2 = 0,0,1,0;
+ dpca.add_to_within_class_variance(samp1, samp2);
+ add_dpca1.add_to_within_class_variance(samp1, samp2);
+
+ // do this to subtract out the variance along the 1st axis
+ samp1 = 0,0,0,0;
+ samp2 = 1,0,0,0;
+ dpca.add_to_within_class_variance(samp1, samp2);
+ add_dpca2.add_to_within_class_variance(samp1, samp2);
+
+ // do this to add the variance along the 3rd axis back in
+ samp1 = 0,0,0,0;
+ samp2 = 0,0,1,0;
+ dpca.add_to_between_class_variance(samp1, samp2);
+ add_dpca3.add_to_between_class_variance(samp1, samp2);
+ }
+
+ matrix<double> mat, mat2;
+
+ sum_dpca += dpca_type() + dpca_type() + add_dpca1 + dpca_type() + add_dpca2 + add_dpca3 + add_dpca4;
+ dpca.set_within_class_weight(0);
+ dpca.set_between_class_weight(0);
+ sum_dpca.set_within_class_weight(0);
+ sum_dpca.set_between_class_weight(0);
+ mat = dpca.dpca_matrix(1);
+ DLIB_TEST(equal(mat, sum_dpca.dpca_matrix(1), 1e-10));
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
+ DLIB_TEST(dpca.dpca_matrix(1).nr() == 4);
+ dpca.set_within_class_weight(10000);
+ sum_dpca.set_within_class_weight(10000);
+ DLIB_TEST(dpca.dpca_matrix(1).nr() == 2);
+
+ // the 1st and 3rd columns of the transformation matrix should be all zero since
+ // we killed all the variation long the 1st and 3rd axes
+ DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),2))) < 1e-4);
+ DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),0))) < 1e-4);
+
+ mat = dpca.dpca_matrix(1);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(2)));
+ DLIB_TEST_MSG(equal(mat, mat2=sum_dpca.dpca_matrix(1), 1e-9), max(abs(mat - mat2)));
+
+
+ // now add the variance back in using the between class weight
+ dpca.set_within_class_weight(0);
+ dpca.set_between_class_weight(1);
+ mat = dpca.dpca_matrix(1);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(4)));
+ DLIB_TEST(dpca.dpca_matrix(1).nr() == 4);
+ dpca.set_within_class_weight (10000);
+ dpca.set_between_class_weight(100000);
+ sum_dpca.set_within_class_weight (10000);
+ sum_dpca.set_between_class_weight(100000);
+ DLIB_TEST(dpca.dpca_matrix(1).nr() == 3);
+
+ // the first column should be all zeros
+ DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),0))) < 1e-5);
+
+ mat = dpca.dpca_matrix(1);
+ DLIB_TEST(equal(mat*trans(mat), identity_matrix<double>(3)));
+ DLIB_TEST(equal(mat, sum_dpca.dpca_matrix(1)));
+
+
+ }
+
+ template <typename dpca_type>
+ void test5()
+ {
+ dpca_type dpca, dpca2;
+ typename dpca_type::column_matrix samp1(4), samp2(4);
+
+ samp1 = 0,0,0,0;
+ samp2 = 0,0,1,0;
+
+ for (int i = 0; i < 5000; ++i)
+ {
+ dpca.add_to_between_class_variance(samp1, samp2);
+ dpca2.add_to_total_variance(samp1);
+ dpca2.add_to_total_variance(samp2);
+ }
+
+ matrix<double> mat, eig;
+ dpca.dpca_matrix(mat, eig, 1);
+
+ // make sure the eigenvalues come out the way they should for this simple data set
+ DLIB_TEST(eig.size() == 1);
+ DLIB_TEST_MSG(abs(eig(0) - 1) < 1e-10, abs(eig(0) - 1));
+
+ dpca2.dpca_matrix(mat, eig, 1);
+
+ // make sure the eigenvalues come out the way they should for this simple data set
+ DLIB_TEST(eig.size() == 1);
+ DLIB_TEST(abs(eig(0) - 0.25) < 1e-10);
+
+ }
+
+ void perform_test (
+ )
+ {
+ ++thetime;
+ typedef matrix<double,0,1> sample_type;
+ typedef discriminant_pca<sample_type> dpca_type;
+
+ dlog << LINFO << "time seed: " << thetime;
+ rnd.set_seed(cast_to_string(thetime));
+
+ test5<dpca_type>();
+
+ for (int i = 0; i < 10; ++i)
+ {
+ print_spinner();
+ test1<dpca_type>();
+ print_spinner();
+ test2<dpca_type>();
+ print_spinner();
+ test3<dpca_type>();
+ print_spinner();
+ test4<dpca_type>();
+ }
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ discriminant_pca_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/disjoint_subsets.cpp b/ml/dlib/dlib/test/disjoint_subsets.cpp
new file mode 100644
index 000000000..2545219cd
--- /dev/null
+++ b/ml/dlib/dlib/test/disjoint_subsets.cpp
@@ -0,0 +1,102 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/disjoint_subsets.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.disjoint_subsets");
+
+ void test_disjoint_subset()
+ {
+ print_spinner();
+ disjoint_subsets s;
+
+ DLIB_TEST(s.size() == 0);
+
+ s.set_size(5);
+ DLIB_TEST(s.size() == 5);
+
+ DLIB_TEST(s.find_set(0) == 0);
+ DLIB_TEST(s.find_set(1) == 1);
+ DLIB_TEST(s.find_set(2) == 2);
+ DLIB_TEST(s.find_set(3) == 3);
+ DLIB_TEST(s.find_set(4) == 4);
+
+ unsigned long id = s.merge_sets(1,3);
+ DLIB_TEST(s.find_set(0) == 0);
+ DLIB_TEST(s.find_set(1) == id);
+ DLIB_TEST(s.find_set(2) == 2);
+ DLIB_TEST(s.find_set(3) == id);
+ DLIB_TEST(s.find_set(4) == 4);
+
+ id = s.merge_sets(s.find_set(1),4);
+ DLIB_TEST(s.find_set(0) == 0);
+ DLIB_TEST(s.find_set(1) == id);
+ DLIB_TEST(s.find_set(2) == 2);
+ DLIB_TEST(s.find_set(3) == id);
+ DLIB_TEST(s.find_set(4) == id);
+
+ unsigned long id2 = s.merge_sets(0,2);
+ DLIB_TEST(s.find_set(0) == id2);
+ DLIB_TEST(s.find_set(1) == id);
+ DLIB_TEST(s.find_set(2) == id2);
+ DLIB_TEST(s.find_set(3) == id);
+ DLIB_TEST(s.find_set(4) == id);
+
+ id = s.merge_sets(s.find_set(1),s.find_set(0));
+ DLIB_TEST(s.find_set(0) == id);
+ DLIB_TEST(s.find_set(1) == id);
+ DLIB_TEST(s.find_set(2) == id);
+ DLIB_TEST(s.find_set(3) == id);
+ DLIB_TEST(s.find_set(4) == id);
+
+ DLIB_TEST(s.size() == 5);
+ s.set_size(1);
+ DLIB_TEST(s.size() == 1);
+ DLIB_TEST(s.find_set(0) == 0);
+ s.set_size(2);
+ DLIB_TEST(s.size() == 2);
+ DLIB_TEST(s.find_set(0) == 0);
+ DLIB_TEST(s.find_set(1) == 1);
+ id = s.merge_sets(0,1);
+ DLIB_TEST(s.size() == 2);
+ DLIB_TEST(id == s.find_set(0));
+ DLIB_TEST(id == s.find_set(1));
+ DLIB_TEST(s.size() == 2);
+ s.clear();
+ DLIB_TEST(s.size() == 0);
+
+ }
+
+
+ class tester_disjoint_subsets : public tester
+ {
+ public:
+ tester_disjoint_subsets (
+ ) :
+ tester ("test_disjoint_subsets",
+ "Runs tests on the disjoint_subsets component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_disjoint_subset();
+ }
+ } a;
+
+
+}
diff --git a/ml/dlib/dlib/test/disjoint_subsets_sized.cpp b/ml/dlib/dlib/test/disjoint_subsets_sized.cpp
new file mode 100644
index 000000000..57ced68e6
--- /dev/null
+++ b/ml/dlib/dlib/test/disjoint_subsets_sized.cpp
@@ -0,0 +1,143 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/disjoint_subsets.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.disjoint_subsets_sized");
+
+ void test_disjoint_subsets_sized()
+ {
+ print_spinner();
+ disjoint_subsets_sized s;
+
+ DLIB_TEST(s.size() == 0);
+ DLIB_TEST(s.get_number_of_sets() == 0);
+
+ s.set_size(5);
+ DLIB_TEST(s.size() == 5);
+ DLIB_TEST(s.get_number_of_sets() == 5);
+
+ DLIB_TEST(s.find_set(0) == 0);
+ DLIB_TEST(s.find_set(1) == 1);
+ DLIB_TEST(s.find_set(2) == 2);
+ DLIB_TEST(s.find_set(3) == 3);
+ DLIB_TEST(s.find_set(4) == 4);
+
+ DLIB_TEST(s.get_size_of_set(0) == 1);
+ DLIB_TEST(s.get_size_of_set(1) == 1);
+ DLIB_TEST(s.get_size_of_set(2) == 1);
+ DLIB_TEST(s.get_size_of_set(3) == 1);
+ DLIB_TEST(s.get_size_of_set(4) == 1);
+
+ unsigned long id = s.merge_sets(1,3);
+ DLIB_TEST(s.get_number_of_sets() == 4);
+ DLIB_TEST(s.find_set(0) == 0);
+ DLIB_TEST(s.find_set(1) == id);
+ DLIB_TEST(s.find_set(2) == 2);
+ DLIB_TEST(s.find_set(3) == id);
+ DLIB_TEST(s.find_set(4) == 4);
+ DLIB_TEST(s.get_size_of_set(0) == 1);
+ DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 2);
+ DLIB_TEST(s.get_size_of_set(2) == 1);
+ DLIB_TEST(s.get_size_of_set(s.find_set(3)) == 2);
+ DLIB_TEST(s.get_size_of_set(4) == 1);
+
+ id = s.merge_sets(s.find_set(1),4);
+ DLIB_TEST(s.get_number_of_sets() == 3);
+ DLIB_TEST(s.find_set(0) == 0);
+ DLIB_TEST(s.find_set(1) == id);
+ DLIB_TEST(s.find_set(2) == 2);
+ DLIB_TEST(s.find_set(3) == id);
+ DLIB_TEST(s.find_set(4) == id);
+ DLIB_TEST(s.get_size_of_set(0) == 1);
+ DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 3);
+ DLIB_TEST(s.get_size_of_set(2) == 1);
+ DLIB_TEST(s.get_size_of_set(s.find_set(3)) == 3);
+ DLIB_TEST(s.get_size_of_set(s.find_set(4)) == 3);
+
+ unsigned long id2 = s.merge_sets(0,2);
+ DLIB_TEST(s.get_number_of_sets() == 2);
+ DLIB_TEST(s.find_set(0) == id2);
+ DLIB_TEST(s.find_set(1) == id);
+ DLIB_TEST(s.find_set(2) == id2);
+ DLIB_TEST(s.find_set(3) == id);
+ DLIB_TEST(s.find_set(4) == id);
+ DLIB_TEST(s.get_size_of_set(s.find_set(0)) == 2);
+ DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 3);
+ DLIB_TEST(s.get_size_of_set(s.find_set(2)) == 2);
+ DLIB_TEST(s.get_size_of_set(s.find_set(3)) == 3);
+ DLIB_TEST(s.get_size_of_set(s.find_set(4)) == 3);
+
+ id = s.merge_sets(s.find_set(1),s.find_set(0));
+ DLIB_TEST(s.get_number_of_sets() == 1);
+ DLIB_TEST(s.find_set(0) == id);
+ DLIB_TEST(s.find_set(1) == id);
+ DLIB_TEST(s.find_set(2) == id);
+ DLIB_TEST(s.find_set(3) == id);
+ DLIB_TEST(s.find_set(4) == id);
+ DLIB_TEST(s.get_size_of_set(s.find_set(0)) == 5);
+ DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 5);
+ DLIB_TEST(s.get_size_of_set(s.find_set(2)) == 5);
+ DLIB_TEST(s.get_size_of_set(s.find_set(3)) == 5);
+ DLIB_TEST(s.get_size_of_set(s.find_set(4)) == 5);
+
+ DLIB_TEST(s.size() == 5);
+ s.set_size(1);
+ DLIB_TEST(s.size() == 1);
+ DLIB_TEST(s.get_number_of_sets() == 1);
+ DLIB_TEST(s.find_set(0) == 0);
+ DLIB_TEST(s.get_size_of_set(0) == 1);
+ s.set_size(2);
+ DLIB_TEST(s.size() == 2);
+ DLIB_TEST(s.get_number_of_sets() == 2);
+ DLIB_TEST(s.find_set(0) == 0);
+ DLIB_TEST(s.find_set(1) == 1);
+ DLIB_TEST(s.get_size_of_set(0) == 1);
+ DLIB_TEST(s.get_size_of_set(1) == 1);
+ id = s.merge_sets(0,1);
+ DLIB_TEST(s.size() == 2);
+ DLIB_TEST(s.get_number_of_sets() == 1);
+ DLIB_TEST(id == s.find_set(0));
+ DLIB_TEST(id == s.find_set(1));
+ DLIB_TEST(s.get_size_of_set(s.find_set(0)) == 2);
+ DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 2);
+ DLIB_TEST(s.size() == 2);
+ s.clear();
+ DLIB_TEST(s.size() == 0);
+ DLIB_TEST(s.get_number_of_sets() == 0);
+
+ }
+
+
+ class tester_disjoint_subsets_sized : public tester
+ {
+ public:
+ tester_disjoint_subsets_sized (
+ ) :
+ tester ("test_disjoint_subsets_sized",
+ "Runs tests on the disjoint_subsets_sized component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_disjoint_subsets_sized();
+ }
+ } a;
+
+
+}
diff --git a/ml/dlib/dlib/test/dnn.cpp b/ml/dlib/dlib/test/dnn.cpp
new file mode 100644
index 000000000..9d3258b70
--- /dev/null
+++ b/ml/dlib/dlib/test/dnn.cpp
@@ -0,0 +1,3261 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include <random>
+#include <numeric>
+#include "../dnn.h"
+
+#include "tester.h"
+
+#ifndef __INTELLISENSE__
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.dnn");
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ float compare_gradients (
+ const tensor& t,
+ T grad
+ )
+ {
+ float max_error = 0;
+ auto p = t.host();
+ for (size_t i = 0; i < t.size(); ++i)
+ {
+ max_error = std::max(max_error, std::abs(p[i]-grad(i)));
+ }
+ return max_error;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_tanh()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ resizable_tensor src, dest, gradient_input;
+ src = matrix_cast<float>(gaussian_randm(5,5, 0));
+ dest = matrix_cast<float>(gaussian_randm(5,5, 1));
+ gradient_input = matrix_cast<float>(gaussian_randm(5,5, 2));
+
+
+
+ auto grad_src = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = src.host()[idx];
+ src.host()[idx] += eps;
+ tanh(dest, src);
+ float result = dot(gradient_input, dest);
+ src.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+
+ resizable_tensor src_grad;
+ src_grad.copy_size(src);
+ src_grad = 0;
+
+ tanh(dest, src);
+ tanh_gradient(src_grad, dest, gradient_input);
+
+ auto grad_error = compare_gradients(src_grad, grad_src);
+ dlog << LINFO << "src error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+ }
+
+ void test_sigmoid()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ resizable_tensor src, dest, gradient_input;
+ src = matrix_cast<float>(gaussian_randm(5,5, 0));
+ dest = matrix_cast<float>(gaussian_randm(5,5, 1));
+ gradient_input = matrix_cast<float>(gaussian_randm(5,5, 2));
+
+
+
+ auto grad_src = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = src.host()[idx];
+ src.host()[idx] += eps;
+ sigmoid(dest, src);
+ float result = dot(gradient_input, dest);
+ src.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+
+ resizable_tensor src_grad;
+ src_grad.copy_size(src);
+ src_grad = 0;
+
+ sigmoid(dest, src);
+ sigmoid_gradient(src_grad, dest, gradient_input);
+
+ auto grad_error = compare_gradients(src_grad, grad_src);
+ dlog << LINFO << "src error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+ }
+
+ void test_softmax()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ const long nr = 3;
+ const long nc = 3;
+ resizable_tensor src(5,5,nr,nr), dest(5,5,nr,nc), gradient_input(5,5,nr,nc);
+ tt::tensor_rand rnd;
+ rnd.fill_uniform(src);
+ rnd.fill_uniform(dest);
+ // fill like this as a test of the assignment operator.
+ gradient_input = matrix_cast<float>(gaussian_randm(5,5*nr*nc, 2));
+
+
+
+ auto grad_src = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = src.host()[idx];
+ src.host()[idx] += eps;
+ tt::softmax(dest, src);
+ float result = dot(gradient_input, dest);
+ src.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+
+ resizable_tensor src_grad;
+ src_grad.copy_size(src);
+ src_grad = 0;
+
+ tt::softmax(dest, src);
+ softmax_gradient(src_grad, dest, gradient_input);
+
+ auto grad_error = compare_gradients(src_grad, grad_src);
+ dlog << LINFO << "src error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+
+#ifdef DLIB_USE_CUDA
+ resizable_tensor src1 = src;
+ resizable_tensor src2 = src;
+ resizable_tensor dest1, dest2;
+ dest1.copy_size(src);
+ dest2.copy_size(src);
+ cuda::softmax_all(dest1, src1);
+ cpu::softmax_all(dest2, src2);
+ DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2))) < 1e-5, max(abs(mat(dest1)-mat(dest2))));
+#endif
+ }
+
+ void test_softmax_all()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ const long nr = 3;
+ const long nc = 3;
+ resizable_tensor src(5,5,nr,nr), dest(5,5,nr,nc), gradient_input(5,5,nr,nc);
+ tt::tensor_rand rnd;
+ rnd.fill_uniform(src);
+ rnd.fill_uniform(dest);
+ // fill like this as a test of the assignment operator.
+ gradient_input = matrix_cast<float>(gaussian_randm(5,5*nr*nc, 2));
+
+
+
+ auto grad_src = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = src.host()[idx];
+ src.host()[idx] += eps;
+ tt::softmax_all(dest, src);
+ float result = dot(gradient_input, dest);
+ src.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+
+ resizable_tensor src_grad;
+ src_grad.copy_size(src);
+ src_grad = 0;
+
+ tt::softmax_all(dest, src);
+ softmax_all_gradient(src_grad, dest, gradient_input);
+
+ auto grad_error = compare_gradients(src_grad, grad_src);
+ dlog << LINFO << "src error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+
+#ifdef DLIB_USE_CUDA
+ resizable_tensor src1 = src;
+ resizable_tensor src2 = src;
+ resizable_tensor dest1, dest2;
+ dest1.copy_size(src);
+ dest2.copy_size(src);
+ cuda::softmax_all(dest1, src1);
+ cpu::softmax_all(dest2, src2);
+ DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2))) < 1e-5, max(abs(mat(dest1)-mat(dest2))));
+#endif
+ }
+
+ void test_batch_normalize()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ resizable_tensor src, gamma, beta, dest, dest2, dest3, means, vars, gradient_input;
+ src = matrix_cast<float>(gaussian_randm(5,5, 0));
+ gamma = matrix_cast<float>(gaussian_randm(1,5, 1));
+ beta = matrix_cast<float>(gaussian_randm(1,5, 2));
+ gradient_input = matrix_cast<float>(gaussian_randm(5,5, 3));
+
+ gamma = 1;
+ beta = 0;
+
+ resizable_tensor running_means;
+ resizable_tensor running_variances;
+ batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta);
+ const double scale = (src.num_samples())/(src.num_samples()-1.0);
+ // Turn back into biased variance estimate because that's how batch_normalize() works, so if we want to match it this is necessary.
+ running_variances = mat(running_variances)/scale;
+ batch_normalize_inference(DEFAULT_BATCH_NORM_EPS,dest2, src, gamma, beta, running_means, running_variances);
+ DLIB_TEST_MSG(max(abs(mat(dest2)-mat(dest))) < 1e-5, max(abs(mat(dest2)-mat(dest))));
+ cpu::batch_normalize_inference(DEFAULT_BATCH_NORM_EPS,dest3, src, gamma, beta, running_means, running_variances);
+ DLIB_TEST_MSG(max(abs(mat(dest3)-mat(dest))) < 1e-5, max(abs(mat(dest3)-mat(dest))));
+
+
+ auto grad_src = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = src.host()[idx];
+ src.host()[idx] += eps;
+ batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta);
+ float result = dot(gradient_input, dest);
+ src.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+ auto grad_gamma = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = gamma.host()[idx];
+ gamma.host()[idx] += eps;
+ batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta);
+ float result = dot(gradient_input, dest);
+ gamma.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+ auto grad_beta = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = beta.host()[idx];
+ beta.host()[idx] += eps;
+ batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta);
+ float result = dot(gradient_input, dest);
+ beta.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+
+ resizable_tensor src_grad, gamma_grad, beta_grad;
+ src_grad.copy_size(src);
+ gamma_grad.copy_size(gamma);
+ beta_grad.copy_size(beta);
+ src_grad = 0;
+ gamma_grad = 8;
+ beta_grad = 8;
+
+ batch_normalize_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, vars, src, gamma, src_grad, gamma_grad, beta_grad);
+
+ auto grad_error = compare_gradients(src_grad, grad_src);
+ dlog << LINFO << "src error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+
+ grad_error = compare_gradients(gamma_grad, grad_gamma);
+ dlog << LINFO << "gamma error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+
+ grad_error = compare_gradients(beta_grad, grad_beta);
+ dlog << LINFO << "beta error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+ }
+
+ void test_batch_normalize_conv()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ resizable_tensor src(5,5,4,4), gamma, beta, dest, dest2, dest3, means, vars, gradient_input(5,5,4,4);
+ tt::tensor_rand rnd;
+ rnd.fill_gaussian(src);
+ rnd.fill_gaussian(gradient_input);
+ gamma = matrix_cast<float>(gaussian_randm(1,5, 1));
+ beta = matrix_cast<float>(gaussian_randm(1,5, 2));
+
+ gamma = 1;
+ beta = 0;
+
+ resizable_tensor running_means;
+ resizable_tensor running_variances;
+ batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta);
+ const double scale = (src.num_samples()*src.nr()*src.nc())/(src.num_samples()*src.nr()*src.nc()-1.0);
+ // Turn back into biased variance estimate because that's how
+ // batch_normalize_conv() works, so if we want to match it this is necessary.
+ running_variances = mat(running_variances)/scale;
+ batch_normalize_conv_inference(DEFAULT_BATCH_NORM_EPS,dest2, src, gamma, beta, running_means, running_variances);
+ DLIB_TEST(max(abs(mat(dest2)-mat(dest))) < 1e-5);
+ cpu::batch_normalize_conv_inference(DEFAULT_BATCH_NORM_EPS,dest3, src, gamma, beta, running_means, running_variances);
+ DLIB_TEST(max(abs(mat(dest3)-mat(dest))) < 1e-5);
+
+
+ auto grad_src = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = src.host()[idx];
+ src.host()[idx] += eps;
+ batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta);
+ float result = dot(gradient_input, dest);
+ src.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+ auto grad_gamma = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = gamma.host()[idx];
+ gamma.host()[idx] += eps;
+ batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta);
+ float result = dot(gradient_input, dest);
+ gamma.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+ auto grad_beta = [&](long idx) {
+ auto f = [&](float eps) {
+ const float old = beta.host()[idx];
+ beta.host()[idx] += eps;
+ batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta);
+ float result = dot(gradient_input, dest);
+ beta.host()[idx] = old;
+ return result;
+ };
+ const float eps = 0.01;
+ return (f(+eps)-f(-eps))/(2*eps);
+ };
+
+
+ resizable_tensor src_grad, gamma_grad, beta_grad;
+ src_grad.copy_size(src);
+ gamma_grad.copy_size(gamma);
+ beta_grad.copy_size(beta);
+ src_grad = 0;
+ gamma_grad = 9;
+ beta_grad = 9;
+
+ batch_normalize_conv_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, vars, src, gamma, src_grad, gamma_grad, beta_grad);
+
+
+ auto grad_error = compare_gradients(src_grad, grad_src);
+ dlog << LINFO << "src error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+
+ grad_error = compare_gradients(gamma_grad, grad_gamma);
+ dlog << LINFO << "gamma error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+
+ grad_error = compare_gradients(beta_grad, grad_beta);
+ dlog << LINFO << "beta error: " << grad_error;
+ DLIB_TEST(grad_error < 0.001);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_basic_tensor_ops()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ resizable_tensor dest, src(3,4), A(1,4), B(1,4);
+ src = 2;
+ dest.copy_size(src);
+ affine_transform(dest, src, 2, 3);
+ dlog << LINFO << mat(dest);
+ matrix<float> truth1(3,4), truth2(3,4);
+
+ truth1 = 2;
+ DLIB_TEST(max(abs(truth1-mat(src))) < 1e-5);
+ src *= 2;
+ truth1 = 4;
+ DLIB_TEST(max(abs(truth1-mat(src))) < 1e-5);
+ src = 2;
+
+
+ truth1 = 7;
+ truth2 = 7, 10, 7, 7,
+ 7, 10, 7, 7,
+ 7, 10, 7, 7;
+ DLIB_TEST(max(abs(truth1-mat(dest))) < 1e-5);
+
+ A = 2;
+ B = 3;
+ A.host()[1] = 3;
+ B.host()[1] = 4;
+ dest = 0;
+ affine_transform(dest, src, A, B);
+ dlog << LINFO << mat(dest);
+ DLIB_TEST(max(abs(truth2-mat(dest))) < 1e-5);
+
+ A = matrix_cast<float>(gaussian_randm(3,4, 1));
+ B = matrix_cast<float>(gaussian_randm(3,4, 2));
+ affine_transform(dest, src, A, B);
+ dlog << LINFO << mat(dest);
+ matrix<float> truth3 = pointwise_multiply(mat(src), mat(A)) + mat(B);
+ DLIB_TEST(max(abs(truth3-mat(dest))) < 1e-5);
+
+ matrix<float> truth4 = pointwise_multiply(mat(A), mat(B));
+ tt::multiply(false, A, A, B);
+ DLIB_TEST(max(abs(truth4-mat(A))) < 1e-5);
+ truth4 = pointwise_multiply(mat(A), mat(B)) + mat(A);
+ tt::multiply(true, A, A, B);
+ DLIB_TEST(max(abs(truth4-mat(A))) < 1e-5);
+
+ matrix<float> truth5 = mat(B) > 0.1;
+ dlog << LINFO << truth5;
+ threshold(B, 0.1);
+ DLIB_TEST(max(abs(truth5-mat(B))) < 1e-5);
+
+ int cnt = 0;
+ for(auto& x : A)
+ x = cnt++;
+
+ truth1.set_size(2,2);
+ truth2.set_size(2,2);
+ truth3.set_size(2,2);
+ truth1 = 0,1,2,3;
+ truth2 = 4,5,6,7;
+ truth3 = 8,9,10,11;
+
+ alias_tensor at(2,2);
+ auto A0 = at(A,0);
+ auto A4 = at(A,4);
+ auto A8 = at(const_cast<const resizable_tensor&>(A),8);
+ DLIB_TEST(mat(A0) == truth1);
+ DLIB_TEST(mat(at(A,4)) == truth2);
+ DLIB_TEST(mat(A8) == truth3);
+
+ A4 += uniform_matrix<float>(2,2,2);
+ truth2 += 2;
+ DLIB_TEST(mat(A4) == truth2);
+ truth1 = trans(reshape_to_column_vector(truth1));
+ truth2 = trans(reshape_to_column_vector(truth2));
+ truth3 = trans(reshape_to_column_vector(truth3));
+
+ DLIB_TEST(mat(A) == join_cols(truth1,join_cols(truth2,truth3)));
+
+ affine_transform(A,A,1,2);
+ truth1 += 2;
+ truth2 += 2;
+ truth3 += 2;
+ DLIB_TEST(mat(at(A,4)) == reshape(truth2,2,2));
+ DLIB_TEST(mat(A) == join_cols(truth1,join_cols(truth2,truth3)));
+
+ {
+ resizable_tensor dest(3,4);
+ resizable_tensor A, B;
+ A = dest;
+ B = dest;
+
+ tensor_rand rnd;
+ rnd.fill_uniform(dest);
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+
+ dest.set_size(1,4);
+
+ tt::multiply(false, dest, A, B);
+ DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(mat(A),mat(B))))) < 1e-6);
+
+ A.set_size(1,4);
+ rnd.fill_uniform(A);
+ matrix<float> AA = join_cols(mat(A),mat(A)); AA = join_cols(mat(A),AA);
+
+ tt::multiply(false, dest, A, B);
+ DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6);
+
+ tt::multiply(false, dest, B, A);
+ DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6);
+ matrix<float> prevdest = mat(dest);
+ tt::multiply(true, dest, B, A);
+ DLIB_TEST(max(abs(mat(dest)-prevdest-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6);
+
+ dest.set_size(3,4);
+ tt::multiply(false, dest, B, A);
+ DLIB_TEST(max(abs(mat(dest)-pointwise_multiply(AA,mat(B)))) < 1e-6);
+ prevdest = mat(dest);
+ tt::multiply(true, dest, B, A);
+ DLIB_TEST(max(abs(mat(dest)-prevdest-pointwise_multiply(AA,mat(B)))) < 1e-6);
+
+ tt::multiply(false, dest, A, B);
+ DLIB_TEST(max(abs(mat(dest)-pointwise_multiply(AA,mat(B)))) < 1e-6);
+ prevdest = mat(dest);
+ tt::multiply(true, dest, B, A);
+ DLIB_TEST(max(abs(mat(dest)-prevdest-pointwise_multiply(AA,mat(B)))) < 1e-6);
+ }
+
+ {
+ resizable_tensor A, B, truth;
+ A.set_size(2,3,4,5);
+ truth.copy_size(A);
+ B.copy_size(A);
+
+ A = 4;
+ B = 1;
+ truth = 1;
+ DLIB_TEST(max(abs(mat(B)- mat(truth))) < 1e-5);
+ memcpy(A, truth);
+ DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5);
+
+ A = 4;
+ A.host();
+ B.host();
+ memcpy(A, truth);
+ DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5);
+
+#ifdef DLIB_USE_CUDA
+ A = 4;
+ A.device();
+ B.host();
+ memcpy(A, truth);
+ DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5);
+
+ A = 4;
+ A.device();
+ B.device();
+ memcpy(A, truth);
+ DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5);
+
+ A = 4;
+ A.host();
+ B.device();
+ memcpy(A, truth);
+ DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5);
+
+ A = 4;
+ A.host_write_only();
+ B.device();
+ memcpy(A, truth);
+ DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5);
+#endif
+ }
+
+ {
+ resizable_tensor A, B;
+ A.set_size(11);
+ B.copy_size(A);
+
+ A = 4;
+ B = 1;
+ matrix<float> truth;
+
+
+ alias_tensor at(5);
+ A = 4;
+ A.host();
+ B.host();
+ {
+ // non-aliasing test
+ auto aA = at(A,5);
+ auto aB = at(B,5);
+ memcpy(aA, aB);
+ truth = {4,4,4,4,4, 1,1,1,1,1, 4};
+ DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5);
+ }
+ {
+ // aliasing test
+ auto aA = at(A,1);
+ auto aB = at(A,6);
+ memcpy(aA, aB);
+ truth = {4,1,1,1,1, 4,1,1,1,1, 4};
+ DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5);
+ }
+
+
+#ifdef DLIB_USE_CUDA
+ A = 4;
+ A.device();
+ B.host();
+ {
+ // non-aliasing test
+ auto aA = at(A,5);
+ auto aB = at(B,5);
+ memcpy(aA, aB);
+ truth = {4,4,4,4,4, 1,1,1,1,1, 4};
+ DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5);
+ }
+ {
+ // aliasing test
+ auto aA = at(A,1);
+ auto aB = at(A,6);
+ memcpy(aA, aB);
+ truth = {4,1,1,1,1, 4,1,1,1,1, 4};
+ DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5);
+ }
+
+
+ A = 4;
+ A.device();
+ B.device();
+ {
+ // non-aliasing test
+ auto aA = at(A,5);
+ auto aB = at(B,5);
+ memcpy(aA, aB);
+ truth = {4,4,4,4,4, 1,1,1,1,1, 4};
+ DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5);
+ }
+ {
+ // aliasing test
+ auto aA = at(A,1);
+ auto aB = at(A,6);
+ memcpy(aA, aB);
+ truth = {4,1,1,1,1, 4,1,1,1,1, 4};
+ DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5);
+ }
+
+ A = 4;
+ A.host();
+ B.device();
+ {
+ // non-aliasing test
+ auto aA = at(A,5);
+ auto aB = at(B,5);
+ memcpy(aA, aB);
+ truth = {4,4,4,4,4, 1,1,1,1,1, 4};
+ DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5);
+ }
+ {
+ // aliasing test
+ auto aA = at(A,1);
+ auto aB = at(A,6);
+ memcpy(aA, aB);
+ truth = {4,1,1,1,1, 4,1,1,1,1, 4};
+ DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5);
+ }
+
+#endif
+ }
+
+ {
+ resizable_tensor A(4,5), B(4);
+
+ tensor_rand rnd;
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+
+ float alpha = 1.4;
+ float beta = 0.5;
+
+ matrix<float> a(mat(A)), b(mat(B));
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ set_colm(a,c) = beta*colm(a,c) + alpha*b;
+ }
+
+ tt::add(beta, A, alpha, B);
+ DLIB_TEST_MSG(max(abs(mat(A)-a)) < 1e-6, max(abs(mat(A)-a)));
+
+ beta = 0;
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ set_colm(a,c) = beta*colm(a,c) + alpha*b;
+ }
+
+ tt::add(beta, A, alpha, B);
+ DLIB_TEST(max(abs(mat(A)-a)) < 1e-6);
+ }
+
+ {
+ resizable_tensor A, B;
+ A.set_size(2,3,4,5);
+ B.set_size(2,3,4,5);
+
+ tensor_rand rnd;
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+
+ matrix<float> truth;
+
+ truth = 2*mat(A) + 3*mat(B);
+ tt::add(2, A, 3, B);
+ DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6);
+
+
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+ truth = 0*mat(A) + 3*mat(B);
+ tt::add(0, A, 3, B);
+ DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6);
+
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+ truth = 1*mat(A) + 0*mat(B);
+ tt::add(1, A, 0, B);
+ DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6);
+
+
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+ truth = 0*mat(A) + 0*mat(B);
+ tt::add(0, A, 0, B);
+ DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6);
+
+
+ B.set_size(1,3,4,5);
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+ truth = 2*mat(A) + 3*join_cols(mat(B), mat(B));
+ tt::add(2, A, 3, B);
+ DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6);
+ DLIB_TEST(A.num_samples()==2);
+
+ B.set_size(1,1,4,5);
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+ matrix<float> temp = join_rows(mat(B), join_rows(mat(B),mat(B)));
+ truth = 2*mat(A) + 3*join_cols(temp,temp);
+ tt::add(2, A, 3, B);
+ DLIB_TEST_MSG(max(abs(mat(A)-truth )) < 1e-6, max(abs(mat(A)-truth )));
+
+ B.set_size(1,3,1,1);
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+ resizable_tensor AA(A), BB(B);
+ tt::add(2, A, 3, B);
+ cpu::add(2, AA, 3, BB);
+ DLIB_TEST_MSG(max(abs(mat(A)-mat(AA) )) < 1e-6, max(abs(mat(A)-mat(AA) )));
+ }
+
+ {
+ print_spinner();
+ resizable_tensor dest1(123,456), dest2(123,456);
+ resizable_tensor src1(123,456), src2(123,456);
+
+ tt::tensor_rand rnd;
+
+ rnd.fill_uniform(src1); tt::affine_transform(src1, src1, 1, 2); src2 = src1; // random in range [2, 3]
+ dest1 = exp(mat(src1));
+ tt::exp(dest2, src2);
+ tt::exp(src2, src2); // should work in place
+ DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2))) < 1e-5, max(abs(mat(dest1)-mat(dest2))));
+ DLIB_TEST(max(abs(mat(dest1)-mat(src2))) < 1e-5);
+
+ rnd.fill_uniform(src1); tt::affine_transform(src1, src1, 1, 2); src2 = src1; // random in range [2, 3]
+ dest1 = log(mat(src1));
+ tt::log(dest2, src2);
+ tt::log(src2, src2); // should work in place
+ DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5);
+ DLIB_TEST(max(abs(mat(dest1)-mat(src2))) < 1e-5);
+
+ rnd.fill_uniform(src1); tt::affine_transform(src1, src1, 1, 2); src2 = src1; // random in range [2, 3]
+ dest1 = log10(mat(src1));
+ tt::log10(dest2, src2);
+ tt::log10(src2, src2); // should work in place
+ DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5);
+ DLIB_TEST(max(abs(mat(dest1)-mat(src2))) < 1e-5);
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef DLIB_USE_CUDA
+
+ void test_scale_channels()
+ {
+ tt::tensor_rand rnd;
+
+ resizable_tensor dest1(2,3,4,5), dest2;
+ rnd.fill_gaussian(dest1);
+ dest2 = dest1;
+
+ resizable_tensor src(2,3,4,5);
+ resizable_tensor scales(2,3);
+ rnd.fill_gaussian(src);
+ rnd.fill_gaussian(scales);
+
+
+ cpu::scale_channels(true, dest1, src, scales);
+ cuda::scale_channels(true, dest2, src, scales);
+
+ DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-6);
+
+ cpu::scale_channels(false, dest1, src, scales);
+ cuda::scale_channels(false, dest2, src, scales);
+
+ DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-6);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_affine_rect()
+ {
+ dlib::rand rnd;
+
+ for (int iter = 0; iter < 20; ++iter)
+ {
+
+ long nr = 1 + rnd.get_random_32bit_number()%10;
+ long nc = 1 + rnd.get_random_32bit_number()%10;
+
+ resizable_tensor dest1(nr,nc), dest2(nr,nc), src1(nr,nc), src2(nr,nc), src3(nr,nc);
+ matrix<float> dest3;
+
+ dest1 = 1;
+ dest2 = 1;
+ dest3 = mat(dest1);
+ src1 = 2;
+ src2 = 3;
+ src3 = 4;
+
+ point p1(rnd.get_random_32bit_number()%nc, rnd.get_random_32bit_number()%nr);
+ point p2(rnd.get_random_32bit_number()%nc, rnd.get_random_32bit_number()%nr);
+ rectangle rect(p1,p2);
+
+ cuda::affine_transform(rect, dest1, src1, src2, src3, 2,3,4);
+
+ cpu::affine_transform(rect, dest2, src1, src2, src3, 2,3,4);
+
+ DLIB_TEST(mat(dest1) == mat(dest2));
+
+ set_subm(dest3,rect) = 2*subm(mat(src1),rect) + 3*subm(mat(src2),rect) + 4*subm(mat(src3),rect);
+ DLIB_TEST(dest3 == mat(dest1));
+
+ dest1 = 1;
+ tt::affine_transform(rect, dest1, src1, src2, src3, 2,3,4);
+ DLIB_TEST(dest3 == mat(dest1));
+ }
+ }
+
+ void test_conv()
+ {
+ cuda::tensor_conv conv1;
+ cpu::tensor_conv conv2;
+
+ dlib::rand prnd;
+ for (int iter = 0; iter < 400; ++iter)
+ {
+ print_spinner();
+
+ resizable_tensor data(prnd.get_random_32bit_number()%5+1,
+ prnd.get_random_32bit_number()%5+1,
+ prnd.get_random_32bit_number()%25+1,
+ prnd.get_random_32bit_number()%25+1
+ );
+ resizable_tensor filters(
+ prnd.get_random_32bit_number()%5+1,
+ data.k(),
+ prnd.get_random_32bit_number()%6+1,
+ prnd.get_random_32bit_number()%6+1
+ );
+
+ tt::tensor_rand rnd;
+ rnd.fill_uniform(data);
+ rnd.fill_uniform(filters);
+
+
+ resizable_tensor output1, output2;
+
+
+ const int stride_y = prnd.get_random_32bit_number()%5+1;
+ const int stride_x = prnd.get_random_32bit_number()%5+1;
+ int padding_y = prnd.get_random_32bit_number()%(filters.nr()/2+1);
+ int padding_x = prnd.get_random_32bit_number()%(filters.nc()/2+1);
+ if (!(filters.nr() <= data.nr() + 2*padding_y))
+ padding_y = (filters.nr()-data.nr()+1)/2;
+ if (!(filters.nc() <= data.nc() + 2*padding_x))
+ padding_x = (filters.nc()-data.nc()+1)/2;
+ conv1.setup(data,filters,stride_y,stride_x,padding_y,padding_x);
+ conv1(false, output1, data, filters);
+ conv2.setup(data,filters,stride_y,stride_x,padding_y,padding_x);
+ conv2(false, output2, data, filters);
+ dlog << LINFO << "forward error: "<< max(abs(mat(output1)-mat(output2)));
+ DLIB_TEST_MSG(max(abs(mat(output1)-mat(output2))) < 1e-3, max(abs(mat(output1)-mat(output2)))
+ <<"\n\t padding_y: "<< padding_y
+ <<"\n\t padding_x: "<< padding_x
+ );
+
+ conv1(true, output1, data, filters);
+ conv2(true, output2, data, filters);
+ dlog << LINFO << "forward error: "<< max(abs(mat(output1)-mat(output2)));
+ DLIB_TEST_MSG(max(abs(mat(output1)-mat(output2))) < 1e-3, max(abs(mat(output1)-mat(output2)))
+ <<"\n\t padding_y: "<< padding_y
+ <<"\n\t padding_x: "<< padding_x
+ );
+
+
+
+ resizable_tensor gi, data_gradient1, data_gradient2;
+ gi.copy_size(output1);
+ rnd.fill_uniform(gi);
+
+ data_gradient1.copy_size(data);
+ data_gradient2.copy_size(data);
+ data_gradient1 = 1;
+ data_gradient2 = 1;
+
+ conv1.get_gradient_for_data(true, gi, filters, data_gradient1);
+ conv2.get_gradient_for_data(true, gi, filters, data_gradient2);
+
+ dlog << LINFO << "data gradient error: "<< max(abs(mat(data_gradient1)-mat(data_gradient2)));
+ DLIB_TEST(max(abs(mat(data_gradient1)-mat(data_gradient2))) < 1e-3);
+
+ conv1.get_gradient_for_data(false, gi, filters, data_gradient1);
+ conv2.get_gradient_for_data(false, gi, filters, data_gradient2);
+
+ dlog << LINFO << "data gradient error: "<< max(abs(mat(data_gradient1)-mat(data_gradient2)));
+ DLIB_TEST(max(abs(mat(data_gradient1)-mat(data_gradient2))) < 1e-3);
+
+
+ resizable_tensor filter_gradient1, filter_gradient2;
+ gi.copy_size(output1);
+ rnd.fill_uniform(gi);
+
+ filter_gradient1.copy_size(filters);
+ filter_gradient2.copy_size(filters);
+ filter_gradient1 = 1;
+ filter_gradient2 = 1;
+
+ conv1.get_gradient_for_filters(false, gi, data, filter_gradient1);
+ conv2.get_gradient_for_filters(false, gi, data, filter_gradient2);
+
+ dlog << LINFO << "filter gradient error: "<< max(abs(mat(filter_gradient1)-mat(filter_gradient2)));
+ DLIB_TEST_MSG(max(abs(mat(filter_gradient1)-mat(filter_gradient2))) < 1e-3, max(abs(mat(filter_gradient1)-mat(filter_gradient2))));
+
+
+ conv1.get_gradient_for_filters(true, gi, data, filter_gradient1);
+ conv2.get_gradient_for_filters(true, gi, data, filter_gradient2);
+
+ dlog << LINFO << "filter gradient error: "<< max(abs(mat(filter_gradient1)-mat(filter_gradient2)));
+ DLIB_TEST_MSG(max(abs(mat(filter_gradient1)-mat(filter_gradient2))) < 2e-3, max(abs(mat(filter_gradient1)-mat(filter_gradient2))));
+ }
+ }
+
+ void compare_adam()
+ {
+ float t = 2;
+ tt::tensor_rand rnd;
+ resizable_tensor s, m, v, params, params_grad;
+ s.set_size(89,90,60,73);
+ m.copy_size(s);
+ v.copy_size(s);
+ params.copy_size(s);
+ params_grad.copy_size(s);
+
+ rnd.fill_uniform(s);
+ rnd.fill_uniform(m);
+ rnd.fill_uniform(v);
+ rnd.fill_uniform(params);
+ rnd.fill_uniform(params_grad);
+
+ resizable_tensor mm(m), vv(v);
+ cpu::compute_adam_update(0,params.size(),s, mm, vv, t, 0.01, 0.001, 0.9, 0.99, params, params_grad);
+ matrix<float> s1 = mat(s);
+
+ rnd.fill_uniform(s);
+ cuda::compute_adam_update(0,params.size(),s, m, v, t, 0.01, 0.001, 0.9, 0.99, params, params_grad);
+ matrix<float> s2 = mat(s);
+
+ DLIB_TEST_MSG(max(abs(s1-s2)) < 1e-6, max(abs(s1-s2)));
+ DLIB_TEST_MSG(max(abs(mat(m)-mat(mm))) < 1e-6, max(abs(mat(m)-mat(mm))));
+ DLIB_TEST_MSG(max(abs(mat(v)-mat(vv))) < 1e-6, max(abs(mat(v)-mat(vv))));
+ }
+
+ void test_multiply_zero_padded()
+ {
+ print_spinner();
+ dlib::rand rnd;
+ tt::tensor_rand trnd;
+ for (int iter = 0; iter < 300; ++iter)
+ {
+ resizable_tensor dest1(rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1);
+ resizable_tensor dest2;
+ dest2.copy_size(dest1);
+ resizable_tensor src1(rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1);
+ resizable_tensor src2(rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1);
+
+ trnd.fill_uniform(dest1);
+ trnd.fill_uniform(dest2);
+ trnd.fill_uniform(src1);
+ trnd.fill_uniform(src2);
+ cpu::multiply_zero_padded(false, dest1, src1, src2);
+ cuda::multiply_zero_padded(false, dest2, src1, src2);
+ DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5);
+
+ cpu::multiply_zero_padded(true, dest1, src1, src2);
+ cuda::multiply_zero_padded(true, dest2, src1, src2);
+ DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5);
+ }
+
+ // make sure we have a test for the case where all tensors have the same
+ // dimensions.
+ resizable_tensor dest1(3,4,5,6);
+ resizable_tensor dest2;
+ resizable_tensor src1;
+ resizable_tensor src2;
+ dest2.copy_size(dest1);
+ src1.copy_size(dest1);
+ src2.copy_size(dest1);
+
+ trnd.fill_uniform(dest1);
+ trnd.fill_uniform(dest2);
+ trnd.fill_uniform(src1);
+ trnd.fill_uniform(src2);
+ cpu::multiply_zero_padded(false, dest1, src1, src2);
+ cuda::multiply_zero_padded(false, dest2, src1, src2);
+ DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5);
+
+ cpu::multiply_zero_padded(true, dest1, src1, src2);
+ cuda::multiply_zero_padded(true, dest2, src1, src2);
+ DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5);
+ }
+
+ void test_add()
+ {
+ print_spinner();
+ dlib::rand rnd;
+ tt::tensor_rand trnd;
+ for (int iter = 0; iter < 300; ++iter)
+ {
+ resizable_tensor dest1(rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1);
+ resizable_tensor dest2;
+ dest2.copy_size(dest1);
+ resizable_tensor src1(rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1);
+ resizable_tensor src2(rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1,
+ rnd.get_random_32bit_number()%4+1);
+
+ trnd.fill_uniform(dest1);
+ trnd.fill_uniform(dest2);
+ trnd.fill_uniform(src1);
+ trnd.fill_uniform(src2);
+ cpu::add(dest1, src1, src2);
+ cuda::add(dest2, src1, src2);
+
+ DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5);
+ }
+
+ // make sure we have a test for the case where all tensors have the same
+ // dimensions.
+ resizable_tensor dest1(3,4,5,6);
+ resizable_tensor dest2;
+ resizable_tensor src1;
+ resizable_tensor src2;
+ dest2.copy_size(dest1);
+ src1.copy_size(dest1);
+ src2.copy_size(dest1);
+
+ trnd.fill_uniform(dest1);
+ trnd.fill_uniform(dest2);
+ trnd.fill_uniform(src1);
+ trnd.fill_uniform(src2);
+
+ cpu::add(dest1, src1, src2);
+ cuda::add(dest2, src1, src2);
+
+ DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5);
+ }
+
+ void test_more_ops(const long nr, const long nc)
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ // We are going to make sure that the CPU implementation of these things matches
+ // the CUDA implementation.
+
+ tensor_rand rnd;
+
+ resizable_tensor dest(nr,nc), src(nr,nc), dest2, src2;
+ resizable_tensor srcb(nr,nc), srcc(nr,nc), srcb2, srcc2;
+
+
+ rnd.fill_uniform(dest);
+ rnd.fill_uniform(src);
+ dest2 = dest; src2 = src;
+ cuda::multiply(false, dest, dest, src);
+ cpu::multiply(false, dest2, dest2, src2);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+ cuda::multiply(true, dest, dest, src);
+ cpu::multiply(true, dest2, dest2, src2);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+
+
+ rnd.fill_uniform(dest);
+ rnd.fill_uniform(src);
+ dest2 = dest; src2 = src;
+ cuda::affine_transform(dest, src, 2, 3);
+ cpu::affine_transform(dest2, src2, 2, 3);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+
+ rnd.fill_uniform(dest);
+ rnd.fill_uniform(src);
+ rnd.fill_uniform(srcb);
+ dest2 = dest; src2 = src; srcb2 = srcb;
+ cuda::affine_transform(dest, src, srcb, 2, 3, 4);
+ cpu::affine_transform(dest2, src2, srcb2, 2, 3, 4);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+
+ rnd.fill_uniform(dest);
+ rnd.fill_uniform(src);
+ rnd.fill_uniform(srcb);
+ rnd.fill_uniform(srcc);
+ dest2 = dest; src2 = src; srcb2 = srcb; srcc2 = srcc;
+ cuda::affine_transform(dest, src, srcb, srcc, 2, 3, 4, 5);
+ cpu::affine_transform(dest2, src2, srcb2, srcc2, 2, 3, 4, 5);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+
+ cuda::affine_transform(dest, src, srcb, srcc, 2, 3, 4, 0);
+ cpu::affine_transform(dest2, src2, srcb2, srcc2, 2, 3, 4, 0);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+
+ cuda::affine_transform_range(0, dest.size(), dest, src, srcb, srcc, 2, 3, 4);
+ cpu::affine_transform_range(0, dest2.size(), dest2, src2, srcb2, srcc2, 2, 3, 4);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+
+ if (3 < dest.size())
+ {
+ dest = 999;
+ dest2 = 999;
+ cuda::affine_transform_range(3, dest.size()-1, dest, src, srcb, srcc, 2, 3, 4);
+ cpu::affine_transform_range(3, dest2.size()-1, dest2, src2, srcb2, srcc2, 2, 3, 4);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+
+ cuda::affine_transform_range(dest.size(), dest.size(), dest, src, srcb, srcc, 2, 3, 4);
+ cpu::affine_transform_range(dest2.size(), dest2.size(), dest2, src2, srcb2, srcc2, 2, 3, 4);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+ }
+
+
+ rnd.fill_uniform(dest);
+ rnd.fill_uniform(src);
+ rnd.fill_uniform(srcb);
+ rnd.fill_uniform(srcc);
+ dest2 = dest; src2 = src; srcb2 = srcb; srcc2 = srcc;
+ cuda::affine_transform(dest, src, srcb, srcc);
+ cpu::affine_transform(dest2, src2, srcb2, srcc2);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+ // now exercise code path where the A/B tensors have num_samples()==1
+ srcb.set_size(1,nc);
+ srcc.set_size(1,nc);
+ rnd.fill_uniform(dest);
+ rnd.fill_uniform(src);
+ rnd.fill_uniform(srcb);
+ rnd.fill_uniform(srcc);
+ dest2 = dest; src2 = src; srcb2 = srcb; srcc2 = srcc;
+ cuda::affine_transform(dest, src, srcb, srcc);
+ cpu::affine_transform(dest2, src2, srcb2, srcc2);
+ DLIB_TEST(equal(mat(dest),mat(dest2)));
+
+
+ rnd.fill_uniform(src);
+ src2 = src;
+ cuda::threshold(src, 0.5);
+ cpu::threshold(src2, 0.5);
+ DLIB_TEST(equal(mat(src),mat(src2)));
+
+ {
+ resizable_tensor dest(3,4);
+ resizable_tensor A, B;
+ A = dest;
+ B = dest;
+
+ rnd.fill_uniform(dest);
+ rnd.fill_uniform(A);
+ rnd.fill_uniform(B);
+
+ dest.set_size(1,4);
+
+ cuda::multiply(false, dest, A, B);
+ DLIB_TEST_MSG(max(abs(mat(dest)-sum_rows(pointwise_multiply(mat(A),mat(B))))) < 1e-6, max(abs(mat(dest)-sum_rows(pointwise_multiply(mat(A),mat(B))))));
+
+ A.set_size(1,4);
+ rnd.fill_uniform(A);
+ matrix<float> AA = join_cols(mat(A),mat(A)); AA = join_cols(mat(A),AA);
+
+ cuda::multiply(false, dest, A, B);
+ DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6);
+
+ cuda::multiply(false, dest, B, A);
+ DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6);
+ matrix<float> prevdest = mat(dest);
+ cuda::multiply(true, dest, B, A);
+ DLIB_TEST(max(abs(mat(dest)-prevdest-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6);
+
+ dest.set_size(3,4);
+ cuda::multiply(false, dest, B, A);
+ DLIB_TEST(max(abs(mat(dest)-pointwise_multiply(AA,mat(B)))) < 1e-6);
+ prevdest = mat(dest);
+ cuda::multiply(true, dest, B, A);
+ DLIB_TEST(max(abs(mat(dest)-prevdest-pointwise_multiply(AA,mat(B)))) < 1e-6);
+
+ cuda::multiply(false, dest, A, B);
+ DLIB_TEST(max(abs(mat(dest)-pointwise_multiply(AA,mat(B)))) < 1e-6);
+ }
+
+ {
+ resizable_tensor invnorms1, invnorms2;
+ resizable_tensor data(4,5), out1, out2;
+ rnd.fill_uniform(data);
+
+ const double eps = 0.1;
+
+ invnorms2 = reciprocal(sqrt(sum_cols(squared(mat(data))) + eps));
+ tt::inverse_norms(invnorms1, data, eps);
+ DLIB_TEST(max(abs(mat(invnorms1)-mat(invnorms2))) < 1e-6);
+
+ out1.copy_size(data);
+ tt::scale_rows(out1, data, invnorms1);
+ out2 = scale_rows(mat(data), mat(invnorms1));
+ DLIB_TEST(max(abs(mat(out1)-mat(out2))) < 1e-6);
+ }
+
+ {
+ resizable_tensor a(123,432), b(123,432);
+ rnd.fill_gaussian(a);
+ rnd.fill_gaussian(b);
+
+ resizable_tensor out;
+ dot_prods(out, a,b);
+ const matrix<float> truth = sum_cols(pointwise_multiply(mat(a), mat(b)));
+ DLIB_TEST(max(abs(mat(out) - truth)) < 1e-4);
+ out = 0;
+ DLIB_TEST(max(abs(mat(out) - truth)) > 1e-2);
+ dot_prods(false, out, a,b);
+ DLIB_TEST(max(abs(mat(out) - truth)) < 1e-4);
+ dot_prods(true, out, a,b);
+ DLIB_TEST(max(abs(mat(out)/2 - truth)) < 1e-4);
+ DLIB_TEST(max(abs(mat(out) - truth)) > 1e-2);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void compare_bn_gpu_and_cpu()
+ {
+ print_spinner();
+ resizable_tensor dest, dest2;
+ resizable_tensor means, means2;
+ resizable_tensor invstds, invstds2;
+ resizable_tensor running_means, running_means2;
+ resizable_tensor running_variances, running_variances2;
+ resizable_tensor src(64,20,100,100);
+ resizable_tensor gamma(1,20,100,100);
+ resizable_tensor beta(1,20,100,100);
+ gamma = 2;
+ beta = 3;
+ tt::tensor_rand rnd;
+ rnd.fill_uniform(src);
+
+
+ cpu::batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, invstds, 1, running_means, running_variances, src, gamma, beta);
+ cuda::batch_normalize(DEFAULT_BATCH_NORM_EPS,dest2,means2,invstds2, 1, running_means2, running_variances2, src, gamma, beta);
+
+ dlog << LINFO << "dest error: "<< max(abs(mat(dest) -mat(dest2)));
+ dlog << LINFO << "means error: "<< max(abs(mat(means) -mat(means2)));
+ dlog << LINFO << "invstds error: "<< max(abs(mat(invstds) -mat(invstds2)));
+ dlog << LINFO << "running_means error: "<< max(abs(mat(running_means) -mat(running_means2)));
+ dlog << LINFO << "running_variances error: "<< max(abs(mat(running_variances) -mat(running_variances2)));
+
+ DLIB_TEST(max(abs(mat(dest) -mat(dest2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(means) -mat(means2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(invstds) -mat(invstds2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(running_means) -mat(running_means2))) < 1e-4);
+ DLIB_TEST_MSG(max(abs(mat(running_variances) -mat(running_variances2))) < 1e-4,
+ mean(mat(running_variances))
+ << "\n" << mean(mat(running_variances2))
+ << "\n" << max(abs(mat(running_variances) -mat(running_variances2)))
+ << "\n" << mean(abs(mat(running_variances) -mat(running_variances2)))
+ );
+
+
+ // now check that the gradients match as well
+ resizable_tensor gradient_input;
+ resizable_tensor src_grad, gamma_grad, beta_grad;
+ resizable_tensor src_grad2, gamma_grad2, beta_grad2;
+ gradient_input.copy_size(dest);
+ src_grad.copy_size(src); src_grad = 0; src_grad2 = src_grad;
+ gamma_grad.copy_size(gamma); gamma_grad = 0; gamma_grad2 = gamma_grad;
+ beta_grad.copy_size(beta); beta_grad = 0; beta_grad2 = beta_grad;
+ rnd.fill_uniform(gradient_input);
+
+
+ cpu::batch_normalize_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
+ cuda::batch_normalize_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, invstds, src, gamma, src_grad2, gamma_grad2, beta_grad2);
+
+ dlog << LINFO << "src_grad error: " << max(abs(mat(src_grad)-mat(src_grad2)));
+ dlog << LINFO << "gamma_grad error: " << max(abs(mat(gamma_grad)-mat(gamma_grad2)));
+ dlog << LINFO << "beta_grad error: " << max(abs(mat(beta_grad)-mat(beta_grad2)));
+ DLIB_TEST(max(abs(mat(src_grad)-mat(src_grad2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(gamma_grad)-mat(gamma_grad2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(beta_grad)-mat(beta_grad2))) < 1e-4);
+ }
+
+ void compare_bn_conv_gpu_and_cpu()
+ {
+ print_spinner();
+ resizable_tensor dest, dest2;
+ resizable_tensor means, means2;
+ resizable_tensor invstds, invstds2;
+ resizable_tensor running_means, running_means2;
+ resizable_tensor running_variances, running_variances2;
+ resizable_tensor src(2,8,10,9);
+ resizable_tensor gamma(1,8);
+ resizable_tensor beta(1,8);
+ gamma = 2;
+ beta = 3;
+ tt::tensor_rand rnd;
+ rnd.fill_uniform(src);
+
+ cpu::batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest,means,invstds,1,running_means,running_variances, src, gamma, beta);
+ cuda::batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest2,means2,invstds2,1,running_means2,running_variances2, src, gamma, beta);
+
+ dlog << LINFO << "dest error: "<< max(abs(mat(dest) -mat(dest2)));
+ dlog << LINFO << "means error: "<< max(abs(mat(means) -mat(means2)));
+ dlog << LINFO << "invstds error: "<< max(abs(mat(invstds) -mat(invstds2)));
+ dlog << LINFO << "running_means error: "<< max(abs(mat(running_means) -mat(running_means2)));
+ dlog << LINFO << "running_variances error: "<< max(abs(mat(running_variances) -mat(running_variances2)));
+
+ DLIB_TEST(max(abs(mat(dest) -mat(dest2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(means) -mat(means2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(invstds) -mat(invstds2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(running_means) -mat(running_means2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(running_variances) -mat(running_variances2))) < 1e-4);
+
+ resizable_tensor gradient_input;
+ resizable_tensor src_grad, gamma_grad, beta_grad;
+ resizable_tensor src_grad2, gamma_grad2, beta_grad2;
+ gradient_input.copy_size(dest);
+ src_grad.copy_size(src); src_grad = 0; src_grad2 = src_grad;
+ gamma_grad.copy_size(gamma); gamma_grad = 0; gamma_grad2 = gamma_grad;
+ beta_grad.copy_size(beta); beta_grad = 0; beta_grad2 = beta_grad;
+ rnd.fill_uniform(gradient_input);
+
+
+ cpu::batch_normalize_conv_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
+ cuda::batch_normalize_conv_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, invstds, src, gamma, src_grad2, gamma_grad2, beta_grad2);
+
+ dlog << LINFO << "src_grad error: " << max(abs(mat(src_grad)-mat(src_grad2)));
+ dlog << LINFO << "gamma_grad error: " << max(abs(mat(gamma_grad)-mat(gamma_grad2)));
+ dlog << LINFO << "beta_grad error: " << max(abs(mat(beta_grad)-mat(beta_grad2)));
+ DLIB_TEST(max(abs(mat(src_grad)-mat(src_grad2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(gamma_grad)-mat(gamma_grad2))) < 1e-4);
+ DLIB_TEST(max(abs(mat(beta_grad)-mat(beta_grad2))) < 1e-4);
+ }
+
+
+ void test_more_ops2()
+ {
+ dlib::rand rnd;
+ tt::tensor_rand trand;
+
+ for (int iter = 0; iter < 100; ++iter)
+ {
+ print_spinner();
+ resizable_tensor dest1, dest2, src1, src2;
+ src1.set_size(rnd.get_random_32bit_number()%30+1,
+ rnd.get_random_32bit_number()%30+1,
+ rnd.get_random_32bit_number()%30+1,
+ rnd.get_random_32bit_number()%30+1);
+ dest1.copy_size(src1);
+ dest2.copy_size(src1);
+ src2.set_size(1,src1.k(),1,1);
+
+ trand.fill_uniform(dest1);
+ trand.fill_uniform(dest2);
+ trand.fill_uniform(src1);
+ trand.fill_uniform(src2);
+
+ cpu::multiply_conv(false, dest1, src1, src2);
+ cuda::multiply_conv(false, dest2, src1, src2);
+ DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5);
+ cpu::multiply_conv(true, dest1, src1, src2);
+ cuda::multiply_conv(true, dest2, src1, src2);
+ DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5);
+
+
+ // now try it using the other mode of multiply_conv
+ src2.copy_size(src1);
+ dest1.set_size(1,src1.k(),1,1);
+ dest2.set_size(1,src1.k(),1,1);
+ trand.fill_uniform(dest1);
+ trand.fill_uniform(dest2);
+ trand.fill_uniform(src1);
+ trand.fill_uniform(src2);
+ cpu::multiply_conv(false, dest1, src1, src2);
+ cuda::multiply_conv(false, dest2, src1, src2);
+ float scale = max(abs(mat(dest1)));
+ float scalem = mean(abs(mat(dest1)));
+ DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2)))/scale < 1e-4 , max(abs(mat(dest1)-mat(dest2)))/scale);
+ DLIB_TEST_MSG(mean(abs(mat(dest1)-mat(dest2)))/scalem < 1e-5 , mean(abs(mat(dest1)-mat(dest2)))/scalem);
+ matrix<float> prevd2 = mat(dest2);
+ cpu::multiply_conv(false, dest1, src1, src2);
+ cuda::multiply_conv(true, dest2, src1, src2);
+ scale = max(abs(mat(dest1)));
+ scalem = mean(abs(mat(dest1)));
+ DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2)+prevd2))/scale < 1e-4 , max(abs(mat(dest1)-mat(dest2)+prevd2))/scale);
+ DLIB_TEST_MSG(mean(abs(mat(dest1)-mat(dest2)+prevd2))/scalem < 1e-5 , mean(abs(mat(dest1)-mat(dest2)+prevd2))/scalem);
+ }
+
+ for (int iter = 0; iter < 100; ++iter)
+ {
+ print_spinner();
+ resizable_tensor dest1, dest2, src, A, B;
+ src.set_size(rnd.get_random_32bit_number()%30+1,
+ rnd.get_random_32bit_number()%30+1,
+ rnd.get_random_32bit_number()%30+1,
+ rnd.get_random_32bit_number()%30+1);
+ dest1.copy_size(src);
+ dest2.copy_size(src);
+ A.set_size(1,src.k(),1,1);
+ B.set_size(1,src.k(),1,1);
+
+ trand.fill_uniform(dest1);
+ trand.fill_uniform(dest2);
+ trand.fill_uniform(src);
+ trand.fill_uniform(A);
+ trand.fill_uniform(B);
+
+ cpu::affine_transform_conv(dest1, src, A, B);
+ cuda::affine_transform_conv(dest2, src, A, B);
+ DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5);
+ }
+
+ for (int iter = 0; iter < 100; ++iter)
+ {
+ print_spinner();
+ resizable_tensor dest1, dest2, g;
+ g.set_size(rnd.get_random_32bit_number()%30+1,
+ rnd.get_random_32bit_number()%30+1,
+ rnd.get_random_32bit_number()%30+1,
+ rnd.get_random_32bit_number()%30+1);
+ dest1.set_size(1,g.k(),1,1);
+ dest2.set_size(1,g.k(),1,1);
+
+ trand.fill_uniform(dest1);
+ trand.fill_uniform(dest2);
+ trand.fill_uniform(g);
+
+ cpu::assign_conv_bias_gradient(dest1, g);
+ cuda::assign_conv_bias_gradient(dest2, g);
+ const float scale = max(abs(mat(dest1)));
+ const float scalem = mean(abs(mat(dest1)));
+ DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2)))/scale < 1e-4 , max(abs(mat(dest1)-mat(dest2)))/scale);
+ DLIB_TEST_MSG(mean(abs(mat(dest1)-mat(dest2)))/scalem < 1e-5 , mean(abs(mat(dest1)-mat(dest2)))/scalem);
+ }
+
+ }
+
+#endif // DLIB_USE_CUDA
+
+// ----------------------------------------------------------------------------------------
+
+ void test_max_pool(
+ const int window_height,
+ const int window_width,
+ const int stride_y,
+ const int stride_x,
+ const int padding_y,
+ const int padding_x
+ )
+ {
+ print_spinner();
+ resizable_tensor A, B, gradient_input;
+ A.set_size(4,5,16,7);
+ B.copy_size(A);
+ gradient_input.copy_size(A);
+
+ tt::tensor_rand rnd;
+ rnd.fill_gaussian(A,0,1);
+ rnd.fill_gaussian(B,0,1);
+ rnd.fill_gaussian(gradient_input,0,1);
+
+
+ tt::pooling mp;
+
+ mp.setup_max_pooling(window_height,window_width,stride_y,stride_x,padding_y,padding_x);
+ mp(A, B);
+
+ // make sure max pooling does what it's spec says it should.
+ DLIB_TEST( A.num_samples() == B.num_samples());
+ DLIB_TEST( A.k() == B.k());
+
+ DLIB_TEST( A.nr() == 1+(B.nr()+2*padding_y-window_height)/stride_y);
+ DLIB_TEST( A.nc() == 1+(B.nc()+2*padding_x-window_width)/stride_x);
+
+ const long x_offset = window_width/2 - padding_x;
+ const long y_offset = window_height/2 - padding_y;
+ for (long s = 0; s < A.num_samples(); ++s)
+ {
+ for (long k = 0; k < A.k(); ++k)
+ {
+ for (long r = 0; r < A.nr(); ++r)
+ {
+ for (long c = 0; c < A.nc(); ++c)
+ {
+ DLIB_TEST_MSG(image_plane(A,s,k)(r,c) == max(subm_clipped(image_plane(B,s,k),
+ centered_rect(c*stride_x+x_offset,
+ r*stride_y+y_offset,
+ window_width,
+ window_height))),
+ "padding: "<< padding_x << " " << padding_y
+ << " window size: " << window_width << " " << window_height
+ << " stride: " << stride_x << " " << stride_y
+ );
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_avg_pool(
+ const int window_height,
+ const int window_width,
+ const int stride_y,
+ const int stride_x,
+ const int padding_y,
+ const int padding_x
+ )
+ {
+ print_spinner();
+ resizable_tensor A, B, gradient_input;
+ A.set_size(4,5,16,7);
+ B.copy_size(A);
+ gradient_input.copy_size(A);
+
+ tt::tensor_rand rnd;
+ rnd.fill_gaussian(A,0,1);
+ rnd.fill_gaussian(B,0,1);
+ rnd.fill_gaussian(gradient_input,0,1);
+
+
+ tt::pooling mp;
+
+ mp.setup_avg_pooling(window_height,window_width,stride_y,stride_x,padding_y,padding_x);
+ mp(A, B);
+
+ // make sure avg pooling does what it's spec says it should.
+ DLIB_TEST( A.num_samples() == B.num_samples());
+ DLIB_TEST( A.k() == B.k());
+ DLIB_TEST( A.nr() == 1+(B.nr()+2*padding_y-window_height)/stride_y);
+ DLIB_TEST( A.nc() == 1+(B.nc()+2*padding_x-window_width)/stride_x);
+
+ const long x_offset = window_width/2 - padding_x;
+ const long y_offset = window_height/2 - padding_y;
+ for (long s = 0; s < A.num_samples(); ++s)
+ {
+ for (long k = 0; k < A.k(); ++k)
+ {
+ for (long r = 0; r < A.nr(); ++r)
+ {
+ for (long c = 0; c < A.nc(); ++c)
+ {
+ float expected = mean(subm_clipped(image_plane(B,s,k),
+ centered_rect(c*stride_x+x_offset,
+ r*stride_y+y_offset,
+ window_width,
+ window_height)));
+ float err = abs(image_plane(A,s,k)(r,c) - expected);
+ DLIB_TEST_MSG(err < 1e-5, err << " " << expected << " " << image_plane(A,s,k)(r,c));
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_layers()
+ {
+ {
+ print_spinner();
+ extract_<0,2,2,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ extract_<3,2,1,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ extract_<0,2,1,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ upsample_<1,1> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ upsample_<2,1> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ upsample_<2,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ upsample_<3,3> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ l2normalize_ l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ multiply_ l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ max_pool_<3,3,1,1> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ avg_pool_<3,3,1,1> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ affine_ l(CONV_MODE);
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ affine_ l(FC_MODE);
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ bn_<CONV_MODE> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ bn_<FC_MODE> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ cont_<3,3,3,2,2,0,0> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ cont_<3,3,3,2,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ cont_<3,3,3,1,1> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ cont_<3,3,3,1,1,0,0> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ cont_<3,2,2,2,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ con_<3,2,2,2,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ con_<3,3,3,1,1>l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ con_<3,3,2,1,1> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ con_<2,1,1,1,1> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ con_<3,0,2,2,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ con_<3,2,0,2,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ con_<3,0,0,2,2> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ fc_<1,FC_HAS_BIAS> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ fc_<5,FC_HAS_BIAS> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ fc_<4,FC_NO_BIAS> l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ relu_ l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ prelu_ l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ sig_ l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ htan_ l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ softmax_ l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ {
+ print_spinner();
+ softmax_all_ l;
+ auto res = test_layer(l);
+ DLIB_TEST_MSG(res, res);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <unsigned long n, typename SUBNET> using rcon = max_pool<2,2,2,2,relu<bn_con<con<n,5,5,1,1,SUBNET>>>>;
+ template <unsigned long n, typename SUBNET> using rfc = relu<bn_fc<fc<n,SUBNET>>>;
+
+ void test_tagging(
+ )
+ {
+ typedef loss_multiclass_log<rfc<10,skip1<rfc<84,rfc<120,tag1<rcon<16,rcon<6,input<matrix<unsigned char>>>>>>>>>> net_type;
+
+ net_type net;
+ net_type net2(num_fc_outputs(4));
+
+ DLIB_TEST(layer<tag1>(net).num_computational_layers == 8);
+ DLIB_TEST(layer<skip1>(net).num_computational_layers == 8+3+3);
+ DLIB_TEST(layer<tag1>(net).num_layers == 10);
+ DLIB_TEST(layer<skip1>(net).num_layers == 10+3+3+1);
+ DLIB_TEST(&layer<skip1>(net).get_output() == &layer<tag1>(net).get_output());
+ DLIB_TEST(&layer<skip1>(net).get_output() != &layer<tag1>(net).subnet().subnet().get_output());
+ DLIB_TEST(net.subnet().subnet().subnet().layer_details().get_num_outputs() == 10);
+ DLIB_TEST(net2.subnet().subnet().subnet().layer_details().get_num_outputs() == 4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ int N,
+ template <typename> class BN,
+ int stride,
+ typename SUBNET
+ >
+ using block = BN<con<N,3,3,1,1,relu<BN<con<N,3,3,stride,stride,SUBNET>>>>>;
+
+ template <
+ template <int,template<typename>class,int,typename> class block,
+ int N,
+ template<typename>class BN,
+ typename SUBNET
+ >
+ using residual = add_prev1<block<N,BN,1,tag1<SUBNET>>>;
+
+ template <
+ template <int,template<typename>class,int,typename> class block,
+ int N,
+ template<typename>class BN,
+ typename SUBNET
+ >
+ using residual_down = add_prev2<avg_pool<2,2,2,2,skip1<tag2<block<N,BN,2,tag1<SUBNET>>>>>>;
+
+
+ template <typename SUBNET> using res = relu<residual<block,8,bn_con,SUBNET>>;
+ template <typename SUBNET> using ares = relu<residual<block,8,affine,SUBNET>>;
+ template <typename SUBNET> using res_down = relu<residual_down<block,8,bn_con,SUBNET>>;
+ template <typename SUBNET> using ares_down = relu<residual_down<block,8,affine,SUBNET>>;
+
+ template <typename SUBNET>
+ using pres = prelu<add_prev1<bn_con<con<8,3,3,1,1,prelu<bn_con<con<8,3,3,1,1,tag1<SUBNET>>>>>>>>;
+
+ void test_visit_funcions()
+ {
+ using net_type2 = loss_multiclass_log<fc<10,
+ avg_pool_everything<
+ pres<res<res<res_down< // 2 prelu layers here
+ tag4<repeat<9,pres, // 9 groups, each containing 2 prelu layers
+ res_down<
+ res<
+ input<matrix<unsigned char>>
+ >>>>>>>>>>>;
+
+ net_type2 pnet;
+
+ DLIB_TEST_MSG(pnet.num_layers == 131, pnet.num_layers);
+ DLIB_TEST_MSG(pnet.num_computational_layers == 109, pnet.num_computational_layers);
+
+ std::vector<bool> hit(pnet.num_computational_layers, false);
+ size_t count = 0;
+ visit_layer_parameter_gradients(pnet, [&](size_t i, tensor& ){hit[i] = true; ++count; });
+ for (auto x : hit)
+ DLIB_TEST(x);
+ DLIB_TEST(count == pnet.num_computational_layers);
+
+ count = 0;
+ std::vector<bool> hit2(pnet.num_computational_layers, false);
+ visit_layer_parameters(pnet, [&](size_t i, tensor& ){hit2[i] = true; ++count; });
+ for (auto x : hit2)
+ DLIB_TEST(x);
+ DLIB_TEST(count == pnet.num_computational_layers);
+ }
+
+ float tensor_read_cpu(const tensor& t, long i, long k, long r, long c)
+ {
+ const float* p = t.host() + t.k() * t.nr() * t.nc() * i +
+ t.nr() * t.nc() * k + t.nc() * r + c;
+ return *p;
+ }
+ void test_copy_tensor_cpu()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ resizable_tensor dest(10, 9, 7, 15);
+ resizable_tensor src1(10, 3, 7, 15);
+ resizable_tensor src2(10, 3, 7, 15);
+ resizable_tensor src3(10, 9, 7, 15);
+ tt::tensor_rand rnd;
+ rnd.fill_gaussian(dest);
+ rnd.fill_gaussian(src1);
+ rnd.fill_gaussian(src2);
+ rnd.fill_gaussian(src3);
+
+ cpu::copy_tensor(false, dest, 0, src1, 0, src1.k()); //full copy src1->dest
+ cpu::copy_tensor(false, dest, src1.k(), src2, 0, src2.k()); //full copy src2->dest with offset of src1
+ cpu::copy_tensor(false, dest, src1.k() + src2.k(), src3, 3, 3); //partial copy src3 into the rest place of dest
+
+
+ for (long i = 0; i < dest.num_samples(); ++i)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ float dest_value = tensor_read_cpu(dest, i, k, r, c);
+ // first part is from src1
+ if (k < src1.k())
+ {
+ float src_value = tensor_read_cpu(src1, i, k, r, c);
+ DLIB_TEST(src_value == dest_value);
+ }
+ // second part is from src2
+ else if (k < src1.k() + src2.k())
+ {
+ float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c);
+ DLIB_TEST(src_value == dest_value);
+ }
+ // third part is from src3
+ else
+ {
+ float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c);
+ DLIB_TEST(src_value == dest_value);
+ }
+ }
+ }
+ }
+ }
+ }
+ void test_copy_tensor_add_to_cpu()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ resizable_tensor dest(10, 9, 7, 15);
+ resizable_tensor src1(10, 3, 7, 15);
+ resizable_tensor src2(10, 3, 7, 15);
+ resizable_tensor src3(10, 9, 7, 15);
+ tt::tensor_rand rnd;
+ rnd.fill_gaussian(dest);
+ rnd.fill_gaussian(src1);
+ rnd.fill_gaussian(src2);
+ rnd.fill_gaussian(src3);
+
+ const resizable_tensor old_dest = dest;
+
+ cpu::copy_tensor(true, dest, 0, src1, 0, src1.k()); //full copy src1->dest
+ cpu::copy_tensor(true, dest, src1.k(), src2, 0, src2.k()); //full copy src2->dest with offset of src1
+ cpu::copy_tensor(true, dest, src1.k() + src2.k(), src3, 3, 3); //partial copy src3 into the rest place of dest
+
+
+ for (long i = 0; i < dest.num_samples(); ++i)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ float old_dest_value = tensor_read_cpu(old_dest, i, k, r, c);
+ float dest_value = tensor_read_cpu(dest, i, k, r, c);
+ // first part is from src1
+ if (k < src1.k())
+ {
+ float src_value = tensor_read_cpu(src1, i, k, r, c)+old_dest_value;
+ DLIB_TEST(std::abs(src_value - dest_value) < 1e-6);
+ }
+ // second part is from src2
+ else if (k < src1.k() + src2.k())
+ {
+ float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c)+old_dest_value;
+ DLIB_TEST(std::abs(src_value - dest_value) < 1e-6);
+ }
+ // third part is from src3
+ else
+ {
+ float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c)+old_dest_value;
+ DLIB_TEST(std::abs(src_value - dest_value) < 1e-6);
+ }
+ }
+ }
+ }
+ }
+ }
+#ifdef DLIB_USE_CUDA
+ void test_copy_tensor_gpu()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ resizable_tensor dest(10, 9, 7, 15);
+ resizable_tensor src1(10, 3, 7, 15);
+ resizable_tensor src2(10, 3, 7, 15);
+ resizable_tensor src3(10, 9, 7, 15);
+ tt::tensor_rand rnd;
+ rnd.fill_gaussian(dest);
+ rnd.fill_gaussian(src1);
+ rnd.fill_gaussian(src2);
+ rnd.fill_gaussian(src3);
+ cuda::copy_tensor(false, dest, 0, src1, 0, src1.k()); //full copy src1->dest
+ cuda::copy_tensor(false, dest, src1.k(), src2, 0, src2.k()); //full copy src2->dest with offset of src1
+ cuda::copy_tensor(false, dest, src1.k() + src2.k(), src3, 3, 3); //partial copy src3 into the rest place of dest
+
+
+ for (long i = 0; i < dest.num_samples(); ++i)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ float dest_value = tensor_read_cpu(dest, i, k, r, c);
+ // first part is from src1
+ if (k < src1.k())
+ {
+ float src_value = tensor_read_cpu(src1, i, k, r, c);
+ DLIB_TEST(src_value == dest_value);
+ }
+ // second part is from src2
+ else if (k < src1.k() + src2.k())
+ {
+ float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c);
+ DLIB_TEST(src_value == dest_value);
+ }
+ // third part is from src3
+ else
+ {
+ float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c);
+ DLIB_TEST(src_value == dest_value);
+ }
+ }
+ }
+ }
+ }
+ }
+ void test_copy_tensor_add_to_gpu()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+ resizable_tensor dest(10, 9, 7, 15);
+ resizable_tensor src1(10, 3, 7, 15);
+ resizable_tensor src2(10, 3, 7, 15);
+ resizable_tensor src3(10, 9, 7, 15);
+ tt::tensor_rand rnd;
+ rnd.fill_gaussian(dest);
+ rnd.fill_gaussian(src1);
+ rnd.fill_gaussian(src2);
+ rnd.fill_gaussian(src3);
+
+ const resizable_tensor old_dest = dest;
+
+ cuda::copy_tensor(true, dest, 0, src1, 0, src1.k()); //full copy src1->dest
+ cuda::copy_tensor(true, dest, src1.k(), src2, 0, src2.k()); //full copy src2->dest with offset of src1
+ cuda::copy_tensor(true, dest, src1.k() + src2.k(), src3, 3, 3); //partial copy src3 into the rest place of dest
+
+
+ for (long i = 0; i < dest.num_samples(); ++i)
+ {
+ for (long k = 0; k < dest.k(); ++k)
+ {
+ for (long r = 0; r < dest.nr(); ++r)
+ {
+ for (long c = 0; c < dest.nc(); ++c)
+ {
+ float old_dest_value = tensor_read_cpu(old_dest, i, k, r, c);
+ float dest_value = tensor_read_cpu(dest, i, k, r, c);
+ // first part is from src1
+ if (k < src1.k())
+ {
+ float src_value = tensor_read_cpu(src1, i, k, r, c)+old_dest_value;
+ DLIB_TEST_MSG(std::abs(src_value - dest_value) < 1e-6, std::abs(src_value - dest_value));
+ }
+ // second part is from src2
+ else if (k < src1.k() + src2.k())
+ {
+ float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c)+old_dest_value;
+ DLIB_TEST(std::abs(src_value - dest_value) < 1e-6);
+ }
+ // third part is from src3
+ else
+ {
+ float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c)+old_dest_value;
+ DLIB_TEST(std::abs(src_value - dest_value) < 1e-6);
+ }
+ }
+ }
+ }
+ }
+ }
+#endif//DLIB_USE_CUDA
+
+ template <typename SUBNET> using concat_block1 = con<5,1,1,1,1,SUBNET>;
+ template <typename SUBNET> using concat_block2 = con<8,3,3,1,1,SUBNET>;
+ template <typename SUBNET> using concat_block3 = max_pool<3,3,1,1,SUBNET>;
+ template <typename SUBNET> using concat_incept = inception3<concat_block1,concat_block2,concat_block3,SUBNET>;
+
+ void test_concat()
+ {
+ using namespace dlib::tt;
+ print_spinner();
+
+ using net_type = concat_incept<input<matrix<float>>>;
+
+ resizable_tensor data(10, 1, 111, 222);
+ tt::tensor_rand rnd;
+ rnd.fill_gaussian(data);
+
+ net_type net;
+
+
+ auto& out = net.forward(data);
+
+ auto& b1o = layer<itag1>(net).get_output();
+ auto& b2o = layer<itag2>(net).get_output();
+ auto& b3o = layer<itag3>(net).get_output();
+
+ resizable_tensor dest(10, 14, 111, 222);
+ copy_tensor(false, dest, 0, b1o, 0, b1o.k());
+ copy_tensor(false, dest, b1o.k(), b2o, 0, b2o.k());
+ copy_tensor(false, dest, b1o.k() + b2o.k(), b3o, 0, b3o.k());
+
+ DLIB_TEST(dest.size() == out.size());
+ int error = memcmp(dest.host(), out.host(), dest.size());
+ DLIB_TEST(error == 0);
+
+ resizable_tensor gr(10, 14, 111, 222);
+ rnd.fill_gaussian(gr);
+
+ resizable_tensor params;
+ net.layer_details().backward(gr, net, params);
+
+ auto& b1g = layer<itag1>(net).subnet().get_gradient_input();
+ auto& b2g = layer<itag2>(net).subnet().get_gradient_input();
+ auto& b3g = layer<itag3>(net).subnet().get_gradient_input();
+
+ resizable_tensor g1(10, 5, 111, 222);
+ resizable_tensor g2(10, 8, 111, 222);
+ resizable_tensor g3(10, 1, 111, 222);
+
+ copy_tensor(false, g1, 0, gr, 0, g1.k());
+ copy_tensor(false, g2, 0, gr, g1.k(), g2.k());
+ copy_tensor(false, g3, 0, gr, g1.k() + g2.k(), g3.k());
+ DLIB_TEST(g1.size() == b1g.size());
+ error = memcmp(g1.host(), b1g.host(), b1g.size());
+ DLIB_TEST(error == 0);
+ DLIB_TEST(g2.size() == b2g.size());
+ error = memcmp(g2.host(), b2g.host(), b2g.size());
+ DLIB_TEST(error == 0);
+ DLIB_TEST(g3.size() == b3g.size());
+ error = memcmp(g3.host(), b3g.host(), b3g.size());
+ DLIB_TEST(error == 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_simple_linear_regression()
+ {
+ const int num_samples = 1000;
+ ::std::vector<matrix<double>> x(num_samples);
+ ::std::vector<float> y(num_samples);
+ ::std::default_random_engine generator(16);
+ ::std::normal_distribution<float> distribution(0,0.1);
+ const float true_intercept = 50.0;
+ const float true_slope = 10.0;
+ for ( int ii = 0; ii < num_samples; ++ii )
+ {
+ const double val = static_cast<double>(ii)/10;
+ matrix<double> tmp(1,1);
+ tmp = val;
+ x[ii] = tmp;
+ y[ii] = (true_intercept + true_slope*static_cast<float>(val) + distribution(generator));
+ }
+
+ using net_type = loss_mean_squared<fc<1, input<matrix<double>>>>;
+ net_type net;
+ layer<1>(net).layer_details().set_bias_learning_rate_multiplier(300);
+ sgd defsolver(0,0.9);
+ dnn_trainer<net_type> trainer(net, defsolver);
+ trainer.set_learning_rate(1e-5);
+ trainer.set_min_learning_rate(1e-6);
+ trainer.set_mini_batch_size(50);
+ trainer.set_max_num_epochs(170);
+ trainer.train(x, y);
+
+ const float slope = layer<1>(net).layer_details().get_weights().host()[0];
+ const float slope_error = abs(true_slope - slope);
+ const float intercept = layer<1>(net).layer_details().get_biases().host()[0];
+ const float intercept_error = abs(true_intercept - intercept);
+ const float eps_slope = 0.05, eps_intercept = 0.1;
+
+ DLIB_TEST_MSG(slope_error <= eps_slope,
+ "Expected slope = " << true_slope << " Estimated slope = " << slope << " Error limit = " << eps_slope);
+ DLIB_TEST_MSG(intercept_error <= eps_intercept,
+ "Expected intercept = " << true_intercept << " Estimated intercept = " << intercept << " Error limit = " << eps_intercept);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_simple_linear_regression_eil()
+ {
+ print_spinner();
+ const int num_samples = 1000;
+ ::std::vector<matrix<double>> x(num_samples);
+ ::std::vector<float> y(num_samples);
+ ::std::default_random_engine generator(16);
+ ::std::normal_distribution<float> distribution(0,0.0001);
+ const float true_intercept = 50.0;
+ const float true_slope = 10.0;
+ for ( int ii = 0; ii < num_samples; ++ii )
+ {
+ const double val = static_cast<double>(ii)/10;
+ matrix<double> tmp(1,1);
+ tmp = val;
+ x[ii] = tmp;
+ y[ii] = (true_intercept + true_slope*static_cast<float>(val) + distribution(generator));
+ }
+
+ using net_type = loss_epsilon_insensitive<fc<1, input<matrix<double>>>>;
+ net_type net(0.01);
+ layer<1>(net).layer_details().set_bias_learning_rate_multiplier(300);
+ sgd defsolver(0,0.9);
+ dnn_trainer<net_type> trainer(net, defsolver);
+ trainer.set_learning_rate(1e-5);
+ trainer.set_min_learning_rate(1e-8);
+ trainer.set_mini_batch_size(50);
+ trainer.set_max_num_epochs(570);
+ trainer.train(x, y);
+
+ const float slope = layer<1>(net).layer_details().get_weights().host()[0];
+ const float slope_error = abs(true_slope - slope);
+ const float intercept = layer<1>(net).layer_details().get_biases().host()[0];
+ const float intercept_error = abs(true_intercept - intercept);
+ const float eps_slope = 0.01, eps_intercept = 0.1;
+
+ dlog << LINFO << "slope_error: "<< slope_error;
+ dlog << LINFO << "intercept_error: "<< intercept_error;
+ DLIB_TEST_MSG(slope_error <= eps_slope,
+ "Expected slope = " << true_slope << " Estimated slope = " << slope << " Error limit = " << eps_slope);
+ DLIB_TEST_MSG(intercept_error <= eps_intercept,
+ "Expected intercept = " << true_intercept << " Estimated intercept = " << intercept << " Error limit = " << eps_intercept);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_simple_linear_regression_with_mult_prev()
+ {
+ srand(1234);
+ print_spinner();
+ const int num_samples = 1000;
+ ::std::vector<matrix<double>> x(num_samples);
+ ::std::vector<float> y(num_samples);
+ const float true_slope = 2.0;
+ for ( int ii = 0; ii < num_samples; ++ii )
+ {
+ const double val = static_cast<double>(ii-500)/100;
+ matrix<double> tmp(1,1);
+ tmp = val;
+ x[ii] = tmp;
+ y[ii] = ( true_slope*static_cast<float>(val*val));
+ }
+
+ randomize_samples(x,y);
+
+ using net_type = loss_mean_squared<fc<1, mult_prev1<fc<2,tag1<fc<2,input<matrix<double>>>>>>>>;
+ net_type net;
+ sgd defsolver(0,0.9);
+ dnn_trainer<net_type> trainer(net, defsolver);
+ trainer.set_learning_rate(1e-5);
+ trainer.set_min_learning_rate(1e-11);
+ trainer.set_mini_batch_size(50);
+ trainer.set_max_num_epochs(2000);
+ trainer.train(x, y);
+
+ running_stats<double> rs;
+ for (size_t i = 0; i < x.size(); ++i)
+ {
+ double val = y[i];
+ double out = net(x[i]);
+ rs.add(std::abs(val-out));
+ }
+ dlog << LINFO << "rs.mean(): " << rs.mean();
+ dlog << LINFO << "rs.stddev(): " << rs.stddev();
+ dlog << LINFO << "rs.max(): " << rs.max();
+ DLIB_TEST(rs.mean() < 0.1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_multioutput_linear_regression()
+ {
+ const int num_outputs = 2;
+ const int num_samples = 1000;
+ ::std::vector<matrix<double>> x(num_samples);
+ ::std::vector<matrix<float>> y(num_samples);
+ ::std::default_random_engine generator(16);
+ ::std::normal_distribution<float> distribution(0,0.1);
+ ::std::normal_distribution<float> slope_distribution(10,5);
+ ::std::normal_distribution<float> intercept_distribution(50,10);
+ ::std::vector<float> true_intercepts(num_outputs);
+ ::std::vector<float> true_slopes(num_outputs);
+ for ( int jj = 0; jj < num_outputs; ++jj )
+ {
+ true_slopes[jj] = slope_distribution(generator);
+ true_intercepts[jj] = intercept_distribution(generator);
+ }
+ matrix<float> ytmp(num_outputs, 1);
+ for ( int ii = 0; ii < num_samples; ++ii )
+ {
+ const double val = static_cast<double>(ii)/10;
+ matrix<double> tmp(1,1);
+ tmp = val;
+ x[ii] = tmp;
+ for ( int jj = 0; jj < num_outputs; ++jj )
+ ytmp(jj, 0) = (true_intercepts[jj] + true_slopes[jj]*static_cast<float>(val) + distribution(generator));
+
+ y[ii] = ytmp;
+ }
+
+ using net_type = loss_mean_squared_multioutput<fc<num_outputs, input<matrix<double>>>>;
+ net_type net;
+ layer<1>(net).layer_details().set_bias_learning_rate_multiplier(900);
+ sgd defsolver(0,0.9);
+ dnn_trainer<net_type> trainer(net, defsolver);
+ trainer.set_learning_rate(1e-5);
+ trainer.set_min_learning_rate(1e-6);
+ trainer.set_mini_batch_size(50);
+ trainer.set_max_num_epochs(170);
+ trainer.train(x, y);
+
+ float slope_error = 0.0;
+ float intercept_error = 0.0;
+ const float eps_slope = 0.05, eps_intercept = 0.1;
+
+ for ( int jj = 0; jj < num_outputs; ++jj )
+ {
+ slope_error += abs(layer<1>(net).layer_details().get_weights().host()[jj] - true_slopes[jj]);
+ intercept_error += abs(layer<1>(net).layer_details().get_biases().host()[jj] - true_intercepts[jj]);
+ }
+
+ slope_error /= float(num_outputs);
+ intercept_error /= float(num_outputs);
+
+ DLIB_TEST_MSG(slope_error <= eps_slope,
+ "Average absolute slope error = " << slope_error << " Error limit = " << eps_slope);
+ DLIB_TEST_MSG(intercept_error <= eps_intercept,
+ "Average absolute intercept error = " << intercept_error << " Error limit = " << eps_intercept);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_simple_autoencoder()
+ {
+ print_spinner();
+
+ srand(1234);
+
+ const int output_width = 7;
+ const int output_height = 7;
+ const int num_samples = 100;
+ ::std::vector<matrix<float>> x(num_samples);
+
+ matrix<float> tmp(output_width, output_height);
+ for (int i = 0; i < num_samples; ++i)
+ {
+ const int model = i % 4;
+
+ for (int r = 0; r < output_height; ++r)
+ for (int c = 0; c < output_width; ++c)
+ switch (model) {
+ case 0: tmp(r, c) = r / output_height; break;
+ case 1: tmp(r, c) = c / output_width; break;
+ case 2: tmp(r, c) = 1.0 - r / output_height; break;
+ case 3: tmp(r, c) = 1.0 - c / output_width; break;
+ default: DLIB_TEST_MSG(false, "Invalid model: " << model << " (should be between 0 and 3)");
+ }
+
+ x[i] = tmp;
+ }
+
+ using net_type = loss_mean_squared_per_pixel<
+ cont<1,output_height,output_width,2,2,
+ relu<con<4,output_height,output_width,2,2,
+ input<matrix<float>>>>>>;
+ net_type net;
+
+ const auto autoencoder_error = [&x, &net, &output_height, &output_width]()
+ {
+ const auto y = net(x);
+ double error = 0.0;
+ for (size_t i = 0; i < x.size(); ++i)
+ for (int r = 0; r < output_height; ++r)
+ for (int c = 0; c < output_width; ++c)
+ error += fabs(y[i](r, c) - x[i](r, c));
+
+ return error / (x.size() * output_height * output_width);
+ };
+
+ // The autoencoder can't be very good before it's been trained
+ // (or at least the probability of the reconstruction error
+ // being small should be super low; in fact, the error ought to
+ // be much higher than 0.01, however since the initialization
+ // is random, putting the limit below too high could make the
+ // tests fail when other, unrelated tests are added into the
+ // sequence)
+ const double error_before = autoencoder_error();
+ DLIB_TEST_MSG(error_before > 0.01, "Autoencoder error before training = " << error_before);
+
+ // Make sure there's an information bottleneck, as intended
+ const auto& output2 = dlib::layer<2>(net).get_output();
+ DLIB_TEST(output2.nr() == 1);
+ DLIB_TEST(output2.nc() == 1);
+ DLIB_TEST(output2.k() == 4);
+
+ sgd defsolver(0,0.9);
+ dnn_trainer<net_type> trainer(net, defsolver);
+ trainer.set_learning_rate(0.01);
+ trainer.set_max_num_epochs(1000);
+ trainer.train(x, x);
+
+ // Now we should have learned everything there is to it
+ const double error_after = autoencoder_error();
+ DLIB_TEST_MSG(error_after < 1e-6, "Autoencoder error after training = " << error_after);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task()
+ {
+ print_spinner();
+
+ constexpr uint16_t num_classes = 7;
+ constexpr uint16_t true_label = num_classes / 2;
+
+ ::std::vector<matrix<float>> x({ matrix<float,1,1>({ 1 }) });
+ ::std::vector<matrix<uint16_t>> y({ matrix<uint16_t,1,1>({ true_label }) });
+
+ using net_type = loss_multiclass_log_per_pixel<con<num_classes,1,1,1,1,input<matrix<float>>>>;
+ net_type net;
+
+ dnn_trainer<net_type> trainer(net, sgd(0,0));
+ trainer.set_learning_rate(1e7);
+ trainer.set_max_num_epochs(1);
+ trainer.train(x, y);
+
+ const tensor& learned_params = layer<1>(net).layer_details().get_layer_params();
+ const float* learned_params_data = learned_params.host();
+
+ for (int is_bias = 0; is_bias <= 1; ++is_bias) {
+ for (uint16_t k = 0; k < num_classes; ++k) {
+ size_t index = k + is_bias * num_classes;
+ DLIB_TEST(index < learned_params.size());
+ if (k == true_label) {
+ DLIB_TEST(learned_params_data[index] > 1e5);
+ }
+ else {
+ DLIB_TEST(learned_params_data[index] < -1e5);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_loss_multiclass_per_pixel_activations_on_trivial_single_pixel_task()
+ {
+ print_spinner();
+
+ constexpr int input_height = 35;
+ constexpr int input_width = 27;
+ constexpr int output_height = input_height;
+ constexpr int output_width = input_width;
+ constexpr int num_samples = 7;
+ constexpr int num_classes = 5;
+
+ ::std::vector<matrix<float>> x(num_samples);
+ ::std::vector<matrix<uint16_t>> y(num_samples);
+
+ matrix<float> xtmp(input_height, input_width);
+ matrix<uint16_t> ytmp(output_height, output_width);
+
+ ::std::default_random_engine generator(16);
+ ::std::bernoulli_distribution coinflip(0.5);
+
+ using filter_type = con<num_classes,1,1,1,1,input<matrix<float>>>;
+
+ // Define a "truth" filter
+ filter_type truth_filter;
+ truth_filter(xtmp); // Set up the convolutional layer
+
+ // Generate training data
+ for (int ii = 0; ii < num_samples; ++ii) {
+ // Generate random inputs x
+ for (int jj = 0; jj < input_height; ++jj)
+ for (int kk = 0; kk < input_width; ++kk)
+ xtmp(jj, kk) = coinflip(generator) ? 1.f : -1.f;
+ x[ii] = xtmp;
+
+ // Generate target output y by applying the truth filter on x
+ const tensor& output = truth_filter(xtmp);
+ const float* const out_data = output.host();
+
+ const auto out_element = [&](int row, int column, int k) {
+ return out_data[(k * output.nr() + row) * output.nc() + column];
+ };
+
+ for (int jj = 0; jj < output_height; ++jj) {
+ for (int kk = 0; kk < output_width; ++kk) {
+ uint16_t label = 0;
+ float max_value = out_element(jj, kk, 0);
+ for (long k = 1; k < num_classes; ++k) {
+ const float value = out_element(jj, kk, k);
+ if (value > max_value) {
+ label = static_cast<uint16_t>(k);
+ max_value = value;
+ }
+ }
+ ytmp(jj, kk) = label;
+ }
+ }
+ y[ii] = ytmp;
+ }
+
+ using net_type = loss_multiclass_log_per_pixel<filter_type>;
+ net_type net;
+
+ dnn_trainer<net_type> trainer(net, sgd(0,0));
+ trainer.set_learning_rate(1e6);
+ trainer.set_max_num_epochs(1);
+ trainer.train(x, y);
+
+ // Feed forward the training samples.
+ resizable_tensor temp_tensor;
+ net.subnet().to_tensor(&x[0], &x[0] + num_samples, temp_tensor);
+ net.subnet().forward(temp_tensor);
+ const dimpl::subnet_wrapper<filter_type> wsub(net.subnet());
+ const tensor& output_tensor = wsub.get_output();
+ const float* const out_data = output_tensor.host();
+
+ // Let's have a look at the activations before softmax. They should be pretty high
+ // (in terms of absolute value), because the learning task is trivial.
+ for (int ii = 0; ii < num_samples; ++ii) {
+ for (int jj = 0; jj < output_height; ++jj) {
+ for (int kk = 0; kk < output_width; ++kk) {
+ const uint16_t true_label = y[ii](jj, kk);
+
+ for (long k = 0; k < num_classes; ++k) {
+ const size_t index = ((ii * output_tensor.k() + k) * output_tensor.nr() + jj) * output_tensor.nc() + kk;
+ DLIB_TEST(index < output_tensor.size());
+
+ if (k == true_label) {
+ DLIB_TEST(out_data[index] > 1e4);
+ }
+ else {
+ DLIB_TEST(out_data[index] < -1e4);
+ }
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_loss_multiclass_per_pixel_outputs_on_trivial_task()
+ {
+ print_spinner();
+
+ constexpr int input_height = 7;
+ constexpr int input_width = 5;
+ constexpr int output_height = input_height;
+ constexpr int output_width = input_width;
+ constexpr int num_samples = 7;
+ constexpr int num_classes = 5;
+ constexpr int filter_height = 3;
+ constexpr int filter_width = 3;
+
+ ::std::vector<matrix<float>> x(num_samples);
+ ::std::vector<matrix<uint16_t>> y(num_samples);
+
+ matrix<float> xtmp(input_height, input_width);
+ matrix<uint16_t> ytmp(output_height, output_width);
+
+ ::std::default_random_engine generator(16);
+ ::std::bernoulli_distribution coinflip(0.5);
+
+ using filter_type = con<num_classes, filter_height, filter_width, 1, 1, input<matrix<float>>>;
+
+ // Define a "truth" filter
+ filter_type truth_filter;
+ truth_filter(xtmp); // Set up the convolutional layer
+
+ // Generate training data
+ for (int ii = 0; ii < num_samples; ++ii) {
+ // Generate random inputs x
+ for (int jj = 0; jj < input_height; ++jj)
+ for (int kk = 0; kk < input_width; ++kk)
+ xtmp(jj, kk) = coinflip(generator) ? 1.f : -1.f;
+ x[ii] = xtmp;
+
+ // Generate target output y by applying the truth filter on x
+ const tensor& output = truth_filter(xtmp);
+ const float* const out_data = output.host();
+
+ const auto out_element = [&](int row, int column, int k) {
+ return out_data[(k * output.nr() + row) * output.nc() + column];
+ };
+
+ for (int jj = 0; jj < output_height; ++jj) {
+ for (int kk = 0; kk < output_width; ++kk) {
+ uint16_t label = 0;
+ float max_value = out_element(jj, kk, 0);
+ for (long k = 1; k < num_classes; ++k) {
+ const float value = out_element(jj, kk, k);
+ if (value > max_value) {
+ label = static_cast<uint16_t>(k);
+ max_value = value;
+ }
+ }
+ ytmp(jj, kk) = label;
+ }
+ }
+ y[ii] = ytmp;
+ }
+
+ using net_type = loss_multiclass_log_per_pixel<filter_type>;
+ net_type net;
+
+ dnn_trainer<net_type> trainer(net, sgd(0, 0.9));
+ trainer.set_learning_rate(1);
+ trainer.set_max_num_epochs(2000);
+ trainer.train(x, y);
+
+ // The learning task is separable, so the net should have no problem
+ // getting all the outputs right.
+ DLIB_TEST(net(x) == y);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_loss_multiclass_per_pixel_with_noise_and_pixels_to_ignore()
+ {
+ // "Semantic segmentation" - see https://github.com/davisking/dlib/issues/288
+ // Test learning when some pixels are to be ignored, etc.
+
+ print_spinner();
+
+ constexpr int input_height = 5;
+ constexpr int input_width = 7;
+ constexpr int output_height = input_height;
+ constexpr int output_width = input_width;
+ const int num_samples = 1000;
+ const int num_classes = 6;
+ const double ignore_probability = 0.5;
+ const double noise_probability = 0.05;
+
+ ::std::default_random_engine generator(16);
+ ::std::bernoulli_distribution ignore(ignore_probability);
+ ::std::bernoulli_distribution noise_occurrence(noise_probability);
+ ::std::uniform_int_distribution<uint16_t> noisy_label(0, num_classes - 1);
+
+ ::std::vector<matrix<double>> x(num_samples);
+ ::std::vector<matrix<uint16_t>> y(num_samples);
+
+ ::std::vector<int> truth_histogram(num_classes);
+
+ matrix<double> xtmp(input_height, input_width);
+ matrix<uint16_t> ytmp(output_height, output_width);
+
+ // The function to be learned.
+ const auto ground_truth = [num_classes](const matrix<double>& x, int row, int column) {
+ double sum = 0.0;
+ const int first_column = std::max(0, column - 1);
+ const int last_column = std::min(static_cast<int>(x.nc() - 1), column + 1);
+ for (int c = first_column; c <= last_column; ++c) {
+ sum += x(row, c);
+ }
+ DLIB_TEST(sum < num_classes);
+ return static_cast<uint16_t>(sum);
+ };
+
+ for ( int ii = 0; ii < num_samples; ++ii ) {
+ for ( int jj = 0; jj < input_height; ++jj ) {
+ for ( int kk = 0; kk < input_width; ++kk ) {
+ // Generate numbers between 0 and 2.
+ double value = static_cast<double>(ii + jj + kk) / 10.0;
+ value -= (static_cast<int>(value) / 2) * 2;
+ DLIB_TEST(value >= 0.0 && value < 2.0);
+ xtmp(jj, kk) = value;
+ }
+ }
+ x[ii] = xtmp;
+
+ for ( int jj = 0; jj < output_height; ++jj ) {
+ for ( int kk = 0; kk < output_width; ++kk ) {
+ uint16_t truth = ground_truth(x[ii], jj, kk);
+ DLIB_TEST(truth < num_classes);
+ ++truth_histogram[truth];
+ if (ignore(generator)) {
+ ytmp(jj, kk) = loss_multiclass_log_per_pixel_::label_to_ignore;
+ }
+ else if (noise_occurrence(generator)) {
+ ytmp(jj, kk) = noisy_label(generator);
+ }
+ else {
+ ytmp(jj, kk) = truth;
+ }
+ }
+ }
+
+ y[ii] = ytmp;
+ }
+
+ const int num_total_elements = num_samples * output_height * output_width;
+
+ { // Require a reasonably balanced truth histogram in order to make sure that a trivial classifier is not enough
+ const int required_min_histogram_value = static_cast<int>(::std::ceil(num_total_elements / num_classes * 0.375));
+ for (auto histogram_value : truth_histogram) {
+ DLIB_TEST_MSG(histogram_value >= required_min_histogram_value,
+ "Histogram value = " << histogram_value << ", required = " << required_min_histogram_value);
+ }
+ }
+
+ using net_type = loss_multiclass_log_per_pixel<bn_con<con<num_classes,1,input_width,1,1,input<matrix<double>>>>>;
+ net_type net;
+ sgd defsolver(0,0.9);
+ dnn_trainer<net_type> trainer(net, defsolver);
+ trainer.set_learning_rate(0.1);
+ trainer.set_min_learning_rate(0.01);
+ trainer.set_mini_batch_size(50);
+ trainer.set_max_num_epochs(170);
+ trainer.train(x, y);
+
+ const ::std::vector<matrix<uint16_t>> predictions = net(x);
+
+ int num_correct = 0;
+
+ for ( int ii = 0; ii < num_samples; ++ii ) {
+ const matrix<uint16_t>& prediction = predictions[ii];
+ DLIB_TEST(prediction.nr() == output_height);
+ DLIB_TEST(prediction.nc() == output_width);
+ for ( int jj = 0; jj < output_height; ++jj )
+ for ( int kk = 0; kk < output_width; ++kk )
+ if ( prediction(jj, kk) == ground_truth(x[ii], jj, kk) )
+ ++num_correct;
+ }
+
+ // First some sanity checks.
+ const int num_correct_max = num_total_elements;
+ DLIB_TEST(num_correct_max == ::std::accumulate(truth_histogram.begin(), truth_histogram.end(), 0));
+ DLIB_TEST_MSG(num_correct <= num_correct_max,
+ "Number of correctly classified elements = " << num_correct << ", max = " << num_correct_max);
+
+ // This is the real test, verifying that we have actually learned something.
+ const int num_correct_required = static_cast<int>(::std::ceil(0.9 * num_correct_max));
+ DLIB_TEST_MSG(num_correct >= num_correct_required,
+ "Number of correctly classified elements = " << num_correct << ", required = " << num_correct_required);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_loss_multiclass_per_pixel_weighted()
+ {
+ // Train with pixel-specific weights
+
+ print_spinner();
+
+ constexpr int input_height = 5;
+ constexpr int input_width = 7;
+ constexpr int output_height = input_height;
+ constexpr int output_width = input_width;
+ const int num_samples = 1000;
+ const int num_classes = 6;
+
+ ::std::default_random_engine generator(16);
+ ::std::uniform_real_distribution<double> u01(0.0, 1.0);
+ ::std::uniform_int_distribution<uint16_t> noisy_label(0, num_classes - 1);
+
+ ::std::vector<matrix<double>> x(num_samples);
+ ::std::vector<matrix<uint16_t>> y(num_samples);
+
+ matrix<double> xtmp(input_height, input_width);
+ matrix<uint16_t> ytmp(output_height, output_width);
+
+ // Generate input data
+ for (int ii = 0; ii < num_samples; ++ii) {
+ for (int jj = 0; jj < input_height; ++jj) {
+ for (int kk = 0; kk < input_width; ++kk) {
+ xtmp(jj, kk) = u01(generator);
+ ytmp(jj, kk) = noisy_label(generator);
+ }
+ }
+ x[ii] = xtmp;
+ y[ii] = ytmp;
+ }
+
+ using net_type = loss_multiclass_log_per_pixel_weighted<con<num_classes,1,1,1,1,input<matrix<double>>>>;
+ using weighted_label = loss_multiclass_log_per_pixel_weighted_::weighted_label;
+
+ ::std::vector<matrix<weighted_label>> y_weighted(num_samples);
+
+ for (int weighted_class = 0; weighted_class < num_classes; ++weighted_class) {
+
+ print_spinner();
+
+ // Assign weights
+ for (int ii = 0; ii < num_samples; ++ii) {
+ if (weighted_class == 0) {
+ y_weighted[ii].set_size(input_height, input_width);
+ }
+ for (int jj = 0; jj < input_height; ++jj) {
+ for (int kk = 0; kk < input_width; ++kk) {
+ const uint16_t label = y[ii](jj, kk);
+ const float weight
+ = label == weighted_class
+ ? 1.1f
+ : 0.9f;
+ y_weighted[ii](jj, kk) = weighted_label(label, weight);
+ }
+ }
+ }
+
+ net_type net;
+ sgd defsolver(0,0.9);
+ dnn_trainer<net_type> trainer(net, defsolver);
+ trainer.set_learning_rate(0.1);
+ trainer.set_min_learning_rate(0.01);
+ trainer.set_mini_batch_size(10);
+ trainer.set_max_num_epochs(10);
+ trainer.train(x, y_weighted);
+
+ const ::std::vector<matrix<uint16_t>> predictions = net(x);
+
+ int num_weighted_class = 0;
+ int num_not_weighted_class = 0;
+
+ for ( int ii = 0; ii < num_samples; ++ii ) {
+ const matrix<uint16_t>& prediction = predictions[ii];
+ DLIB_TEST(prediction.nr() == output_height);
+ DLIB_TEST(prediction.nc() == output_width);
+ for ( int jj = 0; jj < output_height; ++jj )
+ for ( int kk = 0; kk < output_width; ++kk )
+ if ( prediction(jj, kk) == weighted_class )
+ ++num_weighted_class;
+ else
+ ++num_not_weighted_class;
+ }
+
+ DLIB_TEST_MSG(num_weighted_class > num_not_weighted_class,
+ "The weighted class (" << weighted_class << ") does not dominate: "
+ << num_weighted_class << " <= " << num_not_weighted_class);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_tensor_resize_bilinear(long samps, long k, long nr, long nc, long onr, long onc)
+ {
+ resizable_tensor img(samps,k,nr,nc);
+ resizable_tensor out(samps,k,onr,onc);
+ resizable_tensor out2(samps,k,onr,onc);
+
+ dlib::rand rnd;
+ for (int iter = 0; iter < 10; ++iter)
+ {
+ print_spinner();
+
+ const size_t idx = rnd.get_random_64bit_number()%img.size();
+
+ img = 1;
+ img.host()[idx] = 2;
+ cpu::resize_bilinear(out, img);
+#ifdef DLIB_USE_CUDA
+ cuda::resize_bilinear(out2, img);
+ DLIB_TEST(max(abs(mat(out)-mat(out2))) < 1e-5);
+#endif
+
+ resizable_tensor gradient_input;
+ gradient_input.copy_size(out);
+ tt::tensor_rand rnd;
+ rnd.fill_uniform(gradient_input);
+
+ const float h = 1e-2;
+
+ img.host()[idx] = 2;
+ cpu::resize_bilinear(out, img);
+ float f1 = dot(out, gradient_input);
+
+ img.host()[idx] = 2+h;
+ cpu::resize_bilinear(out, img);
+ float f2 = dot(out, gradient_input);
+
+ const float numerical_grad = (f2-f1)/h;
+ dlog << LINFO << "numerical grad: " << numerical_grad;
+
+
+ resizable_tensor grad, grad2;
+ grad.copy_size(img);
+ grad = 0.1;
+ grad2.copy_size(img);
+ grad2 = 0.1;
+
+ cpu::resize_bilinear_gradient(grad2, gradient_input);
+ dlog << LINFO << "analytic grad: "<< grad2.host()[idx]-0.1;
+ DLIB_TEST_MSG(std::abs(numerical_grad - grad2.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad2.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
+
+#ifdef DLIB_USE_CUDA
+ cuda::resize_bilinear_gradient(grad, gradient_input);
+ dlog << LINFO << "analytic grad: "<< grad.host()[idx]-0.1;
+ DLIB_TEST_MSG(std::abs(numerical_grad - grad.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
+ DLIB_TEST(max(abs(mat(grad)-mat(grad2))) < 1e-5);
+#endif
+
+ }
+
+
+ // now test with strided/sub-window calls
+ alias_tensor aimg(samps, k, nr-2,nc-2);
+ alias_tensor aout(samps, k, onr-2,onc-2);
+ for (int iter = 0; iter < 10; ++iter)
+ {
+ print_spinner();
+
+ const size_t idx = rnd.get_random_64bit_number()%img.size();
+
+ img = 1;
+ img.host()[idx] = 2;
+ out = 9;
+ out2 = 9;
+ auto wout = aout(out, out.nc()*1+1);
+ auto wimg = aimg(img, img.nc()*1+1);
+ cpu::resize_bilinear(wout,out.nc(),out.nr()*out.nc(), wimg,img.nc(),img.nr()*img.nc());
+#ifdef DLIB_USE_CUDA
+ auto wout2 = aout(out2, out2.nc()*1+1);
+ cuda::resize_bilinear(wout2,out2.nc(),out2.nr()*out2.nc(), wimg,img.nc(),img.nr()*img.nc());
+ DLIB_TEST(max(abs(mat(out)-mat(out2))) < 1e-5);
+#endif
+
+
+ resizable_tensor gradient_input;
+ gradient_input.copy_size(out);
+ tt::tensor_rand rnd;
+ rnd.fill_uniform(gradient_input);
+
+ const float h = 1e-2;
+
+ img.host()[idx] = 2;
+ out = 0;
+ wout = aout(out, out.nc()*1+1);
+ wimg = aimg(img, img.nc()*1+1);
+ cpu::resize_bilinear(wout,out.nc(),out.nr()*out.nc(), wimg,img.nc(),img.nr()*img.nc());
+ float f1 = dot(out, gradient_input);
+
+ img.host()[idx] = 2+h;
+ out = 0;
+ cpu::resize_bilinear(wout,out.nc(),out.nr()*out.nc(), wimg,img.nc(),img.nr()*img.nc());
+ float f2 = dot(out, gradient_input);
+
+ const float numerical_grad = (f2-f1)/h;
+ dlog << LINFO << "numerical grad: " << numerical_grad;
+
+
+ resizable_tensor grad, grad2;
+ grad.copy_size(img);
+ grad = 0.1;
+ grad2.copy_size(img);
+ grad2 = 0.1;
+
+ auto wgrad2 = aimg(grad2, grad2.nc()*1+1);
+ auto wgradient_input = aout(gradient_input, gradient_input.nc()*1+1);
+ cpu::resize_bilinear_gradient(wgrad2,grad2.nc(),grad2.nr()*grad2.nc(), wgradient_input,gradient_input.nc(),gradient_input.nr()*gradient_input.nc());
+ dlog << LINFO << "analytic grad: "<< grad2.host()[idx]-0.1;
+ DLIB_TEST_MSG(std::abs(numerical_grad - grad2.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad2.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
+
+#ifdef DLIB_USE_CUDA
+ wgrad2 = aimg(grad, grad.nc()*1+1);
+ wgradient_input = aout(gradient_input, gradient_input.nc()*1+1);
+ cuda::resize_bilinear_gradient(wgrad2,grad.nc(),grad.nr()*grad.nc(), wgradient_input,gradient_input.nc(),gradient_input.nr()*gradient_input.nc());
+ dlog << LINFO << "analytic grad: "<< grad.host()[idx]-0.1;
+ DLIB_TEST_MSG(std::abs(numerical_grad - grad.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad.host()[idx]+0.1) << " numerical_grad: " << numerical_grad);
+ DLIB_TEST_MSG(max(abs(mat(grad)-mat(grad2))) < 1e-5, max(abs(mat(grad)-mat(grad2))));
+#endif
+
+
+ }
+ }
+
+
+ void test_serialization()
+ {
+ print_spinner();
+
+ using net_type = loss_mean_squared<fc<1, input<matrix<double>>>>;
+ net_type net, net2;
+
+ std::ostringstream out;
+ serialize(net, out);
+ const std::string serialized = out.str();
+ std::istringstream in(serialized);
+ dlib::deserialize(net2, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_loss_dot()
+ {
+ print_spinner();
+
+ std::vector<matrix<float,0,1>> samples;
+ std::vector<matrix<float,0,1>> labels;
+
+ const matrix<float> proj = matrix_cast<float>(randm(2,3));
+ for (int i = 0; i < 128; ++i)
+ {
+ // The task is going to be to learn the matrix proj. So we make our
+ // training data thusly:
+ matrix<float,0,1> x = matrix_cast<float>(randm(3,1));
+ matrix<float,0,1> y = normalize(proj*x);
+ samples.push_back(x);
+ labels.push_back(y);
+ }
+
+ using net_type = loss_dot<
+ l2normalize<fc_no_bias<2,
+ input<matrix<float,0,1>>
+ >>>;
+
+ net_type net;
+ dnn_trainer<net_type> trainer(net, sgd(1e-4, 0.9));
+ trainer.set_learning_rate(0.01);
+ trainer.set_min_learning_rate(0.0000001);
+ trainer.set_mini_batch_size(128);
+ trainer.set_max_num_epochs(50000);
+ trainer.train(samples, labels);
+
+
+ for (size_t i = 0; i < samples.size(); ++i)
+ {
+ DLIB_TEST(std::abs(1-dot(net(samples[i]),labels[i])) < 0.001);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_loss_multimulticlass_log()
+ {
+ print_spinner();
+ std::map<string,std::vector<string>> all_labels;
+ all_labels["c1"] = {"a", "b", "c"};
+ all_labels["c2"] = {"d", "e", "f"};
+
+ // make training data
+ std::vector<matrix<float>> samples;
+ std::vector<std::map<string,string>> labels;
+ for (int i = 0; i < 3; ++i)
+ {
+ for (int j = 0; j < 3; ++j)
+ {
+ matrix<float> samp(2,3);
+ samp = 0;
+ samp(0,i) = 1;
+ samp(1,j) = 1;
+ samples.push_back(samp);
+
+ std::map<string,string> l;
+ if (i == 0) l["c1"] = "a";
+ if (i == 1) l["c1"] = "b";
+ if (i == 2) l["c1"] = "c";
+ if (j == 0) l["c2"] = "d";
+ if (j == 1) l["c2"] = "e";
+ if (j == 2) l["c2"] = "f";
+ labels.push_back(l);
+ }
+ }
+
+ using net_type = loss_multimulticlass_log<
+ fc<1,
+ input<matrix<float>>
+ >>;
+
+ net_type net(all_labels);
+ net.subnet().layer_details().set_num_outputs(net.loss_details().number_of_labels());
+
+ dnn_trainer<net_type> trainer(net, sgd(0.1));
+ trainer.set_learning_rate(0.1);
+ trainer.set_min_learning_rate(0.00001);
+ trainer.set_iterations_without_progress_threshold(500);
+
+ trainer.train(samples, labels);
+
+ auto predicted_labels = net(samples);
+
+ // make sure the network predicts the right labels
+ for (size_t i = 0; i < samples.size(); ++i)
+ {
+ DLIB_TEST(predicted_labels[i]["c1"] == labels[i]["c1"]);
+ DLIB_TEST(predicted_labels[i]["c2"] == labels[i]["c2"]);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class dnn_tester : public tester
+ {
+ public:
+ dnn_tester (
+ ) :
+ tester ("test_dnn",
+ "Runs tests on the deep neural network tools.")
+ {}
+
+ void run_tests (
+ )
+ {
+ // make the tests repeatable
+ srand(1234);
+
+ test_tagging();
+#ifdef DLIB_USE_CUDA
+ test_affine_rect();
+ test_conv();
+ test_more_ops2();
+ test_more_ops(1,1);
+ test_more_ops(3,4);
+ test_more_ops(4,3);
+ test_more_ops(4,1);
+ test_more_ops(1,4);
+ test_more_ops(10000,4);
+ compare_bn_gpu_and_cpu();
+ compare_bn_conv_gpu_and_cpu();
+ test_add();
+ test_multiply_zero_padded();
+ compare_adam();
+ test_copy_tensor_gpu();
+ test_copy_tensor_add_to_gpu();
+ test_scale_channels();
+#endif
+ test_tensor_resize_bilinear(2, 3, 6,6, 11, 11);
+ test_tensor_resize_bilinear(2, 3, 6,6, 3, 4);
+ test_tensor_resize_bilinear(2, 3, 5,6, 12, 21);
+ test_max_pool(1,1,2,3,0,0);
+ test_max_pool(3,3,1,1,0,0);
+ test_max_pool(3,3,2,2,0,0);
+ test_max_pool(2,2,2,2,0,0);
+ test_max_pool(4,5,3,1,0,0);
+ test_avg_pool(1,1,2,3,0,0);
+ test_avg_pool(3,3,1,1,0,0);
+ test_avg_pool(3,3,2,2,0,0);
+ test_avg_pool(2,2,2,2,0,0);
+ test_avg_pool(4,5,3,1,0,0);
+ test_avg_pool(4,4,2,2,0,0);
+ test_avg_pool(4,5,40,50,0,0);
+ test_max_pool(2,2,2,3,1,1);
+ test_max_pool(3,3,1,1,1,1);
+ test_max_pool(3,3,2,2,2,1);
+ test_max_pool(2,2,2,2,1,0);
+ test_max_pool(4,5,3,1,2,3);
+ test_avg_pool(1,1,2,3,0,0);
+ test_avg_pool(3,3,1,1,1,2);
+ test_avg_pool(3,3,2,2,2,1);
+ test_avg_pool(2,2,2,2,1,0);
+ test_avg_pool(4,5,3,1,2,4);
+ test_avg_pool(4,4,2,2,1,3);
+ test_avg_pool(4,5,40,50,0,1);
+ test_tanh();
+ test_softmax();
+ test_softmax_all();
+ test_sigmoid();
+ test_batch_normalize();
+ test_batch_normalize_conv();
+ test_basic_tensor_ops();
+ test_layers();
+ test_visit_funcions();
+ test_copy_tensor_cpu();
+ test_copy_tensor_add_to_cpu();
+ test_concat();
+ test_simple_linear_regression();
+ test_simple_linear_regression_eil();
+ test_simple_linear_regression_with_mult_prev();
+ test_multioutput_linear_regression();
+ test_simple_autoencoder();
+ test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task();
+ test_loss_multiclass_per_pixel_activations_on_trivial_single_pixel_task();
+ test_loss_multiclass_per_pixel_outputs_on_trivial_task();
+ test_loss_multiclass_per_pixel_with_noise_and_pixels_to_ignore();
+ test_loss_multiclass_per_pixel_weighted();
+ test_serialization();
+ test_loss_dot();
+ test_loss_multimulticlass_log();
+ }
+
+ void perform_test()
+ {
+ dlog << LINFO << "NOW RUNNING TESTS WITH set_dnn_prefer_fastest_algorithms()";
+ set_dnn_prefer_fastest_algorithms();
+ run_tests();
+
+ dlog << LINFO << "NOW RUNNING TESTS WITH set_dnn_prefer_smallest_algorithms()";
+ set_dnn_prefer_smallest_algorithms();
+ run_tests();
+ }
+ } a;
+}
+
+#endif // __INTELLISENSE__
+
diff --git a/ml/dlib/dlib/test/ekm_and_lisf.cpp b/ml/dlib/dlib/test/ekm_and_lisf.cpp
new file mode 100644
index 000000000..b6f410177
--- /dev/null
+++ b/ml/dlib/dlib/test/ekm_and_lisf.cpp
@@ -0,0 +1,306 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm.h>
+#include <dlib/rand.h>
+#include <dlib/string.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.ekm_and_lisf");
+
+
+ class empirical_kernel_map_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ empirical_kernel_map_tester (
+ ) :
+ tester (
+ "test_ekm_and_lisf", // the command line argument name for this test
+ "Run tests on the empirical_kernel_map and linearly_independent_subset_finder objects.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ thetime = time(0);
+ }
+
+ time_t thetime;
+ dlib::rand rnd;
+
+ template <typename T>
+ void validate (
+ const T& ekm_small,
+ const T& ekm_big
+ )
+ {
+ matrix<double> tmat;
+ projection_function<typename T::kernel_type> proj;
+
+ ekm_small.get_transformation_to(ekm_big, tmat, proj);
+ DLIB_TEST(tmat.nr() == ekm_big.out_vector_size());
+ DLIB_TEST(tmat.nc() == ekm_small.out_vector_size());
+ DLIB_TEST((unsigned long)proj.basis_vectors.size() == ekm_big.basis_size() - ekm_small.basis_size());
+ for (unsigned long i = 0; i < 6; ++i)
+ {
+ const typename T::sample_type temp = randm(4,1,rnd);
+ DLIB_TEST(length(ekm_big.project(temp) - (tmat*ekm_small.project(temp) + proj(temp))) < 1e-10);
+ }
+ }
+
+
+ void test_transformation_stuff()
+ {
+ typedef matrix<double,0,1> sample_type;
+ typedef radial_basis_kernel<sample_type> kernel_type;
+ const kernel_type kern(1);
+
+
+ for (unsigned long n = 1; n < 6; ++n)
+ {
+ print_spinner();
+ for (unsigned long extra = 1; extra < 10; ++extra)
+ {
+ std::vector<sample_type> samps_small, samps_big;
+ linearly_independent_subset_finder<kernel_type> lisf_small(kern, 1000);
+ linearly_independent_subset_finder<kernel_type> lisf_big(kern, 1000);
+ for (unsigned long i = 0; i < n; ++i)
+ {
+ samps_small.push_back(randm(4,1,rnd));
+ samps_big.push_back(samps_small.back());
+ lisf_big.add(samps_small.back());
+ lisf_small.add(samps_small.back());
+ }
+ for (unsigned long i = 0; i < extra; ++i)
+ {
+ samps_big.push_back(randm(4,1,rnd));
+ lisf_big.add(samps_big.back());
+ }
+
+
+ // test no lisf
+ {
+ empirical_kernel_map<kernel_type> ekm_small, ekm_big;
+ ekm_small.load(kern, samps_small);
+ ekm_big.load(kern, samps_big);
+
+ validate(ekm_small, ekm_big);
+ }
+
+ // test with lisf
+ {
+ empirical_kernel_map<kernel_type> ekm_small, ekm_big;
+ ekm_small.load(lisf_small);
+ ekm_big.load(lisf_big);
+
+ validate(ekm_small, ekm_big);
+ }
+
+ // test with partly lisf
+ {
+ empirical_kernel_map<kernel_type> ekm_small, ekm_big;
+ ekm_small.load(kern, samps_small);
+ ekm_big.load(lisf_big);
+
+ validate(ekm_small, ekm_big);
+ }
+
+ // test with partly lisf
+ {
+ empirical_kernel_map<kernel_type> ekm_small, ekm_big;
+ ekm_small.load(lisf_small);
+ ekm_big.load(kern, samps_big);
+
+ validate(ekm_small, ekm_big);
+ }
+
+ }
+ }
+
+
+ // test what happens if the bigger ekm only has repeated basis vectors
+ {
+ empirical_kernel_map<kernel_type> ekm_big, ekm_small;
+ std::vector<sample_type> samps_big, samps_small;
+
+ sample_type temp = randm(4,1,rnd);
+
+ samps_small.push_back(temp);
+ samps_big.push_back(temp);
+ samps_big.push_back(temp);
+
+ ekm_big.load(kern, samps_big);
+ ekm_small.load(kern, samps_small);
+
+ validate(ekm_small, ekm_big);
+
+ }
+ {
+ empirical_kernel_map<kernel_type> ekm_big, ekm_small;
+ linearly_independent_subset_finder<kernel_type> lisf_small(kern, 1000);
+ std::vector<sample_type> samps_big;
+
+ sample_type temp = randm(4,1,rnd);
+
+ lisf_small.add(temp);
+ samps_big.push_back(temp);
+ samps_big.push_back(temp);
+
+ ekm_big.load(kern, samps_big);
+ ekm_small.load(lisf_small);
+
+ validate(ekm_small, ekm_big);
+
+ }
+ {
+ empirical_kernel_map<kernel_type> ekm_big, ekm_small;
+ std::vector<sample_type> samps_big, samps_small;
+
+ sample_type temp = randm(4,1,rnd);
+ sample_type temp2 = randm(4,1,rnd);
+
+ samps_small.push_back(temp);
+ samps_small.push_back(temp2);
+ samps_big.push_back(temp);
+ samps_big.push_back(temp2);
+ samps_big.push_back(randm(4,1,rnd));
+
+ ekm_big.load(kern, samps_big);
+ ekm_small.load(kern, samps_small);
+
+ validate(ekm_small, ekm_big);
+
+ }
+ {
+ empirical_kernel_map<kernel_type> ekm_big, ekm_small;
+ linearly_independent_subset_finder<kernel_type> lisf_small(kern, 1000);
+ std::vector<sample_type> samps_big;
+
+ sample_type temp = randm(4,1,rnd);
+ sample_type temp2 = randm(4,1,rnd);
+
+ lisf_small.add(temp);
+ lisf_small.add(temp2);
+ samps_big.push_back(temp);
+ samps_big.push_back(temp2);
+ samps_big.push_back(temp);
+
+ ekm_big.load(kern, samps_big);
+ ekm_small.load(lisf_small);
+
+ validate(ekm_small, ekm_big);
+
+ }
+
+
+ }
+
+
+
+ void perform_test (
+ )
+ {
+ ++thetime;
+ typedef matrix<double,0,1> sample_type;
+ //dlog << LINFO << "time seed: " << thetime;
+ //rnd.set_seed(cast_to_string(thetime));
+
+
+ typedef radial_basis_kernel<sample_type> kernel_type;
+
+
+ for (int n = 1; n < 10; ++n)
+ {
+ print_spinner();
+ dlog << LINFO << "matrix size " << n;
+
+ std::vector<sample_type> samples;
+ // make some samples
+ for (int i = 0; i < n; ++i)
+ {
+ samples.push_back(randm(4,1,rnd));
+ // double up the samples just to mess with the lisf
+ if (n > 5)
+ samples.push_back(samples.back());
+ }
+
+ dlog << LINFO << "samples.size(): "<< samples.size();
+
+ const kernel_type kern(1);
+
+ linearly_independent_subset_finder<kernel_type> lisf(kern, 100, 1e-4);
+ unsigned long count = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (lisf.add(samples[i]))
+ {
+ DLIB_TEST(equal(lisf[lisf.size()-1], samples[i]));
+ ++count;
+ }
+ }
+ DLIB_TEST(count == lisf.size());
+
+ DLIB_TEST(lisf.size() == (unsigned int)n);
+
+
+ dlog << LINFO << "lisf.size(): "<< lisf.size();
+
+ // make sure the kernel matrices coming out of the lisf are correct
+ DLIB_TEST(dlib::equal(lisf.get_kernel_matrix(), kernel_matrix(kern, lisf), 1e-8));
+ DLIB_TEST(dlib::equal(lisf.get_inv_kernel_marix(), inv(kernel_matrix(kern, lisf.get_dictionary())), 1e-8));
+
+ empirical_kernel_map<kernel_type> ekm;
+ ekm.load(lisf);
+ DLIB_TEST(ekm.basis_size() == lisf.size());
+
+ std::vector<sample_type> proj_samples;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ double err;
+ proj_samples.push_back(ekm.project(samples[i], err));
+ DLIB_TEST(err <= 1e-4);
+ const double error_agreement = std::abs(err - lisf.projection_error(samples[i]));
+ dlog << LTRACE << "err: " << err << " error_agreement: "<< error_agreement;
+ DLIB_TEST(error_agreement < 1e-11);
+ }
+
+ for (int i = 0; i < 5; ++i)
+ {
+ sample_type temp = randm(4,1,rnd);
+ double err;
+ ekm.project(temp, err);
+ const double error_agreement = std::abs(err - lisf.projection_error(temp));
+ dlog << LTRACE << "err: " << err << " error_agreement: "<< error_agreement;
+ DLIB_TEST(error_agreement < 1e-11);
+ }
+
+ // make sure the EKM did the projection correctly
+ DLIB_TEST(dlib::equal(kernel_matrix(kern, samples), kernel_matrix(linear_kernel<sample_type>(), proj_samples), 1e-5));
+ }
+
+
+ test_transformation_stuff();
+
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ empirical_kernel_map_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/elastic_net.cpp b/ml/dlib/dlib/test/elastic_net.cpp
new file mode 100644
index 000000000..0e0501639
--- /dev/null
+++ b/ml/dlib/dlib/test/elastic_net.cpp
@@ -0,0 +1,122 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <dlib/optimization/elastic_net.h>
+#include "tester.h"
+#include <dlib/svm.h>
+#include <dlib/rand.h>
+#include <dlib/string.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.elastic_net");
+
+// ----------------------------------------------------------------------------------------
+
+ matrix<double,0,1> basic_elastic_net(
+ const matrix<double>& X,
+ const matrix<double,0,1>& Y,
+ double ridge_lambda,
+ double lasso_budget,
+ double eps
+ )
+ {
+ DLIB_CASSERT(X.nc() == Y.nr(),"");
+
+
+ typedef matrix<double,0,1> sample_type;
+ typedef linear_kernel<sample_type> kernel_type;
+
+ svm_c_linear_dcd_trainer<kernel_type> trainer;
+ trainer.solve_svm_l2_problem(true);
+ const double C = 1/(2*ridge_lambda);
+ trainer.set_c(C);
+ trainer.set_epsilon(eps);
+ trainer.enable_shrinking(true);
+ trainer.include_bias(false);
+
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+ for (long r = 0; r < X.nr(); ++r)
+ {
+ sample_type temp = trans(rowm(X,r));
+
+ const double xmul = (1/lasso_budget);
+ samples.push_back(temp - xmul*Y);
+ labels.push_back(+1);
+ samples.push_back(temp + xmul*Y);
+ labels.push_back(-1);
+ }
+
+ svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
+ auto df = trainer.train(samples, labels, state);
+ auto&& alpha = state.get_alpha();
+
+ matrix<double,0,1> betas(alpha.size()/2);
+ for (long i = 0; i < betas.size(); ++i)
+ betas(i) = lasso_budget*(alpha[2*i] - alpha[2*i+1]);
+ betas /= sum(mat(alpha));
+ return betas;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_elastic_net : public tester
+ {
+ public:
+ test_elastic_net (
+ ) :
+ tester (
+ "test_elastic_net",
+ "Run tests on the elastic_net object.",
+ 0
+ )
+ {
+ }
+
+ void perform_test (
+ )
+ {
+ matrix<double> w = {1,2,0,4, 0,0,0,0,0, 6, 7,8,0, 9, 0};
+
+ matrix<double> X = randm(w.size(),1000);
+ matrix<double> Y = trans(X)*w;
+ Y += 0.1*(randm(Y.nr(), Y.nc())-0.5);
+
+
+ double ridge_lambda = 0.1;
+ double lasso_budget = sum(abs(w));
+ double eps = 0.0000001;
+
+ dlib::elastic_net solver(X*trans(X),X*Y);
+ solver.set_epsilon(eps);
+
+
+ matrix<double,0,1> results;
+ matrix<double,0,1> results2;
+ for (double s = 1.2; s > 0.10; s *= 0.9)
+ {
+ print_spinner();
+ dlog << LINFO << "s: "<< s;
+ // make sure the two solvers agree.
+ results = basic_elastic_net(X, Y, ridge_lambda, lasso_budget*s, eps);
+ results2 = solver(ridge_lambda, lasso_budget*s);
+ dlog << LINFO << "error: "<< max(abs(results - results2));
+ DLIB_TEST(max(abs(results - results2)) < 1e-3);
+ }
+ }
+ } a;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/empirical_kernel_map.cpp b/ml/dlib/dlib/test/empirical_kernel_map.cpp
new file mode 100644
index 000000000..95b085ab3
--- /dev/null
+++ b/ml/dlib/dlib/test/empirical_kernel_map.cpp
@@ -0,0 +1,444 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm.h>
+#include <dlib/rand.h>
+#include <dlib/string.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.empirical_kernel_map");
+
+
+ class empirical_kernel_map_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ empirical_kernel_map_tester (
+ ) :
+ tester (
+ "test_empirical_kernel_map", // the command line argument name for this test
+ "Run tests on the empirical_kernel_map object.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ // always use the same time so that tests are repeatable
+ thetime = 0;//time(0);
+ }
+
+ time_t thetime;
+ dlib::rand rnd;
+
+ void test_projection_error()
+ {
+ for (int runs = 0; runs < 10; ++runs)
+ {
+ print_spinner();
+ typedef matrix<double,0,1> sample_type;
+ typedef radial_basis_kernel<sample_type> kernel_type;
+ const kernel_type kern(0.2);
+
+ empirical_kernel_map<kernel_type> ekm;
+
+ // generate samples
+ const int num = rnd.get_random_8bit_number()%50 + 1;
+ std::vector<sample_type> samples;
+ for (int i = 0; i < num; ++i)
+ {
+ samples.push_back(randm(5,1,rnd));
+ }
+
+
+ ekm.load(kern, samples);
+ DLIB_TEST(ekm.basis_size() == samples.size());
+
+ double err;
+
+ // the samples in the basis should have zero projection error
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ ekm.project(samples[i], err);
+ DLIB_TEST_MSG(abs(err) < 1e-13, abs(err));
+
+ }
+
+ // Do some sanity tests on the conversion to distance functions while we are at it.
+ for (int i = 0; i < 30; ++i)
+ {
+ // pick two random samples
+ const sample_type samp1 = samples[rnd.get_random_32bit_number()%samples.size()];
+ const sample_type samp2 = samples[rnd.get_random_32bit_number()%samples.size()];
+
+ const matrix<double,0,1> proj1 = ekm.project(samp1);
+ const matrix<double,0,1> proj2 = ekm.project(samp2);
+
+ distance_function<kernel_type> df1 = ekm.convert_to_distance_function(proj1);
+ distance_function<kernel_type> df2 = ekm.convert_to_distance_function(proj2);
+
+ DLIB_TEST(df1.get_kernel() == kern);
+ DLIB_TEST(df2.get_kernel() == kern);
+
+ // make sure the norms are correct
+ DLIB_TEST(std::abs(df1.get_squared_norm() -
+ trans(df1.get_alpha())*kernel_matrix(df1.get_kernel(),df1.get_basis_vectors())*df1.get_alpha()) < 1e-10);
+ DLIB_TEST(std::abs(df2.get_squared_norm() -
+ trans(df2.get_alpha())*kernel_matrix(df2.get_kernel(),df2.get_basis_vectors())*df2.get_alpha()) < 1e-10);
+
+
+ const double true_dist = std::sqrt(kern(samp1,samp1) + kern(samp2,samp2) - 2*kern(samp1,samp2));
+ DLIB_TEST_MSG(abs(df1(df2) - true_dist) < 1e-7, abs(df1(df2) - true_dist));
+ DLIB_TEST_MSG(abs(length(proj1-proj2) - true_dist) < 1e-7, abs(length(proj1-proj2) - true_dist));
+
+
+ // test distance function operators
+ const decision_function<kernel_type> dec1 = ekm.convert_to_decision_function(proj1);
+ const decision_function<kernel_type> dec2 = ekm.convert_to_decision_function(proj2);
+ DLIB_TEST(dec1.kernel_function == kern);
+ DLIB_TEST(dec2.kernel_function == kern);
+
+ distance_function<kernel_type> temp;
+ temp = dec1;
+ DLIB_TEST(std::abs(temp.get_squared_norm() - df1.get_squared_norm()) < 1e-10);
+ temp = dec2;
+ DLIB_TEST(std::abs(temp.get_squared_norm() - df2.get_squared_norm()) < 1e-10);
+ temp = distance_function<kernel_type>(dec1.alpha, dec1.kernel_function, dec1.basis_vectors);
+ DLIB_TEST(std::abs(temp.get_squared_norm() - df1.get_squared_norm()) < 1e-10);
+
+ df1 = dec1;
+
+ temp = df1 + df2;
+ decision_function<kernel_type> dec3(temp.get_alpha(), 0, temp.get_kernel(), temp.get_basis_vectors());
+ DLIB_TEST(std::abs(temp.get_squared_norm() -
+ trans(temp.get_alpha())*kernel_matrix(temp.get_kernel(),temp.get_basis_vectors())*temp.get_alpha()) < 1e-10);
+ for (unsigned long j = 0; j < samples.size(); ++j)
+ {
+ DLIB_TEST(std::abs(dec3(samples[j]) - (dec1(samples[j]) + dec2(samples[j]))) < 1e-10);
+ }
+
+
+ temp = df1 - df2;
+ dec3 = decision_function<kernel_type>(temp.get_alpha(), 0, temp.get_kernel(), temp.get_basis_vectors());
+ DLIB_TEST(std::abs(temp.get_squared_norm() -
+ trans(temp.get_alpha())*kernel_matrix(temp.get_kernel(),temp.get_basis_vectors())*temp.get_alpha()) < 1e-10);
+ for (unsigned long j = 0; j < samples.size(); ++j)
+ {
+ DLIB_TEST(std::abs(dec3(samples[j]) - (dec1(samples[j]) - dec2(samples[j]))) < 1e-10);
+ }
+
+ temp = 3*(df1 - df2)*2;
+ dec3 = decision_function<kernel_type>(temp.get_alpha(), 0, temp.get_kernel(), temp.get_basis_vectors());
+ DLIB_TEST(std::abs(temp.get_squared_norm() -
+ trans(temp.get_alpha())*kernel_matrix(temp.get_kernel(),temp.get_basis_vectors())*temp.get_alpha()) < 1e-10);
+ for (unsigned long j = 0; j < samples.size(); ++j)
+ {
+ DLIB_TEST(std::abs(dec3(samples[j]) - 6*(dec1(samples[j]) - dec2(samples[j]))) < 1e-10);
+ }
+
+ distance_function<kernel_type> df_empty(kern);
+
+ temp = df_empty + (df1 + df2)/2 + df_empty - df_empty + (df_empty + df_empty) - (df_empty - df_empty);
+ dec3 = decision_function<kernel_type>(temp.get_alpha(), 0, temp.get_kernel(), temp.get_basis_vectors());
+ DLIB_TEST(std::abs(temp.get_squared_norm() -
+ trans(temp.get_alpha())*kernel_matrix(temp.get_kernel(),temp.get_basis_vectors())*temp.get_alpha()) < 1e-10);
+ for (unsigned long j = 0; j < samples.size(); ++j)
+ {
+ DLIB_TEST(std::abs(dec3(samples[j]) - 0.5*(dec1(samples[j]) + dec2(samples[j]))) < 1e-10);
+ }
+ }
+ // Do some sanity tests on the conversion to distance functions while we are at it. This
+ // time multiply one of the projections by 30 and see that it still all works out right.
+ for (int i = 0; i < 30; ++i)
+ {
+ // pick two random samples
+ const sample_type samp1 = samples[rnd.get_random_32bit_number()%samples.size()];
+ const sample_type samp2 = samples[rnd.get_random_32bit_number()%samples.size()];
+
+ matrix<double,0,1> proj1 = ekm.project(samp1);
+ matrix<double,0,1> proj2 = 30*ekm.project(samp2);
+
+ distance_function<kernel_type> df1 = ekm.convert_to_distance_function(proj1);
+ distance_function<kernel_type> df2 = ekm.convert_to_distance_function(proj2);
+
+ DLIB_TEST_MSG(abs(length(proj1-proj2) - df1(df2)) < 1e-7, abs(length(proj1-proj2) - df1(df2)));
+ }
+
+
+ // now generate points with projection error
+ for (double i = 1; i < 10; ++i)
+ {
+ sample_type test_point = i*randm(5,1,rnd);
+ ekm.project(test_point, err);
+ // turn into normal distance rather than squared distance
+ err = sqrt(err);
+ dlog << LTRACE << "projection error: " << err;
+
+ distance_function<kernel_type> df = ekm.convert_to_distance_function(ekm.project(test_point));
+
+ // the projection error should be the distance between the test_point and the point it gets
+ // projected onto
+ DLIB_TEST_MSG(abs(df(test_point) - err) < 1e-10, abs(df(test_point) - err));
+ // while we are at it make sure the squared norm in the distance function is right
+ double df_error = abs(df.get_squared_norm() - trans(df.get_alpha())*kernel_matrix(kern, samples)*df.get_alpha());
+ DLIB_TEST_MSG( df_error < 1e-10, df_error);
+ }
+
+
+
+ }
+ }
+
+ template <typename kernel_type>
+ void test_with_kernel(const kernel_type& kern)
+ {
+ typedef typename kernel_type::sample_type sample_type;
+
+ empirical_kernel_map<kernel_type> ekm, ekm2, ekm3;
+
+ for (int j = 0; j < 10; ++j)
+ {
+ sample_type samp;
+ std::vector<sample_type> samples;
+ std::vector<sample_type> proj_samples;
+ print_spinner();
+ const int num = rnd.get_random_8bit_number()%200 + 1;
+ // make some random samples
+ for (int i = 0; i < num; ++i)
+ {
+ samples.push_back(randm(4,1,rnd));
+ }
+ // add on a little bit to make sure there is at least one non-zero sample. If all the
+ // samples are zero then empirical_kernel_map_error will be thrown and we don't want that.
+ samples.front()(0) += 0.001;
+
+ ekm2.load(kern, samples);
+ DLIB_TEST(ekm2.basis_size() == samples.size());
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ DLIB_TEST(dlib::equal(ekm2[i] , samples[i]));
+
+ // test serialization
+ ostringstream sout;
+ serialize(ekm2, sout);
+ ekm2.clear();
+ istringstream sin(sout.str());
+ deserialize(ekm3, sin);
+ // also test swap
+ ekm3.swap(ekm);
+ DLIB_TEST(ekm.get_kernel() == kern);
+ DLIB_TEST(ekm.out_vector_size() != 0);
+ DLIB_TEST(ekm2.out_vector_size() == 0);
+ DLIB_TEST(ekm3.out_vector_size() == 0);
+
+
+
+ // project all the samples into kernel space
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ proj_samples.push_back(ekm.project(samples[i]));
+ }
+
+ DLIB_TEST(max(abs(kernel_matrix(kern, samples) - kernel_matrix(linear_kernel<sample_type>(), proj_samples))) < 1e-12);
+ DLIB_TEST(ekm.out_vector_size() == proj_samples[0].size());
+
+ for (int i = 0; i < 30; ++i)
+ {
+ const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size();
+ const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size();
+ decision_function<kernel_type> dec_funct = ekm.convert_to_decision_function(proj_samples[idx1]);
+ distance_function<kernel_type> dist_funct = ekm.convert_to_distance_function(proj_samples[idx1]);
+
+ // make sure the distances match
+ const double dist_error = abs(length(proj_samples[idx1] - proj_samples[idx2]) - dist_funct(samples[idx2]));
+ DLIB_TEST_MSG( dist_error < 1e-6, dist_error);
+ // make sure the dot products match
+ DLIB_TEST(abs(dot(proj_samples[idx1],proj_samples[idx2]) - dec_funct(samples[idx2])) < 1e-10);
+
+ // also try the dec_funct with samples that weren't in the original set
+ samp = 100*randm(4,1,rnd);
+ // make sure the dot products match
+ DLIB_TEST(abs(dot(proj_samples[idx1],ekm.project(samp)) - dec_funct(samp)) < 1e-10);
+ samp = randm(4,1,rnd);
+ // make sure the dot products match
+ DLIB_TEST(abs(dot(proj_samples[idx1],ekm.project(samp)) - dec_funct(samp)) < 1e-10);
+ }
+
+
+
+ proj_samples.clear();
+
+
+ // now do the projection but use the projection_function returned by get_projection_function()
+ projection_function<kernel_type> proj2 = ekm.get_projection_function();
+ projection_function<kernel_type> proj;
+ sout.clear();
+ sout.str("");
+ sin.clear();
+ sin.str("");
+ // test serialization
+ serialize(proj2, sout);
+ sin.str(sout.str());
+ deserialize(proj, sin);
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ proj_samples.push_back(proj(samples[i]));
+ }
+
+ DLIB_TEST(max(abs(kernel_matrix(kern, samples) - kernel_matrix(linear_kernel<sample_type>(), proj_samples))) < 1e-12);
+ DLIB_TEST(ekm.out_vector_size() == proj_samples[0].size());
+ DLIB_TEST(proj.out_vector_size() == proj_samples[0].size());
+
+ ekm.clear();
+ DLIB_TEST(ekm.out_vector_size() == 0);
+ DLIB_TEST(ekm2.out_vector_size() == 0);
+ DLIB_TEST(ekm3.out_vector_size() == 0);
+
+
+ for (int i = 0; i < 30; ++i)
+ {
+ const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size();
+ const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size();
+ decision_function<kernel_type> dec_funct = convert_to_decision_function(proj,proj_samples[idx1]);
+
+ // make sure the dot products match
+ DLIB_TEST(abs(dot(proj_samples[idx1],proj_samples[idx2]) - dec_funct(samples[idx2])) < 1e-10);
+
+ // also try the dec_funct with samples that weren't in the original set
+ samp = 100*randm(4,1,rnd);
+ // make sure the dot products match
+ DLIB_TEST(abs(dot(proj_samples[idx1],proj(samp)) - dec_funct(samp)) < 1e-10);
+ samp = randm(4,1,rnd);
+ // make sure the dot products match
+ DLIB_TEST(abs(dot(proj_samples[idx1],proj(samp)) - dec_funct(samp)) < 1e-10);
+ }
+
+
+
+
+
+ }
+
+ for (int j = 1; j <= 20; ++j)
+ {
+ dlog << LTRACE << "j: " << j;
+ sample_type samp, samp2;
+ std::vector<sample_type> samples1;
+ std::vector<sample_type> samples2;
+ print_spinner();
+ // make some random samples. At the end samples1 will be a subset of samples2
+ for (int i = 0; i < 5*j; ++i)
+ {
+ samples1.push_back(randm(10,1,rnd));
+ samples2.push_back(samples1.back());
+ }
+ for (int i = 0; i < 5*j; ++i)
+ {
+ samples2.push_back(randm(10,1,rnd));
+ }
+ // add on a little bit to make sure there is at least one non-zero sample. If all the
+ // samples are zero then empirical_kernel_map_error will be thrown and we don't want that.
+ samples1.front()(0) += 0.001;
+ samples2.front()(0) += 0.001;
+
+ ekm.load(kern, samples1);
+ for (unsigned long i = 0; i < samples1.size(); ++i)
+ DLIB_TEST(dlib::equal(ekm[i] , samples1[i]));
+ DLIB_TEST(ekm.basis_size() == samples1.size());
+ ekm2.load(kern, samples2);
+ DLIB_TEST(ekm2.basis_size() == samples2.size());
+
+ dlog << LTRACE << "ekm.out_vector_size(): " << ekm.out_vector_size();
+ dlog << LTRACE << "ekm2.out_vector_size(): " << ekm2.out_vector_size();
+ const double eps = 1e-6;
+
+ matrix<double> transform;
+ // Make sure transformations back to yourself work right. Note that we can't just
+ // check that transform is the identity matrix since it might be an identity transform
+ // for only a subspace of vectors (this happens if the ekm maps points into a subspace of
+ // all possible ekm.out_vector_size() vectors).
+ transform = ekm.get_transformation_to(ekm);
+ DLIB_TEST(transform.nr() == ekm.out_vector_size());
+ DLIB_TEST(transform.nc() == ekm.out_vector_size());
+ for (unsigned long i = 0; i < samples1.size(); ++i)
+ {
+ samp = ekm.project(samples1[i]);
+ DLIB_TEST_MSG(length(samp - transform*samp) < eps, length(samp - transform*samp));
+ samp = ekm.project((samples1[0] + samples1[i])/2);
+ DLIB_TEST_MSG(length(samp - transform*samp) < eps, length(samp - transform*samp));
+ }
+
+ transform = ekm2.get_transformation_to(ekm2);
+ DLIB_TEST(transform.nr() == ekm2.out_vector_size());
+ DLIB_TEST(transform.nc() == ekm2.out_vector_size());
+ for (unsigned long i = 0; i < samples2.size(); ++i)
+ {
+ samp = ekm2.project(samples2[i]);
+ DLIB_TEST_MSG(length(samp - transform*samp) < eps, length(samp - transform*samp));
+ samp = ekm2.project((samples2[0] + samples2[i])/2);
+ DLIB_TEST_MSG(length(samp - transform*samp) < eps, length(samp - transform*samp));
+ //dlog << LTRACE << "mapping error: " << length(samp - transform*samp);
+ }
+
+
+ // now test the transform from ekm to ekm2
+ transform = ekm.get_transformation_to(ekm2);
+ DLIB_TEST(transform.nr() == ekm2.out_vector_size());
+ DLIB_TEST(transform.nc() == ekm.out_vector_size());
+ for (unsigned long i = 0; i < samples1.size(); ++i)
+ {
+ samp = ekm.project(samples1[i]);
+ distance_function<kernel_type> df1 = ekm.convert_to_distance_function(samp);
+ distance_function<kernel_type> df2 = ekm2.convert_to_distance_function(transform*samp);
+ DLIB_TEST_MSG(df1(df2) < eps, df1(df2));
+ //dlog << LTRACE << "mapping error: " << df1(df2);
+
+
+ samp = ekm.project((samples1[0] + samples1[i])/2);
+ df1 = ekm.convert_to_distance_function(samp);
+ df2 = ekm2.convert_to_distance_function(transform*samp);
+ DLIB_TEST_MSG(df1(df2) < eps, df1(df2));
+ }
+
+
+ }
+ }
+
+ void perform_test (
+ )
+ {
+ ++thetime;
+ typedef matrix<double,0,1> sample_type;
+ dlog << LINFO << "time seed: " << thetime;
+ rnd.set_seed(cast_to_string(thetime));
+
+ print_spinner();
+ test_projection_error();
+ print_spinner();
+ dlog << LINFO << "test with linear kernel";
+ test_with_kernel(linear_kernel<sample_type>());
+ print_spinner();
+ dlog << LINFO << "test with rbf kernel";
+ test_with_kernel(radial_basis_kernel<sample_type>(0.2));
+ print_spinner();
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ empirical_kernel_map_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/entropy_coder.cpp b/ml/dlib/dlib/test/entropy_coder.cpp
new file mode 100644
index 000000000..12a9a3305
--- /dev/null
+++ b/ml/dlib/dlib/test/entropy_coder.cpp
@@ -0,0 +1,587 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <ctime>
+#include <cstdlib>
+
+#include <dlib/entropy_encoder.h>
+#include <dlib/entropy_decoder.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.entropy_coder");
+
+ namespace entropy_coder_kernel_test_helpers
+ {
+ template <
+ typename encoder,
+ typename decoder
+ >
+ std::string test (
+ const std::string& input
+ )
+ /*!
+ ensures
+ - encodes the data from in and then tries to decode it and returns
+ "" if it was successfully decoded else it returns the decoded string
+ !*/
+ {
+ ostringstream sout;
+ istringstream sin;
+ istringstream in;
+
+
+ in.str(input);
+
+ const unsigned long max_total = 65535;
+
+
+
+
+ unsigned long counts[256];
+ for (int i = 0; i < 256; ++i)
+ {
+ counts[i] = 1;
+ }
+
+
+ encoder e;
+
+
+
+ DLIB_TEST(e.stream_is_set() == false);
+
+ e.set_stream(sout);
+
+ DLIB_TEST(e.stream_is_set() == true);
+ DLIB_TEST(&e.get_stream() == &sout);
+
+ unsigned char ch;
+
+ unsigned long total = 256;
+
+ while (in.read((char*)&ch,1))
+ {
+ if (total > max_total)
+ {
+ total = 0;
+ for (int j = 0; j<256; ++j)
+ {
+ counts[j] >>= 1;
+ if (counts[j] == 0)
+ counts[j] = 1;
+ total += counts[j];
+
+ }
+ }
+
+ unsigned long low_count = 0;
+ unsigned long high_count;
+ for (int i = 0; i < ch; ++i)
+ low_count += counts[i];
+ high_count = low_count + counts[ch];
+
+ e.encode(low_count,high_count,total);
+
+
+ ++total;
+ counts[ch] += 1;
+ }
+
+ DLIB_TEST(e.stream_is_set() == true);
+ DLIB_TEST(&e.get_stream() == &sout);
+
+
+
+ e.clear();
+
+ DLIB_TEST(e.stream_is_set() == false);
+
+
+ // *****************************************
+
+
+ decoder d;
+
+
+ DLIB_TEST(d.stream_is_set() == false);
+ DLIB_TEST(d.get_target_called() == false);
+
+ sin.str(sout.str());
+ sout.str("");
+
+ d.set_stream(sin);
+
+ DLIB_TEST(d.get_target_called() == false);
+
+ DLIB_TEST(d.stream_is_set() == true);
+ DLIB_TEST(&d.get_stream() == &sin);
+
+ for (int i = 0; i < 256; ++i)
+ {
+ counts[i] = 1;
+ }
+
+ total = 256;
+
+ for (string::size_type i = 0; i < input.size() ; ++i)
+ {
+ if (total > max_total)
+ {
+ total = 0;
+ for (int j = 0; j<256; ++j)
+ {
+ counts[j] >>= 1;
+ if (counts[j] == 0)
+ counts[j] = 1;
+ total += counts[j];
+
+ }
+ }
+
+ DLIB_TEST(d.get_target_called() == false);
+
+ unsigned long target = d.get_target(total);
+
+ DLIB_TEST(target < total);
+
+ DLIB_TEST(d.get_target_called() == true);
+
+
+ unsigned long low_count;
+ unsigned long high_count = 0;
+
+ unsigned long j;
+ for (j = 0; high_count <= target; ++j)
+ {
+ high_count += counts[j];
+ }
+ --j;
+ low_count = high_count - counts[j];
+
+
+ ch = static_cast<unsigned char>(j);
+
+
+ sout.rdbuf()->sputn((char*)&ch,1);
+
+
+
+ d.decode(low_count,high_count);
+ DLIB_TEST(d.get_target_called() == false);
+ ++total;
+ counts[ch] += 1;
+
+ }
+
+ DLIB_TEST(d.stream_is_set() == true);
+ DLIB_TEST(&d.get_stream() == &sin);
+
+ d.clear();
+
+ DLIB_TEST(d.stream_is_set() == false);
+ DLIB_TEST_MSG(sout.str().size() == input.size(),"the test script is buggy");
+
+
+ if (sout.str() == input)
+ return "";
+ else
+ return sout.str();
+
+ }
+
+ }
+
+
+
+
+ template <
+ typename encoder,
+ typename decoder
+ >
+ void entropy_coder_kernel_test (
+ )
+ /*!
+ requires
+ - encoder is an implementation of entropy_encoder/entropy_encoder_kernel_abstract.h
+ - decoder is an implementation of entropy_decoder/entropy_decoder_kernel_abstract.h
+ ensures
+ - runs tests on encoder and decoder for compliance with the specs
+ !*/
+ {
+ using namespace entropy_coder_kernel_test_helpers;
+
+ dlog << LTRACE << 1;
+
+ print_spinner();
+ string temp, temp2;
+
+ srand(static_cast<int>(time(0)));
+
+ for (int k = 0; k < 10000; ++k)
+ {
+ string temp;
+ istringstream sin;
+ ostringstream sout;
+ decoder d;
+ encoder e;
+
+ e.set_stream(sout);
+
+ int num = ::rand() %200;
+ unsigned long total[200];
+ unsigned long high_count[200];
+ unsigned long low_count[200];
+ for (int i = 0; i < num; ++i)
+ {
+ total[i] = ::rand()%256 + 20;
+ high_count[i] = ::rand()%total[i] + 1;
+ low_count[i] = ::rand()%high_count[i];
+
+ e.encode(low_count[i],high_count[i],total[i]);
+ }
+
+ e.clear();
+
+ sout.rdbuf()->sputc('a');
+
+ sin.str(sout.str());
+
+
+ d.set_stream(sin);
+
+
+ for (int i = 0; i < num; ++i)
+ {
+ unsigned long N = d.get_target(total[i]);
+ DLIB_TEST(low_count[i] <= N && N < high_count[i]);
+ d.decode(low_count[i],high_count[i]);
+ }
+
+
+
+
+
+
+ DLIB_TEST_MSG(sin.rdbuf()->sgetc() != EOF,"num: " << num);
+ DLIB_TEST_MSG(sin.rdbuf()->sgetc() == 'a',
+ "sin.rdbuf()->sgetc() == " << (char)sin.rdbuf()->sgetc() <<
+ "\nnum: " << num
+ );
+ DLIB_TEST(sin.rdbuf()->sbumpc() == 'a');
+ DLIB_TEST(sin.rdbuf()->sgetc() == EOF);
+
+ } // for (int k = 0; k < 10000; ++k)
+
+ dlog << LTRACE << 2;
+
+ print_spinner();
+
+ // the point of this block is to make sure that the return value
+ // from decoder.get_target(total) is a always less than total
+ for (int k = 0; k < 20; ++k)
+ {
+ string temp;
+ temp.push_back(static_cast<char>(::rand()&0xff));
+ istringstream sin(temp);
+ decoder d;
+ d.set_stream(sin);
+ unsigned long total = ::rand()%256 + 20;
+ unsigned long target = d.get_target(total);
+ DLIB_TEST(target<total);
+
+ for (int i = 0; i < 30; ++i)
+ {
+ unsigned long high_count = ::rand()%total + 1;
+ unsigned long low_count = ::rand()%high_count;
+ if (high_count <= target)
+ high_count = target+1;
+ if (low_count > target)
+ low_count = target;
+
+ d.decode(low_count,high_count);
+ target = d.get_target(total);
+ DLIB_TEST_MSG(target<total,"target: " << target << " total: " << total);
+ }
+ }
+
+ print_spinner();
+
+
+ dlog << LTRACE << 3;
+
+ for (int k = 0; k < 10; ++k)
+ {
+ unsigned long seed1 = 1064644658, seed2 = 1064543921;
+ //unsigned long seed1 = 1064682621, seed2 = 1064543921;
+
+ // make array be an array with each element in the range 0 to 255
+ // and have the probability of seeing each number in the array
+ // not be the same
+ //seed1 = static_cast<unsigned long>(time(0));
+ srand(seed1 );
+ int array[65536];
+ for (int i = 0; i < 65536; ++i)
+ {
+ array[i] = ::rand()%256;
+ }
+ for (int i = 0; i < 60; ++i)
+ {
+ int idx = ::rand()%65536;
+ int radius = 35;
+ if (idx > radius && idx <65536-radius)
+ {
+ for (int j = idx-radius; j < idx+radius; ++j)
+ array[j] = array[idx];
+ }
+ }
+
+ // test with 3 random strings of length 10000
+ // but use the above array to bias the random numbers
+ for (int j = 0; j < 3; ++j)
+ {
+ print_spinner();
+ temp = "";
+ //seed2 = static_cast<unsigned long>(time(0));
+ srand(seed2 );
+ for ( int i = 0; i < 10000; ++i)
+ {
+ int a = array[::rand()%65536];
+ temp += (unsigned char)a;
+ }
+ string temp2;
+ temp2 = test<encoder,decoder>(temp);
+ if (temp2 != "")
+ {
+
+ int k = 0;
+ DLIB_TEST(temp != temp2);
+ while (temp[k] == temp2[k])++k;
+ }
+
+
+ DLIB_TEST_MSG(temp2 == "","");
+ }
+ }
+
+ print_spinner();
+
+
+ dlog << LTRACE << 4;
+
+
+
+
+ // test with a large string which contains all the same character
+ temp = "eeeeeeeeee";
+ for (int i = 0; i < 13; ++i)
+ {
+ temp = temp + temp;
+ }
+ temp = test<encoder,decoder>(temp);
+ if (temp != "")
+ {
+ // crop off all the e's until we find the part that is messed up
+ string::size_type pos = temp.find_first_not_of("e");
+ temp = temp.substr(pos);
+ }
+ DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\""); /**/
+
+
+ dlog << LTRACE << 5;
+
+ print_spinner();
+
+ temp = "davis";
+ temp = test<encoder,decoder>(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\"");
+
+ temp = "";
+ temp = test<encoder,decoder>(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\"");
+
+ // test for each single character
+ for ( int i = 0; i <= 255; ++i)
+ {
+ temp = (unsigned char)i;
+ temp = test<encoder,decoder>(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\"");
+ }
+
+ dlog << LTRACE << 6;
+
+ // test with a long string with the same thing repeated many times
+ temp = "davis ";
+ for (int i = 0; i < 10; ++i)
+ {
+ temp = temp + temp;
+ }
+ temp = test<encoder,decoder>(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\"");
+
+ dlog << LTRACE << 7;
+
+ // test with 10 random strings of length 1000
+ for (int j = 0; j < 10; ++j)
+ {
+ temp = "";
+ srand(static_cast<unsigned int>(time(0)));
+ for ( int i = 0; i < 1000; ++i)
+ {
+ int a = ::rand()%256;
+ temp += (unsigned char)a;
+ }
+ temp = test<encoder,decoder>(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\"");
+ }
+
+
+ dlog << LTRACE << 8;
+
+ print_spinner();
+
+ // test with 15 random strings of length 30000
+ for (int j = 0; j < 15; ++j)
+ {
+ print_spinner();
+ temp = "";
+ unsigned long seed = static_cast<unsigned int>(time(0));
+ srand(seed);
+ for ( int i = 0; i < 30000; ++i)
+ {
+ int a = ::rand()%256;
+ temp += (unsigned char)a;
+ }
+ temp = test<encoder,decoder>(temp); DLIB_TEST_MSG(temp == "","seed: " << seed);
+ }
+
+
+ dlog << LTRACE << 9;
+
+ print_spinner();
+
+ // test with a large string which contains all the same character
+ temp = " ";
+ for (int i = 0; i < 10; ++i)
+ {
+ temp = temp + temp;
+ }
+ temp = test<encoder,decoder>(temp);
+ if (temp != "")
+ {
+ // crop off all the spacess until we find the part that is messed up
+ string::size_type pos = temp.find_first_not_of(" ");
+ temp = temp.substr(pos);
+ }
+ DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\"");/**/
+
+
+
+
+
+
+ dlog << LTRACE << 10;
+
+
+
+
+
+
+ // test with a large string which contains a bunch of a's followed by a
+ // bunch of z's
+ temp = "aaaaaaaa";
+ temp2 = "zzzzzzzz";
+ for (int i = 0; i < 12; ++i)
+ {
+ temp = temp + temp;
+ temp2 = temp2 + temp2;
+ }
+ temp += temp2;
+ print_spinner();
+ temp = test<encoder,decoder>(temp);
+ DLIB_TEST(temp == "");
+
+
+
+ dlog << LTRACE << 11;
+
+
+
+ }
+
+
+
+
+ class entropy_coder_tester : public tester
+ {
+ public:
+ entropy_coder_tester (
+ ) :
+ tester ("test_entropy_coder",
+ "Runs tests on the entropy_coder component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a 1";
+ entropy_coder_kernel_test<
+ entropy_encoder::kernel_1a,
+ entropy_decoder::kernel_1a
+ >();
+
+ dlog << LINFO << "testing kernel_1a_c 2";
+ entropy_coder_kernel_test<
+ entropy_encoder::kernel_1a_c,
+ entropy_decoder::kernel_1a_c
+ >();
+
+ dlog << LINFO << "testing kernel_1a 3";
+ entropy_coder_kernel_test<
+ entropy_encoder::kernel_2a,
+ entropy_decoder::kernel_2a
+ >();
+
+ dlog << LINFO << "testing kernel_1a_c 4";
+ entropy_coder_kernel_test<
+ entropy_encoder::kernel_2a_c,
+ entropy_decoder::kernel_2a_c
+ >();
+
+ dlog << LINFO << "testing kernel_1a 5";
+ entropy_coder_kernel_test<
+ entropy_encoder::kernel_1a,
+ entropy_decoder::kernel_1a_c
+ >();
+
+ dlog << LINFO << "testing kernel_1a_c 6";
+ entropy_coder_kernel_test<
+ entropy_encoder::kernel_1a_c,
+ entropy_decoder::kernel_1a
+ >();
+
+ dlog << LINFO << "testing kernel_1a 7";
+ entropy_coder_kernel_test<
+ entropy_encoder::kernel_2a,
+ entropy_decoder::kernel_2a_c
+ >();
+
+ dlog << LINFO << "testing kernel_1a_c 8";
+ entropy_coder_kernel_test<
+ entropy_encoder::kernel_2a_c,
+ entropy_decoder::kernel_2a
+ >();
+
+ }
+ } a;
+
+
+
+
+}
+
diff --git a/ml/dlib/dlib/test/entropy_encoder_model.cpp b/ml/dlib/dlib/test/entropy_encoder_model.cpp
new file mode 100644
index 000000000..0276cbe77
--- /dev/null
+++ b/ml/dlib/dlib/test/entropy_encoder_model.cpp
@@ -0,0 +1,198 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <ctime>
+#include <cstdlib>
+
+#include <dlib/entropy_encoder_model.h>
+#include <dlib/entropy_decoder_model.h>
+#include <dlib/entropy_encoder.h>
+#include <dlib/entropy_decoder.h>
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.entropy_coder_model");
+
+ template <
+ typename ee,
+ typename ed
+ >
+ void entropy_encoder_model_kernel_test (
+ )
+ /*!
+ requires
+ - ee is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h
+ the alphabet_size for ee is 256
+ - ed is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h
+ the alphabet_size for ed is 256
+ - ee and ed must share the same kernel number
+ ensures
+ - runs tests on ee and ed for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+ srand(static_cast<unsigned int>(time(0)));
+
+ typedef typename ee::entropy_encoder_type ee_type;
+ typedef typename ed::entropy_decoder_type ed_type;
+
+
+
+ {
+
+ ee_type ecoder;
+ ed_type dcoder;
+
+ ee elen(ecoder);
+ ed dlen(dcoder);
+ ee elit(ecoder);
+ ed dlit(dcoder);
+
+
+ istringstream sin;
+ ostringstream sout;
+
+ ecoder.set_stream(sout);
+
+
+ unsigned long temp;
+
+
+ elen.encode(0);
+ elit.encode(9);
+
+ elen.encode(0);
+ elit.encode(0);
+
+ elen.encode(0);
+ elit.encode(4);
+
+ elen.encode(0);
+ elit.encode(0);
+
+ elen.encode(0);
+ elit.encode(2);
+
+ elen.encode(0);
+ elit.encode(0);
+
+
+
+
+
+
+
+ ecoder.clear();
+ sin.str(sout.str());
+ dcoder.set_stream(sin);
+
+
+ dlen.decode(temp);
+ DLIB_TEST(temp == 0);
+ dlit.decode(temp);
+ DLIB_TEST(temp == 9);
+
+ dlen.decode(temp);
+ DLIB_TEST(temp == 0);
+ dlit.decode(temp);
+ DLIB_TEST(temp == 0);
+
+ dlen.decode(temp);
+ DLIB_TEST(temp == 0);
+ dlit.decode(temp);
+ DLIB_TEST(temp == 4);
+
+ dlen.decode(temp);
+ DLIB_TEST(temp == 0);
+ dlit.decode(temp);
+ DLIB_TEST(temp == 0);
+
+ dlen.decode(temp);
+ DLIB_TEST(temp == 0);
+ dlit.decode(temp);
+ DLIB_TEST(temp == 2);
+
+ dlen.decode(temp);
+ DLIB_TEST(temp == 0);
+ dlit.decode(temp);
+ DLIB_TEST(temp == 0);
+
+
+
+
+ }
+
+ }
+
+
+
+
+ class entropy_encoder_model_tester : public tester
+ {
+ public:
+ entropy_encoder_model_tester (
+ ) :
+ tester ("test_entropy_coder_model",
+ "Runs tests on the entropy_encoder_model and entropy_decoder_model components.")
+ {}
+
+ void perform_test (
+ )
+ {
+ typedef entropy_encoder::kernel_2a_c ee;
+ typedef entropy_decoder::kernel_2a_c ed;
+
+ dlog << LINFO << "testing kernel_1a";
+ entropy_encoder_model_kernel_test<
+ entropy_encoder_model<256,ee>::kernel_1a,
+ entropy_decoder_model<256,ed>::kernel_1a>();
+
+ dlog << LINFO << "testing kernel_2a";
+ entropy_encoder_model_kernel_test<
+ entropy_encoder_model<256,ee>::kernel_2a,
+ entropy_decoder_model<256,ed>::kernel_2a>();
+
+ dlog << LINFO << "testing kernel_3a";
+ entropy_encoder_model_kernel_test<
+ entropy_encoder_model<256,ee>::kernel_3a,
+ entropy_decoder_model<256,ed>::kernel_3a>();
+
+ dlog << LINFO << "testing kernel_4a";
+ entropy_encoder_model_kernel_test<
+ entropy_encoder_model<256,ee>::kernel_4a,
+ entropy_decoder_model<256,ed>::kernel_4a>();
+
+ dlog << LINFO << "testing kernel_4b";
+ entropy_encoder_model_kernel_test<
+ entropy_encoder_model<256,ee>::kernel_4b,
+ entropy_decoder_model<256,ed>::kernel_4b>();
+
+ dlog << LINFO << "testing kernel_5a";
+ entropy_encoder_model_kernel_test<
+ entropy_encoder_model<256,ee>::kernel_5a,
+ entropy_decoder_model<256,ed>::kernel_5a>();
+
+ dlog << LINFO << "testing kernel_5c";
+ entropy_encoder_model_kernel_test<
+ entropy_encoder_model<256,ee>::kernel_5c,
+ entropy_decoder_model<256,ed>::kernel_5c>();
+
+ dlog << LINFO << "testing kernel_6a";
+ entropy_encoder_model_kernel_test<
+ entropy_encoder_model<256,ee>::kernel_6a,
+ entropy_decoder_model<256,ed>::kernel_6a>();
+
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/example.cpp b/ml/dlib/dlib/test/example.cpp
new file mode 100644
index 000000000..4cf927159
--- /dev/null
+++ b/ml/dlib/dlib/test/example.cpp
@@ -0,0 +1,72 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+
+// This is called an unnamed-namespace and it has the effect of making everything
+// inside this file "private" so that everything you declare will have static linkage.
+// Thus we won't have any multiply defined symbol errors coming out of the linker when
+// we try to compile the test suite.
+namespace
+{
+ using namespace test;
+ // Declare the logger we will use in this test. The name of the logger
+ // should start with "test."
+ dlib::logger dlog("test.example");
+
+
+ class example_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ example_tester (
+ ) :
+ tester (
+ "test_example", // the command line argument name for this test
+ "Run example tests.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {}
+
+ void perform_test (
+ )
+ {
+ // This message gets logged to the file debug.txt if the user has enabled logging by
+ // supplying the -d option on the command line (and they haven't set the logging level
+ // to something higher than LINFO).
+ dlog << dlib::LINFO << "some message you want to log";
+
+ // This test is considered a success if this function doesn't throw an exception.
+ // So we can use the DLIB_TEST_MSG macro to perform our tests since it throws an
+ // exception containing a message if its first argument is false.
+
+ // make sure 3 is bigger than 2
+ DLIB_TEST_MSG(3 > 2,"This message prints if your compiler doesn't know 3 is bigger than 2");
+
+ // make sure 5 is not equal to 9
+ DLIB_TEST_MSG(5 != 9,"This message prints if your compiler thinks 5 is the same as 9");
+
+ // This is a form of test you can use when you don't care about having a message
+ DLIB_TEST(5 != 8);
+
+ // If your test takes a long time to run you can also call print_spinner()
+ // periodically. This will cause a spinning / character to display on the
+ // console to indicate to the user that your test is still running (rather
+ // than hung)
+ print_spinner();
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ example_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/example_args.cpp b/ml/dlib/dlib/test/example_args.cpp
new file mode 100644
index 000000000..573216c79
--- /dev/null
+++ b/ml/dlib/dlib/test/example_args.cpp
@@ -0,0 +1,75 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+
+// This is called an unnamed-namespace and it has the effect of making everything
+// inside this file "private" so that everything you declare will have static linkage.
+// Thus we won't have any multiply defined symbol errors coming out of the linker when
+// we try to compile the test suite.
+namespace
+{
+ // Declare the logger we will use in this test. The name of the logger
+ // should start with "test."
+ dlib::logger dlog("test.example_args");
+
+ using namespace test;
+
+ class example_args_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+
+ This particular test requires the user to supply a command line
+ argument when they run it.
+ !*/
+ public:
+ example_args_tester (
+ ) :
+ tester (
+ "test_example_args", // the command line argument name for this test
+ "Run example tests with argument.", // the command line argument description
+ 1 // the number of command line arguments for this test
+ )
+ {}
+
+ void perform_test (
+ const std::string& arg
+ )
+ {
+ // This message gets logged to the file debug.txt if the user has enabled logging by
+ // supplying the -d option on the command line (and they haven't set the logging level
+ // to something higher than LINFO).
+ dlog << dlib::LINFO << "some message you want to log";
+ dlog << dlib::LINFO << "the argument passed to this test was " << arg;
+
+ // This test is considered a success if this function doesn't throw an exception.
+ // So we can use the DLIB_TEST_MSG macro to perform our tests since it throws an
+ // exception containing a message if its first argument is false.
+
+ // make sure 3 is bigger than 2
+ DLIB_TEST_MSG(3 > 2,"This message prints if your compiler doesn't know 3 is bigger than 2");
+
+ // make sure 5 is not equal to 9
+ DLIB_TEST_MSG(5 != 9,"This message prints if your compiler thinks 5 is the same as 9");
+
+ // If your test takes a long time to run you can also call print_spinner()
+ // periodically. This will cause a spinning / character to display on the
+ // console to indicate to the user that your test is still running (rather
+ // than hung)
+ print_spinner();
+ }
+
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ example_args_tester a;
+}
+
+
+
diff --git a/ml/dlib/dlib/test/examples/CMakeLists.txt b/ml/dlib/dlib/test/examples/CMakeLists.txt
new file mode 100644
index 000000000..93bd9a139
--- /dev/null
+++ b/ml/dlib/dlib/test/examples/CMakeLists.txt
@@ -0,0 +1,8 @@
+
+# Disable some warnings from gcc when compiling the examples because fixing them would make the
+# examples harder to read.
+if (CMAKE_COMPILER_IS_GNUCXX)
+ add_definitions("-Wno-comment -Wno-unused-parameter")
+endif()
+
+add_subdirectory(../../../examples examples_build)
diff --git a/ml/dlib/dlib/test/face.cpp b/ml/dlib/dlib/test/face.cpp
new file mode 100644
index 000000000..1bb1a94b6
--- /dev/null
+++ b/ml/dlib/dlib/test/face.cpp
@@ -0,0 +1,360 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/image_processing/frontal_face_detector.h>
+#include <dlib/image_processing.h>
+#include <vector>
+#include <sstream>
+#include <dlib/compress_stream.h>
+#include <dlib/base64.h>
+#include <dlib/image_io.h>
+
+//#include <dlib/gui_widgets.h>
+//#include <dlib/image_processing/render_face_detections.h>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.face");
+
+
+ class face_tester : public tester
+ {
+ public:
+ face_tester (
+ ) :
+ tester (
+ "test_face", // the command line argument name for this test
+ "Run tests on the face detection/landmarking modules.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+
+ void get_test_face_landmark_dataset (
+ dlib::array<array2d<unsigned char> >& images,
+ std::vector<std::vector<full_object_detection> >& objects
+ )
+ {
+ istringstream sin(get_decoded_string());
+ images.resize(1);
+ objects.resize(1);
+ load_dng(images[0], sin);
+ pyramid_up(images[0]);
+ deserialize(objects[0], sin);
+ }
+
+ void perform_test()
+ {
+ print_spinner();
+ dlib::array<array2d<unsigned char> > images;
+ std::vector<std::vector<full_object_detection> > objects;
+ get_test_face_landmark_dataset(images, objects);
+
+ frontal_face_detector detector = get_frontal_face_detector();
+
+ print_spinner();
+ shape_predictor_trainer trainer;
+ trainer.set_tree_depth(2);
+ trainer.set_nu(0.05);
+ //trainer.be_verbose();
+
+ shape_predictor sp = trainer.train(images, objects);
+
+ print_spinner();
+
+ // It should have been able to perfectly fit the data
+ DLIB_TEST(test_shape_predictor(sp, images, objects) == 0);
+
+ print_spinner();
+
+ // While we are here, make sure the default face detector works
+ std::vector<rectangle> dets = detector(images[0]);
+ DLIB_TEST(dets.size() == 3);
+
+
+ /*
+ // visualize the detections
+ std::vector<full_object_detection> shapes;
+ for (unsigned long j = 0; j < dets.size(); ++j)
+ {
+ full_object_detection shape = sp(images[0], dets[j]);
+ shapes.push_back(shape);
+ }
+ image_window win(images[0]);
+ win.add_overlay(render_face_detections(shapes));
+ cin.get();
+ */
+
+ }
+
+
+ // ------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'test_faces.dat'
+ const std::string get_decoded_string()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'test_faces.dat' we want to decode and return.
+ sout << "RFYXmMn7UA64INJ2Umw+YCh5xX6v+y3bqV/1EKP3ZtvdIWxDCoyZ8oJjj/LXTw3PyEMQReTyXO+8";
+ sout << "423ExTTLMRLxc5gEI4PK8vLMVWdRjRKldYfFLisH5qk7o6TxYfRXqmch/t4c53UfsUuJuVMvA+nf";
+ sout << "05tfgx0Z2VgrlMmvKSxX0VHEjQ3ZVqQl0mLeur1qoMRmcPjYA4QJOvcruwyz/hFpVU8snvsri4QA";
+ sout << "hkCrOtvl/iTlLvcWXDM98OXEYLk4vr6z2v73mGGEdbaJFafXOgutY0NRESpJfingJkuZKXNV7fQ8";
+ sout << "F1DmHwFM0Izoc3fUyN7+xtLo2LSEH8uohv79wjxdbm71n5kgHb6cnpYiiCLOgqKrK+qJUJG1/OBB";
+ sout << "9IORGL6SHrogefDg76m4ayil0lE4pQLa1fKhXGqphdxDMyBXHkOpxA356i5Evk7jZb/AqHUu/hQV";
+ sout << "QYOsW20TRhuFfrasrb/Veq10hMZomosAnNb+Uhgnu2Ip4igX6ozJqmCL4NngjEgoXuCA02KzmIqK";
+ sout << "fcslrBjyg7WZnF2w7UZXcHoZ0AVijb/ANPUqvq9H4k5WCFh53pwG6oj+CG5N7EXLzt9UHtillGJH";
+ sout << "S252/dLvzaNxSq0vDP+Y0IvtKhZgIlmX3duu+L7iJUMinE90OXQFzhwCaOM8oP3Mq+Becgto5Vjb";
+ sout << "vv+2xeyrYL9hBfMyPzbBhVWZdaPyI62MGMPjTfmAUmufWm3/Pxey5Jw6MUNX1rfWMXYQTtamdtcE";
+ sout << "XmuFGhLCWbnJN2GtlOWu+S0LdWAziktr0mdovdFZTwKnd2Fuf1VeMMHBCgF5D56UK6/JTXO1gPlj";
+ sout << "VpavizirUH46rHVnGOFZHWarPKymXyrQ/VhFbVBSOi9UGbsMo9sYNYhvFlzSBN5orNoY+5AP+Y2X";
+ sout << "BDw342Vg3YRIZBqaXZ7oOqyDtjFYNf2g7ColnaFrPiYCKg9VQvVQVXvr4Oa6TTryepXylHk11ijw";
+ sout << "xYgjiENTDKVuXPLVrAEWZD6wq5+LGbmj1cgh59XgLErkoliKxVeEv4ktXqceCFK9YWljqJZSBxFj";
+ sout << "T64N945O47Oek591sfRFYrjebJz3kooaaOQ7lm4BWpW64Am/Od0FIkggSzPUgf+jYrnPdH4oqv3d";
+ sout << "NJIXJO8aeCi/WCqgbdgLwpu1MB8cTdbK2TrCbi3sHhQddE6rqZ1XIg25gYTw72q4ZFUHmRrBEgdz";
+ sout << "i2xh7qHevCWC9Ht7HrUcLl35/UCcNck/ftiDa/xN3rBJp+cJCyS9ZXScFbgFKYfBau8Dx+JA3ygU";
+ sout << "pvYbUP29UeK5hXZoQzC74UKw/liapW8GbR3ZPV+59xmw1b7phyApZfODkDLG5lyTB5Co6vIfmqXp";
+ sout << "DSA1I0ZPIk7hJHWtddgvAg6Yv4GhMZs/cn08Z/ZRXYSMw20rA66BeMD+IPFe3MPFPnKsl/qRvXIC";
+ sout << "gTmLcjPQ5naBo/3haQg4TjJCBCErTC2JktJE0vRfm8v+1kEFoycTsRfhDZtjMnqmqSiNO3VSbnbb";
+ sout << "et/mFvCdd9kqBXQ4VaeYdJwORA7/TocC84G91jlfNRhbY2KW9xHXYYwscfaV5DtPWhyJBEimzLsz";
+ sout << "F59UbkO/WOT7HEwJ3p1/ReH6feLrIYwR8OdSeuOUtK2oboodiUmx/if9pyNONd0If9jFGDXvWP/r";
+ sout << "dL6NdRh68swmPyiCn7sAwaUbd7PF7K+jqgk3jMDMaoWE6p+aAFK3H/JxRO7th2XvX57uA+RKoXEq";
+ sout << "LBJrDkoPGbA5Ctj6KUWknMM62HVIO1SjwZH7ROOlCVbHvDd3JTT6TNs2Lj8g78q6WRMnDVV0L9Q7";
+ sout << "S6awjHePvyj9fDn66jpmRxI9TkqNPx0b8tGfOnUoXGKK/WcInBjVd+FYy1SO5FIVZ4JQdiSMgehu";
+ sout << "EU7X2yjbdgCtVg+kY3ZgYcCMM72NLdm7+U+qVxzfY9bNtzLRsEDNFL2wnq3DkbRRiHKHswEugXjG";
+ sout << "fFfzqjU71sDH3gc1dhOQT456XK0zhmV/62+N+oTxrFY2F/ArALeDiR1q1JOrE1UDzU/ezAZVWS/a";
+ sout << "DJ8O9TVZsbOXOU4hAiQ6mR5hh6i5c/zb7jBl6ehNgixrJqguo9PT10P3S6I+IiTG2vHlws4ZoTmz";
+ sout << "7iOj1ICCN6ivwPj5+afYScxPGLVujxWTW8ksey8RjSIbH+N9wUEAJAcZAGXo2UrP93uRkdLt/cvi";
+ sout << "Q4eIONlUrJ79hlo1xyvPmyNw9Ye1zG6epUK+lsXxtUbtX+uX+sDM/Gkr5QUdzzorfa4Mj9vEglQp";
+ sout << "iK5319kTeymqqV25YhNZZsoqCg+eckmoZP21cAHkPYZWAzg4lxhm71EAp+fE67Y44szeCSRNydRh";
+ sout << "mCl0ZGV59wVfUONdohN/9OQJ8gxifHspscrRszBALwCu2FrniGwRMvIi6lD6tb49dmpPlO1y5SEl";
+ sout << "HPF2I271RVx0YwFPprp+2/fzsnKIIhfgTNRVji2tfjHPoge84e5O02Z/cLlg7EzIfr/JK8Xd6mwp";
+ sout << "t3kIofktLqWbnBrr8qEvz2BmHRlEjYlpb6HfbK82ZHiOusVJdhAxRcRIqmImh5wtunlJOXdQqfmg";
+ sout << "8P0TYVDzQlc6ywbKt/VjTsxI4qisabVvVBkW5VsKU4/JONoNj7fgomwfHqryIIZcgc5vPLawnurD";
+ sout << "7JR1LRiTy7+7riDIxKdYtgbTuOnK6tO103XC+5NT07cKydDgtTu1gYoTTgzY+gkK53lX2Swbbme7";
+ sout << "wIFOXRwSK4ntGt4XXuuPbF8gc0T8ez8/tOEXE6Di8Ckj2O3uz2vGM27PQH9YU3Y3YCdcuaHlUIMG";
+ sout << "s7bgx4nM+xMPmS5baTuAHiPAxp3SepeKpHg0on6Bv99NPSDf7nWRcvvd3+/MHhyDfkbzd88GPnkQ";
+ sout << "Y7a91Tlj8RKG6B8W+zzn+SWzTIz5eJcLJCEKN7fXm1YjJjk3Tdi99PZA/K3Ek2k2lcd2oQPXj6fc";
+ sout << "EUs9zy6sZdSjttwsYlZzLW0Rpiscjqs1XNA+2D5UahufMk4AVpnDPqZ/OvmSZIPpTT4r+uQsEZlJ";
+ sout << "BYf6A9riAPhgR68zPz+i+ffpPcgwURCDviqf370nLIqfbDgJSztxXI2MPsHZqypB+VtvuLTq+M2s";
+ sout << "NVdvA6z5J35d3fk8EHVo8/TIWbsSulqnpJvjIHT9GeTV4EvaJo0kh092cFQ3QkbRIIzEqnT/o06z";
+ sout << "/+gqB0fKl93o50EUVzJbz02rPc/4qVaqfSJLg+HEZUg30PB8wQHb2oWqgL1lYDlqrv5plbK2kJm9";
+ sout << "HW9aQfOyX57rBsYi4aljZl5icy/JElsuNanhnTcTvrHQ/9ntw8QqV2PuVesbbaUQSjDRWG6D1uaK";
+ sout << "HSYB/6fvMZ+8hy+gq4d/tMKpoCzgERJjJOU+N0vHpKmyZgE0EGkPJwlgev/oxTJtsQndgrNviPkr";
+ sout << "ub9ZRkS5uM7Jb+1mKztyVWDdtVvxnvgboNayuS5/VdwlbQezQKEiB8I8UWLsgEJiXg9sgjBaFrv8";
+ sout << "3WgtSQmMRyukOnDPWwDskmUyIHKIye1wOY3H1BUuav52tg8gv1+y2CrVVebRm/8MmhJYe/8DpLeo";
+ sout << "1eEAX2SNtZH4VzpZSuAdANtYXgBaNTc0uWtw9Wc8mwo2hTfbu5nVYJ6vlUFP7HH7L+1idfrz832x";
+ sout << "l20/+8EKd8y2f20iP6m1mnKCQ9PUPPhWMfWMkJ21VhKJJDcpKhvQq20O/yqhfxabl+73QZiCS/eR";
+ sout << "E1ih71FN4x+s952FXOakzYlo5Gge1OCgHE0+YVXBSal4fz6Ye1iRG7+XgLxDIx6AGbOfemRQbOzW";
+ sout << "e/8pe0kPqkkS5ogdkCGemb3hKTgGFGXgP9IvJ18VuRDxPSHD1e5THwvULz7V0hUWO4aKx8t8tZIK";
+ sout << "flDB/npEo3L/1jU5rLEuX5KQbxJfEY2V22hND1ohUZdU+Uy6BFZ/hdYzFAUyNdLAFS36wB3XThar";
+ sout << "54CQ0RGIxCCtv5ucI7VQuc46jkk6SEmZV8FUwv79ExsrB+rAlBvhNOrmVA2vmikrb/iZfA6z1hC2";
+ sout << "aj8BGiULdcW0YUiN/TXxoVT0qgh87VjPcOfkj8gTRMF8VGAgRs0HpXVXKZf+ncAOKJ59yu3EsOCp";
+ sout << "vSG5zxDxrUYD7TFxrelev/7Jtan4J8ouFsK8ZA4WmBkbWLDBkKvV09c7Jxfmgt27aVI5uuUBJYb0";
+ sout << "TdKIG6J/JC85GrRedIb/kRaMiTP0Els6jGp3C/B9PQq6HbzSCUEkLE8is+uWnTDOynbgTtUVEGxH";
+ sout << "EKZGBtnPfqgZRDOnZTWMO9Hd9qATpI2qRrgxIvHTUhqD3DQQ37AGTcNsMmj/+mXTBV2vbM9H7Q5K";
+ sout << "APzltdgkGc+hZIL8Fy3CXzzRFlXAEoIcnJ3BKT7AdGg3pEaW1YcX4akaOTDmImZelYTCoTGu1R4Y";
+ sout << "ZD/rRCeiGR9txS9x9/ptvxeD3J8wYOXxDzCkyMQy1io+izFuN4kTd5MW3RlvepWT4FE2hvjyTyV+";
+ sout << "G4F+lcZFjCGmJKguZEzH3Qtww3OTGZqOfL9oADqQsUpEl92hm0uNOPHH6+8gWPZgb3EcCRkWeXaA";
+ sout << "DkTbCC5rV04N6Fg3j27K1v7qkykWB4y61W00dFejgPitoYuln4A8lY9RtIKzJYFfSOKnrAqYGMkn";
+ sout << "PReI5ZneFTiklSL6FkhWHvr5oz4EsAoU3fLWjky27E1CtpFkPKHLGyc2mnf2N7iBMGa8j2ipy9xd";
+ sout << "JwInqdDdGIHR5dUmNAy1A4vyGbOE7qa0ErVy4m3riKkyE0ObQTNoBeYa6CUHwNWCLCbiW2VzRY/H";
+ sout << "s6APvI6QIco+Wu9dx36qklvA6/OfM79BsUACnCRG2yo3bHeTeKMwKIx05RkNNJ76eOhjxOiPJBVR";
+ sout << "K5V+G3SutaRM0hSK0ABS5nfveZmowfJr8nZHRBAPHyIdv0bJW8lSNIRtZtgwm1dl13+eaeAJNpJm";
+ sout << "3eVvhP8c43LG938Bjxs/nfDl5GMzPnyLOIYInIt4wTXCZlRbt6pMXs9IX9/DpUt/AF44b+XQLnUT";
+ sout << "DgJa4In4qREDt58wel1oAGe9xpIqdFNPduUuCo/Ly0XHrBFd3jQFgp0JWZ039GtG2YTHpo8rLfu+";
+ sout << "TiTAmKnFbVZa6dlOzmVDV+53ptJxiWNfNMa2ri/YEFI7BpKi1FZvMfpGBzNmAdZBAt7kaSvIzqdl";
+ sout << "OHMqka3GBbzyW1PusVpj6SWBZ7rsfIFPRdVO4PGcOWSIQ/YlZYXtkV82cw+m4D/ScXdlj0VZBB04";
+ sout << "YfeUz1m8tiWsLdHxIKRI1JcLek5PzCQX33RFmfeqGBEa7q/kCzGiJ0QHiJV07Fxm2NKRUZQIWiJO";
+ sout << "n1roUf9C06IT03Wd0rcSlwG8Ji9SJZOlBEi7B9Vos61eXMEnkcNiRVeOxLaOfqr20XMde4kquqrS";
+ sout << "qZi2ZhNVqokhXNswwMtlMDfWBWWI39q4evlS8c60lQjXg39kyKbyVZOzp6KrJ+xDYUxURb2d/7DZ";
+ sout << "+UpikS9DmbUgWqV1pmx84ARtPp8/5teoBM7qd6sRrpJ1Q4cE4WLQfr/nVnVmSfZVzm6yqlxjGvhQ";
+ sout << "Bwz85XvMn4xjWIRDNnuwyJh7PoXkoacU54idHA4k7B3qeW59MKtXD3hA34qKUqqE0amH1Xzi3W84";
+ sout << "HpL4xT9EoNUPv+ufvo/4Yir/YIMIVuj1BFQACONJWezdbHO7ze7DQrceR9ojpVlkM442CArpnYWZ";
+ sout << "3SR0NRB0PFjJcgA3RPVWW591v7Aw3G7/aUh5OLSZsOoxrnE2Hb8wM4wQXutleUjcUCxdliAbLl/k";
+ sout << "RcLSE+IzpwRpnSo05sBx91Q71Ws9NeS0Pruy03wScuztfv15MPoBmMH4Nc+JhF/iokGM7C4IICOb";
+ sout << "Woffq9KzUSWaIEEO+qaKqfnG7+/bW5U/gDW0xAfvTt0IuXtoJH5/Gq/g/anFdecnyCdBa4C02MFJ";
+ sout << "gLqm9Dyiuh5Hny9bAuE08YxgrfpwOx3nGxQ/MrXwGUNjCxJEz4TkmFt/Z6HmOUhobQ4Xue5rK6gp";
+ sout << "qWkg7gkXJmpNaj40TFnxQ3Fvw7S1UPNcY3vEvskuYXlB6mI16FCa6ApkR3+krkCKokxelBU5eMxx";
+ sout << "+j5Lv8hYzlSlULUJrDOmGRxQ65lOJPmxbuukr6Uaeh3/i5m8GLqJ3EAkeIABkck+sT/OCCS+1hmY";
+ sout << "/wRbadX8sd19GjHTLkZ7jwEss0qQj1LdtcEtqXgtzy7+8NhQv9c4KhxFNZdYDezehZYZuf4+UeHM";
+ sout << "UAPE6+DX+AtKRD5lSwcAEoFID5GLbsLa+gGWzhmn5dfTDvaJrQNYzI6K1f0MBsWxuXu5SM6dRehm";
+ sout << "9FR9es1OnhIFP8bg5uR7lfIEsfAH6ysgkJyylceoooXwE+cALB+IYfk3mIHNuDMbrQxWIMZmS1er";
+ sout << "I45Gz53QPLgcSjH9Z1mzrhaZV3ZiJqrTYETtObMRX0136rNLBCaytijF1H4QBJaNEsuJpHAYQHsC";
+ sout << "ko8cTk8lzMmh+ENtUjLvrG+pEx9FKCPIgUIZrAAM39fWS5uKhFaAPEkfD/FxbEMdM/hjzbQzcMHy";
+ sout << "iYywZCaCYwolzrQtlTIL3I4VqzpUk/vWMIZt/PMK8lUGBDxzELkXNTYepRJ+uR6QZKOU+/LpLk8K";
+ sout << "MD6+BZNl2karSYjx9qDDHuADoiNJbSXLV2NqyJiNyTmdKv8E4e52mVFfetoMp+gpd7vMJaAObPuJ";
+ sout << "A7rUsNjWz+Zho52LXUSUC1G1MdZPRZPMfk/KvNnXZX93P/KaBXTNLlRf4KkRXQdrnmP5KgKFOvIX";
+ sout << "KUMJ7UQtP0koYJl8OsFHG5yX+qI4BOJqVT0eUypXtlW/CHRX51Wl11wIDsTqbi7zEkBu3kFLMfnr";
+ sout << "3MEYcVWthG5Y9KecfsdLtRnQVSFRMEKEn1kpVAb7xqZhwKDqREVs+32AawqNxPdO4JOCC/wW0zzH";
+ sout << "QlJ8KhbcrxKLaygvVOrqaZGeCNTrrRxpG792CtX4OqXdkxpqcPNjhQJXMgSm/OvHJ1sJOrIHYaru";
+ sout << "JPDH/J4e0qg28rlRpN/iY9xc5Q/dMxhoEsv87ehP+MPKF5ByxrWHbNc9IaOiz5BQIQWAc1glXkKc";
+ sout << "oyMgoCbZ1Zss8zJ1+k4NwoC59BpnT4KbZHZP7+MbXZidCSosl1P88yC0vj0BX8MK+8x6PA3X9Zc5";
+ sout << "Sg2OMThc6WRR3oi/DeHePYqJgPtJwhZt7N651llY8/YIJD7pqVEiB8KJGcrjZ9tDuU0MDTkKnrr7";
+ sout << "qhlSl/XFBGq0x9zIfYeBt6k2wgrpE2/FjgziM3YuH/e7jfppGx3S3G4O+yrUuAynknZQ6Opq+Qs1";
+ sout << "PYFZMW8Fj+CPMgXGy/2+JPXULK8vf0hN3FfNDeHiGpuGEVaEBHdZdE3CW8cWwyRSdvQHgPbXwEaL";
+ sout << "pLGxPhgGQlPLK/JvDQhtL7gqS1LQEAER1sleOHoTV5zEpCsKUcvisz0V7TyFozYnzaG0IcfeGysR";
+ sout << "Kx7O4dq05dlnXicv2IrhcCp9+QZSHDL/E/t/SEafxZJu+sIhqknM+Sgx3MAhE/U8KyOANAYy1aPB";
+ sout << "H5Qt5RnDbYuFJXHDG9DfEWWcNP8e2EL1CX0/gncYvIz3b5Ge0gv4hwbV23Xzsi3hOki4A/B44nOd";
+ sout << "fAE3Ao1D16+6XowEVC8gUh+y1TSnYPTtnB30NbrcMhHNaJP0pYzG41gNaMJNjK8SGXSnPKiXrsJ2";
+ sout << "1jsvbLaMTEBdCqw+2lgg/QwEVgUef7JhVNKbQKf/HTtuD50Ofa4PXv37Q92/xnHWHwjbWfeNEZWA";
+ sout << "Y5bqQI/MlWBKWSGFjpoQsfbCzS6P3ieD3ID91DmB7tDjTDYm5Wq6LoHAQzx+tTxiq45Qn5tUg0zx";
+ sout << "7VC8Xh0ygItADoEqXpKUXdfv0o6G4zxjePM4dvyf4NBkh27q1S/aQwQp8UOury2Eunij57XSxaOb";
+ sout << "IFwlFJYC6wtmg6oV0QZAburx8qJMLdIvHyqa5YJJ3HGuHJIyiD+tuxWVxDZSJQ1RedYqBPAiloyy";
+ sout << "Z3j6EJKqNh0kq3c0K/B7gpX0P1ZjOFZUadIbEW6dd8bxl0RuyhzdSUMkn4rWgmx7zPvoOEgXzbK+";
+ sout << "mJmBT9lEaYE4sFkXfZqDCGq66jdc4RrlD4IZbGkXwMeTuHsfLL1hviKSOyJCTbmS5dKZUdWuVk8R";
+ sout << "XFqATzyHHdArHR9RvNa2bQk5b59tODpeKCqECv5DJ5T2ap8u21ctDCJiFCHq3gLfw6IumI1L2m6/";
+ sout << "aPGdZ0d4ZMR9ooQvJuhAODKg0ZsA+LjHTakpfUwHpanPWkCbmhuh3oMC+hMX+AqOi55nr2O5uxRL";
+ sout << "9UwixLBmRa8g3lbgUGdteTTbvZ2ePyzwxhWp1RS+aFJqtORmRMkgB6vRc4SC1CywKwwHoR3RKEj7";
+ sout << "uW6oQyLQqPYevLuBrbO6yn7U6DkU6blGF7jg/2y12MpwYPr8l661YMXXsmVHCVSJ0AsPSGhJ8oL3";
+ sout << "Cqk6DJbpyoN8O1xK4zNE0cToRFMhEBjlim0lg56HtbEHfLkqRwnFqfo6vvK6opxJoShMXDa5jrLH";
+ sout << "GkpmE4DzaGtZ/P397TF3Y12c6lXJFDYWslp4tMskzsy9FdW63kRUvl6Q3UsWj72qGE6r/PGVetJ9";
+ sout << "in33oiVHTjYppFrza0ryzPE02V4uobC3y2DtoG/YK/GJFkhm68O0QKxyuBURMfT4j046fBQAbwUo";
+ sout << "B/Ylsb+srIUK6sIEnnzdJ/8ve3f961Yx5Kywf/9kVZnUz5RoiP/bOWXm6jasSq7LEvX2nT2fweup";
+ sout << "/TI83XIDWd0rFQoGTfuIuXFfLQbXkX4oZpmYLdr0kQAKjFtNB/JYV5PTE7PskJKD/AS6iYLd2pOf";
+ sout << "cEYJxPZWSWUSz+EmRNmeDl3lfch9LD6VXgaxY1xPF0/1uQfU+BBikdVQJPlMzVB9QK17ir6rynim";
+ sout << "CP038a8ctWt5RMBsaJPZr7bieh11aTTW2mC65Y3PYQ8WsXofuw4x3xoXI1S/hGCM2QvmgWq29xHp";
+ sout << "5Jkp/Z5YLaGBGxCHM+QQm4LzggVnhYjlguAbfEmWFapwhzmx1L26gv2q2AUiozn2I7dAh1vDRKD/";
+ sout << "u5XMODPPJE+NTsr5DQz7aIdEwRLZynp1FO+VN71nYDa5G8ruF4v320ocRKm/mQ5uzU4X7w69CLdS";
+ sout << "kmag9jfDGvuolYqCooAts5b7tFIcC3WjlXeLPq5y8HmzY69Z72HrISpq0Fyq4vcZaksLSdpv+Pil";
+ sout << "iUu6Vmti4LYLHYsmue1UgMzL0qqdRMz6XmCPBkUXYDS4oOoQuDcH5iJo9aWRoADapKUHqIUnoR3O";
+ sout << "Vx4h5b2/XtCEb8dNF0ubJ+oZGbAize7SWTGnUCQdpz2wxAWFymc5/5jFxfVU2BKSPhrKRYImgnVU";
+ sout << "8GOms14wCx4wyGuRZY3p9s1uPUBNNjDnJqVg6I+7STXJKrYzeP2gP227k2JE3o9eLehe66hcqaPi";
+ sout << "egdpG5RfPN9X87FtePJ6lPjJ8j0Ysgoa6l+DDUDuEZHp6APIG5miY893oac8uo/r2RgRNVv6vLFo";
+ sout << "8VIFR7IwcLj1zXvwriv/Szw3POh7y8svSb4eGKr1c1/5JTJB58Fcjz0AMm2+rW0Twwb3STxn4SZ6";
+ sout << "nXOyP0btY0VCckovcLoFij3lsl21ZGMdfG9cHVlKL88pg4Ip00QmgcQW3QmBBoCjlPlRVSSsDHGY";
+ sout << "8TBBGLuxJi7NPzU1HNtjG3cnYw0t49og2hjrIbF+6fHb9x0pPtnJZwX7SbBYlk4Z8v84fR7cjC+9";
+ sout << "l6SLvaRgqkTj/aaiiHtC17zaxNhP9wHqXmPUZdKM6xs3vsAF1dYOLPlIv+nmMLBPfravTZDkf3p1";
+ sout << "x56EjPHwT1GULNkLBG8iod/cteD0E0GIZf0g5/o0hfI1t18CMGxfMyeaASZugV3+KL1yOwXD07D0";
+ sout << "lp8iLn8FlYgXOgUO7+OweJcIu1IwkzLSM2aNbnH3VDPlu/Ff8ZHL0jiuxhQAVT6jdWJSUEOf2yiB";
+ sout << "8mIGCK7CH9Xv9l/grSdh6GrE+NvvWqKQrhX1k8wxHEM4PnhMSO4R+5dRFWeeh6cYfTNHTYWZ2xzU";
+ sout << "run1L6tULZzpksLtHYDg0vEEq3hDS/3yf+/cvCX4ibt5pUqorTpAtNvbTaguUyBy2TOAy6fSFGh2";
+ sout << "eHil/3JEYZXnfFtcBo+pIvt3LWEPlWCUHNKFLQnpd77Y6wUFn3Ku7o7Nu7zBxemmvxxYUXImAlzt";
+ sout << "354O5G/5G1GUGFf1K2u11fFcuXrfowEE+1eUEpC0KLyxZhOJa5nA6dKtnQbq8wrGXxuGuJGlDSAu";
+ sout << "0sUYLPmELc+kGUyY6A27B0FKFN50bP1U5iWLAtxt0NmPqOnwzvnj5GYQ9R/ZIpzf73N2OoYL4+ba";
+ sout << "6E32ION5IxY2YQ1IqEHxsjzvDW5KoQCb8oz63eKgwLHBz/1yhA0ELpzG9ti5pVE4WbGKOtS/2xTh";
+ sout << "RIgpnpbB7bPUdtw33cjky7t6UAO+QYI1kg8rscd/Ug44hd627JK61SxnGlK5wBRj7aoUxH2yb3Dt";
+ sout << "jMgmcgZYtdIsHjuU63vrN0acMHULcyCRIFuFEtXgnQNIKjPUG3iuN3714Y9sncW5HqDGAYLyRpaA";
+ sout << "69XPtqYEajN2uLF1Q9KIeW701X1diQoHw7TFq0p5x1oTMRzjqcz5lnLIM0DycqCPGoGAnyL0o5A3";
+ sout << "Rw9qaSq1bB5VOdqpZHyN38huATstlHmcO2GqN2r9k2BKqqDYxuzmhB1K5ugoJID0lm6KfR87e+2q";
+ sout << "CQL1tr6ecqFe4LkO6nRR57w6gl48J0tFYuxsgRcktYvkt/BF1JHNnw/BChE9lDBgBZz8TfAqUv5b";
+ sout << "Ofi4WiuGxq5sKIa8a5SRwfVbz1h/MBqBEd9nDlZ1acbg9ZGakzwCfH0WrcArLiN9FqGqmiZClS0d";
+ sout << "cUwftQDUI7yEoiIb18/479c1VU2q5dJddCBaabm8CtGtd8w8KTANBXX4pZoSuOlYUUQaxYsZ/avm";
+ sout << "RYqfU1uOkHlmqm++EvfKEGW/Rmoq/fCIl4mk7YoLpMcub9wlTSVP2W//uvZG0LYwlZScWGOGmo6Z";
+ sout << "xa02FmCXVIyXUfnitTEv2oYj3CV/57nW4+1jkXTcYhH8wt5wX5G2eO7qw3esj1x1kL07xmven4Il";
+ sout << "nmKMri2FNxrWUBzvCwmfKGfPh1IkQ9LAfaJJATfGeBFI33RQdSILTaozNl88dHbUO8I+ZnySvavV";
+ sout << "uX9Ia4Gvrm0nzV6WYbE4SCDjppsaz0CSZy1exA9t/NFoQFjvf25pszQ9JlDtwDm65ssYqCTLjyoJ";
+ sout << "AjMMUygef0nD9WxtWkVCywmyR5HIZWqwV/poAROWBUNL72p0kYxDsm8u8D6LFDQJj+/rSQFr4jzF";
+ sout << "RkH7zemYp52d7nSo7ZV84Pf11aTTWqg1OCwAFz+bcg5wZObUL48+WzPAPePNtlo2ef4hn3tyi/pj";
+ sout << "v62xZax1oHnnB4ozyZqwSd7aH8LGN4G0Pk37PCagVXLEyLEGQXA3NqxxokimT93xORZOF4hZjhUR";
+ sout << "EM6aKChyS4DKkB/HN8IEMrPcKL7zadhDrWX6aeAxBOIbF02jTyCo7rFEO4g3TLuy+aX29SStyzAd";
+ sout << "bESo1w0hmnjboB/cUVh9SU0rWvyHnBveXBU1QFsmKEpEVXbD9iao+VArYlDKQgXS0elIQIJHLJ2v";
+ sout << "8cVk5GtsXQU9Yd5vyzC9R/ZuUD+fuRcoKYwevCjnVegbn4mCK6VICihqWM5etEr9CqdjjzAtemN3";
+ sout << "RhW/4c1v6Gfcb3LQIctJmk08SbUfkImRlULjt3sr+iF/9gMx5AneRDqnq1YRbiusqAfpCl83zBFm";
+ sout << "/txWD4btmU3/Q9TzIYjcDm8JFIpptv1+F6myuN1ElJPj5dcmfBZ2/KQknRf7cFBSFCfeKS2glsIm";
+ sout << "tTj1p8jK+qKf3GS+v/n6VutWgGXAU7bjZSfaWn3wfNX1rJXOX4Czp1dXxmXuxltVQ09bTpEMQL5o";
+ sout << "F9LI3ExjgsokVCsnuAc25hTRWP6bUkWsd3fvbVK/Qrg8sYEpq+3836NJyb2BcFVBqmDJaW+MZ2+P";
+ sout << "dJwzMjQMFfRNnsEBRwHuVTDwa4tyR+1yFOqG514ohI4UKamdebXPlrjm278ztqaN/4ASFsVoQb7O";
+ sout << "nPvqiMT9REV9ZLiK8z2NK4dDr4KUCr/UihBZqX7qQWnFRyy0VooOFAyP9CjfohLthZ8Y0HWFDMpx";
+ sout << "0imNkxXlh3CIDNTRXAmgFupIEQyY08sZX/Oqx4NdzpxWLh2S0VJmIupzuxWTnKBRJXWZEVPKqsak";
+ sout << "zBsNzdtcqrF5dIIY/7d+8LFdEBIZ56wftYkRRvJt9S7mPIY++gohDronAX/ohfwXwu4jCm0OLWzO";
+ sout << "l4FIAdeN0m65Nng2pWnF6M+qB+b5BT3I5cAh7GF4t3woiWOBF7LnjOYmHe7ZkzPzftxeZ820e+eJ";
+ sout << "i1pFSQl3/H+aneIsjEQVGxgynGtvGW3pz/0f5rWrRdUG9ZYPJItWXCjJz2k/Eb3fHDA9JVIT9FEM";
+ sout << "XCsN2FQguDWy7cDxNMKxVRuhomjmVCR50/W9/Wsjg6+tnNNhY5Iukh40jU9uN/WShMpzXNv2QKxR";
+ sout << "bfUtL0zpojBbZOkAd1LZ++bQREzjPBFQ6G18eQa7DphNarttvCzw9qdgvmwpl3CfvW4PeFChxQ2V";
+ sout << "SmfUcuFnESjEN30zsEyFtIJNVd1E+4C6vDxk+FRWl+jGbzNkQ1GK+9z6M0HgUxBXMXcHiJbMEJ3s";
+ sout << "Ri6PqeKjZUAJzORigJmgOO5wURFI+3EVQ0NuPWfaQ/6UZ1n7fdXi8ncmO4jVwY1ptN/4tJL+bAm0";
+ sout << "eu/c6ug2bfNCTx6TnutxBAg5AvxlDrJIUsx+TsM50J4OZmv7pI4dC0UWU3TpzaYHa/cQ+OY61UT/";
+ sout << "j1/QPSquxIlWA46SGt/xrUg9iPIWw+2hdRBkhTBE6PrMMHLl5jxWpJ96YH5WDnifNnnhzRDcLfyf";
+ sout << "9nZYWQaBzm3JangobvnsDPgjF91OfeFJhVPyUwS5u+Sxw+NKPj9wVAaW3gZVS2BOODAud/EvFfQF";
+ sout << "08t3Ab27GGwWNx9ExN/jlq8EHAM+VRcyfn3hBVjLZdt9fs6zZliYdfyNaA/fgO1iLwY4DOSKgALr";
+ sout << "Vs5+KDTmyNf5DkAg/W/d1tA/o76/qXUz3gHdm8Q1TqA3ZuR3g1BE6oS4T9Sg62oYIthP99doW8ll";
+ sout << "FPr0aF+JHwXzG0k6Ao4wGdKmqi+n91pVYI7r2zE1kKtk7GyF0fSmcVnnP42TQfaUbaKB1bMT+a1V";
+ sout << "NMIydv7QD5F2af3El9np7/1S69sGpoNK7PLow71jdlCewGVrq+iYxZaVhe3xHwerQQ1uGfmRxUSK";
+ sout << "exonRQPz0yD5fkjRCqq+Hbys1CXe4iKa2uO0pW+yMlvRnMWpV3Bx6uxYR6pECBJ0x8DQHk4cBwa9";
+ sout << "J/vAlf6dDaVh4PIZy6PF783iAPRusNy0TcCxXTJbIg5YqUb5QAyJ2HEPbIaKeZwSnhytAJKP98JL";
+ sout << "JC+81f3cN/tJKfWQAJz4RgcrbhUfyf5yWK9rPOwvFf5WKfecfPWc6wpT4IqSrlDe3LQqxpFxJwfG";
+ sout << "dzNo23TgOFK+H176FNuW7jXD1sUsh9MA3Yycm3BWZ85cK55DOI3T49K2WB0KnFLkhA9wPJNb20qt";
+ sout << "m1M8QtKwF6jzpMnbBff/vM4oNL0K0QzXovI4gJTDUkHWUlZ3XxMk9PQIqoKZDp9v2JTNW6cs1zcV";
+ sout << "Jcm2XE2ZhTCFmOQ/DIG1AQIzQfOY6s/VRvHPSkIJH76fo1ex2jj5CmbGw2JFNLPDQIIHEjQJGHFw";
+ sout << "KpVGso6XF6CdMtcmlswoRrIvN0odvC0md0D1lf09ZuoyNt68aMSxgeBhoQYW4qC6E86Hef3TmzWO";
+ sout << "esZCvDs2I3UVKgPEkICz2TEHy7dcC1VzU2k/F/cm3+y33lloDem4Dh2igflppvhthcYCLOmFiEW4";
+ sout << "YtxiVJKpIAKbNHqv9qpxNdcbLCsxkENYKwmmG8E3fLSyE7St/z3dvuTuDI35lxANzv1N04YrvfEr";
+ sout << "dIPmvepllDz4Ua9TyYQoZezD6UsivW0gfWdbG9dbYnwAZRKEgQjRD1zvGUt/nGCUKdtBDh/Qu7jD";
+ sout << "c8KMBJx0meoY4jLZBSw3Rccd/KOvzC3UTSsBeTBSkQL7wuSMWZf2cTI3BiWaVTaUXREVmUG0eeSy";
+ sout << "W+Vym1qM2tPWkUm7V3toeS94LU44OlGDPyoHHUOsT663Zew0ao3+rSqS+KASC4L+6oqP3ffSI2WC";
+ sout << "517CtYtPA09gFqdWnN02mJ/+gEWrYUZXJsNh31AGQ4e4N2L1Tupy+L+mgkjTyHeiV6dUsvVQ2J36";
+ sout << "R2VL8EOQcchBMinBo0WKkP6xoTPcCwwMI7T5sHHR+KVUs2DJXTNluNFhYyOic2ImOwhoOHKwfr7I";
+ sout << "bJERGZYKwwBqGPO0mMnB81MZFumFzoo0SNhvNC9a74X8U6gKtDsQCeIHNXWsWxaO2LmjhlZqXEMC";
+ sout << "nrLi4UnXweqUgDscbWdq6fE6Ad/ZimmUJlj5iF6aXWj51B2VIgtYXxBFlVPURZDw9z9KPdyidzM1";
+ sout << "PSitwb6KwWZK8fJ4ZUVOUt1bIki9wAnpyDdAhrPtDVtngS8RJrcCVRrjQ91Vhr95kSTG4b6VtN23";
+ sout << "m3lhkU/T8RAVAphAait/TICzXeXdnFjALDIubsx1e3FMhV/TJHflVQhnahwTaUuIVqb1ZkmP9aMw";
+ sout << "0K1G37NbwIRzyHGVWnvJXxKfLeV3n0OJZ3W5dDmTOZSE1s7JdJ4rdEXpiMVsW07TB6JEtIz0c/AV";
+ sout << "IvDmgF9QH5Ly4Ko171Lg/tVMjBIhCiBU4zQco6brJHx+MoxEsNUXgUCNoMqqjd1a3F5wGsvhlWpn";
+ sout << "iwNpTje2A/ccVQmoJpDSNByMTB/0GsGJTjZ0dWR0KiltaQaKTlRFlGn/2jgbk+i7ujAGj/Gls6DY";
+ sout << "WoR/asND5U/rlyzwJS77EefK16ew8w3Nfy2g/vVvpIUQBG67CnjUzxpnutsI7KvZak2Yehg2NIZQ";
+ sout << "IAuyXMd5zixobwku63RvLT6Fd6D63HnQhEtgtPTcLJT6rQOf2cXAXN3RLOghcpM44flJ0PW679AQ";
+ sout << "pgVrmELVQr2Dvv2fCb3W5k/JdgvTKrJ2YPIY9Y7q5aSZ28VC1u1k6y/KsppJ+6t+TW8PloS+qKNb";
+ sout << "2tawgDRraMsNVUTMhABYEPZ+qoYspJPV4rKfVnELc8otpvkR404ulUsELxml0TndtntTjy/f9lmi";
+ sout << "I7pfDqPsOUfwUtdhuW1XyIhX3RFmavc/VFLHu8gVSmJgdgxjraoe1ldt4F8vgQu22uZu5yBLtxpw";
+ sout << "WgKVifBtJx/lZ4iAPMNJUZt+Fo+1K3uofYmOg78dsM3BSeghNcNI5ki4NrUJ3yWTzfvrUIr90Ee3";
+ sout << "VUaGz0/wtOsFI74ef6BaTCYNAt7psf/dwoow4r4FzHgXNqr9MtWvfp22gGCaRLuUWTHgRzIkCflH";
+ sout << "YyNFYHD+yTaKmSrYay2bBYFGRFe6Vx9b8FR7fmdyqq4HzRwQ2/FVeSQ50OHeT3vWTTg07M31JkZY";
+ sout << "VlTjnw8Ew9N+o+qI9AFaftksVdXr2YjXMpJMwbHcqK6himQLSkzkUEQXsBI6vQSPobBtPM2oNlCs";
+ sout << "oEf4AvHYmHxyyzVdSRjQqEoekWURI6htJxRQh0DaDhqOaRQmUp0AkB5StSD8z6o6Wyf0yvzT8OJo";
+ sout << "NEjaoZbz82eivQUGf3RmYLZIx4xjZ32GA2fUN50RfgyF/KzW1GNLLykJ2NK+saJKGZB7RC2ZlsI6";
+ sout << "CmkHIZeNwG131Opp+ygasp0GBJUzlKnEzmV9CIBigCv2oPPxewEr7T8/OGDPXWhY2LEgVbgyxLdD";
+ sout << "MnZlJC28aanAWGoGMcuXgyacFhOWc7I7TTml8jlQv1oDRJU39PQwZZ93HwhhZ2s87wuRYCQdXSH+";
+ sout << "pdtb9YblyosciNtYJguXl+2iBdGGIqNAp6oxYj/rI5l0pPY7EUkGEnOzQq/U3zCDPL1tdUymaOh8";
+ sout << "+XGYMMo5L34oxvi79Rh4snSaOO5h0fl0aaZES/v/x5QTtQu4H7qwVfWIsOg8ujFFDFeZmmJpb7pc";
+ sout << "ff03fNUsHBfS3yf1JvrXhVGDh7byfy8DTFCy07gXUaKJzQJNOjEqX+iu25I+RNxXbfHsRxsWdNEw";
+ sout << "6aHRUUb+zThVJGkglO5W/S32SJznnjkHdFu1WXBU/DbM7b/quJTUgWnZRLy6CugMu4I/hp5ArKCh";
+ sout << "vekyQySI/9TjOEPEFW37few/H060cZcz1V3DilvOTBrCoItzdjxZB/cmGxPln3mhkPqGr+NsXRdJ";
+ sout << "tPMRc91NxvDP1oBnXXEYD418bmxfvqbXZm4Q+bMq07TT5rWABCZ+NFmkgkXryGhfKQWqidxbRenT";
+ sout << "8teao4iAbwSHWB/Za5+OPYUS4b2u7OULXqkmbQDnTeJX6ou0VFYUpkXStbtLITLJrvAG0BrMq7G9";
+ sout << "09vbgw6e+krF5JHaFXKAgVf3/SlvRrZi6zgPT7wsY1WLGlMAFNAC58UhizDbVV5Xivj/mVxi29s2";
+ sout << "tZHKQRshEnuw1KvEp61O4HEnO2dw27v1000YgRz9HtQp9Ra0SsePIjlh8/h5O2VTqJgMlqgAiy9g";
+ sout << "l7nvX3Ft3a/K7S8GlVMwM8z4DiSzNf0irgnC/+sVu3ZAy13UiviJTv0gZ6iVL78GdPSqhKq6x8IU";
+ sout << "JGX2sbzthVZUpYPm65WjEBKUKsHWIdeAokKh+sS2TGEAo9COawULVwzYttKSY30UlsBLL/ofpjF/";
+ sout << "eCWfzunbpZ3MmkXJYVygXCsGgEuxDmWrhcxIF5/pPZDafmaxqHCml/zxew2twDN9lEJV/jqZrwSL";
+ sout << "I2awqrzxIp/4TfqYj4C8HOQGb4N246snRl3iH2tvzQZrCg1I2mp1s0+xiHESKtfecHfMt8hbb2QZ";
+ sout << "50c/3MAKQb9WatUDqfZuTnnwXMo5vm0Sh9KqLYQj81LFOzKzf1NDil38GSmFGWPsNm7vm8Q6S6BI";
+ sout << "MZafbg2gM+ohtKS/u/36ZADS9/bxf90Fzkn5UEjZUOIRBhYowQNilZzHCABNNXFO/5SUzSJqgLIA";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ };
+
+ face_tester a;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/fft.cpp b/ml/dlib/dlib/test/fft.cpp
new file mode 100644
index 000000000..9d491e8fb
--- /dev/null
+++ b/ml/dlib/dlib/test/fft.cpp
@@ -0,0 +1,553 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/matrix.h>
+#include <dlib/rand.h>
+#include <dlib/compress_stream.h>
+#include <dlib/base64.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.fft");
+
+// ----------------------------------------------------------------------------------------
+
+ matrix<complex<double> > rand_complex(long nr, long nc)
+ {
+ static dlib::rand rnd;
+ matrix<complex<double> > m(nr,nc);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = complex<double>(rnd.get_random_gaussian()*10, rnd.get_random_gaussian()*10);
+ }
+ }
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string get_decoded_string();
+ void test_against_saved_good_ffts()
+ {
+ print_spinner();
+ istringstream sin(get_decoded_string());
+ matrix<complex<double> > m1, m2;
+ matrix<complex<float> > fm1, fm2;
+ while (sin.peek() != EOF)
+ {
+ deserialize(m1,sin);
+ deserialize(m2,sin);
+
+ fm1 = matrix_cast<complex<float> >(m1);
+ fm2 = matrix_cast<complex<float> >(m2);
+
+ DLIB_TEST(max(norm(fft(m1)-m2)) < 1e-16);
+ DLIB_TEST(max(norm(m1-ifft(m2))) < 1e-16);
+
+ DLIB_TEST(max(norm(fft(fm1)-fm2)) < 1e-7);
+ DLIB_TEST(max(norm(fm1-ifft(fm2))) < 1e-7);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_random_ffts()
+ {
+ for (int iter = 0; iter < 10; ++iter)
+ {
+ print_spinner();
+ for (int nr = 1; nr <= 128; nr*=2)
+ {
+ for (int nc = 1; nc <= 128; nc *= 2)
+ {
+ const matrix<complex<double> > m1 = rand_complex(nr,nc);
+ const matrix<complex<float> > fm1 = matrix_cast<complex<float> >(rand_complex(nr,nc));
+
+ DLIB_TEST(max(norm(ifft(fft(m1))-m1)) < 1e-16);
+ DLIB_TEST(max(norm(ifft(fft(fm1))-fm1)) < 1e-7);
+
+ matrix<complex<double> > temp = m1;
+ matrix<complex<float> > ftemp = fm1;
+ fft_inplace(temp);
+ fft_inplace(ftemp);
+ DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16);
+ DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7);
+ ifft_inplace(temp);
+ ifft_inplace(ftemp);
+ DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16);
+ DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <long nr, long nc>
+ void test_real_compile_time_sized_ffts()
+ {
+ print_spinner();
+ const matrix<complex<double>,nr,nc> m1 = complex_matrix(real(rand_complex(nr,nc)));
+ const matrix<complex<float>,nr,nc> fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(nr,nc))));
+
+ DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
+ DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7);
+
+ matrix<complex<double>,nr,nc> temp = m1;
+ matrix<complex<float>,nr,nc> ftemp = fm1;
+ fft_inplace(temp);
+ fft_inplace(ftemp);
+ DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16);
+ DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7);
+ ifft_inplace(temp);
+ ifft_inplace(ftemp);
+ DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16);
+ DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7);
+ }
+
+ void test_random_real_ffts()
+ {
+ for (int iter = 0; iter < 10; ++iter)
+ {
+ print_spinner();
+ for (int nr = 1; nr <= 128; nr*=2)
+ {
+ for (int nc = 1; nc <= 128; nc *= 2)
+ {
+ const matrix<complex<double> > m1 = complex_matrix(real(rand_complex(nr,nc)));
+ const matrix<complex<float> > fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(nr,nc))));
+
+ DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
+ DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7);
+
+ matrix<complex<double> > temp = m1;
+ matrix<complex<float> > ftemp = fm1;
+ fft_inplace(temp);
+ fft_inplace(ftemp);
+ DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16);
+ DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7);
+ ifft_inplace(temp);
+ ifft_inplace(ftemp);
+ DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16);
+ DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7);
+ }
+ }
+ }
+
+ test_real_compile_time_sized_ffts<16,16>();
+ test_real_compile_time_sized_ffts<16,1>();
+ test_real_compile_time_sized_ffts<1,16>();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_fft : public tester
+ {
+ public:
+ test_fft (
+ ) :
+ tester ("test_fft",
+ "Runs tests on the fft routines.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_against_saved_good_ffts();
+ test_random_ffts();
+ test_random_real_ffts();
+ }
+ } a;
+
+// ----------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'fft_test_data.dat'
+ const std::string get_decoded_string()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'fft_test_data.dat' we want to decode and return.
+ sout << "gO1l2wKz8OsyeYMPYcGx6QdBG65vnrB+omgAJ7Bnsuk9vkTw/Y9Y/UZEFXhVf6qnq92QHPLV16Fo";
+ sout << "a+IUHNTjoPAfBOTyfb8QRcTj9SaWpxA65+UCJ+5L6x/TEyPKDtB23S0KRpRSdfxBSW9/rnUrkIv7";
+ sout << "6i6LWcxKzdsw2WGsRCX1k3t0adQW49m/yb8LV9Loqs7/phzY7HkJ4D2PLtpc6Wyk1qG/h6KQ7nkF";
+ sout << "GFkHIoh+xKXhHpqWaSofx8H8m/++H++g0VSPqfQ1ktFz+K8UtiGoyR2GqpP+br47YLXG3WqVU5Km";
+ sout << "Di3+IjQoBH2m4jykD926aRvdRrgUH4gZunokl+U6shv20Zm0NL8j4A46/2f++YPGCVBNJJmcJdI7";
+ sout << "9RlPL9SFbJ8rnH5bbLvZ2pKZmmbeZN78yzLUhdGwn4DGpf/Zo1fU2YPUjVKkwY6olW4w3tiBl05a";
+ sout << "cS1HwBeQjnajqsXNyudbrBkM1Z9XiwM+J5iMsu5ldaJ8iLn30W2Te2RnZhJRHO8MgL7Fn1j0n0Qb";
+ sout << "8dB+6aQYv0l/5LQkr5SX6YSRYX5b5rnqhi8IzJKms6dzoyBm97IGTm8pRxtLXcmsk1MvJcHF2gl2";
+ sout << "CslQazsl5iIS6fMxEodmlMdwdfIpp/6MqmeIydSHwdyJJZnNPl2p5X+Il5egmwdaSoDQNphPfTaQ";
+ sout << "R0Xh3xqsZKgHLKxB14Rsf/R7Eu9ZASTByX3UrEHsSzLSUo9/G+tS3n1iC30Liusksh2Wkt+/QtDy";
+ sout << "A1ZX31H5OlSFwCYC/TYitwyl4U9k7WhHBDoT7MdmVTYQEK1dK48nwOhnZa9prE8n3dD40CCe25q3";
+ sout << "Qo4VVYc5tBWu1TfTbshvkmHAcp3Gyw/caqq6jdq5Z2BD1b67i/bY66xhmowOFS8xeA7v6tKdkvpp";
+ sout << "Rk8FegzVdB72wpw3872T4K+eplMDcCPGkwIieF5pZStWxhGsNOC0p2wvpFvTpQgfNOGUvRt69hsd";
+ sout << "xaUEYlWZcY3sfsiOwPGgBUEEv6b+8W7+8Ddj8Nx4wG+bdWozphfz7THbmOeaDM63imIEHmJbZ47I";
+ sout << "QgoyzFD5WoWtZ1wMEv4LL+a63B3FzBcvPvdPaa2QEmyiK9yN7GEePs2Fv2A3ymhGw5NeR1dOzAjz";
+ sout << "lEQW01p8opk/dpyLO18zj8d+Hn4EnJkKD0p1u+XuLRda8AnRu/WmSOOpyG5EUrUoEyuvbECLbY9X";
+ sout << "3AMgzkbxltmZlkyOOwfCGM0yumGYKdz0aGKdyid7ddLMTpQ908dCNLyRgTybdZG9137PQirgX5+O";
+ sout << "08T/+L4EIyyrslOYxpUaLm2ASnSUgiivoIJvfnu8IeH2W9fPupY89ioXIYuwZU8f9FDCA9z7peQw";
+ sout << "9H6l4PDdDrB7nwQhncpV9FYLkHQLbSgE1VD+eL6Y2k48pI2zUndcoHEZW72NcmK6E8fDvfgbKkYD";
+ sout << "m02RiGuj4tvEEsIVuVa29Q0JGO+37n7Mlz7+RMcUMo1pLnh+jibas6R+1LCy7b4ubiKMFB1gvjut";
+ sout << "gjMABy1dJxSOdb9xUa0K/Alwwu3MxdkrbTxwqkn0C2JnVV7z9S2I+PWcfZKzcpg8Itzh/ON6I/DE";
+ sout << "EGK3s39XhLI2xPg3PE9R9QMaisqxb3FeP1NkBXrLQtuQfrSk+KZk6ArQWVgtem799fxgipQsa5RH";
+ sout << "z2Dq9t+pJzNGUnWg5PWzaAY3lWMscn+BIRhsZfDJ3QBtS9Vmib8r2dtYwXi/Q+FhnAHFfcXbhDC3";
+ sout << "GHn16aP2PY1sw8KMtfPRAcqY8Ylbr9EQXjWoIYUs0YyX2Ks8ZgibunTPFz/Wu98RVYswMtjubFaJ";
+ sout << "jb0pK9S6qoe/w10CAAHqoAfca7uMOxw9trZZmjCf5vF4leH/nDgsNjesYn21rE6rLhSbg8vaZXo5";
+ sout << "I/e1uhZlRz4ZNnMlZSnL70Jt0IjuR0YNphCsGZjmvvZ4ihxrcLrHvAcSTJuqW5EARtvjyQWqBKSP";
+ sout << "5XhlkrI73Ejvy+Lhv6n6O7+VrfWa/tGRuvvAToS1wPOP1T2oniDXsNlD0QbMnCao+dTWgkTDiNTk";
+ sout << "sFxsoN8YjwHqYUAp+hfnu1Vh2ovyemAUmo87vuG7at6f8MgFSuZffmBkGuijKNDDy7OrHoh7+5/+";
+ sout << "aOkcvp0pW3ONZ4l6peRNvzaW5DEBTvcZGvRwVCHWII1eGpzeJKaHWvDfLqjaPkFrG5pR7SGCY/9L";
+ sout << "73W2U0JCe2h7VjWbCM7hdvJEgYi/mEarVQpt+0P834es6Rm9rsMCbgbrWl7uv35+LVMTHU29Oxln";
+ sout << "bDzBUJQs5KIA81IWR3R7D+HuJvpMkAYMF73c1owI7K74SBOsTq1ayC81aNlK7YwOOjZyBqwsQ5sy";
+ sout << "zZi0k9AcKRGmTC323o7Tp/n/gkAU3NObTnqPEJitjGloXqrhPvorixBhHXSZy+wgL5R+04KiF1uU";
+ sout << "LEFOzJ0zKUMstTB+fgC7D6ZnVEtUq3HEYnmaRRwEhRSgMTLXE8VvnOdo802pMVN5GMCkH299rJm5";
+ sout << "Ina8mTwlC9JrNuYHot5KK/Gny4KPyUeS51cifByPwroemwBHe9EmKCkcEJPoDpG3QMaV36aopyJl";
+ sout << "GwhZxaZSqbut9XSWr0IMxHUkFeslRB+n/7Vx+xWpDNjQ7JA5S/B0ZW+YBQPcjA3sRQTey25JD4Jy";
+ sout << "RsULxNY5e3mjn59fI8OpBOYfNPTt2Jzppm1GDpym0LHuz7KZ6xk6QAyogk+HMjC/5RcQA7zJWDRM";
+ sout << "dXC4CXUjrBxVzmm/YHXv76LrsaFdzJgn+/qzlM6IvIgicMhcJl+hA1swTkgcw6JRalJiDqnvapKP";
+ sout << "V+T+/X5PSNMswgZURHQJ2l0PkMrUT909pBOC9t4GCsK8k4rYS2o0I0UYfcpm4jMRU5X34zlT8Qv+";
+ sout << "GV3mA0oGq1U2dJwArlPX3gI5sZ2Jsw7Qa5edvQNG5GoRb2j2Muo4AkZXXjbx0KEa5leLIhVL4BAE";
+ sout << "2GTdbL7T8hUGY3QlRQGwSVAytjUfXg4jCyn9w6ZbxUOu5MDBuCEtrhRSJNKuBLInK3Bh+fr2FshC";
+ sout << "T1eDtIFE2EDEaSbLj4NCNWpTFdKMXZ9CQg2VtoVOIJfgKzqAjjcWX8kqWpMFlQgtdTfIqN7gnFit";
+ sout << "do/FO0OzLghevyexHdl+Ze+MjITKOF0mTPPMkcIYcINIR1za6q3rLDZg03+GouzYhL8lwM3WAnkg";
+ sout << "Qg+NM6reQATKFK3ieOxacZYnIwOR/ZMM/lO/rHY/ZbdAnJHbMBWwRtK1vDi+o+ZgS7EgsDpsmz/l";
+ sout << "PguXPK0Ws51OUhIJJ5YDBv+nVPJabxOYV3dU0z49xFpxNTW9pTISo8mKZvLp2D765kExGJ9YKoAx";
+ sout << "Hfi6WEg3pFS9YQLNhOZjE4bQThugIWXhi+2OPgqUIUoV5ctSnP5Lv+xhbkZfjnQQQQffrrU4peSz";
+ sout << "6CuNEVLuNuG/mc3WEDZwf1HxYv3u9pr7A79QG0EROf23zPzaf5biE9e9xH+ruPApRHM58H2RpxXU";
+ sout << "RlkYnfoAUqyvT3Lhhk6ngv8Axhi4otvz7sRiXQmZO7mtzWzsCTkCJoziwRKlD6P6LYnbm4fRYP1M";
+ sout << "MvOuW3NhsQNrsDtgMuvqiVQpRzg157ES1i1qnTjJxTD5emK1RljuQEetbGksyetctWdWiEd8ZfSh";
+ sout << "DHBJC2FLucmkMt0LHsVPnk4ni055uMRdKPRKjTE2MjpEsxR52xiWR3MtwXiEhH9fZnUl1IdBl3PG";
+ sout << "TfLiZ286m4ePm6JOgNM1chtZir+q8pr4ghk/xycWvHmgkqT9dQcFP8iEtlVLCS22/2mS79cTev2r";
+ sout << "yE90otp9vibcTnpORzrnLrMhhpmYRTxRjRaHGhwdJYluARJFBBVTMEenK2ubdLOJ8skZjLzPv1dt";
+ sout << "9IrO1sNUwrMpEie8PG7D7DzQ7//jdlC/HUZaGKrwj5aMUULi+ZYiBLYoeL4N8ozAK1u3KtXLKlRE";
+ sout << "3Akys4Py8+CmrY5qaaDOXZvwl3FF3skmGhx5KValRXrbndqr3Cks0hXglHgNonZh795galZwu0Jp";
+ sout << "ww/mTQLCV0djTdEfjXBUnP49zyGXWWsEsl2jfqEAfBDcT4+mMzAUtBSwwPJYXXAJQz45R32MThNb";
+ sout << "k21X+rw63QJe0zIbOJepHz3jaedMkj8GKNYBjqzibNqfYelunBUqW0bpi81HYdN5OFY/3GNKgygG";
+ sout << "4R5HJaP+x9e1HxehpI/4pKFC+TAIb29uSV5GtkNTb1fYLm0kjeCZNA5GKtf42gBY52N6STl+lcI0";
+ sout << "gD+jJ/ogknne3sRtEJEtCFFe1c50oikyJamQbeUO1PcDUBt8Phl1rI/p4PTP+H686usJVhSDY+b5";
+ sout << "9CdS6F7XSSDiXlpFl+Esex30fRZ8zAQsTo9oN0sSZUUJKcyVk0dCqw2mHWPpyM6hYKQ3ij1nYjYl";
+ sout << "3PzRfFMlu+dgStcBn70jvlEv5pOWXb2OqrN9nJtb29n8jrB2K2nlbcYoPPiQ3yXk+Wpom82LoT5W";
+ sout << "F9NeNwwAB4EDWtB96OU6noW8NHJj7NiADQJGvQpk/3JzIzeBJQCxULYJMRJdBKf61+24F791reHa";
+ sout << "qrH+rLUrrv05dIPDTUvGW5LQLTTFFa59OmMIu7WJE7Ln6gMIwDw3FXnGFzaWnHlHL/9jJ0zM1FQL";
+ sout << "kfK4wTd++GbeI0gsnXWFK0N0kV/FiHm++J4udWwIXZxH7qZCHtwlT/5oGDVujtAtOPag+txUrjVc";
+ sout << "G4iLeiPbV/2Vfc2D1oV5/yyXDDii9qrLH6SOOfgvdiJZr7X3uMUIDGO75x5wBDSxr9t3I2CrX2dM";
+ sout << "M6kD7U1+bf5QVRbkh3Us4NAhFVnLNEcrm0x9Yx0wRmxPKgJeGGbWi7/BHi8ShIFllizuxuMyfypC";
+ sout << "hhzSlxxbYAQwtcC3cHEnyYZAO7HC6hyke+HQJfxAmKyfguGtzEzsiG18XJVruwz7IoOpZS/O71zy";
+ sout << "Nv+T8trOhy59ZUAgIsCAAQJYEBWl/T/qFtkE+tITbVRKtHjbxHSeN12OnHFRoKguJYaakTo4qLs0";
+ sout << "fr4E4nZUMfjdF7oI7YutegY9TkiJ9ujLJw4pfY1XRtPrRukEl8orypWXq0gErnYO/RVtK3XImrDp";
+ sout << "LY5sXH5pNzkqVH9VCl6lh9sg2HWjNwv9bDcDlIhvTL19Mx9yUtx/iQtG/OKy22tW6ByahPNnMNtA";
+ sout << "tBVB38RLf6eJr68mhn10Qg68cXxVL7/zEIZd9rUaCo8xCzeFblDNErKfG02JJ6fbQ6M6ZdNez7Q0";
+ sout << "x2IYbz2DEk0wHmR7OtA/oTFMvJlyMt+dDWTEpHnvqkbe+veENpxn2WWy3UsumkvhhtzzmLxyD6Sh";
+ sout << "mMbMPwgUjvMG51JfRrgMfJzT49z0sebSfzvid/9QV4lNkR7s9nfUJEwAued4S4klRy3LiFdQhjQR";
+ sout << "FOZZNqUge8vxVOzVCfS+xsjvnGrd7azt7LJg6wPXFgPfeE2bRlx+8AoRFG7SUpudmm/bkNw+uNgS";
+ sout << "YRdaH8p16RyNoMlSfi/7BNDhtKwrl202pVuCqhFey0mPYehYee2HhLZs6ph+HKMYy8lZ/ac1Q17d";
+ sout << "1tcI4WH0Hz0B/3GWl8xWfoq2OO40EIjuCPNhk70MpiytWXggJrKoKPu52GOqTU8+jZ6F+u6U2muZ";
+ sout << "6QZLYXDwPaNz/lq5U4ACw767DkhUHd1/h0g6r/RwtLKxdrzYldQto99TAMmHc+z9aIciTv7kl/Gs";
+ sout << "WA58nI8aODhwjIkOGaExdlR1k/3JR2tAAj5vRzYlJeakhAA82pA+8xMPZr3HRKMPCcCJZiOFUYkA";
+ sout << "AAAIGp0hTQAAAAEAAAABAAO4M25wAAAAAQAAAAEAAAABAAAD22vyAAADGD3aK6jS5oPrMPzwAAjt";
+ sout << "xbxcPWpGahG+GZHRkwABJjSNeG3dJthsOADhIXZ+4e8jAWWTyn1FQrzNkDMdvhTq6iuNBtgaodaU";
+ sout << "LlEAl217cnKFEVX0Gz4mAAAAAQAQ+2CiQuAAAJqBbzxJicET21QADU29xADCbz6wuq0KAV9tgJYC";
+ sout << "z4z1fuKC2IFuACLkvCvxAAAAaRl6AAKB7r5zjcz2HSjBdaDc1QgdOUAzmTdpgegYYD9XVQtaphsW";
+ sout << "sUhXeeIliGybsdhruMlJ8hi2YzzBBZc1GwjNawB2sz18UCIbQKBoDMBo39MbAAAAP1M+WcfB4DGK";
+ sout << "yDXgydAAAACCIiA890S+Coil7foud6zPIspdcqIDxhws3Kiht10f5NDHUhjgTYsAjtAgSoGLQ64J";
+ sout << "iNm0zrLDZkuHC4Hm69wkOD43AHbIxu+k5r5vgCW/m4traQAAbol6XlYIroFESJxqPkH6Zdxu3wnY";
+ sout << "DA1HzRnsBIlQ6SvPrboLtfVXncBN7aM7vLaX177RTxKn2qep5WX3/yG9sAP1QSkqapaSdLpB0Q2N";
+ sout << "9t4O8ryKiLFvPQrbFhK3Pux8X2PWKHAzZfkmUSQh+OiOICNKDbPRY0Es2GaX3Qnl2oIVflDINm0i";
+ sout << "Y6t8x0AZriURrafLgtzuoqaC7rAYwb1iQoFJPwXiZJNIk6W8g0623IGSygd7aZ+xqv5NjU+q0C9d";
+ sout << "0PJs8kZ7Db15tjjTBIoS9gBNefJg2D5TCJjEPSzPcZeTjPzeZJeK7v1KHgunnfi6igrT1efQvxDa";
+ sout << "KlbIBaR5Mk6fxdih+YOFTnVsT1mE8b8nffJ566Rgc6azYkJizsRen8b2g7jzgv2O0BOI7VtdrVpV";
+ sout << "BxoQ58p3EGMCV/mlDhcfVTctvr8hrjUu4OhN7UUoqi94JRXL5XbNDcCbFwXScln/Lm0bkzNsDnlM";
+ sout << "R4OqLiz3ktIpJhAtUemrNZtPa0+/Ge6PILu4jPNom92BcmxjJusoTOrHSTQg6cYtcSB8spjVRdpj";
+ sout << "tFfYfDY/6rxKpE5X+LU8P59b1/cdrNSbqkuRIuGFFzkQDv1BUNfa50aIKo53OmvDWkeEpkpS2xD+";
+ sout << "hz3uN05FteKkZ7kHDPmEEJ2lUHB4oicxgkseb8nWToePhOkr2JcNakTx6yc+ZT7bzPcoT0hueCpg";
+ sout << "Ljzb2AQ0UdAJEAD9eruD1rF9PDEaXD/W4D3ja2EgvEY9wSMR56Ne8LSi0jeFjp8jKxcbmBo848xY";
+ sout << "dofjq919a3V9KDRUZ3d9t2Bmfc4yFoS6nBZCVy9klxK2ZaKePGjeCbENr08cfenUT7kA+ZQURi4u";
+ sout << "WEwCgdui67K4H5NPvbq+QAoKn9d+ZHx2YfullZgBCi34oLzT23yccD4uxP4GbZVXakvv9lLoq+rf";
+ sout << "T2uxZr36V3aJlhqVoSReJ4Oqz5qbTKNH9F3S14GFJOTByKZP4XJwVytHIcuu9JLpQPDt/nkREX4Q";
+ sout << "0cKeNqXqdujkp9XCAZEJ1RkvHE+F4tiUCTKHvoCslgs4x9alnUWVps4Qy5CMaB8+x1eHM3gkuf0Y";
+ sout << "qNIFuIXUybNj0hnKO3IV5k8m2MURJiZsmM/dg0fkwJG0DIAmCFCXdugfHMHpqXZR7IBfA+Q6BICL";
+ sout << "Njxa4BWCRQoeKc9LD1Nits6rbVmqkKPAlwc+yhfYzny/3/ZHfkMuF/s6CBYugf3dYMIch10I/QRi";
+ sout << "9mVUWdCIacS1G42Yaja31j4s+304a/luTAsDKlOObmKTzDV2fDhbKSLToOY9iXxRK6KJY0GydspW";
+ sout << "ak8OvtNCCVvEsA5NccWOn8njo/sryvQgM9yzV8XCI2MDF0blRNFJQR72EDdG9NXKgg1gj/vG51dH";
+ sout << "CHj9E5ffneorEoPjn8pfOny29jcb16D9lVc1zcp7v83pLXuAyUp3lNC2ff/PcIbRMpNns1MnyV2J";
+ sout << "8gjsQTqTL9AAcvz04ohVo1LbZRl6rl2D0PHAOCOJQAJ65OA6BHxfeh7EDDnniKfLlI3CD6W5XZxW";
+ sout << "KPVAa4RGyCcSLbmb5d658vkB9rBKvaqncbiarQVbyLQ5cCRokVd9HTmqYX4Ky+0yPVcmoXAS2B7Z";
+ sout << "S0zXkeuA8Lo2ZYRad7/CBg6foq40dpQ9EPw19H0YquMXeGVEsgSvu7jBCYrtuJzL/wlxe+plrjXy";
+ sout << "E4iwHoZ78KKyOgwTSqKtynp20YkZG9NC59XWvyrd0oTBrGlFpOjer+OtcEgl+3XYl1q66yJTyQie";
+ sout << "x/Q7AjBZWPmlhEkktLtXYW1yIqr+EmYJqJJ4Bbssvsa1/jd/EmZQV9//HRBmber5Kw0C8royzI93";
+ sout << "uzM2t59sG8hnOHXCjAQgV/HuTCTD3hzeDgrv0aHMfIC+mx5Xkt7mLhrhLeODhJxguHb9a4pQwhfK";
+ sout << "gD0DKhyk0RwVZBNF2+3UfK6bJ1zeUgf0FJ6Dvak5i2+BUVtohn2qjbcIRFZKEtE8Oekca+FE+9+S";
+ sout << "mTsAIvR/cbfuDVG4cmQE8sPP2nB75KQJHylHW3ZvR5v1icnLfob/CfYMmfSPYDgDFx9cYroX/luf";
+ sout << "lLKKxxQ+1r4TIUGKcBV0ZJcvjCwzU3C2eG6oN9P99+qndIXYmvRvpLmq5QYyjsRHmXn2MTKSa8nl";
+ sout << "y+vPwWboutMugCUmEfTXhvc33o5/KKWU9cVSP5G3mrGBDu4BfG+GeqP2+DtbHTi9oOYm7723fFF8";
+ sout << "CUxOvoAQkOQQWC7Wfd1QGgsVPxhz7FTsviov+G164lh/4Qlkqy/8pzEJdQK9uYmCsvcip7tKbzS9";
+ sout << "lRX9OoPwQZQPWPosqZjbYBmEVr/e9jtFv+2pH8p+5S0GI4qiNg8n2fZE+isn9XamDpOqNEpKk1gd";
+ sout << "Xg/ombkOajBBxBlpNbXQa4aZylfiz3ANw01leRleTqeHB3HdNnTaJn3mr/lkKg3DYIA1N9/Iv8jQ";
+ sout << "1lQexZCF0jEDg51sUNPF2yDGnpSpZVkDLXCGIlKe4BI1/pR/kiqrEdg/ITP3+tVhND3x/pBWKlNs";
+ sout << "AteO6/IZrK4XOQ1DnZgJSAdz3Qe2uO4wY/7MQBhFO37V5gAQAQVLcXsik3FDcobieQdvvoJM2Z3b";
+ sout << "i+FGg1vcn01gNGDH0PvTKSEN/KFYTpyKDw3sjzgLbfEPOJcGmIFj4JUuaEYB2TfQfzOwzM7QCDQa";
+ sout << "6jbMHpN4VvRPQqZ879vRWzHhg8P2M3rJmcQsrjaJDIL7t41eMOmB3Ey4Vajya50NPPyVjXEtFdYj";
+ sout << "TPc49LO+npIw81WlxPbynhk9lmLeYwr8LALpc5dNFr8BkfCNc3C+IFi+RHSdq6tj7vXqvE+KtH+7";
+ sout << "lFZNdvI4hFiIrDJssynPUFAR7bob0zEq2RxDGJBSe7AmmwhUgGCHRFXlpJg8bFCFs4rCgVrBctry";
+ sout << "AY5TT1WWTTa5P/jPOCTfrUOwjxgGD6ubejoyGIoBPfsRM57XSYv9gQKTi5hk3k4qHQnItrf9p/AO";
+ sout << "LIxqMGF/nky2Z11JWz1krqtE9phBhmXnMk9ap1YWJd83L5Et4wxpud+J2JxJoM4kjtnG+HXqdEsC";
+ sout << "KCg7bNrSyAjtohz2vQvpXsZjWkq7QT+fTzI134PhMzbqnsWdKi/dDsZBNbXB4ua60uqFb/tLtb7n";
+ sout << "sNBfRnWfch0qZLuil5eAWkLlRC4zyaF1zJmchHnVync81bHIJVNj22+ctIbdN+P7aCYBA1n3sl1U";
+ sout << "0LDKWsXycnOxmKFUMm8kh+f/eP3BQu2Pe5W1tYDQnge9rvF1072VXvjpDBIna/VPuDwCFo2jPPl3";
+ sout << "jbxOuvZTOSTWOwyiX5B8aLBfRXm7wXAwRoRy4WLr2dsiGFGSwP+pZlJLfNGY1vmbILPg/9iNNMdW";
+ sout << "8MBRnhh4COlSw5/NxCl3FtRLfO3uLe3Z9+tBpX2qkUHcFyeulVFINjZaNB8gUBPr/Ub6/pk603yY";
+ sout << "YTEqMSo239BJZvCsJmpgZaCA/weLepsvaEzpJvjI3Gvo1Jf/zdm+eY8VVwprkj4WKWKWjHq4miAf";
+ sout << "y4TRxVYOG1lCH/cGcbepOcRMUG6uY0jTP7tqNfd4Re8IXTLCq1P4EOGuvkIe8hT/YdPP2B6W+Rd1";
+ sout << "z37NgjkKJqiSZBxpnj4WZLmdJ2eOUUJOR4Mce1mHEKN6gOhzzFRK98ZB0WEI6jITJ4wNcdDpw7Eh";
+ sout << "1Sg3JvoigX7Onq9eCnb+5uXFQbQrffuNiTCa++zx5Zm89m7ZuTJ5NFhssQ3v/mv8Wesl/9lebi05";
+ sout << "KOVcyMAq9Fzx5Wtj9GMHrTNTyzS33CGVrx3Imh88CXZ79PBLBO9V2Lpjk/yIuCyP9pKezOZDzTED";
+ sout << "rrMRA+AkfStHcPOL8WwGaAQgM+AlF6FVrKzs8TQW3hrBUztSeOFrxpDLod53zDcNaWRe/gpKtB/4";
+ sout << "RKRjjso+LANmNLW/IepM3Jy3sepn4slG9YS/up5puIZY6zrsw8n7nnejBrUBSgZNUaeYLCLhWcWC";
+ sout << "aa/M2BfeOT+X9PXJyzvQxDx/fw2duaVO0yRYb1ZNwtoOXWWXzBmoWKVrlXcegQ1wDzDkLW6lwnww";
+ sout << "4wJ2tMR5bWMUJii/0Ep50BMgAG2TrSK5jrBpiAaGJaSOB7k39YsZ/3/8rozN1IKa2mrK6VqkvY2Y";
+ sout << "AeOhdivfgdgccST6Ymbe81UvjsiVeCQx6tI2RcnR7NTxLC3guqwsSirHXDHTflWI8aP4bPb895Nq";
+ sout << "7JdomRqug/eiaoSfv4AotVU89pWyzC6FAN8UdnEXjZvYNA0gf4plXcMvlLKmRfHj/vs5x3v878/m";
+ sout << "elmdsnYX+sIHEc3hAZSISJoTZkGEptELFMW85Rj87/J7d7cC1q6vTMSNHgpysPX+2M9BNMEGDpJh";
+ sout << "baYrMV6ASpbVikZm97PCzHNnwQAxsOPIFoA/ZtoyXleszvtD2jgNsKiMQo92dIFDAJU+4FIDWG06";
+ sout << "7BoF0nZgUUrfwzC0kL6cN/ui9mBpnAu09t/9dValOs+/Prfp5NfdYesdUpJCqt+o3/VcgGCd5OfD";
+ sout << "1n0lyUjd0g69rM73kEb7jOFV6LhN/5sfmzql9DIqUYtaxs5nJ94eFkp+lSLeXKJqBIrxEql0JL8H";
+ sout << "ISH+HMQBpE/hWIkXJ/RryGxYv/SLm4mfKYNgp8i0KKzpp5fiK9ZlJVmyLAM0eRAMllUa4j7Q0zNl";
+ sout << "p1EO0iv3pD9Lhbb60VlD7ObwhDApO6DpIy0mjj1G31DWJi4uEzYrUttVsSpj6a2+rrI+7nMA1wDO";
+ sout << "OEGWM77EjjTElz2sU+1w9gTNn6j3uOFu5Y4+/ysXnAIVtg0zQluCSflOwSAFyUvpBNoKVnQD+mGH";
+ sout << "870jvWJ/y5uoyjwaxl33+t79t78PC8ycBjvu5M+RFchnLW3QqUSP4l84gYj5oLsz38LzvrRU+lUk";
+ sout << "mFDmrCoDe29pLZXYwnI7HHYPEtPBJX2fdBaudv244Z/a0NPgMSHS2uwdTkdSdWMK1lTBt1GB0+XE";
+ sout << "HqdlnBbMPq9U4y9Uhh1hxdBZHvAcrhYjXDSvbDHDzZRE7acCvLmAOR6aRoGup/WIowNtU/wXfK5c";
+ sout << "irNg4hSWLIjPNB54AkPsPc99IxGPH66PnE2sAWFALd7E3GnAo9N7t5WniKWWI3xtmQLbSOHwdI37";
+ sout << "iAOHsNgX5TjGzUYwwnQndubqfV0NY4/V86wO1Rar1qBsIUSylkids760J0HqyUsWNXxsS6T9Hq8y";
+ sout << "0szlFA/E8IBYgdH2LwkayEbDbKMsonJC+9NhEx8u3dj4ckmmKXOY8NT97XlqxomKyexbsKu9mb1G";
+ sout << "NtxTI3Yt2mDKyt4hzMvhF/DgvKErMfTMTJ/h21uaZXrmwu9G+yQQJVaB/LEfrAenakK5mhmIR0Ne";
+ sout << "+e9cQQSr+iD8oZMy/IkM5rd5FOk8FY9pZb91hXaWNIMDfwyfZZATaJruQfO5cgjMGXSjpq2gFHW2";
+ sout << "9m1zGxyWrEkm2lV6FQlPLHBGybLKxq0VNUqIsdM8VX9Kv+6i8UgN9Ee+hyd4a9In41m0Z50jqgja";
+ sout << "Jojh/Np7WSPqnPhNCv6/K0teVxq3bo3UfaXlrEZtkAi0hgy67XtZoBBQ2Se0ZWzOntgP3Yo3OI1S";
+ sout << "02Ta7r55Ox/WYR5NoFoT5P4ihVcZVmD8DtuFLtjWeWKrHAWPpDDe0RYpN9Ma/DOOlx+vf1Ir7kzz";
+ sout << "oXXOiVjM/hgKR9QjX1N05clcmkG8uSoxtaOU6Zov8BKIZAFfTMdQrhKBrW+Xgi5sKef2mBhPGwyx";
+ sout << "pnudRmwMFKk1D8UXtuoEHYQX03cvYEDmWberg887C1ca4UdAg7b5/mysr2g5Md+vGHrqRtJE5Zhx";
+ sout << "p2BFSzeFV83Zpe7PYf2uTLffLleJSKV5l6ohKDh4V2P4ztKvNwTKZYnYFdWlfMPUz3svuvihxG2h";
+ sout << "OXNZJ9+Byx7v6RrU/42lLyY6p2078QYHif06BBp2267VkKcRN4pP9LlS2JECoex4vg43X/dE/48f";
+ sout << "DbRfW8KeR5kSOnH9dDWcwdoco8MBwOV47KIkte6i3vj0BpAQZD+GFuR5EfIBxXuklMYFr7KRS9xK";
+ sout << "oHMGA00uxrV5VWrzCm5lV4l4oMQcv9/hqFTLKo7nlQP5yu/TyoF+OSXP/qKX7N+CASLfNtwL1Ux5";
+ sout << "fweiUkaKnLZh8f+bxh4x0/UO3H0LWyq07/1evSYUBQhHkzSYPhkTq2msJ+eFBc0+gpbOWntK37JY";
+ sout << "udFJvL50jNoZf9clWcqzxnK5M/rqMjVCi1zzboiC1vyxWPhR8QMvEMRZ8XpVW/cAdLz5R+M3DGms";
+ sout << "KCJtIpxrduXr5+Bq09jn/1oi4qro2/ikBRTVTLj4js+yaI2t0uE6XQX1PO2JwSHW3V9krhJ/7JJh";
+ sout << "sLH3xidX1mf+O/hgu6NU1p4yrsGjz6qTYOQSjU7cVsCMRxW8y9vCWh2PcumONvuWeApgNKkkQj5d";
+ sout << "c+8ftbo70YFtVWhKjTG4BHpNh6fXDC+0ZmkhTFEyHAwx0hF5k7217oTa7aKEQp2qGEROYxe/wBMO";
+ sout << "Kq/lLPLQpVEldVSoOzmFNxGTkOo0j/bjAtwqwEd6mA87f8PvCiAWLTRpPFla4UbU8X0t82Ur943f";
+ sout << "m+hm5CdEBtyV9P1uxMO6ez1SI5YezMPfMCkhpFqozJsPrjkPMIUSDc+WtF8frIfNsWWmjukFwFyB";
+ sout << "3UU7r1AHWLJlEo4g9XPsMks3LkFeQ30cwm8Lnlj4wAistdLd3HNBydGSxO48Uaya2M6dfm3Xeu8Z";
+ sout << "77+7God1knNpAnL7GFhTTTV7XcFxEr76Mfdt+KJTXHIf1M8ofe67hLrsT9FunpPmopFrNp69v2TD";
+ sout << "H0/SQbseKjykzwvJkK46UeJFxYabRHSgL67prx6jwl2Cp9JnNIco/hWNPGbFVaEcMbZx5lkinzvT";
+ sout << "KgwrJbVJ3oKnY2EJAnNDC2f1F8UpC1qQyEvhkA7g3yaJUKMF74TXqjz6BOhGmzn2c2GQBqzZ6d0/";
+ sout << "Ko5f97NNO360xzWgTIiLbg1sPrA8OllufpaQRxyqFzTlU/kSU2BDLWXM2Iy6tGUAWCUGIkMgOUia";
+ sout << "cXdp07NMyxlTzWBux4nmXRndTn9y53qZXjkeF87G+ZG2Q4lUmp2hunx+dyYVdOFWCePCYJ8TcORQ";
+ sout << "kPI7Jcacu7QjoOz652vsf+PQPsarO1KlyKqX4rFudTs79TVOYdHsf53Y7RIMqp1NAFDjAsyFPdq3";
+ sout << "YpOr3UrJu2RZm+eumrTK/JMHSRcbJuAyHJsBITQgvmisRy1wlcDqzw3k8OcfpkjFsJNPNaQ1kL7t";
+ sout << "HiNZMtrv3t0ER3+NZL6PnWWp/tN+ASeE5p75S2UyKtgLONa6xg+gNfU1PzxuqgX6CU1me3GLx/xs";
+ sout << "9pO0eUq57cy+gFcDzvTPw51WUrsaQhD+ayfUyRw9vpBiw34NQY1vMNReMgl6EbvRxFULAkx/oDFy";
+ sout << "wnPkvTkMobxW3JnVGL2mb1iHUKmqQzBsMk6zQKiqtasPGrzg9Reb2DYSVdLpZWCha1BI8ySpOQy/";
+ sout << "ndBNtz6DKHgIQ0lp2pzmJbU7JgqZD3UmfnKzHcjskto87KRz15aWYTnIFg7INMFMNUXqUYizX4V/";
+ sout << "I6UpRr7GvKsITthFaIdwb1AGzKXRRmv7NrsH0x8ip+p6mDp/Sb9dUG0N4ODEYfLlGzA1U3KbZAWg";
+ sout << "tfH2AF5+vVKtgUuuMH7XPQ9FBT8HwvqygWub9K6kLC1qwH9pK8YObWtrfOdTz0yfOwhaMo/kZCzS";
+ sout << "+CWcit3PkXtOuNXlq3oO7fYcRDPdCOFbUs/KV522grIRqZ7mgludKZqP9b702FVEGFT+7TbxjS44";
+ sout << "sx/F5YjT/XEfCP4ivZNjSYdEsQvSWfvnb5IBkjJafE8xnihUwcUNYYY9gUvKqZNm5XjKiTFLhPJR";
+ sout << "SlaNrK/pzAMEDwVzHWPeF0ZG9y22WsIFfsYgDqJwjmXzd7yLkKBBe+NuRPlvhh/Adtzr+P/4NMoz";
+ sout << "f7dcJzr/+VJMCqoq3tiRo6j5nOm3sylEvA7/HTethnRyHx494FMdSwJ1t4AXxrH1dSFNGOXWJ8Tn";
+ sout << "WMi25e99RMpe+Fy3fBygtCiXgQ5/sozmxR6LEjv80uoqL0sopOWKSKe+aGZVzzPHhkfJV/HY9N9O";
+ sout << "vJuNA1Cwp8X9OeFYzusBXdrWMfzTxfeYg6Qj7coKEToRF0SRuskob0+oWzufMPJVCCRRevu2BodH";
+ sout << "QWE5HiZMuR4SYbEgwQBnH7RbF9vW/DB57H3HRvRc+NBpbYZcWipQyi3cy1RwmSNOtexX6XdqoXg/";
+ sout << "iIBLmFENMiByy8AljjARQPUiH6OBU7xAY4zgiBrM0JCtsNlhnDuF1n1udIMnjmUcjAlJ4OkG9Q0L";
+ sout << "w4RvKt6/d7xijIvnX+i0+jH898Jhe+fX+vu9prCfPxnhDPCeDOre7g5xcTAxpX9PyoFww93kdkUq";
+ sout << "FAt7//v+bzkV8FvhUzso67ANCtS9w+ZmeEoU/ePml8oW2xGaOCx51hzfSzXhcvi6DQP70QhWGFKE";
+ sout << "lqjV3s5Ra5WC4P0Ipqq6PKuho8bJn6hz4nlhlLaSF5GfDpqQKzHxvxvDw0//PJQGNv7LoCu5xtmf";
+ sout << "u/CnU9wUfWie3YRDs6023yhdgyuY2nbSWSadumgv/zBNJH8zdP156zru/LLvr05hbn5QMlvdg3H6";
+ sout << "jS5SxCi/FXaW5sOvngedtvRwZuN20nHHZQIFLr0JebQGxOS3Pceh6CeSmslzrvflpnrFJJk7wKDw";
+ sout << "ponOE1ChzU50pPPl1Cr6vtz3mQ+cSSUfy/yj3jiIjVzs1ZRa73+iC3ibUKWzUcuZFeJtZ63UFYTL";
+ sout << "I2oz6V2ylX4Swqr6INVp1abOAWFWba12EP97P36KMGN+Srm5E/BoowWKhpy9uR+AqQVP+NP5BW/m";
+ sout << "02UQyGCthZMnbw7lSD+0Ihu4E/j0Zfh6K2vBz72vxGa9BW3aDUgNvLRyU4CyLf/X8Q5+73iT6Cwl";
+ sout << "dlFumHdywzarpmRS01qfzCT7WN4Pd5HKKNvq9KaaUOUQqmkXrrGaaIoNTAKHh04hhi6BF5rfueW2";
+ sout << "rOFlukiABYz2pLHwd8JCTUUH+l5ly6G5NhNWGdM9AH2tPxUpmqW2D3hmaF7k5I+ehdNQKHxnzXUB";
+ sout << "Wv7O540zfZEVkF8cGAuaF9rvd4GTvTSU/0hO9JerDpYPDcwu7V7JT0lDjjlVyf8Rzr/5QtLAvsmR";
+ sout << "ptFG+VFBQFzS1oSd0yQ5sSLnJbZcABQq5zK9kYMlSVg+DERu0yqSaxk7f5Kjm+KHLjoSR+9th9lB";
+ sout << "AydFXCQdhbntrBkbSjG5t71xVSyhxcazUYoeulumJbiDMo2Nz/GBvviJL3ZAAyq2D0PFXJsVwk+o";
+ sout << "0sSWItsGGLiiw55G5NN5KmQ/hxhbjxCMjzQ5ALsREiye7MPDdg9GgwA7xa7NdVdjprN8RxNlCS9a";
+ sout << "bQRBO+z+P+2MoN7tk6JIWn2KU7ex6R6FKjwHz+kU4lJRaJF8LzpujYtTw78zAgqS2uTPDGpHdfJ9";
+ sout << "uxW9x2IvyYg63TG7xBKS4+iFmSJGHsR5naozx3D72vCZ3jTe8D/fv8FlHOoPyYO4gOZgyC/cOIdC";
+ sout << "jM/EFJKHkL99pSetGn8KALo7QbrosHpZER2s1nhIXc/kfG2rq+scF0ECChECnN9sVYEzerYuNXzk";
+ sout << "nsAPu/3W0xYMg+TE0xQTdO8OTHsfnbGAm5ELN0wxTVfdXIE9QpYoSSRGtHphyFoKcOgRkyFHXMmb";
+ sout << "zPuwR4bRhvT/HiW/bPnNrOBL2qpeoMKmhRyhU/8FpgoYANV+tuh2DCWGSu+b/xgzGO9kuqoekBaf";
+ sout << "PGQjh+kV5tWT5u+NTO2SxkBaGLZkBK8j0a/h7CtCmwpK+7Hq6WyiFUgAUenY8FiZAgd/4lPXvcJf";
+ sout << "JV+e2P83P/iAnnfFH9rt48Bq54rTjuhgDJ1FGHmW6uzHqX1XTINTidVSukSpy8+hpZvAiiNrQfTc";
+ sout << "clOqzIsuuJVBivD5t9/BwKOp3dee9ZyDc9qTH9fjqkq8dKSjZjwTil0meMxI3EEhCOssrXql/+jH";
+ sout << "JjPHBao4OnlnblFyPUFDp0MyFP3OE3o7Wa/RYrhLuq+edYM9aKgdWzvSbJAn8/LGHDt3/iH5joVH";
+ sout << "UrMD9vp93N8RBsdbMY6hDoHEmHWXJvQAyVQpS/urgZbjbdtKT+vQKvwrRlUY4osNJ5fVGVDfY1qx";
+ sout << "+3u5BWq4Fd2UwsKRM8Ut6m4dm0yog+pr3ZVcPfgGsyXG5HZdyDpdg0AJSb/wOAJiNZQGtvZQ87jJ";
+ sout << "2Vk3fIVSb0gk2p4vINajCZksZCWordROndVW9kKviwD4QBpuvxTCgW1ww38M+0osHHzptdC/h98m";
+ sout << "tFXeW1AOwfG41mRwckQs0MWVn5Be4cUFiZQ+qkgH99p+3ZoNiGdk2CCS5ABvejuAiJ07wbiAzi5S";
+ sout << "2kdjxM6GgpBjmT3gOz3jvc4dgcXA6RZC+sbh9k3U7LakNE4PfZ/WME3qbN1nuj+9/xhHoa5/4gVF";
+ sout << "qM571ayMhfBoKsoChNxVxiQDwMnMkIsbD+1xUEr/9OzIf1fnYKU26pX94koCtsdxN5l0HH5Url3J";
+ sout << "+cSXI/XX9lOo/uUM5VW4eZ9aB9zfOx0/fuZTqcVxqkzUz8SH1gfOS2QzGBkoSDATOxVU++tTGsVS";
+ sout << "jvCtloGsRzz7cSOdVEyGH7Pf1PweTT5MIyrqdmVDRyYiYcHNhxJvTS7iEcL6wdwXS0iX8NdIp9uq";
+ sout << "NCrhqBcAlQZ8vKgXLe+FJ/oT0lwZQaF46SFEWjbvP8fm0xMVGm+pkaniNAVn2m7D1FADdIfcNTZt";
+ sout << "QjWrboY5ZfHbOwOl06gTV9JTQGRjFDUjy98D4st3sQUZBq2JaFJ32Qg+fy7knNSrpu7sHWloP+vT";
+ sout << "PvoJGbK/hFP0Zn5ZwN5V5UEbE2JzoJE91+VHhUwbh9TzNBxwnDroOxwpzHvJnCxnb8SPDKKrdT21";
+ sout << "n+U+kVkf+0MiLo/GEQMa/365TlPqVNJDKzbtOwFzAV4/001Z8oHQwWqOeMZtOGkgbLx+mSs/sfCC";
+ sout << "fMbMfcXXyjRgPYsG0JiQHXnB1AlQ7jdTslLgpTLra1uWB7PdFydbSd9xBrGJliq+s1VLML5vtHOx";
+ sout << "zIkRy3z/6F985ZiZN991fTL8O+dnIr7/bhRAfnGz9zjdjVhfMCSZ1qbNg+pDIFre1eRezr7Vi/sf";
+ sout << "/EAvYapgYXBpO2YHMivon5jd8BHIiDuMxvO3k8gtoc3N7cS4gRFbBX3KF4OY7st3c2TVBeXFGfQ4";
+ sout << "pdfpb7uwAfOLY02qrouv7egOIQkfvVGGtFHHWyF7Ua+rBmdUgSN7qyAoX5ImUCkJBfYWKpDyV+qE";
+ sout << "sl2Zv6TAlmoV1Ejb5zpSRGKTOZQeOL32IC4ObnmNN8Tt6PF2jWHKfF+E2EvDUnkvuryJgdaat1/e";
+ sout << "ROv+l2tnyxLcsudPd5IC6PX1PtyQ/VRXVjPdgALubsid997mfwVUEKqN6lWCnyxWevrjqybNaNGy";
+ sout << "TKR9jTz6c1lS/pvGyI6/tRau4N7yCm17IOPL4IS5LQxIrmZM5u4CJbXb66Jc+ugcFY1tLbh/2cOy";
+ sout << "+XriVPTbEGDQ4A5uj+7xOo5ZRdoTrdVVPtDk+dlVfKzXzxpb25S81YFqhjO2mhIAQUVimhDvfzJI";
+ sout << "3STEBOwpTG27aw//24pQ/l2/HG5fmjCFKnU27+lJU4OZWLVC4xyNFcC45PPCOcMbbGOv89uwVAzB";
+ sout << "grsyHpXH8piJ+i48nWyjJYRcMY7QruLJj0XwI/zyfostfynEtzCQ4z7izYem68epGW8hJIno9YC9";
+ sout << "ILlnQ3D1Q6ZuLS+DZbbYX5KtL51EGpOXIcURfvVEgUPpJDRszh3+2ftzLqq2kBq7/fp1qBVlX7V7";
+ sout << "FpB+Atuo0nScYLh2qkckGVz6FD+sd55R1X87c2nyoq7Mneo23ed3+fOWL08vOS8k/TRYm8sah4Mw";
+ sout << "jBMphehu83zxUzzmCxgyw/YqBpNeIvKOl0aG0QKJl+d/B6ZUujG2c/d21wxvGfhUdmPq0fIDY4f6";
+ sout << "C+j6/lZmgeQKQtYoAN6ElDRpTXd1A+3QqqcI4B6stZQR4JS+yXVg9cbekGvrt4QQs4JX28naqY1a";
+ sout << "2FUNMAiLZ2HiAOR65wSUstVLm7BfKhWmjwtsaC/0EW20Dk2UgnVOrBLnBQ9qAN4b7NmU0WonmxEY";
+ sout << "ojTy/G3EUpNlGyA8/vW01VWvnHXY6vXYFIqbr5vVrcva9uTYUR55ZqOoEwieXtBtjMbrfsnHYv8P";
+ sout << "tR/v9PeZbgs9t/KMYR/uMBEJ9AZiee7WVJ4h0btOaXK/46+XzeUtI5WcyLNL//5ryxuloF8d4dv+";
+ sout << "Cc9x78QO31N65ELT/k1U0egzJxgrw/3Hki+p9JU8IkTsW0KpGjEGZBOV2pgikBdJvtj6FAa3wRUi";
+ sout << "PJoSIfsvxzrV2luFEsFgl5yMEW90eOsiLKXcWDhqEEbulmLx2ij8UrCVCY56umeqc+LyPlDfKgrJ";
+ sout << "EZjImfmAt2Ygq1eC4diAMOmE1UxpFq6naGmN7qqepdQluKbFsqDlYTpOaqfems/zLJqkcaw70LGr";
+ sout << "VMp8Pqv5ii/0GDRXDBhkijv5WzyxkR3EypXXxNR/+1AiALmzJmi+es37MAcJrzTJXLa9VBqivhbR";
+ sout << "8dN5+0P+2LwViHb55+k9sXzIjOGnOk0MEWlYHOGjMRTEGFRTN7A9nCyYXy9GNfPK2JcHVf87/hHY";
+ sout << "5K2i4i7tUrCe2Csag8f30XMI6neY3ftCMT4Ig5IaZ1sFOq/T7tq9itnCcp4mwlEYPMyMIkG/F/LU";
+ sout << "hee2A7h1pyXJfcGezwMznUz7W2ul6nPiyNiukKwkyscniZUpLlaY4QEqucllRuJ/68AaZ4b/Oej2";
+ sout << "jyS1Ic2KncAmPZ+1Bg0neMcNSWquSJgYzaDClIpXV8f3PkHrr2uIsC9W3eNshopJHpqLMzuaQSHY";
+ sout << "2vAiF94VOhVSyUWpQ/1azlaZt+to+uWlh913iWJJ1W6/ny5AwzWh7M0ikd65/vX/nmzcWDNIe4FB";
+ sout << "dXtUgIbWZGSvTnb1Q0ff0C4F+XGSG/27sLrdSbmNv7IjZcqQJkNsqlIjOMUXKRPmfMmOD68qQomE";
+ sout << "m02IDYew+Ah0vholXhlgVIa1V5eHItbFm0krwNfQsxBESR4dEQMJbkQAE17x52EneJzQRaNjkDdB";
+ sout << "dCl8MkWvCwi4iNl6Mn5jQsO+3K2cC6fYHDRReXmp2OmWXlR600U+oSNyuQjNiUM6CD+q7IV0SWS9";
+ sout << "txg/d76GC3ELUb9MtBhtm0dUBT4hCjyuLqMAVhzzPLfuR8ko2hc9l8VziIHDGnusZKle7mQWr84t";
+ sout << "L/ERhHo94fQVR+BuogcMUEcFq0cAqYHWzwI41KVB/N77Tud7X5KuhGc5ulVJaPukTGGvqIrBDbBV";
+ sout << "HXFCuuptn8cr123muSyZ9qBcXcSs7DRnZ1mFO1LoxW7j+SsuyE/c0bUN7ndjGTUC8OiS+NC8Bjdu";
+ sout << "u6oxAUIqlNYIoY7EAPF0WphyKvId9I8knRWQUQhD4L2kk8yNvtYvjdv4udkJAUQydaZut/lnijSg";
+ sout << "Ph21q7/c5IlPo22ZIiPI+a3O3wkV07i6vAPoC3zYn5Vu6bFRHSSyolNsOrEvCkWwQ0oTCcm0Mp2S";
+ sout << "eiwjZU6/k1AjV5sbx86VksQ/AvKxU7RyMbeuElprxKJIMLdkghibR3x4clnKG1uRzM5c6eINHY97";
+ sout << "vYFAN7XsTbpmtA5GgHUtBWVJl6OYhBplfM/ZECAqOp0ZbBpEpfEntR/Go9qLPlXT7uFTcqxKePSv";
+ sout << "pYwsS8xLTovG/ELnF7/7NxL1ZPloXw7JXSPoSWEdbQ77qG6nOOYyeYf8i8MCkSC8tBM6rcORlcfl";
+ sout << "GTUIeRALSs6B63QHlms4Ev//nz2bSKv7BHMmrA1+2EI23xeZgiuPoKBfdCe/YxF5kXZ0PSlxsj2N";
+ sout << "7OCHsUTBiQrVCUoAO5hXZWAlSz93FgBGkc8bh/d6dLW98dXvWPxLdq6vx7URXleswpb3FPFNm8R5";
+ sout << "Q6WOEu2+Y7ilKgIq4yv2259Heqthy3z/t3QuTfa+lmhTEpLDyFFZ1o6ERsqYmMrkPM5DoPn5K1BS";
+ sout << "8svZoTTyRFw6yxjoVC3rk3ICe904M/cuuPxOaV6D3/7/k2/9Qn2+dDWV/Kv/HnH9kR5L/YjZ924u";
+ sout << "yI4vfT98p5Jbxd+m8wguelLfXVTz2dNjBQ9fxfZbVWMJj2Eo1BjHIAXpFjE7g9fog+NbIqnlIb3J";
+ sout << "rBh8jEGWy2lNrNZXVzYly0z4d7qmrcbGI5FhlWOkKBFAdp39ju66kI39BZiM/RoTydqIQ2iA/eZX";
+ sout << "wqgzqMPQ/LrpImeroHcY3tMd334UHnzDDq2rJcqreapzK+YM/saPsBmiYWs4joiDCjTJcGuHOroE";
+ sout << "PIDb3yGry6Hov+GNqI/NpjFMQIrEJlR25yBiunN3t4sAcHQJzT8OHxrDVV2BWGRbagCMqWqeN7KY";
+ sout << "ePwovpts9a5QbgncQzbGT8X9p5WQ/uY9miOJACWeHr/kZXjQt2ngGBQdOm9i3RfYW8NOjBDGtBpJ";
+ sout << "Ys31Wj9mwuC7+o+4DLfqtfXrjW/Or4zJca7EkuKJbRjlicYSyMNSrQPjGxq8qYkHds1dEQM0JydM";
+ sout << "NjFG5vLb+nqNRrABnSVkuEl2W83udqnjDE+mJVEnfwHnPxF2k1Kr2nmxqO54ZUranYnyggZxrvuu";
+ sout << "HGkg62IpJXv6u1R+7rF8XRErh36r2lkmps1cCkmW7PrqaB+z/wExOH7kDZG4cMJitrTDywtYsDVu";
+ sout << "xltXSqRBUTH/nxrq/9EiQOM7j/ITCoiZJRfA+G2hW4srArUW8EXEy1dKfOKrooxrZnGeEh2E5sHi";
+ sout << "Dib1M9iQlTCc+BZFPqvOv8lM8uIe4ylVvh5DtyzcKgg1ta5hj4nTKL11vLDBXyqBCVz288zqy+Ra";
+ sout << "seXn0vPNCk9fsxOtFyjUH9UiE7sYcVB/SWYa662DLAvV832uOjFVlOEESlVhhYWb4nCEDZVnOnPH";
+ sout << "1ilf7LOgVi+qHmMm3FS2F55YclALS9CpVmVbt4ADpgKaE8V0TlI84mglvHOJfqE5duRDMfdQkjBV";
+ sout << "o4vxT6Y2xv49ARNgDI2J8yt0bxnQ257afzSUxvvlROxBrYYNNkQ5lPivD+zeWH5j/ojyxqF5iSNq";
+ sout << "Arp6ClK878ubKMvRPjdmrifFhSfs9vhfL3Ox0YFPBq3zKLPbTJEalN3Eeiuzd0dDSnUatLnV38V+";
+ sout << "wyxNsDDywi83kR+dulf2SytyhnCCKOHHstUa3M5klOVLCsrj5lL3CarIo89fqpqqY0H3gzosW+jw";
+ sout << "e+Qd/PTgX7FQd08oasrwptSX41xYXtTH/K/ndPg6i/YzmdIBvGuV+A4MltyzGXXRYxulYZv+OLOE";
+ sout << "9AIl10pH5amcenHV9Q97qZ2NAtGeBGcWQNpQRl7titOfka4CHpNBXjg1xm/go6SHCy2Lb220mw13";
+ sout << "x2kdIEa7b+bRZrXwYJ7CNHE/OuKclyzDvZHfwOrqGl2mfYnG78oi+Gr482vOBZBr5sYf2GxUT4lr";
+ sout << "7YPENdvyYarw/liMmmWE7Td9rTkHXTt3KvCaeJRddK63LIjQrOJNCalDytkGT2LfsY9hu5r6a7Of";
+ sout << "yspHZhOqBBOyjGXe4/D5nQqwpFfccywWBEtnC/+DINOvqpfqx9CxFAZzdjJJRlIPtK6erv1dSTEn";
+ sout << "AodulJz+SoMYUrOSyKPqQlVdN0ycHvLFTwspnZv/6NqNzV8iqYuIb/iylAww3BkmtbmHJBKpeRKc";
+ sout << "DVNWSOiFRb83QaDQLh+nfH3J0p8OcEQgN0XSQIamAOGuRPcTILCNsn9gg2qEwpgpo/1gL50XfpvH";
+ sout << "2SNQBJqD6DaZ1VwqzEyQXAw26niGJ83+UEB915mZR10FYo8mM3Et5WYyg3/0Y3u3+B0GHtF1lImL";
+ sout << "wrpLFhUdRBijcxzCYwuPwrbC6tIHnfV06rwehfR5Anq0cwqip1O+uw2Af3IfjWh0wImoQx3Cm9k5";
+ sout << "Z3iRoZ0+MttTVyFuWjHGrwVkSItkwzvzRL/UdeyTkRqSQvDM9zZiwNygcp8FrpdEBuxMmfC4zAyW";
+ sout << "y0yNR75+2NpI4MPxmEh6jvDJrsKQ+hxl7teqfPRRi8leqIvz305ZBVcrP5bHMcuj7Iun9znHdqnR";
+ sout << "ngddgfi7+SRsy5pQEtmhMoTI5Q3q7OB4I8IrttJPJRu5Y2HCUeCfmTlTVQD0iu/eXYeAbt9GXcb2";
+ sout << "LptTexWOkchMYcVh5qFWVmnPWjRa3X1nl3ildvrkqjTjA96J4G9sd0TwldCyZ/GglrM+SP55tvLv";
+ sout << "4MOevxl2gpNHipkZh19MiJpBArgBT1uyScDXrty46tWdIC3eKAmVfZprFbjswk6L9f+DmzDyeDlD";
+ sout << "v3mfZy7c0rZt/AIrReTIJYb/aWRCno4vyWww6AbvH1AYm0nWD3/jcZazTEsLGk1LkZwQ6hpXOHa7";
+ sout << "47lDu5kdU6hprzMEhGuNLdYDKf+Bk8I7sYp7fBn0FvpnD+w11jhINFJZHL+4MG53CzMqo0iiAbX0";
+ sout << "9ppZbFaOVCrNfqJaRWKYQnkcyCFFOsF8MCKeujNtUynXLAz258d669Fr8c4jl/nzSdi6Y5SMqdJw";
+ sout << "tIBdmkAHxdv6sCRys4gyeyFjVHDeUNrszqG34rEV/905CJn8JNTOtUB/kk+WJaxrKyLkGUMHQHWO";
+ sout << "N8whx74uMJhsg4W+Qpogy9uopqF7Unk/kCbuBrDzXR9+XjYwhFE6uteAbUJ7RgrWCP8929150lOz";
+ sout << "dk86E6B9KZjyIke0jrxH9AwH1hiOLkF8kIWVXdjTPgqF4Q5BjjIdSWhXWS55t2rpD7Hxg3cvOILF";
+ sout << "c6IoU/uRuxpoMMnSuzq+weYnEjTBW6SlnjXjkIuZaK1+2U9rWDaqEgTn2ip8dgS+YA7vA7+dh4we";
+ sout << "vR308ToEwblhv8vSWJICn4iyfaUDcFCXVegEzl4oJcLY0u4NAAxEVj8O4YN9XQcA51dADKZD3+70";
+ sout << "yfVPjt4xF1StFZAt8zZuiuh17/SxoyRZXSOZ4lrX1ihQI+mVIXAblu+LpJM+ugowwdnnBS/poJwJ";
+ sout << "b2W5jv3q+IFy6VJjYkvzUpckAnHGlQyDhYpyYARW6S8BcAox4++WgK6K7iYjgJVUxTKdMf3LmbWm";
+ sout << "UhoR8Xy3P6NBq7lzZRddjFQLZ1wtKBgvqjSdtXDFg8C4oUG3s1UKF5rLQ8ZC4EXn9kHCJI6U4IdT";
+ sout << "nVsNrF8RTfItEy3ji64JnvVGmFDPb3V0CPSlSrGwmiF9TeN4iT1OOWCXNfJyLL947Mpd2hOx1jpa";
+ sout << "52OVo+GofOhImcAXIryA9BoupDVR//+7iFB8qffCuz2ZyhAeKTvVg5/TEXsItr/9Mc4kbKUo31E7";
+ sout << "7cd47sykXhs+vdWU1qVpVkjDOmsHdnXLxkrL23+UUu0gUoE/zdEayg+yJ5nckV0EULpnOUdI6kD2";
+ sout << "r7PFc0abFaw0FrIV61f0AHLAYuFGL4/pZNXoHbMPcTjIOr5MvuneN2n32lqhHQyoZr7er4gm4n6s";
+ sout << "fPmgp/ezVB0uhb5Bovr7uKBjqCSSxwpyWLIioAUIy+J/bGQMqvC447PgH3Kkvkhmp33i0pEp1utN";
+ sout << "hXTHtnu2vdNNwZNrdANwSs8/hzRjI4qdcFVjCpnRyrcIATvkalWB6345vTSyg77XPO0m7q3M41sN";
+ sout << "me9p4AeLxlgRDIfkJrq27zPlYT+w+o27s3jXBS2MvMdSwl4CW6UyHCqrouaXy51PrWtFLY1LZjx9";
+ sout << "3R0xGbp9XZC9WRqiL91VMa9l/sTS92bw7urPeRj1FHh5HNyDSpaLEDHQiihXJaz4xxGbdEveyHW7";
+ sout << "4yhgD4sR0slooQhw9O6IaL7gjw0auFf+ty8jedYY/LfLHhbRGYvmdL1NiskwM6mdgPs/w1q++HFt";
+ sout << "u613gA7BRTxRw9MIiriCtGyk4U8uO6PnGGqRoXNpnJXE9fgl9Qdoy6N+uw2JTIMcE6btB1lbjdQD";
+ sout << "nlUalgZMbXzm07z+MNNNGSc5Lc3ylAJ3v19GL6qPYYvDtXcXC8lPRSmeZu3k4tDZ7gVlokrSewYR";
+ sout << "iQznd/Cq4w/MBXFRy5nbtKMeAfQOEU0KTgcOPVYfd4sUJiQixYOyDpRjWHLuEfCXoXEbLIBLeJ8j";
+ sout << "6YHwyOMMzxBepwBSxTNRlNnY5GF22kwhk/KUIB12mmkegsiTSfA4Kz3q05i7KF93WZB1gDRSkpJ5";
+ sout << "DgtN6fALotQJlAdI7CgKxiqkTD6YD+wIA7g/bu7EwH5anE+r9mBEimDxrDNyhCSVwfYPnFmi/HLk";
+ sout << "CtH5msyO61QCth9RCpOKMsnqHH2oqh6XrH+D6FjhmFxqlzp5PJ3SEybtCY25xxxnmwyJh0uDcAzr";
+ sout << "a1EupoKzlsY8U8kFdcmhOtPOnSgQRqUoLQxZ0RUixPiKFI3ikbai5/ZNNlt7Cl6XcOWxnQn1XLGh";
+ sout << "7KewlqwiweypCVOk1D6NHgTdfCGzW0NWYrs8cqzc47cWDSMQbn8Zzr9LMHdjq1t5VoT+0pyasBhf";
+ sout << "Wxih23uQlVfaNO5SlxpMPuW3TFpASdhxiXVmJbIl/j02lo5w2MLYUUouGKT/WeTr2f2R1kNZkCcN";
+ sout << "44DdifeZYVqoUrAV3etS4Py0H6OHDtAnFmU7983u/bZ7bpkfvFl97zvms5D5JWuFQ8HZTZikTzIM";
+ sout << "uxDpDcxkFz1U24O/WnovgMCipjI6uV39rjWr3qXb3qzc9D7Fp+jh3eg3F9C/SeXY9Ru077P9mAI2";
+ sout << "oRgh+2hJ86+1/YdCetlEVgmIVVVGVbcZWLNoHISNo2lnH3fSDcfWTRekNBlvUgiAE49tH+av0SN7";
+ sout << "cxjm3vWlWv9Mdt1BvZORAk2sYugBnDJZbB+F53rKioLbUNdlCG2HcjDboLrRnY//cxOlOARg9UYB";
+ sout << "N4+HueNJApuqkGhItKHRMKXWXWXKYkyN4PdPicVFpdlxwoeXNWMtNvUbOv9oAox2HwWqFERO9JGa";
+ sout << "wSXE2oYyE0+BUUUPq4Bf0HobSR9yOVj2YmgvvEv61uS/7rDiSHElMnyrGnOpy1Jk9Qc7BwDJ4iJt";
+ sout << "LH0qQR8w6qXJcDajjLT2dYkzFYj8Hl4iglGCRuFTKjhQLp/E1nium2AS5aSlvk2GwX4SHEnk6nAP";
+ sout << "zYbc2wAbehdVW48RzUWWSW62Zckj2YpDSMoHlPJhysfxN5NhGzoBag5D28vAsbmuOYKw+aDDg7/Q";
+ sout << "kmGxNMjXuZ8z2dUHAXA04UxRx7P4kIDYHrOPyd1s+V/Nhd2WSodXW+X0N1pC5Wv8EYYeNaUke7h/";
+ sout << "8w83XCJ9cHSUO/PTRY8VLd1GX0UFkPb4m3Vgbv2ARqK+ZoZ7t09USLMz9JG9MeD3+J0vL9jShQW3";
+ sout << "+4j4G9ZuiIW3VlOu5Qm7jdcLEMzd54mqjC6gYRqTYNP+1P3LtfVY7nhRdxR6S8Zt16B1z+hHtLsY";
+ sout << "APWNCR21V6k3R3MjFIit1mBJZDLFFM2OEAYZW+SQ+TkcSueb5xCyFK2gwynzU6l5/CSbaudjWjsE";
+ sout << "hpcqZTmHsv+ertehndIaMes/Ihe9ZWf8KJU8sYDn4gV/A/ZXesRLMDtGsTKTywN99GdKi1yX1vpv";
+ sout << "ugSgxjMoplA6lZHmz3+sEUepfpLVQxpBQ7OvrZrhBuf8GxgKRC2C9LbdKdq8+qM/qjuNUf4CKz13";
+ sout << "OZfp6K89LPJx4Qua2uEKtoa1i4LN3Yt5urZz6CZmlyZR8/jo2e/PBqsKc5zP7SzJdQ3fELngYdUm";
+ sout << "HJC7uE6ElkbQWqTUg1Tjm8W2kGmbahkM6eVYps1yVrWqWQvFG/m3IUUrIhtvK5JwzlRfH13GIevg";
+ sout << "cyQl2750lWOaewIa3Bni58GNVuR3icSmcf6hZ2PbeEdnnxMW8oHUp/cWMVHmZdG+AA==";
+
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/fhog.cpp b/ml/dlib/dlib/test/fhog.cpp
new file mode 100644
index 000000000..88377bcc1
--- /dev/null
+++ b/ml/dlib/dlib/test/fhog.cpp
@@ -0,0 +1,684 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/image_transforms.h>
+#include <vector>
+#include <sstream>
+#include <dlib/compress_stream.h>
+#include <dlib/base64.h>
+#include <dlib/image_io.h>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.fhog");
+
+
+ class fhog_tester : public tester
+ {
+ public:
+ fhog_tester (
+ ) :
+ tester (
+ "test_fhog", // the command line argument name for this test
+ "Run tests on the fhog functions.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+ template <typename image_type>
+ void test_fhog_interlaced(
+ const image_type& img,
+ const int sbin,
+ const array2d<matrix<float,31,1> >& ref_hog
+ )
+ {
+ array2d<matrix<float,31,1> > hog;
+ extract_fhog_features(img, hog, sbin);
+
+ DLIB_TEST(hog.nr() == ref_hog.nr());
+ DLIB_TEST(hog.nc() == ref_hog.nc());
+ for (long r = 0; r < hog.nr(); ++r)
+ {
+ for (long c = 0; c < hog.nc(); ++c)
+ {
+ DLIB_TEST_MSG(max(abs(hog[r][c] - ref_hog[r][c])) < 1e-6, max(abs(hog[r][c] - ref_hog[r][c])));
+ }
+ }
+ }
+
+ template <typename image_type>
+ void test_fhog_planar(
+ const image_type& img,
+ const int sbin,
+ const array2d<matrix<float,31,1> >& ref_hog
+ )
+ {
+ dlib::array<array2d<float> > hog;
+ extract_fhog_features(img, hog, sbin);
+ DLIB_TEST(hog.size() == 31);
+ DLIB_TEST_MSG(hog[0].nr() == max(static_cast<int>(img.nr()/(double)sbin+0.5)-2,0),
+ hog[0].nr() << " " << max(static_cast<int>(img.nr()/(double)sbin+0.5)-2,0));
+ DLIB_TEST(hog[0].nc() == max(static_cast<int>(img.nc()/(double)sbin+0.5)-2,0));
+
+ DLIB_TEST(hog.size() == 31);
+ for (long o = 0; o < (long)hog.size(); ++o)
+ {
+ DLIB_TEST(hog[o].nr() == ref_hog.nr());
+ DLIB_TEST(hog[o].nc() == ref_hog.nc());
+ for (long r = 0; r < hog[o].nr(); ++r)
+ {
+ for (long c = 0; c < hog[o].nc(); ++c)
+ {
+ DLIB_TEST_MSG(std::abs(hog[o][r][c] - ref_hog[r][c](o)) < 1e-6, std::abs(hog[o][r][c] - ref_hog[r][c](o)));
+ }
+ }
+ }
+ }
+
+ void test_on_small()
+ {
+ print_spinner();
+ array2d<unsigned char> img;
+ dlib::array<array2d<float> > hog;
+
+ // do this just to make sure it doesn't crash on small images
+ for (int i = 0; i < 10; ++i)
+ {
+ img.set_size(i,i);
+ assign_all_pixels(img, i);
+ extract_fhog_features(img, hog);
+
+ DLIB_TEST(hog.size() == 31);
+ DLIB_TEST(hog[0].nr() == max(static_cast<int>(img.nr()/8.0+0.5)-2,0));
+ DLIB_TEST(hog[0].nc() == max(static_cast<int>(img.nc()/8.0+0.5)-2,0));
+ }
+ for (int i = 1; i < 10; ++i)
+ {
+ img.set_size(i,i+1);
+ assign_all_pixels(img, i);
+ extract_fhog_features(img, hog);
+ DLIB_TEST(hog.size() == 31);
+ DLIB_TEST(hog[0].nr() == max(static_cast<int>(img.nr()/8.0+0.5)-2,0));
+ DLIB_TEST(hog[0].nc() == max(static_cast<int>(img.nc()/8.0+0.5)-2,0));
+ }
+ for (int i = 1; i < 10; ++i)
+ {
+ img.set_size(i+1,i);
+ assign_all_pixels(img, i);
+ extract_fhog_features(img, hog);
+ DLIB_TEST(hog.size() == 31);
+ DLIB_TEST(hog[0].nr() == max(static_cast<int>(img.nr()/8.0+0.5)-2,0));
+ DLIB_TEST(hog[0].nc() == max(static_cast<int>(img.nc()/8.0+0.5)-2,0));
+ }
+ }
+
+ void test_point_transforms()
+ {
+ dlib::rand rnd;
+ for (int iter = 0; iter < 100; ++iter)
+ {
+ for (int cell_size = 1; cell_size < 10; ++cell_size)
+ {
+ print_spinner();
+ for (long i = -10; i <= 10; ++i)
+ {
+ for (long j = -10; j <= 10; ++j)
+ {
+ for (long k = -10; k <= 10; ++k)
+ {
+ for (long l = -10; l <= 10; ++l)
+ {
+ rectangle rect(point(i,j), point(k,l));
+ const int rows = rnd.get_random_32bit_number()%11+1;
+ const int cols = rnd.get_random_32bit_number()%11+1;
+ DLIB_TEST_MSG(rect == image_to_fhog(fhog_to_image(rect,cell_size,rows,cols),cell_size,rows,cols),
+ " rows: "<< rows <<
+ " cols: "<< cols <<
+ " cell_size: "<< cell_size <<
+ " rect: "<< rect <<
+ " irect: "<<fhog_to_image(rect,cell_size,rows,cols) <<
+ " frect: "<< image_to_fhog(fhog_to_image(rect,cell_size,rows,cols),cell_size,rows,cols)
+ );
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+
+ void perform_test (
+ )
+ {
+ test_point_transforms();
+ test_on_small();
+
+ print_spinner();
+ // load the testing data
+ array2d<rgb_pixel> img;
+ array2d<unsigned char> gimg;
+ dlog << LINFO << "get_decoded_string_face_dng()";
+ istringstream sin(get_decoded_string_face_dng());
+ load_dng(img, sin);
+ assign_image(gimg, img);
+ dlog << LINFO << "get_decoded_string_fhog_feats()";
+ sin.str(get_decoded_string_fhog_feats());
+ int sbin1, sbin2, gsbin1;
+ array2d<matrix<float,31,1> > vhog1, vhog2, gvhog1;
+ deserialize(sbin1, sin);
+ deserialize(vhog1, sin);
+ deserialize(sbin2, sin);
+ deserialize(vhog2, sin);
+ dlog << LINFO << "get_decoded_string_fhog_grayscale()";
+ sin.str(get_decoded_string_fhog_grayscale());
+ deserialize(gsbin1, sin);
+ deserialize(gvhog1, sin);
+
+ /*
+ // code used to generate the saved feature data.
+ ofstream fout1("feats1.dat", ios::binary);
+ extract_fhog_features(img, vhog1, sbin1);
+ extract_fhog_features(img, vhog2, sbin2);
+ serialize(sbin1,fout1);
+ serialize(vhog1,fout1);
+ serialize(sbin2,fout1);
+ serialize(vhog2,fout1);
+ ofstream fout2("feats2.dat", ios::binary);
+ extract_fhog_features(gimg, gvhog1, gsbin1);
+ serialize(gsbin1,fout2);
+ serialize(gvhog1,fout2);
+ */
+
+ // make sure the feature extractor always outputs the same answer
+ dlog << LINFO << "1";
+ test_fhog_planar(img, sbin1, vhog1);
+ dlog << LINFO << "2";
+ test_fhog_planar(img, sbin2, vhog2);
+ dlog << LINFO << "3";
+ test_fhog_planar(gimg, gsbin1, gvhog1);
+ dlog << LINFO << "4";
+ test_fhog_interlaced(img, sbin1, vhog1);
+ dlog << LINFO << "5";
+ test_fhog_interlaced(img, sbin2, vhog2);
+ dlog << LINFO << "6";
+ test_fhog_interlaced(gimg, gsbin1, gvhog1);
+
+ }
+
+ // This function returns the contents of the file 'face.dng'
+ const std::string get_decoded_string_face_dng()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'face.dng' we want to decode and return.
+ sout << "RFYXmMpdiStV6dVZSJkJX8t7GVavYwTD+Fn11ZjivhnFQyvVJ5t39yJYrK6Qh6K58ovzMlgPiBLV";
+ sout << "nd+ZR0JYVCVvwRapp+dznB9SG9dJbBrsYH68k04uOs9ehP8aK4/EXvcZG6s+rL0rnhVAf7SxL7PT";
+ sout << "r/11jIuASMa1daKZjAm5Sc1icGXG2FJjO6CxM8mOzWJ1ze69MPD1bz/QYAWtMUqUUIAM0qOPHY0x";
+ sout << "T8tdU+Vo6S6E+8dJpV6a6iDdocbp91meDQcT0/kadhC2tmn0eZoNulTn5MtmsmEeuPI2lLLcRJ9P";
+ sout << "yt3c/OJIzI8FaDzYG6aWJ/yBQx/DJF0avAlh7V1UmbD8O/dMoF9nUFDwnhGyS6DYfTXxCYgVgoj+";
+ sout << "Ik5RLHY0U/DhNTciFaLX41/MyIt0xcGtxhoVcvwkfIigKnYQsYfNpRdUWseRlZ1KYaR4Oc5B2tie";
+ sout << "kH3e5AhrY/HtffCah0sf6MBWJEi7CH9AnVLDQefL8Ph+qCWJGf7cGnM/oAaHQCzHIHVi+mK6EBnN";
+ sout << "1NDrzbdXmikwYneB3LUZxCLKZmxsFduB2HgiS0A+tTK6IYc+jqCHqz8N6Gw0sSjAK7rrPDTvxhSN";
+ sout << "lX3f6E2IDfVmyvk0l3RhuA1PNEh/nlKR+YxcXHyYW4wGf+UfWScAzKGxrHLxLC7LQycCEaCMkU92";
+ sout << "SQV5NSSlwKYKACabK6UJ3gGIpvuQK2Aw7VWmC0iLczqgWsX0GKJR0FAcVL9Ed3nV0Wd0s5BkjBsr";
+ sout << "RbUKzw11Qu0toj6BNfwXo/5cY2dtjj93a+CBfNrSEuFyJzZU7cn890c9m+q8C41p+wQdf4pFpjcV";
+ sout << "8Kz40Fyt8KtxItWSsACIwmUO9h7DGnyGskWBYrxgDV2VVlvuPAnnSCFPkbdsa/pfnohUq0C5a/ii";
+ sout << "BjASduPdaBHpjZ64f+TIaXNAGdrFiN61W6e3fOx4fLFlzPQ8szyWuuDh2hIz1FMflbmu6UOEkQji";
+ sout << "w+bwDDJ5OUFmY/00+3B0XAFmj7Pt8OQ70lAVLcX5fC553diQJzrrlJ5p9/8+ILln+oleUVJhtp2q";
+ sout << "VCZ9XknXLjkQik30M7orOj+tZt7HDgC5sz/wHU5arOL3nIX5IuIHBJRlB8dERZgoPNQlB090rItP";
+ sout << "MuT+Hyr/eR7Kcux7Fy2CoMxcIfWEXvxQoolLKC66q4+SFirdMRjXuwbRXrUBbenmBfMMNDAOkQKO";
+ sout << "Bi7d8t1wI9ulNbACtqLbmPjW6iabc0yM4g69cZjRx/JYhV5AaykROJxCP6ZKTw+3ddAht8xoyHLN";
+ sout << "40rB40fwEXIvv7qxCdCa3h6l6IRV26fOLdcew1G0qjPORcKK1TPhzmneYvhPZ1m0r6KxWNEnYcFq";
+ sout << "WhDGNxj/05eBy2qIiIU/KUPhxKyipF0ekgPoCsWT3S8edYjoaIl5cI0TNpNEwKGRLQeOLBDt+MEh";
+ sout << "z1yKm0i2jrtxDBpYf5FW/Fln/XJAK6z9yqSDDTDwQleabvRHDH9bc54bLc0TL9g7/eIj9xcshhaB";
+ sout << "zbi7baB/yUnI1+0N6CZ45gV3BcD5n2QLiHME8wELveiMxps6MTf3SdKSRZJcnoVfN4AGYqQV42ec";
+ sout << "dFB9y8FLZVL3/8rmB+XEu6YoiGcNK6iATLYQfF0GFRrur0Q6bQgdvXv1uZKtNfYfznsAAu/KBdxX";
+ sout << "8qskZBMGA3LxJC3j41VW6Fviy+XUxxcmG9ykbf0COJWDul6ZQ7iRI7rn9EpFIYBM1lKzjdC0UFTW";
+ sout << "yDWEE+mf9y+RZdlxHdROFj93FNwzSdzNr1yjqHvZHBZYiuArHEuDPXdxqVRePcID4EHzmpDgWFwR";
+ sout << "o5qqDxU8e9UYfS8SG545SPZv69SJVJKld5fQLZ4FbcCjv7wTwrOKTROvurKkxopKt1n69BdDA14H";
+ sout << "mViSyK22xK/F7/ydjLoqx6aJ8xyNpoUk6XIeJ5Ei2Lhk84VQk9dxzULVy3KsfRUrZCTTi4YiXkHJ";
+ sout << "SmQx4NQKqHR2IOgnJBZuNG9J3Fzv3NKhQpmKL0ZbYLXWdKP9FHWUR0x7y8f74Su+GrplBsjh9NIm";
+ sout << "QdaKLa3NvJB1TML1/GNcdJVZUuSaX0cQn4bbumvtcENVbC9u99fGnsaS5FOu4AHd3338zLUMy34C";
+ sout << "OpJjU1c/IElgyKmaGYlAolgqGU3lixhxPGBhUlXGfHmST2ZWq/l6NxD//HQXRaRUiQGQzWCvhzOO";
+ sout << "ywUlVzl9eJ5e5cdLWvffsPRzBgRMrdHJG4TbSuLAREsSD9QEGab3a6y+qa8T3Si/1Ut+Sn2QvPh2";
+ sout << "meqqk9g0fRWtrWxcnbUDU6zMlk36L/o/y5inrHdGY+ixIewhI0n4/Nl3wD96SRITcEVSx6K/BVot";
+ sout << "+qIP78I6uk+miUF6MW4AnFyCe1wRNyhT48638KIphSQSKdu7TaBndi2DNgFFvWrm6/cPOqkmCzGC";
+ sout << "O22uwHyuY9XhafswKLH02+VD24PIS/Fw7JMP+KzvfCHQd4XxxdsISe0/cjwg26ZfGcnULLY2E+dX";
+ sout << "LjdgCxNyFBzFTQ4gB4QExF0dHu+SPoo5T3VAojJbYZqIelFY+u2yQDuS4HCUISPkwuLHXHbcBuwg";
+ sout << "5TeuFhyBrlwxHQC/OPACmQJREImiqpzrjmh5QipeEgYHK3Zc72tYSeY7eTzS4jj0eRQ8KiNIGSi2";
+ sout << "2LjzAfN2Zm7HGbiBtKZVen96E8HLcrd3nSWnizfaLLWTWB3zu9zz9/vFdaa3TlO6BidYsKomTCgB";
+ sout << "wy8yMeykE2qbxgrpRqEqmOkrOI9XtTTJIycfAlwBwoFwuqvGIPtFrYmC/MwRMCphg7acSRcjZg81";
+ sout << "5IEqpoq9ca7Zc3s4foteVMCJT1A+qmNAJ/j7IoyeX7GnlM3jsqpYt9BmKfbw5Dr2JB9vzroPV++x";
+ sout << "UN2VXRPbahjbIvrTULpeBdmlHU0i3Ya8H/C9RY6c2DhImZ1gDjgn0jQ9GC+CsZpiM2xBvfZZGOEu";
+ sout << "c8N8pdo2owD8s5q2G5ZCGNdME/AG+iIlb0P00AX+XR8FYhxKb3y50i1giM41mnkKM/WMGFAnpiuo";
+ sout << "YordYSi5plePBnxBfd1Iq46PpsD/n/uUTZMHs6TGp1hM6QriyEhOO261HNHoU+n8m1Omz2cfRJyx";
+ sout << "AuFLwHSEqvGCSmslmoDpSg2qOaIWK1LWlN+1sYJj18iL4GRM0A5QzXaS0RThqEgmPjeBOkFBjfSO";
+ sout << "hB7mb3sDbY49qbN6P48bGV+yF6y34gYAiVkm2NksHzN4ovwg4O6WMQZwEhNk+4gTIzG69jIm6Hbn";
+ sout << "2l48A3CYmn8gcjZw39nrlSxpMf7KPkRsdvGmc5Qx9RjP71zH/KJ2TXP0xxzsaGgmqzXfey5l0Hih";
+ sout << "XZtfZw8Y28fHBfm3bnIncS4w9S91no+RYMv0aqc9ty7l+Pa28ELwSgQj9eP4u/i5iq/GPmmSxiTd";
+ sout << "Si/eeyK1RFJEP4Tv4f3PkV9Js+azu8BbtU+BLO1FBlVg3CzXH5Pc5FMujLdmlqa495hTmi8YW6Et";
+ sout << "Fx8dkC80mYFGpVjS+B6pcQLbLBL9gmKzJf4L94/gXZ25BEDob66+XOaRnJ4RkSAN2g6gFJB9lJDh";
+ sout << "rLerp3kP/ubPCvcFywuGx3UjJuwFNHE9m62uiaXFU4m04Kc4n7ccHc6hYUkhkY53v2Qb5SDx2qCf";
+ sout << "Yg+PWVXujfYrqxRHSwqtV3yX5kMrtYsYpygb7crweOt58BWUa3duyo23UGJHaCwhGwXat6PEC5DQ";
+ sout << "2Oe3LVJmc8eYtD97mHKFPhptBl5u2Bztb3zis/oNj1NdMjnDrNuscEAnrpk1CetvHKLglK63Zo/D";
+ sout << "rf6SJcmGR2h9g6wAeV7UdsfD6AvteiPj5sl4UuY9x55pP3CTTYklBO1MaDd/XO3A66uMh95RZVGr";
+ sout << "VWDd/uKL+rIuI+vKjz8rt80nv3SyUrY9fbftPdK4pBaVnIt73yZrrqv4Zr28H8XpFFQAV9BPlC9o";
+ sout << "a8G+AFx/+W2cSfo9r1Uw7npVvRTe6TtIiKagYUmWpx5BfX0VH/VAW0FUh9oiVfx5rm9eaxfSQnD6";
+ sout << "7qBINPxsKq+ZDSXni7qfC3J043Le/uL+3XUqsccvEMoU65akKC3lmw1txoUukv92oxyqPX0eOGsB";
+ sout << "AU4JdXCldqjU9K3QhyCvv80ZWotGfUr0TlN1LVZqcF2iq3pX1UDOBsPwz9v0QNg8Bmlqy0Vs+MUj";
+ sout << "nMCwU9xErzkXLsuVaG+Llk7mmAl7C34BF9O9qSl2kCmbQYoQ87zS7gm/pK7aKGNsICHrar6vlsKo";
+ sout << "BJA++/8XKL3nseNZHzq7hKHnOTzagP52MRf+TPXbTVjQPKnCKVAZJcsOlkmuZc7iDnLn4muHDRjg";
+ sout << "y09EYcYlFWhLAgsWmatQBsT028ytgMNrQGHDJdjuNkxYfPo+/91ijaaBiey+DgrUVn0fm20k6/Nm";
+ sout << "colrwPwHrK3uOdgBn2ysDeUXU8NLMtR94fIL7etQ9tlUuufwrxEL9zYUM8tpks8HDR51xgTwUOVo";
+ sout << "DyGFzOdYQRzwi+kkEPEwkpNQbB258d5w9G5eR00P8B/aSjm+w4FU0MsXM0GgPxnQ+gTpS1cezLTn";
+ sout << "eelvJYiq/IInLLxoCXycZFPt3WFQqOBpcs6TV/QucjI/5xMZtP3JHUFv16UKPTFI7p9DF+8Ch5HN";
+ sout << "gWXCnRSPdYR4ZRid+Xfzi0TvQsXV6u6PaE+H5MpyNMBWhCwxb6FdiLUW0BswGNpHBaFxjB26Qbmv";
+ sout << "OW+s0OuXDvKigjQRkeaYawjRAIAN/+CEYR3oUad2HyJ5Ybr/lRlybQuuIqBhuvpkYzszS7BqrxOh";
+ sout << "FJYaivT6r3HbHjaJ+Yz/zNW4KsL80zYkPMP7QgcbbSfE2mAavr+ciXdZBqMMUR50sDNLxep9+hoa";
+ sout << "ys9wl75QMdx1jn1qn7f04JMSjCyZ7M4bWSyTW7VEr+NBBLmiMzhI6Ufh1iCUpvrIDSQwSDUL88wt";
+ sout << "oSiouRbqizt36TldsvFV6afdLgjRrp2cb4vOQBiltwnY06JraGZnsrb4UCfHZhxd8sq/invK9tUd";
+ sout << "D3z8hYyLGbS+3LBCK85r74IYvCuhoUp+KobIZPhvWuvdjmmq3SAxIKHNdLC5hnLVMhGJUrckc18H";
+ sout << "9zK53uB3QXX6zGKK62Jph4aOdJoDQaPL0K/yHgn9UayEhH/N1uj3Ao39c05puaxzcSotfBeS5+6K";
+ sout << "WYyOOMtt5ikKz79qfj6dVWge22fxXUc6yHYfdga0IbYWRocIx+DuyUZnrRQHihNKgYpvF2vhCX/o";
+ sout << "R097oHI4ojZFAX/ZWJ7igJvX7ChiwTjK8KDk+vJ4SUd3IHXaiLkkkd9p6tCuc9Lw5jqWiGrrQKuI";
+ sout << "7AmGsPFU2EsfOzmwdZctDGXq2/IutVDmwGiucpBKsRN4y12Q1FWKpceVj2q761LfDx2qJoeZKTPZ";
+ sout << "jHPdXnGKcWy+DM6GoH9e5jP4CW+HfdHe474bHfLDbP4NE1oF1vdh4NcLy6woi47hg3FS60z+wePD";
+ sout << "bWq29WsSwU5oXq58nMxKOBiMcbGFrkOme/Z59Ybi7Cw1+U3nGE3evCFyVMC6g4f/jvCyWF5I3Nm3";
+ sout << "OqmkO6fmZ4ahql6C+RwfdRM8A3FllNPgO5riBNX6RA5xKj+JS+OZUrSSN+tUqcgN18IlmLBExEUt";
+ sout << "rdG06PKy+WM8Cju3gtbOFX43H4URr9CQcDxWbN6NoqgF8k5a/4+xf7DilfGJg0E2Vu8GG7tmFSU/";
+ sout << "LS6gtfLOFyEnQkTzqK8OhVPVLT62cEfCZN9ZY3iKQyZ+VLQhxwarUAgqeNAMXM40NBJqnUIaaTKa";
+ sout << "ryhHefUHazhfVgx2+GikVF9wvMobCvvP1qYONlL9EH+ufuLEw1V35BYIIClbrC2uMrnF0H3QbuJQ";
+ sout << "ma69tq8TDPkyDLiaczKuAxUzJoj9reJOYGYTxzP7AQKmmEmEZ2cX6+2klWcRXv23XXN8Ypjjnj+d";
+ sout << "fTdzxV4kzcHwOYMsI92tadahezCm9uOR0d8p9IH61QQSlUlJw8tX0TpGkNhZpv23STjQhb+uxzAX";
+ sout << "1vYdYbPOenr5vCyhnpp33QezQj9cLhSv1WweplUmZjHcJTkPBdflRA9AuqxDVVnbbofXd4EJDC6k";
+ sout << "u7xBoD7EKC+kCEkx7ygj8Gv5GVKbgy1js4gLuYwhJ5aqdNpqm881kkxntfRMluVcdH3IGAUzWR5w";
+ sout << "26eq7Je4Ttr1cC/xy452i3pJocbhCqrNUG85RyB5FXHAv6GMvm0rUIa6IyC/kfis+sQsdYkQ8GMQ";
+ sout << "wL2s8fdDT6l38N9JbRNwdRv8Xa9QAjwcGNbP2v7tAzM5MyhHW7FImYVAaNAaLbzE8v95zeGpT8Cl";
+ sout << "CWronhkcJRab8AKP2UcqAD+mW1hVEAqyDe7oWoZziKa2G5aW2vs/WG0z+NqL1zGvUekDcmJ5L4SK";
+ sout << "XgdQgxMb/1k48YqYQZFtQrIqoBbYn4qPeB7i378T5TLcCgB6SldsFdlNzs/czN0doroozh3W+sli";
+ sout << "d15Qnv5WMjOinjh0Ybt13wcUzeT2p0ovTtYLoiYAhDeAibydJETLdcozfpXFIJNUSoH6TcLge3tr";
+ sout << "0uVP92B1O+n0MibJvLsLUKQ9ueIiHgZb6bUSUixAg89QCDRLCZgkg24DLZ4MMTg7IRfFm2eR2lmJ";
+ sout << "Erpe62rf2+JE5JTqU88yn6kLK3bQ9vmaGRZ1NUxibTcdpo4hH91qIldLT+jrdmhrawRjYcRduYGO";
+ sout << "WXgjTbgKRTxnqrXwRD4Hl/B1EV6ggYC+jn3LQHaT6bYd1hORmtuLKy9duSHVCNBAZvnto27l+h6g";
+ sout << "VbUF8eZasmk+q8Fn83bx7C3eKoHjr6acEUyQxtWVCbmeaMd7h48Mt3Z3r8TyX3DkwQmpClciwpyC";
+ sout << "E+pbYEWMZXGOuXPmcHTM/Iky8jNSWyw2lLVQQUzPOJ0v0dtNipYRZqBQCDtSE0JuA3Jo8l00uox2";
+ sout << "bH11ErfGplsGZJejPGL8ba90e6xeLwH5oe/GduQ0/Xk4+faqBhy/7TFeexQcFDRCCTrC8+jATm82";
+ sout << "vHo+NWJjjDlEI4+F2FOhpRg7MtrrzNP/e+cD++wYeGjkRplbxd5PeyALnjZJ6kghJqFLL3NJ1E4Q";
+ sout << "gKmRErg1xWWQzyuDbbXPr5dwwRyZU0gkG0WTwyUy2dFV4KRyn2IbMH6STm+0af96YF0joZzkUroH";
+ sout << "ztMN8dWtmQESq6EYQfGlhQzNoBKLXjN3LK4TMWBE+1N5ilXkgv3cnN74RZHdLhEXRJnF29x/DmVQ";
+ sout << "qZQ4s31Vk0kqKdQ8tW0rs82+dMMtFv8+P2rYA1GZJQV4P5/TBU36BVetlN+swvULk4XpoqhTTMbx";
+ sout << "Oj9tONiyIiJitC3YiCU+G5uL0YETB8nSKrtHRiBD8k7nYj4fbUtSbu9+lKRsVK7kU41mKdBImON1";
+ sout << "6Qk0XqAx2DEK4w59khYMRRxOD4u2zZWDVp+Nl7Sd7ihas/vQx5yLXHKmIpjCK3SQYjJz09txIErQ";
+ sout << "0wJJZoSxH8efhGsTPuVrbQpGcHLD7bIkWf5kjR9MmOBCmUGgeeGOyi45x0k6Cx+z9oaaTXYcvRtY";
+ sout << "M+R8tW5gCLaOPjfbq4QjP6yfYogoaTiSEKOPcMgiOQKrXNiv2ahVBT/lvkm6Q8+IdesGWJtD6xqo";
+ sout << "+CC486Du6CFDzAHcnLMk5c3CqDfFGl5Yf68bV4aGm28BaM4vikeKRhm2tULeM7PipfQiI9R9Gy/L";
+ sout << "1yZB26qciwCalP4CA2NVjiJut7FZgTF2bO/g0qfvyKsAxMetRTmqALBJi8QvKqAE4i/8gRlTuwgV";
+ sout << "x6EGsUPCIcQmD8aJkZgy8+erSAY7MLcnUXu90AC37BLyaOt0tzJKfVRb6cP8wfZHqJneoGSNAA==";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'feats1.dat'
+ const std::string get_decoded_string_fhog_feats()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'fhog.feats' we want to decode and return.
+ sout << "AXAWQhjCEZzu/U+RFPgnRfCFsyRzQjyOA8TMshjEOL3FZhUZD4amYHfLuNN8NZrJyy2CdiXOWrFH";
+ sout << "dTTKsMQkSz9f2VGKE6GCsTcZ+fpbk+d7fBnS2DEUkw5ttiSxj/VsrigkeV0MEUR//EP3/w1uvJad";
+ sout << "96PFVeqVf6tEmcuuUbYzViC13lfnps7ftXmic5CwLZ97llqMYpLKgfEIXZpelB2al6PyVAg2FOfe";
+ sout << "2L/SzviQ001tYYgx+L2v057L/lJdx2+uOQ7tovbsY0zEBEXTr14ZQW5Wd2lWnj+CrBNQ/4wMQZIV";
+ sout << "86X5JV6nl/fxrM3eF5+U6U5cZoZFnn9vRONE0kf6/g6W4Z8qQNL2mJ/M8oKoLGtuaaEjymk0ENRM";
+ sout << "0M9GMLrO/qfcH/DHUb/HaK851L5sPuf150ecpIbPmQpxTOpI5XwmvwDAP9BCmLazQWQV85wJJFom";
+ sout << "LjT9dNi7eRsV0mNCZDwO9iwUTv6KfAR3Vn4t30X3n3+x+Yhnq5PU/4vu7GJpje25uxICm6+PC9bz";
+ sout << "QJC3wiw7jrlF8eRb2kZywIJHWtrN39R1tddENrGZkU8LlPTDcqTqfuwC1HXSt0GKlDRBW42DkSxm";
+ sout << "E7yKype2G4bgbSnuf/LKftuWlNXzMP3cDFdChtKJNOws6btgvS014C4UGWe3uNVp9SNAmAFyh+UE";
+ sout << "rRTl1w3RNom0yvlee9IP1sB8mOEZn5Rrzo2lCyTm6w49PlLfnS293FxvnpRSSlvqUt1lhzzEh7Wd";
+ sout << "YjK90oPWGqvNn/bP1vQW35KlVrc+JT9VlewcZRYJAq2OigVxa9Ao8Gs+fFyh8/Aym6Y3Z6+Y7Xna";
+ sout << "fsmYdGaUtfniJemtzHatTuhbuo979fqkIDQ6bhHlMN35IyzK26QKzGrMCaHenZEKck9qR8NTkhc0";
+ sout << "urO85952cv4aI/cJ5tz8S6f1v71F1gIz76zL0UNlkULQf3tqEleTqVF5Q+YKAtov0Tig9xiSuCZN";
+ sout << "btjVVpg06zlRqmnRqWpZU4W7uu19MKe6BGoZ08l5K/3UNgucyuUBTGtTGt35QaML+N4WLcTYj90J";
+ sout << "41Tlt+OD4C8nZ1ID982h7X+YxmnSdMNnaeM2Lc/h8twOwg11vHf6OSo6st4yjK201O8lH0Lemips";
+ sout << "LsO2qhAsF4AcWVG7HKUNh5jMr7TTQ/SPP2qiJ6eglJYLFAlwtCSXRSJFuK93Y3eUtzDpVDh44tDJ";
+ sout << "y+Bsa9oWJrBeBouo9DHG20N+wJB3rOVDprmq1ZGfhkSzqmVzPk6motDY/jkPGZ2E8/NIWpXpyU8z";
+ sout << "cUbc8Er2cAVshJ9uGCg0tUh6lGWsrWucpgCS5lI44PJkbSuyFZMbIZrgxrVmkgOa70pmmuHebs+c";
+ sout << "GW0FheFD0IHXVxaJtTuPeDuMML/YrAqdyzibareHbg2Hn4aS+qhRUcMyemOzrvjhGF23oSXbJqaO";
+ sout << "sJw/DWkt33eFNqWxofXw2pWMMuF8akESryBDyGEfk9nofHoJhTciXGNkuxIiSZtVYsVnKNU9C4ei";
+ sout << "31HWlbHjMtG9RD3zqkIXwerCkhElWfxG2M9ja1S7M4+0VrS9nT2ngc/n/KZDpinuB3GIOnnaqiqQ";
+ sout << "8ASsOEl9/Ni9Lflomns/CdxWns3OlcU8KVhWHYz5hqV+MI6SQLlS9j39WFKd1IPaer4y1fd99x44";
+ sout << "jkyP4ekjoVVfVxpJlb4favfwI1AnFD2K2TaUeULaALTuBwNT5DnRcLNnwqD+6r33rEv94Nk+2+bB";
+ sout << "0QPLkojfPlDJocwcGon+z2EHQO4t+RPAafjQb42dNCymI/bIdgwL64vldhl9KfZtXhE4llYhYuLf";
+ sout << "xliN92ELGSt8o5IdAsoOoCPdqQh5NujdDs3p8kVSETq1HLilGRnuxyVwrTqJEUC4G5ntPfqXr1Fp";
+ sout << "revo1FOr4SRa8+c2GgA7K8/fvXwN59GONFMKxt+sI/xko8hbOQwOu1VEQ2Ak1aS0cc5MoUrhUFhF";
+ sout << "gCTqGd+9bWwg95ELl944dD7MIauKl1wHy5iI7u7EDlhvlGhmtU+6lzCmBiDbbPnoki+yZLc0V0UL";
+ sout << "jvaYU1WlKZhjWcA7fHkUyGchIe+KJE8DcW+huCO5iz6gPWwy3ZURQEW+wYPD4Sp6szOTOPsNKcN/";
+ sout << "j9PxKGDbNO6/Rou+TEfznVuu4ltLKZSDfYMK9g37+XjMLHTU5jCoPm9KJDjFvCTRAcmRoQfwuXyY";
+ sout << "o7w4QwcevrjFdb3/IYMrygb57S0iMkFJPUUiCF/bOfQA8tpLePYYtg2ILGGuH2UFwOLszxguBLLD";
+ sout << "ziSVWU+xmCi46kKVuNE9gyeT2OCPtn/U+qEu2B89UkbynI5v/FVpJhJf6MjLc1jfDrEi5NflvvCQ";
+ sout << "l2QLJGjtfDRJgcXFuuNzWBMiVglMOhT4n3bta7tV2KraK7Yc3Pc0GIZv/zVHf3BqeDEikXdw73tt";
+ sout << "aomM2RQstiy71sMmGkmCvUnEbgrY8O60g5nCSmMZIFbbLit9dLyBjHFUELrLxXup7wmkxu+ZVEzT";
+ sout << "FxshL754iYXXkXTqVGirp3NNNhGKPGc4g77Yne+nxpkZ4MwOi3wQ7YqPihwIkIdYewBMiQEJ4Y8W";
+ sout << "0+Os7iD8OYrccbHKqvKMTE84QOKGZSsGIaCD+iIvv9F9/EIB2ZDv+2aEs/3ix85vtg8L2f6WmGdq";
+ sout << "fOdoKVuU2pPIkzlyTNQAwai3NsjcnI++lEpVC8s0K7fpIN8uWDLRnGGuq/G2gFQEoqN8eP0E944k";
+ sout << "l/YTW8+baMZVp/wQNA4mo3v2UvDdHrDcZmIvVuwEaAUW4dwylXMazecCG1SMkusYSXGoorB615oF";
+ sout << "DaqndLnQSfyJrsXdBdZDAPmsssTuOvCjrjTPnb36+WebHGuLUNYZ1kqjhZt8hnxTFuYN0sTWb+UU";
+ sout << "2tL72LNTK66l8LuuQKUMimV2uHT7Dsv4VX6XE+YXczrC0HycSaVmtshxRh99fnNTEtDKo2bXbSTt";
+ sout << "s57BqGWi3/orxqSXecUPFYvQBjrIfDrinany5uht/8FK5JH+c4PMTsbiEeQ95RX7eOaBo6IoF7PB";
+ sout << "zaXhJH5TW/qwdO6K/Caqvjp+no08tUTdn/hRdGMyQ8cIYXsMaiQKbkpXlGnhoBK34XA12T6pXa+E";
+ sout << "QocfgTw773eQRFWN6vmhvuv8Pd3KPJAJO80slizTFSOvxVO68aM7gZdnDFTgfibe/v+2N1xIUq12";
+ sout << "B+YWn9yGP232QOnNq0nuAvFLhuzlau3U+qR2n8DThKWTboF02vsqThaQzF+0EPFk/3we6AeAiatk";
+ sout << "dcvXEbGk/TkGI1V5ICcpGS/fvivZlqYhAIL+yi5/5M3wPX14KwriXpFVMGKozUYaW05+27adupOM";
+ sout << "p4/0EvfeM8T2m+MQ3GcPLA8njXEDbLLWnoZ+YMC3l2OVRMV/yFXkZvQd6tAQUymv1xB03lNv1M1K";
+ sout << "tPE1Ps7ucMH3cOjff0fZOEYabEl81VbmYUCYfntdWlApBLrs3gBWiT0uLoiV3cJkq2VWtgpeyAcJ";
+ sout << "PiZ4L7AblENmUS9gr/7gYdn3uridNqfoos0uvUCFrIS4a7siub0FpCMwiwSiyZwtUoD5vcq3Khza";
+ sout << "DVGJoijIBo/yEgUTho23FJqMaOYyRnVen4i8GH1H7PUhJe7KThuQYk4UQP+XO1qdLITUwBlbNvks";
+ sout << "ciB3IIN6QKQTcoDXEEQaMcPYRNhaaGYFDeSZ4yIRhLV5JyPhOiLFj5EOzrC2B14Op4lOkerAY2J5";
+ sout << "cx0CN9xEUrm80GJGtKudSd0JKscXIDTBj9lxngSScCmKQRn4AWJ/acRm/fyc5Gpg5PLx+o0jCI97";
+ sout << "hm6qOSqslV4GS3BozqP1x18yqrC++IJvOISjjfSvSAkW2s+qv4ba9gfNYIhsJgAan1vaAgSPDXOD";
+ sout << "hf9RJHBE7Xi86Ux4uK/o/0GK3R4QsKa1/t6qij0XAlc6lmt6MXbSr/Tnjs4ykAbSjzmiwON9Jnzo";
+ sout << "KYQiY46ULY+o+UjbHTMuxkTJQjCKtyertjpISD1yNYBxItA703l4ZLK0iklv0ZrFMyov8y4ySVmF";
+ sout << "Tj/eFWy/PpTEQdyzGXrotmYbb+V+BG5e04bq40FsUhhALgSkcGEYoQtxLCZzkbyWQmEN/uXC9Gdk";
+ sout << "wG6Iln2vWqzZSRMeGZ61VMm+dzr5CN2iHtFcNQnjwWHgr8C726YO5j5eWLvHqLouU8ufiojzvsvI";
+ sout << "ycQ9aUfkcr3AXu8hv0+SOUK8mT5JdQ5aPXc9WOe4c6mdgfOK9Jq2ZJCLgQvj2swIoGQ4OVXn2D4t";
+ sout << "MNaKJu0/5ujgOg/634gfxJbicVJLzn30TgWYXX6Ixj/JQW7yxM2iYyHyOl+/SltruQS+NnF+llzd";
+ sout << "rYuXYejtciq7Sf+DFNOJPSaTJuI12jSl4yH4fOMMkkxWMM5QkFzgWGK7A90BJFo4Uhy05TELqU8J";
+ sout << "FtUWdrngONpIU6haQcDpoYtXVQ8h06yZYJnVQBToS4szMcsPPUVYxITcGnbT470DR5+Rrsm4MSkO";
+ sout << "ZXmf4sC2G/k5at6a8o2OKsohg34bNzdo+R6rpJLcMLmUogDdHj7uNjrvMXgZgN2VS33jpsKaVuww";
+ sout << "0fm1I1pD6ltIfO+iF5oEOQz6Ml8TSxcGei4xi9Si7g4KGTZ6+KhS/KcdGQv7HsGH1+auYumQvKIJ";
+ sout << "mWY1hPH9tLSAB5t6vh6YUg0Fx9G8Wf5Twz8JXsoyRbOs2AAvXcBQ5EPNbNiKz71rzPb0s4Kd3TKJ";
+ sout << "cNZexkD8z0J8U/KVzDHA4kyJU9tKB2ZArGPkIYAWw+u9d9VGdxXgst9WtboxuCOy0+y0XiXez1nq";
+ sout << "8Tib6FgDNCMKD4uk1GvVs697TYCphk9MlHPUevFlgBJuVI2uOjsnGjBtzQOeuWHyXoz8xQhyLtI6";
+ sout << "MnW6c6lCs8REqpAWwuIPF6YLzAZd4uhpOyKFTWI5jus7I2Rkr7RmFDcOCXcnHw7M60Bvcgsa5xy7";
+ sout << "AoMab6/pVJr6EjwK2JEmkLaxPUU/F0WqXpI8roFPbIZ0sfuUzZ7tZmkZtelboyTtuEZTxbKagOWR";
+ sout << "78qwMM9feTEwicGFlvyiYdUGaIwiIckil3oQO9w6bQAaPnygDcykVMFK1fZaZztYQ8AKiEjyr/V5";
+ sout << "dvc04CEShL8uRKDZNY8Y5cDGVHOSR2g/u0t3PDuzMCfmQeKNJQgc0uF3ozXP4xvTvopiK58Y3656";
+ sout << "m2ZUfjDIf/g9gRXmrmte452CbaU4HQzlIblKaEJr153rXTQbZPbSHyHOHWRuczQAtnyP6k8YSC+3";
+ sout << "zqBtSIvq637hKi+7Ov8NgiYhw78ehwHiPtenLDlB/YMCwqqPNmo+8eYGmtOoRgaQecIoYHYWfaUE";
+ sout << "NMtwvnfU5g3JS0j6VZloZxGNmrtuDqWoXclUftoQdsDlE//Q/5+KHzyZjf60WrXx1Ix35UmF9IEI";
+ sout << "jEvtLp8KOoWs077NCkXs0HVwuKxbpZx0v3qVb7HwsgcaoypbjhWaMGYflEKvEbJt+TD4MN8kXDEE";
+ sout << "VHFdCSOzUplHRdcy4aOkQdzfEYvNTgaiTSO2CautGYgS2m+l6Cd8B4K5PZs99xYFa7t80L1d2Zpe";
+ sout << "M9rYCt6RKakTTFDSR94nxxzdyfbA9sVYu6MCD1G0lN+zrEisA6hjqSAqUXsooaW4WTOd3HoNkZeS";
+ sout << "tGviSThwRe0awN4pH5dXfHtxaRBkzOBVw12FB6urObgv+3jcnTzRZWvF14ioEmMuKkHb3Ienbv83";
+ sout << "71BeuTrUarJazVmIzbH4ulyWwLxEHeKL0r4PIEfbUg7tiU44lMnfxmrFFR6FiwuBxrUvyv/yEmiP";
+ sout << "AbMnRjX+4MVrIHr001qxVFOaeK5FehfvmWotuUW63TvubKCRp++5uQDUh/LeMp5ZSX/RRVpEKEwA";
+ sout << "It9wUHj/cafFqSwt603pOOtAC+xL0ozsVF1kYyl9vBEow6mBe66WBwekd+DjlzLbehk5oIXSaYsp";
+ sout << "zu0iBnukS7YoO/2Ho5JMdkJuHf5AtuC1bDt51FL9McuIHboobI3/K0YrpwPmFORwSUSroXle6XK8";
+ sout << "LeX4Fcp6YaicZoFYpLVbhVvrJSW3F4zBERbSqBrHBMxoXMblajhf58RcPKwZPvPnSL+yB5V0VP5X";
+ sout << "RJtdo0ir42WND1VCIfVqCCrSg/m+/R2sffd/CzqQ/JuJnwPtKrAvuaZ8zbmq9UkK0WAWY9Wql+gy";
+ sout << "AYQNgM6Nd9TNJ7LEyEJQl3nK2qjg67rBCizIdwRcPFSWoWE9DjVz5eFPIJjPG//dpt296BsBeW7U";
+ sout << "NhV2Ig9EstFK3+/GC3annnXsT7OQt7QCnx8BvzbVHiFU0n0yikhA1uU9iYm5qsVwdJ23NWimhTeo";
+ sout << "XOWG/MVCrPHpV2qs/4PljiPW2eVdzodj+nfDPVwkHwmFm5Y6TGDRBfJWd+ZiXMkrDa2CEMrMvzEL";
+ sout << "/DymDMQfMI9HgMiLWmiXS4DjcRL1Fp7DTHcQa9swLcZwbyE8r14L+5xXzDySF/EKJ9LwmldUaN+t";
+ sout << "qjmQ7DecED+MkTHMYXnID/Gl3a1/wnrfxa9IuRAwHiIhuwb8siPJT1S2mncAqk7YwUyaY0hl6TB1";
+ sout << "GJqoaPTaeFqzTD1zuz5N3cFxx1O0R7OBzHuBtV/JXrZlyzNKZhUaopWCyQ64+RGa/K/lQcyJNVo9";
+ sout << "lErihKgf/LNzd7f94T3CmYMRUJZanOLNUWcOg2iJLGR3hy1pnlmc34IgQ5SHZisUBy2jsdK0xoKA";
+ sout << "Kb/cPiaZj6ab/4sq+Owh5I66r7FjACv3uILtt69Fz1o4yaOb8z1ZSvFei4CPrikpxYXltEj+aPZq";
+ sout << "pAyL3+qrlJ7eOYE2svEpEAiYTztUAn9/ukCZZ71IafGHnhPggc+eizHPuoqfdUa1NWeM7H/XcFHH";
+ sout << "b5PQ2mQPCY1gD4SrBqPS7I1QvfMuOSRba9YFbJAecSC9bPHlQQ+c/rma+R3yhyLyXN2j2xzHpgJr";
+ sout << "oT2cEn9yT2fHRiZIWqMAJa0KRJgRN0hLjhGQ1VKN4WZfKy8uyqCrAbzBqR0eYcDjkpvD/syBg81K";
+ sout << "8acv+5OjBrHgCofewj2DFs0lC3mMlo8qe+GNjxu/CZufpnNabIY4ggwBxdMWZvZQ7j27VwMmvkdd";
+ sout << "Uei0foug3YA/Gy/XUTRLhuyYKy9dqdv6VTSiek0aFQyM4kz4KqKlAbGixu17SFgvZ33fP6ZMZ5Pd";
+ sout << "eWC+Z1vRafLRpsABf1DJujXUEUZ8HJfWBuuaMkJkxFApcPwctu4SMpZDjnoc/7gblzc0yLr8BRxD";
+ sout << "gygNRNu4uVAT3sbaRzYOvF3tq7wr+R7qS2YmPEzhoZyfubEHdV48wmE3XD/RnZJ8maanzDjEH/Q9";
+ sout << "yX/a34y6YSlrtNim3R8pI/AddLi86cAcEKfK3V9QXCGFqO7GfQuyZHO75grKpL+6MwRPUWigaLGN";
+ sout << "9JfWrc5tpUu8NPY0ACmztq5p8iNqylVgIYOC9LRIl6TKJQXPSY1k73NpzNNzOJhVCCLY2IaCMUtY";
+ sout << "pbxZs+VLufZmNFhq1D5qxUDfKII7x9ocFYqtwLnjZztYT1HhDTcwmB+5uQWHoXKRvsTgvSfR1l33";
+ sout << "w4mqzDAI4ykfuNbLmymuKaVV2ST6W7pI6xfZd/YD1em6OQYF+F44tj2G846Ry1IsdVK8AiprcFHA";
+ sout << "QFxuwxcY9eUgjSvhbK4BKkQgFSzx7yDZTtDxXSWrGsPoHUCh0GKU757B4ubLvDV197o/No9EKfGh";
+ sout << "3tlmEwvvwQd7yTgwdbksgVieGZBJqLK5eRe3CqYXpHryPcxydrO+2LPpzUiiHPY5FDRZjnYY+as0";
+ sout << "IC4dtlX5BicqVB2gbHiKHbjLf0pob27d/WqQKKfA/7h2wY/jYyckEiX8g9M6I5QeABzABxyyAMlx";
+ sout << "zoi3bd/RD35ijUysmIl/+qi3GHaYKK9Bfu1SRb9oPgmFYKKKRrYxm/cypmJXp+GUjlNokWN1u7OM";
+ sout << "Nm6oqRRuqYWiCwZge4HHVQkQs6BuH/Nqvd3Bqkt9U1eH4TnkS5MhnYVdHt0hxmG6Py5A446SERn1";
+ sout << "gJRsUDXk/kRkird01o1UqCrlhwX9WG1cjY2I9nFsokudgnYryxn+d+hVdM4E4MO5ZHmOXicEn+hZ";
+ sout << "N5eyfU7B75rBT/yZzzot1k+B07BE1x8K7N+S5JX07utAi8/htK/a4vxKhJiyx7p8etztBI06LKss";
+ sout << "grLKdJtqZvX3kwFEIXqpJn5W2bijJNeQYnJxfFY+5D/k5POGuBYjf9lw7ilnIkbEDkaMWf9Mrt3W";
+ sout << "A30zlSEMksIByxqL5/VFx7oM1oSx872AYwXDNNnfITgNwKBwniRHuU68SjWBO37CYYmDhcB7Ug/M";
+ sout << "FMX7oBgPrld/Lg/ut/xbWwFOiM1M6L0TmT5HX6RcFSdKg3Sda6adAkSK3Ux2HTZAUVaK8zG2Mm1e";
+ sout << "W+/4SSqK29MAZUBK+oEPrNB7An6NgWZ0TRu8sGeMcZCm2Zb+3+ZVUbsTFoG5pFhzzvMGr9fAsh3Z";
+ sout << "4Ngvc76mkiqT7xDZw72BUnrPz+eO+RGqMG+oGWJlXd5PYD4XIkW25kBfOr+iK/gCFUsfZbFoHEFa";
+ sout << "M9EFExjXiZYOz4iuSZnpeChRQ5wbl/56iTGSMHpB3KKU9Rfyv4dIDGS0KmnSH7Q7mdn2lXerjDmF";
+ sout << "zlc7IbX35O0kHeSXaH08ZFf9vBweKQvEW6Qxs59FWPuS47FpXRAPFClwYskLxByN2ux9kxHprAFr";
+ sout << "zuNKRyilr/NlCEJn5SiYU2/fAif/YrNNiyxXx7sWwQW2mf9Emqzsrkb+eCqAPs1NmwuHO8JXIyKO";
+ sout << "0EbIUaVuENKa+DLtcrfI9HlxhrF8vd8m8s4ZeBHWr3jkbVdcjX9mmtcSrHyHXcWVImc8CJOiR7jw";
+ sout << "0HTAydarj2G229/7kLX6RncrRqY6/Z2e0YEbfTt8RDnwriUmKXuf7VUMljP+tWx8aXMlbm47WMaY";
+ sout << "nnjZ4wE7UTgrYUce7ehuGDBEM4VaKYrp4n782gXFdo9VhUZ7JEaveARi8Di1SI5MuTQ8N2hUfPKn";
+ sout << "JbK8mEjSmcN1sOZBOXxxWX+e9RRsO5t5ujCSy+UBr4gaqMHAlwxtBmif+kl7s9o+UjHOlWLn7V74";
+ sout << "PiLxVUzCTHz5A0rX7MyGXMBai8BT5XjmazTZz8YPIiq/ZmnyULon+uTrmBdivNEjDq4M476YfAma";
+ sout << "rcqxp/picZEanq8yctrujPHilXH1WuXhkPM+gn9Gkjp/sGfA13JiFeXZO9GJHkqLwBKApqS1HJIf";
+ sout << "bhomnqcdp6+hM/IEaNS3dVZ1aTSefIM/OTR2fNVbmzwfENzVke3dCJvw5B8zTsgEXeKuRE58pYaQ";
+ sout << "H+ffv486NoO8gXKrUF47ptIj/Tt5apc5Pt1AVFsuaxvry3UrUmY7MLbfM4xB29ah1PY9iYx/OssX";
+ sout << "DGvnDskGHiQJlz1s3saJl9qm3BZr3//6D1/CgGvXw0DffJYL9yIqlzoBRKAue/lQbWagNArmsccI";
+ sout << "WxVk2TqYgTHwhl40fWSGIuY1kfj6FEAUxBBxGul8y9Cv3e/hMnE0gDQz2tZs7EGqP8WqKPF4BpDh";
+ sout << "Ep+6vvj67y2Hf36GALRYNLlhCB5+HygzOW4jwtQvCHxbmtPWvQkRQ8iQ7XlZH0OOATz1FsccBZXr";
+ sout << "3+O26cf7zcoZIE28MIQHIJqsDFHCJVr8JVcYITcv46R1z9ZgQQ+z7KK0M6FY6SOgCyu5kyQR6YJd";
+ sout << "HEYkNp9J4N21iRr08DUYycJDtrf7xoyYfj2L5QjpPVEKrEhIkjb1kr8uZp/KfCjCkedPBrMsPuOJ";
+ sout << "28Acpne2exkkKaunrJZr64Axfgl5t0OIBP79Cqy8PRnMYQTtqHf/pxaHNyrYaXO9vjtTVoV++kaM";
+ sout << "svy9Ol/R4LxEXrgtlSYhKF4o6iCJUSiLhE6j/xzyOtRrCAHZ56WCJ7YTm0oa22TtH0DvoFMY11Tr";
+ sout << "SZ04ig5Tl8At1jXKymUSnK7EXANVHZmZ+01xTE4efnFIf33N4uWK+c/hEqmS6KCVjuHT9FwTKEIm";
+ sout << "4cKum8uhL11rodikW09dYfIyQV9yO0k6EJpeUSQMhPBqbgTWHnVQoHoot+c1uBoOdK/+bRuz5vip";
+ sout << "l6+0nmpZZoO+OjdEap1pQqpcTEhmfBuUGXCibggZhHEvHvFGQo5an4N48OQA2B8CccDHMJIDP9+j";
+ sout << "0JfmziBnF+ZOLfKLwuuE4uZg87iSWFkwSynsWwoUQ1Cy3u/URW620WLmkx0GDoGPkcsxJ0LVu9dh";
+ sout << "OrWWaFesEANPRMCKuuU55hs913KzoOKdgZPzM8dheJZaZB15wi+u/RTm+obSWZVwibTDyLPQ55mz";
+ sout << "FSv7Lyp1EDFvxyel+7osJa5ifhLrU5f6CAcKwS5t2IwaZaBxrVaFgX5lQmifY1Gd2new1mkfWYmb";
+ sout << "n4AaURVZOBC4U1Dx65ch0PNeEYxk0DLAGOsvdBDbWbFMNc0LiGEF6GiCMBYVsw/cYnjY07cEZg5N";
+ sout << "M9RxGfLfVlyy4MW9ek5ov4+NVLV+vaosZvA9gP92vaiS7jBj8qCb2uYIK1mfrjHcOUFqcxgzltBR";
+ sout << "eSU3ewoVRJIaVreWX4RSXTomL1GOyfcFQ7gZghSJLKlwgkK6+yvqig0jyuvmiRFJGDYOxQtDykwM";
+ sout << "EKOOZrckRdktXs6aFQOla+IRiN3XVUfY2wvx3egfOYPKrJS2HBEdXeIelB7LK6vRx8nwuFAAwZWQ";
+ sout << "iB4OvgM8U9N2H1wa0xXWc1897unEm1KxAP1iS44BmZ0PPASPa1RrltkwWMwRtMp9fD08mwLdGe88";
+ sout << "fsGIrnwHA0HIShoZaBmCSvX0yfaJ34nchnMukP0B/8v7f2j0t1F6hr5qJNSV3cTNcc2QToBy03ro";
+ sout << "GphkM0UhwwoYTlFppF/8LV4QmtsNK9wuziTmv8ghXUgvFRTtVcBMlHY4SFcRKoagFF+B1UfCuVZt";
+ sout << "G8fNilXM8zc2/eP3g5FTU3YXRgw2CAheUNiekne53EE4DmDJ1f6n4OyjZFP8PZ85Nzkkco9FW/8/";
+ sout << "31JIClFqSE1VGbDXleZMOrdfoknF55krQB7v6Wsu82gzQU2AvvhbuFlPodDKH1xUuUT7gERKv4ka";
+ sout << "fdS1xf/PRafedvPX7k8fFZQerqvvUlgO6PaEAxK5MqZqHiUgUh9A/pqBZNBl1kQLP9+G70i1CY9x";
+ sout << "tTETsJJTgYz/HP6yXsqeFm52yO24myKHMURYBmOZU6AFyqFCsOdJG/GA5otaO9MG0A3hPdC3Py+2";
+ sout << "GWbnl11yPDVImktL+LYbx2oolBWZrelLRuKAtDp4p/Svt5R6fOvYM9XxZnnpR/MNaTt7I8iEZSQe";
+ sout << "w57IL33ZfHgKlxEr/ouJgT3vlxqXsuXnP49ytEVKGja0JAOzlShLYqB3GIY9Gb3EtN0h/id1jqQm";
+ sout << "QwR8y5q54W8pdYflhgi92YWYroUUIFRxheAgPwUSqrOUNiN6xpSwr2usNY1zvGGsRqFKuhKgh7P7";
+ sout << "cvGc2sOj3izPgl4NlR2DaoFTbXd6uDM3IrYGSCFJdbMlonX5cvO91ySJpKwPOnqjAbzBjnZCfepb";
+ sout << "4px8fSH5I6LJVU1R+sGRCoewaHTFlsnaCQrsy9BTGMIWwAKCAYCbWN0T4ItZaXhJWarghYGsT57P";
+ sout << "MagSlMpKeiToHWnWhPhtxkhm1ZVRuYekcpwrlWwsHBV6O5pEML3wmWZDWJNZWh/GkmhPhf0aF78F";
+ sout << "2lOVKBWLYq+4Xt3lVNvqCr7m2rQNH6KzZUNfeoIJH48SgirJQiNNE2iQOsReTQTbCW87NCN4GKGh";
+ sout << "Vf5A2JU03N+5fT+dorN/LTQmeKddK0MR2nshO1m0kZSQ9TDUE6Da7ITtIjGpKK1QqohIx9BEMoML";
+ sout << "t3BcLTNkK1SaaYRE9Fm8AZXr4z9AILXUKktv5bytIRBZncs8078FdpF+O2JWEzELgG2s/FTHTyjU";
+ sout << "NU3A5q6+8LoeSXuoqOZz2QwEMhKwkz7AlujJU/CJ3+uZBAUBythODaVBlaZM6/f8dSY3489Xiu9z";
+ sout << "L917CSsYpq4JQWq7I7pOkPjy0t7Y7QmKPITQv4QQmF3P6SJXiainjWDQlZx59RTzg2Z2MDZ0dcWx";
+ sout << "f7f44IkngOVEUHi3iwrSygMTySBYdDxei/kBt7oXAAkOMLucZ24VE26nLD1hMAwRz83ENc0sOZAz";
+ sout << "F7He/e0H2I53NAYlk0s55wntUcdp3sHh3UOcGyBGcRUipg6NWy/LrWzxWqJdo1DsBpV4iazYLRfB";
+ sout << "JdKobsmPGmWSyABV8tea/IuVgqUm5RrWfBUa6Kw2TqVuos2PsGqsRi4cHKl3XQ0GJjbV7qV7Nacn";
+ sout << "mRVqIAetzUUZkuyp6Q8OMLX36zwwrlP0rPJaqk24VBEnx2FKjaR/LkP5lUFenOHJuPXSh89pYWA7";
+ sout << "7/y1i2ejWuMrpElkJ7qfpTGD/A5ZkNWCyla2VXOXUVjK2t+cBYByjEwlMCwr+tXUXQYVMPf+UzzT";
+ sout << "EVv45u14EhFg5XOu7urPZWyO4EbejgEvyyx0c6Zgt7RaQgXn+akwMHI3oDn2NcZYGo8Vwg/fSs6j";
+ sout << "Ggzyse3ShPpUr6qb9TT3Vo6hXegWwd6tDPsqDlc62JhQ0MYdw2A1x8aGK1iUCz4pqsWmXWhGPO1g";
+ sout << "rXyqBaIrVHDMim3cZO2LY/TOwIJZGuHNvVnQQ0TDYAQEgyej9mVLuWaFuCE/JZlw/8U2VngPVq0w";
+ sout << "EePssUQnewMHYTfceXOETEBkU0xl4WLdwaJbvIXG+pPcVzYW8z/D3yh2LU7KHbklffbjWhYyAm3b";
+ sout << "nTTR+YTZoF5PwWA9DmsxIbLQJn3Ejss6pOj1YSqq1cD2r1+TZjWefNmH/nncwjU0T6e/iBxkgQLc";
+ sout << "NXouKJEemknL/HOvd5sUS6YYbIGg7Pa3ur3Qn0quKT8QSHRPTxZeTP47rSp+Koibf/XfZOdG5RzB";
+ sout << "U2Kbz2vi24JkfzbNCG3tjNvzPJVcZrMde1TNVNQVFbXCzMZgO0Za7o3IRyxy05rmieBcGryGX1ln";
+ sout << "hnhwSzxSSxnuJ3szdlfw1j68i8LsX+iaZHonWIw21afDPwXaoflD4cZIZI2uPLAKeBoOYMxKyFGj";
+ sout << "2KhIVlr+poDcu9gZBMF3CwwxSRZpwVhpUllBfOJDJTLbM3UOyNRMs/U2T6TdxH1fRbgeQcfQ4CqU";
+ sout << "D196crANUlHv6VL0bU8CFS5f85YwKO4JliKHZhvQRDtcBrwKIXttWPt1f3OONAscmsl+JtVgVt+r";
+ sout << "h1R+X0b/puYAa4tqBvUBiiokN21cQR40Pp5PPLMszGskImO0XM6Fx9spo80xxDLhOXZsV4rJKs3q";
+ sout << "lub2BoFFGa/nbd88uBheoBR8lh5d3sWciz8Y6eMMlEASpLmXFg/nqq1jzTlC7lGmAhyqSTAL2dHJ";
+ sout << "+8ACp+IgMI8v7DR1TEp0qgwrVC7B6+8bozJlpH73HdqQYkyjAnVcOBTn6KO1pjAJ+YcXqe2ioiBt";
+ sout << "OMAlNQVAJZgxpEnwwd16yXDmK+Fp7v8aGEx1EECVaQYW7ZVvKRC9249ioFaEgFjpo+8veoWFd9G3";
+ sout << "1E4aBC8rOXrrsI6U3QqWbYlusqvogtKGwT2pMYSidnPneM9iaJ5ixPSlB7VPp4MwdZvPIpOOR8Hq";
+ sout << "cFqQfZAa7sDq5idYKPUrRo5XD6szBNbb9Oob6y3G9RQndEEt2363luFr6Nv80r8yuyaUUlRcmPmU";
+ sout << "/VCExotSxcx6B4uWyd6UoA1wZv1hY+OCUL0ahArlLzkR2ZxowfeDCn5eoawdQnazyx9wSAEmqGTT";
+ sout << "lrT+MyINZeyxp9DTtvjFwVw1aDVT87PF7E4Yw2H5V/55elbLzn+Rlr4AgU7W+RnuwOk52LCKrSud";
+ sout << "JTAq/vix0svgRHAs3E8wsLzBsFXc9SZ6/AWptBKuix0lkoDObM8rLJMjxpQPmIyVvB3jWXITDn+n";
+ sout << "8ou5GgmZUDeGu3ZiOWbtPhO90akFKl1XDQJ+k9DGrlQd0dbsDa2lTgPbL+3wOD9fdjFzqfmGA2vI";
+ sout << "kg75VqCCEOUXbQ8N2U2IALiOEuYblJk7WVsbmOwmEw7T57QVJvivVJ9Z36TJeWfpzIiHCeReso1y";
+ sout << "3RGCo1qlsgfwtz5e3/ycu5aEq+Pg1W4EtSpeuIEZyH7zQaUILmAFP23JDmHyZNh9ewxSRfZGETHg";
+ sout << "KKN32MMx6taVzf/8+RQhxI3JJ+PSk0vrvUEX4L/2Bri8meJbN5UtWuRN5gha/O8jPDixa4XrMAiq";
+ sout << "etOumapH/QhEJBy5NHmcjLH7XpZ72qxTIL6VS6m2GjFUyekOdKZOHDyEKA0FR6wgEnoki7gS0L6N";
+ sout << "5plyHCOmd1aBWEXf5+P8EJeh1nmu6AXtDful0kn8G9nNnWrbC+iIj2QQ+XZPxdZIGuJKd6MrXf0y";
+ sout << "zW5dznFYvB8R+LTy2SW4WFvMWzJAlggi1N7w25/rvtnt57E+I5TimCjAKJ5Vcjj/R0DOqX4BNnOK";
+ sout << "MqeMDogP+DE9NesTswgfFngZBjomZsx8f2Wdzzi/UW0xBZa3yPk+CV+wwTlWNxgBKFCYFC8GlJ41";
+ sout << "MQt+TPDsDZJuBdbexuS/PA/zzOE/wZPVXgzLZFmTKsZAfz9894HEHalRlKLgTlvgW65XDihdoh71";
+ sout << "HDwdA9Knb9r2qH0dwsOTpoU0uABJkND+V1Ezr3oi35dJ2zw2gR0omEnVXZM9dW8XIp8ln8zegt+S";
+ sout << "dyMbNDzX9RClWSIVIGuRYGCahBqXMCJG/AAKvkXM87mZiQsIz6uLvJpkeGiG1ah7kTSjVFIKSYES";
+ sout << "73oz5l7AIAIReUzdvMLk5ZbmqdW4JHBR5XBA3rTpAyNEunRi2Ddy1eBt4i8I7GlXaaZ7sVS028ze";
+ sout << "VkY4WR/6FJKO8ccbofsJb8BiM8UZdPmkRdEKGgFv13dkq6inmo3S52P5S50mYapa1YKvWCvMHaov";
+ sout << "z8BOe6WqptQRGc+vZ4v1vgq40yaWBa9pSTyvntLUKCkQX0qi3ifKhRykXP6SBylckIHs8DfuJqXP";
+ sout << "X06djSx/F7IyvBQg90SLa1h5hYRTAchYX1ZgvG7LlSvPp7i+sWCx5D7KpTbF6O7AU8kMPpaRm2wE";
+ sout << "RvZ4DCgXdh9Lgw69wRlMfWJEJcV3mOIGDRoBZToyJqgHMDJigCet6NIsD2xvvza5JWPf3DURWRHz";
+ sout << "wdVqLXlUivD6/9r290ZcIbdWFMbz2aY9x/ojLrJmAGX/kySItIXsJlDbcvquJWNc6UYrpxurQEku";
+ sout << "ejm7E6HjXUPmi9gh2kRkw29A5xIFaugoph9dZSwyCEx2BbqY21hwQ5elfdcvdJD/fe7iiWmfrtxQ";
+ sout << "BqVebUOrY5+/vEvMf7EupYK8cxtCb3mvViCR0rGokTwuw7NVuMH6DEJ/zXB7BO/2bPXnr48Q2pmM";
+ sout << "pThaXSCJ5Ta1SZqGZRG5pBsUWg8MCQBc/KmDWv5csQlqDJOLvEULRsKdxxcm2wBthRXx97JVHMwe";
+ sout << "Tlb1TkjoSIQonMsvU+dRJCW3qhgbR4i0t7wGfbg9YdXG8LiKJHvkpu/IPKHZtaMxOp3O6kl7Lcb+";
+ sout << "Zr412Aanu8BN5GP+rW8/C+TjEBG/WeR23iZr6SByC7TCzZwJO25J6mtC1nFxdAUizSmxUfwDwvdE";
+ sout << "5WyGMq8TTIHZ9KvKMw/jODKXcviOmLHY9CsPtYXbChnF/fWgQx/ykshrLmtSqFEaOt4YCiPFKQHh";
+ sout << "/IjsYqcs0UQeGTk/6tnp8Cewe8A1DAaBEI1RTllobQKLXEUBWAV0VcrVXILhemvApUW0uqu6dSqg";
+ sout << "cfv9uvu1BblNEfNmfTcUssCnDQ5c5vFjEB1KESkrBTL+p6bEn2b1bbxQrhiqnVO6mi9anbtBgBU4";
+ sout << "lTjSG4KMnsD63xnrLapoU7ReEZxGjsQgm1Af8/lewJajhNOZi3FmDgkb8Lhq2rz8OYy6pwXCEM07";
+ sout << "JwMzOWFBZDUA";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'feats2.dat'
+ const std::string get_decoded_string_fhog_grayscale()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'fhog.feats' we want to decode and return.
+ sout << "AXE4jk6QRzUCNtyVtAwaCkqQk/DMJKm1t48e6FXsZJ/Zfw0Utm0AVwNlJU7O2+XVsftV/yE3zO4S";
+ sout << "YG9DY33gIQWV6sw7AGva02FE7zlkRbW3IOyZeG5LYs2r4vGYmZvZuYQJ8CVmkJpGrYioIuqWLyoD";
+ sout << "IfAFmt0z7HREwTsGxP9BG6UIXb95jQCGocuoO+MxQNpq3qtMAr+C2xnN3Na+ITiKYUX+zEP+FrWJ";
+ sout << "uqsYqhyN4H+1rlYaLhrj0nhUs4+Zp+fn3LZecDsZGeq6KDEIM68rOZAY8WWA/o4x30cQ0P299z2m";
+ sout << "7Pl1vRxpN3MdgzkS2zlOpGnrexfA3TK07nFcTK4Lc97t/75JULM85uEo6yYUBjyeY7MPYTHaM0tH";
+ sout << "+TGdAqQBf6TbnYya1MK0BM0nVLp4TkMeRZj49erCtZIVaFWOYmufs3LywApX8AQZAsBhU0QA5cQu";
+ sout << "G6HKcdTiqxUxbjOIueAxWBbDbGIwUj8URJVh/WVW4TAXnzZu48JvRzif++Xeswz6O8+PXdPVp4BG";
+ sout << "TaJv9bBwqV1y7lM/T9FxktP3YIcnbW+ezBUh/logAVeLzCAbYaqChRMUkBizO54DEiA+NBDV5ndu";
+ sout << "Iax8NL9DFp1zmO8gdJBZZLgzf6K8ZbRwwKNlLEkVEZGKei3aVIwo7ed0tjG4aEIodgAuoXD8oa4p";
+ sout << "t7v3sYbn/Wz94AyvAwFV0M1mAMaXqXhpfJSb5CJsSOyKqUtBBhpS58V6Gsq9M9weuwVHee5JI8PI";
+ sout << "bY8WlfQNSVo4JcToIdKPhehz4Ywad3gdX8DhSd+WBT2aMomB71jr3m9CcC9da3oxXjh/jVaN/boQ";
+ sout << "oY5UqFGZ7T1b42x79PiIoOY339sKYR9vr7WPDHTouVtas/3hJBkBKwXqKmBbRPU2N4DTgpt+1zFq";
+ sout << "nig8prSmjiAh9ofp20Pbl5FbYoIdoOfLgYIluAOuT0XhKxLAwWoNGLBA37bGagSxzXnzlDPpFDZ2";
+ sout << "WVdrBwehuD4WhBRxMiTZkzJwhzRkyvsBnO/JqFxXAD5b2CvH9p+3P/czeRv2hRjNZaKHEfdMtIWn";
+ sout << "lPR1zWmZ3LTMBR8R2kmAUt50Sikzj2TpzLlBhemvUVAcv0HEdiILnp3eilOx/ee7GzEroOMHuOGe";
+ sout << "moxWRB1EHj6h2Ya6TVmMb47SzMwKo7mf1PtZ5JLDk61k6Bv9M0WCDiwxlBwC/PjTxHDsTOWRrSJG";
+ sout << "roZ/csy/RQF//nA0f8uKC9HWmGTOsrA6VAVteglGi9k6E9etxvmGCQD2ud2tyz2HSnklJAMutXd+";
+ sout << "flvArk1avpvwXGoNgT3EsEicUGOg93vXl1A873vbHtwnycMzAa8NYjZGW/GPROg8yKkbpoxK+Zel";
+ sout << "+VP14SvqQDKEYDevoGuQWhsQHhW7bYeSa0bSamm9DFzqK2Ld4/aU7JUHYdaAvAJPIHe3F/N8NcJ1";
+ sout << "RE7VoQlrVZXBy33Ly6wwiv1rm8yK4sMdBnNXcxzCyG1NkPsmC/16A08Dy8RV7nJhd1iOJ3gy9BUG";
+ sout << "Y4ofrw/XzITd6vL8mIJW2jyrVzW2OPqMWy6IO/7CHz7d3k5+7KTKWkksTRsuAYWDgZtU5Umph0kN";
+ sout << "/RRnuBRIe/6KYj5/thAY7yLmGFoy8jhDmIikSP/l/pJhUjjgKk1R+oYHGDJ4FSD2KNfRDQ/xHZb2";
+ sout << "++ObYSCHAuhB1HoJBwPcmxkFtxjS047iGXnXo+2nchItcrtifnQTeI0qyV6conJnskM4jfhnKsj7";
+ sout << "A7y4owQjOvkDwu0GueuDOTo/9mW2NjAaiHCp8yarSV1dPI1b/XR231+p5IEp1zWzqFr+O3pDAL8F";
+ sout << "+eU/vj8taOxtT0CcT0gW2rRr8oTigRWCUGP1hxDBdhJtAa7Q+CuFvmtpjm/4fmbroJ5ZVNA/71FN";
+ sout << "yL9DKBmiVTwljf+yRpTGhpMG+xky30zS2R150N93YDDXVT9StjKaOrtLap+9w7BtCvXGdPiyR2QN";
+ sout << "g7gqzrz31poE36fAwM4san1jbbTC+eYcTErQ2wXCQkVne3kbVOwErB0ayl7rNqkw5b3gAME4+DnN";
+ sout << "IM13kdtxY2WeBND96g2enTmFizxHWFzW2asW4/XqGt8EVmzTB4XM/Ytd+XXaNadRUHh44wiNIhPp";
+ sout << "txmIbtkSesyQ4YYSmUYEEcYkZwfcRxHUAGCcnQbSKGLq+N5IcyiLVwhXDfK4fqxH7Oi5DmOwAVmV";
+ sout << "kdDQ6GP7wDcrkAQS9s6fL04bNf7rodacAPdX4HwIKa2YykdXWOiQhp6BRxjUG44AVV5fiTGP1WVn";
+ sout << "MJKXYzSyY5oN3ADAT5em+cIYYVPnsnZBuUzAjHAw3WYj4VlRIrmP+oPdKPFncEyfTn1G7DbmyaPd";
+ sout << "TL3DdB8EDImfZ+A2UMr2i7jynH/fFXzsi2PM9cFKxCsEqG2LGr7KZDP2FFEVvWwIDnvUClj9nrHd";
+ sout << "zyioiild+DW6PoYvqFQX4LUf9Jr09RiJuIvz35I/15pBQbSnTxD1RFJwR92k0yA3514jYBAtfLmf";
+ sout << "/1refalgDjOyD1ntgp3INSYSbpR3TqNIynWwSJGxq9I5aFJypV3Nq8w2Rn3/kld9nqDaTG79ns13";
+ sout << "dwfPOiIZyfdrcxkYOtS4+iijs+YE3JhRKVWMf6ub0cVXttJGORgpzgkSoDWNVZ6O3hVydCziYzoS";
+ sout << "RgHUH2Oas+aZK5IF3Z3aaRYc0A+wy/PEx84LRHsAYdPUQJr+bGseC4wScP1Dyc3xpeJSR1V9t2tc";
+ sout << "AMbhtwcKGX4j5nGnhxSevKzbCDteEnB23TnSQwbuWYmzhB77V3jpu1Cm2h7FcKGM8vPbAQeRBr1b";
+ sout << "+RY95s8hsZGi/USiu4TQ/wHZfEGYcHuVBI920d84pPRVe5EJ8+FzZj+Qy7JwriqLN+7WyUCFdwxn";
+ sout << "4B+WXHTe2epBJMQzlE25kKDHKb2lDDv5HzVUlrZK4rEShkPNv8SB3U9u20GTLlHJbM8RDvPkNjmu";
+ sout << "U4ZLDSJylDaRHWqgchgMnmX08aprI/o1HbgZ5aiByXAUoHSSGHanyYmW/S0LOW7YZH+jOgxzW68U";
+ sout << "lheNnX6Z26RdMe5Xtkd5jx2jXgIT7HADCN8wWdZEVT7FvXRoxtO3nz30cbOim/+IvB2lt//OcTJU";
+ sout << "BwkyweuhJtiXZV1yY+X2z3dWBjBXVFWidPMMWjTUIwalo9A91RL6ZS25kuBXKm/BV6X/zAHkY+jB";
+ sout << "A7qJGLPe5h6SO3GKPSLv0wE+9G6VIhH2TPLfAd+PpqM+xuUNlWlo5JR6ItOgGHuFdjZeDlISYOID";
+ sout << "CJn4zfaU6M4Dmw+m1wsIiyQy4Cw0DcZYUjwStEfLJu7BRyFI55wEbXJr36PRWpvB2wzkI4z1u2yM";
+ sout << "7bFvguH0teVxtMwQy9S4lT9JlfX8QEqNRsuxzQprPJqj0ie2cFBrvJ2E0FtcuarikbcYAmC1iYdz";
+ sout << "Tc2CRVe60taBHR92AsvuTv5OhCIssVc5W+Lbv5e3S0vLwUprWPhfbth/zNevKNiIeogTJOOTQexo";
+ sout << "4EhJ2pf8QSz+ixlWi5ffYKLKSx03wYBZ3fNIhJDP+Y0mz2pd0IyGhxDrMAxYHhPjoEGCfYX6fYz2";
+ sout << "XQGQ1eTGvNEATqz2v9HxYBicYAcW4mjEcTqsMHMYcboVgqiLxOD+jOHJqNrw+eiuC2Ucu6muQ2es";
+ sout << "Zhbo1UNZO/J9XbsZIKa6t0PX0CSczL02Dti8r5rCQ7yeUgwZFzy5oJSShAuUvBDILXBeDhnuU4W3";
+ sout << "eooTvk0lkb8sNAgkUlvHXgdN9jkfQuHqX453IOffB9TnxopLuxS3R6uuRa5ED2pWkL0L5ceftRv3";
+ sout << "0T/sMlHmTNoQhcJSoINMf7f+WmHa95iD1HolQNFJaDEUzLV1DfD2mntncFZoJ4zlr+b/94qp8s/n";
+ sout << "NM0CyoI4ansElbkkjUs7QQp+43pGXu8sRgg7tva7/3vUwLp3Md+8WkX+uppIhj25nlwfkD6IHxdk";
+ sout << "uFNPNgdIikcM463W56CKek7LufT8wQG13mJOMJCObvhW/EU/yb/88hegDxeVJKrWHRyebllsFbZ9";
+ sout << "UeEWkJPKdA+YHilgfSX0aFgofCDmmh1k/cO2RPqIGIqZSfT//6lFEDdzPcCLp/kd4oq7qh5Ko7lo";
+ sout << "bg4N9n6F90oJI+JFJRLy6bEhZC/7obQlP6FFUUqSpizj18+zTLPUaDpt3/eg9QeXUeKcNlkHYt23";
+ sout << "AqzS6PqB9t4bUr+E81QUxpegh5V0M4OPQJbFQ4/rna/AwZVDmGRjJLKzkWUwpnIv7f2+Gl+91ly2";
+ sout << "ve9Ube6Jf9h9M/j7kppHRARY4QbXauYAcSp3wRaIueJWx36dMsqgzuYfwHIyuhGS3CbC1EHUZwM5";
+ sout << "hgCHFvPGHLM7T7p4w6k0e7n8RZJsDybyDAydW8SfSI6LIeX2st4LOHrdLrpUNPyE9JyuIMCSnPaT";
+ sout << "1XnjdWev4jjMeXxa1XWbzy0AQpeQ0UYA4OpqRSu1DyoPWf4IA+bf012m5Im/1BiGF972Ie/6CNvS";
+ sout << "niYOfmxxLWTihdMtslxiy53y2MW2iDzLxX5nvUSyKXfO1XUlDJGTd/zfUywZcY8cgI/f1IPzr2Lk";
+ sout << "3/YKqkJ2De4IpbixDEkbgAroIaoOZXEGD+yNzXVcyISIhMiKHJERdwVxp4j3duc3M04wrwJMZvtK";
+ sout << "lk5jnn+ILhTxcpboq8gge4CbWziUUj9du6mklxZaeYBBpB6CzlVLitesXieA/zl2JnWzRMD7Ho3r";
+ sout << "LLSqxh9UeFgjFtb3ilKSHNZH6x1DS0hFhEIA";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ };
+
+ fhog_tester a;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
diff --git a/ml/dlib/dlib/test/filtering.cpp b/ml/dlib/dlib/test/filtering.cpp
new file mode 100644
index 000000000..61dc88440
--- /dev/null
+++ b/ml/dlib/dlib/test/filtering.cpp
@@ -0,0 +1,166 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/filtering.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/matrix.h>
+#include <dlib/rand.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.filtering");
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename filter_type>
+ double test_filter (
+ filter_type kf,
+ int size
+ )
+ {
+ // This test has a point moving in a circle around the origin. The point
+ // also gets a random bump in a random direction at each time step.
+
+ running_stats<double> rs;
+
+ dlib::rand rnd;
+ int count = 0;
+ const dlib::vector<double,3> z(0,0,1);
+ dlib::vector<double,2> p(10,10), temp;
+ for (int i = 0; i < size; ++i)
+ {
+ // move the point around in a circle
+ p += z.cross(p).normalize()/0.5;
+ // randomly drop measurements
+ if (rnd.get_random_double() < 0.7 || count < 4)
+ {
+ // make a random bump
+ dlib::vector<double,2> pp;
+ pp.x() = rnd.get_random_gaussian()/3;
+ pp.y() = rnd.get_random_gaussian()/3;
+
+ ++count;
+ kf.update(p+pp);
+ }
+ else
+ {
+ kf.update();
+ dlog << LTRACE << "MISSED MEASUREMENT";
+ }
+ // figure out the next position
+ temp = (p+z.cross(p).normalize()/0.5);
+ const double error = length(temp - rowm(kf.get_predicted_next_state(),range(0,1)));
+ rs.add(error);
+
+ dlog << LTRACE << temp << "("<< error << "): " << trans(kf.get_predicted_next_state());
+
+ // test the serialization a few times.
+ if (count < 10)
+ {
+ ostringstream sout;
+ serialize(kf, sout);
+ istringstream sin(sout.str());
+ filter_type temp;
+ deserialize(temp, sin);
+ kf = temp;
+ }
+ }
+
+
+ return rs.mean();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_kalman_filter()
+ {
+ matrix<double,2,2> R;
+ R = 0.3, 0,
+ 0, 0.3;
+
+ // the variables in the state are
+ // x,y, x velocity, y velocity, x acceleration, and y acceleration
+ matrix<double,6,6> A;
+ A = 1, 0, 1, 0, 0, 0,
+ 0, 1, 0, 1, 0, 0,
+ 0, 0, 1, 0, 1, 0,
+ 0, 0, 0, 1, 0, 1,
+ 0, 0, 0, 0, 1, 0,
+ 0, 0, 0, 0, 0, 1;
+
+ // the measurements only tell us the positions
+ matrix<double,2,6> H;
+ H = 1, 0, 0, 0, 0, 0,
+ 0, 1, 0, 0, 0, 0;
+
+
+ kalman_filter<6,2> kf;
+ kf.set_measurement_noise(R);
+ matrix<double> pn = 0.01*identity_matrix<double,6>();
+ kf.set_process_noise(pn);
+ kf.set_observation_model(H);
+ kf.set_transition_model(A);
+
+ DLIB_TEST(equal(kf.get_observation_model() , H));
+ DLIB_TEST(equal(kf.get_transition_model() , A));
+ DLIB_TEST(equal(kf.get_measurement_noise() , R));
+ DLIB_TEST(equal(kf.get_process_noise() , pn));
+ DLIB_TEST(equal(kf.get_current_estimation_error_covariance() , identity_matrix(pn)));
+
+ double kf_error = test_filter(kf, 300);
+
+ dlog << LINFO << "kf error: "<< kf_error;
+ DLIB_TEST_MSG(kf_error < 0.75, kf_error);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_rls_filter()
+ {
+
+ rls_filter rls(10, 0.99, 0.1);
+
+ DLIB_TEST(rls.get_window_size() == 10);
+ DLIB_TEST(rls.get_forget_factor() == 0.99);
+ DLIB_TEST(rls.get_c() == 0.1);
+
+ double rls_error = test_filter(rls, 1000);
+
+ dlog << LINFO << "rls error: "<< rls_error;
+ DLIB_TEST_MSG(rls_error < 0.75, rls_error);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class filtering_tester : public tester
+ {
+ public:
+ filtering_tester (
+ ) :
+ tester ("test_filtering",
+ "Runs tests on the filtering stuff (rls and kalman filters).")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_rls_filter();
+ test_kalman_filter();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp b/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp
new file mode 100644
index 000000000..2260e92a1
--- /dev/null
+++ b/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp
@@ -0,0 +1,787 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/optimization.h>
+#include <dlib/unordered_pair.h>
+#include <dlib/rand.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.find_max_factor_graph_nmplp");
+
+// ----------------------------------------------------------------------------------------
+
+ dlib::rand rnd;
+
+ template <bool fully_connected>
+ class map_problem
+ {
+ /*
+ This is a simple 8 node problem with two cycles in it unless fully_connected is true
+ and then it's a fully connected 8 note graph.
+ */
+
+ public:
+
+ mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
+ map_problem()
+ {
+ for (int i = 0; i < 8; ++i)
+ {
+ for (int j = i; j < 8; ++j)
+ {
+ weights[make_unordered_pair(i,j)][make_pair(0,0)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,j)][make_pair(0,1)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,j)][make_pair(1,0)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,j)][make_pair(1,1)] = rnd.get_random_gaussian();
+ }
+ }
+ }
+
+ struct node_iterator
+ {
+ node_iterator() {}
+ node_iterator(unsigned long nid_): nid(nid_) {}
+ bool operator== (const node_iterator& item) const { return item.nid == nid; }
+ bool operator!= (const node_iterator& item) const { return item.nid != nid; }
+
+ node_iterator& operator++()
+ {
+ ++nid;
+ return *this;
+ }
+
+ unsigned long nid;
+ };
+
+ struct neighbor_iterator
+ {
+ neighbor_iterator() : count(0) {}
+
+ bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
+ bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
+ neighbor_iterator& operator++()
+ {
+ ++count;
+ return *this;
+ }
+
+ unsigned long node_id () const
+ {
+ if (fully_connected)
+ {
+ if (count < home_node)
+ return count;
+ else
+ return count+1;
+ }
+
+ if (home_node < 4)
+ {
+ if (count == 0)
+ return (home_node + 4 + 1)%4;
+ else if (count == 1)
+ return (home_node + 4 - 1)%4;
+ else
+ return 8; // one past the end
+ }
+ else
+ {
+ if (count == 0)
+ return (home_node + 4 + 1)%4 + 4;
+ else if (count == 1)
+ return (home_node + 4 - 1)%4 + 4;
+ else
+ return 8; // one past the end
+ }
+ }
+
+ unsigned long home_node;
+ unsigned long count;
+ };
+
+ unsigned long number_of_nodes (
+ ) const
+ {
+ return 8;
+ }
+
+ node_iterator begin(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 0;
+ return temp;
+ }
+
+ node_iterator end(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 8;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const node_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = it.nid;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const neighbor_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = it.node_id();
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const node_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = 9;
+ temp.count = 8;
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const neighbor_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = 9;
+ temp.count = 8;
+ return temp;
+ }
+
+
+ unsigned long node_id (
+ const node_iterator& it
+ ) const
+ {
+ return it.nid;
+ }
+
+ unsigned long node_id (
+ const neighbor_iterator& it
+ ) const
+ {
+ return it.node_id();
+ }
+
+
+ unsigned long num_states (
+ const node_iterator&
+ ) const
+ {
+ return 2;
+ }
+
+ unsigned long num_states (
+ const neighbor_iterator&
+ ) const
+ {
+ return 2;
+ }
+
+ double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.nid, s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
+ double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }
+
+ private:
+
+ double basic_factor_value (
+ unsigned long n1,
+ unsigned long n2,
+ unsigned long s1,
+ unsigned long s2
+ ) const
+ {
+ if (n1 > n2)
+ {
+ swap(n1,n2);
+ swap(s1,s2);
+ }
+ return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class map_problem_chain
+ {
+ /*
+ This is a chain structured 8 node graph (so no cycles).
+ */
+
+ public:
+
+ mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
+ map_problem_chain()
+ {
+ for (int i = 0; i < 7; ++i)
+ {
+ weights[make_unordered_pair(i,i+1)][make_pair(0,0)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,i+1)][make_pair(0,1)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,i+1)][make_pair(1,0)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,i+1)][make_pair(1,1)] = rnd.get_random_gaussian();
+ }
+ }
+
+ struct node_iterator
+ {
+ node_iterator() {}
+ node_iterator(unsigned long nid_): nid(nid_) {}
+ bool operator== (const node_iterator& item) const { return item.nid == nid; }
+ bool operator!= (const node_iterator& item) const { return item.nid != nid; }
+
+ node_iterator& operator++()
+ {
+ ++nid;
+ return *this;
+ }
+
+ unsigned long nid;
+ };
+
+ struct neighbor_iterator
+ {
+ neighbor_iterator() : count(0) {}
+
+ bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
+ bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
+ neighbor_iterator& operator++()
+ {
+ ++count;
+ return *this;
+ }
+
+ unsigned long node_id () const
+ {
+ if (count >= 2)
+ return 8;
+ return nid[count];
+ }
+
+ unsigned long nid[2];
+ unsigned int count;
+ };
+
+ unsigned long number_of_nodes (
+ ) const
+ {
+ return 8;
+ }
+
+ node_iterator begin(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 0;
+ return temp;
+ }
+
+ node_iterator end(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 8;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const node_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ if (it.nid == 0)
+ {
+ temp.nid[0] = it.nid+1;
+ temp.nid[1] = 8;
+ }
+ else if (it.nid == 7)
+ {
+ temp.nid[0] = it.nid-1;
+ temp.nid[1] = 8;
+ }
+ else
+ {
+ temp.nid[0] = it.nid-1;
+ temp.nid[1] = it.nid+1;
+ }
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const neighbor_iterator& it
+ ) const
+ {
+ const unsigned long nid = it.node_id();
+ neighbor_iterator temp;
+ if (nid == 0)
+ {
+ temp.nid[0] = nid+1;
+ temp.nid[1] = 8;
+ }
+ else if (nid == 7)
+ {
+ temp.nid[0] = nid-1;
+ temp.nid[1] = 8;
+ }
+ else
+ {
+ temp.nid[0] = nid-1;
+ temp.nid[1] = nid+1;
+ }
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const node_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.nid[0] = 8;
+ temp.nid[1] = 8;
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const neighbor_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.nid[0] = 8;
+ temp.nid[1] = 8;
+ return temp;
+ }
+
+
+ unsigned long node_id (
+ const node_iterator& it
+ ) const
+ {
+ return it.nid;
+ }
+
+ unsigned long node_id (
+ const neighbor_iterator& it
+ ) const
+ {
+ return it.node_id();
+ }
+
+
+ unsigned long num_states (
+ const node_iterator&
+ ) const
+ {
+ return 2;
+ }
+
+ unsigned long num_states (
+ const neighbor_iterator&
+ ) const
+ {
+ return 2;
+ }
+
+ double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.nid, s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
+ double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }
+
+ private:
+
+ double basic_factor_value (
+ unsigned long n1,
+ unsigned long n2,
+ unsigned long s1,
+ unsigned long s2
+ ) const
+ {
+ if (n1 > n2)
+ {
+ swap(n1,n2);
+ swap(s1,s2);
+ }
+ return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+
+ class map_problem2
+ {
+ /*
+ This is a simple tree structured graph. In particular, it is a star made
+ up of 6 nodes.
+ */
+ public:
+ matrix<double> numbers;
+
+ map_problem2()
+ {
+ numbers = randm(5,3,rnd);
+ }
+
+ struct node_iterator
+ {
+ node_iterator() {}
+ node_iterator(unsigned long nid_): nid(nid_) {}
+ bool operator== (const node_iterator& item) const { return item.nid == nid; }
+ bool operator!= (const node_iterator& item) const { return item.nid != nid; }
+
+ node_iterator& operator++()
+ {
+ ++nid;
+ return *this;
+ }
+
+ unsigned long nid;
+ };
+
+ struct neighbor_iterator
+ {
+ neighbor_iterator() : count(0) {}
+
+ bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
+ bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
+ neighbor_iterator& operator++()
+ {
+ ++count;
+ return *this;
+ }
+
+ unsigned long node_id () const
+ {
+ if (home_node == 6)
+ return 6;
+
+ if (home_node < 5)
+ {
+ // all the nodes are connected to node 5 and nothing else
+ if (count == 0)
+ return 5;
+ else
+ return 6; // the number returned by the end() functions.
+ }
+ else if (count < 5)
+ {
+ return count;
+ }
+ else
+ {
+ return 6;
+ }
+
+ }
+
+ unsigned long home_node;
+ unsigned long count;
+ };
+
+ unsigned long number_of_nodes (
+ ) const
+ {
+ return 6;
+ }
+
+ node_iterator begin(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 0;
+ return temp;
+ }
+
+ node_iterator end(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 6;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const node_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = it.nid;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const neighbor_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = it.node_id();
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const node_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = 6;
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const neighbor_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = 6;
+ return temp;
+ }
+
+
+ unsigned long node_id (
+ const node_iterator& it
+ ) const
+ {
+ return it.nid;
+ }
+
+ unsigned long node_id (
+ const neighbor_iterator& it
+ ) const
+ {
+ return it.node_id();
+ }
+
+
+ unsigned long num_states (
+ const node_iterator&
+ ) const
+ {
+ return 3;
+ }
+
+ unsigned long num_states (
+ const neighbor_iterator&
+ ) const
+ {
+ return 3;
+ }
+
+ double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.nid, s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
+ double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }
+
+ private:
+
+ double basic_factor_value (
+ unsigned long n1,
+ unsigned long n2,
+ unsigned long s1,
+ unsigned long s2
+ ) const
+ {
+ if (n1 > n2)
+ {
+ swap(n1,n2);
+ swap(s1,s2);
+ }
+
+
+ // basically ignore the other node in this factor. The node we
+ // are ignoring is the center node of this star graph. So we basically
+ // let it always have a value of 1.
+ if (s2 == 1)
+ return numbers(n1,s1) + 1;
+ else
+ return numbers(n1,s1);
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename map_problem>
+ double find_total_score (
+ const map_problem& prob,
+ const std::vector<unsigned long>& map_assignment
+ )
+ {
+ typedef typename map_problem::node_iterator node_iterator;
+ typedef typename map_problem::neighbor_iterator neighbor_iterator;
+
+ double score = 0;
+ for (node_iterator i = prob.begin(); i != prob.end(); ++i)
+ {
+ const unsigned long id_i = prob.node_id(i);
+ for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
+ {
+ const unsigned long id_j = prob.node_id(j);
+ score += prob.factor_value(i,j, map_assignment[id_i], map_assignment[id_j]);
+ }
+ }
+
+ return score;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <
+ typename map_problem
+ >
+ void brute_force_find_max_factor_graph_nmplp (
+ const map_problem& prob,
+ std::vector<unsigned long>& map_assignment
+ )
+ {
+ std::vector<unsigned long> temp_assignment;
+ temp_assignment.resize(prob.number_of_nodes(),0);
+
+ double best_score = -std::numeric_limits<double>::infinity();
+
+ for (unsigned long i = 0; i < 255; ++i)
+ {
+ temp_assignment[0] = (i&0x01)!=0;
+ temp_assignment[1] = (i&0x02)!=0;
+ temp_assignment[2] = (i&0x04)!=0;
+ temp_assignment[3] = (i&0x08)!=0;
+ temp_assignment[4] = (i&0x10)!=0;
+ temp_assignment[5] = (i&0x20)!=0;
+ temp_assignment[6] = (i&0x40)!=0;
+ temp_assignment[7] = (i&0x80)!=0;
+
+ double score = find_total_score(prob,temp_assignment);
+ if (score > best_score)
+ {
+ best_score = score;
+ map_assignment = temp_assignment;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename map_problem>
+ void do_test(
+ )
+ {
+ print_spinner();
+ std::vector<unsigned long> map_assignment1, map_assignment2;
+ map_problem prob;
+ find_max_factor_graph_nmplp(prob, map_assignment1, 1000, 1e-8);
+
+ const double score1 = find_total_score(prob, map_assignment1);
+
+ brute_force_find_max_factor_graph_nmplp(prob, map_assignment2);
+ const double score2 = find_total_score(prob, map_assignment2);
+
+ dlog << LINFO << "score NMPLP: " << score1;
+ dlog << LINFO << "score MAP: " << score2;
+
+ DLIB_TEST(std::abs(score1 - score2) < 1e-10);
+ DLIB_TEST(mat(map_assignment1) == mat(map_assignment2));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename map_problem>
+ void do_test2(
+ )
+ {
+ print_spinner();
+ std::vector<unsigned long> map_assignment1, map_assignment2;
+ map_problem prob;
+ find_max_factor_graph_nmplp(prob, map_assignment1, 10, 1e-8);
+
+ const double score1 = find_total_score(prob, map_assignment1);
+
+ map_assignment2.resize(6);
+ map_assignment2[0] = index_of_max(rowm(prob.numbers,0));
+ map_assignment2[1] = index_of_max(rowm(prob.numbers,1));
+ map_assignment2[2] = index_of_max(rowm(prob.numbers,2));
+ map_assignment2[3] = index_of_max(rowm(prob.numbers,3));
+ map_assignment2[4] = index_of_max(rowm(prob.numbers,4));
+ map_assignment2[5] = 1;
+ const double score2 = find_total_score(prob, map_assignment2);
+
+ dlog << LINFO << "score NMPLP: " << score1;
+ dlog << LINFO << "score MAP: " << score2;
+ dlog << LINFO << "MAP assignment: "<< trans(mat(map_assignment1));
+
+ DLIB_TEST(std::abs(score1 - score2) < 1e-10);
+ DLIB_TEST(mat(map_assignment1) == mat(map_assignment2));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_find_max_factor_graph_nmplp : public tester
+ {
+ public:
+ test_find_max_factor_graph_nmplp (
+ ) :
+ tester ("test_find_max_factor_graph_nmplp",
+ "Runs tests on the find_max_factor_graph_nmplp routine.")
+ {}
+
+ void perform_test (
+ )
+ {
+ rnd.clear();
+
+ dlog << LINFO << "test on a chain structured graph";
+ for (int i = 0; i < 30; ++i)
+ do_test<map_problem_chain>();
+
+ dlog << LINFO << "test on a 2 cycle graph";
+ for (int i = 0; i < 30; ++i)
+ do_test<map_problem<false> >();
+
+ dlog << LINFO << "test on a fully connected graph";
+ for (int i = 0; i < 5; ++i)
+ do_test<map_problem<true> >();
+
+ dlog << LINFO << "test on a tree structured graph";
+ for (int i = 0; i < 10; ++i)
+ do_test2<map_problem2>();
+ }
+ } a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/find_max_factor_graph_viterbi.cpp b/ml/dlib/dlib/test/find_max_factor_graph_viterbi.cpp
new file mode 100644
index 000000000..82754aefd
--- /dev/null
+++ b/ml/dlib/dlib/test/find_max_factor_graph_viterbi.cpp
@@ -0,0 +1,217 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/optimization.h>
+#include <dlib/rand.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.find_max_factor_graph_viterbi");
+
+// ----------------------------------------------------------------------------------------
+
+ dlib::rand rnd;
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long O,
+ unsigned long NS,
+ unsigned long num_nodes,
+ bool all_negative
+ >
+ class map_problem
+ {
+ public:
+ unsigned long order() const { return O; }
+ unsigned long num_states() const { return NS; }
+
+ map_problem()
+ {
+ data = randm(number_of_nodes(),(long)std::pow(num_states(),(double)order()+1), rnd);
+ if (all_negative)
+ data = -data;
+ }
+
+ unsigned long number_of_nodes (
+ ) const
+ {
+ return num_nodes;
+ }
+
+ template <
+ typename EXP
+ >
+ double factor_value (
+ unsigned long node_id,
+ const matrix_exp<EXP>& node_states
+ ) const
+ {
+ if (node_states.size() == 1)
+ return data(node_id, node_states(0));
+ else if (node_states.size() == 2)
+ return data(node_id, node_states(0) + node_states(1)*NS);
+ else if (node_states.size() == 3)
+ return data(node_id, (node_states(0) + node_states(1)*NS)*NS + node_states(2));
+ else
+ return data(node_id, ((node_states(0) + node_states(1)*NS)*NS + node_states(2))*NS + node_states(3));
+ }
+
+ matrix<double> data;
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename map_problem
+ >
+ void brute_force_find_max_factor_graph_viterbi (
+ const map_problem& prob,
+ std::vector<unsigned long>& map_assignment
+ )
+ {
+ using namespace dlib::impl;
+ const int order = prob.order();
+ const int num_states = prob.num_states();
+
+ map_assignment.resize(prob.number_of_nodes());
+ double best_score = -std::numeric_limits<double>::infinity();
+ matrix<unsigned long,1,0> node_states;
+ node_states.set_size(prob.number_of_nodes());
+ node_states = 0;
+ do
+ {
+ double score = 0;
+ for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
+ {
+ score += prob.factor_value(i, (colm(node_states,range(i,i-std::min<int>(order,i)))));
+ }
+
+ if (score > best_score)
+ {
+ for (unsigned long i = 0; i < map_assignment.size(); ++i)
+ map_assignment[i] = node_states(i);
+ best_score = score;
+ }
+
+ } while(advance_state(node_states,num_states));
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ unsigned long order,
+ unsigned long num_states,
+ unsigned long num_nodes,
+ bool all_negative
+ >
+ void do_test_()
+ {
+ dlog << LINFO << "order: "<< order
+ << " num_states: " << num_states
+ << " num_nodes: " << num_nodes
+ << " all_negative: " << all_negative;
+
+ for (int i = 0; i < 25; ++i)
+ {
+ print_spinner();
+ map_problem<order,num_states,num_nodes,all_negative> prob;
+ std::vector<unsigned long> assign, assign2;
+ brute_force_find_max_factor_graph_viterbi(prob, assign);
+ find_max_factor_graph_viterbi(prob, assign2);
+
+ DLIB_TEST_MSG(mat(assign) == mat(assign2),
+ trans(mat(assign))
+ << trans(mat(assign2))
+ );
+ }
+ }
+
+ template <
+ unsigned long order,
+ unsigned long num_states,
+ unsigned long num_nodes
+ >
+ void do_test()
+ {
+ do_test_<order,num_states,num_nodes,false>();
+ }
+
+ template <
+ unsigned long order,
+ unsigned long num_states,
+ unsigned long num_nodes
+ >
+ void do_test_negative()
+ {
+ do_test_<order,num_states,num_nodes,true>();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_find_max_factor_graph_viterbi : public tester
+ {
+ public:
+ test_find_max_factor_graph_viterbi (
+ ) :
+ tester ("test_find_max_factor_graph_viterbi",
+ "Runs tests on the find_max_factor_graph_viterbi routine.")
+ {}
+
+ void perform_test (
+ )
+ {
+ do_test<1,3,0>();
+ do_test<1,3,1>();
+ do_test<1,3,2>();
+ do_test<0,3,2>();
+ do_test_negative<0,3,2>();
+
+ do_test<1,3,8>();
+ do_test<2,3,7>();
+ do_test_negative<2,3,7>();
+ do_test<3,3,8>();
+ do_test<4,3,8>();
+ do_test_negative<4,3,8>();
+ do_test<0,3,8>();
+ do_test<4,3,1>();
+ do_test<4,3,0>();
+
+ do_test<3,2,1>();
+ do_test<3,2,0>();
+ do_test<3,2,2>();
+ do_test<2,2,1>();
+ do_test_negative<3,2,1>();
+ do_test_negative<3,2,0>();
+ do_test_negative<3,2,2>();
+ do_test_negative<2,2,1>();
+
+ do_test<0,3,0>();
+ do_test<1,2,8>();
+ do_test<2,2,7>();
+ do_test<3,2,8>();
+ do_test<0,2,8>();
+
+ do_test<1,1,8>();
+ do_test<2,1,8>();
+ do_test<3,1,8>();
+ do_test<0,1,8>();
+ }
+ } a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/find_optimal_parameters.cpp b/ml/dlib/dlib/test/find_optimal_parameters.cpp
new file mode 100644
index 000000000..9f2f5b348
--- /dev/null
+++ b/ml/dlib/dlib/test/find_optimal_parameters.cpp
@@ -0,0 +1,58 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/optimization/find_optimal_parameters.h>
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.find_optimal_parameters");
+
+// ----------------------------------------------------------------------------------------
+
+
+ class find_optimal_parameters : public tester
+ {
+ public:
+ find_optimal_parameters (
+ ) :
+ tester ("test_find_optimal_parameters",
+ "Runs tests on find_optimal_parameters().")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ matrix<double,0,1> params = {0.5, 0.5};
+ dlib::find_optimal_parameters(4, 0.001, 100, params, {-0.1, -0.01}, {5, 5}, [](const matrix<double,0,1>& params) {
+ cout << ".";
+ return sum(squared(params));
+ });
+
+ matrix<double,0,1> true_params = {0,0};
+
+ DLIB_TEST(max(abs(true_params - params)) < 1e-10);
+
+ params = {0.1};
+ dlib::find_optimal_parameters(4, 0.001, 100, params, {-0.01}, {5}, [](const matrix<double,0,1>& params) {
+ cout << ".";
+ return sum(squared(params));
+ });
+
+ true_params = {0};
+ DLIB_TEST(max(abs(true_params - params)) < 1e-10);
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/geometry.cpp b/ml/dlib/dlib/test/geometry.cpp
new file mode 100644
index 000000000..505a83d95
--- /dev/null
+++ b/ml/dlib/dlib/test/geometry.cpp
@@ -0,0 +1,883 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/geometry.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/string.h>
+#include <dlib/matrix.h>
+#include <dlib/rand.h>
+#include <dlib/array2d.h>
+#include <dlib/image_transforms.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.geometry");
+
+ void geometry_test (
+ )
+ /*!
+ ensures
+ - runs tests on the geometry stuff compliance with the specs
+ !*/
+ {
+ print_spinner();
+
+ point p1;
+ point p2(2,3);
+
+ DLIB_TEST(p1.x() == 0);
+ DLIB_TEST(p1.y() == 0);
+ DLIB_TEST(p2.x() == 2);
+ DLIB_TEST(p2.y() == 3);
+
+ DLIB_TEST((-p2).x() == -2);
+ DLIB_TEST((-p2).y() == -3);
+
+
+ p2 += p2;
+ DLIB_TEST(p2.x() == 4);
+ DLIB_TEST(p2.y() == 6);
+
+ dlib::vector<double> v1 = point(1,0);
+ dlib::vector<double> v2(0,0,1);
+
+ p1 = v2.cross(v1);
+ DLIB_TEST(p1 == point(0,1));
+ DLIB_TEST(p1 != point(1,1));
+ DLIB_TEST(p1 != point(1,0));
+
+ p1 = point(2,3);
+ rectangle rect1 = p1;
+ DLIB_TEST(rect1.width() == 1);
+ DLIB_TEST(rect1.height() == 1);
+ p2 = point(1,1);
+
+ rect1 += p2;
+ DLIB_TEST(rect1.left() == 1);
+ DLIB_TEST(rect1.top() == 1);
+ DLIB_TEST(rect1.right() == 2);
+ DLIB_TEST(rect1.bottom() == 3);
+
+ DLIB_TEST(rect1.width() == 2);
+ DLIB_TEST(rect1.height() == 3);
+
+ // test the iostream << and >> operators (via string_cast and cast_to_string)
+ DLIB_TEST(string_cast<point>(" (1, 2 )") == point(1,2));
+ DLIB_TEST(string_cast<point>(" ( -1, 2 )") == point(-1,2));
+ DLIB_TEST(string_cast<rectangle>(" [(1, 2 )(3,4)]") == rectangle(1,2,3,4));
+ DLIB_TEST(string_cast<dlib::vector<double> >(" (1, 2 , 3.5)") == dlib::vector<double>(1,2,3.5));
+
+ DLIB_TEST(string_cast<rectangle>(cast_to_string(rect1)) == rect1);
+ DLIB_TEST(string_cast<point>(cast_to_string(p1)) == p1);
+ DLIB_TEST(string_cast<dlib::vector<double> >(cast_to_string(v1)) == v1);
+
+ rectangle rect2;
+
+ // test the serialization code
+ ostringstream sout;
+ serialize(rect1,sout);
+ serialize(p1,sout);
+ serialize(v1,sout);
+ serialize(rect1,sout);
+ serialize(p1,sout);
+ serialize(v1,sout);
+
+ istringstream sin(sout.str());
+
+ deserialize(rect2,sin);
+ deserialize(p2,sin);
+ deserialize(v2,sin);
+ DLIB_TEST(rect2 == rect1);
+ DLIB_TEST(p2 == p1);
+ DLIB_TEST(v2 == v1);
+ deserialize(rect2,sin);
+ deserialize(p2,sin);
+ deserialize(v2,sin);
+ DLIB_TEST(rect2 == rect1);
+ DLIB_TEST(p2 == p1);
+ DLIB_TEST(v2 == v1);
+ DLIB_TEST(sin.good());
+ DLIB_TEST(sin.get() == EOF);
+
+
+ v1.x() = 1;
+ v1.y() = 2;
+ v1.z() = 3;
+
+ matrix<double> mv = v1;
+ DLIB_TEST(mv.nr() == 3);
+ DLIB_TEST(mv.nc() == 1);
+ DLIB_TEST(mv(0) == 1);
+ DLIB_TEST(mv(1) == 2);
+ DLIB_TEST(mv(2) == 3);
+
+ set_all_elements(mv,0);
+ DLIB_TEST(mv(0) == 0);
+ DLIB_TEST(mv(1) == 0);
+ DLIB_TEST(mv(2) == 0);
+
+ mv(0) = 5;
+ mv(1) = 6;
+ mv(2) = 7;
+
+ v1 = mv;
+ DLIB_TEST(v1.x() == 5);
+ DLIB_TEST(v1.y() == 6);
+ DLIB_TEST(v1.z() == 7);
+
+
+ {
+ dlib::vector<double,2> vd2;
+ dlib::vector<double,3> vd3;
+ dlib::vector<long,2> vl2;
+ dlib::vector<long,3> vl3;
+
+ vd2.x() = 2.3;
+ vd2.y() = 4.7;
+
+ vd3.z() = 9;
+
+ vd3 = vd2;
+
+
+
+ vl2 = vd3;
+ vl3 = vd3;
+
+
+ DLIB_TEST(vd2.z() == 0);
+ DLIB_TEST(vd3.z() == 0);
+ DLIB_TEST(vl2.z() == 0);
+ DLIB_TEST(vl3.z() == 0);
+
+ DLIB_TEST(vl2.x() == 2);
+ DLIB_TEST(vl3.x() == 2);
+ DLIB_TEST(vl2.y() == 5);
+ DLIB_TEST(vl3.y() == 5);
+
+
+ DLIB_TEST(abs(vd2.cross(vd3).dot(vd2)) < 1e-7);
+ DLIB_TEST(abs(vd3.cross(vd2).dot(vd2)) < 1e-7);
+ DLIB_TEST(abs(vd2.cross(vd3).dot(vd3)) < 1e-7);
+ DLIB_TEST(abs(vd3.cross(vd2).dot(vd3)) < 1e-7);
+
+ DLIB_TEST(abs(vl2.cross(vl3).dot(vl2)) == 0);
+ DLIB_TEST(abs(vl3.cross(vl2).dot(vl2)) == 0);
+ DLIB_TEST(abs(vl2.cross(vl3).dot(vl3)) == 0);
+ DLIB_TEST(abs(vl3.cross(vl2).dot(vl3)) == 0);
+
+
+ DLIB_TEST((vd2-vd3).length() < 1e-7);
+
+ DLIB_TEST(vl2 == vl3);
+
+
+ vl2.x() = 0;
+ vl2.y() = 0;
+ vl3 = vl2;
+
+ vl2.x() = 4;
+ vl3.y() = 3;
+
+ DLIB_TEST(vl2.cross(vl3).length() == 12);
+ DLIB_TEST(vl3.cross(vl2).length() == 12);
+
+
+ matrix<double> m(3,3);
+ m = 1,2,3,
+ 4,5,6,
+ 7,8,9;
+
+ vd3.x() = 2;
+ vd3.y() = 3;
+ vd3.z() = 4;
+
+ vd3 = m*vd3;
+
+ DLIB_TEST_MSG(vd3.x() == 1*2 + 2*3 + 3*4,vd3.x() << " == " << (1*2 + 2*3 + 3*4));
+ DLIB_TEST(vd3.y() == 4*2 + 5*3 + 6*4);
+ DLIB_TEST(vd3.z() == 7*2 + 8*3 + 9*4);
+
+ (vd3*2).dot(vd3);
+ (vd2*2).dot(vd3);
+ (vd3*2).dot(vd2);
+ (vd2*2).dot(vd2);
+ (2*vd3*2).dot(vd3);
+ (2*vd2*2).dot(vd3);
+ (2*vd3*2).dot(vd2);
+ (2*vd2*2).dot(vd2);
+
+ (vd2 + vd3).dot(vd2);
+ (vd2 - vd3).dot(vd2);
+ (vd2/2).dot(vd2);
+ (vd3/2).dot(vd2);
+ }
+
+ {
+ dlib::vector<double,2> vd2;
+ dlib::vector<long,3> vl3;
+
+ vl3.x() = 1;
+ vl3.y() = 2;
+ vl3.z() = 3;
+
+ vd2.x() = 6.5;
+ vd2.y() = 7.5;
+
+ DLIB_TEST((vl3 + vd2).x() == 1+6.5);
+ DLIB_TEST((vl3 + vd2).y() == 2+7.5);
+ DLIB_TEST((vl3 + vd2).z() == 3+0);
+
+ DLIB_TEST((vl3 - vd2).x() == 1-6.5);
+ DLIB_TEST((vl3 - vd2).y() == 2-7.5);
+ DLIB_TEST((vl3 - vd2).z() == 3-0);
+
+ }
+
+ {
+ dlib::vector<double> v(3,4,5);
+ DLIB_TEST((-v).x() == -3.0);
+ DLIB_TEST((-v).y() == -4.0);
+ DLIB_TEST((-v).z() == -5.0);
+ }
+
+ {
+ rectangle rect;
+
+ point tl(2,3);
+ point tr(8,3);
+ point bl(2,9);
+ point br(8,9);
+
+ rect += tl;
+ rect += tr;
+ rect += bl;
+ rect += br;
+
+ DLIB_TEST(rect.tl_corner() == tl);
+ DLIB_TEST(rect.tr_corner() == tr);
+ DLIB_TEST(rect.bl_corner() == bl);
+ DLIB_TEST(rect.br_corner() == br);
+
+ }
+
+ {
+ point p1, center;
+
+ center = point(3,4);
+ p1 = point(10,4);
+
+ DLIB_TEST(rotate_point(center, p1, pi/2) == point(3,7+4));
+
+ center = point(3,3);
+ p1 = point(10,3);
+
+ DLIB_TEST(rotate_point(center, p1, pi/4) == point(8,8));
+ DLIB_TEST(rotate_point(center, p1, -pi/4) == point(8,-2));
+
+ DLIB_TEST(rotate_point(center, p1, pi/4 + 10*pi) == point(8,8));
+ DLIB_TEST(rotate_point(center, p1, -pi/4 + 10*pi) == point(8,-2));
+ DLIB_TEST(rotate_point(center, p1, pi/4 - 10*pi) == point(8,8));
+ DLIB_TEST(rotate_point(center, p1, -pi/4 - 10*pi) == point(8,-2));
+
+ point_rotator rot(pi/2);
+ DLIB_TEST(rot(point(1,0)) == point(0,1));
+ DLIB_TEST(rot(point(0,1)) == point(-1,0));
+ DLIB_TEST(point(rot.get_m()*(dlib::vector<double,2>(1,0))) == point(0,1));
+ DLIB_TEST(point(rot.get_m()*(dlib::vector<double,2>(0,1))) == point(-1,0));
+ }
+
+ {
+ rectangle rect;
+
+ rect = grow_rect(rect,1);
+ DLIB_TEST(rect.width() == 2);
+ DLIB_TEST(rect.height() == 2);
+ DLIB_TEST(rect.left() == -1);
+ DLIB_TEST(rect.top() == -1);
+ DLIB_TEST(rect.right() == 0);
+ DLIB_TEST(rect.bottom() == 0);
+ }
+ {
+ rectangle rect;
+
+ rect = grow_rect(rect,2);
+ DLIB_TEST(rect.width() == 4);
+ DLIB_TEST(rect.height() == 4);
+ DLIB_TEST(rect.left() == -2);
+ DLIB_TEST(rect.top() == -2);
+ DLIB_TEST(rect.right() == 1);
+ DLIB_TEST(rect.bottom() == 1);
+
+ rect = shrink_rect(rect,1);
+ DLIB_TEST(rect.width() == 2);
+ DLIB_TEST(rect.height() == 2);
+ DLIB_TEST(rect.left() == -1);
+ DLIB_TEST(rect.top() == -1);
+ DLIB_TEST(rect.right() == 0);
+ DLIB_TEST(rect.bottom() == 0);
+ }
+ {
+ std::vector< dlib::vector<double> > a;
+
+ dlib::vector<double> v;
+ dlib::rand rnd;
+
+ for (int i = 0; i < 10; ++i)
+ {
+ v.x() = rnd.get_random_double();
+ v.y() = rnd.get_random_double();
+ v.z() = rnd.get_random_double();
+ a.push_back(v);
+
+ }
+
+ // This test is just to make sure the covariance function can compile when used
+ // on a dlib::vector. The actual test doesn't matter.
+ DLIB_TEST(sum(covariance(mat(a))) < 10);
+
+ }
+
+
+ DLIB_TEST(rectangle() + point(5,4) + point(10,10) == rectangle(5,4,10,10));
+
+ // make sure the center of a centered rectangle is always right
+ for (long x = -10; x <= 10; ++x)
+ {
+ for (long y = -10; y <= 10; ++y)
+ {
+ for (long w = 0; w < 10; ++w)
+ {
+ for (long h = 0; h < 10; ++h)
+ {
+ DLIB_TEST(center(centered_rect(x,y,w,h)) == point(x,y));
+ }
+ }
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_border_enumerator()
+ {
+
+
+
+ border_enumerator be;
+ DLIB_TEST(be.at_start() == true);
+ DLIB_TEST(be.size() == 0);
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.move_next() == false);
+ DLIB_TEST(be.at_start() == false);
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.move_next() == false);
+ DLIB_TEST(be.at_start() == false);
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.size() == 0);
+
+ be = border_enumerator(rectangle(4,4,4,4),1);
+ DLIB_TEST(be.at_start() == true);
+ DLIB_TEST(be.size() == 1);
+ be = border_enumerator(rectangle(4,4,4,4),3);
+ DLIB_TEST(be.at_start() == true);
+ DLIB_TEST(be.size() == 1);
+ be = border_enumerator(rectangle(4,4,4,4),0);
+ DLIB_TEST(be.at_start() == true);
+ DLIB_TEST(be.size() == 0);
+ be = border_enumerator(rectangle(4,4,5,5),0);
+ DLIB_TEST(be.at_start() == true);
+ DLIB_TEST(be.size() == 0);
+ be = border_enumerator(rectangle(4,4,5,5),1);
+ DLIB_TEST(be.at_start() == true);
+ DLIB_TEST(be.size() == 4);
+ be = border_enumerator(rectangle(4,4,5,5),2);
+ DLIB_TEST(be.size() == 4);
+ be = border_enumerator(rectangle(4,4,6,6),1);
+ DLIB_TEST(be.size() == 8);
+ be = border_enumerator(rectangle(4,4,6,6),2);
+ DLIB_TEST(be.size() == 9);
+ be = border_enumerator(rectangle(4,4,6,6),3);
+ DLIB_TEST(be.size() == 9);
+ DLIB_TEST(be.at_start() == true);
+
+ array2d<unsigned char> img, img2;
+ for (int size = 1; size < 10; ++size)
+ {
+ for (int bs = 0; bs < 4; ++bs)
+ {
+ img.set_size(size,size);
+ img2.set_size(size,size);
+
+ assign_all_pixels(img, 1);
+ assign_all_pixels(img2, 1);
+
+ zero_border_pixels(img2, bs,bs);
+
+ be = border_enumerator(get_rect(img),bs);
+ DLIB_TEST(be.at_start() == true);
+ DLIB_TEST(be.current_element_valid() == false);
+ while (be.move_next())
+ {
+ DLIB_TEST(be.at_start() == false);
+ DLIB_TEST(be.current_element_valid() == true);
+ DLIB_TEST_MSG(get_rect(img).contains(be.element()) == true,
+ get_rect(img) << " " << be.element()
+ );
+ const point p = be.element();
+ img[p.y()][p.x()] = 0;
+ }
+ DLIB_TEST(be.at_start() == false);
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.move_next() == false);
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.move_next() == false);
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.at_start() == false);
+
+ DLIB_TEST(mat(img) == mat(img2));
+
+ }
+ }
+
+ for (int size = 1; size < 10; ++size)
+ {
+ for (int bs = 0; bs < 4; ++bs)
+ {
+ img.set_size(size,size+5);
+ img2.set_size(size,size+5);
+
+ assign_all_pixels(img, 1);
+ assign_all_pixels(img2, 1);
+
+ zero_border_pixels(img2, bs,bs);
+
+ const point shift(4,5);
+
+ be = border_enumerator(translate_rect(get_rect(img),shift),bs);
+ DLIB_TEST(be.at_start() == true);
+ DLIB_TEST(be.current_element_valid() == false);
+ while (be.move_next())
+ {
+ DLIB_TEST(be.current_element_valid() == true);
+ DLIB_TEST(be.at_start() == false);
+ DLIB_TEST_MSG(get_rect(img).contains(be.element()-shift) == true,
+ get_rect(img) << " " << be.element()
+ );
+ const point p = be.element()-shift;
+ img[p.y()][p.x()] = 0;
+ }
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.move_next() == false);
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.move_next() == false);
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.at_start() == false);
+
+ DLIB_TEST(mat(img) == mat(img2));
+
+ }
+ }
+
+ for (int size = 1; size < 10; ++size)
+ {
+ for (int bs = 0; bs < 4; ++bs)
+ {
+ img.set_size(size+2,size);
+ img2.set_size(size+2,size);
+
+ assign_all_pixels(img, 1);
+ assign_all_pixels(img2, 1);
+
+ zero_border_pixels(img2, bs,bs);
+
+ const point shift(-4,5);
+
+ be = border_enumerator(translate_rect(get_rect(img),shift),bs);
+ DLIB_TEST(be.current_element_valid() == false);
+ while (be.move_next())
+ {
+ DLIB_TEST(be.current_element_valid() == true);
+ DLIB_TEST_MSG(get_rect(img).contains(be.element()-shift) == true,
+ get_rect(img) << " " << be.element()
+ );
+ const point p = be.element()-shift;
+ img[p.y()][p.x()] = 0;
+ }
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.move_next() == false);
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.move_next() == false);
+ DLIB_TEST(be.current_element_valid() == false);
+
+ DLIB_TEST(mat(img) == mat(img2));
+
+ }
+ }
+
+ {
+ matrix<bool,4,5> hits, truth;
+ const rectangle rect = rectangle(1,1,4,3);
+
+ border_enumerator be(rect, rectangle(2,2, 3, 3));
+ DLIB_TEST(be.size() == 8);
+ hits = false;
+ while (be.move_next())
+ {
+ DLIB_TEST(rect.contains(be.element()));
+ hits(be.element().y(), be.element().x()) = true;
+ }
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.size() == 8);
+ truth = false;
+ truth(1,1) = truth(1,2) = truth(1,3) = truth(1,4) = truth(2,1) =
+ truth(3,1) = truth(2,4) = truth(3,4) = true;
+ DLIB_TEST_MSG(truth == hits, truth << endl << hits);
+
+
+
+
+ be = border_enumerator(rect, rectangle(0,0, 9, 9));
+ DLIB_TEST(be.size() == 0);
+ hits = false;
+ while (be.move_next())
+ {
+ DLIB_TEST(rect.contains(be.element()));
+ hits(be.element().y(), be.element().x()) = true;
+ }
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.size() == 0);
+ truth = false;
+ DLIB_TEST(truth == hits);
+
+
+
+ be = border_enumerator(rect, rectangle(0,0, 3, 9));
+ DLIB_TEST(be.size() == 3);
+ hits = false;
+ while (be.move_next())
+ {
+ DLIB_TEST(rect.contains(be.element()));
+ hits(be.element().y(), be.element().x()) = true;
+ }
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.size() == 3);
+ truth = false;
+ truth(1,4) = truth(2,4) = truth(3,4) = true;
+ DLIB_TEST(truth == hits);
+
+
+
+
+ be = border_enumerator(rect, rectangle(2,1, 4, 3));
+ DLIB_TEST(be.size() == 3);
+ hits = false;
+ while (be.move_next())
+ {
+ DLIB_TEST(rect.contains(be.element()));
+ hits(be.element().y(), be.element().x()) = true;
+ }
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.size() == 3);
+ truth = false;
+ truth(1,1) = truth(2,1) = truth(3,1) = true;
+ DLIB_TEST(truth == hits);
+
+
+
+ be = border_enumerator(rect, rectangle(1,1, 5, 2));
+ DLIB_TEST(be.size() == 4);
+ hits = false;
+ while (be.move_next())
+ {
+ DLIB_TEST(rect.contains(be.element()));
+ hits(be.element().y(), be.element().x()) = true;
+ }
+ DLIB_TEST(be.current_element_valid() == false);
+ DLIB_TEST(be.size() == 4);
+ truth = false;
+ truth(3,1) = truth(3,2) = truth(3,3) = truth(3,4) = true;
+ DLIB_TEST(truth == hits);
+
+
+
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_find_affine_transform()
+ {
+ //typedef dlib::vector<double,2> vect;
+ typedef point vect;
+ std::vector<vect> from, to;
+
+ from.push_back(vect(0,0));
+ to.push_back(vect(0,1));
+
+ from.push_back(vect(0,1));
+ to.push_back(vect(1,1));
+
+ from.push_back(vect(1,1));
+ to.push_back(vect(1,0));
+
+ from.push_back(vect(1,0));
+ to.push_back(vect(0,0));
+
+ point_transform_affine t = find_affine_transform(from,to);
+ point_transform_affine tinv = inv(t);
+
+ for (unsigned long i = 0; i < from.size(); ++i)
+ {
+ dlog << LINFO << "affine transformation error: "<< length(t(from[i])-to[i]);
+ DLIB_TEST(length(t(from[i])-to[i]) < 1e-14);
+ DLIB_TEST(length(tinv(t(from[i]))-from[i]) < 1e-14);
+ DLIB_TEST(length(t(tinv(from[i]))-from[i]) < 1e-14);
+
+ point_transform_affine temp = t*inv(t);
+ DLIB_TEST(length(temp.get_b()) < 1e-14);
+ DLIB_TEST(max(abs(temp.get_m() - identity_matrix<double>(2))) < 1e-14);
+ }
+
+ ostringstream sout;
+ serialize(t, sout);
+ istringstream sin(sout.str());
+ point_transform_affine t2;
+ DLIB_TEST(length(t2(point(2,3)) - point(2,3)) < 1e-14);
+ deserialize(t2, sin);
+ DLIB_TEST(max(abs(t2.get_m()-t.get_m())) < 1e-14);
+ DLIB_TEST(max(abs(t2.get_b()-t.get_b())) < 1e-14);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double projective_transform_pass_rate(const double error_rate)
+ {
+ print_spinner();
+ dlog << LINFO << "projective_transform_pass_rate, error_rate: "<< error_rate;
+ dlib::rand rnd;
+ running_stats<double> pass_rate;
+ for (int rounds = 0; rounds < 1000; ++rounds)
+ {
+ running_stats<double> rs, rs_true;
+ matrix<double> H = 2*(randm(3,3,rnd)-0.5);
+
+ H(0,2) = rnd.get_random_gaussian()*10;
+ H(1,2) = rnd.get_random_gaussian()*10;
+
+
+ H(2,0) = rnd.get_random_double()*2.1;
+ H(2,1) = rnd.get_random_double()*2.1;
+ H(2,2) = 1 + rnd.get_random_gaussian()*3.1;
+
+ point_transform_projective tran(H);
+ point_transform_projective traninv = inv(tran);
+
+ const int num = rnd.get_random_32bit_number()%8 + 4;
+
+ std::vector<dlib::vector<double,2> > from_points, to_points;
+ for (int i = 0; i < num; ++i)
+ {
+ dlib::vector<double,2> p = randm(2,1,rnd)*1000;
+ from_points.push_back(p);
+ to_points.push_back(tran(p) + (randm(2,1,rnd)-0.5)*error_rate);
+ DLIB_TEST(length(traninv(tran(p))-p) <= 1e-5);
+ DLIB_TEST(length(tran(traninv(p))-p) <= 1e-5);
+
+ point_transform_projective temp = tran*traninv;
+ DLIB_TEST_MSG(max(abs(temp.get_m() - identity_matrix<double>(3))) < 1e-10, temp.get_m());
+ temp = traninv*tran;
+ DLIB_TEST_MSG(max(abs(temp.get_m() - identity_matrix<double>(3))) < 1e-10, temp.get_m());
+ }
+
+
+ point_transform_projective tran2 = find_projective_transform(from_points, to_points);
+
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ const double err = length_squared(tran2(from_points[i]) - to_points[i]);
+ rs.add(err);
+ const double err_true = length_squared(tran(from_points[i]) - to_points[i]);
+ rs_true.add(err_true);
+ }
+
+ if ( rs.mean() < 0.01)
+ {
+ pass_rate.add(1);
+ }
+ else
+ {
+ dlog << LINFO << " errors: mean/max: " << rs.mean() << " " << rs.max();
+ pass_rate.add(0);
+ }
+
+ ostringstream sout;
+ serialize(tran, sout);
+ istringstream sin(sout.str());
+ point_transform_projective tran3;
+ DLIB_TEST(length(tran3(point(2,3)) - point(2,3)) < 1e-14);
+ deserialize(tran3, sin);
+ DLIB_TEST(max(abs(tran3.get_m()-tran.get_m())) < 1e-14);
+ }
+
+ dlog << LINFO << " pass_rate.mean(): "<< pass_rate.mean();
+ return pass_rate.mean();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void test_find_similarity_transform()
+ {
+ print_spinner();
+ std::vector<dlib::vector<T,2> > from_points, to_points;
+
+ from_points.push_back(dlib::vector<T,2>(0,0));
+ from_points.push_back(dlib::vector<T,2>(0,1));
+ from_points.push_back(dlib::vector<T,2>(1,0));
+
+ to_points.push_back(dlib::vector<T,2>(8,0));
+ to_points.push_back(dlib::vector<T,2>(6,0));
+ to_points.push_back(dlib::vector<T,2>(8,2));
+
+ point_transform_affine tform = find_similarity_transform(from_points, to_points);
+
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ DLIB_TEST(length(tform(from_points[i]) - to_points[i]) < 1e-14);
+ }
+ }
+
+ template <typename T>
+ void test_find_similarity_transform2()
+ {
+ print_spinner();
+ std::vector<dlib::vector<T,2> > from_points, to_points;
+
+ from_points.push_back(dlib::vector<T,2>(0,0));
+ from_points.push_back(dlib::vector<T,2>(0,1));
+
+ to_points.push_back(dlib::vector<T,2>(8,0));
+ to_points.push_back(dlib::vector<T,2>(6,0));
+
+ point_transform_affine tform = find_similarity_transform(from_points, to_points);
+
+ for (unsigned long i = 0; i < from_points.size(); ++i)
+ {
+ DLIB_TEST(length(tform(from_points[i]) - to_points[i]) < 1e-14);
+ }
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ void test_rect_to_drect()
+ {
+ print_spinner();
+ dlib::rand rnd;
+ for (int i = 0; i < 5000; ++i)
+ {
+ rectangle rect = centered_rect(rnd.get_random_32bit_number()%100,
+ rnd.get_random_32bit_number()%100,
+ rnd.get_random_32bit_number()%100,
+ rnd.get_random_32bit_number()%100);
+
+ drectangle drect = rect;
+ rectangle rect2 = drect;
+ DLIB_TEST(rect2 == rect);
+ DLIB_TEST(rect.width() == drect.width());
+ DLIB_TEST(rect.height() == drect.height());
+ DLIB_TEST(dcenter(rect) == dcenter(drect));
+ DLIB_TEST(rect.is_empty() == drect.is_empty());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_affine3d()
+ {
+ const dlib::vector<double> x(1,0,0);
+ const dlib::vector<double> y(0,1,0);
+ const dlib::vector<double> z(0,0,1);
+ const dlib::vector<double> e(1,1,1);
+ const dlib::vector<double> ex(-1,1,1);
+ const dlib::vector<double> ey(1,-1,1);
+ const dlib::vector<double> ez(1,1,-1);
+
+ dlib::vector<double> w;
+
+ w = rotate_around_z(pi/2)(x);
+ DLIB_TEST(length(w-y) < 1e-12);
+ w = rotate_around_z(pi/2)(e);
+ DLIB_TEST(length(w-ex) < 1e-12);
+
+ w = rotate_around_y(-pi/2)(x);
+ DLIB_TEST(length(w-z) < 1e-12);
+ w = rotate_around_y(pi/2)(e);
+ DLIB_TEST(length(w-ez) < 1e-12);
+
+ w = rotate_around_x(pi/2)(y);
+ DLIB_TEST(length(w-z) < 1e-12);
+ w = rotate_around_x(pi/2)(e);
+ DLIB_TEST(length(w-ey) < 1e-12);
+
+ w = translate_point(x)(y);
+ DLIB_TEST(length(w-x-y) < 1e-12);
+
+ point_transform_affine3d tform;
+ tform = rotate_around_x(pi/2)*rotate_around_z(pi/2)*translate_point(x);
+ DLIB_TEST(length(tform(dlib::vector<double>())-z) < 1e-12);
+ DLIB_TEST(length(inv(tform)(z)) < 1e-12);
+
+ point_transform_affine tform2;
+ tform = tform*tform2;// the default tform is the identity mapping so this shouldn't do anything different
+ DLIB_TEST(length(tform(dlib::vector<double>())-z) < 1e-12);
+ DLIB_TEST(length(inv(tform)(z)) < 1e-12);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class geometry_tester : public tester
+ {
+ public:
+ geometry_tester (
+ ) :
+ tester ("test_geometry",
+ "Runs tests on the geometry stuff.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_affine3d();
+ test_rect_to_drect();
+ geometry_test();
+ test_border_enumerator();
+ test_find_affine_transform();
+ DLIB_TEST(projective_transform_pass_rate(0.1) > 0.99);
+ DLIB_TEST(projective_transform_pass_rate(0.0) == 1);
+
+ test_find_similarity_transform<double>();
+ test_find_similarity_transform2<double>();
+ test_find_similarity_transform<float>();
+ test_find_similarity_transform2<float>();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/global_optimization.cpp b/ml/dlib/dlib/test/global_optimization.cpp
new file mode 100644
index 000000000..fee80c81f
--- /dev/null
+++ b/ml/dlib/dlib/test/global_optimization.cpp
@@ -0,0 +1,302 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/global_optimization.h>
+#include <dlib/statistics.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include <dlib/rand.h>
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.global_optimization");
+
+// ----------------------------------------------------------------------------------------
+
+ void test_upper_bound_function(double relative_noise_magnitude, double solver_eps)
+ {
+ print_spinner();
+
+ dlog << LINFO << "test_upper_bound_function, relative_noise_magnitude="<< relative_noise_magnitude << ", solver_eps=" << solver_eps;
+
+ auto rosen = [](const matrix<double,0,1>& x) { return -1*( 100*std::pow(x(1) - x(0)*x(0),2.0) + std::pow(1 - x(0),2)); };
+
+ dlib::rand rnd;
+ auto make_rnd = [&rnd]() { matrix<double,0,1> x(2); x = 2*rnd.get_random_double(), 2*rnd.get_random_double(); return x; };
+
+
+ std::vector<function_evaluation> evals;
+ for (int i = 0; i < 100; ++i)
+ {
+ auto x = make_rnd();
+ evals.emplace_back(x,rosen(x));
+ }
+
+ upper_bound_function ub(evals, relative_noise_magnitude, solver_eps);
+ DLIB_TEST(ub.num_points() == (long)evals.size());
+ DLIB_TEST(ub.dimensionality() == 2);
+ for (auto& ev : evals)
+ {
+ dlog << LINFO << ub(ev.x) - ev.y;
+ DLIB_TEST_MSG(ub(ev.x) - ev.y > -1e10, ub(ev.x) - ev.y);
+ }
+
+
+ for (int i = 0; i < 100; ++i)
+ {
+ auto x = make_rnd();
+ evals.emplace_back(x,rosen(x));
+ ub.add(evals.back());
+ }
+
+ DLIB_TEST(ub.num_points() == (long)evals.size());
+ DLIB_TEST(ub.dimensionality() == 2);
+
+ for (auto& ev : evals)
+ {
+ dlog << LINFO << ub(ev.x) - ev.y;
+ DLIB_TEST_MSG(ub(ev.x) - ev.y > -1e10, ub(ev.x) - ev.y);
+ }
+
+
+ if (solver_eps < 0.001)
+ {
+ dlog << LINFO << "out of sample points: ";
+ for (int i = 0; i < 10; ++i)
+ {
+ auto x = make_rnd();
+ dlog << LINFO << ub(x) - rosen(x);
+ DLIB_TEST_MSG(ub(x) - rosen(x) > 1e-10, ub(x) - rosen(x));
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double complex_holder_table ( double x0, double x1)
+ {
+ // The regular HolderTable function
+ //return -std::abs(sin(x0)*cos(x1)*exp(std::abs(1-std::sqrt(x0*x0+x1*x1)/pi)));
+
+ // My more complex version of it with discontinuities and more local minima.
+ double sign = 1;
+ for (double j = -4; j < 9; j += 0.5)
+ {
+ if (j < x0 && x0 < j+0.5)
+ x0 += sign*0.25;
+ sign *= -1;
+ }
+ // HolderTable function tilted towards 10,10
+ return -std::abs(sin(x0)*cos(x1)*exp(std::abs(1-std::sqrt(x0*x0+x1*x1)/pi))) +(x0+x1)/10 + sin(x0*10)*cos(x1*10);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_global_function_search()
+ {
+
+ function_spec spec{{-10,-10}, {10,10}};
+ function_spec spec2{{-10,-10, -50}, {10,10, 50}};
+ global_function_search opt({spec, spec, spec2});
+
+ dlib::rand rnd;
+ bool found_optimal_point = false;
+ for (int i = 0; i < 400 && !found_optimal_point; ++i)
+ {
+ print_spinner();
+ std::vector<function_evaluation_request> nexts;
+ for (int k = 0; k < rnd.get_integer_in_range(1,4); ++k)
+ nexts.emplace_back(opt.get_next_x());
+
+ for (auto& next : nexts)
+ {
+ switch (next.function_idx())
+ {
+ case 0: next.set( -complex_holder_table(next.x()(0), next.x()(1))); break;
+ case 1: next.set( -10*complex_holder_table(next.x()(0), next.x()(1))); break;
+ case 2: next.set( -2*complex_holder_table(next.x()(0), next.x()(1))); break;
+ default: DLIB_TEST(false); break;
+ }
+
+ matrix<double,0,1> x;
+ double y;
+ size_t function_idx;
+ opt.get_best_function_eval(x,y,function_idx);
+ /*
+ cout << "\ni: "<< i << endl;
+ cout << "best eval x: "<< trans(x);
+ cout << "best eval y: "<< y << endl;
+ cout << "best eval function index: "<< function_idx << endl;
+ */
+
+ if (std::abs(y - 10*21.9210397) < 0.0001)
+ {
+ found_optimal_point = true;
+ break;
+ }
+ }
+ }
+
+ DLIB_TEST(found_optimal_point);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_find_max_global(
+ )
+ {
+ print_spinner();
+ auto rosen = [](const matrix<double,0,1>& x) { return -1*( 100*std::pow(x(1) - x(0)*x(0),2.0) + std::pow(1 - x(0),2)); };
+
+ auto result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100), 0);
+ matrix<double,0,1> true_x = {1,1};
+
+ dlog << LINFO << "rosen: " << trans(result.x);
+ DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100));
+ dlog << LINFO << "rosen: " << trans(result.x);
+ DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, std::chrono::seconds(5));
+ dlog << LINFO << "rosen: " << trans(result.x);
+ DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, {false,false}, max_function_calls(100));
+ dlog << LINFO << "rosen: " << trans(result.x);
+ DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_max_global(rosen, {0.1, 0.1}, {0.9, 0.9}, {false,false}, max_function_calls(140));
+ true_x = {0.9, 0.81};
+ dlog << LINFO << "rosen, bounded at 0.9: " << trans(result.x);
+ DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_max_global([](double x){ return -std::pow(x-2,2.0); }, -10, 10, max_function_calls(10), 0);
+ dlog << LINFO << "(x-2)^2: " << trans(result.x);
+ DLIB_TEST(result.x.size()==1);
+ DLIB_TEST(std::abs(result.x - 2) < 1e-9);
+ print_spinner();
+
+ result = find_max_global([](double x){ return -std::pow(x-2,2.0); }, -10, 1, max_function_calls(10));
+ dlog << LINFO << "(x-2)^2, bound at 1: " << trans(result.x);
+ DLIB_TEST(result.x.size()==1);
+ DLIB_TEST(std::abs(result.x - 1) < 1e-9);
+ print_spinner();
+
+ result = find_max_global([](double x){ return -std::pow(x-2,2.0); }, -10, 1, std::chrono::seconds(2));
+ dlog << LINFO << "(x-2)^2, bound at 1: " << trans(result.x);
+ DLIB_TEST(result.x.size()==1);
+ DLIB_TEST(std::abs(result.x - 1) < 1e-9);
+ print_spinner();
+
+
+ result = find_max_global([](double a, double b){ return -complex_holder_table(a,b);},
+ {-10, -10}, {10, 10}, max_function_calls(400), 0);
+ dlog << LINFO << "complex_holder_table y: "<< result.y;
+ DLIB_TEST_MSG(std::abs(result.y - 21.9210397) < 0.0001, std::abs(result.y - 21.9210397));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_find_min_global(
+ )
+ {
+ print_spinner();
+ auto rosen = [](const matrix<double,0,1>& x) { return +1*( 100*std::pow(x(1) - x(0)*x(0),2.0) + std::pow(1 - x(0),2)); };
+
+ auto result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100), 0);
+ matrix<double,0,1> true_x = {1,1};
+
+ dlog << LINFO << "rosen: " << trans(result.x);
+ DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100));
+ dlog << LINFO << "rosen: " << trans(result.x);
+ DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, std::chrono::seconds(5));
+ dlog << LINFO << "rosen: " << trans(result.x);
+ DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, {false,false}, max_function_calls(100));
+ dlog << LINFO << "rosen: " << trans(result.x);
+ DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_min_global(rosen, {0, 0}, {0.9, 0.9}, {false,false}, max_function_calls(100));
+ true_x = {0.9, 0.81};
+ dlog << LINFO << "rosen, bounded at 0.9: " << trans(result.x);
+ DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x)));
+ print_spinner();
+
+ result = find_min_global([](double x){ return std::pow(x-2,2.0); }, -10, 10, max_function_calls(10), 0);
+ dlog << LINFO << "(x-2)^2: " << trans(result.x);
+ DLIB_TEST(result.x.size()==1);
+ DLIB_TEST(std::abs(result.x - 2) < 1e-9);
+ print_spinner();
+
+ result = find_min_global([](double x){ return std::pow(x-2,2.0); }, -10, 1, max_function_calls(10));
+ dlog << LINFO << "(x-2)^2, bound at 1: " << trans(result.x);
+ DLIB_TEST(result.x.size()==1);
+ DLIB_TEST(std::abs(result.x - 1) < 1e-9);
+ print_spinner();
+
+ result = find_min_global([](double x){ return std::pow(x-2,2.0); }, -10, 1, std::chrono::seconds(2));
+ dlog << LINFO << "(x-2)^2, bound at 1: " << trans(result.x);
+ DLIB_TEST(result.x.size()==1);
+ DLIB_TEST(std::abs(result.x - 1) < 1e-9);
+ print_spinner();
+
+
+ result = find_min_global([](double a, double b){ return complex_holder_table(a,b);},
+ {-10, -10}, {10, 10}, max_function_calls(400), 0);
+ dlog << LINFO << "complex_holder_table y: "<< result.y;
+ DLIB_TEST_MSG(std::abs(result.y + 21.9210397) < 0.0001, std::abs(result.y + 21.9210397));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class global_optimization_tester : public tester
+ {
+ public:
+ global_optimization_tester (
+ ) :
+ tester ("test_global_optimization",
+ "Runs tests on the global optimization components.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_upper_bound_function(0.01, 1e-6);
+ test_upper_bound_function(0.0, 1e-6);
+ test_upper_bound_function(0.0, 1e-1);
+ test_global_function_search();
+ test_find_max_global();
+ test_find_min_global();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/graph.cpp b/ml/dlib/dlib/test/graph.cpp
new file mode 100644
index 000000000..0651f3049
--- /dev/null
+++ b/ml/dlib/dlib/test/graph.cpp
@@ -0,0 +1,414 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/graph.h>
+#include <dlib/graph_utils.h>
+#include <dlib/set.h>
+
+#include "tester.h"
+
+// This is called an unnamed-namespace and it has the effect of making everything inside this file "private"
+// so that everything you declare will have static linkage. Thus we won't have any multiply
+// defined symbol errors coming out of the linker when we try to compile the test suite.
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+
+
+ // Declare the logger we will use in this test. The name of the tester
+ // should start with "test."
+ logger dlog("test.graph");
+
+ template <
+ typename graph
+ >
+ void graph_test (
+ )
+ /*!
+ requires
+ - graph is an implementation of graph/graph_kernel_abstract.h
+ is instantiated with int
+ ensures
+ - runs tests on graph for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ COMPILE_TIME_ASSERT(is_graph<graph>::value);
+
+ graph a, b;
+ dlib::set<unsigned long>::compare_1b_c s;
+
+ DLIB_TEST(graph_contains_length_one_cycle(a) == false);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+
+ DLIB_TEST(a.number_of_nodes() == 0);
+
+ a.set_number_of_nodes(5);
+ DLIB_TEST(graph_is_connected(a) == false);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+ DLIB_TEST(a.number_of_nodes() == 5);
+ DLIB_TEST(graph_contains_length_one_cycle(a) == false);
+
+ for (int i = 0; i < 5; ++i)
+ {
+ a.node(i).data = i;
+ DLIB_TEST(a.node(i).index() == (unsigned int)i);
+ }
+
+ a.remove_node(1);
+
+ DLIB_TEST(a.number_of_nodes() == 4);
+
+
+ // make sure that only the number with data == 1 was removed
+ int count = 0;
+ for (int i = 0; i < 4; ++i)
+ {
+ count += a.node(i).data;
+ DLIB_TEST(a.node(i).number_of_neighbors() == 0);
+ DLIB_TEST(a.node(i).index() == (unsigned int)i);
+ }
+
+ DLIB_TEST(count == 9);
+
+
+ a.add_edge(1,1);
+ DLIB_TEST(graph_contains_length_one_cycle(a) == true);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+ DLIB_TEST(a.has_edge(1,1));
+ DLIB_TEST(a.node(1).number_of_neighbors() == 1);
+
+ a.add_edge(1,3);
+ DLIB_TEST(a.node(1).number_of_neighbors() == 2);
+ DLIB_TEST(a.node(2).number_of_neighbors() == 0);
+ DLIB_TEST(a.node(3).number_of_neighbors() == 1);
+ DLIB_TEST(a.has_edge(1,1));
+ DLIB_TEST(a.has_edge(1,3));
+ DLIB_TEST(a.has_edge(3,1));
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+ a.remove_edge(1,1);
+ DLIB_TEST(graph_contains_length_one_cycle(a) == false);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+ DLIB_TEST(a.node(1).number_of_neighbors() == 1);
+ DLIB_TEST(a.node(2).number_of_neighbors() == 0);
+ DLIB_TEST(a.node(3).number_of_neighbors() == 1);
+ DLIB_TEST(a.has_edge(1,1) == false);
+ DLIB_TEST(a.has_edge(1,3));
+ DLIB_TEST(a.has_edge(3,1));
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+
+ swap(a,b);
+
+
+ DLIB_TEST(graph_contains_undirected_cycle(b) == false);
+ DLIB_TEST(b.node(1).number_of_neighbors() == 1);
+ DLIB_TEST(b.node(2).number_of_neighbors() == 0);
+ DLIB_TEST(b.node(3).number_of_neighbors() == 1);
+ DLIB_TEST(b.has_edge(1,1) == false);
+ DLIB_TEST(b.has_edge(1,3));
+ DLIB_TEST(b.has_edge(3,1));
+ DLIB_TEST(graph_contains_undirected_cycle(b) == false);
+
+ DLIB_TEST(a.number_of_nodes() == 0);
+ DLIB_TEST(b.number_of_nodes() == 4);
+
+ copy_graph_structure(b,b);
+ DLIB_TEST(b.number_of_nodes() == 4);
+
+ b.add_edge(1,2);
+ DLIB_TEST(graph_contains_undirected_cycle(b) == false);
+ DLIB_TEST(graph_contains_undirected_cycle(b) == false);
+ b.add_edge(3,2);
+ DLIB_TEST(graph_contains_undirected_cycle(b) == true);
+ b.add_edge(1,1);
+ DLIB_TEST(graph_is_connected(b) == false);
+ b.add_edge(0,2);
+ DLIB_TEST(graph_is_connected(b) == true);
+
+ DLIB_TEST(graph_contains_undirected_cycle(b) == true);
+
+ DLIB_TEST(a.number_of_nodes() == 0);
+
+ for (unsigned long i = 0; i < b.number_of_nodes(); ++i)
+ {
+ for (unsigned long j = 0; j < b.node(i).number_of_neighbors(); ++j)
+ {
+ b.node(i).edge(j) = 'c';
+ }
+ }
+
+ b.node(1).edge(0) = 'a';
+ const unsigned long e1 = b.node(1).neighbor(0).index();
+ b.node(0).edge(0) = 'n';
+ const unsigned long e2 = b.node(0).neighbor(0).index();
+
+ ostringstream sout;
+ serialize(b, sout);
+ istringstream sin(sout.str());
+
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+
+ a.set_number_of_nodes(10);
+ deserialize(a, sin);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+
+ for (unsigned long i = 0; i < a.number_of_nodes(); ++i)
+ {
+ for (unsigned long j = 0; j < a.node(i).number_of_neighbors(); ++j)
+ {
+ if ((i == 0 && a.node(i).neighbor(j).index() == e2) ||
+ (i == e2 && a.node(i).neighbor(j).index() == 0) )
+ {
+ DLIB_TEST(a.node(i).edge(j) == 'n');
+ }
+ else if ((i == 1 && a.node(i).neighbor(j).index() == e1) ||
+ (i == e1 && a.node(i).neighbor(j).index() == 1))
+ {
+ DLIB_TEST(a.node(i).edge(j) == 'a');
+ }
+ else
+ {
+ DLIB_TEST(i != 0 || a.node(i).neighbor(j).index() != e2);
+ DLIB_TEST_MSG(a.node(i).edge(j) == 'c',a.node(i).edge(j));
+ }
+ }
+ }
+
+ DLIB_TEST(a.number_of_nodes() == 4);
+ DLIB_TEST(a.has_edge(1,2) == true);
+ DLIB_TEST(a.has_edge(3,2) == true);
+ DLIB_TEST(a.has_edge(1,1) == true);
+ DLIB_TEST(a.has_edge(0,2) == true);
+ DLIB_TEST(a.has_edge(1,3) == true);
+ DLIB_TEST(a.has_edge(0,1) == false);
+ DLIB_TEST(a.has_edge(0,3) == false);
+ DLIB_TEST(a.has_edge(0,0) == false);
+ DLIB_TEST(a.has_edge(1,0) == false);
+ DLIB_TEST(a.has_edge(3,0) == false);
+
+
+ for (unsigned long i = 0; i < a.number_of_nodes(); ++i)
+ {
+ a.node(i).data = static_cast<int>(i);
+ }
+
+ a.remove_node(2);
+ DLIB_TEST(a.number_of_nodes() == 3);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+
+ count = 0;
+ for (unsigned long i = 0; i < a.number_of_nodes(); ++i)
+ {
+ if (a.node(i).data == 0)
+ {
+ DLIB_TEST(a.node(i).number_of_neighbors() == 0);
+ }
+ else if (a.node(i).data == 1)
+ {
+ DLIB_TEST(a.node(i).number_of_neighbors() == 2);
+ }
+ else if (a.node(i).data == 3)
+ {
+ DLIB_TEST(a.node(i).number_of_neighbors() == 1);
+ }
+ else
+ {
+ DLIB_TEST_MSG(false,"this is impossible");
+ }
+
+ for (unsigned long j = 0; j < a.number_of_nodes(); ++j)
+ {
+ if ((a.node(i).data == 1 && a.node(j).data == 1) ||
+ (a.node(i).data == 1 && a.node(j).data == 3) ||
+ (a.node(i).data == 3 && a.node(j).data == 1))
+ {
+ DLIB_TEST(a.has_edge(i,j) == true);
+ ++count;
+ }
+ else
+ {
+ DLIB_TEST(a.has_edge(i,j) == false);
+ }
+ }
+ }
+ DLIB_TEST_MSG(count == 3,count);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == true);
+ a.remove_edge(1,1);
+ DLIB_TEST(graph_contains_undirected_cycle(a) == false);
+
+ DLIB_TEST(b.number_of_nodes() == 4);
+ b.clear();
+ DLIB_TEST(b.number_of_nodes() == 0);
+
+
+ a.clear();
+
+ /*
+ 1 7
+ | / \
+ 2 6 0
+ \ / |
+ 3 /
+ / \ /
+ 4 5
+ */
+ a.set_number_of_nodes(8);
+ a.add_edge(1,2);
+ a.add_edge(2,3);
+ a.add_edge(3,4);
+ a.add_edge(3,5);
+ a.add_edge(3,6);
+ a.add_edge(6,7);
+ a.add_edge(7,0);
+ a.add_edge(0,5);
+
+ DLIB_TEST(graph_is_connected(a));
+
+ dlib::set<dlib::set<unsigned long>::compare_1b_c>::kernel_1b_c sos;
+
+ dlib::graph<dlib::set<unsigned long>::compare_1b_c, dlib::set<unsigned long>::compare_1b_c>::kernel_1a_c join_tree;
+ unsigned long temp;
+ triangulate_graph_and_find_cliques(a,sos);
+ DLIB_TEST(a.number_of_nodes() == 8);
+
+ create_join_tree(a, join_tree);
+ DLIB_TEST(join_tree.number_of_nodes() == 6);
+ DLIB_TEST(graph_is_connected(join_tree) == true);
+ DLIB_TEST(graph_contains_undirected_cycle(join_tree) == false);
+ DLIB_TEST(is_join_tree(a, join_tree));
+
+ // check old edges
+ DLIB_TEST(a.has_edge(1,2));
+ DLIB_TEST(a.has_edge(2,3));
+ DLIB_TEST(a.has_edge(3,4));
+ DLIB_TEST(a.has_edge(3,5));
+ DLIB_TEST(a.has_edge(3,6));
+ DLIB_TEST(a.has_edge(6,7));
+ DLIB_TEST(a.has_edge(7,0));
+ DLIB_TEST(a.has_edge(0,5));
+
+ DLIB_TEST(graph_is_connected(a));
+
+ DLIB_TEST(sos.size() == 6);
+
+
+ temp = 1; s.add(temp);
+ temp = 2; s.add(temp);
+ DLIB_TEST(sos.is_member(s));
+ s.clear();
+ temp = 2; s.add(temp);
+ temp = 3; s.add(temp);
+ DLIB_TEST(sos.is_member(s));
+ s.clear();
+ temp = 4; s.add(temp);
+ temp = 3; s.add(temp);
+ DLIB_TEST(sos.is_member(s));
+
+ sos.reset();
+ while (sos.move_next())
+ {
+ DLIB_TEST(is_clique(a, sos.element()));
+ DLIB_TEST(is_maximal_clique(a, sos.element()));
+ }
+
+ }
+
+
+ void test_copy()
+ {
+ {
+ graph<int,int>::kernel_1a_c a,b;
+
+ a.set_number_of_nodes(3);
+ a.node(0).data = 1;
+ a.node(1).data = 2;
+ a.node(2).data = 3;
+ a.add_edge(0,1);
+ a.add_edge(0,2);
+ edge(a,0,1) = 4;
+ edge(a,0,2) = 5;
+
+ a.add_edge(0,0);
+ edge(a,0,0) = 9;
+ copy_graph(a, b);
+
+ DLIB_TEST(b.number_of_nodes() == 3);
+ DLIB_TEST(b.node(0).data == 1);
+ DLIB_TEST(b.node(1).data == 2);
+ DLIB_TEST(b.node(2).data == 3);
+ DLIB_TEST(edge(b,0,1) == 4);
+ DLIB_TEST(edge(b,0,2) == 5);
+ DLIB_TEST(edge(b,0,0) == 9);
+ }
+ {
+ graph<int,int>::kernel_1a_c a,b;
+
+ a.set_number_of_nodes(4);
+ a.node(0).data = 1;
+ a.node(1).data = 2;
+ a.node(2).data = 3;
+ a.node(3).data = 8;
+ a.add_edge(0,1);
+ a.add_edge(0,2);
+ a.add_edge(2,3);
+ edge(a,0,1) = 4;
+ edge(a,0,2) = 5;
+ edge(a,2,3) = 6;
+
+ copy_graph(a, b);
+
+ DLIB_TEST(b.number_of_nodes() == 4);
+ DLIB_TEST(b.node(0).data == 1);
+ DLIB_TEST(b.node(1).data == 2);
+ DLIB_TEST(b.node(2).data == 3);
+ DLIB_TEST(b.node(3).data == 8);
+ DLIB_TEST(edge(b,0,1) == 4);
+ DLIB_TEST(edge(b,0,2) == 5);
+ DLIB_TEST(edge(b,2,3) == 6);
+ }
+ }
+
+
+
+ class graph_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a test for the graph object. When it is constructed
+ it adds itself into the testing framework. The command line switch is
+ specified as test_directed_graph by passing that string to the tester constructor.
+ !*/
+ public:
+ graph_tester (
+ ) :
+ tester ("test_graph",
+ "Runs tests on the graph component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a_c";
+ graph_test<graph<int>::kernel_1a_c>();
+
+ dlog << LINFO << "testing kernel_1a";
+ graph_test<graph<int>::kernel_1a>();
+
+ test_copy();
+ }
+ } a;
+
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/graph_cuts.cpp b/ml/dlib/dlib/test/graph_cuts.cpp
new file mode 100644
index 000000000..43ba35c16
--- /dev/null
+++ b/ml/dlib/dlib/test/graph_cuts.cpp
@@ -0,0 +1,1217 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/graph_cuts.h>
+#include <dlib/graph_utils.h>
+#include <dlib/directed_graph.h>
+#include <dlib/graph.h>
+#include <dlib/rand.h>
+#include <dlib/hash.h>
+#include <dlib/image_transforms.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+
+ logger dlog("test.graph_cuts");
+
+// ----------------------------------------------------------------------------------------
+
+ class dense_potts_problem
+ {
+ public:
+ typedef double value_type;
+ private:
+
+ matrix<value_type,0,1> factors1;
+ matrix<value_type> factors2;
+ matrix<node_label,0,1> labels;
+ public:
+
+ dense_potts_problem (
+ unsigned long num_nodes,
+ dlib::rand& rnd
+ )
+ {
+ factors1 = -7*(randm(num_nodes, 1, rnd)-0.5);
+ factors2 = make_symmetric(randm(num_nodes, num_nodes, rnd) > 0.5);
+ labels.set_size(num_nodes);
+ labels = FREE_NODE;
+ }
+
+ unsigned long number_of_nodes (
+ ) const { return factors1.nr(); }
+
+ unsigned long number_of_neighbors (
+ unsigned long // idx
+ ) const { return number_of_nodes()-1; }
+
+ unsigned long get_neighbor_idx (
+ unsigned long node_id1,
+ unsigned long node_id2
+ ) const
+ {
+ if (node_id2 < node_id1)
+ return node_id2;
+ else
+ return node_id2-1;
+ }
+
+ unsigned long get_neighbor (
+ unsigned long node_id,
+ unsigned long idx
+ ) const
+ {
+ DLIB_TEST(node_id < number_of_nodes());
+ DLIB_TEST(idx < number_of_neighbors(node_id));
+ if (idx < node_id)
+ return idx;
+ else
+ return idx+1;
+ }
+
+ void set_label (
+ const unsigned long& idx,
+ node_label value
+ )
+ {
+ labels(idx) = value;
+ }
+
+ node_label get_label (
+ const unsigned long& idx
+ ) const
+ {
+ return labels(idx);
+ }
+
+
+ value_type factor_value (unsigned long idx) const
+ {
+ DLIB_TEST(idx < number_of_nodes());
+
+ return factors1(idx);
+ }
+
+ value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const
+ {
+ DLIB_TEST(idx1 != idx2);
+ DLIB_TEST(idx1 < number_of_nodes());
+ DLIB_TEST(idx2 < number_of_nodes());
+ DLIB_TEST(get_neighbor_idx(idx1,idx2) < number_of_neighbors(idx1));
+ DLIB_TEST(get_neighbor_idx(idx2,idx1) < number_of_neighbors(idx2));
+
+ return factors2(idx1, idx2);
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class image_potts_problem
+ {
+ public:
+ typedef double value_type;
+ const static unsigned long max_number_of_neighbors = 4;
+ private:
+
+ matrix<value_type,0,1> factors1;
+ matrix<value_type> factors2;
+ matrix<node_label,0,1> labels;
+ long nr;
+ long nc;
+ rectangle rect, inner_rect;
+ mutable long count;
+ public:
+
+ image_potts_problem (
+ long nr_,
+ long nc_,
+ dlib::rand& rnd
+ ) : nr(nr_), nc(nc_)
+ {
+ rect = rectangle(0,0,nc-1,nr-1);
+ inner_rect = shrink_rect(rect,1);
+ const unsigned long num_nodes = nr*nc;
+ factors1 = -7*(randm(num_nodes, 1, rnd));
+ factors2 = randm(num_nodes, 4, rnd) > 0.5;
+
+ //factors1 = 0;
+ //set_rowm(factors1, range(0, factors1.nr()/2)) = -1;
+
+ labels.set_size(num_nodes);
+ labels = FREE_NODE;
+
+ count = 0;
+ }
+
+ ~image_potts_problem()
+ {
+ dlog << LTRACE << "interface calls: " << count;
+ dlog << LTRACE << "labels hash: "<< murmur_hash3_128bit(&labels(0), labels.size()*sizeof(labels(0)), 0).first;
+ }
+
+ unsigned long number_of_nodes (
+ ) const { return factors1.nr(); }
+
+ unsigned long number_of_neighbors (
+ unsigned long idx
+ ) const
+ {
+ ++count;
+ const point& p = get_loc(idx);
+ if (inner_rect.contains(p))
+ return 4;
+ else if (p == rect.tl_corner() ||
+ p == rect.bl_corner() ||
+ p == rect.tr_corner() ||
+ p == rect.br_corner() )
+ return 2;
+ else
+ return 3;
+ }
+
+ unsigned long get_neighbor_idx (
+ long node_id1,
+ long node_id2
+ ) const
+ {
+ ++count;
+ const point& p = get_loc(node_id1);
+ long ret = 0;
+ if (rect.contains(p + point(1,0)))
+ {
+ if (node_id2-node_id1 == 1)
+ return ret;
+ ++ret;
+ }
+
+ if (rect.contains(p - point(1,0)))
+ {
+ if (node_id2-node_id1 == -1)
+ return ret;
+ ++ret;
+ }
+
+ if (rect.contains(p + point(0,1)))
+ {
+ if (node_id2-node_id1 == nc)
+ return ret;
+ ++ret;
+ }
+
+ return ret;
+ }
+
+ unsigned long get_neighbor (
+ long node_id,
+ long idx
+ ) const
+ {
+ ++count;
+ const point& p = get_loc(node_id);
+ if (rect.contains(p + point(1,0)))
+ {
+ if (idx == 0)
+ return node_id+1;
+ --idx;
+ }
+
+ if (rect.contains(p - point(1,0)))
+ {
+ if (idx == 0)
+ return node_id-1;
+ --idx;
+ }
+
+ if (rect.contains(p + point(0,1)))
+ {
+ if (idx == 0)
+ return node_id+nc;
+ --idx;
+ }
+
+ return node_id-nc;
+ }
+
+ void set_label (
+ const unsigned long& idx,
+ node_label value
+ )
+ {
+ ++count;
+ labels(idx) = value;
+ }
+
+ node_label get_label (
+ const unsigned long& idx
+ ) const
+ {
+ ++count;
+ return labels(idx);
+ }
+
+ value_type factor_value (unsigned long idx) const
+ {
+ ++count;
+ DLIB_TEST(idx < (unsigned long)number_of_nodes());
+
+ return factors1(idx);
+ }
+
+ value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const
+ {
+ ++count;
+ DLIB_TEST(idx1 != idx2);
+ DLIB_TEST(idx1 < (unsigned long)number_of_nodes());
+ DLIB_TEST(idx2 < (unsigned long)number_of_nodes());
+
+ // make this function symmetric
+ if (idx1 > idx2)
+ swap(idx1,idx2);
+
+
+ DLIB_TEST(get_neighbor(idx1, get_neighbor_idx(idx1, idx2)) == idx2);
+ DLIB_TEST(get_neighbor(idx2, get_neighbor_idx(idx2, idx1)) == idx1);
+
+ // the neighbor relationship better be symmetric
+ DLIB_TEST(get_neighbor_idx(idx1,idx2) < number_of_neighbors(idx1));
+ DLIB_TEST_MSG(get_neighbor_idx(idx2,idx1) < number_of_neighbors(idx2),
+ "\n idx1: "<< idx1 <<
+ "\n idx2: "<< idx2 <<
+ "\n get_neighbor_idx(idx2,idx1): "<< get_neighbor_idx(idx2,idx1) <<
+ "\n number_of_neighbors(idx2): " << number_of_neighbors(idx2) <<
+ "\n nr: "<< nr <<
+ "\n nc: "<< nc
+ );
+
+ return factors2(idx1, get_neighbor_idx(idx1,idx2));
+ }
+
+ private:
+ point get_loc (
+ const unsigned long& idx
+ ) const
+ {
+ return point(idx%nc, idx/nc);
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename potts_model>
+ void brute_force_potts_model (
+ potts_model& g
+ )
+ {
+ potts_model m(g);
+
+ const unsigned long num = (unsigned long)std::pow(2.0, (double)m.number_of_nodes());
+
+ double best_score = -std::numeric_limits<double>::infinity();
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ for (unsigned long j = 0; j < m.number_of_nodes(); ++j)
+ {
+ unsigned long T = (1)<<j;
+ T = (T&i);
+ if (T != 0)
+ m.set_label(j,SINK_CUT);
+ else
+ m.set_label(j,SOURCE_CUT);
+ }
+
+
+ double score = potts_model_score(m);
+ if (score > best_score)
+ {
+ best_score = score;
+ g = m;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename graph_type>
+ void brute_force_potts_model_on_graph (
+ const graph_type& g,
+ std::vector<node_label>& labels_
+ )
+ {
+ std::vector<node_label> labels;
+ labels.resize(g.number_of_nodes());
+
+ const unsigned long num = (unsigned long)std::pow(2.0, (double)g.number_of_nodes());
+
+ double best_score = -std::numeric_limits<double>::infinity();
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ for (unsigned long j = 0; j < g.number_of_nodes(); ++j)
+ {
+ unsigned long T = (1)<<j;
+ T = (T&i);
+ if (T != 0)
+ labels[j] = SINK_CUT;
+ else
+ labels[j] = SOURCE_CUT;
+ }
+
+
+ double score = potts_model_score(g,labels);
+ if (score > best_score)
+ {
+ best_score = score;
+ labels_ = labels;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename graph_type>
+ void make_random_undirected_graph(
+ dlib::rand& rnd,
+ graph_type& g
+ )
+ {
+ typedef typename graph_type::edge_type edge_weight_type;
+ g.clear();
+ const unsigned int num_nodes = rnd.get_random_32bit_number()%8;
+ g.set_number_of_nodes(num_nodes);
+
+ const unsigned int num_edges = static_cast<unsigned int>(num_nodes*(num_nodes-1)/2*rnd.get_random_double() + 0.5);
+
+ // add the right number of randomly selected edges
+ unsigned int count = 0;
+ while (count < num_edges)
+ {
+ unsigned long i = rnd.get_random_32bit_number()%g.number_of_nodes();
+ unsigned long j = rnd.get_random_32bit_number()%g.number_of_nodes();
+ if (i != j && g.has_edge(i, j) == false)
+ {
+ ++count;
+ g.add_edge(i, j);
+ edge(g, i, j) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ }
+ }
+
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ g.node(i).data = static_cast<edge_weight_type>(rnd.get_random_gaussian()*200);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_graph_potts_model(
+ dlib::rand& rnd
+ )
+ {
+ using namespace std;
+ double brute_force_score;
+ double graph_cut_score;
+
+ graph<double,double>::kernel_1a_c temp;
+ make_random_undirected_graph(rnd,temp);
+
+ {
+ std::vector<node_label> labels;
+
+ brute_force_potts_model_on_graph(temp, labels);
+
+ for (unsigned long i = 0; i < temp.number_of_nodes(); ++i)
+ {
+ dlog << LTRACE << "node " << i << ": "<< (int)labels[i];
+ }
+
+ brute_force_score = potts_model_score(temp, labels);
+ dlog << LTRACE << "brute force score: "<< brute_force_score;
+ }
+ dlog << LTRACE << "******************";
+
+ {
+ std::vector<node_label> labels;
+ find_max_factor_graph_potts(temp, labels);
+ DLIB_TEST(temp.number_of_nodes() == labels.size());
+
+ for (unsigned long i = 0; i < temp.number_of_nodes(); ++i)
+ {
+ dlog << LTRACE << "node " << i << ": "<< (int)labels[i];
+ }
+ graph_cut_score = potts_model_score(temp, labels);
+ dlog << LTRACE << "graph cut score: "<< graph_cut_score;
+ }
+
+ DLIB_TEST_MSG(graph_cut_score == brute_force_score, std::abs(graph_cut_score - brute_force_score));
+
+ dlog << LTRACE << "##################";
+ dlog << LTRACE << "##################";
+ dlog << LTRACE << "##################";
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename potts_prob>
+ void impl_test_potts_model (
+ potts_prob& p
+ )
+ {
+ using namespace std;
+ double brute_force_score;
+ double graph_cut_score;
+
+ {
+ potts_prob temp(p);
+ brute_force_potts_model(temp);
+
+ for (unsigned long i = 0; i < temp.number_of_nodes(); ++i)
+ {
+ dlog << LTRACE << "node " << i << ": "<< (int)temp.get_label(i);
+ }
+ brute_force_score = potts_model_score(temp);
+ dlog << LTRACE << "brute force score: "<< brute_force_score;
+ }
+ dlog << LTRACE << "******************";
+
+ {
+ potts_prob temp(p);
+ find_max_factor_graph_potts(temp);
+
+ for (unsigned long i = 0; i < temp.number_of_nodes(); ++i)
+ {
+ dlog << LTRACE << "node " << i << ": "<< (int)temp.get_label(i);
+ }
+ graph_cut_score = potts_model_score(temp);
+ dlog << LTRACE << "graph cut score: "<< graph_cut_score;
+ }
+
+ DLIB_TEST_MSG(graph_cut_score == brute_force_score, std::abs(graph_cut_score - brute_force_score));
+
+ dlog << LTRACE << "##################";
+ dlog << LTRACE << "##################";
+ dlog << LTRACE << "##################";
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// BASIC MIN CUT STUFF
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename directed_graph>
+ void brute_force_min_cut (
+ directed_graph& g,
+ unsigned long source,
+ unsigned long sink
+ )
+ {
+ typedef typename directed_graph::edge_type edge_weight_type;
+ const unsigned long num = (unsigned long)std::pow(2.0, (double)g.number_of_nodes());
+
+ std::vector<node_label> best_cut(g.number_of_nodes(),FREE_NODE);
+
+ edge_weight_type best_score = std::numeric_limits<edge_weight_type>::max();
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ for (unsigned long j = 0; j < g.number_of_nodes(); ++j)
+ {
+ unsigned long T = (1)<<j;
+ T = (T&i);
+ if (T != 0)
+ g.node(j).data = SINK_CUT;
+ else
+ g.node(j).data = SOURCE_CUT;
+ }
+
+ // ignore cuts that don't label the source or sink node the way we want.
+ if (g.node(source).data != SOURCE_CUT ||
+ g.node(sink).data != SINK_CUT)
+ continue;
+
+ edge_weight_type score = graph_cut_score(g);
+ if (score < best_score)
+ {
+ best_score = score;
+ for (unsigned long j = 0; j < g.number_of_nodes(); ++j)
+ best_cut[j] = g.node(j).data;
+ }
+ }
+
+ for (unsigned long j = 0; j < g.number_of_nodes(); ++j)
+ g.node(j).data = best_cut[j];
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename directed_graph>
+ void print_graph(
+ const directed_graph& g
+ )
+ {
+ using namespace std;
+ dlog << LTRACE << "number of nodes: "<< g.number_of_nodes();
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ for (unsigned long n = 0; n < g.node(i).number_of_children(); ++n)
+ dlog << LTRACE << i << " -(" << g.node(i).child_edge(n) << ")-> " << g.node(i).child(n).index();
+ }
+ }
+
+ template <typename directed_graph>
+ void copy_edge_weights (
+ directed_graph& dest,
+ const directed_graph& src
+ )
+ {
+ for (unsigned long i = 0; i < src.number_of_nodes(); ++i)
+ {
+ for (unsigned long n = 0; n < src.node(i).number_of_children(); ++n)
+ {
+ dest.node(i).child_edge(n) = src.node(i).child_edge(n);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename graph_type>
+ void pick_random_source_and_sink (
+ dlib::rand& rnd,
+ const graph_type& g,
+ unsigned long& source,
+ unsigned long& sink
+ )
+ {
+ source = rnd.get_random_32bit_number()%g.number_of_nodes();
+ sink = rnd.get_random_32bit_number()%g.number_of_nodes();
+ while (sink == source)
+ sink = rnd.get_random_32bit_number()%g.number_of_nodes();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename dgraph_type>
+ void make_random_graph(
+ dlib::rand& rnd,
+ dgraph_type& g,
+ unsigned long& source,
+ unsigned long& sink
+ )
+ {
+ typedef typename dgraph_type::edge_type edge_weight_type;
+ g.clear();
+ const unsigned int num_nodes = rnd.get_random_32bit_number()%7 + 2;
+ g.set_number_of_nodes(num_nodes);
+
+ const unsigned int num_edges = static_cast<unsigned int>(num_nodes*(num_nodes-1)/2*rnd.get_random_double() + 0.5);
+
+ // add the right number of randomly selected edges
+ unsigned int count = 0;
+ while (count < num_edges)
+ {
+ unsigned long parent = rnd.get_random_32bit_number()%g.number_of_nodes();
+ unsigned long child = rnd.get_random_32bit_number()%g.number_of_nodes();
+ if (parent != child && g.has_edge(parent, child) == false)
+ {
+ ++count;
+ g.add_edge(parent, child);
+ edge(g, parent, child) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+
+ // have to have edges both ways
+ swap(parent, child);
+ g.add_edge(parent, child);
+ edge(g, parent, child) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ }
+ }
+
+ pick_random_source_and_sink(rnd, g, source, sink);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename dgraph_type>
+ void make_random_chain_graph(
+ dlib::rand& rnd,
+ dgraph_type& g,
+ unsigned long& source,
+ unsigned long& sink
+ )
+ {
+ typedef typename dgraph_type::edge_type edge_weight_type;
+ g.clear();
+ const unsigned int num_nodes = rnd.get_random_32bit_number()%7 + 2;
+ g.set_number_of_nodes(num_nodes);
+
+ for (unsigned long i = 1; i < g.number_of_nodes(); ++i)
+ {
+ g.add_edge(i,i-1);
+ g.add_edge(i-1,i);
+ edge(g, i, i-1) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ edge(g, i-1, i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ }
+
+ pick_random_source_and_sink(rnd, g, source, sink);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename dgraph_type>
+ void make_random_grid_graph(
+ dlib::rand& rnd,
+ dgraph_type& g,
+ unsigned long& source,
+ unsigned long& sink
+ )
+ /*!
+ ensures
+ - makes a grid graph like the kind used for potts models.
+ !*/
+ {
+ typedef typename dgraph_type::edge_type edge_weight_type;
+ g.clear();
+ const long nr = rnd.get_random_32bit_number()%2 + 2;
+ const long nc = rnd.get_random_32bit_number()%2 + 2;
+ g.set_number_of_nodes(nr*nc+2);
+
+ const rectangle rect(0,0,nc-1,nr-1);
+ for (long r = 0; r < nr; ++r)
+ {
+ for (long c = 0; c < nc; ++c)
+ {
+ const point p(c,r);
+ const unsigned long i = p.y()*nc + p.x();
+
+ const point n2(c-1,r);
+ if (rect.contains(n2))
+ {
+ const unsigned long j = n2.y()*nc + n2.x();
+ g.add_edge(i,j);
+ g.add_edge(j,i);
+ edge(g,i,j) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ edge(g,j,i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ }
+
+ const point n4(c,r-1);
+ if (rect.contains(n4))
+ {
+ const unsigned long j = n4.y()*nc + n4.x();
+ g.add_edge(i,j);
+ g.add_edge(j,i);
+ edge(g,i,j) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ edge(g,j,i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ }
+ }
+ }
+
+ // use the last two nodes as source and sink. Also connect them to all the other nodes.
+ source = g.number_of_nodes()-1;
+ sink = g.number_of_nodes()-2;
+ for (unsigned long i = 0; i < g.number_of_nodes()-2; ++i)
+ {
+ g.add_edge(i,source);
+ g.add_edge(source,i);
+ g.add_edge(i,sink);
+ g.add_edge(sink,i);
+
+ edge(g,i,source) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ edge(g,source,i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ edge(g,i,sink) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ edge(g,sink,i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename min_cut, typename dgraph_type>
+ void run_test_on_graphs (
+ const min_cut& mc,
+ dgraph_type& g1,
+ dgraph_type& g2,
+ unsigned long source,
+ unsigned long sink
+ )
+ {
+ typedef typename dgraph_type::edge_type edge_weight_type;
+ using namespace std;
+
+
+ dlog << LTRACE << "number of nodes: "<< g1.number_of_nodes();
+ dlog << LTRACE << "is graph connected: "<< graph_is_connected(g1);
+ dlog << LTRACE << "has self loops: "<< graph_contains_length_one_cycle(g1);
+ dlog << LTRACE << "SOURCE_CUT: " << source;
+ dlog << LTRACE << "SINK_CUT: " << sink;
+ mc(g1, source, sink);
+ brute_force_min_cut(g2, source, sink);
+
+ print_graph(g1);
+
+ // make sure the flow residuals are 0 at the cut locations
+ for (unsigned long i = 0; i < g1.number_of_nodes(); ++i)
+ {
+ for (unsigned long j = 0; j < g1.node(i).number_of_children(); ++j)
+ {
+ if ((g1.node(i).data == SOURCE_CUT && g1.node(i).child(j).data != SOURCE_CUT) ||
+ (g1.node(i).data != SINK_CUT && g1.node(i).child(j).data == SINK_CUT)
+ )
+ {
+ DLIB_TEST_MSG(g1.node(i).child_edge(j) == 0, g1.node(i).child_edge(j));
+ }
+ }
+ }
+
+ // copy the edge weights from g2 back to g1 so we can compute cut scores
+ copy_edge_weights(g1, g2);
+
+ DLIB_TEST(g1.number_of_nodes() == g2.number_of_nodes());
+ for (unsigned long i = 0; i < g1.number_of_nodes(); ++i)
+ {
+ dlog << LTRACE << "node " << i << ": " << (int)g1.node(i).data << ", " << (int)g2.node(i).data;
+ if (g1.node(i).data != g2.node(i).data)
+ {
+ edge_weight_type cut_score = graph_cut_score(g1);
+ edge_weight_type brute_force_score = graph_cut_score(g2);
+ dlog << LTRACE << "graph cut score: "<< cut_score;
+ dlog << LTRACE << "brute force score: "<< brute_force_score;
+
+ if (brute_force_score != cut_score)
+ print_graph(g1);
+ DLIB_TEST_MSG(brute_force_score == cut_score,std::abs(brute_force_score-cut_score));
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename min_cut, typename edge_weight_type>
+ void test_graph_cuts(dlib::rand& rnd)
+ {
+ typedef typename dlib::directed_graph<node_label, edge_weight_type>::kernel_1a_c dgraph_type;
+ // we will create two identical graphs.
+ dgraph_type g1, g2;
+ min_cut mc;
+
+ unsigned long source, sink;
+
+ dlib::rand rnd_copy(rnd);
+ make_random_graph(rnd,g1, source, sink);
+ make_random_graph(rnd_copy,g2, source, sink);
+ run_test_on_graphs(mc, g1, g2, source, sink);
+
+ rnd_copy = rnd;
+ make_random_grid_graph(rnd,g1, source, sink);
+ make_random_grid_graph(rnd_copy,g2, source, sink);
+ run_test_on_graphs(mc, g1, g2, source, sink);
+
+ rnd_copy = rnd;
+ make_random_chain_graph(rnd,g1, source, sink);
+ make_random_chain_graph(rnd_copy,g2, source, sink);
+ run_test_on_graphs(mc, g1, g2, source, sink);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_potts_grid_problem
+ {
+ public:
+ test_potts_grid_problem(int seed_) :seed(seed_){}
+ int seed;
+
+ long nr() const { return 3;}
+ long nc() const { return 3;}
+
+ typedef double value_type;
+
+ value_type factor_value(unsigned long idx) const
+ {
+ // Copy idx into a char buffer to avoid warnings about violation of strict aliasing
+ // rules when murmur_hash3() gets inlined into this function.
+ char buf[sizeof(idx)];
+ memcpy(buf,&idx,sizeof(idx));
+ // now hash the buffer rather than idx.
+ return ((double)murmur_hash3(buf, sizeof(buf), seed) - std::numeric_limits<uint32>::max()/2.0)/1000.0;
+ }
+
+ value_type factor_value_disagreement(unsigned long idx1, unsigned long idx2) const
+ {
+ return std::abs(factor_value(idx1+idx2)/10.0);
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename prob_type>
+ void brute_force_potts_grid_problem(
+ const prob_type& prob,
+ array2d<unsigned char>& labels
+ )
+ {
+ const unsigned long num = (unsigned long)std::pow(2.0, (double)prob.nr()*prob.nc());
+
+ array2d<unsigned char> temp(prob.nr(), prob.nc());
+ unsigned char* data = &temp[0][0];
+
+ double best_score = -std::numeric_limits<double>::infinity();
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ for (unsigned long j = 0; j < temp.size(); ++j)
+ {
+ unsigned long T = (1)<<j;
+ T = (T&i);
+ if (T != 0)
+ *(data + j) = SINK_CUT;
+ else
+ *(data + j) = SOURCE_CUT;
+ }
+
+
+ double score = potts_model_score(prob, temp);
+ if (score > best_score)
+ {
+ best_score = score;
+ assign_image(labels, temp);
+ }
+ }
+ }
+
+ void test_inf()
+ {
+ graph<double,double>::kernel_1a_c g;
+ g.set_number_of_nodes(4);
+ g.add_edge(0,1);
+ g.add_edge(1,2);
+ g.add_edge(2,3);
+ g.add_edge(3,0);
+
+ g.node(0).data = std::numeric_limits<double>::infinity();
+ g.node(1).data = -std::numeric_limits<double>::infinity();
+ g.node(2).data = std::numeric_limits<double>::infinity();
+ g.node(3).data = -std::numeric_limits<double>::infinity();
+
+ edge(g,0,1) = 1;
+ edge(g,1,2) = 1;
+ edge(g,2,3) = 1;
+ edge(g,3,0) = 1;
+
+ std::vector<node_label> labels;
+ find_max_factor_graph_potts(g, labels);
+
+ DLIB_TEST(labels[0] != 0);
+ DLIB_TEST(labels[1] == 0);
+ DLIB_TEST(labels[2] != 0);
+ DLIB_TEST(labels[3] == 0);
+
+ // --------------------------
+
+ g.node(0).data = std::numeric_limits<double>::infinity();
+ g.node(1).data = 0;
+ g.node(2).data = 0;
+ g.node(3).data = -3;
+
+ edge(g,0,1) = 1;
+ edge(g,1,2) = 1;
+ edge(g,2,3) = 1;
+ edge(g,3,0) = 1;
+
+ find_max_factor_graph_potts(g, labels);
+
+ DLIB_TEST(labels[0] != 0);
+ DLIB_TEST(labels[1] != 0);
+ DLIB_TEST(labels[2] != 0);
+ DLIB_TEST(labels[3] == 0);
+
+ // --------------------------
+
+ g.node(0).data = std::numeric_limits<double>::infinity();
+ g.node(1).data = 0;
+ g.node(2).data = 0;
+ g.node(3).data = -0.1;
+
+ edge(g,0,1) = 1;
+ edge(g,1,2) = 1;
+ edge(g,2,3) = 1;
+ edge(g,3,0) = 1;
+
+ find_max_factor_graph_potts(g, labels);
+
+ DLIB_TEST(labels[0] != 0);
+ DLIB_TEST(labels[1] != 0);
+ DLIB_TEST(labels[2] != 0);
+ DLIB_TEST(labels[3] != 0);
+
+ // --------------------------
+
+ g.node(0).data = std::numeric_limits<double>::infinity();
+ g.node(1).data = 0;
+ g.node(2).data = 0;
+ g.node(3).data = -0.1;
+
+ edge(g,0,1) = 1;
+ edge(g,1,2) = 1;
+ edge(g,2,3) = 0;
+ edge(g,3,0) = 0;
+
+ find_max_factor_graph_potts(g, labels);
+
+ DLIB_TEST(labels[0] != 0);
+ DLIB_TEST(labels[1] != 0);
+ DLIB_TEST(labels[2] != 0);
+ DLIB_TEST(labels[3] == 0);
+
+ // --------------------------
+
+ g.node(0).data = -std::numeric_limits<double>::infinity();
+ g.node(1).data = 0;
+ g.node(2).data = 0;
+ g.node(3).data = 0.1;
+
+ edge(g,0,1) = 1;
+ edge(g,1,2) = 1;
+ edge(g,2,3) = 0;
+ edge(g,3,0) = 0;
+
+ find_max_factor_graph_potts(g, labels);
+
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(labels[1] == 0);
+ DLIB_TEST(labels[2] == 0);
+ DLIB_TEST(labels[3] != 0);
+
+ // --------------------------
+
+ g.node(0).data = -std::numeric_limits<double>::infinity();
+ g.node(1).data = std::numeric_limits<double>::infinity();
+ g.node(2).data = 0;
+ g.node(3).data = 0.1;
+
+ edge(g,0,1) = 1;
+ edge(g,1,2) = 1;
+ edge(g,2,3) = 0;
+ edge(g,3,0) = 0;
+
+ find_max_factor_graph_potts(g, labels);
+
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(labels[1] != 0);
+ DLIB_TEST(labels[2] != 0);
+ DLIB_TEST(labels[3] != 0);
+
+ // --------------------------
+
+ g.node(0).data = -10;
+ g.node(1).data = std::numeric_limits<double>::infinity();
+ g.node(2).data = 0;
+ g.node(3).data = 0.1;
+
+ edge(g,0,1) = std::numeric_limits<double>::infinity();
+ edge(g,1,2) = std::numeric_limits<double>::infinity();
+ edge(g,2,3) = std::numeric_limits<double>::infinity();
+ edge(g,3,0) = std::numeric_limits<double>::infinity();
+
+ find_max_factor_graph_potts(g, labels);
+
+ DLIB_TEST(labels[0] != 0);
+ DLIB_TEST(labels[1] != 0);
+ DLIB_TEST(labels[2] != 0);
+ DLIB_TEST(labels[3] != 0);
+
+ // --------------------------
+
+ g.node(0).data = 10;
+ g.node(1).data = -std::numeric_limits<double>::infinity();
+ g.node(2).data = 20.05;
+ g.node(3).data = -0.1;
+
+ edge(g,0,1) = std::numeric_limits<double>::infinity();
+ edge(g,1,2) = 10;
+ edge(g,2,3) = std::numeric_limits<double>::infinity();
+ edge(g,3,0) = 10;
+
+ find_max_factor_graph_potts(g, labels);
+
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(labels[1] == 0);
+ DLIB_TEST(labels[2] == 0);
+ DLIB_TEST(labels[3] == 0);
+
+ // --------------------------
+
+ g.node(0).data = 10;
+ g.node(1).data = -std::numeric_limits<double>::infinity();
+ g.node(2).data = 20.2;
+ g.node(3).data = -0.1;
+
+ edge(g,0,1) = std::numeric_limits<double>::infinity();
+ edge(g,1,2) = 10;
+ edge(g,2,3) = std::numeric_limits<double>::infinity();
+ edge(g,3,0) = 10;
+
+ find_max_factor_graph_potts(g, labels);
+
+ DLIB_TEST(labels[0] == 0);
+ DLIB_TEST(labels[1] == 0);
+ DLIB_TEST(labels[2] != 0);
+ DLIB_TEST(labels[3] != 0);
+ }
+
+ struct potts_pair_image_model
+ {
+ typedef double value_type;
+
+ template <typename pixel_type1, typename pixel_type2>
+ value_type factor_value (
+ const pixel_type1& ,
+ const pixel_type2& v2
+ ) const
+ {
+ return v2;
+ }
+
+ template <typename pixel_type>
+ value_type factor_value_disagreement (
+ const pixel_type& v1,
+ const pixel_type& v2
+ ) const
+ {
+ if (v1 == v2)
+ return 10;
+ else
+ return 0;
+ }
+ };
+
+ void test_potts_pair_grid()
+ {
+ array2d<int> img1(40,40);
+ array2d<double> img2(40,40);
+
+ assign_all_pixels(img1, -1);
+ assign_all_pixels(img2, -1);
+
+ img1[4][4] = 1000;
+
+ img2[4][3] = 1;
+ img2[4][4] = 1;
+ img2[4][5] = 1;
+ img2[3][3] = 1;
+ img2[3][4] = 1;
+ img2[3][5] = 1;
+ img2[5][3] = 1;
+ img2[5][4] = 1;
+ img2[5][5] = 1;
+
+ array2d<unsigned char> labels;
+ find_max_factor_graph_potts(make_potts_grid_problem(potts_pair_image_model(),img2,img1), labels);
+
+ dlog << LINFO << "num true labels: " << sum(matrix_cast<int>(mat(labels)!=0));
+ DLIB_TEST(sum(matrix_cast<int>(mat(labels)!=0)) == 9);
+ DLIB_TEST(sum(matrix_cast<int>(mat(labels)==0)) == (int)img1.size()-9);
+
+ DLIB_TEST(labels[4][3] != 0);
+ DLIB_TEST(labels[4][4] != 0);
+ DLIB_TEST(labels[4][5] != 0);
+ DLIB_TEST(labels[3][3] != 0);
+ DLIB_TEST(labels[3][4] != 0);
+ DLIB_TEST(labels[3][5] != 0);
+ DLIB_TEST(labels[5][3] != 0);
+ DLIB_TEST(labels[5][4] != 0);
+ DLIB_TEST(labels[5][5] != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class graph_cuts_tester : public tester
+ {
+ public:
+ graph_cuts_tester (
+ ) :
+ tester ("test_graph_cuts",
+ "Runs tests on the graph cuts tools.")
+ {}
+
+ dlib::rand rnd;
+
+ void perform_test (
+ )
+ {
+ test_potts_pair_grid();
+ test_inf();
+
+ for (int i = 0; i < 500; ++i)
+ {
+ array2d<unsigned char> labels, brute_labels;
+ test_potts_grid_problem prob(i);
+ find_max_factor_graph_potts(prob, labels);
+ brute_force_potts_grid_problem(prob, brute_labels);
+
+ DLIB_TEST(labels.nr() == brute_labels.nr());
+ DLIB_TEST(labels.nc() == brute_labels.nc());
+ for (long r = 0; r < labels.nr(); ++r)
+ {
+ for (long c = 0; c < labels.nc(); ++c)
+ {
+ bool normal = (labels[r][c] != 0);
+ bool brute = (brute_labels[r][c] != 0);
+ DLIB_TEST(normal == brute);
+ }
+ }
+ }
+
+ for (int i = 0; i < 1000; ++i)
+ {
+ print_spinner();
+ dlog << LTRACE << "test_grpah_cuts<short> iter: " << i;
+ test_graph_cuts<min_cut,short>(rnd);
+ print_spinner();
+ dlog << LTRACE << "test_grpah_cuts<double> iter: " << i;
+ test_graph_cuts<min_cut,double>(rnd);
+ }
+
+
+ for (int k = 0; k < 300; ++k)
+ {
+ dlog << LTRACE << "image_potts_problem iter " << k;
+ print_spinner();
+ image_potts_problem p(3,3, rnd);
+ impl_test_potts_model(p);
+ }
+ for (int k = 0; k < 300; ++k)
+ {
+ dlog << LTRACE << "dense_potts_problem iter " << k;
+ print_spinner();
+ dense_potts_problem p(6, rnd);
+ impl_test_potts_model(p);
+ }
+
+ for (int k = 0; k < 300; ++k)
+ {
+ dlog << LTRACE << "dense_potts_problem iter " << k;
+ print_spinner();
+ test_graph_potts_model(rnd);
+ }
+ }
+ } a;
+
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/graph_labeler.cpp b/ml/dlib/dlib/test/graph_labeler.cpp
new file mode 100644
index 000000000..112873896
--- /dev/null
+++ b/ml/dlib/dlib/test/graph_labeler.cpp
@@ -0,0 +1,472 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/svm_threaded.h>
+#include <dlib/data_io.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+
+
+ logger dlog("test.graph_cuts");
+
+
+ template <
+ typename graph_type,
+ typename samples_type,
+ typename labels_type
+ >
+ void make_data(
+ samples_type& samples,
+ labels_type& labels
+ )
+ {
+ //samples.clear();
+ //labels.clear();
+
+ std::vector<bool> label;
+ graph_type g;
+
+ // ---------------------------
+ g.set_number_of_nodes(4);
+ label.resize(g.number_of_nodes());
+ g.node(0).data = 0, 0, 1; label[0] = true;
+ g.node(1).data = 0, 0, 1; label[1] = true;
+ g.node(2).data = 0, 1, 0; label[2] = false;
+ g.node(3).data = 0, 1, 0; label[3] = false;
+
+ g.add_edge(0,1);
+ g.add_edge(1,2);
+ g.add_edge(2,3);
+ g.add_edge(3,0);
+
+ edge(g,0,1) = 1, 1;
+ edge(g,1,2) = 1, 1;
+ edge(g,2,3) = 1, 1;
+ edge(g,3,0) = 1, 1;
+ samples.push_back(g);
+ labels.push_back(label);
+ // ---------------------------
+
+ g.clear();
+ g.set_number_of_nodes(4);
+ label.resize(g.number_of_nodes());
+ g.node(0).data = 0, 0, 1; label[0] = true;
+ g.node(1).data = 0, 0, 0; label[1] = true;
+ g.node(2).data = 0, 1, 0; label[2] = false;
+ g.node(3).data = 0, 0, 0; label[3] = false;
+
+ g.add_edge(0,1);
+ g.add_edge(1,2);
+ g.add_edge(2,3);
+ g.add_edge(3,0);
+
+ edge(g,0,1) = 1, 0;
+ edge(g,1,2) = 0, 1;
+ edge(g,2,3) = 1, 0;
+ edge(g,3,0) = 0, 1;
+ samples.push_back(g);
+ labels.push_back(label);
+ // ---------------------------
+
+ g.clear();
+ g.set_number_of_nodes(4);
+ label.resize(g.number_of_nodes());
+ g.node(0).data = 0, 1, 0; label[0] = false;
+ g.node(1).data = 0, 1, 0; label[1] = false;
+ g.node(2).data = 0, 1, 0; label[2] = false;
+ g.node(3).data = 0, 0, 0; label[3] = false;
+
+ g.add_edge(0,1);
+ g.add_edge(1,2);
+ g.add_edge(2,3);
+ g.add_edge(3,0);
+
+ edge(g,0,1) = 1, 0;
+ edge(g,1,2) = 0, 1;
+ edge(g,2,3) = 1, 0;
+ edge(g,3,0) = 0, 1;
+ samples.push_back(g);
+ labels.push_back(label);
+ // ---------------------------
+ }
+
+
+
+
+ template <
+ typename graph_type,
+ typename samples_type,
+ typename labels_type
+ >
+ void make_data_sparse(
+ samples_type& samples,
+ labels_type& labels
+ )
+ {
+ //samples.clear();
+ //labels.clear();
+
+ std::vector<bool> label;
+ graph_type g;
+ typename graph_type::edge_type v;
+
+ // ---------------------------
+ g.set_number_of_nodes(4);
+ label.resize(g.number_of_nodes());
+ g.node(0).data[2] = 1; label[0] = true;
+ g.node(1).data[2] = 1; label[1] = true;
+ g.node(2).data[1] = 1; label[2] = false;
+ g.node(3).data[1] = 1; label[3] = false;
+
+ g.add_edge(0,1);
+ g.add_edge(1,2);
+ g.add_edge(2,3);
+ g.add_edge(3,0);
+ g.add_edge(3,1);
+
+ v[0] = 1; v[1] = 1;
+ edge(g,0,1) = v;
+ edge(g,1,2) = v;
+ edge(g,2,3) = v;
+ edge(g,3,0) = v;
+ samples.push_back(g);
+ labels.push_back(label);
+ // ---------------------------
+
+ g.clear();
+ g.set_number_of_nodes(5);
+ label.resize(g.number_of_nodes());
+ g.node(0).data[2] = 1; label[0] = true;
+ g.node(1).data[0] = 0; label[1] = true;
+ g.node(2).data[1] = 1; label[2] = false;
+ g.node(3).data[0] = 0; label[3] = false;
+ label[4] = true;
+
+ g.add_edge(0,1);
+ g.add_edge(1,4);
+ g.add_edge(1,2);
+ g.add_edge(2,3);
+ g.add_edge(3,0);
+
+ edge(g,0,1)[0] = 1;
+ edge(g,1,4)[0] = 1;
+ edge(g,1,2)[1] = 1;
+ edge(g,2,3)[0] = 1;
+ edge(g,3,0)[1] = 1;
+ samples.push_back(g);
+ labels.push_back(label);
+ // ---------------------------
+
+ g.clear();
+ g.set_number_of_nodes(4);
+ label.resize(g.number_of_nodes());
+ g.node(0).data[1] = 1; label[0] = false;
+ g.node(1).data[1] = 1; label[1] = false;
+ g.node(2).data[1] = 1; label[2] = false;
+ g.node(3).data[1] = 0; label[3] = false;
+
+ g.add_edge(0,1);
+ g.add_edge(1,2);
+ g.add_edge(2,3);
+ g.add_edge(3,0);
+
+ edge(g,0,1)[0] = 1;
+ edge(g,1,2)[1] = 1;
+ edge(g,2,3)[0] = 1;
+ edge(g,3,0)[1] = 1;
+ samples.push_back(g);
+ labels.push_back(label);
+ // ---------------------------
+ }
+
+
+
+
+
+
+ template <
+ typename graph_type,
+ typename samples_type,
+ typename labels_type
+ >
+ void make_data2(
+ samples_type& samples,
+ labels_type& labels
+ )
+ {
+ //samples.clear();
+ //labels.clear();
+
+ std::vector<bool> label;
+ graph_type g;
+
+ // ---------------------------
+ g.set_number_of_nodes(4);
+ label.resize(g.number_of_nodes());
+ g.node(0).data = 0, 0, 1; label[0] = true;
+ g.node(1).data = 0, 0, 1; label[1] = true;
+ g.node(2).data = 0, 1, 0; label[2] = false;
+ g.node(3).data = 0, 1, 0; label[3] = false;
+
+ g.add_edge(0,1);
+ g.add_edge(1,2);
+ g.add_edge(2,3);
+ g.add_edge(3,0);
+
+ edge(g,0,1) = 1, 1;
+ edge(g,1,2) = 1, 1;
+ edge(g,2,3) = 1, 1;
+ edge(g,3,0) = 1, 1;
+ samples.push_back(g);
+ labels.push_back(label);
+ // ---------------------------
+ }
+
+
+
+
+ template <
+ typename graph_type,
+ typename samples_type,
+ typename labels_type
+ >
+ void make_data2_sparse(
+ samples_type& samples,
+ labels_type& labels
+ )
+ {
+ //samples.clear();
+ //labels.clear();
+
+ std::vector<bool> label;
+ graph_type g;
+ typename graph_type::edge_type v;
+
+ // ---------------------------
+ g.set_number_of_nodes(4);
+ label.resize(g.number_of_nodes());
+ g.node(0).data[2] = 1; label[0] = true;
+ g.node(1).data[2] = 1; label[1] = true;
+ g.node(2).data[1] = 1; label[2] = false;
+ g.node(3).data[1] = 1; label[3] = false;
+
+ g.add_edge(0,1);
+ g.add_edge(1,2);
+ g.add_edge(2,3);
+ g.add_edge(3,0);
+
+ v[0] = 1; v[1] = 1;
+ edge(g,0,1) = v;
+ edge(g,1,2) = v;
+ edge(g,2,3) = v;
+ edge(g,3,0) = v;
+ samples.push_back(g);
+ labels.push_back(label);
+ // ---------------------------
+
+ }
+
+
+
+
+
+ template <
+ typename node_vector_type,
+ typename edge_vector_type,
+ typename vector_type,
+ typename graph_type
+ >
+ void test1(
+ const dlib::array<graph_type>& samples,
+ const std::vector<std::vector<bool> >& labels
+ )
+ {
+ dlog << LINFO << "begin test1()";
+
+ structural_graph_labeling_trainer<vector_type> trainer;
+ //trainer.be_verbose();
+ trainer.set_epsilon(1e-12);
+ graph_labeler<vector_type> labeler = trainer.train(samples, labels);
+
+
+ // test serialization code for the labeler.
+ std::ostringstream sout;
+ serialize(labeler, sout);
+ std::istringstream sin(sout.str());
+ labeler = graph_labeler<vector_type>();
+ deserialize(labeler, sin);
+
+ std::vector<bool> temp;
+ for (unsigned long k = 0; k < samples.size(); ++k)
+ {
+ temp = labeler(samples[k]);
+ for (unsigned long i = 0; i < temp.size(); ++i)
+ {
+ const bool true_label = (labels[k][i] != 0);
+ const bool pred_label = (temp[i] != 0);
+ DLIB_TEST(true_label == pred_label);
+ }
+ }
+
+ matrix<double> cv;
+
+ cv = test_graph_labeling_function(labeler, samples, labels);
+ DLIB_TEST(sum(cv) == 2);
+ cv = cross_validate_graph_labeling_trainer(trainer, samples, labels, 4);
+ DLIB_TEST(sum(cv) == 2);
+
+ dlog << LINFO << "edge weights: " << trans(sparse_to_dense(labeler.get_edge_weights()));
+ dlog << LINFO << "node weights: " << trans(sparse_to_dense(labeler.get_node_weights()));
+ }
+
+
+
+ class graph_labeling_tester : public tester
+ {
+ public:
+ graph_labeling_tester (
+ ) :
+ tester ("test_graph_labeling",
+ "Runs tests on the graph labeling component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ // test with dense vectors
+ {
+ typedef matrix<double,3,1> node_vector_type;
+ typedef matrix<double,2,1> edge_vector_type;
+ typedef matrix<double,0,1> vector_type;
+ typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
+
+ dlib::array<graph_type> samples;
+ std::vector<std::vector<bool> > labels;
+
+ make_data<graph_type>(samples, labels);
+ make_data<graph_type>(samples, labels);
+ make_data<graph_type>(samples, labels);
+ make_data<graph_type>(samples, labels);
+
+
+ test1<node_vector_type,edge_vector_type,vector_type>(samples, labels);
+ }
+ print_spinner();
+ // test with dense vectors and sparse vectors together
+ {
+ typedef matrix<double,3,1> node_vector_type;
+ typedef matrix<double,2,1> edge_vector_type;
+ typedef std::map<unsigned long,double> vector_type;
+ typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
+
+ dlib::array<graph_type> samples;
+ std::vector<std::vector<bool> > labels;
+
+ make_data<graph_type>(samples, labels);
+ make_data<graph_type>(samples, labels);
+ make_data<graph_type>(samples, labels);
+ make_data<graph_type>(samples, labels);
+
+
+ test1<node_vector_type,edge_vector_type,vector_type>(samples, labels);
+ }
+ print_spinner();
+ // test with sparse vectors
+ {
+ typedef std::vector<std::pair<unsigned long,double> > vector_type;
+ typedef std::map<unsigned long, double> edge_vector_type;
+ typedef std::map<unsigned long, double> node_vector_type;
+ typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
+
+ dlib::array<graph_type> samples;
+ std::vector<std::vector<bool> > labels;
+
+ make_data_sparse<graph_type>(samples, labels);
+ make_data_sparse<graph_type>(samples, labels);
+ make_data_sparse<graph_type>(samples, labels);
+ make_data_sparse<graph_type>(samples, labels);
+
+
+ test1<node_vector_type,edge_vector_type,vector_type>(samples, labels);
+ }
+
+
+
+ print_spinner();
+ // test with dense vectors
+ {
+ typedef matrix<double,3,1> node_vector_type;
+ typedef matrix<double,2,1> edge_vector_type;
+ typedef matrix<double,0,1> vector_type;
+ typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
+
+ dlib::array<graph_type> samples;
+ std::vector<std::vector<bool> > labels;
+
+ make_data2<graph_type>(samples, labels);
+ make_data2<graph_type>(samples, labels);
+ make_data2<graph_type>(samples, labels);
+ make_data2<graph_type>(samples, labels);
+
+
+ test1<node_vector_type,edge_vector_type,vector_type>(samples, labels);
+ }
+ print_spinner();
+ // test with sparse vectors
+ {
+ typedef std::vector<std::pair<unsigned long,double> > vector_type;
+ typedef std::map<unsigned long, double> edge_vector_type;
+ typedef std::map<unsigned long, double> node_vector_type;
+ typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
+
+ dlib::array<graph_type> samples;
+ std::vector<std::vector<bool> > labels;
+
+ make_data2_sparse<graph_type>(samples, labels);
+ make_data2_sparse<graph_type>(samples, labels);
+ make_data2_sparse<graph_type>(samples, labels);
+ make_data2_sparse<graph_type>(samples, labels);
+
+
+ test1<node_vector_type,edge_vector_type,vector_type>(samples, labels);
+ }
+ print_spinner();
+ // test with sparse vectors and dense mix
+ {
+ typedef matrix<double,0,1> vector_type;
+ typedef std::map<unsigned long, double> edge_vector_type;
+ typedef std::map<unsigned long, double> node_vector_type;
+ typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
+
+ dlib::array<graph_type> samples;
+ std::vector<std::vector<bool> > labels;
+
+ make_data2_sparse<graph_type>(samples, labels);
+ make_data2_sparse<graph_type>(samples, labels);
+ make_data2_sparse<graph_type>(samples, labels);
+ make_data2_sparse<graph_type>(samples, labels);
+
+
+ test1<node_vector_type,edge_vector_type,vector_type>(samples, labels);
+ }
+ }
+ } a;
+
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/gui/CMakeLists.txt b/ml/dlib/dlib/test/gui/CMakeLists.txt
new file mode 100644
index 000000000..2ab3c2b47
--- /dev/null
+++ b/ml/dlib/dlib/test/gui/CMakeLists.txt
@@ -0,0 +1,20 @@
+#
+# This is a CMake makefile. You can find the cmake utility and
+# information about it at http://www.cmake.org
+#
+
+# create a variable called target_name and set it to the string "test"
+set (target_name gui)
+
+project(${target_name})
+
+add_subdirectory(../.. dlib_build)
+
+# add all the cpp files we want to compile to this list. This tells
+# cmake that they are part of our target (which is the executable named test)
+add_executable(${target_name} main.cpp )
+
+
+# Tell cmake to link our target executable to dlib.
+target_link_libraries(${target_name} dlib::dlib )
+
diff --git a/ml/dlib/dlib/test/gui/main.cpp b/ml/dlib/dlib/test/gui/main.cpp
new file mode 100644
index 000000000..d61aba8c0
--- /dev/null
+++ b/ml/dlib/dlib/test/gui/main.cpp
@@ -0,0 +1,840 @@
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "dlib/image_io.h"
+#include "dlib/array2d.h"
+#include "dlib/gui_core.h"
+#include "dlib/assert.h"
+#include "dlib/misc_api.h"
+
+#include "dlib/image_transforms.h"
+
+#include "dlib/timer.h"
+
+#include "dlib/gui_widgets.h"
+#include "dlib/queue.h"
+
+using namespace dlib;
+using namespace std;
+
+
+typedef dlib::array2d<hsi_pixel> image;
+
+
+
+
+#include "dlib/base64.h"
+
+
+
+
+class color_box : public draggable
+{
+ unsigned char red, green,blue;
+
+public:
+ color_box (
+ drawable_window& w,
+ rectangle area,
+ unsigned char red_,
+ unsigned char green_,
+ unsigned char blue_
+ ) :
+ draggable(w, MOUSE_WHEEL),
+ red(red_),
+ green(green_),
+ blue(blue_),
+ t(*this,&color_box::action)
+ {
+ rect = area;
+
+ t.set_delay_time(4);
+ // t.start();
+
+ set_draggable_area(rectangle(10,10,500,500));
+
+ enable_events();
+ }
+
+ ~color_box()
+ {
+ disable_events();
+ }
+
+private:
+
+ void action (
+ )
+ {
+ ++red;
+ parent.invalidate_rectangle(rect);
+ }
+
+ void draw (
+ const canvas& c
+ ) const
+ {
+ if (hidden == false )
+ {
+ fill_rect(c,rect,rgb_pixel(red,green,blue));
+ std::vector<point> poly;
+ poly.push_back((rect.tl_corner()+rect.tr_corner())/2);
+ poly.push_back((rect.tr_corner()+rect.br_corner())/2);
+ poly.push_back((rect.br_corner()+rect.bl_corner())/2);
+ poly.push_back((rect.bl_corner()+rect.tl_corner())/2);
+ draw_solid_convex_polygon(c,poly,rgb_alpha_pixel(0,0,0,70));
+ }
+ }
+
+ void on_wheel_up(
+ unsigned long state
+ )
+ {
+ if (state == base_window::NONE)
+ cout << "up scroll, NONE" << endl;
+ else if (state&base_window::LEFT)
+ cout << "up scroll, LEFT" << endl;
+ else if (state&base_window::RIGHT)
+ cout << "up scroll, RIGHT" << endl;
+ else if (state&base_window::MIDDLE)
+ cout << "up scroll, MIDDLE" << endl;
+ else if (state&base_window::SHIFT)
+ cout << "up scroll, SHIFT" << endl;
+ else if (state&base_window::CONTROL)
+ cout << "up scroll, CONTROL" << endl;
+
+ }
+
+ void on_wheel_down(
+ unsigned long state
+ )
+ {
+
+ if (state == base_window::NONE)
+ cout << "down scroll, NONE" << endl;
+ else if (state&base_window::LEFT)
+ cout << "down scroll, LEFT" << endl;
+ else if (state&base_window::RIGHT)
+ cout << "down scroll, RIGHT" << endl;
+ else if (state&base_window::MIDDLE)
+ cout << "down scroll, MIDDLE" << endl;
+ else if (state&base_window::SHIFT)
+ cout << "down scroll, SHIFT" << endl;
+ else if (state&base_window::CONTROL)
+ cout << "down scroll, CONTROL" << endl;
+
+ }
+
+
+ void on_window_resized ()
+ {
+ draggable::on_window_resized();
+ }
+ timer<color_box> t;
+};
+
+
+
+
+
+
+class win : public drawable_window
+{
+
+ label lbl_last_keydown;
+ label lbl_mod_shift;
+ label lbl_mod_control;
+ label lbl_mod_alt;
+ label lbl_mod_meta;
+ label lbl_mod_caps_lock;
+ label lbl_mod_num_lock;
+ label lbl_mod_scroll_lock;
+ void on_keydown (
+ unsigned long key,
+ bool is_printable,
+ unsigned long state
+ )
+ {
+ if (is_printable)
+ lbl_last_keydown.set_text(string("last keydown: ") + (char)key);
+ else
+ lbl_last_keydown.set_text(string("last keydown: nonprintable"));
+
+ if (state&base_window::KBD_MOD_SHIFT)
+ lbl_mod_shift.set_text("shift is on");
+ else
+ lbl_mod_shift.set_text("shift is off");
+
+ if (state&base_window::KBD_MOD_CONTROL)
+ lbl_mod_control.set_text("control is on");
+ else
+ lbl_mod_control.set_text("control is off");
+
+ if (state&base_window::KBD_MOD_ALT)
+ lbl_mod_alt.set_text("alt is on");
+ else
+ lbl_mod_alt.set_text("alt is off");
+
+
+ if (state&base_window::KBD_MOD_META)
+ lbl_mod_meta.set_text("meta is on");
+ else
+ lbl_mod_meta.set_text("meta is off");
+
+ if (state&base_window::KBD_MOD_CAPS_LOCK)
+ lbl_mod_caps_lock.set_text("caps_lock is on");
+ else
+ lbl_mod_caps_lock.set_text("caps_lock is off");
+
+ if (state&base_window::KBD_MOD_NUM_LOCK)
+ lbl_mod_num_lock.set_text("num_lock is on");
+ else
+ lbl_mod_num_lock.set_text("num_lock is off");
+
+
+ if (state&base_window::KBD_MOD_SCROLL_LOCK)
+ lbl_mod_scroll_lock.set_text("scroll_lock is on");
+ else
+ lbl_mod_scroll_lock.set_text("scroll_lock is off");
+
+ drawable_window::on_keydown(key,is_printable,state);
+ }
+
+ void rb_click (
+ )
+ {
+ if (rb.is_checked())
+ rb.set_name("radio button checked");
+ else
+ rb.set_name("radio button");
+ rb.set_checked();
+ }
+
+ void cb_sb_enabled (
+ toggle_button&
+ )
+ {
+ if (sb_enabled.is_checked())
+ {
+ sb.enable();
+ lb.enable();
+ b.enable();
+ }
+ else
+ {
+ lb.disable();
+ sb.disable();
+ b.disable();
+ }
+
+ if (sb_enabled.is_checked())
+ rb.enable();
+ else
+ rb.disable();
+
+ if (sb_enabled.is_checked())
+ tabs.enable();
+ else
+ tabs.disable();
+
+ if (sb_enabled.is_checked())
+ tf.enable();
+ else
+ tf.disable();
+
+ if (sb_enabled.is_checked())
+ tb.enable();
+ else
+ tb.disable();
+
+ }
+
+ void cb_sb_shown (
+ )
+ {
+ if (sb_shown.is_checked())
+ {
+ sb.show();
+ tabs.show();
+ lb.show();
+ }
+ else
+ {
+ sb.hide();
+ tabs.hide();
+ lb.hide();
+ }
+ }
+
+
+ void tab_change (
+ unsigned long new_idx,
+ unsigned long
+ )
+ {
+ tab_label.set_text(tabs.tab_name(new_idx));
+ }
+
+ void scroll_handler (
+ )
+ {
+ ostringstream sout;
+ sout << "scroll bar pos: " << sb.slider_pos();
+ sbl.set_text(sout.str());
+ }
+
+ void scroll2_handler (
+ )
+ {
+ sb.set_length(sb2.slider_pos());
+ ostringstream sout;
+ sout << "scroll bar2 pos: " << sb2.slider_pos();
+ sbl2.set_text(sout.str());
+ scroll_handler();
+ }
+
+ void scroll3_handler (
+ )
+ {
+ sb.set_max_slider_pos(sb3.slider_pos());
+ ostringstream sout;
+ sout << "scroll bar3 pos: " << sb3.slider_pos();
+ sbl3.set_text(sout.str());
+ scroll_handler();
+ }
+
+ void lb_double_click (
+ unsigned long
+ )
+ {
+ dlib::queue<unsigned long>::kernel_2a_c sel;
+ lb.get_selected(sel);
+ sel.reset();
+ while (sel.move_next())
+ {
+ cout << lb[sel.element()] << endl;
+ }
+ //message_box("list_box",lb[idx]);
+ }
+
+ void msg_box (
+ )
+ {
+ message_box("title","you clicked the ok button!\n HURRAY!");
+ }
+
+ static void try_this_junk (
+ void* param
+ )
+ {
+ win& p = *reinterpret_cast<win*>(param);
+ put_on_clipboard(p.tf.text() + "\nfoobar");
+
+
+ }
+
+ void on_set_clipboard (
+ )
+ {
+ create_new_thread(try_this_junk,this);
+ //try_this_junk(this);
+ }
+
+ static void try_this_junk2 (
+ void*
+ )
+ {
+
+ string temp;
+ get_from_clipboard(temp);
+ message_box("clipboard",temp);
+
+ }
+ void on_get_clipboard (
+ )
+ {
+ create_new_thread(try_this_junk2,this);
+ }
+
+
+ void on_show_msg_click (
+ )
+ {
+ message_box("title","This is a test message.",*this,&win::msg_box);
+ }
+
+ void on_menu_help (
+ )
+ {
+ message_box("About","This is the messy dlib gui regression test program");
+ }
+
+public:
+
+ ~win()
+ {
+ close_window();
+ }
+
+ void cbox_clicked (
+ )
+ {
+ if (cbox.is_checked())
+ cbl.set_text(cbox.name() + " box is checked");
+ else
+ cbl.set_text("box NOT is checked");
+ }
+
+ win (
+ ):
+ drawable_window(true),
+ lbl_last_keydown(*this),
+ lbl_mod_shift(*this),
+ lbl_mod_control(*this),
+ lbl_mod_alt(*this),
+ lbl_mod_meta(*this),
+ lbl_mod_caps_lock(*this),
+ lbl_mod_num_lock(*this),
+ lbl_mod_scroll_lock(*this),
+ b(*this),
+ btn_count(*this),
+ btn_get_clipboard(*this),
+ btn_set_clipboard(*this),
+ btn_show_message(*this),
+ cb1(*this,rectangle(100,100,200,200),255,0,0),
+ cb2(*this,rectangle(150,150,250,240),0,255,0),
+ cbl(*this),
+ cbox(*this),
+ group1(*this),
+ group2(*this),
+ group3(*this),
+ keyboard_count(1),
+ keydown(*this),
+ keyup(*this),
+ l1(*this),
+ l2(*this),
+ l3(*this),
+ lb(*this),
+ leave_count(*this),
+ left_down(*this),
+ left_up(*this),
+ middle_down(*this),
+ middle_up(*this),
+ mouse_state(*this),
+ mt(*this),
+ nrect(*this),
+ pos(*this),
+ rb(*this),
+ right_down(*this),
+ right_up(*this),
+ sb2(*this,scroll_bar::VERTICAL),
+ sb3(*this,scroll_bar::VERTICAL),
+ sb_enabled(*this),
+ sbl2(*this),
+ sbl3(*this),
+ sbl(*this),
+ sb_shown(*this),
+ sb(*this,scroll_bar::HORIZONTAL),
+ scroll(*this),
+ tab_label(*this),
+ tabs(*this),
+ tf(*this),
+ tb(*this),
+ mbar(*this)
+ {
+ bool use_bdf_fonts = false;
+
+ std::shared_ptr<bdf_font> f(new bdf_font);
+
+ if (use_bdf_fonts)
+ {
+
+ ifstream fin("/home/davis/source/10x20.bdf");
+ f->read_bdf_file(fin,0xFFFF);
+
+ mt.set_main_font(f);
+ }
+ //mt.hide();
+ mt.set_pos(5,200);
+
+
+ lbl_last_keydown.set_text("?");
+ lbl_mod_shift.set_text("?");
+ lbl_mod_control.set_text("?");
+ lbl_mod_alt.set_text("?");
+ lbl_mod_meta.set_text("?");
+ lbl_mod_caps_lock.set_text("?");
+ lbl_mod_num_lock.set_text("?");
+ lbl_mod_scroll_lock.set_text("?");
+
+ lbl_last_keydown.set_pos(20,420);
+ lbl_mod_shift.set_pos(20,lbl_last_keydown.bottom()+5);
+ lbl_mod_control.set_pos(20,lbl_mod_shift.bottom()+5);
+ lbl_mod_alt.set_pos(20,lbl_mod_control.bottom()+5);
+ lbl_mod_meta.set_pos(20,lbl_mod_alt.bottom()+5);
+ lbl_mod_caps_lock.set_pos(20,lbl_mod_meta.bottom()+5);
+ lbl_mod_num_lock.set_pos(20,lbl_mod_caps_lock.bottom()+5);
+ lbl_mod_scroll_lock.set_pos(20,lbl_mod_num_lock.bottom()+5);
+
+ lb.set_pos(580,200);
+ lb.set_size(200,300);
+ if (use_bdf_fonts)
+ lb.set_main_font(f);
+
+ dlib::queue<string>::kernel_2a_c qos;
+ string a;
+ a = "Davis"; qos.enqueue(a);
+ a = "king"; qos.enqueue(a);
+ a = "one"; qos.enqueue(a);
+ a = "two"; qos.enqueue(a);
+ a = "three"; qos.enqueue(a);
+ a = "yo yo yo alsdkjf asfj lsa jfsf\n this is a long phrase"; qos.enqueue(a);
+ a = "four"; qos.enqueue(a);
+ a = "five"; qos.enqueue(a);
+ a = "six"; qos.enqueue(a);
+ a = "seven"; qos.enqueue(a);
+ a = "eight"; qos.enqueue(a);
+ a = "nine"; qos.enqueue(a);
+ a = "ten"; qos.enqueue(a);
+ a = "eleven"; qos.enqueue(a);
+ a = "twelve"; qos.enqueue(a);
+ for (int i = 0; i < 1000; ++i)
+ {
+ a = "thirteen"; qos.enqueue(a);
+ }
+ lb.load(qos);
+ lb.select(1);
+ lb.select(2);
+ lb.select(3);
+ lb.select(5);
+ lb.enable_multiple_select();
+ lb.set_double_click_handler(*this,&win::lb_double_click);
+ // lb.disable_multiple_select();
+
+ btn_show_message.set_pos(50,350);
+ btn_show_message.set_name("message_box()");
+ mbar.set_number_of_menus(2);
+ mbar.set_menu_name(0,"File",'F');
+ mbar.set_menu_name(1,"Help",'H');
+ mbar.menu(0).add_menu_item(menu_item_text("show msg click",*this,&win::on_show_msg_click,'s'));
+ mbar.menu(0).add_menu_item(menu_item_text("get clipboard",*this,&win::on_get_clipboard,'g'));
+ mbar.menu(0).add_menu_item(menu_item_text("set clipboard",*this,&win::on_set_clipboard,'c'));
+ mbar.menu(0).add_menu_item(menu_item_separator());
+ mbar.menu(0).add_submenu(menu_item_submenu("submenu",'m'), submenu);
+ submenu.add_menu_item(menu_item_separator());
+ submenu.add_menu_item(menu_item_separator());
+ submenu.add_menu_item(menu_item_text("show msg click",*this,&win::on_show_msg_click,'s'));
+ submenu.add_menu_item(menu_item_text("get clipboard",*this,&win::on_get_clipboard,'g'));
+ submenu.add_menu_item(menu_item_text("set clipboard",*this,&win::on_set_clipboard,'c'));
+ submenu.add_menu_item(menu_item_separator());
+ submenu.add_menu_item(menu_item_separator());
+ mbar.menu(1).add_menu_item(menu_item_text("About",*this,&win::on_menu_help,'A'));
+
+ btn_show_message.set_click_handler(*this,&win::on_show_msg_click);
+ btn_get_clipboard.set_pos(btn_show_message.right()+5,btn_show_message.top());
+ btn_get_clipboard.set_name("get_from_clipboard()");
+ btn_get_clipboard.set_click_handler(*this,&win::on_get_clipboard);
+
+ btn_get_clipboard.set_style(button_style_toolbar1());
+ btn_set_clipboard.set_pos(btn_get_clipboard.right()+5,btn_get_clipboard.top());
+ btn_set_clipboard.set_name("put_on_clipboard()");
+ btn_set_clipboard.set_click_handler(*this,&win::on_set_clipboard);
+
+ nrect.set_size(700,500);
+ nrect.set_name("test widgets");
+ nrect.set_pos(2,mbar.bottom()+2);
+
+ //throw dlib::error("holy crap batman");
+ tab_label.set_pos(10,440);
+
+ tabs.set_click_handler(*this,&win::tab_change);
+ tabs.set_pos(5,mbar.bottom()+10);
+ tabs.set_size(280,100);
+ tabs.set_number_of_tabs(3);
+ tabs.set_tab_name(0,"davis");
+ tabs.set_tab_name(1,"edward");
+ tabs.set_tab_name(2,"king alsklsdkfj asfd");
+ tabs.set_tab_group(0,group1);
+ tabs.set_tab_group(1,group2);
+ tabs.set_tab_group(2,group3);
+
+ l1.set_text("group one");
+ l2.set_text("group two");
+ l3.set_text("group three");
+
+ group1.add(l1,0,0);
+ group2.add(l2,20,10);
+ group3.add(l3,0,0);
+
+
+
+ sb_enabled.set_name("enabled");
+ sb_shown.set_name("shown");
+ sb_shown.set_checked();
+ sb_enabled.set_checked();
+ sb_shown.set_click_handler(*this,&win::cb_sb_shown);
+ sb_enabled.set_click_handler(*this,&win::cb_sb_enabled);
+
+ sb_shown.set_tooltip_text("I'm a checkbox");
+
+ rb.set_click_handler(*this,&win::rb_click);
+
+
+ sb3.set_pos(440,mbar.bottom()+10);
+ sb3.set_max_slider_pos(300);
+ sb3.set_slider_pos(150);
+ sb3.set_length(300);
+ sb2.set_pos(470,mbar.bottom()+10);
+ sb2.set_max_slider_pos(300);
+ sb2.set_length(300);
+ sb.set_pos(500,mbar.bottom()+10);
+ sb.set_max_slider_pos(30);
+ sb.set_length(300);
+
+
+ sb.set_scroll_handler(*this,&win::scroll_handler);
+ sb2.set_scroll_handler(*this,&win::scroll2_handler);
+ sb3.set_scroll_handler(*this,&win::scroll3_handler);
+ sbl.set_pos(540,mbar.bottom()+20);
+ sbl2.set_pos(540,mbar.bottom()+40);
+ sbl3.set_pos(540,mbar.bottom()+60);
+
+ cbox.set_pos(300,mbar.bottom()+30);
+ cbox.set_name("davis king");
+ cbox.set_click_handler(*this,&win::cbox_clicked);
+
+ cbl.set_pos(300,cbox.get_rect().bottom()+1);
+ cbox.set_checked();
+ sb_enabled.set_pos(cbox.get_rect().left(),cbox.get_rect().bottom()+20);
+ sb_shown.set_pos(sb_enabled.get_rect().left(),sb_enabled.get_rect().bottom()+2);
+
+
+
+ if (use_bdf_fonts)
+ rb.set_main_font(f);
+ rb.set_name("radio button");
+ rb.set_pos(sb_shown.get_rect().left(),sb_shown.get_rect().bottom()+2);
+
+
+ cb1.set_z_order(10);
+ cb2.set_z_order(20);
+
+ pos.set_pos(50,50);
+ left_up.set_pos(50,70);
+ left_down.set_pos(50,90);
+ middle_up.set_pos(50,110);
+ middle_down.set_pos(50,130);
+ right_up.set_pos(50,150);
+ right_down.set_pos(50,170);
+
+ mouse_state.set_pos(50,190);
+
+ leave_count.set_pos(50,210);
+
+ scroll_count = 0;
+ scroll.set_pos(50,230);
+
+ btn_count.set_pos(50,250);
+
+
+ keydown.set_pos(50,270);
+ keyup.set_pos(50,290);
+
+ tf.set_pos(50,310);
+ tf.set_text("Davis685g@");
+ tf.set_width(500);
+ tf.set_text_color(rgb_pixel(255,0,0));
+ tf.set_enter_key_handler(*this,&win::on_enter_key);
+ tf.set_focus_lost_handler(*this,&win::on_tf_focus_lost);
+
+ tb.set_pos(250,400);
+ tb.set_text("initial test\nstring");
+ tb.set_size(300,300);
+ tb.set_text_color(rgb_pixel(255,0,0));
+ tb.set_enter_key_handler(*this,&win::on_enter_key);
+ tb.set_focus_lost_handler(*this,&win::on_tf_focus_lost);
+
+
+ button_count = 0;
+ count = 0;
+ b.set_name("button");
+ b.set_pos(540,100);
+ b.set_click_handler(*this,&win::on_click);
+ b.set_tooltip_text("hurray i'm a button!");
+ if (use_bdf_fonts)
+ b.set_main_font(f);
+
+
+ set_size(815,730);
+
+ nrect.wrap_around(
+ cbox.get_rect() +
+ rb.get_rect() +
+ sb_enabled.get_rect() +
+ sb_shown.get_rect());
+
+ flip = 0;
+ open_file_box(*this,&win::on_open_file);
+ open_existing_file_box(*this,&win::on_open_file);
+ save_file_box(*this,&win::on_open_file);
+
+ if (use_bdf_fonts)
+ {
+ tf.set_main_font(f);
+ tb.set_main_font(f);
+ }
+ if (use_bdf_fonts)
+ tabs.set_main_font(f);
+
+ }
+
+private:
+
+
+ void on_enter_key()
+ {
+ cout << "enter key pressed" << endl;
+ }
+
+ void on_tf_focus_lost()
+ {
+ cout << "text field/box lost focus" << endl;
+ }
+
+
+ void on_open_file (const std::string& file)
+ {
+ message_box("file opened",file);
+ }
+
+
+
+
+ void on_click (
+ )
+ {
+ ostringstream sout;
+ sout << "text field: " << tf.text();
+ ++button_count;
+ btn_count.set_text(sout.str());
+
+ if (flip == 0)
+ {
+ flip = 1;
+ lb.set_size(200,200);
+ }
+ else if (flip == 1)
+ {
+ flip = 2;
+ lb.set_size(150,200);
+ }
+ else if (flip == 2)
+ {
+ flip = 3;
+ lb.set_size(150,300);
+ }
+ else
+ {
+ flip = 0;
+ lb.set_size(200,300);
+ }
+ }
+
+
+ button b;
+ label btn_count;
+ button btn_get_clipboard;
+ button btn_set_clipboard;
+ button btn_show_message;
+ int button_count;
+ color_box cb1;
+ color_box cb2;
+ label cbl;
+ check_box cbox;
+ int count;
+ int flip;
+ widget_group group1;
+ widget_group group2;
+ widget_group group3;
+ int keyboard_count;
+ label keydown;
+ label keyup;
+ label l1;
+ label l2;
+ label l3;
+ list_box lb;
+ label leave_count;
+ label left_down;
+ label left_up;
+ label middle_down;
+ label middle_up;
+ label mouse_state;
+ mouse_tracker mt;
+ named_rectangle nrect;
+ label pos;
+ radio_button rb;
+ label right_down;
+ label right_up;
+ scroll_bar sb2;
+ scroll_bar sb3;
+ check_box sb_enabled;
+ label sbl2;
+ label sbl3;
+ label sbl;
+ check_box sb_shown;
+ scroll_bar sb;
+ int scroll_count;
+ label scroll;
+ label tab_label;
+ tabbed_display tabs;
+ text_field tf;
+ text_box tb;
+ menu_bar mbar;
+ popup_menu submenu;
+
+};
+
+
+win w;
+
+int main()
+{
+
+ try
+ {
+
+ image_window win;
+
+ array2d<unsigned char> img;
+ img.set_size(100,100);
+ assign_all_pixels(img,0);
+
+ fill_rect(img, rectangle(1,1,1,1), 255);
+ fill_rect(img, rectangle(1,3,2,5), 255);
+ fill_rect(img, rectangle(4,3,5,4), 255);
+ fill_rect(img, rectangle(9,9,13,10), 255);
+
+ win.set_image(img);
+
+ win.add_overlay(image_display::overlay_rect(rectangle(1,1,1,1), rgb_pixel(255,0,0)));
+ win.add_overlay(image_display::overlay_rect(rectangle(1,3,2,5), rgb_pixel(255,0,0)));
+ win.add_overlay(image_display::overlay_rect(rectangle(4,3,5,4), rgb_pixel(255,0,0)));
+ win.add_overlay(image_display::overlay_rect(rectangle(9,9,13,10), rgb_pixel(255,0,0)));
+
+
+
+ w.set_pos (100,200);
+ w.set_title("test window");
+ w.show();
+
+ w.wait_until_closed();
+ }
+ catch (exception& e)
+ {
+ cout << e.what() << endl;
+ }
+
+}
diff --git a/ml/dlib/dlib/test/hash.cpp b/ml/dlib/dlib/test/hash.cpp
new file mode 100644
index 000000000..94930e629
--- /dev/null
+++ b/ml/dlib/dlib/test/hash.cpp
@@ -0,0 +1,369 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/hash.h>
+#include <dlib/rand.h>
+#include <dlib/matrix.h>
+#include <dlib/byte_orderer.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.hash");
+
+
+ template <typename T>
+ void to_little (
+ std::vector<T>& item
+ )
+ {
+ byte_orderer bo;
+ for (unsigned long i = 0; i < item.size(); ++i)
+ bo.host_to_little(item[i]);
+ }
+
+
+ template <typename T>
+ void to_little (
+ matrix<T>& item
+ )
+ {
+ byte_orderer bo;
+ for (long r = 0; r < item.nr(); ++r)
+ {
+ for (long c = 0; c < item.nc(); ++c)
+ {
+ bo.host_to_little(item(r,c));
+ }
+ }
+ }
+
+ // Run the official test for MurmurHash3
+ void murmur_hash_test()
+ {
+ uint8 key[256];
+ uint32 hashes[256];
+ uint32 final = 0;
+
+ memset(key,0,sizeof(key));
+ memset(hashes,0,sizeof(hashes));
+
+ // Hash keys of the form {0}, {0,1}, {0,1,2}... up to N=255,using 256-N as
+ // the seed.
+ for(int i = 0; i < 256; i++)
+ {
+ key[i] = (uint8)i;
+
+ hashes[i] = murmur_hash3(key,i,256-i);
+ }
+
+ byte_orderer bo;
+ bo.host_to_little(hashes);
+ final = murmur_hash3(hashes,sizeof(hashes),0);
+
+ // using ostringstream to avoid compiler error in visual studio 2005
+ ostringstream sout;
+ sout << hex << final;
+ dlog << LINFO << "final: "<< sout.str();
+ DLIB_TEST(final == 0xB0F57EE3);
+ }
+
+ void murmur_hash_128_test()
+ {
+ uint8 key[256];
+ uint64 hashes[256*2];
+ uint32 final = 0;
+
+ memset(key,0,sizeof(key));
+ memset(hashes,0,sizeof(hashes));
+
+ // Hash keys of the form {0}, {0,1}, {0,1,2}... up to N=255,using 256-N as
+ // the seed.
+ for(int i = 0; i < 256; i++)
+ {
+ key[i] = (uint8)i;
+
+ const std::pair<uint64,uint64> temp = murmur_hash3_128bit(key,i,256-i);
+ hashes[2*i] = temp.first;
+ hashes[2*i+1] = temp.second;
+ }
+
+ byte_orderer bo;
+ bo.host_to_little(hashes);
+ final = static_cast<uint32>(murmur_hash3_128bit(hashes,sizeof(hashes),0).first);
+
+ // using ostringstream to avoid compiler error in visual studio 2005
+ ostringstream sout;
+ sout << hex << final;
+ dlog << LINFO << "final 64: "<< sout.str();
+ DLIB_TEST(final == 0x6384BA69);
+ }
+
+ void test_murmur_hash_128_4()
+ {
+ byte_orderer bo;
+ dlib::rand rnd;
+ for (int i = 0; i < 100; ++i)
+ {
+ uint32 buf[4] = { rnd.get_random_32bit_number(),
+ rnd.get_random_32bit_number(),
+ rnd.get_random_32bit_number(),
+ rnd.get_random_32bit_number()
+ };
+
+ bo.host_to_little(buf);
+
+ std::pair<uint64,uint64> temp1, temp2;
+
+ // Make sure the 4 integer version of murmur hash does the same thing
+ // as the memory block version.
+ temp1 = murmur_hash3_128bit(buf, sizeof(buf), 0);
+ temp2 = murmur_hash3_128bit(buf[0], buf[1], buf[2], buf[3]);
+ DLIB_TEST( temp1.first == temp2.first);
+ DLIB_TEST( temp1.second == temp2.second);
+ }
+ }
+
+ void test_murmur_hash_128_3()
+ {
+ byte_orderer bo;
+ dlib::rand rnd;
+ for (int i = 0; i < 100; ++i)
+ {
+ uint64 buf[2] = { rnd.get_random_64bit_number(),
+ rnd.get_random_64bit_number(),
+ };
+
+ const uint32 seed = rnd.get_random_32bit_number();
+
+ bo.host_to_little(buf);
+ std::pair<uint64,uint64> temp1, temp2;
+
+ // Make sure the 3 integer version of murmur hash does the same thing
+ // as the memory block version.
+ temp1 = murmur_hash3_128bit(buf, sizeof(buf), seed);
+ temp2 = murmur_hash3_128bit_3(buf[0], buf[1], seed);
+ DLIB_TEST( temp1.first == temp2.first);
+ DLIB_TEST( temp1.second == temp2.second);
+ }
+ }
+
+ void test_murmur_hash_64_2()
+ {
+ byte_orderer bo;
+ dlib::rand rnd;
+ for (int i = 0; i < 100; ++i)
+ {
+ uint32 val = rnd.get_random_32bit_number();
+ const uint32 seed = rnd.get_random_32bit_number();
+
+
+ bo.host_to_little(val);
+ uint32 temp1, temp2;
+
+ // Make sure the 2 integer version of murmur hash does the same thing
+ // as the memory block version.
+ temp1 = murmur_hash3(&val, sizeof(val), seed);
+ temp2 = murmur_hash3_2(val, seed);
+ DLIB_TEST(temp1 == temp2);
+ }
+ }
+
+ void test_murmur_hash_64_3()
+ {
+ byte_orderer bo;
+ dlib::rand rnd;
+ for (int i = 0; i < 100; ++i)
+ {
+ uint32 buf[2] = {rnd.get_random_32bit_number(),
+ rnd.get_random_32bit_number()};
+ const uint32 seed = rnd.get_random_32bit_number();
+
+
+ bo.host_to_little(buf);
+ uint32 temp1, temp2;
+
+ // Make sure the 2 integer version of murmur hash does the same thing
+ // as the memory block version.
+ temp1 = murmur_hash3(&buf, sizeof(buf), seed);
+ temp2 = murmur_hash3_3(buf[0], buf[1], seed);
+ DLIB_TEST(temp1 == temp2);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint64 slow_count_bits ( uint64 v)
+ {
+ uint64 count = 0;
+ for (int i = 0; i < 64; ++i)
+ {
+ if (v&1)
+ ++count;
+ v >>= 1;
+ }
+ return count;
+ }
+
+
+ uint32 slow_count_bits ( uint32 v)
+ {
+ uint32 count = 0;
+ for (int i = 0; i < 32; ++i)
+ {
+ if (v&1)
+ ++count;
+ v >>= 1;
+ }
+ return count;
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ void test_hamming_stuff()
+ {
+ dlib::rand rnd;
+ for (int i = 0; i < 10000; ++i)
+ {
+ uint32 v = rnd.get_random_32bit_number();
+ uint64 v2 = rnd.get_random_64bit_number();
+ DLIB_TEST(slow_count_bits(v) == count_bits(v));
+ DLIB_TEST(slow_count_bits(v2) == count_bits(v2));
+ }
+
+ DLIB_TEST(hamming_distance((uint32)0x1F, (uint32)0x0F) == 1);
+ DLIB_TEST(hamming_distance((uint32)0x1F, (uint32)0x1F) == 0);
+ DLIB_TEST(hamming_distance((uint32)0x1F, (uint32)0x19) == 2);
+ DLIB_TEST(hamming_distance((uint32)0x2F, (uint32)0x19) == 4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_hash : public tester
+ {
+ public:
+ test_hash (
+ ) :
+ tester ("test_hash",
+ "Runs tests on the hash routines.")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+
+ test_hamming_stuff();
+
+ murmur_hash_test();
+ murmur_hash_128_test();
+
+ std::string str1 = "some random string";
+ matrix<unsigned char> mat(2,2);
+
+ mat = 1,2,3,4;
+
+ matrix<uint64> mat2(2,3);
+
+ mat2 = 1,2,3,4,5,6;
+
+ to_little(mat2);
+
+ std::vector<unsigned char> v(4);
+ v[0] = 'c';
+ v[1] = 'a';
+ v[2] = 't';
+ v[3] = '!';
+
+ std::vector<uint16> v2(4);
+ v[0] = 'c';
+ v[1] = 'a';
+ v[2] = 't';
+ v[3] = '!';
+ to_little(v2);
+
+ std::map<unsigned char, unsigned char> m;
+ m['c'] = 'C';
+ m['a'] = 'A';
+ m['t'] = 'T';
+
+ dlog << LINFO << "hash(str1): "<< dlib::hash(str1);
+ dlog << LINFO << "hash(v): "<< dlib::hash(v);
+ dlog << LINFO << "hash(v2): "<< dlib::hash(v2);
+ dlog << LINFO << "hash(m): "<< dlib::hash(m);
+ dlog << LINFO << "hash(mat): "<< dlib::hash(mat);
+ dlog << LINFO << "hash(mat2): "<< dlib::hash(mat2);
+
+ uint32 ui1 = 123485393;
+ uint64 ui2 = ui1;
+ ui2 *= ui2;
+ ui2 *= ui2;
+ dlog << LINFO << "hash(ui1): "<< dlib::hash(ui1);
+ dlog << LINFO << "hash(ui2): "<< dlib::hash(ui2);
+ dlog << LINFO << "hash(make_pair(ui2,ui1)): "<< dlib::hash(make_pair(ui2,ui1));
+ dlog << LINFO << "hash(make_pair(ui2,ui2)): "<< dlib::hash(make_pair(ui2,ui2));
+ dlog << LINFO << "hash(make_pair(ui1,ui1)): "<< dlib::hash(make_pair(ui1,ui1));
+ dlog << LINFO << "hash(ui1,3): "<< dlib::hash(ui1,3);
+ dlog << LINFO << "hash(ui2,3): "<< dlib::hash(ui2,3);
+ dlog << LINFO << "hash(make_pair(ui2,ui1),3): "<< dlib::hash(make_pair(ui2,ui1),3);
+ dlog << LINFO << "hash(make_pair(ui2,ui2),3): "<< dlib::hash(make_pair(ui2,ui2),3);
+ dlog << LINFO << "hash(make_pair(ui1,ui1),3): "<< dlib::hash(make_pair(ui1,ui1),3);
+
+ DLIB_TEST(dlib::hash(ui1) == 0x63e272e4);
+ DLIB_TEST(dlib::hash(ui2) == 0xaf55561a);
+ DLIB_TEST(dlib::hash(make_pair(ui2,ui1)) == 0x52685376);
+ DLIB_TEST(dlib::hash(make_pair(ui2,ui2)) == 0xd25d6929);
+ DLIB_TEST(dlib::hash(make_pair(ui1,ui1)) == 0xeea3b63e);
+ DLIB_TEST(dlib::hash(ui1,3) == 0x95d1c4c0);
+ DLIB_TEST(dlib::hash(ui2,3) == 0x6ada728d);
+ DLIB_TEST(dlib::hash(make_pair(ui2,ui1),3) == 0x2f72a0ff);
+ DLIB_TEST(dlib::hash(make_pair(ui2,ui2),3) == 0xac1407f0);
+ DLIB_TEST(dlib::hash(make_pair(ui1,ui1),3) == 0x39ad637a);
+
+
+ DLIB_TEST(dlib::hash(str1) == 0x3ffe6bf6);
+ DLIB_TEST(dlib::hash(v) == 0xf1af2ca6);
+ DLIB_TEST(dlib::hash(v2) == 0x63852afc);
+ DLIB_TEST(dlib::hash(m) == 0xaacc3f6f);
+ DLIB_TEST(dlib::hash(mat) == 0x3e349da5);
+ DLIB_TEST(dlib::hash(mat2) == 0x3a95dc52);
+ DLIB_TEST(murmur_hash3(&str1[0], str1.size(), 0) == 0x3ffe6bf6);
+
+ dlog << LINFO << "hash(str1,1): "<< dlib::hash(str1,1);
+ dlog << LINFO << "hash(v,3): "<< dlib::hash(v,3);
+ dlog << LINFO << "hash(v2,3): "<< dlib::hash(v2,3);
+ dlog << LINFO << "hash(m,4): "<< dlib::hash(m,4);
+ dlog << LINFO << "hash(mat,5): "<< dlib::hash(mat,5);
+ dlog << LINFO << "hash(mat2,6): "<< dlib::hash(mat2,6);
+
+ DLIB_TEST(dlib::hash(str1,1) == 0xb17cea93);
+ DLIB_TEST(dlib::hash(v,3) == 0x7ec9284c);
+ DLIB_TEST(dlib::hash(v2,3) == 0xb2ce147f);
+ DLIB_TEST(dlib::hash(m,4) == 0xfa5e7ac2);
+ DLIB_TEST(dlib::hash(mat,5) == 0x8de27259);
+ DLIB_TEST(dlib::hash(mat2,6) == 0xb8aa7714);
+ DLIB_TEST(murmur_hash3(&str1[0], str1.size(), 1) == 0xb17cea93);
+
+ test_murmur_hash_128_4();
+ test_murmur_hash_128_3();
+ test_murmur_hash_64_2();
+ test_murmur_hash_64_3();
+ }
+ } a;
+
+
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/hash_map.cpp b/ml/dlib/dlib/test/hash_map.cpp
new file mode 100644
index 000000000..09af09936
--- /dev/null
+++ b/ml/dlib/dlib/test/hash_map.cpp
@@ -0,0 +1,450 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/hash_map.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.hash_map");
+
+ template <
+ typename hash_map
+ >
+ void hash_map_kernel_test (
+ )
+ /*!
+ requires
+ - hash_map is an implementation of hash_map/hash_map_kernel_abstract.h and
+ is instantiated to map int to int
+ ensures
+ - runs tests on hash_map for compliance with the specs
+ !*/
+ {
+
+ srand(static_cast<unsigned int>(time(0)));
+
+ print_spinner();
+
+
+ hash_map test, test2;
+
+ enumerable<map_pair<int,int> >& e = test;
+ DLIB_TEST(e.at_start() == true);
+
+ for (int j = 0; j < 4; ++j)
+ {
+ print_spinner();
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+
+
+ int a,b;
+ a = 8;
+ b = 94;
+ test.add(a,b);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(test.is_in_domain(8) == true);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+ DLIB_TEST(test[8] == 94);
+ a = 53;
+ b = 4;
+ test.add(a,b);
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test.is_in_domain(53) == true);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+ DLIB_TEST(test[53] == 4);
+
+
+ swap(test,test2);
+
+
+ DLIB_TEST_MSG(test2.size() == 2,test2.size());
+ DLIB_TEST(test2.is_in_domain(8) == true);
+ DLIB_TEST(test2.is_in_domain(5) == false);
+ DLIB_TEST(test2.is_in_domain(0) == false);
+ DLIB_TEST(test2.is_in_domain(-999) == false);
+ DLIB_TEST(test2.is_in_domain(4999) == false);
+ DLIB_TEST(test2[8] == 94);
+ DLIB_TEST(test2.size() == 2);
+ DLIB_TEST(test2.is_in_domain(53) == true);
+ DLIB_TEST(test2.is_in_domain(5) == false);
+ DLIB_TEST(test2.is_in_domain(0) == false);
+ DLIB_TEST(test2.is_in_domain(-999) == false);
+ DLIB_TEST(test2.is_in_domain(4999) == false);
+ DLIB_TEST(test2[53] == 4);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_in_domain(8) == false);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_in_domain(53) == false);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+
+
+ test.clear();
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+ test.clear();
+ DLIB_TEST(test.size() == 0);
+
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+
+ int count = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element().key() == test.element().key());
+ DLIB_TEST(test.element().value() == test.element().value());
+ DLIB_TEST(test.element().key() == test.element().key());
+ DLIB_TEST(test.element().value() == test.element().value());
+
+
+
+ ++count;
+ }
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.move_next() == false);
+
+ DLIB_TEST(count == 10000);
+
+ test.swap(test2);
+
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test2.size() == 10000);
+ count = 0;
+ test2.reset();
+
+ test2.move_next();
+ test2.element().value() = 99;
+ DLIB_TEST(test2[test2.element().key()] == 99);
+ DLIB_TEST(test2.element().value() == 99);
+
+ test2.reset();
+
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2[test2.element().key()] == test2.element().value());
+ DLIB_TEST(test2.element().key() == test2.element().key());
+ DLIB_TEST(test2.element().value() == test2.element().value());
+ DLIB_TEST(test2.element().key() == test2.element().key());
+ DLIB_TEST(test2.element().value() == test2.element().value());
+
+ ++count;
+ }
+ DLIB_TEST(test2.size() == 10000);
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.move_next() == false);
+
+
+
+ test2.clear();
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.at_start() == true);
+
+ while (test.size() < 20000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ DLIB_TEST(test.at_start() == true);
+
+ {
+ int* array1 = new int[test.size()];
+ int* array2 = new int[test.size()];
+
+ int* tmp1 = array1;
+ int* tmp2 = array2;
+
+
+
+ // serialize the state of test, then clear test, then
+ // load the state back into test.
+ ostringstream sout;
+ serialize(test,sout);
+ DLIB_TEST(test.at_start() == true);
+ istringstream sin(sout.str());
+ test.clear();
+ deserialize(test,sin);
+ DLIB_TEST(test.at_start() == true);
+
+
+ count = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element().key() == test.element().key());
+ DLIB_TEST(test.element().value() == test.element().value());
+ DLIB_TEST(test.element().key() == test.element().key());
+ DLIB_TEST(test.current_element_valid() == true);
+ *tmp1 = test.element().key();
+ *tmp2 = test.element().value();
+ ++tmp1;
+ ++tmp2;
+ ++count;
+ }
+ DLIB_TEST(count == 20000);
+
+ tmp1 = array1;
+ tmp2 = array2;
+ for (int i = 0; i < 20000; ++i)
+ {
+ DLIB_TEST(test.is_in_domain(*tmp1) == true);
+ DLIB_TEST(test[*tmp1] == *tmp2);
+ ++tmp1;
+ ++tmp2;
+ }
+
+ DLIB_TEST(test.size() == 20000);
+
+ tmp1 = array1;
+ tmp2 = array2;
+ count = 0;
+ while (test.size() > 10000)
+ {
+ test.remove(*tmp1,a,b);
+ DLIB_TEST(*tmp1 == a);
+ DLIB_TEST(*tmp2 == b);
+ ++tmp1;
+ ++tmp2;
+ ++count;
+ }
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test.size() == 10000);
+
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element().key() == *tmp1);
+ DLIB_TEST(test.element().key() == *tmp1);
+ DLIB_TEST(test.element().key() == *tmp1);
+ DLIB_TEST(test.element().value() == *tmp2);
+ DLIB_TEST(test.element().value() == *tmp2);
+ DLIB_TEST(test.element().value() == *tmp2);
+ ++tmp1;
+ ++tmp2;
+ ++count;
+ }
+ DLIB_TEST(count == 20000);
+ DLIB_TEST(test.size() == 10000);
+
+ while (test.size() < 20000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ test2.swap(test);
+
+ count = 0;
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2.element().key() == test2.element().key());
+ DLIB_TEST(test2.element().value() == test2.element().value());
+ DLIB_TEST(test2.element().key() == test2.element().key());
+
+ ++count;
+ }
+
+ DLIB_TEST(count == 20000);
+ DLIB_TEST(test2.size() == 20000);
+
+ int c = 0;
+ while (test2.size()>0)
+ {
+ test2.remove_any(b,c);
+
+ }
+
+ DLIB_TEST(test2.size() == 0);
+ delete [] array1;
+ delete [] array2;
+ }
+
+ test.clear();
+ test2.clear();
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ count = 0;
+ while (test.move_next())
+ {
+
+ DLIB_TEST(test[test.element().key()] == test.element().value());
+
+ ++count;
+ if (count == 5000)
+ break;
+ DLIB_TEST(test.current_element_valid() == true);
+ }
+
+ test.reset();
+
+ count = 0;
+
+ while (test.move_next())
+ {
+
+ ++count;
+ DLIB_TEST(test.current_element_valid() == true);
+ }
+
+ DLIB_TEST(count == 10000);
+
+
+ test.clear();
+ test2.clear();
+ }
+
+
+
+
+ {
+ test.clear();
+ DLIB_TEST(test.size() == 0);
+ int a = 5;
+ int b = 6;
+ test.add(a,b);
+ a = 7;
+ b = 8;
+ test.add(a,b);
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test[7] == 8);
+ DLIB_TEST(test[5] == 6);
+ DLIB_TEST(test.is_in_domain(7));
+ DLIB_TEST(test.is_in_domain(5));
+ test.destroy(7);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(!test.is_in_domain(7));
+ DLIB_TEST(test.is_in_domain(5));
+ test.destroy(5);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(!test.is_in_domain(7));
+ DLIB_TEST(!test.is_in_domain(5));
+ }
+
+
+
+ }
+
+
+
+
+
+ class hash_map_tester : public tester
+ {
+ public:
+ hash_map_tester (
+ ) :
+ tester ("test_hash_map",
+ "Runs tests on the hash_map component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ hash_map_kernel_test<hash_map<int,int,14>::kernel_1a>();
+
+ dlog << LINFO << "testing kernel_1b_c";
+ hash_map_kernel_test<hash_map<int,int,14>::kernel_1a_c>();
+
+ dlog << LINFO << "testing kernel_1b";
+ hash_map_kernel_test<hash_map<int,int,14>::kernel_1b>();
+
+ dlog << LINFO << "testing kernel_1a_c";
+ hash_map_kernel_test<hash_map<int,int,14>::kernel_1b_c>();
+
+ dlog << LINFO << "testing kernel_1c";
+ hash_map_kernel_test<hash_map<int,int,14>::kernel_1c>();
+
+ dlog << LINFO << "testing kernel_1c_c";
+ hash_map_kernel_test<hash_map<int,int,14>::kernel_1c_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/hash_set.cpp b/ml/dlib/dlib/test/hash_set.cpp
new file mode 100644
index 000000000..02b665bdb
--- /dev/null
+++ b/ml/dlib/dlib/test/hash_set.cpp
@@ -0,0 +1,387 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/hash_set.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.hash_set");
+
+ template <
+ typename hash_set
+ >
+ void hash_set_kernel_test (
+ )
+ /*!
+ requires
+ - hash_set is an implementation of hash_set/hash_set_kernel_abstract.h and
+ is instantiated with int
+ ensures
+ - runs tests on hash_set for compliance with the specs
+ !*/
+ {
+
+
+ srand(static_cast<unsigned int>(time(0)));
+
+
+ print_spinner();
+
+ hash_set test, test2;
+
+
+ enumerable<const int>& e = test;
+ DLIB_TEST(e.at_start() == true);
+
+
+ for (int j = 0; j < 4; ++j)
+ {
+ print_spinner();
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+
+
+ int a,b = 0;
+ a = 8;
+ test.add(a);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(test.is_member(8) == true);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+ a = 53;
+ test.add(a);
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test.is_member(53) == true);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+
+
+ swap(test,test2);
+
+
+
+ DLIB_TEST(test2.is_member(8) == true);
+ DLIB_TEST(test2.is_member(5) == false);
+ DLIB_TEST(test2.is_member(0) == false);
+ DLIB_TEST(test2.is_member(-999) == false);
+ DLIB_TEST(test2.is_member(4999) == false);
+ DLIB_TEST(test2.size() == 2);
+ DLIB_TEST(test2.is_member(53) == true);
+ DLIB_TEST(test2.is_member(5) == false);
+ DLIB_TEST(test2.is_member(0) == false);
+ DLIB_TEST(test2.is_member(-999) == false);
+ DLIB_TEST(test2.is_member(4999) == false);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_member(8) == false);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_member(53) == false);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+
+
+ test.clear();
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+ test.clear();
+ DLIB_TEST(test.size() == 0);
+
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+
+ int count = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element() == test.element());
+ DLIB_TEST(test.element() == test.element());
+ DLIB_TEST(test.element() == test.element());
+
+
+ ++count;
+ }
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.move_next() == false);
+
+ DLIB_TEST(count == 10000);
+
+ test.swap(test2);
+
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test2.size() == 10000);
+ count = 0;
+ test2.reset();
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(test2.element() == test2.element());
+
+ ++count;
+ }
+ DLIB_TEST(test2.size() == 10000);
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.move_next() == false);
+
+
+
+ test2.clear();
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.at_start() == true);
+
+ while (test.size() < 20000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ DLIB_TEST(test.at_start() == true);
+
+ {
+ int* array = new int[test.size()];
+ int* tmp = array;
+
+ // serialize the state of test, then clear test, then
+ // load the state back into test.
+ ostringstream sout;
+ serialize(test,sout);
+ DLIB_TEST(test.at_start() == true);
+ istringstream sin(sout.str());
+ test.clear();
+ deserialize(test,sin);
+
+
+
+ count = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element() == test.element());
+ DLIB_TEST(test.element() == test.element());
+ DLIB_TEST(test.element() == test.element());
+ *tmp = test.element();
+ ++tmp;
+ ++count;
+ }
+ DLIB_TEST(count == 20000);
+
+ tmp = array;
+ for (int i = 0; i < 20000; ++i)
+ {
+ DLIB_TEST(test.is_member(*tmp) == true);
+ ++tmp;
+ }
+
+ DLIB_TEST(test.size() == 20000);
+
+ tmp = array;
+ count = 0;
+ while (test.size() > 10000)
+ {
+ test.remove(*tmp,a);
+ DLIB_TEST(*tmp == a);
+ ++tmp;
+ ++count;
+ }
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test.size() == 10000);
+
+ while (test.move_next())
+ {
+ ++count;
+ }
+ DLIB_TEST(count == 20000);
+ DLIB_TEST(test.size() == 10000);
+
+ while (test.size() < 20000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ test2.swap(test);
+
+ count = 0;
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(test2.element() == test2.element());
+
+ ++count;
+ }
+
+ DLIB_TEST(count == 20000);
+ DLIB_TEST(test2.size() == 20000);
+
+
+ while (test2.size()>0)
+ {
+ test2.remove_any(b);
+ }
+
+ DLIB_TEST(test2.size() == 0);
+ delete [] array;
+ }
+
+ test.clear();
+ test2.clear();
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ if (count == 5000)
+ break;
+ DLIB_TEST(test.current_element_valid() == true);
+ }
+
+ test.reset();
+
+ count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ DLIB_TEST(test.current_element_valid() == true);
+ }
+
+ DLIB_TEST(count == 10000);
+
+
+ test.clear();
+ test2.clear();
+ }
+
+
+ {
+ test.clear();
+ DLIB_TEST(test.size() == 0);
+ int a = 5;
+ test.add(a);
+ a = 7;
+ test.add(a);
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test.is_member(7));
+ DLIB_TEST(test.is_member(5));
+ test.destroy(7);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(!test.is_member(7));
+ DLIB_TEST(test.is_member(5));
+ test.destroy(5);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(!test.is_member(7));
+ DLIB_TEST(!test.is_member(5));
+ }
+
+ }
+
+
+
+
+ class hash_set_tester : public tester
+ {
+ public:
+ hash_set_tester (
+ ) :
+ tester ("test_hash_set",
+ "Runs tests on the hash_set component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ hash_set_kernel_test<hash_set<int,14>::kernel_1a>();
+ dlog << LINFO << "testing kernel_1a_c";
+ hash_set_kernel_test<hash_set<int,14>::kernel_1a_c>();
+ dlog << LINFO << "testing kernel_1b";
+ hash_set_kernel_test<hash_set<int,14>::kernel_1b>();
+ dlog << LINFO << "testing kernel_1b_c";
+ hash_set_kernel_test<hash_set<int,14>::kernel_1b_c>();
+ dlog << LINFO << "testing kernel_1c";
+ hash_set_kernel_test<hash_set<int,14>::kernel_1c>();
+ dlog << LINFO << "testing kernel_1c_c";
+ hash_set_kernel_test<hash_set<int,14>::kernel_1c_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/hash_table.cpp b/ml/dlib/dlib/test/hash_table.cpp
new file mode 100644
index 000000000..f4754835e
--- /dev/null
+++ b/ml/dlib/dlib/test/hash_table.cpp
@@ -0,0 +1,663 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/hash_table.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.hash_table");
+
+ template <
+ typename hash_table
+ >
+ void hash_table_kernel_test (
+ )
+ /*!
+ requires
+ - hash_table is an implementation of hash_table/hash_table_kernel_abstract.h
+ and is instantiated to map ints to ints
+ ensures
+ - runs tests on hash_table for compliance with the specs
+ !*/
+ {
+
+ srand(static_cast<unsigned int>(time(0)));
+
+
+
+
+ {
+ hash_table test(16);
+
+ DLIB_TEST(test.count(3) == 0);
+
+ enumerable<map_pair<int,int> >& e = test;
+ DLIB_TEST(e.at_start() == true);
+
+ hash_table test2(16);
+
+ hash_table test3(0);
+ hash_table test4(0);
+
+
+ print_spinner();
+
+ int b;
+ for (int j = 0; j < 4; ++j)
+ {
+ int a = 4;
+ b = 5;
+ test2.add(a,b);
+ DLIB_TEST(test2.size() == 1);
+ DLIB_TEST(*test2[4] == 5);
+ DLIB_TEST(test2[99] == 0);
+
+ DLIB_TEST(test2.move_next());
+ DLIB_TEST(test2.element().key() == 4);
+ DLIB_TEST(test2.element().value() == 5);
+
+ swap(test,test2);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(*test[4] == 5);
+ DLIB_TEST(test[99] == 0);
+
+ test.swap(test2);
+
+ a = 99;
+ b = 35;
+ test2.add(a,b);
+ DLIB_TEST(test2.size() == 2);
+ DLIB_TEST(*test2[4] == 5);
+ DLIB_TEST(*test2[99] == 35);
+ DLIB_TEST(test2[99] != 0);
+ DLIB_TEST(test2[949] == 0);
+
+ test2.destroy(4);
+ DLIB_TEST(test2.size() == 1);
+ DLIB_TEST(test2[4] == 0);
+ DLIB_TEST(*test2[99] == 35);
+ DLIB_TEST(test2[99] != 0);
+ DLIB_TEST(test2[949] == 0);
+
+
+
+ test2.destroy(99);
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2[4] == 0);
+ DLIB_TEST(test2[99] == 0);
+ DLIB_TEST(test2[949] == 0);
+
+
+
+ test2.clear();
+ }
+
+
+ print_spinner();
+
+
+
+
+ for (int j = 0; j < 4; ++j)
+ {
+
+ DLIB_TEST(test.count(3) == 0);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+
+ int a;
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand()%1000;
+ int temp = a;
+ unsigned long count = test.count(a);
+ test.add(a,b);
+ DLIB_TEST(test.count(temp) == count+1);
+ }
+
+ {
+ unsigned long count = test.count(3);
+
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ }
+
+
+ test.clear();
+
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = b = i;
+ unsigned long count = test.count(a);
+ test.add(a,b);
+ DLIB_TEST(test.count(i) == count+1);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == true);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.current_element_valid() == true);
+
+
+ test.reset();
+
+ DLIB_TEST(test.size() == 10000);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ if (test.size() > 0)
+ {
+ int* array = new int[test.size()];
+ int* tmp = array;
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ *tmp = test.element().key();
+ DLIB_TEST(test[*tmp] != 0);
+ DLIB_TEST(*tmp == test.element().key());
+ DLIB_TEST(*tmp == test.element().value());
+ DLIB_TEST(*tmp == test.element().key());
+ DLIB_TEST(test.current_element_valid() == true);
+ ++tmp;
+ }
+
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ DLIB_TEST(test.size() == 10000);
+
+ swap(test,test2);
+
+
+
+
+ // serialize the state of test2, then clear test2, then
+ // load the state back into test2.
+ ostringstream sout;
+ serialize(test2,sout);
+ DLIB_TEST(test2.at_start() == true);
+ istringstream sin(sout.str());
+ test2.clear();
+ deserialize(test2,sin);
+ DLIB_TEST(test2.at_start() == true);
+
+
+
+
+ tmp = array;
+ for (int i = 0; i < 10000; ++i)
+ {
+ DLIB_TEST(*test2[*tmp] == *tmp);
+ DLIB_TEST(*test2[*tmp] == *tmp);
+ DLIB_TEST(*test2[*tmp] == *tmp);
+ ++tmp;
+ }
+
+ test2.swap(test);
+ test.reset();
+
+ DLIB_TEST(test.at_start() == true);
+ count = 0;
+ tmp = array;
+ while (test.size() > 0)
+ {
+ test.remove(*tmp,a,b);
+
+ ++tmp;
+ ++count;
+ }
+
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test.size() == 0);
+
+
+
+ DLIB_TEST(count == 10000);
+
+
+
+
+
+
+
+ delete [] array;
+ }
+
+ test.move_next();
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand();
+ test.add(a,b);
+ }
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.move_next() == true);
+
+ DLIB_TEST(test.size() == 10000);
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ test.remove_any(a,b);
+ }
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.size() == 0);
+
+ test.clear();
+
+
+
+
+
+
+
+
+
+ int* dtmp = new int[10000];
+ int* rtmp = new int[10000];
+
+ int* d = dtmp;
+ int* r = rtmp;
+ for (unsigned long i = 0; i < 10000; ++i)
+ {
+ a = ::rand();
+ b = ::rand();
+ *d = a;
+ *r = b;
+ if (test[a] != 0)
+ {
+ --i;
+ continue;
+ }
+ test.add(a,b);
+ ++d;
+ ++r;
+ DLIB_TEST(test.size() == i+1);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ DLIB_TEST(*test[dtmp[i]] == rtmp[i]);
+ }
+
+
+ delete [] dtmp;
+ delete [] rtmp;
+
+ test.clear();
+ }}
+
+
+ print_spinner();
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ // now do the same thing as above but with a much smaller hash table
+ {
+ hash_table test(13);
+
+ DLIB_TEST(test.count(3) == 0);
+
+ enumerable<map_pair<int,int> >& e = test;
+ DLIB_TEST(e.at_start() == true);
+
+ hash_table test2(16);
+
+ hash_table test3(0);
+ hash_table test4(0);
+
+
+ int b;
+ for (int j = 0; j < 4; ++j)
+ {
+ int a = 4;
+ b = 5;
+ test2.add(a,b);
+ DLIB_TEST(test2.size() == 1);
+ DLIB_TEST(*test2[4] == 5);
+ DLIB_TEST(test2[99] == 0);
+
+
+ DLIB_TEST(test2.move_next());
+ DLIB_TEST(test2.element().key() == 4);
+ DLIB_TEST(test2.element().value() == 5);
+
+ swap(test,test2);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(*test[4] == 5);
+ DLIB_TEST(test[99] == 0);
+
+ test.swap(test2);
+
+ a = 99;
+ b = 35;
+ test2.add(a,b);
+ DLIB_TEST(test2.size() == 2);
+ DLIB_TEST(*test2[4] == 5);
+ DLIB_TEST(*test2[99] == 35);
+ DLIB_TEST(test2[99] != 0);
+ DLIB_TEST(test2[949] == 0);
+
+ test2.destroy(4);
+ DLIB_TEST(test2.size() == 1);
+ DLIB_TEST(test2[4] == 0);
+ DLIB_TEST(*test2[99] == 35);
+ DLIB_TEST(test2[99] != 0);
+ DLIB_TEST(test2[949] == 0);
+
+
+
+ test2.destroy(99);
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2[4] == 0);
+ DLIB_TEST(test2[99] == 0);
+ DLIB_TEST(test2[949] == 0);
+
+
+
+ test2.clear();
+ }
+
+
+ print_spinner();
+
+
+
+
+ for (int j = 0; j < 4; ++j)
+ {
+
+ DLIB_TEST(test.count(3) == 0);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+
+ int a;
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand()%1000;
+ int temp = a;
+ unsigned long count = test.count(a);
+ test.add(a,b);
+ DLIB_TEST(test.count(temp) == count+1);
+ }
+
+ {
+ unsigned long count = test.count(3);
+
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ a = 3; test.add(a,b); ++count;
+ DLIB_TEST(test.count(3) == count);
+ }
+
+
+ test.clear();
+
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = b = i;
+ unsigned long count = test.count(a);
+ test.add(a,b);
+ DLIB_TEST(test.count(i) == count+1);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == true);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.current_element_valid() == true);
+
+
+ test.reset();
+
+ DLIB_TEST(test.size() == 10000);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ if (test.size() > 0)
+ {
+ int* array = new int[test.size()];
+ int* tmp = array;
+
+ int count = 0;
+ while (test.move_next())
+ {
+ ++count;
+ *tmp = test.element().key();
+ DLIB_TEST(test[*tmp] != 0);
+ DLIB_TEST(*tmp == test.element().key());
+ DLIB_TEST(*tmp == test.element().value());
+ DLIB_TEST(*tmp == test.element().key());
+ DLIB_TEST(test.current_element_valid() == true);
+ ++tmp;
+ }
+
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ DLIB_TEST(test.size() == 10000);
+
+ swap(test,test2);
+
+ tmp = array;
+ for (int i = 0; i < 10000; ++i)
+ {
+ DLIB_TEST(*test2[*tmp] == *tmp);
+ DLIB_TEST(*test2[*tmp] == *tmp);
+ DLIB_TEST(*test2[*tmp] == *tmp);
+ ++tmp;
+ }
+
+ test2.swap(test);
+ test.reset();
+
+ DLIB_TEST(test.at_start() == true);
+ count = 0;
+ tmp = array;
+ while (test.size() > 0)
+ {
+ test.remove(*tmp,a,b);
+
+ ++tmp;
+ ++count;
+ }
+
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test.size() == 0);
+
+
+
+ DLIB_TEST(count == 10000);
+
+
+
+
+
+
+
+ delete [] array;
+ }
+
+ test.move_next();
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand();
+ test.add(a,b);
+ }
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.move_next() == true);
+
+ DLIB_TEST(test.size() == 10000);
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ test.remove_any(a,b);
+ }
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.size() == 0);
+
+ test.clear();
+
+
+
+
+
+
+
+
+ int* dtmp = new int[10000];
+ int* rtmp = new int[10000];
+
+ int* d = dtmp;
+ int* r = rtmp;
+ for (unsigned long i = 0; i < 10000; ++i)
+ {
+ a = ::rand();
+ b = ::rand();
+ *d = a;
+ *r = b;
+ if (test[a] != 0)
+ {
+ --i;
+ continue;
+ }
+ test.add(a,b);
+ ++d;
+ ++r;
+ DLIB_TEST(test.size() == i+1);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ DLIB_TEST(*test[dtmp[i]] == rtmp[i]);
+ }
+
+
+ delete [] dtmp;
+ delete [] rtmp;
+
+ test.clear();
+ }}
+
+ }
+
+
+
+
+ class hash_table_tester : public tester
+ {
+ public:
+ hash_table_tester (
+ ) :
+ tester ("test_hash_table",
+ "Runs tests on the hash_table component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ hash_table_kernel_test<hash_table<int,int>::kernel_1a> ();
+ dlog << LINFO << "testing kernel_1a_c";
+ hash_table_kernel_test<hash_table<int,int>::kernel_1a_c>();
+ dlog << LINFO << "testing kernel_2a";
+ hash_table_kernel_test<hash_table<int,int>::kernel_2a> ();
+ dlog << LINFO << "testing kernel_2a_c";
+ hash_table_kernel_test<hash_table<int,int>::kernel_2a_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/hog_image.cpp b/ml/dlib/dlib/test/hog_image.cpp
new file mode 100644
index 000000000..615e4d2ff
--- /dev/null
+++ b/ml/dlib/dlib/test/hog_image.cpp
@@ -0,0 +1,126 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/image_keypoint.h>
+#include <dlib/array2d.h>
+#include <dlib/rand.h>
+#include <dlib/pixel.h>
+#include <dlib/image_transforms.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.hog_image");
+
+// ----------------------------------------------------------------------------------------
+
+ class test_hog_image : public tester
+ {
+ public:
+ test_hog_image (
+ ) :
+ tester ("test_hog_image",
+ "Runs tests on the hog_image object.")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ array2d<unsigned char> img;
+ img.set_size(200,200);
+
+ assign_all_pixels(img, 0);
+
+ hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> hog1, hog1_deserialized;
+ hog_image<4,4,2,4,hog_signed_gradient,hog_full_interpolation> hog2;
+
+ hog1.load(img);
+ hog2.load(img);
+
+
+ // Just test all the coordinate mapping functions.
+
+ DLIB_TEST(hog1.get_block_rect(0,0).width() == 3*3);
+ DLIB_TEST(hog1.get_block_rect(0,0).height() == 3*3);
+ DLIB_TEST(hog2.get_block_rect(0,0).width() == 4*4);
+ DLIB_TEST(hog2.get_block_rect(0,0).height() == 4*4);
+
+ DLIB_TEST(get_rect(img).contains(hog1.get_block_rect(0,0)));
+ DLIB_TEST(get_rect(img).contains(hog1.get_block_rect(hog1.nr()-1,hog1.nc()-1)));
+ DLIB_TEST(get_rect(img).contains(hog2.get_block_rect(0,0)));
+ DLIB_TEST(get_rect(img).contains(hog2.get_block_rect(hog2.nr()-1,hog2.nc()-1)));
+
+ dlib::rand rnd;
+ for (int i = 0; i < 20000; ++i)
+ {
+ point p(rnd.get_random_16bit_number(), rnd.get_random_16bit_number());
+ p.x() -= 20000;
+ p.y() -= 20000;
+
+ DLIB_TEST((hog1.feat_to_image_space(hog1.image_to_feat_space(p)) - p).length() <= 3);
+ DLIB_TEST((hog2.feat_to_image_space(hog2.image_to_feat_space(p)) - p).length() <= 10);
+
+ DLIB_TEST_MSG((hog1.image_to_feat_space(hog1.feat_to_image_space(p)) - p).length() <= 3,
+ p << " " << hog1.feat_to_image_space(p) << " " << hog1.image_to_feat_space(hog1.feat_to_image_space(p)) );
+ DLIB_TEST((hog2.image_to_feat_space(hog2.feat_to_image_space(p)) - p).length() <= 10);
+ }
+
+
+ DLIB_TEST(hog1.feat_to_image_space(point(0,0)) == point(5,5));
+ DLIB_TEST(hog2.feat_to_image_space(point(0,0)) == point(9,9));
+
+ DLIB_TEST(hog1.feat_to_image_space(point(1,1)) == point(8,8));
+ DLIB_TEST(hog2.feat_to_image_space(point(1,1)) == point(17,17));
+
+ DLIB_TEST(hog1.image_to_feat_space(hog1.feat_to_image_space(point(0,0))) == point(0,0));
+ DLIB_TEST(hog2.image_to_feat_space(hog2.feat_to_image_space(point(0,0))) == point(0,0));
+ DLIB_TEST(hog1.image_to_feat_space(hog1.feat_to_image_space(point(1,1))) == point(1,1));
+ DLIB_TEST(hog2.image_to_feat_space(hog2.feat_to_image_space(point(1,1))) == point(1,1));
+ DLIB_TEST(hog1.image_to_feat_space(hog1.feat_to_image_space(point(1,2))) == point(1,2));
+ DLIB_TEST(hog2.image_to_feat_space(hog2.feat_to_image_space(point(1,2))) == point(1,2));
+
+
+
+ DLIB_TEST(hog1_deserialized.size() != hog1.size());
+ DLIB_TEST(hog1_deserialized.nr() != hog1.nr());
+ DLIB_TEST(hog1_deserialized.nc() != hog1.nc());
+ ostringstream sout;
+ serialize(hog1, sout);
+ istringstream sin(sout.str());
+ deserialize(hog1_deserialized, sin);
+
+ DLIB_TEST(hog1_deserialized.size() == hog1.size());
+ DLIB_TEST(hog1_deserialized.nr() == hog1.nr());
+ DLIB_TEST(hog1_deserialized.nc() == hog1.nc());
+ DLIB_TEST(hog1_deserialized(0,2) == hog1(0,2));
+ DLIB_TEST(hog1_deserialized.get_block_rect(1,2) == hog1.get_block_rect(1,2));
+ DLIB_TEST(hog1_deserialized.image_to_feat_space(hog1_deserialized.feat_to_image_space(point(0,0))) == point(0,0));
+ DLIB_TEST(hog1_deserialized.image_to_feat_space(hog1_deserialized.feat_to_image_space(point(1,1))) == point(1,1));
+ DLIB_TEST(hog1_deserialized.image_to_feat_space(hog1_deserialized.feat_to_image_space(point(1,2))) == point(1,2));
+
+
+
+ DLIB_TEST(hog1.size() > 1);
+ DLIB_TEST(hog1.nr() > 1);
+ DLIB_TEST(hog1.nc() > 1);
+ hog1.clear();
+ DLIB_TEST(hog1.size() == 0);
+ DLIB_TEST(hog1.nr() == 0);
+ DLIB_TEST(hog1.nc() == 0);
+ }
+ } a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/image.cpp b/ml/dlib/dlib/test/image.cpp
new file mode 100644
index 000000000..01f1410cf
--- /dev/null
+++ b/ml/dlib/dlib/test/image.cpp
@@ -0,0 +1,1903 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/pixel.h>
+#include <dlib/array2d.h>
+#include <dlib/image_transforms.h>
+#include <dlib/image_io.h>
+#include <dlib/matrix.h>
+#include <dlib/rand.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.image");
+
+
+ void image_test (
+ )
+ /*!
+ ensures
+ - runs tests on pixel objects and functions for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ array2d<unsigned char> img1, img2;
+
+ img1.set_size(100,100);
+
+ assign_all_pixels(img1,7);
+
+ assign_image(img2, img1);
+
+ DLIB_TEST_MSG(img1.nr() == 100 && img1.nc() == 100 &&
+ img2.nr() == 100 && img2.nc() == 100,"");
+
+
+ for (long r = 0; r < img1.nr(); ++r)
+ {
+ for (long c = 0; c < img1.nc(); ++c)
+ {
+ DLIB_TEST(img1[r][c] == 7);
+ DLIB_TEST(img2[r][c] == 7);
+ }
+ }
+
+ img2.clear();
+ DLIB_TEST(img2.size() == 0);
+ DLIB_TEST(img2.nr() == 0);
+ DLIB_TEST(img2.nc() == 0);
+ assign_image(img2, mat(img1));
+
+ DLIB_TEST_MSG(img1.nr() == 100 && img1.nc() == 100 &&
+ img2.nr() == 100 && img2.nc() == 100,"");
+
+
+ for (long r = 0; r < img1.nr(); ++r)
+ {
+ for (long c = 0; c < img1.nc(); ++c)
+ {
+ DLIB_TEST(img1[r][c] == 7);
+ DLIB_TEST(img2[r][c] == 7);
+ }
+ }
+
+
+ threshold_image(img1, img2, 4);
+
+ for (long r = 0; r < img1.nr(); ++r)
+ {
+ for (long c = 0; c < img1.nc(); ++c)
+ {
+ DLIB_TEST(img1[r][c] == 7);
+ DLIB_TEST(img2[r][c] == on_pixel);
+ }
+ }
+
+ {
+ array2d<hsi_pixel> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c].h = static_cast<unsigned char>(r*14 + c + 1);
+ img[r][c].s = static_cast<unsigned char>(r*14 + c + 2);
+ img[r][c].i = static_cast<unsigned char>(r*14 + c + 3);
+ }
+ }
+
+ ostringstream sout;
+ save_dng(img, sout);
+ istringstream sin(sout.str());
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_dng(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c].h == r*14 + c + 1);
+ DLIB_TEST(img[r][c].s == r*14 + c + 2);
+ DLIB_TEST(img[r][c].i == r*14 + c + 3);
+ }
+ }
+ }
+
+
+
+
+ {
+ array2d<rgb_alpha_pixel> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c].red = static_cast<unsigned char>(r*14 + c + 1);
+ img[r][c].green = static_cast<unsigned char>(r*14 + c + 2);
+ img[r][c].blue = static_cast<unsigned char>(r*14 + c + 3);
+ img[r][c].alpha = static_cast<unsigned char>(r*14 + c + 4);
+ }
+ }
+
+ ostringstream sout;
+ save_dng(img, sout);
+ istringstream sin(sout.str());
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_dng(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c].red == r*14 + c + 1);
+ DLIB_TEST(img[r][c].green == r*14 + c + 2);
+ DLIB_TEST(img[r][c].blue == r*14 + c + 3);
+ DLIB_TEST(img[r][c].alpha == r*14 + c + 4);
+ }
+ }
+ }
+
+#ifdef DLIB_PNG_SUPPORT
+ {
+ array2d<rgb_alpha_pixel> img;
+ array2d<rgb_pixel> img2, img3;
+ img.set_size(14,15);
+ img2.set_size(img.nr(),img.nc());
+ img3.set_size(img.nr(),img.nc());
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c].red = static_cast<unsigned char>(r*14 + c + 1);
+ img[r][c].green = static_cast<unsigned char>(r*14 + c + 2);
+ img[r][c].blue = static_cast<unsigned char>(r*14 + c + 3);
+ img[r][c].alpha = static_cast<unsigned char>(r*14 + c + 4);
+ }
+ }
+
+ save_png(img, "test.png");
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_png(img, "test.png");
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ assign_all_pixels(img2, 255);
+ assign_all_pixels(img3, 0);
+ load_png(img2, "test.png");
+ assign_image(img3, img);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c].red == r*14 + c + 1);
+ DLIB_TEST(img[r][c].green == r*14 + c + 2);
+ DLIB_TEST(img[r][c].blue == r*14 + c + 3);
+ DLIB_TEST(img[r][c].alpha == r*14 + c + 4);
+
+ DLIB_TEST(img2[r][c].red == img3[r][c].red);
+ DLIB_TEST(img2[r][c].green == img3[r][c].green);
+ DLIB_TEST(img2[r][c].blue == img3[r][c].blue);
+ }
+ }
+ }
+#endif // DLIB_PNG_SUPPORT
+
+
+
+ {
+ array2d<rgb_pixel> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c].red = static_cast<unsigned char>(r*14 + c + 1);
+ img[r][c].green = static_cast<unsigned char>(r*14 + c + 2);
+ img[r][c].blue = static_cast<unsigned char>(r*14 + c + 3);
+ }
+ }
+
+ ostringstream sout;
+ save_dng(img, sout);
+ save_bmp(img, sout);
+ save_dng(img, sout);
+ save_bmp(img, sout);
+ istringstream sin(sout.str());
+
+ for (int i = 0; i < 2; ++i)
+ {
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_dng(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c].red == r*14 + c + 1);
+ DLIB_TEST(img[r][c].green == r*14 + c + 2);
+ DLIB_TEST(img[r][c].blue == r*14 + c + 3);
+ }
+ }
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_bmp(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST_MSG(img[r][c].red == r*14 + c + 1, "got " << (int)img[r][c].red << " but expected " << r*14 + c + 1);
+ DLIB_TEST(img[r][c].green == r*14 + c + 2);
+ DLIB_TEST(img[r][c].blue == r*14 + c + 3);
+ }
+ }
+ }
+ }
+ {
+ array2d<bgr_pixel> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c].red = static_cast<unsigned char>(r*14 + c + 1);
+ img[r][c].green = static_cast<unsigned char>(r*14 + c + 2);
+ img[r][c].blue = static_cast<unsigned char>(r*14 + c + 3);
+ }
+ }
+
+ ostringstream sout;
+ save_dng(img, sout);
+ save_bmp(img, sout);
+ save_dng(img, sout);
+ save_bmp(img, sout);
+ istringstream sin(sout.str());
+
+ for (int i = 0; i < 2; ++i)
+ {
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_dng(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c].red == r*14 + c + 1);
+ DLIB_TEST(img[r][c].green == r*14 + c + 2);
+ DLIB_TEST(img[r][c].blue == r*14 + c + 3);
+ }
+ }
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_bmp(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST_MSG(img[r][c].red == r*14 + c + 1, "got " << (int)img[r][c].red << " but expected " << r*14 + c + 1);
+ DLIB_TEST(img[r][c].green == r*14 + c + 2);
+ DLIB_TEST(img[r][c].blue == r*14 + c + 3);
+ }
+ }
+ }
+ }
+
+#ifdef DLIB_PNG_SUPPORT
+ {
+ array2d<rgb_pixel> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c].red = static_cast<unsigned char>(r*14 + c + 1);
+ img[r][c].green = static_cast<unsigned char>(r*14 + c + 2);
+ img[r][c].blue = static_cast<unsigned char>(r*14 + c + 3);
+ }
+ }
+
+ save_png(img, "test.png");
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_png(img, "test.png");
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c].red == r*14 + c + 1);
+ DLIB_TEST(img[r][c].green == r*14 + c + 2);
+ DLIB_TEST(img[r][c].blue == r*14 + c + 3);
+ }
+ }
+ }
+ {
+ array2d<bgr_pixel> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c].red = static_cast<unsigned char>(r*14 + c + 1);
+ img[r][c].green = static_cast<unsigned char>(r*14 + c + 2);
+ img[r][c].blue = static_cast<unsigned char>(r*14 + c + 3);
+ }
+ }
+
+ save_png(img, "test.png");
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_png(img, "test.png");
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c].red == r*14 + c + 1);
+ DLIB_TEST(img[r][c].green == r*14 + c + 2);
+ DLIB_TEST(img[r][c].blue == r*14 + c + 3);
+ }
+ }
+ }
+#endif // DLIB_PNG_SUPPORT
+
+
+
+ {
+ array2d<unsigned short> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c] = static_cast<unsigned short>(r*14 + c + 0xF0);
+ }
+ }
+
+ ostringstream sout;
+ save_dng(img, sout);
+ istringstream sin(sout.str());
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_dng(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c] == r*14 + c + 0xF0);
+ }
+ }
+ }
+
+
+#ifdef DLIB_PNG_SUPPORT
+ {
+ array2d<unsigned short> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c] = static_cast<unsigned short>(r*14 + c + 0xF0);
+ }
+ }
+
+ save_png(img, "test.png");
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_png(img, "test.png");
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c] == r*14 + c + 0xF0);
+ }
+ }
+ }
+#endif // DLIB_PNG_SUPPORT
+
+
+
+ {
+ array2d<unsigned char> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c] = static_cast<unsigned char>(r*14 + c*111);
+ }
+ }
+
+ ostringstream sout;
+ save_dng(img, sout);
+ save_bmp(img, sout);
+ save_dng(img, sout);
+ save_bmp(img, sout);
+ istringstream sin(sout.str());
+
+ for (int i = 0; i < 2; ++i)
+ {
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_dng(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c] == static_cast<unsigned char>(r*14 + c*111));
+ }
+ }
+
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_bmp(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c] == static_cast<unsigned char>(r*14 + c*111));
+ }
+ }
+ }
+ }
+
+
+#ifdef DLIB_PNG_SUPPORT
+ {
+ array2d<unsigned char> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c] = static_cast<unsigned char>(r*14 + c);
+ }
+ }
+
+ save_png(img, "test.png");
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_png(img, "test.png");
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c] == r*14 + c);
+ }
+ }
+
+ }
+#endif // DLIB_PNG_SUPPORT
+
+
+ {
+ // in this test we will only assign pixel values that can be
+ // represented with 8 bits even though we are using a wider pixel type.
+ array2d<unsigned short> img;
+ img.set_size(14,15);
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ img[r][c] = static_cast<unsigned char>(r*14 + c);
+ }
+ }
+
+ ostringstream sout;
+ save_dng(img, sout);
+ save_bmp(img, sout);
+ save_dng(img, sout);
+ save_bmp(img, sout);
+ istringstream sin(sout.str());
+
+ for (int i = 0; i < 2; ++i)
+ {
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_dng(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c] == r*14 + c);
+ }
+ }
+
+
+ img.clear();
+ DLIB_TEST(img.nr() == 0);
+ DLIB_TEST(img.nc() == 0);
+
+ load_bmp(img, sin);
+
+ DLIB_TEST(img.nr() == 14);
+ DLIB_TEST(img.nc() == 15);
+
+ for (long r = 0; r < 14; ++r)
+ {
+ for (long c = 0; c < 15; ++c)
+ {
+ DLIB_TEST(img[r][c] == r*14 + c);
+ }
+ }
+ }
+ }
+
+ {
+ array2d<unsigned short> img1;
+ array2d<unsigned char> img2;
+ img1.set_size(10,10);
+ assign_all_pixels(img1, 0);
+
+ img1[5][5] = 10000;
+ img1[7][7] = 10000;
+
+ equalize_histogram(img1, img2);
+
+ for (long r = 0; r < img1.nr(); ++r)
+ {
+ for (long c = 0; c < img2.nc(); ++c)
+ {
+ if ((r == 5 && c == 5) ||
+ (r == 7 && c == 7))
+ {
+ DLIB_TEST(img2[r][c] == 255);
+ }
+ else
+ {
+ DLIB_TEST(img2[r][c] == 0);
+ }
+ }
+ }
+
+ }
+
+ {
+ array2d<unsigned char> img;
+ img.set_size(10,10);
+ assign_all_pixels(img, 0);
+
+ assign_border_pixels(img, 2,2, 4);
+
+ DLIB_TEST(zeros_matrix<unsigned char>(6,6) == subm(mat(img), rectangle(2,2,7,7)));
+ DLIB_TEST(uniform_matrix<unsigned char>(1,10, 4) == rowm(mat(img), 0));
+ DLIB_TEST(uniform_matrix<unsigned char>(1,10, 4) == rowm(mat(img), 1));
+ DLIB_TEST(uniform_matrix<unsigned char>(1,10, 4) == rowm(mat(img), 8));
+ DLIB_TEST(uniform_matrix<unsigned char>(1,10, 4) == rowm(mat(img), 9));
+
+ DLIB_TEST(uniform_matrix<unsigned char>(10,1, 4) == colm(mat(img), 0));
+ DLIB_TEST(uniform_matrix<unsigned char>(10,1, 4) == colm(mat(img), 1));
+ DLIB_TEST(uniform_matrix<unsigned char>(10,1, 4) == colm(mat(img), 8));
+ DLIB_TEST(uniform_matrix<unsigned char>(10,1, 4) == colm(mat(img), 9));
+
+
+ assign_border_pixels(img, 7, 7, 5);
+ DLIB_TEST(uniform_matrix<unsigned char>(10,10, 5) == mat(img));
+ assign_border_pixels(img, 37, 47, 5);
+ DLIB_TEST(uniform_matrix<unsigned char>(10,10, 5) == mat(img));
+ }
+
+ {
+ array2d<unsigned char> img;
+ img.set_size(11,11);
+ assign_all_pixels(img, 0);
+
+ assign_border_pixels(img, 2,2, 4);
+
+ DLIB_TEST(zeros_matrix<unsigned char>(7,7) == subm(mat(img), rectangle(2,2,8,8)));
+ DLIB_TEST(uniform_matrix<unsigned char>(1,11, 4) == rowm(mat(img), 0));
+ DLIB_TEST(uniform_matrix<unsigned char>(1,11, 4) == rowm(mat(img), 1));
+ DLIB_TEST(uniform_matrix<unsigned char>(1,11, 4) == rowm(mat(img), 9));
+ DLIB_TEST(uniform_matrix<unsigned char>(1,11, 4) == rowm(mat(img), 10));
+
+ DLIB_TEST(uniform_matrix<unsigned char>(11,1, 4) == colm(mat(img), 0));
+ DLIB_TEST(uniform_matrix<unsigned char>(11,1, 4) == colm(mat(img), 1));
+ DLIB_TEST(uniform_matrix<unsigned char>(11,1, 4) == colm(mat(img), 9));
+ DLIB_TEST(uniform_matrix<unsigned char>(11,1, 4) == colm(mat(img), 10));
+
+ assign_border_pixels(img, 7, 7, 5);
+ DLIB_TEST(uniform_matrix<unsigned char>(11,11, 5) == mat(img));
+ assign_border_pixels(img, 70, 57, 5);
+ DLIB_TEST(uniform_matrix<unsigned char>(11,11, 5) == mat(img));
+ }
+
+
+ }
+
+
+ template <typename T, typename pixel_type>
+ void test_integral_image (
+ )
+ {
+ dlib::rand rnd;
+
+ array2d<pixel_type> img;
+ integral_image_generic<T> int_img;
+
+ int_img.load(img);
+ DLIB_TEST(int_img.nr() == 0);
+ DLIB_TEST(int_img.nc() == 0);
+
+ // make 5 random images
+ for (int i = 0; i < 5; ++i)
+ {
+ print_spinner();
+ img.set_size(rnd.get_random_16bit_number()%200+1, rnd.get_random_16bit_number()%200+1);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = (int)rnd.get_random_8bit_number() - 100;
+ }
+ }
+
+ int_img.load(img);
+ DLIB_TEST(int_img.nr() == img.nr());
+ DLIB_TEST(int_img.nc() == img.nc());
+
+ // make 200 random rectangles
+ for (int j = 0; j < 500; ++j)
+ {
+ point p1(rnd.get_random_32bit_number()%img.nc(), rnd.get_random_32bit_number()%img.nr());
+ point p2(rnd.get_random_32bit_number()%img.nc(), rnd.get_random_32bit_number()%img.nr());
+ rectangle rect(p1,p2);
+ DLIB_TEST(int_img.get_sum_of_area(rect) == sum(subm(matrix_cast<T>(mat(img)), rect)));
+ rect = rectangle(p1,p1);
+ DLIB_TEST(int_img.get_sum_of_area(rect) == sum(subm(matrix_cast<T>(mat(img)), rect)));
+ }
+
+ }
+
+
+ }
+
+ void test_filtering2(int nr, int nc, dlib::rand& rnd)
+ {
+ print_spinner();
+ dlog << LINFO << "test_filtering2(): " << nr << " " << nc;
+ array2d<float> img(302,301);
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = rnd.get_random_gaussian();
+ }
+ }
+ matrix<float> filt = matrix_cast<float>(randm(nr,nc,rnd));
+
+ matrix<float> out = xcorr_same(mat(img),filt);
+ matrix<float> out2 = subm(conv(mat(img),flip(filt)), filt.nr()/2, filt.nc()/2, img.nr(), img.nc());
+ // make sure xcorr_same does exactly what the docs say it should.
+ DLIB_TEST(max(abs(out-out2)) < 1e-7);
+
+ // Now compare the filtering functions to xcorr_same to make sure everything does
+ // filtering in the same way.
+ array2d<float> imout(img.nr(), img.nc());
+ assign_all_pixels(imout, 10);
+ rectangle rect = spatially_filter_image(img, imout, filt);
+ border_enumerator be(get_rect(imout),rect);
+ while (be.move_next())
+ {
+ DLIB_TEST(imout[be.element().y()][be.element().x()] == 0);
+ }
+ DLIB_TEST_MSG(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-5, max(abs(subm(mat(imout),rect) - subm(out,rect))));
+
+
+ assign_all_pixels(imout, 10);
+ out = 10;
+ rect = spatially_filter_image(img, imout, filt,2,true,true);
+ be = border_enumerator(get_rect(imout),rect);
+ while (be.move_next())
+ {
+ DLIB_TEST(imout[be.element().y()][be.element().x()] == 10);
+ }
+ out += abs(xcorr_same(mat(img),filt)/2);
+ DLIB_TEST(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-7);
+
+
+ assign_all_pixels(imout, -10);
+ out = -10;
+ rect = spatially_filter_image(img, imout, filt,2,false,true);
+ be = border_enumerator(get_rect(imout),rect);
+ while (be.move_next())
+ {
+ DLIB_TEST(imout[be.element().y()][be.element().x()] == -10);
+ }
+ out += xcorr_same(mat(img),filt)/2;
+ DLIB_TEST_MSG(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-5, max(abs(subm(mat(imout),rect) - subm(out,rect))));
+
+
+
+
+ matrix<float> row_filt = matrix_cast<float>(randm(nc,1,rnd));
+ matrix<float> col_filt = matrix_cast<float>(randm(nr,1,rnd));
+ assign_all_pixels(imout, 10);
+ rect = spatially_filter_image_separable(img, imout, row_filt, col_filt);
+ out = xcorr_same(tmp(xcorr_same(mat(img),trans(row_filt))), col_filt);
+ DLIB_TEST_MSG(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-5, max(abs(subm(mat(imout),rect) - subm(out,rect))));
+
+ be = border_enumerator(get_rect(imout),rect);
+ while (be.move_next())
+ {
+ DLIB_TEST(imout[be.element().y()][be.element().x()] == 0);
+ }
+
+
+ assign_all_pixels(imout, 10);
+ out = 10;
+ rect = spatially_filter_image_separable(img, imout, row_filt, col_filt,2,true,true);
+ out += abs(xcorr_same(tmp(xcorr_same(mat(img),trans(row_filt))), col_filt)/2);
+ DLIB_TEST_MSG(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-7,
+ max(abs(subm(mat(imout),rect) - subm(out,rect))));
+
+ be = border_enumerator(get_rect(imout),rect);
+ while (be.move_next())
+ {
+ DLIB_TEST(imout[be.element().y()][be.element().x()] == 10);
+ }
+
+ }
+
+ template <typename T>
+ void test_filtering(bool use_abs, unsigned long scale )
+ {
+ print_spinner();
+ dlog << LINFO << "test_filtering(" << use_abs << "," << scale << ")";
+ array2d<T> img, img2, img3;
+ img.set_size(10,11);
+
+ assign_all_pixels(img, 10);
+
+ matrix<int,3,5> filter2;
+ filter2 = 1,1,1,1,1,
+ 1,1,1,1,1,
+ 1,1,1,1,1;
+
+ assign_all_pixels(img2,3);
+ rectangle brect = spatially_filter_image(img, img2, filter2);
+ DLIB_TEST(brect == shrink_rect(get_rect(img), filter2.nc()/2, filter2.nr()/2));
+
+ const rectangle rect(2,1,img.nc()-3,img.nr()-2);
+
+ for (long r = 0; r<img2.nr(); ++r)
+ {
+ for (long c = 0; c<img2.nc(); ++c)
+ {
+ if (rect.contains(c,r))
+ {
+ DLIB_TEST_MSG(img2[r][c] == 150, (int)img2[r][c]);
+ }
+ else
+ {
+ DLIB_TEST_MSG(img2[r][c] == 0,(int)img2[r][c]);
+ }
+ }
+ }
+
+
+ assign_all_pixels(img2,3);
+ assign_all_pixels(img3,3);
+ brect = spatially_filter_image(img, img2, filter2);
+ DLIB_TEST(brect == shrink_rect(get_rect(img), filter2.nc()/2, filter2.nr()/2));
+
+ matrix<int,1,5> row_filter;
+ matrix<int,1,3> col_filter;
+
+ row_filter = 1,1,1,1,1;
+ col_filter = 1,1,1;
+
+ spatially_filter_image_separable(img, img3, row_filter, col_filter);
+
+ DLIB_TEST(mat(img2) == mat(img3));
+
+
+ dlib::rand rnd;
+
+ for (int i = 0; i < 30; ++i)
+ {
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = rnd.get_random_8bit_number();
+ }
+ }
+
+ row_filter(0) = ((int)rnd.get_random_8bit_number() - 100)/10;
+ row_filter(1) = ((int)rnd.get_random_8bit_number() - 100)/10;
+ row_filter(2) = ((int)rnd.get_random_8bit_number() - 100)/10;
+ row_filter(3) = ((int)rnd.get_random_8bit_number() - 100)/10;
+ row_filter(4) = ((int)rnd.get_random_8bit_number() - 100)/10;
+ col_filter(0) = ((int)rnd.get_random_8bit_number() - 100)/10;
+ col_filter(1) = ((int)rnd.get_random_8bit_number() - 100)/10;
+ col_filter(2) = ((int)rnd.get_random_8bit_number() - 100)/10;
+
+ const matrix<int,3,5> filter = trans(col_filter)*row_filter;
+
+ assign_all_pixels(img2,3);
+ assign_all_pixels(img3,3);
+ // Just make sure both filtering methods give the same results.
+ rectangle brect1, brect2;
+ brect1 = spatially_filter_image(img, img2, filter, scale, use_abs);
+ brect2 = spatially_filter_image_separable(img, img3, row_filter, col_filter, scale, use_abs);
+ DLIB_TEST(mat(img2) == mat(img3));
+
+ DLIB_TEST(brect1 == shrink_rect(get_rect(img), filter.nc()/2, filter.nr()/2));
+ DLIB_TEST(brect1 == brect2);
+ }
+
+ {
+ array2d<int> img, img2;
+ img.set_size(3,4);
+
+ matrix<int> filter(3,3);
+ filter = 1;
+ assign_all_pixels(img,-1);
+
+ spatially_filter_image(img,img2,filter);
+
+ DLIB_TEST(img2[0][0] == 0);
+ DLIB_TEST(img2[0][1] == 0);
+ DLIB_TEST(img2[0][2] == 0);
+ DLIB_TEST(img2[0][3] == 0);
+
+ DLIB_TEST(img2[1][0] == 0);
+ DLIB_TEST(img2[1][1] == -9);
+ DLIB_TEST(img2[1][2] == -9);
+ DLIB_TEST(img2[1][3] == 0);
+
+ DLIB_TEST(img2[2][0] == 0);
+ DLIB_TEST(img2[2][1] == 0);
+ DLIB_TEST(img2[2][2] == 0);
+ DLIB_TEST(img2[2][3] == 0);
+
+ assign_all_pixels(img,-1);
+
+ spatially_filter_image(img,img2,filter,2,true);
+
+ DLIB_TEST(img2[0][0] == 0);
+ DLIB_TEST(img2[0][1] == 0);
+ DLIB_TEST(img2[0][2] == 0);
+ DLIB_TEST(img2[0][3] == 0);
+
+ DLIB_TEST(img2[1][0] == 0);
+ DLIB_TEST(img2[1][1] == 4);
+ DLIB_TEST(img2[1][2] == 4);
+ DLIB_TEST(img2[1][3] == 0);
+
+ DLIB_TEST(img2[2][0] == 0);
+ DLIB_TEST(img2[2][1] == 0);
+ DLIB_TEST(img2[2][2] == 0);
+ DLIB_TEST(img2[2][3] == 0);
+
+ matrix<int> rowf(3,1), colf(3,1);
+ rowf = 1;
+ colf = 1;
+ assign_all_pixels(img,-1);
+
+ spatially_filter_image_separable(img,img2,rowf,colf);
+ DLIB_TEST(img2[0][0] == 0);
+ DLIB_TEST(img2[0][1] == 0);
+ DLIB_TEST(img2[0][2] == 0);
+ DLIB_TEST(img2[0][3] == 0);
+
+ DLIB_TEST(img2[1][0] == 0);
+ DLIB_TEST(img2[1][1] == -9);
+ DLIB_TEST(img2[1][2] == -9);
+ DLIB_TEST(img2[1][3] == 0);
+
+ DLIB_TEST(img2[2][0] == 0);
+ DLIB_TEST(img2[2][1] == 0);
+ DLIB_TEST(img2[2][2] == 0);
+ DLIB_TEST(img2[2][3] == 0);
+
+ spatially_filter_image_separable(img,img2,rowf,colf,1,true);
+ DLIB_TEST(img2[0][0] == 0);
+ DLIB_TEST(img2[0][1] == 0);
+ DLIB_TEST(img2[0][2] == 0);
+ DLIB_TEST(img2[0][3] == 0);
+
+ DLIB_TEST(img2[1][0] == 0);
+ DLIB_TEST(img2[1][1] == 9);
+ DLIB_TEST(img2[1][2] == 9);
+ DLIB_TEST(img2[1][3] == 0);
+
+ DLIB_TEST(img2[2][0] == 0);
+ DLIB_TEST(img2[2][1] == 0);
+ DLIB_TEST(img2[2][2] == 0);
+ DLIB_TEST(img2[2][3] == 0);
+
+ assign_all_pixels(img2, 3);
+ spatially_filter_image_separable(img,img2,rowf,colf,1,true, true);
+ DLIB_TEST(img2[0][0] == 3);
+ DLIB_TEST(img2[0][1] == 3);
+ DLIB_TEST(img2[0][2] == 3);
+ DLIB_TEST(img2[0][3] == 3);
+
+ DLIB_TEST(img2[1][0] == 3);
+ DLIB_TEST_MSG(img2[1][1] == 9+3, img2[1][1] );
+ DLIB_TEST(img2[1][2] == 9+3);
+ DLIB_TEST(img2[1][3] == 3);
+
+ DLIB_TEST(img2[2][0] == 3);
+ DLIB_TEST(img2[2][1] == 3);
+ DLIB_TEST(img2[2][2] == 3);
+ DLIB_TEST(img2[2][3] == 3);
+ }
+ {
+ array2d<double> img, img2;
+ img.set_size(3,4);
+
+ matrix<double> filter(3,3);
+ filter = 1;
+ assign_all_pixels(img,-1);
+
+ spatially_filter_image(img,img2,filter,2);
+
+ DLIB_TEST(img2[0][0] == 0);
+ DLIB_TEST(img2[0][1] == 0);
+ DLIB_TEST(img2[0][2] == 0);
+ DLIB_TEST(img2[0][3] == 0);
+
+ DLIB_TEST(img2[1][0] == 0);
+ DLIB_TEST(std::abs(img2[1][1] - -4.5) < 1e-14);
+ DLIB_TEST(std::abs(img2[1][2] - -4.5) < 1e-14);
+ DLIB_TEST(img2[1][3] == 0);
+
+ DLIB_TEST(img2[2][0] == 0);
+ DLIB_TEST(img2[2][1] == 0);
+ DLIB_TEST(img2[2][2] == 0);
+ DLIB_TEST(img2[2][3] == 0);
+
+ }
+ {
+ array2d<double> img, img2;
+ img.set_size(3,4);
+ img2.set_size(3,4);
+ assign_all_pixels(img2, 8);
+
+ matrix<double> filter(3,3);
+ filter = 1;
+ assign_all_pixels(img,-1);
+
+ spatially_filter_image(img,img2,filter,2, false, true);
+
+ DLIB_TEST(img2[0][0] == 8);
+ DLIB_TEST(img2[0][1] == 8);
+ DLIB_TEST(img2[0][2] == 8);
+ DLIB_TEST(img2[0][3] == 8);
+
+ DLIB_TEST(img2[1][0] == 8);
+ DLIB_TEST(std::abs(img2[1][1] - -4.5 - 8) < 1e-14);
+ DLIB_TEST(std::abs(img2[1][2] - -4.5 - 8) < 1e-14);
+ DLIB_TEST(img2[1][3] == 8);
+
+ DLIB_TEST(img2[2][0] == 8);
+ DLIB_TEST(img2[2][1] == 8);
+ DLIB_TEST(img2[2][2] == 8);
+ DLIB_TEST(img2[2][3] == 8);
+
+ }
+ }
+
+
+ void test_zero_border_pixels(
+ )
+ {
+ array2d<unsigned char> img;
+ img.set_size(4,5);
+
+ assign_all_pixels(img, 1);
+ zero_border_pixels(img, 2,1);
+
+ DLIB_TEST(img[0][0] == 0);
+ DLIB_TEST(img[1][0] == 0);
+ DLIB_TEST(img[2][0] == 0);
+ DLIB_TEST(img[3][0] == 0);
+ DLIB_TEST(img[0][1] == 0);
+ DLIB_TEST(img[1][1] == 0);
+ DLIB_TEST(img[2][1] == 0);
+ DLIB_TEST(img[3][1] == 0);
+
+ DLIB_TEST(img[0][3] == 0);
+ DLIB_TEST(img[1][3] == 0);
+ DLIB_TEST(img[2][3] == 0);
+ DLIB_TEST(img[3][3] == 0);
+ DLIB_TEST(img[0][4] == 0);
+ DLIB_TEST(img[1][4] == 0);
+ DLIB_TEST(img[2][4] == 0);
+ DLIB_TEST(img[3][4] == 0);
+
+ DLIB_TEST(img[0][2] == 0);
+ DLIB_TEST(img[3][2] == 0);
+
+ DLIB_TEST(img[1][2] == 1);
+ DLIB_TEST(img[2][2] == 1);
+
+ rectangle rect = get_rect(img);
+ rect.left()+=2;
+ rect.top()+=1;
+ rect.right()-=2;
+ rect.bottom()-=1;
+ assign_all_pixels(img, 1);
+ zero_border_pixels(img, rect);
+
+ DLIB_TEST(img[0][0] == 0);
+ DLIB_TEST(img[1][0] == 0);
+ DLIB_TEST(img[2][0] == 0);
+ DLIB_TEST(img[3][0] == 0);
+ DLIB_TEST(img[0][1] == 0);
+ DLIB_TEST(img[1][1] == 0);
+ DLIB_TEST(img[2][1] == 0);
+ DLIB_TEST(img[3][1] == 0);
+
+ DLIB_TEST(img[0][3] == 0);
+ DLIB_TEST(img[1][3] == 0);
+ DLIB_TEST(img[2][3] == 0);
+ DLIB_TEST(img[3][3] == 0);
+ DLIB_TEST(img[0][4] == 0);
+ DLIB_TEST(img[1][4] == 0);
+ DLIB_TEST(img[2][4] == 0);
+ DLIB_TEST(img[3][4] == 0);
+
+ DLIB_TEST(img[0][2] == 0);
+ DLIB_TEST(img[3][2] == 0);
+
+ DLIB_TEST(img[1][2] == 1);
+ DLIB_TEST(img[2][2] == 1);
+
+ rect.right()+=1;
+ assign_all_pixels(img, 1);
+ zero_border_pixels(img, rect);
+ DLIB_TEST(img[0][0] == 0);
+ DLIB_TEST(img[1][0] == 0);
+ DLIB_TEST(img[2][0] == 0);
+ DLIB_TEST(img[3][0] == 0);
+ DLIB_TEST(img[0][1] == 0);
+ DLIB_TEST(img[1][1] == 0);
+ DLIB_TEST(img[2][1] == 0);
+ DLIB_TEST(img[3][1] == 0);
+
+ DLIB_TEST(img[0][3] == 0);
+ DLIB_TEST(img[1][3] == 1);
+ DLIB_TEST(img[2][3] == 1);
+ DLIB_TEST(img[3][3] == 0);
+ DLIB_TEST(img[0][4] == 0);
+ DLIB_TEST(img[1][4] == 0);
+ DLIB_TEST(img[2][4] == 0);
+ DLIB_TEST(img[3][4] == 0);
+
+ DLIB_TEST(img[0][2] == 0);
+ DLIB_TEST(img[3][2] == 0);
+
+ DLIB_TEST(img[1][2] == 1);
+ DLIB_TEST(img[2][2] == 1);
+ }
+
+
+ void test_label_connected_blobs()
+ {
+ array2d<unsigned char> img;
+ img.set_size(400,401);
+
+ assign_all_pixels(img,0);
+
+ rectangle rect1, rect2, rect3;
+
+ rect1 = centered_rect(99,120, 50,70);
+ rect2 = centered_rect(199,80, 34,68);
+ rect3 = centered_rect(249,180, 120,78);
+
+ fill_rect(img, rect1, 255);
+ fill_rect(img, rect2, 255);
+ fill_rect(img, rect3, 255);
+
+ array2d<unsigned char> labels;
+ unsigned long num;
+ num = label_connected_blobs(img,
+ zero_pixels_are_background(),
+ neighbors_8(),
+ connected_if_both_not_zero(),
+ labels);
+
+ DLIB_TEST(num == 4);
+ DLIB_TEST(labels.nr() == img.nr());
+ DLIB_TEST(labels.nc() == img.nc());
+
+ const unsigned char l1 = labels[rect1.top()][rect1.left()];
+ const unsigned char l2 = labels[rect2.top()][rect2.left()];
+ const unsigned char l3 = labels[rect3.top()][rect3.left()];
+
+ DLIB_TEST(l1 != 0 && l2 != 0 && l3 != 0);
+ DLIB_TEST(l1 != l2 && l1 != l3 && l2 != l3);
+
+ for (long r = 0; r < labels.nr(); ++r)
+ {
+ for (long c = 0; c < labels.nc(); ++c)
+ {
+ if (rect1.contains(c,r))
+ {
+ DLIB_TEST(labels[r][c] == l1);
+ }
+ else if (rect2.contains(c,r))
+ {
+ DLIB_TEST(labels[r][c] == l2);
+ }
+ else if (rect3.contains(c,r))
+ {
+ DLIB_TEST(labels[r][c] == l3);
+ }
+ else
+ {
+ DLIB_TEST(labels[r][c] == 0);
+ }
+ }
+ }
+ }
+
+ void test_label_connected_blobs2()
+ {
+ array2d<unsigned char> img;
+ img.set_size(400,401);
+
+ assign_all_pixels(img,0);
+
+ rectangle rect1, rect2, rect3;
+
+ rect1 = centered_rect(99,120, 50,70);
+ rect2 = centered_rect(199,80, 34,68);
+ rect3 = centered_rect(249,180, 120,78);
+
+ fill_rect(img, rect1, 255);
+ fill_rect(img, rect2, 253);
+ fill_rect(img, rect3, 255);
+
+ array2d<unsigned char> labels;
+ unsigned long num;
+ num = label_connected_blobs(img,
+ nothing_is_background(),
+ neighbors_4(),
+ connected_if_equal(),
+ labels);
+
+ DLIB_TEST(num == 5);
+ DLIB_TEST(labels.nr() == img.nr());
+ DLIB_TEST(labels.nc() == img.nc());
+
+ const unsigned char l0 = labels[0][0];
+ const unsigned char l1 = labels[rect1.top()][rect1.left()];
+ const unsigned char l2 = labels[rect2.top()][rect2.left()];
+ const unsigned char l3 = labels[rect3.top()][rect3.left()];
+
+ DLIB_TEST(l0 != 0 && l1 != 0 && l2 != 0 && l3 != 0);
+ DLIB_TEST(l1 != l2 && l1 != l3 && l2 != l3 &&
+ l0 != l1 && l0 != l2 && l0 != l3);
+
+ for (long r = 0; r < labels.nr(); ++r)
+ {
+ for (long c = 0; c < labels.nc(); ++c)
+ {
+ if (rect1.contains(c,r))
+ {
+ DLIB_TEST(labels[r][c] == l1);
+ }
+ else if (rect2.contains(c,r))
+ {
+ DLIB_TEST(labels[r][c] == l2);
+ }
+ else if (rect3.contains(c,r))
+ {
+ DLIB_TEST(labels[r][c] == l3);
+ }
+ else
+ {
+ DLIB_TEST(labels[r][c] == l0);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void downsample_image (
+ const unsigned long downsample,
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ bool add_to
+ )
+ {
+ out_img.set_size((in_img.nr()+downsample-1)/downsample,
+ (in_img.nc()+downsample-1)/downsample);
+
+ for (long r = 0; r < out_img.nr(); ++r)
+ {
+ for (long c = 0; c < out_img.nc(); ++c)
+ {
+ if (add_to)
+ out_img[r][c] += in_img[r*downsample][c*downsample];
+ else
+ out_img[r][c] = in_img[r*downsample][c*downsample];
+ }
+ }
+ }
+
+ template <
+ typename in_image_type,
+ typename out_image_type,
+ typename EXP1,
+ typename EXP2,
+ typename T
+ >
+ void test_spatially_filter_image_separable_down_simple (
+ const unsigned long downsample,
+ const in_image_type& in_img,
+ out_image_type& out_img,
+ const matrix_exp<EXP1>& row_filter,
+ const matrix_exp<EXP2>& col_filter,
+ T scale,
+ bool use_abs = false,
+ bool add_to = false
+ )
+ {
+ out_image_type temp;
+ spatially_filter_image_separable(in_img, temp, row_filter, col_filter, scale, use_abs, false);
+ downsample_image(downsample, temp, out_img, add_to);
+ }
+
+
+
+
+ template <unsigned long downsample>
+ void test_downsampled_filtering_helper(long row_filt_size, long col_filt_size)
+ {
+ print_spinner();
+ dlog << LTRACE << "***********************************";
+ dlog << LTRACE << "downsample: " << downsample;
+ dlog << LTRACE << "row_filt_size: "<< row_filt_size;
+ dlog << LTRACE << "col_filt_size: "<< col_filt_size;
+ dlib::rand rnd;
+ array2d<int> out1, out2;
+ for (long nr = 0; nr < 3; ++nr)
+ {
+ for (long nc = 0; nc < 3; ++nc)
+ {
+ dlog << LTRACE << "nr: "<< nr;
+ dlog << LTRACE << "nc: "<< nc;
+ array2d<unsigned char> img(25+nr,25+nc);
+ for (int k = 0; k < 5; ++k)
+ {
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = rnd.get_random_8bit_number();
+ }
+ }
+
+ matrix<int,0,1> row_filter(row_filt_size);
+ matrix<int,0,1> col_filter(col_filt_size);
+
+ row_filter = matrix_cast<int>(10*randm(row_filt_size,1, rnd));
+ col_filter = matrix_cast<int>(10*randm(col_filt_size,1, rnd));
+
+ row_filter -= 3;
+ col_filter -= 3;
+
+
+ test_spatially_filter_image_separable_down_simple(downsample, img, out1, row_filter, col_filter,1 );
+ spatially_filter_image_separable_down(downsample, img, out2, row_filter, col_filter);
+
+ DLIB_TEST(get_rect(out1) == get_rect(out2));
+ DLIB_TEST(mat(out1) == mat(out2));
+
+ test_spatially_filter_image_separable_down_simple(downsample, img, out1, row_filter, col_filter,3, true, true );
+ spatially_filter_image_separable_down(downsample, img, out2, row_filter, col_filter, 3, true, true);
+
+ DLIB_TEST(get_rect(out1) == get_rect(out2));
+ DLIB_TEST(mat(out1) == mat(out2));
+
+ }
+ }
+ }
+ }
+
+ void test_downsampled_filtering()
+ {
+ test_downsampled_filtering_helper<1>(5,5);
+ test_downsampled_filtering_helper<2>(5,5);
+ test_downsampled_filtering_helper<3>(5,5);
+ test_downsampled_filtering_helper<1>(3,5);
+ test_downsampled_filtering_helper<2>(3,5);
+ test_downsampled_filtering_helper<3>(3,5);
+ test_downsampled_filtering_helper<1>(5,3);
+ test_downsampled_filtering_helper<2>(5,3);
+ test_downsampled_filtering_helper<3>(5,3);
+
+ test_downsampled_filtering_helper<1>(3,3);
+ test_downsampled_filtering_helper<2>(3,3);
+ test_downsampled_filtering_helper<3>(3,3);
+
+ test_downsampled_filtering_helper<1>(1,1);
+ test_downsampled_filtering_helper<2>(1,1);
+ test_downsampled_filtering_helper<3>(1,1);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void test_segment_image()
+ {
+ print_spinner();
+ array2d<T> img(100,100);
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ if (c < 50 || r < 50)
+ assign_pixel(img[r][c], 0);
+ else
+ assign_pixel(img[r][c], 255);
+ }
+ }
+
+ array2d<unsigned long> out;
+ segment_image(img, out);
+
+ DLIB_TEST(get_rect(img) == get_rect(out));
+ const unsigned long v1 = out[0][0];
+ const unsigned long v2 = out[90][90];
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ if (c < 50 || r < 50)
+ {
+ DLIB_TEST(out[r][c] == v1);
+ }
+ else
+ {
+ DLIB_TEST(out[r][c] == v2);
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void test_dng_floats(double scale)
+ {
+ dlog << LINFO << "in test_dng_floats";
+ print_spinner();
+ array2d<T> img(100,101);
+
+ dlib::rand rnd;
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ T val = rnd.get_random_double()*scale;
+ img[r][c] = val;
+
+ // Lets the float_details object while we are here doing this stuff.
+ float_details temp = val;
+ T val2 = temp;
+ // for the same type we should exactly reproduce the value (unless
+ // it's long double and then maybe it's slightly different).
+ if (is_same_type<T,long double>::value)
+ {
+ DLIB_TEST(std::abs(val2-val) < scale*std::numeric_limits<T>::epsilon());
+ }
+ else
+ {
+ DLIB_TEST(val2 == val);
+ }
+
+ float valf = temp;
+ double vald = temp;
+ long double vall = temp;
+
+ DLIB_TEST(std::abs(valf-val) < scale*std::numeric_limits<float>::epsilon());
+ DLIB_TEST(std::abs(vald-val) < scale*std::numeric_limits<double>::epsilon());
+ DLIB_TEST(std::abs(vall-val) < scale*std::numeric_limits<long double>::epsilon());
+ }
+ }
+
+ ostringstream sout;
+ save_dng(img, sout);
+ istringstream sin;
+
+ array2d<float> img1;
+ array2d<double> img2;
+ array2d<long double> img3;
+
+ sin.clear(); sin.str(sout.str());
+ load_dng(img1, sin);
+
+ sin.clear(); sin.str(sout.str());
+ load_dng(img2, sin);
+
+ sin.clear(); sin.str(sout.str());
+ load_dng(img3, sin);
+
+ DLIB_TEST(img.nr() == img1.nr());
+ DLIB_TEST(img.nr() == img2.nr());
+ DLIB_TEST(img.nr() == img3.nr());
+ DLIB_TEST(img.nc() == img1.nc());
+ DLIB_TEST(img.nc() == img2.nc());
+ DLIB_TEST(img.nc() == img3.nc());
+
+ DLIB_TEST(max(abs(mat(img) - matrix_cast<T>(mat(img1)))) < scale*std::numeric_limits<float>::epsilon());
+ DLIB_TEST(max(abs(mat(img) - matrix_cast<T>(mat(img2)))) < scale*std::numeric_limits<double>::epsilon());
+ DLIB_TEST(max(abs(mat(img) - matrix_cast<T>(mat(img3)))) < scale*std::numeric_limits<long double>::epsilon());
+ }
+
+ void test_dng_float_int()
+ {
+ dlog << LINFO << "in test_dng_float_int";
+ print_spinner();
+
+ array2d<uint16> img;
+ assign_image(img, gaussian_randm(101,100)*10000);
+
+ ostringstream sout;
+ save_dng(img, sout);
+ istringstream sin(sout.str());
+ array2d<double> img2;
+ load_dng(img2, sin);
+ sout.clear(); sout.str("");
+
+ save_dng(img2, sout);
+ sin.clear(); sin.str(sout.str());
+ array2d<uint16> img3;
+ load_dng(img3, sin);
+
+ // this whole thing should have been totally lossless.
+ DLIB_TEST(mat(img) == mat(img3));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void test_filtering_center (
+ dlib::rand& rnd
+ )
+ {
+ array2d<T> img(rnd.get_random_32bit_number()%100+1,
+ rnd.get_random_32bit_number()%100+1);
+ matrix<T> filt(rnd.get_random_32bit_number()%10+1,
+ rnd.get_random_32bit_number()%10+1);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = rnd.get_random_32bit_number()%100;
+ }
+ }
+ for (long r = 0; r < filt.nr(); ++r)
+ {
+ for (long c = 0; c < filt.nc(); ++c)
+ {
+ filt(r,c) = rnd.get_random_32bit_number()%100;
+ }
+ }
+
+ array2d<T> out;
+ const rectangle area = spatially_filter_image(img, out, filt);
+
+ for (long r = 0; r < out.nr(); ++r)
+ {
+ for (long c = 0; c < out.nc(); ++c)
+ {
+ const rectangle rect = centered_rect(point(c,r), filt.nc(), filt.nr());
+ if (get_rect(out).contains(rect))
+ {
+ T val = sum(pointwise_multiply(filt, subm(mat(img),rect)));
+ DLIB_TEST_MSG(val == out[r][c],"err: " << val-out[r][c]);
+ DLIB_TEST(area.contains(point(c,r)));
+ }
+ else
+ {
+ DLIB_TEST(!area.contains(point(c,r)));
+ }
+ }
+ }
+ }
+
+ template <typename T>
+ void test_separable_filtering_center (
+ dlib::rand& rnd
+ )
+ {
+ array2d<T> img(rnd.get_random_32bit_number()%100+1,
+ rnd.get_random_32bit_number()%100+1);
+ matrix<T,1,0> row_filt(rnd.get_random_32bit_number()%10+1);
+ matrix<T,0,1> col_filt(rnd.get_random_32bit_number()%10+1);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = rnd.get_random_32bit_number()%10;
+ }
+ }
+ for (long r = 0; r < row_filt.size(); ++r)
+ {
+ row_filt(r) = rnd.get_random_32bit_number()%10;
+ }
+ for (long r = 0; r < col_filt.size(); ++r)
+ {
+ col_filt(r) = rnd.get_random_32bit_number()%10;
+ }
+
+ array2d<T> out;
+ const rectangle area = spatially_filter_image_separable(img, out, row_filt, col_filt);
+
+ for (long r = 0; r < out.nr(); ++r)
+ {
+ for (long c = 0; c < out.nc(); ++c)
+ {
+ const rectangle rect = centered_rect(point(c,r), row_filt.size(), col_filt.size());
+ if (get_rect(out).contains(rect))
+ {
+ T val = sum(pointwise_multiply(col_filt*row_filt, subm(mat(img),rect)));
+ DLIB_TEST_MSG(val == out[r][c],"err: " << val-out[r][c]);
+
+ DLIB_TEST(area.contains(point(c,r)));
+ }
+ else
+ {
+ DLIB_TEST(!area.contains(point(c,r)));
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void run_hough_test()
+ {
+ array2d<unsigned char> img(300,300);
+
+
+ for (int k = -2; k <= 2; ++k)
+ {
+ print_spinner();
+ running_stats<double> rs;
+ array2d<int> himg;
+ hough_transform ht(200+k);
+ double angle1 = 0;
+ double angle2 = 0;
+ const int len = 90;
+ // Draw a bunch of random lines, hough transform them, then make sure the hough
+ // transform detects them accurately.
+ for (int i = 0; i < 500; ++i)
+ {
+ point cent = center(get_rect(img));
+ point arc = cent + point(len,0);
+ arc = rotate_point(cent, arc, angle1);
+
+ point l = arc + point(500,0);
+ point r = arc - point(500,0);
+ l = rotate_point(arc, l, angle2);
+ r = rotate_point(arc, r, angle2);
+
+ angle1 += pi/13;
+ angle2 += pi/40;
+
+ assign_all_pixels(img, 0);
+ draw_line(img, l, r, 255);
+ rectangle box = translate_rect(get_rect(ht),point(50,50));
+ ht(img, box, himg);
+
+ point p = max_point(mat(himg));
+ DLIB_TEST(himg[p.y()][p.x()] > 255*3);
+
+ l -= point(50,50);
+ r -= point(50,50);
+ std::pair<point,point> line = ht.get_line(p);
+ // make sure the best scoring hough point matches the line we drew.
+ double dist1 = distance_to_line(make_pair(l,r), line.first);
+ double dist2 = distance_to_line(make_pair(l,r), line.second);
+ //cout << "DIST1: " << dist1 << endl;
+ //cout << "DIST2: " << dist2 << endl;
+ rs.add(dist1);
+ rs.add(dist2);
+ DLIB_TEST(dist1 < 2.5);
+ DLIB_TEST(dist2 < 2.5);
+ }
+ //cout << "rs.mean(): " << rs.mean() << endl;
+ DLIB_TEST(rs.mean() < 0.7);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_extract_image_chips()
+ {
+ dlib::rand rnd;
+
+ // Make sure that cropping a white box out of a larger white image always produces an
+ // exact white box. This should catch any bad border effects from a messed up internal
+ // cropping.
+ for (int iter = 0; iter < 1000; ++iter)
+ {
+ print_spinner();
+ const long nr = rnd.get_random_32bit_number()%100 + 1;
+ const long nc = rnd.get_random_32bit_number()%100 + 1;
+ const long size = rnd.get_random_32bit_number()%10000 + 4;
+ const double angle = rnd.get_random_double() * pi;
+
+ matrix<int> img(501,501), chip;
+ img = 255;
+ chip_details details(centered_rect(center(get_rect(img)),nr,nc), size, angle);
+ extract_image_chip(img, details, chip);
+ DLIB_TEST_MSG(max(abs(chip-255))==0,"nr: " << nr << " nc: "<< nc << " size: " << size << " angle: " << angle
+ << " error: " << max(abs(chip-255)) );
+ }
+
+
+ {
+ // Make sure that the interpolation in extract_image_chip() keeps stuff in the
+ // right places.
+
+ matrix<unsigned char> img(10,10), chip;
+ img = 0;
+ img(1,1) = 255;
+ img(8,8) = 255;
+
+ extract_image_chip(img, chip_details(get_rect(img), 9*9), chip);
+
+ DLIB_TEST(chip(1,1) == 195);
+ DLIB_TEST(chip(7,7) == 195);
+ chip(1,1) -= 195;
+ chip(7,7) -= 195;
+ DLIB_TEST(sum(matrix_cast<int>(chip)) == 0);
+ }
+
+
+
+ // Test the rotation ability of extract_image_chip(). Do this by drawing a line and
+ // then rotating it so it's horizontal. Check that it worked correctly by hough
+ // transforming it.
+ hough_transform ht(151);
+ matrix<unsigned char> img(300,300);
+ for (int iter = 0; iter < 1000; ++iter)
+ {
+ print_spinner();
+ img = 0;
+ const int len = 9000;
+ point cent = center(get_rect(img));
+ point l = cent + point(len,0);
+ point r = cent - point(len,0);
+ const double angle = rnd.get_random_double()*pi*3;
+ l = rotate_point(cent, l, angle);
+ r = rotate_point(cent, r, angle);
+ draw_line(img, l, r, 255);
+
+
+ const long wsize = rnd.get_random_32bit_number()%350 + 150;
+
+ matrix<unsigned char> temp;
+ chip_details details(centered_rect(center(get_rect(img)), wsize,wsize), chip_dims(ht.size(),ht.size()), angle);
+ extract_image_chip(img, details, temp);
+
+
+ matrix<long> tform;
+ ht(temp, get_rect(temp), tform);
+ std::pair<point,point> line = ht.get_line(max_point(tform));
+
+ DLIB_TEST_MSG(line.first.y() == line.second.y()," wsize: " << wsize);
+ DLIB_TEST(length(line.first-line.second) > 100);
+ DLIB_TEST(length((line.first+line.second)/2.0 - center(get_rect(temp))) <= 1);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class image_tester : public tester
+ {
+ public:
+ image_tester (
+ ) :
+ tester ("test_image",
+ "Runs tests on the image processing objects and functions.")
+ {}
+
+ void perform_test (
+ )
+ {
+ image_test();
+ run_hough_test();
+ test_extract_image_chips();
+ test_integral_image<long, unsigned char>();
+ test_integral_image<double, int>();
+ test_integral_image<long, unsigned char>();
+ test_integral_image<double, float>();
+
+ test_zero_border_pixels();
+
+ test_filtering<unsigned char>(false,1);
+ test_filtering<unsigned char>(true,1);
+ test_filtering<unsigned char>(false,3);
+ test_filtering<unsigned char>(true,3);
+ test_filtering<int>(false,1);
+ test_filtering<int>(true,1);
+ test_filtering<int>(false,3);
+ test_filtering<int>(true,3);
+
+ test_label_connected_blobs();
+ test_label_connected_blobs2();
+ test_downsampled_filtering();
+
+ test_segment_image<unsigned char>();
+ test_segment_image<unsigned short>();
+ test_segment_image<double>();
+ test_segment_image<int>();
+ test_segment_image<rgb_pixel>();
+ test_segment_image<rgb_alpha_pixel>();
+
+ test_dng_floats<float>(1);
+ test_dng_floats<double>(1);
+ test_dng_floats<long double>(1);
+ test_dng_floats<float>(1e30);
+ test_dng_floats<double>(1e30);
+ test_dng_floats<long double>(1e30);
+
+ test_dng_float_int();
+
+ dlib::rand rnd;
+ for (int i = 0; i < 10; ++i)
+ {
+ // the spatial filtering stuff is the same as xcorr_same when the filter
+ // sizes are odd.
+ test_filtering2(3,3,rnd);
+ test_filtering2(5,5,rnd);
+ test_filtering2(7,7,rnd);
+ }
+
+ for (int i = 0; i < 100; ++i)
+ test_filtering_center<float>(rnd);
+ for (int i = 0; i < 100; ++i)
+ test_filtering_center<int>(rnd);
+ for (int i = 0; i < 100; ++i)
+ test_separable_filtering_center<int>(rnd);
+ for (int i = 0; i < 100; ++i)
+ test_separable_filtering_center<float>(rnd);
+
+ {
+ print_spinner();
+ matrix<unsigned char> img(40,80);
+ assign_all_pixels(img, 255);
+ skeleton(img);
+
+ DLIB_TEST(sum(matrix_cast<int>(mat(img)))/255 == 40);
+ draw_line(img, point(20,19), point(59,19), 00);
+ DLIB_TEST(sum(matrix_cast<int>(mat(img))) == 0);
+ }
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/iosockstream.cpp b/ml/dlib/dlib/test/iosockstream.cpp
new file mode 100644
index 000000000..c68de7bd3
--- /dev/null
+++ b/ml/dlib/dlib/test/iosockstream.cpp
@@ -0,0 +1,181 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/iosockstream.h>
+#include <dlib/server.h>
+#include <vector>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+
+ logger dlog("test.iosockstream");
+
+// ----------------------------------------------------------------------------------------
+
+ class serv : public server_iostream
+ {
+ virtual void on_connect (
+ std::istream& in,
+ std::ostream& out,
+ const std::string& ,
+ const std::string& ,
+ unsigned short ,
+ unsigned short ,
+ uint64
+ )
+ {
+ try
+ {
+ dlog << LINFO << "serv1: serving connection";
+
+ std::string temp;
+ in >> temp;
+ DLIB_TEST(temp == "word");
+ in >> temp;
+ DLIB_TEST(temp == "another");
+ out << "yay words ";
+ in >> temp;
+ DLIB_TEST(temp == "yep");
+ }
+ catch (error& e)
+ {
+ error_string = e.what();
+ }
+
+ }
+
+
+ public:
+ std::string error_string;
+
+ };
+
+ class serv2 : public server_iostream
+ {
+ virtual void on_connect (
+ std::istream& ,
+ std::ostream& out,
+ const std::string& ,
+ const std::string& ,
+ unsigned short ,
+ unsigned short ,
+ uint64
+ )
+ {
+ try
+ {
+ dlog << LINFO << "serv2: serving connection";
+
+ out << "one two three four five";
+ }
+ catch (error& e)
+ {
+ error_string = e.what();
+ }
+
+ }
+
+
+ public:
+ std::string error_string;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void test1()
+ {
+ dlog << LINFO << "in test1()";
+ serv theserv;
+ theserv.set_listening_port(12345);
+ theserv.start_async();
+
+ // wait a little bit to make sure the server has started listening before we try
+ // to connect to it.
+ dlib::sleep(500);
+
+ for (int i = 0; i < 200; ++i)
+ {
+ dlog << LINFO << "i: " << i;
+ print_spinner();
+ iosockstream stream("localhost:12345");
+
+ stream << "word another ";
+ std::string temp;
+ stream >> temp;
+ DLIB_TEST(temp == "yay");
+ stream >> temp;
+ DLIB_TEST(temp == "words");
+ stream << "yep ";
+ }
+
+ // Just to make sure the server finishes processing the last connection before
+ // we kill it and accidentally trigger a DLIB_TEST().
+ dlib::sleep(500);
+
+ if (theserv.error_string.size() != 0)
+ throw error(theserv.error_string);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test2()
+ {
+ dlog << LINFO << "in test2()";
+ serv2 theserv;
+ theserv.set_listening_port(12345);
+ theserv.start_async();
+
+ // wait a little bit to make sure the server has started listening before we try
+ // to connect to it.
+ dlib::sleep(500);
+
+ for (int i = 0; i < 200; ++i)
+ {
+ dlog << LINFO << "i: " << i;
+ print_spinner();
+ iosockstream stream("localhost:12345");
+
+ std::string temp;
+ stream >> temp; DLIB_TEST(temp == "one");
+ stream >> temp; DLIB_TEST(temp == "two");
+ stream >> temp; DLIB_TEST(temp == "three");
+ stream >> temp; DLIB_TEST(temp == "four");
+ stream >> temp; DLIB_TEST(temp == "five");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_iosockstream : public tester
+ {
+ public:
+ test_iosockstream (
+ ) :
+ tester ("test_iosockstream",
+ "Runs tests on the iosockstream component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test1();
+ test2();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/is_same_object.cpp b/ml/dlib/dlib/test/is_same_object.cpp
new file mode 100644
index 000000000..ed71c5bef
--- /dev/null
+++ b/ml/dlib/dlib/test/is_same_object.cpp
@@ -0,0 +1,141 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm.h>
+#include <vector>
+#include <sstream>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.is_same_object");
+
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_booya_template, void, template booya<int>, (std::string)const);
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_booya2_template, void, template booya2<int>, (int)const);
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_funct_int, void, funct, (int));
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_funct_double, void, funct, (double));
+ DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_funct_f, float, funct_f, (int));
+
+ class htest
+ {
+ public:
+ template <typename EXP>
+ void booya(std::string) const {}
+
+ template <typename EXP>
+ void booya2(EXP) const {}
+
+ void funct(double) {}
+ };
+
+ class htest2
+ {
+ public:
+
+ void funct(int) {}
+
+ float funct_f(int) { return 0;}
+ };
+
+ void test_metaprog()
+ {
+ DLIB_TEST(has_booya2_template<htest>::value == true);
+ DLIB_TEST(has_booya2_template<htest2>::value == false);
+
+#if _MSC_VER > 1600 // there is a bug in visual studio 2010 and older that prevents this test from working
+ DLIB_TEST(has_booya_template<htest>::value == true);
+#endif
+
+ DLIB_TEST(has_booya_template<htest2>::value == false);
+
+ DLIB_TEST(has_funct_int<htest>::value == false);
+ DLIB_TEST(has_funct_int<htest2>::value == true);
+ DLIB_TEST(has_funct_double<htest>::value == true);
+ DLIB_TEST(has_funct_double<htest2>::value == false);
+
+ DLIB_TEST(has_funct_f<htest>::value == false);
+ DLIB_TEST(has_funct_f<htest2>::value == true);
+ }
+
+ class is_same_object_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ is_same_object_tester (
+ ) :
+ tester (
+ "test_is_same_object", // the command line argument name for this test
+ "Run tests on the is_same_object function.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+ struct base {};
+ struct derived : public base {};
+
+ template <bool truth>
+ void go(const base& a, const base& b)
+ {
+ DLIB_TEST( is_same_object(a,b) == truth) ;
+ DLIB_TEST( is_same_object(b,a) == truth) ;
+ }
+
+
+ template <bool truth>
+ void go2(const base& a, const derived& b)
+ {
+ DLIB_TEST( is_same_object(a,b) == truth) ;
+ DLIB_TEST( is_same_object(b,a) == truth) ;
+ }
+
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+
+ int a, b;
+ double d;
+ DLIB_TEST( is_same_object(a,a) == true) ;
+ DLIB_TEST( is_same_object(a,b) == false) ;
+ DLIB_TEST( is_same_object(d,b) == false) ;
+ DLIB_TEST( is_same_object(d,d) == true) ;
+
+ base sb;
+ derived sd, sd2;
+
+ DLIB_TEST( is_same_object(sb,sd) == false) ;
+ DLIB_TEST( is_same_object(sd,sb) == false) ;
+
+ go<true>(sd, sd);
+ go<false>(sd, sd2);
+ go<true>(sb, sb);
+ go<false>(sd, sb);
+
+ go2<true>(sd, sd);
+ go2<false>(sd2, sd);
+ go2<false>(sd, sd2);
+ go2<false>(sb, sd);
+
+ test_metaprog();
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ is_same_object_tester a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/isotonic_regression.cpp b/ml/dlib/dlib/test/isotonic_regression.cpp
new file mode 100644
index 000000000..2ab46903c
--- /dev/null
+++ b/ml/dlib/dlib/test/isotonic_regression.cpp
@@ -0,0 +1,103 @@
+// Copyright (C) 2018 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/optimization.h>
+#include <dlib/global_optimization.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.isotonic_regression");
+
+// ----------------------------------------------------------------------------------------
+
+ class optimization_tester : public tester
+ {
+ public:
+ optimization_tester (
+ ) :
+ tester ("test_isotonic_regression",
+ "Runs tests on the isotonic_regression object.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlib::rand rnd;
+
+ for (int round = 0; round < 100; ++round)
+ {
+ print_spinner();
+ std::vector<double> vect;
+ for (int i = 0; i < 5; ++i)
+ vect.push_back(put_in_range(-1,1,rnd.get_random_gaussian()));
+
+
+ auto f = [&](const matrix<double,0,1>& x)
+ {
+ double dist = 0;
+ double sum = 0;
+ for (long i = 0; i < x.size(); ++i)
+ {
+ sum += x(i);
+ dist += (sum-vect[i])*(sum-vect[i]);
+ }
+ return dist;
+ };
+
+ auto objval = [vect](const matrix<double,0,1>& x)
+ {
+ return sum(squared(mat(vect)-x));
+ };
+
+ auto is_monotonic = [](const matrix<double,0,1>& x)
+ {
+ for (long i = 1; i < x.size(); ++i)
+ {
+ if (x(i-1) > x(i))
+ return false;
+ }
+ return true;
+ };
+
+ matrix<double,0,1> lower(5), upper(5);
+ lower = 0;
+ lower(0) = -4;
+ upper = 4;
+ // find the solution with find_min_global() and then check that it matches
+ auto result = find_min_global(f, lower, upper, max_function_calls(40));
+
+ for (long i = 1; i < result.x.size(); ++i)
+ result.x(i) += result.x(i-1);
+
+ isotonic_regression mr;
+ mr(vect);
+
+ dlog << LINFO << "err: "<< objval(mat(vect)) - objval(result.x);
+
+ DLIB_CASSERT(is_monotonic(mat(vect)));
+ DLIB_CASSERT(is_monotonic(result.x));
+ // isotonic_regression should be at least as good as find_min_global().
+ DLIB_CASSERT(objval(mat(vect)) - objval(result.x) < 1e-13);
+ }
+
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/kcentroid.cpp b/ml/dlib/dlib/test/kcentroid.cpp
new file mode 100644
index 000000000..c16ab6eca
--- /dev/null
+++ b/ml/dlib/dlib/test/kcentroid.cpp
@@ -0,0 +1,684 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include <map>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+#include "checkerboard.h"
+#include <dlib/statistics.h>
+
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.kcentroid");
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct unopt_sparse_linear_kernel
+ {
+ typedef typename T::value_type::second_type scalar_type;
+ typedef T sample_type;
+ typedef default_memory_manager mem_manager_type;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return dot(a,b);
+ }
+
+ bool operator== (
+ const unopt_sparse_linear_kernel&
+ ) const
+ {
+ return true;
+ }
+ };
+
+ template <typename T>
+ struct unopt_linear_kernel
+ {
+ typedef typename T::type scalar_type;
+ typedef T sample_type;
+ typedef typename T::mem_manager_type mem_manager_type;
+
+ scalar_type operator() (
+ const sample_type& a,
+ const sample_type& b
+ ) const
+ {
+ return trans(a)*b;
+ }
+
+ bool operator== (
+ const unopt_linear_kernel&
+ ) const
+ {
+ return true;
+ }
+ };
+
+ bool approx_equal(double a, double b)
+ {
+ return (std::abs(a-b) < 1000*(std::numeric_limits<double>::epsilon()));
+ }
+
+ bool approx_equal(double a, double b, double eps)
+ {
+ return (std::abs(a-b) < eps);
+ }
+
+ template <typename K>
+ double dist (
+ const K& k,
+ const matrix<double,4,1>& a,
+ const matrix<double,5,1>& b
+ )
+ /*!
+ ensures
+ - returns the distance between the a and b vectors in the
+ feature space defined by the given kernel k.
+ !*/
+ {
+ const double bias = std::sqrt(k.offset);
+ return std::sqrt(length_squared(a-colm(b,0,4)) + std::pow(b(4)-bias,2.0));
+
+ }
+
+ template <typename K>
+ double dist (
+ const K& k,
+ std::map<unsigned long,double> a,
+ std::map<unsigned long,double> b
+ )
+ /*!
+ ensures
+ - returns the distance between the a and b vectors in the
+ feature space defined by the given kernel k.
+ !*/
+ {
+ double temp = 0;
+ const double bias = std::sqrt(k.offset);
+ temp += std::pow(a[0]-b[0],2.0);
+ temp += std::pow(a[1]-b[1],2.0);
+ temp += std::pow(a[2]-b[2],2.0);
+ temp += std::pow(a[3]-b[3],2.0);
+ temp += std::pow(bias-b[4],2.0);
+
+ return std::sqrt(temp);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ void test_kcentroid_with_linear_kernel(
+ )
+ /*!
+ requires
+ - kernel_type::sample_type == a matrix<double,5,1>
+ - kernel_type == a kernel that just computes a dot product
+ between its inputs. I.e. a linear kernel
+ ensures
+ - tests the kcentroid object with the given kernel
+ !*/
+ {
+ // Here we declare that our samples will be 2 dimensional column vectors.
+ typedef typename kernel_type::sample_type sample_type;
+
+ kernel_type default_kernel;
+ kcentroid<kernel_type> test(default_kernel,0.001,20);
+
+ sample_type temp, temp2;
+
+ temp = 2,0,0,0,0;
+ dlog << LDEBUG << test(temp) ;
+ dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ;
+
+ DLIB_TEST(approx_equal(test(temp), 2));
+ DLIB_TEST(approx_equal(test.squared_norm(), 0));
+
+ // make test store the point(2,0,0,0,0)
+ test.train(temp, 0, 1);
+ dlog << LDEBUG << test(temp) ;
+ dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ;
+ DLIB_TEST(approx_equal(test(temp), 0));
+ DLIB_TEST(approx_equal(test.get_distance_function()(temp), 0));
+ DLIB_TEST(approx_equal(test.squared_norm(), 4));
+
+ temp = 0,2,0,0,0;
+ dlog << LDEBUG << test(temp) ;
+ DLIB_TEST(approx_equal(test(temp), std::sqrt(2*2 + 2*2.0)));
+ DLIB_TEST(approx_equal(test.squared_norm(), 4));
+
+ // make test store the point(0,2,0,0,0)
+ test.train(temp, 0, 1);
+
+ dlog << LDEBUG << test(temp) ;
+ DLIB_TEST(approx_equal(test(temp), 0));
+ DLIB_TEST(approx_equal(test.squared_norm(), 4));
+
+ temp = 2,0,0,0,0;
+ DLIB_TEST(approx_equal(test(temp), std::sqrt(2*2 + 2*2.0)));
+ DLIB_TEST(approx_equal(test.squared_norm(), 4));
+
+ // make test store the point(1,1,0,0,0)
+ test.train(temp, 0.5, 0.5);
+
+ temp = 0;
+ DLIB_TEST(approx_equal(test(temp), std::sqrt(2.0)));
+ DLIB_TEST(approx_equal(test.squared_norm(), 2));
+
+ // make test store the point(1,1,0,3,0)
+ temp = 0,0,0,3,0;
+ temp2 = 1,1,0,3,0;
+ test.train(temp, 1, 1);
+
+ temp = 0;
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ temp = 1,2,3,4,5;
+ DLIB_TEST(approx_equal(test(temp), length(temp2-temp)));
+ DLIB_TEST(approx_equal(test.get_distance_function()(temp), length(temp2-temp)));
+
+ // make test store the point(0,1,0,3,-1)
+ temp = 1,0,0,0,1;
+ test.train(temp, 1, -1);
+ temp2 = 0,1,0,3,-1;
+
+ temp = 0;
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ temp = 1,2,3,4,5;
+ DLIB_TEST(approx_equal(test(temp), length(temp2-temp)));
+
+
+ // make test store the -1*point(0,1,0,3,-1)
+ temp = 0,0,0,0,0;
+ test.train(temp, -1, 0);
+ temp2 = -temp2;
+
+ temp = 0;
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ temp = 1,2,-3,4,5;
+ DLIB_TEST(approx_equal(test(temp), length(temp2-temp)));
+
+
+
+ // make test store the point(0,0,0,0,0)
+ temp = 0,0,0,0,0;
+ test.train(temp, 0, 0);
+ temp2 = 0;
+
+ temp = 0;
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ temp = 1,2,-3,4,5;
+ DLIB_TEST(approx_equal(test(temp), length(temp2-temp)));
+
+
+
+ // make test store the point(1,0,0,0,0)
+ temp = 1,0,0,0,0;
+ test.train(temp, 1, 1);
+ temp2 = 1,0,0,0,0;
+
+ temp = 0;
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ DLIB_TEST(approx_equal(test.inner_product(test), length_squared(temp2)));
+ temp = 1,2,-3,4,5;
+ DLIB_TEST(approx_equal(test(temp), length(temp2-temp)));
+ DLIB_TEST(approx_equal(test(test), 0));
+ DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0));
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ void test_kcentroid_with_offset_linear_kernel(
+ )
+ /*!
+ requires
+ - kernel_type::sample_type == a matrix<double,4,1>
+ - kernel_type == a kernel that just computes a dot product
+ between its inputs + some constant. I.e. a linear kernel
+ wrapped by offset_kernel
+ ensures
+ - tests the kcentroid object with the given kernel
+ !*/
+ {
+ // Here we declare that our samples will be 2 dimensional column vectors.
+ typedef typename kernel_type::sample_type sample_type;
+
+ kernel_type k;
+ kcentroid<kernel_type> test(k,0.001,20);
+
+ sample_type temp, temp2, temp3;
+
+ matrix<double,5,1> val, val2;
+
+ const double b = std::sqrt(k.offset);
+
+ temp = 2,0,0,0;
+ temp2 = 0;
+ val = 0;
+ DLIB_TEST(approx_equal(test(temp), dist(k,temp,val)));
+ DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+
+
+ temp2 = 0;
+
+ // make test store the point(0,0,0,0,b)
+ val = 0,0,0,0,b;
+ test.train(temp2, 0,1);
+
+ temp = 2,0,0,0;
+ dlog << LDEBUG << test(temp) ;
+ dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ;
+
+ DLIB_TEST(approx_equal(test(temp), dist(k,temp,val)));
+ DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val)));
+ DLIB_TEST_MSG(approx_equal(test.get_distance_function()(temp2), dist(k,temp2,val), 1e-6),
+ test.get_distance_function()(temp2) - dist(k,temp2,val) << " compare to: " <<
+ test(temp2) - dist(k,temp2,val)
+ );
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+
+
+ // make test store the point(0,0,0,0,0)
+ val = 0,0,0,0,0;
+ test.train(temp2, 1,-1);
+ DLIB_TEST(approx_equal(test(temp), dist(k,temp,val)));
+ DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val)));
+ DLIB_TEST_MSG(approx_equal(test.get_distance_function()(temp2), dist(k,temp2,val)),
+ test.get_distance_function()(temp2) - dist(k,temp2,val) << " compare to: " <<
+ test(temp2) - dist(k,temp2,val)
+ );
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+
+
+
+ val2 = 0,1,0,0,b;
+ val += val2;
+ temp2 = 0,1,0,0;
+ // make test store the point val
+ test.train(temp2, 1,1);
+
+ temp = 1,0,3,0;
+ DLIB_TEST(approx_equal(test(temp), dist(k,temp,val)));
+ DLIB_TEST_MSG(approx_equal(test(temp2), dist(k,temp2,val), 1e-7),
+ test(temp2) - dist(k,temp2,val));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+ DLIB_TEST_MSG(approx_equal(test(test), 0, 1e-7), test(test));
+
+
+ val2 = 0,1,2.6,8,b;
+ val += val2;
+ temp2 = 0,1,2.6,8;
+ // make test store the point val
+ test.train(temp2, 1,1);
+
+ temp = 1,1,3,0;
+ DLIB_TEST(approx_equal(test(temp), dist(k,temp,val)));
+ DLIB_TEST_MSG(approx_equal(test(temp2), dist(k,temp2,val)), test(temp2) - dist(k,temp2,val));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+ DLIB_TEST(approx_equal(test.inner_product(test), length_squared(val)));
+ DLIB_TEST(approx_equal(test(test), 0));
+ DLIB_TEST_MSG(approx_equal(test.get_distance_function()(test.get_distance_function()), 0, 1e-6),
+ test.get_distance_function()(test.get_distance_function()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ void test_kcentroid_with_sparse_linear_kernel(
+ )
+ /*!
+ requires
+ - kernel_type::sample_type == a std::map<unsigned long,double>
+ - kernel_type == a kernel that just computes a dot product
+ between its inputs. I.e. a linear kernel
+ ensures
+ - tests the kcentroid object with the given kernel
+ !*/
+ {
+ // Here we declare that our samples will be 2 dimensional column vectors.
+ typedef typename kernel_type::sample_type sample_type;
+
+ kernel_type default_kernel;
+ kcentroid<kernel_type> test(default_kernel,0.001,20);
+
+ dlog << LDEBUG << "AAAA 1" ;
+
+ sample_type temp, temp2;
+
+ temp[0] = 2;
+ dlog << LDEBUG << test(temp) ;
+ dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ;
+
+ DLIB_TEST(approx_equal(test(temp), 2));
+ DLIB_TEST(approx_equal(test.squared_norm(), 0));
+
+ // make test store the point(2,0,0,0,0)
+ test.train(temp, 0, 1);
+ dlog << LDEBUG << test(temp) ;
+ dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ;
+ DLIB_TEST(approx_equal(test(temp), 0));
+ DLIB_TEST(approx_equal(test.squared_norm(), 4));
+
+ dlog << LDEBUG << "AAAA 2" ;
+ temp.clear();
+ temp[1] = 2;
+ dlog << LDEBUG << test(temp) ;
+ DLIB_TEST(approx_equal(test(temp), std::sqrt(2*2 + 2*2.0)));
+ DLIB_TEST(approx_equal(test.squared_norm(), 4));
+
+ // make test store the point(0,2,0,0,0)
+ test.train(temp, 0, 1);
+
+ dlog << LDEBUG << test(temp) ;
+ DLIB_TEST(approx_equal(test(temp), 0));
+ DLIB_TEST(approx_equal(test.squared_norm(), 4));
+
+ temp.clear();
+ temp[0] = 2;
+ DLIB_TEST(approx_equal(test(temp), std::sqrt(2*2 + 2*2.0)));
+ DLIB_TEST(approx_equal(test.squared_norm(), 4));
+
+ // make test store the point(1,1,0,0,0)
+ test.train(temp, 0.5, 0.5);
+
+ dlog << LDEBUG << "AAAA 3" ;
+ temp.clear();
+ DLIB_TEST(approx_equal(test(temp), std::sqrt(2.0)));
+ DLIB_TEST(approx_equal(test.squared_norm(), 2));
+ DLIB_TEST(approx_equal(test(test), 0));
+ DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0));
+
+ dlog << LDEBUG << "AAAA 3.1" ;
+ // make test store the point(1,1,0,3,0)
+ temp.clear(); temp[3] = 3;
+ temp2.clear();
+ temp2[0] = 1;
+ temp2[1] = 1;
+ temp2[3] = 3;
+ test.train(temp, 1, 1);
+
+ dlog << LDEBUG << "AAAA 3.2" ;
+ temp.clear();
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ dlog << LDEBUG << "AAAA 3.3" ;
+ temp[0] = 1;
+ temp[1] = 2;
+ temp[2] = 3;
+ temp[3] = 4;
+ temp[4] = 5;
+ dlog << LDEBUG << "AAAA 3.4" ;
+ double junk = dlib::distance(temp2,temp);
+ dlog << LDEBUG << "AAAA 3.5" ;
+ DLIB_TEST(approx_equal(test(temp), junk) );
+
+ dlog << LDEBUG << "AAAA 4" ;
+ // make test store the point(0,1,0,3,-1)
+ temp.clear();
+ temp[0] = 1;
+ temp[4] = 1;
+ test.train(temp, 1, -1);
+ temp2.clear();
+ temp2[1] = 1;
+ temp2[3] = 3;
+ temp2[4] = -1;
+
+ temp.clear();
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ temp[0] = 1;
+ temp[1] = 2;
+ temp[2] = 3;
+ temp[3] = 4;
+ temp[4] = 5;
+ DLIB_TEST(approx_equal(test(temp), dlib::distance(temp2,temp)));
+
+
+ // make test store the -1*point(0,1,0,3,-1)
+ temp.clear();
+ test.train(temp, -1, 0);
+ temp2[0] = -temp2[0];
+ temp2[1] = -temp2[1];
+ temp2[2] = -temp2[2];
+ temp2[3] = -temp2[3];
+ temp2[4] = -temp2[4];
+
+ dlog << LDEBUG << "AAAA 5" ;
+ temp.clear();
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ temp[0] = 1;
+ temp[1] = 2;
+ temp[2] = -3;
+ temp[3] = 4;
+ temp[4] = 5;
+ DLIB_TEST(approx_equal(test(temp), dlib::distance(temp2,temp)));
+
+
+
+ // make test store the point(0,0,0,0,0)
+ temp.clear();
+ test.train(temp, 0, 0);
+ temp2.clear();
+
+ temp.clear();
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ temp[0] = 1;
+ temp[1] = 2;
+ temp[2] = -3;
+ temp[3] = 4;
+ temp[4] = 5;
+ DLIB_TEST(approx_equal(test(temp), dlib::distance(temp2,temp)));
+ DLIB_TEST(approx_equal(test.get_distance_function()(temp), dlib::distance(temp2,temp)));
+
+
+ dlog << LDEBUG << "AAAA 6" ;
+
+ // make test store the point(1,0,0,0,0)
+ temp.clear();
+ temp[0] = 1;
+ test.train(temp, 1, 1);
+ temp2.clear();
+ temp2[0] = 1;
+
+ temp.clear();
+ DLIB_TEST(approx_equal(test(temp), length(temp2)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2)));
+ DLIB_TEST(approx_equal(test.inner_product(test), length_squared(temp2)));
+ temp[0] = 1;
+ temp[1] = 2;
+ temp[2] = -3;
+ temp[3] = 4;
+ temp[4] = 5;
+ DLIB_TEST(approx_equal(test(temp), dlib::distance(temp2,temp)));
+ DLIB_TEST(approx_equal(test.get_distance_function()(temp), dlib::distance(temp2,temp)));
+ DLIB_TEST(approx_equal(test(test), 0));
+ DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0));
+
+ dlog << LDEBUG << "AAAA 7" ;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ void test_kcentroid_with_offset_sparse_linear_kernel(
+ )
+ /*!
+ requires
+ - kernel_type::sample_type == a std::map<unsigned long,double>
+ - kernel_type == a kernel that just computes a dot product
+ between its inputs + some constant. I.e. a linear kernel
+ wrapped by offset_kernel
+ ensures
+ - tests the kcentroid object with the given kernel
+ !*/
+ {
+ // Here we declare that our samples will be 2 dimensional column vectors.
+ typedef typename kernel_type::sample_type sample_type;
+
+ kernel_type k;
+ kcentroid<kernel_type> test(k,0.001,20);
+
+ sample_type temp, temp2, temp3;
+
+ std::map<unsigned long,double> val, val2;
+
+ const double b = std::sqrt(k.offset);
+
+ temp.clear();
+ temp[0] = 2;
+ temp2.clear();
+ val.clear();
+ DLIB_TEST(approx_equal(test(temp), dist(k,temp,val)));
+ DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+
+
+ temp2.clear();
+
+ // make test store the point(0,0,0,0,b)
+ val.clear();
+ val[4] = b;
+ test.train(temp2, 0,1);
+
+ temp.clear();
+ temp[0] = 2;
+ dlog << LDEBUG << test(temp) ;
+ dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ;
+
+ DLIB_TEST(approx_equal(test(temp), dist(k,temp,val)));
+ DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val)));
+ DLIB_TEST_MSG(approx_equal(test.get_distance_function()(temp2), dist(k,temp2,val), 1e-7),
+ test.get_distance_function()(temp2) - dist(k,temp2,val)
+ );
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+ DLIB_TEST(approx_equal(test(test), 0));
+ DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0, 1e-6));
+
+ // make test store the point(0,0,0,0,0)
+ val.clear();
+ test.train(temp2, 1,-1);
+
+ temp.clear();
+ temp[0] = 2;
+ dlog << LDEBUG << test(temp) ;
+ dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ;
+
+ DLIB_TEST_MSG(approx_equal(test(temp), dist(k,temp,val)), test(temp) - dist(k,temp,val));
+ DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val)));
+ DLIB_TEST(approx_equal(test.get_distance_function()(temp2), dist(k,temp2,val)));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+ DLIB_TEST(approx_equal(test(test), 0));
+ DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0));
+
+ val2.clear();
+ val2[0] = 0;
+ val2[1] = 1;
+ val2[2] = 0;
+ val2[3] = 0;
+ val2[4] = b;
+ for (unsigned int i = 0; i < 5; ++i) val[i] += val2[i];
+ temp2.clear();
+ temp2[1] = 1;
+ // make test store the point val
+ test.train(temp2, 1,1);
+
+ temp.clear();
+ temp[0] = 1;
+ temp[2] = 3;
+ DLIB_TEST(approx_equal(test(temp), dist(k,temp,val)));
+ DLIB_TEST_MSG(approx_equal(test(temp2), dist(k,temp2,val), 1e-7),
+ test(temp2) - dist(k,temp2,val));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+
+
+ val2.clear();
+ val2[0] = 0;
+ val2[1] = 1;
+ val2[2] = 2.6;
+ val2[3] = 8;
+ val2[4] = b;
+ for (unsigned int i = 0; i < 5; ++i) val[i] += val2[i];
+
+ temp2.clear();
+ temp2[0] = 0;
+ temp2[1] = 1;
+ temp2[2] = 2.6;
+ temp2[3] = 8;
+ // make test store the point val
+ test.train(temp2, 1,1);
+
+ temp.clear();
+ temp[0] = 1;
+ temp[1] = 1;
+ temp[2] = 3;
+ temp[3] = 0;
+ DLIB_TEST(approx_equal(test(temp), dist(k,temp,val)));
+ DLIB_TEST_MSG(approx_equal(test(temp2), dist(k,temp2,val)), test(temp2) - dist(k,temp2,val));
+ DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val)));
+ DLIB_TEST(approx_equal(test.inner_product(test), length_squared(val)));
+ DLIB_TEST_MSG(approx_equal(test(test), 0, 1e-6), test(test));
+ DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class kcentroid_tester : public tester
+ {
+ public:
+ kcentroid_tester (
+ ) :
+ tester ("test_kcentroid",
+ "Runs tests on the kcentroid components.")
+ {}
+
+ void perform_test (
+ )
+ {
+ // The idea here is to exercize all the various overloads of the kcentroid object. We also want
+ // to exercize the non-overloaded default version. That is why we have these unopt_* linear
+ // kernels
+ test_kcentroid_with_linear_kernel<linear_kernel<matrix<double,5,1> > >();
+ test_kcentroid_with_offset_linear_kernel<offset_kernel<linear_kernel<matrix<double,4,1> > > >();
+ test_kcentroid_with_linear_kernel<unopt_linear_kernel<matrix<double,5,1> > >();
+ test_kcentroid_with_offset_linear_kernel<offset_kernel<unopt_linear_kernel<matrix<double,4,1> > > >();
+ test_kcentroid_with_sparse_linear_kernel<sparse_linear_kernel<std::map<unsigned long,double> > >();
+ test_kcentroid_with_offset_sparse_linear_kernel<offset_kernel<sparse_linear_kernel<std::map<unsigned long,double> > > >();
+ test_kcentroid_with_sparse_linear_kernel<unopt_sparse_linear_kernel<std::map<unsigned long,double> > >();
+ test_kcentroid_with_offset_sparse_linear_kernel<offset_kernel<unopt_sparse_linear_kernel<std::map<unsigned long,double> > > >();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/kernel_matrix.cpp b/ml/dlib/dlib/test/kernel_matrix.cpp
new file mode 100644
index 000000000..8fc3a2d2c
--- /dev/null
+++ b/ml/dlib/dlib/test/kernel_matrix.cpp
@@ -0,0 +1,161 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm.h>
+#include <vector>
+#include <sstream>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.kernel_matrix");
+
+
+ class kernel_matrix_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ kernel_matrix_tester (
+ ) :
+ tester (
+ "test_kernel_matrix", // the command line argument name for this test
+ "Run tests on the kernel_matrix functions.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+
+ typedef matrix<double,0,1> sample_type;
+ typedef radial_basis_kernel<sample_type> kernel_type;
+ kernel_type kern(0.1);
+
+ std::vector<sample_type> vect1;
+ std::vector<sample_type> vect2;
+
+ const sample_type samp = randm(4,1);
+ sample_type samp2, samp3;
+
+ vect1.push_back(randm(4,1));
+ vect1.push_back(randm(4,1));
+ vect1.push_back(randm(4,1));
+ vect1.push_back(randm(4,1));
+
+ vect2.push_back(randm(4,1));
+ vect2.push_back(randm(4,1));
+ vect2.push_back(randm(4,1));
+ vect2.push_back(randm(4,1));
+ vect2.push_back(randm(4,1));
+
+ matrix<double> K;
+
+ K.set_size(vect1.size(), vect2.size());
+ for (long r = 0; r < K.nr(); ++r)
+ {
+ for (long c = 0; c < K.nc(); ++c)
+ {
+ K(r,c) = kern(vect1[r], vect2[c]);
+ }
+ }
+ DLIB_TEST(equal(K, kernel_matrix(kern, vect1, vect2)));
+ DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect1), mat(vect2))));
+
+
+ K.set_size(vect2.size(), vect1.size());
+ for (long r = 0; r < K.nr(); ++r)
+ {
+ for (long c = 0; c < K.nc(); ++c)
+ {
+ K(r,c) = kern(vect2[r], vect1[c]);
+ }
+ }
+ DLIB_TEST(equal(K, kernel_matrix(kern, vect2, vect1)));
+ DLIB_TEST(equal(K, tmp(kernel_matrix(kern, vect2, vect1))));
+ DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect2), mat(vect1))));
+
+
+ K.set_size(vect1.size(), vect1.size());
+ for (long r = 0; r < K.nr(); ++r)
+ {
+ for (long c = 0; c < K.nc(); ++c)
+ {
+ K(r,c) = kern(vect1[r], vect1[c]);
+ }
+ }
+ DLIB_TEST(equal(K, kernel_matrix(kern, vect1, vect1)));
+ DLIB_TEST(equal(K, tmp(kernel_matrix(kern, vect1, vect1))));
+ DLIB_TEST(equal(K, kernel_matrix(kern, vect1)));
+ DLIB_TEST(equal(K, tmp(kernel_matrix(kern, vect1))));
+ DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect1), mat(vect1))));
+ DLIB_TEST(equal(K, tmp(kernel_matrix(kern, mat(vect1), mat(vect1)))));
+ DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect1))));
+ DLIB_TEST(equal(K, tmp(kernel_matrix(kern, mat(vect1)))));
+
+
+ K.set_size(vect1.size(),1);
+ for (long r = 0; r < K.nr(); ++r)
+ {
+ for (long c = 0; c < K.nc(); ++c)
+ {
+ K(r,c) = kern(vect1[r], samp);
+ }
+ }
+ DLIB_TEST(equal(K, kernel_matrix(kern, vect1, samp)));
+ DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect1), samp)));
+
+
+ K.set_size(1, vect1.size());
+ for (long r = 0; r < K.nr(); ++r)
+ {
+ for (long c = 0; c < K.nc(); ++c)
+ {
+ K(r,c) = kern(samp, vect1[c]);
+ }
+ }
+ DLIB_TEST(equal(K, kernel_matrix(kern, samp, vect1)));
+ DLIB_TEST(equal(K, kernel_matrix(kern, samp, mat(vect1))));
+ DLIB_TEST(equal(K, tmp(kernel_matrix(kern, samp, vect1))));
+ DLIB_TEST(equal(K, tmp(kernel_matrix(kern, samp, mat(vect1)))));
+
+
+
+ samp2 = samp;
+ samp3 = samp;
+
+ // test the alias detection
+ samp2 = kernel_matrix(kern, vect1, samp2);
+ DLIB_TEST(equal(samp2, kernel_matrix(kern, vect1, samp)));
+
+ samp3 = trans(kernel_matrix(kern, samp3, vect2));
+ DLIB_TEST(equal(samp3, trans(kernel_matrix(kern, samp, vect2))));
+
+
+ samp2 += kernel_matrix(kern, vect1, samp);
+ DLIB_TEST(equal(samp2, 2*kernel_matrix(kern, vect1, samp)));
+
+ samp3 += trans(kernel_matrix(kern, samp, vect2));
+ DLIB_TEST(equal(samp3, 2*trans(kernel_matrix(kern, samp, vect2))));
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ kernel_matrix_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/kmeans.cpp b/ml/dlib/dlib/test/kmeans.cpp
new file mode 100644
index 000000000..95c037b77
--- /dev/null
+++ b/ml/dlib/dlib/test/kmeans.cpp
@@ -0,0 +1,163 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/svm.h>
+#include <dlib/matrix.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.kmeans");
+
+ dlib::rand rnd;
+
+ template <typename sample_type>
+ void run_test(
+ const std::vector<sample_type>& seed_centers
+ )
+ {
+ print_spinner();
+
+
+ sample_type samp;
+
+ std::vector<sample_type> samples;
+
+
+ for (unsigned long j = 0; j < seed_centers.size(); ++j)
+ {
+ for (int i = 0; i < 250; ++i)
+ {
+ samp = randm(seed_centers[0].size(),1,rnd) - 0.5;
+ samples.push_back(samp + seed_centers[j]);
+ }
+ }
+
+ randomize_samples(samples);
+
+ {
+ std::vector<sample_type> centers;
+ pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>());
+
+ find_clusters_using_kmeans(samples, centers);
+
+ DLIB_TEST(centers.size() == seed_centers.size());
+
+ std::vector<int> hits(centers.size(),0);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ unsigned long best_idx = 0;
+ double best_dist = 1e100;
+ for (unsigned long j = 0; j < centers.size(); ++j)
+ {
+ if (length(samples[i] - centers[j]) < best_dist)
+ {
+ best_dist = length(samples[i] - centers[j]);
+ best_idx = j;
+ }
+ }
+ hits[best_idx]++;
+ }
+
+ for (unsigned long i = 0; i < hits.size(); ++i)
+ {
+ DLIB_TEST(hits[i] == 250);
+ }
+ }
+ {
+ std::vector<sample_type> centers;
+ pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>());
+
+ find_clusters_using_angular_kmeans(samples, centers);
+
+ DLIB_TEST(centers.size() == seed_centers.size());
+
+ std::vector<int> hits(centers.size(),0);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ unsigned long best_idx = 0;
+ double best_dist = 1e100;
+ for (unsigned long j = 0; j < centers.size(); ++j)
+ {
+ if (length(samples[i] - centers[j]) < best_dist)
+ {
+ best_dist = length(samples[i] - centers[j]);
+ best_idx = j;
+ }
+ }
+ hits[best_idx]++;
+ }
+
+ for (unsigned long i = 0; i < hits.size(); ++i)
+ {
+ DLIB_TEST(hits[i] == 250);
+ }
+ }
+ }
+
+
+ class test_kmeans : public tester
+ {
+ public:
+ test_kmeans (
+ ) :
+ tester ("test_kmeans",
+ "Runs tests on the find_clusters_using_kmeans() function.")
+ {}
+
+ void perform_test (
+ )
+ {
+ {
+ dlog << LINFO << "test dlib::vector<double,2>";
+ typedef dlib::vector<double,2> sample_type;
+ std::vector<sample_type> seed_centers;
+ seed_centers.push_back(sample_type(10,10));
+ seed_centers.push_back(sample_type(10,-10));
+ seed_centers.push_back(sample_type(-10,10));
+ seed_centers.push_back(sample_type(-10,-10));
+
+ run_test(seed_centers);
+ }
+ {
+ dlog << LINFO << "test dlib::vector<double,2>";
+ typedef dlib::vector<float,2> sample_type;
+ std::vector<sample_type> seed_centers;
+ seed_centers.push_back(sample_type(10,10));
+ seed_centers.push_back(sample_type(10,-10));
+ seed_centers.push_back(sample_type(-10,10));
+ seed_centers.push_back(sample_type(-10,-10));
+
+ run_test(seed_centers);
+ }
+ {
+ dlog << LINFO << "test dlib::matrix<double,3,1>";
+ typedef dlib::matrix<double,3,1> sample_type;
+ std::vector<sample_type> seed_centers;
+ sample_type samp;
+ samp = 10,10,0; seed_centers.push_back(samp);
+ samp = -10,10,1; seed_centers.push_back(samp);
+ samp = -10,-10,2; seed_centers.push_back(samp);
+
+ run_test(seed_centers);
+ }
+
+
+ }
+ } a;
+
+
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/learning_to_track.cpp b/ml/dlib/dlib/test/learning_to_track.cpp
new file mode 100644
index 000000000..730c01ba8
--- /dev/null
+++ b/ml/dlib/dlib/test/learning_to_track.cpp
@@ -0,0 +1,306 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+#include <dlib/rand.h>
+
+
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.learning_to_track");
+
+// ----------------------------------------------------------------------------------------
+
+ struct detection_dense
+ {
+ typedef struct track_dense track_type;
+ matrix<double,0,1> measurements;
+ };
+
+
+ struct track_dense
+ {
+ typedef matrix<double,0,1> feature_vector_type;
+
+ track_dense()
+ {
+ time_since_last_association = 0;
+ }
+
+ void get_similarity_features(const detection_dense det, feature_vector_type& feats) const
+ {
+ feats = abs(last_measurements - det.measurements);
+ }
+
+ void update_track(const detection_dense det)
+ {
+ last_measurements = det.measurements;
+ time_since_last_association = 0;
+ }
+
+ void propagate_track()
+ {
+ ++time_since_last_association;
+ }
+
+ matrix<double,0,1> last_measurements;
+ unsigned long time_since_last_association;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct detection_sparse
+ {
+ typedef struct track_sparse track_type;
+ matrix<double,0,1> measurements;
+ };
+
+
+ struct track_sparse
+ {
+ typedef std::vector<std::pair<unsigned long,double> > feature_vector_type;
+
+ track_sparse()
+ {
+ time_since_last_association = 0;
+ }
+
+ void get_similarity_features(const detection_sparse det, feature_vector_type& feats) const
+ {
+ matrix<double,0,1> temp = abs(last_measurements - det.measurements);
+ feats.clear();
+ for (long i = 0; i < temp.size(); ++i)
+ feats.push_back(make_pair(i, temp(i)));
+ }
+
+ void update_track(const detection_sparse det)
+ {
+ last_measurements = det.measurements;
+ time_since_last_association = 0;
+ }
+
+ void propagate_track()
+ {
+ ++time_since_last_association;
+ }
+
+ matrix<double,0,1> last_measurements;
+ unsigned long time_since_last_association;
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ dlib::rand rnd;
+ const long num_objects = 4;
+ const long num_properties = 6;
+ std::vector<matrix<double,0,1> > object_properties(num_objects);
+
+ void initialize_object_properties()
+ {
+ rnd.set_seed("23ja2oirfjaf");
+ for (unsigned long i = 0; i < object_properties.size(); ++i)
+ object_properties[i] = randm(num_properties,1,rnd);
+ }
+
+ template <typename detection>
+ detection sample_detection_from_sensor(long object_id)
+ {
+ DLIB_CASSERT(object_id < num_objects,
+ "You can't ask to sample a detection from an object that doesn't exist.");
+ detection temp;
+ // Set the measurements equal to the object's true property values plus a little bit of
+ // noise.
+ temp.measurements = object_properties[object_id] + randm(num_properties,1,rnd)*0.1;
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <typename detection>
+ std::vector<std::vector<labeled_detection<detection> > > make_random_tracking_data_for_training()
+ {
+ typedef std::vector<labeled_detection<detection> > detections_at_single_time_step;
+ typedef std::vector<detections_at_single_time_step> track_history;
+
+ track_history data;
+
+ // At each time step we get a set of detections from the objects in the world.
+ // Simulate 100 time steps worth of data where there are 3 objects present.
+ const int num_time_steps = 100;
+ for (int i = 0; i < num_time_steps; ++i)
+ {
+ detections_at_single_time_step dets(3);
+ // sample a detection from object 0
+ dets[0].det = sample_detection_from_sensor<detection>(0);
+ dets[0].label = 0;
+
+ // sample a detection from object 1
+ dets[1].det = sample_detection_from_sensor<detection>(1);
+ dets[1].label = 1;
+
+ // sample a detection from object 2
+ dets[2].det = sample_detection_from_sensor<detection>(2);
+ dets[2].label = 2;
+
+ randomize_samples(dets, rnd);
+ data.push_back(dets);
+ }
+
+ // Now let's imagine object 1 and 2 are gone but a new object, object 3 has arrived.
+ for (int i = 0; i < num_time_steps; ++i)
+ {
+ detections_at_single_time_step dets(2);
+ // sample a detection from object 0
+ dets[0].det = sample_detection_from_sensor<detection>(0);
+ dets[0].label = 0;
+
+ // sample a detection from object 3
+ dets[1].det = sample_detection_from_sensor<detection>(3);
+ dets[1].label = 3;
+
+ randomize_samples(dets, rnd);
+ data.push_back(dets);
+ }
+
+ return data;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename detection>
+ std::vector<detection> make_random_detections(long num_dets)
+ {
+ DLIB_CASSERT(num_dets <= num_objects,
+ "You can't ask for more detections than there are objects in our little simulation.");
+
+ std::vector<detection> dets(num_dets);
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ dets[i] = sample_detection_from_sensor<detection>(i);
+ }
+ randomize_samples(dets, rnd);
+ return dets;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename detection>
+ void test_tracking_stuff()
+ {
+ print_spinner();
+
+
+ typedef std::vector<labeled_detection<detection> > detections_at_single_time_step;
+ typedef std::vector<detections_at_single_time_step> track_history;
+ std::vector<track_history> data;
+ data.push_back(make_random_tracking_data_for_training<detection>());
+ data.push_back(make_random_tracking_data_for_training<detection>());
+ data.push_back(make_random_tracking_data_for_training<detection>());
+ data.push_back(make_random_tracking_data_for_training<detection>());
+ data.push_back(make_random_tracking_data_for_training<detection>());
+
+
+ structural_track_association_trainer trainer;
+ trainer.set_c(1000);
+ track_association_function<detection> assoc = trainer.train(data);
+
+ double test_val = test_track_association_function(assoc, data);
+ DLIB_TEST_MSG( test_val == 1, test_val);
+ test_val = cross_validate_track_association_trainer(trainer, data, 5);
+ DLIB_TEST_MSG ( test_val == 1, test_val);
+
+
+
+ typedef typename detection::track_type track;
+ std::vector<track> tracks;
+
+ std::vector<detection> dets = make_random_detections<detection>(3);
+ assoc(tracks, dets);
+ DLIB_TEST(tracks.size() == 3);
+
+ dets = make_random_detections<detection>(3);
+ assoc(tracks, dets);
+ DLIB_TEST(tracks.size() == 3);
+
+ dets = make_random_detections<detection>(3);
+ assoc(tracks, dets);
+ DLIB_TEST(tracks.size() == 3);
+
+ dets = make_random_detections<detection>(4);
+ assoc(tracks, dets);
+ DLIB_TEST(tracks.size() == 4);
+
+ dets = make_random_detections<detection>(3);
+ assoc(tracks, dets);
+ DLIB_TEST(tracks.size() == 4);
+ unsigned long total_miss = 0;
+ for (unsigned long i = 0; i < tracks.size(); ++i)
+ total_miss += tracks[i].time_since_last_association;
+ DLIB_TEST(total_miss == 1);
+
+ dets = make_random_detections<detection>(3);
+ assoc(tracks, dets);
+ DLIB_TEST(tracks.size() == 4);
+ total_miss = 0;
+ unsigned long num_zero = 0;
+ for (unsigned long i = 0; i < tracks.size(); ++i)
+ {
+ total_miss += tracks[i].time_since_last_association;
+ if (tracks[i].time_since_last_association == 0)
+ ++num_zero;
+ }
+ DLIB_TEST(total_miss == 2);
+ DLIB_TEST(num_zero == 3);
+
+
+
+ ostringstream sout;
+ serialize(assoc, sout);
+
+ istringstream sin(sout.str());
+ deserialize(assoc, sin);
+ DLIB_TEST( test_track_association_function(assoc, data) == 1);
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ class test_learning_to_track : public tester
+ {
+ public:
+ test_learning_to_track (
+ ) :
+ tester ("test_learning_to_track",
+ "Runs tests on the assignment learning code.")
+ {}
+
+ void perform_test (
+ )
+ {
+ initialize_object_properties();
+ for (int i = 0; i < 3; ++i)
+ {
+ dlog << LINFO << "run test_tracking_stuff<detection_dense>()";
+ test_tracking_stuff<detection_dense>();
+ dlog << LINFO << "run test_tracking_stuff<detection_sparse>()";
+ test_tracking_stuff<detection_sparse>();
+ }
+ }
+ } a;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
diff --git a/ml/dlib/dlib/test/least_squares.cpp b/ml/dlib/dlib/test/least_squares.cpp
new file mode 100644
index 000000000..c4282ad05
--- /dev/null
+++ b/ml/dlib/dlib/test/least_squares.cpp
@@ -0,0 +1,452 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/optimization.h>
+#include "optimization_test_functions.h"
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../rand.h"
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ using namespace dlib::test_functions;
+
+ logger dlog("test.least_squares");
+
+// ----------------------------------------------------------------------------------------
+
+ void test_with_chebyquad()
+ {
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(2);
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch);
+ dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2));
+
+ DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5);
+
+ }
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(2);
+
+ solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "LM chebyquad 2 obj: " << chebyquad(ch);
+ dlog << LINFO << "LM chebyquad 2 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "LM chebyquad 2 error: " << length(ch - chebyquad_solution(2));
+
+ DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5);
+
+ }
+
+ print_spinner();
+ {
+ matrix<double,2,1> ch;
+
+ ch = chebyquad_start(2);
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch);
+ dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2));
+
+ DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<double,2,1> ch;
+
+ ch = chebyquad_start(2);
+
+ solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "LM chebyquad 2 obj: " << chebyquad(ch);
+ dlog << LINFO << "LM chebyquad 2 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "LM chebyquad 2 error: " << length(ch - chebyquad_solution(2));
+
+ DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5);
+
+ }
+
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(4);
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "chebyquad 4 obj: " << chebyquad(ch);
+ dlog << LINFO << "chebyquad 4 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "chebyquad 4 error: " << length(ch - chebyquad_solution(4));
+
+ DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(4);
+
+ solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "LM chebyquad 4 obj: " << chebyquad(ch);
+ dlog << LINFO << "LM chebyquad 4 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "LM chebyquad 4 error: " << length(ch - chebyquad_solution(4));
+
+ DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5);
+
+ }
+
+
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(6);
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "chebyquad 6 obj: " << chebyquad(ch);
+ dlog << LINFO << "chebyquad 6 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "chebyquad 6 error: " << length(ch - chebyquad_solution(6));
+
+ // the ch variable contains a permutation of what is in chebyquad_solution(6).
+ // Apparently there is more than one minimum?. Just check that the objective
+ // goes to zero.
+ DLIB_TEST(chebyquad(ch) < 1e-10);
+
+ }
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(6);
+
+ solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "LM chebyquad 6 obj: " << chebyquad(ch);
+ dlog << LINFO << "LM chebyquad 6 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "LM chebyquad 6 error: " << length(ch - chebyquad_solution(6));
+
+ DLIB_TEST(chebyquad(ch) < 1e-10);
+
+ }
+
+
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(8);
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "chebyquad 8 obj: " << chebyquad(ch);
+ dlog << LINFO << "chebyquad 8 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "chebyquad 8 error: " << length(ch - chebyquad_solution(8));
+
+ DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(8);
+
+ solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
+ chebyquad_residual,
+ derivative(chebyquad_residual),
+ range(0,ch.size()-1),
+ ch);
+
+ dlog << LINFO << "LM chebyquad 8 obj: " << chebyquad(ch);
+ dlog << LINFO << "LM chebyquad 8 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "LM chebyquad 8 error: " << length(ch - chebyquad_solution(8));
+
+ DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5);
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_with_brown()
+ {
+ print_spinner();
+ {
+ matrix<double,4,1> ch;
+
+ ch = brown_start();
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 300),
+ brown_residual,
+ derivative(brown_residual),
+ range(1,20),
+ ch);
+
+ dlog << LINFO << "brown obj: " << brown(ch);
+ dlog << LINFO << "brown der: " << length(brown_derivative(ch));
+ dlog << LINFO << "brown error: " << length(ch - brown_solution());
+
+ DLIB_TEST_MSG(length(ch - brown_solution()) < 1e-5,length(ch - brown_solution()) );
+
+ }
+ print_spinner();
+ {
+ matrix<double,4,1> ch;
+
+ ch = brown_start();
+
+ solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
+ brown_residual,
+ derivative(brown_residual),
+ range(1,20),
+ ch);
+
+ dlog << LINFO << "LM brown obj: " << brown(ch);
+ dlog << LINFO << "LM brown der: " << length(brown_derivative(ch));
+ dlog << LINFO << "LM brown error: " << length(ch - brown_solution());
+
+ DLIB_TEST(length(ch - brown_solution()) < 1e-5);
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// These functions are declared here because wrapping the real rosen functions in this
+// way avoids triggering a bug in visual studio 2005 which prevents this code from compiling.
+ double rosen_residual_double (int i, const matrix<double,2,1>& m)
+ { return rosen_residual(i,m); }
+ float rosen_residual_float (int i, const matrix<float,2,1>& m)
+ { return rosen_residual(i,m); }
+
+ matrix<double,2,1> rosen_residual_derivative_double (int i, const matrix<double,2,1>& m)
+ { return rosen_residual_derivative(i,m); }
+ /*
+ matrix<float,2,1> rosen_residual_derivative_float (int i, const matrix<float,2,1>& m)
+ { return rosen_residual_derivative(i,m); }
+ */
+
+ double rosen_big_residual_double (int i, const matrix<double,2,1>& m)
+ { return rosen_big_residual(i,m); }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_with_rosen()
+ {
+
+ print_spinner();
+ {
+ matrix<double,2,1> ch;
+
+ ch = rosen_start<double>();
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
+ rosen_residual_double,
+ rosen_residual_derivative_double,
+ range(1,20),
+ ch);
+
+ dlog << LINFO << "rosen obj: " << rosen(ch);
+ dlog << LINFO << "rosen error: " << length(ch - rosen_solution<double>());
+
+ DLIB_TEST(length(ch - rosen_solution<double>()) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<double,2,1> ch;
+
+ ch = rosen_start<double>();
+
+ solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
+ rosen_residual_double,
+ rosen_residual_derivative_double,
+ range(1,20),
+ ch);
+
+ dlog << LINFO << "lm rosen obj: " << rosen(ch);
+ dlog << LINFO << "lm rosen error: " << length(ch - rosen_solution<double>());
+
+ DLIB_TEST(length(ch - rosen_solution<double>()) < 1e-5);
+
+ }
+
+
+
+ print_spinner();
+ {
+ matrix<double,2,1> ch;
+
+ ch = rosen_start<double>();
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
+ rosen_residual_double,
+ derivative(rosen_residual_double),
+ range(1,20),
+ ch);
+
+ dlog << LINFO << "rosen obj: " << rosen(ch);
+ dlog << LINFO << "rosen error: " << length(ch - rosen_solution<double>());
+
+ DLIB_TEST(length(ch - rosen_solution<double>()) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<float,2,1> ch;
+
+ ch = rosen_start<float>();
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
+ rosen_residual_float,
+ derivative(rosen_residual_float),
+ range(1,20),
+ ch);
+
+ dlog << LINFO << "float rosen obj: " << rosen(ch);
+ dlog << LINFO << "float rosen error: " << length(ch - rosen_solution<float>());
+
+ DLIB_TEST(length(ch - rosen_solution<float>()) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<float,2,1> ch;
+
+ ch = rosen_start<float>();
+
+ solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
+ rosen_residual_float,
+ derivative(rosen_residual_float),
+ range(1,20),
+ ch);
+
+ dlog << LINFO << "LM float rosen obj: " << rosen(ch);
+ dlog << LINFO << "LM float rosen error: " << length(ch - rosen_solution<float>());
+
+ DLIB_TEST(length(ch - rosen_solution<float>()) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<double,2,1> ch;
+
+ ch = rosen_start<double>();
+
+ solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
+ rosen_residual_double,
+ derivative(rosen_residual_double),
+ range(1,20),
+ ch);
+
+ dlog << LINFO << "LM rosen obj: " << rosen(ch);
+ dlog << LINFO << "LM rosen error: " << length(ch - rosen_solution<double>());
+
+ DLIB_TEST(length(ch - rosen_solution<double>()) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<double,2,1> ch;
+
+ ch = rosen_big_start<double>();
+
+ solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
+ rosen_big_residual_double,
+ derivative(rosen_big_residual_double),
+ range(1,2),
+ ch);
+
+ dlog << LINFO << "rosen big obj: " << rosen_big(ch);
+ dlog << LINFO << "rosen big error: " << length(ch - rosen_big_solution<double>());
+
+ DLIB_TEST(length(ch - rosen_big_solution<double>()) < 1e-5);
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class optimization_tester : public tester
+ {
+ public:
+ optimization_tester (
+ ) :
+ tester ("test_least_squares",
+ "Runs tests on the least squares optimization component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_with_chebyquad();
+ test_with_brown();
+ test_with_rosen();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/linear_manifold_regularizer.cpp b/ml/dlib/dlib/test/linear_manifold_regularizer.cpp
new file mode 100644
index 000000000..e73b1c8d3
--- /dev/null
+++ b/ml/dlib/dlib/test/linear_manifold_regularizer.cpp
@@ -0,0 +1,408 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/manifold_regularization.h>
+#include <dlib/svm.h>
+#include <dlib/rand.h>
+#include <dlib/string.h>
+#include <dlib/graph_utils_threaded.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.linear_manifold_regularizer");
+
+ template <typename hash_type, typename samples_type>
+ void test_find_k_nearest_neighbors_lsh(
+ const samples_type& samples
+ )
+ {
+ std::vector<sample_pair> edges1, edges2;
+
+ find_k_nearest_neighbors(samples, cosine_distance(), 2, edges1);
+ find_k_nearest_neighbors_lsh(samples, cosine_distance(), hash_type(), 2, 6, edges2, 2);
+
+ std::sort(edges1.begin(), edges1.end(), order_by_index<sample_pair>);
+ std::sort(edges2.begin(), edges2.end(), order_by_index<sample_pair>);
+
+ DLIB_TEST_MSG(edges1.size() == edges2.size(), edges1.size() << " " << edges2.size());
+ for (unsigned long i = 0; i < edges1.size(); ++i)
+ {
+ DLIB_TEST(edges1[i] == edges2[i]);
+ DLIB_TEST_MSG(std::abs(edges1[i].distance() - edges2[i].distance()) < 1e-7,
+ edges1[i].distance() - edges2[i].distance());
+ }
+ }
+
+ template <typename scalar_type>
+ void test_knn_lsh_sparse()
+ {
+ dlib::rand rnd;
+ std::vector<std::map<unsigned long,scalar_type> > samples;
+ samples.resize(20);
+ for (unsigned int i = 0; i < samples.size(); ++i)
+ {
+ samples[i][0] = rnd.get_random_gaussian();
+ samples[i][2] = rnd.get_random_gaussian();
+ }
+
+ test_find_k_nearest_neighbors_lsh<hash_similar_angles_64>(samples);
+ test_find_k_nearest_neighbors_lsh<hash_similar_angles_128>(samples);
+ test_find_k_nearest_neighbors_lsh<hash_similar_angles_256>(samples);
+ test_find_k_nearest_neighbors_lsh<hash_similar_angles_512>(samples);
+ }
+
+ template <typename scalar_type>
+ void test_knn_lsh_dense()
+ {
+ dlib::rand rnd;
+ std::vector<matrix<scalar_type,0,1> > samples;
+ samples.resize(20);
+ for (unsigned int i = 0; i < samples.size(); ++i)
+ {
+ samples[i].set_size(2);
+ samples[i](0) = rnd.get_random_gaussian();
+ samples[i](1) = rnd.get_random_gaussian();
+ }
+
+ test_find_k_nearest_neighbors_lsh<hash_similar_angles_64>(samples);
+ test_find_k_nearest_neighbors_lsh<hash_similar_angles_128>(samples);
+ test_find_k_nearest_neighbors_lsh<hash_similar_angles_256>(samples);
+ test_find_k_nearest_neighbors_lsh<hash_similar_angles_512>(samples);
+ }
+
+
+
+ class linear_manifold_regularizer_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ linear_manifold_regularizer_tester (
+ ) :
+ tester (
+ "test_linear_manifold_regularizer", // the command line argument name for this test
+ "Run tests on the linear_manifold_regularizer object.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ seed = 1;
+ }
+
+ dlib::rand rnd;
+
+ unsigned long seed;
+
+ typedef matrix<double, 0, 1> sample_type;
+ typedef radial_basis_kernel<sample_type> kernel_type;
+
+ void do_the_test()
+ {
+ print_spinner();
+ std::vector<sample_type> samples;
+
+ // Declare an instance of the kernel we will be using.
+ const kernel_type kern(0.1);
+
+ const unsigned long num_points = 200;
+
+ // create a large dataset with two concentric circles.
+ generate_circle(samples, 1, num_points); // circle of radius 1
+ generate_circle(samples, 5, num_points); // circle of radius 5
+
+ std::vector<sample_pair> edges;
+ find_percent_shortest_edges_randomly(samples, squared_euclidean_distance(0.1, 4), 1, 10000, "random seed", edges);
+
+ dlog << LTRACE << "number of edges generated: " << edges.size();
+
+ empirical_kernel_map<kernel_type> ekm;
+
+ ekm.load(kern, randomly_subsample(samples, 100));
+
+ // Project all the samples into the span of our 50 basis samples
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ samples[i] = ekm.project(samples[i]);
+
+
+ // Now create the manifold regularizer. The result is a transformation matrix that
+ // embodies the manifold assumption discussed above.
+ linear_manifold_regularizer<sample_type> lmr;
+ lmr.build(samples, edges, use_gaussian_weights(0.1));
+ matrix<double> T = lmr.get_transformation_matrix(10000);
+
+ print_spinner();
+
+ // generate the T matrix manually and make sure it matches. The point of this test
+ // is to make sure that the more complex version of this that happens inside the linear_manifold_regularizer
+ // is correct. It uses a tedious block of loops to do it in a way that is a lot faster for sparse
+ // W matrices but isn't super straight forward.
+ matrix<double> X(samples[0].size(), samples.size());
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ set_colm(X,i) = samples[i];
+
+ matrix<double> W(samples.size(), samples.size());
+ W = 0;
+ for (unsigned long i = 0; i < edges.size(); ++i)
+ {
+ W(edges[i].index1(), edges[i].index2()) = use_gaussian_weights(0.1)(edges[i]);
+ W(edges[i].index2(), edges[i].index1()) = use_gaussian_weights(0.1)(edges[i]);
+ }
+ matrix<double> L = diagm(sum_rows(W)) - W;
+ matrix<double> trueT = inv_lower_triangular(chol(identity_matrix<double>(X.nr()) + (10000.0/sum(lowerm(W)))*X*L*trans(X)));
+
+ dlog << LTRACE << "T error: "<< max(abs(T - trueT));
+ DLIB_TEST(max(abs(T - trueT)) < 1e-7);
+
+
+ print_spinner();
+ // Apply the transformation generated by the linear_manifold_regularizer to
+ // all our samples.
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ samples[i] = T*samples[i];
+
+
+ // For convenience, generate a projection_function and merge the transformation
+ // matrix T into it.
+ projection_function<kernel_type> proj = ekm.get_projection_function();
+ proj.weights = T*proj.weights;
+
+
+ // Pick 2 different labeled points. One on the inner circle and another on the outer.
+ // For each of these test points we will see if using the single plane that separates
+ // them is a good way to separate the concentric circles. Also do this a bunch
+ // of times with different randomly chosen points so we can see how robust the result is.
+ for (int itr = 0; itr < 10; ++itr)
+ {
+ print_spinner();
+ std::vector<sample_type> test_points;
+ // generate a random point from the radius 1 circle
+ generate_circle(test_points, 1, 1);
+ // generate a random point from the radius 5 circle
+ generate_circle(test_points, 5, 1);
+
+ // project the two test points into kernel space. Recall that this projection_function
+ // has the manifold regularizer incorporated into it.
+ const sample_type class1_point = proj(test_points[0]);
+ const sample_type class2_point = proj(test_points[1]);
+
+ double num_wrong = 0;
+
+ // Now attempt to classify all the data samples according to which point
+ // they are closest to. The output of this program shows that without manifold
+ // regularization this test will fail but with it it will perfectly classify
+ // all the points.
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ double distance_to_class1 = length(samples[i] - class1_point);
+ double distance_to_class2 = length(samples[i] - class2_point);
+
+ bool predicted_as_class_1 = (distance_to_class1 < distance_to_class2);
+
+ bool really_is_class_1 = (i < num_points);
+
+ // now count how many times we make a mistake
+ if (predicted_as_class_1 != really_is_class_1)
+ ++num_wrong;
+ }
+
+ DLIB_TEST_MSG(num_wrong == 0, num_wrong);
+ }
+
+ }
+
+ void generate_circle (
+ std::vector<sample_type>& samples,
+ double radius,
+ const long num
+ )
+ {
+ sample_type m(2,1);
+
+ for (long i = 0; i < num; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ samples.push_back(m);
+ }
+ }
+
+
+ void test_knn1()
+ {
+ std::vector<matrix<double,2,1> > samples;
+
+ matrix<double,2,1> test;
+
+ test = 0,0; samples.push_back(test);
+ test = 1,1; samples.push_back(test);
+ test = 1,-1; samples.push_back(test);
+ test = -1,1; samples.push_back(test);
+ test = -1,-1; samples.push_back(test);
+
+ std::vector<sample_pair> edges;
+ find_k_nearest_neighbors(samples, squared_euclidean_distance(), 1, edges);
+ DLIB_TEST(edges.size() == 4);
+
+ std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
+
+ DLIB_TEST(edges[0] == sample_pair(0,1,0));
+ DLIB_TEST(edges[1] == sample_pair(0,2,0));
+ DLIB_TEST(edges[2] == sample_pair(0,3,0));
+ DLIB_TEST(edges[3] == sample_pair(0,4,0));
+
+ find_k_nearest_neighbors(samples, squared_euclidean_distance(), 3, edges);
+ DLIB_TEST(edges.size() == 8);
+
+ find_k_nearest_neighbors(samples, squared_euclidean_distance(3.9, 4.1), 3, edges);
+ DLIB_TEST(edges.size() == 4);
+
+ std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
+
+ DLIB_TEST(edges[0] == sample_pair(1,2,0));
+ DLIB_TEST(edges[1] == sample_pair(1,3,0));
+ DLIB_TEST(edges[2] == sample_pair(2,4,0));
+ DLIB_TEST(edges[3] == sample_pair(3,4,0));
+
+ find_k_nearest_neighbors(samples, squared_euclidean_distance(30000, 4.1), 3, edges);
+ DLIB_TEST(edges.size() == 0);
+ }
+
+ void test_knn1_approx()
+ {
+ std::vector<matrix<double,2,1> > samples;
+
+ matrix<double,2,1> test;
+
+ test = 0,0; samples.push_back(test);
+ test = 1,1; samples.push_back(test);
+ test = 1,-1; samples.push_back(test);
+ test = -1,1; samples.push_back(test);
+ test = -1,-1; samples.push_back(test);
+
+ std::vector<sample_pair> edges;
+ find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 1, 10000, seed, edges);
+ DLIB_TEST(edges.size() == 4);
+
+ std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
+
+ DLIB_TEST(edges[0] == sample_pair(0,1,0));
+ DLIB_TEST(edges[1] == sample_pair(0,2,0));
+ DLIB_TEST(edges[2] == sample_pair(0,3,0));
+ DLIB_TEST(edges[3] == sample_pair(0,4,0));
+
+ find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 3, 10000, seed, edges);
+ DLIB_TEST(edges.size() == 8);
+
+ find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(3.9, 4.1), 3, 10000, seed, edges);
+ DLIB_TEST(edges.size() == 4);
+
+ std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
+
+ DLIB_TEST(edges[0] == sample_pair(1,2,0));
+ DLIB_TEST(edges[1] == sample_pair(1,3,0));
+ DLIB_TEST(edges[2] == sample_pair(2,4,0));
+ DLIB_TEST(edges[3] == sample_pair(3,4,0));
+
+ find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(30000, 4.1), 3, 10000, seed, edges);
+ DLIB_TEST(edges.size() == 0);
+ }
+
+ void test_knn2()
+ {
+ std::vector<matrix<double,2,1> > samples;
+
+ matrix<double,2,1> test;
+
+ test = 1,1; samples.push_back(test);
+ test = 1,-1; samples.push_back(test);
+ test = -1,1; samples.push_back(test);
+ test = -1,-1; samples.push_back(test);
+
+ std::vector<sample_pair> edges;
+ find_k_nearest_neighbors(samples, squared_euclidean_distance(), 2, edges);
+ DLIB_TEST(edges.size() == 4);
+
+ std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
+
+ DLIB_TEST(edges[0] == sample_pair(0,1,0));
+ DLIB_TEST(edges[1] == sample_pair(0,2,0));
+ DLIB_TEST(edges[2] == sample_pair(1,3,0));
+ DLIB_TEST(edges[3] == sample_pair(2,3,0));
+
+ find_k_nearest_neighbors(samples, squared_euclidean_distance(), 200, edges);
+ DLIB_TEST(edges.size() == 4*3/2);
+ }
+
+ void test_knn2_approx()
+ {
+ std::vector<matrix<double,2,1> > samples;
+
+ matrix<double,2,1> test;
+
+ test = 1,1; samples.push_back(test);
+ test = 1,-1; samples.push_back(test);
+ test = -1,1; samples.push_back(test);
+ test = -1,-1; samples.push_back(test);
+
+ std::vector<sample_pair> edges;
+ // For this simple graph and high number of samples we will do we should obtain the exact
+ // knn solution.
+ find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 2, 10000, seed, edges);
+ DLIB_TEST(edges.size() == 4);
+
+ std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
+
+ DLIB_TEST(edges[0] == sample_pair(0,1,0));
+ DLIB_TEST(edges[1] == sample_pair(0,2,0));
+ DLIB_TEST(edges[2] == sample_pair(1,3,0));
+ DLIB_TEST(edges[3] == sample_pair(2,3,0));
+
+
+ find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 200, 10000, seed, edges);
+ DLIB_TEST(edges.size() == 4*3/2);
+ }
+
+ void perform_test (
+ )
+ {
+ for (int i = 0; i < 5; ++i)
+ {
+ do_the_test();
+
+ ++seed;
+ test_knn1_approx();
+ test_knn2_approx();
+ }
+ test_knn1();
+ test_knn2();
+ test_knn_lsh_sparse<double>();
+ test_knn_lsh_sparse<float>();
+ test_knn_lsh_dense<double>();
+ test_knn_lsh_dense<float>();
+
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ linear_manifold_regularizer_tester a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/lspi.cpp b/ml/dlib/dlib/test/lspi.cpp
new file mode 100644
index 000000000..013887115
--- /dev/null
+++ b/ml/dlib/dlib/test/lspi.cpp
@@ -0,0 +1,258 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/control.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.lspi");
+
+ template <bool have_prior>
+ struct chain_model
+ {
+ typedef int state_type;
+ typedef int action_type; // 0 is move left, 1 is move right
+ const static bool force_last_weight_to_1 = have_prior;
+
+
+ const static int num_states = 4; // not required in the model interface
+
+ matrix<double,8,1> offset;
+ chain_model()
+ {
+ offset =
+ 2.048 ,
+ 2.56 ,
+ 2.048 ,
+ 3.2 ,
+ 2.56 ,
+ 4 ,
+ 3.2,
+ 5 ;
+ if (!have_prior)
+ offset = 0;
+
+ }
+
+ unsigned long num_features(
+ ) const
+ {
+ if (have_prior)
+ return num_states*2 + 1;
+ else
+ return num_states*2;
+ }
+
+ action_type find_best_action (
+ const state_type& state,
+ const matrix<double,0,1>& w
+ ) const
+ {
+ if (w(state*2)+offset(state*2) >= w(state*2+1)+offset(state*2+1))
+ //if (w(state*2) >= w(state*2+1))
+ return 0;
+ else
+ return 1;
+ }
+
+ void get_features (
+ const state_type& state,
+ const action_type& action,
+ matrix<double,0,1>& feats
+ ) const
+ {
+ feats.set_size(num_features());
+ feats = 0;
+ feats(state*2 + action) = 1;
+ if (have_prior)
+ feats(num_features()-1) = offset(state*2+action);
+ }
+
+ };
+
+ void test_lspi_prior1()
+ {
+ print_spinner();
+ typedef process_sample<chain_model<true> > sample_type;
+ std::vector<sample_type> samples;
+
+ samples.push_back(sample_type(0,0,0,0));
+ samples.push_back(sample_type(0,1,1,0));
+
+ samples.push_back(sample_type(1,0,0,0));
+ samples.push_back(sample_type(1,1,2,0));
+
+ samples.push_back(sample_type(2,0,1,0));
+ samples.push_back(sample_type(2,1,3,0));
+
+ samples.push_back(sample_type(3,0,2,0));
+ samples.push_back(sample_type(3,1,3,1));
+
+
+ lspi<chain_model<true> > trainer;
+ //trainer.be_verbose();
+ trainer.set_lambda(0);
+ policy<chain_model<true> > pol = trainer.train(samples);
+
+ dlog << LINFO << pol.get_weights();
+
+ matrix<double,0,1> w = pol.get_weights();
+ DLIB_TEST(pol.get_weights().size() == 9);
+ DLIB_TEST(w(w.size()-1) == 1);
+ w(w.size()-1) = 0;
+ DLIB_TEST_MSG(length(w) < 1e-12, length(w));
+
+ dlog << LINFO << "action: " << pol(0);
+ dlog << LINFO << "action: " << pol(1);
+ dlog << LINFO << "action: " << pol(2);
+ dlog << LINFO << "action: " << pol(3);
+ DLIB_TEST(pol(0) == 1);
+ DLIB_TEST(pol(1) == 1);
+ DLIB_TEST(pol(2) == 1);
+ DLIB_TEST(pol(3) == 1);
+ }
+
+ void test_lspi_prior2()
+ {
+ print_spinner();
+ typedef process_sample<chain_model<true> > sample_type;
+ std::vector<sample_type> samples;
+
+ samples.push_back(sample_type(0,0,0,0));
+ samples.push_back(sample_type(0,1,1,0));
+
+ samples.push_back(sample_type(1,0,0,0));
+ samples.push_back(sample_type(1,1,2,0));
+
+ samples.push_back(sample_type(2,0,1,0));
+ samples.push_back(sample_type(2,1,3,1));
+
+ samples.push_back(sample_type(3,0,2,0));
+ samples.push_back(sample_type(3,1,3,0));
+
+
+ lspi<chain_model<true> > trainer;
+ //trainer.be_verbose();
+ trainer.set_lambda(0);
+ policy<chain_model<true> > pol = trainer.train(samples);
+
+
+ dlog << LINFO << "action: " << pol(0);
+ dlog << LINFO << "action: " << pol(1);
+ dlog << LINFO << "action: " << pol(2);
+ dlog << LINFO << "action: " << pol(3);
+ DLIB_TEST(pol(0) == 1);
+ DLIB_TEST(pol(1) == 1);
+ DLIB_TEST(pol(2) == 1);
+ DLIB_TEST(pol(3) == 0);
+ }
+
+ void test_lspi_noprior1()
+ {
+ print_spinner();
+ typedef process_sample<chain_model<false> > sample_type;
+ std::vector<sample_type> samples;
+
+ samples.push_back(sample_type(0,0,0,0));
+ samples.push_back(sample_type(0,1,1,0));
+
+ samples.push_back(sample_type(1,0,0,0));
+ samples.push_back(sample_type(1,1,2,0));
+
+ samples.push_back(sample_type(2,0,1,0));
+ samples.push_back(sample_type(2,1,3,0));
+
+ samples.push_back(sample_type(3,0,2,0));
+ samples.push_back(sample_type(3,1,3,1));
+
+
+ lspi<chain_model<false> > trainer;
+ //trainer.be_verbose();
+ trainer.set_lambda(0.01);
+ policy<chain_model<false> > pol = trainer.train(samples);
+
+ dlog << LINFO << pol.get_weights();
+ DLIB_TEST(pol.get_weights().size() == 8);
+
+
+ dlog << LINFO << "action: " << pol(0);
+ dlog << LINFO << "action: " << pol(1);
+ dlog << LINFO << "action: " << pol(2);
+ dlog << LINFO << "action: " << pol(3);
+ DLIB_TEST(pol(0) == 1);
+ DLIB_TEST(pol(1) == 1);
+ DLIB_TEST(pol(2) == 1);
+ DLIB_TEST(pol(3) == 1);
+ }
+ void test_lspi_noprior2()
+ {
+ print_spinner();
+ typedef process_sample<chain_model<false> > sample_type;
+ std::vector<sample_type> samples;
+
+ samples.push_back(sample_type(0,0,0,0));
+ samples.push_back(sample_type(0,1,1,0));
+
+ samples.push_back(sample_type(1,0,0,0));
+ samples.push_back(sample_type(1,1,2,1));
+
+ samples.push_back(sample_type(2,0,1,0));
+ samples.push_back(sample_type(2,1,3,0));
+
+ samples.push_back(sample_type(3,0,2,0));
+ samples.push_back(sample_type(3,1,3,0));
+
+
+ lspi<chain_model<false> > trainer;
+ //trainer.be_verbose();
+ trainer.set_lambda(0.01);
+ policy<chain_model<false> > pol = trainer.train(samples);
+
+ dlog << LINFO << pol.get_weights();
+ DLIB_TEST(pol.get_weights().size() == 8);
+
+
+ dlog << LINFO << "action: " << pol(0);
+ dlog << LINFO << "action: " << pol(1);
+ dlog << LINFO << "action: " << pol(2);
+ dlog << LINFO << "action: " << pol(3);
+ DLIB_TEST(pol(0) == 1);
+ DLIB_TEST(pol(1) == 1);
+ DLIB_TEST(pol(2) == 0);
+ DLIB_TEST(pol(3) == 0);
+ }
+
+ class lspi_tester : public tester
+ {
+ public:
+ lspi_tester (
+ ) :
+ tester (
+ "test_lspi", // the command line argument name for this test
+ "Run tests on the lspi object.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+ void perform_test (
+ )
+ {
+ test_lspi_prior1();
+ test_lspi_prior2();
+
+ test_lspi_noprior1();
+ test_lspi_noprior2();
+ }
+ };
+
+ lspi_tester a;
+}
+
diff --git a/ml/dlib/dlib/test/lz77_buffer.cpp b/ml/dlib/dlib/test/lz77_buffer.cpp
new file mode 100644
index 000000000..ccbb1a24c
--- /dev/null
+++ b/ml/dlib/dlib/test/lz77_buffer.cpp
@@ -0,0 +1,569 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <ctime>
+#include <cstdlib>
+#include <dlib/sliding_buffer.h>
+
+#include <dlib/lz77_buffer.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.lz77_buffer");
+
+ template <
+ typename buf
+ >
+ void lz77_buffer_kernel_test (
+ )
+ /*!
+ requires
+ - buf is an implementation of lz77_buffer/lz77_buffer_kernel_abstract.h
+ ensures
+ - runs tests on buf for compliance with the specs
+ !*/
+ {
+ typedef dlib::sliding_buffer<unsigned char>::kernel_1a sbuf;
+
+ buf test(8,20);
+ srand(static_cast<unsigned int>(time(0)));
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+ DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,test.get_history_buffer_limit());
+ DLIB_TEST(test.get_lookahead_buffer_limit() == 20);
+
+
+ for (int g = 0; g < 2; ++g)
+ {
+ test.clear();
+
+ for (int i = 0; i < 1000; ++i)
+ {
+ test.add('a');
+ }
+ DLIB_TEST(test.get_lookahead_buffer_size() == 20);
+
+
+ test.shift_buffers(5);
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == 15);
+
+
+
+ unsigned long index, length, temp;
+ temp = test.get_lookahead_buffer_size();
+ test.find_match(index,length,5);
+
+
+ DLIB_TEST_MSG(length <= temp,
+ "length: " << length <<
+ "\ntemp: " << temp);
+ DLIB_TEST(test.get_lookahead_buffer_size() <= 15);
+
+
+ }
+
+
+ for (int g = 0; g < 2; ++g)
+ {
+
+
+
+ test.clear();
+
+
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_limit() == 256-20);
+ DLIB_TEST(test.get_lookahead_buffer_limit() == 20);
+
+ unsigned long a,b, temp = test.get_lookahead_buffer_size();
+ test.find_match(a,b,0);
+ DLIB_TEST(b <= temp);
+ DLIB_TEST(b == 0);
+
+ test.find_match(a,b,5);
+ DLIB_TEST(b == 0);
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_limit() == 256-20);
+ DLIB_TEST(test.get_lookahead_buffer_limit() == 20);
+
+
+
+ ostringstream sout;
+ sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_size() == 0,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n";
+ sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_limit() == 20,);\n";
+ sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_size() == 0,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n";
+ sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_limit() == 20,);\n";
+ istringstream sin(sout.str());
+
+ sout.str("");
+ sout.clear();
+
+ unsigned char ch;
+ sbuf sbuffer;
+ sbuffer.set_size(8);
+
+
+ ch = sin.get();
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch);
+ DLIB_TEST(test.get_lookahead_buffer_size() == 1);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+ DLIB_TEST(test.get_lookahead_buffer_limit() == 20);
+
+
+
+ ch = sin.get();
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch);
+ DLIB_TEST(test.get_lookahead_buffer_size() == 2);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+ DLIB_TEST(test.get_lookahead_buffer_limit() == 20);
+
+
+
+ ch = sin.get();
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch);
+ DLIB_TEST(test.get_lookahead_buffer_size() == 3);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+
+ // add 17 chars to test so that the lookahead buffer will be full
+ for (int i = 0; i < 17; ++i)
+ {
+ ch = sin.get();
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch);
+ }
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == 20);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+ DLIB_TEST(test.lookahead_buffer(0) == sbuffer[20]);
+
+
+ ch = sin.get();
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch);
+ DLIB_TEST(test.get_lookahead_buffer_size() == 20);
+ DLIB_TEST(test.get_history_buffer_size() == 1);
+
+
+
+
+
+
+ // add the above text to test and make sure it gives the correct results
+ ch = sin.get();
+ while (sin)
+ {
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch);
+ DLIB_TEST(test.history_buffer(0) == sbuffer[21]);
+ DLIB_TEST(test.history_buffer(1) == sbuffer[22]);
+
+ ch = sin.get();
+ }
+
+
+
+ // make sure the contents of lookahead_buffer and history_buffer
+ // match what is in sbuffer
+ sbuffer.rotate_right(1);
+ for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i));
+ }
+ for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i));
+ }
+ sbuffer.rotate_left(1);
+
+
+
+
+
+
+
+
+
+
+ sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0]
+
+ unsigned long match_index, match_length;
+ unsigned long ltemp = test.get_lookahead_buffer_size();
+ test.find_match(match_index,match_length,0);
+ DLIB_TEST(match_length <= ltemp);
+
+
+ // verify the match with sbuffer
+ for (unsigned int i = 0; i < match_length; ++i)
+ {
+ DLIB_TEST_MSG(sbuffer[19-i] == sbuffer[match_index+20-i],i);
+ }
+
+
+ sin.str("");
+ sin.clear();
+
+ } // for (int g = 0; g < 2; ++g)
+
+
+ for (int g = 0; g < 8; ++g)
+ {
+ test.clear();
+
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_limit() == 256-20);
+ DLIB_TEST(test.get_lookahead_buffer_limit() == 20);
+
+
+ sbuf sbuffer;
+ sbuffer.set_size(8);
+
+ ostringstream sout;
+ sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_size() == 0,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n";
+ sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_limit() == 20,);\n";
+ sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n";
+ sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_limit() == 20,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n";
+ sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n";
+ sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n";
+ istringstream sin(sout.str());
+
+ unsigned char ch;
+ for (int i = 0; i < 100; ++i)
+ {
+ ch = sin.get();
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ }
+
+ // make sure the contents of lookahead_buffer and history_buffer
+ // match what is in sbuffer
+ sbuffer.rotate_right(1);
+ for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i));
+ }
+ for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i));
+ }
+ sbuffer.rotate_left(1);
+
+
+
+
+ unsigned long match_index, match_length;
+ unsigned long ltemp = test.get_lookahead_buffer_size();
+ test.find_match(match_index,match_length,0);
+ DLIB_TEST(match_length <= ltemp);
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == 20-match_length);
+
+ sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0]
+ // verify the match with sbuffer
+ for (unsigned int i = 0; i < match_length; ++i)
+ {
+ DLIB_TEST(sbuffer[i+20-match_length] == sbuffer[i+1+match_index+20-match_length]);
+ }
+ sbuffer.rotate_left(1); // free up sbuffer[0] for new data
+
+
+
+
+ for (int i = 0; i < 7+g*2; ++i)
+ {
+ ch = sin.get();
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ }
+
+ ch = '?';
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ ch = 'a';
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ ch = 'v';
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ ch = 'i';
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ ch = 's';
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+
+
+ // adjust sbuffer due to the last call to test.find_match()
+ // but only if we haven't already added enough (20 or more) chars
+ // to fill the lookahead buffer already.
+ if (match_length > static_cast<unsigned int>(12+g*2))
+ sbuffer.rotate_left(match_length-(12+g*2));
+
+
+
+
+
+ // make sure the contents of lookahead_buffer and history_buffer
+ // match what is in sbuffer
+ sbuffer.rotate_right(1);
+ for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i));
+ }
+ for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i));
+ }
+ sbuffer.rotate_left(1);
+
+
+
+
+ test.find_match(match_index,match_length,10+g);
+
+ if (match_length > 0)
+ DLIB_TEST(match_length >= static_cast<unsigned int>(10+g) );
+
+
+ sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0]
+ // verify the match with sbuffer
+ for (unsigned int i = 0; i < match_length; ++i)
+ {
+ DLIB_TEST(sbuffer[i+20-match_length] == sbuffer[i+1+match_index+20-match_length]);
+ }
+ sbuffer.rotate_left(1); // free up sbuffer[0] for new data
+
+ } // for (int g = 0; g < 8; ++g)
+
+
+
+
+
+
+
+ srand(static_cast<unsigned int>(time(0)));
+
+ for (int g = 0; g < 200; ++g)
+ {
+ test.clear();
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_limit() == 256-20);
+ DLIB_TEST(test.get_lookahead_buffer_limit() == 20);
+
+
+ sbuf sbuffer;
+ sbuffer.set_size(8);
+
+ ostringstream sout;
+ int l = ::rand()%500;
+ for (int i = 0; i < l; ++i)
+ {
+ char temp = static_cast<char>(::rand()%256);
+ sout << temp;
+ }
+ istringstream sin(sout.str());
+
+ unsigned char ch;
+ for (int i = 0; i < l; ++i)
+ {
+ ch = sin.get();
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ }
+
+ // make sure the contents of lookahead_buffer and history_buffer
+ // match what is in sbuffer
+ sbuffer.rotate_right(1);
+
+ // adjust so that sbuffer[19] is the same as lookahead_buffer[0]
+ if (test.get_lookahead_buffer_size() < 20)
+ sbuffer.rotate_left(20-test.get_lookahead_buffer_size());
+
+ for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i));
+ }
+ for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i));
+ }
+ sbuffer.rotate_left(1);
+
+
+
+ unsigned long match_index, match_length;
+ unsigned long lookahead_size_before = test.get_lookahead_buffer_size();
+ test.find_match(match_index,match_length,0);
+ DLIB_TEST(match_length <= lookahead_size_before);
+
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == lookahead_size_before-match_length);
+
+ sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0]
+ // verify the match with sbuffer
+ for (unsigned int i = 0; i < match_length; ++i)
+ {
+ DLIB_TEST_MSG(sbuffer[19-i] == sbuffer[match_index+20-i],i);
+ }
+ sbuffer.rotate_left(1); // free up sbuffer[0] for new data
+
+ } // for (int g = 0; g < 200; ++g)
+
+
+
+
+
+
+
+
+ srand(static_cast<unsigned int>(time(0)));
+
+ for (int g = 0; g < 300; ++g)
+ {
+ test.clear();
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_size() == 0);
+ DLIB_TEST(test.get_history_buffer_limit() == 256-20);
+ DLIB_TEST(test.get_lookahead_buffer_limit() == 20);
+
+
+ sbuf sbuffer;
+ sbuffer.set_size(8);
+
+ ostringstream sout;
+ int l = ::rand()%500;
+ for (int i = 0; i < l; ++i)
+ {
+ char temp = static_cast<char>(::rand()%20);
+ sout << temp;
+ sout << temp;
+ sout << temp;
+ sout << temp;
+ sout << temp;
+ sout << temp;
+ }
+ istringstream sin(sout.str());
+
+ unsigned char ch;
+ for (int i = 0; i < l; ++i)
+ {
+ ch = sin.get();
+ sbuffer[0] = ch; sbuffer.rotate_left(1);
+ test.add(ch);
+ }
+
+ // make sure the contents of lookahead_buffer and history_buffer
+ // match what is in sbuffer
+ sbuffer.rotate_right(1);
+
+ // adjust so that sbuffer[19] is the same as lookahead_buffer[0]
+ if (test.get_lookahead_buffer_size() < 20)
+ sbuffer.rotate_left(20-test.get_lookahead_buffer_size());
+
+ for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i));
+ }
+ for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i)
+ {
+ DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i));
+ }
+ sbuffer.rotate_left(1);
+
+
+
+ unsigned long match_index = 0, match_length = 0;
+ unsigned long lookahead_size_before = test.get_lookahead_buffer_size();
+ unsigned long history_size_before = test.get_history_buffer_size();
+ test.find_match(match_index,match_length,2);
+
+ if (match_length != 0)
+ {
+ DLIB_TEST_MSG(match_index < history_size_before,
+ "match_index: " << match_index <<
+ "\nhistory_size_before: " << history_size_before);
+
+ }
+
+
+ DLIB_TEST(test.get_lookahead_buffer_size() == lookahead_size_before-match_length);
+
+ sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0]
+ // verify the match with sbuffer
+ for (unsigned int i = 0; i < match_length; ++i)
+ {
+ DLIB_TEST_MSG(sbuffer[19-i] == sbuffer[match_index+20-i],i);
+ }
+ sbuffer.rotate_left(1); // free up sbuffer[0] for new data
+
+
+
+ } // for (int g = 0; g < 300; ++g)
+
+ }
+
+
+
+
+ class lz77_buffer_tester : public tester
+ {
+ public:
+ lz77_buffer_tester (
+ ) :
+ tester ("test_lz77_buffer",
+ "Runs tests on the lz77_buffer component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ lz77_buffer_kernel_test<lz77_buffer::kernel_1a> ();
+ dlog << LINFO << "testing kernel_1a_c";
+ lz77_buffer_kernel_test<lz77_buffer::kernel_1a_c>();
+ dlog << LINFO << "testing kernel_2a";
+ lz77_buffer_kernel_test<lz77_buffer::kernel_2a> ();
+ dlog << LINFO << "testing kernel_2a_c";
+ lz77_buffer_kernel_test<lz77_buffer::kernel_2a_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/main.cpp b/ml/dlib/dlib/test/main.cpp
new file mode 100644
index 000000000..4800a7211
--- /dev/null
+++ b/ml/dlib/dlib/test/main.cpp
@@ -0,0 +1,217 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <iostream>
+#include <fstream>
+#include <dlib/cmd_line_parser.h>
+#include "tester.h"
+#include <dlib/string.h>
+
+
+using namespace std;
+using namespace dlib;
+using namespace test;
+
+typedef cmd_line_parser<char>::check_1a_c clp;
+
+static logger dlog("test.main");
+
+int main (int argc, char** argv)
+{
+ try
+ {
+ clp parser;
+
+ parser.add_option("runall","Run all the tests that don't take any arguments.");
+ parser.add_option("h","Displays this information.");
+ parser.add_option("n","How many times to run the selected tests. The default is 1.",1);
+ parser.add_option("d","log debugging statements to file debug.txt.");
+ parser.add_option("l","Set the logging level (all, trace, debug, info, warn, error, or fatal), the default is all.",1);
+ parser.add_option("a","Append debugging messages to debug.txt rather than clearing the file at program startup.");
+ parser.add_option("q","Be quiet. Don't print the testing progress or results to standard out.");
+
+ unsigned long num = 1;
+
+ // add the options for all the different tests
+ testers().reset();
+ while (testers().move_next())
+ {
+ tester& test = *testers().element().value();
+ parser.add_option(test.cmd_line_switch(), test.description(), test.num_of_args());
+ }
+
+ parser.parse(argc,argv);
+
+ parser.check_option_arg_range("n",1,1000000000);
+ const char* singles[] = {"d","l","a","n","h","runall","q"};
+ parser.check_one_time_options(singles);
+ const char* d_sub[] = {"l","a"};
+ const char* l_args[] = {"all", "trace", "debug", "info", "warn", "error", "fatal"};
+ parser.check_sub_options("d",d_sub);
+ parser.check_option_arg_range("l",l_args);
+
+
+ if (parser.option("n"))
+ {
+ num = string_cast<unsigned long>(parser.option("n").argument());
+ }
+
+ if (parser.option("q"))
+ {
+ be_verbose = false;
+ }
+
+ if (parser.option("h"))
+ {
+ cout << "Usage: test [options]\n";
+ parser.print_options(cout);
+ cout << "\n\n";
+ return 0;
+ }
+
+ ofstream fout;
+ if (parser.option("d"))
+ {
+ if (parser.option("a"))
+ fout.open("debug.txt",ios::app);
+ else
+ fout.open("debug.txt");
+
+ set_all_logging_output_streams(fout);
+
+ if (parser.option("l").count() == 0)
+ set_all_logging_levels(LALL);
+ else if (parser.option("l").argument() == "all")
+ set_all_logging_levels(LALL);
+ else if (parser.option("l").argument() == "trace")
+ set_all_logging_levels(LTRACE);
+ else if (parser.option("l").argument() == "debug")
+ set_all_logging_levels(LDEBUG);
+ else if (parser.option("l").argument() == "info")
+ set_all_logging_levels(LINFO);
+ else if (parser.option("l").argument() == "warn")
+ set_all_logging_levels(LWARN);
+ else if (parser.option("l").argument() == "error")
+ set_all_logging_levels(LERROR);
+ else if (parser.option("l").argument() == "fatal")
+ set_all_logging_levels(LFATAL);
+ }
+ else
+ {
+ set_all_logging_levels(LNONE);
+ }
+
+ unsigned long num_of_failed_tests = 0;
+ unsigned long num_of_passed_tests = 0;
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ dlog << LINFO << "************ Starting Test Run " << i+1 << " of " << num << ". ************";
+
+ // loop over all the testers and see if they are supposed to run
+ testers().reset();
+ while (testers().move_next())
+ {
+ tester& test= *testers().element().value();
+ const clp::option_type& opt = parser.option(test.cmd_line_switch());
+ // run the test for this option as many times as the user has requested.
+ for (unsigned long j = 0; j < parser.option("runall").count() + opt.count(); ++j)
+ {
+ // quit this loop if this option has arguments and this round through the loop is
+ // from the runall option being present.
+ if (test.num_of_args() > 0 && j == opt.count())
+ break;
+
+ if (be_verbose)
+ cout << "Running " << test.cmd_line_switch() << " " << flush;
+
+ dlog << LINFO << "Running " << test.cmd_line_switch();
+ try
+ {
+ switch (test.num_of_args())
+ {
+ case 0:
+ test.perform_test();
+ break;
+ case 1:
+ test.perform_test(opt.argument(0,j));
+ break;
+ case 2:
+ test.perform_test(opt.argument(0,j), opt.argument(1,j));
+ break;
+ default:
+ cerr << "\n\nThe test '" << test.cmd_line_switch() << "' requested " << test.num_of_args()
+ << " arguments but only 2 are supported." << endl;
+ dlog << LINFO << "The test '" << test.cmd_line_switch() << "' requested " << test.num_of_args()
+ << " arguments but only 2 are supported.";
+ break;
+ }
+ if (be_verbose)
+ cout << "\r \r";
+
+ ++num_of_passed_tests;
+
+ }
+ catch (std::exception& e)
+ {
+ if (be_verbose)
+ {
+ cout << "\n\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n";
+ cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TEST FAILED: " << test.cmd_line_switch()
+ << " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!";
+ cout << "\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n";
+ cout << "Failure message from test: " << e.what() << endl;
+ }
+
+
+ dlog << LERROR << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!";
+ dlog << LERROR << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TEST FAILED: " << test.cmd_line_switch()
+ << " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!";
+ dlog << LERROR << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!";
+ dlog << LERROR << "Failure message from test: " << e.what();
+ ++num_of_failed_tests;
+ }
+ }
+ }
+ }
+ dlog << LINFO << "Testing Finished";
+ if (num_of_passed_tests == 0 && num_of_failed_tests == 0)
+ {
+ cout << "You didn't select any tests to run.\n";
+ cout << "Try the -h option for more information.\n";
+ }
+ else if (num_of_failed_tests == 0)
+ {
+ if (be_verbose)
+ {
+ cout << "\n\nTesting Finished\n";
+ cout << "Total number of individual testing statements executed: "<< number_of_testing_statements_executed() << endl;
+ cout << "All tests completed successfully\n\n";
+ }
+ dlog << LINFO << "Total number of individual testing statements executed: "<< number_of_testing_statements_executed();
+ dlog << LINFO << "All tests completed successfully";
+ }
+ else
+ {
+ if (be_verbose)
+ {
+ cout << "\n\nTesting Finished\n";
+ cout << "Total number of individual testing statements executed: "<< number_of_testing_statements_executed() << endl;
+ cout << "Number of failed tests: " << num_of_failed_tests << "\n";
+ cout << "Number of passed tests: " << num_of_passed_tests << "\n\n";
+ }
+ dlog << LINFO << "Total number of individual testing statements executed: "<< number_of_testing_statements_executed();
+ dlog << LWARN << "Number of failed tests: " << num_of_failed_tests;
+ dlog << LWARN << "Number of passed tests: " << num_of_passed_tests;
+ }
+
+ return num_of_failed_tests;
+ }
+ catch (exception& e)
+ {
+ cout << e.what() << endl;
+ cout << "\nTry the -h option for more information.\n";
+ cout << endl;
+ }
+}
+
diff --git a/ml/dlib/dlib/test/makefile b/ml/dlib/dlib/test/makefile
new file mode 100644
index 000000000..d4d478705
--- /dev/null
+++ b/ml/dlib/dlib/test/makefile
@@ -0,0 +1,185 @@
+# This is the makefile used to build the dlib C++ library's regression test suite
+# on Debian Linux using the gcc compiler.
+
+# this is the name of the output executable
+TARGET = dtest
+
+# these are the compile time flags passed to gcc
+CXXFLAGS ?= -ggdb -Wall
+CPPFLAGS ?= -std=c++11 -DDEBUG -DDLIB_NO_GUI_SUPPORT -I../..
+
+# These are the link time flags passed to gcc
+LFLAGS = -lpthread -lnsl
+
+# The name of the compiler. If you only have one version of
+# gcc installed then you probably want to change this to just g++
+CXX ?= nice g++
+
+####################################################
+####################################################
+# Here we list all the cpp files we want to compile
+
+SRC = main.cpp
+SRC += tester.cpp
+SRC += ../all/source.cpp
+
+SRC += example.cpp
+SRC += example_args.cpp
+
+SRC += active_learning.cpp
+SRC += any.cpp
+SRC += any_function.cpp
+SRC += array2d.cpp
+SRC += array.cpp
+SRC += assignment_learning.cpp
+SRC += base64.cpp
+SRC += bayes_nets.cpp
+SRC += bigint.cpp
+SRC += binary_search_tree_kernel_1a.cpp
+SRC += binary_search_tree_kernel_2a.cpp
+SRC += binary_search_tree_mm1.cpp
+SRC += binary_search_tree_mm2.cpp
+SRC += bridge.cpp
+SRC += bsp.cpp
+SRC += byte_orderer.cpp
+SRC += cca.cpp
+SRC += clustering.cpp
+SRC += cmd_line_parser.cpp
+SRC += cmd_line_parser_wchar_t.cpp
+SRC += compress_stream.cpp
+SRC += conditioning_class_c.cpp
+SRC += conditioning_class.cpp
+SRC += config_reader.cpp
+SRC += crc32.cpp
+SRC += create_iris_datafile.cpp
+SRC += data_io.cpp
+SRC += directed_graph.cpp
+SRC += discriminant_pca.cpp
+SRC += disjoint_subsets.cpp
+SRC += ekm_and_lisf.cpp
+SRC += empirical_kernel_map.cpp
+SRC += entropy_coder.cpp
+SRC += entropy_encoder_model.cpp
+SRC += face.cpp
+SRC += fft.cpp
+SRC += fhog.cpp
+SRC += filtering.cpp
+SRC += find_max_factor_graph_nmplp.cpp
+SRC += find_max_factor_graph_viterbi.cpp
+SRC += geometry.cpp
+SRC += graph.cpp
+SRC += graph_cuts.cpp
+SRC += graph_labeler.cpp
+SRC += hash.cpp
+SRC += hash_map.cpp
+SRC += hash_set.cpp
+SRC += hash_table.cpp
+SRC += hog_image.cpp
+SRC += image.cpp
+SRC += iosockstream.cpp
+SRC += is_same_object.cpp
+SRC += kcentroid.cpp
+SRC += kernel_matrix.cpp
+SRC += kmeans.cpp
+SRC += learning_to_track.cpp
+SRC += least_squares.cpp
+SRC += linear_manifold_regularizer.cpp
+SRC += lspi.cpp
+SRC += lz77_buffer.cpp
+SRC += map.cpp
+SRC += matrix2.cpp
+SRC += matrix3.cpp
+SRC += matrix4.cpp
+SRC += matrix_chol.cpp
+SRC += matrix.cpp
+SRC += matrix_eig.cpp
+SRC += matrix_lu.cpp
+SRC += matrix_qr.cpp
+SRC += max_cost_assignment.cpp
+SRC += max_sum_submatrix.cpp
+SRC += md5.cpp
+SRC += member_function_pointer.cpp
+SRC += metaprogramming.cpp
+SRC += mpc.cpp
+SRC += multithreaded_object.cpp
+SRC += numerical_integration.cpp
+SRC += object_detector.cpp
+SRC += oca.cpp
+SRC += one_vs_all_trainer.cpp
+SRC += one_vs_one_trainer.cpp
+SRC += optimization.cpp
+SRC += optimization_test_functions.cpp
+SRC += opt_qp_solver.cpp
+SRC += parallel_for.cpp
+SRC += parse.cpp
+SRC += pipe.cpp
+SRC += pixel.cpp
+SRC += probabilistic.cpp
+SRC += pyramid_down.cpp
+SRC += queue.cpp
+SRC += rand.cpp
+SRC += ranking.cpp
+SRC += read_write_mutex.cpp
+SRC += reference_counter.cpp
+SRC += rls.cpp
+SRC += sammon.cpp
+SRC += scan_image.cpp
+SRC += sequence.cpp
+SRC += sequence_labeler.cpp
+SRC += sequence_segmenter.cpp
+SRC += serialize.cpp
+SRC += set.cpp
+SRC += sldf.cpp
+SRC += sliding_buffer.cpp
+SRC += sockets2.cpp
+SRC += sockets.cpp
+SRC += sockstreambuf.cpp
+SRC += sparse_vector.cpp
+SRC += stack.cpp
+SRC += static_map.cpp
+SRC += static_set.cpp
+SRC += statistics.cpp
+SRC += std_vector_c.cpp
+SRC += string.cpp
+SRC += svm_c_linear.cpp
+SRC += svm_c_linear_dcd.cpp
+SRC += svm.cpp
+SRC += svm_multiclass_linear.cpp
+SRC += svm_struct.cpp
+SRC += svr_linear_trainer.cpp
+SRC += symmetric_matrix_cache.cpp
+SRC += thread_pool.cpp
+SRC += threads.cpp
+SRC += timer.cpp
+SRC += tokenizer.cpp
+SRC += trust_region.cpp
+SRC += tuple.cpp
+SRC += type_safe_union.cpp
+SRC += vectorstream.cpp
+
+
+####################################################
+
+TMP = $(SRC:.cpp=.o)
+OBJ = $(TMP:.c=.o)
+
+$(TARGET): $(OBJ)
+ @echo Linking $@
+ @$(CXX) $(LDFLAGS) $(OBJ) $(LFLAGS) -o $@
+ @echo Build Complete
+
+clean:
+ @rm -f $(OBJ) $(TARGET)
+ @echo All object files and binaries removed
+
+dep:
+ @echo Running makedepend
+ @makedepend -- $(CFLAGS) -- $(SRC) 2> /dev/null
+ @echo Completed makedepend
+
+###############################################################################
+########## Stuff from makedepend #####
+########## type make dep at the command line to rebuild the dependencies #####
+########## Also, DON'T edit the contents of this file beyond this line. #####
+###############################################################################
+
diff --git a/ml/dlib/dlib/test/map.cpp b/ml/dlib/dlib/test/map.cpp
new file mode 100644
index 000000000..6901ddf05
--- /dev/null
+++ b/ml/dlib/dlib/test/map.cpp
@@ -0,0 +1,441 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/map.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.map");
+
+ template <
+ typename map
+ >
+ void map_kernel_test (
+ )
+ /*!
+ requires
+ - map is an implementation of map/map_kernel_abstract.h and
+ is instantiated to map int to int
+ ensures
+ - runs tests on map for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ srand(static_cast<unsigned int>(time(0)));
+
+
+
+ map test, test2;
+
+ enumerable<map_pair<int,int> >& e = test;
+ DLIB_TEST(e.at_start() == true);
+
+ for (int j = 0; j < 4; ++j)
+ {
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+
+
+ int a,b;
+ a = 8;
+ b = 94;
+ test.add(a,b);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(test.is_in_domain(8) == true);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+ DLIB_TEST(test[8] == 94);
+ a = 53;
+ b = 4;
+ test.add(a,b);
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test.is_in_domain(53) == true);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+ DLIB_TEST(test[53] == 4);
+
+
+ swap(test,test2);
+
+
+ DLIB_TEST(test2.size() == 2);
+ DLIB_TEST(test2.is_in_domain(8) == true);
+ DLIB_TEST(test2.is_in_domain(5) == false);
+ DLIB_TEST(test2.is_in_domain(0) == false);
+ DLIB_TEST(test2.is_in_domain(-999) == false);
+ DLIB_TEST(test2.is_in_domain(4999) == false);
+ DLIB_TEST(test2[8] == 94);
+ DLIB_TEST(test2.size() == 2);
+ DLIB_TEST(test2.is_in_domain(53) == true);
+ DLIB_TEST(test2.is_in_domain(5) == false);
+ DLIB_TEST(test2.is_in_domain(0) == false);
+ DLIB_TEST(test2.is_in_domain(-999) == false);
+ DLIB_TEST(test2.is_in_domain(4999) == false);
+ DLIB_TEST(test2[53] == 4);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_in_domain(8) == false);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_in_domain(53) == false);
+ DLIB_TEST(test.is_in_domain(5) == false);
+ DLIB_TEST(test.is_in_domain(0) == false);
+ DLIB_TEST(test.is_in_domain(-999) == false);
+ DLIB_TEST(test.is_in_domain(4999) == false);
+
+
+ test.clear();
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+ test.clear();
+ DLIB_TEST(test.size() == 0);
+
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+
+ int count = 0;
+ a = -1;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element().key() == test.element().key());
+ DLIB_TEST(test.element().value() == test.element().value());
+ DLIB_TEST(test.element().key() == test.element().key());
+ DLIB_TEST(test.element().value() == test.element().value());
+
+
+ DLIB_TEST(a < test.element().key());
+ a = test.element().key();
+ ++count;
+ }
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.move_next() == false);
+
+ DLIB_TEST(count == 10000);
+
+ test.swap(test2);
+
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test2.size() == 10000);
+ count = 0;
+ a = -1;
+ test2.reset();
+
+ test2.move_next();
+ test2.element().value() = 99;
+ DLIB_TEST(test2[test2.element().key()] == 99);
+ DLIB_TEST(test2.element().value() == 99);
+
+ test2.reset();
+
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2[test2.element().key()] == test2.element().value());
+ DLIB_TEST(test2.element().key() == test2.element().key());
+ DLIB_TEST(test2.element().value() == test2.element().value());
+ DLIB_TEST(test2.element().key() == test2.element().key());
+ DLIB_TEST(test2.element().value() == test2.element().value());
+ DLIB_TEST(a < test2.element().key());
+ a = test2.element().key();
+ ++count;
+ }
+ DLIB_TEST(test2.size() == 10000);
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.move_next() == false);
+
+
+
+ test2.clear();
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.at_start() == true);
+
+ while (test.size() < 20000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ // serialize the state of test, then clear test, then
+ // load the state back into test.
+ ostringstream sout;
+ serialize(test,sout);
+ istringstream sin(sout.str());
+ test.clear();
+ deserialize(test,sin);
+
+ DLIB_TEST(test.at_start() == true);
+
+ {
+ int* array1 = new int[test.size()];
+ int* array2 = new int[test.size()];
+
+ int* tmp1 = array1;
+ int* tmp2 = array2;
+
+ count = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element().key() == test.element().key());
+ DLIB_TEST(test.element().value() == test.element().value());
+ DLIB_TEST(test.element().key() == test.element().key());
+ DLIB_TEST(test.current_element_valid() == true);
+ *tmp1 = test.element().key();
+ *tmp2 = test.element().value();
+ ++tmp1;
+ ++tmp2;
+ ++count;
+ }
+ DLIB_TEST(count == 20000);
+
+ tmp1 = array1;
+ tmp2 = array2;
+ for (int i = 0; i < 20000; ++i)
+ {
+ DLIB_TEST(test.is_in_domain(*tmp1) == true);
+ DLIB_TEST(test[*tmp1] == *tmp2);
+ ++tmp1;
+ ++tmp2;
+ }
+
+ DLIB_TEST(test.size() == 20000);
+
+ tmp1 = array1;
+ tmp2 = array2;
+ count = 0;
+ while (test.size() > 10000)
+ {
+ test.remove(*tmp1,a,b);
+ DLIB_TEST(*tmp1 == a);
+ DLIB_TEST(*tmp2 == b);
+ ++tmp1;
+ ++tmp2;
+ ++count;
+ }
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test.size() == 10000);
+
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element().key() == *tmp1);
+ DLIB_TEST(test.element().key() == *tmp1);
+ DLIB_TEST(test.element().key() == *tmp1);
+ DLIB_TEST(test.element().value() == *tmp2);
+ DLIB_TEST(test.element().value() == *tmp2);
+ DLIB_TEST(test.element().value() == *tmp2);
+ ++tmp1;
+ ++tmp2;
+ ++count;
+ }
+ DLIB_TEST(count == 20000);
+ DLIB_TEST(test.size() == 10000);
+
+ while (test.size() < 20000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ test2.swap(test);
+
+ count = 0;
+ a = -1;
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2.element().key() == test2.element().key());
+ DLIB_TEST(test2.element().value() == test2.element().value());
+ DLIB_TEST(test2.element().key() == test2.element().key());
+ DLIB_TEST(a < test2.element().key());
+ a = test2.element().key();
+ ++count;
+ }
+
+ DLIB_TEST(count == 20000);
+ DLIB_TEST(test2.size() == 20000);
+
+ a = -1;
+ int c = 0;
+ while (test2.size()>0)
+ {
+ test2.remove_any(b,c);
+ DLIB_TEST( a < b);
+ a = b;
+ }
+
+ DLIB_TEST(test2.size() == 0);
+ delete [] array1;
+ delete [] array2;
+ }
+
+ test.clear();
+ test2.clear();
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ b = ::rand();
+ if (!test.is_in_domain(a))
+ test.add(a,b);
+ }
+
+ count = 0;
+ a = -1;
+ while (test.move_next())
+ {
+ DLIB_TEST(a < test.element().key());
+ DLIB_TEST(test[test.element().key()] == test.element().value());
+ a = test.element().key();
+ ++count;
+ if (count == 5000)
+ break;
+ DLIB_TEST(test.current_element_valid() == true);
+ }
+
+ test.reset();
+
+ count = 0;
+ a = -1;
+ while (test.move_next())
+ {
+ DLIB_TEST(a < test.element().key());
+ a = test.element().key();
+ ++count;
+ DLIB_TEST(test.current_element_valid() == true);
+ }
+
+ DLIB_TEST(count == 10000);
+
+
+ test.clear();
+ test2.clear();
+ }
+
+
+ {
+ test.clear();
+ DLIB_TEST(test.size() == 0);
+ int a = 5;
+ int b = 6;
+ test.add(a,b);
+ a = 7;
+ b = 8;
+ test.add(a,b);
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test[7] == 8);
+ DLIB_TEST(test[5] == 6);
+ DLIB_TEST(test.is_in_domain(7));
+ DLIB_TEST(test.is_in_domain(5));
+ test.destroy(7);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(!test.is_in_domain(7));
+ DLIB_TEST(test.is_in_domain(5));
+ test.destroy(5);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(!test.is_in_domain(7));
+ DLIB_TEST(!test.is_in_domain(5));
+ }
+
+ }
+
+
+
+
+ class map_tester : public tester
+ {
+ public:
+ map_tester (
+ ) :
+ tester ("test_map",
+ "Runs tests on the map component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ map_kernel_test<dlib::map<int,int>::kernel_1a> ();
+ dlog << LINFO << "testing kernel_1a_c";
+ map_kernel_test<dlib::map<int,int>::kernel_1a_c>();
+ dlog << LINFO << "testing kernel_1b";
+ map_kernel_test<dlib::map<int,int>::kernel_1b> ();
+ dlog << LINFO << "testing kernel_1b_c";
+ map_kernel_test<dlib::map<int,int>::kernel_1b_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/matrix.cpp b/ml/dlib/dlib/test/matrix.cpp
new file mode 100644
index 000000000..0a3ea5996
--- /dev/null
+++ b/ml/dlib/dlib/test/matrix.cpp
@@ -0,0 +1,1519 @@
+
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+
+#include "tester.h"
+#include <dlib/memory_manager_stateless.h>
+#include <dlib/array2d.h>
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ dlib::rand rnd;
+
+ logger dlog("test.matrix");
+
+ template <typename type>
+ const matrix<type> rand_sp_banded(long n, long bw)
+ {
+ matrix<type> m = 10 * identity_matrix<type>(n);
+ for (long row = 0; row < m.nr(); ++row)
+ {
+ for (long col = row; col < min(m.nc(), row + bw); ++col)
+ {
+ type r = rnd.get_random_double();
+ m(row,col) += r;
+ m(col,row) += r;
+ }
+ }
+
+ return m;
+ }
+
+ void matrix_test (
+ )
+ /*!
+ ensures
+ - runs tests on the matrix stuff compliance with the specs
+ !*/
+ {
+ typedef memory_manager_stateless<char>::kernel_2_2a MM;
+ print_spinner();
+
+
+ {
+ matrix<complex<double>,2,2,MM> m;
+ set_all_elements(m,complex<double>(1,2));
+ DLIB_TEST((conj(m) == uniform_matrix<complex<double>,2,2>(conj(m(0,0)))));
+ DLIB_TEST((real(m) == uniform_matrix<double,2,2>(1)));
+ DLIB_TEST((imag(m) == uniform_matrix<double,2,2>(2)));
+ DLIB_TEST_MSG((sum(abs(norm(m) - uniform_matrix<double,2,2>(5))) < 1e-10 ),norm(m));
+
+ }
+
+ {
+ matrix<double,5,5,MM,column_major_layout> m(5,5);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ m = cos(exp(m));
+
+
+ matrix<double> mi = pinv(m );
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
+
+ mi = pinv(m,1e-12);
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
+
+ m = diagm(diag(m));
+ mi = pinv(diagm(diag(m)),1e-12);
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
+
+ mi = pinv(m,0);
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
+
+ m = diagm(diag(m));
+ mi = pinv(diagm(diag(m)),0);
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
+ DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
+ }
+ {
+ matrix<double,5,0,MM> m(5,5);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ m = cos(exp(m));
+
+
+ matrix<double> mi = pinv(m );
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
+ }
+
+ {
+ matrix<double,0,5,MM> m(5,5);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ m = cos(exp(m));
+
+
+ matrix<double> mi = pinv(m );
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
+ }
+
+
+ {
+ matrix<double> m(5,5);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ m = cos(exp(m));
+
+
+ matrix<double> mi = pinv(m );
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
+ }
+
+ {
+ matrix<double,5,2,MM,column_major_layout> m;
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ m = cos(exp(m));
+
+
+ matrix<double> mi = pinv(m );
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,2>())));
+ }
+
+ {
+ matrix<double> m(5,2);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ m = cos(exp(m));
+
+
+ matrix<double> mi = pinv(m );
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,2>())));
+ }
+
+ {
+ matrix<double,5,2,MM,column_major_layout> m;
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ m = cos(exp(m));
+
+
+ matrix<double> mi = trans(pinv(trans(m) ));
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,2>())));
+ }
+
+ {
+ matrix<double> m(5,2);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ m = cos(exp(m));
+
+
+ matrix<double> mi = trans(pinv(trans(m) ));
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,2>())));
+ }
+
+ {
+ matrix<long> a1(5,1);
+ matrix<long,0,0,MM,column_major_layout> a2(1,5);
+ matrix<long,5,1> b1(5,1);
+ matrix<long,1,5> b2(1,5);
+ matrix<long,0,1> c1(5,1);
+ matrix<long,1,0> c2(1,5);
+ matrix<long,0,1,MM,column_major_layout> d1(5,1);
+ matrix<long,1,0,MM> d2(1,5);
+
+ for (long i = 0; i < 5; ++i)
+ {
+ a1(i) = i;
+ a2(i) = i;
+ b1(i) = i;
+ b2(i) = i;
+ c1(i) = i;
+ c2(i) = i;
+ d1(i) = i;
+ d2(i) = i;
+ }
+
+ DLIB_TEST(a1 == trans(a2));
+ DLIB_TEST(a1 == trans(b2));
+ DLIB_TEST(a1 == trans(c2));
+ DLIB_TEST(a1 == trans(d2));
+
+ DLIB_TEST(a1 == b1);
+ DLIB_TEST(a1 == c1);
+ DLIB_TEST(a1 == d1);
+
+ DLIB_TEST(trans(a1) == c2);
+ DLIB_TEST(trans(b1) == c2);
+ DLIB_TEST(trans(c1) == c2);
+ DLIB_TEST(trans(d1) == c2);
+
+ DLIB_TEST(sum(a1) == 10);
+ DLIB_TEST(sum(a2) == 10);
+ DLIB_TEST(sum(b1) == 10);
+ DLIB_TEST(sum(b2) == 10);
+ DLIB_TEST(sum(c1) == 10);
+ DLIB_TEST(sum(c2) == 10);
+ DLIB_TEST(sum(d1) == 10);
+ DLIB_TEST(sum(d2) == 10);
+
+ const matrix<long> orig1 = a1;
+ const matrix<long> orig2 = a2;
+
+ ostringstream sout;
+ serialize(a1,sout);
+ serialize(a2,sout);
+ serialize(b1,sout);
+ serialize(b2,sout);
+ serialize(c1,sout);
+ serialize(c2,sout);
+ serialize(d1,sout);
+ serialize(d2,sout);
+
+ DLIB_TEST(a1 == orig1);
+ DLIB_TEST(a2 == orig2);
+ DLIB_TEST(b1 == orig1);
+ DLIB_TEST(b2 == orig2);
+ DLIB_TEST(c1 == orig1);
+ DLIB_TEST(c2 == orig2);
+ DLIB_TEST(d1 == orig1);
+ DLIB_TEST(d2 == orig2);
+
+ set_all_elements(a1,99);
+ set_all_elements(a2,99);
+ set_all_elements(b1,99);
+ set_all_elements(b2,99);
+ set_all_elements(c1,99);
+ set_all_elements(c2,99);
+ set_all_elements(d1,99);
+ set_all_elements(d2,99);
+
+ DLIB_TEST(a1 != orig1);
+ DLIB_TEST(a2 != orig2);
+ DLIB_TEST(b1 != orig1);
+ DLIB_TEST(b2 != orig2);
+ DLIB_TEST(c1 != orig1);
+ DLIB_TEST(c2 != orig2);
+ DLIB_TEST(d1 != orig1);
+ DLIB_TEST(d2 != orig2);
+
+ istringstream sin(sout.str());
+
+ deserialize(a1,sin);
+ deserialize(a2,sin);
+ deserialize(b1,sin);
+ deserialize(b2,sin);
+ deserialize(c1,sin);
+ deserialize(c2,sin);
+ deserialize(d1,sin);
+ deserialize(d2,sin);
+
+ DLIB_TEST(a1 == orig1);
+ DLIB_TEST(a2 == orig2);
+ DLIB_TEST(b1 == orig1);
+ DLIB_TEST(b2 == orig2);
+ DLIB_TEST(c1 == orig1);
+ DLIB_TEST(c2 == orig2);
+ DLIB_TEST(d1 == orig1);
+ DLIB_TEST(d2 == orig2);
+
+
+ }
+
+ {
+ matrix<double,1,0> a(5);
+ matrix<double,0,1> b(5);
+ matrix<double,1,5> c(5);
+ matrix<double,5,1> d(5);
+ DLIB_TEST(a.nr() == 1);
+ DLIB_TEST(a.nc() == 5);
+ DLIB_TEST(c.nr() == 1);
+ DLIB_TEST(c.nc() == 5);
+
+ DLIB_TEST(b.nc() == 1);
+ DLIB_TEST(b.nr() == 5);
+ DLIB_TEST(d.nc() == 1);
+ DLIB_TEST(d.nr() == 5);
+ }
+
+ {
+ matrix<double,1,0> a;
+ matrix<double,0,1> b;
+ matrix<double,1,5> c;
+ matrix<double,5,1> d;
+
+ a.set_size(5);
+ b.set_size(5);
+ c.set_size(5);
+ d.set_size(5);
+
+ DLIB_TEST(a.nr() == 1);
+ DLIB_TEST(a.nc() == 5);
+ DLIB_TEST(c.nr() == 1);
+ DLIB_TEST(c.nc() == 5);
+
+ DLIB_TEST(b.nc() == 1);
+ DLIB_TEST(b.nr() == 5);
+ DLIB_TEST(d.nc() == 1);
+ DLIB_TEST(d.nr() == 5);
+ }
+
+ {
+ matrix<double> a(1,5);
+ matrix<double> b(5,1);
+
+ set_all_elements(a,1);
+ set_all_elements(b,1);
+
+
+ a = a*b;
+
+ DLIB_TEST(a(0) == 5);
+ }
+
+ {
+ matrix<double,0,0,MM,column_major_layout> a(6,7);
+
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = r*a.nc() + c;
+ }
+ }
+
+
+
+ DLIB_TEST(rowm(a,1).nr() == 1);
+ DLIB_TEST(rowm(a,1).nc() == a.nc());
+ DLIB_TEST(colm(a,1).nr() == a.nr());
+ DLIB_TEST(colm(a,1).nc() == 1);
+
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ DLIB_TEST( rowm(a,1)(c) == 1*a.nc() + c);
+ }
+
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ DLIB_TEST( colm(a,1)(r) == r*a.nc() + 1);
+ }
+
+ rectangle rect(2, 1, 3+2-1, 2+1-1);
+ DLIB_TEST(get_rect(a).contains(get_rect(a)));
+ DLIB_TEST(get_rect(a).contains(rect));
+ for (long r = 0; r < 2; ++r)
+ {
+ for (long c = 0; c < 3; ++c)
+ {
+ DLIB_TEST(subm(a,1,2,2,3)(r,c) == (r+1)*a.nc() + c+2);
+ DLIB_TEST(subm(a,1,2,2,3) == subm(a,rect));
+ DLIB_TEST(subm_clipped(a,1,2,2,3) == subm(a,rect));
+ DLIB_TEST(subm_clipped(a,1,2,2,3) == subm_clipped(a,rect));
+ }
+ }
+
+ DLIB_TEST(subm(a,rectangle()).nr() == 0);
+ DLIB_TEST(subm(a,rectangle()).nc() == 0);
+
+ }
+
+ {
+ array2d<double> a;
+ a.set_size(6,7);
+
+
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a[r][c] = r*a.nc() + c;
+ }
+ }
+
+
+
+ DLIB_TEST(rowm(mat(a),1).nr() == 1);
+ DLIB_TEST(rowm(mat(a),1).nc() == a.nc());
+ DLIB_TEST(colm(mat(a),1).nr() == a.nr());
+ DLIB_TEST(colm(mat(a),1).nc() == 1);
+
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ DLIB_TEST( rowm(mat(a),1)(c) == 1*a.nc() + c);
+ }
+
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ DLIB_TEST( colm(mat(a),1)(r) == r*a.nc() + 1);
+ }
+
+ for (long r = 0; r < 2; ++r)
+ {
+ for (long c = 0; c < 3; ++c)
+ {
+ DLIB_TEST(subm(mat(a),1,2,2,3)(r,c) == (r+1)*a.nc() + c+2);
+ }
+ }
+
+
+ }
+
+ {
+ array2d<double> m;
+ m.set_size(5,5);
+
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m[r][c] = r*c;
+ }
+ }
+
+
+ matrix<double> mi = pinv(cos(exp(mat(m))) );
+ DLIB_TEST(mi.nr() == m.nc());
+ DLIB_TEST(mi.nc() == m.nr());
+ DLIB_TEST((equal(round_zeros(mi*cos(exp(mat(m))),0.000001) , identity_matrix<double,5>())));
+ DLIB_TEST((equal(round_zeros(cos(exp(mat(m)))*mi,0.000001) , identity_matrix<double,5>())));
+ }
+
+ {
+ matrix<long,5,5,MM,column_major_layout> m1, res;
+ matrix<long,2,2> m2;
+
+ set_all_elements(m1,0);
+
+
+ long res_vals[] = {
+ 9, 9, 9, 9, 9,
+ 0, 1, 1, 0, 0,
+ 0, 1, 1, 0, 2,
+ 0, 0, 2, 2, 2,
+ 0, 0, 2, 2, 0
+ };
+
+ res = res_vals;
+
+ set_all_elements(m2, 1);
+ set_subm(m1, range(1,2), range(1,2)) = subm(m2,0,0,2,2);
+ set_all_elements(m2, 2);
+ set_subm(m1, 3,2,2,2) = m2;
+
+ set_colm(m1,4) = trans(rowm(m1,4));
+ set_rowm(m1,0) = 9;
+
+ DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res);
+
+ set_subm(m1,0,0,5,5) = m1*m1;
+ DLIB_TEST_MSG(m1 == res*res, "m1: \n" << m1 << "\nres*res: \n" << res*res);
+
+ m1 = res;
+ set_subm(m1,1,1,2,2) = subm(m1,0,0,2,2);
+
+ long res_vals2[] = {
+ 9, 9, 9, 9, 9,
+ 0, 9, 9, 0, 0,
+ 0, 0, 1, 0, 2,
+ 0, 0, 2, 2, 2,
+ 0, 0, 2, 2, 0
+ };
+
+ res = res_vals2;
+ DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res);
+
+
+ }
+
+ {
+ matrix<long,5,5> m1, res;
+ matrix<long,2,2> m2;
+
+ set_all_elements(m1,0);
+
+
+ long res_vals[] = {
+ 9, 9, 9, 9, 9,
+ 0, 1, 1, 0, 0,
+ 0, 1, 1, 0, 2,
+ 0, 0, 2, 2, 2,
+ 0, 0, 2, 2, 0
+ };
+
+ res = res_vals;
+
+ set_all_elements(m2, 1);
+ set_subm(m1, rectangle(1,1,2,2)) = subm(m2,0,0,2,2);
+ set_all_elements(m2, 2);
+ set_subm(m1, 3,2,2,2) = m2;
+
+ set_colm(m1,4) = trans(rowm(m1,4));
+ set_rowm(m1,0) = 9;
+
+ DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res);
+
+ set_subm(m1,0,0,5,5) = m1*m1;
+ DLIB_TEST_MSG(m1 == res*res, "m1: \n" << m1 << "\nres*res: \n" << res*res);
+
+ m1 = res;
+ set_subm(m1,1,1,2,2) = subm(m1,0,0,2,2);
+
+ long res_vals2[] = {
+ 9, 9, 9, 9, 9,
+ 0, 9, 9, 0, 0,
+ 0, 0, 1, 0, 2,
+ 0, 0, 2, 2, 2,
+ 0, 0, 2, 2, 0
+ };
+
+ res = res_vals2;
+ DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res);
+
+
+ }
+
+ {
+ matrix<long,5,5> m1, res;
+ matrix<long,2,2> m2;
+
+ set_all_elements(m1,0);
+
+
+ long res_vals[] = {
+ 9, 0, 3, 3, 0,
+ 9, 2, 2, 2, 0,
+ 9, 2, 2, 2, 0,
+ 4, 4, 4, 4, 4,
+ 9, 0, 3, 3, 0
+ };
+ long res_vals_c3[] = {
+ 9, 0, 3, 0,
+ 9, 2, 2, 0,
+ 9, 2, 2, 0,
+ 4, 4, 4, 4,
+ 9, 0, 3, 0
+ };
+ long res_vals_r2[] = {
+ 9, 0, 3, 3, 0,
+ 9, 2, 2, 2, 0,
+ 4, 4, 4, 4, 4,
+ 9, 0, 3, 3, 0
+ };
+
+ matrix<long> temp;
+
+ res = res_vals;
+
+ temp = matrix<long,4,5>(res_vals_r2);
+ DLIB_TEST(remove_row<2>(res) == temp);
+ DLIB_TEST(remove_row<2>(res)(3,3) == 3);
+ DLIB_TEST(remove_row<2>(res).nr() == 4);
+ DLIB_TEST(remove_row<2>(res).nc() == 5);
+ DLIB_TEST(remove_row(res,2) == temp);
+ DLIB_TEST(remove_row(res,2)(3,3) == 3);
+ DLIB_TEST(remove_row(res,2).nr() == 4);
+ DLIB_TEST(remove_row(res,2).nc() == 5);
+
+ temp = matrix<long,5,5>(res_vals);
+ temp = remove_row(res,2);
+ DLIB_TEST((temp == matrix<long,4,5>(res_vals_r2)));
+ temp = matrix<long,5,5>(res_vals);
+ temp = remove_col(res,3);
+ DLIB_TEST((temp == matrix<long,5,4>(res_vals_c3)));
+
+ matrix<long,3,1> vect;
+ set_all_elements(vect,1);
+ temp = identity_matrix<long>(3);
+ temp = temp*vect;
+ DLIB_TEST(temp == vect);
+
+ temp = matrix<long,5,4>(res_vals_c3);
+ DLIB_TEST(remove_col(res,3) == temp);
+ DLIB_TEST(remove_col(res,3)(2,3) == 0);
+ DLIB_TEST(remove_col(res,3).nr() == 5);
+ DLIB_TEST(remove_col(res,3).nc() == 4);
+
+ set_all_elements(m2, 1);
+ set_subm(m1, rectangle(1,1,3,2)) = 2;
+ set_all_elements(m2, 2);
+ set_subm(m1, 3,2,2,2) = 3;
+
+ set_colm(m1,0) = 9;
+ set_rowm(m1,0) = rowm(m1,4);
+ set_rowm(m1,3) = 4;
+
+ DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res);
+
+ }
+
+ }
+
+
+ void matrix_test2()
+ {
+ print_spinner();
+
+
+ {
+
+ const double stuff[] = {
+ 1, 2, 3,
+ 6, 3, 3,
+ 7, 3, 9};
+
+ matrix<double,3,3> m(stuff);
+
+ // make m be symmetric
+ m = m*trans(m);
+
+ matrix<double> L = chol(m);
+ DLIB_TEST(equal(L*trans(L), m));
+
+ DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), "");
+ DLIB_TEST(equal(round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10), identity_matrix<double>(3), 1e-10));
+ DLIB_TEST(equal(round_zeros(inv_lower_triangular((L))*(L),1e-10) ,identity_matrix<double>(3),1e-10));
+
+ }
+
+ {
+
+ const double stuff[] = {
+ 1, 2, 3, 6, 3, 4,
+ 6, 3, 3, 1, 2, 3,
+ 7, 3, 9, 54.3, 5, 3,
+ -6, 3, -3, 1, 2, 3,
+ 1, 2, 3, 5, -3, 3,
+ 7, 3, -9, 54.3, 5, 3
+ };
+
+ matrix<double,6,6> m(stuff);
+
+ // make m be symmetric
+ m = m*trans(m);
+
+ matrix<double> L = chol(m);
+ DLIB_TEST_MSG(equal(L*trans(L), m, 1e-10), L*trans(L)-m);
+
+ DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), "");
+ DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*inv_lower_triangular((L))), "");
+ DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*trans(inv_upper_triangular(trans(L)))), "");
+ DLIB_TEST_MSG(equal(round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10) , identity_matrix<double>(6), 1e-10),
+ round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10));
+ DLIB_TEST_MSG(equal(round_zeros(inv_lower_triangular((L))*(L),1e-10) ,identity_matrix<double>(6), 1e-10),
+ round_zeros(inv_lower_triangular((L))*(L),1e-10));
+
+ }
+
+ {
+ // Test band chol
+ matrix<double> m = rand_sp_banded<double>(10, 3);
+
+ matrix<double> L = chol(m);
+ DLIB_TEST_MSG(equal(L*trans(L), m, 1e-10), L*trans(L)-m);
+ DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), "");
+ DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*inv_lower_triangular((L))), "");
+ DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*trans(inv_upper_triangular(trans(L)))), "");
+ }
+
+ {
+ // Test band chol in column major layout
+ matrix<double,10,10,default_memory_manager,column_major_layout> m(rand_sp_banded<double>(10, 3));
+
+ matrix<double> L = chol(m);
+ DLIB_TEST_MSG(equal(L*trans(L), m, 1e-10), L*trans(L)-m);
+ DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), "");
+ DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*inv_lower_triangular((L))), "");
+ DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*trans(inv_upper_triangular(trans(L)))), "");
+ }
+
+ {
+ matrix<int> m(3,4), m2;
+ m = 1,2,3,4,
+ 4,5,6,6,
+ 6,1,8,0;
+ m2 = m;
+ DLIB_TEST(round(m) == m2);
+ DLIB_TEST(round_zeros(m) == m2);
+
+ m2 = 0,2,3,4,
+ 4,5,6,6,
+ 6,0,8,0;
+
+ DLIB_TEST(round_zeros(m,2) == m2);
+ }
+
+
+ {
+
+ matrix<double,6,6> m(identity_matrix<double>(6)*4.5);
+
+ matrix<double> L = chol(m);
+ DLIB_TEST_MSG(equal(L*trans(L), m, 1e-10), L*trans(L)-m);
+
+ DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), "");
+ DLIB_TEST_MSG(equal(round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10) , identity_matrix<double>(6), 1e-10),
+ round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10));
+ DLIB_TEST_MSG(equal(round_zeros(inv_lower_triangular((L))*(L),1e-10) ,identity_matrix<double>(6), 1e-10),
+ round_zeros(inv_lower_triangular((L))*(L),1e-10));
+
+ }
+
+ {
+
+ matrix<double,6,6> m(identity_matrix<double>(6)*4.5);
+ m(1,4) = 2;
+
+ DLIB_TEST_MSG(dlib::equal(inv_upper_triangular(m), inv(m),1e-10), inv_upper_triangular(m)-inv(m));
+ DLIB_TEST_MSG(dlib::equal(inv_lower_triangular(trans(m)), inv(trans(m)),1e-10), inv_lower_triangular(trans(m))-inv(trans(m)));
+
+ }
+
+ {
+ matrix<double> a;
+ matrix<float> b;
+ matrix<int> i;
+ a.set_size(1000,10);
+ b.set_size(1000,10);
+ i.set_size(1000,10);
+ dlib::rand rnd;
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = rnd.get_random_double();
+ b(r,c) = rnd.get_random_float();
+ i(r,c) = r+c*r;
+ }
+ }
+
+ // make sure the multiply optimizations aren't messing things up
+ DLIB_TEST(trans(i)*i == tmp(trans(i)*i));
+ DLIB_TEST_MSG(equal(trans(a)*a , tmp(trans(a)*a), 1e-11),max(abs(trans(a)*a - tmp(trans(a)*a))));
+ DLIB_TEST_MSG(equal(trans(b)*b , tmp(trans(b)*b), 1e-3f),max(abs(trans(b)*b - tmp(trans(b)*b))));
+ }
+
+ {
+ matrix<int,4> i(4,1);
+ i(0) = 1;
+ i(1) = 2;
+ i(2) = 3;
+ i(3) = 4;
+ matrix<int,4,4> m;
+ set_all_elements(m,0);
+ m(0,0) = 1;
+ m(1,1) = 2;
+ m(2,2) = 3;
+ m(3,3) = 4;
+
+ DLIB_TEST(diagm(i) == m);
+ }
+
+ {
+ matrix<int,1,4> i;
+ i(0) = 1;
+ i(1) = 2;
+ i(2) = 3;
+ i(3) = 4;
+ matrix<int,4,4> m;
+ set_all_elements(m,0);
+ m(0,0) = 1;
+ m(1,1) = 2;
+ m(2,2) = 3;
+ m(3,3) = 4;
+
+ DLIB_TEST(diagm(i) == m);
+ }
+
+ {
+ matrix<int> i(4,1);
+ i(0) = 1;
+ i(1) = 2;
+ i(2) = 3;
+ i(3) = 4;
+ matrix<int> m(4,4);
+ set_all_elements(m,0);
+ m(0,0) = 1;
+ m(1,1) = 2;
+ m(2,2) = 3;
+ m(3,3) = 4;
+
+ DLIB_TEST(diagm(i) == m);
+ }
+
+ {
+ matrix<int> i(1,4);
+ i(0) = 1;
+ i(1) = 2;
+ i(2) = 3;
+ i(3) = 4;
+ matrix<int> m(4,4);
+ set_all_elements(m,0);
+ m(0,0) = 1;
+ m(1,1) = 2;
+ m(2,2) = 3;
+ m(3,3) = 4;
+
+ DLIB_TEST(diagm(i) == m);
+ }
+
+ {
+ DLIB_TEST(range(0,5).nc() == 6);
+ DLIB_TEST(range(1,5).nc() == 5);
+ DLIB_TEST(range(0,5).nr() == 1);
+ DLIB_TEST(range(1,5).nr() == 1);
+ DLIB_TEST(trans(range(0,5)).nr() == 6);
+ DLIB_TEST(trans(range(1,5)).nr() == 5);
+ DLIB_TEST(trans(range(0,5)).nc() == 1);
+ DLIB_TEST(trans(range(1,5)).nc() == 1);
+
+ DLIB_TEST(range(0,2,5).nc() == 3);
+ DLIB_TEST(range(1,2,5).nc() == 3);
+ DLIB_TEST(range(0,2,5).nr() == 1);
+ DLIB_TEST(range(1,2,5).nr() == 1);
+ DLIB_TEST(trans(range(0,2,5)).nr() == 3);
+ DLIB_TEST(trans(range(1,2,5)).nr() == 3);
+ DLIB_TEST(trans(range(0,2,5)).nc() == 1);
+ DLIB_TEST(trans(range(1,2,5)).nc() == 1);
+
+ DLIB_TEST(range(0,3,6).nc() == 3);
+ DLIB_TEST(range(1,3,5).nc() == 2);
+ DLIB_TEST(range(0,3,5).nr() == 1);
+ DLIB_TEST(range(1,3,5).nr() == 1);
+ DLIB_TEST(trans(range(0,3,6)).nr() == 3);
+ DLIB_TEST(trans(range(1,3,5)).nr() == 2);
+ DLIB_TEST(trans(range(0,3,5)).nc() == 1);
+ DLIB_TEST(trans(range(1,3,5)).nc() == 1);
+
+ DLIB_TEST(range(1,9,5).nc() == 1);
+ DLIB_TEST(range(1,9,5).nr() == 1);
+
+ DLIB_TEST(range(0,0).nc() == 1);
+ DLIB_TEST(range(0,0).nr() == 1);
+
+ DLIB_TEST(range(1,1)(0) == 1);
+
+ DLIB_TEST(range(0,5)(0) == 0 && range(0,5)(1) == 1 && range(0,5)(5) == 5);
+ DLIB_TEST(range(1,2,5)(0) == 1 && range(1,2,5)(1) == 3 && range(1,2,5)(2) == 5);
+ DLIB_TEST((range<0,5>()(0) == 0 && range<0,5>()(1) == 1 && range<0,5>()(5) == 5));
+ DLIB_TEST((range<1,2,5>()(0) == 1 && range<1,2,5>()(1) == 3 && range<1,2,5>()(2) == 5));
+
+
+ DLIB_TEST((range<0,5>().nc() == 6));
+ DLIB_TEST((range<1,5>().nc() == 5));
+ DLIB_TEST((range<0,5>().nr() == 1));
+ DLIB_TEST((range<1,5>().nr() == 1));
+ DLIB_TEST((trans(range<0,5>()).nr() == 6));
+ DLIB_TEST((trans(range<1,5>()).nr() == 5));
+ DLIB_TEST((trans(range<0,5>()).nc() == 1));
+ DLIB_TEST((trans(range<1,5>()).nc() == 1));
+
+ DLIB_TEST((range<0,2,5>().nc() == 3));
+ DLIB_TEST((range<1,2,5>().nc() == 3));
+ DLIB_TEST((range<0,2,5>().nr() == 1));
+ DLIB_TEST((range<1,2,5>().nr() == 1));
+ DLIB_TEST((trans(range<0,2,5>()).nr() == 3));
+ DLIB_TEST((trans(range<1,2,5>()).nr() == 3));
+ DLIB_TEST((trans(range<0,2,5>()).nc() == 1));
+ DLIB_TEST((trans(range<1,2,5>()).nc() == 1));
+
+ DLIB_TEST((range<0,3,6>().nc() == 3));
+ DLIB_TEST((range<1,3,5>().nc() == 2));
+ DLIB_TEST((range<0,3,5>().nr() == 1));
+ DLIB_TEST((range<1,3,5>().nr() == 1));
+ DLIB_TEST((trans(range<0,3,6>()).nr() == 3));
+ DLIB_TEST((trans(range<1,3,5>()).nr() == 2));
+ DLIB_TEST((trans(range<0,3,5>()).nc() == 1));
+ DLIB_TEST((trans(range<1,3,5>()).nc() == 1));
+ }
+
+ {
+ DLIB_TEST(range(5,0).nc() == 6);
+ DLIB_TEST(range(5,1).nc() == 5);
+ DLIB_TEST(range(5,0).nr() == 1);
+ DLIB_TEST(range(5,1).nr() == 1);
+ DLIB_TEST(trans(range(5,0)).nr() == 6);
+ DLIB_TEST(trans(range(5,1)).nr() == 5);
+ DLIB_TEST(trans(range(5,0)).nc() == 1);
+ DLIB_TEST(trans(range(5,1)).nc() == 1);
+
+ DLIB_TEST(range(5,2,0).nc() == 3);
+ DLIB_TEST(range(5,2,1).nc() == 3);
+ DLIB_TEST(range(5,2,0).nr() == 1);
+ DLIB_TEST(range(5,2,1).nr() == 1);
+ DLIB_TEST(trans(range(5,2,0)).nr() == 3);
+ DLIB_TEST(trans(range(5,2,1)).nr() == 3);
+ DLIB_TEST(trans(range(5,2,0)).nc() == 1);
+ DLIB_TEST(trans(range(5,2,1)).nc() == 1);
+
+ DLIB_TEST(range(6,3,0).nc() == 3);
+ DLIB_TEST(range(5,3,1).nc() == 2);
+ DLIB_TEST(range(5,3,0).nr() == 1);
+ DLIB_TEST(range(5,3,1).nr() == 1);
+ DLIB_TEST(trans(range(6,3,0)).nr() == 3);
+ DLIB_TEST(trans(range(5,3,1)).nr() == 2);
+ DLIB_TEST(trans(range(5,3,0)).nc() == 1);
+ DLIB_TEST(trans(range(5,3,1)).nc() == 1);
+
+ DLIB_TEST(range(5,9,1).nc() == 1);
+ DLIB_TEST(range(5,9,1).nr() == 1);
+
+ DLIB_TEST(range(0,0).nc() == 1);
+ DLIB_TEST(range(0,0).nr() == 1);
+
+ DLIB_TEST(range(1,1)(0) == 1);
+
+ DLIB_TEST(range(5,0)(0) == 5 && range(5,0)(1) == 4 && range(5,0)(5) == 0);
+ DLIB_TEST(range(5,2,1)(0) == 5 && range(5,2,1)(1) == 3 && range(5,2,1)(2) == 1);
+ DLIB_TEST((range<5,0>()(0) == 5 && range<5,0>()(1) == 4 && range<5,0>()(5) == 0));
+ DLIB_TEST((range<5,2,1>()(0) == 5 && range<5,2,1>()(1) == 3 && range<5,2,1>()(2) == 1));
+
+
+ DLIB_TEST((range<5,0>().nc() == 6));
+ DLIB_TEST((range<5,1>().nc() == 5));
+ DLIB_TEST((range<5,0>().nr() == 1));
+ DLIB_TEST((range<5,1>().nr() == 1));
+ DLIB_TEST((trans(range<5,0>()).nr() == 6));
+ DLIB_TEST((trans(range<5,1>()).nr() == 5));
+ DLIB_TEST((trans(range<5,0>()).nc() == 1));
+ DLIB_TEST((trans(range<5,1>()).nc() == 1));
+
+ DLIB_TEST((range<5,2,0>().nc() == 3));
+ DLIB_TEST((range<5,2,1>().nc() == 3));
+ DLIB_TEST((range<5,2,0>().nr() == 1));
+ DLIB_TEST((range<5,2,1>().nr() == 1));
+ DLIB_TEST((trans(range<5,2,0>()).nr() == 3));
+ DLIB_TEST((trans(range<5,2,1>()).nr() == 3));
+ DLIB_TEST((trans(range<5,2,0>()).nc() == 1));
+ DLIB_TEST((trans(range<5,2,1>()).nc() == 1));
+
+ DLIB_TEST((range<6,3,0>().nc() == 3));
+ DLIB_TEST((range<5,3,1>().nc() == 2));
+ DLIB_TEST((range<5,3,0>().nr() == 1));
+ DLIB_TEST((range<5,3,1>().nr() == 1));
+ DLIB_TEST((trans(range<6,3,0>()).nr() == 3));
+ DLIB_TEST((trans(range<5,3,1>()).nr() == 2));
+ DLIB_TEST((trans(range<5,3,0>()).nc() == 1));
+ DLIB_TEST((trans(range<5,3,1>()).nc() == 1));
+ }
+
+ {
+ matrix<double> m(4,3);
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ DLIB_TEST(subm(m,range(0,3),range(0,0)) == colm(m,0));
+ DLIB_TEST(subm(m,range(0,3),range(1,1)) == colm(m,1));
+ DLIB_TEST(subm(m,range(0,3),range(2,2)) == colm(m,2));
+
+ DLIB_TEST(subm(m,range(0,0),range(0,2)) == rowm(m,0));
+ DLIB_TEST(subm(m,range(1,1),range(0,2)) == rowm(m,1));
+ DLIB_TEST(subm(m,range(2,2),range(0,2)) == rowm(m,2));
+ DLIB_TEST(subm(m,range(3,3),range(0,2)) == rowm(m,3));
+
+ DLIB_TEST(subm(m,0,0,2,2) == subm(m,range(0,1),range(0,1)));
+ DLIB_TEST(subm(m,1,1,2,2) == subm(m,range(1,2),range(1,2)));
+
+ matrix<double,2,2> m2 = subm(m,range(0,2,2),range(0,2,2));
+
+ DLIB_TEST(m2(0,0) == m(0,0));
+ DLIB_TEST(m2(0,1) == m(0,2));
+ DLIB_TEST(m2(1,0) == m(2,0));
+ DLIB_TEST(m2(1,1) == m(2,2));
+
+
+ }
+ {
+ matrix<double,4,3> m(4,3);
+ for (long r = 0; r < m.nr(); ++r)
+ {
+ for (long c = 0; c < m.nc(); ++c)
+ {
+ m(r,c) = r*c;
+ }
+ }
+
+ DLIB_TEST(subm(m,range<0,3>(),range<0,0>()) == colm(m,0));
+ DLIB_TEST(subm(m,range<0,3>(),range<1,1>()) == colm(m,1));
+ DLIB_TEST(subm(m,range<0,3>(),range<2,2>()) == colm(m,2));
+
+ DLIB_TEST(subm(m,range<0,0>(),range<0,2>()) == rowm(m,0));
+ DLIB_TEST(subm(m,range<1,1>(),range<0,2>()) == rowm(m,1));
+ DLIB_TEST(subm(m,range<2,2>(),range<0,2>()) == rowm(m,2));
+ DLIB_TEST(subm(m,range<3,3>(),range<0,2>()) == rowm(m,3));
+
+ DLIB_TEST(subm(m,0,0,2,2) == subm(m,range<0,1>(),range<0,1>()));
+ DLIB_TEST(subm(m,1,1,2,2) == subm(m,range<1,2>(),range<1,2>()));
+
+ matrix<double,2,2> m2 = subm(m,range<0,2,2>(),range<0,2,2>());
+
+ DLIB_TEST(m2(0,0) == m(0,0));
+ DLIB_TEST(m2(0,1) == m(0,2));
+ DLIB_TEST(m2(1,0) == m(2,0));
+ DLIB_TEST(m2(1,1) == m(2,2));
+
+
+ }
+
+ {
+ matrix<double> a = randm(3,4);
+ matrix<double> b = randm(3,4);
+
+ matrix<double> m1, m2;
+
+ m1 = max_pointwise(a,b);
+ m2 = min_pointwise(a,b);
+ DLIB_TEST(m1.nr() == a.nr());
+ DLIB_TEST(m1.nc() == a.nc());
+ DLIB_TEST(m2.nr() == a.nr());
+ DLIB_TEST(m2.nc() == a.nc());
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ DLIB_TEST_MSG(m1(r,c) == std::max(a(r,c), b(r,c)), m1(r,c) << " : " << a(r,c) << " " << b(r,c));
+ DLIB_TEST(m2(r,c) == std::min(a(r,c), b(r,c)));
+ }
+ }
+ }
+
+ {
+ matrix<double,4,5> m;
+ set_subm(m, range(0,3), range(0,4)) = 4;
+ DLIB_TEST(min(m) == max(m) && min(m) == 4);
+
+ set_subm(m,range(1,1),range(0,4)) = 7;
+ DLIB_TEST((rowm(m,0) == uniform_matrix<double>(1,5, 4)));
+ DLIB_TEST((rowm(m,1) == uniform_matrix<double>(1,5, 7)));
+ DLIB_TEST((rowm(m,2) == uniform_matrix<double>(1,5, 4)));
+ DLIB_TEST((rowm(m,3) == uniform_matrix<double>(1,5, 4)));
+
+
+ set_subm(m, range(0,2,3), range(0,2,4)) = trans(subm(m,0,0,3,2));
+
+
+ DLIB_TEST(m(0,2) == 7);
+ DLIB_TEST(m(2,2) == 7);
+
+ DLIB_TEST(sum(m) == 7*5+ 7+7 + 4*(4*5 - 7));
+
+ }
+
+ {
+ matrix<double> mat(4,5);
+ DLIB_TEST((uniform_matrix<double>(4,5,1) == ones_matrix<double>(4,5)));
+ DLIB_TEST((uniform_matrix<double>(4,5,1) == ones_matrix(mat)));
+ DLIB_TEST((uniform_matrix<double>(4,5,0) == zeros_matrix<double>(4,5)));
+ DLIB_TEST((uniform_matrix<double>(4,5,0) == zeros_matrix(mat)));
+ DLIB_TEST((uniform_matrix<float>(4,5,1) == ones_matrix<float>(4,5)));
+ DLIB_TEST((uniform_matrix<float>(4,5,0) == zeros_matrix<float>(4,5)));
+ DLIB_TEST((uniform_matrix<complex<double> >(4,5,1) == ones_matrix<complex<double> >(4,5)));
+ DLIB_TEST((uniform_matrix<complex<double> >(4,5,0) == zeros_matrix<complex<double> >(4,5)));
+ DLIB_TEST((uniform_matrix<complex<float> >(4,5,1) == ones_matrix<complex<float> >(4,5)));
+ DLIB_TEST((uniform_matrix<complex<float> >(4,5,0) == zeros_matrix<complex<float> >(4,5)));
+ DLIB_TEST((complex_matrix(ones_matrix<double>(3,3), zeros_matrix<double>(3,3)) == complex_matrix(ones_matrix<double>(3,3))));
+ DLIB_TEST((pointwise_multiply(complex_matrix(ones_matrix<double>(3,3)), ones_matrix<double>(3,3)*2) ==
+ complex_matrix(2*ones_matrix<double>(3,3))));
+ }
+
+ {
+ DLIB_TEST(( uniform_matrix<double>(303,303, 3)*identity_matrix<double>(303) == uniform_matrix<double,303,303>(3) ) );
+ DLIB_TEST(( uniform_matrix<double,303,303>(3)*identity_matrix<double,303>() == uniform_matrix<double,303,303>(3) ));
+ }
+
+ {
+ matrix<double> m(2,3);
+ m = 1,2,3,
+ 5,6,7;
+
+ DLIB_TEST_MSG(m(0,0) == 1 && m(0,1) == 2 && m(0,2) == 3 &&
+ m(1,0) == 5 && m(1,1) == 6 && m(1,2) == 7,"");
+
+ m = 4;
+ DLIB_TEST((m == uniform_matrix<double,2,3>(4)));
+
+ matrix<double,2,3> m2;
+ m2 = 1,2,3,
+ 5,6,7;
+ DLIB_TEST_MSG(m2(0,0) == 1 && m2(0,1) == 2 && m2(0,2) == 3 &&
+ m2(1,0) == 5 && m2(1,1) == 6 && m2(1,2) == 7,"");
+
+ matrix<double,2,1> m3;
+ m3 = 1,
+ 5;
+ DLIB_TEST(m3(0) == 1 && m3(1) == 5 );
+
+ matrix<double,1,2> m4;
+ m4 = 1, 5;
+ DLIB_TEST(m3(0) == 1 && m3(1) == 5 );
+ }
+
+ {
+ matrix<double> m(4,1);
+ m = 3, 1, 5, 2;
+ DLIB_TEST(index_of_min(m) == 1);
+ DLIB_TEST(index_of_max(m) == 2);
+ DLIB_TEST(index_of_min(trans(m)) == 1);
+ DLIB_TEST(index_of_max(trans(m)) == 2);
+ }
+
+ {
+ matrix<double> m1(1,5), m2;
+
+ m1 = 3.0000, 3.7500, 4.5000, 5.2500, 6.0000;
+ m2 = linspace(3, 6, 5);
+
+ DLIB_TEST(equal(m1, m2));
+
+ m1 = pow(10, m1);
+ m2 = logspace(3, 6, 5);
+
+ DLIB_TEST(equal(m1, m2));
+ }
+
+ {
+ matrix<long> m = cartesian_product(range(1,3), range(0,1));
+
+ matrix<long,2,1> c0, c1, c2, c3, c4, c5;
+ c0 = 1, 0;
+ c1 = 1, 1;
+ c2 = 2, 0;
+ c3 = 2, 1;
+ c4 = 3, 0;
+ c5 = 3, 1;
+
+ DLIB_TEST_MSG(colm(m,0) == c0, colm(m,0) << "\n\n" << c0);
+ DLIB_TEST(colm(m,1) == c1);
+ DLIB_TEST(colm(m,2) == c2);
+ DLIB_TEST(colm(m,3) == c3);
+ DLIB_TEST(colm(m,4) == c4);
+ DLIB_TEST(colm(m,5) == c5);
+ }
+
+
+ {
+ matrix<double> m(2,2), mr(2,2), mr_max(2,2);
+
+ m = 1, 2,
+ 0, 4;
+
+ mr = 1, 1.0/2.0,
+ 0, 1.0/4.0;
+
+ mr_max = 1, 1.0/2.0,
+ std::numeric_limits<double>::max(), 1.0/4.0;
+
+ DLIB_TEST(equal(reciprocal(m), mr));
+ DLIB_TEST(equal(reciprocal_max(m), mr_max));
+
+ }
+
+ {
+ matrix<double> m1, m2;
+ m1.set_size(3,1);
+ m2.set_size(1,3);
+
+ m1 = 1,2,3;
+ m2 = 4,5,6;
+ DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
+ }
+
+ {
+ matrix<double,3,1> m1, m2;
+ m1.set_size(3,1);
+ m2.set_size(3,1);
+
+ m1 = 1,2,3;
+ m2 = 4,5,6;
+ DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
+ }
+ {
+ matrix<double,1,3> m1, m2;
+ m1.set_size(1,3);
+ m2.set_size(1,3);
+
+ m1 = 1,2,3;
+ m2 = 4,5,6;
+ DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
+ }
+ {
+ matrix<double,1,3> m1;
+ matrix<double> m2;
+ m1.set_size(1,3);
+ m2.set_size(3,1);
+
+ m1 = 1,2,3;
+ m2 = 4,5,6;
+ DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
+ DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
+ }
+
+ {
+ matrix<double> m1(3,3), m2(3,3);
+
+ m1 = 1;
+ m2 = 1;
+ m1 = m1*subm(m2,0,0,3,3);
+ DLIB_TEST(is_finite(m1));
+ }
+ {
+ matrix<double,3,1> m1;
+ matrix<double> m2(3,3);
+
+ m1 = 1;
+ m2 = 1;
+ m1 = subm(m2,0,0,3,3)*m1;
+ }
+
+ {
+ matrix<int> m(2,1);
+
+ m = 3,3;
+ m /= m(0);
+
+ DLIB_TEST(m(0) == 1);
+ DLIB_TEST(m(1) == 1);
+ }
+ {
+ matrix<int> m(2,1);
+
+ m = 3,3;
+ m *= m(0);
+
+ DLIB_TEST(m(0) == 9);
+ DLIB_TEST(m(1) == 9);
+ }
+ {
+ matrix<int> m(2,1);
+
+ m = 3,3;
+ m -= m(0);
+
+ DLIB_TEST(m(0) == 0);
+ DLIB_TEST(m(1) == 0);
+ }
+ {
+ matrix<int> m(2,1);
+
+ m = 3,3;
+ m += m(0);
+
+ DLIB_TEST(m(0) == 6);
+ DLIB_TEST(m(1) == 6);
+ DLIB_TEST(is_finite(m));
+ }
+
+
+ {
+ matrix<double> m(3,3);
+ m = 3;
+ m(1,1) = std::numeric_limits<double>::infinity();
+ DLIB_TEST(is_finite(m) == false);
+ m(1,1) = -std::numeric_limits<double>::infinity();
+ DLIB_TEST(is_finite(m) == false);
+ m(1,1) = 2;
+ DLIB_TEST(is_finite(m));
+ }
+
+ {
+ matrix<int> m(4,1), mm, mmm;
+
+ mmm = mm = (m = 1,2,3,4);
+ DLIB_TEST(m(0) == 1);
+ DLIB_TEST(m(1) == 2);
+ DLIB_TEST(m(2) == 3);
+ DLIB_TEST(m(3) == 4);
+ DLIB_TEST(mm == m);
+ DLIB_TEST(mmm == m);
+ DLIB_TEST(mm(0) == 1);
+ DLIB_TEST(mm(1) == 2);
+ DLIB_TEST(mm(2) == 3);
+ DLIB_TEST(mm(3) == 4);
+ }
+
+ {
+ const long n = 5;
+ matrix<double> m1, m2, m3, truth;
+ m1 = randm(n,n);
+ m2 = randm(n,n);
+
+ rectangle rect1(1,1,3,3);
+ rectangle rect2(2,1,4,3);
+
+ truth = subm(m1,rect1)*subm(m2,rect2);
+ m3 = mat(&m1(0,0)+6, 3,3, m1.nc()) * mat(&m2(0,0)+7, 3,3, m2.nc());
+
+ DLIB_TEST(max(abs(truth-m3)) < 1e-13);
+ }
+
+ {
+ const long n = 5;
+ matrix<double> m1, m2, m3, truth;
+ m1 = randm(n,n);
+ m2 = randm(n,n);
+ m3 = randm(n,n);
+
+
+ truth = m1*m2;
+ m3 = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n);
+ DLIB_TEST(max(abs(truth-m3)) < 1e-13);
+ m3 = 0;
+ set_ptrm(&m3(0,0),n,n) = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n);
+ DLIB_TEST(max(abs(truth-m3)) < 1e-13);
+ set_ptrm(&m3(0,0),n,n) = m1*m2;
+ DLIB_TEST(max(abs(truth-m3)) < 1e-13);
+
+ // now make sure it deals with aliasing correctly.
+ truth = m1*m2;
+ m1 = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n);
+ DLIB_TEST(max(abs(truth-m1)) < 1e-13);
+
+ m1 = randm(n,n);
+ truth = m1*m2;
+ set_ptrm(&m1(0,0),n,n) = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n);
+ DLIB_TEST(max(abs(truth-m1)) < 1e-13);
+
+ m1 = randm(n,n);
+ truth = m1*m2;
+ set_ptrm(&m1(0,0),n,n) = m1*m2;
+ DLIB_TEST(max(abs(truth-m1)) < 1e-13);
+
+ m1 = randm(n,n);
+ truth = m1+m1*m2;
+ set_ptrm(&m1(0,0),n,n) += m1*m2;
+ DLIB_TEST(max(abs(truth-m1)) < 1e-13);
+
+ m1 = randm(n,n);
+ truth = m1-m1*m2;
+ set_ptrm(&m1(0,0),n,n) -= m1*m2;
+ DLIB_TEST(max(abs(truth-m1)) < 1e-13);
+
+ }
+
+ {
+ matrix<double,3,3,default_memory_manager,column_major_layout> a(3,3);
+ matrix<double,3,3,default_memory_manager,column_major_layout> m = randm(3,3);
+ matrix<double,3,1,default_memory_manager,column_major_layout> b = randm(3,1);
+
+ a = 0;
+ set_colm(a,0) = m*b;
+ DLIB_TEST(colm(a,0) == m*b);
+ a = 0;
+ set_rowm(a,0) = trans(m*b);
+ DLIB_TEST(rowm(a,0) == trans(m*b));
+ DLIB_TEST(rowm(a,0) != m*b);
+ }
+ {
+ matrix<double,0,0,default_memory_manager,column_major_layout> a(3,3);
+ matrix<double,0,0,default_memory_manager,column_major_layout> m = randm(3,3);
+ matrix<double,0,0,default_memory_manager,column_major_layout> b = randm(3,1);
+
+ a = 0;
+ set_colm(a,0) = m*b;
+ DLIB_TEST(equal(colm(a,0) , m*b));
+ a = 0;
+ set_rowm(a,0) = trans(m*b);
+ DLIB_TEST(equal(rowm(a,0) , trans(m*b)));
+ DLIB_TEST(!equal(rowm(a,0) , m*b));
+ }
+ {
+ matrix<double> a(3,3);
+ matrix<double> m = randm(3,3);
+ matrix<double> b = randm(3,1);
+
+ a = 0;
+ set_colm(a,0) = m*b;
+ DLIB_TEST(equal(colm(a,0) , m*b));
+ a = 0;
+ set_rowm(a,0) = trans(m*b);
+ DLIB_TEST(equal(rowm(a,0) , trans(m*b)));
+ DLIB_TEST(!equal(rowm(a,0) , m*b));
+ }
+ }
+
+
+
+
+
+
+ class matrix_tester : public tester
+ {
+ public:
+ matrix_tester (
+ ) :
+ tester ("test_matrix",
+ "Runs tests on the matrix component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ matrix_test();
+ matrix_test2();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/matrix2.cpp b/ml/dlib/dlib/test/matrix2.cpp
new file mode 100644
index 000000000..8de17fc7f
--- /dev/null
+++ b/ml/dlib/dlib/test/matrix2.cpp
@@ -0,0 +1,1158 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+
+#include "tester.h"
+#include <dlib/memory_manager_stateless.h>
+#include <dlib/array2d.h>
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.matrix2");
+
+ dlib::rand rnd;
+
+ void matrix_test1 (
+ )
+ {
+ typedef memory_manager_stateless<char>::kernel_2_2a MM;
+ print_spinner();
+
+ const double ident[] = {
+ 1, 0, 0, 0,
+ 0, 1, 0, 0,
+ 0, 0, 1, 0,
+ 0, 0, 0, 1 };
+
+ const double uniform3[] = {
+ 3, 3, 3, 3,
+ 3, 3, 3, 3,
+ 3, 3, 3, 3,
+ 3, 3, 3, 3
+ };
+
+ const double uniform1[] = {
+ 1, 1, 1, 1,
+ 1, 1, 1, 1,
+ 1, 1, 1, 1,
+ 1, 1, 1, 1
+ };
+
+ const double uniform0[] = {
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0
+ };
+
+ const int array[] = {
+ 42, 58, 9, 1,
+ 9, 5, 8, 2,
+ 98, 28, 4, 77,
+ 9, 2, 44, 88 };
+
+ const int array2[] = {
+ 1, 22, 3,
+ 4, 52, 6,
+ 7, 8, 9 };
+
+ const int array2_r[] = {
+ 52, 6, 4,
+ 8, 9, 7,
+ 22, 3, 1
+ };
+
+ const double array_f[] = {
+ -0.99,
+ 0.99};
+
+
+ matrix<double,2,1,MM> fm(array_f);
+
+ DLIB_TEST(fm.size() == 2);
+ matrix<double> dfm(fm);
+ DLIB_TEST(round(fm)(0) == -1);
+ DLIB_TEST(round(fm)(1) == 1);
+ DLIB_TEST(round(dfm)(0) == -1);
+ DLIB_TEST(round(dfm)(1) == 1);
+ DLIB_TEST(round(dfm).size() == dfm.size());
+
+
+ const int array3[] = { 1, 2, 3, 4 };
+
+ matrix<double,3,3,MM> m3(array2);
+ matrix<double> dm3;
+ DLIB_TEST(dm3.size() == 0);
+ DLIB_TEST(dm3.nr() == 0);
+ DLIB_TEST(dm3.nc() == 0);
+ dm3.set_size(3,4);
+ DLIB_TEST(dm3.nr() == 3);
+ DLIB_TEST(dm3.nc() == 4);
+ DLIB_TEST(dm3.size() == 3*4);
+ dm3.set_size(3,3);
+ DLIB_TEST(dm3.nr() == 3);
+ DLIB_TEST(dm3.nc() == 3);
+ dm3 = m3;
+ dm3(0,0)++;
+ DLIB_TEST( dm3 != m3);
+ dm3 = m3;
+ DLIB_TEST( dm3 == m3);
+ DLIB_TEST( abs(sum(squared(normalize(dm3))) - 1.0) < 1e-10);
+
+ matrix<double,3,4> mrc;
+ mrc.set_size(3,4);
+
+ set_all_elements(mrc,1);
+
+ DLIB_TEST(diag(mrc) == uniform_matrix<double>(3,1,1));
+ DLIB_TEST(diag(matrix<double>(mrc)) == uniform_matrix<double>(3,1,1));
+
+ matrix<double,2,3> mrc2;
+ set_all_elements(mrc2,1);
+ DLIB_TEST((removerc<1,1>(mrc) == mrc2));
+ DLIB_TEST((removerc(mrc,1,1) == mrc2));
+
+ matrix<int,3,3> m4, m5, m6;
+ set_all_elements(m4, 4);
+ set_all_elements(m5, 4);
+ set_all_elements(m6, 1);
+
+ DLIB_TEST(squared(m4) == pointwise_multiply(m4,m4));
+ DLIB_TEST(cubed(m4) == pointwise_multiply(m4,m4,m4));
+ DLIB_TEST(pow(matrix_cast<double>(m4),2) == squared(matrix_cast<double>(m4)));
+ DLIB_TEST(pow(matrix_cast<double>(m4),3) == cubed(matrix_cast<double>(m4)));
+
+ matrix<int> dm4;
+ matrix<int,0,0,memory_manager_stateless<char>::kernel_2_2a> dm5;
+ dm4 = dm4;
+ dm4 = dm5;
+ DLIB_TEST(dm4.nr() == 0);
+ dm4 = m4;
+ dm5 = m5;
+ DLIB_TEST(dm4 == dm5);
+
+
+ DLIB_TEST(m4 == m5);
+ DLIB_TEST(m6 != m5);
+ m4.swap(m6);
+ DLIB_TEST(m6 == m5);
+ DLIB_TEST(m4 != m5);
+
+ DLIB_TEST(m3.nr() == 3);
+ DLIB_TEST(m3.nc() == 3);
+
+ matrix<double,4,1> v(array3), v2;
+ DLIB_TEST(v.nr() == 4);
+ DLIB_TEST(v.nc() == 1);
+
+ std::vector<double> stdv(4);
+ std_vector_c<double> stdv_c(4);
+ dlib::array<double> arr;
+ arr.resize(4);
+ for (long i = 0; i < 4; ++i)
+ stdv[i] = stdv_c[i] = arr[i] = i+1;
+
+ DLIB_TEST(mat(stdv)(0) == 1);
+ DLIB_TEST(mat(stdv)(1) == 2);
+ DLIB_TEST(mat(stdv)(2) == 3);
+ DLIB_TEST(mat(stdv)(3) == 4);
+ DLIB_TEST(mat(stdv).nr() == 4);
+ DLIB_TEST(mat(stdv).nc() == 1);
+ DLIB_TEST(mat(stdv).size() == 4);
+ DLIB_TEST(equal(trans(mat(stdv))*mat(stdv), trans(v)*v));
+ DLIB_TEST(equal(trans(mat(stdv))*mat(stdv), tmp(trans(v)*v)));
+
+ DLIB_TEST(mat(stdv_c)(0) == 1);
+ DLIB_TEST(mat(stdv_c)(1) == 2);
+ DLIB_TEST(mat(stdv_c)(2) == 3);
+ DLIB_TEST(mat(stdv_c)(3) == 4);
+ DLIB_TEST(mat(stdv_c).nr() == 4);
+ DLIB_TEST(mat(stdv_c).nc() == 1);
+ DLIB_TEST(mat(stdv_c).size() == 4);
+ DLIB_TEST(equal(trans(mat(stdv_c))*mat(stdv_c), trans(v)*v));
+
+ DLIB_TEST(mat(arr)(0) == 1);
+ DLIB_TEST(mat(arr)(1) == 2);
+ DLIB_TEST(mat(arr)(2) == 3);
+ DLIB_TEST(mat(arr)(3) == 4);
+ DLIB_TEST(mat(arr).nr() == 4);
+ DLIB_TEST(mat(arr).nc() == 1);
+ DLIB_TEST(mat(arr).size() == 4);
+ DLIB_TEST(equal(trans(mat(arr))*mat(arr), trans(v)*v));
+
+ DLIB_TEST(v(0) == 1);
+ DLIB_TEST(v(1) == 2);
+ DLIB_TEST(v(2) == 3);
+ DLIB_TEST(v(3) == 4);
+ matrix<double> dv = v;
+ DLIB_TEST((trans(v)*v).size() == 1);
+ DLIB_TEST((trans(v)*v).nr() == 1);
+ DLIB_TEST((trans(v)*dv).nr() == 1);
+ DLIB_TEST((trans(dv)*dv).nr() == 1);
+ DLIB_TEST((trans(dv)*v).nr() == 1);
+ DLIB_TEST((trans(v)*v).nc() == 1);
+ DLIB_TEST((trans(v)*dv).nc() == 1);
+ DLIB_TEST((trans(dv)*dv).nc() == 1);
+ DLIB_TEST((trans(dv)*v).nc() == 1);
+ DLIB_TEST((trans(v)*v)(0) == 1*1 + 2*2 + 3*3 + 4*4);
+ DLIB_TEST((trans(dv)*v)(0) == 1*1 + 2*2 + 3*3 + 4*4);
+ DLIB_TEST((trans(dv)*dv)(0) == 1*1 + 2*2 + 3*3 + 4*4);
+ DLIB_TEST((trans(v)*dv)(0) == 1*1 + 2*2 + 3*3 + 4*4);
+
+ dv = trans(dv)*v;
+ DLIB_TEST(dv.nr() == 1);
+ DLIB_TEST(dv.nc() == 1);
+
+ dm3 = m3;
+ DLIB_TEST(floor(det(m3)+0.01) == -444);
+ DLIB_TEST(floor(det(dm3)+0.01) == -444);
+ DLIB_TEST(min(m3) == 1);
+ DLIB_TEST(m3(min_point(m3).y(),min_point(m3).x()) == 1);
+ DLIB_TEST(min(dm3) == 1);
+ DLIB_TEST(max(m3) == 52);
+ DLIB_TEST(m3(max_point(m3).y(),max_point(m3).x()) == 52);
+ DLIB_TEST(max(dm3) == 52);
+ DLIB_TEST(sum(m3) == 112);
+ DLIB_TEST(sum(dm3) == 112);
+ DLIB_TEST(prod(m3) == 41513472);
+ DLIB_TEST(prod(dm3) == 41513472);
+ DLIB_TEST(prod(diag(m3)) == 1*52*9);
+ DLIB_TEST(prod(diag(dm3)) == 1*52*9);
+ DLIB_TEST(sum(diag(m3)) == 1+52+9);
+ DLIB_TEST(sum(diag(dm3)) == 1+52+9);
+ DLIB_TEST(equal(round(10000*m3*inv(m3))/10000 , identity_matrix<double,3>()));
+ DLIB_TEST(equal(round(10000*dm3*inv(m3))/10000 , identity_matrix<double,3>()));
+ DLIB_TEST(equal(round(10000*dm3*inv(dm3))/10000 , identity_matrix<double,3>()));
+ DLIB_TEST(equal(round(10000*m3*inv(dm3))/10000 , identity_matrix<double,3>()));
+ DLIB_TEST(equal(round(10000*tmp(m3*inv(m3)))/10000 , identity_matrix<double,3>()));
+ DLIB_TEST(equal(round(10000*tmp(dm3*inv(m3)))/10000 , identity_matrix<double,3>()));
+ DLIB_TEST(equal(round(10000*tmp(dm3*inv(dm3)))/10000 , identity_matrix<double,3>()));
+ DLIB_TEST(equal(round(10000*tmp(m3*inv(dm3)))/10000 , identity_matrix<double,3>()));
+ DLIB_TEST(-1*m3 == -m3);
+ DLIB_TEST(-1*dm3 == -m3);
+ DLIB_TEST(-1*m3 == -dm3);
+ DLIB_TEST(-1*dm3 == -dm3);
+
+ DLIB_TEST(m3 == dm3);
+ m3(1,1) = 99;
+ DLIB_TEST(m3 != dm3);
+ m3 = dm3;
+ DLIB_TEST(m3 == dm3);
+
+ matrix<double,4,4,MM> mident(ident);
+ matrix<double,4,4> muniform0(uniform0);
+ matrix<double,4,4> muniform1(uniform1);
+ matrix<double,4,4> muniform3(uniform3);
+ matrix<double,4,4> m1(array), m2;
+ DLIB_TEST(m1.nr() == 4);
+ DLIB_TEST(m1.nc() == 4);
+
+ DLIB_TEST(muniform1 + muniform1 + muniform1 == muniform3);
+ DLIB_TEST(muniform1*2 + muniform1 + muniform1 - muniform1 == muniform3);
+ DLIB_TEST(2*muniform1 + muniform1 + muniform1 - muniform1 == muniform3);
+ DLIB_TEST(muniform1 + muniform1 + muniform1 - muniform3 == muniform0);
+ DLIB_TEST(equal(muniform3/3 , muniform1));
+ DLIB_TEST(v != m1);
+ DLIB_TEST(v == v);
+ DLIB_TEST(m1 == m1);
+
+ muniform0.swap(muniform1);
+ DLIB_TEST((muniform1 == matrix_cast<double>(uniform_matrix<long,4,4,0>())));
+ DLIB_TEST((muniform0 == matrix_cast<double>(uniform_matrix<long,4,4,1>())));
+ DLIB_TEST((muniform1 == matrix_cast<double>(uniform_matrix<long>(4,4,0))));
+ DLIB_TEST((muniform0 == matrix_cast<double>(uniform_matrix<long>(4,4,1))));
+ swap(muniform0,muniform1);
+
+ DLIB_TEST((mident == identity_matrix<double,4>()));
+ DLIB_TEST((muniform0 == matrix_cast<double>(uniform_matrix<long,4,4,0>())));
+ DLIB_TEST((muniform1 == matrix_cast<double>(uniform_matrix<long,4,4,1>())));
+ DLIB_TEST((muniform3 == matrix_cast<double>(uniform_matrix<long,4,4,3>())));
+ DLIB_TEST((muniform1*8 == matrix_cast<double>(uniform_matrix<long,4,4,8>())));
+
+ set_all_elements(m2,7);
+ DLIB_TEST(m2 == muniform1*7);
+ m2 = array;
+ DLIB_TEST(m2 == m1);
+
+ const double m1inv[] = {
+ -0.00946427624, 0.0593272941, 0.00970564379, -0.00973323731,
+ 0.0249312057, -0.0590122427, -0.00583102756, 0.00616002729,
+ -0.00575431149, 0.110081189, -0.00806792253, 0.00462297692,
+ 0.00327847478, -0.0597669712, 0.00317386196, 0.00990759201
+ };
+
+ m2 = m1inv;
+ DLIB_TEST((round(m2*m1) == identity_matrix<double,4>()));
+ DLIB_TEST((round(tmp(m2*m1)) == identity_matrix<double,4>()));
+
+ DLIB_TEST_MSG(round(m2*10000) == round(inv(m1)*10000),
+ round(m2*10000) - round(inv(m1)*10000)
+ << "\n\n" << round(m2*10000)
+ << "\n\n" << round(inv(m1)*10000)
+ << "\n\n" << m2
+ << "\n\n" << inv(m1)
+ );
+ DLIB_TEST(m1 == abs(-1*m1));
+ DLIB_TEST(abs(m2) == abs(-1*m2));
+
+ DLIB_TEST_MSG(floor(det(m1)+0.01) == 3297875,"\nm1: \n" << m1 << "\ndet(m1): " << det(m1));
+
+
+ ostringstream sout;
+ m1 = m2;
+ serialize(m1,sout);
+ set_all_elements(m1,0);
+ istringstream sin(sout.str());
+ deserialize(m1,sin);
+ DLIB_TEST_MSG(round(100000*m1) == round(100000*m2),"m1: \n" << m1 << endl << "m2: \n" << m2);
+
+
+ set_all_elements(v,2);
+ v2 = pointwise_multiply(v, v*2);
+ set_all_elements(v,8);
+ DLIB_TEST(v == v2);
+ DLIB_TEST(v == tmp(v2));
+ DLIB_TEST((v == rotate<2,0>(v)));
+
+ m4 = array2;
+ m5 = array2_r;
+ DLIB_TEST((m5 == rotate<1,1>(m4)));
+
+ m5 = array2;
+ DLIB_TEST((m5*2 == pointwise_multiply(m5,uniform_matrix<int,3,3,2>())));
+ DLIB_TEST((tmp(m5*2) == tmp(pointwise_multiply(m5,uniform_matrix<int,3,3,2>()))));
+
+ v = tmp(v);
+
+
+
+
+ matrix<double> dm10(10,5);
+ DLIB_TEST(dm10.nr() == 10);
+ DLIB_TEST(dm10.nc() == 5);
+ set_all_elements(dm10,4);
+ DLIB_TEST(dm10.nr() == 10);
+ DLIB_TEST(dm10.nc() == 5);
+ matrix<double,10,5> m10;
+ DLIB_TEST(m10.nr() == 10);
+ DLIB_TEST(m10.nc() == 5);
+ set_all_elements(m10,4);
+ DLIB_TEST(dm10 == m10);
+ DLIB_TEST((clamp<0,3>(dm10) == clamp<0,3>(m10)));
+ DLIB_TEST((clamp<0,3>(dm10)(0,2) == 3));
+
+ set_all_elements(dm10,1);
+ set_all_elements(m10,4);
+ DLIB_TEST(4*dm10 == m10);
+ DLIB_TEST(5*dm10 - dm10 == m10);
+ DLIB_TEST((16*dm10)/4 == m10);
+ DLIB_TEST(dm10+dm10+2*dm10 == m10);
+ DLIB_TEST(dm10+tmp(dm10+2*dm10) == m10);
+ set_all_elements(dm10,4);
+ DLIB_TEST(dm10 == m10);
+ DLIB_TEST_MSG(sum(abs(sigmoid(dm10) -sigmoid(m10))) < 1e-10,sum(abs(sigmoid(dm10) -sigmoid(m10))) );
+
+ {
+ matrix<double,2,1> x, l, u, out;
+ x = 3,4;
+
+ l = 1,1;
+ u = 2,2.2;
+
+ out = 2, 2.2;
+ DLIB_TEST(equal(dlib::clamp(x, l, u) , out));
+ out = 3, 2.2;
+ DLIB_TEST(!equal(dlib::clamp(x, l, u) , out));
+ out = 2, 4.2;
+ DLIB_TEST(!equal(dlib::clamp(x, l, u) , out));
+
+ x = 1.5, 1.5;
+ out = x;
+ DLIB_TEST(equal(dlib::clamp(x, l, u) , out));
+
+ x = 0.5, 1.5;
+ out = 1, 1.5;
+ DLIB_TEST(equal(dlib::clamp(x, l, u) , out));
+
+ x = 1.5, 0.5;
+ out = 1.5, 1.0;
+ DLIB_TEST(equal(dlib::clamp(x, l, u) , out));
+
+ }
+
+ matrix<double, 7, 7,MM,column_major_layout> m7;
+ matrix<double> dm7(7,7);
+ dm7 = randm(7,7, rnd);
+ m7 = dm7;
+
+ DLIB_TEST_MSG(max(abs(dm7*inv(dm7) - identity_matrix<double>(7))) < 1e-12, max(abs(dm7*inv(dm7) - identity_matrix<double>(7))));
+ DLIB_TEST(equal(inv(dm7), inv(m7)));
+ DLIB_TEST(abs(det(dm7) - det(m7)) < 1e-14);
+ DLIB_TEST(abs(min(dm7) - min(m7)) < 1e-14);
+ DLIB_TEST(abs(max(dm7) - max(m7)) < 1e-14);
+ DLIB_TEST_MSG(abs(sum(dm7) - sum(m7)) < 1e-13,sum(dm7) - sum(m7));
+ DLIB_TEST(abs(prod(dm7) -prod(m7)) < 1e-14);
+ DLIB_TEST(equal(diag(dm7) , diag(m7)));
+ DLIB_TEST(equal(trans(dm7) , trans(m7)));
+ DLIB_TEST(equal(abs(dm7) , abs(m7)));
+ DLIB_TEST(equal(round(dm7) , round(m7)));
+ DLIB_TEST(matrix_cast<int>(dm7) == matrix_cast<int>(m7));
+ DLIB_TEST((rotate<2,3>(dm7) == rotate<2,3>(m7)));
+ DLIB_TEST((sum(pointwise_multiply(dm7,dm7) - pointwise_multiply(m7,m7))) < 1e-10);
+ DLIB_TEST((sum(pointwise_multiply(dm7,dm7,dm7) - pointwise_multiply(m7,m7,m7))) < 1e-10);
+ DLIB_TEST_MSG((sum(pointwise_multiply(dm7,dm7,dm7,dm7) - pointwise_multiply(m7,m7,m7,m7))) < 1e-10,
+ (sum(pointwise_multiply(dm7,dm7,dm7,dm7) - pointwise_multiply(m7,m7,m7,m7)))
+ );
+
+
+ matrix<double> temp(5,5);
+ matrix<double> dsm(5,5);
+ matrix<double,5,5,MM> sm;
+
+ set_all_elements(dsm,1);
+ set_all_elements(sm,1);
+ set_all_elements(temp,1);
+
+ dsm += dsm;
+ sm += sm;
+ DLIB_TEST(dsm == 2*temp);
+ DLIB_TEST(sm == 2*temp);
+ temp = dsm*sm + dsm;
+ dsm += dsm*sm;
+ DLIB_TEST_MSG(temp == dsm,temp - dsm);
+
+ set_all_elements(dsm,1);
+ set_all_elements(sm,1);
+ set_all_elements(temp,1);
+
+ dsm += dsm;
+ sm += sm;
+ DLIB_TEST(dsm == 2*temp);
+ DLIB_TEST(sm == 2*temp);
+ temp = dsm*sm + dsm;
+ sm += dsm*sm;
+ DLIB_TEST_MSG(temp == sm,temp - sm);
+
+ set_all_elements(dsm,1);
+ set_all_elements(sm,1);
+ set_all_elements(temp,1);
+
+ dsm += dsm;
+ sm += sm;
+ DLIB_TEST(dsm == 2*temp);
+ DLIB_TEST(sm == 2*temp);
+ temp = sm - dsm*sm ;
+ sm -= dsm*sm;
+ DLIB_TEST_MSG(temp == sm,temp - sm);
+
+ set_all_elements(dsm,1);
+ set_all_elements(sm,1);
+ set_all_elements(temp,1);
+
+ dsm += dsm;
+ sm += sm;
+ DLIB_TEST(dsm == 2*temp);
+ DLIB_TEST(sm == 2*temp);
+ temp = dsm - dsm*sm ;
+ dsm -= dsm*sm;
+ DLIB_TEST_MSG(temp == dsm,temp - dsm);
+
+ set_all_elements(dsm,1);
+ set_all_elements(sm,1);
+ set_all_elements(temp,2);
+
+ dsm *= 2;
+ sm *= 2;
+ DLIB_TEST(dsm == temp);
+ DLIB_TEST(sm == temp);
+ dsm /= 2;
+ sm /= 2;
+ DLIB_TEST(dsm == temp/2);
+ DLIB_TEST(sm == temp/2);
+
+ dsm += dsm;
+ sm += sm;
+ DLIB_TEST(dsm == temp);
+ DLIB_TEST(sm == temp);
+ dsm += sm;
+ sm += dsm;
+ DLIB_TEST(dsm == 2*temp);
+ DLIB_TEST(sm == temp*3);
+ dsm -= sm;
+ sm -= dsm;
+ DLIB_TEST(dsm == -temp);
+ DLIB_TEST(sm == 4*temp);
+ sm -= sm;
+ dsm -= dsm;
+ DLIB_TEST(dsm == 0*temp);
+ DLIB_TEST(sm == 0*temp);
+
+ set_all_elements(dsm,1);
+ set_all_elements(sm,1);
+ set_all_elements(temp,3);
+ dsm += sm+sm;
+ DLIB_TEST(dsm == temp);
+
+ set_all_elements(dsm,1);
+ set_all_elements(sm,1);
+ set_all_elements(temp,-1);
+ dsm -= sm+sm;
+ DLIB_TEST(dsm == temp);
+
+ set_all_elements(dsm,1);
+ set_all_elements(sm,1);
+ set_all_elements(temp,-1);
+ sm -= dsm+dsm;
+ DLIB_TEST(sm == temp);
+
+ set_all_elements(dsm,1);
+ set_all_elements(sm,1);
+ set_all_elements(temp,3);
+ sm += dsm+dsm;
+ DLIB_TEST(sm == temp);
+
+
+
+ // test the implicit conversion to bool stuff
+ {
+ matrix<float> bt1(3,1);
+ matrix<float,3,1> bt2;
+ set_all_elements(bt1,2);
+ set_all_elements(bt2,3);
+
+ float val = trans(bt1)*bt2;
+ DLIB_TEST((float)(trans(bt1)*bt2) == 18);
+ DLIB_TEST((float)(trans(bt1)*bt2) != 19);
+ DLIB_TEST(val == 18);
+ }
+ {
+ matrix<float,3,1> bt1;
+ matrix<float> bt2(3,1);
+ set_all_elements(bt1,2);
+ set_all_elements(bt2,3);
+
+ float val = trans(bt1)*bt2;
+ DLIB_TEST((float)(trans(bt1)*bt2) == 18);
+ DLIB_TEST((float)(trans(bt1)*bt2) != 19);
+ DLIB_TEST(val == 18);
+ }
+ {
+ matrix<float> bt1(3,1);
+ matrix<float> bt2(3,1);
+ set_all_elements(bt1,2);
+ set_all_elements(bt2,3);
+
+ float val = trans(bt1)*bt2;
+ DLIB_TEST((float)(trans(bt1)*bt2) == 18);
+ DLIB_TEST((float)(trans(bt1)*bt2) != 19);
+ DLIB_TEST(val == 18);
+ }
+ {
+ matrix<float,3,1> bt1;
+ matrix<float,3,1> bt2;
+ set_all_elements(bt1,2);
+ set_all_elements(bt2,3);
+
+ float val = trans(bt1)*bt2;
+ DLIB_TEST((float)(trans(bt1)*bt2) == 18);
+ DLIB_TEST((float)(trans(bt1)*bt2) != 19);
+ DLIB_TEST(val == 18);
+ }
+
+
+
+
+ {
+ srand(423452);
+ const long M = 50;
+ const long N = 40;
+
+ matrix<double> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double> u, u2;
+ matrix<double> q, q2;
+ matrix<double> v, v2;
+
+ matrix<double> a2;
+ a2 = tmp(a/2);
+
+
+ svd2(true,true,a2+a2,u,q,v);
+
+ double err = max(abs(a - subm(u,get_rect(a2+a2))*diagm(q)*trans(v)));
+ DLIB_TEST_MSG( err < 1e-11,"err: " << err);
+ using dlib::equal;
+ DLIB_TEST((equal(trans(u)*u , identity_matrix<double,M>(), 1e-10)));
+ DLIB_TEST((equal(trans(v)*v , identity_matrix<double,N>(), 1e-10)));
+
+ svd2(false,true,a2+a2,u,q,v2);
+ svd2(true,false,a2+a2,u2,q,v);
+ svd2(false,false,a2+a2,u,q2,v);
+
+ err = max(abs(a - subm(u2,get_rect(a2+a2))*diagm(q2)*trans(v2)));
+ DLIB_TEST_MSG( err < 1e-11,"err: " << err);
+ DLIB_TEST((equal(trans(u2)*u2 , identity_matrix<double,M>(), 1e-10)));
+ DLIB_TEST((equal(trans(v2)*v2 , identity_matrix<double,N>(), 1e-10)));
+
+ }
+
+
+ {
+ srand(423452);
+ const long M = 3;
+ const long N = 3;
+
+ matrix<double> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double,M,M> u, u2;
+ matrix<double> q, q2;
+ matrix<double,N,N> v, v2;
+
+ matrix<double,M,N,MM> a2;
+ a2 = tmp(a/2);
+
+
+ svd2(true,true,a2+a2,u,q,v);
+
+ double err = max(abs(a - subm(u,get_rect(a2+a2))*diagm(q)*trans(v)));
+ DLIB_TEST_MSG( err < 1e-11,"err: " << err);
+ using dlib::equal;
+ DLIB_TEST((equal(trans(u)*u , identity_matrix<double,M>(), 1e-10)));
+ DLIB_TEST((equal(trans(v)*v , identity_matrix<double,N>(), 1e-10)));
+
+ svd2(false,true,a2+a2,u,q,v2);
+ svd2(true,false,a2+a2,u2,q,v);
+ svd2(false,false,a2+a2,u,q2,v);
+
+ err = max(abs(a - subm(u2,get_rect(a2+a2))*diagm(q2)*trans(v2)));
+ DLIB_TEST_MSG( err < 1e-11,"err: " << err);
+ DLIB_TEST((equal(trans(u2)*u2 , identity_matrix<double,M>(), 1e-10)));
+ DLIB_TEST((equal(trans(v2)*v2 , identity_matrix<double,N>(), 1e-10)));
+
+ }
+
+ {
+ srand(423452);
+ const long M = 3;
+ const long N = 3;
+
+
+ matrix<double,0,0,default_memory_manager, column_major_layout> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double,M,M,default_memory_manager, column_major_layout> u, u2;
+ matrix<double,0,0,default_memory_manager, column_major_layout> q, q2;
+ matrix<double,N,N,default_memory_manager, column_major_layout> v, v2;
+
+ matrix<double,M,N,MM, column_major_layout> a2;
+ a2 = tmp(a/2);
+
+
+ svd2(true,true,a2+a2,u,q,v);
+
+ double err = max(abs(a - subm(u,get_rect(a2+a2))*diagm(q)*trans(v)));
+ DLIB_TEST_MSG( err < 1e-11,"err: " << err);
+ using dlib::equal;
+ DLIB_TEST((equal(trans(u)*u , identity_matrix<double,M>(), 1e-10)));
+ DLIB_TEST((equal(trans(v)*v , identity_matrix<double,N>(), 1e-10)));
+
+ svd2(false,true,a2+a2,u,q,v2);
+ svd2(true,false,a2+a2,u2,q,v);
+ svd2(false,false,a2+a2,u,q2,v);
+
+ err = max(abs(a - subm(u2,get_rect(a2+a2))*diagm(q2)*trans(v2)));
+ DLIB_TEST_MSG( err < 1e-11,"err: " << err);
+ DLIB_TEST((equal(trans(u2)*u2 , identity_matrix<double,M>(), 1e-10)));
+ DLIB_TEST((equal(trans(v2)*v2 , identity_matrix<double,N>(), 1e-10)));
+
+ }
+
+
+
+ {
+ srand(423452);
+ const long M = 10;
+ const long N = 7;
+
+ matrix<double> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double,M,M> u;
+ matrix<double> q;
+ matrix<double,N,N> v;
+
+ matrix<double,M,N,MM> a2;
+ a2 = tmp(a/2);
+
+
+ svd2(true,true,a2+a2,u,q,v);
+
+ double err = sum(round(1e10*(a - subm(u,get_rect(a2+a2))*diagm(q)*trans(v))));
+ DLIB_TEST_MSG( err == 0,"err: " << err);
+ DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix<double,M>()));
+ DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix<double,N>()));
+ }
+
+
+ }
+
+
+ void matrix_test2 (
+ )
+ {
+ typedef memory_manager_stateless<char>::kernel_2_2a MM;
+ {
+ srand(423452);
+ const long M = 10;
+ const long N = 7;
+
+ matrix<double> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double,M> u(M,N);
+ matrix<double> w;
+ matrix<double,N,N> v(N,N);
+
+ matrix<double,M,N,MM> a2;
+ a2 = tmp(a/2);
+
+
+ svd(a2+a2,u,w,v);
+
+ DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0);
+ DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix<double,N>()));
+ DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix<double,N>()));
+ }
+
+ {
+ srand(423452);
+ const long M = 1;
+ const long N = 1;
+
+ matrix<double> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double,M,N> u;
+ matrix<double> w;
+ matrix<double,N,N> v;
+
+ matrix<double,M,N> a2;
+ a2 = 0;
+ a2 = tmp(a/2);
+
+
+ svd(a2+a2,u,w,v);
+
+ DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0);
+ DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix<double,N>()));
+ DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix<double,N>()));
+ }
+
+
+ {
+ srand(53434);
+ const long M = 5;
+ const long N = 5;
+
+ matrix<double> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double,0,N> u(M,N);
+ matrix<double,N,N> w;
+ matrix<double> v;
+
+ svd(a,u,w,v);
+
+ DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0);
+ DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix<double,N>()));
+ DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix<double,N>()));
+ }
+
+
+ {
+ srand(11234);
+ const long M = 9;
+ const long N = 4;
+
+ matrix<double,0,0,MM> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double> u;
+ matrix<double,0,0,MM> w;
+ matrix<double> v;
+
+ svd(a,u,w,v);
+
+ DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0);
+ DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix<double,N>()));
+ DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix<double,N>()));
+ }
+
+
+
+ {
+ srand(53934);
+ const long M = 2;
+ const long N = 4;
+
+ matrix<double> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double> u;
+ matrix<double> w;
+ matrix<double> v;
+
+ svd(a,u,w,v);
+
+ DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0);
+ }
+
+
+ {
+ srand(53234);
+ const long M = 9;
+ const long N = 40;
+
+ matrix<double> a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ matrix<double> u;
+ matrix<double> w;
+ matrix<double> v;
+
+ svd(a,u,w,v);
+
+ DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0);
+ }
+
+ {
+ srand(53234);
+ const long M = 9;
+ const long N = 40;
+
+ typedef matrix<double,0,0,default_memory_manager, column_major_layout> mat;
+ mat a(M,N);
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = 10*((double)::rand())/RAND_MAX;
+ }
+ }
+
+ mat u;
+ mat w;
+ mat v;
+
+ svd(a,u,w,v);
+
+ DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0);
+ }
+
+
+ {
+ matrix<double> a(3,3);
+ matrix<double,3,3> b;
+ set_all_elements(a,0);
+
+ a(0,0) = 1;
+ a(1,1) = 2;
+ a(2,2) = 3;
+ b = a;
+
+ DLIB_TEST(diag(a)(0) == 1);
+ DLIB_TEST(diag(a)(1) == 2);
+ DLIB_TEST(diag(a)(2) == 3);
+ DLIB_TEST(diag(a).nr() == 3);
+ DLIB_TEST(diag(a).nc() == 1);
+
+ DLIB_TEST(diag(b)(0) == 1);
+ DLIB_TEST(diag(b)(1) == 2);
+ DLIB_TEST(diag(b)(2) == 3);
+ DLIB_TEST(diag(b).nr() == 3);
+ DLIB_TEST(diag(b).nc() == 1);
+
+ DLIB_TEST(pointwise_multiply(a,b)(0,0) == 1);
+ DLIB_TEST(pointwise_multiply(a,b)(1,1) == 4);
+ DLIB_TEST(pointwise_multiply(a,b)(2,2) == 9);
+ DLIB_TEST(pointwise_multiply(a,b)(1,0) == 0);
+ DLIB_TEST(pointwise_multiply(a,b,a)(1,0) == 0);
+ DLIB_TEST(pointwise_multiply(a,b,a,b)(1,0) == 0);
+
+
+ DLIB_TEST(complex_matrix(a,b)(0,0) == std::complex<double>(1,1));
+ DLIB_TEST(complex_matrix(a,b)(2,2) == std::complex<double>(3,3));
+ DLIB_TEST(complex_matrix(a,b)(2,1) == std::complex<double>(0,0));
+ }
+
+ {
+ matrix<complex<double> > m(2,2), m2(2,2);
+ complex<double> val1(1,2), val2(1.0/complex<double>(1,2));
+ m = val1;
+ m2 = val2;
+
+ DLIB_TEST(equal(reciprocal(m) , m2));
+ }
+ {
+ matrix<complex<float> > m(2,2), m2(2,2);
+ complex<float> val1(1,2), val2(1.0f/complex<float>(1,2));
+ m = val1;
+ m2 = val2;
+
+ DLIB_TEST(equal(reciprocal(m) , m2));
+ }
+
+ {
+ matrix<float,3,1> m1, m2;
+ set_all_elements(m1,2.0);
+ set_all_elements(m2,1.0/2.0);
+ DLIB_TEST(reciprocal(m1) == m2);
+ DLIB_TEST((reciprocal(uniform_matrix<float,3,1>(2.0)) == m2));
+ DLIB_TEST((round_zeros(uniform_matrix<float,3,1>(1e-8f)) == uniform_matrix<float,3,1>(0)) );
+ set_all_elements(m1,2.0);
+ m2 = m1;
+ m1(1,0) = static_cast<float>(1e-8);
+ m2(1,0) = 0;
+ DLIB_TEST(round_zeros(m1) == m2);
+ m1 = round_zeros(m1);
+ DLIB_TEST(m1 == m2);
+ }
+
+ {
+ matrix<matrix<double,2,2> > m;
+ m.set_size(3,3);
+ set_all_elements(m,uniform_matrix<double,2,2>(1));
+ DLIB_TEST((sum(m) == uniform_matrix<double,2,2>(9)));
+ DLIB_TEST((round_zeros(sqrt(sum(m)) - uniform_matrix<double,2,2>(3)) == uniform_matrix<double,2,2>(0)));
+ }
+
+ {
+ matrix<int,2,2> m1;
+ matrix<int> m2;
+ m2.set_size(2,2);
+
+ set_all_elements(m1,2);
+ m2 = uniform_matrix<int,2,2>(2);
+
+ m1 = m1 + m2;
+ DLIB_TEST((m1 == uniform_matrix<int,2,2>(4)));
+
+ set_all_elements(m1,2);
+ set_all_elements(m2,2);
+ m1 = m1*m1;
+ DLIB_TEST((m1 == uniform_matrix<int,2,2>(8)));
+
+ m1(1,0) = 1;
+ set_all_elements(m2,8);
+ m2(0,1) = 1;
+ m1 = trans(m1);
+ DLIB_TEST(m1 == m2);
+ }
+
+ {
+ matrix<double,2,3> m;
+ matrix<double> m2(2,3);
+
+ set_all_elements(m,1);
+ DLIB_TEST(mean(m) == 1);
+ set_all_elements(m,2);
+ DLIB_TEST(mean(m) == 2);
+ m(0,0) = 1;
+ m(0,1) = 1;
+ m(0,2) = 1;
+ DLIB_TEST(abs(mean(m) - 1.5) < 1e-10);
+ DLIB_TEST(abs(variance(m) - 0.3) < 1e-10);
+
+ set_all_elements(m2,1);
+ DLIB_TEST(mean(m2) == 1);
+ set_all_elements(m2,2);
+ DLIB_TEST(mean(m2) == 2);
+ m2(0,0) = 1;
+ m2(0,1) = 1;
+ m2(0,2) = 1;
+ DLIB_TEST(abs(mean(m2) - 1.5) < 1e-10);
+ DLIB_TEST(abs(variance(m2) - 0.3) < 1e-10);
+
+ set_all_elements(m,0);
+ DLIB_TEST(abs(variance(m)) < 1e-10);
+ set_all_elements(m,1);
+ DLIB_TEST(abs(variance(m)) < 1e-10);
+ set_all_elements(m,23.4);
+ DLIB_TEST(abs(variance(m)) < 1e-10);
+ }
+
+ {
+ matrix<matrix<double,3,1,MM>,2,2,MM> m;
+ set_all_elements(m,uniform_matrix<double,3,1>(1));
+ DLIB_TEST((round_zeros(variance(m)) == uniform_matrix<double,3,1>(0)));
+ DLIB_TEST((round_zeros(mean(m)) == uniform_matrix<double,3,1>(1)));
+ m(0,0) = uniform_matrix<double,3,1>(9);
+ DLIB_TEST((round_zeros(variance(m)) == uniform_matrix<double,3,1>(16)));
+ DLIB_TEST((round_zeros(mean(m)) == uniform_matrix<double,3,1>(3)));
+
+ matrix<matrix<double> > m2(2,2);
+ set_all_elements(m2,uniform_matrix<double,3,1>(1));
+ DLIB_TEST((round_zeros(variance(m2)) == uniform_matrix<double,3,1>(0)));
+ DLIB_TEST((round_zeros(mean(m2)) == uniform_matrix<double,3,1>(1)));
+ m2(0,0) = uniform_matrix<double,3,1>(9);
+ DLIB_TEST((round_zeros(variance(m2)) == uniform_matrix<double,3,1>(16)));
+ DLIB_TEST((round_zeros(mean(m2)) == uniform_matrix<double,3,1>(3)));
+ }
+
+
+ {
+ matrix<double> m(4,4), m2;
+ m = 1,2,3,4,
+ 1,2,3,4,
+ 4,6,8,10,
+ 4,6,8,10;
+ m2 = m;
+
+ DLIB_TEST(colm(m,range(0,3)) == m);
+ DLIB_TEST(rowm(m,range(0,3)) == m);
+ DLIB_TEST(colm(m,range(0,0)) == colm(m,0));
+ DLIB_TEST(rowm(m,range(0,0)) == rowm(m,0));
+ DLIB_TEST(colm(m,range(1,1)) == colm(m,1));
+ DLIB_TEST(rowm(m,range(1,1)) == rowm(m,1));
+
+ DLIB_TEST(colm(m,range(2,2)) == colm(m,2));
+ DLIB_TEST(rowm(m,range(2,2)) == rowm(m,2));
+
+ DLIB_TEST(colm(m,range(1,2)) == subm(m,0,1,4,2));
+ DLIB_TEST(rowm(m,range(1,2)) == subm(m,1,0,2,4));
+
+ set_colm(m,range(1,2)) = 9;
+ set_subm(m2,0,1,4,2) = 9;
+ DLIB_TEST(m == m2);
+
+ set_colm(m,range(1,2)) = 11;
+ set_subm(m2,0,1,4,2) = 11;
+ DLIB_TEST(m == m2);
+ }
+
+ {
+ print_spinner();
+ matrix<double,1,1> m1;
+ matrix<double,2,2> m2;
+ matrix<double,3,3> m3;
+ matrix<double,4,4> m4;
+
+ dlib::rand rnd;
+ for (int i = 0; i < 50; ++i)
+ {
+ m1 = randm(1,1,rnd);
+ m2 = randm(2,2,rnd);
+ m3 = randm(3,3,rnd);
+ m4 = randm(4,4,rnd);
+
+ DLIB_TEST(max(abs(m1*inv(m1) - identity_matrix(m1))) < 1e-13);
+ DLIB_TEST(max(abs(m2*inv(m2) - identity_matrix(m2))) < 1e-12);
+ DLIB_TEST(max(abs(m3*inv(m3) - identity_matrix(m3))) < 1e-13);
+ DLIB_TEST_MSG(max(abs(m4*inv(m4) - identity_matrix(m4))) < 1e-12, max(abs(m4*inv(m4) - identity_matrix(m4))));
+ }
+ }
+
+ }
+
+
+
+
+
+
+ class matrix_tester : public tester
+ {
+ public:
+ matrix_tester (
+ ) :
+ tester ("test_matrix2",
+ "Runs tests on the matrix component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ matrix_test1();
+ matrix_test2();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/matrix3.cpp b/ml/dlib/dlib/test/matrix3.cpp
new file mode 100644
index 000000000..b66af638c
--- /dev/null
+++ b/ml/dlib/dlib/test/matrix3.cpp
@@ -0,0 +1,1134 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+
+#include "tester.h"
+#include <dlib/memory_manager_stateless.h>
+#include <dlib/array2d.h>
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.matrix3");
+
+
+ const double eps_mul = 200;
+
+ template <typename T, typename U>
+ void check_equal (
+ const T& a,
+ const U& b
+ )
+ {
+ DLIB_TEST(a.nr() == b.nr());
+ DLIB_TEST(a.nc() == b.nc());
+ typedef typename T::type type;
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ type error = std::abs(a(r,c) - b(r,c));
+ DLIB_TEST_MSG(error < std::sqrt(std::numeric_limits<type>::epsilon())*eps_mul, "error: " << error <<
+ " eps: " << std::sqrt(std::numeric_limits<type>::epsilon())*eps_mul);
+ }
+ }
+ }
+
+ template <typename T, typename U>
+ void c_check_equal (
+ const T& a,
+ const U& b
+ )
+ {
+ DLIB_TEST(a.nr() == b.nr());
+ DLIB_TEST(a.nc() == b.nc());
+ typedef typename T::type type;
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ typename type::value_type error = std::abs(a(r,c) - b(r,c));
+ DLIB_TEST_MSG(error < std::sqrt(std::numeric_limits<typename type::value_type>::epsilon())*eps_mul, "error: " << error <<
+ " eps: " << std::sqrt(std::numeric_limits<typename type::value_type>::epsilon())*eps_mul);
+ }
+ }
+ }
+
+ template <typename T, typename U>
+ void assign_no_blas (
+ const T& a_,
+ const U& b
+ )
+ {
+ T& a = const_cast<T&>(a_);
+ DLIB_TEST(a.nr() == b.nr());
+ DLIB_TEST(a.nc() == b.nc());
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = b(r,c);
+ }
+ }
+ }
+
+ template <typename type>
+ type rnd_num (dlib::rand& rnd)
+ {
+ return static_cast<type>(10*rnd.get_random_double());
+ }
+
+ template <typename type>
+ void test_blas( long rows, long cols)
+ {
+ // The tests in this function exercise the BLAS bindings located in the matrix/matrix_blas_bindings.h file.
+ // It does this by performing an assignment that is subject to BLAS bindings and comparing the
+ // results directly to an unevaluated matrix_exp that should be equal.
+
+ dlib::rand rnd;
+
+ matrix<type> a(rows,cols), temp, temp2, temp3;
+
+ for (int k = 0; k < 6; ++k)
+ {
+ for (long r= 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a(r,c) = rnd_num<type>(rnd);
+ }
+ }
+ matrix<type> at;
+ at = trans(a);
+
+ matrix<complex<type> > c_a(rows,cols), c_at, c_sqr;
+ for (long r= 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ c_a(r,c) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
+ }
+ }
+ c_at = trans(c_a);
+ const int size = max(rows,cols);
+ c_sqr = 10*matrix_cast<complex<type> >(complex_matrix(randm(size,size,rnd), randm(size,size,rnd)));
+
+
+ matrix<complex<type> > c_temp(cols,cols), c_temp2(cols,cols);
+ const complex<type> i(0,1);
+
+ const type one = 1;
+ const type two = 1;
+ const type num1 = static_cast<type>(3.6);
+ const type num2 = static_cast<type>(6.6);
+ const type num3 = static_cast<type>(8.6);
+
+ matrix<complex<type>,0,1> c_cv4(cols), c_cv3(rows);
+ matrix<complex<type>,1,0> c_rv4(cols), c_rv3(rows);
+
+ matrix<type,0,1> cv4(cols);
+
+ for (long idx = 0; idx < cv4.size(); ++idx)
+ cv4(idx) = rnd_num<type>(rnd);
+
+ for (long idx = 0; idx < c_cv4.size(); ++idx)
+ c_cv4(idx) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
+
+ matrix<type,1,0> rv3(rows);
+
+ for (long idx = 0; idx < rv3.size(); ++idx)
+ rv3(idx) = rnd_num<type>(rnd);
+
+ for (long idx = 0; idx < c_rv3.size(); ++idx)
+ c_rv3(idx) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
+
+ matrix<type,0,1> cv3(rows);
+
+ for (long idx = 0; idx < cv3.size(); ++idx)
+ cv3(idx) = rnd_num<type>(rnd);
+
+ for (long idx = 0; idx < c_cv3.size(); ++idx)
+ c_cv3(idx) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
+
+ matrix<type,1,0> rv4(cols);
+ for (long idx = 0; idx < rv4.size(); ++idx)
+ rv4(idx) = rnd_num<type>(rnd);
+
+ for (long idx = 0; idx < c_rv4.size(); ++idx)
+ c_rv4(idx) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
+
+
+
+ // GEMM tests
+ dlog << LTRACE << "1.1";
+ check_equal(tmp(at*a), at*a);
+ check_equal(tmp(trans(at*a)), trans(at*a));
+ check_equal(tmp(2.4*trans(4*trans(at*a) + at*3*a)), 2.4*trans(4*trans(at*a) + at*3*a));
+ dlog << LTRACE << "1.2";
+ check_equal(tmp(trans(a)*a), trans(a)*a);
+ check_equal(tmp(trans(trans(a)*a)), trans(trans(a)*a));
+ dlog << LTRACE << "1.3";
+ check_equal(tmp(at*trans(at)), at*trans(at));
+ check_equal(tmp(trans(at*trans(at))), trans(at*trans(at)));
+ dlog << LTRACE << "1.4";
+ check_equal(tmp(trans(at)*trans(a)), a*at);
+ check_equal(tmp(trans(trans(at)*trans(a))), trans(a*at));
+ dlog << LTRACE << "1.5";
+
+ print_spinner();
+ c_check_equal(tmp(conj(trans(c_a))*c_a), trans(conj(c_a))*c_a);
+ dlog << LTRACE << "1.5.1";
+ c_check_equal(tmp(trans(conj(trans(c_a))*c_a)), trans(trans(conj(c_a))*c_a));
+ dlog << LTRACE << "1.5.2";
+ c_check_equal(tmp((conj(trans(c_sqr))*trans(c_sqr))), (trans(conj(c_sqr))*trans(c_sqr)));
+ dlog << LTRACE << "1.5.3";
+ c_check_equal(tmp(trans(conj(trans(c_sqr))*trans(c_sqr))), trans(trans(conj(c_sqr))*trans(c_sqr)));
+ dlog << LTRACE << "1.6";
+ c_check_equal(tmp(c_at*trans(conj(c_at))), c_at*conj(trans(c_at)));
+ dlog << LTRACE << "1.6.1";
+ c_check_equal(tmp(trans(c_at*trans(conj(c_at)))), trans(c_at*conj(trans(c_at))));
+ dlog << LTRACE << "1.6.2";
+ c_check_equal(tmp((c_sqr)*trans(conj(c_sqr))), (c_sqr)*conj(trans(c_sqr)));
+ dlog << LTRACE << "1.6.2.1";
+ c_check_equal(tmp(trans(c_sqr)*trans(conj(c_sqr))), trans(c_sqr)*conj(trans(c_sqr)));
+ dlog << LTRACE << "1.6.3";
+ c_check_equal(tmp(trans(trans(c_sqr)*trans(conj(c_sqr)))), trans(trans(c_sqr)*conj(trans(c_sqr))));
+ dlog << LTRACE << "1.7";
+ c_check_equal(tmp(conj(trans(c_at))*trans(conj(c_a))), conj(trans(c_at))*trans(conj(c_a)));
+ c_check_equal(tmp(trans(conj(trans(c_at))*trans(conj(c_a)))), trans(conj(trans(c_at))*trans(conj(c_a))));
+ dlog << LTRACE << "1.8";
+
+ check_equal(tmp(a*trans(rowm(a,1))) , a*trans(rowm(a,1)));
+ check_equal(tmp(a*colm(at,1)) , a*colm(at,1));
+ check_equal(tmp(subm(a,1,1,2,2)*subm(a,1,2,2,2)), subm(a,1,1,2,2)*subm(a,1,2,2,2));
+
+ dlog << LTRACE << "1.9";
+ check_equal(tmp(trans(a*trans(rowm(a,1)))) , trans(a*trans(rowm(a,1))));
+ dlog << LTRACE << "1.10";
+ check_equal(tmp(trans(a*colm(at,1))) , trans(a*colm(at,1)));
+ dlog << LTRACE << "1.11";
+ check_equal(tmp(trans(subm(a,1,1,2,2)*subm(a,1,2,2,2))), trans(subm(a,1,1,2,2)*subm(a,1,2,2,2)));
+ dlog << LTRACE << "1.12";
+
+ {
+ temp = at*a;
+ temp2 = temp;
+
+ temp += 3.5*at*a;
+ assign_no_blas(temp2, temp2 + 3.5*at*a);
+ check_equal(temp, temp2);
+
+ temp -= at*3.5*a;
+ assign_no_blas(temp2, temp2 - at*3.5*a);
+ check_equal(temp, temp2);
+
+ temp = temp + 4*at*a;
+ assign_no_blas(temp2, temp2 + 4*at*a);
+ check_equal(temp, temp2);
+
+ temp = temp - 2.4*at*a;
+ assign_no_blas(temp2, temp2 - 2.4*at*a);
+ check_equal(temp, temp2);
+ }
+ dlog << LTRACE << "1.13";
+ {
+ temp = trans(at*a);
+ temp2 = temp;
+ temp3 = temp;
+
+ dlog << LTRACE << "1.14";
+ temp += trans(3.5*at*a);
+ assign_no_blas(temp2, temp2 + trans(3.5*at*a));
+ check_equal(temp, temp2);
+
+ dlog << LTRACE << "1.15";
+ temp -= trans(at*3.5*a);
+ assign_no_blas(temp2, temp2 - trans(at*3.5*a));
+ check_equal(temp, temp2);
+
+ dlog << LTRACE << "1.16";
+ temp = trans(temp + 4*at*a);
+ assign_no_blas(temp3, trans(temp2 + 4*at*a));
+ check_equal(temp, temp3);
+
+ temp2 = temp;
+ dlog << LTRACE << "1.17";
+ temp = trans(temp - 2.4*at*a);
+ assign_no_blas(temp3, trans(temp2 - 2.4*at*a));
+ check_equal(temp, temp3);
+ }
+
+ dlog << LTRACE << "1.17.1";
+ {
+ matrix<type> m1, m2;
+
+ m1 = matrix_cast<type>(randm(rows, cols, rnd));
+ m2 = matrix_cast<type>(randm(cols, rows + 8, rnd));
+ check_equal(tmp(m1*m2), m1*m2);
+ check_equal(tmp(trans(m1*m2)), trans(m1*m2));
+
+ m1 = trans(m1);
+ check_equal(tmp(trans(m1)*m2), trans(m1)*m2);
+ check_equal(tmp(trans(trans(m1)*m2)), trans(trans(m1)*m2));
+
+ m2 = trans(m2);
+ check_equal(tmp(trans(m1)*trans(m2)), trans(m1)*trans(m2));
+ check_equal(tmp(trans(trans(m1)*trans(m2))), trans(trans(m1)*trans(m2)));
+
+ m1 = trans(m1);
+ check_equal(tmp(m1*trans(m2)), m1*trans(m2));
+ check_equal(tmp(trans(m1*trans(m2))), trans(m1*trans(m2)));
+ }
+
+ dlog << LTRACE << "1.17.5";
+ {
+ matrix<type,1,0> r;
+ matrix<type,0,1> c;
+
+ r = matrix_cast<type>(randm(1, rows+9, rnd));
+ c = matrix_cast<type>(randm(rows, 1, rnd));
+
+ check_equal(tmp(c*r), c*r);
+ check_equal(tmp(trans(c*r)), trans(c*r));
+
+ check_equal(tmp(trans(r)*trans(c)), trans(r)*trans(c));
+ check_equal(tmp(trans(trans(r)*trans(c))), trans(trans(r)*trans(c)));
+ }
+
+ dlog << LTRACE << "1.18";
+
+ // GEMV tests
+ check_equal(tmp(a*cv4), a*cv4);
+ check_equal(tmp(trans(a*cv4)), trans(a*cv4));
+ check_equal(tmp(rv3*a), rv3*a);
+ check_equal(tmp(trans(cv4)*at), trans(cv4)*at);
+ check_equal(tmp(a*trans(rv4)), a*trans(rv4));
+ check_equal(tmp(trans(a*trans(rv4))), trans(a*trans(rv4)));
+
+ check_equal(tmp(trans(a)*cv3), trans(a)*cv3);
+ check_equal(tmp(rv4*trans(a)), rv4*trans(a));
+ check_equal(tmp(trans(cv3)*trans(at)), trans(cv3)*trans(at));
+ check_equal(tmp(trans(cv3)*a), trans(cv3)*a);
+ check_equal(tmp(trans(a)*trans(rv3)), trans(a)*trans(rv3));
+
+
+ c_check_equal(tmp(trans(conj(c_a))*c_cv3), trans(conj(c_a))*c_cv3);
+ c_check_equal(tmp(c_rv4*trans(conj(c_a))), c_rv4*trans(conj(c_a)));
+ c_check_equal(tmp(trans(c_cv3)*trans(conj(c_at))), trans(c_cv3)*trans(conj(c_at)));
+ c_check_equal(tmp(conj(trans(c_a))*trans(c_rv3)), trans(conj(c_a))*trans(c_rv3));
+ c_check_equal(tmp(c_rv4*conj(c_at)), c_rv4*conj(c_at));
+ c_check_equal(tmp(trans(c_cv4)*conj(c_at)), trans(c_cv4)*conj(c_at));
+
+ dlog << LTRACE << "2.00";
+
+ c_check_equal(tmp(trans(trans(conj(c_a))*c_cv3)), trans(trans(conj(c_a))*c_cv3));
+ c_check_equal(tmp(trans(c_rv4*trans(conj(c_a)))), trans(c_rv4*trans(conj(c_a))));
+ c_check_equal(tmp(trans(trans(c_cv3)*trans(conj(c_at)))), trans(trans(c_cv3)*trans(conj(c_at))));
+ dlog << LTRACE << "2.20";
+ c_check_equal(tmp(trans(conj(trans(c_a))*trans(c_rv3))), trans(trans(conj(c_a))*trans(c_rv3)));
+ c_check_equal(tmp(trans(c_rv4*conj(c_at))), trans(c_rv4*conj(c_at)));
+ c_check_equal(tmp(trans(trans(c_cv4)*conj(c_at))), trans(trans(c_cv4)*conj(c_at)));
+
+
+
+ dlog << LTRACE << "6";
+ temp = a*at;
+ check_equal(temp, a*at);
+ temp = temp + a*at + trans(at)*at + trans(at)*sin(at);
+ check_equal(temp, a*at + a*at+ trans(at)*at + trans(at)*sin(at));
+
+ dlog << LTRACE << "6.1";
+ temp = a*at;
+ check_equal(temp, a*at);
+ temp = a*at + temp;
+ check_equal(temp, a*at + a*at);
+
+ print_spinner();
+ dlog << LTRACE << "6.2";
+ temp = a*at;
+ check_equal(temp, a*at);
+ dlog << LTRACE << "6.2.3";
+ temp = temp - a*at;
+ dlog << LTRACE << "6.2.4";
+ check_equal(temp, a*at-a*at);
+
+ dlog << LTRACE << "6.3";
+ temp = a*at;
+ dlog << LTRACE << "6.3.5";
+ check_equal(temp, a*at);
+ dlog << LTRACE << "6.3.6";
+ temp = a*at - temp;
+ dlog << LTRACE << "6.4";
+ check_equal(temp, a*at-a*at);
+
+
+
+ const long d = min(rows,cols);
+ rectangle rect(1,1,d,d);
+ temp.set_size(max(rows,cols)+4,max(rows,cols)+4);
+ set_all_elements(temp,4);
+ temp2 = temp;
+
+ dlog << LTRACE << "7";
+ set_subm(temp,rect) = a*at;
+ assign_no_blas( set_subm(temp2,rect) , a*at);
+ check_equal(temp, temp2);
+
+ temp = a;
+ temp2 = a;
+
+ set_colm(temp,1) = a*cv4;
+ assign_no_blas( set_colm(temp2,1) , a*cv4);
+ check_equal(temp, temp2);
+
+ set_rowm(temp,1) = rv3*a;
+ assign_no_blas( set_rowm(temp2,1) , rv3*a);
+ check_equal(temp, temp2);
+
+
+ // Test BLAS GER
+ {
+ temp.set_size(cols,cols);
+ set_all_elements(temp,3);
+ temp2 = temp;
+
+
+ dlog << LTRACE << "8";
+ temp += cv4*rv4;
+ assign_no_blas(temp2, temp2 + cv4*rv4);
+ check_equal(temp, temp2);
+
+ dlog << LTRACE << "8.3";
+ temp = temp + cv4*rv4;
+ assign_no_blas(temp2, temp2 + cv4*rv4);
+ check_equal(temp, temp2);
+ dlog << LTRACE << "8.9";
+ }
+ {
+ temp.set_size(cols,cols);
+ set_all_elements(temp,3);
+ temp2 = temp;
+ temp3 = 0;
+
+ dlog << LTRACE << "8.10";
+
+ temp += trans(cv4*rv4);
+ assign_no_blas(temp3, temp2 + trans(cv4*rv4));
+ check_equal(temp, temp3);
+ temp3 = 0;
+
+ dlog << LTRACE << "8.11";
+ temp2 = temp;
+ temp = trans(temp + cv4*rv4);
+ assign_no_blas(temp3, trans(temp2 + cv4*rv4));
+ check_equal(temp, temp3);
+ dlog << LTRACE << "8.12";
+ }
+ {
+ matrix<complex<type> > temp, temp2, temp3;
+ matrix<complex<type>,0,1 > cv4;
+ matrix<complex<type>,1,0 > rv4;
+ cv4.set_size(cols);
+ rv4.set_size(cols);
+ temp.set_size(cols,cols);
+ set_all_elements(temp,complex<type>(3,5));
+ temp(cols-1, cols-4) = 9;
+ temp2 = temp;
+ temp3.set_size(cols,cols);
+ temp3 = 0;
+
+ for (long i = 0; i < rv4.size(); ++i)
+ {
+ rv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
+ cv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
+ }
+
+ dlog << LTRACE << "8.13";
+
+ temp += trans(cv4*rv4);
+ assign_no_blas(temp3, temp2 + trans(cv4*rv4));
+ c_check_equal(temp, temp3);
+ temp3 = 0;
+
+ dlog << LTRACE << "8.14";
+ temp2 = temp;
+ temp = trans(temp + cv4*rv4);
+ assign_no_blas(temp3, trans(temp2 + cv4*rv4));
+ c_check_equal(temp, temp3);
+ dlog << LTRACE << "8.15";
+ }
+
+
+
+
+ set_all_elements(c_temp, one + num1*i);
+ c_temp2 = c_temp;
+ set_all_elements(c_rv4, one + num2*i);
+ set_all_elements(c_cv4, two + num3*i);
+
+
+ dlog << LTRACE << "9";
+ c_temp += c_cv4*c_rv4;
+ assign_no_blas(c_temp2, c_temp2 + c_cv4*c_rv4);
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "9.1";
+ c_temp += c_cv4*conj(c_rv4);
+ assign_no_blas(c_temp2, c_temp2 + c_cv4*conj(c_rv4));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "9.2";
+ c_temp = c_cv4*conj(c_rv4) + c_temp;
+ assign_no_blas(c_temp2, c_temp2 + c_cv4*conj(c_rv4));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "9.3";
+ c_temp = trans(c_rv4)*trans(conj(c_cv4)) + c_temp;
+ assign_no_blas(c_temp2, c_temp2 + trans(c_rv4)*trans(conj(c_cv4)));
+ c_check_equal(c_temp, c_temp2);
+
+
+ dlog << LTRACE << "9.4";
+ c_temp += conj(c_cv4)*c_rv4;
+ assign_no_blas(c_temp2, c_temp2 + conj(c_cv4)*c_rv4);
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "9.5";
+ c_temp += conj(c_cv4)*conj(c_rv4);
+ assign_no_blas(c_temp2, c_temp2 + conj(c_cv4)*conj(c_rv4));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "9.6";
+ c_temp = conj(c_cv4)*conj(c_rv4) + c_temp;
+ assign_no_blas(c_temp2, c_temp2 + conj(c_cv4)*conj(c_rv4));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "9.7";
+ c_temp = conj(trans(c_rv4))*trans(conj(c_cv4)) + c_temp;
+ assign_no_blas(c_temp2, c_temp2 + conj(trans(c_rv4))*trans(conj(c_cv4)));
+ c_check_equal(c_temp, c_temp2);
+
+
+ dlog << LTRACE << "10";
+ c_temp += trans(c_cv4*c_rv4);
+ assign_no_blas(c_temp2, c_temp2 + trans(c_cv4*c_rv4));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "10.1";
+ c_temp += trans(c_cv4*conj(c_rv4));
+ assign_no_blas(c_temp2, c_temp2 + trans(c_cv4*conj(c_rv4)));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "10.2";
+ c_temp = trans(c_cv4*conj(c_rv4)) + c_temp;
+ assign_no_blas(c_temp2, c_temp2 + trans(c_cv4*conj(c_rv4)));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "10.3";
+ c_temp = trans(trans(c_rv4)*trans(conj(c_cv4))) + c_temp;
+ assign_no_blas(c_temp2, c_temp2 + trans(trans(c_rv4)*trans(conj(c_cv4))));
+ c_check_equal(c_temp, c_temp2);
+
+
+ dlog << LTRACE << "10.4";
+ c_temp += trans(conj(c_cv4)*c_rv4);
+ assign_no_blas(c_temp2, c_temp2 + trans(conj(c_cv4)*c_rv4));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "10.5";
+ c_temp += trans(conj(c_cv4)*conj(c_rv4));
+ assign_no_blas(c_temp2, c_temp2 + trans(conj(c_cv4)*conj(c_rv4)));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "10.6";
+ c_temp = trans(conj(c_cv4)*conj(c_rv4)) + c_temp;
+ assign_no_blas(c_temp2, c_temp2 + trans(conj(c_cv4)*conj(c_rv4)));
+ c_check_equal(c_temp, c_temp2);
+ dlog << LTRACE << "10.7";
+ c_temp = trans(conj(trans(c_rv4))*trans(conj(c_cv4))) + c_temp;
+ assign_no_blas(c_temp2, c_temp2 + trans(conj(trans(c_rv4))*trans(conj(c_cv4))));
+ c_check_equal(c_temp, c_temp2);
+
+ dlog << LTRACE << "10.8";
+
+
+ print_spinner();
+
+ // Test DOT
+ check_equal( tmp(rv4*cv4), rv4*cv4);
+ check_equal( tmp(trans(rv4*cv4)), trans(rv4*cv4));
+ check_equal( tmp(trans(cv4)*trans(rv4)), trans(cv4)*trans(rv4));
+ check_equal( tmp(rv4*3.9*cv4), rv4*3.9*cv4);
+ check_equal( tmp(trans(cv4)*3.9*trans(rv4)), trans(cv4)*3.9*trans(rv4));
+ check_equal( tmp(rv4*cv4*3.9), rv4*3.9*cv4);
+ check_equal( tmp(trans(cv4)*trans(rv4)*3.9), trans(cv4)*3.9*trans(rv4));
+
+
+ check_equal( tmp(trans(rv4*cv4)), trans(rv4*cv4));
+ check_equal( tmp(trans(trans(rv4*cv4))), trans(trans(rv4*cv4)));
+ check_equal( tmp(trans(trans(cv4)*trans(rv4))), trans(trans(cv4)*trans(rv4)));
+ check_equal( tmp(trans(rv4*3.9*cv4)), trans(rv4*3.9*cv4));
+ check_equal( tmp(trans(trans(cv4)*3.9*trans(rv4))), trans(trans(cv4)*3.9*trans(rv4)));
+ check_equal( tmp(trans(rv4*cv4*3.9)), trans(rv4*3.9*cv4));
+ check_equal( tmp(trans(trans(cv4)*trans(rv4)*3.9)), trans(trans(cv4)*3.9*trans(rv4)));
+
+
+ temp.set_size(1,1);
+ temp = 4;
+ check_equal( tmp(temp + rv4*cv4), temp + rv4*cv4);
+ check_equal( tmp(temp + trans(cv4)*trans(rv4)), temp + trans(cv4)*trans(rv4));
+
+ dlog << LTRACE << "11";
+
+
+
+ c_check_equal( tmp(conj(c_rv4)*c_cv4), conj(c_rv4)*c_cv4);
+ c_check_equal( tmp(conj(trans(c_cv4))*trans(c_rv4)), trans(conj(c_cv4))*trans(c_rv4));
+
+ c_check_equal( tmp(conj(c_rv4)*i*c_cv4), conj(c_rv4)*i*c_cv4);
+ c_check_equal( tmp(conj(trans(c_cv4))*i*trans(c_rv4)), trans(conj(c_cv4))*i*trans(c_rv4));
+
+ c_temp.set_size(1,1);
+ c_temp = 4;
+ c_check_equal( tmp(c_temp + conj(c_rv4)*c_cv4), c_temp + conj(c_rv4)*c_cv4);
+ c_check_equal( tmp(c_temp + trans(conj(c_cv4))*trans(c_rv4)), c_temp + trans(conj(c_cv4))*trans(c_rv4));
+
+ complex<type> tmp = c_rv4*c_cv4;
+ DLIB_TEST(abs((tmp + i) - ((c_rv4*c_cv4)(0) + i)) < std::sqrt(std::numeric_limits<type>::epsilon())*eps_mul );
+ DLIB_TEST(max(abs((rv4*cv4 + 1.0) - ((rv4*cv4)(0) + 1.0))) < std::sqrt(std::numeric_limits<type>::epsilon())*eps_mul);
+
+ }
+
+ {
+ matrix<int> m(2,3), m2(6,1);
+
+ m = 1,2,3,
+ 4,5,6;
+
+ m2 = 1,2,3,4,5,6;
+
+ DLIB_TEST(reshape_to_column_vector(m) == m2);
+ DLIB_TEST(reshape_to_column_vector(m+m) == m2+m2);
+
+ }
+ {
+ matrix<int,2,3> m(2,3);
+ matrix<int> m2(6,1);
+
+ m = 1,2,3,
+ 4,5,6;
+
+ m2 = 1,2,3,4,5,6;
+
+ DLIB_TEST(reshape_to_column_vector(m) == m2);
+ DLIB_TEST(reshape_to_column_vector(m+m) == m2+m2);
+
+ }
+ }
+
+
+ void matrix_test (
+ )
+ /*!
+ ensures
+ - runs tests on the matrix stuff compliance with the specs
+ !*/
+ {
+ print_spinner();
+
+
+ {
+ matrix<long> m1(2,2), m2(2,2);
+
+ m1 = 1, 2,
+ 3, 4;
+
+ m2 = 4, 5,
+ 6, 7;
+
+
+ DLIB_TEST(subm(tensor_product(m1,m2),range(0,1), range(0,1)) == 1*m2);
+ DLIB_TEST(subm(tensor_product(m1,m2),range(0,1), range(2,3)) == 2*m2);
+ DLIB_TEST(subm(tensor_product(m1,m2),range(2,3), range(0,1)) == 3*m2);
+ DLIB_TEST(subm(tensor_product(m1,m2),range(2,3), range(2,3)) == 4*m2);
+ }
+
+ {
+ print_spinner();
+ dlog << LTRACE << "testing blas stuff";
+ dlog << LTRACE << " \nsmall double";
+ test_blas<double>(3,4);
+ print_spinner();
+ dlog << LTRACE << " \nsmall float";
+ test_blas<float>(3,4);
+ print_spinner();
+ dlog << LTRACE << " \nbig double";
+ test_blas<double>(120,131);
+ print_spinner();
+ dlog << LTRACE << " \nbig float";
+ test_blas<float>(120,131);
+ print_spinner();
+ dlog << LTRACE << "testing done";
+ }
+
+
+ {
+ matrix<long> m(3,4), ml(3,4), mu(3,4);
+ m = 1,2,3,4,
+ 4,5,6,7,
+ 7,8,9,0;
+
+ ml = 1,0,0,0,
+ 4,5,0,0,
+ 7,8,9,0;
+
+ mu = 1,2,3,4,
+ 0,5,6,7,
+ 0,0,9,0;
+
+
+ DLIB_TEST(lowerm(m) == ml);
+ DLIB_TEST(upperm(m) == mu);
+
+ ml = 3,0,0,0,
+ 4,3,0,0,
+ 7,8,3,0;
+
+ mu = 4,2,3,4,
+ 0,4,6,7,
+ 0,0,4,0;
+
+ DLIB_TEST(lowerm(m,3) == ml);
+ DLIB_TEST(upperm(m,4) == mu);
+
+ }
+
+ {
+ matrix<long> m(3,4), row(1,3), col(2,1);
+ m = 1,2,3,4,
+ 4,5,6,7,
+ 7,8,9,0;
+
+ row = 4,5,6;
+ col = 3,6;
+
+ DLIB_TEST(rowm(m, 1, 3) == row);
+ DLIB_TEST(colm(m, 2, 2) == col);
+
+ }
+
+
+ {
+ std::vector<double> v(34, 8);
+ std::vector<double> v2(34, 9);
+
+ DLIB_TEST(mat(&v[0], v.size()) == mat(v));
+ DLIB_TEST(mat(&v2[0], v.size()) != mat(v));
+ }
+
+ {
+ std::vector<long> v(1, 3);
+ std::vector<long> v2(1, 2);
+
+ DLIB_TEST(mat(&v[0], v.size()) == mat(v));
+ DLIB_TEST(mat(&v2[0], v.size()) != mat(v));
+ }
+
+ {
+ matrix<double> a(3,3), b(3,3);
+ a = 1, 2.5, 1,
+ 3, 4, 5,
+ 0.5, 2.2, 3;
+
+ b = 0, 1, 0,
+ 1, 1, 1,
+ 0, 1, 1;
+
+ DLIB_TEST((a>1) == b);
+ DLIB_TEST((1<a) == b);
+
+ b = 1, 1, 1,
+ 1, 1, 1,
+ 0, 1, 1;
+
+ DLIB_TEST((a>=1) == b);
+ DLIB_TEST((1<=a) == b);
+
+ b = 0, 0, 0,
+ 0, 0, 0,
+ 0, 1, 0;
+ DLIB_TEST((a==2.2) == b);
+ DLIB_TEST((a!=2.2) == (b==0));
+ DLIB_TEST((2.2==a) == b);
+ DLIB_TEST((2.2!=a) == (0==b));
+
+ b = 0, 0, 0,
+ 0, 0, 0,
+ 1, 0, 0;
+ DLIB_TEST((a<1) == b);
+ DLIB_TEST((1>a) == b);
+
+ b = 1, 0, 1,
+ 0, 0, 0,
+ 1, 0, 0;
+ DLIB_TEST((a<=1) == b);
+ DLIB_TEST((1>=a) == b);
+ }
+
+ {
+ matrix<double> a, b, c;
+ a = randm(4,2);
+
+ b += a;
+ c -= a;
+
+ DLIB_TEST(equal(a, b));
+ DLIB_TEST(equal(-a, c));
+
+ b += a;
+ c -= a;
+
+ DLIB_TEST(equal(2*a, b));
+ DLIB_TEST(equal(-2*a, c));
+
+ b += a + a;
+ c -= a + a;
+
+ DLIB_TEST(equal(4*a, b));
+ DLIB_TEST(equal(-4*a, c));
+
+ b.set_size(0,0);
+ c.set_size(0,0);
+
+
+ b += a + a;
+ c -= a + a;
+
+ DLIB_TEST(equal(2*a, b));
+ DLIB_TEST(equal(-2*a, c));
+ }
+
+ {
+ matrix<int> a, b, c;
+
+ a.set_size(2, 3);
+ b.set_size(2, 6);
+ c.set_size(4, 3);
+
+ a = 1, 2, 3,
+ 4, 5, 6;
+
+ b = 1, 2, 3, 1, 2, 3,
+ 4, 5, 6, 4, 5, 6;
+
+ c = 1, 2, 3,
+ 4, 5, 6,
+ 1, 2, 3,
+ 4, 5, 6;
+
+ DLIB_TEST(join_rows(a,a) == b);
+ DLIB_TEST(join_rows(a,abs(a)) == b);
+ DLIB_TEST(join_cols(trans(a), trans(a)) == trans(b));
+ DLIB_TEST(join_cols(a,a) == c);
+ DLIB_TEST(join_cols(a,abs(a)) == c);
+ DLIB_TEST(join_rows(trans(a),trans(a)) == trans(c));
+ }
+
+ {
+ matrix<int, 2, 3> a;
+ matrix<int, 2, 6> b;
+ matrix<int, 4, 3> c;
+
+ a = 1, 2, 3,
+ 4, 5, 6;
+
+ b = 1, 2, 3, 1, 2, 3,
+ 4, 5, 6, 4, 5, 6;
+
+ c = 1, 2, 3,
+ 4, 5, 6,
+ 1, 2, 3,
+ 4, 5, 6;
+
+ DLIB_TEST(join_rows(a,a) == b);
+ DLIB_TEST(join_rows(a,abs(a)) == b);
+ DLIB_TEST(join_cols(trans(a), trans(a)) == trans(b));
+ DLIB_TEST(join_cols(a,a) == c);
+ DLIB_TEST(join_cols(a,abs(a)) == c);
+ DLIB_TEST(join_rows(trans(a),trans(a)) == trans(c));
+ }
+
+ {
+ matrix<int, 2, 3> a;
+ matrix<int> a2;
+ matrix<int, 2, 6> b;
+ matrix<int, 4, 3> c;
+
+ a = 1, 2, 3,
+ 4, 5, 6;
+
+ a2 = a;
+
+ b = 1, 2, 3, 1, 2, 3,
+ 4, 5, 6, 4, 5, 6;
+
+ c = 1, 2, 3,
+ 4, 5, 6,
+ 1, 2, 3,
+ 4, 5, 6;
+
+ DLIB_TEST(join_rows(a,a2) == b);
+ DLIB_TEST(join_rows(a2,a) == b);
+ DLIB_TEST(join_cols(trans(a2), trans(a)) == trans(b));
+ DLIB_TEST(join_cols(a2,a) == c);
+ DLIB_TEST(join_cols(a,a2) == c);
+ DLIB_TEST(join_rows(trans(a2),trans(a)) == trans(c));
+ }
+
+ {
+ matrix<int> a, b;
+
+ a.set_size(2,3);
+
+ a = 1, 2, 3,
+ 4, 5, 6;
+
+ b.set_size(3,2);
+ b = 1, 2,
+ 3, 4,
+ 5, 6;
+
+ DLIB_TEST(reshape(a, 3, 2) == b);
+
+ b.set_size(2,3);
+ b = 1, 4, 2,
+ 5, 3, 6;
+
+ DLIB_TEST(reshape(trans(a), 2, 3) == b);
+
+ }
+
+ {
+ matrix<int,2,3> a;
+ matrix<int> b;
+
+ a = 1, 2, 3,
+ 4, 5, 6;
+
+ b.set_size(3,2);
+ b = 1, 2,
+ 3, 4,
+ 5, 6;
+
+ DLIB_TEST(reshape(a, 3, 2) == b);
+
+ b.set_size(2,3);
+ b = 1, 4, 2,
+ 5, 3, 6;
+
+ DLIB_TEST(reshape(trans(a), 2, 3) == b);
+
+ }
+
+ {
+ std::vector<int> v(6);
+ for (unsigned long i = 0; i < v.size(); ++i)
+ v[i] = i;
+
+ matrix<int,2,3> a;
+ a = 0, 1, 2,
+ 3, 4, 5;
+
+ DLIB_TEST(mat(&v[0], 2, 3) == a);
+ }
+
+ {
+ matrix<int> a(3,4);
+ matrix<int> b(3,1), c(1,4);
+
+ a = 1, 2, 3, 6,
+ 4, 5, 6, 9,
+ 1, 1, 1, 3;
+
+ b(0) = sum(rowm(a,0));
+ b(1) = sum(rowm(a,1));
+ b(2) = sum(rowm(a,2));
+
+ c(0) = sum(colm(a,0));
+ c(1) = sum(colm(a,1));
+ c(2) = sum(colm(a,2));
+ c(3) = sum(colm(a,3));
+
+ DLIB_TEST(sum_cols(a) == b);
+ DLIB_TEST(sum_rows(a) == c);
+
+ }
+
+ {
+ matrix<int> m(3,3);
+
+ m = 1, 2, 3,
+ 4, 5, 6,
+ 7, 8, 9;
+
+ DLIB_TEST(make_symmetric(m) == trans(make_symmetric(m)));
+ DLIB_TEST(lowerm(make_symmetric(m)) == lowerm(m));
+ DLIB_TEST(upperm(make_symmetric(m)) == trans(lowerm(m)));
+ }
+
+ {
+ matrix<int,3,4> a;
+ matrix<int> b(3,1), c(1,4);
+
+ a = 1, 2, 3, 6,
+ 4, 5, 6, 9,
+ 1, 1, 1, 3;
+
+ b(0) = sum(rowm(a,0));
+ b(1) = sum(rowm(a,1));
+ b(2) = sum(rowm(a,2));
+
+ c(0) = sum(colm(a,0));
+ c(1) = sum(colm(a,1));
+ c(2) = sum(colm(a,2));
+ c(3) = sum(colm(a,3));
+
+ DLIB_TEST(sum_cols(a) == b);
+ DLIB_TEST(sum_rows(a) == c);
+
+ }
+
+ {
+ matrix<int> m(3,4), s(3,4);
+ m = -2, 1, 5, -5,
+ 5, 5, 5, 5,
+ 9, 0, -4, -2;
+
+ s = -1, 1, 1, -1,
+ 1, 1, 1, 1,
+ 1, 1, -1, -1;
+
+ DLIB_TEST(sign(m) == s);
+ DLIB_TEST(sign(matrix_cast<double>(m)) == matrix_cast<double>(s));
+ }
+
+ }
+
+
+ void test_matrix_IO()
+ {
+ dlib::rand rnd;
+ print_spinner();
+
+ for (int i = 0; i < 400; ++i)
+ {
+ ostringstream sout;
+ sout.precision(20);
+
+ matrix<double> m1, m2, m3;
+
+ const long r = rnd.get_random_32bit_number()%7+1;
+ const long c = rnd.get_random_32bit_number()%7+1;
+ const long num = rnd.get_random_32bit_number()%2+1;
+
+ m1 = randm(r,c,rnd);
+ sout << m1;
+ if (num != 1)
+ sout << "\n" << m1;
+
+ if (rnd.get_random_double() < 0.3)
+ sout << " \n";
+ else if (rnd.get_random_double() < 0.3)
+ sout << " \n\n 3 3 3 3";
+ else if (rnd.get_random_double() < 0.3)
+ sout << " \n \n v 3 3 3 3 3";
+
+ istringstream sin(sout.str());
+ sin >> m2;
+ DLIB_TEST_MSG(equal(m1,m2), m1 << "\n***********\n" << m2);
+
+ if (num != 1)
+ {
+ sin >> m3;
+ DLIB_TEST_MSG(equal(m1,m3), m1 << "\n***********\n" << m3);
+ }
+ }
+
+
+ {
+ istringstream sin(" 1 2\n3");
+ matrix<double> m;
+ DLIB_TEST(sin.good());
+ sin >> m;
+ DLIB_TEST(!sin.good());
+ }
+ {
+ istringstream sin("");
+ matrix<double> m;
+ DLIB_TEST(sin.good());
+ sin >> m;
+ DLIB_TEST(!sin.good());
+ }
+ }
+
+
+ void test_axpy()
+ {
+ const int n = 4;
+ matrix<double> B = dlib::randm(n,n);
+
+ matrix<double> g = dlib::uniform_matrix<double>(n,1,0.0);
+
+ const double tau = 1;
+
+ matrix<double> p = g + tau*dlib::colm(B,0);
+ matrix<double> q = dlib::colm(B,0);
+ DLIB_TEST(max(abs(p-q)) < 1e-14);
+
+ p = tau*dlib::colm(B,0);
+ q = dlib::colm(B,0);
+ DLIB_TEST(max(abs(p-q)) < 1e-14);
+
+
+
+
+ g = dlib::uniform_matrix<double>(n,n,0.0);
+ p = g + tau*B;
+ DLIB_TEST(max(abs(p-B)) < 1e-14);
+
+ p = g + tau*subm(B,get_rect(B));
+ DLIB_TEST(max(abs(p-B)) < 1e-14);
+
+ g = dlib::uniform_matrix<double>(2,2,0.0);
+ p = g + tau*subm(B,1,1,2,2);
+ DLIB_TEST(max(abs(p-subm(B,1,1,2,2))) < 1e-14);
+
+ set_subm(p,0,0,2,2) = g + tau*subm(B,1,1,2,2);
+ DLIB_TEST(max(abs(p-subm(B,1,1,2,2))) < 1e-14);
+ }
+
+
+ class matrix_tester : public tester
+ {
+ public:
+ matrix_tester (
+ ) :
+ tester ("test_matrix3",
+ "Runs tests on the matrix component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_axpy();
+ test_matrix_IO();
+ matrix_test();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/matrix4.cpp b/ml/dlib/dlib/test/matrix4.cpp
new file mode 100644
index 000000000..d2b83b712
--- /dev/null
+++ b/ml/dlib/dlib/test/matrix4.cpp
@@ -0,0 +1,1119 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+
+#include "tester.h"
+#include <dlib/memory_manager_stateless.h>
+#include <dlib/array2d.h>
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.matrix4");
+
+ void matrix_test (
+ )
+ /*!
+ ensures
+ - runs tests on the matrix stuff compliance with the specs
+ !*/
+ {
+ print_spinner();
+
+ {
+ matrix<double,3,3> m = round(10*randm(3,3));
+ matrix<double,3,1> v = round(10*randm(3,1));
+
+ DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
+ DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
+
+ DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
+ DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
+ }
+
+ {
+ matrix<double,3,3> m = round(10*randm(3,3));
+ matrix<double,1,3> v = round(10*randm(1,3));
+
+ DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
+ DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
+
+ DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
+ DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
+ }
+
+ {
+ matrix<double> m = round(10*randm(3,3));
+ matrix<double,1,3> v = round(10*randm(1,3));
+
+ DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
+ DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
+
+ DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
+ DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
+ }
+
+ {
+ matrix<double> m = round(10*randm(3,3));
+ matrix<double,0,3> v = round(10*randm(1,3));
+
+ DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
+ DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
+
+ DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
+ DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
+ }
+
+
+ {
+ matrix<double> m = round(10*randm(3,3));
+ matrix<double,1,0> v = round(10*randm(1,3));
+
+ DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
+ DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
+
+ DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
+ DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
+ }
+
+ {
+ matrix<double> m = round(10*randm(3,3));
+ matrix<double,3,0> v = round(10*randm(3,1));
+
+ DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
+ DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
+
+ DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
+ DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
+ }
+
+
+ {
+ matrix<double> m = round(10*randm(3,3));
+ matrix<double,0,1> v = round(10*randm(3,1));
+
+ DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
+ DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
+
+ DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
+ DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
+ }
+
+ {
+ matrix<double,3,3> m = round(10*randm(3,3));
+ matrix<double,3,0> v = round(10*randm(3,1));
+
+ DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
+ DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
+
+ DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
+ DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
+ }
+
+
+ {
+ matrix<double,3,3> m = round(10*randm(3,3));
+ matrix<double,0,1> v = round(10*randm(3,1));
+
+ DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
+ DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
+
+ DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
+ DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
+ }
+
+ {
+ matrix<double,3,5> m = round(10*randm(3,5));
+ matrix<double,0,1> v1 = round(10*randm(5,1));
+ matrix<double,0,1> v2 = round(10*randm(3,1));
+
+ DLIB_TEST(equal( m*diagm(v1) , m*tmp(diagm(v1)) ));
+ DLIB_TEST(equal( scale_columns(m,v1) , m*tmp(diagm(v1)) ));
+
+ DLIB_TEST(equal( diagm(v2)*m , tmp(diagm(v2))*m ));
+ DLIB_TEST(equal( scale_rows(m,v2) , tmp(diagm(v2))*m ));
+ }
+
+ {
+ matrix<double,3,5> m = round(10*randm(3,5));
+ matrix<double,5,1> v1 = round(10*randm(5,1));
+ matrix<double,3,1> v2 = round(10*randm(3,1));
+
+ DLIB_TEST(equal( m*diagm(v1) , m*tmp(diagm(v1)) ));
+ DLIB_TEST(equal( scale_columns(m,v1) , m*tmp(diagm(v1)) ));
+
+ DLIB_TEST(equal( diagm(v2)*m , tmp(diagm(v2))*m ));
+ DLIB_TEST(equal( scale_rows(m,v2) , tmp(diagm(v2))*m ));
+ }
+
+ }
+
+ void test_stuff()
+ {
+ print_spinner();
+
+ {
+ matrix<double> m(3,3), lr(3,3), ud(3,3);
+
+ m = 1,2,3,
+ 4,5,6,
+ 7,8,9;
+
+
+ lr = 3,2,1,
+ 6,5,4,
+ 9,8,7;
+
+ ud = 7,8,9,
+ 4,5,6,
+ 1,2,3;
+
+ DLIB_TEST(lr == fliplr(m));
+ DLIB_TEST(ud == flipud(m));
+ }
+ {
+ matrix<double> m(3,2), lr(3,2), ud(3,2);
+
+ m = 1,2,
+ 3,4,
+ 5,6;
+
+ lr = 2,1,
+ 4,3,
+ 6,5;
+
+ ud = 5,6,
+ 3,4,
+ 1,2;
+
+ DLIB_TEST(lr == fliplr(m));
+ DLIB_TEST(ud == flipud(m));
+ }
+
+ {
+ matrix<int> a, b;
+
+ a = matrix_cast<int>(round(10*randm(3,3)));
+ b = a;
+
+ b *= b;
+ DLIB_TEST(b == a*a);
+ }
+
+ {
+ matrix<double> m(2,3), m2(2,3);
+
+ m = 1,2,3,
+ 4,5,6;
+
+
+ m2 = 3,4,5,
+ 6,7,8;
+
+ DLIB_TEST(m + 2 == m2);
+ DLIB_TEST(2 + m == m2);
+
+ m += 2;
+ DLIB_TEST(m == m2);
+ m -= 2;
+
+ m2 = 0,1,2,
+ 3,4,5;
+
+ DLIB_TEST(m - 1 == m2);
+
+ m -= 1;
+ DLIB_TEST(m == m2);
+ m += 1;
+
+
+ m2 = 5,4,3,
+ 2,1,0;
+
+ DLIB_TEST(6 - m == m2);
+ }
+
+ {
+ matrix<float> m(2,3), m2(2,3);
+
+ m = 1,2,3,
+ 4,5,6;
+
+
+ m2 = 3,4,5,
+ 6,7,8;
+
+ DLIB_TEST(m + 2 == m2);
+ DLIB_TEST(2 + m == m2);
+
+ m += 2;
+ DLIB_TEST(m == m2);
+ m -= 2;
+
+ m2 = 0,1,2,
+ 3,4,5;
+
+ DLIB_TEST(m - 1 == m2);
+
+ m -= 1;
+ DLIB_TEST(m == m2);
+ m += 1;
+
+
+ m2 = 5,4,3,
+ 2,1,0;
+
+ DLIB_TEST(6 - m == m2);
+ }
+
+ {
+ matrix<int> m(2,3), m2(2,3);
+
+ m = 1,2,3,
+ 4,5,6;
+
+
+ m2 = 3,4,5,
+ 6,7,8;
+
+ DLIB_TEST(m + 2 == m2);
+ DLIB_TEST(2 + m == m2);
+
+ m += 2;
+ DLIB_TEST(m == m2);
+ m -= 2;
+
+ m2 = 0,1,2,
+ 3,4,5;
+
+ DLIB_TEST(m - 1 == m2);
+
+ m -= 1;
+ DLIB_TEST(m == m2);
+ m += 1;
+
+
+ m2 = 5,4,3,
+ 2,1,0;
+
+ DLIB_TEST(6 - m == m2);
+ }
+
+ {
+ matrix<int,2,3> m, m2;
+
+ m = 1,2,3,
+ 4,5,6;
+
+
+ m2 = 3,4,5,
+ 6,7,8;
+
+ DLIB_TEST(m + 2 == m2);
+ DLIB_TEST(2 + m == m2);
+
+ m += 2;
+ DLIB_TEST(m == m2);
+ m -= 2;
+
+ m2 = 0,1,2,
+ 3,4,5;
+
+ DLIB_TEST(m - 1 == m2);
+
+ m -= 1;
+ DLIB_TEST(m == m2);
+ m += 1;
+
+
+ m2 = 5,4,3,
+ 2,1,0;
+
+ DLIB_TEST(6 - m == m2);
+ }
+
+ {
+ matrix<double> m(2,3), m2(3,2);
+
+ m = 1,2,3,
+ 4,5,6;
+
+ m2 = 2,5,
+ 3,6,
+ 4,7;
+
+ DLIB_TEST(trans(m+1) == m2);
+ DLIB_TEST(trans(m)+1 == m2);
+ DLIB_TEST(1+trans(m) == m2);
+ DLIB_TEST(1+m-1 == m);
+
+ m = trans(m+1);
+ DLIB_TEST(m == m2);
+ m = trans(m-1);
+ DLIB_TEST(trans(m+1) == m2);
+ m = trans(m)+1;
+ DLIB_TEST(m == m2);
+ }
+
+ {
+ matrix<double> d(3,1), di(3,1);
+ matrix<double> m(3,3);
+
+ m = 1,2,3,
+ 4,5,6,
+ 7,8,9;
+
+ d = 1,2,3;
+
+ di = 1, 1/2.0, 1/3.0;
+
+ DLIB_TEST(inv(diagm(d)) == diagm(di));
+ DLIB_TEST(pinv(diagm(d)) == diagm(di));
+ DLIB_TEST(inv(diagm(d))*m == tmp(diagm(di))*m);
+ DLIB_TEST(m*inv(diagm(d)) == m*tmp(diagm(di)));
+
+ DLIB_TEST(equal(inv(diagm(d)) + m , tmp(diagm(di)) + m));
+ DLIB_TEST(equal(m + inv(diagm(d)) , tmp(diagm(di)) + m));
+
+ DLIB_TEST((m + identity_matrix<double>(3) == m + tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + identity_matrix<double,3>() == m + tmp(identity_matrix<double,3>())));
+ DLIB_TEST((m + 2*identity_matrix<double>(3) == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + 2*identity_matrix<double,3>() == m + 2*tmp(identity_matrix<double,3>())));
+ DLIB_TEST((m + identity_matrix<double>(3)*2 == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + identity_matrix<double,3>()*2 == m + 2*tmp(identity_matrix<double,3>())));
+
+ DLIB_TEST((identity_matrix<double>(3) + m == m + tmp(identity_matrix<double>(3))));
+ DLIB_TEST((identity_matrix<double,3>() + m == m + tmp(identity_matrix<double,3>())));
+ DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>())));
+
+ }
+ {
+ matrix<double,3,1> d(3,1), di(3,1);
+ matrix<double,3,3> m(3,3);
+
+ m = 1,2,3,
+ 4,5,6,
+ 7,8,9;
+
+ d = 1,2,3;
+
+ di = 1, 1/2.0, 1/3.0;
+
+ DLIB_TEST(equal(inv(diagm(d)) , diagm(di)));
+ DLIB_TEST(equal(inv(diagm(d)) , diagm(di)));
+ DLIB_TEST(equal(inv(diagm(d))*m , tmp(diagm(di))*m));
+ DLIB_TEST(equal(m*inv(diagm(d)) , m*tmp(diagm(di))));
+
+ DLIB_TEST_MSG(equal(inv(diagm(d)) + m , tmp(diagm(di)) + m),
+ (inv(diagm(d)) + m) - (tmp(diagm(di)) + m) );
+ DLIB_TEST(equal(m + inv(diagm(d)) , tmp(diagm(di)) + m));
+
+
+ DLIB_TEST((m + identity_matrix<double>(3) == m + tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + identity_matrix<double,3>() == m + tmp(identity_matrix<double,3>())));
+ DLIB_TEST((m + 2*identity_matrix<double>(3) == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + 2*identity_matrix<double,3>() == m + 2*tmp(identity_matrix<double,3>())));
+ DLIB_TEST((m + identity_matrix<double>(3)*2 == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + identity_matrix<double,3>()*2 == m + 2*tmp(identity_matrix<double,3>())));
+
+ DLIB_TEST((identity_matrix<double>(3) + m == m + tmp(identity_matrix<double>(3))));
+ DLIB_TEST((identity_matrix<double,3>() + m == m + tmp(identity_matrix<double,3>())));
+ DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>())));
+ }
+
+ {
+ matrix<double,1,3> d(1,3), di(1,3);
+ matrix<double,3,3> m(3,3);
+
+ m = 1,2,3,
+ 4,5,6,
+ 7,8,9;
+
+ d = 1,2,3;
+
+ di = 1, 1/2.0, 1/3.0;
+
+ DLIB_TEST(equal(inv(diagm(d)) , diagm(di)));
+ DLIB_TEST(equal(inv(diagm(d)) , diagm(di)));
+ DLIB_TEST(equal(inv(diagm(d))*m , tmp(diagm(di))*m));
+ DLIB_TEST(equal(m*inv(diagm(d)) , m*tmp(diagm(di))));
+
+ DLIB_TEST(equal(inv(diagm(d)) + m , tmp(diagm(di)) + m));
+ DLIB_TEST(equal(m + inv(diagm(d)) , tmp(diagm(di)) + m));
+
+
+ DLIB_TEST((m + identity_matrix<double>(3) == m + tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + identity_matrix<double,3>() == m + tmp(identity_matrix<double,3>())));
+ DLIB_TEST((m + 2*identity_matrix<double>(3) == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + 2*identity_matrix<double,3>() == m + 2*tmp(identity_matrix<double,3>())));
+ DLIB_TEST((m + identity_matrix<double>(3)*2 == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + identity_matrix<double,3>()*2 == m + 2*tmp(identity_matrix<double,3>())));
+
+ DLIB_TEST((identity_matrix<double>(3) + m == m + tmp(identity_matrix<double>(3))));
+ DLIB_TEST((identity_matrix<double,3>() + m == m + tmp(identity_matrix<double,3>())));
+ DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>())));
+ }
+
+ {
+ matrix<double,1,0> d(1,3), di(1,3);
+ matrix<double,0,3> m(3,3);
+
+ m = 1,2,3,
+ 4,5,6,
+ 7,8,9;
+
+ d = 1,2,3;
+
+ di = 1, 1/2.0, 1/3.0;
+
+ DLIB_TEST(equal(inv(diagm(d)) , diagm(di)));
+ DLIB_TEST(equal(inv(diagm(d)) , diagm(di)));
+ DLIB_TEST(equal(inv(diagm(d))*m , tmp(diagm(di))*m));
+ DLIB_TEST(equal(m*inv(diagm(d)) , m*tmp(diagm(di))));
+
+ DLIB_TEST(equal(inv(diagm(d)) + m , tmp(diagm(di)) + m));
+ DLIB_TEST(equal(m + inv(diagm(d)) , tmp(diagm(di)) + m));
+
+
+ DLIB_TEST((m + identity_matrix<double>(3) == m + tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + identity_matrix<double,3>() == m + tmp(identity_matrix<double,3>())));
+ DLIB_TEST((m + 2*identity_matrix<double>(3) == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + 2*identity_matrix<double,3>() == m + 2*tmp(identity_matrix<double,3>())));
+ DLIB_TEST((m + identity_matrix<double>(3)*2 == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((m + identity_matrix<double,3>()*2 == m + 2*tmp(identity_matrix<double,3>())));
+
+ DLIB_TEST((identity_matrix<double>(3) + m == m + tmp(identity_matrix<double>(3))));
+ DLIB_TEST((identity_matrix<double,3>() + m == m + tmp(identity_matrix<double,3>())));
+ DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3))));
+ DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>())));
+ }
+
+
+ {
+ matrix<double,3,1> d1, d2;
+
+ d1 = 1,2,3;
+
+ d2 = 2,3,4;
+
+ matrix<double,3,3> ans;
+ ans = 2, 0, 0,
+ 0, 6, 0,
+ 0, 0, 12;
+
+ DLIB_TEST(ans == diagm(d1)*diagm(d2));
+ }
+
+
+ dlib::rand rnd;
+ for (int i = 0; i < 1; ++i)
+ {
+ matrix<double> d1 = randm(4,1,rnd);
+ matrix<double,5,1> d2 = randm(5,1,rnd);
+
+ matrix<double,4,5> m = randm(4,5,rnd);
+
+ DLIB_TEST_MSG(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*m*diagm(d2)),
+ pointwise_multiply(d1*trans(d2), m) - diagm(d1)*m*diagm(d2)
+ );
+ DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*(m*diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , (diagm(d1)*m)*diagm(d2)));
+
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*m*inv(diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*(m*inv(diagm(d2)))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , (inv(diagm(d1))*m)*inv(diagm(d2))));
+
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*m*(diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*(m*(diagm(d2)))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , (inv(diagm(d1))*m)*(diagm(d2))));
+
+ DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*m*inv(diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*(m*inv(diagm(d2)))));
+ DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , ((diagm(d1))*m)*inv(diagm(d2))));
+ }
+ for (int i = 0; i < 1; ++i)
+ {
+ matrix<double,4,1> d1 = randm(4,1,rnd);
+ matrix<double,5,1> d2 = randm(5,1,rnd);
+
+ matrix<double,4,5> m = randm(4,5,rnd);
+
+ DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*m*diagm(d2)));
+ DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*(m*diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , (diagm(d1)*m)*diagm(d2)));
+
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*m*inv(diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*(m*inv(diagm(d2)))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , (inv(diagm(d1))*m)*inv(diagm(d2))));
+
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*m*(diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*(m*(diagm(d2)))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , (inv(diagm(d1))*m)*(diagm(d2))));
+
+ DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*m*inv(diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*(m*inv(diagm(d2)))));
+ DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , ((diagm(d1))*m)*inv(diagm(d2))));
+ }
+ for (int i = 0; i < 1; ++i)
+ {
+ matrix<double,4,1> d1 = randm(4,1,rnd);
+ matrix<double,5,1> d2 = randm(5,1,rnd);
+
+ matrix<double,0,0> m = randm(4,5,rnd);
+
+ DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*m*diagm(d2)));
+ DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*(m*diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , (diagm(d1)*m)*diagm(d2)));
+
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*m*inv(diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*(m*inv(diagm(d2)))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , (inv(diagm(d1))*m)*inv(diagm(d2))));
+
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*m*(diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*(m*(diagm(d2)))));
+ DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , (inv(diagm(d1))*m)*(diagm(d2))));
+
+ DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*m*inv(diagm(d2))));
+ DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*(m*inv(diagm(d2)))));
+ DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , ((diagm(d1))*m)*inv(diagm(d2))));
+ }
+
+
+
+ {
+ for (int i = 0; i < 5; ++i)
+ {
+ matrix<double> m = randm(3,4) + 1;
+
+ DLIB_TEST(equal(1.0/m , reciprocal(m)));
+ DLIB_TEST(equal(0.0/m , zeros_matrix<double>(3,4)));
+ }
+ }
+
+ {
+ matrix<int> m(2,3);
+ m = 1,2,3,
+ 4,5,6;
+ matrix<int> M(2,3);
+ M = m;
+
+ DLIB_TEST(upperbound(m,6) == M);
+ DLIB_TEST(upperbound(m,60) == M);
+ DLIB_TEST(lowerbound(m,-2) == M);
+ DLIB_TEST(lowerbound(m,0) == M);
+
+ M = 2,2,3,
+ 4,5,6;
+ DLIB_TEST(lowerbound(m,2) == M);
+
+ M = 0,0,0,
+ 0,0,0;
+ DLIB_TEST(upperbound(m,0) == M);
+
+ M = 1,2,3,
+ 3,3,3;
+ DLIB_TEST(upperbound(m,3) == M);
+ }
+
+ {
+ matrix<double,9,5> A = randm(9,5);
+ matrix<double> B = A;
+
+ orthogonalize(A);
+ orthogonalize(B);
+
+ DLIB_TEST(equal(A,B));
+ }
+ }
+
+
+
+ template <
+ long D1,
+ long D2,
+ long D3,
+ long D4
+ >
+ void test_conv()
+ {
+ dlog << LINFO << D1 << " " << D2 << " " << D3 << " " << D4;
+ matrix<int,D1,D1> a(1,1);
+ matrix<int,D2,D2> b(2,2);
+ matrix<int,D3,D3> c(3,3);
+ matrix<int,D4,D1> d(4,1);
+
+ a = 4;
+
+ b = 1,2,
+ 3,4;
+
+ c = 1,2,3,
+ 4,5,6,
+ 7,8,9;
+
+ d = 1,
+ 2,
+ 3,
+ 4;
+
+ matrix<int> temp(4,4), temp2;
+ temp = 1, 4, 7, 6,
+ 7, 23, 33, 24,
+ 19, 53, 63, 42,
+ 21, 52, 59, 36;
+
+ DLIB_TEST(conv(b,c) == temp);
+ DLIB_TEST(conv(c,b) == temp);
+ DLIB_TEST(xcorr(c,flip(b)) == temp);
+
+ temp.set_size(2,2);
+ temp = 23, 33,
+ 53, 63;
+ DLIB_TEST(conv_same(b,c) == temp);
+ DLIB_TEST(xcorr_same(b,flip(c)) == temp);
+
+ temp2.set_size(2,2);
+ temp2 = 63, 53,
+ 33, 23;
+ DLIB_TEST(flip(temp) == temp2);
+ DLIB_TEST(flip(temp) == fliplr(flipud(temp)));
+
+ DLIB_TEST(conv_valid(b,c).nr() == 0);
+ DLIB_TEST(conv_valid(b,c).nc() == 0);
+
+ DLIB_TEST(conv_valid(c,b) == temp);
+ DLIB_TEST(xcorr_valid(c,flip(b)) == temp);
+
+ temp.set_size(1,1);
+ temp = 16;
+
+ DLIB_TEST(conv(a,a) == temp);
+ DLIB_TEST(conv_same(a,a) == temp);
+ DLIB_TEST(conv_valid(a,a) == temp);
+ DLIB_TEST(xcorr(a,a) == temp);
+ DLIB_TEST(xcorr_same(a,a) == temp);
+ DLIB_TEST(xcorr_valid(a,a) == temp);
+
+ temp.set_size(0,0);
+ DLIB_TEST(conv(temp,temp).nr() == 0);
+ DLIB_TEST(conv(temp,temp).nc() == 0);
+ DLIB_TEST(conv_same(temp,temp).nr() == 0);
+ DLIB_TEST(conv_same(temp,temp).nc() == 0);
+ DLIB_TEST_MSG(conv_valid(temp,temp).nr() == 0, conv_valid(temp,temp).nr());
+ DLIB_TEST(conv_valid(temp,temp).nc() == 0);
+ DLIB_TEST(conv(c,temp).nr() == 0);
+ DLIB_TEST(conv(c,temp).nc() == 0);
+ DLIB_TEST(conv_same(c,temp).nr() == 0);
+ DLIB_TEST(conv_same(c,temp).nc() == 0);
+ DLIB_TEST(conv_valid(c,temp).nr() == 0);
+ DLIB_TEST(conv_valid(c,temp).nc() == 0);
+ DLIB_TEST(conv(temp,c).nr() == 0);
+ DLIB_TEST(conv(temp,c).nc() == 0);
+ DLIB_TEST(conv_same(temp,c).nr() == 0);
+ DLIB_TEST(conv_same(temp,c).nc() == 0);
+ DLIB_TEST(conv_valid(temp,c).nr() == 0);
+ DLIB_TEST(conv_valid(temp,c).nc() == 0);
+
+ temp.set_size(5,2);
+ temp = 1, 2,
+ 5, 8,
+ 9, 14,
+ 13, 20,
+ 12, 16;
+ DLIB_TEST(conv(b,d) == temp);
+ DLIB_TEST(xcorr(b,flip(d)) == temp);
+
+ temp.set_size(2,2);
+ temp = 9, 14,
+ 13, 20;
+ DLIB_TEST(conv_same(b,d) == temp);
+ DLIB_TEST(xcorr_same(b,flip(d)) == temp);
+
+ DLIB_TEST(conv_valid(b,d).nr() == 0);
+ DLIB_TEST(xcorr_valid(b,flip(d)).nr() == 0);
+ DLIB_TEST_MSG(conv_valid(b,d).nc() == 0, conv_valid(b,d).nc());
+ DLIB_TEST(xcorr_valid(b,flip(d)).nc() == 0);
+
+ temp.set_size(5,5);
+ temp = 1, 4, 10, 12, 9,
+ 8, 26, 56, 54, 36,
+ 30, 84, 165, 144, 90,
+ 56, 134, 236, 186, 108,
+ 49, 112, 190, 144, 81;
+
+ DLIB_TEST(conv(c,c) == temp);
+ DLIB_TEST(xcorr(c,flip(c)) == temp);
+ matrix<int> temp3 = c;
+ temp3 = conv(temp3,c);
+ DLIB_TEST(temp3 == temp);
+
+ temp3 = c;
+ temp3 = conv(c,temp3);
+ DLIB_TEST(temp3 == temp);
+
+
+ temp.set_size(3,3);
+ temp = 26, 56, 54,
+ 84, 165, 144,
+ 134, 236, 186;
+ DLIB_TEST(conv_same(c,c) == temp);
+ DLIB_TEST(xcorr_same(c,flip(c)) == temp);
+ temp3 = c;
+ temp3 = conv_same(c,temp3);
+ DLIB_TEST(temp3 == temp);
+ temp3 = c;
+ temp3 = conv_same(temp3,c);
+ DLIB_TEST(temp3 == temp);
+
+ temp.set_size(1,1);
+ temp = 165;
+ DLIB_TEST(conv_valid(c,c) == temp);
+ DLIB_TEST(xcorr_valid(c,flip(c)) == temp);
+ temp3 = c;
+ temp3 = conv_valid(c,temp3);
+ DLIB_TEST(temp3 == temp);
+ temp3 = c;
+ temp3 = conv_valid(temp3,c);
+ DLIB_TEST(temp3 == temp);
+
+
+ dlib::rand rnd;
+ for (int i = 0; i < 3; ++i)
+ {
+ matrix<complex<int> > a, b;
+ a = complex_matrix(matrix_cast<int>(round(20*randm(2,7,rnd))),
+ matrix_cast<int>(round(20*randm(2,7,rnd))));
+ b = complex_matrix(matrix_cast<int>(round(20*randm(3,2,rnd))),
+ matrix_cast<int>(round(20*randm(3,2,rnd))));
+
+ DLIB_TEST(xcorr(a,b) == conv(a, flip(conj(b))));
+ DLIB_TEST(xcorr_valid(a,b) == conv_valid(a, flip(conj(b))));
+ DLIB_TEST(xcorr_same(a,b) == conv_same(a, flip(conj(b))));
+ }
+
+
+ for (int i = 0; i < 30; ++i)
+ {
+ auto nr1 = rnd.get_integer_in_range(1,30);
+ auto nc1 = rnd.get_integer_in_range(1,30);
+ auto nr2 = rnd.get_integer_in_range(1,30);
+ auto nc2 = rnd.get_integer_in_range(1,30);
+ matrix<double> a, b;
+ a = randm(nr1,nc1,rnd);
+ b = randm(nr2,nc2,rnd);
+
+ DLIB_TEST(max(abs(xcorr(a,b) - xcorr_fft(a,b))) < 1e-12);
+ }
+ }
+
+ void test_complex()
+ {
+ matrix<complex<double> > a, b;
+
+ a = complex_matrix(linspace(1,7,7), linspace(2,8,7));
+ b = complex_matrix(linspace(4,10,7), linspace(2,8,7));
+
+ DLIB_TEST(mean(a) == complex<double>(4, 5));
+ }
+
+ void test_setsubs()
+ {
+ {
+ matrix<double> m(3,3);
+ m = 0;
+
+ set_colm(m,0) += 1;
+ set_rowm(m,0) += 1;
+ set_subm(m,1,1,2,2) += 5;
+
+ matrix<double> m2(3,3);
+ m2 = 2, 1, 1,
+ 1, 5, 5,
+ 1, 5, 5;
+
+ DLIB_TEST(m == m2);
+
+ set_colm(m,0) -= 1;
+ set_rowm(m,0) -= 1;
+ set_subm(m,1,1,2,2) -= 5;
+
+ m2 = 0;
+ DLIB_TEST(m == m2);
+
+ matrix<double,1,3> r;
+ matrix<double,3,1> c;
+ matrix<double,2,2> b;
+ r = 1,2,3;
+
+ c = 2,
+ 3,
+ 4;
+
+ b = 2,3,
+ 4,5;
+
+ set_colm(m,1) += c;
+ set_rowm(m,1) += r;
+ set_subm(m,1,1,2,2) += b;
+
+ m2 = 0, 2, 0,
+ 1, 7, 6,
+ 0, 8, 5;
+
+ DLIB_TEST(m2 == m);
+
+ set_colm(m,1) -= c;
+ set_rowm(m,1) -= r;
+ set_subm(m,1,1,2,2) -= b;
+
+ m2 = 0;
+ DLIB_TEST(m2 == m);
+
+
+ // check that the code path for destructive aliasing works right.
+ m = 2*identity_matrix<double>(3);
+ set_colm(m,1) += m*c;
+ m2 = 2, 4, 0,
+ 0, 8, 0,
+ 0, 8, 2;
+ DLIB_TEST(m == m2);
+
+ m = 2*identity_matrix<double>(3);
+ set_colm(m,1) -= m*c;
+ m2 = 2, -4, 0,
+ 0, -4, 0,
+ 0, -8, 2;
+ DLIB_TEST(m == m2);
+
+ m = 2*identity_matrix<double>(3);
+ set_rowm(m,1) += r*m;
+ m2 = 2, 0, 0,
+ 2, 6, 6,
+ 0, 0, 2;
+ DLIB_TEST(m == m2);
+
+ m = 2*identity_matrix<double>(3);
+ set_rowm(m,1) -= r*m;
+ m2 = 2, 0, 0,
+ -2, -2, -6,
+ 0, 0, 2;
+ DLIB_TEST(m == m2);
+
+ m = identity_matrix<double>(3);
+ const rectangle rect(0,0,1,1);
+ set_subm(m,rect) += subm(m,rect)*b;
+ m2 = 3, 3, 0,
+ 4, 6, 0,
+ 0, 0, 1;
+ DLIB_TEST(m == m2);
+
+ m = identity_matrix<double>(3);
+ set_subm(m,rect) -= subm(m,rect)*b;
+ m2 = -1, -3, 0,
+ -4, -4, 0,
+ 0, 0, 1;
+ DLIB_TEST(m == m2);
+
+ }
+
+ {
+ matrix<double,1,1> a, b;
+ a = 2;
+ b = 3;
+ DLIB_TEST(dot(a,b) == 6);
+ }
+ {
+ matrix<double,1,1> a;
+ matrix<double,0,1> b(1);
+ a = 2;
+ b = 3;
+ DLIB_TEST(dot(a,b) == 6);
+ DLIB_TEST(dot(b,a) == 6);
+ }
+ {
+ matrix<double,1,1> a;
+ matrix<double,1,0> b(1);
+ a = 2;
+ b = 3;
+ DLIB_TEST(dot(a,b) == 6);
+ DLIB_TEST(dot(b,a) == 6);
+ }
+ }
+
+ template <typename T>
+ std::vector<int> tovect1(const T& m)
+ {
+ std::vector<int> temp;
+ for (typename T::const_iterator i = m.begin(); i != m.end(); ++i)
+ {
+ temp.push_back(*i);
+ }
+ return temp;
+ }
+
+ template <typename T>
+ std::vector<int> tovect2(const T& m)
+ {
+ std::vector<int> temp;
+ for (typename T::const_iterator i = m.begin(); i != m.end(); i++)
+ {
+ temp.push_back(*i);
+ }
+ return temp;
+ }
+
+ template <typename T>
+ std::vector<int> tovect3(const T& m_)
+ {
+ matrix<int> m(m_);
+ std::vector<int> temp;
+ for (matrix<int>::iterator i = m.begin(); i != m.end(); ++i)
+ {
+ temp.push_back(*i);
+ }
+ return temp;
+ }
+
+ template <typename T>
+ std::vector<int> tovect4(const T& m_)
+ {
+ matrix<int> m(m_);
+ std::vector<int> temp;
+ for (matrix<int>::iterator i = m.begin(); i != m.end(); i++)
+ {
+ temp.push_back(*i);
+ }
+ return temp;
+ }
+
+ void test_iterators()
+ {
+ matrix<int> m(3,2);
+ m = 1,2,3,
+ 4,5,6;
+
+ std::vector<int> v1 = tovect1(m);
+ std::vector<int> v2 = tovect2(m);
+ std::vector<int> v3 = tovect3(m);
+ std::vector<int> v4 = tovect4(m);
+
+ std::vector<int> v5 = tovect1(m+m);
+ std::vector<int> v6 = tovect2(m+m);
+ std::vector<int> v7 = tovect3(m+m);
+ std::vector<int> v8 = tovect4(m+m);
+
+
+ std::vector<int> a1, a2;
+ for (int i = 1; i <= 6; ++i)
+ {
+ a1.push_back(i);
+ a2.push_back(i*2);
+ }
+
+ DLIB_TEST(max(abs(mat(v1) - mat(a1))) == 0);
+ DLIB_TEST(max(abs(mat(v2) - mat(a1))) == 0);
+ DLIB_TEST(max(abs(mat(v3) - mat(a1))) == 0);
+ DLIB_TEST(max(abs(mat(v4) - mat(a1))) == 0);
+
+ DLIB_TEST(max(abs(mat(v5) - mat(a2))) == 0);
+ DLIB_TEST(max(abs(mat(v6) - mat(a2))) == 0);
+ DLIB_TEST(max(abs(mat(v7) - mat(a2))) == 0);
+ DLIB_TEST(max(abs(mat(v8) - mat(a2))) == 0);
+ }
+
+ void test_linpiece()
+ {
+ matrix<double,0,1> temp = linpiece(5, linspace(-1, 9, 2));
+ DLIB_CASSERT(temp.size() == 1,"");
+ DLIB_CASSERT(std::abs(temp(0) - 6) < 1e-13,"");
+
+ temp = linpiece(5, linspace(-1, 9, 6));
+ DLIB_CASSERT(temp.size() == 5,"");
+ DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
+
+ temp = linpiece(4, linspace(-1, 9, 6));
+ DLIB_CASSERT(temp.size() == 5,"");
+ DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(2) - 1) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
+
+ temp = linpiece(40, linspace(-1, 9, 6));
+ DLIB_CASSERT(temp.size() == 5,"");
+ DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(3) - 2) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(4) - 2) < 1e-13,"");
+
+ temp = linpiece(-40, linspace(-1, 9, 6));
+ DLIB_CASSERT(temp.size() == 5,"");
+ DLIB_CASSERT(std::abs(temp(0) - 0) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
+
+ temp = linpiece(0, linspace(-1, 9, 6));
+ DLIB_CASSERT(temp.size() == 5,"");
+ DLIB_CASSERT(std::abs(temp(0) - 1) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
+ DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
+
+ }
+
+ class matrix_tester : public tester
+ {
+ public:
+ matrix_tester (
+ ) :
+ tester ("test_matrix4",
+ "Runs tests on the scale_rows and scale_columns functions.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_iterators();
+ test_setsubs();
+
+ test_conv<0,0,0,0>();
+ test_conv<1,2,3,4>();
+
+ test_stuff();
+ for (int i = 0; i < 10; ++i)
+ matrix_test();
+
+ test_complex();
+ test_linpiece();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/matrix_chol.cpp b/ml/dlib/dlib/test/matrix_chol.cpp
new file mode 100644
index 000000000..b46c4866d
--- /dev/null
+++ b/ml/dlib/dlib/test/matrix_chol.cpp
@@ -0,0 +1,182 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+#include <dlib/string.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.matrix_chol");
+
+ dlib::rand rnd;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename mat_type>
+ const matrix<typename mat_type::type> symm(const mat_type& m) { return m*trans(m); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename type>
+ const matrix<type> randmat(long r, long c)
+ {
+ matrix<type> m(r,c);
+ for (long row = 0; row < m.nr(); ++row)
+ {
+ for (long col = 0; col < m.nc(); ++col)
+ {
+ m(row,col) = static_cast<type>(rnd.get_random_double());
+ }
+ }
+
+ return m;
+ }
+
+ template <typename type, long NR, long NC>
+ const matrix<type,NR,NC> randmat()
+ {
+ matrix<type,NR,NC> m;
+ for (long row = 0; row < m.nr(); ++row)
+ {
+ for (long col = 0; col < m.nc(); ++col)
+ {
+ m(row,col) = static_cast<type>(rnd.get_random_double());
+ }
+ }
+
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_type>
+ void test_cholesky ( const matrix_type& m)
+ {
+ typedef typename matrix_type::type type;
+ const type eps = 10*max(abs(m))*sqrt(std::numeric_limits<type>::epsilon());
+ dlog << LDEBUG << "test_cholesky(): " << m.nr() << " x " << m.nc() << " eps: " << eps;
+ print_spinner();
+
+
+ cholesky_decomposition<matrix_type> test(m);
+
+ // none of the matrices we should be passing in to test_cholesky() should be non-spd.
+ DLIB_TEST(test.is_spd() == true);
+
+ type temp;
+ DLIB_TEST_MSG( (temp= max(abs(test.get_l()*trans(test.get_l()) - m))) < eps,temp);
+
+ {
+ matrix<type> mat = chol(m);
+ DLIB_TEST_MSG( (temp= max(abs(mat*trans(mat) - m))) < eps,temp);
+ }
+
+
+ matrix<type> m2;
+ matrix<type,0,1> col;
+
+ m2 = identity_matrix<type>(m.nr());
+ DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2)));
+ m2 = randmat<type>(m.nr(),5);
+ DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2)));
+ m2 = randmat<type>(m.nr(),1);
+ DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2)));
+ col = randmat<type>(m.nr(),1);
+ DLIB_TEST_MSG(equal(m*test.solve(col), col,eps),max(abs(m*test.solve(m2)- m2)));
+
+ // now make us a non-spd matrix
+ if (m.nr() > 2)
+ {
+ matrix<type> sm(lowerm(m));
+ sm(1,1) = 0;
+
+ cholesky_decomposition<matrix_type> test2(sm);
+ DLIB_TEST_MSG(test2.is_spd() == false, test2.get_l());
+
+
+ cholesky_decomposition<matrix_type> test3(sm*trans(sm));
+ DLIB_TEST_MSG(test3.is_spd() == false, test3.get_l());
+
+ sm = sm*trans(sm);
+ sm(1,1) = 5;
+ sm(1,0) -= 1;
+ cholesky_decomposition<matrix_type> test4(sm);
+ DLIB_TEST_MSG(test4.is_spd() == false, test4.get_l());
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void matrix_test_double()
+ {
+
+ test_cholesky(uniform_matrix<double>(1,1,1) + 10*symm(randmat<double>(1,1)));
+ test_cholesky(uniform_matrix<double>(2,2,1) + 10*symm(randmat<double>(2,2)));
+ test_cholesky(uniform_matrix<double>(3,3,1) + 10*symm(randmat<double>(3,3)));
+ test_cholesky(uniform_matrix<double>(4,4,1) + 10*symm(randmat<double>(4,4)));
+ test_cholesky(uniform_matrix<double>(15,15,1) + 10*symm(randmat<double>(15,15)));
+ test_cholesky(uniform_matrix<double>(101,101,1) + 10*symm(randmat<double>(101,101)));
+
+ typedef matrix<double,0,0,default_memory_manager, column_major_layout> mat;
+ test_cholesky(mat(uniform_matrix<double>(101,101,1) + 10*symm(randmat<double>(101,101))));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void matrix_test_float()
+ {
+
+ test_cholesky(uniform_matrix<float>(1,1,1) + 2*symm(randmat<float>(1,1)));
+ test_cholesky(uniform_matrix<float>(2,2,1) + 2*symm(randmat<float>(2,2)));
+ test_cholesky(uniform_matrix<float>(3,3,1) + 2*symm(randmat<float>(3,3)));
+
+ typedef matrix<float,0,0,default_memory_manager, column_major_layout> mat;
+ test_cholesky(mat(uniform_matrix<float>(3,3,1) + 2*symm(randmat<float>(3,3))));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class matrix_tester : public tester
+ {
+ public:
+ matrix_tester (
+ ) :
+ tester ("test_matrix_chol",
+ "Runs tests on the matrix cholesky component.")
+ {
+ rnd.set_seed(cast_to_string(time(0)));
+ }
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "seed string: " << rnd.get_seed();
+
+ dlog << LINFO << "begin testing with double";
+ matrix_test_double();
+ dlog << LINFO << "begin testing with float";
+ matrix_test_float();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/matrix_eig.cpp b/ml/dlib/dlib/test/matrix_eig.cpp
new file mode 100644
index 000000000..9fbce6598
--- /dev/null
+++ b/ml/dlib/dlib/test/matrix_eig.cpp
@@ -0,0 +1,245 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+#include <dlib/string.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.matrix_eig");
+
+ dlib::rand rnd;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename type>
+ const matrix<type> randm(long r, long c)
+ {
+ matrix<type> m(r,c);
+ for (long row = 0; row < m.nr(); ++row)
+ {
+ for (long col = 0; col < m.nc(); ++col)
+ {
+ m(row,col) = static_cast<type>(rnd.get_random_double());
+ }
+ }
+
+ return m;
+ }
+
+ template <typename type, long NR, long NC>
+ const matrix<type,NR,NC> randm()
+ {
+ matrix<type,NR,NC> m;
+ for (long row = 0; row < m.nr(); ++row)
+ {
+ for (long col = 0; col < m.nc(); ++col)
+ {
+ m(row,col) = static_cast<type>(rnd.get_random_double());
+ }
+ }
+
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_type, typename U>
+ void test_eigenvalue_impl ( const matrix_type& m, const eigenvalue_decomposition<U>& test )
+ {
+ typedef typename matrix_type::type type;
+ const type eps = 10*max(abs(m))*sqrt(std::numeric_limits<type>::epsilon());
+ dlog << LDEBUG << "test_eigenvalue(): " << m.nr() << " x " << m.nc() << " eps: " << eps;
+ print_spinner();
+
+
+ DLIB_TEST(test.dim() == m.nr());
+
+ // make sure all the various ways of asking for the eigenvalues are actually returning a
+ // consistent set of eigenvalues.
+ DLIB_TEST(equal(real(test.get_eigenvalues()), test.get_real_eigenvalues(), eps));
+ DLIB_TEST(equal(imag(test.get_eigenvalues()), test.get_imag_eigenvalues(), eps));
+ DLIB_TEST(equal(real(diag(test.get_d())), test.get_real_eigenvalues(), eps));
+ DLIB_TEST(equal(imag(diag(test.get_d())), test.get_imag_eigenvalues(), eps));
+
+ matrix<type> eig1 ( real_eigenvalues(m));
+ matrix<type> eig2 ( test.get_real_eigenvalues());
+ sort(&eig1(0), &eig1(0) + eig1.size());
+ sort(&eig2(0), &eig2(0) + eig2.size());
+ DLIB_TEST(max(abs(eig1 - eig2)) < eps);
+
+ const matrix<type> V = test.get_pseudo_v();
+ const matrix<type> D = test.get_pseudo_d();
+ const matrix<complex<type> > CV = test.get_v();
+ const matrix<complex<type> > CD = test.get_d();
+ const matrix<complex<type> > CM = complex_matrix(m, uniform_matrix<type>(m.nr(),m.nc(),0));
+
+ DLIB_TEST(V.nr() == test.dim());
+ DLIB_TEST(V.nc() == test.dim());
+ DLIB_TEST(D.nr() == test.dim());
+ DLIB_TEST(D.nc() == test.dim());
+
+ // CD is a diagonal matrix
+ DLIB_TEST(diagm(diag(CD)) == CD);
+
+ // verify that these things are actually eigenvalues and eigenvectors of m
+ DLIB_TEST_MSG(max(abs(m*V - V*D)) < eps, max(abs(m*V - V*D)) << " " << eps);
+ DLIB_TEST(max(norm(CM*CV - CV*CD)) < eps);
+
+ // if m is a symmetric matrix
+ if (max(abs(m-trans(m))) < 1e-5)
+ {
+ dlog << LTRACE << "m is symmetric";
+ // there aren't any imaginary eigenvalues
+ DLIB_TEST(max(abs(test.get_imag_eigenvalues())) < eps);
+ DLIB_TEST(diagm(diag(D)) == D);
+
+ // only check the determinant against the eigenvalues for small matrices
+ // because for huge ones the determinant might be so big it overflows a floating point number.
+ if (m.nr() < 50)
+ {
+ const type mdet = det(m);
+ DLIB_TEST_MSG(std::abs(prod(test.get_real_eigenvalues()) - mdet) < std::abs(mdet)*sqrt(std::numeric_limits<type>::epsilon()),
+ std::abs(prod(test.get_real_eigenvalues()) - mdet) <<" eps: " << std::abs(mdet)*sqrt(std::numeric_limits<type>::epsilon())
+ << " mdet: "<< mdet << " prod(eig): " << prod(test.get_real_eigenvalues())
+ );
+ }
+
+ // V is orthogonal
+ DLIB_TEST(equal(V*trans(V), identity_matrix<type>(test.dim()), eps));
+ DLIB_TEST(equal(m , V*D*trans(V), eps));
+ }
+ else
+ {
+ dlog << LTRACE << "m is NOT symmetric";
+ DLIB_TEST_MSG(equal(m , V*D*inv(V), eps), max(abs(m - V*D*inv(V))));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_type>
+ void test_eigenvalue ( const matrix_type& m )
+ {
+ typedef typename matrix_type::type type;
+ typedef typename matrix_type::mem_manager_type MM;
+ matrix<type,matrix_type::NR, matrix_type::NC, MM, row_major_layout> mr(m);
+ matrix<type,matrix_type::NR, matrix_type::NC, MM, column_major_layout> mc(m);
+
+ {
+ eigenvalue_decomposition<matrix_type> test(mr);
+ test_eigenvalue_impl(mr, test);
+
+ eigenvalue_decomposition<matrix_type> test_symm(make_symmetric(mr));
+ test_eigenvalue_impl(make_symmetric(mr), test_symm);
+ }
+
+ {
+ eigenvalue_decomposition<matrix_type> test(mc);
+ test_eigenvalue_impl(mc, test);
+
+ eigenvalue_decomposition<matrix_type> test_symm(make_symmetric(mc));
+ test_eigenvalue_impl(make_symmetric(mc), test_symm);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void matrix_test_double()
+ {
+
+ test_eigenvalue(10*randm<double>(1,1));
+ test_eigenvalue(10*randm<double>(2,2));
+ test_eigenvalue(10*randm<double>(3,3));
+ test_eigenvalue(10*randm<double>(4,4));
+ test_eigenvalue(10*randm<double>(15,15));
+ test_eigenvalue(10*randm<double>(150,150));
+
+ test_eigenvalue(10*randm<double,1,1>());
+ test_eigenvalue(10*randm<double,2,2>());
+ test_eigenvalue(10*randm<double,3,3>());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void matrix_test_float()
+ {
+
+ test_eigenvalue(10*randm<float>(1,1));
+ test_eigenvalue(10*randm<float>(2,2));
+ test_eigenvalue(10*randm<float>(3,3));
+ test_eigenvalue(10*randm<float>(4,4));
+ test_eigenvalue(10*randm<float>(15,15));
+ test_eigenvalue(10*randm<float>(50,50));
+
+ test_eigenvalue(10*randm<float,1,1>());
+ test_eigenvalue(10*randm<float,2,2>());
+ test_eigenvalue(10*randm<float,3,3>());
+ }
+
+ template <int dims>
+ void test_eigenvalue2()
+ {
+ for (int seed = 0; seed < 10; ++seed)
+ {
+ print_spinner();
+ matrix<double> H = gaussian_randm(dims,dims,seed);
+ H = H*trans(H);
+
+ eigenvalue_decomposition<matrix<double> > eig(H);
+ matrix<double> HH = eig.get_pseudo_v()*diagm(eig.get_real_eigenvalues())*trans(eig.get_pseudo_v());
+ DLIB_TEST_MSG(max(abs(H - HH))<1e-12, "dims: " << dims << " error: " << max(abs(H - HH)));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class matrix_tester : public tester
+ {
+ public:
+ matrix_tester (
+ ) :
+ tester ("test_matrix_eig",
+ "Runs tests on the matrix eigen decomp component.")
+ {
+ //rnd.set_seed(cast_to_string(time(0)));
+ }
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "seed string: " << rnd.get_seed();
+
+ dlog << LINFO << "begin testing with double";
+ matrix_test_double();
+ dlog << LINFO << "begin testing with float";
+ matrix_test_float();
+
+ test_eigenvalue2<10>();
+ test_eigenvalue2<11>();
+ test_eigenvalue2<3>();
+ test_eigenvalue2<2>();
+ test_eigenvalue2<1>();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/matrix_lu.cpp b/ml/dlib/dlib/test/matrix_lu.cpp
new file mode 100644
index 000000000..f5425b355
--- /dev/null
+++ b/ml/dlib/dlib/test/matrix_lu.cpp
@@ -0,0 +1,223 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+#include <dlib/string.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.matrix_lu");
+
+ dlib::rand rnd;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename mat_type>
+ const matrix<typename mat_type::type> symm(const mat_type& m) { return m*trans(m); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename type>
+ const matrix<type> randmat(long r, long c)
+ {
+ matrix<type> m(r,c);
+ for (long row = 0; row < m.nr(); ++row)
+ {
+ for (long col = 0; col < m.nc(); ++col)
+ {
+ m(row,col) = static_cast<type>(rnd.get_random_double());
+ }
+ }
+
+ return m;
+ }
+
+ template <typename type, long NR, long NC>
+ const matrix<type,NR,NC> randmat()
+ {
+ matrix<type,NR,NC> m;
+ for (long row = 0; row < m.nr(); ++row)
+ {
+ for (long col = 0; col < m.nc(); ++col)
+ {
+ m(row,col) = static_cast<type>(rnd.get_random_double());
+ }
+ }
+
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_type>
+ void test_lu ( const matrix_type& m)
+ {
+ typedef typename matrix_type::type type;
+ const type eps = 10*max(abs(m))*sqrt(std::numeric_limits<type>::epsilon());
+ dlog << LDEBUG << "test_lu(): " << m.nr() << " x " << m.nc() << " eps: " << eps;
+ print_spinner();
+
+
+ lu_decomposition<matrix_type> test(m);
+
+ DLIB_TEST(test.is_square() == (m.nr() == m.nc()));
+
+ DLIB_TEST(test.nr() == m.nr());
+ DLIB_TEST(test.nc() == m.nc());
+
+ dlog << LDEBUG << "m.nr(): " << m.nr() << " m.nc(): " << m.nc();
+
+ type temp;
+ DLIB_TEST_MSG( (temp= max(abs(test.get_l()*test.get_u() - rowm(m,test.get_pivot())))) < eps,temp);
+
+ if (test.is_square())
+ {
+ // none of the matrices we should be passing in to test_lu() should be singular.
+ DLIB_TEST_MSG (abs(test.det()) > eps/100, "det: " << test.det() );
+ dlog << LDEBUG << "big det: " << test.det();
+
+ DLIB_TEST(test.is_singular() == false);
+
+ matrix<type> m2;
+ matrix<type,0,1> col;
+
+ m2 = identity_matrix<type>(m.nr());
+ DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2)));
+ m2 = randmat<type>(m.nr(),5);
+ DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2)));
+ m2 = randmat<type>(m.nr(),1);
+ DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2)));
+ col = randmat<type>(m.nr(),1);
+ DLIB_TEST_MSG(equal(m*test.solve(col), col,eps),max(abs(m*test.solve(m2)- m2)));
+
+ // now make us a singular matrix
+ if (m.nr() > 1)
+ {
+ matrix<type> sm(m);
+ set_colm(sm,0) = colm(sm,1);
+
+ lu_decomposition<matrix_type> test2(sm);
+ DLIB_TEST_MSG( (temp= max(abs(test2.get_l()*test2.get_u() - rowm(sm,test2.get_pivot())))) < eps,temp);
+
+ // these checks are only accurate for small matrices
+ if (test2.nr() < 100)
+ {
+ DLIB_TEST_MSG(test2.is_singular() == true,"det: " << test2.det());
+ DLIB_TEST_MSG(abs(test2.det()) < eps,"det: " << test2.det());
+ }
+
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void matrix_test_double()
+ {
+
+
+ test_lu(10*randmat<double>(2,2));
+ test_lu(10*randmat<double>(1,1));
+ test_lu(10*symm(randmat<double>(2,2)));
+ test_lu(10*randmat<double>(4,4));
+ test_lu(10*randmat<double>(9,4));
+ test_lu(10*randmat<double>(3,8));
+ test_lu(10*randmat<double>(15,15));
+ test_lu(2*symm(randmat<double>(15,15)));
+ test_lu(10*randmat<double>(100,100));
+ test_lu(10*randmat<double>(137,200));
+ test_lu(10*randmat<double>(200,101));
+
+ test_lu(10*randmat<double,2,2>());
+ test_lu(10*randmat<double,1,1>());
+ test_lu(10*randmat<double,4,3>());
+ test_lu(10*randmat<double,4,4>());
+ test_lu(10*randmat<double,9,4>());
+ test_lu(10*randmat<double,3,8>());
+ test_lu(10*randmat<double,15,15>());
+ test_lu(10*randmat<double,100,100>());
+ test_lu(10*randmat<double,137,200>());
+ test_lu(10*randmat<double,200,101>());
+
+ typedef matrix<double,0,0,default_memory_manager, column_major_layout> mat;
+ test_lu(mat(3*randmat<double>(4,4)));
+ test_lu(mat(3*randmat<double>(9,4)));
+ test_lu(mat(3*randmat<double>(3,8)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void matrix_test_float()
+ {
+
+ // -------------------------------
+
+ test_lu(3*randmat<float>(1,1));
+ test_lu(3*randmat<float>(2,2));
+ test_lu(3*randmat<float>(4,4));
+ test_lu(3*randmat<float>(9,4));
+ test_lu(3*randmat<float>(3,8));
+ test_lu(3*randmat<float>(137,200));
+ test_lu(3*randmat<float>(200,101));
+
+ test_lu(3*randmat<float,1,1>());
+ test_lu(3*randmat<float,2,2>());
+ test_lu(3*randmat<float,4,3>());
+ test_lu(3*randmat<float,4,4>());
+ test_lu(3*randmat<float,9,4>());
+ test_lu(3*randmat<float,3,8>());
+ test_lu(3*randmat<float,137,200>());
+ test_lu(3*randmat<float,200,101>());
+
+ typedef matrix<float,0,0,default_memory_manager, column_major_layout> mat;
+ test_lu(mat(3*randmat<float>(4,4)));
+ test_lu(mat(3*randmat<float>(9,4)));
+ test_lu(mat(3*randmat<float>(3,8)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class matrix_tester : public tester
+ {
+ public:
+ matrix_tester (
+ ) :
+ tester ("test_matrix_lu",
+ "Runs tests on the matrix LU component.")
+ {
+ //rnd.set_seed(cast_to_string(time(0)));
+ }
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "seed string: " << rnd.get_seed();
+
+ dlog << LINFO << "begin testing with double";
+ matrix_test_double();
+ dlog << LINFO << "begin testing with float";
+ matrix_test_float();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/matrix_qr.cpp b/ml/dlib/dlib/test/matrix_qr.cpp
new file mode 100644
index 000000000..e3c7c4e42
--- /dev/null
+++ b/ml/dlib/dlib/test/matrix_qr.cpp
@@ -0,0 +1,208 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+#include <dlib/string.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.matrix_qr");
+
+ dlib::rand rnd;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename mat_type>
+ const matrix<typename mat_type::type> symm(const mat_type& m) { return m*trans(m); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename type>
+ const matrix<type> randmat(long r, long c)
+ {
+ matrix<type> m(r,c);
+ for (long row = 0; row < m.nr(); ++row)
+ {
+ for (long col = 0; col < m.nc(); ++col)
+ {
+ m(row,col) = static_cast<type>(rnd.get_random_double());
+ }
+ }
+
+ return m;
+ }
+
+ template <typename type, long NR, long NC>
+ const matrix<type,NR,NC> randmat()
+ {
+ matrix<type,NR,NC> m;
+ for (long row = 0; row < m.nr(); ++row)
+ {
+ for (long col = 0; col < m.nc(); ++col)
+ {
+ m(row,col) = static_cast<type>(rnd.get_random_double());
+ }
+ }
+
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_type>
+ void test_qr ( const matrix_type& m)
+ {
+ typedef typename matrix_type::type type;
+ const type eps = 10*max(abs(m))*sqrt(std::numeric_limits<type>::epsilon());
+ dlog << LDEBUG << "test_qr(): " << m.nr() << " x " << m.nc() << " eps: " << eps;
+ print_spinner();
+
+
+ qr_decomposition<matrix_type> test(m);
+
+
+ DLIB_TEST(test.nr() == m.nr());
+ DLIB_TEST(test.nc() == m.nc());
+
+
+ type temp;
+ DLIB_TEST_MSG( (temp= max(abs(test.get_q()*test.get_r() - m))) < eps,temp);
+
+ // none of the matrices we should be passing in to test_qr() should be non-full rank.
+ DLIB_TEST(test.is_full_rank() == true);
+
+ if (m.nr() == m.nc())
+ {
+ matrix<type> m2;
+ matrix<type,0,1> col;
+
+ m2 = identity_matrix<type>(m.nr());
+ DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2)));
+ m2 = randmat<type>(m.nr(),5);
+ DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2)));
+ m2 = randmat<type>(m.nr(),1);
+ DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2)));
+ col = randmat<type>(m.nr(),1);
+ DLIB_TEST_MSG(equal(m*test.solve(col), col,eps),max(abs(m*test.solve(m2)- m2)));
+ }
+ else
+ {
+ DLIB_TEST_MSG(dlib::equal(pinv(m), test.solve(identity_matrix<type>(m.nr())), eps),
+ max(abs(pinv(m) - test.solve(identity_matrix<type>(m.nr())))) );
+ }
+
+ // now make us a non-full rank matrix
+ if (m.nc() > 1)
+ {
+ matrix<type> sm(m);
+ set_colm(sm,0) = colm(sm,1);
+
+ qr_decomposition<matrix_type> test2(sm);
+ DLIB_TEST_MSG( (temp= max(abs(test.get_q()*test.get_r() - m))) < eps,temp);
+
+ if (test2.nc() < 100)
+ {
+ DLIB_TEST_MSG(test2.is_full_rank() == false,"eps: " << eps);
+ }
+
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void matrix_test_double()
+ {
+
+ test_qr(10*randmat<double>(1,1));
+ test_qr(10*randmat<double>(2,2));
+ test_qr(10*symm(randmat<double>(2,2)));
+ test_qr(10*randmat<double>(4,4));
+ test_qr(10*randmat<double>(9,4));
+ test_qr(10*randmat<double>(15,15));
+ test_qr(2*symm(randmat<double>(15,15)));
+ test_qr(10*randmat<double>(100,100));
+ test_qr(10*randmat<double>(237,200));
+ test_qr(10*randmat<double>(200,101));
+
+ test_qr(10*randmat<double,1,1>());
+ test_qr(10*randmat<double,2,2>());
+ test_qr(10*randmat<double,4,3>());
+ test_qr(10*randmat<double,4,4>());
+ test_qr(10*randmat<double,9,4>());
+ test_qr(10*randmat<double,15,15>());
+ test_qr(10*randmat<double,100,100>());
+
+ typedef matrix<double,0,0,default_memory_manager, column_major_layout> mat;
+ test_qr(mat(3*randmat<double>(9,4)));
+ test_qr(mat(3*randmat<double>(9,9)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void matrix_test_float()
+ {
+
+
+ test_qr(3*randmat<float>(1,1));
+ test_qr(3*randmat<float>(2,2));
+ test_qr(3*randmat<float>(4,4));
+ test_qr(3*randmat<float>(9,4));
+ test_qr(3*randmat<float>(237,200));
+
+ test_qr(3*randmat<float,1,1>());
+ test_qr(3*randmat<float,2,2>());
+ test_qr(3*randmat<float,4,3>());
+ test_qr(3*randmat<float,4,4>());
+ test_qr(3*randmat<float,5,4>());
+
+ typedef matrix<float,0,0,default_memory_manager, column_major_layout> mat;
+ test_qr(mat(3*randmat<float>(5,4)));
+ test_qr(mat(3*randmat<float>(9,9)));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class matrix_tester : public tester
+ {
+ public:
+ matrix_tester (
+ ) :
+ tester ("test_matrix_qr",
+ "Runs tests on the matrix QR component.")
+ {
+ //rnd.set_seed(cast_to_string(time(0)));
+ }
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "seed string: " << rnd.get_seed();
+
+ dlog << LINFO << "begin testing with double";
+ matrix_test_double();
+ dlog << LINFO << "begin testing with float";
+ matrix_test_float();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/max_cost_assignment.cpp b/ml/dlib/dlib/test/max_cost_assignment.cpp
new file mode 100644
index 000000000..852418764
--- /dev/null
+++ b/ml/dlib/dlib/test/max_cost_assignment.cpp
@@ -0,0 +1,157 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/optimization.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../rand.h"
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.max_cost_assignment");
+
+// ----------------------------------------------------------------------------------------
+
+ std::vector<std::vector<long> > permutations (
+ matrix<long,1,0> vals
+ )
+ {
+ if (vals.size() == 0)
+ {
+ return std::vector<std::vector<long> >();
+ }
+ else if (vals.size() == 1)
+ {
+ return std::vector<std::vector<long> >(1,std::vector<long>(1,vals(0)));
+ }
+
+
+ std::vector<std::vector<long> > temp;
+
+
+ for (long i = 0; i < vals.size(); ++i)
+ {
+ const std::vector<std::vector<long> >& res = permutations(remove_col(vals,i));
+
+ for (unsigned long j = 0; j < res.size(); ++j)
+ {
+ temp.resize(temp.size()+1);
+ std::vector<long>& part = temp.back();
+ part.push_back(vals(i));
+ part.insert(part.end(), res[j].begin(), res[j].end());
+ }
+ }
+
+
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ std::vector<long> brute_force_max_cost_assignment (
+ matrix<T> cost
+ )
+ {
+ if (cost.size() == 0)
+ return std::vector<long>();
+
+ const std::vector<std::vector<long> >& perms = permutations(range(0,cost.nc()-1));
+
+ T best_cost = std::numeric_limits<T>::min();
+ unsigned long best_idx = 0;
+ for (unsigned long i = 0; i < perms.size(); ++i)
+ {
+ const T temp = assignment_cost(cost, perms[i]);
+ if (temp > best_cost)
+ {
+ best_idx = i;
+ best_cost = temp;
+ }
+ }
+
+ return perms[best_idx];
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class test_max_cost_assignment : public tester
+ {
+ public:
+ test_max_cost_assignment (
+ ) :
+ tester ("test_max_cost_assignment",
+ "Runs tests on the max_cost_assignment function.")
+ {}
+
+ dlib::rand rnd;
+
+ template <typename T>
+ void test_hungarian()
+ {
+ long size = rnd.get_random_32bit_number()%7;
+ long range = rnd.get_random_32bit_number()%100;
+ matrix<T> cost = matrix_cast<T>(randm(size,size,rnd)*range) - range/2;
+
+ // use a uniform cost matrix sometimes
+ if ((rnd.get_random_32bit_number()%100) == 0)
+ cost = rnd.get_random_32bit_number()%100;
+
+ // negate the cost matrix every now and then
+ if ((rnd.get_random_32bit_number()%100) == 0)
+ cost = -cost;
+
+
+ std::vector<long> assign = brute_force_max_cost_assignment(cost);
+ T true_eval = assignment_cost(cost, assign);
+ assign = max_cost_assignment(cost);
+ DLIB_TEST(assignment_cost(cost,assign) == true_eval);
+ assign = max_cost_assignment(matrix_cast<signed char>(cost));
+ DLIB_TEST(assignment_cost(cost,assign) == true_eval);
+
+
+ cost = matrix_cast<T>(randm(size,size,rnd)*range);
+ assign = brute_force_max_cost_assignment(cost);
+ true_eval = assignment_cost(cost, assign);
+ assign = max_cost_assignment(cost);
+ DLIB_TEST(assignment_cost(cost,assign) == true_eval);
+ assign = max_cost_assignment(matrix_cast<unsigned char>(cost));
+ DLIB_TEST(assignment_cost(cost,assign) == true_eval);
+ assign = max_cost_assignment(matrix_cast<typename unsigned_type<T>::type>(cost));
+ DLIB_TEST(assignment_cost(cost,assign) == true_eval);
+ }
+
+ void perform_test (
+ )
+ {
+ for (long i = 0; i < 1000; ++i)
+ {
+ if ((i%100) == 0)
+ print_spinner();
+
+ test_hungarian<short>();
+ test_hungarian<int>();
+ test_hungarian<long>();
+ test_hungarian<int64>();
+ }
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/max_sum_submatrix.cpp b/ml/dlib/dlib/test/max_sum_submatrix.cpp
new file mode 100644
index 000000000..34b4756ff
--- /dev/null
+++ b/ml/dlib/dlib/test/max_sum_submatrix.cpp
@@ -0,0 +1,177 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/optimization.h>
+#include <dlib/rand.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.max_sum_submatrix");
+
+// ----------------------------------------------------------------------------------------
+
+ bool order_rects (
+ const rectangle& a,
+ const rectangle& b
+ )
+ {
+ if (a.left() < b.left()) return true;
+ else if (a.left() > b.left()) return false;
+
+ if (a.right() < b.right()) return true;
+ else if (a.right() > b.right()) return false;
+
+ if (a.top() < b.top()) return true;
+ else if (a.top() > b.top()) return false;
+
+ if (a.bottom() < b.bottom()) return true;
+ else if (a.bottom() > b.bottom()) return false;
+
+ return false;
+ }
+
+ void run_test(
+ const int num
+ )
+ {
+ static dlib::rand rnd;
+
+ matrix<int> mat, mask;
+
+ mat.set_size(rnd.get_random_32bit_number()%1000 + 1,
+ rnd.get_random_32bit_number()%1000 + 1);
+ mask.set_size(mat.nr(), mat.nc());
+ mask = 0;
+
+ mat = -10000;
+
+ std::vector<rectangle> true_rects;
+
+ for (int i = 0; i < num; ++i)
+ {
+ const int width = rnd.get_random_32bit_number()%100 + 1;
+ const int height = rnd.get_random_32bit_number()%100 + 1;
+
+ rectangle rect = centered_rect(rnd.get_random_16bit_number()%mat.nc(),
+ rnd.get_random_16bit_number()%mat.nr(),
+ width,height);
+ rect = get_rect(mat).intersect(rect);
+
+ // make sure this new rectangle doesn't overlap or abut any others
+ if (sum(subm(mask,grow_rect(rect,1).intersect(get_rect(mask)))) == 0)
+ {
+ set_subm(mat, rect) = rnd.get_random_8bit_number()%100 + 1;
+ set_subm(mask, rect) = 1;
+ true_rects.push_back(rect);
+ }
+ }
+
+
+ std::vector<rectangle> res;
+ res = max_sum_submatrix(mat, true_rects.size()+10, 0);
+
+ DLIB_TEST(res.size() == true_rects.size());
+
+ // make sure big rectangles come first
+ for (unsigned long i = 0; i+1 < res.size(); ++i)
+ {
+ DLIB_TEST(sum(subm(mat,res[i])) >= sum(subm(mat,res[i+1])));
+ }
+
+ // make sure rectangles match
+ sort(true_rects.begin(), true_rects.end(), order_rects);
+ sort(res.begin(), res.end(), order_rects);
+ for (unsigned long i = 0; i < res.size(); ++i)
+ {
+ DLIB_TEST_MSG(res[i] == true_rects[i],
+ "i: " << i << " res[i]: " << res[i] << " true_rects[i]: " << true_rects[i]);
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void run_test2()
+ {
+ matrix<T> mat(100,100);
+ mat = 1;
+ std::vector<rectangle> res = max_sum_submatrix(mat, 0, 0);
+
+ DLIB_TEST(res.size() == 0);
+ res = max_sum_submatrix(mat, 1, 0);
+ DLIB_TEST(res.size() == 1);
+ DLIB_TEST(res[0] == get_rect(mat));
+ res = max_sum_submatrix(mat, 3, 0);
+ DLIB_TEST(res.size() == 1);
+ DLIB_TEST(res[0] == get_rect(mat));
+ res = max_sum_submatrix(mat, 3, 10);
+ DLIB_TEST(res.size() == 1);
+ DLIB_TEST(res[0] == get_rect(mat));
+
+ res = max_sum_submatrix(mat, 3, mat.size());
+ DLIB_TEST(res.size() == 0);
+
+ mat = -1;
+ res = max_sum_submatrix(mat, 1, 0);
+ DLIB_TEST(res.size() == 0);
+
+ const rectangle rect1 = rectangle(10,10,40,40);
+ const rectangle rect2 = rectangle(35,35,80,80);
+
+ set_subm(mat, rect1) = 2;
+ set_subm(mat, rect2) = 1;
+ res = max_sum_submatrix(mat, 3, 0);
+ DLIB_TEST(res.size() == 2);
+ DLIB_TEST(res[0] == rect2);
+ DLIB_TEST(res[1] == rect1);
+
+ res = max_sum_submatrix(mat, 3, 2*rect1.area() - 2*(rect1.intersect(rect2)).area());
+ DLIB_TEST(res.size() == 1);
+ DLIB_TEST(res[0] == rect2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ class test_max_sum_submatrix : public tester
+ {
+ public:
+ test_max_sum_submatrix (
+ ) :
+ tester ("test_max_sum_submatrix",
+ "Runs tests on the max_sum_submatrix() function.")
+ {}
+
+ void perform_test (
+ )
+ {
+ for (int j = 0; j < 5; ++j)
+ {
+ print_spinner();
+ for (int i = 0; i < 40; ++i)
+ run_test(i);
+ }
+
+ run_test2<int>();
+ run_test2<short>();
+ run_test2<signed char>();
+ run_test2<float>();
+ run_test2<double>();
+ }
+ } a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/md5.cpp b/ml/dlib/dlib/test/md5.cpp
new file mode 100644
index 000000000..15fafac3c
--- /dev/null
+++ b/ml/dlib/dlib/test/md5.cpp
@@ -0,0 +1,71 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/md5.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.md5");
+
+ void md5_test (
+ )
+ /*!
+ ensures
+ - runs tests on the md5 stuff compliance with the specs
+ !*/
+ {
+
+ DLIB_TEST(md5 ("") == "d41d8cd98f00b204e9800998ecf8427e");
+ DLIB_TEST(md5 ("a") == "0cc175b9c0f1b6a831c399e269772661");
+ DLIB_TEST(md5 ("abc") == "900150983cd24fb0d6963f7d28e17f72");
+ DLIB_TEST(md5 ("message digest") == "f96b697d7cb7938d525a2f31aaf161d0");
+ DLIB_TEST(md5 ("abcdefghijklmnopqrstuvwxyz") == "c3fcd3d76192e4007dfb496cca67e13b");
+ DLIB_TEST(md5 ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") == "d174ab98d277d9f5a5611c2c9f419d9f");
+ DLIB_TEST(md5 ("12345678901234567890123456789012345678901234567890123456789012345678901234567890") == "57edf4a22be3c955ac49da2e2107b67a");
+
+ // make sure the two versions of md5() always agree
+ for (int num = 0; num < 2000; ++num)
+ {
+ std::string temp;
+ for (int i = 0; i < num; ++i)
+ temp += 'a';
+
+ istringstream str(temp);
+ DLIB_TEST(md5(temp) == md5(str));
+ }
+
+ }
+
+
+ class md5_tester : public tester
+ {
+ public:
+ md5_tester (
+ ) :
+ tester ("test_md5",
+ "Runs tests on the md5 component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ md5_test();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/member_function_pointer.cpp b/ml/dlib/dlib/test/member_function_pointer.cpp
new file mode 100644
index 000000000..72aa3aa35
--- /dev/null
+++ b/ml/dlib/dlib/test/member_function_pointer.cpp
@@ -0,0 +1,553 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/member_function_pointer.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.member_function_pointer");
+
+ class mfp_test_helper_other
+ {
+ public:
+ mfp_test_helper_other (
+ ): i(-1) {}
+
+
+ mutable int i;
+
+
+ void go0 (
+ ) { i = 0; }
+ void go1 (
+ int v1
+ ) { i = 1*v1; }
+ void go2 (
+ int v1,int v2
+ ) { i = 2*v1*v2; }
+ void go3 (
+ int v1,int v2,int v3
+ ) { i = 3*v1*v2*v3; }
+ void go4 (
+ int v1,int v2,int v3,int v4
+ ) { i = 4*v1*v2*v3*v4; }
+
+ };
+
+
+ class mfp_test_helper
+ {
+ public:
+ mfp_test_helper (
+ ): i(-1) {}
+
+
+ mutable int i;
+
+
+ void go0 (
+ ) { i = 0; }
+ void go1 (
+ int v1
+ ) { i = 1*v1; }
+ void go2 (
+ int v1,int v2
+ ) { i = 2*v1*v2; }
+ void go3 (
+ int v1,int v2,int v3
+ ) { i = 3*v1*v2*v3; }
+ void go4 (
+ int v1,int v2,int v3,int v4
+ ) { i = 4*v1*v2*v3*v4; }
+
+ };
+
+ class mfp_test_helper_const
+ {
+ public:
+ mfp_test_helper_const (
+ ): i(-1) {}
+
+
+ mutable int i;
+
+ void go0 (
+ ) const { i = 0; }
+ void go1 (
+ int v1
+ ) const { i = 1*v1; }
+ void go2 (
+ int v1,int v2
+ ) const { i = 2*v1*v2; }
+ void go3 (
+ int v1,int v2,int v3
+ ) const { i = 3*v1*v2*v3; }
+ void go4 (
+ int v1,int v2,int v3,int v4
+ ) const { i = 4*v1*v2*v3*v4; }
+ };
+
+ template <
+ template <typename P1 = void, typename P2 = void, typename P3 = void, typename P4 = void> class mfp,
+ typename test_helper
+ >
+ void member_function_pointer_kernel_test (
+ )
+ /*!
+ requires
+ - mfp is an implementation of member_function_pointer/member_function_pointer_kernel_abstract.h
+ ensures
+ - runs tests on mfp for compliance with the specs
+ !*/
+ {
+
+
+ test_helper helper;
+
+ mfp<> a0, b0;
+ mfp<int> a1, b1;
+ mfp<int,int> a2, b2;
+ mfp<int,int,int> a3, b3;
+ mfp<int,int,int,int> a4, b4;
+
+ mfp<> a0c, b0c;
+ mfp<int> a1c, b1c;
+ mfp<int,int> a2c, b2c;
+ mfp<int,int,int> a3c, b3c;
+ mfp<int,int,int,int> a4c, b4c;
+
+ DLIB_TEST(a0c == b0c);
+ DLIB_TEST(a1c == b1c);
+ DLIB_TEST(a2c == b2c);
+ DLIB_TEST(a3c == b3c);
+ DLIB_TEST(a4c == b4c);
+ DLIB_TEST((a0c != b0c) == false);
+ DLIB_TEST((a1c != b1c) == false);
+ DLIB_TEST((a2c != b2c) == false);
+ DLIB_TEST((a3c != b3c) == false);
+ DLIB_TEST((a4c != b4c) == false);
+
+ DLIB_TEST(a0.is_set() == false);
+ DLIB_TEST(b0.is_set() == false);
+ DLIB_TEST(a0c.is_set() == false);
+ DLIB_TEST(b0c.is_set() == false);
+
+ DLIB_TEST(!a0 );
+ DLIB_TEST(!b0 );
+ DLIB_TEST(!a0c);
+ DLIB_TEST(!b0c);
+
+ DLIB_TEST(a1.is_set() == false);
+ DLIB_TEST(b1.is_set() == false);
+ DLIB_TEST(a1c.is_set() == false);
+ DLIB_TEST(b1c.is_set() == false);
+
+ DLIB_TEST(!a1 );
+ DLIB_TEST(!b1 );
+ DLIB_TEST(!a1c);
+ DLIB_TEST(!b1c);
+
+
+ DLIB_TEST(a2.is_set() == false);
+ DLIB_TEST(b2.is_set() == false);
+ DLIB_TEST(a2c.is_set() == false);
+ DLIB_TEST(b2c.is_set() == false);
+
+ DLIB_TEST(!a2);
+ DLIB_TEST(!b2);
+ DLIB_TEST(!a2c);
+ DLIB_TEST(!b2c);
+
+ DLIB_TEST(a3.is_set() == false);
+ DLIB_TEST(b3.is_set() == false);
+ DLIB_TEST(a3c.is_set() == false);
+ DLIB_TEST(b3c.is_set() == false);
+
+ DLIB_TEST(!a3);
+ DLIB_TEST(!b3);
+ DLIB_TEST(!a3c);
+ DLIB_TEST(!b3c);
+
+ DLIB_TEST(a4.is_set() == false);
+ DLIB_TEST(b4.is_set() == false);
+ DLIB_TEST(a4c.is_set() == false);
+ DLIB_TEST(b4c.is_set() == false);
+
+ DLIB_TEST(!a4);
+ DLIB_TEST(!b4);
+ DLIB_TEST(!a4c);
+ DLIB_TEST(!b4c);
+
+ a0.set(helper,&test_helper::go0);
+ a0c.set(helper,&test_helper::go0);
+ DLIB_TEST(a0.is_set() == true);
+ DLIB_TEST(a0c.is_set() == true);
+ DLIB_TEST(b0.is_set() == false);
+ DLIB_TEST(b0c.is_set() == false);
+
+ DLIB_TEST(a0);
+ DLIB_TEST(a0c);
+ DLIB_TEST(!b0);
+ DLIB_TEST(!b0c);
+
+ a0 = a0;
+ DLIB_TEST(a0 == a0);
+ DLIB_TEST(!(a0 != a0));
+ DLIB_TEST(a0.is_set() == true);
+ DLIB_TEST(a0c.is_set() == true);
+ DLIB_TEST(b0.is_set() == false);
+ DLIB_TEST(b0c.is_set() == false);
+
+ DLIB_TEST(a0);
+ DLIB_TEST(a0c);
+ DLIB_TEST(!b0);
+ DLIB_TEST(!b0c);
+
+ swap(a0,b0);
+ swap(a0c,b0c);
+ DLIB_TEST(a0.is_set() == false);
+ DLIB_TEST(a0c.is_set() == false);
+ DLIB_TEST(b0.is_set() == true);
+ DLIB_TEST(b0c.is_set() == true);
+
+ DLIB_TEST(!a0);
+ DLIB_TEST(!a0c);
+ DLIB_TEST(b0);
+ DLIB_TEST(b0c);
+
+ a0 = b0;
+ DLIB_TEST(a0 == a0);
+ DLIB_TEST(a0 == b0);
+ DLIB_TEST(!(a0 != b0));
+ DLIB_TEST(a0.is_set() == true);
+ DLIB_TEST(a0c.is_set() == false);
+ DLIB_TEST(b0.is_set() == true);
+ DLIB_TEST(b0c.is_set() == true);
+
+ DLIB_TEST(a0 );
+ DLIB_TEST(!a0c);
+ DLIB_TEST(b0);
+ DLIB_TEST(b0c);
+
+
+ a0.clear();
+ a0c.clear();
+ b0.clear();
+ b0c.clear();
+ DLIB_TEST(a0.is_set() == false);
+ DLIB_TEST(a0c.is_set() == false);
+ DLIB_TEST(b0.is_set() == false);
+ DLIB_TEST(b0c.is_set() == false);
+
+
+ a1.set(helper,&test_helper::go1);
+ a1c.set(helper,&test_helper::go1);
+ DLIB_TEST(a1.is_set() == true);
+ DLIB_TEST(a1c.is_set() == true);
+ DLIB_TEST(b1.is_set() == false);
+ DLIB_TEST(b1c.is_set() == false);
+ swap(a1,b1);
+ swap(a1c,b1c);
+ DLIB_TEST(a1.is_set() == false);
+ DLIB_TEST(a1c.is_set() == false);
+ DLIB_TEST(b1.is_set() == true);
+ DLIB_TEST(b1c.is_set() == true);
+
+ DLIB_TEST(!a1);
+ DLIB_TEST(!a1c);
+ DLIB_TEST(b1);
+ DLIB_TEST(b1c);
+
+
+ a1 = b1;
+ DLIB_TEST(a1 == a1);
+ DLIB_TEST(a1 == b1);
+ DLIB_TEST(!(a1 != b1));
+ DLIB_TEST(a1.is_set() == true);
+ DLIB_TEST(a1c.is_set() == false);
+ DLIB_TEST(b1.is_set() == true);
+ DLIB_TEST(b1c.is_set() == true);
+
+
+ a1.clear();
+ a1c.clear();
+ b1.clear();
+ b1c.clear();
+ DLIB_TEST(a1.is_set() == false);
+ DLIB_TEST(a1c.is_set() == false);
+ DLIB_TEST(b1.is_set() == false);
+ DLIB_TEST(b1c.is_set() == false);
+
+
+ a2.set(helper,&test_helper::go2);
+ a2c.set(helper,&test_helper::go2);
+ DLIB_TEST(a2.is_set() == true);
+ DLIB_TEST(a2c.is_set() == true);
+ DLIB_TEST(b2.is_set() == false);
+ DLIB_TEST(b2c.is_set() == false);
+ swap(a2,b2);
+ swap(a2c,b2c);
+ DLIB_TEST(a2.is_set() == false);
+ DLIB_TEST(a2c.is_set() == false);
+ DLIB_TEST(b2.is_set() == true);
+ DLIB_TEST(b2c.is_set() == true);
+
+ DLIB_TEST(!a2);
+ DLIB_TEST(!a2c);
+ DLIB_TEST(b2);
+ DLIB_TEST(b2c);
+ if (b2)
+ {
+ }
+ else
+ {
+ DLIB_TEST(false);
+ }
+
+ if (a2c)
+ {
+ DLIB_TEST(false);
+ }
+ else
+ {
+ DLIB_TEST(true);
+ }
+
+ a2 = b2;
+ DLIB_TEST(a2 == a2);
+ DLIB_TEST(a2 == b2);
+ DLIB_TEST(!(a2 != b2));
+ DLIB_TEST(a2.is_set() == true);
+ DLIB_TEST(a2c.is_set() == false);
+ DLIB_TEST(b2.is_set() == true);
+ DLIB_TEST(b2c.is_set() == true);
+
+ a2.clear();
+ a2c.clear();
+ b2.clear();
+ b2c.clear();
+ DLIB_TEST(a2.is_set() == false);
+ DLIB_TEST(a2c.is_set() == false);
+ DLIB_TEST(b2.is_set() == false);
+ DLIB_TEST(b2c.is_set() == false);
+
+
+ a3.set(helper,&test_helper::go3);
+ a3c.set(helper,&test_helper::go3);
+ DLIB_TEST(a3.is_set() == true);
+ DLIB_TEST(a3c.is_set() == true);
+ DLIB_TEST(b3.is_set() == false);
+ DLIB_TEST(b3c.is_set() == false);
+ swap(a3,b3);
+ swap(a3c,b3c);
+ DLIB_TEST(a3.is_set() == false);
+ DLIB_TEST(a3c.is_set() == false);
+ DLIB_TEST(b3.is_set() == true);
+ DLIB_TEST(b3c.is_set() == true);
+
+ a3 = b3;
+ DLIB_TEST(a3 == a3);
+ DLIB_TEST(a3 == b3);
+ DLIB_TEST(!(a3 != b3));
+ DLIB_TEST(a3.is_set() == true);
+ DLIB_TEST(a3c.is_set() == false);
+ DLIB_TEST(b3.is_set() == true);
+ DLIB_TEST(b3c.is_set() == true);
+
+
+ a3.clear();
+ a3c.clear();
+ b3.clear();
+ b3c.clear();
+ DLIB_TEST(a3.is_set() == false);
+ DLIB_TEST(a3c.is_set() == false);
+ DLIB_TEST(b3.is_set() == false);
+ DLIB_TEST(b3c.is_set() == false);
+
+
+ a4.set(helper,&test_helper::go4);
+ a4c.set(helper,&test_helper::go4);
+ DLIB_TEST(a4.is_set() == true);
+ DLIB_TEST(a4c.is_set() == true);
+ DLIB_TEST(b4.is_set() == false);
+ DLIB_TEST(b4c.is_set() == false);
+ swap(a4,b4);
+ swap(a4c,b4c);
+ DLIB_TEST(a4.is_set() == false);
+ DLIB_TEST(a4c.is_set() == false);
+ DLIB_TEST(b4.is_set() == true);
+ DLIB_TEST(b4c.is_set() == true);
+
+ a4 = b4;
+ a4 = b4;
+ a4 = b4;
+ a4 = b4;
+ DLIB_TEST(a4 == a4);
+ DLIB_TEST(a4 == b4);
+ DLIB_TEST(!(a4 != b4));
+ DLIB_TEST(a4.is_set() == true);
+ DLIB_TEST(a4c.is_set() == false);
+ DLIB_TEST(b4.is_set() == true);
+ DLIB_TEST(b4c.is_set() == true);
+
+
+ a4.clear();
+ a4c.clear();
+ b4.clear();
+ b4c.clear();
+ DLIB_TEST(a4.is_set() == false);
+ DLIB_TEST(a4c.is_set() == false);
+ DLIB_TEST(b4.is_set() == false);
+ DLIB_TEST(b4c.is_set() == false);
+
+
+ a0.set(helper,&test_helper::go0);
+ a0c.set(helper,&test_helper::go0);
+ b0 = a0;
+ b0c = a0c;
+ helper.i = -1;
+ a0();
+ DLIB_TEST(helper.i == 0);
+ helper.i = -1;
+ b0();
+ DLIB_TEST(helper.i == 0);
+ helper.i = -1;
+ a0c();
+ DLIB_TEST(helper.i == 0);
+ helper.i = -1;
+ b0c();
+ DLIB_TEST(helper.i == 0);
+
+
+ a1.set(helper,&test_helper::go1);
+ a1c.set(helper,&test_helper::go1);
+ b1 = a1;
+ b1c = a1c;
+ helper.i = -1;
+ a1(1);
+ DLIB_TEST(helper.i == 1);
+ helper.i = -1;
+ b1(10);
+ DLIB_TEST(helper.i == 1*10);
+ helper.i = -1;
+ a1c(20);
+ DLIB_TEST(helper.i == 1*20);
+ helper.i = -1;
+ b1c(30);
+ DLIB_TEST(helper.i == 1*30);
+
+
+ a2.set(helper,&test_helper::go2);
+ a2c.set(helper,&test_helper::go2);
+ b2 = a2;
+ b2c = a2c;
+ helper.i = -1;
+ a2(1,2);
+ DLIB_TEST(helper.i == 2*1*2);
+ helper.i = -1;
+ b2(3,4);
+ DLIB_TEST(helper.i == 2*3*4);
+ helper.i = -1;
+ a2c(5,6);
+ DLIB_TEST(helper.i == 2*5*6);
+ helper.i = -1;
+ b2c(7,8);
+ DLIB_TEST(helper.i == 2*7*8);
+
+
+ a3.set(helper,&test_helper::go3);
+ a3c.set(helper,&test_helper::go3);
+ b3 = a3;
+ b3c = a3c;
+ helper.i = -1;
+ a3(1,2,3);
+ DLIB_TEST(helper.i == 3*1*2*3);
+ helper.i = -1;
+ b3(4,5,6);
+ DLIB_TEST(helper.i == 3*4*5*6);
+ helper.i = -1;
+ a3c(7,8,9);
+ DLIB_TEST(helper.i == 3*7*8*9);
+ helper.i = -1;
+ b3c(1,2,3);
+ DLIB_TEST(helper.i == 3*1*2*3);
+
+
+ a4.set(helper,&test_helper::go4);
+ a4c.set(helper,&test_helper::go4);
+ DLIB_TEST(a4 == a4c);
+ b4 = a4;
+ b4c = a4c;
+ helper.i = -1;
+ a4(1,2,3,4);
+ DLIB_TEST(helper.i == 4*1*2*3*4);
+ helper.i = -1;
+ b4(5,6,7,8);
+ DLIB_TEST(helper.i == 4*5*6*7*8);
+ helper.i = -1;
+ a4c(9,1,2,3);
+ DLIB_TEST(helper.i == 4*9*1*2*3);
+ helper.i = -1;
+ b4c(4,5,6,7);
+ DLIB_TEST(helper.i == 4*4*5*6*7);
+
+ DLIB_TEST(a4 == b4);
+ DLIB_TEST(a4);
+ DLIB_TEST(a4 == b4);
+ a4.clear();
+ DLIB_TEST(a4 != b4);
+ DLIB_TEST(!a4);
+ DLIB_TEST(a4 == 0);
+ DLIB_TEST(a4 == a4);
+ a4 = a4;
+ DLIB_TEST(a4 != b4);
+ DLIB_TEST(!a4);
+ DLIB_TEST(a4 == a4);
+ mfp_test_helper_other other;
+ a4.set(other,&mfp_test_helper_other::go4);
+ DLIB_TEST(a4 != b4);
+ DLIB_TEST(a4);
+ DLIB_TEST(a4 == a4);
+ a4.set(helper,&test_helper::go4);
+ DLIB_TEST(a4 == b4);
+ DLIB_TEST(a4);
+ DLIB_TEST(a4 == a4);
+
+
+
+ }
+
+
+
+ class member_function_pointer_tester : public tester
+ {
+ public:
+ member_function_pointer_tester (
+ ) :
+ tester ("test_member_function_pointer",
+ "Runs tests on the member_function_pointer component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ member_function_pointer_kernel_test<member_function_pointer,mfp_test_helper>();
+ member_function_pointer_kernel_test<member_function_pointer,mfp_test_helper_const>();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/metaprogramming.cpp b/ml/dlib/dlib/test/metaprogramming.cpp
new file mode 100644
index 000000000..344d5ca3a
--- /dev/null
+++ b/ml/dlib/dlib/test/metaprogramming.cpp
@@ -0,0 +1,94 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/algs.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.metaprogramming");
+
+
+ void metaprogramming_test (
+ )
+ /*!
+ ensures
+ - runs tests on template metaprogramming objects and functions for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ DLIB_TEST(is_signed_type<signed char>::value == true);
+ DLIB_TEST(is_signed_type<signed short>::value == true);
+ DLIB_TEST(is_signed_type<signed int>::value == true);
+ DLIB_TEST(is_signed_type<signed long>::value == true);
+ DLIB_TEST(is_unsigned_type<signed char>::value == false);
+ DLIB_TEST(is_unsigned_type<signed short>::value == false);
+ DLIB_TEST(is_unsigned_type<signed int>::value == false);
+ DLIB_TEST(is_unsigned_type<signed long>::value == false);
+
+ DLIB_TEST(is_unsigned_type<unsigned char>::value == true);
+ DLIB_TEST(is_unsigned_type<unsigned short>::value == true);
+ DLIB_TEST(is_unsigned_type<unsigned int>::value == true);
+ DLIB_TEST(is_unsigned_type<unsigned long>::value == true);
+ DLIB_TEST(is_signed_type<unsigned char>::value == false);
+ DLIB_TEST(is_signed_type<unsigned short>::value == false);
+ DLIB_TEST(is_signed_type<unsigned int>::value == false);
+ DLIB_TEST(is_signed_type<unsigned long>::value == false);
+
+
+ COMPILE_TIME_ASSERT(is_signed_type<signed char>::value == true);
+ COMPILE_TIME_ASSERT(is_signed_type<signed short>::value == true);
+ COMPILE_TIME_ASSERT(is_signed_type<signed int>::value == true);
+ COMPILE_TIME_ASSERT(is_signed_type<signed long>::value == true);
+ COMPILE_TIME_ASSERT(is_unsigned_type<signed char>::value == false);
+ COMPILE_TIME_ASSERT(is_unsigned_type<signed short>::value == false);
+ COMPILE_TIME_ASSERT(is_unsigned_type<signed int>::value == false);
+ COMPILE_TIME_ASSERT(is_unsigned_type<signed long>::value == false);
+
+ COMPILE_TIME_ASSERT(is_unsigned_type<unsigned char>::value == true);
+ COMPILE_TIME_ASSERT(is_unsigned_type<unsigned short>::value == true);
+ COMPILE_TIME_ASSERT(is_unsigned_type<unsigned int>::value == true);
+ COMPILE_TIME_ASSERT(is_unsigned_type<unsigned long>::value == true);
+ COMPILE_TIME_ASSERT(is_signed_type<unsigned char>::value == false);
+ COMPILE_TIME_ASSERT(is_signed_type<unsigned short>::value == false);
+ COMPILE_TIME_ASSERT(is_signed_type<unsigned int>::value == false);
+ COMPILE_TIME_ASSERT(is_signed_type<unsigned long>::value == false);
+
+
+ }
+
+
+
+
+ class metaprogramming_tester : public tester
+ {
+ public:
+ metaprogramming_tester (
+ ) :
+ tester ("test_metaprogramming",
+ "Runs tests on the metaprogramming objects and functions.")
+ {}
+
+ void perform_test (
+ )
+ {
+ metaprogramming_test();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/mpc.cpp b/ml/dlib/dlib/test/mpc.cpp
new file mode 100644
index 000000000..c0a98dd17
--- /dev/null
+++ b/ml/dlib/dlib/test/mpc.cpp
@@ -0,0 +1,346 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <string>
+#include <sstream>
+
+#include <dlib/control.h>
+#include <dlib/optimization.h>
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.mpc");
+
+ template <
+ typename EXP1,
+ typename EXP2,
+ typename T, long NR, long NC, typename MM, typename L
+ >
+ unsigned long solve_qp_box_using_smo (
+ const matrix_exp<EXP1>& _Q,
+ const matrix_exp<EXP2>& _b,
+ matrix<T,NR,NC,MM,L>& alpha,
+ matrix<T,NR,NC,MM,L>& lower,
+ matrix<T,NR,NC,MM,L>& upper,
+ T eps,
+ unsigned long max_iter
+ )
+ /*!
+ ensures
+ - solves: 0.5*trans(x)*Q*x + trans(b)*x where x is box constrained.
+ !*/
+ {
+ const_temp_matrix<EXP1> Q(_Q);
+ const_temp_matrix<EXP2> b(_b);
+ //cout << "IN QP SOLVER" << endl;
+ //cout << "max eig: " << max(real_eigenvalues(Q)) << endl;
+ //cout << "min eig: " << min(real_eigenvalues(Q)) << endl;
+ //cout << "Q: \n" << Q << endl;
+ //cout << "b: \n" << b << endl;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(Q.nr() == Q.nc() &&
+ alpha.size() == lower.size() &&
+ alpha.size() == upper.size() &&
+ is_col_vector(b) &&
+ is_col_vector(alpha) &&
+ is_col_vector(lower) &&
+ is_col_vector(upper) &&
+ b.size() == alpha.size() &&
+ b.size() == Q.nr() &&
+ alpha.size() > 0 &&
+ 0 <= min(alpha-lower) &&
+ 0 <= max(upper-alpha) &&
+ eps > 0 &&
+ max_iter > 0,
+ "\t unsigned long solve_qp_box_using_smo()"
+ << "\n\t Invalid arguments were given to this function"
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t Q.nc(): " << Q.nc()
+ << "\n\t is_col_vector(b): " << is_col_vector(b)
+ << "\n\t is_col_vector(alpha): " << is_col_vector(alpha)
+ << "\n\t is_col_vector(lower): " << is_col_vector(lower)
+ << "\n\t is_col_vector(upper): " << is_col_vector(upper)
+ << "\n\t b.size(): " << b.size()
+ << "\n\t alpha.size(): " << alpha.size()
+ << "\n\t lower.size(): " << lower.size()
+ << "\n\t upper.size(): " << upper.size()
+ << "\n\t Q.nr(): " << Q.nr()
+ << "\n\t min(alpha-lower): " << min(alpha-lower)
+ << "\n\t max(upper-alpha): " << max(upper-alpha)
+ << "\n\t eps: " << eps
+ << "\n\t max_iter: " << max_iter
+ );
+
+
+ // Compute f'(alpha) (i.e. the gradient of f(alpha)) for the current alpha.
+ matrix<T,NR,NC,MM,L> df = Q*alpha + b;
+ matrix<T,NR,NC,MM,L> QQ = reciprocal_max(diag(Q));
+
+
+ unsigned long iter = 0;
+ for (; iter < max_iter; ++iter)
+ {
+ T max_df = 0;
+ long best_r =0;
+ for (long r = 0; r < Q.nr(); ++r)
+ {
+ if (alpha(r) <= lower(r) && df(r) > 0)
+ ;//alpha(r) = lower(r);
+ else if (alpha(r) >= upper(r) && df(r) < 0)
+ ;//alpha(r) = upper(r);
+ else if (std::abs(df(r)) > max_df)
+ {
+ best_r = r;
+ max_df = std::abs(df(r));
+ }
+ }
+
+ //for (long r = 0; r < Q.nr(); ++r)
+ long r = best_r;
+ {
+
+ const T old_alpha = alpha(r);
+ alpha(r) = -(df(r)-Q(r,r)*alpha(r))*QQ(r);
+ if (alpha(r) < lower(r))
+ alpha(r) = lower(r);
+ else if (alpha(r) > upper(r))
+ alpha(r) = upper(r);
+
+ const T delta = old_alpha-alpha(r);
+
+ // Now update the gradient. We will perform the equivalent of: df = Q*alpha + b;
+ for(long k = 0; k < df.nr(); ++k)
+ df(k) -= Q(r,k)*delta;
+ }
+
+ if (max_df < eps)
+ break;
+ }
+ //cout << "df: \n" << trans(df) << endl;
+ //cout << "objective value: " << 0.5*trans(alpha)*Q*alpha + trans(b)*alpha << endl;
+
+ return iter+1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl_mpc
+ {
+ template <long N>
+ void pack(
+ matrix<double,0,1>& out,
+ const std::vector<matrix<double,N,1> >& item
+ )
+ {
+ DLIB_CASSERT(item.size() != 0,"");
+ out.set_size(item.size()*item[0].size());
+ long j = 0;
+ for (unsigned long i = 0; i < item.size(); ++i)
+ for (long r = 0; r < item[i].size(); ++r)
+ out(j++) = item[i](r);
+ }
+
+ template <long N>
+ void pack(
+ matrix<double,0,1>& out,
+ const matrix<double,N,1>& item,
+ const long num
+ )
+ {
+ out.set_size(item.size()*num);
+ long j = 0;
+ for (long r = 0; r < num; ++r)
+ for (long i = 0; i < item.size(); ++i)
+ out(j++) = item(i);
+ }
+
+ template <long N>
+ void unpack(
+ std::vector<matrix<double,N,1> >& out,
+ const matrix<double,0,1>& item
+ )
+ {
+ DLIB_CASSERT(out.size() != 0,"");
+ DLIB_CASSERT((long)out.size()*out[0].size() == item.size(),"");
+ long j = 0;
+ for (unsigned long i = 0; i < out.size(); ++i)
+ for (long r = 0; r < out[i].size(); ++r)
+ out[i](r) = item(j++);
+ }
+ }
+
+ template <long S, long I>
+ unsigned long solve_linear_mpc (
+ const matrix<double,S,S>& A,
+ const matrix<double,S,I>& B,
+ const matrix<double,S,1>& C,
+ const matrix<double,S,1>& Q,
+ const matrix<double,I,1>& R,
+ const matrix<double,I,1>& _lower,
+ const matrix<double,I,1>& _upper,
+ const std::vector<matrix<double,S,1> >& target,
+ const matrix<double,S,1>& initial_state,
+ std::vector<matrix<double,I,1> >& controls // input and output
+ )
+ {
+ using namespace impl_mpc;
+ DLIB_CASSERT(target.size() == controls.size(),"");
+
+ matrix<double> K(B.nr()*controls.size(), B.nc()*controls.size());
+ matrix<double,0,1> M(B.nr()*controls.size());
+
+ // compute powers of A: Apow[i] == A^i
+ std::vector<matrix<double,S,S> > Apow(controls.size());
+ Apow[0] = identity_matrix(A);
+ for (unsigned long i = 1; i < Apow.size(); ++i)
+ Apow[i] = A*Apow[i-1];
+
+ // fill in K
+ K = 0;
+ for (unsigned long r = 0; r < controls.size(); ++r)
+ for (unsigned long c = 0; c <= r; ++c)
+ set_subm(K,r*B.nr(),c*B.nc(), B.nr(), B.nc()) = Apow[r-c]*B;
+
+ // fill in M
+ set_subm(M,0*A.nr(),0,A.nr(),1) = A*initial_state + C;
+ for (unsigned long i = 1; i < controls.size(); ++i)
+ set_subm(M,i*A.nr(),0,A.nr(),1) = A*subm(M,(i-1)*A.nr(),0,A.nr(),1) + C;
+
+ //cout << "M: \n" << M << endl;
+ //cout << "K: \n" << K << endl;
+
+ matrix<double,0,1> t, v, lower, upper;
+ pack(t, target);
+ pack(v, controls);
+ pack(lower, _lower, controls.size());
+ pack(upper, _upper, controls.size());
+
+
+ matrix<double> QQ(K.nr(),K.nr()), RR(K.nc(),K.nc());
+ QQ = 0;
+ RR = 0;
+ for (unsigned long c = 0; c < controls.size(); ++c)
+ {
+ set_subm(QQ,c*Q.nr(),c*Q.nr(),Q.nr(),Q.nr()) = diagm(Q);
+ set_subm(RR,c*R.nr(),c*R.nr(),R.nr(),R.nr()) = diagm(R);
+ }
+
+ matrix<double> m1 = trans(K)*QQ*K+RR;
+ matrix<double> m2 = trans(K)*QQ*(M-t);
+
+
+ // run the solver...
+ unsigned long iter;
+ iter = solve_qp_box_using_smo(
+ m1,
+ m2,
+ v,
+ lower,
+ upper,
+ 0.00000001,
+ 100000);
+
+ //cout << "iterations: " << iter << endl;
+
+ unpack(controls, v);
+ return iter;
+ }
+
+
+
+ class test_mpc : public tester
+ {
+ public:
+ test_mpc (
+ ) :
+ tester ("test_mpc",
+ "Runs tests on the mpc object.")
+ {}
+
+ void perform_test (
+ )
+ {
+ // a basic position + velocity model
+ matrix<double,2,2> A;
+ A = 1, 1,
+ 0, 1;
+ matrix<double,2,1> B, C;
+ B = 0,
+ 1;
+
+ C = 0.02,0.1; // no constant bias
+
+ matrix<double,2,1> Q;
+ Q = 2, 0; // only care about getting the position right
+ matrix<double,1,1> R, lower, upper;
+ R = 1;
+
+ lower = -0.2;
+ upper = 0.2;
+
+ std::vector<matrix<double,1,1> > controls(30);
+ std::vector<matrix<double,2,1> > target(30);
+ for (unsigned long i = 0; i < controls.size(); ++i)
+ {
+ controls[i] = 0;
+ target[i] = 0;
+ }
+
+ mpc<2,1,30> solver(A,B,C,Q,R,lower,upper);
+ solver.set_epsilon(0.00000001);
+ solver.set_max_iterations(10000);
+ matrix<double,2,1> initial_state;
+ initial_state = 0;
+ initial_state(0) = 5;
+ for (int i = 0; i < 30; ++i)
+ {
+ print_spinner();
+ matrix<double,1,1> control = solver(initial_state);
+
+ for (unsigned long i = 1; i < controls.size(); ++i)
+ controls[i-1] = controls[i];
+
+ // Compute the correct control via SMO and make sure it matches.
+ solve_linear_mpc(A,B,C,Q,R,lower,upper, target, initial_state, controls);
+ dlog << LINFO << "ERROR: " << length(control-controls[0]);
+ DLIB_TEST(length(control-controls[0]) < 1e-7);
+
+ initial_state = A*initial_state + B*control + C;
+ //cout << control(0) << "\t" << trans(initial_state);
+ }
+
+ {
+ // also just generally test our QP solver.
+ matrix<double,20,20> Q = gaussian_randm(20,20,5);
+ Q = Q*trans(Q);
+
+ matrix<double,20,1> b = randm(20,1)-0.5;
+ matrix<double,20,1> alpha, lower, upper, alpha2;
+ alpha = 0;
+ alpha2 = 0;
+ lower = -4;
+ upper = 3;
+
+ solve_qp_box_using_smo(Q,b,alpha,lower, upper, 0.000000001, 500000);
+ solve_qp_box_constrained(Q,b,alpha2,lower, upper, 0.000000001, 50000);
+ dlog << LINFO << trans(alpha);
+ dlog << LINFO << trans(alpha2);
+ dlog << LINFO << "objective value: " << 0.5*trans(alpha)*Q*alpha + trans(b)*alpha;
+ dlog << LINFO << "objective value2: " << 0.5*trans(alpha2)*Q*alpha + trans(b)*alpha2;
+ DLIB_TEST_MSG(max(abs(alpha-alpha2)) < 1e-7, max(abs(alpha-alpha2)));
+ }
+ }
+ } a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/multithreaded_object.cpp b/ml/dlib/dlib/test/multithreaded_object.cpp
new file mode 100644
index 000000000..96f8ea26e
--- /dev/null
+++ b/ml/dlib/dlib/test/multithreaded_object.cpp
@@ -0,0 +1,321 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <string>
+#include <sstream>
+
+#include <dlib/threads.h>
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.multithreaded_object");
+
+ dlib::mutex cm;
+ int count;
+
+ class test1 : multithreaded_object
+ {
+ public:
+ test1 ()
+ {
+ DLIB_TEST(number_of_threads_registered() == 0);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ clear();
+ DLIB_TEST(number_of_threads_registered() == 0);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ }
+
+ ~test1 ()
+ {
+ DLIB_TEST(number_of_threads_registered() == 0);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ stop();
+ DLIB_TEST(number_of_threads_registered() == 0);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ wait();
+ DLIB_TEST(number_of_threads_registered() == 0);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ }
+
+ private:
+ };
+
+ class test2 : private multithreaded_object
+ {
+ public:
+ test2()
+ {
+ DLIB_TEST(number_of_threads_registered() == 0);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ register_thread(*this,&test2::thread);
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ clear();
+ DLIB_TEST(number_of_threads_registered() == 0);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ register_thread(*this,&test2::thread);
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ }
+
+ ~test2()
+ {
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ stop();
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ wait();
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ }
+
+ private:
+
+ void thread()
+ {
+ auto_mutex M(cm);
+ ++count;
+ }
+
+ };
+
+ class test3_c1 : private multithreaded_object
+ {
+ public:
+ test3_c1()
+ {
+ DLIB_TEST(number_of_threads_registered() == 0);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ register_thread(*this,&test3_c1::thread);
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ start();
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(is_running() == true);
+ }
+
+ ~test3_c1()
+ {
+ DLIB_TEST(number_of_threads_registered() == 1);
+ stop();
+ DLIB_TEST(is_running() == false);
+ DLIB_TEST(number_of_threads_registered() == 1);
+ wait();
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ }
+
+ private:
+
+ void thread()
+ {
+ cm.lock();
+ ++count;
+ cm.unlock();
+ // wait until we are supposed to stop
+ while (!should_stop())
+ dlib::sleep(1);
+ }
+
+ };
+
+ class test4_c2 : private multithreaded_object
+ {
+ public:
+ test4_c2()
+ {
+ DLIB_TEST(number_of_threads_registered() == 0);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ register_thread(*this,&test4_c2::thread);
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ start();
+ DLIB_TEST(number_of_threads_registered() == 1);
+ DLIB_TEST(number_of_threads_alive() == 1);
+ DLIB_TEST(is_running() == true);
+ register_thread(*this,&test4_c2::thread);
+ DLIB_TEST(number_of_threads_registered() == 2);
+ DLIB_TEST(number_of_threads_alive() == 2);
+ DLIB_TEST(is_running() == true);
+ start();
+ DLIB_TEST(number_of_threads_registered() == 2);
+ DLIB_TEST(number_of_threads_alive() == 2);
+ DLIB_TEST(is_running() == true);
+ start();
+ DLIB_TEST(number_of_threads_registered() == 2);
+ DLIB_TEST(number_of_threads_alive() == 2);
+ DLIB_TEST(is_running() == true);
+ start();
+ DLIB_TEST(number_of_threads_registered() == 2);
+ DLIB_TEST(number_of_threads_alive() == 2);
+ DLIB_TEST(is_running() == true);
+ start();
+ DLIB_TEST(number_of_threads_registered() == 2);
+ DLIB_TEST(number_of_threads_alive() == 2);
+ DLIB_TEST(is_running() == true);
+ pause();
+ DLIB_TEST(number_of_threads_registered() == 2);
+ DLIB_TEST(number_of_threads_alive() == 2);
+ DLIB_TEST(is_running() == false);
+ }
+
+ ~test4_c2()
+ {
+ try
+ {
+ DLIB_TEST(number_of_threads_registered() == 2);
+ DLIB_TEST(number_of_threads_alive() == 2);
+ DLIB_TEST_MSG(is_running() == false,"is_running(): " << is_running());
+ stop();
+ DLIB_TEST(number_of_threads_registered() == 2);
+ DLIB_TEST(is_running() == false);
+ wait();
+ DLIB_TEST(number_of_threads_registered() == 2);
+ DLIB_TEST(number_of_threads_alive() == 0);
+ DLIB_TEST(is_running() == false);
+ }
+ catch(std::exception& e)
+ {
+ std::cerr << e.what() << std::endl;
+ exit(1);
+ }
+ }
+
+ private:
+
+ void thread()
+ {
+ auto_mutex M(cm);
+ ++count;
+ while (!should_stop())
+ dlib::sleep(10);
+ }
+
+ };
+
+
+ class test5 : private multithreaded_object
+ {
+ public:
+ test5()
+ {
+ register_thread(*this,&test5::thread1);
+ register_thread(*this,&test5::thread2);
+ register_thread(*this,&test5::thread3);
+ register_thread(*this,&test5::thread3);
+ start();
+ }
+
+ ~test5()
+ {
+ stop();
+ wait();
+ }
+
+ private:
+
+ void thread1()
+ {
+ while (!should_stop())
+ dlib::sleep(10);
+ }
+
+ void thread2()
+ {
+ while (!should_stop())
+ dlib::sleep(10);
+ }
+
+ void thread3()
+ {
+ while (!should_stop())
+ dlib::sleep(10);
+ }
+
+ };
+
+
+ void multithreaded_object_test (
+ )
+ /*!
+ ensures
+ - runs tests on dlib::multithreaded_object for compliance with the specs
+ !*/
+ {
+
+ count = 0;
+
+ for (int i = 0; i < 5; ++i)
+ {
+ {
+ test1 a1;
+ test2 a2;
+ test3_c1 a3;
+ test4_c2 a4;
+ test5 a5;
+ }
+ DLIB_TEST(count == (i+1)*3);
+ print_spinner();
+ }
+ count = 0;
+
+ for (int i = 0; i < 5; ++i)
+ {
+ {
+ test1 a1;
+ test2 a2;
+ test3_c1 a3;
+ test4_c2 a4;
+ test5 a5;
+ dlib::sleep(50);
+ }
+ DLIB_TEST(count == (i+1)*3);
+ print_spinner();
+ }
+ }
+
+
+ class multithreaded_object_tester : public tester
+ {
+ public:
+ multithreaded_object_tester (
+ ) :
+ tester ("test_multithreaded_object",
+ "Runs tests on the multithreaded_object component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ multithreaded_object_test();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/numerical_integration.cpp b/ml/dlib/dlib/test/numerical_integration.cpp
new file mode 100644
index 000000000..d0e247623
--- /dev/null
+++ b/ml/dlib/dlib/test/numerical_integration.cpp
@@ -0,0 +1,228 @@
+// Copyright (C) 2013 Steve Taylor (steve98654@gmail.com)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+// This function test battery is given in:
+//
+// Test functions taken from Pedro Gonnet's dissertation at ETH:
+// Adaptive Quadrature Re-Revisited
+// http://e-collection.library.ethz.ch/eserv/eth:65/eth-65-02.pdf
+
+#include <math.h>
+#include <dlib/matrix.h>
+#include <dlib/numeric_constants.h>
+#include <dlib/numerical_integration.h>
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.numerical_integration");
+
+ class numerical_integration_tester : public tester
+ {
+ public:
+ numerical_integration_tester (
+ ) :
+ tester ("test_numerical_integration",
+ "Runs tests on the numerical integration function.",
+ 0
+ )
+ {}
+
+ void perform_test()
+ {
+
+ dlog <<dlib::LINFO << "Testing integrate_function_adapt_simpson";
+
+ matrix<double,23,1> m;
+ double tol = 1e-10;
+ double eps = 1e-8;
+
+ m(0) = integrate_function_adapt_simp(&gg1, 0.0, 1.0, tol);
+ m(1) = integrate_function_adapt_simp(&gg2, 0.0, 1.0, tol);
+ m(2) = integrate_function_adapt_simp(&gg3, 0.0, 1.0, tol);
+ m(3) = integrate_function_adapt_simp(&gg4, 0.0, 1.0, tol);
+ m(4) = integrate_function_adapt_simp(&gg5, -1.0, 1.0, tol);
+ m(5) = integrate_function_adapt_simp(&gg6, 0.0, 1.0, tol);
+ m(6) = integrate_function_adapt_simp(&gg7, 0.0, 1.0, tol);
+ m(7) = integrate_function_adapt_simp(&gg8, 0.0, 1.0, tol);
+ m(8) = integrate_function_adapt_simp(&gg9, 0.0, 1.0, tol);
+ m(9) = integrate_function_adapt_simp(&gg10, 0.0, 1.0, tol);
+ m(10) = integrate_function_adapt_simp(&gg11, 0.0, 1.0, tol);
+ m(11) = integrate_function_adapt_simp(&gg12, 1e-6, 1.0, tol);
+ m(12) = integrate_function_adapt_simp(&gg13, 0.0, 10.0, tol);
+ m(13) = integrate_function_adapt_simp(&gg14, 0.0, 10.0, tol);
+ m(14) = integrate_function_adapt_simp(&gg15, 0.0, 10.0, tol);
+ m(15) = integrate_function_adapt_simp(&gg16, 0.01, 1.0, tol);
+ m(16) = integrate_function_adapt_simp(&gg17, 0.0, pi, tol);
+ m(17) = integrate_function_adapt_simp(&gg18, 0.0, 1.0, tol);
+ m(18) = integrate_function_adapt_simp(&gg19, -1.0, 1.0, tol);
+ m(19) = integrate_function_adapt_simp(&gg20, 0.0, 1.0, tol);
+ m(20) = integrate_function_adapt_simp(&gg21, 0.0, 1.0, tol);
+ m(21) = integrate_function_adapt_simp(&gg22, 0.0, 5.0, tol);
+
+ // Here we compare the approximated integrals against
+ // highly accurate approximations generated either from
+ // the exact integral values or Mathematica's NIntegrate
+ // function using a working precision of 20.
+
+ DLIB_TEST(abs(m(0) - 1.7182818284590452354) < 1e-11);
+ DLIB_TEST(abs(m(1) - 0.7000000000000000000) < eps);
+ DLIB_TEST(abs(m(2) - 0.6666666666666666667) < eps);
+ DLIB_TEST(abs(m(3) - 0.2397141133444008336) < eps);
+ DLIB_TEST(abs(m(4) - 1.5822329637296729331) < 1e-11);
+ DLIB_TEST(abs(m(5) - 0.4000000000000000000) < eps);
+ DLIB_TEST(abs(m(6) - 2.0000000000000000000) < 1e-4);
+ DLIB_TEST(abs(m(7) - 0.8669729873399110375) < eps);
+ DLIB_TEST(abs(m(8) - 1.1547005383792515290) < eps);
+ DLIB_TEST(abs(m(9) - 0.6931471805599453094) < eps);
+ DLIB_TEST(abs(m(10) - 0.3798854930417224753) < eps);
+ DLIB_TEST(abs(m(11) - 0.7775036341124982763) < eps);
+ DLIB_TEST(abs(m(12) - 0.5000000000000000000) < eps);
+ DLIB_TEST(abs(m(13) - 1.0000000000000000000) < eps);
+ DLIB_TEST(abs(m(14) - 0.4993633810764567446) < eps);
+ DLIB_TEST(abs(m(15) - 0.1121393035410217 ) < eps);
+ DLIB_TEST(abs(m(16) - 0.2910187828600526985) < eps);
+ DLIB_TEST(abs(m(17) + 0.4342944819032518276) < 1e-5);
+ DLIB_TEST(abs(m(18) - 1.56439644406905 ) < eps);
+ DLIB_TEST(abs(m(19) - 0.1634949430186372261) < eps);
+ DLIB_TEST(abs(m(20) - 0.0134924856494677726) < eps);
+ }
+
+ static double gg1(double x)
+ {
+ return pow(e,x);
+ }
+
+ static double gg2(double x)
+ {
+ if(x > 0.3)
+ {
+ return 1.0;
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ static double gg3(double x)
+ {
+ return pow(x,0.5);
+ }
+
+ static double gg4(double x)
+ {
+ return 23.0/25.0*cosh(x)-cos(x);
+ }
+
+ static double gg5(double x)
+ {
+ return 1/(pow(x,4) + pow(x,2) + 0.9);
+ }
+
+ static double gg6(double x)
+ {
+ return pow(x,1.5);
+ }
+
+ static double gg7(double x)
+ {
+ return pow(x,-0.5);
+ }
+
+ static double gg8(double x)
+ {
+ return 1/(1 + pow(x,4));
+ }
+
+ static double gg9(double x)
+ {
+ return 2/(2 + sin(10*pi*x));
+ }
+
+ static double gg10(double x)
+ {
+ return 1/(1+x);
+ }
+
+ static double gg11(double x)
+ {
+ return 1.0/(1 + pow(e,x));
+ }
+
+ static double gg12(double x)
+ {
+ return x/(pow(e,x)-1.0);
+ }
+
+ static double gg13(double x)
+ {
+ return sqrt(50.0)*pow(e,-50.0*pi*x*x);
+ }
+
+ static double gg14(double x)
+ {
+ return 25.0*pow(e,-25.0*x);
+ }
+
+ static double gg15(double x)
+ {
+ return 50.0/(pi*(2500.0*x*x+1));
+ }
+
+ static double gg16(double x)
+ {
+ return 50.0*pow((sin(50.0*pi*x)/(50.0*pi*x)),2);
+ }
+
+ static double gg17(double x)
+ {
+ return cos(cos(x)+3*sin(x)+2*cos(2*x)+3*cos(3*x));
+ }
+
+ static double gg18(double x)
+ {
+ return log10(x);
+ }
+
+ static double gg19(double x)
+ {
+ return 1/(1.005+x*x);
+ }
+
+ static double gg20(double x)
+ {
+ return 1/cosh(20.0*(x-1.0/5.0)) + 1/cosh(400.0*(x-2.0/5.0))
+ + 1/cosh(8000.0*(x-3.0/5.0));
+ }
+
+ static double gg21(double x)
+ {
+ return 1.0/(1.0+(230.0*x-30.0)*(230.0*x-30.0));
+ }
+
+ static double gg22(double x)
+ {
+ if(x < 1)
+ {
+ return (x + 1.0);
+ }
+ else if(x >= 1 && x <= 3)
+ {
+ return (3.0 - x);
+ }
+ else
+ {
+ return 2.0;
+ }
+ }
+
+ };
+
+ numerical_integration_tester a;
+}
+
diff --git a/ml/dlib/dlib/test/object_detector.cpp b/ml/dlib/dlib/test/object_detector.cpp
new file mode 100644
index 000000000..fdb72f520
--- /dev/null
+++ b/ml/dlib/dlib/test/object_detector.cpp
@@ -0,0 +1,1028 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/statistics.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include "tester.h"
+#include <dlib/pixel.h>
+#include <dlib/svm_threaded.h>
+#include <dlib/array.h>
+#include <dlib/set_utils.h>
+#include <dlib/array2d.h>
+#include <dlib/image_keypoint.h>
+#include <dlib/image_processing.h>
+#include <dlib/image_transforms.h>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.object_detector");
+
+// ----------------------------------------------------------------------------------------
+
+ struct funny_image
+ {
+ array2d<unsigned char> img;
+ long nr() const { return img.nr(); }
+ long nc() const { return img.nc(); }
+ };
+
+ void swap(funny_image& a, funny_image& b)
+ {
+ a.img.swap(b.img);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type,
+ typename detector_type
+ >
+ void validate_some_object_detector_stuff (
+ const image_array_type& images,
+ detector_type& detector,
+ double eps = 1e-10
+ )
+ {
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ std::vector<rectangle> dets = detector(images[i]);
+ std::vector<std::pair<double,rectangle> > dets2;
+
+ detector(images[i], dets2);
+
+ matrix<double,0,1> psi(detector.get_w().size());
+ matrix<double,0,1> psi2(detector.get_w().size());
+ const double thresh = detector.get_w()(detector.get_w().size()-1);
+
+ DLIB_TEST(dets.size() == dets2.size());
+ for (unsigned long j = 0; j < dets.size(); ++j)
+ {
+ DLIB_TEST(dets[j] == dets2[j].second);
+
+ const full_object_detection fdet = detector.get_scanner().get_full_object_detection(dets[j], detector.get_w());
+ psi = 0;
+ detector.get_scanner().get_feature_vector(fdet, psi);
+
+ double check_score = dot(psi,detector.get_w()) - thresh;
+ DLIB_TEST_MSG(std::abs(check_score - dets2[j].first) < eps, std::abs(check_score - dets2[j].first) << " check_score: "<< check_score);
+ }
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class very_simple_feature_extractor : noncopyable
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is a feature extractor which goes to every pixel in an image and
+ produces a 32 dimensional feature vector. This vector is an indicator vector
+ which records the pattern of pixel values in a 4-connected region. So it should
+ be able to distinguish basic things like whether or not a location falls on the
+ corner of a white box, on an edge, in the middle, etc.
+
+
+ Note that this object also implements the interface defined in dlib/image_keypoint/hashed_feature_image_abstract.h.
+ This means all the member functions in this object are supposed to behave as
+ described in the hashed_feature_image specification. So when you define your own
+ feature extractor objects you should probably refer yourself to that documentation
+ in addition to reading this example program.
+ !*/
+
+
+ public:
+
+ inline void load (
+ const funny_image& img_
+ )
+ {
+ const array2d<unsigned char>& img = img_.img;
+
+ feat_image.set_size(img.nr(), img.nc());
+ assign_all_pixels(feat_image,0);
+ for (long r = 1; r+1 < img.nr(); ++r)
+ {
+ for (long c = 1; c+1 < img.nc(); ++c)
+ {
+ unsigned char f = 0;
+ if (img[r][c]) f |= 0x1;
+ if (img[r][c+1]) f |= 0x2;
+ if (img[r][c-1]) f |= 0x4;
+ if (img[r+1][c]) f |= 0x8;
+ if (img[r-1][c]) f |= 0x10;
+
+ // Store the code value for the pattern of pixel values in the 4-connected
+ // neighborhood around this row and column.
+ feat_image[r][c] = f;
+ }
+ }
+ }
+
+ inline void load (
+ const array2d<unsigned char>& img
+ )
+ {
+ feat_image.set_size(img.nr(), img.nc());
+ assign_all_pixels(feat_image,0);
+ for (long r = 1; r+1 < img.nr(); ++r)
+ {
+ for (long c = 1; c+1 < img.nc(); ++c)
+ {
+ unsigned char f = 0;
+ if (img[r][c]) f |= 0x1;
+ if (img[r][c+1]) f |= 0x2;
+ if (img[r][c-1]) f |= 0x4;
+ if (img[r+1][c]) f |= 0x8;
+ if (img[r-1][c]) f |= 0x10;
+
+ // Store the code value for the pattern of pixel values in the 4-connected
+ // neighborhood around this row and column.
+ feat_image[r][c] = f;
+ }
+ }
+ }
+
+ inline size_t size () const { return feat_image.size(); }
+ inline long nr () const { return feat_image.nr(); }
+ inline long nc () const { return feat_image.nc(); }
+
+ inline long get_num_dimensions (
+ ) const
+ {
+ // Return the dimensionality of the vectors produced by operator()
+ return 32;
+ }
+
+ typedef std::vector<std::pair<unsigned int,double> > descriptor_type;
+
+ inline const descriptor_type& operator() (
+ long row,
+ long col
+ ) const
+ /*!
+ requires
+ - 0 <= row < nr()
+ - 0 <= col < nc()
+ ensures
+ - returns a sparse vector which describes the image at the given row and column.
+ In particular, this is a vector that is 0 everywhere except for one element.
+ !*/
+ {
+ feat.clear();
+ const unsigned long only_nonzero_element_index = feat_image[row][col];
+ feat.push_back(make_pair(only_nonzero_element_index,1.0));
+ return feat;
+ }
+
+ // This block of functions is meant to provide a way to map between the row/col space taken by
+ // this object's operator() function and the images supplied to load(). In this example it's trivial.
+ // However, in general, you might create feature extractors which don't perform extraction at every
+ // possible image location (e.g. the hog_image) and thus result in some more complex mapping.
+ inline const rectangle get_block_rect ( long row, long col) const { return centered_rect(col,row,3,3); }
+ inline const point image_to_feat_space ( const point& p) const { return p; }
+ inline const rectangle image_to_feat_space ( const rectangle& rect) const { return rect; }
+ inline const point feat_to_image_space ( const point& p) const { return p; }
+ inline const rectangle feat_to_image_space ( const rectangle& rect) const { return rect; }
+
+ inline friend void serialize ( const very_simple_feature_extractor& item, std::ostream& out) { serialize(item.feat_image, out); }
+ inline friend void deserialize ( very_simple_feature_extractor& item, std::istream& in ) { deserialize(item.feat_image, in); }
+
+ void copy_configuration ( const very_simple_feature_extractor& ){}
+
+ private:
+ array2d<unsigned char> feat_image;
+
+ // This variable doesn't logically contribute to the state of this object. It is here
+ // only to avoid returning a descriptor_type object by value inside the operator() method.
+ mutable descriptor_type feat;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ void make_simple_test_data (
+ image_array_type& images,
+ std::vector<std::vector<rectangle> >& object_locations
+ )
+ {
+ images.clear();
+ object_locations.clear();
+
+ images.resize(3);
+ images[0].set_size(400,400);
+ images[1].set_size(400,400);
+ images[2].set_size(400,400);
+
+ // set all the pixel values to black
+ assign_all_pixels(images[0], 0);
+ assign_all_pixels(images[1], 0);
+ assign_all_pixels(images[2], 0);
+
+ // Now make some squares and draw them onto our black images. All the
+ // squares will be 70 pixels wide and tall.
+
+ std::vector<rectangle> temp;
+ temp.push_back(centered_rect(point(100,100), 70,70));
+ fill_rect(images[0],temp.back(),255); // Paint the square white
+ temp.push_back(centered_rect(point(200,300), 70,70));
+ fill_rect(images[0],temp.back(),255); // Paint the square white
+ object_locations.push_back(temp);
+
+ temp.clear();
+ temp.push_back(centered_rect(point(140,200), 70,70));
+ fill_rect(images[1],temp.back(),255); // Paint the square white
+ temp.push_back(centered_rect(point(303,200), 70,70));
+ fill_rect(images[1],temp.back(),255); // Paint the square white
+ object_locations.push_back(temp);
+
+ temp.clear();
+ temp.push_back(centered_rect(point(123,121), 70,70));
+ fill_rect(images[2],temp.back(),255); // Paint the square white
+ object_locations.push_back(temp);
+
+ // corrupt each image with random noise just to make this a little more
+ // challenging
+ dlib::rand rnd;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ for (long r = 0; r < images[i].nr(); ++r)
+ {
+ for (long c = 0; c < images[i].nc(); ++c)
+ {
+ typedef typename image_array_type::type image_type;
+ typedef typename image_type::type type;
+ images[i][r][c] = (type)put_in_range(0,255,images[i][r][c] + 10*rnd.get_random_gaussian());
+ }
+ }
+ }
+ }
+
+ template <
+ typename image_array_type
+ >
+ void make_simple_test_data (
+ image_array_type& images,
+ std::vector<std::vector<full_object_detection> >& object_locations
+ )
+ {
+ images.clear();
+ object_locations.clear();
+
+
+ images.resize(3);
+ images[0].set_size(400,400);
+ images[1].set_size(400,400);
+ images[2].set_size(400,400);
+
+ // set all the pixel values to black
+ assign_all_pixels(images[0], 0);
+ assign_all_pixels(images[1], 0);
+ assign_all_pixels(images[2], 0);
+
+ // Now make some squares and draw them onto our black images. All the
+ // squares will be 70 pixels wide and tall.
+ const int shrink = 0;
+ std::vector<full_object_detection> temp;
+
+ rectangle rect = centered_rect(point(100,100), 70,71);
+ std::vector<point> movable_parts;
+ movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
+ temp.push_back(full_object_detection(rect, movable_parts));
+ fill_rect(images[0],rect,255); // Paint the square white
+
+ rect = centered_rect(point(200,200), 70,71);
+ movable_parts.clear();
+ movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
+ temp.push_back(full_object_detection(rect, movable_parts));
+ fill_rect(images[0],rect,255); // Paint the square white
+
+ object_locations.push_back(temp);
+ // ------------------------------------
+ temp.clear();
+
+ rect = centered_rect(point(140,200), 70,71);
+ movable_parts.clear();
+ movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
+ temp.push_back(full_object_detection(rect, movable_parts));
+ fill_rect(images[1],rect,255); // Paint the square white
+
+
+ rect = centered_rect(point(303,200), 70,71);
+ movable_parts.clear();
+ movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
+ temp.push_back(full_object_detection(rect, movable_parts));
+ fill_rect(images[1],rect,255); // Paint the square white
+
+ object_locations.push_back(temp);
+ // ------------------------------------
+ temp.clear();
+
+ rect = centered_rect(point(123,121), 70,71);
+ movable_parts.clear();
+ movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
+ movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
+ temp.push_back(full_object_detection(rect, movable_parts));
+ fill_rect(images[2],rect,255); // Paint the square white
+
+ object_locations.push_back(temp);
+
+ // corrupt each image with random noise just to make this a little more
+ // challenging
+ dlib::rand rnd;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ for (long r = 0; r < images[i].nr(); ++r)
+ {
+ for (long c = 0; c < images[i].nc(); ++c)
+ {
+ typedef typename image_array_type::type image_type;
+ typedef typename image_type::type type;
+ images[i][r][c] = (type)put_in_range(0,255,images[i][r][c] + 40*rnd.get_random_gaussian());
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_fhog_pyramid (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_fhog_pyramid()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef scan_fhog_pyramid<pyramid_down<2> > image_scanner_type;
+ image_scanner_type scanner;
+ scanner.set_detection_window_size(35,35);
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ trainer.set_overlap_tester(test_box_overlap(0,0));
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ validate_some_object_detector_stuff(images, detector, 1e-6);
+ }
+
+ {
+ std::vector<object_detector<image_scanner_type> > detectors;
+ detectors.push_back(detector);
+ detectors.push_back(detector);
+ detectors.push_back(detector);
+
+ std::vector<rectangle> dets1 = evaluate_detectors(detectors, images[0]);
+ std::vector<rectangle> dets2 = detector(images[0]);
+ DLIB_TEST(dets1.size() > 0);
+ DLIB_TEST(dets2.size()*3 == dets1.size());
+ dlib::set<rectangle>::kernel_1a_c d1, d2;
+ for (unsigned long i = 0; i < dets1.size(); ++i)
+ {
+ if (!d1.is_member(dets1[i]))
+ d1.add(dets1[i]);
+ }
+ for (unsigned long i = 0; i < dets2.size(); ++i)
+ {
+ if (!d2.is_member(dets2[i]))
+ d2.add(dets2[i]);
+ }
+ DLIB_TEST(d1.size() == d2.size());
+ DLIB_TEST(set_intersection_size(d1,d2) == d1.size());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_1 (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_1()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
+ typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
+ image_scanner_type scanner;
+ const rectangle object_box = compute_box_dimensions(1,35*35);
+ scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
+ setup_hashed_features(scanner, images, 9);
+ use_uniform_feature_weights(scanner);
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ trainer.set_overlap_tester(test_box_overlap(0,0));
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ validate_some_object_detector_stuff(images, detector);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_1_boxes (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_1_boxes()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
+ typedef scan_image_boxes<feature_extractor_type> image_scanner_type;
+ image_scanner_type scanner;
+ setup_hashed_features(scanner, images, 9);
+ use_uniform_feature_weights(scanner);
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ trainer.set_overlap_tester(test_box_overlap(0,0));
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ validate_some_object_detector_stuff(images, detector);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_1m (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_1m()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<full_object_detection> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
+ typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
+ image_scanner_type scanner;
+ const rectangle object_box = compute_box_dimensions(1,35*35);
+ std::vector<rectangle> mboxes;
+ const int mbox_size = 20;
+ mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
+ mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
+ mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
+ mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
+ scanner.add_detection_template(object_box, create_grid_detection_template(object_box,1,1), mboxes);
+ setup_hashed_features(scanner, images, 9);
+ use_uniform_feature_weights(scanner);
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ trainer.set_overlap_tester(test_box_overlap(0,0));
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ validate_some_object_detector_stuff(images, detector);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_1_fine_hog (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_1_fine_hog()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef hashed_feature_image<fine_hog_image<3,3,2,4,hog_signed_gradient> > feature_extractor_type;
+ typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
+ image_scanner_type scanner;
+ const rectangle object_box = compute_box_dimensions(1,35*35);
+ scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
+ setup_hashed_features(scanner, images, 9);
+ use_uniform_feature_weights(scanner);
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ trainer.set_overlap_tester(test_box_overlap(0,0));
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ validate_some_object_detector_stuff(images, detector);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_1_poly (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_1_poly()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef hashed_feature_image<poly_image<2> > feature_extractor_type;
+ typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
+ image_scanner_type scanner;
+ const rectangle object_box = compute_box_dimensions(1,35*35);
+ scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
+ setup_hashed_features(scanner, images, 9);
+ use_uniform_feature_weights(scanner);
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ trainer.set_overlap_tester(test_box_overlap(0,0));
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ validate_some_object_detector_stuff(images, detector);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_1m_poly (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_1_poly()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<full_object_detection> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef hashed_feature_image<poly_image<2> > feature_extractor_type;
+ typedef scan_image_pyramid<pyramid_down<3>, feature_extractor_type> image_scanner_type;
+ image_scanner_type scanner;
+ const rectangle object_box = compute_box_dimensions(1,35*35);
+ std::vector<rectangle> mboxes;
+ const int mbox_size = 20;
+ mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
+ mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
+ mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
+ mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
+ scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2), mboxes);
+ setup_hashed_features(scanner, images, 9);
+ use_uniform_feature_weights(scanner);
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ trainer.set_overlap_tester(test_box_overlap(0,0));
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ validate_some_object_detector_stuff(images, detector);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_1_poly_nn (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_1_poly_nn()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef nearest_neighbor_feature_image<poly_image<5> > feature_extractor_type;
+ typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
+ image_scanner_type scanner;
+
+ setup_grid_detection_templates(scanner, object_locations, 2, 2);
+ feature_extractor_type nnfe;
+ pyramid_down<2> pyr_down;
+ poly_image<5> polyi;
+ nnfe.set_basis(randomly_sample_image_features(images, pyr_down, polyi, 80));
+ scanner.copy_configuration(nnfe);
+
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ validate_some_object_detector_stuff(images, detector);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_1_poly_nn_boxes (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_1_poly_nn_boxes()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef nearest_neighbor_feature_image<poly_image<5> > feature_extractor_type;
+ typedef scan_image_boxes<feature_extractor_type> image_scanner_type;
+ image_scanner_type scanner;
+
+ feature_extractor_type nnfe;
+ pyramid_down<2> pyr_down;
+ poly_image<5> polyi;
+ nnfe.set_basis(randomly_sample_image_features(images, pyr_down, polyi, 80));
+ scanner.copy_configuration(nnfe);
+
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ validate_some_object_detector_stuff(images, detector);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_2 (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_2()";
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ grayscale_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images, object_locations);
+
+ typedef scan_image_pyramid<pyramid_down<5>, very_simple_feature_extractor> image_scanner_type;
+ image_scanner_type scanner;
+ const rectangle object_box = compute_box_dimensions(1,70*70);
+ scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
+ scanner.set_max_pyramid_levels(1);
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(0);
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3);
+ dlog << LINFO << "3-fold cross validation (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+ validate_some_object_detector_stuff(images, detector);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class pyramid_down_funny : noncopyable
+ {
+ pyramid_down<2> pyr;
+ public:
+
+ template <typename T>
+ dlib::vector<double,2> point_down ( const dlib::vector<T,2>& p) const { return pyr.point_down(p); }
+
+ template <typename T>
+ dlib::vector<double,2> point_up ( const dlib::vector<T,2>& p) const { return pyr.point_up(p); }
+
+ template <typename T>
+ dlib::vector<double,2> point_down ( const dlib::vector<T,2>& p, unsigned int levels) const { return pyr.point_down(p,levels); }
+
+ template <typename T>
+ dlib::vector<double,2> point_up ( const dlib::vector<T,2>& p, unsigned int levels) const { return pyr.point_up(p,levels); }
+
+ rectangle rect_up ( const rectangle& rect) const { return pyr.rect_up(rect); }
+
+ rectangle rect_up ( const rectangle& rect, unsigned int levels) const { return pyr.rect_up(rect,levels); }
+
+ rectangle rect_down ( const rectangle& rect) const { return pyr.rect_down(rect); }
+
+ rectangle rect_down ( const rectangle& rect, unsigned int levels) const { return pyr.rect_down(rect,levels); }
+
+ template <
+ typename in_image_type,
+ typename out_image_type
+ >
+ void operator() (
+ const in_image_type& original,
+ out_image_type& down
+ ) const
+ {
+ pyr(original.img, down.img);
+ }
+
+ };
+
+ // make sure everything works even when the image isn't a dlib::array2d.
+ // So test with funny_image.
+ void test_3 (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_3()";
+
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ typedef dlib::array<funny_image> funny_image_array_type;
+ grayscale_image_array_type images_temp;
+ funny_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images_temp, object_locations);
+ images.resize(images_temp.size());
+ for (unsigned long i = 0; i < images_temp.size(); ++i)
+ {
+ images[i].img.swap(images_temp[i]);
+ }
+
+ typedef scan_image_pyramid<pyramid_down_funny, very_simple_feature_extractor> image_scanner_type;
+ image_scanner_type scanner;
+ const rectangle object_box = compute_box_dimensions(1,70*70);
+ scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
+ scanner.set_max_pyramid_levels(1);
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3);
+ dlog << LINFO << "3-fold cross validation (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class funny_box_generator
+ {
+ public:
+ template <typename image_type>
+ void operator() (
+ const image_type& img,
+ std::vector<rectangle>& rects
+ ) const
+ {
+ rects.clear();
+ find_candidate_object_locations(img.img, rects);
+ dlog << LINFO << "funny_box_generator, rects.size(): "<< rects.size();
+ }
+ };
+
+ inline void serialize(const funny_box_generator&, std::ostream& ) {}
+ inline void deserialize(funny_box_generator&, std::istream& ) {}
+
+
+ // make sure everything works even when the image isn't a dlib::array2d.
+ // So test with funny_image.
+ void test_3_boxes (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test_3_boxes()";
+
+
+ typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
+ typedef dlib::array<funny_image> funny_image_array_type;
+ grayscale_image_array_type images_temp;
+ funny_image_array_type images;
+ std::vector<std::vector<rectangle> > object_locations;
+ make_simple_test_data(images_temp, object_locations);
+ images.resize(images_temp.size());
+ for (unsigned long i = 0; i < images_temp.size(); ++i)
+ {
+ images[i].img.swap(images_temp[i]);
+ }
+
+ typedef scan_image_boxes<very_simple_feature_extractor, funny_box_generator> image_scanner_type;
+ image_scanner_type scanner;
+ structural_object_detection_trainer<image_scanner_type> trainer(scanner);
+ trainer.set_num_threads(4);
+ object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
+
+ matrix<double> res = test_object_detection_function(detector, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3);
+ dlog << LINFO << "3-fold cross validation (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+
+ {
+ ostringstream sout;
+ serialize(detector, sout);
+ istringstream sin(sout.str());
+ object_detector<image_scanner_type> d2;
+ deserialize(d2, sin);
+ matrix<double> res = test_object_detection_function(d2, images, object_locations);
+ dlog << LINFO << "Test detector (precision,recall): " << res;
+ DLIB_TEST(sum(res) == 3);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class object_detector_tester : public tester
+ {
+ public:
+ object_detector_tester (
+ ) :
+ tester ("test_object_detector",
+ "Runs tests on the structural object detection stuff.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_fhog_pyramid();
+ test_1_boxes();
+ test_1_poly_nn_boxes();
+ test_3_boxes();
+
+ test_1();
+ test_1m();
+ test_1_fine_hog();
+ test_1_poly();
+ test_1m_poly();
+ test_1_poly_nn();
+ test_2();
+ test_3();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/oca.cpp b/ml/dlib/dlib/test/oca.cpp
new file mode 100644
index 000000000..97881a758
--- /dev/null
+++ b/ml/dlib/dlib/test/oca.cpp
@@ -0,0 +1,244 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/optimization.h>
+#include <dlib/svm.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.oca");
+
+// ----------------------------------------------------------------------------------------
+
+ class test_oca : public tester
+ {
+
+ public:
+ test_oca (
+ ) :
+ tester ("test_oca",
+ "Runs tests on the oca component.")
+ {
+ }
+
+ void perform_test(
+ )
+ {
+ print_spinner();
+
+ typedef matrix<double,0,1> w_type;
+ w_type w;
+
+ decision_function<linear_kernel<w_type> > df;
+ svm_c_linear_trainer<linear_kernel<w_type> > trainer;
+ trainer.set_c_class1(2);
+ trainer.set_c_class1(3);
+ trainer.set_learns_nonnegative_weights(true);
+ trainer.set_epsilon(1e-12);
+
+ std::vector<w_type> x;
+ w_type temp(2);
+ temp = -1, 1;
+ x.push_back(temp);
+ temp = 1, -1;
+ x.push_back(temp);
+
+ std::vector<double> y;
+ y.push_back(+1);
+ y.push_back(-1);
+
+ w_type true_w(3);
+
+ oca solver;
+
+ // test the version without a non-negativity constraint on w.
+ solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0);
+ dlog << LINFO << trans(w);
+ true_w = -0.5, 0.5, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ solver.solve_with_elastic_net(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0.5);
+ dlog << LINFO << trans(w);
+ true_w = -0.5, 0.5, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ print_spinner();
+
+ w_type prior = true_w;
+ solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
+ dlog << LINFO << trans(w);
+ true_w = -0.5, 0.5, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ prior = 0,0,0;
+ solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
+ dlog << LINFO << trans(w);
+ true_w = -0.5, 0.5, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ prior = -1,1,0;
+ solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
+ dlog << LINFO << trans(w);
+ true_w = -1.0, 1.0, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ prior = -0.2,0.2,0;
+ solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
+ dlog << LINFO << trans(w);
+ true_w = -0.5, 0.5, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ prior = -10.2,-1,0;
+ solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
+ dlog << LINFO << trans(w);
+ true_w = -10.2, -1.0, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ print_spinner();
+
+ // test the version with a non-negativity constraint on w.
+ solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 9999);
+ dlog << LINFO << trans(w);
+ true_w = 0, 1, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ df = trainer.train(x,y);
+ w = join_cols(df.basis_vectors(0), uniform_matrix<double>(1,1,-df.b));
+ true_w = 0, 1, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST_MSG(max(abs(w-true_w)) < 1e-9, max(abs(w-true_w)));
+
+
+ print_spinner();
+
+ // test the version with a non-negativity constraint on w.
+ solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 2);
+ dlog << LINFO << trans(w);
+ true_w = 0, 1, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ print_spinner();
+
+
+ // test the version with a non-negativity constraint on w.
+ solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 1);
+ dlog << LINFO << trans(w);
+ true_w = 0, 1, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ print_spinner();
+
+
+ // switching the labels should change which w weight goes negative.
+ y.clear();
+ y.push_back(-1);
+ y.push_back(+1);
+
+
+ solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0);
+ dlog << LINFO << trans(w);
+ true_w = 0.5, -0.5, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ print_spinner();
+
+ solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 1);
+ dlog << LINFO << trans(w);
+ true_w = 0.5, -0.5, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ print_spinner();
+
+ solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 2);
+ dlog << LINFO << trans(w);
+ true_w = 1, 0, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ print_spinner();
+
+ solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 5);
+ dlog << LINFO << trans(w);
+ true_w = 1, 0, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ df = trainer.train(x,y);
+ w = join_cols(df.basis_vectors(0), uniform_matrix<double>(1,1,-df.b));
+ true_w = 1, 0, 0;
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST_MSG(max(abs(w-true_w)) < 1e-9, max(abs(w-true_w)));
+
+
+
+ x.clear();
+ y.clear();
+ temp = -2, 2;
+ x.push_back(temp);
+ temp = 0, -0;
+ x.push_back(temp);
+
+ y.push_back(+1);
+ y.push_back(-1);
+
+ trainer.set_c(10);
+ df = trainer.train(x,y);
+ w = join_cols(df.basis_vectors(0), uniform_matrix<double>(1,1,-df.b));
+ true_w = 0, 1, -1;
+ dlog << LINFO << "w: " << trans(w);
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+
+ x.clear();
+ y.clear();
+ temp = -2, 2;
+ x.push_back(temp);
+ temp = 0, -0;
+ x.push_back(temp);
+
+ y.push_back(-1);
+ y.push_back(+1);
+
+ trainer.set_c(10);
+ df = trainer.train(x,y);
+ w = join_cols(df.basis_vectors(0), uniform_matrix<double>(1,1,-df.b));
+ true_w = 1, 0, 1;
+ dlog << LINFO << "w: " << trans(w);
+ dlog << LINFO << "error: "<< max(abs(w-true_w));
+ DLIB_TEST(max(abs(w-true_w)) < 1e-10);
+
+ }
+
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/one_vs_all_trainer.cpp b/ml/dlib/dlib/test/one_vs_all_trainer.cpp
new file mode 100644
index 000000000..b928714b2
--- /dev/null
+++ b/ml/dlib/dlib/test/one_vs_all_trainer.cpp
@@ -0,0 +1,305 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+#include <vector>
+#include <sstream>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.one_vs_all_trainer");
+
+
+ class test_one_vs_all_trainer : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ test_one_vs_all_trainer (
+ ) :
+ tester (
+ "test_one_vs_all_trainer", // the command line argument name for this test
+ "Run tests on the one_vs_all_trainer stuff.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+
+
+ template <typename sample_type, typename label_type>
+ void generate_data (
+ std::vector<sample_type>& samples,
+ std::vector<label_type>& labels
+ )
+ {
+ const long num = 50;
+
+ sample_type m;
+
+ dlib::rand rnd;
+
+
+ // make some samples near the origin
+ double radius = 0.5;
+ for (long i = 0; i < num+10; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ // add this sample to our set of samples we will run k-means
+ samples.push_back(m);
+ labels.push_back(1);
+ }
+
+ // make some samples in a circle around the origin but far away
+ radius = 10.0;
+ for (long i = 0; i < num+20; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ // add this sample to our set of samples we will run k-means
+ samples.push_back(m);
+ labels.push_back(2);
+ }
+
+ // make some samples in a circle around the point (25,25)
+ radius = 4.0;
+ for (long i = 0; i < num+30; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ // translate this point away from the origin
+ m(0) += 25;
+ m(1) += 25;
+
+ // add this sample to our set of samples we will run k-means
+ samples.push_back(m);
+ labels.push_back(3);
+ }
+ }
+
+ template <typename label_type, typename scalar_type>
+ void run_test (
+ )
+ {
+ print_spinner();
+ typedef matrix<scalar_type,2,1> sample_type;
+
+ std::vector<sample_type> samples, norm_samples;
+ std::vector<label_type> labels;
+
+ // First, get our labeled set of training data
+ generate_data(samples, labels);
+
+ typedef one_vs_all_trainer<any_trainer<sample_type,scalar_type>,label_type > ova_trainer;
+
+
+ ova_trainer trainer;
+
+ typedef polynomial_kernel<sample_type> poly_kernel;
+ typedef radial_basis_kernel<sample_type> rbf_kernel;
+
+ // make the binary trainers and set some parameters
+ krr_trainer<rbf_kernel> rbf_trainer;
+ svm_nu_trainer<poly_kernel> poly_trainer;
+ poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
+ rbf_trainer.set_kernel(rbf_kernel(0.1));
+
+
+ trainer.set_trainer(rbf_trainer);
+ trainer.set_trainer(poly_trainer, 1);
+
+ randomize_samples(samples, labels);
+ matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
+
+ print_spinner();
+
+ matrix<scalar_type> ans(3,3);
+ ans = 60, 0, 0,
+ 0, 70, 0,
+ 0, 0, 80;
+
+ DLIB_TEST_MSG(ans == res, "res: \n" << res);
+
+ // test using a normalized_function with a one_vs_all_decision_function
+ {
+ poly_trainer.set_kernel(poly_kernel(1.1, 1, 2));
+ trainer.set_trainer(poly_trainer, 1);
+ vector_normalizer<sample_type> normalizer;
+ normalizer.train(samples);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ norm_samples.push_back(normalizer(samples[i]));
+ normalized_function<one_vs_all_decision_function<ova_trainer> > ndf;
+ ndf.function = trainer.train(norm_samples, labels);
+ ndf.normalizer = normalizer;
+ DLIB_TEST(ndf(samples[0]) == labels[0]);
+ DLIB_TEST(ndf(samples[40]) == labels[40]);
+ DLIB_TEST(ndf(samples[90]) == labels[90]);
+ DLIB_TEST(ndf(samples[120]) == labels[120]);
+ poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
+ trainer.set_trainer(poly_trainer, 1);
+ print_spinner();
+ }
+
+ one_vs_all_decision_function<ova_trainer> df = trainer.train(samples, labels);
+
+ DLIB_TEST(df.number_of_classes() == 3);
+
+ DLIB_TEST(df(samples[0]) == labels[0]);
+ DLIB_TEST(df(samples[90]) == labels[90]);
+
+
+ one_vs_all_decision_function<ova_trainer,
+ decision_function<poly_kernel>, // This is the output of the poly_trainer
+ decision_function<rbf_kernel> // This is the output of the rbf_trainer
+ > df2, df3;
+
+
+ df2 = df;
+ ofstream fout("df.dat", ios::binary);
+ serialize(df2, fout);
+ fout.close();
+
+ // load the function back in from disk and store it in df3.
+ ifstream fin("df.dat", ios::binary);
+ deserialize(df3, fin);
+
+
+ DLIB_TEST(df3(samples[0]) == labels[0]);
+ DLIB_TEST(df3(samples[90]) == labels[90]);
+ res = test_multiclass_decision_function(df3, samples, labels);
+
+ DLIB_TEST(res == ans);
+
+
+ }
+
+ template <typename label_type, typename scalar_type>
+ void run_probabilistic_test (
+ )
+ {
+ print_spinner();
+ typedef matrix<scalar_type,2,1> sample_type;
+
+ std::vector<sample_type> samples;
+ std::vector<label_type> labels;
+
+ // First, get our labeled set of training data
+ generate_data(samples, labels);
+
+ typedef one_vs_all_trainer<any_trainer<sample_type,scalar_type>,label_type > ova_trainer;
+
+
+ ova_trainer trainer;
+
+ typedef polynomial_kernel<sample_type> poly_kernel;
+ typedef radial_basis_kernel<sample_type> rbf_kernel;
+
+ // make the binary trainers and set some parameters
+ krr_trainer<rbf_kernel> rbf_trainer;
+ svm_nu_trainer<poly_kernel> poly_trainer;
+ poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
+ rbf_trainer.set_kernel(rbf_kernel(0.1));
+
+
+ trainer.set_trainer(probabilistic(rbf_trainer, 3));
+ trainer.set_trainer(probabilistic(poly_trainer, 3), 1);
+
+ randomize_samples(samples, labels);
+ matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
+
+ print_spinner();
+
+ matrix<scalar_type> ans(3,3);
+ ans = 60, 0, 0,
+ 0, 70, 0,
+ 0, 0, 80;
+
+ DLIB_TEST_MSG(ans == res, "res: \n" << res);
+
+ one_vs_all_decision_function<ova_trainer> df = trainer.train(samples, labels);
+
+ DLIB_TEST(df.number_of_classes() == 3);
+
+ DLIB_TEST(df(samples[0]) == labels[0]);
+ DLIB_TEST(df(samples[90]) == labels[90]);
+
+
+ one_vs_all_decision_function<ova_trainer,
+ probabilistic_function<decision_function<poly_kernel> >, // This is the output of the poly_trainer
+ probabilistic_function<decision_function<rbf_kernel> > // This is the output of the rbf_trainer
+ > df2, df3;
+
+
+ df2 = df;
+ ofstream fout("df.dat", ios::binary);
+ serialize(df2, fout);
+ fout.close();
+
+ // load the function back in from disk and store it in df3.
+ ifstream fin("df.dat", ios::binary);
+ deserialize(df3, fin);
+
+
+ DLIB_TEST(df3(samples[0]) == labels[0]);
+ DLIB_TEST(df3(samples[90]) == labels[90]);
+ res = test_multiclass_decision_function(df3, samples, labels);
+
+ DLIB_TEST(res == ans);
+
+
+ }
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "run_test<double,double>()";
+ run_test<double,double>();
+
+ dlog << LINFO << "run_test<int,double>()";
+ run_test<int,double>();
+
+ dlog << LINFO << "run_test<double,float>()";
+ run_test<double,float>();
+
+ dlog << LINFO << "run_test<int,float>()";
+ run_test<int,float>();
+
+ dlog << LINFO << "run_probabilistic_test<double,double>()";
+ run_probabilistic_test<double,double>();
+
+ dlog << LINFO << "run_probabilistic_test<int,double>()";
+ run_probabilistic_test<int,double>();
+
+ dlog << LINFO << "run_probabilistic_test<double,float>()";
+ run_probabilistic_test<double,float>();
+
+ dlog << LINFO << "run_probabilistic_test<int,float>()";
+ run_probabilistic_test<int,float>();
+ }
+ };
+
+ test_one_vs_all_trainer a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/one_vs_one_trainer.cpp b/ml/dlib/dlib/test/one_vs_one_trainer.cpp
new file mode 100644
index 000000000..70cdefaf9
--- /dev/null
+++ b/ml/dlib/dlib/test/one_vs_one_trainer.cpp
@@ -0,0 +1,218 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+#include <dlib/statistics.h>
+#include <vector>
+#include <sstream>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.one_vs_one_trainer");
+
+
+ class test_one_vs_one_trainer : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ test_one_vs_one_trainer (
+ ) :
+ tester (
+ "test_one_vs_one_trainer", // the command line argument name for this test
+ "Run tests on the one_vs_one_trainer stuff.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+
+
+ template <typename sample_type, typename label_type>
+ void generate_data (
+ std::vector<sample_type>& samples,
+ std::vector<label_type>& labels
+ )
+ {
+ const long num = 50;
+
+ sample_type m;
+
+ dlib::rand rnd;
+
+
+ // make some samples near the origin
+ double radius = 0.5;
+ for (long i = 0; i < num+10; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ // add this sample to our set of samples we will run k-means
+ samples.push_back(m);
+ labels.push_back(1);
+ }
+
+ // make some samples in a circle around the origin but far away
+ radius = 10.0;
+ for (long i = 0; i < num+20; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ // add this sample to our set of samples we will run k-means
+ samples.push_back(m);
+ labels.push_back(2);
+ }
+
+ // make some samples in a circle around the point (25,25)
+ radius = 4.0;
+ for (long i = 0; i < num+30; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ // translate this point away from the origin
+ m(0) += 25;
+ m(1) += 25;
+
+ // add this sample to our set of samples we will run k-means
+ samples.push_back(m);
+ labels.push_back(3);
+ }
+ }
+
+ template <typename label_type, typename scalar_type>
+ void run_test (
+ )
+ {
+ print_spinner();
+ typedef matrix<scalar_type,2,1> sample_type;
+
+ std::vector<sample_type> samples, norm_samples;
+ std::vector<label_type> labels;
+
+ // First, get our labeled set of training data
+ generate_data(samples, labels);
+
+ typedef one_vs_one_trainer<any_trainer<sample_type,scalar_type>,label_type > ovo_trainer;
+
+
+ ovo_trainer trainer;
+
+ typedef histogram_intersection_kernel<sample_type> hist_kernel;
+ typedef radial_basis_kernel<sample_type> rbf_kernel;
+
+ // make the binary trainers and set some parameters
+ krr_trainer<rbf_kernel> rbf_trainer;
+ svm_nu_trainer<hist_kernel> hist_trainer;
+ rbf_trainer.set_kernel(rbf_kernel(0.1));
+
+
+ trainer.set_trainer(rbf_trainer);
+ trainer.set_trainer(hist_trainer, 1, 2);
+
+ randomize_samples(samples, labels);
+ matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
+
+ print_spinner();
+
+ matrix<scalar_type> ans(3,3);
+ ans = 60, 0, 0,
+ 0, 70, 0,
+ 0, 0, 80;
+
+ DLIB_TEST_MSG(ans == res, "res: \n" << res);
+
+ // test using a normalized_function with a one_vs_one_decision_function
+ {
+ trainer.set_trainer(hist_trainer, 1, 2);
+ vector_normalizer<sample_type> normalizer;
+ normalizer.train(samples);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ norm_samples.push_back(normalizer(samples[i]));
+ normalized_function<one_vs_one_decision_function<ovo_trainer> > ndf;
+ ndf.function = trainer.train(norm_samples, labels);
+ ndf.normalizer = normalizer;
+ DLIB_TEST(ndf(samples[0]) == labels[0]);
+ DLIB_TEST(ndf(samples[40]) == labels[40]);
+ DLIB_TEST(ndf(samples[90]) == labels[90]);
+ DLIB_TEST(ndf(samples[120]) == labels[120]);
+ trainer.set_trainer(hist_trainer, 1, 2);
+ print_spinner();
+ }
+
+
+
+
+ one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels);
+
+ DLIB_TEST(df.number_of_classes() == 3);
+
+ DLIB_TEST(df(samples[0]) == labels[0]);
+ DLIB_TEST(df(samples[90]) == labels[90]);
+
+
+ one_vs_one_decision_function<ovo_trainer,
+ decision_function<hist_kernel>, // This is the output of the hist_trainer
+ decision_function<rbf_kernel> // This is the output of the rbf_trainer
+ > df2, df3;
+
+
+ df2 = df;
+ ofstream fout("df.dat", ios::binary);
+ serialize(df2, fout);
+ fout.close();
+
+ // load the function back in from disk and store it in df3.
+ ifstream fin("df.dat", ios::binary);
+ deserialize(df3, fin);
+
+
+ DLIB_TEST(df3(samples[0]) == labels[0]);
+ DLIB_TEST(df3(samples[90]) == labels[90]);
+ res = test_multiclass_decision_function(df3, samples, labels);
+
+ DLIB_TEST(res == ans);
+
+
+ }
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "run_test<double,double>()";
+ run_test<double,double>();
+
+ dlog << LINFO << "run_test<int,double>()";
+ run_test<int,double>();
+
+ dlog << LINFO << "run_test<double,float>()";
+ run_test<double,float>();
+
+ dlog << LINFO << "run_test<int,float>()";
+ run_test<int,float>();
+ }
+ };
+
+ test_one_vs_one_trainer a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/opt_qp_solver.cpp b/ml/dlib/dlib/test/opt_qp_solver.cpp
new file mode 100644
index 000000000..ffa386323
--- /dev/null
+++ b/ml/dlib/dlib/test/opt_qp_solver.cpp
@@ -0,0 +1,813 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/optimization.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include <dlib/rand.h>
+#include <dlib/string.h>
+#include <dlib/statistics.h>
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.opt_qp_solver");
+
+// ----------------------------------------------------------------------------------------
+
+ class test_smo
+ {
+ public:
+ double penalty;
+ double C;
+
+ double operator() (
+ const matrix<double,0,1>& alpha
+ ) const
+ {
+
+ double obj = 0.5* trans(alpha)*Q*alpha - trans(alpha)*b;
+ double c1 = pow(sum(alpha)-C,2);
+ double c2 = sum(pow(pointwise_multiply(alpha, alpha<0), 2));
+
+ obj += penalty*(c1 + c2);
+
+ return obj;
+ }
+
+ matrix<double> Q, b;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class test_smo_derivative
+ {
+ public:
+ double penalty;
+ double C;
+
+ matrix<double,0,1> operator() (
+ const matrix<double,0,1>& alpha
+ ) const
+ {
+
+ matrix<double,0,1> obj = Q*alpha - b;
+ matrix<double,0,1> c1 = uniform_matrix<double>(alpha.size(),1, 2*(sum(alpha)-C));
+ matrix<double,0,1> c2 = 2*pointwise_multiply(alpha, alpha<0);
+
+ return obj + penalty*(c1 + c2);
+ }
+
+ matrix<double> Q, b;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ double compute_objective_value (
+ const matrix<double,0,1>& w,
+ const matrix<double>& A,
+ const matrix<double,0,1>& b,
+ const double C
+ )
+ {
+ return 0.5*dot(w,w) + C*max(trans(A)*w + b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_qp4_test1()
+ {
+ matrix<double> A(3,2);
+ A = 1,2,
+ -3,1,
+ 6,7;
+
+ matrix<double,0,1> b(2);
+ b = 1,
+ 2;
+
+ const double C = 2;
+
+ matrix<double,0,1> alpha(2), true_alpha(2), d(3), lambda;
+ alpha = C/2, C/2;
+ d = 0;
+
+ solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800);
+ matrix<double,0,1> w = lowerbound(-A*alpha, 0);
+
+ dlog << LINFO << "*******************************************************";
+
+ dlog << LINFO << "w: " << trans(w);
+
+ dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
+ w = 0;
+ dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
+
+ dlog << LINFO << "alpha: " << trans(alpha);
+ true_alpha = 0, 2;
+ dlog << LINFO << "true alpha: "<< trans(true_alpha);
+
+ dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
+ DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_qp4_test2()
+ {
+ matrix<double> A(3,2);
+ A = 1,2,
+ 3,-1,
+ 6,7;
+
+ matrix<double,0,1> b(2);
+ b = 1,
+ 2;
+
+ const double C = 2;
+
+ matrix<double,0,1> alpha(2), true_alpha(2), d(3), lambda;
+ alpha = C/2, C/2;
+ d = 0;
+
+ solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800);
+ matrix<double,0,1> w = lowerbound(-A*alpha, 0);
+
+ dlog << LINFO << "*******************************************************";
+
+ dlog << LINFO << "w: " << trans(w);
+
+ dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
+ w = 0, 0.25, 0;
+ dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
+
+ dlog << LINFO << "alpha: " << trans(alpha);
+ true_alpha = 0.43750, 1.56250;
+ dlog << LINFO << "true alpha: "<< trans(true_alpha);
+
+ dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
+ DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_qp4_test3()
+ {
+ matrix<double> A(3,2);
+ A = 1,2,
+ -3,-1,
+ 6,7;
+
+ matrix<double,0,1> b(2);
+ b = 1,
+ 2;
+
+ const double C = 2;
+
+ matrix<double,0,1> alpha(2), true_alpha(2), d(3), lambda;
+ alpha = C/2, C/2;
+ d = 0;
+
+ solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800);
+ matrix<double,0,1> w = lowerbound(-A*alpha, 0);
+
+ dlog << LINFO << "*******************************************************";
+
+ dlog << LINFO << "w: " << trans(w);
+
+ dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
+ w = 0, 2, 0;
+ dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
+
+ dlog << LINFO << "alpha: " << trans(alpha);
+ true_alpha = 0, 2;
+ dlog << LINFO << "true alpha: "<< trans(true_alpha);
+
+ dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
+ DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_qp4_test5()
+ {
+ matrix<double> A(3,3);
+ A = 1,2,4,
+ 3,1,6,
+ 6,7,-2;
+
+ matrix<double,0,1> b(3);
+ b = 1,
+ 2,
+ 3;
+
+ const double C = 2;
+
+ matrix<double,0,1> alpha(3), true_alpha(3), d(3), lambda;
+ alpha = C/2, C/2, 0;
+ d = 0;
+
+ solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800);
+ matrix<double,0,1> w = lowerbound(-A*alpha, 0);
+
+
+ dlog << LINFO << "*******************************************************";
+
+ dlog << LINFO << "w: " << trans(w);
+
+ dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
+ w = 0, 0, 0.11111111111111111111;
+ dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
+
+ dlog << LINFO << "alpha: " << trans(alpha);
+ true_alpha = 0, 0.432098765432099, 1.567901234567901;
+ dlog << LINFO << "true alpha: "<< trans(true_alpha);
+
+ dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
+ DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_qp4_test4()
+ {
+ matrix<double> A(3,2);
+ A = 1,2,
+ 3,1,
+ 6,7;
+
+ matrix<double,0,1> b(2);
+ b = 1,
+ 2;
+
+ const double C = 2;
+
+ matrix<double,0,1> alpha(2), d(3), lambda;
+ alpha = C/2, C/2;
+
+ solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800);
+ matrix<double,0,1> w = lowerbound(-A*alpha, 0);
+
+ dlog << LINFO << "*******************************************************";
+
+ dlog << LINFO << "w: " << trans(w);
+
+ const double computed_obj = compute_objective_value(w,A,b,C);
+ w = 0, 0, 0;
+ const double true_obj = compute_objective_value(w,A,b,C);
+ dlog << LINFO << "computed obj: "<< computed_obj;
+ dlog << LINFO << "with true w obj: "<< true_obj;
+
+ DLIB_TEST_MSG(abs(computed_obj - true_obj) < 1e-8, abs(computed_obj - true_obj));
+ }
+
+ void test_qp4_test6()
+ {
+ matrix<double> A(3,3);
+ A = 1,2,4,
+ 3,1,6,
+ 6,7,-2;
+
+ matrix<double,0,1> b(3);
+ b = -1,
+ -2,
+ -3;
+
+ const double C = 2;
+
+ matrix<double,0,1> alpha(3), d(3), lambda;
+ d = 0;
+ alpha = C/2, C/2, 0;
+
+ unsigned long iters = solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 3000);
+ matrix<double,0,1> w = lowerbound(-A*alpha, 0);
+
+ dlog << LINFO << "*******************************************************";
+
+ dlog << LINFO << "alpha: " << trans(alpha);
+ dlog << LINFO << "lambda: " << trans(lambda);
+ dlog << LINFO << "w: " << trans(w);
+
+
+ const double computed_obj = compute_objective_value(w,A,b,C);
+ w = 0, 0, 0;
+ const double true_obj = compute_objective_value(w,A,b,C);
+ dlog << LINFO << "computed obj: "<< computed_obj;
+ dlog << LINFO << "with true w obj: "<< true_obj;
+
+ DLIB_TEST_MSG(abs(computed_obj - true_obj) < 1e-8,
+ "computed_obj: "<< computed_obj << " true_obj: " << true_obj << " delta: "<< abs(computed_obj - true_obj)
+ << " iters: " << iters
+ << "\n alpha: " << trans(alpha)
+ << " lambda: " << trans(lambda)
+ );
+ }
+
+ void test_qp4_test7()
+ {
+ matrix<double> A(3,3);
+ A = -1,2,4,
+ -3,1,6,
+ -6,7,-2;
+
+ matrix<double,0,1> b(3);
+ b = -1,
+ -2,
+ 3;
+
+ matrix<double> Q(3,3);
+ Q = 4,-5,6,
+ 1,-4,2,
+ -9,-4,5;
+ Q = Q*trans(Q);
+
+ const double C = 2;
+
+ matrix<double,0,1> alpha(3), true_alpha(3), d(3), lambda;
+ alpha = C/2, C/2, 0;
+ d = 0;
+
+ solve_qp4_using_smo(A, Q, b, d, alpha, lambda, 1e-9, 800);
+
+ dlog << LINFO << "*******************************************************";
+
+ dlog << LINFO << "alpha: " << trans(alpha);
+ true_alpha = 0, 2, 0;
+ dlog << LINFO << "true alpha: "<< trans(true_alpha);
+
+ dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
+ DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_solve_qp4_using_smo()
+ {
+ test_qp4_test1();
+ test_qp4_test2();
+ test_qp4_test3();
+ test_qp4_test4();
+ test_qp4_test5();
+ test_qp4_test6();
+ test_qp4_test7();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double max_distance_to(
+ const std::vector<matrix<double,0,1>>& a,
+ const std::vector<matrix<double,0,1>>& b
+ )
+ {
+ double best_dist = 0;
+ for (auto&& aa : a)
+ {
+ for (auto&& bb : b)
+ {
+ double dist = length(aa-bb);
+ if (dist > best_dist)
+ best_dist = dist;
+ }
+ }
+ return best_dist;
+ }
+
+ double min_distance_to(
+ const std::vector<matrix<double,0,1>>& a,
+ const std::vector<matrix<double,0,1>>& b
+ )
+ {
+ double best_dist = std::numeric_limits<double>::infinity();
+ for (auto&& aa : a)
+ {
+ for (auto&& bb : b)
+ {
+ double dist = length(aa-bb);
+ if (dist < best_dist)
+ best_dist = dist;
+ }
+ }
+ return best_dist;
+ }
+
+ double min_distance_to(
+ const std::vector<matrix<double,0,1>>& s,
+ const matrix<double,0,1>& v
+ )
+ {
+ double best_dist = std::numeric_limits<double>::infinity();
+ for (auto& x : s)
+ {
+ double dist = length(v-x);
+ if (dist < best_dist)
+ {
+ best_dist = dist;
+ }
+ }
+ return best_dist;
+ }
+
+ double max_distance_to(
+ const std::vector<matrix<double,0,1>>& s,
+ const matrix<double,0,1>& v
+ )
+ {
+ double best_dist = 0;
+ for (auto& x : s)
+ {
+ double dist = length(v-x);
+ if (dist > best_dist)
+ {
+ best_dist = dist;
+ }
+ }
+ return best_dist;
+ }
+
+ void test_find_gap_between_convex_hulls()
+ {
+ print_spinner();
+ std::vector<matrix<double,0,1>> set1, set2;
+
+ const double dist_thresh = 5.47723;
+
+ // generate two groups of points that are pairwise close within each set and
+ // pairwise far apart between each set, according to dist_thresh distance threshold.
+ bool which = true;
+ for (size_t i = 0; i < 10000; ++i)
+ {
+ matrix<double,0,1> v = gaussian_randm(15,1,i);
+ const auto min_dist1 = min_distance_to(set1,v);
+ const auto min_dist2 = min_distance_to(set2,v);
+ const auto max_dist1 = max_distance_to(set1,v);
+ const auto max_dist2 = max_distance_to(set2,v);
+ if (which)
+ {
+ if ((set1.size()==0 || max_dist1 < dist_thresh) && min_dist2 > dist_thresh )
+ {
+ set1.push_back(v);
+ which = !which;
+ }
+ }
+ else
+ {
+ if ((set2.size()==0 || max_dist2 < dist_thresh) && min_dist1 > dist_thresh)
+ {
+ set2.push_back(v);
+ which = !which;
+ }
+ }
+ }
+
+ dlog << LINFO << "set1.size(): "<< set1.size();
+ dlog << LINFO << "set2.size(): "<< set2.size();
+
+
+ // make sure we generated the points correctly.
+ dlog << LINFO << "dist_thresh: "<< dist_thresh;
+ dlog << LINFO << "max distance between set1 and set1: "<< max_distance_to(set1,set1);
+ dlog << LINFO << "max distance between set2 and set2: "<< max_distance_to(set2,set2);
+ DLIB_TEST(max_distance_to(set1,set1) < dist_thresh);
+ DLIB_TEST(max_distance_to(set2,set2) < dist_thresh);
+ dlog << LINFO << "min distance between set2 and set1: "<< min_distance_to(set2,set1);
+ DLIB_TEST(min_distance_to(set2,set1) > dist_thresh);
+
+
+ // It is slightly counterintuitive but true that points picked using the above procedure
+ // will have elements of their convex hulls that are much closer together than
+ // dist_thresh, even though none of the vertices of the hulls are that close
+ // together. This is especially true in high dimensions. So let's use this to
+ // test find_gap_between_convex_hulls(). It should be able to find a pair of
+ // points in the convex hulls of our sets that are a lot closer together than
+ // dist_thresh.
+
+ // First we need to convert the vectors to matrices.
+ matrix<double> A, B;
+ A.set_size(set1[0].size(), set1.size());
+ B.set_size(set2[0].size(), set2.size());
+ for (long c = 0; c < A.nc(); ++c)
+ set_colm(A,c) = set1[c];
+ for (long c = 0; c < B.nc(); ++c)
+ set_colm(B,c) = set2[c];
+
+ matrix<double,0,1> c1, c2;
+ find_gap_between_convex_hulls(A, B, c1, c2, 0.0001);
+ // make sure c1 and c2 are convex combinations.
+ DLIB_TEST(abs(sum(c1)-1) < 1e-8);
+ DLIB_TEST(abs(sum(c2)-1) < 1e-8);
+ DLIB_TEST(min(c1) >= 0);
+ DLIB_TEST(min(c2) >= 0);
+
+ // now test that the points found are close together.
+ dlog << LINFO << "dist: "<< length(A*c1 - B*c2);
+ DLIB_TEST(length(A*c1 - B*c2) < 4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_solve_qp_box_constrained_blockdiag()
+ {
+ dlib::rand rnd;
+ for (int iter = 0; iter < 50; ++iter)
+ {
+ print_spinner();
+
+ matrix<double> Q1, Q2;
+ matrix<double,0,1> b1, b2;
+
+ Q1 = randm(4,4,rnd); Q1 = Q1*trans(Q1);
+ Q2 = randm(4,4,rnd); Q2 = Q2*trans(Q2);
+ b1 = gaussian_randm(4,1, iter*2+0);
+ b2 = gaussian_randm(4,1, iter*2+1);
+
+ std::map<unordered_pair<size_t>, matrix<double,0,1>> offdiag;
+
+ if (rnd.get_random_gaussian() > 0)
+ offdiag[make_unordered_pair(0,0)] = randm(4,1,rnd);
+ if (rnd.get_random_gaussian() > 0)
+ offdiag[make_unordered_pair(1,0)] = randm(4,1,rnd);
+ if (rnd.get_random_gaussian() > 0)
+ offdiag[make_unordered_pair(1,1)] = randm(4,1,rnd);
+
+ std::vector<matrix<double>> Q_blocks = {Q1, Q2};
+ std::vector<matrix<double,0,1>> bs = {b1, b2};
+
+
+ // make the single big Q and b
+ matrix<double> Q = join_cols(join_rows(Q1, zeros_matrix(Q1)),
+ join_rows(zeros_matrix(Q2),Q2));
+ matrix<double,0,1> b = join_cols(b1,b2);
+ for (auto& p : offdiag)
+ {
+ long r = p.first.first;
+ long c = p.first.second;
+ set_subm(Q, 4*r,4*c, 4,4) += diagm(p.second);
+ if (c != r)
+ set_subm(Q, 4*c,4*r, 4,4) += diagm(p.second);
+ }
+
+
+ matrix<double,0,1> alpha = zeros_matrix(b);
+ matrix<double,0,1> lower = -10000*ones_matrix(b);
+ matrix<double,0,1> upper = 10000*ones_matrix(b);
+
+ auto iters = solve_qp_box_constrained(Q, b, alpha, lower, upper, 1e-9, 10000);
+ dlog << LINFO << "iters: "<< iters;
+ dlog << LINFO << "alpha: " << trans(alpha);
+
+ dlog << LINFO;
+
+ std::vector<matrix<double,0,1>> alphas(2);
+ alphas[0] = zeros_matrix<double>(4,1); alphas[1] = zeros_matrix<double>(4,1);
+
+ lower = -10000*ones_matrix(alphas[0]);
+ upper = 10000*ones_matrix(alphas[0]);
+ std::vector<matrix<double,0,1>> lowers = {lower,lower}, uppers = {upper, upper};
+ auto iters2 = solve_qp_box_constrained_blockdiag(Q_blocks, bs, offdiag, alphas, lowers, uppers, 1e-9, 10000);
+ dlog << LINFO << "iters2: "<< iters2;
+ dlog << LINFO << "alpha: " << trans(join_cols(alphas[0],alphas[1]));
+
+ dlog << LINFO << "obj1: "<< 0.5*trans(alpha)*Q*alpha + trans(b)*alpha;
+ dlog << LINFO << "obj2: "<< 0.5*trans(join_cols(alphas[0],alphas[1]))*Q*join_cols(alphas[0],alphas[1]) + trans(b)*join_cols(alphas[0],alphas[1]);
+ dlog << LINFO << "obj1-obj2: "<<(0.5*trans(alpha)*Q*alpha + trans(b)*alpha) - (0.5*trans(join_cols(alphas[0],alphas[1]))*Q*join_cols(alphas[0],alphas[1]) + trans(b)*join_cols(alphas[0],alphas[1]));
+
+ DLIB_TEST_MSG(max(abs(alpha - join_cols(alphas[0], alphas[1]))) < 1e-6, max(abs(alpha - join_cols(alphas[0], alphas[1]))));
+
+ DLIB_TEST(iters == iters2);
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_solve_qp_box_constrained_blockdiag_compact(dlib::rand& rnd, double percent_off_diag_present)
+ {
+ print_spinner();
+
+ dlog << LINFO << "test_solve_qp_box_constrained_blockdiag_compact(), percent_off_diag_present==" << percent_off_diag_present;
+
+ std::map<unordered_pair<size_t>, matrix<double,0,1>> offdiag;
+ std::vector<matrix<double>> Q_blocks;
+ std::vector<matrix<double,0,1>> bs;
+
+ const long num_blocks = 20;
+ const long dims = 4;
+ const double lambda = 10;
+ for (long i = 0; i < num_blocks; ++i)
+ {
+ matrix<double> Q1;
+ matrix<double,0,1> b1;
+ Q1 = randm(dims,dims,rnd); Q1 = Q1*trans(Q1);
+ b1 = gaussian_randm(dims,1, i);
+
+ Q_blocks.push_back(Q1);
+ bs.push_back(b1);
+
+ // test with some graph regularization terms
+ for (long j = 0; j < num_blocks; ++j)
+ {
+ if (rnd.get_random_double() < percent_off_diag_present)
+ {
+ if (i==j)
+ offdiag[make_unordered_pair(i,j)] = (num_blocks-1)*lambda*rnd.get_random_double()*ones_matrix<double>(dims,1);
+ else
+ offdiag[make_unordered_pair(i,j)] = -lambda*rnd.get_random_double()*ones_matrix<double>(dims,1);
+ }
+ }
+ }
+
+ // build out the dense version of the QP so we can test it against the dense solver.
+ matrix<double> Q(num_blocks*dims, num_blocks*dims);
+ Q = 0;
+ matrix<double,0,1> b(num_blocks*dims);
+ for (long i = 0; i < num_blocks; ++i)
+ {
+ set_subm(Q,i*dims,i*dims,dims,dims) = Q_blocks[i];
+ set_subm(b,i*dims,0,dims,1) = bs[i];
+ }
+ for (auto& p : offdiag)
+ {
+ long r = p.first.first;
+ long c = p.first.second;
+ set_subm(Q, dims*r,dims*c, dims,dims) += diagm(p.second);
+ if (c != r)
+ set_subm(Q, dims*c,dims*r, dims,dims) += diagm(p.second);
+ }
+
+
+
+ matrix<double,0,1> alpha = zeros_matrix<double>(dims*num_blocks,1);
+ matrix<double,0,1> lower = -10000*ones_matrix<double>(dims*num_blocks,1);
+ matrix<double,0,1> upper = 10000*ones_matrix<double>(dims*num_blocks,1);
+
+ auto iters = solve_qp_box_constrained(Q, b, alpha, lower, upper, 1e-9, 20000);
+ dlog << LINFO << "iters: "<< iters;
+
+
+ matrix<double,0,1> init_alpha = zeros_matrix(bs[0]);
+ lower = -10000*ones_matrix(bs[0]);
+ upper = 10000*ones_matrix(bs[0]);
+
+ std::vector<matrix<double,0,1>> alphas(num_blocks, init_alpha);
+ std::vector<matrix<double,0,1>> lowers(num_blocks, lower);
+ std::vector<matrix<double,0,1>> uppers(num_blocks, upper);
+
+ auto iters2 = solve_qp_box_constrained_blockdiag(Q_blocks, bs, offdiag, alphas, lowers, uppers, 1e-9, 20000);
+ dlog << LINFO << "iters2: "<< iters2;
+
+
+ const matrix<double> refalpha = reshape(alpha, num_blocks, dims);
+
+ // now make sure the two solvers agree on the outputs.
+ for (long r = 0; r < num_blocks; ++r)
+ {
+ for (long c = 0; c < dims; ++c)
+ {
+ DLIB_TEST_MSG(std::abs(refalpha(r,c) - alphas[r](c)) < 1e-6, std::abs(refalpha(r,c) - alphas[r](c)));
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class opt_qp_solver_tester : public tester
+ {
+ /*
+ The idea here is just to solve the same problem with two different
+ methods and check that they basically agree. The SMO solver should be
+ very accurate but for this problem the BFGS solver is relatively
+ inaccurate. So this test is really just a sanity check on the SMO
+ solver.
+ */
+ public:
+ opt_qp_solver_tester (
+ ) :
+ tester ("test_opt_qp_solver",
+ "Runs tests on the solve_qp_using_smo component.")
+ {
+ thetime = time(0);
+ }
+
+ time_t thetime;
+ dlib::rand rnd;
+
+ void perform_test(
+ )
+ {
+ print_spinner();
+ test_solve_qp4_using_smo();
+ print_spinner();
+
+ ++thetime;
+ //dlog << LINFO << "time seed: " << thetime;
+ //rnd.set_seed(cast_to_string(thetime));
+
+ running_stats<double> rs;
+
+ for (int i = 0; i < 40; ++i)
+ {
+ for (long dims = 1; dims < 6; ++dims)
+ {
+ rs.add(do_the_test(dims, 1.0));
+ }
+ }
+
+ for (int i = 0; i < 40; ++i)
+ {
+ for (long dims = 1; dims < 6; ++dims)
+ {
+ rs.add(do_the_test(dims, 5.0));
+ }
+ }
+
+ dlog << LINFO << "disagreement mean: " << rs.mean();
+ dlog << LINFO << "disagreement stddev: " << rs.stddev();
+ DLIB_TEST_MSG(rs.mean() < 0.001, rs.mean());
+ DLIB_TEST_MSG(rs.stddev() < 0.001, rs.stddev());
+
+
+ test_find_gap_between_convex_hulls();
+ test_solve_qp_box_constrained_blockdiag();
+
+ // try a range of off diagonal sparseness. We do this to make sure we exercise both
+ // the compact and sparse code paths within the solver.
+ test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.001);
+ test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.01);
+ test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.04);
+ test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.10);
+ test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.50);
+ test_solve_qp_box_constrained_blockdiag_compact(rnd, 1.00);
+ }
+
+ double do_the_test (
+ const long dims,
+ double C
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "dims: " << dims;
+ dlog << LINFO << "testing with C == " << C;
+ test_smo test;
+
+ test.Q = randm(dims, dims, rnd);
+ test.Q = trans(test.Q)*test.Q;
+ test.b = randm(dims,1, rnd);
+ test.C = C;
+
+ test_smo_derivative der;
+ der.Q = test.Q;
+ der.b = test.b;
+ der.C = test.C;
+
+
+ matrix<double,0,1> x(dims), alpha(dims);
+
+
+ test.penalty = 20000;
+ der.penalty = test.penalty;
+
+ alpha = C/alpha.size();
+ x = alpha;
+
+ const unsigned long max_iter = 400000;
+ solve_qp_using_smo(test.Q, test.b, alpha, 0.00000001, max_iter);
+ DLIB_TEST_MSG(abs(sum(alpha) - C) < 1e-13, abs(sum(alpha) - C) );
+ dlog << LTRACE << "alpha: " << alpha;
+ dlog << LINFO << "SMO: true objective: "<< 0.5*trans(alpha)*test.Q*alpha - trans(alpha)*test.b;
+
+
+ double obj = find_min(bfgs_search_strategy(),
+ objective_delta_stop_strategy(1e-13, 5000),
+ test,
+ der,
+ x,
+ -10);
+
+
+ dlog << LINFO << "BFGS: objective: " << obj;
+ dlog << LINFO << "BFGS: true objective: "<< 0.5*trans(x)*test.Q*x - trans(x)*test.b;
+ dlog << LINFO << "sum(x): " << sum(x);
+ dlog << LINFO << x;
+
+ double disagreement = max(abs(x-alpha));
+ dlog << LINFO << "Disagreement: " << disagreement;
+ return disagreement;
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/optimization.cpp b/ml/dlib/dlib/test/optimization.cpp
new file mode 100644
index 000000000..b47449abe
--- /dev/null
+++ b/ml/dlib/dlib/test/optimization.cpp
@@ -0,0 +1,1231 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include "optimization_test_functions.h"
+#include <dlib/optimization.h>
+#include <dlib/statistics.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.optimization");
+
+// ----------------------------------------------------------------------------------------
+
+ bool approx_equal (
+ double a,
+ double b
+ )
+ {
+ return std::abs(a - b) < 100*std::numeric_limits<double>::epsilon();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long total_count = 0;
+
+
+ template <typename T>
+ double apq ( const T& x)
+ {
+ DLIB_ASSERT(x.nr() > 1 && x.nc() == 1,"");
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ double temp = 0;
+ for (long r = 0; r < x.nr(); ++r)
+ {
+ temp += (r+1)*x(r)*x(r);
+ }
+
+ ++total_count;
+
+ return temp + 1/100.0*(x(0) + x(x.nr()-1))*(x(0) + x(x.nr()-1));
+ }
+
+ template <typename T>
+ T der_apq ( const T& x)
+ {
+ DLIB_ASSERT(x.nr() > 1 && x.nc() == 1,"");
+ COMPILE_TIME_ASSERT(is_matrix<T>::value);
+ T temp(x.nr());
+ for (long r = 0; r < x.nr(); ++r)
+ {
+ temp(r) = 2*(r+1)*x(r) ;
+ }
+
+ temp(0) += 1/50.0*(x(0) + x(x.nr()-1));
+ temp(x.nr()-1) += 1/50.0*(x(0) + x(x.nr()-1));
+
+ ++total_count;
+
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // Rosenbrock's function. minimum at (1,1)
+ double rosen ( const matrix<double,2,1>& x)
+ {
+ ++total_count;
+ return 100*pow(x(1) - x(0)*x(0),2) + pow(1 - x(0),2);
+ }
+
+ matrix<double,2,1> der_rosen ( const matrix<double,2,1>& x)
+ {
+ ++total_count;
+ matrix<double,2,1> res;
+ res(0) = -400*x(0)*(x(1)-x(0)*x(0)) - 2*(1-x(0));
+ res(1) = 200*(x(1)-x(0)*x(0));
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // negative of Rosenbrock's function. minimum at (1,1)
+ double neg_rosen ( const matrix<double,2,1>& x)
+ {
+ ++total_count;
+ return -(100*pow(x(1) - x(0)*x(0),2) + pow(1 - x(0),2));
+ }
+
+ matrix<double,2,1> der_neg_rosen ( const matrix<double,2,1>& x)
+ {
+ ++total_count;
+ matrix<double,2,1> res;
+ res(0) = -400*x(0)*(x(1)-x(0)*x(0)) - 2*(1-x(0));
+ res(1) = 200*(x(1)-x(0)*x(0));
+ return -res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double simple ( const matrix<double,2,1>& x)
+ {
+ ++total_count;
+ return 10*x(0)*x(0) + x(1)*x(1);
+ }
+
+ matrix<double,2,1> der_simple ( const matrix<double,2,1>& x)
+ {
+ ++total_count;
+ matrix<double,2,1> res;
+ res(0) = 20*x(0);
+ res(1) = 2*x(1);
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ double powell ( const matrix<double,4,1>& x)
+ {
+ ++total_count;
+ return pow(x(0) + 10*x(1),2) +
+ pow(std::sqrt(5.0)*(x(2) - x(3)),2) +
+ pow((x(1) - 2*x(2))*(x(1) - 2*x(2)),2) +
+ pow(std::sqrt(10.0)*(x(0) - x(3))*(x(0) - x(3)),2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+// a simple function with a minimum at zero
+ double single_variable_function ( double x)
+ {
+ ++total_count;
+ return 3*x*x + 5;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void test_apq (
+ const matrix<double,0,1> p
+ )
+ {
+ typedef matrix<double,0,1> T;
+ const double eps = 1e-12;
+ const double minf = -10;
+ matrix<double,0,1> x(p.nr()), opt(p.nr());
+ set_all_elements(opt, 0);
+ double val = 0;
+
+ if (p.size() < 20)
+ dlog << LINFO << "testing with apq and the start point: " << trans(p);
+ else
+ dlog << LINFO << "testing with apq and a big vector with " << p.size() << " components.";
+
+ // don't use bfgs on really large vectors
+ if (p.size() < 20)
+ {
+ total_count = 0;
+ x = p;
+ val = find_min(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ wrap_function(apq<T>), wrap_function(der_apq<T>), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , apq(x)));
+ dlog << LINFO << "find_min() bgfs: got apq in " << total_count;
+
+ total_count = 0;
+ x = p;
+ find_min(bfgs_search_strategy(),
+ gradient_norm_stop_strategy(),
+ wrap_function(apq<T>), wrap_function(der_apq<T>), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ dlog << LINFO << "find_min() bgfs(gn): got apq in " << total_count;
+ }
+
+
+ if (p.size() < 100)
+ {
+ total_count = 0;
+ x = p;
+ val=find_min_bobyqa(wrap_function(apq<T>), x, 2*x.size()+1,
+ uniform_matrix<double>(x.size(),1,-1e100),
+ uniform_matrix<double>(x.size(),1,1e100),
+ (max(abs(x))+1)/10,
+ 1e-6,
+ 10000);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , apq(x)));
+ dlog << LINFO << "find_min_bobyqa(): got apq in " << total_count;
+ }
+
+ total_count = 0;
+ x = p;
+ val=find_min(lbfgs_search_strategy(10),
+ objective_delta_stop_strategy(eps),
+ wrap_function(apq<T>), wrap_function(der_apq<T>), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , apq(x)));
+ dlog << LINFO << "find_min() lbgfs-10: got apq in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(lbfgs_search_strategy(1),
+ objective_delta_stop_strategy(eps),
+ wrap_function(apq<T>), wrap_function(der_apq<T>), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , apq(x)));
+ dlog << LINFO << "find_min() lbgfs-1: got apq in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ wrap_function(apq<T>), wrap_function(der_apq<T>), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , apq(x)));
+ dlog << LINFO << "find_min() cg: got apq in " << total_count;
+
+
+ // don't do approximate derivative tests if the input point is really long
+ if (p.size() < 20)
+ {
+ total_count = 0;
+ x = p;
+ val=find_min(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ wrap_function(apq<T>), derivative(wrap_function(apq<T>)), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , apq(x)));
+ dlog << LINFO << "find_min() bfgs: got apq/noder in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ wrap_function(apq<T>), derivative(wrap_function(apq<T>)), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , apq(x)));
+ dlog << LINFO << "find_min() cg: got apq/noder in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ wrap_function(apq<T>), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , apq(x)));
+ dlog << LINFO << "find_min() bfgs: got apq/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(lbfgs_search_strategy(10),
+ objective_delta_stop_strategy(eps),
+ wrap_function(apq<T>), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ dlog << LINFO << "find_min() lbfgs-10: got apq/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ wrap_function(apq<T>), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , apq(x)));
+ dlog << LINFO << "find_min() cg: got apq/noder2 in " << total_count;
+ }
+ }
+
+ void test_powell (
+ const matrix<double,4,1> p
+ )
+ {
+ const double eps = 1e-15;
+ const double minf = -1;
+ matrix<double,4,1> x, opt;
+ opt(0) = 0;
+ opt(1) = 0;
+ opt(2) = 0;
+ opt(3) = 0;
+
+ double val = 0;
+
+ dlog << LINFO << "testing with powell and the start point: " << trans(p);
+
+ /*
+ total_count = 0;
+ x = p;
+ val=find_min(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ powell, derivative(powell,1e-8), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-2),opt-x);
+ DLIB_TEST(approx_equal(val , powell(x)));
+ dlog << LINFO << "find_min() bfgs: got powell/noder in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ powell, derivative(powell,1e-9), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-2),opt-x);
+ DLIB_TEST(approx_equal(val , powell(x)));
+ dlog << LINFO << "find_min() cg: got powell/noder in " << total_count;
+ */
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ powell, x, minf, 1e-10);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-1),opt-x);
+ DLIB_TEST(approx_equal(val , powell(x)));
+ dlog << LINFO << "find_min() bfgs: got powell/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(lbfgs_search_strategy(4),
+ objective_delta_stop_strategy(eps),
+ powell, x, minf, 1e-10);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-1),opt-x);
+ DLIB_TEST(approx_equal(val , powell(x)));
+ dlog << LINFO << "find_min() lbfgs-4: got powell/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(lbfgs_search_strategy(4),
+ gradient_norm_stop_strategy(),
+ powell, x, minf, 1e-10);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-1),opt-x);
+ DLIB_TEST(approx_equal(val , powell(x)));
+ dlog << LINFO << "find_min() lbfgs-4(gn): got powell/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ powell, x, minf, 1e-10);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-1),opt-x);
+ DLIB_TEST(approx_equal(val , powell(x)));
+ dlog << LINFO << "find_min() cg: got powell/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_bobyqa(powell, x, 2*x.size()+1,
+ uniform_matrix<double>(x.size(),1,-1e100),
+ uniform_matrix<double>(x.size(),1,1e100),
+ (max(abs(x))+1)/10,
+ 1e-8,
+ 10000);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-3),opt-x);
+ DLIB_TEST(approx_equal(val , powell(x)));
+ dlog << LINFO << "find_min_bobyqa(): got powell in " << total_count;
+
+ }
+
+
+
+ void test_simple (
+ const matrix<double,2,1> p
+ )
+ {
+ const double eps = 1e-12;
+ const double minf = -10000;
+ matrix<double,2,1> x, opt;
+ opt(0) = 0;
+ opt(1) = 0;
+ double val = 0;
+
+ dlog << LINFO << "testing with simple and the start point: " << trans(p);
+
+ total_count = 0;
+ x = p;
+ val=find_min(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ simple, der_simple, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() bfgs: got simple in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(bfgs_search_strategy(),
+ gradient_norm_stop_strategy(),
+ simple, der_simple, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() bfgs(gn): got simple in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(lbfgs_search_strategy(3),
+ objective_delta_stop_strategy(eps),
+ simple, der_simple, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() lbfgs-3: got simple in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ simple, der_simple, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() cg: got simple in " << total_count;
+
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ simple, derivative(simple), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() bfgs: got simple/noder in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(lbfgs_search_strategy(8),
+ objective_delta_stop_strategy(eps),
+ simple, derivative(simple), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() lbfgs-8: got simple/noder in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ simple, derivative(simple), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() cg: got simple/noder in " << total_count;
+
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ simple, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() bfgs: got simple/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(lbfgs_search_strategy(6),
+ objective_delta_stop_strategy(eps),
+ simple, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() lbfgs-6: got simple/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ simple, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min() cg: got simple/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_bobyqa(simple, x, 2*x.size()+1,
+ uniform_matrix<double>(x.size(),1,-1e100),
+ uniform_matrix<double>(x.size(),1,1e100),
+ (max(abs(x))+1)/10,
+ 1e-6,
+ 10000);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , simple(x)));
+ dlog << LINFO << "find_min_bobyqa(): got simple in " << total_count;
+
+ }
+
+
+ void test_rosen (
+ const matrix<double,2,1> p
+ )
+ {
+ const double eps = 1e-15;
+ const double minf = -10;
+ matrix<double,2,1> x, opt;
+ opt(0) = 1;
+ opt(1) = 1;
+
+ double val = 0;
+
+ dlog << LINFO << "testing with rosen and the start point: " << trans(p);
+
+ total_count = 0;
+ x = p;
+ val=find_min(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ rosen, der_rosen, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x);
+ DLIB_TEST(approx_equal(val , rosen(x)));
+ dlog << LINFO << "find_min() bfgs: got rosen in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(bfgs_search_strategy(),
+ gradient_norm_stop_strategy(),
+ rosen, der_rosen, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x);
+ DLIB_TEST(approx_equal(val , rosen(x)));
+ dlog << LINFO << "find_min() bfgs(gn): got rosen in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(lbfgs_search_strategy(20),
+ objective_delta_stop_strategy(eps),
+ rosen, der_rosen, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x);
+ DLIB_TEST(approx_equal(val , rosen(x)));
+ dlog << LINFO << "find_min() lbfgs-20: got rosen in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ rosen, der_rosen, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x);
+ DLIB_TEST(approx_equal(val , rosen(x)));
+ dlog << LINFO << "find_min() cg: got rosen in " << total_count;
+
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ rosen, derivative(rosen,1e-5), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-4),opt-x);
+ DLIB_TEST(approx_equal(val , rosen(x)));
+ dlog << LINFO << "find_min() bfgs: got rosen/noder in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(lbfgs_search_strategy(5),
+ objective_delta_stop_strategy(eps),
+ rosen, derivative(rosen,1e-5), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-4),opt-x);
+ DLIB_TEST(approx_equal(val , rosen(x)));
+ dlog << LINFO << "find_min() lbfgs-5: got rosen/noder in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ rosen, derivative(rosen,1e-5), x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-4),opt-x);
+ DLIB_TEST(approx_equal(val , rosen(x)));
+ dlog << LINFO << "find_min() cg: got rosen/noder in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_min_using_approximate_derivatives(cg_search_strategy(),
+ objective_delta_stop_strategy(eps),
+ rosen, x, minf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-4),opt-x);
+ DLIB_TEST(approx_equal(val , rosen(x)));
+ dlog << LINFO << "find_min() cg: got rosen/noder2 in " << total_count;
+
+
+ if (max(abs(p)) < 1000)
+ {
+ total_count = 0;
+ x = p;
+ val=find_min_bobyqa(rosen, x, 2*x.size()+1,
+ uniform_matrix<double>(x.size(),1,-1e100),
+ uniform_matrix<double>(x.size(),1,1e100),
+ (max(abs(x))+1)/10,
+ 1e-6,
+ 10000);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , rosen(x)));
+ dlog << LINFO << "find_min_bobyqa(): got rosen in " << total_count;
+ }
+ }
+
+
+ void test_neg_rosen (
+ const matrix<double,2,1> p
+ )
+ {
+ const double eps = 1e-15;
+ const double maxf = 10;
+ matrix<double,2,1> x, opt;
+ opt(0) = 1;
+ opt(1) = 1;
+
+ double val = 0;
+
+ dlog << LINFO << "testing with neg_rosen and the start point: " << trans(p);
+
+ total_count = 0;
+ x = p;
+ val=find_max(
+ bfgs_search_strategy(),
+ objective_delta_stop_strategy(eps), neg_rosen, der_neg_rosen, x, maxf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x);
+ DLIB_TEST(approx_equal(val , neg_rosen(x)));
+ dlog << LINFO << "find_max() bfgs: got neg_rosen in " << total_count;
+
+ total_count = 0;
+ x = p;
+ val=find_max(
+ lbfgs_search_strategy(5),
+ objective_delta_stop_strategy(eps), neg_rosen, der_neg_rosen, x, maxf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x);
+ DLIB_TEST(approx_equal(val , neg_rosen(x)));
+ dlog << LINFO << "find_max() lbfgs-5: got neg_rosen in " << total_count;
+
+ total_count = 0;
+ x = p;
+ val=find_max(
+ lbfgs_search_strategy(5),
+ objective_delta_stop_strategy(eps), neg_rosen, derivative(neg_rosen), x, maxf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x);
+ DLIB_TEST(approx_equal(val , neg_rosen(x)));
+ dlog << LINFO << "find_max() lbfgs-5: got neg_rosen/noder in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_max_using_approximate_derivatives(
+ cg_search_strategy(),
+ objective_delta_stop_strategy(eps), neg_rosen, x, maxf);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x);
+ DLIB_TEST(approx_equal(val , neg_rosen(x)));
+ dlog << LINFO << "find_max() cg: got neg_rosen/noder2 in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ val=find_max_bobyqa(neg_rosen, x, 2*x.size()+1,
+ uniform_matrix<double>(x.size(),1,-1e100),
+ uniform_matrix<double>(x.size(),1,1e100),
+ (max(abs(x))+1)/10,
+ 1e-6,
+ 10000);
+ DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
+ DLIB_TEST(approx_equal(val , neg_rosen(x)));
+ dlog << LINFO << "find_max_bobyqa(): got neg_rosen in " << total_count;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_single_variable_function (
+ const double p
+ )
+ {
+ const double eps = 1e-7;
+
+
+ dlog << LINFO << "testing with single_variable_function and the start point: " << p;
+ double out, x;
+
+ total_count = 0;
+ x = p;
+ out = find_min_single_variable(single_variable_function, x, -1e100, 1e100, eps, 1000);
+ DLIB_TEST_MSG(std::abs(out-5) < 1e-6, out-5);
+ DLIB_TEST_MSG(std::abs(x) < 1e-6, x);
+ dlog << LINFO << "find_min_single_variable(): got single_variable_function in " << total_count;
+
+
+ total_count = 0;
+ x = p;
+ out = -find_max_single_variable(negate_function(single_variable_function), x, -1e100, 1e100, eps, 1000);
+ DLIB_TEST_MSG(std::abs(out-5) < 1e-6, out-5);
+ DLIB_TEST_MSG(std::abs(x) < 1e-6, x);
+ dlog << LINFO << "find_max_single_variable(): got single_variable_function in " << total_count;
+
+
+ if (p > 0)
+ {
+ total_count = 0;
+ x = p;
+ out = find_min_single_variable(single_variable_function, x, -1e-4, 1e100, eps, 1000);
+ DLIB_TEST_MSG(std::abs(out-5) < 1e-6, out-5);
+ DLIB_TEST_MSG(std::abs(x) < 1e-6, x);
+ dlog << LINFO << "find_min_single_variable(): got single_variable_function in " << total_count;
+
+
+ if (p > 3)
+ {
+ total_count = 0;
+ x = p;
+ out = -find_max_single_variable(negate_function(single_variable_function), x, 3, 1e100, eps, 1000);
+ DLIB_TEST_MSG(std::abs(out - (3*3*3+5)) < 1e-6, out-(3*3*3+5));
+ DLIB_TEST_MSG(std::abs(x-3) < 1e-6, x);
+ dlog << LINFO << "find_max_single_variable(): got single_variable_function in " << total_count;
+ }
+ }
+
+ if (p < 0)
+ {
+ total_count = 0;
+ x = p;
+ out = find_min_single_variable(single_variable_function, x, -1e100, 1e-4, eps, 1000);
+ DLIB_TEST_MSG(std::abs(out-5) < 1e-6, out-5);
+ DLIB_TEST_MSG(std::abs(x) < 1e-6, x);
+ dlog << LINFO << "find_min_single_variable(): got single_variable_function in " << total_count;
+
+ if (p < -3)
+ {
+ total_count = 0;
+ x = p;
+ out = find_min_single_variable(single_variable_function, x, -1e100, -3, eps, 1000);
+ DLIB_TEST_MSG(std::abs(out - (3*3*3+5)) < 1e-6, out-(3*3*3+5));
+ DLIB_TEST_MSG(std::abs(x+3) < 1e-6, x);
+ dlog << LINFO << "find_min_single_variable(): got single_variable_function in " << total_count;
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void optimization_test (
+ )
+ /*!
+ ensures
+ - runs tests on the optimization stuff compliance with the specs
+ !*/
+ {
+ matrix<double,0,1> p;
+
+ print_spinner();
+
+ p.set_size(2);
+
+ // test with single_variable_function
+ test_single_variable_function(0);
+ test_single_variable_function(1);
+ test_single_variable_function(-10);
+ test_single_variable_function(-100);
+ test_single_variable_function(900.53);
+
+ // test with the rosen function
+ p(0) = 9;
+ p(1) = -4.9;
+ test_rosen(p);
+ test_neg_rosen(p);
+
+ p(0) = 0;
+ p(1) = 0;
+ test_rosen(p);
+
+ p(0) = 5323;
+ p(1) = 98248;
+ test_rosen(p);
+
+ // test with the simple function
+ p(0) = 1;
+ p(1) = 1;
+ test_simple(p);
+
+ p(0) = 0.5;
+ p(1) = -9;
+ test_simple(p);
+
+ p(0) = 645;
+ p(1) = 839485;
+ test_simple(p);
+
+ print_spinner();
+
+ // test with the apq function
+ p.set_size(5);
+
+ p(0) = 1;
+ p(1) = 1;
+ p(2) = 1;
+ p(3) = 1;
+ p(4) = 1;
+ test_apq(p);
+
+ p(0) = 1;
+ p(1) = 2;
+ p(2) = 3;
+ p(3) = 4;
+ p(4) = 5;
+ test_apq(p);
+
+ p(0) = 1;
+ p(1) = 2;
+ p(2) = -3;
+ p(3) = 4;
+ p(4) = 5;
+ test_apq(p);
+
+ print_spinner();
+
+ p(0) = 1;
+ p(1) = 2324;
+ p(2) = -3;
+ p(3) = 4;
+ p(4) = 534534;
+ test_apq(p);
+
+ p.set_size(10);
+ p(0) = 1;
+ p(1) = 2;
+ p(2) = -3;
+ p(3) = 4;
+ p(4) = 5;
+ p(5) = 1;
+ p(6) = 2;
+ p(7) = -3;
+ p(8) = 4;
+ p(9) = 5;
+ test_apq(p);
+
+ // test apq with a big vector
+ p.set_size(500);
+ dlib::rand rnd;
+ for (long i = 0; i < p.size(); ++i)
+ {
+ p(i) = rnd.get_random_double()*20 - 10;
+ }
+ test_apq(p);
+
+ print_spinner();
+
+ // test with the powell function
+ p.set_size(4);
+
+ p(0) = 3;
+ p(1) = -1;
+ p(2) = 0;
+ p(3) = 1;
+ test_powell(p);
+
+ {
+ matrix<double,2,1> m;
+ m(0) = -0.43;
+ m(1) = 0.919;
+ DLIB_TEST(dlib::equal(der_rosen(m) , derivative(rosen)(m),1e-5));
+
+ DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(0) -
+ make_line_search_function(derivative(rosen),m,m)(0)) < 1e-5,"");
+ DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(1) -
+ make_line_search_function(derivative(rosen),m,m)(1)) < 1e-5,"");
+
+ DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(0) -
+ make_line_search_function(der_rosen,m,m)(0)) < 1e-5,"");
+ DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(1) -
+ make_line_search_function(der_rosen,m,m)(1)) < 1e-5,"");
+ }
+ {
+ matrix<double,2,1> m;
+ m(0) = 1;
+ m(1) = 2;
+ DLIB_TEST(dlib::equal(der_rosen(m) , derivative(rosen)(m),1e-5));
+
+ DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(0) -
+ make_line_search_function(derivative(rosen),m,m)(0)) < 1e-5,"");
+ DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(1) -
+ make_line_search_function(derivative(rosen),m,m)(1)) < 1e-5,"");
+
+ DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(0) -
+ make_line_search_function(der_rosen,m,m)(0)) < 1e-5,"");
+ DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(1) -
+ make_line_search_function(der_rosen,m,m)(1)) < 1e-5,"");
+ }
+
+ {
+ matrix<double,2,1> m;
+ m = 1,2;
+ DLIB_TEST(std::abs(neg_rosen(m) - negate_function(rosen)(m) ) < 1e-16);
+ }
+
+ }
+
+ template <typename der_funct, typename T>
+ double unconstrained_gradient_magnitude (
+ const der_funct& grad,
+ const T& x,
+ const T& lower,
+ const T& upper
+ )
+ {
+ T g = grad(x);
+
+ double unorm = 0;
+
+ for (long i = 0; i < g.size(); ++i)
+ {
+ if (lower(i) < x(i) && x(i) < upper(i))
+ unorm += g(i)*g(i);
+ else if (x(i) == lower(i) && g(i) < 0)
+ unorm += g(i)*g(i);
+ else if (x(i) == upper(i) && g(i) > 0)
+ unorm += g(i)*g(i);
+ }
+
+ return unorm;
+ }
+
+ template <typename der_funct, typename T>
+ double unconstrained_gradient_magnitude_neg_funct (
+ const der_funct& grad,
+ const T& x,
+ const T& lower,
+ const T& upper
+ )
+ {
+ T g = grad(x);
+
+ double unorm = 0;
+
+ for (long i = 0; i < g.size(); ++i)
+ {
+ if (lower(i) < x(i) && x(i) < upper(i))
+ unorm += g(i)*g(i);
+ else if (x(i) == lower(i) && g(i) > 0)
+ unorm += g(i)*g(i);
+ else if (x(i) == upper(i) && g(i) < 0)
+ unorm += g(i)*g(i);
+ }
+
+ return unorm;
+ }
+
+ template <typename search_strategy_type>
+ double test_bound_solver_neg_rosen (dlib::rand& rnd, search_strategy_type search_strategy)
+ {
+ using namespace dlib::test_functions;
+ print_spinner();
+ matrix<double,2,1> starting_point, lower, upper, x;
+
+
+ // pick random bounds
+ lower = rnd.get_random_gaussian()+1, rnd.get_random_gaussian()+1;
+ upper = rnd.get_random_gaussian()+1, rnd.get_random_gaussian()+1;
+ while (upper(0) < lower(0)) upper(0) = rnd.get_random_gaussian()+1;
+ while (upper(1) < lower(1)) upper(1) = rnd.get_random_gaussian()+1;
+
+ starting_point = rnd.get_random_double()*(upper(0)-lower(0))+lower(0),
+ rnd.get_random_double()*(upper(1)-lower(1))+lower(1);
+
+ dlog << LINFO << "lower: "<< trans(lower);
+ dlog << LINFO << "upper: "<< trans(upper);
+ dlog << LINFO << "starting: "<< trans(starting_point);
+
+ x = starting_point;
+ double val = find_max_box_constrained(
+ search_strategy,
+ objective_delta_stop_strategy(1e-16, 500),
+ neg_rosen, der_neg_rosen, x,
+ lower,
+ upper
+ );
+
+ DLIB_TEST_MSG(std::abs(val - neg_rosen(x)) < 1e-11, std::abs(val - neg_rosen(x)));
+ dlog << LINFO << "neg_rosen solution:\n" << x;
+
+ dlog << LINFO << "neg_rosen gradient: "<< trans(der_neg_rosen(x));
+ const double gradient_residual = unconstrained_gradient_magnitude_neg_funct(der_neg_rosen, x, lower, upper);
+ dlog << LINFO << "gradient_residual: "<< gradient_residual;
+
+ return gradient_residual;
+ }
+
+ template <typename search_strategy_type>
+ double test_bound_solver_rosen (dlib::rand& rnd, search_strategy_type search_strategy)
+ {
+ using namespace dlib::test_functions;
+ print_spinner();
+ matrix<double,2,1> starting_point, lower, upper, x;
+
+
+ // pick random bounds and sometimes put the upper bound at zero so we can have
+ // a test where the optimal value has a bound active at 0 so make sure this case
+ // works properly.
+ if (rnd.get_random_double() > 0.2)
+ {
+ lower = rnd.get_random_gaussian()+1, rnd.get_random_gaussian()+1;
+ upper = rnd.get_random_gaussian()+1, rnd.get_random_gaussian()+1;
+ while (upper(0) < lower(0)) upper(0) = rnd.get_random_gaussian()+1;
+ while (upper(1) < lower(1)) upper(1) = rnd.get_random_gaussian()+1;
+ }
+ else
+ {
+ upper = 0,0;
+ if (rnd.get_random_double() > 0.5)
+ upper(0) = -rnd.get_random_double();
+ if (rnd.get_random_double() > 0.5)
+ upper(1) = -rnd.get_random_double();
+
+ lower = rnd.get_random_double()+1, rnd.get_random_double()+1;
+ lower = upper - lower;
+ }
+ const bool pick_uniform_bounds = rnd.get_random_double() > 0.9;
+ if (pick_uniform_bounds)
+ {
+ double x = rnd.get_random_gaussian()*2;
+ double y = rnd.get_random_gaussian()*2;
+ lower = min(x,y);
+ upper = max(x,y);
+ }
+
+ starting_point = rnd.get_random_double()*(upper(0)-lower(0))+lower(0),
+ rnd.get_random_double()*(upper(1)-lower(1))+lower(1);
+
+ dlog << LINFO << "lower: "<< trans(lower);
+ dlog << LINFO << "upper: "<< trans(upper);
+ dlog << LINFO << "starting: "<< trans(starting_point);
+
+ x = starting_point;
+ double val;
+ if (!pick_uniform_bounds)
+ {
+ val = find_min_box_constrained(
+ search_strategy,
+ objective_delta_stop_strategy(1e-16, 500),
+ rosen, der_rosen, x,
+ lower,
+ upper
+ );
+ }
+ else
+ {
+ val = find_min_box_constrained(
+ search_strategy,
+ objective_delta_stop_strategy(1e-16, 500),
+ rosen, der_rosen, x,
+ lower(0),
+ upper(0)
+ );
+ }
+
+
+ DLIB_TEST_MSG(std::abs(val - rosen(x)) < 1e-11, std::abs(val - rosen(x)));
+ dlog << LINFO << "rosen solution:\n" << x;
+
+ dlog << LINFO << "rosen gradient: "<< trans(der_rosen(x));
+ const double gradient_residual = unconstrained_gradient_magnitude(der_rosen, x, lower, upper);
+ dlog << LINFO << "gradient_residual: "<< gradient_residual;
+
+ return gradient_residual;
+ }
+
+ template <typename search_strategy_type>
+ double test_bound_solver_brown (dlib::rand& rnd, search_strategy_type search_strategy)
+ {
+ using namespace dlib::test_functions;
+ print_spinner();
+ matrix<double,4,1> starting_point(4), lower(4), upper(4), x;
+
+ const matrix<double,0,1> solution = brown_solution();
+
+ // pick random bounds
+ lower = rnd.get_random_gaussian(), rnd.get_random_gaussian(), rnd.get_random_gaussian(), rnd.get_random_gaussian();
+ lower = lower*10 + solution;
+ upper = rnd.get_random_gaussian(), rnd.get_random_gaussian(), rnd.get_random_gaussian(), rnd.get_random_gaussian();
+ upper = upper*10 + solution;
+ for (int i = 0; i < lower.size(); ++i)
+ {
+ if (upper(i) < lower(i))
+ swap(upper(i),lower(i));
+ }
+
+ starting_point = rnd.get_random_double()*(upper(0)-lower(0))+lower(0),
+ rnd.get_random_double()*(upper(1)-lower(1))+lower(1),
+ rnd.get_random_double()*(upper(2)-lower(2))+lower(2),
+ rnd.get_random_double()*(upper(3)-lower(3))+lower(3);
+
+ dlog << LINFO << "lower: "<< trans(lower);
+ dlog << LINFO << "upper: "<< trans(upper);
+ dlog << LINFO << "starting: "<< trans(starting_point);
+
+ x = starting_point;
+ double val = find_min_box_constrained(
+ search_strategy,
+ objective_delta_stop_strategy(1e-16, 500),
+ brown, brown_derivative, x,
+ lower,
+ upper
+ );
+
+ DLIB_TEST(std::abs(val - brown(x)) < 1e-14);
+ dlog << LINFO << "brown solution:\n" << x;
+ return unconstrained_gradient_magnitude(brown_derivative, x, lower, upper);
+ }
+
+ template <typename search_strategy_type>
+ void test_box_constrained_optimizers(search_strategy_type search_strategy)
+ {
+ dlib::rand rnd;
+ running_stats<double> rs;
+
+ dlog << LINFO << "test find_min_box_constrained() on rosen";
+ for (int i = 0; i < 10000; ++i)
+ rs.add(test_bound_solver_rosen(rnd, search_strategy));
+ dlog << LINFO << "mean rosen gradient: " << rs.mean();
+ dlog << LINFO << "max rosen gradient: " << rs.max();
+ DLIB_TEST(rs.mean() < 1e-12);
+ DLIB_TEST(rs.max() < 1e-9);
+
+ dlog << LINFO << "test find_min_box_constrained() on brown";
+ rs.clear();
+ for (int i = 0; i < 1000; ++i)
+ rs.add(test_bound_solver_brown(rnd, search_strategy));
+ dlog << LINFO << "mean brown gradient: " << rs.mean();
+ dlog << LINFO << "max brown gradient: " << rs.max();
+ dlog << LINFO << "min brown gradient: " << rs.min();
+ DLIB_TEST(rs.mean() < 4e-5);
+ DLIB_TEST_MSG(rs.max() < 3e-2, rs.max());
+ DLIB_TEST(rs.min() < 1e-10);
+
+ dlog << LINFO << "test find_max_box_constrained() on neg_rosen";
+ rs.clear();
+ for (int i = 0; i < 1000; ++i)
+ rs.add(test_bound_solver_neg_rosen(rnd, search_strategy));
+ dlog << LINFO << "mean neg_rosen gradient: " << rs.mean();
+ dlog << LINFO << "max neg_rosen gradient: " << rs.max();
+ DLIB_TEST(rs.mean() < 1e-12);
+ DLIB_TEST(rs.max() < 1e-9);
+
+ }
+
+ void test_poly_min_extract_2nd()
+ {
+ double off;
+
+ off = 0.0; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ off = 0.1; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ off = 0.2; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ off = 0.3; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ off = 0.4; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ off = 0.5; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ off = 0.6; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ off = 0.8; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ off = 0.9; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ off = 1.0; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13);
+ }
+
+ void test_solve_trust_region_subproblem_bounded()
+ {
+ print_spinner();
+ matrix<double> H(2,2);
+ H = 1, 0,
+ 0, 1;
+ matrix<double,0,1> g, lower, upper, p, true_p;
+ g = {0, 0};
+
+ double radius = 0.5;
+ lower = {0.5, 0};
+ upper = {10, 10};
+
+
+ solve_trust_region_subproblem_bounded(H,g, radius, p, 0.001, 500, lower, upper);
+ true_p = { 0.5, 0};
+ DLIB_TEST_MSG(length(p-true_p) < 1e-12, p);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class optimization_tester : public tester
+ {
+ public:
+ optimization_tester (
+ ) :
+ tester ("test_optimization",
+ "Runs tests on the optimization component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "test_box_constrained_optimizers(bfgs_search_strategy())";
+ test_box_constrained_optimizers(bfgs_search_strategy());
+ dlog << LINFO << "test_box_constrained_optimizers(lbfgs_search_strategy(5))";
+ test_box_constrained_optimizers(lbfgs_search_strategy(5));
+ test_poly_min_extract_2nd();
+ optimization_test();
+ test_solve_trust_region_subproblem_bounded();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/optimization_test_functions.cpp b/ml/dlib/dlib/test/optimization_test_functions.cpp
new file mode 100644
index 000000000..1cded50ee
--- /dev/null
+++ b/ml/dlib/dlib/test/optimization_test_functions.cpp
@@ -0,0 +1,425 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include "optimization_test_functions.h"
+
+/*
+
+ Most of the code in this file is converted from the set of Fortran 90 routines
+ created by John Burkardt.
+
+ The original Fortran can be found here: http://orion.math.iastate.edu/burkardt/f_src/testopt/testopt.html
+
+*/
+
+
+namespace dlib
+{
+ namespace test_functions
+ {
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,0,1> chebyquad_residuals(const matrix<double,0,1>& x)
+ {
+ matrix<double,0,1> fvec(x.size());
+ const int n = x.size();
+ int i;
+ int j;
+ double t;
+ double t1;
+ double t2;
+ double th;
+ fvec = 0;
+
+ for (j = 1; j <= n; ++j)
+ {
+ t1 = 1.0E+00;
+ t2 = 2.0E+00 * x(j-1) - 1.0E+00;
+ t = 2.0E+00 * t2;
+ for (i = 1; i <= n; ++i)
+ {
+ fvec(i-1) = fvec(i-1) + t2;
+ th = t * t2 - t1;
+ t1 = t2;
+ t2 = th;
+ }
+ }
+
+ for (i = 1; i <= n; ++i)
+ {
+ fvec(i-1) = fvec(i-1) / (double) ( n );
+ if ( ( i%2 ) == 0 )
+ fvec(i-1) = fvec(i-1) + 1.0E+00 / ( (double)i*i - 1.0E+00 );
+ }
+
+ return fvec;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ double chebyquad_residual(int i, const matrix<double,0,1>& x)
+ {
+ return chebyquad_residuals(x)(i);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ int& chebyquad_calls()
+ {
+ static int count = 0;
+ return count;
+ }
+
+ double chebyquad(const matrix<double,0,1>& x )
+ {
+ chebyquad_calls()++;
+ return sum(squared(chebyquad_residuals(x)));
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,0,1> chebyquad_derivative (const matrix<double,0,1>& x)
+ {
+ const int n = x.size();
+ matrix<double,0,1> fvec = chebyquad_residuals(x);
+ matrix<double,0,1> g(n);
+ int i;
+ int j;
+ double s1;
+ double s2;
+ double t;
+ double t1;
+ double t2;
+ double th;
+
+ for (j = 1; j <= n; ++j)
+ {
+ g(j-1) = 0.0E+00;
+ t1 = 1.0E+00;
+ t2 = 2.0E+00 * x(j-1) - 1.0E+00;
+ t = 2.0E+00 * t2;
+ s1 = 0.0E+00;
+ s2 = 2.0E+00;
+ for (i = 1; i <= n; ++i)
+ {
+ g(j-1) = g(j-1) + fvec(i-1) * s2;
+ th = 4.0E+00 * t2 + t * s2 - s1;
+ s1 = s2;
+ s2 = th;
+ th = t * t2 - t1;
+ t1 = t2;
+ t2 = th;
+ }
+ }
+
+ g = 2.0E+00 * g / (double) ( n );
+
+ return g;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,0,1> chebyquad_start (int n)
+ {
+ int i;
+ matrix<double,0,1> x(n);
+
+ for (i = 1; i <= n; ++i)
+ x(i-1) = double ( i ) / double ( n + 1 );
+
+ return x;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,0,1> chebyquad_solution (int n)
+ {
+ matrix<double,0,1> x(n);
+
+ x = 0;
+ switch (n)
+ {
+ case 2:
+ x = 0.2113249E+00, 0.7886751E+00;
+ break;
+ case 4:
+ x = 0.1026728E+00, 0.4062037E+00, 0.5937963E+00, 0.8973272E+00;
+ break;
+ case 6:
+ x = 0.066877E+00, 0.288741E+00, 0.366682E+00, 0.633318E+00, 0.711259E+00, 0.933123E+00;
+ break;
+ case 8:
+ x = 0.043153E+00, 0.193091E+00, 0.266329E+00, 0.500000E+00, 0.500000E+00, 0.733671E+00, 0.806910E+00, 0.956847E+00;
+ break;
+ default:
+ std::ostringstream sout;
+ sout << "don't know chebyquad solution for n = " << n;
+ throw dlib::error(sout.str());
+ break;
+ }
+
+ return x;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double> chebyquad_hessian(const matrix<double,0,1>& x)
+ {
+ const int lda = x.size();
+ const int n = x.size();
+ double d1;
+ double d2;
+ matrix<double,0,1> fvec = chebyquad_residuals(x);
+ matrix<double,0,1> gvec(n);
+ matrix<double> h(lda,n);
+ int i;
+ int j;
+ int k;
+ double p1;
+ double p2;
+ double s1;
+ double s2;
+ double ss1;
+ double ss2;
+ double t;
+ double t1;
+ double t2;
+ double th;
+ double tt;
+ double tth;
+ double tt1;
+ double tt2;
+ h = 0;
+
+ d1 = 1.0E+00 / double ( n );
+ d2 = 2.0E+00 * d1;
+
+ for (j = 1; j <= n; ++j)
+ {
+
+ h(j-1,j-1) = 4.0E+00 * d1;
+ t1 = 1.0E+00;
+ t2 = 2.0E+00 * x(j-1) - 1.0E+00;
+ t = 2.0E+00 * t2;
+ s1 = 0.0E+00;
+ s2 = 2.0E+00;
+ p1 = 0.0E+00;
+ p2 = 0.0E+00;
+ gvec(0) = s2;
+
+ for (i = 2; i <= n; ++i)
+ {
+ th = 4.0E+00 * t2 + t * s2 - s1;
+ s1 = s2;
+ s2 = th;
+ th = t * t2 - t1;
+ t1 = t2;
+ t2 = th;
+ th = 8.0E+00 * s1 + t * p2 - p1;
+ p1 = p2;
+ p2 = th;
+ gvec(i-1) = s2;
+ h(j-1,j-1) = h(j-1,j-1) + fvec(i-1) * th + d1 * s2*s2;
+ }
+
+ h(j-1,j-1) = d2 * h(j-1,j-1);
+
+ for (k = 1; k <= j-1; ++k)
+ {
+
+ h(j-1,k-1) = 0.0;
+ tt1 = 1.0E+00;
+ tt2 = 2.0E+00 * x(k-1) - 1.0E+00;
+ tt = 2.0E+00 * tt2;
+ ss1 = 0.0E+00;
+ ss2 = 2.0E+00;
+
+ for (i = 1; i <= n; ++i)
+ {
+ h(j-1,k-1) = h(j-1,k-1) + ss2 * gvec(i-1);
+ tth = 4.0E+00 * tt2 + tt * ss2 - ss1;
+ ss1 = ss2;
+ ss2 = tth;
+ tth = tt * tt2 - tt1;
+ tt1 = tt2;
+ tt2 = tth;
+ }
+
+ h(j-1,k-1) = d2 * d1 * h(j-1,k-1);
+
+ }
+
+ }
+
+ h = make_symmetric(h);
+ return h;
+ }
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ double brown_residual (int i, const matrix<double,4,1>& x)
+ /*!
+ requires
+ - 1 <= i <= 20
+ ensures
+ - returns the ith brown residual
+ !*/
+ {
+ double c;
+ double f;
+ double f1;
+ double f2;
+
+ f = 0.0E+00;
+
+
+ c = double ( i ) / 5.0E+00;
+ f1 = x(0) + c * x(1) - std::exp ( c );
+ f2 = x(2) + std::sin ( c ) * x(3) - std::cos ( c );
+
+ f = f1*f1 + f2*f2;
+
+ return f;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ double brown ( const matrix<double,4,1>& x)
+ {
+ double f;
+ int i;
+
+ f = 0;
+
+ for (i = 1; i <= 20; ++i)
+ {
+ f += std::pow(brown_residual(i, x), 2);
+ }
+
+ return f;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,4,1> brown_derivative ( const matrix<double,4,1>& x)
+ {
+ double c;
+ double df1dx1;
+ double df1dx2;
+ double df2dx3;
+ double df2dx4;
+ double f1;
+ double f2;
+ matrix<double,4,1> g;
+ int i;
+
+ g = 0;
+
+ for (i = 1; i <= 20; ++i)
+ {
+
+ c = double ( i ) / 5.0E+00;
+
+ f1 = x(0) + c * x(1) - std::exp ( c );
+ f2 = x(2) + std::sin ( c ) * x(3) - std::cos ( c );
+
+ df1dx1 = 1.0E+00;
+ df1dx2 = c;
+ df2dx3 = 1.0E+00;
+ df2dx4 = std::sin ( c );
+
+ using std::pow;
+ g(0) = g(0) + 4.0E+00 * ( pow(f1,3) * df1dx1 + f1 * pow(f2,2) * df1dx1 );
+ g(1) = g(1) + 4.0E+00 * ( pow(f1,3) * df1dx2 + f1 * pow(f2,2) * df1dx2 );
+ g(2) = g(2) + 4.0E+00 * ( pow(f1,2) * f2 * df2dx3 + pow(f2,3) * df2dx3 );
+ g(3) = g(3) + 4.0E+00 * ( pow(f1,2) * f2 * df2dx4 + pow(f2,3) * df2dx4 );
+
+ }
+
+ return g;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,4,4> brown_hessian ( const matrix<double,4,1>& x)
+ {
+ double c;
+ double df1dx1;
+ double df1dx2;
+ double df2dx3;
+ double df2dx4;
+ double f1;
+ double f2;
+ matrix<double,4,4> h;
+ int i;
+
+ h = 0;
+
+ for (i = 1; i <= 20; ++i)
+ {
+
+ c = double ( i ) / 5.0E+00;
+
+ f1 = x(0) + c * x(1) - std::exp ( c );
+ f2 = x(2) + std::sin ( c ) * x(3) - std::cos ( c );
+
+ df1dx1 = 1.0E+00;
+ df1dx2 = c;
+ df2dx3 = 1.0E+00;
+ df2dx4 = std::sin ( c );
+
+ using std::pow;
+ h(0,0) = h(0,0) + 12.0E+00 * pow(f1,2) * df1dx1 * df1dx1 + 4.0E+00 * pow(f2,2) * df1dx1 * df1dx1;
+ h(0,1) = h(0,1) + 12.0E+00 * pow(f1,2) * df1dx1 * df1dx2 + 4.0E+00 * pow(f2,2) * df1dx1 * df1dx2;
+ h(0,2) = h(0,2) + 8.0E+00 * f1 * f2 * df1dx1 * df2dx3;
+ h(0,3) = h(0,3) + 8.0E+00 * f1 * f2 * df1dx1 * df2dx4;
+
+ h(1,0) = h(1,0) + 12.0E+00 * pow(f1,2) * df1dx2 * df1dx1 + 4.0E+00 * pow(f2,2) * df1dx2 * df1dx1;
+ h(1,1) = h(1,1) + 12.0E+00 * pow(f1,2) * df1dx2 * df1dx2 + 4.0E+00 * pow(f2,2) * df1dx2 * df1dx2;
+ h(1,2) = h(1,2) + 8.0E+00 * f1 * f2 * df1dx2 * df2dx3;
+ h(1,3) = h(1,3) + 8.0E+00 * f1 * f2 * df1dx2 * df2dx4;
+
+ h(2,0) = h(2,0) + 8.0E+00 * f1 * f2 * df2dx3 * df1dx1;
+ h(2,1) = h(2,1) + 8.0E+00 * f1 * f2 * df2dx3 * df1dx2;
+ h(2,2) = h(2,2) + 4.0E+00 * pow(f1,2) * df2dx3 * df2dx3 + 12.0E+00 * pow(f2,2) * df2dx3 * df2dx3;
+ h(2,3) = h(2,3) + 4.0E+00 * pow(f1,2) * df2dx4 * df2dx3 + 12.0E+00 * pow(f2,2) * df2dx3 * df2dx4;
+
+ h(3,0) = h(3,0) + 8.0E+00 * f1 * f2 * df2dx4 * df1dx1;
+ h(3,1) = h(3,1) + 8.0E+00 * f1 * f2 * df2dx4 * df1dx2;
+ h(3,2) = h(3,2) + 4.0E+00 * pow(f1,2) * df2dx3 * df2dx4 + 12.0E+00 * pow(f2,2) * df2dx4 * df2dx3;
+ h(3,3) = h(3,3) + 4.0E+00 * pow(f1,2) * df2dx4 * df2dx4 + 12.0E+00 * pow(f2,2) * df2dx4 * df2dx4;
+
+ }
+
+ return make_symmetric(h);
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,4,1> brown_start ()
+ {
+ matrix<double,4,1> x;
+ x = 25.0E+00, 5.0E+00, -5.0E+00, -1.0E+00;
+ return x;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,4,1> brown_solution ()
+ {
+ matrix<double,4,1> x;
+ // solution from original documentation.
+ //x = -11.5844E+00, 13.1999E+00, -0.406200E+00, 0.240998E+00;
+ x = -11.594439905669450042, 13.203630051593080452, -0.40343948856573402795, 0.23677877338218666914;
+ return x;
+ }
+
+ // ----------------------------------------------------------------------------------------
+
+ }
+}
+
+
diff --git a/ml/dlib/dlib/test/optimization_test_functions.h b/ml/dlib/dlib/test/optimization_test_functions.h
new file mode 100644
index 000000000..a20523eae
--- /dev/null
+++ b/ml/dlib/dlib/test/optimization_test_functions.h
@@ -0,0 +1,310 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_OPTIMIZATION_TEST_FUNCTiONS_H_h_
+#define DLIB_OPTIMIZATION_TEST_FUNCTiONS_H_h_
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <cmath>
+
+/*
+
+ Most of the code in this file is converted from the set of Fortran 90 routines
+ created by John Burkardt.
+
+ The original Fortran can be found here: http://orion.math.iastate.edu/burkardt/f_src/testopt/testopt.html
+
+*/
+
+// GCC 4.8 gives false alarms about some variables being uninitialized. Disable these
+// false warnings.
+#if ( defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
+ #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
+#endif
+
+
+namespace dlib
+{
+ namespace test_functions
+ {
+
+ // ----------------------------------------------------------------------------------------
+
+ matrix<double,0,1> chebyquad_residuals(const matrix<double,0,1>& x);
+
+ double chebyquad_residual(int i, const matrix<double,0,1>& x);
+
+ int& chebyquad_calls();
+
+ double chebyquad(const matrix<double,0,1>& x );
+
+ matrix<double,0,1> chebyquad_derivative (const matrix<double,0,1>& x);
+
+ matrix<double,0,1> chebyquad_start (int n);
+
+ matrix<double,0,1> chebyquad_solution (int n);
+
+ matrix<double> chebyquad_hessian(const matrix<double,0,1>& x);
+
+ // ----------------------------------------------------------------------------------------
+
+ class chebyquad_function_model
+ {
+ public:
+
+ // Define the type used to represent column vectors
+ typedef matrix<double,0,1> column_vector;
+ // Define the type used to represent the hessian matrix
+ typedef matrix<double> general_matrix;
+
+ double operator() (
+ const column_vector& x
+ ) const
+ {
+ return chebyquad(x);
+ }
+
+ void get_derivative_and_hessian (
+ const column_vector& x,
+ column_vector& d,
+ general_matrix& h
+ ) const
+ {
+ d = chebyquad_derivative(x);
+ h = chebyquad_hessian(x);
+ }
+ };
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ double brown_residual (int i, const matrix<double,4,1>& x);
+ /*!
+ requires
+ - 1 <= i <= 20
+ ensures
+ - returns the ith brown residual
+ !*/
+
+ double brown ( const matrix<double,4,1>& x);
+
+ matrix<double,4,1> brown_derivative ( const matrix<double,4,1>& x);
+
+ matrix<double,4,4> brown_hessian ( const matrix<double,4,1>& x);
+
+ matrix<double,4,1> brown_start ();
+
+ matrix<double,4,1> brown_solution ();
+
+ class brown_function_model
+ {
+ public:
+
+ // Define the type used to represent column vectors
+ typedef matrix<double,4,1> column_vector;
+ // Define the type used to represent the hessian matrix
+ typedef matrix<double> general_matrix;
+
+ double operator() (
+ const column_vector& x
+ ) const
+ {
+ return brown(x);
+ }
+
+ void get_derivative_and_hessian (
+ const column_vector& x,
+ column_vector& d,
+ general_matrix& h
+ ) const
+ {
+ d = brown_derivative(x);
+ h = brown_hessian(x);
+ }
+ };
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ matrix<T,2,1> rosen_big_start()
+ {
+ matrix<T,2,1> x;
+ x = -1.2, -1;
+ return x;
+ }
+
+ // This is a variation on the Rosenbrock test function but with large residuals. The
+ // minimum is at 1, 1 and the objective value is 1.
+ template <typename T>
+ T rosen_big_residual (int i, const matrix<T,2,1>& m)
+ {
+ using std::pow;
+ const T x = m(0);
+ const T y = m(1);
+
+ if (i == 1)
+ {
+ return 100*pow(y - x*x,2)+1.0;
+ }
+ else
+ {
+ return pow(1 - x,2) + 1.0;
+ }
+ }
+
+ template <typename T>
+ T rosen_big ( const matrix<T,2,1>& m)
+ {
+ using std::pow;
+ return 0.5*(pow(rosen_big_residual(1,m),2) + pow(rosen_big_residual(2,m),2));
+ }
+
+ template <typename T>
+ matrix<T,2,1> rosen_big_solution ()
+ {
+ matrix<T,2,1> x;
+ // solution from original documentation.
+ x = 1,1;
+ return x;
+ }
+
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+ // ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ matrix<T,2,1> rosen_start()
+ {
+ matrix<T,2,1> x;
+ x = -1.2, -1;
+ return x;
+ }
+
+ template <typename T>
+ T rosen ( const matrix<T,2,1>& m)
+ {
+ const T x = m(0);
+ const T y = m(1);
+
+ using std::pow;
+ // compute Rosenbrock's function and return the result
+ return 100.0*pow(y - x*x,2) + pow(1 - x,2);
+ }
+
+ template <typename T>
+ T rosen_residual (int i, const matrix<T,2,1>& m)
+ {
+ const T x = m(0);
+ const T y = m(1);
+
+
+ if (i == 1)
+ {
+ return 10*(y - x*x);
+ }
+ else
+ {
+ return 1 - x;
+ }
+ }
+
+ template <typename T>
+ matrix<T,2,1> rosen_residual_derivative (int i, const matrix<T,2,1>& m)
+ {
+ const T x = m(0);
+
+ matrix<T,2,1> d;
+
+ if (i == 1)
+ {
+ d = -20*x, 10;
+ }
+ else
+ {
+ d = -1, 0;
+ }
+ return d;
+ }
+
+ template <typename T>
+ const matrix<T,2,1> rosen_derivative ( const matrix<T,2,1>& m)
+ {
+ const T x = m(0);
+ const T y = m(1);
+
+ // make us a column vector of length 2
+ matrix<T,2,1> res(2);
+
+ // now compute the gradient vector
+ res(0) = -400*x*(y-x*x) - 2*(1-x); // derivative of rosen() with respect to x
+ res(1) = 200*(y-x*x); // derivative of rosen() with respect to y
+ return res;
+ }
+
+ template <typename T>
+ const matrix<T,2,2> rosen_hessian ( const matrix<T,2,1>& m)
+ {
+ const T x = m(0);
+ const T y = m(1);
+
+ // make us a column vector of length 2
+ matrix<T,2,2> res;
+
+ // now compute the gradient vector
+ res(0,0) = -400*y + 3*400*x*x + 2;
+ res(1,1) = 200;
+
+ res(0,1) = -400*x;
+ res(1,0) = -400*x;
+ return res;
+ }
+
+ template <typename T>
+ matrix<T,2,1> rosen_solution ()
+ {
+ matrix<T,2,1> x;
+ // solution from original documentation.
+ x = 1,1;
+ return x;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct rosen_function_model
+ {
+ typedef matrix<T,2,1> column_vector;
+ typedef matrix<T,2,2> general_matrix;
+
+ T operator() ( column_vector x) const
+ {
+ return static_cast<T>(rosen(x));
+ }
+
+ void get_derivative_and_hessian (
+ const column_vector& x,
+ column_vector& d,
+ general_matrix& h
+ ) const
+ {
+ d = rosen_derivative(x);
+ h = rosen_hessian(x);
+ }
+
+ };
+
+ // ----------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_OPTIMIZATION_TEST_FUNCTiONS_H_h_
+
+
+
diff --git a/ml/dlib/dlib/test/parallel_for.cpp b/ml/dlib/dlib/test/parallel_for.cpp
new file mode 100644
index 000000000..5bee40955
--- /dev/null
+++ b/ml/dlib/dlib/test/parallel_for.cpp
@@ -0,0 +1,334 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/threads.h>
+#include <vector>
+#include <sstream>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.parallel_for");
+
+ class assign_element
+ {
+ public:
+
+ assign_element(
+ std::vector<int>& vect_
+ ) : vect(vect_){}
+
+ std::vector<int>& vect;
+
+ void go (long i )
+ {
+ DLIB_TEST( 0 <= i && i < (long)vect.size());
+ vect[i] = i;
+ }
+
+ void operator() (long i ) const
+ {
+ DLIB_TEST( 0 <= i && i < (long)vect.size());
+ vect[i] = i;
+ }
+
+ };
+
+ void test_parallel_for(long start)
+ {
+ std::vector<int> vect(200,0);
+
+ parallel_for(4, start, vect.size(), assign_element(vect));
+
+ for (long i = 0; i < start; ++i)
+ {
+ DLIB_TEST(vect[i] == 0);
+ }
+ for (long i = start; i < (long)vect.size(); ++i)
+ {
+ DLIB_TEST(vect[i] == i);
+ }
+ }
+
+ void test_parallel_for2(long start)
+ {
+ std::vector<int> vect(200,0);
+
+ assign_element temp(vect);
+ parallel_for(4, start, vect.size(), temp, &assign_element::go);
+
+ for (long i = 0; i < start; ++i)
+ {
+ DLIB_TEST(vect[i] == 0);
+ }
+ for (long i = start; i < (long)vect.size(); ++i)
+ {
+ DLIB_TEST(vect[i] == i);
+ }
+ }
+
+ struct parfor_test_helper
+ {
+ mutable std::vector<int> test;
+
+ parfor_test_helper() : test(400,100000)
+ {
+ }
+
+ void go(long begin, long end)
+ {
+ for (long i = begin; i < end; ++i)
+ test[i] = i;
+ }
+
+ void operator()(long begin, long end) const
+ {
+ for (long i = begin; i < end; ++i)
+ test[i] = i;
+ }
+
+ void go2(long i)
+ {
+ test[i] = i;
+ }
+
+ };
+
+ struct parfor_test_helper2
+ {
+ mutable std::vector<int> test;
+
+ parfor_test_helper2() : test(400,100000)
+ {
+ }
+
+ void operator()(long i) const
+ {
+ test[i] = i;
+ }
+
+ };
+
+ void test_parallel_for_additional()
+ {
+ {
+ parfor_test_helper helper;
+ parallel_for(4, 0, helper.test.size(), helper, &parfor_test_helper::go2);
+
+ for (unsigned long i = 0; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for(4, 10, helper.test.size(), helper, &parfor_test_helper::go2);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]);
+ }
+ for (unsigned long i = 10; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for_blocked(4, 0, helper.test.size(), helper, &parfor_test_helper::go);
+
+ for (unsigned long i = 0; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for_blocked(4, 10, helper.test.size(), helper, &parfor_test_helper::go);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]);
+ }
+ for (unsigned long i = 10; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for_blocked(4, 0, helper.test.size(), helper);
+
+ for (unsigned long i = 0; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for_blocked(4, 10, helper.test.size(), helper);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]);
+ }
+ for (unsigned long i = 10; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper2 helper;
+ parallel_for(4, 0, helper.test.size(), helper);
+
+ for (unsigned long i = 0; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper2 helper;
+ parallel_for(4, 10, helper.test.size(), helper);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]);
+ }
+ for (unsigned long i = 10; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+
+
+
+
+
+
+ {
+ parfor_test_helper helper;
+ parallel_for_verbose(4, 0, helper.test.size(), helper, &parfor_test_helper::go2);
+
+ for (unsigned long i = 0; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for_verbose(4, 10, helper.test.size(), helper, &parfor_test_helper::go2);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]);
+ }
+ for (unsigned long i = 10; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for_blocked_verbose(4, 0, helper.test.size(), helper, &parfor_test_helper::go);
+
+ for (unsigned long i = 0; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for_blocked_verbose(4, 10, helper.test.size(), helper, &parfor_test_helper::go);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]);
+ }
+ for (unsigned long i = 10; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for_blocked_verbose(4, 0, helper.test.size(), helper);
+
+ for (unsigned long i = 0; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper helper;
+ parallel_for_blocked_verbose(4, 10, helper.test.size(), helper);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]);
+ }
+ for (unsigned long i = 10; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper2 helper;
+ parallel_for_verbose(4, 0, helper.test.size(), helper);
+
+ for (unsigned long i = 0; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ {
+ parfor_test_helper2 helper;
+ parallel_for_verbose(4, 10, helper.test.size(), helper);
+
+ for (unsigned long i = 0; i < 10; ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]);
+ }
+ for (unsigned long i = 10; i < helper.test.size(); ++i)
+ {
+ DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]);
+ }
+ }
+ }
+
+ class test_parallel_for_routines : public tester
+ {
+ public:
+ test_parallel_for_routines (
+ ) :
+ tester (
+ "test_parallel_for", // the command line argument name for this test
+ "Run tests on the parallel_for routines.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+ void perform_test (
+ )
+ {
+ test_parallel_for(0);
+ test_parallel_for(30);
+ test_parallel_for(50);
+ test_parallel_for2(0);
+ test_parallel_for2(30);
+ test_parallel_for2(50);
+
+ test_parallel_for_additional();
+ }
+ };
+
+ test_parallel_for_routines a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/parse.cpp b/ml/dlib/dlib/test/parse.cpp
new file mode 100644
index 000000000..b0ea13b1b
--- /dev/null
+++ b/ml/dlib/dlib/test/parse.cpp
@@ -0,0 +1,233 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <dlib/optimization.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+
+ logger dlog("test.parse");
+
+// ----------------------------------------------------------------------------------------
+
+ const unsigned long DET = 0;
+ const unsigned long N = 1;
+ const unsigned long V = 2;
+ const unsigned long NP = 3;
+ const unsigned long VP = 4;
+ const unsigned long S = 5;
+ const unsigned long B = 6;
+ const unsigned long G = 7;
+ const unsigned long A = 8;
+
+ typedef unsigned long tags;
+
+ template <bool has_glue_term>
+ void user_defined_ruleset (
+ const std::vector<tags>& words,
+ const constituent<tags>& c,
+ std::vector<std::pair<tags,double> >& possible_ids
+ )
+ {
+ DLIB_TEST(c.begin < c.k && c.k < c.end && c.end <= words.size());
+ DLIB_TEST(possible_ids.size() == 0);
+
+ if (c.left_tag == NP && c.right_tag == VP) possible_ids.push_back(make_pair(S,log(0.80)));
+ else if (c.left_tag == DET && c.right_tag == N) possible_ids.push_back(make_pair(NP,log(0.30)));
+ else if (c.left_tag == VP && c.right_tag == A) possible_ids.push_back(make_pair(VP,log(0.30)));
+ else if (c.left_tag == V && c.right_tag == NP)
+ {
+ possible_ids.push_back(make_pair(VP,log(0.20)));
+ possible_ids.push_back(make_pair(B,0.10));
+ }
+ else if (has_glue_term)
+ {
+ possible_ids.push_back(make_pair(G, log(0.01)));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void dotest1()
+ {
+ print_spinner();
+ dlog << LINFO << "in dotest1()";
+
+ std::vector<std::string> words;
+ std::vector<tags> sequence;
+ for (int i = 0; i < 8; ++i)
+ {
+ sequence.push_back(DET);
+ sequence.push_back(N);
+ sequence.push_back(V);
+ sequence.push_back(DET);
+ sequence.push_back(N);
+ sequence.push_back(A);
+
+ words.push_back("The");
+ words.push_back("flight");
+ words.push_back("includes");
+ words.push_back("a");
+ words.push_back("meal");
+ words.push_back("AWORD");
+ }
+
+ std::vector<parse_tree_element<tags> > parse_tree;
+
+ find_max_parse_cky(sequence, user_defined_ruleset<true>, parse_tree);
+ DLIB_TEST(parse_tree.size() != 0);
+
+
+ std::vector<unsigned long> roots;
+ find_trees_not_rooted_with_tag(parse_tree, G, roots);
+ DLIB_TEST(roots.size() == 8);
+
+ for (unsigned long i = 0; i < roots.size(); ++i)
+ {
+ dlog << LINFO << parse_tree_to_string(parse_tree, words, roots[i]);
+ DLIB_TEST(parse_tree_to_string(parse_tree, words, roots[i]) == "[[The flight] [[includes [a meal]] AWORD]]");
+ dlog << LINFO << parse_tree_to_string_tagged(parse_tree, words, roots[i]);
+ DLIB_TEST(parse_tree_to_string_tagged(parse_tree, words, roots[i]) == "[5 [3 The flight] [4 [4 includes [3 a meal]] AWORD]]");
+ }
+
+
+ words.clear();
+ sequence.clear();
+
+ for (int i = 0; i < 2; ++i)
+ {
+ sequence.push_back(DET);
+ sequence.push_back(N);
+ sequence.push_back(V);
+ sequence.push_back(DET);
+ sequence.push_back(N);
+
+ words.push_back("The");
+ words.push_back("flight");
+ words.push_back("includes");
+ words.push_back("a");
+ words.push_back("meal");
+ }
+
+ find_max_parse_cky(sequence, user_defined_ruleset<true>, parse_tree);
+ DLIB_TEST(parse_tree.size() != 0);
+
+ const std::string str1 = "[[[The flight] [includes [a meal]]] [[The flight] [includes [a meal]]]]";
+ const std::string str2 = "[7 [5 [3 The flight] [4 includes [3 a meal]]] [5 [3 The flight] [4 includes [3 a meal]]]]";
+ dlog << LINFO << parse_tree_to_string(parse_tree, words);
+ DLIB_TEST(parse_tree_to_string(parse_tree, words) == str1);
+ dlog << LINFO << parse_tree_to_string_tagged(parse_tree, words);
+ DLIB_TEST(parse_tree_to_string_tagged(parse_tree, words) == str2);
+
+ const std::string str3 = "[[The flight] [includes [a meal]]] [[The flight] [includes [a meal]]]";
+ const std::string str4 = "[5 [3 The flight] [4 includes [3 a meal]]] [5 [3 The flight] [4 includes [3 a meal]]]";
+ dlog << LINFO << parse_trees_to_string(parse_tree, words, G);
+ DLIB_TEST(parse_trees_to_string(parse_tree, words, G) == str3);
+ dlog << LINFO << parse_trees_to_string_tagged(parse_tree, words, G);
+ DLIB_TEST(parse_trees_to_string_tagged(parse_tree, words, G) == str4);
+
+ sequence.clear();
+ find_max_parse_cky(sequence, user_defined_ruleset<true>, parse_tree);
+ DLIB_TEST(parse_tree.size() == 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void dotest2()
+ {
+ print_spinner();
+ dlog << LINFO << "in dotest2()";
+
+ std::vector<std::string> words;
+ std::vector<tags> sequence;
+ for (int i = 0; i < 8; ++i)
+ {
+ sequence.push_back(DET);
+ sequence.push_back(N);
+ sequence.push_back(V);
+ sequence.push_back(DET);
+ sequence.push_back(N);
+
+ words.push_back("The");
+ words.push_back("flight");
+ words.push_back("includes");
+ words.push_back("a");
+ words.push_back("meal");
+ }
+
+ std::vector<parse_tree_element<tags> > parse_tree;
+
+ find_max_parse_cky(sequence, user_defined_ruleset<false>, parse_tree);
+ DLIB_TEST(parse_tree.size() == 0);
+
+
+ std::vector<unsigned long> roots;
+ find_trees_not_rooted_with_tag(parse_tree, G, roots);
+ DLIB_TEST(roots.size() == 0);
+
+
+ words.clear();
+ sequence.clear();
+
+ for (int i = 0; i < 2; ++i)
+ {
+ sequence.push_back(DET);
+ sequence.push_back(N);
+ sequence.push_back(V);
+ sequence.push_back(DET);
+ sequence.push_back(N);
+
+ words.push_back("The");
+ words.push_back("flight");
+ words.push_back("includes");
+ words.push_back("a");
+ words.push_back("meal");
+ }
+
+ find_max_parse_cky(sequence, user_defined_ruleset<false>, parse_tree);
+ DLIB_TEST(parse_tree.size() == 0);
+
+ sequence.clear();
+ find_max_parse_cky(sequence, user_defined_ruleset<false>, parse_tree);
+ DLIB_TEST(parse_tree.size() == 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class parse_tester : public tester
+ {
+ public:
+ parse_tester (
+ ) :
+ tester ("test_parse",
+ "Runs tests on the parsing tools.")
+ {}
+
+
+ void perform_test (
+ )
+ {
+ dotest1();
+ dotest2();
+ }
+ } a;
+
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/pipe.cpp b/ml/dlib/dlib/test/pipe.cpp
new file mode 100644
index 000000000..d84dd0d44
--- /dev/null
+++ b/ml/dlib/dlib/test/pipe.cpp
@@ -0,0 +1,688 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/misc_api.h>
+#include <dlib/pipe.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.pipe");
+
+ namespace pipe_kernel_test_helpers
+ {
+ const unsigned long proc1_count = 10000;
+ dlib::mutex m;
+ signaler s(m);
+ unsigned long threads_running = 0;
+ bool found_error;
+
+ inline void add_running_thread (
+ )
+ {
+ auto_mutex M(m);
+ ++threads_running;
+ }
+
+ inline void remove_running_thread (
+ )
+ {
+ auto_mutex M(m);
+ --threads_running;
+ s.broadcast();
+ }
+
+ inline void wait_for_threads (
+ )
+ {
+ auto_mutex M(m);
+ while (threads_running > 0)
+ s.wait();
+ }
+
+ template <
+ typename pipe
+ >
+ void threadproc1 (
+ void* param
+ )
+ {
+ add_running_thread();
+ pipe& p = *static_cast<pipe*>(param);
+ try
+ {
+
+ int last = -1;
+ for (unsigned long i = 0; i < proc1_count; ++i)
+ {
+ int cur=0;
+ DLIB_TEST(p.dequeue(cur) == true);
+ DLIB_TEST(last + 1 == cur);
+ last = cur;
+ }
+ DLIB_TEST(p.size() == 0);
+ }
+ catch(exception& e)
+ {
+ auto_mutex M(m);
+ found_error = true;
+ cout << "\n\nERRORS FOUND" << endl;
+ cout << e.what() << endl;
+ dlog << LWARN << "ERRORS FOUND";
+ dlog << LWARN << e.what();
+ p.disable();
+ }
+
+ remove_running_thread();
+ }
+
+
+ template <
+ typename pipe
+ >
+ void threadproc2 (
+ void* param
+ )
+ {
+ add_running_thread();
+ pipe& p = *static_cast<pipe*>(param);
+ try
+ {
+
+ int last = -1;
+ int cur;
+ while (p.dequeue(cur))
+ {
+ DLIB_TEST(last < cur);
+ last = cur;
+ }
+ auto_mutex M(m);
+ }
+ catch(exception& e)
+ {
+ auto_mutex M(m);
+ found_error = true;
+ cout << "\n\nERRORS FOUND" << endl;
+ cout << e.what() << endl;
+ dlog << LWARN << "ERRORS FOUND";
+ dlog << LWARN << e.what();
+ p.disable();
+ }
+ remove_running_thread();
+ }
+
+
+
+ template <
+ typename pipe
+ >
+ void threadproc3 (
+ void* param
+ )
+ {
+ add_running_thread();
+ pipe& p = *static_cast<pipe*>(param);
+ try
+ {
+
+ int last = -1;
+ int cur;
+ while (p.dequeue_or_timeout(cur,100000))
+ {
+ DLIB_TEST(last < cur);
+ last = cur;
+ }
+ auto_mutex M(m);
+ }
+ catch(exception& e)
+ {
+ auto_mutex M(m);
+ found_error = true;
+ cout << "\n\nERRORS FOUND" << endl;
+ cout << e.what() << endl;
+ dlog << LWARN << "ERRORS FOUND";
+ dlog << LWARN << e.what();
+ p.disable();
+ }
+ remove_running_thread();
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template<typename in_type, typename out_type>
+ class PipelineProcessor : private dlib::threaded_object
+ {
+ public:
+ PipelineProcessor(
+ dlib::pipe<in_type> & in,
+ dlib::pipe<out_type> & out) :
+ InPipe(in),
+ OutPipe(out),
+ InMsg(),
+ OutMsg() {
+ start();
+ }
+
+ ~PipelineProcessor() {
+ // signal the thread to stop
+ stop();
+ wait();
+ }
+
+ private:
+ dlib::pipe<in_type> & InPipe;
+ dlib::pipe<out_type> & OutPipe;
+
+ in_type InMsg;
+ out_type OutMsg;
+
+ void thread()
+ {
+ while (!should_stop()) {
+ if(InPipe.dequeue_or_timeout(InMsg, 100))
+ {
+ // if function signals ready to send OutMsg
+ while (!OutPipe.enqueue_or_timeout(OutMsg, 100))
+ {
+ // try to send until should stop
+ if (should_stop())
+ {
+ return;
+ }
+ }
+ }
+ }
+ };
+ };
+
+
+ void do_zero_size_test_with_timeouts()
+ {
+ dlog << LINFO << "in do_zero_size_test_with_timeouts()";
+ // make sure we can get though this without deadlocking
+ for (int k = 0; k < 10; ++k)
+ {
+ dlib::pipe<int> in_pipe(10);
+ dlib::pipe<float> out_pipe(0);
+ {
+ PipelineProcessor<int, float> pp(in_pipe, out_pipe);
+
+ int in = 1;
+ in_pipe.enqueue(in);
+ in = 2;
+ in_pipe.enqueue(in);
+ in = 3;
+ in_pipe.enqueue(in);
+ // sleep to make sure thread enqueued
+ dlib::sleep(100);
+
+ float out = 1.0f;
+ out_pipe.dequeue(out);
+ dlib::sleep(100);
+ }
+ print_spinner();
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pipe
+ >
+ void pipe_kernel_test (
+ )
+ /*!
+ requires
+ - pipe is an implementation of pipe/pipe_kernel_abstract.h and
+ is instantiated with int
+ ensures
+ - runs tests on pipe for compliance with the specs
+ !*/
+ {
+ using namespace pipe_kernel_test_helpers;
+ found_error = false;
+
+
+ print_spinner();
+ pipe test(10), test2(100);
+ pipe test_0(0), test2_0(0);
+ pipe test_1(1), test2_1(1);
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test_0.size() == 0);
+ DLIB_TEST(test2_0.size() == 0);
+ DLIB_TEST(test_1.size() == 0);
+ DLIB_TEST(test2_1.size() == 0);
+
+ DLIB_TEST(test.is_enqueue_enabled() == true);
+ DLIB_TEST(test.is_dequeue_enabled() == true);
+ DLIB_TEST(test.is_enabled() == true);
+
+ test.empty();
+ test2.empty();
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test2.size() == 0);
+
+ test_0.empty();
+ test2_0.empty();
+ DLIB_TEST(test_0.size() == 0);
+ DLIB_TEST(test2_0.size() == 0);
+
+ test_1.empty();
+ test2_1.empty();
+ DLIB_TEST(test_1.size() == 0);
+ DLIB_TEST(test2_1.size() == 0);
+
+
+
+ int a;
+ a = 3;
+ test.enqueue(a);
+ DLIB_TEST(test.size() == 1);
+ a = 5;
+ test.enqueue(a);
+ DLIB_TEST(test.size() == 2);
+
+ a = 0;
+ test.dequeue(a);
+ DLIB_TEST(a == 3);
+ DLIB_TEST(test.size() == 1);
+
+ a = 0;
+ test.dequeue(a);
+ DLIB_TEST(a == 5);
+ DLIB_TEST(test.size() == 0);
+
+
+ print_spinner();
+ {
+ dlog << LINFO << "starting normal length pipe tests";
+ create_new_thread(&threadproc1<pipe>,&test);
+ create_new_thread(&threadproc2<pipe>,&test2);
+ create_new_thread(&threadproc2<pipe>,&test2);
+ create_new_thread(&threadproc2<pipe>,&test2);
+
+ for (unsigned long i = 0; i < proc1_count; ++i)
+ {
+ a = i;
+ test.enqueue(a);
+ }
+ DLIB_TEST(test.is_enqueue_enabled() == true);
+ test.disable_enqueue();
+ DLIB_TEST(test.is_enqueue_enabled() == false);
+ for (unsigned long i = 0; i < proc1_count; ++i)
+ {
+ a = i;
+ test.enqueue(a);
+ }
+
+ for (unsigned long i = 0; i < 100000; ++i)
+ {
+ a = i;
+ if (i%2 == 0)
+ test2.enqueue(a);
+ else
+ test2.enqueue_or_timeout(a,100000);
+ }
+
+ test2.wait_for_num_blocked_dequeues(3);
+ DLIB_TEST(test2.size() == 0);
+ test2.disable();
+
+ wait_for_threads();
+ DLIB_TEST(test2.size() == 0);
+
+ test2.enable();
+
+ print_spinner();
+
+ create_new_thread(&threadproc3<pipe>,&test2);
+ create_new_thread(&threadproc3<pipe>,&test2);
+
+
+ for (unsigned long i = 0; i < 100000; ++i)
+ {
+ a = i;
+ if (i%2 == 0)
+ test2.enqueue(a);
+ else
+ test2.enqueue_or_timeout(a,100000);
+ }
+
+ test2.wait_for_num_blocked_dequeues(2);
+ DLIB_TEST(test2.size() == 0);
+ test2.disable();
+
+ wait_for_threads();
+ DLIB_TEST(test2.size() == 0);
+
+ }
+
+
+ print_spinner();
+ {
+ dlog << LINFO << "starting 0 length pipe tests";
+ create_new_thread(&threadproc1<pipe>,&test_0);
+ create_new_thread(&threadproc2<pipe>,&test2_0);
+ create_new_thread(&threadproc2<pipe>,&test2_0);
+ create_new_thread(&threadproc2<pipe>,&test2_0);
+ dlog << LTRACE << "0: 1";
+
+ for (unsigned long i = 0; i < proc1_count; ++i)
+ {
+ a = i;
+ test_0.enqueue(a);
+ }
+
+ dlog << LTRACE << "0: 2";
+ DLIB_TEST(test_0.is_enqueue_enabled() == true);
+ test_0.disable_enqueue();
+ DLIB_TEST(test_0.is_enqueue_enabled() == false);
+ for (unsigned long i = 0; i < proc1_count; ++i)
+ {
+ a = i;
+ test_0.enqueue(a);
+ }
+
+ dlog << LTRACE << "0: 3";
+ for (unsigned long i = 0; i < 100000; ++i)
+ {
+ a = i;
+ if (i%2 == 0)
+ test2_0.enqueue(a);
+ else
+ test2_0.enqueue_or_timeout(a,100000);
+ }
+
+ print_spinner();
+ dlog << LTRACE << "0: 4";
+ test2_0.wait_for_num_blocked_dequeues(3);
+ DLIB_TEST(test2_0.size() == 0);
+ test2_0.disable();
+
+ wait_for_threads();
+ DLIB_TEST(test2_0.size() == 0);
+
+ dlog << LTRACE << "0: 5";
+ test2_0.enable();
+
+
+ create_new_thread(&threadproc3<pipe>,&test2_0);
+ create_new_thread(&threadproc3<pipe>,&test2_0);
+
+
+ for (unsigned long i = 0; i < 20000; ++i)
+ {
+ if ((i%100) == 0)
+ print_spinner();
+
+ a = i;
+ if (i%2 == 0)
+ test2_0.enqueue(a);
+ else
+ test2_0.enqueue_or_timeout(a,100000);
+ }
+
+ dlog << LTRACE << "0: 6";
+ test2_0.wait_for_num_blocked_dequeues(2);
+ DLIB_TEST(test2_0.size() == 0);
+ test2_0.disable();
+
+ wait_for_threads();
+ DLIB_TEST(test2_0.size() == 0);
+
+ dlog << LTRACE << "0: 7";
+ }
+
+ print_spinner();
+ {
+ dlog << LINFO << "starting 1 length pipe tests";
+ create_new_thread(&threadproc1<pipe>,&test_1);
+ create_new_thread(&threadproc2<pipe>,&test2_1);
+ create_new_thread(&threadproc2<pipe>,&test2_1);
+ create_new_thread(&threadproc2<pipe>,&test2_1);
+
+ for (unsigned long i = 0; i < proc1_count; ++i)
+ {
+ a = i;
+ test_1.enqueue(a);
+ }
+ DLIB_TEST(test_1.is_enqueue_enabled() == true);
+ test_1.disable_enqueue();
+ DLIB_TEST(test_1.is_enqueue_enabled() == false);
+ for (unsigned long i = 0; i < proc1_count; ++i)
+ {
+ a = i;
+ test_1.enqueue(a);
+ }
+ print_spinner();
+
+ for (unsigned long i = 0; i < 100000; ++i)
+ {
+ a = i;
+ if (i%2 == 0)
+ test2_1.enqueue(a);
+ else
+ test2_1.enqueue_or_timeout(a,100000);
+ }
+
+ test2_1.wait_for_num_blocked_dequeues(3);
+ DLIB_TEST(test2_1.size() == 0);
+ test2_1.disable();
+
+ wait_for_threads();
+ DLIB_TEST(test2_1.size() == 0);
+
+ test2_1.enable();
+
+
+ create_new_thread(&threadproc3<pipe>,&test2_1);
+ create_new_thread(&threadproc3<pipe>,&test2_1);
+
+
+ for (unsigned long i = 0; i < 100000; ++i)
+ {
+ a = i;
+ if (i%2 == 0)
+ test2_1.enqueue(a);
+ else
+ test2_1.enqueue_or_timeout(a,100000);
+ }
+
+ test2_1.wait_for_num_blocked_dequeues(2);
+ DLIB_TEST(test2_1.size() == 0);
+ test2_1.disable();
+
+ wait_for_threads();
+ DLIB_TEST(test2_1.size() == 0);
+
+ }
+
+ test.enable_enqueue();
+ test_0.enable_enqueue();
+ test_1.enable_enqueue();
+
+ DLIB_TEST(test.is_enabled());
+ DLIB_TEST(test.is_enqueue_enabled());
+ DLIB_TEST(test_0.is_enabled());
+ DLIB_TEST(test_0.is_enqueue_enabled());
+ DLIB_TEST(test_1.is_enabled());
+ DLIB_TEST(test_1.is_enqueue_enabled());
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test_0.size() == 0);
+ DLIB_TEST(test_1.size() == 0);
+ DLIB_TEST(test.max_size() == 10);
+ DLIB_TEST(test_0.max_size() == 0);
+ DLIB_TEST(test_1.max_size() == 1);
+
+
+ for (int i = 0; i < 100; ++i)
+ {
+ a = 1;
+ test.enqueue_or_timeout(a,0);
+ a = 1;
+ test_0.enqueue_or_timeout(a,0);
+ a = 1;
+ test_1.enqueue_or_timeout(a,0);
+ }
+
+ DLIB_TEST_MSG(test.size() == 10,"size: " << test.size() );
+ DLIB_TEST_MSG(test_0.size() == 0,"size: " << test.size() );
+ DLIB_TEST_MSG(test_1.size() == 1,"size: " << test.size() );
+
+ for (int i = 0; i < 10; ++i)
+ {
+ a = 0;
+ DLIB_TEST(test.enqueue_or_timeout(a,10) == false);
+ a = 0;
+ DLIB_TEST(test_0.enqueue_or_timeout(a,10) == false);
+ a = 0;
+ DLIB_TEST(test_1.enqueue_or_timeout(a,10) == false);
+ }
+
+ DLIB_TEST_MSG(test.size() == 10,"size: " << test.size() );
+ DLIB_TEST_MSG(test_0.size() == 0,"size: " << test.size() );
+ DLIB_TEST_MSG(test_1.size() == 1,"size: " << test.size() );
+
+ for (int i = 0; i < 10; ++i)
+ {
+ a = 0;
+ DLIB_TEST(test.dequeue_or_timeout(a,0) == true);
+ DLIB_TEST(a == 1);
+ }
+
+ DLIB_TEST(test.max_size() == 10);
+ DLIB_TEST(test_0.max_size() == 0);
+ DLIB_TEST(test_1.max_size() == 1);
+
+ a = 0;
+ DLIB_TEST(test_1.dequeue_or_timeout(a,0) == true);
+
+ DLIB_TEST(test.max_size() == 10);
+ DLIB_TEST(test_0.max_size() == 0);
+ DLIB_TEST(test_1.max_size() == 1);
+
+
+ DLIB_TEST_MSG(a == 1,"a: " << a);
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test_0.size() == 0);
+ DLIB_TEST(test_1.size() == 0);
+
+ DLIB_TEST(test.dequeue_or_timeout(a,0) == false);
+ DLIB_TEST(test_0.dequeue_or_timeout(a,0) == false);
+ DLIB_TEST(test_1.dequeue_or_timeout(a,0) == false);
+ DLIB_TEST(test.dequeue_or_timeout(a,10) == false);
+ DLIB_TEST(test_0.dequeue_or_timeout(a,10) == false);
+ DLIB_TEST(test_1.dequeue_or_timeout(a,10) == false);
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test_0.size() == 0);
+ DLIB_TEST(test_1.size() == 0);
+
+ DLIB_TEST(found_error == false);
+
+
+
+
+ {
+ test.enable();
+ test.enable_enqueue();
+ test.empty();
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_enabled() == true);
+ DLIB_TEST(test.is_enqueue_enabled() == true);
+ DLIB_TEST(test.is_dequeue_enabled() == true);
+ test.disable_dequeue();
+ dlog << LINFO << "Make sure disable_dequeue() works right...";
+ DLIB_TEST(test.is_dequeue_enabled() == false);
+ DLIB_TEST(test.dequeue(a) == false);
+ test.wait_until_empty();
+ a = 4;
+ test.enqueue(a);
+ test.wait_until_empty();
+ test.wait_for_num_blocked_dequeues(4);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(test.dequeue(a) == false);
+ DLIB_TEST(test.dequeue_or_timeout(a,10000) == false);
+ DLIB_TEST(test.size() == 1);
+ a = 0;
+ test.enable_dequeue();
+ DLIB_TEST(test.is_dequeue_enabled() == true);
+ DLIB_TEST(test.dequeue(a) == true);
+ DLIB_TEST(a == 4);
+ test_1.wait_until_empty();
+ }
+ {
+ test_1.enable();
+ test_1.enable_enqueue();
+ test_1.empty();
+ DLIB_TEST(test_1.size() == 0);
+ DLIB_TEST(test_1.is_enabled() == true);
+ DLIB_TEST(test_1.is_enqueue_enabled() == true);
+ DLIB_TEST(test_1.is_dequeue_enabled() == true);
+ test_1.disable_dequeue();
+ dlog << LINFO << "Make sure disable_dequeue() works right...";
+ DLIB_TEST(test_1.is_dequeue_enabled() == false);
+ DLIB_TEST(test_1.dequeue(a) == false);
+ a = 4;
+ test_1.wait_for_num_blocked_dequeues(4);
+ test_1.wait_for_num_blocked_dequeues(0);
+ test_1.enqueue(a);
+ test_1.wait_until_empty();
+ DLIB_TEST(test_1.size() == 1);
+ DLIB_TEST(test_1.dequeue(a) == false);
+ DLIB_TEST(test_1.dequeue_or_timeout(a,10000) == false);
+ DLIB_TEST(test_1.size() == 1);
+ a = 0;
+ test_1.enable_dequeue();
+ DLIB_TEST(test_1.is_dequeue_enabled() == true);
+ DLIB_TEST(test_1.dequeue(a) == true);
+ DLIB_TEST(a == 4);
+ test_1.wait_until_empty();
+ }
+
+ }
+
+
+
+
+ class pipe_tester : public tester
+ {
+ public:
+ pipe_tester (
+ ) :
+ tester ("test_pipe",
+ "Runs tests on the pipe component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ pipe_kernel_test<dlib::pipe<int> >();
+
+ do_zero_size_test_with_timeouts();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/pixel.cpp b/ml/dlib/dlib/test/pixel.cpp
new file mode 100644
index 000000000..40772b130
--- /dev/null
+++ b/ml/dlib/dlib/test/pixel.cpp
@@ -0,0 +1,777 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/pixel.h>
+#include <dlib/matrix.h>
+#include <dlib/image_io.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.pixel");
+
+
+ void pixel_test (
+ )
+ /*!
+ ensures
+ - runs tests on pixel objects and functions for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ unsigned char p_gray;
+ unsigned short p_gray16;
+ long p_int;
+ float p_float;
+ signed char p_schar;
+ rgb_pixel p_rgb,p_rgb2;
+ hsi_pixel p_hsi, p_hsi2;
+ rgb_alpha_pixel p_rgba;
+ lab_pixel p_lab, p_lab2;
+
+ assign_pixel(p_int, 0.0f);
+ assign_pixel(p_float, 0.0f);
+ assign_pixel(p_schar, 0);
+
+ assign_pixel(p_gray, -2);
+ assign_pixel(p_rgb,0);
+ assign_pixel(p_hsi, -4);
+ assign_pixel(p_rgba, p_int);
+ assign_pixel(p_gray16,0);
+ assign_pixel(p_lab,-400);
+
+ DLIB_TEST(p_int == 0);
+ DLIB_TEST(p_float == 0);
+ DLIB_TEST(p_schar == 0);
+
+ DLIB_TEST(p_gray == 0);
+ DLIB_TEST(p_gray16 == 0);
+
+ DLIB_TEST(p_rgb.red == 0);
+ DLIB_TEST(p_rgb.green == 0);
+ DLIB_TEST(p_rgb.blue == 0);
+
+ DLIB_TEST(p_rgba.red == 0);
+ DLIB_TEST(p_rgba.green == 0);
+ DLIB_TEST(p_rgba.blue == 0);
+ DLIB_TEST(p_rgba.alpha == 255);
+
+ DLIB_TEST(p_hsi.h == 0);
+ DLIB_TEST(p_hsi.s == 0);
+ DLIB_TEST(p_hsi.i == 0);
+
+ DLIB_TEST(p_lab.l == 0);
+ DLIB_TEST(p_lab.a == 128);
+ DLIB_TEST(p_lab.b == 128);
+
+ assign_pixel(p_gray,10);
+ assign_pixel(p_gray16,10);
+ assign_pixel(p_rgb,10);
+ assign_pixel(p_hsi,10);
+ assign_pixel(p_rgba,10);
+ assign_pixel(p_lab,10);
+
+ assign_pixel(p_int, -10);
+ assign_pixel(p_float, -10);
+ assign_pixel(p_schar, -10);
+
+ DLIB_TEST(p_int == -10);
+ DLIB_TEST(p_float == -10);
+ DLIB_TEST(p_schar == -10);
+
+ DLIB_TEST(p_gray == 10);
+ DLIB_TEST(p_gray16 == 10);
+
+ DLIB_TEST(p_rgb.red == 10);
+ DLIB_TEST(p_rgb.green == 10);
+ DLIB_TEST(p_rgb.blue == 10);
+
+ DLIB_TEST(p_rgba.red == 10);
+ DLIB_TEST(p_rgba.green == 10);
+ DLIB_TEST(p_rgba.blue == 10);
+ DLIB_TEST(p_rgba.alpha == 255);
+
+ DLIB_TEST(p_hsi.h == 0);
+ DLIB_TEST(p_hsi.s == 0);
+ DLIB_TEST(p_hsi.i == 10);
+
+ DLIB_TEST(p_lab.l == 10);
+ DLIB_TEST(p_lab.a == 128);
+ DLIB_TEST(p_lab.b == 128);
+
+ assign_pixel(p_gray16,12345);
+ DLIB_TEST(p_gray16 == 12345);
+
+ assign_pixel(p_float,3.141);
+ DLIB_TEST(p_float == 3.141f);
+
+ p_rgb.red = 255;
+ p_rgb.green = 100;
+ p_rgb.blue = 50;
+
+ p_rgba.alpha = 4;
+ assign_pixel(p_gray,p_rgb);
+ assign_pixel(p_rgb,p_rgb);
+ assign_pixel(p_rgba,p_rgb);
+ assign_pixel(p_hsi,p_rgb);
+ assign_pixel(p_lab,p_rgb);
+
+ assign_pixel(p_float,p_rgb);
+ assign_pixel(p_int,p_rgb);
+ assign_pixel(p_schar,p_rgb);
+
+ DLIB_TEST(p_schar == std::numeric_limits<signed char>::max());
+
+ DLIB_TEST(p_int == (255+100+50)/3);
+ DLIB_TEST_MSG(p_float == (255+100+50)/3, p_float - (255+100+50)/3);
+ DLIB_TEST(p_gray == (255+100+50)/3);
+
+ DLIB_TEST(p_rgb.red == 255);
+ DLIB_TEST(p_rgb.green == 100);
+ DLIB_TEST(p_rgb.blue == 50);
+
+ DLIB_TEST(p_rgba.red == 255);
+ DLIB_TEST(p_rgba.green == 100);
+ DLIB_TEST(p_rgba.blue == 50);
+ DLIB_TEST(p_rgba.alpha == 255);
+
+ DLIB_TEST(p_hsi.i > 0);
+ DLIB_TEST(p_hsi.s > 0);
+ DLIB_TEST(p_hsi.h > 0);
+
+ DLIB_TEST(p_lab.l > 0);
+ DLIB_TEST(p_lab.a > 0);
+ DLIB_TEST(p_lab.b > 0);
+
+ assign_pixel(p_rgb,0);
+ DLIB_TEST(p_rgb.red == 0);
+ DLIB_TEST(p_rgb.green == 0);
+ DLIB_TEST(p_rgb.blue == 0);
+ assign_pixel(p_rgb, p_hsi);
+
+ DLIB_TEST_MSG(p_rgb.red > 251 ,(int)p_rgb.green);
+ DLIB_TEST_MSG(p_rgb.green > 96 && p_rgb.green < 104,(int)p_rgb.green);
+ DLIB_TEST_MSG(p_rgb.blue > 47 && p_rgb.blue < 53,(int)p_rgb.green);
+
+ assign_pixel(p_rgb,0);
+ DLIB_TEST(p_rgb.red == 0);
+ DLIB_TEST(p_rgb.green == 0);
+ DLIB_TEST(p_rgb.blue == 0);
+
+ assign_pixel(p_rgb, p_lab);
+ DLIB_TEST_MSG(p_rgb.red > 251 ,(int)p_rgb.green);
+ DLIB_TEST_MSG(p_rgb.green > 96 && p_rgb.green < 104,(int)p_rgb.green);
+ DLIB_TEST_MSG(p_rgb.blue > 47 && p_rgb.blue < 53,(int)p_rgb.green);
+
+ assign_pixel(p_hsi2, p_hsi);
+ DLIB_TEST(p_hsi.h == p_hsi2.h);
+ DLIB_TEST(p_hsi.s == p_hsi2.s);
+ DLIB_TEST(p_hsi.i == p_hsi2.i);
+ assign_pixel(p_hsi,0);
+ DLIB_TEST(p_hsi.h == 0);
+ DLIB_TEST(p_hsi.s == 0);
+ DLIB_TEST(p_hsi.i == 0);
+ assign_pixel(p_hsi, p_rgba);
+
+ DLIB_TEST(p_hsi.h == p_hsi2.h);
+ DLIB_TEST(p_hsi.s == p_hsi2.s);
+ DLIB_TEST(p_hsi.i == p_hsi2.i);
+
+ assign_pixel(p_lab2, p_lab);
+ DLIB_TEST(p_lab.l == p_lab2.l);
+ DLIB_TEST(p_lab.a == p_lab2.a);
+ DLIB_TEST(p_lab.b == p_lab2.b);
+ assign_pixel(p_lab,0);
+ DLIB_TEST(p_lab.l == 0);
+ DLIB_TEST(p_lab.a == 128);
+ DLIB_TEST(p_lab.b == 128);
+ assign_pixel(p_lab, p_rgba);
+
+ DLIB_TEST(p_lab.l == p_lab2.l);
+ DLIB_TEST(p_lab.a == p_lab2.a);
+ DLIB_TEST(p_lab.b == p_lab2.b);
+
+ assign_pixel(p_rgba, 100);
+ assign_pixel(p_gray, 10);
+ assign_pixel(p_rgb, 10);
+ assign_pixel(p_hsi, 10);
+
+ assign_pixel(p_schar, 10);
+ assign_pixel(p_float, 10);
+ assign_pixel(p_int, 10);
+
+ p_rgba.alpha = 0;
+ assign_pixel(p_gray, p_rgba);
+ DLIB_TEST(p_gray == 10);
+ assign_pixel(p_schar, p_rgba);
+ DLIB_TEST(p_schar == 10);
+ assign_pixel(p_int, p_rgba);
+ DLIB_TEST(p_int == 10);
+ assign_pixel(p_float, p_rgba);
+ DLIB_TEST(p_float == 10);
+ assign_pixel(p_rgb, p_rgba);
+ DLIB_TEST(p_rgb.red == 10);
+ DLIB_TEST(p_rgb.green == 10);
+ DLIB_TEST(p_rgb.blue == 10);
+
+ assign_pixel(p_hsi, p_rgba);
+ assign_pixel(p_hsi2, p_rgb);
+ DLIB_TEST(p_hsi.h == 0);
+ DLIB_TEST(p_hsi.s == 0);
+ DLIB_TEST_MSG(p_hsi.i < p_hsi2.i+2 && p_hsi.i > p_hsi2.i -2,(int)p_hsi.i << " " << (int)p_hsi2.i);
+
+ // this value corresponds to RGB(10,10,10)
+ p_lab.l = 7;
+ p_lab.a = 128;
+ p_lab.b = 128;
+
+ assign_pixel(p_lab, p_rgba);
+ assign_pixel(p_lab2, p_rgb);
+ DLIB_TEST(p_lab.a == 128);
+ DLIB_TEST(p_lab.b == 128);
+ DLIB_TEST_MSG(p_lab.l < p_lab2.l+2 && p_lab.l > p_lab2.l -2,(int)p_lab.l << " " << (int)p_lab2.l);
+
+ assign_pixel(p_lab, 128);
+ DLIB_TEST(p_lab.l == 128);
+ DLIB_TEST(p_lab.a == 128);
+ DLIB_TEST(p_lab.b == 128);
+ assign_pixel(p_rgb, p_lab);
+ //Lab midpoint (50,0,0) is not same as RGB midpoint (127,127,127)
+ DLIB_TEST(p_rgb.red == 119);
+ DLIB_TEST(p_rgb.green == 119);
+ DLIB_TEST(p_rgb.blue == 119);
+
+ //Lab limit values test
+ //red, green, blue, yellow, black, white
+ p_lab.l = 84;
+ p_lab.a = 164;
+ p_lab.b = 56;
+ assign_pixel(p_rgb, p_lab);
+ DLIB_TEST(p_rgb.red == 0);
+ DLIB_TEST(p_rgb.green == 64);
+ DLIB_TEST(p_rgb.blue == 194);
+
+ p_lab.l = 255;
+ p_lab.a = 0;
+ p_lab.b = 0;
+ assign_pixel(p_rgb, p_lab);
+ DLIB_TEST(p_rgb.red == 0);
+ DLIB_TEST(p_rgb.green == 255);
+ DLIB_TEST(p_rgb.blue == 255);
+
+ p_lab.l = 0;
+ p_lab.a = 255;
+ p_lab.b = 0;
+
+ assign_pixel(p_rgb, p_lab);
+ DLIB_TEST(p_rgb.red == 0);
+ DLIB_TEST(p_rgb.green == 0);
+ DLIB_TEST(p_rgb.blue == 195);
+
+ p_lab.l = 0;
+ p_lab.a = 0;
+ p_lab.b = 255;
+
+ assign_pixel(p_rgb, p_lab);
+ DLIB_TEST(p_rgb.red == 0);
+ DLIB_TEST(p_rgb.green == 45);
+ DLIB_TEST(p_rgb.blue == 0);
+
+ p_lab.l = 255;
+ p_lab.a = 255;
+ p_lab.b = 0;
+ assign_pixel(p_rgb, p_lab);
+ DLIB_TEST(p_rgb.red == 255);
+ DLIB_TEST(p_rgb.green == 139);
+ DLIB_TEST(p_rgb.blue == 255);
+
+ p_lab.l = 0;
+ p_lab.a = 255;
+ p_lab.b = 255;
+ assign_pixel(p_rgb, p_lab);
+ DLIB_TEST(p_rgb.red == 132);
+ DLIB_TEST(p_rgb.green == 0);
+ DLIB_TEST(p_rgb.blue == 0);
+
+ p_lab.l = 255;
+ p_lab.a = 0;
+ p_lab.b = 255;
+ assign_pixel(p_rgb, p_lab);
+ DLIB_TEST(p_rgb.red == 0);
+ DLIB_TEST(p_rgb.green == 255);
+ DLIB_TEST(p_rgb.blue == 0);
+
+ p_lab.l = 255;
+ p_lab.a = 255;
+ p_lab.b = 255;
+ assign_pixel(p_rgb, p_lab);
+ DLIB_TEST(p_rgb.red == 255);
+ DLIB_TEST(p_rgb.green == 70);
+ DLIB_TEST(p_rgb.blue == 0);
+
+ //RGB limit tests
+ p_rgb.red = 0;
+ p_rgb.green = 0;
+ p_rgb.blue = 0;
+ assign_pixel(p_lab, p_rgb);
+ assign_pixel(p_rgb2, p_lab);
+ DLIB_TEST(p_rgb2.red < 3);
+ DLIB_TEST(p_rgb2.green < 3);
+ DLIB_TEST(p_rgb2.blue < 3);
+
+ p_rgb.red = 255;
+ p_rgb.green = 0;
+ p_rgb.blue = 0;
+ assign_pixel(p_lab, p_rgb);
+ assign_pixel(p_rgb2, p_lab);
+ DLIB_TEST(p_rgb2.red > 252);
+ DLIB_TEST(p_rgb2.green < 3);
+ DLIB_TEST(p_rgb2.blue < 3);
+
+ p_rgb.red = 0;
+ p_rgb.green = 255;
+ p_rgb.blue = 0;
+ assign_pixel(p_lab, p_rgb);
+ assign_pixel(p_rgb2, p_lab);
+ DLIB_TEST(p_rgb2.red < 8);
+ DLIB_TEST(p_rgb2.green > 252);
+ DLIB_TEST(p_rgb2.blue < 5);
+
+ p_rgb.red = 0;
+ p_rgb.green = 0;
+ p_rgb.blue = 255;
+ assign_pixel(p_lab, p_rgb);
+ assign_pixel(p_rgb2, p_lab);
+ DLIB_TEST(p_rgb2.red < 3);
+ DLIB_TEST(p_rgb2.green < 3);
+ DLIB_TEST(p_rgb2.blue > 252);
+
+ p_rgb.red = 255;
+ p_rgb.green = 255;
+ p_rgb.blue = 0;
+ assign_pixel(p_lab, p_rgb);
+ assign_pixel(p_rgb2, p_lab);
+ DLIB_TEST(p_rgb2.red > 252);
+ DLIB_TEST(p_rgb2.green > 252);
+ DLIB_TEST(p_rgb2.blue < 9);
+
+ p_rgb.red = 0;
+ p_rgb.green = 255;
+ p_rgb.blue = 255;
+ assign_pixel(p_lab, p_rgb);
+ assign_pixel(p_rgb2, p_lab);
+ DLIB_TEST(p_rgb2.red < 5);
+ DLIB_TEST(p_rgb2.green > 252);
+ DLIB_TEST(p_rgb2.blue > 252);
+
+ p_rgb.red = 255;
+ p_rgb.green = 0;
+ p_rgb.blue = 255;
+ assign_pixel(p_lab, p_rgb);
+ assign_pixel(p_rgb2, p_lab);
+ DLIB_TEST(p_rgb2.red> 252);
+ DLIB_TEST(p_rgb2.green < 6);
+ DLIB_TEST(p_rgb2.blue > 252);
+
+ p_rgb.red = 255;
+ p_rgb.green = 255;
+ p_rgb.blue = 255;
+ assign_pixel(p_lab, p_rgb);
+ assign_pixel(p_rgb2, p_lab);
+ DLIB_TEST(p_rgb2.red > 252 );
+ DLIB_TEST(p_rgb2.green> 252);
+ DLIB_TEST(p_rgb2.blue > 252);
+
+
+ assign_pixel(p_rgba, 100);
+ assign_pixel(p_gray, 10);
+ assign_pixel(p_schar, 10);
+ assign_pixel(p_float, 10);
+ assign_pixel(p_int, 10);
+
+ assign_pixel(p_rgb, 10);
+ p_rgba.alpha = 128;
+ assign_pixel(p_gray, p_rgba);
+ assign_pixel(p_schar, p_rgba);
+ assign_pixel(p_float, p_rgba);
+ assign_pixel(p_int, p_rgba);
+ assign_pixel(p_rgb, p_rgba);
+ DLIB_TEST(p_gray == (100 + 10)/2);
+ DLIB_TEST(p_schar == (100 + 10)/2);
+ DLIB_TEST(p_int == (100 + 10)/2);
+ DLIB_TEST(p_float == (100 + 10)/2);
+ DLIB_TEST(p_rgb.red == (100 + 10)/2);
+ DLIB_TEST(p_rgb.green == (100 + 10)/2);
+ DLIB_TEST(p_rgb.blue == (100 + 10)/2);
+
+ assign_pixel(p_rgba, 100);
+ assign_pixel(p_gray, 10);
+ assign_pixel(p_schar, 10);
+ assign_pixel(p_int, 10);
+ assign_pixel(p_float, 10);
+ assign_pixel(p_rgb, 10);
+ DLIB_TEST(p_rgba.alpha == 255);
+ assign_pixel(p_gray, p_rgba);
+ assign_pixel(p_schar, p_rgba);
+ assign_pixel(p_int, p_rgba);
+ assign_pixel(p_float, p_rgba);
+ assign_pixel(p_rgb, p_rgba);
+ DLIB_TEST(p_gray == 100);
+ DLIB_TEST(p_schar == 100);
+ DLIB_TEST(p_int == 100);
+ DLIB_TEST(p_float == 100);
+ DLIB_TEST(p_rgb.red == 100);
+ DLIB_TEST(p_rgb.green == 100);
+ DLIB_TEST(p_rgb.blue == 100);
+
+
+ p_rgb.red = 1;
+ p_rgb.green = 2;
+ p_rgb.blue = 3;
+
+ p_rgba.red = 4;
+ p_rgba.green = 5;
+ p_rgba.blue = 6;
+ p_rgba.alpha = 7;
+
+ p_gray = 8;
+ p_schar = 8;
+ p_int = 8;
+ p_float = 8;
+
+ p_hsi.h = 9;
+ p_hsi.s = 10;
+ p_hsi.i = 11;
+
+ p_lab.l = 10;
+ p_lab.a = 9;
+ p_lab.b = 8;
+
+ ostringstream sout;
+ serialize(p_rgb,sout);
+ serialize(p_rgba,sout);
+ serialize(p_gray,sout);
+ serialize(p_schar,sout);
+ serialize(p_int,sout);
+ serialize(p_float,sout);
+ serialize(p_hsi,sout);
+ serialize(p_lab,sout);
+
+ assign_pixel(p_rgb,0);
+ assign_pixel(p_rgba,0);
+ assign_pixel(p_gray,0);
+ assign_pixel(p_schar,0);
+ assign_pixel(p_int,0);
+ assign_pixel(p_float,0);
+ assign_pixel(p_hsi,0);
+ assign_pixel(p_lab,0);
+
+ istringstream sin(sout.str());
+
+ deserialize(p_rgb,sin);
+ deserialize(p_rgba,sin);
+ deserialize(p_gray,sin);
+ deserialize(p_schar,sin);
+ deserialize(p_int,sin);
+ deserialize(p_float,sin);
+ deserialize(p_hsi,sin);
+ deserialize(p_lab,sin);
+
+ DLIB_TEST(p_rgb.red == 1);
+ DLIB_TEST(p_rgb.green == 2);
+ DLIB_TEST(p_rgb.blue == 3);
+
+ DLIB_TEST(p_rgba.red == 4);
+ DLIB_TEST(p_rgba.green == 5);
+ DLIB_TEST(p_rgba.blue == 6);
+ DLIB_TEST(p_rgba.alpha == 7);
+
+ DLIB_TEST(p_gray == 8);
+ DLIB_TEST(p_schar == 8);
+ DLIB_TEST(p_int == 8);
+ DLIB_TEST(p_float == 8);
+
+ DLIB_TEST(p_hsi.h == 9);
+ DLIB_TEST(p_hsi.s == 10);
+ DLIB_TEST(p_hsi.i == 11);
+
+ DLIB_TEST(p_lab.l == 10);
+ DLIB_TEST(p_lab.a == 9);
+ DLIB_TEST(p_lab.b == 8);
+
+ {
+ matrix<double,1,1> m_gray, m_schar, m_int, m_float;
+ matrix<double,3,1> m_rgb, m_hsi, m_lab;
+
+ m_gray = pixel_to_vector<double>(p_gray);
+ m_schar = pixel_to_vector<double>(p_schar);
+ m_int = pixel_to_vector<double>(p_int);
+ m_float = pixel_to_vector<double>(p_float);
+
+ m_hsi = pixel_to_vector<double>(p_hsi);
+ m_rgb = pixel_to_vector<double>(p_rgb);
+ m_lab = pixel_to_vector<double>(p_lab);
+
+ DLIB_TEST(m_gray(0) == p_gray);
+ DLIB_TEST(m_float(0) == p_float);
+ DLIB_TEST(m_int(0) == p_int);
+ DLIB_TEST(m_schar(0) == p_schar);
+
+ DLIB_TEST(m_rgb(0) == p_rgb.red);
+ DLIB_TEST(m_rgb(1) == p_rgb.green);
+ DLIB_TEST(m_rgb(2) == p_rgb.blue);
+ DLIB_TEST(m_hsi(0) == p_hsi.h);
+ DLIB_TEST(m_hsi(1) == p_hsi.s);
+ DLIB_TEST(m_hsi(2) == p_hsi.i);
+ DLIB_TEST(m_lab(0) == p_lab.l);
+ DLIB_TEST(m_lab(1) == p_lab.a);
+ DLIB_TEST(m_lab(2) == p_lab.b);
+
+ DLIB_TEST(p_rgb.red == 1);
+ DLIB_TEST(p_rgb.green == 2);
+ DLIB_TEST(p_rgb.blue == 3);
+
+ DLIB_TEST(p_rgba.red == 4);
+ DLIB_TEST(p_rgba.green == 5);
+ DLIB_TEST(p_rgba.blue == 6);
+ DLIB_TEST(p_rgba.alpha == 7);
+
+ DLIB_TEST(p_gray == 8);
+ DLIB_TEST(p_int == 8);
+ DLIB_TEST(p_float == 8);
+ DLIB_TEST(p_schar == 8);
+
+ DLIB_TEST(p_hsi.h == 9);
+ DLIB_TEST(p_hsi.s == 10);
+ DLIB_TEST(p_hsi.i == 11);
+
+ DLIB_TEST(p_lab.l == 10);
+ DLIB_TEST(p_lab.a == 9);
+ DLIB_TEST(p_lab.b == 8);
+
+ assign_pixel(p_gray,0);
+ assign_pixel(p_hsi,0);
+ assign_pixel(p_rgb,0);
+ assign_pixel(p_lab,0);
+
+ vector_to_pixel(p_gray, m_gray);
+ vector_to_pixel(p_hsi, m_hsi);
+ vector_to_pixel(p_rgb, m_rgb);
+ vector_to_pixel(p_lab, m_lab);
+
+ DLIB_TEST(p_rgb.red == 1);
+ DLIB_TEST(p_rgb.green == 2);
+ DLIB_TEST(p_rgb.blue == 3);
+
+ DLIB_TEST(p_rgba.red == 4);
+ DLIB_TEST(p_rgba.green == 5);
+ DLIB_TEST(p_rgba.blue == 6);
+ DLIB_TEST(p_rgba.alpha == 7);
+
+ DLIB_TEST(p_gray == 8);
+
+ DLIB_TEST(p_hsi.h == 9);
+ DLIB_TEST(p_hsi.s == 10);
+ DLIB_TEST(p_hsi.i == 11);
+
+ DLIB_TEST(p_lab.l == 10);
+ DLIB_TEST(p_lab.a == 9);
+ DLIB_TEST(p_lab.b == 8);
+ }
+
+
+
+
+ {
+ unsigned char p_gray;
+ unsigned short p_gray16;
+ long p_int;
+ float p_float;
+ signed char p_schar;
+ rgb_pixel p_rgb;
+ hsi_pixel p_hsi, p_hsi2;
+ rgb_alpha_pixel p_rgba;
+ lab_pixel p_lab;
+
+
+ assign_pixel(p_gray, 0);
+ assign_pixel(p_gray16, 0);
+ assign_pixel(p_int, 0);
+ assign_pixel(p_float, 0);
+ assign_pixel(p_schar, 0);
+ assign_pixel(p_rgb, 0);
+ assign_pixel(p_hsi, 0);
+ assign_pixel(p_lab, 0);
+
+
+ assign_pixel(p_gray, 100);
+ assign_pixel(p_schar, p_gray);
+ DLIB_TEST(p_schar == 100);
+
+ assign_pixel(p_gray, 200);
+ assign_pixel(p_schar, p_gray);
+ DLIB_TEST(p_schar == std::numeric_limits<signed char>::max());
+
+ assign_pixel(p_int, p_gray);
+ DLIB_TEST(p_int == 200);
+
+ assign_pixel(p_float, p_gray);
+ DLIB_TEST(p_float == 200);
+
+ assign_pixel(p_rgb, p_float);
+ DLIB_TEST(p_rgb.red == 200);
+ DLIB_TEST(p_rgb.green == 200);
+ DLIB_TEST(p_rgb.blue == 200);
+
+ p_schar = 0;
+ assign_pixel(p_schar, p_rgb);
+ DLIB_TEST(p_schar == std::numeric_limits<signed char>::max());
+
+
+ p_schar = -10;
+ assign_pixel(p_float, p_schar);
+ DLIB_TEST(p_float == -10);
+ assign_pixel(p_int, p_schar);
+ DLIB_TEST(p_int == -10);
+ assign_pixel(p_schar, p_schar);
+ DLIB_TEST(p_schar == -10);
+ assign_pixel(p_gray, p_schar);
+ DLIB_TEST(p_gray == 0);
+
+ assign_pixel(p_rgb, p_schar);
+ DLIB_TEST(p_rgb.red == 0);
+ DLIB_TEST(p_rgb.green == 0);
+ DLIB_TEST(p_rgb.blue == 0);
+
+ assign_pixel(p_gray16, p_schar);
+ DLIB_TEST(p_gray16 == 0);
+
+ DLIB_TEST(get_pixel_intensity(p_float) == -10);
+ DLIB_TEST(get_pixel_intensity(p_int) == -10);
+ DLIB_TEST(get_pixel_intensity(p_schar) == -10);
+ DLIB_TEST(get_pixel_intensity(p_rgb) == 0);
+ DLIB_TEST(get_pixel_intensity(p_gray16) == 0);
+
+ p_rgb.red = 100;
+ p_rgb.green = 100;
+ p_rgb.blue = 100;
+ DLIB_TEST(get_pixel_intensity(p_rgb) == 100);
+ p_rgb.red = 1;
+ p_rgb.green = 2;
+ p_rgb.blue = 3;
+ DLIB_TEST(get_pixel_intensity(p_rgb) == 2);
+ p_rgba.alpha = 100;
+ p_rgba.red = 100;
+ p_rgba.green = 100;
+ p_rgba.blue = 100;
+ DLIB_TEST(get_pixel_intensity(p_rgba) == 100);
+ p_rgba.red = 1;
+ p_rgba.green = 2;
+ p_rgba.blue = 3;
+ p_rgba.alpha = 0;
+ DLIB_TEST(get_pixel_intensity(p_rgba) == 2);
+ p_hsi.h = 123;
+ p_hsi.s = 100;
+ p_hsi.i = 84;
+ DLIB_TEST(get_pixel_intensity(p_hsi) == 84);
+
+ p_lab.l = 123;
+ p_lab.a = 100;
+ p_lab.b = 84;
+ DLIB_TEST(get_pixel_intensity(p_lab) == 123);
+
+ p_float = 54.25;
+ DLIB_TEST(get_pixel_intensity(p_float) == 54.25);
+
+ assign_pixel(p_gray, p_float);
+ DLIB_TEST(get_pixel_intensity(p_gray) == 54);
+
+ assign_pixel_intensity(p_float, -1000);
+ assign_pixel_intensity(p_schar, -100);
+ assign_pixel_intensity(p_int, -10000);
+ assign_pixel_intensity(p_gray, -100);
+
+ p_rgba.red = 10;
+ p_rgba.green = 10;
+ p_rgba.blue = 10;
+ p_rgba.alpha = 0;
+ DLIB_TEST_MSG(get_pixel_intensity(p_rgba) == 10, (int)get_pixel_intensity(p_rgba));
+ assign_pixel_intensity(p_rgba, 2);
+ DLIB_TEST_MSG(p_rgba.red == 2, (int)p_rgba.red);
+ DLIB_TEST_MSG(p_rgba.green == 2, (int)p_rgba.green);
+ DLIB_TEST_MSG(p_rgba.blue == 2, (int)p_rgba.blue);
+ DLIB_TEST_MSG(p_rgba.alpha == 0, (int)p_rgba.alpha);
+ DLIB_TEST_MSG(get_pixel_intensity(p_rgba) == 2, (int)get_pixel_intensity(p_rgba));
+
+ DLIB_TEST(p_float == -1000);
+ DLIB_TEST(get_pixel_intensity(p_float) == -1000);
+ DLIB_TEST(p_schar == -100);
+ DLIB_TEST(get_pixel_intensity(p_schar) == -100);
+ DLIB_TEST(p_int == -10000);
+ DLIB_TEST(get_pixel_intensity(p_int) == -10000);
+ DLIB_TEST(p_gray == 0);
+ assign_pixel_intensity(p_gray, 1000);
+ DLIB_TEST(p_gray == 255);
+ DLIB_TEST(get_pixel_intensity(p_gray) == 255);
+
+ assign_pixel_intensity(p_float, p_gray);
+ DLIB_TEST(p_float == 255);
+ DLIB_TEST(get_pixel_intensity(p_float) == 255);
+
+ assign_pixel_intensity(p_int, p_gray);
+ DLIB_TEST(p_int == 255);
+ DLIB_TEST(get_pixel_intensity(p_int) == 255);
+
+
+ p_float = 1e10;
+ assign_pixel(p_schar, p_float);
+ DLIB_TEST(p_schar == std::numeric_limits<signed char>::max());
+
+ p_float = -1e10;
+ assign_pixel(p_schar, p_float);
+ DLIB_TEST(p_schar == std::numeric_limits<signed char>::min());
+
+ double p_double = 1e200;
+ assign_pixel(p_float, p_double);
+ DLIB_TEST(p_float == std::numeric_limits<float>::max());
+
+ p_double = -1e200;
+ assign_pixel(p_float, p_double);
+ DLIB_TEST(p_float == -std::numeric_limits<float>::max());
+ }
+
+
+ }
+
+
+
+
+ class pixel_tester : public tester
+ {
+ public:
+ pixel_tester (
+ ) :
+ tester ("test_pixel",
+ "Runs tests on the pixel objects and functions.")
+ {}
+
+ void perform_test (
+ )
+ {
+ pixel_test();
+ }
+ } a;
+
+}
diff --git a/ml/dlib/dlib/test/probabilistic.cpp b/ml/dlib/dlib/test/probabilistic.cpp
new file mode 100644
index 000000000..e8a24829a
--- /dev/null
+++ b/ml/dlib/dlib/test/probabilistic.cpp
@@ -0,0 +1,123 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+#include "checkerboard.h"
+#include <dlib/statistics.h>
+
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.probabilistic");
+
+// ----------------------------------------------------------------------------------------
+
+ class test_probabilistic : public tester
+ {
+ public:
+ test_probabilistic (
+ ) :
+ tester ("test_probabilistic",
+ "Runs tests on the probabilistic trainer adapter.")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+
+
+ typedef double scalar_type;
+ typedef matrix<scalar_type,2,1> sample_type;
+
+ std::vector<sample_type> x;
+ std::vector<matrix<double,0,1> > x_linearized;
+ std::vector<scalar_type> y;
+
+ get_checkerboard_problem(x,y, 1000, 2);
+
+ random_subset_selector<sample_type> rx;
+ random_subset_selector<scalar_type> ry;
+ rx.set_max_size(x.size());
+ ry.set_max_size(x.size());
+
+ dlog << LINFO << "pos labels: "<< sum(mat(y) == +1);
+ dlog << LINFO << "neg labels: "<< sum(mat(y) == -1);
+
+ for (unsigned long i = 0; i < x.size(); ++i)
+ {
+ rx.add(x[i]);
+ ry.add(y[i]);
+ }
+
+ const scalar_type gamma = 2.0;
+
+ typedef radial_basis_kernel<sample_type> kernel_type;
+
+ krr_trainer<kernel_type> krr_trainer;
+ krr_trainer.use_classification_loss_for_loo_cv();
+ krr_trainer.set_kernel(kernel_type(gamma));
+ krr_trainer.set_basis(randomly_subsample(x, 100));
+ probabilistic_decision_function<kernel_type> df;
+
+ dlog << LINFO << "cross validation: " << cross_validate_trainer(krr_trainer, rx,ry, 4);
+ print_spinner();
+
+ running_stats<scalar_type> rs_pos, rs_neg;
+
+ print_spinner();
+ df = probabilistic(krr_trainer,3).train(x, y);
+ for (unsigned long i = 0; i < x.size(); ++i)
+ {
+ if (y[i] > 0)
+ rs_pos.add(df(x[i]));
+ else
+ rs_neg.add(df(x[i]));
+ }
+ dlog << LINFO << "rs_pos.mean(): "<< rs_pos.mean();
+ dlog << LINFO << "rs_neg.mean(): "<< rs_neg.mean();
+ DLIB_TEST_MSG(rs_pos.mean() > 0.95, rs_pos.mean());
+ DLIB_TEST_MSG(rs_neg.mean() < 0.05, rs_neg.mean());
+ rs_pos.clear();
+ rs_neg.clear();
+
+
+ print_spinner();
+ df = probabilistic(krr_trainer,3).train(rx, ry);
+ for (unsigned long i = 0; i < x.size(); ++i)
+ {
+ if (y[i] > 0)
+ rs_pos.add(df(x[i]));
+ else
+ rs_neg.add(df(x[i]));
+ }
+ dlog << LINFO << "rs_pos.mean(): "<< rs_pos.mean();
+ dlog << LINFO << "rs_neg.mean(): "<< rs_neg.mean();
+ DLIB_TEST_MSG(rs_pos.mean() > 0.95, rs_pos.mean());
+ DLIB_TEST_MSG(rs_neg.mean() < 0.05, rs_neg.mean());
+ rs_pos.clear();
+ rs_neg.clear();
+
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/pyramid_down.cpp b/ml/dlib/dlib/test/pyramid_down.cpp
new file mode 100644
index 000000000..c026a8162
--- /dev/null
+++ b/ml/dlib/dlib/test/pyramid_down.cpp
@@ -0,0 +1,424 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/image_transforms.h>
+//#include <dlib/gui_widgets.h>
+#include <dlib/rand.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.pyramid_down");
+
+// ----------------------------------------------------------------------------------------
+
+void test_pyramid_down_grayscale()
+{
+ array2d<unsigned char> img, down;
+ pyramid_down<2> pyr;
+
+ img.set_size(300,264);
+
+ assign_all_pixels(img, 10);
+
+ pyr(img, down);
+
+ DLIB_TEST(std::abs(down.nr()*2 - img.nr()) < 5);
+ DLIB_TEST(std::abs(down.nc()*2 - img.nc()) < 5);
+
+ rectangle rect1 = get_rect(img);
+ rectangle rect2 = pyr.rect_up(get_rect(down));
+ double overlap = rect1.intersect(rect2).area() / (double)(rect1 + rect2).area();
+ DLIB_TEST(overlap > 0.95);
+
+ rect1 = get_rect(down);
+ rect2 = pyr.rect_down(get_rect(img));
+ overlap = rect1.intersect(rect2).area() / (double)(rect1 + rect2).area();
+ DLIB_TEST(overlap > 0.95);
+
+ DLIB_TEST(min(mat(down)) == 10);
+ DLIB_TEST(max(mat(down)) == 10);
+}
+
+void test_pyramid_down_rgb()
+{
+ array2d<rgb_pixel> img;
+ array2d<bgr_pixel> down;
+ pyramid_down<2> pyr;
+
+ img.set_size(231, 351);
+
+ assign_all_pixels(img, rgb_pixel(1,2,3));
+
+ pyr(img, down);
+
+ DLIB_TEST(std::abs(down.nr()*2 - img.nr()) < 5);
+ DLIB_TEST(std::abs(down.nc()*2 - img.nc()) < 5);
+
+ rectangle rect1 = get_rect(img);
+ rectangle rect2 = pyr.rect_up(get_rect(down));
+ double overlap = rect1.intersect(rect2).area() / (double)(rect1 + rect2).area();
+ DLIB_TEST(overlap > 0.95);
+
+ rect1 = get_rect(down);
+ rect2 = pyr.rect_down(get_rect(img));
+ overlap = rect1.intersect(rect2).area() / (double)(rect1 + rect2).area();
+ DLIB_TEST(overlap > 0.95);
+
+ bool pixels_match = true;
+ for (long r = 0; r < down.nr(); ++r)
+ {
+ for (long c = 0; c < down.nc(); ++c)
+ {
+ if (down[r][c].red != 1 ||
+ down[r][c].green != 2 ||
+ down[r][c].blue != 3 )
+ {
+ pixels_match = false;
+ }
+ }
+ }
+ DLIB_TEST(pixels_match);
+}
+
+// ----------------------------------------------------------------------------
+
+template <typename image_type>
+rgb_pixel mean_pixel (
+ const image_type& img,
+ const rectangle& rect
+)
+{
+ long red = 0;
+ long green = 0;
+ long blue = 0;
+ for (long r = rect.top(); r <= rect.bottom(); ++r)
+ {
+ for (long c = rect.left(); c <= rect.right(); ++c)
+ {
+ red += img[r][c].red;
+ green += img[r][c].green;
+ blue += img[r][c].blue;
+ }
+ }
+
+ const long n = rect.area();
+ return rgb_pixel(red/n, green/n, blue/n);
+}
+
+// ----------------------------------------------------------------------------
+
+template <typename pyramid_down_type>
+void test_pyramid_down_rgb2()
+{
+ array2d<rgb_pixel> img, img3;
+ array2d<unsigned char> img2, img4;
+
+
+ img.set_size(300,400);
+ assign_all_pixels(img, 0);
+ rectangle rect1 = centered_rect( 10,10, 14, 14);
+ rectangle rect2 = centered_rect( 100,100, 34, 42);
+ rectangle rect3 = centered_rect( 310,215, 65, 21);
+
+ fill_rect(img, rect1, rgb_pixel(255,0,0));
+ fill_rect(img, rect2, rgb_pixel(0,255,0));
+ fill_rect(img, rect3, rgb_pixel(0,0,255));
+
+
+
+ pyramid_down_type pyr;
+
+ pyr(img, img2);
+ pyr(img, img3);
+
+
+ DLIB_TEST(((rect1.tl_corner() - pyr.rect_down(pyr.rect_up(rect1,2),2).tl_corner()).length()) < 1);
+ DLIB_TEST(((rect1.br_corner() - pyr.rect_down(pyr.rect_up(rect1,2),2).br_corner()).length()) < 1);
+ DLIB_TEST(((rect2.tl_corner() - pyr.rect_down(pyr.rect_up(rect2,2),2).tl_corner()).length()) < 1);
+ DLIB_TEST(((rect2.br_corner() - pyr.rect_down(pyr.rect_up(rect2,2),2).br_corner()).length()) < 1);
+ DLIB_TEST(((rect3.tl_corner() - pyr.rect_down(pyr.rect_up(rect3,2),2).tl_corner()).length()) < 1);
+ DLIB_TEST(((rect3.br_corner() - pyr.rect_down(pyr.rect_up(rect3,2),2).br_corner()).length()) < 1);
+
+ rect1 = shrink_rect(pyr.rect_down(rect1),1);
+ rect2 = shrink_rect(pyr.rect_down(rect2),1);
+ rect3 = shrink_rect(pyr.rect_down(rect3),1);
+
+ DLIB_TEST(rect1.area() > 10);
+ DLIB_TEST(rect2.area() > 10);
+ DLIB_TEST(rect3.area() > 10);
+
+ /*
+ image_window my_window(img);
+ image_window win2(img2);
+ image_window win3(img3);
+ win2.add_overlay(image_window::overlay_rect(rect1, rgb_pixel(255,0,0)));
+ win2.add_overlay(image_window::overlay_rect(rect2, rgb_pixel(255,0,0)));
+ win2.add_overlay(image_window::overlay_rect(rect3, rgb_pixel(255,0,0)));
+ win3.add_overlay(image_window::overlay_rect(rect1, rgb_pixel(255,0,0)));
+ win3.add_overlay(image_window::overlay_rect(rect2, rgb_pixel(255,0,0)));
+ win3.add_overlay(image_window::overlay_rect(rect3, rgb_pixel(255,0,0)));
+ */
+
+
+ DLIB_TEST(std::abs((int)mean(subm(matrix_cast<long>(mat(img2)),rect1)) - 255/3) < 3);
+ DLIB_TEST(std::abs((int)mean(subm(matrix_cast<long>(mat(img2)),rect2)) - 255/3) < 3);
+ DLIB_TEST(std::abs((int)mean(subm(matrix_cast<long>(mat(img2)),rect3)) - 255/3) < 3);
+ assign_image(img4, img);
+ DLIB_TEST(std::abs((int)mean(mat(img4)) - mean(mat(img2))) < 2);
+
+
+ rgb_pixel mean1 = mean_pixel(img3, rect1);
+ rgb_pixel mean2 = mean_pixel(img3, rect2);
+ rgb_pixel mean3 = mean_pixel(img3, rect3);
+ rgb_pixel mean_all_true = mean_pixel(img, get_rect(img));
+ rgb_pixel mean_all = mean_pixel(img3, get_rect(img3));
+ DLIB_TEST(mean1.red > 250);
+ DLIB_TEST(mean1.green < 3);
+ DLIB_TEST(mean1.blue < 3);
+
+ DLIB_TEST(mean2.red < 3);
+ DLIB_TEST(mean2.green > 250);
+ DLIB_TEST(mean2.blue < 3);
+
+ DLIB_TEST(mean3.red < 3);
+ DLIB_TEST(mean3.green < 3);
+ DLIB_TEST(mean3.blue > 250);
+
+ DLIB_TEST(std::abs((int)mean_all_true.red - mean_all.red) < 1);
+ DLIB_TEST(std::abs((int)mean_all_true.green - mean_all.green) < 1);
+ DLIB_TEST(std::abs((int)mean_all_true.blue - mean_all.blue) < 1);
+
+ //my_window.wait_until_closed();
+}
+
+
+// ----------------------------------------------------------------------------------------
+
+template <typename pyramid_down_type>
+void test_pyramid_down_grayscale2()
+{
+ array2d<unsigned char> img;
+ array2d<unsigned char> img2, img4;
+
+
+ img.set_size(300,400);
+ assign_all_pixels(img, 0);
+ rectangle rect1 = centered_rect( 10,10, 14, 14);
+ rectangle rect2 = centered_rect( 100,100, 34, 42);
+ rectangle rect3 = centered_rect( 310,215, 65, 21);
+
+ fill_rect(img, rect1, 255);
+ fill_rect(img, rect2, 170);
+ fill_rect(img, rect3, 100);
+
+
+
+ pyramid_down_type pyr;
+
+ pyr(img, img2);
+
+
+ DLIB_TEST(((rect1.tl_corner() - pyr.rect_down(pyr.rect_up(rect1,2),2).tl_corner()).length()) < 1);
+ DLIB_TEST(((rect1.br_corner() - pyr.rect_down(pyr.rect_up(rect1,2),2).br_corner()).length()) < 1);
+ DLIB_TEST(((rect2.tl_corner() - pyr.rect_down(pyr.rect_up(rect2,2),2).tl_corner()).length()) < 1);
+ DLIB_TEST(((rect2.br_corner() - pyr.rect_down(pyr.rect_up(rect2,2),2).br_corner()).length()) < 1);
+ DLIB_TEST(((rect3.tl_corner() - pyr.rect_down(pyr.rect_up(rect3,2),2).tl_corner()).length()) < 1);
+ DLIB_TEST(((rect3.br_corner() - pyr.rect_down(pyr.rect_up(rect3,2),2).br_corner()).length()) < 1);
+
+ rect1 = shrink_rect(pyr.rect_down(rect1),1);
+ rect2 = shrink_rect(pyr.rect_down(rect2),1);
+ rect3 = shrink_rect(pyr.rect_down(rect3),1);
+
+ DLIB_TEST(rect1.area() > 10);
+ DLIB_TEST(rect2.area() > 10);
+ DLIB_TEST(rect3.area() > 10);
+
+ /*
+ image_window my_window(img);
+ image_window win2(img2);
+ win2.add_overlay(image_window::overlay_rect(rect1, rgb_pixel(255,0,0)));
+ win2.add_overlay(image_window::overlay_rect(rect2, rgb_pixel(255,0,0)));
+ win2.add_overlay(image_window::overlay_rect(rect3, rgb_pixel(255,0,0)));
+ */
+
+
+ DLIB_TEST(std::abs((int)mean(subm(matrix_cast<long>(mat(img2)),rect1)) - 255) <= 3);
+ DLIB_TEST(std::abs((int)mean(subm(matrix_cast<long>(mat(img2)),rect2)) - 170) < 3);
+ DLIB_TEST(std::abs((int)mean(subm(matrix_cast<long>(mat(img2)),rect3)) - 100) < 3);
+ assign_image(img4, img);
+ DLIB_TEST(std::abs((int)mean(mat(img4)) - mean(mat(img2))) < 2);
+
+
+ //my_window.wait_until_closed();
+
+
+
+ // make sure the coordinate mapping is invertible when it should be
+ for (int l = 0; l < 4; ++l)
+ {
+ for (long x = -10; x <= 10; ++x)
+ {
+ for (long y = -10; y <= 10; ++y)
+ {
+ DLIB_TEST_MSG(point(pyr.point_down(pyr.point_up(point(x,y),l),l)) == point(x,y),
+ point(x,y) << " " << pyr.point_up(point(x,y),l) << " " << pyr.point_down(pyr.point_up(point(x,y),l),l));
+ DLIB_TEST_MSG(point(pyr.point_down(point(pyr.point_up(point(x,y),l)),l)) == point(x,y),
+ point(x,y) << " " << pyr.point_up(point(x,y),l) << " " << pyr.point_down(point(pyr.point_up(point(x,y),l)),l));
+ }
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+template <typename pyramid_down_type>
+void test_pyr_sizes()
+{
+ dlib::rand rnd;
+
+ for (int iter = 0; iter < 20; ++iter)
+ {
+ long nr = rnd.get_random_32bit_number()%10+40;
+ long nc = rnd.get_random_32bit_number()%10+40;
+
+ array2d<unsigned char> img(nr,nc), img2;
+ assign_all_pixels(img,0);
+
+ pyramid_down_type pyr;
+
+ pyr(img, img2);
+ find_pyramid_down_output_image_size(pyr, nr, nc);
+ DLIB_TEST(img2.nr() == nr);
+ DLIB_TEST(img2.nc() == nc);
+ }
+}
+
+
+// ----------------------------------------------------------------------------------------
+
+template <typename pyramid_down_type>
+void test_pyramid_down_small_sizes()
+{
+ print_spinner();
+ // just make sure it doesn't get messed up with small images. This test
+ // is only really useful if asserts are enabled.
+ pyramid_down_type pyr;
+
+ for (int size = 0; size < 20; ++size)
+ {
+ array2d<unsigned char> img1(size,size);
+ array2d<rgb_pixel> img2(size,size);
+
+ array2d<unsigned char> out1;
+ array2d<rgb_pixel> out2;
+
+ assign_all_pixels(img1, 0);
+ assign_all_pixels(img2, 0);
+
+ pyr(img1, out1);
+ pyr(img2, out2);
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+
+ class test_pyramid_down : public tester
+ {
+ public:
+ test_pyramid_down (
+ ) :
+ tester ("test_pyramid_down",
+ "Runs tests on the pyramid_down() function.")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ test_pyramid_down_grayscale();
+ print_spinner();
+ test_pyramid_down_rgb();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_small_sizes<pyramid_down<2> >();";
+ test_pyramid_down_small_sizes<pyramid_down<2> >();
+ dlog << LINFO << "call test_pyramid_down_small_sizes<pyramid_down<3> >();";
+ test_pyramid_down_small_sizes<pyramid_down<3> >();
+ dlog << LINFO << "call test_pyramid_down_small_sizes<pyramid_down<4> >();";
+ test_pyramid_down_small_sizes<pyramid_down<4> >();
+ dlog << LINFO << "call test_pyramid_down_small_sizes<pyramid_down<5> >();";
+ test_pyramid_down_small_sizes<pyramid_down<5> >();
+ dlog << LINFO << "call test_pyramid_down_small_sizes<pyramid_disable>();";
+ test_pyramid_down_small_sizes<pyramid_disable>();
+ dlog << LINFO << "call test_pyramid_down_small_sizes<pyramid_down<9> >();";
+ test_pyramid_down_small_sizes<pyramid_down<9> >();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_rgb2<pyramid_down<2> >();";
+ test_pyramid_down_rgb2<pyramid_down<2> >();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_rgb2<pyramid_down<3> >();";
+ test_pyramid_down_rgb2<pyramid_down<3> >();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_rgb2<pyramid_down<4> >();";
+ test_pyramid_down_rgb2<pyramid_down<4> >();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_rgb2<pyramid_down<5> >();";
+ test_pyramid_down_rgb2<pyramid_down<5> >();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_rgb2<pyramid_down<8> >();";
+ test_pyramid_down_rgb2<pyramid_down<8> >();
+
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_grayscale2<pyramid_down<2> >();";
+ test_pyramid_down_grayscale2<pyramid_down<2> >();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_grayscale2<pyramid_down<3> >();";
+ test_pyramid_down_grayscale2<pyramid_down<3> >();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_grayscale2<pyramid_down<4> >();";
+ test_pyramid_down_grayscale2<pyramid_down<4> >();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_grayscale2<pyramid_down<5> >();";
+ test_pyramid_down_grayscale2<pyramid_down<5> >();
+
+ print_spinner();
+ dlog << LINFO << "call test_pyramid_down_grayscale2<pyramid_down<6> >();";
+ test_pyramid_down_grayscale2<pyramid_down<6> >();
+
+
+ test_pyr_sizes<pyramid_down<1>>();
+ test_pyr_sizes<pyramid_down<2>>();
+ test_pyr_sizes<pyramid_down<3>>();
+ test_pyr_sizes<pyramid_down<4>>();
+ test_pyr_sizes<pyramid_down<5>>();
+ test_pyr_sizes<pyramid_down<6>>();
+ test_pyr_sizes<pyramid_down<7>>();
+ test_pyr_sizes<pyramid_down<8>>();
+ test_pyr_sizes<pyramid_down<28>>();
+ }
+ } a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/queue.cpp b/ml/dlib/dlib/test/queue.cpp
new file mode 100644
index 000000000..efcdf6054
--- /dev/null
+++ b/ml/dlib/dlib/test/queue.cpp
@@ -0,0 +1,426 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/queue.h>
+#include <dlib/memory_manager_global.h>
+
+#include "tester.h"
+
+// This is called an unnamed-namespace and it has the effect of making everything inside this file "private"
+// so that everything you declare will have static linkage. Thus we won't have any multiply
+// defined symbol errors coming out of the linker when we try to compile the test suite.
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ // Declare the logger we will use in this test. The name of the tester
+ // should start with "test."
+ logger dlog("test.queue");
+
+ template <
+ typename queue
+ >
+ void queue_sort_test (
+ )
+ /*!
+ requires
+ - queue is an implementation of queue/queue_sort_abstract.h
+ is instantiated with int
+ ensures
+ - runs tests on queue for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+ srand(static_cast<unsigned int>(time(0)));
+
+ queue q,q2;
+
+ enumerable<int>& e = q;
+
+ // I will use these DLIB_TEST_MSG macros to assert that conditions are true. If they are
+ // false then it means we have detected an error in the queue object. CASSERT
+ // will then throw an exception which we will catch at the end of this function and
+ // report as an error/failed test.
+ DLIB_TEST(e.at_start() == true);
+
+ int a = 0;
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q.at_start() == true);
+ DLIB_TEST(q.current_element_valid() == false);
+
+ q.sort();
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q.at_start() == true);
+ DLIB_TEST(q.current_element_valid() == false);
+
+ DLIB_TEST (q.move_next() == false);
+ DLIB_TEST (q.move_next() == false);
+ DLIB_TEST (q.move_next() == false);
+ DLIB_TEST (q.move_next() == false);
+ DLIB_TEST (q.move_next() == false);
+ DLIB_TEST (q.move_next() == false);
+
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q.at_start() == false);
+ DLIB_TEST(q.current_element_valid() == false);
+
+
+ q.reset();
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q.at_start() == true);
+ DLIB_TEST(q.current_element_valid() == false);
+
+
+
+
+
+
+
+
+
+
+
+ q.clear();
+ q2.clear();
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q2.size() == 0);
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ int a = i;
+ q.enqueue(a);
+ }
+
+ q2.cat(q);
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q2.size() == 10000);
+
+ int g = 0;
+ while (q2.move_next())
+ {
+ DLIB_TEST_MSG(q2.element() == g,g);
+ ++g;
+ }
+
+ for (int i = 0;i < 10000; ++i)
+ {
+ int a = 0;
+ q2.dequeue(a);
+ DLIB_TEST(a == i);
+ }
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q2.size() == 0);
+ q.clear();
+ q2.clear();
+
+
+
+
+ print_spinner();
+
+
+ dlog << LTRACE << "creating big pre-sorted queue";
+ q.clear();
+ DLIB_TEST(q.size() == 0);
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ int a = i;
+ q.enqueue(a);
+ }
+
+ dlog << LTRACE << "sorting already sorted queue";
+ q.sort();
+
+
+ dlog << LTRACE << "done sorting, checking the results";
+ for (int i = 0; i < 10000; ++i)
+ {
+ q.dequeue(a);
+ DLIB_TEST(a == i);
+ }
+
+
+ q.clear();
+ dlog << LTRACE << "done with the big pre-sorted queue test";
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ q.clear();
+ q2.clear();
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q2.size() == 0);
+
+ for (int i = 0; i < 1; ++i)
+ {
+ int a = i;
+ q.enqueue(a);
+ }
+
+ q2.cat(q);
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q2.size() == 1);
+
+
+
+ g = 0;
+ while (q2.move_next())
+ {
+ DLIB_TEST_MSG(q2.element() == g,g);
+ ++g;
+ }
+
+ for (int i = 0;i < 1; ++i)
+ {
+ int a = 0;
+ q2.dequeue(a);
+ DLIB_TEST(a == i);
+ }
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q2.size() == 0);
+ q.clear();
+ q2.clear();
+
+
+
+
+
+
+
+ print_spinner();
+
+
+
+
+
+
+
+
+
+
+
+ for (int j = 0; j < 3; ++j)
+ {
+ for (int i = 0; i < 10000; ++i)
+ {
+ a = ::rand();
+ q.enqueue(a);
+ }
+
+ while (q.move_next()) ;
+
+ DLIB_TEST(q.at_start() == false);
+
+ q.sort();
+
+ DLIB_TEST(q.at_start() == true);
+
+ // serialize the state of q, then clear q, then
+ // load the state back into q.
+ ostringstream sout;
+ serialize(q,sout);
+ DLIB_TEST(q.at_start() == true);
+ istringstream sin(sout.str());
+ q.clear();
+ deserialize(q,sin);
+
+
+ DLIB_TEST(q.at_start() == true);
+
+ a = 0;
+ int last = 0;
+ while (q.move_next())
+ {
+ ++a;
+ DLIB_TEST_MSG(last <= q.element(),"items weren't actually sorted");
+ last = q.element();
+ DLIB_TEST(q.current_element_valid() == true);
+ DLIB_TEST(q.at_start() == false);
+ DLIB_TEST(q.current_element_valid() == true);
+
+
+ }
+ DLIB_TEST_MSG(a == 10000,"some items were lost between the sorting and iterating");
+
+
+ DLIB_TEST(q.size() == 10000);
+ swap(q,q2);
+ DLIB_TEST(q2.at_start() == false);
+ DLIB_TEST(q2.current_element_valid() == false);
+
+ DLIB_TEST (q2.move_next() == false);
+ DLIB_TEST (q2.move_next() == false);
+ DLIB_TEST (q2.move_next() == false);
+ DLIB_TEST (q2.move_next() == false);
+ DLIB_TEST (q2.move_next() == false);
+ DLIB_TEST (q2.move_next() == false);
+
+
+ DLIB_TEST(q2.size() == 10000);
+ DLIB_TEST(q2.at_start() == false);
+ DLIB_TEST(q2.current_element_valid() == false);
+
+ q2.clear();
+
+ q.swap(q2);
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q.at_start() == true);
+ DLIB_TEST(q.current_element_valid() == false);
+ }
+
+
+
+ print_spinner();
+
+
+
+ // try the above code but this time with just one element
+ // in the queue
+ for (int j = 0; j < 3; ++j)
+ {
+ for (int i = 0; i < 1; ++i)
+ {
+ a = ::rand();
+ q.enqueue(a);
+ }
+
+ q.sort();
+
+ a = 0;
+ int last = 0;
+ while (q.move_next())
+ {
+ ++a;
+ DLIB_TEST_MSG(last <= q.element(),"items weren't actually sorted");
+ DLIB_TEST(q.current_element_valid() == true);
+
+ }
+ DLIB_TEST_MSG(a == 1,"some items were lost between the sorting and iterating");
+
+
+ DLIB_TEST(q.size() == 1);
+ DLIB_TEST(q.at_start() == false);
+ DLIB_TEST(q.current_element_valid() == false);
+
+ q.clear();
+
+ DLIB_TEST(q.size() == 0);
+ DLIB_TEST(q.at_start() == true);
+ DLIB_TEST(q.current_element_valid() == false);
+ }
+
+
+ print_spinner();
+
+ {
+ q.clear();
+ remover<int>& go = q;
+ for (int i = 0; i < 100; ++i)
+ {
+ int a = 3;
+ q.enqueue(a);
+ }
+ DLIB_TEST(go.size() == 100);
+ for (int i = 0; i < 100; ++i)
+ {
+ int a = 9;
+ q.remove_any(a);
+ DLIB_TEST(a == 3);
+ }
+ DLIB_TEST(go.size() == 0);
+ }
+
+ }
+
+
+ struct factory
+ {
+ template <typename U>
+ struct return_type {
+ typedef typename memory_manager<U>::kernel_3c type;
+ };
+
+ template <typename U>
+ static typename return_type<U>::type* get_instance (
+ )
+ {
+ static typename return_type<U>::type a;
+ return &a;
+ }
+ };
+
+
+
+
+ class queue_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a test for the queue object. When it is constructed
+ it adds itself into the testing framework. The command line switch is
+ specified as test_queue by passing that string to the tester constructor.
+ !*/
+ public:
+ queue_tester (
+ ) :
+ tester ("test_queue",
+ "Runs tests on the queue component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ // There are multiple implementations of the queue object so use
+ // the templated function defined above to test them all and report
+ // a failed test if any of them don't pass.
+
+ typedef dlib::memory_manager_global<char,factory>::kernel_1a mm;
+
+
+ dlog << LINFO << "testing sort_1a_c";
+ queue_sort_test<queue<int, mm>::sort_1a_c> ();
+ dlog << LINFO << "testing sort_1a";
+ queue_sort_test<queue<int, mm>::sort_1a>();
+ dlog << LINFO << "testing sort_1b";
+ queue_sort_test<queue<int, mm>::sort_1b> ();
+ dlog << LINFO << "testing sort_1b_c";
+ queue_sort_test<queue<int, mm>::sort_1b_c>();
+ dlog << LINFO << "testing sort_1c";
+ queue_sort_test<queue<int, mm>::sort_1c> ();
+ dlog << LINFO << "testing sort_1c_c";
+ queue_sort_test<queue<int, mm>::sort_1c_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/rand.cpp b/ml/dlib/dlib/test/rand.cpp
new file mode 100644
index 000000000..db051c530
--- /dev/null
+++ b/ml/dlib/dlib/test/rand.cpp
@@ -0,0 +1,436 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <cmath>
+#include <dlib/rand.h>
+#include <dlib/compress_stream.h>
+#include <dlib/hash.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.rand");
+
+ void check_bpp (
+ const std::string str
+ )
+ {
+ istringstream rdata;
+ ostringstream sout;
+ rdata.str(str);
+ double compressed_size;
+ compress_stream::kernel_1a cs1;
+ compress_stream::kernel_2a cs2;
+
+ compress_stream_kernel_1<
+ entropy_encoder_model_kernel_5<257,entropy_encoder::kernel_1a,4000000,4>,
+ entropy_decoder_model_kernel_5<257,entropy_decoder::kernel_1a,4000000,4>,
+ crc32::kernel_1a
+ > cs3;
+
+
+ print_spinner();
+
+ rdata.clear();
+ rdata.seekg(0);
+ sout.clear();
+ sout.str("");
+ cs1.compress(rdata,sout);
+ compressed_size = sout.str().size();
+ compressed_size *= 8;
+ compressed_size /= str.size();
+ DLIB_TEST_MSG(compressed_size >= 8, "order 0 bps: " << compressed_size);
+ dlog << LINFO << "order 0: " << compressed_size;
+
+ print_spinner();
+
+ rdata.clear();
+ rdata.seekg(0);
+ sout.clear();
+ sout.str("");
+ cs2.compress(rdata,sout);
+ compressed_size = sout.str().size();
+ compressed_size *= 8;
+ compressed_size /= str.size();
+ DLIB_TEST_MSG(compressed_size >= 8, "order 1 bps: " << compressed_size);
+ dlog << LINFO << "order 1: " << compressed_size;
+
+ print_spinner();
+
+ rdata.clear();
+ rdata.seekg(0);
+ sout.clear();
+ sout.str("");
+ cs3.compress(rdata,sout);
+ compressed_size = sout.str().size();
+ compressed_size *= 8;
+ compressed_size /= str.size();
+ DLIB_TEST_MSG(compressed_size >= 8, "order 4 bps: " << compressed_size);
+ dlog << LINFO << "order 4: " << compressed_size;
+
+ }
+
+ template <
+ typename rand
+ >
+ void rand_test (
+ )
+ /*!
+ requires
+ - rand is an implementation of rand/rand_kernel_abstract.h
+ is instantiated with int
+ ensures
+ - runs tests on rand for compliance with the specs
+ !*/
+ {
+
+ ostringstream seed;
+ seed << (unsigned int)time(0);
+
+ ostringstream sout;
+
+
+ rand r, r2;
+ DLIB_TEST(r.get_seed() == "");
+ r.set_seed(seed.str());
+
+ DLIB_TEST(r.get_seed() == seed.str());
+ r.clear();
+ DLIB_TEST(r.get_seed() == "");
+ swap(r,r2);
+ DLIB_TEST(r.get_seed() == "");
+ r.set_seed(seed.str());
+ DLIB_TEST(r.get_seed() == seed.str());
+ swap(r,r2);
+ DLIB_TEST(r2.get_seed() == seed.str());
+ DLIB_TEST(r.get_seed() == "");
+ swap(r,r2);
+ DLIB_TEST(r.get_seed() == seed.str());
+ DLIB_TEST(r2.get_seed() == "");
+
+ print_spinner();
+ unsigned long size = 100000;
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ uint32 ch = r.get_random_32bit_number();
+ sout.write((char*)&ch,4);
+ }
+
+ check_bpp(sout.str());
+ sout.clear();
+ sout.str("");
+
+ print_spinner();
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ uint16 ch = r.get_random_16bit_number();
+ sout.write((char*)&ch,2);
+ }
+
+ check_bpp(sout.str());
+ sout.clear();
+ sout.str("");
+
+ print_spinner();
+ for (unsigned long i = 0; i < size; ++i)
+ {
+ unsigned char ch = r.get_random_8bit_number();
+ sout.write((char*)&ch,1);
+ }
+
+ check_bpp(sout.str());
+ sout.clear();
+ sout.str("");
+
+
+ // make sure the things can serialize right
+ {
+ r.clear();
+ r2.clear();
+
+
+ for (int i =0; i < 1000; ++i)
+ {
+ r.get_random_32bit_number();
+ r.get_random_gaussian();
+ }
+
+ ostringstream sout;
+ serialize(r, sout);
+
+ istringstream sin(sout.str());
+ deserialize(r2, sin);
+
+
+ for (int i =0; i < 1000; ++i)
+ {
+ DLIB_TEST(r.get_random_32bit_number() == r2.get_random_32bit_number());
+ DLIB_TEST(std::abs(r.get_random_gaussian() - r2.get_random_gaussian()) < 1e-14);
+ }
+ }
+
+
+ // make sure calling clear() and set_seed("") do the same thing
+ {
+ r.clear();
+ r2.set_seed("");
+ rand r3;
+
+
+ DLIB_TEST(r.get_seed() == r2.get_seed());
+ DLIB_TEST(r.get_seed() == r3.get_seed());
+
+
+ for (int i =0; i < 1000; ++i)
+ {
+ const uint32 num1 = r.get_random_32bit_number();
+ const uint32 num2 = r2.get_random_32bit_number();
+ const uint32 num3 = r3.get_random_32bit_number();
+ DLIB_TEST( num1 == num2);
+ DLIB_TEST( num1 == num3);
+ }
+ }
+
+ }
+
+
+ template <typename rand_type>
+ void test_normal_numbers(
+ rand_type& rnd
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test normality";
+ double cnt1 = 0; // num <= -1.2
+ double cnt2 = 0; // num <= -0.5
+ double cnt3 = 0; // num <= 0
+ double cnt4 = 0; // num <= 0.5
+ double cnt5 = 0; // num <= 1.2
+
+ const unsigned long total = 1000000;
+ for (unsigned long i = 0; i < total; ++i)
+ {
+ const double r = rnd.get_random_gaussian();
+ if (r <= -1.2) cnt1 += 1;
+ if (r <= -0.5) cnt2 += 1;
+ if (r <= 0) cnt3 += 1;
+ if (r <= 0.5) cnt4 += 1;
+ if (r <= 1.2) cnt5 += 1;
+ }
+
+ cnt1 /= total;
+ cnt2 /= total;
+ cnt3 /= total;
+ cnt4 /= total;
+ cnt5 /= total;
+
+ dlog << LINFO << "cnt1: "<< cnt1;
+ dlog << LINFO << "cnt2: "<< cnt2;
+ dlog << LINFO << "cnt3: "<< cnt3;
+ dlog << LINFO << "cnt4: "<< cnt4;
+ dlog << LINFO << "cnt5: "<< cnt5;
+
+ DLIB_TEST(std::abs(cnt1 - 0.11507) < 0.001);
+ DLIB_TEST(std::abs(cnt2 - 0.30854) < 0.001);
+ DLIB_TEST(std::abs(cnt3 - 0.5) < 0.001);
+ DLIB_TEST(std::abs(cnt4 - 0.69146) < 0.001);
+ DLIB_TEST(std::abs(cnt5 - 0.88493) < 0.001);
+
+ }
+
+ void test_gaussian_random_hash()
+ {
+ print_spinner();
+ dlog << LINFO << "test_gaussian_random_hash()";
+ double cnt1 = 0; // num <= -1.2
+ double cnt2 = 0; // num <= -0.5
+ double cnt3 = 0; // num <= 0
+ double cnt4 = 0; // num <= 0.5
+ double cnt5 = 0; // num <= 1.2
+
+ const unsigned long total = 1000000;
+ for (unsigned long i = 0; i < total; ++i)
+ {
+ const double r = gaussian_random_hash(i,0,0);
+ if (r <= -1.2) cnt1 += 1;
+ if (r <= -0.5) cnt2 += 1;
+ if (r <= 0) cnt3 += 1;
+ if (r <= 0.5) cnt4 += 1;
+ if (r <= 1.2) cnt5 += 1;
+ }
+ for (unsigned long i = 0; i < total; ++i)
+ {
+ const double r = gaussian_random_hash(0,i,0);
+ if (r <= -1.2) cnt1 += 1;
+ if (r <= -0.5) cnt2 += 1;
+ if (r <= 0) cnt3 += 1;
+ if (r <= 0.5) cnt4 += 1;
+ if (r <= 1.2) cnt5 += 1;
+ }
+ for (unsigned long i = 0; i < total; ++i)
+ {
+ const double r = gaussian_random_hash(0,0,i);
+ if (r <= -1.2) cnt1 += 1;
+ if (r <= -0.5) cnt2 += 1;
+ if (r <= 0) cnt3 += 1;
+ if (r <= 0.5) cnt4 += 1;
+ if (r <= 1.2) cnt5 += 1;
+ }
+
+ cnt1 /= total*3;
+ cnt2 /= total*3;
+ cnt3 /= total*3;
+ cnt4 /= total*3;
+ cnt5 /= total*3;
+
+ dlog << LINFO << "cnt1: "<< cnt1;
+ dlog << LINFO << "cnt2: "<< cnt2;
+ dlog << LINFO << "cnt3: "<< cnt3;
+ dlog << LINFO << "cnt4: "<< cnt4;
+ dlog << LINFO << "cnt5: "<< cnt5;
+
+ DLIB_TEST(std::abs(cnt1 - 0.11507) < 0.001);
+ DLIB_TEST(std::abs(cnt2 - 0.30854) < 0.001);
+ DLIB_TEST(std::abs(cnt3 - 0.5) < 0.001);
+ DLIB_TEST(std::abs(cnt4 - 0.69146) < 0.001);
+ DLIB_TEST(std::abs(cnt5 - 0.88493) < 0.001);
+ }
+
+ void test_uniform_random_hash()
+ {
+ print_spinner();
+ dlog << LINFO << "test_uniform_random_hash()";
+ double cnt1 = 0; // num <= 0.2
+ double cnt2 = 0; // num <= 0.4
+ double cnt3 = 0; // num <= 0.6
+ double cnt4 = 0; // num <= 0.8
+ double cnt5 = 0; // num <= 1.0
+
+ double min_val = 10;
+ double max_val = 0;
+
+ const unsigned long total = 1000000;
+ for (unsigned long i = 0; i < total; ++i)
+ {
+ const double r = uniform_random_hash(i,0,0);
+ min_val = min(r,min_val);
+ max_val = max(r,max_val);
+
+ if (r <= 0.2) cnt1 += 1;
+ if (r <= 0.4) cnt2 += 1;
+ if (r <= 0.6) cnt3 += 1;
+ if (r <= 0.8) cnt4 += 1;
+ if (r <= 1.0) cnt5 += 1;
+ }
+ for (unsigned long i = 0; i < total; ++i)
+ {
+ const double r = uniform_random_hash(0,i,0);
+ min_val = min(r,min_val);
+ max_val = max(r,max_val);
+
+ if (r <= 0.2) cnt1 += 1;
+ if (r <= 0.4) cnt2 += 1;
+ if (r <= 0.6) cnt3 += 1;
+ if (r <= 0.8) cnt4 += 1;
+ if (r <= 1.0) cnt5 += 1;
+ }
+ for (unsigned long i = 0; i < total; ++i)
+ {
+ const double r = uniform_random_hash(0,0,i);
+ min_val = min(r,min_val);
+ max_val = max(r,max_val);
+
+ if (r <= 0.2) cnt1 += 1;
+ if (r <= 0.4) cnt2 += 1;
+ if (r <= 0.6) cnt3 += 1;
+ if (r <= 0.8) cnt4 += 1;
+ if (r <= 1.0) cnt5 += 1;
+ }
+
+ cnt1 /= total*3;
+ cnt2 /= total*3;
+ cnt3 /= total*3;
+ cnt4 /= total*3;
+ cnt5 /= total*3;
+
+ dlog << LINFO << "cnt1: "<< cnt1;
+ dlog << LINFO << "cnt2: "<< cnt2;
+ dlog << LINFO << "cnt3: "<< cnt3;
+ dlog << LINFO << "cnt4: "<< cnt4;
+ dlog << LINFO << "cnt5: "<< cnt5;
+ dlog << LINFO << "min_val: "<< min_val;
+ dlog << LINFO << "max_val: "<< max_val;
+
+ DLIB_TEST(std::abs(cnt1 - 0.2) < 0.001);
+ DLIB_TEST(std::abs(cnt2 - 0.4) < 0.001);
+ DLIB_TEST(std::abs(cnt3 - 0.6) < 0.001);
+ DLIB_TEST(std::abs(cnt4 - 0.8) < 0.001);
+ DLIB_TEST(std::abs(cnt5 - 1.0) < 0.001);
+ DLIB_TEST(std::abs(min_val - 0.0) < 0.001);
+ DLIB_TEST(std::abs(max_val - 1.0) < 0.001);
+ }
+
+ void test_get_integer()
+ {
+
+ print_spinner();
+ dlib::rand rnd;
+
+
+ int big_val = 0;
+ int small_val = 0;
+
+ const long long maxval = (((unsigned long long)1)<<62) + (((unsigned long long)1)<<61);
+ for (int i = 0; i < 10000000; ++i)
+ {
+ if (rnd.get_integer(maxval) > maxval/2)
+ ++big_val;
+ else
+ ++small_val;
+ }
+
+ // make sure there isn't any funny bias
+ DLIB_TEST(std::abs(big_val/(double)small_val - 1) < 0.001);
+
+ //cout << big_val/(double)small_val << endl;
+
+ }
+
+ class rand_tester : public tester
+ {
+ public:
+ rand_tester (
+ ) :
+ tester ("test_rand",
+ "Runs tests on the rand component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ rand_test<dlib::rand>();
+ rand_test<dlib::rand>();
+
+ dlib::rand rnd;
+ test_normal_numbers(rnd);
+ test_gaussian_random_hash();
+ test_uniform_random_hash();
+ test_get_integer();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/random_forest.cpp b/ml/dlib/dlib/test/random_forest.cpp
new file mode 100644
index 000000000..b3447bf5c
--- /dev/null
+++ b/ml/dlib/dlib/test/random_forest.cpp
@@ -0,0 +1,405 @@
+// Copyright (C) 2018 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <dlib/random_forest.h>
+#include <dlib/svm.h>
+#include <dlib/statistics.h>
+
+#include <sstream>
+#include <dlib/compress_stream.h>
+#include <dlib/base64.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+
+ logger dlog("test.random_forest");
+
+ const std::string get_decoded_string();
+
+// ----------------------------------------------------------------------------------------
+
+
+
+ class test_random_forest : public tester
+ {
+ public:
+ test_random_forest (
+ ) :
+ tester ("test_random_forest",
+ "Runs tests on the random forest tools.")
+ {}
+
+
+ void perform_test (
+ )
+ {
+ istringstream sin(get_decoded_string());
+
+ print_spinner();
+
+ typedef matrix<double,0,1> sample_type;
+ std::vector<double> labels;
+ std::vector<sample_type> samples;
+
+ deserialize(samples, sin);
+ deserialize(labels, sin);
+
+ DLIB_TEST(samples.size() == 506);
+
+ random_forest_regression_trainer<dense_feature_extractor> trainer;
+ trainer.set_num_trees(1000);
+ trainer.set_seed("random forest");
+
+ std::vector<double> oobs;
+ auto df = trainer.train(samples, labels, oobs);
+
+ DLIB_TEST(df.get_num_trees() == 1000);
+
+ auto result = test_regression_function(df, samples, labels);
+ // train: 2.239 0.987173 0.970669 1.1399
+ dlog << LINFO << "train: " << result;
+ DLIB_TEST_MSG(result(0) < 2.3, result(0));
+
+ running_stats<double> rs;
+ for (size_t i = 0; i < oobs.size(); ++i)
+ rs.add(std::pow(oobs[i]-labels[i],2.0));
+ dlog << LINFO << "OOB MSE: "<< rs.mean();
+ DLIB_TEST_MSG(rs.mean() < 10.2, rs.mean());
+
+ print_spinner();
+
+ stringstream ss;
+ serialize(df, ss);
+ decltype(df) df2;
+ deserialize(df2, ss);
+ DLIB_TEST(df2.get_num_trees() == 1000);
+ result = test_regression_function(df2, samples, labels);
+ // train: 2.239 0.987173 0.970669 1.1399
+ dlog << LINFO << "serialized train results: " << result;
+ DLIB_TEST_MSG(result(0) < 2.3, result(0));
+ }
+ } a;
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+
+ // This function returns the contents of the file './housing_data.dat'
+ const std::string get_decoded_string()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file './housing_data.dat' we want to decode and return.
+ sout << "AvlDWRK3FGmPtCL8V/RoXQLzsOKukA0zQosjXGZPM6yNf1U5LjdhRVloILZ5baZq5Tj0hLQGf1JY";
+ sout << "ggTd0DEbBez7lzvZ6hOBABkJ6U0aZAEeIMUL/K19h5gHhpwdWvcolJA+VbQVD4acLQFxcIgsgN8N";
+ sout << "2zjQsQCUkHGSGyop7/xrVZAJS4Nwy8qMWJyjgc/9mZXdDsaPWtSVTeoxpb3d94bDYdrEQ8T0Z6N1";
+ sout << "fCXCQo/3bus7NA+FJoOtw23DHMsYnsv+tfvaNzzzX7lc0cRPe8Pi5q9JDBs4Gc5bPh3Hw8QJz3n5";
+ sout << "beGJpU1KRHu9vq1zgFuavXvyZ1HQVHLj/yJRm4lL5eUv7CQeULGP9UmIkvVCJTZ4uw6mIFJyPYjN";
+ sout << "jjygWMfJ1dnlGpI/ZlaVNJaTEB28tNymV0UKeKb/Sg42s9+wNNQzWMBm5TRvqmplf/0Gx+4Tcmcd";
+ sout << "FtUb1pgkz6OB59Ko4L2hsxVQYYZHGpQt6QBMiubW7GVe2g4yWjRhqlQTh2sjceGRi26SFOrD+gnH";
+ sout << "9xZlbyKdlKlcT3nVfcKYziLKAjbmr5QIu+W6WR7M+p90CHkDrjkVK0WTSc23kOhuua5feG54aoht";
+ sout << "hcViWRpASVPXsKKJcl2yTlZ02uFQak5Lid/znCDmWQk0fjfzEZZNzgoCNVi8xCx68lH4Mjm3MdF4";
+ sout << "ExZqX0jsNlxDOwKp7TYIbhfY4/XdzFSi20CtxXbB3knkuggb+ru/u/6lh8rmqwLjsANqb2CWG0EH";
+ sout << "32i1/gmtlY54kAYG58GWB89klTnUwRImQv/QjJBrf8TwK0jUOjkrnOWgKWwNsK0jba54750QPal4";
+ sout << "SJFvmeBIR52/T2ZsC2iAQGQuod1IZtIfI2pn0dpkdWm/Y3JdelJ/EADtNjJch6maVoamxKmoWlVw";
+ sout << "TRoHTIzdsJKtQiewn21H8bFX7HzZ1yE3/iuSEeLPB7T7gnfrtCEEzqAWiUb/t7mXtqBqt6Kdk6By";
+ sout << "JZqE0KdtJ7MJ/yV1MHQ94ExDhYI0qaItZVGwyH8ETCzr6xqmue9hPD6SRq3oy+aQKvJxOqFMcqsj";
+ sout << "cL+/2S2F1frRgZDGCb6tB3TdMCJDhZoChQNmJ3hyoAdWXrPEysL0TW8LFSIItAylIRljMsXgnMRE";
+ sout << "RCkfYPdbweT71k6l7FiaAsqHeKh4w8CxKJkzhJeALEPLz4QvqFt6/DoFmhKTrX4xUk3M/Y+dU/LY";
+ sout << "3B/S+e6v9cWuluEGXDzgj/LKBeruPk7hcnhIilMmd3D8sew3tvOdIowxmM67lqW6fwExcK6oKPlT";
+ sout << "aDZcttWKTndVEKsUvnZr/PQ8sta49+GGSfkXw/MS0TAjTQ0Wck8wSJ2CMoUGBmVSKLSwEWvdAqdo";
+ sout << "lQLDxAVayR3GeKasJshQw69o/3d4JnUOBcU5ZJM0z2D51EDQM3lvmnB9dtiJ2rcypWG53ETvQqYc";
+ sout << "S3suPgaDqKmxZbRNvnfuYbG6+qPeHDN6WmdAt9Iw5XWdjlG6u8BGI6+vqY1C8J6pJ2p7ITUVBVCU";
+ sout << "NRYyXVhz0PQrzy5jFwbVeRZo3IV/bkPlHV4UujcSEOi67fqk4ixOzA2JsxYzei+Rsp1ahpK3Wmuk";
+ sout << "9ZEeqD1/xqqeosV5pvwcXhjjp4UJ0bGY0pEf7w9uDW0aZT+D8prSmTXGjFAQGiSBmfLFw1Yk2CyG";
+ sout << "V8RG7/7uxM6qyj9LYsNTGvdyD8DvQNEJ0v7J9IwCihdJAFhuKgqxlmkJx3mz6MiiIC19CuQKR0NC";
+ sout << "1/PG2wh7zEhnDwfINSR2b41ZDcm8/ky/k+xhJ6fi3ZqpSlkqyPRA/YAID9Dl7Ngn9/xeWQNXzEHD";
+ sout << "pn6aPHclBaDm5keUHlcN9d+vrxrRjV/GdRhs5M2eVRl47djEiYTaQ3Ua9Lg0oWBXYFkC1ncNhsIT";
+ sout << "tFLwx1Fc89lNwN3kHW91X3g1MMDFZoyzwBan2v/3DOZpPH4U1cPr1SbAPu0HITHK6pCUtMRY2/DZ";
+ sout << "9MTmm6FLeJhjiVma1+ALD4wYTNhebWkmX12jeDPT5gMyDhq3dYQKIvq83aVY1MQ2UroMDf9NVEdh";
+ sout << "V94rpEjw0ewZeDKHwPazkOU4q5m69VxwkKSbNzKZ1oe0P5s7j4Z44NYP3o7Qq8MPLi9l7QVKqge8";
+ sout << "6PqEdYoxGr9a/QDB7wcSdpcZTku0MippOZm1uG0zA2K6WlTmC/3lCm4m8EZBZXq6YkTjUpPreeio";
+ sout << "umUsmp+XbtJ1LWK3V5hMl0R67Fa6tBd6x9bP6/FSwjeYPHj8dz3nK0VLX+NC7+NjIPBkKgAcqs+1";
+ sout << "m9u/BA72sNE/cn7NgfgXPlHsho3reV8Iujq+MTN5iayeTH2fyG7XmV0RkpOY770bJEdugB4QlWWL";
+ sout << "nZgYd2yyNFlBcvXRKZpoG9cTWuzdxSbJnYdgKGLuPQ0B0BYwrSLnHEM8pboXCN4uIrALW6ipmKeO";
+ sout << "/S8bW4u73Bqgamnh7/pdLAoWo6EiD5C3uNrW0OYdnnSjhLHkV+krDhPQr6nAfjH/0vo+CXPDbMyW";
+ sout << "DmkJVgJ/cBt+EWNyIOeBLliqbe9zY6hqzGRt4b6cp1fDH/tbYMCsxhp0LIPnDNDouUQK3j+VBB3X";
+ sout << "E8NnCSzBa4SdhMNww7jbJea+RXqe+g1clXqkBf/ZitToTPmzHhYcPRAcvhLcRE0n/uWj7jbv5vOD";
+ sout << "swyUpiJaXQgRG77rh89xNVqz9qio6uP2xGdlyW0IENbxMYsGq+XKxNMAMHvfRh8JwakGw7ZI/Wn7";
+ sout << "8uWdjM2lNmenBCYSS9qe+2DKnqTxOSnm5ugsYr6IXftmlzev0ke2rRBfvllAv8GSY8GTJul+gbZV";
+ sout << "+3Wu8xZawgLFjngRph0aq4PviIwrMS1PhE5M7pC65E40uaY+xJv4rQGNCLF3/+SLvnLfTRdH0QZU";
+ sout << "r/hXG0BCcaWE4hb7HIzon9mNIZf2Eb+IWxAhUQ2/Nhe/hNTRx+DpB/8H2DurZPFK4nrOPvxmmzgA";
+ sout << "3VFL0kJjNfeXGlo2sSQEM8sDecXQkl47KGWROHIaJyRZAoBOBpMHxTXs//3aqWhlOZ88pftZEmSL";
+ sout << "K0sXXxS3BAKB8SLzu4VNPNdvmtT7Z4sHmfYj5NXayXS3V2d2646L2QW4jzzwHjpJ6p2/4mjKniFO";
+ sout << "TSQu1wKSv16L/QVbiwvg6wYgIL8cct+XSRSUGfvVgo9lAt+OHTIL7s/9A66jqdDmK3UnHcV8XOVN";
+ sout << "Wd+DnXOG7nI5shOiGHAqAAtuQZK9sAZZtyLmg68UN02ZroY0pUePwuGaBribKuKLDtcfMcDHwv0x";
+ sout << "lSlCkHu9SjsC1Qswk90Yx8YddYY1ePYaoez6xUAraj+zOLNuFCZtm6hGTQq+pPZ5xn/K6zzOvaos";
+ sout << "sxDaWBSNDEuOfAK2dHgctL6XKH7/kHAZxBa7CbTSe0zFuO1WbicRrqO1NpUyO9P1L82dv1VCSuyi";
+ sout << "Mtj7UNnywrmZmMBf5x5yYBdVwBUKIBy/sNJdrpDHbAi6MJoYzCai8TqTulNL8jGAzkjpXYRqPf99";
+ sout << "fXWRTzN1PMHjvEbNOBIX4OorGR4lj4E7i+d1DKaJjXCDgvJQTUGvRWdu7MOOkXPCmIlxFL9Wv2CB";
+ sout << "LpzzNp+I3NuLg2F30ratBBLoqIrnBBXb390pABYah8bRnARCUJLjFXugVqTWoMwAsrbS6sfdFRf8";
+ sout << "fKt/+Nx2vX8tRJBFFgBEbS2le05ekg7HC6egGCLImh8j8sf4gs+2xdGKXh9mnW8BrqZJvQPkeR4D";
+ sout << "Fro5V/EFe7EAIXpQfMRoNpHUSyn5oPJDFYMjjc1EEO4C6qqJ29nV149m60BjWDuVK1y+mdkCvaDo";
+ sout << "iZfAKI4TiDExfMpdAJM5Ti6G7pauQnW/lxGNX44JOR82strVKwsQxSecUc+uT+kYjcdV9Mx29qJC";
+ sout << "qXJgi2cUhrNsVH0LIe45X3pSzfFnMQ2o+fGAURgQCSW6wToqmtHBsCorr0v32ew524X6tZz11HMC";
+ sout << "7DKppzWxGTBOCBwOPrAjwlUb8zaRpj3bQGXWkvDk7E7ZUMQ6TJu0wgkuBNIPeIAh2tLfGqrqZqxp";
+ sout << "Y2hM/G/qQG+Bukj8yyxitXqIwceSp3v2BnLnL/WriBpSVEm9LnjiPou/BL98WMhp13UKWe3L3XkC";
+ sout << "izri1YMoCyxaQFX6RODu4We1vi/9gjamFSCAc5Tj+CmUasCJpbmkjt7Fp+v3PhXa4qpt32ZR0yuF";
+ sout << "G0dowpb8WYvT3U7nWAOFKBRgj9Ri6QPlCVaUa/LnkjZ+fNzytlkzQ9TTsPpOkEeJo+nCF3cCUYBH";
+ sout << "Y6lIyjcQRk9guLIgk955mBLpyjni8ZFOOsjTsW+LoOvAiZhVTGwA75/g5z6IcYLlcb0nwpZ/O2fS";
+ sout << "QPFcb5V6uhi5TnQHDQGHihSU4MBo5BQfNd+VuSxliK/TVvFU0yYjqPzKxCxgBpDO8qKsPMbc2YKL";
+ sout << "SFY2ygJ7PwksSIEQUum0MSEFf1ZJ3WNTajxSvFLToOkAtLpnvZlWymYkI72/Dgi7jBpfhIw1U1Td";
+ sout << "tuTLc0L6IfALX2I2VL2tOhBcisUL8IRhDipxhBTRBaJYLG2RB6ICKBuAQaXf4ODAPKbLhzfRSss+";
+ sout << "2VTojSwerCyQkKyoUZLR67G2ysWWLERwD1btSNH4IjaPYVEmaWk4I4F1YZhrmN3q5du7t7g3E4C4";
+ sout << "/UVLrCVTQD0CnBVBB5hzMEByG/4ZhIu+JWx+jRx1288XA1k84c28NLfMnqDsLHGtVxOLFDBgwFxs";
+ sout << "vD8S2E1+G4La3DQWc/X6jfkC+dtp0ihh5qQxGaGCKh0mcd3BHnNJYUSqSRQLRhOjiZBxmujGrhJG";
+ sout << "oHPaUCxfgY3vl9y6KAVlcLcKTYZpmukkjxCEVOClOy2pHYivhgkO2HR7okgNGpj8qN9EcVTWPq8u";
+ sout << "dbBjHLQ2GbqHamyaDJFUhJfsibyXR5ZtG2WAZDH3uXlL8AGNriBhwgGVcaRGH4sO/NmWWdM/gnap";
+ sout << "6geVpginIZALN+egxDTQtxTH5qTPfkMg9tdjlX/zB7e1LbVR40waeP5PtanIvb7VU/GbVbQMEDKc";
+ sout << "Lqfj3v6RyK6wX7mDvF7HWFtav8R/j9wlgf75kOiXz+2eN7GeXEmF68LqH6g4n7Ulyhq3uqszT4Jg";
+ sout << "hk7ynKJoLURg5KyJPTUCvadDaiqaLH4hF2bErrQzIIbbKDCq5Cb7n0EhCZ03MjLFs9+HSa3yfaGN";
+ sout << "NUS3wdGiM9x6rNaKDP6/vySXZzBvgtinskFBvb7UqCwmQ2MF1lwr0+nTNfH7R9fw3fi+tHXB6kyh";
+ sout << "PovdaPH+3dfnsbqVSoJLj2OvjsFfTFQXn35xd/IW3UEdBYVSDZP8VGRnXQUSS4BbJ79VUMNOTmwz";
+ sout << "CsoiZzIZNgHShekR2XKv1oXM+BheSAxK+r/d+VdPgjlkByfCwuw8iP/odUoaXzk6iTh8h1pGyESL";
+ sout << "QY8mNIzzPsU39opNlK7JmOzlYG2wtCS+DcG+bw4HLJP3Or9mChHpN+V3xzL5Tsb/5fGeqcQ0hvsA";
+ sout << "aXMhlsnRtxRSDkfE0s1HW0r/O63X5Hm1Yw/vJw8BzEtNYg8h3x7xvECS4vAwwuKLS30rjlVqjPqI";
+ sout << "TNchzWOA98U8AoC3t0asTAaEXce8tPtLqXD0EyycoU3slyfpErU4vySzpCXtkv3BShevfZy9yhX0";
+ sout << "2HG5zTc+l8GdXayf6mVSXaQ2N2OV6gCwd+hwqHjqvYSg4a0Ug+/cEw3zVi5AGiLIzTGDGsfJHE86";
+ sout << "9ohKS4z7yI+doqegx0f7N95Njw315nKSZnSSf6Pa/I20SrcQabMC36H2vdv9gkOlsYlZLyCOL54P";
+ sout << "ZOXlim7GgCt8LPEO6maHmQn0f5mtJAYIxMrJKoMasXvc2ZI2tktbh6bAJNfpSL0KbTfeQtaFJnVX";
+ sout << "C2f8RXf61VY9rNVd+qtpNuiavf8ZuaVbSLsLzF/beAFpS4djI0Nn78CJBhZAnzhPD76byg/vXG42";
+ sout << "nyD/u/FLJ3eccPnvs5umbz/gPiFk6gW+HnTXaYEdwaGWDdlr4QxvDki8Wsr0AWPlzA8D0nmkPCZW";
+ sout << "EZLBUIUjnBN2K86dyqEDW8+C42vuwXfa31wkOX0/8S7FSuT1BET8HdK8fykJ8NxdKlUsIFNr9SPz";
+ sout << "maMVyvkPp6IQ+DG9PBOFIaFy+zHbPzCRNNd3LTBhkQ0K6bP7u4tG6b4fdmmQzSSGsfXqEUiXkjGX";
+ sout << "ge14Qh+f/2KA4TjBZDQWF5NKR6/x4lsHfj8dZDg3+fEwY2fqezjD/jptis7N/VfeSIM/3xD+gF3w";
+ sout << "BqJn4Wz+ohlWucLfS0JRREnPAWfje7RQYatBkLok2Uy2hO7lgfw6ipUHNVPUw6XmK3VW+McnK0Ur";
+ sout << "L4CI9LAFF+kDdBfTs8hnhmLtk6h2Sucjo1ahEBxAyUuRgqMko5Sy8Lr9Eo49KiKO+V8LpA5ZDMq8";
+ sout << "iabdyb1WLFnyvE01K4uKqGHLoeh7heD7/0sAbIVySks0mv5cH46AT288mIVcHrSUurhtxawYZY4P";
+ sout << "/DqW0jHbqIkZWJiOquIfeTbtgRax76gX01JSeEdL0UyPHTmoqkvMdQVwjYcIBdLrCPPWQWNjkmWa";
+ sout << "XBedBCzwmp6fZX8ew5AABNCmlBBlhqNFZQ4yG91NXuiDoS25xckthn+6l7Mn0FHs9418wKqa+3eB";
+ sout << "uGqiAJVwpgWx7CWhWi9MonFdA1nni9AyjubzEaUSbzjL4ghneOGC38FyEQcuKIxrqgY+ManAAdlc";
+ sout << "hVaHl9Rx4r16AITagHNPLwCbbeJ7nbM+arjvU6qmK4Bg4E96IDjrrp2EQJMYZrs5+oRbpdtomWTx";
+ sout << "k3hcUFCMUiulWQs/pc5bXm+Xvx0mNnpu9A4GtFMzzpKO4M9Q/mtKch9H547N8hjV4jrWsJCqplho";
+ sout << "XIp9yBwPKVaEwWYpFcmpfQIJW+I56ewA0xthRBdqqqQSoS3zMJwdQEUm+XibYA9XALC2dkbTH2fo";
+ sout << "H9a4ImxxquSZZpoqUuRsbsgejD2v0ynbipTQ/lNVwswk18Wma+Whg4sOdCSQIMns2QW/3quqHoqb";
+ sout << "DC06jJZuQJnLNly7x48Pcus1gsWc6aVhbpl7cCLq/YUY0ACMS5toED6q+5mqakCg69pK4dm9WHIf";
+ sout << "D0hRK0v/05LheyllMYmQhO98Z0Jz8IJQQsl2sZahUr4Q8oTTFt9rLRKd+onL5HwdJQDAiYB/fex0";
+ sout << "3MHYPjyzK2aH6mN7qOG2VquMGb6LmOPszLSggrzFsVzPqYrzqH/6sGZPV1wMduWw/qYadGPlpzu/";
+ sout << "XIgnQ2Qzb21xqwTnivx30xgWDf8kuQWDdECGT+/GPAqNx1P9RGLru+tSPe06oZ8RuoSs64zDqJ3E";
+ sout << "KMmOeFAt+xQFBg6dEBs1pgfO78AfZDAbadYzvPCEp2zGGeYwhZlF+a+QIbFCyLe4EO6a0wlrPEBn";
+ sout << "3DESwdT7ETielqoEQvKdeSQiks0UGvlXRhxKdH+ylQslmdolzE57pMkquAwiFMXddLGFegrctP9s";
+ sout << "tmsvLPKWDIqiHy+F79eU6vOfwwS7btaRg5zuRKWkQ+B2CU8F/kx4FR4ZxhK8fzGjMUyjAmHZhEXf";
+ sout << "kvnchtB6z0pN7wUf0n+Clxo0DiXlJlRQPo3pZDttbC685azJ3OoH04xS37vxUSx1ir/LWLz/tjkW";
+ sout << "iFYq3qxftzK+jU7XzDx2nif7ZLc/+ecfHdQPXK4YZzJ1x8C7SvC7rBLRxnKqTYgv2bL9G1sCU+x6";
+ sout << "0hQtMba3x42k//w4RtV2KkazHoMTZc9UuNSsaSoAoGauzw0cs99op7HCpOgoyRu5JeY+fimo2H5C";
+ sout << "cXBecQbQdUB0uVxxEQHPwJN7vi94JfbpdnIMLLRjBwRs/2FOmMWbWWcShUYoWDSmJOLaw3Piwtk6";
+ sout << "bg/ppKqGAfrzDJkR0n1OZgKvUbnb8WRyZse0W+tO+PcsL1wvwG+8mMJU+AOBs1P/iVLxW/Y4CuXi";
+ sout << "/e7SckKJ3vsm/pQawrzhDIjOwofxzBWQ4kODfSEWHZvpQD0HNf/qP6IYfqhUu/0JtRJGLhlQ8hQJ";
+ sout << "iJBGtwsCRJWKrBgu6cizrYcA664+XPgjF/FQYLGmPiPrBdrbWjVxSk3tEOgVFOuK+bkI0EX3p0hm";
+ sout << "gYbr3oIec4bKzrSgYsIQtHMo1FnQl1xwHL0vH24KF6V6eyYpgVBfg42MNDk/aaCZ4XVIgH0H0wns";
+ sout << "sRXftElLUVk8yLhqq9kXBmgHvPZMfA5WTP+KhXFRbfxw0A2nWGbztsniRcoA3N0pGdqwDyOE5VGg";
+ sout << "tX94o9eS0eOJzh80SKaHFaX8GtlpVhogNJiMVlzwVoNJASK2Pr7Yp8uIqcUT7+e0VtkdsVlG5wv8";
+ sout << "WLEbqmRXrsKLs9f1p23SelLo78kI9nBujEvDSOCChnNwqNPG85kiz24jL0LMaWHAHlnY6uZypDyM";
+ sout << "TUmjsyosdrCobZRnQFf4UwUhuNtj3f6sQke+GQhzr474hTpfSDqLGMW6IcE3OcU4x3waC87DPSRC";
+ sout << "7PtmJ/+8nWIElEbGJtjS+rL+Ue6faqpkh+dkPC5ZsWHHRvXzyuRNawC15L0kLhhCc9Y+s+fWOppC";
+ sout << "iWtPPQKk4PKDA/g5TRA+KPkFH0B6YchdiEaCMLmleDqF9uo+XNzMnHdOKrkTZ29gPosM9w8CpTSF";
+ sout << "neDroZT/v1ckXkEZ6rhlVkF2pBmqG0DTL3LPclzO3JC6i6noY+kFU0jSARjXgQU3NXrYpgeRhLuv";
+ sout << "hKlC4Fl4xTK8l+p8J8Uk9zKTKsAdyDcAT6rBGfRpmJiN5a+lxuiBSCyygeHDQGhaJFROD53Q/m9M";
+ sout << "dxBlTjqMI7r5M2BNFKaZ4vhFdukvCbu2dPm8t10pc4brs7L5TeBVo5qxaFgRbkpS1mTCoDtFH4fi";
+ sout << "Wl+zpEdF6VpKrmFaCSSduE2tMhr164re7m4P7CxeJWvYdFfWGD+uDhFM7oPVXvvkC5gYmjdAPIYq";
+ sout << "co6IvSMraA9ANQd8b/hO6u+zo+U2Fos9Xe/1u5YWr9JdXZ9oFsdNGQLaLH6VcRrwAyx7+tRjf5Ia";
+ sout << "IHli9+TQmZ4tbtxERYe8TaqFqugpCGtvmcE/DFo0BeWgFRFnAY6nyXGgJCrzbxrMdCOBdvIYzuuB";
+ sout << "+A68idNO9ifsfNxRalfJCQNymwy4MAylkG8ncZ6Bqx1XztI9ckbD7U7TBMHCWt9xnMxQz9G0HmsQ";
+ sout << "pIa0x8tKk5zZ2TyOVe1LjwBXzwhn17i67Ph+NTA2pw8La9KcI2xdlDtAhD+LxRBANo95CeGL9NKp";
+ sout << "cDjWMrqwRlvXL0qroKeJuRqtSPGC+hbFEJgTX7iWDq4QJCyvscIm/lWz0ZQIHyXyh3yV/UyGbMXD";
+ sout << "hc6mVp20J+AGPR0NCEN3mTh01ON3LJI1t1P6OT8oGM2ofce1YsHLdMlY3uu00ErXy0YF6vz7jft1";
+ sout << "0St41Ydx54E5As7cbimxngKnJsFVIdJC8uh8SaAJQJWtQK6DG5sXJVDADSIKXM0TuhRTxRREGQWB";
+ sout << "W6Hd2jP7ZArcBCuB2GGMw4sn8iAYMK1LP12hxoZsBZ8iAbohy3MpWZHiE9MDU5PbGRyNsEnuLmQp";
+ sout << "APmj5sFUcAsA0MSUYZli2jB2WWAWwTaQ1CGm92tdrdflShh5FR8IhAZrwXPzI/w/1vAianD9yheA";
+ sout << "j0cYf/EaB92n9xRSK5zeajIV+DFT/451rNvi0Dqea6cDxRkza31G3d7pWwPkY6WiGdvSzlgy8uJN";
+ sout << "pt2gJoZJ6VzzeYD1bsb5X/FYBtYwFuiGicRbUadEA0b736fEC3AG2OZVh5bFmVArBoUukUoBNf8S";
+ sout << "gWgzfeYNyL25qa5jaeg+X8okmGtNUzfxLtmJiCY4/A9aDh3yoSCIHwH1we8m18DjMTXYzoc2b99i";
+ sout << "18h/UV7FF5xvl5awkZyLjDPPSDtdafb5ufHyNVjblORDSqS2soTzwyoyZm4PTCeiXWaO9Dpz81fM";
+ sout << "+bCkqFrXm7qEJqYSrCGLDlxwVZJeHm/lCNpbO+GYq6Cd5CuaJVactLLRre6s43nyD7IxiRfCmb/f";
+ sout << "LUyVi5sXEIaWiw80Me64uq/s9ADmuDUkX9Gd8WA+7fyyytSuMpooPNEsVBY9LE5nKWOlOy8Hrqmf";
+ sout << "piWc2Pf5nUtyZSsK3p94XbsysthhunMLsiv5j4mcs61xi2IyEgWB5hJ01qk4gQV/8SHPYJ4stRxA";
+ sout << "Ea80306xhLLKQjYSpPKHOvoil9kCHIgBzOp6lZas2vOzK/w50AVekKYXFQK2lWMs8TUBzWy5fYPK";
+ sout << "CZgcrWP0fShh9pthBw0DBEmpGM6fQZ5XQlBC2hG8HEDbbakqvIUkpuL7jlFde+HW31aTQRluTYl3";
+ sout << "huZ21u40SMx86ghv3VUzHpmF5x2b+UIIi+l0yg5FZHXRsth2xRsm4pNO6bspgiL/HrWMfsjwD2Vz";
+ sout << "d/2Kx2Dn9FLRUg/ASedunqth0Ovq3z9Qds7pH6QVdBUmtPokcHoC3KKl1gmY7/cN880Az+h0SMpn";
+ sout << "eqduvQM9adP2tmuybV5zgKGCt1q6cc0fPPBD1DuwAgr832VjU87nVOl13p4TV9NKX6wvnfRcw1bQ";
+ sout << "nJdFr911d2uMjwuPJdKusPo6o86c8YHTOcmUkC2QkMM6gsYp/lK+lv9fwJhvXUKg0aBlNceJ0/eK";
+ sout << "LGzHHzsVCweHXjlVY6Z89uMHZZeqp/wzEfatokYF+jIfD9rP+9AyuMIRuIOuXTCkemsGcAHqpg3F";
+ sout << "EcSZcaimDAyropuc7CYsVhhxKRDQBYjTbnd0dhIIDZ9WVP/MbG7QRmJF77TB1+a6GlNjoOYEuJfm";
+ sout << "RX34p0IQ/ycmc8PcUbFXAC2/epoQKPRprwg2+EbciWSYQe9i8T9gzJVuVHaWF1GjlsNJNvJDnWVW";
+ sout << "2ffDvQuZ/YZ//zqKcA6e9A6tTttCUD4XebQmhT5vIesFMuKNUHBvJZwerszeY+AY1Hs8kwTJNMB/";
+ sout << "DDj71I1sz1vq7X8OczT4vaHqLDg/4MiyHFatIaGMlbegVLtthaj7BdhwxM7xz0iilKncYQ2zYw9S";
+ sout << "wMgYGoTth7eZQe/q0rgzXi25acEvNkbidVbeI+PtUQ1694G/eKRqOYnmaWmhMsCsEUJH5ZI+XhkN";
+ sout << "+94T9Tjb6s/P9z0PisH0UAUDT0Rp+DeikJF1h/yLnxhQ3KxIwt9yB+ZlizVXB+6F7xcOAXuVocD+";
+ sout << "AyoxOZRmI7dRlFB28ki5Bcl/EHXa70EEFyFao+xc66nv7luVhyscR7PydzdbIlYba2tnkr/QS5RC";
+ sout << "kQ+4t8Z2smt7YKo2d/A4Gz3YNk3K3ZbUDWWSClHkcUklQVJDGg4b2da+V9RV8iAuugNuEdDVrs1r";
+ sout << "ixGKhyyEGGfIaFUPEaL5/NrAZqBuX6FSuloE7m/MShmkxRuilG5Ngqy9Gb273F5O5x5T6/koSaTc";
+ sout << "5tLEdNFmpZYmj7vdIAsHeoyjxmmSycfE8lsCFh1yZRp3I58aXVxBoTHGnQYkQoIeBpr9GBPTo9hx";
+ sout << "k05R5M/LAy+Y11NSEW/gRNSiDkmUDGclorU8nz+dLVuyq4ZFX0fGt+IH92B/Ut6oX+S8CaL0iDcf";
+ sout << "L+AtFn+m7o5UUKZx9KN2YEv3EbxEdl3m3BsSsgJr/KnVvd88zCEoILj6zZAHE0tkqrYRDX0rKjGc";
+ sout << "1LNaJQ+TGXjE0btPlj4hVLWwZHJx9JT70uDDDQ/xz4v8p4Q8MQqfyhf5dc6/GeZtMP031eq1R84J";
+ sout << "TeEyGjpF7F4Kngzg1q8mFT8Ay4dmF1rCwvwMl7QGwMAIp0lx+MRC3J9YK5TALpvK2NA70LNDzquo";
+ sout << "vuu2cNlAGDtCBJNQ7n9vFEDy37OeXwbTpoJDOWXx4uDL6HQD4yZKxeFbX9AdyS3pTcl12wDkodRU";
+ sout << "ESXKNoL9DydL+atZpSrTK06OYqoo4s5ihsGUh/CRJe1owWDoCHJEmT4ghhVeU8YVHxxdVEpOtXw2";
+ sout << "csBk3ljCjfpZoXf5yLtEa5Md5JdAP7fVuqDv8sPzg/I0IvXvD24a1RxPcalo5Z5adVfWWGZkC1W/";
+ sout << "oBAEgcYFTCVW7IprKK/JuNv1988z19JHhpqMEDNWr7JszAEQ9KRTNtLYjFb/uDCdUSgqiQbV6tjD";
+ sout << "PeeKTQbxZ4r6fmtEuV75z/0be62g4t+/aHGNWJjJuZ0P1A6of0LOPZwKhRY2kydC8okBVp3TsM76";
+ sout << "9p8yuQAs3WuzaSJR8H2woYQoykHJV/ARapMBuxHlvrhDWFITpyN2LXl5suea3UK1GBJ6HWSiFrIQ";
+ sout << "RvMpY13CsRH7uPdx0svXicHK/GRnOPr6ei7cmMsp+nOKXmE15XfatnD8N6OHLImCrfY+bLS1FO4K";
+ sout << "EOWti+cmcfz06z70BNUnGSHJqWNohvvGVsre5rimgFSRUJrxN1RTrievQuaVB+hya9rL+dKBRjmf";
+ sout << "Uc95nLFFBzuhO/CYEzaGX8JpyyQgh0I38lMww3jK+FRbw3AocP7/rdaNpqi6coY/eFrl8Iv8drWh";
+ sout << "5B59c5boqoOa6ZFvIkqB3oJp7ogpO6zFnSS3rGXt0tMWyj5PkSWeN2Tq9pO2gdSM/p0UgN4Ywcpn";
+ sout << "rU8gtJgD/zct8G5pH4rAETV3vjfKEnlqG47oIDJzi2PY5zuiSlf3z5pY2nnPjhAFhlBAOGxCV+Ch";
+ sout << "i0Y0ziAO3PKo9YXNF8q2hnzroT2o/GXjOZdC56mkRdzYALMv5vkPTQMBqddjahpZDVJLN/jCsnGF";
+ sout << "fVOCW57+dzaFVT5Zlsk9xqIaUim+bDg2IHv832FM49MIQx7sJAS4lRmsZNlS1NjWKHwsOtgLPK7+";
+ sout << "jRIc6qhU1i7l7cWFd6+oM0U3Sv5yBXTLdTGbWbUniUCn0Izv29BjX9KouaFPRNTyKYfNnoE2dDq/";
+ sout << "jNGe+Uxcnbt8vxewpCqRvS7iGBX+Ylf6MW+HkhFu5eKu2pSEK5JyLLS/+kSRHyLhdmhz1PBRh+mr";
+ sout << "duozCqvGZN1cMESXzPAVSgKE2sFz7au28raq7YvYI+Pe/8AbD75HPkYlEdVu6SXzwNGrCksJuE1A";
+ sout << "u/tAl4GZzEzvyQqUtcf3HD1dWV/ihrtPgXbpCR+GeR9zWrj4MjTDm7ZtsL2NDm599UTNNPJgaIxD";
+ sout << "5coKin30t2hOg6LzFoGKpAwGTijauINY/xAqgQBA8vEQ7uYGK2bkPn9llAAG9e2L+KPKVS6nLyFf";
+ sout << "unzr/rPkxU4VITFN6V9GGZoJQZ/QFiCm6kLO+beLPgsPkZAqJNO5Gl9OTa6728Ew05ZnziMsWJaM";
+ sout << "OaAqjBrE92wtFITs3Qdr+CHKgKqsrGXEUK6hfylB0pOhtNq78gtwsQ8rFzwyu3hoDv4YVEt6FBOx";
+ sout << "zb1KBFOMz2x/RZ1qTdO9bMONUe31rjsOTmiFu8/CQyYfjciF8cp367jbzgcWV2aFhQkY2tA7SL/P";
+ sout << "HCU5bSb4qpqJG4zg+RPv2Hx/6DpmIDXwJgSajh3a7O+HfcMKIRvWOXqKkc4LPK1RmMy8aDYzNnav";
+ sout << "frd4Ii8a2KCeVqsGmJylpypEeMyjQCX8CcIYBYOEVQmUQGAoO71Cauftv1pt1yFAUzpDn+gGGSmN";
+ sout << "pp2ZCm/qPb63lb6Kz94Piq6oVw23zqFrr0pqXr2TEi4e7jTLzMGCcXBUj/qNiCly7TmlFzM4obpO";
+ sout << "1ev0yo6ccmAF8H4BOeyX7lqeqlwZHmpjc/8oa7QwuQnBXB7c7HHm+L9F3N9QoPnEqLtSmrJmNPoL";
+ sout << "qk1d0l1174BQPBZCl6dcCHQdefAZAQ5v66WcPAoZNlt0lVL6CBman1pk5p0e4zU1EPrqYIUxzfBG";
+ sout << "mLv2zWim3OpVjpYpB82fhtIlyIwOst+2rkbeCdIm/3X54LiC/hudzp7zUa4pe99DM8jenauzTIR6";
+ sout << "BdqYbHBRQTc5rKaUmRDq/+JPVaG2dAjWVTdPHLs+rFM3MvLdd0wPG2T26uwwQyhAx+PHT9JEhU+t";
+ sout << "pSJpE/s6LlZmqt6RPuGgYuO6jifhECGWmdy6SgT2wYl/9REvPSsMIMiB9DAbdKAFv7ios1KtcjOq";
+ sout << "pIOUs4NKeJ3QMSU3lE5JXf0V45VBkJ5JfO9lyCMgHRGFd89mRf8/HON8GYkSidyFF+d2Z+po7tHS";
+ sout << "Bhfq86T6T4vTUDE3KCIcsir6kE+hyZylWw+fnBRzWYVYMBp2YCKHybxBKdkxpvAnDLibZYdyEtd/";
+ sout << "20RPeXxxk54lJkFAjNy6vtnh6vLomfNALcZ8oqS8iZWX5v4q35b152XHo9lEWbTxokbbXeMmqLdL";
+ sout << "UjBrCPkD1j9ogboDfWD5TNg60dJVikPCyUHbTSTTEU+I9niREiVDdZToXbeaeRgOKYxtnwWiA/FM";
+ sout << "BYXiPb57y+/il8TD1ZT74JRjz6kAmSa3bM7GClr//V3Mdl3TpoP949ZhX1L16IHc8WwgE0rQnQHV";
+ sout << "NxPxVWS0EhAdmlYB73ib5jPua1rtlwRbeYbDF2iNsuw8ss8phioK+ZR6BnZX0XZN2Tw58Fa1e5kV";
+ sout << "y2kHZGIkY6hk/lrA/vAP+QuV8Pxb/P7VXVsgkHglBYuZkkZLgjqAZDvPjvZcBqEDQjCHrG6V4woD";
+ sout << "X7xWGVYXVRt5vJcwOX0hKCcUSv3+1XL6+ID0urNTuFqMw96QKtb3//H6qZvZIO/bYGjJcDVbQHJB";
+ sout << "6Q//OjqKgSiw67SotBAvIVQrgDeHUzaS7JdjyvWHCooClKrAiIVcnKX02r1y+mVpeL36166MC9D7";
+ sout << "dAi/coCiOwcQ6STbBY2PPhWrko31E8o/l3uDeC4cl3TIVjkFjS61GMogqvd2ISZ3Egq1jZkRYMln";
+ sout << "962j1y32UizRikxEK+2d0UREZ70oZzW3UoD6YqshoElP78GkH7HGS8hjpL+UhuCjCura1FK4/qVG";
+ sout << "yp5GneObh4S9DzJV+IZuviLd6drknWn+nKYS4YPWkbOqBu7RBvh3eb4bQMzvVl2thL1b4Ff8Q79d";
+ sout << "mWxu8ajtptUs0OrSor9gnqrLu3K3RWRTWSElrrMHjCD5GZcsrR/qp7Ip4GDBRPCVJNLa1Tmm8TMc";
+ sout << "kTkGtbN0E96kT2qVA4/s0vYCxntAUPunP0k/JtWg5YzVXQi5/hKDdIE3EfpV2PXjveDoBasXAmEi";
+ sout << "6oUcZk1zKqwFygj1793MM02voesHS7BxWz50W6yD+SozDsnliV7fY4S9+8vj8CipRYSD8t7MoczJ";
+ sout << "hg2tEMYVVBo/nV/4l607tvicVT8HYWOkBhuOgeMNrocTp/3zcfiCODeda8rzbGPWjmKOCPcgOmTD";
+ sout << "TO5jx0f568jqAuhFgeg6wv/3uXuY3mmFGWsCx4Sf+2nlWqhCMKmSsRT84UXcFYpp0Kcvx1OUFC/c";
+ sout << "E64GAqCYYhzb3hF7jys36qQefbfm9+t3owVhP0d9udWNBzw/QeMJ6djmHvk11Tl3qvjGUGV0iXh6";
+ sout << "4ZxyCeLQWZA6cNcN8Ovd2vRrtQ8SHlFPhMKqoevmMLkyZKMrD9CUPHmgzlpTuAasa7PEnbaAFHcO";
+ sout << "sVE1KNm7uMU7QjFI8u10VRJ6qBpTx6Z3GXPq7Jslk2V6z0xmH/elkzMAu2Wr3MId0/7GCuheVZhh";
+ sout << "VAf9EWJ2ZfZxGBOXd8Io9eiatd/VdlOh7FBglIGSpx8UHU3jzpSu1fcnCjVRg3XxWmUI0iRrqxQc";
+ sout << "iVT6ttiezDImpPlHxP2OjgIghWD6EZ2Gesunu93a2lep85rEZqN8sACV2sXDy5K7CySNLyhNE7fX";
+ sout << "eV1bU3FdOstad/82yh7TJdIpIdFV1tEOV+gAOMZa+516EjdnDs4WJpfWHPWG6xdVJWAvCHss1B8s";
+ sout << "k11txpDa6vs3+NCx/mF/6ElB7TxkisPo1m+KfgjBGpI/YHlm6c216WwSh1k1hLk0T3s4wFEb8M5w";
+ sout << "BlbxPkt89Y0wWwc/Eg37XquMIme6aBZjU3CZK1NwiAPkKXa9Y6fBTZWLT3W/NpP2Vk8KGxae2fRo";
+ sout << "F83VOgFxb1SUIePgZ2vMS/6OnuRxNqiwkcDEI15uVcK+l1AynHFqODaA31sxqQuHnw/FOrPG5yJR";
+ sout << "OvmgJOJ9ss9QZkvXaTZFc3vfEZElcKdW9K5xEZfiymZWX9Qihiv4PG/L+qo5Tm6cFxi/8MVW7Tgd";
+ sout << "OV5whxO+RWOCj1kyuILneXxrwiYL3tn74Z4+tT6oP2I9l2UXF574JXrOzLLeBhULrD9fpPFlb/lM";
+ sout << "5YfP97MRad6MEZfY5uMUa36kUR8s7FMdTpSyKqhEmzmepGD7JI0uGPNutfSDO5SJsPfK8Yh1uPWA";
+ sout << "5J5d/NEUp/fqvoniWgm/2ye0ApW90EtX9eDE/DyfPQimpqOdFBO7/EvDkDHlZ0u08F7+sCavgzxE";
+ sout << "jHMXJ+uOF5q96vNnGDhyc8WC9NRzX66GChvTCh7nDhBYLMCZzSvWU7wPWC9OLOCvA+lyTTGvFCgs";
+ sout << "qCrk5Hoc5etEIrpyOCe6evI681jAoNI0KK4Opb/vqVtKgLTxJhBbi/EVhJTzdALEB/WduuYpfqqD";
+ sout << "sUbdDlNRdSGgricRZUxD0iN5GZUBDEzVzp6moEW+Q7NxsryW6/88Ow3ky1fm41QPwQ40Rk45cnIU";
+ sout << "5nZxTrYUkBoaWR/Lh43XtKDh3cBhADGiaMIPBVJL/tOUtXTCj93ca73/LqanFOO/rrnROhXUCn1h";
+ sout << "+LWEWn553i9PF0CqFkFjtK/dP68wPQ7y09/NYYw9P86YdgoWwm6ZlqskD2ByyYgR5KimmyJLQyxb";
+ sout << "CmITcgVbucDiDxUUFG4Rq9Sgn4PW6jGaoujCgclaxOgLnM7lqREfW8XlN/yH8Pvyt7hjAIWNOHyg";
+ sout << "hiscgSb1ev+btDSm+aRCjPgRNfBDApl2zINR7ey0wA+wOl6wZh7cxYwbeb55/UqH6OzEXbjJzjki";
+ sout << "6c9Z95PGsu8U+FaBJmWyJWqb/SxBxsf4zPH154KF7wE6b0AoP2Y1tGD2hDN8WdEthVWbztLm66nI";
+ sout << "n0fz38pv7CNJCuVf+jeqvk6bh8/dI0hVF0CsMd5uVe+fOLzK3nmIma5nB/TQvcCqfR1YjDGV3mhd";
+ sout << "QPQ1a7ypy6WplKOU5LamJw8IaYrik3Gg2tLo63FfE8njUkrYdtqiYpgk+okF8vF4S8c8NdWlgTNn";
+ sout << "KE+9+dPQZr174KhRz1JwzspIUqW/d/QjGIgJm5NcXAo07/hgpfI2zfHlkGP5ETxLnquzkCcCsrnK";
+ sout << "IQ+vacYRwHRVSTVWpvQiDXaNK7RJrfW2ei23/bW8YkE72+lZIACjehXSDgKiBg0EIStpOVLwTSln";
+ sout << "fXVW5fHQkyAc4fOVX6bGqcBu40hHN/LRFJY7Jv5qS/AjCtLUR4UCRhMWAo7ITDLgbXoIDKnyrRLF";
+ sout << "vS7LGaJ9IHbFv0Lw8Xiqr1YSpjnLZZDcREIsjCf+2G/7kNBY96DPmp63sGuBfIR0f/I6MzUi8HNG";
+ sout << "R+EBmcCO0AfPzTeoGkIM9+5YmMpXwvMU0zCz71PxeLlnTUVUggnXTBAa92YS6HuhGKIl5aU0OCXW";
+ sout << "76CfhWijfNZCG33EdpOJbXxbtbX97y+Xjn1KyPnNW/Wyk5VO8lqD2ZGC1Wnohpx4q2VtDkVzw9h+";
+ sout << "DDbj2hpyfv2byvYJPUSlZ4bJ77aCkYuZsjMOaz4MYjVfgdi7svHdnIBauBgv+rubJAwx8oB9bk3K";
+ sout << "Wy2EDPfLHGxtUKt9+keHGnx9xZ9vGjLXzaLiGltRQ5YDiVK/vXPUc8eKy7p+NyNu80yopPIjJcgV";
+ sout << "IC0mPwR2BvdGVweS/iLAuMVJw3STNQyl7XVej6lppzt2SWidtjqxYnYjJwRYNnFr0l7SgVA9Tink";
+ sout << "jZz7M4OGdDN521PxFXjOcRRRM1Kr7/56n+VF+LdTRCihmUsOuGi4bBlZ+eKTcUmg7FX6LXKptWLm";
+ sout << "h6QUA5ZXcUZ9XUrJb1AxaHmfP/XXd9r1bsi+F3JdFUPuCOZHrzok6cshx+9r8WH1MNCBBVTeUrFl";
+ sout << "+65D1RufbIVoqZfQ8l3+zghe+U/ujnqEox8ysQAZkhYG6beS6ksj/QhWbGno9W8TeaXr2NvcF2ct";
+ sout << "PAh85hlFNWcxBdGnHXl8fQJ+p5+t8+jftQLoVqd6B9/beAUgp9KYrLugImLUnzw6Q7PjXCiFmo8d";
+ sout << "gqPvKtiqj4mKu5L5L4/ul2zdN4zQtRE6tm24+ENVyfJ00adWwqEGxrCIvUMiNbOzSI9TtoTVrKgx";
+ sout << "lON3nzfcVy3ucT93AtQfa2reu/HwRVNhn0GP/xZGBCy746if7jr5Fa34dAESoJ1mlelIKOion8QJ";
+ sout << "bsgQixF1UWwd1/O4DgHG+U43HYhLWOSw+NOnajpEaVoKB+jf/M0u5+BQvbwCcE6zd+Vhmhzy6x7i";
+ sout << "5ySbdBren80S4wUt97VFIq7HI0KvvWpkbAeQjGOhAWIqRrtbJup9yQa09VJiR0hiMQ8K2UlTx3By";
+ sout << "BFtRAiJO5mn85Q0YxqyogwSE2Kf4F7qu3EBnT/rEta+a0e/UcElFcZPpJfqSwS18mNBPDEyagvBv";
+ sout << "BQiEumckYIVwtFMpfrvELupxCTY51anDR44j+RTJlQMf91CHDZi9eewt9Rfo4ja5HTTdK5Wmjkvt";
+ sout << "DR2SBS/Z4YJX4elAbie5bsrFWR7IH9M8M/vtR++6Sw2SKEAQjS7D23xmoKXD/bC9JWoz9P0yICkJ";
+ sout << "Ckdwh6H+FNzh6Ms8qlcEWtQD28SI271I9tt0Yiw7FvwIYH/+aG7+PnXpyei7Y8OmV2scS2CaAR2e";
+ sout << "+k0qUumoXI/yPunLeJE8BQHjslaw8eglGnb7YjfgmoGPC3lzrc69UX3Okmnb8JJUoC3KEsN0Zf0n";
+ sout << "GPErNUSm7ABFRAQvHvTiDxfGyf8aN6UNYZDC0rNO5MgYpQxvmeTnCMtXz0P9OEpaBs3ZuNoJVR+n";
+ sout << "s1zhiz5JteC9VmCFM8PknTz2tQXBWonqXdHxCPOIi7k0K1Jn8ddygrfUczXeMeZY/H/akoWADYHT";
+ sout << "cGPqpi/81wQUaE4GmV6X27XcuHUcELZTkyTqpkBprMoPJJDspnbFj0aaHiz19Ws0weE3wWhUFiA4";
+ sout << "D4WlUi4uF40OEAhnv0qcSVHil71C7AEpTcS1dxe+GSs02PppfT+PZDexeWk/S/C8E+O/Fr4BVfO0";
+ sout << "G+M+XdcvagPHLZYLTEJhLhvQzPQRbOb9vdPUhiIsodTUzejdQ+zA6k7ynQqNEDAaVwyy82AQRyoi";
+ sout << "4BFGQzQ2hYCUp+In9B/69qyBSvT+q6E1YDUARctcr3h/AR/3J7uG4VdxhFA6YOKAnJGCSCMdunwi";
+ sout << "hb1or9pPHhW5rDfvmtH9iZCsuvvJmFA8FXmAURRDRziSxYR/zcApR0JBT63mtyw3lMCHP1qDg95w";
+ sout << "0ysG3hLdyMI3lnM/g1W0tg1hvXR0C3qjivvuthJHyt3fqSZbVnvNEkD2k+7BoG6KWk5EpoSqv+Ni";
+ sout << "znQvLFggfPOtd6Y8BkmBcNhCTCcyMa2ZkvzUNO7419VR/gNeU56athpUvD21r8d0/MWrEi2yKWsM";
+ sout << "aCVwYr/eAgp9kpn9i8RuSJmxe4NTb7Isqn25vOfUeUY17vS+c6wjCDqbkmukld4nOHl8uHlHdJoK";
+ sout << "elhC/yXSgUyOM8yKFSyw+Ap6v9CISYh3y09g6gBL35LjS7Vb4Ew/ks2t2L9zWrU+IC4BLcQDmMfG";
+ sout << "QUk5tt1n05bmuLmnSboYhPj2MNT9/EiqHY9AfwPhMYmWbcSCshdShSoYs1aciC1gc1H393KbvaTl";
+ sout << "N6TeDKz6bEkvM0mOu+wRDR5u+tnr0wWSW8JkMt+WkjPvqWkGYiKZuhJT/4rTPRhJEKXE34gxeyFA";
+ sout << "16oufb8W9C8Fj30WPJq1bd43y/r0Fs1dwOyZAyLsHv/2AtO5/Jw2KEHsTIgolGr0Qe6//jc/c50R";
+ sout << "TN/aDVTBCbvhcRtV/zOuy7oqb+p0IM3p7eWtKFrZ1wXly7FvvPQ5ODn9CnG5wo59XTikanWOMLKf";
+ sout << "MwQt89ZW60v9bQTgIUnFYpVlp9SF2XPxmeW6x+NVCEOiXZQ8AM3nDiLA5M0ctrI2DzzLeHoEKWem";
+ sout << "TvVZ1MrwJ8jZoRsigx45HqoB353+bXlS+5vMB/M+zUEPu/HHuU5k8zqwU93NFkTR208ZtecG3IPs";
+ sout << "ENQd7J10XsbhbYhAvEkU9ZS/3FS7aK7bDvf7cqD/QcGCfIad4rk1Ks6nuGSeR7hUrH1NK1fe9lpY";
+ sout << "lNk3EaNPlBEBmNkbGpNIUcJ7ntUy+b1q+VoC+q300a4qo3gIOkN3s25NaDJime/eJmlRYZhH4ip8";
+ sout << "6+m+nA8as1/T0/4d7HXiFWQwZ49NZsSESry1yCs97C5IQ6nScahxHDi72AbQeNDB+RtGaQJiOUi7";
+ sout << "NSBNuSlm9G9GpR55HvMi/JUqF09BrOn+49zwYqMbSkf8CjoenLM7UzCoywGlk0nXSBsbANAz7D2B";
+ sout << "65qWXswtvr4xnCSfRLUNHs3AtlfEqpsdlaF9gDQm5Z4IhT1WOXbETcY5S8T/5DDBIoWi9rHnKuxM";
+ sout << "+zmu881jhj3d9p8fbcH65hevrYM49+ZQfWPXTMUY77YbwYTGmYgScAWmqaMugoMWD1ocpYYRM0IJ";
+ sout << "+SEiUb57moAOoEeiYZcPqmckTWuHJhYtgbuBojwXqaK/qvDssM/59aTMWagOYHcapC4gBG+s99No";
+ sout << "pOCnbe2brIm4+6xWs7LzSA38RZHZSdh66V3n+83R0/wAIw9+X35SXMwrXC96OqXF/6AFvqkL2Wnk";
+ sout << "SBbvyq0txWR6b7AaZ418Dmngg3yQh04fwc8xZLy7/1ZYAbGLRRV1mNrpc2Fa1kLjxoRHZMBA75Pt";
+ sout << "HirY4CHOvKaEdlk27BW2px1QCTCkZQ/gojWhiZ1kPUAUiW7VcyFSzjtXzswHEIAnGR2dWHgGZDVT";
+ sout << "OVuBJ0nTPs8itQ2Htelag60Et9jrwDZzYa4Rhy1FjngWN1S/QAp9iGe95SRVXuBtLNgAVp+sx7SU";
+ sout << "VOECSHoLfpSeZPvlm5ibeSN83gFbIG2rsTZ3IlvJjWq82Npzas6p9WVKTEPGS+Ux8nWIBT/enw7o";
+ sout << "7KX9phVWQqcYH5IB2waRO+Ke7h6/y696NQMq0R4Xbki9lmjoWNKFtM+GgLygVqxWWWp9iyQFkUQx";
+ sout << "7tRJT1da0ImlgCXS/uTTRxvcG9d/E5FMotFa6mA7py7P+eraScFdEHL4J0kA";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+
+
+
+}
+
+
+
+
+
diff --git a/ml/dlib/dlib/test/ranking.cpp b/ml/dlib/dlib/test/ranking.cpp
new file mode 100644
index 000000000..83f70d8cc
--- /dev/null
+++ b/ml/dlib/dlib/test/ranking.cpp
@@ -0,0 +1,485 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <dlib/svm.h>
+#include <dlib/rand.h>
+#include <dlib/dnn.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <map>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+
+ logger dlog("test.ranking");
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void brute_force_count_ranking_inversions (
+ const std::vector<T>& x,
+ const std::vector<T>& y,
+ std::vector<unsigned long>& x_count,
+ std::vector<unsigned long>& y_count
+ )
+ {
+ x_count.assign(x.size(),0);
+ y_count.assign(y.size(),0);
+
+ for (unsigned long i = 0; i < x.size(); ++i)
+ {
+ for (unsigned long j = 0; j < y.size(); ++j)
+ {
+ if (x[i] <= y[j])
+ {
+ x_count[i]++;
+ y_count[j]++;
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_count_ranking_inversions()
+ {
+ print_spinner();
+ dlog << LINFO << "in test_count_ranking_inversions()";
+
+ dlib::rand rnd;
+ std::vector<int> x, y;
+ std::vector<unsigned long> x_count, y_count;
+ std::vector<unsigned long> x_count2, y_count2;
+ for (int iter = 0; iter < 5000; ++iter)
+ {
+ x.resize(rnd.get_random_32bit_number()%10);
+ y.resize(rnd.get_random_32bit_number()%10);
+ for (unsigned long i = 0; i < x.size(); ++i)
+ x[i] = ((int)rnd.get_random_32bit_number()%10) - 5;
+ for (unsigned long i = 0; i < y.size(); ++i)
+ y[i] = ((int)rnd.get_random_32bit_number()%10) - 5;
+
+ count_ranking_inversions(x, y, x_count, y_count);
+ brute_force_count_ranking_inversions(x, y, x_count2, y_count2);
+
+ DLIB_TEST(mat(x_count) == mat(x_count2));
+ DLIB_TEST(mat(y_count) == mat(y_count2));
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void run_prior_test()
+ {
+ print_spinner();
+ typedef matrix<double,3,1> sample_type;
+ typedef linear_kernel<sample_type> kernel_type;
+
+ svm_rank_trainer<kernel_type> trainer;
+
+ ranking_pair<sample_type> data;
+
+ sample_type samp;
+ samp = 0, 0, 1; data.relevant.push_back(samp);
+ samp = 0, 1, 0; data.nonrelevant.push_back(samp);
+
+ trainer.set_c(10);
+ decision_function<kernel_type> df = trainer.train(data);
+
+ trainer.set_prior(df);
+
+ data.relevant.clear();
+ data.nonrelevant.clear();
+ samp = 1, 0, 0; data.relevant.push_back(samp);
+ samp = 0, 1, 0; data.nonrelevant.push_back(samp);
+
+ df = trainer.train(data);
+
+ dlog << LINFO << trans(df.basis_vectors(0));
+ DLIB_TEST(df.basis_vectors(0)(0) > 0);
+ DLIB_TEST(df.basis_vectors(0)(1) < 0);
+ DLIB_TEST(df.basis_vectors(0)(2) > 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void run_prior_sparse_test()
+ {
+ print_spinner();
+ typedef std::map<unsigned long,double> sample_type;
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+
+ svm_rank_trainer<kernel_type> trainer;
+
+ ranking_pair<sample_type> data;
+
+ sample_type samp;
+ samp[0] = 1; data.relevant.push_back(samp); samp.clear();
+ samp[1] = 1; data.nonrelevant.push_back(samp); samp.clear();
+
+ trainer.set_c(10);
+ decision_function<kernel_type> df = trainer.train(data);
+
+ trainer.set_prior(df);
+
+ data.relevant.clear();
+ data.nonrelevant.clear();
+ samp[2] = 1; data.relevant.push_back(samp); samp.clear();
+ samp[1] = 1; data.nonrelevant.push_back(samp); samp.clear();
+
+ df = trainer.train(data);
+
+ matrix<double,0,1> w = sparse_to_dense(df.basis_vectors(0));
+ dlog << LINFO << trans(w);
+ DLIB_TEST(w(0) > 0.1);
+ DLIB_TEST(w(1) < -0.1);
+ DLIB_TEST(w(2) > 0.1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void dotest1()
+ {
+ print_spinner();
+ dlog << LINFO << "in dotest1()";
+
+ typedef matrix<double,4,1> sample_type;
+
+ typedef linear_kernel<sample_type> kernel_type;
+
+ svm_rank_trainer<kernel_type> trainer;
+
+
+ std::vector<ranking_pair<sample_type> > samples;
+
+ ranking_pair<sample_type> p;
+ sample_type samp;
+
+ samp = 0, 0, 0, 1; p.relevant.push_back(samp);
+ samp = 1, 0, 0, 0; p.nonrelevant.push_back(samp);
+ samples.push_back(p);
+
+ samp = 0, 0, 1, 0; p.relevant.push_back(samp);
+ samp = 1, 0, 0, 0; p.nonrelevant.push_back(samp);
+ samp = 0, 1, 0, 0; p.nonrelevant.push_back(samp);
+ samp = 0, 1, 0, 0; p.nonrelevant.push_back(samp);
+ samples.push_back(p);
+
+
+ trainer.set_c(10);
+
+ decision_function<kernel_type> df = trainer.train(samples);
+
+ dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
+ matrix<double,1,2> res;
+ res = 1,1;
+ DLIB_TEST(equal(test_ranking_function(df, samples), res));
+
+ DLIB_TEST(equal(test_ranking_function(trainer.train(samples[1]), samples), res));
+
+ trainer.set_epsilon(1e-13);
+ df = trainer.train(samples);
+
+ dlog << LINFO << df.basis_vectors(0);
+ sample_type truew;
+ truew = -0.5, -0.5, 0.5, 0.5;
+ DLIB_TEST(length(truew - df.basis_vectors(0)) < 1e-10);
+
+ dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
+ DLIB_TEST(equal(test_ranking_function(df, samples), res));
+
+ dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,2);
+ DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,2)(0) - 0.7777777778) < 0.0001);
+
+ trainer.set_learns_nonnegative_weights(true);
+ df = trainer.train(samples);
+ truew = 0, 0, 1.0, 1.0;
+ dlog << LINFO << df.basis_vectors(0);
+ DLIB_TEST(length(truew - df.basis_vectors(0)) < 1e-10);
+ dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
+ DLIB_TEST(equal(test_ranking_function(df, samples), res));
+
+
+ samples.clear();
+ samples.push_back(p);
+ samples.push_back(p);
+ samples.push_back(p);
+ samples.push_back(p);
+ dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,4);
+ DLIB_TEST(equal(cross_validate_ranking_trainer(trainer, samples,4) , res));
+
+ df.basis_vectors(0) = 0;
+ dlog << LINFO << "BAD RANKING:" << test_ranking_function(df, samples);
+ DLIB_TEST(test_ranking_function(df, samples)(1) < 0.5);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void dotest_sparse_vectors()
+ {
+ print_spinner();
+ dlog << LINFO << "in dotest_sparse_vectors()";
+
+ typedef std::map<unsigned long,double> sample_type;
+
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+
+ svm_rank_trainer<kernel_type> trainer;
+
+
+ std::vector<ranking_pair<sample_type> > samples;
+
+ ranking_pair<sample_type> p;
+ sample_type samp;
+
+ samp[3] = 1; p.relevant.push_back(samp); samp.clear();
+ samp[0] = 1; p.nonrelevant.push_back(samp); samp.clear();
+ samples.push_back(p);
+
+ samp[2] = 1; p.relevant.push_back(samp); samp.clear();
+ samp[0] = 1; p.nonrelevant.push_back(samp); samp.clear();
+ samp[1] = 1; p.nonrelevant.push_back(samp); samp.clear();
+ samp[1] = 1; p.nonrelevant.push_back(samp); samp.clear();
+ samples.push_back(p);
+
+
+ trainer.set_c(10);
+
+ decision_function<kernel_type> df = trainer.train(samples);
+
+ matrix<double,1,2> res;
+ res = 1,1;
+
+ dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
+ DLIB_TEST(equal(test_ranking_function(df, samples), res));
+
+ DLIB_TEST(equal(test_ranking_function(trainer.train(samples[1]), samples), res));
+
+ trainer.set_epsilon(1e-13);
+ df = trainer.train(samples);
+
+ dlog << LINFO << sparse_to_dense(df.basis_vectors(0));
+ sample_type truew;
+ truew[0] = -0.5;
+ truew[1] = -0.5;
+ truew[2] = 0.5;
+ truew[3] = 0.5;
+ DLIB_TEST(length(subtract(truew , df.basis_vectors(0))) < 1e-10);
+
+ dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
+ DLIB_TEST(equal(test_ranking_function(df, samples), res));
+
+ dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,2);
+ DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,2)(0) - 0.7777777778) < 0.0001);
+
+ trainer.set_learns_nonnegative_weights(true);
+ df = trainer.train(samples);
+ truew[0] = 0.0;
+ truew[1] = 0.0;
+ truew[2] = 1.0;
+ truew[3] = 1.0;
+ dlog << LINFO << sparse_to_dense(df.basis_vectors(0));
+ DLIB_TEST(length(subtract(truew , df.basis_vectors(0))) < 1e-10);
+ dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
+ DLIB_TEST(equal(test_ranking_function(df, samples), res));
+
+
+ samples.clear();
+ samples.push_back(p);
+ samples.push_back(p);
+ samples.push_back(p);
+ samples.push_back(p);
+ dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,4);
+ DLIB_TEST(equal(cross_validate_ranking_trainer(trainer, samples,4) , res) );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K, bool use_dcd_trainer>
+ class simple_rank_trainer
+ {
+ public:
+ template <typename T>
+ decision_function<K> train (
+ const ranking_pair<T>& pair
+ ) const
+ {
+ typedef matrix<double,10,1> sample_type;
+
+ std::vector<sample_type> relevant = pair.relevant;
+ std::vector<sample_type> nonrelevant = pair.nonrelevant;
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+ for (unsigned long i = 0; i < relevant.size(); ++i)
+ {
+ for (unsigned long j = 0; j < nonrelevant.size(); ++j)
+ {
+ samples.push_back(relevant[i] - nonrelevant[j]);
+ labels.push_back(+1);
+ samples.push_back(nonrelevant[i] - relevant[j]);
+ labels.push_back(-1);
+ }
+ }
+
+ if (use_dcd_trainer)
+ {
+ svm_c_linear_dcd_trainer<K> trainer;
+ trainer.set_c(1.0/samples.size());
+ trainer.set_epsilon(1e-10);
+ trainer.force_last_weight_to_1(true);
+ //trainer.be_verbose();
+ return trainer.train(samples, labels);
+ }
+ else
+ {
+ svm_c_linear_trainer<K> trainer;
+ trainer.set_c(1.0);
+ trainer.set_epsilon(1e-13);
+ trainer.force_last_weight_to_1(true);
+ //trainer.be_verbose();
+ decision_function<K> df = trainer.train(samples, labels);
+ DLIB_TEST_MSG(df.b == 0, df.b);
+ return df;
+ }
+ }
+ };
+
+ template <bool use_dcd_trainer>
+ void test_svmrank_weight_force_dense()
+ {
+ print_spinner();
+ dlog << LINFO << "use_dcd_trainer: "<< use_dcd_trainer;
+
+ typedef matrix<double,10,1> sample_type;
+ typedef linear_kernel<sample_type> kernel_type;
+
+ ranking_pair<sample_type> pair;
+
+ for (int i = 0; i < 20; ++i)
+ {
+ pair.relevant.push_back(abs(gaussian_randm(10,1,i)));
+ }
+
+ for (int i = 0; i < 20; ++i)
+ {
+ pair.nonrelevant.push_back(-abs(gaussian_randm(10,1,i+10000)));
+ pair.nonrelevant.back()(9) += 1;
+ }
+
+
+ svm_rank_trainer<kernel_type> trainer;
+ trainer.force_last_weight_to_1(true);
+ trainer.set_epsilon(1e-13);
+ //trainer.be_verbose();
+ decision_function<kernel_type> df;
+ df = trainer.train(pair);
+
+ matrix<double,1,2> res;
+ res = 1,1;
+ dlog << LINFO << "weights: "<< trans(df.basis_vectors(0));
+ const matrix<double,1,2> acc1 = test_ranking_function(df, pair);
+ dlog << LINFO << "ranking accuracy: " << acc1;
+ DLIB_TEST(equal(acc1,res));
+
+ simple_rank_trainer<kernel_type,use_dcd_trainer> strainer;
+ decision_function<kernel_type> df2;
+ df2 = strainer.train(pair);
+ dlog << LINFO << "weights: "<< trans(df2.basis_vectors(0));
+ const matrix<double,1,2> acc2 = test_ranking_function(df2, pair);
+ dlog << LINFO << "ranking accuracy: " << acc2;
+ DLIB_TEST(equal(acc2,res));
+
+ dlog << LINFO << "w error: " << max(abs(df.basis_vectors(0) - df2.basis_vectors(0)));
+ dlog << LINFO << "b error: " << abs(df.b - df2.b);
+ DLIB_TEST(std::abs(max(abs(df.basis_vectors(0) - df2.basis_vectors(0)))) < 1e-8);
+ DLIB_TEST(std::abs(abs(df.b - df2.b)) < 1e-8);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_dnn_ranking_loss()
+ {
+ print_spinner();
+ typedef matrix<double,2,1> sample_type;
+
+
+ ranking_pair<sample_type> data;
+ sample_type samp;
+
+ // Make one relevant example.
+ samp = 1, 0;
+ data.relevant.push_back(samp);
+
+ // Now make a non-relevant example.
+ samp = 0, 1;
+ data.nonrelevant.push_back(samp);
+
+
+ using net_type = loss_ranking<fc_no_bias<1,input<matrix<float,2,1>>>>;
+ net_type net;
+ dnn_trainer<net_type> trainer(net, sgd(1.0, 0.9));
+ std::vector<matrix<float,2,1>> x;
+ std::vector<float> y;
+
+ x.push_back(matrix_cast<float>(data.relevant[0])); y.push_back(1);
+ x.push_back(matrix_cast<float>(data.nonrelevant[0])); y.push_back(-1);
+
+ //trainer.be_verbose();
+ trainer.set_learning_rate_schedule(logspace(-1, -7, 4000));
+ trainer.train(x,y);
+
+ matrix<float> params = mat(net.subnet().layer_details().get_layer_params());
+ dlog << LINFO << "params: "<< params;
+ dlog << LINFO << "relevant output score: " << net(x[0]);
+ dlog << LINFO << "nonrelevant output score: " << net(x[1]);
+
+ DLIB_TEST(std::abs(params(0) - 1) < 0.001);
+ DLIB_TEST(std::abs(params(1) + 1) < 0.001);
+ DLIB_TEST(std::abs(net(x[0]) - 1) < 0.001);
+ DLIB_TEST(std::abs(net(x[1]) + 1) < 0.001);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class test_ranking_tools : public tester
+ {
+ public:
+ test_ranking_tools (
+ ) :
+ tester ("test_ranking",
+ "Runs tests on the ranking tools.")
+ {}
+
+
+ void perform_test (
+ )
+ {
+ test_count_ranking_inversions();
+ dotest1();
+ dotest_sparse_vectors();
+ test_svmrank_weight_force_dense<true>();
+ test_svmrank_weight_force_dense<false>();
+ run_prior_test();
+ run_prior_sparse_test();
+ test_dnn_ranking_loss();
+
+ }
+ } a;
+
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/read_write_mutex.cpp b/ml/dlib/dlib/test/read_write_mutex.cpp
new file mode 100644
index 000000000..fb8bdb84c
--- /dev/null
+++ b/ml/dlib/dlib/test/read_write_mutex.cpp
@@ -0,0 +1,208 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/misc_api.h>
+#include <dlib/threads.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.read_write_mutex");
+
+ class read_write_mutex_tester : public tester, multithreaded_object
+ {
+ public:
+ read_write_mutex_tester (
+ ) :
+ tester ("test_read_write_mutex",
+ "Runs tests on the read_write_mutex component.")
+ {
+ register_thread(*this, &read_write_mutex_tester::thread_write);
+ register_thread(*this, &read_write_mutex_tester::thread_write);
+ register_thread(*this, &read_write_mutex_tester::thread_write);
+
+ register_thread(*this, &read_write_mutex_tester::thread_readonly);
+ register_thread(*this, &read_write_mutex_tester::thread_readonly);
+ register_thread(*this, &read_write_mutex_tester::thread_readonly);
+ register_thread(*this, &read_write_mutex_tester::thread_readonly2);
+ register_thread(*this, &read_write_mutex_tester::thread_readonly2);
+ register_thread(*this, &read_write_mutex_tester::thread_readonly2);
+
+ }
+
+ read_write_mutex m;
+
+ dlib::mutex mut;
+ int num_write;
+ int num_read;
+ int max_read;
+
+ bool failure;
+
+ void thread_write ()
+ {
+ // do this so that the readonly threads can get into their loops first. This way
+ // we can see if the mutex lets many readers into their area
+ dlib::sleep(250);
+ for (int i = 0; i < 6; ++i)
+ {
+ auto_mutex lock(m);
+
+ mut.lock();
+ ++num_write;
+ mut.unlock();
+
+ // only one write thread should ever be active at once
+ if (num_write != 1)
+ {
+ failure = true;
+ dlog << LERROR << "1";
+ }
+
+ dlib::sleep(300);
+
+ // only one write thread should ever be active at once
+ if (num_write != 1)
+ {
+ failure = true;
+ dlog << LERROR << "2";
+ }
+
+ mut.lock();
+ --num_write;
+ mut.unlock();
+
+ print_spinner();
+ }
+ dlog << LINFO << "exit thread_write()";
+ }
+
+ void do_readonly_stuff()
+ {
+ mut.lock();
+ ++num_read;
+ max_read = max(num_read, max_read);
+ mut.unlock();
+
+ if (num_write != 0)
+ {
+ failure = true;
+ dlog << LERROR << "3";
+ }
+
+ dlib::sleep(300);
+
+ if (num_write != 0)
+ {
+ failure = true;
+ dlog << LERROR << "4";
+ }
+
+ mut.lock();
+ max_read = max(num_read, max_read);
+ --num_read;
+ mut.unlock();
+
+ print_spinner();
+ }
+
+ void thread_readonly ()
+ {
+ for (int i = 0; i < 6; ++i)
+ {
+ auto_mutex_readonly lock(m);
+ DLIB_TEST(lock.has_read_lock());
+ DLIB_TEST(!lock.has_write_lock());
+ do_readonly_stuff();
+
+ lock.lock_readonly();
+ DLIB_TEST(lock.has_read_lock());
+ DLIB_TEST(!lock.has_write_lock());
+ lock.unlock();
+ DLIB_TEST(!lock.has_read_lock());
+ DLIB_TEST(!lock.has_write_lock());
+ lock.lock_readonly();
+ DLIB_TEST(lock.has_read_lock());
+ DLIB_TEST(!lock.has_write_lock());
+ lock.lock_write();
+ DLIB_TEST(!lock.has_read_lock());
+ DLIB_TEST(lock.has_write_lock());
+ lock.lock_write();
+ DLIB_TEST(!lock.has_read_lock());
+ DLIB_TEST(lock.has_write_lock());
+ }
+
+ dlog << LINFO << "exit thread_readonly()";
+ }
+
+ void thread_readonly2 ()
+ {
+ for (int i = 0; i < 6; ++i)
+ {
+ m.lock_readonly();
+ auto_unlock_readonly unlock(m);
+
+ do_readonly_stuff();
+ }
+ dlog << LINFO << "exit thread_readonly2()";
+ }
+
+
+ void perform_test (
+ )
+ {
+ num_write = 0;
+ num_read = 0;
+ max_read = 0;
+ failure = false;
+
+ // doing this big block of weird stuff should have no effect.
+ {
+ m.unlock();
+
+ m.lock_readonly();
+ m.lock_readonly();
+ m.unlock();
+ m.unlock_readonly();
+ m.unlock();
+ m.unlock_readonly();
+
+ m.unlock();
+ m.unlock_readonly();
+
+ m.lock();
+ m.unlock_readonly();
+ m.unlock_readonly();
+ m.unlock();
+ }
+
+
+ // start up our testing threads
+ start();
+
+ // wait for the threads to finish
+ wait();
+
+
+ DLIB_TEST(failure == false);
+ DLIB_TEST_MSG(max_read == 6, "max_read: "<< max_read);
+
+ }
+
+ } a;
+
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/reference_counter.cpp b/ml/dlib/dlib/test/reference_counter.cpp
new file mode 100644
index 000000000..330ceed94
--- /dev/null
+++ b/ml/dlib/dlib/test/reference_counter.cpp
@@ -0,0 +1,122 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/reference_counter.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.reference_counter");
+
+ template <
+ typename ref_counter
+ >
+ void reference_counter_test (
+ )
+ /*!
+ requires
+ - ref_counter is an implementation of reference_counter/reference_counter_kernel_abstract.h
+ and is instantiated to contain an int
+ ensures
+ - runs tests on reference_counter for compliance with the specs
+ !*/
+ {
+
+ ref_counter a, b, c;
+
+ for (long i = 0; i < 10; ++i)
+ {
+ print_spinner();
+ for (long j = 0; j < 10000; ++j)
+ {
+ a.modify() = j;
+ b.modify() = j+1;
+ c.modify() = j+2;
+ DLIB_ASSERT(a.access() == j,"");
+ DLIB_ASSERT(b.access() == j+1,"");
+ DLIB_ASSERT(c.access() == j+2,"");
+ DLIB_ASSERT(a.modify() == j,"");
+ DLIB_ASSERT(b.modify() == j+1,"");
+ DLIB_ASSERT(c.modify() == j+2,"");
+ DLIB_ASSERT(a.access() == j,"");
+ DLIB_ASSERT(b.access() == j+1,"");
+ DLIB_ASSERT(c.access() == j+2,"");
+ DLIB_ASSERT(a.modify() == j,"");
+ DLIB_ASSERT(b.modify() == j+1,"");
+ DLIB_ASSERT(c.modify() == j+2,"");
+ a = c;
+ DLIB_ASSERT(a.access() == j+2,"");
+ DLIB_ASSERT(b.access() == j+1,"");
+ DLIB_ASSERT(c.access() == j+2,"");
+ DLIB_ASSERT(a.modify() == j+2,"");
+ DLIB_ASSERT(b.modify() == j+1,"");
+ DLIB_ASSERT(c.modify() == j+2,"");
+ DLIB_ASSERT(a.access() == j+2,"");
+ DLIB_ASSERT(b.access() == j+1,"");
+ DLIB_ASSERT(c.access() == j+2,"");
+ DLIB_ASSERT(a.modify() == j+2,"");
+ DLIB_ASSERT(b.modify() == j+1,"");
+ DLIB_ASSERT(c.modify() == j+2,"");
+
+ a = b = c;
+ DLIB_ASSERT(a.access() == b.access(),"");
+ DLIB_ASSERT(a.access() == c.access(),"");
+ DLIB_ASSERT(c.access() == b.access(),"");
+ a.modify() = j;
+ DLIB_ASSERT(a.access() == j,"");
+ DLIB_ASSERT(a.access() != b.access(),"");
+ DLIB_ASSERT(a.access() != c.access(),"");
+ DLIB_ASSERT(c.access() == b.access(),"");
+ DLIB_ASSERT(c.access() == j+2,"");
+ DLIB_ASSERT(b.access() == j+2,"");
+
+ DLIB_ASSERT(a.access() == j,"");
+ a = a;
+ DLIB_ASSERT(a.access() == j,"");
+ c = c;
+ DLIB_ASSERT(c.access() == j+2,"");
+ DLIB_ASSERT(b.access() == j+2,"");
+ swap(a,c);
+ DLIB_ASSERT(a.access() == j+2,"");
+ DLIB_ASSERT(c.access() == j,"");
+ DLIB_ASSERT(b.access() == j+2,"");
+ }
+ }
+
+ }
+
+
+
+
+
+ class reference_counter_tester : public tester
+ {
+ public:
+ reference_counter_tester (
+ ) :
+ tester ("test_reference_counter",
+ "Runs tests on the reference_counter component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ reference_counter_test<reference_counter<int>::kernel_1a> ();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/rls.cpp b/ml/dlib/dlib/test/rls.cpp
new file mode 100644
index 000000000..c4516ad74
--- /dev/null
+++ b/ml/dlib/dlib/test/rls.cpp
@@ -0,0 +1,196 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <cmath>
+#include <dlib/svm.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.rls");
+
+
+ void test_rls()
+ {
+ dlib::rand rnd;
+
+ running_stats<double> rs1, rs2, rs3, rs4, rs5;
+
+ for (int k = 0; k < 2; ++k)
+ {
+ for (long num_vars = 1; num_vars < 4; ++num_vars)
+ {
+ print_spinner();
+ for (long size = 1; size < 300; ++size)
+ {
+ {
+ matrix<double> X = randm(size,num_vars,rnd);
+ matrix<double,0,1> Y = randm(size,1,rnd);
+
+
+ const double C = 1000;
+ const double forget_factor = 1.0;
+ rls r(forget_factor, C);
+ for (long i = 0; i < Y.size(); ++i)
+ {
+ r.train(trans(rowm(X,i)), Y(i));
+ }
+
+
+ matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
+
+ rs1.add(length(r.get_w() - w));
+ }
+
+ {
+ matrix<double> X = randm(size,num_vars,rnd);
+ matrix<double,0,1> Y = randm(size,1,rnd);
+
+ matrix<double,0,1> G(size,1);
+
+ const double C = 10000;
+ const double forget_factor = 0.8;
+ rls r(forget_factor, C);
+ for (long i = 0; i < Y.size(); ++i)
+ {
+ r.train(trans(rowm(X,i)), Y(i));
+
+ G(i) = std::pow(forget_factor, i/2.0);
+ }
+
+ G = flipud(G);
+
+ X = diagm(G)*X;
+ Y = diagm(G)*Y;
+
+ matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
+
+ rs5.add(length(r.get_w() - w));
+ }
+
+ {
+ matrix<double> X = randm(size,num_vars,rnd);
+ matrix<double> Y = colm(X,0)*10;
+
+
+ const double C = 1000000;
+ const double forget_factor = 1.0;
+ rls r(forget_factor, C);
+ for (long i = 0; i < Y.size(); ++i)
+ {
+ r.train(trans(rowm(X,i)), Y(i));
+ }
+
+
+ matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
+
+ rs2.add(length(r.get_w() - w));
+ }
+
+ {
+ matrix<double> X = join_rows(randm(size,num_vars,rnd)-0.5, ones_matrix<double>(size,1));
+ matrix<double> Y = uniform_matrix<double>(size,1,10);
+
+
+ const double C = 1e7;
+ const double forget_factor = 1.0;
+
+ matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
+
+ rls r(forget_factor, C);
+ for (long i = 0; i < Y.size(); ++i)
+ {
+ r.train(trans(rowm(X,i)), Y(i));
+ rs3.add(std::abs(r(trans(rowm(X,i))) - 10));
+ }
+
+
+ }
+ {
+ matrix<double> X = randm(size,num_vars,rnd)-0.5;
+ matrix<double> Y = colm(X,0)*10;
+
+
+ const double C = 1e6;
+ const double forget_factor = 0.7;
+
+
+ rls r(forget_factor, C);
+ DLIB_TEST(std::abs(r.get_c() - C) < 1e-10);
+ DLIB_TEST(std::abs(r.get_forget_factor() - forget_factor) < 1e-15);
+ DLIB_TEST(r.get_w().size() == 0);
+
+ for (long i = 0; i < Y.size(); ++i)
+ {
+ r.train(trans(rowm(X,i)), Y(i));
+ rs4.add(std::abs(r(trans(rowm(X,i))) - X(i,0)*10));
+ }
+
+ DLIB_TEST(r.get_w().size() == num_vars);
+
+ decision_function<linear_kernel<matrix<double,0,1> > > df = r.get_decision_function();
+ DLIB_TEST(std::abs(df(trans(rowm(X,0))) - r(trans(rowm(X,0)))) < 1e-15);
+ }
+ }
+ }
+ }
+
+ dlog << LINFO << "rs1.mean(): " << rs1.mean();
+ dlog << LINFO << "rs2.mean(): " << rs2.mean();
+ dlog << LINFO << "rs3.mean(): " << rs3.mean();
+ dlog << LINFO << "rs4.mean(): " << rs4.mean();
+ dlog << LINFO << "rs5.mean(): " << rs5.mean();
+ dlog << LINFO << "rs1.max(): " << rs1.max();
+ dlog << LINFO << "rs2.max(): " << rs2.max();
+ dlog << LINFO << "rs3.max(): " << rs3.max();
+ dlog << LINFO << "rs4.max(): " << rs4.max();
+ dlog << LINFO << "rs5.max(): " << rs5.max();
+
+ DLIB_TEST_MSG(rs1.mean() < 1e-10, rs1.mean());
+ DLIB_TEST_MSG(rs2.mean() < 1e-9, rs2.mean());
+ DLIB_TEST_MSG(rs3.mean() < 1e-6, rs3.mean());
+ DLIB_TEST_MSG(rs4.mean() < 1e-6, rs4.mean());
+ DLIB_TEST_MSG(rs5.mean() < 1e-3, rs5.mean());
+
+ DLIB_TEST_MSG(rs1.max() < 1e-10, rs1.max());
+ DLIB_TEST_MSG(rs2.max() < 1e-6, rs2.max());
+ DLIB_TEST_MSG(rs3.max() < 0.001, rs3.max());
+ DLIB_TEST_MSG(rs4.max() < 0.01, rs4.max());
+ DLIB_TEST_MSG(rs5.max() < 0.1, rs5.max());
+
+ }
+
+
+
+
+ class rls_tester : public tester
+ {
+ public:
+ rls_tester (
+ ) :
+ tester ("test_rls",
+ "Runs tests on the rls component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_rls();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/sammon.cpp b/ml/dlib/dlib/test/sammon.cpp
new file mode 100644
index 000000000..5328bd1f6
--- /dev/null
+++ b/ml/dlib/dlib/test/sammon.cpp
@@ -0,0 +1,211 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <cmath>
+#include <dlib/statistics.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.sammon");
+
+
+ std::vector<matrix<double,4,1> > make_test_data4(
+ )
+ {
+ std::vector<matrix<double,4,1> > data;
+
+ matrix<double,4,1> m;
+
+ m = 0,0,0, 0; data.push_back(m);
+ m = 1,0,0, 0; data.push_back(m);
+ m = 0,1,0, 0; data.push_back(m);
+ m = 0,0,1, 0; data.push_back(m);
+
+ return data;
+ }
+
+ std::vector<matrix<double,3,1> > make_test_data3(
+ )
+ {
+ std::vector<matrix<double,3,1> > data;
+
+ matrix<double,3,1> m;
+
+ m = 0,0,0; data.push_back(m);
+ m = 1,0,0; data.push_back(m);
+ m = 0,1,0; data.push_back(m);
+ m = 0,0,1; data.push_back(m);
+
+ return data;
+ }
+
+ std::vector<matrix<double> > make_test_data3d(
+ )
+ {
+ std::vector<matrix<double> > data;
+
+ matrix<double,3,1> m;
+
+ m = 0,0,0; data.push_back(m);
+ m = 1,0,0; data.push_back(m);
+ m = 0,1,0; data.push_back(m);
+ m = 0,0,1; data.push_back(m);
+
+ return data;
+ }
+
+
+ void runtest()
+ {
+ sammon_projection s;
+ std::vector<matrix<double, 0, 1> > projs = s(make_test_data3(),2);
+ running_stats<double> rs1, rs2;
+
+ rs1.add(length(projs[0] - projs[1]));
+ rs1.add(length(projs[0] - projs[2]));
+ rs1.add(length(projs[0] - projs[3]));
+
+ rs2.add(length(projs[1] - projs[2]));
+ rs2.add(length(projs[2] - projs[3]));
+ rs2.add(length(projs[3] - projs[1]));
+
+ DLIB_TEST(rs1.stddev()/rs1.mean() < 1e-4);
+ DLIB_TEST(rs2.stddev()/rs2.mean() < 1e-4);
+
+
+
+ projs = s(make_test_data4(),2);
+ rs1.clear();
+ rs2.clear();
+
+ rs1.add(length(projs[0] - projs[1]));
+ rs1.add(length(projs[0] - projs[2]));
+ rs1.add(length(projs[0] - projs[3]));
+
+ rs2.add(length(projs[1] - projs[2]));
+ rs2.add(length(projs[2] - projs[3]));
+ rs2.add(length(projs[3] - projs[1]));
+
+ DLIB_TEST(rs1.stddev()/rs1.mean() < 1e-4);
+ DLIB_TEST(rs2.stddev()/rs2.mean() < 1e-4);
+
+ projs = s(make_test_data3d(),2);
+ rs1.clear();
+ rs2.clear();
+
+ rs1.add(length(projs[0] - projs[1]));
+ rs1.add(length(projs[0] - projs[2]));
+ rs1.add(length(projs[0] - projs[3]));
+
+ rs2.add(length(projs[1] - projs[2]));
+ rs2.add(length(projs[2] - projs[3]));
+ rs2.add(length(projs[3] - projs[1]));
+
+ DLIB_TEST(rs1.stddev()/rs1.mean() < 1e-4);
+ DLIB_TEST(rs2.stddev()/rs2.mean() < 1e-4);
+ }
+
+ void runtest2()
+ {
+ sammon_projection s;
+ std::vector<matrix<double, 0, 1> > projs, temp;
+
+ DLIB_TEST(s(projs,3).size() == 0);
+
+ matrix<double,2,1> m;
+ m = 1,2;
+ projs.push_back(m);
+ temp = s(projs,2);
+ DLIB_TEST(temp.size() == 1);
+ DLIB_TEST(temp[0].size() == 2);
+
+ projs.push_back(m);
+ temp = s(projs,1);
+ DLIB_TEST(temp.size() == 2);
+ DLIB_TEST(temp[0].size() == 1);
+ DLIB_TEST(temp[1].size() == 1);
+ }
+
+ void runtest3(int num_dims)
+ {
+ sammon_projection s;
+ std::vector<matrix<double, 0, 1> > projs;
+ matrix<double,3,1> m;
+ m = 1, 1, 1;
+ projs.push_back(m);
+
+ m = 1, 2, 1;
+ projs.push_back(m);
+
+ m = 1, 3, 1;
+ projs.push_back(m);
+
+ projs = s(projs,num_dims);
+
+ const double d1a = length(projs[0] - projs[1]);
+ const double d1b = length(projs[1] - projs[2]);
+ const double d2 = length(projs[0] - projs[2]);
+
+ DLIB_TEST(std::abs(d1a-d1b)/d1a < 1e-8);
+ DLIB_TEST(std::abs(d2/d1a-2) < 1e-8);
+ }
+
+ void runtest4(int num_dims)
+ {
+ sammon_projection s;
+ std::vector<matrix<double, 0, 1> > projs;
+ matrix<double,3,1> m;
+ m = 1, 1, 1;
+ projs.push_back(m);
+
+ m = 1, 2, 1;
+ projs.push_back(m);
+
+
+ projs = s(projs,num_dims);
+
+ DLIB_TEST(length(projs[0] - projs[1]) > 1e-5);
+ }
+
+ class sammon_tester : public tester
+ {
+ public:
+ sammon_tester (
+ ) :
+ tester ("test_sammon",
+ "Runs tests on the sammon_projection component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ runtest();
+ print_spinner();
+ runtest2();
+ print_spinner();
+ runtest3(2);
+ print_spinner();
+ runtest4(2);
+ runtest3(1);
+ print_spinner();
+ runtest4(1);
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/scan_image.cpp b/ml/dlib/dlib/test/scan_image.cpp
new file mode 100644
index 000000000..c3a0115e3
--- /dev/null
+++ b/ml/dlib/dlib/test/scan_image.cpp
@@ -0,0 +1,713 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include "dlib/image_processing.h"
+
+#include "dlib/test/tester.h"
+
+#include "dlib/image_transforms.h"
+#include "dlib/pixel.h"
+#include "dlib/array2d.h"
+#include "dlib/array.h"
+
+// ----------------------------------------------------------------------------------------
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ using dlib::array;
+
+ // Declare the logger we will use in this test. The name of the tester
+ // should start with "test."
+ logger dlog("test.scan_image");
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename image_type1, typename image_type2>
+ void sum_filter_i (
+ const image_type1& img,
+ image_type2& out,
+ const rectangle& rect
+ )
+ {
+ typedef typename image_type1::type pixel_type;
+ typedef typename promote<pixel_type>::type ptype;
+ integral_image_generic<ptype> iimg;
+ iimg.load(img);
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ const rectangle temp = translate_rect(rect, point(c,r)).intersect(get_rect(iimg));
+ if (temp.is_empty() == false)
+ out[r][c] += iimg.get_sum_of_area(temp);
+ }
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ void scan_image_i (
+ std::vector<std::pair<double, point> >& dets,
+ const image_array_type& images,
+ const std::vector<std::pair<unsigned int, rectangle> >& rects,
+ const double thresh,
+ const unsigned long max_dets
+ )
+ {
+ typedef typename image_array_type::type::type pixel_type;
+ typedef typename promote<pixel_type>::type ptype;
+ array<integral_image_generic<ptype> > iimg;
+ iimg.set_max_size(images.size());
+ iimg.set_size(images.size());
+
+ for (unsigned long i = 0; i < iimg.size(); ++i)
+ iimg[i].load(images[i]);
+
+
+ dets.clear();
+
+
+ for (long r = 0; r < images[0].nr(); ++r)
+ {
+ for (long c = 0; c < images[0].nc(); ++c)
+ {
+ ptype temp = 0;
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ rectangle rtemp = translate_rect(rects[i].second,point(c,r)).intersect(get_rect(images[0]));
+ if (rtemp.is_empty() == false)
+ temp += iimg[rects[i].first].get_sum_of_area(rtemp);
+ }
+ if (temp > thresh)
+ {
+ dets.push_back(std::make_pair(temp, point(c,r)));
+
+ if (dets.size() >= max_dets)
+ return;
+
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_array_type
+ >
+ void scan_image_old (
+ std::vector<std::pair<double, point> >& dets,
+ const image_array_type& images,
+ const std::vector<std::pair<unsigned int, rectangle> >& rects,
+ const double thresh,
+ const unsigned long max_dets
+ )
+ {
+ dets.clear();
+ if (max_dets == 0)
+ return;
+
+ typedef typename image_array_type::type::type pixel_type;
+ typedef typename promote<pixel_type>::type ptype;
+
+ std::vector<std::vector<ptype> > column_sums(rects.size());
+ for (unsigned long i = 0; i < column_sums.size(); ++i)
+ {
+ const typename image_array_type::type& img = images[rects[i].first];
+ column_sums[i].resize(img.nc() + rects[i].second.width(),0);
+
+ const long top = -1 + rects[i].second.top();
+ const long bottom = -1 + rects[i].second.bottom();
+ long left = rects[i].second.left()-1;
+
+ // initialize column_sums[i] at row -1
+ for (unsigned long j = 0; j < column_sums[i].size(); ++j)
+ {
+ rectangle strip(left,top,left,bottom);
+ strip = strip.intersect(get_rect(img));
+ if (!strip.is_empty())
+ {
+ column_sums[i][j] = sum(matrix_cast<ptype>(subm(mat(img),strip)));
+ }
+
+ ++left;
+ }
+ }
+
+
+ const rectangle area = get_rect(images[0]);
+
+ for (long r = 0; r < images[0].nr(); ++r)
+ {
+ // set to sum at point(-1,r). i.e. should be equal to sum_of_rects_in_images(images, rects, point(-1,r))
+ // We compute it's value in the next loop.
+ ptype cur_sum = 0;
+
+ // Update the first part of column_sums since we only work on the c+width part of column_sums
+ // in the main loop.
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ const typename image_array_type::type& img = images[rects[i].first];
+ const long top = r + rects[i].second.top() - 1;
+ const long bottom = r + rects[i].second.bottom();
+ const long width = rects[i].second.width();
+ for (long k = 0; k < width; ++k)
+ {
+ const long right = k-width + rects[i].second.right();
+
+ const ptype br_corner = area.contains(right,bottom) ? img[bottom][right] : 0;
+ const ptype tr_corner = area.contains(right,top) ? img[top][right] : 0;
+ // update the sum in this column now that we are on the next row
+ column_sums[i][k] = column_sums[i][k] + br_corner - tr_corner;
+ cur_sum += column_sums[i][k];
+ }
+ }
+
+ for (long c = 0; c < images[0].nc(); ++c)
+ {
+ for (unsigned long i = 0; i < rects.size(); ++i)
+ {
+ const typename image_array_type::type& img = images[rects[i].first];
+ const long top = r + rects[i].second.top() - 1;
+ const long bottom = r + rects[i].second.bottom();
+ const long right = c + rects[i].second.right();
+ const long width = rects[i].second.width();
+
+ const ptype br_corner = area.contains(right,bottom) ? img[bottom][right] : 0;
+ const ptype tr_corner = area.contains(right,top) ? img[top][right] : 0;
+ // update the sum in this column now that we are on the next row
+ column_sums[i][c+width] = column_sums[i][c+width] + br_corner - tr_corner;
+
+
+ // add in the new right side of the rect and subtract the old right side.
+ cur_sum = cur_sum + column_sums[i][c+width] - column_sums[i][c];
+
+ }
+
+ if (cur_sum > thresh)
+ {
+ dets.push_back(std::make_pair(cur_sum, point(c,r)));
+
+ if (dets.size() >= max_dets)
+ return;
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void run_test1()
+ {
+ dlog << LINFO << "run_test1()";
+
+ print_spinner();
+ array2d<unsigned char> img, temp_img;
+ img.set_size(600,600);
+ assign_all_pixels(img,0);
+ rectangle rect = centered_rect(10,10,5,5);
+ dlog << LTRACE << "expected: 10,10";
+ fill_rect(img, rect, 255);
+
+
+ array<array2d<unsigned char> > images;
+ std::vector<std::pair<unsigned int, rectangle> > rects;
+ for (int i = 0; i < 10; ++i)
+ {
+ assign_image(temp_img, img);
+ images.push_back(temp_img);
+ rects.push_back(make_pair(i,centered_rect(0,0,5,5)));
+ }
+
+ std::vector<std::pair<double, point> > dets, dets2, dets3;
+
+
+ dlog << LTRACE << "best score: "<< sum_of_rects_in_images(images,rects,point(10,10));
+ scan_image(dets,images,rects,30000, 100);
+ scan_image_i(dets2,images,rects,30000, 100);
+ scan_image_old(dets3,images,rects,30000, 100);
+
+
+
+ dlog << LTRACE << "dets.size(): "<< dets.size();
+ dlog << LTRACE << "dets2.size(): "<< dets2.size();
+ dlog << LTRACE << "dets3.size(): "<< dets3.size();
+
+ DLIB_TEST(dets.size() == dets2.size());
+ DLIB_TEST(dets.size() == dets3.size());
+
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ //dlog << LTRACE << "dets["<<i<<"]: " << dets[i].second << " -> " << dets[i].first;
+ //dlog << LTRACE << "dets2["<<i<<"]: " << dets2[i].second << " -> " << dets2[i].first;
+ //dlog << LTRACE << "dets3["<<i<<"]: " << dets3[i].second << " -> " << dets3[i].first;
+
+ DLIB_TEST(sum_of_rects_in_images(images, rects, dets[i].second) == dets[i].first);
+ DLIB_TEST(sum_of_rects_in_images(images, rects, dets2[i].second) == dets2[i].first);
+ DLIB_TEST(sum_of_rects_in_images(images, rects, dets3[i].second) == dets3[i].first);
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void run_test2()
+ {
+ print_spinner();
+ dlog << LINFO << "run_test2()";
+ array2d<unsigned char> img, temp_img;
+ img.set_size(600,600);
+ assign_all_pixels(img,0);
+ rectangle rect = centered_rect(10,11,5,6);
+ dlog << LTRACE << "expected: 10,11";
+ fill_rect(img, rect, 255);
+
+
+ array<array2d<unsigned char> > images;
+ std::vector<std::pair<unsigned int, rectangle> > rects;
+ for (int i = 0; i < 10; ++i)
+ {
+ assign_image(temp_img, img);
+ images.push_back(temp_img);
+ rects.push_back(make_pair(i,centered_rect(0,0,5,5)));
+ rects.push_back(make_pair(i,centered_rect(3,2,5,6)));
+ }
+
+ std::vector<std::pair<double, point> > dets, dets2, dets3;
+
+
+ scan_image(dets,images,rects,30000, 100);
+ scan_image_i(dets2,images,rects,30000, 100);
+ scan_image_old(dets3,images,rects,30000, 100);
+
+
+
+ dlog << LTRACE << "dets.size(): "<< dets.size();
+ dlog << LTRACE << "dets2.size(): "<< dets2.size();
+ dlog << LTRACE << "dets3.size(): "<< dets3.size();
+
+ DLIB_TEST(dets.size() == dets2.size());
+ DLIB_TEST(dets.size() == dets3.size());
+
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ //dlog << LTRACE << "dets["<<i<<"]: " << dets[i].second << " -> " << dets[i].first;
+ //dlog << LTRACE << "dets2["<<i<<"]: " << dets2[i].second << " -> " << dets2[i].first;
+ //dlog << LTRACE << "dets3["<<i<<"]: " << dets3[i].second << " -> " << dets3[i].first;
+
+ DLIB_TEST(sum_of_rects_in_images(images, rects, dets[i].second) == dets[i].first);
+ DLIB_TEST(sum_of_rects_in_images(images, rects, dets2[i].second) == dets2[i].first);
+ DLIB_TEST(sum_of_rects_in_images(images, rects, dets3[i].second) == dets3[i].first);
+ }
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void run_test3(const double thresh)
+ {
+ dlog << LINFO << "running run_test3("<<thresh<<")";
+ dlib::rand rnd;
+
+ rnd.set_seed("235");
+
+ array<array2d<pixel_type> > images;
+ images.resize(1);
+ images[0].set_size(200,180);
+
+ for (int iter = 0; iter < 50; ++iter)
+ {
+ print_spinner();
+ assign_all_pixels(images[0], thresh - 0.0001);
+
+ for (int i = 0; i < 20; ++i)
+ {
+ point p1(rnd.get_random_32bit_number()%images[0].nc(),
+ rnd.get_random_32bit_number()%images[0].nr());
+ point p2(rnd.get_random_32bit_number()%images[0].nc(),
+ rnd.get_random_32bit_number()%images[0].nr());
+
+ rectangle rect(p1,p2);
+ fill_rect(images[0], rect, static_cast<pixel_type>(rnd.get_random_double()*10 - 5));
+ }
+
+ std::vector<std::pair<unsigned int, rectangle> > rects;
+ rects.push_back(make_pair(0,centered_rect(0,0,1+rnd.get_random_32bit_number()%40,1+rnd.get_random_32bit_number()%40)));
+ rects.push_back(make_pair(0,centered_rect(0,0,1+rnd.get_random_32bit_number()%40,1+rnd.get_random_32bit_number()%40)));
+
+
+
+
+ std::vector<std::pair<double, point> > dets, dets2, dets3;
+ scan_image(dets,images,rects,thresh, 100);
+ scan_image_i(dets2,images,rects,thresh, 100);
+ scan_image_old(dets3,images,rects,thresh, 100);
+
+ dlog << LTRACE << "dets.size(): "<< dets.size();
+ dlog << LTRACE << "dets2.size(): "<< dets2.size();
+ dlog << LTRACE << "dets3.size(): "<< dets3.size();
+
+ DLIB_TEST(dets.size() == dets2.size());
+ DLIB_TEST(dets.size() == dets3.size());
+
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ //dlog << LTRACE << "dets["<<i<<"]: " << dets[i].second << " -> " << dets[i].first;
+ //dlog << LTRACE << "dets2["<<i<<"]: " << dets2[i].second << " -> " << dets2[i].first;
+ //dlog << LTRACE << "dets3["<<i<<"]: " << dets3[i].second << " -> " << dets3[i].first;
+
+ DLIB_TEST_MSG(std::abs(sum_of_rects_in_images(images, rects, dets[i].second) - dets[i].first) < 1e-6,
+ "error: "<< sum_of_rects_in_images(images, rects, dets[i].second) - dets[i].first
+ << " dets["<<i<<"].second: " << dets[i].second
+ );
+ DLIB_TEST_MSG(std::abs(sum_of_rects_in_images(images, rects, dets2[i].second) - dets2[i].first) < 1e-6,
+ sum_of_rects_in_images(images, rects, dets2[i].second) - dets2[i].first
+ );
+ DLIB_TEST_MSG(std::abs(sum_of_rects_in_images(images, rects, dets3[i].second) - dets3[i].first) < 1e-6,
+ "error: "<< sum_of_rects_in_images(images, rects, dets3[i].second) - dets3[i].first
+ << " dets3["<<i<<"].first: " << dets3[i].first
+ << " dets3["<<i<<"].second: " << dets3[i].second
+ );
+ }
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename pixel_type>
+ void test_sum_filter (
+ )
+ {
+ dlib::rand rnd;
+
+ for (int k = 0; k < 20; ++k)
+ {
+ print_spinner();
+
+ array2d<pixel_type> img(1 + rnd.get_random_32bit_number()%100,
+ 1 + rnd.get_random_32bit_number()%100);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = static_cast<pixel_type>(100*(rnd.get_random_double()-0.5));
+ }
+ }
+
+ array2d<long> test1(img.nr(), img.nc());
+ array2d<double> test2(img.nr(), img.nc());
+ array2d<long> test1_i(img.nr(), img.nc());
+ array2d<double> test2_i(img.nr(), img.nc());
+
+ assign_all_pixels(test1, 0);
+ assign_all_pixels(test2, 0);
+ assign_all_pixels(test1_i, 0);
+ assign_all_pixels(test2_i, 0);
+
+ for (int i = 0; i < 10; ++i)
+ {
+ const long width = rnd.get_random_32bit_number()%10 + 1;
+ const long height = rnd.get_random_32bit_number()%10 + 1;
+ const point p(rnd.get_random_32bit_number()%img.nc(),
+ rnd.get_random_32bit_number()%img.nr());
+
+ const rectangle rect = centered_rect(p, width, height);
+ sum_filter(img, test1, rect);
+ sum_filter(img, test2, rect);
+ sum_filter(img, test1_i, rect);
+ sum_filter(img, test2_i, rect);
+
+ DLIB_TEST(mat(test1) == mat(test1_i));
+ DLIB_TEST(mat(test2) == mat(test2_i));
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type1,
+ typename image_type2
+ >
+ void naive_max_filter (
+ const image_type1& img,
+ image_type2& out,
+ const long width,
+ const long height,
+ typename image_type1::type thresh
+ )
+ {
+ const rectangle area = get_rect(img);
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ const rectangle win = centered_rect(point(c,r),width,height).intersect(area);
+ out[r][c] += std::max(dlib::max(subm(mat(img),win)), thresh);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_max_filter(long rows, long cols, long width, long height, dlib::rand& rnd)
+ {
+ array2d<int> img(rows, cols);
+ rectangle rect = centered_rect(0,0, width, height);
+
+ array2d<int> out(img.nr(),img.nc());
+ assign_all_pixels(out, 0);
+ array2d<int> out2(img.nr(),img.nc());
+ assign_all_pixels(out2, 0);
+
+ for (long r = 0; r < img.nr(); ++r)
+ {
+ for (long c = 0; c < img.nc(); ++c)
+ {
+ img[r][c] = rnd.get_random_32bit_number();
+ }
+ }
+
+ const int thresh = rnd.get_random_32bit_number();
+
+ naive_max_filter(img, out2, rect.width(), rect.height(), thresh);
+ max_filter(img, out, rect.width(), rect.height(), thresh);
+
+ DLIB_TEST_MSG(mat(out) == mat(out2),
+ "rows: "<< rows
+ << "\ncols: "<< rows
+ << "\nwidth: "<< width
+ << "\nheight: "<< height );
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_max_filter()
+ {
+ dlib::rand rnd;
+ for (int iter = 0; iter < 300; ++iter)
+ {
+ print_spinner();
+ test_max_filter(0,0,1,1,rnd);
+ test_max_filter(0,0,3,1,rnd);
+ test_max_filter(0,0,3,3,rnd);
+ test_max_filter(0,0,1,3,rnd);
+ test_max_filter(1,1,1,1,rnd);
+ test_max_filter(2,2,1,1,rnd);
+ test_max_filter(3,3,1,1,rnd);
+ test_max_filter(3,3,3,3,rnd);
+ test_max_filter(3,3,2,2,rnd);
+ test_max_filter(3,3,3,5,rnd);
+ test_max_filter(3,3,6,8,rnd);
+ test_max_filter(20,20,901,901,rnd);
+ test_max_filter(5,5,1,5,rnd);
+ test_max_filter(50,50,9,9,rnd);
+ test_max_filter(50,50,9,9,rnd);
+ test_max_filter(50,50,10,10,rnd);
+ test_max_filter(50,50,11,10,rnd);
+ test_max_filter(50,50,10,11,rnd);
+ test_max_filter(50,50,10,21,rnd);
+ test_max_filter(50,50,20,10,rnd);
+ test_max_filter(50,50,20,10,rnd);
+ test_max_filter(50,50,9,9,rnd);
+ test_max_filter(20,20,1,901,rnd);
+ test_max_filter(20,20,3,901,rnd);
+ test_max_filter(20,20,901,1,rnd);
+ }
+
+ for (int iter = 0; iter < 200; ++iter)
+ {
+ print_spinner();
+ test_max_filter((int)rnd.get_random_8bit_number()%100+1,
+ (int)rnd.get_random_8bit_number()%100+1,
+ (int)rnd.get_random_8bit_number()%150+1,
+ (int)rnd.get_random_8bit_number()%150+1,
+ rnd);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void make_images (
+ dlib::rand& rnd,
+ array<array2d<unsigned char> >& images,
+ long num,
+ long nr,
+ long nc
+ )
+ {
+ images.resize(num);
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ images[i].set_size(nr,nc);
+ }
+
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ for (long r = 0; r < nr; ++r)
+ {
+ for (long c = 0; c < nc; ++c)
+ {
+ images[i][r][c] = rnd.get_random_8bit_number();
+ }
+ }
+ }
+ }
+
+
+ template <
+ typename image_array_type
+ >
+ void brute_force_scan_image_movable_parts (
+ std::vector<std::pair<double, point> >& dets,
+ const image_array_type& images,
+ const rectangle& window,
+ const std::vector<std::pair<unsigned int, rectangle> >& fixed_rects,
+ const std::vector<std::pair<unsigned int, rectangle> >& movable_rects,
+ const double thresh,
+ const unsigned long
+ )
+ {
+ dets.clear();
+ if (movable_rects.size() == 0 && fixed_rects.size() == 0)
+ return;
+
+ for (long r = 0; r < images[0].nr(); ++r)
+ {
+ for (long c = 0; c < images[0].nc(); ++c)
+ {
+ const point p(c,r);
+ double score = sum_of_rects_in_images_movable_parts(images,
+ window,
+ fixed_rects,
+ movable_rects,
+ p);
+
+ if (score >= thresh)
+ {
+ dets.push_back(make_pair(score,p));
+ }
+ }
+ }
+ }
+
+ void test_scan_images_movable_parts()
+ {
+ array<array2d<unsigned char> > images;
+ dlib::rand rnd;
+ for (int iter = 0; iter < 40; ++iter)
+ {
+ print_spinner();
+ const int num_images = rnd.get_random_32bit_number()%4+1;
+
+ make_images(rnd,images, num_images,
+ rnd.get_random_32bit_number()%50+1,
+ rnd.get_random_32bit_number()%50+1
+ );
+
+ std::vector<std::pair<double,point> > dets1, dets2;
+ std::vector<std::pair<unsigned int, rectangle> > fixed_rects, movable_rects;
+
+ double total_area = 0;
+ for (unsigned long i = 0; i < images.size(); ++i)
+ {
+ fixed_rects.push_back(make_pair(i, centered_rect(
+ rnd.get_random_32bit_number()%10-5,
+ rnd.get_random_32bit_number()%10-5,
+ rnd.get_random_32bit_number()%10,
+ rnd.get_random_32bit_number()%10
+ )));
+
+ total_area += fixed_rects.back().second.area();
+
+ movable_rects.push_back(make_pair(i, centered_rect(
+ 0,
+ 0,
+ rnd.get_random_32bit_number()%10+1,
+ rnd.get_random_32bit_number()%10+1
+ )));
+ total_area += movable_rects.back().second.area();
+ }
+
+ const rectangle window = centered_rect(0,0,
+ rnd.get_random_32bit_number()%15+1,
+ rnd.get_random_32bit_number()%15+1);
+ dlog << LINFO << "window size: "<< window.width() << ", " << window.height();
+ const double thresh = total_area*130;
+ const unsigned long max_dets = get_rect(images[0]).area();
+
+ scan_image_movable_parts(dets1,images,window,fixed_rects,movable_rects,thresh, max_dets);
+ brute_force_scan_image_movable_parts(dets2,images,window,fixed_rects,movable_rects,thresh, max_dets);
+
+ dlog << LINFO << "max_possible dets: " << max_dets;
+ dlog << LINFO << "regular dets: " << dets1.size();
+ dlog << LINFO << "brute force: " << dets2.size();
+ DLIB_TEST(dets1.size() == dets2.size());
+
+ array2d<double> check(images[0].nr(), images[0].nc());
+ assign_all_pixels(check, 1e-300);
+ for (unsigned long i = 0; i < dets1.size(); ++i)
+ {
+ const point p = dets1[i].second;
+ check[p.y()][p.x()] = dets1[i].first;
+ }
+ for (unsigned long i = 0; i < dets2.size(); ++i)
+ {
+ const point p = dets2[i].second;
+ DLIB_TEST(std::abs(check[p.y()][p.x()] - dets2[i].first) < 1e-10);
+ }
+ dlog << LINFO << "=======================\n";
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class scan_image_tester : public tester
+ {
+ public:
+ scan_image_tester (
+ ) :
+ tester ("test_scan_image",
+ "Runs tests on the scan_image routine.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_scan_images_movable_parts();
+ test_max_filter();
+
+ run_test1();
+ run_test2();
+ run_test3<unsigned char>(1);
+ run_test3<unsigned char>(-1);
+ run_test3<double>(1);
+ run_test3<double>(-1);
+
+ test_sum_filter<unsigned char>();
+ test_sum_filter<double>();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/sequence.cpp b/ml/dlib/dlib/test/sequence.cpp
new file mode 100644
index 000000000..ffa6efdb8
--- /dev/null
+++ b/ml/dlib/dlib/test/sequence.cpp
@@ -0,0 +1,312 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/sequence.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.sequence");
+
+ template <
+ typename seq
+ >
+ void sequence_sort_test (
+ )
+ /*!
+ requires
+ - seq is an implementation of sequence/sequence_sort_aseqract.h is instantiated
+ with int
+ ensures
+ - runs tests on seq for compliance with the specs
+ !*/
+ {
+
+
+ srand(static_cast<unsigned int>(time(0)));
+
+
+ print_spinner();
+
+
+
+
+
+ {
+ // this test is to make sure that jumping around via
+ // operator[] doesn't corrupt the object
+
+ seq a;
+
+ for (int i = 0; i < 100; ++i)
+ {
+ int x = i;
+ a.add(a.size(),x);
+ }
+
+
+ int x = 0;
+
+ for (int i = 0; i < (int)a.size(); ++i)
+ {
+ DLIB_TEST_MSG(a[i] >= i,"1");
+ // cout << a[i] << endl;
+ }
+
+ for (unsigned long i = 0; i < a.size(); ++i)
+ {
+ for (unsigned long j = i+1; j < a.size(); ++j)
+ {
+ if ((a[j]+a[i])%3 ==0)
+ {
+ a.remove(j,x);
+ --j;
+ }
+ }
+ }
+
+ //cout << endl;
+
+ for (int i = 0; i < (int)a.size(); ++i)
+ {
+ // cout << a[i] << endl;
+ DLIB_TEST_MSG(a[i] >= i,"2");
+ }
+
+ }
+
+
+
+
+
+
+
+ seq test, test2;
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ enumerable<int>& e = test;
+
+ DLIB_TEST(e.at_start() == true);
+ DLIB_TEST(e.current_element_valid() == false);
+
+
+ for (int g = 0; g < 5; ++g)
+ {
+ test.clear();
+ test2.clear();
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(e.at_start() == true);
+ DLIB_TEST(e.current_element_valid() == false);
+
+ DLIB_TEST(e.move_next() == false);
+ DLIB_TEST(e.current_element_valid() == false);
+ DLIB_TEST(e.at_start() == false);
+ DLIB_TEST(test.at_start() == false);
+ swap(test,test2);
+ DLIB_TEST(test.at_start() == true);
+ test.clear();
+ test2.clear();
+
+ int a;
+
+
+ for (int i = 0; i < 100; ++i)
+ {
+ a = i;
+ test.add(i,a);
+ }
+
+ DLIB_TEST(test.size() == 100);
+
+ for (int i = 0; i < static_cast<int>(test.size()); ++i)
+ {
+ DLIB_TEST(test[i] == i);
+ }
+
+ swap(test,test2);
+
+ a = 0;
+ DLIB_TEST(test2.at_start() == true);
+ while(test2.move_next())
+ {
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == true);
+ DLIB_TEST(test2.element() == a);
+ ++a;
+ }
+
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+
+ test2.reset();
+
+ DLIB_TEST(test2.at_start() == true);
+ DLIB_TEST(test2.current_element_valid() == false);
+
+ a = 0;
+ while(test2.move_next())
+ {
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == true);
+ DLIB_TEST(test2.element() == a);
+ ++a;
+ }
+
+
+
+
+
+ for (int i = 0; i < 1000; ++i)
+ {
+ a = ::rand();
+ test.add(0,a);
+ }
+ DLIB_TEST(test.size() == 1000);
+
+ test.sort();
+
+
+ for (unsigned long i = 0; i < test.size()-1; ++i)
+ {
+ DLIB_TEST(test[i] <= test[i+1]);
+ }
+
+ a = 0;
+ while(test.move_next())
+ {
+ DLIB_TEST(a <= test.element());
+ a = test.element();
+ }
+
+
+ test.clear();
+ test2.clear();
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test2.size() == 0);
+
+ for (int i = 0; i < 100; ++i)
+ {
+ a = i;
+ test.add(i,a);
+ }
+
+ for (int i = 100; i < 200; ++i)
+ {
+ a = i;
+ test.add(i,a);
+ }
+
+ test.cat(test2);
+ DLIB_TEST(test.size() == 200);
+ DLIB_TEST(test2.size() == 0);
+
+
+ // serialize the state of test, then clear test, then
+ // load the state back into test.
+ ostringstream sout;
+ serialize(test,sout);
+ DLIB_TEST(test.at_start() == true);
+ istringstream sin(sout.str());
+ test.clear();
+ deserialize(test,sin);
+
+
+ for (int i = 0; i < 200; ++i)
+ {
+ DLIB_TEST(test[i] == i);
+ }
+
+ a = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element() == a);
+ DLIB_TEST(test[0]==0);
+ ++a;
+ }
+
+ DLIB_TEST(a == 200);
+
+ DLIB_TEST(test[9] == 9);
+ test.remove(9,a);
+ DLIB_TEST(a == 9);
+ DLIB_TEST(test[9] == 10);
+ DLIB_TEST(test.size() == 199);
+
+ test.remove(0,a);
+ DLIB_TEST(test[0] == 1);
+ DLIB_TEST(test.size() == 198);
+ DLIB_TEST(a == 0);
+ DLIB_TEST(test[9] == 11);
+ DLIB_TEST(test[20] == 22);
+
+
+
+
+ }
+
+ {
+ test.clear();
+ for (int i = 0; i < 100; ++i)
+ {
+ int a = 3;
+ test.add(0,a);
+ }
+ DLIB_TEST(test.size() == 100);
+ remover<int>& go = test;
+ for (int i = 0; i < 100; ++i)
+ {
+ int a = 9;
+ go.remove_any(a);
+ DLIB_TEST(a == 3);
+ }
+ DLIB_TEST(go.size() == 0);
+ }
+
+
+ }
+
+
+
+
+ class sequence_tester : public tester
+ {
+ public:
+ sequence_tester (
+ ) :
+ tester ("test_sequence",
+ "Runs tests on the sequence component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing sort_1a";
+ sequence_sort_test<sequence<int>::sort_1a> ();
+ dlog << LINFO << "testing sort_1a_c";
+ sequence_sort_test<sequence<int>::sort_1a_c>();
+ dlog << LINFO << "testing sort_2a";
+ sequence_sort_test<sequence<int>::sort_2a> ();
+ dlog << LINFO << "testing sort_2a_c";
+ sequence_sort_test<sequence<int>::sort_2a_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/sequence_labeler.cpp b/ml/dlib/dlib/test/sequence_labeler.cpp
new file mode 100644
index 000000000..4002d6821
--- /dev/null
+++ b/ml/dlib/dlib/test/sequence_labeler.cpp
@@ -0,0 +1,461 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+#include <dlib/rand.h>
+
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.sequence_labeler");
+
+// ----------------------------------------------------------------------------------------
+
+ const unsigned long num_label_states = 3; // the "hidden" states
+ const unsigned long num_sample_states = 3;
+
+// ----------------------------------------------------------------------------------------
+
+ struct funny_sequence
+ {
+ std::vector<unsigned long> item;
+ unsigned long size() const { return item.size(); }
+ };
+ funny_sequence make_funny_sequence(const std::vector<unsigned long>& item)
+ {
+ funny_sequence temp;
+ temp.item = item;
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class feature_extractor
+ {
+ public:
+ typedef funny_sequence sequence_type;
+
+ unsigned long num_features() const
+ {
+ return num_label_states*num_label_states + num_label_states*num_sample_states;
+ }
+
+ unsigned long order() const
+ {
+ return 1;
+ }
+
+ unsigned long num_labels() const
+ {
+ return num_label_states;
+ }
+
+ template <typename feature_setter, typename EXP>
+ void get_features (
+ feature_setter& set_feature,
+ const sequence_type& x,
+ const matrix_exp<EXP>& y,
+ unsigned long position
+ ) const
+ {
+ if (y.size() > 1)
+ set_feature(y(1)*num_label_states + y(0));
+
+ set_feature(num_label_states*num_label_states +
+ y(0)*num_sample_states + x.item[position]);
+ }
+ };
+
+ class feature_extractor_partial
+ {
+ public:
+ typedef funny_sequence sequence_type;
+
+ unsigned long num_features() const
+ {
+ return num_label_states*num_label_states + num_label_states*num_sample_states;
+ }
+
+ unsigned long order() const
+ {
+ return 1;
+ }
+
+ unsigned long num_labels() const
+ {
+ return num_label_states;
+ }
+
+ template <typename feature_setter, typename EXP>
+ void get_features (
+ feature_setter& set_feature,
+ const sequence_type& x,
+ const matrix_exp<EXP>& y,
+ unsigned long position
+ ) const
+ {
+ if (y.size() > 1)
+ {
+ set_feature(y(1)*num_label_states + y(0), 0.5);
+ set_feature(y(1)*num_label_states + y(0), 0.5);
+ }
+
+ set_feature(num_label_states*num_label_states +
+ y(0)*num_sample_states + x.item[position],0.25);
+ set_feature(num_label_states*num_label_states +
+ y(0)*num_sample_states + x.item[position],0.75);
+ }
+ };
+
+ bool called_rejct_labeling = false;
+ class feature_extractor2
+ {
+ public:
+ typedef funny_sequence sequence_type;
+
+ unsigned long num_features() const
+ {
+ return num_label_states*num_label_states + num_label_states*num_sample_states;
+ }
+
+ unsigned long order() const
+ {
+ return 1;
+ }
+
+ unsigned long num_labels() const
+ {
+ return num_label_states;
+ }
+
+ template <typename EXP>
+ bool reject_labeling (
+ const sequence_type& ,
+ const matrix_exp<EXP>& ,
+ unsigned long
+ ) const
+ {
+ called_rejct_labeling = true;
+ return false;
+ }
+
+ template <typename feature_setter, typename EXP>
+ void get_features (
+ feature_setter& set_feature,
+ const sequence_type& x,
+ const matrix_exp<EXP>& y,
+ unsigned long position
+ ) const
+ {
+ if (y.size() > 1)
+ set_feature(y(1)*num_label_states + y(0));
+
+ set_feature(num_label_states*num_label_states +
+ y(0)*num_sample_states + x.item[position]);
+ }
+ };
+
+ void serialize(const feature_extractor&, std::ostream&) {}
+ void deserialize(feature_extractor&, std::istream&) {}
+ void serialize(const feature_extractor2&, std::ostream&) {}
+ void deserialize(feature_extractor2&, std::istream&) {}
+
+// ----------------------------------------------------------------------------------------
+
+ void sample_hmm (
+ dlib::rand& rnd,
+ const matrix<double>& transition_probabilities,
+ const matrix<double>& emission_probabilities,
+ unsigned long previous_label,
+ unsigned long& next_label,
+ unsigned long& next_sample
+ )
+ /*!
+ requires
+ - previous_label < transition_probabilities.nr()
+ - transition_probabilities.nr() == transition_probabilities.nc()
+ - transition_probabilities.nr() == emission_probabilities.nr()
+ - The rows of transition_probabilities and emission_probabilities must sum to 1.
+ (i.e. sum_cols(transition_probabilities) and sum_cols(emission_probabilities)
+ must evaluate to vectors of all 1s.)
+ ensures
+ - This function randomly samples the HMM defined by transition_probabilities
+ and emission_probabilities assuming that the previous hidden state
+ was previous_label.
+ - The HMM is defined by:
+ - P(next_label |previous_label) == transition_probabilities(previous_label, next_label)
+ - P(next_sample|next_label) == emission_probabilities (next_label, next_sample)
+ - #next_label == the sampled value of the hidden state
+ - #next_sample == the sampled value of the observed state
+ !*/
+ {
+ // sample next_label
+ double p = rnd.get_random_double();
+ for (long c = 0; p >= 0 && c < transition_probabilities.nc(); ++c)
+ {
+ next_label = c;
+ p -= transition_probabilities(previous_label, c);
+ }
+
+ // now sample next_sample
+ p = rnd.get_random_double();
+ for (long c = 0; p >= 0 && c < emission_probabilities.nc(); ++c)
+ {
+ next_sample = c;
+ p -= emission_probabilities(next_label, c);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void make_dataset (
+ const matrix<double>& transition_probabilities,
+ const matrix<double>& emission_probabilities,
+ std::vector<funny_sequence>& samples,
+ std::vector<std::vector<unsigned long> >& labels,
+ unsigned long dataset_size
+ )
+ /*!
+ requires
+ - transition_probabilities.nr() == transition_probabilities.nc()
+ - transition_probabilities.nr() == emission_probabilities.nr()
+ - The rows of transition_probabilities and emission_probabilities must sum to 1.
+ (i.e. sum_cols(transition_probabilities) and sum_cols(emission_probabilities)
+ must evaluate to vectors of all 1s.)
+ ensures
+ - This function randomly samples a bunch of sequences from the HMM defined by
+ transition_probabilities and emission_probabilities.
+ - The HMM is defined by:
+ - The probability of transitioning from hidden state H1 to H2
+ is given by transition_probabilities(H1,H2).
+ - The probability of a hidden state H producing an observed state
+ O is given by emission_probabilities(H,O).
+ - #samples.size() == labels.size() == dataset_size
+ - for all valid i:
+ - #labels[i] is a randomly sampled sequence of hidden states from the
+ given HMM. #samples[i] is its corresponding randomly sampled sequence
+ of observed states.
+ !*/
+ {
+ samples.clear();
+ labels.clear();
+
+ dlib::rand rnd;
+
+ // now randomly sample some labeled sequences from our Hidden Markov Model
+ for (unsigned long iter = 0; iter < dataset_size; ++iter)
+ {
+ const unsigned long sequence_size = rnd.get_random_32bit_number()%20+3;
+ std::vector<unsigned long> sample(sequence_size);
+ std::vector<unsigned long> label(sequence_size);
+
+ unsigned long previous_label = rnd.get_random_32bit_number()%num_label_states;
+ for (unsigned long i = 0; i < sample.size(); ++i)
+ {
+ unsigned long next_label=0, next_sample=0;
+ sample_hmm(rnd, transition_probabilities, emission_probabilities,
+ previous_label, next_label, next_sample);
+
+ label[i] = next_label;
+ sample[i] = next_sample;
+
+ previous_label = next_label;
+ }
+
+ samples.push_back(make_funny_sequence(sample));
+ labels.push_back(label);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename fe_type>
+ void do_test()
+ {
+ called_rejct_labeling = false;
+
+ matrix<double> transition_probabilities(num_label_states, num_label_states);
+ transition_probabilities = 0.05, 0.90, 0.05,
+ 0.05, 0.05, 0.90,
+ 0.90, 0.05, 0.05;
+
+ matrix<double> emission_probabilities(num_label_states,num_sample_states);
+ emission_probabilities = 0.5, 0.5, 0.0,
+ 0.0, 0.5, 0.5,
+ 0.5, 0.0, 0.5;
+
+ print_spinner();
+
+
+ std::vector<funny_sequence> samples;
+ std::vector<std::vector<unsigned long> > labels;
+ make_dataset(transition_probabilities,emission_probabilities,
+ samples, labels, 1000);
+
+ dlog << LINFO << "samples.size(): "<< samples.size();
+
+ // print out some of the randomly sampled sequences
+ for (int i = 0; i < 10; ++i)
+ {
+ dlog << LINFO << "hidden states: " << trans(mat(labels[i]));
+ dlog << LINFO << "observed states: " << trans(mat(samples[i].item));
+ dlog << LINFO << "******************************";
+ }
+
+ print_spinner();
+ structural_sequence_labeling_trainer<fe_type> trainer;
+ trainer.set_c(4);
+ DLIB_TEST(trainer.get_c() == 4);
+ trainer.set_num_threads(4);
+ DLIB_TEST(trainer.get_num_threads() == 4);
+
+
+
+ // Learn to do sequence labeling from the dataset
+ sequence_labeler<fe_type> labeler = trainer.train(samples, labels);
+
+ std::vector<unsigned long> predicted_labels = labeler(samples[0]);
+ dlog << LINFO << "true hidden states: "<< trans(mat(labels[0]));
+ dlog << LINFO << "predicted hidden states: "<< trans(mat(predicted_labels));
+
+ DLIB_TEST(mat(labels[0]) == mat(predicted_labels));
+
+
+ print_spinner();
+
+
+ // We can also do cross-validation
+ matrix<double> confusion_matrix;
+ confusion_matrix = cross_validate_sequence_labeler(trainer, samples, labels, 4);
+ dlog << LINFO << "cross-validation: ";
+ dlog << LINFO << confusion_matrix;
+ double accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix);
+ dlog << LINFO << "label accuracy: "<< accuracy;
+ DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
+
+ print_spinner();
+
+
+ matrix<double,0,1> true_hmm_model_weights = log(join_cols(reshape_to_column_vector(transition_probabilities),
+ reshape_to_column_vector(emission_probabilities)));
+
+ sequence_labeler<fe_type> labeler_true(true_hmm_model_weights);
+
+ confusion_matrix = test_sequence_labeler(labeler_true, samples, labels);
+ dlog << LINFO << "True HMM model: ";
+ dlog << LINFO << confusion_matrix;
+ accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix);
+ dlog << LINFO << "label accuracy: "<< accuracy;
+ DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
+
+
+
+ print_spinner();
+
+
+
+
+ // Finally, the labeler can be serialized to disk just like most dlib objects.
+ ostringstream sout;
+ serialize(labeler, sout);
+
+ sequence_labeler<fe_type> labeler2;
+ // recall from disk
+ istringstream sin(sout.str());
+ deserialize(labeler2, sin);
+ confusion_matrix = test_sequence_labeler(labeler2, samples, labels);
+ dlog << LINFO << "deserialized labeler: ";
+ dlog << LINFO << confusion_matrix;
+ accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix);
+ dlog << LINFO << "label accuracy: "<< accuracy;
+ DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test2()
+ {
+ /*
+ The point of this test is to make sure calling set_feature() multiple
+ times works the way it is supposed to.
+ */
+
+ print_spinner();
+ std::vector<funny_sequence> samples;
+ std::vector<std::vector<unsigned long> > labels;
+
+ matrix<double> transition_probabilities(num_label_states, num_label_states);
+ transition_probabilities = 0.05, 0.90, 0.05,
+ 0.05, 0.05, 0.90,
+ 0.90, 0.05, 0.05;
+
+ matrix<double> emission_probabilities(num_label_states,num_sample_states);
+ emission_probabilities = 0.5, 0.5, 0.0,
+ 0.0, 0.5, 0.5,
+ 0.5, 0.0, 0.5;
+
+
+ make_dataset(transition_probabilities,emission_probabilities,
+ samples, labels, 1000);
+
+ dlog << LINFO << "samples.size(): "<< samples.size();
+
+ structural_sequence_labeling_trainer<feature_extractor> trainer;
+ structural_sequence_labeling_trainer<feature_extractor_partial> trainer_part;
+ trainer.set_c(4);
+ trainer_part.set_c(4);
+ trainer.set_num_threads(4);
+ trainer_part.set_num_threads(4);
+ trainer.set_epsilon(1e-8);
+ trainer_part.set_epsilon(1e-8);
+
+
+
+ // Learn to do sequence labeling from the dataset
+ sequence_labeler<feature_extractor> labeler = trainer.train(samples, labels);
+ sequence_labeler<feature_extractor_partial> labeler_part = trainer_part.train(samples, labels);
+
+ dlog << LINFO << "weight disagreement: "<< max(abs(labeler.get_weights() - labeler_part.get_weights()));
+ dlog << LINFO << "max weight magnitude: "<< max(abs(labeler.get_weights()));
+
+ // Both feature extractors should be equivalent.
+ DLIB_TEST(max(abs(labeler.get_weights() - labeler_part.get_weights())) < 1e-6);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class sequence_labeler_tester : public tester
+ {
+ public:
+ sequence_labeler_tester (
+ ) :
+ tester ("test_sequence_labeler",
+ "Runs tests on the sequence labeling code.")
+ {}
+
+ void perform_test (
+ )
+ {
+ do_test<feature_extractor>();
+ DLIB_TEST(called_rejct_labeling == false);
+ do_test<feature_extractor2>();
+ DLIB_TEST(called_rejct_labeling == true);
+
+ test2();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/sequence_segmenter.cpp b/ml/dlib/dlib/test/sequence_segmenter.cpp
new file mode 100644
index 000000000..acdcd69be
--- /dev/null
+++ b/ml/dlib/dlib/test/sequence_segmenter.cpp
@@ -0,0 +1,294 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+#include <dlib/rand.h>
+
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.sequence_segmenter");
+
+// ----------------------------------------------------------------------------------------
+
+ dlib::rand rnd;
+
+ template <bool use_BIO_model_, bool use_high_order_features_, bool allow_negative_weights_>
+ class unigram_extractor
+ {
+ public:
+
+ const static bool use_BIO_model = use_BIO_model_;
+ const static bool use_high_order_features = use_high_order_features_;
+ const static bool allow_negative_weights = allow_negative_weights_;
+
+ typedef std::vector<unsigned long> sequence_type;
+
+ std::map<unsigned long, matrix<double,0,1> > feats;
+
+ unigram_extractor()
+ {
+ matrix<double,0,1> v1, v2, v3;
+ v1 = randm(num_features(), 1, rnd);
+ v2 = randm(num_features(), 1, rnd);
+ v3 = randm(num_features(), 1, rnd);
+ v1(0) = 1;
+ v2(1) = 1;
+ v3(2) = 1;
+ v1(3) = -1;
+ v2(4) = -1;
+ v3(5) = -1;
+ for (unsigned long i = 0; i < num_features(); ++i)
+ {
+ if ( i < 3)
+ feats[i] = v1;
+ else if (i < 6)
+ feats[i] = v2;
+ else
+ feats[i] = v3;
+ }
+ }
+
+ unsigned long num_features() const { return 10; }
+ unsigned long window_size() const { return 3; }
+
+ template <typename feature_setter>
+ void get_features (
+ feature_setter& set_feature,
+ const sequence_type& x,
+ unsigned long position
+ ) const
+ {
+ const matrix<double,0,1>& m = feats.find(x[position])->second;
+ for (unsigned long i = 0; i < num_features(); ++i)
+ {
+ set_feature(i, m(i));
+ }
+ }
+
+ };
+
+ template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
+ void serialize(const unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item , std::ostream& out )
+ {
+ serialize(item.feats, out);
+ }
+
+ template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
+ void deserialize(unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item, std::istream& in)
+ {
+ deserialize(item.feats, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void make_dataset (
+ std::vector<std::vector<unsigned long> >& samples,
+ std::vector<std::vector<unsigned long> >& labels,
+ unsigned long dataset_size
+ )
+ {
+ samples.clear();
+ labels.clear();
+
+ samples.resize(dataset_size);
+ labels.resize(dataset_size);
+
+
+ unigram_extractor<true,true,true> fe;
+ dlib::rand rnd;
+
+ for (unsigned long iter = 0; iter < dataset_size; ++iter)
+ {
+
+ samples[iter].resize(10);
+ labels[iter].resize(10);
+
+ for (unsigned long i = 0; i < samples[iter].size(); ++i)
+ {
+ samples[iter][i] = rnd.get_random_32bit_number()%fe.num_features();
+ if (samples[iter][i] < 3)
+ {
+ labels[iter][i] = impl_ss::BEGIN;
+ }
+ else if (samples[iter][i] < 6)
+ {
+ labels[iter][i] = impl_ss::INSIDE;
+ }
+ else
+ {
+ labels[iter][i] = impl_ss::OUTSIDE;
+ }
+
+ if (i != 0)
+ {
+ // do rejection sampling to avoid impossible labels
+ if (labels[iter][i] == impl_ss::INSIDE &&
+ labels[iter][i-1] == impl_ss::OUTSIDE)
+ {
+ --i;
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void make_dataset2 (
+ std::vector<std::vector<unsigned long> >& samples,
+ std::vector<std::vector<std::pair<unsigned long, unsigned long> > >& segments,
+ unsigned long dataset_size
+ )
+ {
+ segments.clear();
+ std::vector<std::vector<unsigned long> > labels;
+ make_dataset(samples, labels, dataset_size);
+ segments.resize(samples.size());
+
+ // Convert from BIO tagging to the explicit segments representation.
+ for (unsigned long k = 0; k < labels.size(); ++k)
+ {
+ for (unsigned long i = 0; i < labels[k].size(); ++i)
+ {
+ if (labels[k][i] == impl_ss::BEGIN)
+ {
+ const unsigned long begin = i;
+ ++i;
+ while (i < labels[k].size() && labels[k][i] == impl_ss::INSIDE)
+ ++i;
+
+ segments[k].push_back(std::make_pair(begin, i));
+ --i;
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <bool use_BIO_model, bool use_high_order_features, bool allow_negative_weights>
+ void do_test()
+ {
+ dlog << LINFO << "use_BIO_model: "<< use_BIO_model;
+ dlog << LINFO << "use_high_order_features: "<< use_high_order_features;
+ dlog << LINFO << "allow_negative_weights: "<< allow_negative_weights;
+
+ std::vector<std::vector<unsigned long> > samples;
+ std::vector<std::vector<std::pair<unsigned long,unsigned long> > > segments;
+ make_dataset2( samples, segments, 100);
+
+ print_spinner();
+ typedef unigram_extractor<use_BIO_model,use_high_order_features,allow_negative_weights> fe_type;
+
+ fe_type fe_temp;
+ fe_type fe_temp2;
+ structural_sequence_segmentation_trainer<fe_type> trainer(fe_temp2);
+ trainer.set_c(5);
+ trainer.set_num_threads(1);
+
+
+ sequence_segmenter<fe_type> labeler = trainer.train(samples, segments);
+
+ print_spinner();
+
+ const std::vector<std::pair<unsigned long, unsigned long> > predicted_labels = labeler(samples[1]);
+ const std::vector<std::pair<unsigned long, unsigned long> > true_labels = segments[1];
+ /*
+ for (unsigned long i = 0; i < predicted_labels.size(); ++i)
+ cout << "["<<predicted_labels[i].first<<","<<predicted_labels[i].second<<") ";
+ cout << endl;
+ for (unsigned long i = 0; i < true_labels.size(); ++i)
+ cout << "["<<true_labels[i].first<<","<<true_labels[i].second<<") ";
+ cout << endl;
+ */
+
+ DLIB_TEST(predicted_labels.size() > 0);
+ DLIB_TEST(predicted_labels.size() == true_labels.size());
+ for (unsigned long i = 0; i < predicted_labels.size(); ++i)
+ {
+ DLIB_TEST(predicted_labels[i].first == true_labels[i].first);
+ DLIB_TEST(predicted_labels[i].second == true_labels[i].second);
+ }
+
+
+ matrix<double> res;
+
+ res = cross_validate_sequence_segmenter(trainer, samples, segments, 3);
+ dlog << LINFO << "cv res: "<< res;
+ DLIB_TEST(min(res) > 0.98);
+ make_dataset2( samples, segments, 100);
+ res = test_sequence_segmenter(labeler, samples, segments);
+ dlog << LINFO << "test res: "<< res;
+ DLIB_TEST(min(res) > 0.98);
+
+ print_spinner();
+
+ ostringstream sout;
+ serialize(labeler, sout);
+ istringstream sin(sout.str());
+ sequence_segmenter<fe_type> labeler2;
+ deserialize(labeler2, sin);
+
+ res = test_sequence_segmenter(labeler2, samples, segments);
+ dlog << LINFO << "test res2: "<< res;
+ DLIB_TEST(min(res) > 0.98);
+
+ long N;
+ if (use_BIO_model)
+ N = 3*3+3;
+ else
+ N = 5*5+5;
+ const double min_normal_weight = min(colm(labeler2.get_weights(), 0, labeler2.get_weights().size()-N));
+ const double min_trans_weight = min(labeler2.get_weights());
+ dlog << LINFO << "min_normal_weight: " << min_normal_weight;
+ dlog << LINFO << "min_trans_weight: " << min_trans_weight;
+ if (allow_negative_weights)
+ {
+ DLIB_TEST(min_normal_weight < 0);
+ DLIB_TEST(min_trans_weight < 0);
+ }
+ else
+ {
+ DLIB_TEST(min_normal_weight == 0);
+ DLIB_TEST(min_trans_weight < 0);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ class unit_test_sequence_segmenter : public tester
+ {
+ public:
+ unit_test_sequence_segmenter (
+ ) :
+ tester ("test_sequence_segmenter",
+ "Runs tests on the sequence segmenting code.")
+ {}
+
+ void perform_test (
+ )
+ {
+ do_test<true,true,false>();
+ do_test<true,false,false>();
+ do_test<false,true,false>();
+ do_test<false,false,false>();
+ do_test<true,true,true>();
+ do_test<true,false,true>();
+ do_test<false,true,true>();
+ do_test<false,false,true>();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/serialize.cpp b/ml/dlib/dlib/test/serialize.cpp
new file mode 100644
index 000000000..f8b3384b9
--- /dev/null
+++ b/ml/dlib/dlib/test/serialize.cpp
@@ -0,0 +1,1087 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <dlib/compress_stream.h>
+#include <dlib/base64.h>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/serialize.h>
+#include <dlib/image_transforms.h>
+
+#include "tester.h"
+
+namespace dlib
+{
+ static bool operator!=(const rgb_pixel& a, const rgb_pixel& b)
+ {
+ return !(a.red==b.red && a.green==b.green && a.blue==b.blue);
+ }
+ static bool operator!=(const bgr_pixel& a, const bgr_pixel& b)
+ {
+ return !(a.red==b.red && a.green==b.green && a.blue==b.blue);
+ }
+
+ static bool operator!=(const hsi_pixel& a, const hsi_pixel& b)
+ {
+ return !(a.h==b.h && a.s==b.s && a.i==b.i);
+ }
+ static bool operator!=(const rgb_alpha_pixel& a, const rgb_alpha_pixel& b)
+ {
+ return !(a.red==b.red && a.green==b.green && a.blue==b.blue && a.alpha==b.alpha);
+ }
+
+}
+
+namespace
+{
+
+// ----------------------------------------------------------------------------------------
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ struct test_object
+ {
+ signed char i1;
+ signed short i2;
+ signed long i3;
+ unsigned char i4;
+ unsigned short i5;
+ unsigned long i6;
+ uint64 i7;
+ int64 i8;
+
+ signed char i1_0;
+ signed short i2_0;
+ signed long i3_0;
+ unsigned char i4_0;
+ unsigned short i5_0;
+ unsigned long i6_0;
+ uint64 i7_0;
+ int64 i8_0;
+
+ signed char i1_n;
+ signed short i2_n;
+ signed long i3_n;
+
+
+ float f1;
+ double f2;
+ long double f3;
+ float f1_inf;
+ double f2_inf;
+ long double f3_inf;
+ float f1_ninf;
+ double f2_ninf;
+ long double f3_ninf;
+ float f1_qnan;
+ double f2_qnan;
+ long double f3_qnan;
+ float f1_snan;
+ double f2_snan;
+ long double f3_snan;
+
+ std::string s1;
+ std::wstring s2;
+
+ int array[10];
+
+ bool b_true;
+ bool b_false;
+
+
+ void set_state_1(
+ )
+ {
+ i1 = 1;
+ i2 = 2;
+ i3 = 3;
+ i4 = 4;
+ i5 = 5;
+ i6 = 6;
+ i7 = 7;
+ i8 = 8;
+
+ i1_0 = 0;
+ i2_0 = 0;
+ i3_0 = 0;
+ i4_0 = 0;
+ i5_0 = 0;
+ i6_0 = 0;
+ i7_0 = 0;
+ i8_0 = 0;
+
+ i1_n = -1;
+ i2_n = -2;
+ i3_n = -3;
+
+ f1 = 123.456f;
+ f2 = 543.341;
+ f3 = 5234234.23;
+
+ f1_inf = numeric_limits<float>::infinity();
+ f2_inf = numeric_limits<double>::infinity();
+ f3_inf = numeric_limits<long double>::infinity();
+ f1_ninf = -numeric_limits<float>::infinity();
+ f2_ninf = -numeric_limits<double>::infinity();
+ f3_ninf = -numeric_limits<long double>::infinity();
+ f1_qnan = numeric_limits<float>::quiet_NaN();
+ f2_qnan = numeric_limits<double>::quiet_NaN();
+ f3_qnan = numeric_limits<long double>::quiet_NaN();
+ f1_snan = numeric_limits<float>::signaling_NaN();
+ f2_snan = numeric_limits<double>::signaling_NaN();
+ f3_snan = numeric_limits<long double>::signaling_NaN();
+
+ s1 = "davis";
+ s2 = L"yo yo yo";
+
+ for (int i = 0; i < 10; ++i)
+ array[i] = i;
+
+ b_true = true;
+ b_false = false;
+ }
+
+ void set_state_2(
+ )
+ {
+ i1 = 10;
+ i2 = 20;
+ i3 = 30;
+ i4 = 40;
+ i5 = 50;
+ i6 = 60;
+ i7 = 70;
+ i8 = 80;
+
+ i1_0 = 5;
+ i2_0 = 6;
+ i3_0 = 7;
+ i4_0 = 8;
+ i5_0 = 9;
+ i6_0 = 10;
+ i7_0 = 11;
+ i8_0 = 12;
+
+ i1_n = -13;
+ i2_n = -25;
+ i3_n = -12;
+
+ f1 = 45.3f;
+ f2 = 0.001;
+ f3 = 2.332;
+
+ f1_inf = f1;
+ f2_inf = f2;
+ f3_inf = f3;
+ f1_ninf = f1;
+ f2_ninf = f2;
+ f3_ninf = f3;
+ f1_qnan = f1;
+ f2_qnan = f2;
+ f3_qnan = f3;
+ f1_snan = f1;
+ f2_snan = f2;
+ f3_snan = f3;
+
+ s1 = "";
+ s2 = L"";
+
+ for (int i = 0; i < 10; ++i)
+ array[i] = 10-i;
+
+ b_true = false;
+ b_false = true;
+ }
+
+ void assert_in_state_1 (
+ )
+ {
+ DLIB_TEST (i1 == 1);
+ DLIB_TEST (i2 == 2);
+ DLIB_TEST (i3 == 3);
+ DLIB_TEST (i4 == 4);
+ DLIB_TEST (i5 == 5);
+ DLIB_TEST (i6 == 6);
+ DLIB_TEST (i7 == 7);
+ DLIB_TEST (i8 == 8);
+
+ DLIB_TEST (i1_0 == 0);
+ DLIB_TEST (i2_0 == 0);
+ DLIB_TEST (i3_0 == 0);
+ DLIB_TEST (i4_0 == 0);
+ DLIB_TEST (i5_0 == 0);
+ DLIB_TEST (i6_0 == 0);
+ DLIB_TEST (i7_0 == 0);
+ DLIB_TEST (i8_0 == 0);
+
+ DLIB_TEST (i1_n == -1);
+ DLIB_TEST (i2_n == -2);
+ DLIB_TEST (i3_n == -3);
+
+ DLIB_TEST (abs(f1 -123.456) < 1e-5);
+ DLIB_TEST (abs(f2 - 543.341) < 1e-10);
+ DLIB_TEST (abs(f3 - 5234234.23) < 1e-10);
+
+ DLIB_TEST (f1_inf == numeric_limits<float>::infinity());
+ DLIB_TEST (f2_inf == numeric_limits<double>::infinity());
+ DLIB_TEST (f3_inf == numeric_limits<long double>::infinity());
+ DLIB_TEST (f1_ninf == -numeric_limits<float>::infinity());
+ DLIB_TEST (f2_ninf == -numeric_limits<double>::infinity());
+ DLIB_TEST (f3_ninf == -numeric_limits<long double>::infinity());
+ DLIB_TEST (!(f1_qnan <= numeric_limits<float>::infinity() && f1_qnan >= -numeric_limits<float>::infinity() ));
+ DLIB_TEST (!(f2_qnan <= numeric_limits<double>::infinity() && f1_qnan >= -numeric_limits<double>::infinity() ));
+ DLIB_TEST (!(f3_qnan <= numeric_limits<long double>::infinity() && f1_qnan >= -numeric_limits<long double>::infinity() ));
+ DLIB_TEST (!(f1_snan <= numeric_limits<float>::infinity() && f1_qnan >= -numeric_limits<float>::infinity() ));
+ DLIB_TEST (!(f2_snan <= numeric_limits<double>::infinity() && f1_qnan >= -numeric_limits<double>::infinity() ));
+ DLIB_TEST (!(f3_snan <= numeric_limits<long double>::infinity() && f1_qnan >= -numeric_limits<long double>::infinity() ));
+
+ DLIB_TEST (s1 == "davis");
+ DLIB_TEST (s2 == L"yo yo yo");
+
+ for (int i = 0; i < 10; ++i)
+ {
+ DLIB_TEST (array[i] == i);
+ }
+
+ DLIB_TEST (b_true == true);
+ DLIB_TEST (b_false == false);
+
+ }
+
+ void assert_in_state_2 (
+ )
+ {
+ DLIB_TEST (i1 == 10);
+ DLIB_TEST (i2 == 20);
+ DLIB_TEST (i3 == 30);
+ DLIB_TEST (i4 == 40);
+ DLIB_TEST (i5 == 50);
+ DLIB_TEST (i6 == 60);
+ DLIB_TEST (i7 == 70);
+ DLIB_TEST (i8 == 80);
+
+ DLIB_TEST (i1_0 == 5);
+ DLIB_TEST (i2_0 == 6);
+ DLIB_TEST (i3_0 == 7);
+ DLIB_TEST (i4_0 == 8);
+ DLIB_TEST (i5_0 == 9);
+ DLIB_TEST (i6_0 == 10);
+ DLIB_TEST (i7_0 == 11);
+ DLIB_TEST (i8_0 == 12);
+
+ DLIB_TEST (i1_n == -13);
+ DLIB_TEST (i2_n == -25);
+ DLIB_TEST (i3_n == -12);
+
+ DLIB_TEST (abs(f1 - 45.3) < 1e-5);
+ DLIB_TEST (abs(f2 - 0.001) < 1e-10);
+ DLIB_TEST (abs(f3 - 2.332) < 1e-10);
+ DLIB_TEST (abs(f1_inf - 45.3) < 1e-5);
+ DLIB_TEST (abs(f2_inf - 0.001) < 1e-10);
+ DLIB_TEST (abs(f3_inf - 2.332) < 1e-10);
+ DLIB_TEST (abs(f1_ninf - 45.3) < 1e-5);
+ DLIB_TEST (abs(f2_ninf - 0.001) < 1e-10);
+ DLIB_TEST (abs(f3_ninf - 2.332) < 1e-10);
+ DLIB_TEST (abs(f1_qnan - 45.3) < 1e-5);
+ DLIB_TEST (abs(f2_qnan - 0.001) < 1e-10);
+ DLIB_TEST (abs(f3_qnan - 2.332) < 1e-10);
+ DLIB_TEST (abs(f1_snan - 45.3) < 1e-5);
+ DLIB_TEST (abs(f2_snan - 0.001) < 1e-10);
+ DLIB_TEST (abs(f3_snan - 2.332) < 1e-10);
+
+ DLIB_TEST (s1 == "");
+ DLIB_TEST (s2 == L"");
+
+ for (int i = 0; i < 10; ++i)
+ {
+ DLIB_TEST (array[i] == 10-i);
+ }
+
+ DLIB_TEST (b_true == false);
+ DLIB_TEST (b_false == true);
+
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize (
+ const test_object& item,
+ std::ostream& out
+ )
+ {
+ dlib::serialize(item.i1,out);
+ dlib::serialize(item.i2,out);
+ dlib::serialize(item.i3,out);
+ dlib::serialize(item.i4,out);
+ dlib::serialize(item.i5,out);
+ dlib::serialize(item.i6,out);
+ dlib::serialize(item.i7,out);
+ dlib::serialize(item.i8,out);
+
+ dlib::serialize(item.i1_0,out);
+ dlib::serialize(item.i2_0,out);
+ dlib::serialize(item.i3_0,out);
+ dlib::serialize(item.i4_0,out);
+ dlib::serialize(item.i5_0,out);
+ dlib::serialize(item.i6_0,out);
+ dlib::serialize(item.i7_0,out);
+ dlib::serialize(item.i8_0,out);
+
+ dlib::serialize(item.i1_n,out);
+ dlib::serialize(item.i2_n,out);
+ dlib::serialize(item.i3_n,out);
+
+
+ dlib::serialize(item.f1,out);
+ dlib::serialize(item.f2,out);
+ dlib::serialize(item.f3,out);
+
+ dlib::serialize(item.f1_inf,out);
+ dlib::serialize(item.f2_inf,out);
+ dlib::serialize(item.f3_inf,out);
+ dlib::serialize(item.f1_ninf,out);
+ dlib::serialize(item.f2_ninf,out);
+ dlib::serialize(item.f3_ninf,out);
+ dlib::serialize(item.f1_qnan,out);
+ dlib::serialize(item.f2_qnan,out);
+ dlib::serialize(item.f3_qnan,out);
+ dlib::serialize(item.f1_snan,out);
+ dlib::serialize(item.f2_snan,out);
+ dlib::serialize(item.f3_snan,out);
+
+ dlib::serialize(item.s1,out);
+ dlib::serialize(item.s2,out);
+
+ dlib::serialize(item.array,out);
+
+ dlib::serialize(item.b_true,out);
+ dlib::serialize(item.b_false,out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void deserialize (
+ test_object& item,
+ std::istream& in
+ )
+ {
+ dlib::deserialize(item.i1,in);
+ dlib::deserialize(item.i2,in);
+ dlib::deserialize(item.i3,in);
+ dlib::deserialize(item.i4,in);
+ dlib::deserialize(item.i5,in);
+ dlib::deserialize(item.i6,in);
+ dlib::deserialize(item.i7,in);
+ dlib::deserialize(item.i8,in);
+
+ dlib::deserialize(item.i1_0,in);
+ dlib::deserialize(item.i2_0,in);
+ dlib::deserialize(item.i3_0,in);
+ dlib::deserialize(item.i4_0,in);
+ dlib::deserialize(item.i5_0,in);
+ dlib::deserialize(item.i6_0,in);
+ dlib::deserialize(item.i7_0,in);
+ dlib::deserialize(item.i8_0,in);
+
+ dlib::deserialize(item.i1_n,in);
+ dlib::deserialize(item.i2_n,in);
+ dlib::deserialize(item.i3_n,in);
+
+
+ dlib::deserialize(item.f1,in);
+ dlib::deserialize(item.f2,in);
+ dlib::deserialize(item.f3,in);
+
+ dlib::deserialize(item.f1_inf,in);
+ dlib::deserialize(item.f2_inf,in);
+ dlib::deserialize(item.f3_inf,in);
+ dlib::deserialize(item.f1_ninf,in);
+ dlib::deserialize(item.f2_ninf,in);
+ dlib::deserialize(item.f3_ninf,in);
+ dlib::deserialize(item.f1_qnan,in);
+ dlib::deserialize(item.f2_qnan,in);
+ dlib::deserialize(item.f3_qnan,in);
+ dlib::deserialize(item.f1_snan,in);
+ dlib::deserialize(item.f2_snan,in);
+ dlib::deserialize(item.f3_snan,in);
+
+ dlib::deserialize(item.s1,in);
+ dlib::deserialize(item.s2,in);
+
+ dlib::deserialize(item.array,in);
+
+ dlib::deserialize(item.b_true,in);
+ dlib::deserialize(item.b_false,in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'stuff.bin' but using the old
+ // floating point serialization format.
+ const std::string get_decoded_string()
+ {
+ dlib::base64::kernel_1a base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+
+ // The base64 encoded data from the file 'stuff.bin' we want to decode and return.
+ sout << "AVaifX9zEbXa9aocsrcRuvnNrR3WLuuU5eLWiy0UeXmnKXGLKZz8V44gzT4CM6wnCmAHFQug8G3C";
+ sout << "4cuLdNgp2ApkeLcvwFNJRENE0ShrRaxEBFEA8nah7vm8B2VmgImNblCejuP5IcDt60EaCKlqiit8";
+ sout << "+JGrzYxqBm3xFS4P+qlOROdbxc7pXBmUdh0rqNSEvn0FBPdoqY/5SpHgA2yAcH8XFrM1cdu0xS3P";
+ sout << "8PBcmLMJ7bFdzplwhrjuxtm4NfEOi6Rl9sU44AXycYgJd0+uH+dyoI9X3co5b3YWJtjvdVeztNAr";
+ sout << "BfSPfR6oAVNfiMBG7QA=";
+
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+
+ // This function returns the contents of the file 'stuff.bin' but using the new
+ // floating point serialization format.
+ const std::string get_decoded_string2()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'stuff.bin' we want to decode and return.
+ sout << "AVaifX9zEbXa9aocsrcRuvnNqzZLptZ5mRd46xScCIfX6sq/46hG9JwIInElG50EtJKJY/+jAWit";
+ sout << "TpDBWrxBz124JRLsBz62h0D3Tqgnd8zygRx7t33Ybw40o07MrhzNEHgYavUukaPje5by78JIWHgk";
+ sout << "l7nb/TK+9ndVLrAThJ4v+GiPT3kh9H1tAAAAAQhbLa06pQjhrnjTXcRox1ZBEAV9/q1zAA==";
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // Declare the logger we will use in this test. The name of the tester
+ // should start with "test."
+ logger dlog("test.serialize");
+
+ void serialize_test (
+ )
+ /*!
+ ensures
+ - runs tests on the serialization code for compliance with the specs
+ !*/
+ {
+
+
+ print_spinner();
+
+ ostringstream sout;
+ test_object obj;
+
+ obj.set_state_1();
+ obj.assert_in_state_1();
+ serialize(obj, sout);
+ obj.assert_in_state_1();
+
+ obj.set_state_2();
+ obj.assert_in_state_2();
+ serialize(obj, sout);
+ obj.assert_in_state_2();
+
+
+ istringstream sin(sout.str());
+
+ deserialize(obj,sin);
+ obj.assert_in_state_1();
+ deserialize(obj,sin);
+ obj.assert_in_state_2();
+
+
+ // now do the same thing as above but deserialize from some stored binary
+ // data to make sure the serialized values are portable between different
+ // machines
+
+ sin.clear();
+ sin.str(get_decoded_string());
+ deserialize(obj,sin);
+ obj.assert_in_state_1();
+ deserialize(obj,sin);
+ obj.assert_in_state_2();
+
+
+ sin.clear();
+ sin.str(get_decoded_string2());
+ deserialize(obj,sin);
+ obj.assert_in_state_1();
+ deserialize(obj,sin);
+ obj.assert_in_state_2();
+
+
+ /*
+ // This is the code that produced the encoded data stored in the get_decoded_string() function
+ ofstream fout("stuff.bin",ios::binary);
+ obj.set_state_1();
+ obj.assert_in_state_1();
+ serialize(obj, fout);
+ obj.assert_in_state_1();
+
+ obj.set_state_2();
+ obj.assert_in_state_2();
+ serialize(obj, fout);
+ obj.assert_in_state_2();
+ */
+
+
+ test_object obj2;
+ obj.set_state_1();
+ obj2.set_state_2();
+ dlib::serialize("serialization_test.dat") << obj << obj2;
+ obj.assert_in_state_1();
+ obj2.assert_in_state_2();
+ obj.set_state_2();
+ obj2.set_state_1();
+ obj.assert_in_state_2();
+ obj2.assert_in_state_1();
+ dlib::deserialize("serialization_test.dat") >> obj >> obj2;
+ obj.assert_in_state_1();
+ obj2.assert_in_state_2();
+ }
+
+
+ template <typename T>
+ void test_vector (
+ )
+ {
+ std::vector<T> a, b;
+
+ for (int i = -10; i < 30; ++i)
+ {
+ a.push_back(i);
+ }
+
+ ostringstream sout;
+ dlib::serialize(a, sout);
+ istringstream sin(sout.str());
+
+ dlib::deserialize(b, sin);
+
+
+ DLIB_TEST(a.size() == b.size());
+ DLIB_TEST(a.size() == 40);
+ for (unsigned long i = 0; i < a.size(); ++i)
+ {
+ DLIB_TEST(a[i] == b[i]);
+ }
+
+ std::vector<T> c;
+ sout.str("");
+ dlib::serialize(c, sout);
+ sin.str(sout.str());
+ dlib::deserialize(a, sin);
+ DLIB_TEST(a.size() == 0);
+ DLIB_TEST(c.size() == 0);
+ }
+
+ void test_std_array (
+ )
+ {
+ std::array<int,5> a, b;
+
+ a = {1, 2, 3, 4, 5};
+
+ ostringstream sout;
+ dlib::serialize(a, sout);
+ istringstream sin(sout.str());
+
+ dlib::deserialize(b, sin);
+
+
+ DLIB_TEST(a.size() == b.size());
+ DLIB_TEST(a.size() == 5);
+ for (unsigned long i = 0; i < a.size(); ++i)
+ {
+ DLIB_TEST(a[i] == b[i]);
+ }
+
+ std::array<int,0> aa, bb;
+ sout.str("");
+ dlib::serialize(aa, sout);
+ sin.str(sout.str());
+ dlib::deserialize(bb, sin);
+ DLIB_TEST(bb.size() == 0);
+ }
+
+ void test_vector_bool (
+ )
+ {
+ std::vector<bool> a, b;
+
+ a.push_back(true);
+ a.push_back(true);
+ a.push_back(false);
+ a.push_back(true);
+ a.push_back(false);
+ a.push_back(true);
+
+ ostringstream sout;
+ dlib::serialize(a, sout);
+ istringstream sin(sout.str());
+
+ dlib::deserialize(b, sin);
+
+
+ DLIB_TEST(a.size() == b.size());
+ DLIB_TEST(a.size() == 6);
+ for (unsigned long i = 0; i < a.size(); ++i)
+ {
+ DLIB_TEST(a[i] == b[i]);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // This function returns the contents of the file 'matarray.dat'
+ const std::string get_decoded_string_matarray_old()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'matarray.dat' we want to decode and return.
+ sout << "AW852sEbTIeV+m/wLUcKJKPW+6IclviUWZcFh1daDZ0blDjPNTgPx0Lv56sIEwlG4I6C5OJzJBkZ";
+ sout << "PvczLjS7IEKh6eg7amNOyEexsQSgojL1oMe2gDEfkyInUGPJV90sNS0cvp/hIB134V8JCTYUP6vH";
+ sout << "9qpegLSIIQG+/NjLWyK2472vC88BJfKgkL3CPLMjQwB3tB928FNLbESDLIvpnb6q9ve68iuoyZZt";
+ sout << "z3TTJxHW3MIdgzuhNomvPxfo/Q+7lC/Orj0FewUX90al6DckwzOtLVRidh/ZKpsQsxzJYQGkjdX5";
+ sout << "mDzzXKqQb3Y3DnzEmwtRD9CUON3iRv1r26gHWLYorrYA";
+
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ // This function returns the contents of the file 'matarray.dat'
+ const std::string get_decoded_string_matarray()
+ {
+ dlib::base64 base64_coder;
+ dlib::compress_stream::kernel_1ea compressor;
+ std::ostringstream sout;
+ std::istringstream sin;
+
+ // The base64 encoded data from the file 'matarray.dat' we want to decode and return.
+ sout << "gO6XH2WGbm8Xaw3a5FJbh3V823W6P2Qk/vHaAAAAARccIppHWdmViaKby7JA5PQvXjYMWUYvXRHv";
+ sout << "xPdURZl1un3CT/rjT11Yry0y3+1W7GBmfBJ0gVFKGdiGuqoNAMtmzL/ll3YfEQ7ED7aB33aDTktw";
+ sout << "AWVkHT+gqTbKwjP+8YvB3s3ziK640ITOAWazAghKDVl7AHGn+fjq29paBZMczuJofl8FinZUhwa9";
+ sout << "Ol5gdAEQa6VZDmJUeo2soTJcEDpkW9LkRmXvjQkyEHfEHQNFDfQq4p2U+dHz4lOKlcj3VzQIeG/s";
+ sout << "oxa9KhJND4aQ5xeNUUHUzFBU3XhQHlyDIn/RNdX/ZwA=";
+
+
+ // Put the data into the istream sin
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decode the base64 text into its compressed binary form
+ base64_coder.decode(sin,sout);
+ sin.clear();
+ sin.str(sout.str());
+ sout.str("");
+
+ // Decompress the data into its original form
+ compressor.decompress(sin,sout);
+
+ // Return the decoded and decompressed data
+ return sout.str();
+ }
+
+ void setup_mats_and_arrays (
+ array2d<int>& a,
+ matrix<int>& m,
+ array2d<unsigned char>& img1,
+ array2d<rgb_pixel>& img2,
+ array2d<bgr_pixel>& img3,
+ array2d<rgb_alpha_pixel>& img4,
+ array2d<hsi_pixel>& img5
+ )
+ {
+ a.set_size(3,5);
+ int cnt = 0;
+ for (long r = 0; r < a.nr(); ++r)
+ {
+ for (long c = 0; c < a.nc(); ++c)
+ {
+ a[r][c] = cnt++;
+ }
+ }
+ m = mat(a);
+
+ img1.set_size(3,5);
+ img2.set_size(3,5);
+ img3.set_size(3,5);
+ img4.set_size(3,5);
+ img5.set_size(3,5);
+
+ assign_all_pixels(img1, 0);
+ assign_all_pixels(img2, 0);
+ assign_all_pixels(img3, 0);
+ assign_all_pixels(img4, 0);
+ assign_all_pixels(img5, 0);
+
+ unsigned char pcnt = 0;
+ for (long r = 0; r < img1.nr(); ++r)
+ {
+ for (long c = 0; c < img1.nc(); ++c)
+ {
+ rgb_alpha_pixel temp;
+ temp.red = pcnt++;
+ temp.green = pcnt++;
+ temp.blue = pcnt++;
+ temp.alpha = 150+pcnt++;
+ assign_pixel(img1[r][c], temp);
+ assign_pixel(img2[r][c], temp);
+ assign_pixel(img3[r][c], temp);
+ assign_pixel(img4[r][c], temp);
+ }
+ }
+
+ for (long r = 0; r < img5.nr(); ++r)
+ {
+ for (long c = 0; c < img5.nc(); ++c)
+ {
+ img5[r][c].h = pcnt++;
+ img5[r][c].s = pcnt++;
+ img5[r][c].i = pcnt++;
+ }
+ }
+ }
+
+
+ void test_deserialize(
+ std::istream& fin
+ )
+ {
+ array2d<int> a;
+ matrix<int> m;
+ array2d<unsigned char> img1;
+ array2d<rgb_pixel> img2;
+ array2d<bgr_pixel> img3;
+ array2d<rgb_alpha_pixel> img4;
+ array2d<hsi_pixel> img5;
+ setup_mats_and_arrays(a,m,img1,img2,img3,img4,img5);
+
+
+ array2d<unsigned char> img1_;
+ array2d<rgb_pixel> img2_;
+ array2d<bgr_pixel> img3_;
+ array2d<rgb_alpha_pixel> img4_;
+ array2d<hsi_pixel> img5_;
+
+ matrix<int> m_;
+ array2d<int> a_;
+
+ deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a));
+ deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m));
+ deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a));
+ deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m));
+
+ deserialize(img1_, fin); DLIB_TEST(mat(img1_) == mat(img1));
+ deserialize(img2_, fin); DLIB_TEST(mat(img2_) == mat(img2));
+ deserialize(img3_, fin); DLIB_TEST(mat(img3_) == mat(img3));
+ deserialize(img4_, fin); DLIB_TEST(mat(img4_) == mat(img4));
+ deserialize(img5_, fin); DLIB_TEST(mat(img5_) == mat(img5));
+ }
+
+ void test_deserialize_all_array2d(
+ std::istream& fin
+ )
+ {
+ array2d<int> a;
+ matrix<int> m;
+ array2d<unsigned char> img1;
+ array2d<rgb_pixel> img2;
+ array2d<bgr_pixel> img3;
+ array2d<rgb_alpha_pixel> img4;
+ array2d<hsi_pixel> img5;
+ setup_mats_and_arrays(a,m,img1,img2,img3,img4,img5);
+
+
+ array2d<unsigned char> img1_;
+ array2d<rgb_pixel> img2_;
+ array2d<bgr_pixel> img3_;
+ array2d<rgb_alpha_pixel> img4_;
+ array2d<hsi_pixel> img5_;
+
+ array2d<int> m_;
+ array2d<int> a_;
+
+ deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a));
+ deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m));
+ deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a));
+ deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m));
+
+ deserialize(img1_, fin); DLIB_TEST(mat(img1_) == mat(img1));
+ deserialize(img2_, fin); DLIB_TEST(mat(img2_) == mat(img2));
+ deserialize(img3_, fin); DLIB_TEST(mat(img3_) == mat(img3));
+ deserialize(img4_, fin); DLIB_TEST(mat(img4_) == mat(img4));
+ deserialize(img5_, fin); DLIB_TEST(mat(img5_) == mat(img5));
+ }
+
+ void test_deserialize_all_matrix(
+ std::istream& fin
+ )
+ {
+ array2d<int> a;
+ matrix<int> m;
+ array2d<unsigned char> img1;
+ array2d<rgb_pixel> img2;
+ array2d<bgr_pixel> img3;
+ array2d<rgb_alpha_pixel> img4;
+ array2d<hsi_pixel> img5;
+ setup_mats_and_arrays(a,m,img1,img2,img3,img4,img5);
+
+
+ matrix<unsigned char> img1_;
+ matrix<rgb_pixel> img2_;
+ matrix<bgr_pixel> img3_;
+ matrix<rgb_alpha_pixel> img4_;
+ matrix<hsi_pixel> img5_;
+
+ matrix<int> m_;
+ matrix<int> a_;
+
+ deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a));
+ deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m));
+ deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a));
+ deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m));
+
+ deserialize(img1_, fin); DLIB_TEST(mat(img1_) == mat(img1));
+ deserialize(img2_, fin); DLIB_TEST(mat(img2_) == mat(img2));
+ deserialize(img3_, fin); DLIB_TEST(mat(img3_) == mat(img3));
+ deserialize(img4_, fin); DLIB_TEST(mat(img4_) == mat(img4));
+ deserialize(img5_, fin); DLIB_TEST(mat(img5_) == mat(img5));
+ }
+
+ void test_array2d_and_matrix_serialization()
+ {
+ ostringstream sout;
+ array2d<int> a;
+ matrix<int> m;
+ array2d<unsigned char> img1;
+ array2d<rgb_pixel> img2;
+ array2d<bgr_pixel> img3;
+ array2d<rgb_alpha_pixel> img4;
+ array2d<hsi_pixel> img5;
+ setup_mats_and_arrays(a,m,img1,img2,img3,img4,img5);
+
+ serialize(a, sout);
+ serialize(m, sout);
+ serialize(a, sout);
+ serialize(m, sout);
+
+ serialize(img1, sout);
+ serialize(img2, sout);
+ serialize(img3, sout);
+ serialize(img4, sout);
+ serialize(img5, sout);
+
+ // --------------------
+
+ {
+ istringstream sin(sout.str());
+ test_deserialize(sin);
+ }
+ {
+ istringstream sin(sout.str());
+ test_deserialize_all_array2d(sin);
+ }
+ {
+ istringstream sin(sout.str());
+ test_deserialize_all_matrix(sin);
+ }
+
+
+ {
+ istringstream sin(get_decoded_string_matarray());
+ test_deserialize(sin);
+ }
+ {
+ istringstream sin(get_decoded_string_matarray());
+ test_deserialize_all_array2d(sin);
+ }
+ {
+ istringstream sin(get_decoded_string_matarray());
+ test_deserialize_all_matrix(sin);
+ }
+
+
+ {
+ // Make sure we can still deserialize the serialization
+ // format for array2d and matrix objects used by older versions
+ // of dlib.
+ istringstream sin(get_decoded_string_matarray_old());
+ test_deserialize(sin);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_strings()
+ {
+ string str1 = "stuff";
+ char buf[6];
+ buf[0] = 0;
+ buf[1] = 1;
+ buf[2] = 2;
+ buf[3] = 0;
+ buf[4] = 3;
+ buf[5] = 3;
+
+ dlib::serialize("ser_test_string.dat") << str1 << buf << "morestuff" << "";
+
+ string str2, str3, str4;
+ char buf2[6];
+ memset(buf2,0,sizeof(buf2));
+ dlib::deserialize("ser_test_string.dat") >> str2 >> buf2 >> str3 >> str4;
+ DLIB_TEST(str2 == "stuff");
+ DLIB_TEST(str3 == "morestuff");
+ DLIB_TEST(str4 == "");
+ DLIB_TEST(buf2[0] == 0);
+ DLIB_TEST(buf2[1] == 1);
+ DLIB_TEST(buf2[2] == 2);
+ DLIB_TEST(buf2[3] == 0);
+ DLIB_TEST(buf2[4] == 3);
+ DLIB_TEST(buf2[5] == 3);
+
+
+ ofstream fout("ser_test_string.dat", ios::binary);
+ dlib::serialize(str1, fout);
+ dlib::serialize(buf, fout);
+ dlib::serialize("morestuff", fout);
+ fout.close();
+ ifstream fin("ser_test_string.dat", ios::binary);
+ memset(buf2,0,sizeof(buf2));
+ str2.clear();
+ str3.clear();
+ dlib::deserialize(str2, fin);
+ dlib::deserialize(buf2, fin);
+ dlib::deserialize(str3, fin);
+
+ DLIB_TEST(str2 == "stuff");
+ DLIB_TEST(str3 == "morestuff");
+ DLIB_TEST(buf2[0] == 0);
+ DLIB_TEST(buf2[1] == 1);
+ DLIB_TEST(buf2[2] == 2);
+ DLIB_TEST(buf2[3] == 0);
+ DLIB_TEST(buf2[4] == 3);
+ DLIB_TEST(buf2[5] == 3);
+
+
+
+ // make sure ramdump() overloads compile and work.
+ {
+ matrix<double,2,2> a = {1,2,3,4};
+ const matrix<double,2,2> b = {3,2,3,4};
+ dlib::serialize("ramdump_mat.dat") << ramdump(a) << ramdump(b);
+ matrix<double,2,2> A, B;
+ dlib::deserialize("ramdump_mat.dat") >> ramdump(A) >> ramdump(B);
+
+ DLIB_TEST(A == a);
+ DLIB_TEST(B == b);
+ A = 0;
+ B = 0;
+ DLIB_TEST(A != a);
+ DLIB_TEST(B != b);
+
+ ostringstream sout;
+ dlib::serialize(ramdump(a), sout);
+ dlib::serialize(ramdump(b), sout);
+ istringstream sin(sout.str());
+ dlib::deserialize(ramdump(A), sin);
+ dlib::deserialize(ramdump(B), sin);
+
+ DLIB_TEST(A == a);
+ DLIB_TEST(B == b);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class serialize_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a test for the serialization . When it is constructed
+ it adds itself into the testing framework. The command line switch is
+ specified as test_serialize by passing that string to the tester constructor.
+ !*/
+ public:
+ serialize_tester (
+ ) :
+ tester ("test_serialize",
+ "Runs tests on the serialization code.")
+ {}
+
+ void perform_test (
+ )
+ {
+ serialize_test();
+ test_vector<char>();
+ test_vector<unsigned char>();
+ test_vector<int>();
+ test_vector_bool();
+ test_array2d_and_matrix_serialization();
+ test_strings();
+ test_std_array();
+ }
+ } a;
+
+
+}
+
+
diff --git a/ml/dlib/dlib/test/set.cpp b/ml/dlib/dlib/test/set.cpp
new file mode 100644
index 000000000..f8d3bd374
--- /dev/null
+++ b/ml/dlib/dlib/test/set.cpp
@@ -0,0 +1,464 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/set.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.set");
+
+ template <
+ typename set
+ >
+ void set_compare_test (
+ )
+ /*!
+ requires
+ - set is an implementation of set/set_compare_abstract.h and
+ is instantiated with int
+ ensures
+ - runs tests on set for compliance with the specs
+ !*/
+ {
+
+
+ srand(static_cast<unsigned int>(time(0)));
+
+
+
+ set test, test2;
+
+ enumerable<const int>& e = test;
+ DLIB_TEST(e.at_start() == true);
+
+ for (int j = 0; j < 4; ++j)
+ {
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+
+
+ int a,b = 0;
+ a = 8;
+ test.add(a);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(test.is_member(8) == true);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+ a = 53;
+ test.add(a);
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test.is_member(53) == true);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+
+
+ swap(test,test2);
+
+
+
+ DLIB_TEST(test2.is_member(8) == true);
+ DLIB_TEST(test2.is_member(5) == false);
+ DLIB_TEST(test2.is_member(0) == false);
+ DLIB_TEST(test2.is_member(-999) == false);
+ DLIB_TEST(test2.is_member(4999) == false);
+ DLIB_TEST(test2.size() == 2);
+ DLIB_TEST(test2.is_member(53) == true);
+ DLIB_TEST(test2.is_member(5) == false);
+ DLIB_TEST(test2.is_member(0) == false);
+ DLIB_TEST(test2.is_member(-999) == false);
+ DLIB_TEST(test2.is_member(4999) == false);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_member(8) == false);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.is_member(53) == false);
+ DLIB_TEST(test.is_member(5) == false);
+ DLIB_TEST(test.is_member(0) == false);
+ DLIB_TEST(test.is_member(-999) == false);
+ DLIB_TEST(test.is_member(4999) == false);
+
+
+ test.clear();
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+ test.clear();
+ DLIB_TEST(test.size() == 0);
+
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ DLIB_TEST(test.size() == 10000);
+
+ int count = 0;
+ a = 0;
+ while (test.move_next())
+ {
+ enumerable<const int>& gogo = test;
+ gogo.element();
+
+ DLIB_TEST(test.element() == test.element());
+ DLIB_TEST(test.element() == test.element());
+ DLIB_TEST(test.element() == test.element());
+
+ DLIB_TEST(a <= test.element());
+ a = test.element();
+ ++count;
+ }
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.move_next() == false);
+
+ DLIB_TEST(count == 10000);
+
+ test.swap(test2);
+
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test2.size() == 10000);
+ count = 0;
+ a = -1;
+ test2.reset();
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(a < test2.element());
+ a = test2.element();
+ ++count;
+ }
+ DLIB_TEST(test2.size() == 10000);
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.move_next() == false);
+
+
+
+ test2.clear();
+ DLIB_TEST(test2.size() == 0);
+ DLIB_TEST(test2.at_start() == true);
+
+ while (test.size() < 20000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ DLIB_TEST(test.at_start() == true);
+
+ {
+ int* array = new int[test.size()];
+ int* tmp = array;
+
+ count = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element() == test.element());
+ DLIB_TEST(test.element() == test.element());
+ DLIB_TEST(test.element() == test.element());
+ *tmp = test.element();
+ ++tmp;
+ ++count;
+ }
+ DLIB_TEST(count == 20000);
+
+ // serialize the state of test, then clear test, then
+ // load the state back into test.
+ ostringstream sout;
+ serialize(test,sout);
+ DLIB_TEST(test.at_start() == true);
+ istringstream sin(sout.str());
+ test.clear();
+ deserialize(test,sin);
+
+
+
+ tmp = array;
+ for (int i = 0; i < 20000; ++i)
+ {
+ DLIB_TEST(test.is_member(*tmp) == true);
+ ++tmp;
+ }
+
+ DLIB_TEST(test.size() == 20000);
+
+ tmp = array;
+ count = 0;
+ while (test.size() > 10000)
+ {
+ test.remove(*tmp,a);
+ DLIB_TEST(*tmp == a);
+ ++tmp;
+ ++count;
+ }
+ DLIB_TEST(count == 10000);
+ DLIB_TEST(test.size() == 10000);
+
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element() == *tmp);
+ DLIB_TEST(test.element() == *tmp);
+ DLIB_TEST(test.element() == *tmp);
+ ++tmp;
+ ++count;
+ }
+ DLIB_TEST(count == 20000);
+ DLIB_TEST(test.size() == 10000);
+
+ while (test.size() < 20000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ test2.swap(test);
+
+ count = 0;
+ a = 0;
+ while (test2.move_next())
+ {
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(test2.element() == test2.element());
+ DLIB_TEST(a <= test2.element());
+ a = test2.element();
+ ++count;
+ }
+
+ DLIB_TEST(count == 20000);
+ DLIB_TEST(test2.size() == 20000);
+
+ a = -1;
+ while (test2.size()>0)
+ {
+ test2.remove_any(b);
+ DLIB_TEST( a < b);
+ a = b;
+ }
+
+ DLIB_TEST(test2.size() == 0);
+ delete [] array;
+ }
+
+ test.clear();
+ test2.clear();
+ while (test.size() < 10000)
+ {
+ a = ::rand();
+ if (!test.is_member(a))
+ test.add(a);
+ }
+
+ count = 0;
+ a = -1;
+ while (test.move_next())
+ {
+ DLIB_TEST(a < test.element());
+ a = test.element();
+ ++count;
+ if (count == 5000)
+ break;
+ DLIB_TEST(test.current_element_valid() == true);
+ }
+
+ test.reset();
+
+ count = 0;
+ a = -1;
+ while (test.move_next())
+ {
+ DLIB_TEST(a < test.element());
+ a = test.element();
+ ++count;
+ DLIB_TEST(test.current_element_valid() == true);
+ }
+
+ DLIB_TEST(count == 10000);
+
+
+ test.clear();
+ test2.clear();
+ }
+
+
+
+ {
+ DLIB_TEST(test == test2);
+ DLIB_TEST((test < test2) == false);
+ DLIB_TEST((test2 < test) == false);
+
+ int a = 3, b = 3;
+ test.add(a);
+ test2.add(b);
+ test.move_next();
+ DLIB_TEST(test == test2);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test.move_next();
+ DLIB_TEST((test < test2) == false);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test.move_next();
+ DLIB_TEST((test2 < test) == false);
+ DLIB_TEST(test.at_start() && test2.at_start());
+
+ a = 2; b = 5;
+ test.add(a);
+ test2.add(b);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test2.move_next();
+ DLIB_TEST((test == test2) == false);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test2.move_next();
+ DLIB_TEST((test < test2) == true);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test2.move_next();
+ DLIB_TEST((test2 < test) == false);
+ DLIB_TEST(test.at_start() && test2.at_start());
+
+
+ a = 8;
+ test.add(a);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test2.move_next();
+ DLIB_TEST((test == test2) == false);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test2.move_next();
+ DLIB_TEST((test < test2) == false);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test2.move_next();
+ DLIB_TEST((test2 < test) == true);
+ DLIB_TEST(test.at_start() && test2.at_start());
+
+ test.clear();
+
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test2.move_next();
+ DLIB_TEST((test == test2) == false);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test2.move_next();
+ DLIB_TEST((test < test2) == true);
+ DLIB_TEST(test.at_start() && test2.at_start());
+ test2.move_next();
+ DLIB_TEST((test2 < test) == false);
+ DLIB_TEST(test.at_start() && test2.at_start());
+
+
+ }
+
+
+ {
+ test.clear();
+ DLIB_TEST(test.size() == 0);
+ int a = 5;
+ test.add(a);
+ a = 7;
+ test.add(a);
+ DLIB_TEST(test.size() == 2);
+ DLIB_TEST(test.is_member(7));
+ DLIB_TEST(test.is_member(5));
+ test.destroy(7);
+ DLIB_TEST(test.size() == 1);
+ DLIB_TEST(!test.is_member(7));
+ DLIB_TEST(test.is_member(5));
+ test.destroy(5);
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(!test.is_member(7));
+ DLIB_TEST(!test.is_member(5));
+ }
+
+
+ }
+
+
+
+
+ class set_tester : public tester
+ {
+ public:
+ set_tester (
+ ) :
+ tester ("test_set",
+ "Runs tests on the set component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing compare_1a";
+ set_compare_test<dlib::set<int>::compare_1a> ();
+ dlog << LINFO << "testing compare_1a_c";
+ set_compare_test<dlib::set<int>::compare_1a_c>();
+ dlog << LINFO << "testing compare_1b";
+ set_compare_test<dlib::set<int>::compare_1b> ();
+ dlog << LINFO << "testing compare_1b_c";
+ set_compare_test<dlib::set<int>::compare_1b_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/sldf.cpp b/ml/dlib/dlib/test/sldf.cpp
new file mode 100644
index 000000000..4ca9a3dd0
--- /dev/null
+++ b/ml/dlib/dlib/test/sldf.cpp
@@ -0,0 +1,296 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm.h>
+#include <dlib/rand.h>
+#include <dlib/string.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+#include <dlib/data_io.h>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.sldf");
+
+
+ class sldf_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ sldf_tester (
+ ) :
+ tester (
+ "test_sldf", // the command line argument name for this test
+ "Run tests on the simplify_linear_decision_function routines.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+ dlib::rand rnd;
+
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ typedef std::map<unsigned long,double> sample_type;
+
+ typedef matrix<double,0,1> dense_sample_type;
+
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+ typedef linear_kernel<dense_sample_type> dense_kernel_type;
+
+
+ svm_nu_trainer<kernel_type> linear_trainer;
+ linear_trainer.set_nu(0.2);
+ svm_nu_trainer<dense_kernel_type> dense_linear_trainer;
+ dense_linear_trainer.set_nu(0.2);
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ // make an instance of a sample vector so we can use it below
+ sample_type sample;
+
+ // Now lets go into a loop and randomly generate 300 samples.
+ double label = +1;
+ for (int i = 0; i < 300; ++i)
+ {
+ // flip this flag
+ label *= -1;
+
+ sample.clear();
+
+ // now make a random sparse sample with at most 10 non-zero elements
+ for (int j = 0; j < 10; ++j)
+ {
+ int idx = rnd.get_random_32bit_number()%100;
+ double value = rnd.get_random_double();
+
+ sample[idx] = label*value;
+ }
+
+ // Also save the samples we are generating so we can let the svm_c_linear_trainer
+ // learn from them below.
+ samples.push_back(sample);
+ labels.push_back(label);
+ }
+
+
+ {
+ print_spinner();
+ dlog << LINFO << " test with sparse samples ";
+ decision_function<kernel_type> df = linear_trainer.train(samples, labels);
+
+ dlog << LINFO << "df.basis_vectors.size(): "<< df.basis_vectors.size();
+ DLIB_TEST(df.basis_vectors.size() > 4);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(df, samples, labels);
+
+ // save the outputs of the decision function before we mess with it
+ std::vector<double> prev_vals;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ prev_vals.push_back(df(samples[i]));
+
+ df = simplify_linear_decision_function(df);
+
+ dlog << LINFO << "df.basis_vectors.size(): "<< df.basis_vectors.size();
+ DLIB_TEST(df.basis_vectors.size() == 1);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(df, samples, labels);
+
+ // now check that the simplified decision function still produces the same results
+ std::vector<double> cur_vals;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ cur_vals.push_back(df(samples[i]));
+
+ const double err = max(abs(mat(cur_vals) - mat(prev_vals)));
+ dlog << LINFO << "simplify error: "<< err;
+ DLIB_TEST(err < 1e-13);
+
+ }
+
+
+ // same as above but call simplify_linear_decision_function() two times
+ {
+ print_spinner();
+ dlog << LINFO << " test with sparse samples ";
+ decision_function<kernel_type> df = linear_trainer.train(samples, labels);
+
+ dlog << LINFO << "df.basis_vectors.size(): "<< df.basis_vectors.size();
+ DLIB_TEST(df.basis_vectors.size() > 4);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(df, samples, labels);
+
+ // save the outputs of the decision function before we mess with it
+ std::vector<double> prev_vals;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ prev_vals.push_back(df(samples[i]));
+
+ df = simplify_linear_decision_function(df);
+ df = simplify_linear_decision_function(df);
+
+ dlog << LINFO << "df.basis_vectors.size(): "<< df.basis_vectors.size();
+ DLIB_TEST(df.basis_vectors.size() == 1);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(df, samples, labels);
+
+ // now check that the simplified decision function still produces the same results
+ std::vector<double> cur_vals;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ cur_vals.push_back(df(samples[i]));
+
+ const double err = max(abs(mat(cur_vals) - mat(prev_vals)));
+ dlog << LINFO << "simplify error: "<< err;
+ DLIB_TEST(err < 1e-13);
+
+ }
+
+
+ {
+ print_spinner();
+ dlog << LINFO << " test with dense samples ";
+ std::vector<dense_sample_type> dense_samples(sparse_to_dense(samples));
+
+ // In addition to the rule we learned with the pegasos trainer lets also use our linear_trainer
+ // to learn a decision rule.
+ decision_function<dense_kernel_type> dense_df = dense_linear_trainer.train(dense_samples, labels);
+
+ dlog << LINFO << "dense_df.basis_vectors.size(): "<< dense_df.basis_vectors.size();
+ DLIB_TEST(dense_df.basis_vectors.size() > 4);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels);
+
+ // save the outputs of the decision function before we mess with it
+ std::vector<double> prev_vals;
+ for (unsigned long i = 0; i < dense_samples.size(); ++i)
+ prev_vals.push_back(dense_df(dense_samples[i]));
+
+ dense_df = simplify_linear_decision_function(dense_df);
+
+ dlog << LINFO << "dense_df.basis_vectors.size(): "<< dense_df.basis_vectors.size();
+ DLIB_TEST(dense_df.basis_vectors.size() == 1);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels);
+
+
+ // now check that the simplified decision function still produces the same results
+ std::vector<double> cur_vals;
+ for (unsigned long i = 0; i < dense_samples.size(); ++i)
+ cur_vals.push_back(dense_df(dense_samples[i]));
+
+ const double err = max(abs(mat(cur_vals) - mat(prev_vals)));
+ dlog << LINFO << "simplify error: "<< err;
+ DLIB_TEST(err < 1e-13);
+ }
+
+ // same as above but call simplify_linear_decision_function() two times
+ {
+ print_spinner();
+ dlog << LINFO << " test with dense samples ";
+ std::vector<dense_sample_type> dense_samples(sparse_to_dense(samples));
+
+ // In addition to the rule we learned with the pegasos trainer lets also use our linear_trainer
+ // to learn a decision rule.
+ decision_function<dense_kernel_type> dense_df = dense_linear_trainer.train(dense_samples, labels);
+
+ dlog << LINFO << "dense_df.basis_vectors.size(): "<< dense_df.basis_vectors.size();
+ DLIB_TEST(dense_df.basis_vectors.size() > 4);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels);
+
+ // save the outputs of the decision function before we mess with it
+ std::vector<double> prev_vals;
+ for (unsigned long i = 0; i < dense_samples.size(); ++i)
+ prev_vals.push_back(dense_df(dense_samples[i]));
+
+ dense_df = simplify_linear_decision_function(dense_df);
+ dense_df = simplify_linear_decision_function(dense_df);
+
+ dlog << LINFO << "dense_df.basis_vectors.size(): "<< dense_df.basis_vectors.size();
+ DLIB_TEST(dense_df.basis_vectors.size() == 1);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels);
+
+
+ // now check that the simplified decision function still produces the same results
+ std::vector<double> cur_vals;
+ for (unsigned long i = 0; i < dense_samples.size(); ++i)
+ cur_vals.push_back(dense_df(dense_samples[i]));
+
+ const double err = max(abs(mat(cur_vals) - mat(prev_vals)));
+ dlog << LINFO << "simplify error: "<< err;
+ DLIB_TEST(err < 1e-13);
+ }
+
+ {
+ print_spinner();
+
+ dlog << LINFO << " test with sparse samples and a vector normalizer";
+ std::vector<dense_sample_type> dense_samples(sparse_to_dense(samples));
+ std::vector<dense_sample_type> norm_samples;
+
+ // make a normalizer and normalize everything
+ vector_normalizer<dense_sample_type> normalizer;
+ normalizer.train(dense_samples);
+ for (unsigned long i = 0; i < dense_samples.size(); ++i)
+ norm_samples.push_back(normalizer(dense_samples[i]));
+
+ normalized_function<decision_function<dense_kernel_type> > dense_df;
+
+ dense_df.normalizer = normalizer;
+ dense_df.function = dense_linear_trainer.train(norm_samples, labels);
+
+ dlog << LINFO << "dense_df.function.basis_vectors.size(): "<< dense_df.function.basis_vectors.size();
+ DLIB_TEST(dense_df.function.basis_vectors.size() > 4);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels);
+
+ // save the outputs of the decision function before we mess with it
+ std::vector<double> prev_vals;
+ for (unsigned long i = 0; i < dense_samples.size(); ++i)
+ prev_vals.push_back(dense_df(dense_samples[i]));
+
+
+ decision_function<dense_kernel_type> simple_df = simplify_linear_decision_function(dense_df);
+
+ dlog << LINFO << "simple_df.basis_vectors.size(): "<< simple_df.basis_vectors.size();
+ DLIB_TEST(simple_df.basis_vectors.size() == 1);
+
+ dlog << LINFO << "test scores: "<< test_binary_decision_function(simple_df, dense_samples, labels);
+
+
+ // now check that the simplified decision function still produces the same results
+ std::vector<double> cur_vals;
+ for (unsigned long i = 0; i < dense_samples.size(); ++i)
+ cur_vals.push_back(simple_df(dense_samples[i]));
+
+ const double err = max(abs(mat(cur_vals) - mat(prev_vals)));
+ dlog << LINFO << "simplify error: "<< err;
+ DLIB_TEST(err < 1e-13);
+
+ }
+
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ sldf_tester a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/sliding_buffer.cpp b/ml/dlib/dlib/test/sliding_buffer.cpp
new file mode 100644
index 000000000..449ea858e
--- /dev/null
+++ b/ml/dlib/dlib/test/sliding_buffer.cpp
@@ -0,0 +1,439 @@
+// Copyright (C) 2004 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <ctime>
+#include <cstdlib>
+
+#include <dlib/sliding_buffer.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.sliding_buffer");
+
+ template <
+ typename buf
+ >
+ void sliding_buffer_kernel_test (
+ )
+ /*!
+ requires
+ - buf is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h
+ - buf is instantiated with T=unsigned char
+ ensures
+ - runs tests on buf for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ buf test;
+
+ DLIB_TEST(test.size() == 0);
+
+ test.set_size(3);
+ buf test2;
+
+ DLIB_TEST(test.size() == 8);
+
+ for (int g = 0; g < 2; ++g)
+ {
+
+ test.clear();
+
+ DLIB_TEST(test.size() == 0);
+ test.set_size(2);
+
+ DLIB_TEST(test.size() == 4);
+
+
+
+ test[0] = 'a';
+ test[1] = 's';
+ test[2] = 'd';
+ test[3] = 'f';
+
+ unsigned long id = test.get_element_id(2);
+ DLIB_TEST(test[test.get_element_index(id)] == 'd');
+
+
+ DLIB_TEST(test[0] == 'a');
+ DLIB_TEST(test[1] == 's');
+ DLIB_TEST(test[2] == 'd');
+ DLIB_TEST(test[3] == 'f');
+
+ DLIB_TEST(test2.size() == 0);
+ swap(test,test2);
+ DLIB_TEST(test2.size() == 4);
+
+ DLIB_TEST(test2[0] == 'a');
+ DLIB_TEST(test2[1] == 's');
+ DLIB_TEST(test2[2] == 'd');
+ DLIB_TEST(test2[3] == 'f');
+
+ swap(test,test2);
+
+ test.rotate_left(4);
+
+
+ DLIB_TEST(test[test.get_element_index(id)] == 'd');
+
+ DLIB_TEST(test[0] == 'a');
+ DLIB_TEST(test[1] == 's');
+ DLIB_TEST(test[2] == 'd');
+ DLIB_TEST(test[3] == 'f');
+
+ test.rotate_right(1);
+
+ DLIB_TEST(test[test.get_element_index(id)] == 'd');
+
+ DLIB_TEST(test[0] == 's');
+ DLIB_TEST(test[1] == 'd');
+ DLIB_TEST(test[2] == 'f');
+ DLIB_TEST(test[3] == 'a');
+
+
+ test.rotate_left(1);
+
+ DLIB_TEST(test[test.get_element_index(id)] == 'd');
+ DLIB_TEST(test[0] == 'a');
+ DLIB_TEST(test[1] == 's');
+ DLIB_TEST(test[2] == 'd');
+ DLIB_TEST(test[3] == 'f');
+
+
+ test.rotate_left(16);
+
+ DLIB_TEST(test[test.get_element_index(id)] == 'd');
+ DLIB_TEST(test[0] == 'a');
+ DLIB_TEST(test[1] == 's');
+ DLIB_TEST(test[2] == 'd');
+ DLIB_TEST(test[3] == 'f');
+
+
+ test.rotate_left(2);
+
+ DLIB_TEST(test[test.get_element_index(id)] == 'd');
+
+ DLIB_TEST(test[0] == 'd');
+ DLIB_TEST(test[1] == 'f');
+ DLIB_TEST(test[2] == 'a');
+ DLIB_TEST(test[3] == 's');
+
+ test.rotate_left(1);
+
+ DLIB_TEST(test[test.get_element_index(id)] == 'd');
+ DLIB_TEST(test[0] == 's');
+ DLIB_TEST(test[1] == 'd');
+ DLIB_TEST(test[2] == 'f');
+ DLIB_TEST(test[3] == 'a');
+
+ test.rotate_left(1);
+
+ DLIB_TEST(test[test.get_element_index(id)] == 'd');
+ DLIB_TEST(test[0] == 'a');
+ DLIB_TEST(test[1] == 's');
+ DLIB_TEST(test[2] == 'd');
+ DLIB_TEST(test[3] == 'f');
+
+ DLIB_TEST(test.size() == 4);
+
+ test[0] = 'x';
+
+ DLIB_TEST(test[0] == 'x');
+ DLIB_TEST(test[1] == 's');
+ DLIB_TEST(test[2] == 'd');
+ DLIB_TEST(test[3] == 'f');
+
+ test.rotate_left(1);
+
+ DLIB_TEST_MSG(test[0] == 'f',test[0]);
+ DLIB_TEST(test[1] == 'x');
+ DLIB_TEST(test[2] == 's');
+ DLIB_TEST(test[3] == 'd');
+
+
+ test[0] = 'x';
+
+ DLIB_TEST(test[0] == 'x');
+ DLIB_TEST(test[1] == 'x');
+ DLIB_TEST(test[2] == 's');
+ DLIB_TEST(test[3] == 'd');
+
+
+ test.rotate_left(1);
+
+
+ DLIB_TEST(test[0] == 'd');
+ DLIB_TEST(test[1] == 'x');
+ DLIB_TEST(test[2] == 'x');
+ DLIB_TEST(test[3] == 's');
+
+
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == true);
+
+ test.clear();
+ test2.clear();
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+
+ swap(test,test2);
+
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+ DLIB_TEST(test2.move_next() == false);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test2.current_element_valid() == false);
+
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.at_start() == false);
+
+ test.set_size(3);
+ DLIB_TEST(test.size() == 8);
+
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == true);
+ test.reset();
+ DLIB_TEST(test.size() == 8);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == true);
+
+
+ test.rotate_right(1);
+ DLIB_TEST(test.size() == 8);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == true);
+
+ test.rotate_left(1);
+ DLIB_TEST(test.size() == 8);
+ DLIB_TEST(test.at_start() == true);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == true);
+ DLIB_TEST(test.at_start() == false);
+ DLIB_TEST(test.current_element_valid() == true);
+ test.reset();
+
+
+ for (unsigned long i = 0; i < test.size(); ++i)
+ {
+ test[i] = static_cast<unsigned char>(i);
+ }
+
+ unsigned long count = 0;
+ while (test.move_next())
+ {
+ DLIB_TEST(test.element() == count);
+ ++count;
+ }
+
+ DLIB_TEST(count == test.size());
+
+
+ test2.clear();
+ ostringstream sout;
+ istringstream sin;
+
+ serialize(test,sout);
+ sin.str(sout.str());
+ deserialize(test2,sin);
+
+ char ch;
+ sin >> ch;
+ DLIB_TEST( !sin);
+
+ DLIB_TEST(test2.size() == test.size());
+
+
+ for (unsigned long i = 0; i < test.size(); ++i)
+ {
+ DLIB_TEST_MSG(test[i] == test2[i],
+ "\ni: " << i <<
+ "\ntest[i]: " << test[i] <<
+ "\ntest2[i]: " << test2[i]);
+ }
+
+ count = 0;
+ while (test.move_next() && test2.move_next())
+ {
+ DLIB_TEST(test.element() == count);
+ DLIB_TEST(test2.element() == count);
+ ++count;
+ }
+
+ DLIB_TEST(test2.size() == count);
+ DLIB_TEST(test.size() == count);
+
+ test2.clear();
+
+
+ } // for (int g = 0; g < 2; ++g)
+
+
+ }
+
+
+
+ void test_circular_buffer()
+ {
+ circular_buffer<int> buf;
+
+ DLIB_TEST(buf.size() == 0);
+
+ buf.assign(4, 0);
+ DLIB_TEST(buf.size() == 4);
+
+ DLIB_TEST(buf[0] == 0);
+ DLIB_TEST(buf[1] == 0);
+ DLIB_TEST(buf[2] == 0);
+ DLIB_TEST(buf[3] == 0);
+ buf.push_back(1);
+ DLIB_TEST(buf[0] == 0);
+ DLIB_TEST(buf[1] == 0);
+ DLIB_TEST(buf[2] == 0);
+ DLIB_TEST(buf[3] == 1);
+ buf.push_back(2);
+ DLIB_TEST(buf[0] == 0);
+ DLIB_TEST(buf[1] == 0);
+ DLIB_TEST(buf[2] == 1);
+ DLIB_TEST(buf[3] == 2);
+ buf.push_front(3);
+ DLIB_TEST(buf[0] == 3);
+ DLIB_TEST(buf[1] == 0);
+ DLIB_TEST(buf[2] == 0);
+ DLIB_TEST(buf[3] == 1);
+ buf.push_front(4);
+ DLIB_TEST(buf.front() == 4);
+ DLIB_TEST(buf[0] == 4);
+ DLIB_TEST(buf[1] == 3);
+ DLIB_TEST(buf[2] == 0);
+ DLIB_TEST(buf[3] == 0);
+
+ buf.assign(4, 5);
+ DLIB_TEST(buf[0] == 5);
+ DLIB_TEST(buf[1] == 5);
+ DLIB_TEST(buf[2] == 5);
+ DLIB_TEST(buf[3] == 5);
+
+ buf.push_back(3);
+ DLIB_TEST(buf[0] == 5);
+ DLIB_TEST(buf[1] == 5);
+ DLIB_TEST(buf[2] == 5);
+ DLIB_TEST(buf[3] == 3);
+ buf.push_back(2);
+ DLIB_TEST(buf[0] == 5);
+ DLIB_TEST(buf[1] == 5);
+ DLIB_TEST(buf[2] == 3);
+ DLIB_TEST(buf[3] == 2);
+ buf.push_back(1);
+ DLIB_TEST(buf[0] == 5);
+ DLIB_TEST(buf[1] == 3);
+ DLIB_TEST(buf[2] == 2);
+ DLIB_TEST(buf[3] == 1);
+ buf.push_back(0);
+ DLIB_TEST(buf[0] == 3);
+ DLIB_TEST(buf[1] == 2);
+ DLIB_TEST(buf[2] == 1);
+ DLIB_TEST(buf[3] == 0);
+ buf.push_back(-1);
+ DLIB_TEST(buf.back() == -1);
+ DLIB_TEST(buf[0] == 2);
+ DLIB_TEST(buf[1] == 1);
+ DLIB_TEST(buf[2] == 0);
+ DLIB_TEST(buf[3] == -1);
+
+ buf.resize(1);
+ buf[0] = 9;
+ DLIB_TEST(buf.size() == 1);
+ DLIB_TEST(buf[0] == 9);
+ buf.push_back(1);
+ DLIB_TEST(buf[0] == 1);
+ buf.push_back(4);
+ DLIB_TEST(buf[0] == 4);
+ buf.push_front(3);
+ DLIB_TEST(buf[0] == 3);
+
+ buf.clear();
+ DLIB_TEST(buf.size() == 0);
+
+ buf.assign(3, 0);
+
+ circular_buffer<int> buf2, buf3;
+
+ buf.push_back(1);
+ buf.push_back(2);
+
+ ostringstream sout;
+ serialize(buf, sout);
+ istringstream sin(sout.str());
+ deserialize(buf2, sin);
+
+ DLIB_TEST(buf.size() == buf2.size());
+ for (unsigned long i = 0; i < buf.size(); ++i)
+ DLIB_TEST(buf[i] == buf2[i]);
+
+ buf.swap(buf3);
+ DLIB_TEST(buf.size() == 0);
+ DLIB_TEST(buf3.size() == buf2.size());
+ for (unsigned long i = 0; i < buf3.size(); ++i)
+ DLIB_TEST(buf3[i] == buf2[i]);
+
+
+
+ }
+
+
+
+ class sliding_buffer_tester : public tester
+ {
+ public:
+ sliding_buffer_tester (
+ ) :
+ tester ("test_sliding_buffer",
+ "Runs tests on the sliding_buffer component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ sliding_buffer_kernel_test<sliding_buffer<unsigned char>::kernel_1a> ();
+ dlog << LINFO << "testing kernel_1a_c";
+ sliding_buffer_kernel_test<sliding_buffer<unsigned char>::kernel_1a_c>();
+
+ test_circular_buffer();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/smart_pointers.cpp b/ml/dlib/dlib/test/smart_pointers.cpp
new file mode 100644
index 000000000..c2281efda
--- /dev/null
+++ b/ml/dlib/dlib/test/smart_pointers.cpp
@@ -0,0 +1,449 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+// This is a legacy test for old dlib smart pointers which is excluded
+// from CMakeLists.txt. Including this test will pull legacy smart_pointers.h
+// code which is uncompilable on C++17 compilers
+
+#include <dlib/smart_pointers.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include "tester.h"
+
+// Don't warn about auto_ptr
+#if (defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 6) || (__GNUC__ > 4))) || \
+ (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4)))
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#endif
+
+namespace
+{
+ bool used_array_delete;
+ template <typename T>
+ struct test_deleter
+ {
+ void operator() (T* item) const
+ {
+ used_array_delete = false;
+ delete item;
+ }
+ };
+
+ template <typename T>
+ struct test_deleter<T[]>
+ {
+ void operator() (T* item) const
+ {
+ used_array_delete = true;
+ delete [] item;
+ }
+ };
+
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.smart_pointers");
+
+ int counter = 0;
+ struct base
+ {
+ int num;
+ virtual ~base() {}
+ };
+
+ struct derived : public base
+ {
+ derived() { ++counter; }
+ ~derived() { --counter; }
+ };
+
+ int deleter_called = 0;
+ void deleter ( derived* p) { ++deleter_called; delete p; }
+ void deleter_base ( base* p) { ++deleter_called; delete p; }
+ typedef void (*D)(derived*);
+ typedef void (*Db)(base*);
+
+ void smart_pointers_test (
+ )
+ /*!
+ ensures
+ - runs tests on the smart pointers for compliance with the specs
+ !*/
+ {
+ counter = 0;
+ deleter_called = 0;
+
+ {
+ DLIB_TEST_MSG(counter == 0,counter);
+ scoped_ptr<base> p1(new derived);
+ scoped_ptr<derived> p2(new derived);
+ scoped_ptr<derived> p3;
+ DLIB_TEST_MSG(counter == 2,counter);
+ DLIB_TEST(!p3);
+
+ p1->num = 1;
+ p2->num = 2;
+ DLIB_TEST(p1->num == 1);
+ DLIB_TEST(p2->num == 2);
+
+ (*p1).num = 3;
+ (*p2).num = 4;
+ DLIB_TEST(p1->num == 3);
+ DLIB_TEST(p2->num == 4);
+
+ DLIB_TEST_MSG(counter == 2,counter);
+
+ DLIB_TEST(p1);
+ DLIB_TEST(p2);
+
+ DLIB_TEST_MSG(counter == 2,counter);
+ p1.reset();
+ DLIB_TEST_MSG(counter == 1,counter);
+ DLIB_TEST(!p1);
+ DLIB_TEST(p2);
+ p1.reset(new derived);
+ DLIB_TEST_MSG(counter == 2,counter);
+ DLIB_TEST(p1);
+
+
+ DLIB_TEST_MSG(counter == 2,counter);
+ p2.reset();
+ DLIB_TEST_MSG(counter == 1,counter);
+ DLIB_TEST(!p2);
+ derived* d = new derived;
+ p2.reset(d);
+ DLIB_TEST(p2.get() == d);
+ DLIB_TEST_MSG(counter == 2,counter);
+ DLIB_TEST(p2);
+ DLIB_TEST(!p3);
+ p2->num = 9;
+ swap(p2,p3);
+ DLIB_TEST(!p2);
+ DLIB_TEST(p3);
+ DLIB_TEST(p3->num == 9);
+ p2.swap(p3);
+ DLIB_TEST(p2);
+ DLIB_TEST(!p3);
+ DLIB_TEST(p2->num == 9);
+
+
+ DLIB_TEST_MSG(counter == 2,counter);
+
+ }
+ DLIB_TEST_MSG(counter == 0,counter);
+
+ {
+ base* realp1 = new derived;
+ derived* realp2 = new derived;
+ dlib::shared_ptr<base> p1(realp1);
+ dlib::shared_ptr<derived> p2(realp2,&deleter);
+ dlib::shared_ptr<base> p3;
+ dlib::shared_ptr<derived> p4;
+ DLIB_TEST(p4.get() == 0);
+ DLIB_TEST(p1);
+ DLIB_TEST(p2);
+ DLIB_TEST(!p3);
+ DLIB_TEST(!p4);
+ DLIB_TEST(p1.get() == realp1);
+ DLIB_TEST(p2.get() == realp2);
+ p1->num = 1;
+ p2->num = 2;
+ DLIB_TEST((*p1).num == 1);
+ DLIB_TEST((*p2).num == 2);
+
+ p1.swap(p3);
+ DLIB_TEST(!p1);
+ DLIB_TEST(p3);
+ DLIB_TEST((*p3).num == 1);
+ DLIB_TEST(p3->num == 1);
+ swap(p1,p3);
+ DLIB_TEST(p1);
+ DLIB_TEST(!p3);
+ DLIB_TEST((*p1).num == 1);
+ DLIB_TEST(p1->num == 1);
+ DLIB_TEST_MSG(counter == 2,counter);
+
+ DLIB_TEST(p1.unique());
+ DLIB_TEST(p2.unique());
+ DLIB_TEST(!p3.unique());
+ DLIB_TEST(!p4.unique());
+
+ DLIB_TEST(p1.use_count() == 1);
+ DLIB_TEST(p2.use_count() == 1);
+ DLIB_TEST(p3.use_count() == 0);
+ DLIB_TEST(p4.use_count() == 0);
+
+ dlib::shared_ptr<base> p11(p1);
+
+ DLIB_TEST(!p1.unique());
+ DLIB_TEST(p2.unique());
+ DLIB_TEST(!p3.unique());
+ DLIB_TEST(!p4.unique());
+
+ DLIB_TEST(p1.use_count() == 2);
+ DLIB_TEST(p2.use_count() == 1);
+ DLIB_TEST(p3.use_count() == 0);
+ DLIB_TEST(p4.use_count() == 0);
+
+ dlib::shared_ptr<base> p22(p2);
+
+ DLIB_TEST(!p1.unique());
+ DLIB_TEST(!p2.unique());
+ DLIB_TEST(!p3.unique());
+ DLIB_TEST(!p4.unique());
+
+ DLIB_TEST(p1.use_count() == 2);
+ DLIB_TEST(p2.use_count() == 2);
+ DLIB_TEST(p3.use_count() == 0);
+ DLIB_TEST(p4.use_count() == 0);
+
+ DLIB_TEST(p11.get() == realp1);
+ DLIB_TEST(p11 == p1);
+ DLIB_TEST(p22 == p2);
+ DLIB_TEST(p3 == p4);
+ DLIB_TEST(p11 != p22);
+ DLIB_TEST(p1 != p2);
+ DLIB_TEST(p3 != p1);
+ DLIB_TEST(p3 != p11);
+ DLIB_TEST(p3 != p2);
+
+
+ p1 = p1 = p1;
+ DLIB_TEST(p1.use_count() == 2);
+ DLIB_TEST(p1->num == 1);
+ DLIB_TEST(p11.use_count() == 2);
+ p1.reset();
+ DLIB_TEST(p1.get() == 0);
+ DLIB_TEST(p1.use_count() == 0);
+ DLIB_TEST(p1.unique() == false);
+ DLIB_TEST(p11.use_count() == 1);
+ p11 = p2;
+ DLIB_TEST(p1.use_count() == 0);
+ DLIB_TEST(p1.unique() == false);
+ DLIB_TEST(p11.use_count() == 3);
+ DLIB_TEST(p11.unique() == false);
+
+ // now p11, p2, and p22 all reference the same thing and the rest are null
+ DLIB_TEST_MSG((p11 < p2) == false,"");
+ DLIB_TEST_MSG((p2 < p11) == false,"");
+
+ DLIB_TEST(get_deleter<D>(p4) == 0);
+ p4 = p2;
+ DLIB_TEST(get_deleter<D>(p4) != 0);
+ DLIB_TEST(get_deleter<D>(p4) == get_deleter<D>(p2));
+ DLIB_TEST(get_deleter<D>(p4) == get_deleter<D>(p11));
+ DLIB_TEST(get_deleter<int>(p4) == 0);
+
+ realp1 = new derived;
+ p1.reset(realp1, &deleter_base);
+ DLIB_TEST(p1.get() == realp1);
+ DLIB_TEST(p1.unique());
+ DLIB_TEST(p1.use_count() == 1);
+ DLIB_TEST(*get_deleter<Db>(p1) == &deleter_base);
+ DLIB_TEST(p1 != p4);
+ p4 = dynamic_pointer_cast<derived>(p1);
+ DLIB_TEST(!p1.unique());
+ DLIB_TEST(p1.use_count() == 2);
+ DLIB_TEST(p1 == p4);
+
+ realp1 = new derived;
+ p1.reset(realp1);
+ DLIB_TEST(p1.get() == realp1);
+ DLIB_TEST(p1.unique());
+ DLIB_TEST(p1.use_count() == 1);
+ DLIB_TEST(get_deleter<D>(p1) == 0);
+
+
+ auto_ptr<derived> ap1(new derived);
+ auto_ptr<derived> ap2(new derived);
+ ap1->num = 35;
+ ap2->num = 36;
+
+ DLIB_TEST(ap1.get() != 0);
+ DLIB_TEST(ap2.get() != 0);
+ p1 = ap2;
+ p2 = ap1;
+
+ DLIB_TEST(ap1.get() == 0);
+ DLIB_TEST(p1.unique());
+ DLIB_TEST(p1.use_count() == 1);
+ DLIB_TEST(ap2.get() == 0);
+ DLIB_TEST(p2.unique());
+ DLIB_TEST(p2.use_count() == 1);
+ DLIB_TEST(p1->num == 36);
+ DLIB_TEST(p2->num == 35);
+
+ }
+
+ DLIB_TEST_MSG(counter == 0,counter);
+ DLIB_TEST_MSG(deleter_called == 2,counter);
+
+ dlib::weak_ptr<base> wp4;
+ {
+ dlib::shared_ptr<derived> p1(new derived, &deleter_base);
+ dlib::shared_ptr<derived> p2;
+ dlib::shared_ptr<base> p3;
+
+ dlib::weak_ptr<derived> wp1;
+ dlib::weak_ptr<base> wp2;
+ dlib::weak_ptr<base> wp3;
+
+ dlib::weak_ptr<derived> wp1c(p1);
+ dlib::weak_ptr<base> wp2c(p1);
+ dlib::weak_ptr<base> wp3c(p2);
+
+ DLIB_TEST(wp1c.use_count() == 1);
+ DLIB_TEST(wp1c.lock() == p1);
+ DLIB_TEST(wp1c.expired() == false);
+
+ DLIB_TEST(wp2c.use_count() == 1);
+ DLIB_TEST(wp2c.lock() == p1);
+ DLIB_TEST(wp2c.expired() == false);
+
+ DLIB_TEST(wp3c.use_count() == 0);
+ DLIB_TEST(wp3c.lock() == dlib::shared_ptr<base>());
+ DLIB_TEST(wp3c.expired() == true);
+
+ DLIB_TEST(wp2.use_count() == 0);
+ DLIB_TEST(wp2.expired() == true);
+ DLIB_TEST(wp2.lock().use_count() == 0);
+ DLIB_TEST(wp2.lock().unique() == false);
+
+ wp1 = p1;
+ wp2 = wp1;
+ wp3 = p1;
+
+ DLIB_TEST(p1.use_count() == 1);
+ DLIB_TEST(p1.unique());
+ DLIB_TEST(wp1.use_count() == 1);
+ DLIB_TEST(wp2.use_count() == 1);
+ DLIB_TEST(wp3.use_count() == 1);
+ DLIB_TEST(wp1.expired() == false);
+ DLIB_TEST(wp2.expired() == false);
+ DLIB_TEST(wp3.expired() == false);
+ DLIB_TEST(wp1.lock() == p1);
+ DLIB_TEST(wp2.lock() == p1);
+ DLIB_TEST(wp3.lock() == p1);
+
+ wp3.reset();
+
+ DLIB_TEST(p1.use_count() == 1);
+ DLIB_TEST(p1.unique());
+ DLIB_TEST(wp1.use_count() == 1);
+ DLIB_TEST(wp2.use_count() == 1);
+ DLIB_TEST(wp3.use_count() == 0);
+ DLIB_TEST(wp1.expired() == false);
+ DLIB_TEST(wp2.expired() == false);
+ DLIB_TEST(wp3.expired() == true);
+ DLIB_TEST(wp1.lock() == p1);
+ DLIB_TEST(wp2.lock() == p1);
+ DLIB_TEST(wp3.lock() == dlib::shared_ptr<base>());
+
+
+ p1.reset();
+
+ DLIB_TEST(p1.use_count() == 0);
+ DLIB_TEST(p1.unique() == false);
+ DLIB_TEST(wp1.use_count() == 0);
+ DLIB_TEST(wp2.use_count() == 0);
+ DLIB_TEST(wp3.use_count() == 0);
+ DLIB_TEST(wp1.expired() == true);
+ DLIB_TEST(wp2.expired() == true);
+ DLIB_TEST(wp3.expired() == true);
+ DLIB_TEST(wp1.lock() == dlib::shared_ptr<base>());
+ DLIB_TEST(wp2.lock() == dlib::shared_ptr<base>());
+ DLIB_TEST(wp3.lock() == dlib::shared_ptr<base>());
+
+ p1.reset(new derived);
+
+ DLIB_TEST(p1.use_count() == 1);
+ DLIB_TEST(p1.unique() == true);
+ DLIB_TEST(wp1.use_count() == 0);
+ DLIB_TEST(wp2.use_count() == 0);
+ DLIB_TEST(wp3.use_count() == 0);
+ DLIB_TEST(wp1.expired() == true);
+ DLIB_TEST(wp2.expired() == true);
+ DLIB_TEST(wp3.expired() == true);
+ DLIB_TEST(wp1.lock() == dlib::shared_ptr<base>());
+ DLIB_TEST(wp2.lock() == dlib::shared_ptr<base>());
+ DLIB_TEST(wp3.lock() == dlib::shared_ptr<base>());
+
+ DLIB_TEST(wp4.expired() == true);
+ DLIB_TEST(wp4.lock() == dlib::shared_ptr<base>());
+ wp4 = p1;
+ p3 = p1;
+ DLIB_TEST(wp4.expired() == false);
+ DLIB_TEST(wp4.lock() == p3);
+
+
+ bool ok = false;
+ try {
+ dlib::shared_ptr<base> bad_ptr(wp1);
+ } catch (dlib::bad_weak_ptr&)
+ {
+ ok = true;
+ }
+ DLIB_TEST(ok);
+ }
+ DLIB_TEST(wp4.expired() == true);
+ DLIB_TEST(wp4.lock() == dlib::shared_ptr<base>());
+
+
+ DLIB_TEST_MSG(counter == 0,counter);
+ DLIB_TEST_MSG(deleter_called == 3,counter);
+
+ {
+ scoped_ptr<int[]> a(new int[10]);
+
+ {
+ used_array_delete = false;
+ scoped_ptr<int[],test_deleter<int[]> > b(new int[10]);
+
+ for (int i = 0; i < 10; ++i)
+ {
+ a[i] = i;
+ b[i] = i;
+ }
+ }
+ DLIB_TEST(used_array_delete == true);
+
+
+ {
+ used_array_delete = true;
+ scoped_ptr<int,test_deleter<int> > c(new int);
+ }
+ DLIB_TEST(used_array_delete == false);
+
+ scoped_ptr<const int[]> const_a(new int[10]);
+
+ }
+
+ }
+
+
+
+ class smart_pointers_tester : public tester
+ {
+ public:
+ smart_pointers_tester (
+ ) :
+ tester ("test_smart_pointers",
+ "Runs tests on the smart pointers.")
+ {}
+
+ void perform_test (
+ )
+ {
+ smart_pointers_test();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/sockets.cpp b/ml/dlib/dlib/test/sockets.cpp
new file mode 100644
index 000000000..920fa9402
--- /dev/null
+++ b/ml/dlib/dlib/test/sockets.cpp
@@ -0,0 +1,247 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <cstdlib>
+#include <ctime>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include <dlib/sockets.h>
+#include <dlib/server.h>
+#include <dlib/misc_api.h>
+
+#include "tester.h"
+
+namespace {
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ dlib::mutex gm;
+ dlib::signaler gs(gm);
+ const char magic_num = 42;
+ const int min_bytes_sent = 10000;
+ int assigned_port;
+
+ logger dlog("test.sockets");
+
+// ----------------------------------------------------------------------------------------
+
+ class serv : public server::kernel_1a_c
+ {
+ public:
+ serv (
+ ) :
+ error_occurred (false),
+ got_connections(false)
+ {}
+
+ void on_listening_port_assigned (
+ )
+ {
+ auto_mutex M(gm);
+ assigned_port = get_listening_port();
+ gs.broadcast();
+ }
+
+
+ void on_connect (
+ connection& con
+ )
+ {
+ dlog << LINFO << "in serv::on_connect(): got new connection";
+ int status;
+ int count = 0;
+ char buf[100];
+ while ((status = con.read(buf,sizeof(buf))) > 0)
+ {
+ for (int i = 0; i < status; ++i)
+ {
+ if (buf[i] != magic_num)
+ {
+ tag = 4.0;
+ error_occurred = true;
+ }
+ }
+ count += status;
+ }
+ if (count != min_bytes_sent)
+ {
+ tag = 5.0;
+ error_occurred = true;
+ }
+ got_connections = true;
+ dlog << LINFO << "in serv::on_connect(): on_connect ending";
+ }
+
+ bool error_occurred;
+ bool got_connections;
+ double tag;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class thread_container : public multithreaded_object
+ {
+ public:
+
+ serv& srv;
+
+ thread_container (
+ serv& srv_
+ ) : srv(srv_)
+ {
+ for (int i = 0; i < 10; ++i)
+ register_thread(*this, &thread_container::thread_proc);
+
+ // start up the threads
+ start();
+ }
+
+ ~thread_container ()
+ {
+ // wait for all threads to terminate
+ wait();
+ }
+
+ void thread_proc (
+ )
+ {
+ try
+ {
+ dlog << LTRACE << "enter thread";
+ {
+ auto_mutex M(gm);
+ while (assigned_port == 0)
+ gs.wait();
+ }
+
+ int status;
+ std::unique_ptr<connection> con;
+ string hostname;
+ string ip;
+ status = get_local_hostname(hostname);
+ if (status)
+ {
+ srv.tag = 1.0;
+ srv.error_occurred = true;
+ srv.clear();
+ dlog << LERROR << "leaving thread, line: " << __LINE__;
+ dlog << LERROR << "get_local_hostname() failed";
+ return;
+ }
+
+ status = hostname_to_ip(hostname,ip);
+ if (status)
+ {
+ srv.tag = 2.0;
+ srv.error_occurred = true;
+ srv.clear();
+ dlog << LERROR << "leaving thread, line: " << __LINE__;
+ dlog << LERROR << "hostname_to_ip() failed";
+ return;
+ }
+
+ dlog << LTRACE << "try to connect to the server at port " << srv.get_listening_port();
+ status = create_connection(con,srv.get_listening_port(),ip);
+ if (status)
+ {
+ srv.tag = 3.0;
+ srv.error_occurred = true;
+ srv.clear();
+ dlog << LERROR << "leaving thread, line: " << __LINE__;
+ dlog << LERROR << "create_connection() failed";
+ return;
+ }
+
+ dlog << LTRACE << "sending magic_num to server";
+ int i;
+ for (i = 0; i < min_bytes_sent; ++i)
+ {
+ con->write(&magic_num,1);
+ }
+
+ dlog << LTRACE << "shutting down connection to server";
+ close_gracefully(con);
+ dlog << LTRACE << "finished calling close_gracefully() on the connection";
+ }
+ catch (exception& e)
+ {
+ srv.error_occurred = true;
+ dlog << LERROR << "exception thrown in thread_proc(): " << e.what();
+ cout << "exception thrown in thread_proc(): " << e.what();
+ }
+ dlog << LTRACE << "exit thread";
+ }
+ };
+
+ void run_server(serv* srv)
+ {
+ dlog << LTRACE << "calling srv.start()";
+ srv->start();
+ dlog << LTRACE << "srv.start() just ended.";
+ }
+
+ void sockets_test (
+ )
+ /*!
+ requires
+ - sockets is an implementation of sockets/sockets_kernel_abstract.h
+ is instantiated with int
+ ensures
+ - runs tests on sockets for compliance with the specs
+ !*/
+ {
+
+ dlog << LTRACE << "starting test";
+ serv srv;
+
+ assigned_port = 0;
+
+
+ dlog << LTRACE << "spawning threads";
+ thread_container stuff(srv);
+
+
+
+ thread_function thread2(run_server, &srv);
+
+ // wait until all the sending threads have ended
+ stuff.wait();
+
+ if (srv.error_occurred)
+ {
+ dlog << LDEBUG << "tag: " << srv.tag;
+ }
+
+ srv.clear();
+
+ dlog << LTRACE << "ending successful test";
+ DLIB_TEST( !srv.error_occurred);
+ DLIB_TEST( srv.got_connections);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ class sockets_tester : public tester
+ {
+ public:
+ sockets_tester (
+ ) :
+ tester ("test_sockets",
+ "Runs tests on the sockets component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ sockets_test();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/sockets2.cpp b/ml/dlib/dlib/test/sockets2.cpp
new file mode 100644
index 000000000..3521e751d
--- /dev/null
+++ b/ml/dlib/dlib/test/sockets2.cpp
@@ -0,0 +1,204 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <algorithm>
+#include <memory>
+
+#include "tester.h"
+#include <dlib/sockets.h>
+#include <dlib/threads.h>
+#include <dlib/array.h>
+
+// This is called an unnamed-namespace and it has the effect of making everything
+// inside this file "private" so that everything you declare will have static linkage.
+// Thus we won't have any multiply defined symbol errors coming out of the linker when
+// we try to compile the test suite.
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ // Declare the logger we will use in this test. The name of the logger
+ // should start with "test."
+ dlib::logger dlog("test.sockets2");
+
+
+ class sockets2_tester : public tester, private multithreaded_object
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+
+ short port_num;
+ string data_to_send;
+
+ bool test_failed;
+
+ void write_thread (
+ )
+ {
+ try
+ {
+ std::unique_ptr<connection> con(connect("127.0.0.1", port_num));
+
+ // Send a copy of the data down the connection so we can test our the read() function
+ // that uses timeouts in the main thread.
+ if (con->write(data_to_send.data(), data_to_send.size()) != (int)data_to_send.size())
+ {
+ test_failed = true;
+ dlog << LERROR << "failed to send all the data down the connection";
+ }
+
+ close_gracefully(con,300000);
+ }
+ catch (exception& e)
+ {
+ test_failed = true;
+ dlog << LERROR << e.what();
+ }
+ }
+
+ void no_write_thread (
+ )
+ {
+ try
+ {
+ std::unique_ptr<connection> con(connect("127.0.0.1", port_num));
+
+ // just do nothing until the connection closes
+ char ch;
+ con->read(&ch, 1);
+ dlog << LDEBUG << "silent connection finally closing";
+ }
+ catch (exception& e)
+ {
+ test_failed = true;
+ dlog << LERROR << e.what();
+ }
+ }
+
+ public:
+ sockets2_tester (
+ ) :
+ tester (
+ "test_sockets2", // the command line argument name for this test
+ "Run sockets2 tests.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ register_thread(*this, &sockets2_tester::write_thread);
+ register_thread(*this, &sockets2_tester::write_thread);
+ register_thread(*this, &sockets2_tester::write_thread);
+ register_thread(*this, &sockets2_tester::write_thread);
+ register_thread(*this, &sockets2_tester::write_thread);
+ register_thread(*this, &sockets2_tester::no_write_thread);
+ }
+
+ void perform_test (
+ )
+ {
+ run_tests(0);
+ run_tests(40);
+ }
+
+ void run_tests (
+ unsigned long timeout_to_use
+ )
+ {
+ // make sure there aren't any threads running
+ wait();
+
+ port_num = 5000;
+ test_failed = false;
+
+ print_spinner();
+ data_to_send = "oi 2m3ormao2m fo2im3fo23mi o2mi3 foa2m3fao23ifm2o3fmia23oima23iom3giugbiua";
+ // make the block of data much larger
+ for (int i = 0; i < 11; ++i)
+ data_to_send = data_to_send + data_to_send;
+
+ dlog << LINFO << "data block size: " << data_to_send.size();
+
+
+ std::unique_ptr<listener> list;
+ DLIB_TEST(create_listener(list, port_num, "127.0.0.1") == 0);
+ DLIB_TEST(bool(list));
+
+ // kick off the sending threads
+ start();
+
+
+ dlib::array<std::unique_ptr<connection> > cons;
+ std::vector<long> bytes_received(6,0);
+ std::unique_ptr<connection> con_temp;
+
+ // accept the 6 connections we should get
+ for (int i = 0; i < 6; ++i)
+ {
+ DLIB_TEST(list->accept(con_temp) == 0);
+ cons.push_back(con_temp);
+ print_spinner();
+ }
+
+ int finished_cons = 0;
+
+ // now receive all the bytes from the sending threads
+ while (finished_cons < 5)
+ {
+ for (unsigned long i = 0; i < cons.size(); ++i)
+ {
+ if (cons[i])
+ {
+ const int buf_size = 3000;
+ char buf[buf_size];
+
+ int status = cons[i]->read(buf, buf_size, timeout_to_use);
+
+ if (status > 0)
+ {
+ DLIB_TEST(equal(buf, buf+status, data_to_send.begin()+bytes_received[i]));
+ bytes_received[i] += status;
+ }
+ else if (status == 0)
+ {
+ // the connection is closed to kill it
+ cons[i].reset();
+ ++finished_cons;
+ }
+ }
+ }
+ print_spinner();
+ }
+
+ for (unsigned long i = 0; i < bytes_received.size(); ++i)
+ {
+ DLIB_TEST(bytes_received[i] == (long)data_to_send.size() || cons[i]);
+ }
+
+
+ dlog << LINFO << "All data received correctly";
+
+ cons.clear();
+
+
+ print_spinner();
+
+ DLIB_TEST(test_failed == false);
+
+
+ // wait for all the sending threads to terminate
+ wait();
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ sockets2_tester a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/sockstreambuf.cpp b/ml/dlib/dlib/test/sockstreambuf.cpp
new file mode 100644
index 000000000..519feb2b5
--- /dev/null
+++ b/ml/dlib/dlib/test/sockstreambuf.cpp
@@ -0,0 +1,253 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <cstdlib>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <ctime>
+#include <dlib/sockets.h>
+#include <dlib/misc_api.h>
+#include <dlib/sockstreambuf.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ dlib::mutex m;
+ dlib::signaler s(m);
+ bool thread_running;
+
+ logger dlog("test.sockstreambuf");
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename ssb>
+ struct thread_proc_struct
+ {
+ static void thread_proc (
+ void* param
+ )
+ {
+
+ listener& list = *static_cast<listener*>(param);
+ connection* con;
+ list.accept(con);
+
+ ssb buf(con);
+ ostream out(&buf);
+
+
+ char ch;
+ char* bigbuf = new char[1000000];
+
+
+ for (int i = 'a'; i < 'z'; ++i)
+ {
+ ch = i;
+ out << ch << " ";
+ }
+
+ out.put('A');
+
+ for (int i = 0; i < 256; ++i)
+ {
+ ch = i;
+ out.write(&ch,1);
+ }
+
+ for (int i = -100; i < 25600; ++i)
+ {
+ out << i << " ";
+ }
+
+ out.put('A');
+
+ for (int i = -100; i < 25600; ++i)
+ {
+ out.write((char*)&i,sizeof(i));
+ }
+
+ for (int i = 0; i < 1000000; ++i)
+ {
+ bigbuf[i] = (i&0xFF);
+ }
+ out.write(bigbuf,1000000);
+
+ out.put('d');
+ out.put('a');
+ out.put('v');
+ out.put('i');
+ out.put('s');
+
+
+ string tstring = "this is a test";
+ int tint = -853;
+ unsigned int tuint = 89;
+ serialize(tstring,out);
+ serialize(tint,out);
+ serialize(tuint,out);
+
+
+ out.flush();
+
+
+ auto_mutex M(m);
+ thread_running = false;
+ s.signal();
+
+ dlib::sleep(300);
+ delete con;
+ delete &list;
+
+ delete [] bigbuf;
+ }
+ };
+
+ template <typename ssb>
+ void sockstreambuf_test (
+ )
+ /*!
+ requires
+ - ssb is an implementation of sockstreambuf/sockstreambuf_kernel_abstract.h
+ ensures
+ - runs tests on ssb for compliance with the specs
+ !*/
+ {
+ char ch;
+ vector<char> vbuf;
+ vbuf.resize(1000000);
+ char* bigbuf = &vbuf[0];
+ connection* con;
+
+ print_spinner();
+ thread_running = true;
+ listener* list;
+ if (create_listener(list,0))
+ {
+ DLIB_TEST_MSG(false, "Unable to create a listener");
+ }
+
+ create_new_thread(&thread_proc_struct<ssb>::thread_proc,list);
+
+ if (create_connection(con,list->get_listening_port(),"127.0.0.1"))
+ {
+ DLIB_TEST_MSG(false, "Unable to create a connection");
+ }
+
+ // make sure con gets deleted
+ std::unique_ptr<connection> del_con(con);
+
+ ssb buf(con);
+ istream in(&buf);
+
+
+
+ for (int i = 'a'; i < 'z'; ++i)
+ {
+ in >> ch;
+ char c = i;
+ DLIB_TEST_MSG(ch == c,"ch: " << (int)ch << " c: " << (int)c);
+ }
+
+ in.get();
+ DLIB_TEST_MSG(in.peek() == 'A', "*" << in.peek() << "*");
+ in.get();
+
+ for (int i = 0; i < 256; ++i)
+ {
+ in.read(&ch,1);
+ char c = i;
+ DLIB_TEST_MSG(ch == c,"ch: " << (int)ch << " c: " << (int)c );
+ }
+
+ for (int i = -100; i < 25600; ++i)
+ {
+ int n = 0;
+ in >> n;
+ DLIB_TEST_MSG(n == i,"n: " << n << " i:" << i);
+ }
+
+ in.get();
+ DLIB_TEST_MSG(in.peek() == 'A', "*" << in.peek() << "*");
+ in.get();
+
+ for (int i = -100; i < 25600; ++i)
+ {
+ int n;
+ in.read((char*)&n,sizeof(n));
+ DLIB_TEST_MSG(n == i,"n: " << n << " i:" << i);
+ }
+
+ in.read(bigbuf,1000000);
+ for (int i = 0; i < 1000000; ++i)
+ {
+ DLIB_TEST(bigbuf[i] == (char)(i&0xFF));
+ }
+
+ DLIB_TEST(in.get() == 'd');
+ DLIB_TEST(in.get() == 'a');
+ DLIB_TEST(in.get() == 'v');
+ DLIB_TEST(in.get() == 'i');
+
+ DLIB_TEST(in.peek() == 's');
+
+ DLIB_TEST(in.get() == 's');
+
+ in.putback('s');
+ DLIB_TEST(in.peek() == 's');
+
+ DLIB_TEST(in.get() == 's');
+
+
+ string tstring;
+ int tint;
+ unsigned int tuint;
+ deserialize(tstring,in);
+ deserialize(tint,in);
+ deserialize(tuint,in);
+
+ DLIB_TEST(tstring == "this is a test");
+ DLIB_TEST(tint == -853);
+ DLIB_TEST(tuint == 89);
+
+
+
+ auto_mutex M(m);
+ while (thread_running)
+ s.wait();
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ class sockstreambuf_tester : public tester
+ {
+ public:
+ sockstreambuf_tester (
+ ) :
+ tester ("test_sockstreambuf",
+ "Runs tests on the sockstreambuf component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing sockstreambuf";
+ sockstreambuf_test<sockstreambuf>();
+ dlog << LINFO << "testing sockstreambuf_unbuffered";
+ sockstreambuf_test<sockstreambuf_unbuffered>();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/sparse_vector.cpp b/ml/dlib/dlib/test/sparse_vector.cpp
new file mode 100644
index 000000000..97b60b7b3
--- /dev/null
+++ b/ml/dlib/dlib/test/sparse_vector.cpp
@@ -0,0 +1,301 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <dlib/sparse_vector.h>
+#include "tester.h"
+#include <dlib/rand.h>
+#include <dlib/string.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.sparse_vector");
+
+ void test_sparse_matrix_vector_multiplies()
+ {
+ dlib::rand rnd;
+
+ const long size = 30;
+
+ for (int iter = 0; iter < 10; ++iter)
+ {
+ print_spinner();
+
+ std::vector<sample_pair> edges;
+ std::vector<ordered_sample_pair> oedges;
+ matrix<double> M(size,size);
+ M = 0;
+ for (long i = 0; i < M.size()/3; ++i)
+ {
+ const long r = rnd.get_random_32bit_number()%M.nr();
+ const long c = rnd.get_random_32bit_number()%M.nc();
+ const double d = rnd.get_random_gaussian()*10;
+ M(r,c) += d;
+ oedges.push_back(ordered_sample_pair(r,c,d));
+ }
+
+ matrix<double> SM(size,size);
+ SM = 0;
+ for (long i = 0; i < SM.size()/3; ++i)
+ {
+ const long r = rnd.get_random_32bit_number()%SM.nr();
+ const long c = rnd.get_random_32bit_number()%SM.nc();
+ const double d = rnd.get_random_gaussian()*10;
+ SM(r,c) += d;
+ if (r != c)
+ SM(c,r) += d;
+ edges.push_back(sample_pair(r,c,d));
+ }
+
+ const matrix<double> v = randm(size,1);
+
+ matrix<double> result;
+
+ sparse_matrix_vector_multiply(oedges, v, result);
+ DLIB_TEST_MSG(length(M*v - result) < 1e-12, length(M*v - result));
+
+ sparse_matrix_vector_multiply(edges, v, result);
+ DLIB_TEST_MSG(length(SM*v - result) < 1e-12, length(SM*v - result));
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_sparse_matrix_vector_multiply1()
+ {
+ print_spinner();
+ std::map<unsigned long,double> sv;
+ sv[2] = 8;
+ sv[6] = 2.3;
+
+ matrix<double,10,1> v;
+ v = 0;
+ v(2) = 8;
+ v(6) = 2.3;
+
+
+ matrix<double,0,1> r1, r2;
+
+ r1 = gaussian_randm(4,10)*v;
+ r2 = sparse_matrix_vector_multiply(gaussian_randm(4,std::numeric_limits<long>::max()),sv);
+
+ DLIB_TEST(max(abs(r1-r2)) < 1e-15);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_sparse_matrix_vector_multiply2()
+ {
+ std::vector<std::pair<unsigned long,double> > sv;
+ sv.push_back(make_pair(6, 1.42));
+ sv.push_back(make_pair(3, 5));
+
+ matrix<double,9,1> v;
+ v = 0;
+ v(3) = 5;
+ v(6) = 1.42;
+
+
+ matrix<double,0,1> r1, r2;
+
+ r1 = gaussian_randm(3,9)*v;
+ r2 = sparse_matrix_vector_multiply(gaussian_randm(3,std::numeric_limits<long>::max()),sv);
+
+ DLIB_TEST(max(abs(r1-r2)) < 1e-15);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_make_sparse_vector_inplace()
+ {
+ std::vector<std::pair<unsigned long,double> > vect;
+ vect.push_back(make_pair(4,1));
+ vect.push_back(make_pair(0,1));
+ vect.push_back(make_pair(4,1));
+ vect.push_back(make_pair(3,1));
+ vect.push_back(make_pair(8,1));
+ vect.push_back(make_pair(8,1));
+ vect.push_back(make_pair(8,1));
+ vect.push_back(make_pair(8,1));
+
+ make_sparse_vector_inplace(vect);
+
+ DLIB_TEST(vect.size() == 4);
+ DLIB_TEST(vect[0].first == 0);
+ DLIB_TEST(vect[1].first == 3);
+ DLIB_TEST(vect[2].first == 4);
+ DLIB_TEST(vect[3].first == 8);
+
+ DLIB_TEST(vect[0].second == 1);
+ DLIB_TEST(vect[1].second == 1);
+ DLIB_TEST(vect[2].second == 2);
+ DLIB_TEST(vect[3].second == 4);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class sparse_vector_tester : public tester
+ {
+ public:
+ sparse_vector_tester (
+ ) :
+ tester (
+ "test_sparse_vector", // the command line argument name for this test
+ "Run tests on the sparse_vector routines.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+
+ void perform_test (
+ )
+ {
+ test_make_sparse_vector_inplace();
+
+ std::map<unsigned int, double> v;
+ v[4] = 8;
+ v[2] = -4;
+ v[9] = 10;
+
+ DLIB_TEST(max(v) == 10);
+ DLIB_TEST(min(v) == -4);
+
+ v.clear();
+ v[4] = 8;
+ v[9] = 10;
+ DLIB_TEST(max(v) == 10);
+ DLIB_TEST(min(v) == 0);
+
+
+ v.clear();
+ v[4] = -9;
+ v[9] = -4;
+ DLIB_TEST(max(v) == 0);
+ DLIB_TEST(min(v) == -9);
+
+
+ {
+ matrix<double> a(2,2), b(2,2);
+ a = randm(2,2);
+ b = randm(2,2);
+
+ DLIB_TEST(equal(a-b, subtract(a,b)));
+ DLIB_TEST(equal(a+b, add(a,b)));
+ DLIB_TEST(equal(a-(b+b), subtract(a,b+b)));
+ DLIB_TEST(equal(a+b+b, add(a,b+b)));
+ }
+
+ {
+ std::map<unsigned long,double> a, b, c;
+ a[1] = 2;
+ a[3] = 5;
+
+ b[0] = 3;
+ b[1] = 1;
+
+ c = add(a,b);
+ DLIB_TEST(c.size() == 3);
+ DLIB_TEST(c[0] == 3);
+ DLIB_TEST(c[1] == 3);
+ DLIB_TEST(c[3] == 5);
+
+ c = subtract(a,b);
+ DLIB_TEST(c.size() == 3);
+ DLIB_TEST(c[0] == -3);
+ DLIB_TEST(c[1] == 1);
+ DLIB_TEST(c[3] == 5);
+
+ c = add(b,a);
+ DLIB_TEST(c.size() == 3);
+ DLIB_TEST(c[0] == 3);
+ DLIB_TEST(c[1] == 3);
+ DLIB_TEST(c[3] == 5);
+
+ c = subtract(b,a);
+ DLIB_TEST(c.size() == 3);
+ DLIB_TEST(c[0] == 3);
+ DLIB_TEST(c[1] == -1);
+ DLIB_TEST(c[3] == -5);
+
+ std::vector<std::pair<unsigned long,double> > aa, bb, cc;
+
+ aa.assign(a.begin(), a.end());
+ bb.assign(b.begin(), b.end());
+
+ cc = add(aa,bb);
+ DLIB_TEST(cc.size() == 3);
+ DLIB_TEST(cc[0].first == 0);
+ DLIB_TEST(cc[1].first == 1);
+ DLIB_TEST(cc[2].first == 3);
+ DLIB_TEST(cc[0].second == 3);
+ DLIB_TEST(cc[1].second == 3);
+ DLIB_TEST(cc[2].second == 5);
+
+ cc = subtract(aa,bb);
+ DLIB_TEST(cc.size() == 3);
+ DLIB_TEST(cc[0].first == 0);
+ DLIB_TEST(cc[1].first == 1);
+ DLIB_TEST(cc[2].first == 3);
+ DLIB_TEST(cc[0].second == -3);
+ DLIB_TEST(cc[1].second == 1);
+ DLIB_TEST(cc[2].second == 5);
+
+ cc = add(bb,aa);
+ DLIB_TEST(cc.size() == 3);
+ DLIB_TEST(cc[0].first == 0);
+ DLIB_TEST(cc[1].first == 1);
+ DLIB_TEST(cc[2].first == 3);
+ DLIB_TEST(cc[0].second == 3);
+ DLIB_TEST(cc[1].second == 3);
+ DLIB_TEST(cc[2].second == 5);
+
+ cc = subtract(bb,aa);
+ DLIB_TEST(cc.size() == 3);
+ DLIB_TEST(cc[0].first == 0);
+ DLIB_TEST(cc[1].first == 1);
+ DLIB_TEST(cc[2].first == 3);
+ DLIB_TEST(cc[0].second == 3);
+ DLIB_TEST(cc[1].second == -1);
+ DLIB_TEST(cc[2].second == -5);
+
+ }
+
+ test_sparse_matrix_vector_multiplies();
+ test_sparse_matrix_vector_multiply1();
+ test_sparse_matrix_vector_multiply2();
+
+ {
+ matrix<double,0,1> a, b;
+ a = gaussian_randm(6,1, 0);
+ b = gaussian_randm(6,1, 1);
+
+ std::vector<std::pair<unsigned long,double> > aa, bb;
+
+ assign(aa, a);
+ assign(bb, b);
+
+ // dot() does something special when the sparse vectors have entries for
+ // each dimension, which is what happens when they are copied from dense
+ // vectors. So the point of the tests in this block is to make sure dot()
+ // works right in this case.
+ DLIB_TEST(std::abs(dot(a,b) - dot(aa,bb)) < 1e-14);
+ a(3) = 0;
+ assign(aa, a);
+ DLIB_TEST(std::abs(dot(a,b) - dot(aa,bb)) < 1e-14);
+ }
+ }
+ };
+
+ sparse_vector_tester a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/stack.cpp b/ml/dlib/dlib/test/stack.cpp
new file mode 100644
index 000000000..0b92eaeb9
--- /dev/null
+++ b/ml/dlib/dlib/test/stack.cpp
@@ -0,0 +1,294 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/stack.h>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.stack");
+
+ template <
+ typename stack
+ >
+ void stack_kernel_test (
+ )
+ /*!
+ requires
+ - stack is an implementation of stack/stack_sort_abstract.h
+ stack is instantiated with int
+ ensures
+ - runs tests on stack for compliance with the specs
+ !*/
+ {
+
+
+ srand(static_cast<unsigned int>(time(0)));
+
+ print_spinner();
+
+ stack a1, a2;
+
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ swap(a1,a2);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.move_next() == false);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.size() == 0);
+
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ a1.reset();
+ a2.reset();
+
+ for (unsigned long k = 0; k < 4; ++k)
+ {
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ swap(a1,a2);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.move_next() == false);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.size() == 0);
+
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ a1.clear();
+ a2.clear();
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ swap(a1,a2);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.move_next() == false);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.size() == 0);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a2.size() == 0);
+
+
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start());
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.move_next() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a1.at_start() == false);
+ DLIB_TEST(a1.size() == 0);
+
+ a1.clear();
+ a2.clear();
+
+
+ for (unsigned long i = 0; i < 100; ++i)
+ {
+ int a = (int)i;
+ a1.push(a);
+ }
+
+ DLIB_TEST(a1.size() == 100);
+
+ int count = 99;
+ while (a1.move_next())
+ {
+ DLIB_TEST_MSG(a1.element() == count,a1.element() << " : " << count);
+ --count;
+ }
+
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == false);
+
+ a1.swap(a2);
+
+ count = 99;
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.at_start() == false);
+ DLIB_TEST(a1.current_element_valid() == false);
+ DLIB_TEST(a1.at_start() == true);
+
+ DLIB_TEST(a1.size() == 0);
+ DLIB_TEST(a2.size() == 100);
+ DLIB_TEST(a2.current() == 99);
+
+ a2.reset();
+ while (a2.move_next())
+ {
+ DLIB_TEST(a2.element() == count--);
+ }
+
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.at_start() == false);
+ int b = 4;
+ a2.push(b);
+ DLIB_TEST(a2.current_element_valid() == false);
+ DLIB_TEST(a2.at_start() == true);
+
+ DLIB_TEST(a2.current() == 4);
+ int c = 0;
+ a2.pop(c);
+ DLIB_TEST(c == 4);
+
+ // serialize the state of a2, then clear a2, then
+ // load the state back into a2.
+ ostringstream sout;
+ serialize(a2,sout);
+ DLIB_TEST(a2.at_start() == true);
+ istringstream sin(sout.str());
+ a2.clear();
+ deserialize(a2,sin);
+
+
+ count = 99;
+ while (a2.size())
+ {
+ int a = 0;
+ DLIB_TEST(a2.current() == count);
+ DLIB_TEST(const_cast<const stack&>(a2).current() == count);
+ a2.pop(a);
+ DLIB_TEST(a == count--);
+ }
+
+
+
+
+
+
+ a1.clear();
+ a2.clear();
+ }
+
+
+ {
+ a1.clear();
+ remover<int>& go = a1;
+ for (int i = 0; i < 100; ++i)
+ {
+ int a = 3;
+ a1.push(a);
+ }
+ DLIB_TEST(go.size() == 100);
+ for (int i = 0; i < 100; ++i)
+ {
+ int a = 9;
+ a1.remove_any(a);
+ DLIB_TEST(a == 3);
+ }
+ DLIB_TEST(go.size() == 0);
+ }
+
+ }
+
+
+
+
+ class stack_tester : public tester
+ {
+ public:
+ stack_tester (
+ ) :
+ tester ("test_stack",
+ "Runs tests on the stack component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ stack_kernel_test<stack<int>::kernel_1a> ();
+ dlog << LINFO << "testing kernel_1a_c";
+ stack_kernel_test<stack<int>::kernel_1a_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/static_map.cpp b/ml/dlib/dlib/test/static_map.cpp
new file mode 100644
index 000000000..931ae1fae
--- /dev/null
+++ b/ml/dlib/dlib/test/static_map.cpp
@@ -0,0 +1,323 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/hash_table.h>
+#include <dlib/binary_search_tree.h>
+
+#include <dlib/static_map.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.static_map");
+
+ template <
+ typename map
+ >
+ void static_map_kernel_test (
+ )
+ /*!
+ requires
+ - map is an implementation of static_map/static_map_kernel_abstract.h and
+ is instantiated to map int to int
+ ensures
+ - runs tests on map for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+ srand(static_cast<unsigned int>(time(0)));
+
+ typedef binary_search_tree<int,int>::kernel_2a_c bst;
+ typedef hash_table<int,int>::kernel_1a_c ht;
+
+ const unsigned long table_4_max_size = 100;
+ const unsigned long tree_max_size = 50000;
+ ht table_4(4);
+ ht table_8(8);
+ bst tree;
+
+ ht table_4b(4);
+ ht table_8b(8);
+ bst treeb;
+
+
+ // just do the following to make sure operator[] doesn't hang
+ // under some instances
+ {
+ int g = 1, h = 1;
+ treeb.add(g,h);
+ map test;
+ map test2;
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test.at_start());
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.move_next() == false);
+ DLIB_TEST(test.current_element_valid() == false);
+ DLIB_TEST(test.at_start() == false);
+
+ swap(test,test2);
+ DLIB_TEST(test2.at_start() == false);
+ DLIB_TEST(test.at_start() == true);
+
+ swap(test,test2);
+ DLIB_TEST(test.at_start() == false);
+
+
+ DLIB_TEST(test.size() == 0);
+ DLIB_TEST(test[1] == 0);
+ DLIB_TEST(test[2] == 0);
+ DLIB_TEST(test[3] == 0);
+ DLIB_TEST(test[0] == 0);
+
+ test.load(treeb);
+ DLIB_TEST(test.at_start());
+ DLIB_TEST(test[1] != 0);
+ DLIB_TEST(test[2] == 0);
+ DLIB_TEST(test[3] == 0);
+ DLIB_TEST(test[0] == 0);
+
+ test2.clear();
+ swap(test2,test);
+ DLIB_TEST(test2[1] != 0);
+ DLIB_TEST(test2[2] == 0);
+ DLIB_TEST(test2[3] == 0);
+ DLIB_TEST(test2[0] == 0);
+ DLIB_TEST(test[1] == 0);
+ DLIB_TEST(test[2] == 0);
+ DLIB_TEST(test[3] == 0);
+ DLIB_TEST(test[0] == 0);
+
+
+ DLIB_TEST(treeb.size() == 0);
+ treeb.clear();
+ }
+
+
+ for (unsigned long i = 0; i < table_4_max_size; ++i)
+ {
+ int a = ::rand()&0xFF;
+ int b = a + 1;
+ int ab = a;
+ int bb = b;
+ table_4.add(a,b);
+ table_4b.add(ab,bb);
+ }
+
+ for (unsigned long i = 0; i < table_4_max_size; ++i)
+ {
+ int a = ::rand()&0xF;
+ int b = a + 1;
+ int ab = a;
+ int bb = b;
+ table_8.add(a,b);
+ table_8b.add(ab,bb);
+ }
+
+ for (unsigned long i = 0; i < tree_max_size; ++i)
+ {
+ int a = ::rand()&0xFFF;
+ int b = a + 1;
+ int ab = a;
+ int bb = b;
+ tree.add(a,b);
+ treeb.add(ab,bb);
+ }
+
+ map m_4;
+ m_4.load(table_4);
+ map m_8;
+ m_8.load(table_8);
+ map m_t;
+ m_t.load(tree);
+ map e;
+ e.load(table_4);
+
+ DLIB_TEST(e.size() == 0);
+ DLIB_TEST(e.at_start() == true);
+ DLIB_TEST(e.current_element_valid() == false);
+ DLIB_TEST(e.move_next() == false);
+ DLIB_TEST(e.at_start() == false);
+ DLIB_TEST(e.current_element_valid() == false);
+
+ DLIB_TEST(m_4.size() == table_4b.size());
+ DLIB_TEST(m_8.size() == table_8b.size());
+ DLIB_TEST(m_t.size() == treeb.size());
+
+ DLIB_TEST(m_4.at_start() == true);
+ DLIB_TEST(m_8.at_start() == true);
+ DLIB_TEST(m_t.at_start() == true);
+ DLIB_TEST(m_4.current_element_valid() == false);
+ DLIB_TEST(m_8.current_element_valid() == false);
+ DLIB_TEST(m_t.current_element_valid() == false);
+
+
+ DLIB_TEST(m_4.move_next() == true);
+ DLIB_TEST(m_4.at_start() == false);
+ DLIB_TEST(m_4.current_element_valid() == true);
+ DLIB_TEST(m_8.move_next() == true);
+ DLIB_TEST(m_8.at_start() == false);
+ DLIB_TEST(m_8.current_element_valid() == true);
+ DLIB_TEST(m_t.move_next() == true);
+ DLIB_TEST(m_t.at_start() == false);
+ DLIB_TEST(m_t.current_element_valid() == true);
+
+ m_4.reset();
+ m_8.reset();
+ m_t.reset();
+
+ while (m_4.move_next())
+ {
+ DLIB_TEST( table_4b[m_4.element().key()] != 0);
+ DLIB_TEST( *table_4b[m_4.element().key()] == m_4.element().value());
+ }
+
+ // serialize the state of m_4, then clear m_4, then
+ // load the state back into m_4.
+ ostringstream sout;
+ serialize(m_4,sout);
+ DLIB_TEST(m_4.at_start() == true);
+ istringstream sin(sout.str());
+ m_4.clear();
+ deserialize(m_4,sin);
+ DLIB_TEST(m_4.at_start() == true);
+
+
+
+ while (table_4b.move_next())
+ {
+ DLIB_TEST( m_4[table_4b.element().key()] != 0);
+ DLIB_TEST( *m_4[table_4b.element().key()] == table_4b.element().value());
+ }
+
+ // serialize the state of m_8, then clear m_8, then
+ // load the state back into m_8.
+ sout.str("");
+ serialize(m_8,sout);
+ DLIB_TEST(m_8.at_start() == true);
+ sin.str(sout.str());
+ m_8.clear();
+ deserialize(m_8,sin);
+ DLIB_TEST(m_8.at_start() == true);
+
+ while (m_8.move_next())
+ {
+ DLIB_TEST( table_8b[m_8.element().key()] != 0);
+ DLIB_TEST( *table_8b[m_8.element().key()] == m_8.element().value());
+ }
+
+ while (table_8b.move_next())
+ {
+ DLIB_TEST( m_8[table_8b.element().key()] != 0);
+ DLIB_TEST( *m_8[table_8b.element().key()] == table_8b.element().value());
+ }
+
+
+ while (m_t.move_next())
+ {
+ DLIB_TEST( treeb[m_t.element().key()] != 0);
+ DLIB_TEST( *treeb[m_t.element().key()] == m_t.element().value());
+ }
+
+ // make sure operator[] doesn't hang
+ for (int l = 1; l < 10000; ++l)
+ {
+ DLIB_TEST(m_t[l+0xFFF] == 0);
+ }
+
+ while (treeb.move_next())
+ {
+ DLIB_TEST( m_t[treeb.element().key()] != 0);
+ DLIB_TEST( *m_t[treeb.element().key()] == treeb.element().value());
+ }
+
+
+
+ m_4.reset();
+ m_8.reset();
+ m_t.reset();
+
+ int last = 0;
+ while (m_4.move_next())
+ {
+ DLIB_TEST(last <= m_4.element().key());
+ DLIB_TEST(m_4.element().key() + 1 == m_4.element().value());
+ last = m_4.element().key();
+ }
+
+ last = 0;
+ while (m_8.move_next())
+ {
+ DLIB_TEST(last <= m_8.element().key());
+ DLIB_TEST(m_8.element().key() + 1 == m_8.element().value());
+ last = m_8.element().key();
+ }
+
+ last = 0;
+ while (m_t.move_next())
+ {
+ DLIB_TEST(last <= m_t.element().key());
+ DLIB_TEST(m_t.element().key() + 1 == m_t.element().value());
+ last = m_t.element().key();
+ }
+
+
+
+
+
+
+ // this is just to test swap
+ m_4.swap(m_8);
+ m_4.reset();
+ table_4b.reset();
+ while (m_8.move_next())
+ {
+ DLIB_TEST( table_4b[m_8.element().key()] != 0);
+ DLIB_TEST( *table_4b[m_8.element().key()] == m_8.element().value());
+ }
+
+ while (table_4b.move_next())
+ {
+ DLIB_TEST( m_8[table_4b.element().key()] != 0);
+ DLIB_TEST( *m_8[table_4b.element().key()] == table_4b.element().value());
+ }
+
+ }
+
+
+
+
+
+ class static_map_tester : public tester
+ {
+ public:
+ static_map_tester (
+ ) :
+ tester ("test_static_map",
+ "Runs tests on the static_map component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ static_map_kernel_test<static_map<int,int>::kernel_1a> ();
+ dlog << LINFO << "testing kernel_1a_c";
+ static_map_kernel_test<static_map<int,int>::kernel_1a_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/static_set.cpp b/ml/dlib/dlib/test/static_set.cpp
new file mode 100644
index 000000000..0ad864e4a
--- /dev/null
+++ b/ml/dlib/dlib/test/static_set.cpp
@@ -0,0 +1,206 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/queue.h>
+#include <dlib/static_set.h>
+#include <dlib/set.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.static_set");
+
+ template <
+ typename set
+ >
+ void static_set_kernel_test (
+ )
+ /*!
+ requires
+ - set is an implementation of static_set/static_set_kernel_abstract.h and
+ is instantiated to hold ints
+ ensures
+ - runs tests on set for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ srand(static_cast<unsigned int>(time(0)));
+
+ typedef queue<int>::kernel_2a_c queue_of_int;
+ typedef dlib::set<int>::kernel_1a_c set_of_int;
+
+ queue_of_int q, qb, qc;
+ set_of_int ds;
+
+ set S;
+ S.load(ds);
+
+ for (int k = 1; k < 1000; ++k)
+ {
+ q.clear();
+ qb.clear();
+ qc.clear();
+ unsigned long num = k;
+ for (unsigned long i = 0; i < num; ++i)
+ {
+ int a = ::rand()&0xFF;
+ int b = a;
+ int c = a;
+ q.enqueue(a);
+ qb.enqueue(b);
+ qc.enqueue(c);
+ }
+
+
+
+ set s;
+
+ DLIB_TEST(s.size() == 0);
+ DLIB_TEST(s.at_start());
+ DLIB_TEST(s.current_element_valid() == false);
+ DLIB_TEST(s.move_next() == false);
+ DLIB_TEST(s.current_element_valid() == false);
+ DLIB_TEST(s.at_start() == false);
+
+ s.load(q);
+ DLIB_TEST(s.at_start());
+ set se;
+ se.load(q);
+
+ DLIB_TEST(se.size() == 0);
+ DLIB_TEST(se.at_start() == true);
+ DLIB_TEST(se.current_element_valid() == false);
+ DLIB_TEST(se.move_next() == false);
+ DLIB_TEST(se.at_start() == false);
+ DLIB_TEST(se.current_element_valid() == false);
+
+
+ DLIB_TEST(s.size() == qb.size());
+ DLIB_TEST(s.at_start() == true);
+ DLIB_TEST(s.current_element_valid() == false);
+ DLIB_TEST(s.move_next() == true);
+ DLIB_TEST(s.at_start() == false);
+ DLIB_TEST(s.current_element_valid() == true);
+ s.reset();
+ se.reset();
+
+ swap(se,s);
+
+ DLIB_TEST(s.size() == 0);
+ DLIB_TEST(s.at_start() == true);
+ DLIB_TEST(s.current_element_valid() == false);
+ DLIB_TEST(s.move_next() == false);
+ DLIB_TEST(s.at_start() == false);
+ DLIB_TEST(s.current_element_valid() == false);
+
+ DLIB_TEST(se.size() == qb.size());
+ DLIB_TEST(se.at_start() == true);
+ DLIB_TEST(se.current_element_valid() == false);
+ DLIB_TEST(se.move_next() == true);
+ DLIB_TEST(se.at_start() == false);
+ DLIB_TEST(se.current_element_valid() == true);
+ s.reset();
+ se.reset();
+
+ swap(se,s);
+
+
+
+ int last = 0;
+ while (s.move_next())
+ {
+ DLIB_TEST(last <= s.element());
+ last = s.element();
+ }
+
+
+
+ while (qb.move_next())
+ {
+ int a;
+ qb.dequeue(a);
+ DLIB_TEST(s.is_member(a));
+ DLIB_TEST(!se.is_member(a));
+
+ // make sure is_member() doesn't hang
+ for (int l = 0; l < 100; ++l)
+ {
+ int a = ::rand();
+ s.is_member(a);
+ }
+ }
+
+ swap(s,se);
+
+ // serialize the state of se, then clear se, then
+ // load the state back into se.
+ ostringstream sout;
+ serialize(se,sout);
+ DLIB_TEST(se.at_start() == true);
+ istringstream sin(sout.str());
+ se.clear();
+ deserialize(se,sin);
+ DLIB_TEST(se.at_start() == true);
+
+
+ last = 0;
+ while (se.move_next())
+ {
+ DLIB_TEST(last <= se.element());
+ last = se.element();
+ }
+
+
+ DLIB_TEST(s.size() == 0);
+ DLIB_TEST(se.size() == qc.size());
+
+ while (qc.move_next())
+ {
+ int a;
+ qc.dequeue(a);
+ DLIB_TEST(se.is_member(a));
+ DLIB_TEST(!s.is_member(a));
+ }
+
+
+ }
+ }
+
+
+
+
+
+ class static_set_tester : public tester
+ {
+ public:
+ static_set_tester (
+ ) :
+ tester ("test_static_set",
+ "Runs tests on the static_set component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ static_set_kernel_test<static_set<int>::kernel_1a> ();
+ dlog << LINFO << "testing kernel_1a_c";
+ static_set_kernel_test<static_set<int>::kernel_1a_c>();
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/statistics.cpp b/ml/dlib/dlib/test/statistics.cpp
new file mode 100644
index 000000000..0394286ad
--- /dev/null
+++ b/ml/dlib/dlib/test/statistics.cpp
@@ -0,0 +1,915 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/statistics.h>
+#include <dlib/rand.h>
+#include <dlib/svm.h>
+#include <algorithm>
+#include <dlib/matrix.h>
+#include <cmath>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.statistics");
+
+
+
+ class statistics_tester : public tester
+ {
+ public:
+ statistics_tester (
+ ) :
+ tester ("test_statistics",
+ "Runs tests on the statistics component.")
+ {}
+
+ void test_random_subset_selector ()
+ {
+ random_subset_selector<double> rand_set;
+
+ for (int j = 0; j < 30; ++j)
+ {
+ print_spinner();
+
+ running_stats<double> rs, rs2;
+
+ rand_set.set_max_size(1000);
+
+ for (double i = 0; i < 100000; ++i)
+ {
+ rs.add(i);
+ rand_set.add(i);
+ }
+
+
+ for (unsigned long i = 0; i < rand_set.size(); ++i)
+ rs2.add(rand_set[i]);
+
+
+ dlog << LDEBUG << "true mean: " << rs.mean();
+ dlog << LDEBUG << "true sampled: " << rs2.mean();
+ double ratio = rs.mean()/rs2.mean();
+ DLIB_TEST_MSG(0.96 < ratio && ratio < 1.04, " ratio: " << ratio);
+ }
+
+
+ {
+ random_subset_selector<int> r1, r2;
+ r1.set_max_size(300);
+ for (int i = 0; i < 4000; ++i)
+ r1.add(i);
+
+ ostringstream sout;
+ serialize(r1, sout);
+ istringstream sin(sout.str());
+ deserialize(r2, sin);
+
+ DLIB_TEST(r1.size() == r2.size());
+ DLIB_TEST(r1.max_size() == r2.max_size());
+ DLIB_TEST(r1.next_add_accepts() == r2.next_add_accepts());
+ DLIB_TEST(std::equal(r1.begin(), r1.end(), r2.begin()));
+
+ for (int i = 0; i < 4000; ++i)
+ {
+ r1.add(i);
+ r2.add(i);
+ }
+
+ DLIB_TEST(r1.size() == r2.size());
+ DLIB_TEST(r1.max_size() == r2.max_size());
+ DLIB_TEST(r1.next_add_accepts() == r2.next_add_accepts());
+ DLIB_TEST(std::equal(r1.begin(), r1.end(), r2.begin()));
+ }
+ }
+
+ void test_random_subset_selector2 ()
+ {
+ random_subset_selector<double> rand_set;
+ DLIB_TEST(rand_set.next_add_accepts() == false);
+ DLIB_TEST(rand_set.size() == 0);
+ DLIB_TEST(rand_set.max_size() == 0);
+
+ for (int j = 0; j < 30; ++j)
+ {
+ print_spinner();
+
+ running_stats<double> rs, rs2;
+
+ rand_set.set_max_size(1000);
+ DLIB_TEST(rand_set.next_add_accepts() == true);
+
+ for (double i = 0; i < 100000; ++i)
+ {
+ rs.add(i);
+ if (rand_set.next_add_accepts())
+ rand_set.add(i);
+ else
+ rand_set.add();
+ }
+
+ DLIB_TEST(rand_set.size() == 1000);
+ DLIB_TEST(rand_set.max_size() == 1000);
+
+ for (unsigned long i = 0; i < rand_set.size(); ++i)
+ rs2.add(rand_set[i]);
+
+
+ dlog << LDEBUG << "true mean: " << rs.mean();
+ dlog << LDEBUG << "true sampled: " << rs2.mean();
+ double ratio = rs.mean()/rs2.mean();
+ DLIB_TEST_MSG(0.96 < ratio && ratio < 1.04, " ratio: " << ratio);
+ }
+ }
+
+ void test_running_cross_covariance ()
+ {
+ running_cross_covariance<matrix<double> > rcc1, rcc2;
+
+ matrix<double,0,1> xm, ym;
+ const int num = 40;
+
+ dlib::rand rnd;
+ for (int i = 0; i < num; ++i)
+ {
+ matrix<double,0,1> x = randm(4,1,rnd);
+ matrix<double,0,1> y = randm(4,1,rnd);
+
+ xm += x/num;
+ ym += y/num;
+
+ if (i < 15)
+ rcc1.add(x,y);
+ else
+ rcc2.add(x,y);
+ }
+
+ rnd.clear();
+ matrix<double> cov;
+ for (int i = 0; i < num; ++i)
+ {
+ matrix<double,0,1> x = randm(4,1,rnd);
+ matrix<double,0,1> y = randm(4,1,rnd);
+ cov += (x-xm)*trans(y-ym);
+ }
+ cov /= num-1;
+
+ running_cross_covariance<matrix<double> > rcc = rcc1 + rcc2;
+ DLIB_TEST(max(abs(rcc.covariance_xy()-cov)) < 1e-14);
+ DLIB_TEST(max(abs(rcc.mean_x()-xm)) < 1e-14);
+ DLIB_TEST(max(abs(rcc.mean_y()-ym)) < 1e-14);
+ }
+
+ std::map<unsigned long,double> dense_to_sparse (
+ const matrix<double,0,1>& x
+ )
+ {
+ std::map<unsigned long,double> temp;
+ for (long i = 0; i < x.size(); ++i)
+ temp[i] = x(i);
+ return temp;
+ }
+
+ void test_running_cross_covariance_sparse()
+ {
+ running_cross_covariance<matrix<double> > rcc1, rcc2;
+
+ running_covariance<matrix<double> > rc1, rc2;
+
+ matrix<double,0,1> xm, ym;
+ const int num = 40;
+
+ rc1.set_dimension(4);
+ rc2.set_dimension(4);
+
+ rcc1.set_dimensions(4,5);
+ rcc2.set_dimensions(4,5);
+
+ dlib::rand rnd;
+ for (int i = 0; i < num; ++i)
+ {
+ matrix<double,0,1> x = randm(4,1,rnd);
+ matrix<double,0,1> y = randm(5,1,rnd);
+
+ xm += x/num;
+ ym += y/num;
+
+ if (i < 15)
+ {
+ rcc1.add(x,dense_to_sparse(y));
+ rc1.add(x);
+ }
+ else if (i < 30)
+ {
+ rcc2.add(dense_to_sparse(x),y);
+ rc2.add(dense_to_sparse(x));
+ }
+ else
+ {
+ rcc2.add(dense_to_sparse(x),dense_to_sparse(y));
+ rc2.add(x);
+ }
+ }
+
+ rnd.clear();
+ matrix<double> cov, cov2;
+ for (int i = 0; i < num; ++i)
+ {
+ matrix<double,0,1> x = randm(4,1,rnd);
+ matrix<double,0,1> y = randm(5,1,rnd);
+ cov += (x-xm)*trans(y-ym);
+ cov2 += (x-xm)*trans(x-xm);
+ }
+ cov /= num-1;
+ cov2 /= num-1;
+
+ running_cross_covariance<matrix<double> > rcc = rcc1 + rcc2;
+ DLIB_TEST_MSG(max(abs(rcc.covariance_xy()-cov)) < 1e-14, max(abs(rcc.covariance_xy()-cov)));
+ DLIB_TEST(max(abs(rcc.mean_x()-xm)) < 1e-14);
+ DLIB_TEST(max(abs(rcc.mean_y()-ym)) < 1e-14);
+
+ running_covariance<matrix<double> > rc = rc1 + rc2;
+ DLIB_TEST(max(abs(rc.covariance()-cov2)) < 1e-14);
+ DLIB_TEST(max(abs(rc.mean()-xm)) < 1e-14);
+ }
+
+ void test_running_covariance (
+ )
+ {
+ dlib::rand rnd;
+ std::vector<matrix<double,0,1> > vects;
+
+ running_covariance<matrix<double,0,1> > cov, cov2;
+ DLIB_TEST(cov.in_vector_size() == 0);
+
+ for (unsigned long dims = 1; dims < 5; ++dims)
+ {
+ for (unsigned long samps = 2; samps < 10; ++samps)
+ {
+ vects.clear();
+ cov.clear();
+ DLIB_TEST(cov.in_vector_size() == 0);
+ for (unsigned long i = 0; i < samps; ++i)
+ {
+ vects.push_back(randm(dims,1,rnd));
+ cov.add(vects.back());
+
+ }
+ DLIB_TEST(cov.in_vector_size() == (long)dims);
+
+ DLIB_TEST(equal(mean(mat(vects)), cov.mean()));
+ DLIB_TEST_MSG(equal(covariance(mat(vects)), cov.covariance()),
+ max(abs(covariance(mat(vects)) - cov.covariance()))
+ << " dims = " << dims << " samps = " << samps
+ );
+ }
+ }
+
+ for (unsigned long dims = 1; dims < 5; ++dims)
+ {
+ for (unsigned long samps = 2; samps < 10; ++samps)
+ {
+ vects.clear();
+ cov.clear();
+ cov2.clear();
+ DLIB_TEST(cov.in_vector_size() == 0);
+ for (unsigned long i = 0; i < samps; ++i)
+ {
+ vects.push_back(randm(dims,1,rnd));
+ if ((i%2) == 0)
+ cov.add(vects.back());
+ else
+ cov2.add(vects.back());
+
+ }
+ DLIB_TEST((cov+cov2).in_vector_size() == (long)dims);
+
+ DLIB_TEST(equal(mean(mat(vects)), (cov+cov2).mean()));
+ DLIB_TEST_MSG(equal(covariance(mat(vects)), (cov+cov2).covariance()),
+ max(abs(covariance(mat(vects)) - (cov+cov2).covariance()))
+ << " dims = " << dims << " samps = " << samps
+ );
+ }
+ }
+
+ }
+
+ void test_running_stats()
+ {
+ print_spinner();
+
+ running_stats<double> rs, rs2;
+
+ running_scalar_covariance<double> rsc1, rsc2;
+ running_scalar_covariance_decayed<double> rscd1(1000000), rscd2(1000000);
+
+ for (double i = 0; i < 100; ++i)
+ {
+ rs.add(i);
+
+ rsc1.add(i,i);
+ rsc2.add(i,i);
+ rsc2.add(i,-i);
+
+ rscd1.add(i,i);
+ rscd2.add(i,i);
+ rscd2.add(i,-i);
+ }
+
+ // make sure the running_stats and running_scalar_covariance agree
+ DLIB_TEST_MSG(std::abs(rs.mean() - rsc1.mean_x()) < 1e-10, std::abs(rs.mean() - rsc1.mean_x()));
+ DLIB_TEST(std::abs(rs.mean() - rsc1.mean_y()) < 1e-10);
+ DLIB_TEST(std::abs(rs.stddev() - rsc1.stddev_x()) < 1e-10);
+ DLIB_TEST(std::abs(rs.stddev() - rsc1.stddev_y()) < 1e-10);
+ DLIB_TEST(std::abs(rs.variance() - rsc1.variance_x()) < 1e-10);
+ DLIB_TEST(std::abs(rs.variance() - rsc1.variance_y()) < 1e-10);
+ DLIB_TEST(rs.current_n() == rsc1.current_n());
+
+ DLIB_TEST(std::abs(rsc1.correlation() - 1) < 1e-10);
+ DLIB_TEST(std::abs(rsc2.correlation() - 0) < 1e-10);
+
+
+ DLIB_TEST_MSG(std::abs(rs.mean() - rscd1.mean_x()) < 1e-2, std::abs(rs.mean() - rscd1.mean_x()) << " " << rscd1.mean_x());
+ DLIB_TEST(std::abs(rs.mean() - rscd1.mean_y()) < 1e-2);
+ DLIB_TEST_MSG(std::abs(rs.stddev() - rscd1.stddev_x()) < 1e-2, std::abs(rs.stddev() - rscd1.stddev_x()));
+ DLIB_TEST(std::abs(rs.stddev() - rscd1.stddev_y()) < 1e-2);
+ DLIB_TEST_MSG(std::abs(rs.variance() - rscd1.variance_x()) < 1e-2, std::abs(rs.variance() - rscd1.variance_x()));
+ DLIB_TEST(std::abs(rs.variance() - rscd1.variance_y()) < 1e-2);
+ DLIB_TEST(std::abs(rscd1.correlation() - 1) < 1e-2);
+ DLIB_TEST(std::abs(rscd2.correlation() - 0) < 1e-2);
+
+
+
+ // test serialization of running_stats
+ ostringstream sout;
+ serialize(rs, sout);
+ istringstream sin(sout.str());
+ deserialize(rs2, sin);
+ // make sure the running_stats and running_scalar_covariance agree
+ DLIB_TEST_MSG(std::abs(rs2.mean() - rsc1.mean_x()) < 1e-10, std::abs(rs2.mean() - rsc1.mean_x()));
+ DLIB_TEST(std::abs(rs2.mean() - rsc1.mean_y()) < 1e-10);
+ DLIB_TEST(std::abs(rs2.stddev() - rsc1.stddev_x()) < 1e-10);
+ DLIB_TEST(std::abs(rs2.stddev() - rsc1.stddev_y()) < 1e-10);
+ DLIB_TEST(std::abs(rs2.variance() - rsc1.variance_x()) < 1e-10);
+ DLIB_TEST(std::abs(rs2.variance() - rsc1.variance_y()) < 1e-10);
+ DLIB_TEST(rs2.current_n() == rsc1.current_n());
+
+ rsc1.clear();
+ rsc1.add(1, -1);
+ rsc1.add(0, 0);
+ rsc1.add(1, -1);
+ rsc1.add(0, 0);
+ rsc1.add(1, -1);
+ rsc1.add(0, 0);
+
+ DLIB_TEST(std::abs(rsc1.covariance() - -0.3) < 1e-10);
+ }
+
+ void test_skewness_and_kurtosis_1()
+ {
+
+ dlib::rand rnum;
+ running_stats<double> rs1;
+
+ double tp = 0;
+
+ rnum.set_seed("DlibRocks");
+
+ for(int i = 0; i< 1000000; i++)
+ {
+ tp = rnum.get_random_gaussian();
+ rs1.add(tp);
+ }
+
+ // check the unbiased skewness and excess kurtosis of one million Gaussian
+ // draws are both near_vects zero.
+ DLIB_TEST(abs(rs1.skewness()) < 0.1);
+ DLIB_TEST(abs(rs1.ex_kurtosis()) < 0.1);
+ }
+
+ void test_skewness_and_kurtosis_2()
+ {
+
+ string str = "DlibRocks";
+
+ for(int j = 0; j<5 ; j++)
+ {
+ matrix<double,1,100000> dat;
+ dlib::rand rnum;
+ running_stats<double> rs1;
+
+ double tp = 0;
+ double n = 100000;
+ double xb = 0;
+
+ double sknum = 0;
+ double skdenom = 0;
+ double unbi_skew = 0;
+
+ double exkurnum = 0;
+ double exkurdenom = 0;
+ double unbi_exkur = 0;
+
+ random_shuffle(str.begin(), str.end());
+ rnum.set_seed(str);
+
+ for(int i = 0; i<n; i++)
+ {
+ tp = rnum.get_random_gaussian();
+ rs1.add(tp);
+ dat(i)=tp;
+ xb += dat(i);
+ }
+
+ xb = xb/n;
+
+ for(int i = 0; i < n; i++ )
+ {
+ sknum += pow(dat(i) - xb,3);
+ skdenom += pow(dat(i) - xb,2);
+ exkurnum += pow(dat(i) - xb,4);
+ exkurdenom += pow(dat(i)-xb,2);
+ }
+
+ sknum = sknum/n;
+ skdenom = pow(skdenom/n,1.5);
+ exkurnum = exkurnum/n;
+ exkurdenom = pow(exkurdenom/n,2);
+
+ unbi_skew = sqrt(n*(n-1))/(n-2)*sknum/skdenom;
+ unbi_exkur = (n-1)*((n+1)*(exkurnum/exkurdenom-3)+6)/((n-2)*(n-3));
+
+ dlog << LINFO << "Skew Diff: " << unbi_skew - rs1.skewness();
+ dlog << LINFO << "Kur Diff: " << unbi_exkur - rs1.ex_kurtosis();
+
+ // Test an alternative implementation of the unbiased skewness and excess
+ // kurtosis against the one in running_stats.
+ DLIB_TEST(abs(unbi_skew - rs1.skewness()) < 1e-10);
+ DLIB_TEST(abs(unbi_exkur - rs1.ex_kurtosis()) < 1e-10);
+ }
+ }
+
+ void test_randomize_samples()
+ {
+ std::vector<unsigned int> t(15),u(15),v(15);
+
+ for (unsigned long i = 0; i < t.size(); ++i)
+ {
+ t[i] = i;
+ u[i] = i+1;
+ v[i] = i+2;
+ }
+ randomize_samples(t,u,v);
+
+ DLIB_TEST(t.size() == 15);
+ DLIB_TEST(u.size() == 15);
+ DLIB_TEST(v.size() == 15);
+
+ for (unsigned long i = 0; i < t.size(); ++i)
+ {
+ const unsigned long val = t[i];
+ DLIB_TEST(u[i] == val+1);
+ DLIB_TEST(v[i] == val+2);
+ }
+ }
+ void test_randomize_samples2()
+ {
+ dlib::matrix<int,15,1> t(15),u(15),v(15);
+
+ for (long i = 0; i < t.size(); ++i)
+ {
+ t(i) = i;
+ u(i) = i+1;
+ v(i) = i+2;
+ }
+ randomize_samples(t,u,v);
+
+ DLIB_TEST(t.size() == 15);
+ DLIB_TEST(u.size() == 15);
+ DLIB_TEST(v.size() == 15);
+
+ for (long i = 0; i < t.size(); ++i)
+ {
+ const long val = t(i);
+ DLIB_TEST(u(i) == val+1);
+ DLIB_TEST(v(i) == val+2);
+ }
+ }
+
+ void another_test()
+ {
+ std::vector<double> a;
+
+ running_stats<double> rs1, rs2;
+
+ for (int i = 0; i < 10; ++i)
+ {
+ rs1.add(i);
+ a.push_back(i);
+ }
+
+ DLIB_TEST(std::abs(variance(mat(a)) - rs1.variance()) < 1e-13);
+ DLIB_TEST(std::abs(stddev(mat(a)) - rs1.stddev()) < 1e-13);
+ DLIB_TEST(std::abs(mean(mat(a)) - rs1.mean()) < 1e-13);
+
+ for (int i = 10; i < 20; ++i)
+ {
+ rs2.add(i);
+ a.push_back(i);
+ }
+
+ DLIB_TEST(std::abs(variance(mat(a)) - (rs1+rs2).variance()) < 1e-13);
+ DLIB_TEST(std::abs(mean(mat(a)) - (rs1+rs2).mean()) < 1e-13);
+ DLIB_TEST((rs1+rs2).current_n() == 20);
+
+ running_scalar_covariance<double> rc1, rc2, rc3;
+ dlib::rand rnd;
+ for (double i = 0; i < 10; ++i)
+ {
+ const double a = i + rnd.get_random_gaussian();
+ const double b = i + rnd.get_random_gaussian();
+ rc1.add(a,b);
+ rc3.add(a,b);
+ }
+ for (double i = 11; i < 20; ++i)
+ {
+ const double a = i + rnd.get_random_gaussian();
+ const double b = i + rnd.get_random_gaussian();
+ rc2.add(a,b);
+ rc3.add(a,b);
+ }
+
+ DLIB_TEST(std::abs((rc1+rc2).mean_x() - rc3.mean_x()) < 1e-13);
+ DLIB_TEST(std::abs((rc1+rc2).mean_y() - rc3.mean_y()) < 1e-13);
+ DLIB_TEST_MSG(std::abs((rc1+rc2).variance_x() - rc3.variance_x()) < 1e-13, std::abs((rc1+rc2).variance_x() - rc3.variance_x()));
+ DLIB_TEST(std::abs((rc1+rc2).variance_y() - rc3.variance_y()) < 1e-13);
+ DLIB_TEST(std::abs((rc1+rc2).covariance() - rc3.covariance()) < 1e-13);
+ DLIB_TEST((rc1+rc2).current_n() == rc3.current_n());
+
+ }
+
+ void test_average_precision()
+ {
+ std::vector<bool> items;
+ DLIB_TEST(average_precision(items) == 1);
+ DLIB_TEST(average_precision(items,1) == 0);
+
+ items.push_back(true);
+ DLIB_TEST(average_precision(items) == 1);
+ DLIB_TEST(std::abs(average_precision(items,1) - 0.5) < 1e-14);
+
+ items.push_back(true);
+ DLIB_TEST(average_precision(items) == 1);
+ DLIB_TEST(std::abs(average_precision(items,1) - 2.0/3.0) < 1e-14);
+
+ items.push_back(false);
+
+ DLIB_TEST(average_precision(items) == 1);
+ DLIB_TEST(std::abs(average_precision(items,1) - 2.0/3.0) < 1e-14);
+
+ items.push_back(true);
+
+ DLIB_TEST(std::abs(average_precision(items) - (2.0+3.0/4.0)/3.0) < 1e-14);
+
+ items.push_back(true);
+
+ DLIB_TEST(std::abs(average_precision(items) - (2.0 + 4.0/5.0 + 4.0/5.0)/4.0) < 1e-14);
+ DLIB_TEST(std::abs(average_precision(items,1) - (2.0 + 4.0/5.0 + 4.0/5.0)/5.0) < 1e-14);
+ }
+
+
+ template <typename sample_type>
+ void check_distance_metrics (
+ const std::vector<frobmetric_training_sample<sample_type> >& samples
+ )
+ {
+ running_stats<double> rs;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j)
+ {
+ const double d1 = length_squared(samples[i].anchor_vect - samples[i].near_vects[j]);
+ for (unsigned long k = 0; k < samples[i].far_vects.size(); ++k)
+ {
+ const double d2 = length_squared(samples[i].anchor_vect - samples[i].far_vects[k]);
+ rs.add(d2-d1);
+ }
+ }
+ }
+
+ dlog << LINFO << "dist gap max: "<< rs.max();
+ dlog << LINFO << "dist gap min: "<< rs.min();
+ dlog << LINFO << "dist gap mean: "<< rs.mean();
+ dlog << LINFO << "dist gap stddev: "<< rs.stddev();
+ DLIB_TEST(rs.min() >= 0.99);
+ DLIB_TEST(rs.mean() >= 0.9999);
+ }
+
+ void test_vector_normalizer_frobmetric(dlib::rand& rnd)
+ {
+ print_spinner();
+ typedef matrix<double,0,1> sample_type;
+ vector_normalizer_frobmetric<sample_type> normalizer;
+
+ std::vector<frobmetric_training_sample<sample_type> > samples;
+ frobmetric_training_sample<sample_type> samp;
+
+ const long key = 1;
+ const long dims = 5;
+ // Lets make some two class training data. Each sample will have dims dimensions but
+ // only the one with index equal to key will be meaningful. In particular, if the key
+ // dimension is > 0 then the sample is class +1 and -1 otherwise.
+
+ long k = 0;
+ for (int i = 0; i < 50; ++i)
+ {
+ samp.clear();
+ samp.anchor_vect = gaussian_randm(dims,1,k++);
+ if (samp.anchor_vect(key) > 0)
+ samp.anchor_vect(key) = rnd.get_random_double() + 5;
+ else
+ samp.anchor_vect(key) = -(rnd.get_random_double() + 5);
+
+ matrix<double,0,1> temp;
+
+ for (int j = 0; j < 5; ++j)
+ {
+ // Don't always put an equal number of near_vects and far_vects vectors into the
+ // training samples.
+ const int numa = rnd.get_random_32bit_number()%2 + 1;
+ const int numb = rnd.get_random_32bit_number()%2 + 1;
+
+ for (int num = 0; num < numa; ++num)
+ {
+ temp = gaussian_randm(dims,1,k++); temp(key) = 0.1;
+ //temp = gaussian_randm(dims,1,k++); temp(key) = std::abs(temp(key));
+ if (samp.anchor_vect(key) > 0) samp.near_vects.push_back(temp);
+ else samp.far_vects.push_back(temp);
+ }
+
+ for (int num = 0; num < numb; ++num)
+ {
+ temp = gaussian_randm(dims,1,k++); temp(key) = -0.1;
+ //temp = gaussian_randm(dims,1,k++); temp(key) = -std::abs(temp(key));
+ if (samp.anchor_vect(key) < 0) samp.near_vects.push_back(temp);
+ else samp.far_vects.push_back(temp);
+ }
+ }
+ samples.push_back(samp);
+ }
+
+ normalizer.set_epsilon(0.0001);
+ normalizer.set_c(100);
+ normalizer.set_max_iterations(6000);
+ normalizer.train(samples);
+
+ dlog << LINFO << "learned transform: \n" << normalizer.transform();
+
+ matrix<double,0,1> total;
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ samples[i].anchor_vect = normalizer(samples[i].anchor_vect);
+ total += samples[i].anchor_vect;
+ for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j)
+ samples[i].near_vects[j] = normalizer(samples[i].near_vects[j]);
+ for (unsigned long j = 0; j < samples[i].far_vects.size(); ++j)
+ samples[i].far_vects[j] = normalizer(samples[i].far_vects[j]);
+ }
+ total /= samples.size();
+ dlog << LINFO << "sample transformed means: "<< trans(total);
+ DLIB_TEST(length(total) < 1e-9);
+ check_distance_metrics(samples);
+
+ // make sure serialization works
+ stringstream os;
+ serialize(normalizer, os);
+ vector_normalizer_frobmetric<sample_type> normalizer2;
+ deserialize(normalizer2, os);
+ DLIB_TEST(equal(normalizer.transform(), normalizer2.transform()));
+ DLIB_TEST(equal(normalizer.transformed_means(), normalizer2.transformed_means()));
+ DLIB_TEST(normalizer.in_vector_size() == normalizer2.in_vector_size());
+ DLIB_TEST(normalizer.out_vector_size() == normalizer2.out_vector_size());
+ DLIB_TEST(normalizer.get_max_iterations() == normalizer2.get_max_iterations());
+ DLIB_TEST(std::abs(normalizer.get_c() - normalizer2.get_c()) < 1e-14);
+ DLIB_TEST(std::abs(normalizer.get_epsilon() - normalizer2.get_epsilon()) < 1e-14);
+
+ }
+
+ void prior_frobnorm_test()
+ {
+ frobmetric_training_sample<matrix<double,0,1> > sample;
+ std::vector<frobmetric_training_sample<matrix<double,0,1> > > samples;
+
+ matrix<double,3,1> x, near_, far_;
+ x = 0,0,0;
+ near_ = 1,0,0;
+ far_ = 0,1,0;
+
+ sample.anchor_vect = x;
+ sample.near_vects.push_back(near_);
+ sample.far_vects.push_back(far_);
+
+ samples.push_back(sample);
+
+ vector_normalizer_frobmetric<matrix<double,0,1> > trainer;
+ trainer.set_c(100);
+ print_spinner();
+ trainer.train(samples);
+
+ matrix<double,3,3> correct;
+ correct = 0, 0, 0,
+ 0, 1, 0,
+ 0, 0, 0;
+
+ dlog << LDEBUG << trainer.transform();
+ DLIB_TEST(max(abs(trainer.transform()-correct)) < 1e-8);
+
+ trainer.set_uses_identity_matrix_prior(true);
+ print_spinner();
+ trainer.train(samples);
+ correct = 1, 0, 0,
+ 0, 2, 0,
+ 0, 0, 1;
+
+ dlog << LDEBUG << trainer.transform();
+ DLIB_TEST(max(abs(trainer.transform()-correct)) < 1e-8);
+
+ }
+
+ void test_lda ()
+ {
+ // This test makes sure we pick the right direction in a simple 2D -> 1D LDA
+ typedef matrix<double,2,1> sample_type;
+
+ std::vector<unsigned long> labels;
+ std::vector<sample_type> samples;
+ for (int i=0; i<4; i++)
+ {
+ sample_type s;
+ s(0) = i;
+ s(1) = i+1;
+ samples.push_back(s);
+ labels.push_back(1);
+
+ sample_type s1;
+ s1(0) = i+1;
+ s1(1) = i;
+ samples.push_back(s1);
+ labels.push_back(2);
+ }
+
+ matrix<double> X;
+ X.set_size(8,2);
+ for (int i=0; i<8; i++){
+ X(i,0) = samples[i](0);
+ X(i,1) = samples[i](1);
+ }
+
+ matrix<double,0,1> mean;
+
+ dlib::compute_lda_transform(X,mean,labels,1);
+
+ std::vector<double> vals1, vals2;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ double val = X*samples[i]-mean;
+ if (i%2 == 0)
+ vals1.push_back(val);
+ else
+ vals2.push_back(val);
+ dlog << LINFO << "1D LDA output: " << val;
+ }
+
+ if (vals1[0] > vals2[0])
+ swap(vals1, vals2);
+
+ const double err = equal_error_rate(vals1, vals2).first;
+ dlog << LINFO << "LDA ERR: " << err;
+ DLIB_TEST(err == 0);
+ DLIB_TEST(equal_error_rate(vals2, vals1).first == 1);
+ }
+
+ void test_running_stats_decayed()
+ {
+ print_spinner();
+ std::vector<double> tmp(300);
+ std::vector<double> tmp_var(tmp.size());
+ dlib::rand rnd;
+ const int num_rounds = 100000;
+ for (int rounds = 0; rounds < num_rounds; ++rounds)
+ {
+ running_stats_decayed<double> rs(100);
+
+ for (size_t i = 0; i < tmp.size(); ++i)
+ {
+ rs.add(rnd.get_random_gaussian() + 1);
+ tmp[i] += rs.mean();
+ if (i > 0)
+ tmp_var[i] += rs.variance();
+ }
+ }
+
+ // should print all 1s basically since the mean and variance should always be 1.
+ for (size_t i = 0; i < tmp.size(); ++i)
+ {
+ DLIB_TEST(std::abs(1-tmp[i]/num_rounds) < 0.001);
+ if (i > 1)
+ DLIB_TEST(std::abs(1-tmp_var[i]/num_rounds) < 0.01);
+ }
+ }
+
+ void test_running_scalar_covariance_decayed()
+ {
+ print_spinner();
+ std::vector<double> tmp(300);
+ std::vector<double> tmp_var(tmp.size());
+ std::vector<double> tmp_covar(tmp.size());
+ dlib::rand rnd;
+ const int num_rounds = 500000;
+ for (int rounds = 0; rounds < num_rounds; ++rounds)
+ {
+ running_scalar_covariance_decayed<double> rs(100);
+
+ for (size_t i = 0; i < tmp.size(); ++i)
+ {
+ rs.add(rnd.get_random_gaussian() + 1, rnd.get_random_gaussian() + 1);
+ tmp[i] += (rs.mean_y()+rs.mean_x())/2;
+ if (i > 0)
+ {
+ tmp_var[i] += (rs.variance_y()+rs.variance_x())/2;
+ tmp_covar[i] += rs.covariance();
+ }
+ }
+ }
+
+ // should print all 1s basically since the mean and variance should always be 1.
+ for (size_t i = 0; i < tmp.size(); ++i)
+ {
+ DLIB_TEST(std::abs(1-tmp[i]/num_rounds) < 0.001);
+ if (i > 1)
+ {
+ DLIB_TEST(std::abs(1-tmp_var[i]/num_rounds) < 0.01);
+ DLIB_TEST(std::abs(tmp_covar[i]/num_rounds) < 0.001);
+ }
+ }
+ }
+
+
+ void test_event_corr()
+ {
+ print_spinner();
+ DLIB_TEST(event_correlation(1000,1000,500,2000) == 0);
+ DLIB_TEST(std::abs(event_correlation(1000,1000,300,2000) + 164.565757010104) < 1e-11);
+ DLIB_TEST(std::abs(event_correlation(1000,1000,700,2000) - 164.565757010104) < 1e-11);
+
+ DLIB_TEST(event_correlation(10,1000,5,2000) == 0);
+ DLIB_TEST(event_correlation(1000,10,5,2000) == 0);
+ DLIB_TEST(std::abs(event_correlation(10,1000,1,2000) - event_correlation(1000,10,1,2000)) < 1e-11);
+ DLIB_TEST(std::abs(event_correlation(10,1000,9,2000) - event_correlation(1000,10,9,2000)) < 1e-11);
+
+ DLIB_TEST(std::abs(event_correlation(10,1000,1,2000) + 3.69672251700842) < 1e-11);
+ DLIB_TEST(std::abs(event_correlation(10,1000,9,2000) - 3.69672251700842) < 1e-11);
+ }
+
+ void perform_test (
+ )
+ {
+ prior_frobnorm_test();
+ dlib::rand rnd;
+ for (int i = 0; i < 5; ++i)
+ test_vector_normalizer_frobmetric(rnd);
+
+ test_random_subset_selector();
+ test_random_subset_selector2();
+ test_running_covariance();
+ test_running_cross_covariance();
+ test_running_cross_covariance_sparse();
+ test_running_stats();
+ test_skewness_and_kurtosis_1();
+ test_skewness_and_kurtosis_2();
+ test_randomize_samples();
+ test_randomize_samples2();
+ another_test();
+ test_average_precision();
+ test_lda();
+ test_event_corr();
+ test_running_stats_decayed();
+ test_running_scalar_covariance_decayed();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/std_vector_c.cpp b/ml/dlib/dlib/test/std_vector_c.cpp
new file mode 100644
index 000000000..fe7f82514
--- /dev/null
+++ b/ml/dlib/dlib/test/std_vector_c.cpp
@@ -0,0 +1,101 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/stl_checked.h>
+
+#include "tester.h"
+
+// This is called an unnamed-namespace and it has the effect of making everything inside this file "private"
+// so that everything you declare will have static linkage. Thus we won't have any multiply
+// defined symbol errors coming out of the linker when we try to compile the test suite.
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ // Declare the logger we will use in this test. The name of the tester
+ // should start with "test."
+ logger dlog("test.std_vector_c");
+
+
+ class std_vector_c_tester : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a test for the std_vector_c object. When it is constructed
+ it adds itself into the testing framework. The command line switch is
+ specified as test_std_vector_c by passing that string to the tester constructor.
+ !*/
+ public:
+ std_vector_c_tester (
+ ) :
+ tester ("test_std_vector_c",
+ "Runs tests on the std_vector_c component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ std::vector<int> c;
+ std_vector_c<int> a, b;
+ a.push_back(3);
+ a.push_back(2);
+ a.push_back(1);
+
+ DLIB_TEST(a[0] == 3);
+ DLIB_TEST(a[1] == 2);
+ DLIB_TEST(a[2] == 1);
+ c = a;
+ DLIB_TEST(c[0] == 3);
+ DLIB_TEST(c[1] == 2);
+ DLIB_TEST(c[2] == 1);
+ DLIB_TEST(c.size() == 3);
+ DLIB_TEST(a.size() == 3);
+ DLIB_TEST(b.size() == 0);
+
+ DLIB_TEST(a == c);
+ DLIB_TEST(!(a != c));
+ DLIB_TEST(a <= c);
+ DLIB_TEST(a >= c);
+ DLIB_TEST(!(a < c));
+ DLIB_TEST(!(a > c));
+
+ swap(b,c);
+ DLIB_TEST(b[0] == 3);
+ DLIB_TEST(b[1] == 2);
+ DLIB_TEST(b[2] == 1);
+ DLIB_TEST(c.size() == 0);
+ DLIB_TEST(b.size() == 3);
+ swap(c,b);
+ DLIB_TEST(c[0] == 3);
+ DLIB_TEST(c[1] == 2);
+ DLIB_TEST(c[2] == 1);
+ DLIB_TEST(c.size() == 3);
+ DLIB_TEST(b.size() == 0);
+ swap(a,b);
+ DLIB_TEST(b[0] == 3);
+ DLIB_TEST(b[1] == 2);
+ DLIB_TEST(b[2] == 1);
+ DLIB_TEST(b.size() == 3);
+ DLIB_TEST(a.size() == 0);
+
+
+ swap(b,c);
+ swap(c,c);
+
+
+ std_vector_c<int> h(a);
+ std_vector_c<int> i(c);
+ std::vector<int> j(b);
+ }
+ } a;
+
+}
+
diff --git a/ml/dlib/dlib/test/string.cpp b/ml/dlib/dlib/test/string.cpp
new file mode 100644
index 000000000..18f9035c5
--- /dev/null
+++ b/ml/dlib/dlib/test/string.cpp
@@ -0,0 +1,329 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/string.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.string");
+
+
+ void string_test (
+ )
+ /*!
+ ensures
+ - runs tests on string functions for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ string a = " davis ";
+ string A = " DAVIS ";
+ string empty = " ";
+
+ dlog << LTRACE << 1;
+
+ double dval;
+ int ival;
+ bool bval;
+
+ DLIB_TEST_MSG(string_cast<int>("5") == 5,string_cast<int>("5"));
+ DLIB_TEST_MSG(string_cast<int>("0x5") == 5,string_cast<int>("0x5"));
+ DLIB_TEST_MSG(string_cast<int>("0xA") == 10,string_cast<int>("0xA"));
+ DLIB_TEST(string_cast<float>("0.5") == 0.5);
+ DLIB_TEST((dval = sa ="0.5") == 0.5);
+ DLIB_TEST(string_cast<std::string>("0.5 !") == "0.5 !");
+ DLIB_TEST(string_cast<bool>("true") == true);
+ DLIB_TEST((bval = sa = "true") == true);
+ DLIB_TEST(string_cast<bool>("false") == false);
+ DLIB_TEST(string_cast<bool>("TRUE") == true);
+ DLIB_TEST(string_cast<bool>("FALSE") == false);
+ DLIB_TEST((bval = sa = "FALSE") == false);
+
+ dlog << LTRACE << 2;
+
+ DLIB_TEST_MSG(string_cast<int>(L"5") == 5,string_cast<int>("5"));
+ DLIB_TEST_MSG((ival = sa = L"5") == 5,string_cast<int>("5"));
+ dlog << LTRACE << 2.1;
+ DLIB_TEST_MSG(string_cast<int>(L"0x5") == 5,string_cast<int>("0x5"));
+ DLIB_TEST_MSG(string_cast<int>(L"0xA") == 10,string_cast<int>("0xA"));
+ DLIB_TEST(string_cast<float>(L"0.5") == 0.5);
+ DLIB_TEST(string_cast<std::string>(L"0.5 !") == "0.5 !");
+ DLIB_TEST(string_cast<bool>(L"true") == true);
+ DLIB_TEST(string_cast<bool>(L"false") == false);
+ DLIB_TEST(string_cast<bool>(L"TRUE") == true);
+ DLIB_TEST((bval = sa = L"TRUE") == true);
+ DLIB_TEST(string_cast<bool>(L"FALSE") == false);
+
+ dlog << LTRACE << 3;
+
+ DLIB_TEST(cast_to_string(5) == "5");
+ DLIB_TEST(cast_to_string(5.5) == "5.5");
+
+ dlog << LTRACE << 4;
+ DLIB_TEST(cast_to_wstring(5) == L"5");
+ DLIB_TEST(cast_to_wstring(5.5) == L"5.5");
+ dlog << LTRACE << 5;
+ DLIB_TEST(toupper(a) == A);
+ DLIB_TEST(toupper(A) == A);
+ DLIB_TEST(tolower(a) == a);
+ DLIB_TEST(tolower(A) == a);
+ DLIB_TEST(trim(a) == "davis");
+ DLIB_TEST(ltrim(a) == "davis ");
+ DLIB_TEST(rtrim(a) == " davis");
+ DLIB_TEST(trim(string_cast<wstring>(a)) == L"davis");
+ DLIB_TEST(ltrim(string_cast<wstring>(a)) == L"davis ");
+ DLIB_TEST(rtrim(string_cast<wstring>(a)) == L" davis");
+ DLIB_TEST(trim(a, " ") == "davis");
+ DLIB_TEST(ltrim(a, " ") == "davis ");
+ DLIB_TEST(rtrim(a, " ") == " davis");
+ DLIB_TEST(trim(empty) == "");
+ DLIB_TEST(ltrim(empty) == "");
+ DLIB_TEST(rtrim(empty) == "");
+ DLIB_TEST(trim(string_cast<wstring>(empty)) == L"");
+ DLIB_TEST(ltrim(string_cast<wstring>(empty)) == L"");
+ DLIB_TEST(rtrim(string_cast<wstring>(empty)) == L"");
+ DLIB_TEST(trim(empty, " ") == "");
+ DLIB_TEST(ltrim(empty, " ") == "");
+ DLIB_TEST(rtrim(empty, " ") == "");
+
+
+ dlog << LTRACE << 6;
+ DLIB_TEST( (lpad(wstring(L"davis"), 10) == L" davis"));
+ DLIB_TEST( (rpad(wstring(L"davis"), 10) == L"davis "));
+ DLIB_TEST( (pad(wstring(L"davis"), 10) == L" davis "));
+
+ DLIB_TEST( (lpad(string("davis"), -10) == "davis"));
+ DLIB_TEST( (rpad(string("davis"), -10) == "davis"));
+ DLIB_TEST( (pad(string("davis"), -10) == "davis"));
+ DLIB_TEST( (lpad(string("davis"), 10) == " davis"));
+ DLIB_TEST( (rpad(string("davis"), 10) == "davis "));
+ DLIB_TEST( (pad(string("davis"), 10) == " davis "));
+ DLIB_TEST( (lpad(string("davis"), 10, string("*")) == "*****davis"));
+ DLIB_TEST( (rpad(string("davis"), 10, string("*")) == "davis*****"));
+ DLIB_TEST( (pad(string("davis"), 10, string("*")) == "**davis***"));
+ DLIB_TEST( (lpad(string("davis"), 10, string("_-")) == "_-_-_davis"));
+ DLIB_TEST( (rpad(string("davis"), 10, string("_-")) == "davis_-_-_"));
+ DLIB_TEST( (pad(string("davis"), 10, string("_-")) == "_-davis_-_"));
+ DLIB_TEST( (lpad(string("davis"), 10, string("willy wanka")) == "willydavis"));
+ DLIB_TEST( (rpad(string("davis"), 10, string("willy wanka")) == "daviswilly"));
+ DLIB_TEST( (pad(string("davis"), 10, string("willy wanka")) == "widaviswil"));
+ DLIB_TEST( (lpad(string("davis"), 10, "*")) == "*****davis");
+ DLIB_TEST( (rpad(string("davis"), 10, "*") == "davis*****"));
+ DLIB_TEST( (pad(string("davis"), 10, "*") == "**davis***"));
+ DLIB_TEST( (lpad(string("davis"), 10, "_-") == "_-_-_davis"));
+ DLIB_TEST( (rpad(string("davis"), 10, "_-") == "davis_-_-_"));
+ DLIB_TEST( (pad(string("davis"), 10, "_-") == "_-davis_-_"));
+ DLIB_TEST( (lpad(string("davis"), 10, "willy wanka") == "willydavis"));
+ DLIB_TEST( (rpad(string("davis"), 10, "willy wanka") == "daviswilly"));
+ DLIB_TEST( (pad(string("davis"), 10, "willy wanka") == "widaviswil"));
+ dlog << LTRACE << 7;
+
+ a = "file.txt";
+ DLIB_TEST( (left_substr(a,string(".")) == "file"));
+ DLIB_TEST( (left_substr(a,".") == "file"));
+ DLIB_TEST( (right_substr(a,string(".")) == "txt"));
+ DLIB_TEST( (right_substr(a,".") == "txt"));
+
+ DLIB_TEST( (left_substr(a," ") == "file.txt"));
+ DLIB_TEST( (right_substr(a," ") == ""));
+
+ DLIB_TEST( (left_substr(a,"") == "file.txt"));
+ DLIB_TEST( (right_substr(a,"") == ""));
+
+ wstring ws = L"file.txt";
+ DLIB_TEST( (left_substr(ws,wstring(L".")) == L"file"));
+ DLIB_TEST_MSG( (left_substr(ws,L".") == L"file"), L"");
+ DLIB_TEST( (right_substr(ws,wstring(L".")) == L"txt"));
+ DLIB_TEST_MSG( (right_substr(ws,L".") == L"txt"), L"");
+
+
+ dlog << LTRACE << 8;
+ {
+ ostringstream sout;
+ wchar_t w = 85;
+ char c = 4;
+ serialize(w,sout);
+ serialize(c,sout);
+ w = static_cast<wchar_t>(-1);
+ serialize(w,sout);
+ c = static_cast<char>(-1);
+ serialize(c,sout);
+
+ istringstream sin(sout.str());
+ w = 0;
+ c = 0;
+ deserialize(w,sin);
+ deserialize(c,sin);
+ DLIB_TEST(w == 85);
+ DLIB_TEST(c == 4);
+ deserialize(w,sin);
+ deserialize(c,sin);
+ DLIB_TEST(w == static_cast<wchar_t>(-1));
+ DLIB_TEST(c == static_cast<char>(-1));
+
+ wstring str = L"test string";
+
+ sout.str("");
+ serialize(str, sout);
+ sin.clear();
+ sin.str(sout.str());
+ str = L"something else";
+ deserialize(str,sin);
+ DLIB_TEST(str == L"test string");
+ }
+ }
+
+
+ void test_split()
+ {
+ std::vector<string> v;
+
+ string str;
+ string delim = " , ";
+
+ v = split(string("one, two,three four")," ,");
+ DLIB_TEST(v.size() == 4);
+ DLIB_TEST(v[0] == "one");
+ DLIB_TEST(v[1] == "two");
+ DLIB_TEST(v[2] == "three");
+ DLIB_TEST(v[3] == "four");
+
+ v = split(string("one, two,three four"),delim);
+ DLIB_TEST(v.size() == 4);
+ DLIB_TEST(v[0] == "one");
+ DLIB_TEST(v[1] == "two");
+ DLIB_TEST(v[2] == "three");
+ DLIB_TEST(v[3] == "four");
+
+ v = split(string(""));
+ DLIB_TEST(v.size() == 0);
+
+ v = split(string(" "));
+ DLIB_TEST(v.size() == 0);
+
+ v = split(string(" one two "));
+ DLIB_TEST(v.size() == 2);
+ DLIB_TEST(v[0] == "one");
+ DLIB_TEST(v[1] == "two");
+
+ v = split(string(" one "));
+ DLIB_TEST(v.size() == 1);
+ DLIB_TEST(v[0] == "one");
+
+ v = split(string("one"));
+ DLIB_TEST(v.size() == 1);
+ DLIB_TEST(v[0] == "one");
+
+ v = split(string("o"));
+ DLIB_TEST(v.size() == 1);
+ DLIB_TEST(v[0] == "o");
+
+
+ std::vector<wstring> wv;
+ wstring wstr = L"test string";
+ wv = split(wstr);
+ DLIB_TEST(wv.size() == 2);
+ DLIB_TEST(wv[0] == L"test");
+ DLIB_TEST(wv[1] == L"string");
+ wv = split(wstr,L" ");
+ DLIB_TEST(wv.size() == 2);
+ DLIB_TEST(wv[0] == L"test");
+ DLIB_TEST(wv[1] == L"string");
+
+ wstr = L"Über alle Maßen\u00A0Öttingenstraße";
+ wv = split(wstr, L" \u00A0\n\r\t");
+ DLIB_TEST(wv.size() == 4);
+ DLIB_TEST(wv[0] == L"Über");
+ DLIB_TEST(wv[1] == L"alle");
+ DLIB_TEST(wv[2] == L"Maßen");
+ DLIB_TEST(wv[3] == L"Öttingenstraße");
+
+ wstr = L"test string hah";
+ DLIB_TEST(split_on_first(wstr).first == L"test");
+ DLIB_TEST(split_on_first(wstr).second == L"string hah");
+ DLIB_TEST(split_on_first(wstr,L"#").first == L"test string hah");
+ DLIB_TEST(split_on_first(wstr,L"#").second == L"");
+ DLIB_TEST(split_on_last(wstr).first == L"test string");
+ DLIB_TEST(split_on_last(wstr).second == L"hah");
+ DLIB_TEST(split_on_last(wstr,L"#").first == L"test string hah");
+ DLIB_TEST(split_on_last(wstr,L"#").second == L"");
+ wstr = L"";
+ DLIB_TEST(split_on_first(wstr).first == L"");
+ DLIB_TEST(split_on_first(wstr).second == L"");
+
+ str = "test string hah";
+ DLIB_TEST(split_on_first(str).first == "test");
+ DLIB_TEST(split_on_first(str).second == "string hah");
+ DLIB_TEST(split_on_first(str,"#").first == "test string hah");
+ DLIB_TEST(split_on_first(str,"#").second == "");
+ DLIB_TEST(split_on_last(str).first == "test string");
+ DLIB_TEST(split_on_last(str).second == "hah");
+ DLIB_TEST(split_on_last(str,"#").first == "test string hah");
+ DLIB_TEST(split_on_last(str,"#").second == "");
+ str = "";
+ DLIB_TEST(split_on_first(str).first == "");
+ DLIB_TEST(split_on_first(str).second == "");
+
+ wstr = L"test.string.hah";
+ DLIB_TEST(split_on_first(wstr,L".").first == L"test");
+ DLIB_TEST(split_on_first(wstr,L".").second == L"string.hah");
+ DLIB_TEST(split_on_first(wstr).first == L"test.string.hah");
+ DLIB_TEST(split_on_first(wstr).second == L"");
+ DLIB_TEST(split_on_last(wstr,L".").first == L"test.string");
+ DLIB_TEST(split_on_last(wstr,L".").second == L"hah");
+ DLIB_TEST(split_on_last(wstr).first == L"test.string.hah");
+ DLIB_TEST(split_on_last(wstr).second == L"");
+ wstr = L"";
+ DLIB_TEST(split_on_first(wstr).first == L"");
+ DLIB_TEST(split_on_first(wstr).second == L"");
+
+ str = "test.string.hah";
+ DLIB_TEST(split_on_first(str,".").first == "test");
+ DLIB_TEST(split_on_first(str,".").second == "string.hah");
+ DLIB_TEST(split_on_first(str).first == "test.string.hah");
+ DLIB_TEST(split_on_first(str).second == "");
+ DLIB_TEST(split_on_last(str,".").first == "test.string");
+ DLIB_TEST(split_on_last(str,".").second == "hah");
+ DLIB_TEST(split_on_last(str).first == "test.string.hah");
+ DLIB_TEST(split_on_last(str).second == "");
+ str = "";
+ DLIB_TEST(split_on_first(str).first == "");
+ DLIB_TEST(split_on_first(str).second == "");
+ }
+
+
+
+ class string_tester : public tester
+ {
+ public:
+ string_tester (
+ ) :
+ tester ("test_string",
+ "Runs tests on the string objects and functions.")
+ {}
+
+ void perform_test (
+ )
+ {
+ string_test();
+ test_split();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/svm.cpp b/ml/dlib/dlib/test/svm.cpp
new file mode 100644
index 000000000..b46d44331
--- /dev/null
+++ b/ml/dlib/dlib/test/svm.cpp
@@ -0,0 +1,661 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+#include "checkerboard.h"
+#include <dlib/statistics.h>
+
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.svm");
+
+// ----------------------------------------------------------------------------------------
+
+ void test_clutering (
+ )
+ {
+ dlog << LINFO << " being test_clutering()";
+ // Here we declare that our samples will be 2 dimensional column vectors.
+ typedef matrix<double,2,1> sample_type;
+
+ // Now we are making a typedef for the kind of kernel we want to use. I picked the
+ // radial basis kernel because it only has one parameter and generally gives good
+ // results without much fiddling.
+ typedef radial_basis_kernel<sample_type> kernel_type;
+
+ // Here we declare an instance of the kcentroid object. The first argument to the constructor
+ // is the kernel we wish to use. The second is a parameter that determines the numerical
+ // accuracy with which the object will perform part of the learning algorithm. Generally
+ // smaller values give better results but cause the algorithm to run slower. You just have
+ // to play with it to decide what balance of speed and accuracy is right for your problem.
+ // Here we have set it to 0.01.
+ kcentroid<kernel_type> kc(kernel_type(0.1),0.01);
+
+ // Now we make an instance of the kkmeans object and tell it to use kcentroid objects
+ // that are configured with the parameters from the kc object we defined above.
+ kkmeans<kernel_type> test(kc);
+
+ std::vector<sample_type> samples;
+ std::vector<sample_type> initial_centers;
+
+ sample_type m;
+
+ dlib::rand rnd;
+
+ print_spinner();
+ // we will make 50 points from each class
+ const long num = 50;
+
+ // make some samples near the origin
+ double radius = 0.5;
+ for (long i = 0; i < num; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ // add this sample to our set of samples we will run k-means
+ samples.push_back(m);
+ }
+
+ // make some samples in a circle around the origin but far away
+ radius = 10.0;
+ for (long i = 0; i < num; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ // add this sample to our set of samples we will run k-means
+ samples.push_back(m);
+ }
+
+ // make some samples in a circle around the point (25,25)
+ radius = 4.0;
+ for (long i = 0; i < num; ++i)
+ {
+ double sign = 1;
+ if (rnd.get_random_double() < 0.5)
+ sign = -1;
+ m(0) = 2*radius*rnd.get_random_double()-radius;
+ m(1) = sign*sqrt(radius*radius - m(0)*m(0));
+
+ // translate this point away from the origin
+ m(0) += 25;
+ m(1) += 25;
+
+ // add this sample to our set of samples we will run k-means
+ samples.push_back(m);
+ }
+ print_spinner();
+
+ // tell the kkmeans object we made that we want to run k-means with k set to 3.
+ // (i.e. we want 3 clusters)
+ test.set_number_of_centers(3);
+
+ // You need to pick some initial centers for the k-means algorithm. So here
+ // we will use the dlib::pick_initial_centers() function which tries to find
+ // n points that are far apart (basically).
+ pick_initial_centers(3, initial_centers, samples, test.get_kernel());
+
+ print_spinner();
+ // now run the k-means algorithm on our set of samples.
+ test.train(samples,initial_centers);
+ print_spinner();
+
+ const unsigned long class1 = test(samples[0]);
+ const unsigned long class2 = test(samples[num]);
+ const unsigned long class3 = test(samples[2*num]);
+ // now loop over all our samples and print out their predicted class. In this example
+ // all points are correctly identified.
+ for (unsigned long i = 0; i < samples.size()/3; ++i)
+ {
+ DLIB_TEST(test(samples[i]) == class1);
+ DLIB_TEST(test(samples[i+num]) == class2);
+ DLIB_TEST(test(samples[i+2*num]) == class3);
+ }
+
+ dlog << LINFO << " end test_clutering()";
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // Here is the sinc function we will be trying to learn with the krls
+ // object.
+ double sinc(double x)
+ {
+ if (x == 0)
+ return 1;
+ return sin(x)/x;
+ }
+
+
+ void test_regression (
+ )
+ {
+ dlog << LINFO << " being test_regression()";
+ // Here we declare that our samples will be 1 dimensional column vectors. The reason for
+ // using a matrix here is that in general you can use N dimensional vectors as inputs to the
+ // krls object. But here we only have 1 dimension to make the example simple.
+ typedef matrix<double,1,1> sample_type;
+
+ // Now we are making a typedef for the kind of kernel we want to use. I picked the
+ // radial basis kernel because it only has one parameter and generally gives good
+ // results without much fiddling.
+ typedef radial_basis_kernel<sample_type> kernel_type;
+
+ // Here we declare an instance of the krls object. The first argument to the constructor
+ // is the kernel we wish to use. The second is a parameter that determines the numerical
+ // accuracy with which the object will perform part of the regression algorithm. Generally
+ // smaller values give better results but cause the algorithm to run slower. You just have
+ // to play with it to decide what balance of speed and accuracy is right for your problem.
+ // Here we have set it to 0.001.
+ krls<kernel_type> test(kernel_type(0.1),0.001);
+ rvm_regression_trainer<kernel_type> rvm_test;
+ rvm_test.set_kernel(test.get_kernel());
+
+ krr_trainer<kernel_type> krr_test;
+ krr_test.set_kernel(test.get_kernel());
+
+ svr_trainer<kernel_type> svr_test;
+ svr_test.set_kernel(test.get_kernel());
+ svr_test.set_epsilon_insensitivity(0.0001);
+ svr_test.set_c(10);
+
+ rbf_network_trainer<kernel_type> rbf_test;
+ rbf_test.set_kernel(test.get_kernel());
+ rbf_test.set_num_centers(13);
+
+ print_spinner();
+ std::vector<sample_type> samples;
+ std::vector<sample_type> samples2;
+ std::vector<double> labels;
+ std::vector<double> labels2;
+ // now we train our object on a few samples of the sinc function.
+ sample_type m;
+ for (double x = -10; x <= 5; x += 0.6)
+ {
+ m(0) = x;
+ test.train(m, sinc(x));
+
+ samples.push_back(m);
+ samples2.push_back(m);
+ labels.push_back(sinc(x));
+ labels2.push_back(2);
+ }
+
+ print_spinner();
+ decision_function<kernel_type> test2 = rvm_test.train(samples, labels);
+ print_spinner();
+ decision_function<kernel_type> test3 = rbf_test.train(samples, labels);
+ print_spinner();
+ decision_function<kernel_type> test4 = krr_test.train(samples, labels);
+ print_spinner();
+ decision_function<kernel_type> test5 = svr_test.train(samples, labels);
+ print_spinner();
+
+ // now we output the value of the sinc function for a few test points as well as the
+ // value predicted by krls object.
+ m(0) = 2.5; dlog << LDEBUG << "krls: " << sinc(m(0)) << " " << test(m); DLIB_TEST(abs(sinc(m(0)) - test(m)) < 0.01);
+ m(0) = 0.1; dlog << LDEBUG << "krls: " << sinc(m(0)) << " " << test(m); DLIB_TEST(abs(sinc(m(0)) - test(m)) < 0.01);
+ m(0) = -4; dlog << LDEBUG << "krls: " << sinc(m(0)) << " " << test(m); DLIB_TEST(abs(sinc(m(0)) - test(m)) < 0.01);
+ m(0) = 5.0; dlog << LDEBUG << "krls: " << sinc(m(0)) << " " << test(m); DLIB_TEST(abs(sinc(m(0)) - test(m)) < 0.01);
+
+ m(0) = 2.5; dlog << LDEBUG << "rvm: " << sinc(m(0)) << " " << test2(m); DLIB_TEST(abs(sinc(m(0)) - test2(m)) < 0.01);
+ m(0) = 0.1; dlog << LDEBUG << "rvm: " << sinc(m(0)) << " " << test2(m); DLIB_TEST(abs(sinc(m(0)) - test2(m)) < 0.01);
+ m(0) = -4; dlog << LDEBUG << "rvm: " << sinc(m(0)) << " " << test2(m); DLIB_TEST(abs(sinc(m(0)) - test2(m)) < 0.01);
+ m(0) = 5.0; dlog << LDEBUG << "rvm: " << sinc(m(0)) << " " << test2(m); DLIB_TEST(abs(sinc(m(0)) - test2(m)) < 0.01);
+
+ m(0) = 2.5; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01);
+ m(0) = 0.1; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01);
+ m(0) = -4; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01);
+ m(0) = 5.0; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01);
+
+ m(0) = 2.5; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
+ m(0) = 0.1; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
+ m(0) = -4; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
+ m(0) = 5.0; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
+
+ m(0) = 2.5; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01);
+ m(0) = 0.1; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01);
+ m(0) = -4; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01);
+ m(0) = 5.0; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01);
+
+
+ randomize_samples(samples, labels);
+ dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6);
+ dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6);
+ matrix<double,1,4> cv = cross_validate_regression_trainer(krr_test, samples, labels, 6);
+ DLIB_TEST(cv(0) < 1e-4);
+ DLIB_TEST(cv(1) > 0.99);
+ cv = cross_validate_regression_trainer(svr_test, samples, labels, 6);
+ DLIB_TEST(cv(0) < 1e-4);
+ DLIB_TEST(cv(1) > 0.99);
+
+
+
+
+ randomize_samples(samples2, labels2);
+ dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples2, labels2, 6);
+ dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples2, labels2, 6);
+ cv = cross_validate_regression_trainer(krr_test, samples2, labels2, 6);
+ DLIB_TEST(cv(0) < 1e-4);
+ cv = cross_validate_regression_trainer(svr_test, samples2, labels2, 6);
+ DLIB_TEST(cv(0) < 1e-4);
+
+ dlog << LINFO << " end test_regression()";
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_anomaly_detection (
+ )
+ {
+ dlog << LINFO << " begin test_anomaly_detection()";
+ // Here we declare that our samples will be 2 dimensional column vectors.
+ typedef matrix<double,2,1> sample_type;
+
+ // Now we are making a typedef for the kind of kernel we want to use. I picked the
+ // radial basis kernel because it only has one parameter and generally gives good
+ // results without much fiddling.
+ typedef radial_basis_kernel<sample_type> kernel_type;
+
+ // Here we declare an instance of the kcentroid object. The first argument to the constructor
+ // is the kernel we wish to use. The second is a parameter that determines the numerical
+ // accuracy with which the object will perform part of the learning algorithm. Generally
+ // smaller values give better results but cause the algorithm to run slower. You just have
+ // to play with it to decide what balance of speed and accuracy is right for your problem.
+ // Here we have set it to 0.01.
+ kcentroid<kernel_type> test(kernel_type(0.1),0.01);
+
+
+ svm_one_class_trainer<kernel_type> one_class_trainer;
+ one_class_trainer.set_nu(0.4);
+ one_class_trainer.set_kernel(kernel_type(0.2));
+
+ std::vector<sample_type> samples;
+
+ // now we train our object on a few samples of the sinc function.
+ sample_type m;
+ for (double x = -15; x <= 8; x += 1)
+ {
+ m(0) = x;
+ m(1) = sinc(x);
+ test.train(m);
+ samples.push_back(m);
+ }
+
+ decision_function<kernel_type> df = one_class_trainer.train(samples);
+
+ running_stats<double> rs;
+
+ // Now lets output the distance from the centroid to some points that are from the sinc function.
+ // These numbers should all be similar. We will also calculate the statistics of these numbers
+ // by accumulating them into the running_stats object called rs. This will let us easily
+ // find the mean and standard deviation of the distances for use below.
+ dlog << LDEBUG << "Points that are on the sinc function:\n";
+ m(0) = -1.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m));
+ m(0) = -1.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m));
+ m(0) = -0; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m));
+ m(0) = -0.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m));
+ m(0) = -4.1; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m));
+ m(0) = -1.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m));
+ m(0) = -0.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m));
+
+ m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m)));
+ m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m)));
+ m(0) = -0; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m)));
+ m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m)));
+ m(0) = -4.1; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m)));
+ m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m)));
+ m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m)));
+
+ const double thresh = 0.01;
+ m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
+ m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
+ m(0) = -0; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
+ m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
+ m(0) = -4.1; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
+ m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
+ m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
+
+ dlog << LDEBUG;
+ // Lets output the distance from the centroid to some points that are NOT from the sinc function.
+ // These numbers should all be significantly bigger than previous set of numbers. We will also
+ // use the rs.scale() function to find out how many standard deviations they are away from the
+ // mean of the test points from the sinc function. So in this case our criterion for "significantly bigger"
+ // is > 3 or 4 standard deviations away from the above points that actually are on the sinc function.
+ dlog << LDEBUG << "Points that are NOT on the sinc function:\n";
+ m(0) = -1.5; m(1) = sinc(m(0))+4;
+ dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
+ DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
+ DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
+
+ m(0) = -1.5; m(1) = sinc(m(0))+3;
+ dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
+ DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
+ DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
+
+ m(0) = -0; m(1) = -sinc(m(0));
+ dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
+ DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
+ DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
+
+ m(0) = -0.5; m(1) = -sinc(m(0));
+ dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
+ DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
+ DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
+
+ m(0) = -4.1; m(1) = sinc(m(0))+2;
+ dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
+ DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
+ DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
+
+ m(0) = -1.5; m(1) = sinc(m(0))+0.9;
+ dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
+ DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
+ DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
+
+ m(0) = -0.5; m(1) = sinc(m(0))+1;
+ dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
+ DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
+ DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
+
+ dlog << LINFO << " end test_anomaly_detection()";
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void unittest_binary_classification (
+ )
+ /*!
+ ensures
+ - runs tests on the svm stuff compliance with the specs
+ !*/
+ {
+ dlog << LINFO << " begin unittest_binary_classification()";
+ print_spinner();
+
+
+ typedef double scalar_type;
+ typedef matrix<scalar_type,2,1> sample_type;
+
+ std::vector<sample_type> x;
+ std::vector<matrix<double,0,1> > x_linearized;
+ std::vector<scalar_type> y;
+
+ get_checkerboard_problem(x,y, 300, 2);
+ const scalar_type gamma = 1;
+
+ typedef radial_basis_kernel<sample_type> kernel_type;
+
+ rbf_network_trainer<kernel_type> rbf_trainer;
+ rbf_trainer.set_kernel(kernel_type(gamma));
+ rbf_trainer.set_num_centers(100);
+
+ rvm_trainer<kernel_type> rvm_trainer;
+ rvm_trainer.set_kernel(kernel_type(gamma));
+
+ krr_trainer<kernel_type> krr_trainer;
+ krr_trainer.use_classification_loss_for_loo_cv();
+ krr_trainer.set_kernel(kernel_type(gamma));
+
+ svm_pegasos<kernel_type> pegasos_trainer;
+ pegasos_trainer.set_kernel(kernel_type(gamma));
+ pegasos_trainer.set_lambda(0.00001);
+
+
+ svm_c_ekm_trainer<kernel_type> ocas_ekm_trainer;
+ ocas_ekm_trainer.set_kernel(kernel_type(gamma));
+ ocas_ekm_trainer.set_c(100000);
+
+ svm_nu_trainer<kernel_type> trainer;
+ trainer.set_kernel(kernel_type(gamma));
+ trainer.set_nu(0.05);
+
+ svm_c_trainer<kernel_type> c_trainer;
+ c_trainer.set_kernel(kernel_type(gamma));
+ c_trainer.set_c(100);
+
+ svm_c_linear_trainer<linear_kernel<matrix<double,0,1> > > lin_trainer;
+ lin_trainer.set_c(100000);
+ // use an ekm to linearize this dataset so we can use it with the lin_trainer
+ empirical_kernel_map<kernel_type> ekm;
+ ekm.load(kernel_type(gamma), x);
+ for (unsigned long i = 0; i < x.size(); ++i)
+ x_linearized.push_back(ekm.project(x[i]));
+
+
+ print_spinner();
+ matrix<scalar_type> rvm_cv = cross_validate_trainer_threaded(rvm_trainer, x,y, 4, 2);
+ print_spinner();
+ matrix<scalar_type> krr_cv = cross_validate_trainer_threaded(krr_trainer, x,y, 4, 2);
+ print_spinner();
+ matrix<scalar_type> svm_cv = cross_validate_trainer(trainer, x,y, 4);
+ print_spinner();
+ matrix<scalar_type> svm_c_cv = cross_validate_trainer(c_trainer, x,y, 4);
+ print_spinner();
+ matrix<scalar_type> rbf_cv = cross_validate_trainer_threaded(rbf_trainer, x,y, 10, 2);
+ print_spinner();
+ matrix<scalar_type> lin_cv = cross_validate_trainer_threaded(lin_trainer, x_linearized, y, 4, 2);
+ print_spinner();
+ matrix<scalar_type> ocas_ekm_cv = cross_validate_trainer_threaded(ocas_ekm_trainer, x, y, 4, 2);
+ print_spinner();
+ ocas_ekm_trainer.set_basis(randomly_subsample(x, 300));
+ matrix<scalar_type> ocas_ekm_cv2 = cross_validate_trainer_threaded(ocas_ekm_trainer, x, y, 4, 2);
+ print_spinner();
+ matrix<scalar_type> peg_cv = cross_validate_trainer_threaded(batch(pegasos_trainer,1.0), x,y, 4, 2);
+ print_spinner();
+ matrix<scalar_type> peg_c_cv = cross_validate_trainer_threaded(batch_cached(pegasos_trainer,1.0), x,y, 4, 2);
+ print_spinner();
+
+ dlog << LDEBUG << "rvm cv: " << rvm_cv;
+ dlog << LDEBUG << "krr cv: " << krr_cv;
+ dlog << LDEBUG << "nu-svm cv: " << svm_cv;
+ dlog << LDEBUG << "C-svm cv: " << svm_c_cv;
+ dlog << LDEBUG << "rbf cv: " << rbf_cv;
+ dlog << LDEBUG << "lin cv: " << lin_cv;
+ dlog << LDEBUG << "ocas_ekm cv: " << ocas_ekm_cv;
+ dlog << LDEBUG << "ocas_ekm cv2: " << ocas_ekm_cv2;
+ dlog << LDEBUG << "peg cv: " << peg_cv;
+ dlog << LDEBUG << "peg cached cv: " << peg_c_cv;
+
+ // make sure the cached version of pegasos computes the same result
+ DLIB_TEST_MSG(sum(abs(peg_cv - peg_c_cv)) < std::sqrt(std::numeric_limits<double>::epsilon()),
+ sum(abs(peg_cv - peg_c_cv)) << " \n" << peg_cv << peg_c_cv );
+
+ DLIB_TEST_MSG(mean(rvm_cv) > 0.9, rvm_cv);
+ DLIB_TEST_MSG(mean(krr_cv) > 0.9, krr_cv);
+ DLIB_TEST_MSG(mean(svm_cv) > 0.9, svm_cv);
+ DLIB_TEST_MSG(mean(svm_c_cv) > 0.9, svm_c_cv);
+ DLIB_TEST_MSG(mean(rbf_cv) > 0.9, rbf_cv);
+ DLIB_TEST_MSG(mean(lin_cv) > 0.9, lin_cv);
+ DLIB_TEST_MSG(mean(peg_cv) > 0.9, peg_cv);
+ DLIB_TEST_MSG(mean(peg_c_cv) > 0.9, peg_c_cv);
+ DLIB_TEST_MSG(mean(ocas_ekm_cv) > 0.9, ocas_ekm_cv);
+ DLIB_TEST_MSG(mean(ocas_ekm_cv2) > 0.9, ocas_ekm_cv2);
+
+ const long num_sv = trainer.train(x,y).basis_vectors.size();
+ print_spinner();
+ const long num_rv = rvm_trainer.train(x,y).basis_vectors.size();
+ print_spinner();
+ dlog << LDEBUG << "num sv: " << num_sv;
+ dlog << LDEBUG << "num rv: " << num_rv;
+ print_spinner();
+ ocas_ekm_trainer.clear_basis();
+ const long num_bv = ocas_ekm_trainer.train(x,y).basis_vectors.size();
+ dlog << LDEBUG << "num ekm bv: " << num_bv;
+
+ DLIB_TEST(num_rv <= 17);
+ DLIB_TEST_MSG(num_sv <= 45, num_sv);
+ DLIB_TEST_MSG(num_bv <= 45, num_bv);
+
+ decision_function<kernel_type> df = reduced2(trainer, 19).train(x,y);
+ print_spinner();
+
+ matrix<scalar_type> svm_reduced_error = test_binary_decision_function(df, x, y);
+ print_spinner();
+ dlog << LDEBUG << "svm reduced test error: " << svm_reduced_error;
+ dlog << LDEBUG << "svm reduced num sv: " << df.basis_vectors.size();
+ DLIB_TEST(mean(svm_reduced_error) > 0.9);
+
+ svm_cv = cross_validate_trainer(reduced(trainer,30), x,y, 4);
+ dlog << LDEBUG << "svm reduced cv: " << svm_cv;
+ DLIB_TEST_MSG(mean(svm_cv) > 0.9, svm_cv);
+
+ DLIB_TEST(df.basis_vectors.size() <= 19);
+ dlog << LINFO << " end unittest_binary_classification()";
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename kernel_type>
+ struct kernel_der_obj
+ {
+ typename kernel_type::sample_type x;
+ kernel_type k;
+
+ double operator()(const typename kernel_type::sample_type& y) const { return k(x,y); }
+ };
+
+
+ template <typename kernel_type>
+ void test_kernel_derivative (
+ const kernel_type& k,
+ const typename kernel_type::sample_type& x,
+ const typename kernel_type::sample_type& y
+ )
+ {
+ kernel_der_obj<kernel_type> obj;
+ obj.x = x;
+ obj.k = k;
+ kernel_derivative<kernel_type> der(obj.k);
+ DLIB_TEST(dlib::equal(derivative(obj)(y) , der(obj.x,y), 1e-5));
+ }
+
+ void test_kernel_derivative (
+ )
+ {
+ typedef matrix<double, 2, 1> sample_type;
+
+ sigmoid_kernel<sample_type> k1;
+ radial_basis_kernel<sample_type> k2;
+ linear_kernel<sample_type> k3;
+ polynomial_kernel<sample_type> k4(2,3,4);
+
+ offset_kernel<sigmoid_kernel<sample_type> > k5;
+ offset_kernel<radial_basis_kernel<sample_type> > k6;
+
+ dlib::rand rnd;
+
+ sample_type x, y;
+ for (int i = 0; i < 10; ++i)
+ {
+ x = randm(2,1,rnd);
+ y = randm(2,1,rnd);
+ test_kernel_derivative(k1, x, y);
+ test_kernel_derivative(k2, x, y);
+ test_kernel_derivative(k3, x, y);
+ test_kernel_derivative(k4, x, y);
+ test_kernel_derivative(k5, x, y);
+ test_kernel_derivative(k6, x, y);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_svm_trainer2()
+ {
+ typedef matrix<double, 2, 1> sample_type;
+ typedef linear_kernel<sample_type> kernel_type;
+
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ sample_type samp;
+ samp(0) = 1;
+ samp(1) = 1;
+ samples.push_back(samp);
+ labels.push_back(+1);
+
+ samp(0) = 1;
+ samp(1) = 2;
+ samples.push_back(samp);
+ labels.push_back(-1);
+
+ svm_c_trainer<kernel_type> trainer;
+
+ decision_function<kernel_type> df = trainer.train(samples, labels);
+
+ samp(0) = 1;
+ samp(1) = 1;
+ dlog << LINFO << "test +1 : "<< df(samp);
+ DLIB_TEST(df(samp) > 0);
+ samp(0) = 1;
+ samp(1) = 2;
+ dlog << LINFO << "test -1 : "<< df(samp);
+ DLIB_TEST(df(samp) < 0);
+
+ svm_nu_trainer<kernel_type> trainer2;
+ df = trainer2.train(samples, labels);
+
+ samp(0) = 1;
+ samp(1) = 1;
+ dlog << LINFO << "test +1 : "<< df(samp);
+ DLIB_TEST(df(samp) > 0);
+ samp(0) = 1;
+ samp(1) = 2;
+ dlog << LINFO << "test -1 : "<< df(samp);
+ DLIB_TEST(df(samp) < 0);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class svm_tester : public tester
+ {
+ public:
+ svm_tester (
+ ) :
+ tester ("test_svm",
+ "Runs tests on the svm/kernel algorithm components.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_kernel_derivative();
+ unittest_binary_classification();
+ test_clutering();
+ test_regression();
+ test_anomaly_detection();
+ test_svm_trainer2();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/svm_c_linear.cpp b/ml/dlib/dlib/test/svm_c_linear.cpp
new file mode 100644
index 000000000..9e30d81f7
--- /dev/null
+++ b/ml/dlib/dlib/test/svm_c_linear.cpp
@@ -0,0 +1,392 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../stl_checked.h"
+#include "../array.h"
+#include "../rand.h"
+#include "checkerboard.h"
+#include <dlib/statistics.h>
+
+#include "tester.h"
+#include <dlib/svm.h>
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.svm_c_linear");
+
+ typedef matrix<double, 0, 1> sample_type;
+ typedef std::vector<std::pair<unsigned int, double> > sparse_sample_type;
+
+// ----------------------------------------------------------------------------------------
+
+ void run_prior_test()
+ {
+ typedef matrix<double,3,1> sample_type;
+ typedef linear_kernel<sample_type> kernel_type;
+
+ svm_c_linear_trainer<kernel_type> trainer;
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ sample_type samp;
+ samp = 0, 0, 1; samples.push_back(samp); labels.push_back(+1);
+ samp = 0, 1, 0; samples.push_back(samp); labels.push_back(-1);
+
+ trainer.set_c(10);
+ decision_function<kernel_type> df = trainer.train(samples, labels);
+
+ trainer.set_prior(df);
+
+ samples.clear();
+ labels.clear();
+ samp = 1, 0, 0; samples.push_back(samp); labels.push_back(+1);
+ samp = 0, 1, 0; samples.push_back(samp); labels.push_back(-1);
+
+ df = trainer.train(samples, labels);
+
+ samp = 0, 0, 1; samples.push_back(samp); labels.push_back(+1);
+ matrix<double,1,2> rs = test_binary_decision_function(df, samples, labels);
+ dlog << LINFO << rs;
+ DLIB_TEST(rs(0) == 1);
+ DLIB_TEST(rs(1) == 1);
+
+ dlog << LINFO << trans(df.basis_vectors(0));
+ DLIB_TEST(df.basis_vectors(0)(0) > 0);
+ DLIB_TEST(df.basis_vectors(0)(1) < 0);
+ DLIB_TEST(df.basis_vectors(0)(2) > 0);
+ }
+
+ void run_prior_sparse_test()
+ {
+ typedef std::map<unsigned long,double> sample_type;
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+
+ svm_c_linear_trainer<kernel_type> trainer;
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ sample_type samp;
+ samp[0] = 1; samples.push_back(samp); labels.push_back(+1); samp.clear();
+ samp[1] = 1; samples.push_back(samp); labels.push_back(-1); samp.clear();
+
+ trainer.set_c(10);
+ decision_function<kernel_type> df = trainer.train(samples, labels);
+
+ trainer.set_prior(df);
+
+ samples.clear();
+ labels.clear();
+ samp[2] = 1; samples.push_back(samp); labels.push_back(+1); samp.clear();
+ samp[1] = 1; samples.push_back(samp); labels.push_back(-1); samp.clear();
+
+ df = trainer.train(samples, labels);
+
+ matrix<double,1,2> rs = test_binary_decision_function(df, samples, labels);
+ dlog << LINFO << rs;
+ DLIB_TEST(rs(0) == 1);
+ DLIB_TEST(rs(1) == 1);
+
+ matrix<double,0,1> w = sparse_to_dense(df.basis_vectors(0));
+ dlog << LINFO << trans(w);
+ DLIB_TEST(w(0) > 0.1);
+ DLIB_TEST(w(1) < -0.1);
+ DLIB_TEST(w(2) > 0.1);
+ }
+
+ void get_simple_points (
+ std::vector<sample_type>& samples,
+ std::vector<double>& labels
+ )
+ {
+ samples.clear();
+ labels.clear();
+ sample_type samp(2);
+
+ samp = 0,0;
+ samples.push_back(samp);
+ labels.push_back(-1);
+
+ samp = 0,1;
+ samples.push_back(samp);
+ labels.push_back(-1);
+
+ samp = 3,0;
+ samples.push_back(samp);
+ labels.push_back(+1);
+
+ samp = 3,1;
+ samples.push_back(samp);
+ labels.push_back(+1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void get_simple_points_sparse (
+ std::vector<sparse_sample_type>& samples,
+ std::vector<double>& labels
+ )
+ {
+ samples.clear();
+ labels.clear();
+ sparse_sample_type samp;
+
+ samp.push_back(make_pair(0, 0.0));
+ samp.push_back(make_pair(1, 0.0));
+ samples.push_back(samp);
+ labels.push_back(-1);
+
+ samp.clear();
+ samp.push_back(make_pair(0, 0.0));
+ samp.push_back(make_pair(1, 1.0));
+ samples.push_back(samp);
+ labels.push_back(-1);
+
+ samp.clear();
+ samp.push_back(make_pair(0, 3.0));
+ samp.push_back(make_pair(1, 0.0));
+ samples.push_back(samp);
+ labels.push_back(+1);
+
+ samp.clear();
+ samp.push_back(make_pair(0, 3.0));
+ samp.push_back(make_pair(1, 1.0));
+ samples.push_back(samp);
+ labels.push_back(+1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_sparse (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test with sparse vectors";
+ std::vector<sparse_sample_type> samples;
+ std::vector<double> labels;
+
+ sample_type samp;
+
+ get_simple_points_sparse(samples,labels);
+
+ svm_c_linear_trainer<sparse_linear_kernel<sparse_sample_type> > trainer;
+ trainer.set_c(1e4);
+ //trainer.be_verbose();
+ trainer.set_epsilon(1e-11);
+
+
+ double obj;
+ decision_function<sparse_linear_kernel<sparse_sample_type> > df = trainer.train(samples, labels, obj);
+ dlog << LDEBUG << "obj: "<< obj;
+ DLIB_TEST_MSG(abs(obj - 0.72222222222) < 1e-7, obj);
+
+ DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6);
+ DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6);
+ DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6);
+ DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6);
+
+
+ // While we are at it, make sure the krr_trainer works with sparse samples
+ krr_trainer<sparse_linear_kernel<sparse_sample_type> > krr;
+
+ df = krr.train(samples, labels);
+ DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6);
+ DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6);
+ DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6);
+ DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6);
+
+
+ // Now test some of the sparse helper functions
+ DLIB_TEST(max_index_plus_one(samples) == 2);
+ DLIB_TEST(max_index_plus_one(samples[0]) == 2);
+
+ matrix<double,3,1> m;
+ m = 1;
+ add_to(m, samples[3]);
+ DLIB_TEST(m(0) == 1 + samples[3][0].second);
+ DLIB_TEST(m(1) == 1 + samples[3][1].second);
+ DLIB_TEST(m(2) == 1);
+
+ m = 1;
+ subtract_from(m, samples[3]);
+ DLIB_TEST(m(0) == 1 - samples[3][0].second);
+ DLIB_TEST(m(1) == 1 - samples[3][1].second);
+ DLIB_TEST(m(2) == 1);
+
+ m = 1;
+ add_to(m, samples[3], 2);
+ DLIB_TEST(m(0) == 1 + 2*samples[3][0].second);
+ DLIB_TEST(m(1) == 1 + 2*samples[3][1].second);
+ DLIB_TEST(m(2) == 1);
+
+ m = 1;
+ subtract_from(m, samples[3], 2);
+ DLIB_TEST(m(0) == 1 - 2*samples[3][0].second);
+ DLIB_TEST(m(1) == 1 - 2*samples[3][1].second);
+ DLIB_TEST(m(2) == 1);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_dense (
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "test with dense vectors";
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ sample_type samp;
+
+ get_simple_points(samples,labels);
+
+ svm_c_linear_trainer<linear_kernel<sample_type> > trainer;
+ trainer.set_c(1e4);
+ //trainer.be_verbose();
+ trainer.set_epsilon(1e-11);
+
+
+ double obj;
+ decision_function<linear_kernel<sample_type> > df = trainer.train(samples, labels, obj);
+ dlog << LDEBUG << "obj: "<< obj;
+ DLIB_TEST_MSG(abs(obj - 0.72222222222) < 1e-7, abs(obj - 0.72222222222));
+ // There shouldn't be any margin violations since this dataset is so trivial. So that means the objective
+ // should be exactly the squared norm of the decision plane (times 0.5).
+ DLIB_TEST_MSG(abs(length_squared(df.basis_vectors(0))*0.5 + df.b*df.b*0.5 - 0.72222222222) < 1e-7,
+ length_squared(df.basis_vectors(0))*0.5 + df.b*df.b*0.5);
+
+ DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6);
+ DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6);
+ DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6);
+ DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class tester_svm_c_linear : public tester
+ {
+ public:
+ tester_svm_c_linear (
+ ) :
+ tester ("test_svm_c_linear",
+ "Runs tests on the svm_c_linear_trainer.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_dense();
+ test_sparse();
+ run_prior_test();
+ run_prior_sparse_test();
+
+ // test mixed sparse and dense dot products
+ {
+ std::map<unsigned int, double> sv;
+ matrix<double,0,1> dv(4);
+
+ dv = 1,2,3,4;
+
+ sv[0] = 1;
+ sv[3] = 1;
+
+
+ DLIB_TEST(dot(sv,dv) == 5);
+ DLIB_TEST(dot(dv,sv) == 5);
+ DLIB_TEST(dot(dv,dv) == 30);
+ DLIB_TEST(dot(sv,sv) == 2);
+
+ sv[10] = 9;
+ DLIB_TEST(dot(sv,dv) == 5);
+ }
+
+ // test mixed sparse dense assignments
+ {
+ std::map<unsigned int, double> sv, sv2;
+ std::vector<std::pair<unsigned int, double> > sv3;
+ matrix<double,0,1> dv(4), dv2;
+
+ dv = 1,2,3,4;
+
+ sv[0] = 1;
+ sv[3] = 1;
+
+
+ assign(dv2, dv);
+
+ DLIB_TEST(dv2.size() == 4);
+ DLIB_TEST(dv2(0) == 1);
+ DLIB_TEST(dv2(1) == 2);
+ DLIB_TEST(dv2(2) == 3);
+ DLIB_TEST(dv2(3) == 4);
+
+ assign(sv2, dv);
+ DLIB_TEST(sv2.size() == 4);
+ DLIB_TEST(sv2[0] == 1);
+ DLIB_TEST(sv2[1] == 2);
+ DLIB_TEST(sv2[2] == 3);
+ DLIB_TEST(sv2[3] == 4);
+
+ assign(sv2, sv);
+ DLIB_TEST(sv2.size() == 2);
+ DLIB_TEST(sv2[0] == 1);
+ DLIB_TEST(sv2[1] == 0);
+ DLIB_TEST(sv2[2] == 0);
+ DLIB_TEST(sv2[3] == 1);
+
+ assign(sv3, sv);
+ DLIB_TEST(sv3.size() == 2);
+ DLIB_TEST(sv3[0].second == 1);
+ DLIB_TEST(sv3[1].second == 1);
+ DLIB_TEST(sv3[0].first == 0);
+ DLIB_TEST(sv3[1].first == 3);
+
+ assign(sv3, dv);
+ DLIB_TEST(sv3.size() == 4);
+ DLIB_TEST(sv3[0].second == 1);
+ DLIB_TEST(sv3[1].second == 2);
+ DLIB_TEST(sv3[2].second == 3);
+ DLIB_TEST(sv3[3].second == 4);
+ DLIB_TEST(sv3[0].first == 0);
+ DLIB_TEST(sv3[1].first == 1);
+ DLIB_TEST(sv3[2].first == 2);
+ DLIB_TEST(sv3[3].first == 3);
+
+ assign(sv3, sv);
+ DLIB_TEST(sv3.size() == 2);
+ DLIB_TEST(sv3[0].second == 1);
+ DLIB_TEST(sv3[1].second == 1);
+ DLIB_TEST(sv3[0].first == 0);
+ DLIB_TEST(sv3[1].first == 3);
+
+ sv.clear();
+ assign(sv, sv3);
+ DLIB_TEST(sv.size() == 2);
+ DLIB_TEST(sv[0] == 1);
+ DLIB_TEST(sv[1] == 0);
+ DLIB_TEST(sv[2] == 0);
+ DLIB_TEST(sv[3] == 1);
+
+ }
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/svm_c_linear_dcd.cpp b/ml/dlib/dlib/test/svm_c_linear_dcd.cpp
new file mode 100644
index 000000000..93c99db30
--- /dev/null
+++ b/ml/dlib/dlib/test/svm_c_linear_dcd.cpp
@@ -0,0 +1,545 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/svm.h>
+#include <dlib/rand.h>
+#include <dlib/statistics.h>
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.svm_c_linear_dcd");
+
+// ----------------------------------------------------------------------------------------
+
+ void test_sparse()
+ {
+ typedef std::map<unsigned long,double> sample_type;
+
+
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+
+
+
+ svm_c_linear_trainer<kernel_type> linear_trainer_cpa;
+ svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
+
+ svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
+
+ const double C = 0.2;
+ linear_trainer.set_epsilon(1e-10);
+ linear_trainer_cpa.set_epsilon(1e-10);
+
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ // make an instance of a sample vector so we can use it below
+ sample_type sample;
+
+ decision_function<kernel_type> df, df2, df3;
+
+ dlib::rand rnd;
+ // Now lets go into a loop and randomly generate 10000 samples.
+ double label = +1;
+ for (int i = 0; i < 100; ++i)
+ {
+ // flip this flag
+ label *= -1;
+
+ sample.clear();
+
+ // now make a random sparse sample with at most 10 non-zero elements
+ for (int j = 0; j < 5; ++j)
+ {
+ int idx = rnd.get_random_32bit_number()%10;
+ double value = rnd.get_random_double();
+
+ sample[idx] = label*value;
+ }
+
+ // Also save the samples we are generating so we can let the svm_c_linear_trainer
+ // learn from them below.
+ samples.push_back(sample);
+ labels.push_back(label);
+
+ if (samples.size() > 1)
+ {
+ linear_trainer_cpa.set_c_class1(C);
+ linear_trainer_cpa.set_c_class2(1.5*C);
+ linear_trainer.set_c_class1(C/samples.size());
+ linear_trainer.set_c_class2(1.5*C/samples.size());
+
+ df = linear_trainer.train(samples, labels, state);
+ df2 = linear_trainer_cpa.train(samples, labels);
+ df3 = linear_trainer.train(samples, labels);
+
+ DLIB_TEST_MSG( dlib::distance(df.basis_vectors(0), df2.basis_vectors(0)) < 1e-8, dlib::distance(df.basis_vectors(0), df2.basis_vectors(0)));
+ DLIB_TEST( std::abs(df.b - df2.b) < 1e-8);
+ DLIB_TEST( dlib::distance(df.basis_vectors(0), df3.basis_vectors(0)) < 1e-8);
+ DLIB_TEST( std::abs(df.b - df3.b) < 1e-8);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_normal_no_bias()
+ {
+ typedef matrix<double,10,1> sample_type;
+
+
+ typedef linear_kernel<sample_type> kernel_type;
+
+
+
+ svm_c_linear_trainer<kernel_type> linear_trainer_cpa;
+ svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
+
+ svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
+
+ const double C = 1.0;
+ linear_trainer.set_epsilon(1e-10);
+ linear_trainer_cpa.set_epsilon(1e-10);
+
+ linear_trainer.include_bias(false);
+
+
+ std::vector<sample_type> samples, samples_explict_bias;
+ std::vector<double> labels;
+
+ // make an instance of a sample vector so we can use it below
+ sample_type sample;
+
+ decision_function<kernel_type> df, df2, df3;
+
+ dlib::rand rnd;
+ // Now lets go into a loop and randomly generate 10000 samples.
+ double label = +1;
+ for (int i = 0; i < 100; ++i)
+ {
+ // flip this flag
+ label *= -1;
+
+ sample = 0;
+
+ // now make a random sparse sample with at most 10 non-zero elements
+ for (int j = 0; j < 5; ++j)
+ {
+ int idx = rnd.get_random_32bit_number()%9;
+ double value = rnd.get_random_double();
+
+ sample(idx) = label*value;
+ }
+
+ // Also save the samples we are generating so we can let the svm_c_linear_trainer
+ // learn from them below.
+ samples.push_back(sample);
+ labels.push_back(label);
+
+ sample(9) = -1;
+ samples_explict_bias.push_back(sample);
+
+ if (samples.size() > 1)
+ {
+ linear_trainer_cpa.set_c_class1(C);
+ linear_trainer_cpa.set_c_class2(1.5*C);
+ linear_trainer.set_c_class1(C/samples.size());
+ linear_trainer.set_c_class2(1.5*C/samples.size());
+
+ df = linear_trainer.train(samples_explict_bias, labels, state);
+ df2 = linear_trainer_cpa.train(samples, labels);
+ df3 = linear_trainer.train(samples_explict_bias, labels);
+
+ DLIB_TEST( std::abs(df2.basis_vectors(0)(9)) < 1e-7);
+ DLIB_TEST_MSG( max(abs(colm(df.basis_vectors(0),0,9) - colm(df2.basis_vectors(0),0,9))) < 1e-6, max(abs(colm(df.basis_vectors(0),0,9) - colm(df2.basis_vectors(0),0,9))));
+ DLIB_TEST( std::abs(df.basis_vectors(0)(9) - df2.b) < 1e-6);
+ DLIB_TEST( max(abs(df.basis_vectors(0) - df3.basis_vectors(0))) < 1e-6);
+ DLIB_TEST( std::abs(df.b - df3.b) < 1e-7);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_normal()
+ {
+ typedef matrix<double,10,1> sample_type;
+
+
+ typedef linear_kernel<sample_type> kernel_type;
+
+
+
+ svm_c_linear_trainer<kernel_type> linear_trainer_cpa;
+ svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
+
+ svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
+
+ const double C = 1;
+ linear_trainer.set_epsilon(1e-10);
+ linear_trainer_cpa.set_epsilon(1e-10);
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ // make an instance of a sample vector so we can use it below
+ sample_type sample;
+
+ decision_function<kernel_type> df, df2, df3;
+
+ dlib::rand rnd;
+ // Now lets go into a loop and randomly generate 10000 samples.
+ double label = +1;
+ for (int i = 0; i < 100; ++i)
+ {
+ // flip this flag
+ label *= -1;
+
+ sample = 0;
+
+ // now make a random sparse sample with at most 10 non-zero elements
+ for (int j = 0; j < 5; ++j)
+ {
+ int idx = rnd.get_random_32bit_number()%10;
+ double value = rnd.get_random_double();
+
+ sample(idx) = label*value;
+ }
+
+ // Also save the samples we are generating so we can let the svm_c_linear_trainer
+ // learn from them below.
+ samples.push_back(sample);
+ labels.push_back(label);
+
+ if (samples.size() > 1)
+ {
+ linear_trainer_cpa.set_c_class1(C);
+ linear_trainer_cpa.set_c_class2(1.5*C);
+ linear_trainer.set_c_class1(C/samples.size());
+ linear_trainer.set_c_class2(1.5*C/samples.size());
+
+ df = linear_trainer.train(samples, labels, state);
+ df2 = linear_trainer_cpa.train(samples, labels);
+ df3 = linear_trainer.train(samples, labels);
+
+ DLIB_TEST_MSG( max(abs(df.basis_vectors(0) - df2.basis_vectors(0))) < 1e-7, max(abs(df.basis_vectors(0) - df2.basis_vectors(0))));
+ DLIB_TEST( std::abs(df.b - df2.b) < 1e-7);
+ DLIB_TEST( max(abs(df.basis_vectors(0) - df3.basis_vectors(0))) < 1e-7);
+ DLIB_TEST( std::abs(df.b - df3.b) < 1e-7);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_normal_force_last_weight(bool have_bias, bool force_weight)
+ {
+ typedef matrix<double,10,1> sample_type;
+ dlog << LINFO << "have_bias: "<< have_bias << " force_weight: "<< force_weight;
+
+
+ typedef linear_kernel<sample_type> kernel_type;
+
+
+ svm_c_linear_trainer<kernel_type> linear_trainer_cpa;
+
+ svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
+
+ svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
+
+ const double C = 1;
+ linear_trainer.set_epsilon(1e-10);
+ linear_trainer_cpa.set_epsilon(1e-11);
+
+ linear_trainer_cpa.force_last_weight_to_1(force_weight);
+
+ linear_trainer.force_last_weight_to_1(force_weight);
+ linear_trainer.include_bias(have_bias);
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ // make an instance of a sample vector so we can use it below
+ sample_type sample;
+
+ decision_function<kernel_type> df, df2;
+
+ running_stats<double> rs;
+
+ dlib::rand rnd;
+ // Now lets go into a loop and randomly generate 10000 samples.
+ double label = +1;
+ for (int i = 0; i < 40; ++i)
+ {
+ // flip this flag
+ label *= -1;
+
+ sample = 0;
+
+ // now make a random sparse sample with at most 10 non-zero elements
+ for (int j = 0; j < 5; ++j)
+ {
+ int idx = rnd.get_random_32bit_number()%9;
+ double value = rnd.get_random_double();
+
+ sample(idx) = label*value + label;
+ }
+
+ sample(9) = 4;
+
+ // Also save the samples we are generating so we can let the svm_c_linear_trainer
+ // learn from them below.
+ samples.push_back(sample);
+ labels.push_back(label);
+
+ linear_trainer.set_c(C);
+ linear_trainer_cpa.set_c(C*samples.size());
+
+ df = linear_trainer.train(samples, labels, state);
+
+ if (force_weight)
+ {
+ DLIB_TEST(std::abs(df.basis_vectors(0)(9) - 1) < 1e-8);
+ DLIB_TEST(std::abs(df.b) < 1e-8);
+
+ if (samples.size() > 1)
+ {
+ df2 = linear_trainer_cpa.train(samples, labels);
+ DLIB_TEST_MSG( max(abs(df.basis_vectors(0) - df2.basis_vectors(0))) < 1e-7, max(abs(df.basis_vectors(0) - df2.basis_vectors(0))));
+ DLIB_TEST( std::abs(df.b - df2.b) < 1e-7);
+ }
+ }
+
+ if (!have_bias)
+ DLIB_TEST(std::abs(df.b) < 1e-8);
+
+
+ for (unsigned long k = 0; k < samples.size(); ++k)
+ {
+ //cout << "pred: "<< labels[k]*df(samples[k]) << endl;
+ rs.add(labels[k]*df(samples[k]));
+ }
+ }
+ DLIB_TEST_MSG(std::abs(rs.min()-1) < 1e-7, std::abs(rs.min()-1));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_normal_1_sample(double label)
+ {
+ typedef matrix<double,10,1> sample_type;
+
+
+ typedef linear_kernel<sample_type> kernel_type;
+
+
+
+ svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
+
+ svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
+
+ const double C = 10;
+ linear_trainer.set_epsilon(1e-10);
+ linear_trainer.set_c(C);
+
+
+ linear_trainer.force_last_weight_to_1(true);
+ linear_trainer.include_bias(false);
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ // make an instance of a sample vector so we can use it below
+ sample_type sample;
+
+ sample = 0;
+ sample(0) = -1;
+ sample(1) = -1;
+ sample(9) = 4;
+
+ samples.push_back(sample);
+ labels.push_back(label);
+
+ for (int i = 0; i < 4; ++i)
+ {
+ decision_function<kernel_type> df;
+ df = linear_trainer.train(samples, labels);
+
+ if (label > 0)
+ {
+ DLIB_TEST(std::abs(df(samples[0])-4) < 1e-8);
+ }
+ else
+ {
+ DLIB_TEST(std::abs(df(samples[0])+1) < 1e-8);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_sparse_1_sample(double label)
+ {
+ typedef std::vector<std::pair<unsigned long,double> > sample_type;
+
+
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+
+
+
+ svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
+
+ svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
+
+ const double C = 10;
+ linear_trainer.set_epsilon(1e-10);
+ linear_trainer.set_c(C);
+
+
+ linear_trainer.force_last_weight_to_1(true);
+ linear_trainer.include_bias(false);
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ // make an instance of a sample vector so we can use it below
+ sample_type sample;
+
+ sample.push_back(make_pair(0,-1));
+ sample.push_back(make_pair(1,1));
+ sample.push_back(make_pair(9,4));
+
+ for (int i = 0; i < 4; ++i)
+ {
+ samples.push_back(sample);
+ labels.push_back(label);
+
+ decision_function<kernel_type> df;
+ df = linear_trainer.train(samples, labels);
+
+
+ if (label > 0)
+ {
+ DLIB_TEST(std::abs(df(samples[0])-4) < 1e-8);
+ }
+ else
+ {
+ DLIB_TEST(std::abs(df(samples[0])+1) < 1e-8);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_l2_version ()
+ {
+ typedef std::map<unsigned long,double> sample_type;
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+
+ svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
+ linear_trainer.set_c(10);
+ linear_trainer.set_epsilon(1e-5);
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+
+ // make an instance of a sample vector so we can use it below
+ sample_type sample;
+
+
+ // Now let's go into a loop and randomly generate 10000 samples.
+ double label = +1;
+ for (int i = 0; i < 1000; ++i)
+ {
+ // flip this flag
+ label *= -1;
+
+ sample.clear();
+
+ // now make a random sparse sample with at most 10 non-zero elements
+ for (int j = 0; j < 10; ++j)
+ {
+ int idx = std::rand()%100;
+ double value = static_cast<double>(std::rand())/RAND_MAX;
+
+ sample[idx] = label*value;
+ }
+
+ // Also save the samples we are generating so we can let the svm_c_linear_trainer
+ // learn from them below.
+ samples.push_back(sample);
+ labels.push_back(label);
+ }
+
+ decision_function<kernel_type> df = linear_trainer.train(samples, labels);
+
+ sample.clear();
+ sample[4] = 0.3;
+ sample[10] = 0.9;
+ DLIB_TEST(df(sample) > 0);
+
+ sample.clear();
+ sample[83] = -0.3;
+ sample[26] = -0.9;
+ sample[58] = -0.7;
+ DLIB_TEST(df(sample) < 0);
+
+ sample.clear();
+ sample[0] = -0.2;
+ sample[9] = -0.8;
+ DLIB_TEST(df(sample) < 0);
+ }
+
+ class tester_svm_c_linear_dcd : public tester
+ {
+ public:
+ tester_svm_c_linear_dcd (
+ ) :
+ tester ("test_svm_c_linear_dcd",
+ "Runs tests on the svm_c_linear_dcd_trainer.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test_normal();
+ print_spinner();
+ test_normal_no_bias();
+ print_spinner();
+ test_sparse();
+ print_spinner();
+ test_normal_force_last_weight(false,false);
+ print_spinner();
+ test_normal_force_last_weight(false,true);
+ print_spinner();
+ test_normal_force_last_weight(true,false);
+ print_spinner();
+ test_normal_force_last_weight(true,true);
+ print_spinner();
+ test_normal_1_sample(+1);
+ print_spinner();
+ test_normal_1_sample(-1);
+ print_spinner();
+ test_sparse_1_sample(+1);
+ print_spinner();
+ test_sparse_1_sample(-1);
+ print_spinner();
+
+ test_l2_version();
+ }
+ } a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/svm_multiclass_linear.cpp b/ml/dlib/dlib/test/svm_multiclass_linear.cpp
new file mode 100644
index 000000000..e01d48892
--- /dev/null
+++ b/ml/dlib/dlib/test/svm_multiclass_linear.cpp
@@ -0,0 +1,226 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+#include <dlib/data_io.h>
+#include "create_iris_datafile.h"
+#include <vector>
+#include <map>
+#include <sstream>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.svm_multiclass_trainer");
+
+
+ class test_svm_multiclass_trainer : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ test_svm_multiclass_trainer (
+ ) :
+ tester (
+ "test_svm_multiclass_trainer", // the command line argument name for this test
+ "Run tests on the svm_multiclass_linear_trainer stuff.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+
+ void test_prior ()
+ {
+ print_spinner();
+ typedef matrix<double,4,1> sample_type;
+ typedef linear_kernel<sample_type> kernel_type;
+
+ std::vector<sample_type> samples;
+ std::vector<int> labels;
+
+ for (int i = 0; i < 4; ++i)
+ {
+ if (i==2)
+ ++i;
+ for (int iter = 0; iter < 5; ++iter)
+ {
+ sample_type samp;
+ samp = 0;
+ samp(i) = 1;
+ samples.push_back(samp);
+ labels.push_back(i);
+ }
+ }
+
+
+ svm_multiclass_linear_trainer<kernel_type,int> trainer;
+
+ multiclass_linear_decision_function<kernel_type,int> df = trainer.train(samples, labels);
+
+ //cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl;
+ //cout << df.weights << endl;
+ //cout << df.b << endl;
+
+ std::vector<sample_type> samples2;
+ std::vector<int> labels2;
+ int i = 2;
+ for (int iter = 0; iter < 5; ++iter)
+ {
+ sample_type samp;
+ samp = 0;
+ samp(i) = 1;
+ samples2.push_back(samp);
+ labels2.push_back(i);
+ samples.push_back(samp);
+ labels.push_back(i);
+ }
+
+ trainer.set_prior(df);
+ trainer.set_c(0.1);
+ df = trainer.train(samples2, labels2);
+
+ matrix<double> res = test_multiclass_decision_function(df, samples, labels);
+ dlog << LINFO << "test: \n" << res;
+ dlog << LINFO << df.weights;
+ dlog << LINFO << df.b;
+ DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
+ }
+
+ void test_prior_sparse ()
+ {
+ print_spinner();
+ typedef std::map<unsigned long,double> sample_type;
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+
+ std::vector<sample_type> samples;
+ std::vector<int> labels;
+
+ for (int i = 0; i < 4; ++i)
+ {
+ if (i==2)
+ ++i;
+ for (int iter = 0; iter < 5; ++iter)
+ {
+ sample_type samp;
+ samp[i] = 1;
+ samples.push_back(samp);
+ labels.push_back(i);
+ }
+ }
+
+
+ svm_multiclass_linear_trainer<kernel_type,int> trainer;
+
+ multiclass_linear_decision_function<kernel_type,int> df = trainer.train(samples, labels);
+
+ //cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl;
+ //cout << df.weights << endl;
+ //cout << df.b << endl;
+
+ std::vector<sample_type> samples2;
+ std::vector<int> labels2;
+ int i = 2;
+ for (int iter = 0; iter < 5; ++iter)
+ {
+ sample_type samp;
+ samp[i] = 1;
+ samp[i+10] = 1;
+ samples2.push_back(samp);
+ labels2.push_back(i);
+ samples.push_back(samp);
+ labels.push_back(i);
+ }
+
+ trainer.set_prior(df);
+ trainer.set_c(0.1);
+ df = trainer.train(samples2, labels2);
+
+ matrix<double> res = test_multiclass_decision_function(df, samples, labels);
+ dlog << LINFO << "test: \n" << res;
+ dlog << LINFO << df.weights;
+ dlog << LINFO << df.b;
+ DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
+ }
+
+ template <typename sample_type>
+ void run_test()
+ {
+ print_spinner();
+
+ typedef typename sample_type::value_type::second_type scalar_type;
+
+ std::vector<sample_type> samples;
+ std::vector<scalar_type> labels;
+
+ load_libsvm_formatted_data("iris.scale",samples, labels);
+
+ DLIB_TEST(samples.size() == 150);
+ DLIB_TEST(labels.size() == 150);
+
+ typedef sparse_linear_kernel<sample_type> kernel_type;
+ svm_multiclass_linear_trainer<kernel_type> trainer;
+ trainer.set_c(100);
+ trainer.set_epsilon(0.000001);
+
+ randomize_samples(samples, labels);
+ matrix<double> cv = cross_validate_multiclass_trainer(trainer, samples, labels, 4);
+
+ dlog << LINFO << "confusion matrix: \n" << cv;
+ const scalar_type cv_accuracy = sum(diag(cv))/sum(cv);
+ dlog << LINFO << "cv accuracy: " << cv_accuracy;
+ DLIB_TEST(cv_accuracy > 0.97);
+
+
+
+
+ {
+ print_spinner();
+ typedef matrix<scalar_type,0,1> dsample_type;
+ std::vector<dsample_type> dsamples = sparse_to_dense(samples);
+ DLIB_TEST(dsamples.size() == 150);
+
+ typedef linear_kernel<dsample_type> kernel_type;
+ svm_multiclass_linear_trainer<kernel_type> trainer;
+ trainer.set_c(100);
+
+ cv = cross_validate_multiclass_trainer(trainer, dsamples, labels, 4);
+
+ dlog << LINFO << "dense confusion matrix: \n" << cv;
+ const scalar_type cv_accuracy = sum(diag(cv))/sum(cv);
+ dlog << LINFO << "dense cv accuracy: " << cv_accuracy;
+ DLIB_TEST(cv_accuracy > 0.97);
+ }
+
+ }
+
+
+
+
+ void perform_test (
+ )
+ {
+ print_spinner();
+ create_iris_datafile();
+
+ run_test<std::map<unsigned int, double> >();
+ run_test<std::map<unsigned int, float> >();
+ run_test<std::vector<std::pair<unsigned int, float> > >();
+ run_test<std::vector<std::pair<unsigned long, double> > >();
+
+ test_prior();
+ test_prior_sparse();
+ }
+ };
+
+ test_svm_multiclass_trainer a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/svm_struct.cpp b/ml/dlib/dlib/test/svm_struct.cpp
new file mode 100644
index 000000000..00208c48d
--- /dev/null
+++ b/ml/dlib/dlib/test/svm_struct.cpp
@@ -0,0 +1,641 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/svm_threaded.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.svm_struct");
+
+
+ template <
+ typename matrix_type,
+ typename sample_type,
+ typename label_type
+ >
+ class test_multiclass_svm_problem : public structural_svm_problem_threaded<matrix_type,
+ std::vector<std::pair<unsigned long,typename matrix_type::type> > >
+ {
+
+ public:
+ typedef typename matrix_type::type scalar_type;
+ typedef std::vector<std::pair<unsigned long,scalar_type> > feature_vector_type;
+
+ test_multiclass_svm_problem (
+ const std::vector<sample_type>& samples_,
+ const std::vector<label_type>& labels_
+ ) :
+ structural_svm_problem_threaded<matrix_type,
+ std::vector<std::pair<unsigned long,typename matrix_type::type> > >(2),
+ samples(samples_),
+ labels(labels_),
+ dims(10+1) // +1 for the bias
+ {
+ for (int i = 0; i < 10; ++i)
+ {
+ distinct_labels.push_back(i);
+ }
+ }
+
+ virtual long get_num_dimensions (
+ ) const
+ {
+ return dims*10;
+ }
+
+ virtual long get_num_samples (
+ ) const
+ {
+ return static_cast<long>(samples.size());
+ }
+
+ virtual void get_truth_joint_feature_vector (
+ long idx,
+ feature_vector_type& psi
+ ) const
+ {
+ assign(psi, samples[idx]);
+ // Add a constant -1 to account for the bias term.
+ psi.push_back(std::make_pair(dims-1,static_cast<scalar_type>(-1)));
+
+ // Find which distinct label goes with this psi.
+ const long label_idx = index_of_max(mat(distinct_labels) == labels[idx]);
+
+ offset_feature_vector(psi, dims*label_idx);
+ }
+
+ virtual void separation_oracle (
+ const long idx,
+ const matrix_type& current_solution,
+ scalar_type& loss,
+ feature_vector_type& psi
+ ) const
+ {
+ scalar_type best_val = -std::numeric_limits<scalar_type>::infinity();
+ unsigned long best_idx = 0;
+
+ // Figure out which label is the best. That is, what label maximizes
+ // LOSS(idx,y) + F(x,y). Note that y in this case is given by distinct_labels[i].
+ for (unsigned long i = 0; i < distinct_labels.size(); ++i)
+ {
+ // Compute the F(x,y) part:
+ // perform: temp == dot(relevant part of current solution, samples[idx]) - current_bias
+ scalar_type temp = dot(rowm(current_solution, range(i*dims, (i+1)*dims-2)), samples[idx]) - current_solution((i+1)*dims-1);
+
+ // Add the LOSS(idx,y) part:
+ if (labels[idx] != distinct_labels[i])
+ temp += 1;
+
+ // Now temp == LOSS(idx,y) + F(x,y). Check if it is the biggest we have seen.
+ if (temp > best_val)
+ {
+ best_val = temp;
+ best_idx = i;
+ }
+ }
+
+ assign(psi, samples[idx]);
+ // add a constant -1 to account for the bias term
+ psi.push_back(std::make_pair(dims-1,static_cast<scalar_type>(-1)));
+
+ offset_feature_vector(psi, dims*best_idx);
+
+ if (distinct_labels[best_idx] == labels[idx])
+ loss = 0;
+ else
+ loss = 1;
+ }
+
+ private:
+
+ void offset_feature_vector (
+ feature_vector_type& sample,
+ const unsigned long val
+ ) const
+ {
+ if (val != 0)
+ {
+ for (typename feature_vector_type::iterator i = sample.begin(); i != sample.end(); ++i)
+ {
+ i->first += val;
+ }
+ }
+ }
+
+
+ const std::vector<sample_type>& samples;
+ const std::vector<label_type>& labels;
+ std::vector<label_type> distinct_labels;
+ const long dims;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename label_type_ = typename K::scalar_type
+ >
+ class test_svm_multiclass_linear_trainer2
+ {
+ public:
+ typedef label_type_ label_type;
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ typedef multiclass_linear_decision_function<kernel_type, label_type> trained_function_type;
+
+
+ test_svm_multiclass_linear_trainer2 (
+ ) :
+ C(10),
+ eps(1e-4),
+ verbose(false)
+ {
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const
+ {
+ scalar_type svm_objective = 0;
+ return train(all_samples, all_labels, svm_objective);
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels,
+ scalar_type& svm_objective
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(all_samples,all_labels),
+ "\t trained_function_type test_svm_multiclass_linear_trainer2::train(all_samples,all_labels)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t all_samples.size(): " << all_samples.size()
+ << "\n\t all_labels.size(): " << all_labels.size()
+ );
+
+ typedef matrix<scalar_type,0,1> w_type;
+ w_type weights;
+ std::vector<sample_type> samples1(all_samples.begin(), all_samples.begin()+all_samples.size()/2);
+ std::vector<sample_type> samples2(all_samples.begin()+all_samples.size()/2, all_samples.end());
+
+ std::vector<label_type> labels1(all_labels.begin(), all_labels.begin()+all_labels.size()/2);
+ std::vector<label_type> labels2(all_labels.begin()+all_labels.size()/2, all_labels.end());
+ test_multiclass_svm_problem<w_type, sample_type, label_type> problem1(samples1, labels1);
+ test_multiclass_svm_problem<w_type, sample_type, label_type> problem2(samples2, labels2);
+ problem1.set_max_cache_size(3);
+ problem2.set_max_cache_size(0);
+
+ svm_struct_processing_node node1(problem1, 12345, 3);
+ svm_struct_processing_node node2(problem2, 12346, 0);
+
+ solver.set_inactive_plane_threshold(50);
+ solver.set_subproblem_epsilon(1e-4);
+
+ svm_struct_controller_node controller;
+ controller.set_c(C);
+ controller.set_epsilon(eps);
+ if (verbose)
+ controller.be_verbose();
+ controller.add_processing_node("127.0.0.1", 12345);
+ controller.add_processing_node("localhost:12346");
+ svm_objective = controller(solver, weights);
+
+
+
+ trained_function_type df;
+
+ const long dims = max_index_plus_one(all_samples);
+ df.labels = select_all_distinct_labels(all_labels);
+ df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1));
+ df.b = colm(reshape(weights, df.labels.size(), dims+1), dims);
+ return df;
+ }
+
+ private:
+ scalar_type C;
+ scalar_type eps;
+ bool verbose;
+ mutable oca solver;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename label_type_ = typename K::scalar_type
+ >
+ class test_svm_multiclass_linear_trainer3
+ {
+ public:
+ typedef label_type_ label_type;
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ typedef multiclass_linear_decision_function<kernel_type, label_type> trained_function_type;
+
+
+ test_svm_multiclass_linear_trainer3 (
+ ) :
+ C(10),
+ eps(1e-4),
+ verbose(false)
+ {
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const
+ {
+ scalar_type svm_objective = 0;
+ return train(all_samples, all_labels, svm_objective);
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels,
+ scalar_type& svm_objective
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(all_samples,all_labels),
+ "\t trained_function_type test_svm_multiclass_linear_trainer3::train(all_samples,all_labels)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t all_samples.size(): " << all_samples.size()
+ << "\n\t all_labels.size(): " << all_labels.size()
+ );
+
+ typedef matrix<scalar_type,0,1> w_type;
+ w_type weights;
+ test_multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels);
+ problem.set_max_cache_size(0);
+
+ problem.set_c(C);
+ problem.set_epsilon(eps);
+
+ if (verbose)
+ problem.be_verbose();
+
+ solver.set_inactive_plane_threshold(50);
+ solver.set_subproblem_epsilon(1e-4);
+ svm_objective = solver(problem, weights);
+
+
+ trained_function_type df;
+
+ const long dims = max_index_plus_one(all_samples);
+ df.labels = select_all_distinct_labels(all_labels);
+ df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1));
+ df.b = colm(reshape(weights, df.labels.size(), dims+1), dims);
+ return df;
+ }
+
+ private:
+ scalar_type C;
+ scalar_type eps;
+ bool verbose;
+ mutable oca solver;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename label_type_ = typename K::scalar_type
+ >
+ class test_svm_multiclass_linear_trainer4
+ {
+ public:
+ typedef label_type_ label_type;
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ typedef multiclass_linear_decision_function<kernel_type, label_type> trained_function_type;
+
+
+ test_svm_multiclass_linear_trainer4 (
+ ) :
+ C(10),
+ eps(1e-4),
+ verbose(false)
+ {
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const
+ {
+ scalar_type svm_objective = 0;
+ return train(all_samples, all_labels, svm_objective);
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels,
+ scalar_type& svm_objective
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(all_samples,all_labels),
+ "\t trained_function_type test_svm_multiclass_linear_trainer4::train(all_samples,all_labels)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t all_samples.size(): " << all_samples.size()
+ << "\n\t all_labels.size(): " << all_labels.size()
+ );
+
+ typedef matrix<scalar_type,0,1> w_type;
+ w_type weights;
+ test_multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels);
+ problem.set_max_cache_size(3);
+
+ problem.set_c(C);
+ problem.set_epsilon(eps);
+
+ if (verbose)
+ problem.be_verbose();
+
+ solver.set_inactive_plane_threshold(50);
+ solver.set_subproblem_epsilon(1e-4);
+ svm_objective = solver(problem, weights);
+
+
+ trained_function_type df;
+
+ const long dims = max_index_plus_one(all_samples);
+ df.labels = select_all_distinct_labels(all_labels);
+ df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1));
+ df.b = colm(reshape(weights, df.labels.size(), dims+1), dims);
+ return df;
+ }
+
+ private:
+ scalar_type C;
+ scalar_type eps;
+ bool verbose;
+ mutable oca solver;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K,
+ typename label_type_ = typename K::scalar_type
+ >
+ class test_svm_multiclass_linear_trainer5
+ {
+ public:
+ typedef label_type_ label_type;
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+
+ typedef multiclass_linear_decision_function<kernel_type, label_type> trained_function_type;
+
+
+ test_svm_multiclass_linear_trainer5 (
+ ) :
+ C(10),
+ eps(1e-4),
+ verbose(false)
+ {
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels
+ ) const
+ {
+ scalar_type svm_objective = 0;
+ return train(all_samples, all_labels, svm_objective);
+ }
+
+ trained_function_type train (
+ const std::vector<sample_type>& all_samples,
+ const std::vector<label_type>& all_labels,
+ scalar_type& svm_objective
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(all_samples,all_labels),
+ "\t trained_function_type test_svm_multiclass_linear_trainer5::train(all_samples,all_labels)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t all_samples.size(): " << all_samples.size()
+ << "\n\t all_labels.size(): " << all_labels.size()
+ );
+
+ typedef matrix<scalar_type,0,1> w_type;
+ w_type weights;
+ const long dims = max_index_plus_one(all_samples);
+ trained_function_type df;
+ df.labels = select_all_distinct_labels(all_labels);
+ multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels, df.labels, dims, 4);
+ problem.set_max_cache_size(3);
+
+ problem.set_c(C);
+ problem.set_epsilon(eps);
+
+ if (verbose)
+ problem.be_verbose();
+
+ solver.set_inactive_plane_threshold(50);
+ solver.set_subproblem_epsilon(1e-4);
+ svm_objective = solver(problem, weights);
+
+
+
+ df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1));
+ df.b = colm(reshape(weights, df.labels.size(), dims+1), dims);
+ return df;
+ }
+
+ private:
+ scalar_type C;
+ scalar_type eps;
+ bool verbose;
+ mutable oca solver;
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ typedef matrix<double,10,1> sample_type;
+ typedef double scalar_type;
+
+ void make_dataset (
+ std::vector<sample_type>& samples,
+ std::vector<scalar_type>& labels,
+ int num,
+ dlib::rand& rnd
+ )
+ {
+ samples.clear();
+ labels.clear();
+ for (int i = 0; i < 10; ++i)
+ {
+ for (int j = 0; j < num; ++j)
+ {
+ sample_type samp;
+ samp = 0;
+ samp(i) = 10*rnd.get_random_double()+1;
+
+ samples.push_back(samp);
+ labels.push_back(i);
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_svm_struct : public tester
+ {
+ public:
+ test_svm_struct (
+ ) :
+ tester ("test_svm_struct",
+ "Runs tests on the structural svm components.")
+ {}
+
+ void run_test (
+ const std::vector<sample_type>& samples,
+ const std::vector<scalar_type>& labels,
+ const double true_obj
+ )
+ {
+ typedef linear_kernel<sample_type> kernel_type;
+ svm_multiclass_linear_trainer<kernel_type> trainer1;
+ test_svm_multiclass_linear_trainer2<kernel_type> trainer2;
+ test_svm_multiclass_linear_trainer3<kernel_type> trainer3;
+ test_svm_multiclass_linear_trainer4<kernel_type> trainer4;
+ test_svm_multiclass_linear_trainer5<kernel_type> trainer5;
+
+ trainer1.set_epsilon(1e-4);
+ trainer1.set_c(10);
+
+
+ multiclass_linear_decision_function<kernel_type,double> df1, df2, df3, df4, df5;
+ double obj1, obj2, obj3, obj4, obj5;
+
+ // Solve a multiclass SVM a whole bunch of different ways and make sure
+ // they all give the same answer.
+ print_spinner();
+ df1 = trainer1.train(samples, labels, obj1);
+ print_spinner();
+ df2 = trainer2.train(samples, labels, obj2);
+ print_spinner();
+ df3 = trainer3.train(samples, labels, obj3);
+ print_spinner();
+ df4 = trainer4.train(samples, labels, obj4);
+ print_spinner();
+ df5 = trainer5.train(samples, labels, obj5);
+ print_spinner();
+
+ dlog << LINFO << "obj1: "<< obj1;
+ dlog << LINFO << "obj2: "<< obj2;
+ dlog << LINFO << "obj3: "<< obj3;
+ dlog << LINFO << "obj4: "<< obj4;
+ dlog << LINFO << "obj5: "<< obj5;
+ DLIB_TEST(std::abs(obj1 - obj2) < 1e-2);
+ DLIB_TEST(std::abs(obj1 - obj3) < 1e-2);
+ DLIB_TEST(std::abs(obj1 - obj4) < 1e-2);
+ DLIB_TEST(std::abs(obj1 - obj5) < 1e-2);
+ DLIB_TEST(std::abs(obj1 - true_obj) < 1e-2);
+ DLIB_TEST(std::abs(obj2 - true_obj) < 1e-2);
+ DLIB_TEST(std::abs(obj3 - true_obj) < 1e-2);
+ DLIB_TEST(std::abs(obj4 - true_obj) < 1e-2);
+ DLIB_TEST(std::abs(obj5 - true_obj) < 1e-2);
+
+ dlog << LINFO << "weight error: "<< max(abs(df1.weights - df2.weights));
+ dlog << LINFO << "weight error: "<< max(abs(df1.weights - df3.weights));
+ dlog << LINFO << "weight error: "<< max(abs(df1.weights - df4.weights));
+ dlog << LINFO << "weight error: "<< max(abs(df1.weights - df5.weights));
+
+ DLIB_TEST(max(abs(df1.weights - df2.weights)) < 1e-2);
+ DLIB_TEST(max(abs(df1.weights - df3.weights)) < 1e-2);
+ DLIB_TEST(max(abs(df1.weights - df4.weights)) < 1e-2);
+ DLIB_TEST(max(abs(df1.weights - df5.weights)) < 1e-2);
+
+ dlog << LINFO << "b error: "<< max(abs(df1.b - df2.b));
+ dlog << LINFO << "b error: "<< max(abs(df1.b - df3.b));
+ dlog << LINFO << "b error: "<< max(abs(df1.b - df4.b));
+ dlog << LINFO << "b error: "<< max(abs(df1.b - df5.b));
+ DLIB_TEST(max(abs(df1.b - df2.b)) < 1e-2);
+ DLIB_TEST(max(abs(df1.b - df3.b)) < 1e-2);
+ DLIB_TEST(max(abs(df1.b - df4.b)) < 1e-2);
+ DLIB_TEST(max(abs(df1.b - df5.b)) < 1e-2);
+
+ matrix<double> res = test_multiclass_decision_function(df1, samples, labels);
+ dlog << LINFO << res;
+ dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res);
+ DLIB_TEST(sum(diag(res)) == samples.size());
+
+ res = test_multiclass_decision_function(df2, samples, labels);
+ dlog << LINFO << res;
+ dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res);
+ DLIB_TEST(sum(diag(res)) == samples.size());
+
+ res = test_multiclass_decision_function(df3, samples, labels);
+ dlog << LINFO << res;
+ dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res);
+ DLIB_TEST(sum(diag(res)) == samples.size());
+
+ res = test_multiclass_decision_function(df4, samples, labels);
+ dlog << LINFO << res;
+ dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res);
+ DLIB_TEST(sum(diag(res)) == samples.size());
+
+ res = test_multiclass_decision_function(df5, samples, labels);
+ dlog << LINFO << res;
+ dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res);
+ DLIB_TEST(sum(diag(res)) == samples.size());
+ }
+
+ void perform_test (
+ )
+ {
+ std::vector<sample_type> samples;
+ std::vector<scalar_type> labels;
+
+ dlib::rand rnd;
+
+ dlog << LINFO << "test with 100 samples per class";
+ make_dataset(samples, labels, 100, rnd);
+ run_test(samples, labels, 1.155);
+
+ dlog << LINFO << "test with 1 sample per class";
+ make_dataset(samples, labels, 1, rnd);
+ run_test(samples, labels, 0.251);
+
+ dlog << LINFO << "test with 2 sample per class";
+ make_dataset(samples, labels, 2, rnd);
+ run_test(samples, labels, 0.444);
+ }
+ } a;
+
+
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/svr_linear_trainer.cpp b/ml/dlib/dlib/test/svr_linear_trainer.cpp
new file mode 100644
index 000000000..ca1a5442f
--- /dev/null
+++ b/ml/dlib/dlib/test/svr_linear_trainer.cpp
@@ -0,0 +1,161 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/matrix.h>
+#include <sstream>
+#include <string>
+#include <ctime>
+#include <vector>
+#include <dlib/statistics.h>
+
+#include "tester.h"
+#include <dlib/svm.h>
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.svr_linear_trainer");
+
+ typedef matrix<double, 0, 1> sample_type;
+ typedef std::vector<std::pair<unsigned int, double> > sparse_sample_type;
+
+// ----------------------------------------------------------------------------------------
+
+ double sinc(double x)
+ {
+ if (x == 0)
+ return 1;
+ return sin(x)/x;
+ }
+
+ template <typename scalar_type>
+ void test1()
+ {
+ typedef matrix<scalar_type,0,1> sample_type;
+
+ typedef radial_basis_kernel<sample_type> kernel_type;
+
+ print_spinner();
+
+ std::vector<sample_type> samples;
+ std::vector<scalar_type> targets;
+
+ // The first thing we do is pick a few training points from the sinc() function.
+ sample_type m(1);
+ for (scalar_type x = -10; x <= 4; x += 1)
+ {
+ m(0) = x;
+
+ samples.push_back(m);
+ targets.push_back(sinc(x)+1.1);
+ }
+
+ randomize_samples(samples, targets);
+
+ empirical_kernel_map<kernel_type> ekm;
+ ekm.load(kernel_type(0.1), samples);
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ samples[i] = ekm.project(samples[i]);
+
+ svr_linear_trainer<linear_kernel<sample_type> > linear_trainer;
+ linear_trainer.set_epsilon(0.0001);
+ linear_trainer.set_c(30);
+ linear_trainer.set_epsilon_insensitivity(0.001);
+
+ matrix<double> res = cross_validate_regression_trainer(linear_trainer, samples, targets, 5);
+ dlog << LINFO << "MSE and R-Squared: "<< res;
+ DLIB_TEST(res(0) < 1e-4);
+ DLIB_TEST(res(1) > 0.99);
+
+ dlib::rand rnd;
+
+ samples.clear();
+ targets.clear();
+ std::vector<scalar_type> noisefree_targets;
+ for (scalar_type x = 0; x <= 5; x += 0.1)
+ {
+ m(0) = x;
+ samples.push_back(matrix_cast<scalar_type>(linpiece(m, linspace(0,5,20))));
+ targets.push_back(x*x + rnd.get_random_gaussian());
+ noisefree_targets.push_back(x*x);
+ }
+ linear_trainer.set_learns_nonnegative_weights(true);
+ linear_trainer.set_epsilon_insensitivity(1.0);
+ decision_function<linear_kernel<sample_type> > df2 = linear_trainer.train(samples, targets);
+
+ print_spinner();
+ res = test_regression_function(df2, samples, noisefree_targets);
+ dlog << LINFO << "MSE and R-Squared: "<< res;
+ DLIB_TEST(res(0) < 0.15);
+ DLIB_TEST(res(1) > 0.98);
+ DLIB_TEST(df2.basis_vectors.size()==1);
+ DLIB_TEST(max(df2.basis_vectors(0)) >= 0);
+
+ linear_trainer.force_last_weight_to_1(true);
+ df2 = linear_trainer.train(samples, targets);
+ DLIB_TEST(std::abs(df2.basis_vectors(0)(samples[0].size()-1) - 1.0) < 1e-14);
+
+ res = test_regression_function(df2, samples, noisefree_targets);
+ dlog << LINFO << "MSE and R-Squared: "<< res;
+ DLIB_TEST(res(0) < 0.20);
+ DLIB_TEST(res(1) > 0.98);
+
+
+ // convert into sparse vectors and try it out
+ typedef std::vector<std::pair<unsigned long, scalar_type> > sparse_samp;
+ std::vector<sparse_samp> ssamples;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ sparse_samp s;
+ for (long j = 0; j < samples[i].size(); ++j)
+ s.push_back(make_pair(j,samples[i](j)));
+ ssamples.push_back(s);
+ }
+
+ svr_linear_trainer<sparse_linear_kernel<sparse_samp> > strainer;
+ strainer.set_learns_nonnegative_weights(true);
+ strainer.set_epsilon_insensitivity(1.0);
+ strainer.set_c(30);
+ decision_function<sparse_linear_kernel<sparse_samp> > df;
+ df = strainer.train(ssamples, targets);
+ res = test_regression_function(df, ssamples, noisefree_targets);
+ dlog << LINFO << "MSE and R-Squared: "<< res;
+ DLIB_TEST(res(0) < 0.15);
+ DLIB_TEST(res(1) > 0.98);
+ DLIB_TEST(df2.basis_vectors.size()==1);
+ DLIB_TEST(max(sparse_to_dense(df2.basis_vectors(0))) >= 0);
+ }
+
+
+// ----------------------------------------------------------------------------------------
+
+ class tester_svr_linear_trainer : public tester
+ {
+ public:
+ tester_svr_linear_trainer (
+ ) :
+ tester ("test_svr_linear_trainer",
+ "Runs tests on the svr_linear_trainer.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "TEST double";
+ test1<double>();
+ dlog << LINFO << "TEST float";
+ test1<float>();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/symmetric_matrix_cache.cpp b/ml/dlib/dlib/test/symmetric_matrix_cache.cpp
new file mode 100644
index 000000000..6d93a4daa
--- /dev/null
+++ b/ml/dlib/dlib/test/symmetric_matrix_cache.cpp
@@ -0,0 +1,212 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/matrix.h>
+#include <dlib/rand.h>
+#include <vector>
+#include <sstream>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.symmetric_matrix_cache");
+
+
+ class test_symmetric_matrix_cache : public tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a unit test. When it is constructed
+ it adds itself into the testing framework.
+ !*/
+ public:
+ test_symmetric_matrix_cache (
+ ) :
+ tester (
+ "test_symmetric_matrix_cache", // the command line argument name for this test
+ "Run tests on the symmetric_matrix_cache function.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+ dlib::rand rnd;
+
+ // -----------------------------------
+
+ template <typename EXP1, typename EXP2>
+ void test_colm_exp (
+ const matrix_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ )
+ {
+ for (long i = 0; i < m1.nc(); ++i)
+ {
+
+ typename colm_exp<EXP1>::type c1 = colm(m1,i);
+ typename colm_exp<EXP2>::type c2 = colm(m2,i);
+
+ DLIB_TEST(equal(c1 , c2));
+ DLIB_TEST(equal(colm(m1,i) , c2));
+ DLIB_TEST(equal(c1 , colm(m2,i)));
+ DLIB_TEST(equal(colm(m1,i) , colm(m2,i)));
+ }
+
+
+ // Get a bunch of columns at once to test out the reference
+ // counting and automatic cache expansion built into the symmetric_matrix_cache.
+ // This test verifies that, for example, getting column 3 doesn't stomp on
+ // any of the previous columns.
+ typename colm_exp<EXP1>::type c1_0 = colm(m1,0);
+ typename colm_exp<EXP1>::type c1_1 = colm(m1,1);
+ typename colm_exp<EXP1>::type c1_2 = colm(m1,2);
+ typename colm_exp<EXP1>::type c1_3 = colm(m1,3);
+ typename colm_exp<EXP1>::type c1_4 = colm(m1,4);
+ typename colm_exp<EXP1>::type c1_5 = colm(m1,5);
+
+ typename colm_exp<EXP2>::type c2_0 = colm(m2,0);
+ typename colm_exp<EXP2>::type c2_1 = colm(m2,1);
+ typename colm_exp<EXP2>::type c2_2 = colm(m2,2);
+ typename colm_exp<EXP2>::type c2_3 = colm(m2,3);
+ typename colm_exp<EXP2>::type c2_4 = colm(m2,4);
+ typename colm_exp<EXP2>::type c2_5 = colm(m2,5);
+
+ DLIB_TEST(equal(c1_0, c2_0));
+ DLIB_TEST(equal(c1_1, c2_1));
+ DLIB_TEST(equal(c1_2, c2_2));
+ DLIB_TEST(equal(c1_3, c2_3));
+ DLIB_TEST(equal(c1_4, c2_4));
+ DLIB_TEST(equal(c1_5, c2_5));
+ }
+
+ // -----------------------------------
+
+ template <typename EXP1, typename EXP2>
+ void test_rowm_exp (
+ const matrix_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ )
+ {
+ for (long i = 0; i < m1.nc(); ++i)
+ {
+
+ typename rowm_exp<EXP1>::type r1 = rowm(m1,i);
+ typename rowm_exp<EXP2>::type r2 = rowm(m2,i);
+
+ DLIB_TEST(equal(r1 , r2));
+ DLIB_TEST(equal(rowm(m1,i) , r2));
+ DLIB_TEST(equal(r1 , rowm(m2,i)));
+ DLIB_TEST(equal(rowm(m1,i) , rowm(m2,i)));
+ }
+
+
+ // Get a bunch of rows at once to test out the reference
+ // counting and automatic cache expansion built into the symmetric_matrix_cache.
+ // This test verifies that, for example, getting row 3 doesn't stomp on
+ // any of the previous rows.
+ typename rowm_exp<EXP1>::type r1_0 = rowm(m1,0);
+ typename rowm_exp<EXP1>::type r1_1 = rowm(m1,1);
+ typename rowm_exp<EXP1>::type r1_2 = rowm(m1,2);
+ typename rowm_exp<EXP1>::type r1_3 = rowm(m1,3);
+ typename rowm_exp<EXP1>::type r1_4 = rowm(m1,4);
+ typename rowm_exp<EXP1>::type r1_5 = rowm(m1,5);
+
+ typename rowm_exp<EXP2>::type r2_0 = rowm(m2,0);
+ typename rowm_exp<EXP2>::type r2_1 = rowm(m2,1);
+ typename rowm_exp<EXP2>::type r2_2 = rowm(m2,2);
+ typename rowm_exp<EXP2>::type r2_3 = rowm(m2,3);
+ typename rowm_exp<EXP2>::type r2_4 = rowm(m2,4);
+ typename rowm_exp<EXP2>::type r2_5 = rowm(m2,5);
+
+ DLIB_TEST(equal(r1_0, r2_0));
+ DLIB_TEST(equal(r1_1, r2_1));
+ DLIB_TEST(equal(r1_2, r2_2));
+ DLIB_TEST(equal(r1_3, r2_3));
+ DLIB_TEST(equal(r1_4, r2_4));
+ DLIB_TEST(equal(r1_5, r2_5));
+ }
+
+ // -----------------------------------
+
+ template <typename EXP1, typename EXP2>
+ void test_diag_exp (
+ const matrix_exp<EXP1>& m1,
+ const matrix_exp<EXP2>& m2
+ )
+ {
+
+ typename diag_exp<EXP1>::type c1 = diag(m1);
+ typename diag_exp<EXP2>::type c2 = diag(m2);
+
+ DLIB_TEST(equal(c1 , c2));
+ DLIB_TEST(equal(diag(m1) , c2));
+ DLIB_TEST(equal(c1 , diag(m2)));
+ DLIB_TEST(equal(diag(m1) , diag(m2)));
+ }
+
+ // -----------------------------------
+
+ void test_stuff (
+ long csize
+ )
+ {
+ print_spinner();
+ dlog << LINFO << "csize: "<< csize;
+ matrix<double> m = randm(10,10,rnd);
+
+ m = make_symmetric(m);
+
+ DLIB_TEST(equal(symmetric_matrix_cache<float>(m, csize), matrix_cast<float>(m)));
+ DLIB_TEST(equal(symmetric_matrix_cache<double>(m, csize), matrix_cast<double>(m)));
+
+ dlog << LINFO << "test colm/rowm";
+
+
+ for (long i = 0; i < m.nr(); ++i)
+ {
+ DLIB_TEST(equal(colm(symmetric_matrix_cache<float>(m, csize),i), colm(matrix_cast<float>(m),i)));
+ DLIB_TEST(equal(rowm(symmetric_matrix_cache<float>(m, csize),i), rowm(matrix_cast<float>(m),i)));
+ // things are supposed to be symmetric
+ DLIB_TEST(equal(colm(symmetric_matrix_cache<float>(m, csize),i), trans(rowm(matrix_cast<float>(m),i))));
+ DLIB_TEST(equal(rowm(symmetric_matrix_cache<float>(m, csize),i), trans(colm(matrix_cast<float>(m),i))));
+ }
+
+ dlog << LINFO << "test diag";
+ DLIB_TEST(equal(diag(symmetric_matrix_cache<float>(m,csize)), diag(matrix_cast<float>(m))));
+
+ test_colm_exp(symmetric_matrix_cache<float>(m,csize), matrix_cast<float>(m));
+ test_rowm_exp(symmetric_matrix_cache<float>(m,csize), matrix_cast<float>(m));
+ test_diag_exp(symmetric_matrix_cache<float>(m,csize), matrix_cast<float>(m));
+
+ test_colm_exp(tmp(symmetric_matrix_cache<float>(m,csize)), tmp(matrix_cast<float>(m)));
+ test_rowm_exp(symmetric_matrix_cache<float>(m,csize), tmp(matrix_cast<float>(m)));
+ test_diag_exp(tmp(symmetric_matrix_cache<float>(m,csize)), tmp(matrix_cast<float>(m)));
+ }
+
+
+ void perform_test (
+ )
+ {
+
+ for (int itr = 0; itr < 5; ++itr)
+ {
+ test_stuff(0);
+ test_stuff(1);
+ test_stuff(2);
+ }
+
+ }
+ };
+
+ // Create an instance of this object. Doing this causes this test
+ // to be automatically inserted into the testing framework whenever this cpp file
+ // is linked into the project. Note that since we are inside an unnamed-namespace
+ // we won't get any linker errors about the symbol a being defined multiple times.
+ test_symmetric_matrix_cache a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/tester.cpp b/ml/dlib/dlib/test/tester.cpp
new file mode 100644
index 000000000..2fb4d41ac
--- /dev/null
+++ b/ml/dlib/dlib/test/tester.cpp
@@ -0,0 +1,175 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <string>
+#include "tester.h"
+#include <cstdlib>
+#include <dlib/threads.h>
+
+namespace test
+{
+// -----------------------------------------------------------------------------
+
+ bool be_verbose = true;
+
+// -----------------------------------------------------------------------------
+
+ static dlib::mutex spinner_mutex;
+ static dlib::mutex test_count_mutex;
+ dlib::uint64 test_count = 0;
+
+// -----------------------------------------------------------------------------
+
+ dlib::uint64 number_of_testing_statements_executed (
+ )
+ {
+ dlib::auto_mutex lock(test_count_mutex);
+ return test_count;
+ }
+
+ void increment_test_count (
+ )
+ {
+ test_count_mutex.lock();
+ ++test_count;
+ test_count_mutex.unlock();
+ }
+
+// -----------------------------------------------------------------------------
+
+ void check_test (
+ bool _exp,
+ long line,
+ const char* file,
+ const char* _exp_str
+ )
+ {
+ test_count_mutex.lock();
+ ++test_count;
+ test_count_mutex.unlock();
+ if ( !(_exp) )
+ {
+ std::ostringstream dlib_o_out;
+ dlib_o_out << "\n\nError occurred at line " << line << ".\n";
+ dlib_o_out << "Error occurred in file " << file << ".\n";
+ dlib_o_out << "Failing expression was " << _exp_str << ".\n";
+ throw dlib::error(dlib_o_out.str());
+ }
+ }
+
+// -----------------------------------------------------------------------------
+
+ map_of_testers& testers (
+ )
+ {
+ static map_of_testers t;
+ return t;
+ }
+
+// -----------------------------------------------------------------------------
+
+ tester::
+ tester (
+ const std::string& switch_name_x,
+ const std::string& description_x,
+ unsigned long num_of_args_x
+ ) :
+ switch_name(switch_name_x),
+ description_(description_x),
+ num_of_args_(num_of_args_x)
+ {
+ using namespace std;
+ if (testers().is_in_domain(switch_name))
+ {
+ cerr << "ERROR: More than one tester has been defined with the switch '" << switch_name << "'." << endl;
+ exit(1);
+ }
+
+ string temp(switch_name);
+ tester* t = this;
+ testers().add(temp,t);
+ }
+
+// -----------------------------------------------------------------------------
+
+ const std::string& tester::
+ cmd_line_switch (
+ ) const
+ {
+ return switch_name;
+ }
+
+// -----------------------------------------------------------------------------
+
+ const std::string& tester::
+ description (
+ ) const
+ {
+ return description_;
+ }
+
+// -----------------------------------------------------------------------------
+
+ unsigned long tester::
+ num_of_args (
+ ) const
+ {
+ return num_of_args_;
+ }
+
+// -----------------------------------------------------------------------------
+
+ void tester::
+ perform_test (
+ )
+ {
+ }
+
+// -----------------------------------------------------------------------------
+
+ void tester::
+ perform_test (
+ const std::string&
+ )
+ {
+ }
+
+// -----------------------------------------------------------------------------
+
+ void tester::
+ perform_test (
+ const std::string&,
+ const std::string&
+ )
+ {
+ }
+
+// -----------------------------------------------------------------------------
+
+ void print_spinner (
+ )
+ {
+ if (be_verbose)
+ {
+ using namespace std;
+ dlib::auto_mutex M(spinner_mutex);
+ static int i = 0;
+ cout << "\b\b";
+ switch (i)
+ {
+ case 0: cout << '|'; break;
+ case 1: cout << '/'; break;
+ case 2: cout << '-'; break;
+ case 3: cout << '\\'; break;
+ }
+ cout << " " << flush;
+ i = (i+1)%4;
+ }
+ }
+
+// -----------------------------------------------------------------------------
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/tester.h b/ml/dlib/dlib/test/tester.h
new file mode 100644
index 000000000..e16647cf5
--- /dev/null
+++ b/ml/dlib/dlib/test/tester.h
@@ -0,0 +1,187 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TESTEr_
+#define DLIB_TESTEr_
+
+#include <iostream>
+#include <string>
+#include <dlib/map.h>
+#include <dlib/logger.h>
+#include <dlib/assert.h>
+#include <dlib/algs.h>
+#include <typeinfo>
+
+#ifdef __INTEL_COMPILER
+// ignore the bogus warning about not overloading perform_test() all the way
+#pragma warning (disable: 654)
+#endif
+
+
+#define DLIB_TEST(_exp) check_test(bool(_exp), __LINE__, __FILE__, #_exp)
+
+#define DLIB_TEST_MSG(_exp,_message) \
+ do{increment_test_count(); if ( !(_exp) ) \
+ { \
+ std::ostringstream dlib_o_out; \
+ dlib_o_out << "\n\nError occurred at line " << __LINE__ << ".\n"; \
+ dlib_o_out << "Error occurred in file " << __FILE__ << ".\n"; \
+ dlib_o_out << "Failing expression was " << #_exp << ".\n"; \
+ dlib_o_out << _message << "\n"; \
+ throw dlib::error(dlib_o_out.str()); \
+ }}while(0)
+
+namespace test
+{
+ class tester;
+ typedef dlib::map<std::string,tester*>::kernel_1a_c map_of_testers;
+
+ map_of_testers& testers (
+ );
+
+// -----------------------------------------------------------------------------
+
+ void check_test (
+ bool _exp,
+ long line,
+ const char* file,
+ const char* _exp_str
+ );
+
+// -----------------------------------------------------------------------------
+
+// This bool controls any cout statements in this program. Only print to
+// standard out if we should be verbose. The default is true
+ extern bool be_verbose;
+
+// -----------------------------------------------------------------------------
+
+ dlib::uint64 number_of_testing_statements_executed (
+ );
+ /*!
+ ensures
+ - returns the total number of DLIB_TEST and DLIB_TEST_MSG
+ statements executed since program startup.
+ !*/
+
+ void increment_test_count (
+ );
+ /*!
+ ensures
+ - increments number_of_testing_statements_executed()
+ !*/
+
+// -----------------------------------------------------------------------------
+
+ void print_spinner (
+ );
+ /*!
+ ensures
+ - reprints the spinner
+ !*/
+
+// -----------------------------------------------------------------------------
+
+ class tester
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a generic regression test.
+ !*/
+
+ public:
+
+ tester (
+ const std::string& switch_name,
+ const std::string& description_,
+ unsigned long num_of_args_ = 0
+ );
+ /*!
+ requires
+ - testers().is_in_domain(switch_name) == false
+ ensures
+ - #cmd_line_switch() == switch_name
+ - #description() == description_
+ - #num_of_args() == num_of_args_
+ - adds this tester to the testers() map.
+ !*/
+
+ virtual ~tester (
+ ){}
+
+ const std::string& cmd_line_switch (
+ ) const;
+ /*!
+ ensures
+ - returns the name of the command line switch for this tester.
+ !*/
+
+ const std::string& description (
+ ) const;
+ /*!
+ ensures
+ - returns the description of what this tester tests.
+ !*/
+
+ unsigned long num_of_args (
+ ) const;
+ /*!
+ ensures
+ - returns the number of arguments this test expects
+ !*/
+
+ virtual void perform_test (
+ );
+ /*!
+ requires
+ - is invoked when number_of_args() == 0
+ ensures
+ - performs the test and throws an exception
+ derived from std::exception if the test fails.
+ !*/
+
+ virtual void perform_test (
+ const std::string& arg
+ );
+ /*!
+ requires
+ - is invoked when number_of_args() == 1
+ ensures
+ - performs the test and throws an exception
+ derived from std::exception if the test fails.
+ !*/
+
+ virtual void perform_test (
+ const std::string& arg1,
+ const std::string& arg2
+ );
+ /*!
+ requires
+ - is invoked when number_of_args() == 2
+ ensures
+ - performs the test and throws an exception
+ derived from std::exception if the test fails.
+ !*/
+
+ private:
+
+ // ---------------------------------------------------------------------------
+ // Implementation Details
+ // ---------------------------------------------------------------------------
+
+ /*!
+ CONVENTION
+ - switch_name == cmd_line_switch()
+ - description_ == description()
+ - num_of_args_ == num_of_args()
+ - test::tester[switch_name] == this
+ !*/
+
+ const std::string switch_name;
+ const std::string description_;
+ const unsigned long num_of_args_;
+ };
+
+}
+
+#endif // DLIB_TESTEr_
+
diff --git a/ml/dlib/dlib/test/thread_pool.cpp b/ml/dlib/dlib/test/thread_pool.cpp
new file mode 100644
index 000000000..73ccb346e
--- /dev/null
+++ b/ml/dlib/dlib/test/thread_pool.cpp
@@ -0,0 +1,428 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/misc_api.h>
+#include <dlib/threads.h>
+#include <dlib/any.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.thread_pool");
+
+
+ struct some_struct : noncopyable
+ {
+ float val;
+ };
+
+ int global_var = 0;
+
+ struct add_functor
+ {
+ add_functor() { var = 1;}
+ add_functor(int v):var(v) {}
+
+ template <typename T, typename U, typename V>
+ void operator()(T a, U b, V& res)
+ {
+ dlib::sleep(20);
+ res = a + b;
+ }
+
+ void set_global_var() { global_var = 9; }
+ void set_global_var_const() const { global_var = 9; }
+
+ void set_global_var_arg1(int val) { global_var = val; }
+ void set_global_var_const_arg1(int val) const { global_var = val; }
+ void set_global_var_arg2(int val, int val2) { global_var = val+val2; }
+ void set_global_var_const_arg2(int val, int val2) const { global_var = val+val2; }
+
+ void operator()()
+ {
+ global_var = 9;
+ }
+
+ // use an any just so that if this object goes out of scope
+ // then var will get all messed up.
+ any var;
+ void operator()(int& a) { dlib::sleep(100); a = var.get<int>(); }
+ void operator()(int& a, int& b) { dlib::sleep(100); a = var.get<int>(); b = 2; }
+ void operator()(int& a, int& b, int& c) { dlib::sleep(100); a = var.get<int>(); b = 2; c = 3; }
+ void operator()(int& a, int& b, int& c, int& d) { dlib::sleep(100); a = var.get<int>(); b = 2; c = 3; d = 4; }
+ };
+
+
+ void set_global_var() { global_var = 9; }
+
+ void gset_struct_to_zero (some_struct& a) { a.val = 0; }
+ void gset_to_zero (int& a) { a = 0; }
+ void gincrement (int& a) { ++a; }
+ void gadd (int a, const int& b, int& res) { dlib::sleep(20); res = a + b; }
+ void gadd1(int& a, int& res) { res += a; }
+ void gadd2 (int c, int a, const int& b, int& res) { dlib::sleep(20); res = a + b + c; }
+
+ class thread_pool_tester : public tester
+ {
+ public:
+ thread_pool_tester (
+ ) :
+ tester ("test_thread_pool",
+ "Runs tests on the thread_pool component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ add_functor f;
+ for (int num_threads= 0; num_threads < 4; ++num_threads)
+ {
+ dlib::future<int> a, b, c, res, d;
+ thread_pool tp(num_threads);
+ print_spinner();
+
+ dlib::future<some_struct> obj;
+
+
+ for (int i = 0; i < 4; ++i)
+ {
+ a = 1;
+ b = 2;
+ c = 3;
+ res = 4;
+
+
+ DLIB_TEST(a==a);
+ DLIB_TEST(a!=b);
+ DLIB_TEST(a==1);
+
+ tp.add_task(gset_to_zero, a);
+ tp.add_task(gset_to_zero, b);
+ tp.add_task(*this, &thread_pool_tester::set_to_zero, c);
+ tp.add_task(gset_to_zero, res);
+ DLIB_TEST(a == 0);
+ DLIB_TEST(b == 0);
+ DLIB_TEST(c == 0);
+ DLIB_TEST(res == 0);
+
+
+ tp.add_task(gincrement, a);
+ tp.add_task(*this, &thread_pool_tester::increment, b);
+ tp.add_task(*this, &thread_pool_tester::increment, c);
+ tp.add_task(gincrement, res);
+
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 1);
+ DLIB_TEST(c == 1);
+ DLIB_TEST(res == 1);
+
+ tp.add_task(&gincrement, a);
+ tp.add_task(*this, &thread_pool_tester::increment, b);
+ tp.add_task(*this, &thread_pool_tester::increment, c);
+ tp.add_task(&gincrement, res);
+ tp.add_task(gincrement, a);
+ tp.add_task(*this, &thread_pool_tester::increment, b);
+ tp.add_task(*this, &thread_pool_tester::increment, c);
+ tp.add_task(gincrement, res);
+
+ DLIB_TEST(a == 3);
+ DLIB_TEST(b == 3);
+ DLIB_TEST(c == 3);
+ DLIB_TEST(res == 3);
+
+ tp.add_task(*this, &thread_pool_tester::increment, c);
+ tp.add_task(gincrement, res);
+ DLIB_TEST(c == 4);
+ DLIB_TEST(res == 4);
+
+
+ tp.add_task(gadd, a, b, res);
+ DLIB_TEST(res == a+b);
+ DLIB_TEST(res == 6);
+ a = 3;
+ b = 4;
+ res = 99;
+ DLIB_TEST(res == 99);
+ tp.add_task(*this, &thread_pool_tester::add, a, b, res);
+ DLIB_TEST(res == a+b);
+ DLIB_TEST(res == 7);
+
+ a = 1;
+ b = 2;
+ c = 3;
+ res = 88;
+ DLIB_TEST(res == 88);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(c == 3);
+
+ tp.add_task(gadd2, a, b, c, res);
+ DLIB_TEST(res == 6);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(c == 3);
+
+ a = 1;
+ b = 2;
+ c = 3;
+ res = 88;
+ DLIB_TEST(res == 88);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(c == 3);
+ tp.add_task(*this, &thread_pool_tester::add2, a, b, c, res);
+ DLIB_TEST(res == 6);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(c == 3);
+
+ a = 1;
+ b = 2;
+ c = 3;
+ res = 88;
+ tp.add_task(gadd1, a, b);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 3);
+ a = 2;
+ tp.add_task(*this, &thread_pool_tester::add1, a, b);
+ DLIB_TEST(a == 2);
+ DLIB_TEST(b == 5);
+
+
+ val = 4;
+ uint64 id = tp.add_task(*this, &thread_pool_tester::zero_val);
+ tp.wait_for_task(id);
+ DLIB_TEST(val == 0);
+ id = tp.add_task(*this, &thread_pool_tester::accum2, 1,2);
+ tp.wait_for_all_tasks();
+ DLIB_TEST(val == 3);
+ id = tp.add_task(*this, &thread_pool_tester::accum1, 3);
+ tp.wait_for_task(id);
+ DLIB_TEST(val == 6);
+
+
+ obj.get().val = 8;
+ DLIB_TEST(obj.get().val == 8);
+ tp.add_task(gset_struct_to_zero, obj);
+ DLIB_TEST(obj.get().val == 0);
+ obj.get().val = 8;
+ DLIB_TEST(obj.get().val == 8);
+ tp.add_task(*this,&thread_pool_tester::set_struct_to_zero, obj);
+ DLIB_TEST(obj.get().val == 0);
+
+ a = 1;
+ b = 2;
+ res = 0;
+ tp.add_task(f, a, b, res);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(res == 3);
+
+
+ global_var = 0;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task(&set_global_var);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 9);
+
+ global_var = 0;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task(f);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 9);
+
+ global_var = 0;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task(f, &add_functor::set_global_var);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 9);
+
+ global_var = 0;
+ a = 4;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task(f, &add_functor::set_global_var_arg1, a);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 4);
+
+ global_var = 0;
+ a = 4;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task_by_value(f, &add_functor::set_global_var_arg1, a);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 4);
+
+
+
+ global_var = 0;
+ a = 4;
+ b = 3;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task(f, &add_functor::set_global_var_arg2, a, b);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 7);
+
+ global_var = 0;
+ a = 4;
+ b = 3;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task_by_value(f, &add_functor::set_global_var_arg2, a, b);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 7);
+
+ global_var = 0;
+ a = 4;
+ b = 3;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task(f, &add_functor::set_global_var_const_arg2, a, b);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 7);
+
+ global_var = 0;
+ a = 4;
+ b = 3;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task_by_value(f, &add_functor::set_global_var_const_arg2, a, b);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 7);
+
+
+
+
+
+
+ global_var = 0;
+ a = 4;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task(f, &add_functor::set_global_var_const_arg1, a);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 4);
+
+ global_var = 0;
+ a = 4;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task_by_value(f, &add_functor::set_global_var_const_arg1, a);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 4);
+
+ global_var = 0;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task_by_value(f, &add_functor::set_global_var);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 9);
+
+
+ global_var = 0;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task(f, &add_functor::set_global_var_const);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 9);
+
+
+ global_var = 0;
+ DLIB_TEST(global_var == 0);
+ id = tp.add_task_by_value(f, &add_functor::set_global_var_const);
+ tp.wait_for_task(id);
+ DLIB_TEST(global_var == 9);
+
+
+
+ }
+
+ // add this task just to to perterb the thread pool before it goes out of scope
+ tp.add_task(f, a, b, res);
+
+ for (int k = 0; k < 3; ++k)
+ {
+ print_spinner();
+ global_var = 0;
+ tp.add_task_by_value(add_functor());
+ tp.wait_for_all_tasks();
+ DLIB_TEST(global_var == 9);
+
+ a = 0; b = 0; c = 0; d = 0;
+ tp.add_task_by_value(add_functor(), a);
+ DLIB_TEST(a == 1);
+ a = 0; b = 0; c = 0; d = 0;
+ tp.add_task_by_value(add_functor(8), a, b);
+ DLIB_TEST(a == 8);
+ DLIB_TEST(b == 2);
+ a = 0; b = 0; c = 0; d = 0;
+ tp.add_task_by_value(add_functor(), a, b, c);
+ DLIB_TEST(a == 1);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(c == 3);
+ a = 0; b = 0; c = 0; d = 0;
+ tp.add_task_by_value(add_functor(5), a, b, c, d);
+ DLIB_TEST(a == 5);
+ DLIB_TEST(b == 2);
+ DLIB_TEST(c == 3);
+ DLIB_TEST(d == 4);
+ }
+
+
+ tp.wait_for_all_tasks();
+
+ // make sure exception propagation from tasks works correctly.
+ auto f_throws = []() { throw dlib::error("test exception");};
+ bool got_exception = false;
+ try
+ {
+ tp.add_task_by_value(f_throws);
+ tp.wait_for_all_tasks();
+ }
+ catch(dlib::error& e)
+ {
+ DLIB_TEST(e.info == "test exception");
+ got_exception = true;
+ }
+ DLIB_TEST(got_exception);
+
+ dlib::future<int> aa;
+ auto f_throws2 = [](int& a) { a = 1; throw dlib::error("test exception");};
+ got_exception = false;
+ try
+ {
+ tp.add_task(f_throws2, aa);
+ aa.get();
+ }
+ catch(dlib::error& e)
+ {
+ DLIB_TEST(e.info == "test exception");
+ got_exception = true;
+ }
+ DLIB_TEST(got_exception);
+
+ }
+ }
+
+ long val;
+ void accum1(long a) { val += a; }
+ void accum2(long a, long b) { val += a + b; }
+ void zero_val() { dlib::sleep(20); val = 0; }
+
+
+ void set_struct_to_zero (some_struct& a) { a.val = 0; }
+ void set_to_zero (int& a) { dlib::sleep(20); a = 0; }
+ void increment (int& a) const { dlib::sleep(20); ++a; }
+ void add (int a, const int& b, int& res) { dlib::sleep(20); res = a + b; }
+ void add1(int& a, int& res) const { res += a; }
+ void add2 (int c, int a, const int& b, int& res) { res = a + b + c; }
+
+
+ } a;
+
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/threads.cpp b/ml/dlib/dlib/test/threads.cpp
new file mode 100644
index 000000000..1aeb1a3f9
--- /dev/null
+++ b/ml/dlib/dlib/test/threads.cpp
@@ -0,0 +1,158 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/misc_api.h>
+#include <dlib/threads.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.threads");
+
+ void test_async()
+ {
+#if __cplusplus >= 201103
+ print_spinner();
+ auto v1 = dlib::async([]() { dlib::sleep(500); return 1; }).share();
+ auto v2 = dlib::async([v1]() { dlib::sleep(400); return v1.get()+1; }).share();
+ auto v3 = dlib::async([v2](int a) { dlib::sleep(300); return v2.get()+a; },2).share();
+ auto v4 = dlib::async([v3]() { dlib::sleep(200); return v3.get()+1; });
+
+ DLIB_TEST(v4.get() == 5);
+
+ print_spinner();
+ auto except = dlib::async([](){ dlib::sleep(300); throw error("oops"); });
+ bool got_exception = false;
+ try
+ {
+ except.get();
+ }
+ catch (error&e)
+ {
+ got_exception = true;
+ DLIB_TEST(e.what() == string("oops"));
+ }
+ DLIB_TEST(got_exception);
+#endif
+ }
+
+ class threads_tester : public tester
+ {
+ public:
+ threads_tester (
+ ) :
+ tester ("test_threads",
+ "Runs tests on the threads component."),
+ sm(cm)
+ {}
+
+ thread_specific_data<int> tsd;
+ rmutex cm;
+ rsignaler sm;
+ int count;
+ bool failure;
+
+ void perform_test (
+ )
+ {
+ failure = false;
+ print_spinner();
+
+
+ count = 10;
+ if (!create_new_thread<threads_tester,&threads_tester::thread1>(*this)) failure = true;
+ if (!create_new_thread<threads_tester,&threads_tester::thread2>(*this)) failure = true;
+ if (!create_new_thread<threads_tester,&threads_tester::thread3>(*this)) failure = true;
+ if (!create_new_thread<threads_tester,&threads_tester::thread4>(*this)) failure = true;
+ if (!create_new_thread<threads_tester,&threads_tester::thread5>(*this)) failure = true;
+ if (!create_new_thread<threads_tester,&threads_tester::thread6>(*this)) failure = true;
+ if (!create_new_thread<threads_tester,&threads_tester::thread7>(*this)) failure = true;
+ if (!create_new_thread<threads_tester,&threads_tester::thread8>(*this)) failure = true;
+ if (!create_new_thread<threads_tester,&threads_tester::thread9>(*this)) failure = true;
+ if (!create_new_thread<threads_tester,&threads_tester::thread10>(*this)) failure = true;
+
+ thread(66);
+
+ // this should happen in the main program thread
+ if (is_dlib_thread())
+ failure = true;
+
+ auto_mutex M(cm);
+ while (count > 0 && !failure)
+ sm.wait();
+
+
+ DLIB_TEST(!failure);
+
+ test_async();
+ }
+
+ void thread_end_handler (
+ )
+ {
+ auto_mutex M(cm);
+ --count;
+ if (count == 0)
+ sm.signal();
+ }
+
+ void thread1() { thread(1); }
+ void thread2()
+ {
+ thread(2);
+ if (is_dlib_thread() == false)
+ failure = true;
+ }
+ void thread3() { thread(3); }
+ void thread4() { thread(4); }
+ void thread5() { thread(5); }
+ void thread6() { thread(6); }
+ void thread7() { thread(7); }
+ void thread8() { thread(8); }
+ void thread9() { thread(9); }
+ void thread10() { thread(10); }
+
+ void thread (
+ int num
+ )
+ {
+ dlog << LTRACE << "starting thread num " << num;
+ if (is_dlib_thread())
+ register_thread_end_handler(*this,&threads_tester::thread_end_handler);
+ tsd.data() = num;
+ for (int i = 0; i < 0x3FFFF; ++i)
+ {
+ if ((i&0xFFF) == 0)
+ {
+ print_spinner();
+ dlib::sleep(10);
+ }
+ // if this isn't equal to num then there is a problem with the thread specific data stuff
+ if (tsd.data() != num)
+ {
+ auto_mutex M(cm);
+ failure = true;
+ sm.signal();
+ }
+ }
+ dlog << LTRACE << "ending of thread num " << num;
+
+
+ }
+ } a;
+
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/timer.cpp b/ml/dlib/dlib/test/timer.cpp
new file mode 100644
index 000000000..ae004a55d
--- /dev/null
+++ b/ml/dlib/dlib/test/timer.cpp
@@ -0,0 +1,347 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+
+#include <dlib/timer.h>
+#include <dlib/timeout.h>
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.timer");
+
+ class timer_test_helper
+ {
+ public:
+ dlib::mutex m;
+ int count;
+ dlib::uint64 timestamp;
+ dlib::timestamper ts;
+
+ timer_test_helper():count(0), timestamp(0){}
+ void add()
+ {
+ m.lock();
+ ++count;
+ m.unlock();
+ }
+
+ void delayed_add()
+ {
+ dlib::sleep(1000);
+ print_spinner();
+ add();
+ }
+
+ void set_timestamp()
+ {
+ m.lock();
+ timestamp = ts.get_timestamp();
+ dlog << LTRACE << "in set_timestamp(), time is " << timestamp;
+ dlib::sleep(1);
+ print_spinner();
+ m.unlock();
+ }
+ };
+
+ template <
+ typename timer_t
+ >
+ void timer_test2 (
+ )
+ /*!
+ requires
+ - timer_t is an implementation of dlib/timer/timer_abstract.h is instantiated
+ timer_test_helper
+ ensures
+ - runs tests on timer_t for compliance with the specs
+ !*/
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ print_spinner();
+ timer_test_helper h;
+
+ timer_t t1(h,&timer_test_helper::set_timestamp);
+ t1.set_delay_time(0);
+ dlog << LTRACE << "t1.start()";
+ t1.start();
+
+ dlib::sleep(60);
+ print_spinner();
+ t1.stop_and_wait();
+
+ dlib::uint64 cur_time = h.ts.get_timestamp();
+ dlog << LTRACE << "get current time: " << cur_time;
+
+ // make sure the action function has been called recently
+ DLIB_TEST_MSG((cur_time-h.timestamp)/1000 < 30, (cur_time-h.timestamp)/1000);
+
+ }
+ }
+
+ template <
+ typename timer_t
+ >
+ void timer_test (
+ )
+ /*!
+ requires
+ - timer_t is an implementation of dlib/timer/timer_abstract.h is instantiated
+ timer_test_helper
+ ensures
+ - runs tests on timer_t for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+ for (int j = 0; j < 3; ++j)
+ {
+ timer_test_helper h;
+
+ timer_t t1(h,&timer_test_helper::add);
+ timer_t t2(h,&timer_test_helper::add);
+ timer_t t3(h,&timer_test_helper::add);
+
+ DLIB_TEST(t1.delay_time() == 1000);
+ DLIB_TEST(t2.delay_time() == 1000);
+ DLIB_TEST(t3.delay_time() == 1000);
+ DLIB_TEST(t1.is_running() == false);
+ DLIB_TEST(t2.is_running() == false);
+ DLIB_TEST(t3.is_running() == false);
+ DLIB_TEST(t1.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t2.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t3.action_function() == &timer_test_helper::add);
+ DLIB_TEST(&t1.action_object() == &h);
+ DLIB_TEST(&t2.action_object() == &h);
+ DLIB_TEST(&t3.action_object() == &h);
+
+ t1.set_delay_time(1000);
+ t2.set_delay_time(500);
+ t3.set_delay_time(1500);
+
+ DLIB_TEST(t1.delay_time() == 1000);
+ DLIB_TEST(t2.delay_time() == 500);
+ DLIB_TEST(t3.delay_time() == 1500);
+ DLIB_TEST(t1.is_running() == false);
+ DLIB_TEST(t2.is_running() == false);
+ DLIB_TEST(t3.is_running() == false);
+ DLIB_TEST(t1.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t2.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t3.action_function() == &timer_test_helper::add);
+ DLIB_TEST(&t1.action_object() == &h);
+ DLIB_TEST(&t2.action_object() == &h);
+ DLIB_TEST(&t3.action_object() == &h);
+ dlib::sleep(1100);
+ print_spinner();
+ DLIB_TEST(h.count == 0);
+
+ t1.stop_and_wait();
+ t2.stop_and_wait();
+ t3.stop_and_wait();
+
+ dlib::sleep(1100);
+ print_spinner();
+ DLIB_TEST(h.count == 0);
+ DLIB_TEST(t1.delay_time() == 1000);
+ DLIB_TEST(t2.delay_time() == 500);
+ DLIB_TEST(t3.delay_time() == 1500);
+ DLIB_TEST(t1.is_running() == false);
+ DLIB_TEST(t2.is_running() == false);
+ DLIB_TEST(t3.is_running() == false);
+ DLIB_TEST(t1.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t2.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t3.action_function() == &timer_test_helper::add);
+ DLIB_TEST(&t1.action_object() == &h);
+ DLIB_TEST(&t2.action_object() == &h);
+ DLIB_TEST(&t3.action_object() == &h);
+
+ t1.start();
+ t2.start();
+ t3.start();
+
+ DLIB_TEST(t1.delay_time() == 1000);
+ DLIB_TEST(t2.delay_time() == 500);
+ DLIB_TEST(t3.delay_time() == 1500);
+ DLIB_TEST(t1.is_running() == true);
+ DLIB_TEST(t2.is_running() == true);
+ DLIB_TEST(t3.is_running() == true);
+ DLIB_TEST(t1.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t2.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t3.action_function() == &timer_test_helper::add);
+ DLIB_TEST(&t1.action_object() == &h);
+ DLIB_TEST(&t2.action_object() == &h);
+ DLIB_TEST(&t3.action_object() == &h);
+
+ t1.stop();
+ t2.stop();
+ t3.stop();
+
+ DLIB_TEST(t1.delay_time() == 1000);
+ DLIB_TEST(t2.delay_time() == 500);
+ DLIB_TEST(t3.delay_time() == 1500);
+ DLIB_TEST(t1.is_running() == false);
+ DLIB_TEST(t2.is_running() == false);
+ DLIB_TEST(t3.is_running() == false);
+ DLIB_TEST(t1.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t2.action_function() == &timer_test_helper::add);
+ DLIB_TEST(t3.action_function() == &timer_test_helper::add);
+ DLIB_TEST(&t1.action_object() == &h);
+ DLIB_TEST(&t2.action_object() == &h);
+ DLIB_TEST(&t3.action_object() == &h);
+
+ DLIB_TEST(h.count == 0);
+ dlib::sleep(1100);
+ print_spinner();
+ DLIB_TEST(h.count == 0);
+
+ for (int i = 1; i <= 3; ++i)
+ {
+ t1.start();
+ t2.start();
+ t3.start();
+
+ DLIB_TEST(t1.is_running() == true);
+ DLIB_TEST(t2.is_running() == true);
+ DLIB_TEST(t3.is_running() == true);
+
+ dlib::sleep(1800);
+ print_spinner();
+ // this should allow the timers to trigger 5 times
+ t1.stop();
+ t2.stop();
+ t3.stop();
+
+ DLIB_TEST_MSG(h.count == 5*i,"h.count: " << h.count << " i: " << i);
+ dlib::sleep(1100);
+ DLIB_TEST_MSG(h.count == 5*i,"h.count: " << h.count << " i: " << i);
+ }
+
+
+ t1.stop_and_wait();
+
+ h.count = 0;
+ t1.start();
+ dlib::sleep(300);
+ print_spinner();
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ t1.set_delay_time(400);
+ dlib::sleep(200);
+ print_spinner();
+ DLIB_TEST_MSG(h.count == 1,h.count);
+ dlib::sleep(250);
+ print_spinner();
+ DLIB_TEST_MSG(h.count == 1,h.count);
+ dlib::sleep(100);
+ print_spinner();
+ DLIB_TEST_MSG(h.count == 2,h.count);
+ t1.set_delay_time(2000);
+ DLIB_TEST_MSG(h.count == 2,h.count);
+ dlib::sleep(1000);
+ print_spinner();
+ DLIB_TEST_MSG(h.count == 2,h.count);
+ t1.clear();
+
+ h.count = 0;
+ t3.start();
+ DLIB_TEST(t3.is_running() == true);
+ DLIB_TEST(t3.delay_time() == 1500);
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ t3.clear();
+ DLIB_TEST(t3.is_running() == false);
+ DLIB_TEST(t3.delay_time() == 1000);
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ dlib::sleep(200);
+ print_spinner();
+ DLIB_TEST(t3.is_running() == false);
+ DLIB_TEST(t3.delay_time() == 1000);
+ DLIB_TEST_MSG(h.count == 0,h.count);
+
+
+ {
+ h.count = 0;
+ timer_t t4(h,&timer_test_helper::delayed_add);
+ t4.set_delay_time(100);
+ t4.start();
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ dlib::sleep(400);
+ print_spinner();
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ t4.stop_and_wait();
+ DLIB_TEST_MSG(h.count == 1,h.count);
+ DLIB_TEST(t4.is_running() == false);
+ }
+
+ {
+ h.count = 0;
+ timer_t t4(h,&timer_test_helper::delayed_add);
+ t4.set_delay_time(100);
+ t4.start();
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ dlib::sleep(400);
+ print_spinner();
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ t4.clear();
+ DLIB_TEST(t4.is_running() == false);
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ t4.stop_and_wait();
+ DLIB_TEST_MSG(h.count == 1,h.count);
+ DLIB_TEST(t4.is_running() == false);
+ }
+
+ {
+ h.count = 0;
+ timer_t t5(h,&timer_test_helper::delayed_add);
+ t5.set_delay_time(100);
+ t5.start();
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ dlib::sleep(400);
+ print_spinner();
+ DLIB_TEST_MSG(h.count == 0,h.count);
+ }
+ DLIB_TEST_MSG(h.count == 1,h.count);
+
+ }
+
+ }
+
+
+
+
+ class timer_tester : public tester
+ {
+ public:
+ timer_tester (
+ ) :
+ tester ("test_timer",
+ "Runs tests on the timer component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing timer_heavy with test_timer";
+ timer_test<timer_heavy<timer_test_helper> > ();
+ dlog << LINFO << "testing timer_heavy with test_timer2";
+ timer_test2<timer_heavy<timer_test_helper> > ();
+
+ dlog << LINFO << "testing timer with test_timer";
+ timer_test<timer<timer_test_helper> > ();
+ dlog << LINFO << "testing timer with test_timer2";
+ timer_test2<timer<timer_test_helper> > ();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/tokenizer.cpp b/ml/dlib/dlib/test/tokenizer.cpp
new file mode 100644
index 000000000..95a95a7e1
--- /dev/null
+++ b/ml/dlib/dlib/test/tokenizer.cpp
@@ -0,0 +1,378 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <string>
+#include <sstream>
+
+#include <dlib/tokenizer.h>
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace std;
+ using namespace dlib;
+
+ logger dlog("test.tokenizer");
+
+ template <
+ typename tok
+ >
+ void tokenizer_kernel_test (
+ )
+ /*!
+ requires
+ - tok is an implementation of tokenizer_kernel_abstract.h
+ ensures
+ - runs tests on tok for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ tok test;
+
+ DLIB_TEST(test.numbers() == "0123456789");
+ DLIB_TEST(test.uppercase_letters() == "ABCDEFGHIJKLMNOPQRSTUVWXYZ");
+ DLIB_TEST(test.lowercase_letters() == "abcdefghijklmnopqrstuvwxyz");
+
+ DLIB_TEST_MSG(test.get_identifier_body() == "_" + test.lowercase_letters() +
+ test.uppercase_letters() + test.numbers(),"");
+ DLIB_TEST_MSG(test.get_identifier_head() == "_" + test.lowercase_letters() +
+ test.uppercase_letters(),"");
+
+ DLIB_TEST(test.stream_is_set() == false);
+ test.clear();
+ DLIB_TEST(test.stream_is_set() == false);
+
+ DLIB_TEST_MSG(test.get_identifier_body() == "_" + test.lowercase_letters() +
+ test.uppercase_letters() + test.numbers(),"");
+ DLIB_TEST_MSG(test.get_identifier_head() == "_" + test.lowercase_letters() +
+ test.uppercase_letters(),"");
+
+ tok test2;
+
+ ostringstream sout;
+ istringstream sin;
+ test2.set_stream(sin);
+
+ DLIB_TEST(test2.stream_is_set());
+ DLIB_TEST(&test2.get_stream() == &sin);
+
+ int type;
+ string token;
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::END_OF_FILE);
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::END_OF_FILE);
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::END_OF_FILE);
+
+
+ sin.clear();
+ sin.str(" The cat 123asdf1234 ._ \n test.");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST(token == " ");
+
+ DLIB_TEST(test2.peek_type() == tok::IDENTIFIER);
+ DLIB_TEST(test2.peek_token() == "The");
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "The");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST(token == " ");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "cat");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST(token == " ");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::NUMBER);
+ DLIB_TEST_MSG(token == "123","token: " << token);
+
+ DLIB_TEST(test2.peek_type() == tok::IDENTIFIER);
+ DLIB_TEST(test2.peek_token() == "asdf1234");
+ DLIB_TEST(test2.peek_type() == tok::IDENTIFIER);
+ DLIB_TEST(test2.peek_token() == "asdf1234");
+ DLIB_TEST(test2.peek_type() == tok::IDENTIFIER);
+ DLIB_TEST(test2.peek_token() == "asdf1234");
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "asdf1234");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(token == " ","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::CHAR);
+ DLIB_TEST_MSG(token == ".","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "_");
+
+ DLIB_TEST(test2.peek_type() == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(test2.peek_token() == " ","token: \"" << token << "\"" <<
+ "\ntoken size: " << (unsigned int)token.size());
+
+ swap(test,test2);
+
+ DLIB_TEST(test2.stream_is_set() == false);
+
+ DLIB_TEST(test.peek_type() == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(test.peek_token() == " ","token: \"" << token << "\"" <<
+ "\ntoken size: " << (unsigned int)token.size());
+ test.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(token == " ","token: \"" << token << "\"" <<
+ "\ntoken size: " << (unsigned int)token.size());
+
+ test.get_token(type,token);
+ DLIB_TEST_MSG(type == tok::END_OF_LINE,"token: " << token);
+ DLIB_TEST_MSG(token == "\n","token: " << token);
+
+ swap(test,test2);
+ DLIB_TEST(test.stream_is_set() == false);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(token == " ","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST_MSG(token == "test","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::CHAR);
+ DLIB_TEST_MSG(token == ".","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::END_OF_FILE);
+
+
+
+
+
+
+
+
+
+
+ test2.set_identifier_token("_" + test.uppercase_letters() +
+ test.lowercase_letters(),test.numbers() + "_" + test.uppercase_letters()
+ +test.lowercase_letters());
+
+
+ sin.clear();
+ sin.str(" The cat 123asdf1234 ._ \n\r test.");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST(token == " ");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "The");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST(token == " ");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "cat");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST(token == " ");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::NUMBER);
+ DLIB_TEST_MSG(token == "123","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "asdf1234");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(token == " ","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::CHAR);
+ DLIB_TEST_MSG(token == ".","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "_");
+
+ swap(test,test2);
+
+ DLIB_TEST(test2.stream_is_set() == false);
+
+ test.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(token == " ","token: \"" << token << "\"" <<
+ "\ntoken size: " << (unsigned int)token.size());
+
+ test.get_token(type,token);
+ DLIB_TEST_MSG(type == tok::END_OF_LINE,"token: " << token);
+ DLIB_TEST_MSG(token == "\n","token: " << token);
+
+ swap(test,test2);
+ DLIB_TEST(test.stream_is_set() == false);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(token == "\r ","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST_MSG(token == "test","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::CHAR);
+ DLIB_TEST_MSG(token == ".","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::END_OF_FILE);
+
+
+
+
+
+
+
+
+
+
+
+
+
+ test2.set_identifier_token(test.uppercase_letters() +
+ test.lowercase_letters(),test.numbers() + test.uppercase_letters()
+ +test.lowercase_letters());
+
+
+ sin.clear();
+ sin.str(" The cat 123as_df1234 ._ \n test.");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST(token == " ");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "The");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST(token == " ");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "cat");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST(token == " ");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::NUMBER);
+ DLIB_TEST_MSG(token == "123","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "as");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::CHAR);
+ DLIB_TEST_MSG(token == "_","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST(token == "df1234");
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(token == " ","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::CHAR);
+ DLIB_TEST_MSG(token == ".","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::CHAR);
+ DLIB_TEST(token == "_");
+
+ swap(test,test2);
+
+ DLIB_TEST(test2.stream_is_set() == false);
+
+ test.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(token == " ","token: \"" << token << "\"" <<
+ "\ntoken size: " << (unsigned int)token.size());
+
+ test.get_token(type,token);
+ DLIB_TEST_MSG(type == tok::END_OF_LINE,"token: " << token);
+ DLIB_TEST_MSG(token == "\n","token: " << token);
+
+ swap(test,test2);
+ DLIB_TEST(test.stream_is_set() == false);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::WHITE_SPACE);
+ DLIB_TEST_MSG(token == " ","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::IDENTIFIER);
+ DLIB_TEST_MSG(token == "test","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::CHAR);
+ DLIB_TEST_MSG(token == ".","token: " << token);
+
+ test2.get_token(type,token);
+ DLIB_TEST(type == tok::END_OF_FILE);
+
+
+ }
+
+
+
+
+
+ class tokenizer_tester : public tester
+ {
+ public:
+ tokenizer_tester (
+ ) :
+ tester ("test_tokenizer",
+ "Runs tests on the tokenizer component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "testing kernel_1a";
+ tokenizer_kernel_test<tokenizer::kernel_1a> ();
+ dlog << LINFO << "testing kernel_1a_c";
+ tokenizer_kernel_test<tokenizer::kernel_1a_c>();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/tools/CMakeLists.txt b/ml/dlib/dlib/test/tools/CMakeLists.txt
new file mode 100644
index 000000000..adbd43cb9
--- /dev/null
+++ b/ml/dlib/dlib/test/tools/CMakeLists.txt
@@ -0,0 +1,5 @@
+cmake_minimum_required(VERSION 2.8.12)
+
+add_subdirectory(../../../tools/imglab imglab_build)
+add_subdirectory(../../../tools/htmlify htmlify_build)
+add_subdirectory(../../../tools/convert_dlib_nets_to_caffe convert_dlib_nets_to_caffe_build)
diff --git a/ml/dlib/dlib/test/trust_region.cpp b/ml/dlib/dlib/test/trust_region.cpp
new file mode 100644
index 000000000..aa2775b9c
--- /dev/null
+++ b/ml/dlib/dlib/test/trust_region.cpp
@@ -0,0 +1,329 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/optimization.h>
+#include "optimization_test_functions.h"
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+#include "../rand.h"
+
+#include "tester.h"
+
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ using namespace dlib::test_functions;
+
+ logger dlog("test.trust_region");
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ struct neg_rosen_model
+ {
+ typedef matrix<T,0,1> column_vector;
+ typedef matrix<T,0,0> general_matrix;
+
+ T operator() ( column_vector x) const
+ {
+ return -static_cast<T>(rosen<T>(x));
+ }
+
+ void get_derivative_and_hessian (
+ const column_vector& x,
+ column_vector& d,
+ general_matrix& h
+ ) const
+ {
+ d = -matrix_cast<T>(rosen_derivative<T>(x));
+ h = -matrix_cast<T>(rosen_hessian<T>(x));
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ dlib::rand rnd;
+
+ template <typename T>
+ void test_with_rosen()
+ {
+ print_spinner();
+
+ matrix<T,2,1> ans;
+ ans = 1,1;
+
+ matrix<T,2,1> p = 100*matrix_cast<T>(randm(2,1,rnd)) - 50;
+
+ T obj = find_min_trust_region(objective_delta_stop_strategy(1e-12, 100), rosen_function_model<T>(), p);
+
+ DLIB_TEST_MSG(std::abs(obj) < 1e-10, "obj: " << obj);
+ DLIB_TEST_MSG(length(p-ans) < 1e-5, "length(p): " << length(p-ans));
+
+ matrix<T,0,1> p2 = 100*matrix_cast<T>(randm(2,1,rnd)) - 50;
+ obj = find_max_trust_region(objective_delta_stop_strategy(1e-12, 100), neg_rosen_model<T>(), p2);
+
+ DLIB_TEST_MSG(std::abs(obj) < 1e-10, "obj: " << obj);
+ DLIB_TEST_MSG(length(p-ans) < 1e-5, "length(p): " << length(p-ans));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_trust_region_sub_problem()
+ {
+ dlog << LINFO << "subproblem test 1";
+ {
+ matrix<double,2,2> B;
+ B = 1, 0,
+ 0, 1;
+
+ matrix<double,2,1> g, p, ans;
+ g = 0;
+
+ ans = 0;
+
+ solve_trust_region_subproblem(B,g,1,p, 0.001, 10);
+
+ DLIB_TEST(length(p-ans) < 1e-10);
+ solve_trust_region_subproblem(B,g,1,p, 0.001, 1);
+ DLIB_TEST(length(p-ans) < 1e-10);
+ }
+
+ dlog << LINFO << "subproblem test 2";
+ {
+ matrix<double,2,2> B;
+ B = 1, 0,
+ 0, 1;
+
+ B *= 0.1;
+
+ matrix<double,2,1> g, p, ans;
+ g = 1;
+
+ ans = -g / length(g);
+
+ solve_trust_region_subproblem(B,g,1,p, 1e-6, 20);
+
+ DLIB_TEST(length(p-ans) < 1e-4);
+ }
+
+ dlog << LINFO << "subproblem test 3";
+ {
+ matrix<double,2,2> B;
+ B = 0, 0,
+ 0, 0;
+
+ matrix<double,2,1> g, p, ans;
+ g = 1;
+
+ ans = -g / length(g);
+
+ solve_trust_region_subproblem(B,g,1,p, 1e-6, 20);
+
+ dlog << LINFO << "ans: " << trans(ans);
+ dlog << LINFO << "p: " << trans(p);
+ DLIB_TEST(length(p-ans) < 1e-4);
+ }
+ return;
+
+ dlog << LINFO << "subproblem test 4";
+ {
+ matrix<double,2,2> B;
+ B = 2, 0,
+ 0, -1;
+
+
+ matrix<double,2,1> g, p, ans;
+ g = 0;
+
+ ans = 0, -1;
+
+ solve_trust_region_subproblem(B,g,1,p, 1e-6, 20);
+
+ DLIB_TEST(length(p-ans) < 1e-4);
+ }
+
+
+ dlog << LINFO << "subproblem test 5";
+ {
+ matrix<double,2,2> B;
+ B = 2, 0,
+ 0, -1;
+
+
+ matrix<double,2,1> g, p, ans;
+ g = 0, 1;
+
+ ans = 0, -1;
+
+ solve_trust_region_subproblem(B,g,1,p, 1e-6, 20);
+
+ DLIB_TEST(length(p-ans) < 1e-4);
+ }
+
+ dlog << LINFO << "subproblem test 6";
+ for (int i = 0; i < 10; ++i)
+ {
+ matrix<double,10,10> B;
+
+ B = randm(10,10, rnd);
+
+ B = 0.01*B*trans(B);
+
+
+ matrix<double,10,1> g, p, ans;
+ g = 1;
+
+ solve_trust_region_subproblem(B,g,1,p, 1e-6, 20);
+
+ DLIB_TEST(std::abs(length(p) - 1) < 1e-4);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void test_problems()
+ {
+ print_spinner();
+ {
+ matrix<double,4,1> ch;
+
+ ch = brown_start();
+
+ find_min_trust_region(objective_delta_stop_strategy(1e-7, 80),
+ brown_function_model(),
+ ch);
+
+ dlog << LINFO << "brown obj: " << brown(ch);
+ dlog << LINFO << "brown der: " << length(brown_derivative(ch));
+ dlog << LINFO << "brown error: " << length(ch - brown_solution());
+
+ DLIB_TEST(length(ch - brown_solution()) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<double,2,1> ch;
+
+ ch = rosen_start<double>();
+
+ find_min_trust_region(objective_delta_stop_strategy(1e-7, 80),
+ rosen_function_model<double>(),
+ ch);
+
+ dlog << LINFO << "rosen obj: " << rosen(ch);
+ dlog << LINFO << "rosen der: " << length(rosen_derivative(ch));
+ dlog << LINFO << "rosen error: " << length(ch - rosen_solution<double>());
+
+ DLIB_TEST(length(ch - rosen_solution<double>()) < 1e-5);
+ }
+
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(2);
+
+ find_min_trust_region(objective_delta_stop_strategy(1e-7, 80),
+ chebyquad_function_model(),
+ ch);
+
+ dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch);
+ dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2));
+
+ DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(4);
+
+ find_min_trust_region(objective_delta_stop_strategy(1e-7, 80),
+ chebyquad_function_model(),
+ ch);
+
+ dlog << LINFO << "chebyquad 4 obj: " << chebyquad(ch);
+ dlog << LINFO << "chebyquad 4 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "chebyquad 4 error: " << length(ch - chebyquad_solution(4));
+
+ DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5);
+ }
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(6);
+
+ find_min_trust_region(objective_delta_stop_strategy(1e-12, 80),
+ chebyquad_function_model(),
+ ch);
+
+ dlog << LINFO << "chebyquad 6 obj: " << chebyquad(ch);
+ dlog << LINFO << "chebyquad 6 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "chebyquad 6 error: " << length(ch - chebyquad_solution(6));
+
+ DLIB_TEST(length(ch - chebyquad_solution(6)) < 1e-5);
+
+ }
+ print_spinner();
+ {
+ matrix<double,0,1> ch;
+
+ ch = chebyquad_start(8);
+
+ find_min_trust_region(objective_delta_stop_strategy(1e-10, 80),
+ chebyquad_function_model(),
+ ch);
+
+ dlog << LINFO << "chebyquad 8 obj: " << chebyquad(ch);
+ dlog << LINFO << "chebyquad 8 der: " << length(chebyquad_derivative(ch));
+ dlog << LINFO << "chebyquad 8 error: " << length(ch - chebyquad_solution(8));
+
+ DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5);
+ }
+
+ }
+
+
+
+ class optimization_tester : public tester
+ {
+ public:
+ optimization_tester (
+ ) :
+ tester ("test_trust_region",
+ "Runs tests on the trust region optimization component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ dlog << LINFO << "test with rosen<float>";
+ for (int i = 0; i < 50; ++i)
+ test_with_rosen<float>();
+
+ dlog << LINFO << "test with rosen<double>";
+ for (int i = 0; i < 50; ++i)
+ test_with_rosen<double>();
+
+
+ test_trust_region_sub_problem();
+
+ test_problems();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test/tuple.cpp b/ml/dlib/dlib/test/tuple.cpp
new file mode 100644
index 000000000..da7a18ec8
--- /dev/null
+++ b/ml/dlib/dlib/test/tuple.cpp
@@ -0,0 +1,186 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/tuple.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.tuple");
+
+ struct s_nil
+ {
+ template <typename T>
+ void operator() (
+ const T&
+ ) const
+ {
+ }
+ };
+
+
+ struct inc
+ {
+ template <typename T>
+ void operator() (
+ T& a
+ ) const
+ {
+ a += 1;
+ }
+ };
+
+
+ template <typename T>
+ void check_const (
+ const T& t
+ )
+ {
+ t.template get<0>();
+
+ typedef typename T::template get_type<0>::type type0;
+ t.template get<type0>();
+ t.template index<type0>();
+ }
+
+ template <typename T>
+ void check_nonconst (
+ T& t
+ )
+ {
+ t.template get<0>();
+
+ typedef typename T::template get_type<0>::type type0;
+ t.template get<type0>();
+ t.template index<type0>();
+ }
+
+ void tuple_test (
+ )
+ /*!
+ ensures
+ - runs tests on tuple functions for compliance with the specs
+ !*/
+ {
+
+ print_spinner();
+
+ using dlib::tuple;
+
+ tuple<> a;
+ tuple<int> b;
+ tuple<int, float> c;
+
+
+ a.get<1>();
+ a.get<2>();
+ a.get<3>();
+ a.get<4>();
+ a.get<5>();
+
+ check_nonconst(b);
+ check_nonconst(c);
+ check_const(b);
+ check_const(c);
+
+ COMPILE_TIME_ASSERT((is_same_type<tuple<>::get_type<0>::type, null_type>::value));
+ COMPILE_TIME_ASSERT((is_same_type<tuple<int>::get_type<0>::type, int>::value));
+ COMPILE_TIME_ASSERT((is_same_type<tuple<int,float>::get_type<0>::type, int>::value));
+ COMPILE_TIME_ASSERT((is_same_type<tuple<int,float>::get_type<1>::type, float>::value));
+ COMPILE_TIME_ASSERT((is_same_type<tuple<int,float>::get_type<2>::type, null_type>::value));
+
+ b.get<0>() = 8;
+ DLIB_TEST(b.get<int>() == 8);
+ DLIB_TEST(b.index<int>() == 0);
+
+ c.get<0>() = 9;
+ DLIB_TEST(c.get<int>() == 9);
+ DLIB_TEST(c.index<int>() == 0);
+ c.get<1>() = 3.0;
+ DLIB_TEST(c.get<float>() == 3.0);
+ DLIB_TEST(c.index<float>() == 1);
+
+
+
+ {
+ typedef tuple<int, short, long> T;
+ T a, b;
+ a.get<0>() = 1;
+ a.get<1>() = 3;
+ a.get<2>() = 2;
+
+ b = a;
+
+ inc i;
+ s_nil n;
+ a.for_each(inc());
+ a.for_each(i);
+ const_cast<const T&>(a).for_each(s_nil());
+ const_cast<const T&>(a).for_each(n);
+
+ DLIB_TEST(a.get<0>() == b.get<0>()+2);
+ DLIB_TEST(a.get<1>() == b.get<1>()+2);
+ DLIB_TEST(a.get<2>() == b.get<2>()+2);
+
+ ostringstream sout;
+
+ serialize(a,sout);
+ istringstream sin(sout.str());
+ deserialize(b,sin);
+
+ DLIB_TEST(a.get<0>() == b.get<0>());
+ DLIB_TEST(a.get<1>() == b.get<1>());
+ DLIB_TEST(a.get<2>() == b.get<2>());
+
+ a.for_index(i,0);
+ a.for_index(inc(),1);
+ const_cast<const T&>(a).for_index(n,2);
+ const_cast<const T&>(a).for_index(s_nil(),0);
+
+ DLIB_TEST(a.get<0>() == b.get<0>()+1);
+ DLIB_TEST(a.get<1>() == b.get<1>()+1);
+ DLIB_TEST(a.get<2>() == b.get<2>()+0);
+
+ swap(a,b);
+
+ DLIB_TEST(b.get<0>() == a.get<0>()+1);
+ DLIB_TEST(b.get<1>() == a.get<1>()+1);
+ DLIB_TEST(b.get<2>() == a.get<2>()+0);
+ }
+
+
+ }
+
+
+
+
+ class tuple_tester : public tester
+ {
+ public:
+ tuple_tester (
+ ) :
+ tester ("test_tuple",
+ "Runs tests on the tuple object")
+ {}
+
+ void perform_test (
+ )
+ {
+ tuple_test();
+ }
+ } a;
+
+}
+
+
+
diff --git a/ml/dlib/dlib/test/type_safe_union.cpp b/ml/dlib/dlib/test/type_safe_union.cpp
new file mode 100644
index 000000000..6a18fa8e1
--- /dev/null
+++ b/ml/dlib/dlib/test/type_safe_union.cpp
@@ -0,0 +1,455 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/type_safe_union.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.type_safe_union");
+
+ struct can_not_copy: noncopyable {};
+ void serialize(const can_not_copy&, std::ostream&) {}
+ void deserialize(can_not_copy&, std::istream&) {}
+
+ void swap(can_not_copy&, can_not_copy&) {}
+
+ class test
+ {
+
+ private:
+
+ enum kind
+ {
+ FLOAT, DOUBLE, CHAR, STRING, NONE
+ };
+
+ void operator() (float val)
+ {
+ DLIB_TEST(val == f_val);
+ last_kind = FLOAT;
+ }
+
+ void operator() (double val)
+ {
+ DLIB_TEST(val == d_val);
+ last_kind = DOUBLE;
+ }
+
+ void operator() (char val)
+ {
+ DLIB_TEST(val == c_val);
+ last_kind = CHAR;
+ }
+
+ void operator()(std::string& val)
+ {
+ DLIB_TEST(val == s_val);
+ last_kind = STRING;
+ }
+
+ void operator()(const std::string& val)
+ {
+ DLIB_TEST(val == s_val);
+ last_kind = STRING;
+ }
+
+ // ------------------------------
+
+ friend class type_safe_union<float, double, char, std::string>;
+ typedef type_safe_union<float, double, char, std::string> tsu;
+ tsu a, b, c;
+
+ float f_val;
+ double d_val;
+ char c_val;
+ std::string s_val;
+
+ kind last_kind;
+
+ public:
+ void test_stuff()
+ {
+ DLIB_TEST(a.is_empty() == true);
+ DLIB_TEST(a.contains<char>() == false);
+ DLIB_TEST(a.contains<float>() == false);
+ DLIB_TEST(a.contains<double>() == false);
+ DLIB_TEST(a.contains<std::string>() == false);
+ DLIB_TEST(a.contains<long>() == false);
+
+ DLIB_TEST(a.get_type_id<int>() == -1);
+ DLIB_TEST(a.get_type_id<float>() == 1);
+ DLIB_TEST(a.get_type_id<double>() == 2);
+ DLIB_TEST(a.get_type_id<char>() == 3);
+ DLIB_TEST(a.get_type_id<std::string>() == 4);
+ DLIB_TEST(a.get_type_id<tsu>() == -1);
+
+
+ f_val = 4.345f;
+ a.get<float>() = f_val;
+ DLIB_TEST(a.cast_to<float>() == f_val);
+ DLIB_TEST(const_cast<const tsu&>(a).cast_to<float>() == f_val);
+ bool exception_thrown = false;
+ try {a.cast_to<char>(); }
+ catch (bad_type_safe_union_cast&) { exception_thrown = true;}
+ DLIB_TEST(exception_thrown);
+
+
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<char>() == false);
+ DLIB_TEST(a.contains<float>() == true);
+ DLIB_TEST(a.contains<double>() == false);
+ DLIB_TEST(a.contains<std::string>() == false);
+ DLIB_TEST(a.contains<long>() == false);
+
+
+ last_kind = NONE;
+ const_cast<const tsu&>(a).apply_to_contents(*this);
+ DLIB_TEST(last_kind == FLOAT);
+
+ // -----------
+
+ d_val = 4.345;
+ a.get<double>() = d_val;
+ last_kind = NONE;
+ a.apply_to_contents(*this);
+ DLIB_TEST(last_kind == DOUBLE);
+
+ // -----------
+
+ c_val = 'a';
+ a.get<char>() = c_val;
+ last_kind = NONE;
+ const_cast<const tsu&>(a).apply_to_contents(*this);
+ DLIB_TEST(last_kind == CHAR);
+
+ // -----------
+
+ s_val = "test string";
+ a.get<std::string>() = s_val;
+ last_kind = NONE;
+ a.apply_to_contents(*this);
+ DLIB_TEST(last_kind == STRING);
+
+ DLIB_TEST(a.cast_to<std::string>() == s_val);
+ exception_thrown = false;
+ try {a.cast_to<float>(); }
+ catch (bad_type_safe_union_cast&) { exception_thrown = true;}
+ DLIB_TEST(exception_thrown);
+
+ // -----------
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<char>() == false);
+ DLIB_TEST(a.contains<float>() == false);
+ DLIB_TEST(a.contains<double>() == false);
+ DLIB_TEST(a.contains<std::string>() == true);
+ DLIB_TEST(a.contains<long>() == false);
+ // -----------
+
+ a.swap(b);
+
+ DLIB_TEST(a.is_empty() == true);
+ DLIB_TEST(a.contains<char>() == false);
+ DLIB_TEST(a.contains<float>() == false);
+ DLIB_TEST(a.contains<double>() == false);
+ DLIB_TEST(a.contains<std::string>() == false);
+ DLIB_TEST(a.contains<long>() == false);
+
+ DLIB_TEST(b.is_empty() == false);
+ DLIB_TEST(b.contains<char>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+ DLIB_TEST(b.contains<double>() == false);
+ DLIB_TEST(b.contains<std::string>() == true);
+ DLIB_TEST(b.contains<long>() == false);
+
+
+ last_kind = NONE;
+ b.apply_to_contents(*this);
+ DLIB_TEST(last_kind == STRING);
+
+ // -----------
+
+ b.swap(a);
+
+ DLIB_TEST(b.is_empty() == true);
+ DLIB_TEST(b.contains<char>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+ DLIB_TEST(b.contains<double>() == false);
+ DLIB_TEST(b.contains<std::string>() == false);
+ DLIB_TEST(b.contains<long>() == false);
+
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<char>() == false);
+ DLIB_TEST(a.contains<float>() == false);
+ DLIB_TEST(a.contains<double>() == false);
+ DLIB_TEST(a.contains<std::string>() == true);
+ DLIB_TEST(a.contains<long>() == false);
+
+
+ last_kind = NONE;
+ a.apply_to_contents(*this);
+ DLIB_TEST(last_kind == STRING);
+ last_kind = NONE;
+ b.apply_to_contents(*this);
+ DLIB_TEST(last_kind == NONE);
+
+
+ a.get<char>() = 'a';
+ b.get<char>() = 'b';
+
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<char>() == true);
+ DLIB_TEST(b.is_empty() == false);
+ DLIB_TEST(b.contains<char>() == true);
+ DLIB_TEST(a.contains<float>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+
+
+ DLIB_TEST(a.get<char>() == 'a');
+ DLIB_TEST(b.get<char>() == 'b');
+
+ swap(a,b);
+
+
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<char>() == true);
+ DLIB_TEST(b.is_empty() == false);
+ DLIB_TEST(b.contains<char>() == true);
+ DLIB_TEST(a.contains<float>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+
+ DLIB_TEST(a.get<char>() == 'b');
+ DLIB_TEST(b.get<char>() == 'a');
+
+ // -----------
+
+ a.get<char>() = 'a';
+ b.get<std::string>() = "a string";
+
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<char>() == true);
+ DLIB_TEST(b.is_empty() == false);
+ DLIB_TEST(b.contains<char>() == false);
+ DLIB_TEST(a.contains<std::string>() == false);
+ DLIB_TEST(b.contains<std::string>() == true);
+
+
+ DLIB_TEST(a.get<char>() == 'a');
+ DLIB_TEST(b.get<std::string>() == "a string");
+
+ swap(a,b);
+
+ DLIB_TEST(b.is_empty() == false);
+ DLIB_TEST(b.contains<char>() == true);
+ DLIB_TEST(a.is_empty() == false);
+ DLIB_TEST(a.contains<char>() == false);
+ DLIB_TEST(b.contains<std::string>() == false);
+ DLIB_TEST(a.contains<std::string>() == true);
+
+
+ DLIB_TEST(b.get<char>() == 'a');
+ DLIB_TEST(a.get<std::string>() == "a string");
+
+
+
+
+ {
+ type_safe_union<char, float, std::string> a, b, empty_union;
+
+ ostringstream sout;
+ istringstream sin;
+
+ a.get<char>() = 'd';
+
+ serialize(a, sout);
+
+ sin.str(sout.str());
+ deserialize(b, sin);
+
+ DLIB_TEST(b.contains<int>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+ DLIB_TEST(b.contains<char>() == true);
+ DLIB_TEST(b.get<char>() == 'd');
+
+ DLIB_TEST(a.contains<int>() == false);
+ DLIB_TEST(a.contains<float>() == false);
+ DLIB_TEST(a.contains<char>() == true);
+ DLIB_TEST(a.get<char>() == 'd');
+
+ sin.clear();
+ sout.clear();
+ sout.str("");
+
+ a.get<std::string>() = "davis";
+
+ serialize(a, sout);
+ sin.str(sout.str());
+ deserialize(b, sin);
+
+
+ DLIB_TEST(b.contains<int>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+ DLIB_TEST(b.contains<std::string>() == true);
+ DLIB_TEST(b.get<std::string>() == "davis");
+
+ sin.clear();
+ sout.clear();
+ sout.str("");
+
+ serialize(empty_union, sout);
+ sin.str(sout.str());
+ deserialize(b, sin);
+
+ DLIB_TEST(b.is_empty() == true);
+
+ }
+
+ {
+ type_safe_union<char, float, std::string> a, b, empty_union;
+
+ ostringstream sout;
+ istringstream sin;
+
+ a = 'd';
+
+ serialize(a, sout);
+
+ sin.str(sout.str());
+ deserialize(b, sin);
+
+ DLIB_TEST(b.contains<int>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+ DLIB_TEST(b.contains<char>() == true);
+ DLIB_TEST(b.get<char>() == 'd');
+
+ DLIB_TEST(a.contains<int>() == false);
+ DLIB_TEST(a.contains<float>() == false);
+ DLIB_TEST(a.contains<char>() == true);
+ DLIB_TEST(a.get<char>() == 'd');
+
+ sin.clear();
+ sout.clear();
+ sout.str("");
+
+ a = std::string("davis");
+
+ serialize(a, sout);
+ sin.str(sout.str());
+ deserialize(b, sin);
+
+
+ DLIB_TEST(b.contains<int>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+ DLIB_TEST(b.contains<std::string>() == true);
+ DLIB_TEST(b.get<std::string>() == "davis");
+
+ sin.clear();
+ sout.clear();
+ sout.str("");
+
+ serialize(empty_union, sout);
+ sin.str(sout.str());
+ deserialize(b, sin);
+
+ DLIB_TEST(b.is_empty() == true);
+
+ }
+
+ {
+ typedef type_safe_union<char, float, std::string, can_not_copy> tsu_type;
+ tsu_type a('d'), aa(std::string("davis")), b, empty_union;
+
+ ostringstream sout;
+ istringstream sin;
+
+
+ serialize(a, sout);
+
+ sin.str(sout.str());
+ deserialize(b, sin);
+
+ DLIB_TEST(b.contains<int>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+ DLIB_TEST(b.contains<char>() == true);
+ DLIB_TEST(b.get<char>() == 'd');
+
+ DLIB_TEST(a.contains<int>() == false);
+ DLIB_TEST(a.contains<float>() == false);
+ DLIB_TEST(a.contains<char>() == true);
+ DLIB_TEST(a.get<char>() == 'd');
+
+ DLIB_TEST(aa.contains<int>() == false);
+ DLIB_TEST(aa.contains<float>() == false);
+ DLIB_TEST(aa.contains<char>() == false);
+ DLIB_TEST(aa.contains<std::string>() == true);
+
+ sin.clear();
+ sout.clear();
+ sout.str("");
+
+
+ serialize(aa, sout);
+ sin.str(sout.str());
+ deserialize(b, sin);
+
+
+ DLIB_TEST(b.contains<int>() == false);
+ DLIB_TEST(b.contains<float>() == false);
+ DLIB_TEST(b.contains<std::string>() == true);
+ DLIB_TEST(b.get<std::string>() == "davis");
+
+ sin.clear();
+ sout.clear();
+ sout.str("");
+
+ serialize(empty_union, sout);
+ sin.str(sout.str());
+ deserialize(b, sin);
+
+ DLIB_TEST(b.is_empty() == true);
+
+ a.get<can_not_copy>();
+ DLIB_TEST(a.contains<can_not_copy>() == true);
+
+ }
+ }
+
+ };
+
+
+
+ class type_safe_union_tester : public tester
+ {
+ public:
+ type_safe_union_tester (
+ ) :
+ tester ("test_type_safe_union",
+ "Runs tests on the type_safe_union object")
+ {}
+
+ void perform_test (
+ )
+ {
+ for (int i = 0; i < 10; ++i)
+ {
+ test a;
+ a.test_stuff();
+ }
+ }
+ } a;
+
+}
+
+
+
+
diff --git a/ml/dlib/dlib/test/vectorstream.cpp b/ml/dlib/dlib/test/vectorstream.cpp
new file mode 100644
index 000000000..a955961a2
--- /dev/null
+++ b/ml/dlib/dlib/test/vectorstream.cpp
@@ -0,0 +1,142 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <dlib/vectorstream.h>
+
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <vector>
+
+#include "tester.h"
+
+namespace
+{
+
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+
+ logger dlog("test.vectorstream");
+
+// ----------------------------------------------------------------------------------------
+
+ void test1()
+ {
+ print_spinner();
+
+ std::vector<char> buf;
+ vectorstream s(buf);
+
+ for (int i = -1000; i <= 1000; ++i)
+ {
+ char ch = i;
+ s.put(ch);
+ }
+
+ DLIB_TEST(buf.size() == 2001);
+
+ int cnt = -1000;
+ for (unsigned long i = 0; i < buf.size(); ++i)
+ {
+ char ch = cnt;
+ DLIB_TEST(buf[i] == ch);
+ ++cnt;
+ }
+
+ for (int i = -1000; i <= 1000; ++i)
+ {
+ DLIB_TEST(s.peek() != EOF);
+ char ch1 = i;
+ char ch2 = s.get();
+ DLIB_TEST(ch1 == ch2);
+ }
+
+ DLIB_TEST(s.peek() == EOF);
+ DLIB_TEST(s.get() == EOF);
+
+ s.clear();
+ s.seekg(6);
+
+ for (int i = -1000+6; i <= 1000; ++i)
+ {
+ DLIB_TEST(s.peek() != EOF);
+ char ch1 = i;
+ char ch2 = s.get();
+ DLIB_TEST(ch1 == ch2);
+ }
+
+ DLIB_TEST(s.peek() == EOF);
+ DLIB_TEST(s.get() == EOF);
+
+ std::string temp;
+ temp = "one two three!";
+
+ s.seekg(0);
+ buf.clear();
+ s.clear();
+
+ serialize(temp, s);
+ std::string temp2;
+ deserialize(temp2, s);
+ DLIB_TEST(temp2 == temp);
+
+ s.put('1');
+ s.put('2');
+ s.put('3');
+ s.put('4');
+ DLIB_TEST(s.get() == '1');
+ DLIB_TEST(s.get() == '2');
+ DLIB_TEST(s.get() == '3');
+ DLIB_TEST(s.get() == '4');
+
+ s.putback('4');
+ DLIB_TEST(s.get() == '4');
+ s.putback('4');
+ s.putback('3');
+ s.putback('2');
+ s.putback('1');
+ DLIB_TEST(s.get() == '1');
+ DLIB_TEST(s.get() == '2');
+ DLIB_TEST(s.get() == '3');
+ DLIB_TEST(s.get() == '4');
+ DLIB_TEST(s.good() == true);
+ DLIB_TEST(s.get() == EOF);
+ DLIB_TEST(s.good() == false);
+
+ // make sure seeking to a crazy offset doesn't mess things up
+ s.clear();
+ s.seekg(1000000);
+ DLIB_TEST(s.get() == EOF);
+ DLIB_TEST(s.good() == false);
+ s.clear();
+ s.seekg(1000000);
+ char sbuf[100];
+ s.read(sbuf, sizeof(sbuf));
+ DLIB_TEST(s.good() == false);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_vectorstream : public tester
+ {
+ public:
+ test_vectorstream (
+ ) :
+ tester ("test_vectorstream",
+ "Runs tests on the vectorstream component.")
+ {}
+
+ void perform_test (
+ )
+ {
+ test1();
+ }
+ } a;
+
+}
+
+
diff --git a/ml/dlib/dlib/test_for_odr_violations.cpp b/ml/dlib/dlib/test_for_odr_violations.cpp
new file mode 100644
index 000000000..fcf785995
--- /dev/null
+++ b/ml/dlib/dlib/test_for_odr_violations.cpp
@@ -0,0 +1,47 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TEST_FOR_ODR_VIOLATIONS_CPp_
+#define DLIB_TEST_FOR_ODR_VIOLATIONS_CPp_
+
+#include "test_for_odr_violations.h"
+
+extern "C"
+{
+// The point of this block of code is to cause a link time error that will prevent a user
+// from compiling part of their application with DLIB_ASSERT enabled and part with them
+// disabled since doing that would be a violation of C++'s one definition rule.
+#ifdef ENABLE_ASSERTS
+ const int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1 = 0;
+#else
+ const int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1_ = 0;
+#endif
+
+
+// The point of this block of code is to cause a link time error if someone builds dlib via
+// cmake as a separately installable library, and therefore generates a dlib/config.h from
+// cmake, but then proceeds to use the default unconfigured dlib/config.h from version
+// control. It should be obvious why this is bad, if it isn't you need to read a book
+// about C++. Moreover, it can only happen if someone manually copies files around and
+// messes things up. If instead they run `make install` or `cmake --build . --target
+// install` things will be setup correctly, which is what they should do. To summarize: DO
+// NOT BUILD A STANDALONE DLIB AND THEN GO CHERRY PICKING FILES FROM THE BUILD FOLDER AND
+// MIXING THEM WITH THE SOURCE FROM GITHUB. USE CMAKE'S INSTALL SCRIPTS TO INSTALL DLIB.
+// Or even better, don't install dlib at all and instead build your program as shown in
+// examples/CMakeLists.txt
+#if defined(DLIB_NOT_CONFIGURED) && !defined(DLIB__CMAKE_GENERATED_A_CONFIG_H_FILE)
+ const int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_2 = 0;
+#endif
+
+
+
+
+
+#ifdef DLIB_CHECK_FOR_VERSION_MISMATCH
+ const int DLIB_CHECK_FOR_VERSION_MISMATCH = 0;
+#endif
+
+}
+
+
+#endif // DLIB_TEST_FOR_ODR_VIOLATIONS_CPp_
+
diff --git a/ml/dlib/dlib/test_for_odr_violations.h b/ml/dlib/dlib/test_for_odr_violations.h
new file mode 100644
index 000000000..5fa5111ba
--- /dev/null
+++ b/ml/dlib/dlib/test_for_odr_violations.h
@@ -0,0 +1,57 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TEST_FOR_ODR_VIOLATIONS_H_
+#define DLIB_TEST_FOR_ODR_VIOLATIONS_H_
+
+#include "assert.h"
+#include "config.h"
+
+extern "C"
+{
+// =========================>>> WHY YOU ARE GETTING AN ERROR HERE <<<=========================
+// The point of this block of code is to cause a link time error that will prevent a user
+// from compiling part of their application with DLIB_ASSERT enabled and part with it
+// disabled since doing that would be a violation of C++'s one definition rule. So if you
+// are getting an error here then you are either not enabling DLIB_ASSERT consistently
+// (e.g. by compiling part of your program in a debug mode and part in a release mode) or
+// you have simply forgotten to compile dlib/all/source.cpp into your application.
+// =========================>>> WHY YOU ARE GETTING AN ERROR HERE <<<=========================
+#ifdef ENABLE_ASSERTS
+ const extern int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1;
+ const int DLIB_NO_WARN_UNUSED dlib_check_assert_helper_variable = USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1;
+#else
+ const extern int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1_;
+ const int DLIB_NO_WARN_UNUSED dlib_check_assert_helper_variable = USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1_;
+#endif
+
+
+
+// The point of this block of code is to cause a link time error if someone builds dlib via
+// cmake as a separately installable library, and therefore generates a dlib/config.h from
+// cmake, but then proceeds to use the default unconfigured dlib/config.h from version
+// control. It should be obvious why this is bad, if it isn't you need to read a book
+// about C++. Moreover, it can only happen if someone manually copies files around and
+// messes things up. If instead they run `make install` or `cmake --build . --target
+// install` things will be setup correctly, which is what they should do. To summarize: DO
+// NOT BUILD A STANDALONE DLIB AND THEN GO CHERRY PICKING FILES FROM THE BUILD FOLDER AND
+// MIXING THEM WITH THE SOURCE FROM GITHUB. USE CMAKE'S INSTALL SCRIPTS TO INSTALL DLIB.
+// Or even better, don't install dlib at all and instead build your program as shown in
+// examples/CMakeLists.txt
+#if defined(DLIB_NOT_CONFIGURED) && !defined(DLIB__CMAKE_GENERATED_A_CONFIG_H_FILE)
+ const extern int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_2;
+ const int DLIB_NO_WARN_UNUSED dlib_check_not_configured_helper_variable = USER_ERROR__inconsistent_build_configuration__see_dlib_faq_2;
+#endif
+
+
+
+// Cause the user to get a linker error if they try to use header files from one version of
+// dlib with the compiled binary from a different version of dlib.
+#ifdef DLIB_CHECK_FOR_VERSION_MISMATCH
+ const extern int DLIB_CHECK_FOR_VERSION_MISMATCH;
+ const int DLIB_NO_WARN_UNUSED dlib_check_for_version_mismatch = DLIB_CHECK_FOR_VERSION_MISMATCH;
+#endif
+
+}
+
+#endif // DLIB_TEST_FOR_ODR_VIOLATIONS_H_
+
diff --git a/ml/dlib/dlib/threads.h b/ml/dlib/dlib/threads.h
new file mode 100644
index 000000000..371a317e0
--- /dev/null
+++ b/ml/dlib/dlib/threads.h
@@ -0,0 +1,28 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifdef DLIB_ALL_SOURCE_END
+#include "dlib_basic_cpp_build_tutorial.txt"
+#endif
+
+#ifndef DLIB_THREADs_
+#define DLIB_THREADs_
+
+#include "threads/threads_kernel.h"
+
+#include "threads/auto_mutex_extension.h"
+#include "threads/auto_unlock_extension.h"
+#include "threads/create_new_thread_extension.h"
+#include "threads/multithreaded_object_extension.h"
+#include "threads/rmutex_extension.h"
+#include "threads/rsignaler_extension.h"
+#include "threads/threaded_object_extension.h"
+#include "threads/thread_specific_data_extension.h"
+#include "threads/thread_function_extension.h"
+#include "threads/thread_pool_extension.h"
+#include "threads/read_write_mutex_extension.h"
+#include "threads/parallel_for_extension.h"
+#include "threads/async.h"
+
+#endif // DLIB_THREADs_
+
diff --git a/ml/dlib/dlib/threads/async.cpp b/ml/dlib/dlib/threads/async.cpp
new file mode 100644
index 000000000..6aa947bcb
--- /dev/null
+++ b/ml/dlib/dlib/threads/async.cpp
@@ -0,0 +1,48 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AsYNC_CPP_
+#define DLIB_AsYNC_CPP_
+
+// C++11 things don't work in old versions of visual studio
+#if !defined( _MSC_VER) || _MSC_VER >= 1900
+
+#include "async.h"
+#include <stdlib.h>
+#include "../string.h"
+#include <thread>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ unsigned long default_num_threads()
+ {
+ try
+ {
+ char* nt = getenv("DLIB_NUM_THREADS");
+ if (nt)
+ return string_cast<unsigned long>(nt);
+ } catch(string_cast_error&) {}
+ return std::thread::hardware_concurrency();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ thread_pool& default_thread_pool()
+ {
+ static thread_pool tp(impl::default_num_threads());
+ return tp;
+ }
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif
+
+#endif // DLIB_AsYNC_CPP_
+
+
diff --git a/ml/dlib/dlib/threads/async.h b/ml/dlib/dlib/threads/async.h
new file mode 100644
index 000000000..bc6fe5575
--- /dev/null
+++ b/ml/dlib/dlib/threads/async.h
@@ -0,0 +1,105 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AsYNC_Hh_
+#define DLIB_AsYNC_Hh_
+
+// C++11 things don't work in old versions of visual studio
+#if !defined( _MSC_VER) || _MSC_VER >= 1900
+
+#include "async_abstract.h"
+#include "thread_pool_extension.h"
+#include <future>
+#include <functional>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T> struct selector {};
+
+ template <typename T, typename U, typename V>
+ void call_prom_set_value(
+ T& prom,
+ U& fun,
+ selector<V>
+ )
+ {
+ prom.set_value(fun());
+ }
+
+ template <typename T, typename U>
+ void call_prom_set_value(
+ T& prom,
+ U& fun,
+ selector<void>
+ )
+ {
+ fun();
+ prom.set_value();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ thread_pool& default_thread_pool();
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Function,
+ typename ...Args
+ >
+ std::future<typename std::result_of<Function(Args...)>::type> async(
+ thread_pool& tp,
+ Function&& f,
+ Args&&... args
+ )
+ {
+ auto prom = std::make_shared<std::promise<typename std::result_of<Function(Args...)>::type>>();
+ std::future<typename std::result_of<Function(Args...)>::type> ret = prom->get_future();
+ using bind_t = decltype(std::bind(std::forward<Function>(f), std::forward<Args>(args)...));
+ auto fun = std::make_shared<bind_t>(std::bind(std::forward<Function>(f), std::forward<Args>(args)...));
+ tp.add_task_by_value([fun, prom]()
+ {
+ try
+ {
+ impl::call_prom_set_value(*prom, *fun, impl::selector<typename std::result_of<Function(Args...)>::type>());
+ }
+ catch(...)
+ {
+ prom->set_exception(std::current_exception());
+ }
+ });
+ return std::move(ret);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Function,
+ typename ...Args
+ >
+ std::future<typename std::result_of<Function(Args...)>::type> async(
+ Function&& f,
+ Args&&... args
+ )
+ {
+ return async(default_thread_pool(), std::forward<Function>(f), std::forward<Args>(args)...);
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef NO_MAKEFILE
+#include "async.cpp"
+#endif
+
+#endif
+#endif // DLIB_AsYNC_Hh_
+
+
+
diff --git a/ml/dlib/dlib/threads/async_abstract.h b/ml/dlib/dlib/threads/async_abstract.h
new file mode 100644
index 000000000..a9fa1e458
--- /dev/null
+++ b/ml/dlib/dlib/threads/async_abstract.h
@@ -0,0 +1,67 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_AsYNC_ABSTRACT_Hh_
+#ifdef DLIB_AsYNC_ABSTRACT_Hh_
+
+#include "thread_pool_extension_abstract.h"
+#include <future>
+#include <functional>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ thread_pool& default_thread_pool(
+ );
+ /*!
+ ensures
+ - returns a reference to a global thread_pool. If the DLIB_NUM_THREADS
+ environment variable is set to an integer then the thread pool will contain
+ DLIB_NUM_THREADS threads, otherwise it will contain
+ std::thread::hardware_concurrency() threads.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Function,
+ typename ...Args
+ >
+ std::future<typename std::result_of<Function(Args...)>::type> async(
+ thread_pool& tp,
+ Function&& f,
+ Args&&... args
+ );
+ /*!
+ requires
+ - f must be a function and f(args...) must be a valid expression.
+ ensures
+ - This function behaves just like std::async(std::launch::async, f, args)
+ except that instead of spawning a new thread to process each task it submits
+ the task to the provided dlib::thread_pool. Therefore, dlib::async() is
+ guaranteed to use a bounded number of threads unlike std::async(). This also
+ means that calls to dlib::async() will block if there aren't any free threads
+ in the thread pool.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename Function,
+ typename ...Args
+ >
+ std::future<typename std::result_of<Function(Args...)>::type> async(
+ Function&& f,
+ Args&&... args
+ );
+ /*!
+ ensures
+ - Calling this function is equivalent to directly calling async(default_thread_pool(), f, args...)
+ !*/
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_AsYNC_ABSTRACT_Hh_
+
diff --git a/ml/dlib/dlib/threads/auto_mutex_extension.h b/ml/dlib/dlib/threads/auto_mutex_extension.h
new file mode 100644
index 000000000..595c1b176
--- /dev/null
+++ b/ml/dlib/dlib/threads/auto_mutex_extension.h
@@ -0,0 +1,180 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AUTO_MUTEX_EXTENSIOn_
+#define DLIB_AUTO_MUTEX_EXTENSIOn_
+
+#include "threads_kernel.h"
+#include "rmutex_extension.h"
+#include "read_write_mutex_extension.h"
+#include "auto_mutex_extension_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class auto_mutex
+ {
+ /*!
+ INITIAL VALUE
+ - if (m != 0) then
+ - the mutex pointed to by m is locked
+ - if (r != 0) then
+ - the mutex pointed to by r is locked
+ - if (rw != 0) then
+ - the mutex pointed to by rw is locked
+ - exactly one of r, m, or rw is not 0.
+
+ CONVENTION
+ - if (m != 0) then
+ - the mutex pointed to by m is locked
+ - if (r != 0) then
+ - the mutex pointed to by r is locked
+ - if (rw != 0) then
+ - the mutex pointed to by rw is locked
+ - exactly one of r, m, or rw is not 0.
+ !*/
+ public:
+
+ explicit auto_mutex (
+ const mutex& m_
+ ) : m(&m_),
+ r(0),
+ rw(0)
+ {
+ m->lock();
+ }
+
+ explicit auto_mutex (
+ const rmutex& r_
+ ) : m(0),
+ r(&r_),
+ rw(0)
+ {
+ r->lock();
+ }
+
+ explicit auto_mutex (
+ const read_write_mutex& rw_
+ ) : m(0),
+ r(0),
+ rw(&rw_)
+ {
+ rw->lock();
+ }
+
+ void unlock()
+ {
+ if (m != 0)
+ {
+ m->unlock();
+ m = 0;
+ }
+ else if (r != 0)
+ {
+ r->unlock();
+ r = 0;
+ }
+ else if (rw != 0)
+ {
+ rw->unlock();
+ rw = 0;
+ }
+ }
+
+ ~auto_mutex (
+ )
+ {
+ unlock();
+ }
+
+ private:
+
+ const mutex* m;
+ const rmutex* r;
+ const read_write_mutex* rw;
+
+ // restricted functions
+ auto_mutex(auto_mutex&); // copy constructor
+ auto_mutex& operator=(auto_mutex&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class auto_mutex_readonly
+ {
+ public:
+
+ explicit auto_mutex_readonly (
+ const read_write_mutex& rw_
+ ) : rw(rw_), _has_write_lock(false), _has_read_lock(true)
+ {
+ rw.lock_readonly();
+ }
+
+ ~auto_mutex_readonly (
+ )
+ {
+ unlock();
+ }
+
+ void lock_readonly (
+ )
+ {
+ if (!_has_read_lock)
+ {
+ unlock();
+ rw.lock_readonly();
+ _has_read_lock = true;
+ }
+ }
+
+ void lock_write (
+ )
+ {
+ if (!_has_write_lock)
+ {
+ unlock();
+ rw.lock();
+ _has_write_lock = true;
+ }
+ }
+
+ void unlock (
+ )
+ {
+ if (_has_write_lock)
+ {
+ rw.unlock();
+ _has_write_lock = false;
+ }
+ else if (_has_read_lock)
+ {
+ rw.unlock_readonly();
+ _has_read_lock = false;
+ }
+ }
+
+ bool has_read_lock (
+ ) { return _has_read_lock; }
+
+ bool has_write_lock (
+ ) { return _has_write_lock; }
+
+ private:
+
+ const read_write_mutex& rw;
+ bool _has_write_lock;
+ bool _has_read_lock;
+
+ // restricted functions
+ auto_mutex_readonly(auto_mutex_readonly&); // copy constructor
+ auto_mutex_readonly& operator=(auto_mutex_readonly&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AUTO_MUTEX_EXTENSIOn_
+
diff --git a/ml/dlib/dlib/threads/auto_mutex_extension_abstract.h b/ml/dlib/dlib/threads/auto_mutex_extension_abstract.h
new file mode 100644
index 000000000..1990c834e
--- /dev/null
+++ b/ml/dlib/dlib/threads/auto_mutex_extension_abstract.h
@@ -0,0 +1,185 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_AUTO_MUTEX_EXTENSIOn_ABSTRACT_
+#ifdef DLIB_AUTO_MUTEX_EXTENSIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+#include "rmutex_extension_abstract.h"
+#include "read_write_mutex_extension_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class auto_mutex
+ {
+ /*!
+ INITIAL VALUE
+ The mutex given in the constructor is locked and associated with this
+ object.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a mechanism for automatically locking and unlocking
+ a mutex object.
+ !*/
+ public:
+
+ explicit auto_mutex (
+ const mutex& m
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - m will be locked
+ !*/
+
+ explicit auto_mutex (
+ const rmutex& m
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - m will be locked
+ !*/
+
+ explicit auto_mutex (
+ const read_write_mutex& m
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - m will be locked via m.lock() (i.e. a write lock will be obtained)
+ !*/
+
+ void unlock(
+ );
+ /*!
+ ensures
+ - if (unlock() has not already been called) then
+ - The mutex associated with *this has been unlocked. This is useful if
+ you want to unlock a mutex before the auto_mutex destructor executes.
+ !*/
+
+ ~auto_mutex (
+ );
+ /*!
+ ensures
+ - all resources allocated by *this have been freed
+ - calls unlock()
+ !*/
+
+ private:
+ // restricted functions
+ auto_mutex(auto_mutex&); // copy constructor
+ auto_mutex& operator=(auto_mutex&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class auto_mutex_readonly
+ {
+ /*!
+ INITIAL VALUE
+ The mutex given in the constructor is locked using a read-only lock and
+ associated with this object.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a mechanism for automatically locking and unlocking
+ a read_write_mutex object. In particular, a readonly lock is used.
+ !*/
+ public:
+
+ explicit auto_mutex_readonly (
+ const read_write_mutex& m
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - a readonly lock will be obtained on m using m.lock_readonly()
+ - #has_read_lock() == true
+ !*/
+
+ ~auto_mutex_readonly (
+ );
+ /*!
+ ensures
+ - all resources allocated by *this have been freed
+ - the mutex associated with *this has been unlocked
+ !*/
+
+ bool has_read_lock (
+ );
+ /*!
+ ensures
+ - returns true if this object has called read_write_mutex::lock_readonly()
+ on its associated mutex and has yet to release that lock.
+ !*/
+
+ bool has_write_lock (
+ );
+ /*!
+ ensures
+ - returns true if this object has called read_write_mutex::lock() on its
+ associated mutex and has yet to release that lock.
+ !*/
+
+ void lock_readonly (
+ );
+ /*!
+ ensures
+ - This function converts the lock on the associated mutex into a readonly lock.
+ Specifically:
+ if (!has_read_lock()) then
+ - if (has_write_lock()) then
+ - unlocks the associated mutex and then relocks it by calling
+ read_write_mutex::lock_readonly()
+ - else
+ - locks the associated mutex by calling read_write_mutex::lock_readonly()
+ - #has_read_lock() == true
+ - Note that the lock switch is not atomic. This means that whatever
+ resource is protected by the mutex might have been modified during the
+ call to lock_readonly().
+ !*/
+
+ void lock_write (
+ );
+ /*!
+ ensures
+ - This function converts the lock on the associated mutex into a write lock.
+ Specifically:
+ if (!has_write_lock()) then
+ - if (has_read_lock()) then
+ - unlocks the associated mutex and then relocks it by calling
+ read_write_mutex::lock()
+ - else
+ - locks the associated mutex by calling read_write_mutex::lock()
+ - #has_write_lock() == true
+ - Note that the lock switch is not atomic. This means that whatever
+ resource is protected by the mutex might have been modified during the
+ call to lock_write().
+ !*/
+
+ void unlock (
+ );
+ /*!
+ ensures
+ - if (has_read_lock() || has_write_lock()) then
+ - unlocks the associated mutex. This is useful if you want to unlock a
+ mutex before the auto_mutex_readonly destructor executes.
+ - #has_read_lock() == false
+ - #has_write_lock() == false
+ !*/
+
+ private:
+ // restricted functions
+ auto_mutex_readonly(auto_mutex_readonly&); // copy constructor
+ auto_mutex_readonly& operator=(auto_mutex_readonly&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AUTO_MUTEX_EXTENSIOn_ABSTRACT_
+
diff --git a/ml/dlib/dlib/threads/auto_unlock_extension.h b/ml/dlib/dlib/threads/auto_unlock_extension.h
new file mode 100644
index 000000000..cd1d4db9a
--- /dev/null
+++ b/ml/dlib/dlib/threads/auto_unlock_extension.h
@@ -0,0 +1,116 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_AUTO_UNLOCK_EXTENSIOn_
+#define DLIB_AUTO_UNLOCK_EXTENSIOn_
+
+#include "threads_kernel.h"
+#include "rmutex_extension.h"
+#include "read_write_mutex_extension.h"
+#include "auto_unlock_extension_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class auto_unlock
+ {
+ /*!
+ INITIAL VALUE
+ - if (m != 0) then
+ - the mutex pointed to by m is locked
+ - if (r != 0) then
+ - the mutex pointed to by r is locked
+ - if (rw != 0) then
+ - the mutex pointed to by rw is locked
+ - exactly one of r, m, or rw is not 0.
+
+ CONVENTION
+ - if (m != 0) then
+ - the mutex pointed to by m is locked
+ - if (r != 0) then
+ - the mutex pointed to by r is locked
+ - if (rw != 0) then
+ - the mutex pointed to by rw is locked
+ - exactly one of r, m, or rw is not 0.
+ !*/
+ public:
+
+ explicit auto_unlock (
+ const mutex& m_
+ ) : m(&m_),
+ r(0),
+ rw(0)
+ {}
+
+ explicit auto_unlock (
+ const rmutex& r_
+ ) : m(0),
+ r(&r_),
+ rw(0)
+ {}
+
+ explicit auto_unlock (
+ const read_write_mutex& rw_
+ ) : m(0),
+ r(0),
+ rw(&rw_)
+ {}
+
+ ~auto_unlock (
+ )
+ {
+ if (m != 0)
+ m->unlock();
+ else if (r != 0)
+ r->unlock();
+ else
+ rw->unlock();
+ }
+
+ private:
+
+ const mutex* m;
+ const rmutex* r;
+ const read_write_mutex* rw;
+
+ // restricted functions
+ auto_unlock(auto_unlock&); // copy constructor
+ auto_unlock& operator=(auto_unlock&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class auto_unlock_readonly
+ {
+
+ public:
+
+ explicit auto_unlock_readonly (
+ const read_write_mutex& rw_
+ ) :
+ rw(rw_)
+ {}
+
+ ~auto_unlock_readonly (
+ )
+ {
+ rw.unlock_readonly();
+ }
+
+ private:
+
+ const read_write_mutex& rw;
+
+ // restricted functions
+ auto_unlock_readonly(auto_unlock_readonly&); // copy constructor
+ auto_unlock_readonly& operator=(auto_unlock_readonly&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AUTO_UNLOCK_EXTENSIOn_
+
+
diff --git a/ml/dlib/dlib/threads/auto_unlock_extension_abstract.h b/ml/dlib/dlib/threads/auto_unlock_extension_abstract.h
new file mode 100644
index 000000000..f947d4879
--- /dev/null
+++ b/ml/dlib/dlib/threads/auto_unlock_extension_abstract.h
@@ -0,0 +1,116 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_AUTO_UNLOCK_EXTENSIOn_ABSTRACT_
+#ifdef DLIB_AUTO_UNLOCK_EXTENSIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+#include "rmutex_extension_abstract.h"
+#include "read_write_mutex_extension_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class auto_unlock
+ {
+ /*!
+ INITIAL VALUE
+ The mutex given in the constructor is associated with this object.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a mechanism for automatically unlocking
+ a mutex object. It is useful when you already have a locked mutex
+ and want to make sure it gets unlocked even if an exception is thrown
+ or you quit the function at a weird spot.
+ !*/
+ public:
+
+ explicit auto_unlock (
+ const mutex& m
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - does not modify m in any way
+ !*/
+
+ explicit auto_unlock (
+ const rmutex& m
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - does not modify m in any way
+ !*/
+
+ explicit auto_unlock (
+ const read_write_mutex& m
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - does not modify m in any way
+ !*/
+
+ ~auto_unlock (
+ );
+ /*!
+ ensures
+ - all resources allocated by *this have been freed
+ - calls unlock() on the mutex associated with *this
+ !*/
+
+ private:
+ // restricted functions
+ auto_unlock(auto_unlock&); // copy constructor
+ auto_unlock& operator=(auto_unlock&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class auto_unlock_readonly
+ {
+ /*!
+ INITIAL VALUE
+ The mutex given in the constructor is associated with this object.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a mechanism for automatically unlocking
+ a read_write_mutex object. It is useful when you already have a locked mutex
+ and want to make sure it gets unlocked even if an exception is thrown
+ or you quit the function at a weird spot. Note that the mutex
+ is unlocked by calling unlock_readonly() on it.
+ !*/
+ public:
+
+ explicit auto_unlock_readonly (
+ const read_write_mutex& m
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - does not modify m in any way
+ !*/
+
+ ~auto_unlock_readonly (
+ );
+ /*!
+ ensures
+ - all resources allocated by *this have been freed
+ - calls unlock_readonly() on the mutex associated with *this
+ !*/
+
+ private:
+ // restricted functions
+ auto_unlock_readonly(auto_unlock_readonly&); // copy constructor
+ auto_unlock_readonly& operator=(auto_unlock_readonly&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_AUTO_UNLOCK_EXTENSIOn_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/threads/create_new_thread_extension.h b/ml/dlib/dlib/threads/create_new_thread_extension.h
new file mode 100644
index 000000000..8f419b6be
--- /dev/null
+++ b/ml/dlib/dlib/threads/create_new_thread_extension.h
@@ -0,0 +1,46 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CREATE_NEW_THREAD_EXTENSIOn_
+#define DLIB_CREATE_NEW_THREAD_EXTENSIOn_
+
+#include "threads_kernel_abstract.h"
+#include "create_new_thread_extension_abstract.h"
+#include "../threads.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ void (T::*funct)()
+ >
+ inline void dlib_create_new_thread_helper (
+ void* obj
+ )
+ {
+ T* o = static_cast<T*>(obj);
+ (o->*funct)();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ void (T::*funct)()
+ >
+ inline bool create_new_thread (
+ T& obj
+ )
+ {
+ return create_new_thread(dlib_create_new_thread_helper<T,funct>,&obj);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CREATE_NEW_THREAD_EXTENSIOn_
+
+
diff --git a/ml/dlib/dlib/threads/create_new_thread_extension_abstract.h b/ml/dlib/dlib/threads/create_new_thread_extension_abstract.h
new file mode 100644
index 000000000..43fbc474d
--- /dev/null
+++ b/ml/dlib/dlib/threads/create_new_thread_extension_abstract.h
@@ -0,0 +1,33 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CREATE_NEW_THREAD_EXTENSIOn_ABSTRACT_
+#ifdef DLIB_CREATE_NEW_THREAD_EXTENSIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T,
+ void (T::*funct)()
+ >
+ bool create_new_thread (
+ T& obj
+ );
+ /*!
+ ensures
+ - creates a new thread and calls obj.*funct() from it.
+ - returns true upon success and false upon failure to create the new thread.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CREATE_NEW_THREAD_EXTENSIOn_ABSTRACT_
+
+
+
diff --git a/ml/dlib/dlib/threads/multithreaded_object_extension.cpp b/ml/dlib/dlib/threads/multithreaded_object_extension.cpp
new file mode 100644
index 000000000..def4af5f2
--- /dev/null
+++ b/ml/dlib/dlib/threads/multithreaded_object_extension.cpp
@@ -0,0 +1,241 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MULTITHREADED_OBJECT_EXTENSIOn_CPP
+#define DLIB_MULTITHREADED_OBJECT_EXTENSIOn_CPP
+
+#include "multithreaded_object_extension.h"
+#include "create_new_thread_extension.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ multithreaded_object::
+ multithreaded_object (
+ ):
+ s(m_),
+ is_running_(false),
+ should_stop_(false),
+ threads_started(0)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ multithreaded_object::
+ ~multithreaded_object (
+ )
+ {
+ try
+ {
+ DLIB_ASSERT(number_of_threads_alive() == 0,
+ "\tmultithreaded_object::~multithreaded_object()"
+ << "\n\tYou have let a multithreaded object destruct itself before terminating its threads"
+ << "\n\tthis: " << this
+ );
+ }
+ catch (std::exception& e)
+ {
+ std::cerr << e.what() << std::endl;
+ assert(false);
+ abort();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void multithreaded_object::
+ clear (
+ )
+ {
+ auto_mutex M(m_);
+ stop();
+ wait();
+ dead_threads.clear();
+ is_running_ = false;
+ should_stop_ = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool multithreaded_object::
+ is_running (
+ ) const
+ {
+ auto_mutex M(m_);
+ return is_running_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long multithreaded_object::
+ number_of_threads_registered (
+ ) const
+ {
+ auto_mutex M(m_);
+ return thread_ids.size() + dead_threads.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long multithreaded_object::
+ number_of_threads_alive (
+ ) const
+ {
+ auto_mutex M(m_);
+ return threads_started;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void multithreaded_object::
+ wait (
+ ) const
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(thread_ids.is_in_domain(get_thread_id()) == false,
+ "\tvoid multithreaded_object::wait()"
+ << "\n\tYou can NOT call this function from one of the threads registered in this object"
+ << "\n\tthis: " << this
+ );
+
+ while (threads_started > 0)
+ s.wait();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void multithreaded_object::
+ start (
+ )
+ {
+ auto_mutex M(m_);
+ const unsigned long num_threads_registered = dead_threads.size() + thread_ids.size();
+ // start any dead threads
+ for (unsigned long i = threads_started; i < num_threads_registered; ++i)
+ {
+ if (create_new_thread<multithreaded_object,&multithreaded_object::thread_helper>(*this) == false)
+ {
+ should_stop_ = true;
+ is_running_ = false;
+ throw thread_error();
+ }
+ ++threads_started;
+ }
+ is_running_ = true;
+ should_stop_ = false;
+ s.broadcast();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void multithreaded_object::
+ pause (
+ )
+ {
+ auto_mutex M(m_);
+ is_running_ = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void multithreaded_object::
+ stop (
+ )
+ {
+ auto_mutex M(m_);
+ should_stop_ = true;
+ is_running_ = false;
+ s.broadcast();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool multithreaded_object::
+ should_stop (
+ ) const
+ {
+ auto_mutex M(m_);
+ DLIB_ASSERT(thread_ids.is_in_domain(get_thread_id()),
+ "\tbool multithreaded_object::should_stop()"
+ << "\n\tYou can only call this function from one of the registered threads in this object"
+ << "\n\tthis: " << this
+ );
+ while (is_running_ == false && should_stop_ == false)
+ s.wait();
+ return should_stop_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ multithreaded_object::raii_thread_helper::
+ raii_thread_helper(
+ multithreaded_object& self_,
+ thread_id_type id_
+ ) : self(self_), id(id_){}
+
+ multithreaded_object::raii_thread_helper::
+ ~raii_thread_helper()
+ {
+ auto_mutex M(self.m_);
+ if (self.thread_ids.is_in_domain(id))
+ {
+ mfp temp;
+ thread_id_type id_temp;
+ self.thread_ids.remove(id,id_temp,temp);
+ // put this thread's registered function back into the dead_threads queue
+ self.dead_threads.enqueue(temp);
+ }
+
+ --self.threads_started;
+ // If this is the last thread to terminate then
+ // signal that that is the case.
+ if (self.threads_started == 0)
+ {
+ self.is_running_ = false;
+ self.should_stop_ = false;
+ self.s.broadcast();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void multithreaded_object::
+ thread_helper(
+ )
+ {
+ mfp mf;
+ thread_id_type id = get_thread_id();
+
+ // this guy's destructor does all the necessary cleanup in this function
+ raii_thread_helper raii(*this, id);
+
+ // if there is a dead_thread sitting around then pull it
+ // out and put it into mf
+ {
+ auto_mutex M(m_);
+ if (dead_threads.size() > 0)
+ {
+ dead_threads.dequeue(mf);
+ mfp temp(mf);
+ thread_ids.add(id,temp);
+ }
+ }
+
+ if (mf.is_set())
+ {
+ // call the registered thread function
+ mf();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MULTITHREADED_OBJECT_EXTENSIOn_CPP
+
+
diff --git a/ml/dlib/dlib/threads/multithreaded_object_extension.h b/ml/dlib/dlib/threads/multithreaded_object_extension.h
new file mode 100644
index 000000000..9dd37fdcc
--- /dev/null
+++ b/ml/dlib/dlib/threads/multithreaded_object_extension.h
@@ -0,0 +1,153 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MULTITHREADED_OBJECT_EXTENSIOn_
+#define DLIB_MULTITHREADED_OBJECT_EXTENSIOn_
+
+#include "multithreaded_object_extension_abstract.h"
+#include "threads_kernel.h"
+#include "auto_mutex_extension.h"
+#include "rmutex_extension.h"
+#include "rsignaler_extension.h"
+#include "../algs.h"
+#include "../assert.h"
+#include "../map.h"
+#include "../member_function_pointer.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class multithreaded_object
+ {
+ /*!
+ INITIAL VALUE
+ - is_running_ == false
+ - should_stop_ == false
+ - thread_ids.size() == 0
+ - dead_threads.size() == 0
+ - threads_started == 0
+
+ CONVENTION
+ - number_of_threads_registered() == thread_ids.size() + dead_threads.size()
+ - number_of_threads_alive() == threads_started
+
+ - is_running() == is_running_
+ - should_stop() == should_stop_
+
+ - thread_ids == a map of current thread ids to the member function
+ pointers that that thread runs.
+ - threads_started == the number of threads that have been spawned to run
+ thread_helper but haven't ended yet.
+
+ - dead_threads == a queue that contains all the member function pointers
+ for threads that are currently registered but not running
+
+ - m_ == the mutex used to protect all our variables
+ - s == the signaler for m_
+ !*/
+
+ public:
+
+ multithreaded_object (
+ );
+
+ virtual ~multithreaded_object (
+ ) = 0;
+
+ void clear (
+ );
+
+ bool is_running (
+ ) const;
+
+ unsigned long number_of_threads_alive (
+ ) const;
+
+ unsigned long number_of_threads_registered (
+ ) const;
+
+ void wait (
+ ) const;
+
+ void start (
+ );
+
+ void pause (
+ );
+
+ void stop (
+ );
+
+ protected:
+
+ bool should_stop (
+ ) const;
+
+ template <
+ typename T
+ >
+ void register_thread (
+ T& object,
+ void (T::*thread)()
+ )
+ {
+ auto_mutex M(m_);
+ try
+ {
+ mfp mf;
+ mf.set(object,thread);
+ dead_threads.enqueue(mf);
+ if (is_running_)
+ start();
+ }
+ catch (...)
+ {
+ is_running_ = false;
+ should_stop_ = true;
+ s.broadcast();
+ throw;
+ }
+ }
+
+ private:
+
+ class raii_thread_helper
+ {
+ public:
+ raii_thread_helper(multithreaded_object& self_, thread_id_type id_);
+ ~raii_thread_helper();
+
+ multithreaded_object& self;
+ thread_id_type id;
+ };
+
+ void thread_helper(
+ );
+
+ typedef member_function_pointer<> mfp;
+
+ rmutex m_;
+ rsignaler s;
+ map<thread_id_type,mfp,memory_manager<char>::kernel_2a>::kernel_1a thread_ids;
+ queue<mfp,memory_manager<char>::kernel_2a>::kernel_1a dead_threads;
+
+ bool is_running_;
+ bool should_stop_;
+ unsigned long threads_started;
+
+ // restricted functions
+ multithreaded_object(multithreaded_object&); // copy constructor
+ multithreaded_object& operator=(multithreaded_object&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "multithreaded_object_extension.cpp"
+#endif
+
+#endif // DLIB_MULTITHREADED_OBJECT_EXTENSIOn_
+
diff --git a/ml/dlib/dlib/threads/multithreaded_object_extension_abstract.h b/ml/dlib/dlib/threads/multithreaded_object_extension_abstract.h
new file mode 100644
index 000000000..e7862b78f
--- /dev/null
+++ b/ml/dlib/dlib/threads/multithreaded_object_extension_abstract.h
@@ -0,0 +1,186 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_MULTITHREADED_OBJECT_EXTENSIOn_ABSTRACT_
+#ifdef DLIB_MULTITHREADED_OBJECT_EXTENSIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class multithreaded_object
+ {
+ /*!
+ INITIAL VALUE
+ - is_running() == false
+ - number_of_threads_alive() == 0
+ - number_of_threads_registered() == 0
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a multithreaded object. It is similar to
+ the threaded_object except it allows you to have many threads in a
+ single object rather than just one. To use it you inherit from it
+ and register the member functions in your new class that you want
+ to run in their own threads by calling register_thread(). Then when
+ you call start() it will spawn all the registered functions
+ in their own threads.
+ !*/
+
+ public:
+
+ multithreaded_object (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create threading objects.
+ !*/
+
+ virtual ~multithreaded_object (
+ ) = 0;
+ /*!
+ requires
+ - number_of_threads_alive() == 0
+ (i.e. in the destructor for the object you derive from this one you
+ must wait for all the threads to end.)
+ ensures
+ - all resources allocated by *this have been freed.
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ - blocks until all threads have terminated
+ throws
+ - std::bad_alloc or dlib::thread_error
+ if an exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ bool is_running (
+ ) const;
+ /*!
+ ensures
+ - if (number_of_threads_alive() > 0 && the threads are currently supposed to be executing) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ unsigned long number_of_threads_alive (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads that are currently alive (i.e.
+ the number of threads that have started but not yet terminated)
+ !*/
+
+ unsigned long number_of_threads_registered (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads that have been registered by
+ calls to register_thread()
+ !*/
+
+ void wait (
+ ) const;
+ /*!
+ requires
+ - is not called from one of this object's threads
+ ensures
+ - if (number_of_threads_alive() > 0) then
+ - blocks until all the threads in this object have terminated
+ (i.e. blocks until number_of_threads_alive() == 0)
+ !*/
+
+ void start (
+ );
+ /*!
+ ensures
+ - #number_of_threads_alive() == number_of_threads_registered()
+ - #is_running() == true
+ - #should_stop() == false
+ - all the threads registered are up and running.
+ throws
+ - std::bad_alloc or dlib::thread_error
+ If either of these exceptions are thrown then
+ #is_running() == false and should_stop() == true
+ !*/
+
+ void pause (
+ );
+ /*!
+ ensures
+ - #is_running() == false
+ !*/
+
+ void stop (
+ );
+ /*!
+ ensures
+ - #should_stop() == true
+ - #is_running() == false
+ !*/
+
+ protected:
+
+ template <
+ typename T
+ >
+ void register_thread (
+ T& object,
+ void (T::*thread)()
+ );
+ /*!
+ requires
+ - (object.*thread)() forms a valid function call
+ - the thread function does not throw
+ ensures
+ - registers the member function pointed to by thread as one of the threads
+ that runs when is_running() == true
+ - #number_of_threads_registered() == number_of_threads_registered() + 1
+ - if (is_running() == true)
+ - spawns this new member function in its own thread
+ - #number_of_threads_alive() += number_of_threads_alive() + 1
+ throws
+ - std::bad_alloc or dlib::thread_error
+ If either of these exceptions are thrown then
+ #is_running() == false and should_stop() == true
+ !*/
+
+ bool should_stop (
+ ) const;
+ /*!
+ requires
+ - is only called from one of the registered threads in this object
+ ensures
+ - if (is_running() == false && should_stop() == false) then
+ - blocks until (#is_running() == true || #should_stop() == true)
+ - if (this thread is supposed to terminate) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ private:
+
+ // restricted functions
+ multithreaded_object(multithreaded_object&); // copy constructor
+ multithreaded_object& operator=(multithreaded_object&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MULTITHREADED_OBJECT_EXTENSIOn_ABSTRACT_
+
diff --git a/ml/dlib/dlib/threads/parallel_for_extension.h b/ml/dlib/dlib/threads/parallel_for_extension.h
new file mode 100644
index 000000000..60b64b1b4
--- /dev/null
+++ b/ml/dlib/dlib/threads/parallel_for_extension.h
@@ -0,0 +1,676 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_PARALLEL_FoR_Hh_
+#define DLIB_PARALLEL_FoR_Hh_
+
+#include "parallel_for_extension_abstract.h"
+#include "thread_pool_extension.h"
+#include "../console_progress_indicator.h"
+#include "async.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ template <typename T>
+ class helper_parallel_for
+ {
+ public:
+ helper_parallel_for (
+ T& obj_,
+ void (T::*funct_)(long)
+ ) :
+ obj(obj_),
+ funct(funct_)
+ {}
+
+ T& obj;
+ void (T::*funct)(long);
+
+ void process_block (long begin, long end)
+ {
+ for (long i = begin; i < end; ++i)
+ (obj.*funct)(i);
+ }
+ };
+
+ template <typename T>
+ class helper_parallel_for_funct
+ {
+ public:
+ helper_parallel_for_funct (
+ const T& funct_
+ ) : funct(funct_) {}
+
+ const T& funct;
+
+ void run(long i)
+ {
+ funct(i);
+ }
+ };
+
+ template <typename T>
+ class helper_parallel_for_funct2
+ {
+ public:
+ helper_parallel_for_funct2 (
+ const T& funct_
+ ) : funct(funct_) {}
+
+ const T& funct;
+
+ void run(long begin, long end)
+ {
+ funct(begin, end);
+ }
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked (
+ thread_pool& tp,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long, long),
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_blocked()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ if (tp.num_threads_in_pool() != 0)
+ {
+ const long num = end-begin;
+ const long num_workers = static_cast<long>(tp.num_threads_in_pool());
+ // How many samples to process in a single task (aim for chunks_per_thread jobs per worker)
+ const long block_size = std::max(1L, num/(num_workers*chunks_per_thread));
+ for (long i = 0; i < num; i+=block_size)
+ {
+ tp.add_task(obj, funct, begin+i, begin+std::min(i+block_size, num));
+ }
+ tp.wait_for_all_tasks();
+ }
+ else
+ {
+ // Since there aren't any threads in the pool we might as well just invoke
+ // the function directly since that's all the thread_pool object would do.
+ // But doing it ourselves skips a mutex lock.
+ (obj.*funct)(begin, end);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long, long),
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_blocked()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ thread_pool tp(num_threads);
+ parallel_for_blocked(tp, begin, end, obj, funct, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked (
+ thread_pool& tp,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_blocked()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::helper_parallel_for_funct2<T> helper(funct);
+ parallel_for_blocked(tp, begin, end, helper, &impl::helper_parallel_for_funct2<T>::run, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_blocked()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ thread_pool tp(num_threads);
+ parallel_for_blocked(tp, begin, end, funct, chunks_per_thread);
+ }
+
+ template <typename T>
+ void parallel_for_blocked (
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ parallel_for_blocked(default_thread_pool(), begin, end, funct, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ thread_pool& tp,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long),
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::helper_parallel_for<T> helper(obj, funct);
+ parallel_for_blocked(tp, begin, end, helper, &impl::helper_parallel_for<T>::process_block, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long),
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ thread_pool tp(num_threads);
+ parallel_for(tp, begin, end, obj, funct, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ thread_pool& tp,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::helper_parallel_for_funct<T> helper(funct);
+ parallel_for(tp, begin, end, helper, &impl::helper_parallel_for_funct<T>::run, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ thread_pool tp(num_threads);
+ parallel_for(tp, begin, end, funct, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ parallel_for(default_thread_pool(), begin, end, funct, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <typename T>
+ class parfor_verbose_helper
+ {
+ public:
+ parfor_verbose_helper(T& obj_, void (T::*funct_)(long), long begin, long end) :
+ obj(obj_), funct(funct_), pbar(end-begin)
+ {
+ count = 0;
+ wrote_to_screen = pbar.print_status(0);
+ }
+
+ ~parfor_verbose_helper()
+ {
+ if (wrote_to_screen)
+ std::cout << std::endl;
+ }
+
+ mutable long count;
+ T& obj;
+ void (T::*funct)(long);
+ mutable console_progress_indicator pbar;
+ mutable bool wrote_to_screen;
+ mutex m;
+
+ void operator()(long i) const
+ {
+ (obj.*funct)(i);
+ {
+ auto_mutex lock(m);
+ wrote_to_screen = pbar.print_status(++count) || wrote_to_screen;
+ }
+ }
+
+ };
+
+ template <typename T>
+ class parfor_verbose_helper3
+ {
+ public:
+ parfor_verbose_helper3(T& obj_, void (T::*funct_)(long,long), long begin, long end) :
+ obj(obj_), funct(funct_), pbar(end-begin)
+ {
+ count = 0;
+ wrote_to_screen = pbar.print_status(0);
+ }
+
+ ~parfor_verbose_helper3()
+ {
+ if (wrote_to_screen)
+ std::cout << std::endl;
+ }
+
+ mutable long count;
+ T& obj;
+ void (T::*funct)(long,long);
+ mutable console_progress_indicator pbar;
+ mutable bool wrote_to_screen;
+ mutex m;
+
+ void operator()(long begin, long end) const
+ {
+ (obj.*funct)(begin, end);
+ {
+ auto_mutex lock(m);
+ count += end-begin;
+ wrote_to_screen = pbar.print_status(count) || wrote_to_screen;
+ }
+ }
+ };
+
+ template <typename T>
+ class parfor_verbose_helper2
+ {
+ public:
+ parfor_verbose_helper2(const T& obj_, long begin, long end) : obj(obj_), pbar(end-begin)
+ {
+ count = 0;
+ wrote_to_screen = pbar.print_status(0);
+ }
+
+ ~parfor_verbose_helper2()
+ {
+ if (wrote_to_screen)
+ std::cout << std::endl;
+ }
+
+ mutable long count;
+ const T& obj;
+ mutable console_progress_indicator pbar;
+ mutable bool wrote_to_screen;
+ mutex m;
+
+ void operator()(long i) const
+ {
+ obj(i);
+ {
+ auto_mutex lock(m);
+ wrote_to_screen = pbar.print_status(++count) || wrote_to_screen;
+ }
+ }
+
+ void operator()(long begin, long end) const
+ {
+ obj(begin, end);
+ {
+ auto_mutex lock(m);
+ count += end-begin;
+ wrote_to_screen = pbar.print_status(count) || wrote_to_screen;
+ }
+ }
+ };
+ }
+
+ template <typename T>
+ void parallel_for_verbose (
+ thread_pool& tp,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long),
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper<T> helper(obj, funct, begin, end);
+ parallel_for(tp, begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_verbose (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long),
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper<T> helper(obj, funct, begin, end);
+ parallel_for(num_threads, begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_verbose (
+ thread_pool& tp,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper2<T> helper(funct, begin, end);
+ parallel_for(tp, begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_verbose (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper2<T> helper(funct, begin, end);
+ parallel_for(num_threads, begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_verbose (
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper2<T> helper(funct, begin, end);
+ parallel_for(begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ thread_pool& tp,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long,long),
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_blocked_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper3<T> helper(obj, funct, begin, end);
+ parallel_for_blocked(tp, begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long,long),
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_blocked_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper3<T> helper(obj, funct, begin, end);
+ parallel_for_blocked(num_threads, begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ thread_pool& tp,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_blocked_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper2<T> helper(funct, begin, end);
+ parallel_for_blocked(tp, begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_blocked_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper2<T> helper(funct, begin, end);
+ parallel_for_blocked(num_threads, begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(begin <= end && chunks_per_thread > 0,
+ "\t void parallel_for_blocked_verbose()"
+ << "\n\t Invalid inputs were given to this function"
+ << "\n\t begin: " << begin
+ << "\n\t end: " << end
+ << "\n\t chunks_per_thread: " << chunks_per_thread
+ );
+
+ impl::parfor_verbose_helper2<T> helper(funct, begin, end);
+ parallel_for_blocked(begin, end, helper, chunks_per_thread);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_PARALLEL_FoR_Hh_
+
diff --git a/ml/dlib/dlib/threads/parallel_for_extension_abstract.h b/ml/dlib/dlib/threads/parallel_for_extension_abstract.h
new file mode 100644
index 000000000..ffd2e0c44
--- /dev/null
+++ b/ml/dlib/dlib/threads/parallel_for_extension_abstract.h
@@ -0,0 +1,469 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_PARALLEL_FoR_ABSTRACT_Hh_
+#ifdef DLIB_PARALLEL_FoR_ABSTRACT_Hh_
+
+#include "thread_pool_extension_abstract.h"
+#include "async_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked (
+ thread_pool& tp,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long, long),
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This is a convenience function for submitting a block of jobs to a thread_pool.
+ In particular, given the half open range [begin, end), this function will
+ split the range into approximately tp.num_threads_in_pool()*chunks_per_thread
+ blocks, which it will then submit to the thread_pool. The given thread_pool
+ will then call (obj.*funct)() on each of the subranges.
+ - To be precise, suppose we have broken the range [begin, end) into the
+ following subranges:
+ - [begin[0], end[0])
+ - [begin[1], end[1])
+ - [begin[2], end[2])
+ ...
+ - [begin[n], end[n])
+ Then parallel_for_blocked() submits each of these subranges to tp for
+ processing such that (obj.*funct)(begin[i], end[i]) is invoked for all valid
+ values of i. Moreover, the subranges are non-overlapping and completely
+ cover the total range of [begin, end).
+ - This function will not perform any memory allocations or create any system
+ resources such as mutex objects.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long, long),
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is equivalent to the following block of code:
+ thread_pool tp(num_threads);
+ parallel_for_blocked(tp, begin, end, obj, funct, chunks_per_thread);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked (
+ thread_pool& tp,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - chunks_per_thread > 0
+ - begin <= end
+ ensures
+ - This is a convenience function for submitting a block of jobs to a
+ thread_pool. In particular, given the range [begin, end), this function will
+ split the range into approximately tp.num_threads_in_pool()*chunks_per_thread
+ blocks, which it will then submit to the thread_pool. The given thread_pool
+ will then call funct() on each of the subranges.
+ - To be precise, suppose we have broken the range [begin, end) into the
+ following subranges:
+ - [begin[0], end[0])
+ - [begin[1], end[1])
+ - [begin[2], end[2])
+ ...
+ - [begin[n], end[n])
+ Then parallel_for_blocked() submits each of these subranges to tp for
+ processing such that funct(begin[i], end[i]) is invoked for all valid values
+ of i.
+ - This function will not perform any memory allocations or create any system
+ resources such as mutex objects.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is equivalent to the following block of code:
+ thread_pool tp(num_threads);
+ parallel_for_blocked(tp, begin, end, funct, chunks_per_thread);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked (
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is equivalent to the following block of code:
+ parallel_for_blocked(default_thread_pool(), begin, end, funct, chunks_per_thread);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ thread_pool& tp,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long),
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is equivalent to the following function call:
+ parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub)
+ {
+ for (long i = begin_sub; i < end_sub; ++i)
+ (obj.*funct)(i);
+ }, chunks_per_thread);
+ - Therefore, this routine invokes (obj.*funct)(i) for all i in the range
+ [begin, end). However, it does so using tp.num_threads_in_pool() parallel
+ threads.
+ - This function will not perform any memory allocations or create any system
+ resources such as mutex objects.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long),
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is equivalent to the following block of code:
+ thread_pool tp(num_threads);
+ parallel_for(tp, begin, end, obj, funct, chunks_per_thread);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ thread_pool& tp,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is equivalent to the following function call:
+ parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub)
+ {
+ for (long i = begin_sub; i < end_sub; ++i)
+ funct(i);
+ }, chunks_per_thread);
+ - Therefore, this routine invokes funct(i) for all i in the range [begin, end).
+ However, it does so using tp.num_threads_in_pool() parallel threads.
+ - This function will not perform any memory allocations or create any system
+ resources such as mutex objects.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is equivalent to the following block of code:
+ thread_pool tp(num_threads);
+ parallel_for(tp, begin, end, funct, chunks_per_thread);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for (
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is equivalent to the following block of code:
+ parallel_for(default_thread_pool(), begin, end, funct, chunks_per_thread);
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_verbose (
+ thread_pool& tp,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long),
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for() routine defined above except
+ that it will print messages to cout showing the progress in executing the
+ parallel for loop.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_verbose (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long),
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for() routine defined above except
+ that it will print messages to cout showing the progress in executing the
+ parallel for loop.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_verbose (
+ thread_pool& tp,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for() routine defined above except
+ that it will print messages to cout showing the progress in executing the
+ parallel for loop.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_verbose (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for() routine defined above except
+ that it will print messages to cout showing the progress in executing the
+ parallel for loop.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_verbose (
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for() routine defined above except
+ that it will print messages to cout showing the progress in executing the
+ parallel for loop.
+ - It will also use the default_thread_pool().
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ thread_pool& tp,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long,long),
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for_blocked() routine defined
+ above except that it will print messages to cout showing the progress in
+ executing the parallel for loop.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ T& obj,
+ void (T::*funct)(long,long),
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for_blocked() routine defined
+ above except that it will print messages to cout showing the progress in
+ executing the parallel for loop.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ thread_pool& tp,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for_blocked() routine defined
+ above except that it will print messages to cout showing the progress in
+ executing the parallel for loop.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ unsigned long num_threads,
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for_blocked() routine defined
+ above except that it will print messages to cout showing the progress in
+ executing the parallel for loop.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void parallel_for_blocked_verbose (
+ long begin,
+ long end,
+ const T& funct,
+ long chunks_per_thread = 8
+ );
+ /*!
+ requires
+ - begin <= end
+ - chunks_per_thread > 0
+ ensures
+ - This function is identical to the parallel_for_blocked() routine defined
+ above except that it will print messages to cout showing the progress in
+ executing the parallel for loop.
+ - It will also use the default_thread_pool()
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_PARALLEL_FoR_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/threads/posix.h b/ml/dlib/dlib/threads/posix.h
new file mode 100644
index 000000000..7226743e1
--- /dev/null
+++ b/ml/dlib/dlib/threads/posix.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADS_KERNEl_1_
+#include "threads_kernel_2.h"
+#endif
+
diff --git a/ml/dlib/dlib/threads/read_write_mutex_extension.h b/ml/dlib/dlib/threads/read_write_mutex_extension.h
new file mode 100644
index 000000000..20e5d5ed8
--- /dev/null
+++ b/ml/dlib/dlib/threads/read_write_mutex_extension.h
@@ -0,0 +1,177 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_READ_WRITE_MUTEX_EXTENSIOn_
+#define DLIB_READ_WRITE_MUTEX_EXTENSIOn_
+
+#include "threads_kernel.h"
+#include "read_write_mutex_extension_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class read_write_mutex
+ {
+ /*!
+ INITIAL VALUE
+ - max_locks == defined by constructor
+ - available_locks == max_locks
+ - write_lock_in_progress == false
+ - write_lock_active == false
+
+ CONVENTION
+ - Each time someone gets a read only lock they take one of the "available locks"
+ and each write lock takes all possible locks (i.e. max_locks). The number of
+ available locks is recorded in available_locks. Any time you try to lock this
+ object and there aren't available locks you have to wait.
+
+ - max_locks == max_readonly_locks()
+
+ - if (some thread is on the process of obtaining a write lock) then
+ - write_lock_in_progress == true
+ - else
+ - write_lock_in_progress == false
+
+ - if (some thread currently has a write lock on this mutex) then
+ - write_lock_active == true
+ - else
+ - write_lock_active == false
+ !*/
+
+ public:
+
+ read_write_mutex (
+ ) : s(m),
+ max_locks(0xFFFFFFFF),
+ available_locks(max_locks),
+ write_lock_in_progress(false),
+ write_lock_active(false)
+ {}
+
+ explicit read_write_mutex (
+ unsigned long max_locks_
+ ) : s(m),
+ max_locks(max_locks_),
+ available_locks(max_locks_),
+ write_lock_in_progress(false),
+ write_lock_active(false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_locks > 0,
+ "\t read_write_mutex::read_write_mutex(max_locks)"
+ << "\n\t You must give a non-zero value for max_locks"
+ << "\n\t this: " << this
+ );
+ }
+
+ ~read_write_mutex (
+ )
+ {}
+
+ void lock (
+ ) const
+ {
+ m.lock();
+
+ // If another write lock is already in progress then wait for it to finish
+ // before we start trying to grab all the available locks. This way we
+ // don't end up fighting over the locks.
+ while (write_lock_in_progress)
+ s.wait();
+
+ // grab the right to perform a write lock
+ write_lock_in_progress = true;
+
+ // now start grabbing all the locks
+ unsigned long locks_obtained = available_locks;
+ available_locks = 0;
+ while (locks_obtained != max_locks)
+ {
+ s.wait();
+ locks_obtained += available_locks;
+ available_locks = 0;
+ }
+
+ write_lock_in_progress = false;
+ write_lock_active = true;
+
+ m.unlock();
+ }
+
+ void unlock (
+ ) const
+ {
+ m.lock();
+
+ // only do something if there really was a lock in place
+ if (write_lock_active)
+ {
+ available_locks = max_locks;
+ write_lock_active = false;
+ s.broadcast();
+ }
+
+ m.unlock();
+ }
+
+ void lock_readonly (
+ ) const
+ {
+ m.lock();
+
+ while (available_locks == 0)
+ s.wait();
+
+ --available_locks;
+
+ m.unlock();
+ }
+
+ void unlock_readonly (
+ ) const
+ {
+ m.lock();
+
+ // If this condition is false then it means there are no more readonly locks
+ // to free. So we don't do anything.
+ if (available_locks != max_locks && !write_lock_active)
+ {
+ ++available_locks;
+
+ // only perform broadcast when there is another thread that might be listening
+ if (available_locks == 1 || write_lock_in_progress)
+ {
+ s.broadcast();
+ }
+ }
+
+ m.unlock();
+ }
+
+ unsigned long max_readonly_locks (
+ ) const
+ {
+ return max_locks;
+ }
+
+ private:
+ mutex m;
+ signaler s;
+ const unsigned long max_locks;
+ mutable unsigned long available_locks;
+ mutable bool write_lock_in_progress;
+ mutable bool write_lock_active;
+
+ // restricted functions
+ read_write_mutex(read_write_mutex&); // copy constructor
+ read_write_mutex& operator=(read_write_mutex&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_READ_WRITE_MUTEX_EXTENSIOn_
+
+
diff --git a/ml/dlib/dlib/threads/read_write_mutex_extension_abstract.h b/ml/dlib/dlib/threads/read_write_mutex_extension_abstract.h
new file mode 100644
index 000000000..18672b057
--- /dev/null
+++ b/ml/dlib/dlib/threads/read_write_mutex_extension_abstract.h
@@ -0,0 +1,146 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_READWRITE_MUTEX_EXTENSIOn_ABSTRACT_
+#ifdef DLIB_READWRITE_MUTEX_EXTENSIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class read_write_mutex
+ {
+ /*!
+ INITIAL VALUE
+ read_write_mutex is in the fully unlocked state
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a mutex intended to be used for synchronous
+ thread control of shared data. When a thread wants to access some
+ shared data it locks out other threads by calling lock() and calls
+ unlock() when it is finished.
+
+ This mutex also has the additional ability to distinguish between
+ a lock for the purposes of modifying some shared data, a write lock,
+ and a lock for the purposes of only reading shared data, a readonly
+ lock. The lock() and unlock() functions are used for write locks while
+ the lock_readonly() and unlock_readonly() are for readonly locks.
+
+ The difference between a readonly and write lock can be understood as
+ follows. The read_write_mutex will allow many threads to obtain simultaneous
+ readonly locks but will only allow a single thread to obtain a write lock.
+ Moreover, while the write lock is obtained no other threads are allowed
+ to have readonly locks.
+ !*/
+ public:
+
+ read_write_mutex (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - max_readonly_locks() == 0xFFFFFFFF
+ (i.e. about 4 billion)
+ throws
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create the read_write_mutex.
+ !*/
+
+ explicit read_write_mutex (
+ unsigned long max_locks
+ );
+ /*!
+ requires
+ - max_locks > 0
+ ensures
+ - #*this is properly initialized
+ - max_readonly_locks() == max_locks
+ throws
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create the read_write_mutex.
+ !*/
+
+ ~read_write_mutex (
+ );
+ /*!
+ requires
+ - *this is not locked
+ ensures
+ - all resources allocated by *this have been freed
+ !*/
+
+ void lock (
+ ) const;
+ /*!
+ requires
+ - The thread calling this function does not have any kind of lock on this
+ object
+ ensures
+ - if (there is any kind of lock on *this) then
+ - the calling thread is put to sleep until a write lock becomes available.
+ Once available, a write lock is obtained on this mutex and this function
+ terminates.
+ - else
+ - a write lock is obtained on this mutex and the calling thread is not put to sleep
+ !*/
+
+ void unlock (
+ ) const;
+ /*!
+ ensures
+ - if (there is a write lock on *this) then
+ - #*this is unlocked (i.e. other threads may now lock this object)
+ - else
+ - the call to unlock() has no effect
+ !*/
+
+ unsigned long max_readonly_locks (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of concurrent readonly locks this object will allow.
+ !*/
+
+ void lock_readonly (
+ ) const;
+ /*!
+ requires
+ - The thread calling this function does not already have a write
+ lock on this object
+ ensures
+ - if (there is a write lock on *this or there are no free readonly locks) then
+ - the calling thread is put to sleep until there is no longer a write lock
+ and a free readonly lock is available. Once this is the case, a readonly
+ lock is obtained and this function terminates.
+ - else
+ - a readonly lock is obtained on *this and the calling thread is not put
+ to sleep. Note that multiple readonly locks can be obtained at once.
+ !*/
+
+ void unlock_readonly (
+ ) const;
+ /*!
+ ensures
+ - if (there is a readonly lock on *this) then
+ - one readonly lock is removed from *this.
+ - else
+ - the call to unlock_readonly() has no effect.
+ !*/
+
+ private:
+ // restricted functions
+ read_write_mutex(read_write_mutex&); // copy constructor
+ read_write_mutex& operator=(read_write_mutex&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_READWRITE_MUTEX_EXTENSIOn_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/threads/rmutex_extension.h b/ml/dlib/dlib/threads/rmutex_extension.h
new file mode 100644
index 000000000..b7bf998be
--- /dev/null
+++ b/ml/dlib/dlib/threads/rmutex_extension.h
@@ -0,0 +1,109 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RMUTEX_EXTENSIOn_
+#define DLIB_RMUTEX_EXTENSIOn_
+
+#include "threads_kernel.h"
+#include "rmutex_extension_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class rmutex
+ {
+ /*!
+ INITIAL VALUE
+ count == 0
+ thread_id == 0
+
+ CONVENTION
+ - count == lock_count()
+
+ - if (no thread currently has a lock on this mutex) then
+ - count == 0
+ - else
+ - count == the number of times the thread that owns this mutex has
+ called lock()
+ - thread_id == the id of this thread.
+ !*/
+ public:
+
+ rmutex (
+ ) : s(m),
+ thread_id(0),
+ count(0)
+ {}
+
+ ~rmutex (
+ )
+ {}
+
+ unsigned long lock_count (
+ ) const
+ {
+ return count;
+ }
+
+ void lock (
+ unsigned long times = 1
+ ) const
+ {
+ const thread_id_type current_thread_id = get_thread_id();
+ m.lock();
+ if (thread_id == current_thread_id)
+ {
+ // we already own this mutex in this case
+ count += times;
+ }
+ else
+ {
+ // wait for our turn to claim this rmutex
+ while (count != 0)
+ s.wait();
+
+ count = times;
+ thread_id = current_thread_id;
+ }
+ m.unlock();
+ }
+
+ void unlock (
+ unsigned long times = 1
+ ) const
+ {
+ const thread_id_type current_thread_id = get_thread_id();
+ m.lock();
+ if (thread_id == current_thread_id)
+ {
+ if (count <= times)
+ {
+ count = 0;
+ s.signal();
+ }
+ else
+ {
+ count -= times;
+ }
+ }
+ m.unlock();
+ }
+
+ private:
+ mutex m;
+ signaler s;
+ mutable thread_id_type thread_id;
+ mutable unsigned long count;
+
+ // restricted functions
+ rmutex(rmutex&); // copy constructor
+ rmutex& operator=(rmutex&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RMUTEX_EXTENSIOn_
+
diff --git a/ml/dlib/dlib/threads/rmutex_extension_abstract.h b/ml/dlib/dlib/threads/rmutex_extension_abstract.h
new file mode 100644
index 000000000..144dbf4d7
--- /dev/null
+++ b/ml/dlib/dlib/threads/rmutex_extension_abstract.h
@@ -0,0 +1,107 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RMUTEX_EXTENSIOn_ABSTRACT_
+#ifdef DLIB_RMUTEX_EXTENSIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class rmutex
+ {
+ /*!
+ INITIAL VALUE
+ rmutex is in the unlocked state
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a recursive mutex intended to be used for synchronous
+ thread control of shared data. When a thread wants to access some
+ shared data it locks out other threads by calling lock() and calls
+ unlock() when it is finished.
+
+ The difference between this and the normal mutex object is that it is safe to
+ call lock() from a thread that already has a lock on this mutex. Doing
+ so just increments a counter but otherwise has no effect on the mutex.
+ Note that unlock() must be called for each call to lock() to release the
+ mutex.
+ !*/
+ public:
+
+ rmutex (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create the rmutex.
+ !*/
+
+ ~rmutex (
+ );
+ /*!
+ requires
+ - *this is not locked
+ ensures
+ - all resources allocated by *this have been freed
+ !*/
+
+ unsigned long lock_count (
+ ) const;
+ /*!
+ requires
+ - the calling thread has a lock on this mutex
+ ensures
+ - returns the number of times the thread has called lock()
+ !*/
+
+ void lock (
+ unsigned long times = 1
+ ) const;
+ /*!
+ ensures
+ - if (*this is currently locked by another thread) then
+ - the thread that called lock() on *this is put to sleep until
+ it becomes available.
+ - #lock_count() == times
+ - if (*this is currently unlocked) then
+ - #*this becomes locked and the current thread is NOT put to sleep
+ but now "owns" #*this
+ - #lock_count() == times
+ - if (*this is locked and owned by the current thread) then
+ - the calling thread retains its lock on *this and isn't put to sleep.
+ - #lock_count() == lock_count() + times
+ !*/
+
+ void unlock (
+ unsigned long times = 1
+ ) const;
+ /*!
+ ensures
+ - if (*this is currently locked and owned by the thread calling unlock) then
+ - if (lock_count() <= times ) then
+ - #*this is unlocked (i.e. other threads may now lock this object)
+ - else
+ - #*this will remain locked
+ - #lock_count() == lock_count() - times
+ - else
+ - the call to unlock() has no effect
+ !*/
+
+
+ private:
+ // restricted functions
+ rmutex(rmutex&); // copy constructor
+ rmutex& operator=(rmutex&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RMUTEX_EXTENSIOn_ABSTRACT_
+
diff --git a/ml/dlib/dlib/threads/rsignaler_extension.h b/ml/dlib/dlib/threads/rsignaler_extension.h
new file mode 100644
index 000000000..bfb5a7ecb
--- /dev/null
+++ b/ml/dlib/dlib/threads/rsignaler_extension.h
@@ -0,0 +1,90 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RSIGNALER_EXTENSIOn_
+#define DLIB_RSIGNALER_EXTENSIOn_
+
+#include "rsignaler_extension_abstract.h"
+#include "rmutex_extension.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class rsignaler
+ {
+ public:
+ rsignaler (
+ const rmutex& associated_mutex
+ ) :
+ assoc_mutex(associated_mutex),
+ s(m)
+ {}
+
+ ~rsignaler (
+ )
+ {}
+
+ void wait (
+ ) const
+ {
+ m.lock();
+ const unsigned long lock_count = assoc_mutex.lock_count();
+ assoc_mutex.unlock(lock_count);
+ s.wait();
+ m.unlock();
+ assoc_mutex.lock(lock_count);
+ }
+
+ bool wait_or_timeout (
+ unsigned long milliseconds
+ ) const
+ {
+ m.lock();
+ const unsigned long lock_count = assoc_mutex.lock_count();
+ assoc_mutex.unlock(lock_count);
+ bool res = s.wait_or_timeout(milliseconds);
+ m.unlock();
+ assoc_mutex.lock(lock_count);
+ return res;
+ }
+
+ void signal (
+ ) const
+ {
+ m.lock();
+ s.signal();
+ m.unlock();
+ }
+
+ void broadcast (
+ ) const
+ {
+ m.lock();
+ s.broadcast();
+ m.unlock();
+ }
+
+ const rmutex& get_mutex (
+ ) const { return assoc_mutex; }
+
+ private:
+
+ const rmutex& assoc_mutex;
+ mutex m;
+ signaler s;
+
+
+ // restricted functions
+ rsignaler(rsignaler&); // copy constructor
+ rsignaler& operator=(rsignaler&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RSIGNALER_EXTENSIOn_
+
+
+
diff --git a/ml/dlib/dlib/threads/rsignaler_extension_abstract.h b/ml/dlib/dlib/threads/rsignaler_extension_abstract.h
new file mode 100644
index 000000000..ae5f450d7
--- /dev/null
+++ b/ml/dlib/dlib/threads/rsignaler_extension_abstract.h
@@ -0,0 +1,123 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RSIGNALER_EXTENSIOn_ABSTRACT_
+#ifdef DLIB_RSIGNALER_EXTENSIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+#include "rmutex_extension_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class rsignaler
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an event signaling system for threads. It gives
+ a thread the ability to wake up other threads that are waiting for a
+ particular signal.
+
+ Each rsignaler object is associated with one and only one rmutex object.
+ More than one rsignaler object may be associated with a single rmutex
+ but a signaler object may only be associated with a single rmutex.
+
+ NOTE:
+ You must guard against spurious wakeups. This means that a thread
+ might return from a call to wait even if no other thread called
+ signal. This is rare but must be guarded against.
+
+ Also note that this object is identical to the signaler object
+ except that it works with rmutex objects rather than mutex objects.
+ !*/
+
+ public:
+
+ rsignaler (
+ const rmutex& associated_mutex
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #get_mutex() == associated_mutex
+ throws
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create the signaler.
+ !*/
+
+
+ ~rsignaler (
+ );
+ /*!
+ ensures
+ - all resources allocated by *this have been freed
+ !*/
+
+ void wait (
+ ) const;
+ /*!
+ requires
+ - get_mutex() is locked and owned by the calling thread
+ ensures
+ - atomically unlocks get_mutex() and blocks the calling thread
+ - calling thread may wake if another thread calls signal() or broadcast()
+ on *this
+ - when wait() returns the calling thread again has a lock on get_mutex()
+ !*/
+
+ bool wait_or_timeout (
+ unsigned long milliseconds
+ ) const;
+ /*!
+ requires
+ - get_mutex() is locked and owned by the calling thread
+ ensures
+ - atomically unlocks get_mutex() and blocks the calling thread
+ - calling thread may wake if another thread calls signal() or broadcast()
+ on *this
+ - after the specified number of milliseconds has elapsed the calling thread
+ will wake once get_mutex() is free
+ - when wait returns the calling thread again has a lock on get_mutex()
+
+ - returns false if the call to wait_or_timeout timed out
+ - returns true if the call did not time out
+ !*/
+
+ void signal (
+ ) const;
+ /*!
+ ensures
+ - if (at least one thread is waiting on *this) then
+ - at least one of the waiting threads will wake
+ !*/
+
+ void broadcast (
+ ) const;
+ /*!
+ ensures
+ - any and all threads waiting on *this will wake
+ !*/
+
+ const rmutex& get_mutex (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the rmutex associated with *this
+ !*/
+
+
+ private:
+ // restricted functions
+ rsignaler(rsignaler&); // copy constructor
+ rsignaler& operator=(rsignaler&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RSIGNALER_EXTENSIOn_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/threads/thread_function_extension.h b/ml/dlib/dlib/threads/thread_function_extension.h
new file mode 100644
index 000000000..7ecdd6520
--- /dev/null
+++ b/ml/dlib/dlib/threads/thread_function_extension.h
@@ -0,0 +1,215 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREAD_FUNCTIOn_
+#define DLIB_THREAD_FUNCTIOn_
+
+#include <memory>
+
+#include "thread_function_extension_abstract.h"
+#include "threads_kernel.h"
+#include "auto_mutex_extension.h"
+#include "threaded_object_extension.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class thread_function : private threaded_object
+ {
+
+ class base_funct
+ {
+ public:
+ virtual void go() = 0;
+ virtual ~base_funct() {}
+ };
+
+ template <typename F, typename T1, typename T2, typename T3, typename T4>
+ class super_funct_4 : public base_funct
+ {
+ public:
+ super_funct_4 ( F funct, T1 arg1, T2 arg2, T3 arg3, T4 arg4) :
+ f(funct),
+ a1(arg1),
+ a2(arg2),
+ a3(arg3),
+ a4(arg4)
+ {
+ }
+
+ void go() { f(a1, a2, a3, a4); }
+
+
+ F f;
+ T1 a1;
+ T2 a2;
+ T3 a3;
+ T4 a4;
+ };
+
+ template <typename F, typename T1, typename T2, typename T3>
+ class super_funct_3 : public base_funct
+ {
+ public:
+ super_funct_3 ( F funct, T1 arg1, T2 arg2, T3 arg3):
+ f(funct),
+ a1(arg1),
+ a2(arg2),
+ a3(arg3)
+ {
+ }
+
+ void go() { f(a1, a2, a3); }
+
+
+ F f;
+ T1 a1;
+ T2 a2;
+ T3 a3;
+ };
+
+ template <typename F, typename T1, typename T2>
+ class super_funct_2 : public base_funct
+ {
+ public:
+ super_funct_2 ( F funct, T1 arg1, T2 arg2) :
+ f(funct),
+ a1(arg1),
+ a2(arg2)
+ {
+ }
+
+ void go() { f(a1, a2); }
+
+
+ F f;
+ T1 a1;
+ T2 a2;
+ };
+
+ template <typename F, typename T>
+ class super_funct_1 : public base_funct
+ {
+ public:
+ super_funct_1 ( F funct, T arg) : f(funct), a(arg)
+ {
+ }
+
+ void go() { f(a); }
+
+
+ F f;
+ T a;
+ };
+
+ template <typename F>
+ class super_funct_0 : public base_funct
+ {
+ public:
+ super_funct_0 ( F funct) : f(funct)
+ {
+ }
+
+ void go() { f(); }
+
+ F f;
+ };
+
+ public:
+
+ template <typename F>
+ thread_function (
+ F funct
+ )
+ {
+ f.reset(new super_funct_0<F>(funct));
+ start();
+ }
+
+ template <typename F, typename T>
+ thread_function (
+ F funct,
+ T arg
+ )
+ {
+ f.reset(new super_funct_1<F,T>(funct,arg));
+ start();
+ }
+
+ template <typename F, typename T1, typename T2>
+ thread_function (
+ F funct,
+ T1 arg1,
+ T2 arg2
+ )
+ {
+ f.reset(new super_funct_2<F,T1,T2>(funct, arg1, arg2));
+ start();
+ }
+
+ template <typename F, typename T1, typename T2, typename T3>
+ thread_function (
+ F funct,
+ T1 arg1,
+ T2 arg2,
+ T3 arg3
+ )
+ {
+ f.reset(new super_funct_3<F,T1,T2,T3>(funct, arg1, arg2, arg3));
+ start();
+ }
+
+ template <typename F, typename T1, typename T2, typename T3, typename T4>
+ thread_function (
+ F funct,
+ T1 arg1,
+ T2 arg2,
+ T3 arg3,
+ T4 arg4
+ )
+ {
+ f.reset(new super_funct_4<F,T1,T2,T3,T4>(funct, arg1, arg2, arg3, arg4));
+ start();
+ }
+
+ ~thread_function (
+ )
+ {
+ threaded_object::wait();
+ }
+
+ bool is_alive (
+ ) const
+ {
+ return threaded_object::is_alive();
+ }
+
+ void wait (
+ ) const
+ {
+ threaded_object::wait();
+ }
+
+ private:
+
+ void thread ()
+ {
+ f->go();
+ }
+
+ std::unique_ptr<base_funct> f;
+
+ // restricted functions
+ thread_function(thread_function&); // copy constructor
+ thread_function& operator=(thread_function&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THREAD_FUNCTIOn_
+
+
+
diff --git a/ml/dlib/dlib/threads/thread_function_extension_abstract.h b/ml/dlib/dlib/threads/thread_function_extension_abstract.h
new file mode 100644
index 000000000..65ea998ac
--- /dev/null
+++ b/ml/dlib/dlib/threads/thread_function_extension_abstract.h
@@ -0,0 +1,146 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_THREAD_FUNCTIOn_ABSTRACT_
+#ifdef DLIB_THREAD_FUNCTIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class thread_function
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a thread on a global C++ function or function
+ object. That is, it allows you to run a function in its own thread.
+ !*/
+ public:
+
+ template <typename F>
+ thread_function (
+ F funct
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - the function funct has been started in its own thread
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create threading objects.
+ !*/
+
+ template <typename F, typename T1>
+ thread_function (
+ F funct,
+ T1 arg1
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - A thread has been created and it will call funct(arg1)
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create threading objects.
+ !*/
+
+ template <typename F, typename T1, typename T2>
+ thread_function (
+ F funct,
+ T1 arg1,
+ T2 arg2
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - A thread has been created and it will call funct(arg1, arg2)
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create threading objects.
+ !*/
+
+ template <typename F, typename T1, typename T2, typename T3>
+ thread_function (
+ F funct,
+ T1 arg1,
+ T2 arg2,
+ T3 arg3
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - A thread has been created and it will call funct(arg1, arg2, arg3)
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create threading objects.
+ !*/
+
+ template <typename F, typename T1, typename T2, typename T3, typename T4>
+ thread_function (
+ F funct,
+ T1 arg1,
+ T2 arg2,
+ T3 arg3,
+ T4 arg4
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - A thread has been created and it will call funct(arg1, arg2, arg3, arg4)
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create threading objects.
+ !*/
+
+ ~thread_function (
+ );
+ /*!
+ ensures
+ - all resources allocated by *this have been freed.
+ - blocks until is_alive() == false
+ !*/
+
+ bool is_alive (
+ ) const;
+ /*!
+ ensures
+ - if (this object's thread has yet to terminate) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void wait (
+ ) const;
+ /*!
+ ensures
+ - if (is_alive() == true) then
+ - blocks until this object's thread terminates
+ !*/
+
+ private:
+
+ // restricted functions
+ thread_function(thread_function&); // copy constructor
+ thread_function& operator=(thread_function&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THREAD_FUNCTIOn_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/threads/thread_pool_extension.cpp b/ml/dlib/dlib/threads/thread_pool_extension.cpp
new file mode 100644
index 000000000..00d99b910
--- /dev/null
+++ b/ml/dlib/dlib/threads/thread_pool_extension.cpp
@@ -0,0 +1,347 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREAD_POOl_CPPh_
+#define DLIB_THREAD_POOl_CPPh_
+
+#include "thread_pool_extension.h"
+#include <memory>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ thread_pool_implementation::
+ thread_pool_implementation (
+ unsigned long num_threads
+ ) :
+ task_done_signaler(m),
+ task_ready_signaler(m),
+ we_are_destructing(false)
+ {
+ tasks.resize(num_threads);
+ threads.resize(num_threads);
+ for (unsigned long i = 0; i < num_threads; ++i)
+ {
+ threads[i] = std::thread([&](){this->thread();});
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void thread_pool_implementation::
+ shutdown_pool (
+ )
+ {
+ {
+ auto_mutex M(m);
+
+ // first wait for all pending tasks to finish
+ bool found_task = true;
+ while (found_task)
+ {
+ found_task = false;
+ for (unsigned long i = 0; i < tasks.size(); ++i)
+ {
+ // If task bucket i has a task that is currently supposed to be processed
+ if (tasks[i].is_empty() == false)
+ {
+ found_task = true;
+ break;
+ }
+ }
+
+ if (found_task)
+ task_done_signaler.wait();
+ }
+
+ // now tell the threads to kill themselves
+ we_are_destructing = true;
+ task_ready_signaler.broadcast();
+ }
+
+ // wait for all threads to terminate
+ for (auto& t : threads)
+ t.join();
+ threads.clear();
+
+ // Throw any unhandled exceptions. Since shutdown_pool() is only called in the
+ // destructor this will kill the program.
+ for (auto&& task : tasks)
+ task.propagate_exception();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ thread_pool_implementation::
+ ~thread_pool_implementation()
+ {
+ shutdown_pool();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long thread_pool_implementation::
+ num_threads_in_pool (
+ ) const
+ {
+ auto_mutex M(m);
+ return tasks.size();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void thread_pool_implementation::
+ wait_for_task (
+ uint64 task_id
+ ) const
+ {
+ auto_mutex M(m);
+ if (tasks.size() != 0)
+ {
+ const unsigned long idx = task_id_to_index(task_id);
+ while (tasks[idx].task_id == task_id)
+ task_done_signaler.wait();
+
+ for (auto&& task : tasks)
+ task.propagate_exception();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void thread_pool_implementation::
+ wait_for_all_tasks (
+ ) const
+ {
+ const thread_id_type thread_id = get_thread_id();
+
+ auto_mutex M(m);
+ bool found_task = true;
+ while (found_task)
+ {
+ found_task = false;
+ for (unsigned long i = 0; i < tasks.size(); ++i)
+ {
+ // If task bucket i has a task that is currently supposed to be processed
+ // and it originated from the calling thread
+ if (tasks[i].is_empty() == false && tasks[i].thread_id == thread_id)
+ {
+ found_task = true;
+ break;
+ }
+ }
+
+ if (found_task)
+ task_done_signaler.wait();
+ }
+
+ // throw any exceptions generated by the tasks
+ for (auto&& task : tasks)
+ task.propagate_exception();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool thread_pool_implementation::
+ is_worker_thread (
+ const thread_id_type id
+ ) const
+ {
+ for (unsigned long i = 0; i < worker_thread_ids.size(); ++i)
+ {
+ if (worker_thread_ids[i] == id)
+ return true;
+ }
+
+ // if there aren't any threads in the pool then we consider all threads
+ // to be worker threads
+ if (tasks.size() == 0)
+ return true;
+ else
+ return false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void thread_pool_implementation::
+ thread (
+ )
+ {
+ {
+ // save the id of this worker thread into worker_thread_ids
+ auto_mutex M(m);
+ thread_id_type id = get_thread_id();
+ worker_thread_ids.push_back(id);
+ }
+
+ task_state_type task;
+ while (we_are_destructing == false)
+ {
+ long idx = 0;
+
+ // wait for a task to do
+ { auto_mutex M(m);
+ while ( (idx = find_ready_task()) == -1 && we_are_destructing == false)
+ task_ready_signaler.wait();
+
+ if (we_are_destructing)
+ break;
+
+ tasks[idx].is_being_processed = true;
+ task = tasks[idx];
+ }
+
+ std::exception_ptr eptr = nullptr;
+ try
+ {
+ // now do the task
+ if (task.bfp)
+ task.bfp();
+ else if (task.mfp0)
+ task.mfp0();
+ else if (task.mfp1)
+ task.mfp1(task.arg1);
+ else if (task.mfp2)
+ task.mfp2(task.arg1, task.arg2);
+ }
+ catch(...)
+ {
+ eptr = std::current_exception();
+ }
+
+ // Now let others know that we finished the task. We do this
+ // by clearing out the state of this task
+ { auto_mutex M(m);
+ tasks[idx].is_being_processed = false;
+ tasks[idx].task_id = 0;
+ tasks[idx].bfp.clear();
+ tasks[idx].mfp0.clear();
+ tasks[idx].mfp1.clear();
+ tasks[idx].mfp2.clear();
+ tasks[idx].arg1 = 0;
+ tasks[idx].arg2 = 0;
+ tasks[idx].eptr = eptr;
+ task_done_signaler.broadcast();
+ }
+
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long thread_pool_implementation::
+ find_empty_task_slot (
+ ) const
+ {
+ for (auto&& task : tasks)
+ task.propagate_exception();
+
+ for (unsigned long i = 0; i < tasks.size(); ++i)
+ {
+ if (tasks[i].is_empty())
+ return i;
+ }
+
+ return -1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ long thread_pool_implementation::
+ find_ready_task (
+ ) const
+ {
+ for (unsigned long i = 0; i < tasks.size(); ++i)
+ {
+ if (tasks[i].is_ready())
+ return i;
+ }
+
+ return -1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint64 thread_pool_implementation::
+ make_next_task_id (
+ long idx
+ )
+ {
+ uint64 id = tasks[idx].next_task_id * tasks.size() + idx;
+ tasks[idx].next_task_id += 1;
+ return id;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ unsigned long thread_pool_implementation::
+ task_id_to_index (
+ uint64 id
+ ) const
+ {
+ return static_cast<unsigned long>(id%tasks.size());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ uint64 thread_pool_implementation::
+ add_task_internal (
+ const bfp_type& bfp,
+ std::shared_ptr<function_object_copy>& item
+ )
+ {
+ auto_mutex M(m);
+ const thread_id_type my_thread_id = get_thread_id();
+
+ // find a thread that isn't doing anything
+ long idx = find_empty_task_slot();
+ if (idx == -1 && is_worker_thread(my_thread_id))
+ {
+ // this function is being called from within a worker thread and there
+ // aren't any other worker threads free so just perform the task right
+ // here
+
+ M.unlock();
+ bfp();
+
+ // return a task id that is both non-zero and also one
+ // that is never normally returned. This way calls
+ // to wait_for_task() will never block given this id.
+ return 1;
+ }
+
+ // wait until there is a thread that isn't doing anything
+ while (idx == -1)
+ {
+ task_done_signaler.wait();
+ idx = find_empty_task_slot();
+ }
+
+ tasks[idx].thread_id = my_thread_id;
+ tasks[idx].task_id = make_next_task_id(idx);
+ tasks[idx].bfp = bfp;
+ tasks[idx].function_copy.swap(item);
+
+ task_ready_signaler.signal();
+
+ return tasks[idx].task_id;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool thread_pool_implementation::
+ is_task_thread (
+ ) const
+ {
+ auto_mutex M(m);
+ return is_worker_thread(get_thread_id());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+#endif // DLIB_THREAD_POOl_CPPh_
+
diff --git a/ml/dlib/dlib/threads/thread_pool_extension.h b/ml/dlib/dlib/threads/thread_pool_extension.h
new file mode 100644
index 000000000..bc2e1782c
--- /dev/null
+++ b/ml/dlib/dlib/threads/thread_pool_extension.h
@@ -0,0 +1,1392 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREAD_POOl_Hh_
+#define DLIB_THREAD_POOl_Hh_
+
+#include <exception>
+#include <memory>
+#include <thread>
+
+#include "thread_pool_extension_abstract.h"
+#include "multithreaded_object_extension.h"
+#include "../member_function_pointer.h"
+#include "../bound_function_pointer.h"
+#include "threads_kernel.h"
+#include "auto_mutex_extension.h"
+#include "../uintn.h"
+#include "../array.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class thread_pool_implementation;
+
+ template <
+ typename T
+ >
+ class future
+ {
+ /*!
+ INITIAL VALUE
+ - task_id == 0
+ - tp.get() == 0
+
+ CONVENTION
+ - is_ready() == (tp.get() == 0)
+ - get() == var
+
+ - if (tp.get() != 0)
+ - tp == a pointer to the thread_pool_implementation that is using this future object
+ - task_id == the task id of the task in the thread pool tp that is using
+ this future object.
+ !*/
+ public:
+
+ future (
+ ) : task_id(0) {}
+
+ future (
+ const T& item
+ ) : task_id(0), var(item) {}
+
+ future (
+ const future& item
+ ) :task_id(0), var(item.get()) {}
+
+ ~future (
+ ) { wait(); }
+
+ future& operator=(
+ const T& item
+ ) { get() = item; return *this; }
+
+ future& operator=(
+ const future& item
+ ) { get() = item.get(); return *this; }
+
+ operator T& (
+ ) { return get(); }
+
+ operator const T& (
+ ) const { return get(); }
+
+ T& get (
+ ) { wait(); return var; }
+
+ const T& get (
+ ) const { wait(); return var; }
+
+ bool is_ready (
+ ) const { return tp.get() == 0; }
+
+ private:
+
+ friend class thread_pool;
+
+ inline void wait () const;
+
+ mutable uint64 task_id;
+ mutable std::shared_ptr<thread_pool_implementation> tp;
+
+ T var;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ inline void swap (
+ future<T>& a,
+ future<T>& b
+ ) { dlib::exchange(a.get(), b.get()); }
+ // Note that dlib::exchange() just calls std::swap. I'm only using it because
+ // this works around some bugs in certain compilers.
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T> bool operator== (const future<T>& a, const future<T>& b) { return a.get() == b.get(); }
+ template <typename T> bool operator!= (const future<T>& a, const future<T>& b) { return a.get() != b.get(); }
+ template <typename T> bool operator<= (const future<T>& a, const future<T>& b) { return a.get() <= b.get(); }
+ template <typename T> bool operator>= (const future<T>& a, const future<T>& b) { return a.get() >= b.get(); }
+ template <typename T> bool operator< (const future<T>& a, const future<T>& b) { return a.get() < b.get(); }
+ template <typename T> bool operator> (const future<T>& a, const future<T>& b) { return a.get() > b.get(); }
+
+ template <typename T> bool operator== (const future<T>& a, const T& b) { return a.get() == b; }
+ template <typename T> bool operator== (const T& a, const future<T>& b) { return a == b.get(); }
+ template <typename T> bool operator!= (const future<T>& a, const T& b) { return a.get() != b; }
+ template <typename T> bool operator!= (const T& a, const future<T>& b) { return a != b.get(); }
+ template <typename T> bool operator<= (const future<T>& a, const T& b) { return a.get() <= b; }
+ template <typename T> bool operator<= (const T& a, const future<T>& b) { return a <= b.get(); }
+ template <typename T> bool operator>= (const future<T>& a, const T& b) { return a.get() >= b; }
+ template <typename T> bool operator>= (const T& a, const future<T>& b) { return a >= b.get(); }
+ template <typename T> bool operator< (const future<T>& a, const T& b) { return a.get() < b; }
+ template <typename T> bool operator< (const T& a, const future<T>& b) { return a < b.get(); }
+ template <typename T> bool operator> (const future<T>& a, const T& b) { return a.get() > b; }
+ template <typename T> bool operator> (const T& a, const future<T>& b) { return a > b.get(); }
+
+// ----------------------------------------------------------------------------------------
+
+ class thread_pool_implementation
+ {
+ /*!
+ CONVENTION
+ - num_threads_in_pool() == tasks.size()
+ - if (the destructor has been called) then
+ - we_are_destructing == true
+ - else
+ - we_are_destructing == false
+
+ - is_task_thread() == is_worker_thread(get_thread_id())
+
+ - m == the mutex used to protect everything in this object
+ - worker_thread_ids == an array that contains the thread ids for
+ all the threads in the thread pool
+ !*/
+ typedef bound_function_pointer::kernel_1a_c bfp_type;
+
+ friend class thread_pool;
+ explicit thread_pool_implementation (
+ unsigned long num_threads
+ );
+
+ public:
+ ~thread_pool_implementation(
+ );
+
+ void wait_for_task (
+ uint64 task_id
+ ) const;
+
+ unsigned long num_threads_in_pool (
+ ) const;
+
+ void wait_for_all_tasks (
+ ) const;
+
+ bool is_task_thread (
+ ) const;
+
+ template <typename T>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)()
+ )
+ {
+ auto_mutex M(m);
+ const thread_id_type my_thread_id = get_thread_id();
+
+ // find a thread that isn't doing anything
+ long idx = find_empty_task_slot();
+ if (idx == -1 && is_worker_thread(my_thread_id))
+ {
+ // this function is being called from within a worker thread and there
+ // aren't any other worker threads free so just perform the task right
+ // here
+
+ M.unlock();
+ (obj.*funct)();
+
+ // return a task id that is both non-zero and also one
+ // that is never normally returned. This way calls
+ // to wait_for_task() will never block given this id.
+ return 1;
+ }
+
+ // wait until there is a thread that isn't doing anything
+ while (idx == -1)
+ {
+ task_done_signaler.wait();
+ idx = find_empty_task_slot();
+ }
+
+ tasks[idx].thread_id = my_thread_id;
+ tasks[idx].task_id = make_next_task_id(idx);
+ tasks[idx].mfp0.set(obj,funct);
+
+ task_ready_signaler.signal();
+
+ return tasks[idx].task_id;
+ }
+
+ template <typename T>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(long),
+ long arg1
+ )
+ {
+ auto_mutex M(m);
+ const thread_id_type my_thread_id = get_thread_id();
+
+ // find a thread that isn't doing anything
+ long idx = find_empty_task_slot();
+ if (idx == -1 && is_worker_thread(my_thread_id))
+ {
+ // this function is being called from within a worker thread and there
+ // aren't any other worker threads free so just perform the task right
+ // here
+
+ M.unlock();
+ (obj.*funct)(arg1);
+
+ // return a task id that is both non-zero and also one
+ // that is never normally returned. This way calls
+ // to wait_for_task() will never block given this id.
+ return 1;
+ }
+
+ // wait until there is a thread that isn't doing anything
+ while (idx == -1)
+ {
+ task_done_signaler.wait();
+ idx = find_empty_task_slot();
+ }
+
+ tasks[idx].thread_id = my_thread_id;
+ tasks[idx].task_id = make_next_task_id(idx);
+ tasks[idx].mfp1.set(obj,funct);
+ tasks[idx].arg1 = arg1;
+
+ task_ready_signaler.signal();
+
+ return tasks[idx].task_id;
+ }
+
+ template <typename T>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(long,long),
+ long arg1,
+ long arg2
+ )
+ {
+ auto_mutex M(m);
+ const thread_id_type my_thread_id = get_thread_id();
+
+ // find a thread that isn't doing anything
+ long idx = find_empty_task_slot();
+ if (idx == -1 && is_worker_thread(my_thread_id))
+ {
+ // this function is being called from within a worker thread and there
+ // aren't any other worker threads free so just perform the task right
+ // here
+
+ M.unlock();
+ (obj.*funct)(arg1, arg2);
+
+ // return a task id that is both non-zero and also one
+ // that is never normally returned. This way calls
+ // to wait_for_task() will never block given this id.
+ return 1;
+ }
+
+ // wait until there is a thread that isn't doing anything
+ while (idx == -1)
+ {
+ task_done_signaler.wait();
+ idx = find_empty_task_slot();
+ }
+
+ tasks[idx].thread_id = my_thread_id;
+ tasks[idx].task_id = make_next_task_id(idx);
+ tasks[idx].mfp2.set(obj,funct);
+ tasks[idx].arg1 = arg1;
+ tasks[idx].arg2 = arg2;
+
+ task_ready_signaler.signal();
+
+ return tasks[idx].task_id;
+ }
+
+ struct function_object_copy
+ {
+ virtual ~function_object_copy(){}
+ };
+
+ template <typename T>
+ struct function_object_copy_instance : function_object_copy
+ {
+ function_object_copy_instance(const T& item_) : item(item_) {}
+ T item;
+ virtual ~function_object_copy_instance(){}
+ };
+
+ uint64 add_task_internal (
+ const bfp_type& bfp,
+ std::shared_ptr<function_object_copy>& item
+ );
+ /*!
+ ensures
+ - adds a task to call the given bfp object.
+ - swaps item into the internal task object which will have a lifetime
+ at least as long as the running task.
+ - returns the task id for this new task
+ !*/
+
+ uint64 add_task_internal (
+ const bfp_type& bfp
+ ) { std::shared_ptr<function_object_copy> temp; return add_task_internal(bfp, temp); }
+ /*!
+ ensures
+ - adds a task to call the given bfp object.
+ - returns the task id for this new task
+ !*/
+
+ void shutdown_pool (
+ );
+ /*!
+ ensures
+ - causes all threads to terminate and blocks the
+ caller until this happens.
+ !*/
+
+ private:
+
+ bool is_worker_thread (
+ const thread_id_type id
+ ) const;
+ /*!
+ requires
+ - m is locked
+ ensures
+ - if (thread with given id is one of the thread pool's worker threads or num_threads_in_pool() == 0) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void thread (
+ );
+ /*!
+ this is the function that executes the threads in the thread pool
+ !*/
+
+ long find_empty_task_slot (
+ ) const;
+ /*!
+ requires
+ - m is locked
+ ensures
+ - if (there is currently a empty task slot) then
+ - returns the index of that task slot in tasks
+ - there is a task slot
+ - else
+ - returns -1
+ !*/
+
+ long find_ready_task (
+ ) const;
+ /*!
+ requires
+ - m is locked
+ ensures
+ - if (there is currently a task to do) then
+ - returns the index of that task in tasks
+ - else
+ - returns -1
+ !*/
+
+ uint64 make_next_task_id (
+ long idx
+ );
+ /*!
+ requires
+ - m is locked
+ - 0 <= idx < tasks.size()
+ ensures
+ - returns the next index to be used for tasks that are placed in
+ tasks[idx]
+ !*/
+
+ unsigned long task_id_to_index (
+ uint64 id
+ ) const;
+ /*!
+ requires
+ - m is locked
+ - num_threads_in_pool() != 0
+ ensures
+ - returns the index in tasks corresponding to the given id
+ !*/
+
+ struct task_state_type
+ {
+ task_state_type() : is_being_processed(false), task_id(0), next_task_id(2), arg1(0), arg2(0), eptr(nullptr) {}
+
+ bool is_ready () const
+ /*!
+ ensures
+ - if (is_empty() == false && no thread is currently processing this task) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ {
+ return !is_being_processed && !is_empty();
+ }
+
+ bool is_empty () const
+ /*!
+ ensures
+ - if (this task state is empty. i.e. it doesn't contain a task to be processed) then
+ - returns true
+ - else
+ - returns false
+ !*/
+ {
+ return task_id == 0;
+ }
+
+ bool is_being_processed; // true when a thread is working on this task
+ uint64 task_id; // the id of this task. 0 means this task is empty
+ thread_id_type thread_id; // the id of the thread that requested this task
+
+ uint64 next_task_id;
+
+ long arg1;
+ long arg2;
+
+ member_function_pointer<> mfp0;
+ member_function_pointer<long> mfp1;
+ member_function_pointer<long,long> mfp2;
+ bfp_type bfp;
+
+ std::shared_ptr<function_object_copy> function_copy;
+ mutable std::exception_ptr eptr; // non-null if the task threw an exception
+
+ void propagate_exception() const
+ {
+ if (eptr)
+ {
+ auto tmp = eptr;
+ eptr = nullptr;
+ std::rethrow_exception(tmp);
+ }
+ }
+
+ };
+
+ array<task_state_type> tasks;
+ array<thread_id_type> worker_thread_ids;
+
+ mutex m;
+ signaler task_done_signaler;
+ signaler task_ready_signaler;
+ bool we_are_destructing;
+
+ std::vector<std::thread> threads;
+
+ // restricted functions
+ thread_pool_implementation(thread_pool_implementation&); // copy constructor
+ thread_pool_implementation& operator=(thread_pool_implementation&); // assignment operator
+
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ class thread_pool
+ {
+ /*!
+ This object is just a shell that holds a std::shared_ptr
+ to the real thread_pool_implementation object. The reason for doing
+ it this way is so that we can allow any mixture of destruction orders
+ between thread_pool objects and futures. Whoever gets destroyed
+ last cleans up the thread_pool_implementation resources.
+ !*/
+ typedef bound_function_pointer::kernel_1a_c bfp_type;
+
+ public:
+ explicit thread_pool (
+ unsigned long num_threads
+ )
+ {
+ impl.reset(new thread_pool_implementation(num_threads));
+ }
+
+ ~thread_pool (
+ )
+ {
+ try
+ {
+ impl->shutdown_pool();
+ }
+ catch (std::exception& e)
+ {
+ std::cerr << "An unhandled exception was inside a dlib::thread_pool when it was destructed." << std::endl;
+ std::cerr << "It's what string is: \n" << e.what() << std::endl;
+ using namespace std;
+ assert(false);
+ abort();
+ }
+ catch (...)
+ {
+ std::cerr << "An unhandled exception was inside a dlib::thread_pool when it was destructed." << std::endl;
+ using namespace std;
+ assert(false);
+ abort();
+ }
+ }
+
+ void wait_for_task (
+ uint64 task_id
+ ) const { impl->wait_for_task(task_id); }
+
+ unsigned long num_threads_in_pool (
+ ) const { return impl->num_threads_in_pool(); }
+
+ void wait_for_all_tasks (
+ ) const { impl->wait_for_all_tasks(); }
+
+ bool is_task_thread (
+ ) const { return impl->is_task_thread(); }
+
+ template <typename T>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)()
+ )
+ {
+ return impl->add_task(obj, funct);
+ }
+
+ template <typename T>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(long),
+ long arg1
+ )
+ {
+ return impl->add_task(obj, funct, arg1);
+ }
+
+ template <typename T>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(long,long),
+ long arg1,
+ long arg2
+ )
+ {
+ return impl->add_task(obj, funct, arg1, arg2);
+ }
+
+ // --------------------
+
+ template <typename F>
+ uint64 add_task (
+ F& function_object
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ bfp_type temp;
+ temp.set(function_object);
+ uint64 id = impl->add_task_internal(temp);
+
+ return id;
+ }
+
+ template <typename F>
+ uint64 add_task_by_value (
+ const F& function_object
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<F>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<F>(function_object);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+
+ bfp_type temp;
+ temp.set(ptr->item);
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ return id;
+ }
+
+ template <typename T>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)() const
+ )
+ {
+ bfp_type temp;
+ temp.set(obj,funct);
+ uint64 id = impl->add_task_internal(temp);
+
+ return id;
+ }
+
+ template <typename T>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)() const
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<const T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<const T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item,funct);
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ return id;
+ }
+
+ template <typename T>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)()
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item,funct);
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ return id;
+ }
+
+ uint64 add_task (
+ void (*funct)()
+ )
+ {
+ bfp_type temp;
+ temp.set(funct);
+ uint64 id = impl->add_task_internal(temp);
+
+ return id;
+ }
+
+ // --------------------
+
+ template <typename F, typename A1>
+ uint64 add_task (
+ F& function_object,
+ future<A1>& arg1
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ bfp_type temp;
+ temp.set(function_object,arg1.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ return id;
+ }
+
+ template <typename F, typename A1>
+ uint64 add_task_by_value (
+ const F& function_object,
+ future<A1>& arg1
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<F>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<F>(function_object);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, arg1.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(T1),
+ future<A1>& arg1
+ )
+ {
+ bfp_type temp;
+ temp.set(obj,funct,arg1.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1),
+ future<A1>& arg1
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item,funct,arg1.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ return id;
+ }
+
+
+ template <typename T, typename T1, typename A1>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)(T1) const,
+ future<A1>& arg1
+ )
+ {
+ bfp_type temp;
+ temp.set(obj,funct,arg1.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1) const,
+ future<A1>& arg1
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<const T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<const T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item,funct,arg1.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ return id;
+ }
+
+ template <typename T1, typename A1>
+ uint64 add_task (
+ void (*funct)(T1),
+ future<A1>& arg1
+ )
+ {
+ bfp_type temp;
+ temp.set(funct,arg1.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ return id;
+ }
+
+ // --------------------
+
+ template <typename F, typename A1, typename A2>
+ uint64 add_task (
+ F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ bfp_type temp;
+ temp.set(function_object, arg1.get(), arg2.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ return id;
+ }
+
+ template <typename F, typename A1, typename A2>
+ uint64 add_task_by_value (
+ const F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<F>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<F>(function_object);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, arg1.get(), arg2.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(T1,T2),
+ future<A1>& arg1,
+ future<A2>& arg2
+ )
+ {
+ bfp_type temp;
+ temp.set(obj, funct, arg1.get(), arg2.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2),
+ future<A1>& arg1,
+ future<A2>& arg2
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, funct, arg1.get(), arg2.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)(T1,T2) const,
+ future<A1>& arg1,
+ future<A2>& arg2
+ )
+ {
+ bfp_type temp;
+ temp.set(obj, funct, arg1.get(), arg2.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2) const,
+ future<A1>& arg1,
+ future<A2>& arg2
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<const T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<const T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, funct, arg1.get(), arg2.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ return id;
+ }
+
+ template <typename T1, typename A1,
+ typename T2, typename A2>
+ uint64 add_task (
+ void (*funct)(T1,T2),
+ future<A1>& arg1,
+ future<A2>& arg2
+ )
+ {
+ bfp_type temp;
+ temp.set(funct, arg1.get(), arg2.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ return id;
+ }
+
+ // --------------------
+
+ template <typename F, typename A1, typename A2, typename A3>
+ uint64 add_task (
+ F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ bfp_type temp;
+ temp.set(function_object, arg1.get(), arg2.get(), arg3.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ return id;
+ }
+
+ template <typename F, typename A1, typename A2, typename A3>
+ uint64 add_task_by_value (
+ const F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<F>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<F>(function_object);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, arg1.get(), arg2.get(), arg3.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(T1,T2,T3),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ )
+ {
+ bfp_type temp;
+ temp.set(obj, funct, arg1.get(), arg2.get(), arg3.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, funct, arg1.get(), arg2.get(), arg3.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3) const,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ )
+ {
+ bfp_type temp;
+ temp.set(obj, funct, arg1.get(), arg2.get(), arg3.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3) const,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<const T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<const T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, funct, arg1.get(), arg2.get(), arg3.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ return id;
+ }
+
+ template <typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task (
+ void (*funct)(T1,T2,T3),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ )
+ {
+ bfp_type temp;
+ temp.set(funct, arg1.get(), arg2.get(), arg3.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ return id;
+ }
+
+ // --------------------
+
+ template <typename F, typename A1, typename A2, typename A3, typename A4>
+ uint64 add_task (
+ F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ )
+ {
+ COMPILE_TIME_ASSERT(is_function<F>::value == false);
+ COMPILE_TIME_ASSERT(is_pointer_type<F>::value == false);
+
+ bfp_type temp;
+ temp.set(function_object, arg1.get(), arg2.get(), arg3.get(), arg4.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ arg4.task_id = id;
+ arg4.tp = impl;
+ return id;
+ }
+
+ template <typename F, typename A1, typename A2, typename A3, typename A4>
+ uint64 add_task_by_value (
+ const F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<F>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<F>(function_object);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, arg1.get(), arg2.get(), arg3.get(), arg4.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the future to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ arg4.task_id = id;
+ arg4.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(T1,T2,T3,T4),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ )
+ {
+ bfp_type temp;
+ temp.set(obj, funct, arg1.get(), arg2.get(), arg3.get(), arg4.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ arg4.task_id = id;
+ arg4.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3,T4),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, funct, arg1.get(), arg2.get(), arg3.get(), arg4.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ arg4.task_id = id;
+ arg4.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3,T4) const,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ )
+ {
+ bfp_type temp;
+ temp.set(obj, funct, arg1.get(), arg2.get(), arg3.get(), arg4.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ arg4.task_id = id;
+ arg4.tp = impl;
+ return id;
+ }
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3,T4) const,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ )
+ {
+ thread_pool_implementation::function_object_copy_instance<const T>* ptr = 0;
+ ptr = new thread_pool_implementation::function_object_copy_instance<const T>(obj);
+ std::shared_ptr<thread_pool_implementation::function_object_copy> function_copy(ptr);
+
+ bfp_type temp;
+ temp.set(ptr->item, funct, arg1.get(), arg2.get(), arg3.get(), arg4.get());
+ uint64 id = impl->add_task_internal(temp, function_copy);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ arg4.task_id = id;
+ arg4.tp = impl;
+ return id;
+ }
+
+ template <typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task (
+ void (*funct)(T1,T2,T3,T4),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ )
+ {
+ bfp_type temp;
+ temp.set(funct, arg1.get(), arg2.get(), arg3.get(), arg4.get());
+ uint64 id = impl->add_task_internal(temp);
+
+ // tie the futures to this task
+ arg1.task_id = id;
+ arg1.tp = impl;
+ arg2.task_id = id;
+ arg2.tp = impl;
+ arg3.task_id = id;
+ arg3.tp = impl;
+ arg4.task_id = id;
+ arg4.tp = impl;
+ return id;
+ }
+
+ private:
+
+ std::shared_ptr<thread_pool_implementation> impl;
+
+ // restricted functions
+ thread_pool(thread_pool&); // copy constructor
+ thread_pool& operator=(thread_pool&); // assignment operator
+
+ };
+
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void future<T>::
+ wait (
+ ) const
+ {
+ if (tp)
+ {
+ tp->wait_for_task(task_id);
+ tp.reset();
+ task_id = 0;
+ }
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#ifdef NO_MAKEFILE
+#include "thread_pool_extension.cpp"
+#endif
+
+#endif // DLIB_THREAD_POOl_Hh_
+
+
diff --git a/ml/dlib/dlib/threads/thread_pool_extension_abstract.h b/ml/dlib/dlib/threads/thread_pool_extension_abstract.h
new file mode 100644
index 000000000..ba54a7546
--- /dev/null
+++ b/ml/dlib/dlib/threads/thread_pool_extension_abstract.h
@@ -0,0 +1,842 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_THREAD_POOl_ABSTRACT_Hh_
+#ifdef DLIB_THREAD_POOl_ABSTRACT_Hh_
+
+#include "threads_kernel_abstract.h"
+#include "../uintn.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class future
+ {
+ /*!
+ INITIAL VALUE
+ - is_ready() == true
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a container that allows you to safely pass objects
+ into the tasks performed by the thread_pool object defined below. An
+ example will make it clear:
+
+ // Suppose you have a global function defined as follows
+ void add (int a, int b, int& result) { result = a + b; }
+
+ // Also suppose you have a thread_pool named tp defined somewhere.
+ // Then you could do the following.
+ future<int> a, b, result;
+ a = 3;
+ b = 4;
+ // this function call causes another thread to execute a call to the add() function
+ // and passes in the int objects contained in a, b, and result
+ tp.add_task(add,a,b,result);
+ // This line will wait for the task in the thread pool to finish and then print the
+ // value in the result integer. So it will print a 7.
+ cout << result << endl;
+ !*/
+
+ public:
+ future (
+ );
+ /*!
+ ensures
+ - The object of type T contained in this future has
+ an initial value for its type.
+ - #is_ready() == true
+ !*/
+
+ future (
+ const T& item
+ );
+ /*!
+ ensures
+ - #get() == item
+ - #is_ready() == true
+ !*/
+
+ future (
+ const future& item
+ );
+ /*!
+ ensures
+ - if (item.is_ready() == false) then
+ - the call to this function blocks until the thread processing the task related
+ to the item future has finished.
+ - #is_ready() == true
+ - #item.is_ready() == true
+ - #get() == item.get()
+ !*/
+
+ ~future (
+ );
+ /*!
+ ensures
+ - if (is_ready() == false) then
+ - the call to this function blocks until the thread processing the task related
+ to this future has finished.
+ !*/
+
+ bool is_ready (
+ ) const;
+ /*!
+ ensures
+ - if (the value of this future may not yet be ready to be accessed because it
+ is in use by a task in a thread_pool) then
+ - returns false
+ - else
+ - returns true
+ !*/
+
+ future& operator=(
+ const T& item
+ );
+ /*!
+ ensures
+ - if (is_ready() == false) then
+ - the call to this function blocks until the thread processing the task related
+ to this future has finished.
+ - #is_ready() == true
+ - #get() == item
+ - returns *this
+ !*/
+
+ future& operator=(
+ const future& item
+ );
+ /*!
+ ensures
+ - if (is_ready() == false || item.is_ready() == false) then
+ - the call to this function blocks until the threads processing the tasks related
+ to this future and the item future have finished.
+ - #is_ready() == true
+ - #item.is_ready() == true
+ - #get() == item.get()
+ - returns *this
+ !*/
+
+ operator T& (
+ );
+ /*!
+ ensures
+ - if (is_ready() == false) then
+ - the call to this function blocks until the thread processing the task related
+ to this future has finished.
+ - #is_ready() == true
+ - returns get()
+ !*/
+
+ operator const T& (
+ ) const;
+ /*!
+ ensures
+ - if (is_ready() == false) then
+ - the call to this function blocks until the thread processing the task related
+ to this future has finished.
+ - #is_ready() == true
+ - returns get()
+ !*/
+
+ T& get (
+ );
+ /*!
+ ensures
+ - if (is_ready() == false) then
+ - the call to this function blocks until the thread processing the task related
+ to this future has finished.
+ - #is_ready() == true
+ - returns a non-const reference to the object of type T contained inside this future
+ !*/
+
+ const T& get (
+ ) const;
+ /*!
+ ensures
+ - if (is_ready() == false) then
+ - the call to this function blocks until the thread processing the task related
+ to this future has finished.
+ - #is_ready() == true
+ - returns a const reference to the object of type T contained inside this future
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ inline void swap (
+ future<T>& a,
+ future<T>& b
+ ) { std::swap(a.get(), b.get()); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+
+// The future object comes with overloads for all the usual comparison operators.
+
+ template <typename T> bool operator== (const future<T>& a, const future<T>& b) { return a.get() == b.get(); }
+ template <typename T> bool operator!= (const future<T>& a, const future<T>& b) { return a.get() != b.get(); }
+ template <typename T> bool operator<= (const future<T>& a, const future<T>& b) { return a.get() <= b.get(); }
+ template <typename T> bool operator>= (const future<T>& a, const future<T>& b) { return a.get() >= b.get(); }
+ template <typename T> bool operator< (const future<T>& a, const future<T>& b) { return a.get() < b.get(); }
+ template <typename T> bool operator> (const future<T>& a, const future<T>& b) { return a.get() > b.get(); }
+
+ template <typename T> bool operator== (const future<T>& a, const T& b) { return a.get() == b; }
+ template <typename T> bool operator== (const T& a, const future<T>& b) { return a == b.get(); }
+ template <typename T> bool operator!= (const future<T>& a, const T& b) { return a.get() != b; }
+ template <typename T> bool operator!= (const T& a, const future<T>& b) { return a != b.get(); }
+ template <typename T> bool operator<= (const future<T>& a, const T& b) { return a.get() <= b; }
+ template <typename T> bool operator<= (const T& a, const future<T>& b) { return a <= b.get(); }
+ template <typename T> bool operator>= (const future<T>& a, const T& b) { return a.get() >= b; }
+ template <typename T> bool operator>= (const T& a, const future<T>& b) { return a >= b.get(); }
+ template <typename T> bool operator< (const future<T>& a, const T& b) { return a.get() < b; }
+ template <typename T> bool operator< (const T& a, const future<T>& b) { return a < b.get(); }
+ template <typename T> bool operator> (const future<T>& a, const T& b) { return a.get() > b; }
+ template <typename T> bool operator> (const T& a, const future<T>& b) { return a > b.get(); }
+
+// ----------------------------------------------------------------------------------------
+
+ class thread_pool
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a fixed size group of threads which you can
+ submit tasks to and then wait for those tasks to be completed.
+
+ Note that setting the number of threads to 0 is a valid way to
+ use this object. It causes it to not contain any threads
+ at all. When tasks are submitted to the object in this mode
+ the tasks are processed within the calling thread. So in this
+ mode any thread that calls add_task() is considered to be
+ a thread_pool thread capable of executing tasks.
+
+ This object is also implemented such that no memory allocations occur
+ after the thread_pool has been constructed so long as the user doesn't
+ call any of the add_task_by_value() routines. The future object also
+ doesn't perform any memory allocations or contain any system resources
+ such as mutex objects.
+
+ EXCEPTIONS
+ Note that if an exception is thrown inside a task thread and is not caught
+ then the exception will be trapped inside the thread pool and rethrown at a
+ later time when someone calls one of the add task or wait member functions
+ of the thread pool. This allows exceptions to propagate out of task threads
+ and into the calling code where they can be handled.
+ !*/
+
+ public:
+ explicit thread_pool (
+ unsigned long num_threads
+ );
+ /*!
+ ensures
+ - #num_threads_in_pool() == num_threads
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create threading objects.
+ !*/
+
+ ~thread_pool(
+ );
+ /*!
+ ensures
+ - blocks until all tasks in the pool have finished.
+ - If one of the threads has generated an exception but it hasn't yet been
+ rethrown to the caller (e.g. by calling wait_for_all_tasks()) then the
+ program will be terminated. So make sure you handle all the possible
+ exceptions from your tasks.
+ !*/
+
+ bool is_task_thread (
+ ) const;
+ /*!
+ ensures
+ - if (the thread calling this function is one of the threads in this
+ thread pool or num_threads_in_pool() == 0) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ unsigned long num_threads_in_pool (
+ ) const;
+ /*!
+ ensures
+ - returns the number of threads contained in this thread pool. That is, returns
+ the maximum number of tasks that this object will process concurrently.
+ !*/
+
+ template <typename F>
+ uint64 add_task_by_value (
+ const F& function_object
+ );
+ /*!
+ requires
+ - function_object() is a valid expression
+ ensures
+ - makes a copy of function_object, call it FCOPY.
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls FCOPY() within the calling thread and returns when it finishes
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls FCOPY().
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename T>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)()
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ - obj will not go out of scope until after the task has completed (i.e.
+ this function passes obj to the task by reference. If you want to avoid
+ this restriction then use add_task_by_value())
+ ensures
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls (obj.*funct)() within the calling thread and returns
+ when it finishes.
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls (obj.*funct)()
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename T>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)()
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ ensures
+ - makes a copy of obj, call it OBJ_COPY.
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls (OBJ_COPY.*funct)() within the calling thread and returns
+ when it finishes.
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls (OBJ_COPY.*funct)().
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename T>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(long),
+ long arg1
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ - obj will not go out of scope until after the task has completed (i.e.
+ this function passes obj to the task by reference. If you want to avoid
+ this restriction then use add_task_by_value())
+ ensures
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls (obj.*funct)(arg1) within the calling thread and returns
+ when it finishes
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls (obj.*funct)(arg1)
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename T>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(long,long),
+ long arg1,
+ long arg2
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ - obj will not go out of scope until after the task has completed (i.e.
+ this function passes obj to the task by reference. If you want to avoid
+ this restriction then use add_task_by_value())
+ ensures
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls (obj.*funct)(arg1,arg2) within the calling thread and returns
+ when it finishes
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls (obj.*funct)(arg1,arg2)
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ void wait_for_task (
+ uint64 task_id
+ ) const;
+ /*!
+ ensures
+ - if (there is currently a task with the given id being executed in the thread pool) then
+ - the call to this function blocks until the task with the given id is complete
+ - else
+ - the call to this function returns immediately
+ !*/
+
+ void wait_for_all_tasks (
+ ) const;
+ /*!
+ ensures
+ - the call to this function blocks until all tasks which were submitted
+ to the thread pool by the thread that is calling this function have
+ finished.
+ !*/
+
+ // --------------------
+
+ template <typename F, typename A1>
+ uint64 add_task (
+ F& function_object,
+ future<A1>& arg1
+ );
+ /*!
+ requires
+ - function_object(arg1.get()) is a valid expression
+ (i.e. The A1 type stored in the future must be a type that can be passed into the given function object)
+ - function_object will not go out of scope until after the task has completed (i.e.
+ this function passes function_object to the task by reference. If you want to avoid
+ this restriction then use add_task_by_value())
+ ensures
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls function_object(arg1.get()) within the calling thread and returns
+ when it finishes
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls function_object(arg1.get()).
+ - #arg1.is_ready() == false
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename F, typename A1>
+ uint64 add_task_by_value (
+ const F& function_object,
+ future<A1>& arg1
+ );
+ /*!
+ requires
+ - function_object(arg1.get()) is a valid expression
+ (i.e. The A1 type stored in the future must be a type that can be passed into the given function object)
+ ensures
+ - makes a copy of function_object, call it FCOPY.
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls FCOPY(arg1.get()) within the calling thread and returns when it finishes
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls FCOPY(arg1.get()).
+ - #arg1.is_ready() == false
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename T, typename T1, typename A1>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(T1),
+ future<A1>& arg1
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ - (obj.*funct)(arg1.get()) must be a valid expression.
+ (i.e. The A1 type stored in the future must be a type that can be passed into the given function)
+ - obj will not go out of scope until after the task has completed (i.e.
+ this function passes obj to the task by reference. If you want to avoid
+ this restriction then use add_task_by_value())
+ ensures
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls (obj.*funct)(arg1.get()) within the calling thread and returns
+ when it finishes
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls (obj.*funct)(arg1.get()).
+ - #arg1.is_ready() == false
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename T, typename T1, typename A1>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1),
+ future<A1>& arg1
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ - (obj.*funct)(arg1.get()) must be a valid expression.
+ (i.e. The A1 type stored in the future must be a type that can be passed into the given function)
+ ensures
+ - makes a copy of obj, call it OBJ_COPY.
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls (OBJ_COPY.*funct)(arg1.get()) within the calling thread and returns
+ when it finishes.
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls (OBJ_COPY.*funct)(arg1.get()).
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename T, typename T1, typename A1>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)(T1) const,
+ future<A1>& arg1
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ - (obj.*funct)(arg1.get()) must be a valid expression.
+ (i.e. The A1 type stored in the future must be a type that can be passed into the given function)
+ - obj will not go out of scope until after the task has completed (i.e.
+ this function passes obj to the task by reference. If you want to avoid
+ this restriction then use add_task_by_value())
+ ensures
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls (obj.*funct)(arg1.get()) within the calling thread and returns
+ when it finishes
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls (obj.*funct)(arg1.get()).
+ - #arg1.is_ready() == false
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename T, typename T1, typename A1>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1) const,
+ future<A1>& arg1
+ );
+ /*!
+ requires
+ - funct == a valid member function pointer for class T
+ - (obj.*funct)(arg1.get()) must be a valid expression.
+ (i.e. The A1 type stored in the future must be a type that can be passed into the given function)
+ ensures
+ - makes a copy of obj, call it OBJ_COPY.
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls (OBJ_COPY.*funct)(arg1.get()) within the calling thread and returns
+ when it finishes.
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls (OBJ_COPY.*funct)(arg1.get()).
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ template <typename T1, typename A1>
+ uint64 add_task (
+ void (*funct)(T1),
+ future<A1>& arg1
+ );
+ /*!
+ requires
+ - funct == a valid function pointer
+ - (funct)(arg1.get()) must be a valid expression.
+ (i.e. The A1 type stored in the future must be a type that can be passed into the given function)
+ ensures
+ - if (is_task_thread() == true and there aren't any free threads available) then
+ - calls funct(arg1.get()) within the calling thread and returns
+ when it finishes
+ - else
+ - the call to this function blocks until there is a free thread in the pool
+ to process this new task. Once a free thread is available the task
+ is handed off to that thread which then calls funct(arg1.get()).
+ - #arg1.is_ready() == false
+ - returns a task id that can be used by this->wait_for_task() to wait
+ for the submitted task to finish.
+ !*/
+
+ // --------------------------------------------------------------------------------
+ // The remainder of this class just contains overloads for add_task() and add_task_by_value()
+ // that take up to 4 futures (as well as 0 futures). Their behavior is identical to the above
+ // add_task() and add_task_by_value() functions.
+ // --------------------------------------------------------------------------------
+
+ template <typename F, typename A1, typename A2>
+ uint64 add_task (
+ F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2
+ );
+
+ template <typename F, typename A1, typename A2>
+ uint64 add_task_by_value (
+ const F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(T1,T2),
+ future<A1>& arg1,
+ future<A2>& arg2
+ );
+
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2),
+ future<A1>& arg1,
+ future<A2>& arg2
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)(T1,T2) const,
+ future<A1>& arg1,
+ future<A2>& arg2
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2) const,
+ future<A1>& arg1,
+ future<A2>& arg2
+ );
+
+ template <typename T1, typename A1,
+ typename T2, typename A2>
+ uint64 add_task (
+ void (*funct)(T1,T2),
+ future<A1>& arg1,
+ future<A2>& arg2
+ );
+
+ // --------------------
+
+ template <typename F, typename A1, typename A2, typename A3>
+ uint64 add_task (
+ F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ );
+
+ template <typename F, typename A1, typename A2, typename A3>
+ uint64 add_task_by_value (
+ const F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(T1,T2,T3),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3) const,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3) const,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ );
+
+ template <typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3>
+ uint64 add_task (
+ void (*funct)(T1,T2,T3),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3
+ );
+
+ // --------------------
+
+ template <typename F, typename A1, typename A2, typename A3, typename A4>
+ uint64 add_task (
+ F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ );
+
+ template <typename F, typename A1, typename A2, typename A3, typename A4>
+ uint64 add_task_by_value (
+ const F& function_object,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task (
+ T& obj,
+ void (T::*funct)(T1,T2,T3,T4),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3,T4),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3,T4) const,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ );
+
+ template <typename T, typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)(T1,T2,T3,T4) const,
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ );
+
+ template <typename T1, typename A1,
+ typename T2, typename A2,
+ typename T3, typename A3,
+ typename T4, typename A4>
+ uint64 add_task (
+ void (*funct)(T1,T2,T3,T4),
+ future<A1>& arg1,
+ future<A2>& arg2,
+ future<A3>& arg3,
+ future<A4>& arg4
+ );
+
+ // --------------------
+
+ template <typename F>
+ uint64 add_task (
+ F& function_object
+ );
+
+ template <typename T>
+ uint64 add_task (
+ const T& obj,
+ void (T::*funct)() const,
+ );
+
+ template <typename T>
+ uint64 add_task_by_value (
+ const T& obj,
+ void (T::*funct)() const
+ );
+
+ uint64 add_task (
+ void (*funct)()
+ );
+
+ // --------------------
+
+ private:
+
+ // restricted functions
+ thread_pool(thread_pool&); // copy constructor
+ thread_pool& operator=(thread_pool&); // assignment operator
+ };
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_THREAD_POOl_ABSTRACT_Hh_
+
+
+
diff --git a/ml/dlib/dlib/threads/thread_specific_data_extension.h b/ml/dlib/dlib/threads/thread_specific_data_extension.h
new file mode 100644
index 000000000..0b5339200
--- /dev/null
+++ b/ml/dlib/dlib/threads/thread_specific_data_extension.h
@@ -0,0 +1,141 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_
+#define DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_
+
+#include "thread_specific_data_extension_abstract.h"
+#include "threads_kernel_abstract.h"
+#include "../binary_search_tree.h"
+#include "auto_mutex_extension.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class thread_specific_data
+ {
+ /*!
+ CONVENTION
+ - for all valid ID:
+ (*items[ID]) == pointer to the data for thread with id ID
+ !*/
+ public:
+
+ thread_specific_data (
+ )
+ {
+ thread_end_handler_calls_left = 0;
+ }
+
+ ~thread_specific_data (
+ )
+ {
+ // We should only call the unregister_thread_end_handler function if there are
+ // some outstanding callbacks we expect to get. Otherwise lets avoid calling it
+ // since the dlib state that maintains the registered thread end handlers may have
+ // been destructed already (since the program might be in the process of terminating).
+ bool call_unregister = false;
+ m.lock();
+ if (thread_end_handler_calls_left > 0)
+ call_unregister = true;
+ m.unlock();
+
+ if (call_unregister)
+ unregister_thread_end_handler(const_cast<thread_specific_data&>(*this),&thread_specific_data::thread_end_handler);
+
+ auto_mutex M(m);
+ items.reset();
+ while (items.move_next())
+ {
+ delete items.element().value();
+ }
+ }
+
+ inline T& data (
+ ) { return get_data(); }
+
+ inline const T& data (
+ ) const { return get_data(); }
+
+ private:
+
+ T& get_data (
+ ) const
+ {
+ thread_id_type id = get_thread_id();
+ auto_mutex M(m);
+
+ T** item = items[id];
+ if (item)
+ {
+ return **item;
+ }
+ else
+ {
+ // register an end handler for this thread so long as it is a dlib created thread.
+ T* new_item = new T;
+
+ bool in_tree = false;
+ try
+ {
+ T* temp_item = new_item;
+ thread_id_type temp_id = id;
+ items.add(temp_id,temp_item);
+ in_tree = true;
+
+ if (is_dlib_thread(id))
+ {
+ register_thread_end_handler(const_cast<thread_specific_data&>(*this),&thread_specific_data::thread_end_handler);
+ ++thread_end_handler_calls_left;
+ }
+ }
+ catch (...)
+ {
+ if (in_tree)
+ {
+ items.destroy(id);
+ }
+ delete new_item;
+ throw;
+ }
+
+ return *new_item;
+ }
+ }
+
+ void thread_end_handler (
+ )
+ {
+ const thread_id_type id = get_thread_id();
+ thread_id_type junk = 0;
+ T* item = 0;
+ auto_mutex M(m);
+ --thread_end_handler_calls_left;
+ if (items[id])
+ {
+ items.remove(id,junk,item);
+ delete item;
+ }
+ }
+
+ mutable typename binary_search_tree<thread_id_type,T*>::kernel_2a items;
+ mutex m;
+ mutable long thread_end_handler_calls_left;
+
+ // restricted functions
+ thread_specific_data(thread_specific_data&); // copy constructor
+ thread_specific_data& operator=(thread_specific_data&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_
+
+
+
diff --git a/ml/dlib/dlib/threads/thread_specific_data_extension_abstract.h b/ml/dlib/dlib/threads/thread_specific_data_extension_abstract.h
new file mode 100644
index 000000000..03fb9ddaa
--- /dev/null
+++ b/ml/dlib/dlib/threads/thread_specific_data_extension_abstract.h
@@ -0,0 +1,87 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_ABSTRACT_
+#ifdef DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class thread_specific_data
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a container of thread specific data. When
+ a thread calls the data() member function it gets a reference to a T object
+ that is specific to its own thread. Each subsequent call to data() from that
+ thread returns the same instance. Also note that when a thread ends
+ the instance of its data() object gets destroyed and freed (if the thread
+ was created by the dlib library). So any pointers or references to the object
+ will be invalid after the thread has ended.
+ !*/
+ public:
+
+ thread_specific_data (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ !*/
+
+ ~thread_specific_data (
+ );
+ /*!
+ ensures
+ - all resources allocated by *this have been freed. This includes
+ all the thread specific data returned by the data() functions.
+ !*/
+
+ T& data (
+ );
+ /*!
+ ensures
+ - if (the calling thread has NOT called this->data() before) then
+ - constructs an instance of T that is specific to the calling
+ thread.
+ - returns a reference to the T instance that was constructed for
+ the calling thread.
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ If an exception is thrown then the call to data() will have
+ no effect on *this.
+ !*/
+
+ const T& data (
+ ) const;
+ /*!
+ ensures
+ - if (the calling thread has NOT called this->data() before) then
+ - constructs an instance of T that is specific to the calling
+ thread.
+ - returns a const reference to the T instance that was constructed for
+ the calling thread.
+ throws
+ - std::bad_alloc or any exception thrown by T's constructor
+ If an exception is thrown then the call to data() will have
+ no effect on *this.
+ !*/
+
+ private:
+ // restricted functions
+ thread_specific_data(thread_specific_data&); // copy constructor
+ thread_specific_data& operator=(thread_specific_data&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/threads/threaded_object_extension.cpp b/ml/dlib/dlib/threads/threaded_object_extension.cpp
new file mode 100644
index 000000000..a7326c11d
--- /dev/null
+++ b/ml/dlib/dlib/threads/threaded_object_extension.cpp
@@ -0,0 +1,290 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADED_OBJECT_EXTENSIOn_CPP
+#define DLIB_THREADED_OBJECT_EXTENSIOn_CPP
+
+#include "threaded_object_extension.h"
+#include "create_new_thread_extension.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ threaded_object::
+ threaded_object (
+ ):
+ s(m_),
+ id1(0),
+ is_running_(false),
+ is_alive_(false),
+ should_stop_(false),
+ id_valid(false)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ threaded_object::
+ ~threaded_object (
+ )
+ {
+ try
+ {
+ DLIB_ASSERT(is_alive() == false,
+ "\tthreaded_object::~threaded_object()"
+ << "\n\tYou have let a threaded object destruct itself before terminating its thread"
+ << "\n\tthis: " << this
+ );
+ }
+ catch (std::exception& e)
+ {
+ std::cerr << e.what() << std::endl;
+ assert(false);
+ abort();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool threaded_object::
+ is_running (
+ ) const
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(id1 != get_thread_id() || id_valid == false,
+ "\tbool threaded_object::is_running()"
+ << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+
+ return is_running_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool threaded_object::
+ is_alive (
+ ) const
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(id1 != get_thread_id() || id_valid == false,
+ "\tbool threaded_object::is_alive()"
+ << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+
+ return is_alive_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threaded_object::
+ wait (
+ ) const
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(id1 != get_thread_id() || id_valid == false,
+ "\tvoid threaded_object::wait()"
+ << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+
+ while (is_alive_)
+ s.wait();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threaded_object::
+ start (
+ )
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(id1 != get_thread_id() || id_valid == false,
+ "\tvoid threaded_object::start()"
+ << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+
+ if (is_alive_ == false)
+ {
+ if (create_new_thread<threaded_object,&threaded_object::thread_helper>(*this) == false)
+ {
+ is_running_ = false;
+ throw thread_error();
+ }
+ }
+ is_alive_ = true;
+ is_running_ = true;
+ should_stop_ = false;
+ s.broadcast();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threaded_object::
+ restart (
+ )
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(id1 != get_thread_id() || id_valid == false,
+ "\tvoid threaded_object::restart()"
+ << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+
+ if (is_alive_ == false)
+ {
+ if (create_new_thread<threaded_object,&threaded_object::thread_helper>(*this) == false)
+ {
+ is_running_ = false;
+ throw thread_error();
+ }
+ should_respawn_ = false;
+ }
+ else
+ {
+ should_respawn_ = true;
+ }
+ is_alive_ = true;
+ is_running_ = true;
+ should_stop_ = false;
+ s.broadcast();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threaded_object::
+ set_respawn (
+ )
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(id1 != get_thread_id() || id_valid == false,
+ "\tvoid threaded_object::set_respawn()"
+ << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+
+ should_respawn_ = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool threaded_object::
+ should_respawn (
+ ) const
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(id1 != get_thread_id() || id_valid == false,
+ "\tbool threaded_object::should_respawn()"
+ << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+
+ return should_respawn_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threaded_object::
+ pause (
+ )
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(id1 != get_thread_id() || id_valid == false,
+ "\tvoid threaded_object::pause()"
+ << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+
+ is_running_ = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threaded_object::
+ stop (
+ )
+ {
+ auto_mutex M(m_);
+
+ DLIB_ASSERT(id1 != get_thread_id() || id_valid == false,
+ "\tvoid threaded_object::stop()"
+ << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+
+ should_stop_ = true;
+ is_running_ = false;
+ should_respawn_ = false;
+ s.broadcast();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool threaded_object::
+ should_stop (
+ ) const
+ {
+ auto_mutex M(m_);
+ DLIB_ASSERT(is_alive_ && id1 == get_thread_id() && id_valid == true,
+ "\tbool threaded_object::should_stop()"
+ << "\n\tYou can only call this function from the thread that executes threaded_object::thread"
+ << "\n\tthis: " << this
+ );
+ while (is_running_ == false && should_stop_ == false)
+ s.wait();
+ return should_stop_;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threaded_object::
+ thread_helper(
+ )
+ {
+#ifdef ENABLE_ASSERTS
+ id1 = get_thread_id();
+ id_valid = true;
+#endif
+ while (true)
+ {
+ m_.lock();
+ should_respawn_ = false;
+ m_.unlock();
+
+ thread();
+
+ auto_mutex M(m_);
+
+ if (should_respawn_)
+ continue;
+
+#ifdef ENABLE_ASSERTS
+ id_valid = false;
+#endif
+
+ is_alive_ = false;
+ is_running_ = false;
+ should_stop_ = false;
+ s.broadcast();
+
+ return;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THREADED_OBJECT_EXTENSIOn_CPP
+
diff --git a/ml/dlib/dlib/threads/threaded_object_extension.h b/ml/dlib/dlib/threads/threaded_object_extension.h
new file mode 100644
index 000000000..dcf00daea
--- /dev/null
+++ b/ml/dlib/dlib/threads/threaded_object_extension.h
@@ -0,0 +1,123 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADED_OBJECT_EXTENSIOn_
+#define DLIB_THREADED_OBJECT_EXTENSIOn_
+
+#include "threaded_object_extension_abstract.h"
+#include "threads_kernel.h"
+#include "auto_mutex_extension.h"
+#include "../algs.h"
+#include "../assert.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class threaded_object
+ {
+ /*!
+ INITIAL VALUE
+ - is_running_ == false
+ - is_alive_ == false
+ - should_stop_ == false
+ - should_respawn_ == false
+
+#ifdef ENABLE_ASSERTS
+ - id_valid == false
+ - id1 == get_main_thread_id()
+#endif
+
+ CONVENTION
+ - is_running() == is_running_
+ - is_alive() == is_alive_
+ - should_stop() == should_stop_
+ - should_respawn() == should_respawn_
+
+
+#ifdef ENABLE_ASSERTS
+ - if (when thread() is executing) then
+ - id1 == the id of the running thread
+ - id_valid == true
+ - else
+ - id1 == an undefined value
+ - id_valid == false
+#endif
+
+ - m_ == the mutex used to protect all our variables
+ - s == the signaler for m_
+ !*/
+
+ public:
+
+ threaded_object (
+ );
+
+ virtual ~threaded_object (
+ );
+
+ bool is_running (
+ ) const;
+
+ bool is_alive (
+ ) const;
+
+ void wait (
+ ) const;
+
+ void start (
+ );
+
+ void restart (
+ );
+
+ void set_respawn (
+ );
+
+ bool should_respawn (
+ ) const;
+
+ void pause (
+ );
+
+ void stop (
+ );
+
+ protected:
+
+ bool should_stop (
+ ) const;
+
+ private:
+
+ void thread_helper(
+ );
+
+ virtual void thread (
+ ) = 0;
+
+ mutex m_;
+ signaler s;
+ thread_id_type id1;
+ bool is_running_;
+ bool is_alive_;
+ bool should_stop_;
+ bool should_respawn_;
+ bool id_valid;
+
+ // restricted functions
+ threaded_object(threaded_object&); // copy constructor
+ threaded_object& operator=(threaded_object&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "threaded_object_extension.cpp"
+#endif
+
+#endif // DLIB_THREADED_OBJECT_EXTENSIOn_
+
+
diff --git a/ml/dlib/dlib/threads/threaded_object_extension_abstract.h b/ml/dlib/dlib/threads/threaded_object_extension_abstract.h
new file mode 100644
index 000000000..32a8fbc31
--- /dev/null
+++ b/ml/dlib/dlib/threads/threaded_object_extension_abstract.h
@@ -0,0 +1,199 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_THREADED_OBJECT_EXTENSIOn_ABSTRACT_
+#ifdef DLIB_THREADED_OBJECT_EXTENSIOn_ABSTRACT_
+
+#include "threads_kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class threaded_object
+ {
+ /*!
+ INITIAL VALUE
+ - is_running() == false
+ - is_alive() == false
+ - should_respawn() == false
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple threaded object. To use it you inherit
+ from it and define the thread() function. Then when you call start()
+ it will spawn a thread that calls this->thread().
+ !*/
+ public:
+
+ threaded_object (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create threading objects.
+ !*/
+
+ virtual ~threaded_object (
+ );
+ /*!
+ requires
+ - is_alive() == false
+ (i.e. in the destructor for the object you derive from this one you
+ must wait for this->thread() to end.)
+ ensures
+ - all resources allocated by *this have been freed.
+ !*/
+
+ bool is_running (
+ ) const;
+ /*!
+ requires
+ - is not called from this->thread()
+ ensures
+ - if (is_alive() && this->thread() is currently supposed to be executing) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_alive (
+ ) const;
+ /*!
+ requires
+ - is not called from this->thread()
+ ensures
+ - if (this->thread() has been called by some thread and has yet to terminate) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ void wait (
+ ) const;
+ /*!
+ requires
+ - is not called from this->thread()
+ ensures
+ - if (is_alive() == true) then
+ - blocks until this->thread() terminates
+ !*/
+
+ void start (
+ );
+ /*!
+ requires
+ - is not called from this->thread()
+ ensures
+ - #is_alive() == true
+ - #is_running() == true
+ - #should_stop() == false
+ throws
+ - std::bad_alloc or dlib::thread_error
+ If either of these exceptions are thrown then
+ #is_alive() == false and #is_running() == false
+ !*/
+
+ void set_respawn (
+ );
+ /*!
+ requires
+ - is not called from this->thread()
+ ensures
+ - #should_respawn() == true
+ !*/
+
+ bool should_respawn (
+ ) const;
+ /*!
+ requires
+ - is not called from this->thread()
+ ensures
+ - returns true if the thread will automatically restart upon termination and
+ false otherwise. Note that every time a thread starts it sets should_respawn()
+ back to false. Therefore, a single call to set_respawn() can cause at most
+ one respawn to occur.
+ !*/
+
+ void restart (
+ );
+ /*!
+ requires
+ - is not called from this->thread()
+ ensures
+ - This function atomically executes set_respawn() and start(). The precise meaning of this
+ is defined below.
+ - if (is_alive()) then
+ - #should_respawn() == true
+ - else
+ - #should_respawn() == false
+ - #is_alive() == true
+ - #is_running() == true
+ - #should_stop() == false
+ throws
+ - std::bad_alloc or dlib::thread_error
+ If either of these exceptions are thrown then
+ #is_alive() == false and #is_running() == false
+ !*/
+
+ void pause (
+ );
+ /*!
+ requires
+ - is not called from this->thread()
+ ensures
+ - #is_running() == false
+ !*/
+
+ void stop (
+ );
+ /*!
+ requires
+ - is not called from this->thread()
+ ensures
+ - #should_stop() == true
+ - #is_running() == false
+ - #should_respawn() == false
+ !*/
+
+ protected:
+
+ bool should_stop (
+ ) const;
+ /*!
+ requires
+ - is only called from the thread that executes this->thread()
+ ensures
+ - calls to this function block until (#is_running() == true || #should_stop() == true)
+ - if (this thread is supposed to terminate) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ private:
+
+ virtual void thread (
+ ) = 0;
+ /*!
+ requires
+ - is executed in its own thread
+ - is only executed in one thread at a time
+ throws
+ - does not throw any exceptions
+ !*/
+
+ // restricted functions
+ threaded_object(threaded_object&); // copy constructor
+ threaded_object& operator=(threaded_object&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THREADED_OBJECT_EXTENSIOn_ABSTRACT_
+
diff --git a/ml/dlib/dlib/threads/threads_kernel.h b/ml/dlib/dlib/threads/threads_kernel.h
new file mode 100644
index 000000000..77cb16d92
--- /dev/null
+++ b/ml/dlib/dlib/threads/threads_kernel.h
@@ -0,0 +1,18 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADs_KERNEL_
+#define DLIB_THREADs_KERNEL_
+
+#include "../platform.h"
+
+#ifdef WIN32
+#include "windows.h"
+#endif
+
+#ifndef WIN32
+#include "posix.h"
+#endif
+
+#endif // DLIB_THREADs_KERNEL_
+
+
diff --git a/ml/dlib/dlib/threads/threads_kernel_1.cpp b/ml/dlib/dlib/threads/threads_kernel_1.cpp
new file mode 100644
index 000000000..cb36b8d3f
--- /dev/null
+++ b/ml/dlib/dlib/threads/threads_kernel_1.cpp
@@ -0,0 +1,83 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADS_KERNEL_1_CPp_
+#define DLIB_THREADS_KERNEL_1_CPp_
+
+#include "../platform.h"
+
+#ifdef WIN32
+
+#include "threads_kernel_1.h"
+
+#include <process.h>
+
+
+namespace dlib
+{
+ namespace threads_kernel_shared_helpers
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ struct info
+ {
+ void* param;
+ void (*funct)(void*);
+ };
+
+ // -----------------------------------------------------------------------------------
+
+ unsigned int __stdcall thread_starter (
+ void* param
+ )
+ {
+ info* alloc_p = static_cast<info*>(param);
+ info p = *alloc_p;
+ delete alloc_p;
+
+ p.funct(p.param);
+ return 0;
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ bool spawn_thread (
+ void (*funct)(void*),
+ void* param
+ )
+ {
+ info* p;
+ try { p = new info; }
+ catch (...) { return false; }
+
+ p->funct = funct;
+ p->param = param;
+
+
+ unsigned int garbage;
+
+ HANDLE thandle = (HANDLE)_beginthreadex (NULL,0,thread_starter,p,0,&garbage);
+ // make thread and add it to the pool
+
+ // return false if _beginthreadex didn't work
+ if ( thandle == 0)
+ {
+ delete p;
+ return false;
+ }
+
+ // throw away the thread handle
+ CloseHandle(thandle);
+ return true;
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ }
+
+}
+
+#endif // WIN32
+
+#endif // DLIB_THREADS_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/threads/threads_kernel_1.h b/ml/dlib/dlib/threads/threads_kernel_1.h
new file mode 100644
index 000000000..586a21b7e
--- /dev/null
+++ b/ml/dlib/dlib/threads/threads_kernel_1.h
@@ -0,0 +1,158 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADS_KERNEl_1_
+#define DLIB_THREADS_KERNEl_1_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+#include "threads_kernel_abstract.h"
+
+#include "../windows_magic.h"
+#include <windows.h>
+#include "../algs.h"
+#include <condition_variable>
+#include <mutex>
+#include <chrono>
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ typedef DWORD thread_id_type;
+
+ inline thread_id_type get_thread_id (
+ )
+ {
+ return GetCurrentThreadId();
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // mutex object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ // forward declaration of signaler
+ class signaler;
+
+ class mutex
+ {
+ public:
+
+ mutex (
+ )
+ {
+ }
+
+ ~mutex (
+ ) { }
+
+ void lock (
+ ) const { cs.lock(); }
+
+ void unlock (
+ ) const { cs.unlock(); }
+
+ private:
+
+ friend class signaler;
+
+ mutable std::mutex cs;
+
+ // restricted functions
+ mutex(mutex&); // copy constructor
+ mutex& operator=(mutex&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // signaler object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class signaler
+ {
+
+ public:
+ signaler (
+ const mutex& associated_mutex
+ ) :
+ m(associated_mutex)
+ {
+
+ }
+
+ ~signaler (
+ ) { }
+
+ void wait (
+ ) const
+ {
+ std::unique_lock<std::mutex> cs(m.cs, std::defer_lock);
+ cv.wait(cs);
+ }
+
+ bool wait_or_timeout (
+ unsigned long milliseconds
+ ) const
+ {
+ std::unique_lock<std::mutex> cs(m.cs, std::defer_lock);
+ auto status = cv.wait_until(cs, std::chrono::system_clock::now() + std::chrono::milliseconds(milliseconds));
+ return status == std::cv_status::no_timeout;
+ }
+
+ void signal (
+ ) const
+ {
+ cv.notify_one();
+ }
+
+ void broadcast (
+ ) const
+ {
+ cv.notify_all();
+ }
+
+ const mutex& get_mutex (
+ ) const { return m; }
+
+ private:
+
+ mutable std::condition_variable cv;
+
+ const mutex& m;
+
+ // restricted functions
+ signaler(signaler&); // copy constructor
+ signaler& operator=(signaler&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ namespace threads_kernel_shared_helpers
+ {
+ bool spawn_thread (
+ void (*funct)(void*),
+ void* param
+ );
+ /*!
+ is identical to create_new_thread() but just doesn't use any thread pooling.
+ !*/
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#include "threads_kernel_shared.h"
+
+#ifdef NO_MAKEFILE
+#include "threads_kernel_1.cpp"
+#endif
+
+#endif // DLIB_THREADS_KERNEl_1_
+
diff --git a/ml/dlib/dlib/threads/threads_kernel_2.cpp b/ml/dlib/dlib/threads/threads_kernel_2.cpp
new file mode 100644
index 000000000..06fb80d00
--- /dev/null
+++ b/ml/dlib/dlib/threads/threads_kernel_2.cpp
@@ -0,0 +1,75 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADS_KERNEL_2_CPp_
+#define DLIB_THREADS_KERNEL_2_CPp_
+
+#include "../platform.h"
+
+#ifdef POSIX
+
+#include "threads_kernel_2.h"
+
+
+namespace dlib
+{
+ namespace threads_kernel_shared_helpers
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ struct info
+ {
+ void* param;
+ void (*funct)(void*);
+ };
+
+ // -----------------------------------------------------------------------------------
+
+ void* thread_starter (
+ void* param
+ )
+ {
+ info* alloc_p = static_cast<info*>(param);
+ info p = *alloc_p;
+ delete alloc_p;
+
+ // detach self
+ pthread_detach(pthread_self());
+
+ p.funct(p.param);
+ return 0;
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ bool spawn_thread (
+ void (*funct)(void*),
+ void* param
+ )
+ {
+ info* p;
+ try { p = new info; }
+ catch (...) { return false; }
+
+ p->funct = funct;
+ p->param = param;
+
+ pthread_t thread_id;
+ if ( pthread_create (&thread_id, 0, thread_starter, p) )
+ {
+ delete p;
+ return false;
+ }
+ return true;
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ }
+
+}
+
+#endif // POSIX
+
+#endif // DLIB_THREADS_KERNEL_2_CPp_
+
diff --git a/ml/dlib/dlib/threads/threads_kernel_2.h b/ml/dlib/dlib/threads/threads_kernel_2.h
new file mode 100644
index 000000000..209142131
--- /dev/null
+++ b/ml/dlib/dlib/threads/threads_kernel_2.h
@@ -0,0 +1,180 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADS_KERNEl_2_
+#define DLIB_THREADS_KERNEl_2_
+
+#ifdef DLIB_ISO_CPP_ONLY
+#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it."
+#endif
+
+#include "threads_kernel_abstract.h"
+#include <pthread.h>
+#include <errno.h>
+#include <sys/time.h>
+#include "../algs.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ typedef pthread_t thread_id_type;
+
+ inline thread_id_type get_thread_id (
+ )
+ {
+ return pthread_self();
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // mutex object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ // forward declaration of signaler
+ class signaler;
+
+ class mutex
+ {
+ // give signaler access to hMutex
+ friend class signaler;
+ public:
+
+ mutex (
+ )
+ {
+ if (pthread_mutex_init(&myMutex,0))
+ {
+ throw dlib::thread_error(ECREATE_MUTEX,
+ "in function mutex::mutex() an error occurred making the mutex"
+ );
+ }
+ }
+
+ ~mutex (
+ ) { pthread_mutex_destroy(&myMutex); }
+
+ void lock (
+ ) const { pthread_mutex_lock(&myMutex); }
+
+ void unlock (
+ ) const { pthread_mutex_unlock(&myMutex); }
+
+ private:
+
+ mutable pthread_mutex_t myMutex;
+
+ // restricted functions
+ mutex(mutex&); // copy constructor
+ mutex& operator=(mutex&); // assignement opertor
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // signaler object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class signaler
+ {
+
+ public:
+
+
+ signaler (
+ const mutex& assoc_mutex
+ ) :
+ associated_mutex(&assoc_mutex.myMutex),
+ m(assoc_mutex)
+ {
+ if (pthread_cond_init(&cond,0))
+ {
+ throw dlib::thread_error(ECREATE_SIGNALER,
+ "in function signaler::signaler() an error occurred making the signaler"
+ );
+ }
+ }
+
+ ~signaler (
+ ) { pthread_cond_destroy(&cond); }
+
+ void wait (
+ ) const
+ {
+ pthread_cond_wait(&cond,associated_mutex);
+ }
+
+ bool wait_or_timeout (
+ unsigned long milliseconds
+ ) const
+ {
+ timespec time_to_wait;
+
+ timeval curtime;
+ gettimeofday(&curtime,0);
+
+ // get the time and adjust the timespec object by the appropriate amount
+ time_to_wait.tv_sec = milliseconds/1000 + curtime.tv_sec;
+ time_to_wait.tv_nsec = curtime.tv_usec;
+ time_to_wait.tv_nsec *= 1000;
+ time_to_wait.tv_nsec += (milliseconds%1000)*1000000;
+
+ time_to_wait.tv_sec += time_to_wait.tv_nsec/1000000000;
+ time_to_wait.tv_nsec = time_to_wait.tv_nsec%1000000000;
+
+ if ( pthread_cond_timedwait(&cond,associated_mutex,&time_to_wait) == ETIMEDOUT)
+ {
+ return false;
+ }
+ else
+ {
+ return true;
+ }
+ }
+
+ void signal (
+ ) const { pthread_cond_signal(&cond); }
+
+ void broadcast (
+ ) const { pthread_cond_broadcast(&cond); }
+
+ const mutex& get_mutex (
+ ) const { return m; }
+
+ private:
+
+ pthread_mutex_t* const associated_mutex;
+ mutable pthread_cond_t cond;
+ const mutex& m;
+
+ // restricted functions
+ signaler(signaler&); // copy constructor
+ signaler& operator=(signaler&); // assignement opertor
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ namespace threads_kernel_shared_helpers
+ {
+ bool spawn_thread (
+ void (*funct)(void*),
+ void* param
+ );
+ /*!
+ is identical to create_new_thread() but just doesn't use any thread pooling.
+ !*/
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#include "threads_kernel_shared.h"
+
+#ifdef NO_MAKEFILE
+#include "threads_kernel_2.cpp"
+#endif
+
+#endif // DLIB_THREADS_KERNEl_2_
+
diff --git a/ml/dlib/dlib/threads/threads_kernel_abstract.h b/ml/dlib/dlib/threads/threads_kernel_abstract.h
new file mode 100644
index 000000000..d88d37dad
--- /dev/null
+++ b/ml/dlib/dlib/threads/threads_kernel_abstract.h
@@ -0,0 +1,302 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_THREADS_KERNEl_ABSTRACT_
+#ifdef DLIB_THREADS_KERNEl_ABSTRACT_
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ /*!
+ THREAD POOLING
+ When threads end they go into a global thread pool and each waits there
+ for 30 seconds before timing out and having its resources returned to the
+ operating system. When create_new_thread() is called it first looks in the
+ thread pool to see if there are any threads it can snatch from the pool, if
+ not then it makes a new one.
+
+ Note that whenever I say something happens when a thread "terminates" or "ends"
+ I mean "when it returns to the thread pool." From the client programmer point
+ of view a thread terminates/ends when it returns to the dlib thread pool and you
+ shouldn't and indeed don't need to know when it actually gets its resources
+ reclaimed by the operating system.
+
+ If you want to change the timeout to a different value you can #define
+ DLIB_THREAD_POOL_TIMEOUT to whatever value (in milliseconds) that you like.
+
+ EXCEPTIONS
+ Unless specified otherwise, nothing in this file throws exceptions.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ thread_id_type get_thread_id (
+ );
+ /*!
+ ensures
+ - returns a unique id for the calling thread. Note that while the id is unique
+ among all currently existing threads it may have been used by a previous
+ thread that has terminated.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_dlib_thread (
+ thread_id_type id = get_thread_id()
+ );
+ /*!
+ ensures
+ - if (the thread with the given id was spawned by a call to
+ dlib::create_new_thread) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void register_thread_end_handler (
+ T& obj,
+ void (T::*handler)()
+ );
+ /*!
+ requires
+ - handler == a valid member function pointer for class T
+ - handler does not throw
+ - handler does not call register_thread_end_handler()
+ - handler does not block
+ - is_dlib_thread() == true (i.e. the calling thread was spawned by dlib::create_new_thread())
+ ensures
+ - let ID == the thread id for the thread calling register_thread_end_handler()
+ - (obj.*handler)() will be called when the thread with thread id ID is
+ terminating and it will be called from within that terminating thread.
+ (i.e. inside the handler function get_thread_id() == ID == the id of the
+ thread that is terminating. )
+ - each call to this function adds another handler that will be called when
+ the given thread terminates. This means that if you call it a bunch of
+ times then you will end up registering multiple handlers (or single
+ handlers multiple times) that will be called when the thread ends.
+ throws
+ - std::bad_alloc
+ If this exception is thrown then the call to this function had no effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void unregister_thread_end_handler (
+ T& obj,
+ void (T::*handler)()
+ );
+ /*!
+ requires
+ - handler == a valid member function pointer for class T
+ ensures
+ - Undoes all previous calls to register_thread_end_handler(obj,handler).
+ So the given handler won't be called when any threads end.
+ throws
+ - std::bad_alloc
+ If this exception is thrown then the call to this function had no effect.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ bool create_new_thread (
+ void (*funct)(void*),
+ void* param
+ );
+ /*!
+ ensures
+ - creates a new thread for the function pointed to by funct
+ - passes it param as its parameter. (i.e. calls funct(param) from the new thread)
+ - returns true upon success and false upon failure to create the new thread
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // mutex object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class mutex
+ {
+ /*!
+ INITIAL VALUE
+ mutex is in the unlocked state
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a mutex intended to be used for synchronous
+ thread control of shared data. When a thread wants to access some
+ shared data it locks out other threads by calling lock() and calls
+ unlock() when it is finished.
+ !*/
+ public:
+
+ mutex (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create the mutex.
+ !*/
+
+ ~mutex (
+ );
+ /*!
+ requires
+ - *this is not locked
+ ensures
+ - all resources allocated by *this have been freed
+ !*/
+
+ void lock (
+ ) const;
+ /*!
+ requires
+ - the thread calling lock() does not already have a lock on *this
+ ensures
+ - if (*this is currently locked by another thread) then
+ - the thread that called lock() on *this is put to sleep until
+ it becomes available
+ - if (*this is currently unlocked) then
+ - #*this becomes locked and the current thread is NOT put to sleep
+ but now "owns" #*this
+ !*/
+
+ void unlock (
+ ) const;
+ /*!
+ requires
+ - the thread calling unlock() already has a lock on *this
+ ensures
+ - #*this is unlocked (i.e. other threads may now lock this object)
+ !*/
+
+
+ private:
+ // restricted functions
+ mutex(mutex&); // copy constructor
+ mutex& operator=(mutex&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // signaler object
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class signaler
+ {
+ /*!
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an event signaling system for threads. It gives
+ a thread the ability to wake up other threads that are waiting for a
+ particular signal.
+
+ Each signaler object is associated with one and only one mutex object.
+ More than one signaler object may be associated with a single mutex
+ but a signaler object may only be associated with a single mutex.
+
+ NOTE:
+ You must guard against spurious wakeups. This means that a thread
+ might return from a call to wait even if no other thread called
+ signal. This is rare but must be guarded against.
+ !*/
+ public:
+
+ signaler (
+ const mutex& associated_mutex
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ - #get_mutex() == associated_mutex
+ throws
+ - dlib::thread_error
+ the constructor may throw this exception if there is a problem
+ gathering resources to create the signaler.
+ !*/
+
+
+ ~signaler (
+ );
+ /*!
+ ensures
+ - all resources allocated by *this have been freed
+ !*/
+
+ void wait (
+ ) const;
+ /*!
+ requires
+ - get_mutex() is locked and owned by the calling thread
+ ensures
+ - atomically unlocks get_mutex() and blocks the calling thread
+ - calling thread may wake if another thread calls signal() or broadcast()
+ on *this
+ - when wait() returns the calling thread again has a lock on get_mutex()
+ !*/
+
+ bool wait_or_timeout (
+ unsigned long milliseconds
+ ) const;
+ /*!
+ requires
+ - get_mutex() is locked and owned by the calling thread
+ ensures
+ - atomically unlocks get_mutex() and blocks the calling thread
+ - calling thread may wake if another thread calls signal() or broadcast()
+ on *this
+ - after the specified number of milliseconds has elapsed the calling thread
+ will wake once get_mutex() is free
+ - when wait returns the calling thread again has a lock on get_mutex()
+
+ - returns false if the call to wait_or_timeout timed out
+ - returns true if the call did not time out
+ !*/
+
+
+ void signal (
+ ) const;
+ /*!
+ ensures
+ - if (at least one thread is waiting on *this) then
+ - at least one of the waiting threads will wake
+ !*/
+
+ void broadcast (
+ ) const;
+ /*!
+ ensures
+ - any and all threads waiting on *this will wake
+ !*/
+
+ const mutex& get_mutex (
+ ) const;
+ /*!
+ ensures
+ - returns a const reference to the mutex associated with *this
+ !*/
+
+ private:
+ // restricted functions
+ signaler(signaler&); // copy constructor
+ signaler& operator=(signaler&); // assignment operator
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THREADS_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/threads/threads_kernel_shared.cpp b/ml/dlib/dlib/threads/threads_kernel_shared.cpp
new file mode 100644
index 000000000..8e81193e9
--- /dev/null
+++ b/ml/dlib/dlib/threads/threads_kernel_shared.cpp
@@ -0,0 +1,318 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADS_KERNEL_SHARED_CPp_
+#define DLIB_THREADS_KERNEL_SHARED_CPp_
+
+#include "threads_kernel_shared.h"
+#include "../assert.h"
+#include "../platform.h"
+#include <iostream>
+
+
+#ifndef DLIB_THREAD_POOL_TIMEOUT
+// default to 30000 milliseconds
+#define DLIB_THREAD_POOL_TIMEOUT 30000
+#endif
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// threader functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ namespace threads_kernel_shared
+ {
+
+ bool thread_pool_has_been_destroyed = false;
+
+// ----------------------------------------------------------------------------------------
+
+ struct threader_destruct_helper
+ {
+ // cause the thread pool to begin its destruction process when
+ // global objects start to be destroyed
+ ~threader_destruct_helper()
+ {
+ thread_pool().destruct_if_ready();
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ threader& thread_pool (
+ )
+ {
+ static threader* thread_pool = new threader;
+ static threader_destruct_helper a;
+ return *thread_pool;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool threader::
+ is_dlib_thread (
+ thread_id_type id
+ )
+ {
+ auto_mutex M(data_mutex);
+ return thread_ids.is_member(id);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ threader::
+ threader (
+ ) :
+ total_count(0),
+ function_pointer(0),
+ pool_count(0),
+ data_ready(data_mutex),
+ data_empty(data_mutex),
+ destruct(false),
+ destructed(data_mutex),
+ do_not_ever_destruct(false)
+ {
+#ifdef WIN32
+ // Trying to destroy the global thread pool when we are part of a DLL and the
+ // DLL is being unloaded can sometimes lead to weird behavior. For example, in
+ // the python interpreter you will get the interpreter to hang. Or if we are
+ // part of a MATLAB mex file and the file is being unloaded there can also be
+ // similar weird issues. So when we are using dlib on windows we just disable
+ // the destruction of the global thread pool since it doesn't matter anyway.
+ // It's resources will just get freed by the OS. This is even the recommended
+ // thing to do by Microsoft (http://blogs.msdn.com/b/oldnewthing/archive/2012/01/05/10253268.aspx).
+ //
+ // As an aside, it's worth pointing out that the reason we try and free
+ // resources on program shutdown on other operating systems is so we can have
+ // clean reports from tools like valgrind which check for memory leaks. But
+ // trying to do this on windows is a lost cause so we give up in this case and
+ // follow the Microsoft recommendation.
+ do_not_ever_destruct = true;
+#endif // WIN32
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ threader::
+ ~threader (
+ )
+ {
+ data_mutex.lock();
+ destruct = true;
+ data_ready.broadcast();
+
+ // wait for all the threads to end
+ while (total_count > 0)
+ destructed.wait();
+
+ thread_pool_has_been_destroyed = true;
+ data_mutex.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threader::
+ destruct_if_ready (
+ )
+ {
+ if (do_not_ever_destruct)
+ return;
+
+ data_mutex.lock();
+
+ // if there aren't any active threads, just maybe some sitting around
+ // in the pool then just destroy the threader
+ if (total_count == pool_count)
+ {
+ destruct = true;
+ data_ready.broadcast();
+ data_mutex.unlock();
+ delete this;
+ }
+ else
+ {
+ // There are still some user threads running so there isn't
+ // much we can really do. Just let the program end without
+ // cleaning up threading resources.
+ data_mutex.unlock();
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void threader::
+ call_end_handlers (
+ )
+ {
+ reg.m.lock();
+ const thread_id_type id = get_thread_id();
+ thread_id_type id_copy;
+ member_function_pointer<> mfp;
+
+ // Remove all the member function pointers for this thread from the tree
+ // and call them.
+ while (reg.reg[id] != 0)
+ {
+ reg.reg.remove(id,id_copy,mfp);
+ reg.m.unlock();
+ mfp();
+ reg.m.lock();
+ }
+ reg.m.unlock();
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ bool threader::
+ create_new_thread (
+ void (*funct)(void*),
+ void* param
+ )
+ {
+
+ // get a lock on the data mutex
+ auto_mutex M(data_mutex);
+
+ // loop to ensure that the new function pointer is in the data
+ while (true)
+ {
+ // if the data is empty then add new data and quit loop
+ if (function_pointer == 0)
+ {
+ parameter = param;
+ function_pointer = funct;
+ break;
+ }
+ else
+ {
+ // wait for data to become empty
+ data_empty.wait();
+ }
+ }
+
+
+ // get a thread for this new data
+ // if a new thread must be created
+ if (pool_count == 0)
+ {
+ // make thread and add it to the pool
+ if ( threads_kernel_shared_helpers::spawn_thread(thread_starter, this) == false )
+ {
+ function_pointer = 0;
+ parameter = 0;
+ data_empty.signal();
+ return false;
+ }
+ ++total_count;
+ }
+ // wake up a thread from the pool
+ else
+ {
+ data_ready.signal();
+ }
+
+ return true;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ void thread_starter (
+ void* object
+ )
+ {
+ // get a reference to the calling threader object
+ threader& self = *static_cast<threader*>(object);
+
+
+ {
+ auto_mutex M(self.data_mutex);
+
+ // add this thread id
+ thread_id_type thread_id = get_thread_id();
+ self.thread_ids.add(thread_id);
+
+ // indicate that this thread is now in the thread pool
+ ++self.pool_count;
+
+ while (self.destruct == false)
+ {
+ // if data is ready then process it and launch the thread
+ // if its not ready then go back into the pool
+ while (self.function_pointer != 0)
+ {
+ // indicate that this thread is now out of the thread pool
+ --self.pool_count;
+
+ // get the data for the function call
+ void (*funct)(void*) = self.function_pointer;
+ void* param = self.parameter;
+ self.function_pointer = 0;
+
+ // signal that the data is now empty
+ self.data_empty.signal();
+
+ self.data_mutex.unlock();
+ // Call funct with its intended parameter. If this function throws then
+ // we intentionally let the exception escape the thread and result in whatever
+ // happens when it gets caught by the OS (generally the program is terminated).
+ funct(param);
+ self.call_end_handlers();
+
+ self.data_mutex.lock();
+
+ // indicate that this thread is now back in the thread pool
+ ++self.pool_count;
+ }
+
+ if (self.destruct == true)
+ break;
+
+ // if we timed out and there isn't any work to do then
+ // this thread will quit this loop and end.
+ if (self.data_ready.wait_or_timeout(DLIB_THREAD_POOL_TIMEOUT) == false &&
+ self.function_pointer == 0)
+ break;
+
+ }
+
+ // remove this thread id from thread_ids
+ thread_id = get_thread_id();
+ self.thread_ids.destroy(thread_id);
+
+ // indicate that this thread is now out of the thread pool
+ --self.pool_count;
+ --self.total_count;
+
+ self.destructed.signal();
+
+ } // end of auto_mutex M(self.data_mutex) block
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_dlib_thread (
+ thread_id_type id
+ )
+ {
+ return threads_kernel_shared::thread_pool().is_dlib_thread(id);
+ }
+
+ bool is_dlib_thread (
+ )
+ {
+ return is_dlib_thread(get_thread_id());
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_THREADS_KERNEL_SHARED_CPp_
+
diff --git a/ml/dlib/dlib/threads/threads_kernel_shared.h b/ml/dlib/dlib/threads/threads_kernel_shared.h
new file mode 100644
index 000000000..b4526e8db
--- /dev/null
+++ b/ml/dlib/dlib/threads/threads_kernel_shared.h
@@ -0,0 +1,274 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADS_KERNEl_SHARED_
+#define DLIB_THREADS_KERNEl_SHARED_
+
+// this file should be included at the bottom of one of the thread kernel headers for a
+// specific platform.
+//#include "../threads.h"
+#include "auto_mutex_extension.h"
+#include "../binary_search_tree.h"
+#include "../member_function_pointer.h"
+#include "../memory_manager.h"
+#include "../queue.h"
+#include "../set.h"
+#include "../test_for_odr_violations.h"
+
+
+
+
+
+namespace dlib
+{
+
+
+// ----------------------------------------------------------------------------------------
+
+ namespace threads_kernel_shared
+ {
+ void thread_starter (
+ void*
+ );
+
+ class threader
+ {
+ /*!
+ INITIAL VALUE
+ - pool_count == 0 and
+ - data_ready is associated with the mutex data_mutex
+ - data_empty is associated with the mutex data_mutex
+ - destructed is associated with the mutex data_mutex
+ - destruct == false
+ - total_count == 0
+ - function_pointer == 0
+ - do_not_ever_destruct == false
+
+ CONVENTION
+ - data_ready is associated with the mutex data_mutex
+ - data_empty is associated with the mutex data_mutex
+ - data_ready == a signaler used signal when there is new data waiting
+ to start a thread with.
+ - data_empty == a signaler used to signal when the data is now empty
+ - pool_count == the number of suspended threads in the thread pool
+ - total_count == the number of threads that are executing anywhere. i.e.
+ pool_count + the ones that are currently running some user function.
+ - if (function_pointer != 0) then
+ - parameter == a void pointer pointing to the parameter which
+ should be used to start the next thread
+ - function_pointer == a pointer to the next function to make a
+ new thread with
+
+ - if (the destructor is running) then
+ - destruct == true
+ - else
+ - destruct == false
+
+ - thread_ids is locked by the data_mutex
+ - thread_ids == a set that contains the thread id for each thread spawned by this
+ object.
+ !*/
+
+
+ public:
+ threader (
+ );
+
+ ~threader (
+ );
+
+ void destruct_if_ready (
+ );
+ /*!
+ ensures
+ - if (there are no threads currently running and we haven't set do_not_ever_destruct) then
+ - calls delete this
+ - else
+ - does nothing
+ !*/
+
+ bool create_new_thread (
+ void (*funct)(void*),
+ void* param
+ );
+
+ template <
+ typename T
+ >
+ void unregister_thread_end_handler (
+ T& obj,
+ void (T::*handler)()
+ )
+ {
+ member_function_pointer<> mfp, junk_mfp;
+ mfp.set(obj,handler);
+
+ thread_id_type junk_id;
+
+ // find any member function pointers in the registry that point to the same
+ // thing as mfp and remove them
+ auto_mutex M(reg.m);
+ reg.reg.reset();
+ while (reg.reg.move_next())
+ {
+ while (reg.reg.current_element_valid() && reg.reg.element().value() == mfp)
+ {
+ reg.reg.remove_current_element(junk_id, junk_mfp);
+ }
+ }
+ }
+
+ template <
+ typename T
+ >
+ void register_thread_end_handler (
+ T& obj,
+ void (T::*handler)()
+ )
+ {
+ thread_id_type id = get_thread_id();
+ member_function_pointer<> mfp;
+ mfp.set(obj,handler);
+
+ auto_mutex M(reg.m);
+ reg.reg.add(id,mfp);
+ }
+
+ bool is_dlib_thread (
+ thread_id_type id
+ );
+
+ private:
+
+ friend void thread_starter (
+ void*
+ );
+
+ void call_end_handlers (
+ );
+ /*!
+ ensures
+ - calls the registered end handlers for the calling thread and
+ then removes them from reg.reg
+ !*/
+
+
+ // private data
+ set<thread_id_type,memory_manager<char>::kernel_2b>::kernel_1b_c thread_ids;
+ unsigned long total_count;
+ void* parameter;
+ void (*function_pointer)(void*);
+ unsigned long pool_count;
+ mutex data_mutex; // mutex to protect the above data
+ signaler data_ready; // signaler to signal when there is new data
+ signaler data_empty; // signaler to signal when the data is empty
+ bool destruct;
+ signaler destructed; // signaler to signal when a thread has ended
+ bool do_not_ever_destruct;
+
+ struct registry_type
+ {
+ mutex m;
+ binary_search_tree<
+ thread_id_type,
+ member_function_pointer<>,
+ memory_manager<char>::kernel_2a
+ >::kernel_2a_c reg;
+ };
+
+ // stuff for the register_thread_end_handler
+ registry_type reg;
+
+
+ // restricted functions
+ threader(threader&); // copy constructor
+ threader& operator=(threader&); // assignement opertor
+
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ threader& thread_pool (
+ );
+ /*!
+ ensures
+ - returns a reference to the global threader object
+ !*/
+
+ // ------------------------------------------------------------------------------------
+
+ extern bool thread_pool_has_been_destroyed;
+ }
+
+ bool is_dlib_thread (
+ thread_id_type id
+ );
+
+ bool is_dlib_thread (
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ inline bool create_new_thread (
+ void (*funct)(void*),
+ void* param
+ )
+ {
+ try
+ {
+ // now make this thread
+ return threads_kernel_shared::thread_pool().create_new_thread(funct,param);
+ }
+ catch (std::bad_alloc&)
+ {
+ return false;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ inline void register_thread_end_handler (
+ T& obj,
+ void (T::*handler)()
+ )
+ {
+ DLIB_ASSERT(is_dlib_thread(),
+ "\tvoid register_thread_end_handler"
+ << "\n\tYou can't register a thread end handler for a thread dlib didn't spawn."
+ );
+
+ threads_kernel_shared::thread_pool().register_thread_end_handler(obj,handler);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ inline void unregister_thread_end_handler (
+ T& obj,
+ void (T::*handler)()
+ )
+ {
+ // Check if the thread pool has been destroyed and if it has then don't do anything.
+ // This bool here is always true except when the program has started to terminate and
+ // the thread pool object has been destroyed. This if is here to catch other global
+ // objects that have destructors that try to call unregister_thread_end_handler().
+ // Without this check we get into trouble if the thread pool is destroyed before these
+ // objects.
+ if (threads_kernel_shared::thread_pool_has_been_destroyed == false)
+ threads_kernel_shared::thread_pool().unregister_thread_end_handler(obj,handler);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "threads_kernel_shared.cpp"
+#endif
+
+#endif // DLIB_THREADS_KERNEl_SHARED_
+
diff --git a/ml/dlib/dlib/threads/windows.h b/ml/dlib/dlib/threads/windows.h
new file mode 100644
index 000000000..f7c775950
--- /dev/null
+++ b/ml/dlib/dlib/threads/windows.h
@@ -0,0 +1,6 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_THREADS_KERNEl_2_
+#include "threads_kernel_1.h"
+#endif
+
diff --git a/ml/dlib/dlib/time_this.h b/ml/dlib/dlib/time_this.h
new file mode 100644
index 000000000..aec0d2de8
--- /dev/null
+++ b/ml/dlib/dlib/time_this.h
@@ -0,0 +1,36 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TIME_THIs_
+#define DLIB_TIME_THIs_
+
+
+#include <chrono>
+
+// ----------------------------------------------------------------------------------------
+
+#define TIME_THIS_TO(_tt_op,_tt_out) \
+ { \
+ auto _tt_start = std::chrono::high_resolution_clock::now(); \
+ {_tt_op;} \
+ auto _tt_stop = std::chrono::high_resolution_clock::now(); \
+ auto _tt_thetime = _tt_stop-_tt_start; \
+ using std::chrono::duration_cast; \
+ using std::chrono::duration; \
+ if (_tt_thetime >= std::chrono::minutes(1)) \
+ _tt_out << "\ntime: " << duration_cast<duration<double,std::ratio<60>>>(_tt_thetime).count() << "min\n"; \
+ else if (_tt_thetime >= std::chrono::seconds(1)) \
+ _tt_out << "\ntime: " << duration_cast<duration<double>>(_tt_thetime).count() << "sec\n"; \
+ else if (_tt_thetime >= std::chrono::milliseconds(1)) \
+ _tt_out << "\ntime: " << duration_cast<duration<double,std::milli>>(_tt_thetime).count() << "ms\n"; \
+ else if (_tt_thetime >= std::chrono::microseconds(1)) \
+ _tt_out << "\ntime: " << duration_cast<duration<double,std::micro>>(_tt_thetime).count() << "us\n"; \
+ else \
+ _tt_out << "\ntime: " << duration_cast<duration<double,std::nano>>(_tt_thetime).count() << "ns\n"; \
+ }
+
+#define TIME_THIS(_tt_op) TIME_THIS_TO(_tt_op,std::cout)
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_TIME_THIs_
+
diff --git a/ml/dlib/dlib/timeout.h b/ml/dlib/dlib/timeout.h
new file mode 100644
index 000000000..2c3590e46
--- /dev/null
+++ b/ml/dlib/dlib/timeout.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TIMEOUt_
+#define DLIB_TIMEOUt_
+
+#include "timeout/timeout.h"
+
+#endif // DLIB_TIMEOUt_
+
+
diff --git a/ml/dlib/dlib/timeout/timeout.h b/ml/dlib/dlib/timeout/timeout.h
new file mode 100644
index 000000000..663e0ca65
--- /dev/null
+++ b/ml/dlib/dlib/timeout/timeout.h
@@ -0,0 +1,200 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TIMEOUT_KERNEl_1_
+#define DLIB_TIMEOUT_KERNEl_1_
+
+#include "../threads.h"
+#include "../algs.h"
+#include "../misc_api.h"
+#include "timeout_abstract.h"
+#include "../uintn.h"
+#include "../timer.h"
+
+#ifdef _MSC_VER
+// this is to disable the "'this' : used in base member initializer list"
+// warning you get from some of the GUI objects since all the objects
+// require that their parent class be passed into their constructor.
+// In this case though it is totally safe so it is ok to disable this warning.
+#pragma warning(disable : 4355)
+#endif // _MSC_VER
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class timeout
+ {
+ /*!
+ INITIAL VALUE
+ - b == a pointer to some kind of bind object
+
+ CONVENTION
+ - b == a pointer to some kind of bind object
+ !*/
+
+ class bind
+ {
+ public:
+ virtual void go() = 0;
+ virtual ~bind() {}
+ };
+
+ template <typename T>
+ class functor : public bind
+ {
+ public:
+ functor(const T& f) : function(f) {}
+ T function;
+ void go() { function(); }
+ };
+
+ template <typename T, typename R>
+ class zero : public bind
+ {
+ public:
+ T* object;
+ R (T::*callback_function)();
+ void go() { (object->*callback_function)(); }
+
+ };
+
+ template <typename T, typename R, typename U>
+ class one : public bind
+ {
+ public:
+ T* object;
+ R (T::*callback_function)(U);
+ U val;
+ void go() { (object->*callback_function)(val); }
+ };
+
+ public:
+
+ // This typedef is here for backwards compatibility with previous versions of dlib.
+ typedef timeout kernel_1a;
+
+ template <
+ typename T
+ >
+ timeout (
+ T callback_function,
+ unsigned long ms_to_timeout
+ ) :
+ t(*this,&timeout::trigger_timeout)
+ {
+ b = new functor<T>(callback_function);
+ t.set_delay_time(ms_to_timeout);
+ t.start();
+ }
+
+ template <
+ typename T
+ >
+ timeout (
+ T& object,
+ void (T::*callback_function)(),
+ unsigned long ms_to_timeout
+ ):
+ t(*this,&timeout::trigger_timeout)
+ {
+ zero<T,void>* B = new zero<T,void>;
+ b = B;
+ B->object = &object;
+ B->callback_function = callback_function;
+ t.set_delay_time(ms_to_timeout);
+ t.start();
+ }
+
+ template <
+ typename T,
+ typename U
+ >
+ timeout (
+ T& object,
+ void (T::*callback_function)(U callback_function_argument),
+ unsigned long ms_to_timeout,
+ U callback_function_argument
+ ):
+ t(*this,&timeout::trigger_timeout)
+ {
+ one<T,void,U>* B = new one<T,void,U>;
+ b = B;
+ B->object = &object;
+ B->callback_function = callback_function;
+ B->val = callback_function_argument;
+ t.set_delay_time(ms_to_timeout);
+ t.start();
+ }
+
+ template <
+ typename T
+ >
+ timeout (
+ T& object,
+ int (T::*callback_function)(),
+ unsigned long ms_to_timeout
+ ):
+ t(*this,&timeout::trigger_timeout)
+ {
+ zero<T,int>* B = new zero<T,int>;
+ b = B;
+ B->object = &object;
+ B->callback_function = callback_function;
+ t.set_delay_time(ms_to_timeout);
+ t.start();
+ }
+
+ template <
+ typename T,
+ typename U
+ >
+ timeout (
+ T& object,
+ int (T::*callback_function)(U callback_function_argument),
+ unsigned long ms_to_timeout,
+ U callback_function_argument
+ ):
+ t(*this,&timeout::trigger_timeout)
+ {
+ one<T,int,U>* B = new one<T,int,U>;
+ b = B;
+ B->object = &object;
+ B->callback_function = callback_function;
+ B->val = callback_function_argument;
+ t.set_delay_time(ms_to_timeout);
+ t.start();
+ }
+
+ virtual ~timeout (
+ )
+ {
+ t.stop_and_wait();
+ delete b;
+ }
+
+ private:
+
+ void trigger_timeout ()
+ {
+ b->go();
+ t.stop();
+ }
+
+ dlib::timer<timeout> t;
+ bind* b;
+
+ // restricted functions
+ timeout(const timeout&); // copy constructor
+ timeout& operator=(const timeout&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TIMEOUT_KERNEl_1_
+
+
+
diff --git a/ml/dlib/dlib/timeout/timeout_abstract.h b/ml/dlib/dlib/timeout/timeout_abstract.h
new file mode 100644
index 000000000..c2401bb68
--- /dev/null
+++ b/ml/dlib/dlib/timeout/timeout_abstract.h
@@ -0,0 +1,188 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_TIMEOUT_KERNEl_ABSTRACT_
+#ifdef DLIB_TIMEOUT_KERNEl_ABSTRACT_
+
+#include "../threads.h"
+
+namespace dlib
+{
+
+ class timeout
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object provides a simple way to implement a timeout. An example will
+ make its use clear. Suppose we want to read from a socket but we want to
+ terminate the connection if the read takes longer than 10 seconds. This
+ could be accomplished as follows:
+
+ connection* con = a connection from somewhere;
+ {
+ // setup a timer that will call con->shutdown() in 10 seconds
+ timeout t(*con,&connection::shutdown,10000);
+ // Now call read on the connection. If this call to read() takes more
+ // than 10 seconds then the t timeout will trigger and shutdown the
+ // connection. If read completes in less than 10 seconds then the t
+ // object will be destructed on the next line due to the } and then the
+ // timeout won't trigger.
+ con->read(buf,100);
+ }
+
+
+ Alternatively, if you have a compiler capable of using C++11 lambda
+ functions, you can use a syntax like this:
+ {
+ timeout t([con](){ con->shutdown(); }, 10000);
+ con->read(buf,100);
+ }
+
+ More generally, you can use this with things other than sockets. For
+ example, the following statement will print "Hello world!" after 1000ms:
+ timeout t([](){ cout << "Hello world!" << endl; }, 1000);
+
+
+
+ THREAD SAFETY
+ All methods of this class are thread safe.
+ !*/
+
+ public:
+
+ template <
+ typename T
+ >
+ timeout (
+ T callback_function,
+ unsigned long ms_to_timeout
+ );
+ /*!
+ requires
+ - callback_function does not throw
+ ensures
+ - does not block.
+ - #*this is properly initialized
+ - if (this object isn't destructed in ms_to_timeout milliseconds) then
+ - callback_function() will be called in ms_to_timeout milliseconds.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ template <
+ typename T
+ >
+ timeout (
+ T& object,
+ void (T::*callback_function)(),
+ unsigned long ms_to_timeout
+ );
+ /*!
+ requires
+ - callback_function does not throw
+ ensures
+ - does not block.
+ - #*this is properly initialized
+ - if (this object isn't destructed in ms_to_timeout milliseconds) then
+ - (object.*callback_function)() will be called in ms_to_timeout
+ milliseconds.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ template <
+ typename T,
+ typename U
+ >
+ timeout (
+ T& object,
+ void (T::*callback_function)(U callback_function_argument),
+ unsigned long ms_to_timeout,
+ U callback_function_argument
+ );
+ /*!
+ requires
+ - callback_function does not throw
+ ensures
+ - does not block.
+ - #*this is properly initialized
+ - if (this object isn't destructed in ms_to_timeout milliseconds) then
+ - (object.*callback_function)(callback_function_argument) will be
+ called in ms_to_timeout milliseconds.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ template <
+ typename T
+ >
+ timeout (
+ T& object,
+ int (T::*callback_function)(),
+ unsigned long ms_to_timeout
+ );
+ /*!
+ requires
+ - callback_function does not throw
+ ensures
+ - does not block.
+ - #*this is properly initialized
+ - if (this object isn't destructed in ms_to_timeout milliseconds) then
+ - (object.*callback_function)() will be called in ms_to_timeout
+ milliseconds.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ template <
+ typename T,
+ typename U
+ >
+ timeout (
+ T& object,
+ int (T::*callback_function)(U callback_function_argument),
+ unsigned long ms_to_timeout,
+ U callback_function_argument
+ );
+ /*!
+ requires
+ - callback_function does not throw
+ ensures
+ - does not block.
+ - #*this is properly initialized
+ - if (this object isn't destructed in ms_to_timeout milliseconds) then
+ - (object.*callback_function)(callback_function_argument) will be
+ called in ms_to_timeout milliseconds.
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~timeout (
+ );
+ /*!
+ requires
+ - is not called from inside the callback_function given to the
+ constructor.
+ ensures
+ - any resources associated with *this have been released
+ - if (the callback_function hasn't been called yet) then
+ - the callback_function specified in the constructor will not be called
+ !*/
+
+ private:
+
+ // restricted functions
+ timeout(const timeout&); // copy constructor
+ timeout& operator=(const timeout&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_TIMEOUT_KERNEl_ABSTRACT_
+
+
diff --git a/ml/dlib/dlib/timer.h b/ml/dlib/dlib/timer.h
new file mode 100644
index 000000000..918078e4d
--- /dev/null
+++ b/ml/dlib/dlib/timer.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TIMEr_
+#define DLIB_TIMEr_
+
+#include "timer/timer.h"
+#include "timer/timer_heavy.h"
+
+#endif // DLIB_TIMEr_
+
diff --git a/ml/dlib/dlib/timer/timer.cpp b/ml/dlib/dlib/timer/timer.cpp
new file mode 100644
index 000000000..d12aa6f29
--- /dev/null
+++ b/ml/dlib/dlib/timer/timer.cpp
@@ -0,0 +1,235 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TIMER_cPPh_
+#define DLIB_TIMER_cPPh_
+
+#include "timer.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ timer_global_clock::
+ timer_global_clock(
+ ):
+ s(m),
+ shutdown(false),
+ running(false)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ timer_global_clock::
+ ~timer_global_clock()
+ {
+ // The only time this destructor is called is when
+ //
+ // a) the process terminates
+ // b) the dynamic library(.so/.dll) is unloaded (could be a part of a))
+ //
+ // in case of a)
+ // windows: the process termination is especially painful, since threads are killed
+ // before destructors of the process image .dll's are called.
+ // Thus, for the windows platform, there is no threads running, so the only thing
+ // to do here is just let the standard memberwise destructors run
+ // linux: it's ok to just signal shutdown and wait for the running thread, to exit
+ //
+ // in case of b)
+ // windows:
+ // if it's part of the termination process, a) applies
+ // if its part of user doing manual load_library/unload_library
+ // there is no (safe/robust)solution, but best practices are described here
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/dn633971.aspx
+ // to support such a clean shutdown, you are required to make a call prior to
+ // unload dll, that shutdown all the threads in the contained dll.
+ // This could be done in this module by providing a global_delete_clock()
+ //
+ // linux: the destructor for linux will do it's usual job regardless.
+ //
+
+ #ifndef _WIN32
+ m.lock();
+ shutdown = true;
+ s.signal();
+ m.unlock();
+ wait();
+ #endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void timer_global_clock::
+ add (
+ timer_base* r
+ )
+ {
+ if (r->in_global_clock == false)
+ {
+ // if the thread isn't running then start it up
+ if (!running)
+ {
+ start();
+ running = true;
+ }
+
+ uint64 t = ts.get_timestamp() + r->delay*1000;
+ tm.reset();
+ if (!tm.move_next() || t < tm.element().key())
+ {
+ // we need to make the thread adjust its next time to
+ // trigger if this new event occurrs sooner than the
+ // next event in tm
+ s.signal();
+ }
+ timer_base* rtemp = r;
+ uint64 ttemp = t;
+ tm.add(ttemp,rtemp);
+ r->next_time_to_run = t;
+ r->in_global_clock = true;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void timer_global_clock::
+ remove (
+ timer_base* r
+ )
+ {
+ if (r->in_global_clock)
+ {
+ tm.position_enumerator(r->next_time_to_run-1);
+ do
+ {
+ if (tm.element().value() == r)
+ {
+ uint64 t;
+ timer_base* rtemp;
+ tm.remove_current_element(t,rtemp);
+ r->in_global_clock = false;
+ break;
+ }
+ } while (tm.move_next());
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void timer_global_clock::
+ adjust_delay (
+ timer_base* r,
+ unsigned long new_delay
+ )
+ {
+ if (r->in_global_clock)
+ {
+ remove(r);
+ // compute the new next_time_to_run and store it in t
+ uint64 t = r->next_time_to_run;
+ t -= r->delay*1000;
+ t += new_delay*1000;
+
+ tm.reset();
+ if (!tm.move_next() || t < tm.element().key())
+ {
+ // we need to make the thread adjust its next time to
+ // trigger if this new event occurrs sooner than the
+ // next event in tm
+ s.signal();
+ }
+
+ // set this incase add throws
+ r->running = false;
+ r->delay = new_delay;
+
+ timer_base* rtemp = r;
+ uint64 ttemp = t;
+ tm.add(ttemp,rtemp);
+ r->next_time_to_run = t;
+ r->in_global_clock = true;
+
+ // put this back now that we know add didn't throw
+ r->running = true;
+
+ }
+ else
+ {
+ r->delay = new_delay;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void timer_global_clock::
+ thread()
+ {
+ auto_mutex M(m);
+ while (!shutdown)
+ {
+ unsigned long delay = 100000;
+
+ tm.reset();
+ tm.move_next();
+ // loop and start all the action functions for timers that should have
+ // triggered.
+ while(tm.current_element_valid())
+ {
+ const uint64 cur_time = ts.get_timestamp();
+ uint64 t = tm.element().key();
+ // if the next event in tm is ready to trigger
+ if (t <= cur_time + 999)
+ {
+ // remove this event from the tm map
+ timer_base* r = tm.element().value();
+ timer_base* rtemp;
+ tm.remove_current_element(t,rtemp);
+ r->in_global_clock = false;
+
+ // if this timer is still "running" then start its action function
+ if (r->running)
+ {
+ r->restart();
+ }
+ }
+ else
+ {
+ // there aren't any more timers that should trigger so we compute
+ // the delay to the next timer event.
+ delay = static_cast<unsigned long>((t - cur_time)/1000);
+ break;
+ }
+ }
+
+ s.wait_or_timeout(delay);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::shared_ptr<timer_global_clock> get_global_clock()
+ {
+ static std::shared_ptr<timer_global_clock> d(new timer_global_clock);
+ return d;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // do this just to make sure get_global_clock() gets called at program startup
+ class timer_global_clock_helper
+ {
+ public:
+ timer_global_clock_helper()
+ {
+ get_global_clock();
+ }
+ };
+ static timer_global_clock_helper call_get_global_clock;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TIMER_cPPh_
+
diff --git a/ml/dlib/dlib/timer/timer.h b/ml/dlib/dlib/timer/timer.h
new file mode 100644
index 000000000..c8b6c0af6
--- /dev/null
+++ b/ml/dlib/dlib/timer/timer.h
@@ -0,0 +1,427 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TIMEr_Hh_
+#define DLIB_TIMEr_Hh_
+
+#include <memory>
+
+#include "../threads.h"
+#include "../algs.h"
+#include "../misc_api.h"
+#include "timer_abstract.h"
+#include "../uintn.h"
+#include "../binary_search_tree.h"
+#include "timer_heavy.h"
+
+namespace dlib
+{
+
+ struct timer_base : public threaded_object
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object contains the base members of the timer object.
+ It exists so that we can access them from outside any templated functions.
+ !*/
+
+ unsigned long delay;
+ // these are only modified by the global_clock
+ uint64 next_time_to_run;
+ timestamper ts;
+ bool running;
+ bool in_global_clock;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class timer_global_clock : private threaded_object
+ {
+ /*!
+ This object sets up a timer that triggers the action function
+ for timer objects that are tracked inside this object.
+ INITIAL VALUE
+ - shutdown == false
+ - running == false
+
+ CONVENTION
+ - if (shutdown) then
+ - thread() should terminate
+ - else (running) then
+ - thread() is running
+
+ - tm[time] == pointer to a timer_base object
+ !*/
+ typedef binary_search_tree<uint64,timer_base*,memory_manager<char>::kernel_2b>::kernel_2a_c time_map;
+ public:
+
+ ~timer_global_clock();
+
+ void add (
+ timer_base* r
+ );
+ /*!
+ requires
+ - m is locked
+ ensures
+ - starts the thread if it isn't already started
+ - adds r to tm
+ - #r->in_global_clock == true
+ - updates r->next_time_to_run appropriately according to
+ r->delay
+ !*/
+
+ void remove (
+ timer_base* r
+ );
+ /*!
+ requires
+ - m is locked
+ ensures
+ - if (r is in tm) then
+ - removes r from tm
+ - #r->in_global_clock == false
+ !*/
+
+ void adjust_delay (
+ timer_base* r,
+ unsigned long new_delay
+ );
+ /*!
+ requires
+ - m is locked
+ ensures
+ - #r->delay == new_delay
+ - if (r->in_global_clock) then
+ - the time to the next event will have been appropriately adjusted
+ !*/
+
+ mutex m;
+
+ friend std::shared_ptr<timer_global_clock> get_global_clock();
+
+ private:
+ timer_global_clock();
+
+ time_map tm;
+ signaler s;
+ bool shutdown;
+ bool running;
+ timestamper ts;
+
+ void thread();
+ /*!
+ ensures
+ - spawns timer tasks as is appropriate
+ !*/
+ };
+ std::shared_ptr<timer_global_clock> get_global_clock();
+ /*!
+ ensures
+ - returns the global instance of the timer_global_clock object
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ class timer : private timer_base
+ {
+ /*!
+ INITIAL VALUE
+ - running == false
+ - delay == 1000
+ - ao == a pointer to the action_object()
+ - af == a pointer to the action_function()
+ - in_global_clock == false
+ - next_time_to_run == 0
+ - gc == get_global_clock()
+
+ CONVENTION
+ - the mutex used to lock everything is gc->m
+ - running == is_running()
+ - delay == delay_time()
+ - *ao == action_object()
+ - af == action_function()
+ - if (!running) then
+ - in_global_clock == false
+ - else
+ - next_time_to_run == the next time this timer should run according
+ to the timestamper in the global_clock
+ !*/
+
+ public:
+
+ // These typedefs are here for backwards compatibility with previous versions of
+ // dlib.
+ typedef timer_heavy<T> kernel_1a;
+ typedef timer kernel_2a;
+
+ typedef void (T::*af_type)();
+
+ timer(
+ T& ao_,
+ af_type af_
+ );
+
+ virtual ~timer(
+ );
+
+ void clear(
+ );
+
+ af_type action_function (
+ ) const;
+
+ const T& action_object (
+ ) const;
+
+ T& action_object (
+ );
+
+ bool is_running (
+ ) const;
+
+ unsigned long delay_time (
+ ) const;
+
+ void set_delay_time (
+ unsigned long milliseconds
+ );
+
+ void start (
+ );
+
+ void stop (
+ );
+
+ void stop_and_wait (
+ );
+
+ private:
+
+ void thread (
+ );
+ /*!
+ ensures
+ - calls the action function
+ !*/
+
+ // data members
+ T& ao;
+ const af_type af;
+ std::shared_ptr<timer_global_clock> gc;
+
+ // restricted functions
+ timer(const timer&); // copy constructor
+ timer& operator=(const timer&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ timer<T>::
+ timer(
+ T& ao_,
+ af_type af_
+ ) :
+ ao(ao_),
+ af(af_),
+ gc(get_global_clock())
+ {
+ delay = 1000;
+ next_time_to_run = 0;
+ running = false;
+ in_global_clock = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ timer<T>::
+ ~timer(
+ )
+ {
+ clear();
+ wait();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer<T>::
+ clear(
+ )
+ {
+ auto_mutex M(gc->m);
+ running = false;
+ gc->remove(this);
+ delay = 1000;
+ next_time_to_run = 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename timer<T>::af_type timer<T>::
+ action_function (
+ ) const
+ {
+ return af;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const T& timer<T>::
+ action_object (
+ ) const
+ {
+ return ao;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T& timer<T>::
+ action_object (
+ )
+ {
+ return ao;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool timer<T>::
+ is_running (
+ ) const
+ {
+ auto_mutex M(gc->m);
+ return running;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ unsigned long timer<T>::
+ delay_time (
+ ) const
+ {
+ auto_mutex M(gc->m);
+ return delay;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer<T>::
+ set_delay_time (
+ unsigned long milliseconds
+ )
+ {
+ auto_mutex M(gc->m);
+ gc->adjust_delay(this,milliseconds);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer<T>::
+ start (
+ )
+ {
+ auto_mutex M(gc->m);
+ if (!running)
+ {
+ gc->add(this);
+ running = true;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer<T>::
+ stop (
+ )
+ {
+ gc->m.lock();
+ running = false;
+ gc->remove(this);
+ gc->m.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer<T>::
+ thread (
+ )
+ {
+ // call the action function
+ (ao.*af)();
+ auto_mutex M(gc->m);
+ if (running)
+ {
+ gc->remove(this);
+ gc->add(this);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer<T>::
+ stop_and_wait (
+ )
+ {
+ gc->m.lock();
+ running = false;
+ gc->remove(this);
+ gc->m.unlock();
+ wait();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "timer.cpp"
+#endif
+
+#endif // DLIB_TIMEr_Hh_
+
+
diff --git a/ml/dlib/dlib/timer/timer_abstract.h b/ml/dlib/dlib/timer/timer_abstract.h
new file mode 100644
index 000000000..180cd490d
--- /dev/null
+++ b/ml/dlib/dlib/timer/timer_abstract.h
@@ -0,0 +1,190 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_TIMER_KERNEl_ABSTRACT_
+#ifdef DLIB_TIMER_KERNEl_ABSTRACT_
+
+#include "../threads.h"
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class timer
+ {
+ /*!
+ INITIAL VALUE
+ is_running() == false
+ delay_time() == 1000
+ action_object() == The object that is passed into the constructor
+ action_function() == The member function pointer that is passed to
+ the constructor.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a timer that will call a given member function
+ (the action function) repeatedly at regular intervals and in its own
+ thread.
+
+ Note that the delay_time() is measured in milliseconds but you are not
+ guaranteed to have that level of resolution. The actual resolution
+ is implementation dependent.
+
+ THREAD SAFETY
+ All methods of this class are thread safe.
+ !*/
+
+ public:
+
+ typedef void (T::*af_type)();
+
+ timer (
+ T& ao,
+ af_type af
+ );
+ /*!
+ requires
+ - af does not throw
+ ensures
+ - does not block.
+ - #*this is properly initialized
+ - #action_object() == ao
+ - #action_function() == af
+ (af is a member function pointer to a member in the class T)
+ throws
+ - std::bad_alloc
+ - dlib::thread_error
+ !*/
+
+ virtual ~timer (
+ );
+ /*!
+ requires
+ - is not called from inside the action_function()
+ ensures
+ - any resources associated with *this have been released
+ - will not call the action_function() anymore.
+ - if (the action function is currently executing) then
+ - blocks until it finishes
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ - does not block
+ throws
+ - std::bad_alloc or dlib::thread_error
+ If either of these exceptions are thrown then #*this is unusable
+ until clear() is called and succeeds.
+ !*/
+
+ af_type action_function (
+ ) const;
+ /*!
+ ensures
+ - does not block.
+ - returns a pointer to the member function of action_object() that is
+ called by *this.
+ !*/
+
+ const T& action_object (
+ ) const;
+ /*!
+ ensures
+ - does not block.
+ - returns a const reference to the object used to call the member
+ function pointer action_function()
+ !*/
+
+ T& action_object (
+ );
+ /*!
+ ensures
+ - does not block.
+ - returns a non-const reference to the object used to call the member
+ function pointer action_function()
+ !*/
+
+ bool is_running (
+ ) const;
+ /*!
+ ensures
+ - does not block.
+ - if (*this is currently scheduled to call the action_function()) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ unsigned long delay_time (
+ ) const;
+ /*!
+ ensures
+ - does not block.
+ - returns the amount of time, in milliseconds, that *this will wait between
+ the return of one call to the action_function() and the beginning of the
+ next call to the action_function().
+ !*/
+
+ void set_delay_time (
+ unsigned long milliseconds
+ );
+ /*!
+ ensures
+ - does not block.
+ - #delay_time() == milliseconds
+ throws
+ - std::bad_alloc or dlib::thread_error
+ If either of these exceptions are thrown then #is_running() == false
+ but otherwise this function succeeds
+ !*/
+
+ void start (
+ );
+ /*!
+ ensures
+ - does not block.
+ - if (is_running() == false) then
+ - #is_running() == true
+ - The action_function() will run in another thread.
+ - The first call to the action_function() will occur in roughly
+ delay_time() milliseconds.
+ - else
+ - this call to start() has no effect
+ throws
+ - dlib::thread_error or std::bad_alloc
+ If this exception is thrown then #is_running() == false but
+ otherwise this call to start() has no effect.
+ !*/
+
+ void stop (
+ );
+ /*!
+ ensures
+ - #is_running() == false
+ - does not block.
+ !*/
+
+ void stop_and_wait (
+ );
+ /*!
+ ensures
+ - #is_running() == false
+ - if (the action function is currently executing) then
+ - blocks until it finishes
+ !*/
+
+ private:
+
+ // restricted functions
+ timer(const timer<T>&); // copy constructor
+ timer<T>& operator=(const timer<T>&); // assignment operator
+
+ };
+
+}
+
+#endif // DLIB_TIMER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/timer/timer_heavy.h b/ml/dlib/dlib/timer/timer_heavy.h
new file mode 100644
index 000000000..693b91ad9
--- /dev/null
+++ b/ml/dlib/dlib/timer/timer_heavy.h
@@ -0,0 +1,392 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TIMER_KERNEl_1_
+#define DLIB_TIMER_KERNEl_1_
+
+#include "../threads.h"
+#include "../algs.h"
+#include "../misc_api.h"
+#include "timer_abstract.h"
+
+namespace dlib
+{
+
+ template <
+ typename T
+ >
+ class timer_heavy
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an implementation of the timer_abstract.h interface. It is very
+ simple and uses only one thread which is always alive in a timer_heavy.
+ The reason this object exists is for historical reasons. Originally, the
+ dlib::timer was a multi-implementation component and the timer_heavy was
+ its first implementation. It was superseded later by the more efficient
+ dlib::timer. However, timer_heavy is still around so that
+ dlib::timer::kernel_1a has something to refer to. This way, old client
+ code which somehow depends on the same thread always calling a timer action
+ function isn't going to be disrupted.
+
+
+ INITIAL VALUE
+ - running == false
+ - delay == 1000
+ - ao == a pointer to the action_object()
+ - af == a pointer to the action_function()
+ - m == a mutex that locks everything in this class
+ - s == a signaler for mutex m
+ - stop_running == false
+
+ CONVENTION
+ - running && !stop_running == is_running()
+ - delay == delay_time()
+ - *ao == action_object()
+ - af == action_function()
+
+ - if (running) then
+ - there is a thread running
+ - if (is_running()) then
+ - next_time_to_run == the time when the next execution of the action
+ function should occur. (the time is given by ts.get_timestamp())
+
+ - stop_running is used to tell the thread to quit. If it is
+ set to true then the thread should end.
+ !*/
+
+ public:
+
+ typedef void (T::*af_type)();
+
+ timer_heavy(
+ T& ao_,
+ af_type af_
+ );
+
+ virtual ~timer_heavy(
+ );
+
+ void clear(
+ );
+
+ af_type action_function (
+ ) const;
+
+ const T& action_object (
+ ) const;
+
+ T& action_object (
+ );
+
+ bool is_running (
+ ) const;
+
+ unsigned long delay_time (
+ ) const;
+
+ void set_delay_time (
+ unsigned long milliseconds
+ );
+
+ void start (
+ );
+
+ void stop (
+ );
+
+ void stop_and_wait (
+ );
+
+ private:
+
+ void thread (
+ );
+ /*!
+ requires
+ - is run in its own thread
+ ensures
+ - calls the action function for the given timer object in the manner
+ specified by timer_kernel_abstract.h
+ !*/
+
+ // data members
+ T& ao;
+ const af_type af;
+ unsigned long delay;
+ mutex m;
+ signaler s;
+
+ bool running;
+ bool stop_running;
+ timestamper ts;
+ uint64 next_time_to_run;
+
+ // restricted functions
+ timer_heavy(const timer_heavy<T>&); // copy constructor
+ timer_heavy<T>& operator=(const timer_heavy<T>&); // assignment operator
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ timer_heavy<T>::
+ timer_heavy(
+ T& ao_,
+ af_type af_
+ ) :
+ ao(ao_),
+ af(af_),
+ delay(1000),
+ s(m),
+ running(false),
+ stop_running(false)
+ {
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ timer_heavy<T>::
+ ~timer_heavy(
+ )
+ {
+ stop_and_wait();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer_heavy<T>::
+ clear(
+ )
+ {
+ m.lock();
+ stop_running = true;
+ delay = 1000;
+ s.broadcast();
+ m.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename timer_heavy<T>::af_type timer_heavy<T>::
+ action_function (
+ ) const
+ {
+ return af;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ const T& timer_heavy<T>::
+ action_object (
+ ) const
+ {
+ return ao;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ T& timer_heavy<T>::
+ action_object (
+ )
+ {
+ return ao;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ bool timer_heavy<T>::
+ is_running (
+ ) const
+ {
+ auto_mutex M(m);
+ return running && !stop_running;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ unsigned long timer_heavy<T>::
+ delay_time (
+ ) const
+ {
+ auto_mutex M(m);
+ return delay;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer_heavy<T>::
+ set_delay_time (
+ unsigned long milliseconds
+ )
+ {
+ m.lock();
+
+ // if (is_running()) then we should adjust next_time_to_run
+ if (running && !stop_running)
+ {
+ next_time_to_run -= delay*1000;
+ next_time_to_run += milliseconds*1000;
+ }
+
+ delay = milliseconds;
+ s.broadcast();
+ m.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer_heavy<T>::
+ start (
+ )
+ {
+ auto_mutex M(m);
+
+ // if (is_running() == false) then reset the countdown to the next call
+ // to the action_function()
+ if ( (running && !stop_running) == false)
+ next_time_to_run = ts.get_timestamp() + delay*1000;
+
+ stop_running = false;
+ if (running == false)
+ {
+ running = true;
+
+ // start the thread
+ if (create_new_thread<timer_heavy,&timer_heavy::thread>(*this) == false)
+ {
+ running = false;
+ throw dlib::thread_error("error creating new thread in timer_heavy::start");
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer_heavy<T>::
+ stop (
+ )
+ {
+ m.lock();
+ stop_running = true;
+ s.broadcast();
+ m.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer_heavy<T>::
+ thread (
+ )
+ {
+ auto_mutex M(m);
+ unsigned long delay_remaining;
+ uint64 current_time = ts.get_timestamp();
+
+ if (current_time < next_time_to_run)
+ delay_remaining = static_cast<unsigned long>((next_time_to_run-current_time)/1000);
+ else
+ delay_remaining = 0;
+
+ while (stop_running == false)
+ {
+ if (delay_remaining > 0)
+ s.wait_or_timeout(delay_remaining);
+
+ if (stop_running)
+ break;
+
+ current_time = ts.get_timestamp();
+ if (current_time < next_time_to_run)
+ {
+ // then we woke up too early so we should keep waiting
+ delay_remaining = static_cast<unsigned long>((next_time_to_run-current_time)/1000);
+
+ // rounding might make this be zero anyway. So if it is
+ // then we will say we have hit the next time to run.
+ if (delay_remaining > 0)
+ continue;
+ }
+
+ // call the action function
+ m.unlock();
+ (ao.*af)();
+ m.lock();
+
+ current_time = ts.get_timestamp();
+ next_time_to_run = current_time + delay*1000;
+ delay_remaining = delay;
+ }
+ running = false;
+ stop_running = false;
+ s.broadcast();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ void timer_heavy<T>::
+ stop_and_wait (
+ )
+ {
+ m.lock();
+ if (running)
+ {
+ // make the running thread terminate
+ stop_running = true;
+
+ s.broadcast();
+ // wait for the thread to quit
+ while (running)
+ s.wait();
+ }
+ m.unlock();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TIMER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/timing.h b/ml/dlib/dlib/timing.h
new file mode 100644
index 000000000..2d5116ade
--- /dev/null
+++ b/ml/dlib/dlib/timing.h
@@ -0,0 +1,196 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TImING_Hh_
+#define DLIB_TImING_Hh_
+
+#include <chrono>
+#include <atomic>
+#include <cstring>
+#include "string.h"
+
+#include <iostream>
+
+// ----------------------------------------------------------------------------------------
+
+/*!A timing
+
+ This set of functions is useful for determining how much time is spent
+ executing blocks of code. Consider the following example:
+
+ int main()
+ {
+ using namespace dlib::timing;
+ for (int i = 0; i < 10; ++i)
+ {
+ // timing block #1
+ start(1,"block #1");
+ dlib::sleep(500);
+ stop(1);
+
+ // timing block #2
+ start(2,"block #2");
+ dlib::sleep(1000);
+ stop(2);
+ }
+
+ print();
+ }
+
+ This program would output:
+ Timing report:
+ block #1: 5.0 seconds
+ block #2: 10.0 seconds
+
+ So we spent 5 seconds in block #1 and 10 seconds in block #2
+
+
+
+ Additionally, note that you can use an RAII style timing block object. For
+ example, if we wanted to find out how much time we spent in a loop a convenient
+ way to do this would be as follows:
+
+ int main()
+ {
+ using namespace dlib::timing;
+ for (int i = 0; i < 10; ++i)
+ {
+ block tb(1, "main loop");
+
+ dlib::sleep(1500);
+ }
+
+ print();
+ }
+
+ This program would output:
+ Timing report:
+ block main loop: 15.0 seconds
+
+!*/
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+ namespace timing
+ {
+ const int TIME_SLOTS = 500;
+ const int NAME_LENGTH = 40;
+
+ inline std::atomic<uint64_t>* time_buf()
+ {
+ static std::atomic<uint64_t> buf[TIME_SLOTS];
+ return buf;
+ }
+
+ inline char* name_buf(int i, const char* name)
+ {
+ static char buf[TIME_SLOTS][NAME_LENGTH] = {{0}};
+ // if this name buffer is empty then copy name into it
+ if (buf[i][0] == '\0')
+ {
+ using namespace std;
+ strncpy(buf[i], name, NAME_LENGTH-1);
+ buf[i][NAME_LENGTH-1] = '\0';
+ }
+ // return the name buffer
+ return buf[i];
+ }
+
+ inline uint64_t ts()
+ {
+ using namespace std::chrono;
+ return duration_cast<duration<double,std::nano>>(high_resolution_clock::now().time_since_epoch()).count();
+ }
+
+ inline void start(int i)
+ {
+ time_buf()[i] -= ts();
+ }
+
+ inline void start(int i, const char* name)
+ {
+ time_buf()[i] -= ts();
+ name_buf(i,name);
+ }
+
+ inline void stop(int i)
+ {
+ time_buf()[i] += ts();
+ }
+
+ inline void print()
+ {
+ using namespace std;
+ cout << "Timing report: " << endl;
+
+ // figure out how long the longest name is going to be.
+ unsigned long max_name_length = 0;
+ for (int i = 0; i < TIME_SLOTS; ++i)
+ {
+ string name;
+ // Check if the name buffer is empty. Use the name it contains if it isn't.
+ if (name_buf(i,"")[0] != '\0')
+ name = cast_to_string(i) + ": " + name_buf(i,"");
+ else
+ name = cast_to_string(i);
+ max_name_length = std::max<unsigned long>(max_name_length, name.size());
+ }
+
+ for (int i = 0; i < TIME_SLOTS; ++i)
+ {
+ if (time_buf()[i] != 0)
+ {
+ double time = time_buf()[i]/1000.0/1000.0;
+ string name;
+ // Check if the name buffer is empty. Use the name it contains if it isn't.
+ if (name_buf(i,"")[0] != '\0')
+ name = cast_to_string(i) + ": " + name_buf(i,"");
+ else
+ name = cast_to_string(i);
+
+ // make sure the name is always the same length. Do so by padding with spaces
+ if (name.size() < max_name_length)
+ name += string(max_name_length-name.size(),' ');
+
+ if (time < 1000)
+ cout << " " << name << ": " << time << " milliseconds" << endl;
+ else if (time < 1000*60)
+ cout << " " << name << ": " << time/1000.0 << " seconds" << endl;
+ else if (time < 1000*60*60)
+ cout << " " << name << ": " << time/1000.0/60.0 << " minutes" << endl;
+ else
+ cout << " " << name << ": " << time/1000.0/60.0/60.0 << " hours" << endl;
+ }
+ }
+ }
+
+ inline void clear()
+ {
+ for (int i = 0; i < TIME_SLOTS; ++i)
+ {
+ // clear timing buffer
+ time_buf()[i] = 0;
+ // clear name buffer
+ name_buf(i,"")[0] = '\0';
+ }
+ }
+
+ struct block
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an RAII tool for calling start() and stop()
+ !*/
+
+ block(int i):idx(i) {start(idx);}
+ block(int i, const char* str):idx(i) {start(idx,str);}
+ ~block() { stop(idx); }
+ const int idx;
+ };
+ }
+}
+
+
+#endif // DLIB_TImING_Hh_
+
diff --git a/ml/dlib/dlib/tokenizer.h b/ml/dlib/dlib/tokenizer.h
new file mode 100644
index 000000000..01b6fcf83
--- /dev/null
+++ b/ml/dlib/dlib/tokenizer.h
@@ -0,0 +1,33 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TOKENIZEr_
+#define DLIB_TOKENIZEr_
+
+#include "tokenizer/tokenizer_kernel_1.h"
+#include "tokenizer/tokenizer_kernel_c.h"
+
+
+namespace dlib
+{
+
+ class tokenizer
+ {
+ tokenizer() {}
+
+
+ public:
+
+ //----------- kernels ---------------
+
+ // kernel_1a
+ typedef tokenizer_kernel_1
+ kernel_1a;
+ typedef tokenizer_kernel_c<kernel_1a>
+ kernel_1a_c;
+
+
+ };
+}
+
+#endif // DLIB_TOKENIZEr_
+
diff --git a/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.cpp b/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.cpp
new file mode 100644
index 000000000..daa83184c
--- /dev/null
+++ b/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.cpp
@@ -0,0 +1,295 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TOKENIZER_KERNEL_1_CPp_
+#define DLIB_TOKENIZER_KERNEL_1_CPp_
+#include "tokenizer_kernel_1.h"
+
+#include <iostream>
+#include <cstdio>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ tokenizer_kernel_1::
+ tokenizer_kernel_1 (
+ ) :
+ headset(0),
+ bodyset(0),
+ have_peeked(false)
+ {
+ try
+ {
+ headset = new bool[UCHAR_MAX];
+ bodyset = new bool[UCHAR_MAX];
+
+ clear();
+ }
+ catch (...)
+ {
+ if (headset) delete [] headset;
+ if (bodyset) delete [] bodyset;
+ throw;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ tokenizer_kernel_1::
+ ~tokenizer_kernel_1 (
+ )
+ {
+ delete [] bodyset;
+ delete [] headset;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tokenizer_kernel_1::
+ clear(
+ )
+ {
+ using namespace std;
+
+ in = 0;
+ streambuf = 0;
+ have_peeked = false;
+
+ head = "_" + lowercase_letters() + uppercase_letters();
+ body = "_" + lowercase_letters() + uppercase_letters() + numbers();
+
+ for (unsigned long i = 0; i < UCHAR_MAX; ++i)
+ {
+ headset[i] = false;
+ bodyset[i] = false;
+ }
+
+ for (string::size_type i = 0; i < head.size(); ++i)
+ headset[static_cast<unsigned char>(head[i])] = true;
+ for (string::size_type i = 0; i < body.size(); ++i)
+ bodyset[static_cast<unsigned char>(body[i])] = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tokenizer_kernel_1::
+ set_stream (
+ std::istream& in_
+ )
+ {
+ in = &in_;
+ streambuf = in_.rdbuf();
+ have_peeked = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool tokenizer_kernel_1::
+ stream_is_set (
+ ) const
+ {
+ return (in != 0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::istream& tokenizer_kernel_1::
+ get_stream (
+ ) const
+ {
+ return *in;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tokenizer_kernel_1::
+ get_token (
+ int& type,
+ std::string& token
+ )
+ {
+ if (!have_peeked)
+ {
+ std::streambuf::int_type ch;
+ ch = streambuf->sbumpc();
+
+ switch (ch)
+ {
+ case EOF:
+ type = END_OF_FILE;
+ token.clear();
+ return;
+
+ case '\n':
+ type = END_OF_LINE;
+ token = "\n";
+ return;
+
+ case '\r':
+ case ' ':
+ case '\t':
+ type = WHITE_SPACE;
+ token = static_cast<char>(ch);
+ ch = streambuf->sgetc();
+ while ((ch == ' ' || ch == '\t' || ch == '\r') && ch != EOF)
+ {
+ token += static_cast<char>(ch);
+ ch = streambuf->snextc();
+ }
+ return;
+
+ default:
+ if (headset[static_cast<unsigned char>(ch)])
+ {
+ type = IDENTIFIER;
+ token = static_cast<char>(ch);
+ ch = streambuf->sgetc();
+ while ( bodyset[static_cast<unsigned char>(ch)] && ch != EOF )
+ {
+ token += static_cast<char>(ch);
+ ch = streambuf->snextc();
+ }
+ }
+ else if ('0' <= ch && ch <= '9')
+ {
+ type = NUMBER;
+ token = static_cast<char>(ch);
+ ch = streambuf->sgetc();
+ while (('0' <= ch && ch <= '9') && ch != EOF)
+ {
+ token += static_cast<char>(ch);
+ ch = streambuf->snextc();
+ }
+ }
+ else
+ {
+ type = CHAR;
+ token = static_cast<char>(ch);
+ }
+ return;
+ } // switch (ch)
+ }
+
+ // if we get this far it means we have peeked so we should
+ // return the peek data.
+ type = next_type;
+ token = next_token;
+ have_peeked = false;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int tokenizer_kernel_1::
+ peek_type (
+ ) const
+ {
+ const_cast<tokenizer_kernel_1*>(this)->get_token(next_type,next_token);
+ have_peeked = true;
+ return next_type;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string& tokenizer_kernel_1::
+ peek_token (
+ ) const
+ {
+ const_cast<tokenizer_kernel_1*>(this)->get_token(next_type,next_token);
+ have_peeked = true;
+ return next_token;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tokenizer_kernel_1::
+ swap (
+ tokenizer_kernel_1& item
+ )
+ {
+ exchange(in,item.in);
+ exchange(streambuf,item.streambuf);
+ exchange(head,item.head);
+ exchange(body,item.body);
+ exchange(bodyset,item.bodyset);
+ exchange(headset,item.headset);
+ exchange(have_peeked,item.have_peeked);
+ exchange(next_type,item.next_type);
+ exchange(next_token,item.next_token);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void tokenizer_kernel_1::
+ set_identifier_token (
+ const std::string& head_,
+ const std::string& body_
+ )
+ {
+ using namespace std;
+
+ head = head_;
+ body = body_;
+
+ for (unsigned long i = 0; i < UCHAR_MAX; ++i)
+ {
+ headset[i] = false;
+ bodyset[i] = false;
+ }
+
+ for (string::size_type i = 0; i < head.size(); ++i)
+ headset[static_cast<unsigned char>(head[i])] = true;
+ for (string::size_type i = 0; i < body.size(); ++i)
+ bodyset[static_cast<unsigned char>(body[i])] = true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string tokenizer_kernel_1::
+ get_identifier_head (
+ ) const
+ {
+ return head;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string tokenizer_kernel_1::
+ get_identifier_body (
+ ) const
+ {
+ return body;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string tokenizer_kernel_1::
+ lowercase_letters (
+ ) const
+ {
+ return std::string("abcdefghijklmnopqrstuvwxyz");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string tokenizer_kernel_1::
+ uppercase_letters (
+ ) const
+ {
+ return std::string("ABCDEFGHIJKLMNOPQRSTUVWXYZ");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string tokenizer_kernel_1::
+ numbers (
+ ) const
+ {
+ return std::string("0123456789");
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+#endif // DLIB_TOKENIZER_KERNEL_1_CPp_
+
diff --git a/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.h b/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.h
new file mode 100644
index 000000000..d67ae278f
--- /dev/null
+++ b/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.h
@@ -0,0 +1,155 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TOKENIZER_KERNEl_1_
+#define DLIB_TOKENIZER_KERNEl_1_
+
+#include <string>
+#include <iosfwd>
+#include <climits>
+#include "../algs.h"
+#include "tokenizer_kernel_abstract.h"
+
+namespace dlib
+{
+
+ class tokenizer_kernel_1
+ {
+ /*!
+ INITIAL VALUE
+ - in == 0
+ - streambuf == 0
+ - have_peeked == false
+ - head == "_" + lowercase_letters() + uppercase_letters()
+ - body == "_" + lowercase_letters() + uppercase_letters() + numbers()
+ - headset == pointer to an array of UCHAR_MAX bools and set according
+ to the CONVENTION.
+ - bodyset == pointer to an array of UCHAR_MAX bools and set according
+ to the CONVENTION.
+
+ CONVENTION
+ - if (stream_is_set()) then
+ - get_stream() == *in
+ - streambuf == in->rdbuf()
+ - else
+ - in == 0
+ - streambuf == 0
+
+ - body == get_identifier_body()
+ - head == get_identifier_head()
+
+ - if (the char x appears in head) then
+ - headset[static_cast<unsigned char>(x)] == true
+ - else
+ - headset[static_cast<unsigned char>(x)] == false
+
+ - if (the char x appears in body) then
+ - bodyset[static_cast<unsigned char>(x)] == true
+ - else
+ - bodyset[static_cast<unsigned char>(x)] == false
+
+ - if (have_peeked) then
+ - next_token == the next token to be returned from get_token()
+ - next_type == the type of token in peek_token
+ !*/
+
+ public:
+
+ // The name of this enum is irrelevant but on some compilers (gcc on MAC OS X) not having it named
+ // causes an error for whatever reason
+ enum some_random_name
+ {
+ END_OF_LINE,
+ END_OF_FILE,
+ IDENTIFIER,
+ CHAR,
+ NUMBER,
+ WHITE_SPACE
+ };
+
+ tokenizer_kernel_1 (
+ );
+
+ virtual ~tokenizer_kernel_1 (
+ );
+
+ void clear(
+ );
+
+ void set_stream (
+ std::istream& in
+ );
+
+ bool stream_is_set (
+ ) const;
+
+ std::istream& get_stream (
+ ) const;
+
+ void get_token (
+ int& type,
+ std::string& token
+ );
+
+ void swap (
+ tokenizer_kernel_1& item
+ );
+
+ void set_identifier_token (
+ const std::string& head,
+ const std::string& body
+ );
+
+ int peek_type (
+ ) const;
+
+ const std::string& peek_token (
+ ) const;
+
+ const std::string get_identifier_head (
+ ) const;
+
+ const std::string get_identifier_body (
+ ) const;
+
+ const std::string lowercase_letters (
+ ) const;
+
+ const std::string uppercase_letters (
+ ) const;
+
+ const std::string numbers (
+ ) const;
+
+ private:
+
+ // restricted functions
+ tokenizer_kernel_1(const tokenizer_kernel_1&); // copy constructor
+ tokenizer_kernel_1& operator=(const tokenizer_kernel_1&); // assignment operator
+
+
+ // data members
+ std::istream* in;
+ std::streambuf* streambuf;
+ std::string head;
+ std::string body;
+ bool* headset;
+ bool* bodyset;
+
+ mutable std::string next_token;
+ mutable int next_type;
+ mutable bool have_peeked;
+ };
+
+ inline void swap (
+ tokenizer_kernel_1& a,
+ tokenizer_kernel_1& b
+ ) { a.swap(b); }
+
+}
+
+#ifdef NO_MAKEFILE
+#include "tokenizer_kernel_1.cpp"
+#endif
+
+#endif // DLIB_TOKENIZER_KERNEl_1
+
diff --git a/ml/dlib/dlib/tokenizer/tokenizer_kernel_abstract.h b/ml/dlib/dlib/tokenizer/tokenizer_kernel_abstract.h
new file mode 100644
index 000000000..f534b8f7f
--- /dev/null
+++ b/ml/dlib/dlib/tokenizer/tokenizer_kernel_abstract.h
@@ -0,0 +1,289 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_TOKENIZER_KERNEl_ABSTRACT_
+#ifdef DLIB_TOKENIZER_KERNEl_ABSTRACT_
+
+#include <string>
+#include <ioswfd>
+
+namespace dlib
+{
+
+ class tokenizer
+ {
+ /*!
+ INITIAL VALUE
+ stream_is_set() == false
+ get_identifier_head() == "_" + lowercase_letters() + uppercase_letters()
+ get_identifier_body() == "_" + lowercase_letters() + uppercase_letters() +
+ numbers()
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple tokenizer for textual data.
+
+ BUFFERING
+ This object is allowed to buffer data from the input stream.
+ Thus if you clear it or switch streams (via calling set_stream())
+ any buffered data will be lost.
+
+ TOKENS
+ When picking out tokens the tokenizer will always extract the
+ longest token it can. For example, if faced with the string
+ "555" it will consider the three 5s to be a single NUMBER
+ token not three smaller NUMBER tokens.
+
+ Also note that no characters in the input stream are discarded.
+ They will all be returned in the text of some token.
+ Additionally, each character will never be returned more than once.
+ This means that if you concatenated all returned tokens it would exactly
+ reproduce the contents of the input stream.
+
+ The tokens are defined as follows:
+
+ END_OF_LINE
+ This is a single character token and is always the '\n'
+ character.
+
+ END_OF_FILE
+ This token represents the end of file. It doesn't have any
+ actual characters associated with it.
+
+ IDENTIFIER
+ This is a multi-character token. It is defined as a string that
+ begins with a character from get_identifier_head() and is
+ followed by any number of characters from get_identifier_body().
+
+ NUMBER
+ This is a multi-character token. It is defined as a sequence of
+ numbers.
+
+ WHITE_SPACE
+ This is a multi character token. It is defined as a sequence of
+ one or more spaces, carrage returns, and tabs. I.e. It is
+ composed of characters from the following string " \r\t".
+
+ CHAR
+ This is a single character token. It matches anything that isn't
+ part of one of the above tokens.
+ !*/
+
+ public:
+
+ enum
+ {
+ END_OF_LINE,
+ END_OF_FILE,
+ IDENTIFIER,
+ CHAR,
+ NUMBER,
+ WHITE_SPACE
+ };
+
+ tokenizer (
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~tokenizer (
+ );
+ /*!
+ ensures
+ - any resources associated with *this have been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ If this exception is thrown then #*this is unusable
+ until clear() is called and succeeds.
+ !*/
+
+ void set_stream (
+ std::istream& in
+ );
+ /*!
+ ensures
+ - #*this will read data from in and tokenize it
+ - #stream_is_set() == true
+ - #get_stream() == in
+ !*/
+
+ bool stream_is_set (
+ ) const;
+ /*!
+ ensures
+ - returns true if a stream has been associated with *this by calling
+ set_stream()
+ !*/
+
+ std::istream& get_stream (
+ ) const;
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - returns a reference to the istream object that *this is reading
+ from.
+ !*/
+
+ void get_token (
+ int& type,
+ std::string& token
+ );
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - #token == the next token from the input stream get_stream()
+ - #type == the type of the token in #token
+ throws
+ - bad_alloc
+ If this exception is thrown then the call to this function will
+ have no effect on *this but the values of #type and #token will be
+ undefined. Additionally, some characters may have been read
+ from the stream get_stream() and lost.
+ !*/
+
+ int peek_type (
+ ) const;
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - returns the type of the token that will be returned from
+ the next call to get_token()
+ throws
+ - bad_alloc
+ If this exception is thrown then the call to this function will
+ have no effect on *this. However, some characters may have been
+ read from the stream get_stream() and lost.
+ !*/
+
+ const std::string& peek_token (
+ ) const;
+ /*!
+ requires
+ - stream_is_set() == true
+ ensures
+ - returns the text of the token that will be returned from
+ the next call to get_token()
+ throws
+ - bad_alloc
+ If this exception is thrown then the call to this function will
+ have no effect on *this. However, some characters may have been
+ read from the stream get_stream() and lost.
+ !*/
+
+ void set_identifier_token (
+ const std::string& head,
+ const std::string& body
+ );
+ /*!
+ requires
+ - head.find_first_of(" \r\t\n0123456789") == std::string::npos
+ (i.e. head doesn't contain any characters from the string
+ " \r\t\n0123456789").
+ - body.find_frst_of(" \r\t\n") == std::string::npos
+ (i.e. body doesn't contain any characters from the string " \r\t\n").
+ ensures
+ - #get_identifier_head() == head
+ - #get_identifier_body() == body
+ throws
+ - std::bad_alloc
+ If this exception is thrown then #*this is unusable
+ until clear() is called and succeeds.
+ !*/
+
+ const std::string get_identifier_head (
+ ) const;
+ /*!
+ ensures
+ - returns a string containing the characters that can be the start
+ of an IDENTIFIER token.
+ throws
+ - std::bad_alloc
+ If this exception is thrown then the call to this function
+ has no effect.
+ !*/
+
+ const std::string get_identifier_body (
+ ) const;
+ /*!
+ ensures
+ - returns a string containing the characters that can appear in the
+ body of an IDENTIFIER token.
+ throws
+ - std::bad_alloc
+ If this exception is thrown then the call to this function
+ has no effect.
+ !*/
+
+ const std::string lowercase_letters (
+ ) const;
+ /*!
+ ensures
+ - returns "abcdefghijklmnopqrstuvwxyz"
+ throws
+ - std::bad_alloc
+ If this exception is thrown then the call to this function
+ has no effect.
+ !*/
+
+ const std::string uppercase_letters (
+ ) const;
+ /*!
+ ensures
+ - returns "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ throws
+ - std::bad_alloc
+ If this exception is thrown then the call to this function
+ has no effect.
+ !*/
+
+ const std::string numbers (
+ ) const;
+ /*!
+ ensures
+ - returns "0123456789"
+ throws
+ - std::bad_alloc
+ If this exception is thrown then the call to this function
+ has no effect.
+ !*/
+
+ void swap (
+ tokenizer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ private:
+
+ // restricted functions
+ tokenizer(const tokenizer&); // copy constructor
+ tokenizer& operator=(const tokenizer&); // assignment operator
+
+ };
+
+ inline void swap (
+ tokenizer& a,
+ tokenizer& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+}
+
+#endif // DLIB_TOKENIZER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/tokenizer/tokenizer_kernel_c.h b/ml/dlib/dlib/tokenizer/tokenizer_kernel_c.h
new file mode 100644
index 000000000..f9604809d
--- /dev/null
+++ b/ml/dlib/dlib/tokenizer/tokenizer_kernel_c.h
@@ -0,0 +1,167 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TOKENIZER_KERNEl_C_
+#define DLIB_TOKENIZER_KERNEl_C_
+
+#include "tokenizer_kernel_abstract.h"
+#include "../assert.h"
+#include <string>
+#include <iostream>
+
+namespace dlib
+{
+
+ template <
+ typename tokenizer
+ >
+ class tokenizer_kernel_c : public tokenizer
+ {
+
+ public:
+ std::istream& get_stream (
+ ) const;
+
+ void get_token (
+ int& type,
+ std::string& token
+ );
+
+ void set_identifier_token (
+ const std::string& head,
+ const std::string& body
+ );
+
+ int peek_type (
+ ) const;
+
+ const std::string& peek_token (
+ ) const;
+ };
+
+ template <
+ typename tokenizer
+ >
+ inline void swap (
+ tokenizer_kernel_c<tokenizer>& a,
+ tokenizer_kernel_c<tokenizer>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tokenizer
+ >
+ void tokenizer_kernel_c<tokenizer>::
+ set_identifier_token (
+ const std::string& head,
+ const std::string& body
+ )
+ {
+ using namespace std;
+ // make sure requires clause is not broken
+ DLIB_CASSERT( head.find_first_of(" \r\t\n0123456789") == string::npos &&
+ body.find_first_of(" \r\t\n") == string::npos ,
+ "\tvoid tokenizer::set_identifier_token()"
+ << "\n\tyou can't define the IDENTIFIER token this way."
+ << "\n\thead: " << head
+ << "\n\tbody: " << body
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ tokenizer::set_identifier_token(head,body);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tokenizer
+ >
+ std::istream& tokenizer_kernel_c<tokenizer>::
+ get_stream (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tstd::istream& tokenizer::get_stream()"
+ << "\n\tyou must set a stream for this object before you can get it"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return tokenizer::get_stream();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tokenizer
+ >
+ int tokenizer_kernel_c<tokenizer>::
+ peek_type (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tint tokenizer::peek_type()"
+ << "\n\tyou must set a stream for this object before you peek at what it contains"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return tokenizer::peek_type();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tokenizer
+ >
+ const std::string& tokenizer_kernel_c<tokenizer>::
+ peek_token (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tint tokenizer::peek_token()"
+ << "\n\tyou must set a stream for this object before you peek at what it contains"
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ return tokenizer::peek_token();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename tokenizer
+ >
+ void tokenizer_kernel_c<tokenizer>::
+ get_token (
+ int& type,
+ std::string& token
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_CASSERT( this->stream_is_set() == true,
+ "\tvoid tokenizer::get_token()"
+ << "\n\tyou must set a stream for this object before you can get tokens from it."
+ << "\n\tthis: " << this
+ );
+
+ // call the real function
+ tokenizer::get_token(type,token);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TOKENIZER_KERNEl_C_
+
+
diff --git a/ml/dlib/dlib/travis/build-and-test.sh b/ml/dlib/dlib/travis/build-and-test.sh
new file mode 100755
index 000000000..4ee74e36b
--- /dev/null
+++ b/ml/dlib/dlib/travis/build-and-test.sh
@@ -0,0 +1,45 @@
+#!/usr/bin/env bash
+# Exit if anything fails.
+set -eux
+
+# execute the contents of MATRIX_EVAL if it's set
+if [[ -v MATRIX_EVAL ]]; then
+ eval "${MATRIX_EVAL}"
+fi
+
+# build dlib and tests
+if [ "$VARIANT" = "test" ]; then
+ mkdir build
+ cd build
+ cmake ../dlib/test
+ cmake --build . --target dtest -- -j 2
+ ./dtest --runall
+fi
+
+if [ "$VARIANT" = "dlib_all_source_cpp" ]; then
+ mkdir build
+ cd build
+ cmake ../dlib/test
+ cmake --build . --target dlib_all_source_cpp -- -j 2
+fi
+
+if [ "$VARIANT" = "tools" ]; then
+ mkdir build
+ cd build
+ cmake ../dlib/test/tools
+ cmake --build . -- -j 2
+fi
+
+if [ "$VARIANT" = "examples" ]; then
+ mkdir build
+ cd build
+ cmake ../examples
+ cmake --build . -- -j 1
+fi
+
+if [ "$VARIANT" = "python-api" ]; then
+ python setup.py test --clean
+ pip uninstall numpy -y
+ python setup.py test --clean
+fi
+
diff --git a/ml/dlib/dlib/tuple.h b/ml/dlib/dlib/tuple.h
new file mode 100644
index 000000000..2cecddeef
--- /dev/null
+++ b/ml/dlib/dlib/tuple.h
@@ -0,0 +1,10 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TUPLe_TOP_
+#define DLIB_TUPLe_TOP_
+
+#include "tuple/tuple.h"
+
+#endif // DLIB_TUPLe_TOPh_
+
+
diff --git a/ml/dlib/dlib/tuple/tuple.h b/ml/dlib/dlib/tuple/tuple.h
new file mode 100644
index 000000000..cac2b4ade
--- /dev/null
+++ b/ml/dlib/dlib/tuple/tuple.h
@@ -0,0 +1,410 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TUPLe_H_
+#define DLIB_TUPLe_H_
+
+#include "../enable_if.h"
+#include "../algs.h"
+#include "../serialize.h"
+#include "tuple_abstract.h"
+
+// ----------------------------------------------------------------------------------------
+
+#define DLIB_TUPLE_GLOBAL_HELPERS(N) \
+ DLIB_TUPLE_GH(N##0) DLIB_TUPLE_GH(N##1) DLIB_TUPLE_GH(N##2) DLIB_TUPLE_GH(N##3) \
+ DLIB_TUPLE_GH(N##4) DLIB_TUPLE_GH(N##5) DLIB_TUPLE_GH(N##6) DLIB_TUPLE_GH(N##7)
+
+#define DLIB_TUPLE_GH(N) DLIB_TUPLE_GET_INDEX(N) DLIB_TUPLE_GET_ITEM(N) DLIB_TUPLE_GET_HELPER_STRUCT(N)
+
+#define DLIB_TUPLE_MEMBER_GET(N) \
+ DLIB_TUPLE_MG(N##0) DLIB_TUPLE_MG(N##1) DLIB_TUPLE_MG(N##2) DLIB_TUPLE_MG(N##3) \
+ DLIB_TUPLE_MG(N##4) DLIB_TUPLE_MG(N##5) DLIB_TUPLE_MG(N##6) DLIB_TUPLE_MG(N##7)
+
+#define DLIB_TUPLE_GET_INDEX(N) \
+ template <class Q, class T> const typename enable_if<is_same_type<typename T::type##N,Q>, long>::type get_index (const T&) {return N;}
+
+#define DLIB_TUPLE_GET_ITEM(N) \
+ template <class Q,class T> const typename enable_if<is_same_type<typename T::type##N,Q>,Q>::type& get_item_const (const T& t) {return t.v##N;}\
+ template <class Q,class T> typename enable_if<is_same_type<typename T::type##N,Q>,Q>::type& get_item ( T& t) {return t.v##N;}
+
+
+#define DLIB_TUPLE_GET_HELPER_STRUCT(N) \
+ template <class T> struct get_helper<N,T> \
+ { \
+ typedef typename T::type##N type; \
+ static const type& get(const T& t) { return t.v##N; } \
+ static type& get( T& t) { return t.v##N; } \
+ };
+
+#define DLIB_TUPLE_TEMPLATE_LIST(N) \
+ class T##N##0 = null_type, class T##N##1 = null_type, class T##N##2 = null_type, class T##N##3 = null_type, \
+ class T##N##4 = null_type, class T##N##5 = null_type, class T##N##6 = null_type, class T##N##7 = null_type
+
+#define DLIB_TUPLE_VARIABLE_LIST(N) \
+ T##N##0 v##N##0; T##N##1 v##N##1; T##N##2 v##N##2; T##N##3 v##N##3; \
+ T##N##4 v##N##4; T##N##5 v##N##5; T##N##6 v##N##6; T##N##7 v##N##7; \
+ typedef T##N##0 type##N##0; typedef T##N##1 type##N##1; typedef T##N##2 type##N##2; \
+ typedef T##N##3 type##N##3; typedef T##N##4 type##N##4; typedef T##N##5 type##N##5; \
+ typedef T##N##6 type##N##6; typedef T##N##7 type##N##7;
+
+// ----------------------------------------------------------------------------------------
+
+namespace dlib
+{
+
+ struct null_type{};
+
+ // provide default serialization for the null_type
+ inline void serialize (
+ const null_type& ,
+ std::ostream&
+ ){}
+ inline void deserialize (
+ null_type& ,
+ std::istream&
+ ){}
+
+// ----------------------------------------------------------------------------------------
+
+ namespace tuple_helpers
+ {
+ template <long idx, class T> struct get_helper;
+
+ // use these preprocessor macros to declare all the global stuff used by the
+ // tuple member functions.
+ DLIB_TUPLE_GLOBAL_HELPERS(0)
+ DLIB_TUPLE_GLOBAL_HELPERS(01)
+ DLIB_TUPLE_GLOBAL_HELPERS(02)
+ DLIB_TUPLE_GLOBAL_HELPERS(03)
+
+ // ------------------------------------------------------------------------------------
+
+ // use templates to recursively enumerate everything in the tuple that isn't a null_type
+ template <
+ typename T,
+ typename F,
+ long i = 0,
+ typename enabled = void
+ >
+ struct for_each
+ {
+ static void go(
+ T& a,
+ F& funct
+ )
+ {
+ funct(a.template get<i>());
+ for_each<T,F,i+1>::go(a,funct);
+ }
+
+ static bool go(
+ T& a,
+ F& funct,
+ long idx
+ )
+ /*!
+ ensures
+ - returns true if the function was applied to the given index
+ - returns false if the index is invalid so the function wasn't
+ applied to anything
+ !*/
+ {
+ if (idx == i)
+ {
+ funct(a.template get<i>());
+ return true;
+ }
+ else
+ {
+ return for_each<T,F,i+1>::go(a,funct,idx);
+ }
+ }
+ };
+
+ template <bool v1, bool v2> struct template_or { const static bool value = true; };
+ template <> struct template_or<false,false> { const static bool value = false; };
+
+ // the base case of the recursion
+ template <
+ typename T,
+ typename F,
+ long i
+ >
+ struct for_each<T,F,i,typename enable_if<template_or<i == T::max_fields , is_same_type<null_type,typename T::template get_type<i>::type >::value> >::type >
+ {
+ static void go( T&, F& ) { }
+ static bool go( T&, F&, long ) { return false; }
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ // use templates to recursively enumerate everything in the tuple that isn't a null_type
+ template <
+ typename T,
+ long i = 0,
+ typename enabled = void
+ >
+ struct tuple_swap
+ {
+ static void go(
+ T& a,
+ T& b
+ )
+ {
+ exchange(a.template get<i>(), b.template get<i>());
+ tuple_swap<T,i+1>::go(a,b);
+ }
+ };
+
+ template <typename T, long i>
+ struct at_base_case
+ {
+
+ };
+
+ // the base case of the recursion
+ template <
+ typename T,
+ long i
+ >
+ struct tuple_swap<T,i,typename enable_if<template_or<i == T::max_fields, is_same_type<null_type,typename T::template get_type<i>::type >::value > >::type >
+ { static void go( T&, T& ) { } };
+
+ // ------------------------------------------------------------------------------------
+
+ struct tuple_serialize
+ {
+ tuple_serialize (std::ostream& out_) : out(out_){}
+ std::ostream& out;
+
+ template <typename T>
+ void operator() (
+ T& a
+ ) const { serialize(a,out); }
+ };
+
+ // ------------------------------------------------------------------------------------
+
+ struct tuple_deserialize
+ {
+ tuple_deserialize (std::istream& in_) : in(in_){}
+ std::istream& in;
+ template <typename T>
+ void operator() (
+ T& a
+ ) const { deserialize(a,in); }
+ };
+
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ // use these preprocessor macros to declare 4*8 template arguments (below we count them in octal)
+ template <
+ DLIB_TUPLE_TEMPLATE_LIST(0), // args 00-07
+ DLIB_TUPLE_TEMPLATE_LIST(01), // args 010-017
+ DLIB_TUPLE_TEMPLATE_LIST(02), // args 020-027
+ DLIB_TUPLE_TEMPLATE_LIST(03) // args 030-037
+ >
+ class tuple
+ {
+ public:
+
+ // use these macros to declare 8*4 member variables
+ DLIB_TUPLE_VARIABLE_LIST(0)
+ DLIB_TUPLE_VARIABLE_LIST(01)
+ DLIB_TUPLE_VARIABLE_LIST(02)
+ DLIB_TUPLE_VARIABLE_LIST(03)
+
+ const static long max_fields = 4*8;
+
+ template < long idx >
+ struct get_type
+ {
+ typedef typename tuple_helpers::get_helper<idx,tuple>::type type;
+ };
+
+ template < long idx >
+ const typename tuple_helpers::get_helper<idx,tuple>::type& get (
+ ) const { return tuple_helpers::get_helper<idx,tuple>::get(*this); }
+
+ template < long idx >
+ typename tuple_helpers::get_helper<idx,tuple>::type& get (
+ ) { return tuple_helpers::get_helper<idx,tuple>::get(*this); }
+
+ template < class Q>
+ long index (
+ ) const { return tuple_helpers::get_index<Q>(*this); }
+
+ template <class Q>
+ const Q& get (
+ ) const {return tuple_helpers::get_item_const<Q>(*this);}
+
+ template <class Q>
+ Q& get (
+ ) {return tuple_helpers::get_item<Q>(*this);}
+
+
+
+
+ template <typename F>
+ void for_index (
+ F& funct,
+ long idx
+ )
+ {
+ // do this #ifdef stuff to avoid getting a warning about valid_idx not being
+ // used when ENABLE_ASSERTS isn't defined.
+#ifdef ENABLE_ASSERTS
+ const bool valid_idx = tuple_helpers::for_each<tuple,F>::go(*this,funct,idx);
+#else
+ tuple_helpers::for_each<tuple,F>::go(*this,funct,idx);
+#endif
+ DLIB_ASSERT(valid_idx,
+ "\tvoid tuple::for_index()"
+ << "\n\tYou have attempted to call for_index() with an index out of the valid range"
+ << "\n\tidx: " << idx
+ << "\n\tthis: " << this
+ );
+ }
+
+ template <typename F>
+ void for_index (
+ F& funct,
+ long idx
+ ) const
+ {
+ // do this #ifdef stuff to avoid getting a warning about valid_idx not being
+ // used when ENABLE_ASSERTS isn't defined.
+#ifdef ENABLE_ASSERTS
+ const bool valid_idx = tuple_helpers::for_each<const tuple,F>::go(*this,funct,idx);
+#else
+ tuple_helpers::for_each<const tuple,F>::go(*this,funct,idx);
+#endif
+ DLIB_ASSERT(valid_idx,
+ "\tvoid tuple::for_index()"
+ << "\n\tYou have attempted to call for_index() with an index out of the valid range"
+ << "\n\tidx: " << idx
+ << "\n\tthis: " << this
+ );
+ }
+
+ template <typename F>
+ void for_index (
+ const F& funct,
+ long idx
+ )
+ {
+ // do this #ifdef stuff to avoid getting a warning about valid_idx not being
+ // used when ENABLE_ASSERTS isn't defined.
+#ifdef ENABLE_ASSERTS
+ const bool valid_idx = tuple_helpers::for_each<tuple,const F>::go(*this,funct,idx);
+#else
+ tuple_helpers::for_each<tuple,const F>::go(*this,funct,idx);
+#endif
+ DLIB_ASSERT(valid_idx,
+ "\tvoid tuple::for_index()"
+ << "\n\tYou have attempted to call for_index() with an index out of the valid range"
+ << "\n\tidx: " << idx
+ << "\n\tthis: " << this
+ );
+ }
+
+ template <typename F>
+ void for_index (
+ const F& funct,
+ long idx
+ ) const
+ {
+ // do this #ifdef stuff to avoid getting a warning about valid_idx not being
+ // used when ENABLE_ASSERTS isn't defined.
+#ifdef ENABLE_ASSERTS
+ const bool valid_idx = tuple_helpers::for_each<const tuple,const F>::go(*this,funct,idx);
+#else
+ tuple_helpers::for_each<const tuple,const F>::go(*this,funct,idx);
+#endif
+ DLIB_ASSERT(valid_idx,
+ "\tvoid tuple::for_index()"
+ << "\n\tYou have attempted to call for_index() with an index out of the valid range"
+ << "\n\tidx: " << idx
+ << "\n\tthis: " << this
+ );
+ }
+
+
+
+
+ template <typename F>
+ void for_each (
+ F& funct
+ ) { tuple_helpers::for_each<tuple,F>::go(*this,funct); }
+
+ template <typename F>
+ void for_each (
+ F& funct
+ ) const { tuple_helpers::for_each<const tuple,F>::go(*this,funct); }
+
+ template <typename F>
+ void for_each (
+ const F& funct
+ ) const { tuple_helpers::for_each<const tuple,const F>::go(*this,funct); }
+
+ template <typename F>
+ void for_each (
+ const F& funct
+ ) { tuple_helpers::for_each<tuple,const F>::go(*this,funct); }
+
+
+
+
+ inline friend void serialize (
+ tuple& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ item.for_each(tuple_helpers::tuple_serialize(out));
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing an object of type dlib::tuple<>");
+ }
+ }
+
+ inline friend void deserialize (
+ tuple& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ item.for_each(tuple_helpers::tuple_deserialize(in));
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing an object of type dlib::tuple<>");
+ }
+ }
+
+ inline friend void swap (
+ tuple& a,
+ tuple& b
+ )
+ {
+ tuple_helpers::tuple_swap<tuple>::go(a,b);
+ }
+
+ inline void swap(
+ tuple& item
+ ) { tuple_helpers::tuple_swap<tuple>::go(item,*this); }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TUPLe_H_
+
diff --git a/ml/dlib/dlib/tuple/tuple_abstract.h b/ml/dlib/dlib/tuple/tuple_abstract.h
new file mode 100644
index 000000000..aff9b8122
--- /dev/null
+++ b/ml/dlib/dlib/tuple/tuple_abstract.h
@@ -0,0 +1,302 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_TUPLe_ABSTRACT_H_
+#ifdef DLIB_TUPLe_ABSTRACT_H_
+
+#include "../algs.h"
+#include "../serialize.h"
+#include "tuple_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct null_type
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object is the default type used as the default
+ template argument to the tuple object's template arguments.
+
+ Also note that it has no state associated with it.
+ !*/
+ };
+
+ inline void serialize (
+ const null_type& ,
+ std::ostream&
+ ){}
+ inline void deserialize (
+ null_type& ,
+ std::istream&
+ ){}
+ /*!
+ Serialization support is provided for null_type because in some cases
+ it makes your code a little more concise and easier to deal with
+ when using tuple objects and serialization. The serialization literally
+ does nothing though.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T0 = null_type,
+ typename T1 = null_type,
+ typename T2 = null_type,
+ typename T3 = null_type,
+ ...
+ typename T31 = null_type
+ >
+ class tuple
+ {
+ /*!
+ INITIAL VALUE
+ Each object in the tuple is default initialized by its own constructor.
+ The tuple object itself does not modify them or add any additional
+ state.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a container of between 0 and 31 objects
+ where the objects contained are specified in the template
+ arguments.
+
+ EXAMPLE
+ We can declare a tuple that contains an int, a float, and a char like so:
+ tuple<int,float,char> ex;
+
+ Then we can access each of these by their index number. The index number
+ is just the order each type has in the template argument list. So we have:
+ ex.get<0>() = 5; // assign the int the value 5
+ ex.get<1>() = 3.14; // assign the float the value 3.14
+ ex.get<2>() = 'c'; // assign the char the value 'c'
+
+ Also, since we only have one of each type in this example tuple we can
+ unambiguously access each field in the tuple by their types. So for
+ example, we can use this syntax to access our fields:
+ ex.get<int>() // returns 5
+ ex.get<float>() // returns 3.14
+ ex.get<char>() // returns 'c'
+
+ We can also get the indexes of each of these fields like so:
+ ex.index<int>() // returns 0
+ ex.index<float>() // returns 1
+ ex.index<char>() // returns 2
+ !*/
+
+ public:
+ // the maximum number of items this tuple template can contain
+ const static long max_fields = 32;
+
+ template <long index>
+ struct get_type
+ {
+ typedef (the type of the Tindex template argument) type;
+ };
+
+ template <long index>
+ const get_type<index>::type& get (
+ ) const;
+ /*!
+ requires
+ - 0 <= index <= 31
+ ensures
+ - returns a const reference to the index(th) object contained
+ inside this tuple
+ !*/
+
+ template <long index>
+ get_type<index>::type& get (
+ );
+ /*!
+ requires
+ - 0 <= index <= 31
+ ensures
+ - returns a non-const reference to the index(th) object contained
+ inside this tuple
+ !*/
+
+ template <typename Q>
+ const long index (
+ ) const;
+ /*!
+ requires
+ - Q is a type of object contained in this tuple and there is
+ only one object of that type in the tuple
+ ensures
+ - returns the index of the object in this tuple with type Q
+ !*/
+
+ template <typename Q>
+ const Q& get (
+ ) const;
+ /*!
+ requires
+ - Q is a type of object contained in this tuple and there is
+ only one object of that type in the tuple
+ ensures
+ - returns a const reference to the object in this tuple
+ with type Q
+ !*/
+
+ template <typename Q>
+ Q& get (
+ );
+ /*!
+ requires
+ - Q is a type of object contained in this tuple and there is
+ only one object of that type in the tuple
+ ensures
+ - returns a non-const reference to the object in this tuple
+ with type Q
+ !*/
+
+ template <typename F>
+ void for_each (
+ F& funct
+ );
+ /*!
+ requires
+ - funct is a templated function object
+ ensures
+ - for each item X in this tuple that isn't a null_type object:
+ - calls funct(X);
+ !*/
+
+ template <typename F>
+ void for_each (
+ F& funct
+ ) const;
+ /*!
+ requires
+ - funct is a templated function object
+ ensures
+ - for each item X in this tuple that isn't a null_type object:
+ - calls funct(X);
+ !*/
+
+ template <typename F>
+ void for_each (
+ const F& funct
+ );
+ /*!
+ requires
+ - funct is a templated function object
+ ensures
+ - for each item X in this tuple that isn't a null_type object:
+ - calls funct(X);
+ !*/
+
+ template <typename F>
+ void for_each (
+ const F& funct
+ ) const;
+ /*!
+ requires
+ - funct is a templated function object
+ ensures
+ - for each item X in this tuple that isn't a null_type object:
+ - calls funct(X);
+ !*/
+
+ template <typename F>
+ void for_index (
+ F& funct,
+ long idx
+ );
+ /*!
+ requires
+ - funct is a templated function object
+ - 0 <= idx < max_fields && get_type<idx>::type != null_type
+ (i.e. idx must be the index of a non-null_type object in this tuple)
+ ensures
+ - calls funct(this->get<idx>());
+ !*/
+
+ template <typename F>
+ void for_index (
+ F& funct,
+ long idx
+ ) const;
+ /*!
+ requires
+ - funct is a templated function object
+ - 0 <= idx < max_fields && get_type<idx>::type != null_type
+ (i.e. idx must be the index of a non-null_type object in this tuple)
+ ensures
+ - calls funct(this->get<idx>());
+ !*/
+
+ template <typename F>
+ void for_index (
+ const F& funct,
+ long idx
+ );
+ /*!
+ requires
+ - funct is a templated function object
+ - 0 <= idx < max_fields && get_type<idx>::type != null_type
+ (i.e. idx must be the index of a non-null_type object in this tuple)
+ ensures
+ - calls funct(this->get<idx>());
+ !*/
+
+ template <typename F>
+ void for_index (
+ const F& funct,
+ long idx
+ ) const;
+ /*!
+ requires
+ - funct is a templated function object
+ - 0 <= idx < max_fields && get_type<idx>::type != null_type
+ (i.e. idx must be the index of a non-null_type object in this tuple)
+ ensures
+ - calls funct(this->get<idx>());
+ !*/
+
+ void swap (
+ tuple& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ // -------------------------------------------------
+ // global functions for tuple objects
+ // -------------------------------------------------
+
+ friend void swap (
+ tuple& a,
+ tuple& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+ friend void serialize (
+ const tuple& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+ friend void deserialize (
+ tuple& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TUPLe_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/type_safe_union.h b/ml/dlib/dlib/type_safe_union.h
new file mode 100644
index 000000000..d04bf0d17
--- /dev/null
+++ b/ml/dlib/dlib/type_safe_union.h
@@ -0,0 +1,11 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TYPE_SAFE_UNIOn_TOP_
+#define DLIB_TYPE_SAFE_UNIOn_TOP_
+
+#include "type_safe_union/type_safe_union_kernel.h"
+
+#endif // DLIB_TYPE_SAFE_UNIOn_TOP_
+
+
+
diff --git a/ml/dlib/dlib/type_safe_union/type_safe_union_kernel.h b/ml/dlib/dlib/type_safe_union/type_safe_union_kernel.h
new file mode 100644
index 000000000..76a171286
--- /dev/null
+++ b/ml/dlib/dlib/type_safe_union/type_safe_union_kernel.h
@@ -0,0 +1,711 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_TYPE_SAFE_UNIOn_h_
+#define DLIB_TYPE_SAFE_UNIOn_h_
+
+#include "type_safe_union_kernel_abstract.h"
+#include "../algs.h"
+#include "../noncopyable.h"
+#include "../serialize.h"
+#include <new>
+#include <iostream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class bad_type_safe_union_cast : public std::bad_cast
+ {
+ public:
+ virtual const char * what() const throw()
+ {
+ return "bad_type_safe_union_cast";
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ struct _void{};
+ inline void serialize( const _void&, std::ostream&){}
+ inline void deserialize( _void&, std::istream&){}
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T1,
+ typename T2 = _void,
+ typename T3 = _void,
+ typename T4 = _void,
+ typename T5 = _void,
+ typename T6 = _void,
+ typename T7 = _void,
+ typename T8 = _void,
+ typename T9 = _void,
+ typename T10 = _void,
+
+ typename T11 = _void,
+ typename T12 = _void,
+ typename T13 = _void,
+ typename T14 = _void,
+ typename T15 = _void,
+ typename T16 = _void,
+ typename T17 = _void,
+ typename T18 = _void,
+ typename T19 = _void,
+ typename T20 = _void
+ >
+ class type_safe_union : noncopyable
+ {
+ /*!
+ CONVENTION
+ - is_empty() == (type_identity == 0)
+ - contains<T>() == (type_identity == get_type_id<T>())
+ - mem.get() == the block of memory on the stack which is
+ where objects in the union are stored
+ !*/
+
+ private:
+
+ template <typename T, typename U>
+ void invoke_on (
+ T& obj,
+ U& item
+ ) const
+ {
+ obj(item);
+ }
+
+ template <typename T>
+ void invoke_on (
+ T& ,
+ _void
+ ) const
+ {
+ }
+
+
+ const static size_t max_size = tmax<tmax<tmax<tmax<tmax<tmax<tmax<tmax<tmax<tmax<
+ tmax<tmax<tmax<tmax<tmax<tmax<tmax<tmax<tmax<sizeof(T1),
+ sizeof(T2)>::value,
+ sizeof(T3)>::value,
+ sizeof(T4)>::value,
+ sizeof(T5)>::value,
+ sizeof(T6)>::value,
+ sizeof(T7)>::value,
+ sizeof(T8)>::value,
+ sizeof(T9)>::value,
+ sizeof(T10)>::value,
+ sizeof(T11)>::value,
+ sizeof(T12)>::value,
+ sizeof(T13)>::value,
+ sizeof(T14)>::value,
+ sizeof(T15)>::value,
+ sizeof(T16)>::value,
+ sizeof(T17)>::value,
+ sizeof(T18)>::value,
+ sizeof(T19)>::value,
+ sizeof(T20)>::value;
+
+ // --------------------------------------------
+
+ // member data
+ stack_based_memory_block<max_size> mem;
+ int type_identity;
+
+ // --------------------------------------------
+
+ template <typename T>
+ void validate_type() const
+ {
+ // ERROR: You are trying to get a type of object that isn't
+ // representable by this type_safe_union. I.e. The given
+ // type T isn't one of the ones given to this object's template
+ // arguments.
+ COMPILE_TIME_ASSERT(( is_same_type<T,T1>::value ||
+ is_same_type<T,T2>::value ||
+ is_same_type<T,T3>::value ||
+ is_same_type<T,T4>::value ||
+ is_same_type<T,T5>::value ||
+ is_same_type<T,T6>::value ||
+ is_same_type<T,T7>::value ||
+ is_same_type<T,T8>::value ||
+ is_same_type<T,T9>::value ||
+ is_same_type<T,T10>::value ||
+
+ is_same_type<T,T11>::value ||
+ is_same_type<T,T12>::value ||
+ is_same_type<T,T13>::value ||
+ is_same_type<T,T14>::value ||
+ is_same_type<T,T15>::value ||
+ is_same_type<T,T16>::value ||
+ is_same_type<T,T17>::value ||
+ is_same_type<T,T18>::value ||
+ is_same_type<T,T19>::value ||
+ is_same_type<T,T20>::value
+ ));
+
+ }
+
+
+ struct destruct_helper
+ {
+ template <typename T>
+ void operator() (T& item) const
+ {
+ item.~T();
+ }
+ };
+
+ void destruct (
+ )
+ /*!
+ ensures
+ - #is_empty() == true
+ !*/
+ {
+ // destruct whatever is in this object
+ apply_to_contents(destruct_helper());
+
+ // mark this object as being empty
+ type_identity = 0;
+ }
+
+ template <typename T>
+ void construct (
+ )
+ {
+ if (type_identity != get_type_id<T>())
+ {
+ destruct();
+ new(mem.get()) T();
+ type_identity = get_type_id<T>();
+ }
+ }
+
+ template <typename T>
+ void construct (
+ const T& item
+ )
+ {
+ if (type_identity != get_type_id<T>())
+ {
+ destruct();
+ new(mem.get()) T(item);
+ type_identity = get_type_id<T>();
+ }
+ }
+
+ template <typename T>
+ T& unchecked_get(
+ )
+ /*!
+ requires
+ - contains<T>() == true
+ ensures
+ - returns a non-const reference to the T object
+ !*/
+ {
+ return *static_cast<T*>(mem.get());
+ }
+
+ template <typename T>
+ const T& unchecked_get(
+ ) const
+ /*!
+ requires
+ - contains<T>() == true
+ ensures
+ - returns a const reference to the T object
+ !*/
+ {
+ return *static_cast<const T*>(mem.get());
+ }
+
+ template <typename T>
+ void operator() (T& item)
+ /*
+ This function is used by the swap function of this class. See that
+ function to see how this works.
+ */
+ {
+ exchange(get<T>(), item);
+ }
+
+ public:
+
+ typedef T1 type1;
+ typedef T2 type2;
+ typedef T3 type3;
+ typedef T4 type4;
+ typedef T5 type5;
+ typedef T6 type6;
+ typedef T7 type7;
+ typedef T8 type8;
+ typedef T9 type9;
+ typedef T10 type10;
+ typedef T11 type11;
+ typedef T12 type12;
+ typedef T13 type13;
+ typedef T14 type14;
+ typedef T15 type15;
+ typedef T16 type16;
+ typedef T17 type17;
+ typedef T18 type18;
+ typedef T19 type19;
+ typedef T20 type20;
+
+
+ type_safe_union() : type_identity(0)
+ {
+ }
+
+ template <typename T>
+ type_safe_union (
+ const T& item
+ ) : type_identity(0)
+ {
+ validate_type<T>();
+ construct(item);
+ }
+
+ ~type_safe_union()
+ {
+ destruct();
+ }
+
+ template <typename T>
+ static int get_type_id (
+ )
+ {
+ if (is_same_type<T,T1>::value) return 1;
+ if (is_same_type<T,T2>::value) return 2;
+ if (is_same_type<T,T3>::value) return 3;
+ if (is_same_type<T,T4>::value) return 4;
+ if (is_same_type<T,T5>::value) return 5;
+
+ if (is_same_type<T,T6>::value) return 6;
+ if (is_same_type<T,T7>::value) return 7;
+ if (is_same_type<T,T8>::value) return 8;
+ if (is_same_type<T,T9>::value) return 9;
+ if (is_same_type<T,T10>::value) return 10;
+
+ if (is_same_type<T,T11>::value) return 11;
+ if (is_same_type<T,T12>::value) return 12;
+ if (is_same_type<T,T13>::value) return 13;
+ if (is_same_type<T,T14>::value) return 14;
+ if (is_same_type<T,T15>::value) return 15;
+
+ if (is_same_type<T,T16>::value) return 16;
+ if (is_same_type<T,T17>::value) return 17;
+ if (is_same_type<T,T18>::value) return 18;
+ if (is_same_type<T,T19>::value) return 19;
+ if (is_same_type<T,T20>::value) return 20;
+
+ // return a number that doesn't match any of the
+ // valid states of type_identity
+ return -1;
+ }
+
+ template <typename T>
+ bool contains (
+ ) const
+ {
+ return type_identity == get_type_id<T>();
+ }
+
+ bool is_empty (
+ ) const
+ {
+ return type_identity == 0;
+ }
+
+
+ public:
+
+ template <
+ typename t1, typename t2, typename t3, typename t4, typename t5,
+ typename t6, typename t7, typename t8, typename t9, typename t10,
+ typename t11, typename t12, typename t13, typename t14, typename t15,
+ typename t16, typename t17, typename t18, typename t19, typename t20
+ >
+ friend void serialize (
+ const type_safe_union<t1,t2,t3,t4,t5,t6,t7,t8,t9,t10, t11,t12,t13,t14,t15,t16,t17,t18,t19,t20>& item,
+ std::ostream& out
+ );
+
+
+ template <
+ typename T
+ >
+ void apply_to_contents (
+ T& obj
+ )
+ {
+ switch (type_identity)
+ {
+ // do nothing because we are empty
+ case 0: break;
+
+ case 1: invoke_on(obj,unchecked_get<T1>()); break;
+ case 2: invoke_on(obj,unchecked_get<T2>()); break;
+ case 3: invoke_on(obj,unchecked_get<T3>()); break;
+ case 4: invoke_on(obj,unchecked_get<T4>()); break;
+ case 5: invoke_on(obj,unchecked_get<T5>()); break;
+
+ case 6: invoke_on(obj,unchecked_get<T6>()); break;
+ case 7: invoke_on(obj,unchecked_get<T7>()); break;
+ case 8: invoke_on(obj,unchecked_get<T8>()); break;
+ case 9: invoke_on(obj,unchecked_get<T9>()); break;
+ case 10: invoke_on(obj,unchecked_get<T10>()); break;
+
+ case 11: invoke_on(obj,unchecked_get<T11>()); break;
+ case 12: invoke_on(obj,unchecked_get<T12>()); break;
+ case 13: invoke_on(obj,unchecked_get<T13>()); break;
+ case 14: invoke_on(obj,unchecked_get<T14>()); break;
+ case 15: invoke_on(obj,unchecked_get<T15>()); break;
+
+ case 16: invoke_on(obj,unchecked_get<T16>()); break;
+ case 17: invoke_on(obj,unchecked_get<T17>()); break;
+ case 18: invoke_on(obj,unchecked_get<T18>()); break;
+ case 19: invoke_on(obj,unchecked_get<T19>()); break;
+ case 20: invoke_on(obj,unchecked_get<T20>()); break;
+ }
+ }
+
+ template <
+ typename T
+ >
+ void apply_to_contents (
+ const T& obj
+ )
+ {
+ switch (type_identity)
+ {
+ // do nothing because we are empty
+ case 0: break;
+
+ case 1: invoke_on(obj,unchecked_get<T1>()); break;
+ case 2: invoke_on(obj,unchecked_get<T2>()); break;
+ case 3: invoke_on(obj,unchecked_get<T3>()); break;
+ case 4: invoke_on(obj,unchecked_get<T4>()); break;
+ case 5: invoke_on(obj,unchecked_get<T5>()); break;
+
+ case 6: invoke_on(obj,unchecked_get<T6>()); break;
+ case 7: invoke_on(obj,unchecked_get<T7>()); break;
+ case 8: invoke_on(obj,unchecked_get<T8>()); break;
+ case 9: invoke_on(obj,unchecked_get<T9>()); break;
+ case 10: invoke_on(obj,unchecked_get<T10>()); break;
+
+ case 11: invoke_on(obj,unchecked_get<T11>()); break;
+ case 12: invoke_on(obj,unchecked_get<T12>()); break;
+ case 13: invoke_on(obj,unchecked_get<T13>()); break;
+ case 14: invoke_on(obj,unchecked_get<T14>()); break;
+ case 15: invoke_on(obj,unchecked_get<T15>()); break;
+
+ case 16: invoke_on(obj,unchecked_get<T16>()); break;
+ case 17: invoke_on(obj,unchecked_get<T17>()); break;
+ case 18: invoke_on(obj,unchecked_get<T18>()); break;
+ case 19: invoke_on(obj,unchecked_get<T19>()); break;
+ case 20: invoke_on(obj,unchecked_get<T20>()); break;
+ }
+ }
+
+ template <
+ typename T
+ >
+ void apply_to_contents (
+ T& obj
+ ) const
+ {
+ switch (type_identity)
+ {
+ // do nothing because we are empty
+ case 0: break;
+
+ case 1: invoke_on(obj,unchecked_get<T1>()); break;
+ case 2: invoke_on(obj,unchecked_get<T2>()); break;
+ case 3: invoke_on(obj,unchecked_get<T3>()); break;
+ case 4: invoke_on(obj,unchecked_get<T4>()); break;
+ case 5: invoke_on(obj,unchecked_get<T5>()); break;
+
+ case 6: invoke_on(obj,unchecked_get<T6>()); break;
+ case 7: invoke_on(obj,unchecked_get<T7>()); break;
+ case 8: invoke_on(obj,unchecked_get<T8>()); break;
+ case 9: invoke_on(obj,unchecked_get<T9>()); break;
+ case 10: invoke_on(obj,unchecked_get<T10>()); break;
+
+ case 11: invoke_on(obj,unchecked_get<T11>()); break;
+ case 12: invoke_on(obj,unchecked_get<T12>()); break;
+ case 13: invoke_on(obj,unchecked_get<T13>()); break;
+ case 14: invoke_on(obj,unchecked_get<T14>()); break;
+ case 15: invoke_on(obj,unchecked_get<T15>()); break;
+
+ case 16: invoke_on(obj,unchecked_get<T16>()); break;
+ case 17: invoke_on(obj,unchecked_get<T17>()); break;
+ case 18: invoke_on(obj,unchecked_get<T18>()); break;
+ case 19: invoke_on(obj,unchecked_get<T19>()); break;
+ case 20: invoke_on(obj,unchecked_get<T20>()); break;
+ }
+ }
+
+ template <
+ typename T
+ >
+ void apply_to_contents (
+ const T& obj
+ ) const
+ {
+ switch (type_identity)
+ {
+ // do nothing because we are empty
+ case 0: break;
+
+ case 1: invoke_on(obj,unchecked_get<T1>()); break;
+ case 2: invoke_on(obj,unchecked_get<T2>()); break;
+ case 3: invoke_on(obj,unchecked_get<T3>()); break;
+ case 4: invoke_on(obj,unchecked_get<T4>()); break;
+ case 5: invoke_on(obj,unchecked_get<T5>()); break;
+
+ case 6: invoke_on(obj,unchecked_get<T6>()); break;
+ case 7: invoke_on(obj,unchecked_get<T7>()); break;
+ case 8: invoke_on(obj,unchecked_get<T8>()); break;
+ case 9: invoke_on(obj,unchecked_get<T9>()); break;
+ case 10: invoke_on(obj,unchecked_get<T10>()); break;
+
+ case 11: invoke_on(obj,unchecked_get<T11>()); break;
+ case 12: invoke_on(obj,unchecked_get<T12>()); break;
+ case 13: invoke_on(obj,unchecked_get<T13>()); break;
+ case 14: invoke_on(obj,unchecked_get<T14>()); break;
+ case 15: invoke_on(obj,unchecked_get<T15>()); break;
+
+ case 16: invoke_on(obj,unchecked_get<T16>()); break;
+ case 17: invoke_on(obj,unchecked_get<T17>()); break;
+ case 18: invoke_on(obj,unchecked_get<T18>()); break;
+ case 19: invoke_on(obj,unchecked_get<T19>()); break;
+ case 20: invoke_on(obj,unchecked_get<T20>()); break;
+ }
+ }
+
+ void swap (
+ type_safe_union& item
+ )
+ {
+ // if both *this and item contain the same type of thing
+ if (type_identity == item.type_identity)
+ {
+ // swap the things in this and item.
+ item.apply_to_contents(*this);
+ }
+ else if (type_identity == 0)
+ {
+ // *this doesn't contain anything. So swap this and item and
+ // then destruct item.
+ item.apply_to_contents(*this);
+ item.destruct();
+ }
+ else if (item.type_identity == 0)
+ {
+ // *this doesn't contain anything. So swap this and item and
+ // then destruct this.
+ apply_to_contents(item);
+ destruct();
+ }
+ else
+ {
+ type_safe_union temp;
+ // swap *this into temp
+ apply_to_contents(temp);
+ // swap item into *this
+ item.apply_to_contents(*this);
+ // swap temp into item
+ temp.apply_to_contents(item);
+ }
+ }
+
+ template <typename T>
+ T& get(
+ )
+ {
+ validate_type<T>();
+ construct<T>();
+ return *static_cast<T*>(mem.get());
+ }
+
+ template <typename T>
+ const T& cast_to (
+ ) const
+ {
+ validate_type<T>();
+ if (contains<T>())
+ return *static_cast<const T*>(mem.get());
+ else
+ throw bad_type_safe_union_cast();
+ }
+
+ template <typename T>
+ T& cast_to (
+ )
+ {
+ validate_type<T>();
+ if (contains<T>())
+ return *static_cast<T*>(mem.get());
+ else
+ throw bad_type_safe_union_cast();
+ }
+
+ template <typename T>
+ type_safe_union& operator= ( const T& item) { get<T>() = item; return *this; }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9, typename T10,
+ typename T11, typename T12, typename T13, typename T14, typename T15,
+ typename T16, typename T17, typename T18, typename T19, typename T20
+ >
+ inline void swap (
+ type_safe_union<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10, T11,T12,T13,T14,T15,T16,T17,T18,T19,T20>& a,
+ type_safe_union<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10, T11,T12,T13,T14,T15,T16,T17,T18,T19,T20>& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename from,
+ typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9, typename T10,
+ typename T11, typename T12, typename T13, typename T14, typename T15,
+ typename T16, typename T17, typename T18, typename T19, typename T20
+ >
+ struct is_convertible<from,
+ type_safe_union<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10, T11,T12,T13,T14,T15,T16,T17,T18,T19,T20> >
+ {
+ const static bool value = is_convertible<from,T1>::value ||
+ is_convertible<from,T2>::value ||
+ is_convertible<from,T3>::value ||
+ is_convertible<from,T4>::value ||
+ is_convertible<from,T5>::value ||
+ is_convertible<from,T6>::value ||
+ is_convertible<from,T7>::value ||
+ is_convertible<from,T8>::value ||
+ is_convertible<from,T9>::value ||
+ is_convertible<from,T10>::value ||
+ is_convertible<from,T11>::value ||
+ is_convertible<from,T12>::value ||
+ is_convertible<from,T13>::value ||
+ is_convertible<from,T14>::value ||
+ is_convertible<from,T15>::value ||
+ is_convertible<from,T16>::value ||
+ is_convertible<from,T17>::value ||
+ is_convertible<from,T18>::value ||
+ is_convertible<from,T19>::value ||
+ is_convertible<from,T20>::value;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl_tsu
+ {
+ struct serialize_helper
+ {
+ /*
+ This is a function object to help us serialize type_safe_unions
+ */
+
+ std::ostream& out;
+ serialize_helper(std::ostream& out_): out(out_) {}
+ template <typename T>
+ void operator() (const T& item) const { serialize(item, out); }
+ };
+ }
+
+ template <
+ typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9, typename T10,
+ typename T11, typename T12, typename T13, typename T14, typename T15,
+ typename T16, typename T17, typename T18, typename T19, typename T20
+ >
+ void serialize (
+ const type_safe_union<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10, T11,T12,T13,T14,T15,T16,T17,T18,T19,T20>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ // save the type_identity
+ serialize(item.type_identity, out);
+ item.apply_to_contents(dlib::impl_tsu::serialize_helper(out));
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing an object of type type_safe_union");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9, typename T10,
+ typename T11, typename T12, typename T13, typename T14, typename T15,
+ typename T16, typename T17, typename T18, typename T19, typename T20
+ >
+ void deserialize (
+ type_safe_union<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10, T11,T12,T13,T14,T15,T16,T17,T18,T19,T20>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ typedef type_safe_union<T1,T2,T3,T4,T5,T6,T7,T8,T9,T10, T11,T12,T13,T14,T15,T16,T17,T18,T19,T20> tsu_type;
+
+ int type_identity;
+ deserialize(type_identity, in);
+ switch (type_identity)
+ {
+ // swap an empty type_safe_union into item since it should be in the empty state
+ case 0: tsu_type().swap(item); break;
+
+ case 1: deserialize(item.template get<T1>(), in); break;
+ case 2: deserialize(item.template get<T2>(), in); break;
+ case 3: deserialize(item.template get<T3>(), in); break;
+ case 4: deserialize(item.template get<T4>(), in); break;
+ case 5: deserialize(item.template get<T5>(), in); break;
+
+ case 6: deserialize(item.template get<T6>(), in); break;
+ case 7: deserialize(item.template get<T7>(), in); break;
+ case 8: deserialize(item.template get<T8>(), in); break;
+ case 9: deserialize(item.template get<T9>(), in); break;
+ case 10: deserialize(item.template get<T10>(), in); break;
+
+ case 11: deserialize(item.template get<T11>(), in); break;
+ case 12: deserialize(item.template get<T12>(), in); break;
+ case 13: deserialize(item.template get<T13>(), in); break;
+ case 14: deserialize(item.template get<T14>(), in); break;
+ case 15: deserialize(item.template get<T15>(), in); break;
+
+ case 16: deserialize(item.template get<T16>(), in); break;
+ case 17: deserialize(item.template get<T17>(), in); break;
+ case 18: deserialize(item.template get<T18>(), in); break;
+ case 19: deserialize(item.template get<T19>(), in); break;
+ case 20: deserialize(item.template get<T20>(), in); break;
+
+ default: throw serialization_error("Corrupt data detected while deserializing type_safe_union");
+ }
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while deserializing an object of type type_safe_union");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TYPE_SAFE_UNIOn_h_
+
diff --git a/ml/dlib/dlib/type_safe_union/type_safe_union_kernel_abstract.h b/ml/dlib/dlib/type_safe_union/type_safe_union_kernel_abstract.h
new file mode 100644
index 000000000..3d041e6a3
--- /dev/null
+++ b/ml/dlib/dlib/type_safe_union/type_safe_union_kernel_abstract.h
@@ -0,0 +1,329 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_TYPE_SAFE_UNION_KERNEl_ABSTRACT_
+#ifdef DLIB_TYPE_SAFE_UNION_KERNEl_ABSTRACT_
+
+#include "../algs.h"
+#include "../noncopyable.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class bad_type_safe_union_cast : public std::bad_cast
+ {
+ /*!
+ This is the exception object thrown by type_safe_union::cast_to() if the
+ type_safe_union does not contain the type of object being requested.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T1,
+ typename T2 = _void, // _void indicates parameter not used.
+ typename T3 = _void,
+ typename T4 = _void,
+ typename T5 = _void,
+ typename T6 = _void,
+ typename T7 = _void,
+ typename T8 = _void,
+ typename T9 = _void,
+ typename T10 = _void,
+ typename T11 = _void,
+ typename T12 = _void,
+ typename T13 = _void,
+ typename T14 = _void,
+ typename T15 = _void,
+ typename T16 = _void,
+ typename T17 = _void,
+ typename T18 = _void,
+ typename T19 = _void,
+ typename T20 = _void
+ >
+ class type_safe_union : noncopyable
+ {
+ /*!
+ REQUIREMENTS ON ALL TEMPLATE ARGUMENTS
+ All template arguments must be default constructable and have
+ a global swap.
+
+ INITIAL VALUE
+ - is_empty() == true
+ - contains<U>() == false, for all possible values of U
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a type safe analogue of the classic C union object.
+ The type_safe_union, unlike a union, can contain non-POD types such
+ as std::string.
+
+ For example:
+ union my_union
+ {
+ int a;
+ std::string b; // Error, std::string isn't a POD
+ };
+
+ type_safe_union<int,std::string> my_type_safe_union; // No error
+ !*/
+
+ public:
+
+ typedef T1 type1;
+ typedef T2 type2;
+ typedef T3 type3;
+ typedef T4 type4;
+ typedef T5 type5;
+ typedef T6 type6;
+ typedef T7 type7;
+ typedef T8 type8;
+ typedef T9 type9;
+ typedef T10 type10;
+ typedef T11 type11;
+ typedef T12 type12;
+ typedef T13 type13;
+ typedef T14 type14;
+ typedef T15 type15;
+ typedef T16 type16;
+ typedef T17 type17;
+ typedef T18 type18;
+ typedef T19 type19;
+ typedef T20 type20;
+
+ type_safe_union(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ template <typename T>
+ type_safe_union (
+ const T& item
+ );
+ /*!
+ requires
+ - T must be one of the types given to this object's template arguments
+ ensures
+ - this object is properly initialized
+ - #get<T>() == item
+ (i.e. this object will contain a copy of item)
+ !*/
+
+ ~type_safe_union(
+ );
+ /*!
+ ensures
+ - all resources associated with this object have been freed
+ !*/
+
+ template <typename T>
+ static int get_type_id (
+ );
+ /*!
+ ensures
+ - if (T is the same type as one of the template arguments) then
+ - returns a number indicating which template argument it is.
+ (e.g. if T is the same type as T3 then this function returns 3)
+ - else
+ - returns -1
+ !*/
+
+ template <typename T>
+ bool contains (
+ ) const;
+ /*!
+ ensures
+ - if (this type_safe_union currently contains an object of type T) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_empty (
+ ) const;
+ /*!
+ ensures
+ - if (this type_safe_union currently contains any object at all) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ template <typename T>
+ void apply_to_contents (
+ T& obj
+ );
+ /*!
+ requires
+ - obj is a function object capable of operating on all the types contained
+ in this type_safe_union. I.e. obj(this->get<U>()) must be a valid
+ expression for all the possible U types.
+ ensures
+ - if (is_empty() == false) then
+ - Let U denote the type of object currently contained in this type_safe_union
+ - calls obj(this->get<U>())
+ - The object returned by this->get<U>() will be non-const
+ !*/
+
+ template <typename T>
+ void apply_to_contents (
+ const T& obj
+ );
+ /*!
+ requires
+ - obj is a function object capable of operating on all the types contained
+ in this type_safe_union. I.e. obj(this->get<U>()) must be a valid
+ expression for all the possible U types.
+ ensures
+ - if (is_empty() == false) then
+ - Let U denote the type of object currently contained in this type_safe_union
+ - calls obj(this->get<U>())
+ - The object returned by this->get<U>() will be non-const
+ !*/
+
+ template <typename T>
+ void apply_to_contents (
+ T& obj
+ ) const;
+ /*!
+ requires
+ - obj is a function object capable of operating on all the types contained
+ in this type_safe_union. I.e. obj(this->get<U>()) must be a valid
+ expression for all the possible U types.
+ ensures
+ - if (is_empty() == false) then
+ - Let U denote the type of object currently contained in this type_safe_union
+ - calls obj(this->get<U>())
+ - The object returned by this->get<U>() will be const
+ !*/
+
+ template <typename T>
+ void apply_to_contents (
+ const T& obj
+ ) const;
+ /*!
+ requires
+ - obj is a function object capable of operating on all the types contained
+ in this type_safe_union. I.e. obj(this->get<U>()) must be a valid
+ expression for all the possible U types.
+ ensures
+ - if (is_empty() == false) then
+ - Let U denote the type of object currently contained in this type_safe_union
+ - calls obj(this->get<U>())
+ - The object returned by this->get<U>() will be const
+ !*/
+
+ template <typename T>
+ T& get(
+ );
+ /*!
+ requires
+ - T must be one of the types given to this object's template arguments
+ ensures
+ - #is_empty() == false
+ - #contains<T>() == true
+ - if (contains<T>() == true)
+ - returns a non-const reference to the object contained in this type_safe_union.
+ - else
+ - Constructs an object of type T inside *this
+ - Any previous object stored in this type_safe_union is destructed and its
+ state is lost.
+ - returns a non-const reference to the newly created T object.
+ !*/
+
+ template <typename T>
+ const T& cast_to (
+ ) const;
+ /*!
+ requires
+ - T must be one of the types given to this object's template arguments
+ ensures
+ - if (contains<T>() == true) then
+ - returns a const reference to the object contained in this type_safe_union.
+ - else
+ - throws bad_type_safe_union_cast
+ !*/
+
+ template <typename T>
+ T& cast_to (
+ );
+ /*!
+ requires
+ - T must be one of the types given to this object's template arguments
+ ensures
+ - if (contains<T>() == true) then
+ - returns a non-const reference to the object contained in this type_safe_union.
+ - else
+ - throws bad_type_safe_union_cast
+ !*/
+
+ template <typename T>
+ type_safe_union& operator= (
+ const T& item
+ );
+ /*!
+ requires
+ - T must be one of the types given to this object's template arguments
+ ensures
+ - #get<T>() == item
+ (i.e. this object will contain a copy of item)
+ - returns *this
+ !*/
+
+ void swap (
+ type_safe_union& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template < ... >
+ inline void swap (
+ type_safe_union<...>& a,
+ type_safe_union<...>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template < ... >
+ void serialize (
+ const type_safe_union<...>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+
+ Note that type_safe_union objects are serialized as follows:
+ - if (item.is_empty()) then
+ - perform: serialize(0, out)
+ - else
+ - perform: serialize(item.get_type_id<type_of_object_in_item>(), out);
+ serialize(item.get<type_of_object_in_item>(), out);
+ !*/
+
+ template < ... >
+ void deserialize (
+ type_safe_union<...>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_TYPE_SAFE_UNION_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/uintn.h b/ml/dlib/dlib/uintn.h
new file mode 100644
index 000000000..8e2726546
--- /dev/null
+++ b/ml/dlib/dlib/uintn.h
@@ -0,0 +1,96 @@
+// Copyright (C) 2005 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#ifndef DLIB_UINtn_
+#define DLIB_UINtn_
+
+#include "assert.h"
+
+namespace dlib
+{
+
+ /*!
+ uint64 is a typedef for an unsigned integer that is exactly 64 bits wide.
+ uint32 is a typedef for an unsigned integer that is exactly 32 bits wide.
+ uint16 is a typedef for an unsigned integer that is exactly 16 bits wide.
+ uint8 is a typedef for an unsigned integer that is exactly 8 bits wide.
+
+ int64 is a typedef for an integer that is exactly 64 bits wide.
+ int32 is a typedef for an integer that is exactly 32 bits wide.
+ int16 is a typedef for an integer that is exactly 16 bits wide.
+ int8 is a typedef for an integer that is exactly 8 bits wide.
+ !*/
+
+
+#ifdef __GNUC__
+ typedef unsigned long long uint64;
+ typedef long long int64;
+#elif defined(__BORLANDC__)
+ typedef unsigned __int64 uint64;
+ typedef __int64 int64;
+#elif defined(_MSC_VER)
+ typedef unsigned __int64 uint64;
+ typedef __int64 int64;
+#else
+ typedef unsigned long long uint64;
+ typedef long long int64;
+#endif
+
+ typedef unsigned short uint16;
+ typedef unsigned int uint32;
+ typedef unsigned char uint8;
+
+ typedef short int16;
+ typedef int int32;
+ typedef char int8;
+
+
+ // make sure these types have the right sizes on this platform
+ COMPILE_TIME_ASSERT(sizeof(uint8) == 1);
+ COMPILE_TIME_ASSERT(sizeof(uint16) == 2);
+ COMPILE_TIME_ASSERT(sizeof(uint32) == 4);
+ COMPILE_TIME_ASSERT(sizeof(uint64) == 8);
+
+ COMPILE_TIME_ASSERT(sizeof(int8) == 1);
+ COMPILE_TIME_ASSERT(sizeof(int16) == 2);
+ COMPILE_TIME_ASSERT(sizeof(int32) == 4);
+ COMPILE_TIME_ASSERT(sizeof(int64) == 8);
+
+
+
+ template <typename T, size_t s = sizeof(T)>
+ struct unsigned_type;
+ template <typename T>
+ struct unsigned_type<T,1> { typedef uint8 type; };
+ template <typename T>
+ struct unsigned_type<T,2> { typedef uint16 type; };
+ template <typename T>
+ struct unsigned_type<T,4> { typedef uint32 type; };
+ template <typename T>
+ struct unsigned_type<T,8> { typedef uint64 type; };
+ /*!
+ ensures
+ - sizeof(unsigned_type<T>::type) == sizeof(T)
+ - unsigned_type<T>::type is an unsigned integral type
+ !*/
+
+ template <typename T, typename U>
+ T zero_extend_cast(
+ const U val
+ )
+ /*!
+ requires
+ - U and T are integral types
+ ensures
+ - let ut be a typedef for unsigned_type<U>::type
+ - return static_cast<T>(static_cast<ut>(val));
+ !*/
+ {
+ typedef typename unsigned_type<U>::type ut;
+ return static_cast<T>(static_cast<ut>(val));
+ }
+
+}
+
+#endif // DLIB_UINtn_
+
diff --git a/ml/dlib/dlib/unicode.h b/ml/dlib/dlib/unicode.h
new file mode 100644
index 000000000..3c1598811
--- /dev/null
+++ b/ml/dlib/dlib/unicode.h
@@ -0,0 +1,9 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_UNICODe_TOP_
+#define DLIB_UNICODe_TOP_
+
+#include "unicode/unicode.h"
+
+#endif // DLIB_UNICODe_TOP_
+
diff --git a/ml/dlib/dlib/unicode/unicode.cpp b/ml/dlib/dlib/unicode/unicode.cpp
new file mode 100644
index 000000000..2facc919c
--- /dev/null
+++ b/ml/dlib/dlib/unicode/unicode.cpp
@@ -0,0 +1,175 @@
+// Copyright (C) 2008 Keita Mochizuki, Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_UNICODe_CPp_
+#define DLIB_UNICODe_CPp_
+#include "unicode.h"
+#include <cwchar>
+#include "../string.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ static const unichar SURROGATE_FIRST_TOP = 0xD800;
+ static const unichar SURROGATE_SECOND_TOP = 0xDC00;
+ static const unichar SURROGATE_CLEARING_MASK = 0x03FF;
+ static const unichar SURROGATE_TOP = SURROGATE_FIRST_TOP;
+ static const unichar SURROGATE_END = 0xE000;
+ static const unichar SMP_TOP = 0x10000;
+ static const int VALID_BITS = 10;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T> bool is_surrogate(T ch)
+ {
+ return (zero_extend_cast<unichar>(ch) >= SURROGATE_TOP &&
+ zero_extend_cast<unichar>(ch) < SURROGATE_END);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T> unichar surrogate_pair_to_unichar(T first, T second)
+ {
+ return ((first & SURROGATE_CLEARING_MASK) << VALID_BITS) | ((second & SURROGATE_CLEARING_MASK) + SMP_TOP);
+ }
+ //110110 0000000000
+ //110111 0000000000
+
+// ----------------------------------------------------------------------------------------
+
+ void unichar_to_surrogate_pair(unichar input, unichar &first, unichar &second)
+ {
+ first = ((input - SMP_TOP) >> VALID_BITS) | SURROGATE_FIRST_TOP;
+ second = (input & SURROGATE_CLEARING_MASK) | SURROGATE_SECOND_TOP;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <int N> void wstr2ustring_t(const wchar_t *src, size_t src_len, ustring &dest);
+
+ template <> void wstr2ustring_t<4>(const wchar_t *src, size_t , ustring &dest)
+ {
+ dest.assign((const unichar *)(src));
+ }
+
+ template <> void wstr2ustring_t<2>(const wchar_t *src, size_t src_len, ustring &dest)
+ {
+ size_t wlen = 0;
+ for (size_t i = 0; i < src_len; i++)
+ {
+ is_surrogate(src[i]) ? i++, wlen++ : wlen++;
+ }
+ dest.resize(wlen);
+ for (size_t i = 0, ii = 0; ii < src_len; ++i)
+ {
+ if (is_surrogate(src[ii]))
+ {
+ dest[i] = surrogate_pair_to_unichar(src[ii], src[ii+1]);
+ ii += 2;
+ }else
+ {
+ dest[i] = zero_extend_cast<unichar>(src[ii]);
+ ii++;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const ustring convert_wstring_to_utf32(const std::wstring &src)
+ {
+ ustring dest;
+ wstr2ustring_t<sizeof(wchar_t)>(src.c_str(), src.size(), dest);
+ return dest;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <int N> struct ustring2wstr
+ {
+ };
+
+ // for the environment of sizeof(wchar_t) == 2 (i.e. Win32)
+ template <> struct ustring2wstr<2>
+ {
+ wchar_t *wstr;
+ size_t wlen;
+ ustring2wstr(const ustring &src){
+ wlen = 0;
+ for (size_t i = 0; i < src.length(); ++i)
+ {
+ if (src[i] < SMP_TOP) wlen++;
+ else wlen += 2;
+ }
+ wstr = new wchar_t[wlen+1];
+ wstr[wlen] = L'\0';
+
+ size_t wi = 0;
+ for (size_t i = 0; i < src.length(); ++i)
+ {
+ if (src[i] < SMP_TOP)
+ {
+ wstr[wi++] = (wchar_t)src[i];
+ }else
+ {
+ unichar high, low;
+ unichar_to_surrogate_pair(src[i], high, low);
+ wstr[wi++] = (wchar_t)high;
+ wstr[wi++] = (wchar_t)low;
+ }
+ }
+ }
+ ~ustring2wstr()
+ {
+ delete[] wstr;
+ }
+ };
+
+ // for the environment of sizeof(wchar_t) == 4 (i.e. Unix gcc)
+ template <> struct ustring2wstr<4>
+ {
+ const wchar_t *wstr;
+ size_t wlen;
+ ustring2wstr(const ustring &src){
+ wstr = (const wchar_t *)(src.c_str());
+ wlen = src.size();
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ const std::wstring convert_utf32_to_wstring(const ustring &src)
+ {
+ ustring2wstr<sizeof(wchar_t)> conv(src);
+ std::wstring dest(conv.wstr);
+ return dest;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::wstring convert_mbstring_to_wstring(const std::string &src)
+ {
+ std::vector<wchar_t> wstr(src.length()+5);
+ std::mbstowcs(&wstr[0], src.c_str(), src.length()+1);
+ return std::wstring(&wstr[0]);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string convert_wstring_to_mbstring(const std::wstring &src)
+ {
+ using namespace std;
+ std::string str;
+ str.resize((src.length() + 1) * MB_CUR_MAX);
+ wcstombs(&str[0], src.c_str(), str.size());
+ return std::string(&str[0]);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_UNICODe_CPp_
+
diff --git a/ml/dlib/dlib/unicode/unicode.h b/ml/dlib/dlib/unicode/unicode.h
new file mode 100644
index 000000000..d7510e34a
--- /dev/null
+++ b/ml/dlib/dlib/unicode/unicode.h
@@ -0,0 +1,622 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net), and Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_UNICODe_H_
+#define DLIB_UNICODe_H_
+
+#include "../uintn.h"
+#include "../algs.h"
+#include "unicode_abstract.h"
+#include <string>
+#include <cstring>
+
+#include <fstream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ typedef uint32 unichar;
+
+#if defined(__GNUC__) && __GNUC__ < 4 && __GNUC_MINOR__ < 4
+ struct unichar_traits
+ {
+ typedef dlib::unichar char_type;
+ typedef dlib::unichar int_type;
+ typedef std::streamoff off_type;
+ typedef std::streampos pos_type;
+ typedef std::mbstate_t state_type;
+
+ static void assign(char_type& c1, const char_type& c2) { c1 = c2; }
+ static bool eq(const char_type& c1, const char_type& c2) { return c1 == c2; }
+ static bool lt(const char_type& c1, const char_type& c2) { return c1 < c2; }
+ static int compare(const char_type* s1, const char_type* s2, size_t n)
+ {
+ for (size_t i = 0; i < n; ++i)
+ {
+ if (s1[i] < s2[i])
+ return -1;
+ else if (s1[i] > s2[i])
+ return 1;
+ }
+ return 0;
+ }
+
+ static size_t length(const char_type* s)
+ {
+ size_t i = 0;
+ while (s[i] != 0)
+ ++i;
+ return i;
+ }
+
+ static const char_type* find(const char_type* s, size_t n,
+ const char_type& a)
+ {
+ for (size_t i = 0; i < n; ++i)
+ {
+ if (s[i] == a)
+ {
+ return s+i;
+ }
+ }
+ return 0;
+ }
+
+ static char_type* move(char_type* s1, const char_type* s2, size_t n)
+ {
+ return static_cast<char_type*>(std::memmove(s1, s2, sizeof(char_type)*n));
+ }
+
+ static char_type* copy(char_type* s1, const char_type* s2, size_t n)
+ {
+ for (size_t i = 0; i < n; ++i)
+ s1[i] = s2[i];
+
+ return s1;
+ }
+
+ static char_type* assign(char_type* s, size_t n, char_type a)
+ {
+ for (size_t i = 0; i < n; ++i)
+ s[i] = a;
+
+ return s;
+ }
+
+
+ static int_type not_eof(const int_type& c)
+ {
+ if (!eq_int_type(c,eof()))
+ return to_int_type(c);
+ else
+ return 0;
+ }
+
+ static char_type to_char_type(const int_type& c) { return static_cast<char_type>(c); }
+ static int_type to_int_type(const char_type& c) { return zero_extend_cast<int_type>(c); }
+
+ static bool eq_int_type(const int_type& c1, const int_type& c2) { return c1 == c2; }
+
+ static int_type eof() { return static_cast<int_type>(EOF); }
+ };
+
+ typedef std::basic_string<unichar, unichar_traits> ustring;
+#else
+ typedef std::basic_string<unichar> ustring;
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ namespace unicode_helpers
+ {
+
+ template <
+ typename charT
+ >
+ int u8_to_u32(
+ charT& result,
+ std::istream& in
+ )
+ /*!
+ ensures
+ - if (there just wasn't any more data and we hit EOF) then
+ - returns 0
+ - else if (we decoded another character without error) then
+ - #result == the decoded character
+ - returns the number of bytes consumed to make this character
+ - else
+ - some error occurred
+ - returns -1
+ !*/
+ {
+ int val = in.get();
+ if (val == EOF)
+ return 0;
+
+ unichar ch[4];
+ ch[0] = zero_extend_cast<unichar>(val);
+ if ( ch[0] < 0x80 )
+ {
+ result = static_cast<charT>(ch[0]);
+ return 1;
+ }
+ if ( ( ch[0] & ~0x3F ) == 0x80 )
+ {
+ // invalid leading byte
+ return -1;
+ }
+ if ( ( ch[0] & ~0x1F ) == 0xC0 )
+ {
+ val = in.get();
+ if ( val == EOF )
+ return -1;
+
+ ch[1] = zero_extend_cast<unichar>(val);
+ if ( ( ch[1] & ~0x3F ) != 0x80 )
+ return -1; // invalid tail
+ if ( ( ch[0] & ~0x01 ) == 0xC0 )
+ return -1; // overlong form
+ ch[0] &= 0x1F;
+ ch[1] &= 0x3F;
+ result = static_cast<charT>(( ch[0] << 6 ) | ch[1]);
+ return 2;
+ }
+ if ( ( ch[0] & ~0x0F ) == 0xE0 )
+ {
+ for ( unsigned n = 1;n < 3;n++ )
+ {
+ val = in.get();
+ if ( val == EOF )
+ return -1;
+ ch[n] = zero_extend_cast<unichar>(val);
+ if ( ( ch[n] & ~0x3F ) != 0x80 )
+ return -1; // invalid tail
+ ch[n] &= 0x3F;
+ }
+ ch[0] &= 0x0F;
+ result = static_cast<charT>(( ch[0] << 12 ) | ( ch[1] << 6 ) | ch[2]);
+ if ( result < 0x0800 )
+ return -1; // overlong form
+ if ( result >= 0xD800 && result < 0xE000 )
+ return -1; // invalid character (UTF-16 surrogate pairs)
+ if ( result >= 0xFDD0 && result <= 0xFDEF )
+ return -1; // noncharacter
+ if ( result >= 0xFFFE )
+ return -1; // noncharacter
+ return 3;
+ }
+ if ( ( ch[0] & ~0x07 ) == 0xF0 )
+ {
+ for ( unsigned n = 1;n < 4;n++ )
+ {
+ val = in.get();
+ if ( val == EOF )
+ return -1;
+ ch[n] = zero_extend_cast<unichar>(val);
+ if ( ( ch[n] & ~0x3F ) != 0x80 )
+ return -1; // invalid tail
+ ch[n] &= 0x3F;
+ }
+ if ( ( ch[0] ^ 0xF6 ) < 4 )
+ return -1;
+ ch[0] &= 0x07;
+ result = static_cast<charT>(( ch[0] << 18 ) | ( ch[1] << 12 ) | ( ch[2] << 6 ) | ch[3]);
+ if ( result < 0x10000 )
+ return -1; // overlong form
+ if ( (result & 0xFFFF) >= 0xFFFE )
+ return -1; // noncharacter
+ return 4;
+ }
+ return -1;
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ template <typename charT>
+ class basic_utf8_streambuf : public std::basic_streambuf<charT>
+ {
+ public:
+ basic_utf8_streambuf (
+ std::ifstream& fin_
+ ) :
+ fin(fin_)
+ {
+ this->setg(in_buffer+max_putback,
+ in_buffer+max_putback,
+ in_buffer+max_putback);
+ }
+
+ protected:
+
+ typedef typename std::basic_streambuf<charT>::int_type int_type;
+
+ // input functions
+ int_type underflow(
+ )
+ {
+ if (this->gptr() < this->egptr())
+ {
+ return zero_extend_cast<int_type>(*this->gptr());
+ }
+
+ int num_put_back = static_cast<int>(this->gptr() - this->eback());
+ if (num_put_back > max_putback)
+ {
+ num_put_back = max_putback;
+ }
+
+ // copy the putback characters into the putback end of the in_buffer
+ std::memmove(in_buffer+(max_putback-num_put_back), this->gptr()-num_put_back, num_put_back);
+
+
+ // fill the buffer with characters
+ int n = in_buffer_size-max_putback;
+ int i;
+ for (i = 0; i < n; ++i)
+ {
+ charT ch;
+ if (unicode_helpers::u8_to_u32(ch,fin) > 0)
+ {
+ (in_buffer+max_putback)[i] = ch;
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ if (i == 0)
+ {
+ // an error occurred or we hit EOF
+ return EOF;
+ }
+
+ // reset in_buffer pointers
+ this->setg (in_buffer+(max_putback-num_put_back),
+ in_buffer+max_putback,
+ in_buffer+max_putback+i);
+
+ return zero_extend_cast<int_type>(*this->gptr());
+ }
+
+ private:
+ std::ifstream& fin;
+ static const int max_putback = 4;
+ static const int in_buffer_size = 10;
+ charT in_buffer[in_buffer_size];
+ };
+ }
+
+// ----------------------------------------------------------------------------------------
+#if defined(__GNUC__) && __GNUC__ >= 6
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wmisleading-indentation"
+#endif
+ template <typename T>
+ bool is_combining_char(
+ const T ch_
+ )
+ {
+ const unichar ch = zero_extend_cast<unichar>(ch_);
+ if ( ch < 0x300 ) return false;
+ if ( ch < 0x370 ) return true;
+
+ if ( ch < 0x800 )
+ {
+ if ( ch < 0x483 )return false;if ( ch < 0x48A )return true;
+
+ if ( ch < 0x591 )return false;if ( ch < 0x5D0 )
+ {
+ if ( ch == 0x5C0 )return false;
+ if ( ch == 0x5C3 )return false;
+ if ( ch == 0x5C6 )return false;
+ return true;
+ }
+ if ( ch < 0x610 )return false;if ( ch < 0x616 )return true;
+ if ( ch < 0x64B )return false;if ( ch < 0x660 )return true;
+
+ if ( ch == 0x670 )return true;
+
+ if ( ch < 0x6D6 )return false;if ( ch < 0x6EE )
+ {
+ if ( ch == 0x6DD )return false;
+ if ( ch == 0x6E5 )return false;
+ if ( ch == 0x6E6 )return false;
+ if ( ch == 0x6E9 )return false;
+ return true;
+ }
+ if ( ch == 0x711 )return true;
+
+ if ( ch < 0x730 )return false;if ( ch < 0x74B )return true;
+ if ( ch < 0x7A6 )return false;if ( ch < 0x7B1 )return true;
+ if ( ch < 0x7EB )return false;if ( ch < 0x7F4 )return true;
+ return false;
+ }
+ if ( ch < 0xA00 )
+ {
+ if ( ch < 0x901 )return false;if ( ch < 0x904 )return true;
+ if ( ch < 0x93C )return false;if ( ch < 0x955 )
+ {
+ if ( ch == 0x93D )return false;
+ if ( ch == 0x950 )return false;
+ return true;
+ }
+ if ( ch < 0x962 )return false;if ( ch < 0x964 )return true;
+ if ( ch < 0x981 )return false;if ( ch < 0x984 )return true;
+ if ( ch < 0x9BC )return false;if ( ch < 0x9D8 )
+ {
+ if ( ch == 0x9BD )return false;
+ if ( ch == 0x9CE )return false;
+ return true;
+ }
+ if ( ch < 0x9E2 )return false;if ( ch < 0x9E4 )return true;
+ return false;
+ }
+ if ( ch < 0xC00 )
+ {
+ if ( ch < 0xA01 )return false;if ( ch < 0xA04 )return true;
+ if ( ch < 0xA3C )return false;if ( ch < 0xA4E )return true;
+ if ( ch < 0xA70 )return false;if ( ch < 0xA72 )return true;
+ if ( ch < 0xA81 )return false;if ( ch < 0xA84 )return true;
+ if ( ch < 0xABC )return false;if ( ch < 0xACE )
+ {
+ if ( ch == 0xABD )return false;
+ return true;
+ }
+ if ( ch < 0xAE2 )return false;if ( ch < 0xAE4 )return true;
+ if ( ch < 0xB01 )return false;if ( ch < 0xB04 )return true;
+ if ( ch < 0xB3C )return false;if ( ch < 0xB58 )
+ {
+ if ( ch == 0xB3D )return false;
+ return true;
+ }
+ if ( ch == 0xB82 )return true;
+
+ if ( ch < 0xBBE )return false;if ( ch < 0xBD8 )return true;
+
+ if ( ch == 0xBF4 )return true;
+ if ( ch == 0xBF8 )return true;
+ return false;
+ }
+ if(ch < 0xE00)
+ {
+ if ( ch < 0xC01 )return false;if ( ch < 0xC04 )return true;
+ if ( ch < 0xC3E )return false;if ( ch < 0xC57 )return true;
+ if ( ch < 0xC82 )return false;if ( ch < 0xC84 )return true;
+ if ( ch < 0xCBC )return false;if ( ch < 0xCD7 )
+ {
+ if ( ch == 0xCBD )return false;
+ return true;
+ }
+ if ( ch < 0xCE2 )return false;if ( ch < 0xCE4 )return true;
+ if ( ch < 0xD02 )return false;if ( ch < 0xD04 )return true;
+ if ( ch < 0xD3E )return false;if ( ch < 0xD58 )return true;
+ if ( ch < 0xD82 )return false;if ( ch < 0xD84 )return true;
+ if ( ch < 0xDCA )return false;if ( ch < 0xDF4 )return true;
+ return false;
+ }
+ if(ch < 0x1000)
+ {
+ if ( ch == 0xE31 )return true;
+
+ if ( ch < 0xE34 )return false;if ( ch < 0xE3B )return true;
+ if ( ch < 0xE47 )return false;if ( ch < 0xE4F )return true;
+
+ if ( ch == 0xEB1 )return true;
+
+ if ( ch < 0xEB4 )return false;if ( ch < 0xEBD )return true;
+ if ( ch < 0xEC8 )return false;if ( ch < 0xECE )return true;
+ if ( ch < 0xF18 )return false;if ( ch < 0xF1A )return true;
+
+ if ( ch == 0xF35 )return true;
+ if ( ch == 0xF37 )return true;
+ if ( ch == 0xF39 )return true;
+
+ if ( ch < 0xF3E )return false;if ( ch < 0xF40 )return true;
+ if ( ch < 0xF71 )return false;if ( ch < 0xF88 )
+ {
+ if ( ch == 0xF85 )return false;
+ return true;
+ }
+ if ( ch < 0xF90 )return false;if ( ch < 0xFBD )return true;
+
+ if ( ch == 0xFC6 )return true;
+ return false;
+ }
+ if ( ch < 0x1800 )
+ {
+ if ( ch < 0x102C )return false;if ( ch < 0x1040 )return true;
+ if ( ch < 0x1056 )return false;if ( ch < 0x105A )return true;
+
+ if ( ch == 0x135F )return true;
+
+ if ( ch < 0x1712 )return false;if ( ch < 0x1715 )return true;
+ if ( ch < 0x1732 )return false;if ( ch < 0x1735 )return true;
+ if ( ch < 0x1752 )return false;if ( ch < 0x1754 )return true;
+ if ( ch < 0x1772 )return false;if ( ch < 0x1774 )return true;
+ if ( ch < 0x17B6 )return false;if ( ch < 0x17D4 )return true;
+
+ if ( ch == 0x17DD )return true;
+ return false;
+ }
+ if(ch < 0x2000)
+ {
+ if ( ch < 0x180B )return false;if ( ch < 0x180E )return true;
+
+ if ( ch == 0x18A9 )return true;
+
+ if ( ch < 0x1920 )return false;if ( ch < 0x193C )return true;
+ if ( ch < 0x19B0 )return false;if ( ch < 0x19C1 )return true;
+ if ( ch < 0x19C8 )return false;if ( ch < 0x19CA )return true;
+ if ( ch < 0x1A17 )return false;if ( ch < 0x1A1C )return true;
+ if ( ch < 0x1B00 )return false;if ( ch < 0x1B05 )return true;
+ if ( ch < 0x1B34 )return false;if ( ch < 0x1B45 )return true;
+ if ( ch < 0x1B6B )return false;if ( ch < 0x1B74 )return true;
+ if ( ch < 0x1DC0 )return false;if ( ch < 0x1E00 )return true;
+ return false;
+ }
+ if ( ch < 0x20D0 )return false;if ( ch < 0x2100 )return true;
+ if ( ch < 0x302A )return false;if ( ch < 0x3030 )return true;
+ if ( ch < 0x3099 )return false;if ( ch < 0x309B )return true;
+
+ if ( ch == 0xA802 )return true;
+ if ( ch == 0xA806 )return true;
+ if ( ch == 0xA80B )return true;
+
+ if ( ch < 0xA823 )return false;if ( ch < 0xA828 )return true;
+
+ if ( ch == 0xFB1E )return true;
+
+ if ( ch < 0xFE00 )return false;if ( ch < 0xFE10 )return true;
+ if ( ch < 0xFE20 )return false;if ( ch < 0xFE30 )return true;
+ if ( ch < 0x10A01 )return false;if ( ch < 0x10A10 )return true;
+ if ( ch < 0x10A38 )return false;if ( ch < 0x10A40 )return true;
+ if ( ch < 0x1D165 )return false;if ( ch < 0x1D16A )return true;
+ if ( ch < 0x1D16D )return false;if ( ch < 0x1D173 )return true;
+ if ( ch < 0x1D17B )return false;if ( ch < 0x1D183 )return true;
+ if ( ch < 0x1D185 )return false;if ( ch < 0x1D18C )return true;
+ if ( ch < 0x1D1AA )return false;if ( ch < 0x1D1AE )return true;
+ if ( ch < 0x1D242 )return false;if ( ch < 0x1D245 )return true;
+ if ( ch < 0xE0100 )return false;if ( ch < 0xE01F0 )return true;
+ return false;
+ }
+#if defined(__GNUC__) && __GNUC__ >= 6
+#pragma GCC diagnostic pop
+#endif
+
+// ----------------------------------------------------------------------------------------
+
+ class invalid_utf8_error : public error
+ {
+ public:
+ invalid_utf8_error():error(EUTF8_TO_UTF32) {}
+ };
+
+ inline const ustring convert_utf8_to_utf32 (
+ const std::string& str
+ )
+ {
+ using namespace unicode_helpers;
+ ustring temp;
+ std::istringstream sin(str);
+
+ temp.reserve(str.size());
+
+ int status;
+ unichar ch;
+ while ( (status = u8_to_u32(ch,sin)) > 0)
+ temp.push_back(ch);
+
+ if (status < 0)
+ throw invalid_utf8_error();
+
+ return temp;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_surrogate(unichar ch);
+
+ unichar surrogate_pair_to_unichar(unichar first, unichar second);
+
+ void unichar_to_surrogate_pair(unichar unicode, unichar &first, unichar &second);
+
+
+ const ustring convert_wstring_to_utf32 (
+ const std::wstring &wstr
+ );
+
+ const std::wstring convert_utf32_to_wstring (
+ const ustring &src
+ );
+
+ const std::wstring convert_mbstring_to_wstring (
+ const std::string &src
+ );
+
+ const std::string convert_wstring_to_mbstring(
+ const std::wstring &src
+ );
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename charT>
+ class basic_utf8_ifstream : public std::basic_istream<charT>
+ {
+ public:
+
+ basic_utf8_ifstream (
+ ) : std::basic_istream<charT>(&buf), buf(fin) {}
+
+ basic_utf8_ifstream (
+ const char* file_name,
+ std::ios_base::openmode mode = std::ios::in
+ ) :
+ std::basic_istream<charT>(&buf),
+ buf(fin)
+ {
+ fin.open(file_name,mode);
+ // make this have the same error state as fin
+ this->clear(fin.rdstate());
+ }
+
+ basic_utf8_ifstream (
+ const std::string& file_name,
+ std::ios_base::openmode mode = std::ios::in
+ ) :
+ std::basic_istream<charT>(&buf),
+ buf(fin)
+ {
+ fin.open(file_name.c_str(),mode);
+ // make this have the same error state as fin
+ this->clear(fin.rdstate());
+ }
+
+ void open(
+ const std::string& file_name,
+ std::ios_base::openmode mode = std::ios::in
+ )
+ {
+ open(file_name.c_str(),mode);
+ }
+
+ void open (
+ const char* file_name,
+ std::ios_base::openmode mode = std::ios::in
+ )
+ {
+ fin.close();
+ fin.clear();
+ fin.open(file_name,mode);
+ // make this have the same error state as fin
+ this->clear(fin.rdstate());
+ }
+
+ void close (
+ )
+ {
+ fin.close();
+ // make this have the same error state as fin
+ this->clear(fin.rdstate());
+ }
+
+ private:
+
+ std::ifstream fin;
+ unicode_helpers::basic_utf8_streambuf<charT> buf;
+ };
+
+ typedef basic_utf8_ifstream<unichar> utf8_uifstream;
+ typedef basic_utf8_ifstream<wchar_t> utf8_wifstream;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#ifdef NO_MAKEFILE
+#include "unicode.cpp"
+#endif
+
+#endif // DLIB_UNICODe_H_
+
diff --git a/ml/dlib/dlib/unicode/unicode_abstract.h b/ml/dlib/dlib/unicode/unicode_abstract.h
new file mode 100644
index 000000000..ed5b9ab4e
--- /dev/null
+++ b/ml/dlib/dlib/unicode/unicode_abstract.h
@@ -0,0 +1,233 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net), and Nils Labugt
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_UNICODe_ABSTRACT_H_
+#ifdef DLIB_UNICODe_ABSTRACT_H_
+
+#include "../uintn.h"
+#include "../error.h"
+#include <string>
+#include <fstream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ // a typedef for an unsigned 32bit integer to hold our UNICODE characters
+ typedef uint32 unichar;
+
+ // a typedef for a string object to hold our UNICODE strings
+ typedef std::basic_string<unichar> ustring;
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ bool is_combining_char(
+ const T ch_
+ );
+ /*!
+ ensures
+ - if (ch_ is a unicode combining character) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ bool is_surrogate(
+ unichar ch
+ );
+ /*!
+ ensures
+ - if (ch is a unicode surrogate character) then
+ - returns true
+ - else
+ - returns false
+ !*/
+
+ unichar surrogate_pair_to_unichar(
+ unichar first,
+ unichar second
+ );
+ /*!
+ requires
+ - 0xD800 <= first < 0xDC00
+ - 0xDC00 <= second < 0xE000
+ - is_surrogate(first) == true
+ - is_surrogate(second) == true
+ ensures
+ - converts two surrogates into one unicode character
+ !*/
+
+ void unichar_to_surrogate_pair(
+ unichar ch,
+ unichar& first,
+ unichar& second
+ );
+ /*!
+ requires
+ - ch >= 0x10000 (i.e. is not in Basic Multilingual Plane)
+ ensures
+ - surrogate_pair_to_unichar(#first,#second) == ch
+ (i.e. converts ch into two surrogate characters)
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ class invalid_utf8_error : public error
+ {
+ public:
+ invalid_utf8_error():error(EUTF8_TO_UTF32) {}
+ };
+
+ const ustring convert_utf8_to_utf32 (
+ const std::string& str
+ );
+ /*!
+ ensures
+ - if (str is a valid UTF-8 encoded string) then
+ - returns a copy of str that has been converted into a
+ unichar string
+ - else
+ - throws invalid_utf8_error
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const ustring convert_wstring_to_utf32 (
+ const std::wstring &wstr
+ );
+ /*!
+ requires
+ - wstr is a valid UTF-16 string when sizeof(wchar_t) == 2
+ - wstr is a valid UTF-32 string when sizeof(wchar_t) == 4
+ ensures
+ - converts wstr into UTF-32 string
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const std::wstring convert_utf32_to_wstring (
+ const ustring &str
+ );
+ /*!
+ requires
+ - str is a valid UTF-32 encoded string
+ ensures
+ - converts str into wstring whose encoding is UTF-16 when sizeof(wchar_t) == 2
+ - converts str into wstring whose encoding is UTF-32 when sizeof(wchar_t) == 4
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const std::wstring convert_mbstring_to_wstring (
+ const std::string &str
+ );
+ /*!
+ requires
+ - str is a valid multibyte string whose encoding is same as current locale setting
+ ensures
+ - converts str into wstring whose encoding is UTF-16 when sizeof(wchar_t) == 2
+ - converts str into wstring whose encoding is UTF-32 when sizeof(wchar_t) == 4
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ const std::string convert_wstring_to_mbstring (
+ const std::wstring &src
+ );
+ /*!
+ requires
+ - str is a valid wide character string string whose encoding is same as current
+ locale setting
+ ensures
+ - returns a multibyte encoded version of the given string
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename charT
+ >
+ class basic_utf8_ifstream : public std::basic_istream<charT>
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This object represents an input file stream much like the
+ normal std::ifstream except that it knows how to read UTF-8
+ data. So when you read characters out of this stream it will
+ automatically convert them from the UTF-8 multibyte encoding
+ into a fixed width wide character encoding.
+ !*/
+
+ public:
+
+ basic_utf8_ifstream (
+ );
+ /*!
+ ensures
+ - constructs an input stream that isn't yet associated with
+ a file.
+ !*/
+
+ basic_utf8_ifstream (
+ const char* file_name,
+ std::ios_base::openmode mode = std::ios::in
+ );
+ /*!
+ ensures
+ - tries to open the given file for reading by this stream
+ - mode is interpreted exactly the same was as the open mode
+ argument used by std::ifstream.
+ !*/
+
+ basic_utf8_ifstream (
+ const std::string& file_name,
+ std::ios_base::openmode mode = std::ios::in
+ );
+ /*!
+ ensures
+ - tries to open the given file for reading by this stream
+ - mode is interpreted exactly the same was as the open mode
+ argument used by std::ifstream.
+ !*/
+
+ void open(
+ const std::string& file_name,
+ std::ios_base::openmode mode = std::ios::in
+ );
+ /*!
+ ensures
+ - tries to open the given file for reading by this stream
+ - mode is interpreted exactly the same was as the open mode
+ argument used by std::ifstream.
+ !*/
+
+ void open (
+ const char* file_name,
+ std::ios_base::openmode mode = std::ios::in
+ );
+ /*!
+ ensures
+ - tries to open the given file for reading by this stream
+ - mode is interpreted exactly the same was as the open mode
+ argument used by std::ifstream.
+ !*/
+
+ void close (
+ );
+ /*!
+ ensures
+ - any file opened by this stream has been closed
+ !*/
+ };
+
+ typedef basic_utf8_ifstream<unichar> utf8_uifstream;
+ typedef basic_utf8_ifstream<wchar_t> utf8_wifstream;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_UNICODe_ABSTRACT_H_
+
+
diff --git a/ml/dlib/dlib/unordered_pair.h b/ml/dlib/dlib/unordered_pair.h
new file mode 100644
index 000000000..9ea75b912
--- /dev/null
+++ b/ml/dlib/dlib/unordered_pair.h
@@ -0,0 +1,176 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_UNORDERED_PAiR_Hh_
+#define DLIB_UNORDERED_PAiR_Hh_
+
+#include "serialize.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct unordered_pair
+ {
+ /*!
+ REQUIREMENTS ON T
+ T must be default constructable, copyable, and comparable using
+ operator < and ==
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is very similar to the std::pair struct except unordered_pair
+ is only capable of representing an unordered set of two items rather than
+ an ordered list of two items like std::pair.
+
+ This is best illustrated by example. Suppose we have the following
+ five variables:
+ std::pair<int,int> p1(1, 5), p2(5,1);
+ unordered_pair<int> up1(1,5), up2(5,1), up3(6,7);
+
+ Then it is the case that:
+ up1 == up2
+ up1 != up3
+ p1 != p2
+
+ So the unordered_pair doesn't care about the order of the arguments.
+ In this case, up1 and up2 are both equivalent.
+
+ !*/
+
+ typedef T type;
+ typedef T first_type;
+ typedef T second_type;
+
+ const T first;
+ const T second;
+
+ unordered_pair() : first(), second()
+ /*!
+ ensures
+ - #first and #second are default initialized
+ !*/ {}
+
+ unordered_pair(
+ const T& a,
+ const T& b
+ ) :
+ first( a < b ? a : b),
+ second(a < b ? b : a)
+ /*!
+ ensures
+ - #first <= #second
+ - #first and #second contain copies of the items a and b.
+ !*/ {}
+
+ template <typename U>
+ unordered_pair (
+ const unordered_pair <U>& p
+ ) :
+ first(p.first),
+ second(p.second)
+ /*!
+ ensures
+ - #*this is a copy of p
+ !*/ {}
+
+ unordered_pair& operator= (
+ const unordered_pair& item
+ )
+ /*!
+ ensures
+ - #*this == item
+ !*/
+ {
+ const_cast<T&>(first) = item.first;
+ const_cast<T&>(second) = item.second;
+ return *this;
+ }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ bool operator==(const unordered_pair<T>& a, const unordered_pair <T>& b)
+ {
+ return a.first == b.first && a.second == b.second;
+ }
+
+ template <typename T>
+ bool operator!=(const unordered_pair<T>& a, const unordered_pair <T>& b)
+ {
+ return !(a == b);
+ }
+
+ template <typename T>
+ bool operator<(const unordered_pair<T>& a, const unordered_pair<T>& b)
+ {
+ return (a.first < b.first || (!(b.first < a.first) && a.second < b.second));
+ }
+
+ template <typename T>
+ bool operator>(const unordered_pair<T>& a, const unordered_pair <T>& b)
+ {
+ return b < a;
+ }
+
+ template <typename T>
+ bool operator<=(const unordered_pair<T>& a, const unordered_pair <T>& b)
+ {
+ return !(b < a);
+ }
+
+ template <typename T>
+ bool operator>=(const unordered_pair<T>& a, const unordered_pair <T>& b)
+ {
+ return !(a < b);
+ }
+
+ template <typename T>
+ unordered_pair<T> make_unordered_pair (const T& a, const T& b)
+ {
+ return unordered_pair<T>(a,b);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void serialize (
+ const unordered_pair<T>& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.first,out);
+ serialize(item.second,out);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while serializing object of type unordered_pair"); }
+ }
+
+ template <typename T>
+ void deserialize (
+ unordered_pair<T>& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ T a, b;
+ deserialize(a,in);
+ deserialize(b,in);
+ item = make_unordered_pair(a,b);
+ }
+ catch (serialization_error& e)
+ { throw serialization_error(e.info + "\n while deserializing object of type unordered_pair"); }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_UNORDERED_PAiR_Hh_
+
diff --git a/ml/dlib/dlib/vectorstream.h b/ml/dlib/dlib/vectorstream.h
new file mode 100644
index 000000000..afeb247de
--- /dev/null
+++ b/ml/dlib/dlib/vectorstream.h
@@ -0,0 +1,11 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_VECTORSTReAMh_
+#define DLIB_VECTORSTReAMh_
+
+#include "vectorstream/vectorstream.h"
+#include "vectorstream/unserialize.h"
+
+
+#endif // DLIB_VECTORSTReAMh_
+
diff --git a/ml/dlib/dlib/vectorstream/unserialize.h b/ml/dlib/dlib/vectorstream/unserialize.h
new file mode 100644
index 000000000..dbfe5584b
--- /dev/null
+++ b/ml/dlib/dlib/vectorstream/unserialize.h
@@ -0,0 +1,98 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_uNSERIALIZE_Hh_
+#define DLIB_uNSERIALIZE_Hh_
+
+#include "unserialize_abstract.h"
+
+#include "../serialize.h"
+#include "../algs.h"
+#include "vectorstream.h"
+
+
+
+namespace dlib
+{
+ class unserialize : public std::istream
+ {
+ class mystreambuf : public std::streambuf
+ {
+ typedef std::vector<char>::size_type size_type;
+ size_type read_pos; // buffer[read_pos] == next byte to read from buffer
+ public:
+ std::vector<char> buffer;
+ std::istream& str;
+
+ template <typename T>
+ mystreambuf(
+ const T& item,
+ std::istream& str_
+ ) :
+ read_pos(0),
+ str(str_)
+ {
+ // put the item into our buffer.
+ vectorstream vstr(buffer);
+ serialize(item, vstr);
+ }
+
+
+ // ------------------------ INPUT FUNCTIONS ------------------------
+
+ int_type underflow(
+ )
+ {
+ if (read_pos < buffer.size())
+ return static_cast<unsigned char>(buffer[read_pos]);
+ else
+ return str.peek();
+ }
+
+ int_type uflow(
+ )
+ {
+ if (read_pos < buffer.size())
+ return static_cast<unsigned char>(buffer[read_pos++]);
+ else
+ return str.get();
+ }
+
+ std::streamsize xsgetn (
+ char* s,
+ std::streamsize n
+ )
+ {
+ if (read_pos < buffer.size())
+ {
+ const size_type num = std::min<size_type>(n, buffer.size()-read_pos);
+ std::memcpy(s, &buffer[read_pos], num);
+ read_pos += num;
+ return num;
+ }
+ else
+ {
+ return str.rdbuf()->sgetn(s,n);
+ }
+ return 0;
+ }
+
+ };
+
+ public:
+
+ template <typename T>
+ unserialize (
+ const T& item,
+ std::istream& str
+ ) :
+ std::istream(&buf),
+ buf(item, str)
+ {}
+
+ private:
+ mystreambuf buf;
+ };
+}
+
+#endif // DLIB_uNSERIALIZE_Hh_
+
diff --git a/ml/dlib/dlib/vectorstream/unserialize_abstract.h b/ml/dlib/dlib/vectorstream/unserialize_abstract.h
new file mode 100644
index 000000000..b7d67836c
--- /dev/null
+++ b/ml/dlib/dlib/vectorstream/unserialize_abstract.h
@@ -0,0 +1,58 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_uNSERIALIZE_ABSTRACT_Hh_
+#ifdef DLIB_uNSERIALIZE_ABSTRACT_Hh_
+
+#include "../serialize.h"
+#include <iostream>
+
+namespace dlib
+{
+ class unserialize : public std::istream
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is a tool that allows you to effectively put an object you just
+ deserialized from a stream back into the stream. Its use is best
+ illustrated via an example.
+
+ void example(std::istream& in)
+ {
+ // Suppose that in contains serialized copies of three "some_type"
+ // objects. You could read them as follows:
+ some_type obj1, obj2, obj3;
+
+ deserialize(obj1, in); // reads obj1 from stream.
+ deserialize(obj2, in); // reads obj2 from stream.
+
+ unserialize in2(obj2, in); // make the in2 stream that has obj2 at its front.
+ deserialize(obj2, in2); // reads obj2 from stream again.
+ deserialize(obj3, in2); // reads obj3 from stream.
+ }
+
+ The reason unserialize is useful is because it allows you to peek at the
+ next object in a stream and potentially do something different based on
+ what object is coming next, but still allowing subsequent deserialize()
+ statements to be undisturbed by the fact that you peeked at the data.
+ !*/
+
+ public:
+
+ template <typename T>
+ unserialize (
+ const T& item,
+ std::istream& in
+ );
+ /*!
+ requires
+ - T must be serializable
+ ensures
+ - The bytes in this stream begin with a serialized copy of item followed
+ immediately by the bytes in the given istream.
+ !*/
+ };
+}
+
+#endif // DLIB_uNSERIALIZE_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/vectorstream/vectorstream.h b/ml/dlib/dlib/vectorstream/vectorstream.h
new file mode 100644
index 000000000..a1b7b8b07
--- /dev/null
+++ b/ml/dlib/dlib/vectorstream/vectorstream.h
@@ -0,0 +1,138 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_VECTORStREAM_Hh_
+#define DLIB_VECTORStREAM_Hh_
+
+#include "vectorstream_abstract.h"
+
+#include <cstring>
+#include <iostream>
+#include <streambuf>
+#include <vector>
+#include <cstdio>
+#include "../algs.h"
+
+
+#ifdef _MSC_VER
+// Disable the warning about inheriting from std::iostream 'via dominance' since this warning is a warning about
+// visual studio conforming to the standard and is ignorable. See http://connect.microsoft.com/VisualStudio/feedback/details/733720/inheriting-from-std-fstream-produces-c4250-warning
+// for further details if interested.
+#pragma warning(disable : 4250)
+#endif // _MSC_VER
+
+namespace dlib
+{
+ class vectorstream : public std::iostream
+ {
+ class vector_streambuf : public std::streambuf
+ {
+ typedef std::vector<char>::size_type size_type;
+ size_type read_pos; // buffer[read_pos] == next byte to read from buffer
+ public:
+ std::vector<char>& buffer;
+
+ vector_streambuf(
+ std::vector<char>& buffer_
+ ) :
+ read_pos(0),
+ buffer(buffer_)
+ {}
+
+
+ void seekg(size_type pos)
+ {
+ read_pos = pos;
+ }
+
+ // ------------------------ OUTPUT FUNCTIONS ------------------------
+
+ int_type overflow ( int_type c)
+ {
+ if (c != EOF) buffer.push_back(static_cast<char>(c));
+ return c;
+ }
+
+ std::streamsize xsputn ( const char* s, std::streamsize num)
+ {
+ buffer.insert(buffer.end(), s, s+num);
+ return num;
+ }
+
+ // ------------------------ INPUT FUNCTIONS ------------------------
+
+ int_type underflow(
+ )
+ {
+ if (read_pos < buffer.size())
+ return static_cast<unsigned char>(buffer[read_pos]);
+ else
+ return EOF;
+ }
+
+ int_type uflow(
+ )
+ {
+ if (read_pos < buffer.size())
+ return static_cast<unsigned char>(buffer[read_pos++]);
+ else
+ return EOF;
+ }
+
+ int_type pbackfail(
+ int_type c
+ )
+ {
+ // if they are trying to push back a character that they didn't read last
+ // that is an error
+ const unsigned long prev = read_pos-1;
+ if (c != EOF && prev < buffer.size() &&
+ c != static_cast<unsigned char>(buffer[prev]))
+ {
+ return EOF;
+ }
+
+ read_pos = prev;
+ return 1;
+ }
+
+ std::streamsize xsgetn (
+ char* s,
+ std::streamsize n
+ )
+ {
+ if (read_pos < buffer.size())
+ {
+ const size_type num = std::min<size_type>(n, buffer.size()-read_pos);
+ std::memcpy(s, &buffer[read_pos], num);
+ read_pos += num;
+ return num;
+ }
+ return 0;
+ }
+
+ };
+
+ public:
+
+ vectorstream (
+ std::vector<char>& buffer
+ ) :
+ std::iostream(&buf),
+ buf(buffer)
+ {}
+
+ std::istream& seekg (
+ std::streampos pos
+ )
+ {
+ buf.seekg(pos);
+ return *this;
+ }
+
+ private:
+ vector_streambuf buf;
+ };
+}
+
+#endif // DLIB_VECTORStREAM_Hh_
+
diff --git a/ml/dlib/dlib/vectorstream/vectorstream_abstract.h b/ml/dlib/dlib/vectorstream/vectorstream_abstract.h
new file mode 100644
index 000000000..f5fe1004c
--- /dev/null
+++ b/ml/dlib/dlib/vectorstream/vectorstream_abstract.h
@@ -0,0 +1,62 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_VECTORStREAM_ABSTRACT_Hh_
+#ifdef DLIB_VECTORStREAM_ABSTRACT_Hh_
+
+#include <iostream>
+#include <vector>
+
+namespace dlib
+{
+ class vectorstream : public std::iostream
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an iostream object that reads and writes from an in-memory buffer.
+ It functions very much the same way as the std::stringstream object.
+ However, while the std::stringstream holds its buffer internally and it can
+ only be accessed by copying it out, the vectorstream uses an external
+ std::vector<char> as its buffer. That is, it holds a reference to an
+ external vector and does not contain any internal buffers of its own.
+
+ This object is useful as a slightly more efficient alternative to the
+ std::stringstream since you can avoid the overhead of copying buffer
+ contents to and from the stream. This is particularly useful when used as
+ a source or target for serialization routines.
+ !*/
+
+ public:
+
+ vectorstream (
+ std::vector<char>& buffer
+ );
+ /*!
+ ensures
+ - This object will use the given vector as its read/write buffer. That is:
+ - Any data written to this stream will be appended to the given buffer
+ - Any data read from this stream is read from the given buffer,
+ starting with buffer[0], then buffer[1], and so on. Just like
+ std::stringstream, writes to the stream do not move the position of
+ the next byte that will be read from the buffer.
+ - This constructor does not copy the buffer. Only a reference to it will
+ be used. Therefore, any time data is written to this stream it will
+ immediately show up in the buffer.
+ !*/
+
+ std::istream& seekg (
+ std::streampos pos
+ );
+ /*!
+ ensures
+ - The next read from this object will read from the position buffer[pos],
+ where buffer is the std::vector given to this object's constructor. Note
+ that if pos >= buffer.size() then the next read will simply return EOF.
+ - returns *this
+ !*/
+
+ };
+}
+
+#endif // DLIB_VECTORStREAM_ABSTRACT_Hh_
+
+
diff --git a/ml/dlib/dlib/windows_magic.h b/ml/dlib/dlib/windows_magic.h
new file mode 100644
index 000000000..cfb7f22ed
--- /dev/null
+++ b/ml/dlib/dlib/windows_magic.h
@@ -0,0 +1,50 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_WINDOWS_MAGIc_
+#define DLIB_WINDOWS_MAGIc_
+
+#include "platform.h"
+
+#ifdef WIN32
+
+// This file contains all the magical #defines you have to setup before you
+// include the windows header files.
+
+#ifndef NOMINMAX
+#define NOMINMAX // prevent windows from messing with std::min and std::max
+#endif
+
+// Prevent windows from #defining IN or OUT
+#ifndef _NO_W32_PSEUDO_MODIFIERS
+#define _NO_W32_PSEUDO_MODIFIERS
+#endif
+
+// now just for good measure undefine min and max if they are defined
+#ifdef min
+#undef min
+#endif
+
+#ifdef max
+#undef max
+#endif
+
+#ifdef NO_MAKEFILE
+// only define this if all the cpp files are going to be sucked into the headers
+// because otherwise we don't need it since everything is isolated in the sockets
+// cpp file and this declaration for _WINSOCKAPI_ appears there also.
+#ifndef _WINSOCKAPI_
+#define _WINSOCKAPI_ /* Prevent inclusion of winsock.h in windows.h */
+#endif
+#endif
+
+// This is something stupid you have to do to make visual studio include the right
+// stuff. I don't really know what the deal is with this.
+#if _WIN32_WINNT < 0x0500
+#undef _WIN32_WINNT
+#define _WIN32_WINNT 0x0500
+#endif
+
+#endif // WIN32
+
+#endif // DLIB_WINDOWS_MAGIc_
+
diff --git a/ml/dlib/dlib/xml_parser.h b/ml/dlib/dlib/xml_parser.h
new file mode 100644
index 000000000..4e60ed68f
--- /dev/null
+++ b/ml/dlib/dlib/xml_parser.h
@@ -0,0 +1,13 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_XML_PARSEr_
+#define DLIB_XML_PARSEr_
+
+#include <string>
+
+#include "xml_parser/xml_parser_kernel_interfaces.h"
+#include "xml_parser/xml_parser_kernel_1.h"
+
+
+#endif // DLIB_XML_PARSEr_
+
diff --git a/ml/dlib/dlib/xml_parser/xml_parser_kernel_1.h b/ml/dlib/dlib/xml_parser/xml_parser_kernel_1.h
new file mode 100644
index 000000000..e1854bc26
--- /dev/null
+++ b/ml/dlib/dlib/xml_parser/xml_parser_kernel_1.h
@@ -0,0 +1,1532 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_XML_PARSER_KERNEl_1_
+#define DLIB_XML_PARSER_KERNEl_1_
+
+
+#include "xml_parser_kernel_abstract.h"
+
+#include <sstream>
+#include <string>
+#include <fstream>
+#include <iostream>
+#include "xml_parser_kernel_interfaces.h"
+#include "../algs.h"
+#include <cstdio>
+#include "../map.h"
+#include "../stack.h"
+#include "../sequence.h"
+#include "../memory_manager.h"
+
+namespace dlib
+{
+
+ class xml_parser
+ {
+ typedef dlib::map<std::string,std::string,memory_manager<char>::kernel_2a>::kernel_1b map;
+ typedef dlib::stack<std::string,memory_manager<char>::kernel_2a>::kernel_1a stack;
+ typedef sequence<document_handler*>::kernel_2a seq_dh;
+ typedef sequence<error_handler*>::kernel_2a seq_eh;
+
+ /*!
+ INITIAL VALUE
+ dh_list.size() == 0
+ eh_list.size() == 0
+
+ CONVENTION
+ dh_list == a sequence of pointers to all the document_handlers that
+ have been added to the xml_parser
+ eh_list == a sequence of pointers to all the error_handlers that
+ have been added to the xml_parser
+
+ map is used to implement the attribute_list interface
+ stack is used just inside the parse function
+ seq_dh is used to make the dh_list member variable
+ seq_eh is used to make the eh_list member variable
+ !*/
+
+
+
+ public:
+
+ // These typedefs are here for backwards compatibly with previous versions of
+ // dlib.
+ typedef xml_parser kernel_1a;
+ typedef xml_parser kernel_1a_c;
+
+ xml_parser(
+ ) {}
+
+ virtual ~xml_parser(
+ ){}
+
+ inline void clear(
+ );
+
+ inline void parse (
+ std::istream& in
+ );
+
+ inline void add_document_handler (
+ document_handler& item
+ );
+
+ inline void add_error_handler (
+ error_handler& item
+ );
+
+
+ inline void swap (
+ xml_parser& item
+ );
+
+
+ private:
+
+ // -----------------------------------
+
+ // attribute_list interface implementation
+ class attrib_list : public attribute_list
+ {
+ public:
+ // the list of attribute name/value pairs
+ map list;
+
+ bool is_in_list (
+ const std::string& key
+ ) const
+ {
+ return list.is_in_domain(key);
+ }
+
+ const std::string& operator[] (
+ const std::string& key
+ ) const
+ {
+ if (is_in_list(key))
+ return list[key];
+ else
+ throw xml_attribute_list_error("No XML attribute named " + key + " is present in tag.");
+ }
+
+ bool at_start (
+ ) const { return list.at_start(); }
+
+ void reset (
+ ) const { return list.reset(); }
+
+ bool current_element_valid (
+ ) const { return list.current_element_valid(); }
+
+ const type& element (
+ ) const { return list.element(); }
+
+ type& element (
+ ) { return list.element(); }
+
+ bool move_next (
+ ) const { return list.move_next(); }
+
+ size_t size (
+ ) const { return list.size(); }
+ };
+
+
+ // -----------------------------------
+
+ enum token_type
+ {
+ element_start, // the first tag of an element
+ element_end, // the last tag of an element
+ empty_element, // the singular tag of an empty element
+ pi, // processing instruction
+ chars, // the non-markup data between tags
+ chars_cdata, // the data from a CDATA section
+ eof, // this token is returned when we reach the end of input
+ error, // this token indicates that the tokenizer couldn't
+ // determine which category the next token fits into
+ dtd, // this token is for an entire dtd
+ comment // this is a token for comments
+ };
+ /*
+ notes about the tokens:
+ the tokenizer guarantees that the following tokens to not
+ contain the '<' character except as the first character of the token
+ element_start, element_end, empty_element, and pi. they also only
+ contain the '>' characer as their last character.
+
+ it is also guaranteed that pi is at least of the form <??>. that
+ is to say that it always always begins with <? and ends with ?>.
+
+ it is also guaranteed that all markup tokens will begin with the '<'
+ character and end with the '>'. there won't be any leading or
+ trailing whitespaces. this whitespace is considered a chars token.
+ */
+
+
+ // private member functions
+ inline void get_next_token(
+ std::istream& in,
+ std::string& token_text,
+ int& token_kind,
+ unsigned long& line_number
+ );
+ /*!
+ ensures
+ gets the next token from in and puts it in token_text and
+ token_kind == the kind of the token found and
+ line_number is incremented every time a '\n' is encountered and
+ entity references are translated into the characters they represent
+ only for chars tokens
+ !*/
+
+ inline int parse_element (
+ const std::string& token,
+ std::string& name,
+ attrib_list& atts
+ );
+ /*!
+ requires
+ token is a token of kind start_element or empty_element
+ ensures
+ gets the element name and puts it into the string name and
+ parses out the attributes and puts them into the attribute_list atts
+
+ return 0 upon success or
+ returns -1 if it failed to parse token
+ !*/
+
+ inline int parse_pi (
+ const std::string& token,
+ std::string& target,
+ std::string& data
+ );
+ /*!
+ requires
+ token is a token of kind pi
+ ensures
+ the target from the processing instruction is put into target and
+ the data from the processing instruction is put into data
+
+ return 0 upon success or
+ returns -1 if it failed to parse token
+ !*/
+
+ inline int parse_element_end (
+ const std::string& token,
+ std::string& name
+ );
+ /*!
+ requires
+ token is a token of kind element_end
+ ensures
+ the name from the ending element tag is put into the string name
+
+ return 0 upon success or
+ returns -1 if it failed to parse token
+ !*/
+
+ inline int change_entity (
+ std::istream& in
+ );
+ /*!
+ ensures
+ performs the following translations and returns the new character
+ amp; -> &
+ lt; -> <
+ gt; -> >
+ apos; -> '
+ quot; -> "
+
+ or returns -1 if we hit an undefined entity reference or EOF.
+ (i.e. it was not one of the entities listed above)
+
+ !*/
+
+ // -----------------------------------
+
+ // private member data
+ seq_dh dh_list;
+ seq_eh eh_list;
+
+ // -----------------------------------
+
+ // restricted functions: assignment and copy construction
+ xml_parser(xml_parser&);
+ xml_parser& operator= (
+ xml_parser&
+ );
+
+ };
+
+ inline void swap (
+ xml_parser& a,
+ xml_parser& b
+ ) { a.swap(b); }
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void xml_parser::
+ clear(
+ )
+ {
+ // unregister all event handlers
+ eh_list.clear();
+ dh_list.clear();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void xml_parser::
+ parse (
+ std::istream& in
+ )
+ {
+ DLIB_CASSERT ( in.fail() == false ,
+ "\tvoid xml_parser::parse"
+ << "\n\tthe input stream must not be in the fail state"
+ << "\n\tthis: " << this
+ );
+
+
+ // save which exceptions in will throw and make it so it won't throw any
+ // for the life of this function
+ std::ios::iostate old_exceptions = in.exceptions();
+ // set it to not throw anything
+ in.exceptions(std::ios::goodbit);
+
+
+ try
+ {
+ unsigned long line_number = 1;
+
+ // skip any whitespace before the start of the document
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n' || in.peek() == '\r' )
+ {
+ if (in.peek() == '\n')
+ ++line_number;
+ in.get();
+ }
+
+
+
+ stack tags; // this stack contains the last start tag seen
+ bool seen_fatal_error = false;
+ bool seen_root_tag = false; // this is true after we have seen the root tag
+
+
+
+ // notify all the document_handlers that we are about to being parsing
+ for (unsigned long i = 0; i < dh_list.size(); ++i)
+ {
+ dh_list[i]->start_document();
+ }
+
+
+ std::string chars_buf; // used to collect chars data between consecutive
+ // chars and chars_cdata tokens so that
+ // document_handlers receive all chars data between
+ // tags in one call
+
+ // variables to be used with the parsing functions
+ attrib_list atts;
+ std::string name;
+ std::string target;
+ std::string data;
+
+
+
+ // variables to use with the get_next_token() function
+ std::string token_text;
+ int token_kind;
+
+ get_next_token(in,token_text,token_kind,line_number);
+
+
+ while (token_kind != eof)
+ {
+ bool is_empty = false; // this becomes true when this token is an empty_element
+
+ switch (token_kind)
+ {
+
+
+ case empty_element: is_empty = true;
+ // fall through
+ case element_start:
+ {
+ seen_root_tag = true;
+
+ int status = parse_element(token_text,name,atts);
+ // if there was no error parsing the element
+ if (status == 0)
+ {
+ // notify all the document_handlers
+ for (unsigned long i = 0; i < dh_list.size(); ++i)
+ {
+ dh_list[i]->start_element(line_number,name,atts);
+ if (is_empty)
+ dh_list[i]->end_element(line_number,name);
+ }
+ }
+ else
+ {
+ seen_fatal_error = true;
+ }
+
+ // if this is an element_start token then push the name of
+ // the element on to the stack
+ if (token_kind == element_start)
+ {
+ tags.push(name);
+ }
+
+ }break;
+
+ // ----------------------------------------
+
+ case element_end:
+ {
+
+ int status = parse_element_end (token_text,name);
+
+ // if there was no error parsing the element
+ if (status == 0)
+ {
+ // make sure this ending element tag matches the last start
+ // element tag we saw
+ if ( tags.size() == 0 || name != tags.current())
+ {
+ // they don't match so signal a fatal error
+ seen_fatal_error = true;
+ }
+ else
+ {
+ // notify all the document_handlers
+ for (unsigned long i = 0; i < dh_list.size(); ++i)
+ {
+ dh_list[i]->end_element(line_number,name);
+ }
+
+ // they match so throw away this element name
+ tags.pop(name);
+ }
+ }
+ else
+ {
+ seen_fatal_error = true;
+ }
+
+
+ }break;
+
+ // ----------------------------------------
+
+ case pi:
+ {
+
+ int status = parse_pi (token_text,target,data);
+ // if there was no error parsing the element
+ if (status == 0)
+ {
+ // notify all the document_handlers
+ for (unsigned long i = 0; i < dh_list.size(); ++i)
+ {
+ dh_list[i]->processing_instruction(line_number,target,data);
+ }
+ }
+ else
+ {
+ // notify all the error_handlers
+ for (unsigned long i = 0; i < eh_list.size(); ++i)
+ {
+ eh_list[i]->error(line_number);
+ }
+ }
+ while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n' || in.peek() == '\r' )
+ {
+ if (in.peek() == '\n')
+ ++line_number;
+ in.get();
+ }
+
+
+ }break;
+
+ // ----------------------------------------
+
+ case chars:
+ {
+ if (tags.size() != 0)
+ {
+ chars_buf += token_text;
+ }
+ else if (token_text.find_first_not_of(" \t\r\n") != std::string::npos)
+ {
+ // you can't have non whitespace chars data outside the root element
+ seen_fatal_error = true;
+ }
+ }break;
+
+ // ----------------------------------------
+
+ case chars_cdata:
+ {
+ if (tags.size() != 0)
+ {
+ chars_buf += token_text;
+ }
+ else
+ {
+ // you can't have chars_data outside the root element
+ seen_fatal_error = true;
+ }
+ }break;
+
+ // ----------------------------------------
+
+ case eof:
+ break;
+
+ // ----------------------------------------
+
+ case error:
+ {
+ seen_fatal_error = true;
+ }break;
+
+ // ----------------------------------------
+
+ case dtd: // fall though
+ case comment: // do nothing
+ break;
+
+ // ----------------------------------------
+
+
+ }
+
+ // if there was a fatal error then quit loop
+ if (seen_fatal_error)
+ break;
+
+ // if we have seen the last tag then quit the loop
+ if (tags.size() == 0 && seen_root_tag)
+ break;
+
+
+ get_next_token(in,token_text,token_kind,line_number);
+
+ // if the next token is not a chars or chars_cdata token then flush
+ // the chars_buf to the document_handlers
+ if ( (token_kind != chars) &&
+ (token_kind != chars_cdata) &&
+ (token_kind != dtd) &&
+ (token_kind != comment) &&
+ (chars_buf.size() != 0)
+ )
+ {
+ // notify all the document_handlers
+ for (unsigned long i = 0; i < dh_list.size(); ++i)
+ {
+ dh_list[i]->characters(chars_buf);
+ }
+ chars_buf.erase();
+ }
+
+
+ } //while (token_kind != eof)
+
+
+
+
+ // you can't have any unmatched tags or any fatal erros
+ if (tags.size() != 0 || seen_fatal_error)
+ {
+ // notify all the error_handlers
+ for (unsigned long i = 0; i < eh_list.size(); ++i)
+ {
+ eh_list[i]->fatal_error(line_number);
+ }
+
+ }
+
+
+ // notify all the document_handlers that we have ended parsing
+ for (unsigned long i = 0; i < dh_list.size(); ++i)
+ {
+ dh_list[i]->end_document();
+ }
+
+ }
+ catch (...)
+ {
+ // notify all the document_handlers that we have ended parsing
+ for (unsigned long i = 0; i < dh_list.size(); ++i)
+ {
+ dh_list[i]->end_document();
+ }
+
+ // restore the old exception settings to in
+ in.exceptions(old_exceptions);
+
+ // don't forget to rethrow the exception
+ throw;
+ }
+
+ // restore the old exception settings to in
+ in.exceptions(old_exceptions);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void xml_parser::
+ add_document_handler (
+ document_handler& item
+ )
+ {
+ document_handler* temp = &item;
+ dh_list.add(dh_list.size(),temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void xml_parser::
+ add_error_handler (
+ error_handler& item
+ )
+ {
+ error_handler* temp = &item;
+ eh_list.add(eh_list.size(),temp);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void xml_parser::
+ swap (
+ xml_parser& item
+ )
+ {
+ dh_list.swap(item.dh_list);
+ eh_list.swap(item.eh_list);
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+ // private member function definitions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ void xml_parser::
+ get_next_token(
+ std::istream& in,
+ std::string& token_text,
+ int& token_kind,
+ unsigned long& line_number
+ )
+ {
+
+ token_text.erase();
+
+ std::istream::int_type ch1 = in.get();
+ std::istream::int_type ch2;
+
+
+ switch (ch1)
+ {
+
+ // -----------------------------------------
+
+ // this is the start of some kind of a tag
+ case '<':
+ {
+ ch2 = in.get();
+ switch (ch2)
+ {
+
+ // ---------------------------------
+
+ // this is a dtd, comment, or chars_cdata token
+ case '!':
+ {
+ // if this is a CDATA section *******************************
+ if ( in.peek() == '[')
+ {
+ token_kind = chars_cdata;
+
+ // throw away the '['
+ in.get();
+
+ // make sure the next chars are CDATA[
+ std::istream::int_type ch = in.get();
+ if (ch != 'C')
+ token_kind = error;
+ ch = in.get();
+ if (ch != 'D')
+ token_kind = error;
+ ch = in.get();
+ if (ch != 'A')
+ token_kind = error;
+ ch = in.get();
+ if (ch != 'T')
+ token_kind = error;
+ ch = in.get();
+ if (ch != 'A')
+ token_kind = error;
+ ch = in.get();
+ if (ch != '[')
+ token_kind = error;
+ // if this is an error token then end
+ if (token_kind == error)
+ break;
+
+
+ // get the rest of the chars and put them into token_text
+ int brackets_seen = 0; // this is the number of ']' chars
+ // we have seen in a row
+ bool seen_closing = false; // true if we have seen ]]>
+ do
+ {
+ ch = in.get();
+
+ if (ch == '\n')
+ ++line_number;
+
+ token_text += ch;
+
+ // if this is the closing
+ if (brackets_seen == 2 && ch == '>')
+ seen_closing = true;
+ // if we are seeing a bracket
+ else if (ch == ']')
+ ++brackets_seen;
+ // if we didn't see a bracket
+ else
+ brackets_seen = 0;
+
+
+ } while ( (!seen_closing) && (ch != EOF) );
+
+ // check if this is an error token
+ if (ch == EOF)
+ {
+ token_kind = error;
+ }
+ else
+ {
+ token_text.erase(token_text.size()-3);
+ }
+
+
+
+ }
+ // this is a comment token ****************************
+ else if (in.peek() == '-')
+ {
+
+ token_text += ch1;
+ token_text += ch2;
+ token_text += '-';
+
+ token_kind = comment;
+
+ // throw away the '-' char
+ in.get();
+
+ // make sure the next char is another '-'
+ std::istream::int_type ch = in.get();
+ if (ch != '-')
+ {
+ token_kind = error;
+ break;
+ }
+
+ token_text += '-';
+
+
+ // get the rest of the chars and put them into token_text
+ int hyphens_seen = 0; // this is the number of '-' chars
+ // we have seen in a row
+ bool seen_closing = false; // true if we have seen ]]>
+ do
+ {
+ ch = in.get();
+
+ if (ch == '\n')
+ ++line_number;
+
+ token_text += ch;
+
+ // if this should be a closing block
+ if (hyphens_seen == 2)
+ {
+ if (ch == '>')
+ seen_closing = true;
+ else // this isn't a closing so make it signal error
+ ch = EOF;
+ }
+ // if we are seeing a hyphen
+ else if (ch == '-')
+ ++hyphens_seen;
+ // if we didn't see a hyphen
+ else
+ hyphens_seen = 0;
+
+
+ } while ( (!seen_closing) && (ch != EOF) );
+
+ // check if this is an error token
+ if (ch == EOF)
+ {
+ token_kind = error;
+ }
+
+
+
+
+
+ }
+ else // this is a dtd token *************************
+ {
+
+ token_text += ch1;
+ token_text += ch2;
+ int bracket_depth = 1; // this is the number of '<' chars seen
+ // minus the number of '>' chars seen
+
+ std::istream::int_type ch;
+ do
+ {
+ ch = in.get();
+ if (ch == '>')
+ --bracket_depth;
+ else if (ch == '<')
+ ++bracket_depth;
+ else if (ch == '\n')
+ ++line_number;
+
+ token_text += ch;
+
+ } while ( (bracket_depth > 0) && (ch != EOF) );
+
+ // make sure we didn't just hit EOF
+ if (bracket_depth == 0)
+ {
+ token_kind = dtd;
+ }
+ else
+ {
+ token_kind = error;
+ }
+ }
+ }
+ break;
+
+ // ---------------------------------
+
+ // this is a pi token
+ case '?':
+ {
+ token_text += ch1;
+ token_text += ch2;
+ std::istream::int_type ch;
+
+ do
+ {
+ ch = in.get();
+ token_text += ch;
+ if (ch == '\n')
+ ++line_number;
+ // else if we hit a < then thats an error
+ else if (ch == '<')
+ ch = EOF;
+ } while (ch != '>' && ch != EOF);
+ // if we hit the end of the pi
+ if (ch == '>')
+ {
+ // make sure there was a trailing '?'
+ if ( (token_text.size() > 3) &&
+ (token_text[token_text.size()-2] != '?')
+ )
+ {
+ token_kind = error;
+ }
+ else
+ {
+ token_kind = pi;
+ }
+ }
+ // if we hit EOF unexpectidely then error
+ else
+ {
+ token_kind = error;
+ }
+ }
+ break;
+
+ // ---------------------------------
+
+ // this is an error token
+ case EOF:
+ {
+ token_kind = error;
+ }
+ break;
+
+ // ---------------------------------
+ // this is an element_end token
+ case '/':
+ {
+ token_kind = element_end;
+ token_text += ch1;
+ token_text += ch2;
+ std::istream::int_type ch;
+ do
+ {
+ ch = in.get();
+ if (ch == '\n')
+ ++line_number;
+ // else if we hit a < then thats an error
+ else if (ch == '<')
+ ch = EOF;
+ token_text += ch;
+ } while ( (ch != '>') && (ch != EOF));
+
+ // check if this is an error token
+ if (ch == EOF)
+ {
+ token_kind = error;
+ }
+ }
+ break;
+
+
+ // ---------------------------------
+
+ // this is an element_start or empty_element token
+ default:
+ {
+
+ token_text += ch1;
+ token_text += ch2;
+ std::istream::int_type ch = '\0';
+ std::istream::int_type last;
+ do
+ {
+ last = ch;
+ ch = in.get();
+ if (ch == '\n')
+ ++line_number;
+ // else if we hit a < then thats an error
+ else if (ch == '<')
+ ch = EOF;
+ token_text += ch;
+ } while ( (ch != '>') && (ch != EOF));
+
+ // check if this is an error token
+ if (ch == EOF)
+ {
+ token_kind = error;
+ }
+ // if this is an empty_element
+ else if (last == '/')
+ {
+ token_kind = empty_element;
+ }
+ else
+ {
+ token_kind = element_start;
+ }
+
+
+ }
+ break;
+
+ // ---------------------------------
+
+ }
+
+ }
+ break;
+
+ // -----------------------------------------
+
+ // this is an eof token
+ case EOF:
+ {
+ token_kind = eof;
+ }
+ break;
+
+ // -----------------------------------------
+
+ // this is a chars token
+ default:
+ {
+ if (ch1 == '\n')
+ {
+ ++line_number;
+ token_text += ch1;
+ }
+ // if the first thing in this chars token is an entity reference
+ else if (ch1 == '&')
+ {
+
+ int temp = change_entity(in);
+ if (temp == -1)
+ {
+ token_kind = error;
+ break;
+ }
+ else
+ {
+ token_text += temp;
+ }
+ }
+ else
+ {
+ token_text += ch1;
+ }
+
+
+ token_kind = chars;
+
+ std::istream::int_type ch = 0;
+ while (in.peek() != '<' && in.peek() != EOF)
+ {
+ ch = in.get();
+
+ if (ch == '\n')
+ ++line_number;
+
+ // if this is one of the predefined entity references then change it
+ if (ch == '&')
+ {
+ int temp = change_entity(in);
+ if (temp == -1)
+ {
+ ch = EOF;
+ break;
+ }
+ else
+ token_text += temp;
+ }
+ else
+ {
+ token_text += ch;
+ }
+ }
+
+ // if this is an error token
+ if (ch == EOF)
+ {
+ token_kind = error;
+ }
+
+ }
+ break;
+
+ // -----------------------------------------
+
+ }
+
+
+ }
+
+
+
+// ----------------------------------------------------------------------------------------
+
+ int xml_parser::
+ parse_element (
+ const std::string& token,
+ std::string& name,
+ attrib_list& atts
+ )
+ {
+ name.erase();
+ atts.list.clear();
+
+ // there must be at least one character between the <>
+ if (token[1] == '>')
+ return -1;
+
+ std::string::size_type i;
+ std::istream::int_type ch = token[1];
+ i = 2;
+
+ // fill out name. the name can not contain any of the following characters
+ while ( (ch != '>') &&
+ (ch != ' ') &&
+ (ch != '=') &&
+ (ch != '/') &&
+ (ch != '\t') &&
+ (ch != '\r') &&
+ (ch != '\n')
+ )
+ {
+ name += ch;
+ ch = token[i];
+ ++i;
+ }
+
+ // skip any whitespaces
+ while ( ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' )
+ {
+ ch = token[i];
+ ++i;
+ }
+
+ // find any attributes
+ while (ch != '>' && ch != '/')
+ {
+ std::string attribute_name;
+ std::string attribute_value;
+
+ // fill out attribute_name
+ while ( (ch != '=') &&
+ (ch != ' ') &&
+ (ch != '\t') &&
+ (ch != '\r') &&
+ (ch != '\n') &&
+ (ch != '>')
+ )
+ {
+ attribute_name += ch;
+ ch = token[i];
+ ++i;
+ }
+
+ // you can't have empty attribute names
+ if (attribute_name.size() == 0)
+ return -1;
+
+ // if we hit > too early then return error
+ if (ch == '>')
+ return -1;
+
+ // skip any whitespaces
+ while (ch == ' ' || ch == '\t' || ch =='\n' || ch =='\r')
+ {
+ ch = token[i];
+ ++i;
+ }
+
+ // the next char should be a '=', error if it's not
+ if (ch != '=')
+ return -1;
+
+ // get the next char
+ ch = token[i];
+ ++i;
+
+ // skip any whitespaces
+ while (ch == ' ' || ch == '\t' || ch =='\n' || ch =='\r')
+ {
+ ch = token[i];
+ ++i;
+ }
+
+
+ // get the delimiter for the attribute value
+ std::istream::int_type delimiter = ch; // this should be either a ' or " character
+ ch = token[i]; // get the next char
+ ++i;
+ if (delimiter != '\'' && delimiter!='"')
+ return -1;
+
+
+ // fill out attribute_value
+ while ( (ch != delimiter) &&
+ (ch != '>')
+ )
+ {
+ attribute_value += ch;
+ ch = token[i];
+ ++i;
+ }
+
+
+ // if there was no delimiter then this is an error
+ if (ch == '>')
+ {
+ return -1;
+ }
+
+ // go to the next char
+ ch = token[i];
+ ++i;
+
+ // the next char must be either a '>' or '/' (denoting the end of the tag)
+ // or a white space character
+ if (ch != '>' && ch != ' ' && ch != '/' && ch != '\t' && ch !='\n' && ch !='\r')
+ return -1;
+
+ // skip any whitespaces
+ while (ch == ' ' || ch == '\t' || ch =='\n' || ch =='\r')
+ {
+ ch = token[i];
+ ++i;
+ }
+
+
+ // add attribute_value and attribute_name to atts
+ if (atts.list.is_in_domain(attribute_name))
+ {
+ // attributes may not be multiply defined
+ return -1;
+ }
+ else
+ {
+ atts.list.add(attribute_name,attribute_value);
+ }
+
+
+ }
+
+ // you can't have an element with no name
+ if (name.size() == 0)
+ return -1;
+
+ return 0;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int xml_parser::
+ parse_pi (
+ const std::string& token,
+ std::string& target,
+ std::string& data
+ )
+ {
+ target.erase();
+ data.erase();
+
+ std::istream::int_type ch = token[2];
+ std::string::size_type i = 3;
+ while (ch != ' ' && ch != '?' && ch != '\t' && ch != '\n' && ch!='\r')
+ {
+ target += ch;
+ ch = token[i];
+ ++i;
+ }
+ if (target.size() == 0)
+ return -1;
+
+ // if we aren't at a ? character then go to the next character
+ if (ch != '?' )
+ {
+ ch = token[i];
+ ++i;
+ }
+
+ // if we still aren't at the end of the processing instruction then
+ // set this stuff in the data section
+ while (ch != '?')
+ {
+ data += ch;
+ ch = token[i];
+ ++i;
+ }
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int xml_parser::
+ parse_element_end (
+ const std::string& token,
+ std::string& name
+ )
+ {
+ name.erase();
+ std::string::size_type end = token.size()-1;
+ for (std::string::size_type i = 2; i < end; ++i)
+ {
+ if (token[i] == ' ' || token[i] == '\t' || token[i] == '\n'|| token[i] == '\r')
+ break;
+ name += token[i];
+ }
+
+ if (name.size() == 0)
+ return -1;
+
+ return 0;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ int xml_parser::
+ change_entity (
+ std::istream& in
+ )
+ {
+
+ std::istream::int_type buf[6];
+
+
+ buf[1] = in.get();
+
+ // if this is an undefined entity reference then return error
+ if (buf[1] != 'a' &&
+ buf[1] != 'l' &&
+ buf[1] != 'g' &&
+ buf[1] != 'q'
+ )
+ return -1;
+
+
+ buf[2] = in.get();
+ // if this is an undefined entity reference then return error
+ if (buf[2] != 'm' &&
+ buf[2] != 't' &&
+ buf[2] != 'p' &&
+ buf[2] != 'u'
+ )
+ return -1;
+
+
+ buf[3] = in.get();
+ // if this is an undefined entity reference then return error
+ if (buf[3] != 'p' &&
+ buf[3] != ';' &&
+ buf[3] != 'o'
+ )
+ return -1;
+
+ // check if this is &lt; or &gt;
+ if (buf[3] == ';')
+ {
+ if (buf[2] != 't')
+ return -1;
+
+ // if this is &lt; then return '<'
+ if (buf[1] == 'l')
+ return '<';
+ // if this is &gt; then return '>'
+ if (buf[1] == 'g')
+ return '>';
+
+ // it is neither so it must be an undefined entity reference
+ return -1;
+ }
+
+
+ buf[4] = in.get();
+ // if this should be &amp;
+ if (buf[4] == ';')
+ {
+ // if this is not &amp; then return error
+ if (buf[1] != 'a' ||
+ buf[2] != 'm' ||
+ buf[3] != 'p'
+ )
+ return -1;
+
+ return '&';
+ }
+
+ buf[5] = in.get();
+
+ // if this should be &apos;
+ if (buf[1] == 'a' &&
+ buf[2] == 'p' &&
+ buf[3] == 'o' &&
+ buf[4] == 's' &&
+ buf[5] == ';'
+ )
+ return '\'';
+
+
+ // if this should be &quot;
+ if (buf[1] == 'q' &&
+ buf[2] == 'u' &&
+ buf[3] == 'o' &&
+ buf[4] == 't' &&
+ buf[5] == ';'
+ )
+ return '"';
+
+
+ // it was an undefined entity reference
+ return -1;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class xml_parse_error : public error
+ {
+ public:
+ xml_parse_error(
+ const std::string& a
+ ): error(a) {}
+ };
+
+ namespace impl
+ {
+ class default_xml_error_handler : public error_handler
+ {
+ std::string filename;
+
+ public:
+
+ default_xml_error_handler (
+ ) {}
+
+ default_xml_error_handler (
+ const std::string& filename_
+ ) :filename(filename_) {}
+
+ virtual void error (
+ const unsigned long
+ )
+ {
+ // just ignore non-fatal errors
+ }
+
+ virtual void fatal_error (
+ const unsigned long line_number
+ )
+ {
+ std::ostringstream sout;
+ if (filename.size() != 0)
+ sout << "There is a fatal error on line " << line_number << " in the XML file '"<<filename<<"'.";
+ else
+ sout << "There is a fatal error on line " << line_number << " in the XML being processed.";
+
+ throw xml_parse_error(sout.str());
+ }
+ };
+ }
+
+ inline void parse_xml (
+ std::istream& in,
+ document_handler& dh,
+ error_handler& eh
+ )
+ {
+ if (!in)
+ throw xml_parse_error("Unexpected end of file during xml parsing.");
+ xml_parser parser;
+ parser.add_document_handler(dh);
+ parser.add_error_handler(eh);
+ parser.parse(in);
+ }
+
+ inline void parse_xml (
+ std::istream& in,
+ error_handler& eh,
+ document_handler& dh
+ )
+ {
+ if (!in)
+ throw xml_parse_error("Unexpected end of file during xml parsing.");
+ xml_parser parser;
+ parser.add_document_handler(dh);
+ parser.add_error_handler(eh);
+ parser.parse(in);
+ }
+
+ inline void parse_xml (
+ std::istream& in,
+ error_handler& eh
+ )
+ {
+ if (!in)
+ throw xml_parse_error("Unexpected end of file during xml parsing.");
+ xml_parser parser;
+ parser.add_error_handler(eh);
+ parser.parse(in);
+ }
+
+ inline void parse_xml (
+ std::istream& in,
+ document_handler& dh
+ )
+ {
+ if (!in)
+ throw xml_parse_error("Unexpected end of file during xml parsing.");
+ xml_parser parser;
+ parser.add_document_handler(dh);
+ impl::default_xml_error_handler eh;
+ parser.add_error_handler(eh);
+ parser.parse(in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void parse_xml (
+ const std::string& filename,
+ document_handler& dh,
+ error_handler& eh
+ )
+ {
+ std::ifstream in(filename.c_str());
+ if (!in)
+ throw xml_parse_error("Unable to open file '" + filename + "'.");
+ xml_parser parser;
+ parser.add_document_handler(dh);
+ parser.add_error_handler(eh);
+ parser.parse(in);
+ }
+
+ inline void parse_xml (
+ const std::string& filename,
+ error_handler& eh,
+ document_handler& dh
+ )
+ {
+ std::ifstream in(filename.c_str());
+ if (!in)
+ throw xml_parse_error("Unable to open file '" + filename + "'.");
+ xml_parser parser;
+ parser.add_document_handler(dh);
+ parser.add_error_handler(eh);
+ parser.parse(in);
+ }
+
+ inline void parse_xml (
+ const std::string& filename,
+ error_handler& eh
+ )
+ {
+ std::ifstream in(filename.c_str());
+ if (!in)
+ throw xml_parse_error("Unable to open file '" + filename + "'.");
+ xml_parser parser;
+ parser.add_error_handler(eh);
+ parser.parse(in);
+ }
+
+ inline void parse_xml (
+ const std::string& filename,
+ document_handler& dh
+ )
+ {
+ std::ifstream in(filename.c_str());
+ if (!in)
+ throw xml_parse_error("Unable to open file '" + filename + "'.");
+ xml_parser parser;
+ parser.add_document_handler(dh);
+ impl::default_xml_error_handler eh(filename);
+ parser.add_error_handler(eh);
+ parser.parse(in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_XML_PARSER_KERNEl_1_
+
diff --git a/ml/dlib/dlib/xml_parser/xml_parser_kernel_abstract.h b/ml/dlib/dlib/xml_parser/xml_parser_kernel_abstract.h
new file mode 100644
index 000000000..45b513e55
--- /dev/null
+++ b/ml/dlib/dlib/xml_parser/xml_parser_kernel_abstract.h
@@ -0,0 +1,276 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_XML_PARSER_KERNEl_ABSTRACT_
+#ifdef DLIB_XML_PARSER_KERNEl_ABSTRACT_
+
+#include <string>
+#include <iosfwd>
+#include "xml_parser_kernel_interfaces.h"
+
+namespace dlib
+{
+
+ class xml_parser
+ {
+
+ /*!
+ INITIAL VALUE
+ no objects are registered to receive events
+
+
+ WHAT THIS OBJECT REPRESENTS
+ This object represents a simple SAX style event driven XML parser.
+ It takes its input from an input stream object and sends events to all
+ registered document_handler and error_handler objects.
+
+ note that this xml parser ignores all DTD related XML markup. It will
+ parse XML documents with DTD's but it just won't check if the document
+ is valid. This also means that entity references may not be used except
+ for the predefined ones which are as follows:
+ &amp;
+ &lt;
+ &gt;
+ &apos;
+ &quot;
+
+ also note that there is no interpreting of entity references inside
+ a CDATA section or inside of tags, they are only interpreted inside
+ normal non-markup data.
+
+ This parser considers the end of the xml document to be the closing
+ tag of the root tag (as opposed to using EOF as the end of the
+ document). This is a deviation from the xml standard.
+
+ Aside from ignoring DTD stuff and entity references everywhere but
+ data, and the above comment regarding EOF, this parser should conform
+ to the rest of the XML standard.
+ !*/
+
+ public:
+
+
+ xml_parser(
+ );
+ /*!
+ ensures
+ - #*this is properly initialized
+ throws
+ - std::bad_alloc
+ !*/
+
+ virtual ~xml_parser(
+ );
+ /*!
+ ensures
+ - all memory associated with *this has been released
+ !*/
+
+ void clear(
+ );
+ /*!
+ ensures
+ - #*this has its initial value
+ throws
+ - std::bad_alloc
+ if this exception is thrown then *this is unusable
+ until clear() is called and succeeds
+ !*/
+
+ void parse (
+ std::istream& in
+ );
+ /*!
+ requires
+ - in.fail() == false
+ ensures
+ - the data from the input stream in will be parsed and the appropriate
+ events will be generated
+ - parsing will stop when the parser has reached the closing tag
+ for the xml document or EOF (which ever comes first). Note that
+ hitting EOF first is a fatal error.
+ throws
+ - std::bad_alloc
+ if parse() throws then it will be unusable until clear() is
+ called and succeeds
+ - other exceptions
+ document_handlers and error_handlers my throw any exception. If
+ they throw while parse() is running then parse() will let the
+ exception propagate out and the xml_parser object will be unusable
+ until clear() is called and succeeds. note that end_document()
+ is still called.
+ !*/
+
+ void add_document_handler (
+ document_handler& item
+ );
+ /*!
+ ensures
+ - item will now receive document events from the parser
+ throws
+ - std::bad_alloc
+ if add_document_handler() throws then it has no effect
+ !*/
+
+ void add_error_handler (
+ error_handler& item
+ );
+ /*!
+ ensures
+ - item will now receive error events from the parser
+ throws
+ - std::bad_alloc
+ if add_error_handler() throws then it has no effect
+ !*/
+
+
+ void swap (
+ xml_parser& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+
+ private:
+
+ // restricted functions
+ xml_parser(xml_parser&); // copy constructor
+ xml_parser& operator=(xml_parser&); // assignment operator
+
+ };
+
+
+ inline void swap (
+ xml_parser& a,
+ xml_parser& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap function
+ !*/
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class xml_parse_error : public error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is the exception object thrown by the parse_xml() routines defined
+ below.
+ !*/
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ void parse_xml (
+ std::istream& in,
+ document_handler& dh,
+ error_handler& eh
+ );
+ /*!
+ ensures
+ - makes an xml_parser and tells it to parse the given input stream using the
+ supplied document_handler and error_handler.
+ !*/
+
+ void parse_xml (
+ std::istream& in,
+ error_handler& eh,
+ document_handler& dh
+ )
+ /*!
+ ensures
+ - makes an xml_parser and tells it to parse the given input stream using the
+ supplied document_handler and error_handler.
+ !*/
+
+ void parse_xml (
+ std::istream& in,
+ error_handler& eh
+ );
+ /*!
+ ensures
+ - makes an xml_parser and tells it to parse the given input stream using the
+ supplied error_handler.
+ !*/
+
+ void parse_xml (
+ std::istream& in,
+ document_handler& dh
+ );
+ /*!
+ ensures
+ - makes an xml_parser and tells it to parse the given input stream using the
+ supplied document_handler.
+ - Uses a default error handler that will throw an xml_parse_error exception
+ if a fatal parsing error is encountered.
+ throws
+ - xml_parse_error
+ Thrown if a fatal parsing error is encountered.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ void parse_xml (
+ const std::string& filename,
+ document_handler& dh,
+ error_handler& eh
+ );
+ /*!
+ ensures
+ - makes an xml_parser and tells it to parse the given input file using the
+ supplied error_handler and document_handler.
+ throws
+ - xml_parse_error
+ Thrown if there is a problem parsing the input file.
+ !*/
+
+ void parse_xml (
+ const std::string& filename,
+ error_handler& eh,
+ document_handler& dh
+ );
+ /*!
+ ensures
+ - makes an xml_parser and tells it to parse the given input file using the
+ supplied error_handler and document_handler.
+ throws
+ - xml_parse_error
+ Thrown if there is a problem parsing the input file.
+ !*/
+
+ void parse_xml (
+ const std::string& filename,
+ error_handler& eh
+ );
+ /*!
+ ensures
+ - makes an xml_parser and tells it to parse the given input file using the
+ supplied error_handler.
+ throws
+ - xml_parse_error
+ Thrown if there is a problem parsing the input file.
+ !*/
+
+ void parse_xml (
+ const std::string& filename,
+ document_handler& dh
+ );
+ /*!
+ ensures
+ - makes an xml_parser and tells it to parse the given input file using the
+ supplied document_handler.
+ - Uses a default error handler that will throw an xml_parse_error exception
+ if a fatal parsing error is encountered.
+ throws
+ - xml_parse_error
+ Thrown if there is a problem parsing the input file.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_XML_PARSER_KERNEl_ABSTRACT_
+
diff --git a/ml/dlib/dlib/xml_parser/xml_parser_kernel_interfaces.h b/ml/dlib/dlib/xml_parser/xml_parser_kernel_interfaces.h
new file mode 100644
index 000000000..a0edf3317
--- /dev/null
+++ b/ml/dlib/dlib/xml_parser/xml_parser_kernel_interfaces.h
@@ -0,0 +1,244 @@
+// Copyright (C) 2003 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_XML_PARSER_KERNEl_INTERFACES_
+#define DLIB_XML_PARSER_KERNEl_INTERFACES_
+
+#include <string>
+#include "../interfaces/enumerable.h"
+#include "../interfaces/map_pair.h"
+#include "../error.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ class xml_attribute_list_error : public dlib::error
+ {
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ This is an exception object thrown by attribute_list objects if you try to
+ access a non-existent attribute.
+ !*/
+ public:
+ xml_attribute_list_error(const std::string& msg) : dlib::error(msg){}
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class attribute_list : public enumerable<map_pair<std::string,std::string> >
+ {
+
+ /*!
+ WHAT THIS OBJECT REPRESENTS
+ this object represents a list of the attributes found in
+ an XML element. each attribute is associated with a value.
+ !*/
+
+
+ public:
+
+ inline virtual ~attribute_list (
+ ) =0;
+
+
+ virtual bool is_in_list (
+ const std::string& key
+ ) const =0;
+ /*!
+ ensures
+ - returns true if there is an attribute named key in the list
+ - returns false
+ !*/
+
+ virtual const std::string& operator[] (
+ const std::string& key
+ ) const =0;
+ /*!
+ ensures
+ if (is_in_list(key) == true) then
+ - returns a const reference to the value associated with the attribute
+ named key.
+ - else
+ - throws xml_attribute_list_error
+ !*/
+
+ protected:
+
+ // restricted functions
+ attribute_list& operator=(attribute_list&) {return *this;}
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class document_handler
+ {
+ /*!
+ EXCEPTIONS
+ a document_handler is allowed to throw any exception
+
+
+ WHAT THIS OBJECT REPRESENTS
+ this object is an interface for handling the basic events
+ generated by an XML parser
+ !*/
+
+
+ public:
+
+ inline virtual ~document_handler (
+ ) =0;
+
+ virtual void start_document (
+ )=0;
+ /*!
+ requires
+ - is called when the document parsing begins
+ !*/
+
+ virtual void end_document (
+ )=0;
+ /*!
+ requires
+ - is called after the document parsing has ended. note that this
+ is always called, even if an error occurs.
+ !*/
+
+ virtual void start_element (
+ const unsigned long line_number,
+ const std::string& name,
+ const dlib::attribute_list& atts
+ )=0;
+ /*!
+ requires
+ - is called when an opening element tag is encountered.
+ - line_number == the line number where the opening tag for this element
+ was encountered.
+ - name == the name of the element encountered
+ - atts == a list containing all the attributes in this element and their
+ associated values
+ !*/
+
+ virtual void end_element (
+ const unsigned long line_number,
+ const std::string& name
+ )=0;
+ /*!
+ requires
+ - is called when a closing element tag is encountered. (note that this
+ includes tags such as <example_tag/>. I.e. the previous tag would
+ trigger a start_element() callback as well as an end_element() callback)
+ - line_number == the line number where the closing tag for this
+ element was encountered and
+ - name == the name of the element encountered
+ !*/
+
+ virtual void characters (
+ const std::string& data
+ )=0;
+ /*!
+ requires
+ - is called just before we encounter a start_element, end_element, or
+ processing_instruction tag but only if there was data between the
+ last and next tag.
+ (i.e. data will never be "")
+ - data == all the normal non-markup data and CDATA between the next and
+ last tag in the document.
+ !*/
+
+ virtual void processing_instruction (
+ const unsigned long line_number,
+ const std::string& target,
+ const std::string& data
+ )=0;
+ /*!
+ requires
+ - is called when a processing instruction is encountered
+ - line_number == the line number where this processing instruction
+ was encountered
+ - target == the target value for this processing instruction
+ - data == the data value for this processing instruction
+ !*/
+
+ protected:
+
+ // restricted functions
+ document_handler& operator=(document_handler&) { return *this; }
+ };
+
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ class error_handler
+ {
+ /*!
+ EXCEPTIONS
+ an error_handler is allowed to throw any exception
+
+
+ WHAT THIS OBJECT REPRESENTS
+ this object is an interface for handling the error/warning
+ events generated by an XML parser
+ !*/
+
+ public:
+
+ inline virtual ~error_handler (
+ ) =0;
+
+ virtual void error (
+ const unsigned long line_number
+ )=0;
+ /*!
+ requires
+ - is called when an error that does NOT require the parser to halt
+ is encountered. (i.e. somewhat minor errors in the input)
+ - line_number == the line number where this error was encountered
+
+ the following events trigger an error:
+ an invalid processing instruction
+ !*/
+
+ virtual void fatal_error (
+ const unsigned long line_number
+ )=0;
+ /*!
+ requires
+ - is called when an error that requires the parser to abort its parsing
+ is encountered (i.e. fatal errors in the input)
+ - line_number == the line number where this fatal error was encountered
+
+ the following events trigger a fatal_error:
+ Everything other than the events listed above for error.
+ Also note that encountering an entity reference other than the
+ predefined ones listed in xml_parser_kernel_abstract is a fatal_error.
+ Hitting EOF before the closing tag for the document is also a fatal_error.
+ !*/
+
+ protected:
+
+ // restricted functions
+ error_handler& operator=(error_handler&) { return *this;}
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ document_handler::~document_handler (
+ ){}
+ attribute_list::~attribute_list (
+ ){}
+ error_handler::~error_handler (
+ ){}
+
+}
+
+#endif // DLIB_XML_PARSER_KERNEl_INTERFACES_
+